gssh-agent 1.0.7 → 1.0.8

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/README.md CHANGED
@@ -14,22 +14,13 @@ gssh 是一个供 Agent 使用的 SSH Session 管理工具。通过 Go 语言实
14
14
 
15
15
  ## 安装
16
16
 
17
- ### 方式一:Homebrew(推荐)
17
+ ### 使用 npm(推荐)
18
18
 
19
19
  ```bash
20
- # 添加 tap(未来会支持)
21
- # brew tap forechoandlook/gssh
22
- # brew install gssh
23
-
24
- # 暂时使用手动安装方式
25
- git clone https://github.com/forechoandlook/gssh.git
26
- cd gssh
27
- go build -o bin/daemon cmd/daemon/main.go
28
- go build -o bin/gssh cmd/gssh/main.go
29
- ./homebrew/install.sh
20
+ npm install -g gssh-agent
30
21
  ```
31
22
 
32
- ### 方式二:直接安装
23
+ ### 手动安装
33
24
 
34
25
  ```bash
35
26
  # 克隆项目
@@ -45,53 +36,11 @@ cp bin/daemon /usr/local/bin/gssh-daemon
45
36
  cp bin/gssh /usr/local/bin/gssh
46
37
  ```
47
38
 
48
- ## 服务管理(macOS)
49
-
50
- ### 使用 launchctl
39
+ ## 启动服务
51
40
 
52
41
  ```bash
53
- # 启动服务
54
- launchctl start com.gssh.daemon
55
-
56
- # 停止服务
57
- launchctl stop com.gssh.daemon
58
-
59
- # 查看状态
60
- launchctl list | grep gssh
61
-
62
- # 卸载服务
63
- launchctl unload ~/Library/LaunchAgents/gssh.plist
64
- rm ~/Library/LaunchAgents/gssh.plist
65
- ```
66
-
67
- ### 使用 Homebrew services
68
-
69
- ```bash
70
- # 启动(首次安装后自动启动)
71
- brew services start gssh
72
-
73
- # 停止
74
- brew services stop gssh
75
-
76
- # 查看状态
77
- brew services list
78
-
79
- # 重启
80
- brew services restart gssh
81
- ```
82
-
83
- ### 使用 PM2(跨平台)
84
-
85
- ```bash
86
- # 安装 PM2
87
- npm install -g pm2
88
-
89
- # 启动 daemon
90
- pm2 start gssh-daemon.sh --name gssh
91
-
92
- # 保存并设置开机自启
93
- pm2 save
94
- pm2 startup
42
+ # 启动 gssh daemon
43
+ gssh-daemon
95
44
  ```
96
45
 
97
46
  ## 使用方法
@@ -99,7 +48,7 @@ pm2 startup
99
48
  ### 连接 SSH(密码认证)
100
49
 
101
50
  ```bash
102
- gssh connect -u admin1 -h 139.196.175.163 -p 7080 -P 1234
51
+ gssh connect -u admin1 -h xxxx -p xxxx -P xxxxx
103
52
  ```
104
53
 
105
54
  ### 连接 SSH(密钥认证)
@@ -116,6 +65,12 @@ gssh exec "ls -la"
116
65
 
117
66
  # 指定 session
118
67
  gssh exec -s <session_id> "pwd"
68
+
69
+ # 带超时命令
70
+ gssh exec -t 10 "ls -la"
71
+
72
+ # sudo 命令
73
+ gssh exec -S password "sudo systemctl restart nginx"
119
74
  ```
120
75
 
121
76
  ### 端口转发
@@ -123,16 +78,27 @@ gssh exec -s <session_id> "pwd"
123
78
  ```bash
124
79
  # 本地端口转发:本地 8080 -> 远程 80
125
80
  gssh forward -l 8080 -r 80
