Forráskód Böngészése

Added custom header to websocket

Toby Chui 3 hónapja
szülő
commit
8822bbeadc
6 módosított fájl, 118 hozzáadás és 53 törlés
  1. 1 1
      def.go
  2. 1 42
      main.go
  3. 10 6
      mod/dynamicproxy/proxyRequestHandler.go
  4. 52 4
      mod/websocketproxy/websocketproxy.go
  5. 43 0
      start.go
  6. 11 0
      web/components/status.html

+ 1 - 1
def.go

@@ -43,7 +43,7 @@ const (
 	/* Build Constants */
 	SYSTEM_NAME       = "Zoraxy"
 	SYSTEM_VERSION    = "3.1.5"
-	DEVELOPMENT_BUILD = true /* Development: Set to false to use embedded web fs */
+	DEVELOPMENT_BUILD = false /* Development: Set to false to use embedded web fs */
 
 	/* System Constants */
 	DATABASE_PATH              = "sys.db"

+ 1 - 42
main.go

@@ -57,47 +57,6 @@ func SetupCloseHandler() {
 	}()
 }
 
-func ShutdownSeq() {
-	SystemWideLogger.Println("Shutting down " + SYSTEM_NAME)
-	SystemWideLogger.Println("Closing Netstats Listener")
-	if netstatBuffers != nil {
-		netstatBuffers.Close()
-	}
-
-	SystemWideLogger.Println("Closing Statistic Collector")
-	if statisticCollector != nil {
-		statisticCollector.Close()
-	}
-
-	if mdnsTickerStop != nil {
-		SystemWideLogger.Println("Stopping mDNS Discoverer (might take a few minutes)")
-		// Stop the mdns service
-		mdnsTickerStop <- true
-	}
-	if mdnsScanner != nil {
-		mdnsScanner.Close()
-	}
-	SystemWideLogger.Println("Shutting down load balancer")
-	if loadBalancer != nil {
-		loadBalancer.Close()
-	}
-	SystemWideLogger.Println("Closing Certificates Auto Renewer")
-	if acmeAutoRenewer != nil {
-		acmeAutoRenewer.Close()
-	}
-	//Remove the tmp folder
-	SystemWideLogger.Println("Cleaning up tmp files")
-	os.RemoveAll("./tmp")
-
-	//Close database
-	SystemWideLogger.Println("Stopping system database")
-	sysdb.Close()
-
-	//Close logger
-	SystemWideLogger.Println("Closing system wide logger")
-	SystemWideLogger.Close()
-}
-
 func main() {
 	//Parse startup flags
 	flag.Parse()
@@ -141,7 +100,7 @@ func main() {
 		csrf.SameSite(csrf.SameSiteLaxMode),
 	)
 
-	//Startup all modules
+	//Startup all modules, see start.go
 	startupSequence()
 
 	//Initiate management interface APIs

+ 10 - 6
mod/dynamicproxy/proxyRequestHandler.go

@@ -143,9 +143,11 @@ func (h *ProxyHandler) hostRequest(w http.ResponseWriter, r *http.Request, targe
 		}
 		h.Parent.logRequest(r, true, 101, "host-websocket", selectedUpstream.OriginIpOrDomain)
 		wspHandler := websocketproxy.NewProxy(u, websocketproxy.Options{
-			SkipTLSValidation: selectedUpstream.SkipCertValidations,
-			SkipOriginCheck:   selectedUpstream.SkipWebSocketOriginCheck,
-			Logger:            h.Parent.Option.Logger,
+			SkipTLSValidation:  selectedUpstream.SkipCertValidations,
+			SkipOriginCheck:    selectedUpstream.SkipWebSocketOriginCheck,
+			CopyAllHeaders:     true,
+			UserDefinedHeaders: target.HeaderRewriteRules.UserDefinedHeaders,
+			Logger:             h.Parent.Option.Logger,
 		})
 		wspHandler.ServeHTTP(w, r)
 		return
@@ -221,9 +223,11 @@ func (h *ProxyHandler) vdirRequest(w http.ResponseWriter, r *http.Request, targe
 		}
 		h.Parent.logRequest(r, true, 101, "vdir-websocket", target.Domain)
 		wspHandler := websocketproxy.NewProxy(u, websocketproxy.Options{
-			SkipTLSValidation: target.SkipCertValidations,
-			SkipOriginCheck:   true, //You should not use websocket via virtual directory. But keep this to true for compatibility
-			Logger:            h.Parent.Option.Logger,
+			SkipTLSValidation:  target.SkipCertValidations,
+			SkipOriginCheck:    true, //You should not use websocket via virtual directory. But keep this to true for compatibility
+			CopyAllHeaders:     true,
+			UserDefinedHeaders: target.parent.HeaderRewriteRules.UserDefinedHeaders,
+			Logger:             h.Parent.Option.Logger,
 		})
 		wspHandler.ServeHTTP(w, r)
 		return

+ 52 - 4
mod/websocketproxy/websocketproxy.go

@@ -13,6 +13,7 @@ import (
 	"strings"
 
 	"github.com/gorilla/websocket"
+	"imuslab.com/zoraxy/mod/dynamicproxy/rewrite"
 	"imuslab.com/zoraxy/mod/info/logger"
 )
 
@@ -56,9 +57,11 @@ type WebsocketProxy struct {
 
 // Additional options for websocket proxy runtime
 type Options struct {
-	SkipTLSValidation bool           //Skip backend TLS validation
-	SkipOriginCheck   bool           //Skip origin check
-	Logger            *logger.Logger //Logger, can be nil
+	SkipTLSValidation  bool                         //Skip backend TLS validation
+	SkipOriginCheck    bool                         //Skip origin check
+	CopyAllHeaders     bool                         //Copy all headers from incoming request to backend request
+	UserDefinedHeaders []*rewrite.UserDefinedHeader //User defined headers
+	Logger             *logger.Logger               //Logger, can be nil
 }
 
 // ProxyHandler returns a new http.Handler interface that reverse proxies the
@@ -78,7 +81,14 @@ func NewProxy(target *url.URL, options Options) *WebsocketProxy {
 		u.RawQuery = r.URL.RawQuery
 		return &u
 	}
-	return &WebsocketProxy{Backend: backend, Verbal: false, Options: options}
+
+	// Create a new websocket proxy
+	wsprox := &WebsocketProxy{Backend: backend, Verbal: false, Options: options}
+	if options.CopyAllHeaders {
+		wsprox.Director = DefaultDirector
+	}
+
+	return wsprox
 }
 
 // Utilities function for log printing
@@ -90,6 +100,35 @@ func (w *WebsocketProxy) Println(messsage string, err error) {
 	log.Println("[websocketproxy] [system:info]"+messsage, err)
 }
 
+// DefaultDirector is the default implementation of Director, which copies
+// all headers from the incoming request to the outgoing request.
+func DefaultDirector(r *http.Request, h http.Header) {
+	//Copy all header values from request to target header
+	for k, vv := range r.Header {
+		for _, v := range vv {
+			h.Set(k, v)
+		}
+	}
+
+	// Remove hop-by-hop headers
+	for _, removePendingHeader := range []string{
+		"Connection",
+		"Keep-Alive",
+		"Proxy-Authenticate",
+		"Proxy-Authorization",
+		"Te",
+		"Trailers",
+		"Transfer-Encoding",
+		"Sec-WebSocket-Extensions",
+		"Sec-WebSocket-Key",
+		"Sec-WebSocket-Protocol",
+		"Sec-WebSocket-Version",
+		"Upgrade",
+	} {
+		h.Del(removePendingHeader)
+	}
+}
+
 // ServeHTTP implements the http.Handler that proxies WebSocket connections.
 func (w *WebsocketProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
 	if w.Backend == nil {
@@ -162,6 +201,15 @@ func (w *WebsocketProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
 		w.Director(req, requestHeader)
 	}
 
+	// Replace header variables and copy user-defined headers
+	rewrittenUserDefinedHeaders := rewrite.PopulateRequestHeaderVariables(req, w.Options.UserDefinedHeaders)
+	upstreamHeaders, _ := rewrite.SplitUpDownStreamHeaders(&rewrite.HeaderRewriteOptions{
+		UserDefinedHeaders: rewrittenUserDefinedHeaders,
+	})
+	for _, headerValuePair := range upstreamHeaders {
+		requestHeader.Set(headerValuePair[0], headerValuePair[1])
+	}
+
 	// Connect to the backend URL, also pass the headers we get from the requst
 	// together with the Forwarded headers we prepared above.
 	// TODO: support multiplexing on the same backend connection instead of

+ 43 - 0
start.go

@@ -331,6 +331,7 @@ func startupSequence() {
 
 }
 
+/* Finalize Startup Sequence */
 // This sequence start after everything is initialized
 func finalSequence() {
 	//Start ACME renew agent
@@ -339,3 +340,45 @@ func finalSequence() {
 	//Inject routing rules
 	registerBuildInRoutingRules()
 }
+
+/* Shutdown Sequence */
+func ShutdownSeq() {
+	SystemWideLogger.Println("Shutting down " + SYSTEM_NAME)
+	SystemWideLogger.Println("Closing Netstats Listener")
+	if netstatBuffers != nil {
+		netstatBuffers.Close()
+	}
+
+	SystemWideLogger.Println("Closing Statistic Collector")
+	if statisticCollector != nil {
+		statisticCollector.Close()
+	}
+
+	if mdnsTickerStop != nil {
+		SystemWideLogger.Println("Stopping mDNS Discoverer (might take a few minutes)")
+		// Stop the mdns service
+		mdnsTickerStop <- true
+	}
+	if mdnsScanner != nil {
+		mdnsScanner.Close()
+	}
+	SystemWideLogger.Println("Shutting down load balancer")
+	if loadBalancer != nil {
+		loadBalancer.Close()
+	}
+	SystemWideLogger.Println("Closing Certificates Auto Renewer")
+	if acmeAutoRenewer != nil {
+		acmeAutoRenewer.Close()
+	}
+	//Remove the tmp folder
+	SystemWideLogger.Println("Cleaning up tmp files")
+	os.RemoveAll("./tmp")
+
+	//Close database
+	SystemWideLogger.Println("Stopping system database")
+	sysdb.Close()
+
+	//Close logger
+	SystemWideLogger.Println("Closing system wide logger")
+	SystemWideLogger.Close()
+}

+ 11 - 0
web/components/status.html

@@ -1,3 +1,10 @@
+<style>
+    #redirect.disabled{
+        opacity: 0.7;
+        pointer-events: none;
+        user-select: none;
+    }
+</style>
 <div class="ui stackable grid">
     <div class="ten wide column serverstatusWrapper">
         <div id="serverstatus" class="ui statustab inverted segment">
@@ -362,9 +369,11 @@
                 }
                 if (enabled){
                     //$("#redirect").show();
+                    $("#redirect").removeClass("disabled");
                     msgbox("Port 80 listener enabled");
                 }else{
                     //$("#redirect").hide();
+                    $("#redirect").addClass("disabled");
                     msgbox("Port 80 listener disabled");
                 }
             }
@@ -402,9 +411,11 @@
         $.get("/api/proxy/listenPort80", function(data){
             if (data){
                 $("#listenP80").checkbox("set checked");
+                $("#redirect").removeClass("disabled");
                 //$("#redirect").show();
             }else{
                 $("#listenP80").checkbox("set unchecked");
+                $("#redirect").addClass("disabled");
                 //$("#redirect").hide();
             }