aws-bootstrap-g4dn 0.5.0__py3-none-any.whl → 0.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
aws_bootstrap/config.py CHANGED
@@ -24,3 +24,5 @@ class LaunchConfig:
24
24
  alias_prefix: str = "aws-gpu"
25
25
  ssh_port: int = 22
26
26
  python_version: str | None = None
27
+ ebs_storage: int | None = None
28
+ ebs_volume_id: str | None = None
aws_bootstrap/ec2.py CHANGED
@@ -7,6 +7,11 @@ import botocore.exceptions
7
7
  import click
8
8
 
9
9
  from .config import LaunchConfig
10
+ from .output import echo, is_text, secho
11
+
12
+
13
+ EBS_DEVICE_NAME = "/dev/sdf"
14
+ EBS_MOUNT_POINT = "/data"
10
15
 
11
16
 
12
17
  class CLIError(click.ClickException):
@@ -77,7 +82,7 @@ def ensure_security_group(ec2_client, name: str, tag_value: str, ssh_port: int =
77
82
  if existing["SecurityGroups"]:
78
83
  sg_id = existing["SecurityGroups"][0]["GroupId"]
79
84
  msg = " Security group " + click.style(f"'{name}'", fg="bright_white")
80
- click.echo(msg + f" already exists ({sg_id}), reusing.")
85
+ echo(msg + f" already exists ({sg_id}), reusing.")
81
86
  return sg_id
82
87
 
83
88
  # Create new SG
@@ -109,7 +114,7 @@ def ensure_security_group(ec2_client, name: str, tag_value: str, ssh_port: int =
109
114
  }
110
115
  ],
111
116
  )
112
- click.secho(f" Created security group '{name}' ({sg_id}) with SSH ingress.", fg="green")
117
+ secho(f" Created security group '{name}' ({sg_id}) with SSH ingress.", fg="green")
113
118
  return sg_id
114
119
 
115
120
 
@@ -159,8 +164,8 @@ def launch_instance(ec2_client, config: LaunchConfig, ami_id: str, sg_id: str) -
159
164
  if code in ("MaxSpotInstanceCountExceeded", "VcpuLimitExceeded"):
160
165
  _raise_quota_error(code, config)
161
166
  elif code in ("InsufficientInstanceCapacity", "SpotMaxPriceTooLow") and config.spot:
162
- click.secho(f"\n Spot request failed: {e.response['Error']['Message']}", fg="yellow")
163
- if click.confirm(" Retry as on-demand instance?"):
167
+ secho(f"\n Spot request failed: {e.response['Error']['Message']}", fg="yellow")
168
+ if not is_text() or click.confirm(" Retry as on-demand instance?"):
164
169
  launch_params.pop("InstanceMarketOptions", None)
165
170
  try:
166
171
  response = ec2_client.run_instances(**launch_params)
