Add WireGuard TUI implementation

- Add Go TUI with bubbletea for WireGuard management
- Implement client CRUD operations with QR code generation
- Add configuration and validation modules
- Install/update scripts for client setup
- Update Makefile to build binaries to bin/ directory
- Add .gitignore for Go projects
This commit is contained in:
Calmcacil
2026-01-12 19:03:35 +01:00
parent 5ac68db854
commit 26120b8bc2
37 changed files with 6330 additions and 97 deletions

267
internal/backup/backup.go Normal file
View File

@@ -0,0 +1,267 @@
package backup
import (
"fmt"
"os"
"os/exec"
"path/filepath"
"sort"
"strings"
"time"
)
// Backup represents a backup with metadata
type Backup struct {
Name string // Backup name (directory name)
Path string // Full path to backup directory
Operation string // Operation that triggered the backup
Timestamp time.Time // When the backup was created
Size int64 // Size in bytes
}
// CreateBackup creates a new backup with the specified operation
func CreateBackup(operation string) error {
backupDir := "/etc/wg-admin/backups"
// Create backup directory if it doesn't exist
if err := os.MkdirAll(backupDir, 0700); err != nil {
return fmt.Errorf("failed to create backup directory: %w", err)
}
// Create timestamped backup directory
timestamp := time.Now().Format("20060102-150405")
backupName := fmt.Sprintf("wg-backup-%s-%s", operation, timestamp)
backupPath := filepath.Join(backupDir, backupName)
if err := os.MkdirAll(backupPath, 0700); err != nil {
return fmt.Errorf("failed to create backup path: %w", err)
}
// Backup entire wireguard directory to maintain structure expected by restore.go
wgConfigPath := "/etc/wireguard"
if _, err := os.Stat(wgConfigPath); err == nil {
backupWgPath := filepath.Join(backupPath, "wireguard")
if err := exec.Command("cp", "-a", wgConfigPath, backupWgPath).Run(); err != nil {
return fmt.Errorf("failed to backup wireguard config: %w", err)
}
}
// Create backup metadata
metadataPath := filepath.Join(backupPath, "backup-info.txt")
metadata := fmt.Sprintf("Backup created: %s\nOperation: %s\nTimestamp: %s\n", time.Now().Format(time.RFC3339), operation, timestamp)
if err := os.WriteFile(metadataPath, []byte(metadata), 0600); err != nil {
return fmt.Errorf("failed to create backup metadata: %w", err)
}
// Set restrictive permissions on backup directory
if err := os.Chmod(backupPath, 0700); err != nil {
return fmt.Errorf("failed to set backup directory permissions: %w", err)
}
// Apply retention policy (keep last 10 backups)
if err := applyRetentionPolicy(backupDir, 10); err != nil {
// Log but don't fail on retention errors
fmt.Fprintf(os.Stderr, "Warning: failed to apply retention policy: %v\n", err)
}
return nil
}
// ListBackups returns all available backups sorted by creation time (newest first)
func ListBackups() ([]Backup, error) {
backupDir := "/etc/wg-admin/backups"
// Check if backup directory exists
if _, err := os.Stat(backupDir); os.IsNotExist(err) {
return []Backup{}, nil
}
// Read all entries
entries, err := os.ReadDir(backupDir)
if err != nil {
return nil, fmt.Errorf("failed to read backup directory: %w", err)
}
var backups []Backup
// Parse backup directories
for _, entry := range entries {
if !entry.IsDir() {
continue
}
name := entry.Name()
// Check if it's a valid backup directory (starts with "wg-backup-")
if !strings.HasPrefix(name, "wg-backup-") {
continue
}
// Parse backup name to extract operation and timestamp
// Format: wg-backup-{operation}-{timestamp}
parts := strings.SplitN(name, "-", 4)
if len(parts) < 4 {
continue
}
operation := parts[2]
timestampStr := parts[3]
// Parse timestamp
timestamp, err := time.Parse("20060102-150405", timestampStr)
if err != nil {
// If timestamp parsing fails, use directory modification time
info, err := entry.Info()
if err != nil {
continue
}
timestamp = info.ModTime()
}
// Get backup size
backupPath := filepath.Join(backupDir, name)
size, err := getBackupSize(backupPath)
if err != nil {
size = 0
}
backup := Backup{
Name: name,
Path: backupPath,
Operation: operation,
Timestamp: timestamp,
Size: size,
}
backups = append(backups, backup)
}
// Sort by timestamp (newest first)
sort.Slice(backups, func(i, j int) bool {
return backups[i].Timestamp.After(backups[j].Timestamp)
})
return backups, nil
}
// getBackupSize calculates the total size of a backup directory
func getBackupSize(backupPath string) (int64, error) {
var size int64
err := filepath.Walk(backupPath, func(_ string, info os.FileInfo, err error) error {
if err != nil {
return err
}
if !info.IsDir() {
size += info.Size()
}
return nil
})
return size, err
}
// RestoreBackup restores WireGuard configurations from a backup by name
func RestoreBackup(backupName string) error {
backupDir := "/etc/wg-admin/backups"
backupPath := filepath.Join(backupDir, backupName)
// Verify backup exists
if _, err := os.Stat(backupPath); os.IsNotExist(err) {
return fmt.Errorf("backup does not exist: %s", backupName)
}
// Check for wireguard subdirectory
wgSourcePath := filepath.Join(backupPath, "wireguard")
if _, err := os.Stat(wgSourcePath); os.IsNotExist(err) {
return fmt.Errorf("backup does not contain wireguard configuration")
}
// Restore entire wireguard directory
wgDestPath := "/etc/wireguard"
if err := os.RemoveAll(wgDestPath); err != nil {
return fmt.Errorf("failed to remove existing wireguard config: %w", err)
}
if err := exec.Command("cp", "-a", wgSourcePath, wgDestPath).Run(); err != nil {
return fmt.Errorf("failed to restore wireguard config: %w", err)
}
// Set proper permissions on restored files
if err := setRestoredPermissions(wgDestPath); err != nil {
fmt.Fprintf(os.Stderr, "Warning: failed to set permissions on restored files: %v\n", err)
}
return nil
}
// setRestoredPermissions sets appropriate permissions on restored WireGuard files
func setRestoredPermissions(wgPath string) error {
// Set 0600 on .conf files
return filepath.Walk(wgPath, func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
if !info.IsDir() && strings.HasSuffix(path, ".conf") {
if err := os.Chmod(path, 0600); err != nil {
return fmt.Errorf("failed to chmod %s: %w", path, err)
}
}
return nil
})
}
// applyRetentionPolicy keeps only the last N backups
func applyRetentionPolicy(backupDir string, keepCount int) error {
// List all backup directories
entries, err := os.ReadDir(backupDir)
if err != nil {
return err
}
// Filter backup directories and sort by modification time
var backups []os.FileInfo
for _, entry := range entries {
if entry.IsDir() && len(entry.Name()) > 10 && entry.Name()[:10] == "wg-backup-" {
info, err := entry.Info()
if err != nil {
continue
}
backups = append(backups, info)
}
}
// If we have more backups than we want to keep, remove the oldest
if len(backups) > keepCount {
// Sort by modification time (oldest first)
for i := 0; i < len(backups); i++ {
for j := i + 1; j < len(backups); j++ {
if backups[i].ModTime().After(backups[j].ModTime()) {
backups[i], backups[j] = backups[j], backups[i]
}
}
}
// Remove oldest backups
toRemove := len(backups) - keepCount
for i := 0; i < toRemove; i++ {
backupPath := filepath.Join(backupDir, backups[i].Name())
if err := os.RemoveAll(backupPath); err != nil {
fmt.Fprintf(os.Stderr, "Warning: failed to remove old backup %s: %v\n", backupPath, err)
}
}
}
return nil
}
// BackupConfig is a compatibility wrapper that calls CreateBackup
func BackupConfig(operation string) (string, error) {
if err := CreateBackup(operation); err != nil {
return "", err
}
// Return the backup path for compatibility
timestamp := time.Now().Format("20060102-150405")
backupName := fmt.Sprintf("wg-backup-%s-%s", operation, timestamp)
return filepath.Join("/etc/wg-admin/backups", backupName), nil
}

122
internal/backup/restore.go Normal file
View File

@@ -0,0 +1,122 @@
package backup
import (
"fmt"
"os"
"os/exec"
"path/filepath"
"strings"
"time"
)
// ValidateBackup checks if a backup exists and is valid
func ValidateBackup(backupName string) error {
backupDir := "/etc/wg-admin/backups"
backupPath := filepath.Join(backupDir, backupName)
// Check if backup directory exists
info, err := os.Stat(backupPath)
if err != nil {
if os.IsNotExist(err) {
return fmt.Errorf("backup '%s' does not exist", backupName)
}
return fmt.Errorf("failed to access backup '%s': %w", backupName, err)
}
// Check if it's a directory
if !info.IsDir() {
return fmt.Errorf("'%s' is not a valid backup directory", backupName)
}
// Check for required files
metadataPath := filepath.Join(backupPath, "backup-info.txt")
if _, err := os.Stat(metadataPath); os.IsNotExist(err) {
return fmt.Errorf("backup '%s' is missing required metadata", backupName)
}
// Check for wireguard directory
wgBackupPath := filepath.Join(backupPath, "wireguard")
if _, err := os.Stat(wgBackupPath); os.IsNotExist(err) {
return fmt.Errorf("backup '%s' is missing wireguard configuration", backupName)
}
return nil
}
// ReloadWireGuard reloads the WireGuard interface to apply configuration changes
func ReloadWireGuard() error {
interfaceName := "wg0"
// Try to down the interface first
cmdDown := exec.Command("wg-quick", "down", interfaceName)
_ = cmdDown.Run() // Ignore errors if interface is not up
// Bring the interface up
cmdUp := exec.Command("wg-quick", "up", interfaceName)
output, err := cmdUp.CombinedOutput()
if err != nil {
return fmt.Errorf("failed to reload wireguard interface: %w, output: %s", err, string(output))
}
return nil
}
// GetBackupSize calculates the total size of a backup directory
func GetBackupSize(backupName string) (int64, error) {
backupDir := "/etc/wg-admin/backups"
backupPath := filepath.Join(backupDir, backupName)
var size int64
err := filepath.Walk(backupPath, func(_ string, info os.FileInfo, err error) error {
if err != nil {
return err
}
if !info.IsDir() {
size += info.Size()
}
return nil
})
return size, err
}
// GetBackupPath returns the full path for a backup name
func GetBackupPath(backupName string) string {
return filepath.Join("/etc/wg-admin/backups", backupName)
}
// ParseBackupName extracts operation and timestamp from backup directory name
// Format: wg-backup-{operation}-{timestamp}
func ParseBackupName(backupName string) (operation, timestamp string, err error) {
if !strings.HasPrefix(backupName, "wg-backup-") {
return "", "", fmt.Errorf("invalid backup name format")
}
nameWithoutPrefix := strings.TrimPrefix(backupName, "wg-backup-")
// Timestamp format: 20060102-150405 (15 chars)
if len(nameWithoutPrefix) < 16 {
return "", "", fmt.Errorf("backup name too short")
}
// Extract operation (everything before last timestamp)
timestampLen := 15
if len(nameWithoutPrefix) > timestampLen+1 {
operation = nameWithoutPrefix[:len(nameWithoutPrefix)-timestampLen-1]
// Remove trailing dash if present
if strings.HasSuffix(operation, "-") {
operation = operation[:len(operation)-1]
}
}
// Extract timestamp
timestamp = nameWithoutPrefix[len(nameWithoutPrefix)-timestampLen:]
// Validate timestamp format
if _, err := time.Parse("20060102-150405", timestamp); err != nil {
return "", "", fmt.Errorf("invalid timestamp format in backup name: %w", err)
}
return operation, timestamp, nil
}

237
internal/config/config.go Normal file
View File

