sshler 0.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sshler/ssh.py ADDED
@@ -0,0 +1,329 @@
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ import logging
5
+ import os
6
+ import shlex
7
+ import socket
8
+ from asyncio.subprocess import Process
9
+ from pathlib import Path
10
+
11
+ import asyncssh
12
+
13
+
14
+ class SSHError(Exception):
15
+ """Raised when an SSH connection or command fails.
16
+
17
+ English:
18
+ Wrapper exception used throughout sshler so callers can handle errors
19
+ without depending directly on ``asyncssh`` internals.
20
+
21
+ 日本語:
22
+ ``asyncssh`` の詳細に依存せずに呼び出し側が例外処理できるようにするための
23
+ ラッパー例外です。
24
+ """
25
+
26
+
27
+ async def connect(
28
+ host: str,
29
+ user: str,
30
+ port: int = 22,
31
+ keyfile: str | None = None,
32
+ known_hosts: str | None = None,
33
+ ssh_config_path: str | None = None,
34
+ ssh_alias: str | None = None,
35
+ allow_alias: bool = True,
36
+ ) -> asyncssh.SSHClientConnection:
37
+ """Establish an SSH connection using asyncssh.
38
+
39
+ English:
40
+ Opens a connection to ``host`` with sensible defaults and optional
41
+ alias expansion. Errors are normalised to :class:`SSHError`.
42
+
43
+ 日本語:
44
+ ``host`` への SSH 接続を確立します。必要に応じてエイリアス解決を行い、
45
+ 失敗した場合は :class:`SSHError` として補足します。
46
+
47
+ Args:
48
+ host: Target host to reach.
49
+ user: Username used for the SSH session.
50
+ port: SSH port exposed by the host.
51
+ keyfile: Optional explicit private-key path.
52
+ known_hosts: Known-hosts override or ``"ignore"`` to disable checks.
53
+ ssh_config_path: Optional SSH config location.
54
+ ssh_alias: Alias name to expand via ``ssh -G`` when DNS fails.
55
+ allow_alias: Whether alias expansion is permitted.
56
+
57
+ Returns:
58
+ asyncssh.SSHClientConnection: Live SSH connection instance.
59
+
60
+ Raises:
61
+ SSHError: Propagates connection issues through a project-specific type.
62
+ """
63
+
64
+ if known_hosts and isinstance(known_hosts, str) and known_hosts.lower() == "ignore":
65
+ known_hosts_path = None
66
+ else:
67
+ known_hosts_path = known_hosts
68
+
69
+ connect_host = host
70
+ connect_user = user
71
+ connect_port = port
72
+ connect_keyfile = keyfile
73
+
74
+ if allow_alias and ssh_alias and not _is_resolvable(connect_host):
75
+ alias_data = await _expand_alias(ssh_alias)
76
+ resolved_host = alias_data.get("hostname")
77
+ if resolved_host:
78
+ connect_host = resolved_host
79
+ if alias_data.get("user") and not connect_user:
80
+ connect_user = alias_data["user"]
81
+ try:
82
+ connect_port = int(alias_data.get("port") or connect_port)
83
+ except (TypeError, ValueError):
84
+ pass
85
+ if not connect_keyfile and alias_data.get("identityfile"):
86
+ connect_keyfile = alias_data["identityfile"]
87
+ LOGGER.info(
88
+ "Resolved SSH alias %s -> host=%s port=%s user=%s",
89
+ ssh_alias,
90
+ connect_host,
91
+ connect_port,
92
+ connect_user,
93
+ )
94
+ else:
95
+ LOGGER.warning("Failed to resolve SSH alias %s; falling back to %s", ssh_alias, host)
96
+
97
+ try:
98
+ connection = await asyncssh.connect(
99
+ host=connect_host,
100
+ port=connect_port,
101
+ username=connect_user,
102
+ client_keys=[connect_keyfile] if connect_keyfile else None,
103
+ known_hosts=known_hosts_path,
104
+ config=[ssh_config_path] if ssh_config_path else None,
105
+ )
106
+ except (OSError, asyncssh.Error) as exc:
107
+ raise SSHError(str(exc)) from exc
108
+ return connection
109
+
110
+
111
+ async def open_tmux(
112
+ connection: asyncssh.SSHClientConnection,
113
+ working_directory: str,
114
+ session: str,
115
+ terminal_type: str = "xterm-256color",
116
+ columns: int = 120,
117
+ rows: int = 32,
118
+ environment: dict[str, str] | None = None,
119
+ ) -> asyncssh.SSHClientProcess:
120
+ """Launch or attach to a tmux session on the remote host.
121
+
122
+ English:
123
+ Spawns ``tmux new -As`` ensuring the session name is safe and returning
124
+ the running process object.
125
+
126
+ 日本語:
127
+ ``tmux new -As`` コマンドを発行し、セッション名を安全な形式に整えてプロセス
128
+ オブジェクトを返します。
129
+
130
+ Args:
131
+ connection: Active SSH connection.
132
+ working_directory: Working directory for the tmux session.
133
+ session: Desired session name.
134
+ terminal_type: Terminal type to request from tmux.
135
+ columns: Width to request for the pseudo-terminal.
136
+ rows: Height to request for the pseudo-terminal.
137
+ environment: Environment variables forwarded to the remote session.
138
+
139
+ Returns:
140
+ asyncssh.SSHClientProcess: Process representing the tmux attachment.
141
+ """
142
+
143
+ # sanitize session name minimally
144
+ safe_session = "".join(ch if ch.isalnum() or ch in "-_." else "_" for ch in session) or "sshler"
145
+ command = f"tmux new -As {shlex.quote(safe_session)} -c {shlex.quote(working_directory)}"
146
+ process = await connection.create_process(
147
+ command=command,
148
+ term_type=terminal_type,
149
+ term_size=(columns, rows),
150
+ encoding=None, # bytes
151
+ env=environment,
152
+ )
153
+ return process
154
+
155
+
156
+ async def sftp_list_directory(
157
+ connection: asyncssh.SSHClientConnection, path: str
158
+ ) -> list[dict[str, object]]:
159
+ """Return directory entries for ``path`` via SFTP.
160
+
161
+ English:
162
+ Lists children of ``path`` and records whether each entry is a
163
+ directory before sorting directories first.
164
+
165
+ 日本語:
166
+ 指定された ``path`` の子要素を列挙し、ディレクトリかどうかの情報とともに
167
+ 取得してディレクトリを先頭に並べ替えます。
168
+
169
+ Args:
170
+ connection: Active SSH connection used to start SFTP.
171
+ path: Remote directory to enumerate.
172
+
173
+ Returns:
174
+ list[dict[str, object]]: Metadata entries sorted with directories first.
175
+ """
176
+
177
+ sftp_client = await connection.start_sftp_client()
178
+ entries: list[dict[str, object]] = []
179
+ try:
180
+ for filename in await sftp_client.listdir(path):
181
+ try:
182
+ stats = await sftp_client.stat(f"{path.rstrip('/')}/{filename}")
183
+ entries.append(
184
+ {
185
+ "name": filename,
186
+ "is_directory": (stats.permissions & 0o40000)
187
+ == 0o40000, # check the directory bit (s_ifdir)
188
+ "size": stats.size,
189
+ }
190
+ )
191
+ except Exception:
192
+ pass
193
+ finally:
194
+ try:
195
+ await sftp_client.exit()
196
+ except Exception:
197
+ pass
198
+ entries.sort(key=lambda entry: (not entry["is_directory"], entry["name"].lower()))
199
+ return entries
200
+
201
+
202
+ async def sftp_is_directory(connection: asyncssh.SSHClientConnection, path: str) -> bool:
203
+ """Return whether ``path`` resolves to a directory via SFTP.
204
+
205
+ English:
206
+ Performs an ``sftp.stat`` call and inspects the directory bit.
207
+
208
+ 日本語:
209
+ ``sftp.stat`` を実行してディレクトリかどうかを判定します。
210
+
211
+ Args:
212
+ connection: Active SSH connection used to start SFTP.
213
+ path: Remote path to probe.
214
+
215
+ Returns:
216
+ bool: ``True`` when ``path`` is a directory, otherwise ``False``.
217
+ """
218
+
219
+ sftp_client = await connection.start_sftp_client()
220
+ try:
221
+ stats = await sftp_client.stat(path)
222
+ return (stats.permissions & 0o40000) == 0o40000
223
+ finally:
224
+ try:
225
+ await sftp_client.exit()
226
+ except Exception:
227
+ pass
228
+
229
+
230
+ async def sftp_read_file(
231
+ connection: asyncssh.SSHClientConnection,
232
+ path: str,
233
+ max_bytes: int = 65536,
234
+ ) -> str:
235
+ """Read a text file over SFTP, truncated to ``max_bytes``.
236
+
237
+ English:
238
+ Opens the remote file, reads up to ``max_bytes`` bytes, and returns a
239
+ UTF-8 string with replacement for undecodable bytes.
240
+
241
+ 日本語:
242
+ リモートファイルを開いて最大 ``max_bytes`` バイトまで読み込み、UTF-8 文字列
243
+ として返します(復号できないバイトは置換します)。
244
+ """
245
+
246
+ sftp_client = await connection.start_sftp_client()
247
+ try:
248
+ async with await sftp_client.open(path, "r", encoding="utf-8") as remote_file:
249
+ data = await remote_file.read(max_bytes)
250
+ if isinstance(data, bytes):
251
+ return data.decode("utf-8", errors="replace")
252
+ return data
253
+ finally:
254
+ try:
255
+ await sftp_client.exit()
256
+ except Exception:
257
+ pass
258
+
259
+
260
+ async def _expand_alias(alias: str) -> dict[str, str]:
261
+ process: Process | None = None
262
+ try:
263
+ process = await asyncio.create_subprocess_exec(
264
+ _ssh_command(),
265
+ "-G",
266
+ alias,
267
+ stdout=asyncio.subprocess.PIPE,
268
+ stderr=asyncio.subprocess.PIPE,
269
+ )
270
+ except Exception:
271
+ return {}
272
+
273
+ stdout, _ = await process.communicate()
274
+ if process.returncode != 0:
275
+ return {}
276
+
277
+ data: dict[str, str] = {}
278
+ for line in stdout.decode().splitlines():
279
+ key, _, value = line.partition(" ")
280
+ if key and value:
281
+ data[key.strip().lower()] = value.strip()
282
+
283
+ return {
284
+ "hostname": data.get("hostname"),
285
+ "user": data.get("user"),
286
+ "port": data.get("port"),
287
+ "identityfile": data.get("identityfile"),
288
+ }
289
+
290
+
291
+ def _is_resolvable(name: str) -> bool:
292
+ """Return whether ``name`` resolves via DNS/system hosts.
293
+
294
+ English:
295
+ Lightweight guard to decide if ``ssh -G`` alias expansion is needed.
296
+
297
+ 日本語:
298
+ DNS や hosts ファイルで名前解決できるかを確認し、エイリアス解決が必要かどうかを
299
+ 判定します。
300
+ """
301
+
302
+ try:
303
+ socket.getaddrinfo(name, None)
304
+ return True
305
+ except OSError:
306
+ return False
307
+
308
+
309
+ def _ssh_command() -> str:
310
+ """Return the preferred ``ssh`` executable path for the current OS.
311
+
312
+ English:
313
+ On Windows the system OpenSSH path is used to avoid PATH hijacking;
314
+ otherwise ``ssh`` from the user's PATH is returned.
315
+
316
+ 日本語:
317
+ Windows では PATH 乗っ取りを避けるためにシステムの OpenSSH を優先し、
318
+ それ以外ではユーザーの PATH にある ``ssh`` を利用します。
319
+ """
320
+
321
+ if os.name == "nt":
322
+ system_root = os.environ.get("SystemRoot", "C:\\Windows")
323
+ candidate = Path(system_root) / "System32" / "OpenSSH" / "ssh.exe"
324
+ if candidate.exists():
325
+ return str(candidate)
326
+ return "ssh"
327
+
328
+
329
+ LOGGER = logging.getLogger(__name__)
sshler/ssh_config.py ADDED
@@ -0,0 +1,134 @@
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import shlex
5
+ from dataclasses import dataclass, field
6
+ from pathlib import Path
7
+ from typing import Any
8
+
9
+ SSH_CONFIG_ENV = "SSHLER_SSH_CONFIG"
10
+ DEFAULT_SSH_CONFIG = Path.home() / ".ssh" / "config"
11
+
12
+
13
+ @dataclass
14
+ class HostConfig:
15
+ """Subset of SSH host configuration important to sshler."""
16
+
17
+ name: str
18
+ hostname: str | None = None
19
+ user: str | None = None
20
+ port: int | None = None
21
+ identity_files: list[str] = field(default_factory=list)
22
+ raw: dict[str, Any] = field(default_factory=dict)
23
+
24
+
25
+ def _expand_path(value: str, *, base_dir: Path) -> str:
26
+ expanded = os.path.expandvars(os.path.expanduser(value))
27
+ expanded_path = Path(expanded)
28
+ if expanded_path.is_absolute():
29
+ return str(expanded_path)
30
+ return str((base_dir / expanded_path).resolve())
31
+
32
+
33
+ def _parse_file(path: Path, *, seen: set[Path]) -> dict[str, HostConfig]:
34
+ hosts: dict[str, HostConfig] = {}
35
+ if not path.exists() or not path.is_file():
36
+ return hosts
37
+
38
+ real_path = path.resolve()
39
+ if real_path in seen:
40
+ return hosts
41
+ seen.add(real_path)
42
+
43
+ current_patterns: list[str] = []
44
+ current_options: dict[str, Any] = {}
45
+
46
+ def flush() -> None:
47
+ nonlocal current_patterns, current_options
48
+ if not current_patterns:
49
+ return
50
+ for pattern in current_patterns:
51
+ if any(ch in pattern for ch in "*? "):
52
+ continue
53
+ entry = HostConfig(
54
+ name=pattern,
55
+ hostname=current_options.get("hostname"),
56
+ user=current_options.get("user"),
57
+ port=int(current_options["port"]) if "port" in current_options else None,
58
+ identity_files=[
59
+ _expand_path(value, base_dir=path.parent)
60
+ for value in current_options.get("identityfile", [])
61
+ ],
62
+ raw=current_options.copy(),
63
+ )
64
+ hosts[pattern] = entry
65
+ current_patterns = []
66
+ current_options = {}
67
+
68
+ with path.open("r", encoding="utf-8") as file_pointer:
69
+ for line in file_pointer:
70
+ stripped = line.strip()
71
+ if not stripped or stripped.startswith("#"):
72
+ continue
73
+ if "#" in stripped:
74
+ before, _, _ = stripped.partition("#")
75
+ stripped = before.strip()
76
+ if not stripped:
77
+ continue
78
+
79
+ try:
80
+ tokens = shlex.split(stripped, comments=False, posix=True)
81
+ except ValueError:
82
+ continue
83
+ if not tokens:
84
+ continue
85
+
86
+ keyword = tokens[0].lower()
87
+ values = tokens[1:]
88
+ if keyword == "host":
89
+ flush()
90
+ current_patterns = values
91
+ current_options = {}
92
+ continue
93
+
94
+ if keyword == "include" and values:
95
+ include_pattern = values[0]
96
+ parent = path.parent
97
+ for include_path in parent.glob(include_pattern):
98
+ hosts.update(_parse_file(include_path, seen=seen))
99
+ continue
100
+
101
+ if keyword == "match":
102
+ flush()
103
+ current_patterns = []
104
+ current_options = {}
105
+ continue
106
+
107
+ if not current_patterns:
108
+ continue
109
+
110
+ if keyword in {"hostname", "user", "port"} and values:
111
+ current_options[keyword] = values[-1]
112
+ elif keyword == "identityfile" and values:
113
+ current_options.setdefault("identityfile", []).extend(values)
114
+ else:
115
+ current_options[keyword] = values[-1] if values else None
116
+
117
+ flush()
118
+ return hosts
119
+
120
+
121
+ def load_ssh_config(explicit_path: str | None = None) -> dict[str, HostConfig]:
122
+ explicit = explicit_path or os.getenv(SSH_CONFIG_ENV)
123
+ seen: set[Path] = set()
124
+ candidates: list[Path] = []
125
+ if explicit:
126
+ candidates.append(Path(explicit).expanduser())
127
+ else:
128
+ candidates.append(DEFAULT_SSH_CONFIG)
129
+
130
+ aggregated: dict[str, HostConfig] = {}
131
+ for candidate in candidates:
132
+ aggregated.update(_parse_file(candidate, seen=seen))
133
+ return aggregated
134
+
sshler/state.py ADDED
@@ -0,0 +1,230 @@
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ import threading
5
+ import time
6
+ from collections.abc import Iterable, Sequence
7
+ from pathlib import Path
8
+ from typing import TYPE_CHECKING
9
+
10
+ from pydantic import Field
11
+ from sqler import SQLerDB, SQLerModel
12
+ from sqler.adapter import SQLiteAdapter
13
+ from sqler.query import SQLerField as F
14
+
15
+ STATE_FILENAME = "state.sqlite"
16
+
17
+ _DB_LOCK = threading.RLock()
18
+ _DB: SQLerDB | None = None
19
+ _DB_PATH: Path | None = None
20
+ _INITIALISED = False
21
+
22
+
23
+ class Favorite(SQLerModel):
24
+ """Persisted favourite directories per box."""
25
+
26
+ __tablename__ = "favorites"
27
+
28
+ box: str
29
+ path: str
30
+ position: int = 0
31
+ created_at: float = Field(default_factory=time.time)
32
+ updated_at: float = Field(default_factory=time.time)
33
+
34
+
35
+ if TYPE_CHECKING: # pragma: no cover - import for typing only
36
+ from .config import StoredBox
37
+
38
+
39
+ def initialize(config_dir: Path) -> None:
40
+ """Initialise the SQLite-backed state store using ``sqler``.
41
+
42
+ The state database lives alongside ``boxes.yaml`` and holds favourites and
43
+ future history records.
44
+ """
45
+
46
+ global _DB, _DB_PATH, _INITIALISED
47
+
48
+ config_dir = config_dir.expanduser()
49
+ config_dir.mkdir(parents=True, exist_ok=True)
50
+ target_path = config_dir / STATE_FILENAME
51
+
52
+ with _DB_LOCK:
53
+ if _INITIALISED and _DB_PATH == target_path:
54
+ return
55
+
56
+ if _DB is not None and _DB_PATH != target_path:
57
+ _DB.close()
58
+
59
+ adapter = SQLiteAdapter(path=str(target_path))
60
+ db = SQLerDB(adapter)
61
+ Favorite.set_db(db)
62
+ Favorite.ensure_index("box")
63
+ Favorite.ensure_index("position")
64
+
65
+ _DB = db
66
+ _DB_PATH = target_path
67
+ _INITIALISED = True
68
+
69
+
70
+ def reset_state() -> None:
71
+ """Reset the in-memory cache (used by tests)."""
72
+
73
+ global _DB, _DB_PATH, _INITIALISED
74
+ with _DB_LOCK:
75
+ if _DB is not None:
76
+ try:
77
+ _DB.close()
78
+ except Exception: # pragma: no cover - best effort cleanup
79
+ pass
80
+ _DB = None
81
+ _DB_PATH = None
82
+ _INITIALISED = False
83
+
84
+
85
+ def _require_db() -> SQLerDB:
86
+ if not _INITIALISED or _DB is None:
87
+ raise RuntimeError("State store not initialised")
88
+ return _DB
89
+
90
+
91
+ def migrate_legacy_favorites(stored: dict[str, StoredBox]) -> bool:
92
+ """Move favourites persisted in YAML into the sqler-backed store."""
93
+
94
+ if not stored:
95
+ return False
96
+
97
+ _require_db()
98
+ migrated = False
99
+ with _DB_LOCK:
100
+ for item in stored.values():
101
+ if not item.favorites:
102
+ continue
103
+ replace_favorites(item.name, item.favorites)
104
+ item.favorites.clear()
105
+ migrated = True
106
+ return migrated
107
+
108
+
109
+ def list_favorites(box_name: str) -> list[str]:
110
+ """Return the ordered favourites for ``box_name``."""
111
+
112
+ _require_db()
113
+ with _DB_LOCK:
114
+ rows = (
115
+ Favorite.query()
116
+ .filter(F("box") == box_name)
117
+ .order_by("position")
118
+ .all()
119
+ )
120
+ return [row.path for row in rows]
121
+
122
+
123
+ async def list_favorites_async(box_name: str) -> list[str]:
124
+ return await asyncio.to_thread(list_favorites, box_name)
125
+
126
+
127
+ def favorites_map(box_names: Sequence[str] | None = None) -> dict[str, list[str]]:
128
+ """Return favourites for the supplied ``box_names``."""
129
+
130
+ _require_db()
131
+ with _DB_LOCK:
132
+ if box_names is not None:
133
+ return {name: list_favorites(name) for name in box_names}
134
+
135
+ rows = Favorite.query().order_by("box").order_by("position").all()
136
+ mapping: dict[str, list[str]] = {}
137
+ for row in rows:
138
+ mapping.setdefault(row.box, []).append(row.path)
139
+ return mapping
140
+
141
+
142
+ def toggle_favorite(box_name: str, path: str) -> bool:
143
+ """Add or remove ``path`` from favourites. Returns ``True`` if added."""
144
+
145
+ if not path:
146
+ return False
147
+
148
+ _require_db()
149
+ now = time.time()
150
+ with _DB_LOCK:
151
+ query = Favorite.query().filter((F("box") == box_name) & (F("path") == path))
152
+ existing = query.first()
153
+ if existing:
154
+ existing.delete()
155
+ return False
156
+
157
+ position = _next_position(box_name)
158
+ Favorite(box=box_name, path=path, position=position, created_at=now, updated_at=now).save()
159
+ return True
160
+
161
+
162
+ async def toggle_favorite_async(box_name: str, path: str) -> bool:
163
+ return await asyncio.to_thread(toggle_favorite, box_name, path)
164
+
165
+
166
+ def replace_favorites(box_name: str, paths: Iterable[str]) -> None:
167
+ """Replace all favourites for ``box_name`` with ``paths`` preserving order."""
168
+
169
+ _require_db()
170
+ deduped: list[str] = []
171
+ seen: set[str] = set()
172
+ for raw in paths:
173
+ cleaned = raw.strip()
174
+ if not cleaned or cleaned in seen:
175
+ continue
176
+ deduped.append(cleaned)
177
+ seen.add(cleaned)
178
+
179
+ now = time.time()
180
+ with _DB_LOCK:
181
+ existing = {
182
+ fav.path: fav
183
+ for fav in Favorite.query().filter(F("box") == box_name).all()
184
+ }
185
+
186
+ for position, path in enumerate(deduped):
187
+ favourite = existing.pop(path, None)
188
+ if favourite is None:
189
+ Favorite(
190
+ box=box_name,
191
+ path=path,
192
+ position=position,
193
+ created_at=now,
194
+ updated_at=now,
195
+ ).save()
196
+ continue
197
+
198
+ if favourite.position != position:
199
+ favourite.position = position
200
+ favourite.updated_at = now
201
+ favourite.save()
202
+
203
+ for leftover in existing.values():
204
+ leftover.delete()
205
+
206
+
207
+ async def replace_favorites_async(box_name: str, paths: Iterable[str]) -> None:
208
+ await asyncio.to_thread(replace_favorites, box_name, list(paths))
209
+
210
+
211
+ def remove_box(box_name: str) -> None:
212
+ """Delete all persisted state for ``box_name``."""
213
+
214
+ _require_db()
215
+ with _DB_LOCK:
216
+ rows = Favorite.query().filter(F("box") == box_name).all()
217
+ for row in rows:
218
+ row.delete()
219
+
220
+
221
+ def _next_position(box_name: str) -> int:
222
+ existing = (
223
+ Favorite.query()
224
+ .filter(F("box") == box_name)
225
+ .order_by("position", desc=True)
226
+ .first()
227
+ )
228
+ if not existing:
229
+ return 0
230
+ return existing.position + 1