acme.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439
  1. package acme
  2. import (
  3. "crypto"
  4. "crypto/ecdsa"
  5. "crypto/elliptic"
  6. "crypto/rand"
  7. "crypto/tls"
  8. "crypto/x509"
  9. "encoding/json"
  10. "encoding/pem"
  11. "errors"
  12. "fmt"
  13. "log"
  14. "net"
  15. "net/http"
  16. "os"
  17. "path/filepath"
  18. "strconv"
  19. "strings"
  20. "time"
  21. "github.com/go-acme/lego/v4/certcrypto"
  22. "github.com/go-acme/lego/v4/certificate"
  23. "github.com/go-acme/lego/v4/challenge/http01"
  24. "github.com/go-acme/lego/v4/lego"
  25. "github.com/go-acme/lego/v4/registration"
  26. "imuslab.com/zoraxy/mod/database"
  27. "imuslab.com/zoraxy/mod/utils"
  28. )
  29. type CertificateInfoJSON struct {
  30. AcmeName string `json:"acme_name"`
  31. AcmeUrl string `json:"acme_url"`
  32. SkipTLS bool `json:"skip_tls"`
  33. }
  34. // ACMEUser represents a user in the ACME system.
  35. type ACMEUser struct {
  36. Email string
  37. Registration *registration.Resource
  38. key crypto.PrivateKey
  39. }
  40. type EABConfig struct {
  41. Kid string `json:"kid"`
  42. HmacKey string `json:"HmacKey"`
  43. }
  44. // GetEmail returns the email of the ACMEUser.
  45. func (u *ACMEUser) GetEmail() string {
  46. return u.Email
  47. }
  48. // GetRegistration returns the registration resource of the ACMEUser.
  49. func (u ACMEUser) GetRegistration() *registration.Resource {
  50. return u.Registration
  51. }
  52. // GetPrivateKey returns the private key of the ACMEUser.
  53. func (u *ACMEUser) GetPrivateKey() crypto.PrivateKey {
  54. return u.key
  55. }
  56. // ACMEHandler handles ACME-related operations.
  57. type ACMEHandler struct {
  58. DefaultAcmeServer string
  59. Port string
  60. Database *database.Database
  61. }
  62. // NewACME creates a new ACMEHandler instance.
  63. func NewACME(acmeServer string, port string, database *database.Database) *ACMEHandler {
  64. return &ACMEHandler{
  65. DefaultAcmeServer: acmeServer,
  66. Port: port,
  67. Database: database,
  68. }
  69. }
  70. // ObtainCert obtains a certificate for the specified domains.
  71. func (a *ACMEHandler) ObtainCert(domains []string, certificateName string, email string, caName string, caUrl string, skipTLS bool) (bool, error) {
  72. log.Println("[ACME] Obtaining certificate...")
  73. // generate private key
  74. privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
  75. if err != nil {
  76. log.Println(err)
  77. return false, err
  78. }
  79. // create a admin user for our new generation
  80. adminUser := ACMEUser{
  81. Email: email,
  82. key: privateKey,
  83. }
  84. // create config
  85. config := lego.NewConfig(&adminUser)
  86. // skip TLS verify if need
  87. // Ref: https://github.com/go-acme/lego/blob/6af2c756ac73a9cb401621afca722d0f4112b1b8/lego/client_config.go#L74
  88. if skipTLS {
  89. log.Println("[INFO] Ignore TLS/SSL Verification Error for ACME Server")
  90. config.HTTPClient.Transport = &http.Transport{
  91. Proxy: http.ProxyFromEnvironment,
  92. DialContext: (&net.Dialer{
  93. Timeout: 30 * time.Second,
  94. KeepAlive: 30 * time.Second,
  95. }).DialContext,
  96. TLSHandshakeTimeout: 30 * time.Second,
  97. ResponseHeaderTimeout: 30 * time.Second,
  98. TLSClientConfig: &tls.Config{
  99. InsecureSkipVerify: true,
  100. },
  101. }
  102. }
  103. // setup the custom ACME url endpoint.
  104. if caUrl != "" {
  105. config.CADirURL = caUrl
  106. }
  107. // if not custom ACME url, load it from ca.json
  108. if caName == "custom" {
  109. log.Println("[INFO] Using Custom ACME " + caUrl + " for CA Directory URL")
  110. } else {
  111. caLinkOverwrite, err := loadCAApiServerFromName(caName)
  112. if err == nil {
  113. config.CADirURL = caLinkOverwrite
  114. log.Println("[INFO] Using " + caLinkOverwrite + " for CA Directory URL")
  115. } else {
  116. // (caName == "" || caUrl == "") will use default acme
  117. config.CADirURL = a.DefaultAcmeServer
  118. log.Println("[INFO] Using Default ACME " + a.DefaultAcmeServer + " for CA Directory URL")
  119. }
  120. }
  121. config.Certificate.KeyType = certcrypto.RSA2048
  122. client, err := lego.NewClient(config)
  123. if err != nil {
  124. log.Println(err)
  125. return false, err
  126. }
  127. // setup how to receive challenge
  128. err = client.Challenge.SetHTTP01Provider(http01.NewProviderServer("", a.Port))
  129. if err != nil {
  130. log.Println(err)
  131. return false, err
  132. }
  133. // New users will need to register
  134. /*
  135. reg, err := client.Registration.Register(registration.RegisterOptions{TermsOfServiceAgreed: true})
  136. if err != nil {
  137. log.Println(err)
  138. return false, err
  139. }
  140. */
  141. var reg *registration.Resource
  142. // New users will need to register
  143. if client.GetExternalAccountRequired() {
  144. log.Println("External Account Required for this ACME Provider.")
  145. // IF KID and HmacEncoded is overidden
  146. if !a.Database.TableExists("acme") {
  147. a.Database.NewTable("acme")
  148. return false, errors.New("kid and HmacEncoded configuration required for ACME Provider (Error -1)")
  149. }
  150. if !a.Database.KeyExists("acme", config.CADirURL+"_kid") || !a.Database.KeyExists("acme", config.CADirURL+"_hmacEncoded") {
  151. return false, errors.New("kid and HmacEncoded configuration required for ACME Provider (Error -2)")
  152. }
  153. var kid string
  154. var hmacEncoded string
  155. err := a.Database.Read("acme", config.CADirURL+"_kid", &kid)
  156. if err != nil {
  157. log.Println(err)
  158. return false, err
  159. }
  160. err = a.Database.Read("acme", config.CADirURL+"_hmacEncoded", &hmacEncoded)
  161. if err != nil {
  162. log.Println(err)
  163. return false, err
  164. }
  165. log.Println("EAB Credential retrieved.", kid, hmacEncoded)
  166. if kid != "" && hmacEncoded != "" {
  167. reg, err = client.Registration.RegisterWithExternalAccountBinding(registration.RegisterEABOptions{
  168. TermsOfServiceAgreed: true,
  169. Kid: kid,
  170. HmacEncoded: hmacEncoded,
  171. })
  172. }
  173. if err != nil {
  174. log.Println(err)
  175. return false, err
  176. }
  177. //return false, errors.New("External Account Required for this ACME Provider.")
  178. } else {
  179. reg, err = client.Registration.Register(registration.RegisterOptions{TermsOfServiceAgreed: true})
  180. if err != nil {
  181. log.Println(err)
  182. return false, err
  183. }
  184. }
  185. adminUser.Registration = reg
  186. // obtain the certificate
  187. request := certificate.ObtainRequest{
  188. Domains: domains,
  189. Bundle: true,
  190. }
  191. certificates, err := client.Certificate.Obtain(request)
  192. if err != nil {
  193. log.Println(err)
  194. return false, err
  195. }
  196. // Each certificate comes back with the cert bytes, the bytes of the client's
  197. // private key, and a certificate URL.
  198. err = os.WriteFile("./conf/certs/"+certificateName+".pem", certificates.Certificate, 0777)
  199. if err != nil {
  200. log.Println(err)
  201. return false, err
  202. }
  203. err = os.WriteFile("./conf/certs/"+certificateName+".key", certificates.PrivateKey, 0777)
  204. if err != nil {
  205. log.Println(err)
  206. return false, err
  207. }
  208. // Save certificate's ACME info for renew usage
  209. certInfo := &CertificateInfoJSON{
  210. AcmeName: caName,
  211. AcmeUrl: caUrl,
  212. SkipTLS: skipTLS,
  213. }
  214. certInfoBytes, err := json.Marshal(certInfo)
  215. if err != nil {
  216. log.Println(err)
  217. return false, err
  218. }
  219. err = os.WriteFile("./conf/certs/"+certificateName+".json", certInfoBytes, 0777)
  220. if err != nil {
  221. log.Println(err)
  222. return false, err
  223. }
  224. return true, nil
  225. }
  226. // CheckCertificate returns a list of domains that are in expired certificates.
  227. // It will return all domains that is in expired certificates
  228. // *** if there is a vaild certificate contains the domain and there is a expired certificate contains the same domain
  229. // it will said expired as well!
  230. func (a *ACMEHandler) CheckCertificate() []string {
  231. // read from dir
  232. filenames, err := os.ReadDir("./conf/certs/")
  233. expiredCerts := []string{}
  234. if err != nil {
  235. log.Println(err)
  236. return []string{}
  237. }
  238. for _, filename := range filenames {
  239. certFilepath := filepath.Join("./conf/certs/", filename.Name())
  240. certBytes, err := os.ReadFile(certFilepath)
  241. if err != nil {
  242. // Unable to load this file
  243. continue
  244. } else {
  245. // Cert loaded. Check its expiry time
  246. block, _ := pem.Decode(certBytes)
  247. if block != nil {
  248. cert, err := x509.ParseCertificate(block.Bytes)
  249. if err == nil {
  250. elapsed := time.Since(cert.NotAfter)
  251. if elapsed > 0 {
  252. // if it is expired then add it in
  253. // make sure it's uniqueless
  254. for _, dnsName := range cert.DNSNames {
  255. if !contains(expiredCerts, dnsName) {
  256. expiredCerts = append(expiredCerts, dnsName)
  257. }
  258. }
  259. if !contains(expiredCerts, cert.Subject.CommonName) {
  260. expiredCerts = append(expiredCerts, cert.Subject.CommonName)
  261. }
  262. }
  263. }
  264. }
  265. }
  266. }
  267. return expiredCerts
  268. }
  269. // return the current port number
  270. func (a *ACMEHandler) Getport() string {
  271. return a.Port
  272. }
  273. // contains checks if a string is present in a slice.
  274. func contains(slice []string, str string) bool {
  275. for _, s := range slice {
  276. if s == str {
  277. return true
  278. }
  279. }
  280. return false
  281. }
  282. // HandleGetExpiredDomains handles the HTTP GET request to retrieve the list of expired domains.
  283. // It calls the CheckCertificate method to obtain the expired domains and sends a JSON response
  284. // containing the list of expired domains.
  285. func (a *ACMEHandler) HandleGetExpiredDomains(w http.ResponseWriter, r *http.Request) {
  286. type ExpiredDomains struct {
  287. Domain []string `json:"domain"`
  288. }
  289. info := ExpiredDomains{
  290. Domain: a.CheckCertificate(),
  291. }
  292. js, _ := json.MarshalIndent(info, "", " ")
  293. utils.SendJSONResponse(w, string(js))
  294. }
  295. // HandleRenewCertificate handles the HTTP GET request to renew a certificate for the provided domains.
  296. // It retrieves the domains and filename parameters from the request, calls the ObtainCert method
  297. // to renew the certificate, and sends a JSON response indicating the result of the renewal process.
  298. func (a *ACMEHandler) HandleRenewCertificate(w http.ResponseWriter, r *http.Request) {
  299. domainPara, err := utils.PostPara(r, "domains")
  300. if err != nil {
  301. utils.SendErrorResponse(w, jsonEscape(err.Error()))
  302. return
  303. }
  304. filename, err := utils.PostPara(r, "filename")
  305. if err != nil {
  306. utils.SendErrorResponse(w, jsonEscape(err.Error()))
  307. return
  308. }
  309. email, err := utils.PostPara(r, "email")
  310. if err != nil {
  311. utils.SendErrorResponse(w, jsonEscape(err.Error()))
  312. return
  313. }
  314. var caUrl string
  315. ca, err := utils.PostPara(r, "ca")
  316. if err != nil {
  317. log.Println("[INFO] CA not set. Using default")
  318. ca, caUrl = "", ""
  319. }
  320. if ca == "custom" {
  321. caUrl, err = utils.PostPara(r, "caURL")
  322. if err != nil {
  323. log.Println("[INFO] Custom CA set but no URL provide, Using default")
  324. ca, caUrl = "", ""
  325. }
  326. }
  327. if ca == "" {
  328. //default. Use Let's Encrypt
  329. ca = "Let's Encrypt"
  330. }
  331. var skipTLS bool
  332. if skipTLSString, err := utils.PostPara(r, "skipTLS"); err != nil {
  333. skipTLS = false
  334. } else if skipTLSString != "true" {
  335. skipTLS = false
  336. } else {
  337. skipTLS = true
  338. }
  339. domains := strings.Split(domainPara, ",")
  340. result, err := a.ObtainCert(domains, filename, email, ca, caUrl, skipTLS)
  341. if err != nil {
  342. utils.SendErrorResponse(w, jsonEscape(err.Error()))
  343. return
  344. }
  345. utils.SendJSONResponse(w, strconv.FormatBool(result))
  346. }
  347. // Escape JSON string
  348. func jsonEscape(i string) string {
  349. b, err := json.Marshal(i)
  350. if err != nil {
  351. log.Println("Unable to escape json data: " + err.Error())
  352. return i
  353. }
  354. s := string(b)
  355. return s[1 : len(s)-1]
  356. }
  357. // Helper function to check if a port is in use
  358. func IsPortInUse(port int) bool {
  359. address := fmt.Sprintf(":%d", port)
  360. listener, err := net.Listen("tcp", address)
  361. if err != nil {
  362. return true // Port is in use
  363. }
  364. defer listener.Close()
  365. return false // Port is not in use
  366. }
  367. // Load cert information from json file
  368. func loadCertInfoJSON(filename string) (*CertificateInfoJSON, error) {
  369. certInfoBytes, err := os.ReadFile(filename)
  370. if err != nil {
  371. return nil, err
  372. }
  373. certInfo := &CertificateInfoJSON{}
  374. if err = json.Unmarshal(certInfoBytes, certInfo); err != nil {
  375. return nil, err
  376. }
  377. return certInfo, nil
  378. }