gssh-agent 1.0.3 → 1.0.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/.github/workflows/ci.yml +27 -0
- package/.github/workflows/publish.yml +104 -0
- package/README.md +31 -61
- package/bin/gssh +0 -0
- package/cmd/gssh/main.go +144 -113
- package/fix_manager.patch +79 -0
- package/internal/client/ssh.go +186 -3
- package/internal/client/ssh_test.go +43 -0
- package/internal/portforward/forwarder.go +94 -40
- package/internal/protocol/types.go +3 -1
- package/internal/session/manager.go +164 -39
- package/internal/session/manager_test.go +324 -0
- package/package.json +3 -4
- package/pkg/rpc/handler.go +1 -1
- package/pkg/rpc/handler_test.go +36 -0
- package/plan.md +4 -0
- package/skill.md +34 -8
- package/bin/daemon +0 -0
- package/bin/gssh-daemon +0 -0
package/internal/client/ssh.go
CHANGED
|
@@ -11,6 +11,7 @@ import (
|
|
|
11
11
|
|
|
12
12
|
"github.com/pkg/sftp"
|
|
13
13
|
"golang.org/x/crypto/ssh"
|
|
14
|
+
"golang.org/x/crypto/ssh/knownhosts"
|
|
14
15
|
)
|
|
15
16
|
|
|
16
17
|
// SSHClient encapsulates SSH connection logic
|
|
@@ -32,12 +33,84 @@ func (k *KeyboardInteractiveHandler) Challenge(name, instruction string, questio
|
|
|
32
33
|
return answers, nil
|
|
33
34
|
}
|
|
34
35
|
|
|
36
|
+
// getHostKeyCallback returns a host key callback that verifies the host key against the user's known_hosts file.
|
|
37
|
+
// If the host is unknown, it implements Trust-On-First-Use (TOFU) by adding the new key to the known_hosts file.
|
|
38
|
+
func getHostKeyCallback() (ssh.HostKeyCallback, error) {
|
|
39
|
+
homeDir, err := os.UserHomeDir()
|
|
40
|
+
if err != nil {
|
|
41
|
+
return nil, fmt.Errorf("could not get user home dir: %w", err)
|
|
42
|
+
}
|
|
43
|
+
|
|
44
|
+
knownHostsPath := filepath.Join(homeDir, ".ssh", "known_hosts")
|
|
45
|
+
|
|
46
|
+
// Ensure the .ssh directory exists
|
|
47
|
+
if err := os.MkdirAll(filepath.Dir(knownHostsPath), 0700); err != nil {
|
|
48
|
+
return nil, fmt.Errorf("failed to create .ssh directory: %w", err)
|
|
49
|
+
}
|
|
50
|
+
|
|
51
|
+
// Create the known_hosts file if it doesn't exist
|
|
52
|
+
if _, err := os.Stat(knownHostsPath); os.IsNotExist(err) {
|
|
53
|
+
f, err := os.OpenFile(knownHostsPath, os.O_CREATE|os.O_RDWR, 0600)
|
|
54
|
+
if err != nil {
|
|
55
|
+
return nil, fmt.Errorf("failed to create known_hosts file: %w", err)
|
|
56
|
+
}
|
|
57
|
+
f.Close()
|
|
58
|
+
}
|
|
59
|
+
|
|
60
|
+
cb, err := knownhosts.New(knownHostsPath)
|
|
61
|
+
if err != nil {
|
|
62
|
+
return nil, fmt.Errorf("failed to create knownhosts callback: %w", err)
|
|
63
|
+
}
|
|
64
|
+
|
|
65
|
+
return func(hostname string, remote net.Addr, key ssh.PublicKey) error {
|
|
66
|
+
err := cb(hostname, remote, key)
|
|
67
|
+
if err == nil {
|
|
68
|
+
return nil
|
|
69
|
+
}
|
|
70
|
+
|
|
71
|
+
keyErr, ok := err.(*knownhosts.KeyError)
|
|
72
|
+
if !ok {
|
|
73
|
+
return err
|
|
74
|
+
}
|
|
75
|
+
|
|
76
|
+
// If len(keyErr.Want) is 0, it means the key is completely unknown (not a mismatch).
|
|
77
|
+
// We implement Trust-On-First-Use (TOFU) by adding the new key to known_hosts.
|
|
78
|
+
if len(keyErr.Want) == 0 {
|
|
79
|
+
f, fErr := os.OpenFile(knownHostsPath, os.O_APPEND|os.O_WRONLY, 0600)
|
|
80
|
+
if fErr != nil {
|
|
81
|
+
return fmt.Errorf("failed to open known_hosts for appending: %w", fErr)
|
|
82
|
+
}
|
|
83
|
+
defer f.Close()
|
|
84
|
+
|
|
85
|
+
addresses := []string{knownhosts.Normalize(hostname)}
|
|
86
|
+
if remoteString := remote.String(); remoteString != hostname {
|
|
87
|
+
addresses = append(addresses, knownhosts.Normalize(remoteString))
|
|
88
|
+
}
|
|
89
|
+
|
|
90
|
+
line := knownhosts.Line(addresses, key)
|
|
91
|
+
if _, wErr := f.WriteString(line + "\n"); wErr != nil {
|
|
92
|
+
return fmt.Errorf("failed to append key to known_hosts: %w", wErr)
|
|
93
|
+
}
|
|
94
|
+
return nil
|
|
95
|
+
}
|
|
96
|
+
|
|
97
|
+
// It's a key mismatch (security risk: MITM or host key changed). Reject the connection.
|
|
98
|
+
return err
|
|
99
|
+
}, nil
|
|
100
|
+
}
|
|
101
|
+
|
|
35
102
|
// NewSSHClient creates a new SSH client
|
|
36
103
|
func NewSSHClient(user, host string, port int, authMethods ...ssh.AuthMethod) (*SSHClient, error) {
|
|
104
|
+
hostKeyCallback, err := getHostKeyCallback()
|
|
105
|
+
if err != nil {
|
|
106
|
+
return nil, fmt.Errorf("failed to get host key callback: %w", err)
|
|
107
|
+
}
|
|
108
|
+
|
|
37
109
|
config := &ssh.ClientConfig{
|
|
38
110
|
User: user,
|
|
39
111
|
Auth: authMethods,
|
|
40
|
-
HostKeyCallback:
|
|
112
|
+
HostKeyCallback: hostKeyCallback,
|
|
113
|
+
Timeout: 10 * time.Second,
|
|
41
114
|
}
|
|
42
115
|
|
|
43
116
|
addr := fmt.Sprintf("%s:%d", host, port)
|
|
@@ -169,8 +242,56 @@ func (c *SSHClient) NewSFTPClient() (*SFTPClient, error) {
|
|
|
169
242
|
return &SFTPClient{Client: sftpClient}, nil
|
|
170
243
|
}
|
|
171
244
|
|
|
172
|
-
// Upload uploads a local file to remote
|
|
245
|
+
// Upload uploads a local file or directory to remote
|
|
173
246
|
func (s *SFTPClient) Upload(localPath, remotePath string) (int64, error) {
|
|
247
|
+
localInfo, err := os.Stat(localPath)
|
|
248
|
+
if err != nil {
|
|
249
|
+
return 0, fmt.Errorf("failed to stat local path: %w", err)
|
|
250
|
+
}
|
|
251
|
+
|
|
252
|
+
if !localInfo.IsDir() {
|
|
253
|
+
return s.uploadFile(localPath, remotePath)
|
|
254
|
+
}
|
|
255
|
+
|
|
256
|
+
var totalWritten int64
|
|
257
|
+
err = filepath.Walk(localPath, func(path string, info os.FileInfo, err error) error {
|
|
258
|
+
if err != nil {
|
|
259
|
+
return err
|
|
260
|
+
}
|
|
261
|
+
|
|
262
|
+
relPath, err := filepath.Rel(localPath, path)
|
|
263
|
+
if err != nil {
|
|
264
|
+
return err
|
|
265
|
+
}
|
|
266
|
+
|
|
267
|
+
targetPath := filepath.ToSlash(filepath.Join(remotePath, relPath))
|
|
268
|
+
|
|
269
|
+
if info.IsDir() {
|
|
270
|
+
s.Client.MkdirAll(targetPath)
|
|
271
|
+
return nil
|
|
272
|
+
}
|
|
273
|
+
|
|
274
|
+
// Sync logic
|
|
275
|
+
remoteInfo, err := s.Client.Stat(targetPath)
|
|
276
|
+
if err == nil && remoteInfo.Size() == info.Size() && remoteInfo.ModTime().Unix() == info.ModTime().Unix() {
|
|
277
|
+
return nil
|
|
278
|
+
}
|
|
279
|
+
|
|
280
|
+
written, err := s.uploadFile(path, targetPath)
|
|
281
|
+
if err != nil {
|
|
282
|
+
return err
|
|
283
|
+
}
|
|
284
|
+
|
|
285
|
+
totalWritten += written
|
|
286
|
+
s.Client.Chtimes(targetPath, info.ModTime(), info.ModTime())
|
|
287
|
+
|
|
288
|
+
return nil
|
|
289
|
+
})
|
|
290
|
+
|
|
291
|
+
return totalWritten, err
|
|
292
|
+
}
|
|
293
|
+
|
|
294
|
+
func (s *SFTPClient) uploadFile(localPath, remotePath string) (int64, error) {
|
|
174
295
|
localFile, err := os.Open(localPath)
|
|
175
296
|
if err != nil {
|
|
176
297
|
return 0, fmt.Errorf("failed to open local file: %w", err)
|
|
@@ -199,8 +320,62 @@ func (s *SFTPClient) Upload(localPath, remotePath string) (int64, error) {
|
|
|
199
320
|
return written, nil
|
|
200
321
|
}
|
|
201
322
|
|
|
202
|
-
// Download downloads a remote file to local
|
|
323
|
+
// Download downloads a remote file or directory to local
|
|
203
324
|
func (s *SFTPClient) Download(remotePath, localPath string) (int64, error) {
|
|
325
|
+
remoteInfo, err := s.Client.Stat(remotePath)
|
|
326
|
+
if err != nil {
|
|
327
|
+
return 0, fmt.Errorf("failed to stat remote path: %w", err)
|
|
328
|
+
}
|
|
329
|
+
|
|
330
|
+
if !remoteInfo.IsDir() {
|
|
331
|
+
return s.downloadFile(remotePath, localPath)
|
|
332
|
+
}
|
|
333
|
+
|
|
334
|
+
var totalWritten int64
|
|
335
|
+
walker := s.Client.Walk(remotePath)
|
|
336
|
+
for walker.Step() {
|
|
337
|
+
if walker.Err() != nil {
|
|
338
|
+
return totalWritten, walker.Err()
|
|
339
|
+
}
|
|
340
|
+
|
|
341
|
+
path := walker.Path()
|
|
342
|
+
info := walker.Stat()
|
|
343
|
+
|
|
344
|
+
relPath, err := filepath.Rel(remotePath, path)
|
|
345
|
+
if err != nil {
|
|
346
|
+
return totalWritten, err
|
|
347
|
+
}
|
|
348
|
+
|
|
349
|
+
if err := checkPathTraversal(relPath); err != nil {
|
|
350
|
+
return totalWritten, err
|
|
351
|
+
}
|
|
352
|
+
|
|
353
|
+
targetPath := filepath.Join(localPath, relPath)
|
|
354
|
+
|
|
355
|
+
if info.IsDir() {
|
|
356
|
+
os.MkdirAll(targetPath, info.Mode())
|
|
357
|
+
continue
|
|
358
|
+
}
|
|
359
|
+
|
|
360
|
+
// Sync logic
|
|
361
|
+
localInfo, err := os.Stat(targetPath)
|
|
362
|
+
if err == nil && localInfo.Size() == info.Size() && localInfo.ModTime().Unix() == info.ModTime().Unix() {
|
|
363
|
+
continue
|
|
364
|
+
}
|
|
365
|
+
|
|
366
|
+
written, err := s.downloadFile(path, targetPath)
|
|
367
|
+
if err != nil {
|
|
368
|
+
return totalWritten, err
|
|
369
|
+
}
|
|
370
|
+
|
|
371
|
+
totalWritten += written
|
|
372
|
+
os.Chtimes(targetPath, info.ModTime(), info.ModTime())
|
|
373
|
+
}
|
|
374
|
+
|
|
375
|
+
return totalWritten, nil
|
|
376
|
+
}
|
|
377
|
+
|
|
378
|
+
func (s *SFTPClient) downloadFile(remotePath, localPath string) (int64, error) {
|
|
204
379
|
remoteFile, err := s.Client.Open(remotePath)
|
|
205
380
|
if err != nil {
|
|
206
381
|
return 0, fmt.Errorf("failed to open remote file: %w", err)
|
|
@@ -380,6 +555,14 @@ func (pw *progressWriter) Write(p []byte) (n int, err error) {
|
|
|
380
555
|
return
|
|
381
556
|
}
|
|
382
557
|
|
|
558
|
+
// checkPathTraversal ensures that relPath does not escape the destination directory
|
|
559
|
+
func checkPathTraversal(relPath string) error {
|
|
560
|
+
if !filepath.IsLocal(relPath) {
|
|
561
|
+
return fmt.Errorf("path traversal detected: %s escapes destination", relPath)
|
|
562
|
+
}
|
|
563
|
+
return nil
|
|
564
|
+
}
|
|
565
|
+
|
|
383
566
|
// expandPath expands ~ to home directory
|
|
384
567
|
func expandPath(path string) string {
|
|
385
568
|
if len(path) > 0 && path[0] == '~' {
|
|
@@ -31,3 +31,46 @@ func TestNewAuthMethodsFromKeyPath(t *testing.T) {
|
|
|
31
31
|
t.Error("expected error for non-existent key")
|
|
32
32
|
}
|
|
33
33
|
}
|
|
34
|
+
|
|
35
|
+
func TestCheckPathTraversal(t *testing.T) {
|
|
36
|
+
tests := []struct {
|
|
37
|
+
name string
|
|
38
|
+
relPath string
|
|
39
|
+
wantErr bool
|
|
40
|
+
}{
|
|
41
|
+
{
|
|
42
|
+
name: "safe path within directory",
|
|
43
|
+
relPath: "safe_file.txt",
|
|
44
|
+
wantErr: false,
|
|
45
|
+
},
|
|
46
|
+
{
|
|
47
|
+
name: "safe path same as directory",
|
|
48
|
+
relPath: ".",
|
|
49
|
+
wantErr: false,
|
|
50
|
+
},
|
|
51
|
+
{
|
|
52
|
+
name: "path traversal attempt with ../",
|
|
53
|
+
relPath: "../etc/passwd",
|
|
54
|
+
wantErr: true,
|
|
55
|
+
},
|
|
56
|
+
{
|
|
57
|
+
name: "absolute path traversal",
|
|
58
|
+
relPath: "/etc/passwd",
|
|
59
|
+
wantErr: true,
|
|
60
|
+
},
|
|
61
|
+
{
|
|
62
|
+
name: "deep nested safe path",
|
|
63
|
+
relPath: "a/b/c/d.txt",
|
|
64
|
+
wantErr: false,
|
|
65
|
+
},
|
|
66
|
+
}
|
|
67
|
+
|
|
68
|
+
for _, tt := range tests {
|
|
69
|
+
t.Run(tt.name, func(t *testing.T) {
|
|
70
|
+
err := checkPathTraversal(tt.relPath)
|
|
71
|
+
if (err != nil) != tt.wantErr {
|
|
72
|
+
t.Errorf("checkPathTraversal() error = %v, wantErr %v", err, tt.wantErr)
|
|
73
|
+
}
|
|
74
|
+
})
|
|
75
|
+
}
|
|
76
|
+
}
|
|
@@ -6,6 +6,7 @@ import (
|
|
|
6
6
|
"log"
|
|
7
7
|
"net"
|
|
8
8
|
"sync"
|
|
9
|
+
"time"
|
|
9
10
|
|
|
10
11
|
"golang.org/x/crypto/ssh"
|
|
11
12
|
)
|
|
@@ -72,16 +73,32 @@ func (f *Forwarder) startLocalForward() {
|
|
|
72
73
|
f.mu.RUnlock()
|
|
73
74
|
return
|
|
74
75
|
}
|
|
76
|
+
listener := f.listener
|
|
75
77
|
f.mu.RUnlock()
|
|
76
78
|
|
|
77
|
-
|
|
79
|
+
if listener == nil {
|
|
80
|
+
time.Sleep(1 * time.Second)
|
|
81
|
+
continue
|
|
82
|
+
}
|
|
83
|
+
|
|
84
|
+
// Set deadline to prevent blocking forever
|
|
85
|
+
if tcpListener, ok := listener.(*net.TCPListener); ok {
|
|
86
|
+
tcpListener.SetDeadline(time.Now().Add(5 * time.Second))
|
|
87
|
+
}
|
|
88
|
+
|
|
89
|
+
conn, err := listener.Accept()
|
|
78
90
|
if err != nil {
|
|
91
|
+
// Check if it's a timeout
|
|
92
|
+
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
|
|
93
|
+
continue
|
|
94
|
+
}
|
|
79
95
|
f.mu.RLock()
|
|
80
96
|
closed := f.closed
|
|
81
97
|
f.mu.RUnlock()
|
|
82
98
|
if closed {
|
|
83
99
|
return
|
|
84
100
|
}
|
|
101
|
+
log.Printf("[portforward] Accept error: %v", err)
|
|
85
102
|
continue
|
|
86
103
|
}
|
|
87
104
|
|
|
@@ -105,15 +122,27 @@ func (f *Forwarder) handleLocalConnection(localConn net.Conn) {
|
|
|
105
122
|
|
|
106
123
|
log.Printf("[portforward] Connection accepted from %s", localConn.RemoteAddr())
|
|
107
124
|
|
|
108
|
-
|
|
109
|
-
|
|
125
|
+
// Use 127.0.0.1 instead of localhost to avoid IPv6 issues
|
|
126
|
+
remoteAddr := fmt.Sprintf("127.0.0.1:%d", f.RemotePort)
|
|
127
|
+
|
|
128
|
+
// Dial the remote endpoint *through* the SSH tunnel
|
|
129
|
+
f.mu.RLock()
|
|
130
|
+
client := f.sshClient
|
|
131
|
+
f.mu.RUnlock()
|
|
132
|
+
|
|
133
|
+
if client == nil {
|
|
134
|
+
log.Printf("[portforward] SSH client is nil, cannot connect to remote %s", remoteAddr)
|
|
135
|
+
return
|
|
136
|
+
}
|
|
137
|
+
|
|
138
|
+
remoteConn, err := client.Dial("tcp", remoteAddr)
|
|
110
139
|
if err != nil {
|
|
111
|
-
log.Printf("[portforward] Failed to connect to remote %s: %v", remoteAddr, err)
|
|
140
|
+
log.Printf("[portforward] Failed to connect to remote %s via SSH: %v", remoteAddr, err)
|
|
112
141
|
return
|
|
113
142
|
}
|
|
114
143
|
defer remoteConn.Close()
|
|
115
144
|
|
|
116
|
-
log.Printf("[portforward] Tunnel established to remote %s", remoteAddr)
|
|
145
|
+
log.Printf("[portforward] Tunnel established to remote %s via SSH", remoteAddr)
|
|
117
146
|
|
|
118
147
|
// Bidirectional copy
|
|
119
148
|
done := make(chan struct{})
|
|
@@ -135,38 +164,48 @@ func (f *Forwarder) startRemoteForward() {
|
|
|
135
164
|
go func() {
|
|
136
165
|
defer f.wg.Done()
|
|
137
166
|
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
}{Addr: addr, Port: uint32(f.RemotePort)})
|
|
145
|
-
|
|
146
|
-
ok, _, err := f.sshClient.SendRequest("tcpip-forward", true, payload)
|
|
147
|
-
if err != nil {
|
|
148
|
-
log.Printf("[portforward] Failed to send tcpip-forward request: %v", err)
|
|
167
|
+
f.mu.RLock()
|
|
168
|
+
client := f.sshClient
|
|
169
|
+
f.mu.RUnlock()
|
|
170
|
+
|
|
171
|
+
if client == nil {
|
|
172
|
+
log.Printf("[portforward] SSH client is nil, cannot start remote forward")
|
|
149
173
|
return
|
|
150
174
|
}
|
|
151
|
-
|
|
152
|
-
|
|
175
|
+
|
|
176
|
+
// Request the SSH server to listen on remote:remotePort and forward connections to us
|
|
177
|
+
remoteAddr := fmt.Sprintf("0.0.0.0:%d", f.RemotePort)
|
|
178
|
+
|
|
179
|
+
// Note: The ssh.Client.Listen internally sends the tcpip-forward request
|
|
180
|
+
listener, err := client.Listen("tcp", remoteAddr)
|
|
181
|
+
if err != nil {
|
|
182
|
+
log.Printf("[portforward] Failed to listen on remote port %d via SSH: %v", f.RemotePort, err)
|
|
153
183
|
return
|
|
154
184
|
}
|
|
155
185
|
|
|
186
|
+
f.mu.Lock()
|
|
187
|
+
f.listener = listener
|
|
188
|
+
f.mu.Unlock()
|
|
189
|
+
|
|
156
190
|
log.Printf("[portforward] Remote forward: SSH server listening on port %d", f.RemotePort)
|
|
157
191
|
|
|
158
|
-
//
|
|
159
|
-
// The SSH server will open a "forwarded-tcpip" channel for each connection
|
|
192
|
+
// Accept forwarded connections in a loop
|
|
160
193
|
for {
|
|
161
194
|
f.mu.RLock()
|
|
162
195
|
if f.closed {
|
|
163
196
|
f.mu.RUnlock()
|
|
164
197
|
return
|
|
165
198
|
}
|
|
199
|
+
currentListener := f.listener
|
|
166
200
|
f.mu.RUnlock()
|
|
167
201
|
|
|
202
|
+
if currentListener == nil {
|
|
203
|
+
time.Sleep(1 * time.Second)
|
|
204
|
+
continue
|
|
205
|
+
}
|
|
206
|
+
|
|
168
207
|
// Wait for a forwarded connection
|
|
169
|
-
|
|
208
|
+
remoteConn, err := currentListener.Accept()
|
|
170
209
|
if err != nil {
|
|
171
210
|
f.mu.RLock()
|
|
172
211
|
closed := f.closed
|
|
@@ -175,18 +214,17 @@ func (f *Forwarder) startRemoteForward() {
|
|
|
175
214
|
return
|
|
176
215
|
}
|
|
177
216
|
log.Printf("[portforward] Error accepting forwarded connection: %v", err)
|
|
178
|
-
|
|
217
|
+
return // Listener is broken or closed
|
|
179
218
|
}
|
|
180
219
|
|
|
181
|
-
|
|
182
|
-
go ssh.DiscardRequests(reqs)
|
|
220
|
+
log.Printf("[portforward] Accepted remote connection from %v", remoteConn.RemoteAddr())
|
|
183
221
|
|
|
184
|
-
// Connect to local port
|
|
185
|
-
localAddr := fmt.Sprintf("
|
|
186
|
-
localConn, err := net.
|
|
222
|
+
// Connect to local port to forward the traffic
|
|
223
|
+
localAddr := fmt.Sprintf("127.0.0.1:%d", f.LocalPort)
|
|
224
|
+
localConn, err := net.DialTimeout("tcp", localAddr, 5*time.Second)
|
|
187
225
|
if err != nil {
|
|
188
226
|
log.Printf("[portforward] Failed to connect to local port %d: %v", f.LocalPort, err)
|
|
189
|
-
|
|
227
|
+
remoteConn.Close()
|
|
190
228
|
continue
|
|
191
229
|
}
|
|
192
230
|
|
|
@@ -198,12 +236,20 @@ func (f *Forwarder) startRemoteForward() {
|
|
|
198
236
|
|
|
199
237
|
// Bidirectional copy
|
|
200
238
|
go func() {
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
239
|
+
defer func() {
|
|
240
|
+
remoteConn.Close()
|
|
241
|
+
localConn.Close()
|
|
242
|
+
f.mu.Lock()
|
|
243
|
+
delete(f.conns, localConn)
|
|
244
|
+
f.mu.Unlock()
|
|
245
|
+
}()
|
|
246
|
+
done := make(chan struct{})
|
|
247
|
+
go func() {
|
|
248
|
+
io.Copy(remoteConn, localConn)
|
|
249
|
+
close(done)
|
|
250
|
+
}()
|
|
251
|
+
io.Copy(localConn, remoteConn)
|
|
252
|
+
<-done
|
|
207
253
|
}()
|
|
208
254
|
}
|
|
209
255
|
}()
|
|
@@ -217,7 +263,6 @@ func (f *Forwarder) Close() {
|
|
|
217
263
|
return
|
|
218
264
|
}
|
|
219
265
|
f.closed = true
|
|
220
|
-
f.mu.Unlock()
|
|
221
266
|
|
|
222
267
|
if f.listener != nil {
|
|
223
268
|
f.listener.Close()
|
|
@@ -226,6 +271,7 @@ func (f *Forwarder) Close() {
|
|
|
226
271
|
for conn := range f.conns {
|
|
227
272
|
conn.Close()
|
|
228
273
|
}
|
|
274
|
+
f.mu.Unlock()
|
|
229
275
|
|
|
230
276
|
f.wg.Wait()
|
|
231
277
|
}
|
|
@@ -235,8 +281,6 @@ func (f *Forwarder) Restart(sshClient *ssh.Client) {
|
|
|
235
281
|
f.mu.Lock()
|
|
236
282
|
f.sshClient = sshClient
|
|
237
283
|
f.closed = false
|
|
238
|
-
f.conns = make(map[net.Conn]bool)
|
|
239
|
-
f.mu.Unlock()
|
|
240
284
|
|
|
241
285
|
if f.listener != nil {
|
|
242
286
|
f.listener.Close()
|
|
@@ -246,11 +290,21 @@ func (f *Forwarder) Restart(sshClient *ssh.Client) {
|
|
|
246
290
|
// Create new listener
|
|
247
291
|
addr := fmt.Sprintf("localhost:%d", f.LocalPort)
|
|
248
292
|
listener, err := net.Listen("tcp", addr)
|
|
249
|
-
if err
|
|
250
|
-
|
|
293
|
+
if err == nil {
|
|
294
|
+
f.listener = listener
|
|
295
|
+
} else {
|
|
296
|
+
f.listener = nil
|
|
297
|
+
log.Printf("[portforward] Failed to recreate listener on %s: %v", addr, err)
|
|
251
298
|
}
|
|
252
|
-
|
|
299
|
+
} else {
|
|
300
|
+
f.listener = nil // Will be re-initialized in startRemoteForward
|
|
301
|
+
}
|
|
302
|
+
|
|
303
|
+
for conn := range f.conns {
|
|
304
|
+
conn.Close()
|
|
253
305
|
}
|
|
306
|
+
f.conns = make(map[net.Conn]bool)
|
|
307
|
+
f.mu.Unlock()
|
|
254
308
|
|
|
255
309
|
f.Start()
|
|
256
310
|
}
|
|
@@ -72,6 +72,7 @@ type ReconnectParams struct {
|
|
|
72
72
|
type ExecParams struct {
|
|
73
73
|
SessionID string `json:"session_id,omitempty"`
|
|
74
74
|
Command string `json:"command"`
|
|
75
|
+
Timeout int `json:"timeout,omitempty"` // 超时时间(秒),0 表示无超时
|
|
75
76
|
}
|
|
76
77
|
|
|
77
78
|
type ForwardParams struct {
|
|
@@ -108,6 +109,7 @@ type SCPResult struct {
|
|
|
108
109
|
// SFTPParams represents SFTP command parameters
|
|
109
110
|
type SFTPParams struct {
|
|
110
111
|
SessionID string `json:"session_id,omitempty"`
|
|
111
|
-
Command string `json:"command"`
|
|
112
|
+
Command string `json:"command"`
|
|
113
|
+
Timeout int `json:"timeout,omitempty"` // 超时时间(秒),0 表示无超时 // "ls", "cd", "pwd", "mkdir", "rm", "rmdir"
|
|
112
114
|
Path string `json:"path"`
|
|
113
115
|
}
|