ratelimit.go 1.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. package dynamicproxy
  2. import (
  3. "errors"
  4. "log"
  5. "net"
  6. "net/http"
  7. "sync"
  8. "sync/atomic"
  9. "time"
  10. )
  11. // IpTable is a rate limiter implementation using sync.Map with atomic int64
  12. type IpTable struct {
  13. table sync.Map
  14. }
  15. // Increment the count of requests for a given IP
  16. func (t *IpTable) Increment(ip string) {
  17. v, _ := t.table.LoadOrStore(ip, new(int64))
  18. atomic.AddInt64(v.(*int64), 1)
  19. }
  20. // Check if the IP is in the table and if it is, check if the count is less than the limit
  21. func (t *IpTable) Exceeded(ip string, limit int64) bool {
  22. v, ok := t.table.Load(ip)
  23. if !ok {
  24. return false
  25. }
  26. count := atomic.LoadInt64(v.(*int64))
  27. return count >= limit
  28. }
  29. // Get the count of requests for a given IP
  30. func (t *IpTable) GetCount(ip string) int64 {
  31. v, ok := t.table.Load(ip)
  32. if !ok {
  33. return 0
  34. }
  35. return atomic.LoadInt64(v.(*int64))
  36. }
  37. // Clear the IP table
  38. func (t *IpTable) Clear() {
  39. t.table.Range(func(key, value interface{}) bool {
  40. t.table.Delete(key)
  41. return true
  42. })
  43. }
  44. var ipTable = IpTable{}
  45. func (h *ProxyHandler) handleRateLimitRouting(w http.ResponseWriter, r *http.Request, pe *ProxyEndpoint) error {
  46. err := handleRateLimit(w, r, pe)
  47. if err != nil {
  48. h.logRequest(r, false, 429, "ratelimit", pe.Domain)
  49. }
  50. return err
  51. }
  52. func handleRateLimit(w http.ResponseWriter, r *http.Request, pe *ProxyEndpoint) error {
  53. ip, _, err := net.SplitHostPort(r.RemoteAddr)
  54. if err != nil {
  55. w.WriteHeader(500)
  56. log.Println("Error resolving remote address", r.RemoteAddr, err)
  57. return errors.New("internal server error")
  58. }
  59. ipTable.Increment(ip)
  60. if ipTable.Exceeded(ip, int64(pe.RateLimit)) {
  61. w.WriteHeader(429)
  62. return errors.New("rate limit exceeded")
  63. }
  64. // log.Println("Rate limit check", ip, ipTable.GetCount(ip))
  65. return nil
  66. }
  67. func InitRateLimit() {
  68. for {
  69. ipTable.Clear()
  70. time.Sleep(time.Second)
  71. }
  72. }