sts.go 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. package handler
  2. import (
  3. "encoding/xml"
  4. "fmt"
  5. "log"
  6. "net/http"
  7. "time"
  8. "aws-sts-mock/internal/service"
  9. "aws-sts-mock/pkg/sigv4"
  10. "aws-sts-mock/pkg/sts"
  11. )
  12. // STSHandler handles STS HTTP requests
  13. type STSHandler struct {
  14. service *service.STSService
  15. }
  16. // NewSTSHandler creates a new STS handler
  17. func NewSTSHandler(service *service.STSService) *STSHandler {
  18. return &STSHandler{
  19. service: service,
  20. }
  21. }
  22. // Handle processes STS requests
  23. func (h *STSHandler) Handle(w http.ResponseWriter, r *http.Request) {
  24. if err := r.ParseForm(); err != nil {
  25. h.writeError(w, "InvalidRequest", "Failed to parse request")
  26. return
  27. }
  28. action := r.FormValue("Action")
  29. version := r.FormValue("Version")
  30. // Validate version
  31. if err := h.service.ValidateAPIVersion(version); err != nil {
  32. h.writeError(w, "InvalidVersion", err.Error())
  33. return
  34. }
  35. switch action {
  36. case "GetCallerIdentity":
  37. h.handleGetCallerIdentity(w, r)
  38. default:
  39. h.writeError(w, "InvalidAction", fmt.Sprintf("Unknown action: %s", action))
  40. }
  41. }
  42. func (h *STSHandler) handleGetCallerIdentity(w http.ResponseWriter, r *http.Request) {
  43. creds, ok := r.Context().Value(sigv4.CredentialsContextKey).(sigv4.AWSCredentials)
  44. if !ok {
  45. h.writeError(w, "AccessDenied", "Failed to retrieve credentials")
  46. return
  47. }
  48. result, err := h.service.GetCallerIdentity(creds)
  49. if err != nil {
  50. log.Printf("Error getting caller identity: %v", err)
  51. h.writeError(w, "InternalError", "Failed to get caller identity")
  52. return
  53. }
  54. response := sts.GetCallerIdentityResponse{
  55. XMLName: xml.Name{Space: "https://sts.amazonaws.com/doc/2011-06-15/", Local: "GetCallerIdentityResponse"},
  56. GetCallerIdentityResult: *result,
  57. ResponseMetadata: sts.ResponseMetadata{
  58. RequestId: generateRequestId(),
  59. },
  60. }
  61. w.Header().Set("Content-Type", "text/xml")
  62. w.Header().Set("x-amzn-RequestId", response.ResponseMetadata.RequestId)
  63. w.Header().Set("Date", time.Now().UTC().Format(http.TimeFormat))
  64. w.WriteHeader(http.StatusOK)
  65. encoder := xml.NewEncoder(w)
  66. encoder.Indent("", " ")
  67. if err := encoder.Encode(response); err != nil {
  68. log.Printf("Error encoding response: %v", err)
  69. }
  70. }
  71. func (h *STSHandler) writeError(w http.ResponseWriter, code, message string) {
  72. errorResp := sts.ErrorResponse{
  73. XMLName: xml.Name{Space: "https://sts.amazonaws.com/doc/2011-06-15/", Local: "ErrorResponse"},
  74. Error: sts.Error{
  75. Type: "Sender",
  76. Code: code,
  77. Message: message,
  78. },
  79. RequestId: generateRequestId(),
  80. }
  81. w.Header().Set("Content-Type", "text/xml")
  82. w.WriteHeader(http.StatusBadRequest)
  83. encoder := xml.NewEncoder(w)
  84. encoder.Indent("", " ")
  85. if err := encoder.Encode(errorResp); err != nil {
  86. log.Printf("Error encoding error response: %v", err)
  87. }
  88. }