diff --git a/build/Dockerfile b/build/Dockerfile index 4cd456e..d18cfa0 100644 --- a/build/Dockerfile +++ b/build/Dockerfile @@ -36,3 +36,4 @@ ENV DATABASE_PATH=/data/zdravko.db ENV TEMPORAL_DATABASE_PATH=/data/temporal.db ENTRYPOINT ["/bin/zdravko"] +CMD ["--server", "--temporal", "--worker"] diff --git a/cmd/zdravko/main.go b/cmd/zdravko/main.go index 8f257e8..fa920b1 100644 --- a/cmd/zdravko/main.go +++ b/cmd/zdravko/main.go @@ -40,13 +40,12 @@ func main() { log.Fatal("At least one of the following must be set: --server, --worker, --temporal") } - cfg := config.NewConfig() - var servers [3]StartableAndStoppable var wg sync.WaitGroup if startTemporal { log.Println("Setting up Temporal") + cfg := config.NewTemporalConfig() temporal, err := temporal.NewTemporal(cfg) if err != nil { log.Fatalf("Unable to create temporal: %v", err) @@ -56,6 +55,7 @@ func main() { if startServer { log.Println("Setting up Server") + cfg := config.NewServerConfig() server, err := server.NewServer(cfg) if err != nil { log.Fatalf("Unable to create server: %v", err) @@ -65,6 +65,7 @@ func main() { if startWorker { log.Println("Setting up Worker") + cfg := config.NewWorkerConfig() worker, err := worker.NewWorker(cfg) if err != nil { log.Fatalf("Unable to create worker: %v", err) diff --git a/deploy/docker-compose.yaml b/deploy/docker-compose.yaml new file mode 100644 index 0000000..da8cce6 --- /dev/null +++ b/deploy/docker-compose.yaml @@ -0,0 +1,41 @@ +version: '3.8' + +volumes: + server_data: + temporal_data: + +services: + server: + image: ghcr.io/mentos1386/zdravko:main + command: ["--server"] + volumes: + - server_data:/data + ports: + - 8000:8000 + environment: + - ROOT_URL=http://localhost:8000 + - SESSION_SECRET=change-me + - JWT_PUBLIC_KEY=change-me + - JWT_PRIVATE_KEY=change-me + - OAUTH2_CLIENT_ID=change-me + - OAUTH2_CLIENT_SECRET=change-me + - OAUTH2_ENDPOINT_TOKEN_URL=change-me + - OAUTH2_ENDPOINT_AUTH_URL=change-me + - OAUTH2_ENDPOINT_USER_INFO_URL=change-me + - TEMPORAL_UI_HOST=temporal:8223 + - TEMPORAL_SERVER_HOST=temporal:7233 + + temporal: + image: ghcr.io/mentos1386/zdravko:main + command: ["--temporal"] + volumes: + - temporal_data:/data + environment: + - JWT_PUBLIC_KEY=change-me + + worker: + image: ghcr.io/mentos1386/zdravko:main + command: ["--worker"] + environment: + - WORKER_TOKEN=change-me + - WORKER_API_URL=http://server:8000 diff --git a/deploy/fly.toml b/deploy/fly.toml index 834d0d3..294b39d 100644 --- a/deploy/fly.toml +++ b/deploy/fly.toml @@ -30,7 +30,7 @@ primary_region = 'waw' force_https = true auto_stop_machines = true auto_start_machines = true - min_machines_running = 0 + min_machines_running = 1 processes = ['server'] [[services]] diff --git a/internal/config/config.go b/internal/config/config.go index ae84f34..0fb24db 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -1,54 +1,11 @@ package config import ( - "log" "os" - "strings" - "github.com/go-playground/validator/v10" "github.com/spf13/viper" ) -type Config struct { - Port string `validate:"required"` - RootUrl string `validate:"required,url"` - DatabasePath string `validate:"required"` - SessionSecret string `validate:"required"` - - Jwt Jwt `validate:"required"` - OAuth2 OAuth2 `validate:"required"` - - Temporal Temporal `validate:"required"` - - Worker Worker `validate:"required"` -} - -type Jwt struct { - PrivateKey string `validate:"required"` - PublicKey string `validate:"required"` -} - -type OAuth2 struct { - ClientID string `validate:"required"` - ClientSecret string `validate:"required"` - Scopes []string `validate:"required"` - EndpointTokenURL string `validate:"required"` - EndpointAuthURL string `validate:"required"` - EndpointUserInfoURL string `validate:"required"` - EndpointLogoutURL string // Optional as not all SSO support this. -} - -type Temporal struct { - DatabasePath string `validate:"required"` - ListenAddress string `validate:"required"` - UIHost string `validate:"required"` - ServerHost string `validate:"required"` -} - -type Worker struct { - Token string `validate:"required"` -} - func GetEnvOrDefault(key, def string) string { value := os.Getenv(key) if value == "" { @@ -57,60 +14,14 @@ func GetEnvOrDefault(key, def string) string { return value } -func NewConfig() *Config { - viper.SetConfigName("zdravko") - viper.SetConfigType("yaml") - viper.AddConfigPath("/etc/zdravko/") - viper.AddConfigPath("$HOME/.zdravko") - viper.AddConfigPath("$HOME/.config/zdravko") - viper.AddConfigPath("$XDG_CONFIG_HOME/zdravko") - viper.AddConfigPath(".") - - // Set defaults - viper.SetDefault("port", GetEnvOrDefault("PORT", "8000")) - viper.SetDefault("rooturl", GetEnvOrDefault("ROOT_URL", "http://localhost:8000")) - viper.SetDefault("databasepath", GetEnvOrDefault("DATABASE_PATH", "zdravko.db")) - viper.SetDefault("sessionsecret", os.Getenv("SESSION_SECRET")) - viper.SetDefault("temporal.databasepath", GetEnvOrDefault("TEMPORAL_DATABASE_PATH", "temporal.db")) - viper.SetDefault("temporal.listenaddress", GetEnvOrDefault("TEMPORAL_LISTEN_ADDRESS", "0.0.0.0")) - viper.SetDefault("temporal.uihost", GetEnvOrDefault("TEMPORAL_UI_HOST", "127.0.0.1:8223")) - viper.SetDefault("temporal.serverhost", GetEnvOrDefault("TEMPORAL_SERVER_HOST", "127.0.0.1:7233")) - viper.SetDefault("jwt.privatekey", os.Getenv("JWT_PRIVATE_KEY")) - viper.SetDefault("jwt.publickey", os.Getenv("JWT_PUBLIC_KEY")) - viper.SetDefault("oauth2.clientid", os.Getenv("OAUTH2_CLIENT_ID")) - viper.SetDefault("oauth2.clientsecret", os.Getenv("OAUTH2_CLIENT_SECRET")) - viper.SetDefault("oauth2.scopes", GetEnvOrDefault("OAUTH2_ENDPOINT_SCOPES", "openid profile email")) - viper.SetDefault("oauth2.endpointtokenurl", os.Getenv("OAUTH2_ENDPOINT_TOKEN_URL")) - viper.SetDefault("oauth2.endpointauthurl", os.Getenv("OAUTH2_ENDPOINT_AUTH_URL")) - viper.SetDefault("oauth2.endpointuserinfourl", os.Getenv("OAUTH2_ENDPOINT_USER_INFO_URL")) - viper.SetDefault("oauth2.endpointlogouturl", GetEnvOrDefault("OAUTH2_ENDPOINT_LOGOUT_URL", "")) - viper.SetDefault("worker.token", os.Getenv("WORKER_TOKEN")) - - err := viper.ReadInConfig() - if err != nil { - if _, ok := err.(viper.ConfigFileNotFoundError); ok { - // ignore - } else { - log.Fatalf("Error reading config file, %s", err) - } - } - log.Println("Config file used: ", viper.ConfigFileUsed()) - - config := &Config{} - err = viper.Unmarshal(config) - if err != nil { - log.Fatalf("Error unmarshalling config, %s", err) - } - - // OAuth2 scopes are space separated - config.OAuth2.Scopes = strings.Split(viper.GetString("oauth2.scopes"), " ") - - // Validate config - validate := validator.New(validator.WithRequiredStructEnabled()) - err = validate.Struct(config) - if err != nil { - log.Fatalf("Error validating config, %s", err) - } - - return config +func newViper() *viper.Viper { + v := viper.New() + v.SetConfigName("zdravko") + v.SetConfigType("yaml") + v.AddConfigPath("/etc/zdravko/") + v.AddConfigPath("$HOME/.zdravko") + v.AddConfigPath("$HOME/.config/zdravko") + v.AddConfigPath("$XDG_CONFIG_HOME/zdravko") + v.AddConfigPath(".") + return v } diff --git a/internal/config/server.go b/internal/config/server.go new file mode 100644 index 0000000..e37ad0d --- /dev/null +++ b/internal/config/server.go @@ -0,0 +1,91 @@ +package config + +import ( + "log" + "os" + "strings" + + "github.com/go-playground/validator/v10" + "github.com/spf13/viper" +) + +type ServerConfig struct { + Port string `validate:"required"` + RootUrl string `validate:"required,url"` + DatabasePath string `validate:"required"` + SessionSecret string `validate:"required"` + + Jwt ServerJwt `validate:"required"` + OAuth2 ServerOAuth2 `validate:"required"` + + Temporal ServerTemporal `validate:"required"` +} + +type ServerJwt struct { + PrivateKey string `validate:"required"` + PublicKey string `validate:"required"` +} + +type ServerOAuth2 struct { + ClientID string `validate:"required"` + ClientSecret string `validate:"required"` + Scopes []string `validate:"required"` + EndpointTokenURL string `validate:"required"` + EndpointAuthURL string `validate:"required"` + EndpointUserInfoURL string `validate:"required"` + EndpointLogoutURL string // Optional as not all SSO support this. +} + +type ServerTemporal struct { + UIHost string `validate:"required"` + ServerHost string `validate:"required"` +} + +func NewServerConfig() *ServerConfig { + v := newViper() + + // Set defaults + v.SetDefault("port", GetEnvOrDefault("PORT", "8000")) + v.SetDefault("rooturl", GetEnvOrDefault("ROOT_URL", "http://localhost:8000")) + v.SetDefault("databasepath", GetEnvOrDefault("DATABASE_PATH", "zdravko.db")) + v.SetDefault("sessionsecret", os.Getenv("SESSION_SECRET")) + v.SetDefault("temporal.uihost", GetEnvOrDefault("TEMPORAL_UI_HOST", "127.0.0.1:8223")) + v.SetDefault("temporal.serverhost", GetEnvOrDefault("TEMPORAL_SERVER_HOST", "127.0.0.1:7233")) + v.SetDefault("jwt.privatekey", os.Getenv("JWT_PRIVATE_KEY")) + v.SetDefault("jwt.publickey", os.Getenv("JWT_PUBLIC_KEY")) + v.SetDefault("oauth2.clientid", os.Getenv("OAUTH2_CLIENT_ID")) + v.SetDefault("oauth2.clientsecret", os.Getenv("OAUTH2_CLIENT_SECRET")) + v.SetDefault("oauth2.scopes", GetEnvOrDefault("OAUTH2_ENDPOINT_SCOPES", "openid profile email")) + v.SetDefault("oauth2.endpointtokenurl", os.Getenv("OAUTH2_ENDPOINT_TOKEN_URL")) + v.SetDefault("oauth2.endpointauthurl", os.Getenv("OAUTH2_ENDPOINT_AUTH_URL")) + v.SetDefault("oauth2.endpointuserinfourl", os.Getenv("OAUTH2_ENDPOINT_USER_INFO_URL")) + v.SetDefault("oauth2.endpointlogouturl", GetEnvOrDefault("OAUTH2_ENDPOINT_LOGOUT_URL", "")) + + err := v.ReadInConfig() + if err != nil { + if _, ok := err.(viper.ConfigFileNotFoundError); ok { + // ignore + } else { + log.Fatalf("Error reading config file, %s", err) + } + } + log.Println("Config file used: ", v.ConfigFileUsed()) + + config := &ServerConfig{} + err = v.Unmarshal(config) + if err != nil { + log.Fatalf("Error unmarshalling config, %s", err) + } + + // OAuth2 scopes are space separated + config.OAuth2.Scopes = strings.Split(v.GetString("oauth2.scopes"), " ") + + // Validate config + validate := validator.New(validator.WithRequiredStructEnabled()) + err = validate.Struct(config) + if err != nil { + log.Fatalf("Error validating config, %s", err) + } + + return config +} diff --git a/internal/config/temporal.go b/internal/config/temporal.go new file mode 100644 index 0000000..0442d5e --- /dev/null +++ b/internal/config/temporal.go @@ -0,0 +1,54 @@ +package config + +import ( + "log" + "os" + + "github.com/go-playground/validator/v10" + "github.com/spf13/viper" +) + +type TemporalConfig struct { + DatabasePath string `validate:"required"` + ListenAddress string `validate:"required"` + + Jwt TemporalJwt `validate:"required"` +} + +type TemporalJwt struct { + PublicKey string `validate:"required"` +} + +func NewTemporalConfig() *TemporalConfig { + v := newViper() + + // Set defaults + v.SetDefault("databasepath", GetEnvOrDefault("TEMPORAL_DATABASE_PATH", "temporal.db")) + v.SetDefault("listenaddress", GetEnvOrDefault("TEMPORAL_LISTEN_ADDRESS", "0.0.0.0")) + v.SetDefault("jwt.publickey", os.Getenv("JWT_PUBLIC_KEY")) + + err := v.ReadInConfig() + if err != nil { + if _, ok := err.(viper.ConfigFileNotFoundError); ok { + // ignore + } else { + log.Fatalf("Error reading config file, %s", err) + } + } + log.Println("Config file used: ", v.ConfigFileUsed()) + + config := &TemporalConfig{} + err = v.Unmarshal(config) + if err != nil { + log.Fatalf("Error unmarshalling config, %s", err) + } + + // Validate config + validate := validator.New(validator.WithRequiredStructEnabled()) + err = validate.Struct(config) + if err != nil { + log.Fatalf("Error validating config, %s", err) + } + + return config +} diff --git a/internal/config/worker.go b/internal/config/worker.go new file mode 100644 index 0000000..c3f2a80 --- /dev/null +++ b/internal/config/worker.go @@ -0,0 +1,47 @@ +package config + +import ( + "log" + "os" + + "github.com/go-playground/validator/v10" + "github.com/spf13/viper" +) + +type WorkerConfig struct { + Token string `validate:"required"` + ApiUrl string `validate:"required"` +} + +func NewWorkerConfig() *WorkerConfig { + v := newViper() + + // Set defaults + v.SetDefault("token", os.Getenv("WORKER_TOKEN")) + v.SetDefault("apiurl", os.Getenv("WORKER_API_URL")) + + err := v.ReadInConfig() + if err != nil { + if _, ok := err.(viper.ConfigFileNotFoundError); ok { + // ignore + } else { + log.Fatalf("Error reading config file, %s", err) + } + } + log.Println("Config file used: ", v.ConfigFileUsed()) + + config := &WorkerConfig{} + err = v.Unmarshal(config) + if err != nil { + log.Fatalf("Error unmarshalling config, %s", err) + } + + // Validate config + validate := validator.New(validator.WithRequiredStructEnabled()) + err = validate.Struct(config) + if err != nil { + log.Fatalf("Error validating config, %s", err) + } + + return config +} diff --git a/internal/handlers/api.go b/internal/handlers/api.go new file mode 100644 index 0000000..c6cf332 --- /dev/null +++ b/internal/handlers/api.go @@ -0,0 +1,35 @@ +package handlers + +import ( + "encoding/json" + "net/http" +) + +type ApiV1WorkersConnectGETResponse struct { + Endpoint string `json:"endpoint"` + Group string `json:"group"` + Slug string `json:"slug"` +} + +func (h *BaseHandler) ApiV1WorkersConnectGET(w http.ResponseWriter, r *http.Request, principal *AuthenticatedPrincipal) { + // Json response containing temporal endpoint + w.Header().Set("Content-Type", "application/json") + + response := ApiV1WorkersConnectGETResponse{ + Endpoint: h.config.Temporal.ServerHost, + Group: principal.Worker.Group, + Slug: principal.Worker.Slug, + } + + responseJson, err := json.Marshal(response) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + _, err = w.Write(responseJson) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } +} diff --git a/internal/handlers/session.go b/internal/handlers/authentication.go similarity index 56% rename from internal/handlers/session.go rename to internal/handlers/authentication.go index 1010c27..e85544b 100644 --- a/internal/handlers/session.go +++ b/internal/handlers/authentication.go @@ -3,12 +3,21 @@ package handlers import ( "context" "fmt" + "log" "net/http" + "strings" "time" + + jwtInternal "code.tjo.space/mentos1386/zdravko/internal/jwt" ) const sessionName = "zdravko-hey" +type AuthenticatedPrincipal struct { + User *AuthenticatedUser + Worker *AuthenticatedWorker +} + type AuthenticatedUser struct { ID string Email string @@ -18,6 +27,11 @@ type AuthenticatedUser struct { OAuth2Expiry time.Time } +type AuthenticatedWorker struct { + Slug string + Group string +} + type authenticatedUserKeyType string const authenticatedUserKey authenticatedUserKeyType = "authenticatedUser" @@ -34,7 +48,7 @@ func GetUser(ctx context.Context) *AuthenticatedUser { return user } -func (h *BaseHandler) GetAuthenticatedUserForRequest(r *http.Request) (*AuthenticatedUser, error) { +func (h *BaseHandler) AuthenticateRequestWithCookies(r *http.Request) (*AuthenticatedUser, error) { session, err := h.store.Get(r, sessionName) if err != nil { return nil, err @@ -60,6 +74,47 @@ func (h *BaseHandler) GetAuthenticatedUserForRequest(r *http.Request) (*Authenti return user, nil } +func (h *BaseHandler) AuthenticateRequestWithToken(r *http.Request) (*AuthenticatedPrincipal, error) { + authorization := r.Header.Get("Authorization") + + splitAuthorization := strings.Split(authorization, " ") + if len(splitAuthorization) != 2 { + return nil, fmt.Errorf("invalid authorization header") + } + if splitAuthorization[0] != "Bearer" { + return nil, fmt.Errorf("invalid authorization header") + } + + _, claims, err := jwtInternal.ParseToken(splitAuthorization[1], h.config.Jwt.PublicKey) + if err != nil { + return nil, err + } + + splitSubject := strings.Split(claims.Subject, ":") + if len(splitSubject) != 2 { + return nil, fmt.Errorf("invalid subject") + } + + var worker *AuthenticatedWorker + var user *AuthenticatedUser + + if splitSubject[0] == "user" { + user = &AuthenticatedUser{} + } else if splitSubject[0] == "worker" { + worker = &AuthenticatedWorker{ + Slug: splitSubject[1], + Group: claims.WorkerGroup, + } + } + + principal := &AuthenticatedPrincipal{ + User: user, + Worker: worker, + } + + return principal, nil +} + func (h *BaseHandler) SetAuthenticatedUserForRequest(w http.ResponseWriter, r *http.Request, user *AuthenticatedUser) error { session, err := h.store.Get(r, sessionName) if err != nil { @@ -91,22 +146,32 @@ func (h *BaseHandler) ClearAuthenticatedUserForRequest(w http.ResponseWriter, r return nil } -type AuthenticatedHandler func(http.ResponseWriter, *http.Request, *AuthenticatedUser) +type AuthenticatedHandler func(http.ResponseWriter, *http.Request, *AuthenticatedPrincipal) func (h *BaseHandler) Authenticated(next AuthenticatedHandler) func(http.ResponseWriter, *http.Request) { return func(w http.ResponseWriter, r *http.Request) { - user, err := h.GetAuthenticatedUserForRequest(r) - if err != nil { - http.Redirect(w, r, "/oauth2/login", http.StatusTemporaryRedirect) + // First try cookie authentication + user, err := h.AuthenticateRequestWithCookies(r) + if err == nil { + if user.OAuth2Expiry.Before(time.Now()) { + user, err = h.RefreshToken(w, r, user) + if err != nil { + http.Redirect(w, r, "/oauth2/login", http.StatusTemporaryRedirect) + return + } + } + next(w, r, &AuthenticatedPrincipal{user, nil}) return } - if user.OAuth2Expiry.Before(time.Now()) { - user, err = h.RefreshToken(w, r, user) - if err != nil { - http.Redirect(w, r, "/oauth2/login", http.StatusTemporaryRedirect) - return - } + // Then try token based authentication + principal, err := h.AuthenticateRequestWithToken(r) + if err == nil { + next(w, r, principal) + return } - next(w, r, user) + + log.Println("err: ", err) + + http.Redirect(w, r, "/oauth2/login", http.StatusTemporaryRedirect) } } diff --git a/internal/handlers/handlers.go b/internal/handlers/handlers.go index 06a7c0c..14cba9e 100644 --- a/internal/handlers/handlers.go +++ b/internal/handlers/handlers.go @@ -27,14 +27,14 @@ func GetPageByTitle(pages []*components.Page, title string) *components.Page { type BaseHandler struct { db *gorm.DB query *query.Query - config *config.Config + config *config.ServerConfig temporal client.Client store *sessions.CookieStore } -func NewBaseHandler(db *gorm.DB, q *query.Query, temporal client.Client, config *config.Config) *BaseHandler { +func NewBaseHandler(db *gorm.DB, q *query.Query, temporal client.Client, config *config.ServerConfig) *BaseHandler { store := sessions.NewCookieStore([]byte(config.SessionSecret)) return &BaseHandler{db, q, config, temporal, store} diff --git a/internal/handlers/oauth2.go b/internal/handlers/oauth2.go index 06ad738..893f702 100644 --- a/internal/handlers/oauth2.go +++ b/internal/handlers/oauth2.go @@ -31,7 +31,7 @@ func newRandomState() string { return hex.EncodeToString(b) } -func newOAuth2(config *config.Config) *oauth2.Config { +func newOAuth2(config *config.ServerConfig) *oauth2.Config { return &oauth2.Config{ ClientID: config.OAuth2.ClientID, ClientSecret: config.OAuth2.ClientSecret, @@ -160,9 +160,9 @@ func (h *BaseHandler) OAuth2CallbackGET(w http.ResponseWriter, r *http.Request) http.Redirect(w, r, "/settings", http.StatusTemporaryRedirect) } -func (h *BaseHandler) OAuth2LogoutGET(w http.ResponseWriter, r *http.Request, user *AuthenticatedUser) { +func (h *BaseHandler) OAuth2LogoutGET(w http.ResponseWriter, r *http.Request, principal *AuthenticatedPrincipal) { if h.config.OAuth2.EndpointLogoutURL != "" { - tok := h.AuthenticatedUserToOAuth2Token(user) + tok := h.AuthenticatedUserToOAuth2Token(principal.User) client := oauth2.NewClient(context.Background(), oauth2.StaticTokenSource(tok)) _, err := client.Get(h.config.OAuth2.EndpointLogoutURL) if err != nil { diff --git a/internal/handlers/settings.go b/internal/handlers/settings.go index b85ab5e..86a1731 100644 --- a/internal/handlers/settings.go +++ b/internal/handlers/settings.go @@ -49,7 +49,7 @@ var SettingsNavbar = []*components.Page{ GetPageByTitle(SettingsPages, "Logout"), } -func (h *BaseHandler) SettingsOverviewGET(w http.ResponseWriter, r *http.Request, user *AuthenticatedUser) { +func (h *BaseHandler) SettingsOverviewGET(w http.ResponseWriter, r *http.Request, principal *AuthenticatedPrincipal) { ts, err := template.ParseFS(templates.Templates, "components/base.tmpl", "components/settings.tmpl", @@ -61,7 +61,7 @@ func (h *BaseHandler) SettingsOverviewGET(w http.ResponseWriter, r *http.Request } err = ts.ExecuteTemplate(w, "base", NewSettings( - user, + principal.User, GetPageByTitle(SettingsPages, "Overview"), []*components.Page{GetPageByTitle(SettingsPages, "Overview")}, )) diff --git a/internal/handlers/settingshealthchecks.go b/internal/handlers/settingshealthchecks.go index 19aa08f..735d41e 100644 --- a/internal/handlers/settingshealthchecks.go +++ b/internal/handlers/settingshealthchecks.go @@ -26,7 +26,7 @@ type SettingsHealthcheck struct { Healthcheck *models.HealthcheckHttp } -func (h *BaseHandler) SettingsHealthchecksGET(w http.ResponseWriter, r *http.Request, user *AuthenticatedUser) { +func (h *BaseHandler) SettingsHealthchecksGET(w http.ResponseWriter, r *http.Request, principal *AuthenticatedPrincipal) { ts, err := template.ParseFS(templates.Templates, "components/base.tmpl", "components/settings.tmpl", @@ -44,7 +44,7 @@ func (h *BaseHandler) SettingsHealthchecksGET(w http.ResponseWriter, r *http.Req err = ts.ExecuteTemplate(w, "base", &SettingsHealthchecks{ Settings: NewSettings( - user, + principal.User, GetPageByTitle(SettingsPages, "Healthchecks"), []*components.Page{GetPageByTitle(SettingsPages, "Healthchecks")}, ), @@ -56,7 +56,7 @@ func (h *BaseHandler) SettingsHealthchecksGET(w http.ResponseWriter, r *http.Req } } -func (h *BaseHandler) SettingsHealthchecksDescribeGET(w http.ResponseWriter, r *http.Request, user *AuthenticatedUser) { +func (h *BaseHandler) SettingsHealthchecksDescribeGET(w http.ResponseWriter, r *http.Request, principal *AuthenticatedPrincipal) { vars := mux.Vars(r) slug := vars["slug"] @@ -77,7 +77,7 @@ func (h *BaseHandler) SettingsHealthchecksDescribeGET(w http.ResponseWriter, r * err = ts.ExecuteTemplate(w, "base", &SettingsHealthcheck{ Settings: NewSettings( - user, + principal.User, GetPageByTitle(SettingsPages, "Healthchecks"), []*components.Page{ GetPageByTitle(SettingsPages, "Healthchecks"), @@ -94,7 +94,7 @@ func (h *BaseHandler) SettingsHealthchecksDescribeGET(w http.ResponseWriter, r * } } -func (h *BaseHandler) SettingsHealthchecksCreateGET(w http.ResponseWriter, r *http.Request, user *AuthenticatedUser) { +func (h *BaseHandler) SettingsHealthchecksCreateGET(w http.ResponseWriter, r *http.Request, principal *AuthenticatedPrincipal) { ts, err := template.ParseFS(templates.Templates, "components/base.tmpl", "components/settings.tmpl", @@ -106,7 +106,7 @@ func (h *BaseHandler) SettingsHealthchecksCreateGET(w http.ResponseWriter, r *ht } err = ts.ExecuteTemplate(w, "base", NewSettings( - user, + principal.User, GetPageByTitle(SettingsPages, "Healthchecks"), []*components.Page{ GetPageByTitle(SettingsPages, "Healthchecks"), @@ -118,7 +118,7 @@ func (h *BaseHandler) SettingsHealthchecksCreateGET(w http.ResponseWriter, r *ht } } -func (h *BaseHandler) SettingsHealthchecksCreatePOST(w http.ResponseWriter, r *http.Request, user *AuthenticatedUser) { +func (h *BaseHandler) SettingsHealthchecksCreatePOST(w http.ResponseWriter, r *http.Request, principal *AuthenticatedPrincipal) { ctx := context.Background() healthcheckHttp := &models.HealthcheckHttp{ diff --git a/internal/handlers/settingsworkers.go b/internal/handlers/settingsworkers.go index ee986c6..26a07a6 100644 --- a/internal/handlers/settingsworkers.go +++ b/internal/handlers/settingsworkers.go @@ -27,7 +27,7 @@ type SettingsWorker struct { Worker *models.Worker } -func (h *BaseHandler) SettingsWorkersGET(w http.ResponseWriter, r *http.Request, user *AuthenticatedUser) { +func (h *BaseHandler) SettingsWorkersGET(w http.ResponseWriter, r *http.Request, principal *AuthenticatedPrincipal) { ts, err := template.ParseFS(templates.Templates, "components/base.tmpl", "components/settings.tmpl", @@ -45,7 +45,7 @@ func (h *BaseHandler) SettingsWorkersGET(w http.ResponseWriter, r *http.Request, err = ts.ExecuteTemplate(w, "base", &SettingsWorkers{ Settings: NewSettings( - user, + principal.User, GetPageByTitle(SettingsPages, "Workers"), []*components.Page{GetPageByTitle(SettingsPages, "Workers")}, ), @@ -57,7 +57,7 @@ func (h *BaseHandler) SettingsWorkersGET(w http.ResponseWriter, r *http.Request, } } -func (h *BaseHandler) SettingsWorkersDescribeGET(w http.ResponseWriter, r *http.Request, user *AuthenticatedUser) { +func (h *BaseHandler) SettingsWorkersDescribeGET(w http.ResponseWriter, r *http.Request, principal *AuthenticatedPrincipal) { vars := mux.Vars(r) slug := vars["slug"] @@ -78,7 +78,7 @@ func (h *BaseHandler) SettingsWorkersDescribeGET(w http.ResponseWriter, r *http. err = ts.ExecuteTemplate(w, "base", &SettingsWorker{ Settings: NewSettings( - user, + principal.User, GetPageByTitle(SettingsPages, "Workers"), []*components.Page{ GetPageByTitle(SettingsPages, "Workers"), @@ -95,7 +95,7 @@ func (h *BaseHandler) SettingsWorkersDescribeGET(w http.ResponseWriter, r *http. } } -func (h *BaseHandler) SettingsWorkersCreateGET(w http.ResponseWriter, r *http.Request, user *AuthenticatedUser) { +func (h *BaseHandler) SettingsWorkersCreateGET(w http.ResponseWriter, r *http.Request, principal *AuthenticatedPrincipal) { ts, err := template.ParseFS(templates.Templates, "components/base.tmpl", "components/settings.tmpl", @@ -107,7 +107,7 @@ func (h *BaseHandler) SettingsWorkersCreateGET(w http.ResponseWriter, r *http.Re } err = ts.ExecuteTemplate(w, "base", NewSettings( - user, + principal.User, GetPageByTitle(SettingsPages, "Workers"), []*components.Page{ GetPageByTitle(SettingsPages, "Workers"), @@ -119,12 +119,13 @@ func (h *BaseHandler) SettingsWorkersCreateGET(w http.ResponseWriter, r *http.Re } } -func (h *BaseHandler) SettingsWorkersCreatePOST(w http.ResponseWriter, r *http.Request, user *AuthenticatedUser) { +func (h *BaseHandler) SettingsWorkersCreatePOST(w http.ResponseWriter, r *http.Request, principal *AuthenticatedPrincipal) { ctx := context.Background() worker := &models.Worker{ - Name: r.FormValue("name"), - Slug: slug.Make(r.FormValue("name")), + Name: r.FormValue("name"), + Slug: slug.Make(r.FormValue("name")), + Group: r.FormValue("group"), } err := validator.New(validator.WithRequiredStructEnabled()).Struct(worker) @@ -144,7 +145,7 @@ func (h *BaseHandler) SettingsWorkersCreatePOST(w http.ResponseWriter, r *http.R http.Redirect(w, r, "/settings/workers", http.StatusSeeOther) } -func (h *BaseHandler) SettingsWorkersTokenGET(w http.ResponseWriter, r *http.Request, user *AuthenticatedUser) { +func (h *BaseHandler) SettingsWorkersTokenGET(w http.ResponseWriter, r *http.Request, principal *AuthenticatedPrincipal) { vars := mux.Vars(r) slug := vars["slug"] @@ -154,7 +155,7 @@ func (h *BaseHandler) SettingsWorkersTokenGET(w http.ResponseWriter, r *http.Req } // Allow write access to default namespace - token, err := jwt.NewToken(h.config, []string{"default:write"}, worker.Slug) + token, err := jwt.NewTokenForWorker(h.config.Jwt.PrivateKey, h.config.Jwt.PublicKey, worker) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) } diff --git a/internal/handlers/temporal.go b/internal/handlers/temporal.go index 4f8dc7e..440ab1d 100644 --- a/internal/handlers/temporal.go +++ b/internal/handlers/temporal.go @@ -4,22 +4,28 @@ import ( "net/http" "net/http/httputil" "net/url" + + "code.tjo.space/mentos1386/zdravko/internal/jwt" ) -func (h *BaseHandler) Temporal(w http.ResponseWriter, r *http.Request, user *AuthenticatedUser) { +func (h *BaseHandler) Temporal(w http.ResponseWriter, r *http.Request, principal *AuthenticatedPrincipal) { proxy := httputil.NewSingleHostReverseProxy(&url.URL{ Host: h.config.Temporal.UIHost, Scheme: "http", }) - // TODO: Maybe add a "navbar" in html to go back to Zdravko? - proxy.ModifyResponse = func(response *http.Response) error { - // Read and update the response here + originalDirector := proxy.Director - // The response here is response from server (proxy B if this is at proxy A) - // It is a pointer, so can be modified to update in place - // It will not be called if Proxy B is unreachable - return nil + proxy.Director = func(r *http.Request) { + originalDirector(r) + // Add authentication token to be able to access temporal. + // FIXME: Maybe cache it somehow so we don't generate it on every request? + token, _ := jwt.NewTokenForUser( + h.config.Jwt.PrivateKey, + h.config.Jwt.PublicKey, + principal.User.Email, + ) + r.Header.Add("Authorization", "Bearer "+token) } proxy.ServeHTTP(w, r) diff --git a/internal/jwt/jwt.go b/internal/jwt/jwt.go index 8ac86e1..9a05065 100644 --- a/internal/jwt/jwt.go +++ b/internal/jwt/jwt.go @@ -6,7 +6,7 @@ import ( "encoding/hex" "time" - "code.tjo.space/mentos1386/zdravko/internal/config" + "code.tjo.space/mentos1386/zdravko/internal/models" "github.com/golang-jwt/jwt/v5" "github.com/pkg/errors" ) @@ -16,58 +16,115 @@ func JwtPublicKeyID(key *rsa.PublicKey) string { return hex.EncodeToString(hash[:]) } -func JwtPrivateKey(c *config.Config) (*rsa.PrivateKey, error) { - key, err := jwt.ParseRSAPrivateKeyFromPEM([]byte(c.Jwt.PrivateKey)) +func JwtPrivateKey(privateKey string) (*rsa.PrivateKey, error) { + key, err := jwt.ParseRSAPrivateKeyFromPEM([]byte(privateKey)) if err != nil { return nil, errors.Wrap(err, "failed to parse private key") } return key, nil } -func JwtPublicKey(c *config.Config) (*rsa.PublicKey, error) { - key, err := jwt.ParseRSAPublicKeyFromPEM([]byte(c.Jwt.PublicKey)) +func JwtPublicKey(publicKey string) (*rsa.PublicKey, error) { + key, err := jwt.ParseRSAPublicKeyFromPEM([]byte(publicKey)) if err != nil { return nil, errors.Wrap(err, "failed to parse public key") } return key, nil } -// Ref: https://docs.temporal.io/self-hosted-guide/security#authorization -func NewToken(config *config.Config, permissions []string, subject string) (string, error) { - privateKey, err := JwtPrivateKey(config) - if err != nil { - return "", err - } - - publicKey, err := JwtPublicKey(config) - if err != nil { - return "", err - } - - type WorkerClaims struct { - jwt.RegisteredClaims - Permissions []string `json:"permissions"` - } +type Claims struct { + jwt.RegisteredClaims + Permissions []string `json:"permissions"` + WorkerGroup string `json:"group"` +} +func NewTokenForUser(privateKey string, publicKey string, email string) (string, error) { // Create claims with multiple fields populated - claims := WorkerClaims{ + claims := Claims{ jwt.RegisteredClaims{ ExpiresAt: jwt.NewNumericDate(time.Now().Add(12 * 30 * 24 * time.Hour)), IssuedAt: jwt.NewNumericDate(time.Now()), NotBefore: jwt.NewNumericDate(time.Now()), Issuer: "zdravko", - Subject: subject, + Subject: "user:" + email, }, - permissions, + // Ref: https://docs.temporal.io/self-hosted-guide/security#authorization + []string{"temporal-system:admin", "default:admin"}, + "", + } + + return NewToken(privateKey, publicKey, claims) +} + +func NewTokenForServer(privateKey string, publicKey string) (string, error) { + // Create claims with multiple fields populated + claims := Claims{ + jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(time.Now().Add(12 * 30 * 24 * time.Hour)), + IssuedAt: jwt.NewNumericDate(time.Now()), + NotBefore: jwt.NewNumericDate(time.Now()), + Issuer: "zdravko", + Subject: "server", + }, + // Ref: https://docs.temporal.io/self-hosted-guide/security#authorization + []string{"temporal-system:admin", "default:admin"}, + "", + } + + return NewToken(privateKey, publicKey, claims) +} + +func NewTokenForWorker(privateKey string, publicKey string, worker *models.Worker) (string, error) { + // Create claims with multiple fields populated + claims := Claims{ + jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(time.Now().Add(12 * 30 * 24 * time.Hour)), + IssuedAt: jwt.NewNumericDate(time.Now()), + NotBefore: jwt.NewNumericDate(time.Now()), + Issuer: "zdravko", + Subject: "worker:" + worker.Slug, + }, + // Ref: https://docs.temporal.io/self-hosted-guide/security#authorization + []string{"default:read", "default:write", "default:worker"}, + worker.Group, + } + + return NewToken(privateKey, publicKey, claims) +} + +func NewToken(privateKey string, publicKey string, claims Claims) (string, error) { + privKey, err := JwtPrivateKey(privateKey) + if err != nil { + return "", err + } + + pubKey, err := JwtPublicKey(publicKey) + if err != nil { + return "", err } token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) - token.Header["kid"] = JwtPublicKeyID(publicKey) + token.Header["kid"] = JwtPublicKeyID(pubKey) - signedToken, err := token.SignedString(privateKey) + signedToken, err := token.SignedString(privKey) if err != nil { return "", err } return signedToken, nil } + +func ParseToken(tokenString string, publicKey string) (*jwt.Token, *Claims, error) { + claims := &Claims{} + token, err := jwt.ParseWithClaims(tokenString, claims, func(token *jwt.Token) (interface{}, error) { + if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok { + return nil, errors.New("unexpected signing method") + } + return JwtPublicKey(publicKey) + }) + if err != nil { + return nil, nil, err + } + + return token, claims, nil +} diff --git a/internal/models/models.go b/internal/models/models.go index af13212..e68b967 100644 --- a/internal/models/models.go +++ b/internal/models/models.go @@ -15,6 +15,7 @@ type Worker struct { gorm.Model Name string `gorm:"unique" validate:"required"` Slug string `gorm:"unique"` + Group string `validate:"required"` Status string } diff --git a/internal/models/query/workers.gen.go b/internal/models/query/workers.gen.go index ecdb88f..046d81a 100644 --- a/internal/models/query/workers.gen.go +++ b/internal/models/query/workers.gen.go @@ -33,6 +33,7 @@ func newWorker(db *gorm.DB, opts ...gen.DOOption) worker { _worker.DeletedAt = field.NewField(tableName, "deleted_at") _worker.Name = field.NewString(tableName, "name") _worker.Slug = field.NewString(tableName, "slug") + _worker.Group = field.NewString(tableName, "group") _worker.Status = field.NewString(tableName, "status") _worker.fillFieldMap() @@ -50,6 +51,7 @@ type worker struct { DeletedAt field.Field Name field.String Slug field.String + Group field.String Status field.String fieldMap map[string]field.Expr @@ -73,6 +75,7 @@ func (w *worker) updateTableName(table string) *worker { w.DeletedAt = field.NewField(table, "deleted_at") w.Name = field.NewString(table, "name") w.Slug = field.NewString(table, "slug") + w.Group = field.NewString(table, "group") w.Status = field.NewString(table, "status") w.fillFieldMap() @@ -98,13 +101,14 @@ func (w *worker) GetFieldByName(fieldName string) (field.OrderExpr, bool) { } func (w *worker) fillFieldMap() { - w.fieldMap = make(map[string]field.Expr, 7) + w.fieldMap = make(map[string]field.Expr, 8) w.fieldMap["id"] = w.ID w.fieldMap["created_at"] = w.CreatedAt w.fieldMap["updated_at"] = w.UpdatedAt w.fieldMap["deleted_at"] = w.DeletedAt w.fieldMap["name"] = w.Name w.fieldMap["slug"] = w.Slug + w.fieldMap["group"] = w.Group w.fieldMap["status"] = w.Status } diff --git a/internal/temporal/temporal.go b/internal/temporal/temporal.go index e134423..4cf8365 100644 --- a/internal/temporal/temporal.go +++ b/internal/temporal/temporal.go @@ -7,6 +7,7 @@ import ( "code.tjo.space/mentos1386/zdravko/internal/config" "code.tjo.space/mentos1386/zdravko/internal/jwt" "code.tjo.space/mentos1386/zdravko/pkg/retry" + "github.com/pkg/errors" "go.temporal.io/sdk/client" ) @@ -20,9 +21,9 @@ func (p *AuthHeadersProvider) GetHeaders(ctx context.Context) (map[string]string }, nil } -func ConnectServerToTemporal(cfg *config.Config) (client.Client, error) { +func ConnectServerToTemporal(cfg *config.ServerConfig) (client.Client, error) { // For server we generate new token with admin permissions - token, err := jwt.NewToken(cfg, []string{"temporal-system:admin", "default:admin"}, "server") + token, err := jwt.NewTokenForServer(cfg.Jwt.PrivateKey, cfg.Jwt.PublicKey) if err != nil { return nil, err } @@ -38,15 +39,20 @@ func ConnectServerToTemporal(cfg *config.Config) (client.Client, error) { }) } -func ConnectWorkerToTemporal(cfg *config.Config) (client.Client, error) { - provider := &AuthHeadersProvider{cfg.Worker.Token} +func ConnectWorkerToTemporal(token string, temporalHost string, identity string) (client.Client, error) { + provider := &AuthHeadersProvider{token} // Try to connect to the Temporal Server return retry.Retry(5, 6*time.Second, func() (client.Client, error) { - return client.Dial(client.Options{ - HostPort: cfg.Temporal.ServerHost, + client, err := client.Dial(client.Options{ + HostPort: temporalHost, HeadersProvider: provider, Namespace: "default", + Identity: identity, }) + if err != nil { + return nil, errors.Wrap(err, "failed to connect to Temporal Server: "+temporalHost) + } + return client, nil }) } diff --git a/justfile b/justfile index 7ffcce0..7b46b7e 100644 --- a/justfile +++ b/justfile @@ -11,32 +11,20 @@ GIT_SHA := `git rev-parse --short HEAD` DOCKER_IMAGE := "ghcr.io/mentos1386/zdravko:sha-"+GIT_SHA STATIC_DIR := "./web/static" -# Build the application -build: - docker build -f build/Dockerfile -t {{DOCKER_IMAGE}} . - -# Run Docker application. -run-docker: - docker run -p 8080:8080 \ - -e SESSION_SECRET \ - -e OAUTH2_CLIENT_ID \ - -e OAUTH2_CLIENT_SECRET \ - -e OAUTH2_ENDPOINT_TOKEN_URL \ - -e OAUTH2_ENDPOINT_AUTH_URL \ - -e OAUTH2_ENDPOINT_USER_INFO_URL \ - -e OAUTH2_ENDPOINT_LOGOUT_URL \ - {{DOCKER_IMAGE}} +_default: + @just --list # Run full development environment run: devbox services up +# Start worker run-worker: go build -o dist/zdravko cmd/zdravko/main.go ./dist/zdravko --worker -# Start zdravko -run-zdravko: +# Start server +run-server: go build -o dist/zdravko cmd/zdravko/main.go ./dist/zdravko --server --temporal @@ -49,11 +37,34 @@ generate-jwt-key: deploy: fly deploy --ha=false -c deploy/fly.toml -i {{DOCKER_IMAGE}} - +# Read local jwt key and set it as fly secret deploy-set-jwt-key-secrets: - @fly secrets set -c deploy/fly.toml \ - "JWT_PRIVATE_KEY={{JWT_PRIVATE_KEY}}" \ - "JWT_PUBLIC_KEY={{JWT_PUBLIC_KEY}}" + #!/bin/bash + # https://github.com/superfly/flyctl/issues/589 + cat < Name + + Group + Status @@ -53,6 +56,9 @@ {{.Name}} + + {{.Group}} + OK diff --git a/web/templates/pages/settings_workers_create.tmpl b/web/templates/pages/settings_workers_create.tmpl index b652a1d..4250deb 100644 --- a/web/templates/pages/settings_workers_create.tmpl +++ b/web/templates/pages/settings_workers_create.tmpl @@ -6,7 +6,11 @@
- + +
+
+ +