diff --git a/docker-compose.yaml b/docker-compose.yaml index dae14ce..1494a72 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -9,6 +9,7 @@ services: - SENDGRID_API_KEY - EMAIL_FROM - EMAIL_FROM_NAME + - SESSION_SECRET ports: - 5000:5000 logging: diff --git a/handler/routes.go b/handler/routes.go index f4f4b18..c16c14b 100644 --- a/handler/routes.go +++ b/handler/routes.go @@ -79,8 +79,6 @@ func Logout() echo.HandlerFunc { // WireGuardClients handler func WireGuardClients() echo.HandlerFunc { return func(c echo.Context) error { - // access validation - validSession(c) clientDataList, err := util.GetClients(true) if err != nil { @@ -99,8 +97,6 @@ func WireGuardClients() echo.HandlerFunc { // GetClients handler return a list of Wireguard client data func GetClients() echo.HandlerFunc { return func(c echo.Context) error { - // access validation - validSession(c) clientDataList, err := util.GetClients(true) if err != nil { @@ -116,8 +112,6 @@ func GetClients() echo.HandlerFunc { // GetClient handler return a of Wireguard client data func GetClient() echo.HandlerFunc { return func(c echo.Context) error { - // access validation - validSession(c) clientID := c.Param("id") clientData, err := util.GetClientByID(clientID, true) @@ -132,8 +126,6 @@ func GetClient() echo.HandlerFunc { // NewClient handler func NewClient() echo.HandlerFunc { return func(c echo.Context) error { - // access validation - validSession(c) client := new(model.Client) c.Bind(client) @@ -204,8 +196,6 @@ func EmailClient(mailer emailer.Emailer) echo.HandlerFunc { } return func(c echo.Context) error { - // access validation - validSession(c) var payload clientIdEmailPayload c.Bind(&payload) // TODO validate email @@ -245,8 +235,6 @@ func EmailClient(mailer emailer.Emailer) echo.HandlerFunc { // UpdateClient handler to update client information func UpdateClient() echo.HandlerFunc { return func(c echo.Context) error { - // access validation - validSession(c) _client := new(model.Client) c.Bind(_client) @@ -305,8 +293,6 @@ func UpdateClient() echo.HandlerFunc { // SetClientStatus handler to enable / disable a client func SetClientStatus() echo.HandlerFunc { return func(c echo.Context) error { - // access validation - validSession(c) data := make(map[string]interface{}) err := json.NewDecoder(c.Request().Body).Decode(&data) @@ -368,8 +354,6 @@ func DownloadClient() echo.HandlerFunc { // RemoveClient handler func RemoveClient() echo.HandlerFunc { return func(c echo.Context) error { - // access validation - validSession(c) client := new(model.Client) c.Bind(client) @@ -394,8 +378,6 @@ func RemoveClient() echo.HandlerFunc { // WireGuardServer handler func WireGuardServer() echo.HandlerFunc { return func(c echo.Context) error { - // access validation - validSession(c) server, err := util.GetServer() if err != nil { @@ -413,8 +395,6 @@ func WireGuardServer() echo.HandlerFunc { // WireGuardServerInterfaces handler func WireGuardServerInterfaces() echo.HandlerFunc { return func(c echo.Context) error { - // access validation - validSession(c) serverInterface := new(model.ServerInterface) c.Bind(serverInterface) @@ -444,8 +424,6 @@ func WireGuardServerInterfaces() echo.HandlerFunc { // WireGuardServerKeyPair handler to generate private and public keys func WireGuardServerKeyPair() echo.HandlerFunc { return func(c echo.Context) error { - // access validation - validSession(c) // gen Wireguard key pair key, err := wgtypes.GeneratePrivateKey() @@ -476,8 +454,6 @@ func WireGuardServerKeyPair() echo.HandlerFunc { // GlobalSettings handler func GlobalSettings() echo.HandlerFunc { return func(c echo.Context) error { - // access validation - validSession(c) globalSettings, err := util.GetGlobalSettings() if err != nil { @@ -494,8 +470,6 @@ func GlobalSettings() echo.HandlerFunc { // GlobalSettingSubmit handler to update the global settings func GlobalSettingSubmit() echo.HandlerFunc { return func(c echo.Context) error { - // access validation - validSession(c) globalSettings := new(model.GlobalSetting) c.Bind(globalSettings) @@ -525,8 +499,6 @@ func GlobalSettingSubmit() echo.HandlerFunc { // MachineIPAddresses handler to get local interface ip addresses func MachineIPAddresses() echo.HandlerFunc { return func(c echo.Context) error { - // access validation - validSession(c) // get private ip addresses interfaceList, err := util.GetInterfaceIPs() @@ -551,8 +523,6 @@ func MachineIPAddresses() echo.HandlerFunc { // SuggestIPAllocation handler to get the list of ip address for client func SuggestIPAllocation() echo.HandlerFunc { return func(c echo.Context) error { - // access validation - validSession(c) server, err := util.GetServer() if err != nil { @@ -589,8 +559,6 @@ func SuggestIPAllocation() echo.HandlerFunc { // ApplyServerConfig handler to write config file and restart Wireguard server func ApplyServerConfig(tmplBox *rice.Box) echo.HandlerFunc { return func(c echo.Context) error { - // access validation - validSession(c) server, err := util.GetServer() if err != nil { diff --git a/handler/session.go b/handler/session.go index 6985327..10042ac 100644 --- a/handler/session.go +++ b/handler/session.go @@ -9,22 +9,32 @@ import ( "github.com/ngoduykhanh/wireguard-ui/util" ) -// validSession to redirect user to the login page if they are not authenticated or session expired. -func validSession(c echo.Context) { - if !util.DisableLogin { - sess, _ := session.Get("session", c) - cookie, err := c.Cookie("session_token") - if err != nil || sess.Values["session_token"] != cookie.Value { +func ValidSession(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + if !isValidSession(c) { nextURL := c.Request().URL - if nextURL != nil { - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("/login?next=%s", c.Request().URL)) + if nextURL != nil && c.Request().Method == http.MethodGet { + return c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("/login?next=%s", c.Request().URL)) } else { - c.Redirect(http.StatusTemporaryRedirect, "/login") + return c.Redirect(http.StatusTemporaryRedirect, "/login") } } + return next(c) } } +func isValidSession(c echo.Context) bool { + if util.DisableLogin { + return true + } + sess, _ := session.Get("session", c) + cookie, err := c.Cookie("session_token") + if err != nil || sess.Values["session_token"] != cookie.Value { + return false + } + return true +} + // currentUser to get username of logged in user func currentUser(c echo.Context) string { if util.DisableLogin { diff --git a/main.go b/main.go index e7a1ab8..47fdb15 100644 --- a/main.go +++ b/main.go @@ -36,6 +36,7 @@ func init() { util.SendgridApiKey = os.Getenv("SENDGRID_API_KEY") util.EmailFrom = os.Getenv("EMAIL_FROM") util.EmailFromName = os.Getenv("EMAIL_FROM_NAME") + util.SessionSecret = []byte(os.Getenv("SESSION_SECRET")) // print app information fmt.Println("Wireguard UI") @@ -66,9 +67,9 @@ func main() { assetHandler := http.FileServer(rice.MustFindBox("assets").HTTPBox()) // register routes - app := router.New(tmplBox, extraData) + app := router.New(tmplBox, extraData, util.SessionSecret) - app.GET("/", handler.WireGuardClients()) + app.GET("/", handler.WireGuardClients(), handler.ValidSession) if !util.DisableLogin { app.GET("/login", handler.LoginPage()) @@ -77,23 +78,23 @@ func main() { sendmail := emailer.NewSendgridApiMail(util.SendgridApiKey, util.EmailFromName, util.EmailFrom) - app.GET("/logout", handler.Logout()) - app.POST("/new-client", handler.NewClient()) - app.POST("/update-client", handler.UpdateClient()) - app.POST("/email-client", handler.EmailClient(sendmail)) - app.POST("/client/set-status", handler.SetClientStatus()) - app.POST("/remove-client", handler.RemoveClient()) - app.GET("/download", handler.DownloadClient()) - app.GET("/wg-server", handler.WireGuardServer()) - app.POST("wg-server/interfaces", handler.WireGuardServerInterfaces()) - app.POST("wg-server/keypair", handler.WireGuardServerKeyPair()) - app.GET("/global-settings", handler.GlobalSettings()) - app.POST("/global-settings", handler.GlobalSettingSubmit()) - app.GET("/api/clients", handler.GetClients()) - app.GET("/api/client/:id", handler.GetClient()) - app.GET("/api/machine-ips", handler.MachineIPAddresses()) - app.GET("/api/suggest-client-ips", handler.SuggestIPAllocation()) - app.GET("/api/apply-wg-config", handler.ApplyServerConfig(tmplBox)) + app.GET("/logout", handler.Logout(), handler.ValidSession) + app.POST("/new-client", handler.NewClient(), handler.ValidSession) + app.POST("/update-client", handler.UpdateClient(), handler.ValidSession) + app.POST("/email-client", handler.EmailClient(sendmail), handler.ValidSession) + app.POST("/client/set-status", handler.SetClientStatus(), handler.ValidSession) + app.POST("/remove-client", handler.RemoveClient(), handler.ValidSession) + app.GET("/download", handler.DownloadClient(), handler.ValidSession) + app.GET("/wg-server", handler.WireGuardServer(), handler.ValidSession) + app.POST("wg-server/interfaces", handler.WireGuardServerInterfaces(), handler.ValidSession) + app.POST("wg-server/keypair", handler.WireGuardServerKeyPair(), handler.ValidSession) + app.GET("/global-settings", handler.GlobalSettings(), handler.ValidSession) + app.POST("/global-settings", handler.GlobalSettingSubmit(), handler.ValidSession) + app.GET("/api/clients", handler.GetClients(), handler.ValidSession) + app.GET("/api/client/:id", handler.GetClient(), handler.ValidSession) + app.GET("/api/machine-ips", handler.MachineIPAddresses(), handler.ValidSession) + app.GET("/api/suggest-client-ips", handler.SuggestIPAllocation(), handler.ValidSession) + app.GET("/api/apply-wg-config", handler.ApplyServerConfig(tmplBox), handler.ValidSession) // servers other static files app.GET("/static/*", echo.WrapHandler(http.StripPrefix("/static/", assetHandler))) diff --git a/router/router.go b/router/router.go index 2bd634e..bb14431 100644 --- a/router/router.go +++ b/router/router.go @@ -6,7 +6,7 @@ import ( "reflect" "text/template" - "github.com/GeertJohan/go.rice" + rice "github.com/GeertJohan/go.rice" "github.com/gorilla/sessions" "github.com/labstack/echo-contrib/session" "github.com/labstack/echo/v4" @@ -44,9 +44,9 @@ func (t *TemplateRegistry) Render(w io.Writer, name string, data interface{}, c } // New function -func New(tmplBox *rice.Box, extraData map[string]string) *echo.Echo { +func New(tmplBox *rice.Box, extraData map[string]string, secret []byte) *echo.Echo { e := echo.New() - e.Use(session.Middleware(sessions.NewCookieStore([]byte("secret")))) + e.Use(session.Middleware(sessions.NewCookieStore(secret))) // read html template file to string tmplBaseString, err := tmplBox.String("base.html") diff --git a/util/config.go b/util/config.go index 24d7595..6ec77a4 100644 --- a/util/config.go +++ b/util/config.go @@ -7,4 +7,5 @@ var ( SendgridApiKey string EmailFrom string EmailFromName string + SessionSecret []byte )