|
@@ -46,14 +46,21 @@ var mockCredentials = map[string]AWSCredentials{
|
|
|
// ValidateSigV4Middleware validates AWS Signature Version 4
|
|
|
func ValidateSigV4Middleware(next http.HandlerFunc) http.HandlerFunc {
|
|
|
return func(w http.ResponseWriter, r *http.Request) {
|
|
|
- // Read the body
|
|
|
+ // Read the body FIRST before any processing
|
|
|
bodyBytes, err := io.ReadAll(r.Body)
|
|
|
if err != nil {
|
|
|
writeAuthError(w, "InvalidRequest", "Failed to read request body")
|
|
|
return
|
|
|
}
|
|
|
+ // Restore the body for downstream handlers
|
|
|
r.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
|
|
|
|
|
+ // Log the body for debugging
|
|
|
+ log.Printf("Request body length: %d bytes", len(bodyBytes))
|
|
|
+ //if len(bodyBytes) > 0 {
|
|
|
+ //log.Printf("Request body: %s", string(bodyBytes))
|
|
|
+ //}
|
|
|
+
|
|
|
// Parse authorization header
|
|
|
authHeader := r.Header.Get("Authorization")
|
|
|
if authHeader == "" {
|
|
@@ -99,7 +106,7 @@ func ValidateSigV4Middleware(next http.HandlerFunc) http.HandlerFunc {
|
|
|
return
|
|
|
}
|
|
|
|
|
|
- // Calculate expected signature
|
|
|
+ // Calculate expected signature with the actual body bytes
|
|
|
expectedSig, err := calculateSignature(r, bodyBytes, creds, sigV4)
|
|
|
if err != nil {
|
|
|
writeAuthError(w, "InternalError", fmt.Sprintf("Failed to calculate signature: %v", err))
|
|
@@ -107,8 +114,8 @@ func ValidateSigV4Middleware(next http.HandlerFunc) http.HandlerFunc {
|
|
|
}
|
|
|
|
|
|
// Compare signatures
|
|
|
- //log.Printf("Expected signature: %s", expectedSig)
|
|
|
- //log.Printf("Provided signature: %s", sigV4.Signature)
|
|
|
+ log.Printf("Expected signature: %s", expectedSig)
|
|
|
+ log.Printf("Provided signature: %s", sigV4.Signature)
|
|
|
|
|
|
if expectedSig != sigV4.Signature {
|
|
|
writeAuthError(w, "SignatureDoesNotMatch",
|
|
@@ -207,9 +214,12 @@ func buildCanonicalRequest(r *http.Request, body []byte, signedHeaders []string)
|
|
|
// Method
|
|
|
method := r.Method
|
|
|
|
|
|
- // Canonical URI - Extract from RequestURI to get the path before any routing modifications
|
|
|
- // RequestURI includes the full path, e.g., "/my-test-bucket/4.psd?uploads"
|
|
|
- canonicalURI := r.URL.Path
|
|
|
+ // Canonical URI - Use RequestURI which preserves URL encoding
|
|
|
+ // Split RequestURI to get just the path (before the query string)
|
|
|
+ canonicalURI := r.RequestURI
|
|
|
+ if idx := strings.Index(canonicalURI, "?"); idx != -1 {
|
|
|
+ canonicalURI = canonicalURI[:idx]
|
|
|
+ }
|
|
|
if canonicalURI == "" {
|
|
|
canonicalURI = "/"
|
|
|
}
|
|
@@ -223,7 +233,8 @@ func buildCanonicalRequest(r *http.Request, body []byte, signedHeaders []string)
|
|
|
// Signed headers
|
|
|
signedHeadersStr := strings.Join(signedHeaders, ";")
|
|
|
|
|
|
- // Payload hash
|
|
|
+ // Payload hash - THIS IS THE KEY FIX
|
|
|
+ // Use the actual body bytes passed in, not an empty body
|
|
|
payloadHash := sha256Hash(body)
|
|
|
|
|
|
canonicalRequest := fmt.Sprintf("%s\n%s\n%s\n%s\n%s\n%s",
|
|
@@ -235,17 +246,15 @@ func buildCanonicalRequest(r *http.Request, body []byte, signedHeaders []string)
|
|
|
payloadHash,
|
|
|
)
|
|
|
|
|
|
- /*
|
|
|
- log.Printf("=== Canonical Request ===")
|
|
|
- log.Printf("Method: %s", method)
|
|
|
- log.Printf("Canonical URI: %s", canonicalURI)
|
|
|
- log.Printf("Canonical Query String: %s", canonicalQueryString)
|
|
|
- log.Printf("Canonical Headers:\n%s", canonicalHeaders)
|
|
|
- log.Printf("Signed Headers: %s", signedHeadersStr)
|
|
|
- log.Printf("Payload Hash: %s", payloadHash)
|
|
|
- log.Printf("Full Canonical Request:\n%s", canonicalRequest)
|
|
|
- log.Printf("========================")
|
|
|
- */
|
|
|
+ log.Printf("=== Canonical Request ===")
|
|
|
+ log.Printf("Method: %s", method)
|
|
|
+ log.Printf("Canonical URI: %s", canonicalURI)
|
|
|
+ log.Printf("Canonical Query String: %s", canonicalQueryString)
|
|
|
+ log.Printf("Canonical Headers:\n%s", canonicalHeaders)
|
|
|
+ log.Printf("Signed Headers: %s", signedHeadersStr)
|
|
|
+ log.Printf("Payload Hash: %s", payloadHash)
|
|
|
+ log.Printf("Full Canonical Request:\n%s", canonicalRequest)
|
|
|
+ log.Printf("========================")
|
|
|
|
|
|
return canonicalRequest
|
|
|
}
|