@@ -325,17 +330,141 @@ def terminate_tagged_instances(ec2_client, instance_ids: list[str]) -> list[dict
325
330
 
326
331
  def wait_instance_ready(ec2_client, instance_id: str) -> dict:
327
332
  """Wait for the instance to be running and pass status checks."""
328
- click.echo(" Waiting for instance " + click.style(instance_id, fg="bright_white") + " to enter 'running' state...")
333
+ echo(" Waiting for instance " + click.style(instance_id, fg="bright_white") + " to enter 'running' state...")
329
334
  waiter = ec2_client.get_waiter("instance_running")
330
335
  waiter.wait(InstanceIds=[instance_id], WaiterConfig={"Delay": 10, "MaxAttempts": 60})
331
- click.secho(" Instance running.", fg="green")
336
+ secho(" Instance running.", fg="green")
332
337
 
333
- click.echo(" Waiting for instance status checks to pass...")
338
+ echo(" Waiting for instance status checks to pass...")
334
339
  waiter = ec2_client.get_waiter("instance_status_ok")
335
340
  waiter.wait(InstanceIds=[instance_id], WaiterConfig={"Delay": 15, "MaxAttempts": 60})
336
- click.secho(" Status checks passed.", fg="green")
341
+ secho(" Status checks passed.", fg="green")
337
342
 
338
343
  # Refresh instance info to get public IP
339
344
  desc = ec2_client.describe_instances(InstanceIds=[instance_id])
340
345
  instance = desc["Reservations"][0]["Instances"][0]
341
346
  return instance
347
+
348
+
349
+ # ---------------------------------------------------------------------------
350
+ # EBS data volume operations
351
+ # ---------------------------------------------------------------------------
352
+
353
+
354
+ def create_ebs_volume(ec2_client, size_gb: int, availability_zone: str, tag_value: str, instance_id: str) -> str:
355
+ """Create a gp3 EBS volume and wait for it to become available.
356
+
357
+ Returns the volume ID.
358
+ """
359
+ response = ec2_client.create_volume(
360
+ AvailabilityZone=availability_zone,
361
+ Size=size_gb,
362
+ VolumeType="gp3",
363
+ TagSpecifications=[
364
+ {
365
+ "ResourceType": "volume",
366
+ "Tags": [
367
+ {"Key": "created-by", "Value": tag_value},
368
+ {"Key": "Name", "Value": f"aws-bootstrap-data-{instance_id}"},
369
+ {"Key": "aws-bootstrap-instance", "Value": instance_id},
370
+ ],
371
+ }
372
+ ],
373
+ )
374
+ volume_id = response["VolumeId"]
375
+
376
+ waiter = ec2_client.get_waiter("volume_available")
377
+ waiter.wait(VolumeIds=[volume_id], WaiterConfig={"Delay": 5, "MaxAttempts": 24})
378
+ return volume_id
379
+
380
+
381
+ def validate_ebs_volume(ec2_client, volume_id: str, availability_zone: str) -> dict:
382
+ """Validate that an existing EBS volume can be attached.
383
+
384
+ Checks that the volume exists, is available (not in-use), and is in the
385
+ correct availability zone. Returns the volume description dict.
386
+
387
+ Raises CLIError for validation failures.
388
+ """
389
+ try:
390
+ response = ec2_client.describe_volumes(VolumeIds=[volume_id])
391
+ except botocore.exceptions.ClientError as e:
392
+ if e.response["Error"]["Code"] == "InvalidVolume.NotFound":
393
+ raise CLIError(f"EBS volume not found: {volume_id}") from None
394
+ raise
395
+
396
+ volumes = response["Volumes"]
397
+ if not volumes:
398
+ raise CLIError(f"EBS volume not found: {volume_id}")
399
+
400
+ vol = volumes[0]
401
+
402
+ if vol["State"] != "available":
403
+ raise CLIError(
404
+ f"EBS volume {volume_id} is currently '{vol['State']}' (must be 'available').\n"
405
+ " Detach it from its current instance first."
406
+ )
407
+
408
+ if vol["AvailabilityZone"] != availability_zone:
409
+ raise CLIError(
410
+ f"EBS volume {volume_id} is in {vol['AvailabilityZone']} "
411
+ f"but the instance is in {availability_zone}.\n"
412
+ " EBS volumes must be in the same availability zone as the instance."
413
+ )
414
+
415
+ return vol
416
+
417
+
418
+ def attach_ebs_volume(ec2_client, volume_id: str, instance_id: str, device_name: str = EBS_DEVICE_NAME) -> None:
419
+ """Attach an EBS volume to an instance and wait for it to be in-use."""
420
+ ec2_client.attach_volume(
421
+ VolumeId=volume_id,
422
+ InstanceId=instance_id,
423
+ Device=device_name,
424
+ )
425
+ waiter = ec2_client.get_waiter("volume_in_use")
426
+ waiter.wait(VolumeIds=[volume_id], WaiterConfig={"Delay": 5, "MaxAttempts": 24})
427
+
428
+
429
+ def detach_ebs_volume(ec2_client, volume_id: str) -> None:
430
+ """Detach an EBS volume and wait for it to become available."""
431
+ ec2_client.detach_volume(VolumeId=volume_id)
432
+ waiter = ec2_client.get_waiter("volume_available")
433
+ waiter.wait(VolumeIds=[volume_id], WaiterConfig={"Delay": 5, "MaxAttempts": 24})
434
+
435
+
436
+ def delete_ebs_volume(ec2_client, volume_id: str) -> None:
437
+ """Delete an EBS volume."""
438
+ ec2_client.delete_volume(VolumeId=volume_id)
439
+
440
+
441
+ def find_ebs_volumes_for_instance(ec2_client, instance_id: str, tag_value: str) -> list[dict]:
442
+ """Find EBS data volumes associated with an instance via tags.
443
+
444
+ Returns a list of dicts with VolumeId, Size, Device, and State.
445
+ Excludes root volumes (only returns volumes tagged by aws-bootstrap).
446
+ """
447
+ try:
448
+ response = ec2_client.describe_volumes(
449
+ Filters=[
450
+ {"Name": "tag:aws-bootstrap-instance", "Values": [instance_id]},
451
+ {"Name": "tag:created-by", "Values": [tag_value]},
452
+ ]
453
+ )
454
+ except botocore.exceptions.ClientError:
455
+ return []
456
+
457
+ volumes = []
458
+ for vol in response.get("Volumes", []):
459
+ device = ""
460
+ if vol.get("Attachments"):
461
+ device = vol["Attachments"][0].get("Device", "")
462
+ volumes.append(
463
+ {
464
+ "VolumeId": vol["VolumeId"],
465
+ "Size": vol["Size"],
466
+ "Device": device,
467
+ "State": vol["State"],
468
+ }
469
+ )
470
+ return volumes
@@ -0,0 +1,106 @@
1
+ """Output formatting for structured CLI output (JSON, YAML, table, text)."""
2
+
3
+ from __future__ import annotations
4
+ import json
5
+ from datetime import datetime
6
+ from enum import StrEnum
7
+ from pathlib import Path
8
+ from typing import Any
9
+
10
+ import click
11
+
12
+
13
+ class OutputFormat(StrEnum):
14
+ TEXT = "text"
15
+ JSON = "json"
16
+ YAML = "yaml"
17
+ TABLE = "table"
18
+
19
+
20
+ def get_format(ctx: click.Context | None = None) -> OutputFormat:
21
+ """Return the current output format from the click context."""
22
+ if ctx is None:
23
+ ctx = click.get_current_context(silent=True)
24
+ if ctx is None or ctx.obj is None:
25
+ return OutputFormat.TEXT
26
+ return ctx.obj.get("output_format", OutputFormat.TEXT)
27
+
28
+
29
+ def is_text(ctx: click.Context | None = None) -> bool:
30
+ """Return True if the current output format is text (default)."""
31
+ return get_format(ctx) == OutputFormat.TEXT
32
+
33
+
34
+ def _default_serializer(obj: Any) -> Any:
35
+ """JSON serializer for objects not serializable by default."""
36
+ if isinstance(obj, datetime):
37
+ return obj.isoformat()
38
+ if isinstance(obj, Path):
39
+ return str(obj)
40
+ raise TypeError(f"Object of type {type(obj).__name__} is not JSON serializable")
41
+
42
+
43
+ def emit(data: dict | list, *, headers: dict[str, str] | None = None, ctx: click.Context | None = None) -> None:
44
+ """Emit structured data in the configured output format.
45
+
46
+ For JSON/YAML: serializes the data directly.
47
+ For TABLE: renders using tabulate. If *data* is a list of dicts, uses
48
+ *headers* mapping ``{dict_key: column_label}`` for column selection/ordering.
49
+ If *data* is a single dict, renders as key-value pairs.
50
+ """
51
+ fmt = get_format(ctx)
52
+
53
+ if fmt == OutputFormat.JSON:
54
+ click.echo(json.dumps(data, indent=2, default=_default_serializer))
55
+ return
56
+
57
+ if fmt == OutputFormat.YAML:
58
+ import yaml # noqa: PLC0415
59
+
60
+ # Convert datetime/Path objects before YAML dump
61
+ prepared = json.loads(json.dumps(data, default=_default_serializer))
62
+ click.echo(yaml.dump(prepared, default_flow_style=False, sort_keys=False).rstrip())
63
+ return
64
+
65
+ if fmt == OutputFormat.TABLE:
66
+ from tabulate import tabulate # noqa: PLC0415
67
+
68
+ table_data = data
69
+ # Unwrap dict-wrapped lists (e.g. {"instances": [...]}) for table rendering
70
+ if isinstance(data, dict) and headers:
71
+ for v in data.values():
72
+ if isinstance(v, list):
73
+ table_data = v
74
+ break
75
+
76
+ if isinstance(table_data, list) and table_data and isinstance(table_data[0], dict):
77
+ if headers:
78
+ keys = list(headers.keys())
79
+ col_labels = list(headers.values())
80
+ rows = [[row.get(k, "") for k in keys] for row in table_data]
81
+ else:
82
+ col_labels = list(table_data[0].keys())
83
+ keys = col_labels
84
+ rows = [[row.get(k, "") for k in keys] for row in table_data]
85
+ click.echo(tabulate(rows, headers=col_labels, tablefmt="simple"))
86
+ elif isinstance(table_data, dict):
87
+ rows = [[k, v] for k, v in table_data.items()]
88
+ click.echo(tabulate(rows, headers=["Key", "Value"], tablefmt="simple"))
89
+ elif isinstance(table_data, list):
90
+ # Empty list
91
+ click.echo("(no data)")
92
+ return
93
+
94
+ # TEXT format: emit() is a no-op in text mode (text output is handled inline)
95
+
96
+
97
+ def echo(msg: str = "", **kwargs: Any) -> None:
98
+ """Wrap ``click.echo``; silent in non-text output modes."""
99
+ if is_text():
100
+ click.echo(msg, **kwargs)
101
+
102
+
103
+ def secho(msg: str = "", **kwargs: Any) -> None:
104
+ """Wrap ``click.secho``; silent in non-text output modes."""
105
+ if is_text():
106
+ click.secho(msg, **kwargs)
@@ -48,8 +48,8 @@ fi
48
48
  # 2. Install utilities
49
49
  echo ""
50
50
  echo "[2/6] Installing utilities..."
51
- sudo apt-get update -qq
52
- sudo apt-get install -y -qq htop tmux tree jq
51
+ sudo DEBIAN_FRONTEND=noninteractive apt-get update -qq
52
+ sudo DEBIAN_FRONTEND=noninteractive apt-get install -y -qq htop tmux tree jq ffmpeg
53
53
 
54
54
  # 3. Set up Python environment with uv
55
55
  echo ""
aws_bootstrap/ssh.py CHANGED
@@ -13,6 +13,7 @@ from pathlib import Path
13
13
  import click
14
14
 
15
15
  from .gpu import _GPU_ARCHITECTURES, GpuInfo
16
+ from .output import echo, secho
16
17
 
17
18
 
18
19
  # ---------------------------------------------------------------------------
@@ -54,7 +55,7 @@ def import_key_pair(ec2_client, key_name: str, key_path: Path) -> str:
54
55
  # Check if key pair already exists
55
56
  try:
56
57
  existing = ec2_client.describe_key_pairs(KeyNames=[key_name])
57
- click.echo(" Key pair " + click.style(f"'{key_name}'", fg="bright_white") + " already exists, reusing.")
58
+ echo(" Key pair " + click.style(f"'{key_name}'", fg="bright_white") + " already exists, reusing.")
58
59
  return existing["KeyPairs"][0]["KeyName"]
59
60
  except ec2_client.exceptions.ClientError as e:
60
61
  if "InvalidKeyPair.NotFound" not in str(e):
@@ -70,7 +71,7 @@ def import_key_pair(ec2_client, key_name: str, key_path: Path) -> str:
70
71
  }
