123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134 |
- package websocketproxy
- import (
- "log"
- "net/http"
- "net/url"
- "testing"
- "time"
- "github.com/gorilla/websocket"
- )
- var (
- serverURL = "ws://127.0.0.1:7777"
- backendURL = "ws://127.0.0.1:8888"
- )
- func TestProxy(t *testing.T) {
- // websocket proxy
- supportedSubProtocols := []string{"test-protocol"}
- upgrader := &websocket.Upgrader{
- ReadBufferSize: 4096,
- WriteBufferSize: 4096,
- CheckOrigin: func(r *http.Request) bool {
- return true
- },
- Subprotocols: supportedSubProtocols,
- }
- u, _ := url.Parse(backendURL)
- proxy := NewProxy(u, Options{
- SkipTLSValidation: false,
- SkipOriginCheck: false,
- Logger: nil,
- })
- proxy.Upgrader = upgrader
- mux := http.NewServeMux()
- mux.Handle("/proxy", proxy)
- go func() {
- if err := http.ListenAndServe(":7777", mux); err != nil {
- t.Fatal("ListenAndServe: ", err)
- }
- }()
- time.Sleep(time.Millisecond * 100)
- // backend echo server
- go func() {
- mux2 := http.NewServeMux()
- mux2.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
- // Don't upgrade if original host header isn't preserved
- if r.Host != "127.0.0.1:7777" {
- log.Printf("Host header set incorrectly. Expecting 127.0.0.1:7777 got %s", r.Host)
- return
- }
- conn, err := upgrader.Upgrade(w, r, nil)
- if err != nil {
- log.Println(err)
- return
- }
- messageType, p, err := conn.ReadMessage()
- if err != nil {
- return
- }
- if err = conn.WriteMessage(messageType, p); err != nil {
- return
- }
- })
- err := http.ListenAndServe(":8888", mux2)
- if err != nil {
- t.Fatal("ListenAndServe: ", err)
- }
- }()
- time.Sleep(time.Millisecond * 100)
- // let's us define two subprotocols, only one is supported by the server
- clientSubProtocols := []string{"test-protocol", "test-notsupported"}
- h := http.Header{}
- for _, subprot := range clientSubProtocols {
- h.Add("Sec-WebSocket-Protocol", subprot)
- }
- // frontend server, dial now our proxy, which will reverse proxy our
- // message to the backend websocket server.
- conn, resp, err := websocket.DefaultDialer.Dial(serverURL+"/proxy", h)
- if err != nil {
- t.Fatal(err)
- }
- // check if the server really accepted only the first one
- in := func(desired string) bool {
- for _, prot := range resp.Header[http.CanonicalHeaderKey("Sec-WebSocket-Protocol")] {
- if desired == prot {
- return true
- }
- }
- return false
- }
- if !in("test-protocol") {
- t.Error("test-protocol should be available")
- }
- if in("test-notsupported") {
- t.Error("test-notsupported should be not recevied from the server.")
- }
- // now write a message and send it to the backend server (which goes trough
- // proxy..)
- msg := "hello kite"
- err = conn.WriteMessage(websocket.TextMessage, []byte(msg))
- if err != nil {
- t.Error(err)
- }
- messageType, p, err := conn.ReadMessage()
- if err != nil {
- t.Error(err)
- }
- if messageType != websocket.TextMessage {
- t.Error("incoming message type is not Text")
- }
- if msg != string(p) {
- t.Errorf("expecting: %s, got: %s", msg, string(p))
- }
- }
|