1
2
3
4
5 package httptest
6
7 import (
8 "bufio"
9 "io"
10 "net"
11 "net/http"
12 "testing"
13 )
14
15 type newServerFunc func(http.Handler) *Server
16
17 var newServers = map[string]newServerFunc{
18 "NewServer": NewServer,
19 "NewTLSServer": NewTLSServer,
20
21
22
23 "NewServerManual": func(h http.Handler) *Server {
24 ts := &Server{Listener: newLocalListener(), Config: &http.Server{Handler: h}}
25 ts.Start()
26 return ts
27 },
28 "NewTLSServerManual": func(h http.Handler) *Server {
29 ts := &Server{Listener: newLocalListener(), Config: &http.Server{Handler: h}}
30 ts.StartTLS()
31 return ts
32 },
33 }
34
35 func TestServer(t *testing.T) {
36 for _, name := range []string{"NewServer", "NewServerManual"} {
37 t.Run(name, func(t *testing.T) {
38 newServer := newServers[name]
39 t.Run("Server", func(t *testing.T) { testServer(t, newServer) })
40 t.Run("GetAfterClose", func(t *testing.T) { testGetAfterClose(t, newServer) })
41 t.Run("ServerCloseBlocking", func(t *testing.T) { testServerCloseBlocking(t, newServer) })
42 t.Run("ServerCloseClientConnections", func(t *testing.T) { testServerCloseClientConnections(t, newServer) })
43 t.Run("ServerClientTransportType", func(t *testing.T) { testServerClientTransportType(t, newServer) })
44 })
45 }
46 for _, name := range []string{"NewTLSServer", "NewTLSServerManual"} {
47 t.Run(name, func(t *testing.T) {
48 newServer := newServers[name]
49 t.Run("ServerClient", func(t *testing.T) { testServerClient(t, newServer) })
50 t.Run("TLSServerClientTransportType", func(t *testing.T) { testTLSServerClientTransportType(t, newServer) })
51 })
52 }
53 }
54
55 func testServer(t *testing.T, newServer newServerFunc) {
56 ts := newServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
57 w.Write([]byte("hello"))
58 }))
59 defer ts.Close()
60 res, err := http.Get(ts.URL)
61 if err != nil {
62 t.Fatal(err)
63 }
64 got, err := io.ReadAll(res.Body)
65 res.Body.Close()
66 if err != nil {
67 t.Fatal(err)
68 }
69 if string(got) != "hello" {
70 t.Errorf("got %q, want hello", string(got))
71 }
72 }
73
74
75 func testGetAfterClose(t *testing.T, newServer newServerFunc) {
76 ts := newServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
77 w.Write([]byte("hello"))
78 }))
79
80 res, err := http.Get(ts.URL)
81 if err != nil {
82 t.Fatal(err)
83 }
84 got, err := io.ReadAll(res.Body)
85 if err != nil {
86 t.Fatal(err)
87 }
88 if string(got) != "hello" {
89 t.Fatalf("got %q, want hello", string(got))
90 }
91
92 ts.Close()
93
94 res, err = http.Get(ts.URL)
95 if err == nil {
96 body, _ := io.ReadAll(res.Body)
97 t.Fatalf("Unexpected response after close: %v, %v, %s", res.Status, res.Header, body)
98 }
99 }
100
101 func testServerCloseBlocking(t *testing.T, newServer newServerFunc) {
102 ts := newServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
103 w.Write([]byte("hello"))
104 }))
105 dial := func() net.Conn {
106 c, err := net.Dial("tcp", ts.Listener.Addr().String())
107 if err != nil {
108 t.Fatal(err)
109 }
110 return c
111 }
112
113
114 cnew := dial()
115 defer cnew.Close()
116
117
118 cidle := dial()
119 defer cidle.Close()
120 cidle.Write([]byte("HEAD / HTTP/1.1\r\nHost: foo\r\n\r\n"))
121 _, err := http.ReadResponse(bufio.NewReader(cidle), nil)
122 if err != nil {
123 t.Fatal(err)
124 }
125
126 ts.Close()
127 }
128
129
130 func testServerCloseClientConnections(t *testing.T, newServer newServerFunc) {
131 var s *Server
132 s = newServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
133 s.CloseClientConnections()
134 }))
135 defer s.Close()
136 res, err := http.Get(s.URL)
137 if err == nil {
138 res.Body.Close()
139 t.Fatalf("Unexpected response: %#v", res)
140 }
141 }
142
143
144
145 func testServerClient(t *testing.T, newTLSServer newServerFunc) {
146 ts := newTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
147 w.Write([]byte("hello"))
148 }))
149 defer ts.Close()
150 client := ts.Client()
151 res, err := client.Get(ts.URL)
152 if err != nil {
153 t.Fatal(err)
154 }
155 got, err := io.ReadAll(res.Body)
156 res.Body.Close()
157 if err != nil {
158 t.Fatal(err)
159 }
160 if string(got) != "hello" {
161 t.Errorf("got %q, want hello", string(got))
162 }
163 }
164
165
166
167 func testServerClientTransportType(t *testing.T, newServer newServerFunc) {
168 ts := newServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
169 }))
170 defer ts.Close()
171 client := ts.Client()
172 if _, ok := client.Transport.(*http.Transport); !ok {
173 t.Errorf("got %T, want *http.Transport", client.Transport)
174 }
175 }
176
177
178
179 func testTLSServerClientTransportType(t *testing.T, newTLSServer newServerFunc) {
180 ts := newTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
181 }))
182 defer ts.Close()
183 client := ts.Client()
184 if _, ok := client.Transport.(*http.Transport); !ok {
185 t.Errorf("got %T, want *http.Transport", client.Transport)
186 }
187 }
188
189 type onlyCloseListener struct {
190 net.Listener
191 }
192
193 func (onlyCloseListener) Close() error { return nil }
194
195
196
197 func TestServerZeroValueClose(t *testing.T) {
198 ts := &Server{
199 Listener: onlyCloseListener{},
200 Config: &http.Server{},
201 }
202
203 ts.Close()
204 }
205
206 func TestTLSServerWithHTTP2(t *testing.T) {
207 modes := []struct {
208 name string
209 wantProto string
210 }{
211 {"http1", "HTTP/1.1"},
212 {"http2", "HTTP/2.0"},
213 }
214
215 for _, tt := range modes {
216 t.Run(tt.name, func(t *testing.T) {
217 cst := NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
218 w.Header().Set("X-Proto", r.Proto)
219 }))
220
221 switch tt.name {
222 case "http2":
223 cst.EnableHTTP2 = true
224 cst.StartTLS()
225 default:
226 cst.Start()
227 }
228
229 defer cst.Close()
230
231 res, err := cst.Client().Get(cst.URL)
232 if err != nil {
233 t.Fatalf("Failed to make request: %v", err)
234 }
235 if g, w := res.Header.Get("X-Proto"), tt.wantProto; g != w {
236 t.Fatalf("X-Proto header mismatch:\n\tgot: %q\n\twant: %q", g, w)
237 }
238 })
239 }
240 }
241
View as plain text