mirror of
https://github.com/ngoduykhanh/wireguard-ui.git
synced 2025-04-20 20:03:39 +03:00

Use MaxAge instead of Expires Verify if the cookie is not too old and not from the future Verify if the user exists and unchanged Refresh not sooner than 24h Do not refresh temporary sessions Delete cookies on logout
412 lines
12 KiB
Go
412 lines
12 KiB
Go
package jsondb
|
|
|
|
import (
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"fmt"
|
|
"os"
|
|
"path"
|
|
"strconv"
|
|
"time"
|
|
|
|
"github.com/sdomino/scribble"
|
|
"github.com/skip2/go-qrcode"
|
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
|
|
|
"github.com/ngoduykhanh/wireguard-ui/model"
|
|
"github.com/ngoduykhanh/wireguard-ui/util"
|
|
)
|
|
|
|
type JsonDB struct {
|
|
conn *scribble.Driver
|
|
dbPath string
|
|
}
|
|
|
|
// New returns a new pointer JsonDB
|
|
func New(dbPath string) (*JsonDB, error) {
|
|
conn, err := scribble.New(dbPath, nil)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
ans := JsonDB{
|
|
conn: conn,
|
|
dbPath: dbPath,
|
|
}
|
|
return &ans, nil
|
|
|
|
}
|
|
|
|
func (o *JsonDB) Init() error {
|
|
var clientPath = path.Join(o.dbPath, "clients")
|
|
var serverPath = path.Join(o.dbPath, "server")
|
|
var userPath = path.Join(o.dbPath, "users")
|
|
var wakeOnLanHostsPath = path.Join(o.dbPath, "wake_on_lan_hosts")
|
|
var serverInterfacePath = path.Join(serverPath, "interfaces.json")
|
|
var serverKeyPairPath = path.Join(serverPath, "keypair.json")
|
|
var globalSettingPath = path.Join(serverPath, "global_settings.json")
|
|
var hashesPath = path.Join(serverPath, "hashes.json")
|
|
|
|
// create directories if they do not exist
|
|
if _, err := os.Stat(clientPath); os.IsNotExist(err) {
|
|
os.MkdirAll(clientPath, os.ModePerm)
|
|
}
|
|
if _, err := os.Stat(serverPath); os.IsNotExist(err) {
|
|
os.MkdirAll(serverPath, os.ModePerm)
|
|
}
|
|
if _, err := os.Stat(userPath); os.IsNotExist(err) {
|
|
os.MkdirAll(userPath, os.ModePerm)
|
|
}
|
|
if _, err := os.Stat(wakeOnLanHostsPath); os.IsNotExist(err) {
|
|
os.MkdirAll(wakeOnLanHostsPath, os.ModePerm)
|
|
}
|
|
|
|
// server's interface
|
|
if _, err := os.Stat(serverInterfacePath); os.IsNotExist(err) {
|
|
serverInterface := new(model.ServerInterface)
|
|
serverInterface.Addresses = util.LookupEnvOrStrings(util.ServerAddressesEnvVar, []string{util.DefaultServerAddress})
|
|
serverInterface.ListenPort = util.LookupEnvOrInt(util.ServerListenPortEnvVar, util.DefaultServerPort)
|
|
serverInterface.PostUp = util.LookupEnvOrString(util.ServerPostUpScriptEnvVar, "")
|
|
serverInterface.PostDown = util.LookupEnvOrString(util.ServerPostDownScriptEnvVar, "")
|
|
serverInterface.UpdatedAt = time.Now().UTC()
|
|
o.conn.Write("server", "interfaces", serverInterface)
|
|
err := util.ManagePerms(serverInterfacePath)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
// server's key pair
|
|
if _, err := os.Stat(serverKeyPairPath); os.IsNotExist(err) {
|
|
|
|
key, err := wgtypes.GeneratePrivateKey()
|
|
if err != nil {
|
|
return scribble.ErrMissingCollection
|
|
}
|
|
serverKeyPair := new(model.ServerKeypair)
|
|
serverKeyPair.PrivateKey = key.String()
|
|
serverKeyPair.PublicKey = key.PublicKey().String()
|
|
serverKeyPair.UpdatedAt = time.Now().UTC()
|
|
o.conn.Write("server", "keypair", serverKeyPair)
|
|
err = util.ManagePerms(serverKeyPairPath)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
// global settings
|
|
if _, err := os.Stat(globalSettingPath); os.IsNotExist(err) {
|
|
endpointAddress := util.LookupEnvOrString(util.EndpointAddressEnvVar, "")
|
|
if endpointAddress == "" {
|
|
// automatically find an external IP address
|
|
publicInterface, err := util.GetPublicIP()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
endpointAddress = publicInterface.IPAddress
|
|
}
|
|
|
|
globalSetting := new(model.GlobalSetting)
|
|
globalSetting.EndpointAddress = endpointAddress
|
|
globalSetting.DNSServers = util.LookupEnvOrStrings(util.DNSEnvVar, []string{util.DefaultDNS})
|
|
globalSetting.MTU = util.LookupEnvOrInt(util.MTUEnvVar, util.DefaultMTU)
|
|
globalSetting.PersistentKeepalive = util.LookupEnvOrInt(util.PersistentKeepaliveEnvVar, util.DefaultPersistentKeepalive)
|
|
globalSetting.FirewallMark = util.LookupEnvOrString(util.FirewallMarkEnvVar, util.DefaultFirewallMark)
|
|
globalSetting.Table = util.LookupEnvOrString(util.TableEnvVar, util.DefaultTable)
|
|
globalSetting.ConfigFilePath = util.LookupEnvOrString(util.ConfigFilePathEnvVar, util.DefaultConfigFilePath)
|
|
globalSetting.UpdatedAt = time.Now().UTC()
|
|
o.conn.Write("server", "global_settings", globalSetting)
|
|
err := util.ManagePerms(globalSettingPath)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
// hashes
|
|
if _, err := os.Stat(hashesPath); os.IsNotExist(err) {
|
|
clientServerHashes := new(model.ClientServerHashes)
|
|
clientServerHashes.Client = "none"
|
|
clientServerHashes.Server = "none"
|
|
o.conn.Write("server", "hashes", clientServerHashes)
|
|
err := util.ManagePerms(hashesPath)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
// user info
|
|
results, err := o.conn.ReadAll("users")
|
|
if err != nil || len(results) < 1 {
|
|
user := new(model.User)
|
|
user.Username = util.LookupEnvOrString(util.UsernameEnvVar, util.DefaultUsername)
|
|
user.Admin = util.DefaultIsAdmin
|
|
user.PasswordHash = util.LookupEnvOrString(util.PasswordHashEnvVar, "")
|
|
if user.PasswordHash == "" {
|
|
user.PasswordHash = util.LookupEnvOrFile(util.PasswordHashFileEnvVar, "")
|
|
if user.PasswordHash == "" {
|
|
plaintext := util.LookupEnvOrString(util.PasswordEnvVar, util.DefaultPassword)
|
|
if plaintext == util.DefaultPassword {
|
|
plaintext = util.LookupEnvOrFile(util.PasswordFileEnvVar, util.DefaultPassword)
|
|
}
|
|
hash, err := util.HashPassword(plaintext)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
user.PasswordHash = hash
|
|
}
|
|
}
|
|
|
|
o.conn.Write("users", user.Username, user)
|
|
err = util.ManagePerms(path.Join(path.Join(o.dbPath, "users"), user.Username+".json"))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
// init cache
|
|
for _, i := range results {
|
|
user := model.User{}
|
|
|
|
if err := json.Unmarshal([]byte(i), &user); err == nil {
|
|
util.DBUsersToCRC32[user.Username] = util.GetDBUserCRC32(user)
|
|
}
|
|
}
|
|
|
|
clients, err := o.GetClients(false)
|
|
if err != nil {
|
|
return nil
|
|
}
|
|
for _, cl := range clients {
|
|
client := cl.Client
|
|
if client.Enabled && len(client.TgUserid) > 0 {
|
|
if userid, err := strconv.ParseInt(client.TgUserid, 10, 64); err == nil {
|
|
util.UpdateTgToClientID(userid, client.ID)
|
|
}
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// GetUsers func to get all users from the database
|
|
func (o *JsonDB) GetUsers() ([]model.User, error) {
|
|
var users []model.User
|
|
results, err := o.conn.ReadAll("users")
|
|
if err != nil {
|
|
return users, err
|
|
}
|
|
for _, i := range results {
|
|
user := model.User{}
|
|
|
|
if err := json.Unmarshal(i, &user); err != nil {
|
|
return users, fmt.Errorf("cannot decode user json structure: %v", err)
|
|
}
|
|
users = append(users, user)
|
|
|
|
}
|
|
return users, err
|
|
}
|
|
|
|
// GetUserByName func to get single user from the database
|
|
func (o *JsonDB) GetUserByName(username string) (model.User, error) {
|
|
user := model.User{}
|
|
|
|
if err := o.conn.Read("users", username, &user); err != nil {
|
|
return user, err
|
|
}
|
|
|
|
return user, nil
|
|
}
|
|
|
|
// SaveUser func to save user in the database
|
|
func (o *JsonDB) SaveUser(user model.User) error {
|
|
userPath := path.Join(path.Join(o.dbPath, "users"), user.Username+".json")
|
|
output := o.conn.Write("users", user.Username, user)
|
|
err := util.ManagePerms(userPath)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
util.DBUsersToCRC32[user.Username] = util.GetDBUserCRC32(user)
|
|
return output
|
|
}
|
|
|
|
// DeleteUser func to remove user from the database
|
|
func (o *JsonDB) DeleteUser(username string) error {
|
|
delete(util.DBUsersToCRC32, username)
|
|
return o.conn.Delete("users", username)
|
|
}
|
|
|
|
// GetGlobalSettings func to query global settings from the database
|
|
func (o *JsonDB) GetGlobalSettings() (model.GlobalSetting, error) {
|
|
settings := model.GlobalSetting{}
|
|
return settings, o.conn.Read("server", "global_settings", &settings)
|
|
}
|
|
|
|
// GetServer func to query Server settings from the database
|
|
func (o *JsonDB) GetServer() (model.Server, error) {
|
|
server := model.Server{}
|
|
// read server interface information
|
|
serverInterface := model.ServerInterface{}
|
|
if err := o.conn.Read("server", "interfaces", &serverInterface); err != nil {
|
|
return server, err
|
|
}
|
|
|
|
// read server key pair information
|
|
serverKeyPair := model.ServerKeypair{}
|
|
if err := o.conn.Read("server", "keypair", &serverKeyPair); err != nil {
|
|
return server, err
|
|
}
|
|
|
|
// create Server object and return
|
|
server.Interface = &serverInterface
|
|
server.KeyPair = &serverKeyPair
|
|
return server, nil
|
|
}
|
|
|
|
func (o *JsonDB) GetClients(hasQRCode bool) ([]model.ClientData, error) {
|
|
var clients []model.ClientData
|
|
|
|
// read all client json files in "clients" directory
|
|
records, err := o.conn.ReadAll("clients")
|
|
if err != nil {
|
|
return clients, err
|
|
}
|
|
|
|
// build the ClientData list
|
|
for _, f := range records {
|
|
client := model.Client{}
|
|
clientData := model.ClientData{}
|
|
|
|
// get client info
|
|
if err := json.Unmarshal(f, &client); err != nil {
|
|
return clients, fmt.Errorf("cannot decode client json structure: %v", err)
|
|
}
|
|
|
|
// generate client qrcode image in base64
|
|
if hasQRCode && client.PrivateKey != "" {
|
|
server, _ := o.GetServer()
|
|
globalSettings, _ := o.GetGlobalSettings()
|
|
|
|
png, err := qrcode.Encode(util.BuildClientConfig(client, server, globalSettings), qrcode.Medium, 256)
|
|
if err == nil {
|
|
clientData.QRCode = "data:image/png;base64," + base64.StdEncoding.EncodeToString(png)
|
|
} else {
|
|
fmt.Print("Cannot generate QR code: ", err)
|
|
}
|
|
}
|
|
|
|
// create the list of clients and their qrcode data
|
|
clientData.Client = &client
|
|
clients = append(clients, clientData)
|
|
}
|
|
|
|
return clients, nil
|
|
}
|
|
|
|
func (o *JsonDB) GetClientByID(clientID string, qrCodeSettings model.QRCodeSettings) (model.ClientData, error) {
|
|
client := model.Client{}
|
|
clientData := model.ClientData{}
|
|
|
|
// read client information
|
|
if err := o.conn.Read("clients", clientID, &client); err != nil {
|
|
return clientData, err
|
|
}
|
|
|
|
// generate client qrcode image in base64
|
|
if qrCodeSettings.Enabled && client.PrivateKey != "" {
|
|
server, _ := o.GetServer()
|
|
globalSettings, _ := o.GetGlobalSettings()
|
|
client := client
|
|
if !qrCodeSettings.IncludeDNS {
|
|
globalSettings.DNSServers = []string{}
|
|
}
|
|
if !qrCodeSettings.IncludeMTU {
|
|
globalSettings.MTU = 0
|
|
}
|
|
|
|
png, err := qrcode.Encode(util.BuildClientConfig(client, server, globalSettings), qrcode.Medium, 256)
|
|
if err == nil {
|
|
clientData.QRCode = "data:image/png;base64," + base64.StdEncoding.EncodeToString(png)
|
|
} else {
|
|
fmt.Print("Cannot generate QR code: ", err)
|
|
}
|
|
}
|
|
|
|
clientData.Client = &client
|
|
|
|
return clientData, nil
|
|
}
|
|
|
|
func (o *JsonDB) SaveClient(client model.Client) error {
|
|
clientPath := path.Join(path.Join(o.dbPath, "clients"), client.ID+".json")
|
|
output := o.conn.Write("clients", client.ID, client)
|
|
if output == nil {
|
|
if client.Enabled && len(client.TgUserid) > 0 {
|
|
if userid, err := strconv.ParseInt(client.TgUserid, 10, 64); err == nil {
|
|
util.UpdateTgToClientID(userid, client.ID)
|
|
}
|
|
} else {
|
|
util.RemoveTgToClientID(client.ID)
|
|
}
|
|
} else {
|
|
util.RemoveTgToClientID(client.ID)
|
|
}
|
|
err := util.ManagePerms(clientPath)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return output
|
|
}
|
|
|
|
func (o *JsonDB) DeleteClient(clientID string) error {
|
|
util.RemoveTgToClientID(clientID)
|
|
return o.conn.Delete("clients", clientID)
|
|
}
|
|
|
|
func (o *JsonDB) SaveServerInterface(serverInterface model.ServerInterface) error {
|
|
serverInterfacePath := path.Join(path.Join(o.dbPath, "server"), "interfaces.json")
|
|
output := o.conn.Write("server", "interfaces", serverInterface)
|
|
err := util.ManagePerms(serverInterfacePath)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return output
|
|
}
|
|
|
|
func (o *JsonDB) SaveServerKeyPair(serverKeyPair model.ServerKeypair) error {
|
|
serverKeyPairPath := path.Join(path.Join(o.dbPath, "server"), "keypair.json")
|
|
output := o.conn.Write("server", "keypair", serverKeyPair)
|
|
err := util.ManagePerms(serverKeyPairPath)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return output
|
|
}
|
|
|
|
func (o *JsonDB) SaveGlobalSettings(globalSettings model.GlobalSetting) error {
|
|
globalSettingsPath := path.Join(path.Join(o.dbPath, "server"), "global_settings.json")
|
|
output := o.conn.Write("server", "global_settings", globalSettings)
|
|
err := util.ManagePerms(globalSettingsPath)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return output
|
|
}
|
|
|
|
func (o *JsonDB) GetPath() string {
|
|
return o.dbPath
|
|
}
|
|
|
|
func (o *JsonDB) GetHashes() (model.ClientServerHashes, error) {
|
|
hashes := model.ClientServerHashes{}
|
|
return hashes, o.conn.Read("server", "hashes", &hashes)
|
|
}
|
|
|
|
func (o *JsonDB) SaveHashes(hashes model.ClientServerHashes) error {
|
|
hashesPath := path.Join(path.Join(o.dbPath, "server"), "hashes.json")
|
|
output := o.conn.Write("server", "hashes", hashes)
|
|
err := util.ManagePerms(hashesPath)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return output
|
|
}
|