wireguard-ui/util/util.go
0xCA bee5c54127 Further session protections and fixes
Use MaxAge instead of Expires
Verify if the cookie is not too old and not from the future
Verify if the user exists and unchanged
Refresh not sooner than 24h
Do not refresh temporary sessions
Delete cookies on logout
2023-12-29 15:08:50 +05:00

856 lines
21 KiB
Go

package util
import (
"bufio"
"encoding/json"
"errors"
"fmt"
"hash/crc32"
"io"
"io/fs"
"math/rand"
"net"
"os"
"path"
"path/filepath"
"strconv"
"strings"
"text/template"
"time"
"github.com/ngoduykhanh/wireguard-ui/store"
"github.com/ngoduykhanh/wireguard-ui/telegram"
"github.com/skip2/go-qrcode"
"golang.org/x/mod/sumdb/dirhash"
externalip "github.com/glendc/go-external-ip"
"github.com/labstack/gommon/log"
"github.com/ngoduykhanh/wireguard-ui/model"
"github.com/sdomino/scribble"
)
var qrCodeSettings = model.QRCodeSettings{
Enabled: true,
IncludeDNS: true,
IncludeMTU: true,
}
// BuildClientConfig to create wireguard client config string
func BuildClientConfig(client model.Client, server model.Server, setting model.GlobalSetting) string {
// Interface section
clientAddress := fmt.Sprintf("Address = %s\n", strings.Join(client.AllocatedIPs, ","))
clientPrivateKey := fmt.Sprintf("PrivateKey = %s\n", client.PrivateKey)
clientDNS := ""
if client.UseServerDNS {
clientDNS = fmt.Sprintf("DNS = %s\n", strings.Join(setting.DNSServers, ","))
}
clientMTU := ""
if setting.MTU > 0 {
clientMTU = fmt.Sprintf("MTU = %d\n", setting.MTU)
}
// Peer section
peerPublicKey := fmt.Sprintf("PublicKey = %s\n", server.KeyPair.PublicKey)
peerPresharedKey := ""
if client.PresharedKey != "" {
peerPresharedKey = fmt.Sprintf("PresharedKey = %s\n", client.PresharedKey)
}
peerAllowedIPs := fmt.Sprintf("AllowedIPs = %s\n", strings.Join(client.AllowedIPs, ","))
desiredHost := setting.EndpointAddress
desiredPort := server.Interface.ListenPort
if strings.Contains(desiredHost, ":") {
split := strings.Split(desiredHost, ":")
desiredHost = split[0]
if n, err := strconv.Atoi(split[1]); err == nil {
desiredPort = n
} else {
log.Error("Endpoint appears to be incorrectly formatted: ", err)
}
}
peerEndpoint := fmt.Sprintf("Endpoint = %s:%d\n", desiredHost, desiredPort)
peerPersistentKeepalive := ""
if setting.PersistentKeepalive > 0 {
peerPersistentKeepalive = fmt.Sprintf("PersistentKeepalive = %d\n", setting.PersistentKeepalive)
}
// build the config as string
strConfig := "[Interface]\n" +
clientAddress +
clientPrivateKey +
clientDNS +
clientMTU +
"\n[Peer]\n" +
peerPublicKey +
peerPresharedKey +
peerAllowedIPs +
peerEndpoint +
peerPersistentKeepalive
return strConfig
}
// ClientDefaultsFromEnv to read the default values for creating a new client from the environment or use sane defaults
func ClientDefaultsFromEnv() model.ClientDefaults {
clientDefaults := model.ClientDefaults{}
clientDefaults.AllowedIps = LookupEnvOrStrings(DefaultClientAllowedIpsEnvVar, []string{"0.0.0.0/0"})
clientDefaults.ExtraAllowedIps = LookupEnvOrStrings(DefaultClientExtraAllowedIpsEnvVar, []string{})
clientDefaults.UseServerDNS = LookupEnvOrBool(DefaultClientUseServerDNSEnvVar, true)
clientDefaults.EnableAfterCreation = LookupEnvOrBool(DefaultClientEnableAfterCreationEnvVar, true)
return clientDefaults
}
// ContainsCIDR to check if ipnet1 contains ipnet2
// https://stackoverflow.com/a/40406619/6111641
// https://go.dev/play/p/Q4J-JEN3sF
func ContainsCIDR(ipnet1, ipnet2 *net.IPNet) bool {
ones1, _ := ipnet1.Mask.Size()
ones2, _ := ipnet2.Mask.Size()
return ones1 <= ones2 && ipnet1.Contains(ipnet2.IP)
}
// ValidateCIDR to validate a network CIDR
func ValidateCIDR(cidr string) bool {
_, _, err := net.ParseCIDR(cidr)
if err != nil {
return false
}
return true
}
// ValidateCIDRList to validate a list of network CIDR
func ValidateCIDRList(cidrs []string, allowEmpty bool) bool {
for _, cidr := range cidrs {
if allowEmpty {
if len(cidr) > 0 {
if ValidateCIDR(cidr) == false {
return false
}
}
} else {
if ValidateCIDR(cidr) == false {
return false
}
}
}
return true
}
// ValidateAllowedIPs to validate allowed ip addresses in CIDR format
func ValidateAllowedIPs(cidrs []string) bool {
if ValidateCIDRList(cidrs, false) == false {
return false
}
return true
}
// ValidateExtraAllowedIPs to validate extra Allowed ip addresses, allowing empty strings
func ValidateExtraAllowedIPs(cidrs []string) bool {
if ValidateCIDRList(cidrs, true) == false {
return false
}
return true
}
// ValidateServerAddresses to validate allowed ip addresses in CIDR format
func ValidateServerAddresses(cidrs []string) bool {
if ValidateCIDRList(cidrs, false) == false {
return false
}
return true
}
// ValidateIPAddress to validate the IPv4 and IPv6 address
func ValidateIPAddress(ip string) bool {
if net.ParseIP(ip) == nil {
return false
}
return true
}
// ValidateIPAddressList to validate a list of IPv4 and IPv6 addresses
func ValidateIPAddressList(ips []string) bool {
for _, ip := range ips {
if ValidateIPAddress(ip) == false {
return false
}
}
return true
}
// GetInterfaceIPs to get local machine's interface ip addresses
func GetInterfaceIPs() ([]model.Interface, error) {
// get machine's interfaces
ifaces, err := net.Interfaces()
if err != nil {
return nil, err
}
var interfaceList []model.Interface
// get interface's ip addresses
for _, i := range ifaces {
addrs, err := i.Addrs()
if err != nil {
return nil, err
}
for _, addr := range addrs {
var ip net.IP
switch v := addr.(type) {
case *net.IPNet:
ip = v.IP
case *net.IPAddr:
ip = v.IP
}
if ip == nil || ip.IsLoopback() {
continue
}
ip = ip.To4()
if ip == nil {
continue
}
iface := model.Interface{}
iface.Name = i.Name
iface.IPAddress = ip.String()
interfaceList = append(interfaceList, iface)
}
}
return interfaceList, err
}
// GetPublicIP to get machine's public ip address
func GetPublicIP() (model.Interface, error) {
// set time out to 5 seconds
cfg := externalip.ConsensusConfig{}
cfg.Timeout = time.Second * 5
consensus := externalip.NewConsensus(&cfg, nil)
// add trusted voters
consensus.AddVoter(externalip.NewHTTPSource("https://checkip.amazonaws.com/"), 1)
consensus.AddVoter(externalip.NewHTTPSource("http://whatismyip.akamai.com"), 1)
consensus.AddVoter(externalip.NewHTTPSource("https://ifconfig.top"), 1)
publicInterface := model.Interface{}
publicInterface.Name = "Public Address"
ip, err := consensus.ExternalIP()
if err != nil {
publicInterface.IPAddress = "N/A"
} else {
publicInterface.IPAddress = ip.String()
}
// error handling happened above, no need to pass it through
return publicInterface, nil
}
// GetIPFromCIDR get ip from CIDR
func GetIPFromCIDR(cidr string) (string, error) {
ip, _, err := net.ParseCIDR(cidr)
if err != nil {
return "", err
}
return ip.String(), nil
}
// GetAllocatedIPs to get all ip addresses allocated to clients and server
func GetAllocatedIPs(ignoreClientID string) ([]string, error) {
allocatedIPs := make([]string, 0)
// initialize database directory
dir := "./db"
db, err := scribble.New(dir, nil)
if err != nil {
return nil, err
}
// read server information
serverInterface := model.ServerInterface{}
if err := db.Read("server", "interfaces", &serverInterface); err != nil {
return nil, err
}
// append server's addresses to the result
for _, cidr := range serverInterface.Addresses {
ip, err := GetIPFromCIDR(cidr)
if err != nil {
return nil, err
}
allocatedIPs = append(allocatedIPs, ip)
}
// read client information
records, err := db.ReadAll("clients")
if err != nil {
return nil, err
}
// append client's addresses to the result
for _, f := range records {
client := model.Client{}
if err := json.Unmarshal(f, &client); err != nil {
return nil, err
}
if client.ID != ignoreClientID {
for _, cidr := range client.AllocatedIPs {
ip, err := GetIPFromCIDR(cidr)
if err != nil {
return nil, err
}
allocatedIPs = append(allocatedIPs, ip)
}
}
}
return allocatedIPs, nil
}
// inc from https://play.golang.org/p/m8TNTtygK0
func inc(ip net.IP) {
for j := len(ip) - 1; j >= 0; j-- {
ip[j]++
if ip[j] > 0 {
break
}
}
}
// GetBroadcastIP func to get the broadcast ip address of a network
func GetBroadcastIP(n *net.IPNet) net.IP {
var broadcast net.IP
if len(n.IP) == 4 {
broadcast = net.ParseIP("0.0.0.0").To4()
} else {
broadcast = net.ParseIP("::")
}
for i := 0; i < len(n.IP); i++ {
broadcast[i] = n.IP[i] | ^n.Mask[i]
}
return broadcast
}
// GetBroadcastAndNetworkAddrsLookup get the ip address that can't be used with current server interfaces
func GetBroadcastAndNetworkAddrsLookup(interfaceAddresses []string) map[string]bool {
list := make(map[string]bool)
for _, ifa := range interfaceAddresses {
_, netAddr, err := net.ParseCIDR(ifa)
if err != nil {
continue
}
broadcastAddr := GetBroadcastIP(netAddr).String()
networkAddr := netAddr.IP.String()
list[broadcastAddr] = true
list[networkAddr] = true
}
return list
}
// GetAvailableIP get the ip address that can be allocated from an CIDR
// We need interfaceAddresses to find real broadcast and network addresses
func GetAvailableIP(cidr string, allocatedList, interfaceAddresses []string) (string, error) {
ip, netAddr, err := net.ParseCIDR(cidr)
if err != nil {
return "", err
}
unavailableIPs := GetBroadcastAndNetworkAddrsLookup(interfaceAddresses)
for ip := ip.Mask(netAddr.Mask); netAddr.Contains(ip); inc(ip) {
available := true
suggestedAddr := ip.String()
for _, allocatedAddr := range allocatedList {
if suggestedAddr == allocatedAddr {
available = false
break
}
}
if available && !unavailableIPs[suggestedAddr] {
return suggestedAddr, nil
}
}
return "", errors.New("no more available ip address")
}
// ValidateIPAllocation to validate the list of client's ip allocation
// They must have a correct format and available in serverAddresses space
func ValidateIPAllocation(serverAddresses []string, ipAllocatedList []string, ipAllocationList []string) (bool, error) {
for _, clientCIDR := range ipAllocationList {
ip, _, _ := net.ParseCIDR(clientCIDR)
// clientCIDR must be in CIDR format
if ip == nil {
return false, fmt.Errorf("invalid ip allocation input %s. Must be in CIDR format", clientCIDR)
}
// return false immediately if the ip is already in use (in ipAllocatedList)
for _, item := range ipAllocatedList {
if item == ip.String() {
return false, fmt.Errorf("IP %s already allocated", ip)
}
}
// even if it is not in use, we still need to check if it
// belongs to a network of the server.
var isValid = false
for _, serverCIDR := range serverAddresses {
_, serverNet, _ := net.ParseCIDR(serverCIDR)
if serverNet.Contains(ip) {
isValid = true
break
}
}
// current ip allocation is valid, check the next one
if isValid {
continue
} else {
return false, fmt.Errorf("IP %s does not belong to any network addresses of WireGuard server", ip)
}
}
return true, nil
}
// findSubnetRangeForIP to find first SR for IP, and cache the match
func findSubnetRangeForIP(cidr string) (uint16, error) {
ip, _, err := net.ParseCIDR(cidr)
if err != nil {
return 0, err
}
if srName, ok := IPToSubnetRange[ip.String()]; ok {
return srName, nil
}
for srIndex, sr := range SubnetRangesOrder {
for _, srCIDR := range SubnetRanges[sr] {
if srCIDR.Contains(ip) {
IPToSubnetRange[ip.String()] = uint16(srIndex)
return uint16(srIndex), nil
}
}
}
return 0, fmt.Errorf("subnet range not found for this IP")
}
// FillClientSubnetRange to fill subnet ranges client belongs to, does nothing if SRs are not found
func FillClientSubnetRange(client model.ClientData) model.ClientData {
cl := *client.Client
for _, ip := range cl.AllocatedIPs {
sr, err := findSubnetRangeForIP(ip)
if err != nil {
continue
}
cl.SubnetRanges = append(cl.SubnetRanges, SubnetRangesOrder[sr])
}
return model.ClientData{
Client: &cl,
QRCode: client.QRCode,
}
}
// ValidateAndFixSubnetRanges to check if subnet ranges are valid for the server configuration
// Removes all non-valid CIDRs
func ValidateAndFixSubnetRanges(db store.IStore) error {
if len(SubnetRangesOrder) == 0 {
return nil
}
server, err := db.GetServer()
if err != nil {
return err
}
var serverSubnets []*net.IPNet
for _, addr := range server.Interface.Addresses {
addr = strings.TrimSpace(addr)
_, netAddr, err := net.ParseCIDR(addr)
if err != nil {
return err
}
serverSubnets = append(serverSubnets, netAddr)
}
for _, rng := range SubnetRangesOrder {
cidrs := SubnetRanges[rng]
if len(cidrs) > 0 {
newCIDRs := make([]*net.IPNet, 0)
for _, cidr := range cidrs {
valid := false
for _, serverSubnet := range serverSubnets {
if ContainsCIDR(serverSubnet, cidr) {
valid = true
break
}
}
if valid {
newCIDRs = append(newCIDRs, cidr)
} else {
log.Warnf("[%v] CIDR is outside of all server subnets: %v. Removed.", rng, cidr)
}
}
if len(newCIDRs) > 0 {
SubnetRanges[rng] = newCIDRs
} else {
delete(SubnetRanges, rng)
log.Warnf("[%v] No valid CIDRs in this subnet range. Removed.", rng)
}
}
}
return nil
}
// GetSubnetRangesString to get a formatted string, representing active subnet ranges
func GetSubnetRangesString() string {
if len(SubnetRangesOrder) == 0 {
return ""
}
strB := strings.Builder{}
for _, rng := range SubnetRangesOrder {
cidrs := SubnetRanges[rng]
if len(cidrs) > 0 {
strB.WriteString(rng)
strB.WriteString(":[")
first := true
for _, cidr := range cidrs {
if !first {
strB.WriteString(", ")
}
strB.WriteString(cidr.String())
first = false
}
strB.WriteString("] ")
}
}
return strings.TrimSpace(strB.String())
}
// WriteWireGuardServerConfig to write Wireguard server config. e.g. wg0.conf
func WriteWireGuardServerConfig(tmplDir fs.FS, serverConfig model.Server, clientDataList []model.ClientData, usersList []model.User, globalSettings model.GlobalSetting) error {
var tmplWireguardConf string
// if set, read wg.conf template from WgConfTemplate
if len(WgConfTemplate) > 0 {
fileContentBytes, err := os.ReadFile(WgConfTemplate)
if err != nil {
return err
}
tmplWireguardConf = string(fileContentBytes)
} else {
// read default wg.conf template file to string
fileContent, err := StringFromEmbedFile(tmplDir, "wg.conf")
if err != nil {
return err
}
tmplWireguardConf = fileContent
}
// parse the template
t, err := template.New("wg_config").Parse(tmplWireguardConf)
if err != nil {
return err
}
// write config file to disk
f, err := os.Create(globalSettings.ConfigFilePath)
if err != nil {
return err
}
config := map[string]interface{}{
"serverConfig": serverConfig,
"clientDataList": clientDataList,
"globalSettings": globalSettings,
"usersList": usersList,
}
err = t.Execute(f, config)
if err != nil {
return err
}
f.Close()
return nil
}
// SendRequestedConfigsToTelegram to send client all their configs. Returns failed configs list.
func SendRequestedConfigsToTelegram(db store.IStore, userid int64) []string {
failedList := make([]string, 0)
TgUseridToClientIDMutex.RLock()
if clids, found := TgUseridToClientID[userid]; found && len(clids) > 0 {
TgUseridToClientIDMutex.RUnlock()
for _, clid := range clids {
clientData, err := db.GetClientByID(clid, qrCodeSettings)
if err != nil {
// return fmt.Errorf("unable to get client")
failedList = append(failedList, clid)
continue
}
// build config
server, _ := db.GetServer()
globalSettings, _ := db.GetGlobalSettings()
config := BuildClientConfig(*clientData.Client, server, globalSettings)
configData := []byte(config)
var qrData []byte
if clientData.Client.PrivateKey != "" {
qrData, err = qrcode.Encode(config, qrcode.Medium, 512)
if err != nil {
// return fmt.Errorf("unable to encode qr")
failedList = append(failedList, clientData.Client.Name)
continue
}
}
userid, err := strconv.ParseInt(clientData.Client.TgUserid, 10, 64)
if err != nil {
// return fmt.Errorf("tg usrid is unreadable")
failedList = append(failedList, clientData.Client.Name)
continue
}
err = telegram.SendConfig(userid, clientData.Client.Name, configData, qrData, true)
if err != nil {
failedList = append(failedList, clientData.Client.Name)
continue
}
time.Sleep(2 * time.Second)
}
} else {
TgUseridToClientIDMutex.RUnlock()
}
return failedList
}
func LookupEnvOrString(key string, defaultVal string) string {
if val, ok := os.LookupEnv(key); ok {
return val
}
return defaultVal
}
func LookupEnvOrBool(key string, defaultVal bool) bool {
if val, ok := os.LookupEnv(key); ok {
v, err := strconv.ParseBool(val)
if err != nil {
fmt.Fprintf(os.Stderr, "LookupEnvOrBool[%s]: %v\n", key, err)
}
return v
}
return defaultVal
}
func LookupEnvOrInt(key string, defaultVal int) int {
if val, ok := os.LookupEnv(key); ok {
v, err := strconv.Atoi(val)
if err != nil {
fmt.Fprintf(os.Stderr, "LookupEnvOrInt[%s]: %v\n", key, err)
}
return v
}
return defaultVal
}
func LookupEnvOrStrings(key string, defaultVal []string) []string {
if val, ok := os.LookupEnv(key); ok {
return strings.Split(val, ",")
}
return defaultVal
}
func LookupEnvOrFile(key string, defaultVal string) string {
if val, ok := os.LookupEnv(key); ok {
if file, err := os.Open(val); err == nil {
var content string
scanner := bufio.NewScanner(file)
for scanner.Scan() {
content += scanner.Text()
}
return content
}
}
return defaultVal
}
func StringFromEmbedFile(embed fs.FS, filename string) (string, error) {
file, err := embed.Open(filename)
if err != nil {
return "", err
}
content, err := io.ReadAll(file)
if err != nil {
return "", err
}
return string(content), nil
}
func ParseLogLevel(lvl string) (log.Lvl, error) {
switch strings.ToLower(lvl) {
case "debug":
return log.DEBUG, nil
case "info":
return log.INFO, nil
case "warn":
return log.WARN, nil
case "error":
return log.ERROR, nil
case "off":
return log.OFF, nil
default:
return log.DEBUG, fmt.Errorf("not a valid log level: %s", lvl)
}
}
// GetCurrentHash returns current hashes
func GetCurrentHash(db store.IStore) (string, string) {
hashClients, _ := dirhash.HashDir(path.Join(db.GetPath(), "clients"), "prefix", dirhash.Hash1)
files := append([]string(nil), "prefix/global_settings.json", "prefix/interfaces.json", "prefix/keypair.json")
osOpen := func(name string) (io.ReadCloser, error) {
return os.Open(filepath.Join(path.Join(db.GetPath(), "server"), strings.TrimPrefix(name, "prefix")))
}
hashServer, _ := dirhash.Hash1(files, osOpen)
return hashClients, hashServer
}
func HashesChanged(db store.IStore) bool {
old, _ := db.GetHashes()
oldClient := old.Client
oldServer := old.Server
newClient, newServer := GetCurrentHash(db)
if oldClient != newClient {
//fmt.Println("Hash for client differs")
return true
}
if oldServer != newServer {
//fmt.Println("Hash for server differs")
return true
}
return false
}
func UpdateHashes(db store.IStore) error {
var clientServerHashes model.ClientServerHashes
clientServerHashes.Client, clientServerHashes.Server = GetCurrentHash(db)
return db.SaveHashes(clientServerHashes)
}
func RandomString(length int) string {
var seededRand = rand.New(rand.NewSource(time.Now().UnixNano()))
charset := "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
b := make([]byte, length)
for i := range b {
b[i] = charset[seededRand.Intn(len(charset))]
}
return string(b)
}
func ManagePerms(path string) error {
err := os.Chmod(path, 0600)
return err
}
func AddTgToClientID(userid int64, clientID string) {
TgUseridToClientIDMutex.Lock()
defer TgUseridToClientIDMutex.Unlock()
if _, ok := TgUseridToClientID[userid]; ok && TgUseridToClientID[userid] != nil {
TgUseridToClientID[userid] = append(TgUseridToClientID[userid], clientID)
} else {
TgUseridToClientID[userid] = []string{clientID}
}
}
func UpdateTgToClientID(userid int64, clientID string) {
TgUseridToClientIDMutex.Lock()
defer TgUseridToClientIDMutex.Unlock()
// Detach clientID from any existing userid
for uid, cls := range TgUseridToClientID {
if cls != nil {
filtered := filterStringSlice(cls, clientID)
if len(filtered) > 0 {
TgUseridToClientID[uid] = filtered
} else {
delete(TgUseridToClientID, uid)
}
}
}
// Attach it to the new one
if _, ok := TgUseridToClientID[userid]; ok && TgUseridToClientID[userid] != nil {
TgUseridToClientID[userid] = append(TgUseridToClientID[userid], clientID)
} else {
TgUseridToClientID[userid] = []string{clientID}
}
}
func RemoveTgToClientID(clientID string) {
TgUseridToClientIDMutex.Lock()
defer TgUseridToClientIDMutex.Unlock()
// Detach clientID from any existing userid
for uid, cls := range TgUseridToClientID {
if cls != nil {
filtered := filterStringSlice(cls, clientID)
if len(filtered) > 0 {
TgUseridToClientID[uid] = filtered
} else {
delete(TgUseridToClientID, uid)
}
}
}
}
func filterStringSlice(s []string, excludedStr string) []string {
filtered := s[:0]
for _, v := range s {
if v != excludedStr {
filtered = append(filtered, v)
}
}
return filtered
}
func GetDBUserCRC32(dbuser model.User) uint32 {
var isAdmin byte = 0
if dbuser.Admin {
isAdmin = 1
}
return crc32.ChecksumIEEE(ConcatMultipleSlices([]byte(dbuser.Username), []byte{isAdmin}, []byte(dbuser.PasswordHash), []byte(dbuser.Password)))
}
func ConcatMultipleSlices(slices ...[]byte) []byte {
var totalLen int
for _, s := range slices {
totalLen += len(s)
}
result := make([]byte, totalLen)
var i int
for _, s := range slices {
i += copy(result[i:], s)
}
return result
}