sigv4.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434
  1. package sigv4
  2. import (
  3. "bytes"
  4. "context"
  5. "crypto/hmac"
  6. "crypto/sha256"
  7. "encoding/hex"
  8. "fmt"
  9. "io"
  10. "log"
  11. "net/http"
  12. "sort"
  13. "strings"
  14. "time"
  15. )
  16. type contextKey string
  17. const (
  18. CredentialsContextKey contextKey = "credentials"
  19. )
  20. type AWSCredentials struct {
  21. AccessKeyID string
  22. SecretAccessKey string
  23. SessionToken string
  24. AccountID string
  25. }
  26. // Mock credential store - in production, this would be a database or external service
  27. var mockCredentials = map[string]AWSCredentials{
  28. "AKIAIOSFODNN7EXAMPLE": {
  29. AccessKeyID: "AKIAIOSFODNN7EXAMPLE",
  30. SecretAccessKey: "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY",
  31. AccountID: "123456789012",
  32. },
  33. "ASIAUIJXACK3L66H7KB4": {
  34. AccessKeyID: "ASIAUIJXACK3L66H7KB4",
  35. SecretAccessKey: "test-secret-key",
  36. SessionToken: "test-session-token",
  37. AccountID: "292709995190",
  38. },
  39. }
  40. // ValidateSigV4Middleware validates AWS Signature Version 4
  41. func ValidateSigV4Middleware(next http.HandlerFunc) http.HandlerFunc {
  42. return func(w http.ResponseWriter, r *http.Request) {
  43. // Read the body FIRST before any processing
  44. bodyBytes, err := io.ReadAll(r.Body)
  45. if err != nil {
  46. writeAuthError(w, "InvalidRequest", "Failed to read request body")
  47. return
  48. }
  49. // Restore the body for downstream handlers
  50. r.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
  51. // Log the body for debugging
  52. log.Printf("Request body length: %d bytes", len(bodyBytes))
  53. //if len(bodyBytes) > 0 {
  54. //log.Printf("Request body: %s", string(bodyBytes))
  55. //}
  56. // Parse authorization header
  57. authHeader := r.Header.Get("Authorization")
  58. if authHeader == "" {
  59. writeAuthError(w, "MissingAuthenticationToken", "Authorization header cannot be empty")
  60. return
  61. }
  62. // Parse SigV4 components
  63. sigV4, err := parseSigV4Header(authHeader)
  64. if err != nil {
  65. writeAuthError(w, "IncompleteSignature", err.Error())
  66. return
  67. }
  68. // Validate credential
  69. creds, ok := mockCredentials[sigV4.AccessKeyID]
  70. if !ok {
  71. writeAuthError(w, "InvalidClientTokenId", "The security token included in the request is invalid")
  72. return
  73. }
  74. // Validate date
  75. dateHeader := r.Header.Get("X-Amz-Date")
  76. if dateHeader == "" {
  77. writeAuthError(w, "InvalidRequest", "X-Amz-Date header is required")
  78. return
  79. }
  80. reqTime, err := time.Parse("20060102T150405Z", dateHeader)
  81. if err != nil {
  82. writeAuthError(w, "InvalidRequest", "Invalid X-Amz-Date format")
  83. return
  84. }
  85. // Check if request is within 15 minutes
  86. now := time.Now().UTC()
  87. if now.Sub(reqTime) > 15*time.Minute {
  88. writeAuthError(w, "RequestExpired", "Request has expired")
  89. return
  90. }
  91. if reqTime.Sub(now) > 15*time.Minute {
  92. writeAuthError(w, "SignatureDoesNotMatch", "Signature not yet current")
  93. return
  94. }
  95. // Calculate expected signature with the actual body bytes
  96. expectedSig, err := calculateSignature(r, bodyBytes, creds, sigV4)
  97. if err != nil {
  98. writeAuthError(w, "InternalError", fmt.Sprintf("Failed to calculate signature: %v", err))
  99. return
  100. }
  101. // Compare signatures
  102. log.Printf("Expected signature: %s", expectedSig)
  103. log.Printf("Provided signature: %s", sigV4.Signature)
  104. if expectedSig != sigV4.Signature {
  105. writeAuthError(w, "SignatureDoesNotMatch",
  106. "The request signature we calculated does not match the signature you provided")
  107. return
  108. }
  109. // Validate session token if present
  110. if creds.SessionToken != "" {
  111. reqToken := r.Header.Get("X-Amz-Security-Token")
  112. if reqToken != creds.SessionToken {
  113. writeAuthError(w, "InvalidClientTokenId", "The security token included in the request is invalid")
  114. return
  115. }
  116. }
  117. // Add credentials to context
  118. ctx := context.WithValue(r.Context(), CredentialsContextKey, creds)
  119. next.ServeHTTP(w, r.WithContext(ctx))
  120. }
  121. }
  122. type sigV4Components struct {
  123. AccessKeyID string
  124. CredentialScope string
  125. SignedHeaders []string
  126. Signature string
  127. Date string
  128. Region string
  129. Service string
  130. }
  131. func parseSigV4Header(authHeader string) (*sigV4Components, error) {
  132. if !strings.HasPrefix(authHeader, "AWS4-HMAC-SHA256 ") {
  133. return nil, fmt.Errorf("Authorization header must start with 'AWS4-HMAC-SHA256'")
  134. }
  135. authHeader = strings.TrimPrefix(authHeader, "AWS4-HMAC-SHA256 ")
  136. parts := strings.Split(authHeader, ", ")
  137. sig := &sigV4Components{}
  138. for _, part := range parts {
  139. kv := strings.SplitN(part, "=", 2)
  140. if len(kv) != 2 {
  141. return nil, fmt.Errorf("Invalid key=value pair in Authorization header")
  142. }
  143. key := strings.TrimSpace(kv[0])
  144. value := strings.TrimSpace(kv[1])
  145. switch key {
  146. case "Credential":
  147. credParts := strings.Split(value, "/")
  148. if len(credParts) != 5 {
  149. return nil, fmt.Errorf("Invalid credential format")
  150. }
  151. sig.AccessKeyID = strings.ReplaceAll(credParts[0], "\"", "")
  152. sig.Date = credParts[1]
  153. sig.Region = credParts[2]
  154. sig.Service = credParts[3]
  155. sig.CredentialScope = strings.Join(credParts[1:], "/")
  156. case "SignedHeaders":
  157. sig.SignedHeaders = strings.Split(value, ";")
  158. case "Signature":
  159. sig.Signature = value
  160. }
  161. }
  162. if sig.AccessKeyID == "" {
  163. return nil, fmt.Errorf("Authorization header requires 'Credential' parameter")
  164. }
  165. if sig.Signature == "" {
  166. return nil, fmt.Errorf("Authorization header requires 'Signature' parameter")
  167. }
  168. return sig, nil
  169. }
  170. func calculateSignature(r *http.Request, body []byte, creds AWSCredentials, sigV4 *sigV4Components) (string, error) {
  171. amzDate := r.Header.Get("X-Amz-Date")
  172. canonicalRequest := buildCanonicalRequest(r, body, sigV4.SignedHeaders)
  173. stringToSign := buildStringToSign(canonicalRequest, sigV4, amzDate)
  174. log.Printf("String to sign:\n%s", stringToSign)
  175. signingKey := deriveSigningKey(creds.SecretAccessKey, sigV4.Date, sigV4.Region, sigV4.Service)
  176. signature := hmacSHA256(signingKey, stringToSign)
  177. finalSig := hex.EncodeToString(signature)
  178. return finalSig, nil
  179. }
  180. // Add this helper function to detect virtual-hosted-style requests
  181. func isVirtualHostedStyle(host string) bool {
  182. // Virtual-hosted-style: bucketname.s3.domain.com
  183. // Path-style: s3.domain.com
  184. parts := strings.Split(host, ".")
  185. // If host starts with a bucket name (has more than 2 parts before domain)
  186. // and contains "s3", it's virtual-hosted-style
  187. return len(parts) > 2 && strings.Contains(host, "s3")
  188. }
  189. // Add this helper to extract bucket from host
  190. func extractBucketFromHost(host string) string {
  191. // Extract bucket name from virtual-hosted-style host
  192. // photosbucket.s3.alanyeung.co -> photosbucket
  193. parts := strings.Split(host, ".")
  194. if len(parts) > 0 {
  195. return parts[0]
  196. }
  197. return ""
  198. }
  199. func buildCanonicalRequest(r *http.Request, body []byte, signedHeaders []string) string {
  200. // Method
  201. method := r.Method
  202. // Canonical URI - Handle both virtual-hosted and path-style
  203. canonicalURI := r.RequestURI
  204. if idx := strings.Index(canonicalURI, "?"); idx != -1 {
  205. canonicalURI = canonicalURI[:idx]
  206. }
  207. // NEW: Check if this is a virtual-hosted-style request
  208. if isVirtualHostedStyle(r.Host) {
  209. // For virtual-hosted-style, strip the bucket name from the URI
  210. bucketName := extractBucketFromHost(r.Host)
  211. if bucketName != "" {
  212. // Remove /bucketname/ prefix from the URI
  213. prefix := "/" + bucketName + "/"
  214. if strings.HasPrefix(canonicalURI, prefix) {
  215. canonicalURI = "/" + strings.TrimPrefix(canonicalURI, prefix)
  216. } else if canonicalURI == "/"+bucketName {
  217. canonicalURI = "/"
  218. }
  219. }
  220. }
  221. if canonicalURI == "" {
  222. canonicalURI = "/"
  223. }
  224. // ... rest of the function remains the same
  225. canonicalQueryString := buildCanonicalQueryString(r)
  226. canonicalHeaders := buildCanonicalHeaders(r, signedHeaders)
  227. signedHeadersStr := strings.Join(signedHeaders, ";")
  228. var payloadHash string
  229. amzContentSha256 := r.Header.Get("X-Amz-Content-SHA256")
  230. if amzContentSha256 == "UNSIGNED-PAYLOAD" {
  231. payloadHash = "UNSIGNED-PAYLOAD"
  232. } else if amzContentSha256 != "" {
  233. payloadHash = amzContentSha256
  234. } else {
  235. payloadHash = sha256Hash(body)
  236. }
  237. canonicalRequest := fmt.Sprintf("%s\n%s\n%s\n%s\n%s\n%s",
  238. method,
  239. canonicalURI,
  240. canonicalQueryString,
  241. canonicalHeaders,
  242. signedHeadersStr,
  243. payloadHash,
  244. )
  245. log.Printf("=== Canonical Request ===")
  246. log.Printf("Method: %s", method)
  247. log.Printf("Host: %s (Virtual-hosted: %v)", r.Host, isVirtualHostedStyle(r.Host))
  248. log.Printf("Original RequestURI: %s", r.RequestURI)
  249. log.Printf("Canonical URI: %s", canonicalURI)
  250. log.Printf("Canonical Query String: %s", canonicalQueryString)
  251. log.Printf("Canonical Headers:\n%s", canonicalHeaders)
  252. log.Printf("Signed Headers: %s", signedHeadersStr)
  253. log.Printf("Payload Hash: %s", payloadHash)
  254. log.Printf("Full Canonical Request:\n%s", canonicalRequest)
  255. log.Printf("========================")
  256. return canonicalRequest
  257. }
  258. func buildCanonicalQueryString(r *http.Request) string {
  259. if r.URL.RawQuery == "" {
  260. return ""
  261. }
  262. // Split into key-value pairs
  263. params := strings.Split(r.URL.RawQuery, "&")
  264. // Parse each parameter to separate key and value
  265. type param struct {
  266. key string
  267. value string
  268. }
  269. parsedParams := make([]param, 0, len(params))
  270. for _, p := range params {
  271. parts := strings.SplitN(p, "=", 2)
  272. if len(parts) == 2 {
  273. parsedParams = append(parsedParams, param{key: parts[0], value: parts[1]})
  274. } else {
  275. parsedParams = append(parsedParams, param{key: parts[0], value: ""})
  276. }
  277. }
  278. // Sort by key, then by value
  279. sort.Slice(parsedParams, func(i, j int) bool {
  280. if parsedParams[i].key == parsedParams[j].key {
  281. return parsedParams[i].value < parsedParams[j].value
  282. }
  283. return parsedParams[i].key < parsedParams[j].key
  284. })
  285. // Rebuild the query string
  286. result := make([]string, len(parsedParams))
  287. for i, p := range parsedParams {
  288. if p.value == "" {
  289. result[i] = p.key + "="
  290. } else {
  291. result[i] = p.key + "=" + p.value
  292. }
  293. }
  294. return strings.Join(result, "&")
  295. }
  296. func buildCanonicalHeaders(r *http.Request, signedHeaders []string) string {
  297. headers := make(map[string]string)
  298. for _, header := range signedHeaders {
  299. headerLower := strings.ToLower(strings.TrimSpace(header))
  300. // Special handling for Host header
  301. if headerLower == "host" {
  302. headers[headerLower] = r.Host
  303. continue
  304. }
  305. values := r.Header[http.CanonicalHeaderKey(header)]
  306. if len(values) > 0 {
  307. trimmedValues := make([]string, len(values))
  308. for i, v := range values {
  309. trimmedValues[i] = strings.TrimSpace(v)
  310. }
  311. headers[headerLower] = strings.Join(trimmedValues, ",")
  312. }
  313. }
  314. // Sort headers
  315. keys := make([]string, 0, len(headers))
  316. for k := range headers {
  317. keys = append(keys, k)
  318. }
  319. sort.Strings(keys)
  320. // Build canonical headers string
  321. var canonical strings.Builder
  322. for _, k := range keys {
  323. canonical.WriteString(k)
  324. canonical.WriteString(":")
  325. canonical.WriteString(headers[k])
  326. canonical.WriteString("\n")
  327. }
  328. return canonical.String()
  329. }
  330. func buildStringToSign(canonicalRequest string, sigV4 *sigV4Components, amzDate string) string {
  331. canonicalRequestHash := sha256Hash([]byte(canonicalRequest))
  332. stringToSign := fmt.Sprintf("AWS4-HMAC-SHA256\n%s\n%s\n%s",
  333. amzDate,
  334. sigV4.CredentialScope,
  335. canonicalRequestHash,
  336. )
  337. return stringToSign
  338. }
  339. func deriveSigningKey(secretKey, date, region, service string) []byte {
  340. kDate := hmacSHA256([]byte("AWS4"+secretKey), date)
  341. kRegion := hmacSHA256(kDate, region)
  342. kService := hmacSHA256(kRegion, service)
  343. kSigning := hmacSHA256(kService, "aws4_request")
  344. return kSigning
  345. }
  346. func hmacSHA256(key []byte, data string) []byte {
  347. h := hmac.New(sha256.New, key)
  348. h.Write([]byte(data))
  349. return h.Sum(nil)
  350. }
  351. func sha256Hash(data []byte) string {
  352. hash := sha256.Sum256(data)
  353. return hex.EncodeToString(hash[:])
  354. }
  355. func writeAuthError(w http.ResponseWriter, code, message string) {
  356. w.Header().Set("Content-Type", "application/xml")
  357. w.WriteHeader(http.StatusForbidden)
  358. errorXML := fmt.Sprintf(`<?xml version="1.0" encoding="UTF-8"?>
  359. <ErrorResponse xmlns="https://sts.amazonaws.com/doc/2011-06-15/">
  360. <Error>
  361. <Type>Sender</Type>
  362. <Code>%s</Code>
  363. <Message>%s</Message>
  364. </Error>
  365. <RequestId>%d</RequestId>
  366. </ErrorResponse>`, code, message, time.Now().Unix())
  367. w.Write([]byte(errorXML))
  368. }