package middleware import ( "context" "log" "net/http" "aws-sts-mock/internal/kvdb" "aws-sts-mock/pkg/sigv4" ) // UserValidationMiddleware creates a middleware that validates users against the KV database func UserValidationMiddleware(db kvdb.KVDB) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Extract credentials from context (set by SigV4 middleware) creds, ok := r.Context().Value(sigv4.CredentialsContextKey).(sigv4.AWSCredentials) if !ok { http.Error(w, "Unauthorized", http.StatusUnauthorized) return } // Validate user exists in database user, err := db.GetUser(creds.AccessKeyID) if err != nil { if err == kvdb.ErrUserNotFound { log.Printf("User not found: %s", creds.AccessKeyID) http.Error(w, "Access Denied: User not found", http.StatusForbidden) return } log.Printf("Error validating user: %v", err) http.Error(w, "Internal Server Error", http.StatusInternalServerError) return } // Validate secret key matches if user.SecretAccessKey != creds.SecretAccessKey { log.Printf("Invalid credentials for user: %s", creds.AccessKeyID) http.Error(w, "Access Denied: Invalid credentials", http.StatusForbidden) return } // Update context with validated user information ctx := context.WithValue(r.Context(), sigv4.CredentialsContextKey, sigv4.AWSCredentials{ AccessKeyID: user.AccessKeyID, SecretAccessKey: user.SecretAccessKey, AccountID: user.AccountID, }) log.Printf("User validated: %s (Account: %s)", user.Username, user.AccountID) // Continue to next handler next.ServeHTTP(w, r.WithContext(ctx)) }) } } // ChainMiddleware chains multiple middleware functions func ChainMiddleware(middlewares ...func(http.Handler) http.Handler) func(http.Handler) http.Handler { return func(final http.Handler) http.Handler { for i := len(middlewares) - 1; i >= 0; i-- { final = middlewares[i](final) } return final } }