aws-bootstrap-g4dn 0.1.0__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aws_bootstrap/cli.py +109 -12
- aws_bootstrap/config.py +2 -0
- aws_bootstrap/ec2.py +3 -3
- aws_bootstrap/gpu.py +27 -0
- aws_bootstrap/resources/remote_setup.sh +7 -1
- aws_bootstrap/ssh.py +47 -47
- aws_bootstrap/tests/test_cli.py +315 -7
- aws_bootstrap/tests/test_gpu.py +98 -0
- aws_bootstrap/tests/test_ssh_config.py +36 -0
- aws_bootstrap/tests/test_ssh_gpu.py +1 -95
- {aws_bootstrap_g4dn-0.1.0.dist-info → aws_bootstrap_g4dn-0.3.0.dist-info}/METADATA +27 -5
- aws_bootstrap_g4dn-0.3.0.dist-info/RECORD +24 -0
- aws_bootstrap_g4dn-0.1.0.dist-info/RECORD +0 -22
- {aws_bootstrap_g4dn-0.1.0.dist-info → aws_bootstrap_g4dn-0.3.0.dist-info}/WHEEL +0 -0
- {aws_bootstrap_g4dn-0.1.0.dist-info → aws_bootstrap_g4dn-0.3.0.dist-info}/entry_points.txt +0 -0
- {aws_bootstrap_g4dn-0.1.0.dist-info → aws_bootstrap_g4dn-0.3.0.dist-info}/licenses/LICENSE +0 -0
- {aws_bootstrap_g4dn-0.1.0.dist-info → aws_bootstrap_g4dn-0.3.0.dist-info}/top_level.txt +0 -0
aws_bootstrap/cli.py
CHANGED
|
@@ -5,6 +5,7 @@ from datetime import UTC, datetime
|
|
|
5
5
|
from pathlib import Path
|
|
6
6
|
|
|
7
7
|
import boto3
|
|
8
|
+
import botocore.exceptions
|
|
8
9
|
import click
|
|
9
10
|
|
|
10
11
|
from .config import LaunchConfig
|
|
@@ -56,7 +57,39 @@ def warn(msg: str) -> None:
|
|
|
56
57
|
click.secho(f" WARNING: {msg}", fg="yellow", err=True)
|
|
57
58
|
|
|
58
59
|
|
|
59
|
-
|
|
60
|
+
class _AWSGroup(click.Group):
|
|
61
|
+
"""Click group that catches common AWS credential/auth errors."""
|
|
62
|
+
|
|
63
|
+
def invoke(self, ctx):
|
|
64
|
+
try:
|
|
65
|
+
return super().invoke(ctx)
|
|
66
|
+
except botocore.exceptions.NoCredentialsError:
|
|
67
|
+
raise CLIError(
|
|
68
|
+
"Unable to locate AWS credentials.\n\n"
|
|
69
|
+
" Make sure you have configured AWS credentials using one of:\n"
|
|
70
|
+
" - Set the AWS_PROFILE environment variable: export AWS_PROFILE=<profile-name>\n"
|
|
71
|
+
" - Pass --profile to the command: aws-bootstrap <command> --profile <profile-name>\n"
|
|
72
|
+
" - Configure a default profile: aws configure\n\n"
|
|
73
|
+
" See: https://docs.aws.amazon.com/cli/latest/userguide/cli-configure-files.html"
|
|
74
|
+
) from None
|
|
75
|
+
except botocore.exceptions.ProfileNotFound as e:
|
|
76
|
+
raise CLIError(f"{e}\n\n List available profiles with: aws configure list-profiles") from None
|
|
77
|
+
except botocore.exceptions.PartialCredentialsError as e:
|
|
78
|
+
raise CLIError(
|
|
79
|
+
f"Incomplete AWS credentials: {e}\n\n Check your AWS configuration with: aws configure list"
|
|
80
|
+
) from None
|
|
81
|
+
except botocore.exceptions.ClientError as e:
|
|
82
|
+
code = e.response["Error"]["Code"]
|
|
83
|
+
if code in ("AuthFailure", "UnauthorizedOperation", "ExpiredTokenException", "ExpiredToken"):
|
|
84
|
+
raise CLIError(
|
|
85
|
+
f"AWS authorization failed: {e.response['Error']['Message']}\n\n"
|
|
86
|
+
" Your credentials may be expired or lack the required permissions.\n"
|
|
87
|
+
" Check your AWS configuration with: aws configure list"
|
|
88
|
+
) from None
|
|
89
|
+
raise
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
@click.group(cls=_AWSGroup)
|
|
60
93
|
@click.version_option(package_name="aws-bootstrap-g4dn")
|
|
61
94
|
def main():
|
|
62
95
|
"""Bootstrap AWS EC2 GPU instances for hybrid local-remote development."""
|
|
@@ -80,6 +113,12 @@ def main():
|
|
|
80
113
|
@click.option("--no-setup", is_flag=True, default=False, help="Skip running the remote setup script.")
|
|
81
114
|
@click.option("--dry-run", is_flag=True, default=False, help="Show what would be done without executing.")
|
|
82
115
|
@click.option("--profile", default=None, help="AWS profile override (defaults to AWS_PROFILE env var).")
|
|
116
|
+
@click.option(
|
|
117
|
+
"--python-version",
|
|
118
|
+
default=None,
|
|
119
|
+
help="Python version for the remote venv (e.g. 3.13, 3.14.2). Passed to uv during setup.",
|
|
120
|
+
)
|
|
121
|
+
@click.option("--ssh-port", default=22, show_default=True, type=int, help="SSH port on the remote instance.")
|
|
83
122
|
def launch(
|
|
84
123
|
instance_type,
|
|
85
124
|
ami_filter,
|
|
@@ -92,6 +131,8 @@ def launch(
|
|
|
92
131
|
no_setup,
|
|
93
132
|
dry_run,
|
|
94
133
|
profile,
|
|
134
|
+
python_version,
|
|
135
|
+
ssh_port,
|
|
95
136
|
):
|
|
96
137
|
"""Launch a GPU-accelerated EC2 instance."""
|
|
97
138
|
config = LaunchConfig(
|
|
@@ -104,6 +145,8 @@ def launch(
|
|
|
104
145
|
volume_size=volume_size,
|
|
105
146
|
run_setup=not no_setup,
|
|
106
147
|
dry_run=dry_run,
|
|
148
|
+
ssh_port=ssh_port,
|
|
149
|
+
python_version=python_version,
|
|
107
150
|
)
|
|
108
151
|
if ami_filter:
|
|
109
152
|
config.ami_filter = ami_filter
|
|
@@ -130,7 +173,7 @@ def launch(
|
|
|
130
173
|
|
|
131
174
|
# Step 3: Security group
|
|
132
175
|
step(3, 6, "Ensuring security group...")
|
|
133
|
-
sg_id = ensure_security_group(ec2, config.security_group, config.tag_value)
|
|
176
|
+
sg_id = ensure_security_group(ec2, config.security_group, config.tag_value, ssh_port=config.ssh_port)
|
|
134
177
|
|
|
135
178
|
pricing = "spot" if config.spot else "on-demand"
|
|
136
179
|
|
|
@@ -145,6 +188,10 @@ def launch(
|
|
|
145
188
|
val("Volume", f"{config.volume_size} GB gp3")
|
|
146
189
|
val("Region", config.region)
|
|
147
190
|
val("Remote setup", "yes" if config.run_setup else "no")
|
|
191
|
+
if config.ssh_port != 22:
|
|
192
|
+
val("SSH port", str(config.ssh_port))
|
|
193
|
+
if config.python_version:
|
|
194
|
+
val("Python version", config.python_version)
|
|
148
195
|
click.echo()
|
|
149
196
|
click.secho("No resources launched (dry-run mode).", fg="yellow")
|
|
150
197
|
return
|
|
@@ -169,9 +216,13 @@ def launch(
|
|
|
169
216
|
# Step 6: SSH and remote setup
|
|
170
217
|
step(6, 6, "Waiting for SSH access...")
|
|
171
218
|
private_key = private_key_path(config.key_path)
|
|
172
|
-
if not wait_for_ssh(public_ip, config.ssh_user, config.key_path):
|
|
219
|
+
if not wait_for_ssh(public_ip, config.ssh_user, config.key_path, port=config.ssh_port):
|
|
173
220
|
warn("SSH did not become available within the timeout.")
|
|
174
|
-
|
|
221
|
+
port_flag = f" -p {config.ssh_port}" if config.ssh_port != 22 else ""
|
|
222
|
+
info(
|
|
223
|
+
f"Instance is running — try connecting manually:"
|
|
224
|
+
f" ssh -i {private_key}{port_flag} {config.ssh_user}@{public_ip}"
|
|
225
|
+
)
|
|
175
226
|
return
|
|
176
227
|
|
|
177
228
|
if config.run_setup:
|
|
@@ -179,7 +230,9 @@ def launch(
|
|
|
179
230
|
warn(f"Setup script not found at {SETUP_SCRIPT}, skipping.")
|
|
180
231
|
else:
|
|
181
232
|
info("Running remote setup...")
|
|
182
|
-
if run_remote_setup(
|
|
233
|
+
if run_remote_setup(
|
|
234
|
+
public_ip, config.ssh_user, config.key_path, SETUP_SCRIPT, config.python_version, port=config.ssh_port
|
|
235
|
+
):
|
|
183
236
|
success("Remote setup completed successfully.")
|
|
184
237
|
else:
|
|
185
238
|
warn("Remote setup failed. Instance is still running.")
|
|
@@ -191,6 +244,7 @@ def launch(
|
|
|
191
244
|
user=config.ssh_user,
|
|
192
245
|
key_path=config.key_path,
|
|
193
246
|
alias_prefix=config.alias_prefix,
|
|
247
|
+
port=config.ssh_port,
|
|
194
248
|
)
|
|
195
249
|
success(f"Added SSH config alias: {alias}")
|
|
196
250
|
|
|
@@ -206,18 +260,27 @@ def launch(
|
|
|
206
260
|
val("Pricing", pricing)
|
|
207
261
|
val("SSH alias", alias)
|
|
208
262
|
|
|
263
|
+
port_flag = f" -p {config.ssh_port}" if config.ssh_port != 22 else ""
|
|
264
|
+
|
|
209
265
|
click.echo()
|
|
210
266
|
click.secho(" SSH:", fg="cyan")
|
|
211
|
-
click.secho(f" ssh {alias}", bold=True)
|
|
212
|
-
info(f"or: ssh -i {private_key} {config.ssh_user}@{public_ip}")
|
|
267
|
+
click.secho(f" ssh{port_flag} {alias}", bold=True)
|
|
268
|
+
info(f"or: ssh -i {private_key}{port_flag} {config.ssh_user}@{public_ip}")
|
|
213
269
|
|
|
214
270
|
click.echo()
|
|
215
271
|
click.secho(" Jupyter (via SSH tunnel):", fg="cyan")
|
|
216
|
-
click.secho(f" ssh -NL 8888:localhost:8888 {alias}", bold=True)
|
|
217
|
-
info(f"or: ssh -i {private_key} -NL 8888:localhost:8888 {config.ssh_user}@{public_ip}")
|
|
272
|
+
click.secho(f" ssh -NL 8888:localhost:8888{port_flag} {alias}", bold=True)
|
|
273
|
+
info(f"or: ssh -i {private_key} -NL 8888:localhost:8888{port_flag} {config.ssh_user}@{public_ip}")
|
|
218
274
|
info("Then open: http://localhost:8888")
|
|
219
275
|
info("Notebook: ~/gpu_smoke_test.ipynb (GPU smoke test)")
|
|
220
276
|
|
|
277
|
+
click.echo()
|
|
278
|
+
click.secho(" VSCode Remote SSH:", fg="cyan")
|
|
279
|
+
click.secho(
|
|
280
|
+
f" code --folder-uri vscode-remote://ssh-remote+{alias}/home/{config.ssh_user}",
|
|
281
|
+
bold=True,
|
|
282
|
+
)
|
|
283
|
+
|
|
221
284
|
click.echo()
|
|
222
285
|
click.secho(" GPU Benchmark:", fg="cyan")
|
|
223
286
|
click.secho(f" ssh {alias} 'python ~/gpu_benchmark.py'", bold=True)
|
|
@@ -233,7 +296,14 @@ def launch(
|
|
|
233
296
|
@click.option("--region", default="us-west-2", show_default=True, help="AWS region.")
|
|
234
297
|
@click.option("--profile", default=None, help="AWS profile override.")
|
|
235
298
|
@click.option("--gpu", is_flag=True, default=False, help="Query GPU info (CUDA, driver) via SSH.")
|
|
236
|
-
|
|
299
|
+
@click.option(
|
|
300
|
+
"--instructions/--no-instructions",
|
|
301
|
+
"-I",
|
|
302
|
+
default=True,
|
|
303
|
+
show_default=True,
|
|
304
|
+
help="Show connection commands (SSH, Jupyter, VSCode) for each running instance.",
|
|
305
|
+
)
|
|
306
|
+
def status(region, profile, gpu, instructions):
|
|
237
307
|
"""Show running instances created by aws-bootstrap."""
|
|
238
308
|
session = boto3.Session(profile_name=profile, region_name=region)
|
|
239
309
|
ec2 = session.client("ec2")
|
|
@@ -272,11 +342,15 @@ def status(region, profile, gpu):
|
|
|
272
342
|
if inst["PublicIp"]:
|
|
273
343
|
val(" IP", inst["PublicIp"])
|
|
274
344
|
|
|
345
|
+
# Look up SSH config details once (used by --gpu and --with-instructions)
|
|
346
|
+
details = None
|
|
347
|
+
if (gpu or instructions) and state == "running" and inst["PublicIp"]:
|
|
348
|
+
details = get_ssh_host_details(inst["InstanceId"])
|
|
349
|
+
|
|
275
350
|
# GPU info (opt-in, only for running instances with a public IP)
|
|
276
351
|
if gpu and state == "running" and inst["PublicIp"]:
|
|
277
|
-
details = get_ssh_host_details(inst["InstanceId"])
|
|
278
352
|
if details:
|
|
279
|
-
gpu_info = query_gpu_info(details.hostname, details.user, details.identity_file)
|
|
353
|
+
gpu_info = query_gpu_info(details.hostname, details.user, details.identity_file, port=details.port)
|
|
280
354
|
else:
|
|
281
355
|
gpu_info = query_gpu_info(
|
|
282
356
|
inst["PublicIp"],
|
|
@@ -320,6 +394,29 @@ def status(region, profile, gpu):
|
|
|
320
394
|
val(" Est. cost", f"~${est_cost:.4f}")
|
|
321
395
|
|
|
322
396
|
val(" Launched", str(inst["LaunchTime"]))
|
|
397
|
+
|
|
398
|
+
# Connection instructions (opt-in, only for running instances with a public IP and alias)
|
|
399
|
+
if instructions and state == "running" and inst["PublicIp"] and alias:
|
|
400
|
+
user = details.user if details else "ubuntu"
|
|
401
|
+
port = details.port if details else 22
|
|
402
|
+
port_flag = f" -p {port}" if port != 22 else ""
|
|
403
|
+
|
|
404
|
+
click.echo()
|
|
405
|
+
click.secho(" SSH:", fg="cyan")
|
|
406
|
+
click.secho(f" ssh{port_flag} {alias}", bold=True)
|
|
407
|
+
|
|
408
|
+
click.secho(" Jupyter (via SSH tunnel):", fg="cyan")
|
|
409
|
+
click.secho(f" ssh -NL 8888:localhost:8888{port_flag} {alias}", bold=True)
|
|
410
|
+
|
|
411
|
+
click.secho(" VSCode Remote SSH:", fg="cyan")
|
|
412
|
+
click.secho(
|
|
413
|
+
f" code --folder-uri vscode-remote://ssh-remote+{alias}/home/{user}",
|
|
414
|
+
bold=True,
|
|
415
|
+
)
|
|
416
|
+
|
|
417
|
+
click.secho(" GPU Benchmark:", fg="cyan")
|
|
418
|
+
click.secho(f" ssh {alias} 'python ~/gpu_benchmark.py'", bold=True)
|
|
419
|
+
|
|
323
420
|
click.echo()
|
|
324
421
|
first_id = instances[0]["InstanceId"]
|
|
325
422
|
click.echo(" To terminate: " + click.style(f"aws-bootstrap terminate {first_id}", bold=True))
|
aws_bootstrap/config.py
CHANGED
aws_bootstrap/ec2.py
CHANGED
|
@@ -59,7 +59,7 @@ def get_latest_ami(ec2_client, ami_filter: str) -> dict:
|
|
|
59
59
|
return images[0]
|
|
60
60
|
|
|
61
61
|
|
|
62
|
-
def ensure_security_group(ec2_client, name: str, tag_value: str) -> str:
|
|
62
|
+
def ensure_security_group(ec2_client, name: str, tag_value: str, ssh_port: int = 22) -> str:
|
|
63
63
|
"""Find or create a security group with SSH ingress in the default VPC."""
|
|
64
64
|
# Find default VPC
|
|
65
65
|
vpcs = ec2_client.describe_vpcs(Filters=[{"Name": "isDefault", "Values": ["true"]}])
|
|
@@ -103,8 +103,8 @@ def ensure_security_group(ec2_client, name: str, tag_value: str) -> str:
|
|
|
103
103
|
IpPermissions=[
|
|
104
104
|
{
|
|
105
105
|
"IpProtocol": "tcp",
|
|
106
|
-
"FromPort":
|
|
107
|
-
"ToPort":
|
|
106
|
+
"FromPort": ssh_port,
|
|
107
|
+
"ToPort": ssh_port,
|
|
108
108
|
"IpRanges": [{"CidrIp": "0.0.0.0/0", "Description": "SSH access"}],
|
|
109
109
|
}
|
|
110
110
|
],
|
aws_bootstrap/gpu.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
"""GPU architecture mapping and GPU info dataclass."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
_GPU_ARCHITECTURES: dict[str, str] = {
|
|
8
|
+
"7.0": "Volta",
|
|
9
|
+
"7.5": "Turing",
|
|
10
|
+
"8.0": "Ampere",
|
|
11
|
+
"8.6": "Ampere",
|
|
12
|
+
"8.7": "Ampere",
|
|
13
|
+
"8.9": "Ada Lovelace",
|
|
14
|
+
"9.0": "Hopper",
|
|
15
|
+
}
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@dataclass
|
|
19
|
+
class GpuInfo:
|
|
20
|
+
"""GPU information retrieved via nvidia-smi and nvcc."""
|
|
21
|
+
|
|
22
|
+
driver_version: str
|
|
23
|
+
cuda_driver_version: str # max CUDA version supported by driver (from nvidia-smi)
|
|
24
|
+
cuda_toolkit_version: str | None # actual CUDA toolkit installed (from nvcc), None if unavailable
|
|
25
|
+
gpu_name: str
|
|
26
|
+
compute_capability: str
|
|
27
|
+
architecture: str
|
|
@@ -34,7 +34,13 @@ if ! command -v uv &>/dev/null; then
|
|
|
34
34
|
fi
|
|
35
35
|
export PATH="$HOME/.local/bin:$PATH"
|
|
36
36
|
|
|
37
|
-
|
|
37
|
+
if [ -n "${PYTHON_VERSION:-}" ]; then
|
|
38
|
+
echo " Installing Python ${PYTHON_VERSION}..."
|
|
39
|
+
uv python install "$PYTHON_VERSION"
|
|
40
|
+
uv venv --python "$PYTHON_VERSION" ~/venv
|
|
41
|
+
else
|
|
42
|
+
uv venv ~/venv
|
|
43
|
+
fi
|
|
38
44
|
|
|
39
45
|
# --- CUDA-aware PyTorch installation ---
|
|
40
46
|
# Known PyTorch CUDA wheel tags (ascending order).
|
aws_bootstrap/ssh.py
CHANGED
|
@@ -12,6 +12,8 @@ from pathlib import Path
|
|
|
12
12
|
|
|
13
13
|
import click
|
|
14
14
|
|
|
15
|
+
from .gpu import _GPU_ARCHITECTURES, GpuInfo
|
|
16
|
+
|
|
15
17
|
|
|
16
18
|
# ---------------------------------------------------------------------------
|
|
17
19
|
# SSH config markers
|
|
@@ -72,17 +74,18 @@ def import_key_pair(ec2_client, key_name: str, key_path: Path) -> str:
|
|
|
72
74
|
return key_name
|
|
73
75
|
|
|
74
76
|
|
|
75
|
-
def wait_for_ssh(host: str, user: str, key_path: Path, retries: int = 30, delay: int = 10) -> bool:
|
|
77
|
+
def wait_for_ssh(host: str, user: str, key_path: Path, retries: int = 30, delay: int = 10, port: int = 22) -> bool:
|
|
76
78
|
"""Wait for SSH to become available on the instance.
|
|
77
79
|
|
|
78
|
-
Tries a TCP connection to port
|
|
80
|
+
Tries a TCP connection to the SSH port first, then an actual SSH command.
|
|
79
81
|
"""
|
|
80
82
|
base_opts = _ssh_opts(key_path)
|
|
83
|
+
port_opts = ["-p", str(port)] if port != 22 else []
|
|
81
84
|
|
|
82
85
|
for attempt in range(1, retries + 1):
|
|
83
|
-
# First check if port
|
|
86
|
+
# First check if the SSH port is open
|
|
84
87
|
try:
|
|
85
|
-
sock = socket.create_connection((host,
|
|
88
|
+
sock = socket.create_connection((host, port), timeout=5)
|
|
86
89
|
sock.close()
|
|
87
90
|
except (TimeoutError, ConnectionRefusedError, OSError):
|
|
88
91
|
click.echo(" SSH not ready " + click.style(f"(attempt {attempt}/{retries})", dim=True) + ", waiting...")
|
|
@@ -90,11 +93,18 @@ def wait_for_ssh(host: str, user: str, key_path: Path, retries: int = 30, delay:
|
|
|
90
93
|
continue
|
|
91
94
|
|
|
92
95
|
# Port is open, try actual SSH
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
96
|
+
cmd = [
|
|
97
|
+
"ssh",
|
|
98
|
+
*base_opts,
|
|
99
|
+
*port_opts,
|
|
100
|
+
"-o",
|
|
101
|
+
"ConnectTimeout=10",
|
|
102
|
+
"-o",
|
|
103
|
+
"BatchMode=yes",
|
|
104
|
+
f"{user}@{host}",
|
|
105
|
+
"echo ok",
|
|
106
|
+
]
|
|
107
|
+
result = subprocess.run(cmd, capture_output=True, text=True)
|
|
98
108
|
if result.returncode == 0:
|
|
99
109
|
click.secho(" SSH connection established.", fg="green")
|
|
100
110
|
return True
|
|
@@ -105,15 +115,19 @@ def wait_for_ssh(host: str, user: str, key_path: Path, retries: int = 30, delay:
|
|
|
105
115
|
return False
|
|
106
116
|
|
|
107
117
|
|
|
108
|
-
def run_remote_setup(
|
|
118
|
+
def run_remote_setup(
|
|
119
|
+
host: str, user: str, key_path: Path, script_path: Path, python_version: str | None = None, port: int = 22
|
|
120
|
+
) -> bool:
|
|
109
121
|
"""SCP the setup script and requirements.txt to the instance and execute."""
|
|
110
122
|
ssh_opts = _ssh_opts(key_path)
|
|
123
|
+
scp_port_opts = ["-P", str(port)] if port != 22 else []
|
|
124
|
+
ssh_port_opts = ["-p", str(port)] if port != 22 else []
|
|
111
125
|
requirements_path = script_path.parent / "requirements.txt"
|
|
112
126
|
|
|
113
127
|
# SCP the requirements file
|
|
114
128
|
click.echo(" Uploading requirements.txt...")
|
|
115
129
|
req_result = subprocess.run(
|
|
116
|
-
["scp", *ssh_opts, str(requirements_path), f"{user}@{host}:/tmp/requirements.txt"],
|
|
130
|
+
["scp", *ssh_opts, *scp_port_opts, str(requirements_path), f"{user}@{host}:/tmp/requirements.txt"],
|
|
117
131
|
capture_output=True,
|
|
118
132
|
text=True,
|
|
119
133
|
)
|
|
@@ -125,7 +139,7 @@ def run_remote_setup(host: str, user: str, key_path: Path, script_path: Path) ->
|
|
|
125
139
|
benchmark_path = script_path.parent / "gpu_benchmark.py"
|
|
126
140
|
click.echo(" Uploading gpu_benchmark.py...")
|
|
127
141
|
bench_result = subprocess.run(
|
|
128
|
-
["scp", *ssh_opts, str(benchmark_path), f"{user}@{host}:/tmp/gpu_benchmark.py"],
|
|
142
|
+
["scp", *ssh_opts, *scp_port_opts, str(benchmark_path), f"{user}@{host}:/tmp/gpu_benchmark.py"],
|
|
129
143
|
capture_output=True,
|
|
130
144
|
text=True,
|
|
131
145
|
)
|
|
@@ -137,7 +151,7 @@ def run_remote_setup(host: str, user: str, key_path: Path, script_path: Path) ->
|
|
|
137
151
|
notebook_path = script_path.parent / "gpu_smoke_test.ipynb"
|
|
138
152
|
click.echo(" Uploading gpu_smoke_test.ipynb...")
|
|
139
153
|
nb_result = subprocess.run(
|
|
140
|
-
["scp", *ssh_opts, str(notebook_path), f"{user}@{host}:/tmp/gpu_smoke_test.ipynb"],
|
|
154
|
+
["scp", *ssh_opts, *scp_port_opts, str(notebook_path), f"{user}@{host}:/tmp/gpu_smoke_test.ipynb"],
|
|
141
155
|
capture_output=True,
|
|
142
156
|
text=True,
|
|
143
157
|
)
|
|
@@ -148,7 +162,7 @@ def run_remote_setup(host: str, user: str, key_path: Path, script_path: Path) ->
|
|
|
148
162
|
# SCP the script
|
|
149
163
|
click.echo(" Uploading remote_setup.sh...")
|
|
150
164
|
scp_result = subprocess.run(
|
|
151
|
-
["scp", *ssh_opts, str(script_path), f"{user}@{host}:/tmp/remote_setup.sh"],
|
|
165
|
+
["scp", *ssh_opts, *scp_port_opts, str(script_path), f"{user}@{host}:/tmp/remote_setup.sh"],
|
|
152
166
|
capture_output=True,
|
|
153
167
|
text=True,
|
|
154
168
|
)
|
|
@@ -156,10 +170,14 @@ def run_remote_setup(host: str, user: str, key_path: Path, script_path: Path) ->
|
|
|
156
170
|
click.secho(f" SCP failed: {scp_result.stderr}", fg="red", err=True)
|
|
157
171
|
return False
|
|
158
172
|
|
|
159
|
-
# Execute the script
|
|
173
|
+
# Execute the script, passing PYTHON_VERSION as an inline env var if specified
|
|
160
174
|
click.echo(" Running remote_setup.sh on instance...")
|
|
175
|
+
remote_cmd = "chmod +x /tmp/remote_setup.sh && "
|
|
176
|
+
if python_version:
|
|
177
|
+
remote_cmd += f"PYTHON_VERSION={python_version} "
|
|
178
|
+
remote_cmd += "/tmp/remote_setup.sh"
|
|
161
179
|
ssh_result = subprocess.run(
|
|
162
|
-
["ssh", *ssh_opts, f"{user}@{host}",
|
|
180
|
+
["ssh", *ssh_opts, *ssh_port_opts, f"{user}@{host}", remote_cmd],
|
|
163
181
|
capture_output=False,
|
|
164
182
|
)
|
|
165
183
|
return ssh_result.returncode == 0
|
|
@@ -222,15 +240,17 @@ def _next_alias(content: str, prefix: str = "aws-gpu") -> str:
|
|
|
222
240
|
return f"{prefix}{max_n + 1}"
|
|
223
241
|
|
|
224
242
|
|
|
225
|
-
def _build_stanza(instance_id: str, alias: str, hostname: str, user: str, key_path: Path) -> str:
|
|
243
|
+
def _build_stanza(instance_id: str, alias: str, hostname: str, user: str, key_path: Path, port: int = 22) -> str:
|
|
226
244
|
"""Build a complete SSH config stanza with markers."""
|
|
227
245
|
priv_key = private_key_path(key_path)
|
|
246
|
+
port_line = f" Port {port}\n" if port != 22 else ""
|
|
228
247
|
return (
|
|
229
248
|
f"{_BEGIN_MARKER.format(instance_id=instance_id)}\n"
|
|
230
249
|
f"Host {alias}\n"
|
|
231
250
|
f" HostName {hostname}\n"
|
|
232
251
|
f" User {user}\n"
|
|
233
252
|
f" IdentityFile {priv_key}\n"
|
|
253
|
+
f"{port_line}"
|
|
234
254
|
f" StrictHostKeyChecking no\n"
|
|
235
255
|
f" UserKnownHostsFile /dev/null\n"
|
|
236
256
|
f"{_END_MARKER.format(instance_id=instance_id)}\n"
|
|
@@ -244,6 +264,7 @@ def add_ssh_host(
|
|
|
244
264
|
key_path: Path,
|
|
245
265
|
config_path: Path | None = None,
|
|
246
266
|
alias_prefix: str = "aws-gpu",
|
|
267
|
+
port: int = 22,
|
|
247
268
|
) -> str:
|
|
248
269
|
"""Add (or update) an SSH host stanza for *instance_id*.
|
|
249
270
|
|
|
@@ -257,7 +278,7 @@ def add_ssh_host(
|
|
|
257
278
|
content = _remove_block(content, instance_id)
|
|
258
279
|
|
|
259
280
|
alias = existing_alias or _next_alias(content, alias_prefix)
|
|
260
|
-
stanza = _build_stanza(instance_id, alias, hostname, user, key_path)
|
|
281
|
+
stanza = _build_stanza(instance_id, alias, hostname, user, key_path, port=port)
|
|
261
282
|
|
|
262
283
|
# Ensure a blank line before our block if file has content
|
|
263
284
|
if content and not content.endswith("\n\n") and not content.endswith("\n"):
|
|
@@ -317,21 +338,6 @@ def list_ssh_hosts(config_path: Path | None = None) -> dict[str, str]:
|
|
|
317
338
|
return result
|
|
318
339
|
|
|
319
340
|
|
|
320
|
-
# ---------------------------------------------------------------------------
|
|
321
|
-
# GPU info via SSH
|
|
322
|
-
# ---------------------------------------------------------------------------
|
|
323
|
-
|
|
324
|
-
_GPU_ARCHITECTURES: dict[str, str] = {
|
|
325
|
-
"7.0": "Volta",
|
|
326
|
-
"7.5": "Turing",
|
|
327
|
-
"8.0": "Ampere",
|
|
328
|
-
"8.6": "Ampere",
|
|
329
|
-
"8.7": "Ampere",
|
|
330
|
-
"8.9": "Ada Lovelace",
|
|
331
|
-
"9.0": "Hopper",
|
|
332
|
-
}
|
|
333
|
-
|
|
334
|
-
|
|
335
341
|
@dataclass
|
|
336
342
|
class SSHHostDetails:
|
|
337
343
|
"""Connection details parsed from an SSH config stanza."""
|
|
@@ -339,18 +345,7 @@ class SSHHostDetails:
|
|
|
339
345
|
hostname: str
|
|
340
346
|
user: str
|
|
341
347
|
identity_file: Path
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
@dataclass
|
|
345
|
-
class GpuInfo:
|
|
346
|
-
"""GPU information retrieved via nvidia-smi and nvcc."""
|
|
347
|
-
|
|
348
|
-
driver_version: str
|
|
349
|
-
cuda_driver_version: str # max CUDA version supported by driver (from nvidia-smi)
|
|
350
|
-
cuda_toolkit_version: str | None # actual CUDA toolkit installed (from nvcc), None if unavailable
|
|
351
|
-
gpu_name: str
|
|
352
|
-
compute_capability: str
|
|
353
|
-
architecture: str
|
|
348
|
+
port: int = 22
|
|
354
349
|
|
|
355
350
|
|
|
356
351
|
def get_ssh_host_details(instance_id: str, config_path: Path | None = None) -> SSHHostDetails | None:
|
|
@@ -371,6 +366,7 @@ def get_ssh_host_details(instance_id: str, config_path: Path | None = None) -> S
|
|
|
371
366
|
hostname: str | None = None
|
|
372
367
|
user: str | None = None
|
|
373
368
|
identity_file: str | None = None
|
|
369
|
+
port: int = 22
|
|
374
370
|
|
|
375
371
|
for line in content.splitlines():
|
|
376
372
|
if line == begin_marker:
|
|
@@ -378,7 +374,7 @@ def get_ssh_host_details(instance_id: str, config_path: Path | None = None) -> S
|
|
|
378
374
|
continue
|
|
379
375
|
if line == end_marker and in_block:
|
|
380
376
|
if hostname and user and identity_file:
|
|
381
|
-
return SSHHostDetails(hostname=hostname, user=user, identity_file=Path(identity_file))
|
|
377
|
+
return SSHHostDetails(hostname=hostname, user=user, identity_file=Path(identity_file), port=port)
|
|
382
378
|
return None
|
|
383
379
|
if in_block:
|
|
384
380
|
stripped = line.strip()
|
|
@@ -388,17 +384,20 @@ def get_ssh_host_details(instance_id: str, config_path: Path | None = None) -> S
|
|
|
388
384
|
user = stripped.removeprefix("User ").strip()
|
|
389
385
|
elif stripped.startswith("IdentityFile "):
|
|
390
386
|
identity_file = stripped.removeprefix("IdentityFile ").strip()
|
|
387
|
+
elif stripped.startswith("Port "):
|
|
388
|
+
port = int(stripped.removeprefix("Port ").strip())
|
|
391
389
|
|
|
392
390
|
return None
|
|
393
391
|
|
|
394
392
|
|
|
395
|
-
def query_gpu_info(host: str, user: str, key_path: Path, timeout: int = 10) -> GpuInfo | None:
|
|
393
|
+
def query_gpu_info(host: str, user: str, key_path: Path, timeout: int = 10, port: int = 22) -> GpuInfo | None:
|
|
396
394
|
"""SSH into a host and query GPU info via ``nvidia-smi``.
|
|
397
395
|
|
|
398
396
|
Returns ``GpuInfo`` on success, or ``None`` if the SSH connection fails,
|
|
399
397
|
``nvidia-smi`` is unavailable, or the output is malformed.
|
|
400
398
|
"""
|
|
401
399
|
ssh_opts = _ssh_opts(key_path)
|
|
400
|
+
port_opts = ["-p", str(port)] if port != 22 else []
|
|
402
401
|
remote_cmd = (
|
|
403
402
|
"nvidia-smi --query-gpu=driver_version,name,compute_cap --format=csv,noheader,nounits"
|
|
404
403
|
" && nvidia-smi | grep -oP 'CUDA Version: \\K[\\d.]+'"
|
|
@@ -407,6 +406,7 @@ def query_gpu_info(host: str, user: str, key_path: Path, timeout: int = 10) -> G
|
|
|
407
406
|
cmd = [
|
|
408
407
|
"ssh",
|
|
409
408
|
*ssh_opts,
|
|
409
|
+
*port_opts,
|
|
410
410
|
"-o",
|
|
411
411
|
f"ConnectTimeout={timeout}",
|
|
412
412
|
"-o",
|