gpu-dev 0.3.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gpu_dev-0.3.5.dist-info/METADATA +687 -0
- gpu_dev-0.3.5.dist-info/RECORD +14 -0
- gpu_dev-0.3.5.dist-info/WHEEL +5 -0
- gpu_dev-0.3.5.dist-info/entry_points.txt +4 -0
- gpu_dev-0.3.5.dist-info/top_level.txt +1 -0
- gpu_dev_cli/__init__.py +9 -0
- gpu_dev_cli/auth.py +158 -0
- gpu_dev_cli/cli.py +3754 -0
- gpu_dev_cli/config.py +248 -0
- gpu_dev_cli/disks.py +523 -0
- gpu_dev_cli/interactive.py +702 -0
- gpu_dev_cli/name_generator.py +117 -0
- gpu_dev_cli/reservations.py +2231 -0
- gpu_dev_cli/ssh_proxy.py +106 -0
|
@@ -0,0 +1,2231 @@
|
|
|
1
|
+
"""Minimal reservation management for GPU Dev CLI"""
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import os
|
|
5
|
+
import select
|
|
6
|
+
import signal
|
|
7
|
+
import sys
|
|
8
|
+
import time
|
|
9
|
+
import uuid
|
|
10
|
+
from datetime import datetime, timedelta
|
|
11
|
+
from decimal import Decimal
|
|
12
|
+
from typing import Optional, List, Dict, Any, Union
|
|
13
|
+
|
|
14
|
+
from botocore.exceptions import ClientError
|
|
15
|
+
from rich.console import Console
|
|
16
|
+
from rich.live import Live
|
|
17
|
+
from rich.spinner import Spinner
|
|
18
|
+
|
|
19
|
+
from .config import Config
|
|
20
|
+
from .name_generator import sanitize_name
|
|
21
|
+
from . import __version__
|
|
22
|
+
|
|
23
|
+
console = Console()
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def _make_vscode_link(pod_name: str) -> str:
|
|
27
|
+
"""Create a clickable vscode:// URL for opening remote SSH connection
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
pod_name: The SSH host name (e.g., gpu-dev-34f5f9e0)
|
|
31
|
+
|
|
32
|
+
Returns:
|
|
33
|
+
A vscode:// URL that opens VS Code with the remote SSH connection
|
|
34
|
+
"""
|
|
35
|
+
# VS Code remote SSH URL format: vscode://vscode-remote/ssh-remote+<host>/path
|
|
36
|
+
return f"vscode://vscode-remote/ssh-remote+{pod_name}/home/dev"
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def _make_cursor_link(pod_name: str) -> str:
|
|
40
|
+
"""Create a clickable cursor:// URL for opening remote SSH connection in Cursor
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
pod_name: The SSH host name (e.g., gpu-dev-34f5f9e0)
|
|
44
|
+
|
|
45
|
+
Returns:
|
|
46
|
+
A cursor:// URL that opens Cursor with the remote SSH connection
|
|
47
|
+
"""
|
|
48
|
+
# Based on VS Code remote SSH URL format: cursor://vscode-remote/ssh-remote+<host>/path
|
|
49
|
+
return f"cursor://vscode-remote/ssh-remote+{pod_name}/home/dev"
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def get_version() -> str:
|
|
53
|
+
"""Get CLI version for inclusion in SQS messages"""
|
|
54
|
+
return __version__
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def _add_agent_forwarding_to_ssh(ssh_command: str) -> str:
|
|
58
|
+
"""Add SSH agent forwarding (-A) flag to SSH command if not already present"""
|
|
59
|
+
try:
|
|
60
|
+
if not ssh_command or not ssh_command.startswith("ssh "):
|
|
61
|
+
return ssh_command
|
|
62
|
+
|
|
63
|
+
# Check if -A is already in the command
|
|
64
|
+
if " -A" in ssh_command or ssh_command.endswith(" -A"):
|
|
65
|
+
return ssh_command
|
|
66
|
+
|
|
67
|
+
# Add -A flag after 'ssh'
|
|
68
|
+
parts = ssh_command.split(" ", 1)
|
|
69
|
+
if len(parts) == 2:
|
|
70
|
+
return f"ssh -A {parts[1]}"
|
|
71
|
+
else:
|
|
72
|
+
return "ssh -A"
|
|
73
|
+
|
|
74
|
+
except Exception:
|
|
75
|
+
return ssh_command
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def _extract_latest_pod_event(pod_events: str) -> str:
|
|
79
|
+
"""Extract the most relevant pod event for display - simplified since Lambda now provides formatted messages"""
|
|
80
|
+
if not pod_events:
|
|
81
|
+
return "Starting pod..."
|
|
82
|
+
|
|
83
|
+
# Lambda now provides pre-formatted messages, so just return them
|
|
84
|
+
# Handle multi-line messages by taking the first meaningful line
|
|
85
|
+
lines = pod_events.split("\n")
|
|
86
|
+
for line in lines:
|
|
87
|
+
line = line.strip()
|
|
88
|
+
if line and not line.startswith("Events:"):
|
|
89
|
+
return line
|
|
90
|
+
|
|
91
|
+
return "Starting pod..."
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def _generate_vscode_command(ssh_command: str) -> Optional[str]:
|
|
95
|
+
"""Generate VS Code remote connection command from SSH command"""
|
|
96
|
+
try:
|
|
97
|
+
# Extract remote server from SSH command
|
|
98
|
+
# Expected format: ssh dev@<hostname> or various formats with -o options
|
|
99
|
+
if not ssh_command or not ssh_command.startswith("ssh "):
|
|
100
|
+
return None
|
|
101
|
+
|
|
102
|
+
# Parse SSH command to extract hostname
|
|
103
|
+
parts = ssh_command.split()
|
|
104
|
+
hostname = None
|
|
105
|
+
|
|
106
|
+
for i, part in enumerate(parts):
|
|
107
|
+
if "@" in part and not part.startswith("-"):
|
|
108
|
+
# Extract just the hostname part (e.g., from dev@hostname.io)
|
|
109
|
+
hostname = part.split("@")[1]
|
|
110
|
+
break
|
|
111
|
+
|
|
112
|
+
if not hostname:
|
|
113
|
+
return None
|
|
114
|
+
|
|
115
|
+
# Generate VS Code command with ProxyCommand and agent forwarding
|
|
116
|
+
# VS Code will use the ssh command options we provide
|
|
117
|
+
remote_server = f"dev@{hostname}"
|
|
118
|
+
|
|
119
|
+
# Escape single quotes in the ProxyCommand for shell
|
|
120
|
+
proxy_cmd = "gpu-dev-ssh-proxy %h %p"
|
|
121
|
+
|
|
122
|
+
return (
|
|
123
|
+
f"code --remote ssh-remote+{remote_server} "
|
|
124
|
+
f"--ssh-option ForwardAgent=yes "
|
|
125
|
+
f"--ssh-option ProxyCommand='{proxy_cmd}' "
|
|
126
|
+
f"--ssh-option StrictHostKeyChecking=no "
|
|
127
|
+
f"--ssh-option UserKnownHostsFile=/dev/null "
|
|
128
|
+
f"/home/dev"
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
except Exception:
|
|
132
|
+
return None
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def _generate_cursor_command(ssh_command: str) -> Optional[str]:
|
|
136
|
+
"""Generate Cursor remote connection command from SSH command"""
|
|
137
|
+
try:
|
|
138
|
+
# Extract remote server from SSH command
|
|
139
|
+
# Expected format: ssh dev@<hostname> or various formats with -o options
|
|
140
|
+
if not ssh_command or not ssh_command.startswith("ssh "):
|
|
141
|
+
return None
|
|
142
|
+
|
|
143
|
+
# Parse SSH command to extract hostname
|
|
144
|
+
parts = ssh_command.split()
|
|
145
|
+
remote_server = parts[-1]
|
|
146
|
+
if '@' in remote_server:
|
|
147
|
+
remote_server = remote_server.split('@')[1]
|
|
148
|
+
|
|
149
|
+
# Return the VS Code command format
|
|
150
|
+
return f"cursor --remote ssh-remote+{remote_server} /home/dev"
|
|
151
|
+
except Exception:
|
|
152
|
+
return None
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
def _generate_ssh_config(hostname: str, pod_name: str) -> str:
|
|
156
|
+
"""Generate SSH config for a reservation
|
|
157
|
+
|
|
158
|
+
Args:
|
|
159
|
+
hostname: The FQDN hostname (e.g., old_bison.devservers.io)
|
|
160
|
+
pod_name: The pod name to use as SSH host alias
|
|
161
|
+
|
|
162
|
+
Returns:
|
|
163
|
+
SSH config content as string
|
|
164
|
+
"""
|
|
165
|
+
config_content = f"""Host {pod_name}
|
|
166
|
+
HostName {hostname}
|
|
167
|
+
User dev
|
|
168
|
+
ForwardAgent yes
|
|
169
|
+
ProxyCommand gpu-dev-ssh-proxy %h %p
|
|
170
|
+
StrictHostKeyChecking no
|
|
171
|
+
UserKnownHostsFile /dev/null
|
|
172
|
+
"""
|
|
173
|
+
return config_content
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
def _check_ssh_config_permission() -> bool:
|
|
177
|
+
"""Check if user has given permission to modify ~/.ssh/config and ~/.cursor/ssh_config
|
|
178
|
+
|
|
179
|
+
Returns:
|
|
180
|
+
True if permission granted or already set up, False otherwise
|
|
181
|
+
"""
|
|
182
|
+
import click
|
|
183
|
+
from pathlib import Path
|
|
184
|
+
|
|
185
|
+
gpu_dev_dir = Path.home() / ".gpu-dev"
|
|
186
|
+
permission_file = gpu_dev_dir / ".ssh-config-permission"
|
|
187
|
+
|
|
188
|
+
# Check if already asked and answered
|
|
189
|
+
if permission_file.exists():
|
|
190
|
+
try:
|
|
191
|
+
response = permission_file.read_text().strip()
|
|
192
|
+
return response == "yes"
|
|
193
|
+
except Exception:
|
|
194
|
+
pass
|
|
195
|
+
|
|
196
|
+
# Check if Include already exists in either ~/.ssh/config or ~/.cursor/ssh_config
|
|
197
|
+
config_files = [
|
|
198
|
+
Path.home() / ".ssh" / "config",
|
|
199
|
+
Path.home() / ".cursor" / "ssh_config",
|
|
200
|
+
]
|
|
201
|
+
|
|
202
|
+
for ssh_config in config_files:
|
|
203
|
+
if ssh_config.exists():
|
|
204
|
+
try:
|
|
205
|
+
content = ssh_config.read_text()
|
|
206
|
+
if "Include ~/.gpu-dev/" in content:
|
|
207
|
+
# Already set up, save permission
|
|
208
|
+
gpu_dev_dir.mkdir(mode=0o700, exist_ok=True)
|
|
209
|
+
permission_file.write_text("yes")
|
|
210
|
+
return True
|
|
211
|
+
except Exception:
|
|
212
|
+
pass
|
|
213
|
+
|
|
214
|
+
# Ask user for permission
|
|
215
|
+
console.print("\n[yellow]━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[/yellow]")
|
|
216
|
+
console.print("[cyan]🔧 SSH Configuration Setup[/cyan]\n")
|
|
217
|
+
console.print("To enable easy SSH access and VS Code/Cursor Remote connections,")
|
|
218
|
+
console.print("we can add GPU dev server configs to your SSH config files.")
|
|
219
|
+
console.print("\n[dim]This adds one line at the top of:[/dim]")
|
|
220
|
+
console.print("[dim] • ~/.ssh/config[/dim]")
|
|
221
|
+
console.print("[dim] • ~/.cursor/ssh_config[/dim]")
|
|
222
|
+
console.print("[dim]Line added: Include ~/.gpu-dev/*-sshconfig[/dim]\n")
|
|
223
|
+
console.print("[green]Benefits:[/green]")
|
|
224
|
+
console.print(" • Simple commands: [green]ssh <pod-name>[/green]")
|
|
225
|
+
console.print(" • VS Code Remote works: [green]code --remote ssh-remote+<pod-name>[/green]")
|
|
226
|
+
console.print(" • Cursor Remote works: Open Remote SSH in Cursor")
|
|
227
|
+
console.print("\n[dim]Without this, you'll need to use: [green]ssh -F ~/.gpu-dev/<id>-sshconfig <pod-name>[/green][/dim]")
|
|
228
|
+
console.print("[yellow]━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[/yellow]\n")
|
|
229
|
+
|
|
230
|
+
approved = click.confirm("Add Include directive to SSH config files?", default=True)
|
|
231
|
+
|
|
232
|
+
# Save response
|
|
233
|
+
gpu_dev_dir.mkdir(mode=0o700, exist_ok=True)
|
|
234
|
+
permission_file.write_text("yes" if approved else "no")
|
|
235
|
+
|
|
236
|
+
return approved
|
|
237
|
+
|
|
238
|
+
|
|
239
|
+
def _ensure_ssh_config_includes_devgpu() -> bool:
|
|
240
|
+
"""Ensure ~/.ssh/config and ~/.cursor/ssh_config include ~/.devgpu/* configs for VS Code/Cursor compatibility
|
|
241
|
+
|
|
242
|
+
Returns:
|
|
243
|
+
True if Include was added/exists, False if user declined
|
|
244
|
+
"""
|
|
245
|
+
from pathlib import Path
|
|
246
|
+
|
|
247
|
+
# Check permission first
|
|
248
|
+
if not _check_ssh_config_permission():
|
|
249
|
+
return False
|
|
250
|
+
|
|
251
|
+
include_line = "Include ~/.gpu-dev/*-sshconfig\n"
|
|
252
|
+
|
|
253
|
+
# List of config files to update: ~/.ssh/config and ~/.cursor/ssh_config
|
|
254
|
+
config_files = [
|
|
255
|
+
(Path.home() / ".ssh", "config"),
|
|
256
|
+
(Path.home() / ".cursor", "ssh_config"),
|
|
257
|
+
]
|
|
258
|
+
|
|
259
|
+
success = False
|
|
260
|
+
for config_dir, config_name in config_files:
|
|
261
|
+
try:
|
|
262
|
+
# Create directory if it doesn't exist
|
|
263
|
+
config_dir.mkdir(mode=0o700, exist_ok=True)
|
|
264
|
+
|
|
265
|
+
config_file = config_dir / config_name
|
|
266
|
+
|
|
267
|
+
# Read existing config or create empty
|
|
268
|
+
if config_file.exists():
|
|
269
|
+
content = config_file.read_text()
|
|
270
|
+
else:
|
|
271
|
+
content = ""
|
|
272
|
+
|
|
273
|
+
# Check if Include already exists
|
|
274
|
+
if "Include ~/.gpu-dev/" in content:
|
|
275
|
+
success = True
|
|
276
|
+
continue
|
|
277
|
+
|
|
278
|
+
# Add Include at the top (must be first in SSH config)
|
|
279
|
+
new_content = include_line + "\n" + content
|
|
280
|
+
config_file.write_text(new_content)
|
|
281
|
+
config_file.chmod(0o600)
|
|
282
|
+
success = True
|
|
283
|
+
except Exception:
|
|
284
|
+
# If one fails, we still try the other
|
|
285
|
+
pass
|
|
286
|
+
|
|
287
|
+
return success
|
|
288
|
+
|
|
289
|
+
|
|
290
|
+
def create_ssh_config_for_reservation(hostname: str, pod_name: str, reservation_id: str, name: Optional[str] = None) -> tuple[Optional[str], bool]:
|
|
291
|
+
"""Create SSH config file for a reservation in ~/.gpu-dev/
|
|
292
|
+
|
|
293
|
+
Args:
|
|
294
|
+
hostname: The FQDN hostname (e.g., old_bison.devservers.io)
|
|
295
|
+
pod_name: The pod name to use as SSH host alias
|
|
296
|
+
reservation_id: The reservation ID (full or short)
|
|
297
|
+
name: Optional reservation name to use for filename (falls back to short ID)
|
|
298
|
+
|
|
299
|
+
Returns:
|
|
300
|
+
Tuple of (config_path, use_include) where:
|
|
301
|
+
- config_path: Path to the created config file, or None on error
|
|
302
|
+
- use_include: True if ~/.ssh/config includes devgpu configs, False if need -F flag
|
|
303
|
+
"""
|
|
304
|
+
from pathlib import Path
|
|
305
|
+
|
|
306
|
+
# Create ~/.gpu-dev directory
|
|
307
|
+
gpu_dev_dir = Path.home() / ".gpu-dev"
|
|
308
|
+
gpu_dev_dir.mkdir(mode=0o700, exist_ok=True)
|
|
309
|
+
|
|
310
|
+
# Use short ID for filename (always safe, avoids issues with special chars like / in names)
|
|
311
|
+
# For multinode, names like "16x B200 multinode - Node 1/2" contain / which breaks filenames
|
|
312
|
+
short_id = reservation_id[:8]
|
|
313
|
+
filename = f"{short_id}-sshconfig"
|
|
314
|
+
|
|
315
|
+
config_file = gpu_dev_dir / filename
|
|
316
|
+
config_content = _generate_ssh_config(hostname, pod_name)
|
|
317
|
+
|
|
318
|
+
try:
|
|
319
|
+
config_file.write_text(config_content)
|
|
320
|
+
config_file.chmod(0o600)
|
|
321
|
+
|
|
322
|
+
# Check/ask permission to include in ~/.ssh/config
|
|
323
|
+
use_include = _ensure_ssh_config_includes_devgpu()
|
|
324
|
+
|
|
325
|
+
return (str(config_file), use_include)
|
|
326
|
+
except Exception:
|
|
327
|
+
return (None, False)
|
|
328
|
+
|
|
329
|
+
|
|
330
|
+
def remove_ssh_config_for_reservation(reservation_id: str, name: Optional[str] = None) -> bool:
|
|
331
|
+
"""Remove SSH config file for a reservation
|
|
332
|
+
|
|
333
|
+
Args:
|
|
334
|
+
reservation_id: The reservation ID (full or short)
|
|
335
|
+
name: Optional reservation name to use for filename (falls back to short ID, name param kept for backwards compat)
|
|
336
|
+
|
|
337
|
+
Returns:
|
|
338
|
+
True if successful (or file didn't exist), False on error
|
|
339
|
+
"""
|
|
340
|
+
from pathlib import Path
|
|
341
|
+
|
|
342
|
+
# Always use short ID for filename (consistent with create_ssh_config_for_reservation)
|
|
343
|
+
short_id = reservation_id[:8]
|
|
344
|
+
filename = f"{short_id}-sshconfig"
|
|
345
|
+
|
|
346
|
+
config_file = Path.home() / ".gpu-dev" / filename
|
|
347
|
+
|
|
348
|
+
try:
|
|
349
|
+
if config_file.exists():
|
|
350
|
+
config_file.unlink()
|
|
351
|
+
return True
|
|
352
|
+
except Exception:
|
|
353
|
+
return False
|
|
354
|
+
|
|
355
|
+
|
|
356
|
+
def is_ssh_include_enabled() -> bool:
|
|
357
|
+
"""Check if user has approved SSH config Include directive
|
|
358
|
+
|
|
359
|
+
Returns:
|
|
360
|
+
True if Include is enabled, False otherwise
|
|
361
|
+
"""
|
|
362
|
+
from pathlib import Path
|
|
363
|
+
|
|
364
|
+
permission_file = Path.home() / ".gpu-dev" / ".ssh-config-permission"
|
|
365
|
+
if permission_file.exists():
|
|
366
|
+
try:
|
|
367
|
+
return permission_file.read_text().strip() == "yes"
|
|
368
|
+
except Exception:
|
|
369
|
+
pass
|
|
370
|
+
return False
|
|
371
|
+
|
|
372
|
+
|
|
373
|
+
def get_ssh_config_path(reservation_id: str, name: Optional[str] = None) -> str:
|
|
374
|
+
"""Get the SSH config file path for a reservation
|
|
375
|
+
|
|
376
|
+
Args:
|
|
377
|
+
reservation_id: The reservation ID (full or short)
|
|
378
|
+
name: Optional reservation name to use for filename (falls back to short ID, name param kept for backwards compat)
|
|
379
|
+
|
|
380
|
+
Returns:
|
|
381
|
+
Path to the config file (may not exist)
|
|
382
|
+
"""
|
|
383
|
+
from pathlib import Path
|
|
384
|
+
# Always use short ID for filename (consistent with create_ssh_config_for_reservation)
|
|
385
|
+
short_id = reservation_id[:8]
|
|
386
|
+
filename = f"{short_id}-sshconfig"
|
|
387
|
+
return str(Path.home() / ".gpu-dev" / filename)
|
|
388
|
+
|
|
389
|
+
|
|
390
|
+
class ReservationManager:
|
|
391
|
+
"""Minimal GPU reservations manager - AWS-only"""
|
|
392
|
+
|
|
393
|
+
def __init__(self, config: Config):
|
|
394
|
+
self.config = config
|
|
395
|
+
self.reservations_table = config.dynamodb.Table(
|
|
396
|
+
config.reservations_table)
|
|
397
|
+
|
|
398
|
+
def create_reservation(
|
|
399
|
+
self,
|
|
400
|
+
user_id: str,
|
|
401
|
+
gpu_count: int,
|
|
402
|
+
gpu_type: str,
|
|
403
|
+
duration_hours: Union[int, float],
|
|
404
|
+
name: Optional[str] = None,
|
|
405
|
+
github_user: Optional[str] = None,
|
|
406
|
+
jupyter_enabled: bool = False,
|
|
407
|
+
recreate_env: bool = False,
|
|
408
|
+
dockerfile: Optional[str] = None,
|
|
409
|
+
no_persistent_disk: bool = False,
|
|
410
|
+
dockerimage: Optional[str] = None,
|
|
411
|
+
preserve_entrypoint: bool = False,
|
|
412
|
+
disk_name: Optional[str] = None,
|
|
413
|
+
node_labels: Optional[Dict[str, str]] = None,
|
|
414
|
+
) -> Optional[str]:
|
|
415
|
+
"""Create a new GPU reservation"""
|
|
416
|
+
try:
|
|
417
|
+
reservation_id = str(uuid.uuid4())
|
|
418
|
+
created_at = datetime.utcnow().isoformat()
|
|
419
|
+
|
|
420
|
+
# Process the name: sanitize user input or let Lambda generate
|
|
421
|
+
processed_name = None
|
|
422
|
+
if name:
|
|
423
|
+
# Sanitize user-provided name
|
|
424
|
+
processed_name = sanitize_name(name)
|
|
425
|
+
# If sanitization results in empty string, let Lambda generate
|
|
426
|
+
if not processed_name:
|
|
427
|
+
processed_name = None
|
|
428
|
+
# If no name provided, let Lambda generate (processed_name stays None)
|
|
429
|
+
|
|
430
|
+
# Create initial reservation record for polling
|
|
431
|
+
# Convert float to Decimal for DynamoDB compatibility
|
|
432
|
+
duration_decimal = Decimal(str(duration_hours))
|
|
433
|
+
|
|
434
|
+
initial_reservation = {
|
|
435
|
+
"reservation_id": reservation_id,
|
|
436
|
+
"user_id": user_id,
|
|
437
|
+
"gpu_count": gpu_count,
|
|
438
|
+
"gpu_type": gpu_type,
|
|
439
|
+
"duration_hours": duration_decimal,
|
|
440
|
+
"name": processed_name,
|
|
441
|
+
"created_at": created_at,
|
|
442
|
+
"status": "pending",
|
|
443
|
+
"expires_at": (
|
|
444
|
+
datetime.utcnow() + timedelta(hours=duration_hours)
|
|
445
|
+
).isoformat(),
|
|
446
|
+
"jupyter_enabled": jupyter_enabled,
|
|
447
|
+
}
|
|
448
|
+
|
|
449
|
+
# Add github_user if provided
|
|
450
|
+
if github_user:
|
|
451
|
+
initial_reservation["github_user"] = github_user
|
|
452
|
+
|
|
453
|
+
# Send processing request to SQS queue (Lambda will create the initial record)
|
|
454
|
+
# Use float for SQS message (JSON serializable)
|
|
455
|
+
message = {
|
|
456
|
+
"reservation_id": reservation_id,
|
|
457
|
+
"user_id": user_id,
|
|
458
|
+
"gpu_count": gpu_count,
|
|
459
|
+
"gpu_type": gpu_type,
|
|
460
|
+
"duration_hours": float(duration_hours),
|
|
461
|
+
"name": processed_name,
|
|
462
|
+
"created_at": created_at,
|
|
463
|
+
"status": "pending",
|
|
464
|
+
"jupyter_enabled": jupyter_enabled,
|
|
465
|
+
"recreate_env": recreate_env,
|
|
466
|
+
"no_persistent_disk": no_persistent_disk,
|
|
467
|
+
"version": get_version(),
|
|
468
|
+
}
|
|
469
|
+
|
|
470
|
+
# Add github_user if provided
|
|
471
|
+
if github_user:
|
|
472
|
+
message["github_user"] = github_user
|
|
473
|
+
|
|
474
|
+
# Add Docker options if provided
|
|
475
|
+
if dockerfile:
|
|
476
|
+
message["dockerfile"] = dockerfile
|
|
477
|
+
if dockerimage:
|
|
478
|
+
message["dockerimage"] = dockerimage
|
|
479
|
+
# Always include preserve_entrypoint flag (don't make it conditional)
|
|
480
|
+
message["preserve_entrypoint"] = preserve_entrypoint
|
|
481
|
+
|
|
482
|
+
# Add disk_name if provided
|
|
483
|
+
if disk_name:
|
|
484
|
+
message["disk_name"] = disk_name
|
|
485
|
+
|
|
486
|
+
# Add node_labels if provided (for node selection preferences)
|
|
487
|
+
if node_labels:
|
|
488
|
+
message["node_labels"] = node_labels
|
|
489
|
+
|
|
490
|
+
queue_url = self.config.get_queue_url()
|
|
491
|
+
self.config.sqs_client.send_message(
|
|
492
|
+
QueueUrl=queue_url, MessageBody=json.dumps(message)
|
|
493
|
+
)
|
|
494
|
+
|
|
495
|
+
return reservation_id
|
|
496
|
+
|
|
497
|
+
except Exception as e:
|
|
498
|
+
console.print(f"[red]❌ Error creating reservation: {str(e)}[/red]")
|
|
499
|
+
return None
|
|
500
|
+
|
|
501
|
+
def create_multinode_reservation(
|
|
502
|
+
self,
|
|
503
|
+
user_id: str,
|
|
504
|
+
gpu_count: int,
|
|
505
|
+
gpu_type: str,
|
|
506
|
+
duration_hours: Union[int, float],
|
|
507
|
+
name: Optional[str] = None,
|
|
508
|
+
github_user: Optional[str] = None,
|
|
509
|
+
jupyter_enabled: bool = False,
|
|
510
|
+
recreate_env: bool = False,
|
|
511
|
+
dockerfile: Optional[str] = None,
|
|
512
|
+
dockerimage: Optional[str] = None,
|
|
513
|
+
no_persistent_disk: bool = False,
|
|
514
|
+
preserve_entrypoint: bool = False,
|
|
515
|
+
disk_name: Optional[str] = None,
|
|
516
|
+
node_labels: Optional[Dict[str, str]] = None,
|
|
517
|
+
) -> Optional[List[str]]:
|
|
518
|
+
"""Create multiple GPU reservations for multinode setup"""
|
|
519
|
+
try:
|
|
520
|
+
# Determine GPU config
|
|
521
|
+
gpu_configs = {
|
|
522
|
+
"t4": {"max_gpus": 4},
|
|
523
|
+
"l4": {"max_gpus": 4},
|
|
524
|
+
"a10g": {"max_gpus": 4},
|
|
525
|
+
"t4-small": {"max_gpus": 1},
|
|
526
|
+
"g5g": {"max_gpus": 2},
|
|
527
|
+
"a100": {"max_gpus": 8},
|
|
528
|
+
"h100": {"max_gpus": 8},
|
|
529
|
+
"h200": {"max_gpus": 8},
|
|
530
|
+
"b200": {"max_gpus": 8},
|
|
531
|
+
}
|
|
532
|
+
|
|
533
|
+
max_gpus_per_node = gpu_configs[gpu_type]["max_gpus"]
|
|
534
|
+
num_nodes = gpu_count // max_gpus_per_node
|
|
535
|
+
|
|
536
|
+
if gpu_count % max_gpus_per_node != 0:
|
|
537
|
+
console.print(
|
|
538
|
+
f"[red]❌ GPU count must be multiple of {max_gpus_per_node} for {gpu_type}[/red]")
|
|
539
|
+
return None
|
|
540
|
+
|
|
541
|
+
# Generate a master reservation ID to group related nodes
|
|
542
|
+
master_reservation_id = str(uuid.uuid4())
|
|
543
|
+
created_at = datetime.utcnow().isoformat()
|
|
544
|
+
reservation_ids = []
|
|
545
|
+
|
|
546
|
+
# Create reservation for each node
|
|
547
|
+
for node_idx in range(num_nodes):
|
|
548
|
+
node_reservation_id = str(uuid.uuid4())
|
|
549
|
+
reservation_ids.append(node_reservation_id)
|
|
550
|
+
|
|
551
|
+
# Node-specific name
|
|
552
|
+
base_name = name or f'{gpu_count}x {gpu_type.upper()} multinode'
|
|
553
|
+
node_name = f"{base_name} - Node {node_idx + 1}/{num_nodes}"
|
|
554
|
+
|
|
555
|
+
# Create reservation message for this node
|
|
556
|
+
message = {
|
|
557
|
+
"reservation_id": node_reservation_id,
|
|
558
|
+
"master_reservation_id": master_reservation_id, # Group related nodes
|
|
559
|
+
"node_index": node_idx,
|
|
560
|
+
"total_nodes": num_nodes,
|
|
561
|
+
"user_id": user_id,
|
|
562
|
+
"gpu_count": max_gpus_per_node, # GPUs per node
|
|
563
|
+
"total_gpu_count": gpu_count, # Total GPUs across all nodes
|
|
564
|
+
"gpu_type": gpu_type,
|
|
565
|
+
"duration_hours": float(duration_hours),
|
|
566
|
+
"name": node_name,
|
|
567
|
+
"version": get_version(),
|
|
568
|
+
"created_at": created_at,
|
|
569
|
+
"status": "pending",
|
|
570
|
+
# Only enable Jupyter on master node
|
|
571
|
+
"jupyter_enabled": jupyter_enabled and node_idx == 0,
|
|
572
|
+
"recreate_env": recreate_env,
|
|
573
|
+
"is_multinode": True,
|
|
574
|
+
"no_persistent_disk": no_persistent_disk,
|
|
575
|
+
}
|
|
576
|
+
|
|
577
|
+
if github_user:
|
|
578
|
+
message["github_user"] = github_user
|
|
579
|
+
|
|
580
|
+
# Add Docker options if provided
|
|
581
|
+
if dockerfile:
|
|
582
|
+
message["dockerfile"] = dockerfile
|
|
583
|
+
if dockerimage:
|
|
584
|
+
message["dockerimage"] = dockerimage
|
|
585
|
+
# Always include preserve_entrypoint flag (don't make it conditional)
|
|
586
|
+
message["preserve_entrypoint"] = preserve_entrypoint
|
|
587
|
+
|
|
588
|
+
# Add disk_name if provided (only for master node in multinode setup)
|
|
589
|
+
if disk_name and node_idx == 0:
|
|
590
|
+
message["disk_name"] = disk_name
|
|
591
|
+
|
|
592
|
+
# Add node_labels if provided (for node selection preferences)
|
|
593
|
+
if node_labels:
|
|
594
|
+
message["node_labels"] = node_labels
|
|
595
|
+
|
|
596
|
+
# Send to SQS queue
|
|
597
|
+
queue_url = self.config.get_queue_url()
|
|
598
|
+
self.config.sqs_client.send_message(
|
|
599
|
+
QueueUrl=queue_url, MessageBody=json.dumps(message)
|
|
600
|
+
)
|
|
601
|
+
|
|
602
|
+
return reservation_ids
|
|
603
|
+
|
|
604
|
+
except Exception as e:
|
|
605
|
+
console.print(
|
|
606
|
+
f"[red]❌ Error creating multinode reservation: {str(e)}[/red]")
|
|
607
|
+
return None
|
|
608
|
+
|
|
609
|
+
def list_reservations(
|
|
610
|
+
self,
|
|
611
|
+
user_filter: Optional[str] = None,
|
|
612
|
+
statuses_to_include: Optional[List[str]] = None,
|
|
613
|
+
) -> List[Dict[str, Any]]:
|
|
614
|
+
"""List GPU reservations with flexible filtering"""
|
|
615
|
+
try:
|
|
616
|
+
all_reservations = []
|
|
617
|
+
|
|
618
|
+
if user_filter:
|
|
619
|
+
# Query by specific user with pagination
|
|
620
|
+
response = self.reservations_table.query(
|
|
621
|
+
IndexName="UserIndex",
|
|
622
|
+
KeyConditionExpression="user_id = :user_id",
|
|
623
|
+
ExpressionAttributeValues={":user_id": user_filter},
|
|
624
|
+
)
|
|
625
|
+
all_reservations = response.get("Items", [])
|
|
626
|
+
|
|
627
|
+
# Handle pagination for UserIndex query
|
|
628
|
+
while "LastEvaluatedKey" in response:
|
|
629
|
+
response = self.reservations_table.query(
|
|
630
|
+
IndexName="UserIndex",
|
|
631
|
+
KeyConditionExpression="user_id = :user_id",
|
|
632
|
+
ExpressionAttributeValues={":user_id": user_filter},
|
|
633
|
+
ExclusiveStartKey=response["LastEvaluatedKey"]
|
|
634
|
+
)
|
|
635
|
+
all_reservations.extend(response.get("Items", []))
|
|
636
|
+
else:
|
|
637
|
+
# Get all reservations (scan with pagination for admin use)
|
|
638
|
+
all_reservations = []
|
|
639
|
+
response = self.reservations_table.scan()
|
|
640
|
+
all_reservations.extend(response.get("Items", []))
|
|
641
|
+
|
|
642
|
+
# Handle pagination
|
|
643
|
+
while "LastEvaluatedKey" in response:
|
|
644
|
+
response = self.reservations_table.scan(
|
|
645
|
+
ExclusiveStartKey=response["LastEvaluatedKey"]
|
|
646
|
+
)
|
|
647
|
+
all_reservations.extend(response.get("Items", []))
|
|
648
|
+
|
|
649
|
+
# Filter by status if specified
|
|
650
|
+
if statuses_to_include:
|
|
651
|
+
filtered_reservations = [
|
|
652
|
+
reservation
|
|
653
|
+
for reservation in all_reservations
|
|
654
|
+
if reservation.get("status") in statuses_to_include
|
|
655
|
+
]
|
|
656
|
+
return filtered_reservations
|
|
657
|
+
|
|
658
|
+
return all_reservations
|
|
659
|
+
|
|
660
|
+
except Exception as e:
|
|
661
|
+
console.print(f"[red]❌ Error listing reservations: {str(e)}[/red]")
|
|
662
|
+
return []
|
|
663
|
+
|
|
664
|
+
def cancel_reservation(self, reservation_id: str, user_id: str) -> bool:
|
|
665
|
+
"""Cancel a GPU reservation by sending cancellation message to queue"""
|
|
666
|
+
try:
|
|
667
|
+
# Send cancellation request to SQS queue for processing
|
|
668
|
+
message = {
|
|
669
|
+
"type": "cancellation",
|
|
670
|
+
"reservation_id": reservation_id,
|
|
671
|
+
"user_id": user_id,
|
|
672
|
+
"requested_at": datetime.utcnow().isoformat(),
|
|
673
|
+
"version": get_version(),
|
|
674
|
+
}
|
|
675
|
+
|
|
676
|
+
queue_url = self.config.get_queue_url()
|
|
677
|
+
self.config.sqs_client.send_message(
|
|
678
|
+
QueueUrl=queue_url, MessageBody=json.dumps(message)
|
|
679
|
+
)
|
|
680
|
+
|
|
681
|
+
console.print(
|
|
682
|
+
f"[yellow]⏳ Cancellation request submitted for reservation {reservation_id[:8]}...[/yellow]"
|
|
683
|
+
)
|
|
684
|
+
console.print(
|
|
685
|
+
"[yellow]💡 The reservation will be cancelled shortly. Use 'gpu-dev list' to check status.[/yellow]"
|
|
686
|
+
)
|
|
687
|
+
return True
|
|
688
|
+
|
|
689
|
+
except Exception as e:
|
|
690
|
+
console.print(
|
|
691
|
+
f"[red]❌ Error submitting cancellation request: {str(e)}[/red]"
|
|
692
|
+
)
|
|
693
|
+
return False
|
|
694
|
+
|
|
695
|
+
def wait_for_multinode_reservation_completion(
|
|
696
|
+
self, reservation_ids: List[str], timeout_minutes: Optional[int] = 10, verbose: bool = False
|
|
697
|
+
) -> Optional[List[Dict[str, Any]]]:
|
|
698
|
+
"""Poll for multiple reservation completion using shared polling logic"""
|
|
699
|
+
return self._wait_for_reservations_completion(reservation_ids, timeout_minutes, is_multinode=True, verbose=verbose)
|
|
700
|
+
|
|
701
|
+
def get_connection_info(
|
|
702
|
+
self, reservation_id: str, user_id: str
|
|
703
|
+
) -> Optional[Dict[str, Any]]:
|
|
704
|
+
"""Get SSH connection information for a reservation"""
|
|
705
|
+
try:
|
|
706
|
+
# Query by user first (efficient), then filter by reservation_id prefix
|
|
707
|
+
response = self.reservations_table.query(
|
|
708
|
+
IndexName="UserIndex",
|
|
709
|
+
KeyConditionExpression="user_id = :user_id",
|
|
710
|
+
ExpressionAttributeValues={":user_id": user_id},
|
|
711
|
+
)
|
|
712
|
+
all_reservations = response.get("Items", [])
|
|
713
|
+
|
|
714
|
+
# Handle pagination for UserIndex query
|
|
715
|
+
while "LastEvaluatedKey" in response:
|
|
716
|
+
response = self.reservations_table.query(
|
|
717
|
+
IndexName="UserIndex",
|
|
718
|
+
KeyConditionExpression="user_id = :user_id",
|
|
719
|
+
ExpressionAttributeValues={":user_id": user_id},
|
|
720
|
+
ExclusiveStartKey=response["LastEvaluatedKey"]
|
|
721
|
+
)
|
|
722
|
+
all_reservations.extend(response.get("Items", []))
|
|
723
|
+
|
|
724
|
+
# Filter by reservation_id prefix in memory
|
|
725
|
+
matching_reservations = [
|
|
726
|
+
res for res in all_reservations
|
|
727
|
+
if res.get("reservation_id", "").startswith(reservation_id)
|
|
728
|
+
]
|
|
729
|
+
|
|
730
|
+
if len(matching_reservations) == 0:
|
|
731
|
+
return None
|
|
732
|
+
elif len(matching_reservations) > 1:
|
|
733
|
+
return None # Ambiguous - need longer prefix
|
|
734
|
+
|
|
735
|
+
reservation = matching_reservations[0]
|
|
736
|
+
|
|
737
|
+
return {
|
|
738
|
+
"ssh_command": reservation.get("ssh_command", "ssh user@pending"),
|
|
739
|
+
"pod_name": reservation.get("pod_name", "pending"),
|
|
740
|
+
"namespace": reservation.get("namespace", "default"),
|
|
741
|
+
"gpu_count": reservation["gpu_count"],
|
|
742
|
+
"status": reservation["status"],
|
|
743
|
+
"launched_at": reservation.get("launched_at"),
|
|
744
|
+
"expires_at": reservation.get("expires_at"),
|
|
745
|
+
"created_at": reservation.get("created_at"),
|
|
746
|
+
"reservation_id": reservation["reservation_id"],
|
|
747
|
+
"name": reservation.get("name"),
|
|
748
|
+
"instance_type": reservation.get("instance_type", "unknown"),
|
|
749
|
+
"gpu_type": reservation.get("gpu_type", "unknown"),
|
|
750
|
+
"failure_reason": reservation.get("failure_reason", ""),
|
|
751
|
+
"current_detailed_status": reservation.get("current_detailed_status", ""),
|
|
752
|
+
"status_history": reservation.get("status_history", []),
|
|
753
|
+
"pod_logs": reservation.get("pod_logs", ""),
|
|
754
|
+
"jupyter_url": reservation.get("jupyter_url", ""),
|
|
755
|
+
"jupyter_port": reservation.get("jupyter_port", ""),
|
|
756
|
+
"jupyter_token": reservation.get("jupyter_token", ""),
|
|
757
|
+
"jupyter_enabled": reservation.get("jupyter_enabled", False),
|
|
758
|
+
"jupyter_error": reservation.get("jupyter_error", ""),
|
|
759
|
+
"ebs_volume_id": reservation.get("ebs_volume_id", ""),
|
|
760
|
+
"secondary_users": reservation.get("secondary_users", []),
|
|
761
|
+
"warning": reservation.get("warning", ""),
|
|
762
|
+
}
|
|
763
|
+
|
|
764
|
+
except Exception as e:
|
|
765
|
+
console.print(
|
|
766
|
+
f"[red]❌ Error getting connection info: {str(e)}[/red]")
|
|
767
|
+
return None
|
|
768
|
+
|
|
769
|
+
def enable_jupyter(self, reservation_id: str, user_id: str) -> bool:
|
|
770
|
+
"""Enable Jupyter Lab for an active reservation"""
|
|
771
|
+
try:
|
|
772
|
+
# Send message to Lambda to start Jupyter service in pod
|
|
773
|
+
# Lambda will handle both the pod changes and DynamoDB updates
|
|
774
|
+
message = {
|
|
775
|
+
"action": "enable_jupyter",
|
|
776
|
+
"reservation_id": reservation_id,
|
|
777
|
+
"user_id": user_id,
|
|
778
|
+
"version": get_version(),
|
|
779
|
+
}
|
|
780
|
+
|
|
781
|
+
queue_url = self.config.get_queue_url()
|
|
782
|
+
self.config.sqs_client.send_message(
|
|
783
|
+
QueueUrl=queue_url, MessageBody=json.dumps(message)
|
|
784
|
+
)
|
|
785
|
+
|
|
786
|
+
console.print(
|
|
787
|
+
f"[yellow]⏳ Jupyter enable request submitted for reservation {reservation_id[:8]}...[/yellow]"
|
|
788
|
+
)
|
|
789
|
+
|
|
790
|
+
# Poll for 3 minutes to show the outcome
|
|
791
|
+
return self._poll_jupyter_action_result(
|
|
792
|
+
reservation_id, user_id, "enable", timeout_minutes=3
|
|
793
|
+
)
|
|
794
|
+
|
|
795
|
+
except Exception as e:
|
|
796
|
+
console.print(
|
|
797
|
+
f"[red]❌ Error submitting Jupyter enable request: {str(e)}[/red]"
|
|
798
|
+
)
|
|
799
|
+
return False
|
|
800
|
+
|
|
801
|
+
def disable_jupyter(self, reservation_id: str, user_id: str) -> bool:
|
|
802
|
+
"""Disable Jupyter Lab for an active reservation"""
|
|
803
|
+
try:
|
|
804
|
+
# Send message to Lambda to stop Jupyter service in pod
|
|
805
|
+
# Lambda will handle both the pod changes and DynamoDB updates
|
|
806
|
+
message = {
|
|
807
|
+
"action": "disable_jupyter",
|
|
808
|
+
"reservation_id": reservation_id,
|
|
809
|
+
"user_id": user_id,
|
|
810
|
+
"version": get_version(),
|
|
811
|
+
}
|
|
812
|
+
|
|
813
|
+
queue_url = self.config.get_queue_url()
|
|
814
|
+
self.config.sqs_client.send_message(
|
|
815
|
+
QueueUrl=queue_url, MessageBody=json.dumps(message)
|
|
816
|
+
)
|
|
817
|
+
|
|
818
|
+
console.print(
|
|
819
|
+
f"[yellow]⏳ Jupyter disable request submitted for reservation {reservation_id[:8]}...[/yellow]"
|
|
820
|
+
)
|
|
821
|
+
|
|
822
|
+
# Poll for 3 minutes to show the outcome
|
|
823
|
+
return self._poll_jupyter_action_result(
|
|
824
|
+
reservation_id, user_id, "disable", timeout_minutes=3
|
|
825
|
+
)
|
|
826
|
+
|
|
827
|
+
except Exception as e:
|
|
828
|
+
console.print(
|
|
829
|
+
f"[red]❌ Error submitting Jupyter disable request: {str(e)}[/red]"
|
|
830
|
+
)
|
|
831
|
+
return False
|
|
832
|
+
|
|
833
|
+
def add_user(self, reservation_id: str, user_id: str, github_username: str) -> bool:
|
|
834
|
+
"""Add a secondary user to an active reservation"""
|
|
835
|
+
try:
|
|
836
|
+
# Validate GitHub username format (basic validation)
|
|
837
|
+
if (
|
|
838
|
+
not github_username
|
|
839
|
+
or not github_username.replace("-", "").replace("_", "").isalnum()
|
|
840
|
+
):
|
|
841
|
+
console.print(
|
|
842
|
+
f"[red]❌ Invalid GitHub username: {github_username}[/red]"
|
|
843
|
+
)
|
|
844
|
+
return False
|
|
845
|
+
|
|
846
|
+
# Send message to Lambda to add user SSH keys to pod
|
|
847
|
+
# Lambda will handle fetching GitHub keys and updating the pod
|
|
848
|
+
message = {
|
|
849
|
+
"action": "add_user",
|
|
850
|
+
"reservation_id": reservation_id,
|
|
851
|
+
"user_id": user_id,
|
|
852
|
+
"github_username": github_username,
|
|
853
|
+
"version": get_version(),
|
|
854
|
+
}
|
|
855
|
+
|
|
856
|
+
queue_url = self.config.get_queue_url()
|
|
857
|
+
self.config.sqs_client.send_message(
|
|
858
|
+
QueueUrl=queue_url, MessageBody=json.dumps(message)
|
|
859
|
+
)
|
|
860
|
+
|
|
861
|
+
console.print(
|
|
862
|
+
f"[yellow]⏳ Adding user {github_username} to reservation {reservation_id[:8]}...[/yellow]"
|
|
863
|
+
)
|
|
864
|
+
|
|
865
|
+
# Poll for 3 minutes to show the outcome
|
|
866
|
+
return self._poll_add_user_result(
|
|
867
|
+
reservation_id, user_id, github_username, timeout_minutes=3
|
|
868
|
+
)
|
|
869
|
+
|
|
870
|
+
except Exception as e:
|
|
871
|
+
console.print(
|
|
872
|
+
f"[red]❌ Error adding user {github_username}: {str(e)}[/red]"
|
|
873
|
+
)
|
|
874
|
+
return False
|
|
875
|
+
|
|
876
|
+
def extend_reservation(self, reservation_id: str, user_id: str, extension_hours: float) -> bool:
|
|
877
|
+
"""Extend an active reservation by the specified number of hours"""
|
|
878
|
+
try:
|
|
879
|
+
# Capture current expiration BEFORE sending extension request to avoid race condition
|
|
880
|
+
response = self.reservations_table.query(
|
|
881
|
+
IndexName="UserIndex",
|
|
882
|
+
KeyConditionExpression="user_id = :user_id",
|
|
883
|
+
ExpressionAttributeValues={":user_id": user_id},
|
|
884
|
+
)
|
|
885
|
+
all_reservations = response.get("Items", [])
|
|
886
|
+
|
|
887
|
+
# Handle pagination for UserIndex query
|
|
888
|
+
while "LastEvaluatedKey" in response:
|
|
889
|
+
response = self.reservations_table.query(
|
|
890
|
+
IndexName="UserIndex",
|
|
891
|
+
KeyConditionExpression="user_id = :user_id",
|
|
892
|
+
ExpressionAttributeValues={":user_id": user_id},
|
|
893
|
+
ExclusiveStartKey=response["LastEvaluatedKey"]
|
|
894
|
+
)
|
|
895
|
+
all_reservations.extend(response.get("Items", []))
|
|
896
|
+
|
|
897
|
+
matching_reservations = [
|
|
898
|
+
res for res in all_reservations
|
|
899
|
+
if res.get("reservation_id", "").startswith(reservation_id)
|
|
900
|
+
]
|
|
901
|
+
|
|
902
|
+
initial_expires_at = None
|
|
903
|
+
if matching_reservations:
|
|
904
|
+
initial_expires_at = matching_reservations[0].get("expires_at", "")
|
|
905
|
+
|
|
906
|
+
# Send message to Lambda to extend reservation
|
|
907
|
+
# Lambda will handle both the expiration timestamp update and any necessary pod updates
|
|
908
|
+
message = {
|
|
909
|
+
"action": "extend_reservation",
|
|
910
|
+
"reservation_id": reservation_id,
|
|
911
|
+
"extension_hours": extension_hours,
|
|
912
|
+
"version": get_version(),
|
|
913
|
+
}
|
|
914
|
+
|
|
915
|
+
queue_url = self.config.get_queue_url()
|
|
916
|
+
self.config.sqs_client.send_message(
|
|
917
|
+
QueueUrl=queue_url, MessageBody=json.dumps(message)
|
|
918
|
+
)
|
|
919
|
+
|
|
920
|
+
console.print(
|
|
921
|
+
f"[yellow]⏳ Extension request submitted for reservation {reservation_id[:8]}...[/yellow]"
|
|
922
|
+
)
|
|
923
|
+
|
|
924
|
+
# Poll for 3 minutes to show the outcome
|
|
925
|
+
return self._poll_extend_action_result(
|
|
926
|
+
reservation_id, user_id, extension_hours, timeout_minutes=3, initial_expires_at=initial_expires_at
|
|
927
|
+
)
|
|
928
|
+
|
|
929
|
+
except Exception as e:
|
|
930
|
+
console.print(
|
|
931
|
+
f"[red]❌ Error submitting extension request: {str(e)}[/red]")
|
|
932
|
+
return False
|
|
933
|
+
|
|
934
|
+
def get_gpu_availability_by_type(self) -> Optional[Dict[str, Dict[str, Any]]]:
|
|
935
|
+
"""Get GPU availability information by GPU type from real-time availability table"""
|
|
936
|
+
try:
|
|
937
|
+
# Try to get real-time availability from the availability table
|
|
938
|
+
availability_table_name = self.config.availability_table
|
|
939
|
+
availability_table = self.config.dynamodb.Table(
|
|
940
|
+
availability_table_name)
|
|
941
|
+
|
|
942
|
+
# Scan the whole availability table with pagination
|
|
943
|
+
response = availability_table.scan()
|
|
944
|
+
availability_info = {}
|
|
945
|
+
all_items = response.get("Items", [])
|
|
946
|
+
|
|
947
|
+
# Handle pagination for availability table
|
|
948
|
+
while "LastEvaluatedKey" in response:
|
|
949
|
+
response = availability_table.scan(
|
|
950
|
+
ExclusiveStartKey=response["LastEvaluatedKey"]
|
|
951
|
+
)
|
|
952
|
+
all_items.extend(response.get("Items", []))
|
|
953
|
+
|
|
954
|
+
for item in all_items:
|
|
955
|
+
gpu_type = item["gpu_type"]
|
|
956
|
+
queue_length = self._get_queue_length_for_gpu_type(gpu_type)
|
|
957
|
+
estimated_wait = queue_length * 15 if queue_length > 0 else 0
|
|
958
|
+
|
|
959
|
+
availability_info[gpu_type] = {
|
|
960
|
+
"available": int(item.get("available_gpus", 0)),
|
|
961
|
+
"total": int(item.get("total_gpus", 0)),
|
|
962
|
+
"max_reservable": int(item.get("max_reservable", 0)),
|
|
963
|
+
"full_nodes_available": int(item.get("full_nodes_available", 0)),
|
|
964
|
+
"gpus_per_instance": int(item.get("gpus_per_instance", 0)),
|
|
965
|
+
"queue_length": queue_length,
|
|
966
|
+
"estimated_wait_minutes": estimated_wait,
|
|
967
|
+
"running_instances": int(item.get("running_instances", 0)),
|
|
968
|
+
"desired_capacity": int(item.get("desired_capacity", 0)),
|
|
969
|
+
"last_updated": item.get("last_updated_timestamp", 0),
|
|
970
|
+
}
|
|
971
|
+
|
|
972
|
+
return availability_info
|
|
973
|
+
|
|
974
|
+
except Exception as e:
|
|
975
|
+
console.print(
|
|
976
|
+
f"[red]❌ Error getting GPU availability: {str(e)}[/red]")
|
|
977
|
+
return None
|
|
978
|
+
|
|
979
|
+
def _get_static_gpu_config(
|
|
980
|
+
self, gpu_type: str, queue_length: int, estimated_wait: int
|
|
981
|
+
) -> Dict[str, Any]:
|
|
982
|
+
"""Get static GPU configuration as fallback when real-time data unavailable"""
|
|
983
|
+
static_configs = {
|
|
984
|
+
# 2x p4d.24xlarge = 16 A100s
|
|
985
|
+
"a100": {"available": 0, "total": 16},
|
|
986
|
+
# 1x p6-b200.48xlarge = 8 B200s
|
|
987
|
+
"b200": {"available": 0, "total": 8},
|
|
988
|
+
# 2x p5e.48xlarge = 16 H200s
|
|
989
|
+
"h200": {"available": 0, "total": 16},
|
|
990
|
+
"h100": {"available": 0, "total": 16}, # 2x p5.48xlarge = 16 H100s
|
|
991
|
+
"t4": {"available": 0, "total": 8}, # 2x g4dn.12xlarge = 8 T4s
|
|
992
|
+
}
|
|
993
|
+
|
|
994
|
+
config = static_configs.get(gpu_type, {"available": 0, "total": 0})
|
|
995
|
+
return {
|
|
996
|
+
"available": config["available"],
|
|
997
|
+
"total": config["total"],
|
|
998
|
+
"queue_length": queue_length,
|
|
999
|
+
"estimated_wait_minutes": estimated_wait,
|
|
1000
|
+
"running_instances": 0,
|
|
1001
|
+
"desired_capacity": 0,
|
|
1002
|
+
"last_updated": 0,
|
|
1003
|
+
}
|
|
1004
|
+
|
|
1005
|
+
def _get_queue_length_for_gpu_type(self, gpu_type: str) -> int:
|
|
1006
|
+
"""Get the number of queued reservations for a specific GPU type"""
|
|
1007
|
+
try:
|
|
1008
|
+
total_count = 0
|
|
1009
|
+
|
|
1010
|
+
# Count queued reservations for this GPU type
|
|
1011
|
+
for status in ["queued", "pending"]:
|
|
1012
|
+
try:
|
|
1013
|
+
response = self.reservations_table.query(
|
|
1014
|
+
IndexName="StatusGpuTypeIndex",
|
|
1015
|
+
KeyConditionExpression="#status = :status AND gpu_type = :gpu_type",
|
|
1016
|
+
ExpressionAttributeNames={"#status": "status"},
|
|
1017
|
+
ExpressionAttributeValues={
|
|
1018
|
+
":status": status,
|
|
1019
|
+
":gpu_type": gpu_type,
|
|
1020
|
+
},
|
|
1021
|
+
)
|
|
1022
|
+
total_count += len(response.get("Items", []))
|
|
1023
|
+
|
|
1024
|
+
# Handle pagination for StatusGpuTypeIndex query
|
|
1025
|
+
while "LastEvaluatedKey" in response:
|
|
1026
|
+
response = self.reservations_table.query(
|
|
1027
|
+
IndexName="StatusGpuTypeIndex",
|
|
1028
|
+
KeyConditionExpression="#status = :status AND gpu_type = :gpu_type",
|
|
1029
|
+
ExpressionAttributeNames={"#status": "status"},
|
|
1030
|
+
ExpressionAttributeValues={
|
|
1031
|
+
":status": status,
|
|
1032
|
+
":gpu_type": gpu_type,
|
|
1033
|
+
},
|
|
1034
|
+
ExclusiveStartKey=response["LastEvaluatedKey"]
|
|
1035
|
+
)
|
|
1036
|
+
total_count += len(response.get("Items", []))
|
|
1037
|
+
except Exception as query_error:
|
|
1038
|
+
# Fallback to scanning if the composite index doesn't exist yet
|
|
1039
|
+
console.print(
|
|
1040
|
+
f"[dim]Fallback: scanning for {status} {gpu_type} reservations[/dim]"
|
|
1041
|
+
)
|
|
1042
|
+
response = self.reservations_table.scan(
|
|
1043
|
+
FilterExpression="contains(#status, :status) AND contains(gpu_type, :gpu_type)",
|
|
1044
|
+
ExpressionAttributeNames={"#status": "status"},
|
|
1045
|
+
ExpressionAttributeValues={
|
|
1046
|
+
":status": status,
|
|
1047
|
+
":gpu_type": gpu_type,
|
|
1048
|
+
},
|
|
1049
|
+
)
|
|
1050
|
+
total_count += len(response.get("Items", []))
|
|
1051
|
+
|
|
1052
|
+
# Handle pagination for fallback scan
|
|
1053
|
+
while "LastEvaluatedKey" in response:
|
|
1054
|
+
response = self.reservations_table.scan(
|
|
1055
|
+
FilterExpression="contains(#status, :status) AND contains(gpu_type, :gpu_type)",
|
|
1056
|
+
ExpressionAttributeNames={"#status": "status"},
|
|
1057
|
+
ExpressionAttributeValues={
|
|
1058
|
+
":status": status,
|
|
1059
|
+
":gpu_type": gpu_type,
|
|
1060
|
+
},
|
|
1061
|
+
ExclusiveStartKey=response["LastEvaluatedKey"]
|
|
1062
|
+
)
|
|
1063
|
+
total_count += len(response.get("Items", []))
|
|
1064
|
+
|
|
1065
|
+
return total_count
|
|
1066
|
+
|
|
1067
|
+
except Exception as e:
|
|
1068
|
+
console.print(
|
|
1069
|
+
f"[red]❌ Error getting queue length for {gpu_type}: {str(e)}[/red]"
|
|
1070
|
+
)
|
|
1071
|
+
return 0
|
|
1072
|
+
|
|
1073
|
+
def _poll_jupyter_action_result(
|
|
1074
|
+
self, reservation_id: str, user_id: str, action: str, timeout_minutes: int = 3
|
|
1075
|
+
) -> bool:
|
|
1076
|
+
"""Poll reservation table for Jupyter action result"""
|
|
1077
|
+
try:
|
|
1078
|
+
start_time = time.time()
|
|
1079
|
+
timeout_seconds = timeout_minutes * 60
|
|
1080
|
+
|
|
1081
|
+
with Live(console=console, refresh_per_second=2) as live:
|
|
1082
|
+
spinner = Spinner(
|
|
1083
|
+
"dots", text=f"🔄 Processing Jupyter {action} request..."
|
|
1084
|
+
)
|
|
1085
|
+
live.update(spinner)
|
|
1086
|
+
|
|
1087
|
+
initial_state = None
|
|
1088
|
+
|
|
1089
|
+
while time.time() - start_time < timeout_seconds:
|
|
1090
|
+
try:
|
|
1091
|
+
# Get current reservation state - query by user first, then filter by prefix
|
|
1092
|
+
response = self.reservations_table.query(
|
|
1093
|
+
IndexName="UserIndex",
|
|
1094
|
+
KeyConditionExpression="user_id = :user_id",
|
|
1095
|
+
ExpressionAttributeValues={":user_id": user_id},
|
|
1096
|
+
)
|
|
1097
|
+
all_reservations = response.get("Items", [])
|
|
1098
|
+
|
|
1099
|
+
# Handle pagination for UserIndex query
|
|
1100
|
+
while "LastEvaluatedKey" in response:
|
|
1101
|
+
response = self.reservations_table.query(
|
|
1102
|
+
IndexName="UserIndex",
|
|
1103
|
+
KeyConditionExpression="user_id = :user_id",
|
|
1104
|
+
ExpressionAttributeValues={
|
|
1105
|
+
":user_id": user_id},
|
|
1106
|
+
ExclusiveStartKey=response["LastEvaluatedKey"]
|
|
1107
|
+
)
|
|
1108
|
+
all_reservations.extend(response.get("Items", []))
|
|
1109
|
+
|
|
1110
|
+
# Filter by reservation_id prefix in memory
|
|
1111
|
+
items = [
|
|
1112
|
+
res for res in all_reservations
|
|
1113
|
+
if res.get("reservation_id", "").startswith(reservation_id)
|
|
1114
|
+
]
|
|
1115
|
+
if len(items) == 0:
|
|
1116
|
+
spinner.text = f"🔄 Waiting for reservation data..."
|
|
1117
|
+
live.update(spinner)
|
|
1118
|
+
time.sleep(2)
|
|
1119
|
+
continue
|
|
1120
|
+
elif len(items) > 1:
|
|
1121
|
+
spinner.text = f"🔄 Multiple reservations found for {reservation_id}, using first match..."
|
|
1122
|
+
live.update(spinner)
|
|
1123
|
+
|
|
1124
|
+
reservation = items[0]
|
|
1125
|
+
|
|
1126
|
+
# Capture initial state on first iteration
|
|
1127
|
+
if initial_state is None:
|
|
1128
|
+
initial_state = {
|
|
1129
|
+
"jupyter_enabled": reservation.get(
|
|
1130
|
+
"jupyter_enabled", False
|
|
1131
|
+
),
|
|
1132
|
+
"jupyter_url": reservation.get("jupyter_url", ""),
|
|
1133
|
+
"jupyter_port": reservation.get("jupyter_port", 0),
|
|
1134
|
+
}
|
|
1135
|
+
|
|
1136
|
+
current_jupyter_enabled = reservation.get(
|
|
1137
|
+
"jupyter_enabled", False
|
|
1138
|
+
)
|
|
1139
|
+
jupyter_url = reservation.get("jupyter_url", "")
|
|
1140
|
+
jupyter_port = reservation.get("jupyter_port", 0)
|
|
1141
|
+
|
|
1142
|
+
# Check if the action has completed
|
|
1143
|
+
if action == "enable":
|
|
1144
|
+
if current_jupyter_enabled and jupyter_url:
|
|
1145
|
+
live.stop()
|
|
1146
|
+
console.print(
|
|
1147
|
+
f"[green]✅ Jupyter Lab enabled successfully![/green]"
|
|
1148
|
+
)
|
|
1149
|
+
console.print(
|
|
1150
|
+
f"[cyan]🔗 Jupyter URL:[/cyan] {jupyter_url}"
|
|
1151
|
+
)
|
|
1152
|
+
console.print(
|
|
1153
|
+
f"[cyan]🔌 Port:[/cyan] {jupyter_port}")
|
|
1154
|
+
return True
|
|
1155
|
+
elif (
|
|
1156
|
+
current_jupyter_enabled
|
|
1157
|
+
!= initial_state["jupyter_enabled"]
|
|
1158
|
+
):
|
|
1159
|
+
spinner.text = f"🔄 Jupyter enabled, waiting for URL..."
|
|
1160
|
+
else: # disable
|
|
1161
|
+
if not current_jupyter_enabled and not jupyter_url:
|
|
1162
|
+
live.stop()
|
|
1163
|
+
console.print(
|
|
1164
|
+
f"[green]✅ Jupyter Lab disabled successfully![/green]"
|
|
1165
|
+
)
|
|
1166
|
+
return True
|
|
1167
|
+
elif (
|
|
1168
|
+
current_jupyter_enabled
|
|
1169
|
+
!= initial_state["jupyter_enabled"]
|
|
1170
|
+
):
|
|
1171
|
+
spinner.text = f"🔄 Stopping Jupyter service..."
|
|
1172
|
+
|
|
1173
|
+
live.update(spinner)
|
|
1174
|
+
time.sleep(3)
|
|
1175
|
+
|
|
1176
|
+
except Exception as poll_error:
|
|
1177
|
+
console.print(
|
|
1178
|
+
f"[red]❌ Error polling Jupyter status: {poll_error}[/red]"
|
|
1179
|
+
)
|
|
1180
|
+
return False
|
|
1181
|
+
|
|
1182
|
+
# Timeout reached
|
|
1183
|
+
live.stop()
|
|
1184
|
+
console.print(
|
|
1185
|
+
f"[yellow]⏰ Timeout after {timeout_minutes} minutes[/yellow]"
|
|
1186
|
+
)
|
|
1187
|
+
console.print(
|
|
1188
|
+
f"[yellow]💡 Use 'gpu-dev show {reservation_id[:8]}' to check Jupyter status[/yellow]"
|
|
1189
|
+
)
|
|
1190
|
+
return False
|
|
1191
|
+
|
|
1192
|
+
except Exception as e:
|
|
1193
|
+
console.print(
|
|
1194
|
+
f"[red]❌ Error during Jupyter {action} polling: {str(e)}[/red]"
|
|
1195
|
+
)
|
|
1196
|
+
return False
|
|
1197
|
+
|
|
1198
|
+
def _poll_add_user_result(
|
|
1199
|
+
self, reservation_id: str, user_id: str, github_username: str, timeout_minutes: int = 3
|
|
1200
|
+
) -> bool:
|
|
1201
|
+
"""Poll reservation table for add user action result"""
|
|
1202
|
+
try:
|
|
1203
|
+
start_time = time.time()
|
|
1204
|
+
timeout_seconds = timeout_minutes * 60
|
|
1205
|
+
|
|
1206
|
+
with Live(console=console, refresh_per_second=2) as live:
|
|
1207
|
+
spinner = Spinner(
|
|
1208
|
+
"dots", text=f"🔄 Adding user {github_username}...")
|
|
1209
|
+
live.update(spinner)
|
|
1210
|
+
|
|
1211
|
+
initial_secondary_users = None
|
|
1212
|
+
|
|
1213
|
+
while time.time() - start_time < timeout_seconds:
|
|
1214
|
+
try:
|
|
1215
|
+
# Get current reservation state - query by user first, then filter by prefix
|
|
1216
|
+
response = self.reservations_table.query(
|
|
1217
|
+
IndexName="UserIndex",
|
|
1218
|
+
KeyConditionExpression="user_id = :user_id",
|
|
1219
|
+
ExpressionAttributeValues={":user_id": user_id},
|
|
1220
|
+
)
|
|
1221
|
+
all_reservations = response.get("Items", [])
|
|
1222
|
+
|
|
1223
|
+
# Handle pagination for UserIndex query
|
|
1224
|
+
while "LastEvaluatedKey" in response:
|
|
1225
|
+
response = self.reservations_table.query(
|
|
1226
|
+
IndexName="UserIndex",
|
|
1227
|
+
KeyConditionExpression="user_id = :user_id",
|
|
1228
|
+
ExpressionAttributeValues={
|
|
1229
|
+
":user_id": user_id},
|
|
1230
|
+
ExclusiveStartKey=response["LastEvaluatedKey"]
|
|
1231
|
+
)
|
|
1232
|
+
all_reservations.extend(response.get("Items", []))
|
|
1233
|
+
|
|
1234
|
+
# Filter by reservation_id prefix in memory
|
|
1235
|
+
items = [
|
|
1236
|
+
res for res in all_reservations
|
|
1237
|
+
if res.get("reservation_id", "").startswith(reservation_id)
|
|
1238
|
+
]
|
|
1239
|
+
if len(items) == 0:
|
|
1240
|
+
spinner.text = f"🔄 Waiting for reservation data..."
|
|
1241
|
+
live.update(spinner)
|
|
1242
|
+
time.sleep(2)
|
|
1243
|
+
continue
|
|
1244
|
+
elif len(items) > 1:
|
|
1245
|
+
spinner.text = f"🔄 Multiple reservations found for {reservation_id}, using first match..."
|
|
1246
|
+
live.update(spinner)
|
|
1247
|
+
|
|
1248
|
+
reservation = items[0]
|
|
1249
|
+
|
|
1250
|
+
# Capture initial state on first iteration
|
|
1251
|
+
if initial_secondary_users is None:
|
|
1252
|
+
initial_secondary_users = reservation.get(
|
|
1253
|
+
"secondary_users", []
|
|
1254
|
+
)
|
|
1255
|
+
|
|
1256
|
+
current_secondary_users = reservation.get(
|
|
1257
|
+
"secondary_users", [])
|
|
1258
|
+
|
|
1259
|
+
# Check if the user has been added
|
|
1260
|
+
if github_username in current_secondary_users:
|
|
1261
|
+
live.stop()
|
|
1262
|
+
console.print(
|
|
1263
|
+
f"[green]✅ User {github_username} added successfully![/green]"
|
|
1264
|
+
)
|
|
1265
|
+
console.print(
|
|
1266
|
+
f"[cyan]👥 Secondary users:[/cyan] {', '.join(current_secondary_users)}"
|
|
1267
|
+
)
|
|
1268
|
+
return True
|
|
1269
|
+
elif len(current_secondary_users) != len(
|
|
1270
|
+
initial_secondary_users
|
|
1271
|
+
):
|
|
1272
|
+
spinner.text = (
|
|
1273
|
+
f"🔄 User list updated, verifying {github_username}..."
|
|
1274
|
+
)
|
|
1275
|
+
|
|
1276
|
+
live.update(spinner)
|
|
1277
|
+
time.sleep(3)
|
|
1278
|
+
|
|
1279
|
+
except Exception as poll_error:
|
|
1280
|
+
console.print(
|
|
1281
|
+
f"[red]❌ Error polling add user status: {poll_error}[/red]"
|
|
1282
|
+
)
|
|
1283
|
+
return False
|
|
1284
|
+
|
|
1285
|
+
# Timeout reached
|
|
1286
|
+
live.stop()
|
|
1287
|
+
console.print(
|
|
1288
|
+
f"[yellow]⏰ Timeout after {timeout_minutes} minutes[/yellow]"
|
|
1289
|
+
)
|
|
1290
|
+
console.print(
|
|
1291
|
+
f"[yellow]💡 Use 'gpu-dev show {reservation_id[:8]}' to check user status[/yellow]"
|
|
1292
|
+
)
|
|
1293
|
+
return False
|
|
1294
|
+
|
|
1295
|
+
except Exception as e:
|
|
1296
|
+
console.print(
|
|
1297
|
+
f"[red]❌ Error during add user polling: {str(e)}[/red]")
|
|
1298
|
+
return False
|
|
1299
|
+
|
|
1300
|
+
def _poll_extend_action_result(
|
|
1301
|
+
self, reservation_id: str, user_id: str, extension_hours: float, timeout_minutes: int = 3, initial_expires_at: str = None
|
|
1302
|
+
) -> bool:
|
|
1303
|
+
"""Poll reservation table for extend action result"""
|
|
1304
|
+
try:
|
|
1305
|
+
start_time = time.time()
|
|
1306
|
+
timeout_seconds = timeout_minutes * 60
|
|
1307
|
+
|
|
1308
|
+
with Live(console=console, refresh_per_second=2) as live:
|
|
1309
|
+
spinner = Spinner(
|
|
1310
|
+
"dots",
|
|
1311
|
+
text=f"🔄 Extending reservation by {extension_hours} hours...",
|
|
1312
|
+
)
|
|
1313
|
+
live.update(spinner)
|
|
1314
|
+
|
|
1315
|
+
# Use pre-captured initial_expires_at if provided (to avoid race condition)
|
|
1316
|
+
initial_expiration = initial_expires_at
|
|
1317
|
+
|
|
1318
|
+
while time.time() - start_time < timeout_seconds:
|
|
1319
|
+
try:
|
|
1320
|
+
# Get current reservation state - query by user first, then filter by prefix
|
|
1321
|
+
response = self.reservations_table.query(
|
|
1322
|
+
IndexName="UserIndex",
|
|
1323
|
+
KeyConditionExpression="user_id = :user_id",
|
|
1324
|
+
ExpressionAttributeValues={":user_id": user_id},
|
|
1325
|
+
)
|
|
1326
|
+
all_reservations = response.get("Items", [])
|
|
1327
|
+
|
|
1328
|
+
# Handle pagination for UserIndex query
|
|
1329
|
+
while "LastEvaluatedKey" in response:
|
|
1330
|
+
response = self.reservations_table.query(
|
|
1331
|
+
IndexName="UserIndex",
|
|
1332
|
+
KeyConditionExpression="user_id = :user_id",
|
|
1333
|
+
ExpressionAttributeValues={
|
|
1334
|
+
":user_id": user_id},
|
|
1335
|
+
ExclusiveStartKey=response["LastEvaluatedKey"]
|
|
1336
|
+
)
|
|
1337
|
+
all_reservations.extend(response.get("Items", []))
|
|
1338
|
+
|
|
1339
|
+
# Filter by reservation_id prefix in memory
|
|
1340
|
+
items = [
|
|
1341
|
+
res for res in all_reservations
|
|
1342
|
+
if res.get("reservation_id", "").startswith(reservation_id)
|
|
1343
|
+
]
|
|
1344
|
+
if len(items) == 0:
|
|
1345
|
+
spinner.text = f"🔄 Waiting for reservation data..."
|
|
1346
|
+
live.update(spinner)
|
|
1347
|
+
time.sleep(2)
|
|
1348
|
+
continue
|
|
1349
|
+
elif len(items) > 1:
|
|
1350
|
+
spinner.text = f"🔄 Multiple reservations found for {reservation_id}, using first match..."
|
|
1351
|
+
live.update(spinner)
|
|
1352
|
+
|
|
1353
|
+
reservation = items[0]
|
|
1354
|
+
|
|
1355
|
+
# Capture initial expiration on first iteration
|
|
1356
|
+
if initial_expiration is None:
|
|
1357
|
+
initial_expiration = reservation.get(
|
|
1358
|
+
"expires_at", "")
|
|
1359
|
+
|
|
1360
|
+
current_expiration = reservation.get("expires_at", "")
|
|
1361
|
+
|
|
1362
|
+
# Check for extension failure indicators
|
|
1363
|
+
last_updated = reservation.get("last_updated", 0)
|
|
1364
|
+
extension_error = reservation.get(
|
|
1365
|
+
"extension_error", "")
|
|
1366
|
+
|
|
1367
|
+
# If there's an extension error, fail immediately
|
|
1368
|
+
if extension_error:
|
|
1369
|
+
live.stop()
|
|
1370
|
+
console.print(
|
|
1371
|
+
f"[red]❌ Extension failed: {extension_error}[/red]"
|
|
1372
|
+
)
|
|
1373
|
+
return False
|
|
1374
|
+
|
|
1375
|
+
# Check if the expiration has been updated (different from initial)
|
|
1376
|
+
if (
|
|
1377
|
+
current_expiration != initial_expiration
|
|
1378
|
+
and current_expiration
|
|
1379
|
+
):
|
|
1380
|
+
live.stop()
|
|
1381
|
+
from datetime import datetime, timezone
|
|
1382
|
+
|
|
1383
|
+
try:
|
|
1384
|
+
# Treat as naive datetime and manually add UTC timezone (matches list command)
|
|
1385
|
+
naive_dt = datetime.fromisoformat(current_expiration)
|
|
1386
|
+
exp_dt_utc = naive_dt.replace(tzinfo=timezone.utc)
|
|
1387
|
+
# Convert to local timezone
|
|
1388
|
+
local_exp = exp_dt_utc.astimezone()
|
|
1389
|
+
# Format with same style as list command: month-day hour:minute
|
|
1390
|
+
formatted_expiration = local_exp.strftime("%m-%d %H:%M")
|
|
1391
|
+
console.print(
|
|
1392
|
+
f"[green]✅ Extended reservation {reservation_id} by {extension_hours} hours -- your new expiration is {formatted_expiration}[/green]"
|
|
1393
|
+
)
|
|
1394
|
+
return True
|
|
1395
|
+
except Exception:
|
|
1396
|
+
# Fallback to raw display if parsing fails
|
|
1397
|
+
console.print(
|
|
1398
|
+
f"[green]✅ Extended reservation {reservation_id} by {extension_hours} hours -- your new expiration is {current_expiration}[/green]"
|
|
1399
|
+
)
|
|
1400
|
+
return True
|
|
1401
|
+
|
|
1402
|
+
spinner.text = f"🔄 Processing extension request..."
|
|
1403
|
+
live.update(spinner)
|
|
1404
|
+
time.sleep(2)
|
|
1405
|
+
|
|
1406
|
+
except Exception as poll_error:
|
|
1407
|
+
spinner.text = f"🔄 Checking extension status (retry)..."
|
|
1408
|
+
live.update(spinner)
|
|
1409
|
+
time.sleep(2)
|
|
1410
|
+
|
|
1411
|
+
live.stop()
|
|
1412
|
+
console.print(
|
|
1413
|
+
f"[red]❌ Extension request timed out after {timeout_minutes} minutes[/red]"
|
|
1414
|
+
)
|
|
1415
|
+
console.print(
|
|
1416
|
+
f"[yellow]The extension may still be processing. Check status with: gpu-dev list[/yellow]"
|
|
1417
|
+
)
|
|
1418
|
+
return False # Return failure on timeout
|
|
1419
|
+
|
|
1420
|
+
except Exception as e:
|
|
1421
|
+
console.print(
|
|
1422
|
+
f"[red]❌ Error polling extension result: {str(e)}[/red]")
|
|
1423
|
+
return False
|
|
1424
|
+
|
|
1425
|
+
def get_cluster_status(self) -> Optional[Dict[str, Any]]:
|
|
1426
|
+
"""Get overall GPU cluster status from availability table"""
|
|
1427
|
+
try:
|
|
1428
|
+
# Get reservations with pagination
|
|
1429
|
+
reservations_response = self.reservations_table.scan()
|
|
1430
|
+
reservations = reservations_response.get("Items", [])
|
|
1431
|
+
|
|
1432
|
+
# Handle pagination for admin stats scan
|
|
1433
|
+
while "LastEvaluatedKey" in reservations_response:
|
|
1434
|
+
reservations_response = self.reservations_table.scan(
|
|
1435
|
+
ExclusiveStartKey=reservations_response["LastEvaluatedKey"]
|
|
1436
|
+
)
|
|
1437
|
+
reservations.extend(reservations_response.get("Items", []))
|
|
1438
|
+
|
|
1439
|
+
# Get total GPUs from availability table
|
|
1440
|
+
availability_info = self.get_gpu_availability_by_type()
|
|
1441
|
+
total_gpus = 0
|
|
1442
|
+
available_gpus = 0
|
|
1443
|
+
|
|
1444
|
+
if availability_info:
|
|
1445
|
+
for gpu_type, info in availability_info.items():
|
|
1446
|
+
total_gpus += info.get("total", 0)
|
|
1447
|
+
available_gpus += info.get("available", 0)
|
|
1448
|
+
|
|
1449
|
+
# Calculate stats
|
|
1450
|
+
active_reservations = [
|
|
1451
|
+
r for r in reservations if r.get("status") == "active"
|
|
1452
|
+
]
|
|
1453
|
+
reserved_gpus = sum(int(r.get("gpu_count", 0))
|
|
1454
|
+
for r in active_reservations)
|
|
1455
|
+
|
|
1456
|
+
# Get queue length
|
|
1457
|
+
try:
|
|
1458
|
+
queue_url = self.config.get_queue_url()
|
|
1459
|
+
queue_attrs = self.config.sqs_client.get_queue_attributes(
|
|
1460
|
+
QueueUrl=queue_url, AttributeNames=[
|
|
1461
|
+
"ApproximateNumberOfMessages"]
|
|
1462
|
+
)
|
|
1463
|
+
queue_length = int(
|
|
1464
|
+
queue_attrs["Attributes"]["ApproximateNumberOfMessages"]
|
|
1465
|
+
)
|
|
1466
|
+
except:
|
|
1467
|
+
queue_length = len(
|
|
1468
|
+
[r for r in reservations if r.get("status") == "pending"]
|
|
1469
|
+
)
|
|
1470
|
+
|
|
1471
|
+
return {
|
|
1472
|
+
"total_gpus": total_gpus,
|
|
1473
|
+
"available_gpus": available_gpus,
|
|
1474
|
+
"reserved_gpus": reserved_gpus,
|
|
1475
|
+
"active_reservations": len(active_reservations),
|
|
1476
|
+
"queue_length": queue_length,
|
|
1477
|
+
}
|
|
1478
|
+
|
|
1479
|
+
except Exception as e:
|
|
1480
|
+
console.print(
|
|
1481
|
+
f"[red]❌ Error getting cluster status: {str(e)}[/red]")
|
|
1482
|
+
return None
|
|
1483
|
+
|
|
1484
|
+
def _wait_for_reservations_completion(
|
|
1485
|
+
self, reservation_ids: List[str], timeout_minutes: Optional[int] = 10, is_multinode: bool = False, verbose: bool = False
|
|
1486
|
+
) -> Optional[List[Dict[str, Any]]]:
|
|
1487
|
+
"""Shared polling logic for both single and multinode reservations (always creates SSH config)"""
|
|
1488
|
+
|
|
1489
|
+
status_messages = {
|
|
1490
|
+
"pending": "⏳ Reservation request submitted, waiting for processing...",
|
|
1491
|
+
"queued": "📋 In queue - waiting for GPU resources...",
|
|
1492
|
+
"preparing": "🚀 GPUs found! Preparing your development environment...",
|
|
1493
|
+
"creating_server": "🐳 Building custom Docker image...",
|
|
1494
|
+
"active": "✅ Reservation complete!",
|
|
1495
|
+
"failed": "❌ Reservation failed",
|
|
1496
|
+
"cancelled": "🛑 Reservation cancelled",
|
|
1497
|
+
}
|
|
1498
|
+
|
|
1499
|
+
start_time = time.time()
|
|
1500
|
+
timeout_seconds = timeout_minutes * 60 if timeout_minutes is not None else None
|
|
1501
|
+
last_status = None
|
|
1502
|
+
last_message = None
|
|
1503
|
+
cancelled = False
|
|
1504
|
+
close_tool = False
|
|
1505
|
+
show_queue_help = True
|
|
1506
|
+
queue_state = {"initial_estimated_wait": None,
|
|
1507
|
+
"queue_start_time": None}
|
|
1508
|
+
total_nodes = len(reservation_ids)
|
|
1509
|
+
|
|
1510
|
+
# Track previous node statuses to only show changes
|
|
1511
|
+
previous_node_statuses = {}
|
|
1512
|
+
|
|
1513
|
+
def handle_interrupt(signum, frame):
|
|
1514
|
+
"""Handle Ctrl+C to cancel reservation(s)"""
|
|
1515
|
+
nonlocal cancelled
|
|
1516
|
+
cancelled = True
|
|
1517
|
+
|
|
1518
|
+
def handle_clean_exit(signum, frame):
|
|
1519
|
+
"""Handle clean exit signal (SIGTERM)"""
|
|
1520
|
+
nonlocal close_tool
|
|
1521
|
+
close_tool = True
|
|
1522
|
+
reservation_text = "reservations" if is_multinode else "reservation"
|
|
1523
|
+
console.print(
|
|
1524
|
+
f"\n[cyan]🔄 Clean exit requested - keeping {reservation_text} active...[/cyan]"
|
|
1525
|
+
)
|
|
1526
|
+
|
|
1527
|
+
def check_keyboard_input():
|
|
1528
|
+
"""Check if clean exit was requested via signal"""
|
|
1529
|
+
return close_tool
|
|
1530
|
+
|
|
1531
|
+
# Set up signal handlers
|
|
1532
|
+
signal.signal(signal.SIGTERM, handle_clean_exit)
|
|
1533
|
+
try:
|
|
1534
|
+
signal.signal(signal.SIGQUIT, handle_clean_exit)
|
|
1535
|
+
action_text = "cancel all reservations" if is_multinode else "cancel reservation"
|
|
1536
|
+
keep_text = "keep reservations" if is_multinode else "keep reservation"
|
|
1537
|
+
console.print(
|
|
1538
|
+
f"[dim]💡 Press [cyan]Ctrl+C[/cyan] to {action_text} • Press [cyan]Ctrl+backslash[/cyan] to exit but {keep_text}[/dim]"
|
|
1539
|
+
)
|
|
1540
|
+
except (AttributeError, OSError):
|
|
1541
|
+
action_text = "cancel all reservations" if is_multinode else "cancel reservation"
|
|
1542
|
+
keep_text = "keep reservations" if is_multinode else "keep reservation"
|
|
1543
|
+
console.print(
|
|
1544
|
+
f"[dim]💡 Press [cyan]Ctrl+C[/cyan] to {action_text} • Send [cyan]SIGTERM[/cyan] to exit but {keep_text}[/dim]"
|
|
1545
|
+
)
|
|
1546
|
+
console.print(
|
|
1547
|
+
f"[dim] (From another terminal: [cyan]kill {os.getpid()}[/cyan])[/dim]"
|
|
1548
|
+
)
|
|
1549
|
+
|
|
1550
|
+
# Set up signal handler for Ctrl+C
|
|
1551
|
+
old_handler = signal.signal(signal.SIGINT, handle_interrupt)
|
|
1552
|
+
|
|
1553
|
+
try:
|
|
1554
|
+
with Live(console=console, refresh_per_second=4) as live:
|
|
1555
|
+
initial_text = f"📡 Starting multinode reservation..." if is_multinode else "🔄 Sending reservation request..."
|
|
1556
|
+
spinner = Spinner("dots", text=initial_text)
|
|
1557
|
+
live.update(spinner)
|
|
1558
|
+
|
|
1559
|
+
while (
|
|
1560
|
+
(timeout_seconds is None or time.time() -
|
|
1561
|
+
start_time < timeout_seconds)
|
|
1562
|
+
and not cancelled
|
|
1563
|
+
and not close_tool
|
|
1564
|
+
):
|
|
1565
|
+
try:
|
|
1566
|
+
# Check for keyboard input (clean exit)
|
|
1567
|
+
if check_keyboard_input():
|
|
1568
|
+
break
|
|
1569
|
+
|
|
1570
|
+
# Get current status of all reservations
|
|
1571
|
+
all_reservations = []
|
|
1572
|
+
node_details = []
|
|
1573
|
+
|
|
1574
|
+
for i, res_id in enumerate(reservation_ids):
|
|
1575
|
+
try:
|
|
1576
|
+
response = self.reservations_table.get_item(
|
|
1577
|
+
Key={"reservation_id": res_id})
|
|
1578
|
+
if "Item" in response:
|
|
1579
|
+
reservation = response["Item"]
|
|
1580
|
+
all_reservations.append(reservation)
|
|
1581
|
+
|
|
1582
|
+
status = reservation.get(
|
|
1583
|
+
"status", "unknown")
|
|
1584
|
+
failure_reason = reservation.get(
|
|
1585
|
+
"failure_reason", "")
|
|
1586
|
+
current_detailed_status = reservation.get(
|
|
1587
|
+
"current_detailed_status", "")
|
|
1588
|
+
queue_position = reservation.get(
|
|
1589
|
+
"queue_position", "?")
|
|
1590
|
+
estimated_wait = reservation.get(
|
|
1591
|
+
"estimated_wait_minutes", "?")
|
|
1592
|
+
gpu_count = reservation.get("gpu_count", 1)
|
|
1593
|
+
|
|
1594
|
+
# Debug what we're reading from DynamoDB - only show if status changed
|
|
1595
|
+
if verbose:
|
|
1596
|
+
node_key = f"node_{i+1}_{res_id[:8]}"
|
|
1597
|
+
current_node_status = f"status={status}, detailed={current_detailed_status}"
|
|
1598
|
+
if previous_node_statuses.get(node_key) != current_node_status:
|
|
1599
|
+
print(
|
|
1600
|
+
f"[DEBUG] Node {i+1} ({res_id[:8]}): {current_node_status}")
|
|
1601
|
+
previous_node_statuses[node_key] = current_node_status
|
|
1602
|
+
|
|
1603
|
+
node_details.append({
|
|
1604
|
+
"index": i,
|
|
1605
|
+
"status": status,
|
|
1606
|
+
"failure_reason": failure_reason,
|
|
1607
|
+
"current_detailed_status": current_detailed_status,
|
|
1608
|
+
"queue_position": queue_position,
|
|
1609
|
+
"estimated_wait": estimated_wait,
|
|
1610
|
+
"gpu_count": gpu_count,
|
|
1611
|
+
"reservation": reservation
|
|
1612
|
+
})
|
|
1613
|
+
else:
|
|
1614
|
+
# No reservation found yet, keep waiting
|
|
1615
|
+
if not is_multinode:
|
|
1616
|
+
spinner.text = "📡 Waiting for reservation status update..."
|
|
1617
|
+
live.update(spinner)
|
|
1618
|
+
time.sleep(2)
|
|
1619
|
+
continue
|
|
1620
|
+
else:
|
|
1621
|
+
node_details.append({
|
|
1622
|
+
"index": i, "status": "unknown", "failure_reason": "",
|
|
1623
|
+
"current_detailed_status": "", "queue_position": "?",
|
|
1624
|
+
"estimated_wait": "?", "gpu_count": 0, "reservation": None
|
|
1625
|
+
})
|
|
1626
|
+
except Exception as e:
|
|
1627
|
+
if verbose:
|
|
1628
|
+
print(
|
|
1629
|
+
f"[DEBUG] Exception querying {res_id[:8]}: {e}")
|
|
1630
|
+
node_details.append({
|
|
1631
|
+
"index": i, "status": "error", "failure_reason": "Connection error",
|
|
1632
|
+
"current_detailed_status": "", "queue_position": "?",
|
|
1633
|
+
"estimated_wait": "?", "gpu_count": 0, "reservation": None
|
|
1634
|
+
})
|
|
1635
|
+
|
|
1636
|
+
# Calculate aggregate status
|
|
1637
|
+
statuses = [node["status"] for node in node_details]
|
|
1638
|
+
active_count = statuses.count("active")
|
|
1639
|
+
failed_count = statuses.count("failed")
|
|
1640
|
+
cancelled_count = statuses.count("cancelled")
|
|
1641
|
+
preparing_count = statuses.count("preparing")
|
|
1642
|
+
queued_count = statuses.count("queued")
|
|
1643
|
+
|
|
1644
|
+
# Debug multinode status calculation - only show when aggregate status changes
|
|
1645
|
+
# Only when there are mixed statuses
|
|
1646
|
+
if is_multinode and verbose and len(set(statuses)) > 1:
|
|
1647
|
+
print(
|
|
1648
|
+
f"[DEBUG] Mixed node statuses: active={active_count}, preparing={preparing_count}, queued={queued_count}, failed={failed_count}, total={total_nodes}")
|
|
1649
|
+
|
|
1650
|
+
# Determine aggregate status for multinode reservations
|
|
1651
|
+
# Only consider it failed if ALL nodes are explicitly failed/cancelled
|
|
1652
|
+
# or if there's a significant portion failed (more than half)
|
|
1653
|
+
if is_multinode:
|
|
1654
|
+
# For multinode, be more conservative about declaring failure
|
|
1655
|
+
if failed_count + cancelled_count >= total_nodes:
|
|
1656
|
+
# All nodes failed - definitely failed
|
|
1657
|
+
aggregate_status = "failed"
|
|
1658
|
+
elif active_count == total_nodes:
|
|
1659
|
+
# All nodes active - success
|
|
1660
|
+
aggregate_status = "active"
|
|
1661
|
+
elif active_count + preparing_count == total_nodes:
|
|
1662
|
+
# All nodes either active or preparing - still working
|
|
1663
|
+
aggregate_status = "preparing" if preparing_count > 0 else "active"
|
|
1664
|
+
elif queued_count > 0:
|
|
1665
|
+
# Any nodes queued - still in queue
|
|
1666
|
+
aggregate_status = "queued"
|
|
1667
|
+
elif failed_count + cancelled_count > total_nodes // 2:
|
|
1668
|
+
# More than half failed - likely a real failure
|
|
1669
|
+
aggregate_status = "failed"
|
|
1670
|
+
else:
|
|
1671
|
+
# Mixed state - keep preparing/pending
|
|
1672
|
+
aggregate_status = "preparing" if preparing_count > 0 else "pending"
|
|
1673
|
+
|
|
1674
|
+
# Debug aggregate status decision - only show when status changes
|
|
1675
|
+
if is_multinode and verbose and aggregate_status != last_status:
|
|
1676
|
+
print(
|
|
1677
|
+
f"[DEBUG] Calculated aggregate_status: {aggregate_status}")
|
|
1678
|
+
|
|
1679
|
+
else:
|
|
1680
|
+
# Single node - use original logic
|
|
1681
|
+
if failed_count > 0 or cancelled_count > 0:
|
|
1682
|
+
aggregate_status = "failed"
|
|
1683
|
+
elif active_count == total_nodes:
|
|
1684
|
+
aggregate_status = "active"
|
|
1685
|
+
elif preparing_count > 0:
|
|
1686
|
+
aggregate_status = "preparing"
|
|
1687
|
+
elif queued_count > 0:
|
|
1688
|
+
aggregate_status = "queued"
|
|
1689
|
+
else:
|
|
1690
|
+
# Check for creating_server status
|
|
1691
|
+
creating_server_count = statuses.count(
|
|
1692
|
+
"creating_server")
|
|
1693
|
+
if creating_server_count > 0:
|
|
1694
|
+
aggregate_status = "creating_server"
|
|
1695
|
+
else:
|
|
1696
|
+
aggregate_status = "pending"
|
|
1697
|
+
|
|
1698
|
+
# Build status message based on aggregate status and mode
|
|
1699
|
+
message = ""
|
|
1700
|
+
|
|
1701
|
+
if aggregate_status == "queued":
|
|
1702
|
+
# Use first queued node's info for display
|
|
1703
|
+
queued_nodes = [
|
|
1704
|
+
node for node in node_details if node["status"] == "queued"]
|
|
1705
|
+
if queued_nodes:
|
|
1706
|
+
first_queued = queued_nodes[0]
|
|
1707
|
+
queue_position = first_queued["queue_position"]
|
|
1708
|
+
estimated_wait = first_queued["estimated_wait"]
|
|
1709
|
+
|
|
1710
|
+
# Initialize countdown logic
|
|
1711
|
+
if (
|
|
1712
|
+
aggregate_status != last_status and estimated_wait != "?"
|
|
1713
|
+
) or (
|
|
1714
|
+
estimated_wait != "?"
|
|
1715
|
+
and queue_state["initial_estimated_wait"] is None
|
|
1716
|
+
):
|
|
1717
|
+
try:
|
|
1718
|
+
wait_minutes = (
|
|
1719
|
+
int(estimated_wait)
|
|
1720
|
+
if isinstance(estimated_wait, (int, str))
|
|
1721
|
+
and str(estimated_wait).isdigit()
|
|
1722
|
+
else None
|
|
1723
|
+
)
|
|
1724
|
+
if wait_minutes is not None:
|
|
1725
|
+
queue_state["initial_estimated_wait"] = wait_minutes
|
|
1726
|
+
queue_state["queue_start_time"] = time.time(
|
|
1727
|
+
)
|
|
1728
|
+
except (ValueError, TypeError):
|
|
1729
|
+
pass
|
|
1730
|
+
|
|
1731
|
+
# Calculate dynamic countdown
|
|
1732
|
+
if (
|
|
1733
|
+
queue_state["initial_estimated_wait"] is not None
|
|
1734
|
+
and queue_state["queue_start_time"] is not None
|
|
1735
|
+
):
|
|
1736
|
+
elapsed_minutes = (
|
|
1737
|
+
time.time() -
|
|
1738
|
+
queue_state["queue_start_time"]
|
|
1739
|
+
) / 60
|
|
1740
|
+
remaining_wait = max(
|
|
1741
|
+
0,
|
|
1742
|
+
queue_state["initial_estimated_wait"] -
|
|
1743
|
+
elapsed_minutes,
|
|
1744
|
+
)
|
|
1745
|
+
wait_display = (
|
|
1746
|
+
f"{remaining_wait:.0f} min"
|
|
1747
|
+
if remaining_wait > 0
|
|
1748
|
+
else "Soon"
|
|
1749
|
+
)
|
|
1750
|
+
else:
|
|
1751
|
+
wait_display = (
|
|
1752
|
+
f"{estimated_wait} min"
|
|
1753
|
+
if estimated_wait != "?"
|
|
1754
|
+
else "Calculating..."
|
|
1755
|
+
)
|
|
1756
|
+
|
|
1757
|
+
if is_multinode:
|
|
1758
|
+
total_gpus = sum(
|
|
1759
|
+
node["gpu_count"] for node in node_details if node["reservation"])
|
|
1760
|
+
message = f"📋 Position #{queue_position} in queue • Estimated wait: {wait_display} • {total_gpus} GPUs across {total_nodes} nodes"
|
|
1761
|
+
else:
|
|
1762
|
+
gpu_count = first_queued["gpu_count"]
|
|
1763
|
+
message = f"📋 You are #{queue_position} in queue • Estimated wait: {wait_display} • {gpu_count} GPU(s) requested"
|
|
1764
|
+
|
|
1765
|
+
# Show help message once when entering queue
|
|
1766
|
+
if show_queue_help and aggregate_status != last_status:
|
|
1767
|
+
help_text = "\n[dim]💡 Press [cyan]Ctrl+C[/cyan] to cancel reservation • Use [cyan]gpu-dev list[/cyan] to check status[/dim]"
|
|
1768
|
+
console.print(help_text)
|
|
1769
|
+
show_queue_help = False
|
|
1770
|
+
else:
|
|
1771
|
+
message = f"📋 Nodes in queue... ({active_count}/{total_nodes} ready)" if is_multinode else "📋 In queue..."
|
|
1772
|
+
|
|
1773
|
+
elif aggregate_status == "preparing":
|
|
1774
|
+
if is_multinode:
|
|
1775
|
+
# Show detailed preparation info for multinode - show ALL nodes, not just preparing ones
|
|
1776
|
+
detailed_events = []
|
|
1777
|
+
for node in node_details:
|
|
1778
|
+
node_status = node["status"]
|
|
1779
|
+
if node_status == "active":
|
|
1780
|
+
detailed_events.append(f"✓ Ready")
|
|
1781
|
+
elif node.get("current_detailed_status"):
|
|
1782
|
+
detailed_events.append(
|
|
1783
|
+
node["current_detailed_status"])
|
|
1784
|
+
elif node.get("failure_reason") and node_status == "failed":
|
|
1785
|
+
detailed_events.append(
|
|
1786
|
+
node["failure_reason"])
|
|
1787
|
+
elif node_status == "preparing":
|
|
1788
|
+
detailed_events.append(
|
|
1789
|
+
"Preparing environment...")
|
|
1790
|
+
elif node_status in ["pending", "queued"]:
|
|
1791
|
+
detailed_events.append(
|
|
1792
|
+
f"{node_status.title()}...")
|
|
1793
|
+
else:
|
|
1794
|
+
detailed_events.append(node_status)
|
|
1795
|
+
|
|
1796
|
+
# Increased from 4 to 16
|
|
1797
|
+
if detailed_events and len(detailed_events) <= 16:
|
|
1798
|
+
# For multinode, create a custom multi-line display with individual spinners
|
|
1799
|
+
from rich.table import Table
|
|
1800
|
+
from rich.text import Text
|
|
1801
|
+
from rich.panel import Panel
|
|
1802
|
+
from rich.console import Group
|
|
1803
|
+
|
|
1804
|
+
# Create a list of renderable items
|
|
1805
|
+
node_lines = []
|
|
1806
|
+
|
|
1807
|
+
for i, event in enumerate(detailed_events):
|
|
1808
|
+
node_num = i + 1
|
|
1809
|
+
node_status = node_details[i]["status"]
|
|
1810
|
+
|
|
1811
|
+
# Create individual spinner or checkmark for each node
|
|
1812
|
+
if node_status == "active":
|
|
1813
|
+
# Ready node - show checkmark without spinner
|
|
1814
|
+
line = Text(
|
|
1815
|
+
f"✓ Node {node_num}: Ready", style="green")
|
|
1816
|
+
else:
|
|
1817
|
+
# Not ready - create a spinner for this specific node
|
|
1818
|
+
node_spinner = Spinner(
|
|
1819
|
+
"dots", text=f"Node {node_num}: {event}")
|
|
1820
|
+
line = node_spinner
|
|
1821
|
+
|
|
1822
|
+
node_lines.append(line)
|
|
1823
|
+
|
|
1824
|
+
# Group all lines together
|
|
1825
|
+
group = Group(*node_lines)
|
|
1826
|
+
|
|
1827
|
+
# Add summary line
|
|
1828
|
+
summary = Text(
|
|
1829
|
+
f"({active_count}/{total_nodes} ready)", style="cyan")
|
|
1830
|
+
full_display = Group(group, summary)
|
|
1831
|
+
|
|
1832
|
+
# Update live display with all spinners
|
|
1833
|
+
panel = Panel(
|
|
1834
|
+
full_display, title="🚀 Multinode Setup", expand=False)
|
|
1835
|
+
live.update(panel)
|
|
1836
|
+
|
|
1837
|
+
# Don't set message since we're using custom display
|
|
1838
|
+
message = None
|
|
1839
|
+
else:
|
|
1840
|
+
# Summarize if we have many nodes
|
|
1841
|
+
preparing_count = statuses.count(
|
|
1842
|
+
"preparing")
|
|
1843
|
+
message = f"🚀 Preparing {preparing_count} nodes... ({active_count}/{total_nodes} ready)"
|
|
1844
|
+
else:
|
|
1845
|
+
# Show detailed preparation info for single node
|
|
1846
|
+
preparing_nodes = [
|
|
1847
|
+
node for node in node_details if node["status"] == "preparing"]
|
|
1848
|
+
node = preparing_nodes[0] if preparing_nodes else node_details[0]
|
|
1849
|
+
|
|
1850
|
+
# Use unified status tracking - prefer current_detailed_status, fall back to failure_reason for actual failures
|
|
1851
|
+
current_detailed_status = node.get(
|
|
1852
|
+
"current_detailed_status", "")
|
|
1853
|
+
failure_reason = node.get("failure_reason", "") if node.get(
|
|
1854
|
+
"status") == "failed" else ""
|
|
1855
|
+
|
|
1856
|
+
if current_detailed_status:
|
|
1857
|
+
message = f"🚀 {current_detailed_status}"
|
|
1858
|
+
elif failure_reason:
|
|
1859
|
+
message = f"🚀 Failed: {failure_reason}"
|
|
1860
|
+
else:
|
|
1861
|
+
message = status_messages.get(
|
|
1862
|
+
aggregate_status, f"Status: {aggregate_status}")
|
|
1863
|
+
|
|
1864
|
+
elif aggregate_status == "failed":
|
|
1865
|
+
failed_nodes = [node for node in node_details if node["status"] in [
|
|
1866
|
+
"failed", "cancelled"]]
|
|
1867
|
+
if is_multinode:
|
|
1868
|
+
failure_details = []
|
|
1869
|
+
for node in failed_nodes:
|
|
1870
|
+
reason = node["failure_reason"] if node["failure_reason"] else node["status"]
|
|
1871
|
+
failure_details.append(
|
|
1872
|
+
f"Node {node['index']+1}: {reason}")
|
|
1873
|
+
|
|
1874
|
+
# Add debug info about all node statuses for troubleshooting
|
|
1875
|
+
debug_details = []
|
|
1876
|
+
for node in node_details:
|
|
1877
|
+
status = node["status"]
|
|
1878
|
+
debug_details.append(
|
|
1879
|
+
f"Node {node['index']+1}: {status}")
|
|
1880
|
+
|
|
1881
|
+
status_display = "\n".join(
|
|
1882
|
+
[f" {detail}" for detail in failure_details])
|
|
1883
|
+
debug_display = "\n".join(
|
|
1884
|
+
[f" {detail}" for detail in debug_details])
|
|
1885
|
+
message = f"❌ Multinode failed ({failed_count + cancelled_count}/{total_nodes})\n{status_display}"
|
|
1886
|
+
live.update(Spinner("dots", text=message))
|
|
1887
|
+
time.sleep(2)
|
|
1888
|
+
console.print(
|
|
1889
|
+
f"\n[red]❌ Multinode reservation failed ({failed_count + cancelled_count}/{total_nodes} nodes failed)[/red]")
|
|
1890
|
+
for detail in failure_details:
|
|
1891
|
+
console.print(f"[red] {detail}[/red]")
|
|
1892
|
+
console.print(
|
|
1893
|
+
f"\n[dim]Debug - All node statuses:[/dim]")
|
|
1894
|
+
for detail in debug_details:
|
|
1895
|
+
console.print(f"[dim] {detail}[/dim]")
|
|
1896
|
+
return None
|
|
1897
|
+
else:
|
|
1898
|
+
# Handle single node failure below in completion check
|
|
1899
|
+
pass
|
|
1900
|
+
|
|
1901
|
+
elif aggregate_status == "active":
|
|
1902
|
+
if is_multinode:
|
|
1903
|
+
# Check if all nodes are truly ready: "active" status AND valid SSH command
|
|
1904
|
+
nodes_ready = 0
|
|
1905
|
+
for node in node_details:
|
|
1906
|
+
if (node["status"] == "active" and
|
|
1907
|
+
node["reservation"] and
|
|
1908
|
+
node["reservation"].get("ssh_command", "ssh user@pending") not in ["ssh user@pending"] and
|
|
1909
|
+
not node["reservation"].get("ssh_command", "").endswith(".cluster.local")):
|
|
1910
|
+
nodes_ready += 1
|
|
1911
|
+
|
|
1912
|
+
if nodes_ready == total_nodes:
|
|
1913
|
+
# All nodes truly ready with SSH access
|
|
1914
|
+
live.update(
|
|
1915
|
+
Spinner("dots", text=f"✅ All {total_nodes} nodes ready!"))
|
|
1916
|
+
time.sleep(1)
|
|
1917
|
+
console.print(
|
|
1918
|
+
f"\n[green]✅ Multinode reservation complete! All {total_nodes} nodes are ready.[/green]")
|
|
1919
|
+
|
|
1920
|
+
# Create SSH config files and show connection info for each node
|
|
1921
|
+
for node in node_details:
|
|
1922
|
+
if node["reservation"]:
|
|
1923
|
+
res = node["reservation"]
|
|
1924
|
+
fqdn = res.get("fqdn")
|
|
1925
|
+
pod_name = res.get("pod_name")
|
|
1926
|
+
res_id = res.get("reservation_id")
|
|
1927
|
+
res_name = res.get("name")
|
|
1928
|
+
|
|
1929
|
+
# Create SSH config file for this node
|
|
1930
|
+
config_path = None
|
|
1931
|
+
use_include = False
|
|
1932
|
+
if fqdn and pod_name and res_id:
|
|
1933
|
+
try:
|
|
1934
|
+
config_path, use_include = create_ssh_config_for_reservation(
|
|
1935
|
+
fqdn, pod_name, res_id, res_name)
|
|
1936
|
+
except Exception as e:
|
|
1937
|
+
console.print(
|
|
1938
|
+
f"[yellow]⚠️ Could not create SSH config for node {node['index']+1}: {str(e)}[/yellow]")
|
|
1939
|
+
|
|
1940
|
+
# Show connection info
|
|
1941
|
+
if config_path and pod_name and use_include:
|
|
1942
|
+
console.print(
|
|
1943
|
+
f"[cyan]🖥️ Node {node['index']+1}:[/cyan] [green]ssh {pod_name}[/green]")
|
|
1944
|
+
else:
|
|
1945
|
+
ssh_command = res.get(
|
|
1946
|
+
"ssh_command", "ssh user@pending")
|
|
1947
|
+
ssh_with_forwarding = _add_agent_forwarding_to_ssh(
|
|
1948
|
+
ssh_command)
|
|
1949
|
+
console.print(
|
|
1950
|
+
f"[cyan]🖥️ Node {node['index']+1}:[/cyan] {ssh_with_forwarding}")
|
|
1951
|
+
|
|
1952
|
+
return all_reservations
|
|
1953
|
+
else:
|
|
1954
|
+
# Some nodes are "active" but SSH not ready yet - keep preparing
|
|
1955
|
+
# For multinode, don't override detailed Panel display with summary message
|
|
1956
|
+
# The preparing logic above will show detailed per-node status
|
|
1957
|
+
message = f"🚀 Setting up SSH access... ({nodes_ready}/{total_nodes} ready)"
|
|
1958
|
+
# Don't directly update spinner here - let the main logic handle display
|
|
1959
|
+
else:
|
|
1960
|
+
# Handle single node completion below in completion check
|
|
1961
|
+
pass
|
|
1962
|
+
|
|
1963
|
+
else:
|
|
1964
|
+
# Default pending/unknown status
|
|
1965
|
+
if is_multinode:
|
|
1966
|
+
message = f"⏳ Processing multinode reservation... ({active_count}/{total_nodes} ready)"
|
|
1967
|
+
else:
|
|
1968
|
+
# Check for detailed status during Docker builds or other detailed operations
|
|
1969
|
+
if aggregate_status == "creating_server" and len(all_reservations) > 0:
|
|
1970
|
+
reservation = all_reservations[0]
|
|
1971
|
+
current_detailed_status = reservation.get(
|
|
1972
|
+
"current_detailed_status", "")
|
|
1973
|
+
if current_detailed_status:
|
|
1974
|
+
message = f"🐳 {current_detailed_status}"
|
|
1975
|
+
else:
|
|
1976
|
+
message = status_messages.get(
|
|
1977
|
+
aggregate_status, f"Status: {aggregate_status}")
|
|
1978
|
+
else:
|
|
1979
|
+
message = status_messages.get(
|
|
1980
|
+
aggregate_status, f"Status: {aggregate_status}")
|
|
1981
|
+
|
|
1982
|
+
# Update spinner if status changed, message changed, or we're in certain states
|
|
1983
|
+
# BUT: Don't override custom Panel display for multinode with spinner
|
|
1984
|
+
if (aggregate_status != last_status or
|
|
1985
|
+
message != last_message or
|
|
1986
|
+
aggregate_status in ["queued", "preparing", "creating_server"]):
|
|
1987
|
+
if message and not (is_multinode and aggregate_status == "preparing"):
|
|
1988
|
+
# Only use spinner for single-node or non-preparing multinode states
|
|
1989
|
+
spinner.text = message
|
|
1990
|
+
last_status = aggregate_status
|
|
1991
|
+
last_message = message
|
|
1992
|
+
live.update(spinner)
|
|
1993
|
+
elif not is_multinode and message:
|
|
1994
|
+
# Single node - always use spinner
|
|
1995
|
+
spinner.text = message
|
|
1996
|
+
last_status = aggregate_status
|
|
1997
|
+
last_message = message
|
|
1998
|
+
live.update(spinner)
|
|
1999
|
+
# For multinode preparing with custom display, we already updated above with Panel
|
|
2000
|
+
|
|
2001
|
+
# Check for single-node completion states (when not multinode or already handled above)
|
|
2002
|
+
if not is_multinode and aggregate_status == "active":
|
|
2003
|
+
reservation = all_reservations[0]
|
|
2004
|
+
ssh_command = reservation.get(
|
|
2005
|
+
"ssh_command", "ssh user@pending")
|
|
2006
|
+
|
|
2007
|
+
# Only complete if we have a real SSH command (not pending/placeholder)
|
|
2008
|
+
if ssh_command != "ssh user@pending" and not ssh_command.endswith(".cluster.local"):
|
|
2009
|
+
live.stop()
|
|
2010
|
+
duration_hours = reservation.get(
|
|
2011
|
+
"duration_hours", 8)
|
|
2012
|
+
reservation_id = reservation["reservation_id"]
|
|
2013
|
+
|
|
2014
|
+
console.print(
|
|
2015
|
+
f"\n[green]✅ Reservation complete![/green]")
|
|
2016
|
+
console.print(
|
|
2017
|
+
f"[cyan]📋 Reservation ID:[/cyan] {reservation_id}")
|
|
2018
|
+
console.print(
|
|
2019
|
+
f"[cyan]⏰ Valid for:[/cyan] {duration_hours} hours")
|
|
2020
|
+
|
|
2021
|
+
# Show quick connect command
|
|
2022
|
+
short_id = reservation_id[:8]
|
|
2023
|
+
console.print(
|
|
2024
|
+
f"[cyan]⚡ Quick Connect:[/cyan] [green]gpu-dev connect {short_id}[/green]")
|
|
2025
|
+
|
|
2026
|
+
# Always create SSH config file for this reservation
|
|
2027
|
+
fqdn = reservation.get("fqdn")
|
|
2028
|
+
pod_name = reservation.get("pod_name")
|
|
2029
|
+
res_id = reservation.get("reservation_id")
|
|
2030
|
+
res_name = reservation.get("name")
|
|
2031
|
+
config_path = None
|
|
2032
|
+
use_include = False
|
|
2033
|
+
if fqdn and pod_name and res_id:
|
|
2034
|
+
try:
|
|
2035
|
+
config_path, use_include = create_ssh_config_for_reservation(
|
|
2036
|
+
fqdn, pod_name, res_id, res_name)
|
|
2037
|
+
except Exception as e:
|
|
2038
|
+
console.print(
|
|
2039
|
+
f"[yellow]⚠️ Could not create SSH config: {str(e)}[/yellow]")
|
|
2040
|
+
|
|
2041
|
+
# Show SSH command using config file if created, otherwise fallback
|
|
2042
|
+
if config_path and pod_name:
|
|
2043
|
+
if use_include:
|
|
2044
|
+
# User approved Include - show simple commands
|
|
2045
|
+
console.print(
|
|
2046
|
+
f"[cyan]🖥️ SSH Command:[/cyan] [green]ssh {pod_name}[/green]")
|
|
2047
|
+
# Create clickable VS Code link
|
|
2048
|
+
vscode_url = _make_vscode_link(pod_name)
|
|
2049
|
+
vscode_command = f"code --remote ssh-remote+{pod_name} /home/dev"
|
|
2050
|
+
console.print(
|
|
2051
|
+
f"[cyan]💻 VS Code Remote:[/cyan] [link={vscode_url}][green]{vscode_command}[/green][/link]")
|
|
2052
|
+
|
|
2053
|
+
# Create clickable Cursor link
|
|
2054
|
+
cursor_url = _make_cursor_link(pod_name)
|
|
2055
|
+
cursor_command = f"cursor --remote ssh-remote+{pod_name} /home/dev"
|
|
2056
|
+
console.print(
|
|
2057
|
+
f"[cyan]🖥️ Cursor Remote:[/cyan] [link={cursor_url}][green]{cursor_command}[/green][/link]")
|
|
2058
|
+
else:
|
|
2059
|
+
# User declined Include - show commands with -F flag
|
|
2060
|
+
console.print(
|
|
2061
|
+
f"[cyan]🖥️ SSH Command:[/cyan] [green]ssh -F {config_path} {pod_name}[/green]")
|
|
2062
|
+
console.print(
|
|
2063
|
+
f"[cyan]💻 VS Code/Cursor:[/cyan] Add [green]Include ~/.gpu-dev/*-sshconfig[/green] to ~/.ssh/config and ~/.cursor/ssh_config")
|
|
2064
|
+
console.print(
|
|
2065
|
+
f"[dim] Or run: [green]gpu-dev config ssh-include enable[/green][/dim]")
|
|
2066
|
+
else:
|
|
2067
|
+
# Fallback to full SSH command if config creation failed
|
|
2068
|
+
ssh_with_forwarding = _add_agent_forwarding_to_ssh(ssh_command)
|
|
2069
|
+
console.print(
|
|
2070
|
+
f"[cyan]🖥️ SSH Command:[/cyan] {ssh_with_forwarding}")
|
|
2071
|
+
|
|
2072
|
+
vscode_command = _generate_vscode_command(ssh_command)
|
|
2073
|
+
if vscode_command:
|
|
2074
|
+
console.print(
|
|
2075
|
+
f"[cyan]💻 VS Code Remote:[/cyan] {vscode_command}")
|
|
2076
|
+
|
|
2077
|
+
cursor_command = _generate_cursor_command(ssh_command)
|
|
2078
|
+
if cursor_command:
|
|
2079
|
+
console.print(
|
|
2080
|
+
f"[cyan]🖱️ Cursor Remote:[/cyan] {cursor_command}")
|
|
2081
|
+
|
|
2082
|
+
# Show Jupyter link if enabled
|
|
2083
|
+
jupyter_enabled = reservation.get(
|
|
2084
|
+
"jupyter_enabled", False)
|
|
2085
|
+
jupyter_url = reservation.get(
|
|
2086
|
+
"jupyter_url", "")
|
|
2087
|
+
if jupyter_enabled and jupyter_url:
|
|
2088
|
+
console.print(
|
|
2089
|
+
f"[cyan]📊 Jupyter Lab:[/cyan] {jupyter_url}")
|
|
2090
|
+
|
|
2091
|
+
return all_reservations
|
|
2092
|
+
else:
|
|
2093
|
+
# Still preparing - show status but don't complete yet
|
|
2094
|
+
current_detailed_status = reservation.get(
|
|
2095
|
+
"current_detailed_status", "")
|
|
2096
|
+
if current_detailed_status:
|
|
2097
|
+
message = f"🚀 {current_detailed_status}"
|
|
2098
|
+
else:
|
|
2099
|
+
message = "🚀 Setting up external SSH access..."
|
|
2100
|
+
|
|
2101
|
+
if message != (last_status if isinstance(last_status, str) else ""):
|
|
2102
|
+
spinner.text = message
|
|
2103
|
+
live.update(spinner)
|
|
2104
|
+
|
|
2105
|
+
elif not is_multinode and aggregate_status in ["failed", "cancelled"]:
|
|
2106
|
+
live.stop()
|
|
2107
|
+
reservation = all_reservations[0] if all_reservations else {
|
|
2108
|
+
}
|
|
2109
|
+
failure_reason = reservation.get(
|
|
2110
|
+
"failure_reason",
|
|
2111
|
+
reservation.get("current_detailed_status", "Unknown error"))
|
|
2112
|
+
reservation_id = reservation.get(
|
|
2113
|
+
"reservation_id", "unknown")
|
|
2114
|
+
|
|
2115
|
+
if aggregate_status == "failed":
|
|
2116
|
+
console.print(
|
|
2117
|
+
f"\n[red]❌ Reservation failed: {failure_reason}[/red]")
|
|
2118
|
+
console.print(
|
|
2119
|
+
f"[red]📋 Reservation ID: {reservation_id}[/red]")
|
|
2120
|
+
|
|
2121
|
+
# Show pod logs if available
|
|
2122
|
+
pod_logs = reservation.get("pod_logs", "")
|
|
2123
|
+
if pod_logs and pod_logs.strip():
|
|
2124
|
+
from rich.panel import Panel
|
|
2125
|
+
from rich.text import Text
|
|
2126
|
+
|
|
2127
|
+
console.print(
|
|
2128
|
+
"\n[red]🔍 Pod logs (last 20 lines) - Details:[/red]")
|
|
2129
|
+
log_text = Text(pod_logs)
|
|
2130
|
+
log_panel = Panel(
|
|
2131
|
+
log_text,
|
|
2132
|
+
title="🐚 Container Startup Logs",
|
|
2133
|
+
title_align="left",
|
|
2134
|
+
border_style="red",
|
|
2135
|
+
expand=False,
|
|
2136
|
+
)
|
|
2137
|
+
console.print(log_panel)
|
|
2138
|
+
else:
|
|
2139
|
+
console.print(
|
|
2140
|
+
f"\n[yellow]🛑 Reservation was cancelled[/yellow]")
|
|
2141
|
+
|
|
2142
|
+
return None
|
|
2143
|
+
|
|
2144
|
+
# Continue polling
|
|
2145
|
+
time.sleep(3)
|
|
2146
|
+
|
|
2147
|
+
except Exception as e:
|
|
2148
|
+
console.print(
|
|
2149
|
+
f"\n[red]❌ Error polling reservation status: {str(e)}[/red]")
|
|
2150
|
+
return None
|
|
2151
|
+
|
|
2152
|
+
# Handle cancellation
|
|
2153
|
+
if cancelled:
|
|
2154
|
+
live.stop()
|
|
2155
|
+
action_text = "multinode reservation" if is_multinode else "reservation request"
|
|
2156
|
+
console.print(
|
|
2157
|
+
f"\n[yellow]⚠️ Cancelling {action_text}...[/yellow]")
|
|
2158
|
+
|
|
2159
|
+
# Cancel all reservations
|
|
2160
|
+
success_count = 0
|
|
2161
|
+
for res_id in reservation_ids:
|
|
2162
|
+
try:
|
|
2163
|
+
response = self.reservations_table.get_item(
|
|
2164
|
+
Key={"reservation_id": res_id})
|
|
2165
|
+
if "Item" in response:
|
|
2166
|
+
user_id = response["Item"].get(
|
|
2167
|
+
"user_id", "unknown")
|
|
2168
|
+
if self.cancel_reservation(res_id, user_id):
|
|
2169
|
+
success_count += 1
|
|
2170
|
+
except Exception as e:
|
|
2171
|
+
console.print(
|
|
2172
|
+
f"[red]❌ Error cancelling reservation {res_id[:8]}: {str(e)}[/red]")
|
|
2173
|
+
|
|
2174
|
+
if success_count == len(reservation_ids):
|
|
2175
|
+
success_text = "All reservations cancelled successfully" if is_multinode else "Reservation cancelled successfully"
|
|
2176
|
+
console.print(f"[green]✅ {success_text}[/green]")
|
|
2177
|
+
elif success_count > 0:
|
|
2178
|
+
console.print(
|
|
2179
|
+
f"[yellow]⚠️ {success_count}/{len(reservation_ids)} reservations cancelled[/yellow]")
|
|
2180
|
+
else:
|
|
2181
|
+
fail_text = "Failed to cancel reservations" if is_multinode else "Failed to cancel reservation"
|
|
2182
|
+
console.print(f"[red]❌ {fail_text}[/red]")
|
|
2183
|
+
|
|
2184
|
+
return None
|
|
2185
|
+
|
|
2186
|
+
# Handle clean exit
|
|
2187
|
+
if close_tool:
|
|
2188
|
+
live.stop()
|
|
2189
|
+
if is_multinode:
|
|
2190
|
+
id_display = ", ".join([res_id[:8]
|
|
2191
|
+
for res_id in reservation_ids])
|
|
2192
|
+
console.print(
|
|
2193
|
+
f"\n[cyan]📱 Exiting - multinode reservations {id_display} continue in background...[/cyan]")
|
|
2194
|
+
else:
|
|
2195
|
+
console.print(
|
|
2196
|
+
f"\n[cyan]📱 Exiting - reservation {reservation_ids[0][:8]} continues in background...[/cyan]")
|
|
2197
|
+
console.print(
|
|
2198
|
+
"[cyan]💡 Use 'gpu-dev list' to check status[/cyan]")
|
|
2199
|
+
if not is_multinode:
|
|
2200
|
+
console.print(
|
|
2201
|
+
f"[cyan]💡 Use 'gpu-dev show {reservation_ids[0][:8]}' to get connection details when ready[/cyan]")
|
|
2202
|
+
return None
|
|
2203
|
+
|
|
2204
|
+
# Timeout reached
|
|
2205
|
+
live.stop()
|
|
2206
|
+
if timeout_minutes is not None:
|
|
2207
|
+
console.print(
|
|
2208
|
+
f"\n[yellow]⏰ Timeout reached after {timeout_minutes} minutes[/yellow]")
|
|
2209
|
+
else:
|
|
2210
|
+
console.print(
|
|
2211
|
+
f"\n[yellow]⏰ Polling stopped unexpectedly[/yellow]")
|
|
2212
|
+
console.print(
|
|
2213
|
+
"[yellow]🔍 Check reservation status manually with: gpu-dev list[/yellow]")
|
|
2214
|
+
return None
|
|
2215
|
+
|
|
2216
|
+
finally:
|
|
2217
|
+
# Restore original signal handlers
|
|
2218
|
+
signal.signal(signal.SIGINT, old_handler)
|
|
2219
|
+
try:
|
|
2220
|
+
signal.signal(signal.SIGTERM, signal.SIG_DFL)
|
|
2221
|
+
signal.signal(signal.SIGQUIT, signal.SIG_DFL)
|
|
2222
|
+
except (AttributeError, OSError):
|
|
2223
|
+
pass
|
|
2224
|
+
|
|
2225
|
+
def wait_for_reservation_completion(
|
|
2226
|
+
self, reservation_id: str, timeout_minutes: Optional[int] = 10, verbose: bool = False
|
|
2227
|
+
) -> Optional[Dict[str, Any]]:
|
|
2228
|
+
"""Poll for single reservation completion using shared polling logic (always creates SSH config)"""
|
|
2229
|
+
results = self._wait_for_reservations_completion(
|
|
2230
|
+
[reservation_id], timeout_minutes, is_multinode=False, verbose=verbose)
|
|
2231
|
+
return results[0] if results else None
|