From 5559e028e256d28b9be78762184d136fec1be621 Mon Sep 17 00:00:00 2001 From: Giorgos Komninos Date: Sun, 15 Aug 2021 13:02:26 +0300 Subject: [PATCH] Creates a layer of abstraction for database so we can change storage if needed --- handler/routes.go | 191 ++++++++++++--------------- main.go | 47 ++++--- util/db.go => store/jsondb/jsondb.go | 173 ++++++++++-------------- store/store.go | 19 +++ util/config.go | 11 ++ 5 files changed, 210 insertions(+), 231 deletions(-) rename util/db.go => store/jsondb/jsondb.go (54%) create mode 100644 store/store.go diff --git a/handler/routes.go b/handler/routes.go index 22df3d7..3460540 100644 --- a/handler/routes.go +++ b/handler/routes.go @@ -9,16 +9,17 @@ import ( "time" rice "github.com/GeertJohan/go.rice" - "github.com/gorilla/sessions" "github.com/labstack/echo-contrib/session" "github.com/labstack/echo/v4" "github.com/labstack/gommon/log" - "github.com/ngoduykhanh/wireguard-ui/emailer" - "github.com/ngoduykhanh/wireguard-ui/model" - "github.com/ngoduykhanh/wireguard-ui/util" "github.com/rs/xid" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + + "github.com/ngoduykhanh/wireguard-ui/emailer" + "github.com/ngoduykhanh/wireguard-ui/model" + "github.com/ngoduykhanh/wireguard-ui/store" + "github.com/ngoduykhanh/wireguard-ui/util" ) // LoginPage handler @@ -29,12 +30,12 @@ func LoginPage() echo.HandlerFunc { } // Login for signing in handler -func Login() echo.HandlerFunc { +func Login(db store.IStore) echo.HandlerFunc { return func(c echo.Context) error { user := new(model.User) c.Bind(user) - dbuser, err := util.GetUser() + dbuser, err := db.GetUser() if err != nil { return c.JSON(http.StatusInternalServerError, jsonHTTPResponse{false, "Cannot query user from DB"}) } @@ -77,10 +78,10 @@ func Logout() echo.HandlerFunc { } // WireGuardClients handler -func WireGuardClients() echo.HandlerFunc { +func WireGuardClients(db store.IStore) echo.HandlerFunc { return func(c echo.Context) error { - clientDataList, err := util.GetClients(true) + clientDataList, err := db.GetClients(true) if err != nil { return c.JSON(http.StatusInternalServerError, jsonHTTPResponse{ false, fmt.Sprintf("Cannot get client list: %v", err), @@ -95,10 +96,10 @@ func WireGuardClients() echo.HandlerFunc { } // GetClients handler return a list of Wireguard client data -func GetClients() echo.HandlerFunc { +func GetClients(db store.IStore) echo.HandlerFunc { return func(c echo.Context) error { - clientDataList, err := util.GetClients(true) + clientDataList, err := db.GetClients(true) if err != nil { return c.JSON(http.StatusInternalServerError, jsonHTTPResponse{ false, fmt.Sprintf("Cannot get client list: %v", err), @@ -110,11 +111,11 @@ func GetClients() echo.HandlerFunc { } // GetClient handler return a of Wireguard client data -func GetClient() echo.HandlerFunc { +func GetClient(db store.IStore) echo.HandlerFunc { return func(c echo.Context) error { clientID := c.Param("id") - clientData, err := util.GetClientByID(clientID, true) + clientData, err := db.GetClientByID(clientID, true) if err != nil { return c.JSON(http.StatusNotFound, jsonHTTPResponse{false, "Client not found"}) } @@ -124,27 +125,22 @@ func GetClient() echo.HandlerFunc { } // NewClient handler -func NewClient() echo.HandlerFunc { +func NewClient(db store.IStore) echo.HandlerFunc { return func(c echo.Context) error { - client := new(model.Client) - c.Bind(client) - - db, err := util.DBConn() - if err != nil { - log.Error("Cannot initialize database: ", err) - return c.JSON(http.StatusInternalServerError, jsonHTTPResponse{false, "Cannot access database"}) - } + var client model.Client + c.Bind(&client) // read server information - serverInterface := model.ServerInterface{} - if err := db.Read("server", "interfaces", &serverInterface); err != nil { - log.Error("Cannot fetch server interface config from database: ", err) + server, err := db.GetServer() + if err != nil { + log.Error("Cannot fetch server from database: ", err) + return c.JSON(http.StatusInternalServerError, jsonHTTPResponse{false, err.Error()}) } // validate the input Allocation IPs allocatedIPs, err := util.GetAllocatedIPs("") - check, err := util.ValidateIPAllocation(serverInterface.Addresses, allocatedIPs, client.AllocatedIPs) + check, err := util.ValidateIPAllocation(server.Interface.Addresses, allocatedIPs, client.AllocatedIPs) if !check { return c.JSON(http.StatusBadRequest, jsonHTTPResponse{false, fmt.Sprintf("%s", err)}) } @@ -181,7 +177,11 @@ func NewClient() echo.HandlerFunc { client.UpdatedAt = client.CreatedAt // write client to the database - db.Write("clients", client.ID, client) + if err := db.SaveClient(client); err != nil { + return c.JSON(http.StatusInternalServerError, jsonHTTPResponse{ + false, err.Error(), + }) + } log.Infof("Created wireguard client: %v", client) return c.JSON(http.StatusOK, client) @@ -189,7 +189,7 @@ func NewClient() echo.HandlerFunc { } // EmailClient handler to sent the configuration via email -func EmailClient(mailer emailer.Emailer, emailSubject, emailContent string) echo.HandlerFunc { +func EmailClient(db store.IStore, mailer emailer.Emailer, emailSubject, emailContent string) echo.HandlerFunc { type clientIdEmailPayload struct { ID string `json:"id"` Email string `json:"email"` @@ -200,15 +200,15 @@ func EmailClient(mailer emailer.Emailer, emailSubject, emailContent string) echo c.Bind(&payload) // TODO validate email - clientData, err := util.GetClientByID(payload.ID, true) + clientData, err := db.GetClientByID(payload.ID, true) if err != nil { log.Errorf("Cannot generate client id %s config file for downloading: %v", payload.ID, err) return c.JSON(http.StatusNotFound, jsonHTTPResponse{false, "Client not found"}) } // build config - server, _ := util.GetServer() - globalSettings, _ := util.GetGlobalSettings() + server, _ := db.GetServer() + globalSettings, _ := db.GetGlobalSettings() config := util.BuildClientConfig(*clientData.Client, server, globalSettings) cfg_att := emailer.Attachment{"wg0.conf", []byte(config)} @@ -233,36 +233,28 @@ func EmailClient(mailer emailer.Emailer, emailSubject, emailContent string) echo } // UpdateClient handler to update client information -func UpdateClient() echo.HandlerFunc { +func UpdateClient(db store.IStore) echo.HandlerFunc { return func(c echo.Context) error { - _client := new(model.Client) - c.Bind(_client) - - db, err := util.DBConn() - if err != nil { - log.Error("Cannot initialize database: ", err) - return c.JSON(http.StatusInternalServerError, jsonHTTPResponse{false, "Cannot access database"}) - } + var _client model.Client + c.Bind(&_client) // validate client existence - client := model.Client{} - if err := db.Read("clients", _client.ID, &client); err != nil { + clientData, err := db.GetClientByID(_client.ID, false) + if err != nil { return c.JSON(http.StatusNotFound, jsonHTTPResponse{false, "Client not found"}) } - // read server information - serverInterface := model.ServerInterface{} - if err := db.Read("server", "interfaces", &serverInterface); err != nil { - log.Error("Cannot fetch server interface config from database: ", err) + server, err := db.GetServer() + if err != nil { return c.JSON(http.StatusBadRequest, jsonHTTPResponse{ false, fmt.Sprintf("Cannot fetch server config: %s", err), }) } - + client := *clientData.Client // validate the input Allocation IPs allocatedIPs, err := util.GetAllocatedIPs(client.ID) - check, err := util.ValidateIPAllocation(serverInterface.Addresses, allocatedIPs, _client.AllocatedIPs) + check, err := util.ValidateIPAllocation(server.Interface.Addresses, allocatedIPs, _client.AllocatedIPs) if !check { return c.JSON(http.StatusBadRequest, jsonHTTPResponse{false, fmt.Sprintf("%s", err)}) } @@ -283,7 +275,9 @@ func UpdateClient() echo.HandlerFunc { client.UpdatedAt = time.Now().UTC() // write to the database - db.Write("clients", client.ID, &client) + if err := db.SaveClient(client); err != nil { + return c.JSON(http.StatusInternalServerError, jsonHTTPResponse{false, err.Error()}) + } log.Infof("Updated client information successfully => %v", client) return c.JSON(http.StatusOK, jsonHTTPResponse{true, "Updated client successfully"}) @@ -291,7 +285,7 @@ func UpdateClient() echo.HandlerFunc { } // SetClientStatus handler to enable / disable a client -func SetClientStatus() echo.HandlerFunc { +func SetClientStatus(db store.IStore) echo.HandlerFunc { return func(c echo.Context) error { data := make(map[string]interface{}) @@ -304,19 +298,17 @@ func SetClientStatus() echo.HandlerFunc { clientID := data["id"].(string) status := data["status"].(bool) - db, err := util.DBConn() + clientdata, err := db.GetClientByID(clientID, false) if err != nil { - log.Error("Cannot initialize database: ", err) - return c.JSON(http.StatusInternalServerError, jsonHTTPResponse{false, "Cannot access database"}) + return c.JSON(http.StatusNotFound, jsonHTTPResponse{false, err.Error()}) } - client := model.Client{} - if err := db.Read("clients", clientID, &client); err != nil { - log.Error("Cannot get client from database: ", err) - } + client := *clientdata.Client client.Enabled = status - db.Write("clients", clientID, &client) + if err := db.SaveClient(client); err != nil { + return c.JSON(http.StatusInternalServerError, jsonHTTPResponse{false, err.Error()}) + } log.Infof("Changed client %s enabled status to %v", client.ID, status) return c.JSON(http.StatusOK, jsonHTTPResponse{true, "Changed client status successfully"}) @@ -324,22 +316,28 @@ func SetClientStatus() echo.HandlerFunc { } // DownloadClient handler -func DownloadClient() echo.HandlerFunc { +func DownloadClient(db store.IStore) echo.HandlerFunc { return func(c echo.Context) error { clientID := c.QueryParam("clientid") if clientID == "" { return c.JSON(http.StatusNotFound, jsonHTTPResponse{false, "Missing clientid parameter"}) } - clientData, err := util.GetClientByID(clientID, false) + clientData, err := db.GetClientByID(clientID, false) if err != nil { log.Errorf("Cannot generate client id %s config file for downloading: %v", clientID, err) return c.JSON(http.StatusNotFound, jsonHTTPResponse{false, "Client not found"}) } // build config - server, _ := util.GetServer() - globalSettings, _ := util.GetGlobalSettings() + server, err := db.GetServer() + if err != nil { + return c.JSON(http.StatusInternalServerError, jsonHTTPResponse{false, err.Error()}) + } + globalSettings, err := db.GetGlobalSettings() + if err != nil { + return c.JSON(http.StatusInternalServerError, jsonHTTPResponse{false, err.Error()}) + } config := util.BuildClientConfig(*clientData.Client, server, globalSettings) // create io reader from string @@ -352,20 +350,15 @@ func DownloadClient() echo.HandlerFunc { } // RemoveClient handler -func RemoveClient() echo.HandlerFunc { +func RemoveClient(db store.IStore) echo.HandlerFunc { return func(c echo.Context) error { client := new(model.Client) c.Bind(client) // delete client from database - db, err := util.DBConn() - if err != nil { - log.Error("Cannot initialize database: ", err) - return c.JSON(http.StatusInternalServerError, jsonHTTPResponse{false, "Cannot access database"}) - } - if err := db.Delete("clients", client.ID); err != nil { + if err := db.DeleteClient(client.ID); err != nil { log.Error("Cannot delete wireguard client: ", err) return c.JSON(http.StatusInternalServerError, jsonHTTPResponse{false, "Cannot delete client from database"}) } @@ -376,10 +369,10 @@ func RemoveClient() echo.HandlerFunc { } // WireGuardServer handler -func WireGuardServer() echo.HandlerFunc { +func WireGuardServer(db store.IStore) echo.HandlerFunc { return func(c echo.Context) error { - server, err := util.GetServer() + server, err := db.GetServer() if err != nil { log.Error("Cannot get server config: ", err) } @@ -393,11 +386,11 @@ func WireGuardServer() echo.HandlerFunc { } // WireGuardServerInterfaces handler -func WireGuardServerInterfaces() echo.HandlerFunc { +func WireGuardServerInterfaces(db store.IStore) echo.HandlerFunc { return func(c echo.Context) error { - serverInterface := new(model.ServerInterface) - c.Bind(serverInterface) + var serverInterface model.ServerInterface + c.Bind(&serverInterface) // validate the input addresses if util.ValidateServerAddresses(serverInterface.Addresses) == false { @@ -408,13 +401,10 @@ func WireGuardServerInterfaces() echo.HandlerFunc { serverInterface.UpdatedAt = time.Now().UTC() // write config to the database - db, err := util.DBConn() - if err != nil { - log.Error("Cannot initialize database: ", err) - return c.JSON(http.StatusInternalServerError, jsonHTTPResponse{false, "Cannot access database"}) - } - db.Write("server", "interfaces", serverInterface) + if err := db.SaveServerInterface(serverInterface); err != nil { + return c.JSON(http.StatusInternalServerError, jsonHTTPResponse{false, "Interface IP address must be in CIDR format"}) + } log.Infof("Updated wireguard server interfaces settings: %v", serverInterface) return c.JSON(http.StatusOK, jsonHTTPResponse{true, "Updated interface addresses successfully"}) @@ -422,7 +412,7 @@ func WireGuardServerInterfaces() echo.HandlerFunc { } // WireGuardServerKeyPair handler to generate private and public keys -func WireGuardServerKeyPair() echo.HandlerFunc { +func WireGuardServerKeyPair(db store.IStore) echo.HandlerFunc { return func(c echo.Context) error { // gen Wireguard key pair @@ -432,19 +422,14 @@ func WireGuardServerKeyPair() echo.HandlerFunc { return c.JSON(http.StatusInternalServerError, jsonHTTPResponse{false, "Cannot generate Wireguard key pair"}) } - serverKeyPair := new(model.ServerKeypair) + var serverKeyPair model.ServerKeypair serverKeyPair.PrivateKey = key.String() serverKeyPair.PublicKey = key.PublicKey().String() serverKeyPair.UpdatedAt = time.Now().UTC() - // write config to the database - db, err := util.DBConn() - if err != nil { - log.Error("Cannot initialize database: ", err) - return c.JSON(http.StatusInternalServerError, jsonHTTPResponse{false, "Cannot access database"}) + if err := db.SaveServerKeyPair(serverKeyPair); err != nil { + return c.JSON(http.StatusInternalServerError, jsonHTTPResponse{false, "Cannot generate Wireguard key pair"}) } - - db.Write("server", "keypair", serverKeyPair) log.Infof("Updated wireguard server interfaces settings: %v", serverKeyPair) return c.JSON(http.StatusOK, serverKeyPair) @@ -452,10 +437,10 @@ func WireGuardServerKeyPair() echo.HandlerFunc { } // GlobalSettings handler -func GlobalSettings() echo.HandlerFunc { +func GlobalSettings(db store.IStore) echo.HandlerFunc { return func(c echo.Context) error { - globalSettings, err := util.GetGlobalSettings() + globalSettings, err := db.GetGlobalSettings() if err != nil { log.Error("Cannot get global settings: ", err) } @@ -468,11 +453,11 @@ func GlobalSettings() echo.HandlerFunc { } // GlobalSettingSubmit handler to update the global settings -func GlobalSettingSubmit() echo.HandlerFunc { +func GlobalSettingSubmit(db store.IStore) echo.HandlerFunc { return func(c echo.Context) error { - globalSettings := new(model.GlobalSetting) - c.Bind(globalSettings) + var globalSettings model.GlobalSetting + c.Bind(&globalSettings) // validate the input dns server list if util.ValidateIPAddressList(globalSettings.DNSServers) == false { @@ -483,13 +468,10 @@ func GlobalSettingSubmit() echo.HandlerFunc { globalSettings.UpdatedAt = time.Now().UTC() // write config to the database - db, err := util.DBConn() - if err != nil { - log.Error("Cannot initialize database: ", err) - return c.JSON(http.StatusInternalServerError, jsonHTTPResponse{false, "Cannot access database"}) + if err := db.SaveGlobalSettings(globalSettings); err != nil { + return c.JSON(http.StatusInternalServerError, jsonHTTPResponse{false, "Cannot generate Wireguard key pair"}) } - db.Write("server", "global_settings", globalSettings) log.Infof("Updated global settings: %v", globalSettings) return c.JSON(http.StatusOK, jsonHTTPResponse{true, "Updated global settings successfully"}) @@ -521,12 +503,13 @@ func MachineIPAddresses() echo.HandlerFunc { } // SuggestIPAllocation handler to get the list of ip address for client -func SuggestIPAllocation() echo.HandlerFunc { +func SuggestIPAllocation(db store.IStore) echo.HandlerFunc { return func(c echo.Context) error { - server, err := util.GetServer() + server, err := db.GetServer() if err != nil { log.Error("Cannot fetch server config from database: ", err) + return c.JSON(http.StatusBadRequest, jsonHTTPResponse{false, err.Error()}) } // return the list of suggestedIPs @@ -557,22 +540,22 @@ func SuggestIPAllocation() echo.HandlerFunc { } // ApplyServerConfig handler to write config file and restart Wireguard server -func ApplyServerConfig(tmplBox *rice.Box) echo.HandlerFunc { +func ApplyServerConfig(db store.IStore, tmplBox *rice.Box) echo.HandlerFunc { return func(c echo.Context) error { - server, err := util.GetServer() + server, err := db.GetServer() if err != nil { log.Error("Cannot get server config: ", err) return c.JSON(http.StatusInternalServerError, jsonHTTPResponse{false, "Cannot get server config"}) } - clients, err := util.GetClients(false) + clients, err := db.GetClients(false) if err != nil { log.Error("Cannot get client config: ", err) return c.JSON(http.StatusInternalServerError, jsonHTTPResponse{false, "Cannot get client config"}) } - settings, err := util.GetGlobalSettings() + settings, err := db.GetGlobalSettings() if err != nil { log.Error("Cannot get global settings: ", err) return c.JSON(http.StatusInternalServerError, jsonHTTPResponse{false, "Cannot get global settings"}) diff --git a/main.go b/main.go index 2e9e5cd..ea1720b 100644 --- a/main.go +++ b/main.go @@ -13,6 +13,7 @@ import ( "github.com/ngoduykhanh/wireguard-ui/emailer" "github.com/ngoduykhanh/wireguard-ui/handler" "github.com/ngoduykhanh/wireguard-ui/router" + "github.com/ngoduykhanh/wireguard-ui/store/jsondb" "github.com/ngoduykhanh/wireguard-ui/util" ) @@ -57,14 +58,16 @@ func init() { fmt.Println("Authentication\t:", !util.DisableLogin) fmt.Println("Bind address\t:", util.BindAddress) - // initialize DB - err := util.InitDB() - if err != nil { - fmt.Print("Cannot init database: ", err) - } } func main() { + db, err := jsondb.New("./db") + if err != nil { + panic(err) + } + if err := db.Init(); err != nil { + panic(err) + } // set app extra data extraData := make(map[string]string) extraData["appVersion"] = appVersion @@ -78,32 +81,32 @@ func main() { // register routes app := router.New(tmplBox, extraData, util.SessionSecret) - app.GET("/", handler.WireGuardClients(), handler.ValidSession) + app.GET("/", handler.WireGuardClients(db), handler.ValidSession) if !util.DisableLogin { app.GET("/login", handler.LoginPage()) - app.POST("/login", handler.Login()) + app.POST("/login", handler.Login(db)) } sendmail := emailer.NewSendgridApiMail(util.SendgridApiKey, util.EmailFromName, util.EmailFrom) app.GET("/logout", handler.Logout(), handler.ValidSession) - app.POST("/new-client", handler.NewClient(), handler.ValidSession) - app.POST("/update-client", handler.UpdateClient(), handler.ValidSession) - app.POST("/email-client", handler.EmailClient(sendmail, defaultEmailSubject, defaultEmailContent), handler.ValidSession) - app.POST("/client/set-status", handler.SetClientStatus(), handler.ValidSession) - app.POST("/remove-client", handler.RemoveClient(), handler.ValidSession) - app.GET("/download", handler.DownloadClient(), handler.ValidSession) - app.GET("/wg-server", handler.WireGuardServer(), handler.ValidSession) - app.POST("wg-server/interfaces", handler.WireGuardServerInterfaces(), handler.ValidSession) - app.POST("wg-server/keypair", handler.WireGuardServerKeyPair(), handler.ValidSession) - app.GET("/global-settings", handler.GlobalSettings(), handler.ValidSession) - app.POST("/global-settings", handler.GlobalSettingSubmit(), handler.ValidSession) - app.GET("/api/clients", handler.GetClients(), handler.ValidSession) - app.GET("/api/client/:id", handler.GetClient(), handler.ValidSession) + app.POST("/new-client", handler.NewClient(db), handler.ValidSession) + app.POST("/update-client", handler.UpdateClient(db), handler.ValidSession) + app.POST("/email-client", handler.EmailClient(db, sendmail, defaultEmailSubject, defaultEmailContent), handler.ValidSession) + app.POST("/client/set-status", handler.SetClientStatus(db), handler.ValidSession) + app.POST("/remove-client", handler.RemoveClient(db), handler.ValidSession) + app.GET("/download", handler.DownloadClient(db), handler.ValidSession) + app.GET("/wg-server", handler.WireGuardServer(db), handler.ValidSession) + app.POST("wg-server/interfaces", handler.WireGuardServerInterfaces(db), handler.ValidSession) + app.POST("wg-server/keypair", handler.WireGuardServerKeyPair(db), handler.ValidSession) + app.GET("/global-settings", handler.GlobalSettings(db), handler.ValidSession) + app.POST("/global-settings", handler.GlobalSettingSubmit(db), handler.ValidSession) + app.GET("/api/clients", handler.GetClients(db), handler.ValidSession) + app.GET("/api/client/:id", handler.GetClient(db), handler.ValidSession) app.GET("/api/machine-ips", handler.MachineIPAddresses(), handler.ValidSession) - app.GET("/api/suggest-client-ips", handler.SuggestIPAllocation(), handler.ValidSession) - app.GET("/api/apply-wg-config", handler.ApplyServerConfig(tmplBox), handler.ValidSession) + app.GET("/api/suggest-client-ips", handler.SuggestIPAllocation(db), handler.ValidSession) + app.GET("/api/apply-wg-config", handler.ApplyServerConfig(db, tmplBox), handler.ValidSession) // servers other static files app.GET("/static/*", echo.WrapHandler(http.StripPrefix("/static/", assetHandler))) diff --git a/util/db.go b/store/jsondb/jsondb.go similarity index 54% rename from util/db.go rename to store/jsondb/jsondb.go index ded1306..b6ca74e 100644 --- a/util/db.go +++ b/store/jsondb/jsondb.go @@ -1,4 +1,4 @@ -package util +package jsondb import ( "encoding/base64" @@ -9,39 +9,38 @@ import ( "time" "github.com/ngoduykhanh/wireguard-ui/model" + "github.com/ngoduykhanh/wireguard-ui/util" "github.com/sdomino/scribble" "github.com/skip2/go-qrcode" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" ) -const dbPath = "./db" -const defaultUsername = "admin" -const defaultPassword = "admin" -const defaultServerAddress = "10.252.1.0/24" -const defaultServerPort = 51820 -const defaultDNS = "1.1.1.1" -const defaultMTU = 1450 -const defaultPersistentKeepalive = 15 -const defaultConfigFilePath = "/etc/wireguard/wg0.conf" +type JsonDB struct { + conn *scribble.Driver + dbPath string +} -// DBConn to initialize the database connection -func DBConn() (*scribble.Driver, error) { - db, err := scribble.New(dbPath, nil) +// New returns a new pointer JsonDB +func New(dbPath string) (*JsonDB, error) { + conn, err := scribble.New(dbPath, nil) if err != nil { return nil, err } - return db, nil + ans := JsonDB{ + conn: conn, + dbPath: dbPath, + } + return &ans, nil + } -// InitDB to create the default database -func InitDB() error { - var clientPath string = path.Join(dbPath, "clients") - var serverPath string = path.Join(dbPath, "server") +func (o *JsonDB) Init() error { + var clientPath string = path.Join(o.dbPath, "clients") + var serverPath string = path.Join(o.dbPath, "server") var serverInterfacePath string = path.Join(serverPath, "interfaces.json") var serverKeyPairPath string = path.Join(serverPath, "keypair.json") var globalSettingPath string = path.Join(serverPath, "global_settings.json") var userPath string = path.Join(serverPath, "users.json") - // create directories if they do not exist if _, err := os.Stat(clientPath); os.IsNotExist(err) { os.MkdirAll(clientPath, os.ModePerm) @@ -52,24 +51,15 @@ func InitDB() error { // server's interface if _, err := os.Stat(serverInterfacePath); os.IsNotExist(err) { - db, err := DBConn() - if err != nil { - return err - } - serverInterface := new(model.ServerInterface) - serverInterface.Addresses = []string{defaultServerAddress} - serverInterface.ListenPort = defaultServerPort + serverInterface.Addresses = []string{util.DefaultServerAddress} + serverInterface.ListenPort = util.DefaultServerPort serverInterface.UpdatedAt = time.Now().UTC() - db.Write("server", "interfaces", serverInterface) + o.conn.Write("server", "interfaces", serverInterface) } // server's key pair if _, err := os.Stat(serverKeyPairPath); os.IsNotExist(err) { - db, err := DBConn() - if err != nil { - return err - } key, err := wgtypes.GeneratePrivateKey() if err != nil { @@ -79,97 +69,62 @@ func InitDB() error { serverKeyPair.PrivateKey = key.String() serverKeyPair.PublicKey = key.PublicKey().String() serverKeyPair.UpdatedAt = time.Now().UTC() - db.Write("server", "keypair", serverKeyPair) + o.conn.Write("server", "keypair", serverKeyPair) } // global settings if _, err := os.Stat(globalSettingPath); os.IsNotExist(err) { - db, err := DBConn() - if err != nil { - return err - } - publicInterface, err := GetPublicIP() + publicInterface, err := util.GetPublicIP() if err != nil { return err } globalSetting := new(model.GlobalSetting) globalSetting.EndpointAddress = publicInterface.IPAddress - globalSetting.DNSServers = []string{defaultDNS} - globalSetting.MTU = defaultMTU - globalSetting.PersistentKeepalive = defaultPersistentKeepalive - globalSetting.ConfigFilePath = defaultConfigFilePath + globalSetting.DNSServers = []string{util.DefaultDNS} + globalSetting.MTU = util.DefaultMTU + globalSetting.PersistentKeepalive = util.DefaultPersistentKeepalive + globalSetting.ConfigFilePath = util.DefaultConfigFilePath globalSetting.UpdatedAt = time.Now().UTC() - db.Write("server", "global_settings", globalSetting) + o.conn.Write("server", "global_settings", globalSetting) } // user info if _, err := os.Stat(userPath); os.IsNotExist(err) { - db, err := DBConn() - if err != nil { - return err - } - user := new(model.User) - user.Username = defaultUsername - user.Password = defaultPassword - db.Write("server", "users", user) + user.Username = util.DefaultUsername + user.Password = util.DefaultPassword + o.conn.Write("server", "users", user) } return nil } // GetUser func to query user info from the database -func GetUser() (model.User, error) { +func (o *JsonDB) GetUser() (model.User, error) { user := model.User{} - - db, err := DBConn() - if err != nil { - return user, err - } - - if err := db.Read("server", "users", &user); err != nil { - return user, err - } - - return user, nil + return user, o.conn.Read("server", "users", &user) } // GetGlobalSettings func to query global settings from the database -func GetGlobalSettings() (model.GlobalSetting, error) { +func (o *JsonDB) GetGlobalSettings() (model.GlobalSetting, error) { settings := model.GlobalSetting{} - - db, err := DBConn() - if err != nil { - return settings, err - } - - if err := db.Read("server", "global_settings", &settings); err != nil { - return settings, err - } - - return settings, nil + return settings, o.conn.Read("server", "global_settings", &settings) } // GetServer func to query Server setting from the database -func GetServer() (model.Server, error) { +func (o *JsonDB) GetServer() (model.Server, error) { server := model.Server{} - - db, err := DBConn() - if err != nil { - return server, err - } - // read server interface information serverInterface := model.ServerInterface{} - if err := db.Read("server", "interfaces", &serverInterface); err != nil { + if err := o.conn.Read("server", "interfaces", &serverInterface); err != nil { return server, err } // read server key pair information serverKeyPair := model.ServerKeypair{} - if err := db.Read("server", "keypair", &serverKeyPair); err != nil { + if err := o.conn.Read("server", "keypair", &serverKeyPair); err != nil { return server, err } @@ -179,17 +134,11 @@ func GetServer() (model.Server, error) { return server, nil } -// GetClients to get all clients from the database -func GetClients(hasQRCode bool) ([]model.ClientData, error) { +func (o *JsonDB) GetClients(hasQRCode bool) ([]model.ClientData, error) { var clients []model.ClientData - db, err := DBConn() - if err != nil { - return clients, err - } - // read all client json file in "clients" directory - records, err := db.ReadAll("clients") + records, err := o.conn.ReadAll("clients") if err != nil { return clients, err } @@ -206,10 +155,10 @@ func GetClients(hasQRCode bool) ([]model.ClientData, error) { // generate client qrcode image in base64 if hasQRCode { - server, _ := GetServer() - globalSettings, _ := GetGlobalSettings() + server, _ := o.GetServer() + globalSettings, _ := o.GetGlobalSettings() - png, err := qrcode.Encode(BuildClientConfig(client, server, globalSettings), qrcode.Medium, 256) + png, err := qrcode.Encode(util.BuildClientConfig(client, server, globalSettings), qrcode.Medium, 256) if err == nil { clientData.QRCode = "data:image/png;base64," + base64.StdEncoding.EncodeToString([]byte(png)) } else { @@ -225,27 +174,21 @@ func GetClients(hasQRCode bool) ([]model.ClientData, error) { return clients, nil } -// GetClientByID func to query a client from the database -func GetClientByID(clientID string, hasQRCode bool) (model.ClientData, error) { +func (o *JsonDB) GetClientByID(clientID string, hasQRCode bool) (model.ClientData, error) { client := model.Client{} clientData := model.ClientData{} - db, err := DBConn() - if err != nil { - return clientData, err - } - // read client information - if err := db.Read("clients", clientID, &client); err != nil { + if err := o.conn.Read("clients", clientID, &client); err != nil { return clientData, err } // generate client qrcode image in base64 if hasQRCode { - server, _ := GetServer() - globalSettings, _ := GetGlobalSettings() + server, _ := o.GetServer() + globalSettings, _ := o.GetGlobalSettings() - png, err := qrcode.Encode(BuildClientConfig(client, server, globalSettings), qrcode.Medium, 256) + png, err := qrcode.Encode(util.BuildClientConfig(client, server, globalSettings), qrcode.Medium, 256) if err == nil { clientData.QRCode = "data:image/png;base64," + base64.StdEncoding.EncodeToString([]byte(png)) } else { @@ -257,3 +200,23 @@ func GetClientByID(clientID string, hasQRCode bool) (model.ClientData, error) { return clientData, nil } + +func (o *JsonDB) SaveClient(client model.Client) error { + return o.conn.Write("clients", client.ID, client) +} + +func (o *JsonDB) DeleteClient(clientID string) error { + return o.conn.Delete("clients", clientID) +} + +func (o *JsonDB) SaveServerInterface(serverInterface model.ServerInterface) error { + return o.conn.Write("server", "interfaces", serverInterface) +} + +func (o *JsonDB) SaveServerKeyPair(serverKeyPair model.ServerKeypair) error { + return o.conn.Write("server", "keypair", serverKeyPair) +} + +func (o *JsonDB) SaveGlobalSettings(globalSettings model.GlobalSetting) error { + return o.conn.Write("server", "global_settings", globalSettings) +} diff --git a/store/store.go b/store/store.go new file mode 100644 index 0000000..8750cf2 --- /dev/null +++ b/store/store.go @@ -0,0 +1,19 @@ +package store + +import ( + "github.com/ngoduykhanh/wireguard-ui/model" +) + +type IStore interface { + Init() error + GetUser() (model.User, error) + GetGlobalSettings() (model.GlobalSetting, error) + GetServer() (model.Server, error) + GetClients(hasQRCode bool) ([]model.ClientData, error) + GetClientByID(clientID string, hasQRCode bool) (model.ClientData, error) + SaveClient(client model.Client) error + DeleteClient(clientID string) error + SaveServerInterface(serverInterface model.ServerInterface) error + SaveServerKeyPair(serverKeyPair model.ServerKeypair) error + SaveGlobalSettings(globalSettings model.GlobalSetting) error +} diff --git a/util/config.go b/util/config.go index 5f86ea9..a2e0297 100644 --- a/util/config.go +++ b/util/config.go @@ -11,3 +11,14 @@ var ( EmailContent string SessionSecret []byte ) + +const ( + DefaultUsername = "admin" + DefaultPassword = "admin" + DefaultServerAddress = "10.252.1.0/24" + DefaultServerPort = 51820 + DefaultDNS = "1.1.1.1" + DefaultMTU = 1450 + DefaultPersistentKeepalive = 15 + DefaultConfigFilePath = "/etc/wireguard/wg0.conf" +)