1
0
Alan Yeung 4 өдөр өмнө
parent
commit
9ed05327fa
4 өөрчлөгдсөн 994 нэмэгдсэн , 42 устгасан
  1. 43 4
      aws/internal/kvdb/bolt.go
  2. 31 19
      aws/main.go
  3. 892 0
      aws/output.txt
  4. 28 19
      aws/pkg/sigv4/sigv4.go

+ 43 - 4
aws/internal/kvdb/bolt.go

@@ -3,6 +3,7 @@ package kvdb
 import (
 	"encoding/json"
 	"errors"
+	"fmt"
 	"time"
 
 	"go.etcd.io/bbolt"
@@ -181,20 +182,26 @@ func (db *BoltKVDB) SetBucketConfig(config *BucketConfig) error {
 // GetBucketConfig gets the configuration for a bucket
 func (db *BoltKVDB) GetBucketConfig(accountID, bucketID string) (*BucketConfig, error) {
 	var config BucketConfig
+
 	err := db.db.View(func(tx *bbolt.Tx) error {
-		b := tx.Bucket(bucketConfigBucket)
+		bucket := tx.Bucket([]byte("buckets"))
+		if bucket == nil {
+			return fmt.Errorf("buckets bucket not found")
+		}
 
-		key := accountID + ":" + bucketID
-		data := b.Get([]byte(key))
+		key := fmt.Sprintf("%s:%s", accountID, bucketID)
+		data := bucket.Get([]byte(key))
 		if data == nil {
-			return errors.New("bucket config not found")
+			return fmt.Errorf("bucket config not found")
 		}
 
 		return json.Unmarshal(data, &config)
 	})
+
 	if err != nil {
 		return nil, err
 	}
+
 	return &config, nil
 }
 
@@ -229,6 +236,38 @@ func (db *BoltKVDB) ListBucketConfigs(accountID string) ([]*BucketConfig, error)
 	return configs, err
 }
 
+func (db *BoltKVDB) ResolveBucketName(bucketName string) (accountID string, bucketID string, errr error) {
+	err := db.db.View(func(tx *bbolt.Tx) error {
+		bucket := tx.Bucket([]byte("buckets"))
+		if bucket == nil {
+			return fmt.Errorf("buckets bucket not found")
+		}
+
+		// Iterate through all bucket configs to find matching bucket name
+		cursor := bucket.Cursor()
+		for k, v := cursor.First(); k != nil; k, v = cursor.Next() {
+			var config BucketConfig
+			if err := json.Unmarshal(v, &config); err != nil {
+				continue
+			}
+
+			if config.BucketName == bucketName {
+				accountID = config.AccountID
+				bucketID = config.BucketID
+				return nil
+			}
+		}
+
+		return fmt.Errorf("bucket not found: %s", bucketName)
+	})
+
+	if err != nil {
+		return "", "", err
+	}
+
+	return accountID, bucketID, nil
+}
+
 // InitializeDefaultUsers initializes hardcoded users
 func (db *BoltKVDB) InitializeDefaultUsers() error {
 	defaultUsers := []*User{

+ 31 - 19
aws/main.go

@@ -3,6 +3,7 @@ package main
 import (
 	"log"
 	"net/http"
+	"strings"
 
 	"aws-sts-mock/internal/config"
 	"aws-sts-mock/internal/handler"
@@ -61,24 +62,35 @@ func main() {
 	// Health check endpoint - no authentication
 	mux.HandleFunc("/health", healthHandler.Handle)
 
-	// Public viewing endpoint - no authentication required
-	mux.HandleFunc("/s3/", publicViewHandler.Handle)
-
-	// AWS S3/STS endpoints - with SigV4 and user validation
+	// Main endpoint - handles both authenticated and public requests
 	mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
-		// Apply SigV4 middleware
-		sigv4Handler := sigv4Middleware(func(w http.ResponseWriter, r *http.Request) {
-			// Apply user validation middleware
-			userValidationMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
-				// Route to appropriate handler
-				if r.Method == "POST" && r.FormValue("Action") != "" {
-					stsHandler.Handle(w, r)
-				} else {
-					s3Handler.Handle(w, r)
-				}
-			})).ServeHTTP(w, r)
-		})
-		sigv4Handler(w, r)
+		// Check if Authorization header exists (indicates authenticated request)
+		authHeader := r.Header.Get("Authorization")
+
+		if authHeader != "" {
+			// Authenticated request - determine if STS or S3 based on Content-Type
+			// STS requests use application/x-www-form-urlencoded
+			// S3 requests typically don't, or use application/xml
+			contentType := r.Header.Get("Content-Type")
+
+			// If it's a POST with form-urlencoded content type, it's likely STS
+			if r.Method == "POST" && strings.Contains(contentType, "application/x-www-form-urlencoded") {
+				// STS request
+				sigv4Handler := sigv4Middleware(func(w http.ResponseWriter, r *http.Request) {
+					userValidationMiddleware(http.HandlerFunc(stsHandler.Handle)).ServeHTTP(w, r)
+				})
+				sigv4Handler(w, r)
+			} else {
+				// S3 request
+				sigv4Handler := sigv4Middleware(func(w http.ResponseWriter, r *http.Request) {
+					userValidationMiddleware(http.HandlerFunc(s3Handler.Handle)).ServeHTTP(w, r)
+				})
+				sigv4Handler(w, r)
+			}
+		} else {
+			// Unauthenticated request - try public viewing
+			publicViewHandler.Handle(w, r)
+		}
 	})
 
 	// Start server
