acme.go 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531
  1. package acme
  2. import (
  3. "crypto"
  4. "crypto/ecdsa"
  5. "crypto/elliptic"
  6. "crypto/rand"
  7. "crypto/tls"
  8. "crypto/x509"
  9. "encoding/json"
  10. "encoding/pem"
  11. "errors"
  12. "fmt"
  13. "net"
  14. "net/http"
  15. "os"
  16. "path/filepath"
  17. "strconv"
  18. "strings"
  19. "time"
  20. "github.com/go-acme/lego/v4/certcrypto"
  21. "github.com/go-acme/lego/v4/certificate"
  22. "github.com/go-acme/lego/v4/challenge/http01"
  23. "github.com/go-acme/lego/v4/lego"
  24. "github.com/go-acme/lego/v4/registration"
  25. "imuslab.com/zoraxy/mod/database"
  26. "imuslab.com/zoraxy/mod/info/logger"
  27. "imuslab.com/zoraxy/mod/utils"
  28. )
  29. type CertificateInfoJSON struct {
  30. AcmeName string `json:"acme_name"` //ACME provider name
  31. AcmeUrl string `json:"acme_url"` //Custom ACME URL (if any)
  32. SkipTLS bool `json:"skip_tls"` //Skip TLS verification of upstream
  33. UseDNS bool `json:"dns"` //Use DNS challenge
  34. PropTimeout int `json:"prop_time"` //Propagation timeout
  35. }
  36. // ACMEUser represents a user in the ACME system.
  37. type ACMEUser struct {
  38. Email string
  39. Registration *registration.Resource
  40. key crypto.PrivateKey
  41. }
  42. type EABConfig struct {
  43. Kid string `json:"kid"`
  44. HmacKey string `json:"HmacKey"`
  45. }
  46. // GetEmail returns the email of the ACMEUser.
  47. func (u *ACMEUser) GetEmail() string {
  48. return u.Email
  49. }
  50. // GetRegistration returns the registration resource of the ACMEUser.
  51. func (u ACMEUser) GetRegistration() *registration.Resource {
  52. return u.Registration
  53. }
  54. // GetPrivateKey returns the private key of the ACMEUser.
  55. func (u *ACMEUser) GetPrivateKey() crypto.PrivateKey {
  56. return u.key
  57. }
  58. // ACMEHandler handles ACME-related operations.
  59. type ACMEHandler struct {
  60. DefaultAcmeServer string
  61. Port string
  62. Database *database.Database
  63. Logger *logger.Logger
  64. }
  65. // NewACME creates a new ACMEHandler instance.
  66. func NewACME(defaultAcmeServer string, port string, database *database.Database, logger *logger.Logger) *ACMEHandler {
  67. return &ACMEHandler{
  68. DefaultAcmeServer: defaultAcmeServer,
  69. Port: port,
  70. Database: database,
  71. Logger: logger,
  72. }
  73. }
  74. func (a *ACMEHandler) Logf(message string, err error) {
  75. a.Logger.PrintAndLog("ACME", message, err)
  76. }
  77. // Close closes the ACMEHandler.
  78. // ACME Handler does not need to close anything
  79. // Function defined for future compatibility
  80. func (a *ACMEHandler) Close() error {
  81. return nil
  82. }
  83. // ObtainCert obtains a certificate for the specified domains.
  84. func (a *ACMEHandler) ObtainCert(domains []string, certificateName string, email string, caName string, caUrl string, skipTLS bool, useDNS bool, propagationTimeout int) (bool, error) {
  85. a.Logf("Obtaining certificate for: "+strings.Join(domains, ", "), nil)
  86. // generate private key
  87. privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
  88. if err != nil {
  89. a.Logf("Private key generation failed", err)
  90. return false, err
  91. }
  92. // create a admin user for our new generation
  93. adminUser := ACMEUser{
  94. Email: email,
  95. key: privateKey,
  96. }
  97. // create config
  98. config := lego.NewConfig(&adminUser)
  99. // skip TLS verify if need
  100. // Ref: https://github.com/go-acme/lego/blob/6af2c756ac73a9cb401621afca722d0f4112b1b8/lego/client_config.go#L74
  101. if skipTLS {
  102. a.Logf("Ignoring TLS/SSL Verification Error for ACME Server", nil)
  103. config.HTTPClient.Transport = &http.Transport{
  104. Proxy: http.ProxyFromEnvironment,
  105. DialContext: (&net.Dialer{
  106. Timeout: 30 * time.Second,
  107. KeepAlive: 30 * time.Second,
  108. }).DialContext,
  109. TLSHandshakeTimeout: 30 * time.Second,
  110. ResponseHeaderTimeout: 30 * time.Second,
  111. TLSClientConfig: &tls.Config{
  112. InsecureSkipVerify: true,
  113. },
  114. }
  115. }
  116. //Fallback to Let's Encrypt if it is not set
  117. if caName == "" {
  118. caName = "Let's Encrypt"
  119. }
  120. // setup the custom ACME url endpoint.
  121. if caUrl != "" {
  122. config.CADirURL = caUrl
  123. }
  124. // if not custom ACME url, load it from ca.json
  125. if caName == "custom" {
  126. a.Logf("Using Custom ACME "+caUrl+" for CA Directory URL", nil)
  127. } else {
  128. caLinkOverwrite, err := loadCAApiServerFromName(caName)
  129. if err == nil {
  130. config.CADirURL = caLinkOverwrite
  131. a.Logf("Using "+caLinkOverwrite+" for CA Directory URL", nil)
  132. } else {
  133. // (caName == "" || caUrl == "") will use default acme
  134. config.CADirURL = a.DefaultAcmeServer
  135. a.Logf("Using Default ACME "+a.DefaultAcmeServer+" for CA Directory URL", nil)
  136. }
  137. }
  138. config.Certificate.KeyType = certcrypto.RSA2048
  139. client, err := lego.NewClient(config)
  140. if err != nil {
  141. a.Logf("Failed to spawn new ACME client from current config", err)
  142. return false, err
  143. }
  144. // setup how to receive challenge
  145. if useDNS {
  146. if !a.Database.TableExists("acme") {
  147. a.Database.NewTable("acme")
  148. return false, errors.New("DNS Provider and DNS Credenital configuration required for ACME Provider (Error -1)")
  149. }
  150. if !a.Database.KeyExists("acme", certificateName+"_dns_provider") || !a.Database.KeyExists("acme", certificateName+"_dns_credentials") {
  151. return false, errors.New("DNS Provider and DNS Credenital configuration required for ACME Provider (Error -2)")
  152. }
  153. var dnsCredentials string
  154. err := a.Database.Read("acme", certificateName+"_dns_credentials", &dnsCredentials)
  155. if err != nil {
  156. a.Logf("Read DNS credential failed", err)
  157. return false, err
  158. }
  159. var dnsProvider string
  160. err = a.Database.Read("acme", certificateName+"_dns_provider", &dnsProvider)
  161. if err != nil {
  162. a.Logf("Read DNS Provider failed", err)
  163. return false, err
  164. }
  165. provider, err := GetDnsChallengeProviderByName(dnsProvider, dnsCredentials, propagationTimeout)
  166. if err != nil {
  167. a.Logf("Unable to resolve DNS challenge provider", err)
  168. return false, err
  169. }
  170. err = client.Challenge.SetDNS01Provider(provider)
  171. if err != nil {
  172. a.Logf("Failed to resolve DNS01 Provider", err)
  173. return false, err
  174. }
  175. } else {
  176. err = client.Challenge.SetHTTP01Provider(http01.NewProviderServer("", a.Port))
  177. if err != nil {
  178. a.Logf("Failed to resolve HTTP01 Provider", err)
  179. return false, err
  180. }
  181. }
  182. // New users will need to register
  183. /*
  184. reg, err := client.Registration.Register(registration.RegisterOptions{TermsOfServiceAgreed: true})
  185. if err != nil {
  186. log.Println(err)
  187. return false, err
  188. }
  189. */
  190. var reg *registration.Resource
  191. // New users will need to register
  192. if client.GetExternalAccountRequired() {
  193. a.Logf("External Account Required for this ACME Provider", nil)
  194. // IF KID and HmacEncoded is overidden
  195. if !a.Database.TableExists("acme") {
  196. a.Database.NewTable("acme")
  197. return false, errors.New("kid and HmacEncoded configuration required for ACME Provider (Error -1)")
  198. }
  199. if !a.Database.KeyExists("acme", config.CADirURL+"_kid") || !a.Database.KeyExists("acme", config.CADirURL+"_hmacEncoded") {
  200. return false, errors.New("kid and HmacEncoded configuration required for ACME Provider (Error -2)")
  201. }
  202. var kid string
  203. var hmacEncoded string
  204. err := a.Database.Read("acme", config.CADirURL+"_kid", &kid)
  205. if err != nil {
  206. a.Logf("Failed to read kid from database", err)
  207. return false, err
  208. }
  209. err = a.Database.Read("acme", config.CADirURL+"_hmacEncoded", &hmacEncoded)
  210. if err != nil {
  211. a.Logf("Failed to read HMAC from database", err)
  212. return false, err
  213. }
  214. a.Logf("EAB Credential retrieved: "+kid+" / "+hmacEncoded, nil)
  215. if kid != "" && hmacEncoded != "" {
  216. reg, err = client.Registration.RegisterWithExternalAccountBinding(registration.RegisterEABOptions{
  217. TermsOfServiceAgreed: true,
  218. Kid: kid,
  219. HmacEncoded: hmacEncoded,
  220. })
  221. }
  222. if err != nil {
  223. a.Logf("Register with external account binder failed", err)
  224. return false, err
  225. }
  226. //return false, errors.New("External Account Required for this ACME Provider.")
  227. } else {
  228. reg, err = client.Registration.Register(registration.RegisterOptions{TermsOfServiceAgreed: true})
  229. if err != nil {
  230. a.Logf("Unable to register client", err)
  231. return false, err
  232. }
  233. }
  234. adminUser.Registration = reg
  235. // obtain the certificate
  236. request := certificate.ObtainRequest{
  237. Domains: domains,
  238. Bundle: true,
  239. }
  240. certificates, err := client.Certificate.Obtain(request)
  241. if err != nil {
  242. a.Logf("Obtain certificate failed", err)
  243. return false, err
  244. }
  245. // Each certificate comes back with the cert bytes, the bytes of the client's
  246. // private key, and a certificate URL.
  247. err = os.WriteFile("./conf/certs/"+certificateName+".pem", certificates.Certificate, 0777)
  248. if err != nil {
  249. a.Logf("Failed to write public key to disk", err)
  250. return false, err
  251. }
  252. err = os.WriteFile("./conf/certs/"+certificateName+".key", certificates.PrivateKey, 0777)
  253. if err != nil {
  254. a.Logf("Failed to write private key to disk", err)
  255. return false, err
  256. }
  257. // Save certificate's ACME info for renew usage
  258. certInfo := &CertificateInfoJSON{
  259. AcmeName: caName,
  260. AcmeUrl: caUrl,
  261. SkipTLS: skipTLS,
  262. UseDNS: useDNS,
  263. PropTimeout: propagationTimeout,
  264. }
  265. certInfoBytes, err := json.Marshal(certInfo)
  266. if err != nil {
  267. a.Logf("Marshal certificate renew config failed", err)
  268. return false, err
  269. }
  270. err = os.WriteFile("./conf/certs/"+certificateName+".json", certInfoBytes, 0777)
  271. if err != nil {
  272. a.Logf("Failed to write certificate renew config to file", err)
  273. return false, err
  274. }
  275. return true, nil
  276. }
  277. // CheckCertificate returns a list of domains that are in expired certificates.
  278. // It will return all domains that is in expired certificates
  279. // *** if there is a vaild certificate contains the domain and there is a expired certificate contains the same domain
  280. // it will said expired as well!
  281. func (a *ACMEHandler) CheckCertificate() []string {
  282. // read from dir
  283. filenames, err := os.ReadDir("./conf/certs/")
  284. expiredCerts := []string{}
  285. if err != nil {
  286. a.Logf("Failed to load certificate folder", err)
  287. return []string{}
  288. }
  289. for _, filename := range filenames {
  290. certFilepath := filepath.Join("./conf/certs/", filename.Name())
  291. certBytes, err := os.ReadFile(certFilepath)
  292. if err != nil {
  293. // Unable to load this file
  294. continue
  295. } else {
  296. // Cert loaded. Check its expiry time
  297. block, _ := pem.Decode(certBytes)
  298. if block != nil {
  299. cert, err := x509.ParseCertificate(block.Bytes)
  300. if err == nil {
  301. elapsed := time.Since(cert.NotAfter)
  302. if elapsed > 0 {
  303. // if it is expired then add it in
  304. // make sure it's uniqueless
  305. for _, dnsName := range cert.DNSNames {
  306. if !contains(expiredCerts, dnsName) {
  307. expiredCerts = append(expiredCerts, dnsName)
  308. }
  309. }
  310. if !contains(expiredCerts, cert.Subject.CommonName) {
  311. expiredCerts = append(expiredCerts, cert.Subject.CommonName)
  312. }
  313. }
  314. }
  315. }
  316. }
  317. }
  318. return expiredCerts
  319. }
  320. // return the current port number
  321. func (a *ACMEHandler) Getport() string {
  322. return a.Port
  323. }
  324. // contains checks if a string is present in a slice.
  325. func contains(slice []string, str string) bool {
  326. for _, s := range slice {
  327. if s == str {
  328. return true
  329. }
  330. }
  331. return false
  332. }
  333. // HandleGetExpiredDomains handles the HTTP GET request to retrieve the list of expired domains.
  334. // It calls the CheckCertificate method to obtain the expired domains and sends a JSON response
  335. // containing the list of expired domains.
  336. func (a *ACMEHandler) HandleGetExpiredDomains(w http.ResponseWriter, r *http.Request) {
  337. type ExpiredDomains struct {
  338. Domain []string `json:"domain"`
  339. }
  340. info := ExpiredDomains{
  341. Domain: a.CheckCertificate(),
  342. }
  343. js, _ := json.MarshalIndent(info, "", " ")
  344. utils.SendJSONResponse(w, string(js))
  345. }
  346. // HandleRenewCertificate handles the HTTP GET request to renew a certificate for the provided domains.
  347. // It retrieves the domains and filename parameters from the request, calls the ObtainCert method
  348. // to renew the certificate, and sends a JSON response indicating the result of the renewal process.
  349. func (a *ACMEHandler) HandleRenewCertificate(w http.ResponseWriter, r *http.Request) {
  350. domainPara, err := utils.PostPara(r, "domains")
  351. if err != nil {
  352. utils.SendErrorResponse(w, jsonEscape(err.Error()))
  353. return
  354. }
  355. filename, err := utils.PostPara(r, "filename")
  356. if err != nil {
  357. utils.SendErrorResponse(w, jsonEscape(err.Error()))
  358. return
  359. }
  360. //Make sure the wildcard * do not goes into the filename
  361. filename = strings.ReplaceAll(filename, "*", "_")
  362. email, err := utils.PostPara(r, "email")
  363. if err != nil {
  364. utils.SendErrorResponse(w, jsonEscape(err.Error()))
  365. return
  366. }
  367. var caUrl string
  368. ca, err := utils.PostPara(r, "ca")
  369. if err != nil {
  370. a.Logf("CA not set. Using default", nil)
  371. ca, caUrl = "", ""
  372. }
  373. if ca == "custom" {
  374. caUrl, err = utils.PostPara(r, "caURL")
  375. if err != nil {
  376. a.Logf("Custom CA set but no URL provide, Using default", nil)
  377. ca, caUrl = "", ""
  378. }
  379. }
  380. if ca == "" {
  381. //default. Use Let's Encrypt
  382. ca = "Let's Encrypt"
  383. }
  384. var skipTLS bool
  385. if skipTLSString, err := utils.PostPara(r, "skipTLS"); err != nil {
  386. skipTLS = false
  387. } else if skipTLSString != "true" {
  388. skipTLS = false
  389. } else {
  390. skipTLS = true
  391. }
  392. var dns bool
  393. if dnsString, err := utils.PostPara(r, "dns"); err != nil {
  394. dns = false
  395. } else if dnsString != "true" {
  396. dns = false
  397. } else {
  398. dns = true
  399. }
  400. domains := strings.Split(domainPara, ",")
  401. // Default propagation timeout is 300 seconds
  402. propagationTimeout := 300
  403. if dns {
  404. ppgTimeout, err := utils.PostPara(r, "ppgTimeout")
  405. if err == nil {
  406. propagationTimeout, err = strconv.Atoi(ppgTimeout)
  407. if err != nil {
  408. utils.SendErrorResponse(w, "Invalid propagation timeout value")
  409. return
  410. }
  411. if propagationTimeout < 60 {
  412. //Minimum propagation timeout is 60 seconds
  413. propagationTimeout = 60
  414. }
  415. }
  416. }
  417. //Clean spaces in front or behind each domain
  418. cleanedDomains := []string{}
  419. for _, domain := range domains {
  420. cleanedDomains = append(cleanedDomains, strings.TrimSpace(domain))
  421. }
  422. result, err := a.ObtainCert(cleanedDomains, filename, email, ca, caUrl, skipTLS, dns, propagationTimeout)
  423. if err != nil {
  424. utils.SendErrorResponse(w, jsonEscape(err.Error()))
  425. return
  426. }
  427. utils.SendJSONResponse(w, strconv.FormatBool(result))
  428. }
  429. // Escape JSON string
  430. func jsonEscape(i string) string {
  431. b, err := json.Marshal(i)
  432. if err != nil {
  433. //log.Println("Unable to escape json data: " + err.Error())
  434. return i
  435. }
  436. s := string(b)
  437. return s[1 : len(s)-1]
  438. }
  439. // Helper function to check if a port is in use
  440. func IsPortInUse(port int) bool {
  441. address := fmt.Sprintf(":%d", port)
  442. listener, err := net.Listen("tcp", address)
  443. if err != nil {
  444. return true // Port is in use
  445. }
  446. defer listener.Close()
  447. return false // Port is not in use
  448. }
  449. // Load cert information from json file
  450. func LoadCertInfoJSON(filename string) (*CertificateInfoJSON, error) {
  451. certInfoBytes, err := os.ReadFile(filename)
  452. if err != nil {
  453. return nil, err
  454. }
  455. certInfo := &CertificateInfoJSON{}
  456. if err = json.Unmarshal(certInfoBytes, certInfo); err != nil {
  457. return nil, err
  458. }
  459. return certInfo, nil
  460. }