123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104 |
- package handler
- import (
- "encoding/xml"
- "fmt"
- "log"
- "net/http"
- "time"
- "aws-sts-mock/internal/service"
- "aws-sts-mock/pkg/sigv4"
- "aws-sts-mock/pkg/sts"
- )
- // STSHandler handles STS HTTP requests
- type STSHandler struct {
- service *service.STSService
- }
- // NewSTSHandler creates a new STS handler
- func NewSTSHandler(service *service.STSService) *STSHandler {
- return &STSHandler{
- service: service,
- }
- }
- // Handle processes STS requests
- func (h *STSHandler) Handle(w http.ResponseWriter, r *http.Request) {
- if err := r.ParseForm(); err != nil {
- h.writeError(w, "InvalidRequest", "Failed to parse request")
- return
- }
- action := r.FormValue("Action")
- version := r.FormValue("Version")
- // Validate version
- if err := h.service.ValidateAPIVersion(version); err != nil {
- h.writeError(w, "InvalidVersion", err.Error())
- return
- }
- switch action {
- case "GetCallerIdentity":
- h.handleGetCallerIdentity(w, r)
- default:
- h.writeError(w, "InvalidAction", fmt.Sprintf("Unknown action: %s", action))
- }
- }
- func (h *STSHandler) handleGetCallerIdentity(w http.ResponseWriter, r *http.Request) {
- creds, ok := r.Context().Value(sigv4.CredentialsContextKey).(sigv4.AWSCredentials)
- if !ok {
- h.writeError(w, "AccessDenied", "Failed to retrieve credentials")
- return
- }
- result, err := h.service.GetCallerIdentity(creds)
- if err != nil {
- log.Printf("Error getting caller identity: %v", err)
- h.writeError(w, "InternalError", "Failed to get caller identity")
- return
- }
- response := sts.GetCallerIdentityResponse{
- XMLName: xml.Name{Space: "https://sts.amazonaws.com/doc/2011-06-15/", Local: "GetCallerIdentityResponse"},
- GetCallerIdentityResult: *result,
- ResponseMetadata: sts.ResponseMetadata{
- RequestId: generateRequestId(),
- },
- }
- w.Header().Set("Content-Type", "text/xml")
- w.Header().Set("x-amzn-RequestId", response.ResponseMetadata.RequestId)
- w.Header().Set("Date", time.Now().UTC().Format(http.TimeFormat))
- w.WriteHeader(http.StatusOK)
- encoder := xml.NewEncoder(w)
- encoder.Indent("", " ")
- if err := encoder.Encode(response); err != nil {
- log.Printf("Error encoding response: %v", err)
- }
- }
- func (h *STSHandler) writeError(w http.ResponseWriter, code, message string) {
- errorResp := sts.ErrorResponse{
- XMLName: xml.Name{Space: "https://sts.amazonaws.com/doc/2011-06-15/", Local: "ErrorResponse"},
- Error: sts.Error{
- Type: "Sender",
- Code: code,
- Message: message,
- },
- RequestId: generateRequestId(),
- }
- w.Header().Set("Content-Type", "text/xml")
- w.WriteHeader(http.StatusBadRequest)
- encoder := xml.NewEncoder(w)
- encoder.Indent("", " ")
- if err := encoder.Encode(errorResp); err != nil {
- log.Printf("Error encoding error response: %v", err)
- }
- }
|