1
0

main_test.go 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214
  1. package main
  2. import (
  3. "bytes"
  4. "context"
  5. "crypto/hmac"
  6. "crypto/sha256"
  7. "encoding/hex"
  8. "encoding/xml"
  9. "fmt"
  10. "net/http"
  11. "net/http/httptest"
  12. "testing"
  13. "time"
  14. "aws-sts-mock/pkg/sigv4"
  15. "aws-sts-mock/pkg/sts"
  16. "github.com/stretchr/testify/assert"
  17. "github.com/stretchr/testify/require"
  18. )
  19. func TestGetCallerIdentityEndpoint(t *testing.T) {
  20. tests := []struct {
  21. name string
  22. setupRequest func() *http.Request
  23. expectError bool
  24. expectedStatus int
  25. validateResp func(*testing.T, *sts.GetCallerIdentityResponse)
  26. }{
  27. {
  28. name: "Valid GetCallerIdentity request",
  29. setupRequest: func() *http.Request {
  30. body := []byte("Action=GetCallerIdentity&Version=2011-06-15")
  31. creds := sigv4.AWSCredentials{
  32. AccessKeyID: "AKIAIOSFODNN7EXAMPLE",
  33. SecretAccessKey: "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY",
  34. AccountID: "123456789012",
  35. }
  36. req := createTestSignedRequest(t, "POST", "/", body, creds)
  37. return req
  38. },
  39. expectError: false,
  40. expectedStatus: http.StatusOK,
  41. validateResp: func(t *testing.T, resp *sts.GetCallerIdentityResponse) {
  42. assert.Equal(t, "123456789012", resp.GetCallerIdentityResult.Account)
  43. assert.Equal(t, "123456789012", resp.GetCallerIdentityResult.UserId)
  44. assert.Equal(t, "arn:aws:iam::123456789012:root", resp.GetCallerIdentityResult.Arn)
  45. assert.NotEmpty(t, resp.ResponseMetadata.RequestId)
  46. },
  47. },
  48. {
  49. name: "Invalid action",
  50. setupRequest: func() *http.Request {
  51. body := []byte("Action=InvalidAction&Version=2011-06-15")
  52. creds := sigv4.AWSCredentials{
  53. AccessKeyID: "AKIAIOSFODNN7EXAMPLE",
  54. SecretAccessKey: "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY",
  55. AccountID: "123456789012",
  56. }
  57. req := createTestSignedRequest(t, "POST", "/", body, creds)
  58. return req
  59. },
  60. expectError: true,
  61. expectedStatus: http.StatusBadRequest,
  62. },
  63. {
  64. name: "Invalid version",
  65. setupRequest: func() *http.Request {
  66. body := []byte("Action=GetCallerIdentity&Version=2099-01-01")
  67. creds := sigv4.AWSCredentials{
  68. AccessKeyID: "AKIAIOSFODNN7EXAMPLE",
  69. SecretAccessKey: "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY",
  70. AccountID: "123456789012",
  71. }
  72. req := createTestSignedRequest(t, "POST", "/", body, creds)
  73. return req
  74. },
  75. expectError: true,
  76. expectedStatus: http.StatusBadRequest,
  77. },
  78. }
  79. for _, tt := range tests {
  80. t.Run(tt.name, func(t *testing.T) {
  81. // Create test server with middleware
  82. handler := sigv4.ValidateSigV4Middleware(handleSTSRequest)
  83. req := tt.setupRequest()
  84. rr := httptest.NewRecorder()
  85. handler.ServeHTTP(rr, req)
  86. assert.Equal(t, tt.expectedStatus, rr.Code)
  87. if !tt.expectError && tt.validateResp != nil {
  88. var response sts.GetCallerIdentityResponse
  89. err := xml.Unmarshal(rr.Body.Bytes(), &response)
  90. require.NoError(t, err)
  91. tt.validateResp(t, &response)
  92. }
  93. })
  94. }
  95. }
  96. func TestHealthEndpoint(t *testing.T) {
  97. req := httptest.NewRequest("GET", "/health", nil)
  98. rr := httptest.NewRecorder()
  99. handleHealth(rr, req)
  100. assert.Equal(t, http.StatusOK, rr.Code)
  101. assert.Contains(t, rr.Body.String(), "healthy")
  102. }
  103. func TestErrorResponse(t *testing.T) {
  104. rr := httptest.NewRecorder()
  105. writeErrorResponse(rr, "TestError", "This is a test error")
  106. assert.Equal(t, http.StatusBadRequest, rr.Code)
  107. assert.Contains(t, rr.Body.String(), "TestError")
  108. assert.Contains(t, rr.Body.String(), "This is a test error")
  109. var errorResp sts.ErrorResponse
  110. err := xml.Unmarshal(rr.Body.Bytes(), &errorResp)
  111. require.NoError(t, err)
  112. assert.Equal(t, "TestError", errorResp.Error.Code)
  113. }
  114. // Helper function to create test signed requests
  115. func createTestSignedRequest(t *testing.T, method, path string, body []byte, creds sigv4.AWSCredentials) *http.Request {
  116. req := httptest.NewRequest(method, path, bytes.NewReader(body))
  117. // Set credentials in context (simulating successful auth)
  118. ctx := context.WithValue(req.Context(), sigv4.CredentialsContextKey, creds)
  119. req = req.WithContext(ctx)
  120. now := time.Now().UTC()
  121. dateStamp := now.Format("20060102")
  122. amzDate := now.Format("20060102T150405Z")
  123. req.Header.Set("Content-Type", "application/x-www-form-urlencoded; charset=utf-8")
  124. req.Header.Set("Host", "sts.us-east-1.amazonaws.com")
  125. req.Header.Set("X-Amz-Date", amzDate)
  126. if creds.SessionToken != "" {
  127. req.Header.Set("X-Amz-Security-Token", creds.SessionToken)
  128. }
  129. canonicalURI := path
  130. if canonicalURI == "" {
  131. canonicalURI = "/"
  132. }
  133. signedHeaders := "content-type;host;x-amz-date"
  134. canonicalHeaders := fmt.Sprintf("content-type:%s\nhost:%s\nx-amz-date:%s\n",
  135. req.Header.Get("Content-Type"),
  136. req.Header.Get("Host"),
  137. amzDate,
  138. )
  139. if creds.SessionToken != "" {
  140. canonicalHeaders += fmt.Sprintf("x-amz-security-token:%s\n", creds.SessionToken)
  141. signedHeaders += ";x-amz-security-token"
  142. }
  143. hash := sha256.Sum256(body)
  144. payloadHash := hex.EncodeToString(hash[:])
  145. canonicalRequest := fmt.Sprintf("%s\n%s\n\n%s\n%s\n%s",
  146. method,
  147. canonicalURI,
  148. canonicalHeaders,
  149. signedHeaders,
  150. payloadHash,
  151. )
  152. credentialScope := fmt.Sprintf("%s/%s/%s/aws4_request", dateStamp, "us-east-1", "sts")
  153. hash2 := sha256.Sum256([]byte(canonicalRequest))
  154. hashedCanonicalRequest := hex.EncodeToString(hash2[:])
  155. stringToSign := fmt.Sprintf("AWS4-HMAC-SHA256\n%s\n%s\n%s",
  156. amzDate,
  157. credentialScope,
  158. hashedCanonicalRequest,
  159. )
  160. // Derive signing key
  161. kDate := hmacSHA256([]byte("AWS4"+creds.SecretAccessKey), dateStamp)
  162. kRegion := hmacSHA256(kDate, "us-east-1")
  163. kService := hmacSHA256(kRegion, "sts")
  164. kSigning := hmacSHA256(kService, "aws4_request")
  165. signature := hmacSHA256(kSigning, stringToSign)
  166. signatureHex := hex.EncodeToString(signature)
  167. authorization := fmt.Sprintf("AWS4-HMAC-SHA256 Credential=%s/%s, SignedHeaders=%s, Signature=%s",
  168. creds.AccessKeyID,
  169. credentialScope,
  170. signedHeaders,
  171. signatureHex,
  172. )
  173. req.Header.Set("Authorization", authorization)
  174. return req
  175. }
  176. func hmacSHA256(key []byte, data string) []byte {
  177. h := hmac.New(sha256.New, key)
  178. h.Write([]byte(data))
  179. return h.Sum(nil)
  180. }