From cb83b7c0d3ddbacdd6ff391bdcc405f8b2ad00f0 Mon Sep 17 00:00:00 2001 From: shuuri-labs <61762328+shuuri-labs@users.noreply.github.com> Date: Fri, 28 Nov 2025 21:53:53 +0100 Subject: [PATCH] [relay] use exposed address for healthcheck TLS validation (#4872) * fix(relay): use exposed address for healthcheck TLS validation Healthcheck was using listen address (0.0.0.0) instead of exposed address (domain name) for certificate validation, causing validation to always fail. Now correctly uses the exposed address where the TLS certificate is valid, matching real client connection behavior. * - store exposedAddress directly in Relay struct instead of parsing on every call - remove unused parseHostPort() function - remove unused ListenAddress() method from ServiceChecker interface - improve error logging with address context * [relay/healthcheck] Remove QUIC health check logic, update WebSocket validation flow Refactored health check logic by removing QUIC-specific connection validation and simplifying logic for WebSocket protocol. Adjusted certificate validation flow and improved handling of exposed addresses. * [relay/healthcheck] Fix certificate validation status during health check --------- Co-authored-by: Maycon Santos --- relay/healthcheck/healthcheck.go | 44 ++++++++++++-------------------- relay/healthcheck/quic.go | 31 ---------------------- relay/healthcheck/ws.go | 12 +++++++-- relay/server/relay.go | 27 ++++++++++++-------- relay/server/server.go | 8 ++---- 5 files changed, 46 insertions(+), 76 deletions(-) delete mode 100644 relay/healthcheck/quic.go diff --git a/relay/healthcheck/healthcheck.go b/relay/healthcheck/healthcheck.go index eedd62394..6463843eb 100644 --- a/relay/healthcheck/healthcheck.go +++ b/relay/healthcheck/healthcheck.go @@ -6,14 +6,13 @@ import ( "errors" "net" "net/http" + "strings" "sync" "time" log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/relay/protocol" - "github.com/netbirdio/netbird/relay/server/listener/quic" - "github.com/netbirdio/netbird/relay/server/listener/ws" ) const ( @@ -27,7 +26,7 @@ const ( type ServiceChecker interface { ListenerProtocols() []protocol.Protocol - ListenAddress() string + ExposedAddress() string } type HealthStatus struct { @@ -135,7 +134,11 @@ func (s *Server) getHealthStatus(ctx context.Context) (*HealthStatus, bool) { } status.Listeners = listeners - if ok := s.validateCertificate(ctx); !ok { + if !strings.HasPrefix(s.config.ServiceChecker.ExposedAddress(), "rels") { + status.CertificateValid = false + } + + if ok := s.validateConnection(ctx); !ok { status.Status = statusUnhealthy status.CertificateValid = false healthy = false @@ -152,32 +155,18 @@ func (s *Server) validateListeners() ([]protocol.Protocol, bool) { return listeners, true } -func (s *Server) validateCertificate(ctx context.Context) bool { - listenAddress := s.config.ServiceChecker.ListenAddress() - if listenAddress == "" { - log.Warn("listen address is empty") +func (s *Server) validateConnection(ctx context.Context) bool { + exposedAddress := s.config.ServiceChecker.ExposedAddress() + if exposedAddress == "" { + log.Error("exposed address is empty, cannot validate certificate") return false } - dAddr := dialAddress(listenAddress) - - for _, proto := range s.config.ServiceChecker.ListenerProtocols() { - switch proto { - case ws.Proto: - if err := dialWS(ctx, dAddr); err != nil { - log.Errorf("failed to dial WebSocket listener: %v", err) - return false - } - case quic.Proto: - if err := dialQUIC(ctx, dAddr); err != nil { - log.Errorf("failed to dial QUIC listener: %v", err) - return false - } - default: - log.Warnf("unknown protocol for healthcheck: %s", proto) - return false - } + if err := dialWS(ctx, exposedAddress); err != nil { + log.Errorf("failed to dial WebSocket listener at %s: %v", exposedAddress, err) + return false } + return true } @@ -187,8 +176,9 @@ func dialAddress(listenAddress string) string { return listenAddress // fallback, might be invalid for dialing } + // When listening on all interfaces, show localhost for better readability if host == "" || host == "::" || host == "0.0.0.0" { - host = "0.0.0.0" + host = "localhost" } return net.JoinHostPort(host, port) diff --git a/relay/healthcheck/quic.go b/relay/healthcheck/quic.go deleted file mode 100644 index 1582edf7b..000000000 --- a/relay/healthcheck/quic.go +++ /dev/null @@ -1,31 +0,0 @@ -package healthcheck - -import ( - "context" - "crypto/tls" - "fmt" - "time" - - "github.com/quic-go/quic-go" - - tlsnb "github.com/netbirdio/netbird/shared/relay/tls" -) - -func dialQUIC(ctx context.Context, address string) error { - tlsConfig := &tls.Config{ - InsecureSkipVerify: false, // Keep certificate validation enabled - NextProtos: []string{tlsnb.NBalpn}, - } - - conn, err := quic.DialAddr(ctx, address, tlsConfig, &quic.Config{ - MaxIdleTimeout: 30 * time.Second, - KeepAlivePeriod: 10 * time.Second, - EnableDatagrams: true, - }) - if err != nil { - return fmt.Errorf("failed to connect to QUIC server: %w", err) - } - - _ = conn.CloseWithError(0, "availability check complete") - return nil -} diff --git a/relay/healthcheck/ws.go b/relay/healthcheck/ws.go index 49694356c..badd31219 100644 --- a/relay/healthcheck/ws.go +++ b/relay/healthcheck/ws.go @@ -3,6 +3,7 @@ package healthcheck import ( "context" "fmt" + "strings" "github.com/coder/websocket" @@ -10,12 +11,19 @@ import ( ) func dialWS(ctx context.Context, address string) error { - url := fmt.Sprintf("wss://%s%s", address, relay.WebSocketURLPath) + addressSplit := strings.Split(address, "/") + scheme := "ws" + if addressSplit[0] == "rels:" { + scheme = "wss" + } + url := fmt.Sprintf("%s://%s%s", scheme, addressSplit[2], relay.WebSocketURLPath) conn, resp, err := websocket.Dial(ctx, url, nil) if resp != nil { defer func() { - _ = resp.Body.Close() + if resp.Body != nil { + _ = resp.Body.Close() + } }() } diff --git a/relay/server/relay.go b/relay/server/relay.go index d86684937..aab575bf0 100644 --- a/relay/server/relay.go +++ b/relay/server/relay.go @@ -51,10 +51,11 @@ type Relay struct { metricsCancel context.CancelFunc validator Validator - store *store.Store - notifier *store.PeerNotifier - instanceURL string - preparedMsg *preparedMsg + store *store.Store + notifier *store.PeerNotifier + instanceURL string + exposedAddress string + preparedMsg *preparedMsg closed bool closeMu sync.RWMutex @@ -87,12 +88,13 @@ func NewRelay(config Config) (*Relay, error) { } r := &Relay{ - metrics: m, - metricsCancel: metricsCancel, - validator: config.AuthValidator, - instanceURL: config.instanceURL, - store: store.NewStore(), - notifier: store.NewPeerNotifier(), + metrics: m, + metricsCancel: metricsCancel, + validator: config.AuthValidator, + instanceURL: config.instanceURL, + exposedAddress: config.ExposedAddress, + store: store.NewStore(), + notifier: store.NewPeerNotifier(), } r.preparedMsg, err = newPreparedMsg(r.instanceURL) @@ -178,3 +180,8 @@ func (r *Relay) Shutdown(ctx context.Context) { func (r *Relay) InstanceURL() string { return r.instanceURL } + +// ExposedAddress returns the exposed address (domain:port) where clients connect +func (r *Relay) ExposedAddress() string { + return r.exposedAddress +} diff --git a/relay/server/server.go b/relay/server/server.go index 4c30e7fdc..2c9e658d6 100644 --- a/relay/server/server.go +++ b/relay/server/server.go @@ -28,8 +28,6 @@ type ListenerConfig struct { // It is the gate between the WebSocket listener and the Relay server logic. // In a new HTTP connection, the server will accept the connection and pass it to the Relay server via the Accept method. type Server struct { - listenAddr string - relay *Relay listeners []listener.Listener listenerMux sync.Mutex @@ -62,8 +60,6 @@ func NewServer(config Config) (*Server, error) { // Listen starts the relay server. func (r *Server) Listen(cfg ListenerConfig) error { - r.listenAddr = cfg.Address - wSListener := &ws.Listener{ Address: cfg.Address, TLSConfig: cfg.TLSConfig, @@ -139,6 +135,6 @@ func (r *Server) ListenerProtocols() []protocol.Protocol { return result } -func (r *Server) ListenAddress() string { - return r.listenAddr +func (r *Server) ExposedAddress() string { + return r.relay.ExposedAddress() }