tlscert.go 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288
  1. package tlscert
  2. import (
  3. "crypto/tls"
  4. "crypto/x509"
  5. "embed"
  6. "encoding/pem"
  7. "fmt"
  8. "io"
  9. "log"
  10. "os"
  11. "path/filepath"
  12. "strings"
  13. "imuslab.com/zoraxy/mod/utils"
  14. )
  15. type CertCache struct {
  16. Cert *x509.Certificate
  17. PubKey string
  18. PriKey string
  19. }
  20. type Manager struct {
  21. CertStore string //Path where all the certs are stored
  22. LoadedCerts []*CertCache //A list of loaded certs
  23. verbal bool
  24. }
  25. //go:embed localhost.pem localhost.key
  26. var buildinCertStore embed.FS
  27. func NewManager(certStore string, verbal bool) (*Manager, error) {
  28. if !utils.FileExists(certStore) {
  29. os.MkdirAll(certStore, 0775)
  30. }
  31. pubKey := "./tmp/localhost.pem"
  32. priKey := "./tmp/localhost.key"
  33. //Check if this is initial setup
  34. if !utils.FileExists(pubKey) {
  35. buildInPubKey, _ := buildinCertStore.ReadFile(filepath.Base(pubKey))
  36. os.WriteFile(pubKey, buildInPubKey, 0775)
  37. }
  38. if !utils.FileExists(priKey) {
  39. buildInPriKey, _ := buildinCertStore.ReadFile(filepath.Base(priKey))
  40. os.WriteFile(priKey, buildInPriKey, 0775)
  41. }
  42. thisManager := Manager{
  43. CertStore: certStore,
  44. LoadedCerts: []*CertCache{},
  45. verbal: verbal,
  46. }
  47. err := thisManager.UpdateLoadedCertList()
  48. if err != nil {
  49. return nil, err
  50. }
  51. return &thisManager, nil
  52. }
  53. // Update domain mapping from file
  54. func (m *Manager) UpdateLoadedCertList() error {
  55. //Get a list of certificates from file
  56. domainList, err := m.ListCertDomains()
  57. if err != nil {
  58. return err
  59. }
  60. //Load each of the certificates into memory
  61. certList := []*CertCache{}
  62. for _, certname := range domainList {
  63. //Read their certificate into memory
  64. pubKey := filepath.Join(m.CertStore, certname+".pem")
  65. priKey := filepath.Join(m.CertStore, certname+".key")
  66. certificate, err := tls.LoadX509KeyPair(pubKey, priKey)
  67. if err != nil {
  68. log.Println("Certificate loaded failed: " + certname)
  69. continue
  70. }
  71. for _, thisCert := range certificate.Certificate {
  72. loadedCert, err := x509.ParseCertificate(thisCert)
  73. if err != nil {
  74. //Error pasring cert, skip this byte segment
  75. continue
  76. }
  77. thisCacheEntry := CertCache{
  78. Cert: loadedCert,
  79. PubKey: pubKey,
  80. PriKey: priKey,
  81. }
  82. certList = append(certList, &thisCacheEntry)
  83. }
  84. }
  85. //Replace runtime cert array
  86. m.LoadedCerts = certList
  87. return nil
  88. }
  89. // Match cert by CN
  90. func (m *Manager) CertMatchExists(serverName string) bool {
  91. for _, certCacheEntry := range m.LoadedCerts {
  92. if certCacheEntry.Cert.VerifyHostname(serverName) == nil || certCacheEntry.Cert.Issuer.CommonName == serverName {
  93. return true
  94. }
  95. }
  96. return false
  97. }
  98. // Get cert entry by matching server name, return pubKey and priKey if found
  99. // check with CertMatchExists before calling to the load function
  100. func (m *Manager) GetCertByX509CNHostname(serverName string) (string, string) {
  101. for _, certCacheEntry := range m.LoadedCerts {
  102. if certCacheEntry.Cert.VerifyHostname(serverName) == nil || certCacheEntry.Cert.Issuer.CommonName == serverName {
  103. return certCacheEntry.PubKey, certCacheEntry.PriKey
  104. }
  105. }
  106. return "", ""
  107. }
  108. // Return a list of domains by filename
  109. func (m *Manager) ListCertDomains() ([]string, error) {
  110. filenames, err := m.ListCerts()
  111. if err != nil {
  112. return []string{}, err
  113. }
  114. //Remove certificates where there are missing public key or private key
  115. filenames = getCertPairs(filenames)
  116. return filenames, nil
  117. }
  118. // Return a list of cert files (public and private keys)
  119. func (m *Manager) ListCerts() ([]string, error) {
  120. certs, err := os.ReadDir(m.CertStore)
  121. if err != nil {
  122. return []string{}, err
  123. }
  124. filenames := make([]string, 0, len(certs))
  125. for _, cert := range certs {
  126. if !cert.IsDir() {
  127. filenames = append(filenames, cert.Name())
  128. }
  129. }
  130. return filenames, nil
  131. }
  132. // Get a certificate from disk where its certificate matches with the helloinfo
  133. func (m *Manager) GetCert(helloInfo *tls.ClientHelloInfo) (*tls.Certificate, error) {
  134. //Check if the domain corrisponding cert exists
  135. pubKey := "./tmp/localhost.pem"
  136. priKey := "./tmp/localhost.key"
  137. if utils.FileExists(filepath.Join(m.CertStore, helloInfo.ServerName+".pem")) && utils.FileExists(filepath.Join(m.CertStore, helloInfo.ServerName+".key")) {
  138. //Direct hit
  139. pubKey = filepath.Join(m.CertStore, helloInfo.ServerName+".pem")
  140. priKey = filepath.Join(m.CertStore, helloInfo.ServerName+".key")
  141. } else if m.CertMatchExists(helloInfo.ServerName) {
  142. //Use x509
  143. pubKey, priKey = m.GetCertByX509CNHostname(helloInfo.ServerName)
  144. fmt.Println(pubKey, priKey)
  145. } else {
  146. //Fallback to legacy method of matching certificates
  147. /*
  148. domainCerts, _ := m.ListCertDomains()
  149. cloestDomainCert := matchClosestDomainCertificate(helloInfo.ServerName, domainCerts)
  150. if cloestDomainCert != "" {
  151. //There is a matching parent domain for this subdomain. Use this instead.
  152. pubKey = filepath.Join(m.CertStore, cloestDomainCert+".pem")
  153. priKey = filepath.Join(m.CertStore, cloestDomainCert+".key")
  154. } else if m.DefaultCertExists() {
  155. //Use default.pem and default.key
  156. pubKey = filepath.Join(m.CertStore, "default.pem")
  157. priKey = filepath.Join(m.CertStore, "default.key")
  158. if m.verbal {
  159. log.Println("No matching certificate found. Serving with default")
  160. }
  161. } else {
  162. if m.verbal {
  163. log.Println("Matching certificate not found. Serving with build-in certificate. Requesting server name: ", helloInfo.ServerName)
  164. }
  165. }*/
  166. if m.DefaultCertExists() {
  167. //Use default.pem and default.key
  168. pubKey = filepath.Join(m.CertStore, "default.pem")
  169. priKey = filepath.Join(m.CertStore, "default.key")
  170. if m.verbal {
  171. log.Println("No matching certificate found. Serving with default")
  172. }
  173. } else {
  174. if m.verbal {
  175. log.Println("Matching certificate not found. Serving with build-in certificate. Requesting server name: ", helloInfo.ServerName)
  176. }
  177. }
  178. }
  179. //Load the cert and serve it
  180. cer, err := tls.LoadX509KeyPair(pubKey, priKey)
  181. if err != nil {
  182. log.Println(err)
  183. return nil, nil
  184. }
  185. return &cer, nil
  186. }
  187. // Check if both the default cert public key and private key exists
  188. func (m *Manager) DefaultCertExists() bool {
  189. return utils.FileExists(filepath.Join(m.CertStore, "default.pem")) && utils.FileExists(filepath.Join(m.CertStore, "default.key"))
  190. }
  191. // Check if the default cert exists returning seperate results for pubkey and prikey
  192. func (m *Manager) DefaultCertExistsSep() (bool, bool) {
  193. return utils.FileExists(filepath.Join(m.CertStore, "default.pem")), utils.FileExists(filepath.Join(m.CertStore, "default.key"))
  194. }
  195. // Delete the cert if exists
  196. func (m *Manager) RemoveCert(domain string) error {
  197. pubKey := filepath.Join(m.CertStore, domain+".pem")
  198. priKey := filepath.Join(m.CertStore, domain+".key")
  199. if utils.FileExists(pubKey) {
  200. err := os.Remove(pubKey)
  201. if err != nil {
  202. return err
  203. }
  204. }
  205. if utils.FileExists(priKey) {
  206. err := os.Remove(priKey)
  207. if err != nil {
  208. return err
  209. }
  210. }
  211. //Update the cert list
  212. m.UpdateLoadedCertList()
  213. return nil
  214. }
  215. // Check if the given file is a valid TLS file
  216. func IsValidTLSFile(file io.Reader) bool {
  217. // Read the contents of the uploaded file
  218. contents, err := io.ReadAll(file)
  219. if err != nil {
  220. // Handle the error
  221. return false
  222. }
  223. // Parse the contents of the file as a PEM-encoded certificate or key
  224. block, _ := pem.Decode(contents)
  225. if block == nil {
  226. // The file is not a valid PEM-encoded certificate or key
  227. return false
  228. }
  229. // Parse the certificate or key
  230. if strings.Contains(block.Type, "CERTIFICATE") {
  231. // The file contains a certificate
  232. cert, err := x509.ParseCertificate(block.Bytes)
  233. if err != nil {
  234. // Handle the error
  235. return false
  236. }
  237. // Check if the certificate is a valid TLS/SSL certificate
  238. return !cert.IsCA && cert.KeyUsage&x509.KeyUsageDigitalSignature != 0 && cert.KeyUsage&x509.KeyUsageKeyEncipherment != 0
  239. } else if strings.Contains(block.Type, "PRIVATE KEY") {
  240. // The file contains a private key
  241. _, err := x509.ParsePKCS1PrivateKey(block.Bytes)
  242. return err == nil
  243. } else {
  244. return false
  245. }
  246. }