sshler 0.3.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sshler/__init__.py +10 -0
- sshler/cli.py +161 -0
- sshler/config.py +425 -0
- sshler/scripts/install-sshler-task.ps1 +28 -0
- sshler/scripts/remove-sshler-task.ps1 +15 -0
- sshler/scripts/run-sshler.ps1 +24 -0
- sshler/ssh.py +329 -0
- sshler/ssh_config.py +134 -0
- sshler/state.py +230 -0
- sshler/static/base.js +305 -0
- sshler/static/favicon-terminal.svg +8 -0
- sshler/static/favicon.svg +8 -0
- sshler/static/file-edit.js +81 -0
- sshler/static/file-view.js +60 -0
- sshler/static/style.css +158 -0
- sshler/static/term.js +304 -0
- sshler/templates/base.html +41 -0
- sshler/templates/box.html +53 -0
- sshler/templates/docs.html +40 -0
- sshler/templates/file_edit.html +30 -0
- sshler/templates/file_view.html +31 -0
- sshler/templates/index.html +49 -0
- sshler/templates/new_box.html +42 -0
- sshler/templates/partials/dir_listing.html +91 -0
- sshler/templates/term.html +67 -0
- sshler/webapp.py +1897 -0
- sshler-0.3.2.dist-info/METADATA +245 -0
- sshler-0.3.2.dist-info/RECORD +31 -0
- sshler-0.3.2.dist-info/WHEEL +5 -0
- sshler-0.3.2.dist-info/entry_points.txt +2 -0
- sshler-0.3.2.dist-info/top_level.txt +1 -0
sshler/ssh.py
ADDED
|
@@ -0,0 +1,329 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import logging
|
|
5
|
+
import os
|
|
6
|
+
import shlex
|
|
7
|
+
import socket
|
|
8
|
+
from asyncio.subprocess import Process
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
|
|
11
|
+
import asyncssh
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class SSHError(Exception):
|
|
15
|
+
"""Raised when an SSH connection or command fails.
|
|
16
|
+
|
|
17
|
+
English:
|
|
18
|
+
Wrapper exception used throughout sshler so callers can handle errors
|
|
19
|
+
without depending directly on ``asyncssh`` internals.
|
|
20
|
+
|
|
21
|
+
日本語:
|
|
22
|
+
``asyncssh`` の詳細に依存せずに呼び出し側が例外処理できるようにするための
|
|
23
|
+
ラッパー例外です。
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
async def connect(
|
|
28
|
+
host: str,
|
|
29
|
+
user: str,
|
|
30
|
+
port: int = 22,
|
|
31
|
+
keyfile: str | None = None,
|
|
32
|
+
known_hosts: str | None = None,
|
|
33
|
+
ssh_config_path: str | None = None,
|
|
34
|
+
ssh_alias: str | None = None,
|
|
35
|
+
allow_alias: bool = True,
|
|
36
|
+
) -> asyncssh.SSHClientConnection:
|
|
37
|
+
"""Establish an SSH connection using asyncssh.
|
|
38
|
+
|
|
39
|
+
English:
|
|
40
|
+
Opens a connection to ``host`` with sensible defaults and optional
|
|
41
|
+
alias expansion. Errors are normalised to :class:`SSHError`.
|
|
42
|
+
|
|
43
|
+
日本語:
|
|
44
|
+
``host`` への SSH 接続を確立します。必要に応じてエイリアス解決を行い、
|
|
45
|
+
失敗した場合は :class:`SSHError` として補足します。
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
host: Target host to reach.
|
|
49
|
+
user: Username used for the SSH session.
|
|
50
|
+
port: SSH port exposed by the host.
|
|
51
|
+
keyfile: Optional explicit private-key path.
|
|
52
|
+
known_hosts: Known-hosts override or ``"ignore"`` to disable checks.
|
|
53
|
+
ssh_config_path: Optional SSH config location.
|
|
54
|
+
ssh_alias: Alias name to expand via ``ssh -G`` when DNS fails.
|
|
55
|
+
allow_alias: Whether alias expansion is permitted.
|
|
56
|
+
|
|
57
|
+
Returns:
|
|
58
|
+
asyncssh.SSHClientConnection: Live SSH connection instance.
|
|
59
|
+
|
|
60
|
+
Raises:
|
|
61
|
+
SSHError: Propagates connection issues through a project-specific type.
|
|
62
|
+
"""
|
|
63
|
+
|
|
64
|
+
if known_hosts and isinstance(known_hosts, str) and known_hosts.lower() == "ignore":
|
|
65
|
+
known_hosts_path = None
|
|
66
|
+
else:
|
|
67
|
+
known_hosts_path = known_hosts
|
|
68
|
+
|
|
69
|
+
connect_host = host
|
|
70
|
+
connect_user = user
|
|
71
|
+
connect_port = port
|
|
72
|
+
connect_keyfile = keyfile
|
|
73
|
+
|
|
74
|
+
if allow_alias and ssh_alias and not _is_resolvable(connect_host):
|
|
75
|
+
alias_data = await _expand_alias(ssh_alias)
|
|
76
|
+
resolved_host = alias_data.get("hostname")
|
|
77
|
+
if resolved_host:
|
|
78
|
+
connect_host = resolved_host
|
|
79
|
+
if alias_data.get("user") and not connect_user:
|
|
80
|
+
connect_user = alias_data["user"]
|
|
81
|
+
try:
|
|
82
|
+
connect_port = int(alias_data.get("port") or connect_port)
|
|
83
|
+
except (TypeError, ValueError):
|
|
84
|
+
pass
|
|
85
|
+
if not connect_keyfile and alias_data.get("identityfile"):
|
|
86
|
+
connect_keyfile = alias_data["identityfile"]
|
|
87
|
+
LOGGER.info(
|
|
88
|
+
"Resolved SSH alias %s -> host=%s port=%s user=%s",
|
|
89
|
+
ssh_alias,
|
|
90
|
+
connect_host,
|
|
91
|
+
connect_port,
|
|
92
|
+
connect_user,
|
|
93
|
+
)
|
|
94
|
+
else:
|
|
95
|
+
LOGGER.warning("Failed to resolve SSH alias %s; falling back to %s", ssh_alias, host)
|
|
96
|
+
|
|
97
|
+
try:
|
|
98
|
+
connection = await asyncssh.connect(
|
|
99
|
+
host=connect_host,
|
|
100
|
+
port=connect_port,
|
|
101
|
+
username=connect_user,
|
|
102
|
+
client_keys=[connect_keyfile] if connect_keyfile else None,
|
|
103
|
+
known_hosts=known_hosts_path,
|
|
104
|
+
config=[ssh_config_path] if ssh_config_path else None,
|
|
105
|
+
)
|
|
106
|
+
except (OSError, asyncssh.Error) as exc:
|
|
107
|
+
raise SSHError(str(exc)) from exc
|
|
108
|
+
return connection
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
async def open_tmux(
|
|
112
|
+
connection: asyncssh.SSHClientConnection,
|
|
113
|
+
working_directory: str,
|
|
114
|
+
session: str,
|
|
115
|
+
terminal_type: str = "xterm-256color",
|
|
116
|
+
columns: int = 120,
|
|
117
|
+
rows: int = 32,
|
|
118
|
+
environment: dict[str, str] | None = None,
|
|
119
|
+
) -> asyncssh.SSHClientProcess:
|
|
120
|
+
"""Launch or attach to a tmux session on the remote host.
|
|
121
|
+
|
|
122
|
+
English:
|
|
123
|
+
Spawns ``tmux new -As`` ensuring the session name is safe and returning
|
|
124
|
+
the running process object.
|
|
125
|
+
|
|
126
|
+
日本語:
|
|
127
|
+
``tmux new -As`` コマンドを発行し、セッション名を安全な形式に整えてプロセス
|
|
128
|
+
オブジェクトを返します。
|
|
129
|
+
|
|
130
|
+
Args:
|
|
131
|
+
connection: Active SSH connection.
|
|
132
|
+
working_directory: Working directory for the tmux session.
|
|
133
|
+
session: Desired session name.
|
|
134
|
+
terminal_type: Terminal type to request from tmux.
|
|
135
|
+
columns: Width to request for the pseudo-terminal.
|
|
136
|
+
rows: Height to request for the pseudo-terminal.
|
|
137
|
+
environment: Environment variables forwarded to the remote session.
|
|
138
|
+
|
|
139
|
+
Returns:
|
|
140
|
+
asyncssh.SSHClientProcess: Process representing the tmux attachment.
|
|
141
|
+
"""
|
|
142
|
+
|
|
143
|
+
# sanitize session name minimally
|
|
144
|
+
safe_session = "".join(ch if ch.isalnum() or ch in "-_." else "_" for ch in session) or "sshler"
|
|
145
|
+
command = f"tmux new -As {shlex.quote(safe_session)} -c {shlex.quote(working_directory)}"
|
|
146
|
+
process = await connection.create_process(
|
|
147
|
+
command=command,
|
|
148
|
+
term_type=terminal_type,
|
|
149
|
+
term_size=(columns, rows),
|
|
150
|
+
encoding=None, # bytes
|
|
151
|
+
env=environment,
|
|
152
|
+
)
|
|
153
|
+
return process
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
async def sftp_list_directory(
|
|
157
|
+
connection: asyncssh.SSHClientConnection, path: str
|
|
158
|
+
) -> list[dict[str, object]]:
|
|
159
|
+
"""Return directory entries for ``path`` via SFTP.
|
|
160
|
+
|
|
161
|
+
English:
|
|
162
|
+
Lists children of ``path`` and records whether each entry is a
|
|
163
|
+
directory before sorting directories first.
|
|
164
|
+
|
|
165
|
+
日本語:
|
|
166
|
+
指定された ``path`` の子要素を列挙し、ディレクトリかどうかの情報とともに
|
|
167
|
+
取得してディレクトリを先頭に並べ替えます。
|
|
168
|
+
|
|
169
|
+
Args:
|
|
170
|
+
connection: Active SSH connection used to start SFTP.
|
|
171
|
+
path: Remote directory to enumerate.
|
|
172
|
+
|
|
173
|
+
Returns:
|
|
174
|
+
list[dict[str, object]]: Metadata entries sorted with directories first.
|
|
175
|
+
"""
|
|
176
|
+
|
|
177
|
+
sftp_client = await connection.start_sftp_client()
|
|
178
|
+
entries: list[dict[str, object]] = []
|
|
179
|
+
try:
|
|
180
|
+
for filename in await sftp_client.listdir(path):
|
|
181
|
+
try:
|
|
182
|
+
stats = await sftp_client.stat(f"{path.rstrip('/')}/{filename}")
|
|
183
|
+
entries.append(
|
|
184
|
+
{
|
|
185
|
+
"name": filename,
|
|
186
|
+
"is_directory": (stats.permissions & 0o40000)
|
|
187
|
+
== 0o40000, # check the directory bit (s_ifdir)
|
|
188
|
+
"size": stats.size,
|
|
189
|
+
}
|
|
190
|
+
)
|
|
191
|
+
except Exception:
|
|
192
|
+
pass
|
|
193
|
+
finally:
|
|
194
|
+
try:
|
|
195
|
+
await sftp_client.exit()
|
|
196
|
+
except Exception:
|
|
197
|
+
pass
|
|
198
|
+
entries.sort(key=lambda entry: (not entry["is_directory"], entry["name"].lower()))
|
|
199
|
+
return entries
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
async def sftp_is_directory(connection: asyncssh.SSHClientConnection, path: str) -> bool:
|
|
203
|
+
"""Return whether ``path`` resolves to a directory via SFTP.
|
|
204
|
+
|
|
205
|
+
English:
|
|
206
|
+
Performs an ``sftp.stat`` call and inspects the directory bit.
|
|
207
|
+
|
|
208
|
+
日本語:
|
|
209
|
+
``sftp.stat`` を実行してディレクトリかどうかを判定します。
|
|
210
|
+
|
|
211
|
+
Args:
|
|
212
|
+
connection: Active SSH connection used to start SFTP.
|
|
213
|
+
path: Remote path to probe.
|
|
214
|
+
|
|
215
|
+
Returns:
|
|
216
|
+
bool: ``True`` when ``path`` is a directory, otherwise ``False``.
|
|
217
|
+
"""
|
|
218
|
+
|
|
219
|
+
sftp_client = await connection.start_sftp_client()
|
|
220
|
+
try:
|
|
221
|
+
stats = await sftp_client.stat(path)
|
|
222
|
+
return (stats.permissions & 0o40000) == 0o40000
|
|
223
|
+
finally:
|
|
224
|
+
try:
|
|
225
|
+
await sftp_client.exit()
|
|
226
|
+
except Exception:
|
|
227
|
+
pass
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
async def sftp_read_file(
|
|
231
|
+
connection: asyncssh.SSHClientConnection,
|
|
232
|
+
path: str,
|
|
233
|
+
max_bytes: int = 65536,
|
|
234
|
+
) -> str:
|
|
235
|
+
"""Read a text file over SFTP, truncated to ``max_bytes``.
|
|
236
|
+
|
|
237
|
+
English:
|
|
238
|
+
Opens the remote file, reads up to ``max_bytes`` bytes, and returns a
|
|
239
|
+
UTF-8 string with replacement for undecodable bytes.
|
|
240
|
+
|
|
241
|
+
日本語:
|
|
242
|
+
リモートファイルを開いて最大 ``max_bytes`` バイトまで読み込み、UTF-8 文字列
|
|
243
|
+
として返します(復号できないバイトは置換します)。
|
|
244
|
+
"""
|
|
245
|
+
|
|
246
|
+
sftp_client = await connection.start_sftp_client()
|
|
247
|
+
try:
|
|
248
|
+
async with await sftp_client.open(path, "r", encoding="utf-8") as remote_file:
|
|
249
|
+
data = await remote_file.read(max_bytes)
|
|
250
|
+
if isinstance(data, bytes):
|
|
251
|
+
return data.decode("utf-8", errors="replace")
|
|
252
|
+
return data
|
|
253
|
+
finally:
|
|
254
|
+
try:
|
|
255
|
+
await sftp_client.exit()
|
|
256
|
+
except Exception:
|
|
257
|
+
pass
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
async def _expand_alias(alias: str) -> dict[str, str]:
|
|
261
|
+
process: Process | None = None
|
|
262
|
+
try:
|
|
263
|
+
process = await asyncio.create_subprocess_exec(
|
|
264
|
+
_ssh_command(),
|
|
265
|
+
"-G",
|
|
266
|
+
alias,
|
|
267
|
+
stdout=asyncio.subprocess.PIPE,
|
|
268
|
+
stderr=asyncio.subprocess.PIPE,
|
|
269
|
+
)
|
|
270
|
+
except Exception:
|
|
271
|
+
return {}
|
|
272
|
+
|
|
273
|
+
stdout, _ = await process.communicate()
|
|
274
|
+
if process.returncode != 0:
|
|
275
|
+
return {}
|
|
276
|
+
|
|
277
|
+
data: dict[str, str] = {}
|
|
278
|
+
for line in stdout.decode().splitlines():
|
|
279
|
+
key, _, value = line.partition(" ")
|
|
280
|
+
if key and value:
|
|
281
|
+
data[key.strip().lower()] = value.strip()
|
|
282
|
+
|
|
283
|
+
return {
|
|
284
|
+
"hostname": data.get("hostname"),
|
|
285
|
+
"user": data.get("user"),
|
|
286
|
+
"port": data.get("port"),
|
|
287
|
+
"identityfile": data.get("identityfile"),
|
|
288
|
+
}
|
|
289
|
+
|
|
290
|
+
|
|
291
|
+
def _is_resolvable(name: str) -> bool:
|
|
292
|
+
"""Return whether ``name`` resolves via DNS/system hosts.
|
|
293
|
+
|
|
294
|
+
English:
|
|
295
|
+
Lightweight guard to decide if ``ssh -G`` alias expansion is needed.
|
|
296
|
+
|
|
297
|
+
日本語:
|
|
298
|
+
DNS や hosts ファイルで名前解決できるかを確認し、エイリアス解決が必要かどうかを
|
|
299
|
+
判定します。
|
|
300
|
+
"""
|
|
301
|
+
|
|
302
|
+
try:
|
|
303
|
+
socket.getaddrinfo(name, None)
|
|
304
|
+
return True
|
|
305
|
+
except OSError:
|
|
306
|
+
return False
|
|
307
|
+
|
|
308
|
+
|
|
309
|
+
def _ssh_command() -> str:
|
|
310
|
+
"""Return the preferred ``ssh`` executable path for the current OS.
|
|
311
|
+
|
|
312
|
+
English:
|
|
313
|
+
On Windows the system OpenSSH path is used to avoid PATH hijacking;
|
|
314
|
+
otherwise ``ssh`` from the user's PATH is returned.
|
|
315
|
+
|
|
316
|
+
日本語:
|
|
317
|
+
Windows では PATH 乗っ取りを避けるためにシステムの OpenSSH を優先し、
|
|
318
|
+
それ以外ではユーザーの PATH にある ``ssh`` を利用します。
|
|
319
|
+
"""
|
|
320
|
+
|
|
321
|
+
if os.name == "nt":
|
|
322
|
+
system_root = os.environ.get("SystemRoot", "C:\\Windows")
|
|
323
|
+
candidate = Path(system_root) / "System32" / "OpenSSH" / "ssh.exe"
|
|
324
|
+
if candidate.exists():
|
|
325
|
+
return str(candidate)
|
|
326
|
+
return "ssh"
|
|
327
|
+
|
|
328
|
+
|
|
329
|
+
LOGGER = logging.getLogger(__name__)
|
sshler/ssh_config.py
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
import shlex
|
|
5
|
+
from dataclasses import dataclass, field
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Any
|
|
8
|
+
|
|
9
|
+
SSH_CONFIG_ENV = "SSHLER_SSH_CONFIG"
|
|
10
|
+
DEFAULT_SSH_CONFIG = Path.home() / ".ssh" / "config"
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@dataclass
|
|
14
|
+
class HostConfig:
|
|
15
|
+
"""Subset of SSH host configuration important to sshler."""
|
|
16
|
+
|
|
17
|
+
name: str
|
|
18
|
+
hostname: str | None = None
|
|
19
|
+
user: str | None = None
|
|
20
|
+
port: int | None = None
|
|
21
|
+
identity_files: list[str] = field(default_factory=list)
|
|
22
|
+
raw: dict[str, Any] = field(default_factory=dict)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def _expand_path(value: str, *, base_dir: Path) -> str:
|
|
26
|
+
expanded = os.path.expandvars(os.path.expanduser(value))
|
|
27
|
+
expanded_path = Path(expanded)
|
|
28
|
+
if expanded_path.is_absolute():
|
|
29
|
+
return str(expanded_path)
|
|
30
|
+
return str((base_dir / expanded_path).resolve())
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def _parse_file(path: Path, *, seen: set[Path]) -> dict[str, HostConfig]:
|
|
34
|
+
hosts: dict[str, HostConfig] = {}
|
|
35
|
+
if not path.exists() or not path.is_file():
|
|
36
|
+
return hosts
|
|
37
|
+
|
|
38
|
+
real_path = path.resolve()
|
|
39
|
+
if real_path in seen:
|
|
40
|
+
return hosts
|
|
41
|
+
seen.add(real_path)
|
|
42
|
+
|
|
43
|
+
current_patterns: list[str] = []
|
|
44
|
+
current_options: dict[str, Any] = {}
|
|
45
|
+
|
|
46
|
+
def flush() -> None:
|
|
47
|
+
nonlocal current_patterns, current_options
|
|
48
|
+
if not current_patterns:
|
|
49
|
+
return
|
|
50
|
+
for pattern in current_patterns:
|
|
51
|
+
if any(ch in pattern for ch in "*? "):
|
|
52
|
+
continue
|
|
53
|
+
entry = HostConfig(
|
|
54
|
+
name=pattern,
|
|
55
|
+
hostname=current_options.get("hostname"),
|
|
56
|
+
user=current_options.get("user"),
|
|
57
|
+
port=int(current_options["port"]) if "port" in current_options else None,
|
|
58
|
+
identity_files=[
|
|
59
|
+
_expand_path(value, base_dir=path.parent)
|
|
60
|
+
for value in current_options.get("identityfile", [])
|
|
61
|
+
],
|
|
62
|
+
raw=current_options.copy(),
|
|
63
|
+
)
|
|
64
|
+
hosts[pattern] = entry
|
|
65
|
+
current_patterns = []
|
|
66
|
+
current_options = {}
|
|
67
|
+
|
|
68
|
+
with path.open("r", encoding="utf-8") as file_pointer:
|
|
69
|
+
for line in file_pointer:
|
|
70
|
+
stripped = line.strip()
|
|
71
|
+
if not stripped or stripped.startswith("#"):
|
|
72
|
+
continue
|
|
73
|
+
if "#" in stripped:
|
|
74
|
+
before, _, _ = stripped.partition("#")
|
|
75
|
+
stripped = before.strip()
|
|
76
|
+
if not stripped:
|
|
77
|
+
continue
|
|
78
|
+
|
|
79
|
+
try:
|
|
80
|
+
tokens = shlex.split(stripped, comments=False, posix=True)
|
|
81
|
+
except ValueError:
|
|
82
|
+
continue
|
|
83
|
+
if not tokens:
|
|
84
|
+
continue
|
|
85
|
+
|
|
86
|
+
keyword = tokens[0].lower()
|
|
87
|
+
values = tokens[1:]
|
|
88
|
+
if keyword == "host":
|
|
89
|
+
flush()
|
|
90
|
+
current_patterns = values
|
|
91
|
+
current_options = {}
|
|
92
|
+
continue
|
|
93
|
+
|
|
94
|
+
if keyword == "include" and values:
|
|
95
|
+
include_pattern = values[0]
|
|
96
|
+
parent = path.parent
|
|
97
|
+
for include_path in parent.glob(include_pattern):
|
|
98
|
+
hosts.update(_parse_file(include_path, seen=seen))
|
|
99
|
+
continue
|
|
100
|
+
|
|
101
|
+
if keyword == "match":
|
|
102
|
+
flush()
|
|
103
|
+
current_patterns = []
|
|
104
|
+
current_options = {}
|
|
105
|
+
continue
|
|
106
|
+
|
|
107
|
+
if not current_patterns:
|
|
108
|
+
continue
|
|
109
|
+
|
|
110
|
+
if keyword in {"hostname", "user", "port"} and values:
|
|
111
|
+
current_options[keyword] = values[-1]
|
|
112
|
+
elif keyword == "identityfile" and values:
|
|
113
|
+
current_options.setdefault("identityfile", []).extend(values)
|
|
114
|
+
else:
|
|
115
|
+
current_options[keyword] = values[-1] if values else None
|
|
116
|
+
|
|
117
|
+
flush()
|
|
118
|
+
return hosts
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def load_ssh_config(explicit_path: str | None = None) -> dict[str, HostConfig]:
|
|
122
|
+
explicit = explicit_path or os.getenv(SSH_CONFIG_ENV)
|
|
123
|
+
seen: set[Path] = set()
|
|
124
|
+
candidates: list[Path] = []
|
|
125
|
+
if explicit:
|
|
126
|
+
candidates.append(Path(explicit).expanduser())
|
|
127
|
+
else:
|
|
128
|
+
candidates.append(DEFAULT_SSH_CONFIG)
|
|
129
|
+
|
|
130
|
+
aggregated: dict[str, HostConfig] = {}
|
|
131
|
+
for candidate in candidates:
|
|
132
|
+
aggregated.update(_parse_file(candidate, seen=seen))
|
|
133
|
+
return aggregated
|
|
134
|
+
|
sshler/state.py
ADDED
|
@@ -0,0 +1,230 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import threading
|
|
5
|
+
import time
|
|
6
|
+
from collections.abc import Iterable, Sequence
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from typing import TYPE_CHECKING
|
|
9
|
+
|
|
10
|
+
from pydantic import Field
|
|
11
|
+
from sqler import SQLerDB, SQLerModel
|
|
12
|
+
from sqler.adapter import SQLiteAdapter
|
|
13
|
+
from sqler.query import SQLerField as F
|
|
14
|
+
|
|
15
|
+
STATE_FILENAME = "state.sqlite"
|
|
16
|
+
|
|
17
|
+
_DB_LOCK = threading.RLock()
|
|
18
|
+
_DB: SQLerDB | None = None
|
|
19
|
+
_DB_PATH: Path | None = None
|
|
20
|
+
_INITIALISED = False
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class Favorite(SQLerModel):
|
|
24
|
+
"""Persisted favourite directories per box."""
|
|
25
|
+
|
|
26
|
+
__tablename__ = "favorites"
|
|
27
|
+
|
|
28
|
+
box: str
|
|
29
|
+
path: str
|
|
30
|
+
position: int = 0
|
|
31
|
+
created_at: float = Field(default_factory=time.time)
|
|
32
|
+
updated_at: float = Field(default_factory=time.time)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
if TYPE_CHECKING: # pragma: no cover - import for typing only
|
|
36
|
+
from .config import StoredBox
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def initialize(config_dir: Path) -> None:
|
|
40
|
+
"""Initialise the SQLite-backed state store using ``sqler``.
|
|
41
|
+
|
|
42
|
+
The state database lives alongside ``boxes.yaml`` and holds favourites and
|
|
43
|
+
future history records.
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
global _DB, _DB_PATH, _INITIALISED
|
|
47
|
+
|
|
48
|
+
config_dir = config_dir.expanduser()
|
|
49
|
+
config_dir.mkdir(parents=True, exist_ok=True)
|
|
50
|
+
target_path = config_dir / STATE_FILENAME
|
|
51
|
+
|
|
52
|
+
with _DB_LOCK:
|
|
53
|
+
if _INITIALISED and _DB_PATH == target_path:
|
|
54
|
+
return
|
|
55
|
+
|
|
56
|
+
if _DB is not None and _DB_PATH != target_path:
|
|
57
|
+
_DB.close()
|
|
58
|
+
|
|
59
|
+
adapter = SQLiteAdapter(path=str(target_path))
|
|
60
|
+
db = SQLerDB(adapter)
|
|
61
|
+
Favorite.set_db(db)
|
|
62
|
+
Favorite.ensure_index("box")
|
|
63
|
+
Favorite.ensure_index("position")
|
|
64
|
+
|
|
65
|
+
_DB = db
|
|
66
|
+
_DB_PATH = target_path
|
|
67
|
+
_INITIALISED = True
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def reset_state() -> None:
|
|
71
|
+
"""Reset the in-memory cache (used by tests)."""
|
|
72
|
+
|
|
73
|
+
global _DB, _DB_PATH, _INITIALISED
|
|
74
|
+
with _DB_LOCK:
|
|
75
|
+
if _DB is not None:
|
|
76
|
+
try:
|
|
77
|
+
_DB.close()
|
|
78
|
+
except Exception: # pragma: no cover - best effort cleanup
|
|
79
|
+
pass
|
|
80
|
+
_DB = None
|
|
81
|
+
_DB_PATH = None
|
|
82
|
+
_INITIALISED = False
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def _require_db() -> SQLerDB:
|
|
86
|
+
if not _INITIALISED or _DB is None:
|
|
87
|
+
raise RuntimeError("State store not initialised")
|
|
88
|
+
return _DB
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def migrate_legacy_favorites(stored: dict[str, StoredBox]) -> bool:
|
|
92
|
+
"""Move favourites persisted in YAML into the sqler-backed store."""
|
|
93
|
+
|
|
94
|
+
if not stored:
|
|
95
|
+
return False
|
|
96
|
+
|
|
97
|
+
_require_db()
|
|
98
|
+
migrated = False
|
|
99
|
+
with _DB_LOCK:
|
|
100
|
+
for item in stored.values():
|
|
101
|
+
if not item.favorites:
|
|
102
|
+
continue
|
|
103
|
+
replace_favorites(item.name, item.favorites)
|
|
104
|
+
item.favorites.clear()
|
|
105
|
+
migrated = True
|
|
106
|
+
return migrated
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def list_favorites(box_name: str) -> list[str]:
|
|
110
|
+
"""Return the ordered favourites for ``box_name``."""
|
|
111
|
+
|
|
112
|
+
_require_db()
|
|
113
|
+
with _DB_LOCK:
|
|
114
|
+
rows = (
|
|
115
|
+
Favorite.query()
|
|
116
|
+
.filter(F("box") == box_name)
|
|
117
|
+
.order_by("position")
|
|
118
|
+
.all()
|
|
119
|
+
)
|
|
120
|
+
return [row.path for row in rows]
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
async def list_favorites_async(box_name: str) -> list[str]:
|
|
124
|
+
return await asyncio.to_thread(list_favorites, box_name)
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def favorites_map(box_names: Sequence[str] | None = None) -> dict[str, list[str]]:
|
|
128
|
+
"""Return favourites for the supplied ``box_names``."""
|
|
129
|
+
|
|
130
|
+
_require_db()
|
|
131
|
+
with _DB_LOCK:
|
|
132
|
+
if box_names is not None:
|
|
133
|
+
return {name: list_favorites(name) for name in box_names}
|
|
134
|
+
|
|
135
|
+
rows = Favorite.query().order_by("box").order_by("position").all()
|
|
136
|
+
mapping: dict[str, list[str]] = {}
|
|
137
|
+
for row in rows:
|
|
138
|
+
mapping.setdefault(row.box, []).append(row.path)
|
|
139
|
+
return mapping
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
def toggle_favorite(box_name: str, path: str) -> bool:
|
|
143
|
+
"""Add or remove ``path`` from favourites. Returns ``True`` if added."""
|
|
144
|
+
|
|
145
|
+
if not path:
|
|
146
|
+
return False
|
|
147
|
+
|
|
148
|
+
_require_db()
|
|
149
|
+
now = time.time()
|
|
150
|
+
with _DB_LOCK:
|
|
151
|
+
query = Favorite.query().filter((F("box") == box_name) & (F("path") == path))
|
|
152
|
+
existing = query.first()
|
|
153
|
+
if existing:
|
|
154
|
+
existing.delete()
|
|
155
|
+
return False
|
|
156
|
+
|
|
157
|
+
position = _next_position(box_name)
|
|
158
|
+
Favorite(box=box_name, path=path, position=position, created_at=now, updated_at=now).save()
|
|
159
|
+
return True
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
async def toggle_favorite_async(box_name: str, path: str) -> bool:
|
|
163
|
+
return await asyncio.to_thread(toggle_favorite, box_name, path)
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
def replace_favorites(box_name: str, paths: Iterable[str]) -> None:
|
|
167
|
+
"""Replace all favourites for ``box_name`` with ``paths`` preserving order."""
|
|
168
|
+
|
|
169
|
+
_require_db()
|
|
170
|
+
deduped: list[str] = []
|
|
171
|
+
seen: set[str] = set()
|
|
172
|
+
for raw in paths:
|
|
173
|
+
cleaned = raw.strip()
|
|
174
|
+
if not cleaned or cleaned in seen:
|
|
175
|
+
continue
|
|
176
|
+
deduped.append(cleaned)
|
|
177
|
+
seen.add(cleaned)
|
|
178
|
+
|
|
179
|
+
now = time.time()
|
|
180
|
+
with _DB_LOCK:
|
|
181
|
+
existing = {
|
|
182
|
+
fav.path: fav
|
|
183
|
+
for fav in Favorite.query().filter(F("box") == box_name).all()
|
|
184
|
+
}
|
|
185
|
+
|
|
186
|
+
for position, path in enumerate(deduped):
|
|
187
|
+
favourite = existing.pop(path, None)
|
|
188
|
+
if favourite is None:
|
|
189
|
+
Favorite(
|
|
190
|
+
box=box_name,
|
|
191
|
+
path=path,
|
|
192
|
+
position=position,
|
|
193
|
+
created_at=now,
|
|
194
|
+
updated_at=now,
|
|
195
|
+
).save()
|
|
196
|
+
continue
|
|
197
|
+
|
|
198
|
+
if favourite.position != position:
|
|
199
|
+
favourite.position = position
|
|
200
|
+
favourite.updated_at = now
|
|
201
|
+
favourite.save()
|
|
202
|
+
|
|
203
|
+
for leftover in existing.values():
|
|
204
|
+
leftover.delete()
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
async def replace_favorites_async(box_name: str, paths: Iterable[str]) -> None:
|
|
208
|
+
await asyncio.to_thread(replace_favorites, box_name, list(paths))
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
def remove_box(box_name: str) -> None:
|
|
212
|
+
"""Delete all persisted state for ``box_name``."""
|
|
213
|
+
|
|
214
|
+
_require_db()
|
|
215
|
+
with _DB_LOCK:
|
|
216
|
+
rows = Favorite.query().filter(F("box") == box_name).all()
|
|
217
|
+
for row in rows:
|
|
218
|
+
row.delete()
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
def _next_position(box_name: str) -> int:
|
|
222
|
+
existing = (
|
|
223
|
+
Favorite.query()
|
|
224
|
+
.filter(F("box") == box_name)
|
|
225
|
+
.order_by("position", desc=True)
|
|
226
|
+
.first()
|
|
227
|
+
)
|
|
228
|
+
if not existing:
|
|
229
|
+
return 0
|
|
230
|
+
return existing.position + 1
|