81
+
82
+ # 远程端口转发:远程 9000 -> 本地 3000
83
+ gssh forward -R -l 9000 -r 3000
84
+
85
+ # 列出所有端口转发
86
+ gssh forwards
87
+
88
+ # 关闭端口转发
89
+ gssh forward-close <forward_id>
126
90
  ```
127
91
 
128
- ### 文件传输(SFTP/SCP)
92
+ ### 文件传输与同步(SFTP/SCP/SYNC
129
93
 
130
94
  ```bash
131
- # 上传文件(本地 -> 远程)
95
+ # 上传文件或文件夹(本地 -> 远程)
132
96
  gssh scp -put /path/to/local/file.txt /path/to/remote/file.txt
97
+ gssh sync -put /path/to/local/dir /path/to/remote/dir
133
98
 
134
- # 下载文件(远程 -> 本地)
99
+ # 下载文件或文件夹(远程 -> 本地)
135
100
  gssh scp -get /path/to/remote/file.txt /path/to/local/file.txt
101
+ gssh sync -get /path/to/remote/dir /path/to/local/dir
136
102
 
137
103
  # 列出远程目录
138
104
  gssh sftp -c ls -p /path/to/remote/dir
@@ -178,6 +144,10 @@ gssh reconnect -s <session_id>
178
144
  | `-get` | 下载模式(远程 -> 本地) |
179
145
  | `-c command` | SFTP 命令(ls/mkdir/rm) |
180
146
  | `-p path` | SFTP 路径 |
147
+ | `-t timeout` | 命令超时时间(秒) |
148
+ | `-S password` | sudo 密码 |
149
+ | `--ask-pass` | 交互输入 SSH 密码 |
150
+ | `--ask-sudo-pass` | 交互输入 sudo 密码 |
181
151
 
182
152
  ## 开发
183
153
 
