textpolicy 0.1.2__py3-none-any.whl → 0.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -26,7 +26,48 @@ if mx is None:
26
26
  return fn
27
27
 
28
28
  mx = _DummyMx()
29
- from typing import List, Union
29
+ from typing import List, Union, Tuple, Dict, Any, Optional
30
+ from dataclasses import dataclass
31
+ from collections import defaultdict
32
+
33
+ # Import length shaping utilities from dedicated module
34
+ from .length_shaping import (
35
+ compute_length_penalty,
36
+ apply_length_shaping,
37
+ compute_length_shaping_stats,
38
+ )
39
+
40
+
41
+ # --- Clip Configuration Helper ---
42
+ @dataclass
43
+ class ClipConfig:
44
+ """Configuration for PPO/DAPO clipping bounds."""
45
+ low: float = 0.2
46
+ high: float = 0.28
47
+
48
+
49
+ def resolve_clip_config(
50
+ clip_ratio: Optional[float],
51
+ clip_ratio_low: float = 0.2,
52
+ clip_ratio_high: float = 0.28,
53
+ ) -> ClipConfig:
54
+ """
55
+ Resolve clipping configuration with backward compatibility.
56
+
57
+ Centralizes the logic for handling symmetric vs asymmetric clipping bounds.
58
+
59
+ Args:
60
+ clip_ratio: Symmetric clipping ratio (backward compatibility).
61
+ If provided, overrides clip_ratio_low and clip_ratio_high.
62
+ clip_ratio_low: Lower bound offset (default 0.2)
63
+ clip_ratio_high: Upper bound offset (default 0.28)
64
+
65
+ Returns:
66
+ ClipConfig with resolved low and high bounds
67
+ """
68
+ if clip_ratio is not None:
69
+ return ClipConfig(low=clip_ratio, high=clip_ratio)
70
+ return ClipConfig(low=clip_ratio_low, high=clip_ratio_high)
30
71
 
31
72
 
32
73
  def compute_advantages(rewards: Union[List[float], mx.array]) -> mx.array:
@@ -113,45 +154,88 @@ def compute_advantages_dr_grpo(rewards: Union[List[float], mx.array]) -> mx.arra
113
154
 
114
155
  def policy_loss(
115
156
  old_logprobs: mx.array,
116
- new_logprobs: mx.array,
157
+ new_logprobs: mx.array,
117
158
  advantages: mx.array,
118
- clip_ratio: float = 0.2
159
+ clip_ratio: float = None,
160
+ clip_ratio_low: float = 0.2,
161
+ clip_ratio_high: float = 0.28,
162
+ normalize_constant: int = None
119
163
  ) -> mx.array:
120
164
  """
121
- GRPO policy loss with PPO-style clipping.
122
-
165
+ GRPO policy loss with PPO-style clipping, supporting DAPO asymmetric bounds.
166
+
123
167
  Uses clipped surrogate objective but with group-relative advantages
124
- instead of GAE advantages.
125
-
168
+ instead of GAE advantages. Supports asymmetric clipping bounds (DAPO-style)
169
+ to prevent entropy collapse while maintaining training stability.
170
+
171
+ DAPO insight: Asymmetric bounds allow the model to increase probabilities
172
+ of good actions more easily than decreasing probabilities of bad actions,
173
+ promoting diversity and preventing entropy collapse.
174
+
175
+ Dr. GRPO insight: Dividing by a fixed constant instead of token count
176
+ eliminates length bias that artificially inflates incorrect (longer) responses.
177
+
126
178
  Args:
127
179
  old_logprobs: Log probabilities from rollout collection
128
180
  new_logprobs: Log probabilities from current policy evaluation
129
181
  advantages: Group-relative advantages from compute_advantages()
130
- clip_ratio: Clipping ratio for surrogate objective
131
-
182
+ clip_ratio: Symmetric clipping ratio (for backward compatibility).
183
+ If provided, overrides clip_ratio_low and clip_ratio_high.
184
+ clip_ratio_low: Lower bound offset (default 0.2, gives lower bound of 0.8)
185
+ clip_ratio_high: Upper bound offset (default 0.28, gives upper bound of 1.28)
186
+ normalize_constant: Fixed constant divisor for loss normalization.
187
+ If None (default), uses mean (original behavior).
188
+ If provided, uses sum/constant to eliminate length bias.
189
+ Typical values: 1024, or batch_size.
190
+
132
191
  Returns:
133
192
  Policy loss scalar (to be minimized)
134
-
193
+
135
194
  Notes:
136
195
  - Fully vectorized (no Python loops over batch)
137
196
  - Uses in-place operations where possible
138
197
  - Suitable for MLX graph optimization
139
198
  - Single forward pass through computation
199
+ - DAPO defaults: clip_ratio_low=0.2, clip_ratio_high=0.28
200
+ - Length bias: When using mean, longer sequences have lower per-token
201
+ contribution, creating implicit bias toward short responses.
202
+
203
+ References:
204
+ DAPO: An Open-Source LLM Reinforcement Learning System at Scale
205
+ https://arxiv.org/abs/2503.14476
206
+
207
+ Dr. GRPO: Understanding R1-Zero-Like Training
208
+ https://arxiv.org/abs/2503.20783
140
209
  """
141
- # Importance ratio: π_new / π_old
210
+ # Resolve clipping configuration (handles backward compatibility)
211
+ clip_cfg = resolve_clip_config(clip_ratio, clip_ratio_low, clip_ratio_high)
212
+
213
+ # Importance ratio: π_new / π_old
142
214
  # MLX optimizes exp() for Apple Silicon
143
215
  ratio = mx.exp(new_logprobs - old_logprobs)
144
-
145
- # PPO clipped surrogate objective
146
- # L = min(ratio * A, clip(ratio, 1-ε, 1+ε) * A)
147
- clipped_ratio = mx.clip(ratio, 1 - clip_ratio, 1 + clip_ratio)
148
-
149
- # Element-wise minimum and mean reduction
150
- # Negative because we minimize (original maximizes)
216
+
217
+ # PPO clipped surrogate objective with asymmetric bounds (DAPO-style)
218
+ # L = min(ratio * A, clip(ratio, 1-ε_low, 1+ε_high) * A)
219
+ clipped_ratio = mx.clip(ratio, 1 - clip_cfg.low, 1 + clip_cfg.high)
220
+
221
+ # Element-wise minimum
151
222
  surr1 = ratio * advantages
152
223
  surr2 = clipped_ratio * advantages
153
- loss = -mx.mean(mx.minimum(surr1, surr2))
154
-
224
+ min_surr = mx.minimum(surr1, surr2)
225
+
226
+ # Normalization: either mean (original) or sum/constant (Dr. GRPO)
227
+ if normalize_constant is not None:
228
+ if normalize_constant <= 0:
229
+ raise ValueError(
230
+ f"normalize_constant must be positive, got {normalize_constant}"
231
+ )
232
+ # Fixed constant normalization eliminates length bias
233
+ # All sequences contribute equally regardless of length
234
+ loss = -mx.sum(min_surr) / normalize_constant
235
+ else:
236
+ # Original mean behavior (for backward compatibility)
237
+ loss = -mx.mean(min_surr)
238
+
155
239
  return loss
156
240
 
157
241
 
@@ -163,86 +247,491 @@ def compute_advantages_compiled(rewards: mx.array) -> mx.array:
163
247
  return rewards - group_mean
164
248
 
165
249
 
