hud-python 0.4.28__py3-none-any.whl → 0.4.29__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of hud-python might be problematic. Click here for more details.

Files changed (75) hide show
  1. hud/__init__.py +2 -1
  2. hud/agents/base.py +73 -45
  3. hud/agents/claude.py +8 -4
  4. hud/agents/openai_chat_generic.py +65 -40
  5. hud/agents/tests/test_base.py +0 -4
  6. hud/agents/tests/test_openai.py +1 -1
  7. hud/cli/__init__.py +182 -52
  8. hud/cli/dev.py +8 -9
  9. hud/cli/eval.py +317 -119
  10. hud/cli/flows/__init__.py +0 -0
  11. hud/cli/flows/tasks.py +0 -0
  12. hud/cli/get.py +160 -0
  13. hud/cli/rl/__init__.py +563 -71
  14. hud/cli/rl/config.py +94 -0
  15. hud/cli/rl/display.py +133 -0
  16. hud/cli/rl/gpu.py +63 -0
  17. hud/cli/rl/gpu_utils.py +318 -0
  18. hud/cli/rl/presets.py +96 -0
  19. hud/cli/rl/remote_runner.py +348 -0
  20. hud/cli/rl/rl_api.py +150 -0
  21. hud/cli/rl/vllm.py +177 -0
  22. hud/cli/tests/test_analyze_metadata.py +0 -1
  23. hud/cli/utils/tasks.py +26 -0
  24. hud/clients/base.py +21 -23
  25. hud/clients/mcp_use.py +36 -44
  26. hud/clients/tests/test_mcp_use_retry.py +10 -10
  27. hud/datasets/__init__.py +4 -3
  28. hud/datasets/{execution/parallel.py → parallel.py} +1 -1
  29. hud/datasets/{execution/runner.py → runner.py} +1 -1
  30. hud/datasets/utils.py +1 -1
  31. hud/native/tests/test_native_init.py +1 -1
  32. hud/otel/config.py +1 -1
  33. hud/otel/instrumentation.py +35 -0
  34. hud/rl/README.md +31 -0
  35. hud/rl/__init__.py +1 -0
  36. hud/rl/actor.py +174 -0
  37. hud/rl/buffer.py +371 -0
  38. hud/rl/chat_template.jinja +101 -0
  39. hud/rl/config.py +184 -0
  40. hud/rl/distributed.py +95 -0
  41. hud/rl/learner.py +586 -0
  42. hud/rl/tests/__init__.py +1 -0
  43. hud/rl/tests/test_learner.py +171 -0
  44. hud/rl/train.py +354 -0
  45. hud/rl/types.py +101 -0
  46. hud/rl/utils/start_vllm_server.sh +30 -0
  47. hud/rl/utils.py +524 -0
  48. hud/rl/vllm_adapter.py +125 -0
  49. hud/settings.py +6 -0
  50. hud/telemetry/__init__.py +2 -1
  51. hud/telemetry/job.py +46 -3
  52. hud/telemetry/tests/test_trace.py +3 -3
  53. hud/telemetry/trace.py +85 -13
  54. hud/tools/tests/test_computer.py +3 -3
  55. hud/tools/tests/test_computer_actions.py +1 -1
  56. hud/types.py +123 -2
  57. hud/utils/group_eval.py +223 -0
  58. hud/utils/hud_console.py +113 -13
  59. hud/utils/tasks.py +119 -0
  60. hud/utils/tests/test_version.py +1 -1
  61. hud/version.py +1 -1
  62. {hud_python-0.4.28.dist-info → hud_python-0.4.29.dist-info}/METADATA +20 -2
  63. {hud_python-0.4.28.dist-info → hud_python-0.4.29.dist-info}/RECORD +66 -46
  64. hud/cli/hf.py +0 -406
  65. hud/cli/rl/README.md +0 -243
  66. hud/cli/rl/init.py +0 -370
  67. hud/cli/rl/pod.py +0 -501
  68. hud/cli/rl/ssh.py +0 -322
  69. hud/cli/rl/train.py +0 -562
  70. hud/cli/rl/utils.py +0 -165
  71. hud/datasets/execution/__init__.py +0 -13
  72. hud/datasets/task.py +0 -116
  73. {hud_python-0.4.28.dist-info → hud_python-0.4.29.dist-info}/WHEEL +0 -0
  74. {hud_python-0.4.28.dist-info → hud_python-0.4.29.dist-info}/entry_points.txt +0 -0
  75. {hud_python-0.4.28.dist-info → hud_python-0.4.29.dist-info}/licenses/LICENSE +0 -0
