package sigv4 import ( "bytes" "context" "crypto/hmac" "crypto/sha256" "encoding/hex" "fmt" "io" "net/http" "net/http/httptest" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestValidateSigV4Middleware(t *testing.T) { tests := []struct { name string setupRequest func() *http.Request expectError bool expectedStatus int expectedCode string }{ { name: "Valid signature", setupRequest: func() *http.Request { body := []byte("Action=GetCallerIdentity&Version=2011-06-15") req := createSignedRequest(t, "POST", "/", body, mockCredentials["AKIAIOSFODNN7EXAMPLE"]) return req }, expectError: false, expectedStatus: http.StatusOK, }, { name: "Missing authorization header", setupRequest: func() *http.Request { req := httptest.NewRequest("POST", "/", bytes.NewReader([]byte("test"))) req.Header.Set("X-Amz-Date", time.Now().UTC().Format("20060102T150405Z")) return req }, expectError: true, expectedStatus: http.StatusForbidden, expectedCode: "MissingAuthenticationToken", }, { name: "Invalid access key", setupRequest: func() *http.Request { body := []byte("Action=GetCallerIdentity&Version=2011-06-15") invalidCreds := AWSCredentials{ AccessKeyID: "INVALID_KEY", SecretAccessKey: "invalid-secret", AccountID: "123456789012", } req := createSignedRequest(t, "POST", "/", body, invalidCreds) return req }, expectError: true, expectedStatus: http.StatusForbidden, expectedCode: "InvalidClientTokenId", }, { name: "Expired request", setupRequest: func() *http.Request { body := []byte("Action=GetCallerIdentity&Version=2011-06-15") req := createSignedRequest(t, "POST", "/", body, mockCredentials["AKIAIOSFODNN7EXAMPLE"]) // Set date to 20 minutes ago expiredDate := time.Now().UTC().Add(-20 * time.Minute).Format("20060102T150405Z") req.Header.Set("X-Amz-Date", expiredDate) return req }, expectError: true, expectedStatus: http.StatusForbidden, expectedCode: "RequestExpired", }, { name: "Invalid signature", setupRequest: func() *http.Request { body := []byte("Action=GetCallerIdentity&Version=2011-06-15") req := createSignedRequest(t, "POST", "/", body, mockCredentials["AKIAIOSFODNN7EXAMPLE"]) // Tamper with signature authHeader := req.Header.Get("Authorization") authHeader = authHeader[:len(authHeader)-10] + "0000000000" req.Header.Set("Authorization", authHeader) return req }, expectError: true, expectedStatus: http.StatusForbidden, expectedCode: "SignatureDoesNotMatch", }, { name: "Missing X-Amz-Date header", setupRequest: func() *http.Request { body := []byte("Action=GetCallerIdentity&Version=2011-06-15") req := createSignedRequest(t, "POST", "/", body, mockCredentials["AKIAIOSFODNN7EXAMPLE"]) req.Header.Del("X-Amz-Date") return req }, expectError: true, expectedStatus: http.StatusForbidden, expectedCode: "InvalidRequest", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { // Create test handler testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) w.Write([]byte("OK")) }) // Wrap with middleware handler := ValidateSigV4Middleware(testHandler) // Create test request and recorder req := tt.setupRequest() rr := httptest.NewRecorder() // Execute request handler.ServeHTTP(rr, req) // Assert status code assert.Equal(t, tt.expectedStatus, rr.Code, "Status code mismatch") if tt.expectError { body := rr.Body.String() assert.Contains(t, body, tt.expectedCode, "Expected error code not found in response") } }) } } func TestParseSigV4Header(t *testing.T) { tests := []struct { name string authHeader string expectError bool validate func(*testing.T, *sigV4Components) }{ { name: "Valid SigV4 header", authHeader: "AWS4-HMAC-SHA256 Credential=AKIAIOSFODNN7EXAMPLE/20231016/us-east-1/sts/aws4_request, " + "SignedHeaders=content-type;host;x-amz-date, Signature=abcd1234", expectError: false, validate: func(t *testing.T, sig *sigV4Components) { assert.Equal(t, "AKIAIOSFODNN7EXAMPLE", sig.AccessKeyID) assert.Equal(t, "20231016", sig.Date) assert.Equal(t, "us-east-1", sig.Region) assert.Equal(t, "sts", sig.Service) assert.Equal(t, "abcd1234", sig.Signature) assert.Equal(t, []string{"content-type", "host", "x-amz-date"}, sig.SignedHeaders) }, }, { name: "Missing algorithm prefix", authHeader: "Credential=AKIAIOSFODNN7EXAMPLE/20231016/us-east-1/sts/aws4_request", expectError: true, }, { name: "Missing credential", authHeader: "AWS4-HMAC-SHA256 SignedHeaders=host, Signature=abcd1234", expectError: true, }, { name: "Missing signature", authHeader: "AWS4-HMAC-SHA256 Credential=AKIAIOSFODNN7EXAMPLE/20231016/us-east-1/sts/aws4_request", expectError: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { sig, err := parseSigV4Header(tt.authHeader) if tt.expectError { assert.Error(t, err) } else { require.NoError(t, err) if tt.validate != nil { tt.validate(t, sig) } } }) } } func TestCalculateSignature(t *testing.T) { creds := mockCredentials["AKIAIOSFODNN7EXAMPLE"] body := []byte("Action=GetCallerIdentity&Version=2011-06-15") req := httptest.NewRequest("POST", "/", bytes.NewReader(body)) req.Header.Set("Content-Type", "application/x-www-form-urlencoded; charset=utf-8") req.Header.Set("Host", "sts.us-east-1.amazonaws.com") req.Header.Set("X-Amz-Date", "20231016T120000Z") sigV4 := &sigV4Components{ AccessKeyID: creds.AccessKeyID, Date: "20231016", Region: "us-east-1", Service: "sts", CredentialScope: "20231016/us-east-1/sts/aws4_request", SignedHeaders: []string{"content-type", "host", "x-amz-date"}, } signature, err := calculateSignature(req, body, creds, sigV4) require.NoError(t, err) assert.NotEmpty(t, signature) assert.Len(t, signature, 64) // SHA256 hex string length } func TestHMACFunctions(t *testing.T) { t.Run("hmacSHA256", func(t *testing.T) { key := []byte("test-key") data := "test-data" result := hmacSHA256(key, data) assert.NotNil(t, result) assert.Greater(t, len(result), 0) }) t.Run("sha256Hash", func(t *testing.T) { data := []byte("test-data") result := sha256Hash(data) assert.NotEmpty(t, result) assert.Len(t, result, 64) // Hex encoded SHA256 }) t.Run("deriveSigningKey", func(t *testing.T) { secretKey := "test-secret" date := "20231016" region := "us-east-1" service := "sts" signingKey := deriveSigningKey(secretKey, date, region, service) assert.NotNil(t, signingKey) assert.Len(t, signingKey, 32) // SHA256 output length }) } // Helper function to create a properly signed request func createSignedRequest(t *testing.T, method, path string, body []byte, creds AWSCredentials) *http.Request { req := httptest.NewRequest(method, path, bytes.NewReader(body)) now := time.Now().UTC() dateStamp := now.Format("20060102") amzDate := now.Format("20060102T150405Z") req.Header.Set("Content-Type", "application/x-www-form-urlencoded; charset=utf-8") req.Header.Set("Host", "sts.us-east-1.amazonaws.com") req.Header.Set("X-Amz-Date", amzDate) if creds.SessionToken != "" { req.Header.Set("X-Amz-Security-Token", creds.SessionToken) } // Calculate canonical request canonicalURI := path if canonicalURI == "" { canonicalURI = "/" } canonicalHeaders := fmt.Sprintf("content-type:%s\nhost:%s\nx-amz-date:%s\n", req.Header.Get("Content-Type"), req.Header.Get("Host"), amzDate, ) signedHeaders := "content-type;host;x-amz-date" if creds.SessionToken != "" { canonicalHeaders += fmt.Sprintf("x-amz-security-token:%s\n", creds.SessionToken) signedHeaders += ";x-amz-security-token" } payloadHash := sha256Hash(body) canonicalRequest := fmt.Sprintf("%s\n%s\n\n%s\n%s\n%s", method, canonicalURI, canonicalHeaders, signedHeaders, payloadHash, ) // Create string to sign credentialScope := fmt.Sprintf("%s/%s/%s/aws4_request", dateStamp, "us-east-1", "sts") hashedCanonicalRequest := sha256Hash([]byte(canonicalRequest)) stringToSign := fmt.Sprintf("AWS4-HMAC-SHA256\n%s\n%s\n%s", amzDate, credentialScope, hashedCanonicalRequest, ) // Calculate signature signingKey := deriveSigningKey(creds.SecretAccessKey, dateStamp, "us-east-1", "sts") signature := hmacSHA256(signingKey, stringToSign) signatureHex := hex.EncodeToString(signature) // Build authorization header authorization := fmt.Sprintf("AWS4-HMAC-SHA256 Credential=%s/%s, SignedHeaders=%s, Signature=%s", creds.AccessKeyID, credentialScope, signedHeaders, signatureHex, ) req.Header.Set("Authorization", authorization) return req } func sha256Hash(data []byte) string { hash := sha256.Sum256(data) return hex.EncodeToString(hash[:]) }