123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214 |
- package main
- import (
- "bytes"
- "context"
- "crypto/hmac"
- "crypto/sha256"
- "encoding/hex"
- "encoding/xml"
- "fmt"
- "net/http"
- "net/http/httptest"
- "testing"
- "time"
- "aws-sts-mock/pkg/sigv4"
- "aws-sts-mock/pkg/sts"
- "github.com/stretchr/testify/assert"
- "github.com/stretchr/testify/require"
- )
- func TestGetCallerIdentityEndpoint(t *testing.T) {
- tests := []struct {
- name string
- setupRequest func() *http.Request
- expectError bool
- expectedStatus int
- validateResp func(*testing.T, *sts.GetCallerIdentityResponse)
- }{
- {
- name: "Valid GetCallerIdentity request",
- setupRequest: func() *http.Request {
- body := []byte("Action=GetCallerIdentity&Version=2011-06-15")
- creds := sigv4.AWSCredentials{
- AccessKeyID: "AKIAIOSFODNN7EXAMPLE",
- SecretAccessKey: "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY",
- AccountID: "123456789012",
- }
- req := createTestSignedRequest(t, "POST", "/", body, creds)
- return req
- },
- expectError: false,
- expectedStatus: http.StatusOK,
- validateResp: func(t *testing.T, resp *sts.GetCallerIdentityResponse) {
- assert.Equal(t, "123456789012", resp.GetCallerIdentityResult.Account)
- assert.Equal(t, "123456789012", resp.GetCallerIdentityResult.UserId)
- assert.Equal(t, "arn:aws:iam::123456789012:root", resp.GetCallerIdentityResult.Arn)
- assert.NotEmpty(t, resp.ResponseMetadata.RequestId)
- },
- },
- {
- name: "Invalid action",
- setupRequest: func() *http.Request {
- body := []byte("Action=InvalidAction&Version=2011-06-15")
- creds := sigv4.AWSCredentials{
- AccessKeyID: "AKIAIOSFODNN7EXAMPLE",
- SecretAccessKey: "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY",
- AccountID: "123456789012",
- }
- req := createTestSignedRequest(t, "POST", "/", body, creds)
- return req
- },
- expectError: true,
- expectedStatus: http.StatusBadRequest,
- },
- {
- name: "Invalid version",
- setupRequest: func() *http.Request {
- body := []byte("Action=GetCallerIdentity&Version=2099-01-01")
- creds := sigv4.AWSCredentials{
- AccessKeyID: "AKIAIOSFODNN7EXAMPLE",
- SecretAccessKey: "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY",
- AccountID: "123456789012",
- }
- req := createTestSignedRequest(t, "POST", "/", body, creds)
- return req
- },
- expectError: true,
- expectedStatus: http.StatusBadRequest,
- },
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- // Create test server with middleware
- handler := sigv4.ValidateSigV4Middleware(handleSTSRequest)
-
- req := tt.setupRequest()
- rr := httptest.NewRecorder()
-
- handler.ServeHTTP(rr, req)
-
- assert.Equal(t, tt.expectedStatus, rr.Code)
-
- if !tt.expectError && tt.validateResp != nil {
- var response sts.GetCallerIdentityResponse
- err := xml.Unmarshal(rr.Body.Bytes(), &response)
- require.NoError(t, err)
- tt.validateResp(t, &response)
- }
- })
- }
- }
- func TestHealthEndpoint(t *testing.T) {
- req := httptest.NewRequest("GET", "/health", nil)
- rr := httptest.NewRecorder()
-
- handleHealth(rr, req)
-
- assert.Equal(t, http.StatusOK, rr.Code)
- assert.Contains(t, rr.Body.String(), "healthy")
- }
- func TestErrorResponse(t *testing.T) {
- rr := httptest.NewRecorder()
- writeErrorResponse(rr, "TestError", "This is a test error")
-
- assert.Equal(t, http.StatusBadRequest, rr.Code)
- assert.Contains(t, rr.Body.String(), "TestError")
- assert.Contains(t, rr.Body.String(), "This is a test error")
-
- var errorResp sts.ErrorResponse
- err := xml.Unmarshal(rr.Body.Bytes(), &errorResp)
- require.NoError(t, err)
- assert.Equal(t, "TestError", errorResp.Error.Code)
- }
- // Helper function to create test signed requests
- func createTestSignedRequest(t *testing.T, method, path string, body []byte, creds sigv4.AWSCredentials) *http.Request {
- req := httptest.NewRequest(method, path, bytes.NewReader(body))
-
- // Set credentials in context (simulating successful auth)
- ctx := context.WithValue(req.Context(), sigv4.CredentialsContextKey, creds)
- req = req.WithContext(ctx)
-
- now := time.Now().UTC()
- dateStamp := now.Format("20060102")
- amzDate := now.Format("20060102T150405Z")
-
- req.Header.Set("Content-Type", "application/x-www-form-urlencoded; charset=utf-8")
- req.Header.Set("Host", "sts.us-east-1.amazonaws.com")
- req.Header.Set("X-Amz-Date", amzDate)
-
- if creds.SessionToken != "" {
- req.Header.Set("X-Amz-Security-Token", creds.SessionToken)
- }
-
- canonicalURI := path
- if canonicalURI == "" {
- canonicalURI = "/"
- }
-
- signedHeaders := "content-type;host;x-amz-date"
- canonicalHeaders := fmt.Sprintf("content-type:%s\nhost:%s\nx-amz-date:%s\n",
- req.Header.Get("Content-Type"),
- req.Header.Get("Host"),
- amzDate,
- )
-
- if creds.SessionToken != "" {
- canonicalHeaders += fmt.Sprintf("x-amz-security-token:%s\n", creds.SessionToken)
- signedHeaders += ";x-amz-security-token"
- }
-
- hash := sha256.Sum256(body)
- payloadHash := hex.EncodeToString(hash[:])
-
- canonicalRequest := fmt.Sprintf("%s\n%s\n\n%s\n%s\n%s",
- method,
- canonicalURI,
- canonicalHeaders,
- signedHeaders,
- payloadHash,
- )
-
- credentialScope := fmt.Sprintf("%s/%s/%s/aws4_request", dateStamp, "us-east-1", "sts")
-
- hash2 := sha256.Sum256([]byte(canonicalRequest))
- hashedCanonicalRequest := hex.EncodeToString(hash2[:])
-
- stringToSign := fmt.Sprintf("AWS4-HMAC-SHA256\n%s\n%s\n%s",
- amzDate,
- credentialScope,
- hashedCanonicalRequest,
- )
-
- // Derive signing key
- kDate := hmacSHA256([]byte("AWS4"+creds.SecretAccessKey), dateStamp)
- kRegion := hmacSHA256(kDate, "us-east-1")
- kService := hmacSHA256(kRegion, "sts")
- kSigning := hmacSHA256(kService, "aws4_request")
-
- signature := hmacSHA256(kSigning, stringToSign)
- signatureHex := hex.EncodeToString(signature)
-
- authorization := fmt.Sprintf("AWS4-HMAC-SHA256 Credential=%s/%s, SignedHeaders=%s, Signature=%s",
- creds.AccessKeyID,
- credentialScope,
- signedHeaders,
- signatureHex,
- )
-
- req.Header.Set("Authorization", authorization)
-
- return req
- }
- func hmacSHA256(key []byte, data string) []byte {
- h := hmac.New(sha256.New, key)
- h.Write([]byte(data))
- return h.Sum(nil)
- }
|