From 7085f0e4d69f25efc3cdd1e7f30dcea88ac226d7 Mon Sep 17 00:00:00 2001 From: Tine Date: Mon, 12 Feb 2024 09:25:11 +0100 Subject: [PATCH] feat(oauth2): use and store state --- internal/database.go | 5 +- internal/handlers/oauth2.go | 38 +- internal/handlers/session.go | 3 +- internal/models/models.go | 7 + internal/models/query/gen.go | 8 + internal/models/query/o_auth2_states.gen.go | 394 ++++++++++++++++++++ justfile | 2 +- tailwind.config.js | 2 +- tools/generate/main.go | 6 +- 9 files changed, 458 insertions(+), 7 deletions(-) create mode 100644 internal/models/query/o_auth2_states.gen.go diff --git a/internal/database.go b/internal/database.go index 8467f5d..c858554 100644 --- a/internal/database.go +++ b/internal/database.go @@ -14,7 +14,10 @@ func ConnectToDatabase(path string) (*gorm.DB, *query.Query, error) { return nil, nil, err } - db.AutoMigrate(&models.Healthcheck{}) + err = db.AutoMigrate(&models.Healthcheck{}, &models.OAuth2State{}) + if err != nil { + return nil, nil, err + } q := query.Use(db) diff --git a/internal/handlers/oauth2.go b/internal/handlers/oauth2.go index 99ebfe5..7e789fa 100644 --- a/internal/handlers/oauth2.go +++ b/internal/handlers/oauth2.go @@ -2,11 +2,16 @@ package handlers import ( "context" + "crypto/rand" + "encoding/hex" "encoding/json" + "fmt" "io" "net/http" + "time" "code.tjo.space/mentos1386/zdravko/internal" + "code.tjo.space/mentos1386/zdravko/internal/models" "golang.org/x/oauth2" ) @@ -15,6 +20,15 @@ type UserInfo struct { Email string `json:"email"` } +func newRandomState() string { + b := make([]byte, 16) + _, err := rand.Read(b) + if err != nil { + panic(err) + } + return hex.EncodeToString(b) +} + func newOAuth2(config *internal.Config) *oauth2.Config { return &oauth2.Config{ ClientID: config.OAUTH2_CLIENT_ID, @@ -42,6 +56,7 @@ func (h *BaseHandler) RefreshToken(w http.ResponseWriter, r *http.Request, user conf := newOAuth2(h.config) refreshed, err := conf.TokenSource(context.Background(), tok).Token() if err != nil { + fmt.Println("Error: ", err) return nil, err } @@ -65,7 +80,13 @@ func (h *BaseHandler) RefreshToken(w http.ResponseWriter, r *http.Request, user func (h *BaseHandler) OAuth2LoginGET(w http.ResponseWriter, r *http.Request) { conf := newOAuth2(h.config) - url := conf.AuthCodeURL("state", oauth2.AccessTypeOffline) + state := newRandomState() + result := h.db.Create(&models.OAuth2State{State: state, Expiry: time.Now().Add(5 * time.Minute)}) + if result.Error != nil { + http.Error(w, result.Error.Error(), http.StatusInternalServerError) + } + + url := conf.AuthCodeURL(state, oauth2.AccessTypeOffline) http.Redirect(w, r, url, http.StatusTemporaryRedirect) } @@ -74,6 +95,21 @@ func (h *BaseHandler) OAuth2CallbackGET(w http.ResponseWriter, r *http.Request) ctx := context.Background() conf := newOAuth2(h.config) + state := r.URL.Query().Get("state") + + result, err := h.query.OAuth2State.WithContext(ctx).Where( + h.query.OAuth2State.State.Eq(state), + h.query.OAuth2State.Expiry.Gt(time.Now()), + ).Delete() + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + if result.RowsAffected != 1 { + http.Error(w, "Invalid state", http.StatusUnauthorized) + return + } + // Exchange the code for a new token. tok, err := conf.Exchange(r.Context(), r.URL.Query().Get("code")) if err != nil { diff --git a/internal/handlers/session.go b/internal/handlers/session.go index 0d5207b..1010c27 100644 --- a/internal/handlers/session.go +++ b/internal/handlers/session.go @@ -103,7 +103,8 @@ func (h *BaseHandler) Authenticated(next AuthenticatedHandler) func(http.Respons if user.OAuth2Expiry.Before(time.Now()) { user, err = h.RefreshToken(w, r, user) if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) + http.Redirect(w, r, "/oauth2/login", http.StatusTemporaryRedirect) + return } } next(w, r, user) diff --git a/internal/models/models.go b/internal/models/models.go index 27ab8bd..033a80f 100644 --- a/internal/models/models.go +++ b/internal/models/models.go @@ -1,5 +1,12 @@ package models +import "time" + +type OAuth2State struct { + State string `gorm:"primary_key"` + Expiry time.Time +} + type Healthcheck struct { ID uint `gorm:"primary_key"` Name string diff --git a/internal/models/query/gen.go b/internal/models/query/gen.go index ccee5be..34d8b2b 100644 --- a/internal/models/query/gen.go +++ b/internal/models/query/gen.go @@ -18,17 +18,20 @@ import ( var ( Q = new(Query) Healthcheck *healthcheck + OAuth2State *oAuth2State ) func SetDefault(db *gorm.DB, opts ...gen.DOOption) { *Q = *Use(db, opts...) Healthcheck = &Q.Healthcheck + OAuth2State = &Q.OAuth2State } func Use(db *gorm.DB, opts ...gen.DOOption) *Query { return &Query{ db: db, Healthcheck: newHealthcheck(db, opts...), + OAuth2State: newOAuth2State(db, opts...), } } @@ -36,6 +39,7 @@ type Query struct { db *gorm.DB Healthcheck healthcheck + OAuth2State oAuth2State } func (q *Query) Available() bool { return q.db != nil } @@ -44,6 +48,7 @@ func (q *Query) clone(db *gorm.DB) *Query { return &Query{ db: db, Healthcheck: q.Healthcheck.clone(db), + OAuth2State: q.OAuth2State.clone(db), } } @@ -59,16 +64,19 @@ func (q *Query) ReplaceDB(db *gorm.DB) *Query { return &Query{ db: db, Healthcheck: q.Healthcheck.replaceDB(db), + OAuth2State: q.OAuth2State.replaceDB(db), } } type queryCtx struct { Healthcheck IHealthcheckDo + OAuth2State IOAuth2StateDo } func (q *Query) WithContext(ctx context.Context) *queryCtx { return &queryCtx{ Healthcheck: q.Healthcheck.WithContext(ctx), + OAuth2State: q.OAuth2State.WithContext(ctx), } } diff --git a/internal/models/query/o_auth2_states.gen.go b/internal/models/query/o_auth2_states.gen.go new file mode 100644 index 0000000..56aa48a --- /dev/null +++ b/internal/models/query/o_auth2_states.gen.go @@ -0,0 +1,394 @@ +// Code generated by gorm.io/gen. DO NOT EDIT. +// Code generated by gorm.io/gen. DO NOT EDIT. +// Code generated by gorm.io/gen. DO NOT EDIT. + +package query + +import ( + "context" + + "gorm.io/gorm" + "gorm.io/gorm/clause" + "gorm.io/gorm/schema" + + "gorm.io/gen" + "gorm.io/gen/field" + + "gorm.io/plugin/dbresolver" + + "code.tjo.space/mentos1386/zdravko/internal/models" +) + +func newOAuth2State(db *gorm.DB, opts ...gen.DOOption) oAuth2State { + _oAuth2State := oAuth2State{} + + _oAuth2State.oAuth2StateDo.UseDB(db, opts...) + _oAuth2State.oAuth2StateDo.UseModel(&models.OAuth2State{}) + + tableName := _oAuth2State.oAuth2StateDo.TableName() + _oAuth2State.ALL = field.NewAsterisk(tableName) + _oAuth2State.State = field.NewString(tableName, "state") + _oAuth2State.Expiry = field.NewTime(tableName, "expiry") + + _oAuth2State.fillFieldMap() + + return _oAuth2State +} + +type oAuth2State struct { + oAuth2StateDo oAuth2StateDo + + ALL field.Asterisk + State field.String + Expiry field.Time + + fieldMap map[string]field.Expr +} + +func (o oAuth2State) Table(newTableName string) *oAuth2State { + o.oAuth2StateDo.UseTable(newTableName) + return o.updateTableName(newTableName) +} + +func (o oAuth2State) As(alias string) *oAuth2State { + o.oAuth2StateDo.DO = *(o.oAuth2StateDo.As(alias).(*gen.DO)) + return o.updateTableName(alias) +} + +func (o *oAuth2State) updateTableName(table string) *oAuth2State { + o.ALL = field.NewAsterisk(table) + o.State = field.NewString(table, "state") + o.Expiry = field.NewTime(table, "expiry") + + o.fillFieldMap() + + return o +} + +func (o *oAuth2State) WithContext(ctx context.Context) IOAuth2StateDo { + return o.oAuth2StateDo.WithContext(ctx) +} + +func (o oAuth2State) TableName() string { return o.oAuth2StateDo.TableName() } + +func (o oAuth2State) Alias() string { return o.oAuth2StateDo.Alias() } + +func (o oAuth2State) Columns(cols ...field.Expr) gen.Columns { return o.oAuth2StateDo.Columns(cols...) } + +func (o *oAuth2State) GetFieldByName(fieldName string) (field.OrderExpr, bool) { + _f, ok := o.fieldMap[fieldName] + if !ok || _f == nil { + return nil, false + } + _oe, ok := _f.(field.OrderExpr) + return _oe, ok +} + +func (o *oAuth2State) fillFieldMap() { + o.fieldMap = make(map[string]field.Expr, 2) + o.fieldMap["state"] = o.State + o.fieldMap["expiry"] = o.Expiry +} + +func (o oAuth2State) clone(db *gorm.DB) oAuth2State { + o.oAuth2StateDo.ReplaceConnPool(db.Statement.ConnPool) + return o +} + +func (o oAuth2State) replaceDB(db *gorm.DB) oAuth2State { + o.oAuth2StateDo.ReplaceDB(db) + return o +} + +type oAuth2StateDo struct{ gen.DO } + +type IOAuth2StateDo interface { + gen.SubQuery + Debug() IOAuth2StateDo + WithContext(ctx context.Context) IOAuth2StateDo + WithResult(fc func(tx gen.Dao)) gen.ResultInfo + ReplaceDB(db *gorm.DB) + ReadDB() IOAuth2StateDo + WriteDB() IOAuth2StateDo + As(alias string) gen.Dao + Session(config *gorm.Session) IOAuth2StateDo + Columns(cols ...field.Expr) gen.Columns + Clauses(conds ...clause.Expression) IOAuth2StateDo + Not(conds ...gen.Condition) IOAuth2StateDo + Or(conds ...gen.Condition) IOAuth2StateDo + Select(conds ...field.Expr) IOAuth2StateDo + Where(conds ...gen.Condition) IOAuth2StateDo + Order(conds ...field.Expr) IOAuth2StateDo + Distinct(cols ...field.Expr) IOAuth2StateDo + Omit(cols ...field.Expr) IOAuth2StateDo + Join(table schema.Tabler, on ...field.Expr) IOAuth2StateDo + LeftJoin(table schema.Tabler, on ...field.Expr) IOAuth2StateDo + RightJoin(table schema.Tabler, on ...field.Expr) IOAuth2StateDo + Group(cols ...field.Expr) IOAuth2StateDo + Having(conds ...gen.Condition) IOAuth2StateDo + Limit(limit int) IOAuth2StateDo + Offset(offset int) IOAuth2StateDo + Count() (count int64, err error) + Scopes(funcs ...func(gen.Dao) gen.Dao) IOAuth2StateDo + Unscoped() IOAuth2StateDo + Create(values ...*models.OAuth2State) error + CreateInBatches(values []*models.OAuth2State, batchSize int) error + Save(values ...*models.OAuth2State) error + First() (*models.OAuth2State, error) + Take() (*models.OAuth2State, error) + Last() (*models.OAuth2State, error) + Find() ([]*models.OAuth2State, error) + FindInBatch(batchSize int, fc func(tx gen.Dao, batch int) error) (results []*models.OAuth2State, err error) + FindInBatches(result *[]*models.OAuth2State, batchSize int, fc func(tx gen.Dao, batch int) error) error + Pluck(column field.Expr, dest interface{}) error + Delete(...*models.OAuth2State) (info gen.ResultInfo, err error) + Update(column field.Expr, value interface{}) (info gen.ResultInfo, err error) + UpdateSimple(columns ...field.AssignExpr) (info gen.ResultInfo, err error) + Updates(value interface{}) (info gen.ResultInfo, err error) + UpdateColumn(column field.Expr, value interface{}) (info gen.ResultInfo, err error) + UpdateColumnSimple(columns ...field.AssignExpr) (info gen.ResultInfo, err error) + UpdateColumns(value interface{}) (info gen.ResultInfo, err error) + UpdateFrom(q gen.SubQuery) gen.Dao + Attrs(attrs ...field.AssignExpr) IOAuth2StateDo + Assign(attrs ...field.AssignExpr) IOAuth2StateDo + Joins(fields ...field.RelationField) IOAuth2StateDo + Preload(fields ...field.RelationField) IOAuth2StateDo + FirstOrInit() (*models.OAuth2State, error) + FirstOrCreate() (*models.OAuth2State, error) + FindByPage(offset int, limit int) (result []*models.OAuth2State, count int64, err error) + ScanByPage(result interface{}, offset int, limit int) (count int64, err error) + Scan(result interface{}) (err error) + Returning(value interface{}, columns ...string) IOAuth2StateDo + UnderlyingDB() *gorm.DB + schema.Tabler +} + +func (o oAuth2StateDo) Debug() IOAuth2StateDo { + return o.withDO(o.DO.Debug()) +} + +func (o oAuth2StateDo) WithContext(ctx context.Context) IOAuth2StateDo { + return o.withDO(o.DO.WithContext(ctx)) +} + +func (o oAuth2StateDo) ReadDB() IOAuth2StateDo { + return o.Clauses(dbresolver.Read) +} + +func (o oAuth2StateDo) WriteDB() IOAuth2StateDo { + return o.Clauses(dbresolver.Write) +} + +func (o oAuth2StateDo) Session(config *gorm.Session) IOAuth2StateDo { + return o.withDO(o.DO.Session(config)) +} + +func (o oAuth2StateDo) Clauses(conds ...clause.Expression) IOAuth2StateDo { + return o.withDO(o.DO.Clauses(conds...)) +} + +func (o oAuth2StateDo) Returning(value interface{}, columns ...string) IOAuth2StateDo { + return o.withDO(o.DO.Returning(value, columns...)) +} + +func (o oAuth2StateDo) Not(conds ...gen.Condition) IOAuth2StateDo { + return o.withDO(o.DO.Not(conds...)) +} + +func (o oAuth2StateDo) Or(conds ...gen.Condition) IOAuth2StateDo { + return o.withDO(o.DO.Or(conds...)) +} + +func (o oAuth2StateDo) Select(conds ...field.Expr) IOAuth2StateDo { + return o.withDO(o.DO.Select(conds...)) +} + +func (o oAuth2StateDo) Where(conds ...gen.Condition) IOAuth2StateDo { + return o.withDO(o.DO.Where(conds...)) +} + +func (o oAuth2StateDo) Order(conds ...field.Expr) IOAuth2StateDo { + return o.withDO(o.DO.Order(conds...)) +} + +func (o oAuth2StateDo) Distinct(cols ...field.Expr) IOAuth2StateDo { + return o.withDO(o.DO.Distinct(cols...)) +} + +func (o oAuth2StateDo) Omit(cols ...field.Expr) IOAuth2StateDo { + return o.withDO(o.DO.Omit(cols...)) +} + +func (o oAuth2StateDo) Join(table schema.Tabler, on ...field.Expr) IOAuth2StateDo { + return o.withDO(o.DO.Join(table, on...)) +} + +func (o oAuth2StateDo) LeftJoin(table schema.Tabler, on ...field.Expr) IOAuth2StateDo { + return o.withDO(o.DO.LeftJoin(table, on...)) +} + +func (o oAuth2StateDo) RightJoin(table schema.Tabler, on ...field.Expr) IOAuth2StateDo { + return o.withDO(o.DO.RightJoin(table, on...)) +} + +func (o oAuth2StateDo) Group(cols ...field.Expr) IOAuth2StateDo { + return o.withDO(o.DO.Group(cols...)) +} + +func (o oAuth2StateDo) Having(conds ...gen.Condition) IOAuth2StateDo { + return o.withDO(o.DO.Having(conds...)) +} + +func (o oAuth2StateDo) Limit(limit int) IOAuth2StateDo { + return o.withDO(o.DO.Limit(limit)) +} + +func (o oAuth2StateDo) Offset(offset int) IOAuth2StateDo { + return o.withDO(o.DO.Offset(offset)) +} + +func (o oAuth2StateDo) Scopes(funcs ...func(gen.Dao) gen.Dao) IOAuth2StateDo { + return o.withDO(o.DO.Scopes(funcs...)) +} + +func (o oAuth2StateDo) Unscoped() IOAuth2StateDo { + return o.withDO(o.DO.Unscoped()) +} + +func (o oAuth2StateDo) Create(values ...*models.OAuth2State) error { + if len(values) == 0 { + return nil + } + return o.DO.Create(values) +} + +func (o oAuth2StateDo) CreateInBatches(values []*models.OAuth2State, batchSize int) error { + return o.DO.CreateInBatches(values, batchSize) +} + +// Save : !!! underlying implementation is different with GORM +// The method is equivalent to executing the statement: db.Clauses(clause.OnConflict{UpdateAll: true}).Create(values) +func (o oAuth2StateDo) Save(values ...*models.OAuth2State) error { + if len(values) == 0 { + return nil + } + return o.DO.Save(values) +} + +func (o oAuth2StateDo) First() (*models.OAuth2State, error) { + if result, err := o.DO.First(); err != nil { + return nil, err + } else { + return result.(*models.OAuth2State), nil + } +} + +func (o oAuth2StateDo) Take() (*models.OAuth2State, error) { + if result, err := o.DO.Take(); err != nil { + return nil, err + } else { + return result.(*models.OAuth2State), nil + } +} + +func (o oAuth2StateDo) Last() (*models.OAuth2State, error) { + if result, err := o.DO.Last(); err != nil { + return nil, err + } else { + return result.(*models.OAuth2State), nil + } +} + +func (o oAuth2StateDo) Find() ([]*models.OAuth2State, error) { + result, err := o.DO.Find() + return result.([]*models.OAuth2State), err +} + +func (o oAuth2StateDo) FindInBatch(batchSize int, fc func(tx gen.Dao, batch int) error) (results []*models.OAuth2State, err error) { + buf := make([]*models.OAuth2State, 0, batchSize) + err = o.DO.FindInBatches(&buf, batchSize, func(tx gen.Dao, batch int) error { + defer func() { results = append(results, buf...) }() + return fc(tx, batch) + }) + return results, err +} + +func (o oAuth2StateDo) FindInBatches(result *[]*models.OAuth2State, batchSize int, fc func(tx gen.Dao, batch int) error) error { + return o.DO.FindInBatches(result, batchSize, fc) +} + +func (o oAuth2StateDo) Attrs(attrs ...field.AssignExpr) IOAuth2StateDo { + return o.withDO(o.DO.Attrs(attrs...)) +} + +func (o oAuth2StateDo) Assign(attrs ...field.AssignExpr) IOAuth2StateDo { + return o.withDO(o.DO.Assign(attrs...)) +} + +func (o oAuth2StateDo) Joins(fields ...field.RelationField) IOAuth2StateDo { + for _, _f := range fields { + o = *o.withDO(o.DO.Joins(_f)) + } + return &o +} + +func (o oAuth2StateDo) Preload(fields ...field.RelationField) IOAuth2StateDo { + for _, _f := range fields { + o = *o.withDO(o.DO.Preload(_f)) + } + return &o +} + +func (o oAuth2StateDo) FirstOrInit() (*models.OAuth2State, error) { + if result, err := o.DO.FirstOrInit(); err != nil { + return nil, err + } else { + return result.(*models.OAuth2State), nil + } +} + +func (o oAuth2StateDo) FirstOrCreate() (*models.OAuth2State, error) { + if result, err := o.DO.FirstOrCreate(); err != nil { + return nil, err + } else { + return result.(*models.OAuth2State), nil + } +} + +func (o oAuth2StateDo) FindByPage(offset int, limit int) (result []*models.OAuth2State, count int64, err error) { + result, err = o.Offset(offset).Limit(limit).Find() + if err != nil { + return + } + + if size := len(result); 0 < limit && 0 < size && size < limit { + count = int64(size + offset) + return + } + + count, err = o.Offset(-1).Limit(-1).Count() + return +} + +func (o oAuth2StateDo) ScanByPage(result interface{}, offset int, limit int) (count int64, err error) { + count, err = o.Count() + if err != nil { + return + } + + err = o.Offset(offset).Limit(limit).Scan(result) + return +} + +func (o oAuth2StateDo) Scan(result interface{}) (err error) { + return o.DO.Scan(result) +} + +func (o oAuth2StateDo) Delete(models ...*models.OAuth2State) (result gen.ResultInfo, err error) { + return o.DO.Delete(models) +} + +func (o *oAuth2StateDo) withDO(do gen.Dao) *oAuth2StateDo { + o.DO = *do.(*gen.DO) + return o +} diff --git a/justfile b/justfile index e8d2b88..efc226e 100644 --- a/justfile +++ b/justfile @@ -3,7 +3,7 @@ set shell := ["devbox", "run"] # Load dotenv set dotenv-load -STATIC_DIR := "./internal/static" +STATIC_DIR := "./web/static" # Run full development environment run: diff --git a/tailwind.config.js b/tailwind.config.js index 2a46255..9c185b7 100644 --- a/tailwind.config.js +++ b/tailwind.config.js @@ -1,5 +1,5 @@ module.exports = { - content: ["./internal/ui/**/*.{tmpl,go}"], + content: ["./web/templates/**/*.{tmpl,go}"], theme: { container: { center: true, diff --git a/tools/generate/main.go b/tools/generate/main.go index 2bc25da..35b0868 100644 --- a/tools/generate/main.go +++ b/tools/generate/main.go @@ -7,6 +7,8 @@ import ( ) func main() { + config := internal.NewConfig() + // Initialize the generator with configuration g := gen.NewGenerator(gen.Config{ OutPath: "internal/models/query", @@ -14,14 +16,14 @@ func main() { FieldNullable: true, }) - db, _ := internal.ConnectToDatabase() + db, _, _ := internal.ConnectToDatabase(config.SQLITE_DB_PATH) // Use the above `*gorm.DB` instance to initialize the generator, // which is required to generate structs from db when using `GenerateModel/GenerateModelAs` g.UseDB(db) // Generate default DAO interface for those specified structs - g.ApplyBasic(models.Healthcheck{}) + g.ApplyBasic(models.Healthcheck{}, models.OAuth2State{}) // Execute the generator g.Execute()