sigv4.go 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376
  1. package sigv4
  2. import (
  3. "bytes"
  4. "context"
  5. "crypto/hmac"
  6. "crypto/sha256"
  7. "encoding/hex"
  8. "fmt"
  9. "io"
  10. "log"
  11. "net/http"
  12. "sort"
  13. "strings"
  14. "time"
  15. )
  16. type contextKey string
  17. const (
  18. CredentialsContextKey contextKey = "credentials"
  19. )
  20. type AWSCredentials struct {
  21. AccessKeyID string
  22. SecretAccessKey string
  23. SessionToken string
  24. AccountID string
  25. }
  26. // Mock credential store - in production, this would be a database or external service
  27. var mockCredentials = map[string]AWSCredentials{
  28. "AKIAIOSFODNN7EXAMPLE": {
  29. AccessKeyID: "AKIAIOSFODNN7EXAMPLE",
  30. SecretAccessKey: "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY",
  31. AccountID: "123456789012",
  32. },
  33. "ASIAUIJXACK3L66H7KB4": {
  34. AccessKeyID: "ASIAUIJXACK3L66H7KB4",
  35. SecretAccessKey: "test-secret-key",
  36. SessionToken: "test-session-token",
  37. AccountID: "292709995190",
  38. },
  39. }
  40. // ValidateSigV4Middleware validates AWS Signature Version 4
  41. func ValidateSigV4Middleware(next http.HandlerFunc) http.HandlerFunc {
  42. return func(w http.ResponseWriter, r *http.Request) {
  43. // Read the body
  44. bodyBytes, err := io.ReadAll(r.Body)
  45. if err != nil {
  46. writeAuthError(w, "InvalidRequest", "Failed to read request body")
  47. return
  48. }
  49. r.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
  50. // Parse authorization header
  51. authHeader := r.Header.Get("Authorization")
  52. if authHeader == "" {
  53. writeAuthError(w, "MissingAuthenticationToken", "Authorization header cannot be empty")
  54. return
  55. }
  56. // Parse SigV4 components
  57. sigV4, err := parseSigV4Header(authHeader)
  58. if err != nil {
  59. writeAuthError(w, "IncompleteSignature", err.Error())
  60. return
  61. }
  62. // Validate credential
  63. creds, ok := mockCredentials[sigV4.AccessKeyID]
  64. if !ok {
  65. writeAuthError(w, "InvalidClientTokenId", "The security token included in the request is invalid")
  66. return
  67. }
  68. // Validate date
  69. dateHeader := r.Header.Get("X-Amz-Date")
  70. if dateHeader == "" {
  71. writeAuthError(w, "InvalidRequest", "X-Amz-Date header is required")
  72. return
  73. }
  74. reqTime, err := time.Parse("20060102T150405Z", dateHeader)
  75. if err != nil {
  76. writeAuthError(w, "InvalidRequest", "Invalid X-Amz-Date format")
  77. return
  78. }
  79. // Check if request is within 15 minutes
  80. now := time.Now().UTC()
  81. if now.Sub(reqTime) > 15*time.Minute {
  82. writeAuthError(w, "RequestExpired", "Request has expired")
  83. return
  84. }
  85. if reqTime.Sub(now) > 15*time.Minute {
  86. writeAuthError(w, "SignatureDoesNotMatch", "Signature not yet current")
  87. return
  88. }
  89. // Calculate expected signature
  90. expectedSig, err := calculateSignature(r, bodyBytes, creds, sigV4)
  91. if err != nil {
  92. writeAuthError(w, "InternalError", fmt.Sprintf("Failed to calculate signature: %v", err))
  93. return
  94. }
  95. // Compare signatures
  96. //log.Printf("Expected signature: %s", expectedSig)
  97. //log.Printf("Provided signature: %s", sigV4.Signature)
  98. if expectedSig != sigV4.Signature {
  99. writeAuthError(w, "SignatureDoesNotMatch",
  100. "The request signature we calculated does not match the signature you provided")
  101. return
  102. }
  103. // Validate session token if present
  104. if creds.SessionToken != "" {
  105. reqToken := r.Header.Get("X-Amz-Security-Token")
  106. if reqToken != creds.SessionToken {
  107. writeAuthError(w, "InvalidClientTokenId", "The security token included in the request is invalid")
  108. return
  109. }
  110. }
  111. // Add credentials to context
  112. ctx := context.WithValue(r.Context(), CredentialsContextKey, creds)
  113. next.ServeHTTP(w, r.WithContext(ctx))
  114. }
  115. }
  116. type sigV4Components struct {
  117. AccessKeyID string
  118. CredentialScope string
  119. SignedHeaders []string
  120. Signature string
  121. Date string
  122. Region string
  123. Service string
  124. }
  125. func parseSigV4Header(authHeader string) (*sigV4Components, error) {
  126. if !strings.HasPrefix(authHeader, "AWS4-HMAC-SHA256 ") {
  127. return nil, fmt.Errorf("Authorization header must start with 'AWS4-HMAC-SHA256'")
  128. }
  129. authHeader = strings.TrimPrefix(authHeader, "AWS4-HMAC-SHA256 ")
  130. parts := strings.Split(authHeader, ", ")
  131. sig := &sigV4Components{}
  132. for _, part := range parts {
  133. kv := strings.SplitN(part, "=", 2)
  134. if len(kv) != 2 {
  135. return nil, fmt.Errorf("Invalid key=value pair in Authorization header")
  136. }
  137. key := strings.TrimSpace(kv[0])
  138. value := strings.TrimSpace(kv[1])
  139. switch key {
  140. case "Credential":
  141. credParts := strings.Split(value, "/")
  142. if len(credParts) != 5 {
  143. return nil, fmt.Errorf("Invalid credential format")
  144. }
  145. sig.AccessKeyID = strings.ReplaceAll(credParts[0], "\"", "")
  146. sig.Date = credParts[1]
  147. sig.Region = credParts[2]
  148. sig.Service = credParts[3]
  149. sig.CredentialScope = strings.Join(credParts[1:], "/")
  150. case "SignedHeaders":
  151. sig.SignedHeaders = strings.Split(value, ";")
  152. case "Signature":
  153. sig.Signature = value
  154. }
  155. }
  156. if sig.AccessKeyID == "" {
  157. return nil, fmt.Errorf("Authorization header requires 'Credential' parameter")
  158. }
  159. if sig.Signature == "" {
  160. return nil, fmt.Errorf("Authorization header requires 'Signature' parameter")
  161. }
  162. return sig, nil
  163. }
  164. func calculateSignature(r *http.Request, body []byte, creds AWSCredentials, sigV4 *sigV4Components) (string, error) {
  165. amzDate := r.Header.Get("X-Amz-Date")
  166. canonicalRequest := buildCanonicalRequest(r, body, sigV4.SignedHeaders)
  167. stringToSign := buildStringToSign(canonicalRequest, sigV4, amzDate)
  168. log.Printf("String to sign:\n%s", stringToSign)
  169. signingKey := deriveSigningKey(creds.SecretAccessKey, sigV4.Date, sigV4.Region, sigV4.Service)
  170. signature := hmacSHA256(signingKey, stringToSign)
  171. finalSig := hex.EncodeToString(signature)
  172. return finalSig, nil
  173. }
  174. func buildCanonicalRequest(r *http.Request, body []byte, signedHeaders []string) string {
  175. // Method
  176. method := r.Method
  177. // Canonical URI - Extract from RequestURI to get the path before any routing modifications
  178. // RequestURI includes the full path, e.g., "/my-test-bucket/4.psd?uploads"
  179. canonicalURI := r.URL.Path
  180. if canonicalURI == "" {
  181. canonicalURI = "/"
  182. }
  183. // Canonical query string - MUST BE SORTED
  184. canonicalQueryString := buildCanonicalQueryString(r)
  185. // Canonical headers (already includes trailing newlines)
  186. canonicalHeaders := buildCanonicalHeaders(r, signedHeaders)
  187. // Signed headers
  188. signedHeadersStr := strings.Join(signedHeaders, ";")
  189. // Payload hash
  190. payloadHash := sha256Hash(body)
  191. canonicalRequest := fmt.Sprintf("%s\n%s\n%s\n%s\n%s\n%s",
  192. method,
  193. canonicalURI,
  194. canonicalQueryString,
  195. canonicalHeaders,
  196. signedHeadersStr,
  197. payloadHash,
  198. )
  199. /*
  200. log.Printf("=== Canonical Request ===")
  201. log.Printf("Method: %s", method)
  202. log.Printf("Canonical URI: %s", canonicalURI)
  203. log.Printf("Canonical Query String: %s", canonicalQueryString)
  204. log.Printf("Canonical Headers:\n%s", canonicalHeaders)
  205. log.Printf("Signed Headers: %s", signedHeadersStr)
  206. log.Printf("Payload Hash: %s", payloadHash)
  207. log.Printf("Full Canonical Request:\n%s", canonicalRequest)
  208. log.Printf("========================")
  209. */
  210. return canonicalRequest
  211. }
  212. func buildCanonicalQueryString(r *http.Request) string {
  213. if r.URL.RawQuery == "" {
  214. return ""
  215. }
  216. // Parse the query parameters properly
  217. query := r.URL.Query()
  218. // Get all keys and sort them
  219. keys := make([]string, 0, len(query))
  220. for k := range query {
  221. keys = append(keys, k)
  222. }
  223. sort.Strings(keys)
  224. // Build the canonical query string
  225. var params []string
  226. for _, key := range keys {
  227. values := query[key]
  228. // Sort values for this key
  229. sort.Strings(values)
  230. for _, value := range values {
  231. if value == "" {
  232. // Empty value - add just key with equals sign
  233. params = append(params, key+"=")
  234. } else {
  235. // Non-empty value
  236. params = append(params, key+"="+value)
  237. }
  238. }
  239. }
  240. return strings.Join(params, "&")
  241. }
  242. func buildCanonicalHeaders(r *http.Request, signedHeaders []string) string {
  243. headers := make(map[string]string)
  244. for _, header := range signedHeaders {
  245. headerLower := strings.ToLower(strings.TrimSpace(header))
  246. // Special handling for Host header
  247. if headerLower == "host" {
  248. headers[headerLower] = r.Host
  249. continue
  250. }
  251. values := r.Header[http.CanonicalHeaderKey(header)]
  252. if len(values) > 0 {
  253. trimmedValues := make([]string, len(values))
  254. for i, v := range values {
  255. trimmedValues[i] = strings.TrimSpace(v)
  256. }
  257. headers[headerLower] = strings.Join(trimmedValues, ",")
  258. }
  259. }
  260. // Sort headers
  261. keys := make([]string, 0, len(headers))
  262. for k := range headers {
  263. keys = append(keys, k)
  264. }
  265. sort.Strings(keys)
  266. // Build canonical headers string
  267. var canonical strings.Builder
  268. for _, k := range keys {
  269. canonical.WriteString(k)
  270. canonical.WriteString(":")
  271. canonical.WriteString(headers[k])
  272. canonical.WriteString("\n")
  273. }
  274. return canonical.String()
  275. }
  276. func buildStringToSign(canonicalRequest string, sigV4 *sigV4Components, amzDate string) string {
  277. canonicalRequestHash := sha256Hash([]byte(canonicalRequest))
  278. stringToSign := fmt.Sprintf("AWS4-HMAC-SHA256\n%s\n%s\n%s",
  279. amzDate,
  280. sigV4.CredentialScope,
  281. canonicalRequestHash,
  282. )
  283. return stringToSign
  284. }
  285. func deriveSigningKey(secretKey, date, region, service string) []byte {
  286. kDate := hmacSHA256([]byte("AWS4"+secretKey), date)
  287. kRegion := hmacSHA256(kDate, region)
  288. kService := hmacSHA256(kRegion, service)
  289. kSigning := hmacSHA256(kService, "aws4_request")
  290. return kSigning
  291. }
  292. func hmacSHA256(key []byte, data string) []byte {
  293. h := hmac.New(sha256.New, key)
  294. h.Write([]byte(data))
  295. return h.Sum(nil)
  296. }
  297. func sha256Hash(data []byte) string {
  298. hash := sha256.Sum256(data)
  299. return hex.EncodeToString(hash[:])
  300. }
  301. func writeAuthError(w http.ResponseWriter, code, message string) {
  302. w.Header().Set("Content-Type", "application/xml")
  303. w.WriteHeader(http.StatusForbidden)
  304. errorXML := fmt.Sprintf(`<?xml version="1.0" encoding="UTF-8"?>
  305. <ErrorResponse xmlns="https://sts.amazonaws.com/doc/2011-06-15/">
  306. <Error>
  307. <Type>Sender</Type>
  308. <Code>%s</Code>
  309. <Message>%s</Message>
  310. </Error>
  311. <RequestId>%d</RequestId>
  312. </ErrorResponse>`, code, message, time.Now().Unix())
  313. w.Write([]byte(errorXML))
  314. }