aws-bootstrap-g4dn 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1 @@
1
+ """aws-bootstrap-g4dn: Bootstrap AWS EC2 GPU instances for hybrid development."""
aws_bootstrap/cli.py ADDED
@@ -0,0 +1,438 @@
1
+ """CLI entry point for aws-bootstrap-g4dn."""
2
+
3
+ from __future__ import annotations
4
+ from datetime import UTC, datetime
5
+ from pathlib import Path
6
+
7
+ import boto3
8
+ import click
9
+
10
+ from .config import LaunchConfig
11
+ from .ec2 import (
12
+ CLIError,
13
+ ensure_security_group,
14
+ find_tagged_instances,
15
+ get_latest_ami,
16
+ get_spot_price,
17
+ launch_instance,
18
+ list_amis,
19
+ list_instance_types,
20
+ terminate_tagged_instances,
21
+ wait_instance_ready,
22
+ )
23
+ from .ssh import (
24
+ add_ssh_host,
25
+ get_ssh_host_details,
26
+ import_key_pair,
27
+ list_ssh_hosts,
28
+ private_key_path,
29
+ query_gpu_info,
30
+ remove_ssh_host,
31
+ run_remote_setup,
32
+ wait_for_ssh,
33
+ )
34
+
35
+
36
+ SETUP_SCRIPT = Path(__file__).parent / "resources" / "remote_setup.sh"
37
+
38
+
39
+ def step(number: int, total: int, msg: str) -> None:
40
+ click.secho(f"\n[{number}/{total}] {msg}", bold=True, fg="cyan")
41
+
42
+
43
+ def info(msg: str) -> None:
44
+ click.echo(f" {msg}")
45
+
46
+
47
+ def val(label: str, value: str) -> None:
48
+ click.echo(f" {label}: " + click.style(str(value), fg="bright_white"))
49
+
50
+
51
+ def success(msg: str) -> None:
52
+ click.secho(f" {msg}", fg="green")
53
+
54
+
55
+ def warn(msg: str) -> None:
56
+ click.secho(f" WARNING: {msg}", fg="yellow", err=True)
57
+
58
+
59
+ @click.group()
60
+ @click.version_option(package_name="aws-bootstrap-g4dn")
61
+ def main():
62
+ """Bootstrap AWS EC2 GPU instances for hybrid local-remote development."""
63
+
64
+
65
+ @main.command()
66
+ @click.option("--instance-type", default="g4dn.xlarge", show_default=True, help="EC2 instance type.")
67
+ @click.option("--ami-filter", default=None, help="AMI name pattern filter (auto-detected if omitted).")
68
+ @click.option("--spot/--on-demand", default=True, show_default=True, help="Use spot or on-demand pricing.")
69
+ @click.option(
70
+ "--key-path",
71
+ default="~/.ssh/id_ed25519.pub",
72
+ show_default=True,
73
+ type=click.Path(),
74
+ help="Path to local SSH public key.",
75
+ )
76
+ @click.option("--key-name", default="aws-bootstrap-key", show_default=True, help="AWS key pair name.")
77
+ @click.option("--region", default="us-west-2", show_default=True, help="AWS region.")
78
+ @click.option("--security-group", default="aws-bootstrap-ssh", show_default=True, help="Security group name.")
79
+ @click.option("--volume-size", default=100, show_default=True, type=int, help="Root EBS volume size in GB (gp3).")
80
+ @click.option("--no-setup", is_flag=True, default=False, help="Skip running the remote setup script.")
81
+ @click.option("--dry-run", is_flag=True, default=False, help="Show what would be done without executing.")
82
+ @click.option("--profile", default=None, help="AWS profile override (defaults to AWS_PROFILE env var).")
83
+ def launch(
84
+ instance_type,
85
+ ami_filter,
86
+ spot,
87
+ key_path,
88
+ key_name,
89
+ region,
90
+ security_group,
91
+ volume_size,
92
+ no_setup,
93
+ dry_run,
94
+ profile,
95
+ ):
96
+ """Launch a GPU-accelerated EC2 instance."""
97
+ config = LaunchConfig(
98
+ instance_type=instance_type,
99
+ spot=spot,
100
+ key_path=Path(key_path).expanduser(),
101
+ key_name=key_name,
102
+ region=region,
103
+ security_group=security_group,
104
+ volume_size=volume_size,
105
+ run_setup=not no_setup,
106
+ dry_run=dry_run,
107
+ )
108
+ if ami_filter:
109
+ config.ami_filter = ami_filter
110
+ if profile:
111
+ config.profile = profile
112
+
113
+ # Validate key path
114
+ if not config.key_path.exists():
115
+ raise CLIError(f"SSH public key not found: {config.key_path}")
116
+
117
+ # Build boto3 session
118
+ session = boto3.Session(profile_name=config.profile, region_name=config.region)
119
+ ec2 = session.client("ec2")
120
+
121
+ # Step 1: AMI lookup
122
+ step(1, 6, "Looking up AMI...")
123
+ ami = get_latest_ami(ec2, config.ami_filter)
124
+ info(f"Found: {ami['Name']}")
125
+ val("AMI ID", ami["ImageId"])
126
+
127
+ # Step 2: SSH key pair
128
+ step(2, 6, "Importing SSH key pair...")
129
+ import_key_pair(ec2, config.key_name, config.key_path)
130
+
131
+ # Step 3: Security group
132
+ step(3, 6, "Ensuring security group...")
133
+ sg_id = ensure_security_group(ec2, config.security_group, config.tag_value)
134
+
135
+ pricing = "spot" if config.spot else "on-demand"
136
+
137
+ if config.dry_run:
138
+ click.echo()
139
+ click.secho("--- Dry Run Summary ---", bold=True, fg="yellow")
140
+ val("Instance type", config.instance_type)
141
+ val("AMI", f"{ami['ImageId']} ({ami['Name']})")
142
+ val("Pricing", pricing)
143
+ val("Key pair", config.key_name)
144
+ val("Security group", sg_id)
145
+ val("Volume", f"{config.volume_size} GB gp3")
146
+ val("Region", config.region)
147
+ val("Remote setup", "yes" if config.run_setup else "no")
148
+ click.echo()
149
+ click.secho("No resources launched (dry-run mode).", fg="yellow")
150
+ return
151
+
152
+ # Step 4: Launch instance
153
+ step(4, 6, f"Launching {config.instance_type} instance ({pricing})...")
154
+ instance = launch_instance(ec2, config, ami["ImageId"], sg_id)
155
+ instance_id = instance["InstanceId"]
156
+ val("Instance ID", instance_id)
157
+
158
+ # Step 5: Wait for ready
159
+ step(5, 6, "Waiting for instance to be ready...")
160
+ instance = wait_instance_ready(ec2, instance_id)
161
+ public_ip = instance.get("PublicIpAddress")
162
+ if not public_ip:
163
+ warn(f"No public IP assigned. Instance ID: {instance_id}")
164
+ info("You may need to assign an Elastic IP or check your VPC settings.")
165
+ return
166
+
167
+ val("Public IP", public_ip)
168
+
169
+ # Step 6: SSH and remote setup
170
+ step(6, 6, "Waiting for SSH access...")
171
+ private_key = private_key_path(config.key_path)
172
+ if not wait_for_ssh(public_ip, config.ssh_user, config.key_path):
173
+ warn("SSH did not become available within the timeout.")
174
+ info(f"Instance is running — try connecting manually: ssh -i {private_key} {config.ssh_user}@{public_ip}")
175
+ return
176
+
177
+ if config.run_setup:
178
+ if not SETUP_SCRIPT.exists():
179
+ warn(f"Setup script not found at {SETUP_SCRIPT}, skipping.")
180
+ else:
181
+ info("Running remote setup...")
182
+ if run_remote_setup(public_ip, config.ssh_user, config.key_path, SETUP_SCRIPT):
183
+ success("Remote setup completed successfully.")
184
+ else:
185
+ warn("Remote setup failed. Instance is still running.")
186
+
187
+ # Add SSH config alias
188
+ alias = add_ssh_host(
189
+ instance_id=instance_id,
190
+ hostname=public_ip,
191
+ user=config.ssh_user,
192
+ key_path=config.key_path,
193
+ alias_prefix=config.alias_prefix,
194
+ )
195
+ success(f"Added SSH config alias: {alias}")
196
+
197
+ # Print connection info
198
+ click.echo()
199
+ click.secho("=" * 60, fg="green")
200
+ click.secho(" Instance ready!", bold=True, fg="green")
201
+ click.secho("=" * 60, fg="green")
202
+ click.echo()
203
+ val("Instance ID", instance_id)
204
+ val("Public IP", public_ip)
205
+ val("Instance", config.instance_type)
206
+ val("Pricing", pricing)
207
+ val("SSH alias", alias)
208
+
209
+ click.echo()
210
+ click.secho(" SSH:", fg="cyan")
211
+ click.secho(f" ssh {alias}", bold=True)
212
+ info(f"or: ssh -i {private_key} {config.ssh_user}@{public_ip}")
213
+
214
+ click.echo()
215
+ click.secho(" Jupyter (via SSH tunnel):", fg="cyan")
216
+ click.secho(f" ssh -NL 8888:localhost:8888 {alias}", bold=True)
217
+ info(f"or: ssh -i {private_key} -NL 8888:localhost:8888 {config.ssh_user}@{public_ip}")
218
+ info("Then open: http://localhost:8888")
219
+ info("Notebook: ~/gpu_smoke_test.ipynb (GPU smoke test)")
220
+
221
+ click.echo()
222
+ click.secho(" GPU Benchmark:", fg="cyan")
223
+ click.secho(f" ssh {alias} 'python ~/gpu_benchmark.py'", bold=True)
224
+ info("Runs CNN (MNIST) and Transformer benchmarks with tqdm progress")
225
+
226
+ click.echo()
227
+ click.secho(" Terminate:", fg="cyan")
228
+ click.secho(f" aws-bootstrap terminate {instance_id} --region {config.region}", bold=True)
229
+ click.echo()
230
+
231
+
232
+ @main.command()
233
+ @click.option("--region", default="us-west-2", show_default=True, help="AWS region.")
234
+ @click.option("--profile", default=None, help="AWS profile override.")
235
+ @click.option("--gpu", is_flag=True, default=False, help="Query GPU info (CUDA, driver) via SSH.")
236
+ def status(region, profile, gpu):
237
+ """Show running instances created by aws-bootstrap."""
238
+ session = boto3.Session(profile_name=profile, region_name=region)
239
+ ec2 = session.client("ec2")
240
+
241
+ instances = find_tagged_instances(ec2, "aws-bootstrap-g4dn")
242
+ if not instances:
243
+ click.secho("No active aws-bootstrap instances found.", fg="yellow")
244
+ return
245
+
246
+ ssh_hosts = list_ssh_hosts()
247
+
248
+ click.secho(f"\n Found {len(instances)} instance(s):\n", bold=True, fg="cyan")
249
+ if gpu:
250
+ click.echo(" " + click.style("Querying GPU info via SSH...", dim=True))
251
+ click.echo()
252
+
253
+ for inst in instances:
254
+ state = inst["State"]
255
+ state_color = {
256
+ "running": "green",
257
+ "pending": "yellow",
258
+ "stopping": "yellow",
259
+ "stopped": "red",
260
+ "shutting-down": "red",
261
+ }.get(state, "white")
262
+ alias = ssh_hosts.get(inst["InstanceId"])
263
+ alias_str = f" ({alias})" if alias else ""
264
+ click.echo(
265
+ " "
266
+ + click.style(inst["InstanceId"], fg="bright_white")
267
+ + click.style(alias_str, fg="cyan")
268
+ + " "
269
+ + click.style(state, fg=state_color)
270
+ )
271
+ val(" Type", inst["InstanceType"])
272
+ if inst["PublicIp"]:
273
+ val(" IP", inst["PublicIp"])
274
+
275
+ # GPU info (opt-in, only for running instances with a public IP)
276
+ if gpu and state == "running" and inst["PublicIp"]:
277
+ details = get_ssh_host_details(inst["InstanceId"])
278
+ if details:
279
+ gpu_info = query_gpu_info(details.hostname, details.user, details.identity_file)
280
+ else:
281
+ gpu_info = query_gpu_info(
282
+ inst["PublicIp"],
283
+ "ubuntu",
284
+ Path("~/.ssh/id_ed25519").expanduser(),
285
+ )
286
+ if gpu_info:
287
+ val(" GPU", f"{gpu_info.gpu_name} ({gpu_info.architecture})")
288
+ if gpu_info.cuda_toolkit_version:
289
+ cuda_str = gpu_info.cuda_toolkit_version
290
+ if gpu_info.cuda_driver_version != gpu_info.cuda_toolkit_version:
291
+ cuda_str += f" (driver supports up to {gpu_info.cuda_driver_version})"
292
+ else:
293
+ cuda_str = f"{gpu_info.cuda_driver_version} (driver max, toolkit unknown)"
294
+ val(" CUDA", cuda_str)
295
+ val(" Driver", gpu_info.driver_version)
296
+ else:
297
+ click.echo(" GPU: " + click.style("unavailable", dim=True))
298
+
299
+ lifecycle = inst["Lifecycle"]
300
+ is_spot = lifecycle == "spot"
301
+
302
+ if is_spot:
303
+ spot_price = get_spot_price(ec2, inst["InstanceType"], inst["AvailabilityZone"])
304
+ if spot_price is not None:
305
+ val(" Pricing", f"spot (${spot_price:.4f}/hr)")
306
+ else:
307
+ val(" Pricing", "spot")
308
+ else:
309
+ val(" Pricing", "on-demand")
310
+
311
+ if state == "running" and is_spot:
312
+ uptime = datetime.now(UTC) - inst["LaunchTime"]
313
+ total_seconds = int(uptime.total_seconds())
314
+ hours, remainder = divmod(total_seconds, 3600)
315
+ minutes = remainder // 60
316
+ val(" Uptime", f"{hours}h {minutes:02d}m")
317
+ if spot_price is not None:
318
+ uptime_hours = uptime.total_seconds() / 3600
319
+ est_cost = uptime_hours * spot_price
320
+ val(" Est. cost", f"~${est_cost:.4f}")
321
+
322
+ val(" Launched", str(inst["LaunchTime"]))
323
+ click.echo()
324
+ first_id = instances[0]["InstanceId"]
325
+ click.echo(" To terminate: " + click.style(f"aws-bootstrap terminate {first_id}", bold=True))
326
+ click.echo()
327
+
328
+
329
+ @main.command()
330
+ @click.option("--region", default="us-west-2", show_default=True, help="AWS region.")
331
+ @click.option("--profile", default=None, help="AWS profile override.")
332
+ @click.option("--yes", "-y", is_flag=True, default=False, help="Skip confirmation prompt.")
333
+ @click.argument("instance_ids", nargs=-1)
334
+ def terminate(region, profile, yes, instance_ids):
335
+ """Terminate instances created by aws-bootstrap.
336
+
337
+ Pass specific instance IDs to terminate, or omit to terminate all
338
+ aws-bootstrap instances in the region.
339
+ """
340
+ session = boto3.Session(profile_name=profile, region_name=region)
341
+ ec2 = session.client("ec2")
342
+
343
+ if instance_ids:
344
+ targets = list(instance_ids)
345
+ else:
346
+ instances = find_tagged_instances(ec2, "aws-bootstrap-g4dn")
347
+ if not instances:
348
+ click.secho("No active aws-bootstrap instances found.", fg="yellow")
349
+ return
350
+ targets = [inst["InstanceId"] for inst in instances]
351
+ click.secho(f"\n Found {len(targets)} instance(s) to terminate:\n", bold=True, fg="cyan")
352
+ for inst in instances:
353
+ iid = click.style(inst["InstanceId"], fg="bright_white")
354
+ click.echo(f" {iid} {inst['State']} {inst['InstanceType']}")
355
+
356
+ if not yes:
357
+ click.echo()
358
+ if not click.confirm(f" Terminate {len(targets)} instance(s)?"):
359
+ click.secho(" Cancelled.", fg="yellow")
360
+ return
361
+
362
+ changes = terminate_tagged_instances(ec2, targets)
363
+ click.echo()
364
+ for change in changes:
365
+ prev = change["PreviousState"]["Name"]
366
+ curr = change["CurrentState"]["Name"]
367
+ click.echo(
368
+ " " + click.style(change["InstanceId"], fg="bright_white") + f" {prev} -> " + click.style(curr, fg="red")
369
+ )
370
+ removed_alias = remove_ssh_host(change["InstanceId"])
371
+ if removed_alias:
372
+ info(f"Removed SSH config alias: {removed_alias}")
373
+ click.echo()
374
+ success(f"Terminated {len(changes)} instance(s).")
375
+
376
+
377
+ # ---------------------------------------------------------------------------
378
+ # list command group
379
+ # ---------------------------------------------------------------------------
380
+
381
+ DEFAULT_AMI_PREFIX = "Deep Learning Base OSS Nvidia Driver GPU AMI*"
382
+
383
+
384
+ @main.group(name="list")
385
+ def list_cmd():
386
+ """List AWS resources (instance types, AMIs)."""
387
+
388
+
389
+ @list_cmd.command(name="instance-types")
390
+ @click.option("--prefix", default="g4dn", show_default=True, help="Instance type family prefix to filter on.")
391
+ @click.option("--region", default="us-west-2", show_default=True, help="AWS region.")
392
+ @click.option("--profile", default=None, help="AWS profile override.")
393
+ def list_instance_types_cmd(prefix, region, profile):
394
+ """List EC2 instance types matching a family prefix (e.g. g4dn, p3, g5)."""
395
+ session = boto3.Session(profile_name=profile, region_name=region)
396
+ ec2 = session.client("ec2")
397
+
398
+ types = list_instance_types(ec2, prefix)
399
+ if not types:
400
+ click.secho(f"No instance types found matching '{prefix}.*'", fg="yellow")
401
+ return
402
+
403
+ click.secho(f"\n {len(types)} instance type(s) matching '{prefix}.*':\n", bold=True, fg="cyan")
404
+
405
+ # Header
406
+ click.echo(
407
+ " " + click.style(f"{'Instance Type':<24}{'vCPUs':>6}{'Memory (MiB)':>14} GPU", fg="bright_white", bold=True)
408
+ )
409
+ click.echo(" " + "-" * 72)
410
+
411
+ for t in types:
412
+ gpu = t["GpuSummary"] or "-"
413
+ click.echo(f" {t['InstanceType']:<24}{t['VCpuCount']:>6}{t['MemoryMiB']:>14} {gpu}")
414
+
415
+ click.echo()
416
+
417
+
418
+ @list_cmd.command(name="amis")
419
+ @click.option("--filter", "ami_filter", default=DEFAULT_AMI_PREFIX, show_default=True, help="AMI name pattern.")
420
+ @click.option("--region", default="us-west-2", show_default=True, help="AWS region.")
421
+ @click.option("--profile", default=None, help="AWS profile override.")
422
+ def list_amis_cmd(ami_filter, region, profile):
423
+ """List available AMIs matching a name pattern."""
424
+ session = boto3.Session(profile_name=profile, region_name=region)
425
+ ec2 = session.client("ec2")
426
+
427
+ amis = list_amis(ec2, ami_filter)
428
+ if not amis:
429
+ click.secho(f"No AMIs found matching '{ami_filter}'", fg="yellow")
430
+ return
431
+
432
+ click.secho(f"\n {len(amis)} AMI(s) matching '{ami_filter}' (newest first):\n", bold=True, fg="cyan")
433
+
434
+ for ami in amis:
435
+ click.echo(" " + click.style(ami["ImageId"], fg="bright_white") + " " + ami["CreationDate"][:10])
436
+ click.echo(f" {ami['Name']}")
437
+
438
+ click.echo()
@@ -0,0 +1,24 @@
1
+ """Default configuration for EC2 GPU instance provisioning."""
2
+
3
+ from __future__ import annotations
4
+ import os
5
+ from dataclasses import dataclass, field
6
+ from pathlib import Path
7
+
8
+
9
+ @dataclass
10
+ class LaunchConfig:
11
+ instance_type: str = "g4dn.xlarge"
12
+ ami_filter: str = "Deep Learning Base OSS Nvidia Driver GPU AMI (Ubuntu 24.04)*"
13
+ spot: bool = True
14
+ key_path: Path = field(default_factory=lambda: Path.home() / ".ssh" / "id_ed25519.pub")
15
+ key_name: str = "aws-bootstrap-key"
16
+ region: str = "us-west-2"
17
+ security_group: str = "aws-bootstrap-ssh"
18
+ volume_size: int = 100
19
+ run_setup: bool = True
20
+ dry_run: bool = False
21
+ profile: str | None = field(default_factory=lambda: os.environ.get("AWS_PROFILE"))
22
+ ssh_user: str = "ubuntu"
23
+ tag_value: str = "aws-bootstrap-g4dn"
24
+ alias_prefix: str = "aws-gpu"