package wireguard import ( "encoding/base64" "fmt" "log" "os" "os/exec" "path/filepath" "strings" "sync" ) const ( // KeyLength is the standard WireGuard key length in bytes (32 bytes = 256 bits) KeyLength = 32 // Base64KeyLength is the length of a base64-encoded WireGuard key (44 characters) Base64KeyLength = 44 // TempDir is the directory for temporary key storage TempDir = "/tmp/wg-admin" ) var ( tempKeys = make(map[string]bool) tempKeysMutex sync.Mutex ) // KeyPair represents a WireGuard key pair type KeyPair struct { PrivateKey string // Base64-encoded private key PublicKey string // Base64-encoded public key } // GeneratePrivateKey generates a new WireGuard private key using wg genkey func GeneratePrivateKey() (string, error) { cmd := exec.Command("wg", "genkey") output, err := cmd.CombinedOutput() if err != nil { return "", fmt.Errorf("wg genkey failed: %w, output: %s", err, string(output)) } key := string(output) if err := ValidateKey(key); err != nil { return "", fmt.Errorf("generated invalid private key: %w", err) } // Store as temporary key tempKeyPath := filepath.Join(TempDir, "private.key") if err := storeTempKey(tempKeyPath, key); err != nil { log.Printf("Warning: failed to store temporary private key: %v", err) } return key, nil } // GeneratePublicKey generates a public key from a private key using wg pubkey func GeneratePublicKey(privateKey string) (string, error) { if err := ValidateKey(privateKey); err != nil { return "", fmt.Errorf("invalid private key: %w", err) } cmd := exec.Command("wg", "pubkey") cmd.Stdin = strings.NewReader(privateKey) output, err := cmd.CombinedOutput() if err != nil { return "", fmt.Errorf("wg pubkey failed: %w, output: %s", err, string(output)) } key := string(output) if err := ValidateKey(key); err != nil { return "", fmt.Errorf("generated invalid public key: %w", err) } // Store as temporary key tempKeyPath := filepath.Join(TempDir, "public.key") if err := storeTempKey(tempKeyPath, key); err != nil { log.Printf("Warning: failed to store temporary public key: %v", err) } return key, nil } // GeneratePSK generates a new pre-shared key using wg genpsk func GeneratePSK() (string, error) { cmd := exec.Command("wg", "genpsk") output, err := cmd.CombinedOutput() if err != nil { return "", fmt.Errorf("wg genpsk failed: %w, output: %s", err, string(output)) } key := string(output) if err := ValidateKey(key); err != nil { return "", fmt.Errorf("generated invalid PSK: %w", err) } // Store as temporary key tempKeyPath := filepath.Join(TempDir, "psk.key") if err := storeTempKey(tempKeyPath, key); err != nil { log.Printf("Warning: failed to store temporary PSK: %v", err) } return key, nil } // GenerateKeyPair generates a complete WireGuard key pair (private + public) func GenerateKeyPair() (*KeyPair, error) { privateKey, err := GeneratePrivateKey() if err != nil { return nil, err } publicKey, err := GeneratePublicKey(privateKey) if err != nil { return nil, fmt.Errorf("failed to generate public key: %w", err) } return &KeyPair{ PrivateKey: privateKey, PublicKey: publicKey, }, nil } // ValidateKey validates that a key is properly formatted (44 base64 characters) func ValidateKey(key string) error { // Trim whitespace key = strings.TrimSpace(key) // Check length (44 base64 characters for 32 bytes) if len(key) != Base64KeyLength { return fmt.Errorf("invalid key length: expected %d characters, got %d", Base64KeyLength, len(key)) } // Verify it's valid base64 decoded, err := base64.StdEncoding.DecodeString(key) if err != nil { return fmt.Errorf("invalid base64 encoding: %w", err) } // Verify decoded length is 32 bytes if len(decoded) != KeyLength { return fmt.Errorf("invalid decoded key length: expected %d bytes, got %d", KeyLength, len(decoded)) } return nil } // StoreKey atomically writes a key to a file with 0600 permissions func StoreKey(path string, key string) error { // Validate key before storing if err := ValidateKey(key); err != nil { return fmt.Errorf("invalid key: %w", err) } // Trim whitespace key = strings.TrimSpace(key) // Create parent directories if needed dir := filepath.Dir(path) if err := os.MkdirAll(dir, 0755); err != nil { return fmt.Errorf("failed to create directory %s: %w", dir, err) } // Write to temporary file tempPath := path + ".tmp" if err := os.WriteFile(tempPath, []byte(key), 0600); err != nil { return fmt.Errorf("failed to write temp file %s: %w", tempPath, err) } // Atomic rename if err := os.Rename(tempPath, path); err != nil { os.Remove(tempPath) // Clean up temp file on failure return fmt.Errorf("failed to rename temp file to %s: %w", path, err) } return nil } // LoadKey reads a key from a file and validates it func LoadKey(path string) (string, error) { data, err := os.ReadFile(path) if err != nil { return "", fmt.Errorf("failed to read key file %s: %w", path, err) } key := strings.TrimSpace(string(data)) if err := ValidateKey(key); err != nil { return "", fmt.Errorf("invalid key in file %s: %w", path, err) } return key, nil } // storeTempKey stores a temporary key and tracks it for cleanup func storeTempKey(path string, key string) error { // Create temp directory if needed if err := os.MkdirAll(TempDir, 0700); err != nil { return fmt.Errorf("failed to create temp directory %s: %w", TempDir, err) } // Trim whitespace key = strings.TrimSpace(key) // Write to file with 0600 permissions if err := os.WriteFile(path, []byte(key), 0600); err != nil { return fmt.Errorf("failed to write temp key to %s: %w", path, err) } // Track for cleanup tempKeysMutex.Lock() tempKeys[path] = true tempKeysMutex.Unlock() return nil } // CleanupTempKeys removes all temporary keys func CleanupTempKeys() error { tempKeysMutex.Lock() defer tempKeysMutex.Unlock() var cleanupErrors []string for path := range tempKeys { if err := os.Remove(path); err != nil && !os.IsNotExist(err) { cleanupErrors = append(cleanupErrors, fmt.Sprintf("%s: %v", path, err)) log.Printf("Warning: failed to remove temp key %s: %v", path, err) } delete(tempKeys, path) } // Also attempt to clean up the temp directory if empty if _, err := os.ReadDir(TempDir); err == nil { if err := os.Remove(TempDir); err != nil && !os.IsNotExist(err) { log.Printf("Warning: failed to remove temp directory %s: %v", TempDir, err) } } if len(cleanupErrors) > 0 { return fmt.Errorf("cleanup errors: %s", strings.Join(cleanupErrors, "; ")) } return nil }