hud/rl/learner.py ADDED
@@ -0,0 +1,586 @@
1
+ """GRPO learner for vision-language and text models."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+ import os
7
+ from typing import TYPE_CHECKING, Any
8
+
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from peft import LoraConfig, get_peft_model
12
+ from torch.nn.parallel import DistributedDataParallel as DDP
13
+ from transformers import (
14
+ AutoModelForCausalLM,
15
+ AutoProcessor,
16
+ AutoTokenizer,
17
+ Qwen2_5_VLForConditionalGeneration,
18
+ )
19
+
20
+ try:
21
+ from liger_kernel.transformers import apply_liger_kernel_to_qwen2_5_vl # type: ignore
22
+
23
+ LIGER_AVAILABLE = True
24
+ except ImportError:
25
+ LIGER_AVAILABLE = False
26
+
27
+ try:
28
+ import bitsandbytes as bnb # type: ignore
29
+
30
+ BNB_AVAILABLE = True
31
+ except ImportError:
32
+ BNB_AVAILABLE = False
33
+
34
+ from contextlib import nullcontext
35
+
36
+ from hud.rl.distributed import (
37
+ get_local_rank,
38
+ get_world_size,
39
+ is_main_process,
40
+ )
41
+ from hud.rl.utils import (
42
+ batch_training_samples,
43
+ entropy_from_logits,
44
+ get_gpu_utilization,
45
+ get_memory_usage,
46
+ prepare_inputs,
47
+ )
48
+ from hud.utils.hud_console import HUDConsole
49
+
50
+ from .types import TrainingMetrics, TrainingSample
51
+
52
+ logger = logging.getLogger(__name__)
53
+ hud_console = HUDConsole(logger)
54
+
55
+ if TYPE_CHECKING:
56
+ from .config import Config
57
+
58
+
59
+ class GRPOLearner:
60
+ """GRPO learning algorithm for Vision-Language Models (VLMs) and Text Models."""
61
+
62
+ def __init__(self, config: Config) -> None:
63
+ self.config = config
64
+ self.local_rank = get_local_rank()
65
+ self.world_size = get_world_size()
66
+ self.device = torch.device(
67
+ f"cuda:{self.local_rank}" if torch.cuda.is_available() else "cpu"
68
+ )
69
+
70
+ # Detect model type
71
+ self.is_vl_model = "VL" in config.model.base_model
72
+
73
+ # Load models and processor
74
+ self.processor, self.policy, self.ref, self.optimizer = self._load_models()
75
+ self.metrics: list[TrainingMetrics] = []
76
+
77
+ def log(self, message: str) -> None:
78
+ hud_console.info_log(f"[{self.local_rank}] {message}")
79
+
80
+ def _load_models(self) -> tuple[Any, Any, Any, Any]:
81
+ """Load policy, reference models and optimizer."""
82
+ model_cfg = self.config.model
83
+
84
+ # Detect if this is a VL model or standard text model
85
+ is_vl_model = "VL" in model_cfg.base_model
86
+ model_type = "Vision-Language" if is_vl_model else "Text"
87
+ self.log(f"Loading {model_type} model: {model_cfg.base_model}")
88
+
89
+ # Apply Liger kernel optimizations if available and enabled
90
+ if model_cfg.use_liger and LIGER_AVAILABLE:
91
+ if is_vl_model:
92
+ self.log("Applying Liger kernel optimizations to Qwen2.5-VL")
93
+ apply_liger_kernel_to_qwen2_5_vl(
94
+ rope=True, # Optimized RoPE
95
+ rms_norm=True, # Optimized RMSNorm
96
+ swiglu=True, # Optimized SwiGLU
97
+ fused_linear_cross_entropy=True, # Fused Linear+CrossEntropy for memory
98
+ )
99
+ elif model_cfg.use_liger and not LIGER_AVAILABLE:
100
+ self.log(
101
+ "Liger kernel requested but not installed. Install with: pip install liger-kernel"
102
+ )
103
+
104
+ # Load processor/tokenizer based on model type
105
+ if is_vl_model:
106
+ processor = AutoProcessor.from_pretrained(
107
+ model_cfg.base_model,
108
+ min_pixels=model_cfg.min_pixels,
109
+ max_pixels=model_cfg.max_pixels,
110
+ )
111
+ else:
112
+ processor = AutoTokenizer.from_pretrained(model_cfg.base_model)
113
+
114
+ # Load policy model with LoRA
115
+ # Use attention implementation from config
116
+ attn_implementation = model_cfg.attn_implementation
117
+
118
+ # Choose the appropriate model class
119
+ model_class = Qwen2_5_VLForConditionalGeneration if is_vl_model else AutoModelForCausalLM
120
+
121
+ try:
122
+ policy = model_class.from_pretrained(
123
+ model_cfg.base_model,
124
+ torch_dtype=torch.bfloat16,
125
+ attn_implementation=attn_implementation,
126
+ )
127
+ self.log(f"Using {attn_implementation} for attention")
128
+ except (ImportError, ValueError) as e:
129
+ # Only fallback if explicitly using flash_attention_2 and it's not available
130
+ if attn_implementation == "flash_attention_2":
131
+ self.log(f"Flash Attention 2 not available ({e}), using eager attention")
132
+ policy = model_class.from_pretrained(
133
+ model_cfg.base_model,
134
+ torch_dtype=torch.bfloat16,
135
+ attn_implementation="eager",
136
+ )
137
+ else:
138
+ raise # Re-raise if it's a different error
139
+
140
+ # Move model to device
141
+ policy = policy.to(self.device) # type: ignore
142
+ # Enable gradient checkpointing for memory efficiency
143
+ if model_cfg.gradient_checkpointing:
144
+ policy.gradient_checkpointing_enable()
145
+ self.log("Gradient checkpointing enabled for memory efficiency")
146
+
147
+ # Add LoRA adapters
148
+ lora_config = LoraConfig(
149
+ r=model_cfg.lora_r,
150
+ lora_alpha=model_cfg.lora_alpha,
151
+ lora_dropout=model_cfg.lora_dropout,
152
+ task_type="CAUSAL_LM",
153
+ bias="none",
154
+ target_modules=list(model_cfg.target_modules),
155
+ )
156
+ policy.config.use_cache = False
157
+ policy = get_peft_model(policy, lora_config)
158
+
159
+ # Wrap with DDP if in distributed mode
160
+ if self.world_size > 1:
161
+ policy = DDP(
162
+ policy,
163
+ device_ids=[self.local_rank],
164
+ output_device=self.local_rank,
165
+ broadcast_buffers=False,
166
+ find_unused_parameters=True,
167
+ )
168
+ self.log("Wrapped model (find_unused_parameters=True)")
169
+
170
+ # Create optimizer - need to access underlying model if DDP
171
+ base_model = policy.module if hasattr(policy, "module") else policy
172
+ trainable_params = [p for _, p in base_model.named_parameters() if p.requires_grad] # type: ignore
173
+
174
+ # Use 8-bit optimizer if configured
175
+ if self.config.training.use_8bit_optimizer and BNB_AVAILABLE:
176
+ hud_console.info("Using 8-bit AdamW optimizer from bitsandbytes")
177
+ optimizer = bnb.optim.AdamW8bit(
178
+ trainable_params,
179
+ lr=self.config.training.lr,
180
+ betas=self.config.training.adam_betas,
181
+ eps=self.config.training.adam_eps,
182
+ )
183
+ else:
184
+ self.log("Using standard FP32 AdamW optimizer")
185
+ optimizer = torch.optim.AdamW(
186
+ trainable_params,
187
+ lr=self.config.training.lr,
188
+ betas=self.config.training.adam_betas,
189
+ eps=self.config.training.adam_eps,
190
+ )
191
+
192
+ # Log optimizer info
193
+ self.log(f"Optimizer: {type(optimizer).__name__}")
194
+ num_params = sum(p.numel() for p in trainable_params)
195
+ self.log(f"Number of trainable parameters: {num_params:,}")
196
+
197
+ return processor, policy, None, optimizer
198
+
199
+ def prepare_groups(
200
+ self,
201
+ samples: list[TrainingSample],
202
+ ) -> list[list[TrainingSample]]:
203
+ """Prepare groups of samples for training."""
204
+ # Prepare inputs with messages
205
+ batch = []
206
+ for sample in samples:
207
+ inputs = prepare_inputs(sample, self.processor)
208
+ # If inputs are invalid, create dummy inputs to maintain batch size
209
+ if (
210
+ not inputs
211
+ or "input_ids" not in inputs
212
+ or inputs.get("input_ids", torch.tensor([])).numel() == 0
213
+ ):
214
+ hud_console.warning_log("Sample has invalid inputs, using dummy values")
215
+ # Create minimal dummy inputs to keep batch size consistent
216
+ inputs = {
217
+ "input_ids": torch.zeros(1, 2, dtype=torch.long), # Minimal sequence
218
+ "attention_mask": torch.ones(1, 2, dtype=torch.long),
219
+ "assistant_mask": torch.zeros(1, 1, dtype=torch.bool), # T-1 length
220
+ }
221
+ elif "assistant_mask" not in inputs:
222
+ hud_console.warning_log("Sample missing assistant_mask, creating zero mask")
223
+ seq_len = inputs["input_ids"].shape[-1]
224
+ inputs["assistant_mask"] = torch.zeros(
225
+ inputs["input_ids"].shape[0], seq_len - 1, dtype=torch.bool
226
+ )
227
+
228
+ new_sample = TrainingSample(**sample.model_dump())
229
+ new_sample.inputs = inputs
230
+ new_sample.advantage = sample.advantage
231
+ batch.append(new_sample)
232
+
233
+ with hud_console.progress("Processing batch of traces...") as progress, torch.no_grad():
234
+ for i, sample in enumerate(batch):
235
+ if is_main_process():
236
+ progress.update(f"Processing batch of traces... {i}/{len(batch)}")
237
+ if sample.inputs:
238
+ sample = sample.to_device(self.device)
239
+ sample.old_logprobs, _ = self.compute_logprobs(self.policy, sample.inputs)
240
+
241
+ policy_module = self.policy.module if hasattr(self.policy, "module") else self.policy
242
+ with policy_module.disable_adapter():
243
+ for i, sample in enumerate(batch):
244
+ if is_main_process():
245
+ progress.update(f"Processing batch of traces... {i}/{len(batch)}")
246
+ if sample.inputs:
247
+ sample.ref_logprobs, _ = self.compute_logprobs(self.policy, sample.inputs)
248
+
249
+ hud_console.info_log("Creating mini-batches...")
250
+ group_size = self.config.training.group_size
251
+ processed_batch = []
252
+ if not self.config.training.accumulate_over_minibatches:
253
+ # Find minibatches and group them via batch_training_samples
254
+ # Minibatches control the batch size of the forward pass to the model
255
+ mb_size = self.config.training.mini_batch_size
256
+ group_size = group_size // mb_size
257
+ for i in range(0, len(batch), mb_size):
258
+ processed_batch.extend(batch_training_samples(batch[i : i + mb_size]))
259
+ else:
260
+ processed_batch = batch
261
+
262
+ for sample in processed_batch:
263
+ sample.to_device(torch.device("cpu"))
264
+
265
+ # Convert to grouped batches (if updating the model after each task group)
266
+ if self.config.training.update_after_group:
267
+ return [
268
+ processed_batch[i : i + group_size]
269
+ for i in range(0, len(processed_batch), group_size)
270
+ ]
271
+ else:
272
+ return [processed_batch]
273
+
274
+ def update(self, samples: list[TrainingSample]) -> TrainingMetrics:
275
+ """Perform a gradient update on a batch."""
276
+ import time
277
+
278
+ training_start_time = time.time()
279
+
280
+ # Always create metrics for synchronization
281
+ self.metrics.append(TrainingMetrics())
282
+ metrics = self.metrics[-1]
283
+
284
+ # Prepare groups for GRPO training
285
+ groups = self.prepare_groups(samples)
286
+ self.log(f"Updating over {len(groups)} groups")
287
+
288
+ # Update over mini batch size
289
+ with hud_console.progress("Gradient update...") as progress:
290
+ for epoch in range(self.config.training.epochs): # Do not accumulate across epochs
291
+ progress.update(f"Training epoch {epoch + 1}/{self.config.training.epochs}")
292
+ for group_idx, group in enumerate(groups): # Do not accumulate across "groups"
293
+ self.optimizer.zero_grad(set_to_none=True)
294
+
295
+ debug_per_group = ""
296
+ grad_accum_steps = len(group)
297
+ # Tensor for distributed sync
298
+ global_skip = torch.zeros(1, device=self.device)
299
+
300
+ for s_idx, sample_minibatch in enumerate(group):
301
+ # self.log(f"{group_idx} {sample_minibatch.inputs['assistant_mask'].sum()}")
302
+ # mini_updated = sample_minibatch.inputs["assistant_mask"].sum() > 0
303
+
304
+ # Update mini_updated globally
305
+ # self.log(f"{group_idx} Mini updated: {mini_updated}")
306
+
307
+ # Do not sync until the last minibatch
308
+ if s_idx < len(group) - 1 and self.world_size > 1:
309
+ ddp_ctx = self.policy.no_sync()
310
+ else:
311
+ ddp_ctx = nullcontext()
312
+
313
+ with ddp_ctx, torch.autocast(device_type="cuda", dtype=torch.bfloat16):
314
+ try:
315
+ # if mini_updated:
316
+ loss = self.compute_loss(sample_minibatch) / grad_accum_steps
317
+ debug_per_group += f"l{s_idx}:{round(loss.item(), 3)!s} "
318
+ loss.backward()
319
+ # else: # Dummy backward that touches all params, produces zero g
320
+ # dummy = sum(p.sum() for p in self.policy.parameters()) * 0.0
321
+ # debug_per_group += f"d{s_idx}:{str(round(dummy.item(), 3))} "
322
+ # dummy.backward()
323
+ # self.log(f"{group_idx} GPU Backward: {get_gpu_utilization():.1f}% | Memory: {get_memory_usage():.2f} GB") # noqa: E501
324
+ except torch.cuda.OutOfMemoryError:
325
+ hud_console.warning_log(
326
+ f"{group_idx} CUDA OOM for {sample_minibatch.inputs['input_ids'].numel()} tokens; skipping minibatch" # noqa: E501
327
+ )
328
+ # Dummy backward to keep DDP happy
329
+ dummy = torch.sum(p.sum() for p in self.policy.parameters()) * 0.0 # type: ignore
330
+ debug_per_group += f"o{s_idx}:{round(dummy.item(), 3)!s} "
331
+ dummy.backward()
332
+ # mark global skip if OOM
333
+ global_skip.fill_(1)
334
+ continue
335
+
336
+ if torch.cuda.is_available():
337
+ torch.cuda.empty_cache()
338
+
339
+ # After minibatches loop, sync skip across ranks
340
+ if torch.distributed.is_initialized():
341
+ torch.distributed.all_reduce(global_skip, op=torch.distributed.ReduceOp.MAX)
342
+ skip_any = bool(global_skip.item())
343
+
344
+ if skip_any:
345
+ self.log(f"G[{group_idx}] {debug_per_group} N/A (skipped)")
346
+ continue
347
+
348
+ grad_norm = torch.nn.utils.clip_grad_norm_(
349
+ self.policy.parameters(),
350
+ self.config.training.grad_clip,
351
+ error_if_nonfinite=True,
352
+ )
353
+ self.optimizer.step()
354
+
355
+ debug_per_group += f"g:{round(grad_norm.item(), 3)!s}"
356
+ self.log(f"G[{group_idx}] {debug_per_group}")
357
+
358
+ metrics.update(
359
+ {
360
+ "grad_norm": grad_norm.item()
361
+ if isinstance(grad_norm, torch.Tensor)
362
+ else float(grad_norm),
363
+ }
364
+ )
365
+
366
+ # Calculate training time and throughput
367
+ training_time = time.time() - training_start_time
368
+ total_samples = (
369
+ len(groups) * self.config.training.group_size * self.config.training.mini_batch_size
370
+ )
371
+ samples_per_second = total_samples / training_time if training_time > 0 else 0.0
372
+
373
+ metrics.update(
374
+ {
375
+ "training_time": training_time,
376
+ "samples_per_second": samples_per_second,
377
+ }
378
+ )
379
+
380
+ return metrics
381
+
382
+ def compute_loss(self, sample: TrainingSample) -> torch.Tensor:
383
+ """Compute GRPO loss for a batch of samples."""
384
+ training_cfg = self.config.training
385
+ metrics = self.metrics[-1] if len(self.metrics) > 0 else TrainingMetrics()
386
+
387
+ sample.to_device(self.device)
388
+
389
+ pol_logp, pol_entropy = self.compute_logprobs(
390
+ self.policy,
391
+ sample.inputs,
392
+ )
393
+
394
+ sanity_check(sample, pol_logp, sample.old_logprobs, sample.ref_logprobs)
395
+
396
+ metrics.update(
397
+ {
398
+ "gpu_util": get_gpu_utilization(), # Track peak utilization
399
+ "gpu_memory": get_memory_usage(), # Track memory usage
400
+ }
401
+ )
402
+ self.log(f"GPU Util: {get_gpu_utilization():.1f}% | Memory: {get_memory_usage():.2f} GB")
403
+
404
+ old_logp = sample.old_logprobs
405
+ ref_logp = sample.ref_logprobs
406
+
407
+ if old_logp is None or ref_logp is None or sample.advantage is None:
408
+ raise ValueError("old_logp, ref_logp, or sample.advantage is None")
409
+
410
+ # Use assistant mask to remove non-assistant tokens
411
+ m = sample.inputs["assistant_mask"]
412
+
413
+ # Aggregate per trace or per token
414
+ if training_cfg.ppo_mode == "per_trace":
415
+ counts = m.sum(dim=1).clamp_min(1.0)
416
+ pol_logp = (pol_logp * m.float()).sum(dim=1) / counts
417
+ pol_entropy = (pol_entropy * m.float()).sum(dim=1) / counts
418
+ old_logp = (old_logp * m.float()).sum(dim=1) / counts
419
+ ref_logp = (ref_logp * m.float()).sum(dim=1) / counts
420
+
421
+ # Clip log probability differences
422
+ log_ratio = torch.where(m, pol_logp - old_logp, torch.zeros_like(pol_logp))
423
+ ratio_tok = torch.exp(log_ratio.clamp(-20.0, 20.0))
424
+
425
+ # Ensure advantage shape matches ratio_tok for broadcasting
426
+ advantage = (
427
+ sample.advantage.view(-1, 1) if ratio_tok.dim() == 2 else sample.advantage.squeeze(-1)
428
+ )
429
+
430
+ unclipped = ratio_tok * advantage
431
+ clipped = (
432
+ torch.clamp(ratio_tok, 1 - training_cfg.top_eps, 1 + training_cfg.bottom_eps)
433
+ * advantage
434
+ )
435
+
436
+ policy_term = -torch.minimum(unclipped, clipped)
437
+
438
+ # Clip log probability differences in KL
439
+ log_rho = torch.where(m, pol_logp - ref_logp, torch.zeros_like(pol_logp))
440
+ rho_tok = torch.exp(log_rho.clamp(-20.0, 20.0))
441
+ kl_approx = rho_tok - torch.log(rho_tok) - 1
442
+
443
+ total_loss = (
444
+ policy_term + training_cfg.kl_beta * kl_approx + training_cfg.entropy_beta * pol_entropy
445
+ )
446
+
447
+ # Aggregate loss
448
+ if training_cfg.ppo_mode == "per_trace":
449
+ total_loss = total_loss.mean() if training_cfg.token_agg == "mean" else total_loss.sum() # noqa: S105
450
+ else:
451
+ if training_cfg.token_agg == "mean": # noqa: S105
452
+ total_loss = (total_loss * m).sum() / m.sum().clamp_min(1.0)
453
+ else:
454
+ total_loss = (total_loss * m).sum()
455
+
456
+ # Compute metrics only over masked (assistant) tokens
457
+ mask_count = m.sum().clamp_min(1.0)
458
+ metrics.update(
459
+ {
460
+ "policy_ratio": (ratio_tok * m).sum().item() / mask_count.item()
461
+ if mask_count.item() > 0
462
+ else 1.0,
463
+ "kl": (kl_approx * m).sum().item() / mask_count.item()
464
+ if mask_count.item() > 0
465
+ else 0.0,
466
+ "entropy": (pol_entropy * m).sum().item() / mask_count.item()
467
+ if mask_count.item() > 0
468
+ else 0.0,
469
+ "tokens": sample.inputs["input_ids"].numel(),
470
+ "loss": total_loss.item(),
471
+ }
472
+ )
473
+
474
+ sample.to_device(torch.device("cpu"))
475
+
476
+ return total_loss
477
+
478
+ def compute_logprobs(self, model: Any, inputs: Any) -> tuple[torch.Tensor, torch.Tensor]:
479
+ """Compute masked per-token log probabilities via the model.
480
+
481
+ Returns log probabilities for the actual next tokens.
482
+ """
483
+ try:
484
+ model_inputs = {k: v for k, v in inputs.items() if k != "assistant_mask"}
485
+ out = model(**model_inputs)
486
+
487
+ logits = out.logits / self.config.actor.temperature
488
+ log_probs = F.log_softmax(logits, dim=-1)
489
+
490
+ targets = inputs["input_ids"][:, 1:]
491
+ token_log_probs = log_probs[:, :-1].gather(-1, targets.unsqueeze(-1)).squeeze(-1)
492
+
493
+ # Compute entropy only for assistant tokens to save memory
494
+ assistant_mask = inputs["assistant_mask"]
495
+ entropy = torch.zeros_like(token_log_probs)
496
+ if assistant_mask.any():
497
+ entropy[assistant_mask] = entropy_from_logits(logits[:, :-1][assistant_mask])
498
+
499
+ return token_log_probs, entropy
500
+ except (IndexError, RuntimeError) as e:
501
+ # Handle empty inputs or DDP errors
502
+ hud_console.warning_log(f"Error in compute_logprobs: {e}. Returning dummy values.")
503
+ # Return dummy values that match expected shapes
504
+ seq_len = inputs["input_ids"].shape[1] - 1 if "input_ids" in inputs else 0
505
+ batch_size = inputs["input_ids"].shape[0] if "input_ids" in inputs else 1
506
+ dummy_logprobs = torch.zeros(batch_size, seq_len, device=self.device)
507
+ dummy_entropy = torch.zeros(batch_size, seq_len, device=self.device)
508
+ return dummy_logprobs, dummy_entropy
509
+
510
+ def save(self, path: str) -> None:
511
+ """Save the current policy checkpoint (only on rank 0)."""
512
+ if is_main_process():
513
+ os.makedirs(path, exist_ok=True)
514
+ # Unwrap DDP model if needed
515
+ model_to_save = self.policy.module if hasattr(self.policy, "module") else self.policy
516
+ model_to_save.save_pretrained(path)
517
+ self.log(f"Saved checkpoint to {path}")
518
+
519
+ def load(self, path: str) -> None:
520
+ """Load a policy checkpoint."""
521
+ # Would need to reload LoRA weights
522
+ self.log(f"Loading checkpoint from {path}")
523
+ # Implementation depends on PEFT version
524
+
525
+
526
+ def sanity_check(
527
+ sample: TrainingSample,
528
+ pol_logp: torch.Tensor,
529
+ old_logp: torch.Tensor | None,
530
+ ref_logp: torch.Tensor | None,
531
+ ) -> None:
532
+ assert "assistant_mask" in sample.inputs # noqa: S101
533
+ m = sample.inputs["assistant_mask"]
534
+ if old_logp is None or ref_logp is None:
535
+ return
536
+ with torch.no_grad():
537
+ B, K = pol_logp.shape
538
+ assert old_logp.shape == (B, K), "old_logp shape mismatch" # noqa: S101
539
+ assert ref_logp.shape == (B, K), "ref_logp shape mismatch" # noqa: S101
540
+ assert m.shape == (B, K), "assistant_mask shape mismatch" # noqa: S101
541
+
542
+ # Check mask is subset of attention_mask[:, 1:]
543
+ att = sample.inputs.get("attention_mask", None)
544
+ if att is not None and att.dim() == 2:
545
+ att_shift = att[:, 1:].bool()
546
+ bad = (m & ~att_shift).sum().item()
547
+ if bad > 0:
548
+ hud_console.warning_log(f"assistant_mask overlaps padding: {bad} tokens")
549
+
550
+ # Finiteness on masked entries only
551
+ def _stats(name: str, t: torch.Tensor) -> None:
552
+ sel = t[m]
553
+ if sel.numel() == 0:
554
+ hud_console.warning_log(f"{name} empty under mask")
555
+ return
556
+ finite = torch.isfinite(sel)
557
+ if finite.sum() < sel.numel():
558
+ hud_console.warning_log(
559
+ f"{name} non-finite: {((~finite).sum().item())}/{sel.numel()}"
560
+ )
561
+ sel = sel[finite].float()
562
+
563
+ _stats("pol_logp", pol_logp)
564
+ _stats("old_logp", old_logp)
565
+ _stats("ref_logp", ref_logp)
566
+
567
+ # Log-probabilities should be <= 0 (log-softmax)
568
+ if (pol_logp[m] > 1e-6).any():
569
+ hud_console.warning_log("pol_logp has positive values under mask")
570
+
571
+ # Precompute masked deltas and ratios for diagnostics (before exp)
572
+ masked_log_ratio = torch.zeros_like(pol_logp)
573
+ masked_log_ratio[m] = (pol_logp - old_logp)[m]
574
+ masked_log_rho = torch.zeros_like(pol_logp)
575
+ masked_log_rho[m] = (pol_logp - ref_logp)[m]
576
+
577
+ _stats("log_ratio(masked)", masked_log_ratio)
578
+ _stats("log_rho(masked)", masked_log_rho)
579
+
580
+ # Ratios after clamp (diagnostic only)
581
+ ratio_diag = torch.zeros_like(pol_logp)
582
+ rho_diag = torch.zeros_like(pol_logp)
583
+ ratio_diag[m] = torch.exp(masked_log_ratio[m].clamp(-20.0, 20.0))
584
+ rho_diag[m] = torch.exp(masked_log_rho[m].clamp(-20.0, 20.0))
585
+ _stats("ratio_tok(masked)", ratio_diag)
586
+ _stats("rho_tok(masked)", rho_diag)
@@ -0,0 +1 @@
1
+ """Tests for RL module."""