Add WireGuard TUI implementation

- Add Go TUI with bubbletea for WireGuard management
- Implement client CRUD operations with QR code generation
- Add configuration and validation modules
- Install/update scripts for client setup
- Update Makefile to build binaries to bin/ directory
- Add .gitignore for Go projects
This commit is contained in:
Calmcacil
2026-01-12 19:03:35 +01:00
parent 5ac68db854
commit 26120b8bc2
37 changed files with 6330 additions and 97 deletions

View File

@@ -0,0 +1,585 @@
package wireguard
import (
"bytes"
"fmt"
"log"
"os"
"os/exec"
"path/filepath"
"strings"
"github.com/calmcacil/wg-admin/internal/backup"
"github.com/calmcacil/wg-admin/internal/config"
)
// Client represents a WireGuard peer configuration
type Client struct {
Name string // Client name extracted from filename
IPv4 string // IPv4 address from AllowedIPs
IPv6 string // IPv6 address from AllowedIPs
PublicKey string // WireGuard public key
HasPSK bool // Whether PresharedKey is configured
ConfigPath string // Path to the client config file
}
// ParseClientConfig parses a single WireGuard client configuration file
func ParseClientConfig(path string) (*Client, error) {
// Read the file
content, err := os.ReadFile(path)
if err != nil {
return nil, fmt.Errorf("failed to read config file %s: %w", path, err)
}
// Extract client name from filename
base := filepath.Base(path)
name := strings.TrimPrefix(base, "client-")
name = strings.TrimSuffix(name, ".conf")
if name == "" || name == "client-" {
return nil, fmt.Errorf("invalid client filename: %s", base)
}
client := &Client{
Name: name,
ConfigPath: path,
}
// Parse the INI-style config
inPeerSection := false
hasPublicKey := false
for i, line := range strings.Split(string(content), "\n") {
line = strings.TrimSpace(line)
// Skip empty lines and comments
if line == "" || strings.HasPrefix(line, "#") || strings.HasPrefix(line, ";") {
continue
}
// Check for Peer section
if line == "[Peer]" {
inPeerSection = true
continue
}
// Parse key-value pairs within Peer section
if inPeerSection {
parts := strings.SplitN(line, "=", 2)
if len(parts) != 2 {
log.Printf("Warning: malformed line %d in %s: %s", i+1, path, line)
continue
}
key := strings.TrimSpace(parts[0])
value := strings.TrimSpace(parts[1])
switch key {
case "PublicKey":
client.PublicKey = value
hasPublicKey = true
case "PresharedKey":
client.HasPSK = true
case "AllowedIPs":
if err := parseAllowedIPs(client, value); err != nil {
log.Printf("Warning: %v (file: %s, line: %d)", err, path, i+1)
}
}
}
}
// Validate required fields
if !hasPublicKey {
return nil, fmt.Errorf("missing required PublicKey in %s", path)
}
if client.IPv4 == "" && client.IPv6 == "" {
return nil, fmt.Errorf("no valid IP addresses found in AllowedIPs in %s", path)
}
return client, nil
}
// parseAllowedIPs extracts IPv4 and IPv6 addresses from AllowedIPs value
func parseAllowedIPs(client *Client, allowedIPs string) error {
// AllowedIPs format: "ipv4/32, ipv6/128"
addresses := strings.Split(allowedIPs, ",")
for _, addr := range addresses {
addr = strings.TrimSpace(addr)
if addr == "" {
continue
}
// Split IP from CIDR suffix
parts := strings.Split(addr, "/")
if len(parts) != 2 {
return fmt.Errorf("invalid AllowedIP format: %s", addr)
}
ip := strings.TrimSpace(parts[0])
// Detect if IPv4 or IPv6 based on presence of colon
if strings.Contains(ip, ":") {
client.IPv6 = ip
} else {
client.IPv4 = ip
}
}
return nil
}
// ListClients finds and parses all client configurations from /etc/wireguard/conf.d/
func ListClients() ([]Client, error) {
configDir := "/etc/wireguard/conf.d"
// Check if directory exists
if _, err := os.Stat(configDir); os.IsNotExist(err) {
return nil, fmt.Errorf("wireguard config directory does not exist: %s", configDir)
}
// Find all client-*.conf files
pattern := filepath.Join(configDir, "client-*.conf")
matches, err := filepath.Glob(pattern)
if err != nil {
return nil, fmt.Errorf("failed to find client config files: %w", err)
}
if len(matches) == 0 {
return []Client{}, nil // No clients found, return empty slice
}
// Parse each config file
var clients []Client
var parseErrors []string
for _, match := range matches {
client, err := ParseClientConfig(match)
if err != nil {
parseErrors = append(parseErrors, err.Error())
log.Printf("Warning: failed to parse %s: %v", match, err)
continue
}
clients = append(clients, *client)
}
// If all files failed to parse, return an error
if len(clients) == 0 && len(parseErrors) > 0 {
return nil, fmt.Errorf("failed to parse any client configs: %s", strings.Join(parseErrors, "; "))
}
return clients, nil
}
// GetClientConfigContent reads the raw configuration content for a client
func GetClientConfigContent(name string) (string, error) {
configDir := "/etc/wireguard/clients"
configPath := filepath.Join(configDir, fmt.Sprintf("%s.conf", name))
content, err := os.ReadFile(configPath)
if err != nil {
if os.IsNotExist(err) {
return "", fmt.Errorf("client config not found: %s", configPath)
}
return "", fmt.Errorf("failed to read client config %s: %w", configPath, err)
}
return string(content), nil
}
// DeleteClient removes a WireGuard client configuration and associated files
func DeleteClient(name string) error {
// First, find the client config to get public key for removal from interface
configDir := "/etc/wireguard/conf.d"
configPath := filepath.Join(configDir, fmt.Sprintf("client-%s.conf", name))
client, err := ParseClientConfig(configPath)
if err != nil {
return fmt.Errorf("failed to parse client config for deletion: %w", err)
}
log.Printf("Deleting client: %s (public key: %s)", name, client.PublicKey)
// Create backup before deletion
backupPath, err := backup.BackupConfig(fmt.Sprintf("delete-%s", name))
if err != nil {
log.Printf("Warning: failed to create backup before deletion: %v", err)
} else {
log.Printf("Created backup: %s", backupPath)
}
// Remove peer from WireGuard interface using wg command
if err := removePeerFromInterface(client.PublicKey); err != nil {
log.Printf("Warning: failed to remove peer from interface: %v", err)
}
// Remove client config from /etc/wireguard/conf.d/
if err := os.Remove(configPath); err != nil && !os.IsNotExist(err) {
return fmt.Errorf("failed to remove client config %s: %w", configPath, err)
}
log.Printf("Removed client config: %s", configPath)
// Remove client files from /etc/wireguard/clients/
clientsDir := "/etc/wireguard/clients"
clientFile := filepath.Join(clientsDir, fmt.Sprintf("%s.conf", name))
if err := os.Remove(clientFile); err != nil && !os.IsNotExist(err) {
log.Printf("Warning: failed to remove client file %s: %v", clientFile, err)
} else {
log.Printf("Removed client file: %s", clientFile)
}
// Remove QR code PNG if it exists
qrFile := filepath.Join(clientsDir, fmt.Sprintf("%s.png", name))
if err := os.Remove(qrFile); err != nil && !os.IsNotExist(err) {
log.Printf("Warning: failed to remove QR code %s: %v", qrFile, err)
} else {
log.Printf("Removed QR code: %s", qrFile)
}
log.Printf("Successfully deleted client: %s", name)
return nil
}
// removePeerFromInterface removes a peer from the WireGuard interface
func removePeerFromInterface(publicKey string) error {
// Use wg command to remove peer
cmd := exec.Command("wg", "set", "wg0", "peer", publicKey, "remove")
output, err := cmd.CombinedOutput()
if err != nil {
return fmt.Errorf("wg set peer remove failed: %w, output: %s", err, string(output))
}
return nil
}
// CreateClient creates a new WireGuard client configuration
func CreateClient(name, dns string, usePSK bool) error {
log.Printf("Creating client: %s (PSK: %v)", name, usePSK)
// Create backup before creating client
backupPath, err := backup.BackupConfig(fmt.Sprintf("create-%s", name))
if err != nil {
log.Printf("Warning: failed to create backup before creating client: %v", err)
} else {
log.Printf("Created backup: %s", backupPath)
}
// Generate keys
privateKey, publicKey, err := generateKeyPair()
if err != nil {
return fmt.Errorf("failed to generate key pair: %w", err)
}
var psk string
if usePSK {
psk, err = generatePSK()
if err != nil {
return fmt.Errorf("failed to generate PSK: %w", err)
}
}
// Get next available IP addresses
cfg, err := config.LoadConfig()
if err != nil {
return fmt.Errorf("failed to load config: %w", err)
}
clients, err := ListClients()
if err != nil {
return fmt.Errorf("failed to list existing clients: %w", err)
}
ipv4, err := getNextAvailableIP(cfg.VPNIPv4Range, clients)
if err != nil {
return fmt.Errorf("failed to get IPv4 address: %w", err)
}
ipv6, err := getNextAvailableIP(cfg.VPNIPv6Range, clients)
if err != nil {
return fmt.Errorf("failed to get IPv6 address: %w", err)
}
// Create server config
serverConfigPath := fmt.Sprintf("/etc/wireguard/conf.d/client-%s.conf", name)
serverConfig, err := generateServerConfig(name, publicKey, ipv4, ipv6, psk, cfg)
if err != nil {
return fmt.Errorf("failed to generate server config: %w", err)
}
if err := os.WriteFile(serverConfigPath, []byte(serverConfig), 0600); err != nil {
return fmt.Errorf("failed to write server config: %w", err)
}
log.Printf("Created server config: %s", serverConfigPath)
// Create client config
clientsDir := "/etc/wireguard/clients"
if err := os.MkdirAll(clientsDir, 0700); err != nil {
return fmt.Errorf("failed to create clients directory: %w", err)
}
clientConfigPath := filepath.Join(clientsDir, fmt.Sprintf("%s.conf", name))
clientConfig, err := generateClientConfig(name, privateKey, ipv4, ipv6, dns, cfg)
if err != nil {
return fmt.Errorf("failed to generate client config: %w", err)
}
if err := os.WriteFile(clientConfigPath, []byte(clientConfig), 0600); err != nil {
return fmt.Errorf("failed to write client config: %w", err)
}
log.Printf("Created client config: %s", clientConfigPath)
// Add peer to WireGuard interface
if err := addPeerToInterface(publicKey, ipv4, ipv6, psk); err != nil {
log.Printf("Warning: failed to add peer to interface: %v", err)
} else {
log.Printf("Added peer to WireGuard interface")
}
// Generate QR code
qrPath := filepath.Join(clientsDir, fmt.Sprintf("%s.png", name))
if err := generateQRCode(clientConfigPath, qrPath); err != nil {
log.Printf("Warning: failed to generate QR code: %v", err)
} else {
log.Printf("Generated QR code: %s", qrPath)
}
log.Printf("Successfully created client: %s", name)
return nil
}
// generateKeyPair generates a WireGuard private and public key pair
func generateKeyPair() (privateKey, publicKey string, err error) {
// Generate private key
privateKeyBytes, err := exec.Command("wg", "genkey").Output()
if err != nil {
return "", "", fmt.Errorf("wg genkey failed: %w", err)
}
privateKey = strings.TrimSpace(string(privateKeyBytes))
// Derive public key
pubKeyCmd := exec.Command("wg", "pubkey")
pubKeyCmd.Stdin = strings.NewReader(privateKey)
publicKeyBytes, err := pubKeyCmd.Output()
if err != nil {
return "", "", fmt.Errorf("wg pubkey failed: %w", err)
}
publicKey = strings.TrimSpace(string(publicKeyBytes))
return privateKey, publicKey, nil
}
// generatePSK generates a WireGuard preshared key
func generatePSK() (string, error) {
psk, err := exec.Command("wg", "genpsk").Output()
if err != nil {
return "", fmt.Errorf("wg genpsk failed: %w", err)
}
return strings.TrimSpace(string(psk)), nil
}
// getNextAvailableIP finds the next available IP address in the given CIDR range
func getNextAvailableIP(cidr string, existingClients []Client) (string, error) {
// Parse CIDR to get network
parts := strings.Split(cidr, "/")
if len(parts) != 2 {
return "", fmt.Errorf("invalid CIDR format: %s", cidr)
}
network := strings.TrimSpace(parts[0])
// For IPv4, extract base network and assign next available host
if !strings.Contains(network, ":") {
// IPv4: Simple implementation - use .1, .2, etc.
// In production, this would parse the CIDR properly
usedHosts := make(map[string]bool)
for _, client := range existingClients {
if client.IPv4 != "" {
ipParts := strings.Split(client.IPv4, ".")
if len(ipParts) == 4 {
usedHosts[ipParts[3]] = true
}
}
}
// Find next available host (skip 0 and 1 as they may be reserved)
for i := 2; i < 255; i++ {
host := fmt.Sprintf("%d", i)
if !usedHosts[host] {
return fmt.Sprintf("%s.%s", network, host), nil
}
}
return "", fmt.Errorf("no available IPv4 addresses in range: %s", cidr)
}
// IPv6: Similar simplified approach
usedHosts := make(map[string]bool)
for _, client := range existingClients {
if client.IPv6 != "" {
// Extract last segment for IPv6
lastColon := strings.LastIndex(client.IPv6, ":")
if lastColon > 0 {
host := client.IPv6[lastColon+1:]
usedHosts[host] = true
}
}
}
// Find next available host
for i := 1; i < 65536; i++ {
host := fmt.Sprintf("%x", i)
if !usedHosts[host] {
return fmt.Sprintf("%s:%s", network, host), nil
}
}
return "", fmt.Errorf("no available IPv6 addresses in range: %s", cidr)
}
// generateServerConfig generates the server-side configuration for a client
func generateServerConfig(name, publicKey, ipv4, ipv6, psk string, cfg *config.Config) (string, error) {
var builder strings.Builder
builder.WriteString(fmt.Sprintf("# Client: %s\n", name))
builder.WriteString(fmt.Sprintf("[Peer]\n"))
builder.WriteString(fmt.Sprintf("PublicKey = %s\n", publicKey))
allowedIPs := ""
if ipv4 != "" {
allowedIPs = ipv4 + "/32"
}
if ipv6 != "" {
if allowedIPs != "" {
allowedIPs += ", "
}
allowedIPs += ipv6 + "/128"
}
builder.WriteString(fmt.Sprintf("AllowedIPs = %s\n", allowedIPs))
if psk != "" {
builder.WriteString(fmt.Sprintf("PresharedKey = %s\n", psk))
}
return builder.String(), nil
}
// generateClientConfig generates the client-side configuration
func generateClientConfig(name, privateKey, ipv4, ipv6, dns string, cfg *config.Config) (string, error) {
// Get server's public key from the main config
serverConfigPath := "/etc/wireguard/wg0.conf"
serverPublicKey, serverEndpoint, err := getServerConfig(serverConfigPath)
if err != nil {
return "", fmt.Errorf("failed to read server config: %w", err)
}
var builder strings.Builder
builder.WriteString(fmt.Sprintf("# WireGuard client configuration for %s\n", name))
builder.WriteString("[Interface]\n")
builder.WriteString(fmt.Sprintf("PrivateKey = %s\n", privateKey))
builder.WriteString(fmt.Sprintf("Address = %s/32", ipv4))
if ipv6 != "" {
builder.WriteString(fmt.Sprintf(", %s/128", ipv6))
}
builder.WriteString("\n")
builder.WriteString(fmt.Sprintf("DNS = %s\n", dns))
builder.WriteString("\n")
builder.WriteString("[Peer]\n")
builder.WriteString(fmt.Sprintf("PublicKey = %s\n", serverPublicKey))
builder.WriteString(fmt.Sprintf("Endpoint = %s:%d\n", serverEndpoint, cfg.WGPort))
builder.WriteString("AllowedIPs = 0.0.0.0/0, ::/0\n")
return builder.String(), nil
}
// getServerConfig reads the server's public key and endpoint from the main config
func getServerConfig(path string) (publicKey, endpoint string, err error) {
content, err := os.ReadFile(path)
if err != nil {
return "", "", fmt.Errorf("failed to read server config: %w", err)
}
inInterfaceSection := false
for _, line := range strings.Split(string(content), "\n") {
line = strings.TrimSpace(line)
if line == "[Interface]" {
inInterfaceSection = true
continue
}
if line == "[Peer]" {
inInterfaceSection = false
continue
}
if inInterfaceSection {
if strings.HasPrefix(line, "PublicKey") {
parts := strings.SplitN(line, "=", 2)
if len(parts) == 2 {
publicKey = strings.TrimSpace(parts[1])
}
}
}
}
// Use SERVER_DOMAIN as endpoint if available, otherwise fallback
cfg, err := config.LoadConfig()
if err == nil && cfg.ServerDomain != "" {
endpoint = cfg.ServerDomain
} else {
endpoint = "0.0.0.0"
}
return publicKey, endpoint, nil
}
// addPeerToInterface adds a peer to the WireGuard interface
func addPeerToInterface(publicKey, ipv4, ipv6, psk string) error {
args := []string{"set", "wg0", "peer", publicKey}
if ipv4 != "" {
args = append(args, "allowed-ips", ipv4+"/32")
}
if ipv6 != "" {
args = append(args, ipv6+"/128")
}
if psk != "" {
args = append(args, "preshared-key", "/dev/stdin")
cmd := exec.Command("wg", args...)
cmd.Stdin = strings.NewReader(psk)
output, err := cmd.CombinedOutput()
if err != nil {
return fmt.Errorf("wg set peer failed: %w, output: %s", err, string(output))
}
} else {
cmd := exec.Command("wg", args...)
output, err := cmd.CombinedOutput()
if err != nil {
return fmt.Errorf("wg set peer failed: %w, output: %s", err, string(output))
}
}
return nil
}
// generateQRCode generates a QR code from the client config
func generateQRCode(configPath, qrPath string) error {
// Read config file
content, err := os.ReadFile(configPath)
if err != nil {
return fmt.Errorf("failed to read config file: %w", err)
}
// Generate QR code using qrencode or similar
// For now, use a simple approach with qrencode if available
cmd := exec.Command("qrencode", "-o", qrPath, "-t", "PNG")
cmd.Stdin = bytes.NewReader(content)
if err := cmd.Run(); err != nil {
// qrencode not available, try alternative method
return fmt.Errorf("qrencode not available: %w", err)
}
return nil
}

