Browse Source

Fixed trie tree implementation

Toby Chui 7 months ago
parent
commit
fb34367618
5 changed files with 19 additions and 50 deletions
  1. 4 3
      mod/geodb/geodb.go
  2. 2 1
      mod/geodb/geodb_test.go
  3. 4 6
      mod/geodb/geoloader.go
  4. 8 39
      mod/geodb/trie.go
  5. 1 1
      web/script/utils.js

+ 4 - 3
mod/geodb/geodb.go

@@ -3,6 +3,7 @@ package geodb
 import (
 	_ "embed"
 	"net/http"
+	"sync"
 
 	"imuslab.com/zoraxy/mod/database"
 	"imuslab.com/zoraxy/mod/netutils"
@@ -19,9 +20,9 @@ type Store struct {
 	geodbIpv6   [][]string //Parsed geodb list for ipv6
 	geotrie     *trie
 	geotrieIpv6 *trie
-	//geoipCache sync.Map
-	sysdb  *database.Database
-	option *StoreOptions
+	geoipCache  sync.Map
+	sysdb       *database.Database
+	option      *StoreOptions
 }
 
 type StoreOptions struct {

+ 2 - 1
mod/geodb/geodb_test.go

@@ -43,7 +43,7 @@ func TestResolveCountryCodeFromIP(t *testing.T) {
 	// Create a new store
 	store, err := geodb.NewGeoDb(nil, &geodb.StoreOptions{
 		false,
-		false,
+		true,
 	})
 	if err != nil {
 		t.Errorf("error creating store: %v", err)
@@ -56,6 +56,7 @@ func TestResolveCountryCodeFromIP(t *testing.T) {
 		{"176.113.115.113", "RU"},
 		{"65.21.233.213", "FI"},
 		{"94.23.207.193", "FR"},
+		{"77.131.21.232", "FR"},
 	}
 
 	for _, testcase := range knownIpCountryMap {

+ 4 - 6
mod/geodb/geoloader.go

@@ -17,12 +17,10 @@ func (s *Store) search(ip string) string {
 		ip = strings.TrimSpace(ip)
 	}
 	//See if there are cached country code for this ip
-	/*
-		ccc, ok := s.geoipCache.Load(ip)
-		if ok {
-			return ccc.(string)
-		}
-	*/
+	ccc, ok := s.geoipCache.Load(ip)
+	if ok {
+		return ccc.(string)
+	}
 
 	//Search in geotrie tree
 	cc := ""

+ 8 - 39
mod/geodb/trie.go

@@ -1,7 +1,6 @@
 package geodb
 
 import (
-	"math"
 	"net"
 )
 
@@ -41,14 +40,10 @@ func (t *trie) insert(ipAddr string, cc string) {
 	ipBytes := ipToBytes(ipAddr)
 	current := t.root
 	for _, b := range ipBytes {
-		//For each byte in the ip address
+		//For each byte in the ip address (4 / 16 bytes)
 		//each byte is 8 bit
-		for j := 0; j < 8; j++ {
-			bitwise := (b&uint8(math.Pow(float64(2), float64(j))) > 0)
-			bit := 0b0000
-			if bitwise {
-				bit = 0b0001
-			}
+		for j := 7; j >= 0; j-- {
+			bit := int(b >> j & 1)
 			if current.childrens[bit] == nil {
 				current.childrens[bit] = &trie_Node{
 					childrens: [2]*trie_Node{},
@@ -58,21 +53,9 @@ func (t *trie) insert(ipAddr string, cc string) {
 			current = current.childrens[bit]
 		}
 	}
-
-	/*
-		for i := 63; i >= 0; i-- {
-			bit := (ipInt64 >> uint(i)) & 1
-			if current.childrens[bit] == nil {
-				current.childrens[bit] = &trie_Node{
-					childrens: [2]*trie_Node{},
-					cc:        cc,
-				}
-			}
-			current = current.childrens[bit]
-		}
-	*/
 }
 
+// isReservedIP check if the given ip address is NOT a public ip address
 func isReservedIP(ip string) bool {
 	parsedIP := net.ParseIP(ip)
 	if parsedIP == nil {
@@ -86,12 +69,10 @@ func isReservedIP(ip string) bool {
 	if parsedIP.IsLinkLocalUnicast() || parsedIP.IsLinkLocalMulticast() {
 		return true
 	}
-
+	//Check if the IP is in the reserved private range
 	if parsedIP.IsPrivate() {
 		return true
 	}
-
-	// If the IP address is not a reserved address, return false
 	return false
 }
 
@@ -106,27 +87,15 @@ func (t *trie) search(ipAddr string) string {
 	for _, b := range ipBytes {
 		//For each byte in the ip address
 		//each byte is 8 bit
-		for j := 0; j < 8; j++ {
-			bitwise := (b&uint8(math.Pow(float64(2), float64(j))) > 0)
-			bit := 0b0000
-			if bitwise {
-				bit = 0b0001
-			}
+		for j := 7; j >= 0; j-- {
+			bit := int(b >> j & 1)
 			if current.childrens[bit] == nil {
 				return current.cc
 			}
 			current = current.childrens[bit]
 		}
 	}
-	/*
-		for i := 63; i >= 0; i-- {
-			bit := (ipInt64 >> uint(i)) & 1
-			if current.childrens[bit] == nil {
-				return current.cc
-			}
-			current = current.childrens[bit]
-		}
-	*/
+
 	if len(current.childrens) == 0 {
 		return current.cc
 	}

+ 1 - 1
web/script/utils.js

@@ -30,7 +30,7 @@ Object.defineProperty(String.prototype, 'capitalize', {
 
 //Add a new function to jquery for ajax override with csrf token injected
 $.cjax = function(payload){
-    let requireTokenMethod = ["POST", "PUT", "DELETE"];;
+    let requireTokenMethod = ["POST", "PUT", "DELETE"];
     if (requireTokenMethod.includes(payload.method) || requireTokenMethod.includes(payload.type)){
         //csrf token is required
         let csrfToken = document.getElementsByTagName("meta")["zoraxy.csrf.Token"].getAttribute("content");