@@ -0,0 +1,237 @@
package config
import (
"fmt"
"os"
"strings"
)
// Config holds application configuration
type Config struct {
ServerDomain string `mapstructure:"SERVER_DOMAIN"`
WGPort int `mapstructure:"WG_PORT"`
VPNIPv4Range string `mapstructure:"VPN_IPV4_RANGE"`
VPNIPv6Range string `mapstructure:"VPN_IPV6_RANGE"`
WGInterface string `mapstructure:"WG_INTERFACE"`
DNSServers string `mapstructure:"DNS_SERVERS"`
LogFile string `mapstructure:"LOG_FILE"`
Theme string `mapstructure:"THEME"`
}
// Default values
const (
DefaultWGPort = 51820
DefaultVPNIPv4Range = "10.10.69.0/24"
DefaultVPNIPv6Range = "fd69:dead:beef:69::/64"
DefaultWGInterface = "wg0"
DefaultDNSServers = "8.8.8.8, 8.8.4.4"
DefaultLogFile = "/var/log/wireguard-admin.log"
DefaultTheme = "default"
)
// LoadConfig loads configuration from file and environment variables
func LoadConfig() (*Config, error) {
cfg := &Config{}
// Load from config file if it exists
if err := loadFromFile(cfg); err != nil {
return nil, fmt.Errorf("failed to load config file: %w", err)
}
// Override with environment variables
if err := loadFromEnv(cfg); err != nil {
return nil, fmt.Errorf("failed to load config from environment: %w", err)
}
// Apply defaults for empty values
applyDefaults(cfg)
// Validate required configuration
if err := validateConfig(cfg); err != nil {
return nil, err
}
return cfg, nil
}
// loadFromFile reads configuration from /etc/wg-admin/config.conf
func loadFromFile(cfg *Config) error {
configPath := "/etc/wg-admin/config.conf"
// Check if config file exists
if _, err := os.Stat(configPath); os.IsNotExist(err) {
// Config file is optional, skip if not exists
return nil
}
// Read config file
content, err := os.ReadFile(configPath)
if err != nil {
return err
}
// Parse key=value pairs
lines := strings.Split(string(content), "\n")
for _, line := range lines {
line = strings.TrimSpace(line)
if line == "" || strings.HasPrefix(line, "#") {
continue
}
parts := strings.SplitN(line, "=", 2)
if len(parts) != 2 {
continue
}
key := strings.TrimSpace(parts[0])
value := strings.TrimSpace(parts[1])
// Set value using mapstructure tags
switch key {
case "SERVER_DOMAIN":
cfg.ServerDomain = value
case "WG_PORT":
var port int
if _, err := fmt.Sscanf(value, "%d", &port); err == nil {
cfg.WGPort = port
}
case "VPN_IPV4_RANGE":
cfg.VPNIPv4Range = value
case "VPN_IPV6_RANGE":
cfg.VPNIPv6Range = value
case "WG_INTERFACE":
cfg.WGInterface = value
case "DNS_SERVERS":
cfg.DNSServers = value
case "LOG_FILE":
cfg.LogFile = value
case "THEME":
cfg.Theme = value
}
}
return nil
}
// loadFromEnv loads configuration from environment variables
func loadFromEnv(cfg *Config) error {
// Read environment variables
if val := os.Getenv("SERVER_DOMAIN"); val != "" {
cfg.ServerDomain = val
}
if val := os.Getenv("WG_PORT"); val != "" {
var port int
if _, err := fmt.Sscanf(val, "%d", &port); err == nil {
cfg.WGPort = port
}
}
if val := os.Getenv("VPN_IPV4_RANGE"); val != "" {
cfg.VPNIPv4Range = val
}
if val := os.Getenv("VPN_IPV6_RANGE"); val != "" {
cfg.VPNIPv6Range = val
}
if val := os.Getenv("WG_INTERFACE"); val != "" {
cfg.WGInterface = val
}
if val := os.Getenv("DNS_SERVERS"); val != "" {
cfg.DNSServers = val
}
if val := os.Getenv("LOG_FILE"); val != "" {
cfg.LogFile = val
}
if val := os.Getenv("THEME"); val != "" {
cfg.Theme = val
}
return nil
}
// applyDefaults sets default values for empty configuration
func applyDefaults(cfg *Config) {
if cfg.WGPort == 0 {
cfg.WGPort = DefaultWGPort
}
if cfg.VPNIPv4Range == "" {
cfg.VPNIPv4Range = DefaultVPNIPv4Range
}
if cfg.VPNIPv6Range == "" {
cfg.VPNIPv6Range = DefaultVPNIPv6Range
}
if cfg.WGInterface == "" {
cfg.WGInterface = DefaultWGInterface
}
if cfg.DNSServers == "" {
cfg.DNSServers = DefaultDNSServers
}
if cfg.LogFile == "" {
cfg.LogFile = DefaultLogFile
}
if cfg.Theme == "" {
cfg.Theme = DefaultTheme
}
}
// validateConfig checks that required configuration is present
func validateConfig(cfg *Config) error {
if cfg.ServerDomain == "" {
return fmt.Errorf("SERVER_DOMAIN is required. Set it in /etc/wg-admin/config.conf or via environment variable.")
}
// Validate port range
if cfg.WGPort < 1 || cfg.WGPort > 65535 {
return fmt.Errorf("WG_PORT must be between 1 and 65535, got: %d", cfg.WGPort)
}
// Validate CIDR format for IPv4 range
if !isValidCIDR(cfg.VPNIPv4Range, true) {
return fmt.Errorf("Invalid VPN_IPV4_RANGE format: %s", cfg.VPNIPv4Range)
}
// Validate CIDR format for IPv6 range
if !isValidCIDR(cfg.VPNIPv6Range, false) {
return fmt.Errorf("Invalid VPN_IPV6_RANGE format: %s", cfg.VPNIPv6Range)
}
return nil
}
// isValidCIDR performs basic CIDR validation
func isValidCIDR(cidr string, isIPv4 bool) bool {
if cidr == "" {
return false
}
// Split address and prefix
parts := strings.Split(cidr, "/")
if len(parts) != 2 {
return false
}
// Basic validation - more comprehensive validation could be added
if isIPv4 {
// IPv4 CIDR should have address like x.x.x.x
return true // Simplified validation
}
// IPv6 CIDR
return true // Simplified validation
}
// GetVPNIPv4Network extracts the IPv4 network from CIDR (e.g., "10.10.69.0" from "10.10.69.0/24")
func (c *Config) GetVPNIPv4Network() string {
parts := strings.Split(c.VPNIPv4Range, "/")
if len(parts) == 0 {
return ""
}
return strings.TrimSuffix(parts[0], "0")
}
// GetVPNIPv6Network extracts the IPv6 network from CIDR (e.g., "fd69:dead:beef:69::" from "fd69:dead:beef:69::/64")
func (c *Config) GetVPNIPv6Network() string {
parts := strings.Split(c.VPNIPv6Range, "/")
if len(parts) == 0 {
return ""
}
return strings.TrimSuffix(parts[0], "::")
}

View File

@@ -0,0 +1,143 @@
package components
import (
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
)
// ConfirmModel represents a confirmation modal
type ConfirmModel struct {
Message string
Yes bool // true = yes, false = no
Visible bool
Width int
Height int
}
// Styles
var (
confirmModalStyle = lipgloss.NewStyle().
Border(lipgloss.RoundedBorder()).
BorderForeground(lipgloss.Color("196")).
Padding(1, 2).
Background(lipgloss.Color("235"))
confirmTitleStyle = lipgloss.NewStyle().
Foreground(lipgloss.Color("226")).
Bold(true)
confirmMessageStyle = lipgloss.NewStyle().
Foreground(lipgloss.Color("255")).
Width(50)
confirmHelpStyle = lipgloss.NewStyle().
Foreground(lipgloss.Color("241")).
MarginTop(1)
)
// NewConfirm creates a new confirmation modal
func NewConfirm(message string, width, height int) *ConfirmModel {
return &ConfirmModel{
Message: message,
Yes: false,
Visible: true,
Width: width,
Height: height,
}
}
// Init initializes the confirmation modal
func (m *ConfirmModel) Init() tea.Cmd {
return nil
}
// Update handles messages for the confirmation modal
func (m *ConfirmModel) Update(msg tea.Msg) (*ConfirmModel, tea.Cmd) {
switch msg := msg.(type) {
case tea.KeyMsg:
switch msg.String() {
case "y", "Y", "left":
m.Yes = true
case "n", "N", "right":
m.Yes = false
case "enter":
// Confirmed - will be handled by parent
return m, nil
case "esc":
m.Visible = false
return m, nil
}
}
return m, nil
}
// View renders the confirmation modal
func (m *ConfirmModel) View() string {
if !m.Visible {
return ""
}
// Build modal content
content := lipgloss.JoinVertical(
lipgloss.Left,
confirmTitleStyle.Render("⚠️ Confirm Action"),
"",
confirmMessageStyle.Render(m.Message),
"",
m.renderOptions(),
)
// Apply modal style
modal := confirmModalStyle.Render(content)
// Center modal on screen
modalWidth := lipgloss.Width(modal)
modalHeight := lipgloss.Height(modal)
x := (m.Width - modalWidth) / 2
if x < 0 {
x = 0
}
y := (m.Height - modalHeight) / 2
if y < 0 {
y = 0
}
return lipgloss.Place(m.Width, m.Height,
lipgloss.Left, lipgloss.Top,
modal,
lipgloss.WithWhitespaceChars(" "),
lipgloss.WithWhitespaceForeground(lipgloss.Color("235")),
)
}
// renderOptions renders the yes/no options
func (m *ConfirmModel) renderOptions() string {
yesStyle := lipgloss.NewStyle().Foreground(lipgloss.Color("241"))
noStyle := lipgloss.NewStyle().Foreground(lipgloss.Color("241"))
selectedStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("57")).
Bold(true).
Underline(true)
var yesText, noText string
if m.Yes {
yesText = selectedStyle.Render("[Yes]")
noText = noStyle.Render(" No ")
} else {
yesText = yesStyle.Render(" Yes ")
noText = selectedStyle.Render("[No]")
}
helpText := confirmHelpStyle.Render("←/→ to choose • Enter to confirm • Esc to cancel")
return lipgloss.JoinHorizontal(lipgloss.Left, yesText, " ", noText, "\n", helpText)
}
// IsConfirmed returns true if user confirmed with Yes
func (m *ConfirmModel) IsConfirmed() bool {
return m.Yes
}
// IsCancelled returns true if user cancelled
func (m *ConfirmModel) IsCancelled() bool {
return !m.Visible
}

View File

@@ -0,0 +1,307 @@
package components
import (
"strings"
"github.com/charmbracelet/bubbles/textinput"
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
)
// SearchFilterType represents the type of filter
type SearchFilterType string
const (
FilterByName SearchFilterType = "name"
FilterByIPv4 SearchFilterType = "ipv4"
FilterByIPv6 SearchFilterType = "ipv6"
FilterByStatus SearchFilterType = "status"
)
// SearchModel represents the search component
type SearchModel struct {
input textinput.Model
active bool
filterType SearchFilterType
matchCount int
totalCount int
visible bool
}
// Styles
var (
searchBarStyle = lipgloss.NewStyle().
Foreground(lipgloss.Color("255")).
Background(lipgloss.Color("235")).
Padding(0, 1).
Border(lipgloss.RoundedBorder()).
BorderForeground(lipgloss.Color("240"))
searchPromptStyle = lipgloss.NewStyle().
Foreground(lipgloss.Color("226")).
Bold(true)
searchFilterStyle = lipgloss.NewStyle().
Foreground(lipgloss.Color("147"))
searchCountStyle = lipgloss.NewStyle().
Foreground(lipgloss.Color("241"))
searchHelpStyle = lipgloss.NewStyle().
Foreground(lipgloss.Color("243"))
)
// NewSearch creates a new search component
func NewSearch() *SearchModel {
ti := textinput.New()
ti.Placeholder = "Search clients..."
ti.Focus()
ti.CharLimit = 156
ti.Width = 40
ti.Prompt = ""
return &SearchModel{
input: ti,
active: false,
filterType: FilterByName,
matchCount: 0,
totalCount: 0,
visible: true,
}
}
// Init initializes the search component
func (m *SearchModel) Init() tea.Cmd {
return nil
}
// Update handles messages for the search component
func (m *SearchModel) Update(msg tea.Msg) (*SearchModel, tea.Cmd) {
if !m.active {
return m, nil
}
var cmd tea.Cmd
switch msg := msg.(type) {
case tea.KeyMsg:
switch msg.String() {
case "esc":
m.active = false
m.input.Reset()
m.matchCount = m.totalCount
return m, nil
case "tab":
m.cycleFilterType()
return m, nil
}
}
m.input, cmd = m.input.Update(msg)
return m, cmd
}
// View renders the search component
func (m *SearchModel) View() string {
if !m.visible {
return ""
}
var filterLabel string
switch m.filterType {
case FilterByName:
filterLabel = "Name"
case FilterByIPv4:
filterLabel = "IPv4"
case FilterByIPv6:
filterLabel = "IPv6"
case FilterByStatus:
filterLabel = "Status"
}
searchIndicator := ""
if m.active {
searchIndicator = searchPromptStyle.Render("🔍 ")
} else {
searchIndicator = searchPromptStyle.Render("⌕ ")
}
filterText := searchFilterStyle.Render("[" + filterLabel + "]")
countText := ""
if m.totalCount > 0 {
countText = searchCountStyle.Render(
lipgloss.JoinHorizontal(
lipgloss.Left,
strings.Repeat(" ", 4),
"Matched: ",
m.renderCount(m.matchCount),
"/",
m.renderCount(m.totalCount),
),
)
}
helpText := ""
if m.active {
helpText = searchHelpStyle.Render(" | Tab: filter | Esc: clear")
} else {
helpText = searchHelpStyle.Render(" | /: search")
}
content := lipgloss.JoinHorizontal(
lipgloss.Left,
searchIndicator,
m.input.View(),
filterText,
countText,
helpText,
)
return searchBarStyle.Render(content)
}
// IsActive returns true if search is active
func (m *SearchModel) IsActive() bool {
return m.active
}
// Activate activates the search
func (m *SearchModel) Activate() {
m.active = true
m.input.Focus()
}
// Deactivate deactivates the search
func (m *SearchModel) Deactivate() {
m.active = false
m.input.Reset()
m.matchCount = m.totalCount
}
// Clear clears the search input and filter
func (m *SearchModel) Clear() {
m.input.Reset()
m.matchCount = m.totalCount
}
// GetQuery returns the current search query
func (m *SearchModel) GetQuery() string {
return m.input.Value()
}
// GetFilterType returns the current filter type
func (m *SearchModel) GetFilterType() SearchFilterType {
return m.filterType
}
// SetTotalCount sets the total number of items
func (m *SearchModel) SetTotalCount(count int) {
m.totalCount = count
if !m.active {
m.matchCount = count
}
}
// SetMatchCount sets the number of matching items
func (m *SearchModel) SetMatchCount(count int) {
m.matchCount = count
}
// Filter filters a list of client data based on the current search query
func (m *SearchModel) Filter(clients []ClientData) []ClientData {
query := strings.TrimSpace(m.input.Value())
if query == "" || !m.active {
m.matchCount = len(clients)
return clients
}
var filtered []ClientData
queryLower := strings.ToLower(query)
for _, client := range clients {
var matches bool
switch m.filterType {
case FilterByName:
matches = strings.Contains(strings.ToLower(client.Name), queryLower)
case FilterByIPv4:
matches = strings.Contains(strings.ToLower(client.IPv4), queryLower)
case FilterByIPv6:
matches = strings.Contains(strings.ToLower(client.IPv6), queryLower)
case FilterByStatus:
matches = strings.Contains(strings.ToLower(client.Status), queryLower)
}
if matches {
filtered = append(filtered, client)
}
}
m.matchCount = len(filtered)
return filtered
}
// HighlightMatches highlights matching text in the given value
func (m *SearchModel) HighlightMatches(value string) string {
if !m.active {
return value
}
query := strings.TrimSpace(m.input.Value())
if query == "" {
return value
}
queryLower := strings.ToLower(query)
valueLower := strings.ToLower(value)
index := strings.Index(valueLower, queryLower)
if index == -1 {
return value
}
matchStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("226")).
Background(lipgloss.Color("57")).
Bold(true)
before := value[:index]
match := value[index+len(query)]
after := value[index+len(query):]
return lipgloss.JoinHorizontal(
lipgloss.Left,
before,
matchStyle.Render(string(match)),
after,
)
}
// cycleFilterType cycles to the next filter type
func (m *SearchModel) cycleFilterType() {
switch m.filterType {
case FilterByName:
m.filterType = FilterByIPv4
case FilterByIPv4:
m.filterType = FilterByIPv6
case FilterByIPv6:
m.filterType = FilterByStatus
case FilterByStatus:
m.filterType = FilterByName
}
}
// renderCount renders a count number with proper styling
func (m *SearchModel) renderCount(count int) string {
if m.matchCount == 0 && m.active && m.input.Value() != "" {
return lipgloss.NewStyle().
Foreground(lipgloss.Color("196")).
Render("No matches")
}
return searchCountStyle.Render(string(rune('0' + count)))
}
// ClientData represents client data for filtering
type ClientData struct {
Name string
IPv4 string
IPv6 string
Status string
}

