package geodb

import (
	_ "embed"
	"log"
	"net"
	"net/http"
	"strings"

	"imuslab.com/zoraxy/mod/database"
)

//go:embed geoipv4.csv
var geoipv4 []byte //Original embedded csv file

type Store struct {
	Enabled bool
	geodb   [][]string //Parsed geodb list
	//geoipCache sync.Map
	geotrie *trie
	sysdb   *database.Database
}

type CountryInfo struct {
	CountryIsoCode string
	ContinetCode   string
}

func NewGeoDb(sysdb *database.Database) (*Store, error) {
	parsedGeoData, err := parseCSV(geoipv4)
	if err != nil {
		return nil, err
	}

	blacklistEnabled := false
	if sysdb != nil {
		err = sysdb.NewTable("blacklist-cn")
		if err != nil {
			return nil, err
		}

		err = sysdb.NewTable("blacklist-ip")
		if err != nil {
			return nil, err
		}

		err = sysdb.NewTable("blacklist")
		if err != nil {
			return nil, err
		}
		sysdb.Read("blacklist", "enabled", &blacklistEnabled)
	} else {
		log.Println("Database pointer set to nil: Entering debug mode")
	}

	return &Store{
		Enabled: blacklistEnabled,
		geodb:   parsedGeoData,
		//geoipCache: sync.Map{},
		geotrie: constrctTrieTree(parsedGeoData),
		sysdb:   sysdb,
	}, nil
}

func (s *Store) ToggleBlacklist(enabled bool) {
	s.sysdb.Write("blacklist", "enabled", enabled)
	s.Enabled = enabled
}

func (s *Store) ResolveCountryCodeFromIP(ipstring string) (*CountryInfo, error) {
	cc := s.search(ipstring)
	return &CountryInfo{
		CountryIsoCode: cc,
		ContinetCode:   "",
	}, nil
}

func (s *Store) Close() {

}

func (s *Store) AddCountryCodeToBlackList(countryCode string) {
	countryCode = strings.ToLower(countryCode)
	s.sysdb.Write("blacklist-cn", countryCode, true)
}

func (s *Store) RemoveCountryCodeFromBlackList(countryCode string) {
	countryCode = strings.ToLower(countryCode)
	s.sysdb.Delete("blacklist-cn", countryCode)
}

func (s *Store) IsCountryCodeBlacklisted(countryCode string) bool {
	countryCode = strings.ToLower(countryCode)
	var isBlacklisted bool = false
	s.sysdb.Read("blacklist-cn", countryCode, &isBlacklisted)
	return isBlacklisted
}

func (s *Store) GetAllBlacklistedCountryCode() []string {
	bannedCountryCodes := []string{}
	entries, err := s.sysdb.ListTable("blacklist-cn")
	if err != nil {
		return bannedCountryCodes
	}
	for _, keypairs := range entries {
		ip := string(keypairs[0])
		bannedCountryCodes = append(bannedCountryCodes, ip)
	}

	return bannedCountryCodes
}

func (s *Store) AddIPToBlackList(ipAddr string) {
	s.sysdb.Write("blacklist-ip", ipAddr, true)
}

func (s *Store) RemoveIPFromBlackList(ipAddr string) {
	s.sysdb.Delete("blacklist-ip", ipAddr)
}

func (s *Store) IsIPBlacklisted(ipAddr string) bool {
	var isBlacklisted bool = false
	s.sysdb.Read("blacklist-ip", ipAddr, &isBlacklisted)
	if isBlacklisted {
		return true
	}

	//Check for IP wildcard and CIRD rules
	AllBlacklistedIps := s.GetAllBlacklistedIp()
	for _, blacklistRule := range AllBlacklistedIps {
		wildcardMatch := MatchIpWildcard(ipAddr, blacklistRule)
		if wildcardMatch {
			return true
		}

		cidrMatch := MatchIpCIDR(ipAddr, blacklistRule)
		if cidrMatch {
			return true
		}
	}

	return false
}

func (s *Store) GetAllBlacklistedIp() []string {
	bannedIps := []string{}
	entries, err := s.sysdb.ListTable("blacklist-ip")
	if err != nil {
		return bannedIps
	}

	for _, keypairs := range entries {
		ip := string(keypairs[0])
		bannedIps = append(bannedIps, ip)
	}

	return bannedIps
}

// Check if a IP address is blacklisted, in either country or IP blacklist
func (s *Store) IsBlacklisted(ipAddr string) bool {
	if !s.Enabled {
		//Blacklist not enabled. Always return false
		return false
	}

	if ipAddr == "" {
		//Unable to get the target IP address
		return false
	}

	countryCode, err := s.ResolveCountryCodeFromIP(ipAddr)
	if err != nil {
		return false
	}

	if s.IsCountryCodeBlacklisted(countryCode.CountryIsoCode) {
		return true
	}

	if s.IsIPBlacklisted(ipAddr) {
		return true
	}

	return false
}

func (s *Store) GetRequesterCountryISOCode(r *http.Request) string {
	ipAddr := GetRequesterIP(r)
	if ipAddr == "" {
		return ""
	}
	countryCode, err := s.ResolveCountryCodeFromIP(ipAddr)
	if err != nil {
		return ""
	}

	return countryCode.CountryIsoCode
}

// Utilities function
func GetRequesterIP(r *http.Request) string {
	ip := r.Header.Get("X-Forwarded-For")
	if ip == "" {
		ip = r.Header.Get("X-Real-IP")
		if ip == "" {
			ip = strings.Split(r.RemoteAddr, ":")[0]
		}
	}
	return ip
}

// Match the IP address with a wildcard string
func MatchIpWildcard(ipAddress, wildcard string) bool {
	// Split IP address and wildcard into octets
	ipOctets := strings.Split(ipAddress, ".")
	wildcardOctets := strings.Split(wildcard, ".")

	// Check that both have 4 octets
	if len(ipOctets) != 4 || len(wildcardOctets) != 4 {
		return false
	}

	// Check each octet to see if it matches the wildcard or is an exact match
	for i := 0; i < 4; i++ {
		if wildcardOctets[i] == "*" {
			continue
		}
		if ipOctets[i] != wildcardOctets[i] {
			return false
		}
	}

	return true
}

// Match ip address with CIDR
func MatchIpCIDR(ip string, cidr string) bool {
	// parse the CIDR string
	_, cidrnet, err := net.ParseCIDR(cidr)
	if err != nil {
		return false
	}

	// parse the IP address
	ipAddr := net.ParseIP(ip)

	// check if the IP address is within the CIDR range
	return cidrnet.Contains(ipAddr)
}