mirror of
https://github.com/ngoduykhanh/wireguard-ui.git
synced 2025-04-19 19:59:13 +03:00
Merge b593512946
into 2fdafd34ca
This commit is contained in:
commit
b31729a420
4 changed files with 101 additions and 3 deletions
|
@ -8,7 +8,11 @@ import (
|
||||||
"github.com/gorilla/sessions"
|
"github.com/gorilla/sessions"
|
||||||
"github.com/labstack/echo-contrib/session"
|
"github.com/labstack/echo-contrib/session"
|
||||||
"github.com/labstack/echo/v4"
|
"github.com/labstack/echo/v4"
|
||||||
|
"github.com/labstack/gommon/log"
|
||||||
|
"github.com/ngoduykhanh/wireguard-ui/model"
|
||||||
|
"github.com/ngoduykhanh/wireguard-ui/store/jsondb"
|
||||||
"github.com/ngoduykhanh/wireguard-ui/util"
|
"github.com/ngoduykhanh/wireguard-ui/util"
|
||||||
|
"github.com/rs/xid"
|
||||||
)
|
)
|
||||||
|
|
||||||
func ValidSession(next echo.HandlerFunc) echo.HandlerFunc {
|
func ValidSession(next echo.HandlerFunc) echo.HandlerFunc {
|
||||||
|
@ -43,6 +47,86 @@ func NeedsAdmin(next echo.HandlerFunc) echo.HandlerFunc {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SSOauth uses external authentication (usually by reverseproxy) in the form of HTTP header REMOTE_USER
|
||||||
|
func SSOauth(next echo.HandlerFunc) echo.HandlerFunc {
|
||||||
|
return func(c echo.Context) error {
|
||||||
|
if !util.RemoteUser {
|
||||||
|
return next(c)
|
||||||
|
}
|
||||||
|
if !isValidSession(c) {
|
||||||
|
remoteUser := c.Request().Header.Get("REMOTE_USER")
|
||||||
|
if remoteUser == "" {
|
||||||
|
// TODO: Better error handling
|
||||||
|
log.Infof("No REMOTE_USER in reqest. Bailing out.")
|
||||||
|
return c.Redirect(http.StatusTemporaryRedirect, util.BasePath+"/")
|
||||||
|
}
|
||||||
|
log.Debugf("No valid session for REMOTE_USER: %s", remoteUser)
|
||||||
|
|
||||||
|
db := c.Get("db").(*jsondb.JsonDB)
|
||||||
|
dbuser, err := db.GetUserByName(remoteUser)
|
||||||
|
if err != nil {
|
||||||
|
log.Infof("User %s not in database, creating user", remoteUser)
|
||||||
|
newUser := model.User{
|
||||||
|
Username: remoteUser,
|
||||||
|
Admin: false,
|
||||||
|
}
|
||||||
|
err = db.SaveUser(newUser)
|
||||||
|
if err != nil {
|
||||||
|
// TODO: Better error handling
|
||||||
|
return c.Redirect(http.StatusTemporaryRedirect, util.BasePath+"/")
|
||||||
|
}
|
||||||
|
// Update dbuser from database
|
||||||
|
dbuser, err = db.GetUserByName(remoteUser)
|
||||||
|
if err != nil {
|
||||||
|
// TODO: Better error handling
|
||||||
|
return c.Redirect(http.StatusTemporaryRedirect, util.BasePath+"/")
|
||||||
|
}
|
||||||
|
|
||||||
|
} else {
|
||||||
|
log.Debugf("Got user from db: %s", dbuser.Username)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set session for REMOTE_USER
|
||||||
|
ageMax := 0
|
||||||
|
|
||||||
|
cookiePath := util.GetCookiePath()
|
||||||
|
|
||||||
|
sess, _ := session.Get("session", c)
|
||||||
|
sess.Options = &sessions.Options{
|
||||||
|
Path: cookiePath,
|
||||||
|
MaxAge: ageMax,
|
||||||
|
HttpOnly: true,
|
||||||
|
SameSite: http.SameSiteLaxMode,
|
||||||
|
}
|
||||||
|
|
||||||
|
// set session_token
|
||||||
|
tokenUID := xid.New().String()
|
||||||
|
now := time.Now().UTC().Unix()
|
||||||
|
sess.Values["username"] = dbuser.Username
|
||||||
|
sess.Values["user_hash"] = util.GetDBUserCRC32(dbuser)
|
||||||
|
sess.Values["admin"] = dbuser.Admin
|
||||||
|
sess.Values["session_token"] = tokenUID
|
||||||
|
sess.Values["max_age"] = ageMax
|
||||||
|
sess.Values["created_at"] = now
|
||||||
|
sess.Values["updated_at"] = now
|
||||||
|
sess.Save(c.Request(), c.Response())
|
||||||
|
|
||||||
|
// set session_token in cookie
|
||||||
|
cookie := new(http.Cookie)
|
||||||
|
cookie.Name = "session_token"
|
||||||
|
cookie.Path = cookiePath
|
||||||
|
cookie.Value = tokenUID
|
||||||
|
cookie.MaxAge = ageMax
|
||||||
|
cookie.HttpOnly = true
|
||||||
|
cookie.SameSite = http.SameSiteLaxMode
|
||||||
|
c.SetCookie(cookie)
|
||||||
|
|
||||||
|
return c.Redirect(http.StatusTemporaryRedirect, util.BasePath)
|
||||||
|
}
|
||||||
|
return next(c)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func isValidSession(c echo.Context) bool {
|
func isValidSession(c echo.Context) bool {
|
||||||
if util.DisableLogin {
|
if util.DisableLogin {
|
||||||
return true
|
return true
|
||||||
|
|
8
main.go
8
main.go
|
@ -33,6 +33,7 @@ var (
|
||||||
buildTime = fmt.Sprintf(time.Now().UTC().Format("01-02-2006 15:04:05"))
|
buildTime = fmt.Sprintf(time.Now().UTC().Format("01-02-2006 15:04:05"))
|
||||||
// configuration variables
|
// configuration variables
|
||||||
flagDisableLogin = false
|
flagDisableLogin = false
|
||||||
|
flagRemoteUser = false
|
||||||
flagBindAddress = "0.0.0.0:5000"
|
flagBindAddress = "0.0.0.0:5000"
|
||||||
flagSmtpHostname = "127.0.0.1"
|
flagSmtpHostname = "127.0.0.1"
|
||||||
flagSmtpPort = 25
|
flagSmtpPort = 25
|
||||||
|
@ -77,6 +78,7 @@ var embeddedAssets embed.FS
|
||||||
func init() {
|
func init() {
|
||||||
// command-line flags and env variables
|
// command-line flags and env variables
|
||||||
flag.BoolVar(&flagDisableLogin, "disable-login", util.LookupEnvOrBool("DISABLE_LOGIN", flagDisableLogin), "Disable authentication on the app. This is potentially dangerous.")
|
flag.BoolVar(&flagDisableLogin, "disable-login", util.LookupEnvOrBool("DISABLE_LOGIN", flagDisableLogin), "Disable authentication on the app. This is potentially dangerous.")
|
||||||
|
flag.BoolVar(&flagRemoteUser, "remote_user", util.LookupEnvOrBool("REMOTE_USER", flagRemoteUser), "Use HTTP header REMOTE_USER for auth. Commonly used with SSO and a proxy funcion.")
|
||||||
flag.StringVar(&flagBindAddress, "bind-address", util.LookupEnvOrString("BIND_ADDRESS", flagBindAddress), "Address:Port to which the app will be bound.")
|
flag.StringVar(&flagBindAddress, "bind-address", util.LookupEnvOrString("BIND_ADDRESS", flagBindAddress), "Address:Port to which the app will be bound.")
|
||||||
flag.StringVar(&flagSmtpHostname, "smtp-hostname", util.LookupEnvOrString("SMTP_HOSTNAME", flagSmtpHostname), "SMTP Hostname")
|
flag.StringVar(&flagSmtpHostname, "smtp-hostname", util.LookupEnvOrString("SMTP_HOSTNAME", flagSmtpHostname), "SMTP Hostname")
|
||||||
flag.IntVar(&flagSmtpPort, "smtp-port", util.LookupEnvOrInt("SMTP_PORT", flagSmtpPort), "SMTP Port")
|
flag.IntVar(&flagSmtpPort, "smtp-port", util.LookupEnvOrInt("SMTP_PORT", flagSmtpPort), "SMTP Port")
|
||||||
|
@ -126,6 +128,7 @@ func init() {
|
||||||
|
|
||||||
// update runtime config
|
// update runtime config
|
||||||
util.DisableLogin = flagDisableLogin
|
util.DisableLogin = flagDisableLogin
|
||||||
|
util.RemoteUser = flagRemoteUser
|
||||||
util.BindAddress = flagBindAddress
|
util.BindAddress = flagBindAddress
|
||||||
util.SmtpHostname = flagSmtpHostname
|
util.SmtpHostname = flagSmtpHostname
|
||||||
util.SmtpPort = flagSmtpPort
|
util.SmtpPort = flagSmtpPort
|
||||||
|
@ -161,6 +164,7 @@ func init() {
|
||||||
fmt.Println("Build Time\t:", buildTime)
|
fmt.Println("Build Time\t:", buildTime)
|
||||||
fmt.Println("Git Repo\t:", "https://github.com/ngoduykhanh/wireguard-ui")
|
fmt.Println("Git Repo\t:", "https://github.com/ngoduykhanh/wireguard-ui")
|
||||||
fmt.Println("Authentication\t:", !util.DisableLogin)
|
fmt.Println("Authentication\t:", !util.DisableLogin)
|
||||||
|
fmt.Println("Remote_user\t:", util.RemoteUser)
|
||||||
fmt.Println("Bind address\t:", util.BindAddress)
|
fmt.Println("Bind address\t:", util.BindAddress)
|
||||||
//fmt.Println("Sendgrid key\t:", util.SendgridApiKey)
|
//fmt.Println("Sendgrid key\t:", util.SendgridApiKey)
|
||||||
fmt.Println("Email from\t:", util.EmailFrom)
|
fmt.Println("Email from\t:", util.EmailFrom)
|
||||||
|
@ -206,9 +210,9 @@ func main() {
|
||||||
}
|
}
|
||||||
|
|
||||||
// register routes
|
// register routes
|
||||||
app := router.New(tmplDir, extraData, util.SessionSecret)
|
app := router.New(tmplDir, extraData, util.SessionSecret, db)
|
||||||
|
|
||||||
app.GET(util.BasePath, handler.WireGuardClients(db), handler.ValidSession, handler.RefreshSession)
|
app.GET(util.BasePath, handler.WireGuardClients(db), handler.SSOauth, handler.ValidSession, handler.RefreshSession)
|
||||||
|
|
||||||
// Important: Make sure that all non-GET routes check the request content type using handler.ContentTypeJson to
|
// Important: Make sure that all non-GET routes check the request content type using handler.ContentTypeJson to
|
||||||
// mitigate CSRF attacks. This is effective, because browsers don't allow setting the Content-Type header on
|
// mitigate CSRF attacks. This is effective, because browsers don't allow setting the Content-Type header on
|
||||||
|
|
|
@ -13,6 +13,7 @@ import (
|
||||||
"github.com/labstack/echo/v4"
|
"github.com/labstack/echo/v4"
|
||||||
"github.com/labstack/echo/v4/middleware"
|
"github.com/labstack/echo/v4/middleware"
|
||||||
"github.com/labstack/gommon/log"
|
"github.com/labstack/gommon/log"
|
||||||
|
"github.com/ngoduykhanh/wireguard-ui/store/jsondb"
|
||||||
"github.com/ngoduykhanh/wireguard-ui/util"
|
"github.com/ngoduykhanh/wireguard-ui/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -48,7 +49,7 @@ func (t *TemplateRegistry) Render(w io.Writer, name string, data interface{}, c
|
||||||
}
|
}
|
||||||
|
|
||||||
// New function
|
// New function
|
||||||
func New(tmplDir fs.FS, extraData map[string]interface{}, secret [64]byte) *echo.Echo {
|
func New(tmplDir fs.FS, extraData map[string]interface{}, secret [64]byte, db *jsondb.JsonDB) *echo.Echo {
|
||||||
e := echo.New()
|
e := echo.New()
|
||||||
|
|
||||||
cookiePath := util.GetCookiePath()
|
cookiePath := util.GetCookiePath()
|
||||||
|
@ -60,6 +61,14 @@ func New(tmplDir fs.FS, extraData map[string]interface{}, secret [64]byte) *echo
|
||||||
|
|
||||||
e.Use(session.Middleware(cookieStore))
|
e.Use(session.Middleware(cookieStore))
|
||||||
|
|
||||||
|
// Add db to context so middlewares can use it.
|
||||||
|
e.Use(func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||||
|
return func(c echo.Context) error {
|
||||||
|
c.Set("db", db)
|
||||||
|
return next(c)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
// read html template file to string
|
// read html template file to string
|
||||||
tmplBaseString, err := util.StringFromEmbedFile(tmplDir, "base.html")
|
tmplBaseString, err := util.StringFromEmbedFile(tmplDir, "base.html")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -10,6 +10,7 @@ import (
|
||||||
// Runtime config
|
// Runtime config
|
||||||
var (
|
var (
|
||||||
DisableLogin bool
|
DisableLogin bool
|
||||||
|
RemoteUser bool
|
||||||
BindAddress string
|
BindAddress string
|
||||||
SmtpHostname string
|
SmtpHostname string
|
||||||
SmtpPort int
|
SmtpPort int
|
||||||
|
|
Loading…
Add table
Reference in a new issue