websocketproxy_test.go 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134
  1. package websocketproxy
  2. import (
  3. "log"
  4. "net/http"
  5. "net/url"
  6. "testing"
  7. "time"
  8. "github.com/gorilla/websocket"
  9. )
  10. var (
  11. serverURL = "ws://127.0.0.1:7777"
  12. backendURL = "ws://127.0.0.1:8888"
  13. )
  14. func TestProxy(t *testing.T) {
  15. // websocket proxy
  16. supportedSubProtocols := []string{"test-protocol"}
  17. upgrader := &websocket.Upgrader{
  18. ReadBufferSize: 4096,
  19. WriteBufferSize: 4096,
  20. CheckOrigin: func(r *http.Request) bool {
  21. return true
  22. },
  23. Subprotocols: supportedSubProtocols,
  24. }
  25. u, _ := url.Parse(backendURL)
  26. proxy := NewProxy(u, Options{
  27. SkipTLSValidation: false,
  28. SkipOriginCheck: false,
  29. Logger: nil,
  30. })
  31. proxy.Upgrader = upgrader
  32. mux := http.NewServeMux()
  33. mux.Handle("/proxy", proxy)
  34. go func() {
  35. if err := http.ListenAndServe(":7777", mux); err != nil {
  36. t.Fatal("ListenAndServe: ", err)
  37. }
  38. }()
  39. time.Sleep(time.Millisecond * 100)
  40. // backend echo server
  41. go func() {
  42. mux2 := http.NewServeMux()
  43. mux2.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
  44. // Don't upgrade if original host header isn't preserved
  45. if r.Host != "127.0.0.1:7777" {
  46. log.Printf("Host header set incorrectly. Expecting 127.0.0.1:7777 got %s", r.Host)
  47. return
  48. }
  49. conn, err := upgrader.Upgrade(w, r, nil)
  50. if err != nil {
  51. log.Println(err)
  52. return
  53. }
  54. messageType, p, err := conn.ReadMessage()
  55. if err != nil {
  56. return
  57. }
  58. if err = conn.WriteMessage(messageType, p); err != nil {
  59. return
  60. }
  61. })
  62. err := http.ListenAndServe(":8888", mux2)
  63. if err != nil {
  64. t.Fatal("ListenAndServe: ", err)
  65. }
  66. }()
  67. time.Sleep(time.Millisecond * 100)
  68. // let's us define two subprotocols, only one is supported by the server
  69. clientSubProtocols := []string{"test-protocol", "test-notsupported"}
  70. h := http.Header{}
  71. for _, subprot := range clientSubProtocols {
  72. h.Add("Sec-WebSocket-Protocol", subprot)
  73. }
  74. // frontend server, dial now our proxy, which will reverse proxy our
  75. // message to the backend websocket server.
  76. conn, resp, err := websocket.DefaultDialer.Dial(serverURL+"/proxy", h)
  77. if err != nil {
  78. t.Fatal(err)
  79. }
  80. // check if the server really accepted only the first one
  81. in := func(desired string) bool {
  82. for _, prot := range resp.Header[http.CanonicalHeaderKey("Sec-WebSocket-Protocol")] {
  83. if desired == prot {
  84. return true
  85. }
  86. }
  87. return false
  88. }
  89. if !in("test-protocol") {
  90. t.Error("test-protocol should be available")
  91. }
  92. if in("test-notsupported") {
  93. t.Error("test-notsupported should be not recevied from the server.")
  94. }
  95. // now write a message and send it to the backend server (which goes trough
  96. // proxy..)
  97. msg := "hello kite"
  98. err = conn.WriteMessage(websocket.TextMessage, []byte(msg))
  99. if err != nil {
  100. t.Error(err)
  101. }
  102. messageType, p, err := conn.ReadMessage()
  103. if err != nil {
  104. t.Error(err)
  105. }
  106. if messageType != websocket.TextMessage {
  107. t.Error("incoming message type is not Text")
  108. }
  109. if msg != string(p) {
  110. t.Errorf("expecting: %s, got: %s", msg, string(p))
  111. }
  112. }