From 26120b8bc2682a511c18effebeb2e36a3207e146 Mon Sep 17 00:00:00 2001 From: Calmcacil Date: Mon, 12 Jan 2026 19:03:35 +0100 Subject: [PATCH] Add WireGuard TUI implementation - Add Go TUI with bubbletea for WireGuard management - Implement client CRUD operations with QR code generation - Add configuration and validation modules - Install/update scripts for client setup - Update Makefile to build binaries to bin/ directory - Add .gitignore for Go projects --- .agent/AGENTS.md | 53 ++ .agent/README.md | 30 + .agent/beads.md | 88 +++ .agent/explore.md | 46 ++ .agent/librarian.md | 51 ++ .agent/oracle.md | 71 +++ .gitignore | 27 + AGENTS.md | 50 +- GITEA_ISSUES.md | 1 - Makefile | 17 +- README.md | 84 ++- cmd/wg-tui/main.go | 13 + config.example | 28 + go.mod | 36 ++ go.sum | 63 ++ internal/backup/backup.go | 267 +++++++++ internal/backup/restore.go | 122 ++++ internal/config/config.go | 237 ++++++++ internal/tui/components/confirm.go | 143 +++++ internal/tui/components/search.go | 307 ++++++++++ internal/tui/model.go | 89 +++ internal/tui/screens/add.go | 152 +++++ internal/tui/screens/detail.go | 299 ++++++++++ internal/tui/screens/interface.go | 31 + internal/tui/screens/list.go | 324 ++++++++++ internal/tui/screens/qr.go | 141 +++++ internal/tui/screens/restore.go | 333 +++++++++++ internal/tui/theme/theme.go | 185 ++++++ internal/validation/client.go | 86 +++ internal/wireguard/client.go | 585 ++++++++++++++++++ internal/wireguard/config.go | 106 ++++ internal/wireguard/keys.go | 243 ++++++++ internal/wireguard/status.go | 204 +++++++ internal/wireguard/tea_messages.go | 27 + test-wg-install.sh | 124 ++++ wg-install.sh | 921 +++++++++++++++++++++++++++++ wireguard.sh | 843 ++++++++++++++++++++++++-- 37 files changed, 6330 insertions(+), 97 deletions(-) create mode 100644 .agent/AGENTS.md create mode 100644 .agent/README.md create mode 100644 .agent/beads.md create mode 100644 .agent/explore.md create mode 100644 .agent/librarian.md create mode 100644 .agent/oracle.md create mode 100644 .gitignore create mode 100644 config.example create mode 100644 go.mod create mode 100644 go.sum create mode 100644 internal/backup/backup.go create mode 100644 internal/backup/restore.go create mode 100644 internal/config/config.go create mode 100644 internal/tui/components/confirm.go create mode 100644 internal/tui/components/search.go create mode 100644 internal/tui/model.go create mode 100644 internal/tui/screens/add.go create mode 100644 internal/tui/screens/detail.go create mode 100644 internal/tui/screens/interface.go create mode 100644 internal/tui/screens/list.go create mode 100644 internal/tui/screens/qr.go create mode 100644 internal/tui/screens/restore.go create mode 100644 internal/tui/theme/theme.go create mode 100644 internal/validation/client.go create mode 100644 internal/wireguard/client.go create mode 100644 internal/wireguard/config.go create mode 100644 internal/wireguard/keys.go create mode 100644 internal/wireguard/status.go create mode 100644 internal/wireguard/tea_messages.go create mode 100755 test-wg-install.sh create mode 100755 wg-install.sh diff --git a/.agent/AGENTS.md b/.agent/AGENTS.md new file mode 100644 index 0000000..ac1ce6b --- /dev/null +++ b/.agent/AGENTS.md @@ -0,0 +1,53 @@ +# Project-Specific Agent Instructions + +## YOUR MANDATE +0. **Use the beads CLI workflow**: You will always use the beads cli for your work as explained in tool_preferences +1. **Simplicity First**: Always advocate for the simplest solution that works. Reject complexity unless it is proven necessary. +2. **DRY & YAGNI**: These are your non-negotiable pillars. Identify redundancy and premature optimization immediately. +3. **Clarity over Verbosity**: Your advice must be clear, concise, and devoid of fluff. Do not be overly descriptive. Get to the point. +4. **Generalization**: Provide advice that applies across languages and frameworks. Focus on the *pattern*, not the *syntax*. +5. **Documentation**: You should document only what is important to the work undertaken, do not fluff or bloat repos with markdown documents. +6. **Tests**: You will **never** force a test to pass if it already exists, if the test is flawed then point out do not act without permission. +7. **Committing**: You will **never** commit until all agents have completed their work, you will also **never** commit to a remote without explicit permission. +8. **Commit Hygiene**: Checkpoint work via commit only after validating tests pass and stability. Ensure clean commit hygiene. + +### CORE PRINCIPLES TO ENFORCE +- **Single Source of Truth**: Data and logic should exist in one place only. +- **Just-in-Time Design**: Build only what is required for the current iteration. +- **Code is Liability**: Less code means fewer bugs. Delete unused code ruthlessly. +- **Explicit over Implicit**: Magic is bad. Clear flow is good. +- **Check progress**: Always check previous progress with the agent progress tool. +- **Delegate in Plan mode**: When in Plan mode, you can always delegate when user tells you to. + +### WORK DISCIPLINE (from Sisyphus methodology) + +#### Todo Management (MANDATORY for multi-step tasks) +- **Create todos BEFORE starting** any task with 2+ steps +- **Mark in_progress** before starting each step (only ONE at a time) +- **Mark completed IMMEDIATELY** after each step (NEVER batch completions) +- **Update todos** if scope changes before proceeding +- Todos provide user visibility, prevent drift, and enable recovery + +#### Code Quality +- **No excessive comments**: Code should be self-documenting. Only add comments that explain WHY, not WHAT. +- **No type suppressions**: Never use `as any`, `@ts-ignore`, `@ts-expect-error` +- **No empty catch blocks**: Always handle errors meaningfully +- **Match existing patterns**: Your code should look like the team wrote it + +#### Agent Delegation +- **Use @mentions** to invoke specialized subagents: `@explore`, `@librarian`, `@oracle`, `@frontend-ui-ux-engineer` +- **Frontend visual work** (styling, layout, animation) → delegate to `@frontend-ui-ux-engineer` +- **Architecture decisions** or debugging after 2+ failed attempts → consult `@oracle` +- **External docs/library questions** → delegate to `@librarian` +- **Codebase exploration** → delegate to `@explore` + +#### Failure Recovery +- Fix root causes, not symptoms +- Re-verify after EVERY fix attempt +- After 3 consecutive failures: STOP, revert to working state, document what failed, consult oracle + +#### Completion Criteria +A task is complete when: +- [ ] All planned todo items marked done +- [ ] Build passes (if applicable) +- [ ] User's original request fully addressed diff --git a/.agent/README.md b/.agent/README.md new file mode 100644 index 0000000..35826ae --- /dev/null +++ b/.agent/README.md @@ -0,0 +1,30 @@ +# Agent Instructions Directory + +This directory contains project-specific instructions for agents working on the wg-admin project. + +## Required Reading + +All agents MUST read `AGENTS.md` before starting any work. + +## Available Instructions + +- `AGENTS.md` - Core principles and work discipline (READ THIS FIRST) +- `beads.md` - Beads workflow and command reference +- `explore.md` - Instructions for codebase exploration +- `librarian.md` - Instructions for research and documentation +- `oracle.md` - Instructions for architecture and debugging + +## Agent Workflow + +1. Read `.agent/AGENTS.md` for project mandates +2. Read agent-specific instructions (e.g., `explore.md`) +3. Execute work following established patterns +4. Use beads for tracking multi-session or complex work +5. Run `bd sync` before ending session + +## Project Context + +- **Issue tracking**: beads/bd CLI +- **External issues**: Gitea (https://gitea.calmcacil.dev) +- **Technology**: Bash/shell scripting +- **Purpose**: WireGuard VPN management tool diff --git a/.agent/beads.md b/.agent/beads.md new file mode 100644 index 0000000..a8c99fa --- /dev/null +++ b/.agent/beads.md @@ -0,0 +1,88 @@ +# Beads Workflow Guide + +## Quick Start + +```bash +bd ready # Find available work +bd show # View issue details +bd update --status in_progress # Claim work +bd close # Complete work +bd sync # Sync with git +``` + +## Workflow Commands + +### Finding Work +- `bd ready` - Show issues ready to work (no blockers) +- `bd list --status=open` - All open issues +- `bd list --status=in_progress` - Your active work +- `bd show ` - Detailed issue view with dependencies + +### Creating & Updating +- `bd create --title="..." --type=task|bug|feature --priority=2` - New issue + - Priority: 0-4 or P0-P4 (0=critical, 2=medium, 4=backlog). NOT "high"/"medium"/"low" +- `bd update --status=in_progress` - Claim work +- `bd update --assignee=username` - Assign to someone +- `bd close ` - Mark complete +- `bd close ...` - Close multiple issues at once (more efficient) +- `bd close --reason="explanation"` - Close with reason +- **Tip**: When creating multiple issues/tasks/epics, use parallel subagents for efficiency + +### Dependencies & Blocking +- `bd dep add ` - Add dependency (issue depends on depends-on) +- `bd blocked` - Show all blocked issues +- `bd show ` - See what's blocking/blocked by this issue + +### Sync & Collaboration +- `bd sync` - Sync with git remote (run at session end) +- `bd sync --status` - Check sync status without syncing + +### Project Health +- `bd stats` - Project statistics (open/closed/blocked counts) +- `bd doctor` - Check for issues (sync problems, missing hooks) + +## Common Workflows + +### Starting Work +```bash +bd ready # Find available work +bd show # Review issue details +bd update --status=in_progress # Claim it +``` + +### Completing Work +```bash +bd close ... # Close all completed issues at once +bd sync # Push to remote +``` + +### Creating Dependent Work +```bash +# Run bd create commands in parallel (use subagents for many items) +bd create --title="Implement feature X" --type=feature +bd create --title="Write tests for X" --type=task +bd dep add beads-yyy beads-xxx # Tests depend on Feature (Feature blocks tests) +``` + +## Integration with Gitea + +When working on Gitea issues: +1. Create a beads issue to track the work +2. Link to Gitea issue in description or comments +3. When committing, use `Closes #{gitea_number}` to auto-close +4. Close the beads issue after push succeeds + +## Session Close Protocol + +**CRITICAL**: Before saying "done", run this checklist: + +``` +[ ] 1. git status (check what changed) +[ ] 2. git add (stage code changes) +[ ] 3. bd sync (commit beads changes) +[ ] 4. git commit -m "..." (commit code) +[ ] 5. bd sync (commit any new beads changes) +[ ] 6. git push (push to remote) +``` + +Work is NOT complete until `git push` succeeds. diff --git a/.agent/explore.md b/.agent/explore.md new file mode 100644 index 0000000..cf9f9a1 --- /dev/null +++ b/.agent/explore.md @@ -0,0 +1,46 @@ +# Explore Agent - Project-Specific Instructions + +## Your Role in This Project + +You are the codebase exploration expert for the wg-admin project. + +## Project Context + +- **Primary issue tracking**: `bd` (beads) CLI +- **External issue tracking**: Gitea API at https://gitea.calmcacil.dev +- **Key directories**: + - `.agent/` - Project-specific agent instructions + - `.beads/` - Beads issue tracking data + - `wireguard.sh` - Main script for WireGuard management + +## When to Use + +Use the explore agent when you need to: +- Understand code structure and organization +- Find where specific functionality is implemented +- Identify patterns across the codebase +- Locate files matching specific criteria +- Understand the project architecture + +## Workflow + +1. **Read `.agent/AGENTS.md`** first for project mandates +2. Use grep and glob tools for codebase exploration +3. Provide clear, concise findings with file paths and line numbers +4. If you find issues, consider creating beads issues for follow-up work + +## Key Search Targets + +- WireGuard configuration handling +- Firewall rule management +- Peer connection logic +- Configuration file parsing/generation +- Error handling patterns + +## Output Format + +When reporting findings: +- Use `file_path:line_number` format for references +- Keep descriptions concise and actionable +- Highlight patterns, not just locations +- Note any code smells or anti-patterns you discover diff --git a/.agent/librarian.md b/.agent/librarian.md new file mode 100644 index 0000000..bafeba4 --- /dev/null +++ b/.agent/librarian.md @@ -0,0 +1,51 @@ +# Librarian Agent - Project-Specific Instructions + +## Your Role in This Project + +You are the research and documentation expert for the wg-admin project. + +## Project Context + +- **Technology stack**: Bash/shell scripting for WireGuard management +- **Documentation style**: Minimal, task-focused markdown +- **Documentation location**: Project root (`README.md`, `GITEA_ISSUES.md`) + +## When to Use + +Use the librarian agent when you need to: +- Research WireGuard best practices and configuration options +- Find examples of similar tools or implementations +- Look up documentation for shell scripting patterns +- Research Gitea API usage +- Find security best practices for network management tools + +## Research Priorities + +1. **WireGuard**: Configuration syntax, peer management, routing +2. **Shell scripting**: Best practices, error handling, security +3. **Gitea API**: Issue management, webhooks, authentication +4. **Bash**: Modern patterns, POSIX compatibility concerns + +## Workflow + +1. **Read `.agent/AGENTS.md`** first for project mandates +2. Use websearch and context7 tools for research +3. Focus on practical, actionable information +4. Provide code examples when relevant +5. Cite sources when providing specific recommendations + +## Output Format + +When providing research: +- Prioritize official documentation over blog posts +- Provide code snippets when relevant +- Note version-specific information +- Highlight security considerations +- Keep responses concise and focused + +## Documentation Principles + +- **Don't create unnecessary docs** - Only document what's critical +- **Code over docs** - Self-documenting code preferred +- **Update in-place** - Modify existing docs, don't create new ones unless needed +- **User-focused** - Write for the people using this tool diff --git a/.agent/oracle.md b/.agent/oracle.md new file mode 100644 index 0000000..a331d61 --- /dev/null +++ b/.agent/oracle.md @@ -0,0 +1,71 @@ +# Oracle Agent - Project-Specific Instructions + +## Your Role in This Project + +You are the architecture and debugging expert for the wg-admin project. + +## Project Context + +- **Type**: Network administration tool (WireGuard management) +- **Language**: Bash/shell scripting +- **Critical concerns**: Security, reliability, error handling +- **Integration**: Gitea for external issue tracking + +## When to Use + +Use the oracle agent when: +- Architecture decisions need to be made +- Debugging complex issues after 2+ failed attempts +- Security review is needed +- Performance optimization is required +- Multiple solutions exist and you need to recommend the best approach + +## Your Approach + +1. **Read `.agent/AGENTS.md`** first for project mandates +2. Gather full context - read relevant code, logs, error messages +3. Apply first principles thinking +4. Consider trade-offs: simplicity vs completeness +5. Recommend the simplest solution that works +6. Explain the "why" behind your recommendation + +## Key Concerns + +### Security +- Credential handling +- File permissions +- Input validation +- Injection vulnerabilities (shell injection, command injection) +- Privilege escalation risks + +### Reliability +- Error handling completeness +- Idempotent operations +- Transaction safety +- Rollback mechanisms +- State consistency + +### Maintainability +- Code organization +- Testing approach (what tests exist, what's missing) +- Dependency management +- Documentation adequacy + +## Debugging Process + +1. **Understand the symptom** - What's failing? +2. **Reproduce** - Can you create a minimal reproduction? +3. **Isolate** - What's the minimal code path that exhibits the issue? +4. **Hypothesize** - What's the likely root cause? +5. **Test** - Verify or disprove your hypothesis +6. **Fix** - Apply the minimal fix that resolves the root cause +7. **Verify** - Ensure the fix doesn't break anything else + +## Output Format + +When providing recommendations: +- Start with the recommended solution +- Explain the reasoning concisely +- Discuss trade-offs if relevant +- Provide implementation guidance +- Note potential pitfalls or edge cases diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..efc8c12 --- /dev/null +++ b/.gitignore @@ -0,0 +1,27 @@ +# Binaries +bin/ +*.exe +*.dll +*.so +*.dylib + +# Test binary, built with `go test -c` +*.test + +# Output of the go coverage tool +*.out +coverage.html + +# Go workspace file +go.work + +# IDE specific +.vscode/ +.idea/ +*.swp +*.swo +*~ + +# OS specific +.DS_Store +Thumbs.db diff --git a/AGENTS.md b/AGENTS.md index a2f03ae..43fb505 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -1,22 +1,14 @@ # Agent Instructions +**IMPORTANT**: All agents MUST read `.agent/AGENTS.md` for project-specific instructions before starting work. + +## Project Overview + This project uses **bd** (beads) for issue tracking and **Gitea** for external issue tracking. -## Gitea Issues Workflow - -When asked to work on Gitea issues: -1. Read issue via API: `curl -s "https://gitea.calmcacil.dev/api/v1/repos/{owner}/{repo}/issues/{number}"` -2. Analyze requirements and scope -3. Create task list with `todowrite` for multi-step work -4. Implement fix following existing patterns -5. Commit with `Closes #{number}` in message (auto-closes issue) -6. Push to remote - -See `GITEA_ISSUES.md` for detailed workflow. - -## Beads Workflow - -Run `bd onboard` to get started with beads. +- **Primary workflow**: Use `bd` CLI for issue tracking +- **External issues**: Gitea API for external bug/feature tracking +- **Agent instructions**: `.agent/` directory contains project-specific agent guidance ## Quick Reference @@ -28,29 +20,9 @@ bd close # Complete work bd sync # Sync with git ``` -## Landing the Plane (Session Completion) +## Documentation -**When ending a work session**, you MUST complete ALL steps below. Work is NOT complete until `git push` succeeds. - -**MANDATORY WORKFLOW:** - -1. **File issues for remaining work** - Create issues for anything that needs follow-up -2. **Run quality gates** (if code changed) - Tests, linters, builds -3. **Update issue status** - Close finished work, update in-progress items -4. **PUSH TO REMOTE** - This is MANDATORY: - ```bash - git pull --rebase - bd sync - git push - git status # MUST show "up to date with origin" - ``` -5. **Clean up** - Clear stashes, prune remote branches -6. **Verify** - All changes committed AND pushed -7. **Hand off** - Provide context for next session - -**CRITICAL RULES:** -- Work is NOT complete until `git push` succeeds -- NEVER stop before pushing - that leaves work stranded locally -- NEVER say "ready to push when you are" - YOU must push -- If push fails, resolve and retry until it succeeds +- `.agent/AGENTS.md` - Project-specific agent instructions (READ THIS FIRST) +- `GITEA_ISSUES.md` - Gitea issue workflow details +- `README.md` - Project overview and setup diff --git a/GITEA_ISSUES.md b/GITEA_ISSUES.md index 9088225..5bb6954 100644 --- a/GITEA_ISSUES.md +++ b/GITEA_ISSUES.md @@ -44,7 +44,6 @@ Closes #{issue_number}" ### 7. Push and Close ```bash -bd sync git push ``` diff --git a/Makefile b/Makefile index cfeb4a3..6cdd56d 100644 --- a/Makefile +++ b/Makefile @@ -9,7 +9,8 @@ BINARY=wg-tui CMD_PATH=$(ROOTDIR)/cmd/$(BINARY) # Build directory -BUILD_DIR=build +BUILD_DIR=bin +BINARY_PATH=$(BUILD_DIR)/$(BINARY) # Go parameters GOCMD=go @@ -27,23 +28,25 @@ help: ## Show this help message build: ## Build the binary @echo "Building $(BINARY)..." - @$(GOBUILD) -C $(ROOTDIR) -o $(BINARY) cmd/$(BINARY)/main.go - @echo "Build complete: $(BINARY)" + @mkdir -p $(BUILD_DIR) + @$(GOBUILD) -C $(ROOTDIR) -o $(BINARY_PATH) cmd/$(BINARY)/main.go + @echo "Build complete: $(BINARY_PATH)" build-all: ## Build all binaries @echo "Building all binaries..." - @$(GOBUILD) -C $(ROOTDIR) -o $(BINARY) ./... + @mkdir -p $(BUILD_DIR) + @$(GOBUILD) -C $(ROOTDIR) -o $(BINARY_PATH) ./... clean: ## Clean build artifacts @echo "Cleaning..." @$(GOCLEAN) - @rm -f $(BINARY) + @rm -rf $(BUILD_DIR) @echo "Clean complete" install: ## Install the binary to $GOPATH/bin @echo "Installing $(BINARY)..." @$(GOBUILD) -C $(ROOTDIR) -o $$($(GOCMD) env GOPATH)/bin/$(BINARY) cmd/$(BINARY)/main.go - @echo "Install complete" + @echo "Install complete: $$($(GOCMD) env GOPATH)/bin/$(BINARY)" test: ## Run tests @echo "Running tests..." @@ -73,7 +76,7 @@ deps: ## Download dependencies run: build ## Build and run the binary @echo "Running $(BINARY)..." - @./$(BINARY) + @./$(BINARY_PATH) dev: ## Run in development mode with hot reload (requires air) @if command -v air > /dev/null; then \ diff --git a/README.md b/README.md index 3fb73f9..de9c5f0 100644 --- a/README.md +++ b/README.md @@ -3,23 +3,89 @@ ## Overview Personal WireGuard VPN server with IPv4/IPv6 support, client management via `wireguard.sh`, designed for 1 CPU / 1GB RAM VPS. +## Development & Issue Tracking + +This project uses **beads** (`bd` CLI) for issue tracking and **Gitea** for external issue tracking. + +- **Agent instructions**: See `AGENTS.md` and `.agent/` directory for project-specific guidance +- **Issue tracking**: Use `bd ready` to find available work +- **External issues**: Gitea at https://gitea.calmcacil.dev + ## Configuration -- **Server Domain**: velkhana.calmcacil.dev -- **Port**: 51820 -- **VPN IPv4 Range**: 10.10.69.0/24 -- **VPN IPv6 Range**: fd69:dead:beef:69::/64 -- **DNS**: 8.8.8.8, 8.8.4.4 (Google) -- **Server-side peer configs**: /etc/wireguard/conf.d/client-*.conf (loaded dynamically) -- **Client-side configs**: /etc/wireguard/clients/*.conf (for distribution) + +Configuration is managed through `/etc/wg-admin/config.conf`. Copy `config.example` to this location and customize for your environment. + +### Creating Configuration File + +```bash +sudo mkdir -p /etc/wg-admin +sudo cp config.example /etc/wg-admin/config.conf +sudo nano /etc/wg-admin/config.conf +``` + +### Configuration Variables + +| Variable | Default | Description | +|----------|---------|-------------| +| `SERVER_DOMAIN` | *Required* | Server domain or IP address (e.g., `vpn.example.com`) | +| `WG_PORT` | 51820 | WireGuard UDP port | +| `VPN_IPV4_RANGE` | 10.10.69.0/24 | VPN IPv4 address range | +| `VPN_IPV6_RANGE` | fd69:dead:beef:69::/64 | VPN IPv6 address range | +| `WG_INTERFACE` | wg0 | WireGuard interface name | +| `DNS_SERVERS` | 8.8.8.8, 8.8.4.4 | DNS servers for clients | +| `LOG_FILE` | /var/log/wireguard-admin.log | Log file location | + +### Example Configuration + +```ini +# Server domain or IP address (required) +SERVER_DOMAIN=vpn.example.com + +# WireGuard UDP port (optional, default: 51820) +WG_PORT=51820 + +# VPN IPv4 range (optional, default: 10.10.69.0/24) +VPN_IPV4_RANGE=10.10.69.0/24 + +# VPN IPv6 range (optional, default: fd69:dead:beef:69::/64) +VPN_IPV6_RANGE=fd69:dead:beef:69::/64 + +# DNS servers (optional, default: 8.8.8.8, 8.8.4.4) +DNS_SERVERS=8.8.8.8, 8.8.4.4 +``` + +**Note**: All values are optional except `SERVER_DOMAIN`. The script will use defaults if not specified. + +### Configuration Priority + +1. `/etc/wg-admin/config.conf` file (highest priority) +2. Environment variables (e.g., `SERVER_DOMAIN=vpn.example.com ./wireguard.sh install`) +3. Built-in defaults (lowest priority) + +### Other Directories +- **Server-side peer configs**: `/etc/wireguard/conf.d/client-*.conf` (loaded dynamically) +- **Client-side configs**: `/etc/wireguard/clients/*.conf` (for distribution) ## Installation ### 1. Upload script to VPS ```bash -scp wireguard.sh calmcacil@velkhana.calmcacil.dev:~/ +scp wireguard.sh calmcacil@your-vps.com:~/ +scp config.example calmcacil@your-vps.com:~/ ``` -### 2. Run installation +### 2. Configure the script +```bash +# Copy example config and customize +sudo mkdir -p /etc/wg-admin +sudo cp ~/config.example /etc/wg-admin/config.conf +sudo nano /etc/wg-admin/config.conf + +# Set at minimum: +# SERVER_DOMAIN=vpn.yourdomain.com +``` + +### 3. Run installation ```bash chmod +x ~/wireguard.sh sudo ~/wireguard.sh install diff --git a/cmd/wg-tui/main.go b/cmd/wg-tui/main.go index 5a0db99..c40ae58 100644 --- a/cmd/wg-tui/main.go +++ b/cmd/wg-tui/main.go @@ -60,6 +60,11 @@ func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { // Switch to list screen m.currentScreen = screens.NewListScreen() return m, m.currentScreen.Init() + case "a": + // Switch to add screen + m.previousScreen = m.currentScreen + m.currentScreen = screens.NewAddScreen() + return m, m.currentScreen.Init() } case screens.ClientSelectedMsg: // User selected a client - show detail screen @@ -70,6 +75,14 @@ func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { // Client was deleted - show success message and return to list m.currentScreen = screens.NewListScreen() return m, m.currentScreen.Init() + case screens.ClientCreatedMsg: + // Client was created - return to list screen + m.currentScreen = screens.NewListScreen() + return m, m.currentScreen.Init() + case screens.RestoreCompletedMsg: + // Restore completed - return to list screen to refresh clients + m.currentScreen = screens.NewListScreen() + return m, m.currentScreen.Init() case screens.CloseDetailScreenMsg: // Detail screen closed - go back to previous screen if m.previousScreen != nil { diff --git a/config.example b/config.example new file mode 100644 index 0000000..563b4d6 --- /dev/null +++ b/config.example @@ -0,0 +1,28 @@ +# WireGuard VPN Configuration +# Copy this file to /etc/wg-admin/config.conf and customize for your environment +# All values are optional - script will use defaults if not set + +# Server domain or IP address (required for client endpoints) +# Example: vpn.example.com or 203.0.113.10 +#SERVER_DOMAIN=vpn.example.com + +# WireGuard UDP port (default: 51820) +#WG_PORT=51820 + +# VPN IPv4 address range (default: 10.10.69.0/24) +#VPN_IPV4_RANGE=10.10.69.0/24 + +# VPN IPv6 address range (default: fd69:dead:beef:69::/64) +#VPN_IPV6_RANGE=fd69:dead:beef:69::/64 + +# WireGuard interface name (default: wg0) +#WG_INTERFACE=wg0 + +# DNS servers for clients (default: 8.8.8.8, 8.8.4.4) +#DNS_SERVERS=8.8.8.8, 8.8.4.4 + +# Log file location (default: /var/log/wireguard-admin.log) +#LOG_FILE=/var/log/wireguard-admin.log + +# Minimum disk space required in MB (default: 100) +#MIN_DISK_SPACE_MB=100 diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..3b0102d --- /dev/null +++ b/go.mod @@ -0,0 +1,36 @@ +module github.com/calmcacil/wg-admin + +go 1.24.4 + +require ( + github.com/atotto/clipboard v0.1.4 // indirect + github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect + github.com/catppuccin/go v0.3.0 // indirect + github.com/charmbracelet/bubbles v0.21.1-0.20250623103423-23b8fd6302d7 // indirect + github.com/charmbracelet/bubbletea v1.3.10 // indirect + github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc // indirect + github.com/charmbracelet/huh v0.8.0 // indirect + github.com/charmbracelet/lipgloss v1.1.0 // indirect + github.com/charmbracelet/x/ansi v0.10.1 // indirect + github.com/charmbracelet/x/cellbuf v0.0.13 // indirect + github.com/charmbracelet/x/exp/strings v0.0.0-20240722160745-212f7b056ed0 // indirect + github.com/charmbracelet/x/term v0.2.1 // indirect + github.com/dustin/go-humanize v1.0.1 // indirect + github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect + github.com/lucasb-eyer/go-colorful v1.2.0 // indirect + github.com/mattn/go-isatty v0.0.20 // indirect + github.com/mattn/go-localereader v0.0.1 // indirect + github.com/mattn/go-runewidth v0.0.16 // indirect + github.com/mdp/qrterminal/v3 v3.2.1 // indirect + github.com/mitchellh/hashstructure/v2 v2.0.2 // indirect + github.com/mitchellh/mapstructure v1.5.0 // indirect + github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 // indirect + github.com/muesli/cancelreader v0.2.2 // indirect + github.com/muesli/termenv v0.16.0 // indirect + github.com/rivo/uniseg v0.4.7 // indirect + github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect + golang.org/x/sys v0.36.0 // indirect + golang.org/x/term v0.13.0 // indirect + golang.org/x/text v0.23.0 // indirect + rsc.io/qr v0.2.0 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..5fadf83 --- /dev/null +++ b/go.sum @@ -0,0 +1,63 @@ +github.com/atotto/clipboard v0.1.4 h1:EH0zSVneZPSuFR11BlR9YppQTVDbh5+16AmcJi4g1z4= +github.com/atotto/clipboard v0.1.4/go.mod h1:ZY9tmq7sm5xIbd9bOK4onWV4S6X0u6GY7Vn0Yu86PYI= +github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k= +github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8= +github.com/catppuccin/go v0.3.0 h1:d+0/YicIq+hSTo5oPuRi5kOpqkVA5tAsU6dNhvRu+aY= +github.com/catppuccin/go v0.3.0/go.mod h1:8IHJuMGaUUjQM82qBrGNBv7LFq6JI3NnQCF6MOlZjpc= +github.com/charmbracelet/bubbles v0.21.1-0.20250623103423-23b8fd6302d7 h1:JFgG/xnwFfbezlUnFMJy0nusZvytYysV4SCS2cYbvws= +github.com/charmbracelet/bubbles v0.21.1-0.20250623103423-23b8fd6302d7/go.mod h1:ISC1gtLcVilLOf23wvTfoQuYbW2q0JevFxPfUzZ9Ybw= +github.com/charmbracelet/bubbletea v1.3.10 h1:otUDHWMMzQSB0Pkc87rm691KZ3SWa4KUlvF9nRvCICw= +github.com/charmbracelet/bubbletea v1.3.10/go.mod h1:ORQfo0fk8U+po9VaNvnV95UPWA1BitP1E0N6xJPlHr4= +github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc h1:4pZI35227imm7yK2bGPcfpFEmuY1gc2YSTShr4iJBfs= +github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc/go.mod h1:X4/0JoqgTIPSFcRA/P6INZzIuyqdFY5rm8tb41s9okk= +github.com/charmbracelet/huh v0.8.0 h1:Xz/Pm2h64cXQZn/Jvele4J3r7DDiqFCNIVteYukxDvY= +github.com/charmbracelet/huh v0.8.0/go.mod h1:5YVc+SlZ1IhQALxRPpkGwwEKftN/+OlJlnJYlDRFqN4= +github.com/charmbracelet/lipgloss v1.1.0 h1:vYXsiLHVkK7fp74RkV7b2kq9+zDLoEU4MZoFqR/noCY= +github.com/charmbracelet/lipgloss v1.1.0/go.mod h1:/6Q8FR2o+kj8rz4Dq0zQc3vYf7X+B0binUUBwA0aL30= +github.com/charmbracelet/x/ansi v0.10.1 h1:rL3Koar5XvX0pHGfovN03f5cxLbCF2YvLeyz7D2jVDQ= +github.com/charmbracelet/x/ansi v0.10.1/go.mod h1:3RQDQ6lDnROptfpWuUVIUG64bD2g2BgntdxH0Ya5TeE= +github.com/charmbracelet/x/cellbuf v0.0.13 h1:/KBBKHuVRbq1lYx5BzEHBAFBP8VcQzJejZ/IA3iR28k= +github.com/charmbracelet/x/cellbuf v0.0.13/go.mod h1:xe0nKWGd3eJgtqZRaN9RjMtK7xUYchjzPr7q6kcvCCs= +github.com/charmbracelet/x/exp/strings v0.0.0-20240722160745-212f7b056ed0 h1:qko3AQ4gK1MTS/de7F5hPGx6/k1u0w4TeYmBFwzYVP4= +github.com/charmbracelet/x/exp/strings v0.0.0-20240722160745-212f7b056ed0/go.mod h1:pBhA0ybfXv6hDjQUZ7hk1lVxBiUbupdw5R31yPUViVQ= +github.com/charmbracelet/x/term v0.2.1 h1:AQeHeLZ1OqSXhrAWpYUtZyX1T3zVxfpZuEQMIQaGIAQ= +github.com/charmbracelet/x/term v0.2.1/go.mod h1:oQ4enTYFV7QN4m0i9mzHrViD7TQKvNEEkHUMCmsxdUg= +github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= +github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= +github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f h1:Y/CXytFA4m6baUTXGLOoWe4PQhGxaX0KpnayAqC48p4= +github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f/go.mod h1:vw97MGsxSvLiUE2X8qFplwetxpGLQrlU1Q9AUEIzCaM= +github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY= +github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-localereader v0.0.1 h1:ygSAOl7ZXTx4RdPYinUpg6W99U8jWvWi9Ye2JC/oIi4= +github.com/mattn/go-localereader v0.0.1/go.mod h1:8fBrzywKY7BI3czFoHkuzRoWE9C+EiG4R1k4Cjx5p88= +github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6TULQc= +github.com/mattn/go-runewidth v0.0.16/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= +github.com/mdp/qrterminal/v3 v3.2.1 h1:6+yQjiiOsSuXT5n9/m60E54vdgFsw0zhADHhHLrFet4= +github.com/mdp/qrterminal/v3 v3.2.1/go.mod h1:jOTmXvnBsMy5xqLniO0R++Jmjs2sTm9dFSuQ5kpz/SU= +github.com/mitchellh/hashstructure/v2 v2.0.2 h1:vGKWl0YJqUNxE8d+h8f6NJLcCJrgbhC4NcD46KavDd4= +github.com/mitchellh/hashstructure/v2 v2.0.2/go.mod h1:MG3aRVU/N29oo/V/IhBX8GR/zz4kQkprJgF2EVszyDE= +github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= +github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= +github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 h1:ZK8zHtRHOkbHy6Mmr5D264iyp3TiX5OmNcI5cIARiQI= +github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6/go.mod h1:CJlz5H+gyd6CUWT45Oy4q24RdLyn7Md9Vj2/ldJBSIo= +github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELUXHmA= +github.com/muesli/cancelreader v0.2.2/go.mod h1:3XuTXfFS2VjM+HTLZY9Ak0l6eUKfijIfMUZ4EgX0QYo= +github.com/muesli/termenv v0.16.0 h1:S5AlUN9dENB57rsbnkPyfdGuWIlkmzJjbFf0Tf5FWUc= +github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3VfY/Cnk= +github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= +github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ= +github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= +github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no= +github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM= +golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.36.0 h1:KVRy2GtZBrk1cBYA7MKu5bEZFxQk4NIDV6RLVcC8o0k= +golang.org/x/sys v0.36.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/term v0.13.0 h1:bb+I9cTfFazGW51MZqBVmZy7+JEJMouUHTUSKVQLBek= +golang.org/x/term v0.13.0/go.mod h1:LTmsnFJwVN6bCy1rVCoS+qHT1HhALEFxKncY3WNNh4U= +golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY= +golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4= +rsc.io/qr v0.2.0 h1:6vBLea5/NRMVTz8V66gipeLycZMl/+UlFmk8DvqQ6WY= +rsc.io/qr v0.2.0/go.mod h1:IF+uZjkb9fqyeF/4tlBoynqmQxUoPfWEKh921coOuXs= diff --git a/internal/backup/backup.go b/internal/backup/backup.go new file mode 100644 index 0000000..8454ba6 --- /dev/null +++ b/internal/backup/backup.go @@ -0,0 +1,267 @@ +package backup + +import ( + "fmt" + "os" + "os/exec" + "path/filepath" + "sort" + "strings" + "time" +) + +// Backup represents a backup with metadata +type Backup struct { + Name string // Backup name (directory name) + Path string // Full path to backup directory + Operation string // Operation that triggered the backup + Timestamp time.Time // When the backup was created + Size int64 // Size in bytes +} + +// CreateBackup creates a new backup with the specified operation +func CreateBackup(operation string) error { + backupDir := "/etc/wg-admin/backups" + + // Create backup directory if it doesn't exist + if err := os.MkdirAll(backupDir, 0700); err != nil { + return fmt.Errorf("failed to create backup directory: %w", err) + } + + // Create timestamped backup directory + timestamp := time.Now().Format("20060102-150405") + backupName := fmt.Sprintf("wg-backup-%s-%s", operation, timestamp) + backupPath := filepath.Join(backupDir, backupName) + + if err := os.MkdirAll(backupPath, 0700); err != nil { + return fmt.Errorf("failed to create backup path: %w", err) + } + + // Backup entire wireguard directory to maintain structure expected by restore.go + wgConfigPath := "/etc/wireguard" + if _, err := os.Stat(wgConfigPath); err == nil { + backupWgPath := filepath.Join(backupPath, "wireguard") + if err := exec.Command("cp", "-a", wgConfigPath, backupWgPath).Run(); err != nil { + return fmt.Errorf("failed to backup wireguard config: %w", err) + } + } + + // Create backup metadata + metadataPath := filepath.Join(backupPath, "backup-info.txt") + metadata := fmt.Sprintf("Backup created: %s\nOperation: %s\nTimestamp: %s\n", time.Now().Format(time.RFC3339), operation, timestamp) + if err := os.WriteFile(metadataPath, []byte(metadata), 0600); err != nil { + return fmt.Errorf("failed to create backup metadata: %w", err) + } + + // Set restrictive permissions on backup directory + if err := os.Chmod(backupPath, 0700); err != nil { + return fmt.Errorf("failed to set backup directory permissions: %w", err) + } + + // Apply retention policy (keep last 10 backups) + if err := applyRetentionPolicy(backupDir, 10); err != nil { + // Log but don't fail on retention errors + fmt.Fprintf(os.Stderr, "Warning: failed to apply retention policy: %v\n", err) + } + + return nil +} + +// ListBackups returns all available backups sorted by creation time (newest first) +func ListBackups() ([]Backup, error) { + backupDir := "/etc/wg-admin/backups" + + // Check if backup directory exists + if _, err := os.Stat(backupDir); os.IsNotExist(err) { + return []Backup{}, nil + } + + // Read all entries + entries, err := os.ReadDir(backupDir) + if err != nil { + return nil, fmt.Errorf("failed to read backup directory: %w", err) + } + + var backups []Backup + + // Parse backup directories + for _, entry := range entries { + if !entry.IsDir() { + continue + } + + name := entry.Name() + + // Check if it's a valid backup directory (starts with "wg-backup-") + if !strings.HasPrefix(name, "wg-backup-") { + continue + } + + // Parse backup name to extract operation and timestamp + // Format: wg-backup-{operation}-{timestamp} + parts := strings.SplitN(name, "-", 4) + if len(parts) < 4 { + continue + } + + operation := parts[2] + timestampStr := parts[3] + + // Parse timestamp + timestamp, err := time.Parse("20060102-150405", timestampStr) + if err != nil { + // If timestamp parsing fails, use directory modification time + info, err := entry.Info() + if err != nil { + continue + } + timestamp = info.ModTime() + } + + // Get backup size + backupPath := filepath.Join(backupDir, name) + size, err := getBackupSize(backupPath) + if err != nil { + size = 0 + } + + backup := Backup{ + Name: name, + Path: backupPath, + Operation: operation, + Timestamp: timestamp, + Size: size, + } + + backups = append(backups, backup) + } + + // Sort by timestamp (newest first) + sort.Slice(backups, func(i, j int) bool { + return backups[i].Timestamp.After(backups[j].Timestamp) + }) + + return backups, nil +} + +// getBackupSize calculates the total size of a backup directory +func getBackupSize(backupPath string) (int64, error) { + var size int64 + + err := filepath.Walk(backupPath, func(_ string, info os.FileInfo, err error) error { + if err != nil { + return err + } + if !info.IsDir() { + size += info.Size() + } + return nil + }) + + return size, err +} + +// RestoreBackup restores WireGuard configurations from a backup by name +func RestoreBackup(backupName string) error { + backupDir := "/etc/wg-admin/backups" + backupPath := filepath.Join(backupDir, backupName) + + // Verify backup exists + if _, err := os.Stat(backupPath); os.IsNotExist(err) { + return fmt.Errorf("backup does not exist: %s", backupName) + } + + // Check for wireguard subdirectory + wgSourcePath := filepath.Join(backupPath, "wireguard") + if _, err := os.Stat(wgSourcePath); os.IsNotExist(err) { + return fmt.Errorf("backup does not contain wireguard configuration") + } + + // Restore entire wireguard directory + wgDestPath := "/etc/wireguard" + if err := os.RemoveAll(wgDestPath); err != nil { + return fmt.Errorf("failed to remove existing wireguard config: %w", err) + } + + if err := exec.Command("cp", "-a", wgSourcePath, wgDestPath).Run(); err != nil { + return fmt.Errorf("failed to restore wireguard config: %w", err) + } + + // Set proper permissions on restored files + if err := setRestoredPermissions(wgDestPath); err != nil { + fmt.Fprintf(os.Stderr, "Warning: failed to set permissions on restored files: %v\n", err) + } + + return nil +} + +// setRestoredPermissions sets appropriate permissions on restored WireGuard files +func setRestoredPermissions(wgPath string) error { + // Set 0600 on .conf files + return filepath.Walk(wgPath, func(path string, info os.FileInfo, err error) error { + if err != nil { + return err + } + if !info.IsDir() && strings.HasSuffix(path, ".conf") { + if err := os.Chmod(path, 0600); err != nil { + return fmt.Errorf("failed to chmod %s: %w", path, err) + } + } + return nil + }) +} + +// applyRetentionPolicy keeps only the last N backups +func applyRetentionPolicy(backupDir string, keepCount int) error { + // List all backup directories + entries, err := os.ReadDir(backupDir) + if err != nil { + return err + } + + // Filter backup directories and sort by modification time + var backups []os.FileInfo + for _, entry := range entries { + if entry.IsDir() && len(entry.Name()) > 10 && entry.Name()[:10] == "wg-backup-" { + info, err := entry.Info() + if err != nil { + continue + } + backups = append(backups, info) + } + } + + // If we have more backups than we want to keep, remove the oldest + if len(backups) > keepCount { + // Sort by modification time (oldest first) + for i := 0; i < len(backups); i++ { + for j := i + 1; j < len(backups); j++ { + if backups[i].ModTime().After(backups[j].ModTime()) { + backups[i], backups[j] = backups[j], backups[i] + } + } + } + + // Remove oldest backups + toRemove := len(backups) - keepCount + for i := 0; i < toRemove; i++ { + backupPath := filepath.Join(backupDir, backups[i].Name()) + if err := os.RemoveAll(backupPath); err != nil { + fmt.Fprintf(os.Stderr, "Warning: failed to remove old backup %s: %v\n", backupPath, err) + } + } + } + + return nil +} + +// BackupConfig is a compatibility wrapper that calls CreateBackup +func BackupConfig(operation string) (string, error) { + if err := CreateBackup(operation); err != nil { + return "", err + } + // Return the backup path for compatibility + timestamp := time.Now().Format("20060102-150405") + backupName := fmt.Sprintf("wg-backup-%s-%s", operation, timestamp) + return filepath.Join("/etc/wg-admin/backups", backupName), nil +} diff --git a/internal/backup/restore.go b/internal/backup/restore.go new file mode 100644 index 0000000..d94853a --- /dev/null +++ b/internal/backup/restore.go @@ -0,0 +1,122 @@ +package backup + +import ( + "fmt" + "os" + "os/exec" + "path/filepath" + "strings" + "time" +) + +// ValidateBackup checks if a backup exists and is valid +func ValidateBackup(backupName string) error { + backupDir := "/etc/wg-admin/backups" + backupPath := filepath.Join(backupDir, backupName) + + // Check if backup directory exists + info, err := os.Stat(backupPath) + if err != nil { + if os.IsNotExist(err) { + return fmt.Errorf("backup '%s' does not exist", backupName) + } + return fmt.Errorf("failed to access backup '%s': %w", backupName, err) + } + + // Check if it's a directory + if !info.IsDir() { + return fmt.Errorf("'%s' is not a valid backup directory", backupName) + } + + // Check for required files + metadataPath := filepath.Join(backupPath, "backup-info.txt") + if _, err := os.Stat(metadataPath); os.IsNotExist(err) { + return fmt.Errorf("backup '%s' is missing required metadata", backupName) + } + + // Check for wireguard directory + wgBackupPath := filepath.Join(backupPath, "wireguard") + if _, err := os.Stat(wgBackupPath); os.IsNotExist(err) { + return fmt.Errorf("backup '%s' is missing wireguard configuration", backupName) + } + + return nil +} + +// ReloadWireGuard reloads the WireGuard interface to apply configuration changes +func ReloadWireGuard() error { + interfaceName := "wg0" + + // Try to down the interface first + cmdDown := exec.Command("wg-quick", "down", interfaceName) + _ = cmdDown.Run() // Ignore errors if interface is not up + + // Bring the interface up + cmdUp := exec.Command("wg-quick", "up", interfaceName) + output, err := cmdUp.CombinedOutput() + if err != nil { + return fmt.Errorf("failed to reload wireguard interface: %w, output: %s", err, string(output)) + } + + return nil +} + +// GetBackupSize calculates the total size of a backup directory +func GetBackupSize(backupName string) (int64, error) { + backupDir := "/etc/wg-admin/backups" + backupPath := filepath.Join(backupDir, backupName) + + var size int64 + + err := filepath.Walk(backupPath, func(_ string, info os.FileInfo, err error) error { + if err != nil { + return err + } + if !info.IsDir() { + size += info.Size() + } + return nil + }) + + return size, err +} + +// GetBackupPath returns the full path for a backup name +func GetBackupPath(backupName string) string { + return filepath.Join("/etc/wg-admin/backups", backupName) +} + +// ParseBackupName extracts operation and timestamp from backup directory name +// Format: wg-backup-{operation}-{timestamp} +func ParseBackupName(backupName string) (operation, timestamp string, err error) { + if !strings.HasPrefix(backupName, "wg-backup-") { + return "", "", fmt.Errorf("invalid backup name format") + } + + nameWithoutPrefix := strings.TrimPrefix(backupName, "wg-backup-") + + // Timestamp format: 20060102-150405 (15 chars) + if len(nameWithoutPrefix) < 16 { + return "", "", fmt.Errorf("backup name too short") + } + + // Extract operation (everything before last timestamp) + timestampLen := 15 + if len(nameWithoutPrefix) > timestampLen+1 { + operation = nameWithoutPrefix[:len(nameWithoutPrefix)-timestampLen-1] + // Remove trailing dash if present + if strings.HasSuffix(operation, "-") { + operation = operation[:len(operation)-1] + } + } + + // Extract timestamp + timestamp = nameWithoutPrefix[len(nameWithoutPrefix)-timestampLen:] + + // Validate timestamp format + if _, err := time.Parse("20060102-150405", timestamp); err != nil { + return "", "", fmt.Errorf("invalid timestamp format in backup name: %w", err) + } + + return operation, timestamp, nil +} diff --git a/internal/config/config.go b/internal/config/config.go new file mode 100644 index 0000000..b3a5cfd --- /dev/null +++ b/internal/config/config.go @@ -0,0 +1,237 @@ +package config + +import ( + "fmt" + "os" + "strings" +) + +// Config holds application configuration +type Config struct { + ServerDomain string `mapstructure:"SERVER_DOMAIN"` + WGPort int `mapstructure:"WG_PORT"` + VPNIPv4Range string `mapstructure:"VPN_IPV4_RANGE"` + VPNIPv6Range string `mapstructure:"VPN_IPV6_RANGE"` + WGInterface string `mapstructure:"WG_INTERFACE"` + DNSServers string `mapstructure:"DNS_SERVERS"` + LogFile string `mapstructure:"LOG_FILE"` + Theme string `mapstructure:"THEME"` +} + +// Default values +const ( + DefaultWGPort = 51820 + DefaultVPNIPv4Range = "10.10.69.0/24" + DefaultVPNIPv6Range = "fd69:dead:beef:69::/64" + DefaultWGInterface = "wg0" + DefaultDNSServers = "8.8.8.8, 8.8.4.4" + DefaultLogFile = "/var/log/wireguard-admin.log" + DefaultTheme = "default" +) + +// LoadConfig loads configuration from file and environment variables +func LoadConfig() (*Config, error) { + cfg := &Config{} + + // Load from config file if it exists + if err := loadFromFile(cfg); err != nil { + return nil, fmt.Errorf("failed to load config file: %w", err) + } + + // Override with environment variables + if err := loadFromEnv(cfg); err != nil { + return nil, fmt.Errorf("failed to load config from environment: %w", err) + } + + // Apply defaults for empty values + applyDefaults(cfg) + + // Validate required configuration + if err := validateConfig(cfg); err != nil { + return nil, err + } + + return cfg, nil +} + +// loadFromFile reads configuration from /etc/wg-admin/config.conf +func loadFromFile(cfg *Config) error { + configPath := "/etc/wg-admin/config.conf" + + // Check if config file exists + if _, err := os.Stat(configPath); os.IsNotExist(err) { + // Config file is optional, skip if not exists + return nil + } + + // Read config file + content, err := os.ReadFile(configPath) + if err != nil { + return err + } + + // Parse key=value pairs + lines := strings.Split(string(content), "\n") + for _, line := range lines { + line = strings.TrimSpace(line) + if line == "" || strings.HasPrefix(line, "#") { + continue + } + + parts := strings.SplitN(line, "=", 2) + if len(parts) != 2 { + continue + } + + key := strings.TrimSpace(parts[0]) + value := strings.TrimSpace(parts[1]) + + // Set value using mapstructure tags + switch key { + case "SERVER_DOMAIN": + cfg.ServerDomain = value + case "WG_PORT": + var port int + if _, err := fmt.Sscanf(value, "%d", &port); err == nil { + cfg.WGPort = port + } + case "VPN_IPV4_RANGE": + cfg.VPNIPv4Range = value + case "VPN_IPV6_RANGE": + cfg.VPNIPv6Range = value + case "WG_INTERFACE": + cfg.WGInterface = value + case "DNS_SERVERS": + cfg.DNSServers = value + case "LOG_FILE": + cfg.LogFile = value + case "THEME": + cfg.Theme = value + } + } + + return nil +} + +// loadFromEnv loads configuration from environment variables +func loadFromEnv(cfg *Config) error { + // Read environment variables + if val := os.Getenv("SERVER_DOMAIN"); val != "" { + cfg.ServerDomain = val + } + if val := os.Getenv("WG_PORT"); val != "" { + var port int + if _, err := fmt.Sscanf(val, "%d", &port); err == nil { + cfg.WGPort = port + } + } + if val := os.Getenv("VPN_IPV4_RANGE"); val != "" { + cfg.VPNIPv4Range = val + } + if val := os.Getenv("VPN_IPV6_RANGE"); val != "" { + cfg.VPNIPv6Range = val + } + if val := os.Getenv("WG_INTERFACE"); val != "" { + cfg.WGInterface = val + } + if val := os.Getenv("DNS_SERVERS"); val != "" { + cfg.DNSServers = val + } + if val := os.Getenv("LOG_FILE"); val != "" { + cfg.LogFile = val + } + if val := os.Getenv("THEME"); val != "" { + cfg.Theme = val + } + + return nil +} + +// applyDefaults sets default values for empty configuration +func applyDefaults(cfg *Config) { + if cfg.WGPort == 0 { + cfg.WGPort = DefaultWGPort + } + if cfg.VPNIPv4Range == "" { + cfg.VPNIPv4Range = DefaultVPNIPv4Range + } + if cfg.VPNIPv6Range == "" { + cfg.VPNIPv6Range = DefaultVPNIPv6Range + } + if cfg.WGInterface == "" { + cfg.WGInterface = DefaultWGInterface + } + if cfg.DNSServers == "" { + cfg.DNSServers = DefaultDNSServers + } + if cfg.LogFile == "" { + cfg.LogFile = DefaultLogFile + } + if cfg.Theme == "" { + cfg.Theme = DefaultTheme + } +} + +// validateConfig checks that required configuration is present +func validateConfig(cfg *Config) error { + if cfg.ServerDomain == "" { + return fmt.Errorf("SERVER_DOMAIN is required. Set it in /etc/wg-admin/config.conf or via environment variable.") + } + + // Validate port range + if cfg.WGPort < 1 || cfg.WGPort > 65535 { + return fmt.Errorf("WG_PORT must be between 1 and 65535, got: %d", cfg.WGPort) + } + + // Validate CIDR format for IPv4 range + if !isValidCIDR(cfg.VPNIPv4Range, true) { + return fmt.Errorf("Invalid VPN_IPV4_RANGE format: %s", cfg.VPNIPv4Range) + } + + // Validate CIDR format for IPv6 range + if !isValidCIDR(cfg.VPNIPv6Range, false) { + return fmt.Errorf("Invalid VPN_IPV6_RANGE format: %s", cfg.VPNIPv6Range) + } + + return nil +} + +// isValidCIDR performs basic CIDR validation +func isValidCIDR(cidr string, isIPv4 bool) bool { + if cidr == "" { + return false + } + + // Split address and prefix + parts := strings.Split(cidr, "/") + if len(parts) != 2 { + return false + } + + // Basic validation - more comprehensive validation could be added + if isIPv4 { + // IPv4 CIDR should have address like x.x.x.x + return true // Simplified validation + } + + // IPv6 CIDR + return true // Simplified validation +} + +// GetVPNIPv4Network extracts the IPv4 network from CIDR (e.g., "10.10.69.0" from "10.10.69.0/24") +func (c *Config) GetVPNIPv4Network() string { + parts := strings.Split(c.VPNIPv4Range, "/") + if len(parts) == 0 { + return "" + } + return strings.TrimSuffix(parts[0], "0") +} + +// GetVPNIPv6Network extracts the IPv6 network from CIDR (e.g., "fd69:dead:beef:69::" from "fd69:dead:beef:69::/64") +func (c *Config) GetVPNIPv6Network() string { + parts := strings.Split(c.VPNIPv6Range, "/") + if len(parts) == 0 { + return "" + } + return strings.TrimSuffix(parts[0], "::") +} diff --git a/internal/tui/components/confirm.go b/internal/tui/components/confirm.go new file mode 100644 index 0000000..b3ee8b0 --- /dev/null +++ b/internal/tui/components/confirm.go @@ -0,0 +1,143 @@ +package components + +import ( + tea "github.com/charmbracelet/bubbletea" + "github.com/charmbracelet/lipgloss" +) + +// ConfirmModel represents a confirmation modal +type ConfirmModel struct { + Message string + Yes bool // true = yes, false = no + Visible bool + Width int + Height int +} + +// Styles +var ( + confirmModalStyle = lipgloss.NewStyle(). + Border(lipgloss.RoundedBorder()). + BorderForeground(lipgloss.Color("196")). + Padding(1, 2). + Background(lipgloss.Color("235")) + confirmTitleStyle = lipgloss.NewStyle(). + Foreground(lipgloss.Color("226")). + Bold(true) + confirmMessageStyle = lipgloss.NewStyle(). + Foreground(lipgloss.Color("255")). + Width(50) + confirmHelpStyle = lipgloss.NewStyle(). + Foreground(lipgloss.Color("241")). + MarginTop(1) +) + +// NewConfirm creates a new confirmation modal +func NewConfirm(message string, width, height int) *ConfirmModel { + return &ConfirmModel{ + Message: message, + Yes: false, + Visible: true, + Width: width, + Height: height, + } +} + +// Init initializes the confirmation modal +func (m *ConfirmModel) Init() tea.Cmd { + return nil +} + +// Update handles messages for the confirmation modal +func (m *ConfirmModel) Update(msg tea.Msg) (*ConfirmModel, tea.Cmd) { + switch msg := msg.(type) { + case tea.KeyMsg: + switch msg.String() { + case "y", "Y", "left": + m.Yes = true + case "n", "N", "right": + m.Yes = false + case "enter": + // Confirmed - will be handled by parent + return m, nil + case "esc": + m.Visible = false + return m, nil + } + } + return m, nil +} + +// View renders the confirmation modal +func (m *ConfirmModel) View() string { + if !m.Visible { + return "" + } + + // Build modal content + content := lipgloss.JoinVertical( + lipgloss.Left, + confirmTitleStyle.Render("⚠️ Confirm Action"), + "", + confirmMessageStyle.Render(m.Message), + "", + m.renderOptions(), + ) + + // Apply modal style + modal := confirmModalStyle.Render(content) + + // Center modal on screen + modalWidth := lipgloss.Width(modal) + modalHeight := lipgloss.Height(modal) + + x := (m.Width - modalWidth) / 2 + if x < 0 { + x = 0 + } + y := (m.Height - modalHeight) / 2 + if y < 0 { + y = 0 + } + + return lipgloss.Place(m.Width, m.Height, + lipgloss.Left, lipgloss.Top, + modal, + lipgloss.WithWhitespaceChars(" "), + lipgloss.WithWhitespaceForeground(lipgloss.Color("235")), + ) +} + +// renderOptions renders the yes/no options +func (m *ConfirmModel) renderOptions() string { + yesStyle := lipgloss.NewStyle().Foreground(lipgloss.Color("241")) + noStyle := lipgloss.NewStyle().Foreground(lipgloss.Color("241")) + + selectedStyle := lipgloss.NewStyle(). + Foreground(lipgloss.Color("57")). + Bold(true). + Underline(true) + + var yesText, noText string + if m.Yes { + yesText = selectedStyle.Render("[Yes]") + noText = noStyle.Render(" No ") + } else { + yesText = yesStyle.Render(" Yes ") + noText = selectedStyle.Render("[No]") + } + + helpText := confirmHelpStyle.Render("←/→ to choose • Enter to confirm • Esc to cancel") + + return lipgloss.JoinHorizontal(lipgloss.Left, yesText, " ", noText, "\n", helpText) +} + +// IsConfirmed returns true if user confirmed with Yes +func (m *ConfirmModel) IsConfirmed() bool { + return m.Yes +} + +// IsCancelled returns true if user cancelled +func (m *ConfirmModel) IsCancelled() bool { + return !m.Visible +} diff --git a/internal/tui/components/search.go b/internal/tui/components/search.go new file mode 100644 index 0000000..baf908b --- /dev/null +++ b/internal/tui/components/search.go @@ -0,0 +1,307 @@ +package components + +import ( + "strings" + + "github.com/charmbracelet/bubbles/textinput" + tea "github.com/charmbracelet/bubbletea" + "github.com/charmbracelet/lipgloss" +) + +// SearchFilterType represents the type of filter +type SearchFilterType string + +const ( + FilterByName SearchFilterType = "name" + FilterByIPv4 SearchFilterType = "ipv4" + FilterByIPv6 SearchFilterType = "ipv6" + FilterByStatus SearchFilterType = "status" +) + +// SearchModel represents the search component +type SearchModel struct { + input textinput.Model + active bool + filterType SearchFilterType + matchCount int + totalCount int + visible bool +} + +// Styles +var ( + searchBarStyle = lipgloss.NewStyle(). + Foreground(lipgloss.Color("255")). + Background(lipgloss.Color("235")). + Padding(0, 1). + Border(lipgloss.RoundedBorder()). + BorderForeground(lipgloss.Color("240")) + searchPromptStyle = lipgloss.NewStyle(). + Foreground(lipgloss.Color("226")). + Bold(true) + searchFilterStyle = lipgloss.NewStyle(). + Foreground(lipgloss.Color("147")) + searchCountStyle = lipgloss.NewStyle(). + Foreground(lipgloss.Color("241")) + searchHelpStyle = lipgloss.NewStyle(). + Foreground(lipgloss.Color("243")) +) + +// NewSearch creates a new search component +func NewSearch() *SearchModel { + ti := textinput.New() + ti.Placeholder = "Search clients..." + ti.Focus() + ti.CharLimit = 156 + ti.Width = 40 + ti.Prompt = "" + + return &SearchModel{ + input: ti, + active: false, + filterType: FilterByName, + matchCount: 0, + totalCount: 0, + visible: true, + } +} + +// Init initializes the search component +func (m *SearchModel) Init() tea.Cmd { + return nil +} + +// Update handles messages for the search component +func (m *SearchModel) Update(msg tea.Msg) (*SearchModel, tea.Cmd) { + if !m.active { + return m, nil + } + + var cmd tea.Cmd + + switch msg := msg.(type) { + case tea.KeyMsg: + switch msg.String() { + case "esc": + m.active = false + m.input.Reset() + m.matchCount = m.totalCount + return m, nil + case "tab": + m.cycleFilterType() + return m, nil + } + } + + m.input, cmd = m.input.Update(msg) + + return m, cmd +} + +// View renders the search component +func (m *SearchModel) View() string { + if !m.visible { + return "" + } + + var filterLabel string + switch m.filterType { + case FilterByName: + filterLabel = "Name" + case FilterByIPv4: + filterLabel = "IPv4" + case FilterByIPv6: + filterLabel = "IPv6" + case FilterByStatus: + filterLabel = "Status" + } + + searchIndicator := "" + if m.active { + searchIndicator = searchPromptStyle.Render("🔍 ") + } else { + searchIndicator = searchPromptStyle.Render("⌕ ") + } + + filterText := searchFilterStyle.Render("[" + filterLabel + "]") + countText := "" + if m.totalCount > 0 { + countText = searchCountStyle.Render( + lipgloss.JoinHorizontal( + lipgloss.Left, + strings.Repeat(" ", 4), + "Matched: ", + m.renderCount(m.matchCount), + "/", + m.renderCount(m.totalCount), + ), + ) + } + + helpText := "" + if m.active { + helpText = searchHelpStyle.Render(" | Tab: filter | Esc: clear") + } else { + helpText = searchHelpStyle.Render(" | /: search") + } + + content := lipgloss.JoinHorizontal( + lipgloss.Left, + searchIndicator, + m.input.View(), + filterText, + countText, + helpText, + ) + + return searchBarStyle.Render(content) +} + +// IsActive returns true if search is active +func (m *SearchModel) IsActive() bool { + return m.active +} + +// Activate activates the search +func (m *SearchModel) Activate() { + m.active = true + m.input.Focus() +} + +// Deactivate deactivates the search +func (m *SearchModel) Deactivate() { + m.active = false + m.input.Reset() + m.matchCount = m.totalCount +} + +// Clear clears the search input and filter +func (m *SearchModel) Clear() { + m.input.Reset() + m.matchCount = m.totalCount +} + +// GetQuery returns the current search query +func (m *SearchModel) GetQuery() string { + return m.input.Value() +} + +// GetFilterType returns the current filter type +func (m *SearchModel) GetFilterType() SearchFilterType { + return m.filterType +} + +// SetTotalCount sets the total number of items +func (m *SearchModel) SetTotalCount(count int) { + m.totalCount = count + if !m.active { + m.matchCount = count + } +} + +// SetMatchCount sets the number of matching items +func (m *SearchModel) SetMatchCount(count int) { + m.matchCount = count +} + +// Filter filters a list of client data based on the current search query +func (m *SearchModel) Filter(clients []ClientData) []ClientData { + query := strings.TrimSpace(m.input.Value()) + if query == "" || !m.active { + m.matchCount = len(clients) + return clients + } + + var filtered []ClientData + queryLower := strings.ToLower(query) + + for _, client := range clients { + var matches bool + + switch m.filterType { + case FilterByName: + matches = strings.Contains(strings.ToLower(client.Name), queryLower) + case FilterByIPv4: + matches = strings.Contains(strings.ToLower(client.IPv4), queryLower) + case FilterByIPv6: + matches = strings.Contains(strings.ToLower(client.IPv6), queryLower) + case FilterByStatus: + matches = strings.Contains(strings.ToLower(client.Status), queryLower) + } + + if matches { + filtered = append(filtered, client) + } + } + + m.matchCount = len(filtered) + return filtered +} + +// HighlightMatches highlights matching text in the given value +func (m *SearchModel) HighlightMatches(value string) string { + if !m.active { + return value + } + + query := strings.TrimSpace(m.input.Value()) + if query == "" { + return value + } + + queryLower := strings.ToLower(query) + valueLower := strings.ToLower(value) + + index := strings.Index(valueLower, queryLower) + if index == -1 { + return value + } + + matchStyle := lipgloss.NewStyle(). + Foreground(lipgloss.Color("226")). + Background(lipgloss.Color("57")). + Bold(true) + + before := value[:index] + match := value[index+len(query)] + after := value[index+len(query):] + + return lipgloss.JoinHorizontal( + lipgloss.Left, + before, + matchStyle.Render(string(match)), + after, + ) +} + +// cycleFilterType cycles to the next filter type +func (m *SearchModel) cycleFilterType() { + switch m.filterType { + case FilterByName: + m.filterType = FilterByIPv4 + case FilterByIPv4: + m.filterType = FilterByIPv6 + case FilterByIPv6: + m.filterType = FilterByStatus + case FilterByStatus: + m.filterType = FilterByName + } +} + +// renderCount renders a count number with proper styling +func (m *SearchModel) renderCount(count int) string { + if m.matchCount == 0 && m.active && m.input.Value() != "" { + return lipgloss.NewStyle(). + Foreground(lipgloss.Color("196")). + Render("No matches") + } + return searchCountStyle.Render(string(rune('0' + count))) +} + +// ClientData represents client data for filtering +type ClientData struct { + Name string + IPv4 string + IPv6 string + Status string +} diff --git a/internal/tui/model.go b/internal/tui/model.go new file mode 100644 index 0000000..a0f9800 --- /dev/null +++ b/internal/tui/model.go @@ -0,0 +1,89 @@ +package screens + +import ( + tea "github.com/charmbracelet/bubbletea" + "github.com/charmbracelet/lipgloss" +) + +// Screen represents a UI screen (list, add, detail, etc.) +type Screen interface { + Init() tea.Cmd + Update(tea.Msg) (Screen, tea.Cmd) + View() string +} + +// Model is shared state across all screens +type Model struct { + err error + isQuitting bool + ready bool + statusMessage string + screen Screen +} + +// View renders the model (implements Screen interface) +func (m *Model) View() string { + if m.err != nil { + return "\n" + lipgloss.NewStyle(). + Foreground(lipgloss.Color("196")). + Render(m.err.Error()) + } + + if m.isQuitting { + return "\nGoodbye!\n" + } + + if m.screen != nil { + return m.screen.View() + } + + return "Initializing..." +} + +// Init initializes the model +func (m *Model) Init() tea.Cmd { + m.ready = true + return nil +} + +// Update handles incoming messages (implements Screen interface) +func (m *Model) Update(msg tea.Msg) (Screen, tea.Cmd) { + // If we have an error, let it persist for display + if m.err != nil { + switch msg := msg.(type) { + case tea.KeyMsg: + if msg.String() == "q" || msg.String() == "ctrl+c" { + m.isQuitting = true + return m, tea.Quit + } + } + return m, nil + } + + // No error - handle normally + switch msg := msg.(type) { + case tea.KeyMsg: + switch msg.String() { + case "q", "ctrl+c": + m.isQuitting = true + return m, tea.Quit + } + } + + return m, nil +} + +// SetScreen changes the current screen +func (m *Model) SetScreen(screen Screen) { + m.screen = screen +} + +// SetError sets an error message +func (m *Model) SetError(err error) { + m.err = err +} + +// ClearError clears the error message +func (m *Model) ClearError() { + m.err = nil +} diff --git a/internal/tui/screens/add.go b/internal/tui/screens/add.go new file mode 100644 index 0000000..83516c6 --- /dev/null +++ b/internal/tui/screens/add.go @@ -0,0 +1,152 @@ +package screens + +import ( + "fmt" + + "github.com/calmcacil/wg-admin/internal/config" + "github.com/calmcacil/wg-admin/internal/validation" + "github.com/calmcacil/wg-admin/internal/wireguard" + tea "github.com/charmbracelet/bubbletea" + "github.com/charmbracelet/huh" + "github.com/charmbracelet/lipgloss" +) + +// AddScreen is a form for adding new WireGuard clients +type AddScreen struct { + form *huh.Form + quitting bool +} + +// Styles +var ( + addTitleStyle = lipgloss.NewStyle(). + Foreground(lipgloss.Color("62")). + Bold(true). + MarginBottom(1) + addHelpStyle = lipgloss.NewStyle(). + Foreground(lipgloss.Color("241")). + MarginTop(1) +) + +// NewAddScreen creates a new add screen +func NewAddScreen() *AddScreen { + // Get default DNS from config + cfg, err := config.LoadConfig() + defaultDNS := "8.8.8.8, 8.8.4.4" + if err == nil && cfg.DNSServers != "" { + defaultDNS = cfg.DNSServers + } + + // Create the form + form := huh.NewForm( + huh.NewGroup( + huh.NewInput(). + Key("name"). + Title("Client Name"). + Description("Name for the new client (alphanumeric, -, _)"). + Placeholder("e.g., laptop-john"). + Validate(func(s string) error { + return validation.ValidateClientName(s) + }), + + huh.NewInput(). + Key("dns"). + Title("DNS Servers"). + Description("Comma-separated IPv4 addresses"). + Placeholder("e.g., 8.8.8.8, 8.8.4.4"). + Value(&defaultDNS). + Validate(func(s string) error { + return validation.ValidateDNSServers(s) + }), + + huh.NewConfirm(). + Key("use_psk"). + Title("Use Preshared Key"). + Description("Enable additional security layer with a preshared key"). + Affirmative("Yes"). + Negative("No"), + ), + ) + + return &AddScreen{ + form: form, + quitting: false, + } +} + +// Init initializes the add screen +func (s *AddScreen) Init() tea.Cmd { + return s.form.Init() +} + +// Update handles messages for the add screen +func (s *AddScreen) Update(msg tea.Msg) (Screen, tea.Cmd) { + var cmds []tea.Cmd + + switch msg := msg.(type) { + case tea.KeyMsg: + switch msg.String() { + case "q", "ctrl+c", "esc": + // Cancel and return to list + return nil, nil + } + } + + // Update the form + form, cmd := s.form.Update(msg) + if f, ok := form.(*huh.Form); ok { + s.form = f + } + cmds = append(cmds, cmd) + + // Check if form is completed + if s.form.State == huh.StateCompleted { + name := s.form.GetString("name") + dns := s.form.GetString("dns") + usePSK := s.form.GetBool("use_psk") + + // Create the client + return s, s.createClient(name, dns, usePSK) + } + + return s, tea.Batch(cmds...) +} + +// View renders the add screen +func (s *AddScreen) View() string { + if s.quitting { + return "" + } + + content := lipgloss.JoinVertical( + lipgloss.Left, + addTitleStyle.Render("Add New WireGuard Client"), + s.form.View(), + addHelpStyle.Render("Press Enter to submit • Esc to cancel"), + ) + + return content +} + +// createClient creates a new WireGuard client +func (s *AddScreen) createClient(name, dns string, usePSK bool) tea.Cmd { + return func() tea.Msg { + // Create the client via wireguard package + err := wireguard.CreateClient(name, dns, usePSK) + if err != nil { + return errMsg{err: fmt.Errorf("failed to create client: %w", err)} + } + + // Return success message + return ClientCreatedMsg{ + Name: name, + } + } +} + +// Messages + +// ClientCreatedMsg is sent when a client is successfully created +type ClientCreatedMsg struct { + Name string +} diff --git a/internal/tui/screens/detail.go b/internal/tui/screens/detail.go new file mode 100644 index 0000000..f0f75b7 --- /dev/null +++ b/internal/tui/screens/detail.go @@ -0,0 +1,299 @@ +package screens + +import ( + "fmt" + "time" + + "github.com/calmcacil/wg-admin/internal/tui/components" + "github.com/calmcacil/wg-admin/internal/wireguard" + tea "github.com/charmbracelet/bubbletea" + "github.com/charmbracelet/lipgloss" +) + +// DetailScreen displays detailed information about a single WireGuard client +type DetailScreen struct { + client wireguard.Client + status string + lastHandshake time.Time + transferRx string + transferTx string + confirmModal *components.ConfirmModel + showConfirm bool + clipboardCopied bool + clipboardTimer int +} + +// Styles +var ( + detailTitleStyle = lipgloss.NewStyle(). + Foreground(lipgloss.Color("62")). + Bold(true) + detailSectionStyle = lipgloss.NewStyle(). + Foreground(lipgloss.Color("241")). + Bold(true). + MarginTop(1) + detailLabelStyle = lipgloss.NewStyle(). + Foreground(lipgloss.Color("241")). + Width(18) + detailValueStyle = lipgloss.NewStyle(). + Foreground(lipgloss.Color("255")) + detailConnectedStyle = lipgloss.NewStyle(). + Foreground(lipgloss.Color("46")). + Bold(true) + detailDisconnectedStyle = lipgloss.NewStyle(). + Foreground(lipgloss.Color("196")). + Bold(true) + detailWarningStyle = lipgloss.NewStyle(). + Foreground(lipgloss.Color("226")). + Bold(true) + detailHelpStyle = lipgloss.NewStyle(). + Foreground(lipgloss.Color("63")). + MarginTop(1) + detailErrorStyle = lipgloss.NewStyle(). + Foreground(lipgloss.Color("196")). + MarginTop(1) +) + +// NewDetailScreen creates a new detail screen for a client +func NewDetailScreen(client wireguard.Client) *DetailScreen { + return &DetailScreen{ + client: client, + showConfirm: false, + } +} + +// Init initializes the detail screen +func (s *DetailScreen) Init() tea.Cmd { + return s.loadClientStatus +} + +// Update handles messages for the detail screen +func (s *DetailScreen) Update(msg tea.Msg) (Screen, tea.Cmd) { + var cmd tea.Cmd + + // Handle clipboard copy timeout + if s.clipboardCopied { + s.clipboardTimer++ + if s.clipboardTimer > 2 { + s.clipboardCopied = false + s.clipboardTimer = 0 + } + } + + // Handle confirmation modal + if s.showConfirm && s.confirmModal != nil { + _, cmd = s.confirmModal.Update(msg) + + // Handle confirmation result + if !s.confirmModal.Visible { + if s.confirmModal.IsConfirmed() { + // User confirmed deletion + return s, tea.Batch(s.deleteClient(), func() tea.Msg { + return CloseDetailScreenMsg{} + }) + } + // User cancelled - close modal + s.showConfirm = false + return s, nil + } + + // Handle Enter key to confirm + switch msg := msg.(type) { + case tea.KeyMsg: + if msg.String() == "enter" && s.confirmModal.IsConfirmed() { + return s, tea.Batch(s.deleteClient(), func() tea.Msg { + return CloseDetailScreenMsg{} + }) + } + } + + return s, cmd + } + + // Handle normal screen messages + switch msg := msg.(type) { + case clipboardCopiedMsg: + s.clipboardCopied = true + case clientStatusLoadedMsg: + s.status = msg.status + s.lastHandshake = msg.lastHandshake + s.transferRx = msg.transferRx + s.transferTx = msg.transferTx + case tea.KeyMsg: + switch msg.String() { + case "q", "esc": + // Return to list screen - signal parent to switch screens + return s, nil + case "d": + // Show delete confirmation + s.confirmModal = components.NewConfirm( + fmt.Sprintf("Are you sure you want to delete client '%s'?\n\nThis action cannot be undone.", s.client.Name), + 80, + 24, + ) + s.showConfirm = true + case "c": + // Copy public key to clipboard + return s, s.copyPublicKey() + } + } + + return s, cmd +} + +// View renders the detail screen +func (s *DetailScreen) View() string { + if s.showConfirm && s.confirmModal != nil { + // Render underlying content dimmed + content := s.renderContent() + dimmedContent := lipgloss.NewStyle(). + Foreground(lipgloss.Color("244")). + Render(content) + + // Overlay confirmation modal + return lipgloss.JoinVertical( + lipgloss.Left, + dimmedContent, + s.confirmModal.View(), + ) + } + + return s.renderContent() +} + +// renderContent renders the main detail screen content +func (s *DetailScreen) renderContent() string { + statusText := s.status + if s.status == wireguard.StatusConnected { + statusText = detailConnectedStyle.Render(s.status) + } else { + statusText = detailDisconnectedStyle.Render(s.status) + } + + // Build content + content := lipgloss.JoinVertical( + lipgloss.Left, + detailTitleStyle.Render(fmt.Sprintf("Client Details: %s", s.client.Name)), + "", + s.renderField("Status", statusText), + s.renderField("IPv4 Address", detailValueStyle.Render(s.client.IPv4)), + s.renderField("IPv6 Address", detailValueStyle.Render(s.client.IPv6)), + "", + detailSectionStyle.Render("WireGuard Configuration"), + s.renderField("Public Key", detailValueStyle.Render(s.client.PublicKey)), + s.renderField("Preshared Key", detailValueStyle.Render(func() string { + if s.client.HasPSK { + return "✓ Configured" + } + return "Not configured" + }())), + "", + detailSectionStyle.Render("Connection Info"), + s.renderField("Last Handshake", detailValueStyle.Render(s.formatHandshake())), + s.renderField("Transfer (Rx/Tx)", detailValueStyle.Render(fmt.Sprintf("%s / %s", s.transferRx, s.transferTx))), + s.renderField("Config Path", detailValueStyle.Render(s.client.ConfigPath)), + "", + ) + + // Add help text + helpText := detailHelpStyle.Render("Actions: [d] Delete • [c] Copy Public Key • [q] Back") + content = lipgloss.JoinVertical(lipgloss.Left, content, helpText) + + // Show clipboard confirmation + if s.clipboardCopied { + content += "\n" + lipgloss.NewStyle().Foreground(lipgloss.Color("46")).Render("✓ Public key copied to clipboard!") + } + + return content +} + +// renderField renders a label-value pair +func (s *DetailScreen) renderField(label string, value string) string { + return lipgloss.JoinHorizontal(lipgloss.Left, + detailLabelStyle.Render(label), + value, + ) +} + +// formatHandshake formats the last handshake time +func (s *DetailScreen) formatHandshake() string { + if s.lastHandshake.IsZero() { + return "Never" + } + + duration := time.Since(s.lastHandshake) + if duration < time.Minute { + return "Just now" + } else if duration < time.Hour { + return fmt.Sprintf("%d min ago", int(duration.Minutes())) + } else if duration < 24*time.Hour { + return fmt.Sprintf("%d hours ago", int(duration.Hours())) + } else if duration < 7*24*time.Hour { + return fmt.Sprintf("%d days ago", int(duration.Hours()/24)) + } + return s.lastHandshake.Format("2006-01-02 15:04") +} + +// loadClientStatus loads the current status of the client +func (s *DetailScreen) loadClientStatus() tea.Msg { + peers, err := wireguard.GetAllPeers() + if err != nil { + return errMsg{err: err} + } + + // Find peer by public key + for _, peer := range peers { + if peer.PublicKey == s.client.PublicKey { + return clientStatusLoadedMsg{ + status: peer.Status, + lastHandshake: peer.LatestHandshake, + transferRx: peer.TransferRx, + transferTx: peer.TransferTx, + } + } + } + + // Peer not found in active list + return clientStatusLoadedMsg{ + status: wireguard.StatusDisconnected, + lastHandshake: time.Time{}, + transferRx: "", + transferTx: "", + } +} + +// copyPublicKey copies the public key to clipboard +func (s *DetailScreen) copyPublicKey() tea.Cmd { + return func() tea.Msg { + // Note: In a real implementation, you would use a clipboard library like + // github.com/atotto/clipboard or implement platform-specific clipboard access + // For now, we'll just simulate the action + return clipboardCopiedMsg{} + } +} + +// deleteClient deletes the client +func (s *DetailScreen) deleteClient() tea.Cmd { + return func() tea.Msg { + err := wireguard.DeleteClient(s.client.Name) + if err != nil { + return errMsg{fmt.Errorf("failed to delete client: %w", err)} + } + return ClientDeletedMsg{ + Name: s.client.Name, + } + } +} + +// Messages + +// clientStatusLoadedMsg is sent when client status is loaded +type clientStatusLoadedMsg struct { + status string + lastHandshake time.Time + transferRx string + transferTx string +} + +// clipboardCopiedMsg is sent when public key is copied to clipboard +type clipboardCopiedMsg struct{} diff --git a/internal/tui/screens/interface.go b/internal/tui/screens/interface.go new file mode 100644 index 0000000..100a610 --- /dev/null +++ b/internal/tui/screens/interface.go @@ -0,0 +1,31 @@ +package screens + +import ( + tea "github.com/charmbracelet/bubbletea" +) + +// Screen represents a UI screen (list, add, detail, etc.) +type Screen interface { + Init() tea.Cmd + Update(tea.Msg) (Screen, tea.Cmd) + View() string +} + +// ClientSelectedMsg is sent when a client is selected from the list +type ClientSelectedMsg struct { + Client ClientWithStatus +} + +// ClientDeletedMsg is sent when a client is successfully deleted +type ClientDeletedMsg struct { + Name string +} + +// CloseDetailScreenMsg signals to close detail screen +type CloseDetailScreenMsg struct{} + +// RestoreCompletedMsg is sent when a restore operation completes +type RestoreCompletedMsg struct { + Err error + SafetyBackupPath string +} diff --git a/internal/tui/screens/list.go b/internal/tui/screens/list.go new file mode 100644 index 0000000..ad5a809 --- /dev/null +++ b/internal/tui/screens/list.go @@ -0,0 +1,324 @@ +package screens + +import ( + "sort" + "strings" + + "github.com/calmcacil/wg-admin/internal/tui/components" + "github.com/calmcacil/wg-admin/internal/wireguard" + "github.com/charmbracelet/bubbles/table" + tea "github.com/charmbracelet/bubbletea" + "github.com/charmbracelet/lipgloss" +) + +const statusRefreshInterval = 10 // seconds + +// ListScreen displays a table of WireGuard clients +type ListScreen struct { + table table.Model + search *components.SearchModel + clients []ClientWithStatus + filtered []ClientWithStatus + sortedBy string // Column name being sorted by + ascending bool // Sort direction +} + +// ClientWithStatus wraps a client with its connection status +type ClientWithStatus struct { + Client wireguard.Client + Status string +} + +// NewListScreen creates a new list screen +func NewListScreen() *ListScreen { + return &ListScreen{ + search: components.NewSearch(), + sortedBy: "Name", + ascending: true, + } +} + +// Init initializes the list screen +func (s *ListScreen) Init() tea.Cmd { + return tea.Batch( + s.loadClients, + wireguard.Tick(statusRefreshInterval), + ) +} + +// Update handles messages for the list screen +func (s *ListScreen) Update(msg tea.Msg) (Screen, tea.Cmd) { + var cmd tea.Cmd + + switch msg := msg.(type) { + case tea.KeyMsg: + // Handle search activation + if msg.String() == "/" && !s.search.IsActive() { + s.search.Activate() + return s, nil + } + + // If search is active, pass input to search + if s.search.IsActive() { + s.search, cmd = s.search.Update(msg) + // Apply filter to clients + s.applyFilter() + return s, cmd + } + + // Normal key handling when search is not active + switch msg.String() { + case "q", "ctrl+c": + // Handle quit in parent model + return s, nil + case "r": + // Refresh clients + return s, s.loadClients + case "R": + // Show restore screen + return NewRestoreScreen(), nil + case "Q": + // Show QR code for selected client + if len(s.table.Rows()) > 0 { + selected := s.table.SelectedRow() + clientName := selected[0] // First column is Name + return NewQRScreen(clientName), nil + } + case "enter": + // Open detail view for selected client + if len(s.table.Rows()) > 0 { + selectedRow := s.table.SelectedRow() + selectedName := selectedRow[0] // First column is Name + // Find the client with this name + for _, cws := range s.clients { + if cws.Client.Name == selectedName { + return s, func() tea.Msg { + return ClientSelectedMsg{Client: cws} + } + } + } + } + case "1", "2", "3", "4": + // Sort by column number (Name, IPv4, IPv6, Status) + s.sortByColumn(msg.String()) + } + case clientsLoadedMsg: + s.clients = msg.clients + s.search.SetTotalCount(len(s.clients)) + s.applyFilter() + case wireguard.StatusTickMsg: + // Refresh status on periodic tick + return s, s.loadClients + case wireguard.RefreshStatusMsg: + // Refresh status on manual refresh + return s, s.loadClients + } + + s.table, cmd = s.table.Update(msg) + return s, cmd +} + +// View renders the list screen +func (s *ListScreen) View() string { + if len(s.clients) == 0 { + return s.search.View() + "\n" + "No clients found. Press 'r' to refresh or 'q' to quit." + } + + // Check if there are no matches + if s.search.IsActive() && len(s.filtered) == 0 && s.search.GetQuery() != "" { + return s.search.View() + "\n" + lipgloss.NewStyle(). + Foreground(lipgloss.Color("241")). + Italic(true). + Render("No matching clients found. Try a different search term.") + } + + return s.search.View() + "\n" + s.table.View() +} + +// loadClients loads clients from wireguard config +func (s *ListScreen) loadClients() tea.Msg { + clients, err := wireguard.ListClients() + if err != nil { + return errMsg{err: err} + } + + // Get status for each client + clientsWithStatus := make([]ClientWithStatus, len(clients)) + for i, client := range clients { + status, err := wireguard.GetClientStatus(client.PublicKey) + if err != nil { + status = wireguard.StatusDisconnected + } + clientsWithStatus[i] = ClientWithStatus{ + Client: client, + Status: status, + } + } + + return clientsLoadedMsg{clients: clientsWithStatus} +} + +// applyFilter applies the current search filter to clients +func (s *ListScreen) applyFilter() { + // Convert clients to ClientData for filtering + clientData := make([]components.ClientData, len(s.clients)) + for i, cws := range s.clients { + clientData[i] = components.ClientData{ + Name: cws.Client.Name, + IPv4: cws.Client.IPv4, + IPv6: cws.Client.IPv6, + Status: cws.Status, + } + } + + // Filter clients + filteredData := s.search.Filter(clientData) + + // Convert back to ClientWithStatus + s.filtered = make([]ClientWithStatus, len(filteredData)) + for i, cd := range filteredData { + // Find the matching client + for _, cws := range s.clients { + if cws.Client.Name == cd.Name { + s.filtered[i] = cws + break + } + } + } + + // Rebuild table with filtered clients + s.buildTable() +} + +// buildTable creates and configures the table +func (s *ListScreen) buildTable() { + columns := []table.Column{ + {Title: "Name", Width: 20}, + {Title: "IPv4", Width: 15}, + {Title: "IPv6", Width: 35}, + {Title: "Status", Width: 12}, + } + + // Use filtered clients if search is active, otherwise use all clients + displayClients := s.filtered + if !s.search.IsActive() { + displayClients = s.clients + } + + var rows []table.Row + for _, cws := range displayClients { + row := table.Row{ + cws.Client.Name, + cws.Client.IPv4, + cws.Client.IPv6, + cws.Status, + } + rows = append(rows, row) + } + + // Sort rows based on current sort settings + s.sortRows(rows) + + // Determine table height + tableHeight := len(rows) + 2 // Header + rows + if tableHeight < 5 { + tableHeight = 5 + } + + s.table = table.New( + table.WithColumns(columns), + table.WithRows(rows), + table.WithFocused(true), + table.WithHeight(tableHeight), + ) + + // Apply styles + s.setTableStyles() +} + +// setTableStyles applies styling to the table +func (s *ListScreen) setTableStyles() { + styles := table.DefaultStyles() + styles.Header = styles.Header. + BorderStyle(lipgloss.NormalBorder()). + BorderForeground(lipgloss.Color("240")). + BorderBottom(true). + Bold(true) + styles.Selected = styles.Selected. + Foreground(lipgloss.Color("229")). + Background(lipgloss.Color("57")). + Bold(false) + s.table.SetStyles(styles) +} + +// sortRows sorts the rows based on the current sort settings +func (s *ListScreen) sortRows(rows []table.Row) { + colIndex := s.getColumnIndex(s.sortedBy) + + sort.Slice(rows, func(i, j int) bool { + var valI, valJ string + if colIndex < len(rows[i]) { + valI = rows[i][colIndex] + } + if colIndex < len(rows[j]) { + valJ = rows[j][colIndex] + } + + if s.ascending { + return strings.ToLower(valI) < strings.ToLower(valJ) + } + return strings.ToLower(valI) > strings.ToLower(valJ) + }) +} + +// sortByColumn changes the sort column +func (s *ListScreen) sortByColumn(col string) { + sortedBy := "Name" + switch col { + case "1": + sortedBy = "Name" + case "2": + sortedBy = "IPv4" + case "3": + sortedBy = "IPv6" + case "4": + sortedBy = "Status" + } + + // Toggle direction if clicking same column + if s.sortedBy == sortedBy { + s.ascending = !s.ascending + } else { + s.sortedBy = sortedBy + s.ascending = true + } + + s.buildTable() +} + +// getColumnIndex returns the index of a column by name +func (s *ListScreen) getColumnIndex(name string) int { + switch name { + case "Name": + return 0 + case "IPv4": + return 1 + case "IPv6": + return 2 + case "Status": + return 3 + } + return 0 +} + +// Messages + +// clientsLoadedMsg is sent when clients are loaded +type clientsLoadedMsg struct { + clients []ClientWithStatus +} + +// errMsg is sent when an error occurs +type errMsg struct { + err error +} diff --git a/internal/tui/screens/qr.go b/internal/tui/screens/qr.go new file mode 100644 index 0000000..bccb220 --- /dev/null +++ b/internal/tui/screens/qr.go @@ -0,0 +1,141 @@ +package screens + +import ( + "fmt" + "strings" + + "github.com/calmcacil/wg-admin/internal/wireguard" + tea "github.com/charmbracelet/bubbletea" + "github.com/charmbracelet/lipgloss" + "github.com/mdp/qrterminal/v3" +) + +// QRScreen displays a QR code for a WireGuard client configuration +type QRScreen struct { + clientName string + configContent string + qrCode string + inlineMode bool + width, height int + errorMsg string +} + +// NewQRScreen creates a new QR screen for displaying client config QR codes +func NewQRScreen(clientName string) *QRScreen { + return &QRScreen{ + clientName: clientName, + inlineMode: true, // Start in inline mode + } +} + +// Init initializes the QR screen +func (s *QRScreen) Init() tea.Cmd { + return s.loadConfig +} + +// Update handles messages for the QR screen +func (s *QRScreen) Update(msg tea.Msg) (Screen, tea.Cmd) { + switch msg := msg.(type) { + case tea.KeyMsg: + switch msg.String() { + case "q", "escape": + // Return to list screen (parent should handle this) + return nil, nil + case "f": + // Toggle between inline and fullscreen mode + s.inlineMode = !s.inlineMode + s.generateQRCode() + } + case tea.WindowSizeMsg: + // Handle terminal resize + s.width = msg.Width + s.height = msg.Height + s.generateQRCode() + case configLoadedMsg: + s.configContent = msg.content + s.generateQRCode() + case errMsg: + s.errorMsg = msg.err.Error() + } + + return s, nil +} + +// View renders the QR screen +func (s *QRScreen) View() string { + if s.errorMsg != "" { + return s.renderError() + } + if s.qrCode == "" { + return "Loading QR code..." + } + return s.renderQR() +} + +// loadConfig loads the client configuration +func (s *QRScreen) loadConfig() tea.Msg { + content, err := wireguard.GetClientConfigContent(s.clientName) + if err != nil { + return errMsg{err: err} + } + return configLoadedMsg{content: content} +} + +// generateQRCode generates the QR code based on current mode and terminal size +func (s *QRScreen) generateQRCode() { + if s.configContent == "" { + return + } + + // Generate QR code and capture output + var builder strings.Builder + + // Generate ANSI QR code using half-block characters + qrterminal.GenerateHalfBlock(s.configContent, qrterminal.L, &builder) + + s.qrCode = builder.String() +} + +// renderQR renders the QR code with styling +func (s *QRScreen) renderQR() string { + styleTitle := lipgloss.NewStyle(). + Foreground(lipgloss.Color("62")). + Bold(true). + MarginBottom(1) + + styleHelp := lipgloss.NewStyle(). + Foreground(lipgloss.Color("241")). + MarginTop(1) + + styleQR := lipgloss.NewStyle(). + MarginLeft(2) + + title := styleTitle.Render(fmt.Sprintf("QR Code: %s", s.clientName)) + help := "Press [f] to toggle fullscreen • Press [q/Escape] to return" + + return title + "\n\n" + styleQR.Render(s.qrCode) + "\n" + styleHelp.Render(help) +} + +// renderError renders an error message +func (s *QRScreen) renderError() string { + styleError := lipgloss.NewStyle(). + Foreground(lipgloss.Color("196")). + Bold(true) + + styleHelp := lipgloss.NewStyle(). + Foreground(lipgloss.Color("241")). + MarginTop(1) + + title := styleError.Render("Error") + message := s.errorMsg + help := "Press [q/Escape] to return" + + return title + "\n\n" + message + "\n" + styleHelp.Render(help) +} + +// Messages + +// configLoadedMsg is sent when the client configuration is loaded +type configLoadedMsg struct { + content string +} diff --git a/internal/tui/screens/restore.go b/internal/tui/screens/restore.go new file mode 100644 index 0000000..26daff9 --- /dev/null +++ b/internal/tui/screens/restore.go @@ -0,0 +1,333 @@ +package screens + +import ( + "fmt" + "strings" + + "github.com/calmcacil/wg-admin/internal/backup" + "github.com/calmcacil/wg-admin/internal/tui/components" + "github.com/charmbracelet/bubbles/table" + tea "github.com/charmbracelet/bubbletea" + "github.com/charmbracelet/lipgloss" +) + +// RestoreScreen displays a list of available backups for restoration +type RestoreScreen struct { + table table.Model + backups []backup.Backup + selectedBackup *backup.Backup + confirmModal *components.ConfirmModel + showConfirm bool + isRestoring bool + restoreError error + restoreSuccess bool + message string +} + +// Styles +var ( + restoreTitleStyle = lipgloss.NewStyle(). + Foreground(lipgloss.Color("62")). + Bold(true) + restoreHelpStyle = lipgloss.NewStyle(). + Foreground(lipgloss.Color("63")). + MarginTop(1) + restoreSuccessStyle = lipgloss.NewStyle(). + Foreground(lipgloss.Color("46")). + Bold(true) + restoreErrorStyle = lipgloss.NewStyle(). + Foreground(lipgloss.Color("196")). + Bold(true) + restoreInfoStyle = lipgloss.NewStyle(). + Foreground(lipgloss.Color("241")). + MarginTop(1) +) + +// NewRestoreScreen creates a new restore screen +func NewRestoreScreen() *RestoreScreen { + return &RestoreScreen{ + showConfirm: false, + } +} + +// Init initializes the restore screen +func (s *RestoreScreen) Init() tea.Cmd { + return s.loadBackups +} + +// Update handles messages for the restore screen +func (s *RestoreScreen) Update(msg tea.Msg) (Screen, tea.Cmd) { + var cmd tea.Cmd + + // Handle confirmation modal + if s.showConfirm && s.confirmModal != nil { + _, cmd = s.confirmModal.Update(msg) + + // Handle confirmation result + if !s.confirmModal.Visible { + if s.confirmModal.IsConfirmed() && s.selectedBackup != nil { + // User confirmed restore + s.isRestoring = true + s.showConfirm = false + return s, s.performRestore() + } + // User cancelled - close modal + s.showConfirm = false + return s, nil + } + + // Handle Enter key to confirm + switch msg := msg.(type) { + case tea.KeyMsg: + if msg.String() == "enter" && s.confirmModal.IsConfirmed() && s.selectedBackup != nil { + s.isRestoring = true + s.showConfirm = false + return s, s.performRestore() + } + } + + return s, cmd + } + + // Handle normal screen messages + switch msg := msg.(type) { + case backupsLoadedMsg: + s.backups = msg.backups + s.buildTable() + case tea.KeyMsg: + switch msg.String() { + case "q", "esc": + // Return to list screen - signal parent to switch screens + return s, nil + case "enter": + // Show confirmation for selected backup + if len(s.table.Rows()) > 0 { + selected := s.table.SelectedRow() + if len(selected) > 0 { + // Find the backup by name + for _, b := range s.backups { + if b.Name == selected[0] { + s.selectedBackup = &b + s.confirmModal = components.NewConfirm( + fmt.Sprintf( + "Are you sure you want to restore from backup '%s'?\n\nOperation: %s\nDate: %s\n\nThis will replace current WireGuard configuration.\nA safety backup will be created first.", + b.Name, + b.Operation, + b.Timestamp.Format("2006-01-02 15:04:05"), + ), + 80, + 24, + ) + s.showConfirm = true + break + } + } + } + } + } + case restoreCompletedMsg: + s.isRestoring = false + if msg.err != nil { + s.restoreError = msg.err + s.message = fmt.Sprintf("Restore failed: %v", msg.err) + } else { + s.restoreSuccess = true + s.message = fmt.Sprintf("Restore successful! Safety backup created at: %s", msg.safetyBackupPath) + } + } + + if !s.showConfirm && s.confirmModal != nil { + s.table, cmd = s.table.Update(msg) + } + return s, cmd +} + +// View renders the restore screen +func (s *RestoreScreen) View() string { + if s.showConfirm && s.confirmModal != nil { + // Render underlying content dimmed + content := s.renderContent() + dimmedContent := lipgloss.NewStyle(). + Foreground(lipgloss.Color("244")). + Render(content) + + // Overlay confirmation modal + return lipgloss.JoinVertical( + lipgloss.Left, + dimmedContent, + s.confirmModal.View(), + ) + } + + return s.renderContent() +} + +// renderContent renders the main restore screen content +func (s *RestoreScreen) renderContent() string { + var content strings.Builder + + content.WriteString(restoreTitleStyle.Render("Restore WireGuard Configuration")) + content.WriteString("\n\n") + + if len(s.backups) == 0 && !s.isRestoring && s.message == "" { + content.WriteString("No backups found. Press 'q' to return.") + return content.String() + } + + if s.isRestoring { + content.WriteString("Restoring from backup, please wait...") + return content.String() + } + + if s.restoreSuccess { + content.WriteString(restoreSuccessStyle.Render("✓ " + s.message)) + content.WriteString("\n\n") + content.WriteString(restoreInfoStyle.Render("Press 'q' to return to client list.")) + return content.String() + } + + if s.restoreError != nil { + content.WriteString(restoreErrorStyle.Render("✗ " + s.message)) + content.WriteString("\n\n") + content.WriteString(s.table.View()) + content.WriteString("\n\n") + content.WriteString(restoreHelpStyle.Render("Actions: [Enter] Restore Selected • [↑/↓] Navigate • [q] Back")) + return content.String() + } + + // Show backup list + content.WriteString(s.table.View()) + content.WriteString("\n\n") + + // Show selected backup details + if len(s.table.Rows()) > 0 && s.selectedBackup != nil { + content.WriteString(restoreInfoStyle.Render( + fmt.Sprintf( + "Selected: %s (%s) - %s\nSize: %s", + s.selectedBackup.Operation, + s.selectedBackup.Timestamp.Format("2006-01-02 15:04:05"), + s.selectedBackup.Name, + formatBytes(s.selectedBackup.Size), + ), + )) + content.WriteString("\n") + } + + content.WriteString(restoreHelpStyle.Render("Actions: [Enter] Restore Selected • [↑/↓] Navigate • [q] Back")) + + return content.String() +} + +// loadBackups loads the list of available backups +func (s *RestoreScreen) loadBackups() tea.Msg { + backups, err := backup.ListBackups() + if err != nil { + return errMsg{err: err} + } + return backupsLoadedMsg{backups: backups} +} + +// buildTable creates and configures the backup list table +func (s *RestoreScreen) buildTable() { + columns := []table.Column{ + {Title: "Name", Width: 40}, + {Title: "Operation", Width: 15}, + {Title: "Date", Width: 20}, + {Title: "Size", Width: 12}, + } + + var rows []table.Row + for _, b := range s.backups { + row := table.Row{ + b.Name, + b.Operation, + b.Timestamp.Format("2006-01-02 15:04"), + formatBytes(b.Size), + } + rows = append(rows, row) + } + + s.table = table.New( + table.WithColumns(columns), + table.WithRows(rows), + table.WithFocused(true), + table.WithHeight(len(rows)+2), // Header + rows + ) + + // Apply styles + s.setTableStyles() +} + +// setTableStyles applies styling to the table +func (s *RestoreScreen) setTableStyles() { + styles := table.DefaultStyles() + styles.Header = styles.Header. + BorderStyle(lipgloss.NormalBorder()). + BorderForeground(lipgloss.Color("240")). + BorderBottom(true). + Bold(true) + styles.Selected = styles.Selected. + Foreground(lipgloss.Color("229")). + Background(lipgloss.Color("57")). + Bold(false) + s.table.SetStyles(styles) +} + +// performRestore performs the restore operation +func (s *RestoreScreen) performRestore() tea.Cmd { + return func() tea.Msg { + if s.selectedBackup == nil { + return restoreCompletedMsg{ + err: fmt.Errorf("no backup selected"), + } + } + + // Get safety backup path from backup.BackupConfig + safetyBackupPath, err := backup.BackupConfig(fmt.Sprintf("pre-restore-from-%s", s.selectedBackup.Name)) + if err != nil { + return restoreCompletedMsg{ + err: fmt.Errorf("failed to create safety backup: %w", err), + } + } + + // Perform restore + if err := backup.RestoreBackup(s.selectedBackup.Name); err != nil { + return restoreCompletedMsg{ + err: err, + safetyBackupPath: safetyBackupPath, + } + } + + // Restore succeeded - trigger client list refresh + return restoreCompletedMsg{ + safetyBackupPath: safetyBackupPath, + } + } +} + +// formatBytes formats a byte count into human-readable format +func formatBytes(bytes int64) string { + const unit = 1024 + if bytes < unit { + return fmt.Sprintf("%d B", bytes) + } + div, exp := int64(unit), 0 + for n := bytes / unit; n >= unit; n /= unit { + div *= unit + exp++ + } + return fmt.Sprintf("%.1f %ciB", float64(bytes)/float64(div), "KMGTPE"[exp]) +} + +// Messages + +// backupsLoadedMsg is sent when backups are loaded +type backupsLoadedMsg struct { + backups []backup.Backup +} + +// restoreCompletedMsg is sent when a restore operation completes +type restoreCompletedMsg struct { + err error + safetyBackupPath string +} diff --git a/internal/tui/theme/theme.go b/internal/tui/theme/theme.go new file mode 100644 index 0000000..d5bfba5 --- /dev/null +++ b/internal/tui/theme/theme.go @@ -0,0 +1,185 @@ +package theme + +import ( + "fmt" + "os" + "sync" + + "github.com/charmbracelet/lipgloss" +) + +// ColorScheme defines the color palette for the theme +type ColorScheme struct { + Primary lipgloss.Color + Success lipgloss.Color + Warning lipgloss.Color + Error lipgloss.Color + Muted lipgloss.Color + Background lipgloss.Color +} + +// Theme represents a color theme with its name and color scheme +type Theme struct { + Name string + Scheme ColorScheme +} + +// Global variables for current theme and styles +var ( + currentTheme *Theme + once sync.Once + + // Global styles that can be used throughout the application + StylePrimary lipgloss.Style + StyleSuccess lipgloss.Style + StyleWarning lipgloss.Style + StyleError lipgloss.Style + StyleMuted lipgloss.Style + StyleTitle lipgloss.Style + StyleSubtitle lipgloss.Style + StyleHelpKey lipgloss.Style +) + +// DefaultTheme is the standard blue-based theme +var DefaultTheme = &Theme{ + Name: "default", + Scheme: ColorScheme{ + Primary: lipgloss.Color("62"), // Blue + Success: lipgloss.Color("46"), // Green + Warning: lipgloss.Color("208"), // Orange + Error: lipgloss.Color("196"), // Red + Muted: lipgloss.Color("241"), // Gray + Background: lipgloss.Color(""), // Default terminal background + }, +} + +// DarkTheme is a purple-based dark theme +var DarkTheme = &Theme{ + Name: "dark", + Scheme: ColorScheme{ + Primary: lipgloss.Color("141"), // Purple + Success: lipgloss.Color("51"), // Cyan + Warning: lipgloss.Color("226"), // Yellow + Error: lipgloss.Color("196"), // Red + Muted: lipgloss.Color("245"), // Light gray + Background: lipgloss.Color(""), // Default terminal background + }, +} + +// LightTheme is a green-based light theme +var LightTheme = &Theme{ + Name: "light", + Scheme: ColorScheme{ + Primary: lipgloss.Color("34"), // Green + Success: lipgloss.Color("36"), // Teal + Warning: lipgloss.Color("214"), // Amber + Error: lipgloss.Color("196"), // Red + Muted: lipgloss.Color("244"), // Gray + Background: lipgloss.Color(""), // Default terminal background + }, +} + +// ThemeRegistry holds all available themes +var ThemeRegistry = map[string]*Theme{ + "default": DefaultTheme, + "dark": DarkTheme, + "light": LightTheme, +} + +// GetTheme loads the theme from config or environment variable +// Returns the default theme if no theme is specified +func GetTheme() (*Theme, error) { + once.Do(func() { + // Try to get theme from environment variable first + themeName := os.Getenv("THEME") + if themeName == "" { + themeName = "default" + } + + // Look up the theme in the registry + if theme, ok := ThemeRegistry[themeName]; ok { + currentTheme = theme + } else { + // If theme not found, use default + currentTheme = DefaultTheme + } + + // Apply the theme to initialize styles + ApplyTheme(currentTheme) + }) + + return currentTheme, nil +} + +// ApplyTheme applies the given theme to all global styles +func ApplyTheme(theme *Theme) { + currentTheme = theme + + // Primary style + StylePrimary = lipgloss.NewStyle(). + Foreground(theme.Scheme.Primary) + + // Success style + StyleSuccess = lipgloss.NewStyle(). + Foreground(theme.Scheme.Success) + + // Warning style + StyleWarning = lipgloss.NewStyle(). + Foreground(theme.Scheme.Warning) + + // Error style + StyleError = lipgloss.NewStyle(). + Foreground(theme.Scheme.Error) + + // Muted style + StyleMuted = lipgloss.NewStyle(). + Foreground(theme.Scheme.Muted) + + // Title style (bold primary) + StyleTitle = lipgloss.NewStyle(). + Foreground(theme.Scheme.Primary). + Bold(true) + + // Subtitle style (muted) + StyleSubtitle = lipgloss.NewStyle(). + Foreground(theme.Scheme.Muted) + + // Help key style (bold primary, slightly different shade) + StyleHelpKey = lipgloss.NewStyle(). + Foreground(theme.Scheme.Primary). + Bold(true) +} + +// GetThemeNames returns a list of available theme names +func GetThemeNames() []string { + names := make([]string, 0, len(ThemeRegistry)) + for name := range ThemeRegistry { + names = append(names, name) + } + return names +} + +// SetTheme changes the current theme by name +func SetTheme(name string) error { + theme, ok := ThemeRegistry[name] + if !ok { + return fmt.Errorf("theme '%s' not found. Available themes: %v", name, GetThemeNames()) + } + + ApplyTheme(theme) + return nil +} + +// GetCurrentTheme returns the currently active theme +func GetCurrentTheme() *Theme { + if currentTheme == nil { + currentTheme = DefaultTheme + ApplyTheme(currentTheme) + } + return currentTheme +} + +// GetColorScheme returns the color scheme of the current theme +func GetColorScheme() ColorScheme { + return GetCurrentTheme().Scheme +} diff --git a/internal/validation/client.go b/internal/validation/client.go new file mode 100644 index 0000000..ced92d9 --- /dev/null +++ b/internal/validation/client.go @@ -0,0 +1,86 @@ +package validation + +import ( + "fmt" + "regexp" + "strings" +) + +// ValidateClientName validates the client name format +// Requirements: alphanumeric, hyphens, underscores only, max 64 characters +func ValidateClientName(name string) error { + if name == "" { + return fmt.Errorf("client name cannot be empty") + } + + if len(name) > 64 { + return fmt.Errorf("client name must be 64 characters or less (got %d)", len(name)) + } + + // Check for valid characters: a-z, A-Z, 0-9, hyphen, underscore + matched, err := regexp.MatchString(`^[a-zA-Z0-9_-]+$`, name) + if err != nil { + return fmt.Errorf("failed to validate client name: %w", err) + } + + if !matched { + return fmt.Errorf("client name contains invalid characters (allowed: a-z, A-Z, 0-9, -, _)") + } + + return nil +} + +// ValidateIPAvailability checks if an IP address is already assigned to a client +// This function is meant to be called after getting next available IPs +func ValidateIPAvailability(ipv4, ipv6 string, existingIPs map[string]bool) error { + if ipv4 != "" && existingIPs[ipv4] { + return fmt.Errorf("IPv4 address %s is already assigned", ipv4) + } + + if ipv6 != "" && existingIPs[ipv6] { + return fmt.Errorf("IPv6 address %s is already assigned", ipv6) + } + + return nil +} + +// ValidateDNSServers validates DNS server addresses +// Requirements: comma-separated list of IPv4 addresses +func ValidateDNSServers(dns string) error { + if dns == "" { + return fmt.Errorf("DNS servers cannot be empty") + } + + // Split by comma and trim whitespace + servers := strings.Split(dns, ",") + for _, server := range servers { + server = strings.TrimSpace(server) + if server == "" { + continue + } + + // Basic IPv4 validation (4 octets, each 0-255) + matched, err := regexp.MatchString(`^(\d{1,3}\.){3}\d{1,3}$`, server) + if err != nil { + return fmt.Errorf("failed to validate DNS server: %w", err) + } + + if !matched { + return fmt.Errorf("invalid DNS server format: %s (expected IPv4 address)", server) + } + + // Validate each octet is in range 0-255 + parts := strings.Split(server, ".") + for _, part := range parts { + var num int + if _, err := fmt.Sscanf(part, "%d", &num); err != nil { + return fmt.Errorf("invalid DNS server: %s", server) + } + if num < 0 || num > 255 { + return fmt.Errorf("DNS server octet out of range (0-255) in: %s", server) + } + } + } + + return nil +} diff --git a/internal/wireguard/client.go b/internal/wireguard/client.go new file mode 100644 index 0000000..e7e67f1 --- /dev/null +++ b/internal/wireguard/client.go @@ -0,0 +1,585 @@ +package wireguard + +import ( + "bytes" + "fmt" + "log" + "os" + "os/exec" + "path/filepath" + "strings" + + "github.com/calmcacil/wg-admin/internal/backup" + "github.com/calmcacil/wg-admin/internal/config" +) + +// Client represents a WireGuard peer configuration +type Client struct { + Name string // Client name extracted from filename + IPv4 string // IPv4 address from AllowedIPs + IPv6 string // IPv6 address from AllowedIPs + PublicKey string // WireGuard public key + HasPSK bool // Whether PresharedKey is configured + ConfigPath string // Path to the client config file +} + +// ParseClientConfig parses a single WireGuard client configuration file +func ParseClientConfig(path string) (*Client, error) { + // Read the file + content, err := os.ReadFile(path) + if err != nil { + return nil, fmt.Errorf("failed to read config file %s: %w", path, err) + } + + // Extract client name from filename + base := filepath.Base(path) + name := strings.TrimPrefix(base, "client-") + name = strings.TrimSuffix(name, ".conf") + + if name == "" || name == "client-" { + return nil, fmt.Errorf("invalid client filename: %s", base) + } + + client := &Client{ + Name: name, + ConfigPath: path, + } + + // Parse the INI-style config + inPeerSection := false + hasPublicKey := false + + for i, line := range strings.Split(string(content), "\n") { + line = strings.TrimSpace(line) + + // Skip empty lines and comments + if line == "" || strings.HasPrefix(line, "#") || strings.HasPrefix(line, ";") { + continue + } + + // Check for Peer section + if line == "[Peer]" { + inPeerSection = true + continue + } + + // Parse key-value pairs within Peer section + if inPeerSection { + parts := strings.SplitN(line, "=", 2) + if len(parts) != 2 { + log.Printf("Warning: malformed line %d in %s: %s", i+1, path, line) + continue + } + + key := strings.TrimSpace(parts[0]) + value := strings.TrimSpace(parts[1]) + + switch key { + case "PublicKey": + client.PublicKey = value + hasPublicKey = true + case "PresharedKey": + client.HasPSK = true + case "AllowedIPs": + if err := parseAllowedIPs(client, value); err != nil { + log.Printf("Warning: %v (file: %s, line: %d)", err, path, i+1) + } + } + } + } + + // Validate required fields + if !hasPublicKey { + return nil, fmt.Errorf("missing required PublicKey in %s", path) + } + + if client.IPv4 == "" && client.IPv6 == "" { + return nil, fmt.Errorf("no valid IP addresses found in AllowedIPs in %s", path) + } + + return client, nil +} + +// parseAllowedIPs extracts IPv4 and IPv6 addresses from AllowedIPs value +func parseAllowedIPs(client *Client, allowedIPs string) error { + // AllowedIPs format: "ipv4/32, ipv6/128" + addresses := strings.Split(allowedIPs, ",") + + for _, addr := range addresses { + addr = strings.TrimSpace(addr) + if addr == "" { + continue + } + + // Split IP from CIDR suffix + parts := strings.Split(addr, "/") + if len(parts) != 2 { + return fmt.Errorf("invalid AllowedIP format: %s", addr) + } + + ip := strings.TrimSpace(parts[0]) + + // Detect if IPv4 or IPv6 based on presence of colon + if strings.Contains(ip, ":") { + client.IPv6 = ip + } else { + client.IPv4 = ip + } + } + + return nil +} + +// ListClients finds and parses all client configurations from /etc/wireguard/conf.d/ +func ListClients() ([]Client, error) { + configDir := "/etc/wireguard/conf.d" + + // Check if directory exists + if _, err := os.Stat(configDir); os.IsNotExist(err) { + return nil, fmt.Errorf("wireguard config directory does not exist: %s", configDir) + } + + // Find all client-*.conf files + pattern := filepath.Join(configDir, "client-*.conf") + matches, err := filepath.Glob(pattern) + if err != nil { + return nil, fmt.Errorf("failed to find client config files: %w", err) + } + + if len(matches) == 0 { + return []Client{}, nil // No clients found, return empty slice + } + + // Parse each config file + var clients []Client + var parseErrors []string + + for _, match := range matches { + client, err := ParseClientConfig(match) + if err != nil { + parseErrors = append(parseErrors, err.Error()) + log.Printf("Warning: failed to parse %s: %v", match, err) + continue + } + clients = append(clients, *client) + } + + // If all files failed to parse, return an error + if len(clients) == 0 && len(parseErrors) > 0 { + return nil, fmt.Errorf("failed to parse any client configs: %s", strings.Join(parseErrors, "; ")) + } + + return clients, nil +} + +// GetClientConfigContent reads the raw configuration content for a client +func GetClientConfigContent(name string) (string, error) { + configDir := "/etc/wireguard/clients" + configPath := filepath.Join(configDir, fmt.Sprintf("%s.conf", name)) + + content, err := os.ReadFile(configPath) + if err != nil { + if os.IsNotExist(err) { + return "", fmt.Errorf("client config not found: %s", configPath) + } + return "", fmt.Errorf("failed to read client config %s: %w", configPath, err) + } + + return string(content), nil +} + +// DeleteClient removes a WireGuard client configuration and associated files +func DeleteClient(name string) error { + // First, find the client config to get public key for removal from interface + configDir := "/etc/wireguard/conf.d" + configPath := filepath.Join(configDir, fmt.Sprintf("client-%s.conf", name)) + + client, err := ParseClientConfig(configPath) + if err != nil { + return fmt.Errorf("failed to parse client config for deletion: %w", err) + } + + log.Printf("Deleting client: %s (public key: %s)", name, client.PublicKey) + + // Create backup before deletion + backupPath, err := backup.BackupConfig(fmt.Sprintf("delete-%s", name)) + if err != nil { + log.Printf("Warning: failed to create backup before deletion: %v", err) + } else { + log.Printf("Created backup: %s", backupPath) + } + + // Remove peer from WireGuard interface using wg command + if err := removePeerFromInterface(client.PublicKey); err != nil { + log.Printf("Warning: failed to remove peer from interface: %v", err) + } + + // Remove client config from /etc/wireguard/conf.d/ + if err := os.Remove(configPath); err != nil && !os.IsNotExist(err) { + return fmt.Errorf("failed to remove client config %s: %w", configPath, err) + } + log.Printf("Removed client config: %s", configPath) + + // Remove client files from /etc/wireguard/clients/ + clientsDir := "/etc/wireguard/clients" + clientFile := filepath.Join(clientsDir, fmt.Sprintf("%s.conf", name)) + if err := os.Remove(clientFile); err != nil && !os.IsNotExist(err) { + log.Printf("Warning: failed to remove client file %s: %v", clientFile, err) + } else { + log.Printf("Removed client file: %s", clientFile) + } + + // Remove QR code PNG if it exists + qrFile := filepath.Join(clientsDir, fmt.Sprintf("%s.png", name)) + if err := os.Remove(qrFile); err != nil && !os.IsNotExist(err) { + log.Printf("Warning: failed to remove QR code %s: %v", qrFile, err) + } else { + log.Printf("Removed QR code: %s", qrFile) + } + + log.Printf("Successfully deleted client: %s", name) + return nil +} + +// removePeerFromInterface removes a peer from the WireGuard interface +func removePeerFromInterface(publicKey string) error { + // Use wg command to remove peer + cmd := exec.Command("wg", "set", "wg0", "peer", publicKey, "remove") + output, err := cmd.CombinedOutput() + if err != nil { + return fmt.Errorf("wg set peer remove failed: %w, output: %s", err, string(output)) + } + return nil +} + +// CreateClient creates a new WireGuard client configuration +func CreateClient(name, dns string, usePSK bool) error { + log.Printf("Creating client: %s (PSK: %v)", name, usePSK) + + // Create backup before creating client + backupPath, err := backup.BackupConfig(fmt.Sprintf("create-%s", name)) + if err != nil { + log.Printf("Warning: failed to create backup before creating client: %v", err) + } else { + log.Printf("Created backup: %s", backupPath) + } + + // Generate keys + privateKey, publicKey, err := generateKeyPair() + if err != nil { + return fmt.Errorf("failed to generate key pair: %w", err) + } + + var psk string + if usePSK { + psk, err = generatePSK() + if err != nil { + return fmt.Errorf("failed to generate PSK: %w", err) + } + } + + // Get next available IP addresses + cfg, err := config.LoadConfig() + if err != nil { + return fmt.Errorf("failed to load config: %w", err) + } + + clients, err := ListClients() + if err != nil { + return fmt.Errorf("failed to list existing clients: %w", err) + } + + ipv4, err := getNextAvailableIP(cfg.VPNIPv4Range, clients) + if err != nil { + return fmt.Errorf("failed to get IPv4 address: %w", err) + } + + ipv6, err := getNextAvailableIP(cfg.VPNIPv6Range, clients) + if err != nil { + return fmt.Errorf("failed to get IPv6 address: %w", err) + } + + // Create server config + serverConfigPath := fmt.Sprintf("/etc/wireguard/conf.d/client-%s.conf", name) + serverConfig, err := generateServerConfig(name, publicKey, ipv4, ipv6, psk, cfg) + if err != nil { + return fmt.Errorf("failed to generate server config: %w", err) + } + + if err := os.WriteFile(serverConfigPath, []byte(serverConfig), 0600); err != nil { + return fmt.Errorf("failed to write server config: %w", err) + } + log.Printf("Created server config: %s", serverConfigPath) + + // Create client config + clientsDir := "/etc/wireguard/clients" + if err := os.MkdirAll(clientsDir, 0700); err != nil { + return fmt.Errorf("failed to create clients directory: %w", err) + } + + clientConfigPath := filepath.Join(clientsDir, fmt.Sprintf("%s.conf", name)) + clientConfig, err := generateClientConfig(name, privateKey, ipv4, ipv6, dns, cfg) + if err != nil { + return fmt.Errorf("failed to generate client config: %w", err) + } + + if err := os.WriteFile(clientConfigPath, []byte(clientConfig), 0600); err != nil { + return fmt.Errorf("failed to write client config: %w", err) + } + log.Printf("Created client config: %s", clientConfigPath) + + // Add peer to WireGuard interface + if err := addPeerToInterface(publicKey, ipv4, ipv6, psk); err != nil { + log.Printf("Warning: failed to add peer to interface: %v", err) + } else { + log.Printf("Added peer to WireGuard interface") + } + + // Generate QR code + qrPath := filepath.Join(clientsDir, fmt.Sprintf("%s.png", name)) + if err := generateQRCode(clientConfigPath, qrPath); err != nil { + log.Printf("Warning: failed to generate QR code: %v", err) + } else { + log.Printf("Generated QR code: %s", qrPath) + } + + log.Printf("Successfully created client: %s", name) + return nil +} + +// generateKeyPair generates a WireGuard private and public key pair +func generateKeyPair() (privateKey, publicKey string, err error) { + // Generate private key + privateKeyBytes, err := exec.Command("wg", "genkey").Output() + if err != nil { + return "", "", fmt.Errorf("wg genkey failed: %w", err) + } + privateKey = strings.TrimSpace(string(privateKeyBytes)) + + // Derive public key + pubKeyCmd := exec.Command("wg", "pubkey") + pubKeyCmd.Stdin = strings.NewReader(privateKey) + publicKeyBytes, err := pubKeyCmd.Output() + if err != nil { + return "", "", fmt.Errorf("wg pubkey failed: %w", err) + } + publicKey = strings.TrimSpace(string(publicKeyBytes)) + + return privateKey, publicKey, nil +} + +// generatePSK generates a WireGuard preshared key +func generatePSK() (string, error) { + psk, err := exec.Command("wg", "genpsk").Output() + if err != nil { + return "", fmt.Errorf("wg genpsk failed: %w", err) + } + return strings.TrimSpace(string(psk)), nil +} + +// getNextAvailableIP finds the next available IP address in the given CIDR range +func getNextAvailableIP(cidr string, existingClients []Client) (string, error) { + // Parse CIDR to get network + parts := strings.Split(cidr, "/") + if len(parts) != 2 { + return "", fmt.Errorf("invalid CIDR format: %s", cidr) + } + + network := strings.TrimSpace(parts[0]) + + // For IPv4, extract base network and assign next available host + if !strings.Contains(network, ":") { + // IPv4: Simple implementation - use .1, .2, etc. + // In production, this would parse the CIDR properly + usedHosts := make(map[string]bool) + for _, client := range existingClients { + if client.IPv4 != "" { + ipParts := strings.Split(client.IPv4, ".") + if len(ipParts) == 4 { + usedHosts[ipParts[3]] = true + } + } + } + + // Find next available host (skip 0 and 1 as they may be reserved) + for i := 2; i < 255; i++ { + host := fmt.Sprintf("%d", i) + if !usedHosts[host] { + return fmt.Sprintf("%s.%s", network, host), nil + } + } + + return "", fmt.Errorf("no available IPv4 addresses in range: %s", cidr) + } + + // IPv6: Similar simplified approach + usedHosts := make(map[string]bool) + for _, client := range existingClients { + if client.IPv6 != "" { + // Extract last segment for IPv6 + lastColon := strings.LastIndex(client.IPv6, ":") + if lastColon > 0 { + host := client.IPv6[lastColon+1:] + usedHosts[host] = true + } + } + } + + // Find next available host + for i := 1; i < 65536; i++ { + host := fmt.Sprintf("%x", i) + if !usedHosts[host] { + return fmt.Sprintf("%s:%s", network, host), nil + } + } + + return "", fmt.Errorf("no available IPv6 addresses in range: %s", cidr) +} + +// generateServerConfig generates the server-side configuration for a client +func generateServerConfig(name, publicKey, ipv4, ipv6, psk string, cfg *config.Config) (string, error) { + var builder strings.Builder + + builder.WriteString(fmt.Sprintf("# Client: %s\n", name)) + builder.WriteString(fmt.Sprintf("[Peer]\n")) + builder.WriteString(fmt.Sprintf("PublicKey = %s\n", publicKey)) + + allowedIPs := "" + if ipv4 != "" { + allowedIPs = ipv4 + "/32" + } + if ipv6 != "" { + if allowedIPs != "" { + allowedIPs += ", " + } + allowedIPs += ipv6 + "/128" + } + builder.WriteString(fmt.Sprintf("AllowedIPs = %s\n", allowedIPs)) + + if psk != "" { + builder.WriteString(fmt.Sprintf("PresharedKey = %s\n", psk)) + } + + return builder.String(), nil +} + +// generateClientConfig generates the client-side configuration +func generateClientConfig(name, privateKey, ipv4, ipv6, dns string, cfg *config.Config) (string, error) { + // Get server's public key from the main config + serverConfigPath := "/etc/wireguard/wg0.conf" + serverPublicKey, serverEndpoint, err := getServerConfig(serverConfigPath) + if err != nil { + return "", fmt.Errorf("failed to read server config: %w", err) + } + + var builder strings.Builder + + builder.WriteString(fmt.Sprintf("# WireGuard client configuration for %s\n", name)) + builder.WriteString("[Interface]\n") + builder.WriteString(fmt.Sprintf("PrivateKey = %s\n", privateKey)) + builder.WriteString(fmt.Sprintf("Address = %s/32", ipv4)) + if ipv6 != "" { + builder.WriteString(fmt.Sprintf(", %s/128", ipv6)) + } + builder.WriteString("\n") + builder.WriteString(fmt.Sprintf("DNS = %s\n", dns)) + + builder.WriteString("\n") + builder.WriteString("[Peer]\n") + builder.WriteString(fmt.Sprintf("PublicKey = %s\n", serverPublicKey)) + builder.WriteString(fmt.Sprintf("Endpoint = %s:%d\n", serverEndpoint, cfg.WGPort)) + builder.WriteString("AllowedIPs = 0.0.0.0/0, ::/0\n") + + return builder.String(), nil +} + +// getServerConfig reads the server's public key and endpoint from the main config +func getServerConfig(path string) (publicKey, endpoint string, err error) { + content, err := os.ReadFile(path) + if err != nil { + return "", "", fmt.Errorf("failed to read server config: %w", err) + } + + inInterfaceSection := false + for _, line := range strings.Split(string(content), "\n") { + line = strings.TrimSpace(line) + + if line == "[Interface]" { + inInterfaceSection = true + continue + } + if line == "[Peer]" { + inInterfaceSection = false + continue + } + + if inInterfaceSection { + if strings.HasPrefix(line, "PublicKey") { + parts := strings.SplitN(line, "=", 2) + if len(parts) == 2 { + publicKey = strings.TrimSpace(parts[1]) + } + } + } + } + + // Use SERVER_DOMAIN as endpoint if available, otherwise fallback + cfg, err := config.LoadConfig() + if err == nil && cfg.ServerDomain != "" { + endpoint = cfg.ServerDomain + } else { + endpoint = "0.0.0.0" + } + + return publicKey, endpoint, nil +} + +// addPeerToInterface adds a peer to the WireGuard interface +func addPeerToInterface(publicKey, ipv4, ipv6, psk string) error { + args := []string{"set", "wg0", "peer", publicKey} + + if ipv4 != "" { + args = append(args, "allowed-ips", ipv4+"/32") + } + if ipv6 != "" { + args = append(args, ipv6+"/128") + } + + if psk != "" { + args = append(args, "preshared-key", "/dev/stdin") + cmd := exec.Command("wg", args...) + cmd.Stdin = strings.NewReader(psk) + output, err := cmd.CombinedOutput() + if err != nil { + return fmt.Errorf("wg set peer failed: %w, output: %s", err, string(output)) + } + } else { + cmd := exec.Command("wg", args...) + output, err := cmd.CombinedOutput() + if err != nil { + return fmt.Errorf("wg set peer failed: %w, output: %s", err, string(output)) + } + } + + return nil +} + +// generateQRCode generates a QR code from the client config +func generateQRCode(configPath, qrPath string) error { + // Read config file + content, err := os.ReadFile(configPath) + if err != nil { + return fmt.Errorf("failed to read config file: %w", err) + } + + // Generate QR code using qrencode or similar + // For now, use a simple approach with qrencode if available + cmd := exec.Command("qrencode", "-o", qrPath, "-t", "PNG") + cmd.Stdin = bytes.NewReader(content) + if err := cmd.Run(); err != nil { + // qrencode not available, try alternative method + return fmt.Errorf("qrencode not available: %w", err) + } + + return nil +} diff --git a/internal/wireguard/config.go b/internal/wireguard/config.go new file mode 100644 index 0000000..775c8e4 --- /dev/null +++ b/internal/wireguard/config.go @@ -0,0 +1,106 @@ +package wireguard + +import ( + "fmt" + "log" + "os" + "path/filepath" + "strings" +) + +// GenerateServerConfig generates a server-side WireGuard [Peer] configuration +// and writes it to /etc/wireguard/conf.d/client-.conf +func GenerateServerConfig(name, publicKey string, hasPSK bool, ipv4, ipv6, psk string) (string, error) { + var builder strings.Builder + + builder.WriteString(fmt.Sprintf("[Peer]\n")) + builder.WriteString(fmt.Sprintf("# %s\n", name)) + builder.WriteString(fmt.Sprintf("PublicKey = %s\n", publicKey)) + + if hasPSK { + builder.WriteString(fmt.Sprintf("PresharedKey = %s\n", psk)) + } + + builder.WriteString(fmt.Sprintf("AllowedIPs = %s/32", ipv4)) + if ipv6 != "" { + builder.WriteString(fmt.Sprintf(", %s/128", ipv6)) + } + builder.WriteString("\n") + + configContent := builder.String() + + configDir := "/etc/wireguard/conf.d" + configPath := filepath.Join(configDir, fmt.Sprintf("client-%s.conf", name)) + + if err := atomicWrite(configPath, configContent); err != nil { + return "", fmt.Errorf("failed to write server config: %w", err) + } + + log.Printf("Generated server config: %s", configPath) + return configPath, nil +} + +// GenerateClientConfig generates a client-side WireGuard configuration +// with [Interface] and [Peer] sections and writes it to /etc/wireguard/clients/.conf +func GenerateClientConfig(name, privateKey, ipv4, ipv6, dns, serverPublicKey, endpoint string, port int, hasPSK bool, psk string) (string, error) { + var builder strings.Builder + + // [Interface] section + builder.WriteString("[Interface]\n") + builder.WriteString(fmt.Sprintf("PrivateKey = %s\n", privateKey)) + builder.WriteString(fmt.Sprintf("Address = %s/24", ipv4)) + if ipv6 != "" { + builder.WriteString(fmt.Sprintf(", %s/64", ipv6)) + } + builder.WriteString("\n") + builder.WriteString(fmt.Sprintf("DNS = %s\n", dns)) + + // [Peer] section + builder.WriteString("\n[Peer]\n") + builder.WriteString(fmt.Sprintf("PublicKey = %s\n", serverPublicKey)) + + if hasPSK { + builder.WriteString(fmt.Sprintf("PresharedKey = %s\n", psk)) + } + + builder.WriteString(fmt.Sprintf("Endpoint = %s:%d\n", endpoint, port)) + builder.WriteString("AllowedIPs = 0.0.0.0/0, ::/0\n") + builder.WriteString("PersistentKeepalive = 25\n") + + configContent := builder.String() + + clientsDir := "/etc/wireguard/clients" + configPath := filepath.Join(clientsDir, fmt.Sprintf("%s.conf", name)) + + if err := atomicWrite(configPath, configContent); err != nil { + return "", fmt.Errorf("failed to write client config: %w", err) + } + + log.Printf("Generated client config: %s", configPath) + return configPath, nil +} + +// atomicWrite writes content to a file atomically +// Uses temp file + rename pattern for atomicity and sets permissions to 0600 +func atomicWrite(path, content string) error { + // Ensure parent directory exists + dir := filepath.Dir(path) + if err := os.MkdirAll(dir, 0755); err != nil { + return fmt.Errorf("failed to create directory %s: %w", dir, err) + } + + // Write to temp file first + tempPath := path + ".tmp" + if err := os.WriteFile(tempPath, []byte(content), 0600); err != nil { + return fmt.Errorf("failed to write temp file: %w", err) + } + + // Atomic rename + if err := os.Rename(tempPath, path); err != nil { + // Clean up temp file if rename fails + _ = os.Remove(tempPath) + return fmt.Errorf("failed to rename temp file: %w", err) + } + + return nil +} diff --git a/internal/wireguard/keys.go b/internal/wireguard/keys.go new file mode 100644 index 0000000..638005a --- /dev/null +++ b/internal/wireguard/keys.go @@ -0,0 +1,243 @@ +package wireguard + +import ( + "encoding/base64" + "fmt" + "log" + "os" + "os/exec" + "path/filepath" + "strings" + "sync" +) + +const ( + // KeyLength is the standard WireGuard key length in bytes (32 bytes = 256 bits) + KeyLength = 32 + // Base64KeyLength is the length of a base64-encoded WireGuard key (44 characters) + Base64KeyLength = 44 + // TempDir is the directory for temporary key storage + TempDir = "/tmp/wg-admin" +) + +var ( + tempKeys = make(map[string]bool) + tempKeysMutex sync.Mutex +) + +// KeyPair represents a WireGuard key pair +type KeyPair struct { + PrivateKey string // Base64-encoded private key + PublicKey string // Base64-encoded public key +} + +// GeneratePrivateKey generates a new WireGuard private key using wg genkey +func GeneratePrivateKey() (string, error) { + cmd := exec.Command("wg", "genkey") + output, err := cmd.CombinedOutput() + if err != nil { + return "", fmt.Errorf("wg genkey failed: %w, output: %s", err, string(output)) + } + + key := string(output) + if err := ValidateKey(key); err != nil { + return "", fmt.Errorf("generated invalid private key: %w", err) + } + + // Store as temporary key + tempKeyPath := filepath.Join(TempDir, "private.key") + if err := storeTempKey(tempKeyPath, key); err != nil { + log.Printf("Warning: failed to store temporary private key: %v", err) + } + + return key, nil +} + +// GeneratePublicKey generates a public key from a private key using wg pubkey +func GeneratePublicKey(privateKey string) (string, error) { + if err := ValidateKey(privateKey); err != nil { + return "", fmt.Errorf("invalid private key: %w", err) + } + + cmd := exec.Command("wg", "pubkey") + cmd.Stdin = strings.NewReader(privateKey) + output, err := cmd.CombinedOutput() + if err != nil { + return "", fmt.Errorf("wg pubkey failed: %w, output: %s", err, string(output)) + } + + key := string(output) + if err := ValidateKey(key); err != nil { + return "", fmt.Errorf("generated invalid public key: %w", err) + } + + // Store as temporary key + tempKeyPath := filepath.Join(TempDir, "public.key") + if err := storeTempKey(tempKeyPath, key); err != nil { + log.Printf("Warning: failed to store temporary public key: %v", err) + } + + return key, nil +} + +// GeneratePSK generates a new pre-shared key using wg genpsk +func GeneratePSK() (string, error) { + cmd := exec.Command("wg", "genpsk") + output, err := cmd.CombinedOutput() + if err != nil { + return "", fmt.Errorf("wg genpsk failed: %w, output: %s", err, string(output)) + } + + key := string(output) + if err := ValidateKey(key); err != nil { + return "", fmt.Errorf("generated invalid PSK: %w", err) + } + + // Store as temporary key + tempKeyPath := filepath.Join(TempDir, "psk.key") + if err := storeTempKey(tempKeyPath, key); err != nil { + log.Printf("Warning: failed to store temporary PSK: %v", err) + } + + return key, nil +} + +// GenerateKeyPair generates a complete WireGuard key pair (private + public) +func GenerateKeyPair() (*KeyPair, error) { + privateKey, err := GeneratePrivateKey() + if err != nil { + return nil, err + } + + publicKey, err := GeneratePublicKey(privateKey) + if err != nil { + return nil, fmt.Errorf("failed to generate public key: %w", err) + } + + return &KeyPair{ + PrivateKey: privateKey, + PublicKey: publicKey, + }, nil +} + +// ValidateKey validates that a key is properly formatted (44 base64 characters) +func ValidateKey(key string) error { + // Trim whitespace + key = strings.TrimSpace(key) + + // Check length (44 base64 characters for 32 bytes) + if len(key) != Base64KeyLength { + return fmt.Errorf("invalid key length: expected %d characters, got %d", Base64KeyLength, len(key)) + } + + // Verify it's valid base64 + decoded, err := base64.StdEncoding.DecodeString(key) + if err != nil { + return fmt.Errorf("invalid base64 encoding: %w", err) + } + + // Verify decoded length is 32 bytes + if len(decoded) != KeyLength { + return fmt.Errorf("invalid decoded key length: expected %d bytes, got %d", KeyLength, len(decoded)) + } + + return nil +} + +// StoreKey atomically writes a key to a file with 0600 permissions +func StoreKey(path string, key string) error { + // Validate key before storing + if err := ValidateKey(key); err != nil { + return fmt.Errorf("invalid key: %w", err) + } + + // Trim whitespace + key = strings.TrimSpace(key) + + // Create parent directories if needed + dir := filepath.Dir(path) + if err := os.MkdirAll(dir, 0755); err != nil { + return fmt.Errorf("failed to create directory %s: %w", dir, err) + } + + // Write to temporary file + tempPath := path + ".tmp" + if err := os.WriteFile(tempPath, []byte(key), 0600); err != nil { + return fmt.Errorf("failed to write temp file %s: %w", tempPath, err) + } + + // Atomic rename + if err := os.Rename(tempPath, path); err != nil { + os.Remove(tempPath) // Clean up temp file on failure + return fmt.Errorf("failed to rename temp file to %s: %w", path, err) + } + + return nil +} + +// LoadKey reads a key from a file and validates it +func LoadKey(path string) (string, error) { + data, err := os.ReadFile(path) + if err != nil { + return "", fmt.Errorf("failed to read key file %s: %w", path, err) + } + + key := strings.TrimSpace(string(data)) + if err := ValidateKey(key); err != nil { + return "", fmt.Errorf("invalid key in file %s: %w", path, err) + } + + return key, nil +} + +// storeTempKey stores a temporary key and tracks it for cleanup +func storeTempKey(path string, key string) error { + // Create temp directory if needed + if err := os.MkdirAll(TempDir, 0700); err != nil { + return fmt.Errorf("failed to create temp directory %s: %w", TempDir, err) + } + + // Trim whitespace + key = strings.TrimSpace(key) + + // Write to file with 0600 permissions + if err := os.WriteFile(path, []byte(key), 0600); err != nil { + return fmt.Errorf("failed to write temp key to %s: %w", path, err) + } + + // Track for cleanup + tempKeysMutex.Lock() + tempKeys[path] = true + tempKeysMutex.Unlock() + + return nil +} + +// CleanupTempKeys removes all temporary keys +func CleanupTempKeys() error { + tempKeysMutex.Lock() + defer tempKeysMutex.Unlock() + + var cleanupErrors []string + + for path := range tempKeys { + if err := os.Remove(path); err != nil && !os.IsNotExist(err) { + cleanupErrors = append(cleanupErrors, fmt.Sprintf("%s: %v", path, err)) + log.Printf("Warning: failed to remove temp key %s: %v", path, err) + } + delete(tempKeys, path) + } + + // Also attempt to clean up the temp directory if empty + if _, err := os.ReadDir(TempDir); err == nil { + if err := os.Remove(TempDir); err != nil && !os.IsNotExist(err) { + log.Printf("Warning: failed to remove temp directory %s: %v", TempDir, err) + } + } + + if len(cleanupErrors) > 0 { + return fmt.Errorf("cleanup errors: %s", strings.Join(cleanupErrors, "; ")) + } + + return nil +} diff --git a/internal/wireguard/status.go b/internal/wireguard/status.go new file mode 100644 index 0000000..4cba2cb --- /dev/null +++ b/internal/wireguard/status.go @@ -0,0 +1,204 @@ +package wireguard + +import ( + "bytes" + "fmt" + "os/exec" + "regexp" + "strconv" + "strings" + "time" +) + +const ( + // StatusConnected indicates a peer has an active connection + StatusConnected = "Connected" + // StatusDisconnected indicates a peer is not connected + StatusDisconnected = "Disconnected" +) + +// PeerStatus represents the status of a WireGuard peer +type PeerStatus struct { + PublicKey string `json:"public_key"` + Endpoint string `json:"endpoint"` + AllowedIPs string `json:"allowed_ips"` + LatestHandshake time.Time `json:"latest_handshake"` + TransferRx string `json:"transfer_rx"` + TransferTx string `json:"transfer_tx"` + Status string `json:"status"` // "Connected" or "Disconnected" +} + +// GetClientStatus checks if a specific client is connected +// Returns "Connected" if the peer appears in the active peers list, "Disconnected" otherwise +func GetClientStatus(publicKey string) (string, error) { + peers, err := GetAllPeers() + if err != nil { + return StatusDisconnected, fmt.Errorf("failed to get peer status: %w", err) + } + + for _, peer := range peers { + if peer.PublicKey == publicKey { + return peer.Status, nil + } + } + + return StatusDisconnected, nil +} + +// GetAllPeers retrieves all peers with their current status from WireGuard +func GetAllPeers() ([]PeerStatus, error) { + output, err := exec.Command("wg", "show", "wg0").Output() + if err != nil { + return nil, fmt.Errorf("failed to execute wg show: %w", err) + } + + return parsePeersOutput(string(output)), nil +} + +// parsePeersOutput parses the output of 'wg show wg0' command +func parsePeersOutput(output string) []PeerStatus { + var peers []PeerStatus + var currentPeer *PeerStatus + var handshake string + var transfer string + + lines := strings.Split(output, "\n") + peerLineRegex := regexp.MustCompile(`^peer:\s*(.+)$`) + handshakeRegex := regexp.MustCompile(`^latest handshake:\s*(.+)\s+ago$`) + transferRegex := regexp.MustCompile(`^transfer:\s*(.+),\s+(.+)$`) + + for _, line := range lines { + line = strings.TrimSpace(line) + + // Check for new peer + if match := peerLineRegex.FindStringSubmatch(line); match != nil { + // Save previous peer if exists + if currentPeer != nil { + peers = append(peers, finalizePeerStatus(currentPeer, handshake, transfer)) + } + + // Start new peer + currentPeer = &PeerStatus{ + PublicKey: match[1], + } + handshake = "" + transfer = "" + continue + } + + if currentPeer == nil { + continue + } + + // Parse endpoint + if strings.HasPrefix(line, "endpoint:") { + currentPeer.Endpoint = strings.TrimSpace(strings.TrimPrefix(line, "endpoint:")) + } + + // Parse allowed ips + if strings.HasPrefix(line, "allowed ips:") { + currentPeer.AllowedIPs = strings.TrimSpace(strings.TrimPrefix(line, "allowed ips:")) + } + + // Parse latest handshake + if match := handshakeRegex.FindStringSubmatch(line); match != nil { + handshake = match[1] + } + + // Parse transfer + if match := transferRegex.FindStringSubmatch(line); match != nil { + transfer = fmt.Sprintf("%s, %s", match[1], match[2]) + } + } + + // Don't forget the last peer + if currentPeer != nil { + peers = append(peers, finalizePeerStatus(currentPeer, handshake, transfer)) + } + + return peers +} + +// finalizePeerStatus determines the peer's status based on handshake time +func finalizePeerStatus(peer *PeerStatus, handshake string, transfer string) PeerStatus { + peer.TransferRx = "" + peer.TransferTx = "" + + // Parse transfer + if transfer != "" { + parts := strings.Split(transfer, ", ") + if len(parts) == 2 { + // Extract received and sent values + rxParts := strings.Fields(parts[0]) + if len(rxParts) >= 2 { + peer.TransferRx = strings.Join(rxParts[:2], " ") + } + txParts := strings.Fields(parts[1]) + if len(txParts) >= 2 { + peer.TransferTx = strings.Join(txParts[:2], " ") + } + } + } + + // Determine status based on handshake + if handshake != "" { + peer.LatestHandshake = parseHandshake(handshake) + // Peer is considered connected if handshake is recent (within 3 minutes) + if time.Since(peer.LatestHandshake) < 3*time.Minute { + peer.Status = StatusConnected + } else { + peer.Status = StatusDisconnected + } + } else { + peer.Status = StatusDisconnected + } + + return *peer +} + +// parseHandshake converts handshake string to time.Time +func parseHandshake(handshake string) time.Time { + now := time.Now() + parts := strings.Fields(handshake) + + for i, part := range parts { + if strings.HasSuffix(part, "second") || strings.HasSuffix(part, "seconds") { + if val, err := strconv.Atoi(strings.TrimSuffix(part, "s")); err == nil { + return now.Add(-time.Duration(val) * time.Second) + } + } + if strings.HasSuffix(part, "minute") || strings.HasSuffix(part, "minutes") { + if val, err := strconv.Atoi(strings.TrimSuffix(part, "s")); err == nil { + return now.Add(-time.Duration(val) * time.Minute) + } + } + if strings.HasSuffix(part, "hour") || strings.HasSuffix(part, "hours") { + if val, err := strconv.Atoi(strings.TrimSuffix(part, "s")); err == nil { + return now.Add(-time.Duration(val) * time.Hour) + } + } + if strings.HasSuffix(part, "day") || strings.HasSuffix(part, "days") { + if val, err := strconv.Atoi(strings.TrimSuffix(part, "s")); err == nil { + return now.Add(-time.Duration(val) * 24 * time.Hour) + } + } + // Handle "ago" word + if i > 0 && (part == "ago" || part == "ago,") { + // Continue parsing time units + } + } + + return now.Add(-time.Hour) // Default to 1 hour ago if parsing fails +} + +// CheckInterface verifies if the WireGuard interface exists and is accessible +func CheckInterface(interfaceName string) error { + cmd := exec.Command("wg", "show", interfaceName) + var out bytes.Buffer + cmd.Stdout = &out + err := cmd.Run() + if err != nil { + return fmt.Errorf("wireguard interface '%s' not accessible: %w", interfaceName, err) + } + return nil +} diff --git a/internal/wireguard/tea_messages.go b/internal/wireguard/tea_messages.go new file mode 100644 index 0000000..243e6d9 --- /dev/null +++ b/internal/wireguard/tea_messages.go @@ -0,0 +1,27 @@ +package wireguard + +import ( + "time" + + tea "github.com/charmbracelet/bubbletea" +) + +// StatusTickMsg is sent when it's time to refresh the status +type StatusTickMsg struct{} + +// RefreshStatusMsg is sent when manual refresh is triggered +type RefreshStatusMsg struct{} + +// Tick returns a tea.Cmd that will send StatusTickMsg at the specified interval +func Tick(interval int) tea.Cmd { + return tea.Tick(time.Duration(interval)*time.Second, func(t time.Time) tea.Msg { + return StatusTickMsg{} + }) +} + +// ManualRefresh returns a tea.Cmd to trigger immediate status refresh +func ManualRefresh() tea.Cmd { + return func() tea.Msg { + return RefreshStatusMsg{} + } +} diff --git a/test-wg-install.sh b/test-wg-install.sh new file mode 100755 index 0000000..b70f9b1 --- /dev/null +++ b/test-wg-install.sh @@ -0,0 +1,124 @@ +#!/usr/bin/env bash +# Test script for wg-install.sh validation functions + +set -euo pipefail + +# Override log location for tests +export WGI_LOG_FILE="/tmp/wg-admin-install-test.log" + +# Source the main script to get functions +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +source "${SCRIPT_DIR}/wg-install.sh" + +# Test counter +TESTS_RUN=0 +TESTS_PASSED=0 +TESTS_FAILED=0 + +# Test function +run_test() { + local test_name="$1" + local expected="$2" + shift 2 + local result + + TESTS_RUN=$((TESTS_RUN + 1)) + + # Run function and capture output and exit code + if result=$("$@" 2>&1); then + local exit_code=$? + else + local exit_code=$? + fi + + if [[ "$expected" == "success" && "$exit_code" -eq 0 ]]; then + TESTS_PASSED=$((TESTS_PASSED + 1)) + echo "✓ PASS: $test_name" + return 0 + elif [[ "$expected" == "failure" && "$exit_code" -ne 0 ]]; then + TESTS_PASSED=$((TESTS_PASSED + 1)) + echo "✓ PASS: $test_name" + return 0 + else + TESTS_FAILED=$((TESTS_FAILED + 1)) + echo "✗ FAIL: $test_name" + echo " Expected: $expected" + echo " Exit code: $exit_code" + echo " Output: $result" + return 1 + fi +} + +echo "=== Running wg-install.sh Tests ===" +echo "" + +# Test validate_server_domain +echo "Testing validate_server_domain..." +run_test "Valid domain" success validate_server_domain "vpn.example.com" +run_test "Valid subdomain" success validate_server_domain "wg.vpn.example.com" +run_test "Valid domain with hyphen" success validate_server_domain "my-vpn.example.com" +run_test "Empty domain" failure validate_server_domain "" +run_test "Invalid domain - spaces" failure validate_server_domain "vpn example.com" +run_test "Invalid domain - special chars" failure validate_server_domain "vpn@example.com" +echo "" + +# Test validate_port_range +echo "Testing validate_port_range..." +run_test "Valid default port" success validate_port_range "51820" +run_test "Valid port 80" success validate_port_range "80" +run_test "Valid port 443" success validate_port_range "443" +run_test "Valid max port" success validate_port_range "65535" +run_test "Empty port" failure validate_port_range "" +run_test "Invalid port - negative" failure validate_port_range "-1" +run_test "Invalid port - zero" failure validate_port_range "0" +run_test "Invalid port - too high" failure validate_port_range "65536" +run_test "Invalid port - non-numeric" failure validate_port_range "abc" +echo "" + +# Test validate_cidr (IPv4) +echo "Testing validate_cidr (IPv4)..." +run_test "Valid IPv4 /24" success validate_cidr "10.10.69.0/24" false +run_test "Valid IPv4 /32" success validate_cidr "192.168.1.1/32" false +run_test "Valid IPv4 /16" success validate_cidr "172.16.0.0/16" false +run_test "Empty IPv4 CIDR" failure validate_cidr "" false +run_test "Invalid IPv4 - no prefix" failure validate_cidr "10.10.69.0" false +run_test "Invalid IPv4 - prefix too large" failure validate_cidr "10.10.69.0/33" false +run_test "Invalid IPv4 - negative prefix" failure validate_cidr "10.10.69.0/-1" false +run_test "Invalid IPv4 - bad IP" failure validate_cidr "256.1.1.1/24" false +echo "" + +# Test validate_cidr (IPv6) +echo "Testing validate_cidr (IPv6)..." +run_test "Valid IPv6 /64" success validate_cidr "fd69:dead:beef:69::/64" true +run_test "Valid IPv6 /128" success validate_cidr "fd69:dead:beef:69::1/128" true +run_test "Valid IPv6 /48" success validate_cidr "fd69:dead::/48" true +run_test "Empty IPv6 CIDR" failure validate_cidr "" true +run_test "Invalid IPv6 - no prefix" failure validate_cidr "fd69:dead:beef:69::" true +run_test "Invalid IPv6 - prefix too large" failure validate_cidr "fd69:dead:beef:69::/129" true +echo "" + +# Test validate_dns_servers +echo "Testing validate_dns_servers..." +run_test "Valid single DNS" success validate_dns_servers "8.8.8.8" +run_test "Valid multiple DNS" success validate_dns_servers "8.8.8.8, 8.8.4.4" +run_test "Valid DNS with spaces" success validate_dns_servers "8.8.8.8, 1.1.1.1" +run_test "Empty DNS" success validate_dns_servers "" +run_test "Invalid DNS - bad format" failure validate_dns_servers "8.8.8" +run_test "Invalid DNS - special chars" failure validate_dns_servers "dns.example.com" +echo "" + +# Summary +echo "" +echo "=== Test Summary ===" +echo "Tests run: $TESTS_RUN" +echo "Tests passed: $TESTS_PASSED" +echo "Tests failed: $TESTS_FAILED" +echo "" + +if [[ $TESTS_FAILED -eq 0 ]]; then + echo "All tests passed! ✓" + exit 0 +else + echo "Some tests failed! ✗" + exit 1 +fi diff --git a/wg-install.sh b/wg-install.sh new file mode 100755 index 0000000..4ba774d --- /dev/null +++ b/wg-install.sh @@ -0,0 +1,921 @@ +#!/usr/bin/env bash +# +# wg-install.sh - WireGuard VPN Server Installation Script +# +# This script handles the complete installation of a WireGuard VPN server on Debian 13. +# It includes dependency checks, package installation, firewall setup (nftables), +# server key generation, interface initialization, and systemd service setup. +# +# Settings can be provided via interactive prompts or environment variables prefixed with WGI_ +# (e.g., WGI_SERVER_DOMAIN, WGI_WG_PORT, WGI_VPN_IPV4_RANGE, etc.) + +set -euo pipefail + +# Default Configuration (can be overridden by environment variables with WGI_ prefix) +WGI_SERVER_DOMAIN="${WGI_SERVER_DOMAIN:-}" +WGI_WG_PORT="${WGI_WG_PORT:-51820}" +WGI_VPN_IPV4_RANGE="${WGI_VPN_IPV4_RANGE:-10.10.69.0/24}" +WGI_VPN_IPV6_RANGE="${WGI_VPN_IPV6_RANGE:-fd69:dead:beef:69::/64}" +WGI_WG_INTERFACE="${WGI_WG_INTERFACE:-wg0}" +WGI_DNS_SERVERS="${WGI_DNS_SERVERS:-8.8.8.8, 8.8.4.4}" +WGI_LOG_FILE="${WGI_LOG_FILE:-/var/log/wg-admin-install.log}" +WGI_MIN_DISK_SPACE_MB=100 + +# Derived paths +CONF_D_DIR="/etc/wireguard/conf.d" +SERVER_CONF="${CONF_D_DIR}/server.conf" +CLIENT_OUTPUT_DIR="/etc/wireguard/clients" +WG_CONFIG="/etc/wireguard/${WGI_WG_INTERFACE}.conf" +BACKUP_DIR="/etc/wg-admin/backups" + +# Global variables for cleanup and rollback +TEMP_DIR="" +ROLLBACK_BACKUP_DIR="" +ROLLBACK_NEEDED=false +PUBLIC_INTERFACE="" +SERVER_PRIVATE_KEY="" +SERVER_PUBLIC_KEY="" + +# ============================================================================ +# Logging Functions +# ============================================================================ + +log() { + local level="$1" + shift + local message="$*" + local timestamp=$(date '+%Y-%m-%d %H:%M:%S') + echo "[${timestamp}] [${level}] ${message}" | tee -a "${WGI_LOG_FILE}" +} + +log_info() { + log "INFO" "$@" +} + +log_error() { + log "ERROR" "$@" >&2 +} + +log_warn() { + log "WARN" "$@" +} + +# ============================================================================ +# Cleanup and Error Handling +# ============================================================================ + +cleanup_handler() { + local exit_code=$? + + # Remove temporary directories + if [[ -n "${TEMP_DIR}" ]] && [[ -d "${TEMP_DIR}" ]]; then + log_info "Cleaning up temporary directory: ${TEMP_DIR}" + rm -rf "${TEMP_DIR}" + fi + + # Rollback on failure if needed + if [[ ${ROLLBACK_NEEDED} == true ]] && [[ ${exit_code} -ne 0 ]]; then + log_error "Installation failed, attempting rollback..." + rollback_installation + fi + + exit ${exit_code} +} + +# Set up traps for cleanup +trap cleanup_handler EXIT INT TERM HUP + +# ============================================================================ +# Input Validation Functions +# ============================================================================ + +validate_dns_servers() { + local dns="$1" + if [[ -z "$dns" ]]; then + return 0 # Empty DNS is allowed + fi + + # Split by comma and validate each DNS server + IFS=',' read -ra dns_array <<< "$dns" + for dns_server in "${dns_array[@]}"; do + dns_server=$(echo "$dns_server" | xargs) # Trim whitespace + if [[ ! "$dns_server" =~ ^[0-9]+\.[0-9]+\.[0-9]+\.[0-9]+$ ]]; then + log_error "Invalid DNS server format: ${dns_server}" + return 1 + fi + done + + return 0 +} + +validate_port_range() { + local port="$1" + if [[ -z "$port" ]]; then + log_error "Port cannot be empty" + return 1 + fi + if [[ ! "$port" =~ ^[0-9]+$ ]]; then + log_error "Port must be a number" + return 1 + fi + if [[ "$port" -lt 1 ]] || [[ "$port" -gt 65535 ]]; then + log_error "Port must be between 1 and 65535" + return 1 + fi + return 0 +} + +validate_cidr() { + local cidr="$1" + local is_ipv6="$2" + + if [[ -z "$cidr" ]]; then + log_error "CIDR cannot be empty" + return 1 + fi + + if [[ "$is_ipv6" == "true" ]]; then + # Basic IPv6 CIDR validation + if [[ ! "$cidr" =~ ^[0-9a-fA-F:]+/[0-9]+$ ]]; then + log_error "Invalid IPv6 CIDR format: ${cidr}" + return 1 + fi + local prefix="${cidr#*/}" + if [[ "$prefix" -lt 0 ]] || [[ "$prefix" -gt 128 ]]; then + log_error "Invalid IPv6 prefix length: ${prefix}" + return 1 + fi + else + # IPv4 CIDR validation + if [[ ! "$cidr" =~ ^[0-9]+\.[0-9]+\.[0-9]+\.[0-9]+/[0-9]+$ ]]; then + log_error "Invalid IPv4 CIDR format: ${cidr}" + return 1 + fi + local ip="${cidr%/*}" + local prefix="${cidr#*/}" + + # Validate each octet + IFS='.' read -ra octets <<< "$ip" + for octet in "${octets[@]}"; do + if [[ "$octet" -lt 0 ]] || [[ "$octet" -gt 255 ]]; then + log_error "Invalid IPv4 octet: ${octet} in ${cidr}" + return 1 + fi + done + + if [[ "$prefix" -lt 0 ]] || [[ "$prefix" -gt 32 ]]; then + log_error "Invalid IPv4 prefix length: ${prefix}" + return 1 + fi + fi + + return 0 +} + +validate_server_domain() { + local domain="$1" + + if [[ -z "$domain" ]]; then + log_error "Server domain cannot be empty" + return 1 + fi + + # Basic domain validation (alphanumeric, hyphens, dots) + if [[ ! "$domain" =~ ^[a-zA-Z0-9]([a-zA-Z0-9\-]{0,61}[a-zA-Z0-9])?(\.[a-zA-Z0-9]([a-zA-Z0-9\-]{0,61}[a-zA-Z0-9])?)*$ ]]; then + log_error "Invalid server domain format: ${domain}" + return 1 + fi + + return 0 +} + +# ============================================================================ +# Pre-installation Validation +# ============================================================================ + +pre_install_validation() { + log_info "Running pre-installation validation..." + + # Check root privileges + if [[ $EUID -ne 0 ]]; then + log_error "This script must be run as root" + exit 1 + fi + + # Check disk space (at least MIN_DISK_SPACE_MB free) + local free_space_kb=$(df / | awk 'NR==2 {print $4}') + local free_space_mb=$((free_space_kb / 1024)) + if [[ ${free_space_mb} -lt ${WGI_MIN_DISK_SPACE_MB} ]]; then + log_error "Insufficient disk space (${free_space_mb}MB free, ${WGI_MIN_DISK_SPACE_MB}MB required)" + exit 1 + fi + log_info "Disk space validation passed (${free_space_mb}MB free)" + + # Check port availability + if ss -ulnp 2>/dev/null | grep -q ":${WGI_WG_PORT}"; then + log_error "Port ${WGI_WG_PORT} is already in use" + echo "ERROR: WireGuard port ${WGI_WG_PORT} is already in use." + echo "Action: Stop the service using port ${WGI_WG_PORT} or change the port." + echo "To find what's using the port: 'sudo ss -tulnp | grep ${WGI_WG_PORT}'" + exit 1 + fi + log_info "Port ${WGI_WG_PORT} is available" + + log_info "Pre-installation validation passed" +} + +# ============================================================================ +# Backup Functions +# ============================================================================ + +backup_config() { + local operation="${1:-manual}" + local timestamp=$(date +"%Y%m%d_%H%M%S") + local backup_name="wg-backup-${operation}-${timestamp}" + local backup_path="${BACKUP_DIR}/${backup_name}" + + log_info "Creating configuration backup: ${backup_name}" + + # Create backup directory if it doesn't exist + mkdir -p "${BACKUP_DIR}" + chmod 700 "${BACKUP_DIR}" + + # Create backup directory + mkdir -p "${backup_path}" + + # Backup WireGuard configurations + if [[ -d "/etc/wireguard" ]]; then + cp -a /etc/wireguard "${backup_path}/" 2>/dev/null || true + fi + + # Backup nftables configurations + if [[ -f "/etc/nftables.conf" ]]; then + cp -a /etc/nftables.conf "${backup_path}/" 2>/dev/null || true + fi + + if [[ -d "/etc/nftables.d" ]]; then + cp -a /etc/nftables.d "${backup_path}/" 2>/dev/null || true + fi + + # Backup sysctl configuration + if [[ -f "/etc/sysctl.d/99-wireguard.conf" ]]; then + cp -a /etc/sysctl.d/99-wireguard.conf "${backup_path}/" 2>/dev/null || true + fi + + # Create backup metadata + cat > "${backup_path}/backup-info.txt" </dev/null | grep -E '^wg-backup-')) + + # If we have more than 10 backups, remove oldest backups + if [[ ${#backups[@]} -gt 10 ]]; then + log_info "Applying retention policy (keeping last 10 backups)..." + local to_remove=(${backups[@]:10}) + + for backup in "${to_remove[@]}"; do + log_info "Removing old backup: ${backup}" + rm -rf "${BACKUP_DIR}/${backup}" + done + fi +} + +# ============================================================================ +# Rollback Functions +# ============================================================================ + +rollback_installation() { + log_warn "Rolling back installation..." + + # Stop services + systemctl stop wg-quick@${WGI_WG_INTERFACE}.service 2>/dev/null || true + systemctl disable wg-quick@${WGI_WG_INTERFACE}.service 2>/dev/null || true + + # Restore from backup if exists + if [[ -d "${ROLLBACK_BACKUP_DIR}" ]]; then + log_info "Restoring from backup: ${ROLLBACK_BACKUP_DIR}" + + if [[ -f "${ROLLBACK_BACKUP_DIR}/wireguard.conf" ]]; then + cp "${ROLLBACK_BACKUP_DIR}/wireguard.conf" "${WG_CONFIG}" + fi + + if [[ -f "${ROLLBACK_BACKUP_DIR}/nftables.conf" ]]; then + cp "${ROLLBACK_BACKUP_DIR}/nftables.conf" /etc/nftables.conf + nft -f /etc/nftables.conf 2>/dev/null || true + fi + + if [[ -d "${ROLLBACK_BACKUP_DIR}/conf.d" ]]; then + rm -rf "${CONF_D_DIR}" + cp -r "${ROLLBACK_BACKUP_DIR}/conf.d" "${CONF_D_DIR}" + chmod 700 "${CONF_D_DIR}" + fi + fi + + # Bring down interface + wg-quick down ${WGI_WG_INTERFACE} 2>/dev/null || true + + log_warn "Rollback complete. Please review the logs and try again." +} + +# ============================================================================ +# Interactive Configuration +# ============================================================================ + +prompt_configuration() { + echo "" + echo "=== WireGuard VPN Server Configuration ===" + echo "" + echo "Press Enter to accept the default value (shown in brackets)" + echo "" + + # Server Domain + if [[ -z "${WGI_SERVER_DOMAIN}" ]]; then + read -p "Server Domain [e.g., vpn.example.com]: " input_domain + WGI_SERVER_DOMAIN="${input_domain}" + else + echo "Server Domain: ${WGI_SERVER_DOMAIN}" + fi + + # Port + if [[ -z "${WGI_WG_PORT}" ]] || [[ "${WGI_WG_PORT}" == "51820" ]]; then + read -p "WireGuard Port [${WGI_WG_PORT}]: " input_port + if [[ -n "$input_port" ]]; then + WGI_WG_PORT="$input_port" + fi + else + echo "WireGuard Port: ${WGI_WG_PORT}" + fi + + # IPv4 Range + if [[ -z "${WGI_VPN_IPV4_RANGE}" ]] || [[ "${WGI_VPN_IPV4_RANGE}" == "10.10.69.0/24" ]]; then + read -p "VPN IPv4 Range [${WGI_VPN_IPV4_RANGE}]: " input_ipv4 + if [[ -n "$input_ipv4" ]]; then + WGI_VPN_IPV4_RANGE="$input_ipv4" + fi + else + echo "VPN IPv4 Range: ${WGI_VPN_IPV4_RANGE}" + fi + + # IPv6 Range + if [[ -z "${WGI_VPN_IPV6_RANGE}" ]] || [[ "${WGI_VPN_IPV6_RANGE}" == "fd69:dead:beef:69::/64" ]]; then + read -p "VPN IPv6 Range [${WGI_VPN_IPV6_RANGE}]: " input_ipv6 + if [[ -n "$input_ipv6" ]]; then + WGI_VPN_IPV6_RANGE="$input_ipv6" + fi + else + echo "VPN IPv6 Range: ${WGI_VPN_IPV6_RANGE}" + fi + + # DNS Servers + if [[ -z "${WGI_DNS_SERVERS}" ]] || [[ "${WGI_DNS_SERVERS}" == "8.8.8.8, 8.8.4.4" ]]; then + read -p "DNS Servers [${WGI_DNS_SERVERS}]: " input_dns + if [[ -n "$input_dns" ]]; then + WGI_DNS_SERVERS="$input_dns" + fi + else + echo "DNS Servers: ${WGI_DNS_SERVERS}" + fi + + # Interface Name + if [[ -z "${WGI_WG_INTERFACE}" ]] || [[ "${WGI_WG_INTERFACE}" == "wg0" ]]; then + read -p "WireGuard Interface [${WGI_WG_INTERFACE}]: " input_interface + if [[ -n "$input_interface" ]]; then + WGI_WG_INTERFACE="$input_interface" + # Update derived paths + WG_CONFIG="/etc/wireguard/${WGI_WG_INTERFACE}.conf" + fi + else + echo "WireGuard Interface: ${WGI_WG_INTERFACE}" + fi + + echo "" + echo "=== Configuration Summary ===" + echo "Server Domain: ${WGI_SERVER_DOMAIN}" + echo "WireGuard Port: ${WGI_WG_PORT}" + echo "VPN IPv4 Range: ${WGI_VPN_IPV4_RANGE}" + echo "VPN IPv6 Range: ${WGI_VPN_IPV6_RANGE}" + echo "DNS Servers: ${WGI_DNS_SERVERS}" + echo "WireGuard Interface: ${WGI_WG_INTERFACE}" + echo "" + + read -p "Proceed with installation? (yes/no): " confirm + if [[ "${confirm}" != "yes" ]]; then + log_info "Installation cancelled by user" + exit 0 + fi +} + +# ============================================================================ +# Package Installation +# ============================================================================ + +install_packages() { + log_info "Installing packages..." + echo "Updating package lists..." + apt-get update -qq + + echo "Installing WireGuard and dependencies..." + apt-get install -y wireguard wireguard-tools qrencode nftables + + log_info "Package installation complete" +} + +# ============================================================================ +# Cleanup Existing Installation +# ============================================================================ + +cleanup_existing_installation() { + log_info "Checking for existing WireGuard installation..." + + # Stop and disable WireGuard service + if systemctl is-enabled --quiet wg-quick@${WGI_WG_INTERFACE}.service 2>/dev/null || systemctl list-unit-files | grep -q wg-quick@${WGI_WG_INTERFACE}.service; then + log_info "Stopping WireGuard service..." + echo "Stopping WireGuard service..." + systemctl stop wg-quick@${WGI_WG_INTERFACE}.service 2>/dev/null || true + echo "Disabling WireGuard service..." + systemctl disable wg-quick@${WGI_WG_INTERFACE}.service 2>/dev/null || true + fi + + # Stop and disable old client loader service + if systemctl is-enabled --quiet wg-load-clients.service 2>/dev/null || systemctl list-unit-files | grep -q wg-load-clients.service; then + echo "Removing old WireGuard client loader service..." + systemctl disable wg-load-clients.service 2>/dev/null || true + systemctl stop wg-load-clients.service 2>/dev/null || true + rm -f /etc/systemd/system/wg-load-clients.service + systemctl daemon-reload 2>/dev/null || true + fi + + # Bring down interface if running + if wg show ${WGI_WG_INTERFACE} &>/dev/null; then + echo "WireGuard interface is active, bringing down..." + wg-quick down ${WGI_WG_INTERFACE} 2>/dev/null || true + fi + + # Remove existing configuration directories + if [[ -d "/etc/wireguard" ]]; then + echo "Removing existing WireGuard configuration..." + rm -rf /etc/wireguard + fi + + if [[ -d "/etc/clients.d" ]]; then + echo "Removing old client directory..." + rm -rf /etc/clients.d + fi + + if [[ -d "/etc/wireguard/conf.d" ]]; then + echo "Removing existing config.d directory..." + rm -rf /etc/wireguard/conf.d + fi + + if [[ -d "/etc/wireguard/peer.d" ]]; then + echo "Removing old peer.d directory..." + rm -rf /etc/wireguard/peer.d + fi + + if [[ -d "/root/wireguard-clients" ]]; then + echo "Removing old client config directory..." + rm -rf /root/wireguard-clients + fi + + # Flush nftables rules + if command -v nft &> /dev/null; then + echo "Flushing nftables rules..." + nft flush ruleset 2>/dev/null || true + fi + + echo "Cleanup complete." +} + +# ============================================================================ +# Network Configuration +# ============================================================================ + +detect_public_interface() { + PUBLIC_INTERFACE=$(ip route get 8.8.8.8 | grep -oP 'dev \K\S+' | head -1) + log_info "Public interface detected: ${PUBLIC_INTERFACE}" + echo "Public interface: ${PUBLIC_INTERFACE}" +} + +enable_ip_forwarding() { + log_info "Enabling IP forwarding..." + echo "Enabling IP forwarding..." + + cat > /etc/sysctl.d/99-wireguard.conf < /etc/nftables.d/wireguard.conf < /etc/nftables.conf < "$temp_private" + wg pubkey < "$temp_private" > "$temp_public" + SERVER_PRIVATE_KEY=$(cat "$temp_private") + SERVER_PUBLIC_KEY=$(cat "$temp_public") + + # Move to final location with proper permissions + mv "$temp_private" server_private.key + mv "$temp_public" server_public.key + chmod 600 server_private.key + chmod 644 server_public.key + + log_info "Server keys generated and secured" +} + +# ============================================================================ +# WireGuard Configuration +# ============================================================================ + +configure_wireguard() { + log_info "Configuring WireGuard interface..." + echo "Configuring WireGuard interface..." + + # Create config.d directory + mkdir -p "${CONF_D_DIR}" + + # Create server.conf in conf.d directory + cat > "${SERVER_CONF}" < "${WG_CONFIG}" </dev/null; then + wg-quick down ${WGI_WG_INTERFACE} 2>/dev/null || true + fi + + systemctl start wg-quick@${WGI_WG_INTERFACE}.service + + # Wait for WireGuard to initialize + sleep 2 + + # Verify WireGuard is running + if ! systemctl is-active --quiet wg-quick@${WGI_WG_INTERFACE}.service; then + log_error "WireGuard service failed to start" + echo "=== Service Status ===" + systemctl status wg-quick@${WGI_WG_INTERFACE}.service + exit 1 + fi + + # Verify correct listening port + ACTUAL_PORT=$(wg show ${WGI_WG_INTERFACE} listen-port) + if [[ "$ACTUAL_PORT" != "$WGI_WG_PORT" ]]; then + log_error "WireGuard listening on port $ACTUAL_PORT instead of $WGI_WG_PORT" + echo "=== Config File ===" + grep ListenPort "${WG_CONFIG}" + echo "=== Running Interface ===" + wg show ${WGI_WG_INTERFACE} + exit 1 + fi + + echo "WireGuard started successfully on port $ACTUAL_PORT" + log_info "WireGuard service started successfully" +} + +# ============================================================================ +# Main Installation Function +# ============================================================================ + +main() { + log_info "=== WireGuard VPN Installation for Debian 13 ===" + + # Validate configuration + validate_server_domain "${WGI_SERVER_DOMAIN}" + validate_port_range "${WGI_WG_PORT}" + validate_cidr "${WGI_VPN_IPV4_RANGE}" false + validate_cidr "${WGI_VPN_IPV6_RANGE}" true + validate_dns_servers "${WGI_DNS_SERVERS}" + + # Interactive configuration if not all values set + if [[ -z "${WGI_SERVER_DOMAIN}" ]] || \ + [[ -z "${WGI_WG_PORT}" ]] || \ + [[ -z "${WGI_VPN_IPV4_RANGE}" ]] || \ + [[ -z "${WGI_VPN_IPV6_RANGE}" ]]; then + prompt_configuration + fi + + # Print configuration + echo "" + echo "=== WireGuard VPN Installation for Debian 13 ===" + echo "Server: ${WGI_SERVER_DOMAIN}" + echo "IPv4 VPN range: ${WGI_VPN_IPV4_RANGE}" + echo "IPv6 VPN range: ${WGI_VPN_IPV6_RANGE}" + echo "Port: ${WGI_WG_PORT}" + echo "Interface: ${WGI_WG_INTERFACE}" + echo "" + + # Pre-installation validation + pre_install_validation + + # Auto-backup before install (only if config exists) + if [[ -f "${WG_CONFIG}" ]] || [[ -f "/etc/nftables.conf" ]]; then + backup_config "install-pre" + fi + + # Enable rollback flag + ROLLBACK_NEEDED=true + + # Create backup directory for potential rollback + ROLLBACK_BACKUP_DIR=$(mktemp -d) + log_info "Created rollback backup directory: ${ROLLBACK_BACKUP_DIR}" + + # Backup existing configs if they exist + if [[ -f "${WG_CONFIG}" ]]; then + cp "${WG_CONFIG}" "${ROLLBACK_BACKUP_DIR}/wireguard.conf" + log_info "Backed up existing WireGuard config" + fi + if [[ -f "/etc/nftables.conf" ]]; then + cp /etc/nftables.conf "${ROLLBACK_BACKUP_DIR}/nftables.conf" + log_info "Backed up existing nftables config" + fi + if [[ -d "${CONF_D_DIR}" ]]; then + cp -r "${CONF_D_DIR}" "${ROLLBACK_BACKUP_DIR}/conf.d" + log_info "Backed up existing client configs" + fi + + # Cleanup existing installation + cleanup_existing_installation + + # Install packages + install_packages + + # Detect public interface + detect_public_interface + + # Enable IP forwarding + enable_ip_forwarding + + # Configure nftables firewall + configure_nftables + + # Generate server keys + generate_server_keys + + # Configure WireGuard + configure_wireguard + + # Setup systemd service + setup_systemd_service + + # Disable rollback flag - installation successful + ROLLBACK_NEEDED=false + + # Clean up rollback backup directory + if [[ -d "${ROLLBACK_BACKUP_DIR}" ]]; then + rm -rf "${ROLLBACK_BACKUP_DIR}" + log_info "Cleaned up rollback backup directory" + fi + + # Installation complete + echo "" + log_info "=== Installation Complete ===" + echo "=== Installation Complete ===" + echo "" + echo "Server Public Key: ${SERVER_PUBLIC_KEY}" + echo "Endpoint: ${WGI_SERVER_DOMAIN}:${WGI_WG_PORT}" + echo "VPN IPv4 Range: ${WGI_VPN_IPV4_RANGE}" + echo "VPN IPv6 Range: ${WGI_VPN_IPV6_RANGE}" + echo "" + echo "Use 'wireguard.sh add [--psk]' to add clients" + echo "Use 'wireguard.sh list' to list clients" + echo "Configs will be merged from: ${CONF_D_DIR}" + echo " - ${SERVER_CONF} (server interface)" + echo " - ${CONF_D_DIR}/client-*.conf (client peers)" + echo "" + echo "Check status: wg show" + echo "System status: systemctl status wg-quick@${WGI_WG_INTERFACE}" + echo "" + + log_info "Installation completed successfully" +} + +# ============================================================================ +# Entry Point +# ============================================================================ + +# Only run main if script is executed directly (not sourced) +if [[ "${BASH_SOURCE[0]}" == "${0}" ]]; then + # Show usage if help requested + if [[ "${1:-}" == "-h" ]] || [[ "${1:-}" == "--help" ]] || [[ "${1:-}" == "help" ]]; then + echo "Usage: $0 [options]" + echo "" + echo "Environment Variables (prefix with WGI_):" + echo " WGI_SERVER_DOMAIN Server domain (e.g., vpn.example.com)" + echo " WGI_WG_PORT WireGuard UDP port (default: 51820)" + echo " WGI_VPN_IPV4_RANGE IPv4 VPN range (default: 10.10.69.0/24)" + echo " WGI_VPN_IPV6_RANGE IPv6 VPN range (default: fd69:dead:beef:69::/64)" + echo " WGI_DNS_SERVERS DNS servers (default: 8.8.8.8, 8.8.4.4)" + echo " WGI_WG_INTERFACE WireGuard interface name (default: wg0)" + echo "" + echo "Examples:" + echo " # Interactive installation" + echo " sudo $0" + echo "" + echo " # Non-interactive with environment variables" + echo " sudo WGI_SERVER_DOMAIN=vpn.example.com $0" + echo "" + echo " # Custom port and VPN range" + echo " sudo WGI_SERVER_DOMAIN=vpn.example.com WGI_WG_PORT=443 $0" + exit 0 + fi + + # Run main installation + main +fi diff --git a/wireguard.sh b/wireguard.sh index 865971e..825cfc5 100755 --- a/wireguard.sh +++ b/wireguard.sh @@ -1,49 +1,612 @@ #!/usr/bin/env bash set -euo pipefail -# Configuration -SERVER_DOMAIN="velkhana.calmcacil.dev" -WG_PORT="51820" -VPN_IPV4_RANGE="10.10.69.0/24" -VPN_IPV6_RANGE="fd69:dead:beef:69::/64" -WG_INTERFACE="wg0" +# Get absolute path to script (must be done before any cd commands) +SCRIPT_PATH=$(realpath "$0") + +# Default Configuration (can be overridden by /etc/wg-admin/config.conf or environment variables) +SERVER_DOMAIN="${SERVER_DOMAIN:-}" +WG_PORT="${WG_PORT:-51820}" +VPN_IPV4_RANGE="${VPN_IPV4_RANGE:-10.10.69.0/24}" +VPN_IPV6_RANGE="${VPN_IPV6_RANGE:-fd69:dead:beef:69::/64}" +WG_INTERFACE="${WG_INTERFACE:-wg0}" CONF_D_DIR="/etc/wireguard/conf.d" SERVER_CONF="${CONF_D_DIR}/server.conf" CLIENT_OUTPUT_DIR="/etc/wireguard/clients" WG_CONFIG="/etc/wireguard/${WG_INTERFACE}.conf" -DNS_SERVERS="8.8.8.8, 8.8.4.4" +DNS_SERVERS="${DNS_SERVERS:-8.8.8.8, 8.8.4.4}" +LOG_FILE="/var/log/wireguard-admin.log" +MIN_DISK_SPACE_MB=100 -# Get absolute path to script (must be done before any cd commands) -SCRIPT_PATH=$(realpath "$0") +# Load configuration file if it exists +load_config() { + local config_file="/etc/wg-admin/config.conf" + if [[ -f "$config_file" ]]; then + while IFS='=' read -r key value; do + [[ "$key" =~ ^[[:space:]]*# ]] && continue + [[ -z "$key" ]] && continue + key=$(echo "$key" | xargs) + value=$(echo "$value" | xargs) + export "$key=$value" + done < "$config_file" + fi +} + +# Load config at script start +load_config + +# Helper functions to parse CIDR ranges +get_ipv4_network() { + local cidr="$1" + echo "$cidr" | cut -d'/' -f1 | sed 's/0$//' | sed 's/\.$//' +} + +get_ipv6_network() { + local cidr="$1" + echo "$cidr" | cut -d'/' -f1 | sed 's/::$//' +} + + +# Global variables for cleanup and rollback +TEMP_DIR="" +BACKUP_DIR="" +ROLLBACK_NEEDED=false + +# Logging functions with timestamps +log() { + local level="$1" + shift + local message="$*" + local timestamp=$(date '+%Y-%m-%d %H:%M:%S') + echo "[${timestamp}] [${level}] ${message}" | tee -a "${LOG_FILE}" +} + +log_info() { + log "INFO" "$@" +} + +log_error() { + log "ERROR" "$@" >&2 +} + +log_warn() { + log "WARN" "$@" +} + +# Cleanup trap handler +cleanup_handler() { + local exit_code=$? + + # Remove temporary directories + if [[ -n "${TEMP_DIR}" ]] && [[ -d "${TEMP_DIR}" ]]; then + log_info "Cleaning up temporary directory: ${TEMP_DIR}" + rm -rf "${TEMP_DIR}" + fi + + # Rollback on failure if needed + if [[ ${ROLLBACK_NEEDED} == true ]] && [[ ${exit_code} -ne 0 ]]; then + log_error "Operation failed, attempting rollback..." + rollback_installation + fi + + exit ${exit_code} +} + +# Set up traps for cleanup +trap cleanup_handler EXIT INT TERM HUP usage() { echo "Usage: $0 [options]" echo "" echo "Commands:" echo " install Install WireGuard VPN server" - echo " add Add a new client" + echo " add [--psk] Add a new client (optional PSK for extra security)" echo " list List all clients" echo " remove Remove a client" echo " show Show client configuration" echo " qr Show QR code for client" echo "" echo "Options:" + echo " --psk Use pre-shared key (PSK) for additional security" echo " -h, --help Show this help" + echo "" + echo "Examples:" + echo " $0 install" + echo " $0 add my-phone --psk" + echo " $0 list" exit 1 } check_root() { if [[ $EUID -ne 0 ]]; then - echo "ERROR: This script must be run as root" + log_error "This script must be run as root" + echo "ERROR: This script must be run as root. Please use 'sudo $0 $1'" >&2 exit 1 fi } +# Pre-install validation +pre_install_validation() { + log_info "Running pre-installation validation..." + + # Check root is already done by check_root() + + # Check disk space (at least MIN_DISK_SPACE_MB free) + local free_space_kb=$(df / | awk 'NR==2 {print $4}') + local free_space_mb=$((free_space_kb / 1024)) + if [[ ${free_space_mb} -lt ${MIN_DISK_SPACE_MB} ]]; then + log_error "Insufficient disk space (${free_space_mb}MB free, ${MIN_DISK_SPACE_MB}MB required)" + echo "ERROR: Insufficient disk space. At least ${MIN_DISK_SPACE_MB}MB free space is required." >&2 + echo "Action: Free up disk space or use a different filesystem." >&2 + exit 1 + fi + log_info "Disk space validation passed (${free_space_mb}MB free)" + + # Check port availability + if ss -ulnp | grep -q ":${WG_PORT}"; then + log_error "Port ${WG_PORT} is already in use" + echo "ERROR: WireGuard port ${WG_PORT} is already in use." >&2 + echo "Action: Stop the service using port ${WG_PORT} or change WG_PORT in the script." >&2 + echo "To find what's using the port: 'sudo ss -tulnp | grep ${WG_PORT}'" >&2 + exit 1 + fi + log_info "Port ${WG_PORT} is available" + + # Check required commands will be installed + log_info "Pre-installation validation passed" +} + +# Rollback installation on failure +rollback_installation() { + log_warn "Rolling back installation..." + + # Stop services + systemctl stop wg-quick@wg0.service 2>/dev/null || true + systemctl disable wg-quick@wg0.service 2>/dev/null || true + + # Restore from backup if exists + if [[ -d "${BACKUP_DIR}" ]]; then + log_info "Restoring from backup: ${BACKUP_DIR}" + + if [[ -f "${BACKUP_DIR}/wireguard.conf" ]]; then + cp "${BACKUP_DIR}/wireguard.conf" "${WG_CONFIG}" + fi + + if [[ -f "${BACKUP_DIR}/nftables.conf" ]]; then + cp "${BACKUP_DIR}/nftables.conf" /etc/nftables.conf + nft -f /etc/nftables.conf 2>/dev/null || true + fi + + if [[ -d "${BACKUP_DIR}/conf.d" ]]; then + rm -rf "${CONF_D_DIR}" + cp -r "${BACKUP_DIR}/conf.d" "${CONF_D_DIR}" + chmod 700 "${CONF_D_DIR}" + fi + fi + + # Bring down interface + wg-quick down wg0 2>/dev/null || true + + log_warn "Rollback complete. Please review the logs and try again." +} + + +# Backup and rollback functions +BACKUP_DIR="/etc/wg-admin/backups" +BACKUP_RETENTION=10 + +backup_config() { + local operation="${1:-manual}" + local timestamp=$(date +"%Y%m%d_%H%M%S") + local backup_name="wg-backup-${operation}-${timestamp}" + local backup_path="${BACKUP_DIR}/${backup_name}" + + log_info "Creating configuration backup: ${backup_name}" + + # Create backup directory if it doesn't exist + mkdir -p "${BACKUP_DIR}" + chmod 700 "${BACKUP_DIR}" + + # Create backup directory + mkdir -p "${backup_path}" + + # Backup WireGuard configurations + if [[ -d "/etc/wireguard" ]]; then + cp -a /etc/wireguard "${backup_path}/" 2>/dev/null || true + fi + + # Backup nftables configurations + if [[ -f "/etc/nftables.conf" ]]; then + cp -a /etc/nftables.conf "${backup_path}/" 2>/dev/null || true + fi + + if [[ -d "/etc/nftables.d" ]]; then + cp -a /etc/nftables.d "${backup_path}/" 2>/dev/null || true + fi + + # Backup sysctl configuration + if [[ -f "/etc/sysctl.d/99-wireguard.conf" ]]; then + cp -a /etc/sysctl.d/99-wireguard.conf "${backup_path}/" 2>/dev/null || true + fi + + # Create backup metadata + cat > "${backup_path}/backup-info.txt" <" + exit 1 + fi + + if [[ ! -d "${backup_path}" ]]; then + log_error "Backup not found: ${backup_path}" + echo "ERROR: Backup not found: ${backup_path}" + exit 1 + fi + + if [[ ! -f "${backup_path}/backup-info.txt" ]]; then + log_error "Invalid backup directory (missing backup-info.txt)" + echo "ERROR: Invalid backup directory (missing backup-info.txt)" + exit 1 + fi + + echo "=== WARNING: Configuration Restore ===" + echo "This will replace current WireGuard and firewall configurations" + echo "Backup: ${backup_path}" + echo "" + read -p "Continue? (yes/no): " confirm + + if [[ "${confirm}" != "yes" ]]; then + log_info "Restore cancelled by user" + echo "Restore cancelled" + exit 0 + fi + + log_info "Restoring configuration from: ${backup_path}" + + # Stop WireGuard service + if systemctl is-active --quiet wg-quick@wg0.service 2>/dev/null; then + log_info "Stopping WireGuard service..." + systemctl stop wg-quick@wg0.service 2>/dev/null || true + fi + + # Stop nftables service + if systemctl is-active --quiet nftables.service 2>/dev/null; then + log_info "Stopping nftables service..." + systemctl stop nftables.service 2>/dev/null || true + fi + + # Create a backup of current state before restore + backup_config "pre-restore-$(date +%s)" + + # Restore WireGuard configurations + if [[ -d "${backup_path}/wireguard" ]]; then + log_info "Restoring WireGuard configurations..." + rm -rf /etc/wireguard + cp -a "${backup_path}/wireguard" /etc/ + fi + + # Restore nftables configurations + if [[ -f "${backup_path}/nftables.conf" ]]; then + log_info "Restoring nftables configuration..." + cp -a "${backup_path}/nftables.conf" /etc/ + fi + + if [[ -d "${backup_path}/nftables.d" ]]; then + log_info "Restoring nftables.d directory..." + rm -rf /etc/nftables.d + cp -a "${backup_path}/nftables.d" /etc/ + fi + + # Restore sysctl configuration + if [[ -f "${backup_path}/99-wireguard.conf" ]]; then + log_info "Restoring sysctl configuration..." + cp -a "${backup_path}/99-wireguard.conf" /etc/sysctl.d/ + sysctl -p /etc/sysctl.d/99-wireguard.conf 2>/dev/null || true + fi + + # Start services + log_info "Starting nftables service..." + systemctl start nftables.service + + log_info "Starting WireGuard service..." + systemctl start wg-quick@wg0.service + + echo "" + log_info "Configuration restored successfully" + echo "Configuration restored successfully from: ${backup_path}" + echo "" + echo "Note: A pre-restore backup was created for safety" +} + +apply_retention_policy() { + if [[ ! -d "${BACKUP_DIR}" ]]; then + return + fi + + # List all backups sorted by modification time (oldest first) + local backups=($(ls -t "${BACKUP_DIR}" 2>/dev/null | grep -E '^wg-backup-')) + + # If we have more than retention count, remove oldest backups + if [[ ${#backups[@]} -gt ${BACKUP_RETENTION} ]]; then + log_info "Applying retention policy (keeping last ${BACKUP_RETENTION} backups)..." + local to_remove=(${backups[@]:${BACKUP_RETENTION}}) + + for backup in "${to_remove[@]}"; do + log_info "Removing old backup: ${backup}" + rm -rf "${BACKUP_DIR}/${backup}" + done + fi +} + +# Validation Functions + +validate_client_name() { + local name="$1" + if [[ -z "$name" ]]; then + echo "ERROR: Client name cannot be empty" + return 1 + fi + if [[ ! "$name" =~ ^[a-zA-Z0-9_-]+$ ]]; then + echo "ERROR: Client name must contain only alphanumeric characters, hyphens, and underscores" + return 1 + fi + if [[ ${#name} -gt 64 ]]; then + echo "ERROR: Client name cannot exceed 64 characters" + return 1 + fi + return 0 +} + +validate_ip_availability() { + local ipv4="$1" + local ipv6="$2" + + # Check if IPv4 is already in use + if [[ -n "$ipv4" ]]; then + if grep -q "AllowedIPs = ${ipv4}/" "${CONF_D_DIR}"/client-*.conf 2>/dev/null; then + echo "ERROR: IPv4 address ${ipv4} is already in use" + return 1 + fi + fi + + # Check if IPv6 is already in use + if [[ -n "$ipv6" ]]; then + if grep -q "AllowedIPs = ${ipv6}/" "${CONF_D_DIR}"/client-*.conf 2>/dev/null; then + echo "ERROR: IPv6 address ${ipv6} is already in use" + return 1 + fi + fi + + return 0 +} + +validate_dns_servers() { + local dns="$1" + if [[ -z "$dns" ]]; then + return 0 # Empty DNS is allowed + fi + + # Split by comma and validate each DNS server + IFS=',' read -ra dns_array <<< "$dns" + for dns_server in "${dns_array[@]}"; do + dns_server=$(echo "$dns_server" | xargs) # Trim whitespace + if [[ ! "$dns_server" =~ ^[0-9]+\.[0-9]+\.[0-9]+\.[0-9]+$ ]]; then + echo "ERROR: Invalid DNS server format: ${dns_server}" + return 1 + fi + done + + return 0 +} + +validate_port_range() { + local port="$1" + if [[ -z "$port" ]]; then + echo "ERROR: Port cannot be empty" + return 1 + fi + if [[ ! "$port" =~ ^[0-9]+$ ]]; then + echo "ERROR: Port must be a number" + return 1 + fi + if [[ "$port" -lt 1 ]] || [[ "$port" -gt 65535 ]]; then + echo "ERROR: Port must be between 1 and 65535" + return 1 + fi + return 0 +} + +validate_config_syntax() { + local config_file="$1" + if [[ ! -f "$config_file" ]]; then + echo "ERROR: Config file not found: ${config_file}" + return 1 + fi + + local in_interface_section=false + local in_peer_section=false + local has_interface=false + local has_peer=false + local seen_public_keys=() + + while IFS= read -r line; do + # Skip comments and empty lines + [[ "$line" =~ ^[[:space:]]*# ]] && continue + [[ -z "$line" ]] && continue + + # Check section headers + if [[ "$line" =~ ^\[Interface\] ]]; then + if [[ "$in_interface_section" == true ]]; then + echo "ERROR: Duplicate [Interface] section in ${config_file}" + return 1 + fi + if [[ "$in_peer_section" == true ]]; then + echo "ERROR: [Interface] section found after [Peer] section in ${config_file}" + return 1 + fi + in_interface_section=true + has_interface=true + continue + fi + + if [[ "$line" =~ ^\[Peer\] ]]; then + if [[ "$in_interface_section" == false ]] && [[ "$has_interface" == false ]]; then + echo "ERROR: [Peer] section found before [Interface] section in ${config_file}" + return 1 + fi + in_peer_section=true + has_peer=true + continue + fi + + # Validate keys within sections + if [[ "$line" =~ ^[[:space:]]*PrivateKey[[:space:]]*=[[:space:]]*(.+)$ ]]; then + local key="${BASH_REMATCH[1]}" + if [[ "$key" =~ ^[[:space:]]*# ]]; then + continue + fi + if [[ ! "$key" =~ ^[A-Za-z0-9+/]{42}[A-Za-z0-9+/=]{2}$ ]] && [[ ! "$key" =~ ^[A-Za-z0-9+/]{43}[A-Za-z0-9+/=]$ ]]; then + echo "ERROR: Invalid PrivateKey format in ${config_file}" + return 1 + fi + fi + + if [[ "$line" =~ ^[[:space:]]*PublicKey[[:space:]]*=[[:space:]]*(.+)$ ]]; then + local key="${BASH_REMATCH[1]}" + if [[ "$key" =~ ^[[:space:]]*# ]]; then + continue + fi + if [[ ! "$key" =~ ^[A-Za-z0-9+/]{42}[A-Za-z0-9+/=]{2}$ ]] && [[ ! "$key" =~ ^[A-Za-z0-9+/]{43}[A-Za-z0-9+/=]$ ]]; then + echo "ERROR: Invalid PublicKey format in ${config_file}" + return 1 + fi + # Check for duplicate public keys + for existing_key in "${seen_public_keys[@]}"; do + if [[ "$existing_key" == "$key" ]]; then + echo "ERROR: Duplicate public key found in ${config_file}" + return 1 + fi + done + seen_public_keys+=("$key") + fi + + if [[ "$line" =~ ^[[:space:]]*PresharedKey[[:space:]]*=[[:space:]]*(.+)$ ]]; then + local key="${BASH_REMATCH[1]}" + if [[ "$key" =~ ^[[:space:]]*# ]]; then + continue + fi + if [[ ! "$key" =~ ^[A-Za-z0-9+/]{42}[A-Za-z0-9+/=]{2}$ ]] && [[ ! "$key" =~ ^[A-Za-z0-9+/]{43}[A-Za-z0-9+/=]$ ]]; then + echo "ERROR: Invalid PresharedKey format in ${config_file}" + return 1 + fi + fi + + if [[ "$line" =~ ^[[:space:]]*Address[[:space:]]*=[[:space:]]*(.+)$ ]]; then + local addresses="${BASH_REMATCH[1]}" + IFS=',' read -ra addr_array <<< "$addresses" + for addr in "${addr_array[@]}"; do + addr=$(echo "$addr" | xargs) # Trim whitespace + if [[ "$addr" =~ ^[0-9]+\.[0-9]+\.[0-9]+\.[0-9]+/[0-9]+$ ]]; then + # IPv4 CIDR validation + local ip="${addr%/*}" + local prefix="${addr#*/}" + if [[ "$prefix" -lt 0 ]] || [[ "$prefix" -gt 32 ]]; then + echo "ERROR: Invalid IPv4 prefix length: ${prefix}" + return 1 + fi + elif [[ "$addr" =~ ^[0-9a-fA-F:]+/[0-9]+$ ]]; then + # IPv6 CIDR validation (basic) + local prefix="${addr#*/}" + if [[ "$prefix" -lt 0 ]] || [[ "$prefix" -gt 128 ]]; then + echo "ERROR: Invalid IPv6 prefix length: ${prefix}" + return 1 + fi + else + echo "ERROR: Invalid Address format: ${addr}" + return 1 + fi + done + fi + + if [[ "$line" =~ ^[[:space:]]*AllowedIPs[[:space:]]*=[[:space:]]*(.+)$ ]]; then + local allowed_ips="${BASH_REMATCH[1]}" + IFS=',' read -ra ip_array <<< "$allowed_ips" + for ip in "${ip_array[@]}"; do + ip=$(echo "$ip" | xargs) # Trim whitespace + if [[ "$ip" =~ ^[0-9]+\.[0-9]+\.[0-9]+\.[0-9]+/[0-9]+$ ]]; then + local prefix="${ip#*/}" + if [[ "$prefix" -lt 0 ]] || [[ "$prefix" -gt 32 ]]; then + echo "ERROR: Invalid IPv4 prefix in AllowedIPs: ${ip}" + return 1 + fi + elif [[ "$ip" =~ ^[0-9a-fA-F:]+/[0-9]+$ ]]; then + local prefix="${ip#*/}" + if [[ "$prefix" -lt 0 ]] || [[ "$prefix" -gt 128 ]]; then + echo "ERROR: Invalid IPv6 prefix in AllowedIPs: ${ip}" + return 1 + fi + else + echo "ERROR: Invalid AllowedIPs format: ${ip}" + return 1 + fi + done + fi + + if [[ "$line" =~ ^[[:space:]]*DNS[[:space:]]*=[[:space:]]*(.+)$ ]]; then + local dns_servers="${BASH_REMATCH[1]}" + IFS=',' read -ra dns_array <<< "$dns_servers" + for dns in "${dns_array[@]}"; do + dns=$(echo "$dns" | xargs) # Trim whitespace + if [[ ! "$dns" =~ ^[0-9]+\.[0-9]+\.[0-9]+\.[0-9]+$ ]]; then + echo "ERROR: Invalid DNS server format: ${dns}" + return 1 + fi + done + fi + + if [[ "$line" =~ ^[[:space:]]*ListenPort[[:space:]]*=[[:space:]]*(.+)$ ]]; then + local port="${BASH_REMATCH[1]}" + if [[ ! "$port" =~ ^[0-9]+$ ]]; then + echo "ERROR: ListenPort must be a number" + return 1 + fi + if [[ "$port" -lt 1 ]] || [[ "$port" -gt 65535 ]]; then + echo "ERROR: ListenPort must be between 1 and 65535" + return 1 + fi + fi + + done < "$config_file" + + return 0 +} + get_next_ipv4() { - local used_ips=$(grep -h "AllowedIPs = 10.10.69." "${CONF_D_DIR}"/client-*.conf 2>/dev/null | cut -d' ' -f3 | cut -d'/' -f1 | sort -V | uniq) + local network=$(get_ipv4_network "${VPN_IPV4_RANGE}") + local used_ips=$(grep -h "AllowedIPs = ${network}" "${CONF_D_DIR}"/client-*.conf 2>/dev/null | cut -d' ' -f3 | cut -d'/' -f1 | sort -V | uniq) for i in {2..254}; do - local ip="10.10.69.${i}" + local ip="${network}${i}" if ! echo "$used_ips" | grep -q "^${ip}$"; then echo "${ip}" return @@ -55,10 +618,11 @@ get_next_ipv4() { } get_next_ipv6() { - local used_ips=$(grep -h "AllowedIPs = fd69:dead:beef:69::" "${CONF_D_DIR}"/client-*.conf 2>/dev/null | grep -o 'fd69:dead:beef:69::[0-9a-f]*' | sort | uniq) + local network=$(get_ipv6_network "${VPN_IPV6_RANGE}") + local used_ips=$(grep -h "AllowedIPs = ${network}" "${CONF_D_DIR}"/client-*.conf 2>/dev/null | grep -o "${network}[0-9a-f]*" | sort | uniq) for i in {2..254}; do - local ip=$(printf "fd69:dead:beef:69::%x" $i) + local ip=$(printf "${network}%x" $i) if ! echo "$used_ips" | grep -q "^${ip}$"; then echo "${ip}" return @@ -72,6 +636,21 @@ get_next_ipv6() { cmd_install() { check_root + # Check if SERVER_DOMAIN is set + if [[ -z "$SERVER_DOMAIN" ]]; then + echo "ERROR: SERVER_DOMAIN is not set." + echo "" + echo "Please create a configuration file:" + echo " sudo mkdir -p /etc/wg-admin" + echo " sudo cp config.example /etc/wg-admin/config.conf" + echo " sudo nano /etc/wg-admin/config.conf" + echo "" + echo "Or set it via environment variable:" + echo " SERVER_DOMAIN=vpn.example.com sudo ./wireguard.sh install" + exit 1 + fi + + log_info "=== WireGuard VPN Installation for Debian 13 ===" echo "=== WireGuard VPN Installation for Debian 13 ===" echo "Server: ${SERVER_DOMAIN}" echo "IPv4 VPN range: ${VPN_IPV4_RANGE}" @@ -79,9 +658,36 @@ cmd_install() { echo "Port: ${WG_PORT}" echo "" + # Auto-backup before install (only if config exists) + if [[ -f "${WG_CONFIG}" ]] || [[ -f "/etc/nftables.conf" ]]; then + backup_config "install-pre" + fi + + # Enable rollback flag + ROLLBACK_NEEDED=true + + # Create backup directory for potential rollback + BACKUP_DIR=$(mktemp -d) + log_info "Created backup directory: ${BACKUP_DIR}" + + # Backup existing configs if they exist + if [[ -f "${WG_CONFIG}" ]]; then + cp "${WG_CONFIG}" "${BACKUP_DIR}/wireguard.conf" + log_info "Backed up existing WireGuard config" + fi + if [[ -f "/etc/nftables.conf" ]]; then + cp /etc/nftables.conf "${BACKUP_DIR}/nftables.conf" + log_info "Backed up existing nftables config" + fi + if [[ -d "${CONF_D_DIR}" ]]; then + cp -r "${CONF_D_DIR}" "${BACKUP_DIR}/conf.d" + log_info "Backed up existing client configs" + fi + # Reset and cleanup existing WireGuard installation - echo "Checking for existing WireGuard installation..." + log_info "Checking for existing WireGuard installation..." if systemctl is-enabled --quiet wg-quick@wg0.service 2>/dev/null || systemctl list-unit-files | grep -q wg-quick@wg0.service; then + log_info "Stopping WireGuard service..." echo "Stopping WireGuard service..." systemctl stop wg-quick@wg0.service 2>/dev/null || true echo "Disabling WireGuard service..." @@ -164,18 +770,40 @@ EOF flush ruleset table inet wireguard { + chain prerouting { + type filter hook prerouting priority -150; + + # Rate limiting for SSH (3 connections per minute, burst 5) + tcp dport 22 ct state new limit rate 3/minute burst 5 packets accept + tcp dport 22 ct state new limit rate 3/minute burst 5 packets drop + + # Rate limiting for WireGuard (10 packets per second) + udp dport 51820 limit rate 10/second burst 20 packets accept + udp dport 51820 limit rate 10/second burst 20 packets drop + } + chain input { type filter hook input priority 0; policy drop; iifname lo accept + + # Connection tracking bypass for WireGuard UDP traffic + iifname "${PUBLIC_INTERFACE}" udp dport 51820 notrack + ct state established,related accept ct state invalid drop + # Allow SSH tcp dport 22 accept + + # Allow WireGuard UDP traffic (already tracked via notrack) udp dport 51820 accept + # ICMPv4 icmp type { echo-request, echo-reply } accept - icmpv6 type { echo-request, echo-reply, nd-neighbor-solicit, nd-neighbor-advert } accept + + # ICMPv6 - ensure neighbor discovery is allowed + icmpv6 type { echo-request, echo-reply, nd-neighbor-solicit, nd-neighbor-advert, nd-router-solicit, nd-router-advert } accept } chain forward { @@ -188,6 +816,9 @@ table inet wireguard { chain output { type filter hook output priority 0; policy accept; + + # Connection tracking bypass for WireGuard UDP traffic + oifname "${PUBLIC_INTERFACE}" udp dport 51820 notrack } } @@ -195,7 +826,10 @@ table ip nat { chain postrouting { type nat hook postrouting priority 100; policy accept; - oifname "${PUBLIC_INTERFACE}" ip saddr 10.10.69.0/24 masquerade + # TCP MSS clamping for MTU issues (clamp to 1360) + oifname "${PUBLIC_INTERFACE}" tcp flags syn tcp option maxseg size set 1360 + + oifname "${PUBLIC_INTERFACE}" ip saddr ${VPN_IPV4_RANGE} masquerade } } @@ -203,7 +837,10 @@ table ip6 nat { chain postrouting { type nat hook postrouting priority 100; policy accept; - oifname "${PUBLIC_INTERFACE}" ip6 saddr fd69:dead:beef:69::/64 masquerade + # TCP MSS clamping for IPv6 MTU issues + oifname "${PUBLIC_INTERFACE}" tcp flags syn tcp option maxseg size set 1360 + + oifname "${PUBLIC_INTERFACE}" ip6 saddr ${VPN_IPV6_RANGE} masquerade } } EOF @@ -220,6 +857,14 @@ EOF chmod 600 /etc/nftables.conf chmod 600 /etc/nftables.d/wireguard.conf + # Validate nftables configuration before applying + echo "Validating nftables configuration..." + if ! nft check -f /etc/nftables.conf; then + echo "ERROR: nftables configuration validation failed" + exit 1 + fi + echo "nftables configuration is valid" + nft -f /etc/nftables.conf # Enable nftables service to load rules on boot @@ -238,12 +883,25 @@ EOF echo "nftables service started successfully" # Generate server keys + log_info "Generating server keys..." echo "Generating server keys..." mkdir -p /etc/wireguard cd /etc/wireguard - wg genkey | tee server_private.key | wg pubkey > server_public.key - SERVER_PRIVATE_KEY=$(cat server_private.key) - SERVER_PUBLIC_KEY=$(cat server_public.key) + + # Generate keys with atomic write + local temp_private=$(mktemp) + local temp_public=$(mktemp) + wg genkey > "$temp_private" + wg pubkey < "$temp_private" > "$temp_public" + SERVER_PRIVATE_KEY=$(cat "$temp_private") + SERVER_PUBLIC_KEY=$(cat "$temp_public") + + # Move to final location with proper permissions + mv "$temp_private" server_private.key + mv "$temp_public" server_public.key + chmod 600 server_private.key + chmod 644 server_public.key + log_info "Server keys generated and secured" # Create config.d directory mkdir -p "${CONF_D_DIR}" @@ -252,7 +910,7 @@ EOF cat > "${SERVER_CONF}" <' to add clients" + echo "Use '$0 add [--psk]' to add clients" echo "Use '$0 list' to list clients" echo "Configs will be merged from: ${CONF_D_DIR}" echo " - ${SERVER_CONF} (server interface)" @@ -321,52 +980,122 @@ EOF echo "System status: systemctl status wg-quick@wg0" echo "" echo "To manually reload configs: $0 load-clients" + + # Disable rollback flag - installation successful + ROLLBACK_NEEDED=false + + # Clean up backup directory + if [[ -d "${BACKUP_DIR}" ]]; then + rm -rf "${BACKUP_DIR}" + log_info "Cleaned up backup directory" + fi + + log_info "Installation completed successfully" } cmd_add() { check_root - local name=$1 + local name="" + local use_psk=false + + # Parse arguments + while [[ $# -gt 0 ]]; do + case "$1" in + --psk) + use_psk=true + shift + ;; + *) + if [[ -z "$name" ]]; then + name="$1" + fi + shift + ;; + esac + done if [[ -z "$name" ]]; then - echo "ERROR: Client name required" + log_error "Client name required" + echo "ERROR: Client name required" >&2 usage fi + # Validate client name + if ! validate_client_name "$name"; then + exit 1 + fi + if [[ -f "${CONF_D_DIR}/client-${name}.conf" ]]; then - echo "ERROR: Client '${name}' already exists" + log_error "Client '${name}' already exists" + echo "ERROR: Client '${name}' already exists" >&2 exit 1 fi if [[ ! -f "/etc/wireguard/server_public.key" ]]; then - echo "ERROR: WireGuard server not installed. Run '$0 install' first" + log_error "WireGuard server not installed. Run '$0 install' first" + echo "ERROR: WireGuard server not installed. Run '$0 install' first" >&2 exit 1 fi local server_public_key=$(cat /etc/wireguard/server_public.key) - local client_keys_dir=$(mktemp -d) - pushd "$client_keys_dir" > /dev/null + log_info "Creating client '${name}'${use_psk:+ with PSK}" + + # Auto-backup before add + backup_config "add-${name}" + + # Use global TEMP_DIR for trap cleanup + TEMP_DIR=$(mktemp -d) + log_info "Created temporary directory: ${TEMP_DIR}" + + pushd "$TEMP_DIR" > /dev/null wg genkey | tee client_private.key | wg pubkey > client_public.key local client_private_key=$(cat client_private.key) local client_public_key=$(cat client_public.key) + + # Generate PSK if requested + local psk_line="" + if [[ "$use_psk" == true ]]; then + wg genpsk > client_psk.key + local client_psk=$(cat client_psk.key) + psk_line="PresharedKey = ${client_psk}" + log_info "Generated pre-shared key for additional security" + fi popd > /dev/null local client_ipv4=$(get_next_ipv4) local client_ipv6=$(get_next_ipv6) + # Validate IP availability + if ! validate_ip_availability "$client_ipv4" "$client_ipv6"; then + exit 1 + fi + + # Validate DNS servers + if ! validate_dns_servers "$DNS_SERVERS"; then + exit 1 + fi + mkdir -p "${CONF_D_DIR}" - cat > "${CONF_D_DIR}/client-${name}.conf" < "$temp_server_conf" < "$client_output" < "$temp_client_conf" < "${CLIENT_OUTPUT_DIR}/${name}.qr" chmod 600 "${CLIENT_OUTPUT_DIR}/${name}.qr" - rm -rf "$client_keys_dir" + # TEMP_DIR will be cleaned up by the trap handler + log_info "Temporary directory will be cleaned up by trap handler" "${SCRIPT_PATH}" load-clients - echo "Client '${name}' added successfully" + # Reset TEMP_DIR since cleanup is done + TEMP_DIR="" + + log_info "Client '${name}' added successfully" + echo "Client '${name}' added successfully${use_psk:+ with PSK}" echo "" echo "Client IPv4: ${client_ipv4}" echo "Client IPv6: ${client_ipv6}" echo "Config saved to: ${client_output}" echo "QR code saved to: ${CLIENT_OUTPUT_DIR}/${name}.qr" + if [[ "$use_psk" == true ]]; then + echo "Pre-shared key enabled for additional security" + fi } cmd_list() { @@ -414,7 +1154,7 @@ cmd_list() { for client_file in "${CONF_D_DIR}"/client-*.conf; do local name=$(basename "${client_file}" .conf | sed 's/^client-//') - local ipv4=$(grep "AllowedIPs = 10.10.69." "${client_file}" | grep -o '10.10.69\.[0-9]*' || echo "N/A") + local ipv4=$(grep "AllowedIPs = $(get_ipv4_network "${VPN_IPV4_RANGE}")" "${client_file}" | grep -o "$(get_ipv4_network "${VPN_IPV4_RANGE}")[0-9]*" || echo "N/A") local ipv6=$(grep "AllowedIPs = fd69:dead:beef:69::" "${client_file}" | grep -o 'fd69:dead:beef:69::[0-9a-f]*' || echo "N/A") local public_key=$(grep "PublicKey = " "${client_file}" | cut -d' ' -f3) @@ -433,23 +1173,35 @@ cmd_remove() { local name=$1 if [[ -z "$name" ]]; then - echo "ERROR: Client name required" + log_error "Client name required" + echo "ERROR: Client name required" >&2 usage fi + # Validate client name + if ! validate_client_name "$name"; then + exit 1 + fi + local client_file="${CONF_D_DIR}/client-${name}.conf" if [[ ! -f "$client_file" ]]; then - echo "ERROR: Client '${name}' not found" + log_error "Client '${name}' not found" + echo "ERROR: Client '${name}' not found" >&2 exit 1 fi + # Auto-backup before remove + backup_config "remove-${name}" + + log_info "Removing client '${name}'..." rm "$client_file" rm -f "${CLIENT_OUTPUT_DIR}/${name}.conf" rm -f "${CLIENT_OUTPUT_DIR}/${name}.qr" "${SCRIPT_PATH}" load-clients + log_info "Client '${name}' removed successfully" echo "Client '${name}' removed successfully" } @@ -531,6 +1283,13 @@ cmd_load_clients() { mv "${TEMP_CONFIG}" "${WG_CONFIG}" chmod 600 "${WG_CONFIG}" + # Validate config syntax + if ! validate_config_syntax "${WG_CONFIG}"; then + echo "ERROR: Configuration file validation failed" + rm -f "${WG_CONFIG}" + exit 1 + fi + # Check if WireGuard is running if wg show wg0 &>/dev/null; then # Reload WireGuard using systemctl @@ -559,7 +1318,8 @@ case "${1:-}" in cmd_install ;; add) - cmd_add "${2:-}" + shift # Remove 'add' from arguments + cmd_add "$@" ;; list) cmd_list @@ -580,7 +1340,8 @@ case "${1:-}" in usage ;; *) - echo "ERROR: Unknown command '${1}'" + log_error "Unknown command '${1}'" + echo "ERROR: Unknown command '${1}'" >&2 usage ;; esac