reverse_test.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452
  1. package reverseproxy
  2. import (
  3. "bytes"
  4. "io/ioutil"
  5. "log"
  6. "net/http"
  7. "net/http/httptest"
  8. "net/url"
  9. "reflect"
  10. "strings"
  11. "testing"
  12. "time"
  13. )
  14. const fakeHopHeader = "X-Fake-Hop-Header-For-Test"
  15. func init() {
  16. hopHeaders = append(hopHeaders, fakeHopHeader)
  17. }
  18. func TestReverseProxy(t *testing.T) {
  19. backendResponse := "I am the backend"
  20. backendStatus := 404
  21. backend := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
  22. if len(req.TransferEncoding) > 0 {
  23. t.Errorf("backend got unexpected TransferEncoding: %v", req.TransferEncoding)
  24. }
  25. if req.Header.Get("X-Forwarded-For") == "" {
  26. t.Errorf("didn't get X-Forwarded-For header")
  27. }
  28. if c := req.Header.Get("Connection"); c != "" {
  29. t.Errorf("handler got Connection header value %q", c)
  30. }
  31. if c := req.Header.Get("Upgrade"); c != "" {
  32. t.Errorf("handler got Upgrade header value %q", c)
  33. }
  34. if c := req.Header.Get("Proxy-Connection"); c != "" {
  35. t.Errorf("handler got Proxy-Connection header value %q", c)
  36. }
  37. if c := req.Host; c == "" {
  38. t.Errorf("backend got Host header %q", c)
  39. }
  40. rw.Header().Set("X-Foo", "bar")
  41. rw.Header().Set(fakeHopHeader, "foo")
  42. rw.Header().Set("Trailers", "not a special header field name")
  43. rw.Header().Set("Trailer", "X-Trailer")
  44. rw.Header().Set("Upgrade", "foo")
  45. rw.Header().Add("X-Multi-Value", "foo")
  46. rw.Header().Add("X-Multi-Value", "bar")
  47. http.SetCookie(rw, &http.Cookie{Name: "flavor", Value: "chocolateChip"})
  48. rw.WriteHeader(backendStatus)
  49. rw.Write([]byte(backendResponse))
  50. rw.Header().Set("X-Trailer", "trailer_value")
  51. }))
  52. defer backend.Close()
  53. backendURL, err := url.Parse(backend.URL)
  54. if err != nil {
  55. t.Fatal(err)
  56. }
  57. proxyHandler := NewReverseProxy(backendURL)
  58. proxyHandler.ErrorLog = log.New(ioutil.Discard, "", 0)
  59. frontend := httptest.NewServer(proxyHandler)
  60. defer frontend.Close()
  61. getReq, _ := http.NewRequest("GET", frontend.URL, nil)
  62. getReq.Host = "some host"
  63. getReq.Header.Set("Connection", "close")
  64. getReq.Header.Set("Proxy-Connection", "should be deleted")
  65. getReq.Header.Set("Upgrade", "foo")
  66. getReq.Close = true
  67. res, err := http.DefaultClient.Do(getReq)
  68. if err != nil {
  69. t.Fatalf("Get: %v", err)
  70. }
  71. if g, e := res.StatusCode, backendStatus; g != e {
  72. t.Errorf("got res.StatusCode %d; expected %d", g, e)
  73. }
  74. if g, e := res.Header.Get("X-Foo"), "bar"; g != e {
  75. t.Errorf("got X-Foo %q; expected %q", g, e)
  76. }
  77. if c := res.Header.Get(fakeHopHeader); c != "" {
  78. t.Errorf("got %s header value %q", fakeHopHeader, c)
  79. }
  80. if g, e := res.Header.Get("Trailers"), "not a special header field name"; g != e {
  81. t.Errorf("header Trailers = %q; want %q", g, e)
  82. }
  83. if g, e := len(res.Header["X-Multi-Value"]), 2; g != e {
  84. t.Errorf("got %d X-Multi-Value header values; expected %d", g, e)
  85. }
  86. if g, e := len(res.Header["Set-Cookie"]), 1; g != e {
  87. t.Fatalf("got %d SetCookies, want %d", g, e)
  88. }
  89. if g, e := res.Trailer, (http.Header{"X-Trailer": nil}); !reflect.DeepEqual(g, e) {
  90. t.Errorf("before reading body, Trailer = %#v; want %#v", g, e)
  91. }
  92. if cookie := res.Cookies()[0]; cookie.Name != "flavor" {
  93. t.Errorf("unexpected cookie %q", cookie.Name)
  94. }
  95. bodyBytes, _ := ioutil.ReadAll(res.Body)
  96. if g, e := string(bodyBytes), backendResponse; g != e {
  97. t.Errorf("got body %q; expected %q", g, e)
  98. }
  99. if g, e := res.Trailer.Get("X-Trailer"), "trailer_value"; g != e {
  100. t.Errorf("Trailer(X-Trailer) = %q ; want %q", g, e)
  101. }
  102. }
  103. func TestReverseProxyStripHeadersPresentInConnection(t *testing.T) {
  104. const fakeConnectionToken = "X-Fake-Connection-Token"
  105. const backendResponse = "I am the backend"
  106. backend := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
  107. if c := req.Header.Get(fakeConnectionToken); c != "" {
  108. t.Errorf("handler got header %q = %q; want empty", fakeConnectionToken, c)
  109. }
  110. if c := req.Header.Get("Upgrade"); c != "" {
  111. t.Errorf("handler got header %q = %q; want empty", "Upgrade", c)
  112. }
  113. rw.Header().Set("Connection", "Upgrade, "+fakeConnectionToken)
  114. rw.Header().Set("Upgrade", "should be deleted")
  115. rw.Header().Set(fakeConnectionToken, "should be deleted")
  116. rw.Write([]byte(backendResponse))
  117. }))
  118. defer backend.Close()
  119. backendURL, err := url.Parse(backend.URL)
  120. if err != nil {
  121. t.Fatal(err)
  122. }
  123. proxyHandler := NewReverseProxy(backendURL)
  124. frontend := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
  125. proxyHandler.ServeHTTP(rw, req)
  126. if c := req.Header.Get("Upgrade"); c != "original value" {
  127. t.Errorf("handler modified header %q = %q; want %q", "Upgrade", c, "original value")
  128. }
  129. }))
  130. defer frontend.Close()
  131. getReq, _ := http.NewRequest("GET", frontend.URL, nil)
  132. getReq.Header.Set("Connection", "Upgrade, "+fakeConnectionToken)
  133. getReq.Header.Set("Upgrade", "original value")
  134. getReq.Header.Set(fakeConnectionToken, "should be deleted")
  135. res, err := http.DefaultClient.Do(getReq)
  136. if err != nil {
  137. t.Fatalf("Get: %v", err)
  138. }
  139. defer res.Body.Close()
  140. bodyBytes, err := ioutil.ReadAll(res.Body)
  141. if err != nil {
  142. t.Fatalf("reading body: %v", err)
  143. }
  144. if g, e := string(bodyBytes), backendResponse; g != e {
  145. t.Errorf("got body %q; want %q", g, e)
  146. }
  147. if c := res.Header.Get("Upgrade"); c != "" {
  148. t.Errorf("handler got header %q = %q; want empty", "Upgrade", c)
  149. }
  150. if c := res.Header.Get(fakeConnectionToken); c != "" {
  151. t.Errorf("handler got header %q = %q; want empty", fakeConnectionToken, c)
  152. }
  153. }
  154. func TestXForwardedFor(t *testing.T) {
  155. const prevForwardedFor = "client ip"
  156. const backendResponse = "I am the backend"
  157. const backendStatus = 404
  158. backend := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
  159. if req.Header.Get("X-Forwarded-For") == "" {
  160. t.Errorf("didn't get X-Forwarded-For header")
  161. }
  162. if !strings.Contains(req.Header.Get("X-Forwarded-For"), prevForwardedFor) {
  163. t.Errorf("X-Forwarded-For didn't contain prior data")
  164. }
  165. rw.WriteHeader(backendStatus)
  166. rw.Write([]byte(backendResponse))
  167. }))
  168. defer backend.Close()
  169. backendURL, err := url.Parse(backend.URL)
  170. if err != nil {
  171. t.Fatal(err)
  172. }
  173. proxyHandler := NewReverseProxy(backendURL)
  174. frontend := httptest.NewServer(proxyHandler)
  175. defer frontend.Close()
  176. getReq, _ := http.NewRequest("GET", frontend.URL, nil)
  177. getReq.Header.Set("X-Forwarded-For", prevForwardedFor)
  178. getReq.Close = true
  179. res, err := http.DefaultClient.Do(getReq)
  180. if err != nil {
  181. t.Fatalf("Get: %v", err)
  182. }
  183. defer res.Body.Close()
  184. if g, e := res.StatusCode, backendStatus; g != e {
  185. t.Errorf("got res.StatusCode %d; expected %d", g, e)
  186. }
  187. bodyBytes, _ := ioutil.ReadAll(res.Body)
  188. if g, e := string(bodyBytes), backendResponse; g != e {
  189. t.Errorf("got body %q; expected %q", g, e)
  190. }
  191. }
  192. var proxyQueryTests = []struct {
  193. baseSuffix string // suffix to add to backend URL
  194. reqSuffix string // suffix to add to frontend's request URL
  195. want string // what backend should see for final request URL (without ?)
  196. }{
  197. {"", "", ""},
  198. {"?sta=tic", "?us=er", "sta=tic&us=er"},
  199. {"", "?us=er", "us=er"},
  200. {"?sta=tic", "", "sta=tic"},
  201. }
  202. func TestReverseProxyQuery(t *testing.T) {
  203. backend := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
  204. rw.Header().Set("X-Got-Query", req.URL.RawQuery)
  205. rw.Write([]byte("hi"))
  206. }))
  207. defer backend.Close()
  208. for i, tt := range proxyQueryTests {
  209. backendURL, err := url.Parse(backend.URL + tt.baseSuffix)
  210. if err != nil {
  211. t.Fatal(err)
  212. }
  213. frontend := httptest.NewServer(NewReverseProxy(backendURL))
  214. req, _ := http.NewRequest("GET", frontend.URL+tt.reqSuffix, nil)
  215. req.Close = true
  216. res, err := http.DefaultClient.Do(req)
  217. if err != nil {
  218. t.Fatalf("%d. Get: %v", i, err)
  219. }
  220. if g, e := res.Header.Get("X-Got-Query"), tt.want; g != e {
  221. t.Errorf("%d. got query %q; expected %q", i, g, e)
  222. }
  223. res.Body.Close()
  224. frontend.Close()
  225. }
  226. }
  227. func TestReverseProxyFlushInterval(t *testing.T) {
  228. const expected = "hi"
  229. backend := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
  230. rw.Write([]byte(expected))
  231. }))
  232. defer backend.Close()
  233. backendURL, err := url.Parse(backend.URL)
  234. if err != nil {
  235. t.Fatal(err)
  236. }
  237. proxyHandler := NewReverseProxy(backendURL)
  238. proxyHandler.FlushInterval = time.Microsecond
  239. done := make(chan bool)
  240. onExitFlushLoop = func() { done <- true }
  241. frontend := httptest.NewServer(proxyHandler)
  242. defer frontend.Close()
  243. getReq, _ := http.NewRequest("GET", frontend.URL, nil)
  244. getReq.Close = true
  245. res, err := http.DefaultClient.Do(getReq)
  246. if err != nil {
  247. t.Fatalf("Get: %v", err)
  248. }
  249. defer res.Body.Close()
  250. bodyBytes, _ := ioutil.ReadAll(res.Body)
  251. if g, e := string(bodyBytes), expected; g != e {
  252. t.Errorf("got body %q; expected %q", g, e)
  253. }
  254. select {
  255. case <-done:
  256. // do nothing
  257. case <-time.After(3 * time.Second):
  258. t.Errorf("maxLatencyWriter flushLoop() never exited")
  259. }
  260. }
  261. func TestReverseProxyCancelation(t *testing.T) {
  262. const backendResponse = "I am the backend"
  263. reqInFlight := make(chan bool)
  264. backend := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
  265. close(reqInFlight)
  266. select {
  267. case <-time.After(time.Second * 3):
  268. t.Errorf("Handler never saw CloseNotify")
  269. case <-rw.(http.CloseNotifier).CloseNotify():
  270. // do nothing
  271. }
  272. rw.WriteHeader(http.StatusOK)
  273. rw.Write([]byte(backendResponse))
  274. }))
  275. defer backend.Close()
  276. backendURL, err := url.Parse(backend.URL)
  277. if err != nil {
  278. t.Fatal(err)
  279. }
  280. proxyHandler := NewReverseProxy(backendURL)
  281. proxyHandler.ErrorLog = log.New(ioutil.Discard, "", 0)
  282. frontend := httptest.NewServer(proxyHandler)
  283. defer frontend.Close()
  284. getReq, err := http.NewRequest("GET", frontend.URL, nil)
  285. if err != nil {
  286. t.Fatal(err)
  287. }
  288. go func() {
  289. <-reqInFlight
  290. http.DefaultTransport.(*http.Transport).CancelRequest(getReq)
  291. }()
  292. res, err := http.DefaultClient.Do(getReq)
  293. if res != nil {
  294. t.Errorf("got response %v; want nil", res.Status)
  295. }
  296. if err == nil {
  297. t.Error("DefaultClient.Do() returned nil error; want non-nil error")
  298. }
  299. }
  300. func TestReverProxyPost(t *testing.T) {
  301. const backendResponse = "I am the backend"
  302. const backendStatus = 200
  303. var requestBody = bytes.Repeat([]byte("a"), 1<<20)
  304. backend := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
  305. requestData, err := ioutil.ReadAll(req.Body)
  306. if err != nil {
  307. t.Errorf("Backend body read = %v", err)
  308. }
  309. if len(requestData) != len(requestBody) {
  310. t.Errorf("Backend read %d request body bytes; want %d", len(requestData), len(requestBody))
  311. }
  312. if !bytes.Equal(requestData, requestBody) {
  313. t.Error("Backend read wrong request body.")
  314. }
  315. rw.Write([]byte(backendResponse))
  316. }))
  317. defer backend.Close()
  318. backendURL, err := url.Parse(backend.URL)
  319. if err != nil {
  320. t.Fatal(err)
  321. }
  322. proxyHandler := NewReverseProxy(backendURL)
  323. frontend := httptest.NewServer(proxyHandler)
  324. defer frontend.Close()
  325. res, err := http.Post(frontend.URL, "", bytes.NewReader(requestBody))
  326. if err != nil {
  327. t.Fatal(err)
  328. }
  329. defer res.Body.Close()
  330. bodyBytes, err := ioutil.ReadAll(res.Body)
  331. if err != nil {
  332. t.Fatal(err)
  333. }
  334. if g, e := string(bodyBytes), backendResponse; g != e {
  335. t.Errorf("got response %v, want %v", g, e)
  336. }
  337. }
  338. func TestHTTPTunnel(t *testing.T) {
  339. const backendResponse = "I am the backend"
  340. backend := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
  341. rw.Write([]byte(backendResponse))
  342. }))
  343. defer backend.Close()
  344. backendURL, err := url.Parse(backend.URL)
  345. if err != nil {
  346. t.Fatal(err)
  347. }
  348. proxyHandler := NewReverseProxy(backendURL)
  349. frontend := httptest.NewServer(proxyHandler)
  350. defer frontend.Close()
  351. frontendURL, err := url.Parse(frontend.URL)
  352. if err != nil {
  353. t.Fatal(err)
  354. }
  355. getReq := &http.Request{
  356. Method: "CONNECT",
  357. URL: &url.URL{
  358. Host: frontendURL.Host,
  359. Scheme: frontendURL.Scheme,
  360. Path: "google.com:80",
  361. },
  362. Header: http.Header{},
  363. }
  364. res, err := http.DefaultTransport.(*http.Transport).RoundTrip(getReq)
  365. if err != nil {
  366. t.Fatal(err)
  367. }
  368. defer res.Body.Close()
  369. if res.Status != "200 OK" {
  370. t.Errorf("got response status %v, want %v", res.Status, "200 OK")
  371. }
  372. }