From bc6f0f491f32dc90eed03c75c3528d78b4da7c42 Mon Sep 17 00:00:00 2001 From: Matthew Nickson Date: Wed, 16 Mar 2022 22:49:18 +0000 Subject: [PATCH] Added MySQL as a datastore The specific datastore backend to use can now be set by using command line options or by using environment variables. The default datastore backend is still jsondb but mysql can now also be used as a backend. Environment variables have also been added to control settings relevant to the database. SQL queries are made by directly accessing the database/sql API. TLS is also supported. Signed-off-by: Matthew Nickson --- go.mod | 1 + go.sum | 2 + main.go | 32 ++- store/mysqldb/mysqldb.go | 514 +++++++++++++++++++++++++++++++++++++++ templates/mysql.sql | 2 +- util/config.go | 1 + 6 files changed, 543 insertions(+), 9 deletions(-) create mode 100644 store/mysqldb/mysqldb.go diff --git a/go.mod b/go.mod index 918d704..a358cd8 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,7 @@ require ( github.com/GeertJohan/go.rice v1.0.0 github.com/glendc/go-external-ip v0.0.0-20170425150139-139229dcdddd github.com/go-playground/universal-translator v0.17.0 // indirect + github.com/go-sql-driver/mysql v1.6.0 github.com/gorilla/sessions v1.2.0 github.com/jcelliott/lumber v0.0.0-20160324203708-dd349441af25 // indirect github.com/labstack/echo-contrib v0.9.0 diff --git a/go.sum b/go.sum index 206090f..6ebb75f 100644 --- a/go.sum +++ b/go.sum @@ -27,6 +27,8 @@ github.com/go-playground/locales v0.13.0 h1:HyWk6mgj5qFqCT5fjGBuRArbVDfE4hi8+e8c github.com/go-playground/locales v0.13.0/go.mod h1:taPMhCMXrRLJO55olJkUXHZBHCxTMfnGwq/HNwmWNS8= github.com/go-playground/universal-translator v0.17.0 h1:icxd5fm+REJzpZx7ZfpaD876Lmtgy7VtROAbHHXk8no= github.com/go-playground/universal-translator v0.17.0/go.mod h1:UkSxE5sNxxRwHyU+Scu5vgOQjsIJAF8j9muTVoKLVtA= +github.com/go-sql-driver/mysql v1.6.0 h1:BCTh4TKNUYmOmMUcQ3IipzF5prigylS7XXjEkfCHuOE= +github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= diff --git a/main.go b/main.go index a63148d..8a44b43 100644 --- a/main.go +++ b/main.go @@ -12,7 +12,9 @@ import ( "github.com/ngoduykhanh/wireguard-ui/emailer" "github.com/ngoduykhanh/wireguard-ui/handler" "github.com/ngoduykhanh/wireguard-ui/router" + "github.com/ngoduykhanh/wireguard-ui/store" "github.com/ngoduykhanh/wireguard-ui/store/jsondb" + "github.com/ngoduykhanh/wireguard-ui/store/mysqldb" "github.com/ngoduykhanh/wireguard-ui/util" ) @@ -41,6 +43,7 @@ var ( flagDBDatabase string = "wireguard-ui" flagDBUsername string flagDBPassword string + flagDBTLS string = "false" ) const ( @@ -67,12 +70,13 @@ func init() { flag.StringVar(&flagEmailFrom, "email-from", util.LookupEnvOrString("EMAIL_FROM_ADDRESS", flagEmailFrom), "'From' email address.") flag.StringVar(&flagEmailFromName, "email-from-name", util.LookupEnvOrString("EMAIL_FROM_NAME", flagEmailFromName), "'From' email name.") flag.StringVar(&flagSessionSecret, "session-secret", util.LookupEnvOrString("SESSION_SECRET", flagSessionSecret), "The key used to encrypt session cookies.") - flag.StringVar(&flagDBType, "db-type", util.LookupEnvOrString("DB_TYPE", flagDBType), "Type of database to use. One of: `jsondb`|`mysql`.") + flag.StringVar(&flagDBType, "db-type", util.LookupEnvOrString("DB_TYPE", flagDBType), "Type of database to use. [jsondb, mysql]") flag.StringVar(&flagDBHost, "db-host", util.LookupEnvOrString("DB_HOST", flagDBHost), "Database host") flag.IntVar(&flagDBPort, "db-port", util.LookupEnvOrInt("DB_PORT", flagDBPort), "Database port") flag.StringVar(&flagDBDatabase, "db-database", util.LookupEnvOrString("DB_DATABASE", flagDBDatabase), "Database name") flag.StringVar(&flagDBUsername, "db-username", util.LookupEnvOrString("DB_USERNAME", flagDBUsername), "Database username") flag.StringVar(&flagDBPassword, "db-password", util.LookupEnvOrString("DB_PASSWORD", flagDBPassword), "Database password") + flag.StringVar(&flagDBTLS, "db-tls", util.LookupEnvOrString("DB_TLS", flagDBTLS), "TLS mode. [true, false, skip-verify, preferred]") flag.Parse() // update runtime config @@ -94,6 +98,7 @@ func init() { util.DBDatabase = flagDBDatabase util.DBUsername = flagDBUsername util.DBPassword = flagDBPassword + util.DBTLS = flagDBTLS // print app information fmt.Println("Wireguard UI") @@ -107,18 +112,12 @@ func init() { //fmt.Println("Sendgrid key\t:", util.SendgridApiKey) fmt.Println("Email from\t:", util.EmailFrom) fmt.Println("Email from name\t:", util.EmailFromName) + fmt.Println("Datastore\t:", util.DBType) //fmt.Println("Session secret\t:", util.SessionSecret) } func main() { - db, err := jsondb.New("./db") - if err != nil { - panic(err) - } - if err := db.Init(); err != nil { - panic(err) - } // set app extra data extraData := make(map[string]string) extraData["appVersion"] = appVersion @@ -129,6 +128,23 @@ func main() { // rice file server for assets. "assets" is the folder where the files come from. assetHandler := http.FileServer(rice.MustFindBox("assets").HTTPBox()) + // Configure database + var db store.IStore + var err error + switch util.DBType { + case "jsondb": + db, err = jsondb.New("./db") + case "mysql": + db, err = mysqldb.New(util.DBUsername, util.DBPassword, util.DBHost, util.DBPort, util.DBDatabase, util.DBTLS, tmplBox) + } + + if err != nil { + panic(err) + } + if err := db.Init(); err != nil { + panic(err) + } + // register routes app := router.New(tmplBox, extraData, util.SessionSecret) diff --git a/store/mysqldb/mysqldb.go b/store/mysqldb/mysqldb.go new file mode 100644 index 0000000..58290b7 --- /dev/null +++ b/store/mysqldb/mysqldb.go @@ -0,0 +1,514 @@ +// Package mysqldb provides a MySQL storage backend for Wireguard UI +package mysqldb + +import ( + "database/sql" + "encoding/base64" + "fmt" + "strings" + "time" + + rice "github.com/GeertJohan/go.rice" + "github.com/go-sql-driver/mysql" + "github.com/skip2/go-qrcode" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + + "github.com/ngoduykhanh/wireguard-ui/model" + "github.com/ngoduykhanh/wireguard-ui/util" +) + +// MySQLDB - Representation of MySQL database backend +type MySQLDB struct { + conn *sql.DB + schema string + dbName string +} + +// String to split each item in array +var arrayDelimiter = "," + +// New returns pointer to MySQL database +func New(uname string, pwd string, host string, port int, database string, tls string, templateBox *rice.Box) (*MySQLDB, error) { + // Set connection config + config := mysql.NewConfig() + config.User = uname + config.Passwd = pwd + config.Net = "tcp" + config.Addr = fmt.Sprintf("%s:%d", host, port) + config.DBName = database + config.MultiStatements = true + config.ParseTime = true + config.TLSConfig = tls + + // Open connection pool + conn, err := sql.Open("mysql", config.FormatDSN()) + if err != nil { + return nil, err + } + conn.SetConnMaxLifetime(time.Minute * 3) + conn.SetMaxOpenConns(10) + conn.SetMaxIdleConns(10) + + // Test the connection + if err := conn.Ping(); err != nil { + return nil, err + } + + // Load DB schema + schema, err := templateBox.String("mysql.sql") + if err != nil { + return nil, err + } + + ans := MySQLDB{ + conn: conn, + schema: schema, + dbName: database, + } + return &ans, nil +} + +// Init initializes the database +func (o *MySQLDB) Init() error { + // Check if database is empty + var databaseEmpty int + err := o.conn.QueryRow( + "SELECT COUNT(DISTINCT `table_name`) FROM `information_schema`.`columns` WHERE `table_schema` = ?", + o.dbName, + ).Scan(&databaseEmpty) + if err != nil { + return err + } + + if !(databaseEmpty > 0) { + // Initialize database + // Tell the user what we're doing as this could take a while + fmt.Println("Initializing database") + + // Create database schema + if _, err := o.conn.Exec(o.schema); err != nil { + return err + } + + // servers's interface + if _, err := o.conn.Exec( + "INSERT INTO interfaces (addresses, listen_port, updated_at) VALUES (?, ?, ?);", + util.DefaultServerAddress, + util.DefaultServerPort, + time.Now().UTC(), + ); err != nil { + return err + } + + // server's keypair + key, err := wgtypes.GeneratePrivateKey() + if err != nil { + return err + } + + if _, err := o.conn.Exec( + "INSERT INTO keypair (private_key, public_key, updated_at) VALUES (?, ?, ?);", + key.String(), + key.PublicKey().String(), + time.Now().UTC(), + ); err != nil { + return err + } + + // global settings + publicInterface, err := util.GetPublicIP() + if err != nil { + return err + } + + if _, err := o.conn.Exec( + "INSERT INTO global_settings (endpoint_address, dns_servers, mtu, persistent_keepalive, config_file_path, updated_at) VALUES (?, ?, ?, ?, ?, ?);", + publicInterface.IPAddress, + util.DefaultDNS, + util.DefaultMTU, + util.DefaultPersistentKeepalive, + util.DefaultConfigFilePath, + time.Now().UTC(), + ); err != nil { + return err + } + + // user info + if _, err := o.conn.Exec( + "INSERT INTO users (username, password) VALUES (?, ?);", + util.GetCredVar(util.UsernameEnvVar, util.DefaultUsername), + util.GetCredVar(util.PasswordEnvVar, util.DefaultPassword), + ); err != nil { + return err + } + } + + return nil +} + +// GetUser func to query user info from the database +func (o *MySQLDB) GetUser() (model.User, error) { + user := model.User{} + row := o.conn.QueryRow("SELECT username, password FROM users;") + err := row.Scan( + &user.Username, + &user.Password, + ) + return user, err +} + +// GetGlobalSettings func to query global settings from the database +func (o *MySQLDB) GetGlobalSettings() (model.GlobalSetting, error) { + settings := model.GlobalSetting{} + var dnsServers string + + row := o.conn.QueryRow("SELECT endpoint_address, dns_servers, mtu, persistent_keepalive, config_file_path, updated_at FROM global_settings;") + // Can't use ScanStruct here as doesn't know how to handle + // dns_servers list. Instead we must populate struct it manually. + err := row.Scan( + &settings.EndpointAddress, + &dnsServers, + &settings.MTU, + &settings.PersistentKeepalive, + &settings.ConfigFilePath, + &settings.UpdatedAt, + ) + settings.DNSServers = strings.Split(dnsServers, arrayDelimiter) + return settings, err +} + +// GetServer func to query Server setting from the database +func (o *MySQLDB) GetServer() (model.Server, error) { + server := model.Server{} + + // Get interface + serverInterface := model.ServerInterface{} + var addresses string + + row := o.conn.QueryRow("SELECT addresses, listen_port, updated_at, post_up, post_down FROM interfaces;") + err := row.Scan( + &addresses, + &serverInterface.ListenPort, + &serverInterface.UpdatedAt, + &serverInterface.PostUp, + &serverInterface.PostDown, + ) + serverInterface.Addresses = strings.Split(addresses, arrayDelimiter) + if err != nil { + return server, err + } + + // Get keypair + serverKeyPair := model.ServerKeypair{} + if err := o.conn.QueryRow("SELECT private_key, public_key, updated_at FROM keypair;"). + Scan( + &serverKeyPair.PrivateKey, + &serverKeyPair.PublicKey, + &serverKeyPair.UpdatedAt, + ); err != nil { + return server, err + } + + // create Server object and return + server.Interface = &serverInterface + server.KeyPair = &serverKeyPair + return server, nil +} + +// GetClients func to query Client settings from the database +func (o *MySQLDB) GetClients(hasQRCode bool) ([]model.ClientData, error) { + var clients []model.ClientData + + rows, err := o.conn.Query("SELECT * FROM clients;") + if err != nil { + return clients, err + } + + for rows.Next() { + client := model.Client{} + clientData := model.ClientData{} + var allocatedIPs string + var allowedIPs string + var extraAllowedIPs string + + // Get client info + if err := rows.Scan( + &client.ID, + &client.PrivateKey, + &client.PublicKey, + &client.PresharedKey, + &client.Name, + &client.Email, + &allocatedIPs, + &allowedIPs, + &extraAllowedIPs, + &client.UseServerDNS, + &client.Enabled, + &client.CreatedAt, + &client.UpdatedAt, + ); err != nil { + return clients, err + } + client.AllocatedIPs = strings.Split(allocatedIPs, arrayDelimiter) + client.AllowedIPs = strings.Split(allowedIPs, arrayDelimiter) + client.ExtraAllowedIPs = strings.Split(extraAllowedIPs, arrayDelimiter) + + // generate client qrcode image in base64 + if hasQRCode && client.PrivateKey != "" { + server, _ := o.GetServer() + globalSettings, _ := o.GetGlobalSettings() + + png, err := qrcode.Encode(util.BuildClientConfig(client, server, globalSettings), qrcode.Medium, 256) + if err == nil { + clientData.QRCode = "data:image/png;base64," + base64.StdEncoding.EncodeToString([]byte(png)) + } else { + fmt.Print("Cannot generate QR code: ", err) + } + } + + // create the list of clients and their qrcode data + clientData.Client = &client + clients = append(clients, clientData) + } + + return clients, nil +} + +// GetClientByID func to query Clients by ID from the database +func (o *MySQLDB) GetClientByID(clientID string, hasQRCode bool) (model.ClientData, error) { + client := model.Client{} + clientData := model.ClientData{} + var allocatedIPs string + var allowedIPs string + var extraAllowedIPs string + + // read client info + if err := o.conn.QueryRow("SELECT * FROM clients WHERE id = ?;", clientID).Scan( + &client.ID, + &client.PrivateKey, + &client.PublicKey, + &client.PresharedKey, + &client.Name, + &client.Email, + &allocatedIPs, + &allowedIPs, + &extraAllowedIPs, + &client.UseServerDNS, + &client.Enabled, + &client.CreatedAt, + &client.UpdatedAt, + ); err != nil { + return clientData, err + } + client.AllocatedIPs = strings.Split(allocatedIPs, arrayDelimiter) + client.AllowedIPs = strings.Split(allowedIPs, arrayDelimiter) + client.ExtraAllowedIPs = strings.Split(extraAllowedIPs, arrayDelimiter) + + // generate client qrcode image in base64 + if hasQRCode && client.PrivateKey != "" { + server, _ := o.GetServer() + globalSettings, _ := o.GetGlobalSettings() + + png, err := qrcode.Encode(util.BuildClientConfig(client, server, globalSettings), qrcode.Medium, 256) + if err == nil { + clientData.QRCode = "data:image/png;base64," + base64.StdEncoding.EncodeToString([]byte(png)) + } else { + fmt.Print("Cannot generate QR code: ", err) + } + } + + clientData.Client = &client + + return clientData, nil +} + +// SaveClient func saves client to database +func (o *MySQLDB) SaveClient(client model.Client) error { + // If client doesn't exist, create a record, else update existing record + querySet := ` + SET + @id = ?, + @private_key = ?, + @public_key = ?, + @preshared_key = ?, + @name = ?, + @email = ?, + @allocated_ips = ?, + @allowed_ips = ?, + @extra_allowed_ips = ?, + @use_server_dns = ?, + @enabled = ?, + @created_at = ?, + @updated_at = ?;` + queryInsert := ` + INSERT INTO clients( + id, + private_key, + public_key, + preshared_key, + NAME, + email, + allocated_ips, + allowed_ips, + extra_allowed_ips, + use_server_dns, + enabled, + created_at, + updated_at + ) + VALUES( + @id, + @private_key, + @public_key, + @preshared_key, + @name, + @email, + @allocated_ips, + @allowed_ips, + @extra_allowed_ips, + @use_server_dns, + @enabled, + @created_at, + @updated_at + ) + ON DUPLICATE KEY + UPDATE + id = @id, + private_key = @private_key, + public_key = @public_key, + preshared_key = @preshared_key, + NAME = @name, + email = @email, + allocated_ips = @allocated_ips, + allowed_ips = @allowed_ips, + extra_allowed_ips = @extra_allowed_ips, + use_server_dns = @use_server_dns, + enabled = @enabled, + created_at = @created_at, + updated_at = @updated_at;` + + tx, err := o.conn.Begin() + if err != nil { + return err + } + // set values + if _, err := tx.Exec( + querySet, + client.ID, + client.PrivateKey, + client.PublicKey, + client.PresharedKey, + client.Name, + client.Email, + strings.Join(client.AllocatedIPs, arrayDelimiter), + strings.Join(client.AllowedIPs, arrayDelimiter), + strings.Join(client.ExtraAllowedIPs, arrayDelimiter), + client.UseServerDNS, + client.Enabled, + client.CreatedAt, + client.UpdatedAt, + ); err != nil { + if rbErr := tx.Rollback(); rbErr != nil { + return fmt.Errorf("tx err: %v, rb err: %v", err, rbErr) + } + + return err + } + + // insert or update row + if _, err := tx.Exec(queryInsert); err != nil { + if rbErr := tx.Rollback(); rbErr != nil { + return fmt.Errorf("tx err: %v, rb err: %v", err, rbErr) + } + + return err + } + + return tx.Commit() +} + +// DeleteClient func deletes client from the database +func (o *MySQLDB) DeleteClient(clientID string) error { + if _, err := o.conn.Exec("DELETE FROM clients WHERE id=?;", clientID); err != nil { + return err + } + + return nil +} + +// SaveServerInterface func saves a server interface to database +func (o *MySQLDB) SaveServerInterface(serverInterface model.ServerInterface) error { + // No need for ON DUPLICATE KEY UPDATE as only ever 1 record + query := ` + UPDATE + interfaces + SET + addresses = ?, + listen_port = ?, + updated_at = ?, + post_up = ?, + post_down = ? + WHERE + id = 1;` + + _, err := o.conn.Exec( + query, + strings.Join(serverInterface.Addresses, arrayDelimiter), + serverInterface.ListenPort, + serverInterface.UpdatedAt, + serverInterface.PostUp, + serverInterface.PostDown, + ) + + return err +} + +// SaveServerKeyPair func saves a server keypair to database +func (o *MySQLDB) SaveServerKeyPair(serverKeyPair model.ServerKeypair) error { + query := ` + UPDATE + keypair + SET + private_key = ?, + public_key = ?, + updated_at = ? + WHERE + id = 1;` + + _, err := o.conn.Exec( + query, + serverKeyPair.PrivateKey, + serverKeyPair.PublicKey, + serverKeyPair.UpdatedAt, + ) + + return err +} + +// SaveGlobalSettings saves global settings to database +func (o *MySQLDB) SaveGlobalSettings(globalSettings model.GlobalSetting) error { + query := ` + UPDATE + global_settings + SET + endpoint_address = ?, + dns_servers = ?, + mtu = ?, + persistent_keepalive = ?, + config_file_path = ?, + updated_at = ? + WHERE + id = 1;` + + _, err := o.conn.Exec( + query, + globalSettings.EndpointAddress, + strings.Join(globalSettings.DNSServers, arrayDelimiter), + globalSettings.MTU, + globalSettings.PersistentKeepalive, + globalSettings.ConfigFilePath, + globalSettings.UpdatedAt, + ) + + return err +} diff --git a/templates/mysql.sql b/templates/mysql.sql index 3af0e39..ffeb26f 100644 --- a/templates/mysql.sql +++ b/templates/mysql.sql @@ -77,4 +77,4 @@ ALTER TABLE `keypair` ALTER TABLE `users` MODIFY `id` INT(11) NOT NULL AUTO_INCREMENT; -COMMIT; \ No newline at end of file +COMMIT; diff --git a/util/config.go b/util/config.go index 8423d17..9dc4948 100644 --- a/util/config.go +++ b/util/config.go @@ -22,6 +22,7 @@ var ( DBDatabase string DBUsername string DBPassword string + DBTLS string ) const (