// Copyright 2016 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. // White-box tests for transport.go (in package http instead of http_test). package http import ( "bytes" "crypto/tls" "errors" "io" "net" "net/http/internal/testcert" "strings" "testing" ) // Issue 15446: incorrect wrapping of errors when server closes an idle connection. func TestTransportPersistConnReadLoopEOF(t *testing.T) { ln := newLocalListener(t) defer ln.Close() connc := make(chan net.Conn, 1) go func() { defer close(connc) c, err := ln.Accept() if err != nil { t.Error(err) return } connc <- c }() tr := new(Transport) req, _ := NewRequest("GET", "http://"+ln.Addr().String(), nil) req = req.WithT(t) treq := &transportRequest{Request: req} cm := connectMethod{targetScheme: "http", targetAddr: ln.Addr().String()} pc, err := tr.getConn(treq, cm) if err != nil { t.Fatal(err) } defer pc.close(errors.New("test over")) conn := <-connc if conn == nil { // Already called t.Error in the accept goroutine. return } conn.Close() // simulate the server hanging up on the client _, err = pc.roundTrip(treq) if !isTransportReadFromServerError(err) && err != errServerClosedIdle { t.Errorf("roundTrip = %#v, %v; want errServerClosedIdle or transportReadFromServerError", err, err) } <-pc.closech err = pc.closed if !isTransportReadFromServerError(err) && err != errServerClosedIdle { t.Errorf("pc.closed = %#v, %v; want errServerClosedIdle or transportReadFromServerError", err, err) } } func isTransportReadFromServerError(err error) bool { _, ok := err.(transportReadFromServerError) return ok } func newLocalListener(t *testing.T) net.Listener { ln, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { ln, err = net.Listen("tcp6", "[::1]:0") } if err != nil { t.Fatal(err) } return ln } func dummyRequest(method string) *Request { req, err := NewRequest(method, "http://fake.tld/", nil) if err != nil { panic(err) } return req } func dummyRequestWithBody(method string) *Request { req, err := NewRequest(method, "http://fake.tld/", strings.NewReader("foo")) if err != nil { panic(err) } return req } func dummyRequestWithBodyNoGetBody(method string) *Request { req := dummyRequestWithBody(method) req.GetBody = nil return req } // issue22091Error acts like a golang.org/x/net/http2.ErrNoCachedConn. type issue22091Error struct{} func (issue22091Error) IsHTTP2NoCachedConnError() {} func (issue22091Error) Error() string { return "issue22091Error" } func TestTransportShouldRetryRequest(t *testing.T) { tests := []struct { pc *persistConn req *Request err error want bool }{ 0: { pc: &persistConn{reused: false}, req: dummyRequest("POST"), err: nothingWrittenError{}, want: false, }, 1: { pc: &persistConn{reused: true}, req: dummyRequest("POST"), err: nothingWrittenError{}, want: true, }, 2: { pc: &persistConn{reused: true}, req: dummyRequest("POST"), err: http2ErrNoCachedConn, want: true, }, 3: { pc: nil, req: nil, err: issue22091Error{}, // like an external http2ErrNoCachedConn want: true, }, 4: { pc: &persistConn{reused: true}, req: dummyRequest("POST"), err: errMissingHost, want: false, }, 5: { pc: &persistConn{reused: true}, req: dummyRequest("POST"), err: transportReadFromServerError{}, want: false, }, 6: { pc: &persistConn{reused: true}, req: dummyRequest("GET"), err: transportReadFromServerError{}, want: true, }, 7: { pc: &persistConn{reused: true}, req: dummyRequest("GET"), err: errServerClosedIdle, want: true, }, 8: { pc: &persistConn{reused: true}, req: dummyRequestWithBody("POST"), err: nothingWrittenError{}, want: true, }, 9: { pc: &persistConn{reused: true}, req: dummyRequestWithBodyNoGetBody("POST"), err: nothingWrittenError{}, want: false, }, } for i, tt := range tests { got := tt.pc.shouldRetryRequest(tt.req, tt.err) if got != tt.want { t.Errorf("%d. shouldRetryRequest = %v; want %v", i, got, tt.want) } } } type roundTripFunc func(r *Request) (*Response, error) func (f roundTripFunc) RoundTrip(r *Request) (*Response, error) { return f(r) } // Issue 25009 func TestTransportBodyAltRewind(t *testing.T) { cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey) if err != nil { t.Fatal(err) } ln := newLocalListener(t) defer ln.Close() go func() { tln := tls.NewListener(ln, &tls.Config{ NextProtos: []string{"foo"}, Certificates: []tls.Certificate{cert}, }) for i := 0; i < 2; i++ { sc, err := tln.Accept() if err != nil { t.Error(err) return } if err := sc.(*tls.Conn).Handshake(); err != nil { t.Error(err) return } sc.Close() } }() addr := ln.Addr().String() req, _ := NewRequest("POST", "https://example.org/", bytes.NewBufferString("request")) roundTripped := false tr := &Transport{ DisableKeepAlives: true, TLSNextProto: map[string]func(string, *tls.Conn) RoundTripper{ "foo": func(authority string, c *tls.Conn) RoundTripper { return roundTripFunc(func(r *Request) (*Response, error) { n, _ := io.Copy(io.Discard, r.Body) if n == 0 { t.Error("body length is zero") } if roundTripped { return &Response{ Body: NoBody, StatusCode: 200, }, nil } roundTripped = true return nil, http2noCachedConnError{} }) }, }, DialTLS: func(_, _ string) (net.Conn, error) { tc, err := tls.Dial("tcp", addr, &tls.Config{ InsecureSkipVerify: true, NextProtos: []string{"foo"}, }) if err != nil { return nil, err } if err := tc.Handshake(); err != nil { return nil, err } return tc, nil }, } c := &Client{Transport: tr} _, err = c.Do(req) if err != nil { t.Error(err) } }