hud-python 0.4.45__py3-none-any.whl → 0.5.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hud/__init__.py +27 -7
- hud/agents/__init__.py +70 -5
- hud/agents/base.py +238 -500
- hud/agents/claude.py +236 -247
- hud/agents/gateway.py +42 -0
- hud/agents/gemini.py +264 -0
- hud/agents/gemini_cua.py +324 -0
- hud/agents/grounded_openai.py +98 -100
- hud/agents/misc/integration_test_agent.py +51 -20
- hud/agents/misc/response_agent.py +48 -36
- hud/agents/openai.py +282 -296
- hud/agents/{openai_chat_generic.py → openai_chat.py} +63 -33
- hud/agents/operator.py +199 -0
- hud/agents/resolver.py +70 -0
- hud/agents/tests/conftest.py +133 -0
- hud/agents/tests/test_base.py +300 -622
- hud/agents/tests/test_base_runtime.py +233 -0
- hud/agents/tests/test_claude.py +381 -214
- hud/agents/tests/test_client.py +9 -10
- hud/agents/tests/test_gemini.py +369 -0
- hud/agents/tests/test_grounded_openai_agent.py +65 -50
- hud/agents/tests/test_openai.py +377 -140
- hud/agents/tests/test_operator.py +362 -0
- hud/agents/tests/test_resolver.py +192 -0
- hud/agents/tests/test_run_eval.py +179 -0
- hud/agents/types.py +148 -0
- hud/cli/__init__.py +493 -546
- hud/cli/analyze.py +43 -5
- hud/cli/build.py +699 -113
- hud/cli/debug.py +8 -5
- hud/cli/dev.py +889 -732
- hud/cli/eval.py +793 -667
- hud/cli/flows/dev.py +167 -0
- hud/cli/flows/init.py +191 -0
- hud/cli/flows/tasks.py +153 -56
- hud/cli/flows/templates.py +151 -0
- hud/cli/flows/tests/__init__.py +1 -0
- hud/cli/flows/tests/test_dev.py +126 -0
- hud/cli/init.py +60 -58
- hud/cli/pull.py +1 -1
- hud/cli/push.py +38 -13
- hud/cli/rft.py +311 -0
- hud/cli/rft_status.py +145 -0
- hud/cli/tests/test_analyze.py +5 -5
- hud/cli/tests/test_analyze_metadata.py +3 -2
- hud/cli/tests/test_analyze_module.py +120 -0
- hud/cli/tests/test_build.py +110 -8
- hud/cli/tests/test_build_failure.py +41 -0
- hud/cli/tests/test_build_module.py +50 -0
- hud/cli/tests/test_cli_init.py +6 -1
- hud/cli/tests/test_cli_more_wrappers.py +30 -0
- hud/cli/tests/test_cli_root.py +140 -0
- hud/cli/tests/test_convert.py +361 -0
- hud/cli/tests/test_debug.py +12 -10
- hud/cli/tests/test_dev.py +197 -0
- hud/cli/tests/test_eval.py +251 -0
- hud/cli/tests/test_eval_bedrock.py +51 -0
- hud/cli/tests/test_init.py +124 -0
- hud/cli/tests/test_main_module.py +11 -5
- hud/cli/tests/test_mcp_server.py +12 -100
- hud/cli/tests/test_push.py +1 -1
- hud/cli/tests/test_push_happy.py +74 -0
- hud/cli/tests/test_push_wrapper.py +23 -0
- hud/cli/tests/test_registry.py +1 -1
- hud/cli/tests/test_utils.py +1 -1
- hud/cli/{rl → utils}/celebrate.py +14 -12
- hud/cli/utils/config.py +18 -1
- hud/cli/utils/docker.py +130 -4
- hud/cli/utils/env_check.py +9 -9
- hud/cli/utils/git.py +136 -0
- hud/cli/utils/interactive.py +39 -5
- hud/cli/utils/metadata.py +70 -1
- hud/cli/utils/runner.py +1 -1
- hud/cli/utils/server.py +2 -2
- hud/cli/utils/source_hash.py +3 -3
- hud/cli/utils/tasks.py +4 -1
- hud/cli/utils/tests/__init__.py +0 -0
- hud/cli/utils/tests/test_config.py +58 -0
- hud/cli/utils/tests/test_docker.py +93 -0
- hud/cli/utils/tests/test_docker_hints.py +71 -0
- hud/cli/utils/tests/test_env_check.py +74 -0
- hud/cli/utils/tests/test_environment.py +42 -0
- hud/cli/utils/tests/test_git.py +142 -0
- hud/cli/utils/tests/test_interactive_module.py +60 -0
- hud/cli/utils/tests/test_local_runner.py +50 -0
- hud/cli/utils/tests/test_logging_utils.py +23 -0
- hud/cli/utils/tests/test_metadata.py +49 -0
- hud/cli/utils/tests/test_package_runner.py +35 -0
- hud/cli/utils/tests/test_registry_utils.py +49 -0
- hud/cli/utils/tests/test_remote_runner.py +25 -0
- hud/cli/utils/tests/test_runner_modules.py +52 -0
- hud/cli/utils/tests/test_source_hash.py +36 -0
- hud/cli/utils/tests/test_tasks.py +80 -0
- hud/cli/utils/version_check.py +258 -0
- hud/cli/{rl → utils}/viewer.py +2 -2
- hud/clients/README.md +12 -11
- hud/clients/__init__.py +4 -3
- hud/clients/base.py +166 -26
- hud/clients/environment.py +51 -0
- hud/clients/fastmcp.py +13 -6
- hud/clients/mcp_use.py +45 -15
- hud/clients/tests/test_analyze_scenarios.py +206 -0
- hud/clients/tests/test_protocol.py +9 -3
- hud/datasets/__init__.py +23 -20
- hud/datasets/loader.py +326 -0
- hud/datasets/runner.py +198 -105
- hud/datasets/tests/__init__.py +0 -0
- hud/datasets/tests/test_loader.py +221 -0
- hud/datasets/tests/test_utils.py +315 -0
- hud/datasets/utils.py +270 -90
- hud/environment/__init__.py +52 -0
- hud/environment/connection.py +258 -0
- hud/environment/connectors/__init__.py +33 -0
- hud/environment/connectors/base.py +68 -0
- hud/environment/connectors/local.py +177 -0
- hud/environment/connectors/mcp_config.py +137 -0
- hud/environment/connectors/openai.py +101 -0
- hud/environment/connectors/remote.py +172 -0
- hud/environment/environment.py +835 -0
- hud/environment/integrations/__init__.py +45 -0
- hud/environment/integrations/adk.py +67 -0
- hud/environment/integrations/anthropic.py +196 -0
- hud/environment/integrations/gemini.py +92 -0
- hud/environment/integrations/langchain.py +82 -0
- hud/environment/integrations/llamaindex.py +68 -0
- hud/environment/integrations/openai.py +238 -0
- hud/environment/mock.py +306 -0
- hud/environment/router.py +263 -0
- hud/environment/scenarios.py +620 -0
- hud/environment/tests/__init__.py +1 -0
- hud/environment/tests/test_connection.py +317 -0
- hud/environment/tests/test_connectors.py +205 -0
- hud/environment/tests/test_environment.py +593 -0
- hud/environment/tests/test_integrations.py +257 -0
- hud/environment/tests/test_local_connectors.py +242 -0
- hud/environment/tests/test_scenarios.py +1086 -0
- hud/environment/tests/test_tools.py +208 -0
- hud/environment/types.py +23 -0
- hud/environment/utils/__init__.py +35 -0
- hud/environment/utils/formats.py +215 -0
- hud/environment/utils/schema.py +171 -0
- hud/environment/utils/tool_wrappers.py +113 -0
- hud/eval/__init__.py +67 -0
- hud/eval/context.py +727 -0
- hud/eval/display.py +299 -0
- hud/eval/instrument.py +187 -0
- hud/eval/manager.py +533 -0
- hud/eval/parallel.py +268 -0
- hud/eval/task.py +372 -0
- hud/eval/tests/__init__.py +1 -0
- hud/eval/tests/test_context.py +178 -0
- hud/eval/tests/test_eval.py +210 -0
- hud/eval/tests/test_manager.py +152 -0
- hud/eval/tests/test_parallel.py +168 -0
- hud/eval/tests/test_task.py +291 -0
- hud/eval/types.py +65 -0
- hud/eval/utils.py +194 -0
- hud/patches/__init__.py +19 -0
- hud/patches/mcp_patches.py +308 -0
- hud/patches/warnings.py +54 -0
- hud/samples/browser.py +4 -4
- hud/server/__init__.py +2 -1
- hud/server/low_level.py +2 -1
- hud/server/router.py +164 -0
- hud/server/server.py +567 -80
- hud/server/tests/test_mcp_server_integration.py +11 -11
- hud/server/tests/test_mcp_server_more.py +1 -1
- hud/server/tests/test_server_extra.py +2 -0
- hud/settings.py +45 -3
- hud/shared/exceptions.py +36 -10
- hud/shared/hints.py +26 -1
- hud/shared/requests.py +15 -3
- hud/shared/tests/test_exceptions.py +40 -31
- hud/shared/tests/test_hints.py +167 -0
- hud/telemetry/__init__.py +20 -19
- hud/telemetry/exporter.py +201 -0
- hud/telemetry/instrument.py +165 -253
- hud/telemetry/tests/test_eval_telemetry.py +356 -0
- hud/telemetry/tests/test_exporter.py +258 -0
- hud/telemetry/tests/test_instrument.py +401 -0
- hud/tools/__init__.py +18 -2
- hud/tools/agent.py +223 -0
- hud/tools/apply_patch.py +639 -0
- hud/tools/base.py +54 -4
- hud/tools/bash.py +2 -2
- hud/tools/computer/__init__.py +36 -3
- hud/tools/computer/anthropic.py +2 -2
- hud/tools/computer/gemini.py +385 -0
- hud/tools/computer/hud.py +23 -6
- hud/tools/computer/openai.py +20 -21
- hud/tools/computer/qwen.py +434 -0
- hud/tools/computer/settings.py +37 -0
- hud/tools/edit.py +3 -7
- hud/tools/executors/base.py +4 -2
- hud/tools/executors/pyautogui.py +1 -1
- hud/tools/grounding/grounded_tool.py +13 -18
- hud/tools/grounding/grounder.py +10 -31
- hud/tools/grounding/tests/test_grounded_tool.py +26 -44
- hud/tools/jupyter.py +330 -0
- hud/tools/playwright.py +18 -3
- hud/tools/shell.py +308 -0
- hud/tools/tests/test_agent_tool.py +355 -0
- hud/tools/tests/test_apply_patch.py +718 -0
- hud/tools/tests/test_computer.py +4 -9
- hud/tools/tests/test_computer_actions.py +24 -2
- hud/tools/tests/test_jupyter_tool.py +181 -0
- hud/tools/tests/test_shell.py +596 -0
- hud/tools/tests/test_submit.py +85 -0
- hud/tools/tests/test_types.py +193 -0
- hud/tools/types.py +21 -1
- hud/types.py +194 -56
- hud/utils/__init__.py +2 -0
- hud/utils/env.py +67 -0
- hud/utils/hud_console.py +89 -18
- hud/utils/mcp.py +15 -58
- hud/utils/strict_schema.py +162 -0
- hud/utils/tests/test_init.py +1 -2
- hud/utils/tests/test_mcp.py +1 -28
- hud/utils/tests/test_pretty_errors.py +186 -0
- hud/utils/tests/test_tool_shorthand.py +154 -0
- hud/utils/tests/test_version.py +1 -1
- hud/utils/types.py +20 -0
- hud/version.py +1 -1
- hud_python-0.5.13.dist-info/METADATA +264 -0
- hud_python-0.5.13.dist-info/RECORD +305 -0
- {hud_python-0.4.45.dist-info → hud_python-0.5.13.dist-info}/WHEEL +1 -1
- hud/agents/langchain.py +0 -261
- hud/agents/lite_llm.py +0 -72
- hud/cli/rl/__init__.py +0 -180
- hud/cli/rl/config.py +0 -101
- hud/cli/rl/display.py +0 -133
- hud/cli/rl/gpu.py +0 -63
- hud/cli/rl/gpu_utils.py +0 -321
- hud/cli/rl/local_runner.py +0 -595
- hud/cli/rl/presets.py +0 -96
- hud/cli/rl/remote_runner.py +0 -463
- hud/cli/rl/rl_api.py +0 -150
- hud/cli/rl/vllm.py +0 -177
- hud/cli/rl/wait_utils.py +0 -89
- hud/datasets/parallel.py +0 -687
- hud/misc/__init__.py +0 -1
- hud/misc/claude_plays_pokemon.py +0 -292
- hud/otel/__init__.py +0 -35
- hud/otel/collector.py +0 -142
- hud/otel/config.py +0 -181
- hud/otel/context.py +0 -570
- hud/otel/exporters.py +0 -369
- hud/otel/instrumentation.py +0 -135
- hud/otel/processors.py +0 -121
- hud/otel/tests/__init__.py +0 -1
- hud/otel/tests/test_processors.py +0 -197
- hud/rl/README.md +0 -30
- hud/rl/__init__.py +0 -1
- hud/rl/actor.py +0 -176
- hud/rl/buffer.py +0 -405
- hud/rl/chat_template.jinja +0 -101
- hud/rl/config.py +0 -192
- hud/rl/distributed.py +0 -132
- hud/rl/learner.py +0 -637
- hud/rl/tests/__init__.py +0 -1
- hud/rl/tests/test_learner.py +0 -186
- hud/rl/train.py +0 -382
- hud/rl/types.py +0 -101
- hud/rl/utils/start_vllm_server.sh +0 -30
- hud/rl/utils.py +0 -524
- hud/rl/vllm_adapter.py +0 -143
- hud/telemetry/job.py +0 -352
- hud/telemetry/replay.py +0 -74
- hud/telemetry/tests/test_replay.py +0 -40
- hud/telemetry/tests/test_trace.py +0 -63
- hud/telemetry/trace.py +0 -158
- hud/utils/agent_factories.py +0 -86
- hud/utils/async_utils.py +0 -65
- hud/utils/group_eval.py +0 -223
- hud/utils/progress.py +0 -149
- hud/utils/tasks.py +0 -127
- hud/utils/tests/test_async_utils.py +0 -173
- hud/utils/tests/test_progress.py +0 -261
- hud_python-0.4.45.dist-info/METADATA +0 -552
- hud_python-0.4.45.dist-info/RECORD +0 -228
- {hud_python-0.4.45.dist-info → hud_python-0.5.13.dist-info}/entry_points.txt +0 -0
- {hud_python-0.4.45.dist-info → hud_python-0.5.13.dist-info}/licenses/LICENSE +0 -0
hud/rl/learner.py
DELETED
|
@@ -1,637 +0,0 @@
|
|
|
1
|
-
"""GRPO learner for vision-language and text models."""
|
|
2
|
-
|
|
3
|
-
from __future__ import annotations
|
|
4
|
-
|
|
5
|
-
import logging
|
|
6
|
-
import os
|
|
7
|
-
from typing import TYPE_CHECKING, Any
|
|
8
|
-
|
|
9
|
-
import torch
|
|
10
|
-
from peft import LoraConfig, get_peft_model
|
|
11
|
-
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
12
|
-
from transformers import (
|
|
13
|
-
AutoModelForCausalLM,
|
|
14
|
-
AutoProcessor,
|
|
15
|
-
AutoTokenizer,
|
|
16
|
-
Qwen2_5_VLForConditionalGeneration,
|
|
17
|
-
)
|
|
18
|
-
|
|
19
|
-
try:
|
|
20
|
-
from liger_kernel.transformers import apply_liger_kernel_to_qwen2_5_vl # type: ignore
|
|
21
|
-
|
|
22
|
-
LIGER_AVAILABLE = True
|
|
23
|
-
except ImportError:
|
|
24
|
-
LIGER_AVAILABLE = False
|
|
25
|
-
|
|
26
|
-
try:
|
|
27
|
-
import bitsandbytes as bnb # type: ignore
|
|
28
|
-
|
|
29
|
-
BNB_AVAILABLE = True
|
|
30
|
-
except ImportError:
|
|
31
|
-
BNB_AVAILABLE = False
|
|
32
|
-
|
|
33
|
-
from contextlib import nullcontext
|
|
34
|
-
|
|
35
|
-
from hud.rl.distributed import (
|
|
36
|
-
get_local_rank,
|
|
37
|
-
get_world_size,
|
|
38
|
-
is_main_process,
|
|
39
|
-
)
|
|
40
|
-
from hud.rl.utils import (
|
|
41
|
-
batch_training_samples,
|
|
42
|
-
entropy_from_logits,
|
|
43
|
-
get_gpu_utilization,
|
|
44
|
-
get_memory_usage,
|
|
45
|
-
prepare_inputs,
|
|
46
|
-
)
|
|
47
|
-
from hud.utils.hud_console import HUDConsole
|
|
48
|
-
|
|
49
|
-
from .types import TrainingMetrics, TrainingSample
|
|
50
|
-
|
|
51
|
-
logger = logging.getLogger(__name__)
|
|
52
|
-
hud_console = HUDConsole(logger)
|
|
53
|
-
|
|
54
|
-
if TYPE_CHECKING:
|
|
55
|
-
from .config import Config
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
class GRPOLearner:
|
|
59
|
-
"""GRPO learning algorithm for Vision-Language Models (VLMs) and Text Models."""
|
|
60
|
-
|
|
61
|
-
def __init__(self, config: Config) -> None:
|
|
62
|
-
self.config = config
|
|
63
|
-
self.local_rank = get_local_rank()
|
|
64
|
-
self.world_size = get_world_size()
|
|
65
|
-
self.device = torch.device(
|
|
66
|
-
f"cuda:{self.local_rank}" if torch.cuda.is_available() else "cpu"
|
|
67
|
-
)
|
|
68
|
-
|
|
69
|
-
# Detect model type
|
|
70
|
-
self.is_vl_model = "VL" in config.model.base_model
|
|
71
|
-
|
|
72
|
-
# Load models and processor
|
|
73
|
-
self.processor, self.policy, self.ref, self.optimizer = self._load_models()
|
|
74
|
-
self.metrics: list[TrainingMetrics] = []
|
|
75
|
-
|
|
76
|
-
def log(self, message: str) -> None:
|
|
77
|
-
hud_console.info_log(f"[{self.local_rank}] {message}")
|
|
78
|
-
|
|
79
|
-
def _load_models(self) -> tuple[Any, Any, Any, Any]:
|
|
80
|
-
"""Load policy, reference models and optimizer."""
|
|
81
|
-
model_cfg = self.config.model
|
|
82
|
-
|
|
83
|
-
# Detect if this is a VL model or standard text model
|
|
84
|
-
is_vl_model = "VL" in model_cfg.base_model
|
|
85
|
-
model_type = "Vision-Language" if is_vl_model else "Text"
|
|
86
|
-
self.log(f"Loading {model_type} model: {model_cfg.base_model}")
|
|
87
|
-
|
|
88
|
-
# Apply Liger kernel optimizations if available and enabled
|
|
89
|
-
if model_cfg.use_liger and LIGER_AVAILABLE:
|
|
90
|
-
if is_vl_model:
|
|
91
|
-
self.log("Applying Liger kernel optimizations to Qwen2.5-VL")
|
|
92
|
-
apply_liger_kernel_to_qwen2_5_vl(
|
|
93
|
-
rope=True, # Optimized RoPE
|
|
94
|
-
rms_norm=True, # Optimized RMSNorm
|
|
95
|
-
swiglu=True, # Optimized SwiGLU
|
|
96
|
-
fused_linear_cross_entropy=True, # Fused Linear+CrossEntropy for memory
|
|
97
|
-
)
|
|
98
|
-
elif model_cfg.use_liger and not LIGER_AVAILABLE:
|
|
99
|
-
self.log(
|
|
100
|
-
"Liger kernel requested but not installed. Install with: pip install liger-kernel"
|
|
101
|
-
)
|
|
102
|
-
|
|
103
|
-
# Load processor/tokenizer based on model type
|
|
104
|
-
if is_vl_model:
|
|
105
|
-
# Some environments require remote code for Qwen2.5-VL processors
|
|
106
|
-
processor = AutoProcessor.from_pretrained(
|
|
107
|
-
model_cfg.base_model,
|
|
108
|
-
min_pixels=model_cfg.min_pixels,
|
|
109
|
-
max_pixels=model_cfg.max_pixels,
|
|
110
|
-
trust_remote_code=True,
|
|
111
|
-
)
|
|
112
|
-
else:
|
|
113
|
-
processor = AutoTokenizer.from_pretrained(model_cfg.base_model)
|
|
114
|
-
|
|
115
|
-
# Load policy model with LoRA
|
|
116
|
-
# Use attention implementation from config
|
|
117
|
-
attn_implementation = model_cfg.attn_implementation
|
|
118
|
-
|
|
119
|
-
# Choose the appropriate model class
|
|
120
|
-
model_class = Qwen2_5_VLForConditionalGeneration if is_vl_model else AutoModelForCausalLM
|
|
121
|
-
|
|
122
|
-
try:
|
|
123
|
-
policy = model_class.from_pretrained(
|
|
124
|
-
model_cfg.base_model,
|
|
125
|
-
torch_dtype=torch.bfloat16,
|
|
126
|
-
attn_implementation=attn_implementation,
|
|
127
|
-
trust_remote_code=True,
|
|
128
|
-
)
|
|
129
|
-
self.log(f"Using {attn_implementation} for attention")
|
|
130
|
-
except (ImportError, ValueError) as e:
|
|
131
|
-
# Only fallback if explicitly using flash_attention_2 and it's not available
|
|
132
|
-
if attn_implementation == "flash_attention_2":
|
|
133
|
-
self.log(f"Flash Attention 2 not available ({e}), using eager attention")
|
|
134
|
-
policy = model_class.from_pretrained(
|
|
135
|
-
model_cfg.base_model,
|
|
136
|
-
torch_dtype=torch.bfloat16,
|
|
137
|
-
attn_implementation="eager",
|
|
138
|
-
)
|
|
139
|
-
else:
|
|
140
|
-
raise # Re-raise if it's a different error
|
|
141
|
-
|
|
142
|
-
# Move model to device
|
|
143
|
-
policy = policy.to(self.device) # type: ignore
|
|
144
|
-
# Enable gradient checkpointing for memory efficiency
|
|
145
|
-
if model_cfg.gradient_checkpointing:
|
|
146
|
-
policy.gradient_checkpointing_enable()
|
|
147
|
-
self.log("Gradient checkpointing enabled for memory efficiency")
|
|
148
|
-
|
|
149
|
-
# Add LoRA adapters
|
|
150
|
-
lora_config = LoraConfig(
|
|
151
|
-
r=model_cfg.lora_r,
|
|
152
|
-
lora_alpha=model_cfg.lora_alpha,
|
|
153
|
-
lora_dropout=model_cfg.lora_dropout,
|
|
154
|
-
task_type="CAUSAL_LM",
|
|
155
|
-
bias="none",
|
|
156
|
-
target_modules=list(model_cfg.target_modules),
|
|
157
|
-
)
|
|
158
|
-
policy.config.use_cache = False
|
|
159
|
-
policy = get_peft_model(policy, lora_config)
|
|
160
|
-
|
|
161
|
-
# Wrap with DDP if in distributed mode
|
|
162
|
-
if self.world_size > 1:
|
|
163
|
-
policy = DDP(
|
|
164
|
-
policy,
|
|
165
|
-
device_ids=[self.local_rank],
|
|
166
|
-
output_device=self.local_rank,
|
|
167
|
-
broadcast_buffers=False,
|
|
168
|
-
find_unused_parameters=True,
|
|
169
|
-
)
|
|
170
|
-
self.log("Wrapped model (find_unused_parameters=True)")
|
|
171
|
-
|
|
172
|
-
# Create optimizer - need to access underlying model if DDP
|
|
173
|
-
base_model = policy.module if hasattr(policy, "module") else policy
|
|
174
|
-
trainable_params = [p for _, p in base_model.named_parameters() if p.requires_grad] # type: ignore
|
|
175
|
-
|
|
176
|
-
# Use 8-bit optimizer if configured
|
|
177
|
-
if self.config.training.use_8bit_optimizer and BNB_AVAILABLE:
|
|
178
|
-
hud_console.info("Using 8-bit AdamW optimizer from bitsandbytes")
|
|
179
|
-
optimizer = bnb.optim.AdamW8bit(
|
|
180
|
-
trainable_params,
|
|
181
|
-
lr=self.config.training.lr,
|
|
182
|
-
betas=self.config.training.adam_betas,
|
|
183
|
-
eps=self.config.training.adam_eps,
|
|
184
|
-
)
|
|
185
|
-
else:
|
|
186
|
-
self.log("Using standard FP32 AdamW optimizer")
|
|
187
|
-
optimizer = torch.optim.AdamW(
|
|
188
|
-
trainable_params,
|
|
189
|
-
lr=self.config.training.lr,
|
|
190
|
-
betas=self.config.training.adam_betas,
|
|
191
|
-
eps=self.config.training.adam_eps,
|
|
192
|
-
)
|
|
193
|
-
|
|
194
|
-
# Log optimizer info
|
|
195
|
-
self.log(f"Optimizer: {type(optimizer).__name__}")
|
|
196
|
-
num_params = sum(p.numel() for p in trainable_params)
|
|
197
|
-
self.log(f"Number of trainable parameters: {num_params:,}")
|
|
198
|
-
|
|
199
|
-
return processor, policy, None, optimizer
|
|
200
|
-
|
|
201
|
-
def prepare_groups(
|
|
202
|
-
self,
|
|
203
|
-
samples: list[TrainingSample],
|
|
204
|
-
) -> list[list[TrainingSample]]:
|
|
205
|
-
"""Prepare groups of samples for training."""
|
|
206
|
-
# Prepare inputs with messages
|
|
207
|
-
batch = []
|
|
208
|
-
for sample in samples:
|
|
209
|
-
inputs = prepare_inputs(sample, self.processor)
|
|
210
|
-
# If inputs are invalid, create dummy inputs to maintain batch size
|
|
211
|
-
if (
|
|
212
|
-
not inputs
|
|
213
|
-
or "input_ids" not in inputs
|
|
214
|
-
or inputs.get("input_ids", torch.tensor([])).numel() == 0
|
|
215
|
-
):
|
|
216
|
-
hud_console.warning_log("Sample has invalid inputs, using dummy values")
|
|
217
|
-
# Create minimal dummy inputs to keep batch size consistent
|
|
218
|
-
inputs = {
|
|
219
|
-
"input_ids": torch.zeros(1, 2, dtype=torch.long), # Minimal sequence
|
|
220
|
-
"attention_mask": torch.ones(1, 2, dtype=torch.long),
|
|
221
|
-
"assistant_mask": torch.zeros(1, 1, dtype=torch.bool), # T-1 length
|
|
222
|
-
}
|
|
223
|
-
elif "assistant_mask" not in inputs:
|
|
224
|
-
hud_console.warning_log("Sample missing assistant_mask, creating zero mask")
|
|
225
|
-
seq_len = inputs["input_ids"].shape[-1]
|
|
226
|
-
inputs["assistant_mask"] = torch.zeros(
|
|
227
|
-
inputs["input_ids"].shape[0], seq_len - 1, dtype=torch.bool
|
|
228
|
-
)
|
|
229
|
-
|
|
230
|
-
new_sample = TrainingSample(**sample.model_dump())
|
|
231
|
-
new_sample.inputs = inputs
|
|
232
|
-
new_sample.advantage = sample.advantage
|
|
233
|
-
batch.append(new_sample)
|
|
234
|
-
|
|
235
|
-
with hud_console.progress("Processing batch of traces...") as progress, torch.no_grad():
|
|
236
|
-
for i, sample in enumerate(batch):
|
|
237
|
-
if is_main_process():
|
|
238
|
-
progress.update(f"Processing batch of traces... {i}/{len(batch)}")
|
|
239
|
-
if sample.inputs:
|
|
240
|
-
sample = sample.to_device(self.device)
|
|
241
|
-
sample.old_logprobs, _ = self.compute_logprobs(self.policy, sample.inputs)
|
|
242
|
-
# Free GPU memory for this sample immediately
|
|
243
|
-
sample.to_device(torch.device("cpu"))
|
|
244
|
-
|
|
245
|
-
policy_module = self.policy.module if hasattr(self.policy, "module") else self.policy
|
|
246
|
-
with policy_module.disable_adapter():
|
|
247
|
-
for i, sample in enumerate(batch):
|
|
248
|
-
if is_main_process():
|
|
249
|
-
progress.update(f"Processing batch of traces... {i}/{len(batch)}")
|
|
250
|
-
if sample.inputs:
|
|
251
|
-
# Move back to GPU for reference computation, then free
|
|
252
|
-
sample = sample.to_device(self.device)
|
|
253
|
-
sample.ref_logprobs, _ = self.compute_logprobs(self.policy, sample.inputs)
|
|
254
|
-
sample.to_device(torch.device("cpu"))
|
|
255
|
-
|
|
256
|
-
hud_console.info_log("Creating mini-batches...")
|
|
257
|
-
group_size = self.config.training.group_size
|
|
258
|
-
processed_batch = []
|
|
259
|
-
if not self.config.training.accumulate_over_minibatches:
|
|
260
|
-
# Find minibatches and group them via batch_training_samples
|
|
261
|
-
# Minibatches control the batch size of the forward pass to the model
|
|
262
|
-
mb_size = self.config.training.mini_batch_size
|
|
263
|
-
group_size = group_size // mb_size
|
|
264
|
-
for i in range(0, len(batch), mb_size):
|
|
265
|
-
processed_batch.extend(batch_training_samples(batch[i : i + mb_size]))
|
|
266
|
-
else:
|
|
267
|
-
processed_batch = batch
|
|
268
|
-
|
|
269
|
-
for sample in processed_batch:
|
|
270
|
-
sample.to_device(torch.device("cpu"))
|
|
271
|
-
|
|
272
|
-
# Convert to grouped batches (if updating the model after each task group)
|
|
273
|
-
if self.config.training.update_after_group:
|
|
274
|
-
return [
|
|
275
|
-
processed_batch[i : i + group_size]
|
|
276
|
-
for i in range(0, len(processed_batch), group_size)
|
|
277
|
-
]
|
|
278
|
-
else:
|
|
279
|
-
return [processed_batch]
|
|
280
|
-
|
|
281
|
-
def update(self, samples: list[TrainingSample]) -> TrainingMetrics:
|
|
282
|
-
"""Perform a gradient update on a batch."""
|
|
283
|
-
import time
|
|
284
|
-
|
|
285
|
-
training_start_time = time.time()
|
|
286
|
-
|
|
287
|
-
# Always create metrics for synchronization
|
|
288
|
-
self.metrics.append(TrainingMetrics())
|
|
289
|
-
metrics = self.metrics[-1]
|
|
290
|
-
|
|
291
|
-
# Prepare groups for GRPO training
|
|
292
|
-
groups = self.prepare_groups(samples)
|
|
293
|
-
self.log(f"Updating over {len(groups)} groups")
|
|
294
|
-
|
|
295
|
-
# Update over mini batch size
|
|
296
|
-
with hud_console.progress("Gradient update...") as progress:
|
|
297
|
-
for epoch in range(self.config.training.epochs): # Do not accumulate across epochs
|
|
298
|
-
progress.update(f"Training epoch {epoch + 1}/{self.config.training.epochs}")
|
|
299
|
-
for group_idx, group in enumerate(groups): # Do not accumulate across "groups"
|
|
300
|
-
self.optimizer.zero_grad(set_to_none=True)
|
|
301
|
-
|
|
302
|
-
debug_per_group = ""
|
|
303
|
-
grad_accum_steps = len(group)
|
|
304
|
-
# Tensor for distributed sync
|
|
305
|
-
global_skip = torch.zeros(1, device=self.device)
|
|
306
|
-
|
|
307
|
-
for s_idx, sample_minibatch in enumerate(group):
|
|
308
|
-
# self.log(f"{group_idx} {sample_minibatch.inputs['assistant_mask'].sum()}")
|
|
309
|
-
# mini_updated = sample_minibatch.inputs["assistant_mask"].sum() > 0
|
|
310
|
-
|
|
311
|
-
# Update mini_updated globally
|
|
312
|
-
# self.log(f"{group_idx} Mini updated: {mini_updated}")
|
|
313
|
-
|
|
314
|
-
# Do not sync until the last minibatch
|
|
315
|
-
if s_idx < len(group) - 1 and self.world_size > 1:
|
|
316
|
-
ddp_ctx = self.policy.no_sync()
|
|
317
|
-
else:
|
|
318
|
-
ddp_ctx = nullcontext()
|
|
319
|
-
|
|
320
|
-
with ddp_ctx, torch.autocast(device_type="cuda", dtype=torch.bfloat16):
|
|
321
|
-
try:
|
|
322
|
-
# if mini_updated:
|
|
323
|
-
loss = self.compute_loss(sample_minibatch) / grad_accum_steps
|
|
324
|
-
debug_per_group += f"l{s_idx}:{round(loss.item(), 3)!s} "
|
|
325
|
-
loss.backward()
|
|
326
|
-
# else: # Dummy backward that touches all params, produces zero g
|
|
327
|
-
# dummy = sum(p.sum() for p in self.policy.parameters()) * 0.0
|
|
328
|
-
# debug_per_group += f"d{s_idx}:{str(round(dummy.item(), 3))} "
|
|
329
|
-
# dummy.backward()
|
|
330
|
-
# self.log(f"{group_idx} GPU Backward: {get_gpu_utilization():.1f}% | Memory: {get_memory_usage():.2f} GB") # noqa: E501
|
|
331
|
-
except torch.cuda.OutOfMemoryError:
|
|
332
|
-
hud_console.warning_log(
|
|
333
|
-
f"{group_idx} CUDA OOM for {sample_minibatch.inputs['input_ids'].numel()} tokens; skipping minibatch" # noqa: E501
|
|
334
|
-
)
|
|
335
|
-
# Dummy backward to keep DDP happy
|
|
336
|
-
dummy = torch.sum(p.sum() for p in self.policy.parameters()) * 0.0 # type: ignore
|
|
337
|
-
debug_per_group += f"o{s_idx}:{round(dummy.item(), 3)!s} "
|
|
338
|
-
dummy.backward()
|
|
339
|
-
# mark global skip if OOM
|
|
340
|
-
global_skip.fill_(1)
|
|
341
|
-
continue
|
|
342
|
-
|
|
343
|
-
if torch.cuda.is_available():
|
|
344
|
-
torch.cuda.empty_cache()
|
|
345
|
-
|
|
346
|
-
# After minibatches loop, sync skip across ranks
|
|
347
|
-
if torch.distributed.is_initialized():
|
|
348
|
-
torch.distributed.all_reduce(global_skip, op=torch.distributed.ReduceOp.MAX)
|
|
349
|
-
skip_any = bool(global_skip.item())
|
|
350
|
-
|
|
351
|
-
if skip_any:
|
|
352
|
-
self.log(f"G[{group_idx}] {debug_per_group} N/A (skipped)")
|
|
353
|
-
continue
|
|
354
|
-
|
|
355
|
-
grad_norm = torch.nn.utils.clip_grad_norm_(
|
|
356
|
-
self.policy.parameters(),
|
|
357
|
-
self.config.training.grad_clip,
|
|
358
|
-
error_if_nonfinite=True,
|
|
359
|
-
)
|
|
360
|
-
self.optimizer.step()
|
|
361
|
-
|
|
362
|
-
debug_per_group += f"g:{round(grad_norm.item(), 3)!s}"
|
|
363
|
-
self.log(f"G[{group_idx}] {debug_per_group}")
|
|
364
|
-
|
|
365
|
-
metrics.update(
|
|
366
|
-
{
|
|
367
|
-
"grad_norm": grad_norm.item()
|
|
368
|
-
if isinstance(grad_norm, torch.Tensor)
|
|
369
|
-
else float(grad_norm),
|
|
370
|
-
}
|
|
371
|
-
)
|
|
372
|
-
|
|
373
|
-
# Calculate training time and throughput
|
|
374
|
-
training_time = time.time() - training_start_time
|
|
375
|
-
total_samples = (
|
|
376
|
-
len(groups) * self.config.training.group_size * self.config.training.mini_batch_size
|
|
377
|
-
)
|
|
378
|
-
samples_per_second = total_samples / training_time if training_time > 0 else 0.0
|
|
379
|
-
|
|
380
|
-
metrics.update(
|
|
381
|
-
{
|
|
382
|
-
"training_time": training_time,
|
|
383
|
-
"samples_per_second": samples_per_second,
|
|
384
|
-
}
|
|
385
|
-
)
|
|
386
|
-
|
|
387
|
-
return metrics
|
|
388
|
-
|
|
389
|
-
def compute_loss(self, sample: TrainingSample) -> torch.Tensor:
|
|
390
|
-
"""Compute GRPO loss for a batch of samples."""
|
|
391
|
-
training_cfg = self.config.training
|
|
392
|
-
metrics = self.metrics[-1] if len(self.metrics) > 0 else TrainingMetrics()
|
|
393
|
-
|
|
394
|
-
sample.to_device(self.device)
|
|
395
|
-
|
|
396
|
-
pol_logp, pol_entropy = self.compute_logprobs(
|
|
397
|
-
self.policy,
|
|
398
|
-
sample.inputs,
|
|
399
|
-
)
|
|
400
|
-
|
|
401
|
-
sanity_check(sample, pol_logp, sample.old_logprobs, sample.ref_logprobs)
|
|
402
|
-
|
|
403
|
-
metrics.update(
|
|
404
|
-
{
|
|
405
|
-
"gpu_util": get_gpu_utilization(), # Track peak utilization
|
|
406
|
-
"gpu_memory": get_memory_usage(), # Track memory usage
|
|
407
|
-
}
|
|
408
|
-
)
|
|
409
|
-
self.log(f"GPU Util: {get_gpu_utilization():.1f}% | Memory: {get_memory_usage():.2f} GB")
|
|
410
|
-
|
|
411
|
-
old_logp = sample.old_logprobs
|
|
412
|
-
ref_logp = sample.ref_logprobs
|
|
413
|
-
|
|
414
|
-
if old_logp is None or ref_logp is None or sample.advantage is None:
|
|
415
|
-
raise ValueError("old_logp, ref_logp, or sample.advantage is None")
|
|
416
|
-
|
|
417
|
-
# Use assistant mask to remove non-assistant tokens
|
|
418
|
-
m = sample.inputs["assistant_mask"]
|
|
419
|
-
|
|
420
|
-
# Aggregate per trace or per token
|
|
421
|
-
if training_cfg.ppo_mode == "per_trace":
|
|
422
|
-
counts = m.sum(dim=1).clamp_min(1.0)
|
|
423
|
-
pol_logp = (pol_logp * m.float()).sum(dim=1) / counts
|
|
424
|
-
pol_entropy = (pol_entropy * m.float()).sum(dim=1) / counts
|
|
425
|
-
old_logp = (old_logp * m.float()).sum(dim=1) / counts
|
|
426
|
-
ref_logp = (ref_logp * m.float()).sum(dim=1) / counts
|
|
427
|
-
|
|
428
|
-
# Clip log probability differences
|
|
429
|
-
log_ratio = torch.where(m, pol_logp - old_logp, torch.zeros_like(pol_logp))
|
|
430
|
-
ratio_tok = torch.exp(log_ratio.clamp(-20.0, 20.0))
|
|
431
|
-
|
|
432
|
-
# Ensure advantage shape matches ratio_tok for broadcasting
|
|
433
|
-
advantage = (
|
|
434
|
-
sample.advantage.view(-1, 1) if ratio_tok.dim() == 2 else sample.advantage.squeeze(-1)
|
|
435
|
-
)
|
|
436
|
-
|
|
437
|
-
unclipped = ratio_tok * advantage
|
|
438
|
-
clipped = (
|
|
439
|
-
torch.clamp(ratio_tok, 1 - training_cfg.top_eps, 1 + training_cfg.bottom_eps)
|
|
440
|
-
* advantage
|
|
441
|
-
)
|
|
442
|
-
|
|
443
|
-
policy_term = -torch.minimum(unclipped, clipped)
|
|
444
|
-
|
|
445
|
-
# Clip log probability differences in KL
|
|
446
|
-
log_rho = torch.where(m, pol_logp - ref_logp, torch.zeros_like(pol_logp))
|
|
447
|
-
rho_tok = torch.exp(log_rho.clamp(-20.0, 20.0))
|
|
448
|
-
kl_approx = rho_tok - torch.log(rho_tok) - 1
|
|
449
|
-
|
|
450
|
-
total_loss = (
|
|
451
|
-
policy_term + training_cfg.kl_beta * kl_approx + training_cfg.entropy_beta * pol_entropy
|
|
452
|
-
)
|
|
453
|
-
|
|
454
|
-
# Aggregate loss
|
|
455
|
-
if training_cfg.ppo_mode == "per_trace":
|
|
456
|
-
total_loss = total_loss.mean() if training_cfg.token_agg == "mean" else total_loss.sum() # noqa: S105
|
|
457
|
-
else:
|
|
458
|
-
if training_cfg.token_agg == "mean": # noqa: S105
|
|
459
|
-
total_loss = (total_loss * m).sum() / m.sum().clamp_min(1.0)
|
|
460
|
-
else:
|
|
461
|
-
total_loss = (total_loss * m).sum()
|
|
462
|
-
|
|
463
|
-
# Compute metrics only over masked (assistant) tokens
|
|
464
|
-
mask_count = m.sum().clamp_min(1.0)
|
|
465
|
-
metrics.update(
|
|
466
|
-
{
|
|
467
|
-
"policy_ratio": (ratio_tok * m).sum().item() / mask_count.item()
|
|
468
|
-
if mask_count.item() > 0
|
|
469
|
-
else 1.0,
|
|
470
|
-
"kl": (kl_approx * m).sum().item() / mask_count.item()
|
|
471
|
-
if mask_count.item() > 0
|
|
472
|
-
else 0.0,
|
|
473
|
-
"entropy": (pol_entropy * m).sum().item() / mask_count.item()
|
|
474
|
-
if mask_count.item() > 0
|
|
475
|
-
else 0.0,
|
|
476
|
-
"tokens": sample.inputs["input_ids"].numel(),
|
|
477
|
-
"loss": total_loss.item(),
|
|
478
|
-
}
|
|
479
|
-
)
|
|
480
|
-
|
|
481
|
-
sample.to_device(torch.device("cpu"))
|
|
482
|
-
|
|
483
|
-
return total_loss
|
|
484
|
-
|
|
485
|
-
def compute_logprobs(self, model: Any, inputs: Any) -> tuple[torch.Tensor, torch.Tensor]:
|
|
486
|
-
"""Compute masked per-token log probabilities via the model.
|
|
487
|
-
|
|
488
|
-
Returns log probabilities for the actual next tokens.
|
|
489
|
-
"""
|
|
490
|
-
try:
|
|
491
|
-
model_inputs = {k: v for k, v in inputs.items() if k != "assistant_mask"}
|
|
492
|
-
out = model(**model_inputs)
|
|
493
|
-
|
|
494
|
-
logits = out.logits / self.config.actor.temperature
|
|
495
|
-
|
|
496
|
-
targets = inputs["input_ids"][:, 1:]
|
|
497
|
-
|
|
498
|
-
# Align logits to predict next token: use logits[:, :-1, :]
|
|
499
|
-
next_logits = logits[:, :-1, :]
|
|
500
|
-
|
|
501
|
-
token_log_probs = _selective_log_softmax(next_logits, targets)
|
|
502
|
-
|
|
503
|
-
# Compute entropy only for assistant tokens to save memory
|
|
504
|
-
assistant_mask = inputs["assistant_mask"]
|
|
505
|
-
entropy = torch.zeros_like(token_log_probs)
|
|
506
|
-
if assistant_mask.any():
|
|
507
|
-
entropy[assistant_mask] = entropy_from_logits(logits[:, :-1][assistant_mask])
|
|
508
|
-
|
|
509
|
-
return token_log_probs, entropy
|
|
510
|
-
except (IndexError, RuntimeError) as e:
|
|
511
|
-
# Handle empty inputs or DDP errors
|
|
512
|
-
hud_console.warning_log(f"Error in compute_logprobs: {e}. Returning dummy values.")
|
|
513
|
-
# Return dummy values that match expected shapes
|
|
514
|
-
seq_len = inputs["input_ids"].shape[1] - 1 if "input_ids" in inputs else 0
|
|
515
|
-
batch_size = inputs["input_ids"].shape[0] if "input_ids" in inputs else 1
|
|
516
|
-
# Create dummy tensors that still participate in autograd so backward doesn't fail
|
|
517
|
-
try:
|
|
518
|
-
# Touch params to build a graph
|
|
519
|
-
param_sum = torch.sum(next(self.policy.parameters()))
|
|
520
|
-
base = param_sum * 0.0
|
|
521
|
-
except StopIteration:
|
|
522
|
-
base = torch.tensor(0.0, device=self.device)
|
|
523
|
-
dummy_logprobs = (
|
|
524
|
-
base + torch.zeros(batch_size, seq_len, device=self.device)
|
|
525
|
-
).requires_grad_(True)
|
|
526
|
-
dummy_entropy = (
|
|
527
|
-
base + torch.zeros(batch_size, seq_len, device=self.device)
|
|
528
|
-
).requires_grad_(True)
|
|
529
|
-
return dummy_logprobs, dummy_entropy
|
|
530
|
-
|
|
531
|
-
def save(self, path: str) -> None:
|
|
532
|
-
"""Save the current policy checkpoint (only on rank 0)."""
|
|
533
|
-
if is_main_process():
|
|
534
|
-
os.makedirs(path, exist_ok=True)
|
|
535
|
-
# Unwrap DDP model if needed
|
|
536
|
-
model_to_save = self.policy.module if hasattr(self.policy, "module") else self.policy
|
|
537
|
-
model_to_save.save_pretrained(path)
|
|
538
|
-
self.log(f"Saved checkpoint to {path}")
|
|
539
|
-
|
|
540
|
-
def load(self, path: str) -> None:
|
|
541
|
-
"""Load a policy checkpoint."""
|
|
542
|
-
# Would need to reload LoRA weights
|
|
543
|
-
self.log(f"Loading checkpoint from {path}")
|
|
544
|
-
# Implementation depends on PEFT version
|
|
545
|
-
|
|
546
|
-
|
|
547
|
-
def sanity_check(
|
|
548
|
-
sample: TrainingSample,
|
|
549
|
-
pol_logp: torch.Tensor,
|
|
550
|
-
old_logp: torch.Tensor | None,
|
|
551
|
-
ref_logp: torch.Tensor | None,
|
|
552
|
-
) -> None:
|
|
553
|
-
assert "assistant_mask" in sample.inputs # noqa: S101
|
|
554
|
-
m = sample.inputs["assistant_mask"]
|
|
555
|
-
if old_logp is None or ref_logp is None:
|
|
556
|
-
return
|
|
557
|
-
with torch.no_grad():
|
|
558
|
-
B, K = pol_logp.shape
|
|
559
|
-
assert old_logp.shape == (B, K), "old_logp shape mismatch" # noqa: S101
|
|
560
|
-
assert ref_logp.shape == (B, K), "ref_logp shape mismatch" # noqa: S101
|
|
561
|
-
assert m.shape == (B, K), "assistant_mask shape mismatch" # noqa: S101
|
|
562
|
-
|
|
563
|
-
# Check mask is subset of attention_mask[:, 1:]
|
|
564
|
-
att = sample.inputs.get("attention_mask", None)
|
|
565
|
-
if att is not None and att.dim() == 2:
|
|
566
|
-
att_shift = att[:, 1:].bool()
|
|
567
|
-
bad = (m & ~att_shift).sum().item()
|
|
568
|
-
if bad > 0:
|
|
569
|
-
hud_console.warning_log(f"assistant_mask overlaps padding: {bad} tokens")
|
|
570
|
-
|
|
571
|
-
# Finiteness on masked entries only
|
|
572
|
-
def _stats(name: str, t: torch.Tensor) -> None:
|
|
573
|
-
sel = t[m]
|
|
574
|
-
if sel.numel() == 0:
|
|
575
|
-
hud_console.warning_log(f"{name} empty under mask")
|
|
576
|
-
return
|
|
577
|
-
finite = torch.isfinite(sel)
|
|
578
|
-
if finite.sum() < sel.numel():
|
|
579
|
-
hud_console.warning_log(
|
|
580
|
-
f"{name} non-finite: {((~finite).sum().item())}/{sel.numel()}"
|
|
581
|
-
)
|
|
582
|
-
sel = sel[finite].float()
|
|
583
|
-
|
|
584
|
-
_stats("pol_logp", pol_logp)
|
|
585
|
-
_stats("old_logp", old_logp)
|
|
586
|
-
_stats("ref_logp", ref_logp)
|
|
587
|
-
|
|
588
|
-
# Log-probabilities should be <= 0 (log-softmax)
|
|
589
|
-
if (pol_logp[m] > 1e-6).any():
|
|
590
|
-
hud_console.warning_log("pol_logp has positive values under mask")
|
|
591
|
-
|
|
592
|
-
# Precompute masked deltas and ratios for diagnostics (before exp)
|
|
593
|
-
masked_log_ratio = torch.zeros_like(pol_logp)
|
|
594
|
-
masked_log_ratio[m] = (pol_logp - old_logp)[m]
|
|
595
|
-
masked_log_rho = torch.zeros_like(pol_logp)
|
|
596
|
-
masked_log_rho[m] = (pol_logp - ref_logp)[m]
|
|
597
|
-
|
|
598
|
-
_stats("log_ratio(masked)", masked_log_ratio)
|
|
599
|
-
_stats("log_rho(masked)", masked_log_rho)
|
|
600
|
-
|
|
601
|
-
# Ratios after clamp (diagnostic only)
|
|
602
|
-
ratio_diag = torch.zeros_like(pol_logp)
|
|
603
|
-
rho_diag = torch.zeros_like(pol_logp)
|
|
604
|
-
ratio_diag[m] = torch.exp(masked_log_ratio[m].clamp(-20.0, 20.0))
|
|
605
|
-
rho_diag[m] = torch.exp(masked_log_rho[m].clamp(-20.0, 20.0))
|
|
606
|
-
_stats("ratio_tok(masked)", ratio_diag)
|
|
607
|
-
_stats("rho_tok(masked)", rho_diag)
|
|
608
|
-
|
|
609
|
-
|
|
610
|
-
def _selective_log_softmax(
|
|
611
|
-
logits_bt_v: torch.Tensor,
|
|
612
|
-
index_bt: torch.Tensor,
|
|
613
|
-
) -> torch.Tensor:
|
|
614
|
-
"""Gather log softmax for selected indices with reduced peak memory.
|
|
615
|
-
|
|
616
|
-
Uses logsumexp subtraction for float32/64; falls back to per-row
|
|
617
|
-
log_softmax for bf16/fp16.
|
|
618
|
-
logits_bt_v: [B, T, V]
|
|
619
|
-
index_bt: [B, T]
|
|
620
|
-
Returns: [B, T]
|
|
621
|
-
"""
|
|
622
|
-
if logits_bt_v.dtype in (torch.float32, torch.float64):
|
|
623
|
-
# Compute logsumexp per [B, T] in a loop over batch to reduce
|
|
624
|
-
# peak from B*T*V to T*V
|
|
625
|
-
logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits_bt_v])
|
|
626
|
-
selected_logits = torch.gather(logits_bt_v, dim=-1, index=index_bt.unsqueeze(-1)).squeeze(
|
|
627
|
-
-1
|
|
628
|
-
)
|
|
629
|
-
return selected_logits - logsumexp_values
|
|
630
|
-
# Reduced precision: numerically stable route using per-row log_softmax
|
|
631
|
-
token_logprobs_rows: list[torch.Tensor] = []
|
|
632
|
-
for logits_row, index_row in zip(logits_bt_v, index_bt, strict=True):
|
|
633
|
-
logprobs_row = logits_row.log_softmax(dim=-1)
|
|
634
|
-
token_logprobs_rows.append(
|
|
635
|
-
torch.gather(logprobs_row, dim=-1, index=index_row.unsqueeze(-1)).squeeze(-1)
|
|
636
|
-
)
|
|
637
|
-
return torch.stack(token_logprobs_rows)
|
hud/rl/tests/__init__.py
DELETED
|
@@ -1 +0,0 @@
|
|
|
1
|
-
"""Tests for RL module."""
|