Browse Source

auto update script executed

Toby Chui 1 year ago
parent
commit
062dfb32c2
5 changed files with 145 additions and 272 deletions
  1. 2 0
      mod/geodb/geodb.go
  2. 0 261
      mod/geodb/geodb.go_backup
  3. 33 10
      mod/geodb/geodb_test.go
  4. 15 1
      mod/geodb/geoloader.go
  5. 95 0
      mod/geodb/trie.go

+ 2 - 0
mod/geodb/geodb.go

@@ -18,6 +18,7 @@ type Store struct {
 	Enabled    bool
 	geodb      [][]string //Parsed geodb list
 	geoipCache sync.Map
+	geotrie    *trie
 	sysdb      *database.Database
 }
 
@@ -57,6 +58,7 @@ func NewGeoDb(sysdb *database.Database) (*Store, error) {
 		Enabled:    blacklistEnabled,
 		geodb:      parsedGeoData,
 		geoipCache: sync.Map{},
+		geotrie:    constrctTrieTree(parsedGeoData),
 		sysdb:      sysdb,
 	}, nil
 }

+ 0 - 261
mod/geodb/geodb.go_backup

@@ -1,261 +0,0 @@
-package geodb
-
-import (
-	"embed"
-	"net"
-	"net/http"
-	"os"
-	"strings"
-
-	"github.com/oschwald/geoip2-golang"
-	"imuslab.com/zoraxy/mod/database"
-	"imuslab.com/zoraxy/mod/utils"
-)
-
-//go:embed GeoLite2-Country.mmdb
-var geodb embed.FS
-
-type Store struct {
-	Enabled bool
-	geodb   *geoip2.Reader
-	sysdb   *database.Database
-}
-
-type CountryInfo struct {
-	CountryIsoCode string
-	ContinetCode   string
-}
-
-func NewGeoDb(sysdb *database.Database, dbfile string) (*Store, error) {
-	if !utils.FileExists(dbfile) {
-		//Unzip it from binary
-		geodbFile, err := geodb.ReadFile("GeoLite2-Country.mmdb")
-		if err != nil {
-			return nil, err
-		}
-		err = os.WriteFile(dbfile, geodbFile, 0775)
-		if err != nil {
-			return nil, err
-		}
-	}
-	db, err := geoip2.Open(dbfile)
-	if err != nil {
-		return nil, err
-	}
-
-	err = sysdb.NewTable("blacklist-cn")
-	if err != nil {
-		return nil, err
-	}
-
-	err = sysdb.NewTable("blacklist-ip")
-	if err != nil {
-		return nil, err
-	}
-
-	err = sysdb.NewTable("blacklist")
-	if err != nil {
-		return nil, err
-	}
-
-	blacklistEnabled := false
-	sysdb.Read("blacklist", "enabled", &blacklistEnabled)
-
-	return &Store{
-		Enabled: blacklistEnabled,
-		geodb:   db,
-		sysdb:   sysdb,
-	}, nil
-}
-
-func (s *Store) ToggleBlacklist(enabled bool) {
-	s.sysdb.Write("blacklist", "enabled", enabled)
-	s.Enabled = enabled
-}
-
-func (s *Store) ResolveCountryCodeFromIP(ipstring string) (*CountryInfo, error) {
-	// If you are using strings that may be invalid, check that ip is not nil
-	ip := net.ParseIP(ipstring)
-	record, err := s.geodb.City(ip)
-	if err != nil {
-		return nil, err
-	}
-	return &CountryInfo{
-		record.Country.IsoCode,
-		record.Continent.Code,
-	}, nil
-}
-
-func (s *Store) Close() {
-	s.geodb.Close()
-}
-
-func (s *Store) AddCountryCodeToBlackList(countryCode string) {
-	countryCode = strings.ToLower(countryCode)
-	s.sysdb.Write("blacklist-cn", countryCode, true)
-}
-
-func (s *Store) RemoveCountryCodeFromBlackList(countryCode string) {
-	countryCode = strings.ToLower(countryCode)
-	s.sysdb.Delete("blacklist-cn", countryCode)
-}
-
-func (s *Store) IsCountryCodeBlacklisted(countryCode string) bool {
-	countryCode = strings.ToLower(countryCode)
-	var isBlacklisted bool = false
-	s.sysdb.Read("blacklist-cn", countryCode, &isBlacklisted)
-	return isBlacklisted
-}
-
-func (s *Store) GetAllBlacklistedCountryCode() []string {
-	bannedCountryCodes := []string{}
-	entries, err := s.sysdb.ListTable("blacklist-cn")
-	if err != nil {
-		return bannedCountryCodes
-	}
-	for _, keypairs := range entries {
-		ip := string(keypairs[0])
-		bannedCountryCodes = append(bannedCountryCodes, ip)
-	}
-
-	return bannedCountryCodes
-}
-
-func (s *Store) AddIPToBlackList(ipAddr string) {
-	s.sysdb.Write("blacklist-ip", ipAddr, true)
-}
-
-func (s *Store) RemoveIPFromBlackList(ipAddr string) {
-	s.sysdb.Delete("blacklist-ip", ipAddr)
-}
-
-func (s *Store) IsIPBlacklisted(ipAddr string) bool {
-	var isBlacklisted bool = false
-	s.sysdb.Read("blacklist-ip", ipAddr, &isBlacklisted)
-	if isBlacklisted {
-		return true
-	}
-
-	//Check for IP wildcard and CIRD rules
-	AllBlacklistedIps := s.GetAllBlacklistedIp()
-	for _, blacklistRule := range AllBlacklistedIps {
-		wildcardMatch := MatchIpWildcard(ipAddr, blacklistRule)
-		if wildcardMatch {
-			return true
-		}
-
-		cidrMatch := MatchIpCIDR(ipAddr, blacklistRule)
-		if cidrMatch {
-			return true
-		}
-	}
-
-	return false
-}
-
-func (s *Store) GetAllBlacklistedIp() []string {
-	bannedIps := []string{}
-	entries, err := s.sysdb.ListTable("blacklist-ip")
-	if err != nil {
-		return bannedIps
-	}
-
-	for _, keypairs := range entries {
-		ip := string(keypairs[0])
-		bannedIps = append(bannedIps, ip)
-	}
-
-	return bannedIps
-}
-
-//Check if a IP address is blacklisted, in either country or IP blacklist
-func (s *Store) IsBlacklisted(ipAddr string) bool {
-	if !s.Enabled {
-		//Blacklist not enabled. Always return false
-		return false
-	}
-
-	if ipAddr == "" {
-		//Unable to get the target IP address
-		return false
-	}
-
-	countryCode, err := s.ResolveCountryCodeFromIP(ipAddr)
-	if err != nil {
-		return false
-	}
-
-	if s.IsCountryCodeBlacklisted(countryCode.CountryIsoCode) {
-		return true
-	}
-
-	if s.IsIPBlacklisted(ipAddr) {
-		return true
-	}
-
-	return false
-}
-
-func (s *Store) GetRequesterCountryISOCode(r *http.Request) string {
-	ipAddr := GetRequesterIP(r)
-	if ipAddr == "" {
-		return ""
-	}
-	countryCode, err := s.ResolveCountryCodeFromIP(ipAddr)
-	if err != nil {
-		return ""
-	}
-
-	return countryCode.CountryIsoCode
-}
-
-//Utilities function
-func GetRequesterIP(r *http.Request) string {
-	ip := r.Header.Get("X-Forwarded-For")
-	if ip == "" {
-		ip = r.Header.Get("X-Real-IP")
-		if ip == "" {
-			ip = strings.Split(r.RemoteAddr, ":")[0]
-		}
-	}
-	return ip
-}
-
-//Match the IP address with a wildcard string
-func MatchIpWildcard(ipAddress, wildcard string) bool {
-	// Split IP address and wildcard into octets
-	ipOctets := strings.Split(ipAddress, ".")
-	wildcardOctets := strings.Split(wildcard, ".")
-
-	// Check that both have 4 octets
-	if len(ipOctets) != 4 || len(wildcardOctets) != 4 {
-		return false
-	}
-
-	// Check each octet to see if it matches the wildcard or is an exact match
-	for i := 0; i < 4; i++ {
-		if wildcardOctets[i] == "*" {
-			continue
-		}
-		if ipOctets[i] != wildcardOctets[i] {
-			return false
-		}
-	}
-
-	return true
-}
-
-//Match ip address with CIDR
-func MatchIpCIDR(ip string, cidr string) bool {
-	// parse the CIDR string
-	_, cidrnet, err := net.ParseCIDR(cidr)
-	if err != nil {
-		return false
-	}
-
-	// parse the IP address
-	ipAddr := net.ParseIP(ip)
-
-	// check if the IP address is within the CIDR range
-	return cidrnet.Contains(ipAddr)
-}

