Black Lives Matter. Support the Equal Justice Initiative.

Source file src/net/http/httputil/reverseproxy_test.go

Documentation: net/http/httputil

     1  // Copyright 2011 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // Reverse proxy tests.
     6  
     7  package httputil
     8  
     9  import (
    10  	"bufio"
    11  	"bytes"
    12  	"context"
    13  	"errors"
    14  	"fmt"
    15  	"io"
    16  	"log"
    17  	"net/http"
    18  	"net/http/httptest"
    19  	"net/http/internal/ascii"
    20  	"net/url"
    21  	"os"
    22  	"reflect"
    23  	"sort"
    24  	"strconv"
    25  	"strings"
    26  	"sync"
    27  	"testing"
    28  	"time"
    29  )
    30  
    31  const fakeHopHeader = "X-Fake-Hop-Header-For-Test"
    32  
    33  func init() {
    34  	inOurTests = true
    35  	hopHeaders = append(hopHeaders, fakeHopHeader)
    36  }
    37  
    38  func TestReverseProxy(t *testing.T) {
    39  	const backendResponse = "I am the backend"
    40  	const backendStatus = 404
    41  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    42  		if r.Method == "GET" && r.FormValue("mode") == "hangup" {
    43  			c, _, _ := w.(http.Hijacker).Hijack()
    44  			c.Close()
    45  			return
    46  		}
    47  		if len(r.TransferEncoding) > 0 {
    48  			t.Errorf("backend got unexpected TransferEncoding: %v", r.TransferEncoding)
    49  		}
    50  		if r.Header.Get("X-Forwarded-For") == "" {
    51  			t.Errorf("didn't get X-Forwarded-For header")
    52  		}
    53  		if c := r.Header.Get("Connection"); c != "" {
    54  			t.Errorf("handler got Connection header value %q", c)
    55  		}
    56  		if c := r.Header.Get("Te"); c != "trailers" {
    57  			t.Errorf("handler got Te header value %q; want 'trailers'", c)
    58  		}
    59  		if c := r.Header.Get("Upgrade"); c != "" {
    60  			t.Errorf("handler got Upgrade header value %q", c)
    61  		}
    62  		if c := r.Header.Get("Proxy-Connection"); c != "" {
    63  			t.Errorf("handler got Proxy-Connection header value %q", c)
    64  		}
    65  		if g, e := r.Host, "some-name"; g != e {
    66  			t.Errorf("backend got Host header %q, want %q", g, e)
    67  		}
    68  		w.Header().Set("Trailers", "not a special header field name")
    69  		w.Header().Set("Trailer", "X-Trailer")
    70  		w.Header().Set("X-Foo", "bar")
    71  		w.Header().Set("Upgrade", "foo")
    72  		w.Header().Set(fakeHopHeader, "foo")
    73  		w.Header().Add("X-Multi-Value", "foo")
    74  		w.Header().Add("X-Multi-Value", "bar")
    75  		http.SetCookie(w, &http.Cookie{Name: "flavor", Value: "chocolateChip"})
    76  		w.WriteHeader(backendStatus)
    77  		w.Write([]byte(backendResponse))
    78  		w.Header().Set("X-Trailer", "trailer_value")
    79  		w.Header().Set(http.TrailerPrefix+"X-Unannounced-Trailer", "unannounced_trailer_value")
    80  	}))
    81  	defer backend.Close()
    82  	backendURL, err := url.Parse(backend.URL)
    83  	if err != nil {
    84  		t.Fatal(err)
    85  	}
    86  	proxyHandler := NewSingleHostReverseProxy(backendURL)
    87  	proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
    88  	frontend := httptest.NewServer(proxyHandler)
    89  	defer frontend.Close()
    90  	frontendClient := frontend.Client()
    91  
    92  	getReq, _ := http.NewRequest("GET", frontend.URL, nil)
    93  	getReq.Host = "some-name"
    94  	getReq.Header.Set("Connection", "close, TE")
    95  	getReq.Header.Add("Te", "foo")
    96  	getReq.Header.Add("Te", "bar, trailers")
    97  	getReq.Header.Set("Proxy-Connection", "should be deleted")
    98  	getReq.Header.Set("Upgrade", "foo")
    99  	getReq.Close = true
   100  	res, err := frontendClient.Do(getReq)
   101  	if err != nil {
   102  		t.Fatalf("Get: %v", err)
   103  	}
   104  	if g, e := res.StatusCode, backendStatus; g != e {
   105  		t.Errorf("got res.StatusCode %d; expected %d", g, e)
   106  	}
   107  	if g, e := res.Header.Get("X-Foo"), "bar"; g != e {
   108  		t.Errorf("got X-Foo %q; expected %q", g, e)
   109  	}
   110  	if c := res.Header.Get(fakeHopHeader); c != "" {
   111  		t.Errorf("got %s header value %q", fakeHopHeader, c)
   112  	}
   113  	if g, e := res.Header.Get("Trailers"), "not a special header field name"; g != e {
   114  		t.Errorf("header Trailers = %q; want %q", g, e)
   115  	}
   116  	if g, e := len(res.Header["X-Multi-Value"]), 2; g != e {
   117  		t.Errorf("got %d X-Multi-Value header values; expected %d", g, e)
   118  	}
   119  	if g, e := len(res.Header["Set-Cookie"]), 1; g != e {
   120  		t.Fatalf("got %d SetCookies, want %d", g, e)
   121  	}
   122  	if g, e := res.Trailer, (http.Header{"X-Trailer": nil}); !reflect.DeepEqual(g, e) {
   123  		t.Errorf("before reading body, Trailer = %#v; want %#v", g, e)
   124  	}
   125  	if cookie := res.Cookies()[0]; cookie.Name != "flavor" {
   126  		t.Errorf("unexpected cookie %q", cookie.Name)
   127  	}
   128  	bodyBytes, _ := io.ReadAll(res.Body)
   129  	if g, e := string(bodyBytes), backendResponse; g != e {
   130  		t.Errorf("got body %q; expected %q", g, e)
   131  	}
   132  	if g, e := res.Trailer.Get("X-Trailer"), "trailer_value"; g != e {
   133  		t.Errorf("Trailer(X-Trailer) = %q ; want %q", g, e)
   134  	}
   135  	if g, e := res.Trailer.Get("X-Unannounced-Trailer"), "unannounced_trailer_value"; g != e {
   136  		t.Errorf("Trailer(X-Unannounced-Trailer) = %q ; want %q", g, e)
   137  	}
   138  
   139  	// Test that a backend failing to be reached or one which doesn't return
   140  	// a response results in a StatusBadGateway.
   141  	getReq, _ = http.NewRequest("GET", frontend.URL+"/?mode=hangup", nil)
   142  	getReq.Close = true
   143  	res, err = frontendClient.Do(getReq)
   144  	if err != nil {
   145  		t.Fatal(err)
   146  	}
   147  	res.Body.Close()
   148  	if res.StatusCode != http.StatusBadGateway {
   149  		t.Errorf("request to bad proxy = %v; want 502 StatusBadGateway", res.Status)
   150  	}
   151  
   152  }
   153  
   154  // Issue 16875: remove any proxied headers mentioned in the "Connection"
   155  // header value.
   156  func TestReverseProxyStripHeadersPresentInConnection(t *testing.T) {
   157  	const fakeConnectionToken = "X-Fake-Connection-Token"
   158  	const backendResponse = "I am the backend"
   159  
   160  	// someConnHeader is some arbitrary header to be declared as a hop-by-hop header
   161  	// in the Request's Connection header.
   162  	const someConnHeader = "X-Some-Conn-Header"
   163  
   164  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   165  		if c := r.Header.Get("Connection"); c != "" {
   166  			t.Errorf("handler got header %q = %q; want empty", "Connection", c)
   167  		}
   168  		if c := r.Header.Get(fakeConnectionToken); c != "" {
   169  			t.Errorf("handler got header %q = %q; want empty", fakeConnectionToken, c)
   170  		}
   171  		if c := r.Header.Get(someConnHeader); c != "" {
   172  			t.Errorf("handler got header %q = %q; want empty", someConnHeader, c)
   173  		}
   174  		w.Header().Add("Connection", "Upgrade, "+fakeConnectionToken)
   175  		w.Header().Add("Connection", someConnHeader)
   176  		w.Header().Set(someConnHeader, "should be deleted")
   177  		w.Header().Set(fakeConnectionToken, "should be deleted")
   178  		io.WriteString(w, backendResponse)
   179  	}))
   180  	defer backend.Close()
   181  	backendURL, err := url.Parse(backend.URL)
   182  	if err != nil {
   183  		t.Fatal(err)
   184  	}
   185  	proxyHandler := NewSingleHostReverseProxy(backendURL)
   186  	frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   187  		proxyHandler.ServeHTTP(w, r)
   188  		if c := r.Header.Get(someConnHeader); c != "should be deleted" {
   189  			t.Errorf("handler modified header %q = %q; want %q", someConnHeader, c, "should be deleted")
   190  		}
   191  		if c := r.Header.Get(fakeConnectionToken); c != "should be deleted" {
   192  			t.Errorf("handler modified header %q = %q; want %q", fakeConnectionToken, c, "should be deleted")
   193  		}
   194  		c := r.Header["Connection"]
   195  		var cf []string
   196  		for _, f := range c {
   197  			for _, sf := range strings.Split(f, ",") {
   198  				if sf = strings.TrimSpace(sf); sf != "" {
   199  					cf = append(cf, sf)
   200  				}
   201  			}
   202  		}
   203  		sort.Strings(cf)
   204  		expectedValues := []string{"Upgrade", someConnHeader, fakeConnectionToken}
   205  		sort.Strings(expectedValues)
   206  		if !reflect.DeepEqual(cf, expectedValues) {
   207  			t.Errorf("handler modified header %q = %q; want %q", "Connection", cf, expectedValues)
   208  		}
   209  	}))
   210  	defer frontend.Close()
   211  
   212  	getReq, _ := http.NewRequest("GET", frontend.URL, nil)
   213  	getReq.Header.Add("Connection", "Upgrade, "+fakeConnectionToken)
   214  	getReq.Header.Add("Connection", someConnHeader)
   215  	getReq.Header.Set(someConnHeader, "should be deleted")
   216  	getReq.Header.Set(fakeConnectionToken, "should be deleted")
   217  	res, err := frontend.Client().Do(getReq)
   218  	if err != nil {
   219  		t.Fatalf("Get: %v", err)
   220  	}
   221  	defer res.Body.Close()
   222  	bodyBytes, err := io.ReadAll(res.Body)
   223  	if err != nil {
   224  		t.Fatalf("reading body: %v", err)
   225  	}
   226  	if got, want := string(bodyBytes), backendResponse; got != want {
   227  		t.Errorf("got body %q; want %q", got, want)
   228  	}
   229  	if c := res.Header.Get("Connection"); c != "" {
   230  		t.Errorf("handler got header %q = %q; want empty", "Connection", c)
   231  	}
   232  	if c := res.Header.Get(someConnHeader); c != "" {
   233  		t.Errorf("handler got header %q = %q; want empty", someConnHeader, c)
   234  	}
   235  	if c := res.Header.Get(fakeConnectionToken); c != "" {
   236  		t.Errorf("handler got header %q = %q; want empty", fakeConnectionToken, c)
   237  	}
   238  }
   239  
   240  func TestReverseProxyStripEmptyConnection(t *testing.T) {
   241  	// See Issue 46313.
   242  	const backendResponse = "I am the backend"
   243  
   244  	// someConnHeader is some arbitrary header to be declared as a hop-by-hop header
   245  	// in the Request's Connection header.
   246  	const someConnHeader = "X-Some-Conn-Header"
   247  
   248  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   249  		if c := r.Header.Values("Connection"); len(c) != 0 {
   250  			t.Errorf("handler got header %q = %v; want empty", "Connection", c)
   251  		}
   252  		if c := r.Header.Get(someConnHeader); c != "" {
   253  			t.Errorf("handler got header %q = %q; want empty", someConnHeader, c)
   254  		}
   255  		w.Header().Add("Connection", "")
   256  		w.Header().Add("Connection", someConnHeader)
   257  		w.Header().Set(someConnHeader, "should be deleted")
   258  		io.WriteString(w, backendResponse)
   259  	}))
   260  	defer backend.Close()
   261  	backendURL, err := url.Parse(backend.URL)
   262  	if err != nil {
   263  		t.Fatal(err)
   264  	}
   265  	proxyHandler := NewSingleHostReverseProxy(backendURL)
   266  	frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   267  		proxyHandler.ServeHTTP(w, r)
   268  		if c := r.Header.Get(someConnHeader); c != "should be deleted" {
   269  			t.Errorf("handler modified header %q = %q; want %q", someConnHeader, c, "should be deleted")
   270  		}
   271  	}))
   272  	defer frontend.Close()
   273  
   274  	getReq, _ := http.NewRequest("GET", frontend.URL, nil)
   275  	getReq.Header.Add("Connection", "")
   276  	getReq.Header.Add("Connection", someConnHeader)
   277  	getReq.Header.Set(someConnHeader, "should be deleted")
   278  	res, err := frontend.Client().Do(getReq)
   279  	if err != nil {
   280  		t.Fatalf("Get: %v", err)
   281  	}
   282  	defer res.Body.Close()
   283  	bodyBytes, err := io.ReadAll(res.Body)
   284  	if err != nil {
   285  		t.Fatalf("reading body: %v", err)
   286  	}
   287  	if got, want := string(bodyBytes), backendResponse; got != want {
   288  		t.Errorf("got body %q; want %q", got, want)
   289  	}
   290  	if c := res.Header.Get("Connection"); c != "" {
   291  		t.Errorf("handler got header %q = %q; want empty", "Connection", c)
   292  	}
   293  	if c := res.Header.Get(someConnHeader); c != "" {
   294  		t.Errorf("handler got header %q = %q; want empty", someConnHeader, c)
   295  	}
   296  }
   297  
   298  func TestXForwardedFor(t *testing.T) {
   299  	const prevForwardedFor = "client ip"
   300  	const backendResponse = "I am the backend"
   301  	const backendStatus = 404
   302  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   303  		if r.Header.Get("X-Forwarded-For") == "" {
   304  			t.Errorf("didn't get X-Forwarded-For header")
   305  		}
   306  		if !strings.Contains(r.Header.Get("X-Forwarded-For"), prevForwardedFor) {
   307  			t.Errorf("X-Forwarded-For didn't contain prior data")
   308  		}
   309  		w.WriteHeader(backendStatus)
   310  		w.Write([]byte(backendResponse))
   311  	}))
   312  	defer backend.Close()
   313  	backendURL, err := url.Parse(backend.URL)
   314  	if err != nil {
   315  		t.Fatal(err)
   316  	}
   317  	proxyHandler := NewSingleHostReverseProxy(backendURL)
   318  	frontend := httptest.NewServer(proxyHandler)
   319  	defer frontend.Close()
   320  
   321  	getReq, _ := http.NewRequest("GET", frontend.URL, nil)
   322  	getReq.Host = "some-name"
   323  	getReq.Header.Set("Connection", "close")
   324  	getReq.Header.Set("X-Forwarded-For", prevForwardedFor)
   325  	getReq.Close = true
   326  	res, err := frontend.Client().Do(getReq)
   327  	if err != nil {
   328  		t.Fatalf("Get: %v", err)
   329  	}
   330  	if g, e := res.StatusCode, backendStatus; g != e {
   331  		t.Errorf("got res.StatusCode %d; expected %d", g, e)
   332  	}
   333  	bodyBytes, _ := io.ReadAll(res.Body)
   334  	if g, e := string(bodyBytes), backendResponse; g != e {
   335  		t.Errorf("got body %q; expected %q", g, e)
   336  	}
   337  }
   338  
   339  // Issue 38079: don't append to X-Forwarded-For if it's present but nil
   340  func TestXForwardedFor_Omit(t *testing.T) {
   341  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   342  		if v := r.Header.Get("X-Forwarded-For"); v != "" {
   343  			t.Errorf("got X-Forwarded-For header: %q", v)
   344  		}
   345  		w.Write([]byte("hi"))
   346  	}))
   347  	defer backend.Close()
   348  	backendURL, err := url.Parse(backend.URL)
   349  	if err != nil {
   350  		t.Fatal(err)
   351  	}
   352  	proxyHandler := NewSingleHostReverseProxy(backendURL)
   353  	frontend := httptest.NewServer(proxyHandler)
   354  	defer frontend.Close()
   355  
   356  	oldDirector := proxyHandler.Director
   357  	proxyHandler.Director = func(r *http.Request) {
   358  		r.Header["X-Forwarded-For"] = nil
   359  		oldDirector(r)
   360  	}
   361  
   362  	getReq, _ := http.NewRequest("GET", frontend.URL, nil)
   363  	getReq.Host = "some-name"
   364  	getReq.Close = true
   365  	res, err := frontend.Client().Do(getReq)
   366  	if err != nil {
   367  		t.Fatalf("Get: %v", err)
   368  	}
   369  	res.Body.Close()
   370  }
   371  
   372  var proxyQueryTests = []struct {
   373  	baseSuffix string // suffix to add to backend URL
   374  	reqSuffix  string // suffix to add to frontend's request URL
   375  	want       string // what backend should see for final request URL (without ?)
   376  }{
   377  	{"", "", ""},
   378  	{"?sta=tic", "?us=er", "sta=tic&us=er"},
   379  	{"", "?us=er", "us=er"},
   380  	{"?sta=tic", "", "sta=tic"},
   381  }
   382  
   383  func TestReverseProxyQuery(t *testing.T) {
   384  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   385  		w.Header().Set("X-Got-Query", r.URL.RawQuery)
   386  		w.Write([]byte("hi"))
   387  	}))
   388  	defer backend.Close()
   389  
   390  	for i, tt := range proxyQueryTests {
   391  		backendURL, err := url.Parse(backend.URL + tt.baseSuffix)
   392  		if err != nil {
   393  			t.Fatal(err)
   394  		}
   395  		frontend := httptest.NewServer(NewSingleHostReverseProxy(backendURL))
   396  		req, _ := http.NewRequest("GET", frontend.URL+tt.reqSuffix, nil)
   397  		req.Close = true
   398  		res, err := frontend.Client().Do(req)
   399  		if err != nil {
   400  			t.Fatalf("%d. Get: %v", i, err)
   401  		}
   402  		if g, e := res.Header.Get("X-Got-Query"), tt.want; g != e {
   403  			t.Errorf("%d. got query %q; expected %q", i, g, e)
   404  		}
   405  		res.Body.Close()
   406  		frontend.Close()
   407  	}
   408  }
   409  
   410  func TestReverseProxyFlushInterval(t *testing.T) {
   411  	const expected = "hi"
   412  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   413  		w.Write([]byte(expected))
   414  	}))
   415  	defer backend.Close()
   416  
   417  	backendURL, err := url.Parse(backend.URL)
   418  	if err != nil {
   419  		t.Fatal(err)
   420  	}
   421  
   422  	proxyHandler := NewSingleHostReverseProxy(backendURL)
   423  	proxyHandler.FlushInterval = time.Microsecond
   424  
   425  	frontend := httptest.NewServer(proxyHandler)
   426  	defer frontend.Close()
   427  
   428  	req, _ := http.NewRequest("GET", frontend.URL, nil)
   429  	req.Close = true
   430  	res, err := frontend.Client().Do(req)
   431  	if err != nil {
   432  		t.Fatalf("Get: %v", err)
   433  	}
   434  	defer res.Body.Close()
   435  	if bodyBytes, _ := io.ReadAll(res.Body); string(bodyBytes) != expected {
   436  		t.Errorf("got body %q; expected %q", bodyBytes, expected)
   437  	}
   438  }
   439  
   440  func TestReverseProxyFlushIntervalHeaders(t *testing.T) {
   441  	const expected = "hi"
   442  	stopCh := make(chan struct{})
   443  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   444  		w.Header().Add("MyHeader", expected)
   445  		w.WriteHeader(200)
   446  		w.(http.Flusher).Flush()
   447  		<-stopCh
   448  	}))
   449  	defer backend.Close()
   450  	defer close(stopCh)
   451  
   452  	backendURL, err := url.Parse(backend.URL)
   453  	if err != nil {
   454  		t.Fatal(err)
   455  	}
   456  
   457  	proxyHandler := NewSingleHostReverseProxy(backendURL)
   458  	proxyHandler.FlushInterval = time.Microsecond
   459  
   460  	frontend := httptest.NewServer(proxyHandler)
   461  	defer frontend.Close()
   462  
   463  	req, _ := http.NewRequest("GET", frontend.URL, nil)
   464  	req.Close = true
   465  
   466  	ctx, cancel := context.WithTimeout(req.Context(), 10*time.Second)
   467  	defer cancel()
   468  	req = req.WithContext(ctx)
   469  
   470  	res, err := frontend.Client().Do(req)
   471  	if err != nil {
   472  		t.Fatalf("Get: %v", err)
   473  	}
   474  	defer res.Body.Close()
   475  
   476  	if res.Header.Get("MyHeader") != expected {
   477  		t.Errorf("got header %q; expected %q", res.Header.Get("MyHeader"), expected)
   478  	}
   479  }
   480  
   481  func TestReverseProxyCancellation(t *testing.T) {
   482  	const backendResponse = "I am the backend"
   483  
   484  	reqInFlight := make(chan struct{})
   485  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   486  		close(reqInFlight) // cause the client to cancel its request
   487  
   488  		select {
   489  		case <-time.After(10 * time.Second):
   490  			// Note: this should only happen in broken implementations, and the
   491  			// closenotify case should be instantaneous.
   492  			t.Error("Handler never saw CloseNotify")
   493  			return
   494  		case <-w.(http.CloseNotifier).CloseNotify():
   495  		}
   496  
   497  		w.WriteHeader(http.StatusOK)
   498  		w.Write([]byte(backendResponse))
   499  	}))
   500  
   501  	defer backend.Close()
   502  
   503  	backend.Config.ErrorLog = log.New(io.Discard, "", 0)
   504  
   505  	backendURL, err := url.Parse(backend.URL)
   506  	if err != nil {
   507  		t.Fatal(err)
   508  	}
   509  
   510  	proxyHandler := NewSingleHostReverseProxy(backendURL)
   511  
   512  	// Discards errors of the form:
   513  	// http: proxy error: read tcp 127.0.0.1:44643: use of closed network connection
   514  	proxyHandler.ErrorLog = log.New(io.Discard, "", 0)
   515  
   516  	frontend := httptest.NewServer(proxyHandler)
   517  	defer frontend.Close()
   518  	frontendClient := frontend.Client()
   519  
   520  	getReq, _ := http.NewRequest("GET", frontend.URL, nil)
   521  	go func() {
   522  		<-reqInFlight
   523  		frontendClient.Transport.(*http.Transport).CancelRequest(getReq)
   524  	}()
   525  	res, err := frontendClient.Do(getReq)
   526  	if res != nil {
   527  		t.Errorf("got response %v; want nil", res.Status)
   528  	}
   529  	if err == nil {
   530  		// This should be an error like:
   531  		// Get "http://127.0.0.1:58079": read tcp 127.0.0.1:58079:
   532  		//    use of closed network connection
   533  		t.Error("Server.Client().Do() returned nil error; want non-nil error")
   534  	}
   535  }
   536  
   537  func req(t *testing.T, v string) *http.Request {
   538  	req, err := http.ReadRequest(bufio.NewReader(strings.NewReader(v)))
   539  	if err != nil {
   540  		t.Fatal(err)
   541  	}
   542  	return req
   543  }
   544  
   545  // Issue 12344
   546  func TestNilBody(t *testing.T) {
   547  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   548  		w.Write([]byte("hi"))
   549  	}))
   550  	defer backend.Close()
   551  
   552  	frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
   553  		backURL, _ := url.Parse(backend.URL)
   554  		rp := NewSingleHostReverseProxy(backURL)
   555  		r := req(t, "GET / HTTP/1.0\r\n\r\n")
   556  		r.Body = nil // this accidentally worked in Go 1.4 and below, so keep it working
   557  		rp.ServeHTTP(w, r)
   558  	}))
   559  	defer frontend.Close()
   560  
   561  	res, err := http.Get(frontend.URL)
   562  	if err != nil {
   563  		t.Fatal(err)
   564  	}
   565  	defer res.Body.Close()
   566  	slurp, err := io.ReadAll(res.Body)
   567  	if err != nil {
   568  		t.Fatal(err)
   569  	}
   570  	if string(slurp) != "hi" {
   571  		t.Errorf("Got %q; want %q", slurp, "hi")
   572  	}
   573  }
   574  
   575  // Issue 15524
   576  func TestUserAgentHeader(t *testing.T) {
   577  	const explicitUA = "explicit UA"
   578  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   579  		if r.URL.Path == "/noua" {
   580  			if c := r.Header.Get("User-Agent"); c != "" {
   581  				t.Errorf("handler got non-empty User-Agent header %q", c)
   582  			}
   583  			return
   584  		}
   585  		if c := r.Header.Get("User-Agent"); c != explicitUA {
   586  			t.Errorf("handler got unexpected User-Agent header %q", c)
   587  		}
   588  	}))
   589  	defer backend.Close()
   590  	backendURL, err := url.Parse(backend.URL)
   591  	if err != nil {
   592  		t.Fatal(err)
   593  	}
   594  	proxyHandler := NewSingleHostReverseProxy(backendURL)
   595  	proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
   596  	frontend := httptest.NewServer(proxyHandler)
   597  	defer frontend.Close()
   598  	frontendClient := frontend.Client()
   599  
   600  	getReq, _ := http.NewRequest("GET", frontend.URL, nil)
   601  	getReq.Header.Set("User-Agent", explicitUA)
   602  	getReq.Close = true
   603  	res, err := frontendClient.Do(getReq)
   604  	if err != nil {
   605  		t.Fatalf("Get: %v", err)
   606  	}
   607  	res.Body.Close()
   608  
   609  	getReq, _ = http.NewRequest("GET", frontend.URL+"/noua", nil)
   610  	getReq.Header.Set("User-Agent", "")
   611  	getReq.Close = true
   612  	res, err = frontendClient.Do(getReq)
   613  	if err != nil {
   614  		t.Fatalf("Get: %v", err)
   615  	}
   616  	res.Body.Close()
   617  }
   618  
   619  type bufferPool struct {
   620  	get func() []byte
   621  	put func([]byte)
   622  }
   623  
   624  func (bp bufferPool) Get() []byte  { return bp.get() }
   625  func (bp bufferPool) Put(v []byte) { bp.put(v) }
   626  
   627  func TestReverseProxyGetPutBuffer(t *testing.T) {
   628  	const msg = "hi"
   629  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   630  		io.WriteString(w, msg)
   631  	}))
   632  	defer backend.Close()
   633  
   634  	backendURL, err := url.Parse(backend.URL)
   635  	if err != nil {
   636  		t.Fatal(err)
   637  	}
   638  
   639  	var (
   640  		mu  sync.Mutex
   641  		log []string
   642  	)
   643  	addLog := func(event string) {
   644  		mu.Lock()
   645  		defer mu.Unlock()
   646  		log = append(log, event)
   647  	}
   648  	rp := NewSingleHostReverseProxy(backendURL)
   649  	const size = 1234
   650  	rp.BufferPool = bufferPool{
   651  		get: func() []byte {
   652  			addLog("getBuf")
   653  			return make([]byte, size)
   654  		},
   655  		put: func(p []byte) {
   656  			addLog("putBuf-" + strconv.Itoa(len(p)))
   657  		},
   658  	}
   659  	frontend := httptest.NewServer(rp)
   660  	defer frontend.Close()
   661  
   662  	req, _ := http.NewRequest("GET", frontend.URL, nil)
   663  	req.Close = true
   664  	res, err := frontend.Client().Do(req)
   665  	if err != nil {
   666  		t.Fatalf("Get: %v", err)
   667  	}
   668  	slurp, err := io.ReadAll(res.Body)
   669  	res.Body.Close()
   670  	if err != nil {
   671  		t.Fatalf("reading body: %v", err)
   672  	}
   673  	if string(slurp) != msg {
   674  		t.Errorf("msg = %q; want %q", slurp, msg)
   675  	}
   676  	wantLog := []string{"getBuf", "putBuf-" + strconv.Itoa(size)}
   677  	mu.Lock()
   678  	defer mu.Unlock()
   679  	if !reflect.DeepEqual(log, wantLog) {
   680  		t.Errorf("Log events = %q; want %q", log, wantLog)
   681  	}
   682  }
   683  
   684  func TestReverseProxy_Post(t *testing.T) {
   685  	const backendResponse = "I am the backend"
   686  	const backendStatus = 200
   687  	var requestBody = bytes.Repeat([]byte("a"), 1<<20)
   688  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   689  		slurp, err := io.ReadAll(r.Body)
   690  		if err != nil {
   691  			t.Errorf("Backend body read = %v", err)
   692  		}
   693  		if len(slurp) != len(requestBody) {
   694  			t.Errorf("Backend read %d request body bytes; want %d", len(slurp), len(requestBody))
   695  		}
   696  		if !bytes.Equal(slurp, requestBody) {
   697  			t.Error("Backend read wrong request body.") // 1MB; omitting details
   698  		}
   699  		w.Write([]byte(backendResponse))
   700  	}))
   701  	defer backend.Close()
   702  	backendURL, err := url.Parse(backend.URL)
   703  	if err != nil {
   704  		t.Fatal(err)
   705  	}
   706  	proxyHandler := NewSingleHostReverseProxy(backendURL)
   707  	frontend := httptest.NewServer(proxyHandler)
   708  	defer frontend.Close()
   709  
   710  	postReq, _ := http.NewRequest("POST", frontend.URL, bytes.NewReader(requestBody))
   711  	res, err := frontend.Client().Do(postReq)
   712  	if err != nil {
   713  		t.Fatalf("Do: %v", err)
   714  	}
   715  	if g, e := res.StatusCode, backendStatus; g != e {
   716  		t.Errorf("got res.StatusCode %d; expected %d", g, e)
   717  	}
   718  	bodyBytes, _ := io.ReadAll(res.Body)
   719  	if g, e := string(bodyBytes), backendResponse; g != e {
   720  		t.Errorf("got body %q; expected %q", g, e)
   721  	}
   722  }
   723  
   724  type RoundTripperFunc func(*http.Request) (*http.Response, error)
   725  
   726  func (fn RoundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) {
   727  	return fn(req)
   728  }
   729  
   730  // Issue 16036: send a Request with a nil Body when possible
   731  func TestReverseProxy_NilBody(t *testing.T) {
   732  	backendURL, _ := url.Parse("http://fake.tld/")
   733  	proxyHandler := NewSingleHostReverseProxy(backendURL)
   734  	proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
   735  	proxyHandler.Transport = RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
   736  		if req.Body != nil {
   737  			t.Error("Body != nil; want a nil Body")
   738  		}
   739  		return nil, errors.New("done testing the interesting part; so force a 502 Gateway error")
   740  	})
   741  	frontend := httptest.NewServer(proxyHandler)
   742  	defer frontend.Close()
   743  
   744  	res, err := frontend.Client().Get(frontend.URL)
   745  	if err != nil {
   746  		t.Fatal(err)
   747  	}
   748  	defer res.Body.Close()
   749  	if res.StatusCode != 502 {
   750  		t.Errorf("status code = %v; want 502 (Gateway Error)", res.Status)
   751  	}
   752  }
   753  
   754  // Issue 33142: always allocate the request headers
   755  func TestReverseProxy_AllocatedHeader(t *testing.T) {
   756  	proxyHandler := new(ReverseProxy)
   757  	proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
   758  	proxyHandler.Director = func(*http.Request) {}     // noop
   759  	proxyHandler.Transport = RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
   760  		if req.Header == nil {
   761  			t.Error("Header == nil; want a non-nil Header")
   762  		}
   763  		return nil, errors.New("done testing the interesting part; so force a 502 Gateway error")
   764  	})
   765  
   766  	proxyHandler.ServeHTTP(httptest.NewRecorder(), &http.Request{
   767  		Method:     "GET",
   768  		URL:        &url.URL{Scheme: "http", Host: "fake.tld", Path: "/"},
   769  		Proto:      "HTTP/1.0",
   770  		ProtoMajor: 1,
   771  	})
   772  }
   773  
   774  // Issue 14237. Test ModifyResponse and that an error from it
   775  // causes the proxy to return StatusBadGateway, or StatusOK otherwise.
   776  func TestReverseProxyModifyResponse(t *testing.T) {
   777  	backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   778  		w.Header().Add("X-Hit-Mod", fmt.Sprintf("%v", r.URL.Path == "/mod"))
   779  	}))
   780  	defer backendServer.Close()
   781  
   782  	rpURL, _ := url.Parse(backendServer.URL)
   783  	rproxy := NewSingleHostReverseProxy(rpURL)
   784  	rproxy.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
   785  	rproxy.ModifyResponse = func(resp *http.Response) error {
   786  		if resp.Header.Get("X-Hit-Mod") != "true" {
   787  			return fmt.Errorf("tried to by-pass proxy")
   788  		}
   789  		return nil
   790  	}
   791  
   792  	frontendProxy := httptest.NewServer(rproxy)
   793  	defer frontendProxy.Close()
   794  
   795  	tests := []struct {
   796  		url      string
   797  		wantCode int
   798  	}{
   799  		{frontendProxy.URL + "/mod", http.StatusOK},
   800  		{frontendProxy.URL + "/schedule", http.StatusBadGateway},
   801  	}
   802  
   803  	for i, tt := range tests {
   804  		resp, err := http.Get(tt.url)
   805  		if err != nil {
   806  			t.Fatalf("failed to reach proxy: %v", err)
   807  		}
   808  		if g, e := resp.StatusCode, tt.wantCode; g != e {
   809  			t.Errorf("#%d: got res.StatusCode %d; expected %d", i, g, e)
   810  		}
   811  		resp.Body.Close()
   812  	}
   813  }
   814  
   815  type failingRoundTripper struct{}
   816  
   817  func (failingRoundTripper) RoundTrip(*http.Request) (*http.Response, error) {
   818  	return nil, errors.New("some error")
   819  }
   820  
   821  type staticResponseRoundTripper struct{ res *http.Response }
   822  
   823  func (rt staticResponseRoundTripper) RoundTrip(*http.Request) (*http.Response, error) {
   824  	return rt.res, nil
   825  }
   826  
   827  func TestReverseProxyErrorHandler(t *testing.T) {
   828  	tests := []struct {
   829  		name           string
   830  		wantCode       int
   831  		errorHandler   func(http.ResponseWriter, *http.Request, error)
   832  		transport      http.RoundTripper // defaults to failingRoundTripper
   833  		modifyResponse func(*http.Response) error
   834  	}{
   835  		{
   836  			name:     "default",
   837  			wantCode: http.StatusBadGateway,
   838  		},
   839  		{
   840  			name:         "errorhandler",
   841  			wantCode:     http.StatusTeapot,
   842  			errorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { rw.WriteHeader(http.StatusTeapot) },
   843  		},
   844  		{
   845  			name: "modifyresponse_noerr",
   846  			transport: staticResponseRoundTripper{
   847  				&http.Response{StatusCode: 345, Body: http.NoBody},
   848  			},
   849  			modifyResponse: func(res *http.Response) error {
   850  				res.StatusCode++
   851  				return nil
   852  			},
   853  			errorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { rw.WriteHeader(http.StatusTeapot) },
   854  			wantCode:     346,
   855  		},
   856  		{
   857  			name: "modifyresponse_err",
   858  			transport: staticResponseRoundTripper{
   859  				&http.Response{StatusCode: 345, Body: http.NoBody},
   860  			},
   861  			modifyResponse: func(res *http.Response) error {
   862  				res.StatusCode++
   863  				return errors.New("some error to trigger errorHandler")
   864  			},
   865  			errorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { rw.WriteHeader(http.StatusTeapot) },
   866  			wantCode:     http.StatusTeapot,
   867  		},
   868  	}
   869  
   870  	for _, tt := range tests {
   871  		t.Run(tt.name, func(t *testing.T) {
   872  			target := &url.URL{
   873  				Scheme: "http",
   874  				Host:   "dummy.tld",
   875  				Path:   "/",
   876  			}
   877  			rproxy := NewSingleHostReverseProxy(target)
   878  			rproxy.Transport = tt.transport
   879  			rproxy.ModifyResponse = tt.modifyResponse
   880  			if rproxy.Transport == nil {
   881  				rproxy.Transport = failingRoundTripper{}
   882  			}
   883  			rproxy.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
   884  			if tt.errorHandler != nil {
   885  				rproxy.ErrorHandler = tt.errorHandler
   886  			}
   887  			frontendProxy := httptest.NewServer(rproxy)
   888  			defer frontendProxy.Close()
   889  
   890  			resp, err := http.Get(frontendProxy.URL + "/test")
   891  			if err != nil {
   892  				t.Fatalf("failed to reach proxy: %v", err)
   893  			}
   894  			if g, e := resp.StatusCode, tt.wantCode; g != e {
   895  				t.Errorf("got res.StatusCode %d; expected %d", g, e)
   896  			}
   897  			resp.Body.Close()
   898  		})
   899  	}
   900  }
   901  
   902  // Issue 16659: log errors from short read
   903  func TestReverseProxy_CopyBuffer(t *testing.T) {
   904  	backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   905  		out := "this call was relayed by the reverse proxy"
   906  		// Coerce a wrong content length to induce io.UnexpectedEOF
   907  		w.Header().Set("Content-Length", fmt.Sprintf("%d", len(out)*2))
   908  		fmt.Fprintln(w, out)
   909  	}))
   910  	defer backendServer.Close()
   911  
   912  	rpURL, err := url.Parse(backendServer.URL)
   913  	if err != nil {
   914  		t.Fatal(err)
   915  	}
   916  
   917  	var proxyLog bytes.Buffer
   918  	rproxy := NewSingleHostReverseProxy(rpURL)
   919  	rproxy.ErrorLog = log.New(&proxyLog, "", log.Lshortfile)
   920  	donec := make(chan bool, 1)
   921  	frontendProxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   922  		defer func() { donec <- true }()
   923  		rproxy.ServeHTTP(w, r)
   924  	}))
   925  	defer frontendProxy.Close()
   926  
   927  	if _, err = frontendProxy.Client().Get(frontendProxy.URL); err == nil {
   928  		t.Fatalf("want non-nil error")
   929  	}
   930  	// The race detector complains about the proxyLog usage in logf in copyBuffer
   931  	// and our usage below with proxyLog.Bytes() so we're explicitly using a
   932  	// channel to ensure that the ReverseProxy's ServeHTTP is done before we
   933  	// continue after Get.
   934  	<-donec
   935  
   936  	expected := []string{
   937  		"EOF",
   938  		"read",
   939  	}
   940  	for _, phrase := range expected {
   941  		if !bytes.Contains(proxyLog.Bytes(), []byte(phrase)) {
   942  			t.Errorf("expected log to contain phrase %q", phrase)
   943  		}
   944  	}
   945  }
   946  
   947  type staticTransport struct {
   948  	res *http.Response
   949  }
   950  
   951  func (t *staticTransport) RoundTrip(r *http.Request) (*http.Response, error) {
   952  	return t.res, nil
   953  }
   954  
   955  func BenchmarkServeHTTP(b *testing.B) {
   956  	res := &http.Response{
   957  		StatusCode: 200,
   958  		Body:       io.NopCloser(strings.NewReader("")),
   959  	}
   960  	proxy := &ReverseProxy{
   961  		Director:  func(*http.Request) {},
   962  		Transport: &staticTransport{res},
   963  	}
   964  
   965  	w := httptest.NewRecorder()
   966  	r := httptest.NewRequest("GET", "/", nil)
   967  
   968  	b.ReportAllocs()
   969  	for i := 0; i < b.N; i++ {
   970  		proxy.ServeHTTP(w, r)
   971  	}
   972  }
   973  
   974  func TestServeHTTPDeepCopy(t *testing.T) {
   975  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   976  		w.Write([]byte("Hello Gopher!"))
   977  	}))
   978  	defer backend.Close()
   979  	backendURL, err := url.Parse(backend.URL)
   980  	if err != nil {
   981  		t.Fatal(err)
   982  	}
   983  
   984  	type result struct {
   985  		before, after string
   986  	}
   987  
   988  	resultChan := make(chan result, 1)
   989  	proxyHandler := NewSingleHostReverseProxy(backendURL)
   990  	frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   991  		before := r.URL.String()
   992  		proxyHandler.ServeHTTP(w, r)
   993  		after := r.URL.String()
   994  		resultChan <- result{before: before, after: after}
   995  	}))
   996  	defer frontend.Close()
   997  
   998  	want := result{before: "/", after: "/"}
   999  
  1000  	res, err := frontend.Client().Get(frontend.URL)
  1001  	if err != nil {
  1002  		t.Fatalf("Do: %v", err)
  1003  	}
  1004  	res.Body.Close()
  1005  
  1006  	got := <-resultChan
  1007  	if got != want {
  1008  		t.Errorf("got = %+v; want = %+v", got, want)
  1009  	}
  1010  }
  1011  
  1012  // Issue 18327: verify we always do a deep copy of the Request.Header map
  1013  // before any mutations.
  1014  func TestClonesRequestHeaders(t *testing.T) {
  1015  	log.SetOutput(io.Discard)
  1016  	defer log.SetOutput(os.Stderr)
  1017  	req, _ := http.NewRequest("GET", "http://foo.tld/", nil)
  1018  	req.RemoteAddr = "1.2.3.4:56789"
  1019  	rp := &ReverseProxy{
  1020  		Director: func(req *http.Request) {
  1021  			req.Header.Set("From-Director", "1")
  1022  		},
  1023  		Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) {
  1024  			if v := req.Header.Get("From-Director"); v != "1" {
  1025  				t.Errorf("From-Directory value = %q; want 1", v)
  1026  			}
  1027  			return nil, io.EOF
  1028  		}),
  1029  	}
  1030  	rp.ServeHTTP(httptest.NewRecorder(), req)
  1031  
  1032  	if req.Header.Get("From-Director") == "1" {
  1033  		t.Error("Director header mutation modified caller's request")
  1034  	}
  1035  	if req.Header.Get("X-Forwarded-For") != "" {
  1036  		t.Error("X-Forward-For header mutation modified caller's request")
  1037  	}
  1038  
  1039  }
  1040  
  1041  type roundTripperFunc func(req *http.Request) (*http.Response, error)
  1042  
  1043  func (fn roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) {
  1044  	return fn(req)
  1045  }
  1046  
  1047  func TestModifyResponseClosesBody(t *testing.T) {
  1048  	req, _ := http.NewRequest("GET", "http://foo.tld/", nil)
  1049  	req.RemoteAddr = "1.2.3.4:56789"
  1050  	closeCheck := new(checkCloser)
  1051  	logBuf := new(bytes.Buffer)
  1052  	outErr := errors.New("ModifyResponse error")
  1053  	rp := &ReverseProxy{
  1054  		Director: func(req *http.Request) {},
  1055  		Transport: &staticTransport{&http.Response{
  1056  			StatusCode: 200,
  1057  			Body:       closeCheck,
  1058  		}},
  1059  		ErrorLog: log.New(logBuf, "", 0),
  1060  		ModifyResponse: func(*http.Response) error {
  1061  			return outErr
  1062  		},
  1063  	}
  1064  	rec := httptest.NewRecorder()
  1065  	rp.ServeHTTP(rec, req)
  1066  	res := rec.Result()
  1067  	if g, e := res.StatusCode, http.StatusBadGateway; g != e {
  1068  		t.Errorf("got res.StatusCode %d; expected %d", g, e)
  1069  	}
  1070  	if !closeCheck.closed {
  1071  		t.Errorf("body should have been closed")
  1072  	}
  1073  	if g, e := logBuf.String(), outErr.Error(); !strings.Contains(g, e) {
  1074  		t.Errorf("ErrorLog %q does not contain %q", g, e)
  1075  	}
  1076  }
  1077  
  1078  type checkCloser struct {
  1079  	closed bool
  1080  }
  1081  
  1082  func (cc *checkCloser) Close() error {
  1083  	cc.closed = true
  1084  	return nil
  1085  }
  1086  
  1087  func (cc *checkCloser) Read(b []byte) (int, error) {
  1088  	return len(b), nil
  1089  }
  1090  
  1091  // Issue 23643: panic on body copy error
  1092  func TestReverseProxy_PanicBodyError(t *testing.T) {
  1093  	log.SetOutput(io.Discard)
  1094  	defer log.SetOutput(os.Stderr)
  1095  	backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  1096  		out := "this call was relayed by the reverse proxy"
  1097  		// Coerce a wrong content length to induce io.ErrUnexpectedEOF
  1098  		w.Header().Set("Content-Length", fmt.Sprintf("%d", len(out)*2))
  1099  		fmt.Fprintln(w, out)
  1100  	}))
  1101  	defer backendServer.Close()
  1102  
  1103  	rpURL, err := url.Parse(backendServer.URL)
  1104  	if err != nil {
  1105  		t.Fatal(err)
  1106  	}
  1107  
  1108  	rproxy := NewSingleHostReverseProxy(rpURL)
  1109  
  1110  	// Ensure that the handler panics when the body read encounters an
  1111  	// io.ErrUnexpectedEOF
  1112  	defer func() {
  1113  		err := recover()
  1114  		if err == nil {
  1115  			t.Fatal("handler should have panicked")
  1116  		}
  1117  		if err != http.ErrAbortHandler {
  1118  			t.Fatal("expected ErrAbortHandler, got", err)
  1119  		}
  1120  	}()
  1121  	req, _ := http.NewRequest("GET", "http://foo.tld/", nil)
  1122  	rproxy.ServeHTTP(httptest.NewRecorder(), req)
  1123  }
  1124  
  1125  // Issue #46866: panic without closing incoming request body causes a panic
  1126  func TestReverseProxy_PanicClosesIncomingBody(t *testing.T) {
  1127  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  1128  		out := "this call was relayed by the reverse proxy"
  1129  		// Coerce a wrong content length to induce io.ErrUnexpectedEOF
  1130  		w.Header().Set("Content-Length", fmt.Sprintf("%d", len(out)*2))
  1131  		fmt.Fprintln(w, out)
  1132  	}))
  1133  	defer backend.Close()
  1134  	backendURL, err := url.Parse(backend.URL)
  1135  	if err != nil {
  1136  		t.Fatal(err)
  1137  	}
  1138  	proxyHandler := NewSingleHostReverseProxy(backendURL)
  1139  	proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
  1140  	frontend := httptest.NewServer(proxyHandler)
  1141  	defer frontend.Close()
  1142  	frontendClient := frontend.Client()
  1143  
  1144  	var wg sync.WaitGroup
  1145  	for i := 0; i < 2; i++ {
  1146  		wg.Add(1)
  1147  		go func() {
  1148  			defer wg.Done()
  1149  			for j := 0; j < 10; j++ {
  1150  				const reqLen = 6 * 1024 * 1024
  1151  				req, _ := http.NewRequest("POST", frontend.URL, &io.LimitedReader{R: neverEnding('x'), N: reqLen})
  1152  				req.ContentLength = reqLen
  1153  				resp, _ := frontendClient.Transport.RoundTrip(req)
  1154  				if resp != nil {
  1155  					io.Copy(io.Discard, resp.Body)
  1156  					resp.Body.Close()
  1157  				}
  1158  			}
  1159  		}()
  1160  	}
  1161  	wg.Wait()
  1162  }
  1163  
  1164  func TestSelectFlushInterval(t *testing.T) {
  1165  	tests := []struct {
  1166  		name string
  1167  		p    *ReverseProxy
  1168  		res  *http.Response
  1169  		want time.Duration
  1170  	}{
  1171  		{
  1172  			name: "default",
  1173  			res:  &http.Response{},
  1174  			p:    &ReverseProxy{FlushInterval: 123},
  1175  			want: 123,
  1176  		},
  1177  		{
  1178  			name: "server-sent events overrides non-zero",
  1179  			res: &http.Response{
  1180  				Header: http.Header{
  1181  					"Content-Type": {"text/event-stream"},
  1182  				},
  1183  			},
  1184  			p:    &ReverseProxy{FlushInterval: 123},
  1185  			want: -1,
  1186  		},
  1187  		{
  1188  			name: "server-sent events overrides zero",
  1189  			res: &http.Response{
  1190  				Header: http.Header{
  1191  					"Content-Type": {"text/event-stream"},
  1192  				},
  1193  			},
  1194  			p:    &ReverseProxy{FlushInterval: 0},
  1195  			want: -1,
  1196  		},
  1197  		{
  1198  			name: "Content-Length: -1, overrides non-zero",
  1199  			res: &http.Response{
  1200  				ContentLength: -1,
  1201  			},
  1202  			p:    &ReverseProxy{FlushInterval: 123},
  1203  			want: -1,
  1204  		},
  1205  		{
  1206  			name: "Content-Length: -1, overrides zero",
  1207  			res: &http.Response{
  1208  				ContentLength: -1,
  1209  			},
  1210  			p:    &ReverseProxy{FlushInterval: 0},
  1211  			want: -1,
  1212  		},
  1213  	}
  1214  	for _, tt := range tests {
  1215  		t.Run(tt.name, func(t *testing.T) {
  1216  			got := tt.p.flushInterval(tt.res)
  1217  			if got != tt.want {
  1218  				t.Errorf("flushLatency = %v; want %v", got, tt.want)
  1219  			}
  1220  		})
  1221  	}
  1222  }
  1223  
  1224  func TestReverseProxyWebSocket(t *testing.T) {
  1225  	backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  1226  		if upgradeType(r.Header) != "websocket" {
  1227  			t.Error("unexpected backend request")
  1228  			http.Error(w, "unexpected request", 400)
  1229  			return
  1230  		}
  1231  		c, _, err := w.(http.Hijacker).Hijack()
  1232  		if err != nil {
  1233  			t.Error(err)
  1234  			return
  1235  		}
  1236  		defer c.Close()
  1237  		io.WriteString(c, "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: WebSocket\r\n\r\n")
  1238  		bs := bufio.NewScanner(c)
  1239  		if !bs.Scan() {
  1240  			t.Errorf("backend failed to read line from client: %v", bs.Err())
  1241  			return
  1242  		}
  1243  		fmt.Fprintf(c, "backend got %q\n", bs.Text())
  1244  	}))
  1245  	defer backendServer.Close()
  1246  
  1247  	backURL, _ := url.Parse(backendServer.URL)
  1248  	rproxy := NewSingleHostReverseProxy(backURL)
  1249  	rproxy.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
  1250  	rproxy.ModifyResponse = func(res *http.Response) error {
  1251  		res.Header.Add("X-Modified", "true")
  1252  		return nil
  1253  	}
  1254  
  1255  	handler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
  1256  		rw.Header().Set("X-Header", "X-Value")
  1257  		rproxy.ServeHTTP(rw, req)
  1258  		if got, want := rw.Header().Get("X-Modified"), "true"; got != want {
  1259  			t.Errorf("response writer X-Modified header = %q; want %q", got, want)
  1260  		}
  1261  	})
  1262  
  1263  	frontendProxy := httptest.NewServer(handler)
  1264  	defer frontendProxy.Close()
  1265  
  1266  	req, _ := http.NewRequest("GET", frontendProxy.URL, nil)
  1267  	req.Header.Set("Connection", "Upgrade")
  1268  	req.Header.Set("Upgrade", "websocket")
  1269  
  1270  	c := frontendProxy.Client()
  1271  	res, err := c.Do(req)
  1272  	if err != nil {
  1273  		t.Fatal(err)
  1274  	}
  1275  	if res.StatusCode != 101 {
  1276  		t.Fatalf("status = %v; want 101", res.Status)
  1277  	}
  1278  
  1279  	got := res.Header.Get("X-Header")
  1280  	want := "X-Value"
  1281  	if got != want {
  1282  		t.Errorf("Header(XHeader) = %q; want %q", got, want)
  1283  	}
  1284  
  1285  	if !ascii.EqualFold(upgradeType(res.Header), "websocket") {
  1286  		t.Fatalf("not websocket upgrade; got %#v", res.Header)
  1287  	}
  1288  	rwc, ok := res.Body.(io.ReadWriteCloser)
  1289  	if !ok {
  1290  		t.Fatalf("response body is of type %T; does not implement ReadWriteCloser", res.Body)
  1291  	}
  1292  	defer rwc.Close()
  1293  
  1294  	if got, want := res.Header.Get("X-Modified"), "true"; got != want {
  1295  		t.Errorf("response X-Modified header = %q; want %q", got, want)
  1296  	}
  1297  
  1298  	io.WriteString(rwc, "Hello\n")
  1299  	bs := bufio.NewScanner(rwc)
  1300  	if !bs.Scan() {
  1301  		t.Fatalf("Scan: %v", bs.Err())
  1302  	}
  1303  	got = bs.Text()
  1304  	want = `backend got "Hello"`
  1305  	if got != want {
  1306  		t.Errorf("got %#q, want %#q", got, want)
  1307  	}
  1308  }
  1309  
  1310  func TestReverseProxyWebSocketCancellation(t *testing.T) {
  1311  	n := 5
  1312  	triggerCancelCh := make(chan bool, n)
  1313  	nthResponse := func(i int) string {
  1314  		return fmt.Sprintf("backend response #%d\n", i)
  1315  	}
  1316  	terminalMsg := "final message"
  1317  
  1318  	cst := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  1319  		if g, ws := upgradeType(r.Header), "websocket"; g != ws {
  1320  			t.Errorf("Unexpected upgrade type %q, want %q", g, ws)
  1321  			http.Error(w, "Unexpected request", 400)
  1322  			return
  1323  		}
  1324  		conn, bufrw, err := w.(http.Hijacker).Hijack()
  1325  		if err != nil {
  1326  			t.Error(err)
  1327  			return
  1328  		}
  1329  		defer conn.Close()
  1330  
  1331  		upgradeMsg := "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: WebSocket\r\n\r\n"
  1332  		if _, err := io.WriteString(conn, upgradeMsg); err != nil {
  1333  			t.Error(err)
  1334  			return
  1335  		}
  1336  		if _, _, err := bufrw.ReadLine(); err != nil {
  1337  			t.Errorf("Failed to read line from client: %v", err)
  1338  			return
  1339  		}
  1340  
  1341  		for i := 0; i < n; i++ {
  1342  			if _, err := bufrw.WriteString(nthResponse(i)); err != nil {
  1343  				select {
  1344  				case <-triggerCancelCh:
  1345  				default:
  1346  					t.Errorf("Writing response #%d failed: %v", i, err)
  1347  				}
  1348  				return
  1349  			}
  1350  			bufrw.Flush()
  1351  			time.Sleep(time.Second)
  1352  		}
  1353  		if _, err := bufrw.WriteString(terminalMsg); err != nil {
  1354  			select {
  1355  			case <-triggerCancelCh:
  1356  			default:
  1357  				t.Errorf("Failed to write terminal message: %v", err)
  1358  			}
  1359  		}
  1360  		bufrw.Flush()
  1361  	}))
  1362  	defer cst.Close()
  1363  
  1364  	backendURL, _ := url.Parse(cst.URL)
  1365  	rproxy := NewSingleHostReverseProxy(backendURL)
  1366  	rproxy.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
  1367  	rproxy.ModifyResponse = func(res *http.Response) error {
  1368  		res.Header.Add("X-Modified", "true")
  1369  		return nil
  1370  	}
  1371  
  1372  	handler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
  1373  		rw.Header().Set("X-Header", "X-Value")
  1374  		ctx, cancel := context.WithCancel(req.Context())
  1375  		go func() {
  1376  			<-triggerCancelCh
  1377  			cancel()
  1378  		}()
  1379  		rproxy.ServeHTTP(rw, req.WithContext(ctx))
  1380  	})
  1381  
  1382  	frontendProxy := httptest.NewServer(handler)
  1383  	defer frontendProxy.Close()
  1384  
  1385  	req, _ := http.NewRequest("GET", frontendProxy.URL, nil)
  1386  	req.Header.Set("Connection", "Upgrade")
  1387  	req.Header.Set("Upgrade", "websocket")
  1388  
  1389  	res, err := frontendProxy.Client().Do(req)
  1390  	if err != nil {
  1391  		t.Fatalf("Dialing to frontend proxy: %v", err)
  1392  	}
  1393  	defer res.Body.Close()
  1394  	if g, w := res.StatusCode, 101; g != w {
  1395  		t.Fatalf("Switching protocols failed, got: %d, want: %d", g, w)
  1396  	}
  1397  
  1398  	if g, w := res.Header.Get("X-Header"), "X-Value"; g != w {
  1399  		t.Errorf("X-Header mismatch\n\tgot:  %q\n\twant: %q", g, w)
  1400  	}
  1401  
  1402  	if g, w := upgradeType(res.Header), "websocket"; !ascii.EqualFold(g, w) {
  1403  		t.Fatalf("Upgrade header mismatch\n\tgot:  %q\n\twant: %q", g, w)
  1404  	}
  1405  
  1406  	rwc, ok := res.Body.(io.ReadWriteCloser)
  1407  	if !ok {
  1408  		t.Fatalf("Response body type mismatch, got %T, want io.ReadWriteCloser", res.Body)
  1409  	}
  1410  
  1411  	if got, want := res.Header.Get("X-Modified"), "true"; got != want {
  1412  		t.Errorf("response X-Modified header = %q; want %q", got, want)
  1413  	}
  1414  
  1415  	if _, err := io.WriteString(rwc, "Hello\n"); err != nil {
  1416  		t.Fatalf("Failed to write first message: %v", err)
  1417  	}
  1418  
  1419  	// Read loop.
  1420  
  1421  	br := bufio.NewReader(rwc)
  1422  	for {
  1423  		line, err := br.ReadString('\n')
  1424  		switch {
  1425  		case line == terminalMsg: // this case before "err == io.EOF"
  1426  			t.Fatalf("The websocket request was not canceled, unfortunately!")
  1427  
  1428  		case err == io.EOF:
  1429  			return
  1430  
  1431  		case err != nil:
  1432  			t.Fatalf("Unexpected error: %v", err)
  1433  
  1434  		case line == nthResponse(0): // We've gotten the first response back
  1435  			// Let's trigger a cancel.
  1436  			close(triggerCancelCh)
  1437  		}
  1438  	}
  1439  }
  1440  
  1441  func TestUnannouncedTrailer(t *testing.T) {
  1442  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  1443  		w.WriteHeader(http.StatusOK)
  1444  		w.(http.Flusher).Flush()
  1445  		w.Header().Set(http.TrailerPrefix+"X-Unannounced-Trailer", "unannounced_trailer_value")
  1446  	}))
  1447  	defer backend.Close()
  1448  	backendURL, err := url.Parse(backend.URL)
  1449  	if err != nil {
  1450  		t.Fatal(err)
  1451  	}
  1452  	proxyHandler := NewSingleHostReverseProxy(backendURL)
  1453  	proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
  1454  	frontend := httptest.NewServer(proxyHandler)
  1455  	defer frontend.Close()
  1456  	frontendClient := frontend.Client()
  1457  
  1458  	res, err := frontendClient.Get(frontend.URL)
  1459  	if err != nil {
  1460  		t.Fatalf("Get: %v", err)
  1461  	}
  1462  
  1463  	io.ReadAll(res.Body)
  1464  
  1465  	if g, w := res.Trailer.Get("X-Unannounced-Trailer"), "unannounced_trailer_value"; g != w {
  1466  		t.Errorf("Trailer(X-Unannounced-Trailer) = %q; want %q", g, w)
  1467  	}
  1468  
  1469  }
  1470  
  1471  func TestSingleJoinSlash(t *testing.T) {
  1472  	tests := []struct {
  1473  		slasha   string
  1474  		slashb   string
  1475  		expected string
  1476  	}{
  1477  		{"https://www.google.com/", "/favicon.ico", "https://www.google.com/favicon.ico"},
  1478  		{"https://www.google.com", "/favicon.ico", "https://www.google.com/favicon.ico"},
  1479  		{"https://www.google.com", "favicon.ico", "https://www.google.com/favicon.ico"},
  1480  		{"https://www.google.com", "", "https://www.google.com/"},
  1481  		{"", "favicon.ico", "/favicon.ico"},
  1482  	}
  1483  	for _, tt := range tests {
  1484  		if got := singleJoiningSlash(tt.slasha, tt.slashb); got != tt.expected {
  1485  			t.Errorf("singleJoiningSlash(%q,%q) want %q got %q",
  1486  				tt.slasha,
  1487  				tt.slashb,
  1488  				tt.expected,
  1489  				got)
  1490  		}
  1491  	}
  1492  }
  1493  
  1494  func TestJoinURLPath(t *testing.T) {
  1495  	tests := []struct {
  1496  		a        *url.URL
  1497  		b        *url.URL
  1498  		wantPath string
  1499  		wantRaw  string
  1500  	}{
  1501  		{&url.URL{Path: "/a/b"}, &url.URL{Path: "/c"}, "/a/b/c", ""},
  1502  		{&url.URL{Path: "/a/b", RawPath: "badpath"}, &url.URL{Path: "c"}, "/a/b/c", "/a/b/c"},
  1503  		{&url.URL{Path: "/a/b", RawPath: "/a%2Fb"}, &url.URL{Path: "/c"}, "/a/b/c", "/a%2Fb/c"},
  1504  		{&url.URL{Path: "/a/b", RawPath: "/a%2Fb"}, &url.URL{Path: "/c"}, "/a/b/c", "/a%2Fb/c"},
  1505  		{&url.URL{Path: "/a/b/", RawPath: "/a%2Fb%2F"}, &url.URL{Path: "c"}, "/a/b//c", "/a%2Fb%2F/c"},
  1506  		{&url.URL{Path: "/a/b/", RawPath: "/a%2Fb/"}, &url.URL{Path: "/c/d", RawPath: "/c%2Fd"}, "/a/b/c/d", "/a%2Fb/c%2Fd"},
  1507  	}
  1508  
  1509  	for _, tt := range tests {
  1510  		p, rp := joinURLPath(tt.a, tt.b)
  1511  		if p != tt.wantPath || rp != tt.wantRaw {
  1512  			t.Errorf("joinURLPath(URL(%q,%q),URL(%q,%q)) want (%q,%q) got (%q,%q)",
  1513  				tt.a.Path, tt.a.RawPath,
  1514  				tt.b.Path, tt.b.RawPath,
  1515  				tt.wantPath, tt.wantRaw,
  1516  				p, rp)
  1517  		}
  1518  	}
  1519  }
  1520  

View as plain text