aws-bootstrap-g4dn 0.1.0__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -5,10 +5,12 @@ from datetime import UTC, datetime
5
5
  from pathlib import Path
6
6
  from unittest.mock import patch
7
7
 
8
+ import botocore.exceptions
8
9
  from click.testing import CliRunner
9
10
 
10
11
  from aws_bootstrap.cli import main
11
- from aws_bootstrap.ssh import GpuInfo, SSHHostDetails
12
+ from aws_bootstrap.gpu import GpuInfo
13
+ from aws_bootstrap.ssh import SSHHostDetails
12
14
 
13
15
 
14
16
  def test_help():
@@ -72,11 +74,12 @@ def test_status_no_instances(mock_find, mock_session):
72
74
  assert "No active" in result.output
73
75
 
74
76
 
77
+ @patch("aws_bootstrap.cli.get_ssh_host_details", return_value=None)
75
78
  @patch("aws_bootstrap.cli.list_ssh_hosts", return_value={})
76
79
  @patch("aws_bootstrap.cli.boto3.Session")
77
80
  @patch("aws_bootstrap.cli.get_spot_price")
78
81
  @patch("aws_bootstrap.cli.find_tagged_instances")
79
- def test_status_shows_instances(mock_find, mock_spot_price, mock_session, mock_ssh_hosts):
82
+ def test_status_shows_instances(mock_find, mock_spot_price, mock_session, mock_ssh_hosts, mock_details):
80
83
  mock_find.return_value = [
81
84
  {
82
85
  "InstanceId": "i-abc123",
@@ -100,11 +103,12 @@ def test_status_shows_instances(mock_find, mock_spot_price, mock_session, mock_s
100
103
  assert "Est. cost" in result.output
101
104
 
102
105
 
106
+ @patch("aws_bootstrap.cli.get_ssh_host_details", return_value=None)
103
107
  @patch("aws_bootstrap.cli.list_ssh_hosts", return_value={})
104
108
  @patch("aws_bootstrap.cli.boto3.Session")
105
109
  @patch("aws_bootstrap.cli.get_spot_price")
106
110
  @patch("aws_bootstrap.cli.find_tagged_instances")
107
- def test_status_on_demand_no_cost(mock_find, mock_spot_price, mock_session, mock_ssh_hosts):
111
+ def test_status_on_demand_no_cost(mock_find, mock_spot_price, mock_session, mock_ssh_hosts, mock_details):
108
112
  mock_find.return_value = [
109
113
  {
110
114
  "InstanceId": "i-ondemand",
@@ -350,11 +354,12 @@ def test_terminate_removes_ssh_config(mock_terminate, mock_find, mock_session, m
350
354
  mock_remove_ssh.assert_called_once_with("i-abc123")
351
355
 
352
356
 
357
+ @patch("aws_bootstrap.cli.get_ssh_host_details", return_value=None)
353
358
  @patch("aws_bootstrap.cli.list_ssh_hosts")
354
359
  @patch("aws_bootstrap.cli.boto3.Session")
355
360
  @patch("aws_bootstrap.cli.get_spot_price")
356
361
  @patch("aws_bootstrap.cli.find_tagged_instances")
357
- def test_status_shows_alias(mock_find, mock_spot_price, mock_session, mock_ssh_hosts):
362
+ def test_status_shows_alias(mock_find, mock_spot_price, mock_session, mock_ssh_hosts, mock_details):
358
363
  mock_find.return_value = [
359
364
  {
360
365
  "InstanceId": "i-abc123",
@@ -375,11 +380,12 @@ def test_status_shows_alias(mock_find, mock_spot_price, mock_session, mock_ssh_h
375
380
  assert "aws-gpu1" in result.output
376
381
 
377
382
 
383
+ @patch("aws_bootstrap.cli.get_ssh_host_details", return_value=None)
378
384
  @patch("aws_bootstrap.cli.list_ssh_hosts", return_value={})
379
385
  @patch("aws_bootstrap.cli.boto3.Session")
380
386
  @patch("aws_bootstrap.cli.get_spot_price")
381
387
  @patch("aws_bootstrap.cli.find_tagged_instances")
382
- def test_status_no_alias_graceful(mock_find, mock_spot_price, mock_session, mock_ssh_hosts):
388
+ def test_status_no_alias_graceful(mock_find, mock_spot_price, mock_session, mock_ssh_hosts, mock_details):
383
389
  mock_find.return_value = [
384
390
  {
385
391
  "InstanceId": "i-old999",
@@ -519,10 +525,312 @@ def test_status_gpu_skips_non_running(mock_find, mock_session, mock_ssh_hosts, m
519
525
  @patch("aws_bootstrap.cli.boto3.Session")
520
526
  @patch("aws_bootstrap.cli.get_spot_price", return_value=0.15)
521
527
  @patch("aws_bootstrap.cli.find_tagged_instances")
522
- def test_status_without_gpu_flag_no_ssh(mock_find, mock_spot, mock_session, mock_ssh_hosts, mock_details, mock_gpu):
528
+ def test_status_without_gpu_flag_no_gpu_query(
529
+ mock_find, mock_spot, mock_session, mock_ssh_hosts, mock_details, mock_gpu
530
+ ):
523
531
  mock_find.return_value = [_RUNNING_INSTANCE]
524
532
  runner = CliRunner()
525
533
  result = runner.invoke(main, ["status"])
526
534
  assert result.exit_code == 0
527
535
  mock_gpu.assert_not_called()
528
- mock_details.assert_not_called()
536
+
537
+
538
+ # ---------------------------------------------------------------------------
539
+ # --instructions / --no-instructions / -I flag tests
540
+ # ---------------------------------------------------------------------------
541
+
542
+
543
+ def test_status_help_shows_instructions_flag():
544
+ runner = CliRunner()
545
+ result = runner.invoke(main, ["status", "--help"])
546
+ assert result.exit_code == 0
547
+ assert "--instructions" in result.output
548
+ assert "--no-instructions" in result.output
549
+ assert "-I" in result.output
550
+
551
+
552
+ @patch("aws_bootstrap.cli.get_ssh_host_details")
553
+ @patch("aws_bootstrap.cli.list_ssh_hosts", return_value={"i-abc123": "aws-gpu1"})
554
+ @patch("aws_bootstrap.cli.boto3.Session")
555
+ @patch("aws_bootstrap.cli.get_spot_price", return_value=0.15)
556
+ @patch("aws_bootstrap.cli.find_tagged_instances")
557
+ def test_status_instructions_shown_by_default(mock_find, mock_spot, mock_session, mock_ssh_hosts, mock_details):
558
+ """Instructions are shown by default (no flag needed)."""
559
+ mock_find.return_value = [_RUNNING_INSTANCE]
560
+ mock_details.return_value = SSHHostDetails(
561
+ hostname="1.2.3.4", user="ubuntu", identity_file=Path("/home/user/.ssh/id_ed25519")
562
+ )
563
+ runner = CliRunner()
564
+ result = runner.invoke(main, ["status"])
565
+ assert result.exit_code == 0
566
+ assert "ssh aws-gpu1" in result.output
567
+ assert "ssh -NL 8888:localhost:8888 aws-gpu1" in result.output
568
+ assert "vscode-remote://ssh-remote+aws-gpu1/home/ubuntu" in result.output
569
+ assert "python ~/gpu_benchmark.py" in result.output
570
+
571
+
572
+ @patch("aws_bootstrap.cli.get_ssh_host_details")
573
+ @patch("aws_bootstrap.cli.list_ssh_hosts", return_value={"i-abc123": "aws-gpu1"})
574
+ @patch("aws_bootstrap.cli.boto3.Session")
575
+ @patch("aws_bootstrap.cli.get_spot_price", return_value=0.15)
576
+ @patch("aws_bootstrap.cli.find_tagged_instances")
577
+ def test_status_no_instructions_suppresses_commands(mock_find, mock_spot, mock_session, mock_ssh_hosts, mock_details):
578
+ """--no-instructions suppresses connection commands."""
579
+ mock_find.return_value = [_RUNNING_INSTANCE]
580
+ mock_details.return_value = SSHHostDetails(
581
+ hostname="1.2.3.4", user="ubuntu", identity_file=Path("/home/user/.ssh/id_ed25519")
582
+ )
583
+ runner = CliRunner()
584
+ result = runner.invoke(main, ["status", "--no-instructions"])
585
+ assert result.exit_code == 0
586
+ assert "vscode-remote" not in result.output
587
+ assert "Jupyter" not in result.output
588
+
589
+
590
+ @patch("aws_bootstrap.cli.get_ssh_host_details")
591
+ @patch("aws_bootstrap.cli.list_ssh_hosts", return_value={})
592
+ @patch("aws_bootstrap.cli.boto3.Session")
593
+ @patch("aws_bootstrap.cli.get_spot_price", return_value=0.15)
594
+ @patch("aws_bootstrap.cli.find_tagged_instances")
595
+ def test_status_instructions_no_alias_skips(mock_find, mock_spot, mock_session, mock_ssh_hosts, mock_details):
596
+ """Instances without an SSH alias don't get connection instructions."""
597
+ mock_find.return_value = [_RUNNING_INSTANCE]
598
+ runner = CliRunner()
599
+ result = runner.invoke(main, ["status"])
600
+ assert result.exit_code == 0
601
+ assert "ssh aws-gpu" not in result.output
602
+ assert "vscode-remote" not in result.output
603
+
604
+
605
+ @patch("aws_bootstrap.cli.get_ssh_host_details")
606
+ @patch("aws_bootstrap.cli.list_ssh_hosts", return_value={"i-abc123": "aws-gpu1"})
607
+ @patch("aws_bootstrap.cli.boto3.Session")
608
+ @patch("aws_bootstrap.cli.get_spot_price", return_value=0.15)
609
+ @patch("aws_bootstrap.cli.find_tagged_instances")
610
+ def test_status_instructions_non_default_port(mock_find, mock_spot, mock_session, mock_ssh_hosts, mock_details):
611
+ mock_find.return_value = [_RUNNING_INSTANCE]
612
+ mock_details.return_value = SSHHostDetails(
613
+ hostname="1.2.3.4", user="ubuntu", identity_file=Path("/home/user/.ssh/id_ed25519"), port=2222
614
+ )
615
+ runner = CliRunner()
616
+ result = runner.invoke(main, ["status"])
617
+ assert result.exit_code == 0
618
+ assert "ssh -p 2222 aws-gpu1" in result.output
619
+ assert "ssh -NL 8888:localhost:8888 -p 2222 aws-gpu1" in result.output
620
+
621
+
622
+ # ---------------------------------------------------------------------------
623
+ # AWS credential / auth error handling tests
624
+ # ---------------------------------------------------------------------------
625
+
626
+
627
+ @patch("aws_bootstrap.cli.find_tagged_instances")
628
+ @patch("aws_bootstrap.cli.boto3.Session")
629
+ def test_no_credentials_shows_friendly_error(mock_session, mock_find):
630
+ """NoCredentialsError should show a helpful message, not a raw traceback."""
631
+ mock_find.side_effect = botocore.exceptions.NoCredentialsError()
632
+ runner = CliRunner()
633
+ result = runner.invoke(main, ["status"])
634
+ assert result.exit_code != 0
635
+ assert "Unable to locate AWS credentials" in result.output
636
+ assert "AWS_PROFILE" in result.output
637
+ assert "--profile" in result.output
638
+ assert "aws configure" in result.output
639
+
640
+
641
+ @patch("aws_bootstrap.cli.boto3.Session")
642
+ def test_profile_not_found_shows_friendly_error(mock_session):
643
+ """ProfileNotFound should show the missing profile name and list command."""
644
+ mock_session.side_effect = botocore.exceptions.ProfileNotFound(profile="nonexistent")
645
+ runner = CliRunner()
646
+ result = runner.invoke(main, ["status", "--profile", "nonexistent"])
647
+ assert result.exit_code != 0
648
+ assert "nonexistent" in result.output
649
+ assert "aws configure list-profiles" in result.output
650
+
651
+
652
+ @patch("aws_bootstrap.cli.find_tagged_instances")
653
+ @patch("aws_bootstrap.cli.boto3.Session")
654
+ def test_partial_credentials_shows_friendly_error(mock_session, mock_find):
655
+ """PartialCredentialsError should mention incomplete credentials."""
656
+ mock_find.side_effect = botocore.exceptions.PartialCredentialsError(
657
+ provider="env", cred_var="AWS_SECRET_ACCESS_KEY"
658
+ )
659
+ runner = CliRunner()
660
+ result = runner.invoke(main, ["status"])
661
+ assert result.exit_code != 0
662
+ assert "Incomplete AWS credentials" in result.output
663
+ assert "aws configure list" in result.output
664
+
665
+
666
+ @patch("aws_bootstrap.cli.find_tagged_instances")
667
+ @patch("aws_bootstrap.cli.boto3.Session")
668
+ def test_expired_token_shows_friendly_error(mock_session, mock_find):
669
+ """ExpiredTokenException should show authorization failure with context."""
670
+ mock_find.side_effect = botocore.exceptions.ClientError(
671
+ {"Error": {"Code": "ExpiredTokenException", "Message": "The security token is expired"}},
672
+ "DescribeInstances",
673
+ )
674
+ runner = CliRunner()
675
+ result = runner.invoke(main, ["status"])
676
+ assert result.exit_code != 0
677
+ assert "AWS authorization failed" in result.output
678
+ assert "expired" in result.output.lower()
679
+
680
+
681
+ @patch("aws_bootstrap.cli.find_tagged_instances")
682
+ @patch("aws_bootstrap.cli.boto3.Session")
683
+ def test_auth_failure_shows_friendly_error(mock_session, mock_find):
684
+ """AuthFailure ClientError should show authorization failure message."""
685
+ mock_find.side_effect = botocore.exceptions.ClientError(
686
+ {"Error": {"Code": "AuthFailure", "Message": "credentials are invalid"}},
687
+ "DescribeInstances",
688
+ )
689
+ runner = CliRunner()
690
+ result = runner.invoke(main, ["status"])
691
+ assert result.exit_code != 0
692
+ assert "AWS authorization failed" in result.output
693
+
694
+
695
+ @patch("aws_bootstrap.cli.find_tagged_instances")
696
+ @patch("aws_bootstrap.cli.boto3.Session")
697
+ def test_unhandled_client_error_propagates(mock_session, mock_find):
698
+ """Non-auth ClientErrors should propagate without being caught."""
699
+ mock_find.side_effect = botocore.exceptions.ClientError(
700
+ {"Error": {"Code": "UnknownError", "Message": "something else"}},
701
+ "DescribeInstances",
702
+ )
703
+ runner = CliRunner()
704
+ result = runner.invoke(main, ["status"])
705
+ assert result.exit_code != 0
706
+ assert isinstance(result.exception, botocore.exceptions.ClientError)
707
+
708
+
709
+ @patch("aws_bootstrap.cli.find_tagged_instances")
710
+ @patch("aws_bootstrap.cli.boto3.Session")
711
+ def test_no_credentials_caught_on_terminate(mock_session, mock_find):
712
+ """Credential errors are caught for all subcommands, not just status."""
713
+ mock_find.side_effect = botocore.exceptions.NoCredentialsError()
714
+ runner = CliRunner()
715
+ result = runner.invoke(main, ["terminate"])
716
+ assert result.exit_code != 0
717
+ assert "Unable to locate AWS credentials" in result.output
718
+
719
+
720
+ @patch("aws_bootstrap.cli.list_instance_types")
721
+ @patch("aws_bootstrap.cli.boto3.Session")
722
+ def test_no_credentials_caught_on_list(mock_session, mock_list):
723
+ """Credential errors are caught for nested subcommands (list instance-types)."""
724
+ mock_list.side_effect = botocore.exceptions.NoCredentialsError()
725
+ runner = CliRunner()
726
+ result = runner.invoke(main, ["list", "instance-types"])
727
+ assert result.exit_code != 0
728
+ assert "Unable to locate AWS credentials" in result.output
729
+
730
+
731
+ # ---------------------------------------------------------------------------
732
+ # --python-version tests
733
+ # ---------------------------------------------------------------------------
734
+
735
+
736
+ @patch("aws_bootstrap.cli.add_ssh_host", return_value="aws-gpu1")
737
+ @patch("aws_bootstrap.cli.run_remote_setup", return_value=True)
738
+ @patch("aws_bootstrap.cli.wait_for_ssh", return_value=True)
739
+ @patch("aws_bootstrap.cli.wait_instance_ready")
740
+ @patch("aws_bootstrap.cli.launch_instance")
741
+ @patch("aws_bootstrap.cli.ensure_security_group", return_value="sg-123")
742
+ @patch("aws_bootstrap.cli.import_key_pair", return_value="aws-bootstrap-key")
743
+ @patch("aws_bootstrap.cli.get_latest_ami")
744
+ @patch("aws_bootstrap.cli.boto3.Session")
745
+ def test_launch_python_version_passed_to_setup(
746
+ mock_session, mock_ami, mock_import, mock_sg, mock_launch, mock_wait, mock_ssh, mock_setup, mock_add_ssh, tmp_path
747
+ ):
748
+ mock_ami.return_value = {"ImageId": "ami-123", "Name": "TestAMI"}
749
+ mock_launch.return_value = {"InstanceId": "i-test123"}
750
+ mock_wait.return_value = {"PublicIpAddress": "1.2.3.4"}
751
+
752
+ key_path = tmp_path / "id_ed25519.pub"
753
+ key_path.write_text("ssh-ed25519 AAAA test@host")
754
+
755
+ runner = CliRunner()
756
+ result = runner.invoke(main, ["launch", "--key-path", str(key_path), "--python-version", "3.13"])
757
+ assert result.exit_code == 0
758
+ mock_setup.assert_called_once()
759
+ assert mock_setup.call_args[0][4] == "3.13"
760
+
761
+
762
+ @patch("aws_bootstrap.cli.boto3.Session")
763
+ @patch("aws_bootstrap.cli.get_latest_ami")
764
+ @patch("aws_bootstrap.cli.import_key_pair", return_value="aws-bootstrap-key")
765
+ @patch("aws_bootstrap.cli.ensure_security_group", return_value="sg-123")
766
+ def test_launch_dry_run_shows_python_version(mock_sg, mock_import, mock_ami, mock_session, tmp_path):
767
+ mock_ami.return_value = {"ImageId": "ami-123", "Name": "TestAMI"}
768
+
769
+ key_path = tmp_path / "id_ed25519.pub"
770
+ key_path.write_text("ssh-ed25519 AAAA test@host")
771
+
772
+ runner = CliRunner()
773
+ result = runner.invoke(main, ["launch", "--key-path", str(key_path), "--dry-run", "--python-version", "3.14.2"])
774
+ assert result.exit_code == 0
775
+ assert "3.14.2" in result.output
776
+ assert "Python version" in result.output
777
+
778
+
779
+ @patch("aws_bootstrap.cli.boto3.Session")
780
+ @patch("aws_bootstrap.cli.get_latest_ami")
781
+ @patch("aws_bootstrap.cli.import_key_pair", return_value="aws-bootstrap-key")
782
+ @patch("aws_bootstrap.cli.ensure_security_group", return_value="sg-123")
783
+ def test_launch_dry_run_omits_python_version_when_unset(mock_sg, mock_import, mock_ami, mock_session, tmp_path):
784
+ mock_ami.return_value = {"ImageId": "ami-123", "Name": "TestAMI"}
785
+
786
+ key_path = tmp_path / "id_ed25519.pub"
787
+ key_path.write_text("ssh-ed25519 AAAA test@host")
788
+
789
+ runner = CliRunner()
790
+ result = runner.invoke(main, ["launch", "--key-path", str(key_path), "--dry-run"])
791
+ assert result.exit_code == 0
792
+ assert "Python version" not in result.output
793
+
794
+
795
+ # ---------------------------------------------------------------------------
796
+ # --ssh-port tests
797
+ # ---------------------------------------------------------------------------
798
+
799
+
800
+ def test_launch_help_shows_ssh_port():
801
+ runner = CliRunner()
802
+ result = runner.invoke(main, ["launch", "--help"])
803
+ assert result.exit_code == 0
804
+ assert "--ssh-port" in result.output
805
+
806
+
807
+ @patch("aws_bootstrap.cli.boto3.Session")
808
+ @patch("aws_bootstrap.cli.get_latest_ami")
809
+ @patch("aws_bootstrap.cli.import_key_pair", return_value="aws-bootstrap-key")
810
+ @patch("aws_bootstrap.cli.ensure_security_group", return_value="sg-123")
811
+ def test_launch_dry_run_shows_ssh_port_when_non_default(mock_sg, mock_import, mock_ami, mock_session, tmp_path):
812
+ mock_ami.return_value = {"ImageId": "ami-123", "Name": "TestAMI"}
813
+
814
+ key_path = tmp_path / "id_ed25519.pub"
815
+ key_path.write_text("ssh-ed25519 AAAA test@host")
816
+
817
+ runner = CliRunner()
818
+ result = runner.invoke(main, ["launch", "--key-path", str(key_path), "--dry-run", "--ssh-port", "2222"])
819
+ assert result.exit_code == 0
820
+ assert "2222" in result.output
821
+
822
+
823
+ @patch("aws_bootstrap.cli.boto3.Session")
824
+ @patch("aws_bootstrap.cli.get_latest_ami")
825
+ @patch("aws_bootstrap.cli.import_key_pair", return_value="aws-bootstrap-key")
826
+ @patch("aws_bootstrap.cli.ensure_security_group", return_value="sg-123")
827
+ def test_launch_dry_run_omits_ssh_port_when_default(mock_sg, mock_import, mock_ami, mock_session, tmp_path):
828
+ mock_ami.return_value = {"ImageId": "ami-123", "Name": "TestAMI"}
829
+
830
+ key_path = tmp_path / "id_ed25519.pub"
831
+ key_path.write_text("ssh-ed25519 AAAA test@host")
832
+
833
+ runner = CliRunner()
834
+ result = runner.invoke(main, ["launch", "--key-path", str(key_path), "--dry-run"])
835
+ assert result.exit_code == 0
836
+ assert "SSH port" not in result.output
@@ -0,0 +1,98 @@
1
+ """Tests for GPU info queries via SSH (query_gpu_info, GPU architecture mapping)."""
2
+
3
+ from __future__ import annotations
4
+ import subprocess
5
+ from pathlib import Path
6
+ from unittest.mock import patch
7
+
8
+ from aws_bootstrap.gpu import _GPU_ARCHITECTURES, GpuInfo
9
+ from aws_bootstrap.ssh import query_gpu_info
10
+
11
+
12
+ # ---------------------------------------------------------------------------
13
+ # query_gpu_info
14
+ # ---------------------------------------------------------------------------
15
+
16
+ NVIDIA_SMI_OUTPUT = "560.35.03, Tesla T4, 7.5\n12.8\n12.6\n"
17
+
18
+
19
+ @patch("aws_bootstrap.ssh.subprocess.run")
20
+ def test_query_gpu_info_success(mock_run):
21
+ """Successful nvidia-smi + nvcc output returns a valid GpuInfo."""
22
+ mock_run.return_value = subprocess.CompletedProcess(args=[], returncode=0, stdout=NVIDIA_SMI_OUTPUT, stderr="")
23
+
24
+ info = query_gpu_info("1.2.3.4", "ubuntu", Path("/home/user/.ssh/id_ed25519"))
25
+ assert info is not None
26
+ assert isinstance(info, GpuInfo)
27
+ assert info.driver_version == "560.35.03"
28
+ assert info.cuda_driver_version == "12.8"
29
+ assert info.cuda_toolkit_version == "12.6"
30
+ assert info.gpu_name == "Tesla T4"
31
+ assert info.compute_capability == "7.5"
32
+ assert info.architecture == "Turing"
33
+
34
+
35
+ @patch("aws_bootstrap.ssh.subprocess.run")
36
+ def test_query_gpu_info_no_nvcc(mock_run):
37
+ """When nvcc is unavailable, cuda_toolkit_version is None."""
38
+ output = "560.35.03, Tesla T4, 7.5\n12.8\nN/A\n"
39
+ mock_run.return_value = subprocess.CompletedProcess(args=[], returncode=0, stdout=output, stderr="")
40
+
41
+ info = query_gpu_info("1.2.3.4", "ubuntu", Path("/home/user/.ssh/id_ed25519"))
42
+ assert info is not None
43
+ assert info.cuda_driver_version == "12.8"
44
+ assert info.cuda_toolkit_version is None
45
+
46
+
47
+ @patch("aws_bootstrap.ssh.subprocess.run")
48
+ def test_query_gpu_info_ssh_failure(mock_run):
49
+ """Non-zero exit code returns None."""
50
+ mock_run.return_value = subprocess.CompletedProcess(args=[], returncode=255, stdout="", stderr="Connection refused")
51
+
52
+ info = query_gpu_info("1.2.3.4", "ubuntu", Path("/home/user/.ssh/id_ed25519"))
53
+ assert info is None
54
+
55
+
56
+ @patch("aws_bootstrap.ssh.subprocess.run", side_effect=subprocess.TimeoutExpired(cmd="ssh", timeout=15))
57
+ def test_query_gpu_info_timeout(mock_run):
58
+ """TimeoutExpired returns None."""
59
+ info = query_gpu_info("1.2.3.4", "ubuntu", Path("/home/user/.ssh/id_ed25519"))
60
+ assert info is None
61
+
62
+
63
+ @patch("aws_bootstrap.ssh.subprocess.run")
64
+ def test_query_gpu_info_malformed_output(mock_run):
65
+ """Garbage output returns None."""
66
+ mock_run.return_value = subprocess.CompletedProcess(
67
+ args=[], returncode=0, stdout="not valid gpu output\n", stderr=""
68
+ )
69
+
70
+ info = query_gpu_info("1.2.3.4", "ubuntu", Path("/home/user/.ssh/id_ed25519"))
71
+ assert info is None
72
+
73
+
74
+ # ---------------------------------------------------------------------------
75
+ # GPU architecture mapping
76
+ # ---------------------------------------------------------------------------
77
+
78
+
79
+ def test_gpu_architecture_mapping():
80
+ """Known compute capabilities map to correct architecture names."""
81
+ assert _GPU_ARCHITECTURES["7.5"] == "Turing"
82
+ assert _GPU_ARCHITECTURES["8.0"] == "Ampere"
83
+ assert _GPU_ARCHITECTURES["8.6"] == "Ampere"
84
+ assert _GPU_ARCHITECTURES["8.9"] == "Ada Lovelace"
85
+ assert _GPU_ARCHITECTURES["9.0"] == "Hopper"
86
+ assert _GPU_ARCHITECTURES["7.0"] == "Volta"
87
+
88
+
89
+ @patch("aws_bootstrap.ssh.subprocess.run")
90
+ def test_query_gpu_info_unknown_architecture(mock_run):
91
+ """Unknown compute capability produces a fallback architecture string."""
92
+ mock_run.return_value = subprocess.CompletedProcess(
93
+ args=[], returncode=0, stdout="550.00.00, Future GPU, 10.0\n13.0\n13.0\n", stderr=""
94
+ )
95
+
96
+ info = query_gpu_info("1.2.3.4", "ubuntu", Path("/home/user/.ssh/id_ed25519"))
97
+ assert info is not None
98
+ assert info.architecture == "Unknown (10.0)"
@@ -10,6 +10,7 @@ from aws_bootstrap.ssh import (
10
10
  _read_ssh_config,
11
11
  add_ssh_host,
12
12
  find_ssh_alias,
13
+ get_ssh_host_details,
13
14
  list_ssh_hosts,
14
15
  remove_ssh_host,
15
16
  )
@@ -295,3 +296,38 @@ def test_list_hosts_nonexistent_file(tmp_path):
295
296
  def test_remove_nonexistent_file(tmp_path):
296
297
  cfg = tmp_path / "no_such_file"
297
298
  assert remove_ssh_host("i-abc123", config_path=cfg) is None
299
+
300
+
301
+ # ---------------------------------------------------------------------------
302
+ # Port in stanza / details
303
+ # ---------------------------------------------------------------------------
304
+
305
+
306
+ def test_stanza_includes_port_when_non_default(tmp_path):
307
+ cfg = _config_path(tmp_path)
308
+ add_ssh_host("i-abc123", "1.2.3.4", "ubuntu", KEY_PATH, config_path=cfg, port=2222)
309
+ content = cfg.read_text()
310
+ assert "Port 2222" in content
311
+
312
+
313
+ def test_stanza_omits_port_when_default(tmp_path):
314
+ cfg = _config_path(tmp_path)
315
+ add_ssh_host("i-abc123", "1.2.3.4", "ubuntu", KEY_PATH, config_path=cfg)
316
+ content = cfg.read_text()
317
+ assert "Port" not in content
318
+
319
+
320
+ def test_get_ssh_host_details_parses_port(tmp_path):
321
+ cfg = _config_path(tmp_path)
322
+ add_ssh_host("i-abc123", "1.2.3.4", "ubuntu", KEY_PATH, config_path=cfg, port=2222)
323
+ details = get_ssh_host_details("i-abc123", config_path=cfg)
324
+ assert details is not None
325
+ assert details.port == 2222
326
+
327
+
328
+ def test_get_ssh_host_details_default_port(tmp_path):
329
+ cfg = _config_path(tmp_path)
330
+ add_ssh_host("i-abc123", "1.2.3.4", "ubuntu", KEY_PATH, config_path=cfg)
331
+ details = get_ssh_host_details("i-abc123", config_path=cfg)
332
+ assert details is not None
333
+ assert details.port == 22
@@ -1,16 +1,11 @@
1
- """Tests for GPU info queries via SSH (get_ssh_host_details, query_gpu_info)."""
1
+ """Tests for get_ssh_host_details (SSH config parsing)."""
2
2
 
3
3
  from __future__ import annotations
4
- import subprocess
5
4
  from pathlib import Path
6
- from unittest.mock import patch
7
5
 
8
6
  from aws_bootstrap.ssh import (
9
- _GPU_ARCHITECTURES,
10
- GpuInfo,
11
7
  add_ssh_host,
12
8
  get_ssh_host_details,
13
- query_gpu_info,
14
9
  )
15
10
 
16
11
 
@@ -47,92 +42,3 @@ def test_get_ssh_host_details_nonexistent_file(tmp_path):
47
42
  """Returns None when the SSH config file doesn't exist."""
48
43
  cfg = tmp_path / "no_such_file"
49
44
  assert get_ssh_host_details("i-abc123", config_path=cfg) is None
50
-
51
-
52
- # ---------------------------------------------------------------------------
53
- # query_gpu_info
54
- # ---------------------------------------------------------------------------
55
-
56
- NVIDIA_SMI_OUTPUT = "560.35.03, Tesla T4, 7.5\n12.8\n12.6\n"
57
-
58
-
59
- @patch("aws_bootstrap.ssh.subprocess.run")
60
- def test_query_gpu_info_success(mock_run):
61
- """Successful nvidia-smi + nvcc output returns a valid GpuInfo."""
62
- mock_run.return_value = subprocess.CompletedProcess(args=[], returncode=0, stdout=NVIDIA_SMI_OUTPUT, stderr="")
63
-
64
- info = query_gpu_info("1.2.3.4", "ubuntu", Path("/home/user/.ssh/id_ed25519"))
65
- assert info is not None
66
- assert isinstance(info, GpuInfo)
67
- assert info.driver_version == "560.35.03"
68
- assert info.cuda_driver_version == "12.8"
69
- assert info.cuda_toolkit_version == "12.6"
70
- assert info.gpu_name == "Tesla T4"
71
- assert info.compute_capability == "7.5"
72
- assert info.architecture == "Turing"
73
-
74
-
75
- @patch("aws_bootstrap.ssh.subprocess.run")
76
- def test_query_gpu_info_no_nvcc(mock_run):
77
- """When nvcc is unavailable, cuda_toolkit_version is None."""
78
- output = "560.35.03, Tesla T4, 7.5\n12.8\nN/A\n"
79
- mock_run.return_value = subprocess.CompletedProcess(args=[], returncode=0, stdout=output, stderr="")
80
-
81
- info = query_gpu_info("1.2.3.4", "ubuntu", Path("/home/user/.ssh/id_ed25519"))
82
- assert info is not None
83
- assert info.cuda_driver_version == "12.8"
84
- assert info.cuda_toolkit_version is None
85
-
86
-
87
- @patch("aws_bootstrap.ssh.subprocess.run")
88
- def test_query_gpu_info_ssh_failure(mock_run):
89
- """Non-zero exit code returns None."""
90
- mock_run.return_value = subprocess.CompletedProcess(args=[], returncode=255, stdout="", stderr="Connection refused")
91
-
92
- info = query_gpu_info("1.2.3.4", "ubuntu", Path("/home/user/.ssh/id_ed25519"))
93
- assert info is None
94
-
95
-
96
- @patch("aws_bootstrap.ssh.subprocess.run", side_effect=subprocess.TimeoutExpired(cmd="ssh", timeout=15))
97
- def test_query_gpu_info_timeout(mock_run):
98
- """TimeoutExpired returns None."""
99
- info = query_gpu_info("1.2.3.4", "ubuntu", Path("/home/user/.ssh/id_ed25519"))
100
- assert info is None
101
-
102
-
103
- @patch("aws_bootstrap.ssh.subprocess.run")
104
- def test_query_gpu_info_malformed_output(mock_run):
105
- """Garbage output returns None."""
106
- mock_run.return_value = subprocess.CompletedProcess(
107
- args=[], returncode=0, stdout="not valid gpu output\n", stderr=""
108
- )
109
-
110
- info = query_gpu_info("1.2.3.4", "ubuntu", Path("/home/user/.ssh/id_ed25519"))
111
- assert info is None
112
-
113
-
114
- # ---------------------------------------------------------------------------
115
- # GPU architecture mapping
116
- # ---------------------------------------------------------------------------
117
-
118
-
119
- def test_gpu_architecture_mapping():
120
- """Known compute capabilities map to correct architecture names."""
121
- assert _GPU_ARCHITECTURES["7.5"] == "Turing"
122
- assert _GPU_ARCHITECTURES["8.0"] == "Ampere"
123
- assert _GPU_ARCHITECTURES["8.6"] == "Ampere"
124
- assert _GPU_ARCHITECTURES["8.9"] == "Ada Lovelace"
125
- assert _GPU_ARCHITECTURES["9.0"] == "Hopper"
126
- assert _GPU_ARCHITECTURES["7.0"] == "Volta"
127
-
128
-
129
- @patch("aws_bootstrap.ssh.subprocess.run")
130
- def test_query_gpu_info_unknown_architecture(mock_run):
131
- """Unknown compute capability produces a fallback architecture string."""
132
- mock_run.return_value = subprocess.CompletedProcess(
133
- args=[], returncode=0, stdout="550.00.00, Future GPU, 10.0\n13.0\n13.0\n", stderr=""
134
- )
135
-
136
- info = query_gpu_info("1.2.3.4", "ubuntu", Path("/home/user/.ssh/id_ed25519"))
137
- assert info is not None
138
- assert info.architecture == "Unknown (10.0)"