123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318 |
- package sigv4
- import (
- "bytes"
- "context"
- "crypto/hmac"
- "crypto/sha256"
- "encoding/hex"
- "fmt"
- "io"
- "net/http"
- "net/http/httptest"
- "testing"
- "time"
- "github.com/stretchr/testify/assert"
- "github.com/stretchr/testify/require"
- )
- func TestValidateSigV4Middleware(t *testing.T) {
- tests := []struct {
- name string
- setupRequest func() *http.Request
- expectError bool
- expectedStatus int
- expectedCode string
- }{
- {
- name: "Valid signature",
- setupRequest: func() *http.Request {
- body := []byte("Action=GetCallerIdentity&Version=2011-06-15")
- req := createSignedRequest(t, "POST", "/", body, mockCredentials["AKIAIOSFODNN7EXAMPLE"])
- return req
- },
- expectError: false,
- expectedStatus: http.StatusOK,
- },
- {
- name: "Missing authorization header",
- setupRequest: func() *http.Request {
- req := httptest.NewRequest("POST", "/", bytes.NewReader([]byte("test")))
- req.Header.Set("X-Amz-Date", time.Now().UTC().Format("20060102T150405Z"))
- return req
- },
- expectError: true,
- expectedStatus: http.StatusForbidden,
- expectedCode: "MissingAuthenticationToken",
- },
- {
- name: "Invalid access key",
- setupRequest: func() *http.Request {
- body := []byte("Action=GetCallerIdentity&Version=2011-06-15")
- invalidCreds := AWSCredentials{
- AccessKeyID: "INVALID_KEY",
- SecretAccessKey: "invalid-secret",
- AccountID: "123456789012",
- }
- req := createSignedRequest(t, "POST", "/", body, invalidCreds)
- return req
- },
- expectError: true,
- expectedStatus: http.StatusForbidden,
- expectedCode: "InvalidClientTokenId",
- },
- {
- name: "Expired request",
- setupRequest: func() *http.Request {
- body := []byte("Action=GetCallerIdentity&Version=2011-06-15")
- req := createSignedRequest(t, "POST", "/", body, mockCredentials["AKIAIOSFODNN7EXAMPLE"])
- // Set date to 20 minutes ago
- expiredDate := time.Now().UTC().Add(-20 * time.Minute).Format("20060102T150405Z")
- req.Header.Set("X-Amz-Date", expiredDate)
- return req
- },
- expectError: true,
- expectedStatus: http.StatusForbidden,
- expectedCode: "RequestExpired",
- },
- {
- name: "Invalid signature",
- setupRequest: func() *http.Request {
- body := []byte("Action=GetCallerIdentity&Version=2011-06-15")
- req := createSignedRequest(t, "POST", "/", body, mockCredentials["AKIAIOSFODNN7EXAMPLE"])
- // Tamper with signature
- authHeader := req.Header.Get("Authorization")
- authHeader = authHeader[:len(authHeader)-10] + "0000000000"
- req.Header.Set("Authorization", authHeader)
- return req
- },
- expectError: true,
- expectedStatus: http.StatusForbidden,
- expectedCode: "SignatureDoesNotMatch",
- },
- {
- name: "Missing X-Amz-Date header",
- setupRequest: func() *http.Request {
- body := []byte("Action=GetCallerIdentity&Version=2011-06-15")
- req := createSignedRequest(t, "POST", "/", body, mockCredentials["AKIAIOSFODNN7EXAMPLE"])
- req.Header.Del("X-Amz-Date")
- return req
- },
- expectError: true,
- expectedStatus: http.StatusForbidden,
- expectedCode: "InvalidRequest",
- },
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- // Create test handler
- testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- w.WriteHeader(http.StatusOK)
- w.Write([]byte("OK"))
- })
- // Wrap with middleware
- handler := ValidateSigV4Middleware(testHandler)
- // Create test request and recorder
- req := tt.setupRequest()
- rr := httptest.NewRecorder()
- // Execute request
- handler.ServeHTTP(rr, req)
- // Assert status code
- assert.Equal(t, tt.expectedStatus, rr.Code, "Status code mismatch")
- if tt.expectError {
- body := rr.Body.String()
- assert.Contains(t, body, tt.expectedCode, "Expected error code not found in response")
- }
- })
- }
- }
- func TestParseSigV4Header(t *testing.T) {
- tests := []struct {
- name string
- authHeader string
- expectError bool
- validate func(*testing.T, *sigV4Components)
- }{
- {
- name: "Valid SigV4 header",
- authHeader: "AWS4-HMAC-SHA256 Credential=AKIAIOSFODNN7EXAMPLE/20231016/us-east-1/sts/aws4_request, " +
- "SignedHeaders=content-type;host;x-amz-date, Signature=abcd1234",
- expectError: false,
- validate: func(t *testing.T, sig *sigV4Components) {
- assert.Equal(t, "AKIAIOSFODNN7EXAMPLE", sig.AccessKeyID)
- assert.Equal(t, "20231016", sig.Date)
- assert.Equal(t, "us-east-1", sig.Region)
- assert.Equal(t, "sts", sig.Service)
- assert.Equal(t, "abcd1234", sig.Signature)
- assert.Equal(t, []string{"content-type", "host", "x-amz-date"}, sig.SignedHeaders)
- },
- },
- {
- name: "Missing algorithm prefix",
- authHeader: "Credential=AKIAIOSFODNN7EXAMPLE/20231016/us-east-1/sts/aws4_request",
- expectError: true,
- },
- {
- name: "Missing credential",
- authHeader: "AWS4-HMAC-SHA256 SignedHeaders=host, Signature=abcd1234",
- expectError: true,
- },
- {
- name: "Missing signature",
- authHeader: "AWS4-HMAC-SHA256 Credential=AKIAIOSFODNN7EXAMPLE/20231016/us-east-1/sts/aws4_request",
- expectError: true,
- },
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- sig, err := parseSigV4Header(tt.authHeader)
- if tt.expectError {
- assert.Error(t, err)
- } else {
- require.NoError(t, err)
- if tt.validate != nil {
- tt.validate(t, sig)
- }
- }
- })
- }
- }
- func TestCalculateSignature(t *testing.T) {
- creds := mockCredentials["AKIAIOSFODNN7EXAMPLE"]
- body := []byte("Action=GetCallerIdentity&Version=2011-06-15")
- req := httptest.NewRequest("POST", "/", bytes.NewReader(body))
- req.Header.Set("Content-Type", "application/x-www-form-urlencoded; charset=utf-8")
- req.Header.Set("Host", "sts.us-east-1.amazonaws.com")
- req.Header.Set("X-Amz-Date", "20231016T120000Z")
- sigV4 := &sigV4Components{
- AccessKeyID: creds.AccessKeyID,
- Date: "20231016",
- Region: "us-east-1",
- Service: "sts",
- CredentialScope: "20231016/us-east-1/sts/aws4_request",
- SignedHeaders: []string{"content-type", "host", "x-amz-date"},
- }
- signature, err := calculateSignature(req, body, creds, sigV4)
- require.NoError(t, err)
- assert.NotEmpty(t, signature)
- assert.Len(t, signature, 64) // SHA256 hex string length
- }
- func TestHMACFunctions(t *testing.T) {
- t.Run("hmacSHA256", func(t *testing.T) {
- key := []byte("test-key")
- data := "test-data"
- result := hmacSHA256(key, data)
- assert.NotNil(t, result)
- assert.Greater(t, len(result), 0)
- })
- t.Run("sha256Hash", func(t *testing.T) {
- data := []byte("test-data")
- result := sha256Hash(data)
- assert.NotEmpty(t, result)
- assert.Len(t, result, 64) // Hex encoded SHA256
- })
- t.Run("deriveSigningKey", func(t *testing.T) {
- secretKey := "test-secret"
- date := "20231016"
- region := "us-east-1"
- service := "sts"
- signingKey := deriveSigningKey(secretKey, date, region, service)
- assert.NotNil(t, signingKey)
- assert.Len(t, signingKey, 32) // SHA256 output length
- })
- }
- // Helper function to create a properly signed request
- func createSignedRequest(t *testing.T, method, path string, body []byte, creds AWSCredentials) *http.Request {
- req := httptest.NewRequest(method, path, bytes.NewReader(body))
-
- now := time.Now().UTC()
- dateStamp := now.Format("20060102")
- amzDate := now.Format("20060102T150405Z")
-
- req.Header.Set("Content-Type", "application/x-www-form-urlencoded; charset=utf-8")
- req.Header.Set("Host", "sts.us-east-1.amazonaws.com")
- req.Header.Set("X-Amz-Date", amzDate)
-
- if creds.SessionToken != "" {
- req.Header.Set("X-Amz-Security-Token", creds.SessionToken)
- }
- // Calculate canonical request
- canonicalURI := path
- if canonicalURI == "" {
- canonicalURI = "/"
- }
-
- canonicalHeaders := fmt.Sprintf("content-type:%s\nhost:%s\nx-amz-date:%s\n",
- req.Header.Get("Content-Type"),
- req.Header.Get("Host"),
- amzDate,
- )
-
- signedHeaders := "content-type;host;x-amz-date"
- if creds.SessionToken != "" {
- canonicalHeaders += fmt.Sprintf("x-amz-security-token:%s\n", creds.SessionToken)
- signedHeaders += ";x-amz-security-token"
- }
-
- payloadHash := sha256Hash(body)
-
- canonicalRequest := fmt.Sprintf("%s\n%s\n\n%s\n%s\n%s",
- method,
- canonicalURI,
- canonicalHeaders,
- signedHeaders,
- payloadHash,
- )
-
- // Create string to sign
- credentialScope := fmt.Sprintf("%s/%s/%s/aws4_request", dateStamp, "us-east-1", "sts")
- hashedCanonicalRequest := sha256Hash([]byte(canonicalRequest))
-
- stringToSign := fmt.Sprintf("AWS4-HMAC-SHA256\n%s\n%s\n%s",
- amzDate,
- credentialScope,
- hashedCanonicalRequest,
- )
-
- // Calculate signature
- signingKey := deriveSigningKey(creds.SecretAccessKey, dateStamp, "us-east-1", "sts")
- signature := hmacSHA256(signingKey, stringToSign)
- signatureHex := hex.EncodeToString(signature)
-
- // Build authorization header
- authorization := fmt.Sprintf("AWS4-HMAC-SHA256 Credential=%s/%s, SignedHeaders=%s, Signature=%s",
- creds.AccessKeyID,
- credentialScope,
- signedHeaders,
- signatureHex,
- )
-
- req.Header.Set("Authorization", authorization)
-
- return req
- }
- func sha256Hash(data []byte) string {
- hash := sha256.Sum256(data)
- return hex.EncodeToString(hash[:])
- }
|