gssh-agent 1.0.3 → 1.0.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -11,6 +11,7 @@ import (
11
11
 
12
12
  "github.com/pkg/sftp"
13
13
  "golang.org/x/crypto/ssh"
14
+ "golang.org/x/crypto/ssh/knownhosts"
14
15
  )
15
16
 
16
17
  // SSHClient encapsulates SSH connection logic
@@ -32,12 +33,84 @@ func (k *KeyboardInteractiveHandler) Challenge(name, instruction string, questio
32
33
  return answers, nil
33
34
  }
34
35
 
36
+ // getHostKeyCallback returns a host key callback that verifies the host key against the user's known_hosts file.
37
+ // If the host is unknown, it implements Trust-On-First-Use (TOFU) by adding the new key to the known_hosts file.
38
+ func getHostKeyCallback() (ssh.HostKeyCallback, error) {
39
+ homeDir, err := os.UserHomeDir()
40
+ if err != nil {
41
+ return nil, fmt.Errorf("could not get user home dir: %w", err)
42
+ }
43
+
44
+ knownHostsPath := filepath.Join(homeDir, ".ssh", "known_hosts")
45
+
46
+ // Ensure the .ssh directory exists
47
+ if err := os.MkdirAll(filepath.Dir(knownHostsPath), 0700); err != nil {
48
+ return nil, fmt.Errorf("failed to create .ssh directory: %w", err)
49
+ }
50
+
51
+ // Create the known_hosts file if it doesn't exist
52
+ if _, err := os.Stat(knownHostsPath); os.IsNotExist(err) {
53
+ f, err := os.OpenFile(knownHostsPath, os.O_CREATE|os.O_RDWR, 0600)
54
+ if err != nil {
55
+ return nil, fmt.Errorf("failed to create known_hosts file: %w", err)
56
+ }
57
+ f.Close()
58
+ }
59
+
60
+ cb, err := knownhosts.New(knownHostsPath)
61
+ if err != nil {
62
+ return nil, fmt.Errorf("failed to create knownhosts callback: %w", err)
63
+ }
64
+
65
+ return func(hostname string, remote net.Addr, key ssh.PublicKey) error {
66
+ err := cb(hostname, remote, key)
67
+ if err == nil {
68
+ return nil
69
+ }
70
+
71
+ keyErr, ok := err.(*knownhosts.KeyError)
72
+ if !ok {
73
+ return err
74
+ }
75
+
76
+ // If len(keyErr.Want) is 0, it means the key is completely unknown (not a mismatch).
77
+ // We implement Trust-On-First-Use (TOFU) by adding the new key to known_hosts.
78
+ if len(keyErr.Want) == 0 {
79
+ f, fErr := os.OpenFile(knownHostsPath, os.O_APPEND|os.O_WRONLY, 0600)
80
+ if fErr != nil {
81
+ return fmt.Errorf("failed to open known_hosts for appending: %w", fErr)
82
+ }
83
+ defer f.Close()
84
+
85
+ addresses := []string{knownhosts.Normalize(hostname)}
86
+ if remoteString := remote.String(); remoteString != hostname {
87
+ addresses = append(addresses, knownhosts.Normalize(remoteString))
88
+ }
89
+
90
+ line := knownhosts.Line(addresses, key)
91
+ if _, wErr := f.WriteString(line + "\n"); wErr != nil {
92
+ return fmt.Errorf("failed to append key to known_hosts: %w", wErr)
93
+ }
94
+ return nil
95
+ }
96
+
97
+ // It's a key mismatch (security risk: MITM or host key changed). Reject the connection.
98
+ return err
99
+ }, nil
100
+ }
101
+
35
102
  // NewSSHClient creates a new SSH client
