originPicker.go 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196
  1. package loadbalance
  2. import (
  3. "errors"
  4. "math/rand"
  5. "net/http"
  6. )
  7. /*
  8. Origin Picker
  9. This script contains the code to pick the best origin
  10. by this request.
  11. */
  12. // GetRequestUpstreamTarget return the upstream target where this
  13. // request should be routed
  14. func (m *RouteManager) GetRequestUpstreamTarget(w http.ResponseWriter, r *http.Request, origins []*Upstream, useStickySession bool) (*Upstream, error) {
  15. if len(origins) == 0 {
  16. return nil, errors.New("no upstream is defined for this host")
  17. }
  18. //Pick the origin
  19. if useStickySession {
  20. //Use stick session, check which origins this request previously used
  21. targetOriginId, err := m.getSessionHandler(r, origins)
  22. if err != nil || !m.IsTargetOnline(origins[targetOriginId].OriginIpOrDomain) {
  23. // No valid session found or origin is offline
  24. // Filter the offline origins
  25. origins = m.FilterOfflineOrigins(origins)
  26. if len(origins) == 0 {
  27. return nil, errors.New("no online upstream is available for origin: " + r.Host)
  28. }
  29. //Get a random origin
  30. targetOrigin, index, err := getRandomUpstreamByWeight(origins)
  31. if err != nil {
  32. m.println("Unable to get random upstream", err)
  33. targetOrigin = origins[0]
  34. index = 0
  35. }
  36. m.setSessionHandler(w, r, targetOrigin.OriginIpOrDomain, index)
  37. return targetOrigin, nil
  38. }
  39. //Valid session found and origin is online
  40. return origins[targetOriginId], nil
  41. }
  42. //No sticky session, get a random origin
  43. //Filter the offline origins
  44. origins = m.FilterOfflineOrigins(origins)
  45. if len(origins) == 0 {
  46. return nil, errors.New("no online upstream is available for origin: " + r.Host)
  47. }
  48. //Get a random origin
  49. targetOrigin, _, err := getRandomUpstreamByWeight(origins)
  50. if err != nil {
  51. m.println("Failed to get next origin", err)
  52. targetOrigin = origins[0]
  53. }
  54. //fmt.Println("DEBUG: Picking origin " + targetOrigin.OriginIpOrDomain)
  55. return targetOrigin, nil
  56. }
  57. // GetUsableUpstreamCounts return the number of usable upstreams
  58. func (m *RouteManager) GetUsableUpstreamCounts(origins []*Upstream) int {
  59. origins = m.FilterOfflineOrigins(origins)
  60. return len(origins)
  61. }
  62. /* Features related to session access */
  63. //Set a new origin for this connection by session
  64. func (m *RouteManager) setSessionHandler(w http.ResponseWriter, r *http.Request, originIpOrDomain string, index int) error {
  65. session, err := m.SessionStore.Get(r, "STICKYSESSION")
  66. if err != nil {
  67. return err
  68. }
  69. session.Values["zr_sid_origin"] = originIpOrDomain
  70. session.Values["zr_sid_index"] = index
  71. session.Options.MaxAge = 86400 //1 day
  72. session.Options.Path = "/"
  73. err = session.Save(r, w)
  74. if err != nil {
  75. return err
  76. }
  77. return nil
  78. }
  79. // Get the previous connected origin from session
  80. func (m *RouteManager) getSessionHandler(r *http.Request, upstreams []*Upstream) (int, error) {
  81. // Get existing session
  82. session, err := m.SessionStore.Get(r, "STICKYSESSION")
  83. if err != nil {
  84. return -1, err
  85. }
  86. // Retrieve session values for origin
  87. originDomainRaw := session.Values["zr_sid_origin"]
  88. originIDRaw := session.Values["zr_sid_index"]
  89. if originDomainRaw == nil || originIDRaw == nil {
  90. return -1, errors.New("no session has been set")
  91. }
  92. originDomain := originDomainRaw.(string)
  93. originID := originIDRaw.(int)
  94. //Check if it has been modified
  95. if len(upstreams) < originID || upstreams[originID].OriginIpOrDomain != originDomain {
  96. //Mismatch or upstreams has been updated
  97. return -1, errors.New("upstreams has been changed")
  98. }
  99. return originID, nil
  100. }
  101. /* Functions related to random upstream picking */
  102. // Get a random upstream by the weights defined in Upstream struct, return the upstream, index value and any error
  103. func getRandomUpstreamByWeight(upstreams []*Upstream) (*Upstream, int, error) {
  104. // If there is only one upstream, return it
  105. if len(upstreams) == 1 {
  106. return upstreams[0], 0, nil
  107. }
  108. // Preserve the index with upstreams
  109. type upstreamWithIndex struct {
  110. Upstream *Upstream
  111. Index int
  112. }
  113. // Calculate total weight for upstreams with weight > 0
  114. totalWeight := 0
  115. fallbackUpstreams := make([]upstreamWithIndex, 0, len(upstreams))
  116. for index, upstream := range upstreams {
  117. if upstream.Weight > 0 {
  118. totalWeight += upstream.Weight
  119. } else {
  120. // Collect fallback upstreams
  121. fallbackUpstreams = append(fallbackUpstreams, upstreamWithIndex{upstream, index})
  122. }
  123. }
  124. // If there are no upstreams with weight > 0, return a fallback upstream if available
  125. if totalWeight == 0 {
  126. if len(fallbackUpstreams) > 0 {
  127. // Randomly select one of the fallback upstreams
  128. randIndex := rand.Intn(len(fallbackUpstreams))
  129. return fallbackUpstreams[randIndex].Upstream, fallbackUpstreams[randIndex].Index, nil
  130. }
  131. // No upstreams available at all
  132. return nil, -1, errors.New("no valid upstream servers available")
  133. }
  134. // Random weight between 0 and total weight
  135. randomWeight := rand.Intn(totalWeight)
  136. // Select an upstream based on the random weight
  137. for index, upstream := range upstreams {
  138. if upstream.Weight > 0 { // Only consider upstreams with weight > 0
  139. if randomWeight < upstream.Weight {
  140. // Return the selected upstream and its index
  141. return upstream, index, nil
  142. }
  143. randomWeight -= upstream.Weight
  144. }
  145. }
  146. // If we reach here, it means we should return a fallback upstream if available
  147. if len(fallbackUpstreams) > 0 {
  148. randIndex := rand.Intn(len(fallbackUpstreams))
  149. return fallbackUpstreams[randIndex].Upstream, fallbackUpstreams[randIndex].Index, nil
  150. }
  151. return nil, -1, errors.New("failed to pick an upstream origin server")
  152. }
  153. // IntRange returns a random integer in the range from min to max.
  154. /*
  155. func intRange(min, max int) (int, error) {
  156. var result int
  157. switch {
  158. case min > max:
  159. // Fail with error
  160. return result, errors.New("min is greater than max")
  161. case max == min:
  162. result = max
  163. case max > min:
  164. b := rand.Intn(max-min) + min
  165. result = min + int(b)
  166. }
  167. return result, nil
  168. }
  169. */