diff --git a/handler/routes.go b/handler/routes.go index 4dc95c6..1899cfa 100644 --- a/handler/routes.go +++ b/handler/routes.go @@ -93,18 +93,24 @@ func Login(db store.IStore) echo.HandlerFunc { } if userCorrect && passwordCorrect { - // TODO: refresh the token ageMax := 0 expiration := time.Now().Add(24 * time.Hour) if rememberMe { - ageMax = 86400 - expiration.Add(144 * time.Hour) + ageMax = 86400 * 7 + expiration = time.Now().Add(time.Duration(ageMax) * time.Second) } + + cookiePath := util.BasePath + if cookiePath == "" { + cookiePath = "/" + } + sess, _ := session.Get("session", c) sess.Options = &sessions.Options{ - Path: util.BasePath, + Path: cookiePath, MaxAge: ageMax, HttpOnly: true, + SameSite: http.SameSiteLaxMode, } // set session_token @@ -117,8 +123,11 @@ func Login(db store.IStore) echo.HandlerFunc { // set session_token in cookie cookie := new(http.Cookie) cookie.Name = "session_token" + cookie.Path = cookiePath cookie.Value = tokenUID cookie.Expires = expiration + cookie.HttpOnly = true + cookie.SameSite = http.SameSiteLaxMode c.SetCookie(cookie) return c.JSON(http.StatusOK, jsonHTTPResponse{true, "Logged in successfully"}) diff --git a/handler/session.go b/handler/session.go index 4cede6e..bcc44b8 100644 --- a/handler/session.go +++ b/handler/session.go @@ -3,7 +3,9 @@ package handler import ( "fmt" "net/http" + "time" + "github.com/gorilla/sessions" "github.com/labstack/echo-contrib/session" "github.com/labstack/echo/v4" "github.com/ngoduykhanh/wireguard-ui/util" @@ -23,6 +25,13 @@ func ValidSession(next echo.HandlerFunc) echo.HandlerFunc { } } +func RefreshSession(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + doRefreshSession(c) + return next(c) + } +} + func NeedsAdmin(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { if !isAdmin(c) { @@ -44,6 +53,40 @@ func isValidSession(c echo.Context) bool { return true } +func doRefreshSession(c echo.Context) { + if util.DisableLogin { + return + } + + sess, _ := session.Get("session", c) + oldCookie, err := c.Cookie("session_token") + if err != nil || sess.Values["session_token"] != oldCookie.Value { + return + } + + cookiePath := util.BasePath + if cookiePath == "" { + cookiePath = "/" + } + + sess.Options = &sessions.Options{ + Path: cookiePath, + MaxAge: sess.Options.MaxAge, + HttpOnly: true, + SameSite: http.SameSiteLaxMode, + } + sess.Save(c.Request(), c.Response()) + + cookie := new(http.Cookie) + cookie.Name = "session_token" + cookie.Path = cookiePath + cookie.Value = oldCookie.Value + cookie.Expires = time.Now().Add(time.Duration(sess.Options.MaxAge) * time.Second) + cookie.HttpOnly = true + cookie.SameSite = http.SameSiteLaxMode + c.SetCookie(cookie) +} + // currentUser to get username of logged in user func currentUser(c echo.Context) string { if util.DisableLogin { diff --git a/main.go b/main.go index e11cf29..c36a66e 100644 --- a/main.go +++ b/main.go @@ -205,7 +205,7 @@ func main() { // register routes app := router.New(tmplDir, extraData, util.SessionSecret) - app.GET(util.BasePath, handler.WireGuardClients(db), handler.ValidSession) + app.GET(util.BasePath, handler.WireGuardClients(db), handler.ValidSession, handler.RefreshSession) // Important: Make sure that all non-GET routes check the request content type using handler.ContentTypeJson to // mitigate CSRF attacks. This is effective, because browsers don't allow setting the Content-Type header on @@ -215,8 +215,8 @@ func main() { app.GET(util.BasePath+"/login", handler.LoginPage()) app.POST(util.BasePath+"/login", handler.Login(db), handler.ContentTypeJson) app.GET(util.BasePath+"/logout", handler.Logout(), handler.ValidSession) - app.GET(util.BasePath+"/profile", handler.LoadProfile(), handler.ValidSession) - app.GET(util.BasePath+"/users-settings", handler.UsersSettings(), handler.ValidSession, handler.NeedsAdmin) + app.GET(util.BasePath+"/profile", handler.LoadProfile(), handler.ValidSession, handler.RefreshSession) + app.GET(util.BasePath+"/users-settings", handler.UsersSettings(), handler.ValidSession, handler.RefreshSession, handler.NeedsAdmin) app.POST(util.BasePath+"/update-user", handler.UpdateUser(db), handler.ValidSession, handler.ContentTypeJson) app.POST(util.BasePath+"/create-user", handler.CreateUser(db), handler.ValidSession, handler.ContentTypeJson, handler.NeedsAdmin) app.POST(util.BasePath+"/remove-user", handler.RemoveUser(db), handler.ValidSession, handler.ContentTypeJson, handler.NeedsAdmin) @@ -242,19 +242,19 @@ func main() { app.POST(util.BasePath+"/client/set-status", handler.SetClientStatus(db), handler.ValidSession, handler.ContentTypeJson) app.POST(util.BasePath+"/remove-client", handler.RemoveClient(db), handler.ValidSession, handler.ContentTypeJson) app.GET(util.BasePath+"/download", handler.DownloadClient(db), handler.ValidSession) - app.GET(util.BasePath+"/wg-server", handler.WireGuardServer(db), handler.ValidSession, handler.NeedsAdmin) + app.GET(util.BasePath+"/wg-server", handler.WireGuardServer(db), handler.ValidSession, handler.RefreshSession, handler.NeedsAdmin) app.POST(util.BasePath+"/wg-server/interfaces", handler.WireGuardServerInterfaces(db), handler.ValidSession, handler.ContentTypeJson, handler.NeedsAdmin) app.POST(util.BasePath+"/wg-server/keypair", handler.WireGuardServerKeyPair(db), handler.ValidSession, handler.ContentTypeJson, handler.NeedsAdmin) - app.GET(util.BasePath+"/global-settings", handler.GlobalSettings(db), handler.ValidSession, handler.NeedsAdmin) + app.GET(util.BasePath+"/global-settings", handler.GlobalSettings(db), handler.ValidSession, handler.RefreshSession, handler.NeedsAdmin) app.POST(util.BasePath+"/global-settings", handler.GlobalSettingSubmit(db), handler.ValidSession, handler.ContentTypeJson, handler.NeedsAdmin) - app.GET(util.BasePath+"/status", handler.Status(db), handler.ValidSession) + app.GET(util.BasePath+"/status", handler.Status(db), handler.ValidSession, handler.RefreshSession) app.GET(util.BasePath+"/api/clients", handler.GetClients(db), handler.ValidSession) app.GET(util.BasePath+"/api/client/:id", handler.GetClient(db), handler.ValidSession) app.GET(util.BasePath+"/api/machine-ips", handler.MachineIPAddresses(), handler.ValidSession) app.GET(util.BasePath+"/api/subnet-ranges", handler.GetOrderedSubnetRanges(), handler.ValidSession) app.GET(util.BasePath+"/api/suggest-client-ips", handler.SuggestIPAllocation(db), handler.ValidSession) app.POST(util.BasePath+"/api/apply-wg-config", handler.ApplyServerConfig(db, tmplDir), handler.ValidSession, handler.ContentTypeJson) - app.GET(util.BasePath+"/wake_on_lan_hosts", handler.GetWakeOnLanHosts(db), handler.ValidSession) + app.GET(util.BasePath+"/wake_on_lan_hosts", handler.GetWakeOnLanHosts(db), handler.ValidSession, handler.RefreshSession) app.POST(util.BasePath+"/wake_on_lan_host", handler.SaveWakeOnLanHost(db), handler.ValidSession, handler.ContentTypeJson) app.DELETE(util.BasePath+"/wake_on_lan_host/:mac_address", handler.DeleteWakeOnHost(db), handler.ValidSession, handler.ContentTypeJson) app.PUT(util.BasePath+"/wake_on_lan_host/:mac_address", handler.WakeOnHost(db), handler.ValidSession, handler.ContentTypeJson)