From b59351294669dcfa936cd93b77fdcee38292486f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bj=C3=B6rn=20Fj=C3=A4llstr=C3=B6m?= Date: Wed, 2 Oct 2024 09:33:28 +0200 Subject: [PATCH] Add SSOauth middleware --- handler/session.go | 84 ++++++++++++++++++++++++++++++++++++++++++++++ main.go | 2 +- 2 files changed, 85 insertions(+), 1 deletion(-) diff --git a/handler/session.go b/handler/session.go index b660d9c..86be792 100644 --- a/handler/session.go +++ b/handler/session.go @@ -8,7 +8,11 @@ import ( "github.com/gorilla/sessions" "github.com/labstack/echo-contrib/session" "github.com/labstack/echo/v4" + "github.com/labstack/gommon/log" + "github.com/ngoduykhanh/wireguard-ui/model" + "github.com/ngoduykhanh/wireguard-ui/store/jsondb" "github.com/ngoduykhanh/wireguard-ui/util" + "github.com/rs/xid" ) func ValidSession(next echo.HandlerFunc) echo.HandlerFunc { @@ -43,6 +47,86 @@ func NeedsAdmin(next echo.HandlerFunc) echo.HandlerFunc { } } +// SSOauth uses external authentication (usually by reverseproxy) in the form of HTTP header REMOTE_USER +func SSOauth(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + if !util.RemoteUser { + return next(c) + } + if !isValidSession(c) { + remoteUser := c.Request().Header.Get("REMOTE_USER") + if remoteUser == "" { + // TODO: Better error handling + log.Infof("No REMOTE_USER in reqest. Bailing out.") + return c.Redirect(http.StatusTemporaryRedirect, util.BasePath+"/") + } + log.Debugf("No valid session for REMOTE_USER: %s", remoteUser) + + db := c.Get("db").(*jsondb.JsonDB) + dbuser, err := db.GetUserByName(remoteUser) + if err != nil { + log.Infof("User %s not in database, creating user", remoteUser) + newUser := model.User{ + Username: remoteUser, + Admin: false, + } + err = db.SaveUser(newUser) + if err != nil { + // TODO: Better error handling + return c.Redirect(http.StatusTemporaryRedirect, util.BasePath+"/") + } + // Update dbuser from database + dbuser, err = db.GetUserByName(remoteUser) + if err != nil { + // TODO: Better error handling + return c.Redirect(http.StatusTemporaryRedirect, util.BasePath+"/") + } + + } else { + log.Debugf("Got user from db: %s", dbuser.Username) + } + + // Set session for REMOTE_USER + ageMax := 0 + + cookiePath := util.GetCookiePath() + + sess, _ := session.Get("session", c) + sess.Options = &sessions.Options{ + Path: cookiePath, + MaxAge: ageMax, + HttpOnly: true, + SameSite: http.SameSiteLaxMode, + } + + // set session_token + tokenUID := xid.New().String() + now := time.Now().UTC().Unix() + sess.Values["username"] = dbuser.Username + sess.Values["user_hash"] = util.GetDBUserCRC32(dbuser) + sess.Values["admin"] = dbuser.Admin + sess.Values["session_token"] = tokenUID + sess.Values["max_age"] = ageMax + sess.Values["created_at"] = now + sess.Values["updated_at"] = now + sess.Save(c.Request(), c.Response()) + + // set session_token in cookie + cookie := new(http.Cookie) + cookie.Name = "session_token" + cookie.Path = cookiePath + cookie.Value = tokenUID + cookie.MaxAge = ageMax + cookie.HttpOnly = true + cookie.SameSite = http.SameSiteLaxMode + c.SetCookie(cookie) + + return c.Redirect(http.StatusTemporaryRedirect, util.BasePath) + } + return next(c) + } +} + func isValidSession(c echo.Context) bool { if util.DisableLogin { return true diff --git a/main.go b/main.go index c9d4180..b3aee92 100644 --- a/main.go +++ b/main.go @@ -212,7 +212,7 @@ func main() { // register routes app := router.New(tmplDir, extraData, util.SessionSecret, db) - app.GET(util.BasePath, handler.WireGuardClients(db), handler.ValidSession, handler.RefreshSession) + app.GET(util.BasePath, handler.WireGuardClients(db), handler.SSOauth, handler.ValidSession, handler.RefreshSession) // Important: Make sure that all non-GET routes check the request content type using handler.ContentTypeJson to // mitigate CSRF attacks. This is effective, because browsers don't allow setting the Content-Type header on