89
internal/tui/model.go Normal file
View File

@@ -0,0 +1,89 @@
package screens
import (
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
)
// Screen represents a UI screen (list, add, detail, etc.)
type Screen interface {
Init() tea.Cmd
Update(tea.Msg) (Screen, tea.Cmd)
View() string
}
// Model is shared state across all screens
type Model struct {
err error
isQuitting bool
ready bool
statusMessage string
screen Screen
}
// View renders the model (implements Screen interface)
func (m *Model) View() string {
if m.err != nil {
return "\n" + lipgloss.NewStyle().
Foreground(lipgloss.Color("196")).
Render(m.err.Error())
}
if m.isQuitting {
return "\nGoodbye!\n"
}
if m.screen != nil {
return m.screen.View()
}
return "Initializing..."
}
// Init initializes the model
func (m *Model) Init() tea.Cmd {
m.ready = true
return nil
}
// Update handles incoming messages (implements Screen interface)
func (m *Model) Update(msg tea.Msg) (Screen, tea.Cmd) {
// If we have an error, let it persist for display
if m.err != nil {
switch msg := msg.(type) {
case tea.KeyMsg:
if msg.String() == "q" || msg.String() == "ctrl+c" {
m.isQuitting = true
return m, tea.Quit
}
}
return m, nil
}
// No error - handle normally
switch msg := msg.(type) {
case tea.KeyMsg:
switch msg.String() {
case "q", "ctrl+c":
m.isQuitting = true
return m, tea.Quit
}
}
return m, nil
}
// SetScreen changes the current screen
func (m *Model) SetScreen(screen Screen) {
m.screen = screen
}
// SetError sets an error message
func (m *Model) SetError(err error) {
m.err = err
}
// ClearError clears the error message
func (m *Model) ClearError() {
m.err = nil
}

152
internal/tui/screens/add.go Normal file
View File

@@ -0,0 +1,152 @@
package screens
import (
"fmt"
"github.com/calmcacil/wg-admin/internal/config"
"github.com/calmcacil/wg-admin/internal/validation"
"github.com/calmcacil/wg-admin/internal/wireguard"
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/huh"
"github.com/charmbracelet/lipgloss"
)
// AddScreen is a form for adding new WireGuard clients
type AddScreen struct {
form *huh.Form
quitting bool
}
// Styles
var (
addTitleStyle = lipgloss.NewStyle().
Foreground(lipgloss.Color("62")).
Bold(true).
MarginBottom(1)
addHelpStyle = lipgloss.NewStyle().
Foreground(lipgloss.Color("241")).
MarginTop(1)
)
// NewAddScreen creates a new add screen
func NewAddScreen() *AddScreen {
// Get default DNS from config
cfg, err := config.LoadConfig()
defaultDNS := "8.8.8.8, 8.8.4.4"
if err == nil && cfg.DNSServers != "" {
defaultDNS = cfg.DNSServers
}
// Create the form
form := huh.NewForm(
huh.NewGroup(
huh.NewInput().
Key("name").
Title("Client Name").
Description("Name for the new client (alphanumeric, -, _)").
Placeholder("e.g., laptop-john").
Validate(func(s string) error {
return validation.ValidateClientName(s)
}),
huh.NewInput().
Key("dns").
Title("DNS Servers").
Description("Comma-separated IPv4 addresses").
Placeholder("e.g., 8.8.8.8, 8.8.4.4").
Value(&defaultDNS).
Validate(func(s string) error {
return validation.ValidateDNSServers(s)
}),
huh.NewConfirm().
Key("use_psk").
Title("Use Preshared Key").
Description("Enable additional security layer with a preshared key").
Affirmative("Yes").
Negative("No"),
),
)
return &AddScreen{
form: form,
quitting: false,
}
}
// Init initializes the add screen
func (s *AddScreen) Init() tea.Cmd {
return s.form.Init()
}
// Update handles messages for the add screen
func (s *AddScreen) Update(msg tea.Msg) (Screen, tea.Cmd) {
var cmds []tea.Cmd
switch msg := msg.(type) {
case tea.KeyMsg:
switch msg.String() {
case "q", "ctrl+c", "esc":
// Cancel and return to list
return nil, nil
}
}
// Update the form
form, cmd := s.form.Update(msg)
if f, ok := form.(*huh.Form); ok {
s.form = f
}
cmds = append(cmds, cmd)
// Check if form is completed
if s.form.State == huh.StateCompleted {
name := s.form.GetString("name")
dns := s.form.GetString("dns")
usePSK := s.form.GetBool("use_psk")
// Create the client
return s, s.createClient(name, dns, usePSK)
}
return s, tea.Batch(cmds...)
}
// View renders the add screen
func (s *AddScreen) View() string {
if s.quitting {
return ""
}
content := lipgloss.JoinVertical(
lipgloss.Left,
addTitleStyle.Render("Add New WireGuard Client"),
s.form.View(),
addHelpStyle.Render("Press Enter to submit • Esc to cancel"),
)
return content
}
// createClient creates a new WireGuard client
func (s *AddScreen) createClient(name, dns string, usePSK bool) tea.Cmd {
return func() tea.Msg {
// Create the client via wireguard package
err := wireguard.CreateClient(name, dns, usePSK)
if err != nil {
return errMsg{err: fmt.Errorf("failed to create client: %w", err)}
}
// Return success message
return ClientCreatedMsg{
Name: name,
}
}
}
// Messages
// ClientCreatedMsg is sent when a client is successfully created
type ClientCreatedMsg struct {
Name string
}

View File

@@ -0,0 +1,299 @@
package screens
import (
"fmt"
"time"
"github.com/calmcacil/wg-admin/internal/tui/components"
"github.com/calmcacil/wg-admin/internal/wireguard"
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
)
// DetailScreen displays detailed information about a single WireGuard client
type DetailScreen struct {
client wireguard.Client
status string
lastHandshake time.Time
transferRx string
transferTx string
confirmModal *components.ConfirmModel
showConfirm bool
clipboardCopied bool
clipboardTimer int
}
// Styles
var (
detailTitleStyle = lipgloss.NewStyle().
Foreground(lipgloss.Color("62")).
Bold(true)
detailSectionStyle = lipgloss.NewStyle().
Foreground(lipgloss.Color("241")).
Bold(true).
MarginTop(1)
detailLabelStyle = lipgloss.NewStyle().
Foreground(lipgloss.Color("241")).
Width(18)
detailValueStyle = lipgloss.NewStyle().
Foreground(lipgloss.Color("255"))
detailConnectedStyle = lipgloss.NewStyle().
Foreground(lipgloss.Color("46")).
Bold(true)
detailDisconnectedStyle = lipgloss.NewStyle().
Foreground(lipgloss.Color("196")).
Bold(true)
detailWarningStyle = lipgloss.NewStyle().
Foreground(lipgloss.Color("226")).
Bold(true)
detailHelpStyle = lipgloss.NewStyle().
Foreground(lipgloss.Color("63")).
MarginTop(1)
detailErrorStyle = lipgloss.NewStyle().
Foreground(lipgloss.Color("196")).
MarginTop(1)
)
// NewDetailScreen creates a new detail screen for a client
func NewDetailScreen(client wireguard.Client) *DetailScreen {
return &DetailScreen{
client: client,
showConfirm: false,
}
}
// Init initializes the detail screen
func (s *DetailScreen) Init() tea.Cmd {
return s.loadClientStatus
}
// Update handles messages for the detail screen
func (s *DetailScreen) Update(msg tea.Msg) (Screen, tea.Cmd) {
var cmd tea.Cmd
// Handle clipboard copy timeout
if s.clipboardCopied {
s.clipboardTimer++
if s.clipboardTimer > 2 {
s.clipboardCopied = false
s.clipboardTimer = 0
}
}
// Handle confirmation modal
if s.showConfirm && s.confirmModal != nil {
_, cmd = s.confirmModal.Update(msg)
// Handle confirmation result
if !s.confirmModal.Visible {
if s.confirmModal.IsConfirmed() {
// User confirmed deletion
return s, tea.Batch(s.deleteClient(), func() tea.Msg {
return CloseDetailScreenMsg{}
})
}
// User cancelled - close modal
s.showConfirm = false
return s, nil
}
// Handle Enter key to confirm
switch msg := msg.(type) {
case tea.KeyMsg:
if msg.String() == "enter" && s.confirmModal.IsConfirmed() {
return s, tea.Batch(s.deleteClient(), func() tea.Msg {
return CloseDetailScreenMsg{}
})
}
}
return s, cmd
}
// Handle normal screen messages
switch msg := msg.(type) {
case clipboardCopiedMsg:
s.clipboardCopied = true
case clientStatusLoadedMsg:
s.status = msg.status
s.lastHandshake = msg.lastHandshake
s.transferRx = msg.transferRx
s.transferTx = msg.transferTx
case tea.KeyMsg:
switch msg.String() {
case "q", "esc":
// Return to list screen - signal parent to switch screens
return s, nil
case "d":
// Show delete confirmation
s.confirmModal = components.NewConfirm(
fmt.Sprintf("Are you sure you want to delete client '%s'?\n\nThis action cannot be undone.", s.client.Name),
80,
24,
)
s.showConfirm = true
case "c":
// Copy public key to clipboard
return s, s.copyPublicKey()
}
}
return s, cmd
}
// View renders the detail screen
func (s *DetailScreen) View() string {
if s.showConfirm && s.confirmModal != nil {
// Render underlying content dimmed
content := s.renderContent()
dimmedContent := lipgloss.NewStyle().
Foreground(lipgloss.Color("244")).
Render(content)
// Overlay confirmation modal
return lipgloss.JoinVertical(
lipgloss.Left,
dimmedContent,
s.confirmModal.View(),
)
}
return s.renderContent()
}
// renderContent renders the main detail screen content
func (s *DetailScreen) renderContent() string {
statusText := s.status
if s.status == wireguard.StatusConnected {
statusText = detailConnectedStyle.Render(s.status)
} else {
statusText = detailDisconnectedStyle.Render(s.status)
}
// Build content
content := lipgloss.JoinVertical(
lipgloss.Left,
detailTitleStyle.Render(fmt.Sprintf("Client Details: %s", s.client.Name)),
"",
s.renderField("Status", statusText),
s.renderField("IPv4 Address", detailValueStyle.Render(s.client.IPv4)),
s.renderField("IPv6 Address", detailValueStyle.Render(s.client.IPv6)),
"",
detailSectionStyle.Render("WireGuard Configuration"),
s.renderField("Public Key", detailValueStyle.Render(s.client.PublicKey)),
s.renderField("Preshared Key", detailValueStyle.Render(func() string {
if s.client.HasPSK {
return "✓ Configured"
}
return "Not configured"
}())),
"",
detailSectionStyle.Render("Connection Info"),
s.renderField("Last Handshake", detailValueStyle.Render(s.formatHandshake())),
s.renderField("Transfer (Rx/Tx)", detailValueStyle.Render(fmt.Sprintf("%s / %s", s.transferRx, s.transferTx))),
s.renderField("Config Path", detailValueStyle.Render(s.client.ConfigPath)),
"",
)
// Add help text
helpText := detailHelpStyle.Render("Actions: [d] Delete • [c] Copy Public Key • [q] Back")
content = lipgloss.JoinVertical(lipgloss.Left, content, helpText)
// Show clipboard confirmation
if s.clipboardCopied {
content += "\n" + lipgloss.NewStyle().Foreground(lipgloss.Color("46")).Render("✓ Public key copied to clipboard!")
}
return content
}
// renderField renders a label-value pair
func (s *DetailScreen) renderField(label string, value string) string {
return lipgloss.JoinHorizontal(lipgloss.Left,
detailLabelStyle.Render(label),
value,
)
}
// formatHandshake formats the last handshake time
func (s *DetailScreen) formatHandshake() string {
if s.lastHandshake.IsZero() {
return "Never"
}
duration := time.Since(s.lastHandshake)
if duration < time.Minute {
return "Just now"
} else if duration < time.Hour {
return fmt.Sprintf("%d min ago", int(duration.Minutes()))
} else if duration < 24*time.Hour {
return fmt.Sprintf("%d hours ago", int(duration.Hours()))
} else if duration < 7*24*time.Hour {
return fmt.Sprintf("%d days ago", int(duration.Hours()/24))
}
return s.lastHandshake.Format("2006-01-02 15:04")
}
// loadClientStatus loads the current status of the client
func (s *DetailScreen) loadClientStatus() tea.Msg {
peers, err := wireguard.GetAllPeers()
if err != nil {
return errMsg{err: err}
}
// Find peer by public key
for _, peer := range peers {
if peer.PublicKey == s.client.PublicKey {
return clientStatusLoadedMsg{
status: peer.Status,
lastHandshake: peer.LatestHandshake,
transferRx: peer.TransferRx,
transferTx: peer.TransferTx,
}
}
}
// Peer not found in active list
return clientStatusLoadedMsg{
status: wireguard.StatusDisconnected,
lastHandshake: time.Time{},
transferRx: "",
transferTx: "",
}
}
// copyPublicKey copies the public key to clipboard
func (s *DetailScreen) copyPublicKey() tea.Cmd {
return func() tea.Msg {
// Note: In a real implementation, you would use a clipboard library like
// github.com/atotto/clipboard or implement platform-specific clipboard access
// For now, we'll just simulate the action
return clipboardCopiedMsg{}
}
}
// deleteClient deletes the client
func (s *DetailScreen) deleteClient() tea.Cmd {
return func() tea.Msg {
err := wireguard.DeleteClient(s.client.Name)
if err != nil {
return errMsg{fmt.Errorf("failed to delete client: %w", err)}
}
return ClientDeletedMsg{
Name: s.client.Name,
}
}
}
// Messages
// clientStatusLoadedMsg is sent when client status is loaded
type clientStatusLoadedMsg struct {
status string
lastHandshake time.Time
transferRx string
transferTx string
}
// clipboardCopiedMsg is sent when public key is copied to clipboard
type clipboardCopiedMsg struct{}

