hud-python 0.4.45__py3-none-any.whl → 0.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hud/__init__.py +27 -7
- hud/agents/__init__.py +11 -5
- hud/agents/base.py +220 -500
- hud/agents/claude.py +200 -240
- hud/agents/gemini.py +275 -0
- hud/agents/gemini_cua.py +335 -0
- hud/agents/grounded_openai.py +98 -100
- hud/agents/misc/integration_test_agent.py +51 -20
- hud/agents/misc/response_agent.py +41 -36
- hud/agents/openai.py +291 -292
- hud/agents/{openai_chat_generic.py → openai_chat.py} +80 -34
- hud/agents/operator.py +211 -0
- hud/agents/tests/conftest.py +133 -0
- hud/agents/tests/test_base.py +300 -622
- hud/agents/tests/test_base_runtime.py +233 -0
- hud/agents/tests/test_claude.py +379 -210
- hud/agents/tests/test_client.py +9 -10
- hud/agents/tests/test_gemini.py +369 -0
- hud/agents/tests/test_grounded_openai_agent.py +65 -50
- hud/agents/tests/test_openai.py +376 -140
- hud/agents/tests/test_operator.py +362 -0
- hud/agents/tests/test_run_eval.py +179 -0
- hud/cli/__init__.py +461 -545
- hud/cli/analyze.py +43 -5
- hud/cli/build.py +664 -110
- hud/cli/debug.py +8 -5
- hud/cli/dev.py +882 -734
- hud/cli/eval.py +782 -668
- hud/cli/flows/dev.py +167 -0
- hud/cli/flows/init.py +191 -0
- hud/cli/flows/tasks.py +153 -56
- hud/cli/flows/templates.py +151 -0
- hud/cli/flows/tests/__init__.py +1 -0
- hud/cli/flows/tests/test_dev.py +126 -0
- hud/cli/init.py +60 -58
- hud/cli/push.py +29 -11
- hud/cli/rft.py +311 -0
- hud/cli/rft_status.py +145 -0
- hud/cli/tests/test_analyze.py +5 -5
- hud/cli/tests/test_analyze_metadata.py +3 -2
- hud/cli/tests/test_analyze_module.py +120 -0
- hud/cli/tests/test_build.py +108 -6
- hud/cli/tests/test_build_failure.py +41 -0
- hud/cli/tests/test_build_module.py +50 -0
- hud/cli/tests/test_cli_init.py +6 -1
- hud/cli/tests/test_cli_more_wrappers.py +30 -0
- hud/cli/tests/test_cli_root.py +140 -0
- hud/cli/tests/test_convert.py +361 -0
- hud/cli/tests/test_debug.py +12 -10
- hud/cli/tests/test_dev.py +197 -0
- hud/cli/tests/test_eval.py +251 -0
- hud/cli/tests/test_eval_bedrock.py +51 -0
- hud/cli/tests/test_init.py +124 -0
- hud/cli/tests/test_main_module.py +11 -5
- hud/cli/tests/test_mcp_server.py +12 -100
- hud/cli/tests/test_push_happy.py +74 -0
- hud/cli/tests/test_push_wrapper.py +23 -0
- hud/cli/tests/test_registry.py +1 -1
- hud/cli/tests/test_utils.py +1 -1
- hud/cli/{rl → utils}/celebrate.py +14 -12
- hud/cli/utils/config.py +18 -1
- hud/cli/utils/docker.py +130 -4
- hud/cli/utils/env_check.py +9 -9
- hud/cli/utils/git.py +136 -0
- hud/cli/utils/interactive.py +39 -5
- hud/cli/utils/metadata.py +69 -0
- hud/cli/utils/runner.py +1 -1
- hud/cli/utils/server.py +2 -2
- hud/cli/utils/source_hash.py +3 -3
- hud/cli/utils/tasks.py +4 -1
- hud/cli/utils/tests/__init__.py +0 -0
- hud/cli/utils/tests/test_config.py +58 -0
- hud/cli/utils/tests/test_docker.py +93 -0
- hud/cli/utils/tests/test_docker_hints.py +71 -0
- hud/cli/utils/tests/test_env_check.py +74 -0
- hud/cli/utils/tests/test_environment.py +42 -0
- hud/cli/utils/tests/test_git.py +142 -0
- hud/cli/utils/tests/test_interactive_module.py +60 -0
- hud/cli/utils/tests/test_local_runner.py +50 -0
- hud/cli/utils/tests/test_logging_utils.py +23 -0
- hud/cli/utils/tests/test_metadata.py +49 -0
- hud/cli/utils/tests/test_package_runner.py +35 -0
- hud/cli/utils/tests/test_registry_utils.py +49 -0
- hud/cli/utils/tests/test_remote_runner.py +25 -0
- hud/cli/utils/tests/test_runner_modules.py +52 -0
- hud/cli/utils/tests/test_source_hash.py +36 -0
- hud/cli/utils/tests/test_tasks.py +80 -0
- hud/cli/utils/version_check.py +258 -0
- hud/cli/{rl → utils}/viewer.py +2 -2
- hud/clients/README.md +12 -11
- hud/clients/__init__.py +4 -3
- hud/clients/base.py +166 -26
- hud/clients/environment.py +51 -0
- hud/clients/fastmcp.py +13 -6
- hud/clients/mcp_use.py +40 -15
- hud/clients/tests/test_analyze_scenarios.py +206 -0
- hud/clients/tests/test_protocol.py +9 -3
- hud/datasets/__init__.py +23 -20
- hud/datasets/loader.py +327 -0
- hud/datasets/runner.py +192 -105
- hud/datasets/tests/__init__.py +0 -0
- hud/datasets/tests/test_loader.py +221 -0
- hud/datasets/tests/test_utils.py +315 -0
- hud/datasets/utils.py +270 -90
- hud/environment/__init__.py +50 -0
- hud/environment/connection.py +206 -0
- hud/environment/connectors/__init__.py +33 -0
- hud/environment/connectors/base.py +68 -0
- hud/environment/connectors/local.py +177 -0
- hud/environment/connectors/mcp_config.py +109 -0
- hud/environment/connectors/openai.py +101 -0
- hud/environment/connectors/remote.py +172 -0
- hud/environment/environment.py +694 -0
- hud/environment/integrations/__init__.py +45 -0
- hud/environment/integrations/adk.py +67 -0
- hud/environment/integrations/anthropic.py +196 -0
- hud/environment/integrations/gemini.py +92 -0
- hud/environment/integrations/langchain.py +82 -0
- hud/environment/integrations/llamaindex.py +68 -0
- hud/environment/integrations/openai.py +238 -0
- hud/environment/mock.py +306 -0
- hud/environment/router.py +112 -0
- hud/environment/scenarios.py +493 -0
- hud/environment/tests/__init__.py +1 -0
- hud/environment/tests/test_connection.py +317 -0
- hud/environment/tests/test_connectors.py +218 -0
- hud/environment/tests/test_environment.py +161 -0
- hud/environment/tests/test_integrations.py +257 -0
- hud/environment/tests/test_local_connectors.py +201 -0
- hud/environment/tests/test_scenarios.py +280 -0
- hud/environment/tests/test_tools.py +208 -0
- hud/environment/types.py +23 -0
- hud/environment/utils/__init__.py +35 -0
- hud/environment/utils/formats.py +215 -0
- hud/environment/utils/schema.py +171 -0
- hud/environment/utils/tool_wrappers.py +113 -0
- hud/eval/__init__.py +67 -0
- hud/eval/context.py +674 -0
- hud/eval/display.py +299 -0
- hud/eval/instrument.py +185 -0
- hud/eval/manager.py +466 -0
- hud/eval/parallel.py +268 -0
- hud/eval/task.py +340 -0
- hud/eval/tests/__init__.py +1 -0
- hud/eval/tests/test_context.py +178 -0
- hud/eval/tests/test_eval.py +210 -0
- hud/eval/tests/test_manager.py +152 -0
- hud/eval/tests/test_parallel.py +168 -0
- hud/eval/tests/test_task.py +145 -0
- hud/eval/types.py +63 -0
- hud/eval/utils.py +183 -0
- hud/patches/__init__.py +19 -0
- hud/patches/mcp_patches.py +151 -0
- hud/patches/warnings.py +54 -0
- hud/samples/browser.py +4 -4
- hud/server/__init__.py +2 -1
- hud/server/low_level.py +2 -1
- hud/server/router.py +164 -0
- hud/server/server.py +567 -80
- hud/server/tests/test_mcp_server_integration.py +11 -11
- hud/server/tests/test_mcp_server_more.py +1 -1
- hud/server/tests/test_server_extra.py +2 -0
- hud/settings.py +45 -3
- hud/shared/exceptions.py +36 -10
- hud/shared/hints.py +26 -1
- hud/shared/requests.py +15 -3
- hud/shared/tests/test_exceptions.py +40 -31
- hud/shared/tests/test_hints.py +167 -0
- hud/telemetry/__init__.py +20 -19
- hud/telemetry/exporter.py +201 -0
- hud/telemetry/instrument.py +158 -253
- hud/telemetry/tests/test_eval_telemetry.py +356 -0
- hud/telemetry/tests/test_exporter.py +258 -0
- hud/telemetry/tests/test_instrument.py +401 -0
- hud/tools/__init__.py +16 -2
- hud/tools/apply_patch.py +639 -0
- hud/tools/base.py +54 -4
- hud/tools/bash.py +2 -2
- hud/tools/computer/__init__.py +4 -0
- hud/tools/computer/anthropic.py +2 -2
- hud/tools/computer/gemini.py +385 -0
- hud/tools/computer/hud.py +23 -6
- hud/tools/computer/openai.py +20 -21
- hud/tools/computer/qwen.py +434 -0
- hud/tools/computer/settings.py +37 -0
- hud/tools/edit.py +3 -7
- hud/tools/executors/base.py +4 -2
- hud/tools/executors/pyautogui.py +1 -1
- hud/tools/grounding/grounded_tool.py +13 -18
- hud/tools/grounding/grounder.py +10 -31
- hud/tools/grounding/tests/test_grounded_tool.py +26 -44
- hud/tools/jupyter.py +330 -0
- hud/tools/playwright.py +18 -3
- hud/tools/shell.py +308 -0
- hud/tools/tests/test_apply_patch.py +718 -0
- hud/tools/tests/test_computer.py +4 -9
- hud/tools/tests/test_computer_actions.py +24 -2
- hud/tools/tests/test_jupyter_tool.py +181 -0
- hud/tools/tests/test_shell.py +596 -0
- hud/tools/tests/test_submit.py +85 -0
- hud/tools/tests/test_types.py +193 -0
- hud/tools/types.py +21 -1
- hud/types.py +167 -57
- hud/utils/__init__.py +2 -0
- hud/utils/env.py +67 -0
- hud/utils/hud_console.py +61 -3
- hud/utils/mcp.py +15 -58
- hud/utils/strict_schema.py +162 -0
- hud/utils/tests/test_init.py +1 -2
- hud/utils/tests/test_mcp.py +1 -28
- hud/utils/tests/test_pretty_errors.py +186 -0
- hud/utils/tests/test_tool_shorthand.py +154 -0
- hud/utils/tests/test_version.py +1 -1
- hud/utils/types.py +20 -0
- hud/version.py +1 -1
- hud_python-0.5.1.dist-info/METADATA +264 -0
- hud_python-0.5.1.dist-info/RECORD +299 -0
- {hud_python-0.4.45.dist-info → hud_python-0.5.1.dist-info}/WHEEL +1 -1
- hud/agents/langchain.py +0 -261
- hud/agents/lite_llm.py +0 -72
- hud/cli/rl/__init__.py +0 -180
- hud/cli/rl/config.py +0 -101
- hud/cli/rl/display.py +0 -133
- hud/cli/rl/gpu.py +0 -63
- hud/cli/rl/gpu_utils.py +0 -321
- hud/cli/rl/local_runner.py +0 -595
- hud/cli/rl/presets.py +0 -96
- hud/cli/rl/remote_runner.py +0 -463
- hud/cli/rl/rl_api.py +0 -150
- hud/cli/rl/vllm.py +0 -177
- hud/cli/rl/wait_utils.py +0 -89
- hud/datasets/parallel.py +0 -687
- hud/misc/__init__.py +0 -1
- hud/misc/claude_plays_pokemon.py +0 -292
- hud/otel/__init__.py +0 -35
- hud/otel/collector.py +0 -142
- hud/otel/config.py +0 -181
- hud/otel/context.py +0 -570
- hud/otel/exporters.py +0 -369
- hud/otel/instrumentation.py +0 -135
- hud/otel/processors.py +0 -121
- hud/otel/tests/__init__.py +0 -1
- hud/otel/tests/test_processors.py +0 -197
- hud/rl/README.md +0 -30
- hud/rl/__init__.py +0 -1
- hud/rl/actor.py +0 -176
- hud/rl/buffer.py +0 -405
- hud/rl/chat_template.jinja +0 -101
- hud/rl/config.py +0 -192
- hud/rl/distributed.py +0 -132
- hud/rl/learner.py +0 -637
- hud/rl/tests/__init__.py +0 -1
- hud/rl/tests/test_learner.py +0 -186
- hud/rl/train.py +0 -382
- hud/rl/types.py +0 -101
- hud/rl/utils/start_vllm_server.sh +0 -30
- hud/rl/utils.py +0 -524
- hud/rl/vllm_adapter.py +0 -143
- hud/telemetry/job.py +0 -352
- hud/telemetry/replay.py +0 -74
- hud/telemetry/tests/test_replay.py +0 -40
- hud/telemetry/tests/test_trace.py +0 -63
- hud/telemetry/trace.py +0 -158
- hud/utils/agent_factories.py +0 -86
- hud/utils/async_utils.py +0 -65
- hud/utils/group_eval.py +0 -223
- hud/utils/progress.py +0 -149
- hud/utils/tasks.py +0 -127
- hud/utils/tests/test_async_utils.py +0 -173
- hud/utils/tests/test_progress.py +0 -261
- hud_python-0.4.45.dist-info/METADATA +0 -552
- hud_python-0.4.45.dist-info/RECORD +0 -228
- {hud_python-0.4.45.dist-info → hud_python-0.5.1.dist-info}/entry_points.txt +0 -0
- {hud_python-0.4.45.dist-info → hud_python-0.5.1.dist-info}/licenses/LICENSE +0 -0
hud/rl/tests/test_learner.py
DELETED
|
@@ -1,186 +0,0 @@
|
|
|
1
|
-
from __future__ import annotations
|
|
2
|
-
|
|
3
|
-
import pytest
|
|
4
|
-
import torch
|
|
5
|
-
|
|
6
|
-
from hud.rl.config import Config
|
|
7
|
-
from hud.rl.learner import GRPOLearner
|
|
8
|
-
from hud.rl.types import TrainingSample
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
@pytest.fixture()
|
|
12
|
-
def learner_stub(monkeypatch):
|
|
13
|
-
cfg = Config()
|
|
14
|
-
# Speed up: tiny settings
|
|
15
|
-
cfg.training.epochs = 1
|
|
16
|
-
cfg.training.group_size = 1
|
|
17
|
-
cfg.training.mini_batch_size = 1
|
|
18
|
-
cfg.training.use_8bit_optimizer = False
|
|
19
|
-
|
|
20
|
-
# Stub _load_models to avoid heavy model init
|
|
21
|
-
def _stub_load_models(self):
|
|
22
|
-
class DummyPolicy(torch.nn.Module):
|
|
23
|
-
def __init__(self):
|
|
24
|
-
super().__init__()
|
|
25
|
-
self.w = torch.nn.Parameter(torch.zeros(1))
|
|
26
|
-
|
|
27
|
-
dummy_policy = DummyPolicy()
|
|
28
|
-
dummy_opt = torch.optim.SGD(dummy_policy.parameters(), lr=0.1)
|
|
29
|
-
return None, dummy_policy, None, dummy_opt
|
|
30
|
-
|
|
31
|
-
monkeypatch.setattr(GRPOLearner, "_load_models", _stub_load_models, raising=True)
|
|
32
|
-
return GRPOLearner(cfg)
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
def make_sample(
|
|
36
|
-
pol_logp_tok: torch.Tensor,
|
|
37
|
-
old_logp_tok: torch.Tensor,
|
|
38
|
-
ref_logp_tok: torch.Tensor,
|
|
39
|
-
advantage: float,
|
|
40
|
-
):
|
|
41
|
-
# Minimal-but-correct object for GRPOLearner.compute_loss.
|
|
42
|
-
# Needs assistant_mask (T-1) and attention_mask (T) for sanity_check().
|
|
43
|
-
Tm1 = pol_logp_tok.size(-1)
|
|
44
|
-
inputs = {
|
|
45
|
-
"input_ids": torch.zeros(1, Tm1 + 1, dtype=torch.long),
|
|
46
|
-
"attention_mask": torch.ones(1, Tm1 + 1, dtype=torch.long),
|
|
47
|
-
"assistant_mask": torch.ones(1, Tm1, dtype=torch.bool),
|
|
48
|
-
}
|
|
49
|
-
return TrainingSample(
|
|
50
|
-
inputs=inputs,
|
|
51
|
-
old_logprobs=old_logp_tok,
|
|
52
|
-
ref_logprobs=ref_logp_tok,
|
|
53
|
-
# advantage must be 1D so .view(-1,1) works in compute_loss
|
|
54
|
-
advantage=torch.tensor([advantage], dtype=torch.float32),
|
|
55
|
-
)
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
def patch_compute_logprobs(
|
|
59
|
-
monkeypatch, learner: GRPOLearner, pol_logp_tok: torch.Tensor, pol_entropy_tok: torch.Tensor
|
|
60
|
-
):
|
|
61
|
-
# Return (pol_logp, pol_entropy) as expected by compute_loss
|
|
62
|
-
def _stub_compute_logprobs(self, model, inputs):
|
|
63
|
-
return pol_logp_tok.to(inputs["input_ids"].device), pol_entropy_tok.to(
|
|
64
|
-
inputs["input_ids"].device
|
|
65
|
-
)
|
|
66
|
-
|
|
67
|
-
monkeypatch.setattr(GRPOLearner, "compute_logprobs", _stub_compute_logprobs, raising=True)
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
def test_per_token_mean_vs_sum(monkeypatch, learner_stub: GRPOLearner):
|
|
71
|
-
# Setup
|
|
72
|
-
_, Tm1 = 1, 4
|
|
73
|
-
pol = torch.tensor([[-1.0, -1.0, -1.0, -1.0]], dtype=torch.float32) # logp
|
|
74
|
-
old = torch.tensor([[-1.2, -0.8, -1.0, -1.1]], dtype=torch.float32)
|
|
75
|
-
ref = torch.tensor([[-1.0, -1.0, -1.0, -1.0]], dtype=torch.float32)
|
|
76
|
-
ent = torch.zeros_like(pol)
|
|
77
|
-
patch_compute_logprobs(monkeypatch, learner_stub, pol, ent)
|
|
78
|
-
|
|
79
|
-
# Common config
|
|
80
|
-
learner_stub.config.training.kl_beta = 0.0
|
|
81
|
-
learner_stub.config.training.entropy_beta = 0.0
|
|
82
|
-
learner_stub.config.training.top_eps = 0.2
|
|
83
|
-
learner_stub.config.training.bottom_eps = 0.1
|
|
84
|
-
|
|
85
|
-
sample = make_sample(pol, old, ref, advantage=1.0)
|
|
86
|
-
|
|
87
|
-
# token_agg=mean
|
|
88
|
-
learner_stub.config.training.ppo_mode = "per_token"
|
|
89
|
-
learner_stub.config.training.token_agg = "mean"
|
|
90
|
-
loss_mean = learner_stub.compute_loss(sample).item()
|
|
91
|
-
|
|
92
|
-
# token_agg=sum
|
|
93
|
-
learner_stub.config.training.token_agg = "sum"
|
|
94
|
-
loss_sum = learner_stub.compute_loss(sample).item()
|
|
95
|
-
|
|
96
|
-
# Expect sum ≈ mean * num_tokens
|
|
97
|
-
assert pytest.approx(loss_sum, rel=1e-5) == loss_mean * Tm1
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
def test_per_trace_vs_per_token(monkeypatch, learner_stub: GRPOLearner):
|
|
101
|
-
# Equal per-token deltas -> per_trace matches per_token(mean)
|
|
102
|
-
pol = torch.tensor([[-1.0, -1.0, -1.0]], dtype=torch.float32)
|
|
103
|
-
old = torch.tensor([[-1.2, -1.2, -1.2]], dtype=torch.float32)
|
|
104
|
-
ref = torch.tensor([[-1.1, -1.1, -1.1]], dtype=torch.float32)
|
|
105
|
-
ent = torch.zeros_like(pol)
|
|
106
|
-
patch_compute_logprobs(monkeypatch, learner_stub, pol, ent)
|
|
107
|
-
|
|
108
|
-
learner_stub.config.training.kl_beta = 0.0
|
|
109
|
-
learner_stub.config.training.entropy_beta = 0.0
|
|
110
|
-
learner_stub.config.training.top_eps = 0.2
|
|
111
|
-
learner_stub.config.training.bottom_eps = 0.1
|
|
112
|
-
|
|
113
|
-
sample = make_sample(pol, old, ref, advantage=1.0)
|
|
114
|
-
|
|
115
|
-
learner_stub.config.training.ppo_mode = "per_token"
|
|
116
|
-
learner_stub.config.training.token_agg = "mean"
|
|
117
|
-
ltok = learner_stub.compute_loss(sample).item()
|
|
118
|
-
|
|
119
|
-
learner_stub.config.training.ppo_mode = "per_trace"
|
|
120
|
-
ltraj = learner_stub.compute_loss(sample).item()
|
|
121
|
-
|
|
122
|
-
assert pytest.approx(ltraj, rel=1e-6) == ltok
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
def test_entropy_beta_effect(monkeypatch, learner_stub: GRPOLearner):
|
|
126
|
-
pol = torch.tensor([[-1.0, -1.1]], dtype=torch.float32)
|
|
127
|
-
old = torch.tensor([[-1.0, -1.1]], dtype=torch.float32)
|
|
128
|
-
ref = torch.tensor([[-1.0, -1.1]], dtype=torch.float32)
|
|
129
|
-
ent = torch.tensor([[0.5, 1.5]], dtype=torch.float32)
|
|
130
|
-
patch_compute_logprobs(monkeypatch, learner_stub, pol, ent)
|
|
131
|
-
|
|
132
|
-
# No policy/kl effect, only entropy
|
|
133
|
-
learner_stub.config.training.ppo_mode = "per_token"
|
|
134
|
-
learner_stub.config.training.token_agg = "mean"
|
|
135
|
-
learner_stub.config.training.kl_beta = 0.0
|
|
136
|
-
|
|
137
|
-
sample = make_sample(pol, old, ref, advantage=0.0)
|
|
138
|
-
|
|
139
|
-
learner_stub.config.training.entropy_beta = 0.0
|
|
140
|
-
l0 = learner_stub.compute_loss(sample).item()
|
|
141
|
-
|
|
142
|
-
learner_stub.config.training.entropy_beta = 2.0
|
|
143
|
-
l1 = learner_stub.compute_loss(sample).item()
|
|
144
|
-
|
|
145
|
-
# Mean entropy = (0.5+1.5)/2 = 1.0, scaled by beta=2.0 -> +2.0
|
|
146
|
-
assert pytest.approx(l1 - l0, rel=1e-6) == 2.0
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
def test_skip_update_when_zero_adv(monkeypatch, learner_stub: GRPOLearner):
|
|
150
|
-
# Patch prepare_groups to yield a single group with a minibatch-like object
|
|
151
|
-
class MiniBatch:
|
|
152
|
-
def __init__(self):
|
|
153
|
-
self.advantage = torch.zeros(1)
|
|
154
|
-
|
|
155
|
-
def to_device(self, device: torch.device) -> MiniBatch:
|
|
156
|
-
return self
|
|
157
|
-
|
|
158
|
-
def _stub_prepare_groups(self, samples: list[TrainingSample]) -> list[list[MiniBatch]]:
|
|
159
|
-
return [[MiniBatch(), MiniBatch()]]
|
|
160
|
-
|
|
161
|
-
monkeypatch.setattr(GRPOLearner, "prepare_groups", _stub_prepare_groups, raising=True)
|
|
162
|
-
|
|
163
|
-
# Return a zero scalar loss that *depends* on params so backward works,
|
|
164
|
-
# but has zero gradients (no update signal).
|
|
165
|
-
def _zero_loss(self, sample) -> torch.Tensor:
|
|
166
|
-
return sum(p.sum() for p in self.policy.parameters()) * 0.0 # type: ignore
|
|
167
|
-
|
|
168
|
-
monkeypatch.setattr(GRPOLearner, "compute_loss", _zero_loss, raising=True)
|
|
169
|
-
|
|
170
|
-
# Count optimizer.step calls
|
|
171
|
-
steps = {"n": 0}
|
|
172
|
-
# orig_step = learner_stub.optimizer.step
|
|
173
|
-
|
|
174
|
-
def _count_step():
|
|
175
|
-
steps["n"] += 1
|
|
176
|
-
|
|
177
|
-
monkeypatch.setattr(learner_stub.optimizer, "step", _count_step, raising=False)
|
|
178
|
-
|
|
179
|
-
# Ensure dummy backward can touch a parameter
|
|
180
|
-
assert any(p.requires_grad for p in learner_stub.policy.parameters())
|
|
181
|
-
|
|
182
|
-
learner_stub.update([])
|
|
183
|
-
# With the current learner implementation we still call optimizer.step()
|
|
184
|
-
# even if the per-minibatch "advantage" is zero (the step is a no-op
|
|
185
|
-
# because the gradients are zero). So we expect exactly one step here.
|
|
186
|
-
assert steps["n"] == 1
|
hud/rl/train.py
DELETED
|
@@ -1,382 +0,0 @@
|
|
|
1
|
-
"""Main training loop for GRPO RL."""
|
|
2
|
-
|
|
3
|
-
from __future__ import annotations
|
|
4
|
-
|
|
5
|
-
import os
|
|
6
|
-
|
|
7
|
-
# Disable tokenizer parallelism warnings
|
|
8
|
-
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
9
|
-
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
|
10
|
-
import argparse
|
|
11
|
-
import asyncio
|
|
12
|
-
import json
|
|
13
|
-
import logging
|
|
14
|
-
from pathlib import Path
|
|
15
|
-
from typing import TYPE_CHECKING, cast
|
|
16
|
-
|
|
17
|
-
import hud
|
|
18
|
-
from hud.rl.actor import Actor
|
|
19
|
-
from hud.rl.buffer import DatasetBuffer, ReplayBuffer
|
|
20
|
-
from hud.rl.config import Config
|
|
21
|
-
from hud.rl.distributed import (
|
|
22
|
-
broadcast_object,
|
|
23
|
-
cleanup_distributed,
|
|
24
|
-
get_global_rank,
|
|
25
|
-
get_world_size,
|
|
26
|
-
is_main_process,
|
|
27
|
-
scatter_object,
|
|
28
|
-
setup_distributed,
|
|
29
|
-
synchronize,
|
|
30
|
-
)
|
|
31
|
-
from hud.rl.learner import GRPOLearner
|
|
32
|
-
from hud.rl.utils import (
|
|
33
|
-
aggregate_metrics_across_ranks,
|
|
34
|
-
ensure_dir,
|
|
35
|
-
preprocess_advantages,
|
|
36
|
-
set_seed,
|
|
37
|
-
)
|
|
38
|
-
from hud.rl.vllm_adapter import VLLMAdapter
|
|
39
|
-
from hud.utils.hud_console import HUDConsole
|
|
40
|
-
from hud.utils.tasks import load_tasks
|
|
41
|
-
|
|
42
|
-
if TYPE_CHECKING:
|
|
43
|
-
from hud.types import Task
|
|
44
|
-
hud_console = HUDConsole(logging.getLogger(__name__))
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
async def train(config: Config, tasks: list[Task]) -> None:
|
|
48
|
-
"""Main training loop."""
|
|
49
|
-
# Setup distributed environment
|
|
50
|
-
setup_distributed()
|
|
51
|
-
|
|
52
|
-
# Initialize components
|
|
53
|
-
set_seed(config.seed + get_global_rank()) # Different seed per rank
|
|
54
|
-
ensure_dir(config.out_dir)
|
|
55
|
-
if config.verbose:
|
|
56
|
-
logging.basicConfig(level=logging.INFO)
|
|
57
|
-
# Remove httpx logger
|
|
58
|
-
logging.getLogger("httpx").setLevel(logging.WARNING)
|
|
59
|
-
if config.very_verbose:
|
|
60
|
-
logging.basicConfig(level=logging.DEBUG)
|
|
61
|
-
# Remove httpx logger
|
|
62
|
-
logging.getLogger("httpx").setLevel(logging.INFO)
|
|
63
|
-
|
|
64
|
-
if is_main_process():
|
|
65
|
-
hud_console.header("Starting GRPO Training")
|
|
66
|
-
hud_console.section_title(
|
|
67
|
-
f"\n[1/3] Initializing components (world_size={get_world_size()})..."
|
|
68
|
-
)
|
|
69
|
-
|
|
70
|
-
num_gpus = get_world_size()
|
|
71
|
-
|
|
72
|
-
# Actor is responsible for running tasks and collecting episodes
|
|
73
|
-
actor = Actor(config) if is_main_process() else None
|
|
74
|
-
|
|
75
|
-
# Learner is responsible for updating the policy
|
|
76
|
-
learner = GRPOLearner(config)
|
|
77
|
-
|
|
78
|
-
# Dataset buffer is responsible for storing tasks
|
|
79
|
-
dataset_buffer = DatasetBuffer(tasks, config)
|
|
80
|
-
if is_main_process():
|
|
81
|
-
hud_console.key_value_table(dataset_buffer.info)
|
|
82
|
-
|
|
83
|
-
if dataset_buffer.groups_per_batch % num_gpus != 0:
|
|
84
|
-
hud_console.warning(
|
|
85
|
-
f"Groups per batch {dataset_buffer.groups_per_batch} is not divisible by number of GPUs {num_gpus}" # noqa: E501
|
|
86
|
-
)
|
|
87
|
-
exit(1)
|
|
88
|
-
|
|
89
|
-
# Replay buffer is responsible for storing episodes for training
|
|
90
|
-
trace_buffer = ReplayBuffer(config)
|
|
91
|
-
|
|
92
|
-
# VLLM adapter is responsible for loading and unloading adapters (only on main process)
|
|
93
|
-
vllm = (
|
|
94
|
-
VLLMAdapter(config.actor.vllm_base_url, config.actor.vllm_api_key)
|
|
95
|
-
if is_main_process()
|
|
96
|
-
else None
|
|
97
|
-
)
|
|
98
|
-
|
|
99
|
-
# Training state
|
|
100
|
-
step = 0
|
|
101
|
-
last_metrics = None # Store last successful metrics for error recovery
|
|
102
|
-
|
|
103
|
-
if is_main_process():
|
|
104
|
-
hud_console.section_title("\n[2/3] Running training loop...")
|
|
105
|
-
|
|
106
|
-
# Create job on main process and distribute ID across GPUs
|
|
107
|
-
if is_main_process():
|
|
108
|
-
hud_console.info(f"Creating job with config.job_id: {config.job_id}")
|
|
109
|
-
job_obj = hud.create_job(
|
|
110
|
-
job_id=config.job_id,
|
|
111
|
-
name=config.job_name,
|
|
112
|
-
metadata={"config": config.to_dict(), "agent_class": config.model.base_model},
|
|
113
|
-
)
|
|
114
|
-
hud_console.info(f"Created job with job_obj.id: {job_obj.id}")
|
|
115
|
-
job_obj.update_status_sync("running")
|
|
116
|
-
job_id = job_obj.id
|
|
117
|
-
else:
|
|
118
|
-
job_obj = None
|
|
119
|
-
job_id = None
|
|
120
|
-
|
|
121
|
-
# Broadcast job ID to all ranks
|
|
122
|
-
job_id = broadcast_object(job_id, src=0)
|
|
123
|
-
|
|
124
|
-
try:
|
|
125
|
-
while len(dataset_buffer) > 0:
|
|
126
|
-
if is_main_process():
|
|
127
|
-
hud_console.section_title(f"Step {step + 1}/{dataset_buffer.training_steps}")
|
|
128
|
-
hud_console.info(f"{len(dataset_buffer)} tasks remaining")
|
|
129
|
-
# Get batch of tasks (all ranks need same tasks)
|
|
130
|
-
tasks = dataset_buffer.get_tasks()
|
|
131
|
-
|
|
132
|
-
# Initialize variables on all ranks
|
|
133
|
-
global_reward_stats = None
|
|
134
|
-
global_advantage_stats = None
|
|
135
|
-
|
|
136
|
-
# Step-state gate: ensure all ranks branch coherently
|
|
137
|
-
state = {"ok": False, "err": None, "num_samples": 0}
|
|
138
|
-
rank_samples = None
|
|
139
|
-
episode_time_value = None
|
|
140
|
-
|
|
141
|
-
# Only rank 0 runs tasks and prepares distribution
|
|
142
|
-
if is_main_process() and actor is not None:
|
|
143
|
-
import time
|
|
144
|
-
|
|
145
|
-
try:
|
|
146
|
-
episode_start_time = time.time()
|
|
147
|
-
traces = await actor.run_tasks(tasks, job_id=job_id)
|
|
148
|
-
episode_time = time.time() - episode_start_time
|
|
149
|
-
hud_console.info(f"Sampled {len(traces)} traces in {episode_time:.1f}s")
|
|
150
|
-
trace_buffer.add(traces)
|
|
151
|
-
global_reward_stats = [trace.reward for trace in traces]
|
|
152
|
-
|
|
153
|
-
# Get all traces from buffer for distribution
|
|
154
|
-
all_traces = trace_buffer.sample_traces()
|
|
155
|
-
|
|
156
|
-
# Preprocess traces to training samples
|
|
157
|
-
preprocessed_traces = preprocess_advantages(all_traces, config)
|
|
158
|
-
|
|
159
|
-
# Store these for later use in metrics
|
|
160
|
-
global_advantage_stats = [sample.advantage for sample in preprocessed_traces]
|
|
161
|
-
|
|
162
|
-
# Distribute preprocessed samples in groups across ranks via scatter
|
|
163
|
-
# Ensure list length is a multiple of num_gpus by allowing empty per-rank slices
|
|
164
|
-
gpu_batch_size = max(1, (len(preprocessed_traces) + num_gpus - 1) // num_gpus)
|
|
165
|
-
rank_samples = [
|
|
166
|
-
preprocessed_traces[i : i + gpu_batch_size]
|
|
167
|
-
for i in range(0, len(preprocessed_traces), gpu_batch_size)
|
|
168
|
-
]
|
|
169
|
-
# Pad rank_samples to exactly num_gpus entries
|
|
170
|
-
if len(rank_samples) < num_gpus:
|
|
171
|
-
rank_samples.extend([[] for _ in range(num_gpus - len(rank_samples))])
|
|
172
|
-
|
|
173
|
-
# Log distribution info
|
|
174
|
-
dist_msg = (
|
|
175
|
-
f"Distributing {len(preprocessed_traces)} samples as {gpu_batch_size} "
|
|
176
|
-
f"sized batches across {num_gpus} GPUs"
|
|
177
|
-
)
|
|
178
|
-
hud_console.info(dist_msg)
|
|
179
|
-
for rank in range(num_gpus):
|
|
180
|
-
n_samples = len(rank_samples[rank]) if rank < len(rank_samples) else 0
|
|
181
|
-
hud_console.info(f" Rank {rank}: {n_samples} samples")
|
|
182
|
-
|
|
183
|
-
hud_console.section_title(f"Training on {len(all_traces)} traces")
|
|
184
|
-
episode_time_value = episode_time
|
|
185
|
-
|
|
186
|
-
state.update({"ok": True, "num_samples": len(preprocessed_traces)})
|
|
187
|
-
except Exception as e:
|
|
188
|
-
state.update({"ok": False, "err": str(e)})
|
|
189
|
-
|
|
190
|
-
# Broadcast step-state to keep ranks in lockstep
|
|
191
|
-
state = broadcast_object(state, src=0)
|
|
192
|
-
if not state.get("ok", False):
|
|
193
|
-
hud_console.warning("Step failed on rank 0; skipping this step coherently")
|
|
194
|
-
synchronize()
|
|
195
|
-
continue
|
|
196
|
-
|
|
197
|
-
# Scatter per-rank samples; each rank receives only its slice
|
|
198
|
-
my_samples = scatter_object(rank_samples if is_main_process() else None, src=0)
|
|
199
|
-
# Broadcast the episode time (small object)
|
|
200
|
-
episode_time_value = broadcast_object(episode_time_value, src=0)
|
|
201
|
-
|
|
202
|
-
# Process only assigned samples
|
|
203
|
-
last_metrics = learner.update(my_samples)
|
|
204
|
-
|
|
205
|
-
# Add episode time (same for all ranks since episodes run on rank 0)
|
|
206
|
-
if episode_time_value is not None:
|
|
207
|
-
last_metrics.update(
|
|
208
|
-
{
|
|
209
|
-
"episode_time": episode_time_value,
|
|
210
|
-
}
|
|
211
|
-
)
|
|
212
|
-
|
|
213
|
-
# Aggregate metrics across all GPUs for proper statistics
|
|
214
|
-
aggregate_metrics_across_ranks(last_metrics)
|
|
215
|
-
|
|
216
|
-
if is_main_process() and job_obj is not None:
|
|
217
|
-
# Use the global statistics we collected before distribution
|
|
218
|
-
if global_reward_stats is not None and global_advantage_stats is not None:
|
|
219
|
-
last_metrics.update(
|
|
220
|
-
{
|
|
221
|
-
"advantage": global_advantage_stats,
|
|
222
|
-
"reward": global_reward_stats,
|
|
223
|
-
}
|
|
224
|
-
)
|
|
225
|
-
else:
|
|
226
|
-
# Fallback: use only this rank's data
|
|
227
|
-
hud_console.warning("Global statistics not available, using partial data")
|
|
228
|
-
last_metrics.update(
|
|
229
|
-
{
|
|
230
|
-
"advantage": [sample.advantage for sample in my_samples]
|
|
231
|
-
if my_samples
|
|
232
|
-
else [],
|
|
233
|
-
"reward": [sample.reward for sample in my_samples]
|
|
234
|
-
if my_samples
|
|
235
|
-
else [],
|
|
236
|
-
}
|
|
237
|
-
)
|
|
238
|
-
|
|
239
|
-
job_obj.log_sync(last_metrics.to_dict())
|
|
240
|
-
|
|
241
|
-
if step % config.stats_interval == 0:
|
|
242
|
-
hud_console.key_value_table(last_metrics.to_dict())
|
|
243
|
-
|
|
244
|
-
# Increment step counter on all processes
|
|
245
|
-
step += 1
|
|
246
|
-
|
|
247
|
-
# Save checkpoint and update vLLM (only on main process)
|
|
248
|
-
if step % config.training.save_every_batches == 0:
|
|
249
|
-
if is_main_process() and vllm is not None and actor is not None:
|
|
250
|
-
hud_console.section_title("Saving checkpoint and updating vLLM")
|
|
251
|
-
checkpoint_path = Path(config.out_dir) / f"{config.adapter_prefix}-{step}"
|
|
252
|
-
learner.save(str(checkpoint_path))
|
|
253
|
-
|
|
254
|
-
# Wait for 6 seconds to ensure the checkpoint is saved
|
|
255
|
-
await asyncio.sleep(6)
|
|
256
|
-
|
|
257
|
-
# If there is a previous adapter, unload it
|
|
258
|
-
current_adapter = vllm.get_current()
|
|
259
|
-
if current_adapter is not None:
|
|
260
|
-
vllm.unload_adapter(current_adapter)
|
|
261
|
-
|
|
262
|
-
adapter_name = f"{config.adapter_prefix}-{step}"
|
|
263
|
-
if vllm.load_adapter(adapter_name, str(checkpoint_path)):
|
|
264
|
-
actor.update_adapter(adapter_name)
|
|
265
|
-
hud_console.info(f"✓ Checkpoint saved and loaded: {adapter_name}")
|
|
266
|
-
else:
|
|
267
|
-
hud_console.warning(f"Failed to hot-load adapter {adapter_name}")
|
|
268
|
-
|
|
269
|
-
# Ensure all processes wait for checkpoint operations to complete
|
|
270
|
-
synchronize()
|
|
271
|
-
|
|
272
|
-
if is_main_process():
|
|
273
|
-
hud_console.section_title("\n[3/3] Training completed!")
|
|
274
|
-
# Update job status to completed
|
|
275
|
-
if job_obj:
|
|
276
|
-
job_obj.update_status_sync("completed")
|
|
277
|
-
except Exception as e:
|
|
278
|
-
# Log error and any available metrics before failing
|
|
279
|
-
hud_console.error(f"Training failed on rank {get_global_rank()}: {e}")
|
|
280
|
-
|
|
281
|
-
if is_main_process():
|
|
282
|
-
# Log final metrics if we have any
|
|
283
|
-
if last_metrics and job_obj:
|
|
284
|
-
try:
|
|
285
|
-
job_obj.log_sync(last_metrics.to_dict())
|
|
286
|
-
except Exception:
|
|
287
|
-
hud_console.warning("Failed to log final metrics")
|
|
288
|
-
|
|
289
|
-
# Update job status to failed
|
|
290
|
-
if job_obj:
|
|
291
|
-
job_obj.update_status_sync("failed")
|
|
292
|
-
|
|
293
|
-
# Don't re-raise immediately to allow cleanup
|
|
294
|
-
raise
|
|
295
|
-
|
|
296
|
-
finally:
|
|
297
|
-
# Try to sync one last time, but don't fail if it doesn't work
|
|
298
|
-
try:
|
|
299
|
-
synchronize()
|
|
300
|
-
except Exception:
|
|
301
|
-
hud_console.warning("Failed to synchronize during cleanup")
|
|
302
|
-
|
|
303
|
-
# Clean up distributed environment
|
|
304
|
-
cleanup_distributed()
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
async def main() -> None:
|
|
308
|
-
parser = argparse.ArgumentParser(description="GRPO RL Training")
|
|
309
|
-
parser.add_argument("--config", type=str, help="Path to config JSON file")
|
|
310
|
-
parser.add_argument("--test", action="store_true", help="Run in test mode")
|
|
311
|
-
parser.add_argument("--debug", action="store_true", help="Enable debug mode")
|
|
312
|
-
parser.add_argument("--verbose", action="store_true", help="Enable verbose mode")
|
|
313
|
-
# Task input arguments
|
|
314
|
-
parser.add_argument(
|
|
315
|
-
"--tasks", type=str, help="Path to tasks JSONL file or HuggingFace dataset name"
|
|
316
|
-
)
|
|
317
|
-
parser.add_argument("--tasks-json", type=json.loads, help="Tasks as JSON list string")
|
|
318
|
-
|
|
319
|
-
args = parser.parse_args()
|
|
320
|
-
|
|
321
|
-
# Load config
|
|
322
|
-
if args.config:
|
|
323
|
-
with open(args.config, encoding="utf-8") as f: # noqa: ASYNC230
|
|
324
|
-
config_dict = json.load(f)
|
|
325
|
-
config = Config.from_dict(config_dict)
|
|
326
|
-
else:
|
|
327
|
-
config = Config()
|
|
328
|
-
|
|
329
|
-
# Apply test mode settings
|
|
330
|
-
if args.test:
|
|
331
|
-
hud_console.info("[TEST MODE] Using minimal configuration")
|
|
332
|
-
eps = 6
|
|
333
|
-
config.training.batch_size = eps
|
|
334
|
-
config.actor.max_parallel_episodes = 12
|
|
335
|
-
config.training.group_size = eps
|
|
336
|
-
config.training.mini_batch_size = 3
|
|
337
|
-
config.training.training_steps = 4
|
|
338
|
-
config.actor.max_steps_per_episode = 4
|
|
339
|
-
|
|
340
|
-
# Calculate the memory usage
|
|
341
|
-
INITIAL_MEMORY = 8.0
|
|
342
|
-
SCALING_FACTOR = 4 / (28 * 28 * 256 * 1024)
|
|
343
|
-
token_estimate = (
|
|
344
|
-
config.training.mini_batch_size
|
|
345
|
-
* config.actor.max_steps_per_episode
|
|
346
|
-
* config.actor.max_new_tokens
|
|
347
|
-
)
|
|
348
|
-
hud_console.info(f"Estimated tokens per forward pass: {token_estimate}")
|
|
349
|
-
image_estimate = config.model.max_pixels
|
|
350
|
-
total_memory = INITIAL_MEMORY + SCALING_FACTOR * token_estimate * image_estimate
|
|
351
|
-
hud_console.info(f"Estimated memory peak: {total_memory:.2f} GB")
|
|
352
|
-
if total_memory > 75.0:
|
|
353
|
-
hud_console.warning(
|
|
354
|
-
"Potential memory usage is too high, decrease either training steps or mini batch size"
|
|
355
|
-
)
|
|
356
|
-
exit(1)
|
|
357
|
-
|
|
358
|
-
# Load tasks
|
|
359
|
-
if args.tasks_json:
|
|
360
|
-
# Tasks provided as JSON list via command line
|
|
361
|
-
tasks = load_tasks(args.tasks_json)
|
|
362
|
-
elif args.tasks:
|
|
363
|
-
# Tasks provided as file path or HuggingFace dataset
|
|
364
|
-
tasks = load_tasks(args.tasks)
|
|
365
|
-
else:
|
|
366
|
-
# Default to browser_2048_tasks.jsonl if it exists
|
|
367
|
-
default_tasks_path = "browser_2048_tasks.jsonl"
|
|
368
|
-
if Path(default_tasks_path).exists():
|
|
369
|
-
hud_console.info(f"No tasks specified, using default: {default_tasks_path}")
|
|
370
|
-
tasks = load_tasks(default_tasks_path)
|
|
371
|
-
else:
|
|
372
|
-
raise ValueError(
|
|
373
|
-
"No tasks specified. Use --tasks, --tasks-json, or specify tasks_file in config"
|
|
374
|
-
)
|
|
375
|
-
|
|
376
|
-
# Run training
|
|
377
|
-
tasks_typed = cast("list[Task]", tasks)
|
|
378
|
-
await train(config, tasks_typed)
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
if __name__ == "__main__":
|
|
382
|
-
asyncio.run(main())
|
hud/rl/types.py
DELETED
|
@@ -1,101 +0,0 @@
|
|
|
1
|
-
"""Shared types for RL training."""
|
|
2
|
-
|
|
3
|
-
from __future__ import annotations
|
|
4
|
-
|
|
5
|
-
import math
|
|
6
|
-
from typing import Any
|
|
7
|
-
|
|
8
|
-
from pydantic import ConfigDict, Field
|
|
9
|
-
from pydantic.dataclasses import dataclass
|
|
10
|
-
|
|
11
|
-
from hud.types import Trace
|
|
12
|
-
|
|
13
|
-
try:
|
|
14
|
-
import torch
|
|
15
|
-
except ImportError:
|
|
16
|
-
raise ImportError("uv tool install hud-python[rl] to use this module") from None
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
class TrainingSample(Trace):
|
|
20
|
-
"""A single training sample for GRPO."""
|
|
21
|
-
|
|
22
|
-
model_config = ConfigDict(arbitrary_types_allowed=True)
|
|
23
|
-
# Tokenized inputs to the model (model.forward(*inputs))
|
|
24
|
-
# This includes the input tokens, logit mask, etc.
|
|
25
|
-
inputs: dict[str, torch.Tensor] = Field(default_factory=dict)
|
|
26
|
-
old_logprobs: torch.Tensor | None = Field(default=None)
|
|
27
|
-
ref_logprobs: torch.Tensor | None = Field(default=None)
|
|
28
|
-
|
|
29
|
-
# Weighted advantage of group calculation
|
|
30
|
-
advantage: torch.Tensor | None = Field(default=None)
|
|
31
|
-
|
|
32
|
-
def to_device(self, device: torch.device) -> TrainingSample:
|
|
33
|
-
"""Move sample to device."""
|
|
34
|
-
self.inputs = {
|
|
35
|
-
k: (t.to(device, non_blocking=True) if hasattr(t, "to") else t)
|
|
36
|
-
for k, t in self.inputs.items()
|
|
37
|
-
}
|
|
38
|
-
self.advantage = self.advantage.to(device) if self.advantage is not None else None
|
|
39
|
-
self.old_logprobs = self.old_logprobs.to(device) if self.old_logprobs is not None else None
|
|
40
|
-
self.ref_logprobs = self.ref_logprobs.to(device) if self.ref_logprobs is not None else None
|
|
41
|
-
return self
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
@dataclass
|
|
45
|
-
class Metric:
|
|
46
|
-
"""A tuple for metrics."""
|
|
47
|
-
|
|
48
|
-
name: str = Field(default="")
|
|
49
|
-
mean: float = Field(default=0.0)
|
|
50
|
-
std: float = Field(default=0.0)
|
|
51
|
-
values: list[float] = Field(default_factory=list)
|
|
52
|
-
|
|
53
|
-
def update(
|
|
54
|
-
self, value: float | torch.Tensor | list[float] | list[int] | list[torch.Tensor]
|
|
55
|
-
) -> None:
|
|
56
|
-
"""Update metric."""
|
|
57
|
-
if isinstance(value, list):
|
|
58
|
-
self.values.extend(value.item() if isinstance(value, torch.Tensor) else value) # type: ignore
|
|
59
|
-
else:
|
|
60
|
-
self.values.append(value.item() if isinstance(value, torch.Tensor) else value) # type: ignore
|
|
61
|
-
mean_val = sum(self.values) / len(self.values)
|
|
62
|
-
self.mean = mean_val.item() if isinstance(mean_val, torch.Tensor) else float(mean_val) # type: ignore
|
|
63
|
-
variance = sum((x - self.mean) ** 2 for x in self.values) / len(self.values)
|
|
64
|
-
variance_val = variance.item() if isinstance(variance, torch.Tensor) else float(variance) # type: ignore
|
|
65
|
-
self.std = math.sqrt(variance_val)
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
@dataclass
|
|
69
|
-
class TrainingMetrics:
|
|
70
|
-
"""Metrics for GRPO training (per training step)."""
|
|
71
|
-
|
|
72
|
-
# Learner metrics
|
|
73
|
-
grad_norm: Metric = Field(default=Metric())
|
|
74
|
-
loss: Metric = Field(default=Metric())
|
|
75
|
-
kl: Metric = Field(default=Metric())
|
|
76
|
-
reward: Metric = Field(default=Metric())
|
|
77
|
-
advantage: Metric = Field(default=Metric())
|
|
78
|
-
policy_ratio: Metric = Field(default=Metric())
|
|
79
|
-
tokens: Metric = Field(default=Metric())
|
|
80
|
-
entropy: Metric = Field(default=Metric())
|
|
81
|
-
|
|
82
|
-
# Computation metrics
|
|
83
|
-
gpu_util: Metric = Field(default=Metric()) # GPU utilization percentage
|
|
84
|
-
gpu_memory: Metric = Field(default=Metric()) # GPU memory usage in GB
|
|
85
|
-
episode_time: Metric = Field(default=Metric()) # Time to run episodes (actor)
|
|
86
|
-
training_time: Metric = Field(default=Metric()) # Time for gradient updates (learner)
|
|
87
|
-
samples_per_second: Metric = Field(default=Metric()) # Training throughput
|
|
88
|
-
|
|
89
|
-
def update(self, metrics: dict[str, Any]) -> None:
|
|
90
|
-
"""Update metrics."""
|
|
91
|
-
for key, value in metrics.items():
|
|
92
|
-
if key in self.__dataclass_fields__:
|
|
93
|
-
getattr(self, key).update(value)
|
|
94
|
-
|
|
95
|
-
def to_dict(self) -> dict[str, Any]:
|
|
96
|
-
"""Convert metrics to dictionary."""
|
|
97
|
-
final_metrics = {}
|
|
98
|
-
for key in self.__dataclass_fields__:
|
|
99
|
-
final_metrics[f"{key}_mean"] = getattr(self, key).mean
|
|
100
|
-
final_metrics[f"{key}_std"] = getattr(self, key).std
|
|
101
|
-
return final_metrics
|
|
@@ -1,30 +0,0 @@
|
|
|
1
|
-
#!/bin/bash
|
|
2
|
-
# Start vLLM server with OpenAI-compatible API
|
|
3
|
-
|
|
4
|
-
echo "Starting vLLM server for Qwen2.5-VL-3B-Instruct..."
|
|
5
|
-
|
|
6
|
-
# Enable runtime LoRA adapter loading
|
|
7
|
-
export VLLM_ALLOW_RUNTIME_LORA_UPDATING=True
|
|
8
|
-
|
|
9
|
-
export TOKENIZERS_PARALLELISM=false
|
|
10
|
-
export VLLM_LOGGING_LEVEL=DEBUG
|
|
11
|
-
export CUDA_LAUNCH_BLOCKING=1 # Better error messages for CUDA errors
|
|
12
|
-
|
|
13
|
-
# Common vLLM server command
|
|
14
|
-
# Using CUDA_VISIBLE_DEVICES to put vLLM on GPU 1
|
|
15
|
-
CUDA_VISIBLE_DEVICES=1 uv run vllm serve \
|
|
16
|
-
Qwen/Qwen2.5-VL-3B-Instruct \
|
|
17
|
-
--api-key token-abc123 \
|
|
18
|
-
--host 0.0.0.0 \
|
|
19
|
-
--port 8000 \
|
|
20
|
-
--tensor-parallel-size 1 \
|
|
21
|
-
--trust-remote-code \
|
|
22
|
-
--max-model-len 16384 \
|
|
23
|
-
--enable-lora \
|
|
24
|
-
--max-lora-rank 64 \
|
|
25
|
-
--max-cpu-loras 4 \
|
|
26
|
-
--enable-auto-tool-choice \
|
|
27
|
-
--tool-call-parser hermes \
|
|
28
|
-
--chat-template chat_template.jinja \
|
|
29
|
-
--enable-log-requests \
|
|
30
|
-
--uvicorn-log-level=debug 2>&1 | tee vllm_debug.log
|