gssh-agent 1.0.4 → 1.0.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -6,6 +6,7 @@ import (
6
6
  "log"
7
7
  "net"
8
8
  "sync"
9
+ "time"
9
10
 
10
11
  "golang.org/x/crypto/ssh"
11
12
  )
@@ -72,16 +73,32 @@ func (f *Forwarder) startLocalForward() {
72
73
  f.mu.RUnlock()
73
74
  return
74
75
  }
76
+ listener := f.listener
75
77
  f.mu.RUnlock()
76
78
 
77
- conn, err := f.listener.Accept()
79
+ if listener == nil {
80
+ time.Sleep(1 * time.Second)
81
+ continue
82
+ }
83
+
84
+ // Set deadline to prevent blocking forever
85
+ if tcpListener, ok := listener.(*net.TCPListener); ok {
86
+ tcpListener.SetDeadline(time.Now().Add(5 * time.Second))
87
+ }
88
+
89
+ conn, err := listener.Accept()
78
90
  if err != nil {
91
+ // Check if it's a timeout
92
+ if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
93
+ continue
94
+ }
79
95
  f.mu.RLock()
80
96
  closed := f.closed
81
97
  f.mu.RUnlock()
82
98
  if closed {
83
99
  return
84
100
  }
101
+ log.Printf("[portforward] Accept error: %v", err)
85
102
  continue
86
103
  }
87
104
 
