123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566 |
- package middleware
- import (
- "context"
- "log"
- "net/http"
- "aws-sts-mock/internal/kvdb"
- "aws-sts-mock/pkg/sigv4"
- )
- // UserValidationMiddleware creates a middleware that validates users against the KV database
- func UserValidationMiddleware(db kvdb.KVDB) func(http.Handler) http.Handler {
- return func(next http.Handler) http.Handler {
- return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- // Extract credentials from context (set by SigV4 middleware)
- creds, ok := r.Context().Value(sigv4.CredentialsContextKey).(sigv4.AWSCredentials)
- if !ok {
- http.Error(w, "Unauthorized", http.StatusUnauthorized)
- return
- }
- // Validate user exists in database
- user, err := db.GetUser(creds.AccessKeyID)
- if err != nil {
- if err == kvdb.ErrUserNotFound {
- log.Printf("User not found: %s", creds.AccessKeyID)
- http.Error(w, "Access Denied: User not found", http.StatusForbidden)
- return
- }
- log.Printf("Error validating user: %v", err)
- http.Error(w, "Internal Server Error", http.StatusInternalServerError)
- return
- }
- // Validate secret key matches
- if user.SecretAccessKey != creds.SecretAccessKey {
- log.Printf("Invalid credentials for user: %s", creds.AccessKeyID)
- http.Error(w, "Access Denied: Invalid credentials", http.StatusForbidden)
- return
- }
- // Update context with validated user information
- ctx := context.WithValue(r.Context(), sigv4.CredentialsContextKey, sigv4.AWSCredentials{
- AccessKeyID: user.AccessKeyID,
- SecretAccessKey: user.SecretAccessKey,
- AccountID: user.AccountID,
- })
- log.Printf("User validated: %s (Account: %s)", user.Username, user.AccountID)
- // Continue to next handler
- next.ServeHTTP(w, r.WithContext(ctx))
- })
- }
- }
- // ChainMiddleware chains multiple middleware functions
- func ChainMiddleware(middlewares ...func(http.Handler) http.Handler) func(http.Handler) http.Handler {
- return func(final http.Handler) http.Handler {
- for i := len(middlewares) - 1; i >= 0; i-- {
- final = middlewares[i](final)
- }
- return final
- }
- }
|