feat(oauth2): use and store state

This commit is contained in:
Tine 2024-02-12 09:25:11 +01:00
parent 0a323c79e6
commit 7085f0e4d6
Signed by: mentos1386
SSH key fingerprint: SHA256:MNtTsLbihYaWF8j1fkOHfkKNlnN1JQfxEU/rBU8nCGw
9 changed files with 458 additions and 7 deletions

View file

@ -14,7 +14,10 @@ func ConnectToDatabase(path string) (*gorm.DB, *query.Query, error) {
return nil, nil, err
}
db.AutoMigrate(&models.Healthcheck{})
err = db.AutoMigrate(&models.Healthcheck{}, &models.OAuth2State{})
if err != nil {
return nil, nil, err
}
q := query.Use(db)

View file

@ -2,11 +2,16 @@ package handlers
import (
"context"
"crypto/rand"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"net/http"
"time"
"code.tjo.space/mentos1386/zdravko/internal"
"code.tjo.space/mentos1386/zdravko/internal/models"
"golang.org/x/oauth2"
)
@ -15,6 +20,15 @@ type UserInfo struct {
Email string `json:"email"`
}
func newRandomState() string {
b := make([]byte, 16)
_, err := rand.Read(b)
if err != nil {
panic(err)
}
return hex.EncodeToString(b)
}
func newOAuth2(config *internal.Config) *oauth2.Config {
return &oauth2.Config{
ClientID: config.OAUTH2_CLIENT_ID,
@ -42,6 +56,7 @@ func (h *BaseHandler) RefreshToken(w http.ResponseWriter, r *http.Request, user
conf := newOAuth2(h.config)
refreshed, err := conf.TokenSource(context.Background(), tok).Token()
if err != nil {
fmt.Println("Error: ", err)
return nil, err
}
@ -65,7 +80,13 @@ func (h *BaseHandler) RefreshToken(w http.ResponseWriter, r *http.Request, user
func (h *BaseHandler) OAuth2LoginGET(w http.ResponseWriter, r *http.Request) {
conf := newOAuth2(h.config)
url := conf.AuthCodeURL("state", oauth2.AccessTypeOffline)
state := newRandomState()
result := h.db.Create(&models.OAuth2State{State: state, Expiry: time.Now().Add(5 * time.Minute)})
if result.Error != nil {
http.Error(w, result.Error.Error(), http.StatusInternalServerError)
}
url := conf.AuthCodeURL(state, oauth2.AccessTypeOffline)
http.Redirect(w, r, url, http.StatusTemporaryRedirect)
}
@ -74,6 +95,21 @@ func (h *BaseHandler) OAuth2CallbackGET(w http.ResponseWriter, r *http.Request)
ctx := context.Background()
conf := newOAuth2(h.config)
state := r.URL.Query().Get("state")
result, err := h.query.OAuth2State.WithContext(ctx).Where(
h.query.OAuth2State.State.Eq(state),
h.query.OAuth2State.Expiry.Gt(time.Now()),
).Delete()
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
if result.RowsAffected != 1 {
http.Error(w, "Invalid state", http.StatusUnauthorized)
return
}
// Exchange the code for a new token.
tok, err := conf.Exchange(r.Context(), r.URL.Query().Get("code"))
if err != nil {

View file

@ -103,7 +103,8 @@ func (h *BaseHandler) Authenticated(next AuthenticatedHandler) func(http.Respons
if user.OAuth2Expiry.Before(time.Now()) {
user, err = h.RefreshToken(w, r, user)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
http.Redirect(w, r, "/oauth2/login", http.StatusTemporaryRedirect)
return
}
}
next(w, r, user)

View file

@ -1,5 +1,12 @@
package models
import "time"
type OAuth2State struct {
State string `gorm:"primary_key"`
Expiry time.Time
}
type Healthcheck struct {
ID uint `gorm:"primary_key"`
Name string

View file

@ -18,17 +18,20 @@ import (
var (
Q = new(Query)
Healthcheck *healthcheck
OAuth2State *oAuth2State
)
func SetDefault(db *gorm.DB, opts ...gen.DOOption) {
*Q = *Use(db, opts...)
Healthcheck = &Q.Healthcheck
OAuth2State = &Q.OAuth2State
}
func Use(db *gorm.DB, opts ...gen.DOOption) *Query {
return &Query{
db: db,
Healthcheck: newHealthcheck(db, opts...),
OAuth2State: newOAuth2State(db, opts...),
}
}
@ -36,6 +39,7 @@ type Query struct {
db *gorm.DB
Healthcheck healthcheck
OAuth2State oAuth2State
}
func (q *Query) Available() bool { return q.db != nil }
@ -44,6 +48,7 @@ func (q *Query) clone(db *gorm.DB) *Query {
return &Query{
db: db,
Healthcheck: q.Healthcheck.clone(db),
OAuth2State: q.OAuth2State.clone(db),
}
}
@ -59,16 +64,19 @@ func (q *Query) ReplaceDB(db *gorm.DB) *Query {
return &Query{
db: db,
Healthcheck: q.Healthcheck.replaceDB(db),
OAuth2State: q.OAuth2State.replaceDB(db),
}
}
type queryCtx struct {
Healthcheck IHealthcheckDo
OAuth2State IOAuth2StateDo
}
func (q *Query) WithContext(ctx context.Context) *queryCtx {
return &queryCtx{
Healthcheck: q.Healthcheck.WithContext(ctx),
OAuth2State: q.OAuth2State.WithContext(ctx),
}
}

View file

@ -0,0 +1,394 @@
// Code generated by gorm.io/gen. DO NOT EDIT.
// Code generated by gorm.io/gen. DO NOT EDIT.
// Code generated by gorm.io/gen. DO NOT EDIT.
package query
import (
"context"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/gorm/schema"
"gorm.io/gen"
"gorm.io/gen/field"
"gorm.io/plugin/dbresolver"
"code.tjo.space/mentos1386/zdravko/internal/models"
)
func newOAuth2State(db *gorm.DB, opts ...gen.DOOption) oAuth2State {
_oAuth2State := oAuth2State{}
_oAuth2State.oAuth2StateDo.UseDB(db, opts...)
_oAuth2State.oAuth2StateDo.UseModel(&models.OAuth2State{})
tableName := _oAuth2State.oAuth2StateDo.TableName()
_oAuth2State.ALL = field.NewAsterisk(tableName)
_oAuth2State.State = field.NewString(tableName, "state")
_oAuth2State.Expiry = field.NewTime(tableName, "expiry")
_oAuth2State.fillFieldMap()
return _oAuth2State
}
type oAuth2State struct {
oAuth2StateDo oAuth2StateDo
ALL field.Asterisk
State field.String
Expiry field.Time
fieldMap map[string]field.Expr
}
func (o oAuth2State) Table(newTableName string) *oAuth2State {
o.oAuth2StateDo.UseTable(newTableName)
return o.updateTableName(newTableName)
}
func (o oAuth2State) As(alias string) *oAuth2State {
o.oAuth2StateDo.DO = *(o.oAuth2StateDo.As(alias).(*gen.DO))
return o.updateTableName(alias)
}
func (o *oAuth2State) updateTableName(table string) *oAuth2State {
o.ALL = field.NewAsterisk(table)
o.State = field.NewString(table, "state")
o.Expiry = field.NewTime(table, "expiry")
o.fillFieldMap()
return o
}
func (o *oAuth2State) WithContext(ctx context.Context) IOAuth2StateDo {
return o.oAuth2StateDo.WithContext(ctx)
}
func (o oAuth2State) TableName() string { return o.oAuth2StateDo.TableName() }
func (o oAuth2State) Alias() string { return o.oAuth2StateDo.Alias() }
func (o oAuth2State) Columns(cols ...field.Expr) gen.Columns { return o.oAuth2StateDo.Columns(cols...) }
func (o *oAuth2State) GetFieldByName(fieldName string) (field.OrderExpr, bool) {
_f, ok := o.fieldMap[fieldName]
if !ok || _f == nil {
return nil, false
}
_oe, ok := _f.(field.OrderExpr)
return _oe, ok
}
func (o *oAuth2State) fillFieldMap() {
o.fieldMap = make(map[string]field.Expr, 2)
o.fieldMap["state"] = o.State
o.fieldMap["expiry"] = o.Expiry
}
func (o oAuth2State) clone(db *gorm.DB) oAuth2State {
o.oAuth2StateDo.ReplaceConnPool(db.Statement.ConnPool)
return o
}
func (o oAuth2State) replaceDB(db *gorm.DB) oAuth2State {
o.oAuth2StateDo.ReplaceDB(db)
return o
}
type oAuth2StateDo struct{ gen.DO }
type IOAuth2StateDo interface {
gen.SubQuery
Debug() IOAuth2StateDo
WithContext(ctx context.Context) IOAuth2StateDo
WithResult(fc func(tx gen.Dao)) gen.ResultInfo
ReplaceDB(db *gorm.DB)
ReadDB() IOAuth2StateDo
WriteDB() IOAuth2StateDo
As(alias string) gen.Dao
Session(config *gorm.Session) IOAuth2StateDo
Columns(cols ...field.Expr) gen.Columns
Clauses(conds ...clause.Expression) IOAuth2StateDo
Not(conds ...gen.Condition) IOAuth2StateDo
Or(conds ...gen.Condition) IOAuth2StateDo
Select(conds ...field.Expr) IOAuth2StateDo
Where(conds ...gen.Condition) IOAuth2StateDo
Order(conds ...field.Expr) IOAuth2StateDo
Distinct(cols ...field.Expr) IOAuth2StateDo
Omit(cols ...field.Expr) IOAuth2StateDo
Join(table schema.Tabler, on ...field.Expr) IOAuth2StateDo
LeftJoin(table schema.Tabler, on ...field.Expr) IOAuth2StateDo
RightJoin(table schema.Tabler, on ...field.Expr) IOAuth2StateDo
Group(cols ...field.Expr) IOAuth2StateDo
Having(conds ...gen.Condition) IOAuth2StateDo
Limit(limit int) IOAuth2StateDo
Offset(offset int) IOAuth2StateDo
Count() (count int64, err error)
Scopes(funcs ...func(gen.Dao) gen.Dao) IOAuth2StateDo
Unscoped() IOAuth2StateDo
Create(values ...*models.OAuth2State) error
CreateInBatches(values []*models.OAuth2State, batchSize int) error
Save(values ...*models.OAuth2State) error
First() (*models.OAuth2State, error)
Take() (*models.OAuth2State, error)
Last() (*models.OAuth2State, error)
Find() ([]*models.OAuth2State, error)
FindInBatch(batchSize int, fc func(tx gen.Dao, batch int) error) (results []*models.OAuth2State, err error)
FindInBatches(result *[]*models.OAuth2State, batchSize int, fc func(tx gen.Dao, batch int) error) error
Pluck(column field.Expr, dest interface{}) error
Delete(...*models.OAuth2State) (info gen.ResultInfo, err error)
Update(column field.Expr, value interface{}) (info gen.ResultInfo, err error)
UpdateSimple(columns ...field.AssignExpr) (info gen.ResultInfo, err error)
Updates(value interface{}) (info gen.ResultInfo, err error)
UpdateColumn(column field.Expr, value interface{}) (info gen.ResultInfo, err error)
UpdateColumnSimple(columns ...field.AssignExpr) (info gen.ResultInfo, err error)
UpdateColumns(value interface{}) (info gen.ResultInfo, err error)
UpdateFrom(q gen.SubQuery) gen.Dao
Attrs(attrs ...field.AssignExpr) IOAuth2StateDo
Assign(attrs ...field.AssignExpr) IOAuth2StateDo
Joins(fields ...field.RelationField) IOAuth2StateDo
Preload(fields ...field.RelationField) IOAuth2StateDo
FirstOrInit() (*models.OAuth2State, error)
FirstOrCreate() (*models.OAuth2State, error)
FindByPage(offset int, limit int) (result []*models.OAuth2State, count int64, err error)
ScanByPage(result interface{}, offset int, limit int) (count int64, err error)
Scan(result interface{}) (err error)
Returning(value interface{}, columns ...string) IOAuth2StateDo
UnderlyingDB() *gorm.DB
schema.Tabler
}
func (o oAuth2StateDo) Debug() IOAuth2StateDo {
return o.withDO(o.DO.Debug())
}
func (o oAuth2StateDo) WithContext(ctx context.Context) IOAuth2StateDo {
return o.withDO(o.DO.WithContext(ctx))
}
func (o oAuth2StateDo) ReadDB() IOAuth2StateDo {
return o.Clauses(dbresolver.Read)
}
func (o oAuth2StateDo) WriteDB() IOAuth2StateDo {
return o.Clauses(dbresolver.Write)
}
func (o oAuth2StateDo) Session(config *gorm.Session) IOAuth2StateDo {
return o.withDO(o.DO.Session(config))
}
func (o oAuth2StateDo) Clauses(conds ...clause.Expression) IOAuth2StateDo {
return o.withDO(o.DO.Clauses(conds...))
}
func (o oAuth2StateDo) Returning(value interface{}, columns ...string) IOAuth2StateDo {
return o.withDO(o.DO.Returning(value, columns...))
}
func (o oAuth2StateDo) Not(conds ...gen.Condition) IOAuth2StateDo {
return o.withDO(o.DO.Not(conds...))
}
func (o oAuth2StateDo) Or(conds ...gen.Condition) IOAuth2StateDo {
return o.withDO(o.DO.Or(conds...))
}
func (o oAuth2StateDo) Select(conds ...field.Expr) IOAuth2StateDo {
return o.withDO(o.DO.Select(conds...))
}
func (o oAuth2StateDo) Where(conds ...gen.Condition) IOAuth2StateDo {
return o.withDO(o.DO.Where(conds...))
}
func (o oAuth2StateDo) Order(conds ...field.Expr) IOAuth2StateDo {
return o.withDO(o.DO.Order(conds...))
}
func (o oAuth2StateDo) Distinct(cols ...field.Expr) IOAuth2StateDo {
return o.withDO(o.DO.Distinct(cols...))
}
func (o oAuth2StateDo) Omit(cols ...field.Expr) IOAuth2StateDo {
return o.withDO(o.DO.Omit(cols...))
}
func (o oAuth2StateDo) Join(table schema.Tabler, on ...field.Expr) IOAuth2StateDo {
return o.withDO(o.DO.Join(table, on...))
}
func (o oAuth2StateDo) LeftJoin(table schema.Tabler, on ...field.Expr) IOAuth2StateDo {
return o.withDO(o.DO.LeftJoin(table, on...))
}
func (o oAuth2StateDo) RightJoin(table schema.Tabler, on ...field.Expr) IOAuth2StateDo {
return o.withDO(o.DO.RightJoin(table, on...))
}
func (o oAuth2StateDo) Group(cols ...field.Expr) IOAuth2StateDo {
return o.withDO(o.DO.Group(cols...))
}
func (o oAuth2StateDo) Having(conds ...gen.Condition) IOAuth2StateDo {
return o.withDO(o.DO.Having(conds...))
}
func (o oAuth2StateDo) Limit(limit int) IOAuth2StateDo {
return o.withDO(o.DO.Limit(limit))
}
func (o oAuth2StateDo) Offset(offset int) IOAuth2StateDo {
return o.withDO(o.DO.Offset(offset))
}
func (o oAuth2StateDo) Scopes(funcs ...func(gen.Dao) gen.Dao) IOAuth2StateDo {
return o.withDO(o.DO.Scopes(funcs...))
}
func (o oAuth2StateDo) Unscoped() IOAuth2StateDo {
return o.withDO(o.DO.Unscoped())
}
func (o oAuth2StateDo) Create(values ...*models.OAuth2State) error {
if len(values) == 0 {
return nil
}
return o.DO.Create(values)
}
func (o oAuth2StateDo) CreateInBatches(values []*models.OAuth2State, batchSize int) error {
return o.DO.CreateInBatches(values, batchSize)
}
// Save : !!! underlying implementation is different with GORM
// The method is equivalent to executing the statement: db.Clauses(clause.OnConflict{UpdateAll: true}).Create(values)
func (o oAuth2StateDo) Save(values ...*models.OAuth2State) error {
if len(values) == 0 {
return nil
}
return o.DO.Save(values)
}
func (o oAuth2StateDo) First() (*models.OAuth2State, error) {
if result, err := o.DO.First(); err != nil {
return nil, err
} else {
return result.(*models.OAuth2State), nil
}
}
func (o oAuth2StateDo) Take() (*models.OAuth2State, error) {
if result, err := o.DO.Take(); err != nil {
return nil, err
} else {
return result.(*models.OAuth2State), nil
}
}
func (o oAuth2StateDo) Last() (*models.OAuth2State, error) {
if result, err := o.DO.Last(); err != nil {
return nil, err
} else {
return result.(*models.OAuth2State), nil
}
}
func (o oAuth2StateDo) Find() ([]*models.OAuth2State, error) {
result, err := o.DO.Find()
return result.([]*models.OAuth2State), err
}
func (o oAuth2StateDo) FindInBatch(batchSize int, fc func(tx gen.Dao, batch int) error) (results []*models.OAuth2State, err error) {
buf := make([]*models.OAuth2State, 0, batchSize)
err = o.DO.FindInBatches(&buf, batchSize, func(tx gen.Dao, batch int) error {
defer func() { results = append(results, buf...) }()
return fc(tx, batch)
})
return results, err
}
func (o oAuth2StateDo) FindInBatches(result *[]*models.OAuth2State, batchSize int, fc func(tx gen.Dao, batch int) error) error {
return o.DO.FindInBatches(result, batchSize, fc)
}
func (o oAuth2StateDo) Attrs(attrs ...field.AssignExpr) IOAuth2StateDo {
return o.withDO(o.DO.Attrs(attrs...))
}
func (o oAuth2StateDo) Assign(attrs ...field.AssignExpr) IOAuth2StateDo {
return o.withDO(o.DO.Assign(attrs...))
}
func (o oAuth2StateDo) Joins(fields ...field.RelationField) IOAuth2StateDo {
for _, _f := range fields {
o = *o.withDO(o.DO.Joins(_f))
}
return &o
}
func (o oAuth2StateDo) Preload(fields ...field.RelationField) IOAuth2StateDo {
for _, _f := range fields {
o = *o.withDO(o.DO.Preload(_f))
}
return &o
}
func (o oAuth2StateDo) FirstOrInit() (*models.OAuth2State, error) {
if result, err := o.DO.FirstOrInit(); err != nil {
return nil, err
} else {
return result.(*models.OAuth2State), nil
}
}
func (o oAuth2StateDo) FirstOrCreate() (*models.OAuth2State, error) {
if result, err := o.DO.FirstOrCreate(); err != nil {
return nil, err
} else {
return result.(*models.OAuth2State), nil
}
}
func (o oAuth2StateDo) FindByPage(offset int, limit int) (result []*models.OAuth2State, count int64, err error) {
result, err = o.Offset(offset).Limit(limit).Find()
if err != nil {
return
}
if size := len(result); 0 < limit && 0 < size && size < limit {
count = int64(size + offset)
return
}
count, err = o.Offset(-1).Limit(-1).Count()
return
}
func (o oAuth2StateDo) ScanByPage(result interface{}, offset int, limit int) (count int64, err error) {
count, err = o.Count()
if err != nil {
return
}
err = o.Offset(offset).Limit(limit).Scan(result)
return
}
func (o oAuth2StateDo) Scan(result interface{}) (err error) {
return o.DO.Scan(result)
}
func (o oAuth2StateDo) Delete(models ...*models.OAuth2State) (result gen.ResultInfo, err error) {
return o.DO.Delete(models)
}
func (o *oAuth2StateDo) withDO(do gen.Dao) *oAuth2StateDo {
o.DO = *do.(*gen.DO)
return o
}

View file

@ -3,7 +3,7 @@ set shell := ["devbox", "run"]
# Load dotenv
set dotenv-load
STATIC_DIR := "./internal/static"
STATIC_DIR := "./web/static"
# Run full development environment
run:

View file

@ -1,5 +1,5 @@
module.exports = {
content: ["./internal/ui/**/*.{tmpl,go}"],
content: ["./web/templates/**/*.{tmpl,go}"],
theme: {
container: {
center: true,

View file

@ -7,6 +7,8 @@ import (
)
func main() {
config := internal.NewConfig()
// Initialize the generator with configuration
g := gen.NewGenerator(gen.Config{
OutPath: "internal/models/query",
@ -14,14 +16,14 @@ func main() {
FieldNullable: true,
})
db, _ := internal.ConnectToDatabase()
db, _, _ := internal.ConnectToDatabase(config.SQLITE_DB_PATH)
// Use the above `*gorm.DB` instance to initialize the generator,
// which is required to generate structs from db when using `GenerateModel/GenerateModelAs`
g.UseDB(db)
// Generate default DAO interface for those specified structs
g.ApplyBasic(models.Healthcheck{})
g.ApplyBasic(models.Healthcheck{}, models.OAuth2State{})
// Execute the generator
g.Execute()