166
- @mx.compile
167
- def policy_loss_compiled(
250
+ # --- Compiled Policy Loss Variants ---
251
+ # Two internal compiled functions for different normalization strategies.
252
+ # Compiled functions require static control flow, so we keep them separate.
253
+
254
+ @mx.compile
255
+ def _policy_loss_compiled_mean(
168
256
  old_logprobs: mx.array,
169
257
  new_logprobs: mx.array,
170
258
  advantages: mx.array,
171
- clip_ratio: float = 0.2
259
+ clip_ratio_low: float,
260
+ clip_ratio_high: float
172
261
  ) -> mx.array:
173
- """Compiled version of policy_loss for maximum performance."""
262
+ """Internal compiled function: mean normalization."""
174
263
  ratio = mx.exp(new_logprobs - old_logprobs)
175
- clipped_ratio = mx.clip(ratio, 1 - clip_ratio, 1 + clip_ratio)
264
+ clipped_ratio = mx.clip(ratio, 1 - clip_ratio_low, 1 + clip_ratio_high)
176
265
  surr1 = ratio * advantages
177
266
  surr2 = clipped_ratio * advantages
178
267
  return -mx.mean(mx.minimum(surr1, surr2))
179
268
 
180
269
 
270
+ @mx.compile
271
+ def _policy_loss_compiled_constant(
272
+ old_logprobs: mx.array,
273
+ new_logprobs: mx.array,
274
+ advantages: mx.array,
275
+ clip_ratio_low: float,
276
+ clip_ratio_high: float,
277
+ normalize_constant: float
278
+ ) -> mx.array:
279
+ """Internal compiled function: constant normalization."""
280
+ ratio = mx.exp(new_logprobs - old_logprobs)
281
+ clipped_ratio = mx.clip(ratio, 1 - clip_ratio_low, 1 + clip_ratio_high)
282
+ surr1 = ratio * advantages
283
+ surr2 = clipped_ratio * advantages
284
+ return -mx.sum(mx.minimum(surr1, surr2)) / normalize_constant
285
+
286
+
287
+ def policy_loss_compiled(
288
+ old_logprobs: mx.array,
289
+ new_logprobs: mx.array,
290
+ advantages: mx.array,
291
+ clip_ratio: float = None,
292
+ clip_ratio_low: float = 0.2,
293
+ clip_ratio_high: float = 0.28
294
+ ) -> mx.array:
295
+ """
296
+ Compiled version of policy_loss for maximum performance (mean normalization).
297
+
298
+ Supports DAPO-style asymmetric clipping bounds with backward compatibility.
299
+ Uses mean normalization (original behavior).
300
+
301
+ Args:
302
+ old_logprobs: Log probabilities from rollout collection
303
+ new_logprobs: Log probabilities from current policy evaluation
304
+ advantages: Group-relative advantages
305
+ clip_ratio: Symmetric clipping ratio (for backward compatibility).
306
+ If provided, overrides clip_ratio_low and clip_ratio_high.
307
+ clip_ratio_low: Lower bound offset (default 0.2)
308
+ clip_ratio_high: Upper bound offset (default 0.28)
309
+ """
310
+ clip_cfg = resolve_clip_config(clip_ratio, clip_ratio_low, clip_ratio_high)
311
+ return _policy_loss_compiled_mean(
312
+ old_logprobs, new_logprobs, advantages,
313
+ clip_cfg.low, clip_cfg.high
314
+ )
315
+
316
+
317
+ def policy_loss_compiled_constant_norm(
318
+ old_logprobs: mx.array,
319
+ new_logprobs: mx.array,
320
+ advantages: mx.array,
321
+ clip_ratio: float = None,
322
+ clip_ratio_low: float = 0.2,
323
+ clip_ratio_high: float = 0.28,
324
+ normalize_constant: float = 1024.0
325
+ ) -> mx.array:
326
+ """
327
+ Compiled version of policy_loss with fixed constant normalization (Dr. GRPO).
328
+
329
+ Uses sum/constant instead of mean to eliminate length bias.
330
+
331
+ Args:
332
+ old_logprobs: Log probabilities from rollout collection
333
+ new_logprobs: Log probabilities from current policy evaluation
334
+ advantages: Group-relative advantages
335
+ clip_ratio: Symmetric clipping ratio (for backward compatibility).
336
+ If provided, overrides clip_ratio_low and clip_ratio_high.
337
+ clip_ratio_low: Lower bound offset (default 0.2)
338
+ clip_ratio_high: Upper bound offset (default 0.28)
339
+ normalize_constant: Fixed constant divisor (default 1024)
340
+
341
+ References:
342
+ Dr. GRPO: Understanding R1-Zero-Like Training
343
+ https://arxiv.org/abs/2503.20783
344
+ """
345
+ if normalize_constant <= 0:
346
+ raise ValueError(
347
+ f"normalize_constant must be positive, got {normalize_constant}"
348
+ )
349
+ clip_cfg = resolve_clip_config(clip_ratio, clip_ratio_low, clip_ratio_high)
350
+ return _policy_loss_compiled_constant(
351
+ old_logprobs, new_logprobs, advantages,
352
+ clip_cfg.low, clip_cfg.high, normalize_constant
353
+ )
354
+
355
+
181
356
  def entropy_bonus(logprobs: mx.array, coefficient: float = 0.01) -> mx.array:
182
357
  """
183
358
  Entropy bonus for exploration (optional GRPO component).
184
-
359
+
185
360
  Args:
186
361
  logprobs: Log probabilities from policy
187
362
  coefficient: Entropy coefficient (typically small, like 0.01)
188
-
363
+
189
364
  Returns:
190
365
  Entropy bonus (added to loss for exploration)
191
366
  """
192
367
  if coefficient <= 0:
193
368
  return mx.array(0.0)
194
-
369
+
195
370
  # Entropy = -sum(p * log(p))
196
371
  # For log probabilities: entropy = -sum(exp(logp) * logp)
197
372
  probs = mx.exp(logprobs)
198
373
  entropy = -mx.sum(probs * logprobs, axis=-1)
199
-
374
+
200
375
  # Return negative entropy (since we add to loss but want to maximize entropy)
201
376
  return -coefficient * mx.mean(entropy)
202
377
 
203
378
 
379
+ # Note: compute_length_penalty, apply_length_shaping, and compute_length_shaping_stats
380
+ # are imported from .length_shaping module (see imports at top of file)
381
+
382
+
383
+ # DAPO-style dynamic batch filtering (Issue #9)
384
+ def _get_episode_reward(episode) -> float:
385
+ """Extract total reward from episode (handles both Episode objects and dicts)."""
386
+ if hasattr(episode, 'rew'):
387
+ # Episode object
388
+ return float(mx.sum(mx.array(episode.rew)).item())
389
+ else:
390
+ # Serialized dictionary
391
+ rew = episode.get('rew', episode.get('reward', [0.0]))
392
+ if isinstance(rew, (int, float)):
393
+ return float(rew)
394
+ return float(mx.sum(mx.array(rew)).item())
395
+
396
+
397
+ def _get_prompt_key(episode) -> tuple:
398
+ """
399
+ Generate a hashable key for an episode's prompt.
400
+
401
+ Handles both Episode objects and serialized dictionaries.
402
+ Uses the observation (prompt) tokens to identify the prompt.
403
+ """
404
+ if hasattr(episode, 'obs'):
405
+ obs = episode.obs
406
+ else:
407
+ obs = episode.get('obs', [])
408
+
409
+ # Flatten nested structures to create consistent key
410
+ flattened = []
411
+ for item in obs:
412
+ if hasattr(item, 'tolist'): # MLX array
413
+ flattened.extend(item.tolist())
414
+ elif isinstance(item, list):
415
+ flattened.extend(item)
416
+ else:
417
+ flattened.append(item)
418
+
419
+ return tuple(flattened)
420
+
421
+
422
+ def _precompute_episode_rewards(episodes: List[Any]) -> List[float]:
423
+ """
424
+ Pre-compute rewards for all episodes in a single pass.
425
+
426
+ Uses batched MLX evaluation to avoid per-episode .item() sync barriers.
427
+ All mx.sum() calls are built lazily, then evaluated in one mx.eval() call.
428
+
429
+ Args:
430
+ episodes: List of episodes
431
+
432
+ Returns:
433
+ List of rewards in the same order as episodes
434
+ """
435
+ if not episodes:
436
+ return []
437
+
438
+ rewards: List[Optional[float]] = [None] * len(episodes)
439
+ pending: List[Tuple[int, mx.array]] = [] # (index, lazy_sum) pairs
440
+
441
+ for i, ep in enumerate(episodes):
442
+ if hasattr(ep, 'rew'):
443
+ rew = ep.rew
444
+ else:
445
+ rew = ep.get('rew', ep.get('reward', [0.0]))
446
+
447
+ if isinstance(rew, (int, float)):
448
+ rewards[i] = float(rew)
449
+ else:
450
+ pending.append((i, mx.sum(mx.array(rew))))
451
+
452
+ # Single sync barrier for all array rewards
453
+ if pending:
454
+ indices, lazy_sums = zip(*pending)
455
+ stacked = mx.stack(list(lazy_sums))
456
+ mx.eval(stacked)
457
+ values = stacked.tolist()
458
+ for idx, val in zip(indices, values):
459
+ rewards[idx] = float(val)
460
+
461
+ return rewards # type: ignore[return-value]
462
+
463
+
464
+ def _compute_group_variance_and_mean(
465
+ group_indices: List[int],
466
+ all_rewards: List[float]
467
+ ) -> Tuple[float, float]:
468
+ """
469
+ Compute variance and mean for a group of episodes using pre-computed rewards.
470
+
471
+ Args:
472
+ group_indices: Indices into all_rewards for this group
473
+ all_rewards: Pre-computed rewards for all episodes
474
+
475
+ Returns:
476
+ Tuple of (variance, mean)
477
+ """
478
+ group_rewards = mx.array([all_rewards[i] for i in group_indices])
479
+ return mx.var(group_rewards).item(), mx.mean(group_rewards).item()
480
+
481
+
482
+ def filter_informative_prompts(
483
+ episodes: List[Any],
484
+ min_variance: float = 0.01,
485
+ keep_single_completion: bool = True
486
+ ) -> Tuple[List[Any], Dict[str, Union[int, float]]]:
487
+ """
488
+ Filter episodes to keep only informative prompts (DAPO dynamic sampling).
489
+
490
+ Removes prompts where all completions have same outcome:
491
+ - All correct (reward ~1.0): no learning signal (nothing to improve)
492
+ - All wrong (reward ~0.0): no positive signal (can't learn what works)
493
+
494
+ GRPO uses group-relative advantages. If all completions have the same
495
+ outcome, advantages are zero, producing no gradient and wasting compute.
496
+
497
+ Note on single-completion prompts:
498
+ The DAPO paper (Equation 11) defines informative prompts as having
499
+ mixed outcomes: `0 < |correct| < G`. This assumes G > 1 completions
500
+ per prompt. For single-completion prompts (G=1), variance is always 0
501
+ by definition, but this doesn't mean "all outcomes are the same" -
502
+ it means we have insufficient data to determine variance.
503
+
504
+ By default (keep_single_completion=True), single-completion prompts
505
+ are kept since they still provide valid gradient signal. Set to False
506
+ to filter them out (stricter DAPO interpretation).
507
+
508
+ Args:
509
+ episodes: List of episodes (Episode objects or serialized dicts)
510
+ min_variance: Minimum reward variance to keep a prompt group.
511
+ Groups with variance below this threshold are filtered out.
512
+ Default 0.01 filters prompts with essentially identical rewards.
513
+ Only applied to groups with 2+ completions.
514
+ keep_single_completion: Whether to keep prompts with only one completion.
515
+ Default True (keep them). Set False to require
516
+ multiple completions for variance calculation.
517
+
518
+ Returns:
519
+ Tuple of:
520
+ - filtered: List of episodes from informative prompts
521
+ - stats: Dictionary with filtering statistics:
522
+ - 'prompts_kept': Number of prompt groups kept
523
+ - 'prompts_dropped_all_correct': Prompts where all completions succeeded
524
+ - 'prompts_dropped_all_wrong': Prompts where all completions failed
525
+ - 'prompts_dropped_single': Prompts dropped due to single completion
526
+ - 'prompts_kept_single': Single-completion prompts that were kept
527
+ - 'episodes_kept': Total episodes kept
528
+ - 'episodes_dropped': Total episodes filtered out
529
+ - 'filter_rate': Fraction of prompts filtered
530
+
531
+ Example:
532
+ >>> filtered, stats = filter_informative_prompts(episodes, min_variance=0.01)
533
+ >>> print(f"Kept {stats['prompts_kept']} prompts, "
534
+ ... f"dropped {stats['prompts_dropped_all_correct']} all-correct, "
535
+ ... f"{stats['prompts_dropped_all_wrong']} all-wrong")
536
+
537
+ References:
538
+ DAPO: An Open-Source LLM Reinforcement Learning System at Scale
539
+ https://arxiv.org/abs/2503.14476 (Equation 11: 0 < |correct| < G)
540
+
541
+ GRPO++ Tricks
542
+ https://cameronrwolfe.substack.com/p/grpo-tricks
543
+ """
544
+ if not episodes:
545
+ return [], {
546
+ 'prompts_kept': 0,
547
+ 'prompts_dropped_all_correct': 0,
548
+ 'prompts_dropped_all_wrong': 0,
549
+ 'prompts_dropped_single': 0,
550
+ 'prompts_kept_single': 0,
551
+ 'episodes_kept': 0,
552
+ 'episodes_dropped': 0,
553
+ 'filter_rate': 0.0,
554
+ }
555
+
556
+ # Pre-compute all rewards once (avoids repeated _get_episode_reward calls)
557
+ all_rewards = _precompute_episode_rewards(episodes)
558
+
559
+ # Group episodes by prompt, storing indices instead of episodes
560
+ prompt_groups: Dict[tuple, List[int]] = defaultdict(list)
561
+ for idx, ep in enumerate(episodes):
562
+ prompt_key = _get_prompt_key(ep)
563
+ prompt_groups[prompt_key].append(idx)
564
+
565
+ filtered = []
566
+ stats = {
567
+ 'prompts_kept': 0,
568
+ 'prompts_dropped_all_correct': 0,
569
+ 'prompts_dropped_all_wrong': 0,
570
+ 'prompts_dropped_single': 0,
571
+ 'prompts_kept_single': 0,
572
+ 'episodes_kept': 0,
573
+ 'episodes_dropped': 0,
574
+ }
575
+
576
+ for prompt_key, group_indices in prompt_groups.items():
577
+ group_size = len(group_indices)
578
+
579
+ # Handle single-completion prompts separately
580
+ if group_size == 1:
581
+ if keep_single_completion:
582
+ # Keep single-completion prompts (variance undefined, not "zero")
583
+ filtered.append(episodes[group_indices[0]])
584
+ stats['prompts_kept'] += 1
585
+ stats['prompts_kept_single'] += 1
586
+ stats['episodes_kept'] += 1
587
+ else:
588
+ # Filter out single-completion prompts (strict DAPO interpretation)
589
+ stats['prompts_dropped_single'] += 1
590
+ stats['episodes_dropped'] += 1
591
+ continue
592
+
593
+ # For groups with 2+ completions, use variance criterion
594
+ variance, mean_reward = _compute_group_variance_and_mean(group_indices, all_rewards)
595
+
596
+ if variance > min_variance:
597
+ # Informative: mixed outcomes, keep all episodes from this prompt
598
+ for idx in group_indices:
599
+ filtered.append(episodes[idx])
600
+ stats['prompts_kept'] += 1
601
+ stats['episodes_kept'] += group_size
602
+ else:
603
+ # Uninformative: all completions have same outcome
604
+ stats['episodes_dropped'] += group_size
605
+ if mean_reward > 0.5:
606
+ stats['prompts_dropped_all_correct'] += 1
607
+ else:
608
+ stats['prompts_dropped_all_wrong'] += 1
609
+
610
+ # Compute filter rate
611
+ total_prompts = len(prompt_groups)
612
+ stats['filter_rate'] = 1.0 - (stats['prompts_kept'] / total_prompts) if total_prompts > 0 else 0.0
613
+
614
+ return filtered, stats
615
+
616
+
617
+ def compute_prompt_group_stats(episodes: List[Any]) -> Dict[str, Any]:
618
+ """
619
+ Compute statistics about prompt groups for monitoring.
620
+
621
+ Useful for understanding the distribution of prompts and completions
622
+ before and after filtering.
623
+
624
+ Args:
625
+ episodes: List of episodes
626
+
627
+ Returns:
628
+ Dictionary with:
629
+ - 'num_prompts': Total unique prompts
630
+ - 'num_episodes': Total episodes
631
+ - 'completions_per_prompt': Average completions per prompt
632
+ - 'reward_variance_mean': Mean variance across prompt groups
633
+ - 'reward_variance_std': Std of variance across prompt groups
634
+ """
635
+ if not episodes:
636
+ return {
637
+ 'num_prompts': 0,
638
+ 'num_episodes': 0,
639
+ 'completions_per_prompt': 0.0,
640
+ 'reward_variance_mean': 0.0,
641
+ 'reward_variance_std': 0.0,
642
+ }
643
+
644
+ # Pre-compute all rewards once
645
+ all_rewards = _precompute_episode_rewards(episodes)
646
+
647
+ # Group by prompt, storing indices
648
+ prompt_groups: Dict[tuple, List[int]] = defaultdict(list)
649
+ for idx, ep in enumerate(episodes):
650
+ prompt_key = _get_prompt_key(ep)
651
+ prompt_groups[prompt_key].append(idx)
652
+
653
+ # Compute variance for each group using pre-computed rewards
654
+ variances = []
655
+ for group_indices in prompt_groups.values():
656
+ group_rewards = mx.array([all_rewards[i] for i in group_indices])
657
+ variances.append(mx.var(group_rewards).item())
658
+
659
+ variances_arr = mx.array(variances) if variances else mx.array([0.0])
660
+
661
+ return {
662
+ 'num_prompts': len(prompt_groups),
663
+ 'num_episodes': len(episodes),
664
+ 'completions_per_prompt': len(episodes) / len(prompt_groups) if prompt_groups else 0.0,
665
+ 'reward_variance_mean': float(mx.mean(variances_arr).item()),
666
+ 'reward_variance_std': float(mx.std(variances_arr).item()),
667
+ }
668
+
669
+
204
670
  # Convenience function for complete GRPO computation
205
671
  def grpo_loss(
206
672
  old_logprobs: mx.array,
207
673
  new_logprobs: mx.array,
208
674
  rewards: Union[List[float], mx.array],
209
- clip_ratio: float = 0.2,
210
- entropy_coeff: float = 0.0
675
+ clip_ratio: float = None,
676
+ clip_ratio_low: float = 0.2,
677
+ clip_ratio_high: float = 0.28,
678
+ entropy_coeff: float = 0.0,
679
+ normalize_constant: int = None
211
680
  ) -> mx.array:
212
681
  """
213
682
  Complete GRPO loss computation in one function.
214
-
683
+
215
684
  Combines advantage calculation and policy loss for convenience.
216
685
  Can be compiled as a single unit for maximum efficiency.
217
-
686
+ Supports DAPO-style asymmetric clipping bounds and Dr. GRPO length-bias fix.
687
+
218
688
  Args:
219
689
  old_logprobs: Log probabilities from rollout
220
690
  new_logprobs: Log probabilities from current policy
221
691
  rewards: Episode rewards for group-relative advantages
222
- clip_ratio: PPO clipping ratio
692
+ clip_ratio: Symmetric clipping ratio (for backward compatibility).
693
+ If provided, overrides clip_ratio_low and clip_ratio_high.
694
+ clip_ratio_low: Lower bound offset (default 0.2)
695
+ clip_ratio_high: Upper bound offset (default 0.28)
223
696
  entropy_coeff: Entropy bonus coefficient (0 disables)
224
-
697
+ normalize_constant: Fixed constant divisor for loss normalization.
698
+ If None (default), uses mean. If provided, uses
699
+ sum/constant to eliminate length bias.
700
+
225
701
  Returns:
226
702
  Total GRPO loss (policy + optional entropy)
703
+
704
+ References:
705
+ DAPO: An Open-Source LLM Reinforcement Learning System at Scale
706
+ https://arxiv.org/abs/2503.14476
707
+
708
+ Dr. GRPO: Understanding R1-Zero-Like Training
709
+ https://arxiv.org/abs/2503.20783
227
710
  """
228
711
  # Compute group-relative advantages
229
712
  advantages = compute_advantages(rewards)
230
-
713
+
231
714
  # Expand advantages to match logprob sequence length if needed
232
715
  if advantages.ndim == 1 and old_logprobs.ndim > 1:
233
716
  # Each episode contributes its advantage to all tokens in that episode
234
717
  # This requires knowing episode boundaries - simplified version assumes
235
718
  # advantages and logprobs are already aligned
236
719
  pass
237
-
238
- # Compute policy loss
239
- policy_loss_val = policy_loss(old_logprobs, new_logprobs, advantages, clip_ratio)
240
-
720
+
721
+ # Compute policy loss with asymmetric clipping and optional length-bias fix
722
+ policy_loss_val = policy_loss(
723
+ old_logprobs, new_logprobs, advantages,
724
+ clip_ratio=clip_ratio,
725
+ clip_ratio_low=clip_ratio_low,
726
+ clip_ratio_high=clip_ratio_high,
727
+ normalize_constant=normalize_constant
728
+ )
729
+
241
730
  # Add entropy bonus if specified
242
731
  if entropy_coeff > 0:
243
732
  entropy_bonus_val = entropy_bonus(new_logprobs, entropy_coeff)
244
733
  return policy_loss_val + entropy_bonus_val
245
-
734
+
246
735
  return policy_loss_val
247
736
 
248
737
 
@@ -251,392 +740,313 @@ def compute_metrics(
251
740
  old_logprobs: mx.array,
252
741
  new_logprobs: mx.array,
253
742
  advantages: mx.array,
254
- clip_ratio: float = 0.2
743
+ clip_ratio: float = None,
744
+ clip_ratio_low: float = 0.2,
745
+ clip_ratio_high: float = 0.28
255
746
  ) -> dict:
256
747
  """
257
748
  Compute GRPO training metrics for monitoring.
258
-
749
+
750
+ Supports DAPO-style asymmetric clipping bounds and tracks clip fractions
751
+ for upper vs lower bounds separately.
752
+
259
753
  Args:
260
754
  old_logprobs: Log probabilities from rollout
261
- new_logprobs: Log probabilities from current policy
755
+ new_logprobs: Log probabilities from current policy
262
756
  advantages: Group-relative advantages
263
- clip_ratio: Clipping ratio used in loss
264
-
757
+ clip_ratio: Symmetric clipping ratio (for backward compatibility).
758
+ If provided, overrides clip_ratio_low and clip_ratio_high.
759
+ clip_ratio_low: Lower bound offset (default 0.2)
760
+ clip_ratio_high: Upper bound offset (default 0.28)
761
+
265
762
  Returns:
266
- Dictionary of metrics for logging/monitoring
763
+ Dictionary of metrics for logging/monitoring, including:
764
+ - clip_fraction_lower: Fraction of ratios clipped at lower bound
765
+ - clip_fraction_upper: Fraction of ratios clipped at upper bound
766
+ - clip_fraction: Total fraction of ratios clipped (either bound)
267
767
  """
768
+ # Resolve clipping configuration (handles backward compatibility)
769
+ clip_cfg = resolve_clip_config(clip_ratio, clip_ratio_low, clip_ratio_high)
770
+
268
771
  # Importance ratio statistics
269
772
  ratio = mx.exp(new_logprobs - old_logprobs)
270
-
271
- # Clipping statistics
272
- clip_lower = 1 - clip_ratio
273
- clip_upper = 1 + clip_ratio
274
- clipped = (ratio < clip_lower) | (ratio > clip_upper)
773
+
774
+ # Asymmetric clipping bounds
775
+ clip_lower = 1 - clip_cfg.low
776
+ clip_upper = 1 + clip_cfg.high
777
+
778
+ # Track clip fractions separately for upper and lower bounds
779
+ clipped_lower = ratio < clip_lower
780
+ clipped_upper = ratio > clip_upper
781
+ clipped = clipped_lower | clipped_upper
782
+
783
+ clip_fraction_lower = mx.mean(clipped_lower.astype(mx.float32))
784
+ clip_fraction_upper = mx.mean(clipped_upper.astype(mx.float32))
275
785
  clip_fraction = mx.mean(clipped.astype(mx.float32))
276
-
786
+
277
787
  # KL divergence approximation
278
788
  kl_div = mx.mean(old_logprobs - new_logprobs)
279
-
789
+
280
790
  return {
281
791
  'mean_advantage': mx.mean(advantages).item(),
282
792
  'std_advantage': mx.std(advantages).item(),
283
793
  'mean_ratio': mx.mean(ratio).item(),
284
794
  'clip_fraction': clip_fraction.item(),
795
+ 'clip_fraction_lower': clip_fraction_lower.item(),
796
+ 'clip_fraction_upper': clip_fraction_upper.item(),
285
797
  'kl_divergence': kl_div.item(),
286
798
  'min_advantage': mx.min(advantages).item(),
287
- 'max_advantage': mx.max(advantages).item()
799
+ 'max_advantage': mx.max(advantages).item(),
800
+ 'clip_ratio_low': clip_cfg.low,
801
+ 'clip_ratio_high': clip_cfg.high
288
802
  }
289
803
 
290
804
 
291
- # Algorithm-specific data selection strategies
292
- def select_all_data(buffer):
805
+ # --- Episode Packing Helper ---
806
+ def _flatten_tokens(items: List[Any]) -> List:
807
+ """Flatten nested token sequences into a flat list."""
808
+ flattened = []
809
+ for item in items:
810
+ if hasattr(item, 'tolist'): # MLX array
811
+ flattened.extend(item.tolist())
812
+ elif isinstance(item, list): # Python list
813
+ flattened.extend(item)
814
+ else: # Single token
815
+ flattened.append(item)
816
+ return flattened
817
+
818
+
819
+ def _pack_episodes(episodes: List[Any]) -> Dict[str, Any]:
293
820
  """
294
- GRPO data selector: Use all available data.
295
-
296
- GRPO is on-policy but can benefit from using all collected episodes
297
- since group-relative advantages normalize across the entire group.
298
-
821
+ Pack episodes into batch data for GRPO training.
822
+
823
+ This is the shared helper for episode-to-batch conversion, used by all
824
+ data selectors (select_all_data, select_informative_data, select_recent_data).
825
+
299
826
  Args:
300
- buffer: Buffer containing episodes
301
-
827
+ episodes: List of episodes (Episode objects or serialized dicts)
828
+
302
829
  Returns:
303
- All episode data prepared for training
830
+ Dictionary with:
831
+ - 'obs': Flat concatenated full sequences (prompt + response)
832
+ - 'act': Flat concatenated response tokens
833
+ - 'logprob': Flat concatenated log probabilities
834
+ - 'rewards': Episode rewards as MLX array
835
+ - 'episode_lengths': List of episode lengths
304
836
  """
305
- from textpolicy.buffer import Buffer
306
- if not isinstance(buffer, Buffer):
307
- raise TypeError(f"Expected Buffer, got {type(buffer)}")
308
-
309
- # Use all available data - GRPO benefits from larger groups
310
- episodes_data = buffer.sample() # This returns concatenated transitions
311
-
312
- # We need to convert this back to episode structure for reward extraction
313
- episodes = buffer.episodes # Access episodes directly from storage
314
-
315
837
  if not episodes:
316
- raise ValueError("Buffer is empty - no episodes to train on")
317
-
318
- # Extract episode rewards for advantage computation
319
- episode_rewards = []
838
+ return {
839
+ 'obs': mx.array([], dtype=mx.int64),
840
+ 'act': mx.array([], dtype=mx.int64),
841
+ 'logprob': mx.array([], dtype=mx.float32),
842
+ 'rewards': mx.array([]),
843
+ 'episode_lengths': [],
844
+ }
845
+
320
846
  episode_lengths = []
321
-
322
- # Collect all transitions
323
847
  all_obs = []
324
848
  all_acts = []
325
849
  all_logprobs = []
326
-
327
- for episode in episodes:
328
- # Episode reward (sum of all rewards in episode)
329
- # Handle both Episode objects and serialized dictionaries
850
+ pending_reward_sums: List[Tuple[int, mx.array]] = []
851
+ scalar_rewards: Dict[int, float] = {}
852
+
853
+ for i, episode in enumerate(episodes):
330
854
  if hasattr(episode, 'rew'):
331
855
  # Episode object with attributes
332
- episode_reward = mx.sum(mx.array(episode.rew)).item()
333
- episode_rewards.append(episode_reward)
334
- episode_lengths.append(len(episode.obs))
335
-
336
- # Collect transitions
337
- # For proper logprob extraction during training, we need the full context (prompt + response)
338
- # This matches how the model was called during rollout generation
339
- # Flatten nested token sequences to create uniform token arrays
340
-
341
- # Extract and flatten observation tokens (prompt)
342
- flattened_obs = []
343
- for obs in episode.obs:
344
- if hasattr(obs, 'tolist'): # MLX array
345
- flattened_obs.extend(obs.tolist())
346
- elif isinstance(obs, list): # Python list
347
- flattened_obs.extend(obs)
348
- else: # Single token
349
- flattened_obs.append(obs)
350
-
351
- # Extract and flatten action tokens (response)
352
- flattened_acts = []
353
- for act in episode.act:
354
- if hasattr(act, 'tolist'): # MLX array
355
- flattened_acts.extend(act.tolist())
356
- elif isinstance(act, list): # Python list
357
- flattened_acts.extend(act)
358
- else: # Single token
359
- flattened_acts.append(act)
360
-
856
+ pending_reward_sums.append((i, mx.sum(mx.array(episode.rew))))
857
+
858
+ # Flatten observation and action tokens
859
+ flattened_obs = _flatten_tokens(episode.obs)
860
+ flattened_acts = _flatten_tokens(episode.act)
861
+
862
+ # Use flattened token count for episode_lengths (used by _expand_advantages)
863
+ # This ensures alignment between expanded advantages and actual token sequences
864
+ episode_lengths.append(len(flattened_acts))
865
+
361
866
  # Create full sequence: [prompt_tokens..., response_tokens...]
362
867
  full_sequence = flattened_obs + flattened_acts
363
868
  all_obs.append(full_sequence)
364
869
  all_acts.append(flattened_acts)
365
- all_logprobs.append(episode.logprob if episode.logprob else [])
870
+ all_logprobs.append(episode.logprob if episode.logprob is not None else [])
366
871
  else:
367
872
  # Serialized dictionary from multiprocessing
368
- episode_reward = mx.sum(episode['rew']).item()
369
- episode_rewards.append(episode_reward)
370
- episode_lengths.append(len(episode['obs']))
371
-
372
- # Collect transitions
373
- # For proper logprob extraction during training, we need the full context (prompt + response)
374
- # This matches how the model was called during rollout generation
375
- full_sequence = episode['obs'] + episode['act'] # Concatenate prompt + response
873
+ rew = episode['rew']
874
+ if isinstance(rew, (int, float)):
875
+ scalar_rewards[i] = float(rew)
876
+ else:
877
+ pending_reward_sums.append((i, mx.sum(mx.array(rew))))
878
+
879
+ # Flatten observation and action tokens
880
+ flattened_obs = _flatten_tokens(episode['obs'])
881
+ flattened_acts = _flatten_tokens(episode['act'])
882
+
883
+ # Use flattened token count for episode_lengths
884
+ episode_lengths.append(len(flattened_acts))
885
+
886
+ full_sequence = flattened_obs + flattened_acts
376
887
  all_obs.append(full_sequence)
377
- all_acts.append(episode['act'])
888
+ all_acts.append(flattened_acts)
378
889
  all_logprobs.append(episode.get('logprob', []))
379
-
380
- # Convert Python lists to MLX arrays before concatenation
381
- # This is required because Episode objects store data as Python lists for memory efficiency
382
- # For proper logprob extraction, we need uniform-length sequences, so we pad to the maximum length
383
-
384
- # Find maximum sequence length for padding
890
+
891
+ # Batch evaluate all pending reward sums (single sync barrier instead of N)
892
+ episode_rewards = [0.0] * len(episodes)
893
+ for idx, val in scalar_rewards.items():
894
+ episode_rewards[idx] = val
895
+ if pending_reward_sums:
896
+ indices, lazy_sums = zip(*pending_reward_sums)
897
+ stacked = mx.stack(list(lazy_sums))
898
+ mx.eval(stacked)
899
+ values = stacked.tolist()
900
+ for idx, val in zip(indices, values):
901
+ episode_rewards[idx] = float(val)
902
+
903
+ # Find maximum sequence lengths for padding
385
904
  max_obs_len = max(len(obs) for obs in all_obs) if all_obs else 0
386
905
  max_act_len = max(len(act) for act in all_acts) if all_acts else 0
387
906
  max_logprob_len = max(len(logprob) for logprob in all_logprobs) if all_logprobs else 0
388
-
389
- # MLX-native padding and array operations for optimal Apple Silicon performance
390
- # Convert all sequences to MLX arrays and pad directly in MLX space
391
- try:
392
- # Convert all sequences to MLX arrays first (staying in unified memory)
393
- all_obs_mx = [mx.array(obs, dtype=mx.int64) for obs in all_obs if obs]
394
- all_acts_mx = [mx.array(act, dtype=mx.int64) for act in all_acts if act]
395
- all_logprobs_mx = [mx.array(logprob, dtype=mx.float32) for logprob in all_logprobs if logprob]
396
-
397
- # Pad using native MLX operations (more efficient for Apple Silicon)
398
- if all_obs_mx:
399
- padded_obs_mx = [mx.pad(obs, (0, max_obs_len - obs.shape[0]), constant_values=0)
400
- if obs.shape[0] < max_obs_len else obs[:max_obs_len]
401
- for obs in all_obs_mx]
402
- else:
403
- padded_obs_mx = []
404
-
405
- if all_acts_mx:
406
- padded_acts_mx = [mx.pad(act, (0, max_act_len - act.shape[0]), constant_values=0)
407
- if act.shape[0] < max_act_len else act[:max_act_len]
408
- for act in all_acts_mx]
409
- else:
410
- padded_acts_mx = []
411
-
412
- if all_logprobs_mx:
413
- padded_logprobs_mx = [mx.pad(logprob, (0, max_logprob_len - logprob.shape[0]), constant_values=0.0)
414
- if logprob.shape[0] < max_logprob_len else logprob[:max_logprob_len]
415
- for logprob in all_logprobs_mx]
416
- else:
417
- padded_logprobs_mx = []
418
-
419
- # Use padded MLX arrays directly (no intermediate conversion needed)
420
- all_obs_mx = padded_obs_mx
421
- all_acts_mx = padded_acts_mx
422
- all_logprobs_mx = padded_logprobs_mx
423
-
424
- except Exception as e:
425
- print(f"ERROR in MLX array conversion: {e}")
426
- print(f"DEBUG: all_obs types: {[type(obs) for obs in all_obs[:3]]}") # Show first 3 for brevity
427
- print(f"DEBUG: all_logprobs types: {[type(logprob) for logprob in all_logprobs[:3]]}")
428
- raise
429
-
430
- # GRPO data structure: both observations and actions as flat concatenated sequences
431
- # This matches the expected format for GRPO logprob extraction function
432
- batch_data = {
433
- 'obs': mx.concatenate(all_obs_mx) if all_obs_mx else mx.array([]), # Flat concatenated full sequences
434
- 'act': mx.concatenate(all_acts_mx) if all_acts_mx else mx.array([]), # Flat concatenated response tokens
435
- 'logprob': mx.concatenate([logprob.flatten() for logprob in all_logprobs_mx]) if all_logprobs_mx else mx.array([]), # Flat sequence for training
907
+
908
+ # Convert to MLX arrays with padding
909
+ # Always create an array for each episode to maintain alignment
910
+ all_obs_mx = [mx.array(obs, dtype=mx.int64) if obs else mx.array([], dtype=mx.int64) for obs in all_obs]
911
+ all_acts_mx = [mx.array(act, dtype=mx.int64) if act else mx.array([], dtype=mx.int64) for act in all_acts]
912
+ all_logprobs_mx = [mx.array(logprob, dtype=mx.float32) if logprob else mx.array([], dtype=mx.float32) for logprob in all_logprobs]
913
+
914
+ # Filter out empty arrays for padding/concatenation
915
+ non_empty_obs = [obs for obs in all_obs_mx if obs.size > 0]
916
+ non_empty_acts = [act for act in all_acts_mx if act.size > 0]
917
+ non_empty_logprobs = [logprob for logprob in all_logprobs_mx if logprob.size > 0]
918
+
919
+ # Pad using native MLX operations
920
+ if non_empty_obs:
921
+ padded_obs = [mx.pad(obs, (0, max_obs_len - obs.shape[0]), constant_values=0)
922
+ if obs.shape[0] < max_obs_len else obs[:max_obs_len]
923
+ for obs in non_empty_obs]
924
+ else:
925
+ padded_obs = []
926
+
927
+ if non_empty_acts:
928
+ padded_acts = [mx.pad(act, (0, max_act_len - act.shape[0]), constant_values=0)
929
+ if act.shape[0] < max_act_len else act[:max_act_len]
930
+ for act in non_empty_acts]
931
+ else:
932
+ padded_acts = []
933
+
934
+ if non_empty_logprobs:
935
+ padded_logprobs = [mx.pad(logprob, (0, max_logprob_len - logprob.shape[0]), constant_values=0.0)
936
+ if logprob.shape[0] < max_logprob_len else logprob[:max_logprob_len]
937
+ for logprob in non_empty_logprobs]
938
+ else:
939
+ padded_logprobs = []
940
+
941
+ return {
942
+ 'obs': mx.concatenate(padded_obs) if padded_obs else mx.array([], dtype=mx.int64),
943
+ 'act': mx.concatenate(padded_acts) if padded_acts else mx.array([], dtype=mx.int64),
944
+ 'logprob': mx.concatenate([lp.flatten() for lp in padded_logprobs]) if padded_logprobs else mx.array([], dtype=mx.float32),
436
945
  'rewards': mx.array(episode_rewards),
437
- 'episode_lengths': episode_lengths
946
+ 'episode_lengths': episode_lengths,
438
947
  }
439
-
948
+
949
+
950
+ # Algorithm-specific data selection strategies
951
+ def select_all_data(buffer):
952
+ """
953
+ GRPO data selector: Use all available data.
954
+
955
+ GRPO is on-policy but can benefit from using all collected episodes
956
+ since group-relative advantages normalize across the entire group.
957
+
958
+ Args:
959
+ buffer: Buffer containing episodes
960
+
961
+ Returns:
962
+ All episode data prepared for training
963
+ """
964
+ from textpolicy.buffer import Buffer
965
+ if not isinstance(buffer, Buffer):
966
+ raise TypeError(f"Expected Buffer, got {type(buffer)}")
967
+
968
+ episodes = buffer.episodes
969
+ if not episodes:
970
+ raise ValueError("Buffer is empty - no episodes to train on")
971
+
972
+ return _pack_episodes(episodes)
973
+
974
+
975
+ def select_informative_data(buffer, min_variance: float = 0.01):
976
+ """
977
+ GRPO data selector with dynamic batch filtering (DAPO-style).
978
+
979
+ Filters out uninformative prompts where all completions have the same
980
+ outcome (all correct or all wrong), improving sample efficiency by
981
+ maintaining meaningful gradient signals.
982
+
983
+ This is the recommended selector for GRPO training when using multiple
984
+ completions per prompt, as it eliminates wasted compute on prompts
985
+ that provide no learning signal.
986
+
987
+ Args:
988
+ buffer: Buffer containing episodes (Episode objects or serialized dictionaries)
989
+ min_variance: Minimum reward variance to keep a prompt group.
990
+ Prompts with variance below this are filtered out.
991
+
992
+ Returns:
993
+ Filtered episode data prepared for training, plus filtering stats.
994
+
995
+ Example:
996
+ >>> batch_data = select_informative_data(buffer, min_variance=0.01)
997
+ >>> # batch_data includes 'filter_stats' with filtering information
998
+
999
+ References:
1000
+ DAPO: An Open-Source LLM Reinforcement Learning System at Scale
1001
+ https://arxiv.org/abs/2503.14476
1002
+ """
1003
+ from textpolicy.buffer import Buffer
1004
+ if not isinstance(buffer, Buffer):
1005
+ raise TypeError(f"Expected Buffer, got {type(buffer)}")
1006
+
1007
+ episodes = buffer.episodes
1008
+ if not episodes:
1009
+ raise ValueError("Buffer is empty - no episodes to train on")
1010
+
1011
+ # Filter to keep only informative prompts
1012
+ filtered_episodes, filter_stats = filter_informative_prompts(episodes, min_variance)
1013
+
1014
+ if not filtered_episodes:
1015
+ raise ValueError(
1016
+ f"All prompts filtered out (min_variance={min_variance}). "
1017
+ f"Stats: {filter_stats}. Consider lowering min_variance or "
1018
+ "ensuring diversity in completions."
1019
+ )
1020
+
1021
+ # Pack filtered episodes using shared helper
1022
+ batch_data = _pack_episodes(filtered_episodes)
1023
+ batch_data['filter_stats'] = filter_stats
440
1024
  return batch_data
441
1025
 
442
1026
 
443
1027
  def select_recent_data(buffer, max_episodes: int = 100):
444
1028
  """
445
1029
  GRPO data selector: Use only recent episodes.
446
-
1030
+
447
1031
  Alternative selector for GRPO that limits to recent episodes
448
1032
  for faster training on large buffers.
449
-
1033
+
450
1034
  Args:
451
1035
  buffer: Buffer containing episodes (Episode objects or serialized dictionaries)
452
1036
  max_episodes: Maximum number of recent episodes to use
453
-
1037
+
454
1038
  Returns:
455
1039
  Recent episode data prepared for training
456
1040
  """
457
1041
  from textpolicy.buffer import Buffer
458
1042
  if not isinstance(buffer, Buffer):
459
1043
  raise TypeError(f"Expected Buffer, got {type(buffer)}")
460
-
1044
+
461
1045
  episodes = buffer.episodes
462
1046
  if not episodes:
463
1047
  raise ValueError("Buffer is empty - no episodes to train on")
464
-
1048
+
465
1049
  # Select recent episodes
466
1050
  recent_episodes = episodes[-max_episodes:] if len(episodes) > max_episodes else episodes
467
-
468
- # Process recent episodes
469
- episode_rewards = []
470
- episode_lengths = []
471
- all_obs = []
472
- all_acts = []
473
- all_logprobs = []
474
-
475
- for episode in recent_episodes:
476
- # Handle both Episode objects and serialized dictionaries
477
- if hasattr(episode, 'rew'):
478
- # Episode object with attributes
479
- episode_reward = mx.sum(mx.array(episode.rew)).item()
480
- episode_rewards.append(episode_reward)
481
- episode_lengths.append(len(episode.obs))
482
-
483
- # For proper logprob extraction during training, we need the full context (prompt + response)
484
- # This matches how the model was called during rollout generation
485
- # Convert both obs and act to consistent Python list format before concatenation
486
- obs_as_lists = []
487
- for obs_item in episode.obs:
488
- if hasattr(obs_item, 'tolist'): # MLX array
489
- obs_as_lists.extend(obs_item.tolist())
490
- elif isinstance(obs_item, list): # Already Python list
491
- obs_as_lists.extend(obs_item)
492
- else: # Single item
493
- obs_as_lists.append(obs_item)
494
-
495
- act_as_lists = []
496
- for act_item in episode.act:
497
- if hasattr(act_item, 'tolist'): # MLX array
498
- act_as_lists.extend(act_item.tolist())
499
- elif isinstance(act_item, list): # Already Python list
500
- act_as_lists.extend(act_item)
501
- else: # Single item
502
- act_as_lists.append(act_item)
503
-
504
- # Now concatenate the normalized lists
505
- full_sequence = obs_as_lists + act_as_lists
506
- all_obs.append(full_sequence)
507
-
508
- # Extract actions as consistent Python lists
509
- episode_actions = []
510
- for act_item in episode.act:
511
- if hasattr(act_item, 'tolist'): # MLX array
512
- episode_actions.extend(act_item.tolist())
513
- elif isinstance(act_item, list): # Already Python list
514
- episode_actions.extend(act_item)
515
- else: # Single item
516
- episode_actions.append(act_item)
517
- all_acts.append(episode_actions)
518
-
519
- # Extract logprobs as consistent Python lists
520
- episode_logprobs = []
521
- if episode.logprob:
522
- for logprob_item in episode.logprob:
523
- if hasattr(logprob_item, 'tolist'): # MLX array
524
- episode_logprobs.extend(logprob_item.tolist())
525
- elif isinstance(logprob_item, list): # Already Python list
526
- episode_logprobs.extend(logprob_item)
527
- else: # Single item
528
- episode_logprobs.append(logprob_item)
529
- all_logprobs.append(episode_logprobs)
530
- else:
531
- # Serialized dictionary from multiprocessing
532
- episode_reward = mx.sum(episode['rew']).item()
533
- episode_rewards.append(episode_reward)
534
- episode_lengths.append(len(episode['obs']))
535
-
536
- # For proper logprob extraction during training, we need the full context (prompt + response)
537
- # This matches how the model was called during rollout generation
538
- # Convert both obs and act to consistent Python list format before concatenation
539
- obs_as_lists = []
540
- for obs_item in episode['obs']:
541
- if hasattr(obs_item, 'tolist'): # MLX array
542
- obs_as_lists.extend(obs_item.tolist())
543
- elif isinstance(obs_item, list): # Already Python list
544
- obs_as_lists.extend(obs_item)
545
- else: # Single item
546
- obs_as_lists.append(obs_item)
547
-
548
- act_as_lists = []
549
- for act_item in episode['act']:
550
- if hasattr(act_item, 'tolist'): # MLX array
551
- act_as_lists.extend(act_item.tolist())
552
- elif isinstance(act_item, list): # Already Python list
553
- act_as_lists.extend(act_item)
554
- else: # Single item
555
- act_as_lists.append(act_item)
556
-
557
- # Now concatenate the normalized lists
558
- full_sequence = obs_as_lists + act_as_lists
559
- all_obs.append(full_sequence)
560
-
561
- # Extract actions as consistent Python lists
562
- episode_actions = []
563
- for act_item in episode['act']:
564
- if hasattr(act_item, 'tolist'): # MLX array
565
- episode_actions.extend(act_item.tolist())
566
- elif isinstance(act_item, list): # Already Python list
567
- episode_actions.extend(act_item)
568
- else: # Single item
569
- episode_actions.append(act_item)
570
- all_acts.append(episode_actions)
571
-
572
- # Extract logprobs as consistent Python lists
573
- episode_logprobs = []
574
- if episode.get('logprob'):
575
- for logprob_item in episode['logprob']:
576
- if hasattr(logprob_item, 'tolist'): # MLX array
577
- episode_logprobs.extend(logprob_item.tolist())
578
- elif isinstance(logprob_item, list): # Already Python list
579
- episode_logprobs.extend(logprob_item)
580
- else: # Single item
581
- episode_logprobs.append(logprob_item)
582
- all_logprobs.append(episode_logprobs)
583
-
584
- # Convert Python lists to MLX arrays before concatenation
585
- # This is required because Episode objects store data as Python lists for memory efficiency
586
- # For proper logprob extraction, we need uniform-length sequences, so we pad to the maximum length
587
-
588
- # Find maximum sequence length for padding
589
- max_obs_len = max(len(obs) for obs in all_obs) if all_obs else 0
590
- max_act_len = max(len(act) for act in all_acts) if all_acts else 0
591
- max_logprob_len = max(len(logprob) for logprob in all_logprobs) if all_logprobs else 0
592
-
593
- # MLX-native padding and array operations for optimal Apple Silicon performance
594
- # Convert all sequences to MLX arrays and pad directly in MLX space
595
- try:
596
- # Convert all sequences to MLX arrays first (staying in unified memory)
597
- all_obs_mx = [mx.array(obs, dtype=mx.int64) for obs in all_obs if obs]
598
- all_acts_mx = [mx.array(act, dtype=mx.int64) for act in all_acts if act]
599
- all_logprobs_mx = [mx.array(logprob, dtype=mx.float32) for logprob in all_logprobs if logprob]
600
-
601
- # Pad using native MLX operations (more efficient for Apple Silicon)
602
- if all_obs_mx:
603
- padded_obs_mx = [mx.pad(obs, (0, max_obs_len - obs.shape[0]), constant_values=0)
604
- if obs.shape[0] < max_obs_len else obs[:max_obs_len]
605
- for obs in all_obs_mx]
606
- else:
607
- padded_obs_mx = []
608
-
609
- if all_acts_mx:
610
- padded_acts_mx = [mx.pad(act, (0, max_act_len - act.shape[0]), constant_values=0)
611
- if act.shape[0] < max_act_len else act[:max_act_len]
612
- for act in all_acts_mx]
613
- else:
614
- padded_acts_mx = []
615
-
616
- if all_logprobs_mx:
617
- padded_logprobs_mx = [mx.pad(logprob, (0, max_logprob_len - logprob.shape[0]), constant_values=0.0)
618
- if logprob.shape[0] < max_logprob_len else logprob[:max_logprob_len]
619
- for logprob in all_logprobs_mx]
620
- else:
621
- padded_logprobs_mx = []
622
-
623
- # Use padded MLX arrays directly (no intermediate conversion needed)
624
- all_obs_mx = padded_obs_mx
625
- all_acts_mx = padded_acts_mx
626
- all_logprobs_mx = padded_logprobs_mx
627
-
628
- except Exception as e:
629
- print(f"ERROR in MLX array conversion: {e}")
630
- print(f"DEBUG: all_obs types: {[type(obs) for obs in all_obs[:3]]}") # Show first 3 for brevity
631
- print(f"DEBUG: all_logprobs types: {[type(logprob) for logprob in all_logprobs[:3]]}")
632
- raise
633
-
634
- batch_data = {
635
- 'obs': mx.concatenate(all_obs_mx) if all_obs_mx else mx.array([]), # Flat concatenated full sequences
636
- 'act': mx.concatenate(all_acts_mx) if all_acts_mx else mx.array([]), # Flat concatenated response tokens
637
- 'logprob': mx.concatenate([logprob.flatten() for logprob in all_logprobs_mx]) if all_logprobs_mx else mx.array([]), # Flat sequence for training
638
- 'rewards': mx.array(episode_rewards),
639
- 'episode_lengths': episode_lengths
640
- }
641
-
642
- return batch_data
1051
+
1052
+ return _pack_episodes(recent_episodes)