@@ -92,8 +104,8 @@ func main() {
 	log.Printf("Folder structure: ./uploads/{accountID}/{bucketID}/")
 	log.Printf("")
 	log.Printf("Endpoints:")
-	log.Printf("  - AWS S3/STS: http://localhost:%s/", cfg.Port)
-	log.Printf("  - Public View: http://localhost:%s/s3/{accountID}/{bucketID}/", cfg.Port)
+	log.Printf("  - AWS S3/STS (authenticated): http://localhost:%s/", cfg.Port)
+	log.Printf("  - Public View (unauthenticated): http://localhost:%s/{bucketName}/{objectKey}", cfg.Port)
 	log.Printf("  - Health Check: http://localhost:%s/health", cfg.Port)
 	log.Printf("========================================")
 

Файлын зөрүү хэтэрхий том тул дарагдсан байна
+ 892 - 0
aws/output.txt


+ 28 - 19
aws/pkg/sigv4/sigv4.go

@@ -46,14 +46,21 @@ var mockCredentials = map[string]AWSCredentials{
 // ValidateSigV4Middleware validates AWS Signature Version 4
 func ValidateSigV4Middleware(next http.HandlerFunc) http.HandlerFunc {
 	return func(w http.ResponseWriter, r *http.Request) {
-		// Read the body
+		// Read the body FIRST before any processing
 		bodyBytes, err := io.ReadAll(r.Body)
 		if err != nil {
 			writeAuthError(w, "InvalidRequest", "Failed to read request body")
 			return
 		}
+		// Restore the body for downstream handlers
 		r.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
 
+		// Log the body for debugging
+		log.Printf("Request body length: %d bytes", len(bodyBytes))
+		//if len(bodyBytes) > 0 {
+		//log.Printf("Request body: %s", string(bodyBytes))
+		//}
+
 		// Parse authorization header
 		authHeader := r.Header.Get("Authorization")
 		if authHeader == "" {
@@ -99,7 +106,7 @@ func ValidateSigV4Middleware(next http.HandlerFunc) http.HandlerFunc {
 			return
 		}
 
-		// Calculate expected signature
+		// Calculate expected signature with the actual body bytes
 		expectedSig, err := calculateSignature(r, bodyBytes, creds, sigV4)
 		if err != nil {
 			writeAuthError(w, "InternalError", fmt.Sprintf("Failed to calculate signature: %v", err))
@@ -107,8 +114,8 @@ func ValidateSigV4Middleware(next http.HandlerFunc) http.HandlerFunc {
 		}
 
 		// Compare signatures
-		//log.Printf("Expected signature: %s", expectedSig)
-		//log.Printf("Provided signature: %s", sigV4.Signature)
+		log.Printf("Expected signature: %s", expectedSig)
+		log.Printf("Provided signature: %s", sigV4.Signature)
 
 		if expectedSig != sigV4.Signature {
 			writeAuthError(w, "SignatureDoesNotMatch",
@@ -207,9 +214,12 @@ func buildCanonicalRequest(r *http.Request, body []byte, signedHeaders []string)
 	// Method
 	method := r.Method
 
-	// Canonical URI - Extract from RequestURI to get the path before any routing modifications
-	// RequestURI includes the full path, e.g., "/my-test-bucket/4.psd?uploads"
-	canonicalURI := r.URL.Path
+	// Canonical URI - Use RequestURI which preserves URL encoding
+	// Split RequestURI to get just the path (before the query string)
+	canonicalURI := r.RequestURI
+	if idx := strings.Index(canonicalURI, "?"); idx != -1 {
+		canonicalURI = canonicalURI[:idx]
+	}
 	if canonicalURI == "" {
 		canonicalURI = "/"
 	}
@@ -223,7 +233,8 @@ func buildCanonicalRequest(r *http.Request, body []byte, signedHeaders []string)
 	// Signed headers
 	signedHeadersStr := strings.Join(signedHeaders, ";")
 
-	// Payload hash
+	// Payload hash - THIS IS THE KEY FIX
+	// Use the actual body bytes passed in, not an empty body
 	payloadHash := sha256Hash(body)
 
 	canonicalRequest := fmt.Sprintf("%s\n%s\n%s\n%s\n%s\n%s",
@@ -235,17 +246,15 @@ func buildCanonicalRequest(r *http.Request, body []byte, signedHeaders []string)
 		payloadHash,
 	)
 
-	/*
-		log.Printf("=== Canonical Request ===")
-		log.Printf("Method: %s", method)
-		log.Printf("Canonical URI: %s", canonicalURI)
-		log.Printf("Canonical Query String: %s", canonicalQueryString)
-		log.Printf("Canonical Headers:\n%s", canonicalHeaders)
-		log.Printf("Signed Headers: %s", signedHeadersStr)
-		log.Printf("Payload Hash: %s", payloadHash)
-		log.Printf("Full Canonical Request:\n%s", canonicalRequest)
-		log.Printf("========================")
-	*/
+	log.Printf("=== Canonical Request ===")
+	log.Printf("Method: %s", method)
+	log.Printf("Canonical URI: %s", canonicalURI)
+	log.Printf("Canonical Query String: %s", canonicalQueryString)
+	log.Printf("Canonical Headers:\n%s", canonicalHeaders)
+	log.Printf("Signed Headers: %s", signedHeadersStr)
+	log.Printf("Payload Hash: %s", payloadHash)
+	log.Printf("Full Canonical Request:\n%s", canonicalRequest)
+	log.Printf("========================")
 
 	return canonicalRequest
 }

Энэ ялгаанд хэт олон файл өөрчлөгдсөн тул зарим файлыг харуулаагүй болно