+ 33 - 10
mod/geodb/geodb_test.go

@@ -6,6 +6,39 @@ import (
 	"imuslab.com/zoraxy/mod/geodb"
 )
 
+/*
+func TestTrieConstruct(t *testing.T) {
+	tt := geodb.NewTrie()
+	data := [][]string{
+		{"1.0.16.0", "1.0.31.255", "JP"},
+		{"1.0.32.0", "1.0.63.255", "CN"},
+		{"1.0.64.0", "1.0.127.255", "JP"},
+		{"1.0.128.0", "1.0.255.255", "TH"},
+		{"1.1.0.0", "1.1.0.255", "CN"},
+		{"1.1.1.0", "1.1.1.255", "AU"},
+		{"1.1.2.0", "1.1.63.255", "CN"},
+		{"1.1.64.0", "1.1.127.255", "JP"},
+		{"1.1.128.0", "1.1.255.255", "TH"},
+		{"1.2.0.0", "1.2.2.255", "CN"},
+		{"1.2.3.0", "1.2.3.255", "AU"},
+	}
+
+	for _, entry := range data {
+		startIp := entry[0]
+		endIp := entry[1]
+		cc := entry[2]
+		tt.Insert(startIp, cc)
+		tt.Insert(endIp, cc)
+	}
+
+	t.Log(tt.Search("1.0.16.20"), "== JP")  //JP
+	t.Log(tt.Search("1.2.0.122"), "== CN")  //CN
+	t.Log(tt.Search("1.2.1.0"), "== CN")    //CN
+	t.Log(tt.Search("1.0.65.243"), "== JP") //JP
+	t.Log(tt.Search("1.0.62.243"), "== CN") //CN
+}
+*/
+
 func TestResolveCountryCodeFromIP(t *testing.T) {
 	// Create a new store
 	store, err := geodb.NewGeoDb(nil)
@@ -14,16 +47,6 @@ func TestResolveCountryCodeFromIP(t *testing.T) {
 		return
 	}
 
-	/*
-		result, err := store.Search("7.8.8.8")
-		if err != nil {
-			t.Error(err.Error())
-			return
-		}
-		fmt.Println(">> ", result, err)
-		return
-	*/
-
 	// Test an IP address that should return a valid country code
 	ip := "8.8.8.8"
 	expected := "US"

+ 15 - 1
mod/geodb/geoloader.go

@@ -37,7 +37,21 @@ func (s *Store) search(ip string) string {
 	return ""
 }
 
-//Parse the embedded csv as ipstart, ipend and country code entries
+// Construct the trie data structure for quick lookup
+func constrctTrieTree(data [][]string) *trie {
+	tt := newTrie()
+	for _, entry := range data {
+		startIp := entry[0]
+		endIp := entry[1]
+		cc := entry[2]
+		tt.insert(startIp, cc)
+		tt.insert(endIp, cc)
+	}
+
+	return tt
+}
+
+// Parse the embedded csv as ipstart, ipend and country code entries
 func parseCSV(content []byte) ([][]string, error) {
 	var records [][]string
 	r := csv.NewReader(bytes.NewReader(content))

+ 95 - 0
mod/geodb/trie.go

@@ -0,0 +1,95 @@
+package geodb
+
+import (
+	"fmt"
+	"net"
+	"strconv"
+	"strings"
+)
+
+type trie_Node struct {
+	childrens [2]*trie_Node
+	ends      bool
+	cc        string
+}
+
+// Initializing the root of the trie
+type trie struct {
+	root *trie_Node
+}
+
+func ipToBitString(ip string) string {
+	// Parse the IP address string into a net.IP object
+	parsedIP := net.ParseIP(ip)
+
+	// Convert the IP address to a 4-byte slice
+	ipBytes := parsedIP.To4()
+
+	// Convert each byte in the IP address to its 8-bit binary representation
+	var result []string
+	for _, b := range ipBytes {
+		result = append(result, fmt.Sprintf("%08b", b))
+	}
+
+	// Join the binary representation of each byte with dots to form the final bit string
+	return strings.Join(result, "")
+}
+
+func bitStringToIp(bitString string) string {
+	// Split the bit string into four 8-bit segments
+	segments := []string{
+		bitString[:8],
+		bitString[8:16],
+		bitString[16:24],
+		bitString[24:32],
+	}
+
+	// Convert each segment to its decimal equivalent
+	var decimalSegments []int
+	for _, s := range segments {
+		i, _ := strconv.ParseInt(s, 2, 64)
+		decimalSegments = append(decimalSegments, int(i))
+	}
+
+	// Join the decimal segments with dots to form the IP address string
+	return fmt.Sprintf("%d.%d.%d.%d", decimalSegments[0], decimalSegments[1], decimalSegments[2], decimalSegments[3])
+}
+
+// inititlaizing a new trie
+func newTrie() *trie {
+	t := new(trie)
+	t.root = new(trie_Node)
+	return t
+}
+
+// Passing words to trie
+func (t *trie) insert(ipAddr string, cc string) {
+	word := ipToBitString(ipAddr)
+	current := t.root
+	for _, wr := range word {
+		index := wr - '0'
+		if current.childrens[index] == nil {
+			current.childrens[index] = &trie_Node{
+				childrens: [2]*trie_Node{},
+				ends:      false,
+				cc:        cc,
+			}
+		}
+		current = current.childrens[index]
+	}
+	current.ends = true
+}
+
+// Initializing the search for word in node
+func (t *trie) search(ipAddr string) string {
+	word := ipToBitString(ipAddr)
+	current := t.root
+	for _, wr := range word {
+		index := wr - '0'
+		if current.childrens[index] == nil {
+			return current.cc
+		}
+		current = current.childrens[index]
+	}
+	return current.cc
+}