package main import ( "bytes" "context" "crypto/hmac" "crypto/sha256" "encoding/hex" "encoding/xml" "fmt" "net/http" "net/http/httptest" "testing" "time" "aws-sts-mock/pkg/sigv4" "aws-sts-mock/pkg/sts" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestGetCallerIdentityEndpoint(t *testing.T) { tests := []struct { name string setupRequest func() *http.Request expectError bool expectedStatus int validateResp func(*testing.T, *sts.GetCallerIdentityResponse) }{ { name: "Valid GetCallerIdentity request", setupRequest: func() *http.Request { body := []byte("Action=GetCallerIdentity&Version=2011-06-15") creds := sigv4.AWSCredentials{ AccessKeyID: "AKIAIOSFODNN7EXAMPLE", SecretAccessKey: "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY", AccountID: "123456789012", } req := createTestSignedRequest(t, "POST", "/", body, creds) return req }, expectError: false, expectedStatus: http.StatusOK, validateResp: func(t *testing.T, resp *sts.GetCallerIdentityResponse) { assert.Equal(t, "123456789012", resp.GetCallerIdentityResult.Account) assert.Equal(t, "123456789012", resp.GetCallerIdentityResult.UserId) assert.Equal(t, "arn:aws:iam::123456789012:root", resp.GetCallerIdentityResult.Arn) assert.NotEmpty(t, resp.ResponseMetadata.RequestId) }, }, { name: "Invalid action", setupRequest: func() *http.Request { body := []byte("Action=InvalidAction&Version=2011-06-15") creds := sigv4.AWSCredentials{ AccessKeyID: "AKIAIOSFODNN7EXAMPLE", SecretAccessKey: "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY", AccountID: "123456789012", } req := createTestSignedRequest(t, "POST", "/", body, creds) return req }, expectError: true, expectedStatus: http.StatusBadRequest, }, { name: "Invalid version", setupRequest: func() *http.Request { body := []byte("Action=GetCallerIdentity&Version=2099-01-01") creds := sigv4.AWSCredentials{ AccessKeyID: "AKIAIOSFODNN7EXAMPLE", SecretAccessKey: "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY", AccountID: "123456789012", } req := createTestSignedRequest(t, "POST", "/", body, creds) return req }, expectError: true, expectedStatus: http.StatusBadRequest, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { // Create test server with middleware handler := sigv4.ValidateSigV4Middleware(handleSTSRequest) req := tt.setupRequest() rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) assert.Equal(t, tt.expectedStatus, rr.Code) if !tt.expectError && tt.validateResp != nil { var response sts.GetCallerIdentityResponse err := xml.Unmarshal(rr.Body.Bytes(), &response) require.NoError(t, err) tt.validateResp(t, &response) } }) } } func TestHealthEndpoint(t *testing.T) { req := httptest.NewRequest("GET", "/health", nil) rr := httptest.NewRecorder() handleHealth(rr, req) assert.Equal(t, http.StatusOK, rr.Code) assert.Contains(t, rr.Body.String(), "healthy") } func TestErrorResponse(t *testing.T) { rr := httptest.NewRecorder() writeErrorResponse(rr, "TestError", "This is a test error") assert.Equal(t, http.StatusBadRequest, rr.Code) assert.Contains(t, rr.Body.String(), "TestError") assert.Contains(t, rr.Body.String(), "This is a test error") var errorResp sts.ErrorResponse err := xml.Unmarshal(rr.Body.Bytes(), &errorResp) require.NoError(t, err) assert.Equal(t, "TestError", errorResp.Error.Code) } // Helper function to create test signed requests func createTestSignedRequest(t *testing.T, method, path string, body []byte, creds sigv4.AWSCredentials) *http.Request { req := httptest.NewRequest(method, path, bytes.NewReader(body)) // Set credentials in context (simulating successful auth) ctx := context.WithValue(req.Context(), sigv4.CredentialsContextKey, creds) req = req.WithContext(ctx) now := time.Now().UTC() dateStamp := now.Format("20060102") amzDate := now.Format("20060102T150405Z") req.Header.Set("Content-Type", "application/x-www-form-urlencoded; charset=utf-8") req.Header.Set("Host", "sts.us-east-1.amazonaws.com") req.Header.Set("X-Amz-Date", amzDate) if creds.SessionToken != "" { req.Header.Set("X-Amz-Security-Token", creds.SessionToken) } canonicalURI := path if canonicalURI == "" { canonicalURI = "/" } signedHeaders := "content-type;host;x-amz-date" canonicalHeaders := fmt.Sprintf("content-type:%s\nhost:%s\nx-amz-date:%s\n", req.Header.Get("Content-Type"), req.Header.Get("Host"), amzDate, ) if creds.SessionToken != "" { canonicalHeaders += fmt.Sprintf("x-amz-security-token:%s\n", creds.SessionToken) signedHeaders += ";x-amz-security-token" } hash := sha256.Sum256(body) payloadHash := hex.EncodeToString(hash[:]) canonicalRequest := fmt.Sprintf("%s\n%s\n\n%s\n%s\n%s", method, canonicalURI, canonicalHeaders, signedHeaders, payloadHash, ) credentialScope := fmt.Sprintf("%s/%s/%s/aws4_request", dateStamp, "us-east-1", "sts") hash2 := sha256.Sum256([]byte(canonicalRequest)) hashedCanonicalRequest := hex.EncodeToString(hash2[:]) stringToSign := fmt.Sprintf("AWS4-HMAC-SHA256\n%s\n%s\n%s", amzDate, credentialScope, hashedCanonicalRequest, ) // Derive signing key kDate := hmacSHA256([]byte("AWS4"+creds.SecretAccessKey), dateStamp) kRegion := hmacSHA256(kDate, "us-east-1") kService := hmacSHA256(kRegion, "sts") kSigning := hmacSHA256(kService, "aws4_request") signature := hmacSHA256(kSigning, stringToSign) signatureHex := hex.EncodeToString(signature) authorization := fmt.Sprintf("AWS4-HMAC-SHA256 Credential=%s/%s, SignedHeaders=%s, Signature=%s", creds.AccessKeyID, credentialScope, signedHeaders, signatureHex, ) req.Header.Set("Authorization", authorization) return req } func hmacSHA256(key []byte, data string) []byte { h := hmac.New(sha256.New, key) h.Write([]byte(data)) return h.Sum(nil) }