浏览代码

Finished rate limit module

Toby Chui 10 月之前
父节点
当前提交
eab071d5a8

+ 1 - 1
mod/dynamicproxy/Server.go

@@ -72,7 +72,7 @@ func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
 			return
 		}
 
-		// Rate Limit Check
+		// Rate Limit
 		if sep.RequireRateLimit {
 			err := h.handleRateLimitRouting(w, r, sep)
 			if err != nil {

+ 0 - 1
mod/dynamicproxy/dpcore/header.go

@@ -91,7 +91,6 @@ func addXForwardedForHeader(req *http.Request) {
 					req.Header.Set("X-Real-Ip", strings.TrimSpace(ips[0]))
 				}
 			}
-
 		}
 
 	}

+ 30 - 11
mod/dynamicproxy/dynamicproxy.go

@@ -23,12 +23,12 @@ import (
 func NewDynamicProxy(option RouterOption) (*Router, error) {
 	proxyMap := sync.Map{}
 	thisRouter := Router{
-		Option:         &option,
-		ProxyEndpoints: &proxyMap,
-		Running:        false,
-		server:         nil,
-		routingRules:   []*RoutingRule{},
-		tldMap:         map[string]int{},
+		Option:           &option,
+		ProxyEndpoints:   &proxyMap,
+		Running:          false,
+		server:           nil,
+		routingRules:     []*RoutingRule{},
+		rateLimitCounter: RequestCountPerIpTable{},
 	}
 
 	thisRouter.mux = &ProxyHandler{
@@ -85,6 +85,12 @@ func (router *Router) StartProxyService() error {
 		MinVersion:     uint16(minVersion),
 	}
 
+	//Start rate limitor
+	err := router.startRateLimterCounterResetTicker()
+	if err != nil {
+		return err
+	}
+
 	if router.Option.UseTls {
 		router.server = &http.Server{
 			Addr:      ":" + strconv.Itoa(router.Option.Port),
@@ -129,12 +135,12 @@ func (router *Router) StartProxyService() error {
 							}
 						}
 
-						// Rate Limit Check
-						// if sep.RequireBasicAuth {
-						if err := handleRateLimit(w, r, sep); err != nil {
-							return
+						// Rate Limit
+						if sep.RequireRateLimit {
+							if err := router.handleRateLimit(w, r, sep); err != nil {
+								return
+							}
 						}
-						// }
 
 						//Validate basic auth
 						if sep.RequireBasicAuth {
@@ -239,10 +245,23 @@ func (router *Router) StopProxyService() error {
 		return err
 	}
 
+	//Stop TLS listener
 	if router.tlsListener != nil {
 		router.tlsListener.Close()
 	}
 
+	//Stop rate limiter
+	if router.rateLimterStop != nil {
+		go func() {
+			// As the rate timer loop has a 1 sec ticker
+			// stop the rate limiter in go routine can prevent
+			// front end from freezing for 1 sec
+			router.rateLimterStop <- true
+		}()
+
+	}
+
+	//Stop TLS redirection (from port 80)
 	if router.tlsRedirectStop != nil {
 		router.tlsRedirectStop <- true
 	}

+ 53 - 20
mod/dynamicproxy/ratelimit.go

@@ -2,27 +2,27 @@ package dynamicproxy
 
 import (
 	"errors"
-	"log"
 	"net"
 	"net/http"
+	"strings"
 	"sync"
 	"sync/atomic"
 	"time"
 )
 
 // IpTable is a rate limiter implementation using sync.Map with atomic int64
-type IpTable struct {
+type RequestCountPerIpTable struct {
 	table sync.Map
 }
 
 // Increment the count of requests for a given IP
-func (t *IpTable) Increment(ip string) {
+func (t *RequestCountPerIpTable) Increment(ip string) {
 	v, _ := t.table.LoadOrStore(ip, new(int64))
 	atomic.AddInt64(v.(*int64), 1)
 }
 
 // Check if the IP is in the table and if it is, check if the count is less than the limit
-func (t *IpTable) Exceeded(ip string, limit int64) bool {
+func (t *RequestCountPerIpTable) Exceeded(ip string, limit int64) bool {
 	v, ok := t.table.Load(ip)
 	if !ok {
 		return false
@@ -32,7 +32,7 @@ func (t *IpTable) Exceeded(ip string, limit int64) bool {
 }
 
 // Get the count of requests for a given IP
-func (t *IpTable) GetCount(ip string) int64 {
+func (t *RequestCountPerIpTable) GetCount(ip string) int64 {
 	v, ok := t.table.Load(ip)
 	if !ok {
 		return 0
@@ -41,34 +41,50 @@ func (t *IpTable) GetCount(ip string) int64 {
 }
 
 // Clear the IP table
-func (t *IpTable) Clear() {
+func (t *RequestCountPerIpTable) Clear() {
 	t.table.Range(func(key, value interface{}) bool {
 		t.table.Delete(key)
 		return true
 	})
 }
 
-var ipTable = IpTable{}
-
 func (h *ProxyHandler) handleRateLimitRouting(w http.ResponseWriter, r *http.Request, pe *ProxyEndpoint) error {
-	err := handleRateLimit(w, r, pe)
+	err := h.Parent.handleRateLimit(w, r, pe)
 	if err != nil {
 		h.logRequest(r, false, 429, "ratelimit", pe.Domain)
 	}
 	return err
 }
 
-func handleRateLimit(w http.ResponseWriter, r *http.Request, pe *ProxyEndpoint) error {
-	ip, _, err := net.SplitHostPort(r.RemoteAddr)
+func (router *Router) handleRateLimit(w http.ResponseWriter, r *http.Request, pe *ProxyEndpoint) error {
+	//Get the real client-ip from request header
+	clientIP := r.RemoteAddr
+	if r.Header.Get("X-Real-Ip") == "" {
+		CF_Connecting_IP := r.Header.Get("CF-Connecting-IP")
+		Fastly_Client_IP := r.Header.Get("Fastly-Client-IP")
+		if CF_Connecting_IP != "" {
+			//Use CF Connecting IP
+			clientIP = CF_Connecting_IP
+		} else if Fastly_Client_IP != "" {
+			//Use Fastly Client IP
+			clientIP = Fastly_Client_IP
+		} else {
+			ips := strings.Split(clientIP, ",")
+			if len(ips) > 0 {
+				clientIP = strings.TrimSpace(ips[0])
+			}
+		}
+	}
+
+	ip, _, err := net.SplitHostPort(clientIP)
 	if err != nil {
-		w.WriteHeader(500)
-		log.Println("Error resolving remote address", r.RemoteAddr, err)
-		return errors.New("internal server error")
+		//Default allow passthrough on error
+		return nil
 	}
 
-	ipTable.Increment(ip)
+	router.rateLimitCounter.Increment(ip)
 
-	if ipTable.Exceeded(ip, int64(pe.RateLimit)) {
+	if router.rateLimitCounter.Exceeded(ip, int64(pe.RateLimit)) {
 		w.WriteHeader(429)
 		return errors.New("rate limit exceeded")
 	}
@@ -78,9 +94,26 @@ func handleRateLimit(w http.ResponseWriter, r *http.Request, pe *ProxyEndpoint)
 	return nil
 }
 
-func InitRateLimit() {
-	for {
-		ipTable.Clear()
-		time.Sleep(time.Second)
+// Start the ticker routine for reseting the rate limit counter every seconds
+func (r *Router) startRateLimterCounterResetTicker() error {
+	if r.rateLimterStop != nil {
+		return errors.New("another rate limiter ticker already running")
 	}
+	tickerStopChan := make(chan bool)
+	r.rateLimterStop = tickerStopChan
+
+	counterResetTicker := time.NewTicker(1 * time.Second)
+	go func() {
+		for {
+			select {
+			case <-tickerStopChan:
+				r.rateLimterStop = nil
+				return
+			case <-counterResetTicker.C:
+				r.rateLimitCounter.Clear()
+			}
+		}
+	}()
+
+	return nil
 }

+ 3 - 2
mod/dynamicproxy/typedef.go

@@ -51,8 +51,9 @@ type Router struct {
 	tlsListener    net.Listener
 	routingRules   []*RoutingRule
 
-	tlsRedirectStop chan bool      //Stop channel for tls redirection server
-	tldMap          map[string]int //Top level domain map, see tld.json
+	tlsRedirectStop  chan bool              //Stop channel for tls redirection server
+	rateLimterStop   chan bool              //Stop channel for rate limiter
+	rateLimitCounter RequestCountPerIpTable //Request counter for rate limter
 }
 
 // Auth credential for basic auth on certain endpoints

+ 0 - 5
reverseproxy.go

@@ -145,11 +145,6 @@ func ReverseProxtInit() {
 		})
 		SystemWideLogger.Println("Uptime Monitor background service started")
 	}()
-
-	// Init Rate Limit
-	go func() {
-		dynamicproxy.InitRateLimit()
-	}()
 }
 
 func ReverseProxyHandleOnOff(w http.ResponseWriter, r *http.Request) {