36
103
  func NewSSHClient(user, host string, port int, authMethods ...ssh.AuthMethod) (*SSHClient, error) {
104
+ hostKeyCallback, err := getHostKeyCallback()
105
+ if err != nil {
106
+ return nil, fmt.Errorf("failed to get host key callback: %w", err)
107
+ }
108
+
37
109
  config := &ssh.ClientConfig{
38
110
  User: user,
39
111
  Auth: authMethods,
40
- HostKeyCallback: ssh.InsecureIgnoreHostKey(),
112
+ HostKeyCallback: hostKeyCallback,
113
+ Timeout: 10 * time.Second,
41
114
  }
42
115
 
43
116
  addr := fmt.Sprintf("%s:%d", host, port)
@@ -169,8 +242,56 @@ func (c *SSHClient) NewSFTPClient() (*SFTPClient, error) {
169
242
  return &SFTPClient{Client: sftpClient}, nil
170
243
  }
171
244
 
172
- // Upload uploads a local file to remote
245
+ // Upload uploads a local file or directory to remote
173
246
  func (s *SFTPClient) Upload(localPath, remotePath string) (int64, error) {
247
+ localInfo, err := os.Stat(localPath)
248
+ if err != nil {
249
+ return 0, fmt.Errorf("failed to stat local path: %w", err)
250
+ }
251
+
252
+ if !localInfo.IsDir() {
253
+ return s.uploadFile(localPath, remotePath)
254
+ }
255
+
256
+ var totalWritten int64
257
+ err = filepath.Walk(localPath, func(path string, info os.FileInfo, err error) error {
258
+ if err != nil {
259
+ return err
260
+ }
261
+
262
+ relPath, err := filepath.Rel(localPath, path)
263
+ if err != nil {
264
+ return err
265
+ }
266
+
267
+ targetPath := filepath.ToSlash(filepath.Join(remotePath, relPath))
268
+
269
+ if info.IsDir() {
270
+ s.Client.MkdirAll(targetPath)
271
+ return nil
272
+ }
273
+
274
+ // Sync logic
275
+ remoteInfo, err := s.Client.Stat(targetPath)
276
+ if err == nil && remoteInfo.Size() == info.Size() && remoteInfo.ModTime().Unix() == info.ModTime().Unix() {
277
+ return nil
278
+ }
279
+
280
+ written, err := s.uploadFile(path, targetPath)
281
+ if err != nil {
282
+ return err
283
+ }
284
+
285
+ totalWritten += written
286
+ s.Client.Chtimes(targetPath, info.ModTime(), info.ModTime())
287
+
288
+ return nil
289
+ })
290
+
291
+ return totalWritten, err
292
+ }
293
+
294
+ func (s *SFTPClient) uploadFile(localPath, remotePath string) (int64, error) {
174
295
  localFile, err := os.Open(localPath)
175
296
  if err != nil {
176
297
  return 0, fmt.Errorf("failed to open local file: %w", err)
@@ -199,8 +320,62 @@ func (s *SFTPClient) Upload(localPath, remotePath string) (int64, error) {
199
320
  return written, nil
200
321
  }
201
322
 
202
- // Download downloads a remote file to local
323
+ // Download downloads a remote file or directory to local
203
324
  func (s *SFTPClient) Download(remotePath, localPath string) (int64, error) {
325
+ remoteInfo, err := s.Client.Stat(remotePath)
326
+ if err != nil {
327
+ return 0, fmt.Errorf("failed to stat remote path: %w", err)
328
+ }
329
+
330
+ if !remoteInfo.IsDir() {
331
+ return s.downloadFile(remotePath, localPath)
332
+ }
333
+
334
+ var totalWritten int64
335
+ walker := s.Client.Walk(remotePath)
336
+ for walker.Step() {
337
+ if walker.Err() != nil {
338
+ return totalWritten, walker.Err()
339
+ }
340
+
341
+ path := walker.Path()
342
+ info := walker.Stat()
343
+
344
+ relPath, err := filepath.Rel(remotePath, path)
345
+ if err != nil {
346
+ return totalWritten, err
347
+ }
348
+
349
+ if err := checkPathTraversal(relPath); err != nil {
350
+ return totalWritten, err
351
+ }
352
+
353
+ targetPath := filepath.Join(localPath, relPath)
354
+
355
+ if info.IsDir() {
356
+ os.MkdirAll(targetPath, info.Mode())
357
+ continue
358
+ }
359
+
360
+ // Sync logic
361
+ localInfo, err := os.Stat(targetPath)
362
+ if err == nil && localInfo.Size() == info.Size() && localInfo.ModTime().Unix() == info.ModTime().Unix() {
363
+ continue
364
+ }
365
+
366
+ written, err := s.downloadFile(path, targetPath)
367
+ if err != nil {
368
+ return totalWritten, err
369
+ }
370
+
371
+ totalWritten += written
372
+ os.Chtimes(targetPath, info.ModTime(), info.ModTime())
373
+ }
374
+
375
+ return totalWritten, nil
376
+ }
377
+
378
+ func (s *SFTPClient) downloadFile(remotePath, localPath string) (int64, error) {
204
379
  remoteFile, err := s.Client.Open(remotePath)
205
380
  if err != nil {
206
381
  return 0, fmt.Errorf("failed to open remote file: %w", err)
@@ -380,6 +555,14 @@ func (pw *progressWriter) Write(p []byte) (n int, err error) {
380
555
  return
381
556
  }
382
557
 
558
+ // checkPathTraversal ensures that relPath does not escape the destination directory
559
+ func checkPathTraversal(relPath string) error {
560
+ if !filepath.IsLocal(relPath) {
561
+ return fmt.Errorf("path traversal detected: %s escapes destination", relPath)
562
+ }
563
+ return nil
564
+ }
565
+
383
566
  // expandPath expands ~ to home directory
384
567
  func expandPath(path string) string {
385
568
  if len(path) > 0 && path[0] == '~' {
@@ -31,3 +31,46 @@ func TestNewAuthMethodsFromKeyPath(t *testing.T) {
31
31
  t.Error("expected error for non-existent key")
32
32
  }
33
33
  }
34
+
35
+ func TestCheckPathTraversal(t *testing.T) {
36
+ tests := []struct {
37
+ name string
38
+ relPath string
39
+ wantErr bool
40
+ }{
41
+ {
42
+ name: "safe path within directory",
43
+ relPath: "safe_file.txt",
44
+ wantErr: false,
45
+ },
46
+ {
47
+ name: "safe path same as directory",
48
+ relPath: ".",
49
+ wantErr: false,
50
+ },
51
+ {
52
+ name: "path traversal attempt with ../",
53
+ relPath: "../etc/passwd",
54
+ wantErr: true,
55
+ },
56
+ {
57
+ name: "absolute path traversal",
58
+ relPath: "/etc/passwd",
59
+ wantErr: true,
60
+ },
61
+ {
62
+ name: "deep nested safe path",
63
+ relPath: "a/b/c/d.txt",
64
+ wantErr: false,
65
+ },
66
+ }
67
+
68
+ for _, tt := range tests {
69
+ t.Run(tt.name, func(t *testing.T) {
70
+ err := checkPathTraversal(tt.relPath)
71
+ if (err != nil) != tt.wantErr {
72
+ t.Errorf("checkPathTraversal() error = %v, wantErr %v", err, tt.wantErr)
73
+ }
74
+ })
75
+ }
76
+ }
@@ -6,6 +6,7 @@ import (
6
6
  "log"
7
7
  "net"
8
8
  "sync"
9
+ "time"
9
10
 
10
11
  "golang.org/x/crypto/ssh"
11
12
  )
@@ -72,16 +73,32 @@ func (f *Forwarder) startLocalForward() {
72
73
  f.mu.RUnlock()
73
74
  return
74
75
  }
76
+ listener := f.listener
75
77
  f.mu.RUnlock()
76
78
 
77
- conn, err := f.listener.Accept()
79
+ if listener == nil {
80
+ time.Sleep(1 * time.Second)
81
+ continue
82
+ }
83
+
84
+ // Set deadline to prevent blocking forever
85
+ if tcpListener, ok := listener.(*net.TCPListener); ok {
86
+ tcpListener.SetDeadline(time.Now().Add(5 * time.Second))
87
+ }
88
+
89
+ conn, err := listener.Accept()
78
90
  if err != nil {
91
+ // Check if it's a timeout
92
+ if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
93
+ continue
94
+ }
79
95
  f.mu.RLock()
80
96
  closed := f.closed
81
97
  f.mu.RUnlock()
82
98
  if closed {
83
99
  return
84
100
  }
101
+ log.Printf("[portforward] Accept error: %v", err)
85
102
  continue
86
103
  }
87
104
 
@@ -105,15 +122,27 @@ func (f *Forwarder) handleLocalConnection(localConn net.Conn) {
105
122
 
106
123
  log.Printf("[portforward] Connection accepted from %s", localConn.RemoteAddr())
107
124
 
108
- remoteAddr := fmt.Sprintf("localhost:%d", f.RemotePort)
109
- remoteConn, err := f.sshClient.Dial("tcp", remoteAddr)
125
+ // Use 127.0.0.1 instead of localhost to avoid IPv6 issues
126
+ remoteAddr := fmt.Sprintf("127.0.0.1:%d", f.RemotePort)
127
+
128
+ // Dial the remote endpoint *through* the SSH tunnel
129
+ f.mu.RLock()
130
+ client := f.sshClient
131
+ f.mu.RUnlock()
132
+
133
+ if client == nil {
134
+ log.Printf("[portforward] SSH client is nil, cannot connect to remote %s", remoteAddr)
135
+ return
136
+ }
137
+
138
+ remoteConn, err := client.Dial("tcp", remoteAddr)
110
139
  if err != nil {
111
- log.Printf("[portforward] Failed to connect to remote %s: %v", remoteAddr, err)
140
+ log.Printf("[portforward] Failed to connect to remote %s via SSH: %v", remoteAddr, err)
112
141
  return
113
142
  }
114
143
  defer remoteConn.Close()
115
144
 
116
- log.Printf("[portforward] Tunnel established to remote %s", remoteAddr)
145
+ log.Printf("[portforward] Tunnel established to remote %s via SSH", remoteAddr)
117
146
 
118
147
  // Bidirectional copy
119
148
  done := make(chan struct{})
@@ -135,38 +164,48 @@ func (f *Forwarder) startRemoteForward() {
135
164
  go func() {
136
165
  defer f.wg.Done()
137
166
 
138
- // Request the SSH server to listen on remote:remotePort and forward connections to us
139
- // Payload format: string (address) + uint32 (port)
140
- addr := fmt.Sprintf(":%d", f.RemotePort)
141
- payload := ssh.Marshal(struct {
142
- Addr string
143
- Port uint32
144
- }{Addr: addr, Port: uint32(f.RemotePort)})
145
-
146
- ok, _, err := f.sshClient.SendRequest("tcpip-forward", true, payload)
147
- if err != nil {
148
- log.Printf("[portforward] Failed to send tcpip-forward request: %v", err)
167
+ f.mu.RLock()
168
+ client := f.sshClient
169
+ f.mu.RUnlock()
170
+
171
+ if client == nil {
172
+ log.Printf("[portforward] SSH client is nil, cannot start remote forward")
149
173
  return
150
174
  }
151
- if !ok {
152
- log.Printf("[portforward] SSH server rejected tcpip-forward request for port %d", f.RemotePort)
175
+
176
+ // Request the SSH server to listen on remote:remotePort and forward connections to us
177
+ remoteAddr := fmt.Sprintf("0.0.0.0:%d", f.RemotePort)
178
+
179
+ // Note: The ssh.Client.Listen internally sends the tcpip-forward request
180
+ listener, err := client.Listen("tcp", remoteAddr)
181
+ if err != nil {
182
+ log.Printf("[portforward] Failed to listen on remote port %d via SSH: %v", f.RemotePort, err)
153
183
  return
154
184
  }
155
185
 
186
+ f.mu.Lock()
187
+ f.listener = listener
188
+ f.mu.Unlock()
189
+
156
190
  log.Printf("[portforward] Remote forward: SSH server listening on port %d", f.RemotePort)
157
191
 
158
- // Now we need to accept forwarded connections from the SSH server
159
- // The SSH server will open a "forwarded-tcpip" channel for each connection
192
+ // Accept forwarded connections in a loop
160
193
  for {
161
194
  f.mu.RLock()
162
195
  if f.closed {
163
196
  f.mu.RUnlock()
164
197
  return
165
198
  }
199
+ currentListener := f.listener
166
200
  f.mu.RUnlock()
167
201
 
202
+ if currentListener == nil {
203
+ time.Sleep(1 * time.Second)
204
+ continue
205
+ }
206
+
168
207
  // Wait for a forwarded connection
169
- ch, reqs, err := f.sshClient.OpenChannel("forwarded-tcpip", nil)
208
+ remoteConn, err := currentListener.Accept()
170
209
  if err != nil {
171
210
  f.mu.RLock()
172
211
  closed := f.closed
@@ -175,18 +214,17 @@ func (f *Forwarder) startRemoteForward() {
175
214
  return
176
215
  }
177
216
  log.Printf("[portforward] Error accepting forwarded connection: %v", err)
178
- continue
217
+ return // Listener is broken or closed
179
218
  }
180
219
 
181
- // Discard any requests
182
- go ssh.DiscardRequests(reqs)
220
+ log.Printf("[portforward] Accepted remote connection from %v", remoteConn.RemoteAddr())
183
221
 
184
- // Connect to local port
185
- localAddr := fmt.Sprintf("localhost:%d", f.LocalPort)
186
- localConn, err := net.Dial("tcp", localAddr)
222
+ // Connect to local port to forward the traffic
223
+ localAddr := fmt.Sprintf("127.0.0.1:%d", f.LocalPort)
224
+ localConn, err := net.DialTimeout("tcp", localAddr, 5*time.Second)
187
225
  if err != nil {
188
226
  log.Printf("[portforward] Failed to connect to local port %d: %v", f.LocalPort, err)
189
- ch.Close()
227
+ remoteConn.Close()
190
228
  continue
191
229
  }
192
230
 
@@ -198,12 +236,20 @@ func (f *Forwarder) startRemoteForward() {
198
236
 
199
237
  // Bidirectional copy
200
238
  go func() {
201
- io.Copy(ch, localConn)
202
- ch.Close()
203
- localConn.Close()
204
- }()
205
- go func() {
206
- io.Copy(localConn, ch)
239
+ defer func() {
240
+ remoteConn.Close()
241
+ localConn.Close()
242
+ f.mu.Lock()
243
+ delete(f.conns, localConn)
244
+ f.mu.Unlock()
245
+ }()
246
+ done := make(chan struct{})
247
+ go func() {
248
+ io.Copy(remoteConn, localConn)
249
+ close(done)
250
+ }()
251
+ io.Copy(localConn, remoteConn)
252
+ <-done
207
253
  }()
208
254
  }
209
255
  }()
@@ -217,7 +263,6 @@ func (f *Forwarder) Close() {
217
263
  return
218
264
  }
219
265
  f.closed = true
220
- f.mu.Unlock()
221
266
 
222
267
  if f.listener != nil {
223
268
  f.listener.Close()
@@ -226,6 +271,7 @@ func (f *Forwarder) Close() {
226
271
  for conn := range f.conns {
227
272
  conn.Close()
228
273
  }
274
+ f.mu.Unlock()
229
275
 
230
276
  f.wg.Wait()
231
277
  }
@@ -235,8 +281,6 @@ func (f *Forwarder) Restart(sshClient *ssh.Client) {
235
281
  f.mu.Lock()
236
282
  f.sshClient = sshClient
237
283
  f.closed = false
238
- f.conns = make(map[net.Conn]bool)
239
- f.mu.Unlock()
240
284
 
241
285
  if f.listener != nil {
242
286
  f.listener.Close()
@@ -246,11 +290,21 @@ func (f *Forwarder) Restart(sshClient *ssh.Client) {
246
290
  // Create new listener
247
291
  addr := fmt.Sprintf("localhost:%d", f.LocalPort)
248
292
  listener, err := net.Listen("tcp", addr)
249
- if err != nil {
250
- return
293
+ if err == nil {
294
+ f.listener = listener
295
+ } else {
296
+ f.listener = nil
297
+ log.Printf("[portforward] Failed to recreate listener on %s: %v", addr, err)
251
298
  }
252
- f.listener = listener
299
+ } else {
300
+ f.listener = nil // Will be re-initialized in startRemoteForward
301
+ }
302
+
303
+ for conn := range f.conns {
304
+ conn.Close()
253
305
  }
306
+ f.conns = make(map[net.Conn]bool)
307
+ f.mu.Unlock()
254
308
 
255
309
  f.Start()
256
310
  }
@@ -72,6 +72,7 @@ type ReconnectParams struct {
72
72
  type ExecParams struct {
73
73
  SessionID string `json:"session_id,omitempty"`
74
74
  Command string `json:"command"`
75
+ Timeout int `json:"timeout,omitempty"` // 超时时间(秒),0 表示无超时
75
76
  }
76
77
 
77
78
  type ForwardParams struct {
@@ -108,6 +109,7 @@ type SCPResult struct {
108
109
  // SFTPParams represents SFTP command parameters
109
110
  type SFTPParams struct {
110
111
  SessionID string `json:"session_id,omitempty"`
111
- Command string `json:"command"` // "ls", "cd", "pwd", "mkdir", "rm", "rmdir"
112
+ Command string `json:"command"`
113
+ Timeout int `json:"timeout,omitempty"` // 超时时间(秒),0 表示无超时 // "ls", "cd", "pwd", "mkdir", "rm", "rmdir"
112
114
  Path string `json:"path"`
113
115
  }