View File

@@ -0,0 +1,106 @@
package wireguard
import (
"fmt"
"log"
"os"
"path/filepath"
"strings"
)
// GenerateServerConfig generates a server-side WireGuard [Peer] configuration
// and writes it to /etc/wireguard/conf.d/client-<name>.conf
func GenerateServerConfig(name, publicKey string, hasPSK bool, ipv4, ipv6, psk string) (string, error) {
var builder strings.Builder
builder.WriteString(fmt.Sprintf("[Peer]\n"))
builder.WriteString(fmt.Sprintf("# %s\n", name))
builder.WriteString(fmt.Sprintf("PublicKey = %s\n", publicKey))
if hasPSK {
builder.WriteString(fmt.Sprintf("PresharedKey = %s\n", psk))
}
builder.WriteString(fmt.Sprintf("AllowedIPs = %s/32", ipv4))
if ipv6 != "" {
builder.WriteString(fmt.Sprintf(", %s/128", ipv6))
}
builder.WriteString("\n")
configContent := builder.String()
configDir := "/etc/wireguard/conf.d"
configPath := filepath.Join(configDir, fmt.Sprintf("client-%s.conf", name))
if err := atomicWrite(configPath, configContent); err != nil {
return "", fmt.Errorf("failed to write server config: %w", err)
}
log.Printf("Generated server config: %s", configPath)
return configPath, nil
}
// GenerateClientConfig generates a client-side WireGuard configuration
// with [Interface] and [Peer] sections and writes it to /etc/wireguard/clients/<name>.conf
func GenerateClientConfig(name, privateKey, ipv4, ipv6, dns, serverPublicKey, endpoint string, port int, hasPSK bool, psk string) (string, error) {
var builder strings.Builder
// [Interface] section
builder.WriteString("[Interface]\n")
builder.WriteString(fmt.Sprintf("PrivateKey = %s\n", privateKey))
builder.WriteString(fmt.Sprintf("Address = %s/24", ipv4))
if ipv6 != "" {
builder.WriteString(fmt.Sprintf(", %s/64", ipv6))
}
builder.WriteString("\n")
builder.WriteString(fmt.Sprintf("DNS = %s\n", dns))
// [Peer] section
builder.WriteString("\n[Peer]\n")
builder.WriteString(fmt.Sprintf("PublicKey = %s\n", serverPublicKey))
if hasPSK {
builder.WriteString(fmt.Sprintf("PresharedKey = %s\n", psk))
}
builder.WriteString(fmt.Sprintf("Endpoint = %s:%d\n", endpoint, port))
builder.WriteString("AllowedIPs = 0.0.0.0/0, ::/0\n")
builder.WriteString("PersistentKeepalive = 25\n")
configContent := builder.String()
clientsDir := "/etc/wireguard/clients"
configPath := filepath.Join(clientsDir, fmt.Sprintf("%s.conf", name))
if err := atomicWrite(configPath, configContent); err != nil {
return "", fmt.Errorf("failed to write client config: %w", err)
}
log.Printf("Generated client config: %s", configPath)
return configPath, nil
}
// atomicWrite writes content to a file atomically
// Uses temp file + rename pattern for atomicity and sets permissions to 0600
func atomicWrite(path, content string) error {
// Ensure parent directory exists
dir := filepath.Dir(path)
if err := os.MkdirAll(dir, 0755); err != nil {
return fmt.Errorf("failed to create directory %s: %w", dir, err)
}
// Write to temp file first
tempPath := path + ".tmp"
if err := os.WriteFile(tempPath, []byte(content), 0600); err != nil {
return fmt.Errorf("failed to write temp file: %w", err)
}
// Atomic rename
if err := os.Rename(tempPath, path); err != nil {
// Clean up temp file if rename fails
_ = os.Remove(tempPath)
return fmt.Errorf("failed to rename temp file: %w", err)
}
return nil
}

