hud-python 0.4.28__py3-none-any.whl → 0.4.30__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of hud-python might be problematic. Click here for more details.
- hud/__init__.py +2 -1
- hud/agents/base.py +81 -45
- hud/agents/claude.py +8 -4
- hud/agents/openai_chat_generic.py +66 -40
- hud/agents/tests/test_base.py +0 -4
- hud/agents/tests/test_openai.py +1 -1
- hud/cli/__init__.py +182 -52
- hud/cli/dev.py +8 -9
- hud/cli/eval.py +317 -119
- hud/cli/flows/__init__.py +0 -0
- hud/cli/flows/tasks.py +0 -0
- hud/cli/get.py +160 -0
- hud/cli/rl/__init__.py +567 -71
- hud/cli/rl/config.py +94 -0
- hud/cli/rl/display.py +133 -0
- hud/cli/rl/gpu.py +63 -0
- hud/cli/rl/gpu_utils.py +318 -0
- hud/cli/rl/presets.py +96 -0
- hud/cli/rl/remote_runner.py +347 -0
- hud/cli/rl/rl_api.py +150 -0
- hud/cli/rl/vllm.py +177 -0
- hud/cli/tests/test_analyze_metadata.py +0 -1
- hud/cli/utils/tasks.py +26 -0
- hud/clients/base.py +21 -23
- hud/clients/mcp_use.py +36 -44
- hud/clients/tests/test_mcp_use_retry.py +10 -10
- hud/datasets/__init__.py +4 -3
- hud/datasets/{execution/parallel.py → parallel.py} +1 -1
- hud/datasets/{execution/runner.py → runner.py} +1 -1
- hud/datasets/utils.py +1 -1
- hud/native/comparator.py +6 -6
- hud/native/tests/test_comparator.py +8 -8
- hud/native/tests/test_native_init.py +13 -11
- hud/otel/config.py +1 -1
- hud/otel/instrumentation.py +35 -0
- hud/rl/README.md +30 -0
- hud/rl/__init__.py +1 -0
- hud/rl/actor.py +174 -0
- hud/rl/buffer.py +371 -0
- hud/rl/chat_template.jinja +101 -0
- hud/rl/config.py +184 -0
- hud/rl/distributed.py +95 -0
- hud/rl/learner.py +589 -0
- hud/rl/tests/__init__.py +1 -0
- hud/rl/tests/test_learner.py +171 -0
- hud/rl/train.py +354 -0
- hud/rl/types.py +101 -0
- hud/rl/utils/start_vllm_server.sh +30 -0
- hud/rl/utils.py +524 -0
- hud/rl/vllm_adapter.py +125 -0
- hud/settings.py +6 -0
- hud/telemetry/__init__.py +2 -1
- hud/telemetry/job.py +46 -3
- hud/telemetry/tests/test_trace.py +3 -3
- hud/telemetry/trace.py +85 -13
- hud/tools/tests/test_computer.py +3 -3
- hud/tools/tests/test_computer_actions.py +1 -1
- hud/types.py +123 -2
- hud/utils/group_eval.py +223 -0
- hud/utils/hud_console.py +113 -13
- hud/utils/tasks.py +119 -0
- hud/utils/tests/test_version.py +1 -1
- hud/version.py +1 -1
- {hud_python-0.4.28.dist-info → hud_python-0.4.30.dist-info}/METADATA +20 -2
- {hud_python-0.4.28.dist-info → hud_python-0.4.30.dist-info}/RECORD +68 -48
- hud/cli/hf.py +0 -406
- hud/cli/rl/README.md +0 -243
- hud/cli/rl/init.py +0 -370
- hud/cli/rl/pod.py +0 -501
- hud/cli/rl/ssh.py +0 -322
- hud/cli/rl/train.py +0 -562
- hud/cli/rl/utils.py +0 -165
- hud/datasets/execution/__init__.py +0 -13
- hud/datasets/task.py +0 -116
- {hud_python-0.4.28.dist-info → hud_python-0.4.30.dist-info}/WHEEL +0 -0
- {hud_python-0.4.28.dist-info → hud_python-0.4.30.dist-info}/entry_points.txt +0 -0
- {hud_python-0.4.28.dist-info → hud_python-0.4.30.dist-info}/licenses/LICENSE +0 -0
hud/cli/rl/train.py
DELETED
|
@@ -1,562 +0,0 @@
|
|
|
1
|
-
"""Main RL training command implementation."""
|
|
2
|
-
|
|
3
|
-
from __future__ import annotations
|
|
4
|
-
|
|
5
|
-
import asyncio
|
|
6
|
-
import json
|
|
7
|
-
import subprocess
|
|
8
|
-
from pathlib import Path
|
|
9
|
-
from typing import Any
|
|
10
|
-
|
|
11
|
-
import typer
|
|
12
|
-
|
|
13
|
-
from hud.settings import settings
|
|
14
|
-
from hud.utils.hud_console import HUDConsole
|
|
15
|
-
|
|
16
|
-
from .pod import run_prime_training
|
|
17
|
-
from .utils import (
|
|
18
|
-
detect_image_name,
|
|
19
|
-
get_primary_dataset,
|
|
20
|
-
validate_dataset_name,
|
|
21
|
-
)
|
|
22
|
-
|
|
23
|
-
hud_console = HUDConsole()
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
def find_task_json_files() -> list[Path]:
|
|
27
|
-
"""Find JSON files containing tasks in the current directory."""
|
|
28
|
-
json_files = []
|
|
29
|
-
patterns = [
|
|
30
|
-
"*task*.json",
|
|
31
|
-
"*eval*.json",
|
|
32
|
-
"*Task*.json",
|
|
33
|
-
"*Eval*.json",
|
|
34
|
-
"*TASK*.json",
|
|
35
|
-
"*EVAL*.json",
|
|
36
|
-
"tasks.json", # Most common name
|
|
37
|
-
]
|
|
38
|
-
|
|
39
|
-
# First check current directory
|
|
40
|
-
for pattern in patterns:
|
|
41
|
-
json_files.extend(Path(".").glob(pattern))
|
|
42
|
-
|
|
43
|
-
# If no files found, search one level deep
|
|
44
|
-
if not json_files:
|
|
45
|
-
for pattern in patterns:
|
|
46
|
-
json_files.extend(Path(".").glob(f"*/{pattern}"))
|
|
47
|
-
|
|
48
|
-
# Remove duplicates and sort, prioritizing "tasks.json"
|
|
49
|
-
json_files = sorted(set(json_files))
|
|
50
|
-
|
|
51
|
-
# Put tasks.json first if it exists
|
|
52
|
-
tasks_json = Path("tasks.json")
|
|
53
|
-
if tasks_json in json_files:
|
|
54
|
-
json_files.remove(tasks_json)
|
|
55
|
-
json_files.insert(0, tasks_json)
|
|
56
|
-
|
|
57
|
-
return json_files
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
def train_command_wrapper(
|
|
61
|
-
model: str,
|
|
62
|
-
dataset: str | None,
|
|
63
|
-
config: Path | None,
|
|
64
|
-
gpus: str,
|
|
65
|
-
provider: str,
|
|
66
|
-
output_dir: Path,
|
|
67
|
-
) -> None:
|
|
68
|
-
"""Wrapper to handle interactive prompts before entering async context."""
|
|
69
|
-
# Pre-flight checks for required environment variables
|
|
70
|
-
hud_console.section_title("🔍 Pre-flight Checks")
|
|
71
|
-
|
|
72
|
-
missing_vars = []
|
|
73
|
-
|
|
74
|
-
# Check HUD API key
|
|
75
|
-
if not settings.api_key:
|
|
76
|
-
missing_vars.append("HUD_API_KEY")
|
|
77
|
-
else:
|
|
78
|
-
hud_console.success("✓ HUD_API_KEY configured")
|
|
79
|
-
|
|
80
|
-
# Check WANDB API key (optional but recommended)
|
|
81
|
-
if not getattr(settings, "wandb_api_key", None):
|
|
82
|
-
hud_console.warning(
|
|
83
|
-
"⚠ WANDB_API_KEY not set (optional but recommended for training metrics)"
|
|
84
|
-
)
|
|
85
|
-
else:
|
|
86
|
-
hud_console.success("✓ WANDB_API_KEY configured")
|
|
87
|
-
|
|
88
|
-
# Check PRIME API key (required for remote training)
|
|
89
|
-
if provider == "prime" and not getattr(settings, "prime_api_key", None):
|
|
90
|
-
missing_vars.append("PRIME_API_KEY")
|
|
91
|
-
elif provider == "prime":
|
|
92
|
-
hud_console.success("✓ PRIME_API_KEY configured")
|
|
93
|
-
|
|
94
|
-
if missing_vars:
|
|
95
|
-
hud_console.error(f"Missing required environment variables: {', '.join(missing_vars)}")
|
|
96
|
-
hud_console.info("")
|
|
97
|
-
hud_console.info("Set them using one of these methods:")
|
|
98
|
-
hud_console.info("1. Environment variables:")
|
|
99
|
-
for var in missing_vars:
|
|
100
|
-
hud_console.command_example(f"export {var}=your-{var.lower().replace('_', '-')}")
|
|
101
|
-
hud_console.info("")
|
|
102
|
-
hud_console.info("2. Create a .env file in your project root:")
|
|
103
|
-
hud_console.command_example(
|
|
104
|
-
"\n".join([f"{var}=your-{var.lower().replace('_', '-')}" for var in missing_vars]),
|
|
105
|
-
"env",
|
|
106
|
-
)
|
|
107
|
-
raise typer.Exit(1)
|
|
108
|
-
|
|
109
|
-
# Check for required components
|
|
110
|
-
missing = check_requirements(config, dataset)
|
|
111
|
-
|
|
112
|
-
# Auto-detect config if not specified and exactly one exists
|
|
113
|
-
if not config and "config" not in missing:
|
|
114
|
-
config_dir = Path("configs")
|
|
115
|
-
if config_dir.exists():
|
|
116
|
-
yaml_files = list(config_dir.glob("*.yaml"))
|
|
117
|
-
if len(yaml_files) == 1:
|
|
118
|
-
config = yaml_files[0]
|
|
119
|
-
hud_console.info(f"Using config: {config}")
|
|
120
|
-
|
|
121
|
-
# Store user choice for pod creation
|
|
122
|
-
auto_create_pod = None
|
|
123
|
-
team_id = None
|
|
124
|
-
|
|
125
|
-
if missing:
|
|
126
|
-
# Handle interactive prompts here
|
|
127
|
-
if "config" in missing:
|
|
128
|
-
if missing["config"] == "multiple":
|
|
129
|
-
# Select from multiple configs
|
|
130
|
-
config_dir = Path("configs")
|
|
131
|
-
yaml_files = list(config_dir.glob("*.yaml"))
|
|
132
|
-
config_names = [f.name for f in yaml_files]
|
|
133
|
-
selected_config = hud_console.select(
|
|
134
|
-
"Multiple config files found. Select one:", config_names
|
|
135
|
-
)
|
|
136
|
-
config = config_dir / selected_config
|
|
137
|
-
else:
|
|
138
|
-
# No config found, offer to generate
|
|
139
|
-
generate_config = hud_console.select(
|
|
140
|
-
"No config file found. Would you like to generate one?",
|
|
141
|
-
["Yes, generate config", "No, I'll create it manually"],
|
|
142
|
-
)
|
|
143
|
-
|
|
144
|
-
if generate_config == "Yes, generate config":
|
|
145
|
-
hud_console.info("Running 'hud rl init' to generate config...")
|
|
146
|
-
hud_console.info("")
|
|
147
|
-
# Import here to avoid circular imports
|
|
148
|
-
from .init import init_command_wrapper
|
|
149
|
-
|
|
150
|
-
init_command_wrapper(".", None, False, False)
|
|
151
|
-
|
|
152
|
-
# Look for generated config
|
|
153
|
-
config_dir = Path("configs")
|
|
154
|
-
if config_dir.exists():
|
|
155
|
-
yaml_files = list(config_dir.glob("*.yaml"))
|
|
156
|
-
if yaml_files:
|
|
157
|
-
config = yaml_files[0]
|
|
158
|
-
hud_console.success(f"Using generated config: {config}")
|
|
159
|
-
else:
|
|
160
|
-
hud_console.error("Config generation failed")
|
|
161
|
-
raise typer.Exit(1)
|
|
162
|
-
else:
|
|
163
|
-
hud_console.info("Please create a config file and try again")
|
|
164
|
-
raise typer.Exit(1)
|
|
165
|
-
|
|
166
|
-
if "dataset" in missing:
|
|
167
|
-
if missing["dataset"] == "multiple_json":
|
|
168
|
-
# Multiple JSON files found, let user choose
|
|
169
|
-
json_files = find_task_json_files()
|
|
170
|
-
hud_console.info("Multiple task files found:")
|
|
171
|
-
file_choice = hud_console.select(
|
|
172
|
-
"Select a task file to use:",
|
|
173
|
-
choices=[str(f) for f in json_files],
|
|
174
|
-
)
|
|
175
|
-
dataset = file_choice
|
|
176
|
-
hud_console.success(f"Selected: {dataset}")
|
|
177
|
-
elif missing["dataset"] == "none":
|
|
178
|
-
hud_console.error("No dataset specified and no task JSON files found")
|
|
179
|
-
hud_console.info("Please use --dataset or create a tasks.json file")
|
|
180
|
-
hud_console.hint(
|
|
181
|
-
"Example: hud hf --name my-org/my-tasks # Generate tasks from HUD evaluation"
|
|
182
|
-
)
|
|
183
|
-
raise typer.Exit(1)
|
|
184
|
-
|
|
185
|
-
# Ask about pod creation for Prime training
|
|
186
|
-
if provider == "prime":
|
|
187
|
-
# Check if team ID is globally configured
|
|
188
|
-
team_check = subprocess.run(
|
|
189
|
-
["prime", "config", "view"], # noqa: S607
|
|
190
|
-
capture_output=True,
|
|
191
|
-
text=True,
|
|
192
|
-
)
|
|
193
|
-
|
|
194
|
-
has_global_team = False
|
|
195
|
-
if team_check.returncode == 0:
|
|
196
|
-
# Parse the table output - look for Team ID row
|
|
197
|
-
for line in team_check.stdout.split("\n"):
|
|
198
|
-
if "team id" in line.lower():
|
|
199
|
-
# Check if there's a value after the | separator
|
|
200
|
-
parts = line.split("|")
|
|
201
|
-
if len(parts) >= 2:
|
|
202
|
-
# Get the value part and check if it's not empty
|
|
203
|
-
value = parts[1].strip()
|
|
204
|
-
if value and value != "None":
|
|
205
|
-
has_global_team = True
|
|
206
|
-
hud_console.info("Using globally configured team ID")
|
|
207
|
-
break
|
|
208
|
-
|
|
209
|
-
if not has_global_team:
|
|
210
|
-
# Only ask if no global team is configured
|
|
211
|
-
auto_create_pod = hud_console.select(
|
|
212
|
-
"How would you like to create the Prime Intellect pod?",
|
|
213
|
-
["Personal account (automated)", "Team account (enter team ID)"],
|
|
214
|
-
)
|
|
215
|
-
|
|
216
|
-
# If team account selected, get the team ID
|
|
217
|
-
if auto_create_pod == "Team account (enter team ID)":
|
|
218
|
-
team_id = typer.prompt("Enter your team ID (e.g., team_abc123def456)")
|
|
219
|
-
|
|
220
|
-
# Save it globally automatically
|
|
221
|
-
subprocess.run(["prime", "config", "set-team-id", team_id]) # noqa: S603, S607
|
|
222
|
-
hud_console.success("Team ID saved globally")
|
|
223
|
-
|
|
224
|
-
auto_create_pod = (
|
|
225
|
-
"Personal account (automated)" # Treat as automated after getting team ID
|
|
226
|
-
)
|
|
227
|
-
|
|
228
|
-
# Now run the async command
|
|
229
|
-
asyncio.run(
|
|
230
|
-
train_command(
|
|
231
|
-
model=model,
|
|
232
|
-
dataset=dataset,
|
|
233
|
-
config=config,
|
|
234
|
-
gpus=gpus,
|
|
235
|
-
provider=provider,
|
|
236
|
-
output_dir=output_dir,
|
|
237
|
-
auto_create_pod=auto_create_pod,
|
|
238
|
-
team_id=team_id,
|
|
239
|
-
)
|
|
240
|
-
)
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
async def train_command(
|
|
244
|
-
model: str,
|
|
245
|
-
dataset: str | None,
|
|
246
|
-
config: Path | None,
|
|
247
|
-
gpus: str,
|
|
248
|
-
provider: str,
|
|
249
|
-
output_dir: Path,
|
|
250
|
-
auto_create_pod: str | None = None,
|
|
251
|
-
team_id: str | None = None,
|
|
252
|
-
) -> None:
|
|
253
|
-
"""Run RL training on HUD environments."""
|
|
254
|
-
hud_console.header("🤖 HUD RL Training")
|
|
255
|
-
|
|
256
|
-
# Get environment image
|
|
257
|
-
image = detect_image_name()
|
|
258
|
-
if not image:
|
|
259
|
-
hud_console.error("No environment image found")
|
|
260
|
-
hud_console.hint("Run 'hud build' first or specify with 'hud rl init <image>'")
|
|
261
|
-
raise typer.Exit(1)
|
|
262
|
-
|
|
263
|
-
# Handle dataset (JSON file or HuggingFace dataset)
|
|
264
|
-
dataset_size = None
|
|
265
|
-
is_json_file = False
|
|
266
|
-
|
|
267
|
-
# Use dataset from command or look for JSON files
|
|
268
|
-
if not dataset:
|
|
269
|
-
# Check for JSON files if no dataset specified
|
|
270
|
-
json_files = find_task_json_files()
|
|
271
|
-
if json_files:
|
|
272
|
-
if len(json_files) == 1:
|
|
273
|
-
dataset = str(json_files[0])
|
|
274
|
-
hud_console.info(f"Found task file: {dataset}")
|
|
275
|
-
is_json_file = True
|
|
276
|
-
else:
|
|
277
|
-
# This case should have been handled in train_command_wrapper
|
|
278
|
-
hud_console.error("Multiple task files found but none selected")
|
|
279
|
-
raise typer.Exit(1)
|
|
280
|
-
else:
|
|
281
|
-
# Use dataset from lock file
|
|
282
|
-
dataset = get_primary_dataset()
|
|
283
|
-
if dataset:
|
|
284
|
-
hud_console.info(f"Using dataset from lock file: {dataset}")
|
|
285
|
-
|
|
286
|
-
# Check if dataset is a file path
|
|
287
|
-
if dataset and Path(dataset).exists() and dataset.endswith(".json"):
|
|
288
|
-
is_json_file = True
|
|
289
|
-
|
|
290
|
-
# Validate dataset
|
|
291
|
-
if dataset and is_json_file:
|
|
292
|
-
# Load and validate JSON file
|
|
293
|
-
hud_console.info(f"Validating task file: {dataset}")
|
|
294
|
-
try:
|
|
295
|
-
with open(dataset) as f: # noqa: ASYNC230
|
|
296
|
-
tasks_data = json.load(f)
|
|
297
|
-
|
|
298
|
-
# Handle both single task and array of tasks
|
|
299
|
-
if isinstance(tasks_data, dict):
|
|
300
|
-
tasks = [tasks_data]
|
|
301
|
-
elif isinstance(tasks_data, list):
|
|
302
|
-
tasks = tasks_data
|
|
303
|
-
else:
|
|
304
|
-
hud_console.error("Invalid tasks file format")
|
|
305
|
-
raise typer.Exit(1)
|
|
306
|
-
|
|
307
|
-
dataset_size = len(tasks)
|
|
308
|
-
if dataset_size < 4:
|
|
309
|
-
hud_console.error(f"Task file has only {dataset_size} tasks")
|
|
310
|
-
hud_console.info("RL training requires at least 4 tasks for proper batching")
|
|
311
|
-
hud_console.hint("Consider adding more tasks to your JSON file")
|
|
312
|
-
raise typer.Exit(1)
|
|
313
|
-
|
|
314
|
-
hud_console.success(f"✓ Task file has {dataset_size} tasks")
|
|
315
|
-
|
|
316
|
-
# Check and convert MCP configs to remote if needed
|
|
317
|
-
if tasks:
|
|
318
|
-
sample_task = tasks[0]
|
|
319
|
-
sample_mcp_config = sample_task.get("mcp_config", {})
|
|
320
|
-
|
|
321
|
-
# Check if using local MCP configs
|
|
322
|
-
config_type = "unknown"
|
|
323
|
-
for server_config in sample_mcp_config.values():
|
|
324
|
-
if isinstance(server_config, dict) and "url" in server_config:
|
|
325
|
-
url = server_config.get("url", "")
|
|
326
|
-
if "mcp.hud.so" in url:
|
|
327
|
-
config_type = "remote"
|
|
328
|
-
break
|
|
329
|
-
else:
|
|
330
|
-
config_type = "local"
|
|
331
|
-
|
|
332
|
-
if config_type == "local":
|
|
333
|
-
hud_console.info("Converting local MCP configs to remote for training...")
|
|
334
|
-
|
|
335
|
-
# Get the image name from lock file or environment
|
|
336
|
-
from .utils import get_image_from_lock
|
|
337
|
-
|
|
338
|
-
env_image = image or get_image_from_lock()
|
|
339
|
-
|
|
340
|
-
if not env_image:
|
|
341
|
-
hud_console.error("No image found for remote MCP conversion")
|
|
342
|
-
hud_console.hint("Run 'hud build' first")
|
|
343
|
-
raise typer.Exit(1)
|
|
344
|
-
|
|
345
|
-
# Check if image needs to be pushed
|
|
346
|
-
if "/" not in env_image or env_image.startswith("local/"):
|
|
347
|
-
hud_console.warning(f"Image '{env_image}' appears to be local only")
|
|
348
|
-
hud_console.info("Running 'hud push' to make it publicly available...")
|
|
349
|
-
from hud.cli.push import push_command
|
|
350
|
-
|
|
351
|
-
push_command(directory=".", yes=True)
|
|
352
|
-
hud_console.success("Image pushed successfully")
|
|
353
|
-
# Re-read image name after push
|
|
354
|
-
env_image = get_image_from_lock()
|
|
355
|
-
|
|
356
|
-
# Convert all tasks to use remote MCP
|
|
357
|
-
for task in tasks:
|
|
358
|
-
remote_config = {
|
|
359
|
-
"hud": {
|
|
360
|
-
"url": "https://mcp.hud.so/v3/mcp",
|
|
361
|
-
"headers": {
|
|
362
|
-
"Authorization": "Bearer $HUD_API_KEY",
|
|
363
|
-
"Mcp-Image": env_image,
|
|
364
|
-
},
|
|
365
|
-
}
|
|
366
|
-
}
|
|
367
|
-
task["mcp_config"] = remote_config
|
|
368
|
-
|
|
369
|
-
hud_console.success("✓ Converted all tasks to use remote MCP configs")
|
|
370
|
-
|
|
371
|
-
# Save the modified tasks back to the file
|
|
372
|
-
with open(dataset, "w") as f: # noqa: ASYNC230
|
|
373
|
-
json.dump(tasks, f, indent=2)
|
|
374
|
-
hud_console.info("Updated task file with remote configs")
|
|
375
|
-
except json.JSONDecodeError as e:
|
|
376
|
-
hud_console.error(f"Invalid JSON in task file: {e}")
|
|
377
|
-
raise typer.Exit(1) from e
|
|
378
|
-
elif dataset:
|
|
379
|
-
# Validate HuggingFace dataset
|
|
380
|
-
hud_console.info(f"Validating dataset: {dataset}")
|
|
381
|
-
try:
|
|
382
|
-
# Try to load dataset info from HuggingFace
|
|
383
|
-
from datasets import load_dataset_builder
|
|
384
|
-
|
|
385
|
-
ds_builder = load_dataset_builder(dataset)
|
|
386
|
-
ds_info = ds_builder.info
|
|
387
|
-
|
|
388
|
-
# Check split sizes
|
|
389
|
-
train_size = ds_info.splits.get("train", None) if ds_info.splits else None
|
|
390
|
-
if train_size and train_size.num_examples < 4:
|
|
391
|
-
hud_console.error(f"Dataset '{dataset}' has only {train_size.num_examples} tasks")
|
|
392
|
-
hud_console.info("RL training requires at least 4 tasks for proper batching")
|
|
393
|
-
hud_console.hint("Consider adding more tasks or duplicating existing ones")
|
|
394
|
-
raise typer.Exit(1)
|
|
395
|
-
elif train_size:
|
|
396
|
-
dataset_size = train_size.num_examples
|
|
397
|
-
hud_console.success(f"✓ Dataset has {dataset_size} tasks")
|
|
398
|
-
except Exception as e:
|
|
399
|
-
# If we can't validate, warn but continue
|
|
400
|
-
hud_console.warning(f"Could not validate dataset size: {e}")
|
|
401
|
-
hud_console.info("Proceeding with training - ensure dataset has at least 4 tasks")
|
|
402
|
-
|
|
403
|
-
# Display configuration
|
|
404
|
-
hud_console.section_title("📋 Training Configuration")
|
|
405
|
-
hud_console.json_config(
|
|
406
|
-
json.dumps(
|
|
407
|
-
{
|
|
408
|
-
"Model": model,
|
|
409
|
-
"Dataset": dataset,
|
|
410
|
-
"Config": str(config) if config else None,
|
|
411
|
-
"Environment": image,
|
|
412
|
-
"GPUs": gpus,
|
|
413
|
-
"Provider": provider,
|
|
414
|
-
"Output": str(output_dir),
|
|
415
|
-
},
|
|
416
|
-
indent=2,
|
|
417
|
-
)
|
|
418
|
-
)
|
|
419
|
-
|
|
420
|
-
if not config:
|
|
421
|
-
hud_console.error("No config file found")
|
|
422
|
-
hud_console.hint("Run 'hud rl init' to generate a config file")
|
|
423
|
-
raise typer.Exit(1)
|
|
424
|
-
|
|
425
|
-
if not dataset:
|
|
426
|
-
hud_console.error("No dataset found")
|
|
427
|
-
hud_console.hint("Run 'hud hf tasks.json' to create a dataset")
|
|
428
|
-
raise typer.Exit(1)
|
|
429
|
-
|
|
430
|
-
# Always run remote training
|
|
431
|
-
await run_remote_training(
|
|
432
|
-
model=model,
|
|
433
|
-
dataset=dataset,
|
|
434
|
-
config=config,
|
|
435
|
-
gpus=gpus,
|
|
436
|
-
provider=provider,
|
|
437
|
-
output_dir=output_dir,
|
|
438
|
-
image=image,
|
|
439
|
-
auto_create_pod=auto_create_pod,
|
|
440
|
-
team_id=team_id,
|
|
441
|
-
dataset_size=dataset_size,
|
|
442
|
-
is_json_file=is_json_file,
|
|
443
|
-
)
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
def check_requirements(config: Path | None, dataset: str | None) -> dict[str, Any]:
|
|
447
|
-
"""Check if required components are present."""
|
|
448
|
-
missing = {}
|
|
449
|
-
|
|
450
|
-
# Check config
|
|
451
|
-
if not config:
|
|
452
|
-
config_dir = Path("configs")
|
|
453
|
-
if config_dir.exists():
|
|
454
|
-
yaml_files = list(config_dir.glob("*.yaml"))
|
|
455
|
-
if not yaml_files:
|
|
456
|
-
missing["config"] = "none"
|
|
457
|
-
elif len(yaml_files) > 1:
|
|
458
|
-
missing["config"] = "multiple"
|
|
459
|
-
# If exactly one config, we'll use it
|
|
460
|
-
else:
|
|
461
|
-
missing["config"] = "none"
|
|
462
|
-
|
|
463
|
-
# Check dataset
|
|
464
|
-
if not dataset:
|
|
465
|
-
# First check for JSON files (preferred method)
|
|
466
|
-
json_files = find_task_json_files()
|
|
467
|
-
if json_files:
|
|
468
|
-
if len(json_files) == 1:
|
|
469
|
-
# Will be auto-selected
|
|
470
|
-
pass
|
|
471
|
-
else:
|
|
472
|
-
missing["dataset"] = "multiple_json"
|
|
473
|
-
else:
|
|
474
|
-
# Check lock file for HuggingFace dataset
|
|
475
|
-
primary_dataset = get_primary_dataset()
|
|
476
|
-
if not primary_dataset:
|
|
477
|
-
missing["dataset"] = "none"
|
|
478
|
-
|
|
479
|
-
return missing
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
def generate_config_interactive() -> Path | None:
|
|
483
|
-
"""Generate config interactively and return the path."""
|
|
484
|
-
from .init import init_command
|
|
485
|
-
|
|
486
|
-
# Run init command
|
|
487
|
-
asyncio.run(init_command(".", None, False, False))
|
|
488
|
-
|
|
489
|
-
# Look for generated config
|
|
490
|
-
config_dir = Path("configs")
|
|
491
|
-
if config_dir.exists():
|
|
492
|
-
yaml_files = list(config_dir.glob("*.yaml"))
|
|
493
|
-
if yaml_files:
|
|
494
|
-
return yaml_files[0]
|
|
495
|
-
|
|
496
|
-
return None
|
|
497
|
-
|
|
498
|
-
|
|
499
|
-
def create_dataset_interactive() -> str | None:
|
|
500
|
-
"""Create dataset interactively and return the name."""
|
|
501
|
-
# Check if tasks.json exists
|
|
502
|
-
tasks_file = Path("tasks.json")
|
|
503
|
-
if not tasks_file.exists():
|
|
504
|
-
hud_console.error("No tasks.json file found")
|
|
505
|
-
return None
|
|
506
|
-
|
|
507
|
-
# Prompt for dataset name
|
|
508
|
-
dataset_name = typer.prompt("Enter HuggingFace dataset name (e.g., username/dataset-name)")
|
|
509
|
-
|
|
510
|
-
if not validate_dataset_name(dataset_name):
|
|
511
|
-
hud_console.error("Invalid dataset name format")
|
|
512
|
-
return None
|
|
513
|
-
|
|
514
|
-
# Run hf command
|
|
515
|
-
result = subprocess.run( # noqa: S603
|
|
516
|
-
["hud", "hf", "tasks.json", "--name", dataset_name], # noqa: S607
|
|
517
|
-
capture_output=True,
|
|
518
|
-
text=True,
|
|
519
|
-
)
|
|
520
|
-
|
|
521
|
-
if result.returncode == 0:
|
|
522
|
-
return dataset_name
|
|
523
|
-
else:
|
|
524
|
-
hud_console.error("Failed to create dataset")
|
|
525
|
-
if result.stderr:
|
|
526
|
-
hud_console.error(result.stderr)
|
|
527
|
-
return None
|
|
528
|
-
|
|
529
|
-
|
|
530
|
-
async def run_remote_training(
|
|
531
|
-
model: str,
|
|
532
|
-
dataset: str,
|
|
533
|
-
config: Path,
|
|
534
|
-
gpus: str,
|
|
535
|
-
provider: str,
|
|
536
|
-
output_dir: Path,
|
|
537
|
-
image: str,
|
|
538
|
-
auto_create_pod: str | None = None,
|
|
539
|
-
team_id: str | None = None,
|
|
540
|
-
dataset_size: int | None = None,
|
|
541
|
-
is_json_file: bool = False,
|
|
542
|
-
) -> None:
|
|
543
|
-
"""Run training on remote infrastructure."""
|
|
544
|
-
hud_console.section_title("🚀 Remote Training")
|
|
545
|
-
|
|
546
|
-
if provider == "prime":
|
|
547
|
-
await run_prime_training(
|
|
548
|
-
model,
|
|
549
|
-
dataset,
|
|
550
|
-
config,
|
|
551
|
-
gpus,
|
|
552
|
-
output_dir,
|
|
553
|
-
image,
|
|
554
|
-
auto_create_pod,
|
|
555
|
-
team_id,
|
|
556
|
-
dataset_size,
|
|
557
|
-
is_json_file,
|
|
558
|
-
)
|
|
559
|
-
else:
|
|
560
|
-
hud_console.error(f"Provider '{provider}' not yet supported")
|
|
561
|
-
hud_console.info("Currently supported: prime")
|
|
562
|
-
raise typer.Exit(1)
|