diff --git a/internal/server/handlers/authentication.go b/internal/server/handlers/authentication.go index 3ef1513..349daa3 100644 --- a/internal/server/handlers/authentication.go +++ b/internal/server/handlers/authentication.go @@ -11,7 +11,7 @@ import ( jwtInternal "github.com/mentos1386/zdravko/pkg/jwt" ) -const sessionName = "zdravko-hey" +const authenticationSessionName = "zdravko-hey" type AuthenticatedPrincipal struct { User *AuthenticatedUser @@ -48,7 +48,7 @@ func GetUser(ctx context.Context) *AuthenticatedUser { } func (h *BaseHandler) AuthenticateRequestWithCookies(r *http.Request) (*AuthenticatedUser, error) { - session, err := h.store.Get(r, sessionName) + session, err := h.store.Get(r, authenticationSessionName) if err != nil { return nil, err } @@ -114,7 +114,7 @@ func (h *BaseHandler) AuthenticateRequestWithToken(r *http.Request) (*Authentica } func (h *BaseHandler) SetAuthenticatedUserForRequest(w http.ResponseWriter, r *http.Request, user *AuthenticatedUser) error { - session, err := h.store.Get(r, sessionName) + session, err := h.store.Get(r, authenticationSessionName) if err != nil { return err } @@ -124,24 +124,16 @@ func (h *BaseHandler) SetAuthenticatedUserForRequest(w http.ResponseWriter, r *h session.Values["oauth2_refresh_token"] = user.OAuth2RefreshToken session.Values["oauth2_token_type"] = user.OAuth2TokenType session.Values["oauth2_expiry"] = user.OAuth2Expiry.Format(time.RFC3339) - err = h.store.Save(r, w, session) - if err != nil { - return err - } - return nil + return h.store.Save(r, w, session) } func (h *BaseHandler) ClearAuthenticatedUserForRequest(w http.ResponseWriter, r *http.Request) error { - session, err := h.store.Get(r, sessionName) + session, err := h.store.Get(r, authenticationSessionName) if err != nil { return err } session.Options.MaxAge = -1 - err = h.store.Save(r, w, session) - if err != nil { - return err - } - return nil + return h.store.Save(r, w, session) } type AuthenticatedHandler func(http.ResponseWriter, *http.Request, *AuthenticatedPrincipal) @@ -159,7 +151,7 @@ func (h *BaseHandler) Authenticated(next echo.HandlerFunc) echo.HandlerFunc { if user.OAuth2Expiry.Before(time.Now()) { user, err = h.RefreshToken(c.Response(), c.Request(), user) if err != nil { - return c.Redirect(http.StatusTemporaryRedirect, "/oauth2/login") + return c.Redirect(http.StatusTemporaryRedirect, "/oauth2/login?redirect="+c.Request().URL.Path) } } @@ -173,6 +165,6 @@ func (h *BaseHandler) Authenticated(next echo.HandlerFunc) echo.HandlerFunc { return next(cc) } - return c.Redirect(http.StatusTemporaryRedirect, "/oauth2/login") + return c.Redirect(http.StatusTemporaryRedirect, "/oauth2/login?redirect="+c.Request().URL.Path) } } diff --git a/internal/server/handlers/oauth2.go b/internal/server/handlers/oauth2.go index e170654..6032ffe 100644 --- a/internal/server/handlers/oauth2.go +++ b/internal/server/handlers/oauth2.go @@ -19,6 +19,45 @@ import ( "golang.org/x/oauth2" ) +const oauth2RedirectSessionName = "zdravko-hey-oauth2" + +func (h *BaseHandler) setOAuth2Redirect(c echo.Context, redirect string) error { + w := c.Response() + r := c.Request() + + session, err := h.store.Get(r, oauth2RedirectSessionName) + if err != nil { + return err + } + session.Values["redirect"] = redirect + return h.store.Save(r, w, session) +} + +func (h *BaseHandler) getOAuth2Redirect(c echo.Context) (string, error) { + r := c.Request() + + session, err := h.store.Get(r, oauth2RedirectSessionName) + if err != nil { + return "", err + } + if session.IsNew { + return "", nil + } + return session.Values["redirect"].(string), nil +} + +func (h *BaseHandler) clearOAuth2Redirect(c echo.Context) error { + w := c.Response() + r := c.Request() + + session, err := h.store.Get(r, oauth2RedirectSessionName) + if err != nil { + return err + } + session.Options.MaxAge = -1 + return h.store.Save(r, w, session) +} + type UserInfo struct { Id int `json:"id"` // FIXME: This might not always be int? Sub string `json:"sub"` @@ -97,6 +136,14 @@ func (h *BaseHandler) OAuth2LoginGET(c echo.Context) error { url := conf.AuthCodeURL(state, oauth2.AccessTypeOffline) + redirect := c.QueryParam("redirect") + h.logger.Info("OAuth2LoginGET", "redirect", redirect) + + err = h.setOAuth2Redirect(c, redirect) + if err != nil { + return err + } + return c.Redirect(http.StatusTemporaryRedirect, url) } @@ -156,7 +203,21 @@ func (h *BaseHandler) OAuth2CallbackGET(c echo.Context) error { return err } - return c.Redirect(http.StatusTemporaryRedirect, "/settings") + redirect, err := h.getOAuth2Redirect(c) + if err != nil { + return err + } + h.logger.Info("OAuth2CallbackGET", "redirect", redirect) + if redirect == "" { + redirect = "/settings" + } + + err = h.clearOAuth2Redirect(c) + if err != nil { + return err + } + + return c.Redirect(http.StatusTemporaryRedirect, redirect) } func (h *BaseHandler) OAuth2LogoutGET(c echo.Context) error {