originPicker.go 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177
  1. package loadbalance
  2. import (
  3. "errors"
  4. "math/rand"
  5. "net/http"
  6. )
  7. /*
  8. Origin Picker
  9. This script contains the code to pick the best origin
  10. by this request.
  11. */
  12. // GetRequestUpstreamTarget return the upstream target where this
  13. // request should be routed
  14. func (m *RouteManager) GetRequestUpstreamTarget(w http.ResponseWriter, r *http.Request, origins []*Upstream, useStickySession bool) (*Upstream, error) {
  15. if len(origins) == 0 {
  16. return nil, errors.New("no upstream is defined for this host")
  17. }
  18. var targetOrigin = origins[0]
  19. if useStickySession {
  20. //Use stick session, check which origins this request previously used
  21. targetOriginId, err := m.getSessionHandler(r, origins)
  22. if err != nil {
  23. //No valid session found. Assign a new upstream
  24. targetOrigin, index, err := getRandomUpstreamByWeight(origins)
  25. if err != nil {
  26. m.println("Unable to get random upstream", err)
  27. targetOrigin = origins[0]
  28. index = 0
  29. }
  30. m.setSessionHandler(w, r, targetOrigin.OriginIpOrDomain, index)
  31. return targetOrigin, nil
  32. }
  33. //Valid session found. Resume the previous session
  34. return origins[targetOriginId], nil
  35. } else {
  36. //Do not use stick session. Get a random one
  37. var err error
  38. targetOrigin, _, err = getRandomUpstreamByWeight(origins)
  39. if err != nil {
  40. m.println("Failed to get next origin", err)
  41. targetOrigin = origins[0]
  42. }
  43. }
  44. //fmt.Println("DEBUG: Picking origin " + targetOrigin.OriginIpOrDomain)
  45. return targetOrigin, nil
  46. }
  47. /* Features related to session access */
  48. //Set a new origin for this connection by session
  49. func (m *RouteManager) setSessionHandler(w http.ResponseWriter, r *http.Request, originIpOrDomain string, index int) error {
  50. session, err := m.SessionStore.Get(r, "STICKYSESSION")
  51. if err != nil {
  52. return err
  53. }
  54. session.Values["zr_sid_origin"] = originIpOrDomain
  55. session.Values["zr_sid_index"] = index
  56. session.Options.MaxAge = 86400 //1 day
  57. session.Options.Path = "/"
  58. err = session.Save(r, w)
  59. if err != nil {
  60. return err
  61. }
  62. return nil
  63. }
  64. // Get the previous connected origin from session
  65. func (m *RouteManager) getSessionHandler(r *http.Request, upstreams []*Upstream) (int, error) {
  66. // Get existing session
  67. session, err := m.SessionStore.Get(r, "STICKYSESSION")
  68. if err != nil {
  69. return -1, err
  70. }
  71. // Retrieve session values for origin
  72. originDomainRaw := session.Values["zr_sid_origin"]
  73. originIDRaw := session.Values["zr_sid_index"]
  74. if originDomainRaw == nil || originIDRaw == nil {
  75. return -1, errors.New("no session has been set")
  76. }
  77. originDomain := originDomainRaw.(string)
  78. originID := originIDRaw.(int)
  79. //Check if it has been modified
  80. if len(upstreams) < originID || upstreams[originID].OriginIpOrDomain != originDomain {
  81. //Mismatch or upstreams has been updated
  82. return -1, errors.New("upstreams has been changed")
  83. }
  84. return originID, nil
  85. }
  86. /* Functions related to random upstream picking */
  87. // Get a random upstream by the weights defined in Upstream struct, return the upstream, index value and any error
  88. func getRandomUpstreamByWeight(upstreams []*Upstream) (*Upstream, int, error) {
  89. // If there is only one upstream, return it
  90. if len(upstreams) == 1 {
  91. return upstreams[0], 0, nil
  92. }
  93. // Preserve the index with upstreams
  94. type upstreamWithIndex struct {
  95. Upstream *Upstream
  96. Index int
  97. }
  98. // Calculate total weight for upstreams with weight > 0
  99. totalWeight := 0
  100. fallbackUpstreams := make([]upstreamWithIndex, 0, len(upstreams))
  101. for index, upstream := range upstreams {
  102. if upstream.Weight > 0 {
  103. totalWeight += upstream.Weight
  104. } else {
  105. // Collect fallback upstreams
  106. fallbackUpstreams = append(fallbackUpstreams, upstreamWithIndex{upstream, index})
  107. }
  108. }
  109. // If there are no upstreams with weight > 0, return a fallback upstream if available
  110. if totalWeight == 0 {
  111. if len(fallbackUpstreams) > 0 {
  112. // Randomly select one of the fallback upstreams
  113. randIndex := rand.Intn(len(fallbackUpstreams))
  114. return fallbackUpstreams[randIndex].Upstream, fallbackUpstreams[randIndex].Index, nil
  115. }
  116. // No upstreams available at all
  117. return nil, -1, errors.New("no valid upstream servers available")
  118. }
  119. // Random weight between 0 and total weight
  120. randomWeight := rand.Intn(totalWeight)
  121. // Select an upstream based on the random weight
  122. for index, upstream := range upstreams {
  123. if upstream.Weight > 0 { // Only consider upstreams with weight > 0
  124. if randomWeight < upstream.Weight {
  125. // Return the selected upstream and its index
  126. return upstream, index, nil
  127. }
  128. randomWeight -= upstream.Weight
  129. }
  130. }
  131. // If we reach here, it means we should return a fallback upstream if available
  132. if len(fallbackUpstreams) > 0 {
  133. randIndex := rand.Intn(len(fallbackUpstreams))
  134. return fallbackUpstreams[randIndex].Upstream, fallbackUpstreams[randIndex].Index, nil
  135. }
  136. return nil, -1, errors.New("failed to pick an upstream origin server")
  137. }
  138. // IntRange returns a random integer in the range from min to max.
  139. /*
  140. func intRange(min, max int) (int, error) {
  141. var result int
  142. switch {
  143. case min > max:
  144. // Fail with error
  145. return result, errors.New("min is greater than max")
  146. case max == min:
  147. result = max
  148. case max > min:
  149. b := rand.Intn(max-min) + min
  150. result = min + int(b)
  151. }
  152. return result, nil
  153. }
  154. */