hud-python 0.4.28__py3-none-any.whl → 0.4.30__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of hud-python might be problematic. Click here for more details.
- hud/__init__.py +2 -1
- hud/agents/base.py +81 -45
- hud/agents/claude.py +8 -4
- hud/agents/openai_chat_generic.py +66 -40
- hud/agents/tests/test_base.py +0 -4
- hud/agents/tests/test_openai.py +1 -1
- hud/cli/__init__.py +182 -52
- hud/cli/dev.py +8 -9
- hud/cli/eval.py +317 -119
- hud/cli/flows/__init__.py +0 -0
- hud/cli/flows/tasks.py +0 -0
- hud/cli/get.py +160 -0
- hud/cli/rl/__init__.py +567 -71
- hud/cli/rl/config.py +94 -0
- hud/cli/rl/display.py +133 -0
- hud/cli/rl/gpu.py +63 -0
- hud/cli/rl/gpu_utils.py +318 -0
- hud/cli/rl/presets.py +96 -0
- hud/cli/rl/remote_runner.py +347 -0
- hud/cli/rl/rl_api.py +150 -0
- hud/cli/rl/vllm.py +177 -0
- hud/cli/tests/test_analyze_metadata.py +0 -1
- hud/cli/utils/tasks.py +26 -0
- hud/clients/base.py +21 -23
- hud/clients/mcp_use.py +36 -44
- hud/clients/tests/test_mcp_use_retry.py +10 -10
- hud/datasets/__init__.py +4 -3
- hud/datasets/{execution/parallel.py → parallel.py} +1 -1
- hud/datasets/{execution/runner.py → runner.py} +1 -1
- hud/datasets/utils.py +1 -1
- hud/native/comparator.py +6 -6
- hud/native/tests/test_comparator.py +8 -8
- hud/native/tests/test_native_init.py +13 -11
- hud/otel/config.py +1 -1
- hud/otel/instrumentation.py +35 -0
- hud/rl/README.md +30 -0
- hud/rl/__init__.py +1 -0
- hud/rl/actor.py +174 -0
- hud/rl/buffer.py +371 -0
- hud/rl/chat_template.jinja +101 -0
- hud/rl/config.py +184 -0
- hud/rl/distributed.py +95 -0
- hud/rl/learner.py +589 -0
- hud/rl/tests/__init__.py +1 -0
- hud/rl/tests/test_learner.py +171 -0
- hud/rl/train.py +354 -0
- hud/rl/types.py +101 -0
- hud/rl/utils/start_vllm_server.sh +30 -0
- hud/rl/utils.py +524 -0
- hud/rl/vllm_adapter.py +125 -0
- hud/settings.py +6 -0
- hud/telemetry/__init__.py +2 -1
- hud/telemetry/job.py +46 -3
- hud/telemetry/tests/test_trace.py +3 -3
- hud/telemetry/trace.py +85 -13
- hud/tools/tests/test_computer.py +3 -3
- hud/tools/tests/test_computer_actions.py +1 -1
- hud/types.py +123 -2
- hud/utils/group_eval.py +223 -0
- hud/utils/hud_console.py +113 -13
- hud/utils/tasks.py +119 -0
- hud/utils/tests/test_version.py +1 -1
- hud/version.py +1 -1
- {hud_python-0.4.28.dist-info → hud_python-0.4.30.dist-info}/METADATA +20 -2
- {hud_python-0.4.28.dist-info → hud_python-0.4.30.dist-info}/RECORD +68 -48
- hud/cli/hf.py +0 -406
- hud/cli/rl/README.md +0 -243
- hud/cli/rl/init.py +0 -370
- hud/cli/rl/pod.py +0 -501
- hud/cli/rl/ssh.py +0 -322
- hud/cli/rl/train.py +0 -562
- hud/cli/rl/utils.py +0 -165
- hud/datasets/execution/__init__.py +0 -13
- hud/datasets/task.py +0 -116
- {hud_python-0.4.28.dist-info → hud_python-0.4.30.dist-info}/WHEEL +0 -0
- {hud_python-0.4.28.dist-info → hud_python-0.4.30.dist-info}/entry_points.txt +0 -0
- {hud_python-0.4.28.dist-info → hud_python-0.4.30.dist-info}/licenses/LICENSE +0 -0
hud/rl/distributed.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
1
|
+
"""Distributed training utilities for GRPO."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import os
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
import torch.distributed as dist
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def setup_distributed() -> None:
|
|
13
|
+
"""Initialize distributed training environment."""
|
|
14
|
+
if "RANK" in os.environ and int(os.environ["WORLD_SIZE"]) > 1:
|
|
15
|
+
# Set device for this process
|
|
16
|
+
local_rank = int(os.environ["LOCAL_RANK"])
|
|
17
|
+
torch.cuda.set_device(local_rank)
|
|
18
|
+
|
|
19
|
+
# Initialize process group
|
|
20
|
+
dist.init_process_group("nccl")
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def get_local_rank() -> int:
|
|
24
|
+
"""Get local rank from environment."""
|
|
25
|
+
return int(os.environ.get("LOCAL_RANK", 0))
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def get_global_rank() -> int:
|
|
29
|
+
"""Get global rank from environment."""
|
|
30
|
+
return int(os.environ.get("RANK", 0))
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def get_world_size() -> int:
|
|
34
|
+
"""Get world size from environment."""
|
|
35
|
+
return int(os.environ.get("WORLD_SIZE", 1))
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def cleanup_distributed() -> None:
|
|
39
|
+
"""Clean up distributed environment."""
|
|
40
|
+
if dist.is_initialized():
|
|
41
|
+
dist.destroy_process_group()
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def is_main_process() -> bool:
|
|
45
|
+
"""Check if this is the main process (rank 0)."""
|
|
46
|
+
if not dist.is_initialized():
|
|
47
|
+
return True
|
|
48
|
+
return dist.get_rank() == 0
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def synchronize() -> None:
|
|
52
|
+
"""Synchronize all processes."""
|
|
53
|
+
if dist.is_initialized():
|
|
54
|
+
dist.barrier()
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor:
|
|
58
|
+
"""Average a tensor across all processes."""
|
|
59
|
+
if not dist.is_initialized():
|
|
60
|
+
return tensor
|
|
61
|
+
|
|
62
|
+
world_size = dist.get_world_size()
|
|
63
|
+
dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
|
|
64
|
+
tensor /= world_size
|
|
65
|
+
return tensor
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def broadcast_object(obj: Any, src: int = 0) -> Any:
|
|
69
|
+
"""Broadcast a Python object from src rank to all ranks."""
|
|
70
|
+
if not dist.is_initialized():
|
|
71
|
+
return obj
|
|
72
|
+
|
|
73
|
+
obj_list = [obj] if dist.get_rank() == src else [None]
|
|
74
|
+
dist.broadcast_object_list(obj_list, src=src)
|
|
75
|
+
return obj_list[0]
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def gather_tensors(tensor: torch.Tensor) -> list[torch.Tensor] | None:
|
|
79
|
+
"""Gather tensors from all ranks to rank 0.
|
|
80
|
+
|
|
81
|
+
Returns:
|
|
82
|
+
List of tensors on rank 0, None on other ranks
|
|
83
|
+
"""
|
|
84
|
+
if not dist.is_initialized():
|
|
85
|
+
return [tensor]
|
|
86
|
+
|
|
87
|
+
world_size = dist.get_world_size()
|
|
88
|
+
|
|
89
|
+
if dist.get_rank() == 0:
|
|
90
|
+
gathered = [torch.zeros_like(tensor) for _ in range(world_size)]
|
|
91
|
+
dist.gather(tensor, gathered, dst=0)
|
|
92
|
+
return gathered
|
|
93
|
+
else:
|
|
94
|
+
dist.gather(tensor, None, dst=0)
|
|
95
|
+
return None
|
hud/rl/learner.py
ADDED
|
@@ -0,0 +1,589 @@
|
|
|
1
|
+
"""GRPO learner for vision-language and text models."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import logging
|
|
6
|
+
import os
|
|
7
|
+
from typing import TYPE_CHECKING, Any
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
import torch.nn.functional as F
|
|
11
|
+
from peft import LoraConfig, get_peft_model
|
|
12
|
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
13
|
+
from transformers import (
|
|
14
|
+
AutoModelForCausalLM,
|
|
15
|
+
AutoProcessor,
|
|
16
|
+
AutoTokenizer,
|
|
17
|
+
Qwen2_5_VLForConditionalGeneration,
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
try:
|
|
21
|
+
from liger_kernel.transformers import apply_liger_kernel_to_qwen2_5_vl # type: ignore
|
|
22
|
+
|
|
23
|
+
LIGER_AVAILABLE = True
|
|
24
|
+
except ImportError:
|
|
25
|
+
LIGER_AVAILABLE = False
|
|
26
|
+
|
|
27
|
+
try:
|
|
28
|
+
import bitsandbytes as bnb # type: ignore
|
|
29
|
+
|
|
30
|
+
BNB_AVAILABLE = True
|
|
31
|
+
except ImportError:
|
|
32
|
+
BNB_AVAILABLE = False
|
|
33
|
+
|
|
34
|
+
from contextlib import nullcontext
|
|
35
|
+
|
|
36
|
+
from hud.rl.distributed import (
|
|
37
|
+
get_local_rank,
|
|
38
|
+
get_world_size,
|
|
39
|
+
is_main_process,
|
|
40
|
+
)
|
|
41
|
+
from hud.rl.utils import (
|
|
42
|
+
batch_training_samples,
|
|
43
|
+
entropy_from_logits,
|
|
44
|
+
get_gpu_utilization,
|
|
45
|
+
get_memory_usage,
|
|
46
|
+
prepare_inputs,
|
|
47
|
+
)
|
|
48
|
+
from hud.utils.hud_console import HUDConsole
|
|
49
|
+
|
|
50
|
+
from .types import TrainingMetrics, TrainingSample
|
|
51
|
+
|
|
52
|
+
logger = logging.getLogger(__name__)
|
|
53
|
+
hud_console = HUDConsole(logger)
|
|
54
|
+
|
|
55
|
+
if TYPE_CHECKING:
|
|
56
|
+
from .config import Config
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class GRPOLearner:
|
|
60
|
+
"""GRPO learning algorithm for Vision-Language Models (VLMs) and Text Models."""
|
|
61
|
+
|
|
62
|
+
def __init__(self, config: Config) -> None:
|
|
63
|
+
self.config = config
|
|
64
|
+
self.local_rank = get_local_rank()
|
|
65
|
+
self.world_size = get_world_size()
|
|
66
|
+
self.device = torch.device(
|
|
67
|
+
f"cuda:{self.local_rank}" if torch.cuda.is_available() else "cpu"
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
# Detect model type
|
|
71
|
+
self.is_vl_model = "VL" in config.model.base_model
|
|
72
|
+
|
|
73
|
+
# Load models and processor
|
|
74
|
+
self.processor, self.policy, self.ref, self.optimizer = self._load_models()
|
|
75
|
+
self.metrics: list[TrainingMetrics] = []
|
|
76
|
+
|
|
77
|
+
def log(self, message: str) -> None:
|
|
78
|
+
hud_console.info_log(f"[{self.local_rank}] {message}")
|
|
79
|
+
|
|
80
|
+
def _load_models(self) -> tuple[Any, Any, Any, Any]:
|
|
81
|
+
"""Load policy, reference models and optimizer."""
|
|
82
|
+
model_cfg = self.config.model
|
|
83
|
+
|
|
84
|
+
# Detect if this is a VL model or standard text model
|
|
85
|
+
is_vl_model = "VL" in model_cfg.base_model
|
|
86
|
+
model_type = "Vision-Language" if is_vl_model else "Text"
|
|
87
|
+
self.log(f"Loading {model_type} model: {model_cfg.base_model}")
|
|
88
|
+
|
|
89
|
+
# Apply Liger kernel optimizations if available and enabled
|
|
90
|
+
if model_cfg.use_liger and LIGER_AVAILABLE:
|
|
91
|
+
if is_vl_model:
|
|
92
|
+
self.log("Applying Liger kernel optimizations to Qwen2.5-VL")
|
|
93
|
+
apply_liger_kernel_to_qwen2_5_vl(
|
|
94
|
+
rope=True, # Optimized RoPE
|
|
95
|
+
rms_norm=True, # Optimized RMSNorm
|
|
96
|
+
swiglu=True, # Optimized SwiGLU
|
|
97
|
+
fused_linear_cross_entropy=True, # Fused Linear+CrossEntropy for memory
|
|
98
|
+
)
|
|
99
|
+
elif model_cfg.use_liger and not LIGER_AVAILABLE:
|
|
100
|
+
self.log(
|
|
101
|
+
"Liger kernel requested but not installed. Install with: pip install liger-kernel"
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
# Load processor/tokenizer based on model type
|
|
105
|
+
if is_vl_model:
|
|
106
|
+
# Some environments require remote code for Qwen2.5-VL processors
|
|
107
|
+
processor = AutoProcessor.from_pretrained(
|
|
108
|
+
model_cfg.base_model,
|
|
109
|
+
min_pixels=model_cfg.min_pixels,
|
|
110
|
+
max_pixels=model_cfg.max_pixels,
|
|
111
|
+
trust_remote_code=True,
|
|
112
|
+
)
|
|
113
|
+
else:
|
|
114
|
+
processor = AutoTokenizer.from_pretrained(model_cfg.base_model)
|
|
115
|
+
|
|
116
|
+
# Load policy model with LoRA
|
|
117
|
+
# Use attention implementation from config
|
|
118
|
+
attn_implementation = model_cfg.attn_implementation
|
|
119
|
+
|
|
120
|
+
# Choose the appropriate model class
|
|
121
|
+
model_class = Qwen2_5_VLForConditionalGeneration if is_vl_model else AutoModelForCausalLM
|
|
122
|
+
|
|
123
|
+
try:
|
|
124
|
+
policy = model_class.from_pretrained(
|
|
125
|
+
model_cfg.base_model,
|
|
126
|
+
torch_dtype=torch.bfloat16,
|
|
127
|
+
attn_implementation=attn_implementation,
|
|
128
|
+
trust_remote_code=True,
|
|
129
|
+
)
|
|
130
|
+
self.log(f"Using {attn_implementation} for attention")
|
|
131
|
+
except (ImportError, ValueError) as e:
|
|
132
|
+
# Only fallback if explicitly using flash_attention_2 and it's not available
|
|
133
|
+
if attn_implementation == "flash_attention_2":
|
|
134
|
+
self.log(f"Flash Attention 2 not available ({e}), using eager attention")
|
|
135
|
+
policy = model_class.from_pretrained(
|
|
136
|
+
model_cfg.base_model,
|
|
137
|
+
torch_dtype=torch.bfloat16,
|
|
138
|
+
attn_implementation="eager",
|
|
139
|
+
)
|
|
140
|
+
else:
|
|
141
|
+
raise # Re-raise if it's a different error
|
|
142
|
+
|
|
143
|
+
# Move model to device
|
|
144
|
+
policy = policy.to(self.device) # type: ignore
|
|
145
|
+
# Enable gradient checkpointing for memory efficiency
|
|
146
|
+
if model_cfg.gradient_checkpointing:
|
|
147
|
+
policy.gradient_checkpointing_enable()
|
|
148
|
+
self.log("Gradient checkpointing enabled for memory efficiency")
|
|
149
|
+
|
|
150
|
+
# Add LoRA adapters
|
|
151
|
+
lora_config = LoraConfig(
|
|
152
|
+
r=model_cfg.lora_r,
|
|
153
|
+
lora_alpha=model_cfg.lora_alpha,
|
|
154
|
+
lora_dropout=model_cfg.lora_dropout,
|
|
155
|
+
task_type="CAUSAL_LM",
|
|
156
|
+
bias="none",
|
|
157
|
+
target_modules=list(model_cfg.target_modules),
|
|
158
|
+
)
|
|
159
|
+
policy.config.use_cache = False
|
|
160
|
+
policy = get_peft_model(policy, lora_config)
|
|
161
|
+
|
|
162
|
+
# Wrap with DDP if in distributed mode
|
|
163
|
+
if self.world_size > 1:
|
|
164
|
+
policy = DDP(
|
|
165
|
+
policy,
|
|
166
|
+
device_ids=[self.local_rank],
|
|
167
|
+
output_device=self.local_rank,
|
|
168
|
+
broadcast_buffers=False,
|
|
169
|
+
find_unused_parameters=True,
|
|
170
|
+
)
|
|
171
|
+
self.log("Wrapped model (find_unused_parameters=True)")
|
|
172
|
+
|
|
173
|
+
# Create optimizer - need to access underlying model if DDP
|
|
174
|
+
base_model = policy.module if hasattr(policy, "module") else policy
|
|
175
|
+
trainable_params = [p for _, p in base_model.named_parameters() if p.requires_grad] # type: ignore
|
|
176
|
+
|
|
177
|
+
# Use 8-bit optimizer if configured
|
|
178
|
+
if self.config.training.use_8bit_optimizer and BNB_AVAILABLE:
|
|
179
|
+
hud_console.info("Using 8-bit AdamW optimizer from bitsandbytes")
|
|
180
|
+
optimizer = bnb.optim.AdamW8bit(
|
|
181
|
+
trainable_params,
|
|
182
|
+
lr=self.config.training.lr,
|
|
183
|
+
betas=self.config.training.adam_betas,
|
|
184
|
+
eps=self.config.training.adam_eps,
|
|
185
|
+
)
|
|
186
|
+
else:
|
|
187
|
+
self.log("Using standard FP32 AdamW optimizer")
|
|
188
|
+
optimizer = torch.optim.AdamW(
|
|
189
|
+
trainable_params,
|
|
190
|
+
lr=self.config.training.lr,
|
|
191
|
+
betas=self.config.training.adam_betas,
|
|
192
|
+
eps=self.config.training.adam_eps,
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
# Log optimizer info
|
|
196
|
+
self.log(f"Optimizer: {type(optimizer).__name__}")
|
|
197
|
+
num_params = sum(p.numel() for p in trainable_params)
|
|
198
|
+
self.log(f"Number of trainable parameters: {num_params:,}")
|
|
199
|
+
|
|
200
|
+
return processor, policy, None, optimizer
|
|
201
|
+
|
|
202
|
+
def prepare_groups(
|
|
203
|
+
self,
|
|
204
|
+
samples: list[TrainingSample],
|
|
205
|
+
) -> list[list[TrainingSample]]:
|
|
206
|
+
"""Prepare groups of samples for training."""
|
|
207
|
+
# Prepare inputs with messages
|
|
208
|
+
batch = []
|
|
209
|
+
for sample in samples:
|
|
210
|
+
inputs = prepare_inputs(sample, self.processor)
|
|
211
|
+
# If inputs are invalid, create dummy inputs to maintain batch size
|
|
212
|
+
if (
|
|
213
|
+
not inputs
|
|
214
|
+
or "input_ids" not in inputs
|
|
215
|
+
or inputs.get("input_ids", torch.tensor([])).numel() == 0
|
|
216
|
+
):
|
|
217
|
+
hud_console.warning_log("Sample has invalid inputs, using dummy values")
|
|
218
|
+
# Create minimal dummy inputs to keep batch size consistent
|
|
219
|
+
inputs = {
|
|
220
|
+
"input_ids": torch.zeros(1, 2, dtype=torch.long), # Minimal sequence
|
|
221
|
+
"attention_mask": torch.ones(1, 2, dtype=torch.long),
|
|
222
|
+
"assistant_mask": torch.zeros(1, 1, dtype=torch.bool), # T-1 length
|
|
223
|
+
}
|
|
224
|
+
elif "assistant_mask" not in inputs:
|
|
225
|
+
hud_console.warning_log("Sample missing assistant_mask, creating zero mask")
|
|
226
|
+
seq_len = inputs["input_ids"].shape[-1]
|
|
227
|
+
inputs["assistant_mask"] = torch.zeros(
|
|
228
|
+
inputs["input_ids"].shape[0], seq_len - 1, dtype=torch.bool
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
new_sample = TrainingSample(**sample.model_dump())
|
|
232
|
+
new_sample.inputs = inputs
|
|
233
|
+
new_sample.advantage = sample.advantage
|
|
234
|
+
batch.append(new_sample)
|
|
235
|
+
|
|
236
|
+
with hud_console.progress("Processing batch of traces...") as progress, torch.no_grad():
|
|
237
|
+
for i, sample in enumerate(batch):
|
|
238
|
+
if is_main_process():
|
|
239
|
+
progress.update(f"Processing batch of traces... {i}/{len(batch)}")
|
|
240
|
+
if sample.inputs:
|
|
241
|
+
sample = sample.to_device(self.device)
|
|
242
|
+
sample.old_logprobs, _ = self.compute_logprobs(self.policy, sample.inputs)
|
|
243
|
+
|
|
244
|
+
policy_module = self.policy.module if hasattr(self.policy, "module") else self.policy
|
|
245
|
+
with policy_module.disable_adapter():
|
|
246
|
+
for i, sample in enumerate(batch):
|
|
247
|
+
if is_main_process():
|
|
248
|
+
progress.update(f"Processing batch of traces... {i}/{len(batch)}")
|
|
249
|
+
if sample.inputs:
|
|
250
|
+
sample.ref_logprobs, _ = self.compute_logprobs(self.policy, sample.inputs)
|
|
251
|
+
|
|
252
|
+
hud_console.info_log("Creating mini-batches...")
|
|
253
|
+
group_size = self.config.training.group_size
|
|
254
|
+
processed_batch = []
|
|
255
|
+
if not self.config.training.accumulate_over_minibatches:
|
|
256
|
+
# Find minibatches and group them via batch_training_samples
|
|
257
|
+
# Minibatches control the batch size of the forward pass to the model
|
|
258
|
+
mb_size = self.config.training.mini_batch_size
|
|
259
|
+
group_size = group_size // mb_size
|
|
260
|
+
for i in range(0, len(batch), mb_size):
|
|
261
|
+
processed_batch.extend(batch_training_samples(batch[i : i + mb_size]))
|
|
262
|
+
else:
|
|
263
|
+
processed_batch = batch
|
|
264
|
+
|
|
265
|
+
for sample in processed_batch:
|
|
266
|
+
sample.to_device(torch.device("cpu"))
|
|
267
|
+
|
|
268
|
+
# Convert to grouped batches (if updating the model after each task group)
|
|
269
|
+
if self.config.training.update_after_group:
|
|
270
|
+
return [
|
|
271
|
+
processed_batch[i : i + group_size]
|
|
272
|
+
for i in range(0, len(processed_batch), group_size)
|
|
273
|
+
]
|
|
274
|
+
else:
|
|
275
|
+
return [processed_batch]
|
|
276
|
+
|
|
277
|
+
def update(self, samples: list[TrainingSample]) -> TrainingMetrics:
|
|
278
|
+
"""Perform a gradient update on a batch."""
|
|
279
|
+
import time
|
|
280
|
+
|
|
281
|
+
training_start_time = time.time()
|
|
282
|
+
|
|
283
|
+
# Always create metrics for synchronization
|
|
284
|
+
self.metrics.append(TrainingMetrics())
|
|
285
|
+
metrics = self.metrics[-1]
|
|
286
|
+
|
|
287
|
+
# Prepare groups for GRPO training
|
|
288
|
+
groups = self.prepare_groups(samples)
|
|
289
|
+
self.log(f"Updating over {len(groups)} groups")
|
|
290
|
+
|
|
291
|
+
# Update over mini batch size
|
|
292
|
+
with hud_console.progress("Gradient update...") as progress:
|
|
293
|
+
for epoch in range(self.config.training.epochs): # Do not accumulate across epochs
|
|
294
|
+
progress.update(f"Training epoch {epoch + 1}/{self.config.training.epochs}")
|
|
295
|
+
for group_idx, group in enumerate(groups): # Do not accumulate across "groups"
|
|
296
|
+
self.optimizer.zero_grad(set_to_none=True)
|
|
297
|
+
|
|
298
|
+
debug_per_group = ""
|
|
299
|
+
grad_accum_steps = len(group)
|
|
300
|
+
# Tensor for distributed sync
|
|
301
|
+
global_skip = torch.zeros(1, device=self.device)
|
|
302
|
+
|
|
303
|
+
for s_idx, sample_minibatch in enumerate(group):
|
|
304
|
+
# self.log(f"{group_idx} {sample_minibatch.inputs['assistant_mask'].sum()}")
|
|
305
|
+
# mini_updated = sample_minibatch.inputs["assistant_mask"].sum() > 0
|
|
306
|
+
|
|
307
|
+
# Update mini_updated globally
|
|
308
|
+
# self.log(f"{group_idx} Mini updated: {mini_updated}")
|
|
309
|
+
|
|
310
|
+
# Do not sync until the last minibatch
|
|
311
|
+
if s_idx < len(group) - 1 and self.world_size > 1:
|
|
312
|
+
ddp_ctx = self.policy.no_sync()
|
|
313
|
+
else:
|
|
314
|
+
ddp_ctx = nullcontext()
|
|
315
|
+
|
|
316
|
+
with ddp_ctx, torch.autocast(device_type="cuda", dtype=torch.bfloat16):
|
|
317
|
+
try:
|
|
318
|
+
# if mini_updated:
|
|
319
|
+
loss = self.compute_loss(sample_minibatch) / grad_accum_steps
|
|
320
|
+
debug_per_group += f"l{s_idx}:{round(loss.item(), 3)!s} "
|
|
321
|
+
loss.backward()
|
|
322
|
+
# else: # Dummy backward that touches all params, produces zero g
|
|
323
|
+
# dummy = sum(p.sum() for p in self.policy.parameters()) * 0.0
|
|
324
|
+
# debug_per_group += f"d{s_idx}:{str(round(dummy.item(), 3))} "
|
|
325
|
+
# dummy.backward()
|
|
326
|
+
# self.log(f"{group_idx} GPU Backward: {get_gpu_utilization():.1f}% | Memory: {get_memory_usage():.2f} GB") # noqa: E501
|
|
327
|
+
except torch.cuda.OutOfMemoryError:
|
|
328
|
+
hud_console.warning_log(
|
|
329
|
+
f"{group_idx} CUDA OOM for {sample_minibatch.inputs['input_ids'].numel()} tokens; skipping minibatch" # noqa: E501
|
|
330
|
+
)
|
|
331
|
+
# Dummy backward to keep DDP happy
|
|
332
|
+
dummy = torch.sum(p.sum() for p in self.policy.parameters()) * 0.0 # type: ignore
|
|
333
|
+
debug_per_group += f"o{s_idx}:{round(dummy.item(), 3)!s} "
|
|
334
|
+
dummy.backward()
|
|
335
|
+
# mark global skip if OOM
|
|
336
|
+
global_skip.fill_(1)
|
|
337
|
+
continue
|
|
338
|
+
|
|
339
|
+
if torch.cuda.is_available():
|
|
340
|
+
torch.cuda.empty_cache()
|
|
341
|
+
|
|
342
|
+
# After minibatches loop, sync skip across ranks
|
|
343
|
+
if torch.distributed.is_initialized():
|
|
344
|
+
torch.distributed.all_reduce(global_skip, op=torch.distributed.ReduceOp.MAX)
|
|
345
|
+
skip_any = bool(global_skip.item())
|
|
346
|
+
|
|
347
|
+
if skip_any:
|
|
348
|
+
self.log(f"G[{group_idx}] {debug_per_group} N/A (skipped)")
|
|
349
|
+
continue
|
|
350
|
+
|
|
351
|
+
grad_norm = torch.nn.utils.clip_grad_norm_(
|
|
352
|
+
self.policy.parameters(),
|
|
353
|
+
self.config.training.grad_clip,
|
|
354
|
+
error_if_nonfinite=True,
|
|
355
|
+
)
|
|
356
|
+
self.optimizer.step()
|
|
357
|
+
|
|
358
|
+
debug_per_group += f"g:{round(grad_norm.item(), 3)!s}"
|
|
359
|
+
self.log(f"G[{group_idx}] {debug_per_group}")
|
|
360
|
+
|
|
361
|
+
metrics.update(
|
|
362
|
+
{
|
|
363
|
+
"grad_norm": grad_norm.item()
|
|
364
|
+
if isinstance(grad_norm, torch.Tensor)
|
|
365
|
+
else float(grad_norm),
|
|
366
|
+
}
|
|
367
|
+
)
|
|
368
|
+
|
|
369
|
+
# Calculate training time and throughput
|
|
370
|
+
training_time = time.time() - training_start_time
|
|
371
|
+
total_samples = (
|
|
372
|
+
len(groups) * self.config.training.group_size * self.config.training.mini_batch_size
|
|
373
|
+
)
|
|
374
|
+
samples_per_second = total_samples / training_time if training_time > 0 else 0.0
|
|
375
|
+
|
|
376
|
+
metrics.update(
|
|
377
|
+
{
|
|
378
|
+
"training_time": training_time,
|
|
379
|
+
"samples_per_second": samples_per_second,
|
|
380
|
+
}
|
|
381
|
+
)
|
|
382
|
+
|
|
383
|
+
return metrics
|
|
384
|
+
|
|
385
|
+
def compute_loss(self, sample: TrainingSample) -> torch.Tensor:
|
|
386
|
+
"""Compute GRPO loss for a batch of samples."""
|
|
387
|
+
training_cfg = self.config.training
|
|
388
|
+
metrics = self.metrics[-1] if len(self.metrics) > 0 else TrainingMetrics()
|
|
389
|
+
|
|
390
|
+
sample.to_device(self.device)
|
|
391
|
+
|
|
392
|
+
pol_logp, pol_entropy = self.compute_logprobs(
|
|
393
|
+
self.policy,
|
|
394
|
+
sample.inputs,
|
|
395
|
+
)
|
|
396
|
+
|
|
397
|
+
sanity_check(sample, pol_logp, sample.old_logprobs, sample.ref_logprobs)
|
|
398
|
+
|
|
399
|
+
metrics.update(
|
|
400
|
+
{
|
|
401
|
+
"gpu_util": get_gpu_utilization(), # Track peak utilization
|
|
402
|
+
"gpu_memory": get_memory_usage(), # Track memory usage
|
|
403
|
+
}
|
|
404
|
+
)
|
|
405
|
+
self.log(f"GPU Util: {get_gpu_utilization():.1f}% | Memory: {get_memory_usage():.2f} GB")
|
|
406
|
+
|
|
407
|
+
old_logp = sample.old_logprobs
|
|
408
|
+
ref_logp = sample.ref_logprobs
|
|
409
|
+
|
|
410
|
+
if old_logp is None or ref_logp is None or sample.advantage is None:
|
|
411
|
+
raise ValueError("old_logp, ref_logp, or sample.advantage is None")
|
|
412
|
+
|
|
413
|
+
# Use assistant mask to remove non-assistant tokens
|
|
414
|
+
m = sample.inputs["assistant_mask"]
|
|
415
|
+
|
|
416
|
+
# Aggregate per trace or per token
|
|
417
|
+
if training_cfg.ppo_mode == "per_trace":
|
|
418
|
+
counts = m.sum(dim=1).clamp_min(1.0)
|
|
419
|
+
pol_logp = (pol_logp * m.float()).sum(dim=1) / counts
|
|
420
|
+
pol_entropy = (pol_entropy * m.float()).sum(dim=1) / counts
|
|
421
|
+
old_logp = (old_logp * m.float()).sum(dim=1) / counts
|
|
422
|
+
ref_logp = (ref_logp * m.float()).sum(dim=1) / counts
|
|
423
|
+
|
|
424
|
+
# Clip log probability differences
|
|
425
|
+
log_ratio = torch.where(m, pol_logp - old_logp, torch.zeros_like(pol_logp))
|
|
426
|
+
ratio_tok = torch.exp(log_ratio.clamp(-20.0, 20.0))
|
|
427
|
+
|
|
428
|
+
# Ensure advantage shape matches ratio_tok for broadcasting
|
|
429
|
+
advantage = (
|
|
430
|
+
sample.advantage.view(-1, 1) if ratio_tok.dim() == 2 else sample.advantage.squeeze(-1)
|
|
431
|
+
)
|
|
432
|
+
|
|
433
|
+
unclipped = ratio_tok * advantage
|
|
434
|
+
clipped = (
|
|
435
|
+
torch.clamp(ratio_tok, 1 - training_cfg.top_eps, 1 + training_cfg.bottom_eps)
|
|
436
|
+
* advantage
|
|
437
|
+
)
|
|
438
|
+
|
|
439
|
+
policy_term = -torch.minimum(unclipped, clipped)
|
|
440
|
+
|
|
441
|
+
# Clip log probability differences in KL
|
|
442
|
+
log_rho = torch.where(m, pol_logp - ref_logp, torch.zeros_like(pol_logp))
|
|
443
|
+
rho_tok = torch.exp(log_rho.clamp(-20.0, 20.0))
|
|
444
|
+
kl_approx = rho_tok - torch.log(rho_tok) - 1
|
|
445
|
+
|
|
446
|
+
total_loss = (
|
|
447
|
+
policy_term + training_cfg.kl_beta * kl_approx + training_cfg.entropy_beta * pol_entropy
|
|
448
|
+
)
|
|
449
|
+
|
|
450
|
+
# Aggregate loss
|
|
451
|
+
if training_cfg.ppo_mode == "per_trace":
|
|
452
|
+
total_loss = total_loss.mean() if training_cfg.token_agg == "mean" else total_loss.sum() # noqa: S105
|
|
453
|
+
else:
|
|
454
|
+
if training_cfg.token_agg == "mean": # noqa: S105
|
|
455
|
+
total_loss = (total_loss * m).sum() / m.sum().clamp_min(1.0)
|
|
456
|
+
else:
|
|
457
|
+
total_loss = (total_loss * m).sum()
|
|
458
|
+
|
|
459
|
+
# Compute metrics only over masked (assistant) tokens
|
|
460
|
+
mask_count = m.sum().clamp_min(1.0)
|
|
461
|
+
metrics.update(
|
|
462
|
+
{
|
|
463
|
+
"policy_ratio": (ratio_tok * m).sum().item() / mask_count.item()
|
|
464
|
+
if mask_count.item() > 0
|
|
465
|
+
else 1.0,
|
|
466
|
+
"kl": (kl_approx * m).sum().item() / mask_count.item()
|
|
467
|
+
if mask_count.item() > 0
|
|
468
|
+
else 0.0,
|
|
469
|
+
"entropy": (pol_entropy * m).sum().item() / mask_count.item()
|
|
470
|
+
if mask_count.item() > 0
|
|
471
|
+
else 0.0,
|
|
472
|
+
"tokens": sample.inputs["input_ids"].numel(),
|
|
473
|
+
"loss": total_loss.item(),
|
|
474
|
+
}
|
|
475
|
+
)
|
|
476
|
+
|
|
477
|
+
sample.to_device(torch.device("cpu"))
|
|
478
|
+
|
|
479
|
+
return total_loss
|
|
480
|
+
|
|
481
|
+
def compute_logprobs(self, model: Any, inputs: Any) -> tuple[torch.Tensor, torch.Tensor]:
|
|
482
|
+
"""Compute masked per-token log probabilities via the model.
|
|
483
|
+
|
|
484
|
+
Returns log probabilities for the actual next tokens.
|
|
485
|
+
"""
|
|
486
|
+
try:
|
|
487
|
+
model_inputs = {k: v for k, v in inputs.items() if k != "assistant_mask"}
|
|
488
|
+
out = model(**model_inputs)
|
|
489
|
+
|
|
490
|
+
logits = out.logits / self.config.actor.temperature
|
|
491
|
+
log_probs = F.log_softmax(logits, dim=-1)
|
|
492
|
+
|
|
493
|
+
targets = inputs["input_ids"][:, 1:]
|
|
494
|
+
token_log_probs = log_probs[:, :-1].gather(-1, targets.unsqueeze(-1)).squeeze(-1)
|
|
495
|
+
|
|
496
|
+
# Compute entropy only for assistant tokens to save memory
|
|
497
|
+
assistant_mask = inputs["assistant_mask"]
|
|
498
|
+
entropy = torch.zeros_like(token_log_probs)
|
|
499
|
+
if assistant_mask.any():
|
|
500
|
+
entropy[assistant_mask] = entropy_from_logits(logits[:, :-1][assistant_mask])
|
|
501
|
+
|
|
502
|
+
return token_log_probs, entropy
|
|
503
|
+
except (IndexError, RuntimeError) as e:
|
|
504
|
+
# Handle empty inputs or DDP errors
|
|
505
|
+
hud_console.warning_log(f"Error in compute_logprobs: {e}. Returning dummy values.")
|
|
506
|
+
# Return dummy values that match expected shapes
|
|
507
|
+
seq_len = inputs["input_ids"].shape[1] - 1 if "input_ids" in inputs else 0
|
|
508
|
+
batch_size = inputs["input_ids"].shape[0] if "input_ids" in inputs else 1
|
|
509
|
+
dummy_logprobs = torch.zeros(batch_size, seq_len, device=self.device)
|
|
510
|
+
dummy_entropy = torch.zeros(batch_size, seq_len, device=self.device)
|
|
511
|
+
return dummy_logprobs, dummy_entropy
|
|
512
|
+
|
|
513
|
+
def save(self, path: str) -> None:
|
|
514
|
+
"""Save the current policy checkpoint (only on rank 0)."""
|
|
515
|
+
if is_main_process():
|
|
516
|
+
os.makedirs(path, exist_ok=True)
|
|
517
|
+
# Unwrap DDP model if needed
|
|
518
|
+
model_to_save = self.policy.module if hasattr(self.policy, "module") else self.policy
|
|
519
|
+
model_to_save.save_pretrained(path)
|
|
520
|
+
self.log(f"Saved checkpoint to {path}")
|
|
521
|
+
|
|
522
|
+
def load(self, path: str) -> None:
|
|
523
|
+
"""Load a policy checkpoint."""
|
|
524
|
+
# Would need to reload LoRA weights
|
|
525
|
+
self.log(f"Loading checkpoint from {path}")
|
|
526
|
+
# Implementation depends on PEFT version
|
|
527
|
+
|
|
528
|
+
|
|
529
|
+
def sanity_check(
|
|
530
|
+
sample: TrainingSample,
|
|
531
|
+
pol_logp: torch.Tensor,
|
|
532
|
+
old_logp: torch.Tensor | None,
|
|
533
|
+
ref_logp: torch.Tensor | None,
|
|
534
|
+
) -> None:
|
|
535
|
+
assert "assistant_mask" in sample.inputs # noqa: S101
|
|
536
|
+
m = sample.inputs["assistant_mask"]
|
|
537
|
+
if old_logp is None or ref_logp is None:
|
|
538
|
+
return
|
|
539
|
+
with torch.no_grad():
|
|
540
|
+
B, K = pol_logp.shape
|
|
541
|
+
assert old_logp.shape == (B, K), "old_logp shape mismatch" # noqa: S101
|
|
542
|
+
assert ref_logp.shape == (B, K), "ref_logp shape mismatch" # noqa: S101
|
|
543
|
+
assert m.shape == (B, K), "assistant_mask shape mismatch" # noqa: S101
|
|
544
|
+
|
|
545
|
+
# Check mask is subset of attention_mask[:, 1:]
|
|
546
|
+
att = sample.inputs.get("attention_mask", None)
|
|
547
|
+
if att is not None and att.dim() == 2:
|
|
548
|
+
att_shift = att[:, 1:].bool()
|
|
549
|
+
bad = (m & ~att_shift).sum().item()
|
|
550
|
+
if bad > 0:
|
|
551
|
+
hud_console.warning_log(f"assistant_mask overlaps padding: {bad} tokens")
|
|
552
|
+
|
|
553
|
+
# Finiteness on masked entries only
|
|
554
|
+
def _stats(name: str, t: torch.Tensor) -> None:
|
|
555
|
+
sel = t[m]
|
|
556
|
+
if sel.numel() == 0:
|
|
557
|
+
hud_console.warning_log(f"{name} empty under mask")
|
|
558
|
+
return
|
|
559
|
+
finite = torch.isfinite(sel)
|
|
560
|
+
if finite.sum() < sel.numel():
|
|
561
|
+
hud_console.warning_log(
|
|
562
|
+
f"{name} non-finite: {((~finite).sum().item())}/{sel.numel()}"
|
|
563
|
+
)
|
|
564
|
+
sel = sel[finite].float()
|
|
565
|
+
|
|
566
|
+
_stats("pol_logp", pol_logp)
|
|
567
|
+
_stats("old_logp", old_logp)
|
|
568
|
+
_stats("ref_logp", ref_logp)
|
|
569
|
+
|
|
570
|
+
# Log-probabilities should be <= 0 (log-softmax)
|
|
571
|
+
if (pol_logp[m] > 1e-6).any():
|
|
572
|
+
hud_console.warning_log("pol_logp has positive values under mask")
|
|
573
|
+
|
|
574
|
+
# Precompute masked deltas and ratios for diagnostics (before exp)
|
|
575
|
+
masked_log_ratio = torch.zeros_like(pol_logp)
|
|
576
|
+
masked_log_ratio[m] = (pol_logp - old_logp)[m]
|
|
577
|
+
masked_log_rho = torch.zeros_like(pol_logp)
|
|
578
|
+
masked_log_rho[m] = (pol_logp - ref_logp)[m]
|
|
579
|
+
|
|
580
|
+
_stats("log_ratio(masked)", masked_log_ratio)
|
|
581
|
+
_stats("log_rho(masked)", masked_log_rho)
|
|
582
|
+
|
|
583
|
+
# Ratios after clamp (diagnostic only)
|
|
584
|
+
ratio_diag = torch.zeros_like(pol_logp)
|
|
585
|
+
rho_diag = torch.zeros_like(pol_logp)
|
|
586
|
+
ratio_diag[m] = torch.exp(masked_log_ratio[m].clamp(-20.0, 20.0))
|
|
587
|
+
rho_diag[m] = torch.exp(masked_log_rho[m].clamp(-20.0, 20.0))
|
|
588
|
+
_stats("ratio_tok(masked)", ratio_diag)
|
|
589
|
+
_stats("rho_tok(masked)", rho_diag)
|
hud/rl/tests/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""Tests for RL module."""
|