hud-python 0.4.47__py3-none-any.whl → 0.4.49__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of hud-python might be problematic. Click here for more details.
- hud/agents/base.py +55 -142
- hud/agents/claude.py +5 -6
- hud/agents/grounded_openai.py +1 -1
- hud/agents/misc/integration_test_agent.py +2 -0
- hud/agents/tests/test_base.py +2 -5
- hud/cli/__init__.py +80 -215
- hud/cli/build.py +105 -45
- hud/cli/dev.py +614 -743
- hud/cli/eval.py +14 -9
- hud/cli/flows/tasks.py +100 -21
- hud/cli/init.py +18 -14
- hud/cli/push.py +27 -9
- hud/cli/rl/local_runner.py +28 -16
- hud/cli/rl/vllm.py +2 -0
- hud/cli/tests/test_analyze_metadata.py +3 -2
- hud/cli/tests/test_eval.py +574 -0
- hud/cli/tests/test_mcp_server.py +6 -95
- hud/cli/tests/test_utils.py +1 -1
- hud/cli/utils/env_check.py +9 -9
- hud/cli/utils/source_hash.py +1 -1
- hud/datasets/parallel.py +0 -12
- hud/datasets/runner.py +1 -4
- hud/rl/actor.py +4 -2
- hud/rl/distributed.py +1 -1
- hud/rl/learner.py +2 -1
- hud/rl/train.py +1 -1
- hud/server/__init__.py +2 -1
- hud/server/router.py +160 -0
- hud/server/server.py +246 -79
- hud/telemetry/trace.py +1 -1
- hud/tools/base.py +20 -10
- hud/tools/computer/__init__.py +2 -0
- hud/tools/computer/qwen.py +431 -0
- hud/tools/computer/settings.py +16 -0
- hud/tools/executors/pyautogui.py +1 -1
- hud/tools/playwright.py +1 -1
- hud/types.py +2 -3
- hud/utils/hud_console.py +43 -0
- hud/utils/tests/test_version.py +1 -1
- hud/version.py +1 -1
- {hud_python-0.4.47.dist-info → hud_python-0.4.49.dist-info}/METADATA +1 -1
- {hud_python-0.4.47.dist-info → hud_python-0.4.49.dist-info}/RECORD +45 -42
- {hud_python-0.4.47.dist-info → hud_python-0.4.49.dist-info}/WHEEL +0 -0
- {hud_python-0.4.47.dist-info → hud_python-0.4.49.dist-info}/entry_points.txt +0 -0
- {hud_python-0.4.47.dist-info → hud_python-0.4.49.dist-info}/licenses/LICENSE +0 -0
hud/cli/eval.py
CHANGED
|
@@ -199,6 +199,8 @@ async def run_single_task(
|
|
|
199
199
|
) -> None:
|
|
200
200
|
"""Load one task and execute it, or detect if JSON contains a list and run as dataset."""
|
|
201
201
|
|
|
202
|
+
# Provide early feedback to user
|
|
203
|
+
hud_console.info("🔧 Initializing evaluation...")
|
|
202
204
|
# Import Task and run_dataset lazily
|
|
203
205
|
try:
|
|
204
206
|
from hud.utils.tasks import load_tasks
|
|
@@ -318,7 +320,10 @@ async def run_single_task(
|
|
|
318
320
|
)
|
|
319
321
|
display_group_statistics(stats, show_details=True)
|
|
320
322
|
else:
|
|
321
|
-
#
|
|
323
|
+
# Enable agent step logging for single task mode
|
|
324
|
+
logging.getLogger("hud.agents").setLevel(logging.INFO)
|
|
325
|
+
logging.getLogger("hud.agents.base").setLevel(logging.INFO)
|
|
326
|
+
|
|
322
327
|
with hud.trace(name=task_prompt):
|
|
323
328
|
agent = build_agent(
|
|
324
329
|
agent_type,
|
|
@@ -352,6 +357,9 @@ async def run_full_dataset(
|
|
|
352
357
|
Uses either asyncio-based run_dataset or process-based parallel execution
|
|
353
358
|
depending on the parallel flag."""
|
|
354
359
|
|
|
360
|
+
# Provide early feedback to user
|
|
361
|
+
hud_console.info("🔧 Initializing evaluation...")
|
|
362
|
+
|
|
355
363
|
# Import run_dataset lazily
|
|
356
364
|
try:
|
|
357
365
|
from hud.datasets import run_dataset, run_dataset_parallel, run_dataset_parallel_manual
|
|
@@ -367,7 +375,7 @@ async def run_full_dataset(
|
|
|
367
375
|
hud_console.info(f"📊 Loading tasks from: {source}…")
|
|
368
376
|
tasks: list[Task] = load_tasks(source) # type: ignore[assignment]
|
|
369
377
|
|
|
370
|
-
if
|
|
378
|
+
if len(tasks) == 0:
|
|
371
379
|
hud_console.error(f"No tasks found in: {source}")
|
|
372
380
|
raise typer.Exit(1)
|
|
373
381
|
|
|
@@ -646,10 +654,10 @@ def eval_command(
|
|
|
646
654
|
hud eval hud-evals/SheetBench-50 --full --agent claude
|
|
647
655
|
|
|
648
656
|
# Run large dataset with PARALLEL execution (auto-optimized)
|
|
649
|
-
hud eval hud-evals/OSWorld-Verified-
|
|
657
|
+
hud eval hud-evals/OSWorld-Verified-Gold --full --parallel
|
|
650
658
|
|
|
651
659
|
# Parallel mode with manual configuration (16 workers, 25 tasks each)
|
|
652
|
-
hud eval hud-evals/OSWorld-Verified-
|
|
660
|
+
hud eval hud-evals/OSWorld-Verified-Gold --full --parallel --max-workers 16
|
|
653
661
|
|
|
654
662
|
# Limit total concurrent tasks to prevent rate limits
|
|
655
663
|
hud eval hud-evals/SheetBench-50 --full --parallel --max-concurrent 20
|
|
@@ -674,6 +682,8 @@ def eval_command(
|
|
|
674
682
|
"""
|
|
675
683
|
from hud.settings import settings
|
|
676
684
|
|
|
685
|
+
# Always configure basic logging so agent steps can be logged
|
|
686
|
+
# Set to INFO by default for consistency with run_evaluation.py
|
|
677
687
|
if very_verbose:
|
|
678
688
|
logging.basicConfig(
|
|
679
689
|
level=logging.DEBUG,
|
|
@@ -683,11 +693,6 @@ def eval_command(
|
|
|
683
693
|
logging.getLogger("hud.agents").setLevel(logging.DEBUG)
|
|
684
694
|
logging.getLogger("hud.agents.base").setLevel(logging.DEBUG)
|
|
685
695
|
elif verbose:
|
|
686
|
-
logging.basicConfig(
|
|
687
|
-
level=logging.INFO,
|
|
688
|
-
format="%(asctime)s - %(name)s - %(message)s",
|
|
689
|
-
datefmt="%H:%M:%S",
|
|
690
|
-
)
|
|
691
696
|
logging.getLogger("hud.agents").setLevel(logging.INFO)
|
|
692
697
|
logging.getLogger("hud.agents.base").setLevel(logging.INFO)
|
|
693
698
|
|
hud/cli/flows/tasks.py
CHANGED
|
@@ -78,26 +78,38 @@ def _ensure_pushed(env_dir: Path, lock_data: dict[str, Any]) -> dict[str, Any]:
|
|
|
78
78
|
|
|
79
79
|
|
|
80
80
|
def _derive_remote_image(lock_data: dict[str, Any]) -> str:
|
|
81
|
-
"""Derive org/name:tag from lock file for MCP header.
|
|
81
|
+
"""Derive org/name:tag from lock file for remote MCP header.
|
|
82
82
|
|
|
83
|
-
Preference order:
|
|
84
|
-
1) lock_data["push"]["image_with_tag"]
|
|
85
|
-
2)
|
|
83
|
+
Preference order (new lock first, then legacy):
|
|
84
|
+
1) lock_data["push"]["image_with_tag"] (exact org/name:tag that was pushed)
|
|
85
|
+
2) lock_data["images"]["local"] (base name with internal version)
|
|
86
|
+
3) lock_data["image"] (legacy field; may contain tag or digest)
|
|
86
87
|
"""
|
|
87
|
-
|
|
88
|
+
if not isinstance(lock_data, dict): # Defensive
|
|
89
|
+
raise typer.Exit(1)
|
|
88
90
|
|
|
89
|
-
# 1)
|
|
90
|
-
|
|
91
|
+
# 1) Prefer the exact image that was pushed (org/name:tag)
|
|
92
|
+
push_info = lock_data.get("push") or {}
|
|
93
|
+
pushed_with_tag = str(push_info.get("image_with_tag") or "").strip()
|
|
91
94
|
if pushed_with_tag:
|
|
92
95
|
name, tag = extract_name_and_tag(pushed_with_tag)
|
|
93
96
|
return f"{name}:{tag}"
|
|
94
97
|
|
|
95
|
-
#
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
98
|
+
# 2) Fall back to the local tag recorded in the new lock schema
|
|
99
|
+
images = lock_data.get("images") or {}
|
|
100
|
+
local_image = str(images.get("local") or "").strip()
|
|
101
|
+
if local_image:
|
|
102
|
+
name, tag = extract_name_and_tag(local_image)
|
|
103
|
+
return f"{name}:{tag}"
|
|
104
|
+
|
|
105
|
+
# 3) Legacy top-level image field
|
|
106
|
+
legacy_image = str(lock_data.get("image") or "").strip()
|
|
107
|
+
if legacy_image:
|
|
108
|
+
name, tag = extract_name_and_tag(legacy_image)
|
|
109
|
+
return f"{name}:{tag}"
|
|
110
|
+
|
|
111
|
+
# If none of the above exist, we cannot derive an image
|
|
112
|
+
raise typer.Exit(1)
|
|
101
113
|
|
|
102
114
|
|
|
103
115
|
def _extract_existing_images(tasks: list[Task]) -> set[str]:
|
|
@@ -183,6 +195,63 @@ def _extract_dotenv_api_key_vars(env_dir: Path) -> set[str]:
|
|
|
183
195
|
return detected
|
|
184
196
|
|
|
185
197
|
|
|
198
|
+
def _extract_env_vars_from_docker_args(args: list[str]) -> set[str]:
|
|
199
|
+
"""Extract environment variable names from docker run arguments.
|
|
200
|
+
|
|
201
|
+
Parses args like: ["run", "--rm", "-i", "-e", "API_KEY=value", "-e", "TOKEN", "image:tag"]
|
|
202
|
+
Returns set of env var names (not values).
|
|
203
|
+
"""
|
|
204
|
+
env_vars: set[str] = set()
|
|
205
|
+
i = 0
|
|
206
|
+
while i < len(args):
|
|
207
|
+
arg = args[i]
|
|
208
|
+
|
|
209
|
+
# Check for -e or --env flags
|
|
210
|
+
if arg in ("-e", "--env"):
|
|
211
|
+
if i + 1 < len(args):
|
|
212
|
+
env_spec = args[i + 1]
|
|
213
|
+
# Could be "KEY=value" or just "KEY"
|
|
214
|
+
var_name = env_spec.split("=", 1)[0].strip()
|
|
215
|
+
if var_name:
|
|
216
|
+
env_vars.add(var_name)
|
|
217
|
+
i += 2
|
|
218
|
+
continue
|
|
219
|
+
# Check for --env=KEY=value format
|
|
220
|
+
elif arg.startswith("--env="):
|
|
221
|
+
env_spec = arg[6:] # Remove "--env=" prefix
|
|
222
|
+
var_name = env_spec.split("=", 1)[0].strip()
|
|
223
|
+
if var_name:
|
|
224
|
+
env_vars.add(var_name)
|
|
225
|
+
|
|
226
|
+
i += 1
|
|
227
|
+
|
|
228
|
+
env_vars.discard("HUD_API_KEY")
|
|
229
|
+
return env_vars
|
|
230
|
+
|
|
231
|
+
|
|
232
|
+
def _extract_vars_from_task_configs(raw_tasks: list[dict[str, Any]]) -> set[str]:
|
|
233
|
+
"""Extract environment variable names from docker run commands in task mcp_configs."""
|
|
234
|
+
all_env_vars: set[str] = set()
|
|
235
|
+
|
|
236
|
+
for task in raw_tasks:
|
|
237
|
+
mcp_config = task.get("mcp_config", {})
|
|
238
|
+
|
|
239
|
+
# Iterate through all server configs
|
|
240
|
+
for server_config in mcp_config.values():
|
|
241
|
+
if not isinstance(server_config, dict):
|
|
242
|
+
continue
|
|
243
|
+
|
|
244
|
+
command = server_config.get("command", "")
|
|
245
|
+
args = server_config.get("args", [])
|
|
246
|
+
|
|
247
|
+
# Only process docker run commands
|
|
248
|
+
if command == "docker" and "run" in args:
|
|
249
|
+
env_vars = _extract_env_vars_from_docker_args(args)
|
|
250
|
+
all_env_vars.update(env_vars)
|
|
251
|
+
|
|
252
|
+
return all_env_vars
|
|
253
|
+
|
|
254
|
+
|
|
186
255
|
def convert_tasks_to_remote(tasks_file: str) -> str:
|
|
187
256
|
"""Convert a local tasks file to remote MCP tasks and return new filename.
|
|
188
257
|
|
|
@@ -297,12 +366,21 @@ def convert_tasks_to_remote(tasks_file: str) -> str:
|
|
|
297
366
|
hud_console.success(f"Updated {tasks_path.name} with latest image: {remote_image}")
|
|
298
367
|
return str(tasks_path)
|
|
299
368
|
|
|
300
|
-
# Extract
|
|
369
|
+
# Extract environment variables from multiple sources:
|
|
370
|
+
# 1. Lock file (authoritative for required env vars)
|
|
301
371
|
provided_keys = _extract_api_key_vars(lock_data)
|
|
372
|
+
|
|
373
|
+
# 2. Task configs (docker run -e flags)
|
|
374
|
+
task_env_vars = _extract_vars_from_task_configs(raw_tasks)
|
|
375
|
+
|
|
376
|
+
# 3. .env file (detect API-like vars)
|
|
302
377
|
dotenv_keys = _extract_dotenv_api_key_vars(env_dir)
|
|
303
378
|
|
|
304
|
-
#
|
|
305
|
-
|
|
379
|
+
# Combine: lock file vars + task config vars, then check for missing from .env
|
|
380
|
+
all_detected = provided_keys | task_env_vars
|
|
381
|
+
|
|
382
|
+
# If .env contains API-like vars not yet included, offer to add them
|
|
383
|
+
missing = sorted(dotenv_keys - all_detected)
|
|
306
384
|
if missing:
|
|
307
385
|
names_preview = ", ".join(missing)
|
|
308
386
|
prompt = (
|
|
@@ -310,7 +388,10 @@ def convert_tasks_to_remote(tasks_file: str) -> str:
|
|
|
310
388
|
"Include them as remote headers (values will be ${VAR} placeholders)?"
|
|
311
389
|
)
|
|
312
390
|
if hud_console.confirm(prompt, default=True):
|
|
313
|
-
|
|
391
|
+
all_detected.update(missing)
|
|
392
|
+
|
|
393
|
+
# Final set of env vars to convert to headers
|
|
394
|
+
provided_keys = all_detected
|
|
314
395
|
|
|
315
396
|
extra_api_key_headers: dict[str, str] = {}
|
|
316
397
|
for var_name in provided_keys:
|
|
@@ -364,10 +445,8 @@ def convert_tasks_to_remote(tasks_file: str) -> str:
|
|
|
364
445
|
item["setup_tool"] = _simplify_tool_call(t.setup_tool)
|
|
365
446
|
if t.evaluate_tool is not None:
|
|
366
447
|
item["evaluate_tool"] = _simplify_tool_call(t.evaluate_tool)
|
|
367
|
-
if t.
|
|
368
|
-
item["
|
|
369
|
-
if t.system_prompt is not None:
|
|
370
|
-
item["system_prompt"] = t.system_prompt
|
|
448
|
+
if t.agent_config is not None:
|
|
449
|
+
item["agent_config"] = t.agent_config
|
|
371
450
|
if t.metadata:
|
|
372
451
|
item["metadata"] = t.metadata
|
|
373
452
|
if t.id is not None:
|
hud/cli/init.py
CHANGED
|
@@ -29,9 +29,12 @@ SKIP_DIR_NAMES = {"node_modules", "__pycache__", "dist", "build", ".next", ".git
|
|
|
29
29
|
|
|
30
30
|
# Files that need placeholder replacement
|
|
31
31
|
PLACEHOLDER_FILES = {
|
|
32
|
-
"pyproject.toml",
|
|
32
|
+
"server/pyproject.toml",
|
|
33
|
+
"environment/pyproject.toml",
|
|
34
|
+
"server/main.py",
|
|
35
|
+
"server/README.md",
|
|
36
|
+
"environment/README.md",
|
|
33
37
|
"tasks.json",
|
|
34
|
-
"src/controller/server.py",
|
|
35
38
|
"test_env.ipynb",
|
|
36
39
|
"README.md",
|
|
37
40
|
}
|
|
@@ -48,7 +51,7 @@ def _replace_placeholders(target_dir: Path, env_name: str) -> list[str]:
|
|
|
48
51
|
List of files that were modified
|
|
49
52
|
"""
|
|
50
53
|
modified_files = []
|
|
51
|
-
placeholder = "
|
|
54
|
+
placeholder = "blank" # Placeholder used in blank environment template
|
|
52
55
|
|
|
53
56
|
# Normalize environment name for use in code/configs
|
|
54
57
|
# Replace spaces and special chars with underscores for Python identifiers
|
|
@@ -240,17 +243,18 @@ def create_environment(
|
|
|
240
243
|
f"Downloaded {len(files_created_dl)} files in {duration_ms} ms into {target_dir}"
|
|
241
244
|
)
|
|
242
245
|
|
|
243
|
-
# Replace placeholders in template files
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
246
|
+
# Replace placeholders in template files (only for blank preset)
|
|
247
|
+
if preset_normalized == "blank":
|
|
248
|
+
hud_console.section_title("Customizing template files")
|
|
249
|
+
modified_files = _replace_placeholders(target_dir, name)
|
|
250
|
+
if modified_files:
|
|
251
|
+
hud_console.success(f"Replaced placeholders in {len(modified_files)} files:")
|
|
252
|
+
for file in modified_files[:5]: # Show first 5 files
|
|
253
|
+
hud_console.status_item(file, "updated")
|
|
254
|
+
if len(modified_files) > 5:
|
|
255
|
+
hud_console.info(f"... and {len(modified_files) - 5} more files")
|
|
256
|
+
else:
|
|
257
|
+
hud_console.info("No placeholder replacements needed")
|
|
254
258
|
|
|
255
259
|
hud_console.section_title("Top-level files and folders")
|
|
256
260
|
for entry in sorted(os.listdir(target_dir)):
|
hud/cli/push.py
CHANGED
|
@@ -163,10 +163,7 @@ def push_environment(
|
|
|
163
163
|
lock_data = yaml.safe_load(f)
|
|
164
164
|
|
|
165
165
|
# Handle both old and new lock file formats
|
|
166
|
-
local_image = lock_data.get("image", "")
|
|
167
|
-
if not local_image and "build" in lock_data:
|
|
168
|
-
# New format might have image elsewhere
|
|
169
|
-
local_image = lock_data.get("image", "")
|
|
166
|
+
local_image = lock_data.get("images", {}).get("local") or lock_data.get("image", "")
|
|
170
167
|
|
|
171
168
|
# Get internal version from lock file
|
|
172
169
|
internal_version = lock_data.get("build", {}).get("version", None)
|
|
@@ -293,7 +290,7 @@ def push_environment(
|
|
|
293
290
|
# Push the image
|
|
294
291
|
hud_console.progress_message(f"Pushing {image} to registry...")
|
|
295
292
|
|
|
296
|
-
# Show push output
|
|
293
|
+
# Show push output (filtered for cleaner display)
|
|
297
294
|
process = subprocess.Popen( # noqa: S603
|
|
298
295
|
["docker", "push", image], # noqa: S607
|
|
299
296
|
stdout=subprocess.PIPE,
|
|
@@ -303,8 +300,27 @@ def push_environment(
|
|
|
303
300
|
errors="replace",
|
|
304
301
|
)
|
|
305
302
|
|
|
303
|
+
# Filter output to only show meaningful progress
|
|
304
|
+
layers_pushed = 0
|
|
306
305
|
for line in process.stdout or []:
|
|
307
|
-
|
|
306
|
+
line = line.rstrip()
|
|
307
|
+
# Only show: digest, pushed, mounted, or error lines
|
|
308
|
+
if any(
|
|
309
|
+
keyword in line.lower()
|
|
310
|
+
for keyword in ["digest:", "pushed", "mounted", "error", "denied"]
|
|
311
|
+
):
|
|
312
|
+
if "pushed" in line.lower():
|
|
313
|
+
layers_pushed += 1
|
|
314
|
+
if (
|
|
315
|
+
verbose
|
|
316
|
+
or "error" in line.lower()
|
|
317
|
+
or "denied" in line.lower()
|
|
318
|
+
or "digest:" in line.lower()
|
|
319
|
+
):
|
|
320
|
+
hud_console.info(line)
|
|
321
|
+
|
|
322
|
+
if layers_pushed > 0 and not verbose:
|
|
323
|
+
hud_console.info(f"Pushed {layers_pushed} layer(s)")
|
|
308
324
|
|
|
309
325
|
process.wait()
|
|
310
326
|
|
|
@@ -331,8 +347,10 @@ def push_environment(
|
|
|
331
347
|
hud_console.section_title("Pushed Image")
|
|
332
348
|
hud_console.status_item("Registry", pushed_digest, primary=True)
|
|
333
349
|
|
|
334
|
-
# Update the lock file with
|
|
335
|
-
|
|
350
|
+
# Update the lock file with pushed image reference
|
|
351
|
+
if "images" not in lock_data:
|
|
352
|
+
lock_data["images"] = {}
|
|
353
|
+
lock_data["images"]["pushed"] = image
|
|
336
354
|
|
|
337
355
|
# Add push information
|
|
338
356
|
from datetime import UTC, datetime
|
|
@@ -348,7 +366,7 @@ def push_environment(
|
|
|
348
366
|
with open(lock_path, "w") as f:
|
|
349
367
|
yaml.dump(lock_data, f, default_flow_style=False, sort_keys=False)
|
|
350
368
|
|
|
351
|
-
hud_console.success("Updated lock file with
|
|
369
|
+
hud_console.success("Updated lock file with pushed image reference")
|
|
352
370
|
|
|
353
371
|
# Upload lock file to HUD registry
|
|
354
372
|
try:
|
hud/cli/rl/local_runner.py
CHANGED
|
@@ -190,9 +190,9 @@ def run_local_training(
|
|
|
190
190
|
|
|
191
191
|
invalid_tasks: list[str] = []
|
|
192
192
|
for i, task in enumerate(tasks):
|
|
193
|
-
if not hasattr(task, "prompt") or not task.prompt:
|
|
193
|
+
if not hasattr(task, "prompt") or not task.prompt: # type: ignore
|
|
194
194
|
invalid_tasks.append(f"Task {i}: missing 'prompt' field")
|
|
195
|
-
if not hasattr(task, "mcp_config") or not task.mcp_config:
|
|
195
|
+
if not hasattr(task, "mcp_config") or not task.mcp_config: # type: ignore
|
|
196
196
|
invalid_tasks.append(f"Task {i}: missing 'mcp_config' field")
|
|
197
197
|
|
|
198
198
|
if invalid_tasks:
|
|
@@ -230,19 +230,33 @@ def run_local_training(
|
|
|
230
230
|
console.print("Enter the model name (HuggingFace ID):")
|
|
231
231
|
model = input().strip()
|
|
232
232
|
|
|
233
|
-
#
|
|
234
|
-
if
|
|
233
|
+
# try to get model from config file
|
|
234
|
+
if config_file:
|
|
235
|
+
console.print(f"\n[cyan]Loading configuration from: {config_file}[/cyan]")
|
|
236
|
+
config = load_config(config_file)
|
|
237
|
+
if hasattr(config, "model") and hasattr(config.model, "base_model"):
|
|
238
|
+
if model is None:
|
|
239
|
+
model = config.model.base_model
|
|
240
|
+
else:
|
|
241
|
+
console.print(
|
|
242
|
+
f"[yellow]Model already set to {model}, using that instead "
|
|
243
|
+
f"of {config.model.base_model}[/yellow] (override)"
|
|
244
|
+
)
|
|
245
|
+
|
|
246
|
+
if model is None:
|
|
247
|
+
console.print("[red]❌ No model specified either through CLI or config file[/red]")
|
|
235
248
|
try:
|
|
236
|
-
|
|
237
|
-
except ValueError as e:
|
|
238
|
-
console.print(f"\n[red]❌ {e}[/red]")
|
|
239
|
-
try:
|
|
240
|
-
import typer
|
|
249
|
+
import typer
|
|
241
250
|
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
251
|
+
raise typer.Exit(1)
|
|
252
|
+
except Exception:
|
|
253
|
+
return
|
|
254
|
+
|
|
255
|
+
# Validate model is a VL model (whether provided via CLI or selected)
|
|
256
|
+
try:
|
|
257
|
+
validate_vl_model(model)
|
|
258
|
+
except ValueError as e:
|
|
259
|
+
console.print(f"\n[red]❌ {e}[/red]")
|
|
246
260
|
try:
|
|
247
261
|
import typer
|
|
248
262
|
|
|
@@ -488,7 +502,6 @@ def run_local_training(
|
|
|
488
502
|
from .vllm import start_vllm_server, wait_for_vllm_server
|
|
489
503
|
|
|
490
504
|
start_vllm_server(config.model.base_model, vllm_gpu_idx, restart=restart)
|
|
491
|
-
|
|
492
505
|
server_ready = asyncio.run(wait_for_vllm_server())
|
|
493
506
|
if not server_ready:
|
|
494
507
|
console.print("[red]❌ Failed to start vLLM server[/red]")
|
|
@@ -507,7 +520,6 @@ def run_local_training(
|
|
|
507
520
|
f"\n[bold green]🎯 Starting DDP training on {len(training_gpus)} GPUs...[/bold green]\n"
|
|
508
521
|
)
|
|
509
522
|
launch_ddp_training(training_gpus, tasks_file, temp_config_path, verbose)
|
|
510
|
-
console.print("\n[green]✅ Training completed successfully![/green]")
|
|
511
523
|
else:
|
|
512
524
|
console.print("\n[bold green]🎯 Starting single-GPU training...[/bold green]\n")
|
|
513
525
|
try:
|
|
@@ -518,7 +530,7 @@ def run_local_training(
|
|
|
518
530
|
# Import and run the async training function lazily
|
|
519
531
|
from hud.rl.train import train # heavy import
|
|
520
532
|
|
|
521
|
-
asyncio.run(train(config, tasks))
|
|
533
|
+
asyncio.run(train(config, tasks)) # type: ignore
|
|
522
534
|
console.print("\n[green]✅ Training completed successfully![/green]")
|
|
523
535
|
|
|
524
536
|
try:
|
hud/cli/rl/vllm.py
CHANGED
|
@@ -165,6 +165,8 @@ async def wait_for_vllm_server(timeout: int = 360) -> bool: # noqa: ASYNC109
|
|
|
165
165
|
if response.status_code == 200:
|
|
166
166
|
console.print("[green]✅ vLLM server is ready![/green]")
|
|
167
167
|
return True
|
|
168
|
+
except httpx.ConnectError:
|
|
169
|
+
pass
|
|
168
170
|
except Exception as e:
|
|
169
171
|
hud_console.error(f"Failed to connect to vLLM server: {e}")
|
|
170
172
|
|
|
@@ -214,6 +214,7 @@ class TestAnalyzeFromMetadata:
|
|
|
214
214
|
|
|
215
215
|
@mock.patch("hud.cli.utils.metadata.check_local_cache")
|
|
216
216
|
@mock.patch("hud.cli.utils.metadata.fetch_lock_from_registry")
|
|
217
|
+
@mock.patch("hud.cli.utils.metadata.hud_console")
|
|
217
218
|
@mock.patch("hud.cli.utils.metadata.console")
|
|
218
219
|
async def test_analyze_not_found(self, mock_console, mock_hud_console, mock_fetch, mock_check):
|
|
219
220
|
"""Test when environment not found anywhere."""
|
|
@@ -222,9 +223,9 @@ class TestAnalyzeFromMetadata:
|
|
|
222
223
|
|
|
223
224
|
await analyze_from_metadata("test/notfound:latest", "json", verbose=False)
|
|
224
225
|
|
|
225
|
-
# Should show error
|
|
226
|
+
# Should show error via hud_console
|
|
226
227
|
mock_hud_console.error.assert_called_with("Environment metadata not found")
|
|
227
|
-
# Should print suggestions
|
|
228
|
+
# Should print suggestions via console
|
|
228
229
|
mock_console.print.assert_called()
|
|
229
230
|
|
|
230
231
|
@mock.patch("hud.cli.utils.metadata.check_local_cache")
|