1
0
Эх сурвалжийг харах

auto update script executed

Toby Chui 1 жил өмнө
parent
commit
0079b4ab3e
8 өөрчлөгдсөн 245 нэмэгдсэн , 48 устгасан
  1. 2 2
      acme.go
  2. 1 0
      cert.go
  3. 1 1
      main.go
  4. 24 16
      mod/acme/acme.go
  5. 148 23
      mod/acme/autorenew.go
  6. 45 3
      mod/acme/utils.go
  7. 1 1
      start.go
  8. 23 2
      web/snippet/acme.html

+ 2 - 2
acme.go

@@ -28,7 +28,7 @@ func getRandomPort(minPort int) int {
 
 // init the new ACME instance
 func initACME() *acme.ACMEHandler {
-	log.Println("Start initializing ACME")
+	log.Println("Starting ACME handler")
 	rand.Seed(time.Now().UnixNano())
 	// Generate a random port above 30000
 	port := getRandomPort(30000)
@@ -38,7 +38,7 @@ func initACME() *acme.ACMEHandler {
 		port = getRandomPort(30000)
 	}
 
-	return acme.NewACME("[email protected]", "https://acme-staging-v02.api.letsencrypt.org/directory", strconv.Itoa(port))
+	return acme.NewACME("https://acme-staging-v02.api.letsencrypt.org/directory", strconv.Itoa(port))
 }
 
 // create the special routing rule for ACME

+ 1 - 0
cert.go

@@ -128,6 +128,7 @@ func handleListDomains(w http.ResponseWriter, r *http.Request) {
 		certBtyes, err := os.ReadFile(certFilepath)
 		if err != nil {
 			// Unable to load this file
+			log.Println("Unable to load certificate: " + certFilepath)
 			continue
 		} else {
 			// Cert loaded. Check its expiry time

+ 1 - 1
main.go

@@ -38,7 +38,7 @@ var showver = flag.Bool("version", false, "Show version of this server")
 var allowSshLoopback = flag.Bool("sshlb", false, "Allow loopback web ssh connection (DANGER)")
 var ztAuthToken = flag.String("ztauth", "", "ZeroTier authtoken for the local node")
 var ztAPIPort = flag.Int("ztport", 9993, "ZeroTier controller API port")
-var acmeAutoRenewInterval = flag.Int("autorenew", 86400, "ACME auto TLS/SSL certificate renew check interval")
+var acmeAutoRenewInterval = flag.Int("autorenew", 86400, "ACME auto TLS/SSL certificate renew check interval (seconds)")
 var (
 	name        = "Zoraxy"
 	version     = "2.6.5"

+ 24 - 16
mod/acme/acme.go

@@ -8,6 +8,7 @@ import (
 	"crypto/x509"
 	"encoding/json"
 	"encoding/pem"
+	"errors"
 	"fmt"
 	"io/ioutil"
 	"log"
@@ -51,23 +52,21 @@ func (u *ACMEUser) GetPrivateKey() crypto.PrivateKey {
 
 // ACMEHandler handles ACME-related operations.
 type ACMEHandler struct {
-	email      string
-	acmeServer string
-	port       string
+	DefaultAcmeServer string
+	Port              string
 }
 
 // NewACME creates a new ACMEHandler instance.
-func NewACME(email string, acmeServer string, port string) *ACMEHandler {
+func NewACME(acmeServer string, port string) *ACMEHandler {
 	return &ACMEHandler{
-		email:      email,
-		acmeServer: acmeServer,
-		port:       port,
+		DefaultAcmeServer: acmeServer,
+		Port:              port,
 	}
 }
 
 // ObtainCert obtains a certificate for the specified domains.
-func (a *ACMEHandler) ObtainCert(domains []string, certificateName string, ca string) (bool, error) {
-	log.Println("Obtaining certificate...")
+func (a *ACMEHandler) ObtainCert(domains []string, certificateName string, email string, ca string) (bool, error) {
+	log.Println("[ACME] Obtaining certificate...")
 
 	// generate private key
 	privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
@@ -78,7 +77,7 @@ func (a *ACMEHandler) ObtainCert(domains []string, certificateName string, ca st
 
 	// create a admin user for our new generation
 	adminUser := ACMEUser{
-		Email: a.email,
+		Email: email,
 		key:   privateKey,
 	}
 
@@ -86,7 +85,7 @@ func (a *ACMEHandler) ObtainCert(domains []string, certificateName string, ca st
 	config := lego.NewConfig(&adminUser)
 
 	// setup who is the issuer and the key type
-	config.CADirURL = a.acmeServer
+	config.CADirURL = a.DefaultAcmeServer
 
 	//Overwrite the CADir URL if set
 	if ca != "" {
@@ -94,6 +93,8 @@ func (a *ACMEHandler) ObtainCert(domains []string, certificateName string, ca st
 		if err == nil {
 			config.CADirURL = caLinkOverwrite
 			log.Println("[INFO] Using " + caLinkOverwrite + " for CA Directory URL")
+		} else {
+			return false, errors.New("CA " + ca + " is not supported. Please contribute to the source code and add this CA's directory link.")
 		}
 	}
 
@@ -106,7 +107,7 @@ func (a *ACMEHandler) ObtainCert(domains []string, certificateName string, ca st
 	}
 
 	// setup how to receive challenge
-	err = client.Challenge.SetHTTP01Provider(http01.NewProviderServer("", a.port))
+	err = client.Challenge.SetHTTP01Provider(http01.NewProviderServer("", a.Port))
 	if err != nil {
 		log.Println(err)
 		return false, err
@@ -165,13 +166,13 @@ func (a *ACMEHandler) CheckCertificate() []string {
 	for _, filename := range filenames {
 		certFilepath := filepath.Join("./certs/", filename.Name())
 
-		certBtyes, err := os.ReadFile(certFilepath)
+		certBytes, err := os.ReadFile(certFilepath)
 		if err != nil {
 			// Unable to load this file
 			continue
 		} else {
 			// Cert loaded. Check its expiry time
-			block, _ := pem.Decode(certBtyes)
+			block, _ := pem.Decode(certBytes)
 			if block != nil {
 				cert, err := x509.ParseCertificate(block.Bytes)
 				if err == nil {
@@ -198,7 +199,7 @@ func (a *ACMEHandler) CheckCertificate() []string {
 
 // return the current port number
 func (a *ACMEHandler) Getport() string {
-	return a.port
+	return a.Port
 }
 
 // contains checks if a string is present in a slice.
@@ -236,12 +237,19 @@ func (a *ACMEHandler) HandleRenewCertificate(w http.ResponseWriter, r *http.Requ
 		utils.SendErrorResponse(w, jsonEscape(err.Error()))
 		return
 	}
+
 	filename, err := utils.PostPara(r, "filename")
 	if err != nil {
 		utils.SendErrorResponse(w, jsonEscape(err.Error()))
 		return
 	}
 
+	email, err := utils.PostPara(r, "email")
+	if err != nil {
+		utils.SendErrorResponse(w, jsonEscape(err.Error()))
+		return
+	}
+
 	ca, err := utils.PostPara(r, "ca")
 	if err != nil {
 		log.Println("CA not set. Using default (Let's Encrypt)")
@@ -249,7 +257,7 @@ func (a *ACMEHandler) HandleRenewCertificate(w http.ResponseWriter, r *http.Requ
 	}
 
 	domains := strings.Split(domainPara, ",")
-	result, err := a.ObtainCert(domains, filename, ca)
+	result, err := a.ObtainCert(domains, filename, email, ca)
 	if err != nil {
 		utils.SendErrorResponse(w, jsonEscape(err.Error()))
 		return

+ 148 - 23
mod/acme/autorenew.go

@@ -3,12 +3,12 @@ package acme
 import (
 	"encoding/json"
 	"errors"
-	"fmt"
 	"log"
 	"net/http"
 	"net/mail"
 	"os"
 	"path/filepath"
+	"strings"
 	"time"
 
 	"imuslab.com/zoraxy/mod/utils"
@@ -28,22 +28,27 @@ type AutoRenewConfig struct {
 }
 
 type AutoRenewer struct {
-	ConfigFilePath string
-	CertFolder     string
-	RenewerConfig  *AutoRenewConfig
-	TickerstopChan chan bool
+	ConfigFilePath    string
+	CertFolder        string
+	AcmeHandler       *ACMEHandler
+	RenewerConfig     *AutoRenewConfig
+	RenewTickInterval int64
+	TickerstopChan    chan bool
+}
+
+type ExpiredCerts struct {
+	Domains  []string
+	Filepath string
+	CA       string
 }
 
 // Create an auto renew agent, require config filepath and auto scan & renew interval (seconds)
 // Set renew check interval to 0 for auto (1 day)
-func NewAutoRenewer(config string, certFolder string, renewCheckInterval int64) (*AutoRenewer, error) {
+func NewAutoRenewer(config string, certFolder string, renewCheckInterval int64, AcmeHandler *ACMEHandler) (*AutoRenewer, error) {
 	if renewCheckInterval == 0 {
 		renewCheckInterval = 86400 //1 day
 	}
 
-	ticker := time.NewTicker(time.Duration(renewCheckInterval) * time.Second)
-	done := make(chan bool)
-
 	//Load the config file. If not found, create one
 	if !utils.FileExists(config) {
 		//Create one
@@ -72,29 +77,57 @@ func NewAutoRenewer(config string, certFolder string, renewCheckInterval int64)
 
 	//Create an Auto renew object
 	thisRenewer := AutoRenewer{
-		ConfigFilePath: config,
-		CertFolder:     certFolder,
-		RenewerConfig:  &renewerConfig,
-		TickerstopChan: done,
+		ConfigFilePath:    config,
+		CertFolder:        certFolder,
+		AcmeHandler:       AcmeHandler,
+		RenewerConfig:     &renewerConfig,
+		RenewTickInterval: renewCheckInterval,
 	}
 
-	//Check and renew certificate on startup
-	thisRenewer.CheckAndRenewCertificates()
+	if thisRenewer.RenewerConfig.Enabled {
+		//Start the renew ticker
+		thisRenewer.StartAutoRenewTicker()
+
+		//Check and renew certificate on startup
+		go thisRenewer.CheckAndRenewCertificates()
+	}
+
+	return &thisRenewer, nil
+}
+
+func (a *AutoRenewer) StartAutoRenewTicker() {
+	//Stop the previous ticker if still running
+	if a.TickerstopChan != nil {
+		a.TickerstopChan <- true
+	}
+
+	time.Sleep(1 * time.Second)
+
+	ticker := time.NewTicker(time.Duration(a.RenewTickInterval) * time.Second)
+	done := make(chan bool)
 
 	//Start the ticker to check and renew every x seconds
-	go func() {
+	go func(a *AutoRenewer) {
 		for {
 			select {
 			case <-done:
 				return
 			case <-ticker.C:
 				log.Println("Check and renew certificates in progress")
-				thisRenewer.CheckAndRenewCertificates()
+				a.CheckAndRenewCertificates()
 			}
 		}
-	}()
+	}(a)
 
-	return &thisRenewer, nil
+	a.TickerstopChan = done
+}
+
+func (a *AutoRenewer) StopAutoRenewTicker() {
+	if a.TickerstopChan != nil {
+		a.TickerstopChan <- true
+	}
+
+	a.TickerstopChan = nil
 }
 
 // Handle update auto renew domains
@@ -170,11 +203,13 @@ func (a *AutoRenewer) HandleAutoRenewEnable(w http.ResponseWriter, r *http.Reque
 
 			a.RenewerConfig.Enabled = true
 			a.saveRenewConfigToFile()
-			log.Println("[AutoRenew] ACME auto renew enabled")
+			log.Println("[ACME] ACME auto renew enabled")
+			a.StartAutoRenewTicker()
 		} else {
 			a.RenewerConfig.Enabled = false
 			a.saveRenewConfigToFile()
-			log.Println("[AutoRenew] ACME auto renew disabled")
+			log.Println("[ACME] ACME auto renew disabled")
+			a.StopAutoRenewTicker()
 		}
 	}
 }
@@ -212,8 +247,98 @@ func (a *AutoRenewer) CheckAndRenewCertificates() ([]string, error) {
 		return []string{}, err
 	}
 
-	fmt.Println("[ACME DEBUG] Cert found: ", files)
-	return []string{}, nil
+	expiredCertList := []*ExpiredCerts{}
+	if a.RenewerConfig.RenewAll {
+		//Scan and renew all
+		for _, file := range files {
+			if filepath.Ext(file.Name()) == ".crt" || filepath.Ext(file.Name()) == ".pem" {
+				//This is a public key file
+				certBytes, err := os.ReadFile(filepath.Join(certFolder, file.Name()))
+				if err != nil {
+					continue
+				}
+				if CertExpireSoon(certBytes) || CertIsExpired(certBytes) {
+					//This cert is expired
+					CAName, err := ExtractIssuerName(certBytes)
+					if err != nil {
+						//Maybe self signed. Ignore this
+						log.Println("Unable to extract issuer name for cert " + file.Name())
+						continue
+					}
+
+					DNSName, err := ExtractDomains(certBytes)
+					if err != nil {
+						//Maybe self signed. Ignore this
+						log.Println("Encounted error when trying to resolve DNS name for cert " + file.Name())
+						continue
+					}
+
+					expiredCertList = append(expiredCertList, &ExpiredCerts{
+						Filepath: filepath.Join(certFolder, file.Name()),
+						CA:       CAName,
+						Domains:  DNSName,
+					})
+				}
+			}
+		}
+	} else {
+		//Only renew those in the list
+		for _, file := range files {
+			fileName := file.Name()
+			certName := fileName[:len(fileName)-len(filepath.Ext(fileName))]
+			if contains(a.RenewerConfig.FilesToRenew, certName) {
+				//This is the one to auto renew
+				certBytes, err := os.ReadFile(filepath.Join(certFolder, file.Name()))
+				if err != nil {
+					continue
+				}
+				if CertExpireSoon(certBytes) || CertIsExpired(certBytes) {
+					//This cert is expired
+					CAName, err := ExtractIssuerName(certBytes)
+					if err != nil {
+						//Maybe self signed. Ignore this
+						log.Println("Unable to extract issuer name for cert " + file.Name())
+						continue
+					}
+
+					DNSName, err := ExtractDomains(certBytes)
+					if err != nil {
+						//Maybe self signed. Ignore this
+						log.Println("Encounted error when trying to resolve DNS name for cert " + file.Name())
+						continue
+					}
+
+					expiredCertList = append(expiredCertList, &ExpiredCerts{
+						Filepath: filepath.Join(certFolder, file.Name()),
+						CA:       CAName,
+						Domains:  DNSName,
+					})
+				}
+			}
+		}
+	}
+
+	return a.renewExpiredDomains(expiredCertList)
+}
+
+// Renew the certificate by filename extract all DNS name from the
+// certificate and renew them one by one by calling to the acmeHandler
+func (a *AutoRenewer) renewExpiredDomains(certs []*ExpiredCerts) ([]string, error) {
+	renewedCertFiles := []string{}
+	for _, expiredCert := range certs {
+		log.Println("Renewing " + expiredCert.Filepath + " (Might take a few minutes)")
+		fileName := filepath.Base(expiredCert.Filepath)
+		certName := fileName[:len(fileName)-len(filepath.Ext(fileName))]
+		_, err := a.AcmeHandler.ObtainCert(expiredCert.Domains, certName, a.RenewerConfig.Email, expiredCert.CA)
+		if err != nil {
+			log.Println("Renew " + fileName + "(" + strings.Join(expiredCert.Domains, ",") + ") failed: " + err.Error())
+		} else {
+			log.Println("Successfully renewed " + filepath.Base(expiredCert.Filepath))
+			renewedCertFiles = append(renewedCertFiles, filepath.Base(expiredCert.Filepath))
+		}
+	}
+
+	return renewedCertFiles, nil
 }
 
 // Write the current renewer config to file

+ 45 - 3
mod/acme/utils.go

@@ -3,6 +3,7 @@ package acme
 import (
 	"crypto/x509"
 	"encoding/pem"
+	"errors"
 	"fmt"
 	"io/ioutil"
 	"time"
@@ -16,8 +17,32 @@ func ExtractIssuerNameFromPEM(pemFilePath string) (string, error) {
 		return "", err
 	}
 
+	return ExtractIssuerName(pemData)
+}
+
+// Get the DNSName in the cert
+func ExtractDomains(certBytes []byte) ([]string, error) {
+	domains := []string{}
+	block, _ := pem.Decode(certBytes)
+	if block != nil {
+		cert, err := x509.ParseCertificate(block.Bytes)
+		if err != nil {
+			return []string{}, err
+		}
+		for _, dnsName := range cert.DNSNames {
+			if !contains(domains, dnsName) {
+				domains = append(domains, dnsName)
+			}
+		}
+
+		return domains, nil
+	}
+	return []string{}, errors.New("decode cert bytes failed")
+}
+
+func ExtractIssuerName(certBytes []byte) (string, error) {
 	// Parse the PEM block
-	block, _ := pem.Decode(pemData)
+	block, _ := pem.Decode(certBytes)
 	if block == nil || block.Type != "CERTIFICATE" {
 		return "", fmt.Errorf("failed to decode PEM block containing certificate")
 	}
@@ -35,8 +60,8 @@ func ExtractIssuerNameFromPEM(pemFilePath string) (string, error) {
 }
 
 // Check if a cert is expired by public key
-func CertIsExpired(certBtyes []byte) bool {
-	block, _ := pem.Decode(certBtyes)
+func CertIsExpired(certBytes []byte) bool {
+	block, _ := pem.Decode(certBytes)
 	if block != nil {
 		cert, err := x509.ParseCertificate(block.Bytes)
 		if err == nil {
@@ -50,3 +75,20 @@ func CertIsExpired(certBtyes []byte) bool {
 	}
 	return false
 }
+
+func CertExpireSoon(certBytes []byte) bool {
+	block, _ := pem.Decode(certBytes)
+	if block != nil {
+		cert, err := x509.ParseCertificate(block.Bytes)
+		if err == nil {
+			expirationDate := cert.NotAfter
+			threshold := 14 * 24 * time.Hour // 14 days
+
+			timeRemaining := time.Until(expirationDate)
+			if timeRemaining <= threshold {
+				return true
+			}
+		}
+	}
+	return false
+}

+ 1 - 1
start.go

@@ -197,7 +197,7 @@ func startupSequence() {
 		Obtaining certificates from ACME Server
 	*/
 	acmeHandler = initACME()
-	acmeAutoRenewer, err = acme.NewAutoRenewer("./rules/acme_conf.json", "./certs/", int64(*acmeAutoRenewInterval))
+	acmeAutoRenewer, err = acme.NewAutoRenewer("./rules/acme_conf.json", "./certs/", int64(*acmeAutoRenewInterval), acmeHandler)
 	if err != nil {
 		log.Fatal(err)
 	}

+ 23 - 2
web/snippet/acme.html

@@ -41,7 +41,7 @@
       <p>Email is required by many CAs for renewing via ACME protocol</p>
       <div class="ui fluid action input">
         <input id="caRegisterEmail" type="text" placeholder="[email protected]">
-        <button class="ui icon basic button" onclick="saveEmailToConfig();">
+        <button class="ui icon basic button" onclick="saveEmailToConfig(this);">
             <i class="blue save icon"></i>
         </button>
       </div>
@@ -122,6 +122,7 @@
 
   <script>
     let expiredDomains = [];
+    let enableTrigerOnChangeEvent = true;
     $(".accordion").accordion();
     $(".dropdown").dropdown();
 
@@ -144,6 +145,9 @@
         }
 
         $("#enableCertAutoRenew").on("change", function(){
+          if (!enableTrigerOnChangeEvent){
+            return;
+          }
           toggleAutoRenew();
         })
       });
@@ -156,7 +160,7 @@
     }
     initRenewerConfigFromFile();
 
-    function saveEmailToConfig(){
+    function saveEmailToConfig(btn){
       $.ajax({
         url: "/api/acme/autoRenew/email",
         data: {set: $("#caRegisterEmail").val()},
@@ -165,6 +169,12 @@
             parent.msgbox(data.error, false, 5000);
           }else{
             parent.msgbox("Email updated");
+            $(btn).html(`<i class="green check icon"></i>`);
+            $(btn).addClass("disabled");
+            setTimeout(function(){
+              $(btn).html(`<i class="blue save icon"></i>`);
+              $(btn).removeClass("disabled");
+            }, 3000);
           }
         }
       });
@@ -175,6 +185,11 @@
       $.post("/api/acme/autoRenew/enable?enable=" + enabled, function(data){
         if (data.error){
           parent.msgbox(data.error, false, 5000);
+          if (enabled){
+            enableTrigerOnChangeEvent = false;
+            $("#enableCertAutoRenew").parent().checkbox("set unchecked");
+            enableTrigerOnChangeEvent = true;
+          }
         }else{
           $("#enableToggleSucc").stop().finish().fadeIn("fast").delay(3000).fadeOut("fast");
         }
@@ -265,6 +280,11 @@
     function obtainCertificate() {
       var domains = $("#domainsInput").val();
       var filename = $("#filenameInput").val();
+      var email = $("#caRegisterEmail").val();
+      if (email == ""){
+        parent.msgbox("ACME renew email is not set")
+        return;
+      }
       if (filename.trim() == "" && !domains.includes(",")){
         //Zoraxy filename are the matching name for domains.
         //Use the same as domains
@@ -284,6 +304,7 @@
         data: {
           domains: domains,
           filename: filename,
+          email: email,
           ca: ca,
         },
         success: function(response) {