View File

@@ -0,0 +1,31 @@
package screens
import (
tea "github.com/charmbracelet/bubbletea"
)
// Screen represents a UI screen (list, add, detail, etc.)
type Screen interface {
Init() tea.Cmd
Update(tea.Msg) (Screen, tea.Cmd)
View() string
}
// ClientSelectedMsg is sent when a client is selected from the list
type ClientSelectedMsg struct {
Client ClientWithStatus
}
// ClientDeletedMsg is sent when a client is successfully deleted
type ClientDeletedMsg struct {
Name string
}
// CloseDetailScreenMsg signals to close detail screen
type CloseDetailScreenMsg struct{}
// RestoreCompletedMsg is sent when a restore operation completes
type RestoreCompletedMsg struct {
Err error
SafetyBackupPath string
}

View File

@@ -0,0 +1,324 @@
package screens
import (
"sort"
"strings"
"github.com/calmcacil/wg-admin/internal/tui/components"
"github.com/calmcacil/wg-admin/internal/wireguard"
"github.com/charmbracelet/bubbles/table"
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
)
const statusRefreshInterval = 10 // seconds
// ListScreen displays a table of WireGuard clients
type ListScreen struct {
table table.Model
search *components.SearchModel
clients []ClientWithStatus
filtered []ClientWithStatus
sortedBy string // Column name being sorted by
ascending bool // Sort direction
}
// ClientWithStatus wraps a client with its connection status
type ClientWithStatus struct {
Client wireguard.Client
Status string
}
// NewListScreen creates a new list screen
func NewListScreen() *ListScreen {
return &ListScreen{
search: components.NewSearch(),
sortedBy: "Name",
ascending: true,
}
}
// Init initializes the list screen
func (s *ListScreen) Init() tea.Cmd {
return tea.Batch(
s.loadClients,
wireguard.Tick(statusRefreshInterval),
)
}
// Update handles messages for the list screen
func (s *ListScreen) Update(msg tea.Msg) (Screen, tea.Cmd) {
var cmd tea.Cmd
switch msg := msg.(type) {
case tea.KeyMsg:
// Handle search activation
if msg.String() == "/" && !s.search.IsActive() {
s.search.Activate()
return s, nil
}
// If search is active, pass input to search
if s.search.IsActive() {
s.search, cmd = s.search.Update(msg)
// Apply filter to clients
s.applyFilter()
return s, cmd
}
// Normal key handling when search is not active
switch msg.String() {
case "q", "ctrl+c":
// Handle quit in parent model
return s, nil
case "r":
// Refresh clients
return s, s.loadClients
case "R":
// Show restore screen
return NewRestoreScreen(), nil
case "Q":
// Show QR code for selected client
if len(s.table.Rows()) > 0 {
selected := s.table.SelectedRow()
clientName := selected[0] // First column is Name
return NewQRScreen(clientName), nil
}
case "enter":
// Open detail view for selected client
if len(s.table.Rows()) > 0 {
selectedRow := s.table.SelectedRow()
selectedName := selectedRow[0] // First column is Name
// Find the client with this name
for _, cws := range s.clients {
if cws.Client.Name == selectedName {
return s, func() tea.Msg {
return ClientSelectedMsg{Client: cws}
}
}
}
}
case "1", "2", "3", "4":
// Sort by column number (Name, IPv4, IPv6, Status)
s.sortByColumn(msg.String())
}
case clientsLoadedMsg:
s.clients = msg.clients
s.search.SetTotalCount(len(s.clients))
s.applyFilter()
case wireguard.StatusTickMsg:
// Refresh status on periodic tick
return s, s.loadClients
case wireguard.RefreshStatusMsg:
// Refresh status on manual refresh
return s, s.loadClients
}
s.table, cmd = s.table.Update(msg)
return s, cmd
}
// View renders the list screen
func (s *ListScreen) View() string {
if len(s.clients) == 0 {
return s.search.View() + "\n" + "No clients found. Press 'r' to refresh or 'q' to quit."
}
// Check if there are no matches
if s.search.IsActive() && len(s.filtered) == 0 && s.search.GetQuery() != "" {
return s.search.View() + "\n" + lipgloss.NewStyle().
Foreground(lipgloss.Color("241")).
Italic(true).
Render("No matching clients found. Try a different search term.")
}
return s.search.View() + "\n" + s.table.View()
}
// loadClients loads clients from wireguard config
func (s *ListScreen) loadClients() tea.Msg {
clients, err := wireguard.ListClients()
if err != nil {
return errMsg{err: err}
}
// Get status for each client
clientsWithStatus := make([]ClientWithStatus, len(clients))
for i, client := range clients {
status, err := wireguard.GetClientStatus(client.PublicKey)
if err != nil {
status = wireguard.StatusDisconnected
}
clientsWithStatus[i] = ClientWithStatus{
Client: client,
Status: status,
}
}
return clientsLoadedMsg{clients: clientsWithStatus}
}
// applyFilter applies the current search filter to clients
func (s *ListScreen) applyFilter() {
// Convert clients to ClientData for filtering
clientData := make([]components.ClientData, len(s.clients))
for i, cws := range s.clients {
clientData[i] = components.ClientData{
Name: cws.Client.Name,
IPv4: cws.Client.IPv4,
IPv6: cws.Client.IPv6,
Status: cws.Status,
}
}
// Filter clients
filteredData := s.search.Filter(clientData)
// Convert back to ClientWithStatus
s.filtered = make([]ClientWithStatus, len(filteredData))
for i, cd := range filteredData {
// Find the matching client
for _, cws := range s.clients {
if cws.Client.Name == cd.Name {
s.filtered[i] = cws
break
}
}
}
// Rebuild table with filtered clients
s.buildTable()
}
// buildTable creates and configures the table
func (s *ListScreen) buildTable() {
columns := []table.Column{
{Title: "Name", Width: 20},
{Title: "IPv4", Width: 15},
{Title: "IPv6", Width: 35},
{Title: "Status", Width: 12},
}
// Use filtered clients if search is active, otherwise use all clients
displayClients := s.filtered
if !s.search.IsActive() {
displayClients = s.clients
}
var rows []table.Row
for _, cws := range displayClients {
row := table.Row{
cws.Client.Name,
cws.Client.IPv4,
cws.Client.IPv6,
cws.Status,
}
rows = append(rows, row)
}
// Sort rows based on current sort settings
s.sortRows(rows)
// Determine table height
tableHeight := len(rows) + 2 // Header + rows
if tableHeight < 5 {
tableHeight = 5
}
s.table = table.New(
table.WithColumns(columns),
table.WithRows(rows),
table.WithFocused(true),
table.WithHeight(tableHeight),
)
// Apply styles
s.setTableStyles()
}
// setTableStyles applies styling to the table
func (s *ListScreen) setTableStyles() {
styles := table.DefaultStyles()
styles.Header = styles.Header.
BorderStyle(lipgloss.NormalBorder()).
BorderForeground(lipgloss.Color("240")).
BorderBottom(true).
Bold(true)
styles.Selected = styles.Selected.
Foreground(lipgloss.Color("229")).
Background(lipgloss.Color("57")).
Bold(false)
s.table.SetStyles(styles)
}
// sortRows sorts the rows based on the current sort settings
func (s *ListScreen) sortRows(rows []table.Row) {
colIndex := s.getColumnIndex(s.sortedBy)
sort.Slice(rows, func(i, j int) bool {
var valI, valJ string
if colIndex < len(rows[i]) {
valI = rows[i][colIndex]
}
if colIndex < len(rows[j]) {
valJ = rows[j][colIndex]
}
if s.ascending {
return strings.ToLower(valI) < strings.ToLower(valJ)
}
return strings.ToLower(valI) > strings.ToLower(valJ)
})
}
// sortByColumn changes the sort column
func (s *ListScreen) sortByColumn(col string) {
sortedBy := "Name"
switch col {
case "1":
sortedBy = "Name"
case "2":
sortedBy = "IPv4"
case "3":
sortedBy = "IPv6"
case "4":
sortedBy = "Status"
}
// Toggle direction if clicking same column
if s.sortedBy == sortedBy {
s.ascending = !s.ascending
} else {
s.sortedBy = sortedBy
s.ascending = true
}
s.buildTable()
}
// getColumnIndex returns the index of a column by name
func (s *ListScreen) getColumnIndex(name string) int {
switch name {
case "Name":
return 0
case "IPv4":
return 1
case "IPv6":
return 2
case "Status":
return 3
}
return 0
}
// Messages
// clientsLoadedMsg is sent when clients are loaded
type clientsLoadedMsg struct {
clients []ClientWithStatus
}
// errMsg is sent when an error occurs
type errMsg struct {
err error
}

141
internal/tui/screens/qr.go Normal file
View File

@@ -0,0 +1,141 @@
package screens
import (
"fmt"
"strings"
"github.com/calmcacil/wg-admin/internal/wireguard"
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
"github.com/mdp/qrterminal/v3"
)
// QRScreen displays a QR code for a WireGuard client configuration
type QRScreen struct {
clientName string
configContent string
qrCode string
inlineMode bool
width, height int
errorMsg string
}
// NewQRScreen creates a new QR screen for displaying client config QR codes
func NewQRScreen(clientName string) *QRScreen {
return &QRScreen{
clientName: clientName,
inlineMode: true, // Start in inline mode
}
}
// Init initializes the QR screen
func (s *QRScreen) Init() tea.Cmd {
return s.loadConfig
}
// Update handles messages for the QR screen
func (s *QRScreen) Update(msg tea.Msg) (Screen, tea.Cmd) {
switch msg := msg.(type) {
case tea.KeyMsg:
switch msg.String() {
case "q", "escape":
// Return to list screen (parent should handle this)
return nil, nil
case "f":
// Toggle between inline and fullscreen mode
s.inlineMode = !s.inlineMode
s.generateQRCode()
}
case tea.WindowSizeMsg:
// Handle terminal resize
s.width = msg.Width
s.height = msg.Height
s.generateQRCode()
case configLoadedMsg:
s.configContent = msg.content
s.generateQRCode()
case errMsg:
s.errorMsg = msg.err.Error()
}
return s, nil
}
// View renders the QR screen
func (s *QRScreen) View() string {
if s.errorMsg != "" {
return s.renderError()
}
if s.qrCode == "" {
return "Loading QR code..."
}
return s.renderQR()
}
// loadConfig loads the client configuration
func (s *QRScreen) loadConfig() tea.Msg {
content, err := wireguard.GetClientConfigContent(s.clientName)
if err != nil {
return errMsg{err: err}
}
return configLoadedMsg{content: content}
}
// generateQRCode generates the QR code based on current mode and terminal size
func (s *QRScreen) generateQRCode() {
if s.configContent == "" {
return
}
// Generate QR code and capture output
var builder strings.Builder
// Generate ANSI QR code using half-block characters
qrterminal.GenerateHalfBlock(s.configContent, qrterminal.L, &builder)
s.qrCode = builder.String()
}
// renderQR renders the QR code with styling
func (s *QRScreen) renderQR() string {
styleTitle := lipgloss.NewStyle().
Foreground(lipgloss.Color("62")).
Bold(true).
MarginBottom(1)
styleHelp := lipgloss.NewStyle().
Foreground(lipgloss.Color("241")).
MarginTop(1)
styleQR := lipgloss.NewStyle().
MarginLeft(2)
title := styleTitle.Render(fmt.Sprintf("QR Code: %s", s.clientName))
help := "Press [f] to toggle fullscreen • Press [q/Escape] to return"
return title + "\n\n" + styleQR.Render(s.qrCode) + "\n" + styleHelp.Render(help)
}
// renderError renders an error message
func (s *QRScreen) renderError() string {
styleError := lipgloss.NewStyle().
Foreground(lipgloss.Color("196")).
Bold(true)
styleHelp := lipgloss.NewStyle().
Foreground(lipgloss.Color("241")).
MarginTop(1)
title := styleError.Render("Error")
message := s.errorMsg
help := "Press [q/Escape] to return"
return title + "\n\n" + message + "\n" + styleHelp.Render(help)
}
// Messages
// configLoadedMsg is sent when the client configuration is loaded
type configLoadedMsg struct {
content string
}

