aws-bootstrap-g4dn 0.2.0__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
aws_bootstrap/ssh.py CHANGED
@@ -12,6 +12,8 @@ from pathlib import Path
12
12
 
13
13
  import click
14
14
 
15
+ from .gpu import _GPU_ARCHITECTURES, GpuInfo
16
+
15
17
 
16
18
  # ---------------------------------------------------------------------------
17
19
  # SSH config markers
@@ -72,17 +74,18 @@ def import_key_pair(ec2_client, key_name: str, key_path: Path) -> str:
72
74
  return key_name
73
75
 
74
76
 
75
- def wait_for_ssh(host: str, user: str, key_path: Path, retries: int = 30, delay: int = 10) -> bool:
77
+ def wait_for_ssh(host: str, user: str, key_path: Path, retries: int = 30, delay: int = 10, port: int = 22) -> bool:
76
78
  """Wait for SSH to become available on the instance.
77
79
 
78
- Tries a TCP connection to port 22 first, then an actual SSH command.
80
+ Tries a TCP connection to the SSH port first, then an actual SSH command.
79
81
  """
80
82
  base_opts = _ssh_opts(key_path)
83
+ port_opts = ["-p", str(port)] if port != 22 else []
81
84
 
82
85
  for attempt in range(1, retries + 1):
83
- # First check if port 22 is open
86
+ # First check if the SSH port is open
84
87
  try:
85
- sock = socket.create_connection((host, 22), timeout=5)
88
+ sock = socket.create_connection((host, port), timeout=5)
86
89
  sock.close()
87
90
  except (TimeoutError, ConnectionRefusedError, OSError):
88
91
  click.echo(" SSH not ready " + click.style(f"(attempt {attempt}/{retries})", dim=True) + ", waiting...")
