aws-bootstrap-g4dn 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aws_bootstrap/__init__.py +1 -0
- aws_bootstrap/cli.py +438 -0
- aws_bootstrap/config.py +24 -0
- aws_bootstrap/ec2.py +341 -0
- aws_bootstrap/resources/__init__.py +0 -0
- aws_bootstrap/resources/gpu_benchmark.py +839 -0
- aws_bootstrap/resources/gpu_smoke_test.ipynb +340 -0
- aws_bootstrap/resources/remote_setup.sh +188 -0
- aws_bootstrap/resources/requirements.txt +8 -0
- aws_bootstrap/ssh.py +513 -0
- aws_bootstrap/tests/__init__.py +0 -0
- aws_bootstrap/tests/test_cli.py +528 -0
- aws_bootstrap/tests/test_config.py +35 -0
- aws_bootstrap/tests/test_ec2.py +313 -0
- aws_bootstrap/tests/test_ssh_config.py +297 -0
- aws_bootstrap/tests/test_ssh_gpu.py +138 -0
- aws_bootstrap_g4dn-0.1.0.dist-info/METADATA +308 -0
- aws_bootstrap_g4dn-0.1.0.dist-info/RECORD +22 -0
- aws_bootstrap_g4dn-0.1.0.dist-info/WHEEL +5 -0
- aws_bootstrap_g4dn-0.1.0.dist-info/entry_points.txt +2 -0
- aws_bootstrap_g4dn-0.1.0.dist-info/licenses/LICENSE +21 -0
- aws_bootstrap_g4dn-0.1.0.dist-info/top_level.txt +1 -0
aws_bootstrap/ec2.py
ADDED
|
@@ -0,0 +1,341 @@
|
|
|
1
|
+
"""EC2 instance provisioning: AMI lookup, security groups, and instance launch."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
from datetime import UTC, datetime
|
|
5
|
+
|
|
6
|
+
import botocore.exceptions
|
|
7
|
+
import click
|
|
8
|
+
|
|
9
|
+
from .config import LaunchConfig
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class CLIError(click.ClickException):
|
|
13
|
+
"""A ClickException that displays the error message in red."""
|
|
14
|
+
|
|
15
|
+
def show(self, file=None): # type: ignore[no-untyped-def]
|
|
16
|
+
if file is None:
|
|
17
|
+
file = click.get_text_stream("stderr")
|
|
18
|
+
click.secho(f"Error: {self.format_message()}", file=file, fg="red")
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
# Well-known AMI owners by name prefix
|
|
22
|
+
_OWNER_HINTS = {
|
|
23
|
+
"Deep Learning": ["amazon"],
|
|
24
|
+
"ubuntu": ["099720109477"], # Canonical
|
|
25
|
+
"Ubuntu": ["099720109477"],
|
|
26
|
+
"RHEL": ["309956199498"],
|
|
27
|
+
"al20": ["amazon"], # Amazon Linux
|
|
28
|
+
}
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def get_latest_ami(ec2_client, ami_filter: str) -> dict:
|
|
32
|
+
"""Find the latest AMI matching the filter pattern.
|
|
33
|
+
|
|
34
|
+
Infers the owner from the filter prefix when possible,
|
|
35
|
+
otherwise searches all public AMIs.
|
|
36
|
+
"""
|
|
37
|
+
owners = None
|
|
38
|
+
for prefix, owner_ids in _OWNER_HINTS.items():
|
|
39
|
+
if ami_filter.startswith(prefix):
|
|
40
|
+
owners = owner_ids
|
|
41
|
+
break
|
|
42
|
+
|
|
43
|
+
params: dict = {
|
|
44
|
+
"Filters": [
|
|
45
|
+
{"Name": "name", "Values": [ami_filter]},
|
|
46
|
+
{"Name": "state", "Values": ["available"]},
|
|
47
|
+
{"Name": "architecture", "Values": ["x86_64"]},
|
|
48
|
+
],
|
|
49
|
+
}
|
|
50
|
+
if owners:
|
|
51
|
+
params["Owners"] = owners
|
|
52
|
+
|
|
53
|
+
response = ec2_client.describe_images(**params)
|
|
54
|
+
images = response["Images"]
|
|
55
|
+
if not images:
|
|
56
|
+
raise CLIError(f"No AMI found matching filter: {ami_filter}\nTry adjusting --ami-filter or check the region.")
|
|
57
|
+
|
|
58
|
+
images.sort(key=lambda x: x["CreationDate"], reverse=True)
|
|
59
|
+
return images[0]
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def ensure_security_group(ec2_client, name: str, tag_value: str) -> str:
|
|
63
|
+
"""Find or create a security group with SSH ingress in the default VPC."""
|
|
64
|
+
# Find default VPC
|
|
65
|
+
vpcs = ec2_client.describe_vpcs(Filters=[{"Name": "isDefault", "Values": ["true"]}])
|
|
66
|
+
if not vpcs["Vpcs"]:
|
|
67
|
+
raise CLIError("No default VPC found. Create one or specify a VPC.")
|
|
68
|
+
vpc_id = vpcs["Vpcs"][0]["VpcId"]
|
|
69
|
+
|
|
70
|
+
# Check if SG already exists
|
|
71
|
+
existing = ec2_client.describe_security_groups(
|
|
72
|
+
Filters=[
|
|
73
|
+
{"Name": "group-name", "Values": [name]},
|
|
74
|
+
{"Name": "vpc-id", "Values": [vpc_id]},
|
|
75
|
+
]
|
|
76
|
+
)
|
|
77
|
+
if existing["SecurityGroups"]:
|
|
78
|
+
sg_id = existing["SecurityGroups"][0]["GroupId"]
|
|
79
|
+
msg = " Security group " + click.style(f"'{name}'", fg="bright_white")
|
|
80
|
+
click.echo(msg + f" already exists ({sg_id}), reusing.")
|
|
81
|
+
return sg_id
|
|
82
|
+
|
|
83
|
+
# Create new SG
|
|
84
|
+
sg = ec2_client.create_security_group(
|
|
85
|
+
GroupName=name,
|
|
86
|
+
Description="SSH access for aws-bootstrap-g4dn instances",
|
|
87
|
+
VpcId=vpc_id,
|
|
88
|
+
TagSpecifications=[
|
|
89
|
+
{
|
|
90
|
+
"ResourceType": "security-group",
|
|
91
|
+
"Tags": [
|
|
92
|
+
{"Key": "created-by", "Value": tag_value},
|
|
93
|
+
{"Key": "Name", "Value": name},
|
|
94
|
+
],
|
|
95
|
+
}
|
|
96
|
+
],
|
|
97
|
+
)
|
|
98
|
+
sg_id = sg["GroupId"]
|
|
99
|
+
|
|
100
|
+
# Add SSH ingress
|
|
101
|
+
ec2_client.authorize_security_group_ingress(
|
|
102
|
+
GroupId=sg_id,
|
|
103
|
+
IpPermissions=[
|
|
104
|
+
{
|
|
105
|
+
"IpProtocol": "tcp",
|
|
106
|
+
"FromPort": 22,
|
|
107
|
+
"ToPort": 22,
|
|
108
|
+
"IpRanges": [{"CidrIp": "0.0.0.0/0", "Description": "SSH access"}],
|
|
109
|
+
}
|
|
110
|
+
],
|
|
111
|
+
)
|
|
112
|
+
click.secho(f" Created security group '{name}' ({sg_id}) with SSH ingress.", fg="green")
|
|
113
|
+
return sg_id
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def launch_instance(ec2_client, config: LaunchConfig, ami_id: str, sg_id: str) -> dict:
|
|
117
|
+
"""Launch an EC2 instance (spot or on-demand)."""
|
|
118
|
+
launch_params = {
|
|
119
|
+
"ImageId": ami_id,
|
|
120
|
+
"InstanceType": config.instance_type,
|
|
121
|
+
"KeyName": config.key_name,
|
|
122
|
+
"SecurityGroupIds": [sg_id],
|
|
123
|
+
"MinCount": 1,
|
|
124
|
+
"MaxCount": 1,
|
|
125
|
+
"BlockDeviceMappings": [
|
|
126
|
+
{
|
|
127
|
+
"DeviceName": "/dev/sda1",
|
|
128
|
+
"Ebs": {
|
|
129
|
+
"VolumeSize": config.volume_size,
|
|
130
|
+
"VolumeType": "gp3",
|
|
131
|
+
"DeleteOnTermination": True,
|
|
132
|
+
},
|
|
133
|
+
}
|
|
134
|
+
],
|
|
135
|
+
"TagSpecifications": [
|
|
136
|
+
{
|
|
137
|
+
"ResourceType": "instance",
|
|
138
|
+
"Tags": [
|
|
139
|
+
{"Key": "Name", "Value": f"aws-bootstrap-{config.instance_type}"},
|
|
140
|
+
{"Key": "created-by", "Value": config.tag_value},
|
|
141
|
+
],
|
|
142
|
+
}
|
|
143
|
+
],
|
|
144
|
+
}
|
|
145
|
+
|
|
146
|
+
if config.spot:
|
|
147
|
+
launch_params["InstanceMarketOptions"] = {
|
|
148
|
+
"MarketType": "spot",
|
|
149
|
+
"SpotOptions": {
|
|
150
|
+
"SpotInstanceType": "one-time",
|
|
151
|
+
"InstanceInterruptionBehavior": "terminate",
|
|
152
|
+
},
|
|
153
|
+
}
|
|
154
|
+
|
|
155
|
+
try:
|
|
156
|
+
response = ec2_client.run_instances(**launch_params)
|
|
157
|
+
except botocore.exceptions.ClientError as e:
|
|
158
|
+
code = e.response["Error"]["Code"]
|
|
159
|
+
if code in ("MaxSpotInstanceCountExceeded", "VcpuLimitExceeded"):
|
|
160
|
+
_raise_quota_error(code, config)
|
|
161
|
+
elif code in ("InsufficientInstanceCapacity", "SpotMaxPriceTooLow") and config.spot:
|
|
162
|
+
click.secho(f"\n Spot request failed: {e.response['Error']['Message']}", fg="yellow")
|
|
163
|
+
if click.confirm(" Retry as on-demand instance?"):
|
|
164
|
+
launch_params.pop("InstanceMarketOptions", None)
|
|
165
|
+
try:
|
|
166
|
+
response = ec2_client.run_instances(**launch_params)
|
|
167
|
+
except botocore.exceptions.ClientError as retry_e:
|
|
168
|
+
retry_code = retry_e.response["Error"]["Code"]
|
|
169
|
+
if retry_code in ("MaxSpotInstanceCountExceeded", "VcpuLimitExceeded"):
|
|
170
|
+
_raise_quota_error(retry_code, config)
|
|
171
|
+
raise
|
|
172
|
+
else:
|
|
173
|
+
raise CLIError("Launch cancelled.") from None
|
|
174
|
+
else:
|
|
175
|
+
raise
|
|
176
|
+
|
|
177
|
+
return response["Instances"][0]
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
_UBUNTU_AMI = "ubuntu/images/hvm-ssd-gp3/ubuntu-noble-24.04-amd64-server-*"
|
|
181
|
+
|
|
182
|
+
QUOTA_HINT = (
|
|
183
|
+
"See the 'EC2 vCPU Quotas' section in README.md for instructions on\n"
|
|
184
|
+
" checking and requesting quota increases.\n\n"
|
|
185
|
+
" To test the flow without GPU quotas, try:\n"
|
|
186
|
+
f' aws-bootstrap launch --instance-type t3.medium --ami-filter "{_UBUNTU_AMI}"'
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
def _raise_quota_error(code: str, config: LaunchConfig) -> None:
|
|
191
|
+
if code == "MaxSpotInstanceCountExceeded":
|
|
192
|
+
pricing = "spot"
|
|
193
|
+
label = "Spot instance"
|
|
194
|
+
else:
|
|
195
|
+
pricing = "spot" if config.spot else "on-demand"
|
|
196
|
+
label = "On-demand vCPU"
|
|
197
|
+
msg = (
|
|
198
|
+
f"{label} quota exceeded for {config.instance_type} in {config.region}.\n\n"
|
|
199
|
+
f" Your account's {pricing} vCPU limit for this instance family is too low.\n"
|
|
200
|
+
f" {QUOTA_HINT}"
|
|
201
|
+
)
|
|
202
|
+
raise CLIError(msg)
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
def find_tagged_instances(ec2_client, tag_value: str) -> list[dict]:
|
|
206
|
+
"""Find all non-terminated instances with the created-by tag."""
|
|
207
|
+
response = ec2_client.describe_instances(
|
|
208
|
+
Filters=[
|
|
209
|
+
{"Name": "tag:created-by", "Values": [tag_value]},
|
|
210
|
+
{
|
|
211
|
+
"Name": "instance-state-name",
|
|
212
|
+
"Values": ["pending", "running", "stopping", "stopped", "shutting-down"],
|
|
213
|
+
},
|
|
214
|
+
]
|
|
215
|
+
)
|
|
216
|
+
instances = []
|
|
217
|
+
for reservation in response["Reservations"]:
|
|
218
|
+
for inst in reservation["Instances"]:
|
|
219
|
+
name = next((tag["Value"] for tag in inst.get("Tags", []) if tag["Key"] == "Name"), "")
|
|
220
|
+
instances.append(
|
|
221
|
+
{
|
|
222
|
+
"InstanceId": inst["InstanceId"],
|
|
223
|
+
"Name": name,
|
|
224
|
+
"State": inst["State"]["Name"],
|
|
225
|
+
"InstanceType": inst["InstanceType"],
|
|
226
|
+
"PublicIp": inst.get("PublicIpAddress", ""),
|
|
227
|
+
"LaunchTime": inst["LaunchTime"],
|
|
228
|
+
"Lifecycle": inst.get("InstanceLifecycle", "on-demand"),
|
|
229
|
+
"AvailabilityZone": inst["Placement"]["AvailabilityZone"],
|
|
230
|
+
}
|
|
231
|
+
)
|
|
232
|
+
return instances
|
|
233
|
+
|
|
234
|
+
|
|
235
|
+
def get_spot_price(ec2_client, instance_type: str, availability_zone: str) -> float | None:
|
|
236
|
+
"""Get the current spot price for an instance type in a given AZ.
|
|
237
|
+
|
|
238
|
+
Returns the hourly price as a float, or None if unavailable.
|
|
239
|
+
"""
|
|
240
|
+
response = ec2_client.describe_spot_price_history(
|
|
241
|
+
InstanceTypes=[instance_type],
|
|
242
|
+
ProductDescriptions=["Linux/UNIX"],
|
|
243
|
+
AvailabilityZone=availability_zone,
|
|
244
|
+
StartTime=datetime.now(UTC),
|
|
245
|
+
MaxResults=1,
|
|
246
|
+
)
|
|
247
|
+
prices = response.get("SpotPriceHistory", [])
|
|
248
|
+
if not prices:
|
|
249
|
+
return None
|
|
250
|
+
return float(prices[0]["SpotPrice"])
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
def list_instance_types(ec2_client, name_prefix: str = "g4dn") -> list[dict]:
|
|
254
|
+
"""List EC2 instance types matching a name prefix (e.g. 'g4dn', 'p3').
|
|
255
|
+
|
|
256
|
+
Returns a list of dicts with InstanceType, vCPUs, MemoryMiB, and GPUs info,
|
|
257
|
+
sorted by instance type name.
|
|
258
|
+
"""
|
|
259
|
+
paginator = ec2_client.get_paginator("describe_instance_types")
|
|
260
|
+
pages = paginator.paginate(
|
|
261
|
+
Filters=[{"Name": "instance-type", "Values": [f"{name_prefix}.*"]}],
|
|
262
|
+
)
|
|
263
|
+
results = []
|
|
264
|
+
for page in pages:
|
|
265
|
+
for it in page["InstanceTypes"]:
|
|
266
|
+
gpus = it.get("GpuInfo", {}).get("Gpus", [])
|
|
267
|
+
gpu_summary = ""
|
|
268
|
+
if gpus:
|
|
269
|
+
g = gpus[0]
|
|
270
|
+
mem = g.get("MemoryInfo", {}).get("SizeInMiB", 0)
|
|
271
|
+
gpu_summary = f"{g.get('Count', '?')}x {g.get('Name', 'GPU')} ({mem} MiB)"
|
|
272
|
+
results.append(
|
|
273
|
+
{
|
|
274
|
+
"InstanceType": it["InstanceType"],
|
|
275
|
+
"VCpuCount": it["VCpuInfo"]["DefaultVCpus"],
|
|
276
|
+
"MemoryMiB": it["MemoryInfo"]["SizeInMiB"],
|
|
277
|
+
"GpuSummary": gpu_summary,
|
|
278
|
+
}
|
|
279
|
+
)
|
|
280
|
+
results.sort(key=lambda x: x["InstanceType"])
|
|
281
|
+
return results
|
|
282
|
+
|
|
283
|
+
|
|
284
|
+
def list_amis(ec2_client, ami_filter: str) -> list[dict]:
|
|
285
|
+
"""List available AMIs matching a name filter pattern.
|
|
286
|
+
|
|
287
|
+
Returns a list of dicts with ImageId, Name, CreationDate, and Architecture,
|
|
288
|
+
sorted by creation date (newest first). Limited to the 20 most recent.
|
|
289
|
+
"""
|
|
290
|
+
owners = None
|
|
291
|
+
for prefix, owner_ids in _OWNER_HINTS.items():
|
|
292
|
+
if ami_filter.startswith(prefix):
|
|
293
|
+
owners = owner_ids
|
|
294
|
+
break
|
|
295
|
+
|
|
296
|
+
params: dict = {
|
|
297
|
+
"Filters": [
|
|
298
|
+
{"Name": "name", "Values": [ami_filter]},
|
|
299
|
+
{"Name": "state", "Values": ["available"]},
|
|
300
|
+
{"Name": "architecture", "Values": ["x86_64"]},
|
|
301
|
+
],
|
|
302
|
+
}
|
|
303
|
+
if owners:
|
|
304
|
+
params["Owners"] = owners
|
|
305
|
+
|
|
306
|
+
response = ec2_client.describe_images(**params)
|
|
307
|
+
images = response["Images"]
|
|
308
|
+
images.sort(key=lambda x: x["CreationDate"], reverse=True)
|
|
309
|
+
return [
|
|
310
|
+
{
|
|
311
|
+
"ImageId": img["ImageId"],
|
|
312
|
+
"Name": img["Name"],
|
|
313
|
+
"CreationDate": img["CreationDate"],
|
|
314
|
+
"Architecture": img.get("Architecture", ""),
|
|
315
|
+
}
|
|
316
|
+
for img in images[:20]
|
|
317
|
+
]
|
|
318
|
+
|
|
319
|
+
|
|
320
|
+
def terminate_tagged_instances(ec2_client, instance_ids: list[str]) -> list[dict]:
|
|
321
|
+
"""Terminate instances by ID. Returns the state changes."""
|
|
322
|
+
response = ec2_client.terminate_instances(InstanceIds=instance_ids)
|
|
323
|
+
return response["TerminatingInstances"]
|
|
324
|
+
|
|
325
|
+
|
|
326
|
+
def wait_instance_ready(ec2_client, instance_id: str) -> dict:
|
|
327
|
+
"""Wait for the instance to be running and pass status checks."""
|
|
328
|
+
click.echo(" Waiting for instance " + click.style(instance_id, fg="bright_white") + " to enter 'running' state...")
|
|
329
|
+
waiter = ec2_client.get_waiter("instance_running")
|
|
330
|
+
waiter.wait(InstanceIds=[instance_id], WaiterConfig={"Delay": 10, "MaxAttempts": 60})
|
|
331
|
+
click.secho(" Instance running.", fg="green")
|
|
332
|
+
|
|
333
|
+
click.echo(" Waiting for instance status checks to pass...")
|
|
334
|
+
waiter = ec2_client.get_waiter("instance_status_ok")
|
|
335
|
+
waiter.wait(InstanceIds=[instance_id], WaiterConfig={"Delay": 15, "MaxAttempts": 60})
|
|
336
|
+
click.secho(" Status checks passed.", fg="green")
|
|
337
|
+
|
|
338
|
+
# Refresh instance info to get public IP
|
|
339
|
+
desc = ec2_client.describe_instances(InstanceIds=[instance_id])
|
|
340
|
+
instance = desc["Reservations"][0]["Instances"][0]
|
|
341
|
+
return instance
|
|
File without changes
|