hud-python 0.4.15__py3-none-any.whl → 0.4.17__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of hud-python might be problematic. Click here for more details.

hud/cli/rl/train.py ADDED
@@ -0,0 +1,560 @@
1
+ """Main RL training command implementation."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import asyncio
6
+ import json
7
+ import subprocess
8
+ from pathlib import Path
9
+ from typing import Any
10
+
11
+ import typer
12
+
13
+ from hud.settings import settings
14
+ from hud.utils.design import HUDDesign
15
+
16
+ from .pod import run_prime_training
17
+ from .utils import (
18
+ detect_image_name,
19
+ get_primary_dataset,
20
+ validate_dataset_name,
21
+ )
22
+
23
+ design = HUDDesign()
24
+
25
+
26
+ def find_task_json_files() -> list[Path]:
27
+ """Find JSON files containing tasks in the current directory."""
28
+ json_files = []
29
+ patterns = [
30
+ "*task*.json",
31
+ "*eval*.json",
32
+ "*Task*.json",
33
+ "*Eval*.json",
34
+ "*TASK*.json",
35
+ "*EVAL*.json",
36
+ "tasks.json", # Most common name
37
+ ]
38
+
39
+ # First check current directory
40
+ for pattern in patterns:
41
+ json_files.extend(Path(".").glob(pattern))
42
+
43
+ # If no files found, search one level deep
44
+ if not json_files:
45
+ for pattern in patterns:
46
+ json_files.extend(Path(".").glob(f"*/{pattern}"))
47
+
48
+ # Remove duplicates and sort, prioritizing "tasks.json"
49
+ json_files = sorted(set(json_files))
50
+
51
+ # Put tasks.json first if it exists
52
+ tasks_json = Path("tasks.json")
53
+ if tasks_json in json_files:
54
+ json_files.remove(tasks_json)
55
+ json_files.insert(0, tasks_json)
56
+
57
+ return json_files
58
+
59
+
60
+ def train_command_wrapper(
61
+ model: str,
62
+ dataset: str | None,
63
+ config: Path | None,
64
+ gpus: str,
65
+ provider: str,
66
+ output_dir: Path,
67
+ ) -> None:
68
+ """Wrapper to handle interactive prompts before entering async context."""
69
+ # Pre-flight checks for required environment variables
70
+ design.section_title("🔍 Pre-flight Checks")
71
+
72
+ missing_vars = []
73
+
74
+ # Check HUD API key
75
+ if not settings.api_key:
76
+ missing_vars.append("HUD_API_KEY")
77
+ else:
78
+ design.success("✓ HUD_API_KEY configured")
79
+
80
+ # Check WANDB API key (optional but recommended)
81
+ if not getattr(settings, "wandb_api_key", None):
82
+ design.warning("⚠ WANDB_API_KEY not set (optional but recommended for training metrics)")
83
+ else:
84
+ design.success("✓ WANDB_API_KEY configured")
85
+
86
+ # Check PRIME API key (required for remote training)
87
+ if provider == "prime" and not getattr(settings, "prime_api_key", None):
88
+ missing_vars.append("PRIME_API_KEY")
89
+ elif provider == "prime":
90
+ design.success("✓ PRIME_API_KEY configured")
91
+
92
+ if missing_vars:
93
+ design.error(f"Missing required environment variables: {', '.join(missing_vars)}")
94
+ design.info("")
95
+ design.info("Set them using one of these methods:")
96
+ design.info("1. Environment variables:")
97
+ for var in missing_vars:
98
+ design.command_example(f"export {var}=your-{var.lower().replace('_', '-')}")
99
+ design.info("")
100
+ design.info("2. Create a .env file in your project root:")
101
+ design.command_example(
102
+ "\n".join([f"{var}=your-{var.lower().replace('_', '-')}" for var in missing_vars]),
103
+ "env",
104
+ )
105
+ raise typer.Exit(1)
106
+
107
+ # Check for required components
108
+ missing = check_requirements(config, dataset)
109
+
110
+ # Auto-detect config if not specified and exactly one exists
111
+ if not config and "config" not in missing:
112
+ config_dir = Path("configs")
113
+ if config_dir.exists():
114
+ yaml_files = list(config_dir.glob("*.yaml"))
115
+ if len(yaml_files) == 1:
116
+ config = yaml_files[0]
117
+ design.info(f"Using config: {config}")
118
+
119
+ # Store user choice for pod creation
120
+ auto_create_pod = None
121
+ team_id = None
122
+
123
+ if missing:
124
+ # Handle interactive prompts here
125
+ if "config" in missing:
126
+ if missing["config"] == "multiple":
127
+ # Select from multiple configs
128
+ config_dir = Path("configs")
129
+ yaml_files = list(config_dir.glob("*.yaml"))
130
+ config_names = [f.name for f in yaml_files]
131
+ selected_config = design.select(
132
+ "Multiple config files found. Select one:", config_names
133
+ )
134
+ config = config_dir / selected_config
135
+ else:
136
+ # No config found, offer to generate
137
+ generate_config = design.select(
138
+ "No config file found. Would you like to generate one?",
139
+ ["Yes, generate config", "No, I'll create it manually"],
140
+ )
141
+
142
+ if generate_config == "Yes, generate config":
143
+ design.info("Running 'hud rl init' to generate config...")
144
+ design.info("")
145
+ # Import here to avoid circular imports
146
+ from .init import init_command_wrapper
147
+
148
+ init_command_wrapper(".", None, False, False)
149
+
150
+ # Look for generated config
151
+ config_dir = Path("configs")
152
+ if config_dir.exists():
153
+ yaml_files = list(config_dir.glob("*.yaml"))
154
+ if yaml_files:
155
+ config = yaml_files[0]
156
+ design.success(f"Using generated config: {config}")
157
+ else:
158
+ design.error("Config generation failed")
159
+ raise typer.Exit(1)
160
+ else:
161
+ design.info("Please create a config file and try again")
162
+ raise typer.Exit(1)
163
+
164
+ if "dataset" in missing:
165
+ if missing["dataset"] == "multiple_json":
166
+ # Multiple JSON files found, let user choose
167
+ json_files = find_task_json_files()
168
+ design.info("Multiple task files found:")
169
+ file_choice = design.select(
170
+ "Select a task file to use:",
171
+ choices=[str(f) for f in json_files],
172
+ )
173
+ dataset = file_choice
174
+ design.success(f"Selected: {dataset}")
175
+ elif missing["dataset"] == "none":
176
+ design.error("No dataset specified and no task JSON files found")
177
+ design.info("Please use --dataset or create a tasks.json file")
178
+ design.hint(
179
+ "Example: hud hf --name my-org/my-tasks # Generate tasks from HUD evaluation"
180
+ )
181
+ raise typer.Exit(1)
182
+
183
+ # Ask about pod creation for Prime training
184
+ if provider == "prime":
185
+ # Check if team ID is globally configured
186
+ team_check = subprocess.run(
187
+ ["prime", "config", "view"], # noqa: S607
188
+ capture_output=True,
189
+ text=True,
190
+ )
191
+
192
+ has_global_team = False
193
+ if team_check.returncode == 0:
194
+ # Parse the table output - look for Team ID row
195
+ for line in team_check.stdout.split("\n"):
196
+ if "team id" in line.lower():
197
+ # Check if there's a value after the | separator
198
+ parts = line.split("|")
199
+ if len(parts) >= 2:
200
+ # Get the value part and check if it's not empty
201
+ value = parts[1].strip()
202
+ if value and value != "None":
203
+ has_global_team = True
204
+ design.info("Using globally configured team ID")
205
+ break
206
+
207
+ if not has_global_team:
208
+ # Only ask if no global team is configured
209
+ auto_create_pod = design.select(
210
+ "How would you like to create the Prime Intellect pod?",
211
+ ["Personal account (automated)", "Team account (enter team ID)"],
212
+ )
213
+
214
+ # If team account selected, get the team ID
215
+ if auto_create_pod == "Team account (enter team ID)":
216
+ team_id = typer.prompt("Enter your team ID (e.g., team_abc123def456)")
217
+
218
+ # Save it globally automatically
219
+ subprocess.run(["prime", "config", "set-team-id", team_id]) # noqa: S603, S607
220
+ design.success("Team ID saved globally")
221
+
222
+ auto_create_pod = (
223
+ "Personal account (automated)" # Treat as automated after getting team ID
224
+ )
225
+
226
+ # Now run the async command
227
+ asyncio.run(
228
+ train_command(
229
+ model=model,
230
+ dataset=dataset,
231
+ config=config,
232
+ gpus=gpus,
233
+ provider=provider,
234
+ output_dir=output_dir,
235
+ auto_create_pod=auto_create_pod,
236
+ team_id=team_id,
237
+ )
238
+ )
239
+
240
+
241
+ async def train_command(
242
+ model: str,
243
+ dataset: str | None,
244
+ config: Path | None,
245
+ gpus: str,
246
+ provider: str,
247
+ output_dir: Path,
248
+ auto_create_pod: str | None = None,
249
+ team_id: str | None = None,
250
+ ) -> None:
251
+ """Run RL training on HUD environments."""
252
+ design.header("🤖 HUD RL Training")
253
+
254
+ # Get environment image
255
+ image = detect_image_name()
256
+ if not image:
257
+ design.error("No environment image found")
258
+ design.hint("Run 'hud build' first or specify with 'hud rl init <image>'")
259
+ raise typer.Exit(1)
260
+
261
+ # Handle dataset (JSON file or HuggingFace dataset)
262
+ dataset_size = None
263
+ is_json_file = False
264
+
265
+ # Use dataset from command or look for JSON files
266
+ if not dataset:
267
+ # Check for JSON files if no dataset specified
268
+ json_files = find_task_json_files()
269
+ if json_files:
270
+ if len(json_files) == 1:
271
+ dataset = str(json_files[0])
272
+ design.info(f"Found task file: {dataset}")
273
+ is_json_file = True
274
+ else:
275
+ # This case should have been handled in train_command_wrapper
276
+ design.error("Multiple task files found but none selected")
277
+ raise typer.Exit(1)
278
+ else:
279
+ # Use dataset from lock file
280
+ dataset = get_primary_dataset()
281
+ if dataset:
282
+ design.info(f"Using dataset from lock file: {dataset}")
283
+
284
+ # Check if dataset is a file path
285
+ if dataset and Path(dataset).exists() and dataset.endswith(".json"):
286
+ is_json_file = True
287
+
288
+ # Validate dataset
289
+ if dataset and is_json_file:
290
+ # Load and validate JSON file
291
+ design.info(f"Validating task file: {dataset}")
292
+ try:
293
+ with open(dataset) as f: # noqa: ASYNC230
294
+ tasks_data = json.load(f)
295
+
296
+ # Handle both single task and array of tasks
297
+ if isinstance(tasks_data, dict):
298
+ tasks = [tasks_data]
299
+ elif isinstance(tasks_data, list):
300
+ tasks = tasks_data
301
+ else:
302
+ design.error("Invalid tasks file format")
303
+ raise typer.Exit(1)
304
+
305
+ dataset_size = len(tasks)
306
+ if dataset_size < 4:
307
+ design.error(f"Task file has only {dataset_size} tasks")
308
+ design.info("RL training requires at least 4 tasks for proper batching")
309
+ design.hint("Consider adding more tasks to your JSON file")
310
+ raise typer.Exit(1)
311
+
312
+ design.success(f"✓ Task file has {dataset_size} tasks")
313
+
314
+ # Check and convert MCP configs to remote if needed
315
+ if tasks:
316
+ sample_task = tasks[0]
317
+ sample_mcp_config = sample_task.get("mcp_config", {})
318
+
319
+ # Check if using local MCP configs
320
+ config_type = "unknown"
321
+ for server_config in sample_mcp_config.values():
322
+ if isinstance(server_config, dict) and "url" in server_config:
323
+ url = server_config.get("url", "")
324
+ if "mcp.hud.so" in url:
325
+ config_type = "remote"
326
+ break
327
+ else:
328
+ config_type = "local"
329
+
330
+ if config_type == "local":
331
+ design.info("Converting local MCP configs to remote for training...")
332
+
333
+ # Get the image name from lock file or environment
334
+ from .utils import get_image_from_lock
335
+
336
+ env_image = image or get_image_from_lock()
337
+
338
+ if not env_image:
339
+ design.error("No image found for remote MCP conversion")
340
+ design.hint("Run 'hud build' first")
341
+ raise typer.Exit(1)
342
+
343
+ # Check if image needs to be pushed
344
+ if "/" not in env_image or env_image.startswith("local/"):
345
+ design.warning(f"Image '{env_image}' appears to be local only")
346
+ design.info("Running 'hud push' to make it publicly available...")
347
+ from hud.cli.push import push_command
348
+
349
+ push_command(directory=".", yes=True)
350
+ design.success("Image pushed successfully")
351
+ # Re-read image name after push
352
+ env_image = get_image_from_lock()
353
+
354
+ # Convert all tasks to use remote MCP
355
+ for task in tasks:
356
+ remote_config = {
357
+ "hud": {
358
+ "url": "https://mcp.hud.so/v3/mcp",
359
+ "headers": {
360
+ "Authorization": "Bearer $HUD_API_KEY",
361
+ "Mcp-Image": env_image,
362
+ },
363
+ }
364
+ }
365
+ task["mcp_config"] = remote_config
366
+
367
+ design.success("✓ Converted all tasks to use remote MCP configs")
368
+
369
+ # Save the modified tasks back to the file
370
+ with open(dataset, "w") as f: # noqa: ASYNC230
371
+ json.dump(tasks, f, indent=2)
372
+ design.info("Updated task file with remote configs")
373
+ except json.JSONDecodeError as e:
374
+ design.error(f"Invalid JSON in task file: {e}")
375
+ raise typer.Exit(1) from e
376
+ elif dataset:
377
+ # Validate HuggingFace dataset
378
+ design.info(f"Validating dataset: {dataset}")
379
+ try:
380
+ # Try to load dataset info from HuggingFace
381
+ from datasets import load_dataset_builder
382
+
383
+ ds_builder = load_dataset_builder(dataset)
384
+ ds_info = ds_builder.info
385
+
386
+ # Check split sizes
387
+ train_size = ds_info.splits.get("train", None) if ds_info.splits else None
388
+ if train_size and train_size.num_examples < 4:
389
+ design.error(f"Dataset '{dataset}' has only {train_size.num_examples} tasks")
390
+ design.info("RL training requires at least 4 tasks for proper batching")
391
+ design.hint("Consider adding more tasks or duplicating existing ones")
392
+ raise typer.Exit(1)
393
+ elif train_size:
394
+ dataset_size = train_size.num_examples
395
+ design.success(f"✓ Dataset has {dataset_size} tasks")
396
+ except Exception as e:
397
+ # If we can't validate, warn but continue
398
+ design.warning(f"Could not validate dataset size: {e}")
399
+ design.info("Proceeding with training - ensure dataset has at least 4 tasks")
400
+
401
+ # Display configuration
402
+ design.section_title("📋 Training Configuration")
403
+ design.json_config(
404
+ json.dumps(
405
+ {
406
+ "Model": model,
407
+ "Dataset": dataset,
408
+ "Config": str(config) if config else None,
409
+ "Environment": image,
410
+ "GPUs": gpus,
411
+ "Provider": provider,
412
+ "Output": str(output_dir),
413
+ },
414
+ indent=2,
415
+ )
416
+ )
417
+
418
+ if not config:
419
+ design.error("No config file found")
420
+ design.hint("Run 'hud rl init' to generate a config file")
421
+ raise typer.Exit(1)
422
+
423
+ if not dataset:
424
+ design.error("No dataset found")
425
+ design.hint("Run 'hud hf tasks.json' to create a dataset")
426
+ raise typer.Exit(1)
427
+
428
+ # Always run remote training
429
+ await run_remote_training(
430
+ model=model,
431
+ dataset=dataset,
432
+ config=config,
433
+ gpus=gpus,
434
+ provider=provider,
435
+ output_dir=output_dir,
436
+ image=image,
437
+ auto_create_pod=auto_create_pod,
438
+ team_id=team_id,
439
+ dataset_size=dataset_size,
440
+ is_json_file=is_json_file,
441
+ )
442
+
443
+
444
+ def check_requirements(config: Path | None, dataset: str | None) -> dict[str, Any]:
445
+ """Check if required components are present."""
446
+ missing = {}
447
+
448
+ # Check config
449
+ if not config:
450
+ config_dir = Path("configs")
451
+ if config_dir.exists():
452
+ yaml_files = list(config_dir.glob("*.yaml"))
453
+ if not yaml_files:
454
+ missing["config"] = "none"
455
+ elif len(yaml_files) > 1:
456
+ missing["config"] = "multiple"
457
+ # If exactly one config, we'll use it
458
+ else:
459
+ missing["config"] = "none"
460
+
461
+ # Check dataset
462
+ if not dataset:
463
+ # First check for JSON files (preferred method)
464
+ json_files = find_task_json_files()
465
+ if json_files:
466
+ if len(json_files) == 1:
467
+ # Will be auto-selected
468
+ pass
469
+ else:
470
+ missing["dataset"] = "multiple_json"
471
+ else:
472
+ # Check lock file for HuggingFace dataset
473
+ primary_dataset = get_primary_dataset()
474
+ if not primary_dataset:
475
+ missing["dataset"] = "none"
476
+
477
+ return missing
478
+
479
+
480
+ def generate_config_interactive() -> Path | None:
481
+ """Generate config interactively and return the path."""
482
+ from .init import init_command
483
+
484
+ # Run init command
485
+ asyncio.run(init_command(".", None, False, False))
486
+
487
+ # Look for generated config
488
+ config_dir = Path("configs")
489
+ if config_dir.exists():
490
+ yaml_files = list(config_dir.glob("*.yaml"))
491
+ if yaml_files:
492
+ return yaml_files[0]
493
+
494
+ return None
495
+
496
+
497
+ def create_dataset_interactive() -> str | None:
498
+ """Create dataset interactively and return the name."""
499
+ # Check if tasks.json exists
500
+ tasks_file = Path("tasks.json")
501
+ if not tasks_file.exists():
502
+ design.error("No tasks.json file found")
503
+ return None
504
+
505
+ # Prompt for dataset name
506
+ dataset_name = typer.prompt("Enter HuggingFace dataset name (e.g., username/dataset-name)")
507
+
508
+ if not validate_dataset_name(dataset_name):
509
+ design.error("Invalid dataset name format")
510
+ return None
511
+
512
+ # Run hf command
513
+ result = subprocess.run( # noqa: S603
514
+ ["hud", "hf", "tasks.json", "--name", dataset_name], # noqa: S607
515
+ capture_output=True,
516
+ text=True,
517
+ )
518
+
519
+ if result.returncode == 0:
520
+ return dataset_name
521
+ else:
522
+ design.error("Failed to create dataset")
523
+ if result.stderr:
524
+ design.error(result.stderr)
525
+ return None
526
+
527
+
528
+ async def run_remote_training(
529
+ model: str,
530
+ dataset: str,
531
+ config: Path,
532
+ gpus: str,
533
+ provider: str,
534
+ output_dir: Path,
535
+ image: str,
536
+ auto_create_pod: str | None = None,
537
+ team_id: str | None = None,
538
+ dataset_size: int | None = None,
539
+ is_json_file: bool = False,
540
+ ) -> None:
541
+ """Run training on remote infrastructure."""
542
+ design.section_title("🚀 Remote Training")
543
+
544
+ if provider == "prime":
545
+ await run_prime_training(
546
+ model,
547
+ dataset,
548
+ config,
549
+ gpus,
550
+ output_dir,
551
+ image,
552
+ auto_create_pod,
553
+ team_id,
554
+ dataset_size,
555
+ is_json_file,
556
+ )
557
+ else:
558
+ design.error(f"Provider '{provider}' not yet supported")
559
+ design.info("Currently supported: prime")
560
+ raise typer.Exit(1)