Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 83 additions & 12 deletions clients/clients.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@ package clients
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net"
"net/netip"
"os"
Expand Down Expand Up @@ -105,6 +107,12 @@ type WireGuardService struct {
netstackListener net.PacketConn
netstackListenerMu sync.Mutex
wgTesterServer *wgtester.Server
// Bandwidth check goroutine lifecycle
bandwidthCheckStop chan struct{}
bandwidthCheckWg sync.WaitGroup
bandwidthCheckMu sync.Mutex
// UAPI listener for native interface mode
uapiListener net.Listener
}

func NewWireGuardService(interfaceName string, port uint16, mtu int, host string, newtId string, wsClient *websocket.Client, dns string, useNativeInterface bool) (*WireGuardService, error) {
Expand Down Expand Up @@ -196,6 +204,9 @@ func (s *WireGuardService) Close() {
s.stopGetConfig = nil
}

// Stop the periodic bandwidth check goroutine
s.stopPeriodicBandwidthCheck()

// Stop the direct UDP relay first
s.StopDirectUDPRelay()

Expand All @@ -204,6 +215,12 @@ func (s *WireGuardService) Close() {
s.holePunchManager.Stop()
}

// Close UAPI listener (native interface mode) - this will cause the Accept goroutine to exit
if s.uapiListener != nil {
s.uapiListener.Close()
s.uapiListener = nil
}

s.mu.Lock()
defer s.mu.Unlock()

Expand Down Expand Up @@ -236,6 +253,20 @@ func (s *WireGuardService) Close() {
}
}

func (s *WireGuardService) startPeriodicBandwidthCheck() {
s.bandwidthCheckMu.Lock()
defer s.bandwidthCheckMu.Unlock()

if s.bandwidthCheckStop != nil {
close(s.bandwidthCheckStop)
s.bandwidthCheckWg.Wait()
}

s.bandwidthCheckStop = make(chan struct{})
s.bandwidthCheckWg.Add(1)
go s.periodicBandwidthCheck(s.bandwidthCheckStop)
}

func (s *WireGuardService) SetToken(token string) {
s.token = token
if s.holePunchManager != nil {
Expand Down Expand Up @@ -378,9 +409,17 @@ func (s *WireGuardService) runDirectUDPRelay(listener net.PacketConn) {

n, remoteAddr, err := listener.ReadFrom(buf)
if err != nil {
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
// Check for timeout first - this is normal operation
var netErr net.Error
if errors.As(err, &netErr) && netErr.Timeout() {
continue // Just a timeout, check for stop and try again
}
// Check for connection closed conditions - exit gracefully
if errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) {
logger.Debug("Direct UDP relay connection closed, stopping")
return
}
// Check if we've been asked to stop
if s.directRelayStop != nil {
select {
case <-s.directRelayStop:
Expand Down Expand Up @@ -448,7 +487,9 @@ func (s *WireGuardService) LoadRemoteConfig() error {
}, 2*time.Second)

logger.Debug("Requesting WireGuard configuration from remote server")
go s.periodicBandwidthCheck()

// Restart the periodic bandwidth check for the current device lifecycle.
s.startPeriodicBandwidthCheck()

return nil
}
Expand Down Expand Up @@ -683,6 +724,16 @@ func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error {
// Parse the IP address and CIDR mask
tunnelIP := netip.MustParseAddr(parts[0])

// Config refreshes can legitimately resend the same config. Reuse the
// existing interface instead of creating a second device/listener stack.
if s.device != nil {
if s.TunnelIP != "" && s.TunnelIP != tunnelIP.String() {
logger.Warn("WireGuard interface already initialized with tunnel IP %s; ignoring re-init request for %s", s.TunnelIP, tunnelIP.String())
}
s.mu.Unlock()
return nil
}

var err error

if s.useNativeInterface {
Expand Down Expand Up @@ -724,22 +775,23 @@ func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error {
logger.Error("UAPI listen error: %v", err)
}

uapiListener, err := newtDevice.UapiListen(interfaceName, fileUAPI)
listener, err := newtDevice.UapiListen(interfaceName, fileUAPI)
if err != nil {
logger.Error("Failed to listen on uapi socket: %v", err)
os.Exit(1)
}
s.uapiListener = listener

go func() {
go func(listener net.Listener, dev *device.Device) {
for {
conn, err := uapiListener.Accept()
conn, err := listener.Accept()
if err != nil {

// Listener closed, exit goroutine
return
}
go s.device.IpcHandle(conn)
go dev.IpcHandle(conn)
}
}()
}(listener, s.device)
logger.Info("UAPI listener started")

// Configure WireGuard with private key
Expand Down Expand Up @@ -1110,17 +1162,36 @@ func (s *WireGuardService) handleUpdatePeer(msg websocket.WSMessage) {
logger.Info("Peer %s updated successfully", request.PublicKey)
}

func (s *WireGuardService) periodicBandwidthCheck() {
func (s *WireGuardService) periodicBandwidthCheck(stopCh <-chan struct{}) {
defer s.bandwidthCheckWg.Done()
ticker := time.NewTicker(10 * time.Second)
defer ticker.Stop()

for range ticker.C {
if err := s.reportPeerBandwidth(); err != nil {
logger.Info("Failed to report peer bandwidth: %v", err)
for {
select {
case <-stopCh:
logger.Debug("Stopping periodic bandwidth check")
return
case <-ticker.C:
if err := s.reportPeerBandwidth(); err != nil {
logger.Info("Failed to report peer bandwidth: %v", err)
}
}
}
}

// stopPeriodicBandwidthCheck stops the bandwidth check goroutine and waits for it to exit
func (s *WireGuardService) stopPeriodicBandwidthCheck() {
s.bandwidthCheckMu.Lock()
defer s.bandwidthCheckMu.Unlock()

if s.bandwidthCheckStop != nil {
close(s.bandwidthCheckStop)
s.bandwidthCheckWg.Wait()
s.bandwidthCheckStop = nil
}
}

func (s *WireGuardService) calculatePeerBandwidth() ([]PeerBandwidth, error) {
if s.device == nil {
return []PeerBandwidth{}, nil
Expand Down