View File

@@ -0,0 +1,333 @@
package screens
import (
"fmt"
"strings"
"github.com/calmcacil/wg-admin/internal/backup"
"github.com/calmcacil/wg-admin/internal/tui/components"
"github.com/charmbracelet/bubbles/table"
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
)
// RestoreScreen displays a list of available backups for restoration
type RestoreScreen struct {
table table.Model
backups []backup.Backup
selectedBackup *backup.Backup
confirmModal *components.ConfirmModel
showConfirm bool
isRestoring bool
restoreError error
restoreSuccess bool
message string
}
// Styles
var (
restoreTitleStyle = lipgloss.NewStyle().
Foreground(lipgloss.Color("62")).
Bold(true)
restoreHelpStyle = lipgloss.NewStyle().
Foreground(lipgloss.Color("63")).
MarginTop(1)
restoreSuccessStyle = lipgloss.NewStyle().
Foreground(lipgloss.Color("46")).
Bold(true)
restoreErrorStyle = lipgloss.NewStyle().
Foreground(lipgloss.Color("196")).
Bold(true)
restoreInfoStyle = lipgloss.NewStyle().
Foreground(lipgloss.Color("241")).
MarginTop(1)
)
// NewRestoreScreen creates a new restore screen
func NewRestoreScreen() *RestoreScreen {
return &RestoreScreen{
showConfirm: false,
}
}
// Init initializes the restore screen
func (s *RestoreScreen) Init() tea.Cmd {
return s.loadBackups
}
// Update handles messages for the restore screen
func (s *RestoreScreen) Update(msg tea.Msg) (Screen, tea.Cmd) {
var cmd tea.Cmd
// Handle confirmation modal
if s.showConfirm && s.confirmModal != nil {
_, cmd = s.confirmModal.Update(msg)
// Handle confirmation result
if !s.confirmModal.Visible {
if s.confirmModal.IsConfirmed() && s.selectedBackup != nil {
// User confirmed restore
s.isRestoring = true
s.showConfirm = false
return s, s.performRestore()
}
// User cancelled - close modal
s.showConfirm = false
return s, nil
}
// Handle Enter key to confirm
switch msg := msg.(type) {
case tea.KeyMsg:
if msg.String() == "enter" && s.confirmModal.IsConfirmed() && s.selectedBackup != nil {
s.isRestoring = true
s.showConfirm = false
return s, s.performRestore()
}
}
return s, cmd
}
// Handle normal screen messages
switch msg := msg.(type) {
case backupsLoadedMsg:
s.backups = msg.backups
s.buildTable()
case tea.KeyMsg:
switch msg.String() {
case "q", "esc":
// Return to list screen - signal parent to switch screens
return s, nil
case "enter":
// Show confirmation for selected backup
if len(s.table.Rows()) > 0 {
selected := s.table.SelectedRow()
if len(selected) > 0 {
// Find the backup by name
for _, b := range s.backups {
if b.Name == selected[0] {
s.selectedBackup = &b
s.confirmModal = components.NewConfirm(
fmt.Sprintf(
"Are you sure you want to restore from backup '%s'?\n\nOperation: %s\nDate: %s\n\nThis will replace current WireGuard configuration.\nA safety backup will be created first.",
b.Name,
b.Operation,
b.Timestamp.Format("2006-01-02 15:04:05"),
),
80,
24,
)
s.showConfirm = true
break
}
}
}
}
}
case restoreCompletedMsg:
s.isRestoring = false
if msg.err != nil {
s.restoreError = msg.err
s.message = fmt.Sprintf("Restore failed: %v", msg.err)
} else {
s.restoreSuccess = true
s.message = fmt.Sprintf("Restore successful! Safety backup created at: %s", msg.safetyBackupPath)
}
}
if !s.showConfirm && s.confirmModal != nil {
s.table, cmd = s.table.Update(msg)
}
return s, cmd
}
// View renders the restore screen
func (s *RestoreScreen) View() string {
if s.showConfirm && s.confirmModal != nil {
// Render underlying content dimmed
content := s.renderContent()
dimmedContent := lipgloss.NewStyle().
Foreground(lipgloss.Color("244")).
Render(content)
// Overlay confirmation modal
return lipgloss.JoinVertical(
lipgloss.Left,
dimmedContent,
s.confirmModal.View(),
)
}
return s.renderContent()
}
// renderContent renders the main restore screen content
func (s *RestoreScreen) renderContent() string {
var content strings.Builder
content.WriteString(restoreTitleStyle.Render("Restore WireGuard Configuration"))
content.WriteString("\n\n")
if len(s.backups) == 0 && !s.isRestoring && s.message == "" {
content.WriteString("No backups found. Press 'q' to return.")
return content.String()
}
if s.isRestoring {
content.WriteString("Restoring from backup, please wait...")
return content.String()
}
if s.restoreSuccess {
content.WriteString(restoreSuccessStyle.Render("✓ " + s.message))
content.WriteString("\n\n")
content.WriteString(restoreInfoStyle.Render("Press 'q' to return to client list."))
return content.String()
}
if s.restoreError != nil {
content.WriteString(restoreErrorStyle.Render("✗ " + s.message))
content.WriteString("\n\n")
content.WriteString(s.table.View())
content.WriteString("\n\n")
content.WriteString(restoreHelpStyle.Render("Actions: [Enter] Restore Selected • [↑/↓] Navigate • [q] Back"))
return content.String()
}
// Show backup list
content.WriteString(s.table.View())
content.WriteString("\n\n")
// Show selected backup details
if len(s.table.Rows()) > 0 && s.selectedBackup != nil {
content.WriteString(restoreInfoStyle.Render(
fmt.Sprintf(
"Selected: %s (%s) - %s\nSize: %s",
s.selectedBackup.Operation,
s.selectedBackup.Timestamp.Format("2006-01-02 15:04:05"),
s.selectedBackup.Name,
formatBytes(s.selectedBackup.Size),
),
))
content.WriteString("\n")
}
content.WriteString(restoreHelpStyle.Render("Actions: [Enter] Restore Selected • [↑/↓] Navigate • [q] Back"))
return content.String()
}
// loadBackups loads the list of available backups
func (s *RestoreScreen) loadBackups() tea.Msg {
backups, err := backup.ListBackups()
if err != nil {
return errMsg{err: err}
}
return backupsLoadedMsg{backups: backups}
}
// buildTable creates and configures the backup list table
func (s *RestoreScreen) buildTable() {
columns := []table.Column{
{Title: "Name", Width: 40},
{Title: "Operation", Width: 15},
{Title: "Date", Width: 20},
{Title: "Size", Width: 12},
}
var rows []table.Row
for _, b := range s.backups {
row := table.Row{
b.Name,
b.Operation,
b.Timestamp.Format("2006-01-02 15:04"),
formatBytes(b.Size),
}
rows = append(rows, row)
}
s.table = table.New(
table.WithColumns(columns),
table.WithRows(rows),
table.WithFocused(true),
table.WithHeight(len(rows)+2), // Header + rows
)
// Apply styles
s.setTableStyles()
}
// setTableStyles applies styling to the table
func (s *RestoreScreen) setTableStyles() {
styles := table.DefaultStyles()
styles.Header = styles.Header.
BorderStyle(lipgloss.NormalBorder()).
BorderForeground(lipgloss.Color("240")).
BorderBottom(true).
Bold(true)
styles.Selected = styles.Selected.
Foreground(lipgloss.Color("229")).
Background(lipgloss.Color("57")).
Bold(false)
s.table.SetStyles(styles)
}
// performRestore performs the restore operation
func (s *RestoreScreen) performRestore() tea.Cmd {
return func() tea.Msg {
if s.selectedBackup == nil {
return restoreCompletedMsg{
err: fmt.Errorf("no backup selected"),
}
}
// Get safety backup path from backup.BackupConfig
safetyBackupPath, err := backup.BackupConfig(fmt.Sprintf("pre-restore-from-%s", s.selectedBackup.Name))
if err != nil {
return restoreCompletedMsg{
err: fmt.Errorf("failed to create safety backup: %w", err),
}
}
// Perform restore
if err := backup.RestoreBackup(s.selectedBackup.Name); err != nil {
return restoreCompletedMsg{
err: err,
safetyBackupPath: safetyBackupPath,
}
}
// Restore succeeded - trigger client list refresh
return restoreCompletedMsg{
safetyBackupPath: safetyBackupPath,
}
}
}
// formatBytes formats a byte count into human-readable format
func formatBytes(bytes int64) string {
const unit = 1024
if bytes < unit {
return fmt.Sprintf("%d B", bytes)
}
div, exp := int64(unit), 0
for n := bytes / unit; n >= unit; n /= unit {
div *= unit
exp++
}
return fmt.Sprintf("%.1f %ciB", float64(bytes)/float64(div), "KMGTPE"[exp])
}
// Messages
// backupsLoadedMsg is sent when backups are loaded
type backupsLoadedMsg struct {
backups []backup.Backup
}
// restoreCompletedMsg is sent when a restore operation completes
type restoreCompletedMsg struct {
err error
safetyBackupPath string
}

185
internal/tui/theme/theme.go Normal file
View File

@@ -0,0 +1,185 @@
package theme
import (
"fmt"
"os"
"sync"
"github.com/charmbracelet/lipgloss"
)
// ColorScheme defines the color palette for the theme
type ColorScheme struct {
Primary lipgloss.Color
Success lipgloss.Color
Warning lipgloss.Color
Error lipgloss.Color
Muted lipgloss.Color
Background lipgloss.Color
}
// Theme represents a color theme with its name and color scheme
type Theme struct {
Name string
Scheme ColorScheme
}
// Global variables for current theme and styles
var (
currentTheme *Theme
once sync.Once
// Global styles that can be used throughout the application
StylePrimary lipgloss.Style
StyleSuccess lipgloss.Style
StyleWarning lipgloss.Style
StyleError lipgloss.Style
StyleMuted lipgloss.Style
StyleTitle lipgloss.Style
StyleSubtitle lipgloss.Style
StyleHelpKey lipgloss.Style
)
// DefaultTheme is the standard blue-based theme
var DefaultTheme = &Theme{
Name: "default",
Scheme: ColorScheme{
Primary: lipgloss.Color("62"), // Blue
Success: lipgloss.Color("46"), // Green
Warning: lipgloss.Color("208"), // Orange
Error: lipgloss.Color("196"), // Red
Muted: lipgloss.Color("241"), // Gray
Background: lipgloss.Color(""), // Default terminal background
},
}
// DarkTheme is a purple-based dark theme
var DarkTheme = &Theme{
Name: "dark",
Scheme: ColorScheme{
Primary: lipgloss.Color("141"), // Purple
Success: lipgloss.Color("51"), // Cyan
Warning: lipgloss.Color("226"), // Yellow
Error: lipgloss.Color("196"), // Red
Muted: lipgloss.Color("245"), // Light gray
Background: lipgloss.Color(""), // Default terminal background
},
}
// LightTheme is a green-based light theme
var LightTheme = &Theme{
Name: "light",
Scheme: ColorScheme{
Primary: lipgloss.Color("34"), // Green
Success: lipgloss.Color("36"), // Teal
Warning: lipgloss.Color("214"), // Amber
Error: lipgloss.Color("196"), // Red
Muted: lipgloss.Color("244"), // Gray
Background: lipgloss.Color(""), // Default terminal background
},
}
// ThemeRegistry holds all available themes
var ThemeRegistry = map[string]*Theme{
"default": DefaultTheme,
"dark": DarkTheme,
"light": LightTheme,
}
// GetTheme loads the theme from config or environment variable
// Returns the default theme if no theme is specified
func GetTheme() (*Theme, error) {
once.Do(func() {
// Try to get theme from environment variable first
themeName := os.Getenv("THEME")
if themeName == "" {
themeName = "default"
}
// Look up the theme in the registry
if theme, ok := ThemeRegistry[themeName]; ok {
currentTheme = theme
} else {
// If theme not found, use default
currentTheme = DefaultTheme
}
// Apply the theme to initialize styles
ApplyTheme(currentTheme)
})
return currentTheme, nil
}
// ApplyTheme applies the given theme to all global styles
func ApplyTheme(theme *Theme) {
currentTheme = theme
// Primary style
StylePrimary = lipgloss.NewStyle().
Foreground(theme.Scheme.Primary)
// Success style
StyleSuccess = lipgloss.NewStyle().
Foreground(theme.Scheme.Success)
// Warning style
StyleWarning = lipgloss.NewStyle().
Foreground(theme.Scheme.Warning)
// Error style
StyleError = lipgloss.NewStyle().
Foreground(theme.Scheme.Error)
// Muted style
StyleMuted = lipgloss.NewStyle().
Foreground(theme.Scheme.Muted)
// Title style (bold primary)
StyleTitle = lipgloss.NewStyle().
Foreground(theme.Scheme.Primary).
Bold(true)
// Subtitle style (muted)
StyleSubtitle = lipgloss.NewStyle().
Foreground(theme.Scheme.Muted)
// Help key style (bold primary, slightly different shade)
StyleHelpKey = lipgloss.NewStyle().
Foreground(theme.Scheme.Primary).
Bold(true)
}
// GetThemeNames returns a list of available theme names
func GetThemeNames() []string {
names := make([]string, 0, len(ThemeRegistry))
for name := range ThemeRegistry {
names = append(names, name)
}
return names
}
// SetTheme changes the current theme by name
func SetTheme(name string) error {
theme, ok := ThemeRegistry[name]
if !ok {
return fmt.Errorf("theme '%s' not found. Available themes: %v", name, GetThemeNames())
}
ApplyTheme(theme)
return nil
}
// GetCurrentTheme returns the currently active theme
func GetCurrentTheme() *Theme {
if currentTheme == nil {
currentTheme = DefaultTheme
ApplyTheme(currentTheme)
}
return currentTheme
}
// GetColorScheme returns the color scheme of the current theme
func GetColorScheme() ColorScheme {
return GetCurrentTheme().Scheme
}