@@ -105,15 +122,27 @@ func (f *Forwarder) handleLocalConnection(localConn net.Conn) {
105
122
 
106
123
  log.Printf("[portforward] Connection accepted from %s", localConn.RemoteAddr())
107
124
 
108
- remoteAddr := fmt.Sprintf("localhost:%d", f.RemotePort)
109
- remoteConn, err := f.sshClient.Dial("tcp", remoteAddr)
125
+ // Use 127.0.0.1 instead of localhost to avoid IPv6 issues
126
+ remoteAddr := fmt.Sprintf("127.0.0.1:%d", f.RemotePort)
127
+
128
+ // Dial the remote endpoint *through* the SSH tunnel
129
+ f.mu.RLock()
130
+ client := f.sshClient
131
+ f.mu.RUnlock()
132
+
133
+ if client == nil {
134
+ log.Printf("[portforward] SSH client is nil, cannot connect to remote %s", remoteAddr)
135
+ return
136
+ }
137
+
138
+ remoteConn, err := client.Dial("tcp", remoteAddr)
110
139
  if err != nil {
111
- log.Printf("[portforward] Failed to connect to remote %s: %v", remoteAddr, err)
140
+ log.Printf("[portforward] Failed to connect to remote %s via SSH: %v", remoteAddr, err)
112
141
  return
113
142
  }
114
143
  defer remoteConn.Close()
115
144
 
116
- log.Printf("[portforward] Tunnel established to remote %s", remoteAddr)
145
+ log.Printf("[portforward] Tunnel established to remote %s via SSH", remoteAddr)
117
146
 
118
147
  // Bidirectional copy
119
148
  done := make(chan struct{})
@@ -135,38 +164,48 @@ func (f *Forwarder) startRemoteForward() {
135
164
  go func() {
136
165
  defer f.wg.Done()
137
166
 
138
- // Request the SSH server to listen on remote:remotePort and forward connections to us
139
- // Payload format: string (address) + uint32 (port)
140
- addr := fmt.Sprintf(":%d", f.RemotePort)
141
- payload := ssh.Marshal(struct {
142
- Addr string
143
- Port uint32
144
- }{Addr: addr, Port: uint32(f.RemotePort)})
145
-
146
- ok, _, err := f.sshClient.SendRequest("tcpip-forward", true, payload)
147
- if err != nil {
148
- log.Printf("[portforward] Failed to send tcpip-forward request: %v", err)
167
+ f.mu.RLock()
168
+ client := f.sshClient
169
+ f.mu.RUnlock()
170
+
171
+ if client == nil {
172
+ log.Printf("[portforward] SSH client is nil, cannot start remote forward")
149
173
  return
150
174
  }
151
- if !ok {
152
- log.Printf("[portforward] SSH server rejected tcpip-forward request for port %d", f.RemotePort)
175
+
176
+ // Request the SSH server to listen on remote:remotePort and forward connections to us
177
+ remoteAddr := fmt.Sprintf("0.0.0.0:%d", f.RemotePort)
178
+
179
+ // Note: The ssh.Client.Listen internally sends the tcpip-forward request
180
+ listener, err := client.Listen("tcp", remoteAddr)
181
+ if err != nil {
182
+ log.Printf("[portforward] Failed to listen on remote port %d via SSH: %v", f.RemotePort, err)
153
183
  return
154
184
  }
155
185
 
186
+ f.mu.Lock()
187
+ f.listener = listener
188
+ f.mu.Unlock()
189
+
156
190
  log.Printf("[portforward] Remote forward: SSH server listening on port %d", f.RemotePort)
157
191
 
158
- // Now we need to accept forwarded connections from the SSH server
159
- // The SSH server will open a "forwarded-tcpip" channel for each connection
192
+ // Accept forwarded connections in a loop
160
193
  for {
161
194
  f.mu.RLock()
162
195
  if f.closed {
163
196
  f.mu.RUnlock()
164
197
  return
165
198
  }
199
+ currentListener := f.listener
166
200
  f.mu.RUnlock()
167
201
 
202
+ if currentListener == nil {
203
+ time.Sleep(1 * time.Second)
204
+ continue
205
+ }
206
+
168
207
  // Wait for a forwarded connection
169
- ch, reqs, err := f.sshClient.OpenChannel("forwarded-tcpip", nil)
208
+ remoteConn, err := currentListener.Accept()
170
209
  if err != nil {
171
210
  f.mu.RLock()
172
211
  closed := f.closed
@@ -175,18 +214,17 @@ func (f *Forwarder) startRemoteForward() {
175
214
  return
176
215
  }
177
216
  log.Printf("[portforward] Error accepting forwarded connection: %v", err)
178
- continue
217
+ return // Listener is broken or closed
179
218
  }
180
219
 
181
- // Discard any requests
182
- go ssh.DiscardRequests(reqs)
220
+ log.Printf("[portforward] Accepted remote connection from %v", remoteConn.RemoteAddr())
183
221
 
184
- // Connect to local port
185
- localAddr := fmt.Sprintf("localhost:%d", f.LocalPort)
186
- localConn, err := net.Dial("tcp", localAddr)
222
+ // Connect to local port to forward the traffic
223
+ localAddr := fmt.Sprintf("127.0.0.1:%d", f.LocalPort)
224
+ localConn, err := net.DialTimeout("tcp", localAddr, 5*time.Second)
187
225
  if err != nil {
188
226
  log.Printf("[portforward] Failed to connect to local port %d: %v", f.LocalPort, err)
189
- ch.Close()
227
+ remoteConn.Close()
190
228
  continue
191
229
  }
192
230
 
@@ -198,12 +236,20 @@ func (f *Forwarder) startRemoteForward() {
198
236
 
199
237
  // Bidirectional copy
200
238
  go func() {
201
- io.Copy(ch, localConn)
202
- ch.Close()
203
- localConn.Close()
204
- }()
205
- go func() {
206
- io.Copy(localConn, ch)
239
+ defer func() {
240
+ remoteConn.Close()
241
+ localConn.Close()
242
+ f.mu.Lock()
243
+ delete(f.conns, localConn)
244
+ f.mu.Unlock()
245
+ }()
246
+ done := make(chan struct{})
247
+ go func() {
248
+ io.Copy(remoteConn, localConn)
249
+ close(done)
250
+ }()
251
+ io.Copy(localConn, remoteConn)
252
+ <-done
207
253
  }()
208
254
  }
209
255
  }()
@@ -217,7 +263,6 @@ func (f *Forwarder) Close() {
217
263
  return
218
264
  }
219
265
  f.closed = true
220
- f.mu.Unlock()
221
266
 
222
267
  if f.listener != nil {
223
268
  f.listener.Close()
@@ -226,6 +271,7 @@ func (f *Forwarder) Close() {
226
271
  for conn := range f.conns {
227
272
  conn.Close()
228
273
  }
274
+ f.mu.Unlock()
229
275
 
230
276
  f.wg.Wait()
231
277
  }
@@ -235,8 +281,6 @@ func (f *Forwarder) Restart(sshClient *ssh.Client) {
235
281
  f.mu.Lock()
236
282
  f.sshClient = sshClient
237
283
  f.closed = false
238
- f.conns = make(map[net.Conn]bool)
239
- f.mu.Unlock()
240
284
 
241
285
  if f.listener != nil {
242
286
  f.listener.Close()
@@ -246,11 +290,21 @@ func (f *Forwarder) Restart(sshClient *ssh.Client) {
246
290
  // Create new listener
247
291
  addr := fmt.Sprintf("localhost:%d", f.LocalPort)
248
292
  listener, err := net.Listen("tcp", addr)
249
- if err != nil {
250
- return
293
+ if err == nil {
294
+ f.listener = listener
295
+ } else {
296
+ f.listener = nil
297
+ log.Printf("[portforward] Failed to recreate listener on %s: %v", addr, err)
251
298
  }
252
- f.listener = listener
299
+ } else {
300
+ f.listener = nil // Will be re-initialized in startRemoteForward
301
+ }
302
+
303
+ for conn := range f.conns {
304
+ conn.Close()
253
305
  }
306
+ f.conns = make(map[net.Conn]bool)
307
+ f.mu.Unlock()
254
308
 
255
309
  f.Start()
256
310
  }
@@ -72,6 +72,7 @@ type ReconnectParams struct {
72
72
  type ExecParams struct {
73
73
  SessionID string `json:"session_id,omitempty"`
74
74
  Command string `json:"command"`
75
+ Timeout int `json:"timeout,omitempty"` // 超时时间(秒),0 表示无超时
75
76
  }
76
77
 
77
78
  type ForwardParams struct {
@@ -108,6 +109,7 @@ type SCPResult struct {
108
109
  // SFTPParams represents SFTP command parameters
109
110
  type SFTPParams struct {
110
111
  SessionID string `json:"session_id,omitempty"`
111
- Command string `json:"command"` // "ls", "cd", "pwd", "mkdir", "rm", "rmdir"
112
+ Command string `json:"command"`
113
+ Timeout int `json:"timeout,omitempty"` // 超时时间(秒),0 表示无超时 // "ls", "cd", "pwd", "mkdir", "rm", "rmdir"
112
114
  Path string `json:"path"`
113
115
  }
@@ -96,42 +96,65 @@ func needsShell(cmd string) bool {
96
96
  // Connect creates a new SSH session
97
97
  func (m *Manager) Connect(user, host string, port int, password, keyPath string) (*protocol.Session, error) {
98
98
  m.mu.Lock()
99
- defer m.mu.Unlock()
100
99
 
101
100
  // Check if session already exists
102
101
  for _, s := range m.sessions {
103
102
  if s.Host == host && s.User == user && s.Port == port {
104
- if s.Status == "connected" {
105
- return toProtocolSession(s), fmt.Errorf("session already exists")
103
+ s.mu.RLock()
104
+ status := s.Status
105
+ s.mu.RUnlock()
106
+
107
+ if status == "connected" || status == "connecting" || status == "reconnecting" {
108
+ m.defaultID = s.ID
109
+ m.mu.Unlock()
110
+ return toProtocolSession(s), nil
106
111
  }
107
- // Try to reconnect
112
+
113
+ // Mark as connecting to prevent concurrent connect attempts
114
+ s.mu.Lock()
115
+ s.Status = "connecting"
116
+ s.mu.Unlock()
117
+
118
+ // Switch default session
119
+ m.defaultID = s.ID
120
+
121
+ m.mu.Unlock()
122
+
123
+ // Perform network connect outside of Manager lock
108
124
  sshClient, err := client.Connect(user, host, port, password, keyPath)
109
125
  if err != nil {
126
+ s.mu.Lock()
127
+ s.Status = "offline"
128
+ s.mu.Unlock()
110
129
  return nil, err
111
130
  }
131
+
132
+ s.mu.Lock()
133
+ // Check if a disconnect was requested while we were connecting
134
+ if s.Status != "connecting" {
135
+ s.mu.Unlock()
136
+ sshClient.Close()
137
+ return nil, fmt.Errorf("session was disconnected while connecting")
138
+ }
112
139
  s.SSHClient = sshClient
113
140
  s.Status = "connected"
141
+ s.mu.Unlock()
142
+
114
143
  return toProtocolSession(s), nil
115
144
  }
116
145
  }
117
146
 
118
- // Create new session
119
- sshClient, err := client.Connect(user, host, port, password, keyPath)
120
- if err != nil {
121
- return nil, err
122
- }
123
-
124
- id := uuid.New().String()
147
+ // Create new session placeholder
148
+ id := uuid.New().String()[:8]
125
149
  ms := &ManagedSession{
126
- ID: id,
127
- Host: host,
128
- User: user,
129
- Port: port,
130
- Status: "connected",
131
- Password: password,
132
- KeyPath: keyPath,
133
- SSHClient: sshClient,
134
- Forwards: make(map[string]*portforward.Forwarder),
150
+ ID: id,
151
+ Host: host,
152
+ User: user,
153
+ Port: port,
154
+ Status: "connecting",
155
+ Password: password,
156
+ KeyPath: keyPath,
157
+ Forwards: make(map[string]*portforward.Forwarder),
135
158
  }
136
159
 
137
160
  m.sessions[id] = ms
@@ -140,6 +163,31 @@ func (m *Manager) Connect(user, host string, port int, password, keyPath string)
140
163
  if m.defaultID == "" {
141
164
  m.defaultID = id
142
165
  }
166
+ m.mu.Unlock()
167
+
168
+ // Perform network connect outside of Manager lock
169
+ sshClient, err := client.Connect(user, host, port, password, keyPath)
170
+ if err != nil {
171
+ // Clean up the placeholder since initialization failed completely
172
+ m.mu.Lock()
173
+ delete(m.sessions, id)
174
+ if m.defaultID == id {
175
+ m.defaultID = ""
176
+ }
177
+ m.mu.Unlock()
178
+ return nil, err
179
+ }
180
+
181
+ ms.mu.Lock()
182
+ // Check if a disconnect was requested while we were connecting
183
+ if ms.Status != "connecting" {
184
+ ms.mu.Unlock()
185
+ sshClient.Close()
186
+ return nil, fmt.Errorf("session was disconnected while connecting")
187
+ }
188
+ ms.SSHClient = sshClient
189
+ ms.Status = "connected"
190
+ ms.mu.Unlock()
143
191
 
144
192
  // Start reconnect monitor
145
193
  go m.monitorReconnect(ms)
@@ -149,28 +197,32 @@ func (m *Manager) Connect(user, host string, port int, password, keyPath string)
149
197
 
150
198
  // Disconnect closes a session
151
199
  func (m *Manager) Disconnect(sessionID string) error {
200
+ m.mu.Lock()
201
+
152
202
  // Use default session if not specified
153
203
  if sessionID == "" {
154
204
  sessionID = m.defaultID
155
205
  }
156
206
 
157
- m.mu.Lock()
158
- defer m.mu.Unlock()
159
-
160
207
  ms, ok := m.sessions[sessionID]
161
208
  if !ok {
209
+ m.mu.Unlock()
162
210
  return fmt.Errorf("session not found")
163
211
  }
164
212
 
165
- if ms.SSHClient != nil {
166
- ms.SSHClient.Close()
213
+ // Clear default ID when disconnecting
214
+ if m.defaultID == sessionID {
215
+ m.defaultID = ""
167
216
  }
217
+ m.mu.Unlock()
168
218
 
219
+ ms.mu.Lock()
220
+ sshClient := ms.SSHClient
169
221
  ms.Status = "disconnected"
222
+ ms.mu.Unlock()
170
223
 
171
- // Clear default ID when disconnecting
172
- if m.defaultID == sessionID {
173
- m.defaultID = ""
224
+ if sshClient != nil {
225
+ sshClient.Close()
174
226
  }
175
227
 
176
228
  return nil
@@ -191,21 +243,36 @@ func (m *Manager) Reconnect(sessionID string) (*protocol.Session, error) {
191
243
  return nil, fmt.Errorf("session not found")
192
244
  }
193
245
 
246
+ ms.mu.Lock()
247
+ if ms.Status == "connecting" || ms.Status == "reconnecting" {
248
+ ms.mu.Unlock()
249
+ return nil, fmt.Errorf("session is currently connecting")
250
+ }
251
+ existingClient := ms.SSHClient
252
+ ms.Status = "reconnecting"
253
+ ms.mu.Unlock()
254
+
194
255
  // Close existing connection
195
- if ms.SSHClient != nil {
196
- ms.SSHClient.Close()
256
+ if existingClient != nil {
257
+ existingClient.Close()
197
258
  }
198
259
 
199
260
  // Create new connection
200
261
  sshClient, err := client.Connect(ms.User, ms.Host, ms.Port, ms.Password, ms.KeyPath)
201
262
  if err != nil {
202
263
  ms.mu.Lock()
203
- ms.Status = "disconnected"
264
+ ms.Status = "offline"
204
265
  ms.mu.Unlock()
205
266
  return nil, err
206
267
  }
207
268
 
208
269
  ms.mu.Lock()
270
+ // Check if a disconnect was requested while we were reconnecting
271
+ if ms.Status != "reconnecting" {
272
+ ms.mu.Unlock()
273
+ sshClient.Close()
274
+ return nil, fmt.Errorf("session was disconnected while reconnecting")
275
+ }
209
276
  ms.SSHClient = sshClient
210
277
  ms.Status = "connected"
211
278
  ms.mu.Unlock()
@@ -224,7 +291,7 @@ func (m *Manager) Reconnect(sessionID string) (*protocol.Session, error) {
224
291
  }
225
292
 
226
293
  // Exec executes a command on a session
227
- func (m *Manager) Exec(sessionID, command string) (*protocol.ExecResult, error) {
294
+ func (m *Manager) Exec(sessionID, command string, timeout int) (*protocol.ExecResult, error) {
228
295
  m.mu.RLock()
229
296
  var ms *ManagedSession
230
297
  if sessionID != "" {
@@ -259,18 +326,51 @@ func (m *Manager) Exec(sessionID, command string) (*protocol.ExecResult, error)
259
326
  fullCmd = fmt.Sprintf("/bin/sh -c %q", command)
260
327
  }
261
328
 
262
- output, err := session.CombinedOutput(fullCmd)
329
+ // 执行命令,支持超时
330
+ var output []byte
331
+ var exitErr *ssh.ExitError
332
+ var ok bool
333
+
334
+ if timeout > 0 {
335
+ type result struct {
336
+ out []byte
337
+ err error
338
+ }
339
+ done := make(chan result, 1)
340
+ go func() {
341
+ out, err := session.CombinedOutput(fullCmd)
342
+ done <- result{out, err}
343
+ }()
344
+ select {
345
+ case res := <-done:
346
+ output = res.out
347
+ err = res.err
348
+ // 命令执行完成
349
+ case <-time.After(time.Duration(timeout) * time.Second):
350
+ session.Signal(ssh.SIGKILL)
351
+ return &protocol.ExecResult{
352
+ Stdout: "",
353
+ Stderr: "",
354
+ ExitCode: -1,
355
+ }, fmt.Errorf("command timed out after %d seconds", timeout)
356
+ }
357
+ } else {
358
+ output, err = session.CombinedOutput(fullCmd)
359
+ }
360
+
361
+ ms.mu.Lock()
362
+ ms.LastCmd = command
363
+ ms.mu.Unlock()
364
+
263
365
  if err != nil {
264
- exitErr, ok := err.(*ssh.ExitError)
366
+ exitErr, ok = err.(*ssh.ExitError)
265
367
  if ok {
266
- ms.LastCmd = command
267
368
  return &protocol.ExecResult{
268
369
  Stdout: string(output),
269
370
  Stderr: "",
270
371
  ExitCode: exitErr.ExitStatus(),
271
372
  }, nil
272
373
  }
273
- ms.LastCmd = command
274
374
  return &protocol.ExecResult{
275
375
  Stdout: string(output),
276
376
  Stderr: "",
@@ -278,7 +378,6 @@ func (m *Manager) Exec(sessionID, command string) (*protocol.ExecResult, error)
278
378
  }, nil
279
379
  }
280
380
 
281
- ms.LastCmd = command
282
381
  return &protocol.ExecResult{
283
382
  Stdout: string(output),
284
383
  Stderr: "",
@@ -346,7 +445,7 @@ func (m *Manager) AddForward(sessionID, forwardType string, localPort, remotePor
346
445
  return nil, err
347
446
  }
348
447
 
349
- id := uuid.New().String()
448
+ id := uuid.New().String()[:8]
350
449
  forwarder.ID = id
351
450
 
352
451
  m.forwardMu.Lock()
@@ -573,14 +672,30 @@ func (m *Manager) monitorReconnect(ms *ManagedSession) {
573
672
  return
574
673
  }
575
674
 
576
- if sshClient == nil || sshClient.Client == nil {
675
+ isAlive := false
676
+ if sshClient != nil && sshClient.Client != nil {
677
+ // Try to send a keepalive request
678
+ _, _, err := sshClient.Client.SendRequest("keepalive@gssh", true, nil)
679
+ if err == nil {
680
+ isAlive = true
681
+ }
682
+ }
683
+
684
+ if !isAlive {
577
685
  ms.mu.Lock()
578
686
  ms.Status = "reconnecting"
687
+ if sshClient != nil {
688
+ sshClient.Close()
689
+ ms.SSHClient = nil
690
+ }
579
691
  ms.mu.Unlock()
580
692
 
581
693
  // Try to reconnect
582
694
  newClient, err := client.Connect(ms.User, ms.Host, ms.Port, ms.Password, ms.KeyPath)
583
695
  if err != nil {
696
+ ms.mu.Lock()
697
+ ms.Status = "offline"
698
+ ms.mu.Unlock()
584
699
  continue
585
700
  }
586
701