package loadbalance

import (
	"fmt"
	"math"
	"math/rand"
	"testing"
	"time"
)

// func getRandomUpstreamByWeight(upstreams []*Upstream) (*Upstream, int, error) { ... }
func TestRandomUpstreamSelection(t *testing.T) {
	rand.Seed(time.Now().UnixNano()) // Seed for randomness

	// Define some test upstreams
	upstreams := []*Upstream{
		{
			OriginIpOrDomain:         "192.168.1.1:8080",
			RequireTLS:               false,
			SkipCertValidations:      false,
			SkipWebSocketOriginCheck: false,
			Weight:                   1,
			MaxConn:                  0, // No connection limit for now
		},
		{
			OriginIpOrDomain:         "192.168.1.2:8080",
			RequireTLS:               false,
			SkipCertValidations:      false,
			SkipWebSocketOriginCheck: false,
			Weight:                   1,
			MaxConn:                  0,
		},
		{
			OriginIpOrDomain:         "192.168.1.3:8080",
			RequireTLS:               true,
			SkipCertValidations:      true,
			SkipWebSocketOriginCheck: true,
			Weight:                   1,
			MaxConn:                  0,
		},
		{
			OriginIpOrDomain:         "192.168.1.4:8080",
			RequireTLS:               true,
			SkipCertValidations:      true,
			SkipWebSocketOriginCheck: true,
			Weight:                   1,
			MaxConn:                  0,
		},
	}

	// Track how many times each upstream is selected
	selectionCount := make(map[string]int)
	totalPicks := 10000 // Number of times to call getRandomUpstreamByWeight
	//expectedPickCount := totalPicks / len(upstreams) // Ideal count for each upstream

	// Pick upstreams and record their selection count
	for i := 0; i < totalPicks; i++ {
		upstream, _, err := getRandomUpstreamByWeight(upstreams)
		if err != nil {
			t.Fatalf("Error getting random upstream: %v", err)
		}
		selectionCount[upstream.OriginIpOrDomain]++
	}

	// Condition 1: Ensure every upstream has been picked at least once
	for _, upstream := range upstreams {
		if selectionCount[upstream.OriginIpOrDomain] == 0 {
			t.Errorf("Upstream %s was never selected", upstream.OriginIpOrDomain)
		}
	}

	// Condition 2: Check that the distribution is within 1-2 standard deviations
	counts := make([]float64, len(upstreams))
	for i, upstream := range upstreams {
		counts[i] = float64(selectionCount[upstream.OriginIpOrDomain])
	}

	mean := float64(totalPicks) / float64(len(upstreams))
	stddev := calculateStdDev(counts, mean)

	tolerance := 2 * stddev // Allowing up to 2 standard deviations
	for i, count := range counts {
		if math.Abs(count-mean) > tolerance {
			t.Errorf("Selection of upstream %s is outside acceptable range: %v picks (mean: %v, stddev: %v)", upstreams[i].OriginIpOrDomain, count, mean, stddev)
		}
	}

	fmt.Println("Selection count:", selectionCount)
	fmt.Printf("Mean: %.2f, StdDev: %.2f\n", mean, stddev)
}

// Helper function to calculate standard deviation
func calculateStdDev(data []float64, mean float64) float64 {
	var sumOfSquares float64
	for _, value := range data {
		sumOfSquares += (value - mean) * (value - mean)
	}
	variance := sumOfSquares / float64(len(data))
	return math.Sqrt(variance)
}