hud-python 0.4.15__py3-none-any.whl → 0.4.17__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of hud-python might be problematic. Click here for more details.

hud/cli/rl/pod.py ADDED
@@ -0,0 +1,495 @@
1
+ """Pod creation and management utilities for Prime Intellect."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import random
6
+ import re
7
+ import string
8
+ import subprocess
9
+ import time
10
+ from pathlib import Path # noqa: TC003
11
+
12
+ import typer
13
+ from rich.console import Console
14
+ from rich.live import Live
15
+ from rich.spinner import Spinner
16
+
17
+ from hud.settings import settings
18
+ from hud.utils.design import HUDDesign
19
+
20
+ from .ssh import check_and_configure_ssh_key, connect_and_train
21
+
22
+ design = HUDDesign()
23
+
24
+
25
+ def parse_gpu_config(gpus: str) -> tuple[int, str]:
26
+ """Parse GPU configuration string like '2xA100' into count and type."""
27
+ if "x" in gpus:
28
+ count_str, gpu_type = gpus.split("x", 1)
29
+ try:
30
+ count = int(count_str)
31
+ except ValueError as e:
32
+ design.error(f"Invalid GPU count: {count_str}")
33
+ raise typer.Exit(1) from e
34
+ else:
35
+ # Default to 1 GPU if no count specified
36
+ count = 1
37
+ gpu_type = gpus
38
+
39
+ # Map common GPU names to Prime's expected format
40
+ gpu_type_map = {
41
+ "A100": "A100_80GB",
42
+ "A10": "A10_24GB",
43
+ "H100": "H100_80GB",
44
+ "V100": "V100_32GB",
45
+ "RTX3090": "RTX_3090",
46
+ "RTX4090": "RTX_4090",
47
+ }
48
+
49
+ gpu_type = gpu_type_map.get(gpu_type, gpu_type)
50
+
51
+ return count, gpu_type
52
+
53
+
54
+ async def create_and_connect_prime_pod(
55
+ pod_name: str,
56
+ gpu_type: str,
57
+ gpu_count: int,
58
+ model: str,
59
+ dataset: str,
60
+ config: Path,
61
+ output_dir: Path,
62
+ image: str,
63
+ team_id: str | None = None,
64
+ dataset_size: int | None = None,
65
+ is_json_file: bool = False,
66
+ ) -> None:
67
+ """Create a Prime Intellect pod and connect to it for training."""
68
+ design.section_title("🌐 Creating Prime Intellect Pod")
69
+
70
+ create_cmd = [
71
+ "prime",
72
+ "pods",
73
+ "create",
74
+ "--gpu-type",
75
+ gpu_type,
76
+ "--gpu-count",
77
+ str(gpu_count),
78
+ "--name",
79
+ pod_name,
80
+ ]
81
+
82
+ design.info(f"Creating pod: {pod_name}")
83
+ design.info(f"GPU configuration: {gpu_count}x {gpu_type}")
84
+
85
+ # Check for global team config first
86
+ has_global_team = False
87
+ if not team_id: # Only check if not explicitly provided
88
+ team_check = subprocess.run( # noqa: ASYNC221
89
+ ["prime", "config", "view"], # noqa: S607
90
+ capture_output=True,
91
+ text=True,
92
+ )
93
+ if team_check.returncode == 0:
94
+ # Parse the table output more carefully
95
+ for line in team_check.stdout.split("\n"):
96
+ # Look for "Team ID" in the table (case insensitive)
97
+ if "team id" in line.lower():
98
+ # Check if there's a value after the | separator
99
+ parts = line.split("|")
100
+ if len(parts) >= 2:
101
+ # Get the value part and check if it's not empty
102
+ value = parts[1].strip()
103
+ if value and value != "None":
104
+ has_global_team = True
105
+ # Don't overwrite team_id parameter - that's for explicit user input
106
+ break
107
+
108
+ # Display automated selections
109
+ design.info("")
110
+ design.info("Automated selections:")
111
+ design.info(" Provider: Will select from supported providers")
112
+ design.info(" Disk: Default size")
113
+ design.info(" Image: cuda_12_4_pytorch_2_5")
114
+ if team_id:
115
+ design.info(f" Team: {team_id}")
116
+ elif has_global_team:
117
+ design.info(" Team: Using pre-configured team")
118
+ else:
119
+ design.info(" Team: Personal Account")
120
+ design.info("")
121
+
122
+ # First, get the provider list by running the command with minimal input
123
+ design.info("Checking available providers...")
124
+
125
+ # Run command with just a newline to see provider list
126
+ provider_check = subprocess.run( # noqa: S603, ASYNC221
127
+ create_cmd,
128
+ input="\n", # Just send newline to see providers
129
+ text=True,
130
+ capture_output=True,
131
+ )
132
+
133
+ # Parse provider list
134
+ provider_lines = []
135
+ provider_map = {} # Maps provider name to number
136
+
137
+ if provider_check.stdout:
138
+ lines = provider_check.stdout.strip().split("\n")
139
+ for line in lines:
140
+ # Look for lines like "1. datacrunch (spot) ($0.65/hr)"
141
+ if ". " in line and ("$" in line or "/hr" in line):
142
+ # Extract provider number and name
143
+ parts = line.strip().split(". ", 1)
144
+ if len(parts) == 2:
145
+ num = parts[0].strip()
146
+ # Extract provider name (before parentheses or dollar sign)
147
+ provider_info = parts[1]
148
+ provider_name = provider_info.split("(")[0].split("$")[0].strip().lower()
149
+ provider_map[provider_name] = num
150
+ provider_lines.append(line.strip())
151
+
152
+ # Select provider based on our supported list
153
+ supported_providers = ["datacrunch", "hyperstack"]
154
+ provider_choice = "1" # Default fallback
155
+
156
+ for provider in supported_providers:
157
+ if provider in provider_map:
158
+ provider_choice = provider_map[provider]
159
+ design.info(f"Selected provider: {provider} (option {provider_choice})")
160
+ break
161
+
162
+ # Build inputs step by step for clarity
163
+ disk_size = "" # Just press enter for default
164
+ image_choice = "7" # cuda_12_4_pytorch_2_5
165
+
166
+ # Log what we're doing
167
+ design.debug("Pod creation configuration:")
168
+ design.debug(f" Team ID provided: {team_id}")
169
+ design.debug(f" Global team detected: {has_global_team}")
170
+
171
+ if team_id:
172
+ # Explicit team ID provided, select Custom Team ID (option 3)
173
+ team_choice = "3"
174
+ # Fixed: confirmation should be lowercase 'y'
175
+ inputs = f"{provider_choice}\n{disk_size}\n{image_choice}\n{team_choice}\n{team_id}\ny\n"
176
+ design.debug(f" Using explicit team ID: option {team_choice} with ID {team_id}")
177
+ elif has_global_team:
178
+ # When team is pre-configured, it shows as option 2 - select it
179
+ team_choice = "2"
180
+ # Fixed: confirmation should be lowercase 'y' and come after team selection
181
+ inputs = f"{provider_choice}\n{disk_size}\n{image_choice}\n{team_choice}\ny\n"
182
+ design.debug(f" Using pre-configured team: option {team_choice}")
183
+ else:
184
+ # Personal account (option 1) - just press enter to accept default [1]
185
+ inputs = (
186
+ f"{provider_choice}\n{disk_size}\n{image_choice}\n\ny\n" # Empty line for default [1]
187
+ )
188
+ design.debug(" Using personal account: default option [1]")
189
+
190
+ design.debug(
191
+ f" Input sequence: provider={provider_choice}, disk={disk_size or 'default'}, image={image_choice}, team={team_choice if 'team_choice' in locals() else 'default'}" # noqa: E501
192
+ )
193
+
194
+ # Show found providers
195
+ if provider_lines:
196
+ design.info("")
197
+ design.info("Found providers:")
198
+ for pl in provider_lines[:5]: # Show first 5
199
+ design.info(f" {pl}")
200
+
201
+ try:
202
+ console = Console()
203
+
204
+ with Live(
205
+ Spinner("dots", text="[bold]Creating pod...[/bold]", style="gold"),
206
+ console=console,
207
+ refresh_per_second=10,
208
+ ):
209
+ result = subprocess.run( # noqa: S603, ASYNC221
210
+ create_cmd,
211
+ input=inputs,
212
+ text=True,
213
+ capture_output=True,
214
+ )
215
+
216
+ if result.returncode != 0:
217
+ design.error("Failed to create pod")
218
+
219
+ # Parse output for better error reporting
220
+ output_lines = result.stdout.strip().split("\n") if result.stdout else []
221
+
222
+ # Look for provider prices
223
+ for line in output_lines:
224
+ if "$" in line and "/hr" in line:
225
+ design.info(f"Provider option: {line.strip()}")
226
+
227
+ # Check for team selection error
228
+ if "invalid selection" in result.stdout.lower():
229
+ design.error("Team selection failed")
230
+ # Find and display the team selection section
231
+ for i, line in enumerate(output_lines):
232
+ if "Select Team:" in line:
233
+ design.info("Team selection options:")
234
+ # Show next few lines
235
+ for j in range(i, min(i + 6, len(output_lines))):
236
+ design.info(f" {output_lines[j]}")
237
+ break
238
+
239
+ design.info("")
240
+ design.hint(
241
+ "The Prime CLI interface may have changed. Try running the command manually:"
242
+ )
243
+ design.command_example(
244
+ f"prime pods create --gpu-type {gpu_type} --gpu-count {gpu_count} --name {pod_name}" # noqa: E501
245
+ )
246
+
247
+ # Show error details
248
+ if result.stderr:
249
+ design.error("Error output:")
250
+ for line in result.stderr.strip().split("\n"):
251
+ design.error(f" {line}")
252
+
253
+ # Show last part of stdout for context
254
+ if result.stdout:
255
+ design.info("Command output:")
256
+ # Show last 15 lines for brevity
257
+ for line in output_lines[-15:]:
258
+ design.info(f" {line}")
259
+
260
+ if "max_price" in str(result.stderr) or "max_price" in str(result.stdout):
261
+ design.warning("")
262
+ design.warning("The selected provider requires a maximum price limit.")
263
+ design.info("This is a known limitation with some providers.")
264
+ design.info("")
265
+ design.hint("Workarounds:")
266
+ design.info("1. Run the command manually and select a different provider")
267
+ design.info("2. Try again later when datacrunch (usually cheapest) is available")
268
+ design.info("3. Use the Prime web interface: https://app.primeintellect.ai")
269
+
270
+ design.info("")
271
+ design.info("Debug info:")
272
+ design.info(f" Command: {' '.join(create_cmd)}")
273
+ design.info(f" Pod name: {pod_name}")
274
+ design.info(f" Team ID: {'Provided' if team_id else 'Not provided'}")
275
+ design.info(f" Global team detected: {has_global_team}")
276
+
277
+ # Show the exact inputs we sent
278
+ design.info(" Inputs sent (in order):")
279
+ input_parts = inputs.strip().split("\n")
280
+ input_labels = [
281
+ "Provider selection",
282
+ "Disk size",
283
+ "Image selection",
284
+ "Team selection",
285
+ "Team ID (if custom)",
286
+ "Confirmation",
287
+ ]
288
+ for i, (part, label) in enumerate(zip(input_parts, input_labels, strict=False)):
289
+ if part:
290
+ design.info(f" {i + 1}. {label}: '{part}'")
291
+ else:
292
+ design.info(f" {i + 1}. {label}: [Enter/default]")
293
+
294
+ raise typer.Exit(1)
295
+
296
+ # Extract pod ID from output
297
+ output_lines = result.stdout.strip().split("\n")
298
+ pod_id = None
299
+ for line in output_lines:
300
+ if "Successfully created pod" in line:
301
+ # Extract just the pod ID (alphanumeric characters)
302
+ match = re.search(r"pod\s+([a-f0-9]+)", line)
303
+ pod_id = match.group(1) if match else line.split()[-1].strip()
304
+ break
305
+
306
+ if not pod_id:
307
+ design.error("Could not extract pod ID from output")
308
+ design.info(f"Output: {result.stdout}")
309
+ raise typer.Exit(1)
310
+
311
+ design.success(f"Created pod: {pod_id}")
312
+
313
+ # Poll for pod status
314
+ ssh_info = await poll_pod_status(pod_id)
315
+
316
+ if ssh_info:
317
+ design.success("Pod is ready!")
318
+ design.info(f"SSH: {ssh_info}")
319
+
320
+ # Check if SSH key is configured globally
321
+ ssh_key_configured = await check_and_configure_ssh_key()
322
+
323
+ if ssh_key_configured:
324
+ # Automatically connect and run training
325
+ await connect_and_train(
326
+ pod_id=pod_id,
327
+ ssh_info=ssh_info,
328
+ model=model,
329
+ dataset=dataset,
330
+ config=config,
331
+ output_dir=output_dir,
332
+ image=image,
333
+ dataset_size=dataset_size,
334
+ is_json_file=is_json_file,
335
+ )
336
+ else:
337
+ # Manual fallback
338
+ design.section_title("📋 Manual Connection Required")
339
+ design.info("SSH key configuration failed. Connect manually:")
340
+ design.info("")
341
+ design.info("1. Download the SSH key from:")
342
+ design.info(" https://app.primeintellect.ai/dashboard/profile")
343
+ design.info("")
344
+ design.info("2. Set permissions:")
345
+ design.command_example("chmod 400 /path/to/prime-key.pem", "")
346
+ design.info("")
347
+ design.info("3. Connect to your instance:")
348
+ design.command_example(f"ssh -i /path/to/prime-key.pem {ssh_info}", "")
349
+ design.info("")
350
+ design.info("4. Run these commands:")
351
+ design.command_example("pip install verifiers hud-vf-gym", "")
352
+ design.command_example(f"prime env install {image}", "")
353
+
354
+ # Build training command with env vars
355
+ if settings.wandb_api_key:
356
+ design.command_example(f"export WANDB_API_KEY={settings.wandb_api_key}", "")
357
+
358
+ train_cmd = f"""vf-train hud-vf-gym \\
359
+ --model {model} \\
360
+ --env-args '{{"taskset": "{dataset}", "config_path": "/root/config.yaml"}}' \\
361
+ --output-dir {output_dir} \\
362
+ --run-name hud-rl-{pod_id[:8]} \\
363
+ --wandb-project hud-rl"""
364
+
365
+ design.command_example(train_cmd, "")
366
+ design.info("")
367
+ design.warning(f"Remember to terminate when done: prime pods terminate {pod_id}")
368
+ else:
369
+ design.error("Pod failed to become active")
370
+ raise typer.Exit(1)
371
+
372
+ except subprocess.CalledProcessError as e:
373
+ design.error(f"Failed to create pod: {e}")
374
+ raise typer.Exit(1) from e
375
+
376
+
377
+ async def poll_pod_status(pod_id: str) -> str | None:
378
+ """Poll pod status until SSH is available."""
379
+ console = Console()
380
+ max_attempts = 120 # 20 minutes with 10s intervals
381
+ attempt = 0
382
+
383
+ # Create spinner
384
+ spinner = Spinner(
385
+ "dots", text="Waiting for pod to become active (should take 5-20 min)...", style="gold"
386
+ )
387
+
388
+ with Live(spinner, console=console, refresh_per_second=10) as live:
389
+ while attempt < max_attempts:
390
+ try:
391
+ # Update check frequency in spinner text every minute
392
+ if attempt % 6 == 0: # Every minute
393
+ pass # Will update in spinner text below
394
+
395
+ result = subprocess.run( # noqa: S603, ASYNC221
396
+ ["prime", "pods", "status", pod_id], # noqa: S607
397
+ capture_output=True,
398
+ text=True,
399
+ )
400
+
401
+ if result.returncode == 0:
402
+ output = result.stdout
403
+ elapsed_minutes = (attempt * 10) // 60
404
+
405
+ # Parse status - look for lines with Status and SSH
406
+ lines = output.split("\n")
407
+ status_value = None
408
+ ssh_value = None
409
+
410
+ for line in lines:
411
+ # Handle both regular pipes | and box-drawing chars │
412
+ if "|" in line or "│" in line:
413
+ # Split by either type of pipe
414
+ separator = "│" if "│" in line else "|"
415
+ parts = [p.strip() for p in line.split(separator)]
416
+
417
+ if len(parts) >= 3:
418
+ key = parts[1].strip()
419
+ value = parts[2].strip()
420
+
421
+ if key == "Status":
422
+ status_value = value
423
+ elif key == "SSH":
424
+ ssh_value = value
425
+
426
+ # Update spinner text with current status
427
+ if status_value:
428
+ # Include SSH status in spinner text
429
+ ssh_status = f" | SSH: {ssh_value}" if ssh_value else ""
430
+ spinner.text = f"Pod status: {status_value} ({elapsed_minutes}m elapsed, should take 5-20 min){ssh_status}" # noqa: E501
431
+
432
+ # Check if SSH is available (and not N/A)
433
+ if ssh_value and ssh_value.strip() and ssh_value.strip() != "N/A":
434
+ # Stop the spinner before logging
435
+ live.stop()
436
+ design.success(f"SSH is available: {ssh_value}")
437
+ return ssh_value
438
+
439
+ time.sleep(10) # Wait 10 seconds # noqa: ASYNC251
440
+ attempt += 1
441
+
442
+ except Exception as e:
443
+ spinner.text = f"[bold red]Status check failed: {e}[/bold red]"
444
+ time.sleep(10) # noqa: ASYNC251
445
+ attempt += 1
446
+
447
+ # Spinner is done, now we can use design.error
448
+ design.error("Timeout: Pod did not become ready within 20 minutes")
449
+ return None
450
+
451
+
452
+ async def run_prime_training(
453
+ model: str,
454
+ dataset: str,
455
+ config: Path,
456
+ gpus: str,
457
+ output_dir: Path,
458
+ image: str,
459
+ auto_create_pod: str | None = None,
460
+ team_id: str | None = None,
461
+ dataset_size: int | None = None,
462
+ is_json_file: bool = False,
463
+ ) -> None:
464
+ """Run training on Prime Intellect infrastructure."""
465
+ # Check API key
466
+ if not settings.prime_api_key:
467
+ design.error("Prime API key not found")
468
+ design.info("Set your Prime API key:")
469
+ design.info(" export PRIME_API_KEY='your-api-key'")
470
+ design.info(" # or")
471
+ design.info(" prime auth")
472
+ raise typer.Exit(1)
473
+
474
+ # Parse GPU configuration
475
+ gpu_count, gpu_type = parse_gpu_config(gpus)
476
+
477
+ # Generate short pod name (no dots allowed)
478
+ model_suffix = model.split("/")[-1].replace(".", "-").lower()
479
+ short_id = "".join(random.choices(string.ascii_lowercase + string.digits, k=4)) # noqa: S311
480
+ pod_name = f"hud-rl-{model_suffix}-{short_id}"[:30] # Keep it short
481
+
482
+ # Always create pod automatically
483
+ await create_and_connect_prime_pod(
484
+ pod_name=pod_name,
485
+ gpu_type=gpu_type,
486
+ gpu_count=gpu_count,
487
+ model=model,
488
+ dataset=dataset,
489
+ config=config,
490
+ output_dir=output_dir,
491
+ image=image,
492
+ team_id=team_id,
493
+ dataset_size=dataset_size,
494
+ is_json_file=is_json_file,
495
+ )