oauth2server.go 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295
  1. package oauth2server
  2. import (
  3. "context"
  4. _ "embed"
  5. "encoding/json"
  6. "log"
  7. "net/http"
  8. "net/url"
  9. "time"
  10. "github.com/go-oauth2/oauth2/v4/errors"
  11. "github.com/go-oauth2/oauth2/v4/generates"
  12. "github.com/go-oauth2/oauth2/v4/manage"
  13. "github.com/go-oauth2/oauth2/v4/models"
  14. "github.com/go-oauth2/oauth2/v4/server"
  15. "github.com/go-oauth2/oauth2/v4/store"
  16. "github.com/go-session/session"
  17. "imuslab.com/zoraxy/mod/utils"
  18. )
  19. const (
  20. SSO_SESSION_NAME = "ZoraxySSO"
  21. )
  22. type OAuth2Server struct {
  23. srv *server.Server //oAuth server instance
  24. config *SSOConfig
  25. parent *SSOHandler
  26. }
  27. //go:embed static/auth.html
  28. var authHtml []byte
  29. //go:embed static/login.html
  30. var loginHtml []byte
  31. // NewOAuth2Server creates a new OAuth2 server instance
  32. func NewOAuth2Server(config *SSOConfig, parent *SSOHandler) (*OAuth2Server, error) {
  33. manager := manage.NewDefaultManager()
  34. manager.SetAuthorizeCodeTokenCfg(manage.DefaultAuthorizeCodeTokenCfg)
  35. // token store
  36. manager.MustTokenStorage(store.NewFileTokenStore("./conf/sso.db"))
  37. // generate jwt access token
  38. manager.MapAccessGenerate(generates.NewAccessGenerate())
  39. //Load the information of registered app within the OAuth2 server
  40. clientStore := store.NewClientStore()
  41. clientStore.Set("myapp", &models.Client{
  42. ID: "myapp",
  43. Secret: "verysecurepassword",
  44. Domain: "localhost:9094",
  45. })
  46. //TODO: LOAD THIS DYNAMICALLY FROM DATABASE
  47. manager.MapClientStorage(clientStore)
  48. thisServer := OAuth2Server{
  49. config: config,
  50. parent: parent,
  51. }
  52. //Create a new oauth server
  53. srv := server.NewServer(server.NewConfig(), manager)
  54. srv.SetPasswordAuthorizationHandler(thisServer.PasswordAuthorizationHandler)
  55. srv.SetUserAuthorizationHandler(thisServer.UserAuthorizeHandler)
  56. srv.SetInternalErrorHandler(func(err error) (re *errors.Response) {
  57. log.Println("Internal Error:", err.Error())
  58. return
  59. })
  60. srv.SetResponseErrorHandler(func(re *errors.Response) {
  61. log.Println("Response Error:", re.Error.Error())
  62. })
  63. //Set the access scope handler
  64. srv.SetAuthorizeScopeHandler(thisServer.AuthorizationScopeHandler)
  65. //Set the access token expiration handler based on requesting domain / hostname
  66. srv.SetAccessTokenExpHandler(thisServer.ExpireHandler)
  67. thisServer.srv = srv
  68. return &thisServer, nil
  69. }
  70. // Password handler, validate if the given username and password are correct
  71. func (oas *OAuth2Server) PasswordAuthorizationHandler(ctx context.Context, clientID, username, password string) (userID string, err error) {
  72. //TODO: LOAD THIS DYNAMICALLY FROM DATABASE
  73. if username == "test" && password == "test" {
  74. userID = "test"
  75. }
  76. return
  77. }
  78. // User Authorization Handler, handle auth request from user
  79. func (oas *OAuth2Server) UserAuthorizeHandler(w http.ResponseWriter, r *http.Request) (userID string, err error) {
  80. store, err := session.Start(r.Context(), w, r)
  81. if err != nil {
  82. return
  83. }
  84. uid, ok := store.Get(SSO_SESSION_NAME)
  85. if !ok {
  86. if r.Form == nil {
  87. r.ParseForm()
  88. }
  89. store.Set("ReturnUri", r.Form)
  90. store.Save()
  91. w.Header().Set("Location", "/oauth2/login")
  92. w.WriteHeader(http.StatusFound)
  93. return
  94. }
  95. userID = uid.(string)
  96. store.Delete(SSO_SESSION_NAME)
  97. store.Save()
  98. return
  99. }
  100. // AccessTokenExpHandler, set the SSO session length default value
  101. func (oas *OAuth2Server) ExpireHandler(w http.ResponseWriter, r *http.Request) (exp time.Duration, err error) {
  102. requestHostname := r.Host
  103. if requestHostname == "" {
  104. //Use default value
  105. return time.Hour, nil
  106. }
  107. //Get the Registered App Config from parent
  108. appConfig, ok := oas.parent.Apps[requestHostname]
  109. if !ok {
  110. //Use default value
  111. return time.Hour, nil
  112. }
  113. //Use the app's session length
  114. return time.Second * time.Duration(appConfig.SessionDuration), nil
  115. }
  116. // AuthorizationScopeHandler, handle the scope of the request
  117. func (oas *OAuth2Server) AuthorizationScopeHandler(w http.ResponseWriter, r *http.Request) (scope string, err error) {
  118. //Get the scope from post or GEt request
  119. if r.Form == nil {
  120. if err := r.ParseForm(); err != nil {
  121. return "none", err
  122. }
  123. }
  124. //Get the hostname of the request
  125. requestHostname := r.Host
  126. if requestHostname == "" {
  127. //No rule set. Use default
  128. return "none", nil
  129. }
  130. //Get the Registered App Config from parent
  131. appConfig, ok := oas.parent.Apps[requestHostname]
  132. if !ok {
  133. //No rule set. Use default
  134. return "none", nil
  135. }
  136. //Check if the scope is set in the request
  137. if v, ok := r.Form["scope"]; ok {
  138. //Check if the requested scope is in the appConfig scope
  139. if utils.StringInArray(appConfig.Scopes, v[0]) {
  140. return v[0], nil
  141. }
  142. return "none", nil
  143. }
  144. return "none", nil
  145. }
  146. /* SSO Web Server Toggle Functions */
  147. func (oas *OAuth2Server) RegisterOauthEndpoints(primaryMux *http.ServeMux) {
  148. primaryMux.HandleFunc("/oauth2/login", oas.loginHandler)
  149. primaryMux.HandleFunc("/oauth2/auth", oas.authHandler)
  150. primaryMux.HandleFunc("/oauth2/authorize", func(w http.ResponseWriter, r *http.Request) {
  151. store, err := session.Start(r.Context(), w, r)
  152. if err != nil {
  153. http.Error(w, err.Error(), http.StatusInternalServerError)
  154. return
  155. }
  156. var form url.Values
  157. if v, ok := store.Get("ReturnUri"); ok {
  158. form = v.(url.Values)
  159. }
  160. r.Form = form
  161. store.Delete("ReturnUri")
  162. store.Save()
  163. err = oas.srv.HandleAuthorizeRequest(w, r)
  164. if err != nil {
  165. http.Error(w, err.Error(), http.StatusBadRequest)
  166. }
  167. })
  168. primaryMux.HandleFunc("/oauth2/token", func(w http.ResponseWriter, r *http.Request) {
  169. err := oas.srv.HandleTokenRequest(w, r)
  170. if err != nil {
  171. http.Error(w, err.Error(), http.StatusInternalServerError)
  172. }
  173. })
  174. primaryMux.HandleFunc("/test", func(w http.ResponseWriter, r *http.Request) {
  175. token, err := oas.srv.ValidationBearerToken(r)
  176. if err != nil {
  177. http.Error(w, err.Error(), http.StatusBadRequest)
  178. return
  179. }
  180. data := map[string]interface{}{
  181. "expires_in": int64(time.Until(token.GetAccessCreateAt().Add(token.GetAccessExpiresIn())).Seconds()),
  182. "client_id": token.GetClientID(),
  183. "user_id": token.GetUserID(),
  184. }
  185. e := json.NewEncoder(w)
  186. e.SetIndent("", " ")
  187. e.Encode(data)
  188. })
  189. }
  190. func (oas *OAuth2Server) loginHandler(w http.ResponseWriter, r *http.Request) {
  191. store, err := session.Start(r.Context(), w, r)
  192. if err != nil {
  193. http.Error(w, err.Error(), http.StatusInternalServerError)
  194. return
  195. }
  196. if r.Method == "POST" {
  197. if r.Form == nil {
  198. if err := r.ParseForm(); err != nil {
  199. http.Error(w, err.Error(), http.StatusInternalServerError)
  200. return
  201. }
  202. }
  203. //Load username and password from form post
  204. username, err := utils.PostPara(r, "username")
  205. if err != nil {
  206. w.Write([]byte("invalid username or password"))
  207. return
  208. }
  209. password, err := utils.PostPara(r, "password")
  210. if err != nil {
  211. w.Write([]byte("invalid username or password"))
  212. return
  213. }
  214. //Validate the user
  215. if !oas.parent.ValidateUsernameAndPassword(username, password) {
  216. //Wrong password
  217. w.Write([]byte("invalid username or password"))
  218. return
  219. }
  220. store.Set(SSO_SESSION_NAME, r.Form.Get("username"))
  221. store.Save()
  222. w.Header().Set("Location", "/oauth2/auth")
  223. w.WriteHeader(http.StatusFound)
  224. return
  225. } else if r.Method == "GET" {
  226. //Check if the user is logged in
  227. if _, ok := store.Get(SSO_SESSION_NAME); ok {
  228. w.Header().Set("Location", "/oauth2/auth")
  229. w.WriteHeader(http.StatusFound)
  230. return
  231. }
  232. }
  233. //User not logged in. Show login page
  234. w.Write(loginHtml)
  235. }
  236. func (oas *OAuth2Server) authHandler(w http.ResponseWriter, r *http.Request) {
  237. store, err := session.Start(context.TODO(), w, r)
  238. if err != nil {
  239. http.Error(w, err.Error(), http.StatusInternalServerError)
  240. return
  241. }
  242. if _, ok := store.Get(SSO_SESSION_NAME); !ok {
  243. w.Header().Set("Location", "/oauth2/login")
  244. w.WriteHeader(http.StatusFound)
  245. return
  246. }
  247. //User logged in. Check if this user have previously authorized the app
  248. //TODO: Check if the user have previously authorized the app
  249. //User have not authorized the app. Show the authorization page
  250. w.Write(authHtml)
  251. }