hud-python 0.4.45__py3-none-any.whl → 0.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hud/__init__.py +27 -7
- hud/agents/__init__.py +11 -5
- hud/agents/base.py +220 -500
- hud/agents/claude.py +200 -240
- hud/agents/gemini.py +275 -0
- hud/agents/gemini_cua.py +335 -0
- hud/agents/grounded_openai.py +98 -100
- hud/agents/misc/integration_test_agent.py +51 -20
- hud/agents/misc/response_agent.py +41 -36
- hud/agents/openai.py +291 -292
- hud/agents/{openai_chat_generic.py → openai_chat.py} +80 -34
- hud/agents/operator.py +211 -0
- hud/agents/tests/conftest.py +133 -0
- hud/agents/tests/test_base.py +300 -622
- hud/agents/tests/test_base_runtime.py +233 -0
- hud/agents/tests/test_claude.py +379 -210
- hud/agents/tests/test_client.py +9 -10
- hud/agents/tests/test_gemini.py +369 -0
- hud/agents/tests/test_grounded_openai_agent.py +65 -50
- hud/agents/tests/test_openai.py +376 -140
- hud/agents/tests/test_operator.py +362 -0
- hud/agents/tests/test_run_eval.py +179 -0
- hud/cli/__init__.py +461 -545
- hud/cli/analyze.py +43 -5
- hud/cli/build.py +664 -110
- hud/cli/debug.py +8 -5
- hud/cli/dev.py +882 -734
- hud/cli/eval.py +782 -668
- hud/cli/flows/dev.py +167 -0
- hud/cli/flows/init.py +191 -0
- hud/cli/flows/tasks.py +153 -56
- hud/cli/flows/templates.py +151 -0
- hud/cli/flows/tests/__init__.py +1 -0
- hud/cli/flows/tests/test_dev.py +126 -0
- hud/cli/init.py +60 -58
- hud/cli/push.py +29 -11
- hud/cli/rft.py +311 -0
- hud/cli/rft_status.py +145 -0
- hud/cli/tests/test_analyze.py +5 -5
- hud/cli/tests/test_analyze_metadata.py +3 -2
- hud/cli/tests/test_analyze_module.py +120 -0
- hud/cli/tests/test_build.py +108 -6
- hud/cli/tests/test_build_failure.py +41 -0
- hud/cli/tests/test_build_module.py +50 -0
- hud/cli/tests/test_cli_init.py +6 -1
- hud/cli/tests/test_cli_more_wrappers.py +30 -0
- hud/cli/tests/test_cli_root.py +140 -0
- hud/cli/tests/test_convert.py +361 -0
- hud/cli/tests/test_debug.py +12 -10
- hud/cli/tests/test_dev.py +197 -0
- hud/cli/tests/test_eval.py +251 -0
- hud/cli/tests/test_eval_bedrock.py +51 -0
- hud/cli/tests/test_init.py +124 -0
- hud/cli/tests/test_main_module.py +11 -5
- hud/cli/tests/test_mcp_server.py +12 -100
- hud/cli/tests/test_push_happy.py +74 -0
- hud/cli/tests/test_push_wrapper.py +23 -0
- hud/cli/tests/test_registry.py +1 -1
- hud/cli/tests/test_utils.py +1 -1
- hud/cli/{rl → utils}/celebrate.py +14 -12
- hud/cli/utils/config.py +18 -1
- hud/cli/utils/docker.py +130 -4
- hud/cli/utils/env_check.py +9 -9
- hud/cli/utils/git.py +136 -0
- hud/cli/utils/interactive.py +39 -5
- hud/cli/utils/metadata.py +69 -0
- hud/cli/utils/runner.py +1 -1
- hud/cli/utils/server.py +2 -2
- hud/cli/utils/source_hash.py +3 -3
- hud/cli/utils/tasks.py +4 -1
- hud/cli/utils/tests/__init__.py +0 -0
- hud/cli/utils/tests/test_config.py +58 -0
- hud/cli/utils/tests/test_docker.py +93 -0
- hud/cli/utils/tests/test_docker_hints.py +71 -0
- hud/cli/utils/tests/test_env_check.py +74 -0
- hud/cli/utils/tests/test_environment.py +42 -0
- hud/cli/utils/tests/test_git.py +142 -0
- hud/cli/utils/tests/test_interactive_module.py +60 -0
- hud/cli/utils/tests/test_local_runner.py +50 -0
- hud/cli/utils/tests/test_logging_utils.py +23 -0
- hud/cli/utils/tests/test_metadata.py +49 -0
- hud/cli/utils/tests/test_package_runner.py +35 -0
- hud/cli/utils/tests/test_registry_utils.py +49 -0
- hud/cli/utils/tests/test_remote_runner.py +25 -0
- hud/cli/utils/tests/test_runner_modules.py +52 -0
- hud/cli/utils/tests/test_source_hash.py +36 -0
- hud/cli/utils/tests/test_tasks.py +80 -0
- hud/cli/utils/version_check.py +258 -0
- hud/cli/{rl → utils}/viewer.py +2 -2
- hud/clients/README.md +12 -11
- hud/clients/__init__.py +4 -3
- hud/clients/base.py +166 -26
- hud/clients/environment.py +51 -0
- hud/clients/fastmcp.py +13 -6
- hud/clients/mcp_use.py +40 -15
- hud/clients/tests/test_analyze_scenarios.py +206 -0
- hud/clients/tests/test_protocol.py +9 -3
- hud/datasets/__init__.py +23 -20
- hud/datasets/loader.py +327 -0
- hud/datasets/runner.py +192 -105
- hud/datasets/tests/__init__.py +0 -0
- hud/datasets/tests/test_loader.py +221 -0
- hud/datasets/tests/test_utils.py +315 -0
- hud/datasets/utils.py +270 -90
- hud/environment/__init__.py +50 -0
- hud/environment/connection.py +206 -0
- hud/environment/connectors/__init__.py +33 -0
- hud/environment/connectors/base.py +68 -0
- hud/environment/connectors/local.py +177 -0
- hud/environment/connectors/mcp_config.py +109 -0
- hud/environment/connectors/openai.py +101 -0
- hud/environment/connectors/remote.py +172 -0
- hud/environment/environment.py +694 -0
- hud/environment/integrations/__init__.py +45 -0
- hud/environment/integrations/adk.py +67 -0
- hud/environment/integrations/anthropic.py +196 -0
- hud/environment/integrations/gemini.py +92 -0
- hud/environment/integrations/langchain.py +82 -0
- hud/environment/integrations/llamaindex.py +68 -0
- hud/environment/integrations/openai.py +238 -0
- hud/environment/mock.py +306 -0
- hud/environment/router.py +112 -0
- hud/environment/scenarios.py +493 -0
- hud/environment/tests/__init__.py +1 -0
- hud/environment/tests/test_connection.py +317 -0
- hud/environment/tests/test_connectors.py +218 -0
- hud/environment/tests/test_environment.py +161 -0
- hud/environment/tests/test_integrations.py +257 -0
- hud/environment/tests/test_local_connectors.py +201 -0
- hud/environment/tests/test_scenarios.py +280 -0
- hud/environment/tests/test_tools.py +208 -0
- hud/environment/types.py +23 -0
- hud/environment/utils/__init__.py +35 -0
- hud/environment/utils/formats.py +215 -0
- hud/environment/utils/schema.py +171 -0
- hud/environment/utils/tool_wrappers.py +113 -0
- hud/eval/__init__.py +67 -0
- hud/eval/context.py +674 -0
- hud/eval/display.py +299 -0
- hud/eval/instrument.py +185 -0
- hud/eval/manager.py +466 -0
- hud/eval/parallel.py +268 -0
- hud/eval/task.py +340 -0
- hud/eval/tests/__init__.py +1 -0
- hud/eval/tests/test_context.py +178 -0
- hud/eval/tests/test_eval.py +210 -0
- hud/eval/tests/test_manager.py +152 -0
- hud/eval/tests/test_parallel.py +168 -0
- hud/eval/tests/test_task.py +145 -0
- hud/eval/types.py +63 -0
- hud/eval/utils.py +183 -0
- hud/patches/__init__.py +19 -0
- hud/patches/mcp_patches.py +151 -0
- hud/patches/warnings.py +54 -0
- hud/samples/browser.py +4 -4
- hud/server/__init__.py +2 -1
- hud/server/low_level.py +2 -1
- hud/server/router.py +164 -0
- hud/server/server.py +567 -80
- hud/server/tests/test_mcp_server_integration.py +11 -11
- hud/server/tests/test_mcp_server_more.py +1 -1
- hud/server/tests/test_server_extra.py +2 -0
- hud/settings.py +45 -3
- hud/shared/exceptions.py +36 -10
- hud/shared/hints.py +26 -1
- hud/shared/requests.py +15 -3
- hud/shared/tests/test_exceptions.py +40 -31
- hud/shared/tests/test_hints.py +167 -0
- hud/telemetry/__init__.py +20 -19
- hud/telemetry/exporter.py +201 -0
- hud/telemetry/instrument.py +158 -253
- hud/telemetry/tests/test_eval_telemetry.py +356 -0
- hud/telemetry/tests/test_exporter.py +258 -0
- hud/telemetry/tests/test_instrument.py +401 -0
- hud/tools/__init__.py +16 -2
- hud/tools/apply_patch.py +639 -0
- hud/tools/base.py +54 -4
- hud/tools/bash.py +2 -2
- hud/tools/computer/__init__.py +4 -0
- hud/tools/computer/anthropic.py +2 -2
- hud/tools/computer/gemini.py +385 -0
- hud/tools/computer/hud.py +23 -6
- hud/tools/computer/openai.py +20 -21
- hud/tools/computer/qwen.py +434 -0
- hud/tools/computer/settings.py +37 -0
- hud/tools/edit.py +3 -7
- hud/tools/executors/base.py +4 -2
- hud/tools/executors/pyautogui.py +1 -1
- hud/tools/grounding/grounded_tool.py +13 -18
- hud/tools/grounding/grounder.py +10 -31
- hud/tools/grounding/tests/test_grounded_tool.py +26 -44
- hud/tools/jupyter.py +330 -0
- hud/tools/playwright.py +18 -3
- hud/tools/shell.py +308 -0
- hud/tools/tests/test_apply_patch.py +718 -0
- hud/tools/tests/test_computer.py +4 -9
- hud/tools/tests/test_computer_actions.py +24 -2
- hud/tools/tests/test_jupyter_tool.py +181 -0
- hud/tools/tests/test_shell.py +596 -0
- hud/tools/tests/test_submit.py +85 -0
- hud/tools/tests/test_types.py +193 -0
- hud/tools/types.py +21 -1
- hud/types.py +167 -57
- hud/utils/__init__.py +2 -0
- hud/utils/env.py +67 -0
- hud/utils/hud_console.py +61 -3
- hud/utils/mcp.py +15 -58
- hud/utils/strict_schema.py +162 -0
- hud/utils/tests/test_init.py +1 -2
- hud/utils/tests/test_mcp.py +1 -28
- hud/utils/tests/test_pretty_errors.py +186 -0
- hud/utils/tests/test_tool_shorthand.py +154 -0
- hud/utils/tests/test_version.py +1 -1
- hud/utils/types.py +20 -0
- hud/version.py +1 -1
- hud_python-0.5.1.dist-info/METADATA +264 -0
- hud_python-0.5.1.dist-info/RECORD +299 -0
- {hud_python-0.4.45.dist-info → hud_python-0.5.1.dist-info}/WHEEL +1 -1
- hud/agents/langchain.py +0 -261
- hud/agents/lite_llm.py +0 -72
- hud/cli/rl/__init__.py +0 -180
- hud/cli/rl/config.py +0 -101
- hud/cli/rl/display.py +0 -133
- hud/cli/rl/gpu.py +0 -63
- hud/cli/rl/gpu_utils.py +0 -321
- hud/cli/rl/local_runner.py +0 -595
- hud/cli/rl/presets.py +0 -96
- hud/cli/rl/remote_runner.py +0 -463
- hud/cli/rl/rl_api.py +0 -150
- hud/cli/rl/vllm.py +0 -177
- hud/cli/rl/wait_utils.py +0 -89
- hud/datasets/parallel.py +0 -687
- hud/misc/__init__.py +0 -1
- hud/misc/claude_plays_pokemon.py +0 -292
- hud/otel/__init__.py +0 -35
- hud/otel/collector.py +0 -142
- hud/otel/config.py +0 -181
- hud/otel/context.py +0 -570
- hud/otel/exporters.py +0 -369
- hud/otel/instrumentation.py +0 -135
- hud/otel/processors.py +0 -121
- hud/otel/tests/__init__.py +0 -1
- hud/otel/tests/test_processors.py +0 -197
- hud/rl/README.md +0 -30
- hud/rl/__init__.py +0 -1
- hud/rl/actor.py +0 -176
- hud/rl/buffer.py +0 -405
- hud/rl/chat_template.jinja +0 -101
- hud/rl/config.py +0 -192
- hud/rl/distributed.py +0 -132
- hud/rl/learner.py +0 -637
- hud/rl/tests/__init__.py +0 -1
- hud/rl/tests/test_learner.py +0 -186
- hud/rl/train.py +0 -382
- hud/rl/types.py +0 -101
- hud/rl/utils/start_vllm_server.sh +0 -30
- hud/rl/utils.py +0 -524
- hud/rl/vllm_adapter.py +0 -143
- hud/telemetry/job.py +0 -352
- hud/telemetry/replay.py +0 -74
- hud/telemetry/tests/test_replay.py +0 -40
- hud/telemetry/tests/test_trace.py +0 -63
- hud/telemetry/trace.py +0 -158
- hud/utils/agent_factories.py +0 -86
- hud/utils/async_utils.py +0 -65
- hud/utils/group_eval.py +0 -223
- hud/utils/progress.py +0 -149
- hud/utils/tasks.py +0 -127
- hud/utils/tests/test_async_utils.py +0 -173
- hud/utils/tests/test_progress.py +0 -261
- hud_python-0.4.45.dist-info/METADATA +0 -552
- hud_python-0.4.45.dist-info/RECORD +0 -228
- {hud_python-0.4.45.dist-info → hud_python-0.5.1.dist-info}/entry_points.txt +0 -0
- {hud_python-0.4.45.dist-info → hud_python-0.5.1.dist-info}/licenses/LICENSE +0 -0
hud/cli/rl/local_runner.py
DELETED
|
@@ -1,595 +0,0 @@
|
|
|
1
|
-
"""
|
|
2
|
-
Local runner for HUD RL training.
|
|
3
|
-
|
|
4
|
-
This module encapsulates the local training flow and imports heavy
|
|
5
|
-
dependencies (torch, transformers, etc.) only when actually running
|
|
6
|
-
locally. The CLI entrypoint should import this module lazily to avoid
|
|
7
|
-
pulling heavy deps during remote-only usage.
|
|
8
|
-
"""
|
|
9
|
-
|
|
10
|
-
from __future__ import annotations
|
|
11
|
-
|
|
12
|
-
import asyncio
|
|
13
|
-
import os
|
|
14
|
-
import subprocess
|
|
15
|
-
import sys
|
|
16
|
-
from pathlib import Path
|
|
17
|
-
|
|
18
|
-
from rich.console import Console
|
|
19
|
-
|
|
20
|
-
from hud.rl.config import validate_vl_model
|
|
21
|
-
from hud.utils.hud_console import hud_console
|
|
22
|
-
from hud.utils.tasks import load_tasks
|
|
23
|
-
|
|
24
|
-
console = Console()
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
def run_local_training(
|
|
28
|
-
*,
|
|
29
|
-
tasks_file: str,
|
|
30
|
-
model: str | None,
|
|
31
|
-
config_file: Path | None,
|
|
32
|
-
output_dir: str,
|
|
33
|
-
yes: bool,
|
|
34
|
-
restart: bool,
|
|
35
|
-
verbose: bool,
|
|
36
|
-
no_ddp: bool,
|
|
37
|
-
ddp_gpus: str | None,
|
|
38
|
-
vllm_gpu: int | None,
|
|
39
|
-
skip_vllm_startup: bool,
|
|
40
|
-
) -> None:
|
|
41
|
-
"""Run RL training locally on the current machine.
|
|
42
|
-
|
|
43
|
-
Heavy modules are imported inside this function to avoid import-time side effects
|
|
44
|
-
during remote-only runs.
|
|
45
|
-
"""
|
|
46
|
-
# Light-weight utilities
|
|
47
|
-
from .config import generate_config_interactive, load_config, save_config
|
|
48
|
-
from .display import display_config_summary, display_gpu_info
|
|
49
|
-
from .gpu import detect_cuda_devices, validate_gpu_memory
|
|
50
|
-
from .presets import get_training_presets
|
|
51
|
-
|
|
52
|
-
# Python version compatibility warning for vLLM
|
|
53
|
-
python_version = sys.version_info
|
|
54
|
-
if python_version.major == 3 and python_version.minor >= 13:
|
|
55
|
-
console.print("[red]⚠️ Warning: Python 3.13+ detected![/red]")
|
|
56
|
-
console.print("[yellow]vLLM has compatibility issues with Python 3.13.[/yellow]")
|
|
57
|
-
console.print("[yellow]Recommended: Use Python 3.12 or 3.11[/yellow]")
|
|
58
|
-
console.print("\n[dim]To create a new environment with Python 3.12:[/dim]")
|
|
59
|
-
console.print("[dim] 1. Exit this shell: exit[/dim]")
|
|
60
|
-
console.print("[dim] 2. Remove current venv: sudo rm -rf .venv[/dim]")
|
|
61
|
-
console.print("[dim] 3. Create new venv: uv venv --python 3.12[/dim]")
|
|
62
|
-
console.print("[dim] 4. Install dependencies: uv pip install -e '.[rl]'[/dim]")
|
|
63
|
-
|
|
64
|
-
try:
|
|
65
|
-
import typer
|
|
66
|
-
|
|
67
|
-
if not yes:
|
|
68
|
-
if not typer.confirm("\nDo you want to continue anyway?", default=False):
|
|
69
|
-
raise typer.Exit(1)
|
|
70
|
-
else:
|
|
71
|
-
hud_console.warning("Auto-continuing despite Python 3.13+ (--yes mode)")
|
|
72
|
-
except Exception as e:
|
|
73
|
-
hud_console.warning(f"Failed to confirm: {e}")
|
|
74
|
-
return
|
|
75
|
-
|
|
76
|
-
# Step 1: Validate CUDA devices
|
|
77
|
-
console.print("[yellow]Checking GPU availability...[/yellow]")
|
|
78
|
-
gpu_info = detect_cuda_devices()
|
|
79
|
-
|
|
80
|
-
if not gpu_info["available"]:
|
|
81
|
-
console.print(f"[red]❌ {gpu_info['error']}[/red]")
|
|
82
|
-
console.print("[yellow]RL training requires CUDA-capable GPUs[/yellow]")
|
|
83
|
-
try:
|
|
84
|
-
import typer
|
|
85
|
-
|
|
86
|
-
raise typer.Exit(1)
|
|
87
|
-
except Exception:
|
|
88
|
-
return
|
|
89
|
-
|
|
90
|
-
display_gpu_info(gpu_info)
|
|
91
|
-
|
|
92
|
-
# Perform GPU health check (imports torch lazily)
|
|
93
|
-
all_gpu_indices = [device["index"] for device in gpu_info["devices"]]
|
|
94
|
-
from .gpu_utils import health_check_gpus # heavy import (torch)
|
|
95
|
-
|
|
96
|
-
health_results = health_check_gpus(all_gpu_indices)
|
|
97
|
-
|
|
98
|
-
if not health_results["all_healthy"]:
|
|
99
|
-
console.print("\n[yellow]⚠️ Some GPUs failed health checks![/yellow]")
|
|
100
|
-
console.print(
|
|
101
|
-
f"[yellow]Unhealthy GPUs: {list(health_results['unhealthy_gpus'].keys())}[/yellow]"
|
|
102
|
-
)
|
|
103
|
-
|
|
104
|
-
if not health_results["healthy_gpus"]:
|
|
105
|
-
console.print("[red]❌ No healthy GPUs available for training![/red]")
|
|
106
|
-
try:
|
|
107
|
-
import typer
|
|
108
|
-
|
|
109
|
-
raise typer.Exit(1)
|
|
110
|
-
except Exception:
|
|
111
|
-
return
|
|
112
|
-
|
|
113
|
-
console.print(
|
|
114
|
-
f"\n[cyan]You have {len(health_results['healthy_gpus'])} healthy GPUs available.[/cyan]"
|
|
115
|
-
)
|
|
116
|
-
|
|
117
|
-
try:
|
|
118
|
-
import typer
|
|
119
|
-
|
|
120
|
-
if yes:
|
|
121
|
-
continue_training = True
|
|
122
|
-
hud_console.info("Auto-continuing with healthy GPUs only (--yes mode)")
|
|
123
|
-
else:
|
|
124
|
-
continue_training = typer.confirm(
|
|
125
|
-
"\nContinue with healthy GPUs only?", default=True
|
|
126
|
-
)
|
|
127
|
-
except Exception:
|
|
128
|
-
continue_training = True
|
|
129
|
-
|
|
130
|
-
if not continue_training:
|
|
131
|
-
healthy_str = ",".join(map(str, health_results["healthy_gpus"]))
|
|
132
|
-
console.print("\n[yellow]Exiting. Please resolve GPU issues and try again.[/yellow]")
|
|
133
|
-
console.print("\n[cyan]💡 Tip: To use only healthy GPUs, you can run:[/cyan]")
|
|
134
|
-
console.print(f"[white]hud rl {tasks_file} --ddp-gpus {healthy_str} --local[/white]\n")
|
|
135
|
-
try:
|
|
136
|
-
import typer
|
|
137
|
-
|
|
138
|
-
raise typer.Exit(0)
|
|
139
|
-
except Exception:
|
|
140
|
-
return
|
|
141
|
-
else:
|
|
142
|
-
# Continue with healthy GPUs only
|
|
143
|
-
gpu_info["devices"] = [
|
|
144
|
-
d for d in gpu_info["devices"] if d["index"] in health_results["healthy_gpus"]
|
|
145
|
-
]
|
|
146
|
-
console.print(
|
|
147
|
-
f"\n[green]✅ Continuing with {len(gpu_info['devices'])} healthy GPUs[/green]"
|
|
148
|
-
)
|
|
149
|
-
|
|
150
|
-
# Get primary GPU memory for configuration
|
|
151
|
-
primary_gpu = gpu_info["devices"][0]
|
|
152
|
-
gpu_memory_gb = primary_gpu["memory_gb"]
|
|
153
|
-
|
|
154
|
-
# Validate GPU memory for 3B model
|
|
155
|
-
if not validate_gpu_memory(gpu_memory_gb, "3B"):
|
|
156
|
-
console.print(f"[red]❌ Insufficient GPU memory ({gpu_memory_gb:.1f} GB)[/red]")
|
|
157
|
-
console.print("[yellow]Qwen 2.5 VL 3B requires at least 12 GB of GPU memory[/yellow]")
|
|
158
|
-
try:
|
|
159
|
-
import typer
|
|
160
|
-
|
|
161
|
-
raise typer.Exit(1)
|
|
162
|
-
except Exception:
|
|
163
|
-
return
|
|
164
|
-
|
|
165
|
-
# Step 2: Load and validate tasks
|
|
166
|
-
if tasks_file:
|
|
167
|
-
console.print(f"\n[cyan]Loading tasks from: {tasks_file}[/cyan]")
|
|
168
|
-
else:
|
|
169
|
-
possible_files = ["tasks.json", "tasks.jsonl", "browser_2048_tasks.jsonl"]
|
|
170
|
-
for f in possible_files:
|
|
171
|
-
if Path(f).exists():
|
|
172
|
-
tasks_file = f
|
|
173
|
-
console.print(f"[green]Auto-detected tasks file: {f}[/green]")
|
|
174
|
-
break
|
|
175
|
-
|
|
176
|
-
if not tasks_file:
|
|
177
|
-
console.print("[red]❌ No tasks file specified or auto-detected[/red]")
|
|
178
|
-
console.print(
|
|
179
|
-
"[yellow]Please provide a tasks file or create one of: tasks.json, tasks.jsonl[/yellow]" # noqa: E501
|
|
180
|
-
)
|
|
181
|
-
try:
|
|
182
|
-
import typer
|
|
183
|
-
|
|
184
|
-
raise typer.Exit(1)
|
|
185
|
-
except Exception:
|
|
186
|
-
return
|
|
187
|
-
|
|
188
|
-
tasks = load_tasks(tasks_file)
|
|
189
|
-
console.print(f"[green]✅ Loaded {len(tasks)} tasks[/green]")
|
|
190
|
-
|
|
191
|
-
invalid_tasks: list[str] = []
|
|
192
|
-
for i, task in enumerate(tasks):
|
|
193
|
-
if not hasattr(task, "prompt") or not task.prompt:
|
|
194
|
-
invalid_tasks.append(f"Task {i}: missing 'prompt' field")
|
|
195
|
-
if not hasattr(task, "mcp_config") or not task.mcp_config:
|
|
196
|
-
invalid_tasks.append(f"Task {i}: missing 'mcp_config' field")
|
|
197
|
-
|
|
198
|
-
if invalid_tasks:
|
|
199
|
-
console.print("[red]❌ Invalid tasks found:[/red]")
|
|
200
|
-
for error in invalid_tasks[:5]:
|
|
201
|
-
console.print(f" - {error}")
|
|
202
|
-
if len(invalid_tasks) > 5:
|
|
203
|
-
console.print(f" ... and {len(invalid_tasks) - 5} more")
|
|
204
|
-
try:
|
|
205
|
-
import typer
|
|
206
|
-
|
|
207
|
-
raise typer.Exit(1)
|
|
208
|
-
except Exception:
|
|
209
|
-
return
|
|
210
|
-
|
|
211
|
-
# Step 3: Model selection (if not provided)
|
|
212
|
-
if model is None and not config_file:
|
|
213
|
-
if yes:
|
|
214
|
-
model = "Qwen/Qwen2.5-VL-3B-Instruct" # Default model in yes mode
|
|
215
|
-
hud_console.info(f"Auto-selecting model: {model} (--yes mode)")
|
|
216
|
-
else:
|
|
217
|
-
model = hud_console.select(
|
|
218
|
-
"Select a model for RL training:",
|
|
219
|
-
choices=[
|
|
220
|
-
{
|
|
221
|
-
"name": "Qwen 2.5 VL 3B (Recommended - Vision-Language)",
|
|
222
|
-
"value": "Qwen/Qwen2.5-VL-3B-Instruct",
|
|
223
|
-
},
|
|
224
|
-
{"name": "Custom model", "value": "custom"},
|
|
225
|
-
],
|
|
226
|
-
default=0,
|
|
227
|
-
)
|
|
228
|
-
|
|
229
|
-
if model == "custom":
|
|
230
|
-
console.print("Enter the model name (HuggingFace ID):")
|
|
231
|
-
model = input().strip()
|
|
232
|
-
|
|
233
|
-
# Validate model is a VL model (whether provided via CLI or selected)
|
|
234
|
-
if model:
|
|
235
|
-
try:
|
|
236
|
-
validate_vl_model(model)
|
|
237
|
-
except ValueError as e:
|
|
238
|
-
console.print(f"\n[red]❌ {e}[/red]")
|
|
239
|
-
try:
|
|
240
|
-
import typer
|
|
241
|
-
|
|
242
|
-
raise typer.Exit(1)
|
|
243
|
-
except Exception:
|
|
244
|
-
return
|
|
245
|
-
else:
|
|
246
|
-
try:
|
|
247
|
-
import typer
|
|
248
|
-
|
|
249
|
-
raise typer.Exit(1)
|
|
250
|
-
except Exception:
|
|
251
|
-
return
|
|
252
|
-
|
|
253
|
-
# Step 4: Generate or load configuration
|
|
254
|
-
if config_file:
|
|
255
|
-
console.print(f"\n[cyan]Loading configuration from: {config_file}[/cyan]")
|
|
256
|
-
config = load_config(config_file)
|
|
257
|
-
|
|
258
|
-
# Validate model from config
|
|
259
|
-
if hasattr(config, "model") and hasattr(config.model, "base_model"):
|
|
260
|
-
try:
|
|
261
|
-
validate_vl_model(config.model.base_model)
|
|
262
|
-
except ValueError as e:
|
|
263
|
-
console.print(f"\n[red]❌ {e}[/red]")
|
|
264
|
-
try:
|
|
265
|
-
import typer
|
|
266
|
-
|
|
267
|
-
raise typer.Exit(1)
|
|
268
|
-
except Exception:
|
|
269
|
-
return
|
|
270
|
-
|
|
271
|
-
# Estimate memory for display
|
|
272
|
-
from .presets import estimate_memory_usage
|
|
273
|
-
|
|
274
|
-
estimated_memory = estimate_memory_usage(
|
|
275
|
-
config.training.mini_batch_size,
|
|
276
|
-
config.actor.max_steps_per_episode,
|
|
277
|
-
config.actor.max_new_tokens,
|
|
278
|
-
config.model.max_pixels,
|
|
279
|
-
)
|
|
280
|
-
else:
|
|
281
|
-
console.print("\n[cyan]Generating training configuration...[/cyan]")
|
|
282
|
-
# Get number of GPUs for preset scaling
|
|
283
|
-
num_training_gpus = 1 # Default, will be adjusted later
|
|
284
|
-
if len(gpu_info["devices"]) > 2:
|
|
285
|
-
num_training_gpus = len(gpu_info["devices"]) - 1 # Reserve 1 for vLLM
|
|
286
|
-
console.print(
|
|
287
|
-
f"[yellow]Note: Episodes will be scaled for {num_training_gpus} training GPUs[/yellow]\n" # noqa: E501
|
|
288
|
-
)
|
|
289
|
-
|
|
290
|
-
presets = get_training_presets(gpu_memory_gb)
|
|
291
|
-
config, estimated_memory = generate_config_interactive(
|
|
292
|
-
model_name=model,
|
|
293
|
-
presets=presets,
|
|
294
|
-
yes=yes,
|
|
295
|
-
)
|
|
296
|
-
|
|
297
|
-
# Step 5: Save temporary config and display summary
|
|
298
|
-
temp_config_path = Path(".rl_config_temp.json")
|
|
299
|
-
save_config(config, temp_config_path)
|
|
300
|
-
console.print(f"\n[cyan]📝 Configuration saved to: {temp_config_path.absolute()}[/cyan]")
|
|
301
|
-
console.print("[yellow]You can edit this file before starting training.[/yellow]")
|
|
302
|
-
|
|
303
|
-
# Display configuration summary
|
|
304
|
-
display_config_summary(config, len(tasks), gpu_info, estimated_memory)
|
|
305
|
-
|
|
306
|
-
# Step 6: Ask for confirmation (skip if config was provided or in yes mode)
|
|
307
|
-
if not config_file and not yes:
|
|
308
|
-
console.print("\n[bold yellow]Options:[/bold yellow]")
|
|
309
|
-
console.print(" • Type [green]'start'[/green] to begin training")
|
|
310
|
-
console.print(" • Type [cyan]'edit'[/cyan] to open config in your editor")
|
|
311
|
-
console.print(" • Type [red]'cancel'[/red] to abort")
|
|
312
|
-
console.print("\n[bold]Your choice:[/bold] ", end="")
|
|
313
|
-
|
|
314
|
-
while True:
|
|
315
|
-
choice = input().strip().lower()
|
|
316
|
-
|
|
317
|
-
if choice == "start":
|
|
318
|
-
config = load_config(temp_config_path) # Reload config in case it was edited
|
|
319
|
-
break
|
|
320
|
-
elif choice == "edit":
|
|
321
|
-
editor = os.environ.get("EDITOR", "nano")
|
|
322
|
-
|
|
323
|
-
if editor == "nano":
|
|
324
|
-
console.print("\n[cyan]Opening config in nano editor...[/cyan]")
|
|
325
|
-
console.print("[yellow]Tips:[/yellow]")
|
|
326
|
-
console.print(" • Edit the configuration values as needed")
|
|
327
|
-
console.print(" • Press [bold]Ctrl+O[/bold] then [bold]Enter[/bold] to save")
|
|
328
|
-
console.print(" • Press [bold]Ctrl+X[/bold] to exit")
|
|
329
|
-
console.print(" • Press [bold]Ctrl+C[/bold] to cancel without saving\n")
|
|
330
|
-
input("Press Enter to continue...")
|
|
331
|
-
|
|
332
|
-
try:
|
|
333
|
-
subprocess.run([editor, str(temp_config_path)], check=True) # noqa: S603
|
|
334
|
-
# Reload and display updated config
|
|
335
|
-
config = load_config(temp_config_path)
|
|
336
|
-
from .presets import estimate_memory_usage as _estimate_memory
|
|
337
|
-
|
|
338
|
-
estimated_memory = _estimate_memory(
|
|
339
|
-
config.training.mini_batch_size,
|
|
340
|
-
config.actor.max_steps_per_episode,
|
|
341
|
-
config.actor.max_new_tokens,
|
|
342
|
-
config.model.max_pixels,
|
|
343
|
-
)
|
|
344
|
-
display_config_summary(config, len(tasks), gpu_info, estimated_memory)
|
|
345
|
-
console.print(
|
|
346
|
-
"\n[bold]Type 'start' to begin or 'cancel' to abort:[/bold] ", end=""
|
|
347
|
-
)
|
|
348
|
-
except subprocess.CalledProcessError:
|
|
349
|
-
console.print(
|
|
350
|
-
"\n[yellow]Editor closed without saving or was cancelled.[/yellow]"
|
|
351
|
-
)
|
|
352
|
-
console.print("[bold]Your choice:[/bold] ", end="")
|
|
353
|
-
except Exception as e:
|
|
354
|
-
console.print(f"\n[red]Failed to open editor: {e}[/red]")
|
|
355
|
-
console.print(
|
|
356
|
-
f"[yellow]Please edit {temp_config_path} manually and type 'start' when ready.[/yellow]" # noqa: E501
|
|
357
|
-
)
|
|
358
|
-
console.print("[bold]Your choice:[/bold] ", end="")
|
|
359
|
-
elif choice == "cancel":
|
|
360
|
-
console.print("[red]Training cancelled[/red]")
|
|
361
|
-
try:
|
|
362
|
-
import typer
|
|
363
|
-
|
|
364
|
-
if yes:
|
|
365
|
-
# Always save in yes mode
|
|
366
|
-
config_path = Path("rl_config.json")
|
|
367
|
-
save_config(config, config_path)
|
|
368
|
-
hud_console.info("Auto-saved configuration (--yes mode)")
|
|
369
|
-
elif typer.confirm("Save this configuration for later?", default=True):
|
|
370
|
-
config_path = Path("rl_config.json")
|
|
371
|
-
save_config(config, config_path)
|
|
372
|
-
except Exception as e:
|
|
373
|
-
hud_console.warning(f"Failed to save config: {e}")
|
|
374
|
-
|
|
375
|
-
try:
|
|
376
|
-
temp_config_path.unlink()
|
|
377
|
-
except Exception as e:
|
|
378
|
-
hud_console.warning(f"Failed to clean up temp config: {e}")
|
|
379
|
-
|
|
380
|
-
try:
|
|
381
|
-
import typer
|
|
382
|
-
|
|
383
|
-
raise typer.Exit(0)
|
|
384
|
-
except Exception:
|
|
385
|
-
return
|
|
386
|
-
else:
|
|
387
|
-
console.print(
|
|
388
|
-
"[red]Invalid choice. Type 'start', 'edit', or 'cancel':[/red] ", end=""
|
|
389
|
-
)
|
|
390
|
-
elif yes:
|
|
391
|
-
# In yes mode, auto-start training
|
|
392
|
-
hud_console.info("Auto-starting training (--yes mode)")
|
|
393
|
-
config = load_config(temp_config_path)
|
|
394
|
-
else:
|
|
395
|
-
console.print("\n[dim]Using provided configuration file...[/dim]")
|
|
396
|
-
config = load_config(temp_config_path)
|
|
397
|
-
|
|
398
|
-
# Step 7: Determine if DDP should be used (imports heavy helpers lazily)
|
|
399
|
-
num_gpus = len(gpu_info["devices"])
|
|
400
|
-
use_ddp = False
|
|
401
|
-
training_gpus = [0] # Default single GPU
|
|
402
|
-
vllm_gpu_idx = 1 if num_gpus > 1 else 0
|
|
403
|
-
|
|
404
|
-
if num_gpus > 2 and not no_ddp:
|
|
405
|
-
console.print(f"\n[cyan]🚀 Detected {num_gpus} GPUs - checking DDP configuration...[/cyan]")
|
|
406
|
-
|
|
407
|
-
from .gpu_utils import calculate_optimal_gpu_allocation # heavy import (torch at module)
|
|
408
|
-
|
|
409
|
-
gpu_allocation = calculate_optimal_gpu_allocation(gpu_info, config)
|
|
410
|
-
|
|
411
|
-
if gpu_allocation["use_ddp"]:
|
|
412
|
-
use_ddp = True
|
|
413
|
-
training_gpus = gpu_allocation["training_gpus"]
|
|
414
|
-
vllm_gpu_idx = gpu_allocation["vllm_gpu"]
|
|
415
|
-
|
|
416
|
-
console.print(
|
|
417
|
-
f"[green]✅ Will use DDP with {len(training_gpus)} GPUs for training[/green]"
|
|
418
|
-
)
|
|
419
|
-
console.print(f"[green]✅ GPU {vllm_gpu_idx} reserved for vLLM server[/green]")
|
|
420
|
-
|
|
421
|
-
console.print("\n[cyan]Training Configuration:[/cyan]")
|
|
422
|
-
console.print(f" • Groups to process: {gpu_allocation['num_groups']}")
|
|
423
|
-
console.print(f" • Training GPUs: {training_gpus}")
|
|
424
|
-
console.print(f" • Groups per GPU: {gpu_allocation.get('groups_per_gpu', 'N/A'):.1f}")
|
|
425
|
-
|
|
426
|
-
if gpu_allocation.get("parallel_efficiency", 1.0) < 0.8:
|
|
427
|
-
console.print(
|
|
428
|
-
f"\n[yellow]⚠️ GPU efficiency: {gpu_allocation['parallel_efficiency'] * 100:.0f}%[/yellow]" # noqa: E501
|
|
429
|
-
)
|
|
430
|
-
console.print(
|
|
431
|
-
f"[yellow]Consider adjusting batch_size to {len(training_gpus) * config.training.group_size} for optimal performance[/yellow]" # noqa: E501
|
|
432
|
-
)
|
|
433
|
-
else:
|
|
434
|
-
console.print(f"[cyan]{gpu_allocation.get('reason', 'Using single GPU')}[/cyan]")
|
|
435
|
-
|
|
436
|
-
# Allow manual overrides
|
|
437
|
-
if ddp_gpus is not None:
|
|
438
|
-
requested_gpus = [int(x) for x in ddp_gpus.split(",")]
|
|
439
|
-
console.print(f"[cyan]Manual GPU selection: {requested_gpus}[/cyan]")
|
|
440
|
-
available_indices = [d["index"] for d in gpu_info["devices"]]
|
|
441
|
-
invalid_gpus = [g for g in requested_gpus if g not in available_indices]
|
|
442
|
-
if invalid_gpus:
|
|
443
|
-
console.print(f"[red]❌ Invalid/unhealthy GPU(s) requested: {invalid_gpus}[/red]")
|
|
444
|
-
console.print(f"[yellow]Available healthy GPUs: {available_indices}[/yellow]")
|
|
445
|
-
try:
|
|
446
|
-
import typer
|
|
447
|
-
|
|
448
|
-
raise typer.Exit(1)
|
|
449
|
-
except Exception:
|
|
450
|
-
return
|
|
451
|
-
training_gpus = requested_gpus
|
|
452
|
-
use_ddp = len(training_gpus) > 1
|
|
453
|
-
|
|
454
|
-
if vllm_gpu is not None:
|
|
455
|
-
vllm_gpu_idx = vllm_gpu
|
|
456
|
-
console.print(f"[cyan]Manual vLLM GPU: {vllm_gpu_idx}[/cyan]")
|
|
457
|
-
available_indices = [d["index"] for d in gpu_info["devices"]]
|
|
458
|
-
if vllm_gpu_idx not in available_indices:
|
|
459
|
-
console.print(f"[red]❌ vLLM GPU {vllm_gpu_idx} is not available/healthy![/red]")
|
|
460
|
-
console.print(f"[yellow]Available healthy GPUs: {available_indices}[/yellow]")
|
|
461
|
-
try:
|
|
462
|
-
import typer
|
|
463
|
-
|
|
464
|
-
raise typer.Exit(1)
|
|
465
|
-
except Exception:
|
|
466
|
-
return
|
|
467
|
-
|
|
468
|
-
# Ensure we have at least one training GPU
|
|
469
|
-
if not training_gpus:
|
|
470
|
-
console.print("[red]❌ No available GPUs for training![/red]")
|
|
471
|
-
try:
|
|
472
|
-
import typer
|
|
473
|
-
|
|
474
|
-
raise typer.Exit(1)
|
|
475
|
-
except Exception:
|
|
476
|
-
return
|
|
477
|
-
|
|
478
|
-
# Always adjust batch_size based on number of training GPUs (lazy import)
|
|
479
|
-
from .gpu_utils import adjust_config_for_ddp # heavy import (torch at module)
|
|
480
|
-
|
|
481
|
-
config = adjust_config_for_ddp(config, len(training_gpus))
|
|
482
|
-
save_config(config, temp_config_path)
|
|
483
|
-
|
|
484
|
-
# Step 8: Start vLLM server (unless we're using a remote one)
|
|
485
|
-
if not skip_vllm_startup:
|
|
486
|
-
console.print(f"\n[cyan]Setting up vLLM server on GPU {vllm_gpu_idx}...[/cyan]")
|
|
487
|
-
|
|
488
|
-
from .vllm import start_vllm_server, wait_for_vllm_server
|
|
489
|
-
|
|
490
|
-
start_vllm_server(config.model.base_model, vllm_gpu_idx, restart=restart)
|
|
491
|
-
|
|
492
|
-
server_ready = asyncio.run(wait_for_vllm_server())
|
|
493
|
-
if not server_ready:
|
|
494
|
-
console.print("[red]❌ Failed to start vLLM server[/red]")
|
|
495
|
-
try:
|
|
496
|
-
import typer
|
|
497
|
-
|
|
498
|
-
raise typer.Exit(1)
|
|
499
|
-
except Exception:
|
|
500
|
-
return
|
|
501
|
-
else:
|
|
502
|
-
console.print("\n[cyan]Using remote vLLM server (skipping local startup)[/cyan]")
|
|
503
|
-
|
|
504
|
-
# Step 9: Run training (DDP or single GPU)
|
|
505
|
-
if use_ddp:
|
|
506
|
-
console.print(
|
|
507
|
-
f"\n[bold green]🎯 Starting DDP training on {len(training_gpus)} GPUs...[/bold green]\n"
|
|
508
|
-
)
|
|
509
|
-
launch_ddp_training(training_gpus, tasks_file, temp_config_path, verbose)
|
|
510
|
-
console.print("\n[green]✅ Training completed successfully![/green]")
|
|
511
|
-
else:
|
|
512
|
-
console.print("\n[bold green]🎯 Starting single-GPU training...[/bold green]\n")
|
|
513
|
-
try:
|
|
514
|
-
# Set verbose in config instead of passing as parameter
|
|
515
|
-
if verbose:
|
|
516
|
-
config.verbose = True
|
|
517
|
-
|
|
518
|
-
# Import and run the async training function lazily
|
|
519
|
-
from hud.rl.train import train # heavy import
|
|
520
|
-
|
|
521
|
-
asyncio.run(train(config, tasks))
|
|
522
|
-
console.print("\n[green]✅ Training completed successfully![/green]")
|
|
523
|
-
|
|
524
|
-
try:
|
|
525
|
-
temp_config_path.unlink()
|
|
526
|
-
except Exception as e:
|
|
527
|
-
hud_console.warning(f"Failed to clean up temp config: {e}")
|
|
528
|
-
|
|
529
|
-
except KeyboardInterrupt:
|
|
530
|
-
console.print("\n[yellow]Training interrupted by user[/yellow]")
|
|
531
|
-
try:
|
|
532
|
-
import typer
|
|
533
|
-
|
|
534
|
-
raise typer.Exit(1)
|
|
535
|
-
except Exception:
|
|
536
|
-
return
|
|
537
|
-
except Exception as e:
|
|
538
|
-
console.print(f"\n[red]❌ Training failed: {e}")
|
|
539
|
-
try:
|
|
540
|
-
import typer
|
|
541
|
-
|
|
542
|
-
raise typer.Exit(1)
|
|
543
|
-
except Exception:
|
|
544
|
-
return
|
|
545
|
-
|
|
546
|
-
|
|
547
|
-
def launch_ddp_training(
|
|
548
|
-
training_gpus: list[int], tasks_file: str, config_path: Path, verbose: bool
|
|
549
|
-
) -> None:
|
|
550
|
-
"""Launch DDP training with torchrun.
|
|
551
|
-
|
|
552
|
-
Uses subprocess to run the training module, so heavy dependencies load in
|
|
553
|
-
the spawned processes rather than the CLI import path.
|
|
554
|
-
"""
|
|
555
|
-
import subprocess as _subprocess
|
|
556
|
-
import sys as _sys
|
|
557
|
-
|
|
558
|
-
env = os.environ.copy()
|
|
559
|
-
env["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, training_gpus))
|
|
560
|
-
|
|
561
|
-
if not verbose:
|
|
562
|
-
env["HUD_LOG_LEVEL"] = "WARNING"
|
|
563
|
-
|
|
564
|
-
cmd = [
|
|
565
|
-
_sys.executable,
|
|
566
|
-
"-m",
|
|
567
|
-
"torch.distributed.run",
|
|
568
|
-
f"--nproc_per_node={len(training_gpus)}",
|
|
569
|
-
"--master_port=29500",
|
|
570
|
-
"-m",
|
|
571
|
-
"hud.rl.train",
|
|
572
|
-
"--config",
|
|
573
|
-
str(config_path),
|
|
574
|
-
"--tasks",
|
|
575
|
-
tasks_file,
|
|
576
|
-
]
|
|
577
|
-
|
|
578
|
-
if verbose:
|
|
579
|
-
cmd.append("--verbose")
|
|
580
|
-
|
|
581
|
-
try:
|
|
582
|
-
_subprocess.run(cmd, env=env, check=True) # noqa: S603
|
|
583
|
-
except _subprocess.CalledProcessError as e:
|
|
584
|
-
console.print(f"\n[red]❌ DDP training failed with exit code {e.returncode}[/red]")
|
|
585
|
-
try:
|
|
586
|
-
import typer
|
|
587
|
-
|
|
588
|
-
raise typer.Exit(1)
|
|
589
|
-
except Exception:
|
|
590
|
-
return
|
|
591
|
-
finally:
|
|
592
|
-
try:
|
|
593
|
-
config_path.unlink()
|
|
594
|
-
except Exception as e:
|
|
595
|
-
hud_console.warning(f"Failed to clean up temp config: {e}")
|
hud/cli/rl/presets.py
DELETED
|
@@ -1,96 +0,0 @@
|
|
|
1
|
-
"""Training configuration presets for different GPU configurations."""
|
|
2
|
-
|
|
3
|
-
from __future__ import annotations
|
|
4
|
-
|
|
5
|
-
from typing import Any
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
def get_training_presets(gpu_memory_gb: float) -> list[dict[str, Any]]:
|
|
9
|
-
"""Get training configuration presets based on GPU memory."""
|
|
10
|
-
# Time estimates based on provided benchmarks
|
|
11
|
-
if gpu_memory_gb >= 40: # A100 40GB or better
|
|
12
|
-
presets = [
|
|
13
|
-
{
|
|
14
|
-
"name": "More Steps",
|
|
15
|
-
"max_steps_per_episode": 12,
|
|
16
|
-
"mini_batch_size": 1,
|
|
17
|
-
"group_size": 4,
|
|
18
|
-
"batch_size": 8,
|
|
19
|
-
"max_new_tokens": 256,
|
|
20
|
-
"tasks_per_hour": 847,
|
|
21
|
-
"steps_per_hour": 424,
|
|
22
|
-
"lr": 3e-5,
|
|
23
|
-
"epochs": 2,
|
|
24
|
-
},
|
|
25
|
-
{
|
|
26
|
-
"name": "Balanced (Recommended)",
|
|
27
|
-
"max_steps_per_episode": 5,
|
|
28
|
-
"mini_batch_size": 1,
|
|
29
|
-
"group_size": 6,
|
|
30
|
-
"batch_size": 12,
|
|
31
|
-
"max_new_tokens": 1024,
|
|
32
|
-
"tasks_per_hour": 738,
|
|
33
|
-
"steps_per_hour": 415,
|
|
34
|
-
"lr": 3e-5,
|
|
35
|
-
"epochs": 2,
|
|
36
|
-
},
|
|
37
|
-
{
|
|
38
|
-
"name": "Low Variance",
|
|
39
|
-
"max_steps_per_episode": 3,
|
|
40
|
-
"mini_batch_size": 2,
|
|
41
|
-
"group_size": 8,
|
|
42
|
-
"batch_size": 16,
|
|
43
|
-
"max_new_tokens": 512,
|
|
44
|
-
"tasks_per_hour": 900,
|
|
45
|
-
"steps_per_hour": 450,
|
|
46
|
-
"lr": 3e-5,
|
|
47
|
-
"epochs": 2,
|
|
48
|
-
},
|
|
49
|
-
]
|
|
50
|
-
elif gpu_memory_gb >= 24: # RTX 4090, A10, etc
|
|
51
|
-
presets = [
|
|
52
|
-
{
|
|
53
|
-
"name": "Balanced (Recommended)",
|
|
54
|
-
"max_steps_per_episode": 4,
|
|
55
|
-
"mini_batch_size": 1,
|
|
56
|
-
"group_size": 4,
|
|
57
|
-
"batch_size": 16,
|
|
58
|
-
"lr": 1e-4,
|
|
59
|
-
"epochs": 2,
|
|
60
|
-
},
|
|
61
|
-
{
|
|
62
|
-
"name": "Low Variance",
|
|
63
|
-
"max_steps_per_episode": 3,
|
|
64
|
-
"mini_batch_size": 2,
|
|
65
|
-
"group_size": 4,
|
|
66
|
-
"batch_size": 16,
|
|
67
|
-
"lr": 5e-5,
|
|
68
|
-
"epochs": 2,
|
|
69
|
-
},
|
|
70
|
-
]
|
|
71
|
-
else: # Smaller GPUs
|
|
72
|
-
presets = [
|
|
73
|
-
{
|
|
74
|
-
"name": "Test",
|
|
75
|
-
"max_steps_per_episode": 5,
|
|
76
|
-
"mini_batch_size": 1,
|
|
77
|
-
"group_size": 4,
|
|
78
|
-
"batch_size": 8,
|
|
79
|
-
"lr": 1e-4,
|
|
80
|
-
"epochs": 1,
|
|
81
|
-
},
|
|
82
|
-
]
|
|
83
|
-
|
|
84
|
-
return presets
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
def estimate_memory_usage(
|
|
88
|
-
mini_batch_size: int, max_steps: int, max_new_tokens: int, max_pixels: int
|
|
89
|
-
) -> float:
|
|
90
|
-
"""Calculate estimated GPU memory usage using the formula from train.py."""
|
|
91
|
-
INITIAL_MEMORY = 8.0
|
|
92
|
-
SCALING_FACTOR = 4 / (28 * 28 * 256 * 1024)
|
|
93
|
-
token_estimate = mini_batch_size * max_steps * max_new_tokens
|
|
94
|
-
image_estimate = max_pixels
|
|
95
|
-
total_memory = INITIAL_MEMORY + SCALING_FACTOR * token_estimate * image_estimate
|
|
96
|
-
return total_memory
|