123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404 |
- package sigv4
- import (
- "bytes"
- "context"
- "crypto/hmac"
- "crypto/sha256"
- "encoding/hex"
- "fmt"
- "io"
- "log"
- "net/http"
- "sort"
- "strings"
- "time"
- )
- type contextKey string
- const (
- CredentialsContextKey contextKey = "credentials"
- )
- type AWSCredentials struct {
- AccessKeyID string
- SecretAccessKey string
- SessionToken string
- AccountID string
- }
- // Mock credential store - in production, this would be a database or external service
- var mockCredentials = map[string]AWSCredentials{
- "AKIAIOSFODNN7EXAMPLE": {
- AccessKeyID: "AKIAIOSFODNN7EXAMPLE",
- SecretAccessKey: "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY",
- AccountID: "123456789012",
- },
- "ASIAUIJXACK3L66H7KB4": {
- AccessKeyID: "ASIAUIJXACK3L66H7KB4",
- SecretAccessKey: "test-secret-key",
- SessionToken: "test-session-token",
- AccountID: "292709995190",
- },
- }
- // ValidateSigV4Middleware validates AWS Signature Version 4
- func ValidateSigV4Middleware(next http.HandlerFunc) http.HandlerFunc {
- return func(w http.ResponseWriter, r *http.Request) {
- // Read the body FIRST before any processing
- bodyBytes, err := io.ReadAll(r.Body)
- if err != nil {
- writeAuthError(w, "InvalidRequest", "Failed to read request body")
- return
- }
- // Restore the body for downstream handlers
- r.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
- // Log the body for debugging
- log.Printf("Request body length: %d bytes", len(bodyBytes))
- //if len(bodyBytes) > 0 {
- //log.Printf("Request body: %s", string(bodyBytes))
- //}
- // Parse authorization header
- authHeader := r.Header.Get("Authorization")
- if authHeader == "" {
- writeAuthError(w, "MissingAuthenticationToken", "Authorization header cannot be empty")
- return
- }
- // Parse SigV4 components
- sigV4, err := parseSigV4Header(authHeader)
- if err != nil {
- writeAuthError(w, "IncompleteSignature", err.Error())
- return
- }
- // Validate credential
- creds, ok := mockCredentials[sigV4.AccessKeyID]
- if !ok {
- writeAuthError(w, "InvalidClientTokenId", "The security token included in the request is invalid")
- return
- }
- // Validate date
- dateHeader := r.Header.Get("X-Amz-Date")
- if dateHeader == "" {
- writeAuthError(w, "InvalidRequest", "X-Amz-Date header is required")
- return
- }
- reqTime, err := time.Parse("20060102T150405Z", dateHeader)
- if err != nil {
- writeAuthError(w, "InvalidRequest", "Invalid X-Amz-Date format")
- return
- }
- // Check if request is within 15 minutes
- now := time.Now().UTC()
- if now.Sub(reqTime) > 15*time.Minute {
- writeAuthError(w, "RequestExpired", "Request has expired")
- return
- }
- if reqTime.Sub(now) > 15*time.Minute {
- writeAuthError(w, "SignatureDoesNotMatch", "Signature not yet current")
- return
- }
- // Calculate expected signature with the actual body bytes
- expectedSig, err := calculateSignature(r, bodyBytes, creds, sigV4)
- if err != nil {
- writeAuthError(w, "InternalError", fmt.Sprintf("Failed to calculate signature: %v", err))
- return
- }
- // Compare signatures
- log.Printf("Expected signature: %s", expectedSig)
- log.Printf("Provided signature: %s", sigV4.Signature)
- if expectedSig != sigV4.Signature {
- writeAuthError(w, "SignatureDoesNotMatch",
- "The request signature we calculated does not match the signature you provided")
- return
- }
- // Validate session token if present
- if creds.SessionToken != "" {
- reqToken := r.Header.Get("X-Amz-Security-Token")
- if reqToken != creds.SessionToken {
- writeAuthError(w, "InvalidClientTokenId", "The security token included in the request is invalid")
- return
- }
- }
- // Add credentials to context
- ctx := context.WithValue(r.Context(), CredentialsContextKey, creds)
- next.ServeHTTP(w, r.WithContext(ctx))
- }
- }
- type sigV4Components struct {
- AccessKeyID string
- CredentialScope string
- SignedHeaders []string
- Signature string
- Date string
- Region string
- Service string
- }
- func parseSigV4Header(authHeader string) (*sigV4Components, error) {
- if !strings.HasPrefix(authHeader, "AWS4-HMAC-SHA256 ") {
- return nil, fmt.Errorf("Authorization header must start with 'AWS4-HMAC-SHA256'")
- }
- authHeader = strings.TrimPrefix(authHeader, "AWS4-HMAC-SHA256 ")
- parts := strings.Split(authHeader, ", ")
- sig := &sigV4Components{}
- for _, part := range parts {
- kv := strings.SplitN(part, "=", 2)
- if len(kv) != 2 {
- return nil, fmt.Errorf("Invalid key=value pair in Authorization header")
- }
- key := strings.TrimSpace(kv[0])
- value := strings.TrimSpace(kv[1])
- switch key {
- case "Credential":
- credParts := strings.Split(value, "/")
- if len(credParts) != 5 {
- return nil, fmt.Errorf("Invalid credential format")
- }
- sig.AccessKeyID = strings.ReplaceAll(credParts[0], "\"", "")
- sig.Date = credParts[1]
- sig.Region = credParts[2]
- sig.Service = credParts[3]
- sig.CredentialScope = strings.Join(credParts[1:], "/")
- case "SignedHeaders":
- sig.SignedHeaders = strings.Split(value, ";")
- case "Signature":
- sig.Signature = value
- }
- }
- if sig.AccessKeyID == "" {
- return nil, fmt.Errorf("Authorization header requires 'Credential' parameter")
- }
- if sig.Signature == "" {
- return nil, fmt.Errorf("Authorization header requires 'Signature' parameter")
- }
- return sig, nil
- }
- func calculateSignature(r *http.Request, body []byte, creds AWSCredentials, sigV4 *sigV4Components) (string, error) {
- amzDate := r.Header.Get("X-Amz-Date")
- canonicalRequest := buildCanonicalRequest(r, body, sigV4.SignedHeaders)
- stringToSign := buildStringToSign(canonicalRequest, sigV4, amzDate)
- log.Printf("String to sign:\n%s", stringToSign)
- signingKey := deriveSigningKey(creds.SecretAccessKey, sigV4.Date, sigV4.Region, sigV4.Service)
- signature := hmacSHA256(signingKey, stringToSign)
- finalSig := hex.EncodeToString(signature)
- return finalSig, nil
- }
- func buildCanonicalRequest(r *http.Request, body []byte, signedHeaders []string) string {
- // Method
- method := r.Method
- // Canonical URI - Use RequestURI which preserves URL encoding
- // Split RequestURI to get just the path (before the query string)
- canonicalURI := r.RequestURI
- if idx := strings.Index(canonicalURI, "?"); idx != -1 {
- canonicalURI = canonicalURI[:idx]
- }
- if canonicalURI == "" {
- canonicalURI = "/"
- }
- // Canonical query string - MUST BE SORTED
- canonicalQueryString := buildCanonicalQueryString(r)
- // Canonical headers (already includes trailing newlines)
- canonicalHeaders := buildCanonicalHeaders(r, signedHeaders)
- // Signed headers
- signedHeadersStr := strings.Join(signedHeaders, ";")
- // Payload hash - Check if client sent UNSIGNED-PAYLOAD
- var payloadHash string
- amzContentSha256 := r.Header.Get("X-Amz-Content-SHA256")
- if amzContentSha256 == "UNSIGNED-PAYLOAD" {
- // Use the literal string for streaming/multipart uploads
- payloadHash = "UNSIGNED-PAYLOAD"
- } else if amzContentSha256 != "" {
- // Use the hash provided by the client
- payloadHash = amzContentSha256
- } else {
- // Calculate hash from the actual body
- payloadHash = sha256Hash(body)
- }
- canonicalRequest := fmt.Sprintf("%s\n%s\n%s\n%s\n%s\n%s",
- method,
- canonicalURI,
- canonicalQueryString,
- canonicalHeaders,
- signedHeadersStr,
- payloadHash,
- )
- log.Printf("=== Canonical Request ===")
- log.Printf("Method: %s", method)
- log.Printf("Canonical URI: %s", canonicalURI)
- log.Printf("Canonical Query String: %s", canonicalQueryString)
- log.Printf("Canonical Headers:\n%s", canonicalHeaders)
- log.Printf("Signed Headers: %s", signedHeadersStr)
- log.Printf("Payload Hash: %s", payloadHash)
- log.Printf("Full Canonical Request:\n%s", canonicalRequest)
- log.Printf("========================")
- return canonicalRequest
- }
- func buildCanonicalQueryString(r *http.Request) string {
- if r.URL.RawQuery == "" {
- return ""
- }
- // Split into key-value pairs
- params := strings.Split(r.URL.RawQuery, "&")
- // Parse each parameter to separate key and value
- type param struct {
- key string
- value string
- }
- parsedParams := make([]param, 0, len(params))
- for _, p := range params {
- parts := strings.SplitN(p, "=", 2)
- if len(parts) == 2 {
- parsedParams = append(parsedParams, param{key: parts[0], value: parts[1]})
- } else {
- parsedParams = append(parsedParams, param{key: parts[0], value: ""})
- }
- }
- // Sort by key, then by value
- sort.Slice(parsedParams, func(i, j int) bool {
- if parsedParams[i].key == parsedParams[j].key {
- return parsedParams[i].value < parsedParams[j].value
- }
- return parsedParams[i].key < parsedParams[j].key
- })
- // Rebuild the query string
- result := make([]string, len(parsedParams))
- for i, p := range parsedParams {
- if p.value == "" {
- result[i] = p.key + "="
- } else {
- result[i] = p.key + "=" + p.value
- }
- }
- return strings.Join(result, "&")
- }
- func buildCanonicalHeaders(r *http.Request, signedHeaders []string) string {
- headers := make(map[string]string)
- for _, header := range signedHeaders {
- headerLower := strings.ToLower(strings.TrimSpace(header))
- // Special handling for Host header
- if headerLower == "host" {
- headers[headerLower] = r.Host
- continue
- }
- values := r.Header[http.CanonicalHeaderKey(header)]
- if len(values) > 0 {
- trimmedValues := make([]string, len(values))
- for i, v := range values {
- trimmedValues[i] = strings.TrimSpace(v)
- }
- headers[headerLower] = strings.Join(trimmedValues, ",")
- }
- }
- // Sort headers
- keys := make([]string, 0, len(headers))
- for k := range headers {
- keys = append(keys, k)
- }
- sort.Strings(keys)
- // Build canonical headers string
- var canonical strings.Builder
- for _, k := range keys {
- canonical.WriteString(k)
- canonical.WriteString(":")
- canonical.WriteString(headers[k])
- canonical.WriteString("\n")
- }
- return canonical.String()
- }
- func buildStringToSign(canonicalRequest string, sigV4 *sigV4Components, amzDate string) string {
- canonicalRequestHash := sha256Hash([]byte(canonicalRequest))
- stringToSign := fmt.Sprintf("AWS4-HMAC-SHA256\n%s\n%s\n%s",
- amzDate,
- sigV4.CredentialScope,
- canonicalRequestHash,
- )
- return stringToSign
- }
- func deriveSigningKey(secretKey, date, region, service string) []byte {
- kDate := hmacSHA256([]byte("AWS4"+secretKey), date)
- kRegion := hmacSHA256(kDate, region)
- kService := hmacSHA256(kRegion, service)
- kSigning := hmacSHA256(kService, "aws4_request")
- return kSigning
- }
- func hmacSHA256(key []byte, data string) []byte {
- h := hmac.New(sha256.New, key)
- h.Write([]byte(data))
- return h.Sum(nil)
- }
- func sha256Hash(data []byte) string {
- hash := sha256.Sum256(data)
- return hex.EncodeToString(hash[:])
- }
- func writeAuthError(w http.ResponseWriter, code, message string) {
- w.Header().Set("Content-Type", "application/xml")
- w.WriteHeader(http.StatusForbidden)
- errorXML := fmt.Sprintf(`<?xml version="1.0" encoding="UTF-8"?>
- <ErrorResponse xmlns="https://sts.amazonaws.com/doc/2011-06-15/">
- <Error>
- <Type>Sender</Type>
- <Code>%s</Code>
- <Message>%s</Message>
- </Error>
- <RequestId>%d</RequestId>
- </ErrorResponse>`, code, message, time.Now().Unix())
- w.Write([]byte(errorXML))
- }
|