243
internal/wireguard/keys.go Normal file
View File

@@ -0,0 +1,243 @@
package wireguard
import (
"encoding/base64"
"fmt"
"log"
"os"
"os/exec"
"path/filepath"
"strings"
"sync"
)
const (
// KeyLength is the standard WireGuard key length in bytes (32 bytes = 256 bits)
KeyLength = 32
// Base64KeyLength is the length of a base64-encoded WireGuard key (44 characters)
Base64KeyLength = 44
// TempDir is the directory for temporary key storage
TempDir = "/tmp/wg-admin"
)
var (
tempKeys = make(map[string]bool)
tempKeysMutex sync.Mutex
)
// KeyPair represents a WireGuard key pair
type KeyPair struct {
PrivateKey string // Base64-encoded private key
PublicKey string // Base64-encoded public key
}
// GeneratePrivateKey generates a new WireGuard private key using wg genkey
func GeneratePrivateKey() (string, error) {
cmd := exec.Command("wg", "genkey")
output, err := cmd.CombinedOutput()
if err != nil {
return "", fmt.Errorf("wg genkey failed: %w, output: %s", err, string(output))
}
key := string(output)
if err := ValidateKey(key); err != nil {
return "", fmt.Errorf("generated invalid private key: %w", err)
}
// Store as temporary key
tempKeyPath := filepath.Join(TempDir, "private.key")
if err := storeTempKey(tempKeyPath, key); err != nil {
log.Printf("Warning: failed to store temporary private key: %v", err)
}
return key, nil
}
// GeneratePublicKey generates a public key from a private key using wg pubkey
func GeneratePublicKey(privateKey string) (string, error) {
if err := ValidateKey(privateKey); err != nil {
return "", fmt.Errorf("invalid private key: %w", err)
}
cmd := exec.Command("wg", "pubkey")
cmd.Stdin = strings.NewReader(privateKey)
output, err := cmd.CombinedOutput()
if err != nil {
return "", fmt.Errorf("wg pubkey failed: %w, output: %s", err, string(output))
}
key := string(output)
if err := ValidateKey(key); err != nil {
return "", fmt.Errorf("generated invalid public key: %w", err)
}
// Store as temporary key
tempKeyPath := filepath.Join(TempDir, "public.key")
if err := storeTempKey(tempKeyPath, key); err != nil {
log.Printf("Warning: failed to store temporary public key: %v", err)
}
return key, nil
}
// GeneratePSK generates a new pre-shared key using wg genpsk
func GeneratePSK() (string, error) {
cmd := exec.Command("wg", "genpsk")
output, err := cmd.CombinedOutput()
if err != nil {
return "", fmt.Errorf("wg genpsk failed: %w, output: %s", err, string(output))
}
key := string(output)
if err := ValidateKey(key); err != nil {
return "", fmt.Errorf("generated invalid PSK: %w", err)
}
// Store as temporary key
tempKeyPath := filepath.Join(TempDir, "psk.key")
if err := storeTempKey(tempKeyPath, key); err != nil {
log.Printf("Warning: failed to store temporary PSK: %v", err)
}
return key, nil
}
// GenerateKeyPair generates a complete WireGuard key pair (private + public)
func GenerateKeyPair() (*KeyPair, error) {
privateKey, err := GeneratePrivateKey()
if err != nil {
return nil, err
}
publicKey, err := GeneratePublicKey(privateKey)
if err != nil {
return nil, fmt.Errorf("failed to generate public key: %w", err)
}
return &KeyPair{
PrivateKey: privateKey,
PublicKey: publicKey,
}, nil
}
// ValidateKey validates that a key is properly formatted (44 base64 characters)
func ValidateKey(key string) error {
// Trim whitespace
key = strings.TrimSpace(key)
// Check length (44 base64 characters for 32 bytes)
if len(key) != Base64KeyLength {
return fmt.Errorf("invalid key length: expected %d characters, got %d", Base64KeyLength, len(key))
}
// Verify it's valid base64
decoded, err := base64.StdEncoding.DecodeString(key)
if err != nil {
return fmt.Errorf("invalid base64 encoding: %w", err)
}
// Verify decoded length is 32 bytes
if len(decoded) != KeyLength {
return fmt.Errorf("invalid decoded key length: expected %d bytes, got %d", KeyLength, len(decoded))
}
return nil
}
// StoreKey atomically writes a key to a file with 0600 permissions
func StoreKey(path string, key string) error {
// Validate key before storing
if err := ValidateKey(key); err != nil {
return fmt.Errorf("invalid key: %w", err)
}
// Trim whitespace
key = strings.TrimSpace(key)
// Create parent directories if needed
dir := filepath.Dir(path)
if err := os.MkdirAll(dir, 0755); err != nil {
return fmt.Errorf("failed to create directory %s: %w", dir, err)
}
// Write to temporary file
tempPath := path + ".tmp"
if err := os.WriteFile(tempPath, []byte(key), 0600); err != nil {
return fmt.Errorf("failed to write temp file %s: %w", tempPath, err)
}
// Atomic rename
if err := os.Rename(tempPath, path); err != nil {
os.Remove(tempPath) // Clean up temp file on failure
return fmt.Errorf("failed to rename temp file to %s: %w", path, err)
}
return nil
}
// LoadKey reads a key from a file and validates it
func LoadKey(path string) (string, error) {
data, err := os.ReadFile(path)
if err != nil {
return "", fmt.Errorf("failed to read key file %s: %w", path, err)
}
key := strings.TrimSpace(string(data))
if err := ValidateKey(key); err != nil {
return "", fmt.Errorf("invalid key in file %s: %w", path, err)
}
return key, nil
}
// storeTempKey stores a temporary key and tracks it for cleanup
func storeTempKey(path string, key string) error {
// Create temp directory if needed
if err := os.MkdirAll(TempDir, 0700); err != nil {
return fmt.Errorf("failed to create temp directory %s: %w", TempDir, err)
}
// Trim whitespace
key = strings.TrimSpace(key)
// Write to file with 0600 permissions
if err := os.WriteFile(path, []byte(key), 0600); err != nil {
return fmt.Errorf("failed to write temp key to %s: %w", path, err)
}
// Track for cleanup
tempKeysMutex.Lock()
tempKeys[path] = true
tempKeysMutex.Unlock()
return nil
}
// CleanupTempKeys removes all temporary keys
func CleanupTempKeys() error {
tempKeysMutex.Lock()
defer tempKeysMutex.Unlock()
var cleanupErrors []string
for path := range tempKeys {
if err := os.Remove(path); err != nil && !os.IsNotExist(err) {
cleanupErrors = append(cleanupErrors, fmt.Sprintf("%s: %v", path, err))
log.Printf("Warning: failed to remove temp key %s: %v", path, err)
}
delete(tempKeys, path)
}
// Also attempt to clean up the temp directory if empty
if _, err := os.ReadDir(TempDir); err == nil {
if err := os.Remove(TempDir); err != nil && !os.IsNotExist(err) {
log.Printf("Warning: failed to remove temp directory %s: %v", TempDir, err)
}
}
if len(cleanupErrors) > 0 {
return fmt.Errorf("cleanup errors: %s", strings.Join(cleanupErrors, "; "))
}
return nil
}

