diff --git a/client/cmd/ssh.go b/client/cmd/ssh.go index d0f506ff8..f97ce5f90 100644 --- a/client/cmd/ssh.go +++ b/client/cmd/ssh.go @@ -9,7 +9,6 @@ import ( "strings" "syscall" - log "github.com/sirupsen/logrus" "github.com/spf13/cobra" "github.com/netbirdio/netbird/client/internal" @@ -73,7 +72,8 @@ var sshCmd = &cobra.Command{ go func() { // blocking if err := runSSH(sshctx, host, []byte(config.SSHKey), cmd); err != nil { - log.Print(err) + os.Exit(1) + // log.Print(err) } cancel() }() @@ -92,11 +92,9 @@ func runSSH(ctx context.Context, addr string, pemKey []byte, cmd *cobra.Command) c, err := nbssh.DialWithKey(fmt.Sprintf("%s:%d", addr, port), user, pemKey) if err != nil { cmd.Printf("Error: %v\n", err) - cmd.Printf("Couldn't connect. " + - "You might be disconnected from the NetBird network, or the NetBird agent isn't running.\n" + - "Run the status command: \n\n" + - " netbird status\n\n" + - "It might also be that the SSH server is disabled on the agent you are trying to connect to.\n") + cmd.Printf("Couldn't connect. Please check the connection status or if the ssh server is enabled on the other peer" + + "You can verify the connection by running:\n\n" + + " netbird status\n\n") return nil } go func() { diff --git a/client/ssh/server.go b/client/ssh/server.go index 5d63362b9..f08c5a2f1 100644 --- a/client/ssh/server.go +++ b/client/ssh/server.go @@ -2,9 +2,6 @@ package ssh import ( "fmt" - "github.com/creack/pty" - "github.com/gliderlabs/ssh" - log "github.com/sirupsen/logrus" "io" "net" "os" @@ -13,6 +10,11 @@ import ( "runtime" "strings" "sync" + "time" + + "github.com/creack/pty" + "github.com/gliderlabs/ssh" + log "github.com/sirupsen/logrus" ) // DefaultSSHPort is the default SSH port of the NetBird's embedded SSH server @@ -137,6 +139,8 @@ func (srv *DefaultServer) sessionHandler(session ssh.Session) { } }() + log.Infof("Establishing SSH session for %s from host %s", session.User(), session.RemoteAddr().String()) + localUser, err := userNameLookup(session.User()) if err != nil { _, err = fmt.Fprintf(session, "remote SSH server couldn't find local user %s\n", session.User()) //nolint @@ -172,6 +176,7 @@ func (srv *DefaultServer) sessionHandler(session ssh.Session) { } } + log.Debugf("Login command: %s", cmd.String()) file, err := pty.Start(cmd) if err != nil { log.Errorf("failed starting SSH server %v", err) @@ -199,24 +204,49 @@ func (srv *DefaultServer) sessionHandler(session ssh.Session) { return } } + log.Debugf("SSH session ended") } func (srv *DefaultServer) stdInOut(file *os.File, session ssh.Session) { go func() { // stdin - _, err := io.Copy(file, session) - if err != nil { - return - } + io.Copy(file, session) }() + // For nodes on AWS the terminal takes a while to be ready so we need to wait + terminalIsReady := make(chan bool) go func() { - // stdout - _, err := io.Copy(session, file) - if err != nil { - return + for { + log.Debugf("Checking if terminal is ready") + if checkIfFileIsReady(file) { + terminalIsReady <- true + } + time.Sleep(100 * time.Millisecond) } }() + timer := time.NewTimer(30 * time.Second) + for { + select { + case <-timer.C: + session.Write([]byte("Reached timeout while opening connection\n")) + session.Exit(1) + case <-terminalIsReady: + // stdout + io.Copy(session, file) + session.Exit(0) + } + } +} + +func checkIfFileIsReady(file *os.File) bool { + buffer := make([]byte, 0) + _, err := file.Read(buffer) + // _, err := file.Stat() + // log.Infof("file stat: %v", err) + if err == nil { + return true + } + return false } // Start starts SSH server. Blocking