originPicker.go 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154
  1. package loadbalance
  2. import (
  3. "errors"
  4. "fmt"
  5. "log"
  6. "math/rand"
  7. "net/http"
  8. )
  9. /*
  10. Origin Picker
  11. This script contains the code to pick the best origin
  12. by this request.
  13. */
  14. // GetRequestUpstreamTarget return the upstream target where this
  15. // request should be routed
  16. func (m *RouteManager) GetRequestUpstreamTarget(w http.ResponseWriter, r *http.Request, origins []*Upstream, useStickySession bool) (*Upstream, error) {
  17. if len(origins) == 0 {
  18. return nil, errors.New("no upstream is defined for this host")
  19. }
  20. var targetOrigin = origins[0]
  21. if useStickySession {
  22. //Use stick session, check which origins this request previously used
  23. targetOriginId, err := m.getSessionHandler(r, origins)
  24. if err != nil {
  25. //No valid session found. Assign a new upstream
  26. targetOrigin, index, err := getRandomUpstreamByWeight(origins)
  27. if err != nil {
  28. fmt.Println("Oops. Unable to get random upstream")
  29. targetOrigin = origins[0]
  30. index = 0
  31. }
  32. m.setSessionHandler(w, r, targetOrigin.OriginIpOrDomain, index)
  33. return targetOrigin, nil
  34. }
  35. //Valid session found. Resume the previous session
  36. return origins[targetOriginId], nil
  37. } else {
  38. //Do not use stick session. Get a random one
  39. var err error
  40. targetOrigin, _, err = getRandomUpstreamByWeight(origins)
  41. if err != nil {
  42. log.Println(err)
  43. targetOrigin = origins[0]
  44. }
  45. }
  46. //fmt.Println("DEBUG: Picking origin " + targetOrigin.OriginIpOrDomain)
  47. return targetOrigin, nil
  48. }
  49. /* Features related to session access */
  50. //Set a new origin for this connection by session
  51. func (m *RouteManager) setSessionHandler(w http.ResponseWriter, r *http.Request, originIpOrDomain string, index int) error {
  52. session, err := m.SessionStore.Get(r, "STICKYSESSION")
  53. if err != nil {
  54. return err
  55. }
  56. session.Values["zr_sid_origin"] = originIpOrDomain
  57. session.Values["zr_sid_index"] = index
  58. session.Options.MaxAge = 86400 //1 day
  59. session.Options.Path = "/"
  60. err = session.Save(r, w)
  61. if err != nil {
  62. return err
  63. }
  64. return nil
  65. }
  66. // Get the previous connected origin from session
  67. func (m *RouteManager) getSessionHandler(r *http.Request, upstreams []*Upstream) (int, error) {
  68. // Get existing session
  69. session, err := m.SessionStore.Get(r, "STICKYSESSION")
  70. if err != nil {
  71. return -1, err
  72. }
  73. // Retrieve session values for origin
  74. originDomainRaw := session.Values["zr_sid_origin"]
  75. originIDRaw := session.Values["zr_sid_index"]
  76. if originDomainRaw == nil || originIDRaw == nil {
  77. return -1, errors.New("no session has been set")
  78. }
  79. originDomain := originDomainRaw.(string)
  80. originID := originIDRaw.(int)
  81. //Check if it has been modified
  82. if len(upstreams) < originID || upstreams[originID].OriginIpOrDomain != originDomain {
  83. //Mismatch or upstreams has been updated
  84. return -1, errors.New("upstreams has been changed")
  85. }
  86. return originID, nil
  87. }
  88. /* Functions related to random upstream picking */
  89. // Get a random upstream by the weights defined in Upstream struct, return the upstream, index value and any error
  90. func getRandomUpstreamByWeight(upstreams []*Upstream) (*Upstream, int, error) {
  91. var ret *Upstream
  92. sum := 0
  93. for _, c := range upstreams {
  94. sum += c.Weight
  95. }
  96. r, err := intRange(0, sum)
  97. if err != nil {
  98. return ret, -1, err
  99. }
  100. counter := 0
  101. for _, c := range upstreams {
  102. r -= c.Weight
  103. if r < 0 {
  104. return c, counter, nil
  105. }
  106. counter++
  107. }
  108. if ret == nil {
  109. //All fallback
  110. //use the first one that is with weight = 0
  111. fallbackUpstreams := []*Upstream{}
  112. fallbackUpstreamsOriginalID := []int{}
  113. for ix, upstream := range upstreams {
  114. if upstream.Weight == 0 {
  115. fallbackUpstreams = append(fallbackUpstreams, upstream)
  116. fallbackUpstreamsOriginalID = append(fallbackUpstreamsOriginalID, ix)
  117. }
  118. }
  119. upstreamID := rand.Intn(len(fallbackUpstreams))
  120. return fallbackUpstreams[upstreamID], fallbackUpstreamsOriginalID[upstreamID], nil
  121. }
  122. return ret, -1, errors.New("failed to pick an upstream origin server")
  123. }
  124. // IntRange returns a random integer in the range from min to max.
  125. func intRange(min, max int) (int, error) {
  126. var result int
  127. switch {
  128. case min > max:
  129. // Fail with error
  130. return result, errors.New("min is greater than max")
  131. case max == min:
  132. result = max
  133. case max > min:
  134. b := rand.Intn(max-min) + min
  135. result = min + int(b)
  136. }
  137. return result, nil
  138. }