View File

@@ -0,0 +1,204 @@
package wireguard
import (
"bytes"
"fmt"
"os/exec"
"regexp"
"strconv"
"strings"
"time"
)
const (
// StatusConnected indicates a peer has an active connection
StatusConnected = "Connected"
// StatusDisconnected indicates a peer is not connected
StatusDisconnected = "Disconnected"
)
// PeerStatus represents the status of a WireGuard peer
type PeerStatus struct {
PublicKey string `json:"public_key"`
Endpoint string `json:"endpoint"`
AllowedIPs string `json:"allowed_ips"`
LatestHandshake time.Time `json:"latest_handshake"`
TransferRx string `json:"transfer_rx"`
TransferTx string `json:"transfer_tx"`
Status string `json:"status"` // "Connected" or "Disconnected"
}
// GetClientStatus checks if a specific client is connected
// Returns "Connected" if the peer appears in the active peers list, "Disconnected" otherwise
func GetClientStatus(publicKey string) (string, error) {
peers, err := GetAllPeers()
if err != nil {
return StatusDisconnected, fmt.Errorf("failed to get peer status: %w", err)
}
for _, peer := range peers {
if peer.PublicKey == publicKey {
return peer.Status, nil
}
}
return StatusDisconnected, nil
}
// GetAllPeers retrieves all peers with their current status from WireGuard
func GetAllPeers() ([]PeerStatus, error) {
output, err := exec.Command("wg", "show", "wg0").Output()
if err != nil {
return nil, fmt.Errorf("failed to execute wg show: %w", err)
}
return parsePeersOutput(string(output)), nil
}
// parsePeersOutput parses the output of 'wg show wg0' command
func parsePeersOutput(output string) []PeerStatus {
var peers []PeerStatus
var currentPeer *PeerStatus
var handshake string
var transfer string
lines := strings.Split(output, "\n")
peerLineRegex := regexp.MustCompile(`^peer:\s*(.+)$`)
handshakeRegex := regexp.MustCompile(`^latest handshake:\s*(.+)\s+ago$`)
transferRegex := regexp.MustCompile(`^transfer:\s*(.+),\s+(.+)$`)
for _, line := range lines {
line = strings.TrimSpace(line)
// Check for new peer
if match := peerLineRegex.FindStringSubmatch(line); match != nil {
// Save previous peer if exists
if currentPeer != nil {
peers = append(peers, finalizePeerStatus(currentPeer, handshake, transfer))
}
// Start new peer
currentPeer = &PeerStatus{
PublicKey: match[1],
}
handshake = ""
transfer = ""
continue
}
if currentPeer == nil {
continue
}
// Parse endpoint
if strings.HasPrefix(line, "endpoint:") {
currentPeer.Endpoint = strings.TrimSpace(strings.TrimPrefix(line, "endpoint:"))
}
// Parse allowed ips
if strings.HasPrefix(line, "allowed ips:") {
currentPeer.AllowedIPs = strings.TrimSpace(strings.TrimPrefix(line, "allowed ips:"))
}
// Parse latest handshake
if match := handshakeRegex.FindStringSubmatch(line); match != nil {
handshake = match[1]
}
// Parse transfer
if match := transferRegex.FindStringSubmatch(line); match != nil {
transfer = fmt.Sprintf("%s, %s", match[1], match[2])
}
}
// Don't forget the last peer
if currentPeer != nil {
peers = append(peers, finalizePeerStatus(currentPeer, handshake, transfer))
}
return peers
}
// finalizePeerStatus determines the peer's status based on handshake time
func finalizePeerStatus(peer *PeerStatus, handshake string, transfer string) PeerStatus {
peer.TransferRx = ""
peer.TransferTx = ""
// Parse transfer
if transfer != "" {
parts := strings.Split(transfer, ", ")
if len(parts) == 2 {
// Extract received and sent values
rxParts := strings.Fields(parts[0])
if len(rxParts) >= 2 {
peer.TransferRx = strings.Join(rxParts[:2], " ")
}
txParts := strings.Fields(parts[1])
if len(txParts) >= 2 {
peer.TransferTx = strings.Join(txParts[:2], " ")
}
}
}
// Determine status based on handshake
if handshake != "" {
peer.LatestHandshake = parseHandshake(handshake)
// Peer is considered connected if handshake is recent (within 3 minutes)
if time.Since(peer.LatestHandshake) < 3*time.Minute {
peer.Status = StatusConnected
} else {
peer.Status = StatusDisconnected
}
} else {
peer.Status = StatusDisconnected
}
return *peer
}
// parseHandshake converts handshake string to time.Time
func parseHandshake(handshake string) time.Time {
now := time.Now()
parts := strings.Fields(handshake)
for i, part := range parts {
if strings.HasSuffix(part, "second") || strings.HasSuffix(part, "seconds") {
if val, err := strconv.Atoi(strings.TrimSuffix(part, "s")); err == nil {
return now.Add(-time.Duration(val) * time.Second)
}
}
if strings.HasSuffix(part, "minute") || strings.HasSuffix(part, "minutes") {
if val, err := strconv.Atoi(strings.TrimSuffix(part, "s")); err == nil {
return now.Add(-time.Duration(val) * time.Minute)
}
}
if strings.HasSuffix(part, "hour") || strings.HasSuffix(part, "hours") {
if val, err := strconv.Atoi(strings.TrimSuffix(part, "s")); err == nil {
return now.Add(-time.Duration(val) * time.Hour)
}
}
if strings.HasSuffix(part, "day") || strings.HasSuffix(part, "days") {
if val, err := strconv.Atoi(strings.TrimSuffix(part, "s")); err == nil {
return now.Add(-time.Duration(val) * 24 * time.Hour)
}
}
// Handle "ago" word
if i > 0 && (part == "ago" || part == "ago,") {
// Continue parsing time units
}
}
return now.Add(-time.Hour) // Default to 1 hour ago if parsing fails
}
// CheckInterface verifies if the WireGuard interface exists and is accessible
func CheckInterface(interfaceName string) error {
cmd := exec.Command("wg", "show", interfaceName)
var out bytes.Buffer
cmd.Stdout = &out
err := cmd.Run()
if err != nil {
return fmt.Errorf("wireguard interface '%s' not accessible: %w", interfaceName, err)
}
return nil
}

View File

@@ -0,0 +1,27 @@
package wireguard
import (
"time"
tea "github.com/charmbracelet/bubbletea"
)
// StatusTickMsg is sent when it's time to refresh the status
type StatusTickMsg struct{}
// RefreshStatusMsg is sent when manual refresh is triggered
type RefreshStatusMsg struct{}
// Tick returns a tea.Cmd that will send StatusTickMsg at the specified interval
func Tick(interval int) tea.Cmd {
return tea.Tick(time.Duration(interval)*time.Second, func(t time.Time) tea.Msg {
return StatusTickMsg{}
})
}
// ManualRefresh returns a tea.Cmd to trigger immediate status refresh
func ManualRefresh() tea.Cmd {
return func() tea.Msg {
return RefreshStatusMsg{}
}
}