package sigv4 import ( "bytes" "context" "crypto/hmac" "crypto/sha256" "encoding/hex" "fmt" "io" "log" "net/http" "sort" "strings" "time" ) type contextKey string const ( CredentialsContextKey contextKey = "credentials" ) type AWSCredentials struct { AccessKeyID string SecretAccessKey string SessionToken string AccountID string } // Mock credential store - in production, this would be a database or external service var mockCredentials = map[string]AWSCredentials{ "AKIAIOSFODNN7EXAMPLE": { AccessKeyID: "AKIAIOSFODNN7EXAMPLE", SecretAccessKey: "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY", AccountID: "123456789012", }, "ASIAUIJXACK3L66H7KB4": { AccessKeyID: "ASIAUIJXACK3L66H7KB4", SecretAccessKey: "test-secret-key", SessionToken: "test-session-token", AccountID: "292709995190", }, } // ValidateSigV4Middleware validates AWS Signature Version 4 func ValidateSigV4Middleware(next http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { // Read the body bodyBytes, err := io.ReadAll(r.Body) if err != nil { writeAuthError(w, "InvalidRequest", "Failed to read request body") return } r.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) // Parse authorization header authHeader := r.Header.Get("Authorization") if authHeader == "" { writeAuthError(w, "MissingAuthenticationToken", "Authorization header cannot be empty") return } // Parse SigV4 components sigV4, err := parseSigV4Header(authHeader) if err != nil { writeAuthError(w, "IncompleteSignature", err.Error()) return } // Validate credential creds, ok := mockCredentials[sigV4.AccessKeyID] if !ok { writeAuthError(w, "InvalidClientTokenId", "The security token included in the request is invalid") return } // Validate date dateHeader := r.Header.Get("X-Amz-Date") if dateHeader == "" { writeAuthError(w, "InvalidRequest", "X-Amz-Date header is required") return } reqTime, err := time.Parse("20060102T150405Z", dateHeader) if err != nil { writeAuthError(w, "InvalidRequest", "Invalid X-Amz-Date format") return } // Check if request is within 15 minutes now := time.Now().UTC() if now.Sub(reqTime) > 15*time.Minute { writeAuthError(w, "RequestExpired", "Request has expired") return } if reqTime.Sub(now) > 15*time.Minute { writeAuthError(w, "SignatureDoesNotMatch", "Signature not yet current") return } // Calculate expected signature expectedSig, err := calculateSignature(r, bodyBytes, creds, sigV4) if err != nil { writeAuthError(w, "InternalError", fmt.Sprintf("Failed to calculate signature: %v", err)) return } // Compare signatures //log.Printf("Expected signature: %s", expectedSig) //log.Printf("Provided signature: %s", sigV4.Signature) if expectedSig != sigV4.Signature { writeAuthError(w, "SignatureDoesNotMatch", "The request signature we calculated does not match the signature you provided") return } // Validate session token if present if creds.SessionToken != "" { reqToken := r.Header.Get("X-Amz-Security-Token") if reqToken != creds.SessionToken { writeAuthError(w, "InvalidClientTokenId", "The security token included in the request is invalid") return } } // Add credentials to context ctx := context.WithValue(r.Context(), CredentialsContextKey, creds) next.ServeHTTP(w, r.WithContext(ctx)) } } type sigV4Components struct { AccessKeyID string CredentialScope string SignedHeaders []string Signature string Date string Region string Service string } func parseSigV4Header(authHeader string) (*sigV4Components, error) { if !strings.HasPrefix(authHeader, "AWS4-HMAC-SHA256 ") { return nil, fmt.Errorf("Authorization header must start with 'AWS4-HMAC-SHA256'") } authHeader = strings.TrimPrefix(authHeader, "AWS4-HMAC-SHA256 ") parts := strings.Split(authHeader, ", ") sig := &sigV4Components{} for _, part := range parts { kv := strings.SplitN(part, "=", 2) if len(kv) != 2 { return nil, fmt.Errorf("Invalid key=value pair in Authorization header") } key := strings.TrimSpace(kv[0]) value := strings.TrimSpace(kv[1]) switch key { case "Credential": credParts := strings.Split(value, "/") if len(credParts) != 5 { return nil, fmt.Errorf("Invalid credential format") } sig.AccessKeyID = strings.ReplaceAll(credParts[0], "\"", "") sig.Date = credParts[1] sig.Region = credParts[2] sig.Service = credParts[3] sig.CredentialScope = strings.Join(credParts[1:], "/") case "SignedHeaders": sig.SignedHeaders = strings.Split(value, ";") case "Signature": sig.Signature = value } } if sig.AccessKeyID == "" { return nil, fmt.Errorf("Authorization header requires 'Credential' parameter") } if sig.Signature == "" { return nil, fmt.Errorf("Authorization header requires 'Signature' parameter") } return sig, nil } func calculateSignature(r *http.Request, body []byte, creds AWSCredentials, sigV4 *sigV4Components) (string, error) { amzDate := r.Header.Get("X-Amz-Date") canonicalRequest := buildCanonicalRequest(r, body, sigV4.SignedHeaders) stringToSign := buildStringToSign(canonicalRequest, sigV4, amzDate) log.Printf("String to sign:\n%s", stringToSign) signingKey := deriveSigningKey(creds.SecretAccessKey, sigV4.Date, sigV4.Region, sigV4.Service) signature := hmacSHA256(signingKey, stringToSign) finalSig := hex.EncodeToString(signature) return finalSig, nil } func buildCanonicalRequest(r *http.Request, body []byte, signedHeaders []string) string { // Method method := r.Method // Canonical URI - Extract from RequestURI to get the path before any routing modifications // RequestURI includes the full path, e.g., "/my-test-bucket/4.psd?uploads" canonicalURI := r.URL.Path if canonicalURI == "" { canonicalURI = "/" } // Canonical query string - MUST BE SORTED canonicalQueryString := buildCanonicalQueryString(r) // Canonical headers (already includes trailing newlines) canonicalHeaders := buildCanonicalHeaders(r, signedHeaders) // Signed headers signedHeadersStr := strings.Join(signedHeaders, ";") // Payload hash payloadHash := sha256Hash(body) canonicalRequest := fmt.Sprintf("%s\n%s\n%s\n%s\n%s\n%s", method, canonicalURI, canonicalQueryString, canonicalHeaders, signedHeadersStr, payloadHash, ) /* log.Printf("=== Canonical Request ===") log.Printf("Method: %s", method) log.Printf("Canonical URI: %s", canonicalURI) log.Printf("Canonical Query String: %s", canonicalQueryString) log.Printf("Canonical Headers:\n%s", canonicalHeaders) log.Printf("Signed Headers: %s", signedHeadersStr) log.Printf("Payload Hash: %s", payloadHash) log.Printf("Full Canonical Request:\n%s", canonicalRequest) log.Printf("========================") */ return canonicalRequest } func buildCanonicalQueryString(r *http.Request) string { if r.URL.RawQuery == "" { return "" } // Parse the query parameters properly query := r.URL.Query() // Get all keys and sort them keys := make([]string, 0, len(query)) for k := range query { keys = append(keys, k) } sort.Strings(keys) // Build the canonical query string var params []string for _, key := range keys { values := query[key] // Sort values for this key sort.Strings(values) for _, value := range values { if value == "" { // Empty value - add just key with equals sign params = append(params, key+"=") } else { // Non-empty value params = append(params, key+"="+value) } } } return strings.Join(params, "&") } func buildCanonicalHeaders(r *http.Request, signedHeaders []string) string { headers := make(map[string]string) for _, header := range signedHeaders { headerLower := strings.ToLower(strings.TrimSpace(header)) // Special handling for Host header if headerLower == "host" { headers[headerLower] = r.Host continue } values := r.Header[http.CanonicalHeaderKey(header)] if len(values) > 0 { trimmedValues := make([]string, len(values)) for i, v := range values { trimmedValues[i] = strings.TrimSpace(v) } headers[headerLower] = strings.Join(trimmedValues, ",") } } // Sort headers keys := make([]string, 0, len(headers)) for k := range headers { keys = append(keys, k) } sort.Strings(keys) // Build canonical headers string var canonical strings.Builder for _, k := range keys { canonical.WriteString(k) canonical.WriteString(":") canonical.WriteString(headers[k]) canonical.WriteString("\n") } return canonical.String() } func buildStringToSign(canonicalRequest string, sigV4 *sigV4Components, amzDate string) string { canonicalRequestHash := sha256Hash([]byte(canonicalRequest)) stringToSign := fmt.Sprintf("AWS4-HMAC-SHA256\n%s\n%s\n%s", amzDate, sigV4.CredentialScope, canonicalRequestHash, ) return stringToSign } func deriveSigningKey(secretKey, date, region, service string) []byte { kDate := hmacSHA256([]byte("AWS4"+secretKey), date) kRegion := hmacSHA256(kDate, region) kService := hmacSHA256(kRegion, service) kSigning := hmacSHA256(kService, "aws4_request") return kSigning } func hmacSHA256(key []byte, data string) []byte { h := hmac.New(sha256.New, key) h.Write([]byte(data)) return h.Sum(nil) } func sha256Hash(data []byte) string { hash := sha256.Sum256(data) return hex.EncodeToString(hash[:]) } func writeAuthError(w http.ResponseWriter, code, message string) { w.Header().Set("Content-Type", "application/xml") w.WriteHeader(http.StatusForbidden) errorXML := fmt.Sprintf(` Sender %s %s %d `, code, message, time.Now().Unix()) w.Write([]byte(errorXML)) }