aws-bootstrap-g4dn 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
aws_bootstrap/ssh.py ADDED
@@ -0,0 +1,513 @@
1
+ """SSH key pair management and SSH config management for EC2 instances."""
2
+
3
+ from __future__ import annotations
4
+ import os
5
+ import re
6
+ import socket
7
+ import subprocess
8
+ import tempfile
9
+ import time
10
+ from dataclasses import dataclass
11
+ from pathlib import Path
12
+
13
+ import click
14
+
15
+
16
+ # ---------------------------------------------------------------------------
17
+ # SSH config markers
18
+ # ---------------------------------------------------------------------------
19
+
20
+ _BEGIN_MARKER = "# >>> aws-bootstrap [{instance_id}] >>>"
21
+ _END_MARKER = "# <<< aws-bootstrap [{instance_id}] <<<"
22
+ _BEGIN_RE = re.compile(r"^# >>> aws-bootstrap \[(?P<iid>i-[a-f0-9]+)\] >>>$")
23
+ _END_RE = re.compile(r"^# <<< aws-bootstrap \[(?P<iid>i-[a-f0-9]+)\] <<<$")
24
+
25
+ _DEFAULT_SSH_CONFIG = Path.home() / ".ssh" / "config"
26
+
27
+
28
+ def private_key_path(key_path: Path) -> Path:
29
+ """Derive the private key path from a public key path (strips .pub suffix)."""
30
+ return key_path.with_suffix("") if key_path.suffix == ".pub" else key_path
31
+
32
+
33
+ def _ssh_opts(key_path: Path) -> list[str]:
34
+ """Build common SSH/SCP options: suppress host-key checking and specify identity."""
35
+ return [
36
+ "-o",
37
+ "StrictHostKeyChecking=no",
38
+ "-o",
39
+ "UserKnownHostsFile=/dev/null",
40
+ "-i",
41
+ str(private_key_path(key_path)),
42
+ ]
43
+
44
+
45
+ def import_key_pair(ec2_client, key_name: str, key_path: Path) -> str:
46
+ """Import a local SSH public key to AWS, reusing if it already exists.
47
+
48
+ Returns the key pair name.
49
+ """
50
+ pub_key = key_path.read_bytes()
51
+
52
+ # Check if key pair already exists
53
+ try:
54
+ existing = ec2_client.describe_key_pairs(KeyNames=[key_name])
55
+ click.echo(" Key pair " + click.style(f"'{key_name}'", fg="bright_white") + " already exists, reusing.")
56
+ return existing["KeyPairs"][0]["KeyName"]
57
+ except ec2_client.exceptions.ClientError as e:
58
+ if "InvalidKeyPair.NotFound" not in str(e):
59
+ raise
60
+
61
+ ec2_client.import_key_pair(
62
+ KeyName=key_name,
63
+ PublicKeyMaterial=pub_key,
64
+ TagSpecifications=[
65
+ {
66
+ "ResourceType": "key-pair",
67
+ "Tags": [{"Key": "created-by", "Value": "aws-bootstrap-g4dn"}],
68
+ }
69
+ ],
70
+ )
71
+ click.secho(f" Imported key pair '{key_name}' from {key_path}", fg="green")
72
+ return key_name
73
+
74
+
75
+ def wait_for_ssh(host: str, user: str, key_path: Path, retries: int = 30, delay: int = 10) -> bool:
76
+ """Wait for SSH to become available on the instance.
77
+
78
+ Tries a TCP connection to port 22 first, then an actual SSH command.
79
+ """
80
+ base_opts = _ssh_opts(key_path)
81
+
82
+ for attempt in range(1, retries + 1):
83
+ # First check if port 22 is open
84
+ try:
85
+ sock = socket.create_connection((host, 22), timeout=5)
86
+ sock.close()
87
+ except (TimeoutError, ConnectionRefusedError, OSError):
88
+ click.echo(" SSH not ready " + click.style(f"(attempt {attempt}/{retries})", dim=True) + ", waiting...")
89
+ time.sleep(delay)
90
+ continue
91
+
92
+ # Port is open, try actual SSH
93
+ result = subprocess.run(
94
+ ["ssh", *base_opts, "-o", "ConnectTimeout=10", "-o", "BatchMode=yes", f"{user}@{host}", "echo ok"],
95
+ capture_output=True,
96
+ text=True,
97
+ )
98
+ if result.returncode == 0:
99
+ click.secho(" SSH connection established.", fg="green")
100
+ return True
101
+
102
+ click.echo(" SSH not ready " + click.style(f"(attempt {attempt}/{retries})", dim=True) + ", waiting...")
103
+ time.sleep(delay)
104
+
105
+ return False
106
+
107
+
108
+ def run_remote_setup(host: str, user: str, key_path: Path, script_path: Path) -> bool:
109
+ """SCP the setup script and requirements.txt to the instance and execute."""
110
+ ssh_opts = _ssh_opts(key_path)
111
+ requirements_path = script_path.parent / "requirements.txt"
112
+
113
+ # SCP the requirements file
114
+ click.echo(" Uploading requirements.txt...")
115
+ req_result = subprocess.run(
116
+ ["scp", *ssh_opts, str(requirements_path), f"{user}@{host}:/tmp/requirements.txt"],
117
+ capture_output=True,
118
+ text=True,
119
+ )
120
+ if req_result.returncode != 0:
121
+ click.secho(f" SCP failed: {req_result.stderr}", fg="red", err=True)
122
+ return False
123
+
124
+ # SCP the GPU benchmark script
125
+ benchmark_path = script_path.parent / "gpu_benchmark.py"
126
+ click.echo(" Uploading gpu_benchmark.py...")
127
+ bench_result = subprocess.run(
128
+ ["scp", *ssh_opts, str(benchmark_path), f"{user}@{host}:/tmp/gpu_benchmark.py"],
129
+ capture_output=True,
130
+ text=True,
131
+ )
132
+ if bench_result.returncode != 0:
133
+ click.secho(f" SCP failed: {bench_result.stderr}", fg="red", err=True)
134
+ return False
135
+
136
+ # SCP the GPU smoke test notebook
137
+ notebook_path = script_path.parent / "gpu_smoke_test.ipynb"
138
+ click.echo(" Uploading gpu_smoke_test.ipynb...")
139
+ nb_result = subprocess.run(
140
+ ["scp", *ssh_opts, str(notebook_path), f"{user}@{host}:/tmp/gpu_smoke_test.ipynb"],
141
+ capture_output=True,
142
+ text=True,
143
+ )
144
+ if nb_result.returncode != 0:
145
+ click.secho(f" SCP failed: {nb_result.stderr}", fg="red", err=True)
146
+ return False
147
+
148
+ # SCP the script
149
+ click.echo(" Uploading remote_setup.sh...")
150
+ scp_result = subprocess.run(
151
+ ["scp", *ssh_opts, str(script_path), f"{user}@{host}:/tmp/remote_setup.sh"],
152
+ capture_output=True,
153
+ text=True,
154
+ )
155
+ if scp_result.returncode != 0:
156
+ click.secho(f" SCP failed: {scp_result.stderr}", fg="red", err=True)
157
+ return False
158
+
159
+ # Execute the script
160
+ click.echo(" Running remote_setup.sh on instance...")
161
+ ssh_result = subprocess.run(
162
+ ["ssh", *ssh_opts, f"{user}@{host}", "chmod +x /tmp/remote_setup.sh && /tmp/remote_setup.sh"],
163
+ capture_output=False,
164
+ )
165
+ return ssh_result.returncode == 0
166
+
167
+
168
+ # ---------------------------------------------------------------------------
169
+ # SSH config management
170
+ # ---------------------------------------------------------------------------
171
+
172
+
173
+ def _read_ssh_config(config_path: Path) -> str:
174
+ """Read SSH config content. Returns ``""`` if file doesn't exist."""
175
+ if config_path.exists():
176
+ return config_path.read_text()
177
+ return ""
178
+
179
+
180
+ def _write_ssh_config(config_path: Path, content: str) -> None:
181
+ """Atomically write *content* to *config_path*.
182
+
183
+ Creates ``~/.ssh/`` (mode 0700) and the file (mode 0600) if needed.
184
+ """
185
+ ssh_dir = config_path.parent
186
+ ssh_dir.mkdir(mode=0o700, parents=True, exist_ok=True)
187
+
188
+ fd, tmp = tempfile.mkstemp(dir=str(ssh_dir), prefix=".ssh_config_tmp_")
189
+ try:
190
+ os.write(fd, content.encode())
191
+ os.close(fd)
192
+ os.chmod(tmp, 0o600)
193
+ os.replace(tmp, str(config_path))
194
+ except BaseException:
195
+ os.close(fd) if not os.get_inheritable(fd) else None # noqa: B018
196
+ if os.path.exists(tmp):
197
+ os.unlink(tmp)
198
+ raise
199
+
200
+
201
+ def _next_alias(content: str, prefix: str = "aws-gpu") -> str:
202
+ """Return the next sequential alias like ``aws-gpu3``.
203
+
204
+ Only considers aliases inside aws-bootstrap marker blocks so that
205
+ user-defined hosts with coincidentally matching names are ignored.
206
+ """
207
+ max_n = 0
208
+ in_block = False
209
+ for line in content.splitlines():
210
+ if _BEGIN_RE.match(line):
211
+ in_block = True
212
+ continue
213
+ if _END_RE.match(line):
214
+ in_block = False
215
+ continue
216
+ if in_block and line.strip().startswith("Host "):
217
+ alias = line.strip().removeprefix("Host ").strip()
218
+ if alias.startswith(prefix):
219
+ suffix = alias[len(prefix) :]
220
+ if suffix.isdigit():
221
+ max_n = max(max_n, int(suffix))
222
+ return f"{prefix}{max_n + 1}"
223
+
224
+
225
+ def _build_stanza(instance_id: str, alias: str, hostname: str, user: str, key_path: Path) -> str:
226
+ """Build a complete SSH config stanza with markers."""
227
+ priv_key = private_key_path(key_path)
228
+ return (
229
+ f"{_BEGIN_MARKER.format(instance_id=instance_id)}\n"
230
+ f"Host {alias}\n"
231
+ f" HostName {hostname}\n"
232
+ f" User {user}\n"
233
+ f" IdentityFile {priv_key}\n"
234
+ f" StrictHostKeyChecking no\n"
235
+ f" UserKnownHostsFile /dev/null\n"
236
+ f"{_END_MARKER.format(instance_id=instance_id)}\n"
237
+ )
238
+
239
+
240
+ def add_ssh_host(
241
+ instance_id: str,
242
+ hostname: str,
243
+ user: str,
244
+ key_path: Path,
245
+ config_path: Path | None = None,
246
+ alias_prefix: str = "aws-gpu",
247
+ ) -> str:
248
+ """Add (or update) an SSH host stanza for *instance_id*.
249
+
250
+ Returns the alias that was created (e.g. ``aws-gpu1``).
251
+ """
252
+ config_path = config_path or _DEFAULT_SSH_CONFIG
253
+ content = _read_ssh_config(config_path)
254
+
255
+ # Idempotent: if this instance already has a stanza, remember its alias
256
+ existing_alias = _find_alias_in_content(content, instance_id)
257
+ content = _remove_block(content, instance_id)
258
+
259
+ alias = existing_alias or _next_alias(content, alias_prefix)
260
+ stanza = _build_stanza(instance_id, alias, hostname, user, key_path)
261
+
262
+ # Ensure a blank line before our block if file has content
263
+ if content and not content.endswith("\n\n") and not content.endswith("\n"):
264
+ content += "\n\n"
265
+ elif content and not content.endswith("\n") or content and content.endswith("\n") and not content.endswith("\n\n"):
266
+ content += "\n"
267
+
268
+ content += stanza
269
+ _write_ssh_config(config_path, content)
270
+ return alias
271
+
272
+
273
+ def remove_ssh_host(instance_id: str, config_path: Path | None = None) -> str | None:
274
+ """Remove the SSH host stanza for *instance_id*.
275
+
276
+ Returns the alias that was removed, or ``None`` if not found.
277
+ """
278
+ config_path = config_path or _DEFAULT_SSH_CONFIG
279
+ content = _read_ssh_config(config_path)
280
+ if not content:
281
+ return None
282
+
283
+ alias = _find_alias_in_content(content, instance_id)
284
+ if alias is None:
285
+ return None
286
+
287
+ content = _remove_block(content, instance_id)
288
+ _write_ssh_config(config_path, content)
289
+ return alias
290
+
291
+
292
+ def find_ssh_alias(instance_id: str, config_path: Path | None = None) -> str | None:
293
+ """Read-only lookup of alias for a given instance ID."""
294
+ config_path = config_path or _DEFAULT_SSH_CONFIG
295
+ content = _read_ssh_config(config_path)
296
+ return _find_alias_in_content(content, instance_id)
297
+
298
+
299
+ def list_ssh_hosts(config_path: Path | None = None) -> dict[str, str]:
300
+ """Return ``{instance_id: alias}`` for all aws-bootstrap-managed hosts."""
301
+ config_path = config_path or _DEFAULT_SSH_CONFIG
302
+ content = _read_ssh_config(config_path)
303
+ result: dict[str, str] = {}
304
+ current_iid: str | None = None
305
+ for line in content.splitlines():
306
+ begin = _BEGIN_RE.match(line)
307
+ if begin:
308
+ current_iid = begin.group("iid")
309
+ continue
310
+ end = _END_RE.match(line)
311
+ if end:
312
+ current_iid = None
313
+ continue
314
+ if current_iid and line.strip().startswith("Host "):
315
+ alias = line.strip().removeprefix("Host ").strip()
316
+ result[current_iid] = alias
317
+ return result
318
+
319
+
320
+ # ---------------------------------------------------------------------------
321
+ # GPU info via SSH
322
+ # ---------------------------------------------------------------------------
323
+
324
+ _GPU_ARCHITECTURES: dict[str, str] = {
325
+ "7.0": "Volta",
326
+ "7.5": "Turing",
327
+ "8.0": "Ampere",
328
+ "8.6": "Ampere",
329
+ "8.7": "Ampere",
330
+ "8.9": "Ada Lovelace",
331
+ "9.0": "Hopper",
332
+ }
333
+
334
+
335
+ @dataclass
336
+ class SSHHostDetails:
337
+ """Connection details parsed from an SSH config stanza."""
338
+
339
+ hostname: str
340
+ user: str
341
+ identity_file: Path
342
+
343
+
344
+ @dataclass
345
+ class GpuInfo:
346
+ """GPU information retrieved via nvidia-smi and nvcc."""
347
+
348
+ driver_version: str
349
+ cuda_driver_version: str # max CUDA version supported by driver (from nvidia-smi)
350
+ cuda_toolkit_version: str | None # actual CUDA toolkit installed (from nvcc), None if unavailable
351
+ gpu_name: str
352
+ compute_capability: str
353
+ architecture: str
354
+
355
+
356
+ def get_ssh_host_details(instance_id: str, config_path: Path | None = None) -> SSHHostDetails | None:
357
+ """Parse the managed SSH config block for *instance_id*.
358
+
359
+ Returns ``SSHHostDetails`` with HostName, User, and IdentityFile,
360
+ or ``None`` if no complete managed block is found.
361
+ """
362
+ config_path = config_path or _DEFAULT_SSH_CONFIG
363
+ content = _read_ssh_config(config_path)
364
+ if not content:
365
+ return None
366
+
367
+ begin_marker = _BEGIN_MARKER.format(instance_id=instance_id)
368
+ end_marker = _END_MARKER.format(instance_id=instance_id)
369
+
370
+ in_block = False
371
+ hostname: str | None = None
372
+ user: str | None = None
373
+ identity_file: str | None = None
374
+
375
+ for line in content.splitlines():
376
+ if line == begin_marker:
377
+ in_block = True
378
+ continue
379
+ if line == end_marker and in_block:
380
+ if hostname and user and identity_file:
381
+ return SSHHostDetails(hostname=hostname, user=user, identity_file=Path(identity_file))
382
+ return None
383
+ if in_block:
384
+ stripped = line.strip()
385
+ if stripped.startswith("HostName "):
386
+ hostname = stripped.removeprefix("HostName ").strip()
387
+ elif stripped.startswith("User "):
388
+ user = stripped.removeprefix("User ").strip()
389
+ elif stripped.startswith("IdentityFile "):
390
+ identity_file = stripped.removeprefix("IdentityFile ").strip()
391
+
392
+ return None
393
+
394
+
395
+ def query_gpu_info(host: str, user: str, key_path: Path, timeout: int = 10) -> GpuInfo | None:
396
+ """SSH into a host and query GPU info via ``nvidia-smi``.
397
+
398
+ Returns ``GpuInfo`` on success, or ``None`` if the SSH connection fails,
399
+ ``nvidia-smi`` is unavailable, or the output is malformed.
400
+ """
401
+ ssh_opts = _ssh_opts(key_path)
402
+ remote_cmd = (
403
+ "nvidia-smi --query-gpu=driver_version,name,compute_cap --format=csv,noheader,nounits"
404
+ " && nvidia-smi | grep -oP 'CUDA Version: \\K[\\d.]+'"
405
+ " && (nvcc --version 2>/dev/null | grep -oP 'release \\K[\\d.]+' || echo 'N/A')"
406
+ )
407
+ cmd = [
408
+ "ssh",
409
+ *ssh_opts,
410
+ "-o",
411
+ f"ConnectTimeout={timeout}",
412
+ "-o",
413
+ "BatchMode=yes",
414
+ f"{user}@{host}",
415
+ remote_cmd,
416
+ ]
417
+
418
+ try:
419
+ result = subprocess.run(cmd, capture_output=True, text=True, timeout=timeout + 5)
420
+ except subprocess.TimeoutExpired:
421
+ return None
422
+
423
+ if result.returncode != 0:
424
+ return None
425
+
426
+ lines = result.stdout.strip().splitlines()
427
+ if len(lines) < 2:
428
+ return None
429
+
430
+ try:
431
+ parts = [p.strip() for p in lines[0].split(",")]
432
+ if len(parts) != 3:
433
+ return None
434
+ driver_version, gpu_name, compute_cap = parts
435
+ cuda_driver_version = lines[1].strip()
436
+ cuda_toolkit_version: str | None = None
437
+ if len(lines) >= 3:
438
+ toolkit_line = lines[2].strip()
439
+ if toolkit_line and toolkit_line != "N/A":
440
+ cuda_toolkit_version = toolkit_line
441
+ architecture = _GPU_ARCHITECTURES.get(compute_cap, f"Unknown ({compute_cap})")
442
+ return GpuInfo(
443
+ driver_version=driver_version,
444
+ cuda_driver_version=cuda_driver_version,
445
+ cuda_toolkit_version=cuda_toolkit_version,
446
+ gpu_name=gpu_name,
447
+ compute_capability=compute_cap,
448
+ architecture=architecture,
449
+ )
450
+ except (ValueError, IndexError):
451
+ return None
452
+
453
+
454
+ # ---------------------------------------------------------------------------
455
+ # Internal helpers
456
+ # ---------------------------------------------------------------------------
457
+
458
+
459
+ def _find_alias_in_content(content: str, instance_id: str) -> str | None:
460
+ """Extract the alias from a managed block for *instance_id*, or ``None``.
461
+
462
+ Only returns an alias when both begin and end markers are present (safety).
463
+ """
464
+ in_block = False
465
+ alias: str | None = None
466
+ begin_marker = _BEGIN_MARKER.format(instance_id=instance_id)
467
+ end_marker = _END_MARKER.format(instance_id=instance_id)
468
+ for line in content.splitlines():
469
+ if line == begin_marker:
470
+ in_block = True
471
+ alias = None
472
+ continue
473
+ if line == end_marker and in_block:
474
+ return alias # complete block found
475
+ if in_block and alias is None and line.strip().startswith("Host "):
476
+ alias = line.strip().removeprefix("Host ").strip()
477
+ return None # no complete block found
478
+
479
+
480
+ def _remove_block(content: str, instance_id: str) -> str:
481
+ """Remove the marker block for *instance_id* from *content*.
482
+
483
+ If begin marker is found without matching end marker, content is returned
484
+ unchanged (safety measure).
485
+ """
486
+ begin_marker = _BEGIN_MARKER.format(instance_id=instance_id)
487
+ end_marker = _END_MARKER.format(instance_id=instance_id)
488
+
489
+ lines = content.splitlines(keepends=True)
490
+ begin_idx: int | None = None
491
+ end_idx: int | None = None
492
+
493
+ for i, line in enumerate(lines):
494
+ if line.rstrip("\n") == begin_marker:
495
+ begin_idx = i
496
+ elif line.rstrip("\n") == end_marker and begin_idx is not None:
497
+ end_idx = i
498
+ break
499
+
500
+ if begin_idx is None or end_idx is None:
501
+ return content
502
+
503
+ # Remove block lines
504
+ del lines[begin_idx : end_idx + 1]
505
+
506
+ # Clean up extra blank lines at removal site
507
+ while begin_idx < len(lines) and lines[begin_idx].strip() == "":
508
+ if begin_idx > 0 and lines[begin_idx - 1].strip() == "":
509
+ del lines[begin_idx]
510
+ else:
511
+ break
512
+
513
+ return "".join(lines)
File without changes