From 0a323c79e67aaa5e14f9e82448f1e44a864c0505 Mon Sep 17 00:00:00 2001 From: Tine Date: Sun, 11 Feb 2024 23:48:37 +0100 Subject: [PATCH] feat: fully implemented oauth2 authentication and cookie sessions --- cmd/server/main.go | 5 +- example.env | 1 + fly.toml | 1 + go.mod | 2 + go.sum | 4 ++ internal/config.go | 2 + internal/handlers/handlers.go | 7 +- internal/handlers/oauth2.go | 73 +++++++++++++++++++- internal/handlers/session.go | 111 ++++++++++++++++++++++++++++++ internal/handlers/settings.go | 4 +- web/templates/pages/settings.tmpl | 4 ++ 11 files changed, 209 insertions(+), 5 deletions(-) create mode 100644 internal/handlers/session.go diff --git a/cmd/server/main.go b/cmd/server/main.go index 51226dc..b7b17e4 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -44,11 +44,14 @@ func main() { r.PathPrefix("/static/").Handler(http.StripPrefix("/static/", http.FileServer(http.FS(static.Static)))) r.HandleFunc("/", h.Index).Methods("GET") - r.HandleFunc("/settings", h.Settings).Methods("GET") + + // Authenticated routes + r.HandleFunc("/settings", h.Authenticated(h.Settings)).Methods("GET") // OAuth2 r.HandleFunc("/oauth2/login", h.OAuth2LoginGET).Methods("GET") r.HandleFunc("/oauth2/callback", h.OAuth2CallbackGET).Methods("GET") + r.HandleFunc("/oauth2/logout", h.Authenticated(h.OAuth2LogoutGET)).Methods("GET") log.Println("Server started on", config.PORT) log.Fatal(http.ListenAndServe(":"+config.PORT, r)) diff --git a/example.env b/example.env index 5c2b6db..3b59ffb 100644 --- a/example.env +++ b/example.env @@ -16,3 +16,4 @@ OAUTH2_SCOPES=openid,profile,email OAUTH2_ENDPOINT_TOKEN_URL=https://your_oauth2_provider/token OAUTH2_ENDPOINT_AUTH_URL=https://your_oauth2_provider/auth OAUTH2_ENDPOINT_USER_INFO_URL=https://your_oauth2_provider/userinfo +OAUTH2_ENDPOINT_USER_INFO_URL=https://your_oauth2_provider/logout diff --git a/fly.toml b/fly.toml index dd0dd7e..f87ae35 100644 --- a/fly.toml +++ b/fly.toml @@ -16,6 +16,7 @@ primary_region = 'waw' OAUTH2_ENDPOINT_TOKEN_URL = 'https://id.tjo.space/application/o/token/' OAUTH2_ENDPOINT_AUTH_URL = 'https://id.tjo.space/application/o/authorize/' OAUTH2_ENDPOINT_USER_INFO_URL = 'https://id.tjo.space/application/o/userinfo/' + OAUTH2_ENDPOINT_LOGOUT_URL = 'https://id.tjo.space/application/o/zdravko-development/end-session/' [processes] server = "server" diff --git a/go.mod b/go.mod index 27532ef..cd3c0b4 100644 --- a/go.mod +++ b/go.mod @@ -23,6 +23,8 @@ require ( github.com/golang/mock v1.6.0 // indirect github.com/golang/protobuf v1.5.3 // indirect github.com/google/uuid v1.6.0 // indirect + github.com/gorilla/securecookie v1.1.2 // indirect + github.com/gorilla/sessions v1.2.2 // indirect github.com/grpc-ecosystem/go-grpc-middleware v1.3.0 // indirect github.com/grpc-ecosystem/grpc-gateway v1.16.0 // indirect github.com/jinzhu/inflection v1.0.0 // indirect diff --git a/go.sum b/go.sum index fccf296..eed271b 100644 --- a/go.sum +++ b/go.sum @@ -972,6 +972,10 @@ github.com/googleapis/go-type-adapters v1.0.0/go.mod h1:zHW75FOG2aur7gAO2B+MLby+ github.com/googleapis/google-cloud-go-testing v0.0.0-20200911160855-bcd43fbb19e8/go.mod h1:dvDLG8qkwmyD9a/MJJN3XJcT3xFxOKAvTZGvuZmac9g= github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY= github.com/gorilla/mux v1.8.1/go.mod h1:AKf9I4AEqPTmMytcMc0KkNouC66V3BtZ4qD5fmWSiMQ= +github.com/gorilla/securecookie v1.1.2 h1:YCIWL56dvtr73r6715mJs5ZvhtnY73hBvEF8kXD8ePA= +github.com/gorilla/securecookie v1.1.2/go.mod h1:NfCASbcHqRSY+3a8tlWJwsQap2VX5pwzwo4h3eOamfo= +github.com/gorilla/sessions v1.2.2 h1:lqzMYz6bOfvn2WriPUjNByzeXIlVzURcPmgMczkmTjY= +github.com/gorilla/sessions v1.2.2/go.mod h1:ePLdVu+jbEgHH+KWw8I1z2wqd0BAdAQh/8LRvBeoNcQ= github.com/grpc-ecosystem/go-grpc-middleware v1.3.0 h1:+9834+KizmvFV7pXQGSXQTsaWhq2GjuNUt0aUU0YBYw= github.com/grpc-ecosystem/go-grpc-middleware v1.3.0/go.mod h1:z0ButlSOZa5vEBq9m2m2hlwIgKw+rp3sdCBRoJY+30Y= github.com/grpc-ecosystem/grpc-gateway v1.16.0 h1:gmcG1KaJ57LophUzW0Hy8NmPhnMZb4M0+kPpLofRdBo= diff --git a/internal/config.go b/internal/config.go index b9a3a40..162864c 100644 --- a/internal/config.go +++ b/internal/config.go @@ -19,6 +19,7 @@ type Config struct { OAUTH2_ENDPOINT_TOKEN_URL string OAUTH2_ENDPOINT_AUTH_URL string OAUTH2_ENDPOINT_USER_INFO_URL string + OAUTH2_ENDPOINT_LOGOUT_URL string } func getEnv(key, fallback string) string { @@ -49,5 +50,6 @@ func NewConfig() *Config { OAUTH2_ENDPOINT_TOKEN_URL: getEnvRequired("OAUTH2_ENDPOINT_TOKEN_URL"), OAUTH2_ENDPOINT_AUTH_URL: getEnvRequired("OAUTH2_ENDPOINT_AUTH_URL"), OAUTH2_ENDPOINT_USER_INFO_URL: getEnvRequired("OAUTH2_ENDPOINT_USER_INFO_URL"), + OAUTH2_ENDPOINT_LOGOUT_URL: getEnvRequired("OAUTH2_ENDPOINT_LOGOUT_URL"), } } diff --git a/internal/handlers/handlers.go b/internal/handlers/handlers.go index 55f2ac2..e6a38ff 100644 --- a/internal/handlers/handlers.go +++ b/internal/handlers/handlers.go @@ -3,6 +3,7 @@ package handlers import ( "code.tjo.space/mentos1386/zdravko/internal" "code.tjo.space/mentos1386/zdravko/internal/models/query" + "github.com/gorilla/sessions" "gorm.io/gorm" ) @@ -10,8 +11,12 @@ type BaseHandler struct { db *gorm.DB query *query.Query config *internal.Config + + store *sessions.CookieStore } func NewBaseHandler(db *gorm.DB, q *query.Query, config *internal.Config) *BaseHandler { - return &BaseHandler{db, q, config} + store := sessions.NewCookieStore([]byte(config.SESSION_SECRET)) + + return &BaseHandler{db, q, config, store} } diff --git a/internal/handlers/oauth2.go b/internal/handlers/oauth2.go index 7f99dfe..99ebfe5 100644 --- a/internal/handlers/oauth2.go +++ b/internal/handlers/oauth2.go @@ -2,6 +2,7 @@ package handlers import ( "context" + "encoding/json" "io" "net/http" @@ -9,6 +10,11 @@ import ( "golang.org/x/oauth2" ) +type UserInfo struct { + Sub string `json:"sub"` + Email string `json:"email"` +} + func newOAuth2(config *internal.Config) *oauth2.Config { return &oauth2.Config{ ClientID: config.OAUTH2_CLIENT_ID, @@ -22,6 +28,40 @@ func newOAuth2(config *internal.Config) *oauth2.Config { } } +func (h *BaseHandler) AuthenticatedUserToOAuth2Token(user *AuthenticatedUser) *oauth2.Token { + return &oauth2.Token{ + AccessToken: user.OAuth2AccessToken, + TokenType: user.OAuth2TokenType, + RefreshToken: user.OAuth2RefreshToken, + Expiry: user.OAuth2Expiry, + } +} + +func (h *BaseHandler) RefreshToken(w http.ResponseWriter, r *http.Request, user *AuthenticatedUser) (*AuthenticatedUser, error) { + tok := h.AuthenticatedUserToOAuth2Token(user) + conf := newOAuth2(h.config) + refreshed, err := conf.TokenSource(context.Background(), tok).Token() + if err != nil { + return nil, err + } + + refreshedUser := &AuthenticatedUser{ + ID: user.ID, + Email: user.Email, + OAuth2AccessToken: refreshed.AccessToken, + OAuth2RefreshToken: refreshed.RefreshToken, + OAuth2TokenType: refreshed.TokenType, + OAuth2Expiry: refreshed.Expiry, + } + + err = h.SetAuthenticatedUserForRequest(w, r, refreshedUser) + if err != nil { + return nil, err + } + + return refreshedUser, nil +} + func (h *BaseHandler) OAuth2LoginGET(w http.ResponseWriter, r *http.Request) { conf := newOAuth2(h.config) @@ -54,10 +94,41 @@ func (h *BaseHandler) OAuth2CallbackGET(w http.ResponseWriter, r *http.Request) http.Error(w, err.Error(), http.StatusInternalServerError) } - _, err = w.Write(body) + var userInfo UserInfo + err = json.Unmarshal(body, &userInfo) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } + err = h.SetAuthenticatedUserForRequest(w, r, &AuthenticatedUser{ + ID: userInfo.Sub, + Email: userInfo.Email, + OAuth2AccessToken: tok.AccessToken, + OAuth2RefreshToken: tok.RefreshToken, + OAuth2TokenType: tok.TokenType, + OAuth2Expiry: tok.Expiry, + }) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + http.Redirect(w, r, "/settings", http.StatusTemporaryRedirect) +} + +func (h *BaseHandler) OAuth2LogoutGET(w http.ResponseWriter, r *http.Request, user *AuthenticatedUser) { + tok := h.AuthenticatedUserToOAuth2Token(user) + client := oauth2.NewClient(context.Background(), oauth2.StaticTokenSource(tok)) + _, err := client.Get(h.config.OAUTH2_ENDPOINT_USER_INFO_URL) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + err = h.ClearAuthenticatedUserForRequest(w, r) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } + http.Redirect(w, r, "/", http.StatusTemporaryRedirect) } diff --git a/internal/handlers/session.go b/internal/handlers/session.go new file mode 100644 index 0000000..0d5207b --- /dev/null +++ b/internal/handlers/session.go @@ -0,0 +1,111 @@ +package handlers + +import ( + "context" + "fmt" + "net/http" + "time" +) + +const sessionName = "zdravko-hey" + +type AuthenticatedUser struct { + ID string + Email string + OAuth2AccessToken string + OAuth2RefreshToken string + OAuth2TokenType string + OAuth2Expiry time.Time +} + +type authenticatedUserKeyType string + +const authenticatedUserKey authenticatedUserKeyType = "authenticatedUser" + +func WithUser(ctx context.Context, user *AuthenticatedUser) context.Context { + return context.WithValue(ctx, authenticatedUserKey, user) +} + +func GetUser(ctx context.Context) *AuthenticatedUser { + user, ok := ctx.Value(authenticatedUserKey).(*AuthenticatedUser) + if !ok { + return nil + } + return user +} + +func (h *BaseHandler) GetAuthenticatedUserForRequest(r *http.Request) (*AuthenticatedUser, error) { + session, err := h.store.Get(r, sessionName) + if err != nil { + return nil, err + } + if session.IsNew { + return nil, fmt.Errorf("session is nil") + } + + expiry, err := time.Parse(time.RFC3339, session.Values["oauth2_expiry"].(string)) + if err != nil { + return nil, err + } + + user := &AuthenticatedUser{ + ID: session.Values["id"].(string), + Email: session.Values["email"].(string), + OAuth2AccessToken: session.Values["oauth2_access_token"].(string), + OAuth2RefreshToken: session.Values["oauth2_refresh_token"].(string), + OAuth2TokenType: session.Values["oauth2_token_type"].(string), + OAuth2Expiry: expiry, + } + + return user, nil +} + +func (h *BaseHandler) SetAuthenticatedUserForRequest(w http.ResponseWriter, r *http.Request, user *AuthenticatedUser) error { + session, err := h.store.Get(r, sessionName) + if err != nil { + return err + } + session.Values["id"] = user.ID + session.Values["email"] = user.Email + session.Values["oauth2_access_token"] = user.OAuth2AccessToken + session.Values["oauth2_refresh_token"] = user.OAuth2RefreshToken + session.Values["oauth2_token_type"] = user.OAuth2TokenType + session.Values["oauth2_expiry"] = user.OAuth2Expiry.Format(time.RFC3339) + err = h.store.Save(r, w, session) + if err != nil { + return err + } + return nil +} + +func (h *BaseHandler) ClearAuthenticatedUserForRequest(w http.ResponseWriter, r *http.Request) error { + session, err := h.store.Get(r, sessionName) + if err != nil { + return err + } + session.Options.MaxAge = -1 + err = h.store.Save(r, w, session) + if err != nil { + return err + } + return nil +} + +type AuthenticatedHandler func(http.ResponseWriter, *http.Request, *AuthenticatedUser) + +func (h *BaseHandler) Authenticated(next AuthenticatedHandler) func(http.ResponseWriter, *http.Request) { + return func(w http.ResponseWriter, r *http.Request) { + user, err := h.GetAuthenticatedUserForRequest(r) + if err != nil { + http.Redirect(w, r, "/oauth2/login", http.StatusTemporaryRedirect) + return + } + if user.OAuth2Expiry.Before(time.Now()) { + user, err = h.RefreshToken(w, r, user) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } + } + next(w, r, user) + } +} diff --git a/internal/handlers/settings.go b/internal/handlers/settings.go index c4e119b..ef8955f 100644 --- a/internal/handlers/settings.go +++ b/internal/handlers/settings.go @@ -7,7 +7,7 @@ import ( "code.tjo.space/mentos1386/zdravko/web/templates" ) -func (h *BaseHandler) Settings(w http.ResponseWriter, r *http.Request) { +func (h *BaseHandler) Settings(w http.ResponseWriter, r *http.Request, user *AuthenticatedUser) { ts, err := template.ParseFS(templates.Templates, "components/base.tmpl", "pages/settings.tmpl", @@ -17,7 +17,7 @@ func (h *BaseHandler) Settings(w http.ResponseWriter, r *http.Request) { return } - err = ts.ExecuteTemplate(w, "base", nil) + err = ts.ExecuteTemplate(w, "base", user) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) } diff --git a/web/templates/pages/settings.tmpl b/web/templates/pages/settings.tmpl index a635aaf..c619330 100644 --- a/web/templates/pages/settings.tmpl +++ b/web/templates/pages/settings.tmpl @@ -2,4 +2,8 @@ {{define "main"}}

The settings!

+

You are logged in as {{.Email}}.

+

Your id is {{.ID}}.

+

Your access expieres at {{.OAuth2Expiry}}.

+ Logout {{end}}