wireguard-ui/store/mysqldb/mysqldb.go
Matthew Nickson bc6f0f491f
Added MySQL as a datastore
The specific datastore backend to use can now be set by using
command line options or by using environment variables. The default
datastore backend is still jsondb but mysql can now also be used as a
backend. Environment variables have also been added to control settings
relevant to the database. SQL queries are made by directly accessing the
database/sql API. TLS is also supported.

Signed-off-by: Matthew Nickson <mnickson@sidingsmedia.com>
2022-03-16 22:49:18 +00:00

514 lines
12 KiB
Go

// Package mysqldb provides a MySQL storage backend for Wireguard UI
package mysqldb
import (
"database/sql"
"encoding/base64"
"fmt"
"strings"
"time"
rice "github.com/GeertJohan/go.rice"
"github.com/go-sql-driver/mysql"
"github.com/skip2/go-qrcode"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/ngoduykhanh/wireguard-ui/model"
"github.com/ngoduykhanh/wireguard-ui/util"
)
// MySQLDB - Representation of MySQL database backend
type MySQLDB struct {
conn *sql.DB
schema string
dbName string
}
// String to split each item in array
var arrayDelimiter = ","
// New returns pointer to MySQL database
func New(uname string, pwd string, host string, port int, database string, tls string, templateBox *rice.Box) (*MySQLDB, error) {
// Set connection config
config := mysql.NewConfig()
config.User = uname
config.Passwd = pwd
config.Net = "tcp"
config.Addr = fmt.Sprintf("%s:%d", host, port)
config.DBName = database
config.MultiStatements = true
config.ParseTime = true
config.TLSConfig = tls
// Open connection pool
conn, err := sql.Open("mysql", config.FormatDSN())
if err != nil {
return nil, err
}
conn.SetConnMaxLifetime(time.Minute * 3)
conn.SetMaxOpenConns(10)
conn.SetMaxIdleConns(10)
// Test the connection
if err := conn.Ping(); err != nil {
return nil, err
}
// Load DB schema
schema, err := templateBox.String("mysql.sql")
if err != nil {
return nil, err
}
ans := MySQLDB{
conn: conn,
schema: schema,
dbName: database,
}
return &ans, nil
}
// Init initializes the database
func (o *MySQLDB) Init() error {
// Check if database is empty
var databaseEmpty int
err := o.conn.QueryRow(
"SELECT COUNT(DISTINCT `table_name`) FROM `information_schema`.`columns` WHERE `table_schema` = ?",
o.dbName,
).Scan(&databaseEmpty)
if err != nil {
return err
}
if !(databaseEmpty > 0) {
// Initialize database
// Tell the user what we're doing as this could take a while
fmt.Println("Initializing database")
// Create database schema
if _, err := o.conn.Exec(o.schema); err != nil {
return err
}
// servers's interface
if _, err := o.conn.Exec(
"INSERT INTO interfaces (addresses, listen_port, updated_at) VALUES (?, ?, ?);",
util.DefaultServerAddress,
util.DefaultServerPort,
time.Now().UTC(),
); err != nil {
return err
}
// server's keypair
key, err := wgtypes.GeneratePrivateKey()
if err != nil {
return err
}
if _, err := o.conn.Exec(
"INSERT INTO keypair (private_key, public_key, updated_at) VALUES (?, ?, ?);",
key.String(),
key.PublicKey().String(),
time.Now().UTC(),
); err != nil {
return err
}
// global settings
publicInterface, err := util.GetPublicIP()
if err != nil {
return err
}
if _, err := o.conn.Exec(
"INSERT INTO global_settings (endpoint_address, dns_servers, mtu, persistent_keepalive, config_file_path, updated_at) VALUES (?, ?, ?, ?, ?, ?);",
publicInterface.IPAddress,
util.DefaultDNS,
util.DefaultMTU,
util.DefaultPersistentKeepalive,
util.DefaultConfigFilePath,
time.Now().UTC(),
); err != nil {
return err
}
// user info
if _, err := o.conn.Exec(
"INSERT INTO users (username, password) VALUES (?, ?);",
util.GetCredVar(util.UsernameEnvVar, util.DefaultUsername),
util.GetCredVar(util.PasswordEnvVar, util.DefaultPassword),
); err != nil {
return err
}
}
return nil
}
// GetUser func to query user info from the database
func (o *MySQLDB) GetUser() (model.User, error) {
user := model.User{}
row := o.conn.QueryRow("SELECT username, password FROM users;")
err := row.Scan(
&user.Username,
&user.Password,
)
return user, err
}
// GetGlobalSettings func to query global settings from the database
func (o *MySQLDB) GetGlobalSettings() (model.GlobalSetting, error) {
settings := model.GlobalSetting{}
var dnsServers string
row := o.conn.QueryRow("SELECT endpoint_address, dns_servers, mtu, persistent_keepalive, config_file_path, updated_at FROM global_settings;")
// Can't use ScanStruct here as doesn't know how to handle
// dns_servers list. Instead we must populate struct it manually.
err := row.Scan(
&settings.EndpointAddress,
&dnsServers,
&settings.MTU,
&settings.PersistentKeepalive,
&settings.ConfigFilePath,
&settings.UpdatedAt,
)
settings.DNSServers = strings.Split(dnsServers, arrayDelimiter)
return settings, err
}
// GetServer func to query Server setting from the database
func (o *MySQLDB) GetServer() (model.Server, error) {
server := model.Server{}
// Get interface
serverInterface := model.ServerInterface{}
var addresses string
row := o.conn.QueryRow("SELECT addresses, listen_port, updated_at, post_up, post_down FROM interfaces;")
err := row.Scan(
&addresses,
&serverInterface.ListenPort,
&serverInterface.UpdatedAt,
&serverInterface.PostUp,
&serverInterface.PostDown,
)
serverInterface.Addresses = strings.Split(addresses, arrayDelimiter)
if err != nil {
return server, err
}
// Get keypair
serverKeyPair := model.ServerKeypair{}
if err := o.conn.QueryRow("SELECT private_key, public_key, updated_at FROM keypair;").
Scan(
&serverKeyPair.PrivateKey,
&serverKeyPair.PublicKey,
&serverKeyPair.UpdatedAt,
); err != nil {
return server, err
}
// create Server object and return
server.Interface = &serverInterface
server.KeyPair = &serverKeyPair
return server, nil
}
// GetClients func to query Client settings from the database
func (o *MySQLDB) GetClients(hasQRCode bool) ([]model.ClientData, error) {
var clients []model.ClientData
rows, err := o.conn.Query("SELECT * FROM clients;")
if err != nil {
return clients, err
}
for rows.Next() {
client := model.Client{}
clientData := model.ClientData{}
var allocatedIPs string
var allowedIPs string
var extraAllowedIPs string
// Get client info
if err := rows.Scan(
&client.ID,
&client.PrivateKey,
&client.PublicKey,
&client.PresharedKey,
&client.Name,
&client.Email,
&allocatedIPs,
&allowedIPs,
&extraAllowedIPs,
&client.UseServerDNS,
&client.Enabled,
&client.CreatedAt,
&client.UpdatedAt,
); err != nil {
return clients, err
}
client.AllocatedIPs = strings.Split(allocatedIPs, arrayDelimiter)
client.AllowedIPs = strings.Split(allowedIPs, arrayDelimiter)
client.ExtraAllowedIPs = strings.Split(extraAllowedIPs, arrayDelimiter)
// generate client qrcode image in base64
if hasQRCode && client.PrivateKey != "" {
server, _ := o.GetServer()
globalSettings, _ := o.GetGlobalSettings()
png, err := qrcode.Encode(util.BuildClientConfig(client, server, globalSettings), qrcode.Medium, 256)
if err == nil {
clientData.QRCode = "data:image/png;base64," + base64.StdEncoding.EncodeToString([]byte(png))
} else {
fmt.Print("Cannot generate QR code: ", err)
}
}
// create the list of clients and their qrcode data
clientData.Client = &client
clients = append(clients, clientData)
}
return clients, nil
}
// GetClientByID func to query Clients by ID from the database
func (o *MySQLDB) GetClientByID(clientID string, hasQRCode bool) (model.ClientData, error) {
client := model.Client{}
clientData := model.ClientData{}
var allocatedIPs string
var allowedIPs string
var extraAllowedIPs string
// read client info
if err := o.conn.QueryRow("SELECT * FROM clients WHERE id = ?;", clientID).Scan(
&client.ID,
&client.PrivateKey,
&client.PublicKey,
&client.PresharedKey,
&client.Name,
&client.Email,
&allocatedIPs,
&allowedIPs,
&extraAllowedIPs,
&client.UseServerDNS,
&client.Enabled,
&client.CreatedAt,
&client.UpdatedAt,
); err != nil {
return clientData, err
}
client.AllocatedIPs = strings.Split(allocatedIPs, arrayDelimiter)
client.AllowedIPs = strings.Split(allowedIPs, arrayDelimiter)
client.ExtraAllowedIPs = strings.Split(extraAllowedIPs, arrayDelimiter)
// generate client qrcode image in base64
if hasQRCode && client.PrivateKey != "" {
server, _ := o.GetServer()
globalSettings, _ := o.GetGlobalSettings()
png, err := qrcode.Encode(util.BuildClientConfig(client, server, globalSettings), qrcode.Medium, 256)
if err == nil {
clientData.QRCode = "data:image/png;base64," + base64.StdEncoding.EncodeToString([]byte(png))
} else {
fmt.Print("Cannot generate QR code: ", err)
}
}
clientData.Client = &client
return clientData, nil
}
// SaveClient func saves client to database
func (o *MySQLDB) SaveClient(client model.Client) error {
// If client doesn't exist, create a record, else update existing record
querySet := `
SET
@id = ?,
@private_key = ?,
@public_key = ?,
@preshared_key = ?,
@name = ?,
@email = ?,
@allocated_ips = ?,
@allowed_ips = ?,
@extra_allowed_ips = ?,
@use_server_dns = ?,
@enabled = ?,
@created_at = ?,
@updated_at = ?;`
queryInsert := `
INSERT INTO clients(
id,
private_key,
public_key,
preshared_key,
NAME,
email,
allocated_ips,
allowed_ips,
extra_allowed_ips,
use_server_dns,
enabled,
created_at,
updated_at
)
VALUES(
@id,
@private_key,
@public_key,
@preshared_key,
@name,
@email,
@allocated_ips,
@allowed_ips,
@extra_allowed_ips,
@use_server_dns,
@enabled,
@created_at,
@updated_at
)
ON DUPLICATE KEY
UPDATE
id = @id,
private_key = @private_key,
public_key = @public_key,
preshared_key = @preshared_key,
NAME = @name,
email = @email,
allocated_ips = @allocated_ips,
allowed_ips = @allowed_ips,
extra_allowed_ips = @extra_allowed_ips,
use_server_dns = @use_server_dns,
enabled = @enabled,
created_at = @created_at,
updated_at = @updated_at;`
tx, err := o.conn.Begin()
if err != nil {
return err
}
// set values
if _, err := tx.Exec(
querySet,
client.ID,
client.PrivateKey,
client.PublicKey,
client.PresharedKey,
client.Name,
client.Email,
strings.Join(client.AllocatedIPs, arrayDelimiter),
strings.Join(client.AllowedIPs, arrayDelimiter),
strings.Join(client.ExtraAllowedIPs, arrayDelimiter),
client.UseServerDNS,
client.Enabled,
client.CreatedAt,
client.UpdatedAt,
); err != nil {
if rbErr := tx.Rollback(); rbErr != nil {
return fmt.Errorf("tx err: %v, rb err: %v", err, rbErr)
}
return err
}
// insert or update row
if _, err := tx.Exec(queryInsert); err != nil {
if rbErr := tx.Rollback(); rbErr != nil {
return fmt.Errorf("tx err: %v, rb err: %v", err, rbErr)
}
return err
}
return tx.Commit()
}
// DeleteClient func deletes client from the database
func (o *MySQLDB) DeleteClient(clientID string) error {
if _, err := o.conn.Exec("DELETE FROM clients WHERE id=?;", clientID); err != nil {
return err
}
return nil
}
// SaveServerInterface func saves a server interface to database
func (o *MySQLDB) SaveServerInterface(serverInterface model.ServerInterface) error {
// No need for ON DUPLICATE KEY UPDATE as only ever 1 record
query := `
UPDATE
interfaces
SET
addresses = ?,
listen_port = ?,
updated_at = ?,
post_up = ?,
post_down = ?
WHERE
id = 1;`
_, err := o.conn.Exec(
query,
strings.Join(serverInterface.Addresses, arrayDelimiter),
serverInterface.ListenPort,
serverInterface.UpdatedAt,
serverInterface.PostUp,
serverInterface.PostDown,
)
return err
}
// SaveServerKeyPair func saves a server keypair to database
func (o *MySQLDB) SaveServerKeyPair(serverKeyPair model.ServerKeypair) error {
query := `
UPDATE
keypair
SET
private_key = ?,
public_key = ?,
updated_at = ?
WHERE
id = 1;`
_, err := o.conn.Exec(
query,
serverKeyPair.PrivateKey,
serverKeyPair.PublicKey,
serverKeyPair.UpdatedAt,
)
return err
}
// SaveGlobalSettings saves global settings to database
func (o *MySQLDB) SaveGlobalSettings(globalSettings model.GlobalSetting) error {
query := `
UPDATE
global_settings
SET
endpoint_address = ?,
dns_servers = ?,
mtu = ?,
persistent_keepalive = ?,
config_file_path = ?,
updated_at = ?
WHERE
id = 1;`
_, err := o.conn.Exec(
query,
globalSettings.EndpointAddress,
strings.Join(globalSettings.DNSServers, arrayDelimiter),
globalSettings.MTU,
globalSettings.PersistentKeepalive,
globalSettings.ConfigFilePath,
globalSettings.UpdatedAt,
)
return err
}