package dynamicproxy import ( "errors" "net" "net/http" "strings" "sync" "sync/atomic" "time" ) // IpTable is a rate limiter implementation using sync.Map with atomic int64 type RequestCountPerIpTable struct { table sync.Map } // Increment the count of requests for a given IP func (t *RequestCountPerIpTable) Increment(ip string) { v, _ := t.table.LoadOrStore(ip, new(int64)) atomic.AddInt64(v.(*int64), 1) } // Check if the IP is in the table and if it is, check if the count is less than the limit func (t *RequestCountPerIpTable) Exceeded(ip string, limit int64) bool { v, ok := t.table.Load(ip) if !ok { return false } count := atomic.LoadInt64(v.(*int64)) return count >= limit } // Get the count of requests for a given IP func (t *RequestCountPerIpTable) GetCount(ip string) int64 { v, ok := t.table.Load(ip) if !ok { return 0 } return atomic.LoadInt64(v.(*int64)) } // Clear the IP table func (t *RequestCountPerIpTable) Clear() { t.table.Range(func(key, value interface{}) bool { t.table.Delete(key) return true }) } func (h *ProxyHandler) handleRateLimitRouting(w http.ResponseWriter, r *http.Request, pe *ProxyEndpoint) error { err := h.Parent.handleRateLimit(w, r, pe) if err != nil { h.Parent.logRequest(r, false, 429, "ratelimit", r.URL.Hostname()) } return err } func (router *Router) handleRateLimit(w http.ResponseWriter, r *http.Request, pe *ProxyEndpoint) error { //Get the real client-ip from request header clientIP := r.RemoteAddr if r.Header.Get("X-Real-Ip") == "" { CF_Connecting_IP := r.Header.Get("CF-Connecting-IP") Fastly_Client_IP := r.Header.Get("Fastly-Client-IP") if CF_Connecting_IP != "" { //Use CF Connecting IP clientIP = CF_Connecting_IP } else if Fastly_Client_IP != "" { //Use Fastly Client IP clientIP = Fastly_Client_IP } else { ips := strings.Split(clientIP, ",") if len(ips) > 0 { clientIP = strings.TrimSpace(ips[0]) } } } ip, _, err := net.SplitHostPort(clientIP) if err != nil { //Default allow passthrough on error return nil } router.rateLimitCounter.Increment(ip) if router.rateLimitCounter.Exceeded(ip, int64(pe.RateLimit)) { w.WriteHeader(429) return errors.New("rate limit exceeded") } // log.Println("Rate limit check", ip, ipTable.GetCount(ip)) return nil } // Start the ticker routine for reseting the rate limit counter every seconds func (r *Router) startRateLimterCounterResetTicker() error { if r.rateLimterStop != nil { return errors.New("another rate limiter ticker already running") } tickerStopChan := make(chan bool) r.rateLimterStop = tickerStopChan counterResetTicker := time.NewTicker(1 * time.Second) go func() { for { select { case <-tickerStopChan: r.rateLimterStop = nil return case <-counterResetTicker.C: r.rateLimitCounter.Clear() } } }() return nil }