gpu-dev 0.3.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,2231 @@
1
+ """Minimal reservation management for GPU Dev CLI"""
2
+
3
+ import json
4
+ import os
5
+ import select
6
+ import signal
7
+ import sys
8
+ import time
9
+ import uuid
10
+ from datetime import datetime, timedelta
11
+ from decimal import Decimal
12
+ from typing import Optional, List, Dict, Any, Union
13
+
14
+ from botocore.exceptions import ClientError
15
+ from rich.console import Console
16
+ from rich.live import Live
17
+ from rich.spinner import Spinner
18
+
19
+ from .config import Config
20
+ from .name_generator import sanitize_name
21
+ from . import __version__
22
+
23
+ console = Console()
24
+
25
+
26
+ def _make_vscode_link(pod_name: str) -> str:
27
+ """Create a clickable vscode:// URL for opening remote SSH connection
28
+
29
+ Args:
30
+ pod_name: The SSH host name (e.g., gpu-dev-34f5f9e0)
31
+
32
+ Returns:
33
+ A vscode:// URL that opens VS Code with the remote SSH connection
34
+ """
35
+ # VS Code remote SSH URL format: vscode://vscode-remote/ssh-remote+<host>/path
36
+ return f"vscode://vscode-remote/ssh-remote+{pod_name}/home/dev"
37
+
38
+
39
+ def _make_cursor_link(pod_name: str) -> str:
40
+ """Create a clickable cursor:// URL for opening remote SSH connection in Cursor
41
+
42
+ Args:
43
+ pod_name: The SSH host name (e.g., gpu-dev-34f5f9e0)
44
+
45
+ Returns:
46
+ A cursor:// URL that opens Cursor with the remote SSH connection
47
+ """
48
+ # Based on VS Code remote SSH URL format: cursor://vscode-remote/ssh-remote+<host>/path
49
+ return f"cursor://vscode-remote/ssh-remote+{pod_name}/home/dev"
50
+
51
+
52
+ def get_version() -> str:
53
+ """Get CLI version for inclusion in SQS messages"""
54
+ return __version__
55
+
56
+
57
+ def _add_agent_forwarding_to_ssh(ssh_command: str) -> str:
58
+ """Add SSH agent forwarding (-A) flag to SSH command if not already present"""
59
+ try:
60
+ if not ssh_command or not ssh_command.startswith("ssh "):
61
+ return ssh_command
62
+
63
+ # Check if -A is already in the command
64
+ if " -A" in ssh_command or ssh_command.endswith(" -A"):
65
+ return ssh_command
66
+
67
+ # Add -A flag after 'ssh'
68
+ parts = ssh_command.split(" ", 1)
69
+ if len(parts) == 2:
70
+ return f"ssh -A {parts[1]}"
71
+ else:
72
+ return "ssh -A"
73
+
74
+ except Exception:
75
+ return ssh_command
76
+
77
+
78
+ def _extract_latest_pod_event(pod_events: str) -> str:
79
+ """Extract the most relevant pod event for display - simplified since Lambda now provides formatted messages"""
80
+ if not pod_events:
81
+ return "Starting pod..."
82
+
83
+ # Lambda now provides pre-formatted messages, so just return them
84
+ # Handle multi-line messages by taking the first meaningful line
85
+ lines = pod_events.split("\n")
86
+ for line in lines:
87
+ line = line.strip()
88
+ if line and not line.startswith("Events:"):
89
+ return line
90
+
91
+ return "Starting pod..."
92
+
93
+
94
+ def _generate_vscode_command(ssh_command: str) -> Optional[str]:
95
+ """Generate VS Code remote connection command from SSH command"""
96
+ try:
97
+ # Extract remote server from SSH command
98
+ # Expected format: ssh dev@<hostname> or various formats with -o options
99
+ if not ssh_command or not ssh_command.startswith("ssh "):
100
+ return None
101
+
102
+ # Parse SSH command to extract hostname
103
+ parts = ssh_command.split()
104
+ hostname = None
105
+
106
+ for i, part in enumerate(parts):
107
+ if "@" in part and not part.startswith("-"):
108
+ # Extract just the hostname part (e.g., from dev@hostname.io)
109
+ hostname = part.split("@")[1]
110
+ break
111
+
112
+ if not hostname:
113
+ return None
114
+
115
+ # Generate VS Code command with ProxyCommand and agent forwarding
116
+ # VS Code will use the ssh command options we provide
117
+ remote_server = f"dev@{hostname}"
118
+
119
+ # Escape single quotes in the ProxyCommand for shell
120
+ proxy_cmd = "gpu-dev-ssh-proxy %h %p"
121
+
122
+ return (
123
+ f"code --remote ssh-remote+{remote_server} "
124
+ f"--ssh-option ForwardAgent=yes "
125
+ f"--ssh-option ProxyCommand='{proxy_cmd}' "
126
+ f"--ssh-option StrictHostKeyChecking=no "
127
+ f"--ssh-option UserKnownHostsFile=/dev/null "
128
+ f"/home/dev"
129
+ )
130
+
131
+ except Exception:
132
+ return None
133
+
134
+
135
+ def _generate_cursor_command(ssh_command: str) -> Optional[str]:
136
+ """Generate Cursor remote connection command from SSH command"""
137
+ try:
138
+ # Extract remote server from SSH command
139
+ # Expected format: ssh dev@<hostname> or various formats with -o options
140
+ if not ssh_command or not ssh_command.startswith("ssh "):
141
+ return None
142
+
143
+ # Parse SSH command to extract hostname
144
+ parts = ssh_command.split()
145
+ remote_server = parts[-1]
146
+ if '@' in remote_server:
147
+ remote_server = remote_server.split('@')[1]
148
+
149
+ # Return the VS Code command format
150
+ return f"cursor --remote ssh-remote+{remote_server} /home/dev"
151
+ except Exception:
152
+ return None
153
+
154
+
155
+ def _generate_ssh_config(hostname: str, pod_name: str) -> str:
156
+ """Generate SSH config for a reservation
157
+
158
+ Args:
159
+ hostname: The FQDN hostname (e.g., old_bison.devservers.io)
160
+ pod_name: The pod name to use as SSH host alias
161
+
162
+ Returns:
163
+ SSH config content as string
164
+ """
165
+ config_content = f"""Host {pod_name}
166
+ HostName {hostname}
167
+ User dev
168
+ ForwardAgent yes
169
+ ProxyCommand gpu-dev-ssh-proxy %h %p
170
+ StrictHostKeyChecking no
171
+ UserKnownHostsFile /dev/null
172
+ """
173
+ return config_content
174
+
175
+
176
+ def _check_ssh_config_permission() -> bool:
177
+ """Check if user has given permission to modify ~/.ssh/config and ~/.cursor/ssh_config
178
+
179
+ Returns:
180
+ True if permission granted or already set up, False otherwise
181
+ """
182
+ import click
183
+ from pathlib import Path
184
+
185
+ gpu_dev_dir = Path.home() / ".gpu-dev"
186
+ permission_file = gpu_dev_dir / ".ssh-config-permission"
187
+
188
+ # Check if already asked and answered
189
+ if permission_file.exists():
190
+ try:
191
+ response = permission_file.read_text().strip()
192
+ return response == "yes"
193
+ except Exception:
194
+ pass
195
+
196
+ # Check if Include already exists in either ~/.ssh/config or ~/.cursor/ssh_config
197
+ config_files = [
198
+ Path.home() / ".ssh" / "config",
199
+ Path.home() / ".cursor" / "ssh_config",
200
+ ]
201
+
202
+ for ssh_config in config_files:
203
+ if ssh_config.exists():
204
+ try:
205
+ content = ssh_config.read_text()
206
+ if "Include ~/.gpu-dev/" in content:
207
+ # Already set up, save permission
208
+ gpu_dev_dir.mkdir(mode=0o700, exist_ok=True)
209
+ permission_file.write_text("yes")
210
+ return True
211
+ except Exception:
212
+ pass
213
+
214
+ # Ask user for permission
215
+ console.print("\n[yellow]━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[/yellow]")
216
+ console.print("[cyan]🔧 SSH Configuration Setup[/cyan]\n")
217
+ console.print("To enable easy SSH access and VS Code/Cursor Remote connections,")
218
+ console.print("we can add GPU dev server configs to your SSH config files.")
219
+ console.print("\n[dim]This adds one line at the top of:[/dim]")
220
+ console.print("[dim] • ~/.ssh/config[/dim]")
221
+ console.print("[dim] • ~/.cursor/ssh_config[/dim]")
222
+ console.print("[dim]Line added: Include ~/.gpu-dev/*-sshconfig[/dim]\n")
223
+ console.print("[green]Benefits:[/green]")
224
+ console.print(" • Simple commands: [green]ssh <pod-name>[/green]")
225
+ console.print(" • VS Code Remote works: [green]code --remote ssh-remote+<pod-name>[/green]")
226
+ console.print(" • Cursor Remote works: Open Remote SSH in Cursor")
227
+ console.print("\n[dim]Without this, you'll need to use: [green]ssh -F ~/.gpu-dev/<id>-sshconfig <pod-name>[/green][/dim]")
228
+ console.print("[yellow]━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[/yellow]\n")
229
+
230
+ approved = click.confirm("Add Include directive to SSH config files?", default=True)
231
+
232
+ # Save response
233
+ gpu_dev_dir.mkdir(mode=0o700, exist_ok=True)
234
+ permission_file.write_text("yes" if approved else "no")
235
+
236
+ return approved
237
+
238
+
239
+ def _ensure_ssh_config_includes_devgpu() -> bool:
240
+ """Ensure ~/.ssh/config and ~/.cursor/ssh_config include ~/.devgpu/* configs for VS Code/Cursor compatibility
241
+
242
+ Returns:
243
+ True if Include was added/exists, False if user declined
244
+ """
245
+ from pathlib import Path
246
+
247
+ # Check permission first
248
+ if not _check_ssh_config_permission():
249
+ return False
250
+
251
+ include_line = "Include ~/.gpu-dev/*-sshconfig\n"
252
+
253
+ # List of config files to update: ~/.ssh/config and ~/.cursor/ssh_config
254
+ config_files = [
255
+ (Path.home() / ".ssh", "config"),
256
+ (Path.home() / ".cursor", "ssh_config"),
257
+ ]
258
+
259
+ success = False
260
+ for config_dir, config_name in config_files:
261
+ try:
262
+ # Create directory if it doesn't exist
263
+ config_dir.mkdir(mode=0o700, exist_ok=True)
264
+
265
+ config_file = config_dir / config_name
266
+
267
+ # Read existing config or create empty
268
+ if config_file.exists():
269
+ content = config_file.read_text()
270
+ else:
271
+ content = ""
272
+
273
+ # Check if Include already exists
274
+ if "Include ~/.gpu-dev/" in content:
275
+ success = True
276
+ continue
277
+
278
+ # Add Include at the top (must be first in SSH config)
279
+ new_content = include_line + "\n" + content
280
+ config_file.write_text(new_content)
281
+ config_file.chmod(0o600)
282
+ success = True
283
+ except Exception:
284
+ # If one fails, we still try the other
285
+ pass
286
+
287
+ return success
288
+
289
+
290
+ def create_ssh_config_for_reservation(hostname: str, pod_name: str, reservation_id: str, name: Optional[str] = None) -> tuple[Optional[str], bool]:
291
+ """Create SSH config file for a reservation in ~/.gpu-dev/
292
+
293
+ Args:
294
+ hostname: The FQDN hostname (e.g., old_bison.devservers.io)
295
+ pod_name: The pod name to use as SSH host alias
296
+ reservation_id: The reservation ID (full or short)
297
+ name: Optional reservation name to use for filename (falls back to short ID)
298
+
299
+ Returns:
300
+ Tuple of (config_path, use_include) where:
301
+ - config_path: Path to the created config file, or None on error
302
+ - use_include: True if ~/.ssh/config includes devgpu configs, False if need -F flag
303
+ """
304
+ from pathlib import Path
305
+
306
+ # Create ~/.gpu-dev directory
307
+ gpu_dev_dir = Path.home() / ".gpu-dev"
308
+ gpu_dev_dir.mkdir(mode=0o700, exist_ok=True)
309
+
310
+ # Use short ID for filename (always safe, avoids issues with special chars like / in names)
311
+ # For multinode, names like "16x B200 multinode - Node 1/2" contain / which breaks filenames
312
+ short_id = reservation_id[:8]
313
+ filename = f"{short_id}-sshconfig"
314
+
315
+ config_file = gpu_dev_dir / filename
316
+ config_content = _generate_ssh_config(hostname, pod_name)
317
+
318
+ try:
319
+ config_file.write_text(config_content)
320
+ config_file.chmod(0o600)
321
+
322
+ # Check/ask permission to include in ~/.ssh/config
323
+ use_include = _ensure_ssh_config_includes_devgpu()
324
+
325
+ return (str(config_file), use_include)
326
+ except Exception:
327
+ return (None, False)
328
+
329
+
330
+ def remove_ssh_config_for_reservation(reservation_id: str, name: Optional[str] = None) -> bool:
331
+ """Remove SSH config file for a reservation
332
+
333
+ Args:
334
+ reservation_id: The reservation ID (full or short)
335
+ name: Optional reservation name to use for filename (falls back to short ID, name param kept for backwards compat)
336
+
337
+ Returns:
338
+ True if successful (or file didn't exist), False on error
339
+ """
340
+ from pathlib import Path
341
+
342
+ # Always use short ID for filename (consistent with create_ssh_config_for_reservation)
343
+ short_id = reservation_id[:8]
344
+ filename = f"{short_id}-sshconfig"
345
+
346
+ config_file = Path.home() / ".gpu-dev" / filename
347
+
348
+ try:
349
+ if config_file.exists():
350
+ config_file.unlink()
351
+ return True
352
+ except Exception:
353
+ return False
354
+
355
+
356
+ def is_ssh_include_enabled() -> bool:
357
+ """Check if user has approved SSH config Include directive
358
+
359
+ Returns:
360
+ True if Include is enabled, False otherwise
361
+ """
362
+ from pathlib import Path
363
+
364
+ permission_file = Path.home() / ".gpu-dev" / ".ssh-config-permission"
365
+ if permission_file.exists():
366
+ try:
367
+ return permission_file.read_text().strip() == "yes"
368
+ except Exception:
369
+ pass
370
+ return False
371
+
372
+
373
+ def get_ssh_config_path(reservation_id: str, name: Optional[str] = None) -> str:
374
+ """Get the SSH config file path for a reservation
375
+
376
+ Args:
377
+ reservation_id: The reservation ID (full or short)
378
+ name: Optional reservation name to use for filename (falls back to short ID, name param kept for backwards compat)
379
+
380
+ Returns:
381
+ Path to the config file (may not exist)
382
+ """
383
+ from pathlib import Path
384
+ # Always use short ID for filename (consistent with create_ssh_config_for_reservation)
385
+ short_id = reservation_id[:8]
386
+ filename = f"{short_id}-sshconfig"
387
+ return str(Path.home() / ".gpu-dev" / filename)
388
+
389
+
390
+ class ReservationManager:
391
+ """Minimal GPU reservations manager - AWS-only"""
392
+
393
+ def __init__(self, config: Config):
394
+ self.config = config
395
+ self.reservations_table = config.dynamodb.Table(
396
+ config.reservations_table)
397
+
398
+ def create_reservation(
399
+ self,
400
+ user_id: str,
401
+ gpu_count: int,
402
+ gpu_type: str,
403
+ duration_hours: Union[int, float],
404
+ name: Optional[str] = None,
405
+ github_user: Optional[str] = None,
406
+ jupyter_enabled: bool = False,
407
+ recreate_env: bool = False,
408
+ dockerfile: Optional[str] = None,
409
+ no_persistent_disk: bool = False,
410
+ dockerimage: Optional[str] = None,
411
+ preserve_entrypoint: bool = False,
412
+ disk_name: Optional[str] = None,
413
+ node_labels: Optional[Dict[str, str]] = None,
414
+ ) -> Optional[str]:
415
+ """Create a new GPU reservation"""
416
+ try:
417
+ reservation_id = str(uuid.uuid4())
418
+ created_at = datetime.utcnow().isoformat()
419
+
420
+ # Process the name: sanitize user input or let Lambda generate
421
+ processed_name = None
422
+ if name:
423
+ # Sanitize user-provided name
424
+ processed_name = sanitize_name(name)
425
+ # If sanitization results in empty string, let Lambda generate
426
+ if not processed_name:
427
+ processed_name = None
428
+ # If no name provided, let Lambda generate (processed_name stays None)
429
+
430
+ # Create initial reservation record for polling
431
+ # Convert float to Decimal for DynamoDB compatibility
432
+ duration_decimal = Decimal(str(duration_hours))
433
+
434
+ initial_reservation = {
435
+ "reservation_id": reservation_id,
436
+ "user_id": user_id,
437
+ "gpu_count": gpu_count,
438
+ "gpu_type": gpu_type,
439
+ "duration_hours": duration_decimal,
440
+ "name": processed_name,
441
+ "created_at": created_at,
442
+ "status": "pending",
443
+ "expires_at": (
444
+ datetime.utcnow() + timedelta(hours=duration_hours)
445
+ ).isoformat(),
446
+ "jupyter_enabled": jupyter_enabled,
447
+ }
448
+
449
+ # Add github_user if provided
450
+ if github_user:
451
+ initial_reservation["github_user"] = github_user
452
+
453
+ # Send processing request to SQS queue (Lambda will create the initial record)
454
+ # Use float for SQS message (JSON serializable)
455
+ message = {
456
+ "reservation_id": reservation_id,
457
+ "user_id": user_id,
458
+ "gpu_count": gpu_count,
459
+ "gpu_type": gpu_type,
460
+ "duration_hours": float(duration_hours),
461
+ "name": processed_name,
462
+ "created_at": created_at,
463
+ "status": "pending",
464
+ "jupyter_enabled": jupyter_enabled,
465
+ "recreate_env": recreate_env,
466
+ "no_persistent_disk": no_persistent_disk,
467
+ "version": get_version(),
468
+ }
469
+
470
+ # Add github_user if provided
471
+ if github_user:
472
+ message["github_user"] = github_user
473
+
474
+ # Add Docker options if provided
475
+ if dockerfile:
476
+ message["dockerfile"] = dockerfile
477
+ if dockerimage:
478
+ message["dockerimage"] = dockerimage
479
+ # Always include preserve_entrypoint flag (don't make it conditional)
480
+ message["preserve_entrypoint"] = preserve_entrypoint
481
+
482
+ # Add disk_name if provided
483
+ if disk_name:
484
+ message["disk_name"] = disk_name
485
+
486
+ # Add node_labels if provided (for node selection preferences)
487
+ if node_labels:
488
+ message["node_labels"] = node_labels
489
+
490
+ queue_url = self.config.get_queue_url()
491
+ self.config.sqs_client.send_message(
492
+ QueueUrl=queue_url, MessageBody=json.dumps(message)
493
+ )
494
+
495
+ return reservation_id
496
+
497
+ except Exception as e:
498
+ console.print(f"[red]❌ Error creating reservation: {str(e)}[/red]")
499
+ return None
500
+
501
+ def create_multinode_reservation(
502
+ self,
503
+ user_id: str,
504
+ gpu_count: int,
505
+ gpu_type: str,
506
+ duration_hours: Union[int, float],
507
+ name: Optional[str] = None,
508
+ github_user: Optional[str] = None,
509
+ jupyter_enabled: bool = False,
510
+ recreate_env: bool = False,
511
+ dockerfile: Optional[str] = None,
512
+ dockerimage: Optional[str] = None,
513
+ no_persistent_disk: bool = False,
514
+ preserve_entrypoint: bool = False,
515
+ disk_name: Optional[str] = None,
516
+ node_labels: Optional[Dict[str, str]] = None,
517
+ ) -> Optional[List[str]]:
518
+ """Create multiple GPU reservations for multinode setup"""
519
+ try:
520
+ # Determine GPU config
521
+ gpu_configs = {
522
+ "t4": {"max_gpus": 4},
523
+ "l4": {"max_gpus": 4},
524
+ "a10g": {"max_gpus": 4},
525
+ "t4-small": {"max_gpus": 1},
526
+ "g5g": {"max_gpus": 2},
527
+ "a100": {"max_gpus": 8},
528
+ "h100": {"max_gpus": 8},
529
+ "h200": {"max_gpus": 8},
530
+ "b200": {"max_gpus": 8},
531
+ }
532
+
533
+ max_gpus_per_node = gpu_configs[gpu_type]["max_gpus"]
534
+ num_nodes = gpu_count // max_gpus_per_node
535
+
536
+ if gpu_count % max_gpus_per_node != 0:
537
+ console.print(
538
+ f"[red]❌ GPU count must be multiple of {max_gpus_per_node} for {gpu_type}[/red]")
539
+ return None
540
+
541
+ # Generate a master reservation ID to group related nodes
542
+ master_reservation_id = str(uuid.uuid4())
543
+ created_at = datetime.utcnow().isoformat()
544
+ reservation_ids = []
545
+
546
+ # Create reservation for each node
547
+ for node_idx in range(num_nodes):
548
+ node_reservation_id = str(uuid.uuid4())
549
+ reservation_ids.append(node_reservation_id)
550
+
551
+ # Node-specific name
552
+ base_name = name or f'{gpu_count}x {gpu_type.upper()} multinode'
553
+ node_name = f"{base_name} - Node {node_idx + 1}/{num_nodes}"
554
+
555
+ # Create reservation message for this node
556
+ message = {
557
+ "reservation_id": node_reservation_id,
558
+ "master_reservation_id": master_reservation_id, # Group related nodes
559
+ "node_index": node_idx,
560
+ "total_nodes": num_nodes,
561
+ "user_id": user_id,
562
+ "gpu_count": max_gpus_per_node, # GPUs per node
563
+ "total_gpu_count": gpu_count, # Total GPUs across all nodes
564
+ "gpu_type": gpu_type,
565
+ "duration_hours": float(duration_hours),
566
+ "name": node_name,
567
+ "version": get_version(),
568
+ "created_at": created_at,
569
+ "status": "pending",
570
+ # Only enable Jupyter on master node
571
+ "jupyter_enabled": jupyter_enabled and node_idx == 0,
572
+ "recreate_env": recreate_env,
573
+ "is_multinode": True,
574
+ "no_persistent_disk": no_persistent_disk,
575
+ }
576
+
577
+ if github_user:
578
+ message["github_user"] = github_user
579
+
580
+ # Add Docker options if provided
581
+ if dockerfile:
582
+ message["dockerfile"] = dockerfile
583
+ if dockerimage:
584
+ message["dockerimage"] = dockerimage
585
+ # Always include preserve_entrypoint flag (don't make it conditional)
586
+ message["preserve_entrypoint"] = preserve_entrypoint
587
+
588
+ # Add disk_name if provided (only for master node in multinode setup)
589
+ if disk_name and node_idx == 0:
590
+ message["disk_name"] = disk_name
591
+
592
+ # Add node_labels if provided (for node selection preferences)
593
+ if node_labels:
594
+ message["node_labels"] = node_labels
595
+
596
+ # Send to SQS queue
597
+ queue_url = self.config.get_queue_url()
598
+ self.config.sqs_client.send_message(
599
+ QueueUrl=queue_url, MessageBody=json.dumps(message)
600
+ )
601
+
602
+ return reservation_ids
603
+
604
+ except Exception as e:
605
+ console.print(
606
+ f"[red]❌ Error creating multinode reservation: {str(e)}[/red]")
607
+ return None
608
+
609
+ def list_reservations(
610
+ self,
611
+ user_filter: Optional[str] = None,
612
+ statuses_to_include: Optional[List[str]] = None,
613
+ ) -> List[Dict[str, Any]]:
614
+ """List GPU reservations with flexible filtering"""
615
+ try:
616
+ all_reservations = []
617
+
618
+ if user_filter:
619
+ # Query by specific user with pagination
620
+ response = self.reservations_table.query(
621
+ IndexName="UserIndex",
622
+ KeyConditionExpression="user_id = :user_id",
623
+ ExpressionAttributeValues={":user_id": user_filter},
624
+ )
625
+ all_reservations = response.get("Items", [])
626
+
627
+ # Handle pagination for UserIndex query
628
+ while "LastEvaluatedKey" in response:
629
+ response = self.reservations_table.query(
630
+ IndexName="UserIndex",
631
+ KeyConditionExpression="user_id = :user_id",
632
+ ExpressionAttributeValues={":user_id": user_filter},
633
+ ExclusiveStartKey=response["LastEvaluatedKey"]
634
+ )
635
+ all_reservations.extend(response.get("Items", []))
636
+ else:
637
+ # Get all reservations (scan with pagination for admin use)
638
+ all_reservations = []
639
+ response = self.reservations_table.scan()
640
+ all_reservations.extend(response.get("Items", []))
641
+
642
+ # Handle pagination
643
+ while "LastEvaluatedKey" in response:
644
+ response = self.reservations_table.scan(
645
+ ExclusiveStartKey=response["LastEvaluatedKey"]
646
+ )
647
+ all_reservations.extend(response.get("Items", []))
648
+
649
+ # Filter by status if specified
650
+ if statuses_to_include:
651
+ filtered_reservations = [
652
+ reservation
653
+ for reservation in all_reservations
654
+ if reservation.get("status") in statuses_to_include
655
+ ]
656
+ return filtered_reservations
657
+
658
+ return all_reservations
659
+
660
+ except Exception as e:
661
+ console.print(f"[red]❌ Error listing reservations: {str(e)}[/red]")
662
+ return []
663
+
664
+ def cancel_reservation(self, reservation_id: str, user_id: str) -> bool:
665
+ """Cancel a GPU reservation by sending cancellation message to queue"""
666
+ try:
667
+ # Send cancellation request to SQS queue for processing
668
+ message = {
669
+ "type": "cancellation",
670
+ "reservation_id": reservation_id,
671
+ "user_id": user_id,
672
+ "requested_at": datetime.utcnow().isoformat(),
673
+ "version": get_version(),
674
+ }
675
+
676
+ queue_url = self.config.get_queue_url()
677
+ self.config.sqs_client.send_message(
678
+ QueueUrl=queue_url, MessageBody=json.dumps(message)
679
+ )
680
+
681
+ console.print(
682
+ f"[yellow]⏳ Cancellation request submitted for reservation {reservation_id[:8]}...[/yellow]"
683
+ )
684
+ console.print(
685
+ "[yellow]💡 The reservation will be cancelled shortly. Use 'gpu-dev list' to check status.[/yellow]"
686
+ )
687
+ return True
688
+
689
+ except Exception as e:
690
+ console.print(
691
+ f"[red]❌ Error submitting cancellation request: {str(e)}[/red]"
692
+ )
693
+ return False
694
+
695
+ def wait_for_multinode_reservation_completion(
696
+ self, reservation_ids: List[str], timeout_minutes: Optional[int] = 10, verbose: bool = False
697
+ ) -> Optional[List[Dict[str, Any]]]:
698
+ """Poll for multiple reservation completion using shared polling logic"""
699
+ return self._wait_for_reservations_completion(reservation_ids, timeout_minutes, is_multinode=True, verbose=verbose)
700
+
701
+ def get_connection_info(
702
+ self, reservation_id: str, user_id: str
703
+ ) -> Optional[Dict[str, Any]]:
704
+ """Get SSH connection information for a reservation"""
705
+ try:
706
+ # Query by user first (efficient), then filter by reservation_id prefix
707
+ response = self.reservations_table.query(
708
+ IndexName="UserIndex",
709
+ KeyConditionExpression="user_id = :user_id",
710
+ ExpressionAttributeValues={":user_id": user_id},
711
+ )
712
+ all_reservations = response.get("Items", [])
713
+
714
+ # Handle pagination for UserIndex query
715
+ while "LastEvaluatedKey" in response:
716
+ response = self.reservations_table.query(
717
+ IndexName="UserIndex",
718
+ KeyConditionExpression="user_id = :user_id",
719
+ ExpressionAttributeValues={":user_id": user_id},
720
+ ExclusiveStartKey=response["LastEvaluatedKey"]
721
+ )
722
+ all_reservations.extend(response.get("Items", []))
723
+
724
+ # Filter by reservation_id prefix in memory
725
+ matching_reservations = [
726
+ res for res in all_reservations
727
+ if res.get("reservation_id", "").startswith(reservation_id)
728
+ ]
729
+
730
+ if len(matching_reservations) == 0:
731
+ return None
732
+ elif len(matching_reservations) > 1:
733
+ return None # Ambiguous - need longer prefix
734
+
735
+ reservation = matching_reservations[0]
736
+
737
+ return {
738
+ "ssh_command": reservation.get("ssh_command", "ssh user@pending"),
739
+ "pod_name": reservation.get("pod_name", "pending"),
740
+ "namespace": reservation.get("namespace", "default"),
741
+ "gpu_count": reservation["gpu_count"],
742
+ "status": reservation["status"],
743
+ "launched_at": reservation.get("launched_at"),
744
+ "expires_at": reservation.get("expires_at"),
745
+ "created_at": reservation.get("created_at"),
746
+ "reservation_id": reservation["reservation_id"],
747
+ "name": reservation.get("name"),
748
+ "instance_type": reservation.get("instance_type", "unknown"),
749
+ "gpu_type": reservation.get("gpu_type", "unknown"),
750
+ "failure_reason": reservation.get("failure_reason", ""),
751
+ "current_detailed_status": reservation.get("current_detailed_status", ""),
752
+ "status_history": reservation.get("status_history", []),
753
+ "pod_logs": reservation.get("pod_logs", ""),
754
+ "jupyter_url": reservation.get("jupyter_url", ""),
755
+ "jupyter_port": reservation.get("jupyter_port", ""),
756
+ "jupyter_token": reservation.get("jupyter_token", ""),
757
+ "jupyter_enabled": reservation.get("jupyter_enabled", False),
758
+ "jupyter_error": reservation.get("jupyter_error", ""),
759
+ "ebs_volume_id": reservation.get("ebs_volume_id", ""),
760
+ "secondary_users": reservation.get("secondary_users", []),
761
+ "warning": reservation.get("warning", ""),
762
+ }
763
+
764
+ except Exception as e:
765
+ console.print(
766
+ f"[red]❌ Error getting connection info: {str(e)}[/red]")
767
+ return None
768
+
769
+ def enable_jupyter(self, reservation_id: str, user_id: str) -> bool:
770
+ """Enable Jupyter Lab for an active reservation"""
771
+ try:
772
+ # Send message to Lambda to start Jupyter service in pod
773
+ # Lambda will handle both the pod changes and DynamoDB updates
774
+ message = {
775
+ "action": "enable_jupyter",
776
+ "reservation_id": reservation_id,
777
+ "user_id": user_id,
778
+ "version": get_version(),
779
+ }
780
+
781
+ queue_url = self.config.get_queue_url()
782
+ self.config.sqs_client.send_message(
783
+ QueueUrl=queue_url, MessageBody=json.dumps(message)
784
+ )
785
+
786
+ console.print(
787
+ f"[yellow]⏳ Jupyter enable request submitted for reservation {reservation_id[:8]}...[/yellow]"
788
+ )
789
+
790
+ # Poll for 3 minutes to show the outcome
791
+ return self._poll_jupyter_action_result(
792
+ reservation_id, user_id, "enable", timeout_minutes=3
793
+ )
794
+
795
+ except Exception as e:
796
+ console.print(
797
+ f"[red]❌ Error submitting Jupyter enable request: {str(e)}[/red]"
798
+ )
799
+ return False
800
+
801
+ def disable_jupyter(self, reservation_id: str, user_id: str) -> bool:
802
+ """Disable Jupyter Lab for an active reservation"""
803
+ try:
804
+ # Send message to Lambda to stop Jupyter service in pod
805
+ # Lambda will handle both the pod changes and DynamoDB updates
806
+ message = {
807
+ "action": "disable_jupyter",
808
+ "reservation_id": reservation_id,
809
+ "user_id": user_id,
810
+ "version": get_version(),
811
+ }
812
+
813
+ queue_url = self.config.get_queue_url()
814
+ self.config.sqs_client.send_message(
815
+ QueueUrl=queue_url, MessageBody=json.dumps(message)
816
+ )
817
+
818
+ console.print(
819
+ f"[yellow]⏳ Jupyter disable request submitted for reservation {reservation_id[:8]}...[/yellow]"
820
+ )
821
+
822
+ # Poll for 3 minutes to show the outcome
823
+ return self._poll_jupyter_action_result(
824
+ reservation_id, user_id, "disable", timeout_minutes=3
825
+ )
826
+
827
+ except Exception as e:
828
+ console.print(
829
+ f"[red]❌ Error submitting Jupyter disable request: {str(e)}[/red]"
830
+ )
831
+ return False
832
+
833
+ def add_user(self, reservation_id: str, user_id: str, github_username: str) -> bool:
834
+ """Add a secondary user to an active reservation"""
835
+ try:
836
+ # Validate GitHub username format (basic validation)
837
+ if (
838
+ not github_username
839
+ or not github_username.replace("-", "").replace("_", "").isalnum()
840
+ ):
841
+ console.print(
842
+ f"[red]❌ Invalid GitHub username: {github_username}[/red]"
843
+ )
844
+ return False
845
+
846
+ # Send message to Lambda to add user SSH keys to pod
847
+ # Lambda will handle fetching GitHub keys and updating the pod
848
+ message = {
849
+ "action": "add_user",
850
+ "reservation_id": reservation_id,
851
+ "user_id": user_id,
852
+ "github_username": github_username,
853
+ "version": get_version(),
854
+ }
855
+
856
+ queue_url = self.config.get_queue_url()
857
+ self.config.sqs_client.send_message(
858
+ QueueUrl=queue_url, MessageBody=json.dumps(message)
859
+ )
860
+
861
+ console.print(
862
+ f"[yellow]⏳ Adding user {github_username} to reservation {reservation_id[:8]}...[/yellow]"
863
+ )
864
+
865
+ # Poll for 3 minutes to show the outcome
866
+ return self._poll_add_user_result(
867
+ reservation_id, user_id, github_username, timeout_minutes=3
868
+ )
869
+
870
+ except Exception as e:
871
+ console.print(
872
+ f"[red]❌ Error adding user {github_username}: {str(e)}[/red]"
873
+ )
874
+ return False
875
+
876
+ def extend_reservation(self, reservation_id: str, user_id: str, extension_hours: float) -> bool:
877
+ """Extend an active reservation by the specified number of hours"""
878
+ try:
879
+ # Capture current expiration BEFORE sending extension request to avoid race condition
880
+ response = self.reservations_table.query(
881
+ IndexName="UserIndex",
882
+ KeyConditionExpression="user_id = :user_id",
883
+ ExpressionAttributeValues={":user_id": user_id},
884
+ )
885
+ all_reservations = response.get("Items", [])
886
+
887
+ # Handle pagination for UserIndex query
888
+ while "LastEvaluatedKey" in response:
889
+ response = self.reservations_table.query(
890
+ IndexName="UserIndex",
891
+ KeyConditionExpression="user_id = :user_id",
892
+ ExpressionAttributeValues={":user_id": user_id},
893
+ ExclusiveStartKey=response["LastEvaluatedKey"]
894
+ )
895
+ all_reservations.extend(response.get("Items", []))
896
+
897
+ matching_reservations = [
898
+ res for res in all_reservations
899
+ if res.get("reservation_id", "").startswith(reservation_id)
900
+ ]
901
+
902
+ initial_expires_at = None
903
+ if matching_reservations:
904
+ initial_expires_at = matching_reservations[0].get("expires_at", "")
905
+
906
+ # Send message to Lambda to extend reservation
907
+ # Lambda will handle both the expiration timestamp update and any necessary pod updates
908
+ message = {
909
+ "action": "extend_reservation",
910
+ "reservation_id": reservation_id,
911
+ "extension_hours": extension_hours,
912
+ "version": get_version(),
913
+ }
914
+
915
+ queue_url = self.config.get_queue_url()
916
+ self.config.sqs_client.send_message(
917
+ QueueUrl=queue_url, MessageBody=json.dumps(message)
918
+ )
919
+
920
+ console.print(
921
+ f"[yellow]⏳ Extension request submitted for reservation {reservation_id[:8]}...[/yellow]"
922
+ )
923
+
924
+ # Poll for 3 minutes to show the outcome
925
+ return self._poll_extend_action_result(
926
+ reservation_id, user_id, extension_hours, timeout_minutes=3, initial_expires_at=initial_expires_at
927
+ )
928
+
929
+ except Exception as e:
930
+ console.print(
931
+ f"[red]❌ Error submitting extension request: {str(e)}[/red]")
932
+ return False
933
+
934
+ def get_gpu_availability_by_type(self) -> Optional[Dict[str, Dict[str, Any]]]:
935
+ """Get GPU availability information by GPU type from real-time availability table"""
936
+ try:
937
+ # Try to get real-time availability from the availability table
938
+ availability_table_name = self.config.availability_table
939
+ availability_table = self.config.dynamodb.Table(
940
+ availability_table_name)
941
+
942
+ # Scan the whole availability table with pagination
943
+ response = availability_table.scan()
944
+ availability_info = {}
945
+ all_items = response.get("Items", [])
946
+
947
+ # Handle pagination for availability table
948
+ while "LastEvaluatedKey" in response:
949
+ response = availability_table.scan(
950
+ ExclusiveStartKey=response["LastEvaluatedKey"]
951
+ )
952
+ all_items.extend(response.get("Items", []))
953
+
954
+ for item in all_items:
955
+ gpu_type = item["gpu_type"]
956
+ queue_length = self._get_queue_length_for_gpu_type(gpu_type)
957
+ estimated_wait = queue_length * 15 if queue_length > 0 else 0
958
+
959
+ availability_info[gpu_type] = {
960
+ "available": int(item.get("available_gpus", 0)),
961
+ "total": int(item.get("total_gpus", 0)),
962
+ "max_reservable": int(item.get("max_reservable", 0)),
963
+ "full_nodes_available": int(item.get("full_nodes_available", 0)),
964
+ "gpus_per_instance": int(item.get("gpus_per_instance", 0)),
965
+ "queue_length": queue_length,
966
+ "estimated_wait_minutes": estimated_wait,
967
+ "running_instances": int(item.get("running_instances", 0)),
968
+ "desired_capacity": int(item.get("desired_capacity", 0)),
969
+ "last_updated": item.get("last_updated_timestamp", 0),
970
+ }
971
+
972
+ return availability_info
973
+
974
+ except Exception as e:
975
+ console.print(
976
+ f"[red]❌ Error getting GPU availability: {str(e)}[/red]")
977
+ return None
978
+
979
+ def _get_static_gpu_config(
980
+ self, gpu_type: str, queue_length: int, estimated_wait: int
981
+ ) -> Dict[str, Any]:
982
+ """Get static GPU configuration as fallback when real-time data unavailable"""
983
+ static_configs = {
984
+ # 2x p4d.24xlarge = 16 A100s
985
+ "a100": {"available": 0, "total": 16},
986
+ # 1x p6-b200.48xlarge = 8 B200s
987
+ "b200": {"available": 0, "total": 8},
988
+ # 2x p5e.48xlarge = 16 H200s
989
+ "h200": {"available": 0, "total": 16},
990
+ "h100": {"available": 0, "total": 16}, # 2x p5.48xlarge = 16 H100s
991
+ "t4": {"available": 0, "total": 8}, # 2x g4dn.12xlarge = 8 T4s
992
+ }
993
+
994
+ config = static_configs.get(gpu_type, {"available": 0, "total": 0})
995
+ return {
996
+ "available": config["available"],
997
+ "total": config["total"],
998
+ "queue_length": queue_length,
999
+ "estimated_wait_minutes": estimated_wait,
1000
+ "running_instances": 0,
1001
+ "desired_capacity": 0,
1002
+ "last_updated": 0,
1003
+ }
1004
+
1005
+ def _get_queue_length_for_gpu_type(self, gpu_type: str) -> int:
1006
+ """Get the number of queued reservations for a specific GPU type"""
1007
+ try:
1008
+ total_count = 0
1009
+
1010
+ # Count queued reservations for this GPU type
1011
+ for status in ["queued", "pending"]:
1012
+ try:
1013
+ response = self.reservations_table.query(
1014
+ IndexName="StatusGpuTypeIndex",
1015
+ KeyConditionExpression="#status = :status AND gpu_type = :gpu_type",
1016
+ ExpressionAttributeNames={"#status": "status"},
1017
+ ExpressionAttributeValues={
1018
+ ":status": status,
1019
+ ":gpu_type": gpu_type,
1020
+ },
1021
+ )
1022
+ total_count += len(response.get("Items", []))
1023
+
1024
+ # Handle pagination for StatusGpuTypeIndex query
1025
+ while "LastEvaluatedKey" in response:
1026
+ response = self.reservations_table.query(
1027
+ IndexName="StatusGpuTypeIndex",
1028
+ KeyConditionExpression="#status = :status AND gpu_type = :gpu_type",
1029
+ ExpressionAttributeNames={"#status": "status"},
1030
+ ExpressionAttributeValues={
1031
+ ":status": status,
1032
+ ":gpu_type": gpu_type,
1033
+ },
1034
+ ExclusiveStartKey=response["LastEvaluatedKey"]
1035
+ )
1036
+ total_count += len(response.get("Items", []))
1037
+ except Exception as query_error:
1038
+ # Fallback to scanning if the composite index doesn't exist yet
1039
+ console.print(
1040
+ f"[dim]Fallback: scanning for {status} {gpu_type} reservations[/dim]"
1041
+ )
1042
+ response = self.reservations_table.scan(
1043
+ FilterExpression="contains(#status, :status) AND contains(gpu_type, :gpu_type)",
1044
+ ExpressionAttributeNames={"#status": "status"},
1045
+ ExpressionAttributeValues={
1046
+ ":status": status,
1047
+ ":gpu_type": gpu_type,
1048
+ },
1049
+ )
1050
+ total_count += len(response.get("Items", []))
1051
+
1052
+ # Handle pagination for fallback scan
1053
+ while "LastEvaluatedKey" in response:
1054
+ response = self.reservations_table.scan(
1055
+ FilterExpression="contains(#status, :status) AND contains(gpu_type, :gpu_type)",
1056
+ ExpressionAttributeNames={"#status": "status"},
1057
+ ExpressionAttributeValues={
1058
+ ":status": status,
1059
+ ":gpu_type": gpu_type,
1060
+ },
1061
+ ExclusiveStartKey=response["LastEvaluatedKey"]
1062
+ )
1063
+ total_count += len(response.get("Items", []))
1064
+
1065
+ return total_count
1066
+
1067
+ except Exception as e:
1068
+ console.print(
1069
+ f"[red]❌ Error getting queue length for {gpu_type}: {str(e)}[/red]"
1070
+ )
1071
+ return 0
1072
+
1073
+ def _poll_jupyter_action_result(
1074
+ self, reservation_id: str, user_id: str, action: str, timeout_minutes: int = 3
1075
+ ) -> bool:
1076
+ """Poll reservation table for Jupyter action result"""
1077
+ try:
1078
+ start_time = time.time()
1079
+ timeout_seconds = timeout_minutes * 60
1080
+
1081
+ with Live(console=console, refresh_per_second=2) as live:
1082
+ spinner = Spinner(
1083
+ "dots", text=f"🔄 Processing Jupyter {action} request..."
1084
+ )
1085
+ live.update(spinner)
1086
+
1087
+ initial_state = None
1088
+
1089
+ while time.time() - start_time < timeout_seconds:
1090
+ try:
1091
+ # Get current reservation state - query by user first, then filter by prefix
1092
+ response = self.reservations_table.query(
1093
+ IndexName="UserIndex",
1094
+ KeyConditionExpression="user_id = :user_id",
1095
+ ExpressionAttributeValues={":user_id": user_id},
1096
+ )
1097
+ all_reservations = response.get("Items", [])
1098
+
1099
+ # Handle pagination for UserIndex query
1100
+ while "LastEvaluatedKey" in response:
1101
+ response = self.reservations_table.query(
1102
+ IndexName="UserIndex",
1103
+ KeyConditionExpression="user_id = :user_id",
1104
+ ExpressionAttributeValues={
1105
+ ":user_id": user_id},
1106
+ ExclusiveStartKey=response["LastEvaluatedKey"]
1107
+ )
1108
+ all_reservations.extend(response.get("Items", []))
1109
+
1110
+ # Filter by reservation_id prefix in memory
1111
+ items = [
1112
+ res for res in all_reservations
1113
+ if res.get("reservation_id", "").startswith(reservation_id)
1114
+ ]
1115
+ if len(items) == 0:
1116
+ spinner.text = f"🔄 Waiting for reservation data..."
1117
+ live.update(spinner)
1118
+ time.sleep(2)
1119
+ continue
1120
+ elif len(items) > 1:
1121
+ spinner.text = f"🔄 Multiple reservations found for {reservation_id}, using first match..."
1122
+ live.update(spinner)
1123
+
1124
+ reservation = items[0]
1125
+
1126
+ # Capture initial state on first iteration
1127
+ if initial_state is None:
1128
+ initial_state = {
1129
+ "jupyter_enabled": reservation.get(
1130
+ "jupyter_enabled", False
1131
+ ),
1132
+ "jupyter_url": reservation.get("jupyter_url", ""),
1133
+ "jupyter_port": reservation.get("jupyter_port", 0),
1134
+ }
1135
+
1136
+ current_jupyter_enabled = reservation.get(
1137
+ "jupyter_enabled", False
1138
+ )
1139
+ jupyter_url = reservation.get("jupyter_url", "")
1140
+ jupyter_port = reservation.get("jupyter_port", 0)
1141
+
1142
+ # Check if the action has completed
1143
+ if action == "enable":
1144
+ if current_jupyter_enabled and jupyter_url:
1145
+ live.stop()
1146
+ console.print(
1147
+ f"[green]✅ Jupyter Lab enabled successfully![/green]"
1148
+ )
1149
+ console.print(
1150
+ f"[cyan]🔗 Jupyter URL:[/cyan] {jupyter_url}"
1151
+ )
1152
+ console.print(
1153
+ f"[cyan]🔌 Port:[/cyan] {jupyter_port}")
1154
+ return True
1155
+ elif (
1156
+ current_jupyter_enabled
1157
+ != initial_state["jupyter_enabled"]
1158
+ ):
1159
+ spinner.text = f"🔄 Jupyter enabled, waiting for URL..."
1160
+ else: # disable
1161
+ if not current_jupyter_enabled and not jupyter_url:
1162
+ live.stop()
1163
+ console.print(
1164
+ f"[green]✅ Jupyter Lab disabled successfully![/green]"
1165
+ )
1166
+ return True
1167
+ elif (
1168
+ current_jupyter_enabled
1169
+ != initial_state["jupyter_enabled"]
1170
+ ):
1171
+ spinner.text = f"🔄 Stopping Jupyter service..."
1172
+
1173
+ live.update(spinner)
1174
+ time.sleep(3)
1175
+
1176
+ except Exception as poll_error:
1177
+ console.print(
1178
+ f"[red]❌ Error polling Jupyter status: {poll_error}[/red]"
1179
+ )
1180
+ return False
1181
+
1182
+ # Timeout reached
1183
+ live.stop()
1184
+ console.print(
1185
+ f"[yellow]⏰ Timeout after {timeout_minutes} minutes[/yellow]"
1186
+ )
1187
+ console.print(
1188
+ f"[yellow]💡 Use 'gpu-dev show {reservation_id[:8]}' to check Jupyter status[/yellow]"
1189
+ )
1190
+ return False
1191
+
1192
+ except Exception as e:
1193
+ console.print(
1194
+ f"[red]❌ Error during Jupyter {action} polling: {str(e)}[/red]"
1195
+ )
1196
+ return False
1197
+
1198
+ def _poll_add_user_result(
1199
+ self, reservation_id: str, user_id: str, github_username: str, timeout_minutes: int = 3
1200
+ ) -> bool:
1201
+ """Poll reservation table for add user action result"""
1202
+ try:
1203
+ start_time = time.time()
1204
+ timeout_seconds = timeout_minutes * 60
1205
+
1206
+ with Live(console=console, refresh_per_second=2) as live:
1207
+ spinner = Spinner(
1208
+ "dots", text=f"🔄 Adding user {github_username}...")
1209
+ live.update(spinner)
1210
+
1211
+ initial_secondary_users = None
1212
+
1213
+ while time.time() - start_time < timeout_seconds:
1214
+ try:
1215
+ # Get current reservation state - query by user first, then filter by prefix
1216
+ response = self.reservations_table.query(
1217
+ IndexName="UserIndex",
1218
+ KeyConditionExpression="user_id = :user_id",
1219
+ ExpressionAttributeValues={":user_id": user_id},
1220
+ )
1221
+ all_reservations = response.get("Items", [])
1222
+
1223
+ # Handle pagination for UserIndex query
1224
+ while "LastEvaluatedKey" in response:
1225
+ response = self.reservations_table.query(
1226
+ IndexName="UserIndex",
1227
+ KeyConditionExpression="user_id = :user_id",
1228
+ ExpressionAttributeValues={
1229
+ ":user_id": user_id},
1230
+ ExclusiveStartKey=response["LastEvaluatedKey"]
1231
+ )
1232
+ all_reservations.extend(response.get("Items", []))
1233
+
1234
+ # Filter by reservation_id prefix in memory
1235
+ items = [
1236
+ res for res in all_reservations
1237
+ if res.get("reservation_id", "").startswith(reservation_id)
1238
+ ]
1239
+ if len(items) == 0:
1240
+ spinner.text = f"🔄 Waiting for reservation data..."
1241
+ live.update(spinner)
1242
+ time.sleep(2)
1243
+ continue
1244
+ elif len(items) > 1:
1245
+ spinner.text = f"🔄 Multiple reservations found for {reservation_id}, using first match..."
1246
+ live.update(spinner)
1247
+
1248
+ reservation = items[0]
1249
+
1250
+ # Capture initial state on first iteration
1251
+ if initial_secondary_users is None:
1252
+ initial_secondary_users = reservation.get(
1253
+ "secondary_users", []
1254
+ )
1255
+
1256
+ current_secondary_users = reservation.get(
1257
+ "secondary_users", [])
1258
+
1259
+ # Check if the user has been added
1260
+ if github_username in current_secondary_users:
1261
+ live.stop()
1262
+ console.print(
1263
+ f"[green]✅ User {github_username} added successfully![/green]"
1264
+ )
1265
+ console.print(
1266
+ f"[cyan]👥 Secondary users:[/cyan] {', '.join(current_secondary_users)}"
1267
+ )
1268
+ return True
1269
+ elif len(current_secondary_users) != len(
1270
+ initial_secondary_users
1271
+ ):
1272
+ spinner.text = (
1273
+ f"🔄 User list updated, verifying {github_username}..."
1274
+ )
1275
+
1276
+ live.update(spinner)
1277
+ time.sleep(3)
1278
+
1279
+ except Exception as poll_error:
1280
+ console.print(
1281
+ f"[red]❌ Error polling add user status: {poll_error}[/red]"
1282
+ )
1283
+ return False
1284
+
1285
+ # Timeout reached
1286
+ live.stop()
1287
+ console.print(
1288
+ f"[yellow]⏰ Timeout after {timeout_minutes} minutes[/yellow]"
1289
+ )
1290
+ console.print(
1291
+ f"[yellow]💡 Use 'gpu-dev show {reservation_id[:8]}' to check user status[/yellow]"
1292
+ )
1293
+ return False
1294
+
1295
+ except Exception as e:
1296
+ console.print(
1297
+ f"[red]❌ Error during add user polling: {str(e)}[/red]")
1298
+ return False
1299
+
1300
+ def _poll_extend_action_result(
1301
+ self, reservation_id: str, user_id: str, extension_hours: float, timeout_minutes: int = 3, initial_expires_at: str = None
1302
+ ) -> bool:
1303
+ """Poll reservation table for extend action result"""
1304
+ try:
1305
+ start_time = time.time()
1306
+ timeout_seconds = timeout_minutes * 60
1307
+
1308
+ with Live(console=console, refresh_per_second=2) as live:
1309
+ spinner = Spinner(
1310
+ "dots",
1311
+ text=f"🔄 Extending reservation by {extension_hours} hours...",
1312
+ )
1313
+ live.update(spinner)
1314
+
1315
+ # Use pre-captured initial_expires_at if provided (to avoid race condition)
1316
+ initial_expiration = initial_expires_at
1317
+
1318
+ while time.time() - start_time < timeout_seconds:
1319
+ try:
1320
+ # Get current reservation state - query by user first, then filter by prefix
1321
+ response = self.reservations_table.query(
1322
+ IndexName="UserIndex",
1323
+ KeyConditionExpression="user_id = :user_id",
1324
+ ExpressionAttributeValues={":user_id": user_id},
1325
+ )
1326
+ all_reservations = response.get("Items", [])
1327
+
1328
+ # Handle pagination for UserIndex query
1329
+ while "LastEvaluatedKey" in response:
1330
+ response = self.reservations_table.query(
1331
+ IndexName="UserIndex",
1332
+ KeyConditionExpression="user_id = :user_id",
1333
+ ExpressionAttributeValues={
1334
+ ":user_id": user_id},
1335
+ ExclusiveStartKey=response["LastEvaluatedKey"]
1336
+ )
1337
+ all_reservations.extend(response.get("Items", []))
1338
+
1339
+ # Filter by reservation_id prefix in memory
1340
+ items = [
1341
+ res for res in all_reservations
1342
+ if res.get("reservation_id", "").startswith(reservation_id)
1343
+ ]
1344
+ if len(items) == 0:
1345
+ spinner.text = f"🔄 Waiting for reservation data..."
1346
+ live.update(spinner)
1347
+ time.sleep(2)
1348
+ continue
1349
+ elif len(items) > 1:
1350
+ spinner.text = f"🔄 Multiple reservations found for {reservation_id}, using first match..."
1351
+ live.update(spinner)
1352
+
1353
+ reservation = items[0]
1354
+
1355
+ # Capture initial expiration on first iteration
1356
+ if initial_expiration is None:
1357
+ initial_expiration = reservation.get(
1358
+ "expires_at", "")
1359
+
1360
+ current_expiration = reservation.get("expires_at", "")
1361
+
1362
+ # Check for extension failure indicators
1363
+ last_updated = reservation.get("last_updated", 0)
1364
+ extension_error = reservation.get(
1365
+ "extension_error", "")
1366
+
1367
+ # If there's an extension error, fail immediately
1368
+ if extension_error:
1369
+ live.stop()
1370
+ console.print(
1371
+ f"[red]❌ Extension failed: {extension_error}[/red]"
1372
+ )
1373
+ return False
1374
+
1375
+ # Check if the expiration has been updated (different from initial)
1376
+ if (
1377
+ current_expiration != initial_expiration
1378
+ and current_expiration
1379
+ ):
1380
+ live.stop()
1381
+ from datetime import datetime, timezone
1382
+
1383
+ try:
1384
+ # Treat as naive datetime and manually add UTC timezone (matches list command)
1385
+ naive_dt = datetime.fromisoformat(current_expiration)
1386
+ exp_dt_utc = naive_dt.replace(tzinfo=timezone.utc)
1387
+ # Convert to local timezone
1388
+ local_exp = exp_dt_utc.astimezone()
1389
+ # Format with same style as list command: month-day hour:minute
1390
+ formatted_expiration = local_exp.strftime("%m-%d %H:%M")
1391
+ console.print(
1392
+ f"[green]✅ Extended reservation {reservation_id} by {extension_hours} hours -- your new expiration is {formatted_expiration}[/green]"
1393
+ )
1394
+ return True
1395
+ except Exception:
1396
+ # Fallback to raw display if parsing fails
1397
+ console.print(
1398
+ f"[green]✅ Extended reservation {reservation_id} by {extension_hours} hours -- your new expiration is {current_expiration}[/green]"
1399
+ )
1400
+ return True
1401
+
1402
+ spinner.text = f"🔄 Processing extension request..."
1403
+ live.update(spinner)
1404
+ time.sleep(2)
1405
+
1406
+ except Exception as poll_error:
1407
+ spinner.text = f"🔄 Checking extension status (retry)..."
1408
+ live.update(spinner)
1409
+ time.sleep(2)
1410
+
1411
+ live.stop()
1412
+ console.print(
1413
+ f"[red]❌ Extension request timed out after {timeout_minutes} minutes[/red]"
1414
+ )
1415
+ console.print(
1416
+ f"[yellow]The extension may still be processing. Check status with: gpu-dev list[/yellow]"
1417
+ )
1418
+ return False # Return failure on timeout
1419
+
1420
+ except Exception as e:
1421
+ console.print(
1422
+ f"[red]❌ Error polling extension result: {str(e)}[/red]")
1423
+ return False
1424
+
1425
+ def get_cluster_status(self) -> Optional[Dict[str, Any]]:
1426
+ """Get overall GPU cluster status from availability table"""
1427
+ try:
1428
+ # Get reservations with pagination
1429
+ reservations_response = self.reservations_table.scan()
1430
+ reservations = reservations_response.get("Items", [])
1431
+
1432
+ # Handle pagination for admin stats scan
1433
+ while "LastEvaluatedKey" in reservations_response:
1434
+ reservations_response = self.reservations_table.scan(
1435
+ ExclusiveStartKey=reservations_response["LastEvaluatedKey"]
1436
+ )
1437
+ reservations.extend(reservations_response.get("Items", []))
1438
+
1439
+ # Get total GPUs from availability table
1440
+ availability_info = self.get_gpu_availability_by_type()
1441
+ total_gpus = 0
1442
+ available_gpus = 0
1443
+
1444
+ if availability_info:
1445
+ for gpu_type, info in availability_info.items():
1446
+ total_gpus += info.get("total", 0)
1447
+ available_gpus += info.get("available", 0)
1448
+
1449
+ # Calculate stats
1450
+ active_reservations = [
1451
+ r for r in reservations if r.get("status") == "active"
1452
+ ]
1453
+ reserved_gpus = sum(int(r.get("gpu_count", 0))
1454
+ for r in active_reservations)
1455
+
1456
+ # Get queue length
1457
+ try:
1458
+ queue_url = self.config.get_queue_url()
1459
+ queue_attrs = self.config.sqs_client.get_queue_attributes(
1460
+ QueueUrl=queue_url, AttributeNames=[
1461
+ "ApproximateNumberOfMessages"]
1462
+ )
1463
+ queue_length = int(
1464
+ queue_attrs["Attributes"]["ApproximateNumberOfMessages"]
1465
+ )
1466
+ except:
1467
+ queue_length = len(
1468
+ [r for r in reservations if r.get("status") == "pending"]
1469
+ )
1470
+
1471
+ return {
1472
+ "total_gpus": total_gpus,
1473
+ "available_gpus": available_gpus,
1474
+ "reserved_gpus": reserved_gpus,
1475
+ "active_reservations": len(active_reservations),
1476
+ "queue_length": queue_length,
1477
+ }
1478
+
1479
+ except Exception as e:
1480
+ console.print(
1481
+ f"[red]❌ Error getting cluster status: {str(e)}[/red]")
1482
+ return None
1483
+
1484
+ def _wait_for_reservations_completion(
1485
+ self, reservation_ids: List[str], timeout_minutes: Optional[int] = 10, is_multinode: bool = False, verbose: bool = False
1486
+ ) -> Optional[List[Dict[str, Any]]]:
1487
+ """Shared polling logic for both single and multinode reservations (always creates SSH config)"""
1488
+
1489
+ status_messages = {
1490
+ "pending": "⏳ Reservation request submitted, waiting for processing...",
1491
+ "queued": "📋 In queue - waiting for GPU resources...",
1492
+ "preparing": "🚀 GPUs found! Preparing your development environment...",
1493
+ "creating_server": "🐳 Building custom Docker image...",
1494
+ "active": "✅ Reservation complete!",
1495
+ "failed": "❌ Reservation failed",
1496
+ "cancelled": "🛑 Reservation cancelled",
1497
+ }
1498
+
1499
+ start_time = time.time()
1500
+ timeout_seconds = timeout_minutes * 60 if timeout_minutes is not None else None
1501
+ last_status = None
1502
+ last_message = None
1503
+ cancelled = False
1504
+ close_tool = False
1505
+ show_queue_help = True
1506
+ queue_state = {"initial_estimated_wait": None,
1507
+ "queue_start_time": None}
1508
+ total_nodes = len(reservation_ids)
1509
+
1510
+ # Track previous node statuses to only show changes
1511
+ previous_node_statuses = {}
1512
+
1513
+ def handle_interrupt(signum, frame):
1514
+ """Handle Ctrl+C to cancel reservation(s)"""
1515
+ nonlocal cancelled
1516
+ cancelled = True
1517
+
1518
+ def handle_clean_exit(signum, frame):
1519
+ """Handle clean exit signal (SIGTERM)"""
1520
+ nonlocal close_tool
1521
+ close_tool = True
1522
+ reservation_text = "reservations" if is_multinode else "reservation"
1523
+ console.print(
1524
+ f"\n[cyan]🔄 Clean exit requested - keeping {reservation_text} active...[/cyan]"
1525
+ )
1526
+
1527
+ def check_keyboard_input():
1528
+ """Check if clean exit was requested via signal"""
1529
+ return close_tool
1530
+
1531
+ # Set up signal handlers
1532
+ signal.signal(signal.SIGTERM, handle_clean_exit)
1533
+ try:
1534
+ signal.signal(signal.SIGQUIT, handle_clean_exit)
1535
+ action_text = "cancel all reservations" if is_multinode else "cancel reservation"
1536
+ keep_text = "keep reservations" if is_multinode else "keep reservation"
1537
+ console.print(
1538
+ f"[dim]💡 Press [cyan]Ctrl+C[/cyan] to {action_text} • Press [cyan]Ctrl+backslash[/cyan] to exit but {keep_text}[/dim]"
1539
+ )
1540
+ except (AttributeError, OSError):
1541
+ action_text = "cancel all reservations" if is_multinode else "cancel reservation"
1542
+ keep_text = "keep reservations" if is_multinode else "keep reservation"
1543
+ console.print(
1544
+ f"[dim]💡 Press [cyan]Ctrl+C[/cyan] to {action_text} • Send [cyan]SIGTERM[/cyan] to exit but {keep_text}[/dim]"
1545
+ )
1546
+ console.print(
1547
+ f"[dim] (From another terminal: [cyan]kill {os.getpid()}[/cyan])[/dim]"
1548
+ )
1549
+
1550
+ # Set up signal handler for Ctrl+C
1551
+ old_handler = signal.signal(signal.SIGINT, handle_interrupt)
1552
+
1553
+ try:
1554
+ with Live(console=console, refresh_per_second=4) as live:
1555
+ initial_text = f"📡 Starting multinode reservation..." if is_multinode else "🔄 Sending reservation request..."
1556
+ spinner = Spinner("dots", text=initial_text)
1557
+ live.update(spinner)
1558
+
1559
+ while (
1560
+ (timeout_seconds is None or time.time() -
1561
+ start_time < timeout_seconds)
1562
+ and not cancelled
1563
+ and not close_tool
1564
+ ):
1565
+ try:
1566
+ # Check for keyboard input (clean exit)
1567
+ if check_keyboard_input():
1568
+ break
1569
+
1570
+ # Get current status of all reservations
1571
+ all_reservations = []
1572
+ node_details = []
1573
+
1574
+ for i, res_id in enumerate(reservation_ids):
1575
+ try:
1576
+ response = self.reservations_table.get_item(
1577
+ Key={"reservation_id": res_id})
1578
+ if "Item" in response:
1579
+ reservation = response["Item"]
1580
+ all_reservations.append(reservation)
1581
+
1582
+ status = reservation.get(
1583
+ "status", "unknown")
1584
+ failure_reason = reservation.get(
1585
+ "failure_reason", "")
1586
+ current_detailed_status = reservation.get(
1587
+ "current_detailed_status", "")
1588
+ queue_position = reservation.get(
1589
+ "queue_position", "?")
1590
+ estimated_wait = reservation.get(
1591
+ "estimated_wait_minutes", "?")
1592
+ gpu_count = reservation.get("gpu_count", 1)
1593
+
1594
+ # Debug what we're reading from DynamoDB - only show if status changed
1595
+ if verbose:
1596
+ node_key = f"node_{i+1}_{res_id[:8]}"
1597
+ current_node_status = f"status={status}, detailed={current_detailed_status}"
1598
+ if previous_node_statuses.get(node_key) != current_node_status:
1599
+ print(
1600
+ f"[DEBUG] Node {i+1} ({res_id[:8]}): {current_node_status}")
1601
+ previous_node_statuses[node_key] = current_node_status
1602
+
1603
+ node_details.append({
1604
+ "index": i,
1605
+ "status": status,
1606
+ "failure_reason": failure_reason,
1607
+ "current_detailed_status": current_detailed_status,
1608
+ "queue_position": queue_position,
1609
+ "estimated_wait": estimated_wait,
1610
+ "gpu_count": gpu_count,
1611
+ "reservation": reservation
1612
+ })
1613
+ else:
1614
+ # No reservation found yet, keep waiting
1615
+ if not is_multinode:
1616
+ spinner.text = "📡 Waiting for reservation status update..."
1617
+ live.update(spinner)
1618
+ time.sleep(2)
1619
+ continue
1620
+ else:
1621
+ node_details.append({
1622
+ "index": i, "status": "unknown", "failure_reason": "",
1623
+ "current_detailed_status": "", "queue_position": "?",
1624
+ "estimated_wait": "?", "gpu_count": 0, "reservation": None
1625
+ })
1626
+ except Exception as e:
1627
+ if verbose:
1628
+ print(
1629
+ f"[DEBUG] Exception querying {res_id[:8]}: {e}")
1630
+ node_details.append({
1631
+ "index": i, "status": "error", "failure_reason": "Connection error",
1632
+ "current_detailed_status": "", "queue_position": "?",
1633
+ "estimated_wait": "?", "gpu_count": 0, "reservation": None
1634
+ })
1635
+
1636
+ # Calculate aggregate status
1637
+ statuses = [node["status"] for node in node_details]
1638
+ active_count = statuses.count("active")
1639
+ failed_count = statuses.count("failed")
1640
+ cancelled_count = statuses.count("cancelled")
1641
+ preparing_count = statuses.count("preparing")
1642
+ queued_count = statuses.count("queued")
1643
+
1644
+ # Debug multinode status calculation - only show when aggregate status changes
1645
+ # Only when there are mixed statuses
1646
+ if is_multinode and verbose and len(set(statuses)) > 1:
1647
+ print(
1648
+ f"[DEBUG] Mixed node statuses: active={active_count}, preparing={preparing_count}, queued={queued_count}, failed={failed_count}, total={total_nodes}")
1649
+
1650
+ # Determine aggregate status for multinode reservations
1651
+ # Only consider it failed if ALL nodes are explicitly failed/cancelled
1652
+ # or if there's a significant portion failed (more than half)
1653
+ if is_multinode:
1654
+ # For multinode, be more conservative about declaring failure
1655
+ if failed_count + cancelled_count >= total_nodes:
1656
+ # All nodes failed - definitely failed
1657
+ aggregate_status = "failed"
1658
+ elif active_count == total_nodes:
1659
+ # All nodes active - success
1660
+ aggregate_status = "active"
1661
+ elif active_count + preparing_count == total_nodes:
1662
+ # All nodes either active or preparing - still working
1663
+ aggregate_status = "preparing" if preparing_count > 0 else "active"
1664
+ elif queued_count > 0:
1665
+ # Any nodes queued - still in queue
1666
+ aggregate_status = "queued"
1667
+ elif failed_count + cancelled_count > total_nodes // 2:
1668
+ # More than half failed - likely a real failure
1669
+ aggregate_status = "failed"
1670
+ else:
1671
+ # Mixed state - keep preparing/pending
1672
+ aggregate_status = "preparing" if preparing_count > 0 else "pending"
1673
+
1674
+ # Debug aggregate status decision - only show when status changes
1675
+ if is_multinode and verbose and aggregate_status != last_status:
1676
+ print(
1677
+ f"[DEBUG] Calculated aggregate_status: {aggregate_status}")
1678
+
1679
+ else:
1680
+ # Single node - use original logic
1681
+ if failed_count > 0 or cancelled_count > 0:
1682
+ aggregate_status = "failed"
1683
+ elif active_count == total_nodes:
1684
+ aggregate_status = "active"
1685
+ elif preparing_count > 0:
1686
+ aggregate_status = "preparing"
1687
+ elif queued_count > 0:
1688
+ aggregate_status = "queued"
1689
+ else:
1690
+ # Check for creating_server status
1691
+ creating_server_count = statuses.count(
1692
+ "creating_server")
1693
+ if creating_server_count > 0:
1694
+ aggregate_status = "creating_server"
1695
+ else:
1696
+ aggregate_status = "pending"
1697
+
1698
+ # Build status message based on aggregate status and mode
1699
+ message = ""
1700
+
1701
+ if aggregate_status == "queued":
1702
+ # Use first queued node's info for display
1703
+ queued_nodes = [
1704
+ node for node in node_details if node["status"] == "queued"]
1705
+ if queued_nodes:
1706
+ first_queued = queued_nodes[0]
1707
+ queue_position = first_queued["queue_position"]
1708
+ estimated_wait = first_queued["estimated_wait"]
1709
+
1710
+ # Initialize countdown logic
1711
+ if (
1712
+ aggregate_status != last_status and estimated_wait != "?"
1713
+ ) or (
1714
+ estimated_wait != "?"
1715
+ and queue_state["initial_estimated_wait"] is None
1716
+ ):
1717
+ try:
1718
+ wait_minutes = (
1719
+ int(estimated_wait)
1720
+ if isinstance(estimated_wait, (int, str))
1721
+ and str(estimated_wait).isdigit()
1722
+ else None
1723
+ )
1724
+ if wait_minutes is not None:
1725
+ queue_state["initial_estimated_wait"] = wait_minutes
1726
+ queue_state["queue_start_time"] = time.time(
1727
+ )
1728
+ except (ValueError, TypeError):
1729
+ pass
1730
+
1731
+ # Calculate dynamic countdown
1732
+ if (
1733
+ queue_state["initial_estimated_wait"] is not None
1734
+ and queue_state["queue_start_time"] is not None
1735
+ ):
1736
+ elapsed_minutes = (
1737
+ time.time() -
1738
+ queue_state["queue_start_time"]
1739
+ ) / 60
1740
+ remaining_wait = max(
1741
+ 0,
1742
+ queue_state["initial_estimated_wait"] -
1743
+ elapsed_minutes,
1744
+ )
1745
+ wait_display = (
1746
+ f"{remaining_wait:.0f} min"
1747
+ if remaining_wait > 0
1748
+ else "Soon"
1749
+ )
1750
+ else:
1751
+ wait_display = (
1752
+ f"{estimated_wait} min"
1753
+ if estimated_wait != "?"
1754
+ else "Calculating..."
1755
+ )
1756
+
1757
+ if is_multinode:
1758
+ total_gpus = sum(
1759
+ node["gpu_count"] for node in node_details if node["reservation"])
1760
+ message = f"📋 Position #{queue_position} in queue • Estimated wait: {wait_display} • {total_gpus} GPUs across {total_nodes} nodes"
1761
+ else:
1762
+ gpu_count = first_queued["gpu_count"]
1763
+ message = f"📋 You are #{queue_position} in queue • Estimated wait: {wait_display} • {gpu_count} GPU(s) requested"
1764
+
1765
+ # Show help message once when entering queue
1766
+ if show_queue_help and aggregate_status != last_status:
1767
+ help_text = "\n[dim]💡 Press [cyan]Ctrl+C[/cyan] to cancel reservation • Use [cyan]gpu-dev list[/cyan] to check status[/dim]"
1768
+ console.print(help_text)
1769
+ show_queue_help = False
1770
+ else:
1771
+ message = f"📋 Nodes in queue... ({active_count}/{total_nodes} ready)" if is_multinode else "📋 In queue..."
1772
+
1773
+ elif aggregate_status == "preparing":
1774
+ if is_multinode:
1775
+ # Show detailed preparation info for multinode - show ALL nodes, not just preparing ones
1776
+ detailed_events = []
1777
+ for node in node_details:
1778
+ node_status = node["status"]
1779
+ if node_status == "active":
1780
+ detailed_events.append(f"✓ Ready")
1781
+ elif node.get("current_detailed_status"):
1782
+ detailed_events.append(
1783
+ node["current_detailed_status"])
1784
+ elif node.get("failure_reason") and node_status == "failed":
1785
+ detailed_events.append(
1786
+ node["failure_reason"])
1787
+ elif node_status == "preparing":
1788
+ detailed_events.append(
1789
+ "Preparing environment...")
1790
+ elif node_status in ["pending", "queued"]:
1791
+ detailed_events.append(
1792
+ f"{node_status.title()}...")
1793
+ else:
1794
+ detailed_events.append(node_status)
1795
+
1796
+ # Increased from 4 to 16
1797
+ if detailed_events and len(detailed_events) <= 16:
1798
+ # For multinode, create a custom multi-line display with individual spinners
1799
+ from rich.table import Table
1800
+ from rich.text import Text
1801
+ from rich.panel import Panel
1802
+ from rich.console import Group
1803
+
1804
+ # Create a list of renderable items
1805
+ node_lines = []
1806
+
1807
+ for i, event in enumerate(detailed_events):
1808
+ node_num = i + 1
1809
+ node_status = node_details[i]["status"]
1810
+
1811
+ # Create individual spinner or checkmark for each node
1812
+ if node_status == "active":
1813
+ # Ready node - show checkmark without spinner
1814
+ line = Text(
1815
+ f"✓ Node {node_num}: Ready", style="green")
1816
+ else:
1817
+ # Not ready - create a spinner for this specific node
1818
+ node_spinner = Spinner(
1819
+ "dots", text=f"Node {node_num}: {event}")
1820
+ line = node_spinner
1821
+
1822
+ node_lines.append(line)
1823
+
1824
+ # Group all lines together
1825
+ group = Group(*node_lines)
1826
+
1827
+ # Add summary line
1828
+ summary = Text(
1829
+ f"({active_count}/{total_nodes} ready)", style="cyan")
1830
+ full_display = Group(group, summary)
1831
+
1832
+ # Update live display with all spinners
1833
+ panel = Panel(
1834
+ full_display, title="🚀 Multinode Setup", expand=False)
1835
+ live.update(panel)
1836
+
1837
+ # Don't set message since we're using custom display
1838
+ message = None
1839
+ else:
1840
+ # Summarize if we have many nodes
1841
+ preparing_count = statuses.count(
1842
+ "preparing")
1843
+ message = f"🚀 Preparing {preparing_count} nodes... ({active_count}/{total_nodes} ready)"
1844
+ else:
1845
+ # Show detailed preparation info for single node
1846
+ preparing_nodes = [
1847
+ node for node in node_details if node["status"] == "preparing"]
1848
+ node = preparing_nodes[0] if preparing_nodes else node_details[0]
1849
+
1850
+ # Use unified status tracking - prefer current_detailed_status, fall back to failure_reason for actual failures
1851
+ current_detailed_status = node.get(
1852
+ "current_detailed_status", "")
1853
+ failure_reason = node.get("failure_reason", "") if node.get(
1854
+ "status") == "failed" else ""
1855
+
1856
+ if current_detailed_status:
1857
+ message = f"🚀 {current_detailed_status}"
1858
+ elif failure_reason:
1859
+ message = f"🚀 Failed: {failure_reason}"
1860
+ else:
1861
+ message = status_messages.get(
1862
+ aggregate_status, f"Status: {aggregate_status}")
1863
+
1864
+ elif aggregate_status == "failed":
1865
+ failed_nodes = [node for node in node_details if node["status"] in [
1866
+ "failed", "cancelled"]]
1867
+ if is_multinode:
1868
+ failure_details = []
1869
+ for node in failed_nodes:
1870
+ reason = node["failure_reason"] if node["failure_reason"] else node["status"]
1871
+ failure_details.append(
1872
+ f"Node {node['index']+1}: {reason}")
1873
+
1874
+ # Add debug info about all node statuses for troubleshooting
1875
+ debug_details = []
1876
+ for node in node_details:
1877
+ status = node["status"]
1878
+ debug_details.append(
1879
+ f"Node {node['index']+1}: {status}")
1880
+
1881
+ status_display = "\n".join(
1882
+ [f" {detail}" for detail in failure_details])
1883
+ debug_display = "\n".join(
1884
+ [f" {detail}" for detail in debug_details])
1885
+ message = f"❌ Multinode failed ({failed_count + cancelled_count}/{total_nodes})\n{status_display}"
1886
+ live.update(Spinner("dots", text=message))
1887
+ time.sleep(2)
1888
+ console.print(
1889
+ f"\n[red]❌ Multinode reservation failed ({failed_count + cancelled_count}/{total_nodes} nodes failed)[/red]")
1890
+ for detail in failure_details:
1891
+ console.print(f"[red] {detail}[/red]")
1892
+ console.print(
1893
+ f"\n[dim]Debug - All node statuses:[/dim]")
1894
+ for detail in debug_details:
1895
+ console.print(f"[dim] {detail}[/dim]")
1896
+ return None
1897
+ else:
1898
+ # Handle single node failure below in completion check
1899
+ pass
1900
+
1901
+ elif aggregate_status == "active":
1902
+ if is_multinode:
1903
+ # Check if all nodes are truly ready: "active" status AND valid SSH command
1904
+ nodes_ready = 0
1905
+ for node in node_details:
1906
+ if (node["status"] == "active" and
1907
+ node["reservation"] and
1908
+ node["reservation"].get("ssh_command", "ssh user@pending") not in ["ssh user@pending"] and
1909
+ not node["reservation"].get("ssh_command", "").endswith(".cluster.local")):
1910
+ nodes_ready += 1
1911
+
1912
+ if nodes_ready == total_nodes:
1913
+ # All nodes truly ready with SSH access
1914
+ live.update(
1915
+ Spinner("dots", text=f"✅ All {total_nodes} nodes ready!"))
1916
+ time.sleep(1)
1917
+ console.print(
1918
+ f"\n[green]✅ Multinode reservation complete! All {total_nodes} nodes are ready.[/green]")
1919
+
1920
+ # Create SSH config files and show connection info for each node
1921
+ for node in node_details:
1922
+ if node["reservation"]:
1923
+ res = node["reservation"]
1924
+ fqdn = res.get("fqdn")
1925
+ pod_name = res.get("pod_name")
1926
+ res_id = res.get("reservation_id")
1927
+ res_name = res.get("name")
1928
+
1929
+ # Create SSH config file for this node
1930
+ config_path = None
1931
+ use_include = False
1932
+ if fqdn and pod_name and res_id:
1933
+ try:
1934
+ config_path, use_include = create_ssh_config_for_reservation(
1935
+ fqdn, pod_name, res_id, res_name)
1936
+ except Exception as e:
1937
+ console.print(
1938
+ f"[yellow]⚠️ Could not create SSH config for node {node['index']+1}: {str(e)}[/yellow]")
1939
+
1940
+ # Show connection info
1941
+ if config_path and pod_name and use_include:
1942
+ console.print(
1943
+ f"[cyan]🖥️ Node {node['index']+1}:[/cyan] [green]ssh {pod_name}[/green]")
1944
+ else:
1945
+ ssh_command = res.get(
1946
+ "ssh_command", "ssh user@pending")
1947
+ ssh_with_forwarding = _add_agent_forwarding_to_ssh(
1948
+ ssh_command)
1949
+ console.print(
1950
+ f"[cyan]🖥️ Node {node['index']+1}:[/cyan] {ssh_with_forwarding}")
1951
+
1952
+ return all_reservations
1953
+ else:
1954
+ # Some nodes are "active" but SSH not ready yet - keep preparing
1955
+ # For multinode, don't override detailed Panel display with summary message
1956
+ # The preparing logic above will show detailed per-node status
1957
+ message = f"🚀 Setting up SSH access... ({nodes_ready}/{total_nodes} ready)"
1958
+ # Don't directly update spinner here - let the main logic handle display
1959
+ else:
1960
+ # Handle single node completion below in completion check
1961
+ pass
1962
+
1963
+ else:
1964
+ # Default pending/unknown status
1965
+ if is_multinode:
1966
+ message = f"⏳ Processing multinode reservation... ({active_count}/{total_nodes} ready)"
1967
+ else:
1968
+ # Check for detailed status during Docker builds or other detailed operations
1969
+ if aggregate_status == "creating_server" and len(all_reservations) > 0:
1970
+ reservation = all_reservations[0]
1971
+ current_detailed_status = reservation.get(
1972
+ "current_detailed_status", "")
1973
+ if current_detailed_status:
1974
+ message = f"🐳 {current_detailed_status}"
1975
+ else:
1976
+ message = status_messages.get(
1977
+ aggregate_status, f"Status: {aggregate_status}")
1978
+ else:
1979
+ message = status_messages.get(
1980
+ aggregate_status, f"Status: {aggregate_status}")
1981
+
1982
+ # Update spinner if status changed, message changed, or we're in certain states
1983
+ # BUT: Don't override custom Panel display for multinode with spinner
1984
+ if (aggregate_status != last_status or
1985
+ message != last_message or
1986
+ aggregate_status in ["queued", "preparing", "creating_server"]):
1987
+ if message and not (is_multinode and aggregate_status == "preparing"):
1988
+ # Only use spinner for single-node or non-preparing multinode states
1989
+ spinner.text = message
1990
+ last_status = aggregate_status
1991
+ last_message = message
1992
+ live.update(spinner)
1993
+ elif not is_multinode and message:
1994
+ # Single node - always use spinner
1995
+ spinner.text = message
1996
+ last_status = aggregate_status
1997
+ last_message = message
1998
+ live.update(spinner)
1999
+ # For multinode preparing with custom display, we already updated above with Panel
2000
+
2001
+ # Check for single-node completion states (when not multinode or already handled above)
2002
+ if not is_multinode and aggregate_status == "active":
2003
+ reservation = all_reservations[0]
2004
+ ssh_command = reservation.get(
2005
+ "ssh_command", "ssh user@pending")
2006
+
2007
+ # Only complete if we have a real SSH command (not pending/placeholder)
2008
+ if ssh_command != "ssh user@pending" and not ssh_command.endswith(".cluster.local"):
2009
+ live.stop()
2010
+ duration_hours = reservation.get(
2011
+ "duration_hours", 8)
2012
+ reservation_id = reservation["reservation_id"]
2013
+
2014
+ console.print(
2015
+ f"\n[green]✅ Reservation complete![/green]")
2016
+ console.print(
2017
+ f"[cyan]📋 Reservation ID:[/cyan] {reservation_id}")
2018
+ console.print(
2019
+ f"[cyan]⏰ Valid for:[/cyan] {duration_hours} hours")
2020
+
2021
+ # Show quick connect command
2022
+ short_id = reservation_id[:8]
2023
+ console.print(
2024
+ f"[cyan]⚡ Quick Connect:[/cyan] [green]gpu-dev connect {short_id}[/green]")
2025
+
2026
+ # Always create SSH config file for this reservation
2027
+ fqdn = reservation.get("fqdn")
2028
+ pod_name = reservation.get("pod_name")
2029
+ res_id = reservation.get("reservation_id")
2030
+ res_name = reservation.get("name")
2031
+ config_path = None
2032
+ use_include = False
2033
+ if fqdn and pod_name and res_id:
2034
+ try:
2035
+ config_path, use_include = create_ssh_config_for_reservation(
2036
+ fqdn, pod_name, res_id, res_name)
2037
+ except Exception as e:
2038
+ console.print(
2039
+ f"[yellow]⚠️ Could not create SSH config: {str(e)}[/yellow]")
2040
+
2041
+ # Show SSH command using config file if created, otherwise fallback
2042
+ if config_path and pod_name:
2043
+ if use_include:
2044
+ # User approved Include - show simple commands
2045
+ console.print(
2046
+ f"[cyan]🖥️ SSH Command:[/cyan] [green]ssh {pod_name}[/green]")
2047
+ # Create clickable VS Code link
2048
+ vscode_url = _make_vscode_link(pod_name)
2049
+ vscode_command = f"code --remote ssh-remote+{pod_name} /home/dev"
2050
+ console.print(
2051
+ f"[cyan]💻 VS Code Remote:[/cyan] [link={vscode_url}][green]{vscode_command}[/green][/link]")
2052
+
2053
+ # Create clickable Cursor link
2054
+ cursor_url = _make_cursor_link(pod_name)
2055
+ cursor_command = f"cursor --remote ssh-remote+{pod_name} /home/dev"
2056
+ console.print(
2057
+ f"[cyan]🖥️ Cursor Remote:[/cyan] [link={cursor_url}][green]{cursor_command}[/green][/link]")
2058
+ else:
2059
+ # User declined Include - show commands with -F flag
2060
+ console.print(
2061
+ f"[cyan]🖥️ SSH Command:[/cyan] [green]ssh -F {config_path} {pod_name}[/green]")
2062
+ console.print(
2063
+ f"[cyan]💻 VS Code/Cursor:[/cyan] Add [green]Include ~/.gpu-dev/*-sshconfig[/green] to ~/.ssh/config and ~/.cursor/ssh_config")
2064
+ console.print(
2065
+ f"[dim] Or run: [green]gpu-dev config ssh-include enable[/green][/dim]")
2066
+ else:
2067
+ # Fallback to full SSH command if config creation failed
2068
+ ssh_with_forwarding = _add_agent_forwarding_to_ssh(ssh_command)
2069
+ console.print(
2070
+ f"[cyan]🖥️ SSH Command:[/cyan] {ssh_with_forwarding}")
2071
+
2072
+ vscode_command = _generate_vscode_command(ssh_command)
2073
+ if vscode_command:
2074
+ console.print(
2075
+ f"[cyan]💻 VS Code Remote:[/cyan] {vscode_command}")
2076
+
2077
+ cursor_command = _generate_cursor_command(ssh_command)
2078
+ if cursor_command:
2079
+ console.print(
2080
+ f"[cyan]🖱️ Cursor Remote:[/cyan] {cursor_command}")
2081
+
2082
+ # Show Jupyter link if enabled
2083
+ jupyter_enabled = reservation.get(
2084
+ "jupyter_enabled", False)
2085
+ jupyter_url = reservation.get(
2086
+ "jupyter_url", "")
2087
+ if jupyter_enabled and jupyter_url:
2088
+ console.print(
2089
+ f"[cyan]📊 Jupyter Lab:[/cyan] {jupyter_url}")
2090
+
2091
+ return all_reservations
2092
+ else:
2093
+ # Still preparing - show status but don't complete yet
2094
+ current_detailed_status = reservation.get(
2095
+ "current_detailed_status", "")
2096
+ if current_detailed_status:
2097
+ message = f"🚀 {current_detailed_status}"
2098
+ else:
2099
+ message = "🚀 Setting up external SSH access..."
2100
+
2101
+ if message != (last_status if isinstance(last_status, str) else ""):
2102
+ spinner.text = message
2103
+ live.update(spinner)
2104
+
2105
+ elif not is_multinode and aggregate_status in ["failed", "cancelled"]:
2106
+ live.stop()
2107
+ reservation = all_reservations[0] if all_reservations else {
2108
+ }
2109
+ failure_reason = reservation.get(
2110
+ "failure_reason",
2111
+ reservation.get("current_detailed_status", "Unknown error"))
2112
+ reservation_id = reservation.get(
2113
+ "reservation_id", "unknown")
2114
+
2115
+ if aggregate_status == "failed":
2116
+ console.print(
2117
+ f"\n[red]❌ Reservation failed: {failure_reason}[/red]")
2118
+ console.print(
2119
+ f"[red]📋 Reservation ID: {reservation_id}[/red]")
2120
+
2121
+ # Show pod logs if available
2122
+ pod_logs = reservation.get("pod_logs", "")
2123
+ if pod_logs and pod_logs.strip():
2124
+ from rich.panel import Panel
2125
+ from rich.text import Text
2126
+
2127
+ console.print(
2128
+ "\n[red]🔍 Pod logs (last 20 lines) - Details:[/red]")
2129
+ log_text = Text(pod_logs)
2130
+ log_panel = Panel(
2131
+ log_text,
2132
+ title="🐚 Container Startup Logs",
2133
+ title_align="left",
2134
+ border_style="red",
2135
+ expand=False,
2136
+ )
2137
+ console.print(log_panel)
2138
+ else:
2139
+ console.print(
2140
+ f"\n[yellow]🛑 Reservation was cancelled[/yellow]")
2141
+
2142
+ return None
2143
+
2144
+ # Continue polling
2145
+ time.sleep(3)
2146
+
2147
+ except Exception as e:
2148
+ console.print(
2149
+ f"\n[red]❌ Error polling reservation status: {str(e)}[/red]")
2150
+ return None
2151
+
2152
+ # Handle cancellation
2153
+ if cancelled:
2154
+ live.stop()
2155
+ action_text = "multinode reservation" if is_multinode else "reservation request"
2156
+ console.print(
2157
+ f"\n[yellow]⚠️ Cancelling {action_text}...[/yellow]")
2158
+
2159
+ # Cancel all reservations
2160
+ success_count = 0
2161
+ for res_id in reservation_ids:
2162
+ try:
2163
+ response = self.reservations_table.get_item(
2164
+ Key={"reservation_id": res_id})
2165
+ if "Item" in response:
2166
+ user_id = response["Item"].get(
2167
+ "user_id", "unknown")
2168
+ if self.cancel_reservation(res_id, user_id):
2169
+ success_count += 1
2170
+ except Exception as e:
2171
+ console.print(
2172
+ f"[red]❌ Error cancelling reservation {res_id[:8]}: {str(e)}[/red]")
2173
+
2174
+ if success_count == len(reservation_ids):
2175
+ success_text = "All reservations cancelled successfully" if is_multinode else "Reservation cancelled successfully"
2176
+ console.print(f"[green]✅ {success_text}[/green]")
2177
+ elif success_count > 0:
2178
+ console.print(
2179
+ f"[yellow]⚠️ {success_count}/{len(reservation_ids)} reservations cancelled[/yellow]")
2180
+ else:
2181
+ fail_text = "Failed to cancel reservations" if is_multinode else "Failed to cancel reservation"
2182
+ console.print(f"[red]❌ {fail_text}[/red]")
2183
+
2184
+ return None
2185
+
2186
+ # Handle clean exit
2187
+ if close_tool:
2188
+ live.stop()
2189
+ if is_multinode:
2190
+ id_display = ", ".join([res_id[:8]
2191
+ for res_id in reservation_ids])
2192
+ console.print(
2193
+ f"\n[cyan]📱 Exiting - multinode reservations {id_display} continue in background...[/cyan]")
2194
+ else:
2195
+ console.print(
2196
+ f"\n[cyan]📱 Exiting - reservation {reservation_ids[0][:8]} continues in background...[/cyan]")
2197
+ console.print(
2198
+ "[cyan]💡 Use 'gpu-dev list' to check status[/cyan]")
2199
+ if not is_multinode:
2200
+ console.print(
2201
+ f"[cyan]💡 Use 'gpu-dev show {reservation_ids[0][:8]}' to get connection details when ready[/cyan]")
2202
+ return None
2203
+
2204
+ # Timeout reached
2205
+ live.stop()
2206
+ if timeout_minutes is not None:
2207
+ console.print(
2208
+ f"\n[yellow]⏰ Timeout reached after {timeout_minutes} minutes[/yellow]")
2209
+ else:
2210
+ console.print(
2211
+ f"\n[yellow]⏰ Polling stopped unexpectedly[/yellow]")
2212
+ console.print(
2213
+ "[yellow]🔍 Check reservation status manually with: gpu-dev list[/yellow]")
2214
+ return None
2215
+
2216
+ finally:
2217
+ # Restore original signal handlers
2218
+ signal.signal(signal.SIGINT, old_handler)
2219
+ try:
2220
+ signal.signal(signal.SIGTERM, signal.SIG_DFL)
2221
+ signal.signal(signal.SIGQUIT, signal.SIG_DFL)
2222
+ except (AttributeError, OSError):
2223
+ pass
2224
+
2225
+ def wait_for_reservation_completion(
2226
+ self, reservation_id: str, timeout_minutes: Optional[int] = 10, verbose: bool = False
2227
+ ) -> Optional[Dict[str, Any]]:
2228
+ """Poll for single reservation completion using shared polling logic (always creates SSH config)"""
2229
+ results = self._wait_for_reservations_completion(
2230
+ [reservation_id], timeout_minutes, is_multinode=False, verbose=verbose)
2231
+ return results[0] if results else None