aws-bootstrap-g4dn 0.2.0__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aws_bootstrap/cli.py +75 -11
- aws_bootstrap/config.py +2 -0
- aws_bootstrap/ec2.py +3 -3
- aws_bootstrap/gpu.py +27 -0
- aws_bootstrap/resources/gpu_benchmark.py +15 -5
- aws_bootstrap/resources/launch.json +42 -0
- aws_bootstrap/resources/remote_setup.sh +90 -6
- aws_bootstrap/resources/saxpy.cu +49 -0
- aws_bootstrap/resources/tasks.json +48 -0
- aws_bootstrap/ssh.py +83 -47
- aws_bootstrap/tests/test_cli.py +205 -7
- aws_bootstrap/tests/test_gpu.py +98 -0
- aws_bootstrap/tests/test_ssh_config.py +36 -0
- aws_bootstrap/tests/test_ssh_gpu.py +1 -95
- {aws_bootstrap_g4dn-0.2.0.dist-info → aws_bootstrap_g4dn-0.4.0.dist-info}/METADATA +41 -6
- aws_bootstrap_g4dn-0.4.0.dist-info/RECORD +27 -0
- aws_bootstrap_g4dn-0.2.0.dist-info/RECORD +0 -22
- {aws_bootstrap_g4dn-0.2.0.dist-info → aws_bootstrap_g4dn-0.4.0.dist-info}/WHEEL +0 -0
- {aws_bootstrap_g4dn-0.2.0.dist-info → aws_bootstrap_g4dn-0.4.0.dist-info}/entry_points.txt +0 -0
- {aws_bootstrap_g4dn-0.2.0.dist-info → aws_bootstrap_g4dn-0.4.0.dist-info}/licenses/LICENSE +0 -0
- {aws_bootstrap_g4dn-0.2.0.dist-info → aws_bootstrap_g4dn-0.4.0.dist-info}/top_level.txt +0 -0
aws_bootstrap/ssh.py
CHANGED
|
@@ -12,6 +12,8 @@ from pathlib import Path
|
|
|
12
12
|
|
|
13
13
|
import click
|
|
14
14
|
|
|
15
|
+
from .gpu import _GPU_ARCHITECTURES, GpuInfo
|
|
16
|
+
|
|
15
17
|
|
|
16
18
|
# ---------------------------------------------------------------------------
|
|
17
19
|
# SSH config markers
|
|
@@ -72,17 +74,18 @@ def import_key_pair(ec2_client, key_name: str, key_path: Path) -> str:
|
|
|
72
74
|
return key_name
|
|
73
75
|
|
|
74
76
|
|
|
75
|
-
def wait_for_ssh(host: str, user: str, key_path: Path, retries: int = 30, delay: int = 10) -> bool:
|
|
77
|
+
def wait_for_ssh(host: str, user: str, key_path: Path, retries: int = 30, delay: int = 10, port: int = 22) -> bool:
|
|
76
78
|
"""Wait for SSH to become available on the instance.
|
|
77
79
|
|
|
78
|
-
Tries a TCP connection to port
|
|
80
|
+
Tries a TCP connection to the SSH port first, then an actual SSH command.
|
|
79
81
|
"""
|
|
80
82
|
base_opts = _ssh_opts(key_path)
|
|
83
|
+
port_opts = ["-p", str(port)] if port != 22 else []
|
|
81
84
|
|
|
82
85
|
for attempt in range(1, retries + 1):
|
|
83
|
-
# First check if port
|
|
86
|
+
# First check if the SSH port is open
|
|
84
87
|
try:
|
|
85
|
-
sock = socket.create_connection((host,
|
|
88
|
+
sock = socket.create_connection((host, port), timeout=5)
|
|
86
89
|
sock.close()
|
|
87
90
|
except (TimeoutError, ConnectionRefusedError, OSError):
|
|
88
91
|
click.echo(" SSH not ready " + click.style(f"(attempt {attempt}/{retries})", dim=True) + ", waiting...")
|
|
@@ -90,11 +93,18 @@ def wait_for_ssh(host: str, user: str, key_path: Path, retries: int = 30, delay:
|
|
|
90
93
|
continue
|
|
91
94
|
|
|
92
95
|
# Port is open, try actual SSH
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
96
|
+
cmd = [
|
|
97
|
+
"ssh",
|
|
98
|
+
*base_opts,
|
|
99
|
+
*port_opts,
|
|
100
|
+
"-o",
|
|
101
|
+
"ConnectTimeout=10",
|
|
102
|
+
"-o",
|
|
103
|
+
"BatchMode=yes",
|
|
104
|
+
f"{user}@{host}",
|
|
105
|
+
"echo ok",
|
|
106
|
+
]
|
|
107
|
+
result = subprocess.run(cmd, capture_output=True, text=True)
|
|
98
108
|
if result.returncode == 0:
|
|
99
109
|
click.secho(" SSH connection established.", fg="green")
|
|
100
110
|
return True
|
|
@@ -105,15 +115,19 @@ def wait_for_ssh(host: str, user: str, key_path: Path, retries: int = 30, delay:
|
|
|
105
115
|
return False
|
|
106
116
|
|
|
107
117
|
|
|
108
|
-
def run_remote_setup(
|
|
118
|
+
def run_remote_setup(
|
|
119
|
+
host: str, user: str, key_path: Path, script_path: Path, python_version: str | None = None, port: int = 22
|
|
120
|
+
) -> bool:
|
|
109
121
|
"""SCP the setup script and requirements.txt to the instance and execute."""
|
|
110
122
|
ssh_opts = _ssh_opts(key_path)
|
|
123
|
+
scp_port_opts = ["-P", str(port)] if port != 22 else []
|
|
124
|
+
ssh_port_opts = ["-p", str(port)] if port != 22 else []
|
|
111
125
|
requirements_path = script_path.parent / "requirements.txt"
|
|
112
126
|
|
|
113
127
|
# SCP the requirements file
|
|
114
128
|
click.echo(" Uploading requirements.txt...")
|
|
115
129
|
req_result = subprocess.run(
|
|
116
|
-
["scp", *ssh_opts, str(requirements_path), f"{user}@{host}:/tmp/requirements.txt"],
|
|
130
|
+
["scp", *ssh_opts, *scp_port_opts, str(requirements_path), f"{user}@{host}:/tmp/requirements.txt"],
|
|
117
131
|
capture_output=True,
|
|
118
132
|
text=True,
|
|
119
133
|
)
|
|
@@ -125,7 +139,7 @@ def run_remote_setup(host: str, user: str, key_path: Path, script_path: Path) ->
|
|
|
125
139
|
benchmark_path = script_path.parent / "gpu_benchmark.py"
|
|
126
140
|
click.echo(" Uploading gpu_benchmark.py...")
|
|
127
141
|
bench_result = subprocess.run(
|
|
128
|
-
["scp", *ssh_opts, str(benchmark_path), f"{user}@{host}:/tmp/gpu_benchmark.py"],
|
|
142
|
+
["scp", *ssh_opts, *scp_port_opts, str(benchmark_path), f"{user}@{host}:/tmp/gpu_benchmark.py"],
|
|
129
143
|
capture_output=True,
|
|
130
144
|
text=True,
|
|
131
145
|
)
|
|
@@ -137,7 +151,7 @@ def run_remote_setup(host: str, user: str, key_path: Path, script_path: Path) ->
|
|
|
137
151
|
notebook_path = script_path.parent / "gpu_smoke_test.ipynb"
|
|
138
152
|
click.echo(" Uploading gpu_smoke_test.ipynb...")
|
|
139
153
|
nb_result = subprocess.run(
|
|
140
|
-
["scp", *ssh_opts, str(notebook_path), f"{user}@{host}:/tmp/gpu_smoke_test.ipynb"],
|
|
154
|
+
["scp", *ssh_opts, *scp_port_opts, str(notebook_path), f"{user}@{host}:/tmp/gpu_smoke_test.ipynb"],
|
|
141
155
|
capture_output=True,
|
|
142
156
|
text=True,
|
|
143
157
|
)
|
|
@@ -145,10 +159,46 @@ def run_remote_setup(host: str, user: str, key_path: Path, script_path: Path) ->
|
|
|
145
159
|
click.secho(f" SCP failed: {nb_result.stderr}", fg="red", err=True)
|
|
146
160
|
return False
|
|
147
161
|
|
|
162
|
+
# SCP the CUDA example source
|
|
163
|
+
saxpy_path = script_path.parent / "saxpy.cu"
|
|
164
|
+
click.echo(" Uploading saxpy.cu...")
|
|
165
|
+
saxpy_result = subprocess.run(
|
|
166
|
+
["scp", *ssh_opts, *scp_port_opts, str(saxpy_path), f"{user}@{host}:/tmp/saxpy.cu"],
|
|
167
|
+
capture_output=True,
|
|
168
|
+
text=True,
|
|
169
|
+
)
|
|
170
|
+
if saxpy_result.returncode != 0:
|
|
171
|
+
click.secho(f" SCP failed: {saxpy_result.stderr}", fg="red", err=True)
|
|
172
|
+
return False
|
|
173
|
+
|
|
174
|
+
# SCP the VSCode launch.json
|
|
175
|
+
launch_json_path = script_path.parent / "launch.json"
|
|
176
|
+
click.echo(" Uploading launch.json...")
|
|
177
|
+
launch_result = subprocess.run(
|
|
178
|
+
["scp", *ssh_opts, *scp_port_opts, str(launch_json_path), f"{user}@{host}:/tmp/launch.json"],
|
|
179
|
+
capture_output=True,
|
|
180
|
+
text=True,
|
|
181
|
+
)
|
|
182
|
+
if launch_result.returncode != 0:
|
|
183
|
+
click.secho(f" SCP failed: {launch_result.stderr}", fg="red", err=True)
|
|
184
|
+
return False
|
|
185
|
+
|
|
186
|
+
# SCP the VSCode tasks.json
|
|
187
|
+
tasks_json_path = script_path.parent / "tasks.json"
|
|
188
|
+
click.echo(" Uploading tasks.json...")
|
|
189
|
+
tasks_result = subprocess.run(
|
|
190
|
+
["scp", *ssh_opts, *scp_port_opts, str(tasks_json_path), f"{user}@{host}:/tmp/tasks.json"],
|
|
191
|
+
capture_output=True,
|
|
192
|
+
text=True,
|
|
193
|
+
)
|
|
194
|
+
if tasks_result.returncode != 0:
|
|
195
|
+
click.secho(f" SCP failed: {tasks_result.stderr}", fg="red", err=True)
|
|
196
|
+
return False
|
|
197
|
+
|
|
148
198
|
# SCP the script
|
|
149
199
|
click.echo(" Uploading remote_setup.sh...")
|
|
150
200
|
scp_result = subprocess.run(
|
|
151
|
-
["scp", *ssh_opts, str(script_path), f"{user}@{host}:/tmp/remote_setup.sh"],
|
|
201
|
+
["scp", *ssh_opts, *scp_port_opts, str(script_path), f"{user}@{host}:/tmp/remote_setup.sh"],
|
|
152
202
|
capture_output=True,
|
|
153
203
|
text=True,
|
|
154
204
|
)
|
|
@@ -156,10 +206,14 @@ def run_remote_setup(host: str, user: str, key_path: Path, script_path: Path) ->
|
|
|
156
206
|
click.secho(f" SCP failed: {scp_result.stderr}", fg="red", err=True)
|
|
157
207
|
return False
|
|
158
208
|
|
|
159
|
-
# Execute the script
|
|
209
|
+
# Execute the script, passing PYTHON_VERSION as an inline env var if specified
|
|
160
210
|
click.echo(" Running remote_setup.sh on instance...")
|
|
211
|
+
remote_cmd = "chmod +x /tmp/remote_setup.sh && "
|
|
212
|
+
if python_version:
|
|
213
|
+
remote_cmd += f"PYTHON_VERSION={python_version} "
|
|
214
|
+
remote_cmd += "/tmp/remote_setup.sh"
|
|
161
215
|
ssh_result = subprocess.run(
|
|
162
|
-
["ssh", *ssh_opts, f"{user}@{host}",
|
|
216
|
+
["ssh", *ssh_opts, *ssh_port_opts, f"{user}@{host}", remote_cmd],
|
|
163
217
|
capture_output=False,
|
|
164
218
|
)
|
|
165
219
|
return ssh_result.returncode == 0
|
|
@@ -222,15 +276,17 @@ def _next_alias(content: str, prefix: str = "aws-gpu") -> str:
|
|
|
222
276
|
return f"{prefix}{max_n + 1}"
|
|
223
277
|
|
|
224
278
|
|
|
225
|
-
def _build_stanza(instance_id: str, alias: str, hostname: str, user: str, key_path: Path) -> str:
|
|
279
|
+
def _build_stanza(instance_id: str, alias: str, hostname: str, user: str, key_path: Path, port: int = 22) -> str:
|
|
226
280
|
"""Build a complete SSH config stanza with markers."""
|
|
227
281
|
priv_key = private_key_path(key_path)
|
|
282
|
+
port_line = f" Port {port}\n" if port != 22 else ""
|
|
228
283
|
return (
|
|
229
284
|
f"{_BEGIN_MARKER.format(instance_id=instance_id)}\n"
|
|
230
285
|
f"Host {alias}\n"
|
|
231
286
|
f" HostName {hostname}\n"
|
|
232
287
|
f" User {user}\n"
|
|
233
288
|
f" IdentityFile {priv_key}\n"
|
|
289
|
+
f"{port_line}"
|
|
234
290
|
f" StrictHostKeyChecking no\n"
|
|
235
291
|
f" UserKnownHostsFile /dev/null\n"
|
|
236
292
|
f"{_END_MARKER.format(instance_id=instance_id)}\n"
|
|
@@ -244,6 +300,7 @@ def add_ssh_host(
|
|
|
244
300
|
key_path: Path,
|
|
245
301
|
config_path: Path | None = None,
|
|
246
302
|
alias_prefix: str = "aws-gpu",
|
|
303
|
+
port: int = 22,
|
|
247
304
|
) -> str:
|
|
248
305
|
"""Add (or update) an SSH host stanza for *instance_id*.
|
|
249
306
|
|
|
@@ -257,7 +314,7 @@ def add_ssh_host(
|
|
|
257
314
|
content = _remove_block(content, instance_id)
|
|
258
315
|
|
|
259
316
|
alias = existing_alias or _next_alias(content, alias_prefix)
|
|
260
|
-
stanza = _build_stanza(instance_id, alias, hostname, user, key_path)
|
|
317
|
+
stanza = _build_stanza(instance_id, alias, hostname, user, key_path, port=port)
|
|
261
318
|
|
|
262
319
|
# Ensure a blank line before our block if file has content
|
|
263
320
|
if content and not content.endswith("\n\n") and not content.endswith("\n"):
|
|
@@ -317,21 +374,6 @@ def list_ssh_hosts(config_path: Path | None = None) -> dict[str, str]:
|
|
|
317
374
|
return result
|
|
318
375
|
|
|
319
376
|
|
|
320
|
-
# ---------------------------------------------------------------------------
|
|
321
|
-
# GPU info via SSH
|
|
322
|
-
# ---------------------------------------------------------------------------
|
|
323
|
-
|
|
324
|
-
_GPU_ARCHITECTURES: dict[str, str] = {
|
|
325
|
-
"7.0": "Volta",
|
|
326
|
-
"7.5": "Turing",
|
|
327
|
-
"8.0": "Ampere",
|
|
328
|
-
"8.6": "Ampere",
|
|
329
|
-
"8.7": "Ampere",
|
|
330
|
-
"8.9": "Ada Lovelace",
|
|
331
|
-
"9.0": "Hopper",
|
|
332
|
-
}
|
|
333
|
-
|
|
334
|
-
|
|
335
377
|
@dataclass
|
|
336
378
|
class SSHHostDetails:
|
|
337
379
|
"""Connection details parsed from an SSH config stanza."""
|
|
@@ -339,18 +381,7 @@ class SSHHostDetails:
|
|
|
339
381
|
hostname: str
|
|
340
382
|
user: str
|
|
341
383
|
identity_file: Path
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
@dataclass
|
|
345
|
-
class GpuInfo:
|
|
346
|
-
"""GPU information retrieved via nvidia-smi and nvcc."""
|
|
347
|
-
|
|
348
|
-
driver_version: str
|
|
349
|
-
cuda_driver_version: str # max CUDA version supported by driver (from nvidia-smi)
|
|
350
|
-
cuda_toolkit_version: str | None # actual CUDA toolkit installed (from nvcc), None if unavailable
|
|
351
|
-
gpu_name: str
|
|
352
|
-
compute_capability: str
|
|
353
|
-
architecture: str
|
|
384
|
+
port: int = 22
|
|
354
385
|
|
|
355
386
|
|
|
356
387
|
def get_ssh_host_details(instance_id: str, config_path: Path | None = None) -> SSHHostDetails | None:
|
|
@@ -371,6 +402,7 @@ def get_ssh_host_details(instance_id: str, config_path: Path | None = None) -> S
|
|
|
371
402
|
hostname: str | None = None
|
|
372
403
|
user: str | None = None
|
|
373
404
|
identity_file: str | None = None
|
|
405
|
+
port: int = 22
|
|
374
406
|
|
|
375
407
|
for line in content.splitlines():
|
|
376
408
|
if line == begin_marker:
|
|
@@ -378,7 +410,7 @@ def get_ssh_host_details(instance_id: str, config_path: Path | None = None) -> S
|
|
|
378
410
|
continue
|
|
379
411
|
if line == end_marker and in_block:
|
|
380
412
|
if hostname and user and identity_file:
|
|
381
|
-
return SSHHostDetails(hostname=hostname, user=user, identity_file=Path(identity_file))
|
|
413
|
+
return SSHHostDetails(hostname=hostname, user=user, identity_file=Path(identity_file), port=port)
|
|
382
414
|
return None
|
|
383
415
|
if in_block:
|
|
384
416
|
stripped = line.strip()
|
|
@@ -388,17 +420,20 @@ def get_ssh_host_details(instance_id: str, config_path: Path | None = None) -> S
|
|
|
388
420
|
user = stripped.removeprefix("User ").strip()
|
|
389
421
|
elif stripped.startswith("IdentityFile "):
|
|
390
422
|
identity_file = stripped.removeprefix("IdentityFile ").strip()
|
|
423
|
+
elif stripped.startswith("Port "):
|
|
424
|
+
port = int(stripped.removeprefix("Port ").strip())
|
|
391
425
|
|
|
392
426
|
return None
|
|
393
427
|
|
|
394
428
|
|
|
395
|
-
def query_gpu_info(host: str, user: str, key_path: Path, timeout: int = 10) -> GpuInfo | None:
|
|
429
|
+
def query_gpu_info(host: str, user: str, key_path: Path, timeout: int = 10, port: int = 22) -> GpuInfo | None:
|
|
396
430
|
"""SSH into a host and query GPU info via ``nvidia-smi``.
|
|
397
431
|
|
|
398
432
|
Returns ``GpuInfo`` on success, or ``None`` if the SSH connection fails,
|
|
399
433
|
``nvidia-smi`` is unavailable, or the output is malformed.
|
|
400
434
|
"""
|
|
401
435
|
ssh_opts = _ssh_opts(key_path)
|
|
436
|
+
port_opts = ["-p", str(port)] if port != 22 else []
|
|
402
437
|
remote_cmd = (
|
|
403
438
|
"nvidia-smi --query-gpu=driver_version,name,compute_cap --format=csv,noheader,nounits"
|
|
404
439
|
" && nvidia-smi | grep -oP 'CUDA Version: \\K[\\d.]+'"
|
|
@@ -407,6 +442,7 @@ def query_gpu_info(host: str, user: str, key_path: Path, timeout: int = 10) -> G
|
|
|
407
442
|
cmd = [
|
|
408
443
|
"ssh",
|
|
409
444
|
*ssh_opts,
|
|
445
|
+
*port_opts,
|
|
410
446
|
"-o",
|
|
411
447
|
f"ConnectTimeout={timeout}",
|
|
412
448
|
"-o",
|
aws_bootstrap/tests/test_cli.py
CHANGED
|
@@ -9,7 +9,8 @@ import botocore.exceptions
|
|
|
9
9
|
from click.testing import CliRunner
|
|
10
10
|
|
|
11
11
|
from aws_bootstrap.cli import main
|
|
12
|
-
from aws_bootstrap.
|
|
12
|
+
from aws_bootstrap.gpu import GpuInfo
|
|
13
|
+
from aws_bootstrap.ssh import SSHHostDetails
|
|
13
14
|
|
|
14
15
|
|
|
15
16
|
def test_help():
|
|
@@ -73,11 +74,12 @@ def test_status_no_instances(mock_find, mock_session):
|
|
|
73
74
|
assert "No active" in result.output
|
|
74
75
|
|
|
75
76
|
|
|
77
|
+
@patch("aws_bootstrap.cli.get_ssh_host_details", return_value=None)
|
|
76
78
|
@patch("aws_bootstrap.cli.list_ssh_hosts", return_value={})
|
|
77
79
|
@patch("aws_bootstrap.cli.boto3.Session")
|
|
78
80
|
@patch("aws_bootstrap.cli.get_spot_price")
|
|
79
81
|
@patch("aws_bootstrap.cli.find_tagged_instances")
|
|
80
|
-
def test_status_shows_instances(mock_find, mock_spot_price, mock_session, mock_ssh_hosts):
|
|
82
|
+
def test_status_shows_instances(mock_find, mock_spot_price, mock_session, mock_ssh_hosts, mock_details):
|
|
81
83
|
mock_find.return_value = [
|
|
82
84
|
{
|
|
83
85
|
"InstanceId": "i-abc123",
|
|
@@ -101,11 +103,12 @@ def test_status_shows_instances(mock_find, mock_spot_price, mock_session, mock_s
|
|
|
101
103
|
assert "Est. cost" in result.output
|
|
102
104
|
|
|
103
105
|
|
|
106
|
+
@patch("aws_bootstrap.cli.get_ssh_host_details", return_value=None)
|
|
104
107
|
@patch("aws_bootstrap.cli.list_ssh_hosts", return_value={})
|
|
105
108
|
@patch("aws_bootstrap.cli.boto3.Session")
|
|
106
109
|
@patch("aws_bootstrap.cli.get_spot_price")
|
|
107
110
|
@patch("aws_bootstrap.cli.find_tagged_instances")
|
|
108
|
-
def test_status_on_demand_no_cost(mock_find, mock_spot_price, mock_session, mock_ssh_hosts):
|
|
111
|
+
def test_status_on_demand_no_cost(mock_find, mock_spot_price, mock_session, mock_ssh_hosts, mock_details):
|
|
109
112
|
mock_find.return_value = [
|
|
110
113
|
{
|
|
111
114
|
"InstanceId": "i-ondemand",
|
|
@@ -351,11 +354,12 @@ def test_terminate_removes_ssh_config(mock_terminate, mock_find, mock_session, m
|
|
|
351
354
|
mock_remove_ssh.assert_called_once_with("i-abc123")
|
|
352
355
|
|
|
353
356
|
|
|
357
|
+
@patch("aws_bootstrap.cli.get_ssh_host_details", return_value=None)
|
|
354
358
|
@patch("aws_bootstrap.cli.list_ssh_hosts")
|
|
355
359
|
@patch("aws_bootstrap.cli.boto3.Session")
|
|
356
360
|
@patch("aws_bootstrap.cli.get_spot_price")
|
|
357
361
|
@patch("aws_bootstrap.cli.find_tagged_instances")
|
|
358
|
-
def test_status_shows_alias(mock_find, mock_spot_price, mock_session, mock_ssh_hosts):
|
|
362
|
+
def test_status_shows_alias(mock_find, mock_spot_price, mock_session, mock_ssh_hosts, mock_details):
|
|
359
363
|
mock_find.return_value = [
|
|
360
364
|
{
|
|
361
365
|
"InstanceId": "i-abc123",
|
|
@@ -376,11 +380,12 @@ def test_status_shows_alias(mock_find, mock_spot_price, mock_session, mock_ssh_h
|
|
|
376
380
|
assert "aws-gpu1" in result.output
|
|
377
381
|
|
|
378
382
|
|
|
383
|
+
@patch("aws_bootstrap.cli.get_ssh_host_details", return_value=None)
|
|
379
384
|
@patch("aws_bootstrap.cli.list_ssh_hosts", return_value={})
|
|
380
385
|
@patch("aws_bootstrap.cli.boto3.Session")
|
|
381
386
|
@patch("aws_bootstrap.cli.get_spot_price")
|
|
382
387
|
@patch("aws_bootstrap.cli.find_tagged_instances")
|
|
383
|
-
def test_status_no_alias_graceful(mock_find, mock_spot_price, mock_session, mock_ssh_hosts):
|
|
388
|
+
def test_status_no_alias_graceful(mock_find, mock_spot_price, mock_session, mock_ssh_hosts, mock_details):
|
|
384
389
|
mock_find.return_value = [
|
|
385
390
|
{
|
|
386
391
|
"InstanceId": "i-old999",
|
|
@@ -520,13 +525,98 @@ def test_status_gpu_skips_non_running(mock_find, mock_session, mock_ssh_hosts, m
|
|
|
520
525
|
@patch("aws_bootstrap.cli.boto3.Session")
|
|
521
526
|
@patch("aws_bootstrap.cli.get_spot_price", return_value=0.15)
|
|
522
527
|
@patch("aws_bootstrap.cli.find_tagged_instances")
|
|
523
|
-
def
|
|
528
|
+
def test_status_without_gpu_flag_no_gpu_query(
|
|
529
|
+
mock_find, mock_spot, mock_session, mock_ssh_hosts, mock_details, mock_gpu
|
|
530
|
+
):
|
|
524
531
|
mock_find.return_value = [_RUNNING_INSTANCE]
|
|
525
532
|
runner = CliRunner()
|
|
526
533
|
result = runner.invoke(main, ["status"])
|
|
527
534
|
assert result.exit_code == 0
|
|
528
535
|
mock_gpu.assert_not_called()
|
|
529
|
-
|
|
536
|
+
|
|
537
|
+
|
|
538
|
+
# ---------------------------------------------------------------------------
|
|
539
|
+
# --instructions / --no-instructions / -I flag tests
|
|
540
|
+
# ---------------------------------------------------------------------------
|
|
541
|
+
|
|
542
|
+
|
|
543
|
+
def test_status_help_shows_instructions_flag():
|
|
544
|
+
runner = CliRunner()
|
|
545
|
+
result = runner.invoke(main, ["status", "--help"])
|
|
546
|
+
assert result.exit_code == 0
|
|
547
|
+
assert "--instructions" in result.output
|
|
548
|
+
assert "--no-instructions" in result.output
|
|
549
|
+
assert "-I" in result.output
|
|
550
|
+
|
|
551
|
+
|
|
552
|
+
@patch("aws_bootstrap.cli.get_ssh_host_details")
|
|
553
|
+
@patch("aws_bootstrap.cli.list_ssh_hosts", return_value={"i-abc123": "aws-gpu1"})
|
|
554
|
+
@patch("aws_bootstrap.cli.boto3.Session")
|
|
555
|
+
@patch("aws_bootstrap.cli.get_spot_price", return_value=0.15)
|
|
556
|
+
@patch("aws_bootstrap.cli.find_tagged_instances")
|
|
557
|
+
def test_status_instructions_shown_by_default(mock_find, mock_spot, mock_session, mock_ssh_hosts, mock_details):
|
|
558
|
+
"""Instructions are shown by default (no flag needed)."""
|
|
559
|
+
mock_find.return_value = [_RUNNING_INSTANCE]
|
|
560
|
+
mock_details.return_value = SSHHostDetails(
|
|
561
|
+
hostname="1.2.3.4", user="ubuntu", identity_file=Path("/home/user/.ssh/id_ed25519")
|
|
562
|
+
)
|
|
563
|
+
runner = CliRunner()
|
|
564
|
+
result = runner.invoke(main, ["status"])
|
|
565
|
+
assert result.exit_code == 0
|
|
566
|
+
assert "ssh aws-gpu1" in result.output
|
|
567
|
+
assert "ssh -NL 8888:localhost:8888 aws-gpu1" in result.output
|
|
568
|
+
assert "vscode-remote://ssh-remote+aws-gpu1/home/ubuntu/workspace" in result.output
|
|
569
|
+
assert "python ~/gpu_benchmark.py" in result.output
|
|
570
|
+
|
|
571
|
+
|
|
572
|
+
@patch("aws_bootstrap.cli.get_ssh_host_details")
|
|
573
|
+
@patch("aws_bootstrap.cli.list_ssh_hosts", return_value={"i-abc123": "aws-gpu1"})
|
|
574
|
+
@patch("aws_bootstrap.cli.boto3.Session")
|
|
575
|
+
@patch("aws_bootstrap.cli.get_spot_price", return_value=0.15)
|
|
576
|
+
@patch("aws_bootstrap.cli.find_tagged_instances")
|
|
577
|
+
def test_status_no_instructions_suppresses_commands(mock_find, mock_spot, mock_session, mock_ssh_hosts, mock_details):
|
|
578
|
+
"""--no-instructions suppresses connection commands."""
|
|
579
|
+
mock_find.return_value = [_RUNNING_INSTANCE]
|
|
580
|
+
mock_details.return_value = SSHHostDetails(
|
|
581
|
+
hostname="1.2.3.4", user="ubuntu", identity_file=Path("/home/user/.ssh/id_ed25519")
|
|
582
|
+
)
|
|
583
|
+
runner = CliRunner()
|
|
584
|
+
result = runner.invoke(main, ["status", "--no-instructions"])
|
|
585
|
+
assert result.exit_code == 0
|
|
586
|
+
assert "vscode-remote" not in result.output
|
|
587
|
+
assert "Jupyter" not in result.output
|
|
588
|
+
|
|
589
|
+
|
|
590
|
+
@patch("aws_bootstrap.cli.get_ssh_host_details")
|
|
591
|
+
@patch("aws_bootstrap.cli.list_ssh_hosts", return_value={})
|
|
592
|
+
@patch("aws_bootstrap.cli.boto3.Session")
|
|
593
|
+
@patch("aws_bootstrap.cli.get_spot_price", return_value=0.15)
|
|
594
|
+
@patch("aws_bootstrap.cli.find_tagged_instances")
|
|
595
|
+
def test_status_instructions_no_alias_skips(mock_find, mock_spot, mock_session, mock_ssh_hosts, mock_details):
|
|
596
|
+
"""Instances without an SSH alias don't get connection instructions."""
|
|
597
|
+
mock_find.return_value = [_RUNNING_INSTANCE]
|
|
598
|
+
runner = CliRunner()
|
|
599
|
+
result = runner.invoke(main, ["status"])
|
|
600
|
+
assert result.exit_code == 0
|
|
601
|
+
assert "ssh aws-gpu" not in result.output
|
|
602
|
+
assert "vscode-remote" not in result.output
|
|
603
|
+
|
|
604
|
+
|
|
605
|
+
@patch("aws_bootstrap.cli.get_ssh_host_details")
|
|
606
|
+
@patch("aws_bootstrap.cli.list_ssh_hosts", return_value={"i-abc123": "aws-gpu1"})
|
|
607
|
+
@patch("aws_bootstrap.cli.boto3.Session")
|
|
608
|
+
@patch("aws_bootstrap.cli.get_spot_price", return_value=0.15)
|
|
609
|
+
@patch("aws_bootstrap.cli.find_tagged_instances")
|
|
610
|
+
def test_status_instructions_non_default_port(mock_find, mock_spot, mock_session, mock_ssh_hosts, mock_details):
|
|
611
|
+
mock_find.return_value = [_RUNNING_INSTANCE]
|
|
612
|
+
mock_details.return_value = SSHHostDetails(
|
|
613
|
+
hostname="1.2.3.4", user="ubuntu", identity_file=Path("/home/user/.ssh/id_ed25519"), port=2222
|
|
614
|
+
)
|
|
615
|
+
runner = CliRunner()
|
|
616
|
+
result = runner.invoke(main, ["status"])
|
|
617
|
+
assert result.exit_code == 0
|
|
618
|
+
assert "ssh -p 2222 aws-gpu1" in result.output
|
|
619
|
+
assert "ssh -NL 8888:localhost:8888 -p 2222 aws-gpu1" in result.output
|
|
530
620
|
|
|
531
621
|
|
|
532
622
|
# ---------------------------------------------------------------------------
|
|
@@ -636,3 +726,111 @@ def test_no_credentials_caught_on_list(mock_session, mock_list):
|
|
|
636
726
|
result = runner.invoke(main, ["list", "instance-types"])
|
|
637
727
|
assert result.exit_code != 0
|
|
638
728
|
assert "Unable to locate AWS credentials" in result.output
|
|
729
|
+
|
|
730
|
+
|
|
731
|
+
# ---------------------------------------------------------------------------
|
|
732
|
+
# --python-version tests
|
|
733
|
+
# ---------------------------------------------------------------------------
|
|
734
|
+
|
|
735
|
+
|
|
736
|
+
@patch("aws_bootstrap.cli.add_ssh_host", return_value="aws-gpu1")
|
|
737
|
+
@patch("aws_bootstrap.cli.run_remote_setup", return_value=True)
|
|
738
|
+
@patch("aws_bootstrap.cli.wait_for_ssh", return_value=True)
|
|
739
|
+
@patch("aws_bootstrap.cli.wait_instance_ready")
|
|
740
|
+
@patch("aws_bootstrap.cli.launch_instance")
|
|
741
|
+
@patch("aws_bootstrap.cli.ensure_security_group", return_value="sg-123")
|
|
742
|
+
@patch("aws_bootstrap.cli.import_key_pair", return_value="aws-bootstrap-key")
|
|
743
|
+
@patch("aws_bootstrap.cli.get_latest_ami")
|
|
744
|
+
@patch("aws_bootstrap.cli.boto3.Session")
|
|
745
|
+
def test_launch_python_version_passed_to_setup(
|
|
746
|
+
mock_session, mock_ami, mock_import, mock_sg, mock_launch, mock_wait, mock_ssh, mock_setup, mock_add_ssh, tmp_path
|
|
747
|
+
):
|
|
748
|
+
mock_ami.return_value = {"ImageId": "ami-123", "Name": "TestAMI"}
|
|
749
|
+
mock_launch.return_value = {"InstanceId": "i-test123"}
|
|
750
|
+
mock_wait.return_value = {"PublicIpAddress": "1.2.3.4"}
|
|
751
|
+
|
|
752
|
+
key_path = tmp_path / "id_ed25519.pub"
|
|
753
|
+
key_path.write_text("ssh-ed25519 AAAA test@host")
|
|
754
|
+
|
|
755
|
+
runner = CliRunner()
|
|
756
|
+
result = runner.invoke(main, ["launch", "--key-path", str(key_path), "--python-version", "3.13"])
|
|
757
|
+
assert result.exit_code == 0
|
|
758
|
+
mock_setup.assert_called_once()
|
|
759
|
+
assert mock_setup.call_args[0][4] == "3.13"
|
|
760
|
+
|
|
761
|
+
|
|
762
|
+
@patch("aws_bootstrap.cli.boto3.Session")
|
|
763
|
+
@patch("aws_bootstrap.cli.get_latest_ami")
|
|
764
|
+
@patch("aws_bootstrap.cli.import_key_pair", return_value="aws-bootstrap-key")
|
|
765
|
+
@patch("aws_bootstrap.cli.ensure_security_group", return_value="sg-123")
|
|
766
|
+
def test_launch_dry_run_shows_python_version(mock_sg, mock_import, mock_ami, mock_session, tmp_path):
|
|
767
|
+
mock_ami.return_value = {"ImageId": "ami-123", "Name": "TestAMI"}
|
|
768
|
+
|
|
769
|
+
key_path = tmp_path / "id_ed25519.pub"
|
|
770
|
+
key_path.write_text("ssh-ed25519 AAAA test@host")
|
|
771
|
+
|
|
772
|
+
runner = CliRunner()
|
|
773
|
+
result = runner.invoke(main, ["launch", "--key-path", str(key_path), "--dry-run", "--python-version", "3.14.2"])
|
|
774
|
+
assert result.exit_code == 0
|
|
775
|
+
assert "3.14.2" in result.output
|
|
776
|
+
assert "Python version" in result.output
|
|
777
|
+
|
|
778
|
+
|
|
779
|
+
@patch("aws_bootstrap.cli.boto3.Session")
|
|
780
|
+
@patch("aws_bootstrap.cli.get_latest_ami")
|
|
781
|
+
@patch("aws_bootstrap.cli.import_key_pair", return_value="aws-bootstrap-key")
|
|
782
|
+
@patch("aws_bootstrap.cli.ensure_security_group", return_value="sg-123")
|
|
783
|
+
def test_launch_dry_run_omits_python_version_when_unset(mock_sg, mock_import, mock_ami, mock_session, tmp_path):
|
|
784
|
+
mock_ami.return_value = {"ImageId": "ami-123", "Name": "TestAMI"}
|
|
785
|
+
|
|
786
|
+
key_path = tmp_path / "id_ed25519.pub"
|
|
787
|
+
key_path.write_text("ssh-ed25519 AAAA test@host")
|
|
788
|
+
|
|
789
|
+
runner = CliRunner()
|
|
790
|
+
result = runner.invoke(main, ["launch", "--key-path", str(key_path), "--dry-run"])
|
|
791
|
+
assert result.exit_code == 0
|
|
792
|
+
assert "Python version" not in result.output
|
|
793
|
+
|
|
794
|
+
|
|
795
|
+
# ---------------------------------------------------------------------------
|
|
796
|
+
# --ssh-port tests
|
|
797
|
+
# ---------------------------------------------------------------------------
|
|
798
|
+
|
|
799
|
+
|
|
800
|
+
def test_launch_help_shows_ssh_port():
|
|
801
|
+
runner = CliRunner()
|
|
802
|
+
result = runner.invoke(main, ["launch", "--help"])
|
|
803
|
+
assert result.exit_code == 0
|
|
804
|
+
assert "--ssh-port" in result.output
|
|
805
|
+
|
|
806
|
+
|
|
807
|
+
@patch("aws_bootstrap.cli.boto3.Session")
|
|
808
|
+
@patch("aws_bootstrap.cli.get_latest_ami")
|
|
809
|
+
@patch("aws_bootstrap.cli.import_key_pair", return_value="aws-bootstrap-key")
|
|
810
|
+
@patch("aws_bootstrap.cli.ensure_security_group", return_value="sg-123")
|
|
811
|
+
def test_launch_dry_run_shows_ssh_port_when_non_default(mock_sg, mock_import, mock_ami, mock_session, tmp_path):
|
|
812
|
+
mock_ami.return_value = {"ImageId": "ami-123", "Name": "TestAMI"}
|
|
813
|
+
|
|
814
|
+
key_path = tmp_path / "id_ed25519.pub"
|
|
815
|
+
key_path.write_text("ssh-ed25519 AAAA test@host")
|
|
816
|
+
|
|
817
|
+
runner = CliRunner()
|
|
818
|
+
result = runner.invoke(main, ["launch", "--key-path", str(key_path), "--dry-run", "--ssh-port", "2222"])
|
|
819
|
+
assert result.exit_code == 0
|
|
820
|
+
assert "2222" in result.output
|
|
821
|
+
|
|
822
|
+
|
|
823
|
+
@patch("aws_bootstrap.cli.boto3.Session")
|
|
824
|
+
@patch("aws_bootstrap.cli.get_latest_ami")
|
|
825
|
+
@patch("aws_bootstrap.cli.import_key_pair", return_value="aws-bootstrap-key")
|
|
826
|
+
@patch("aws_bootstrap.cli.ensure_security_group", return_value="sg-123")
|
|
827
|
+
def test_launch_dry_run_omits_ssh_port_when_default(mock_sg, mock_import, mock_ami, mock_session, tmp_path):
|
|
828
|
+
mock_ami.return_value = {"ImageId": "ami-123", "Name": "TestAMI"}
|
|
829
|
+
|
|
830
|
+
key_path = tmp_path / "id_ed25519.pub"
|
|
831
|
+
key_path.write_text("ssh-ed25519 AAAA test@host")
|
|
832
|
+
|
|
833
|
+
runner = CliRunner()
|
|
834
|
+
result = runner.invoke(main, ["launch", "--key-path", str(key_path), "--dry-run"])
|
|
835
|
+
assert result.exit_code == 0
|
|
836
|
+
assert "SSH port" not in result.output
|
|
@@ -0,0 +1,98 @@
|
|
|
1
|
+
"""Tests for GPU info queries via SSH (query_gpu_info, GPU architecture mapping)."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
import subprocess
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from unittest.mock import patch
|
|
7
|
+
|
|
8
|
+
from aws_bootstrap.gpu import _GPU_ARCHITECTURES, GpuInfo
|
|
9
|
+
from aws_bootstrap.ssh import query_gpu_info
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
# ---------------------------------------------------------------------------
|
|
13
|
+
# query_gpu_info
|
|
14
|
+
# ---------------------------------------------------------------------------
|
|
15
|
+
|
|
16
|
+
NVIDIA_SMI_OUTPUT = "560.35.03, Tesla T4, 7.5\n12.8\n12.6\n"
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@patch("aws_bootstrap.ssh.subprocess.run")
|
|
20
|
+
def test_query_gpu_info_success(mock_run):
|
|
21
|
+
"""Successful nvidia-smi + nvcc output returns a valid GpuInfo."""
|
|
22
|
+
mock_run.return_value = subprocess.CompletedProcess(args=[], returncode=0, stdout=NVIDIA_SMI_OUTPUT, stderr="")
|
|
23
|
+
|
|
24
|
+
info = query_gpu_info("1.2.3.4", "ubuntu", Path("/home/user/.ssh/id_ed25519"))
|
|
25
|
+
assert info is not None
|
|
26
|
+
assert isinstance(info, GpuInfo)
|
|
27
|
+
assert info.driver_version == "560.35.03"
|
|
28
|
+
assert info.cuda_driver_version == "12.8"
|
|
29
|
+
assert info.cuda_toolkit_version == "12.6"
|
|
30
|
+
assert info.gpu_name == "Tesla T4"
|
|
31
|
+
assert info.compute_capability == "7.5"
|
|
32
|
+
assert info.architecture == "Turing"
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@patch("aws_bootstrap.ssh.subprocess.run")
|
|
36
|
+
def test_query_gpu_info_no_nvcc(mock_run):
|
|
37
|
+
"""When nvcc is unavailable, cuda_toolkit_version is None."""
|
|
38
|
+
output = "560.35.03, Tesla T4, 7.5\n12.8\nN/A\n"
|
|
39
|
+
mock_run.return_value = subprocess.CompletedProcess(args=[], returncode=0, stdout=output, stderr="")
|
|
40
|
+
|
|
41
|
+
info = query_gpu_info("1.2.3.4", "ubuntu", Path("/home/user/.ssh/id_ed25519"))
|
|
42
|
+
assert info is not None
|
|
43
|
+
assert info.cuda_driver_version == "12.8"
|
|
44
|
+
assert info.cuda_toolkit_version is None
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
@patch("aws_bootstrap.ssh.subprocess.run")
|
|
48
|
+
def test_query_gpu_info_ssh_failure(mock_run):
|
|
49
|
+
"""Non-zero exit code returns None."""
|
|
50
|
+
mock_run.return_value = subprocess.CompletedProcess(args=[], returncode=255, stdout="", stderr="Connection refused")
|
|
51
|
+
|
|
52
|
+
info = query_gpu_info("1.2.3.4", "ubuntu", Path("/home/user/.ssh/id_ed25519"))
|
|
53
|
+
assert info is None
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
@patch("aws_bootstrap.ssh.subprocess.run", side_effect=subprocess.TimeoutExpired(cmd="ssh", timeout=15))
|
|
57
|
+
def test_query_gpu_info_timeout(mock_run):
|
|
58
|
+
"""TimeoutExpired returns None."""
|
|
59
|
+
info = query_gpu_info("1.2.3.4", "ubuntu", Path("/home/user/.ssh/id_ed25519"))
|
|
60
|
+
assert info is None
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
@patch("aws_bootstrap.ssh.subprocess.run")
|
|
64
|
+
def test_query_gpu_info_malformed_output(mock_run):
|
|
65
|
+
"""Garbage output returns None."""
|
|
66
|
+
mock_run.return_value = subprocess.CompletedProcess(
|
|
67
|
+
args=[], returncode=0, stdout="not valid gpu output\n", stderr=""
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
info = query_gpu_info("1.2.3.4", "ubuntu", Path("/home/user/.ssh/id_ed25519"))
|
|
71
|
+
assert info is None
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
# ---------------------------------------------------------------------------
|
|
75
|
+
# GPU architecture mapping
|
|
76
|
+
# ---------------------------------------------------------------------------
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def test_gpu_architecture_mapping():
|
|
80
|
+
"""Known compute capabilities map to correct architecture names."""
|
|
81
|
+
assert _GPU_ARCHITECTURES["7.5"] == "Turing"
|
|
82
|
+
assert _GPU_ARCHITECTURES["8.0"] == "Ampere"
|
|
83
|
+
assert _GPU_ARCHITECTURES["8.6"] == "Ampere"
|
|
84
|
+
assert _GPU_ARCHITECTURES["8.9"] == "Ada Lovelace"
|
|
85
|
+
assert _GPU_ARCHITECTURES["9.0"] == "Hopper"
|
|
86
|
+
assert _GPU_ARCHITECTURES["7.0"] == "Volta"
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
@patch("aws_bootstrap.ssh.subprocess.run")
|
|
90
|
+
def test_query_gpu_info_unknown_architecture(mock_run):
|
|
91
|
+
"""Unknown compute capability produces a fallback architecture string."""
|
|
92
|
+
mock_run.return_value = subprocess.CompletedProcess(
|
|
93
|
+
args=[], returncode=0, stdout="550.00.00, Future GPU, 10.0\n13.0\n13.0\n", stderr=""
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
info = query_gpu_info("1.2.3.4", "ubuntu", Path("/home/user/.ssh/id_ed25519"))
|
|
97
|
+
assert info is not None
|
|
98
|
+
assert info.architecture == "Unknown (10.0)"
|