@@ -90,11 +93,18 @@ def wait_for_ssh(host: str, user: str, key_path: Path, retries: int = 30, delay:
90
93
  continue
91
94
 
92
95
  # Port is open, try actual SSH
93
- result = subprocess.run(
94
- ["ssh", *base_opts, "-o", "ConnectTimeout=10", "-o", "BatchMode=yes", f"{user}@{host}", "echo ok"],
95
- capture_output=True,
96
- text=True,
97
- )
96
+ cmd = [
97
+ "ssh",
98
+ *base_opts,
99
+ *port_opts,
100
+ "-o",
101
+ "ConnectTimeout=10",
102
+ "-o",
103
+ "BatchMode=yes",
104
+ f"{user}@{host}",
105
+ "echo ok",
106
+ ]
107
+ result = subprocess.run(cmd, capture_output=True, text=True)
98
108
  if result.returncode == 0:
99
109
  click.secho(" SSH connection established.", fg="green")
100
110
  return True
@@ -105,15 +115,19 @@ def wait_for_ssh(host: str, user: str, key_path: Path, retries: int = 30, delay:
105
115
  return False
106
116
 
107
117
 
108
- def run_remote_setup(host: str, user: str, key_path: Path, script_path: Path) -> bool:
118
+ def run_remote_setup(
119
+ host: str, user: str, key_path: Path, script_path: Path, python_version: str | None = None, port: int = 22
120
+ ) -> bool:
109
121
  """SCP the setup script and requirements.txt to the instance and execute."""
110
122
  ssh_opts = _ssh_opts(key_path)
123
+ scp_port_opts = ["-P", str(port)] if port != 22 else []
124
+ ssh_port_opts = ["-p", str(port)] if port != 22 else []
111
125
  requirements_path = script_path.parent / "requirements.txt"
112
126
 
113
127
  # SCP the requirements file
114
128
  click.echo(" Uploading requirements.txt...")
115
129
  req_result = subprocess.run(
116
- ["scp", *ssh_opts, str(requirements_path), f"{user}@{host}:/tmp/requirements.txt"],
130
+ ["scp", *ssh_opts, *scp_port_opts, str(requirements_path), f"{user}@{host}:/tmp/requirements.txt"],
117
131
  capture_output=True,
118
132
  text=True,
119
133
  )
@@ -125,7 +139,7 @@ def run_remote_setup(host: str, user: str, key_path: Path, script_path: Path) ->
125
139
  benchmark_path = script_path.parent / "gpu_benchmark.py"
126
140
  click.echo(" Uploading gpu_benchmark.py...")
127
141
  bench_result = subprocess.run(
128
- ["scp", *ssh_opts, str(benchmark_path), f"{user}@{host}:/tmp/gpu_benchmark.py"],
142
+ ["scp", *ssh_opts, *scp_port_opts, str(benchmark_path), f"{user}@{host}:/tmp/gpu_benchmark.py"],
129
143
  capture_output=True,
130
144
  text=True,
131
145
  )
@@ -137,7 +151,7 @@ def run_remote_setup(host: str, user: str, key_path: Path, script_path: Path) ->
137
151
  notebook_path = script_path.parent / "gpu_smoke_test.ipynb"
138
152
  click.echo(" Uploading gpu_smoke_test.ipynb...")
139
153
  nb_result = subprocess.run(
140
- ["scp", *ssh_opts, str(notebook_path), f"{user}@{host}:/tmp/gpu_smoke_test.ipynb"],
154
+ ["scp", *ssh_opts, *scp_port_opts, str(notebook_path), f"{user}@{host}:/tmp/gpu_smoke_test.ipynb"],
141
155
  capture_output=True,
142
156
  text=True,
143
157
  )
@@ -145,10 +159,46 @@ def run_remote_setup(host: str, user: str, key_path: Path, script_path: Path) ->
145
159
  click.secho(f" SCP failed: {nb_result.stderr}", fg="red", err=True)
146
160
  return False
147
161
 
162
+ # SCP the CUDA example source
163
+ saxpy_path = script_path.parent / "saxpy.cu"
164
+ click.echo(" Uploading saxpy.cu...")
165
+ saxpy_result = subprocess.run(
166
+ ["scp", *ssh_opts, *scp_port_opts, str(saxpy_path), f"{user}@{host}:/tmp/saxpy.cu"],
167
+ capture_output=True,
168
+ text=True,
169
+ )
170
+ if saxpy_result.returncode != 0:
171
+ click.secho(f" SCP failed: {saxpy_result.stderr}", fg="red", err=True)
172
+ return False
173
+
174
+ # SCP the VSCode launch.json
175
+ launch_json_path = script_path.parent / "launch.json"
176
+ click.echo(" Uploading launch.json...")
177
+ launch_result = subprocess.run(
178
+ ["scp", *ssh_opts, *scp_port_opts, str(launch_json_path), f"{user}@{host}:/tmp/launch.json"],
179
+ capture_output=True,
180
+ text=True,
181
+ )
182
+ if launch_result.returncode != 0:
183
+ click.secho(f" SCP failed: {launch_result.stderr}", fg="red", err=True)
184
+ return False
185
+
186
+ # SCP the VSCode tasks.json
187
+ tasks_json_path = script_path.parent / "tasks.json"
188
+ click.echo(" Uploading tasks.json...")
189
+ tasks_result = subprocess.run(
190
+ ["scp", *ssh_opts, *scp_port_opts, str(tasks_json_path), f"{user}@{host}:/tmp/tasks.json"],
191
+ capture_output=True,
192
+ text=True,
193
+ )
194
+ if tasks_result.returncode != 0:
195
+ click.secho(f" SCP failed: {tasks_result.stderr}", fg="red", err=True)
196
+ return False
197
+
148
198
  # SCP the script
149
199
  click.echo(" Uploading remote_setup.sh...")
150
200
  scp_result = subprocess.run(
151
- ["scp", *ssh_opts, str(script_path), f"{user}@{host}:/tmp/remote_setup.sh"],
201
+ ["scp", *ssh_opts, *scp_port_opts, str(script_path), f"{user}@{host}:/tmp/remote_setup.sh"],
152
202
  capture_output=True,
153
203
  text=True,
154
204
  )
@@ -156,10 +206,14 @@ def run_remote_setup(host: str, user: str, key_path: Path, script_path: Path) ->
156
206
  click.secho(f" SCP failed: {scp_result.stderr}", fg="red", err=True)
157
207
  return False
158
208
 
159
- # Execute the script
209
+ # Execute the script, passing PYTHON_VERSION as an inline env var if specified
160
210
  click.echo(" Running remote_setup.sh on instance...")
211
+ remote_cmd = "chmod +x /tmp/remote_setup.sh && "
212
+ if python_version:
213
+ remote_cmd += f"PYTHON_VERSION={python_version} "
214
+ remote_cmd += "/tmp/remote_setup.sh"
161
215
  ssh_result = subprocess.run(
162
- ["ssh", *ssh_opts, f"{user}@{host}", "chmod +x /tmp/remote_setup.sh && /tmp/remote_setup.sh"],
216
+ ["ssh", *ssh_opts, *ssh_port_opts, f"{user}@{host}", remote_cmd],
163
217
  capture_output=False,
164
218
  )
165
219
  return ssh_result.returncode == 0
@@ -222,15 +276,17 @@ def _next_alias(content: str, prefix: str = "aws-gpu") -> str:
222
276
  return f"{prefix}{max_n + 1}"
223
277
 
224
278
 
225
- def _build_stanza(instance_id: str, alias: str, hostname: str, user: str, key_path: Path) -> str:
279
+ def _build_stanza(instance_id: str, alias: str, hostname: str, user: str, key_path: Path, port: int = 22) -> str:
226
280
  """Build a complete SSH config stanza with markers."""
227
281
  priv_key = private_key_path(key_path)
282
+ port_line = f" Port {port}\n" if port != 22 else ""
228
283
  return (
229
284
  f"{_BEGIN_MARKER.format(instance_id=instance_id)}\n"
230
285
  f"Host {alias}\n"
231
286
  f" HostName {hostname}\n"
232
287
  f" User {user}\n"
233
288
  f" IdentityFile {priv_key}\n"
289
+ f"{port_line}"
234
290
  f" StrictHostKeyChecking no\n"
235
291
  f" UserKnownHostsFile /dev/null\n"
236
292
  f"{_END_MARKER.format(instance_id=instance_id)}\n"
@@ -244,6 +300,7 @@ def add_ssh_host(
244
300
  key_path: Path,
245
301
  config_path: Path | None = None,
246
302
  alias_prefix: str = "aws-gpu",
303
+ port: int = 22,
247
304
  ) -> str:
248
305
  """Add (or update) an SSH host stanza for *instance_id*.
249
306
 
@@ -257,7 +314,7 @@ def add_ssh_host(
257
314
  content = _remove_block(content, instance_id)
258
315
 
259
316
  alias = existing_alias or _next_alias(content, alias_prefix)
260
- stanza = _build_stanza(instance_id, alias, hostname, user, key_path)
317
+ stanza = _build_stanza(instance_id, alias, hostname, user, key_path, port=port)
261
318
 
262
319
  # Ensure a blank line before our block if file has content
263
320
  if content and not content.endswith("\n\n") and not content.endswith("\n"):
@@ -317,21 +374,6 @@ def list_ssh_hosts(config_path: Path | None = None) -> dict[str, str]:
317
374
  return result
318
375
 
319
376
 
320
- # ---------------------------------------------------------------------------
321
- # GPU info via SSH
322
- # ---------------------------------------------------------------------------
323
-
324
- _GPU_ARCHITECTURES: dict[str, str] = {
325
- "7.0": "Volta",
326
- "7.5": "Turing",
327
- "8.0": "Ampere",
328
- "8.6": "Ampere",
329
- "8.7": "Ampere",
330
- "8.9": "Ada Lovelace",
331
- "9.0": "Hopper",
332
- }
333
-
334
-
335
377
  @dataclass
336
378
  class SSHHostDetails:
337
379
  """Connection details parsed from an SSH config stanza."""
@@ -339,18 +381,7 @@ class SSHHostDetails:
339
381
  hostname: str
340
382
  user: str
341
383
  identity_file: Path
342
-
343
-
344
- @dataclass
345
- class GpuInfo:
346
- """GPU information retrieved via nvidia-smi and nvcc."""
347
-
348
- driver_version: str
349
- cuda_driver_version: str # max CUDA version supported by driver (from nvidia-smi)
350
- cuda_toolkit_version: str | None # actual CUDA toolkit installed (from nvcc), None if unavailable
351
- gpu_name: str
352
- compute_capability: str
353
- architecture: str
384
+ port: int = 22
354
385
 
355
386
 
356
387
  def get_ssh_host_details(instance_id: str, config_path: Path | None = None) -> SSHHostDetails | None:
@@ -371,6 +402,7 @@ def get_ssh_host_details(instance_id: str, config_path: Path | None = None) -> S
371
402
  hostname: str | None = None
372
403
  user: str | None = None
373
404
  identity_file: str | None = None
405
+ port: int = 22
374
406
 
375
407
  for line in content.splitlines():
376
408
  if line == begin_marker:
@@ -378,7 +410,7 @@ def get_ssh_host_details(instance_id: str, config_path: Path | None = None) -> S
378
410
  continue
379
411
  if line == end_marker and in_block:
380
412
  if hostname and user and identity_file:
381
- return SSHHostDetails(hostname=hostname, user=user, identity_file=Path(identity_file))
413
+ return SSHHostDetails(hostname=hostname, user=user, identity_file=Path(identity_file), port=port)
382
414
  return None
383
415
  if in_block:
384
416
  stripped = line.strip()
@@ -388,17 +420,20 @@ def get_ssh_host_details(instance_id: str, config_path: Path | None = None) -> S
388
420
  user = stripped.removeprefix("User ").strip()
389
421
  elif stripped.startswith("IdentityFile "):
390
422
  identity_file = stripped.removeprefix("IdentityFile ").strip()
423
+ elif stripped.startswith("Port "):
424
+ port = int(stripped.removeprefix("Port ").strip())
391
425
 
392
426
  return None
393
427
 
394
428
 
395
- def query_gpu_info(host: str, user: str, key_path: Path, timeout: int = 10) -> GpuInfo | None:
429
+ def query_gpu_info(host: str, user: str, key_path: Path, timeout: int = 10, port: int = 22) -> GpuInfo | None:
396
430
  """SSH into a host and query GPU info via ``nvidia-smi``.
397
431
 
398
432
  Returns ``GpuInfo`` on success, or ``None`` if the SSH connection fails,
399
433
  ``nvidia-smi`` is unavailable, or the output is malformed.
400
434
  """
401
435
  ssh_opts = _ssh_opts(key_path)
436
+ port_opts = ["-p", str(port)] if port != 22 else []
402
437
  remote_cmd = (
403
438
  "nvidia-smi --query-gpu=driver_version,name,compute_cap --format=csv,noheader,nounits"
404
439
  " && nvidia-smi | grep -oP 'CUDA Version: \\K[\\d.]+'"
@@ -407,6 +442,7 @@ def query_gpu_info(host: str, user: str, key_path: Path, timeout: int = 10) -> G
407
442
  cmd = [
408
443
  "ssh",
409
444
  *ssh_opts,
445
+ *port_opts,
410
446
  "-o",
411
447
  f"ConnectTimeout={timeout}",
412
448
  "-o",
@@ -9,7 +9,8 @@ import botocore.exceptions
9
9
  from click.testing import CliRunner
10
10
 
11
11
  from aws_bootstrap.cli import main
12
- from aws_bootstrap.ssh import GpuInfo, SSHHostDetails
12
+ from aws_bootstrap.gpu import GpuInfo
13
+ from aws_bootstrap.ssh import SSHHostDetails
13
14
 
14
15
 
15
16
  def test_help():
@@ -73,11 +74,12 @@ def test_status_no_instances(mock_find, mock_session):
73
74
  assert "No active" in result.output
74
75
 
75
76
 
77
+ @patch("aws_bootstrap.cli.get_ssh_host_details", return_value=None)
76
78
  @patch("aws_bootstrap.cli.list_ssh_hosts", return_value={})
77
79
  @patch("aws_bootstrap.cli.boto3.Session")
78
80
  @patch("aws_bootstrap.cli.get_spot_price")
79
81
  @patch("aws_bootstrap.cli.find_tagged_instances")
80
- def test_status_shows_instances(mock_find, mock_spot_price, mock_session, mock_ssh_hosts):
82
+ def test_status_shows_instances(mock_find, mock_spot_price, mock_session, mock_ssh_hosts, mock_details):
81
83
  mock_find.return_value = [
82
84
  {
83
85
  "InstanceId": "i-abc123",
@@ -101,11 +103,12 @@ def test_status_shows_instances(mock_find, mock_spot_price, mock_session, mock_s
101
103
  assert "Est. cost" in result.output
102
104
 
103
105
 
106
+ @patch("aws_bootstrap.cli.get_ssh_host_details", return_value=None)
104
107
  @patch("aws_bootstrap.cli.list_ssh_hosts", return_value={})
105
108
  @patch("aws_bootstrap.cli.boto3.Session")
106
109
  @patch("aws_bootstrap.cli.get_spot_price")
107
110
  @patch("aws_bootstrap.cli.find_tagged_instances")
108
- def test_status_on_demand_no_cost(mock_find, mock_spot_price, mock_session, mock_ssh_hosts):
111
+ def test_status_on_demand_no_cost(mock_find, mock_spot_price, mock_session, mock_ssh_hosts, mock_details):
109
112
  mock_find.return_value = [
110
113
  {
111
114
  "InstanceId": "i-ondemand",
@@ -351,11 +354,12 @@ def test_terminate_removes_ssh_config(mock_terminate, mock_find, mock_session, m
351
354
  mock_remove_ssh.assert_called_once_with("i-abc123")
352
355
 
353
356
 
357
+ @patch("aws_bootstrap.cli.get_ssh_host_details", return_value=None)
354
358
  @patch("aws_bootstrap.cli.list_ssh_hosts")
355
359
  @patch("aws_bootstrap.cli.boto3.Session")
356
360
  @patch("aws_bootstrap.cli.get_spot_price")
357
361
  @patch("aws_bootstrap.cli.find_tagged_instances")
358
- def test_status_shows_alias(mock_find, mock_spot_price, mock_session, mock_ssh_hosts):
362
+ def test_status_shows_alias(mock_find, mock_spot_price, mock_session, mock_ssh_hosts, mock_details):
359
363
  mock_find.return_value = [
360
364
  {
361
365
  "InstanceId": "i-abc123",
@@ -376,11 +380,12 @@ def test_status_shows_alias(mock_find, mock_spot_price, mock_session, mock_ssh_h
376
380
  assert "aws-gpu1" in result.output
377
381
 
378
382
 
383
+ @patch("aws_bootstrap.cli.get_ssh_host_details", return_value=None)
379
384
  @patch("aws_bootstrap.cli.list_ssh_hosts", return_value={})
380
385
  @patch("aws_bootstrap.cli.boto3.Session")
381
386
  @patch("aws_bootstrap.cli.get_spot_price")
382
387
  @patch("aws_bootstrap.cli.find_tagged_instances")
383
- def test_status_no_alias_graceful(mock_find, mock_spot_price, mock_session, mock_ssh_hosts):
388
+ def test_status_no_alias_graceful(mock_find, mock_spot_price, mock_session, mock_ssh_hosts, mock_details):
384
389
  mock_find.return_value = [
385
390
  {
386
391
  "InstanceId": "i-old999",
@@ -520,13 +525,98 @@ def test_status_gpu_skips_non_running(mock_find, mock_session, mock_ssh_hosts, m
520
525
  @patch("aws_bootstrap.cli.boto3.Session")
521
526
  @patch("aws_bootstrap.cli.get_spot_price", return_value=0.15)
522
527
  @patch("aws_bootstrap.cli.find_tagged_instances")
523
- def test_status_without_gpu_flag_no_ssh(mock_find, mock_spot, mock_session, mock_ssh_hosts, mock_details, mock_gpu):
528
+ def test_status_without_gpu_flag_no_gpu_query(
529
+ mock_find, mock_spot, mock_session, mock_ssh_hosts, mock_details, mock_gpu
530
+ ):
524
531
  mock_find.return_value = [_RUNNING_INSTANCE]
525
532
  runner = CliRunner()
526
533
  result = runner.invoke(main, ["status"])
527
534
  assert result.exit_code == 0
528
535
  mock_gpu.assert_not_called()
529
- mock_details.assert_not_called()
536
+
537
+
538
+ # ---------------------------------------------------------------------------
539
+ # --instructions / --no-instructions / -I flag tests
540
+ # ---------------------------------------------------------------------------
541
+
542
+
543
+ def test_status_help_shows_instructions_flag():
544
+ runner = CliRunner()
545
+ result = runner.invoke(main, ["status", "--help"])
546
+ assert result.exit_code == 0
547
+ assert "--instructions" in result.output
548
+ assert "--no-instructions" in result.output
549
+ assert "-I" in result.output
550
+
551
+
552
+ @patch("aws_bootstrap.cli.get_ssh_host_details")
553
+ @patch("aws_bootstrap.cli.list_ssh_hosts", return_value={"i-abc123": "aws-gpu1"})
554
+ @patch("aws_bootstrap.cli.boto3.Session")
555
+ @patch("aws_bootstrap.cli.get_spot_price", return_value=0.15)
556
+ @patch("aws_bootstrap.cli.find_tagged_instances")
557
+ def test_status_instructions_shown_by_default(mock_find, mock_spot, mock_session, mock_ssh_hosts, mock_details):
558
+ """Instructions are shown by default (no flag needed)."""
559
+ mock_find.return_value = [_RUNNING_INSTANCE]
560
+ mock_details.return_value = SSHHostDetails(
561
+ hostname="1.2.3.4", user="ubuntu", identity_file=Path("/home/user/.ssh/id_ed25519")
562
+ )
563
+ runner = CliRunner()
564
+ result = runner.invoke(main, ["status"])
565
+ assert result.exit_code == 0
566
+ assert "ssh aws-gpu1" in result.output
567
+ assert "ssh -NL 8888:localhost:8888 aws-gpu1" in result.output
568
+ assert "vscode-remote://ssh-remote+aws-gpu1/home/ubuntu/workspace" in result.output
569
+ assert "python ~/gpu_benchmark.py" in result.output
570
+
571
+
572
+ @patch("aws_bootstrap.cli.get_ssh_host_details")
573
+ @patch("aws_bootstrap.cli.list_ssh_hosts", return_value={"i-abc123": "aws-gpu1"})
574
+ @patch("aws_bootstrap.cli.boto3.Session")
575
+ @patch("aws_bootstrap.cli.get_spot_price", return_value=0.15)
576
+ @patch("aws_bootstrap.cli.find_tagged_instances")
577
+ def test_status_no_instructions_suppresses_commands(mock_find, mock_spot, mock_session, mock_ssh_hosts, mock_details):
578
+ """--no-instructions suppresses connection commands."""
579
+ mock_find.return_value = [_RUNNING_INSTANCE]
580
+ mock_details.return_value = SSHHostDetails(
581
+ hostname="1.2.3.4", user="ubuntu", identity_file=Path("/home/user/.ssh/id_ed25519")
582
+ )
583
+ runner = CliRunner()
584
+ result = runner.invoke(main, ["status", "--no-instructions"])
585
+ assert result.exit_code == 0
586
+ assert "vscode-remote" not in result.output
587
+ assert "Jupyter" not in result.output
588
+
589
+
590
+ @patch("aws_bootstrap.cli.get_ssh_host_details")
591
+ @patch("aws_bootstrap.cli.list_ssh_hosts", return_value={})
592
+ @patch("aws_bootstrap.cli.boto3.Session")
593
+ @patch("aws_bootstrap.cli.get_spot_price", return_value=0.15)
594
+ @patch("aws_bootstrap.cli.find_tagged_instances")
595
+ def test_status_instructions_no_alias_skips(mock_find, mock_spot, mock_session, mock_ssh_hosts, mock_details):
596
+ """Instances without an SSH alias don't get connection instructions."""
597
+ mock_find.return_value = [_RUNNING_INSTANCE]
598
+ runner = CliRunner()
599
+ result = runner.invoke(main, ["status"])
600
+ assert result.exit_code == 0
601
+ assert "ssh aws-gpu" not in result.output
602
+ assert "vscode-remote" not in result.output
603
+
604
+
605
+ @patch("aws_bootstrap.cli.get_ssh_host_details")
606
+ @patch("aws_bootstrap.cli.list_ssh_hosts", return_value={"i-abc123": "aws-gpu1"})
607
+ @patch("aws_bootstrap.cli.boto3.Session")
608
+ @patch("aws_bootstrap.cli.get_spot_price", return_value=0.15)
609
+ @patch("aws_bootstrap.cli.find_tagged_instances")
610
+ def test_status_instructions_non_default_port(mock_find, mock_spot, mock_session, mock_ssh_hosts, mock_details):
611
+ mock_find.return_value = [_RUNNING_INSTANCE]
612
+ mock_details.return_value = SSHHostDetails(
613
+ hostname="1.2.3.4", user="ubuntu", identity_file=Path("/home/user/.ssh/id_ed25519"), port=2222
614
+ )
615
+ runner = CliRunner()
616
+ result = runner.invoke(main, ["status"])
617
+ assert result.exit_code == 0
618
+ assert "ssh -p 2222 aws-gpu1" in result.output
619
+ assert "ssh -NL 8888:localhost:8888 -p 2222 aws-gpu1" in result.output
530
620
 
531
621
 
532
622
  # ---------------------------------------------------------------------------
@@ -636,3 +726,111 @@ def test_no_credentials_caught_on_list(mock_session, mock_list):
636
726
  result = runner.invoke(main, ["list", "instance-types"])
637
727
  assert result.exit_code != 0
638
728
  assert "Unable to locate AWS credentials" in result.output
729
+
730
+
731
+ # ---------------------------------------------------------------------------
732
+ # --python-version tests
733
+ # ---------------------------------------------------------------------------
734
+
735
+
736
+ @patch("aws_bootstrap.cli.add_ssh_host", return_value="aws-gpu1")
737
+ @patch("aws_bootstrap.cli.run_remote_setup", return_value=True)
738
+ @patch("aws_bootstrap.cli.wait_for_ssh", return_value=True)
739
+ @patch("aws_bootstrap.cli.wait_instance_ready")
740
+ @patch("aws_bootstrap.cli.launch_instance")
741
+ @patch("aws_bootstrap.cli.ensure_security_group", return_value="sg-123")
742
+ @patch("aws_bootstrap.cli.import_key_pair", return_value="aws-bootstrap-key")
743
+ @patch("aws_bootstrap.cli.get_latest_ami")
744
+ @patch("aws_bootstrap.cli.boto3.Session")
745
+ def test_launch_python_version_passed_to_setup(
746
+ mock_session, mock_ami, mock_import, mock_sg, mock_launch, mock_wait, mock_ssh, mock_setup, mock_add_ssh, tmp_path
747
+ ):
748
+ mock_ami.return_value = {"ImageId": "ami-123", "Name": "TestAMI"}
749
+ mock_launch.return_value = {"InstanceId": "i-test123"}
750
+ mock_wait.return_value = {"PublicIpAddress": "1.2.3.4"}
751
+
752
+ key_path = tmp_path / "id_ed25519.pub"
753
+ key_path.write_text("ssh-ed25519 AAAA test@host")
754
+
755
+ runner = CliRunner()
756
+ result = runner.invoke(main, ["launch", "--key-path", str(key_path), "--python-version", "3.13"])
757
+ assert result.exit_code == 0
758
+ mock_setup.assert_called_once()
759
+ assert mock_setup.call_args[0][4] == "3.13"
760
+
761
+
762
+ @patch("aws_bootstrap.cli.boto3.Session")
763
+ @patch("aws_bootstrap.cli.get_latest_ami")
764
+ @patch("aws_bootstrap.cli.import_key_pair", return_value="aws-bootstrap-key")
765
+ @patch("aws_bootstrap.cli.ensure_security_group", return_value="sg-123")
766
+ def test_launch_dry_run_shows_python_version(mock_sg, mock_import, mock_ami, mock_session, tmp_path):
767
+ mock_ami.return_value = {"ImageId": "ami-123", "Name": "TestAMI"}
768
+
769
+ key_path = tmp_path / "id_ed25519.pub"
770
+ key_path.write_text("ssh-ed25519 AAAA test@host")
771
+
772
+ runner = CliRunner()
773
+ result = runner.invoke(main, ["launch", "--key-path", str(key_path), "--dry-run", "--python-version", "3.14.2"])
774
+ assert result.exit_code == 0
775
+ assert "3.14.2" in result.output
776
+ assert "Python version" in result.output
777
+
778
+
779
+ @patch("aws_bootstrap.cli.boto3.Session")
780
+ @patch("aws_bootstrap.cli.get_latest_ami")
781
+ @patch("aws_bootstrap.cli.import_key_pair", return_value="aws-bootstrap-key")
782
+ @patch("aws_bootstrap.cli.ensure_security_group", return_value="sg-123")
783
+ def test_launch_dry_run_omits_python_version_when_unset(mock_sg, mock_import, mock_ami, mock_session, tmp_path):
784
+ mock_ami.return_value = {"ImageId": "ami-123", "Name": "TestAMI"}
785
+
786
+ key_path = tmp_path / "id_ed25519.pub"
787
+ key_path.write_text("ssh-ed25519 AAAA test@host")
788
+
789
+ runner = CliRunner()
790
+ result = runner.invoke(main, ["launch", "--key-path", str(key_path), "--dry-run"])
791
+ assert result.exit_code == 0
792
+ assert "Python version" not in result.output
793
+
794
+
795
+ # ---------------------------------------------------------------------------
796
+ # --ssh-port tests
797
+ # ---------------------------------------------------------------------------
798
+
799
+
800
+ def test_launch_help_shows_ssh_port():
801
+ runner = CliRunner()
802
+ result = runner.invoke(main, ["launch", "--help"])
803
+ assert result.exit_code == 0
804
+ assert "--ssh-port" in result.output
805
+
806
+
807
+ @patch("aws_bootstrap.cli.boto3.Session")
808
+ @patch("aws_bootstrap.cli.get_latest_ami")
809
+ @patch("aws_bootstrap.cli.import_key_pair", return_value="aws-bootstrap-key")
810
+ @patch("aws_bootstrap.cli.ensure_security_group", return_value="sg-123")
811
+ def test_launch_dry_run_shows_ssh_port_when_non_default(mock_sg, mock_import, mock_ami, mock_session, tmp_path):
812
+ mock_ami.return_value = {"ImageId": "ami-123", "Name": "TestAMI"}
813
+
814
+ key_path = tmp_path / "id_ed25519.pub"
815
+ key_path.write_text("ssh-ed25519 AAAA test@host")
816
+
817
+ runner = CliRunner()
818
+ result = runner.invoke(main, ["launch", "--key-path", str(key_path), "--dry-run", "--ssh-port", "2222"])
819
+ assert result.exit_code == 0
820
+ assert "2222" in result.output
821
+
822
+
823
+ @patch("aws_bootstrap.cli.boto3.Session")
824
+ @patch("aws_bootstrap.cli.get_latest_ami")
825
+ @patch("aws_bootstrap.cli.import_key_pair", return_value="aws-bootstrap-key")
826
+ @patch("aws_bootstrap.cli.ensure_security_group", return_value="sg-123")
827
+ def test_launch_dry_run_omits_ssh_port_when_default(mock_sg, mock_import, mock_ami, mock_session, tmp_path):
828
+ mock_ami.return_value = {"ImageId": "ami-123", "Name": "TestAMI"}
829
+
830
+ key_path = tmp_path / "id_ed25519.pub"
831
+ key_path.write_text("ssh-ed25519 AAAA test@host")
832
+
833
+ runner = CliRunner()
834
+ result = runner.invoke(main, ["launch", "--key-path", str(key_path), "--dry-run"])
835
+ assert result.exit_code == 0
836
+ assert "SSH port" not in result.output
@@ -0,0 +1,98 @@
1
+ """Tests for GPU info queries via SSH (query_gpu_info, GPU architecture mapping)."""
2
+
3
+ from __future__ import annotations
4
+ import subprocess
5
+ from pathlib import Path
6
+ from unittest.mock import patch
7
+
8
+ from aws_bootstrap.gpu import _GPU_ARCHITECTURES, GpuInfo
9
+ from aws_bootstrap.ssh import query_gpu_info
10
+
11
+
12
+ # ---------------------------------------------------------------------------
13
+ # query_gpu_info
14
+ # ---------------------------------------------------------------------------
15
+
16
+ NVIDIA_SMI_OUTPUT = "560.35.03, Tesla T4, 7.5\n12.8\n12.6\n"
17
+
18
+
19
+ @patch("aws_bootstrap.ssh.subprocess.run")
20
+ def test_query_gpu_info_success(mock_run):
21
+ """Successful nvidia-smi + nvcc output returns a valid GpuInfo."""
22
+ mock_run.return_value = subprocess.CompletedProcess(args=[], returncode=0, stdout=NVIDIA_SMI_OUTPUT, stderr="")
23
+
24
+ info = query_gpu_info("1.2.3.4", "ubuntu", Path("/home/user/.ssh/id_ed25519"))
25
+ assert info is not None
26
+ assert isinstance(info, GpuInfo)
27
+ assert info.driver_version == "560.35.03"
28
+ assert info.cuda_driver_version == "12.8"
29
+ assert info.cuda_toolkit_version == "12.6"
30
+ assert info.gpu_name == "Tesla T4"
31
+ assert info.compute_capability == "7.5"
32
+ assert info.architecture == "Turing"
33
+
34
+
35
+ @patch("aws_bootstrap.ssh.subprocess.run")
36
+ def test_query_gpu_info_no_nvcc(mock_run):
37
+ """When nvcc is unavailable, cuda_toolkit_version is None."""
38
+ output = "560.35.03, Tesla T4, 7.5\n12.8\nN/A\n"
39
+ mock_run.return_value = subprocess.CompletedProcess(args=[], returncode=0, stdout=output, stderr="")
40
+
41
+ info = query_gpu_info("1.2.3.4", "ubuntu", Path("/home/user/.ssh/id_ed25519"))
42
+ assert info is not None
43
+ assert info.cuda_driver_version == "12.8"
44
+ assert info.cuda_toolkit_version is None
45
+
46
+
47
+ @patch("aws_bootstrap.ssh.subprocess.run")
48
+ def test_query_gpu_info_ssh_failure(mock_run):
49
+ """Non-zero exit code returns None."""
50
+ mock_run.return_value = subprocess.CompletedProcess(args=[], returncode=255, stdout="", stderr="Connection refused")
51
+
52
+ info = query_gpu_info("1.2.3.4", "ubuntu", Path("/home/user/.ssh/id_ed25519"))
53
+ assert info is None
54
+
55
+
56
+ @patch("aws_bootstrap.ssh.subprocess.run", side_effect=subprocess.TimeoutExpired(cmd="ssh", timeout=15))
57
+ def test_query_gpu_info_timeout(mock_run):
58
+ """TimeoutExpired returns None."""
59
+ info = query_gpu_info("1.2.3.4", "ubuntu", Path("/home/user/.ssh/id_ed25519"))
60
+ assert info is None
61
+
62
+
63
+ @patch("aws_bootstrap.ssh.subprocess.run")
64
+ def test_query_gpu_info_malformed_output(mock_run):
65
+ """Garbage output returns None."""
66
+ mock_run.return_value = subprocess.CompletedProcess(
67
+ args=[], returncode=0, stdout="not valid gpu output\n", stderr=""
68
+ )
69
+
70
+ info = query_gpu_info("1.2.3.4", "ubuntu", Path("/home/user/.ssh/id_ed25519"))
71
+ assert info is None
72
+
73
+
74
+ # ---------------------------------------------------------------------------
75
+ # GPU architecture mapping
76
+ # ---------------------------------------------------------------------------
77
+
78
+
79
+ def test_gpu_architecture_mapping():
80
+ """Known compute capabilities map to correct architecture names."""
81
+ assert _GPU_ARCHITECTURES["7.5"] == "Turing"
82
+ assert _GPU_ARCHITECTURES["8.0"] == "Ampere"
83
+ assert _GPU_ARCHITECTURES["8.6"] == "Ampere"
84
+ assert _GPU_ARCHITECTURES["8.9"] == "Ada Lovelace"
85
+ assert _GPU_ARCHITECTURES["9.0"] == "Hopper"
86
+ assert _GPU_ARCHITECTURES["7.0"] == "Volta"
87
+
88
+
89
+ @patch("aws_bootstrap.ssh.subprocess.run")
90
+ def test_query_gpu_info_unknown_architecture(mock_run):
91
+ """Unknown compute capability produces a fallback architecture string."""
92
+ mock_run.return_value = subprocess.CompletedProcess(
93
+ args=[], returncode=0, stdout="550.00.00, Future GPU, 10.0\n13.0\n13.0\n", stderr=""
94
+ )
95
+
96
+ info = query_gpu_info("1.2.3.4", "ubuntu", Path("/home/user/.ssh/id_ed25519"))
97
+ assert info is not None
98
+ assert info.architecture == "Unknown (10.0)"