View File

@@ -0,0 +1,86 @@
package validation
import (
"fmt"
"regexp"
"strings"
)
// ValidateClientName validates the client name format
// Requirements: alphanumeric, hyphens, underscores only, max 64 characters
func ValidateClientName(name string) error {
if name == "" {
return fmt.Errorf("client name cannot be empty")
}
if len(name) > 64 {
return fmt.Errorf("client name must be 64 characters or less (got %d)", len(name))
}
// Check for valid characters: a-z, A-Z, 0-9, hyphen, underscore
matched, err := regexp.MatchString(`^[a-zA-Z0-9_-]+$`, name)
if err != nil {
return fmt.Errorf("failed to validate client name: %w", err)
}
if !matched {
return fmt.Errorf("client name contains invalid characters (allowed: a-z, A-Z, 0-9, -, _)")
}
return nil
}
// ValidateIPAvailability checks if an IP address is already assigned to a client
// This function is meant to be called after getting next available IPs
func ValidateIPAvailability(ipv4, ipv6 string, existingIPs map[string]bool) error {
if ipv4 != "" && existingIPs[ipv4] {
return fmt.Errorf("IPv4 address %s is already assigned", ipv4)
}
if ipv6 != "" && existingIPs[ipv6] {
return fmt.Errorf("IPv6 address %s is already assigned", ipv6)
}
return nil
}
// ValidateDNSServers validates DNS server addresses
// Requirements: comma-separated list of IPv4 addresses
func ValidateDNSServers(dns string) error {
if dns == "" {
return fmt.Errorf("DNS servers cannot be empty")
}
// Split by comma and trim whitespace
servers := strings.Split(dns, ",")
for _, server := range servers {
server = strings.TrimSpace(server)
if server == "" {
continue
}
// Basic IPv4 validation (4 octets, each 0-255)
matched, err := regexp.MatchString(`^(\d{1,3}\.){3}\d{1,3}$`, server)
if err != nil {
return fmt.Errorf("failed to validate DNS server: %w", err)
}
if !matched {
return fmt.Errorf("invalid DNS server format: %s (expected IPv4 address)", server)
}
// Validate each octet is in range 0-255
parts := strings.Split(server, ".")
for _, part := range parts {
var num int
if _, err := fmt.Sscanf(part, "%d", &num); err != nil {
return fmt.Errorf("invalid DNS server: %s", server)
}
if num < 0 || num > 255 {
return fmt.Errorf("DNS server octet out of range (0-255) in: %s", server)
}
}
}
return nil
}

View File

