hud-python 0.4.15__py3-none-any.whl → 0.4.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of hud-python might be problematic. Click here for more details.
- hud/cli/rl/README.md +243 -0
- hud/cli/rl/__init__.py +82 -0
- hud/cli/rl/init.py +370 -0
- hud/cli/rl/pod.py +491 -0
- hud/cli/rl/ssh.py +288 -0
- hud/cli/rl/train.py +421 -0
- hud/cli/rl/utils.py +165 -0
- hud/utils/tests/test_version.py +1 -1
- hud/version.py +1 -1
- {hud_python-0.4.15.dist-info → hud_python-0.4.16.dist-info}/METADATA +1 -1
- {hud_python-0.4.15.dist-info → hud_python-0.4.16.dist-info}/RECORD +14 -7
- {hud_python-0.4.15.dist-info → hud_python-0.4.16.dist-info}/WHEEL +0 -0
- {hud_python-0.4.15.dist-info → hud_python-0.4.16.dist-info}/entry_points.txt +0 -0
- {hud_python-0.4.15.dist-info → hud_python-0.4.16.dist-info}/licenses/LICENSE +0 -0
hud/cli/rl/pod.py
ADDED
|
@@ -0,0 +1,491 @@
|
|
|
1
|
+
"""Pod creation and management utilities for Prime Intellect."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import random
|
|
6
|
+
import re
|
|
7
|
+
import string
|
|
8
|
+
import subprocess
|
|
9
|
+
import time
|
|
10
|
+
from pathlib import Path # noqa: TC003
|
|
11
|
+
|
|
12
|
+
import typer
|
|
13
|
+
from rich.console import Console
|
|
14
|
+
from rich.live import Live
|
|
15
|
+
from rich.spinner import Spinner
|
|
16
|
+
|
|
17
|
+
from hud.settings import settings
|
|
18
|
+
from hud.utils.design import HUDDesign
|
|
19
|
+
|
|
20
|
+
from .ssh import check_and_configure_ssh_key, connect_and_train
|
|
21
|
+
|
|
22
|
+
design = HUDDesign()
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def parse_gpu_config(gpus: str) -> tuple[int, str]:
|
|
26
|
+
"""Parse GPU configuration string like '2xA100' into count and type."""
|
|
27
|
+
if "x" in gpus:
|
|
28
|
+
count_str, gpu_type = gpus.split("x", 1)
|
|
29
|
+
try:
|
|
30
|
+
count = int(count_str)
|
|
31
|
+
except ValueError as e:
|
|
32
|
+
design.error(f"Invalid GPU count: {count_str}")
|
|
33
|
+
raise typer.Exit(1) from e
|
|
34
|
+
else:
|
|
35
|
+
# Default to 1 GPU if no count specified
|
|
36
|
+
count = 1
|
|
37
|
+
gpu_type = gpus
|
|
38
|
+
|
|
39
|
+
# Map common GPU names to Prime's expected format
|
|
40
|
+
gpu_type_map = {
|
|
41
|
+
"A100": "A100_80GB",
|
|
42
|
+
"A10": "A10_24GB",
|
|
43
|
+
"H100": "H100_80GB",
|
|
44
|
+
"V100": "V100_32GB",
|
|
45
|
+
"RTX3090": "RTX_3090",
|
|
46
|
+
"RTX4090": "RTX_4090",
|
|
47
|
+
}
|
|
48
|
+
|
|
49
|
+
gpu_type = gpu_type_map.get(gpu_type, gpu_type)
|
|
50
|
+
|
|
51
|
+
return count, gpu_type
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
async def create_and_connect_prime_pod(
|
|
55
|
+
pod_name: str,
|
|
56
|
+
gpu_type: str,
|
|
57
|
+
gpu_count: int,
|
|
58
|
+
model: str,
|
|
59
|
+
dataset: str,
|
|
60
|
+
config: Path,
|
|
61
|
+
output_dir: Path,
|
|
62
|
+
image: str,
|
|
63
|
+
team_id: str | None = None,
|
|
64
|
+
dataset_size: int | None = None,
|
|
65
|
+
) -> None:
|
|
66
|
+
"""Create a Prime Intellect pod and connect to it for training."""
|
|
67
|
+
design.section_title("🌐 Creating Prime Intellect Pod")
|
|
68
|
+
|
|
69
|
+
create_cmd = [
|
|
70
|
+
"prime",
|
|
71
|
+
"pods",
|
|
72
|
+
"create",
|
|
73
|
+
"--gpu-type",
|
|
74
|
+
gpu_type,
|
|
75
|
+
"--gpu-count",
|
|
76
|
+
str(gpu_count),
|
|
77
|
+
"--name",
|
|
78
|
+
pod_name,
|
|
79
|
+
]
|
|
80
|
+
|
|
81
|
+
design.info(f"Creating pod: {pod_name}")
|
|
82
|
+
design.info(f"GPU configuration: {gpu_count}x {gpu_type}")
|
|
83
|
+
|
|
84
|
+
# Check for global team config first
|
|
85
|
+
has_global_team = False
|
|
86
|
+
if not team_id: # Only check if not explicitly provided
|
|
87
|
+
team_check = subprocess.run( # noqa: ASYNC221
|
|
88
|
+
["prime", "config", "view"], # noqa: S607
|
|
89
|
+
capture_output=True,
|
|
90
|
+
text=True,
|
|
91
|
+
)
|
|
92
|
+
if team_check.returncode == 0:
|
|
93
|
+
# Parse the table output more carefully
|
|
94
|
+
for line in team_check.stdout.split("\n"):
|
|
95
|
+
# Look for "Team ID" in the table (case insensitive)
|
|
96
|
+
if "team id" in line.lower():
|
|
97
|
+
# Check if there's a value after the | separator
|
|
98
|
+
parts = line.split("|")
|
|
99
|
+
if len(parts) >= 2:
|
|
100
|
+
# Get the value part and check if it's not empty
|
|
101
|
+
value = parts[1].strip()
|
|
102
|
+
if value and value != "None":
|
|
103
|
+
has_global_team = True
|
|
104
|
+
# Don't overwrite team_id parameter - that's for explicit user input
|
|
105
|
+
break
|
|
106
|
+
|
|
107
|
+
# Display automated selections
|
|
108
|
+
design.info("")
|
|
109
|
+
design.info("Automated selections:")
|
|
110
|
+
design.info(" Provider: Will select from supported providers")
|
|
111
|
+
design.info(" Disk: Default size")
|
|
112
|
+
design.info(" Image: cuda_12_4_pytorch_2_5")
|
|
113
|
+
if team_id:
|
|
114
|
+
design.info(f" Team: {team_id}")
|
|
115
|
+
elif has_global_team:
|
|
116
|
+
design.info(" Team: Using pre-configured team")
|
|
117
|
+
else:
|
|
118
|
+
design.info(" Team: Personal Account")
|
|
119
|
+
design.info("")
|
|
120
|
+
|
|
121
|
+
# First, get the provider list by running the command with minimal input
|
|
122
|
+
design.info("Checking available providers...")
|
|
123
|
+
|
|
124
|
+
# Run command with just a newline to see provider list
|
|
125
|
+
provider_check = subprocess.run( # noqa: S603, ASYNC221
|
|
126
|
+
create_cmd,
|
|
127
|
+
input="\n", # Just send newline to see providers
|
|
128
|
+
text=True,
|
|
129
|
+
capture_output=True,
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
# Parse provider list
|
|
133
|
+
provider_lines = []
|
|
134
|
+
provider_map = {} # Maps provider name to number
|
|
135
|
+
|
|
136
|
+
if provider_check.stdout:
|
|
137
|
+
lines = provider_check.stdout.strip().split("\n")
|
|
138
|
+
for line in lines:
|
|
139
|
+
# Look for lines like "1. datacrunch (spot) ($0.65/hr)"
|
|
140
|
+
if ". " in line and ("$" in line or "/hr" in line):
|
|
141
|
+
# Extract provider number and name
|
|
142
|
+
parts = line.strip().split(". ", 1)
|
|
143
|
+
if len(parts) == 2:
|
|
144
|
+
num = parts[0].strip()
|
|
145
|
+
# Extract provider name (before parentheses or dollar sign)
|
|
146
|
+
provider_info = parts[1]
|
|
147
|
+
provider_name = provider_info.split("(")[0].split("$")[0].strip().lower()
|
|
148
|
+
provider_map[provider_name] = num
|
|
149
|
+
provider_lines.append(line.strip())
|
|
150
|
+
|
|
151
|
+
# Select provider based on our supported list
|
|
152
|
+
supported_providers = ["datacrunch", "hyperstack"]
|
|
153
|
+
provider_choice = "1" # Default fallback
|
|
154
|
+
|
|
155
|
+
for provider in supported_providers:
|
|
156
|
+
if provider in provider_map:
|
|
157
|
+
provider_choice = provider_map[provider]
|
|
158
|
+
design.info(f"Selected provider: {provider} (option {provider_choice})")
|
|
159
|
+
break
|
|
160
|
+
|
|
161
|
+
# Build inputs step by step for clarity
|
|
162
|
+
disk_size = "" # Just press enter for default
|
|
163
|
+
image_choice = "7" # cuda_12_4_pytorch_2_5
|
|
164
|
+
|
|
165
|
+
# Log what we're doing
|
|
166
|
+
design.debug("Pod creation configuration:")
|
|
167
|
+
design.debug(f" Team ID provided: {team_id}")
|
|
168
|
+
design.debug(f" Global team detected: {has_global_team}")
|
|
169
|
+
|
|
170
|
+
if team_id:
|
|
171
|
+
# Explicit team ID provided, select Custom Team ID (option 3)
|
|
172
|
+
team_choice = "3"
|
|
173
|
+
# Fixed: confirmation should be lowercase 'y'
|
|
174
|
+
inputs = f"{provider_choice}\n{disk_size}\n{image_choice}\n{team_choice}\n{team_id}\ny\n"
|
|
175
|
+
design.debug(f" Using explicit team ID: option {team_choice} with ID {team_id}")
|
|
176
|
+
elif has_global_team:
|
|
177
|
+
# When team is pre-configured, it shows as option 2 - select it
|
|
178
|
+
team_choice = "2"
|
|
179
|
+
# Fixed: confirmation should be lowercase 'y' and come after team selection
|
|
180
|
+
inputs = f"{provider_choice}\n{disk_size}\n{image_choice}\n{team_choice}\ny\n"
|
|
181
|
+
design.debug(f" Using pre-configured team: option {team_choice}")
|
|
182
|
+
else:
|
|
183
|
+
# Personal account (option 1) - just press enter to accept default [1]
|
|
184
|
+
inputs = (
|
|
185
|
+
f"{provider_choice}\n{disk_size}\n{image_choice}\n\ny\n" # Empty line for default [1]
|
|
186
|
+
)
|
|
187
|
+
design.debug(" Using personal account: default option [1]")
|
|
188
|
+
|
|
189
|
+
design.debug(
|
|
190
|
+
f" Input sequence: provider={provider_choice}, disk={disk_size or 'default'}, image={image_choice}, team={team_choice if 'team_choice' in locals() else 'default'}" # noqa: E501
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
# Show found providers
|
|
194
|
+
if provider_lines:
|
|
195
|
+
design.info("")
|
|
196
|
+
design.info("Found providers:")
|
|
197
|
+
for pl in provider_lines[:5]: # Show first 5
|
|
198
|
+
design.info(f" {pl}")
|
|
199
|
+
|
|
200
|
+
try:
|
|
201
|
+
console = Console()
|
|
202
|
+
|
|
203
|
+
with Live(
|
|
204
|
+
Spinner("dots", text="[bold]Creating pod...[/bold]", style="gold"),
|
|
205
|
+
console=console,
|
|
206
|
+
refresh_per_second=10,
|
|
207
|
+
):
|
|
208
|
+
result = subprocess.run( # noqa: S603, ASYNC221
|
|
209
|
+
create_cmd,
|
|
210
|
+
input=inputs,
|
|
211
|
+
text=True,
|
|
212
|
+
capture_output=True,
|
|
213
|
+
)
|
|
214
|
+
|
|
215
|
+
if result.returncode != 0:
|
|
216
|
+
design.error("Failed to create pod")
|
|
217
|
+
|
|
218
|
+
# Parse output for better error reporting
|
|
219
|
+
output_lines = result.stdout.strip().split("\n") if result.stdout else []
|
|
220
|
+
|
|
221
|
+
# Look for provider prices
|
|
222
|
+
for line in output_lines:
|
|
223
|
+
if "$" in line and "/hr" in line:
|
|
224
|
+
design.info(f"Provider option: {line.strip()}")
|
|
225
|
+
|
|
226
|
+
# Check for team selection error
|
|
227
|
+
if "invalid selection" in result.stdout.lower():
|
|
228
|
+
design.error("Team selection failed")
|
|
229
|
+
# Find and display the team selection section
|
|
230
|
+
for i, line in enumerate(output_lines):
|
|
231
|
+
if "Select Team:" in line:
|
|
232
|
+
design.info("Team selection options:")
|
|
233
|
+
# Show next few lines
|
|
234
|
+
for j in range(i, min(i + 6, len(output_lines))):
|
|
235
|
+
design.info(f" {output_lines[j]}")
|
|
236
|
+
break
|
|
237
|
+
|
|
238
|
+
design.info("")
|
|
239
|
+
design.hint(
|
|
240
|
+
"The Prime CLI interface may have changed. Try running the command manually:"
|
|
241
|
+
)
|
|
242
|
+
design.command_example(
|
|
243
|
+
f"prime pods create --gpu-type {gpu_type} --gpu-count {gpu_count} --name {pod_name}" # noqa: E501
|
|
244
|
+
)
|
|
245
|
+
|
|
246
|
+
# Show error details
|
|
247
|
+
if result.stderr:
|
|
248
|
+
design.error("Error output:")
|
|
249
|
+
for line in result.stderr.strip().split("\n"):
|
|
250
|
+
design.error(f" {line}")
|
|
251
|
+
|
|
252
|
+
# Show last part of stdout for context
|
|
253
|
+
if result.stdout:
|
|
254
|
+
design.info("Command output:")
|
|
255
|
+
# Show last 15 lines for brevity
|
|
256
|
+
for line in output_lines[-15:]:
|
|
257
|
+
design.info(f" {line}")
|
|
258
|
+
|
|
259
|
+
if "max_price" in str(result.stderr) or "max_price" in str(result.stdout):
|
|
260
|
+
design.warning("")
|
|
261
|
+
design.warning("The selected provider requires a maximum price limit.")
|
|
262
|
+
design.info("This is a known limitation with some providers.")
|
|
263
|
+
design.info("")
|
|
264
|
+
design.hint("Workarounds:")
|
|
265
|
+
design.info("1. Run the command manually and select a different provider")
|
|
266
|
+
design.info("2. Try again later when datacrunch (usually cheapest) is available")
|
|
267
|
+
design.info("3. Use the Prime web interface: https://app.primeintellect.ai")
|
|
268
|
+
|
|
269
|
+
design.info("")
|
|
270
|
+
design.info("Debug info:")
|
|
271
|
+
design.info(f" Command: {' '.join(create_cmd)}")
|
|
272
|
+
design.info(f" Pod name: {pod_name}")
|
|
273
|
+
design.info(f" Team ID: {'Provided' if team_id else 'Not provided'}")
|
|
274
|
+
design.info(f" Global team detected: {has_global_team}")
|
|
275
|
+
|
|
276
|
+
# Show the exact inputs we sent
|
|
277
|
+
design.info(" Inputs sent (in order):")
|
|
278
|
+
input_parts = inputs.strip().split("\n")
|
|
279
|
+
input_labels = [
|
|
280
|
+
"Provider selection",
|
|
281
|
+
"Disk size",
|
|
282
|
+
"Image selection",
|
|
283
|
+
"Team selection",
|
|
284
|
+
"Team ID (if custom)",
|
|
285
|
+
"Confirmation",
|
|
286
|
+
]
|
|
287
|
+
for i, (part, label) in enumerate(zip(input_parts, input_labels, strict=False)):
|
|
288
|
+
if part:
|
|
289
|
+
design.info(f" {i + 1}. {label}: '{part}'")
|
|
290
|
+
else:
|
|
291
|
+
design.info(f" {i + 1}. {label}: [Enter/default]")
|
|
292
|
+
|
|
293
|
+
raise typer.Exit(1)
|
|
294
|
+
|
|
295
|
+
# Extract pod ID from output
|
|
296
|
+
output_lines = result.stdout.strip().split("\n")
|
|
297
|
+
pod_id = None
|
|
298
|
+
for line in output_lines:
|
|
299
|
+
if "Successfully created pod" in line:
|
|
300
|
+
# Extract just the pod ID (alphanumeric characters)
|
|
301
|
+
match = re.search(r"pod\s+([a-f0-9]+)", line)
|
|
302
|
+
pod_id = match.group(1) if match else line.split()[-1].strip()
|
|
303
|
+
break
|
|
304
|
+
|
|
305
|
+
if not pod_id:
|
|
306
|
+
design.error("Could not extract pod ID from output")
|
|
307
|
+
design.info(f"Output: {result.stdout}")
|
|
308
|
+
raise typer.Exit(1)
|
|
309
|
+
|
|
310
|
+
design.success(f"Created pod: {pod_id}")
|
|
311
|
+
|
|
312
|
+
# Poll for pod status
|
|
313
|
+
ssh_info = await poll_pod_status(pod_id)
|
|
314
|
+
|
|
315
|
+
if ssh_info:
|
|
316
|
+
design.success("Pod is ready!")
|
|
317
|
+
design.info(f"SSH: {ssh_info}")
|
|
318
|
+
|
|
319
|
+
# Check if SSH key is configured globally
|
|
320
|
+
ssh_key_configured = await check_and_configure_ssh_key()
|
|
321
|
+
|
|
322
|
+
if ssh_key_configured:
|
|
323
|
+
# Automatically connect and run training
|
|
324
|
+
await connect_and_train(
|
|
325
|
+
pod_id=pod_id,
|
|
326
|
+
ssh_info=ssh_info,
|
|
327
|
+
model=model,
|
|
328
|
+
dataset=dataset,
|
|
329
|
+
config=config,
|
|
330
|
+
output_dir=output_dir,
|
|
331
|
+
image=image,
|
|
332
|
+
dataset_size=dataset_size,
|
|
333
|
+
)
|
|
334
|
+
else:
|
|
335
|
+
# Manual fallback
|
|
336
|
+
design.section_title("📋 Manual Connection Required")
|
|
337
|
+
design.info("SSH key configuration failed. Connect manually:")
|
|
338
|
+
design.info("")
|
|
339
|
+
design.info("1. Download the SSH key from:")
|
|
340
|
+
design.info(" https://app.primeintellect.ai/dashboard/profile")
|
|
341
|
+
design.info("")
|
|
342
|
+
design.info("2. Set permissions:")
|
|
343
|
+
design.command_example("chmod 400 /path/to/prime-key.pem", "")
|
|
344
|
+
design.info("")
|
|
345
|
+
design.info("3. Connect to your instance:")
|
|
346
|
+
design.command_example(f"ssh -i /path/to/prime-key.pem {ssh_info}", "")
|
|
347
|
+
design.info("")
|
|
348
|
+
design.info("4. Run these commands:")
|
|
349
|
+
design.command_example("pip install verifiers hud-vf-gym", "")
|
|
350
|
+
design.command_example(f"prime env install {image}", "")
|
|
351
|
+
|
|
352
|
+
# Build training command with env vars
|
|
353
|
+
if settings.wandb_api_key:
|
|
354
|
+
design.command_example(f"export WANDB_API_KEY={settings.wandb_api_key}", "")
|
|
355
|
+
|
|
356
|
+
train_cmd = f"""vf-train hud-vf-gym \\
|
|
357
|
+
--model {model} \\
|
|
358
|
+
--env-args '{{"taskset": "{dataset}", "config_path": "/root/config.yaml"}}' \\
|
|
359
|
+
--output-dir {output_dir} \\
|
|
360
|
+
--run-name hud-rl-{pod_id[:8]} \\
|
|
361
|
+
--wandb-project hud-rl"""
|
|
362
|
+
|
|
363
|
+
design.command_example(train_cmd, "")
|
|
364
|
+
design.info("")
|
|
365
|
+
design.warning(f"Remember to terminate when done: prime pods terminate {pod_id}")
|
|
366
|
+
else:
|
|
367
|
+
design.error("Pod failed to become active")
|
|
368
|
+
raise typer.Exit(1)
|
|
369
|
+
|
|
370
|
+
except subprocess.CalledProcessError as e:
|
|
371
|
+
design.error(f"Failed to create pod: {e}")
|
|
372
|
+
raise typer.Exit(1) from e
|
|
373
|
+
|
|
374
|
+
|
|
375
|
+
async def poll_pod_status(pod_id: str) -> str | None:
|
|
376
|
+
"""Poll pod status until SSH is available."""
|
|
377
|
+
console = Console()
|
|
378
|
+
max_attempts = 120 # 20 minutes with 10s intervals
|
|
379
|
+
attempt = 0
|
|
380
|
+
|
|
381
|
+
# Create spinner
|
|
382
|
+
spinner = Spinner(
|
|
383
|
+
"dots", text="Waiting for pod to become active (should take 5-20 min)...", style="gold"
|
|
384
|
+
)
|
|
385
|
+
|
|
386
|
+
with Live(spinner, console=console, refresh_per_second=10) as live:
|
|
387
|
+
while attempt < max_attempts:
|
|
388
|
+
try:
|
|
389
|
+
# Update check frequency in spinner text every minute
|
|
390
|
+
if attempt % 6 == 0: # Every minute
|
|
391
|
+
pass # Will update in spinner text below
|
|
392
|
+
|
|
393
|
+
result = subprocess.run( # noqa: S603, ASYNC221
|
|
394
|
+
["prime", "pods", "status", pod_id], # noqa: S607
|
|
395
|
+
capture_output=True,
|
|
396
|
+
text=True,
|
|
397
|
+
)
|
|
398
|
+
|
|
399
|
+
if result.returncode == 0:
|
|
400
|
+
output = result.stdout
|
|
401
|
+
elapsed_minutes = (attempt * 10) // 60
|
|
402
|
+
|
|
403
|
+
# Parse status - look for lines with Status and SSH
|
|
404
|
+
lines = output.split("\n")
|
|
405
|
+
status_value = None
|
|
406
|
+
ssh_value = None
|
|
407
|
+
|
|
408
|
+
for line in lines:
|
|
409
|
+
# Handle both regular pipes | and box-drawing chars │
|
|
410
|
+
if "|" in line or "│" in line:
|
|
411
|
+
# Split by either type of pipe
|
|
412
|
+
separator = "│" if "│" in line else "|"
|
|
413
|
+
parts = [p.strip() for p in line.split(separator)]
|
|
414
|
+
|
|
415
|
+
if len(parts) >= 3:
|
|
416
|
+
key = parts[1].strip()
|
|
417
|
+
value = parts[2].strip()
|
|
418
|
+
|
|
419
|
+
if key == "Status":
|
|
420
|
+
status_value = value
|
|
421
|
+
elif key == "SSH":
|
|
422
|
+
ssh_value = value
|
|
423
|
+
|
|
424
|
+
# Update spinner text with current status
|
|
425
|
+
if status_value:
|
|
426
|
+
# Include SSH status in spinner text
|
|
427
|
+
ssh_status = f" | SSH: {ssh_value}" if ssh_value else ""
|
|
428
|
+
spinner.text = f"Pod status: {status_value} ({elapsed_minutes}m elapsed, should take 5-20 min){ssh_status}" # noqa: E501
|
|
429
|
+
|
|
430
|
+
# Check if SSH is available (and not N/A)
|
|
431
|
+
if ssh_value and ssh_value.strip() and ssh_value.strip() != "N/A":
|
|
432
|
+
# Stop the spinner before logging
|
|
433
|
+
live.stop()
|
|
434
|
+
design.success(f"SSH is available: {ssh_value}")
|
|
435
|
+
return ssh_value
|
|
436
|
+
|
|
437
|
+
time.sleep(10) # Wait 10 seconds # noqa: ASYNC251
|
|
438
|
+
attempt += 1
|
|
439
|
+
|
|
440
|
+
except Exception as e:
|
|
441
|
+
spinner.text = f"[bold red]Status check failed: {e}[/bold red]"
|
|
442
|
+
time.sleep(10) # noqa: ASYNC251
|
|
443
|
+
attempt += 1
|
|
444
|
+
|
|
445
|
+
# Spinner is done, now we can use design.error
|
|
446
|
+
design.error("Timeout: Pod did not become ready within 20 minutes")
|
|
447
|
+
return None
|
|
448
|
+
|
|
449
|
+
|
|
450
|
+
async def run_prime_training(
|
|
451
|
+
model: str,
|
|
452
|
+
dataset: str,
|
|
453
|
+
config: Path,
|
|
454
|
+
gpus: str,
|
|
455
|
+
output_dir: Path,
|
|
456
|
+
image: str,
|
|
457
|
+
auto_create_pod: str | None = None,
|
|
458
|
+
team_id: str | None = None,
|
|
459
|
+
dataset_size: int | None = None,
|
|
460
|
+
) -> None:
|
|
461
|
+
"""Run training on Prime Intellect infrastructure."""
|
|
462
|
+
# Check API key
|
|
463
|
+
if not settings.prime_api_key:
|
|
464
|
+
design.error("Prime API key not found")
|
|
465
|
+
design.info("Set your Prime API key:")
|
|
466
|
+
design.info(" export PRIME_API_KEY='your-api-key'")
|
|
467
|
+
design.info(" # or")
|
|
468
|
+
design.info(" prime auth")
|
|
469
|
+
raise typer.Exit(1)
|
|
470
|
+
|
|
471
|
+
# Parse GPU configuration
|
|
472
|
+
gpu_count, gpu_type = parse_gpu_config(gpus)
|
|
473
|
+
|
|
474
|
+
# Generate short pod name (no dots allowed)
|
|
475
|
+
model_suffix = model.split("/")[-1].replace(".", "-").lower()
|
|
476
|
+
short_id = "".join(random.choices(string.ascii_lowercase + string.digits, k=4)) # noqa: S311
|
|
477
|
+
pod_name = f"hud-rl-{model_suffix}-{short_id}"[:30] # Keep it short
|
|
478
|
+
|
|
479
|
+
# Always create pod automatically
|
|
480
|
+
await create_and_connect_prime_pod(
|
|
481
|
+
pod_name=pod_name,
|
|
482
|
+
gpu_type=gpu_type,
|
|
483
|
+
gpu_count=gpu_count,
|
|
484
|
+
model=model,
|
|
485
|
+
dataset=dataset,
|
|
486
|
+
config=config,
|
|
487
|
+
output_dir=output_dir,
|
|
488
|
+
image=image,
|
|
489
|
+
team_id=team_id,
|
|
490
|
+
dataset_size=dataset_size,
|
|
491
|
+
)
|