|
@@ -3,6 +3,8 @@ package loadbalance
|
|
|
import (
|
|
|
"errors"
|
|
|
"fmt"
|
|
|
+ "log"
|
|
|
+ "math/rand"
|
|
|
"net/http"
|
|
|
)
|
|
|
|
|
@@ -15,12 +17,122 @@ import (
|
|
|
|
|
|
// GetRequestUpstreamTarget return the upstream target where this
|
|
|
// request should be routed
|
|
|
-func (m *RouteManager) GetRequestUpstreamTarget(r *http.Request, origins []*Upstream) (*Upstream, error) {
|
|
|
+func (m *RouteManager) GetRequestUpstreamTarget(w http.ResponseWriter, r *http.Request, origins []*Upstream, useStickySession bool) (*Upstream, error) {
|
|
|
if len(origins) == 0 {
|
|
|
return nil, errors.New("no upstream is defined for this host")
|
|
|
}
|
|
|
+ var targetOrigin = origins[0]
|
|
|
+ if useStickySession {
|
|
|
+ //Use stick session, check which origins this request previously used
|
|
|
+ targetOriginId, err := m.getSessionHandler(r, origins)
|
|
|
+ if err != nil {
|
|
|
+ //No valid session found. Assign a new upstream
|
|
|
+ targetOrigin, index, err := getRandomUpstreamByWeight(origins)
|
|
|
+ if err != nil {
|
|
|
+ fmt.Println("Oops. Unable to get random upstream")
|
|
|
+ targetOrigin = origins[0]
|
|
|
+ index = 0
|
|
|
+ }
|
|
|
+ m.setSessionHandler(w, r, targetOrigin.OriginIpOrDomain, index)
|
|
|
+ return targetOrigin, nil
|
|
|
+ }
|
|
|
|
|
|
- //TODO: Add upstream picking algorithm here
|
|
|
- fmt.Println("DEBUG: Picking origin " + origins[0].OriginIpOrDomain)
|
|
|
+ //Valid session found. Resume the previous session
|
|
|
+ return origins[targetOriginId], nil
|
|
|
+ } else {
|
|
|
+ //Do not use stick session. Get a random one
|
|
|
+ var err error
|
|
|
+ targetOrigin, _, err = getRandomUpstreamByWeight(origins)
|
|
|
+ if err != nil {
|
|
|
+ log.Println(err)
|
|
|
+ targetOrigin = origins[0]
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ fmt.Println("DEBUG: Picking origin " + targetOrigin.OriginIpOrDomain)
|
|
|
return origins[0], nil
|
|
|
}
|
|
|
+
|
|
|
+/* Features related to session access */
|
|
|
+//Set a new origin for this connection by session
|
|
|
+func (m *RouteManager) setSessionHandler(w http.ResponseWriter, r *http.Request, originIpOrDomain string, index int) error {
|
|
|
+ session, err := m.SessionStore.Get(r, "STICKYSESSION")
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+ session.Values["zr_sid_origin"] = originIpOrDomain
|
|
|
+ session.Values["zr_sid_index"] = index
|
|
|
+ session.Options.MaxAge = 86400 //1 day
|
|
|
+ session.Options.Path = "/"
|
|
|
+ err = session.Save(r, w)
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+ return nil
|
|
|
+}
|
|
|
+
|
|
|
+// Get the previous connected origin from session
|
|
|
+func (m *RouteManager) getSessionHandler(r *http.Request, upstreams []*Upstream) (int, error) {
|
|
|
+ // Get existing session
|
|
|
+ session, err := m.SessionStore.Get(r, "STICKYSESSION")
|
|
|
+ if err != nil {
|
|
|
+ return -1, err
|
|
|
+ }
|
|
|
+
|
|
|
+ // Retrieve session values for origin
|
|
|
+ originDomainRaw := session.Values["zr_sid_origin"]
|
|
|
+ originIDRaw := session.Values["zr_sid_index"]
|
|
|
+
|
|
|
+ if originDomainRaw == nil || originIDRaw == nil {
|
|
|
+ return -1, errors.New("no session has been set")
|
|
|
+ }
|
|
|
+ originDomain := originDomainRaw.(string)
|
|
|
+ originID := originIDRaw.(int)
|
|
|
+
|
|
|
+ //Check if it has been modified
|
|
|
+ if len(upstreams) < originID || upstreams[originID].OriginIpOrDomain != originDomain {
|
|
|
+ //Mismatch or upstreams has been updated
|
|
|
+ return -1, errors.New("upstreams has been changed")
|
|
|
+ }
|
|
|
+
|
|
|
+ return originID, nil
|
|
|
+}
|
|
|
+
|
|
|
+/* Functions related to random upstream picking */
|
|
|
+// Get a random upstream by the weights defined in Upstream struct, return the upstream, index value and any error
|
|
|
+func getRandomUpstreamByWeight(upstreams []*Upstream) (*Upstream, int, error) {
|
|
|
+ var ret *Upstream
|
|
|
+ sum := 0
|
|
|
+ for _, c := range upstreams {
|
|
|
+ sum += c.Weight
|
|
|
+ }
|
|
|
+ r, err := intRange(0, sum)
|
|
|
+ if err != nil {
|
|
|
+ return ret, -1, err
|
|
|
+ }
|
|
|
+ counter := 0
|
|
|
+ for _, c := range upstreams {
|
|
|
+ r -= c.Weight
|
|
|
+ if r < 0 {
|
|
|
+ return c, counter, nil
|
|
|
+ }
|
|
|
+ counter++
|
|
|
+ }
|
|
|
+ return ret, -1, err
|
|
|
+}
|
|
|
+
|
|
|
+// IntRange returns a random integer in the range from min to max.
|
|
|
+func intRange(min, max int) (int, error) {
|
|
|
+ var result int
|
|
|
+ switch {
|
|
|
+ case min > max:
|
|
|
+ // Fail with error
|
|
|
+ return result, errors.New("min is greater than max")
|
|
|
+ case max == min:
|
|
|
+ result = max
|
|
|
+ case max > min:
|
|
|
+ b := rand.Intn(max-min) + min
|
|
|
+ result = min + int(b)
|
|
|
+ }
|
|
|
+ return result, nil
|
|
|
+}
|