Browse Source

Added sticky session load balancer

Toby Chui 8 months ago
parent
commit
45fe740d29

+ 2 - 2
main.go

@@ -58,8 +58,8 @@ var enableAutoUpdate = flag.Bool("cfgupgrade", true, "Enable auto config upgrade
 var (
 	name        = "Zoraxy"
 	version     = "3.0.8"
-	nodeUUID    = "generic"
-	development = true //Set this to false to use embedded web fs
+	nodeUUID    = "generic" //System uuid, in uuidv4 format
+	development = true      //Set this to false to use embedded web fs
 	bootTime    = time.Now().Unix()
 
 	/*

+ 1 - 0
mod/dynamicproxy/dpcore/dpcore.go

@@ -64,6 +64,7 @@ type ResponseRewriteRuleSet struct {
 	PathPrefix        string //Vdir prefix for root, / will be rewrite to this
 	UpstreamHeaders   [][]string
 	DownstreamHeaders [][]string
+	NoRemoveHopByHop  bool   //Do not remove hop-by-hop headers, dangerous
 	Version           string //Version number of Zoraxy, use for X-Proxy-By
 }
 

+ 7 - 6
mod/dynamicproxy/dynamicproxy.go

@@ -151,18 +151,19 @@ func (router *Router) StartProxyService() error {
 							}
 						}
 
-						selectedUpstream, err := router.loadBalancer.GetRequestUpstreamTarget(r, sep.ActiveOrigins)
+						selectedUpstream, err := router.loadBalancer.GetRequestUpstreamTarget(w, r, sep.ActiveOrigins, sep.UseStickySession)
 						if err != nil {
 							http.ServeFile(w, r, "./web/hosterror.html")
 							log.Println(err.Error())
 							router.logRequest(r, false, 404, "vdir-http", r.Host)
 						}
 						selectedUpstream.ServeHTTP(w, r, &dpcore.ResponseRewriteRuleSet{
-							ProxyDomain:  selectedUpstream.OriginIpOrDomain,
-							OriginalHost: originalHostHeader,
-							UseTLS:       selectedUpstream.RequireTLS,
-							PathPrefix:   "",
-							Version:      sep.parent.Option.HostVersion,
+							ProxyDomain:      selectedUpstream.OriginIpOrDomain,
+							OriginalHost:     originalHostHeader,
+							UseTLS:           selectedUpstream.RequireTLS,
+							NoRemoveHopByHop: sep.DisableHopByHopHeaderRemoval,
+							PathPrefix:       "",
+							Version:          sep.parent.Option.HostVersion,
 						})
 						return
 					}

+ 12 - 0
mod/dynamicproxy/loadbalance/loadbalance.go

@@ -5,6 +5,8 @@ import (
 	"sync"
 	"sync/atomic"
 
+	"github.com/google/uuid"
+	"github.com/gorilla/sessions"
 	"imuslab.com/zoraxy/mod/dynamicproxy/dpcore"
 	"imuslab.com/zoraxy/mod/geodb"
 	"imuslab.com/zoraxy/mod/info/logger"
@@ -17,12 +19,14 @@ import (
 */
 
 type Options struct {
+	SystemUUID           string       //Use for the session store
 	UseActiveHealthCheck bool         //Use active health check, default to false
 	Geodb                *geodb.Store //GeoIP resolver for checking incoming request origin country
 	Logger               *logger.Logger
 }
 
 type RouteManager struct {
+	SessionStore           *sessions.CookieStore
 	LoadBalanceMap         sync.Map  //Sync map to store the last load balance state of a given node
 	OnlineStatusMap        sync.Map  //Sync map to store the online status of a given ip address or domain name
 	onlineStatusTickerStop chan bool //Stopping channel for the online status pinger
@@ -47,7 +51,15 @@ type Upstream struct {
 
 // Create a new load balancer
 func NewLoadBalancer(options *Options) *RouteManager {
+	if options.SystemUUID == "" {
+		//System UUID not passed in. Use random key
+		options.SystemUUID = uuid.New().String()
+	}
+
+	//Generate a session store for stickySession
+	store := sessions.NewCookieStore([]byte("something-very-secret"))
 	return &RouteManager{
+		SessionStore:           store,
 		LoadBalanceMap:         sync.Map{},
 		OnlineStatusMap:        sync.Map{},
 		onlineStatusTickerStop: nil,

+ 115 - 3
mod/dynamicproxy/loadbalance/originPicker.go

@@ -3,6 +3,8 @@ package loadbalance
 import (
 	"errors"
 	"fmt"
+	"log"
+	"math/rand"
 	"net/http"
 )
 
@@ -15,12 +17,122 @@ import (
 
 // GetRequestUpstreamTarget return the upstream target where this
 // request should be routed
-func (m *RouteManager) GetRequestUpstreamTarget(r *http.Request, origins []*Upstream) (*Upstream, error) {
+func (m *RouteManager) GetRequestUpstreamTarget(w http.ResponseWriter, r *http.Request, origins []*Upstream, useStickySession bool) (*Upstream, error) {
 	if len(origins) == 0 {
 		return nil, errors.New("no upstream is defined for this host")
 	}
+	var targetOrigin = origins[0]
+	if useStickySession {
+		//Use stick session, check which origins this request previously used
+		targetOriginId, err := m.getSessionHandler(r, origins)
+		if err != nil {
+			//No valid session found. Assign a new upstream
+			targetOrigin, index, err := getRandomUpstreamByWeight(origins)
+			if err != nil {
+				fmt.Println("Oops. Unable to get random upstream")
+				targetOrigin = origins[0]
+				index = 0
+			}
+			m.setSessionHandler(w, r, targetOrigin.OriginIpOrDomain, index)
+			return targetOrigin, nil
+		}
 
-	//TODO: Add upstream picking algorithm here
-	fmt.Println("DEBUG: Picking origin " + origins[0].OriginIpOrDomain)
+		//Valid session found. Resume the previous session
+		return origins[targetOriginId], nil
+	} else {
+		//Do not use stick session. Get a random one
+		var err error
+		targetOrigin, _, err = getRandomUpstreamByWeight(origins)
+		if err != nil {
+			log.Println(err)
+			targetOrigin = origins[0]
+		}
+	}
+
+	fmt.Println("DEBUG: Picking origin " + targetOrigin.OriginIpOrDomain)
 	return origins[0], nil
 }
+
+/* Features related to session access */
+//Set a new origin for this connection by session
+func (m *RouteManager) setSessionHandler(w http.ResponseWriter, r *http.Request, originIpOrDomain string, index int) error {
+	session, err := m.SessionStore.Get(r, "STICKYSESSION")
+	if err != nil {
+		return err
+	}
+	session.Values["zr_sid_origin"] = originIpOrDomain
+	session.Values["zr_sid_index"] = index
+	session.Options.MaxAge = 86400 //1 day
+	session.Options.Path = "/"
+	err = session.Save(r, w)
+	if err != nil {
+		return err
+	}
+	return nil
+}
+
+// Get the previous connected origin from session
+func (m *RouteManager) getSessionHandler(r *http.Request, upstreams []*Upstream) (int, error) {
+	// Get existing session
+	session, err := m.SessionStore.Get(r, "STICKYSESSION")
+	if err != nil {
+		return -1, err
+	}
+
+	// Retrieve session values for origin
+	originDomainRaw := session.Values["zr_sid_origin"]
+	originIDRaw := session.Values["zr_sid_index"]
+
+	if originDomainRaw == nil || originIDRaw == nil {
+		return -1, errors.New("no session has been set")
+	}
+	originDomain := originDomainRaw.(string)
+	originID := originIDRaw.(int)
+
+	//Check if it has been modified
+	if len(upstreams) < originID || upstreams[originID].OriginIpOrDomain != originDomain {
+		//Mismatch or upstreams has been updated
+		return -1, errors.New("upstreams has been changed")
+	}
+
+	return originID, nil
+}
+
+/* Functions related to random upstream picking */
+// Get a random upstream by the weights defined in Upstream struct, return the upstream, index value and any error
+func getRandomUpstreamByWeight(upstreams []*Upstream) (*Upstream, int, error) {
+	var ret *Upstream
+	sum := 0
+	for _, c := range upstreams {
+		sum += c.Weight
+	}
+	r, err := intRange(0, sum)
+	if err != nil {
+		return ret, -1, err
+	}
+	counter := 0
+	for _, c := range upstreams {
+		r -= c.Weight
+		if r < 0 {
+			return c, counter, nil
+		}
+		counter++
+	}
+	return ret, -1, err
+}
+
+// IntRange returns a random integer in the range from min to max.
+func intRange(min, max int) (int, error) {
+	var result int
+	switch {
+	case min > max:
+		// Fail with error
+		return result, errors.New("min is greater than max")
+	case max == min:
+		result = max
+	case max > min:
+		b := rand.Intn(max-min) + min
+		result = min + int(b)
+	}
+	return result, nil
+}

+ 2 - 1
mod/dynamicproxy/proxyRequestHandler.go

@@ -112,7 +112,7 @@ func (router *Router) rewriteURL(rooturl string, requestURL string) string {
 func (h *ProxyHandler) hostRequest(w http.ResponseWriter, r *http.Request, target *ProxyEndpoint) {
 	r.Header.Set("X-Forwarded-Host", r.Host)
 	r.Header.Set("X-Forwarded-Server", "zoraxy-"+h.Parent.Option.HostUUID)
-	selectedUpstream, err := h.Parent.loadBalancer.GetRequestUpstreamTarget(r, target.ActiveOrigins)
+	selectedUpstream, err := h.Parent.loadBalancer.GetRequestUpstreamTarget(w, r, target.ActiveOrigins, target.UseStickySession)
 	if err != nil {
 		http.ServeFile(w, r, "./web/rperror.html")
 		log.Println(err.Error())
@@ -164,6 +164,7 @@ func (h *ProxyHandler) hostRequest(w http.ResponseWriter, r *http.Request, targe
 		PathPrefix:        "",
 		UpstreamHeaders:   upstreamHeaders,
 		DownstreamHeaders: downstreamHeaders,
+		NoRemoveHopByHop:  target.DisableHopByHopHeaderRemoval,
 		Version:           target.parent.Option.HostVersion,
 	})
 

+ 1 - 0
mod/dynamicproxy/typedef.go

@@ -133,6 +133,7 @@ type ProxyEndpoint struct {
 	HSTSMaxAge                   int64                               //HSTS max age, set to 0 for disable HSTS headers
 	EnablePermissionPolicyHeader bool                                //Enable injection of permission policy header
 	PermissionPolicy             *permissionpolicy.PermissionsPolicy //Permission policy header
+	DisableHopByHopHeaderRemoval bool                                //TODO: Do not remove hop-by-hop headers
 
 	//Authentication
 	RequireBasicAuth        bool                      //Set to true to request basic auth before proxy

+ 3 - 2
start.go

@@ -104,8 +104,9 @@ func startupSequence() {
 
 	//Create a load balancer
 	loadBalancer = loadbalance.NewLoadBalancer(&loadbalance.Options{
-		Geodb:  geodbStore,
-		Logger: SystemWideLogger,
+		SystemUUID: nodeUUID,
+		Geodb:      geodbStore,
+		Logger:     SystemWideLogger,
 	})
 
 	//Create the access controller