aws-bootstrap-g4dn 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aws_bootstrap/__init__.py +1 -0
- aws_bootstrap/cli.py +438 -0
- aws_bootstrap/config.py +24 -0
- aws_bootstrap/ec2.py +341 -0
- aws_bootstrap/resources/__init__.py +0 -0
- aws_bootstrap/resources/gpu_benchmark.py +839 -0
- aws_bootstrap/resources/gpu_smoke_test.ipynb +340 -0
- aws_bootstrap/resources/remote_setup.sh +188 -0
- aws_bootstrap/resources/requirements.txt +8 -0
- aws_bootstrap/ssh.py +513 -0
- aws_bootstrap/tests/__init__.py +0 -0
- aws_bootstrap/tests/test_cli.py +528 -0
- aws_bootstrap/tests/test_config.py +35 -0
- aws_bootstrap/tests/test_ec2.py +313 -0
- aws_bootstrap/tests/test_ssh_config.py +297 -0
- aws_bootstrap/tests/test_ssh_gpu.py +138 -0
- aws_bootstrap_g4dn-0.1.0.dist-info/METADATA +308 -0
- aws_bootstrap_g4dn-0.1.0.dist-info/RECORD +22 -0
- aws_bootstrap_g4dn-0.1.0.dist-info/WHEEL +5 -0
- aws_bootstrap_g4dn-0.1.0.dist-info/entry_points.txt +2 -0
- aws_bootstrap_g4dn-0.1.0.dist-info/licenses/LICENSE +21 -0
- aws_bootstrap_g4dn-0.1.0.dist-info/top_level.txt +1 -0
aws_bootstrap/ssh.py
ADDED
|
@@ -0,0 +1,513 @@
|
|
|
1
|
+
"""SSH key pair management and SSH config management for EC2 instances."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
import os
|
|
5
|
+
import re
|
|
6
|
+
import socket
|
|
7
|
+
import subprocess
|
|
8
|
+
import tempfile
|
|
9
|
+
import time
|
|
10
|
+
from dataclasses import dataclass
|
|
11
|
+
from pathlib import Path
|
|
12
|
+
|
|
13
|
+
import click
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
# ---------------------------------------------------------------------------
|
|
17
|
+
# SSH config markers
|
|
18
|
+
# ---------------------------------------------------------------------------
|
|
19
|
+
|
|
20
|
+
_BEGIN_MARKER = "# >>> aws-bootstrap [{instance_id}] >>>"
|
|
21
|
+
_END_MARKER = "# <<< aws-bootstrap [{instance_id}] <<<"
|
|
22
|
+
_BEGIN_RE = re.compile(r"^# >>> aws-bootstrap \[(?P<iid>i-[a-f0-9]+)\] >>>$")
|
|
23
|
+
_END_RE = re.compile(r"^# <<< aws-bootstrap \[(?P<iid>i-[a-f0-9]+)\] <<<$")
|
|
24
|
+
|
|
25
|
+
_DEFAULT_SSH_CONFIG = Path.home() / ".ssh" / "config"
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def private_key_path(key_path: Path) -> Path:
|
|
29
|
+
"""Derive the private key path from a public key path (strips .pub suffix)."""
|
|
30
|
+
return key_path.with_suffix("") if key_path.suffix == ".pub" else key_path
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def _ssh_opts(key_path: Path) -> list[str]:
|
|
34
|
+
"""Build common SSH/SCP options: suppress host-key checking and specify identity."""
|
|
35
|
+
return [
|
|
36
|
+
"-o",
|
|
37
|
+
"StrictHostKeyChecking=no",
|
|
38
|
+
"-o",
|
|
39
|
+
"UserKnownHostsFile=/dev/null",
|
|
40
|
+
"-i",
|
|
41
|
+
str(private_key_path(key_path)),
|
|
42
|
+
]
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def import_key_pair(ec2_client, key_name: str, key_path: Path) -> str:
|
|
46
|
+
"""Import a local SSH public key to AWS, reusing if it already exists.
|
|
47
|
+
|
|
48
|
+
Returns the key pair name.
|
|
49
|
+
"""
|
|
50
|
+
pub_key = key_path.read_bytes()
|
|
51
|
+
|
|
52
|
+
# Check if key pair already exists
|
|
53
|
+
try:
|
|
54
|
+
existing = ec2_client.describe_key_pairs(KeyNames=[key_name])
|
|
55
|
+
click.echo(" Key pair " + click.style(f"'{key_name}'", fg="bright_white") + " already exists, reusing.")
|
|
56
|
+
return existing["KeyPairs"][0]["KeyName"]
|
|
57
|
+
except ec2_client.exceptions.ClientError as e:
|
|
58
|
+
if "InvalidKeyPair.NotFound" not in str(e):
|
|
59
|
+
raise
|
|
60
|
+
|
|
61
|
+
ec2_client.import_key_pair(
|
|
62
|
+
KeyName=key_name,
|
|
63
|
+
PublicKeyMaterial=pub_key,
|
|
64
|
+
TagSpecifications=[
|
|
65
|
+
{
|
|
66
|
+
"ResourceType": "key-pair",
|
|
67
|
+
"Tags": [{"Key": "created-by", "Value": "aws-bootstrap-g4dn"}],
|
|
68
|
+
}
|
|
69
|
+
],
|
|
70
|
+
)
|
|
71
|
+
click.secho(f" Imported key pair '{key_name}' from {key_path}", fg="green")
|
|
72
|
+
return key_name
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def wait_for_ssh(host: str, user: str, key_path: Path, retries: int = 30, delay: int = 10) -> bool:
|
|
76
|
+
"""Wait for SSH to become available on the instance.
|
|
77
|
+
|
|
78
|
+
Tries a TCP connection to port 22 first, then an actual SSH command.
|
|
79
|
+
"""
|
|
80
|
+
base_opts = _ssh_opts(key_path)
|
|
81
|
+
|
|
82
|
+
for attempt in range(1, retries + 1):
|
|
83
|
+
# First check if port 22 is open
|
|
84
|
+
try:
|
|
85
|
+
sock = socket.create_connection((host, 22), timeout=5)
|
|
86
|
+
sock.close()
|
|
87
|
+
except (TimeoutError, ConnectionRefusedError, OSError):
|
|
88
|
+
click.echo(" SSH not ready " + click.style(f"(attempt {attempt}/{retries})", dim=True) + ", waiting...")
|
|
89
|
+
time.sleep(delay)
|
|
90
|
+
continue
|
|
91
|
+
|
|
92
|
+
# Port is open, try actual SSH
|
|
93
|
+
result = subprocess.run(
|
|
94
|
+
["ssh", *base_opts, "-o", "ConnectTimeout=10", "-o", "BatchMode=yes", f"{user}@{host}", "echo ok"],
|
|
95
|
+
capture_output=True,
|
|
96
|
+
text=True,
|
|
97
|
+
)
|
|
98
|
+
if result.returncode == 0:
|
|
99
|
+
click.secho(" SSH connection established.", fg="green")
|
|
100
|
+
return True
|
|
101
|
+
|
|
102
|
+
click.echo(" SSH not ready " + click.style(f"(attempt {attempt}/{retries})", dim=True) + ", waiting...")
|
|
103
|
+
time.sleep(delay)
|
|
104
|
+
|
|
105
|
+
return False
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def run_remote_setup(host: str, user: str, key_path: Path, script_path: Path) -> bool:
|
|
109
|
+
"""SCP the setup script and requirements.txt to the instance and execute."""
|
|
110
|
+
ssh_opts = _ssh_opts(key_path)
|
|
111
|
+
requirements_path = script_path.parent / "requirements.txt"
|
|
112
|
+
|
|
113
|
+
# SCP the requirements file
|
|
114
|
+
click.echo(" Uploading requirements.txt...")
|
|
115
|
+
req_result = subprocess.run(
|
|
116
|
+
["scp", *ssh_opts, str(requirements_path), f"{user}@{host}:/tmp/requirements.txt"],
|
|
117
|
+
capture_output=True,
|
|
118
|
+
text=True,
|
|
119
|
+
)
|
|
120
|
+
if req_result.returncode != 0:
|
|
121
|
+
click.secho(f" SCP failed: {req_result.stderr}", fg="red", err=True)
|
|
122
|
+
return False
|
|
123
|
+
|
|
124
|
+
# SCP the GPU benchmark script
|
|
125
|
+
benchmark_path = script_path.parent / "gpu_benchmark.py"
|
|
126
|
+
click.echo(" Uploading gpu_benchmark.py...")
|
|
127
|
+
bench_result = subprocess.run(
|
|
128
|
+
["scp", *ssh_opts, str(benchmark_path), f"{user}@{host}:/tmp/gpu_benchmark.py"],
|
|
129
|
+
capture_output=True,
|
|
130
|
+
text=True,
|
|
131
|
+
)
|
|
132
|
+
if bench_result.returncode != 0:
|
|
133
|
+
click.secho(f" SCP failed: {bench_result.stderr}", fg="red", err=True)
|
|
134
|
+
return False
|
|
135
|
+
|
|
136
|
+
# SCP the GPU smoke test notebook
|
|
137
|
+
notebook_path = script_path.parent / "gpu_smoke_test.ipynb"
|
|
138
|
+
click.echo(" Uploading gpu_smoke_test.ipynb...")
|
|
139
|
+
nb_result = subprocess.run(
|
|
140
|
+
["scp", *ssh_opts, str(notebook_path), f"{user}@{host}:/tmp/gpu_smoke_test.ipynb"],
|
|
141
|
+
capture_output=True,
|
|
142
|
+
text=True,
|
|
143
|
+
)
|
|
144
|
+
if nb_result.returncode != 0:
|
|
145
|
+
click.secho(f" SCP failed: {nb_result.stderr}", fg="red", err=True)
|
|
146
|
+
return False
|
|
147
|
+
|
|
148
|
+
# SCP the script
|
|
149
|
+
click.echo(" Uploading remote_setup.sh...")
|
|
150
|
+
scp_result = subprocess.run(
|
|
151
|
+
["scp", *ssh_opts, str(script_path), f"{user}@{host}:/tmp/remote_setup.sh"],
|
|
152
|
+
capture_output=True,
|
|
153
|
+
text=True,
|
|
154
|
+
)
|
|
155
|
+
if scp_result.returncode != 0:
|
|
156
|
+
click.secho(f" SCP failed: {scp_result.stderr}", fg="red", err=True)
|
|
157
|
+
return False
|
|
158
|
+
|
|
159
|
+
# Execute the script
|
|
160
|
+
click.echo(" Running remote_setup.sh on instance...")
|
|
161
|
+
ssh_result = subprocess.run(
|
|
162
|
+
["ssh", *ssh_opts, f"{user}@{host}", "chmod +x /tmp/remote_setup.sh && /tmp/remote_setup.sh"],
|
|
163
|
+
capture_output=False,
|
|
164
|
+
)
|
|
165
|
+
return ssh_result.returncode == 0
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
# ---------------------------------------------------------------------------
|
|
169
|
+
# SSH config management
|
|
170
|
+
# ---------------------------------------------------------------------------
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
def _read_ssh_config(config_path: Path) -> str:
|
|
174
|
+
"""Read SSH config content. Returns ``""`` if file doesn't exist."""
|
|
175
|
+
if config_path.exists():
|
|
176
|
+
return config_path.read_text()
|
|
177
|
+
return ""
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
def _write_ssh_config(config_path: Path, content: str) -> None:
|
|
181
|
+
"""Atomically write *content* to *config_path*.
|
|
182
|
+
|
|
183
|
+
Creates ``~/.ssh/`` (mode 0700) and the file (mode 0600) if needed.
|
|
184
|
+
"""
|
|
185
|
+
ssh_dir = config_path.parent
|
|
186
|
+
ssh_dir.mkdir(mode=0o700, parents=True, exist_ok=True)
|
|
187
|
+
|
|
188
|
+
fd, tmp = tempfile.mkstemp(dir=str(ssh_dir), prefix=".ssh_config_tmp_")
|
|
189
|
+
try:
|
|
190
|
+
os.write(fd, content.encode())
|
|
191
|
+
os.close(fd)
|
|
192
|
+
os.chmod(tmp, 0o600)
|
|
193
|
+
os.replace(tmp, str(config_path))
|
|
194
|
+
except BaseException:
|
|
195
|
+
os.close(fd) if not os.get_inheritable(fd) else None # noqa: B018
|
|
196
|
+
if os.path.exists(tmp):
|
|
197
|
+
os.unlink(tmp)
|
|
198
|
+
raise
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
def _next_alias(content: str, prefix: str = "aws-gpu") -> str:
|
|
202
|
+
"""Return the next sequential alias like ``aws-gpu3``.
|
|
203
|
+
|
|
204
|
+
Only considers aliases inside aws-bootstrap marker blocks so that
|
|
205
|
+
user-defined hosts with coincidentally matching names are ignored.
|
|
206
|
+
"""
|
|
207
|
+
max_n = 0
|
|
208
|
+
in_block = False
|
|
209
|
+
for line in content.splitlines():
|
|
210
|
+
if _BEGIN_RE.match(line):
|
|
211
|
+
in_block = True
|
|
212
|
+
continue
|
|
213
|
+
if _END_RE.match(line):
|
|
214
|
+
in_block = False
|
|
215
|
+
continue
|
|
216
|
+
if in_block and line.strip().startswith("Host "):
|
|
217
|
+
alias = line.strip().removeprefix("Host ").strip()
|
|
218
|
+
if alias.startswith(prefix):
|
|
219
|
+
suffix = alias[len(prefix) :]
|
|
220
|
+
if suffix.isdigit():
|
|
221
|
+
max_n = max(max_n, int(suffix))
|
|
222
|
+
return f"{prefix}{max_n + 1}"
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
def _build_stanza(instance_id: str, alias: str, hostname: str, user: str, key_path: Path) -> str:
|
|
226
|
+
"""Build a complete SSH config stanza with markers."""
|
|
227
|
+
priv_key = private_key_path(key_path)
|
|
228
|
+
return (
|
|
229
|
+
f"{_BEGIN_MARKER.format(instance_id=instance_id)}\n"
|
|
230
|
+
f"Host {alias}\n"
|
|
231
|
+
f" HostName {hostname}\n"
|
|
232
|
+
f" User {user}\n"
|
|
233
|
+
f" IdentityFile {priv_key}\n"
|
|
234
|
+
f" StrictHostKeyChecking no\n"
|
|
235
|
+
f" UserKnownHostsFile /dev/null\n"
|
|
236
|
+
f"{_END_MARKER.format(instance_id=instance_id)}\n"
|
|
237
|
+
)
|
|
238
|
+
|
|
239
|
+
|
|
240
|
+
def add_ssh_host(
|
|
241
|
+
instance_id: str,
|
|
242
|
+
hostname: str,
|
|
243
|
+
user: str,
|
|
244
|
+
key_path: Path,
|
|
245
|
+
config_path: Path | None = None,
|
|
246
|
+
alias_prefix: str = "aws-gpu",
|
|
247
|
+
) -> str:
|
|
248
|
+
"""Add (or update) an SSH host stanza for *instance_id*.
|
|
249
|
+
|
|
250
|
+
Returns the alias that was created (e.g. ``aws-gpu1``).
|
|
251
|
+
"""
|
|
252
|
+
config_path = config_path or _DEFAULT_SSH_CONFIG
|
|
253
|
+
content = _read_ssh_config(config_path)
|
|
254
|
+
|
|
255
|
+
# Idempotent: if this instance already has a stanza, remember its alias
|
|
256
|
+
existing_alias = _find_alias_in_content(content, instance_id)
|
|
257
|
+
content = _remove_block(content, instance_id)
|
|
258
|
+
|
|
259
|
+
alias = existing_alias or _next_alias(content, alias_prefix)
|
|
260
|
+
stanza = _build_stanza(instance_id, alias, hostname, user, key_path)
|
|
261
|
+
|
|
262
|
+
# Ensure a blank line before our block if file has content
|
|
263
|
+
if content and not content.endswith("\n\n") and not content.endswith("\n"):
|
|
264
|
+
content += "\n\n"
|
|
265
|
+
elif content and not content.endswith("\n") or content and content.endswith("\n") and not content.endswith("\n\n"):
|
|
266
|
+
content += "\n"
|
|
267
|
+
|
|
268
|
+
content += stanza
|
|
269
|
+
_write_ssh_config(config_path, content)
|
|
270
|
+
return alias
|
|
271
|
+
|
|
272
|
+
|
|
273
|
+
def remove_ssh_host(instance_id: str, config_path: Path | None = None) -> str | None:
|
|
274
|
+
"""Remove the SSH host stanza for *instance_id*.
|
|
275
|
+
|
|
276
|
+
Returns the alias that was removed, or ``None`` if not found.
|
|
277
|
+
"""
|
|
278
|
+
config_path = config_path or _DEFAULT_SSH_CONFIG
|
|
279
|
+
content = _read_ssh_config(config_path)
|
|
280
|
+
if not content:
|
|
281
|
+
return None
|
|
282
|
+
|
|
283
|
+
alias = _find_alias_in_content(content, instance_id)
|
|
284
|
+
if alias is None:
|
|
285
|
+
return None
|
|
286
|
+
|
|
287
|
+
content = _remove_block(content, instance_id)
|
|
288
|
+
_write_ssh_config(config_path, content)
|
|
289
|
+
return alias
|
|
290
|
+
|
|
291
|
+
|
|
292
|
+
def find_ssh_alias(instance_id: str, config_path: Path | None = None) -> str | None:
|
|
293
|
+
"""Read-only lookup of alias for a given instance ID."""
|
|
294
|
+
config_path = config_path or _DEFAULT_SSH_CONFIG
|
|
295
|
+
content = _read_ssh_config(config_path)
|
|
296
|
+
return _find_alias_in_content(content, instance_id)
|
|
297
|
+
|
|
298
|
+
|
|
299
|
+
def list_ssh_hosts(config_path: Path | None = None) -> dict[str, str]:
|
|
300
|
+
"""Return ``{instance_id: alias}`` for all aws-bootstrap-managed hosts."""
|
|
301
|
+
config_path = config_path or _DEFAULT_SSH_CONFIG
|
|
302
|
+
content = _read_ssh_config(config_path)
|
|
303
|
+
result: dict[str, str] = {}
|
|
304
|
+
current_iid: str | None = None
|
|
305
|
+
for line in content.splitlines():
|
|
306
|
+
begin = _BEGIN_RE.match(line)
|
|
307
|
+
if begin:
|
|
308
|
+
current_iid = begin.group("iid")
|
|
309
|
+
continue
|
|
310
|
+
end = _END_RE.match(line)
|
|
311
|
+
if end:
|
|
312
|
+
current_iid = None
|
|
313
|
+
continue
|
|
314
|
+
if current_iid and line.strip().startswith("Host "):
|
|
315
|
+
alias = line.strip().removeprefix("Host ").strip()
|
|
316
|
+
result[current_iid] = alias
|
|
317
|
+
return result
|
|
318
|
+
|
|
319
|
+
|
|
320
|
+
# ---------------------------------------------------------------------------
|
|
321
|
+
# GPU info via SSH
|
|
322
|
+
# ---------------------------------------------------------------------------
|
|
323
|
+
|
|
324
|
+
_GPU_ARCHITECTURES: dict[str, str] = {
|
|
325
|
+
"7.0": "Volta",
|
|
326
|
+
"7.5": "Turing",
|
|
327
|
+
"8.0": "Ampere",
|
|
328
|
+
"8.6": "Ampere",
|
|
329
|
+
"8.7": "Ampere",
|
|
330
|
+
"8.9": "Ada Lovelace",
|
|
331
|
+
"9.0": "Hopper",
|
|
332
|
+
}
|
|
333
|
+
|
|
334
|
+
|
|
335
|
+
@dataclass
|
|
336
|
+
class SSHHostDetails:
|
|
337
|
+
"""Connection details parsed from an SSH config stanza."""
|
|
338
|
+
|
|
339
|
+
hostname: str
|
|
340
|
+
user: str
|
|
341
|
+
identity_file: Path
|
|
342
|
+
|
|
343
|
+
|
|
344
|
+
@dataclass
|
|
345
|
+
class GpuInfo:
|
|
346
|
+
"""GPU information retrieved via nvidia-smi and nvcc."""
|
|
347
|
+
|
|
348
|
+
driver_version: str
|
|
349
|
+
cuda_driver_version: str # max CUDA version supported by driver (from nvidia-smi)
|
|
350
|
+
cuda_toolkit_version: str | None # actual CUDA toolkit installed (from nvcc), None if unavailable
|
|
351
|
+
gpu_name: str
|
|
352
|
+
compute_capability: str
|
|
353
|
+
architecture: str
|
|
354
|
+
|
|
355
|
+
|
|
356
|
+
def get_ssh_host_details(instance_id: str, config_path: Path | None = None) -> SSHHostDetails | None:
|
|
357
|
+
"""Parse the managed SSH config block for *instance_id*.
|
|
358
|
+
|
|
359
|
+
Returns ``SSHHostDetails`` with HostName, User, and IdentityFile,
|
|
360
|
+
or ``None`` if no complete managed block is found.
|
|
361
|
+
"""
|
|
362
|
+
config_path = config_path or _DEFAULT_SSH_CONFIG
|
|
363
|
+
content = _read_ssh_config(config_path)
|
|
364
|
+
if not content:
|
|
365
|
+
return None
|
|
366
|
+
|
|
367
|
+
begin_marker = _BEGIN_MARKER.format(instance_id=instance_id)
|
|
368
|
+
end_marker = _END_MARKER.format(instance_id=instance_id)
|
|
369
|
+
|
|
370
|
+
in_block = False
|
|
371
|
+
hostname: str | None = None
|
|
372
|
+
user: str | None = None
|
|
373
|
+
identity_file: str | None = None
|
|
374
|
+
|
|
375
|
+
for line in content.splitlines():
|
|
376
|
+
if line == begin_marker:
|
|
377
|
+
in_block = True
|
|
378
|
+
continue
|
|
379
|
+
if line == end_marker and in_block:
|
|
380
|
+
if hostname and user and identity_file:
|
|
381
|
+
return SSHHostDetails(hostname=hostname, user=user, identity_file=Path(identity_file))
|
|
382
|
+
return None
|
|
383
|
+
if in_block:
|
|
384
|
+
stripped = line.strip()
|
|
385
|
+
if stripped.startswith("HostName "):
|
|
386
|
+
hostname = stripped.removeprefix("HostName ").strip()
|
|
387
|
+
elif stripped.startswith("User "):
|
|
388
|
+
user = stripped.removeprefix("User ").strip()
|
|
389
|
+
elif stripped.startswith("IdentityFile "):
|
|
390
|
+
identity_file = stripped.removeprefix("IdentityFile ").strip()
|
|
391
|
+
|
|
392
|
+
return None
|
|
393
|
+
|
|
394
|
+
|
|
395
|
+
def query_gpu_info(host: str, user: str, key_path: Path, timeout: int = 10) -> GpuInfo | None:
|
|
396
|
+
"""SSH into a host and query GPU info via ``nvidia-smi``.
|
|
397
|
+
|
|
398
|
+
Returns ``GpuInfo`` on success, or ``None`` if the SSH connection fails,
|
|
399
|
+
``nvidia-smi`` is unavailable, or the output is malformed.
|
|
400
|
+
"""
|
|
401
|
+
ssh_opts = _ssh_opts(key_path)
|
|
402
|
+
remote_cmd = (
|
|
403
|
+
"nvidia-smi --query-gpu=driver_version,name,compute_cap --format=csv,noheader,nounits"
|
|
404
|
+
" && nvidia-smi | grep -oP 'CUDA Version: \\K[\\d.]+'"
|
|
405
|
+
" && (nvcc --version 2>/dev/null | grep -oP 'release \\K[\\d.]+' || echo 'N/A')"
|
|
406
|
+
)
|
|
407
|
+
cmd = [
|
|
408
|
+
"ssh",
|
|
409
|
+
*ssh_opts,
|
|
410
|
+
"-o",
|
|
411
|
+
f"ConnectTimeout={timeout}",
|
|
412
|
+
"-o",
|
|
413
|
+
"BatchMode=yes",
|
|
414
|
+
f"{user}@{host}",
|
|
415
|
+
remote_cmd,
|
|
416
|
+
]
|
|
417
|
+
|
|
418
|
+
try:
|
|
419
|
+
result = subprocess.run(cmd, capture_output=True, text=True, timeout=timeout + 5)
|
|
420
|
+
except subprocess.TimeoutExpired:
|
|
421
|
+
return None
|
|
422
|
+
|
|
423
|
+
if result.returncode != 0:
|
|
424
|
+
return None
|
|
425
|
+
|
|
426
|
+
lines = result.stdout.strip().splitlines()
|
|
427
|
+
if len(lines) < 2:
|
|
428
|
+
return None
|
|
429
|
+
|
|
430
|
+
try:
|
|
431
|
+
parts = [p.strip() for p in lines[0].split(",")]
|
|
432
|
+
if len(parts) != 3:
|
|
433
|
+
return None
|
|
434
|
+
driver_version, gpu_name, compute_cap = parts
|
|
435
|
+
cuda_driver_version = lines[1].strip()
|
|
436
|
+
cuda_toolkit_version: str | None = None
|
|
437
|
+
if len(lines) >= 3:
|
|
438
|
+
toolkit_line = lines[2].strip()
|
|
439
|
+
if toolkit_line and toolkit_line != "N/A":
|
|
440
|
+
cuda_toolkit_version = toolkit_line
|
|
441
|
+
architecture = _GPU_ARCHITECTURES.get(compute_cap, f"Unknown ({compute_cap})")
|
|
442
|
+
return GpuInfo(
|
|
443
|
+
driver_version=driver_version,
|
|
444
|
+
cuda_driver_version=cuda_driver_version,
|
|
445
|
+
cuda_toolkit_version=cuda_toolkit_version,
|
|
446
|
+
gpu_name=gpu_name,
|
|
447
|
+
compute_capability=compute_cap,
|
|
448
|
+
architecture=architecture,
|
|
449
|
+
)
|
|
450
|
+
except (ValueError, IndexError):
|
|
451
|
+
return None
|
|
452
|
+
|
|
453
|
+
|
|
454
|
+
# ---------------------------------------------------------------------------
|
|
455
|
+
# Internal helpers
|
|
456
|
+
# ---------------------------------------------------------------------------
|
|
457
|
+
|
|
458
|
+
|
|
459
|
+
def _find_alias_in_content(content: str, instance_id: str) -> str | None:
|
|
460
|
+
"""Extract the alias from a managed block for *instance_id*, or ``None``.
|
|
461
|
+
|
|
462
|
+
Only returns an alias when both begin and end markers are present (safety).
|
|
463
|
+
"""
|
|
464
|
+
in_block = False
|
|
465
|
+
alias: str | None = None
|
|
466
|
+
begin_marker = _BEGIN_MARKER.format(instance_id=instance_id)
|
|
467
|
+
end_marker = _END_MARKER.format(instance_id=instance_id)
|
|
468
|
+
for line in content.splitlines():
|
|
469
|
+
if line == begin_marker:
|
|
470
|
+
in_block = True
|
|
471
|
+
alias = None
|
|
472
|
+
continue
|
|
473
|
+
if line == end_marker and in_block:
|
|
474
|
+
return alias # complete block found
|
|
475
|
+
if in_block and alias is None and line.strip().startswith("Host "):
|
|
476
|
+
alias = line.strip().removeprefix("Host ").strip()
|
|
477
|
+
return None # no complete block found
|
|
478
|
+
|
|
479
|
+
|
|
480
|
+
def _remove_block(content: str, instance_id: str) -> str:
|
|
481
|
+
"""Remove the marker block for *instance_id* from *content*.
|
|
482
|
+
|
|
483
|
+
If begin marker is found without matching end marker, content is returned
|
|
484
|
+
unchanged (safety measure).
|
|
485
|
+
"""
|
|
486
|
+
begin_marker = _BEGIN_MARKER.format(instance_id=instance_id)
|
|
487
|
+
end_marker = _END_MARKER.format(instance_id=instance_id)
|
|
488
|
+
|
|
489
|
+
lines = content.splitlines(keepends=True)
|
|
490
|
+
begin_idx: int | None = None
|
|
491
|
+
end_idx: int | None = None
|
|
492
|
+
|
|
493
|
+
for i, line in enumerate(lines):
|
|
494
|
+
if line.rstrip("\n") == begin_marker:
|
|
495
|
+
begin_idx = i
|
|
496
|
+
elif line.rstrip("\n") == end_marker and begin_idx is not None:
|
|
497
|
+
end_idx = i
|
|
498
|
+
break
|
|
499
|
+
|
|
500
|
+
if begin_idx is None or end_idx is None:
|
|
501
|
+
return content
|
|
502
|
+
|
|
503
|
+
# Remove block lines
|
|
504
|
+
del lines[begin_idx : end_idx + 1]
|
|
505
|
+
|
|
506
|
+
# Clean up extra blank lines at removal site
|
|
507
|
+
while begin_idx < len(lines) and lines[begin_idx].strip() == "":
|
|
508
|
+
if begin_idx > 0 and lines[begin_idx - 1].strip() == "":
|
|
509
|
+
del lines[begin_idx]
|
|
510
|
+
else:
|
|
511
|
+
break
|
|
512
|
+
|
|
513
|
+
return "".join(lines)
|
|
File without changes
|