71
72
  ],
72
73
  )
73
- click.secho(f" Imported key pair '{key_name}' from {key_path}", fg="green")
74
+ secho(f" Imported key pair '{key_name}' from {key_path}", fg="green")
74
75
  return key_name
75
76
 
76
77
 
@@ -88,7 +89,7 @@ def wait_for_ssh(host: str, user: str, key_path: Path, retries: int = 30, delay:
88
89
  sock = socket.create_connection((host, port), timeout=5)
89
90
  sock.close()
90
91
  except (TimeoutError, ConnectionRefusedError, OSError):
91
- click.echo(" SSH not ready " + click.style(f"(attempt {attempt}/{retries})", dim=True) + ", waiting...")
92
+ echo(" SSH not ready " + click.style(f"(attempt {attempt}/{retries})", dim=True) + ", waiting...")
92
93
  time.sleep(delay)
93
94
  continue
94
95
 
@@ -106,10 +107,10 @@ def wait_for_ssh(host: str, user: str, key_path: Path, retries: int = 30, delay:
106
107
  ]
107
108
  result = subprocess.run(cmd, capture_output=True, text=True)
108
109
  if result.returncode == 0:
109
- click.secho(" SSH connection established.", fg="green")
110
+ secho(" SSH connection established.", fg="green")
110
111
  return True
111
112
 
112
- click.echo(" SSH not ready " + click.style(f"(attempt {attempt}/{retries})", dim=True) + ", waiting...")
113
+ echo(" SSH not ready " + click.style(f"(attempt {attempt}/{retries})", dim=True) + ", waiting...")
113
114
  time.sleep(delay)
114
115
 
115
116
  return False
@@ -125,89 +126,89 @@ def run_remote_setup(
125
126
  requirements_path = script_path.parent / "requirements.txt"
126
127
 
127
128
  # SCP the requirements file
128
- click.echo(" Uploading requirements.txt...")
129
+ echo(" Uploading requirements.txt...")
129
130
  req_result = subprocess.run(
130
131
  ["scp", *ssh_opts, *scp_port_opts, str(requirements_path), f"{user}@{host}:/tmp/requirements.txt"],
131
132
  capture_output=True,
132
133
  text=True,
133
134
  )
134
135
  if req_result.returncode != 0:
135
- click.secho(f" SCP failed: {req_result.stderr}", fg="red", err=True)
136
+ secho(f" SCP failed: {req_result.stderr}", fg="red", err=True)
136
137
  return False
137
138
 
138
139
  # SCP the GPU benchmark script
139
140
  benchmark_path = script_path.parent / "gpu_benchmark.py"
140
- click.echo(" Uploading gpu_benchmark.py...")
141
+ echo(" Uploading gpu_benchmark.py...")
141
142
  bench_result = subprocess.run(
142
143
  ["scp", *ssh_opts, *scp_port_opts, str(benchmark_path), f"{user}@{host}:/tmp/gpu_benchmark.py"],
143
144
  capture_output=True,
144
145
  text=True,
145
146
  )
146
147
  if bench_result.returncode != 0:
147
- click.secho(f" SCP failed: {bench_result.stderr}", fg="red", err=True)
148
+ secho(f" SCP failed: {bench_result.stderr}", fg="red", err=True)
148
149
  return False
149
150
 
150
151
  # SCP the GPU smoke test notebook
151
152
  notebook_path = script_path.parent / "gpu_smoke_test.ipynb"
152
- click.echo(" Uploading gpu_smoke_test.ipynb...")
153
+ echo(" Uploading gpu_smoke_test.ipynb...")
153
154
  nb_result = subprocess.run(
154
155
  ["scp", *ssh_opts, *scp_port_opts, str(notebook_path), f"{user}@{host}:/tmp/gpu_smoke_test.ipynb"],
155
156
  capture_output=True,
156
157
  text=True,
157
158
  )
158
159
  if nb_result.returncode != 0:
159
- click.secho(f" SCP failed: {nb_result.stderr}", fg="red", err=True)
160
+ secho(f" SCP failed: {nb_result.stderr}", fg="red", err=True)
160
161
  return False
161
162
 
162
163
  # SCP the CUDA example source
163
164
  saxpy_path = script_path.parent / "saxpy.cu"
164
- click.echo(" Uploading saxpy.cu...")
165
+ echo(" Uploading saxpy.cu...")
165
166
  saxpy_result = subprocess.run(
166
167
  ["scp", *ssh_opts, *scp_port_opts, str(saxpy_path), f"{user}@{host}:/tmp/saxpy.cu"],
167
168
  capture_output=True,
168
169
  text=True,
169
170
  )
170
171
  if saxpy_result.returncode != 0:
171
- click.secho(f" SCP failed: {saxpy_result.stderr}", fg="red", err=True)
172
+ secho(f" SCP failed: {saxpy_result.stderr}", fg="red", err=True)
172
173
  return False
173
174
 
174
175
  # SCP the VSCode launch.json
175
176
  launch_json_path = script_path.parent / "launch.json"
176
- click.echo(" Uploading launch.json...")
177
+ echo(" Uploading launch.json...")
177
178
  launch_result = subprocess.run(
178
179
  ["scp", *ssh_opts, *scp_port_opts, str(launch_json_path), f"{user}@{host}:/tmp/launch.json"],
179
180
  capture_output=True,
180
181
  text=True,
181
182
  )
182
183
  if launch_result.returncode != 0:
183
- click.secho(f" SCP failed: {launch_result.stderr}", fg="red", err=True)
184
+ secho(f" SCP failed: {launch_result.stderr}", fg="red", err=True)
184
185
  return False
185
186
 
186
187
  # SCP the VSCode tasks.json
187
188
  tasks_json_path = script_path.parent / "tasks.json"
188
- click.echo(" Uploading tasks.json...")
189
+ echo(" Uploading tasks.json...")
189
190
  tasks_result = subprocess.run(
190
191
  ["scp", *ssh_opts, *scp_port_opts, str(tasks_json_path), f"{user}@{host}:/tmp/tasks.json"],
191
192
  capture_output=True,
192
193
  text=True,
193
194
  )
194
195
  if tasks_result.returncode != 0:
195
- click.secho(f" SCP failed: {tasks_result.stderr}", fg="red", err=True)
196
+ secho(f" SCP failed: {tasks_result.stderr}", fg="red", err=True)
196
197
  return False
197
198
 
198
199
  # SCP the script
199
- click.echo(" Uploading remote_setup.sh...")
200
+ echo(" Uploading remote_setup.sh...")
200
201
  scp_result = subprocess.run(
201
202
  ["scp", *ssh_opts, *scp_port_opts, str(script_path), f"{user}@{host}:/tmp/remote_setup.sh"],
202
203
  capture_output=True,
203
204
  text=True,
204
205
  )
205
206
  if scp_result.returncode != 0:
206
- click.secho(f" SCP failed: {scp_result.stderr}", fg="red", err=True)
207
+ secho(f" SCP failed: {scp_result.stderr}", fg="red", err=True)
207
208
  return False
208
209
 
209
210
  # Execute the script, passing PYTHON_VERSION as an inline env var if specified
210
- click.echo(" Running remote_setup.sh on instance...")
211
+ echo(" Running remote_setup.sh on instance...")
211
212
  remote_cmd = "chmod +x /tmp/remote_setup.sh && "
212
213
  if python_version:
213
214
  remote_cmd += f"PYTHON_VERSION={python_version} "
@@ -374,6 +375,37 @@ def list_ssh_hosts(config_path: Path | None = None) -> dict[str, str]:
374
375
  return result
375
376
 
376
377
 
378
+ def find_stale_ssh_hosts(live_instance_ids: set[str], config_path: Path | None = None) -> list[tuple[str, str]]:
379
+ """Identify SSH config entries whose instances no longer exist.
380
+
381
+ Returns ``[(instance_id, alias), ...]`` for entries where the instance ID
382
+ is **not** in *live_instance_ids*, sorted by alias.
383
+ """
384
+ hosts = list_ssh_hosts(config_path)
385
+ stale = [(iid, alias) for iid, alias in hosts.items() if iid not in live_instance_ids]
386
+ stale.sort(key=lambda t: t[1])
387
+ return stale
388
+
389
+
390
+ def cleanup_stale_ssh_hosts(
391
+ live_instance_ids: set[str],
392
+ config_path: Path | None = None,
393
+ dry_run: bool = False,
394
+ ) -> list[CleanupResult]:
395
+ """Remove SSH config entries for terminated/non-existent instances.
396
+
397
+ If *dry_run* is ``True``, entries are identified but not removed.
398
+ Returns a list of :class:`CleanupResult` objects.
399
+ """
400
+ stale = find_stale_ssh_hosts(live_instance_ids, config_path)
401
+ results: list[CleanupResult] = []
402
+ for iid, alias in stale:
403
+ if not dry_run:
404
+ remove_ssh_host(iid, config_path)
405
+ results.append(CleanupResult(instance_id=iid, alias=alias, removed=not dry_run))
406
+ return results
407
+
408
+
377
409
  _INSTANCE_ID_RE = re.compile(r"^i-[0-9a-f]{8,17}$")
378
410
 
379
411
 
@@ -402,6 +434,15 @@ def resolve_instance_id(value: str, config_path: Path | None = None) -> str | No
402
434
  return None
403
435
 
404
436
 
437
+ @dataclass
438
+ class CleanupResult:
439
+ """Result of cleaning up a single stale SSH config entry."""
440
+
441
+ instance_id: str
442
+ alias: str
443
+ removed: bool
444
+
445
+
405
446
  @dataclass
406
447
  class SSHHostDetails:
407
448
  """Connection details parsed from an SSH config stanza."""
@@ -515,6 +556,87 @@ def query_gpu_info(host: str, user: str, key_path: Path, timeout: int = 10, port
515
556
  return None
516
557
 
517
558
 
559
+ # ---------------------------------------------------------------------------
560
+ # EBS volume mount
561
+ # ---------------------------------------------------------------------------
562
+
563
+
564
+ def mount_ebs_volume(
565
+ host: str,
566
+ user: str,
567
+ key_path: Path,
568
+ volume_id: str,
569
+ mount_point: str = "/data",
570
+ format_volume: bool = True,
571
+ port: int = 22,
572
+ ) -> bool:
573
+ """Mount an EBS volume on the remote instance via SSH.
574
+
575
+ Detects the NVMe device by volume ID serial, formats if requested,
576
+ mounts at *mount_point*, and adds an fstab entry for persistence.
577
+
578
+ Returns True on success, False on failure.
579
+ """
580
+ ssh_opts = _ssh_opts(key_path)
581
+ port_opts = ["-p", str(port)] if port != 22 else []
582
+
583
+ # Strip the vol- prefix and hyphen for NVMe serial matching
584
+ vol_serial = volume_id.replace("-", "")
585
+
586
+ format_cmd = ""
587
+ if format_volume:
588
+ format_cmd = (
589
+ ' if ! sudo blkid "$DEVICE" > /dev/null 2>&1; then\n'
590
+ ' echo "Formatting $DEVICE as ext4..."\n'
591
+ ' sudo mkfs.ext4 "$DEVICE"\n'
592
+ " fi\n"
593
+ )
594
+
595
+ remote_script = (
596
+ "set -e\n"
597
+ "# Detect EBS device by NVMe serial (Nitro instances)\n"
598
+ f'SERIAL="{vol_serial}"\n'
599
+ "DEVICE=$(lsblk -o NAME,SERIAL -dpn 2>/dev/null | "
600
+ "awk -v s=\"$SERIAL\" '$2 == s {print $1}' | head -1)\n"
601
+ "# Fallback to common device paths\n"
602
+ 'if [ -z "$DEVICE" ]; then\n'
603
+ " for dev in /dev/nvme1n1 /dev/xvdf /dev/sdf; do\n"
604
+ ' if [ -b "$dev" ]; then DEVICE="$dev"; break; fi\n'
605
+ " done\n"
606
+ "fi\n"
607
+ 'if [ -z "$DEVICE" ]; then\n'
608
+ ' echo "ERROR: Could not find EBS device" >&2\n'
609
+ " exit 1\n"
610
+ "fi\n"
611
+ 'echo "Found EBS device: $DEVICE"\n'
612
+ f"{format_cmd}"
613
+ f"sudo mkdir -p {mount_point}\n"
614
+ f'sudo mount "$DEVICE" {mount_point}\n'
615
+ f"sudo chown {user}:{user} {mount_point}\n"
616
+ "# Add fstab entry for reboot persistence\n"
617
+ 'UUID=$(sudo blkid -s UUID -o value "$DEVICE")\n'
618
+ 'if [ -n "$UUID" ]; then\n'
619
+ f' if ! grep -q "$UUID" /etc/fstab; then\n'
620
+ f' echo "UUID=$UUID {mount_point} ext4 defaults,nofail 0 2" | sudo tee -a /etc/fstab > /dev/null\n'
621
+ " fi\n"
622
+ "fi\n"
623
+ f'echo "Mounted $DEVICE at {mount_point}"'
624
+ )
625
+
626
+ cmd = [
627
+ "ssh",
628
+ *ssh_opts,
629
+ *port_opts,
630
+ "-o",
631
+ "ConnectTimeout=10",
632
+ f"{user}@{host}",
633
+ remote_script,
634
+ ]
635
+
636
+ result = subprocess.run(cmd, capture_output=False)
637
+ return result.returncode == 0
638
+
639
+
518
640
  # ---------------------------------------------------------------------------
519
641
  # Internal helpers
520
642
  # ---------------------------------------------------------------------------