@@ -0,0 +1,585 @@
package wireguard
import (
"bytes"
"fmt"
"log"
"os"
"os/exec"
"path/filepath"
"strings"
"github.com/calmcacil/wg-admin/internal/backup"
"github.com/calmcacil/wg-admin/internal/config"
)
// Client represents a WireGuard peer configuration
type Client struct {
Name string // Client name extracted from filename
IPv4 string // IPv4 address from AllowedIPs
IPv6 string // IPv6 address from AllowedIPs
PublicKey string // WireGuard public key
HasPSK bool // Whether PresharedKey is configured
ConfigPath string // Path to the client config file
}
// ParseClientConfig parses a single WireGuard client configuration file
func ParseClientConfig(path string) (*Client, error) {
// Read the file
content, err := os.ReadFile(path)
if err != nil {
return nil, fmt.Errorf("failed to read config file %s: %w", path, err)
}
// Extract client name from filename
base := filepath.Base(path)
name := strings.TrimPrefix(base, "client-")
name = strings.TrimSuffix(name, ".conf")
if name == "" || name == "client-" {
return nil, fmt.Errorf("invalid client filename: %s", base)
}
client := &Client{
Name: name,
ConfigPath: path,
}
// Parse the INI-style config
inPeerSection := false
hasPublicKey := false
for i, line := range strings.Split(string(content), "\n") {
line = strings.TrimSpace(line)
// Skip empty lines and comments
if line == "" || strings.HasPrefix(line, "#") || strings.HasPrefix(line, ";") {
continue
}
// Check for Peer section
if line == "[Peer]" {
inPeerSection = true
continue
}
// Parse key-value pairs within Peer section
if inPeerSection {
parts := strings.SplitN(line, "=", 2)
if len(parts) != 2 {
log.Printf("Warning: malformed line %d in %s: %s", i+1, path, line)
continue
}
key := strings.TrimSpace(parts[0])
value := strings.TrimSpace(parts[1])
switch key {
case "PublicKey":
client.PublicKey = value
hasPublicKey = true
case "PresharedKey":
client.HasPSK = true
case "AllowedIPs":
if err := parseAllowedIPs(client, value); err != nil {
log.Printf("Warning: %v (file: %s, line: %d)", err, path, i+1)
}
}
}
}
// Validate required fields
if !hasPublicKey {
return nil, fmt.Errorf("missing required PublicKey in %s", path)
}
if client.IPv4 == "" && client.IPv6 == "" {
return nil, fmt.Errorf("no valid IP addresses found in AllowedIPs in %s", path)
}
return client, nil
}
// parseAllowedIPs extracts IPv4 and IPv6 addresses from AllowedIPs value
func parseAllowedIPs(client *Client, allowedIPs string) error {
// AllowedIPs format: "ipv4/32, ipv6/128"
addresses := strings.Split(allowedIPs, ",")
for _, addr := range addresses {
addr = strings.TrimSpace(addr)
if addr == "" {
continue
}
// Split IP from CIDR suffix
parts := strings.Split(addr, "/")
if len(parts) != 2 {
return fmt.Errorf("invalid AllowedIP format: %s", addr)
}
ip := strings.TrimSpace(parts[0])
// Detect if IPv4 or IPv6 based on presence of colon
if strings.Contains(ip, ":") {
client.IPv6 = ip
} else {
client.IPv4 = ip
}
}
return nil
}
// ListClients finds and parses all client configurations from /etc/wireguard/conf.d/
func ListClients() ([]Client, error) {
configDir := "/etc/wireguard/conf.d"
// Check if directory exists
if _, err := os.Stat(configDir); os.IsNotExist(err) {
return nil, fmt.Errorf("wireguard config directory does not exist: %s", configDir)
}
// Find all client-*.conf files
pattern := filepath.Join(configDir, "client-*.conf")
matches, err := filepath.Glob(pattern)
if err != nil {
return nil, fmt.Errorf("failed to find client config files: %w", err)
}
if len(matches) == 0 {
return []Client{}, nil // No clients found, return empty slice
}
// Parse each config file
var clients []Client
var parseErrors []string
for _, match := range matches {
client, err := ParseClientConfig(match)
if err != nil {
parseErrors = append(parseErrors, err.Error())
log.Printf("Warning: failed to parse %s: %v", match, err)
continue
}
clients = append(clients, *client)
}
// If all files failed to parse, return an error
if len(clients) == 0 && len(parseErrors) > 0 {
return nil, fmt.Errorf("failed to parse any client configs: %s", strings.Join(parseErrors, "; "))
}
return clients, nil
}
// GetClientConfigContent reads the raw configuration content for a client
func GetClientConfigContent(name string) (string, error) {
configDir := "/etc/wireguard/clients"
configPath := filepath.Join(configDir, fmt.Sprintf("%s.conf", name))
content, err := os.ReadFile(configPath)
if err != nil {
if os.IsNotExist(err) {
return "", fmt.Errorf("client config not found: %s", configPath)
}
return "", fmt.Errorf("failed to read client config %s: %w", configPath, err)
}
return string(content), nil
}
// DeleteClient removes a WireGuard client configuration and associated files
func DeleteClient(name string) error {
// First, find the client config to get public key for removal from interface
configDir := "/etc/wireguard/conf.d"
configPath := filepath.Join(configDir, fmt.Sprintf("client-%s.conf", name))
client, err := ParseClientConfig(configPath)
if err != nil {
return fmt.Errorf("failed to parse client config for deletion: %w", err)
}
log.Printf("Deleting client: %s (public key: %s)", name, client.PublicKey)
// Create backup before deletion
backupPath, err := backup.BackupConfig(fmt.Sprintf("delete-%s", name))
if err != nil {
log.Printf("Warning: failed to create backup before deletion: %v", err)
} else {
log.Printf("Created backup: %s", backupPath)
}
// Remove peer from WireGuard interface using wg command
if err := removePeerFromInterface(client.PublicKey); err != nil {
log.Printf("Warning: failed to remove peer from interface: %v", err)
}
// Remove client config from /etc/wireguard/conf.d/
if err := os.Remove(configPath); err != nil && !os.IsNotExist(err) {
return fmt.Errorf("failed to remove client config %s: %w", configPath, err)
}
log.Printf("Removed client config: %s", configPath)
// Remove client files from /etc/wireguard/clients/
clientsDir := "/etc/wireguard/clients"
clientFile := filepath.Join(clientsDir, fmt.Sprintf("%s.conf", name))
if err := os.Remove(clientFile); err != nil && !os.IsNotExist(err) {
log.Printf("Warning: failed to remove client file %s: %v", clientFile, err)
} else {
log.Printf("Removed client file: %s", clientFile)
}
// Remove QR code PNG if it exists
qrFile := filepath.Join(clientsDir, fmt.Sprintf("%s.png", name))
if err := os.Remove(qrFile); err != nil && !os.IsNotExist(err) {
log.Printf("Warning: failed to remove QR code %s: %v", qrFile, err)
} else {
log.Printf("Removed QR code: %s", qrFile)
}
log.Printf("Successfully deleted client: %s", name)
return nil
}
// removePeerFromInterface removes a peer from the WireGuard interface
func removePeerFromInterface(publicKey string) error {
// Use wg command to remove peer
cmd := exec.Command("wg", "set", "wg0", "peer", publicKey, "remove")
output, err := cmd.CombinedOutput()
if err != nil {
return fmt.Errorf("wg set peer remove failed: %w, output: %s", err, string(output))
}
return nil
}
// CreateClient creates a new WireGuard client configuration
func CreateClient(name, dns string, usePSK bool) error {
log.Printf("Creating client: %s (PSK: %v)", name, usePSK)
// Create backup before creating client
backupPath, err := backup.BackupConfig(fmt.Sprintf("create-%s", name))
if err != nil {
log.Printf("Warning: failed to create backup before creating client: %v", err)
} else {
log.Printf("Created backup: %s", backupPath)
}
// Generate keys
privateKey, publicKey, err := generateKeyPair()
if err != nil {
return fmt.Errorf("failed to generate key pair: %w", err)
}
var psk string
if usePSK {
psk, err = generatePSK()
if err != nil {
return fmt.Errorf("failed to generate PSK: %w", err)
}
}
// Get next available IP addresses
cfg, err := config.LoadConfig()
if err != nil {
return fmt.Errorf("failed to load config: %w", err)
}
clients, err := ListClients()
if err != nil {
return fmt.Errorf("failed to list existing clients: %w", err)
}
ipv4, err := getNextAvailableIP(cfg.VPNIPv4Range, clients)
if err != nil {
return fmt.Errorf("failed to get IPv4 address: %w", err)
}
ipv6, err := getNextAvailableIP(cfg.VPNIPv6Range, clients)
if err != nil {
return fmt.Errorf("failed to get IPv6 address: %w", err)
}
// Create server config
serverConfigPath := fmt.Sprintf("/etc/wireguard/conf.d/client-%s.conf", name)
serverConfig, err := generateServerConfig(name, publicKey, ipv4, ipv6, psk, cfg)
if err != nil {
return fmt.Errorf("failed to generate server config: %w", err)
}
if err := os.WriteFile(serverConfigPath, []byte(serverConfig), 0600); err != nil {
return fmt.Errorf("failed to write server config: %w", err)
}
log.Printf("Created server config: %s", serverConfigPath)
// Create client config
clientsDir := "/etc/wireguard/clients"
if err := os.MkdirAll(clientsDir, 0700); err != nil {
return fmt.Errorf("failed to create clients directory: %w", err)
}
clientConfigPath := filepath.Join(clientsDir, fmt.Sprintf("%s.conf", name))
clientConfig, err := generateClientConfig(name, privateKey, ipv4, ipv6, dns, cfg)
if err != nil {
return fmt.Errorf("failed to generate client config: %w", err)
}
if err := os.WriteFile(clientConfigPath, []byte(clientConfig), 0600); err != nil {
return fmt.Errorf("failed to write client config: %w", err)
}
log.Printf("Created client config: %s", clientConfigPath)
// Add peer to WireGuard interface
if err := addPeerToInterface(publicKey, ipv4, ipv6, psk); err != nil {
log.Printf("Warning: failed to add peer to interface: %v", err)
} else {
log.Printf("Added peer to WireGuard interface")
}
// Generate QR code
qrPath := filepath.Join(clientsDir, fmt.Sprintf("%s.png", name))
if err := generateQRCode(clientConfigPath, qrPath); err != nil {
log.Printf("Warning: failed to generate QR code: %v", err)
} else {
log.Printf("Generated QR code: %s", qrPath)
}
log.Printf("Successfully created client: %s", name)
return nil
}
// generateKeyPair generates a WireGuard private and public key pair
func generateKeyPair() (privateKey, publicKey string, err error) {
// Generate private key
privateKeyBytes, err := exec.Command("wg", "genkey").Output()
if err != nil {
return "", "", fmt.Errorf("wg genkey failed: %w", err)
}
privateKey = strings.TrimSpace(string(privateKeyBytes))
// Derive public key
pubKeyCmd := exec.Command("wg", "pubkey")
pubKeyCmd.Stdin = strings.NewReader(privateKey)
publicKeyBytes, err := pubKeyCmd.Output()
if err != nil {
return "", "", fmt.Errorf("wg pubkey failed: %w", err)
}
publicKey = strings.TrimSpace(string(publicKeyBytes))
return privateKey, publicKey, nil
}
// generatePSK generates a WireGuard preshared key
func generatePSK() (string, error) {
psk, err := exec.Command("wg", "genpsk").Output()
if err != nil {
return "", fmt.Errorf("wg genpsk failed: %w", err)
}
return strings.TrimSpace(string(psk)), nil
}
// getNextAvailableIP finds the next available IP address in the given CIDR range
func getNextAvailableIP(cidr string, existingClients []Client) (string, error) {
// Parse CIDR to get network
parts := strings.Split(cidr, "/")
if len(parts) != 2 {
return "", fmt.Errorf("invalid CIDR format: %s", cidr)
}
network := strings.TrimSpace(parts[0])
// For IPv4, extract base network and assign next available host
if !strings.Contains(network, ":") {
// IPv4: Simple implementation - use .1, .2, etc.
// In production, this would parse the CIDR properly
usedHosts := make(map[string]bool)
for _, client := range existingClients {
if client.IPv4 != "" {
ipParts := strings.Split(client.IPv4, ".")
if len(ipParts) == 4 {
usedHosts[ipParts[3]] = true
}
}
}
// Find next available host (skip 0 and 1 as they may be reserved)
for i := 2; i < 255; i++ {
host := fmt.Sprintf("%d", i)
if !usedHosts[host] {
return fmt.Sprintf("%s.%s", network, host), nil
}
}
return "", fmt.Errorf("no available IPv4 addresses in range: %s", cidr)
}
// IPv6: Similar simplified approach
usedHosts := make(map[string]bool)
for _, client := range existingClients {
if client.IPv6 != "" {
// Extract last segment for IPv6
lastColon := strings.LastIndex(client.IPv6, ":")
if lastColon > 0 {
host := client.IPv6[lastColon+1:]
usedHosts[host] = true
}
}
}
// Find next available host
for i := 1; i < 65536; i++ {
host := fmt.Sprintf("%x", i)
if !usedHosts[host] {
return fmt.Sprintf("%s:%s", network, host), nil
}
}
return "", fmt.Errorf("no available IPv6 addresses in range: %s", cidr)
}
// generateServerConfig generates the server-side configuration for a client
func generateServerConfig(name, publicKey, ipv4, ipv6, psk string, cfg *config.Config) (string, error) {
var builder strings.Builder
builder.WriteString(fmt.Sprintf("# Client: %s\n", name))
builder.WriteString(fmt.Sprintf("[Peer]\n"))
builder.WriteString(fmt.Sprintf("PublicKey = %s\n", publicKey))
allowedIPs := ""
if ipv4 != "" {
allowedIPs = ipv4 + "/32"
}
if ipv6 != "" {
if allowedIPs != "" {
allowedIPs += ", "
}
allowedIPs += ipv6 + "/128"
}
builder.WriteString(fmt.Sprintf("AllowedIPs = %s\n", allowedIPs))
if psk != "" {
builder.WriteString(fmt.Sprintf("PresharedKey = %s\n", psk))
}
return builder.String(), nil
}
// generateClientConfig generates the client-side configuration
func generateClientConfig(name, privateKey, ipv4, ipv6, dns string, cfg *config.Config) (string, error) {
// Get server's public key from the main config
serverConfigPath := "/etc/wireguard/wg0.conf"
serverPublicKey, serverEndpoint, err := getServerConfig(serverConfigPath)
if err != nil {
return "", fmt.Errorf("failed to read server config: %w", err)
}
var builder strings.Builder
builder.WriteString(fmt.Sprintf("# WireGuard client configuration for %s\n", name))
builder.WriteString("[Interface]\n")
builder.WriteString(fmt.Sprintf("PrivateKey = %s\n", privateKey))
builder.WriteString(fmt.Sprintf("Address = %s/32", ipv4))
if ipv6 != "" {
builder.WriteString(fmt.Sprintf(", %s/128", ipv6))
}
builder.WriteString("\n")
builder.WriteString(fmt.Sprintf("DNS = %s\n", dns))
builder.WriteString("\n")
builder.WriteString("[Peer]\n")
builder.WriteString(fmt.Sprintf("PublicKey = %s\n", serverPublicKey))
builder.WriteString(fmt.Sprintf("Endpoint = %s:%d\n", serverEndpoint, cfg.WGPort))
builder.WriteString("AllowedIPs = 0.0.0.0/0, ::/0\n")
return builder.String(), nil
}
// getServerConfig reads the server's public key and endpoint from the main config
func getServerConfig(path string) (publicKey, endpoint string, err error) {
content, err := os.ReadFile(path)
if err != nil {
return "", "", fmt.Errorf("failed to read server config: %w", err)
}
inInterfaceSection := false
for _, line := range strings.Split(string(content), "\n") {
line = strings.TrimSpace(line)
if line == "[Interface]" {
inInterfaceSection = true
continue
}
if line == "[Peer]" {
inInterfaceSection = false
continue
}
if inInterfaceSection {
if strings.HasPrefix(line, "PublicKey") {
parts := strings.SplitN(line, "=", 2)
if len(parts) == 2 {
publicKey = strings.TrimSpace(parts[1])
}
}
}
}
// Use SERVER_DOMAIN as endpoint if available, otherwise fallback
cfg, err := config.LoadConfig()
if err == nil && cfg.ServerDomain != "" {
endpoint = cfg.ServerDomain
} else {
endpoint = "0.0.0.0"
}
return publicKey, endpoint, nil
}
// addPeerToInterface adds a peer to the WireGuard interface
func addPeerToInterface(publicKey, ipv4, ipv6, psk string) error {
args := []string{"set", "wg0", "peer", publicKey}
if ipv4 != "" {
args = append(args, "allowed-ips", ipv4+"/32")
}
if ipv6 != "" {
args = append(args, ipv6+"/128")
}
if psk != "" {
args = append(args, "preshared-key", "/dev/stdin")
cmd := exec.Command("wg", args...)
cmd.Stdin = strings.NewReader(psk)
output, err := cmd.CombinedOutput()
if err != nil {
return fmt.Errorf("wg set peer failed: %w, output: %s", err, string(output))
}
} else {
cmd := exec.Command("wg", args...)
output, err := cmd.CombinedOutput()
if err != nil {
return fmt.Errorf("wg set peer failed: %w, output: %s", err, string(output))
}
}
return nil
}
// generateQRCode generates a QR code from the client config
func generateQRCode(configPath, qrPath string) error {
// Read config file
content, err := os.ReadFile(configPath)
if err != nil {
return fmt.Errorf("failed to read config file: %w", err)
}
// Generate QR code using qrencode or similar
// For now, use a simple approach with qrencode if available
cmd := exec.Command("qrencode", "-o", qrPath, "-t", "PNG")
cmd.Stdin = bytes.NewReader(content)
if err := cmd.Run(); err != nil {
// qrencode not available, try alternative method
return fmt.Errorf("qrencode not available: %w", err)
}
return nil
}

View File

@@ -0,0 +1,106 @@
package wireguard
import (
"fmt"
"log"
"os"
"path/filepath"
"strings"
)
// GenerateServerConfig generates a server-side WireGuard [Peer] configuration
// and writes it to /etc/wireguard/conf.d/client-<name>.conf
func GenerateServerConfig(name, publicKey string, hasPSK bool, ipv4, ipv6, psk string) (string, error) {
var builder strings.Builder
builder.WriteString(fmt.Sprintf("[Peer]\n"))
builder.WriteString(fmt.Sprintf("# %s\n", name))
builder.WriteString(fmt.Sprintf("PublicKey = %s\n", publicKey))
if hasPSK {
builder.WriteString(fmt.Sprintf("PresharedKey = %s\n", psk))
}
builder.WriteString(fmt.Sprintf("AllowedIPs = %s/32", ipv4))
if ipv6 != "" {
builder.WriteString(fmt.Sprintf(", %s/128", ipv6))
}
builder.WriteString("\n")
configContent := builder.String()
configDir := "/etc/wireguard/conf.d"
configPath := filepath.Join(configDir, fmt.Sprintf("client-%s.conf", name))
if err := atomicWrite(configPath, configContent); err != nil {
return "", fmt.Errorf("failed to write server config: %w", err)
}
log.Printf("Generated server config: %s", configPath)
return configPath, nil
}
// GenerateClientConfig generates a client-side WireGuard configuration
// with [Interface] and [Peer] sections and writes it to /etc/wireguard/clients/<name>.conf
func GenerateClientConfig(name, privateKey, ipv4, ipv6, dns, serverPublicKey, endpoint string, port int, hasPSK bool, psk string) (string, error) {
var builder strings.Builder
// [Interface] section
builder.WriteString("[Interface]\n")
builder.WriteString(fmt.Sprintf("PrivateKey = %s\n", privateKey))
builder.WriteString(fmt.Sprintf("Address = %s/24", ipv4))
if ipv6 != "" {
builder.WriteString(fmt.Sprintf(", %s/64", ipv6))
}
builder.WriteString("\n")
builder.WriteString(fmt.Sprintf("DNS = %s\n", dns))
// [Peer] section
builder.WriteString("\n[Peer]\n")
builder.WriteString(fmt.Sprintf("PublicKey = %s\n", serverPublicKey))
if hasPSK {
builder.WriteString(fmt.Sprintf("PresharedKey = %s\n", psk))
}
builder.WriteString(fmt.Sprintf("Endpoint = %s:%d\n", endpoint, port))
builder.WriteString("AllowedIPs = 0.0.0.0/0, ::/0\n")
builder.WriteString("PersistentKeepalive = 25\n")
configContent := builder.String()
clientsDir := "/etc/wireguard/clients"
configPath := filepath.Join(clientsDir, fmt.Sprintf("%s.conf", name))
if err := atomicWrite(configPath, configContent); err != nil {
return "", fmt.Errorf("failed to write client config: %w", err)
}
log.Printf("Generated client config: %s", configPath)
return configPath, nil
}
// atomicWrite writes content to a file atomically
// Uses temp file + rename pattern for atomicity and sets permissions to 0600
func atomicWrite(path, content string) error {
// Ensure parent directory exists
dir := filepath.Dir(path)
if err := os.MkdirAll(dir, 0755); err != nil {
return fmt.Errorf("failed to create directory %s: %w", dir, err)
}
// Write to temp file first
tempPath := path + ".tmp"
if err := os.WriteFile(tempPath, []byte(content), 0600); err != nil {
return fmt.Errorf("failed to write temp file: %w", err)
}
// Atomic rename
if err := os.Rename(tempPath, path); err != nil {
// Clean up temp file if rename fails
_ = os.Remove(tempPath)
return fmt.Errorf("failed to rename temp file: %w", err)
}
return nil
}

243
internal/wireguard/keys.go Normal file
View File

