acme.go 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230
  1. package acme
  2. import (
  3. "crypto"
  4. "crypto/ecdsa"
  5. "crypto/elliptic"
  6. "crypto/rand"
  7. "crypto/x509"
  8. "encoding/json"
  9. "encoding/pem"
  10. "io/ioutil"
  11. "log"
  12. "net/http"
  13. "os"
  14. "path/filepath"
  15. "strconv"
  16. "strings"
  17. "time"
  18. "github.com/go-acme/lego/v4/certcrypto"
  19. "github.com/go-acme/lego/v4/certificate"
  20. "github.com/go-acme/lego/v4/challenge/http01"
  21. "github.com/go-acme/lego/v4/lego"
  22. "github.com/go-acme/lego/v4/registration"
  23. "imuslab.com/zoraxy/mod/utils"
  24. )
  25. // You'll need a user or account type that implements acme.User
  26. type ACMEUser struct {
  27. Email string
  28. Registration *registration.Resource
  29. key crypto.PrivateKey
  30. }
  31. func (u *ACMEUser) GetEmail() string {
  32. return u.Email
  33. }
  34. func (u ACMEUser) GetRegistration() *registration.Resource {
  35. return u.Registration
  36. }
  37. func (u *ACMEUser) GetPrivateKey() crypto.PrivateKey {
  38. return u.key
  39. }
  40. type ACMEHandler struct {
  41. email string
  42. acmeServer string
  43. port string
  44. }
  45. func NewACME(email string, acmeServer string, port string) *ACMEHandler {
  46. return &ACMEHandler{
  47. email: email,
  48. acmeServer: acmeServer,
  49. port: port,
  50. }
  51. }
  52. func (a *ACMEHandler) ObtainCert(domains []string, certificateName string) (bool, error) {
  53. log.Println("Obtaining certificate...")
  54. privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
  55. if err != nil {
  56. log.Println(err)
  57. return false, err
  58. }
  59. log.Println(a.acmeServer)
  60. adminUser := ACMEUser{
  61. Email: a.email,
  62. key: privateKey,
  63. }
  64. config := lego.NewConfig(&adminUser)
  65. config.CADirURL = a.acmeServer
  66. config.Certificate.KeyType = certcrypto.RSA2048
  67. client, err := lego.NewClient(config)
  68. if err != nil {
  69. log.Println(err)
  70. return false, err
  71. }
  72. err = client.Challenge.SetHTTP01Provider(http01.NewProviderServer("", a.port))
  73. if err != nil {
  74. log.Println(err)
  75. return false, err
  76. }
  77. // New users will need to register
  78. reg, err := client.Registration.Register(registration.RegisterOptions{TermsOfServiceAgreed: true})
  79. if err != nil {
  80. log.Println(err)
  81. return false, err
  82. }
  83. adminUser.Registration = reg
  84. request := certificate.ObtainRequest{
  85. Domains: domains,
  86. Bundle: true,
  87. }
  88. certificates, err := client.Certificate.Obtain(request)
  89. if err != nil {
  90. log.Println(err)
  91. return false, err
  92. }
  93. // Each certificate comes back with the cert bytes, the bytes of the client's
  94. // private key, and a certificate URL. SAVE THESE TO DISK.
  95. err = ioutil.WriteFile("./certs/"+certificateName+".crt", certificates.Certificate, 0777)
  96. if err != nil {
  97. log.Println(err)
  98. return false, err
  99. }
  100. err = ioutil.WriteFile("./certs/"+certificateName+".key", certificates.PrivateKey, 0777)
  101. if err != nil {
  102. log.Println(err)
  103. return false, err
  104. }
  105. return true, nil
  106. }
  107. // Return a list of domains that is in expired certificates
  108. func (a *ACMEHandler) CheckCertificate() []string {
  109. filenames, err := os.ReadDir("./certs/")
  110. expiredCerts := []string{}
  111. if err != nil {
  112. log.Println(err)
  113. return []string{}
  114. }
  115. for _, filename := range filenames {
  116. certFilepath := filepath.Join("./certs/", filename.Name())
  117. certBtyes, err := os.ReadFile(certFilepath)
  118. if err != nil {
  119. //Unable to load this file
  120. continue
  121. } else {
  122. //Cert loaded. Check its expire time
  123. block, _ := pem.Decode(certBtyes)
  124. if block != nil {
  125. cert, err := x509.ParseCertificate(block.Bytes)
  126. if err == nil {
  127. elapsed := time.Since(cert.NotAfter)
  128. //approxMonths := -int(elapsed.Hours() / (24 * 30.44))
  129. //approxDays := -int(elapsed.Hours()/24) % 30
  130. if elapsed > 0 {
  131. //log.Println("Certificate", certFilepath, " expired")
  132. for _, dnsName := range cert.DNSNames {
  133. if !contains(expiredCerts, dnsName) {
  134. expiredCerts = append(expiredCerts, dnsName)
  135. }
  136. }
  137. if !contains(expiredCerts, cert.Subject.CommonName) {
  138. expiredCerts = append(expiredCerts, cert.Subject.CommonName)
  139. }
  140. } else {
  141. //log.Println("Certificate", certFilepath, " will still vaild for the next ", approxMonths, "m", approxDays, "d")
  142. }
  143. }
  144. }
  145. }
  146. }
  147. return expiredCerts
  148. }
  149. func (a *ACMEHandler) Getport() string {
  150. return a.port
  151. }
  152. func contains(slice []string, str string) bool {
  153. for _, s := range slice {
  154. if s == str {
  155. return true
  156. }
  157. }
  158. return false
  159. }
  160. func (a *ACMEHandler) HandleGetExpiredDomains(w http.ResponseWriter, r *http.Request) {
  161. type ExpiredDomains struct {
  162. Domain []string `json:"domain"`
  163. }
  164. info := ExpiredDomains{
  165. Domain: a.CheckCertificate(),
  166. }
  167. js, _ := json.MarshalIndent(info, "", " ")
  168. utils.SendJSONResponse(w, string(js))
  169. }
  170. func (a *ACMEHandler) HandleRenewCertificate(w http.ResponseWriter, r *http.Request) {
  171. domainPara, err := utils.GetPara(r, "domains")
  172. if err != nil {
  173. utils.SendErrorResponse(w, jsonEscape(err.Error()))
  174. return
  175. }
  176. filename, err := utils.GetPara(r, "filename")
  177. if err != nil {
  178. utils.SendErrorResponse(w, jsonEscape(err.Error()))
  179. return
  180. }
  181. domains := strings.Split(domainPara, ",")
  182. result, err := a.ObtainCert(domains, filename)
  183. if err != nil {
  184. utils.SendErrorResponse(w, jsonEscape(err.Error()))
  185. return
  186. }
  187. utils.SendJSONResponse(w, strconv.FormatBool(result))
  188. }
  189. func jsonEscape(i string) string {
  190. b, err := json.Marshal(i)
  191. if err != nil {
  192. panic(err)
  193. }
  194. s := string(b)
  195. return s[1 : len(s)-1]
  196. }