1
0

sigv4.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400
  1. package sigv4
  2. import (
  3. "bytes"
  4. "context"
  5. "crypto/hmac"
  6. "crypto/sha256"
  7. "encoding/hex"
  8. "fmt"
  9. "io"
  10. "log"
  11. "net/http"
  12. "sort"
  13. "strings"
  14. "time"
  15. )
  16. type contextKey string
  17. const (
  18. CredentialsContextKey contextKey = "credentials"
  19. )
  20. type AWSCredentials struct {
  21. AccessKeyID string
  22. SecretAccessKey string
  23. SessionToken string
  24. AccountID string
  25. }
  26. // Mock credential store - in production, this would be a database or external service
  27. var mockCredentials = map[string]AWSCredentials{
  28. "AKIAIOSFODNN7EXAMPLE": {
  29. AccessKeyID: "AKIAIOSFODNN7EXAMPLE",
  30. SecretAccessKey: "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY",
  31. AccountID: "123456789012",
  32. },
  33. "ASIAUIJXACK3L66H7KB4": {
  34. AccessKeyID: "ASIAUIJXACK3L66H7KB4",
  35. SecretAccessKey: "test-secret-key",
  36. SessionToken: "test-session-token",
  37. AccountID: "292709995190",
  38. },
  39. }
  40. // ValidateSigV4Middleware validates AWS Signature Version 4
  41. func ValidateSigV4Middleware(next http.HandlerFunc) http.HandlerFunc {
  42. testHMAC()
  43. return func(w http.ResponseWriter, r *http.Request) {
  44. // Read the body
  45. bodyBytes, err := io.ReadAll(r.Body)
  46. if err != nil {
  47. writeAuthError(w, "InvalidRequest", "Failed to read request body")
  48. return
  49. }
  50. r.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
  51. // Parse authorization header
  52. authHeader := r.Header.Get("Authorization")
  53. if authHeader == "" {
  54. writeAuthError(w, "MissingAuthenticationToken", "Authorization header cannot be empty")
  55. return
  56. }
  57. // Parse SigV4 components
  58. sigV4, err := parseSigV4Header(authHeader)
  59. if err != nil {
  60. writeAuthError(w, "IncompleteSignature", err.Error())
  61. return
  62. }
  63. log.Println(sigV4)
  64. log.Println(sigV4.AccessKeyID)
  65. // Validate credential
  66. creds, ok := mockCredentials[sigV4.AccessKeyID]
  67. if !ok {
  68. writeAuthError(w, "InvalidClientTokenId", "The security token included in the request is invalid 1")
  69. return
  70. }
  71. // Validate date
  72. dateHeader := r.Header.Get("X-Amz-Date")
  73. if dateHeader == "" {
  74. writeAuthError(w, "InvalidRequest", "X-Amz-Date header is required")
  75. return
  76. }
  77. reqTime, err := time.Parse("20060102T150405Z", dateHeader)
  78. if err != nil {
  79. writeAuthError(w, "InvalidRequest", "Invalid X-Amz-Date format")
  80. return
  81. }
  82. // Check if request is within 15 minutes
  83. now := time.Now().UTC()
  84. if now.Sub(reqTime) > 15*time.Minute {
  85. writeAuthError(w, "RequestExpired", "Request has expired")
  86. return
  87. }
  88. if reqTime.Sub(now) > 15*time.Minute {
  89. writeAuthError(w, "SignatureDoesNotMatch", "Signature not yet current")
  90. return
  91. }
  92. // Calculate expected signature
  93. expectedSig, err := calculateSignature(r, bodyBytes, creds, sigV4)
  94. if err != nil {
  95. writeAuthError(w, "InternalError", fmt.Sprintf("Failed to calculate signature: %v", err))
  96. return
  97. }
  98. // Compare signatures
  99. log.Println(expectedSig + " " + sigV4.Signature)
  100. if expectedSig != sigV4.Signature {
  101. //writeAuthError(w, "SignatureDoesNotMatch",
  102. //"The request signature we calculated does not match the signature you provided")
  103. //return
  104. }
  105. // Validate session token if present
  106. if creds.SessionToken != "" {
  107. reqToken := r.Header.Get("X-Amz-Security-Token")
  108. if reqToken != creds.SessionToken {
  109. writeAuthError(w, "InvalidClientTokenId", "The security token included in the request is invalid 2")
  110. return
  111. }
  112. }
  113. // Add credentials to context
  114. ctx := context.WithValue(r.Context(), CredentialsContextKey, creds)
  115. next.ServeHTTP(w, r.WithContext(ctx))
  116. }
  117. }
  118. type sigV4Components struct {
  119. AccessKeyID string
  120. CredentialScope string
  121. SignedHeaders []string
  122. Signature string
  123. Date string
  124. Region string
  125. Service string
  126. }
  127. func parseSigV4Header(authHeader string) (*sigV4Components, error) {
  128. if !strings.HasPrefix(authHeader, "AWS4-HMAC-SHA256 ") {
  129. return nil, fmt.Errorf("Authorization header must start with 'AWS4-HMAC-SHA256'")
  130. }
  131. authHeader = strings.TrimPrefix(authHeader, "AWS4-HMAC-SHA256 ")
  132. parts := strings.Split(authHeader, ", ")
  133. sig := &sigV4Components{}
  134. for _, part := range parts {
  135. kv := strings.SplitN(part, "=", 2)
  136. if len(kv) != 2 {
  137. return nil, fmt.Errorf("Invalid key=value pair in Authorization header")
  138. }
  139. key := strings.TrimSpace(kv[0])
  140. value := strings.TrimSpace(kv[1])
  141. switch key {
  142. case "Credential":
  143. credParts := strings.Split(value, "/")
  144. if len(credParts) != 5 {
  145. return nil, fmt.Errorf("Invalid credential format")
  146. }
  147. sig.AccessKeyID = strings.ReplaceAll(credParts[0], "\"", "")
  148. sig.Date = credParts[1]
  149. sig.Region = credParts[2]
  150. sig.Service = credParts[3]
  151. sig.CredentialScope = strings.Join(credParts[1:], "/")
  152. case "SignedHeaders":
  153. sig.SignedHeaders = strings.Split(value, ";")
  154. case "Signature":
  155. sig.Signature = value
  156. }
  157. }
  158. if sig.AccessKeyID == "" {
  159. return nil, fmt.Errorf("Authorization header requires 'Credential' parameter")
  160. }
  161. if sig.Signature == "" {
  162. return nil, fmt.Errorf("Authorization header requires 'Signature' parameter")
  163. }
  164. return sig, nil
  165. }
  166. func testHMAC() {
  167. fmt.Println("=== Testing with quotes in secret key ===")
  168. // The secret key WITH quotes as it appears in environment variable
  169. secretKeyWithQuotes := "\"wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY\""
  170. kDate := hmacSHA256([]byte("AWS4"+secretKeyWithQuotes), "20251016")
  171. fmt.Println("kDate with quotes:", hex.EncodeToString(kDate))
  172. kRegion := hmacSHA256(kDate, "us-west-2")
  173. kService := hmacSHA256(kRegion, "sts")
  174. kSigning := hmacSHA256(kService, "aws4_request")
  175. fmt.Println("Signing key with quotes:", hex.EncodeToString(kSigning))
  176. stringToSign := "AWS4-HMAC-SHA256\n20251016T185909Z\n20251016/us-west-2/sts/aws4_request\nb1eade88f71175b2790108f7015d7ded65f982153aad655da6d50d5976ae14df"
  177. signature := hmacSHA256(kSigning, stringToSign)
  178. fmt.Println("Our signature: ", hex.EncodeToString(signature))
  179. fmt.Println("AWS CLI's signature: ", "c434c79f1567cf160a5cb88d2ad710a5bc524bbb6d481ce51b4b63787f9c21a4")
  180. fmt.Println("Match:", hex.EncodeToString(signature) == "c434c79f1567cf160a5cb88d2ad710a5bc524bbb6d481ce51b4b63787f9c21a4")
  181. }
  182. func calculateSignature(r *http.Request, body []byte, creds AWSCredentials, sigV4 *sigV4Components) (string, error) {
  183. amzDate := r.Header.Get("X-Amz-Date")
  184. canonicalRequest := buildCanonicalRequest(r, body, sigV4.SignedHeaders)
  185. stringToSign := buildStringToSign(canonicalRequest, sigV4, amzDate)
  186. signingKey := deriveSigningKey(creds.SecretAccessKey, sigV4.Date, sigV4.Region, sigV4.Service)
  187. log.Println("=== Final HMAC Input ===")
  188. log.Printf("Key (hex): %s\n", hex.EncodeToString(signingKey))
  189. log.Printf("Data (bytes): %v\n", []byte(stringToSign))
  190. log.Printf("Data (hex): %s\n", hex.EncodeToString([]byte(stringToSign)))
  191. signature := hmacSHA256(signingKey, stringToSign)
  192. finalSig := hex.EncodeToString(signature)
  193. log.Printf("Signature (hex): %s\n", finalSig)
  194. log.Println("=== End Final HMAC Input ===")
  195. return finalSig, nil
  196. }
  197. func buildCanonicalRequest(r *http.Request, body []byte, signedHeaders []string) string {
  198. // Method
  199. method := r.Method
  200. // Canonical URI
  201. canonicalURI := r.URL.Path
  202. if canonicalURI == "" {
  203. canonicalURI = "/"
  204. }
  205. // Canonical query string - MUST BE SORTED
  206. canonicalQueryString := buildCanonicalQueryString(r)
  207. // Canonical headers (already includes trailing newlines)
  208. canonicalHeaders := buildCanonicalHeaders(r, signedHeaders)
  209. // Signed headers
  210. signedHeadersStr := strings.Join(signedHeaders, ";")
  211. // Payload hash
  212. payloadHash := sha256Hash(body)
  213. canonicalRequest := fmt.Sprintf("%s\n%s\n%s\n%s\n%s\n%s",
  214. method,
  215. canonicalURI,
  216. canonicalQueryString,
  217. canonicalHeaders,
  218. signedHeadersStr,
  219. payloadHash,
  220. )
  221. return canonicalRequest
  222. }
  223. func buildCanonicalQueryString(r *http.Request) string {
  224. if r.URL.RawQuery == "" {
  225. return ""
  226. }
  227. // Split the raw query string by '&'
  228. params := strings.Split(r.URL.RawQuery, "&")
  229. // Sort the parameters
  230. sort.Strings(params)
  231. // Join them back with '&'
  232. return strings.Join(params, "&")
  233. }
  234. func buildCanonicalHeaders(r *http.Request, signedHeaders []string) string {
  235. headers := make(map[string]string)
  236. for _, header := range signedHeaders {
  237. headerLower := strings.ToLower(strings.TrimSpace(header))
  238. // Special handling for Host header
  239. if headerLower == "host" {
  240. headers[headerLower] = r.Host
  241. log.Printf("Header '%s' = '%s'", headerLower, r.Host)
  242. continue
  243. }
  244. values := r.Header[http.CanonicalHeaderKey(header)]
  245. if len(values) > 0 {
  246. trimmedValues := make([]string, len(values))
  247. for i, v := range values {
  248. trimmedValues[i] = strings.TrimSpace(v)
  249. }
  250. headers[headerLower] = strings.Join(trimmedValues, ",")
  251. log.Printf("Header '%s' = '%s'", headerLower, headers[headerLower])
  252. }
  253. }
  254. // Sort headers
  255. keys := make([]string, 0, len(headers))
  256. for k := range headers {
  257. keys = append(keys, k)
  258. }
  259. sort.Strings(keys)
  260. // Build canonical headers string
  261. var canonical strings.Builder
  262. for _, k := range keys {
  263. canonical.WriteString(k)
  264. canonical.WriteString(":")
  265. canonical.WriteString(headers[k])
  266. canonical.WriteString("\n")
  267. }
  268. return canonical.String()
  269. }
  270. func buildStringToSign(canonicalRequest string, sigV4 *sigV4Components, amzDate string) string {
  271. log.Println("=== Canonical Request (raw) ===")
  272. log.Printf("%q\n", canonicalRequest) // Using %q to show escape characters
  273. log.Println("=== End Canonical Request (raw) ===")
  274. canonicalRequestHash := sha256Hash([]byte(canonicalRequest))
  275. log.Println("Canonical Request Hash:", canonicalRequestHash)
  276. stringToSign := fmt.Sprintf("AWS4-HMAC-SHA256\n%s\n%s\n%s",
  277. amzDate,
  278. sigV4.CredentialScope,
  279. canonicalRequestHash,
  280. )
  281. log.Println("=== String to Sign ===")
  282. log.Println(stringToSign)
  283. log.Println("=== End String to Sign ===")
  284. return stringToSign
  285. }
  286. func deriveSigningKey(secretKey, date, region, service string) []byte {
  287. log.Println("=== Deriving Signing Key ===")
  288. log.Println("Secret Key:", secretKey)
  289. log.Println("Date:", date)
  290. log.Println("Region:", region)
  291. log.Println("Service:", service)
  292. kDate := hmacSHA256([]byte("AWS4"+secretKey), date)
  293. log.Println("kDate:", hex.EncodeToString(kDate))
  294. kRegion := hmacSHA256(kDate, region)
  295. log.Println("kRegion:", hex.EncodeToString(kRegion))
  296. kService := hmacSHA256(kRegion, service)
  297. log.Println("kService:", hex.EncodeToString(kService))
  298. kSigning := hmacSHA256(kService, "aws4_request")
  299. log.Println("kSigning:", hex.EncodeToString(kSigning))
  300. log.Println("=== End Deriving Signing Key ===")
  301. return kSigning
  302. }
  303. func hmacSHA256(key []byte, data string) []byte {
  304. h := hmac.New(sha256.New, key)
  305. h.Write([]byte(data))
  306. return h.Sum(nil)
  307. }
  308. func sha256Hash(data []byte) string {
  309. hash := sha256.Sum256(data)
  310. return hex.EncodeToString(hash[:])
  311. }
  312. func writeAuthError(w http.ResponseWriter, code, message string) {
  313. w.Header().Set("Content-Type", "application/xml")
  314. w.WriteHeader(http.StatusForbidden)
  315. errorXML := fmt.Sprintf(`<?xml version="1.0" encoding="UTF-8"?>
  316. <ErrorResponse xmlns="https://sts.amazonaws.com/doc/2011-06-15/">
  317. <Error>
  318. <Type>Sender</Type>
  319. <Code>%s</Code>
  320. <Message>%s</Message>
  321. </Error>
  322. <RequestId>%d</RequestId>
  323. </ErrorResponse>`, code, message, time.Now().Unix())
  324. w.Write([]byte(errorXML))
  325. }