ratelimit.go 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. package dynamicproxy
  2. import (
  3. "errors"
  4. "net"
  5. "net/http"
  6. "strings"
  7. "sync"
  8. "sync/atomic"
  9. "time"
  10. )
  11. // IpTable is a rate limiter implementation using sync.Map with atomic int64
  12. type RequestCountPerIpTable struct {
  13. table sync.Map
  14. }
  15. // Increment the count of requests for a given IP
  16. func (t *RequestCountPerIpTable) Increment(ip string) {
  17. v, _ := t.table.LoadOrStore(ip, new(int64))
  18. atomic.AddInt64(v.(*int64), 1)
  19. }
  20. // Check if the IP is in the table and if it is, check if the count is less than the limit
  21. func (t *RequestCountPerIpTable) Exceeded(ip string, limit int64) bool {
  22. v, ok := t.table.Load(ip)
  23. if !ok {
  24. return false
  25. }
  26. count := atomic.LoadInt64(v.(*int64))
  27. return count >= limit
  28. }
  29. // Get the count of requests for a given IP
  30. func (t *RequestCountPerIpTable) GetCount(ip string) int64 {
  31. v, ok := t.table.Load(ip)
  32. if !ok {
  33. return 0
  34. }
  35. return atomic.LoadInt64(v.(*int64))
  36. }
  37. // Clear the IP table
  38. func (t *RequestCountPerIpTable) Clear() {
  39. t.table.Range(func(key, value interface{}) bool {
  40. t.table.Delete(key)
  41. return true
  42. })
  43. }
  44. func (h *ProxyHandler) handleRateLimitRouting(w http.ResponseWriter, r *http.Request, pe *ProxyEndpoint) error {
  45. err := h.Parent.handleRateLimit(w, r, pe)
  46. if err != nil {
  47. h.Parent.logRequest(r, false, 429, "ratelimit", r.URL.Hostname())
  48. }
  49. return err
  50. }
  51. func (router *Router) handleRateLimit(w http.ResponseWriter, r *http.Request, pe *ProxyEndpoint) error {
  52. //Get the real client-ip from request header
  53. clientIP := r.RemoteAddr
  54. if r.Header.Get("X-Real-Ip") == "" {
  55. CF_Connecting_IP := r.Header.Get("CF-Connecting-IP")
  56. Fastly_Client_IP := r.Header.Get("Fastly-Client-IP")
  57. if CF_Connecting_IP != "" {
  58. //Use CF Connecting IP
  59. clientIP = CF_Connecting_IP
  60. } else if Fastly_Client_IP != "" {
  61. //Use Fastly Client IP
  62. clientIP = Fastly_Client_IP
  63. } else {
  64. ips := strings.Split(clientIP, ",")
  65. if len(ips) > 0 {
  66. clientIP = strings.TrimSpace(ips[0])
  67. }
  68. }
  69. }
  70. ip, _, err := net.SplitHostPort(clientIP)
  71. if err != nil {
  72. //Default allow passthrough on error
  73. return nil
  74. }
  75. router.rateLimitCounter.Increment(ip)
  76. if router.rateLimitCounter.Exceeded(ip, int64(pe.RateLimit)) {
  77. w.WriteHeader(429)
  78. return errors.New("rate limit exceeded")
  79. }
  80. // log.Println("Rate limit check", ip, ipTable.GetCount(ip))
  81. return nil
  82. }
  83. // Start the ticker routine for reseting the rate limit counter every seconds
  84. func (r *Router) startRateLimterCounterResetTicker() error {
  85. if r.rateLimterStop != nil {
  86. return errors.New("another rate limiter ticker already running")
  87. }
  88. tickerStopChan := make(chan bool)
  89. r.rateLimterStop = tickerStopChan
  90. counterResetTicker := time.NewTicker(1 * time.Second)
  91. go func() {
  92. for {
  93. select {
  94. case <-tickerStopChan:
  95. r.rateLimterStop = nil
  96. return
  97. case <-counterResetTicker.C:
  98. r.rateLimitCounter.Clear()
  99. }
  100. }
  101. }()
  102. return nil
  103. }