Source file
src/net/http/csrf_test.go
1
2
3
4
5 package http_test
6
7 import (
8 "io"
9 "net/http"
10 "net/http/httptest"
11 "strings"
12 "testing"
13 )
14
15
16 func httptestNewRequest(method, target string) *http.Request {
17 req := httptest.NewRequest(method, target, nil)
18 req.URL.Scheme = ""
19 req.URL.Host = ""
20 return req
21 }
22
23 var okHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
24 w.WriteHeader(http.StatusOK)
25 })
26
27 func TestCrossOriginProtectionSecFetchSite(t *testing.T) {
28 protection := http.NewCrossOriginProtection()
29 handler := protection.Handler(okHandler)
30
31 tests := []struct {
32 name string
33 method string
34 secFetchSite string
35 origin string
36 expectedStatus int
37 }{
38 {"same-origin allowed", "POST", "same-origin", "", http.StatusOK},
39 {"none allowed", "POST", "none", "", http.StatusOK},
40 {"cross-site blocked", "POST", "cross-site", "", http.StatusForbidden},
41 {"same-site blocked", "POST", "same-site", "", http.StatusForbidden},
42
43 {"no header with no origin", "POST", "", "", http.StatusOK},
44 {"no header with matching origin", "POST", "", "https://example.com", http.StatusOK},
45 {"no header with mismatched origin", "POST", "", "https://attacker.example", http.StatusForbidden},
46 {"no header with null origin", "POST", "", "null", http.StatusForbidden},
47
48 {"GET allowed", "GET", "cross-site", "", http.StatusOK},
49 {"HEAD allowed", "HEAD", "cross-site", "", http.StatusOK},
50 {"OPTIONS allowed", "OPTIONS", "cross-site", "", http.StatusOK},
51 {"PUT blocked", "PUT", "cross-site", "", http.StatusForbidden},
52 }
53
54 for _, tc := range tests {
55 t.Run(tc.name, func(t *testing.T) {
56 req := httptestNewRequest(tc.method, "https://example.com/")
57 if tc.secFetchSite != "" {
58 req.Header.Set("Sec-Fetch-Site", tc.secFetchSite)
59 }
60 if tc.origin != "" {
61 req.Header.Set("Origin", tc.origin)
62 }
63
64 w := httptest.NewRecorder()
65 handler.ServeHTTP(w, req)
66
67 if w.Code != tc.expectedStatus {
68 t.Errorf("got status %d, want %d", w.Code, tc.expectedStatus)
69 }
70 })
71 }
72 }
73
74 func TestCrossOriginProtectionTrustedOriginBypass(t *testing.T) {
75 protection := http.NewCrossOriginProtection()
76 err := protection.AddTrustedOrigin("https://trusted.example")
77 if err != nil {
78 t.Fatalf("AddTrustedOrigin: %v", err)
79 }
80 handler := protection.Handler(okHandler)
81
82 tests := []struct {
83 name string
84 origin string
85 secFetchSite string
86 expectedStatus int
87 }{
88 {"trusted origin without sec-fetch-site", "https://trusted.example", "", http.StatusOK},
89 {"trusted origin with cross-site", "https://trusted.example", "cross-site", http.StatusOK},
90 {"untrusted origin without sec-fetch-site", "https://attacker.example", "", http.StatusForbidden},
91 {"untrusted origin with cross-site", "https://attacker.example", "cross-site", http.StatusForbidden},
92 }
93
94 for _, tc := range tests {
95 t.Run(tc.name, func(t *testing.T) {
96 req := httptestNewRequest("POST", "https://example.com/")
97 req.Header.Set("Origin", tc.origin)
98 if tc.secFetchSite != "" {
99 req.Header.Set("Sec-Fetch-Site", tc.secFetchSite)
100 }
101
102 w := httptest.NewRecorder()
103 handler.ServeHTTP(w, req)
104
105 if w.Code != tc.expectedStatus {
106 t.Errorf("got status %d, want %d", w.Code, tc.expectedStatus)
107 }
108 })
109 }
110 }
111
112 func TestCrossOriginProtectionPatternBypass(t *testing.T) {
113 protection := http.NewCrossOriginProtection()
114 protection.AddInsecureBypassPattern("/bypass/")
115 protection.AddInsecureBypassPattern("/only/{foo}")
116 protection.AddInsecureBypassPattern("/no-trailing")
117 protection.AddInsecureBypassPattern("/yes-trailing/")
118 protection.AddInsecureBypassPattern("PUT /put-only/")
119 protection.AddInsecureBypassPattern("GET /get-only/")
120 protection.AddInsecureBypassPattern("POST /post-only/")
121 handler := protection.Handler(okHandler)
122
123 tests := []struct {
124 name string
125 path string
126 secFetchSite string
127 expectedStatus int
128 }{
129 {"bypass path without sec-fetch-site", "/bypass/", "", http.StatusOK},
130 {"bypass path with cross-site", "/bypass/", "cross-site", http.StatusOK},
131 {"non-bypass path without sec-fetch-site", "/api/", "", http.StatusForbidden},
132 {"non-bypass path with cross-site", "/api/", "cross-site", http.StatusForbidden},
133
134 {"redirect to bypass path without ..", "/foo/../bypass/bar", "", http.StatusForbidden},
135 {"redirect to bypass path with trailing slash", "/bypass", "", http.StatusForbidden},
136 {"redirect to non-bypass path with ..", "/foo/../api/bar", "", http.StatusForbidden},
137 {"redirect to non-bypass path with trailing slash", "/api", "", http.StatusForbidden},
138
139 {"wildcard bypass", "/only/123", "", http.StatusOK},
140 {"non-wildcard", "/only/123/foo", "", http.StatusForbidden},
141
142
143 {"no trailing slash exact match", "/no-trailing", "", http.StatusOK},
144 {"no trailing slash with slash", "/no-trailing/", "", http.StatusForbidden},
145 {"yes trailing slash exact match", "/yes-trailing/", "", http.StatusOK},
146 {"yes trailing slash without slash", "/yes-trailing", "", http.StatusForbidden},
147
148 {"method-specific hit", "/post-only/", "", http.StatusOK},
149 {"method-specific miss (PUT)", "/put-only/", "", http.StatusForbidden},
150 {"method-specific miss (GET)", "/get-only/", "", http.StatusForbidden},
151 }
152
153 for _, tc := range tests {
154 t.Run(tc.name, func(t *testing.T) {
155 req := httptestNewRequest("POST", "https://example.com"+tc.path)
156 req.Header.Set("Origin", "https://attacker.example")
157 if tc.secFetchSite != "" {
158 req.Header.Set("Sec-Fetch-Site", tc.secFetchSite)
159 }
160
161 w := httptest.NewRecorder()
162 handler.ServeHTTP(w, req)
163
164 if w.Code != tc.expectedStatus {
165 t.Errorf("got status %d, want %d", w.Code, tc.expectedStatus)
166 }
167 })
168 }
169 }
170
171 func TestCrossOriginProtectionSetDenyHandler(t *testing.T) {
172 protection := http.NewCrossOriginProtection()
173
174 handler := protection.Handler(okHandler)
175
176 req := httptestNewRequest("POST", "https://example.com/")
177 req.Header.Set("Sec-Fetch-Site", "cross-site")
178
179 w := httptest.NewRecorder()
180 handler.ServeHTTP(w, req)
181
182 if w.Code != http.StatusForbidden {
183 t.Errorf("got status %d, want %d", w.Code, http.StatusForbidden)
184 }
185
186 customErrHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
187 w.WriteHeader(http.StatusTeapot)
188 io.WriteString(w, "custom error")
189 })
190 protection.SetDenyHandler(customErrHandler)
191
192 w = httptest.NewRecorder()
193 handler.ServeHTTP(w, req)
194
195 if w.Code != http.StatusTeapot {
196 t.Errorf("got status %d, want %d", w.Code, http.StatusTeapot)
197 }
198
199 if !strings.Contains(w.Body.String(), "custom error") {
200 t.Errorf("expected custom error message, got: %q", w.Body.String())
201 }
202
203 req = httptestNewRequest("GET", "https://example.com/")
204
205 w = httptest.NewRecorder()
206 handler.ServeHTTP(w, req)
207
208 if w.Code != http.StatusOK {
209 t.Errorf("got status %d, want %d", w.Code, http.StatusOK)
210 }
211
212 protection.SetDenyHandler(nil)
213
214 req = httptestNewRequest("POST", "https://example.com/")
215 req.Header.Set("Sec-Fetch-Site", "cross-site")
216
217 w = httptest.NewRecorder()
218 handler.ServeHTTP(w, req)
219
220 if w.Code != http.StatusForbidden {
221 t.Errorf("got status %d, want %d", w.Code, http.StatusForbidden)
222 }
223 }
224
225 func TestCrossOriginProtectionAddTrustedOriginErrors(t *testing.T) {
226 protection := http.NewCrossOriginProtection()
227
228 tests := []struct {
229 name string
230 origin string
231 wantErr bool
232 }{
233 {"valid origin", "https://example.com", false},
234 {"valid origin with port", "https://example.com:8080", false},
235 {"http origin", "http://example.com", false},
236 {"missing scheme", "example.com", true},
237 {"missing host", "https://", true},
238 {"trailing slash", "https://example.com/", true},
239 {"with path", "https://example.com/path", true},
240 {"with query", "https://example.com?query=value", true},
241 {"with fragment", "https://example.com#fragment", true},
242 {"invalid url", "https://ex ample.com", true},
243 {"empty string", "", true},
244 {"null", "null", true},
245 }
246
247 for _, tc := range tests {
248 t.Run(tc.name, func(t *testing.T) {
249 err := protection.AddTrustedOrigin(tc.origin)
250 if (err != nil) != tc.wantErr {
251 t.Errorf("AddTrustedOrigin(%q) error = %v, wantErr %v", tc.origin, err, tc.wantErr)
252 }
253 })
254 }
255 }
256
257 func TestCrossOriginProtectionAddingBypassesConcurrently(t *testing.T) {
258 protection := http.NewCrossOriginProtection()
259 handler := protection.Handler(okHandler)
260
261 req := httptestNewRequest("POST", "https://example.com/")
262 req.Header.Set("Origin", "https://concurrent.example")
263 req.Header.Set("Sec-Fetch-Site", "cross-site")
264
265 w := httptest.NewRecorder()
266 handler.ServeHTTP(w, req)
267
268 if w.Code != http.StatusForbidden {
269 t.Errorf("got status %d, want %d", w.Code, http.StatusForbidden)
270 }
271
272 start := make(chan struct{})
273 done := make(chan struct{})
274 go func() {
275 close(start)
276 defer close(done)
277 for range 10 {
278 w := httptest.NewRecorder()
279 handler.ServeHTTP(w, req)
280 }
281 }()
282
283
284 <-start
285 protection.AddTrustedOrigin("https://concurrent.example")
286 protection.AddInsecureBypassPattern("/foo/")
287 <-done
288
289 w = httptest.NewRecorder()
290 handler.ServeHTTP(w, req)
291
292 if w.Code != http.StatusOK {
293 t.Errorf("After concurrent bypass addition, got status %d, want %d", w.Code, http.StatusOK)
294 }
295 }
296
297 func TestCrossOriginProtectionServer(t *testing.T) {
298 protection := http.NewCrossOriginProtection()
299 protection.AddTrustedOrigin("https://trusted.example")
300 protection.AddInsecureBypassPattern("/bypass/")
301 handler := protection.Handler(okHandler)
302
303 ts := httptest.NewServer(handler)
304 defer ts.Close()
305
306 tests := []struct {
307 name string
308 method string
309 url string
310 origin string
311 secFetchSite string
312 expectedStatus int
313 }{
314 {"cross-site", "POST", ts.URL, "https://attacker.example", "cross-site", http.StatusForbidden},
315 {"same-origin", "POST", ts.URL, "", "same-origin", http.StatusOK},
316 {"origin matches host", "POST", ts.URL, ts.URL, "", http.StatusOK},
317 {"trusted origin", "POST", ts.URL, "https://trusted.example", "", http.StatusOK},
318 {"untrusted origin", "POST", ts.URL, "https://attacker.example", "", http.StatusForbidden},
319 {"bypass path", "POST", ts.URL + "/bypass/", "https://attacker.example", "", http.StatusOK},
320 }
321
322 for _, tc := range tests {
323 t.Run(tc.name, func(t *testing.T) {
324 req, err := http.NewRequest(tc.method, tc.url, nil)
325 if err != nil {
326 t.Fatalf("NewRequest: %v", err)
327 }
328 if tc.origin != "" {
329 req.Header.Set("Origin", tc.origin)
330 }
331 if tc.secFetchSite != "" {
332 req.Header.Set("Sec-Fetch-Site", tc.secFetchSite)
333 }
334 client := &http.Client{}
335 resp, err := client.Do(req)
336 if err != nil {
337 t.Fatalf("Do: %v", err)
338 }
339 defer resp.Body.Close()
340 if resp.StatusCode != tc.expectedStatus {
341 t.Errorf("got status %d, want %d", resp.StatusCode, tc.expectedStatus)
342 }
343 })
344 }
345 }
346
View as plain text