@@ -0,0 +1,243 @@
package wireguard
import (
"encoding/base64"
"fmt"
"log"
"os"
"os/exec"
"path/filepath"
"strings"
"sync"
)
const (
// KeyLength is the standard WireGuard key length in bytes (32 bytes = 256 bits)
KeyLength = 32
// Base64KeyLength is the length of a base64-encoded WireGuard key (44 characters)
Base64KeyLength = 44
// TempDir is the directory for temporary key storage
TempDir = "/tmp/wg-admin"
)
var (
tempKeys = make(map[string]bool)
tempKeysMutex sync.Mutex
)
// KeyPair represents a WireGuard key pair
type KeyPair struct {
PrivateKey string // Base64-encoded private key
PublicKey string // Base64-encoded public key
}
// GeneratePrivateKey generates a new WireGuard private key using wg genkey
func GeneratePrivateKey() (string, error) {
cmd := exec.Command("wg", "genkey")
output, err := cmd.CombinedOutput()
if err != nil {
return "", fmt.Errorf("wg genkey failed: %w, output: %s", err, string(output))
}
key := string(output)
if err := ValidateKey(key); err != nil {
return "", fmt.Errorf("generated invalid private key: %w", err)
}
// Store as temporary key
tempKeyPath := filepath.Join(TempDir, "private.key")
if err := storeTempKey(tempKeyPath, key); err != nil {
log.Printf("Warning: failed to store temporary private key: %v", err)
}
return key, nil
}
// GeneratePublicKey generates a public key from a private key using wg pubkey
func GeneratePublicKey(privateKey string) (string, error) {
if err := ValidateKey(privateKey); err != nil {
return "", fmt.Errorf("invalid private key: %w", err)
}
cmd := exec.Command("wg", "pubkey")
cmd.Stdin = strings.NewReader(privateKey)
output, err := cmd.CombinedOutput()
if err != nil {
return "", fmt.Errorf("wg pubkey failed: %w, output: %s", err, string(output))
}
key := string(output)
if err := ValidateKey(key); err != nil {
return "", fmt.Errorf("generated invalid public key: %w", err)
}
// Store as temporary key
tempKeyPath := filepath.Join(TempDir, "public.key")
if err := storeTempKey(tempKeyPath, key); err != nil {
log.Printf("Warning: failed to store temporary public key: %v", err)
}
return key, nil
}
// GeneratePSK generates a new pre-shared key using wg genpsk
func GeneratePSK() (string, error) {
cmd := exec.Command("wg", "genpsk")
output, err := cmd.CombinedOutput()
if err != nil {
return "", fmt.Errorf("wg genpsk failed: %w, output: %s", err, string(output))
}
key := string(output)
if err := ValidateKey(key); err != nil {
return "", fmt.Errorf("generated invalid PSK: %w", err)
}
// Store as temporary key
tempKeyPath := filepath.Join(TempDir, "psk.key")
if err := storeTempKey(tempKeyPath, key); err != nil {
log.Printf("Warning: failed to store temporary PSK: %v", err)
}
return key, nil
}
// GenerateKeyPair generates a complete WireGuard key pair (private + public)
func GenerateKeyPair() (*KeyPair, error) {
privateKey, err := GeneratePrivateKey()
if err != nil {
return nil, err
}
publicKey, err := GeneratePublicKey(privateKey)
if err != nil {
return nil, fmt.Errorf("failed to generate public key: %w", err)
}
return &KeyPair{
PrivateKey: privateKey,
PublicKey: publicKey,
}, nil
}
// ValidateKey validates that a key is properly formatted (44 base64 characters)
func ValidateKey(key string) error {
// Trim whitespace
key = strings.TrimSpace(key)
// Check length (44 base64 characters for 32 bytes)
if len(key) != Base64KeyLength {
return fmt.Errorf("invalid key length: expected %d characters, got %d", Base64KeyLength, len(key))
}
// Verify it's valid base64
decoded, err := base64.StdEncoding.DecodeString(key)
if err != nil {
return fmt.Errorf("invalid base64 encoding: %w", err)
}
// Verify decoded length is 32 bytes
if len(decoded) != KeyLength {
return fmt.Errorf("invalid decoded key length: expected %d bytes, got %d", KeyLength, len(decoded))
}
return nil
}
// StoreKey atomically writes a key to a file with 0600 permissions
func StoreKey(path string, key string) error {
// Validate key before storing
if err := ValidateKey(key); err != nil {
return fmt.Errorf("invalid key: %w", err)
}
// Trim whitespace
key = strings.TrimSpace(key)
// Create parent directories if needed
dir := filepath.Dir(path)
if err := os.MkdirAll(dir, 0755); err != nil {
return fmt.Errorf("failed to create directory %s: %w", dir, err)
}
// Write to temporary file
tempPath := path + ".tmp"
if err := os.WriteFile(tempPath, []byte(key), 0600); err != nil {
return fmt.Errorf("failed to write temp file %s: %w", tempPath, err)
}
// Atomic rename
if err := os.Rename(tempPath, path); err != nil {
os.Remove(tempPath) // Clean up temp file on failure
return fmt.Errorf("failed to rename temp file to %s: %w", path, err)
}
return nil
}
// LoadKey reads a key from a file and validates it
func LoadKey(path string) (string, error) {
data, err := os.ReadFile(path)
if err != nil {
return "", fmt.Errorf("failed to read key file %s: %w", path, err)
}
key := strings.TrimSpace(string(data))
if err := ValidateKey(key); err != nil {
return "", fmt.Errorf("invalid key in file %s: %w", path, err)
}
return key, nil
}
// storeTempKey stores a temporary key and tracks it for cleanup
func storeTempKey(path string, key string) error {
// Create temp directory if needed
if err := os.MkdirAll(TempDir, 0700); err != nil {
return fmt.Errorf("failed to create temp directory %s: %w", TempDir, err)
}
// Trim whitespace
key = strings.TrimSpace(key)
// Write to file with 0600 permissions
if err := os.WriteFile(path, []byte(key), 0600); err != nil {
return fmt.Errorf("failed to write temp key to %s: %w", path, err)
}
// Track for cleanup
tempKeysMutex.Lock()
tempKeys[path] = true
tempKeysMutex.Unlock()
return nil
}
// CleanupTempKeys removes all temporary keys
func CleanupTempKeys() error {
tempKeysMutex.Lock()
defer tempKeysMutex.Unlock()
var cleanupErrors []string
for path := range tempKeys {
if err := os.Remove(path); err != nil && !os.IsNotExist(err) {
cleanupErrors = append(cleanupErrors, fmt.Sprintf("%s: %v", path, err))
log.Printf("Warning: failed to remove temp key %s: %v", path, err)
}
delete(tempKeys, path)
}
// Also attempt to clean up the temp directory if empty
if _, err := os.ReadDir(TempDir); err == nil {
if err := os.Remove(TempDir); err != nil && !os.IsNotExist(err) {
log.Printf("Warning: failed to remove temp directory %s: %v", TempDir, err)
}
}
if len(cleanupErrors) > 0 {
return fmt.Errorf("cleanup errors: %s", strings.Join(cleanupErrors, "; "))
}
return nil
}

View File

@@ -0,0 +1,204 @@
package wireguard
import (
"bytes"
"fmt"
"os/exec"
"regexp"
"strconv"
"strings"
"time"
)
const (
// StatusConnected indicates a peer has an active connection
StatusConnected = "Connected"
// StatusDisconnected indicates a peer is not connected
StatusDisconnected = "Disconnected"
)
// PeerStatus represents the status of a WireGuard peer
type PeerStatus struct {
PublicKey string `json:"public_key"`
Endpoint string `json:"endpoint"`
AllowedIPs string `json:"allowed_ips"`
LatestHandshake time.Time `json:"latest_handshake"`
TransferRx string `json:"transfer_rx"`
TransferTx string `json:"transfer_tx"`
Status string `json:"status"` // "Connected" or "Disconnected"
}
// GetClientStatus checks if a specific client is connected
// Returns "Connected" if the peer appears in the active peers list, "Disconnected" otherwise
func GetClientStatus(publicKey string) (string, error) {
peers, err := GetAllPeers()
if err != nil {
return StatusDisconnected, fmt.Errorf("failed to get peer status: %w", err)
}
for _, peer := range peers {
if peer.PublicKey == publicKey {
return peer.Status, nil
}
}
return StatusDisconnected, nil
}
// GetAllPeers retrieves all peers with their current status from WireGuard
func GetAllPeers() ([]PeerStatus, error) {
output, err := exec.Command("wg", "show", "wg0").Output()
if err != nil {
return nil, fmt.Errorf("failed to execute wg show: %w", err)
}
return parsePeersOutput(string(output)), nil
}
// parsePeersOutput parses the output of 'wg show wg0' command
func parsePeersOutput(output string) []PeerStatus {
var peers []PeerStatus
var currentPeer *PeerStatus
var handshake string
var transfer string
lines := strings.Split(output, "\n")
peerLineRegex := regexp.MustCompile(`^peer:\s*(.+)$`)
handshakeRegex := regexp.MustCompile(`^latest handshake:\s*(.+)\s+ago$`)
transferRegex := regexp.MustCompile(`^transfer:\s*(.+),\s+(.+)$`)
for _, line := range lines {
line = strings.TrimSpace(line)
// Check for new peer
if match := peerLineRegex.FindStringSubmatch(line); match != nil {
// Save previous peer if exists
if currentPeer != nil {
peers = append(peers, finalizePeerStatus(currentPeer, handshake, transfer))
}
// Start new peer
currentPeer = &PeerStatus{
PublicKey: match[1],
}
handshake = ""
transfer = ""
continue
}
if currentPeer == nil {
continue
}
// Parse endpoint
if strings.HasPrefix(line, "endpoint:") {
currentPeer.Endpoint = strings.TrimSpace(strings.TrimPrefix(line, "endpoint:"))
}
// Parse allowed ips
if strings.HasPrefix(line, "allowed ips:") {
currentPeer.AllowedIPs = strings.TrimSpace(strings.TrimPrefix(line, "allowed ips:"))
}
// Parse latest handshake
if match := handshakeRegex.FindStringSubmatch(line); match != nil {
handshake = match[1]
}
// Parse transfer
if match := transferRegex.FindStringSubmatch(line); match != nil {
transfer = fmt.Sprintf("%s, %s", match[1], match[2])
}
}
// Don't forget the last peer
if currentPeer != nil {
peers = append(peers, finalizePeerStatus(currentPeer, handshake, transfer))
}
return peers
}
// finalizePeerStatus determines the peer's status based on handshake time
func finalizePeerStatus(peer *PeerStatus, handshake string, transfer string) PeerStatus {
peer.TransferRx = ""
peer.TransferTx = ""
// Parse transfer
if transfer != "" {
parts := strings.Split(transfer, ", ")
if len(parts) == 2 {
// Extract received and sent values
rxParts := strings.Fields(parts[0])
if len(rxParts) >= 2 {
peer.TransferRx = strings.Join(rxParts[:2], " ")
}
txParts := strings.Fields(parts[1])
if len(txParts) >= 2 {
peer.TransferTx = strings.Join(txParts[:2], " ")
}
}
}
// Determine status based on handshake
if handshake != "" {
peer.LatestHandshake = parseHandshake(handshake)
// Peer is considered connected if handshake is recent (within 3 minutes)
if time.Since(peer.LatestHandshake) < 3*time.Minute {
peer.Status = StatusConnected
} else {
peer.Status = StatusDisconnected
}
} else {
peer.Status = StatusDisconnected
}
return *peer
}
// parseHandshake converts handshake string to time.Time
func parseHandshake(handshake string) time.Time {
now := time.Now()
parts := strings.Fields(handshake)
for i, part := range parts {
if strings.HasSuffix(part, "second") || strings.HasSuffix(part, "seconds") {
if val, err := strconv.Atoi(strings.TrimSuffix(part, "s")); err == nil {
return now.Add(-time.Duration(val) * time.Second)
}
}
if strings.HasSuffix(part, "minute") || strings.HasSuffix(part, "minutes") {
if val, err := strconv.Atoi(strings.TrimSuffix(part, "s")); err == nil {
return now.Add(-time.Duration(val) * time.Minute)
}
}
if strings.HasSuffix(part, "hour") || strings.HasSuffix(part, "hours") {
if val, err := strconv.Atoi(strings.TrimSuffix(part, "s")); err == nil {
return now.Add(-time.Duration(val) * time.Hour)
}
}
if strings.HasSuffix(part, "day") || strings.HasSuffix(part, "days") {
if val, err := strconv.Atoi(strings.TrimSuffix(part, "s")); err == nil {
return now.Add(-time.Duration(val) * 24 * time.Hour)
}
}
// Handle "ago" word
if i > 0 && (part == "ago" || part == "ago,") {
// Continue parsing time units
}
}
return now.Add(-time.Hour) // Default to 1 hour ago if parsing fails
}
// CheckInterface verifies if the WireGuard interface exists and is accessible
func CheckInterface(interfaceName string) error {
cmd := exec.Command("wg", "show", interfaceName)
var out bytes.Buffer
cmd.Stdout = &out
err := cmd.Run()
if err != nil {
return fmt.Errorf("wireguard interface '%s' not accessible: %w", interfaceName, err)
}
return nil
}

View File

@@ -0,0 +1,27 @@
package wireguard
import (
"time"
tea "github.com/charmbracelet/bubbletea"
)
// StatusTickMsg is sent when it's time to refresh the status
type StatusTickMsg struct{}
// RefreshStatusMsg is sent when manual refresh is triggered
type RefreshStatusMsg struct{}
// Tick returns a tea.Cmd that will send StatusTickMsg at the specified interval
func Tick(interval int) tea.Cmd {
return tea.Tick(time.Duration(interval)*time.Second, func(t time.Time) tea.Msg {
return StatusTickMsg{}
})
}
// ManualRefresh returns a tea.Cmd to trigger immediate status refresh
func ManualRefresh() tea.Cmd {
return func() tea.Msg {
return RefreshStatusMsg{}
}
}