hud-python 0.4.28__py3-none-any.whl → 0.4.30__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of hud-python might be problematic. Click here for more details.

Files changed (77) hide show
  1. hud/__init__.py +2 -1
  2. hud/agents/base.py +81 -45
  3. hud/agents/claude.py +8 -4
  4. hud/agents/openai_chat_generic.py +66 -40
  5. hud/agents/tests/test_base.py +0 -4
  6. hud/agents/tests/test_openai.py +1 -1
  7. hud/cli/__init__.py +182 -52
  8. hud/cli/dev.py +8 -9
  9. hud/cli/eval.py +317 -119
  10. hud/cli/flows/__init__.py +0 -0
  11. hud/cli/flows/tasks.py +0 -0
  12. hud/cli/get.py +160 -0
  13. hud/cli/rl/__init__.py +567 -71
  14. hud/cli/rl/config.py +94 -0
  15. hud/cli/rl/display.py +133 -0
  16. hud/cli/rl/gpu.py +63 -0
  17. hud/cli/rl/gpu_utils.py +318 -0
  18. hud/cli/rl/presets.py +96 -0
  19. hud/cli/rl/remote_runner.py +347 -0
  20. hud/cli/rl/rl_api.py +150 -0
  21. hud/cli/rl/vllm.py +177 -0
  22. hud/cli/tests/test_analyze_metadata.py +0 -1
  23. hud/cli/utils/tasks.py +26 -0
  24. hud/clients/base.py +21 -23
  25. hud/clients/mcp_use.py +36 -44
  26. hud/clients/tests/test_mcp_use_retry.py +10 -10
  27. hud/datasets/__init__.py +4 -3
  28. hud/datasets/{execution/parallel.py → parallel.py} +1 -1
  29. hud/datasets/{execution/runner.py → runner.py} +1 -1
  30. hud/datasets/utils.py +1 -1
  31. hud/native/comparator.py +6 -6
  32. hud/native/tests/test_comparator.py +8 -8
  33. hud/native/tests/test_native_init.py +13 -11
  34. hud/otel/config.py +1 -1
  35. hud/otel/instrumentation.py +35 -0
  36. hud/rl/README.md +30 -0
  37. hud/rl/__init__.py +1 -0
  38. hud/rl/actor.py +174 -0
  39. hud/rl/buffer.py +371 -0
  40. hud/rl/chat_template.jinja +101 -0
  41. hud/rl/config.py +184 -0
  42. hud/rl/distributed.py +95 -0
  43. hud/rl/learner.py +589 -0
  44. hud/rl/tests/__init__.py +1 -0
  45. hud/rl/tests/test_learner.py +171 -0
  46. hud/rl/train.py +354 -0
  47. hud/rl/types.py +101 -0
  48. hud/rl/utils/start_vllm_server.sh +30 -0
  49. hud/rl/utils.py +524 -0
  50. hud/rl/vllm_adapter.py +125 -0
  51. hud/settings.py +6 -0
  52. hud/telemetry/__init__.py +2 -1
  53. hud/telemetry/job.py +46 -3
  54. hud/telemetry/tests/test_trace.py +3 -3
  55. hud/telemetry/trace.py +85 -13
  56. hud/tools/tests/test_computer.py +3 -3
  57. hud/tools/tests/test_computer_actions.py +1 -1
  58. hud/types.py +123 -2
  59. hud/utils/group_eval.py +223 -0
  60. hud/utils/hud_console.py +113 -13
  61. hud/utils/tasks.py +119 -0
  62. hud/utils/tests/test_version.py +1 -1
  63. hud/version.py +1 -1
  64. {hud_python-0.4.28.dist-info → hud_python-0.4.30.dist-info}/METADATA +20 -2
  65. {hud_python-0.4.28.dist-info → hud_python-0.4.30.dist-info}/RECORD +68 -48
  66. hud/cli/hf.py +0 -406
  67. hud/cli/rl/README.md +0 -243
  68. hud/cli/rl/init.py +0 -370
  69. hud/cli/rl/pod.py +0 -501
  70. hud/cli/rl/ssh.py +0 -322
  71. hud/cli/rl/train.py +0 -562
  72. hud/cli/rl/utils.py +0 -165
  73. hud/datasets/execution/__init__.py +0 -13
  74. hud/datasets/task.py +0 -116
  75. {hud_python-0.4.28.dist-info → hud_python-0.4.30.dist-info}/WHEEL +0 -0
  76. {hud_python-0.4.28.dist-info → hud_python-0.4.30.dist-info}/entry_points.txt +0 -0
  77. {hud_python-0.4.28.dist-info → hud_python-0.4.30.dist-info}/licenses/LICENSE +0 -0
