ssh-handler 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ssh_handler/core.py ADDED
@@ -0,0 +1,793 @@
1
+ """
2
+ The main SSH handler: command execution, interactive shells, sudo, full SFTP
3
+ file operations, optional SCP, jump-host chaining, and remote-OS awareness.
4
+
5
+ Designed for three consumers from one object:
6
+ * test-automation framework -> raise-on-error (default)
7
+ * standalone CLI script -> see ssh_handler.cli
8
+ * PyQt5 tool -> safe=True + log_callback (see pyqt_worker)
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ import os
14
+ import socket
15
+ import stat
16
+ import time
17
+ import logging
18
+ import posixpath
19
+ from typing import Callable, Optional, Sequence, Union
20
+
21
+ import paramiko
22
+
23
+ from .config import SSHConfig
24
+ from .credentials import Secret, mask
25
+ from .results import CommandResult, TransferResult, ShellResult, OperationResult
26
+ from .exceptions import (
27
+ SSHError,
28
+ SSHConnectionError,
29
+ SSHAuthenticationError,
30
+ SSHTimeoutError,
31
+ SSHCommandError,
32
+ SSHTransferError,
33
+ SSHNotConnectedError,
34
+ )
35
+
36
+ try: # optional SCP support
37
+ from scp import SCPClient
38
+ _HAS_SCP = True
39
+ except Exception: # pragma: no cover
40
+ SCPClient = None
41
+ _HAS_SCP = False
42
+
43
+
44
+ # --------------------------------------------------------------------------- #
45
+ # Interactive shell session
46
+ # --------------------------------------------------------------------------- #
47
+ class ShellSession:
48
+ """
49
+ A persistent interactive shell (one PTY channel) for devices/flows that
50
+ need state between commands, prompts, or send/expect interaction.
51
+ """
52
+
53
+ def __init__(self, channel: paramiko.Channel, encoding: str = "utf-8"):
54
+ self._chan = channel
55
+ self.encoding = encoding
56
+
57
+ def send(self, data: str) -> None:
58
+ if not data.endswith("\n"):
59
+ data += "\n"
60
+ self._chan.send(data)
61
+
62
+ def read_until(self, marker: str, timeout: float = 30.0) -> ShellResult:
63
+ """Read output until *marker* appears or *timeout* elapses."""
64
+ start = time.time()
65
+ buf = ""
66
+ self._chan.settimeout(0.5)
67
+ while time.time() - start < timeout:
68
+ try:
69
+ chunk = self._chan.recv(65536)
70
+ if not chunk:
71
+ break
72
+ buf += chunk.decode(self.encoding, errors="replace")
73
+ if marker in buf:
74
+ return ShellResult(buf, marker, False, time.time() - start)
75
+ except socket.timeout:
76
+ continue
77
+ return ShellResult(buf, None, True, time.time() - start)
78
+
79
+ def read_available(self) -> str:
80
+ out = ""
81
+ self._chan.settimeout(0.2)
82
+ while self._chan.recv_ready():
83
+ try:
84
+ out += self._chan.recv(65536).decode(self.encoding, errors="replace")
85
+ except socket.timeout:
86
+ break
87
+ return out
88
+
89
+ def close(self) -> None:
90
+ try:
91
+ self._chan.close()
92
+ except Exception:
93
+ pass
94
+
95
+ def __enter__(self) -> "ShellSession":
96
+ return self
97
+
98
+ def __exit__(self, *exc) -> None:
99
+ self.close()
100
+
101
+
102
+ # --------------------------------------------------------------------------- #
103
+ # The handler
104
+ # --------------------------------------------------------------------------- #
105
+ class SSHHandler:
106
+ """High-level SSH + SFTP + SCP handler. See package README for recipes."""
107
+
108
+ _POLICIES = {
109
+ "auto": paramiko.AutoAddPolicy,
110
+ "reject": paramiko.RejectPolicy,
111
+ "warn": paramiko.WarningPolicy,
112
+ }
113
+
114
+ def __init__(
115
+ self,
116
+ config: SSHConfig,
117
+ *,
118
+ log_callback: Optional[Callable[[str], None]] = None,
119
+ logger: Optional[logging.Logger] = None,
120
+ safe: bool = False,
121
+ ):
122
+ self.config = config
123
+ self._safe_default = safe
124
+ self._log_callback = log_callback
125
+ self.log = logger or self._build_logger()
126
+
127
+ self._client: Optional[paramiko.SSHClient] = None
128
+ self._sftp: Optional[paramiko.SFTPClient] = None
129
+ self._jump: Optional["SSHHandler"] = None
130
+ self._remote_os: Optional[str] = (
131
+ None if config.remote_os == "auto" else config.remote_os
132
+ )
133
+
134
+ # ------------------------------------------------------------------ #
135
+ # Logging (secret-aware)
136
+ # ------------------------------------------------------------------ #
137
+ def _build_logger(self) -> logging.Logger:
138
+ logger = logging.getLogger(f"ssh_handler.{self.config.host}")
139
+ if not logger.handlers:
140
+ handler = logging.StreamHandler()
141
+ handler.setFormatter(
142
+ logging.Formatter("%(asctime)s [%(levelname)s] %(name)s: %(message)s")
143
+ )
144
+ logger.addHandler(handler)
145
+ logger.setLevel(logging.INFO)
146
+ return logger
147
+
148
+ def _secrets(self) -> list:
149
+ s = [self.config.password, self.config.key_passphrase]
150
+ if self.config.jump_host:
151
+ s += [self.config.jump_host.password, self.config.jump_host.key_passphrase]
152
+ return [x for x in s if x]
153
+
154
+ def _emit(self, level: int, msg: str) -> None:
155
+ safe_msg = mask(msg, *self._secrets())
156
+ self.log.log(level, safe_msg)
157
+ if self._log_callback:
158
+ try:
159
+ self._log_callback(f"[{logging.getLevelName(level)}] {safe_msg}")
160
+ except Exception:
161
+ pass
162
+
163
+ # ------------------------------------------------------------------ #
164
+ # Safe-mode wrapper
165
+ # ------------------------------------------------------------------ #
166
+ def _guard(self, action: str, fn, *args, safe: Optional[bool] = None, **kwargs):
167
+ use_safe = self._safe_default if safe is None else safe
168
+ if not use_safe:
169
+ return fn(*args, **kwargs)
170
+ try:
171
+ return OperationResult(True, action, value=fn(*args, **kwargs))
172
+ except Exception as exc:
173
+ self._emit(logging.ERROR, f"{action} failed: {exc}")
174
+ return OperationResult(False, action, error=exc)
175
+
176
+ # ------------------------------------------------------------------ #
177
+ # State
178
+ # ------------------------------------------------------------------ #
179
+ @property
180
+ def is_connected(self) -> bool:
181
+ if self._client is None:
182
+ return False
183
+ t = self._client.get_transport()
184
+ return bool(t and t.is_active())
185
+
186
+ @property
187
+ def client(self) -> paramiko.SSHClient:
188
+ """The underlying paramiko SSHClient (for anything not wrapped here)."""
189
+ self._ensure_connected()
190
+ return self._client
191
+
192
+ @property
193
+ def transport(self) -> paramiko.Transport:
194
+ return self.client.get_transport()
195
+
196
+ def _ensure_connected(self) -> None:
197
+ if self.is_connected:
198
+ return
199
+ if self.config.auto_reconnect:
200
+ self._emit(logging.WARNING, "Session not active; reconnecting…")
201
+ self._connect_with_retries()
202
+ else:
203
+ raise SSHNotConnectedError("Not connected (auto_reconnect disabled).")
204
+
205
+ # ------------------------------------------------------------------ #
206
+ # Connect / disconnect
207
+ # ------------------------------------------------------------------ #
208
+ def connect(self, *, safe: Optional[bool] = None):
209
+ return self._guard("connect", self._connect_with_retries, safe=safe)
210
+
211
+ def _make_client(self) -> paramiko.SSHClient:
212
+ client = paramiko.SSHClient()
213
+ if self.config.known_hosts:
214
+ try:
215
+ client.load_host_keys(os.path.expanduser(self.config.known_hosts))
216
+ except FileNotFoundError:
217
+ self._emit(logging.WARNING,
218
+ f"known_hosts not found: {self.config.known_hosts}")
219
+ else:
220
+ client.load_system_host_keys()
221
+ policy = self._POLICIES.get(self.config.host_key_policy, paramiko.AutoAddPolicy)
222
+ client.set_missing_host_key_policy(policy())
223
+ return client
224
+
225
+ def _open_jump_channel(self):
226
+ """Connect the jump host and return a direct-tcpip channel to target."""
227
+ self._jump = SSHHandler(
228
+ self.config.jump_host,
229
+ log_callback=self._log_callback,
230
+ logger=self.log,
231
+ )
232
+ self._jump._connect_with_retries()
233
+ self._emit(logging.INFO, f"Tunnelling through jump host "
234
+ f"{self.config.jump_host.host}.")
235
+ return self._jump.transport.open_channel(
236
+ "direct-tcpip",
237
+ dest_addr=(self.config.host, self.config.port),
238
+ src_addr=("127.0.0.1", 0),
239
+ )
240
+
241
+ def _build_connect_kwargs(self, *, empty_password: bool, sock) -> dict:
242
+ cfg = self.config
243
+ pw = cfg.password.reveal() if isinstance(cfg.password, Secret) else cfg.password
244
+ passphrase = (cfg.key_passphrase.reveal()
245
+ if isinstance(cfg.key_passphrase, Secret) else cfg.key_passphrase)
246
+ has_password = pw is not None and pw != ""
247
+
248
+ kwargs: dict = {
249
+ "hostname": cfg.host,
250
+ "port": cfg.port,
251
+ "username": cfg.auth_username,
252
+ "timeout": cfg.connect_timeout,
253
+ "auth_timeout": cfg.auth_timeout,
254
+ "banner_timeout": cfg.banner_timeout,
255
+ "compress": cfg.compress,
256
+ "allow_agent": cfg.allow_agent,
257
+ "look_for_keys": cfg.look_for_keys,
258
+ }
259
+ if sock is not None:
260
+ kwargs["sock"] = sock
261
+
262
+ keys = cfg.normalized_key_files()
263
+ if keys:
264
+ kwargs["key_filename"] = keys
265
+ if passphrase:
266
+ kwargs["passphrase"] = passphrase
267
+
268
+ if cfg.passwordless:
269
+ kwargs["allow_agent"] = True
270
+ kwargs["look_for_keys"] = True
271
+ elif empty_password:
272
+ kwargs["password"] = ""
273
+ kwargs["look_for_keys"] = False
274
+ kwargs["allow_agent"] = False
275
+ elif has_password:
276
+ kwargs["password"] = pw
277
+ # Performance: with a password and no explicit key, skip the slow
278
+ # agent/key probing that otherwise runs first (and can trip the
279
+ # server's MaxAuthTries -> "Too many authentication failures").
280
+ if cfg.fast_auth and not keys:
281
+ kwargs["allow_agent"] = False
282
+ kwargs["look_for_keys"] = False
283
+ return kwargs
284
+
285
+ def _attempt_strategies(self, sock) -> None:
286
+ cfg = self.config
287
+ last_auth: Optional[Exception] = None
288
+ strategies = [False]
289
+ if cfg.allow_empty_password and not cfg.passwordless:
290
+ strategies.append(True)
291
+
292
+ for empty in strategies:
293
+ client = self._make_client()
294
+ kwargs = self._build_connect_kwargs(empty_password=empty, sock=sock)
295
+ label = "empty-password" if empty else "primary"
296
+ try:
297
+ self._emit(logging.DEBUG, f"Auth attempt ({label}) as "
298
+ f"{cfg.auth_username}…")
299
+ client.connect(**kwargs)
300
+ self._client = client
301
+ self._post_connect()
302
+ self._emit(logging.INFO,
303
+ f"Connected to {cfg.auth_username}@{cfg.host}:{cfg.port} "
304
+ f"({label}).")
305
+ return
306
+ except paramiko.AuthenticationException as exc:
307
+ last_auth = exc
308
+ client.close()
309
+ self._emit(logging.DEBUG, f"Auth strategy '{label}' rejected.")
310
+ except (paramiko.SSHException, OSError) as exc:
311
+ client.close()
312
+ raise exc
313
+ raise SSHAuthenticationError(
314
+ f"All authentication strategies failed for "
315
+ f"{cfg.auth_username}@{cfg.host}: {last_auth}"
316
+ )
317
+
318
+ def _connect_with_retries(self) -> "SSHHandler":
319
+ cfg = self.config
320
+ attempt, delay, last_exc = 0, cfg.retry_backoff, None
321
+ while attempt < max(1, cfg.max_retries):
322
+ attempt += 1
323
+ sock = None
324
+ try:
325
+ if cfg.jump_host:
326
+ sock = self._open_jump_channel()
327
+ self._attempt_strategies(sock)
328
+ return self
329
+ except SSHAuthenticationError:
330
+ raise # retrying won't help
331
+ except socket.timeout as exc:
332
+ last_exc = exc
333
+ self._emit(logging.WARNING,
334
+ f"Connect timeout (attempt {attempt}/{cfg.max_retries}).")
335
+ except (OSError, paramiko.SSHException) as exc:
336
+ last_exc = exc
337
+ self._emit(logging.WARNING,
338
+ f"Connect error (attempt {attempt}/{cfg.max_retries}): {exc}")
339
+ if attempt < cfg.max_retries:
340
+ time.sleep(delay)
341
+ delay *= cfg.retry_backoff
342
+
343
+ if isinstance(last_exc, socket.timeout):
344
+ raise SSHTimeoutError(f"Timed out connecting to {cfg.host}:{cfg.port}") \
345
+ from last_exc
346
+ raise SSHConnectionError(f"Could not connect to {cfg.host}:{cfg.port}: "
347
+ f"{last_exc}") from last_exc
348
+
349
+ def _post_connect(self) -> None:
350
+ t = self._client.get_transport()
351
+ if t and self.config.keepalive_interval > 0:
352
+ t.set_keepalive(self.config.keepalive_interval)
353
+ self._sftp = None
354
+
355
+ def disconnect(self) -> None:
356
+ if self._sftp is not None:
357
+ try:
358
+ self._sftp.close()
359
+ except Exception:
360
+ pass
361
+ self._sftp = None
362
+ if self._client is not None:
363
+ try:
364
+ self._client.close()
365
+ except Exception:
366
+ pass
367
+ self._emit(logging.INFO, f"Disconnected from {self.config.host}.")
368
+ self._client = None
369
+ if self._jump is not None:
370
+ self._jump.disconnect()
371
+ self._jump = None
372
+
373
+ close = disconnect
374
+
375
+ # ------------------------------------------------------------------ #
376
+ # Remote OS awareness
377
+ # ------------------------------------------------------------------ #
378
+ def detect_os(self) -> str:
379
+ """Return 'windows' or 'linux' (cached). Runs one probe command."""
380
+ if self._remote_os:
381
+ return self._remote_os
382
+ self._ensure_connected()
383
+ probe = self._run("uname -s || ver", timeout=10, check=False,
384
+ get_pty=False, environment=None, encoding="utf-8")
385
+ text = (probe.stdout + probe.stderr).lower()
386
+ self._remote_os = "windows" if ("windows" in text or "microsoft" in text) \
387
+ else "linux"
388
+ self._emit(logging.DEBUG, f"Detected remote OS: {self._remote_os}")
389
+ return self._remote_os
390
+
391
+ @property
392
+ def is_windows(self) -> bool:
393
+ return self.detect_os() == "windows"
394
+
395
+ # ------------------------------------------------------------------ #
396
+ # Command execution
397
+ # ------------------------------------------------------------------ #
398
+ def run(self, command: str, *, timeout: Optional[float] = None,
399
+ check: bool = False, get_pty: bool = False,
400
+ environment: Optional[dict] = None, encoding: str = "utf-8",
401
+ safe: Optional[bool] = None):
402
+ """Execute a command. Returns CommandResult (or OperationResult if safe)."""
403
+ return self._guard("run", self._run, command, timeout=timeout, check=check,
404
+ get_pty=get_pty, environment=environment,
405
+ encoding=encoding, safe=safe)
406
+
407
+ def _run(self, command, *, timeout, check, get_pty, environment, encoding):
408
+ self._ensure_connected()
409
+ eff_timeout = timeout if timeout is not None else self.config.command_timeout
410
+ self._emit(logging.INFO, f"$ {command}")
411
+ start = time.time()
412
+ try:
413
+ _, stdout, stderr = self._client.exec_command(
414
+ command, timeout=eff_timeout, get_pty=get_pty, environment=environment)
415
+ exit_code = stdout.channel.recv_exit_status()
416
+ out = stdout.read().decode(encoding, errors="replace")
417
+ err = stderr.read().decode(encoding, errors="replace")
418
+ except socket.timeout as exc:
419
+ raise SSHTimeoutError(
420
+ f"Command timed out after {eff_timeout}s: {command!r}") from exc
421
+ except paramiko.SSHException as exc:
422
+ res = CommandResult(command, -1, "", str(exc), time.time() - start,
423
+ host=self.config.host)
424
+ raise SSHCommandError(command, res) from exc
425
+
426
+ result = CommandResult(command, exit_code, out, err, time.time() - start,
427
+ host=self.config.host)
428
+ lvl = logging.INFO if result.ok else logging.WARNING
429
+ self._emit(lvl, f"exit={result.exit_code} ({result.duration:.2f}s)")
430
+ if check and not result.ok:
431
+ raise SSHCommandError(command, result)
432
+ return result
433
+
434
+ def run_many(self, commands: Sequence[str], *, stop_on_error: bool = True,
435
+ **kwargs) -> list[CommandResult]:
436
+ results = []
437
+ for cmd in commands:
438
+ res = self._run(cmd, timeout=kwargs.get("timeout"), check=False,
439
+ get_pty=kwargs.get("get_pty", False),
440
+ environment=kwargs.get("environment"),
441
+ encoding=kwargs.get("encoding", "utf-8"))
442
+ results.append(res)
443
+ if stop_on_error and not res.ok:
444
+ self._emit(logging.WARNING,
445
+ f"Stopping batch: {cmd!r} exited {res.exit_code}.")
446
+ break
447
+ return results
448
+
449
+ def sudo(self, command: str, password: Optional[Union[str, Secret]] = None,
450
+ *, timeout: Optional[float] = None, check: bool = False,
451
+ safe: Optional[bool] = None):
452
+ """Run a command via ``sudo -S`` feeding the password on stdin."""
453
+ pw = password if password is not None else self.config.password
454
+ raw = pw.reveal() if isinstance(pw, Secret) else (pw or "")
455
+
456
+ def _do():
457
+ self._ensure_connected()
458
+ full = f"sudo -S -p '' {command}"
459
+ self._emit(logging.INFO, f"$ sudo {command}")
460
+ start = time.time()
461
+ stdin, stdout, stderr = self._client.exec_command(full, timeout=timeout,
462
+ get_pty=True)
463
+ if raw:
464
+ stdin.write(raw + "\n")
465
+ stdin.flush()
466
+ exit_code = stdout.channel.recv_exit_status()
467
+ out = stdout.read().decode("utf-8", errors="replace")
468
+ err = stderr.read().decode("utf-8", errors="replace")
469
+ result = CommandResult(f"sudo {command}", exit_code, out, err,
470
+ time.time() - start, host=self.config.host)
471
+ if check and not result.ok:
472
+ raise SSHCommandError(command, result)
473
+ return result
474
+
475
+ return self._guard("sudo", _do, safe=safe)
476
+
477
+ # ------------------------------------------------------------------ #
478
+ # Interactive shell
479
+ # ------------------------------------------------------------------ #
480
+ def open_shell(self, *, term: str = "xterm", width: int = 200,
481
+ height: int = 50, safe: Optional[bool] = None):
482
+ """Open a persistent interactive ShellSession (PTY)."""
483
+ def _do():
484
+ self._ensure_connected()
485
+ chan = self._client.invoke_shell(term=term, width=width, height=height)
486
+ return ShellSession(chan)
487
+ return self._guard("open_shell", _do, safe=safe)
488
+
489
+ # ------------------------------------------------------------------ #
490
+ # SFTP — full operation surface
491
+ # ------------------------------------------------------------------ #
492
+ def sftp(self) -> paramiko.SFTPClient:
493
+ """Return the live SFTPClient (opened lazily, reused)."""
494
+ self._ensure_connected()
495
+ if self._sftp is None:
496
+ self._sftp = self._client.open_sftp()
497
+ return self._sftp
498
+
499
+ @staticmethod
500
+ def _rnorm(path: str) -> str:
501
+ return path.replace("\\", "/")
502
+
503
+ def _remote_is_dir(self, sftp, path: str) -> bool:
504
+ try:
505
+ return stat.S_ISDIR(sftp.stat(path).st_mode)
506
+ except IOError:
507
+ return False
508
+
509
+ def _remote_exists(self, sftp, path: str) -> bool:
510
+ try:
511
+ sftp.stat(path)
512
+ return True
513
+ except IOError:
514
+ return False
515
+
516
+ def _mkdir_p(self, sftp, remote_dir: str) -> None:
517
+ remote_dir = self._rnorm(remote_dir)
518
+ if remote_dir in ("", "/", "."):
519
+ return
520
+ is_abs = remote_dir.startswith("/")
521
+ parts = [p for p in remote_dir.split("/") if p]
522
+ current = ""
523
+ for part in parts:
524
+ current = (f"/{part}" if is_abs else part) if current == "" \
525
+ else f"{current}/{part}"
526
+ if not self._remote_exists(sftp, current):
527
+ try:
528
+ sftp.mkdir(current)
529
+ except IOError:
530
+ pass
531
+
532
+ # thin pass-throughs (paramiko parity) -----------------------------
533
+ def listdir(self, path: str = ".", *, safe=None):
534
+ return self._guard("listdir", lambda: self.sftp().listdir(self._rnorm(path)),
535
+ safe=safe)
536
+
537
+ def listdir_attr(self, path: str = ".", *, safe=None):
538
+ return self._guard("listdir_attr",
539
+ lambda: self.sftp().listdir_attr(self._rnorm(path)), safe=safe)
540
+
541
+ def stat(self, path: str, *, safe=None):
542
+ return self._guard("stat", lambda: self.sftp().stat(self._rnorm(path)), safe=safe)
543
+
544
+ def lstat(self, path: str, *, safe=None):
545
+ return self._guard("lstat", lambda: self.sftp().lstat(self._rnorm(path)),
546
+ safe=safe)
547
+
548
+ def exists(self, path: str, *, safe=None):
549
+ return self._guard("exists",
550
+ lambda: self._remote_exists(self.sftp(), self._rnorm(path)),
551
+ safe=safe)
552
+
553
+ def isdir(self, path: str, *, safe=None):
554
+ return self._guard("isdir",
555
+ lambda: self._remote_is_dir(self.sftp(), self._rnorm(path)),
556
+ safe=safe)
557
+
558
+ def mkdir(self, path: str, mode: int = 0o777, *, safe=None):
559
+ return self._guard("mkdir",
560
+ lambda: self.sftp().mkdir(self._rnorm(path), mode), safe=safe)
561
+
562
+ def makedirs(self, path: str, *, safe=None):
563
+ return self._guard("makedirs", lambda: self._mkdir_p(self.sftp(), path),
564
+ safe=safe)
565
+
566
+ def rmdir(self, path: str, *, safe=None):
567
+ return self._guard("rmdir", lambda: self.sftp().rmdir(self._rnorm(path)),
568
+ safe=safe)
569
+
570
+ def remove(self, path: str, *, safe=None):
571
+ return self._guard("remove", lambda: self.sftp().remove(self._rnorm(path)),
572
+ safe=safe)
573
+
574
+ def rename(self, old: str, new: str, *, safe=None):
575
+ return self._guard("rename",
576
+ lambda: self.sftp().posix_rename(self._rnorm(old),
577
+ self._rnorm(new)), safe=safe)
578
+
579
+ def chmod(self, path: str, mode: int, *, safe=None):
580
+ return self._guard("chmod", lambda: self.sftp().chmod(self._rnorm(path), mode),
581
+ safe=safe)
582
+
583
+ def chown(self, path: str, uid: int, gid: int, *, safe=None):
584
+ return self._guard("chown",
585
+ lambda: self.sftp().chown(self._rnorm(path), uid, gid),
586
+ safe=safe)
587
+
588
+ def symlink(self, source: str, dest: str, *, safe=None):
589
+ return self._guard("symlink",
590
+ lambda: self.sftp().symlink(source, self._rnorm(dest)),
591
+ safe=safe)
592
+
593
+ def readlink(self, path: str, *, safe=None):
594
+ return self._guard("readlink", lambda: self.sftp().readlink(self._rnorm(path)),
595
+ safe=safe)
596
+
597
+ def open(self, path: str, mode: str = "r", bufsize: int = -1, *, safe=None):
598
+ """Open a remote file object (paramiko SFTPFile)."""
599
+ return self._guard("open",
600
+ lambda: self.sftp().open(self._rnorm(path), mode, bufsize),
601
+ safe=safe)
602
+
603
+ def read_text(self, path: str, encoding: str = "utf-8", *, safe=None):
604
+ def _do():
605
+ with self.sftp().open(self._rnorm(path), "r") as fh:
606
+ return fh.read().decode(encoding, errors="replace")
607
+ return self._guard("read_text", _do, safe=safe)
608
+
609
+ def write_text(self, path: str, data: str, encoding: str = "utf-8", *, safe=None):
610
+ def _do():
611
+ with self.sftp().open(self._rnorm(path), "w") as fh:
612
+ fh.write(data.encode(encoding))
613
+ return len(data)
614
+ return self._guard("write_text", _do, safe=safe)
615
+
616
+ def walk(self, remote_path: str):
617
+ """Generator like os.walk over a remote tree (dirpath, dirs, files)."""
618
+ sftp = self.sftp()
619
+ remote_path = self._rnorm(remote_path)
620
+ dirs, files = [], []
621
+ for entry in sftp.listdir_attr(remote_path):
622
+ (dirs if stat.S_ISDIR(entry.st_mode) else files).append(entry.filename)
623
+ yield remote_path, dirs, files
624
+ for d in dirs:
625
+ yield from self.walk(posixpath.join(remote_path, d))
626
+
627
+ # ------------------------------------------------------------------ #
628
+ # Push / Pull (SFTP) with TransferResult
629
+ # ------------------------------------------------------------------ #
630
+ def push(self, local_path: str, remote_path: str, *, recursive: bool = False,
631
+ make_dirs: bool = True, callback=None, safe: Optional[bool] = None):
632
+ """Upload a file or directory (SFTP). Returns TransferResult."""
633
+ return self._guard("push", self._push, local_path, remote_path,
634
+ recursive=recursive, make_dirs=make_dirs,
635
+ callback=callback, safe=safe)
636
+
637
+ def _push(self, local_path, remote_path, *, recursive, make_dirs, callback):
638
+ local_path = os.path.expanduser(local_path)
639
+ remote_path = self._rnorm(remote_path)
640
+ sftp = self.sftp()
641
+ if not os.path.exists(local_path):
642
+ raise SSHTransferError(f"Local path does not exist: {local_path}")
643
+ start = time.time()
644
+ try:
645
+ if os.path.isdir(local_path):
646
+ if not recursive:
647
+ raise SSHTransferError(f"{local_path} is a directory; "
648
+ f"pass recursive=True.")
649
+ size, count = self._push_dir(sftp, local_path, remote_path, callback)
650
+ else:
651
+ if make_dirs:
652
+ parent = posixpath.dirname(remote_path)
653
+ if parent:
654
+ self._mkdir_p(sftp, parent)
655
+ self._emit(logging.INFO, f"PUSH {local_path} -> {remote_path}")
656
+ sftp.put(local_path, remote_path, callback=callback)
657
+ size, count = os.path.getsize(local_path), 1
658
+ except SSHTransferError:
659
+ raise
660
+ except (IOError, OSError, paramiko.SSHException) as exc:
661
+ raise SSHTransferError(f"Failed to push {local_path} -> {remote_path}: "
662
+ f"{exc}") from exc
663
+ return TransferResult(local_path, remote_path, "push", "sftp", size,
664
+ time.time() - start, count)
665
+
666
+ def _push_dir(self, sftp, local_dir, remote_dir, callback):
667
+ self._mkdir_p(sftp, remote_dir)
668
+ total, count = 0, 0
669
+ for entry in os.listdir(local_dir):
670
+ lpath = os.path.join(local_dir, entry)
671
+ rpath = posixpath.join(remote_dir, entry)
672
+ if os.path.isdir(lpath):
673
+ s, c = self._push_dir(sftp, lpath, rpath, callback)
674
+ total += s
675
+ count += c
676
+ else:
677
+ self._emit(logging.INFO, f"PUSH {lpath} -> {rpath}")
678
+ sftp.put(lpath, rpath, callback=callback)
679
+ total += os.path.getsize(lpath)
680
+ count += 1
681
+ return total, count
682
+
683
+ def pull(self, remote_path: str, local_path: str, *, recursive: bool = False,
684
+ make_dirs: bool = True, callback=None, safe: Optional[bool] = None):
685
+ """Download a file or directory (SFTP). Returns TransferResult."""
686
+ return self._guard("pull", self._pull, remote_path, local_path,
687
+ recursive=recursive, make_dirs=make_dirs,
688
+ callback=callback, safe=safe)
689
+
690
+ def _pull(self, remote_path, local_path, *, recursive, make_dirs, callback):
691
+ remote_path = self._rnorm(remote_path)
692
+ local_path = os.path.expanduser(local_path)
693
+ sftp = self.sftp()
694
+ start = time.time()
695
+ try:
696
+ if self._remote_is_dir(sftp, remote_path):
697
+ if not recursive:
698
+ raise SSHTransferError(f"{remote_path} is a directory; "
699
+ f"pass recursive=True.")
700
+ size, count = self._pull_dir(sftp, remote_path, local_path, callback)
701
+ else:
702
+ if make_dirs:
703
+ parent = os.path.dirname(local_path)
704
+ if parent and not os.path.exists(parent):
705
+ os.makedirs(parent, exist_ok=True)
706
+ self._emit(logging.INFO, f"PULL {remote_path} -> {local_path}")
707
+ sftp.get(remote_path, local_path, callback=callback)
708
+ size, count = os.path.getsize(local_path), 1
709
+ except SSHTransferError:
710
+ raise
711
+ except (IOError, OSError, paramiko.SSHException) as exc:
712
+ raise SSHTransferError(f"Failed to pull {remote_path} -> {local_path}: "
713
+ f"{exc}") from exc
714
+ return TransferResult(remote_path, local_path, "pull", "sftp", size,
715
+ time.time() - start, count)
716
+
717
+ def _pull_dir(self, sftp, remote_dir, local_dir, callback):
718
+ os.makedirs(local_dir, exist_ok=True)
719
+ total, count = 0, 0
720
+ for entry in sftp.listdir_attr(remote_dir):
721
+ rpath = posixpath.join(remote_dir, entry.filename)
722
+ lpath = os.path.join(local_dir, entry.filename)
723
+ if stat.S_ISDIR(entry.st_mode):
724
+ s, c = self._pull_dir(sftp, rpath, lpath, callback)
725
+ total += s
726
+ count += c
727
+ else:
728
+ self._emit(logging.INFO, f"PULL {rpath} -> {lpath}")
729
+ sftp.get(rpath, lpath, callback=callback)
730
+ total += entry.st_size or 0
731
+ count += 1
732
+ return total, count
733
+
734
+ # ------------------------------------------------------------------ #
735
+ # SCP (optional, via the 'scp' package)
736
+ # ------------------------------------------------------------------ #
737
+ @staticmethod
738
+ def scp_available() -> bool:
739
+ return _HAS_SCP
740
+
741
+ def _scp_client(self, callback=None):
742
+ if not _HAS_SCP:
743
+ raise SSHTransferError(
744
+ "SCP support needs the 'scp' package. Install it with: pip install scp "
745
+ "(SFTP push/pull already works without it).")
746
+ self._ensure_connected()
747
+ progress = (lambda fn, sz, sent: callback(sent, sz)) if callback else None
748
+ return SCPClient(self.transport, progress=progress)
749
+
750
+ def scp_push(self, local_path: str, remote_path: str, *, recursive: bool = False,
751
+ callback=None, safe: Optional[bool] = None):
752
+ """Upload via the SCP protocol (alternative to SFTP push)."""
753
+ def _do():
754
+ start = time.time()
755
+ with self._scp_client(callback) as scp:
756
+ scp.put(os.path.expanduser(local_path), remote_path,
757
+ recursive=recursive)
758
+ size = (os.path.getsize(os.path.expanduser(local_path))
759
+ if os.path.isfile(os.path.expanduser(local_path)) else 0)
760
+ return TransferResult(local_path, remote_path, "push", "scp", size,
761
+ time.time() - start)
762
+ return self._guard("scp_push", _do, safe=safe)
763
+
764
+ def scp_pull(self, remote_path: str, local_path: str, *, recursive: bool = False,
765
+ callback=None, safe: Optional[bool] = None):
766
+ """Download via the SCP protocol (alternative to SFTP pull)."""
767
+ def _do():
768
+ start = time.time()
769
+ with self._scp_client(callback) as scp:
770
+ scp.get(remote_path, os.path.expanduser(local_path),
771
+ recursive=recursive)
772
+ lp = os.path.expanduser(local_path)
773
+ size = os.path.getsize(lp) if os.path.isfile(lp) else 0
774
+ return TransferResult(remote_path, local_path, "pull", "scp", size,
775
+ time.time() - start)
776
+ return self._guard("scp_pull", _do, safe=safe)
777
+
778
+ # ------------------------------------------------------------------ #
779
+ # Context manager
780
+ # ------------------------------------------------------------------ #
781
+ def __enter__(self) -> "SSHHandler":
782
+ result = self.connect()
783
+ if isinstance(result, OperationResult) and not result.success:
784
+ raise result.error or SSHConnectionError("connect failed")
785
+ return self
786
+
787
+ def __exit__(self, *exc) -> None:
788
+ self.disconnect()
789
+
790
+ def __repr__(self) -> str:
791
+ state = "connected" if self.is_connected else "disconnected"
792
+ return (f"<SSHHandler {self.config.auth_username}@{self.config.host}:"
793
+ f"{self.config.port} {state}>")