Source file src/net/http/csrf_test.go

     1  // Copyright 2025 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package http_test
     6  
     7  import (
     8  	"io"
     9  	"net/http"
    10  	"net/http/httptest"
    11  	"strings"
    12  	"testing"
    13  )
    14  
    15  // httptestNewRequest works around https://go.dev/issue/73151.
    16  func httptestNewRequest(method, target string) *http.Request {
    17  	req := httptest.NewRequest(method, target, nil)
    18  	req.URL.Scheme = ""
    19  	req.URL.Host = ""
    20  	return req
    21  }
    22  
    23  var okHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    24  	w.WriteHeader(http.StatusOK)
    25  })
    26  
    27  func TestCrossOriginProtectionSecFetchSite(t *testing.T) {
    28  	protection := http.NewCrossOriginProtection()
    29  	handler := protection.Handler(okHandler)
    30  
    31  	tests := []struct {
    32  		name           string
    33  		method         string
    34  		secFetchSite   string
    35  		origin         string
    36  		expectedStatus int
    37  	}{
    38  		{"same-origin allowed", "POST", "same-origin", "", http.StatusOK},
    39  		{"none allowed", "POST", "none", "", http.StatusOK},
    40  		{"cross-site blocked", "POST", "cross-site", "", http.StatusForbidden},
    41  		{"same-site blocked", "POST", "same-site", "", http.StatusForbidden},
    42  
    43  		{"no header with no origin", "POST", "", "", http.StatusOK},
    44  		{"no header with matching origin", "POST", "", "https://example.com", http.StatusOK},
    45  		{"no header with mismatched origin", "POST", "", "https://attacker.example", http.StatusForbidden},
    46  		{"no header with null origin", "POST", "", "null", http.StatusForbidden},
    47  
    48  		{"GET allowed", "GET", "cross-site", "", http.StatusOK},
    49  		{"HEAD allowed", "HEAD", "cross-site", "", http.StatusOK},
    50  		{"OPTIONS allowed", "OPTIONS", "cross-site", "", http.StatusOK},
    51  		{"PUT blocked", "PUT", "cross-site", "", http.StatusForbidden},
    52  	}
    53  
    54  	for _, tc := range tests {
    55  		t.Run(tc.name, func(t *testing.T) {
    56  			req := httptestNewRequest(tc.method, "https://example.com/")
    57  			if tc.secFetchSite != "" {
    58  				req.Header.Set("Sec-Fetch-Site", tc.secFetchSite)
    59  			}
    60  			if tc.origin != "" {
    61  				req.Header.Set("Origin", tc.origin)
    62  			}
    63  
    64  			w := httptest.NewRecorder()
    65  			handler.ServeHTTP(w, req)
    66  
    67  			if w.Code != tc.expectedStatus {
    68  				t.Errorf("got status %d, want %d", w.Code, tc.expectedStatus)
    69  			}
    70  		})
    71  	}
    72  }
    73  
    74  func TestCrossOriginProtectionTrustedOriginBypass(t *testing.T) {
    75  	protection := http.NewCrossOriginProtection()
    76  	err := protection.AddTrustedOrigin("https://trusted.example")
    77  	if err != nil {
    78  		t.Fatalf("AddTrustedOrigin: %v", err)
    79  	}
    80  	handler := protection.Handler(okHandler)
    81  
    82  	tests := []struct {
    83  		name           string
    84  		origin         string
    85  		secFetchSite   string
    86  		expectedStatus int
    87  	}{
    88  		{"trusted origin without sec-fetch-site", "https://trusted.example", "", http.StatusOK},
    89  		{"trusted origin with cross-site", "https://trusted.example", "cross-site", http.StatusOK},
    90  		{"untrusted origin without sec-fetch-site", "https://attacker.example", "", http.StatusForbidden},
    91  		{"untrusted origin with cross-site", "https://attacker.example", "cross-site", http.StatusForbidden},
    92  	}
    93  
    94  	for _, tc := range tests {
    95  		t.Run(tc.name, func(t *testing.T) {
    96  			req := httptestNewRequest("POST", "https://example.com/")
    97  			req.Header.Set("Origin", tc.origin)
    98  			if tc.secFetchSite != "" {
    99  				req.Header.Set("Sec-Fetch-Site", tc.secFetchSite)
   100  			}
   101  
   102  			w := httptest.NewRecorder()
   103  			handler.ServeHTTP(w, req)
   104  
   105  			if w.Code != tc.expectedStatus {
   106  				t.Errorf("got status %d, want %d", w.Code, tc.expectedStatus)
   107  			}
   108  		})
   109  	}
   110  }
   111  
   112  func TestCrossOriginProtectionPatternBypass(t *testing.T) {
   113  	protection := http.NewCrossOriginProtection()
   114  	protection.AddInsecureBypassPattern("/bypass/")
   115  	protection.AddInsecureBypassPattern("/only/{foo}")
   116  	protection.AddInsecureBypassPattern("/no-trailing")
   117  	protection.AddInsecureBypassPattern("/yes-trailing/")
   118  	protection.AddInsecureBypassPattern("PUT /put-only/")
   119  	protection.AddInsecureBypassPattern("GET /get-only/")
   120  	protection.AddInsecureBypassPattern("POST /post-only/")
   121  	handler := protection.Handler(okHandler)
   122  
   123  	tests := []struct {
   124  		name           string
   125  		path           string
   126  		secFetchSite   string
   127  		expectedStatus int
   128  	}{
   129  		{"bypass path without sec-fetch-site", "/bypass/", "", http.StatusOK},
   130  		{"bypass path with cross-site", "/bypass/", "cross-site", http.StatusOK},
   131  		{"non-bypass path without sec-fetch-site", "/api/", "", http.StatusForbidden},
   132  		{"non-bypass path with cross-site", "/api/", "cross-site", http.StatusForbidden},
   133  
   134  		{"redirect to bypass path without ..", "/foo/../bypass/bar", "", http.StatusForbidden},
   135  		{"redirect to bypass path with trailing slash", "/bypass", "", http.StatusForbidden},
   136  		{"redirect to non-bypass path with ..", "/foo/../api/bar", "", http.StatusForbidden},
   137  		{"redirect to non-bypass path with trailing slash", "/api", "", http.StatusForbidden},
   138  
   139  		{"wildcard bypass", "/only/123", "", http.StatusOK},
   140  		{"non-wildcard", "/only/123/foo", "", http.StatusForbidden},
   141  
   142  		// https://go.dev/issue/75054
   143  		{"no trailing slash exact match", "/no-trailing", "", http.StatusOK},
   144  		{"no trailing slash with slash", "/no-trailing/", "", http.StatusForbidden},
   145  		{"yes trailing slash exact match", "/yes-trailing/", "", http.StatusOK},
   146  		{"yes trailing slash without slash", "/yes-trailing", "", http.StatusForbidden},
   147  
   148  		{"method-specific hit", "/post-only/", "", http.StatusOK},
   149  		{"method-specific miss (PUT)", "/put-only/", "", http.StatusForbidden},
   150  		{"method-specific miss (GET)", "/get-only/", "", http.StatusForbidden},
   151  	}
   152  
   153  	for _, tc := range tests {
   154  		t.Run(tc.name, func(t *testing.T) {
   155  			req := httptestNewRequest("POST", "https://example.com"+tc.path)
   156  			req.Header.Set("Origin", "https://attacker.example")
   157  			if tc.secFetchSite != "" {
   158  				req.Header.Set("Sec-Fetch-Site", tc.secFetchSite)
   159  			}
   160  
   161  			w := httptest.NewRecorder()
   162  			handler.ServeHTTP(w, req)
   163  
   164  			if w.Code != tc.expectedStatus {
   165  				t.Errorf("got status %d, want %d", w.Code, tc.expectedStatus)
   166  			}
   167  		})
   168  	}
   169  }
   170  
   171  func TestCrossOriginProtectionSetDenyHandler(t *testing.T) {
   172  	protection := http.NewCrossOriginProtection()
   173  
   174  	handler := protection.Handler(okHandler)
   175  
   176  	req := httptestNewRequest("POST", "https://example.com/")
   177  	req.Header.Set("Sec-Fetch-Site", "cross-site")
   178  
   179  	w := httptest.NewRecorder()
   180  	handler.ServeHTTP(w, req)
   181  
   182  	if w.Code != http.StatusForbidden {
   183  		t.Errorf("got status %d, want %d", w.Code, http.StatusForbidden)
   184  	}
   185  
   186  	customErrHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   187  		w.WriteHeader(http.StatusTeapot)
   188  		io.WriteString(w, "custom error")
   189  	})
   190  	protection.SetDenyHandler(customErrHandler)
   191  
   192  	w = httptest.NewRecorder()
   193  	handler.ServeHTTP(w, req)
   194  
   195  	if w.Code != http.StatusTeapot {
   196  		t.Errorf("got status %d, want %d", w.Code, http.StatusTeapot)
   197  	}
   198  
   199  	if !strings.Contains(w.Body.String(), "custom error") {
   200  		t.Errorf("expected custom error message, got: %q", w.Body.String())
   201  	}
   202  
   203  	req = httptestNewRequest("GET", "https://example.com/")
   204  
   205  	w = httptest.NewRecorder()
   206  	handler.ServeHTTP(w, req)
   207  
   208  	if w.Code != http.StatusOK {
   209  		t.Errorf("got status %d, want %d", w.Code, http.StatusOK)
   210  	}
   211  
   212  	protection.SetDenyHandler(nil)
   213  
   214  	req = httptestNewRequest("POST", "https://example.com/")
   215  	req.Header.Set("Sec-Fetch-Site", "cross-site")
   216  
   217  	w = httptest.NewRecorder()
   218  	handler.ServeHTTP(w, req)
   219  
   220  	if w.Code != http.StatusForbidden {
   221  		t.Errorf("got status %d, want %d", w.Code, http.StatusForbidden)
   222  	}
   223  }
   224  
   225  func TestCrossOriginProtectionAddTrustedOriginErrors(t *testing.T) {
   226  	protection := http.NewCrossOriginProtection()
   227  
   228  	tests := []struct {
   229  		name    string
   230  		origin  string
   231  		wantErr bool
   232  	}{
   233  		{"valid origin", "https://example.com", false},
   234  		{"valid origin with port", "https://example.com:8080", false},
   235  		{"http origin", "http://example.com", false},
   236  		{"missing scheme", "example.com", true},
   237  		{"missing host", "https://", true},
   238  		{"trailing slash", "https://example.com/", true},
   239  		{"with path", "https://example.com/path", true},
   240  		{"with query", "https://example.com?query=value", true},
   241  		{"with fragment", "https://example.com#fragment", true},
   242  		{"invalid url", "https://ex ample.com", true},
   243  		{"empty string", "", true},
   244  		{"null", "null", true},
   245  	}
   246  
   247  	for _, tc := range tests {
   248  		t.Run(tc.name, func(t *testing.T) {
   249  			err := protection.AddTrustedOrigin(tc.origin)
   250  			if (err != nil) != tc.wantErr {
   251  				t.Errorf("AddTrustedOrigin(%q) error = %v, wantErr %v", tc.origin, err, tc.wantErr)
   252  			}
   253  		})
   254  	}
   255  }
   256  
   257  func TestCrossOriginProtectionAddingBypassesConcurrently(t *testing.T) {
   258  	protection := http.NewCrossOriginProtection()
   259  	handler := protection.Handler(okHandler)
   260  
   261  	req := httptestNewRequest("POST", "https://example.com/")
   262  	req.Header.Set("Origin", "https://concurrent.example")
   263  	req.Header.Set("Sec-Fetch-Site", "cross-site")
   264  
   265  	w := httptest.NewRecorder()
   266  	handler.ServeHTTP(w, req)
   267  
   268  	if w.Code != http.StatusForbidden {
   269  		t.Errorf("got status %d, want %d", w.Code, http.StatusForbidden)
   270  	}
   271  
   272  	start := make(chan struct{})
   273  	done := make(chan struct{})
   274  	go func() {
   275  		close(start)
   276  		defer close(done)
   277  		for range 10 {
   278  			w := httptest.NewRecorder()
   279  			handler.ServeHTTP(w, req)
   280  		}
   281  	}()
   282  
   283  	// Add bypasses while the requests are in flight.
   284  	<-start
   285  	protection.AddTrustedOrigin("https://concurrent.example")
   286  	protection.AddInsecureBypassPattern("/foo/")
   287  	<-done
   288  
   289  	w = httptest.NewRecorder()
   290  	handler.ServeHTTP(w, req)
   291  
   292  	if w.Code != http.StatusOK {
   293  		t.Errorf("After concurrent bypass addition, got status %d, want %d", w.Code, http.StatusOK)
   294  	}
   295  }
   296  
   297  func TestCrossOriginProtectionServer(t *testing.T) {
   298  	protection := http.NewCrossOriginProtection()
   299  	protection.AddTrustedOrigin("https://trusted.example")
   300  	protection.AddInsecureBypassPattern("/bypass/")
   301  	handler := protection.Handler(okHandler)
   302  
   303  	ts := httptest.NewServer(handler)
   304  	defer ts.Close()
   305  
   306  	tests := []struct {
   307  		name           string
   308  		method         string
   309  		url            string
   310  		origin         string
   311  		secFetchSite   string
   312  		expectedStatus int
   313  	}{
   314  		{"cross-site", "POST", ts.URL, "https://attacker.example", "cross-site", http.StatusForbidden},
   315  		{"same-origin", "POST", ts.URL, "", "same-origin", http.StatusOK},
   316  		{"origin matches host", "POST", ts.URL, ts.URL, "", http.StatusOK},
   317  		{"trusted origin", "POST", ts.URL, "https://trusted.example", "", http.StatusOK},
   318  		{"untrusted origin", "POST", ts.URL, "https://attacker.example", "", http.StatusForbidden},
   319  		{"bypass path", "POST", ts.URL + "/bypass/", "https://attacker.example", "", http.StatusOK},
   320  	}
   321  
   322  	for _, tc := range tests {
   323  		t.Run(tc.name, func(t *testing.T) {
   324  			req, err := http.NewRequest(tc.method, tc.url, nil)
   325  			if err != nil {
   326  				t.Fatalf("NewRequest: %v", err)
   327  			}
   328  			if tc.origin != "" {
   329  				req.Header.Set("Origin", tc.origin)
   330  			}
   331  			if tc.secFetchSite != "" {
   332  				req.Header.Set("Sec-Fetch-Site", tc.secFetchSite)
   333  			}
   334  			client := &http.Client{}
   335  			resp, err := client.Do(req)
   336  			if err != nil {
   337  				t.Fatalf("Do: %v", err)
   338  			}
   339  			defer resp.Body.Close()
   340  			if resp.StatusCode != tc.expectedStatus {
   341  				t.Errorf("got status %d, want %d", resp.StatusCode, tc.expectedStatus)
   342  			}
   343  		})
   344  	}
   345  }
   346  

View as plain text