|
@@ -0,0 +1,100 @@
|
|
|
+package loadbalance
|
|
|
+
|
|
|
+import (
|
|
|
+ "fmt"
|
|
|
+ "math"
|
|
|
+ "math/rand"
|
|
|
+ "testing"
|
|
|
+ "time"
|
|
|
+)
|
|
|
+
|
|
|
+// func getRandomUpstreamByWeight(upstreams []*Upstream) (*Upstream, int, error) { ... }
|
|
|
+func TestRandomUpstreamSelection(t *testing.T) {
|
|
|
+ rand.Seed(time.Now().UnixNano()) // Seed for randomness
|
|
|
+
|
|
|
+ // Define some test upstreams
|
|
|
+ upstreams := []*Upstream{
|
|
|
+ {
|
|
|
+ OriginIpOrDomain: "192.168.1.1:8080",
|
|
|
+ RequireTLS: false,
|
|
|
+ SkipCertValidations: false,
|
|
|
+ SkipWebSocketOriginCheck: false,
|
|
|
+ Weight: 1,
|
|
|
+ MaxConn: 0, // No connection limit for now
|
|
|
+ },
|
|
|
+ {
|
|
|
+ OriginIpOrDomain: "192.168.1.2:8080",
|
|
|
+ RequireTLS: false,
|
|
|
+ SkipCertValidations: false,
|
|
|
+ SkipWebSocketOriginCheck: false,
|
|
|
+ Weight: 1,
|
|
|
+ MaxConn: 0,
|
|
|
+ },
|
|
|
+ {
|
|
|
+ OriginIpOrDomain: "192.168.1.3:8080",
|
|
|
+ RequireTLS: true,
|
|
|
+ SkipCertValidations: true,
|
|
|
+ SkipWebSocketOriginCheck: true,
|
|
|
+ Weight: 1,
|
|
|
+ MaxConn: 0,
|
|
|
+ },
|
|
|
+ {
|
|
|
+ OriginIpOrDomain: "192.168.1.4:8080",
|
|
|
+ RequireTLS: true,
|
|
|
+ SkipCertValidations: true,
|
|
|
+ SkipWebSocketOriginCheck: true,
|
|
|
+ Weight: 1,
|
|
|
+ MaxConn: 0,
|
|
|
+ },
|
|
|
+ }
|
|
|
+
|
|
|
+ // Track how many times each upstream is selected
|
|
|
+ selectionCount := make(map[string]int)
|
|
|
+ totalPicks := 10000 // Number of times to call getRandomUpstreamByWeight
|
|
|
+ //expectedPickCount := totalPicks / len(upstreams) // Ideal count for each upstream
|
|
|
+
|
|
|
+ // Pick upstreams and record their selection count
|
|
|
+ for i := 0; i < totalPicks; i++ {
|
|
|
+ upstream, _, err := getRandomUpstreamByWeight(upstreams)
|
|
|
+ if err != nil {
|
|
|
+ t.Fatalf("Error getting random upstream: %v", err)
|
|
|
+ }
|
|
|
+ selectionCount[upstream.OriginIpOrDomain]++
|
|
|
+ }
|
|
|
+
|
|
|
+ // Condition 1: Ensure every upstream has been picked at least once
|
|
|
+ for _, upstream := range upstreams {
|
|
|
+ if selectionCount[upstream.OriginIpOrDomain] == 0 {
|
|
|
+ t.Errorf("Upstream %s was never selected", upstream.OriginIpOrDomain)
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ // Condition 2: Check that the distribution is within 1-2 standard deviations
|
|
|
+ counts := make([]float64, len(upstreams))
|
|
|
+ for i, upstream := range upstreams {
|
|
|
+ counts[i] = float64(selectionCount[upstream.OriginIpOrDomain])
|
|
|
+ }
|
|
|
+
|
|
|
+ mean := float64(totalPicks) / float64(len(upstreams))
|
|
|
+ stddev := calculateStdDev(counts, mean)
|
|
|
+
|
|
|
+ tolerance := 2 * stddev // Allowing up to 2 standard deviations
|
|
|
+ for i, count := range counts {
|
|
|
+ if math.Abs(count-mean) > tolerance {
|
|
|
+ t.Errorf("Selection of upstream %s is outside acceptable range: %v picks (mean: %v, stddev: %v)", upstreams[i].OriginIpOrDomain, count, mean, stddev)
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ fmt.Println("Selection count:", selectionCount)
|
|
|
+ fmt.Printf("Mean: %.2f, StdDev: %.2f\n", mean, stddev)
|
|
|
+}
|
|
|
+
|
|
|
+// Helper function to calculate standard deviation
|
|
|
+func calculateStdDev(data []float64, mean float64) float64 {
|
|
|
+ var sumOfSquares float64
|
|
|
+ for _, value := range data {
|
|
|
+ sumOfSquares += (value - mean) * (value - mean)
|
|
|
+ }
|
|
|
+ variance := sumOfSquares / float64(len(data))
|
|
|
+ return math.Sqrt(variance)
|
|
|
+}
|