From 131136439764a67187b7fbfadef1215b5cc66780 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Thu, 20 Nov 2025 17:09:22 +0100 Subject: [PATCH] [client] Increase ssh detection timeout (#4827) --- client/cmd/ssh.go | 17 +++++++++++++---- client/ssh/client/client.go | 9 ++++++--- client/ssh/config/manager.go | 7 +------ client/ssh/detection/detection.go | 18 +++++++++--------- client/ssh/server/jwt_test.go | 6 +++--- client/wasm/cmd/main.go | 24 ++++++++++++++---------- 6 files changed, 46 insertions(+), 35 deletions(-) diff --git a/client/cmd/ssh.go b/client/cmd/ssh.go index 70c7dbcff..92857c637 100644 --- a/client/cmd/ssh.go +++ b/client/cmd/ssh.go @@ -749,7 +749,9 @@ func sshProxyFn(cmd *cobra.Command, args []string) error { if firstLogFile := util.FindFirstLogPath(logFiles); firstLogFile != "" && firstLogFile != defaultLogFile { logOutput = firstLogFile } - if err := util.InitLog(logLevel, logOutput); err != nil { + + proxyLogLevel := getEnvOrDefault("LOG_LEVEL", logLevel) + if err := util.InitLog(proxyLogLevel, logOutput); err != nil { return fmt.Errorf("init log: %w", err) } @@ -788,7 +790,8 @@ var sshDetectCmd = &cobra.Command{ } func sshDetectFn(cmd *cobra.Command, args []string) error { - if err := util.InitLog(logLevel, "console"); err != nil { + detectLogLevel := getEnvOrDefault("LOG_LEVEL", logLevel) + if err := util.InitLog(detectLogLevel, "console"); err != nil { os.Exit(detection.ServerTypeRegular.ExitCode()) } @@ -797,15 +800,21 @@ func sshDetectFn(cmd *cobra.Command, args []string) error { port, err := strconv.Atoi(portStr) if err != nil { + log.Debugf("invalid port %q: %v", portStr, err) os.Exit(detection.ServerTypeRegular.ExitCode()) } - dialer := &net.Dialer{Timeout: detection.Timeout} - serverType, err := detection.DetectSSHServerType(cmd.Context(), dialer, host, port) + ctx, cancel := context.WithTimeout(cmd.Context(), detection.DefaultTimeout) + + dialer := &net.Dialer{} + serverType, err := detection.DetectSSHServerType(ctx, dialer, host, port) if err != nil { + log.Debugf("SSH server detection failed: %v", err) + cancel() os.Exit(detection.ServerTypeRegular.ExitCode()) } + cancel() os.Exit(serverType.ExitCode()) return nil } diff --git a/client/ssh/client/client.go b/client/ssh/client/client.go index 882056374..31b80317a 100644 --- a/client/ssh/client/client.go +++ b/client/ssh/client/client.go @@ -343,10 +343,13 @@ func dialWithJWT(ctx context.Context, network, addr string, config *ssh.ClientCo return nil, fmt.Errorf("parse port %s: %w", portStr, err) } - dialer := &net.Dialer{Timeout: detection.Timeout} - serverType, err := detection.DetectSSHServerType(ctx, dialer, host, port) + detectionCtx, cancel := context.WithTimeout(ctx, config.Timeout) + defer cancel() + + dialer := &net.Dialer{} + serverType, err := detection.DetectSSHServerType(detectionCtx, dialer, host, port) if err != nil { - return nil, fmt.Errorf("SSH server detection failed: %w", err) + return nil, fmt.Errorf("SSH server detection: %w", err) } if !serverType.RequiresJWT() { diff --git a/client/ssh/config/manager.go b/client/ssh/config/manager.go index 03a136de3..cc47fd2d2 100644 --- a/client/ssh/config/manager.go +++ b/client/ssh/config/manager.go @@ -189,12 +189,7 @@ func (m *Manager) buildPeerConfig(allHostPatterns []string) (string, error) { hostLine := strings.Join(deduplicatedPatterns, " ") config := fmt.Sprintf("Host %s\n", hostLine) - - if runtime.GOOS == "windows" { - config += fmt.Sprintf(" Match exec \"%s ssh detect %%h %%p\"\n", execPath) - } else { - config += fmt.Sprintf(" Match exec \"%s ssh detect %%h %%p 2>/dev/null\"\n", execPath) - } + config += fmt.Sprintf(" Match exec \"%s ssh detect %%h %%p\"\n", execPath) config += " PreferredAuthentications password,publickey,keyboard-interactive\n" config += " PasswordAuthentication yes\n" config += " PubkeyAuthentication yes\n" diff --git a/client/ssh/detection/detection.go b/client/ssh/detection/detection.go index 487f4665a..f23ea4c37 100644 --- a/client/ssh/detection/detection.go +++ b/client/ssh/detection/detection.go @@ -3,6 +3,7 @@ package detection import ( "bufio" "context" + "fmt" "net" "strconv" "strings" @@ -19,8 +20,8 @@ const ( // JWTRequiredMarker is appended to responses when JWT is required JWTRequiredMarker = "NetBird-JWT-Required" - // Timeout is the timeout for SSH server detection - Timeout = 5 * time.Second + // DefaultTimeout is the default timeout for SSH server detection + DefaultTimeout = 5 * time.Second ) type ServerType string @@ -61,21 +62,20 @@ func DetectSSHServerType(ctx context.Context, dialer Dialer, host string, port i conn, err := dialer.DialContext(ctx, "tcp", targetAddr) if err != nil { - log.Debugf("SSH connection failed for detection: %v", err) - return ServerTypeRegular, nil + return ServerTypeRegular, fmt.Errorf("connect to %s: %w", targetAddr, err) } defer conn.Close() - if err := conn.SetReadDeadline(time.Now().Add(Timeout)); err != nil { - log.Debugf("set read deadline: %v", err) - return ServerTypeRegular, nil + if deadline, ok := ctx.Deadline(); ok { + if err := conn.SetReadDeadline(deadline); err != nil { + return ServerTypeRegular, fmt.Errorf("set read deadline: %w", err) + } } reader := bufio.NewReader(conn) serverBanner, err := reader.ReadString('\n') if err != nil { - log.Debugf("read SSH banner: %v", err) - return ServerTypeRegular, nil + return ServerTypeRegular, fmt.Errorf("read SSH banner: %w", err) } serverBanner = strings.TrimSpace(serverBanner) diff --git a/client/ssh/server/jwt_test.go b/client/ssh/server/jwt_test.go index e22bdfb06..1f3bac76d 100644 --- a/client/ssh/server/jwt_test.go +++ b/client/ssh/server/jwt_test.go @@ -58,7 +58,7 @@ func TestJWTEnforcement(t *testing.T) { require.NoError(t, err) port, err := strconv.Atoi(portStr) require.NoError(t, err) - dialer := &net.Dialer{Timeout: detection.Timeout} + dialer := &net.Dialer{} serverType, err := detection.DetectSSHServerType(context.Background(), dialer, host, port) if err != nil { t.Logf("Detection failed: %v", err) @@ -93,7 +93,7 @@ func TestJWTEnforcement(t *testing.T) { portNoJWT, err := strconv.Atoi(portStrNoJWT) require.NoError(t, err) - dialer := &net.Dialer{Timeout: detection.Timeout} + dialer := &net.Dialer{} serverType, err := detection.DetectSSHServerType(context.Background(), dialer, hostNoJWT, portNoJWT) require.NoError(t, err) assert.Equal(t, detection.ServerTypeNetBirdNoJWT, serverType) @@ -218,7 +218,7 @@ func TestJWTDetection(t *testing.T) { port, err := strconv.Atoi(portStr) require.NoError(t, err) - dialer := &net.Dialer{Timeout: detection.Timeout} + dialer := &net.Dialer{} serverType, err := detection.DetectSSHServerType(context.Background(), dialer, host, port) require.NoError(t, err) assert.Equal(t, detection.ServerTypeNetBirdJWT, serverType) diff --git a/client/wasm/cmd/main.go b/client/wasm/cmd/main.go index 4dc14a1ca..238e272fa 100644 --- a/client/wasm/cmd/main.go +++ b/client/wasm/cmd/main.go @@ -19,9 +19,10 @@ import ( ) const ( - clientStartTimeout = 30 * time.Second - clientStopTimeout = 10 * time.Second - defaultLogLevel = "warn" + clientStartTimeout = 30 * time.Second + clientStopTimeout = 10 * time.Second + defaultLogLevel = "warn" + defaultSSHDetectionTimeout = 20 * time.Second ) func main() { @@ -207,11 +208,19 @@ func createDetectSSHServerMethod(client *netbird.Client) js.Func { host := args[0].String() port := args[1].Int() + timeoutMs := int(defaultSSHDetectionTimeout.Milliseconds()) + if len(args) >= 3 && !args[2].IsNull() && !args[2].IsUndefined() { + timeoutMs = args[2].Int() + if timeoutMs <= 0 { + return js.ValueOf("error: timeout must be positive") + } + } + return createPromise(func(resolve, reject js.Value) { - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeoutMs)*time.Millisecond) defer cancel() - serverType, err := detectSSHServerType(ctx, client, host, port) + serverType, err := sshdetection.DetectSSHServerType(ctx, client, host, port) if err != nil { reject.Invoke(err.Error()) return @@ -222,11 +231,6 @@ func createDetectSSHServerMethod(client *netbird.Client) js.Func { }) } -// detectSSHServerType detects SSH server type using NetBird network connection -func detectSSHServerType(ctx context.Context, client *netbird.Client, host string, port int) (sshdetection.ServerType, error) { - return sshdetection.DetectSSHServerType(ctx, client, host, port) -} - // createClientObject wraps the NetBird client in a JavaScript object func createClientObject(client *netbird.Client) js.Value { obj := make(map[string]interface{})