hud/rl/distributed.py ADDED
@@ -0,0 +1,95 @@
1
+ """Distributed training utilities for GRPO."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import os
6
+ from typing import Any
7
+
8
+ import torch
9
+ import torch.distributed as dist
10
+
11
+
12
+ def setup_distributed() -> None:
13
+ """Initialize distributed training environment."""
14
+ if "RANK" in os.environ and int(os.environ["WORLD_SIZE"]) > 1:
15
+ # Set device for this process
16
+ local_rank = int(os.environ["LOCAL_RANK"])
17
+ torch.cuda.set_device(local_rank)
18
+
19
+ # Initialize process group
20
+ dist.init_process_group("nccl")
21
+
22
+
23
+ def get_local_rank() -> int:
24
+ """Get local rank from environment."""
25
+ return int(os.environ.get("LOCAL_RANK", 0))
26
+
27
+
28
+ def get_global_rank() -> int:
29
+ """Get global rank from environment."""
30
+ return int(os.environ.get("RANK", 0))
31
+
32
+
33
+ def get_world_size() -> int:
34
+ """Get world size from environment."""
35
+ return int(os.environ.get("WORLD_SIZE", 1))
36
+
37
+
38
+ def cleanup_distributed() -> None:
39
+ """Clean up distributed environment."""
40
+ if dist.is_initialized():
41
+ dist.destroy_process_group()
42
+
43
+
44
+ def is_main_process() -> bool:
45
+ """Check if this is the main process (rank 0)."""
46
+ if not dist.is_initialized():
47
+ return True
48
+ return dist.get_rank() == 0
49
+
50
+
51
+ def synchronize() -> None:
52
+ """Synchronize all processes."""
53
+ if dist.is_initialized():
54
+ dist.barrier()
55
+
56
+
57
+ def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor:
58
+ """Average a tensor across all processes."""
59
+ if not dist.is_initialized():
60
+ return tensor
61
+
62
+ world_size = dist.get_world_size()
63
+ dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
64
+ tensor /= world_size
65
+ return tensor
66
+
67
+
68
+ def broadcast_object(obj: Any, src: int = 0) -> Any:
69
+ """Broadcast a Python object from src rank to all ranks."""
70
+ if not dist.is_initialized():
71
+ return obj
72
+
73
+ obj_list = [obj] if dist.get_rank() == src else [None]
74
+ dist.broadcast_object_list(obj_list, src=src)
75
+ return obj_list[0]
76
+
77
+
78
+ def gather_tensors(tensor: torch.Tensor) -> list[torch.Tensor] | None:
79
+ """Gather tensors from all ranks to rank 0.
80
+
81
+ Returns:
82
+ List of tensors on rank 0, None on other ranks
83
+ """
84
+ if not dist.is_initialized():
85
+ return [tensor]
86
+
87
+ world_size = dist.get_world_size()
88
+
89
+ if dist.get_rank() == 0:
90
+ gathered = [torch.zeros_like(tensor) for _ in range(world_size)]
91
+ dist.gather(tensor, gathered, dst=0)
92
+ return gathered
93
+ else:
94
+ dist.gather(tensor, None, dst=0)
95
+ return None
hud/rl/learner.py ADDED
@@ -0,0 +1,589 @@
1
+ """GRPO learner for vision-language and text models."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+ import os
7
+ from typing import TYPE_CHECKING, Any
8
+
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from peft import LoraConfig, get_peft_model
12
+ from torch.nn.parallel import DistributedDataParallel as DDP
13
+ from transformers import (
14
+ AutoModelForCausalLM,
15
+ AutoProcessor,
16
+ AutoTokenizer,
17
+ Qwen2_5_VLForConditionalGeneration,
18
+ )
19
+
20
+ try:
21
+ from liger_kernel.transformers import apply_liger_kernel_to_qwen2_5_vl # type: ignore
22
+
23
+ LIGER_AVAILABLE = True
24
+ except ImportError:
25
+ LIGER_AVAILABLE = False
26
+
27
+ try:
28
+ import bitsandbytes as bnb # type: ignore
29
+
30
+ BNB_AVAILABLE = True
31
+ except ImportError:
32
+ BNB_AVAILABLE = False
33
+
34
+ from contextlib import nullcontext
35
+
36
+ from hud.rl.distributed import (
37
+ get_local_rank,
38
+ get_world_size,
39
+ is_main_process,
40
+ )
41
+ from hud.rl.utils import (
42
+ batch_training_samples,
43
+ entropy_from_logits,
44
+ get_gpu_utilization,
45
+ get_memory_usage,
46
+ prepare_inputs,
47
+ )
48
+ from hud.utils.hud_console import HUDConsole
49
+
50
+ from .types import TrainingMetrics, TrainingSample
51
+
52
+ logger = logging.getLogger(__name__)
53
+ hud_console = HUDConsole(logger)
54
+
55
+ if TYPE_CHECKING:
56
+ from .config import Config
57
+
58
+
59
+ class GRPOLearner:
60
+ """GRPO learning algorithm for Vision-Language Models (VLMs) and Text Models."""
61
+
62
+ def __init__(self, config: Config) -> None:
63
+ self.config = config
64
+ self.local_rank = get_local_rank()
65
+ self.world_size = get_world_size()
66
+ self.device = torch.device(
67
+ f"cuda:{self.local_rank}" if torch.cuda.is_available() else "cpu"
68
+ )
69
+
70
+ # Detect model type
71
+ self.is_vl_model = "VL" in config.model.base_model
72
+
73
+ # Load models and processor
74
+ self.processor, self.policy, self.ref, self.optimizer = self._load_models()
75
+ self.metrics: list[TrainingMetrics] = []
76
+
77
+ def log(self, message: str) -> None:
78
+ hud_console.info_log(f"[{self.local_rank}] {message}")
79
+
80
+ def _load_models(self) -> tuple[Any, Any, Any, Any]:
81
+ """Load policy, reference models and optimizer."""
82
+ model_cfg = self.config.model
83
+
84
+ # Detect if this is a VL model or standard text model
85
+ is_vl_model = "VL" in model_cfg.base_model
86
+ model_type = "Vision-Language" if is_vl_model else "Text"
87
+ self.log(f"Loading {model_type} model: {model_cfg.base_model}")
88
+
89
+ # Apply Liger kernel optimizations if available and enabled
90
+ if model_cfg.use_liger and LIGER_AVAILABLE:
91
+ if is_vl_model:
92
+ self.log("Applying Liger kernel optimizations to Qwen2.5-VL")
93
+ apply_liger_kernel_to_qwen2_5_vl(
94
+ rope=True, # Optimized RoPE
95
+ rms_norm=True, # Optimized RMSNorm
96
+ swiglu=True, # Optimized SwiGLU
97
+ fused_linear_cross_entropy=True, # Fused Linear+CrossEntropy for memory
98
+ )
99
+ elif model_cfg.use_liger and not LIGER_AVAILABLE:
100
+ self.log(
101
+ "Liger kernel requested but not installed. Install with: pip install liger-kernel"
102
+ )
103
+
104
+ # Load processor/tokenizer based on model type
105
+ if is_vl_model:
106
+ # Some environments require remote code for Qwen2.5-VL processors
107
+ processor = AutoProcessor.from_pretrained(
108
+ model_cfg.base_model,
109
+ min_pixels=model_cfg.min_pixels,
110
+ max_pixels=model_cfg.max_pixels,
111
+ trust_remote_code=True,
112
+ )
113
+ else:
114
+ processor = AutoTokenizer.from_pretrained(model_cfg.base_model)
115
+
116
+ # Load policy model with LoRA
117
+ # Use attention implementation from config
118
+ attn_implementation = model_cfg.attn_implementation
119
+
120
+ # Choose the appropriate model class
121
+ model_class = Qwen2_5_VLForConditionalGeneration if is_vl_model else AutoModelForCausalLM
122
+
123
+ try:
124
+ policy = model_class.from_pretrained(
125
+ model_cfg.base_model,
126
+ torch_dtype=torch.bfloat16,
127
+ attn_implementation=attn_implementation,
128
+ trust_remote_code=True,
129
+ )
130
+ self.log(f"Using {attn_implementation} for attention")
131
+ except (ImportError, ValueError) as e:
132
+ # Only fallback if explicitly using flash_attention_2 and it's not available
133
+ if attn_implementation == "flash_attention_2":
134
+ self.log(f"Flash Attention 2 not available ({e}), using eager attention")
135
+ policy = model_class.from_pretrained(
136
+ model_cfg.base_model,
137
+ torch_dtype=torch.bfloat16,
138
+ attn_implementation="eager",
139
+ )
140
+ else:
141
+ raise # Re-raise if it's a different error
142
+
143
+ # Move model to device
144
+ policy = policy.to(self.device) # type: ignore
145
+ # Enable gradient checkpointing for memory efficiency
146
+ if model_cfg.gradient_checkpointing:
147
+ policy.gradient_checkpointing_enable()
148
+ self.log("Gradient checkpointing enabled for memory efficiency")
149
+
150
+ # Add LoRA adapters
151
+ lora_config = LoraConfig(
152
+ r=model_cfg.lora_r,
153
+ lora_alpha=model_cfg.lora_alpha,
154
+ lora_dropout=model_cfg.lora_dropout,
155
+ task_type="CAUSAL_LM",
156
+ bias="none",
157
+ target_modules=list(model_cfg.target_modules),
158
+ )
159
+ policy.config.use_cache = False
160
+ policy = get_peft_model(policy, lora_config)
161
+
162
+ # Wrap with DDP if in distributed mode
163
+ if self.world_size > 1:
164
+ policy = DDP(
165
+ policy,
166
+ device_ids=[self.local_rank],
167
+ output_device=self.local_rank,
168
+ broadcast_buffers=False,
169
+ find_unused_parameters=True,
170
+ )
171
+ self.log("Wrapped model (find_unused_parameters=True)")
172
+
173
+ # Create optimizer - need to access underlying model if DDP
174
+ base_model = policy.module if hasattr(policy, "module") else policy
175
+ trainable_params = [p for _, p in base_model.named_parameters() if p.requires_grad] # type: ignore
176
+
177
+ # Use 8-bit optimizer if configured
178
+ if self.config.training.use_8bit_optimizer and BNB_AVAILABLE:
179
+ hud_console.info("Using 8-bit AdamW optimizer from bitsandbytes")
180
+ optimizer = bnb.optim.AdamW8bit(
181
+ trainable_params,
182
+ lr=self.config.training.lr,
183
+ betas=self.config.training.adam_betas,
184
+ eps=self.config.training.adam_eps,
185
+ )
186
+ else:
187
+ self.log("Using standard FP32 AdamW optimizer")
188
+ optimizer = torch.optim.AdamW(
189
+ trainable_params,
190
+ lr=self.config.training.lr,
191
+ betas=self.config.training.adam_betas,
192
+ eps=self.config.training.adam_eps,
193
+ )
194
+
195
+ # Log optimizer info
196
+ self.log(f"Optimizer: {type(optimizer).__name__}")
197
+ num_params = sum(p.numel() for p in trainable_params)
198
+ self.log(f"Number of trainable parameters: {num_params:,}")
199
+
200
+ return processor, policy, None, optimizer
201
+
202
+ def prepare_groups(
203
+ self,
204
+ samples: list[TrainingSample],
205
+ ) -> list[list[TrainingSample]]:
206
+ """Prepare groups of samples for training."""
207
+ # Prepare inputs with messages
208
+ batch = []
209
+ for sample in samples:
210
+ inputs = prepare_inputs(sample, self.processor)
211
+ # If inputs are invalid, create dummy inputs to maintain batch size
212
+ if (
213
+ not inputs
214
+ or "input_ids" not in inputs
215
+ or inputs.get("input_ids", torch.tensor([])).numel() == 0
216
+ ):
217
+ hud_console.warning_log("Sample has invalid inputs, using dummy values")
218
+ # Create minimal dummy inputs to keep batch size consistent
219
+ inputs = {
220
+ "input_ids": torch.zeros(1, 2, dtype=torch.long), # Minimal sequence
221
+ "attention_mask": torch.ones(1, 2, dtype=torch.long),
222
+ "assistant_mask": torch.zeros(1, 1, dtype=torch.bool), # T-1 length
223
+ }
224
+ elif "assistant_mask" not in inputs:
225
+ hud_console.warning_log("Sample missing assistant_mask, creating zero mask")
226
+ seq_len = inputs["input_ids"].shape[-1]
227
+ inputs["assistant_mask"] = torch.zeros(
228
+ inputs["input_ids"].shape[0], seq_len - 1, dtype=torch.bool
229
+ )
230
+
231
+ new_sample = TrainingSample(**sample.model_dump())
232
+ new_sample.inputs = inputs
233
+ new_sample.advantage = sample.advantage
234
+ batch.append(new_sample)
235
+
236
+ with hud_console.progress("Processing batch of traces...") as progress, torch.no_grad():
237
+ for i, sample in enumerate(batch):
238
+ if is_main_process():
239
+ progress.update(f"Processing batch of traces... {i}/{len(batch)}")
240
+ if sample.inputs:
241
+ sample = sample.to_device(self.device)
242
+ sample.old_logprobs, _ = self.compute_logprobs(self.policy, sample.inputs)
243
+
244
+ policy_module = self.policy.module if hasattr(self.policy, "module") else self.policy
245
+ with policy_module.disable_adapter():
246
+ for i, sample in enumerate(batch):
247
+ if is_main_process():
248
+ progress.update(f"Processing batch of traces... {i}/{len(batch)}")
249
+ if sample.inputs:
250
+ sample.ref_logprobs, _ = self.compute_logprobs(self.policy, sample.inputs)
251
+
252
+ hud_console.info_log("Creating mini-batches...")
253
+ group_size = self.config.training.group_size
254
+ processed_batch = []
255
+ if not self.config.training.accumulate_over_minibatches:
256
+ # Find minibatches and group them via batch_training_samples
257
+ # Minibatches control the batch size of the forward pass to the model
258
+ mb_size = self.config.training.mini_batch_size
259
+ group_size = group_size // mb_size
260
+ for i in range(0, len(batch), mb_size):
261
+ processed_batch.extend(batch_training_samples(batch[i : i + mb_size]))
262
+ else:
263
+ processed_batch = batch
264
+
265
+ for sample in processed_batch:
266
+ sample.to_device(torch.device("cpu"))
267
+
268
+ # Convert to grouped batches (if updating the model after each task group)
269
+ if self.config.training.update_after_group:
270
+ return [
271
+ processed_batch[i : i + group_size]
272
+ for i in range(0, len(processed_batch), group_size)
273
+ ]
274
+ else:
275
+ return [processed_batch]
276
+
277
+ def update(self, samples: list[TrainingSample]) -> TrainingMetrics:
278
+ """Perform a gradient update on a batch."""
279
+ import time
280
+
281
+ training_start_time = time.time()
282
+
283
+ # Always create metrics for synchronization
284
+ self.metrics.append(TrainingMetrics())
285
+ metrics = self.metrics[-1]
286
+
287
+ # Prepare groups for GRPO training
288
+ groups = self.prepare_groups(samples)
289
+ self.log(f"Updating over {len(groups)} groups")
290
+
291
+ # Update over mini batch size
292
+ with hud_console.progress("Gradient update...") as progress:
293
+ for epoch in range(self.config.training.epochs): # Do not accumulate across epochs
294
+ progress.update(f"Training epoch {epoch + 1}/{self.config.training.epochs}")
295
+ for group_idx, group in enumerate(groups): # Do not accumulate across "groups"
296
+ self.optimizer.zero_grad(set_to_none=True)
297
+
298
+ debug_per_group = ""
299
+ grad_accum_steps = len(group)
300
+ # Tensor for distributed sync
301
+ global_skip = torch.zeros(1, device=self.device)
302
+
303
+ for s_idx, sample_minibatch in enumerate(group):
304
+ # self.log(f"{group_idx} {sample_minibatch.inputs['assistant_mask'].sum()}")
305
+ # mini_updated = sample_minibatch.inputs["assistant_mask"].sum() > 0
306
+
307
+ # Update mini_updated globally
308
+ # self.log(f"{group_idx} Mini updated: {mini_updated}")
309
+
310
+ # Do not sync until the last minibatch
311
+ if s_idx < len(group) - 1 and self.world_size > 1:
312
+ ddp_ctx = self.policy.no_sync()
313
+ else:
314
+ ddp_ctx = nullcontext()
315
+
316
+ with ddp_ctx, torch.autocast(device_type="cuda", dtype=torch.bfloat16):
317
+ try:
318
+ # if mini_updated:
319
+ loss = self.compute_loss(sample_minibatch) / grad_accum_steps
320
+ debug_per_group += f"l{s_idx}:{round(loss.item(), 3)!s} "
321
+ loss.backward()
322
+ # else: # Dummy backward that touches all params, produces zero g
323
+ # dummy = sum(p.sum() for p in self.policy.parameters()) * 0.0
324
+ # debug_per_group += f"d{s_idx}:{str(round(dummy.item(), 3))} "
325
+ # dummy.backward()
326
+ # self.log(f"{group_idx} GPU Backward: {get_gpu_utilization():.1f}% | Memory: {get_memory_usage():.2f} GB") # noqa: E501
327
+ except torch.cuda.OutOfMemoryError:
328
+ hud_console.warning_log(
329
+ f"{group_idx} CUDA OOM for {sample_minibatch.inputs['input_ids'].numel()} tokens; skipping minibatch" # noqa: E501
330
+ )
331
+ # Dummy backward to keep DDP happy
332
+ dummy = torch.sum(p.sum() for p in self.policy.parameters()) * 0.0 # type: ignore
333
+ debug_per_group += f"o{s_idx}:{round(dummy.item(), 3)!s} "
334
+ dummy.backward()
335
+ # mark global skip if OOM
336
+ global_skip.fill_(1)
337
+ continue
338
+
339
+ if torch.cuda.is_available():
340
+ torch.cuda.empty_cache()
341
+
342
+ # After minibatches loop, sync skip across ranks
343
+ if torch.distributed.is_initialized():
344
+ torch.distributed.all_reduce(global_skip, op=torch.distributed.ReduceOp.MAX)
345
+ skip_any = bool(global_skip.item())
346
+
347
+ if skip_any:
348
+ self.log(f"G[{group_idx}] {debug_per_group} N/A (skipped)")
349
+ continue
350
+
351
+ grad_norm = torch.nn.utils.clip_grad_norm_(
352
+ self.policy.parameters(),
353
+ self.config.training.grad_clip,
354
+ error_if_nonfinite=True,
355
+ )
356
+ self.optimizer.step()
357
+
358
+ debug_per_group += f"g:{round(grad_norm.item(), 3)!s}"
359
+ self.log(f"G[{group_idx}] {debug_per_group}")
360
+
361
+ metrics.update(
362
+ {
363
+ "grad_norm": grad_norm.item()
364
+ if isinstance(grad_norm, torch.Tensor)
365
+ else float(grad_norm),
366
+ }
367
+ )
368
+
369
+ # Calculate training time and throughput
370
+ training_time = time.time() - training_start_time
371
+ total_samples = (
372
+ len(groups) * self.config.training.group_size * self.config.training.mini_batch_size
373
+ )
374
+ samples_per_second = total_samples / training_time if training_time > 0 else 0.0
375
+
376
+ metrics.update(
377
+ {
378
+ "training_time": training_time,
379
+ "samples_per_second": samples_per_second,
380
+ }
381
+ )
382
+
383
+ return metrics
384
+
385
+ def compute_loss(self, sample: TrainingSample) -> torch.Tensor:
386
+ """Compute GRPO loss for a batch of samples."""
387
+ training_cfg = self.config.training
388
+ metrics = self.metrics[-1] if len(self.metrics) > 0 else TrainingMetrics()
389
+
390
+ sample.to_device(self.device)
391
+
392
+ pol_logp, pol_entropy = self.compute_logprobs(
393
+ self.policy,
394
+ sample.inputs,
395
+ )
396
+
397
+ sanity_check(sample, pol_logp, sample.old_logprobs, sample.ref_logprobs)
398
+
399
+ metrics.update(
400
+ {
401
+ "gpu_util": get_gpu_utilization(), # Track peak utilization
402
+ "gpu_memory": get_memory_usage(), # Track memory usage
403
+ }
404
+ )
405
+ self.log(f"GPU Util: {get_gpu_utilization():.1f}% | Memory: {get_memory_usage():.2f} GB")
406
+
407
+ old_logp = sample.old_logprobs
408
+ ref_logp = sample.ref_logprobs
409
+
410
+ if old_logp is None or ref_logp is None or sample.advantage is None:
411
+ raise ValueError("old_logp, ref_logp, or sample.advantage is None")
412
+
413
+ # Use assistant mask to remove non-assistant tokens
414
+ m = sample.inputs["assistant_mask"]
415
+
416
+ # Aggregate per trace or per token
417
+ if training_cfg.ppo_mode == "per_trace":
418
+ counts = m.sum(dim=1).clamp_min(1.0)
419
+ pol_logp = (pol_logp * m.float()).sum(dim=1) / counts
420
+ pol_entropy = (pol_entropy * m.float()).sum(dim=1) / counts
421
+ old_logp = (old_logp * m.float()).sum(dim=1) / counts
422
+ ref_logp = (ref_logp * m.float()).sum(dim=1) / counts
423
+
424
+ # Clip log probability differences
425
+ log_ratio = torch.where(m, pol_logp - old_logp, torch.zeros_like(pol_logp))
426
+ ratio_tok = torch.exp(log_ratio.clamp(-20.0, 20.0))
427
+
428
+ # Ensure advantage shape matches ratio_tok for broadcasting
429
+ advantage = (
430
+ sample.advantage.view(-1, 1) if ratio_tok.dim() == 2 else sample.advantage.squeeze(-1)
431
+ )
432
+
433
+ unclipped = ratio_tok * advantage
434
+ clipped = (
435
+ torch.clamp(ratio_tok, 1 - training_cfg.top_eps, 1 + training_cfg.bottom_eps)
436
+ * advantage
437
+ )
438
+
439
+ policy_term = -torch.minimum(unclipped, clipped)
440
+
441
+ # Clip log probability differences in KL
442
+ log_rho = torch.where(m, pol_logp - ref_logp, torch.zeros_like(pol_logp))
443
+ rho_tok = torch.exp(log_rho.clamp(-20.0, 20.0))
444
+ kl_approx = rho_tok - torch.log(rho_tok) - 1
445
+
446
+ total_loss = (
447
+ policy_term + training_cfg.kl_beta * kl_approx + training_cfg.entropy_beta * pol_entropy
448
+ )
449
+
450
+ # Aggregate loss
451
+ if training_cfg.ppo_mode == "per_trace":
452
+ total_loss = total_loss.mean() if training_cfg.token_agg == "mean" else total_loss.sum() # noqa: S105
453
+ else:
454
+ if training_cfg.token_agg == "mean": # noqa: S105
455
+ total_loss = (total_loss * m).sum() / m.sum().clamp_min(1.0)
456
+ else:
457
+ total_loss = (total_loss * m).sum()
458
+
459
+ # Compute metrics only over masked (assistant) tokens
460
+ mask_count = m.sum().clamp_min(1.0)
461
+ metrics.update(
462
+ {
463
+ "policy_ratio": (ratio_tok * m).sum().item() / mask_count.item()
464
+ if mask_count.item() > 0
465
+ else 1.0,
466
+ "kl": (kl_approx * m).sum().item() / mask_count.item()
467
+ if mask_count.item() > 0
468
+ else 0.0,
469
+ "entropy": (pol_entropy * m).sum().item() / mask_count.item()
470
+ if mask_count.item() > 0
471
+ else 0.0,
472
+ "tokens": sample.inputs["input_ids"].numel(),
473
+ "loss": total_loss.item(),
474
+ }
475
+ )
476
+
477
+ sample.to_device(torch.device("cpu"))
478
+
479
+ return total_loss
480
+
481
+ def compute_logprobs(self, model: Any, inputs: Any) -> tuple[torch.Tensor, torch.Tensor]:
482
+ """Compute masked per-token log probabilities via the model.
483
+
484
+ Returns log probabilities for the actual next tokens.
485
+ """
486
+ try:
487
+ model_inputs = {k: v for k, v in inputs.items() if k != "assistant_mask"}
488
+ out = model(**model_inputs)
489
+
490
+ logits = out.logits / self.config.actor.temperature
491
+ log_probs = F.log_softmax(logits, dim=-1)
492
+
493
+ targets = inputs["input_ids"][:, 1:]
494
+ token_log_probs = log_probs[:, :-1].gather(-1, targets.unsqueeze(-1)).squeeze(-1)
495
+
496
+ # Compute entropy only for assistant tokens to save memory
497
+ assistant_mask = inputs["assistant_mask"]
498
+ entropy = torch.zeros_like(token_log_probs)
499
+ if assistant_mask.any():
500
+ entropy[assistant_mask] = entropy_from_logits(logits[:, :-1][assistant_mask])
501
+
502
+ return token_log_probs, entropy
503
+ except (IndexError, RuntimeError) as e:
504
+ # Handle empty inputs or DDP errors
505
+ hud_console.warning_log(f"Error in compute_logprobs: {e}. Returning dummy values.")
506
+ # Return dummy values that match expected shapes
507
+ seq_len = inputs["input_ids"].shape[1] - 1 if "input_ids" in inputs else 0
508
+ batch_size = inputs["input_ids"].shape[0] if "input_ids" in inputs else 1
509
+ dummy_logprobs = torch.zeros(batch_size, seq_len, device=self.device)
510
+ dummy_entropy = torch.zeros(batch_size, seq_len, device=self.device)
511
+ return dummy_logprobs, dummy_entropy
512
+
513
+ def save(self, path: str) -> None:
514
+ """Save the current policy checkpoint (only on rank 0)."""
515
+ if is_main_process():
516
+ os.makedirs(path, exist_ok=True)
517
+ # Unwrap DDP model if needed
518
+ model_to_save = self.policy.module if hasattr(self.policy, "module") else self.policy
519
+ model_to_save.save_pretrained(path)
520
+ self.log(f"Saved checkpoint to {path}")
521
+
522
+ def load(self, path: str) -> None:
523
+ """Load a policy checkpoint."""
524
+ # Would need to reload LoRA weights
525
+ self.log(f"Loading checkpoint from {path}")
526
+ # Implementation depends on PEFT version
527
+
528
+
529
+ def sanity_check(
530
+ sample: TrainingSample,
531
+ pol_logp: torch.Tensor,
532
+ old_logp: torch.Tensor | None,
533
+ ref_logp: torch.Tensor | None,
534
+ ) -> None:
535
+ assert "assistant_mask" in sample.inputs # noqa: S101
536
+ m = sample.inputs["assistant_mask"]
537
+ if old_logp is None or ref_logp is None:
538
+ return
539
+ with torch.no_grad():
540
+ B, K = pol_logp.shape
541
+ assert old_logp.shape == (B, K), "old_logp shape mismatch" # noqa: S101
542
+ assert ref_logp.shape == (B, K), "ref_logp shape mismatch" # noqa: S101
543
+ assert m.shape == (B, K), "assistant_mask shape mismatch" # noqa: S101
544
+
545
+ # Check mask is subset of attention_mask[:, 1:]
546
+ att = sample.inputs.get("attention_mask", None)
547
+ if att is not None and att.dim() == 2:
548
+ att_shift = att[:, 1:].bool()
549
+ bad = (m & ~att_shift).sum().item()
550
+ if bad > 0:
551
+ hud_console.warning_log(f"assistant_mask overlaps padding: {bad} tokens")
552
+
553
+ # Finiteness on masked entries only
554
+ def _stats(name: str, t: torch.Tensor) -> None:
555
+ sel = t[m]
556
+ if sel.numel() == 0:
557
+ hud_console.warning_log(f"{name} empty under mask")
558
+ return
559
+ finite = torch.isfinite(sel)
560
+ if finite.sum() < sel.numel():
561
+ hud_console.warning_log(
562
+ f"{name} non-finite: {((~finite).sum().item())}/{sel.numel()}"
563
+ )
564
+ sel = sel[finite].float()
565
+
566
+ _stats("pol_logp", pol_logp)
567
+ _stats("old_logp", old_logp)
568
+ _stats("ref_logp", ref_logp)
569
+
570
+ # Log-probabilities should be <= 0 (log-softmax)
571
+ if (pol_logp[m] > 1e-6).any():
572
+ hud_console.warning_log("pol_logp has positive values under mask")
573
+
574
+ # Precompute masked deltas and ratios for diagnostics (before exp)
575
+ masked_log_ratio = torch.zeros_like(pol_logp)
576
+ masked_log_ratio[m] = (pol_logp - old_logp)[m]
577
+ masked_log_rho = torch.zeros_like(pol_logp)
578
+ masked_log_rho[m] = (pol_logp - ref_logp)[m]
579
+
580
+ _stats("log_ratio(masked)", masked_log_ratio)
581
+ _stats("log_rho(masked)", masked_log_rho)
582
+
583
+ # Ratios after clamp (diagnostic only)
584
+ ratio_diag = torch.zeros_like(pol_logp)
585
+ rho_diag = torch.zeros_like(pol_logp)
586
+ ratio_diag[m] = torch.exp(masked_log_ratio[m].clamp(-20.0, 20.0))
587
+ rho_diag[m] = torch.exp(masked_log_rho[m].clamp(-20.0, 20.0))
588
+ _stats("ratio_tok(masked)", ratio_diag)
589
+ _stats("rho_tok(masked)", rho_diag)
@@ -0,0 +1 @@
1
+ """Tests for RL module."""