aws-bootstrap-g4dn 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
aws_bootstrap/ec2.py ADDED
@@ -0,0 +1,341 @@
1
+ """EC2 instance provisioning: AMI lookup, security groups, and instance launch."""
2
+
3
+ from __future__ import annotations
4
+ from datetime import UTC, datetime
5
+
6
+ import botocore.exceptions
7
+ import click
8
+
9
+ from .config import LaunchConfig
10
+
11
+
12
+ class CLIError(click.ClickException):
13
+ """A ClickException that displays the error message in red."""
14
+
15
+ def show(self, file=None): # type: ignore[no-untyped-def]
16
+ if file is None:
17
+ file = click.get_text_stream("stderr")
18
+ click.secho(f"Error: {self.format_message()}", file=file, fg="red")
19
+
20
+
21
+ # Well-known AMI owners by name prefix
22
+ _OWNER_HINTS = {
23
+ "Deep Learning": ["amazon"],
24
+ "ubuntu": ["099720109477"], # Canonical
25
+ "Ubuntu": ["099720109477"],
26
+ "RHEL": ["309956199498"],
27
+ "al20": ["amazon"], # Amazon Linux
28
+ }
29
+
30
+
31
+ def get_latest_ami(ec2_client, ami_filter: str) -> dict:
32
+ """Find the latest AMI matching the filter pattern.
33
+
34
+ Infers the owner from the filter prefix when possible,
35
+ otherwise searches all public AMIs.
36
+ """
37
+ owners = None
38
+ for prefix, owner_ids in _OWNER_HINTS.items():
39
+ if ami_filter.startswith(prefix):
40
+ owners = owner_ids
41
+ break
42
+
43
+ params: dict = {
44
+ "Filters": [
45
+ {"Name": "name", "Values": [ami_filter]},
46
+ {"Name": "state", "Values": ["available"]},
47
+ {"Name": "architecture", "Values": ["x86_64"]},
48
+ ],
49
+ }
50
+ if owners:
51
+ params["Owners"] = owners
52
+
53
+ response = ec2_client.describe_images(**params)
54
+ images = response["Images"]
55
+ if not images:
56
+ raise CLIError(f"No AMI found matching filter: {ami_filter}\nTry adjusting --ami-filter or check the region.")
57
+
58
+ images.sort(key=lambda x: x["CreationDate"], reverse=True)
59
+ return images[0]
60
+
61
+
62
+ def ensure_security_group(ec2_client, name: str, tag_value: str) -> str:
63
+ """Find or create a security group with SSH ingress in the default VPC."""
64
+ # Find default VPC
65
+ vpcs = ec2_client.describe_vpcs(Filters=[{"Name": "isDefault", "Values": ["true"]}])
66
+ if not vpcs["Vpcs"]:
67
+ raise CLIError("No default VPC found. Create one or specify a VPC.")
68
+ vpc_id = vpcs["Vpcs"][0]["VpcId"]
69
+
70
+ # Check if SG already exists
71
+ existing = ec2_client.describe_security_groups(
72
+ Filters=[
73
+ {"Name": "group-name", "Values": [name]},
74
+ {"Name": "vpc-id", "Values": [vpc_id]},
75
+ ]
76
+ )
77
+ if existing["SecurityGroups"]:
78
+ sg_id = existing["SecurityGroups"][0]["GroupId"]
79
+ msg = " Security group " + click.style(f"'{name}'", fg="bright_white")
80
+ click.echo(msg + f" already exists ({sg_id}), reusing.")
81
+ return sg_id
82
+
83
+ # Create new SG
84
+ sg = ec2_client.create_security_group(
85
+ GroupName=name,
86
+ Description="SSH access for aws-bootstrap-g4dn instances",
87
+ VpcId=vpc_id,
88
+ TagSpecifications=[
89
+ {
90
+ "ResourceType": "security-group",
91
+ "Tags": [
92
+ {"Key": "created-by", "Value": tag_value},
93
+ {"Key": "Name", "Value": name},
94
+ ],
95
+ }
96
+ ],
97
+ )
98
+ sg_id = sg["GroupId"]
99
+
100
+ # Add SSH ingress
101
+ ec2_client.authorize_security_group_ingress(
102
+ GroupId=sg_id,
103
+ IpPermissions=[
104
+ {
105
+ "IpProtocol": "tcp",
106
+ "FromPort": 22,
107
+ "ToPort": 22,
108
+ "IpRanges": [{"CidrIp": "0.0.0.0/0", "Description": "SSH access"}],
109
+ }
110
+ ],
111
+ )
112
+ click.secho(f" Created security group '{name}' ({sg_id}) with SSH ingress.", fg="green")
113
+ return sg_id
114
+
115
+
116
+ def launch_instance(ec2_client, config: LaunchConfig, ami_id: str, sg_id: str) -> dict:
117
+ """Launch an EC2 instance (spot or on-demand)."""
118
+ launch_params = {
119
+ "ImageId": ami_id,
120
+ "InstanceType": config.instance_type,
121
+ "KeyName": config.key_name,
122
+ "SecurityGroupIds": [sg_id],
123
+ "MinCount": 1,
124
+ "MaxCount": 1,
125
+ "BlockDeviceMappings": [
126
+ {
127
+ "DeviceName": "/dev/sda1",
128
+ "Ebs": {
129
+ "VolumeSize": config.volume_size,
130
+ "VolumeType": "gp3",
131
+ "DeleteOnTermination": True,
132
+ },
133
+ }
134
+ ],
135
+ "TagSpecifications": [
136
+ {
137
+ "ResourceType": "instance",
138
+ "Tags": [
139
+ {"Key": "Name", "Value": f"aws-bootstrap-{config.instance_type}"},
140
+ {"Key": "created-by", "Value": config.tag_value},
141
+ ],
142
+ }
143
+ ],
144
+ }
145
+
146
+ if config.spot:
147
+ launch_params["InstanceMarketOptions"] = {
148
+ "MarketType": "spot",
149
+ "SpotOptions": {
150
+ "SpotInstanceType": "one-time",
151
+ "InstanceInterruptionBehavior": "terminate",
152
+ },
153
+ }
154
+
155
+ try:
156
+ response = ec2_client.run_instances(**launch_params)
157
+ except botocore.exceptions.ClientError as e:
158
+ code = e.response["Error"]["Code"]
159
+ if code in ("MaxSpotInstanceCountExceeded", "VcpuLimitExceeded"):
160
+ _raise_quota_error(code, config)
161
+ elif code in ("InsufficientInstanceCapacity", "SpotMaxPriceTooLow") and config.spot:
162
+ click.secho(f"\n Spot request failed: {e.response['Error']['Message']}", fg="yellow")
163
+ if click.confirm(" Retry as on-demand instance?"):
164
+ launch_params.pop("InstanceMarketOptions", None)
165
+ try:
166
+ response = ec2_client.run_instances(**launch_params)
167
+ except botocore.exceptions.ClientError as retry_e:
168
+ retry_code = retry_e.response["Error"]["Code"]
169
+ if retry_code in ("MaxSpotInstanceCountExceeded", "VcpuLimitExceeded"):
170
+ _raise_quota_error(retry_code, config)
171
+ raise
172
+ else:
173
+ raise CLIError("Launch cancelled.") from None
174
+ else:
175
+ raise
176
+
177
+ return response["Instances"][0]
178
+
179
+
180
+ _UBUNTU_AMI = "ubuntu/images/hvm-ssd-gp3/ubuntu-noble-24.04-amd64-server-*"
181
+
182
+ QUOTA_HINT = (
183
+ "See the 'EC2 vCPU Quotas' section in README.md for instructions on\n"
184
+ " checking and requesting quota increases.\n\n"
185
+ " To test the flow without GPU quotas, try:\n"
186
+ f' aws-bootstrap launch --instance-type t3.medium --ami-filter "{_UBUNTU_AMI}"'
187
+ )
188
+
189
+
190
+ def _raise_quota_error(code: str, config: LaunchConfig) -> None:
191
+ if code == "MaxSpotInstanceCountExceeded":
192
+ pricing = "spot"
193
+ label = "Spot instance"
194
+ else:
195
+ pricing = "spot" if config.spot else "on-demand"
196
+ label = "On-demand vCPU"
197
+ msg = (
198
+ f"{label} quota exceeded for {config.instance_type} in {config.region}.\n\n"
199
+ f" Your account's {pricing} vCPU limit for this instance family is too low.\n"
200
+ f" {QUOTA_HINT}"
201
+ )
202
+ raise CLIError(msg)
203
+
204
+
205
+ def find_tagged_instances(ec2_client, tag_value: str) -> list[dict]:
206
+ """Find all non-terminated instances with the created-by tag."""
207
+ response = ec2_client.describe_instances(
208
+ Filters=[
209
+ {"Name": "tag:created-by", "Values": [tag_value]},
210
+ {
211
+ "Name": "instance-state-name",
212
+ "Values": ["pending", "running", "stopping", "stopped", "shutting-down"],
213
+ },
214
+ ]
215
+ )
216
+ instances = []
217
+ for reservation in response["Reservations"]:
218
+ for inst in reservation["Instances"]:
219
+ name = next((tag["Value"] for tag in inst.get("Tags", []) if tag["Key"] == "Name"), "")
220
+ instances.append(
221
+ {
222
+ "InstanceId": inst["InstanceId"],
223
+ "Name": name,
224
+ "State": inst["State"]["Name"],
225
+ "InstanceType": inst["InstanceType"],
226
+ "PublicIp": inst.get("PublicIpAddress", ""),
227
+ "LaunchTime": inst["LaunchTime"],
228
+ "Lifecycle": inst.get("InstanceLifecycle", "on-demand"),
229
+ "AvailabilityZone": inst["Placement"]["AvailabilityZone"],
230
+ }
231
+ )
232
+ return instances
233
+
234
+
235
+ def get_spot_price(ec2_client, instance_type: str, availability_zone: str) -> float | None:
236
+ """Get the current spot price for an instance type in a given AZ.
237
+
238
+ Returns the hourly price as a float, or None if unavailable.
239
+ """
240
+ response = ec2_client.describe_spot_price_history(
241
+ InstanceTypes=[instance_type],
242
+ ProductDescriptions=["Linux/UNIX"],
243
+ AvailabilityZone=availability_zone,
244
+ StartTime=datetime.now(UTC),
245
+ MaxResults=1,
246
+ )
247
+ prices = response.get("SpotPriceHistory", [])
248
+ if not prices:
249
+ return None
250
+ return float(prices[0]["SpotPrice"])
251
+
252
+
253
+ def list_instance_types(ec2_client, name_prefix: str = "g4dn") -> list[dict]:
254
+ """List EC2 instance types matching a name prefix (e.g. 'g4dn', 'p3').
255
+
256
+ Returns a list of dicts with InstanceType, vCPUs, MemoryMiB, and GPUs info,
257
+ sorted by instance type name.
258
+ """
259
+ paginator = ec2_client.get_paginator("describe_instance_types")
260
+ pages = paginator.paginate(
261
+ Filters=[{"Name": "instance-type", "Values": [f"{name_prefix}.*"]}],
262
+ )
263
+ results = []
264
+ for page in pages:
265
+ for it in page["InstanceTypes"]:
266
+ gpus = it.get("GpuInfo", {}).get("Gpus", [])
267
+ gpu_summary = ""
268
+ if gpus:
269
+ g = gpus[0]
270
+ mem = g.get("MemoryInfo", {}).get("SizeInMiB", 0)
271
+ gpu_summary = f"{g.get('Count', '?')}x {g.get('Name', 'GPU')} ({mem} MiB)"
272
+ results.append(
273
+ {
274
+ "InstanceType": it["InstanceType"],
275
+ "VCpuCount": it["VCpuInfo"]["DefaultVCpus"],
276
+ "MemoryMiB": it["MemoryInfo"]["SizeInMiB"],
277
+ "GpuSummary": gpu_summary,
278
+ }
279
+ )
280
+ results.sort(key=lambda x: x["InstanceType"])
281
+ return results
282
+
283
+
284
+ def list_amis(ec2_client, ami_filter: str) -> list[dict]:
285
+ """List available AMIs matching a name filter pattern.
286
+
287
+ Returns a list of dicts with ImageId, Name, CreationDate, and Architecture,
288
+ sorted by creation date (newest first). Limited to the 20 most recent.
289
+ """
290
+ owners = None
291
+ for prefix, owner_ids in _OWNER_HINTS.items():
292
+ if ami_filter.startswith(prefix):
293
+ owners = owner_ids
294
+ break
295
+
296
+ params: dict = {
297
+ "Filters": [
298
+ {"Name": "name", "Values": [ami_filter]},
299
+ {"Name": "state", "Values": ["available"]},
300
+ {"Name": "architecture", "Values": ["x86_64"]},
301
+ ],
302
+ }
303
+ if owners:
304
+ params["Owners"] = owners
305
+
306
+ response = ec2_client.describe_images(**params)
307
+ images = response["Images"]
308
+ images.sort(key=lambda x: x["CreationDate"], reverse=True)
309
+ return [
310
+ {
311
+ "ImageId": img["ImageId"],
312
+ "Name": img["Name"],
313
+ "CreationDate": img["CreationDate"],
314
+ "Architecture": img.get("Architecture", ""),
315
+ }
316
+ for img in images[:20]
317
+ ]
318
+
319
+
320
+ def terminate_tagged_instances(ec2_client, instance_ids: list[str]) -> list[dict]:
321
+ """Terminate instances by ID. Returns the state changes."""
322
+ response = ec2_client.terminate_instances(InstanceIds=instance_ids)
323
+ return response["TerminatingInstances"]
324
+
325
+
326
+ def wait_instance_ready(ec2_client, instance_id: str) -> dict:
327
+ """Wait for the instance to be running and pass status checks."""
328
+ click.echo(" Waiting for instance " + click.style(instance_id, fg="bright_white") + " to enter 'running' state...")
329
+ waiter = ec2_client.get_waiter("instance_running")
330
+ waiter.wait(InstanceIds=[instance_id], WaiterConfig={"Delay": 10, "MaxAttempts": 60})
331
+ click.secho(" Instance running.", fg="green")
332
+
333
+ click.echo(" Waiting for instance status checks to pass...")
334
+ waiter = ec2_client.get_waiter("instance_status_ok")
335
+ waiter.wait(InstanceIds=[instance_id], WaiterConfig={"Delay": 15, "MaxAttempts": 60})
336
+ click.secho(" Status checks passed.", fg="green")
337
+
338
+ # Refresh instance info to get public IP
339
+ desc = ec2_client.describe_instances(InstanceIds=[instance_id])
340
+ instance = desc["Reservations"][0]["Instances"][0]
341
+ return instance
File without changes