package/bin/gssh CHANGED
Binary file
package/cmd/gssh/main.go CHANGED
@@ -73,7 +73,7 @@ func main() {
73
73
  err = handleForwards(socketPath)
74
74
  case "forward-close":
75
75
  err = handleForwardClose(subArgs, socketPath)
76
- case "scp":
76
+ case "scp", "sync":
77
77
  err = handleSCP(subArgs, socketPath)
78
78
  case "sftp":
79
79
  err = handleSFTP(subArgs, socketPath)
@@ -625,6 +625,8 @@ Usage:
625
625
  gssh forward-close <forward_id>
626
626
  gssh scp [-s session_id] -put <local> <remote>
627
627
  gssh scp [-s session_id] -get <remote> <local>
628
+ gssh sync [-s session_id] -put <local> <remote>
629
+ gssh sync [-s session_id] -get <remote> <local>
628
630
  gssh sftp [-s session_id] -c <ls|mkdir|rm> -p <path>
629
631
  gssh -v, --version
630
632
 
@@ -654,6 +656,7 @@ Examples:
654
656
  gssh reconnect 549b6eff-f62c-4dae-a7e9-298815233cf4
655
657
  gssh forward -l 8080 -r 80
656
658
  gssh scp -put local.txt /home/user/remote.txt
659
+ gssh sync -put local_dir /home/user/remote_dir
657
660
  `, version)
658
661
  }
659
662
 
@@ -11,6 +11,7 @@ import (
11
11
 
12
12
  "github.com/pkg/sftp"
13
13
  "golang.org/x/crypto/ssh"
14
+ "golang.org/x/crypto/ssh/knownhosts"
14
15
  )
15
16
 
16
17
  // SSHClient encapsulates SSH connection logic
@@ -32,12 +33,84 @@ func (k *KeyboardInteractiveHandler) Challenge(name, instruction string, questio
32
33
  return answers, nil
33
34
  }
34
35
 
36
+ // getHostKeyCallback returns a host key callback that verifies the host key against the user's known_hosts file.
37
+ // If the host is unknown, it implements Trust-On-First-Use (TOFU) by adding the new key to the known_hosts file.
38
+ func getHostKeyCallback() (ssh.HostKeyCallback, error) {
39
+ homeDir, err := os.UserHomeDir()
40
+ if err != nil {
41
+ return nil, fmt.Errorf("could not get user home dir: %w", err)
42
+ }
43
+
44
+ knownHostsPath := filepath.Join(homeDir, ".ssh", "known_hosts")
45
+
46
+ // Ensure the .ssh directory exists
47
+ if err := os.MkdirAll(filepath.Dir(knownHostsPath), 0700); err != nil {
48
+ return nil, fmt.Errorf("failed to create .ssh directory: %w", err)
49
+ }
50
+
51
+ // Create the known_hosts file if it doesn't exist
52
+ if _, err := os.Stat(knownHostsPath); os.IsNotExist(err) {
53
+ f, err := os.OpenFile(knownHostsPath, os.O_CREATE|os.O_RDWR, 0600)
54
+ if err != nil {
55
+ return nil, fmt.Errorf("failed to create known_hosts file: %w", err)
56
+ }
57
+ f.Close()
58
+ }
59
+
60
+ cb, err := knownhosts.New(knownHostsPath)
61
+ if err != nil {
62
+ return nil, fmt.Errorf("failed to create knownhosts callback: %w", err)
63
+ }
64
+
65
+ return func(hostname string, remote net.Addr, key ssh.PublicKey) error {
66
+ err := cb(hostname, remote, key)
67
+ if err == nil {
68
+ return nil
69
+ }
70
+
71
+ keyErr, ok := err.(*knownhosts.KeyError)
72
+ if !ok {
73
+ return err
74
+ }
75
+
76
+ // If len(keyErr.Want) is 0, it means the key is completely unknown (not a mismatch).
77
+ // We implement Trust-On-First-Use (TOFU) by adding the new key to known_hosts.
78
+ if len(keyErr.Want) == 0 {
79
+ f, fErr := os.OpenFile(knownHostsPath, os.O_APPEND|os.O_WRONLY, 0600)
80
+ if fErr != nil {
81
+ return fmt.Errorf("failed to open known_hosts for appending: %w", fErr)
82
+ }
83
+ defer f.Close()
84
+
85
+ addresses := []string{knownhosts.Normalize(hostname)}
86
+ if remoteString := remote.String(); remoteString != hostname {
87
+ addresses = append(addresses, knownhosts.Normalize(remoteString))
88
+ }
89
+
90
+ line := knownhosts.Line(addresses, key)
91
+ if _, wErr := f.WriteString(line + "\n"); wErr != nil {
92
+ return fmt.Errorf("failed to append key to known_hosts: %w", wErr)
93
+ }
94
+ return nil
95
+ }
96
+
97
+ // It's a key mismatch (security risk: MITM or host key changed). Reject the connection.
98
+ return err
99
+ }, nil
100
+ }
101
+
35
102
  // NewSSHClient creates a new SSH client
36
103
  func NewSSHClient(user, host string, port int, authMethods ...ssh.AuthMethod) (*SSHClient, error) {
104
+ hostKeyCallback, err := getHostKeyCallback()
105
+ if err != nil {
106
+ return nil, fmt.Errorf("failed to get host key callback: %w", err)
107
+ }
108
+
37
109
  config := &ssh.ClientConfig{
38
110
  User: user,
39
111
  Auth: authMethods,
40
- HostKeyCallback: ssh.InsecureIgnoreHostKey(),
112
+ HostKeyCallback: hostKeyCallback,
113
+ Timeout: 10 * time.Second,
41
114
  }
42
115
 
43
116
  addr := fmt.Sprintf("%s:%d", host, port)
@@ -169,8 +242,56 @@ func (c *SSHClient) NewSFTPClient() (*SFTPClient, error) {
169
242
  return &SFTPClient{Client: sftpClient}, nil
170
243
  }
171
244
 
172
- // Upload uploads a local file to remote
245
+ // Upload uploads a local file or directory to remote
173
246
  func (s *SFTPClient) Upload(localPath, remotePath string) (int64, error) {
247
+ localInfo, err := os.Stat(localPath)
248
+ if err != nil {
249
+ return 0, fmt.Errorf("failed to stat local path: %w", err)
250
+ }
251
+
252
+ if !localInfo.IsDir() {
253
+ return s.uploadFile(localPath, remotePath)
254
+ }
255
+
256
+ var totalWritten int64
257
+ err = filepath.Walk(localPath, func(path string, info os.FileInfo, err error) error {
258
+ if err != nil {
259
+ return err
260
+ }
261
+
262
+ relPath, err := filepath.Rel(localPath, path)
263
+ if err != nil {
264
+ return err
265
+ }
266
+
267
+ targetPath := filepath.ToSlash(filepath.Join(remotePath, relPath))
268
+
269
+ if info.IsDir() {
270
+ s.Client.MkdirAll(targetPath)
271
+ return nil
272
+ }
273
+
274
+ // Sync logic
275
+ remoteInfo, err := s.Client.Stat(targetPath)
276
+ if err == nil && remoteInfo.Size() == info.Size() && remoteInfo.ModTime().Unix() == info.ModTime().Unix() {
277
+ return nil
278
+ }
279
+
280
+ written, err := s.uploadFile(path, targetPath)
281
+ if err != nil {
282
+ return err
283
+ }
284
+
285
+ totalWritten += written
286
+ s.Client.Chtimes(targetPath, info.ModTime(), info.ModTime())
287
+
288
+ return nil
289
+ })
290
+
291
+ return totalWritten, err
292
+ }
293
+
294
+ func (s *SFTPClient) uploadFile(localPath, remotePath string) (int64, error) {
174
295
  localFile, err := os.Open(localPath)
175
296
  if err != nil {
176
297
  return 0, fmt.Errorf("failed to open local file: %w", err)
@@ -199,8 +320,62 @@ func (s *SFTPClient) Upload(localPath, remotePath string) (int64, error) {
199
320
  return written, nil
200
321
  }
201
322
 
202
- // Download downloads a remote file to local
323
+ // Download downloads a remote file or directory to local
203
324
  func (s *SFTPClient) Download(remotePath, localPath string) (int64, error) {
325
+ remoteInfo, err := s.Client.Stat(remotePath)
326
+ if err != nil {
327
+ return 0, fmt.Errorf("failed to stat remote path: %w", err)
328
+ }
329
+
330
+ if !remoteInfo.IsDir() {
331
+ return s.downloadFile(remotePath, localPath)
332
+ }
333
+
334
+ var totalWritten int64
335
+ walker := s.Client.Walk(remotePath)
336
+ for walker.Step() {
337
+ if walker.Err() != nil {
338
+ return totalWritten, walker.Err()
339
+ }
340
+
341
+ path := walker.Path()
342
+ info := walker.Stat()
343
+
344
+ relPath, err := filepath.Rel(remotePath, path)
345
+ if err != nil {
346
+ return totalWritten, err
347
+ }
348
+
349
+ if err := checkPathTraversal(relPath); err != nil {
350
+ return totalWritten, err
351
+ }
352
+
353
+ targetPath := filepath.Join(localPath, relPath)
354
+
355
+ if info.IsDir() {
356
+ os.MkdirAll(targetPath, info.Mode())
357
+ continue
358
+ }
359
+
360
+ // Sync logic
361
+ localInfo, err := os.Stat(targetPath)
362
+ if err == nil && localInfo.Size() == info.Size() && localInfo.ModTime().Unix() == info.ModTime().Unix() {
363
+ continue
364
+ }
365
+
366
+ written, err := s.downloadFile(path, targetPath)
367
+ if err != nil {
368
+ return totalWritten, err
369
+ }
370
+
371
+ totalWritten += written
372
+ os.Chtimes(targetPath, info.ModTime(), info.ModTime())
373
+ }
374
+
375
+ return totalWritten, nil
376
+ }
377
+
378
+ func (s *SFTPClient) downloadFile(remotePath, localPath string) (int64, error) {
204
379
  remoteFile, err := s.Client.Open(remotePath)
205
380
  if err != nil {
206
381
  return 0, fmt.Errorf("failed to open remote file: %w", err)
@@ -380,6 +555,14 @@ func (pw *progressWriter) Write(p []byte) (n int, err error) {
380
555
  return
381
556
  }
382
557
 
558
+ // checkPathTraversal ensures that relPath does not escape the destination directory
559
+ func checkPathTraversal(relPath string) error {
560
+ if !filepath.IsLocal(relPath) {
561
+ return fmt.Errorf("path traversal detected: %s escapes destination", relPath)
562
+ }
563
+ return nil
564
+ }
565
+
383
566
  // expandPath expands ~ to home directory
384
567
  func expandPath(path string) string {
385
568
  if len(path) > 0 && path[0] == '~' {
@@ -31,3 +31,46 @@ func TestNewAuthMethodsFromKeyPath(t *testing.T) {
31
31
  t.Error("expected error for non-existent key")
32
32
  }
33
33
  }
34
+
35
+ func TestCheckPathTraversal(t *testing.T) {
36
+ tests := []struct {
37
+ name string
38
+ relPath string
39
+ wantErr bool
40
+ }{
41
+ {
42
+ name: "safe path within directory",
43
+ relPath: "safe_file.txt",
44
+ wantErr: false,
45
+ },
46
+ {
47
+ name: "safe path same as directory",
48
+ relPath: ".",
49
+ wantErr: false,
50
+ },
51
+ {
52
+ name: "path traversal attempt with ../",
53
+ relPath: "../etc/passwd",
54
+ wantErr: true,
55
+ },
56
+ {
57
+ name: "absolute path traversal",
58
+ relPath: "/etc/passwd",
59
+ wantErr: true,
60
+ },
61
+ {
62
+ name: "deep nested safe path",
63
+ relPath: "a/b/c/d.txt",
64
+ wantErr: false,
65
+ },
66
+ }
67
+
68
+ for _, tt := range tests {
69
+ t.Run(tt.name, func(t *testing.T) {
70
+ err := checkPathTraversal(tt.relPath)
71
+ if (err != nil) != tt.wantErr {
72
+ t.Errorf("checkPathTraversal() error = %v, wantErr %v", err, tt.wantErr)
73
+ }
74
+ })
75
+ }
76
+ }
@@ -73,12 +73,20 @@ func (f *Forwarder) startLocalForward() {
73
73
  f.mu.RUnlock()
74
74
  return
75
75
  }
76
+ listener := f.listener
76
77
  f.mu.RUnlock()
77
78
 
79
+ if listener == nil {
80
+ time.Sleep(1 * time.Second)
81
+ continue
82
+ }
83
+
78
84
  // Set deadline to prevent blocking forever
79
- f.listener.(*net.TCPListener).SetDeadline(time.Now().Add(5 * time.Second))
85
+ if tcpListener, ok := listener.(*net.TCPListener); ok {
86
+ tcpListener.SetDeadline(time.Now().Add(5 * time.Second))
87
+ }
80
88
 
81
- conn, err := f.listener.Accept()
89
+ conn, err := listener.Accept()
82
90
  if err != nil {
83
91
  // Check if it's a timeout
84
92
  if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
@@ -117,14 +125,24 @@ func (f *Forwarder) handleLocalConnection(localConn net.Conn) {
117
125
  // Use 127.0.0.1 instead of localhost to avoid IPv6 issues
118
126
  remoteAddr := fmt.Sprintf("127.0.0.1:%d", f.RemotePort)
119
127
 
120
- remoteConn, err := net.DialTimeout("tcp", remoteAddr, 5*time.Second)
128
+ // Dial the remote endpoint *through* the SSH tunnel
129
+ f.mu.RLock()
130
+ client := f.sshClient
131
+ f.mu.RUnlock()
132
+
133
+ if client == nil {
134
+ log.Printf("[portforward] SSH client is nil, cannot connect to remote %s", remoteAddr)
135
+ return
136
+ }
137
+
138
+ remoteConn, err := client.Dial("tcp", remoteAddr)
121
139
  if err != nil {
122
- log.Printf("[portforward] Failed to connect to remote %s: %v", remoteAddr, err)
140
+ log.Printf("[portforward] Failed to connect to remote %s via SSH: %v", remoteAddr, err)
123
141
  return
124
142
  }
125
143
  defer remoteConn.Close()
126
144
 
127
- log.Printf("[portforward] Tunnel established to remote %s", remoteAddr)
145
+ log.Printf("[portforward] Tunnel established to remote %s via SSH", remoteAddr)
128
146
 
129
147
  // Bidirectional copy
130
148
  done := make(chan struct{})
@@ -146,39 +164,29 @@ func (f *Forwarder) startRemoteForward() {
146
164
  go func() {
147
165
  defer f.wg.Done()
148
166
 
149
- // Request the SSH server to listen on remote:remotePort and forward connections to us
150
- addr := fmt.Sprintf(":%d", f.RemotePort)
151
- payload := ssh.Marshal(struct {
152
- Addr string
153
- Port uint32
154
- }{Addr: addr, Port: uint32(f.RemotePort)})
155
-
156
- // Use timeout for the request
157
- type result struct {
158
- ok bool
159
- err error
167
+ f.mu.RLock()
168
+ client := f.sshClient
169
+ f.mu.RUnlock()
170
+
171
+ if client == nil {
172
+ log.Printf("[portforward] SSH client is nil, cannot start remote forward")
173
+ return
160
174
  }
161
- resultCh := make(chan result, 1)
162
- go func() {
163
- ok, _, err := f.sshClient.SendRequest("tcpip-forward", true, payload)
164
- resultCh <- result{ok: ok, err: err}
165
- }()
166
-
167
- select {
168
- case r := <-resultCh:
169
- if r.err != nil {
170
- log.Printf("[portforward] Failed to send tcpip-forward request: %v", r.err)
171
- return
172
- }
173
- if !r.ok {
174
- log.Printf("[portforward] SSH server rejected tcpip-forward request for port %d", f.RemotePort)
175
- return
176
- }
177
- case <-time.After(5 * time.Second):
178
- log.Printf("[portforward] tcpip-forward request timed out")
175
+
176
+ // Request the SSH server to listen on remote:remotePort and forward connections to us
177
+ remoteAddr := fmt.Sprintf("0.0.0.0:%d", f.RemotePort)
178
+
179
+ // Note: The ssh.Client.Listen internally sends the tcpip-forward request
180
+ listener, err := client.Listen("tcp", remoteAddr)
181
+ if err != nil {
182
+ log.Printf("[portforward] Failed to listen on remote port %d via SSH: %v", f.RemotePort, err)
179
183
  return
180
184
  }
181
185
 
186
+ f.mu.Lock()
187
+ f.listener = listener
188
+ f.mu.Unlock()
189
+
182
190
  log.Printf("[portforward] Remote forward: SSH server listening on port %d", f.RemotePort)
183
191
 
184
192
  // Accept forwarded connections in a loop
@@ -188,10 +196,16 @@ func (f *Forwarder) startRemoteForward() {
188
196
  f.mu.RUnlock()
189
197
  return
190
198
  }
199
+ currentListener := f.listener
191
200
  f.mu.RUnlock()
192
201
 
193
- // Wait for a forwarded connection with timeout
194
- ch, reqs, err := f.sshClient.OpenChannel("forwarded-tcpip", nil)
202
+ if currentListener == nil {
203
+ time.Sleep(1 * time.Second)
204
+ continue
205
+ }
206
+
207
+ // Wait for a forwarded connection
208
+ remoteConn, err := currentListener.Accept()
195
209
  if err != nil {
196
210
  f.mu.RLock()
197
211
  closed := f.closed
@@ -199,21 +213,18 @@ func (f *Forwarder) startRemoteForward() {
199
213
  if closed {
200
214
  return
201
215
  }
202
- // Log and continue - don't block
203
216
  log.Printf("[portforward] Error accepting forwarded connection: %v", err)
204
- time.Sleep(1 * time.Second)
205
- continue
217
+ return // Listener is broken or closed
206
218
  }
207
219
 
208
- // Discard any requests
209
- go ssh.DiscardRequests(reqs)
220
+ log.Printf("[portforward] Accepted remote connection from %v", remoteConn.RemoteAddr())
210
221
 
211
- // Connect to local port
222
+ // Connect to local port to forward the traffic
212
223
  localAddr := fmt.Sprintf("127.0.0.1:%d", f.LocalPort)
213
224
  localConn, err := net.DialTimeout("tcp", localAddr, 5*time.Second)
214
225
  if err != nil {
215
226
  log.Printf("[portforward] Failed to connect to local port %d: %v", f.LocalPort, err)
216
- ch.Close()
227
+ remoteConn.Close()
217
228
  continue
218
229
  }
219
230
 
@@ -225,12 +236,20 @@ func (f *Forwarder) startRemoteForward() {
225
236
 
226
237
  // Bidirectional copy
227
238
  go func() {
228
- io.Copy(ch, localConn)
229
- ch.Close()
230
- localConn.Close()
231
- }()
232
- go func() {
233
- io.Copy(localConn, ch)
239
+ defer func() {
240
+ remoteConn.Close()
241
+ localConn.Close()
242
+ f.mu.Lock()
243
+ delete(f.conns, localConn)
244
+ f.mu.Unlock()
245
+ }()
246
+ done := make(chan struct{})
247
+ go func() {
248
+ io.Copy(remoteConn, localConn)
249
+ close(done)
250
+ }()
251
+ io.Copy(localConn, remoteConn)
252
+ <-done
234
253
  }()
235
254
  }
236
255
  }()
@@ -244,7 +263,6 @@ func (f *Forwarder) Close() {
244
263
  return
245
264
  }
246
265
  f.closed = true
247
- f.mu.Unlock()
248
266
 
249
267
  if f.listener != nil {
250
268
  f.listener.Close()
@@ -253,6 +271,7 @@ func (f *Forwarder) Close() {
253
271
  for conn := range f.conns {
254
272
  conn.Close()
255
273
  }
274
+ f.mu.Unlock()
256
275
 
257
276
  f.wg.Wait()
258
277
  }
@@ -262,8 +281,6 @@ func (f *Forwarder) Restart(sshClient *ssh.Client) {
262
281
  f.mu.Lock()
263
282
  f.sshClient = sshClient
264
283
  f.closed = false
265
- f.conns = make(map[net.Conn]bool)
266
- f.mu.Unlock()
267
284
 
268
285
  if f.listener != nil {
269
286
  f.listener.Close()
@@ -273,11 +290,21 @@ func (f *Forwarder) Restart(sshClient *ssh.Client) {
273
290
  // Create new listener
274
291
  addr := fmt.Sprintf("localhost:%d", f.LocalPort)
275
292
  listener, err := net.Listen("tcp", addr)
276
- if err != nil {
277
- return
293
+ if err == nil {
294
+ f.listener = listener
295
+ } else {
296
+ f.listener = nil
297
+ log.Printf("[portforward] Failed to recreate listener on %s: %v", addr, err)
278
298
  }
279
- f.listener = listener
299
+ } else {
300
+ f.listener = nil // Will be re-initialized in startRemoteForward
280
301
  }
281
302
 
303
+ for conn := range f.conns {
304
+ conn.Close()
305
+ }
306
+ f.conns = make(map[net.Conn]bool)
307
+ f.mu.Unlock()
308
+
282
309
  f.Start()
283
310
  }