Added MySQL as a datastore

The specific datastore backend to use can now be set by using
command line options or by using environment variables. The default
datastore backend is still jsondb but mysql can now also be used as a
backend. Environment variables have also been added to control settings
relevant to the database. SQL queries are made by directly accessing the
database/sql API. TLS is also supported.

Signed-off-by: Matthew Nickson <mnickson@sidingsmedia.com>
This commit is contained in:
Matthew Nickson 2022-03-16 22:49:18 +00:00
parent 4a28486f73
commit bc6f0f491f
No known key found for this signature in database
GPG key ID: BF229DCFD4748E05
6 changed files with 543 additions and 9 deletions

1
go.mod
View file

@ -6,6 +6,7 @@ require (
github.com/GeertJohan/go.rice v1.0.0 github.com/GeertJohan/go.rice v1.0.0
github.com/glendc/go-external-ip v0.0.0-20170425150139-139229dcdddd github.com/glendc/go-external-ip v0.0.0-20170425150139-139229dcdddd
github.com/go-playground/universal-translator v0.17.0 // indirect github.com/go-playground/universal-translator v0.17.0 // indirect
github.com/go-sql-driver/mysql v1.6.0
github.com/gorilla/sessions v1.2.0 github.com/gorilla/sessions v1.2.0
github.com/jcelliott/lumber v0.0.0-20160324203708-dd349441af25 // indirect github.com/jcelliott/lumber v0.0.0-20160324203708-dd349441af25 // indirect
github.com/labstack/echo-contrib v0.9.0 github.com/labstack/echo-contrib v0.9.0

2
go.sum
View file

@ -27,6 +27,8 @@ github.com/go-playground/locales v0.13.0 h1:HyWk6mgj5qFqCT5fjGBuRArbVDfE4hi8+e8c
github.com/go-playground/locales v0.13.0/go.mod h1:taPMhCMXrRLJO55olJkUXHZBHCxTMfnGwq/HNwmWNS8= github.com/go-playground/locales v0.13.0/go.mod h1:taPMhCMXrRLJO55olJkUXHZBHCxTMfnGwq/HNwmWNS8=
github.com/go-playground/universal-translator v0.17.0 h1:icxd5fm+REJzpZx7ZfpaD876Lmtgy7VtROAbHHXk8no= github.com/go-playground/universal-translator v0.17.0 h1:icxd5fm+REJzpZx7ZfpaD876Lmtgy7VtROAbHHXk8no=
github.com/go-playground/universal-translator v0.17.0/go.mod h1:UkSxE5sNxxRwHyU+Scu5vgOQjsIJAF8j9muTVoKLVtA= github.com/go-playground/universal-translator v0.17.0/go.mod h1:UkSxE5sNxxRwHyU+Scu5vgOQjsIJAF8j9muTVoKLVtA=
github.com/go-sql-driver/mysql v1.6.0 h1:BCTh4TKNUYmOmMUcQ3IipzF5prigylS7XXjEkfCHuOE=
github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg=
github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY=
github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ=
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=

32
main.go
View file

@ -12,7 +12,9 @@ import (
"github.com/ngoduykhanh/wireguard-ui/emailer" "github.com/ngoduykhanh/wireguard-ui/emailer"
"github.com/ngoduykhanh/wireguard-ui/handler" "github.com/ngoduykhanh/wireguard-ui/handler"
"github.com/ngoduykhanh/wireguard-ui/router" "github.com/ngoduykhanh/wireguard-ui/router"
"github.com/ngoduykhanh/wireguard-ui/store"
"github.com/ngoduykhanh/wireguard-ui/store/jsondb" "github.com/ngoduykhanh/wireguard-ui/store/jsondb"
"github.com/ngoduykhanh/wireguard-ui/store/mysqldb"
"github.com/ngoduykhanh/wireguard-ui/util" "github.com/ngoduykhanh/wireguard-ui/util"
) )
@ -41,6 +43,7 @@ var (
flagDBDatabase string = "wireguard-ui" flagDBDatabase string = "wireguard-ui"
flagDBUsername string flagDBUsername string
flagDBPassword string flagDBPassword string
flagDBTLS string = "false"
) )
const ( const (
@ -67,12 +70,13 @@ func init() {
flag.StringVar(&flagEmailFrom, "email-from", util.LookupEnvOrString("EMAIL_FROM_ADDRESS", flagEmailFrom), "'From' email address.") flag.StringVar(&flagEmailFrom, "email-from", util.LookupEnvOrString("EMAIL_FROM_ADDRESS", flagEmailFrom), "'From' email address.")
flag.StringVar(&flagEmailFromName, "email-from-name", util.LookupEnvOrString("EMAIL_FROM_NAME", flagEmailFromName), "'From' email name.") flag.StringVar(&flagEmailFromName, "email-from-name", util.LookupEnvOrString("EMAIL_FROM_NAME", flagEmailFromName), "'From' email name.")
flag.StringVar(&flagSessionSecret, "session-secret", util.LookupEnvOrString("SESSION_SECRET", flagSessionSecret), "The key used to encrypt session cookies.") flag.StringVar(&flagSessionSecret, "session-secret", util.LookupEnvOrString("SESSION_SECRET", flagSessionSecret), "The key used to encrypt session cookies.")
flag.StringVar(&flagDBType, "db-type", util.LookupEnvOrString("DB_TYPE", flagDBType), "Type of database to use. One of: `jsondb`|`mysql`.") flag.StringVar(&flagDBType, "db-type", util.LookupEnvOrString("DB_TYPE", flagDBType), "Type of database to use. [jsondb, mysql]")
flag.StringVar(&flagDBHost, "db-host", util.LookupEnvOrString("DB_HOST", flagDBHost), "Database host") flag.StringVar(&flagDBHost, "db-host", util.LookupEnvOrString("DB_HOST", flagDBHost), "Database host")
flag.IntVar(&flagDBPort, "db-port", util.LookupEnvOrInt("DB_PORT", flagDBPort), "Database port") flag.IntVar(&flagDBPort, "db-port", util.LookupEnvOrInt("DB_PORT", flagDBPort), "Database port")
flag.StringVar(&flagDBDatabase, "db-database", util.LookupEnvOrString("DB_DATABASE", flagDBDatabase), "Database name") flag.StringVar(&flagDBDatabase, "db-database", util.LookupEnvOrString("DB_DATABASE", flagDBDatabase), "Database name")
flag.StringVar(&flagDBUsername, "db-username", util.LookupEnvOrString("DB_USERNAME", flagDBUsername), "Database username") flag.StringVar(&flagDBUsername, "db-username", util.LookupEnvOrString("DB_USERNAME", flagDBUsername), "Database username")
flag.StringVar(&flagDBPassword, "db-password", util.LookupEnvOrString("DB_PASSWORD", flagDBPassword), "Database password") flag.StringVar(&flagDBPassword, "db-password", util.LookupEnvOrString("DB_PASSWORD", flagDBPassword), "Database password")
flag.StringVar(&flagDBTLS, "db-tls", util.LookupEnvOrString("DB_TLS", flagDBTLS), "TLS mode. [true, false, skip-verify, preferred]")
flag.Parse() flag.Parse()
// update runtime config // update runtime config
@ -94,6 +98,7 @@ func init() {
util.DBDatabase = flagDBDatabase util.DBDatabase = flagDBDatabase
util.DBUsername = flagDBUsername util.DBUsername = flagDBUsername
util.DBPassword = flagDBPassword util.DBPassword = flagDBPassword
util.DBTLS = flagDBTLS
// print app information // print app information
fmt.Println("Wireguard UI") fmt.Println("Wireguard UI")
@ -107,18 +112,12 @@ func init() {
//fmt.Println("Sendgrid key\t:", util.SendgridApiKey) //fmt.Println("Sendgrid key\t:", util.SendgridApiKey)
fmt.Println("Email from\t:", util.EmailFrom) fmt.Println("Email from\t:", util.EmailFrom)
fmt.Println("Email from name\t:", util.EmailFromName) fmt.Println("Email from name\t:", util.EmailFromName)
fmt.Println("Datastore\t:", util.DBType)
//fmt.Println("Session secret\t:", util.SessionSecret) //fmt.Println("Session secret\t:", util.SessionSecret)
} }
func main() { func main() {
db, err := jsondb.New("./db")
if err != nil {
panic(err)
}
if err := db.Init(); err != nil {
panic(err)
}
// set app extra data // set app extra data
extraData := make(map[string]string) extraData := make(map[string]string)
extraData["appVersion"] = appVersion extraData["appVersion"] = appVersion
@ -129,6 +128,23 @@ func main() {
// rice file server for assets. "assets" is the folder where the files come from. // rice file server for assets. "assets" is the folder where the files come from.
assetHandler := http.FileServer(rice.MustFindBox("assets").HTTPBox()) assetHandler := http.FileServer(rice.MustFindBox("assets").HTTPBox())
// Configure database
var db store.IStore
var err error
switch util.DBType {
case "jsondb":
db, err = jsondb.New("./db")
case "mysql":
db, err = mysqldb.New(util.DBUsername, util.DBPassword, util.DBHost, util.DBPort, util.DBDatabase, util.DBTLS, tmplBox)
}
if err != nil {
panic(err)
}
if err := db.Init(); err != nil {
panic(err)
}
// register routes // register routes
app := router.New(tmplBox, extraData, util.SessionSecret) app := router.New(tmplBox, extraData, util.SessionSecret)

514
store/mysqldb/mysqldb.go Normal file
View file

@ -0,0 +1,514 @@
// Package mysqldb provides a MySQL storage backend for Wireguard UI
package mysqldb
import (
"database/sql"
"encoding/base64"
"fmt"
"strings"
"time"
rice "github.com/GeertJohan/go.rice"
"github.com/go-sql-driver/mysql"
"github.com/skip2/go-qrcode"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/ngoduykhanh/wireguard-ui/model"
"github.com/ngoduykhanh/wireguard-ui/util"
)
// MySQLDB - Representation of MySQL database backend
type MySQLDB struct {
conn *sql.DB
schema string
dbName string
}
// String to split each item in array
var arrayDelimiter = ","
// New returns pointer to MySQL database
func New(uname string, pwd string, host string, port int, database string, tls string, templateBox *rice.Box) (*MySQLDB, error) {
// Set connection config
config := mysql.NewConfig()
config.User = uname
config.Passwd = pwd
config.Net = "tcp"
config.Addr = fmt.Sprintf("%s:%d", host, port)
config.DBName = database
config.MultiStatements = true
config.ParseTime = true
config.TLSConfig = tls
// Open connection pool
conn, err := sql.Open("mysql", config.FormatDSN())
if err != nil {
return nil, err
}
conn.SetConnMaxLifetime(time.Minute * 3)
conn.SetMaxOpenConns(10)
conn.SetMaxIdleConns(10)
// Test the connection
if err := conn.Ping(); err != nil {
return nil, err
}
// Load DB schema
schema, err := templateBox.String("mysql.sql")
if err != nil {
return nil, err
}
ans := MySQLDB{
conn: conn,
schema: schema,
dbName: database,
}
return &ans, nil
}
// Init initializes the database
func (o *MySQLDB) Init() error {
// Check if database is empty
var databaseEmpty int
err := o.conn.QueryRow(
"SELECT COUNT(DISTINCT `table_name`) FROM `information_schema`.`columns` WHERE `table_schema` = ?",
o.dbName,
).Scan(&databaseEmpty)
if err != nil {
return err
}
if !(databaseEmpty > 0) {
// Initialize database
// Tell the user what we're doing as this could take a while
fmt.Println("Initializing database")
// Create database schema
if _, err := o.conn.Exec(o.schema); err != nil {
return err
}
// servers's interface
if _, err := o.conn.Exec(
"INSERT INTO interfaces (addresses, listen_port, updated_at) VALUES (?, ?, ?);",
util.DefaultServerAddress,
util.DefaultServerPort,
time.Now().UTC(),
); err != nil {
return err
}
// server's keypair
key, err := wgtypes.GeneratePrivateKey()
if err != nil {
return err
}
if _, err := o.conn.Exec(
"INSERT INTO keypair (private_key, public_key, updated_at) VALUES (?, ?, ?);",
key.String(),
key.PublicKey().String(),
time.Now().UTC(),
); err != nil {
return err
}
// global settings
publicInterface, err := util.GetPublicIP()
if err != nil {
return err
}
if _, err := o.conn.Exec(
"INSERT INTO global_settings (endpoint_address, dns_servers, mtu, persistent_keepalive, config_file_path, updated_at) VALUES (?, ?, ?, ?, ?, ?);",
publicInterface.IPAddress,
util.DefaultDNS,
util.DefaultMTU,
util.DefaultPersistentKeepalive,
util.DefaultConfigFilePath,
time.Now().UTC(),
); err != nil {
return err
}
// user info
if _, err := o.conn.Exec(
"INSERT INTO users (username, password) VALUES (?, ?);",
util.GetCredVar(util.UsernameEnvVar, util.DefaultUsername),
util.GetCredVar(util.PasswordEnvVar, util.DefaultPassword),
); err != nil {
return err
}
}
return nil
}
// GetUser func to query user info from the database
func (o *MySQLDB) GetUser() (model.User, error) {
user := model.User{}
row := o.conn.QueryRow("SELECT username, password FROM users;")
err := row.Scan(
&user.Username,
&user.Password,
)
return user, err
}
// GetGlobalSettings func to query global settings from the database
func (o *MySQLDB) GetGlobalSettings() (model.GlobalSetting, error) {
settings := model.GlobalSetting{}
var dnsServers string
row := o.conn.QueryRow("SELECT endpoint_address, dns_servers, mtu, persistent_keepalive, config_file_path, updated_at FROM global_settings;")
// Can't use ScanStruct here as doesn't know how to handle
// dns_servers list. Instead we must populate struct it manually.
err := row.Scan(
&settings.EndpointAddress,
&dnsServers,
&settings.MTU,
&settings.PersistentKeepalive,
&settings.ConfigFilePath,
&settings.UpdatedAt,
)
settings.DNSServers = strings.Split(dnsServers, arrayDelimiter)
return settings, err
}
// GetServer func to query Server setting from the database
func (o *MySQLDB) GetServer() (model.Server, error) {
server := model.Server{}
// Get interface
serverInterface := model.ServerInterface{}
var addresses string
row := o.conn.QueryRow("SELECT addresses, listen_port, updated_at, post_up, post_down FROM interfaces;")
err := row.Scan(
&addresses,
&serverInterface.ListenPort,
&serverInterface.UpdatedAt,
&serverInterface.PostUp,
&serverInterface.PostDown,
)
serverInterface.Addresses = strings.Split(addresses, arrayDelimiter)
if err != nil {
return server, err
}
// Get keypair
serverKeyPair := model.ServerKeypair{}
if err := o.conn.QueryRow("SELECT private_key, public_key, updated_at FROM keypair;").
Scan(
&serverKeyPair.PrivateKey,
&serverKeyPair.PublicKey,
&serverKeyPair.UpdatedAt,
); err != nil {
return server, err
}
// create Server object and return
server.Interface = &serverInterface
server.KeyPair = &serverKeyPair
return server, nil
}
// GetClients func to query Client settings from the database
func (o *MySQLDB) GetClients(hasQRCode bool) ([]model.ClientData, error) {
var clients []model.ClientData
rows, err := o.conn.Query("SELECT * FROM clients;")
if err != nil {
return clients, err
}
for rows.Next() {
client := model.Client{}
clientData := model.ClientData{}
var allocatedIPs string
var allowedIPs string
var extraAllowedIPs string
// Get client info
if err := rows.Scan(
&client.ID,
&client.PrivateKey,
&client.PublicKey,
&client.PresharedKey,
&client.Name,
&client.Email,
&allocatedIPs,
&allowedIPs,
&extraAllowedIPs,
&client.UseServerDNS,
&client.Enabled,
&client.CreatedAt,
&client.UpdatedAt,
); err != nil {
return clients, err
}
client.AllocatedIPs = strings.Split(allocatedIPs, arrayDelimiter)
client.AllowedIPs = strings.Split(allowedIPs, arrayDelimiter)
client.ExtraAllowedIPs = strings.Split(extraAllowedIPs, arrayDelimiter)
// generate client qrcode image in base64
if hasQRCode && client.PrivateKey != "" {
server, _ := o.GetServer()
globalSettings, _ := o.GetGlobalSettings()
png, err := qrcode.Encode(util.BuildClientConfig(client, server, globalSettings), qrcode.Medium, 256)
if err == nil {
clientData.QRCode = "data:image/png;base64," + base64.StdEncoding.EncodeToString([]byte(png))
} else {
fmt.Print("Cannot generate QR code: ", err)
}
}
// create the list of clients and their qrcode data
clientData.Client = &client
clients = append(clients, clientData)
}
return clients, nil
}
// GetClientByID func to query Clients by ID from the database
func (o *MySQLDB) GetClientByID(clientID string, hasQRCode bool) (model.ClientData, error) {
client := model.Client{}
clientData := model.ClientData{}
var allocatedIPs string
var allowedIPs string
var extraAllowedIPs string
// read client info
if err := o.conn.QueryRow("SELECT * FROM clients WHERE id = ?;", clientID).Scan(
&client.ID,
&client.PrivateKey,
&client.PublicKey,
&client.PresharedKey,
&client.Name,
&client.Email,
&allocatedIPs,
&allowedIPs,
&extraAllowedIPs,
&client.UseServerDNS,
&client.Enabled,
&client.CreatedAt,
&client.UpdatedAt,
); err != nil {
return clientData, err
}
client.AllocatedIPs = strings.Split(allocatedIPs, arrayDelimiter)
client.AllowedIPs = strings.Split(allowedIPs, arrayDelimiter)
client.ExtraAllowedIPs = strings.Split(extraAllowedIPs, arrayDelimiter)
// generate client qrcode image in base64
if hasQRCode && client.PrivateKey != "" {
server, _ := o.GetServer()
globalSettings, _ := o.GetGlobalSettings()
png, err := qrcode.Encode(util.BuildClientConfig(client, server, globalSettings), qrcode.Medium, 256)
if err == nil {
clientData.QRCode = "data:image/png;base64," + base64.StdEncoding.EncodeToString([]byte(png))
} else {
fmt.Print("Cannot generate QR code: ", err)
}
}
clientData.Client = &client
return clientData, nil
}
// SaveClient func saves client to database
func (o *MySQLDB) SaveClient(client model.Client) error {
// If client doesn't exist, create a record, else update existing record
querySet := `
SET
@id = ?,
@private_key = ?,
@public_key = ?,
@preshared_key = ?,
@name = ?,
@email = ?,
@allocated_ips = ?,
@allowed_ips = ?,
@extra_allowed_ips = ?,
@use_server_dns = ?,
@enabled = ?,
@created_at = ?,
@updated_at = ?;`
queryInsert := `
INSERT INTO clients(
id,
private_key,
public_key,
preshared_key,
NAME,
email,
allocated_ips,
allowed_ips,
extra_allowed_ips,
use_server_dns,
enabled,
created_at,
updated_at
)
VALUES(
@id,
@private_key,
@public_key,
@preshared_key,
@name,
@email,
@allocated_ips,
@allowed_ips,
@extra_allowed_ips,
@use_server_dns,
@enabled,
@created_at,
@updated_at
)
ON DUPLICATE KEY
UPDATE
id = @id,
private_key = @private_key,
public_key = @public_key,
preshared_key = @preshared_key,
NAME = @name,
email = @email,
allocated_ips = @allocated_ips,
allowed_ips = @allowed_ips,
extra_allowed_ips = @extra_allowed_ips,
use_server_dns = @use_server_dns,
enabled = @enabled,
created_at = @created_at,
updated_at = @updated_at;`
tx, err := o.conn.Begin()
if err != nil {
return err
}
// set values
if _, err := tx.Exec(
querySet,
client.ID,
client.PrivateKey,
client.PublicKey,
client.PresharedKey,
client.Name,
client.Email,
strings.Join(client.AllocatedIPs, arrayDelimiter),
strings.Join(client.AllowedIPs, arrayDelimiter),
strings.Join(client.ExtraAllowedIPs, arrayDelimiter),
client.UseServerDNS,
client.Enabled,
client.CreatedAt,
client.UpdatedAt,
); err != nil {
if rbErr := tx.Rollback(); rbErr != nil {
return fmt.Errorf("tx err: %v, rb err: %v", err, rbErr)
}
return err
}
// insert or update row
if _, err := tx.Exec(queryInsert); err != nil {
if rbErr := tx.Rollback(); rbErr != nil {
return fmt.Errorf("tx err: %v, rb err: %v", err, rbErr)
}
return err
}
return tx.Commit()
}
// DeleteClient func deletes client from the database
func (o *MySQLDB) DeleteClient(clientID string) error {
if _, err := o.conn.Exec("DELETE FROM clients WHERE id=?;", clientID); err != nil {
return err
}
return nil
}
// SaveServerInterface func saves a server interface to database
func (o *MySQLDB) SaveServerInterface(serverInterface model.ServerInterface) error {
// No need for ON DUPLICATE KEY UPDATE as only ever 1 record
query := `
UPDATE
interfaces
SET
addresses = ?,
listen_port = ?,
updated_at = ?,
post_up = ?,
post_down = ?
WHERE
id = 1;`
_, err := o.conn.Exec(
query,
strings.Join(serverInterface.Addresses, arrayDelimiter),
serverInterface.ListenPort,
serverInterface.UpdatedAt,
serverInterface.PostUp,
serverInterface.PostDown,
)
return err
}
// SaveServerKeyPair func saves a server keypair to database
func (o *MySQLDB) SaveServerKeyPair(serverKeyPair model.ServerKeypair) error {
query := `
UPDATE
keypair
SET
private_key = ?,
public_key = ?,
updated_at = ?
WHERE
id = 1;`
_, err := o.conn.Exec(
query,
serverKeyPair.PrivateKey,
serverKeyPair.PublicKey,
serverKeyPair.UpdatedAt,
)
return err
}
// SaveGlobalSettings saves global settings to database
func (o *MySQLDB) SaveGlobalSettings(globalSettings model.GlobalSetting) error {
query := `
UPDATE
global_settings
SET
endpoint_address = ?,
dns_servers = ?,
mtu = ?,
persistent_keepalive = ?,
config_file_path = ?,
updated_at = ?
WHERE
id = 1;`
_, err := o.conn.Exec(
query,
globalSettings.EndpointAddress,
strings.Join(globalSettings.DNSServers, arrayDelimiter),
globalSettings.MTU,
globalSettings.PersistentKeepalive,
globalSettings.ConfigFilePath,
globalSettings.UpdatedAt,
)
return err
}

View file

@ -77,4 +77,4 @@ ALTER TABLE `keypair`
ALTER TABLE `users` ALTER TABLE `users`
MODIFY `id` INT(11) NOT NULL AUTO_INCREMENT; MODIFY `id` INT(11) NOT NULL AUTO_INCREMENT;
COMMIT; COMMIT;

View file

@ -22,6 +22,7 @@ var (
DBDatabase string DBDatabase string
DBUsername string DBUsername string
DBPassword string DBPassword string
DBTLS string
) )
const ( const (