textpolicy 0.1.1__py3-none-any.whl → 0.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- textpolicy/__init__.py +3 -0
- textpolicy/algorithms/__init__.py +29 -4
- textpolicy/algorithms/grpo.py +771 -361
- textpolicy/algorithms/length_shaping.py +151 -0
- textpolicy/analysis/__init__.py +23 -0
- textpolicy/analysis/emergence_logger.py +248 -0
- textpolicy/analysis/planning_patterns.py +105 -0
- textpolicy/analysis/serialization.py +65 -0
- textpolicy/generation/mlx_generation.py +36 -21
- textpolicy/tasks/__init__.py +7 -0
- textpolicy/tasks/countdown/__init__.py +21 -0
- textpolicy/tasks/countdown/dataset.py +163 -0
- textpolicy/tasks/countdown/evaluator.py +197 -0
- textpolicy/tasks/countdown/prompt.py +89 -0
- textpolicy/tasks/countdown/reward.py +56 -0
- textpolicy/training/trainer.py +41 -21
- {textpolicy-0.1.1.dist-info → textpolicy-0.1.3.dist-info}/METADATA +3 -3
- {textpolicy-0.1.1.dist-info → textpolicy-0.1.3.dist-info}/RECORD +22 -11
- {textpolicy-0.1.1.dist-info → textpolicy-0.1.3.dist-info}/WHEEL +0 -0
- {textpolicy-0.1.1.dist-info → textpolicy-0.1.3.dist-info}/entry_points.txt +0 -0
- {textpolicy-0.1.1.dist-info → textpolicy-0.1.3.dist-info}/licenses/LICENSE +0 -0
- {textpolicy-0.1.1.dist-info → textpolicy-0.1.3.dist-info}/top_level.txt +0 -0
textpolicy/algorithms/grpo.py
CHANGED
|
@@ -26,7 +26,48 @@ if mx is None:
|
|
|
26
26
|
return fn
|
|
27
27
|
|
|
28
28
|
mx = _DummyMx()
|
|
29
|
-
from typing import List, Union
|
|
29
|
+
from typing import List, Union, Tuple, Dict, Any, Optional
|
|
30
|
+
from dataclasses import dataclass
|
|
31
|
+
from collections import defaultdict
|
|
32
|
+
|
|
33
|
+
# Import length shaping utilities from dedicated module
|
|
34
|
+
from .length_shaping import (
|
|
35
|
+
compute_length_penalty,
|
|
36
|
+
apply_length_shaping,
|
|
37
|
+
compute_length_shaping_stats,
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
# --- Clip Configuration Helper ---
|
|
42
|
+
@dataclass
|
|
43
|
+
class ClipConfig:
|
|
44
|
+
"""Configuration for PPO/DAPO clipping bounds."""
|
|
45
|
+
low: float = 0.2
|
|
46
|
+
high: float = 0.28
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def resolve_clip_config(
|
|
50
|
+
clip_ratio: Optional[float],
|
|
51
|
+
clip_ratio_low: float = 0.2,
|
|
52
|
+
clip_ratio_high: float = 0.28,
|
|
53
|
+
) -> ClipConfig:
|
|
54
|
+
"""
|
|
55
|
+
Resolve clipping configuration with backward compatibility.
|
|
56
|
+
|
|
57
|
+
Centralizes the logic for handling symmetric vs asymmetric clipping bounds.
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
clip_ratio: Symmetric clipping ratio (backward compatibility).
|
|
61
|
+
If provided, overrides clip_ratio_low and clip_ratio_high.
|
|
62
|
+
clip_ratio_low: Lower bound offset (default 0.2)
|
|
63
|
+
clip_ratio_high: Upper bound offset (default 0.28)
|
|
64
|
+
|
|
65
|
+
Returns:
|
|
66
|
+
ClipConfig with resolved low and high bounds
|
|
67
|
+
"""
|
|
68
|
+
if clip_ratio is not None:
|
|
69
|
+
return ClipConfig(low=clip_ratio, high=clip_ratio)
|
|
70
|
+
return ClipConfig(low=clip_ratio_low, high=clip_ratio_high)
|
|
30
71
|
|
|
31
72
|
|
|
32
73
|
def compute_advantages(rewards: Union[List[float], mx.array]) -> mx.array:
|
|
@@ -113,45 +154,88 @@ def compute_advantages_dr_grpo(rewards: Union[List[float], mx.array]) -> mx.arra
|
|
|
113
154
|
|
|
114
155
|
def policy_loss(
|
|
115
156
|
old_logprobs: mx.array,
|
|
116
|
-
new_logprobs: mx.array,
|
|
157
|
+
new_logprobs: mx.array,
|
|
117
158
|
advantages: mx.array,
|
|
118
|
-
clip_ratio: float =
|
|
159
|
+
clip_ratio: float = None,
|
|
160
|
+
clip_ratio_low: float = 0.2,
|
|
161
|
+
clip_ratio_high: float = 0.28,
|
|
162
|
+
normalize_constant: int = None
|
|
119
163
|
) -> mx.array:
|
|
120
164
|
"""
|
|
121
|
-
GRPO policy loss with PPO-style clipping.
|
|
122
|
-
|
|
165
|
+
GRPO policy loss with PPO-style clipping, supporting DAPO asymmetric bounds.
|
|
166
|
+
|
|
123
167
|
Uses clipped surrogate objective but with group-relative advantages
|
|
124
|
-
instead of GAE advantages.
|
|
125
|
-
|
|
168
|
+
instead of GAE advantages. Supports asymmetric clipping bounds (DAPO-style)
|
|
169
|
+
to prevent entropy collapse while maintaining training stability.
|
|
170
|
+
|
|
171
|
+
DAPO insight: Asymmetric bounds allow the model to increase probabilities
|
|
172
|
+
of good actions more easily than decreasing probabilities of bad actions,
|
|
173
|
+
promoting diversity and preventing entropy collapse.
|
|
174
|
+
|
|
175
|
+
Dr. GRPO insight: Dividing by a fixed constant instead of token count
|
|
176
|
+
eliminates length bias that artificially inflates incorrect (longer) responses.
|
|
177
|
+
|
|
126
178
|
Args:
|
|
127
179
|
old_logprobs: Log probabilities from rollout collection
|
|
128
180
|
new_logprobs: Log probabilities from current policy evaluation
|
|
129
181
|
advantages: Group-relative advantages from compute_advantages()
|
|
130
|
-
clip_ratio:
|
|
131
|
-
|
|
182
|
+
clip_ratio: Symmetric clipping ratio (for backward compatibility).
|
|
183
|
+
If provided, overrides clip_ratio_low and clip_ratio_high.
|
|
184
|
+
clip_ratio_low: Lower bound offset (default 0.2, gives lower bound of 0.8)
|
|
185
|
+
clip_ratio_high: Upper bound offset (default 0.28, gives upper bound of 1.28)
|
|
186
|
+
normalize_constant: Fixed constant divisor for loss normalization.
|
|
187
|
+
If None (default), uses mean (original behavior).
|
|
188
|
+
If provided, uses sum/constant to eliminate length bias.
|
|
189
|
+
Typical values: 1024, or batch_size.
|
|
190
|
+
|
|
132
191
|
Returns:
|
|
133
192
|
Policy loss scalar (to be minimized)
|
|
134
|
-
|
|
193
|
+
|
|
135
194
|
Notes:
|
|
136
195
|
- Fully vectorized (no Python loops over batch)
|
|
137
196
|
- Uses in-place operations where possible
|
|
138
197
|
- Suitable for MLX graph optimization
|
|
139
198
|
- Single forward pass through computation
|
|
199
|
+
- DAPO defaults: clip_ratio_low=0.2, clip_ratio_high=0.28
|
|
200
|
+
- Length bias: When using mean, longer sequences have lower per-token
|
|
201
|
+
contribution, creating implicit bias toward short responses.
|
|
202
|
+
|
|
203
|
+
References:
|
|
204
|
+
DAPO: An Open-Source LLM Reinforcement Learning System at Scale
|
|
205
|
+
https://arxiv.org/abs/2503.14476
|
|
206
|
+
|
|
207
|
+
Dr. GRPO: Understanding R1-Zero-Like Training
|
|
208
|
+
https://arxiv.org/abs/2503.20783
|
|
140
209
|
"""
|
|
141
|
-
#
|
|
210
|
+
# Resolve clipping configuration (handles backward compatibility)
|
|
211
|
+
clip_cfg = resolve_clip_config(clip_ratio, clip_ratio_low, clip_ratio_high)
|
|
212
|
+
|
|
213
|
+
# Importance ratio: π_new / π_old
|
|
142
214
|
# MLX optimizes exp() for Apple Silicon
|
|
143
215
|
ratio = mx.exp(new_logprobs - old_logprobs)
|
|
144
|
-
|
|
145
|
-
# PPO clipped surrogate objective
|
|
146
|
-
# L = min(ratio * A, clip(ratio, 1
|
|
147
|
-
clipped_ratio = mx.clip(ratio, 1 -
|
|
148
|
-
|
|
149
|
-
# Element-wise minimum
|
|
150
|
-
# Negative because we minimize (original maximizes)
|
|
216
|
+
|
|
217
|
+
# PPO clipped surrogate objective with asymmetric bounds (DAPO-style)
|
|
218
|
+
# L = min(ratio * A, clip(ratio, 1-ε_low, 1+ε_high) * A)
|
|
219
|
+
clipped_ratio = mx.clip(ratio, 1 - clip_cfg.low, 1 + clip_cfg.high)
|
|
220
|
+
|
|
221
|
+
# Element-wise minimum
|
|
151
222
|
surr1 = ratio * advantages
|
|
152
223
|
surr2 = clipped_ratio * advantages
|
|
153
|
-
|
|
154
|
-
|
|
224
|
+
min_surr = mx.minimum(surr1, surr2)
|
|
225
|
+
|
|
226
|
+
# Normalization: either mean (original) or sum/constant (Dr. GRPO)
|
|
227
|
+
if normalize_constant is not None:
|
|
228
|
+
if normalize_constant <= 0:
|
|
229
|
+
raise ValueError(
|
|
230
|
+
f"normalize_constant must be positive, got {normalize_constant}"
|
|
231
|
+
)
|
|
232
|
+
# Fixed constant normalization eliminates length bias
|
|
233
|
+
# All sequences contribute equally regardless of length
|
|
234
|
+
loss = -mx.sum(min_surr) / normalize_constant
|
|
235
|
+
else:
|
|
236
|
+
# Original mean behavior (for backward compatibility)
|
|
237
|
+
loss = -mx.mean(min_surr)
|
|
238
|
+
|
|
155
239
|
return loss
|
|
156
240
|
|
|
157
241
|
|
|
@@ -163,86 +247,491 @@ def compute_advantages_compiled(rewards: mx.array) -> mx.array:
|
|
|
163
247
|
return rewards - group_mean
|
|
164
248
|
|
|
165
249
|
|
|
166
|
-
|
|
167
|
-
|
|
250
|
+
# --- Compiled Policy Loss Variants ---
|
|
251
|
+
# Two internal compiled functions for different normalization strategies.
|
|
252
|
+
# Compiled functions require static control flow, so we keep them separate.
|
|
253
|
+
|
|
254
|
+
@mx.compile
|
|
255
|
+
def _policy_loss_compiled_mean(
|
|
168
256
|
old_logprobs: mx.array,
|
|
169
257
|
new_logprobs: mx.array,
|
|
170
258
|
advantages: mx.array,
|
|
171
|
-
|
|
259
|
+
clip_ratio_low: float,
|
|
260
|
+
clip_ratio_high: float
|
|
172
261
|
) -> mx.array:
|
|
173
|
-
"""
|
|
262
|
+
"""Internal compiled function: mean normalization."""
|
|
174
263
|
ratio = mx.exp(new_logprobs - old_logprobs)
|
|
175
|
-
clipped_ratio = mx.clip(ratio, 1 -
|
|
264
|
+
clipped_ratio = mx.clip(ratio, 1 - clip_ratio_low, 1 + clip_ratio_high)
|
|
176
265
|
surr1 = ratio * advantages
|
|
177
266
|
surr2 = clipped_ratio * advantages
|
|
178
267
|
return -mx.mean(mx.minimum(surr1, surr2))
|
|
179
268
|
|
|
180
269
|
|
|
270
|
+
@mx.compile
|
|
271
|
+
def _policy_loss_compiled_constant(
|
|
272
|
+
old_logprobs: mx.array,
|
|
273
|
+
new_logprobs: mx.array,
|
|
274
|
+
advantages: mx.array,
|
|
275
|
+
clip_ratio_low: float,
|
|
276
|
+
clip_ratio_high: float,
|
|
277
|
+
normalize_constant: float
|
|
278
|
+
) -> mx.array:
|
|
279
|
+
"""Internal compiled function: constant normalization."""
|
|
280
|
+
ratio = mx.exp(new_logprobs - old_logprobs)
|
|
281
|
+
clipped_ratio = mx.clip(ratio, 1 - clip_ratio_low, 1 + clip_ratio_high)
|
|
282
|
+
surr1 = ratio * advantages
|
|
283
|
+
surr2 = clipped_ratio * advantages
|
|
284
|
+
return -mx.sum(mx.minimum(surr1, surr2)) / normalize_constant
|
|
285
|
+
|
|
286
|
+
|
|
287
|
+
def policy_loss_compiled(
|
|
288
|
+
old_logprobs: mx.array,
|
|
289
|
+
new_logprobs: mx.array,
|
|
290
|
+
advantages: mx.array,
|
|
291
|
+
clip_ratio: float = None,
|
|
292
|
+
clip_ratio_low: float = 0.2,
|
|
293
|
+
clip_ratio_high: float = 0.28
|
|
294
|
+
) -> mx.array:
|
|
295
|
+
"""
|
|
296
|
+
Compiled version of policy_loss for maximum performance (mean normalization).
|
|
297
|
+
|
|
298
|
+
Supports DAPO-style asymmetric clipping bounds with backward compatibility.
|
|
299
|
+
Uses mean normalization (original behavior).
|
|
300
|
+
|
|
301
|
+
Args:
|
|
302
|
+
old_logprobs: Log probabilities from rollout collection
|
|
303
|
+
new_logprobs: Log probabilities from current policy evaluation
|
|
304
|
+
advantages: Group-relative advantages
|
|
305
|
+
clip_ratio: Symmetric clipping ratio (for backward compatibility).
|
|
306
|
+
If provided, overrides clip_ratio_low and clip_ratio_high.
|
|
307
|
+
clip_ratio_low: Lower bound offset (default 0.2)
|
|
308
|
+
clip_ratio_high: Upper bound offset (default 0.28)
|
|
309
|
+
"""
|
|
310
|
+
clip_cfg = resolve_clip_config(clip_ratio, clip_ratio_low, clip_ratio_high)
|
|
311
|
+
return _policy_loss_compiled_mean(
|
|
312
|
+
old_logprobs, new_logprobs, advantages,
|
|
313
|
+
clip_cfg.low, clip_cfg.high
|
|
314
|
+
)
|
|
315
|
+
|
|
316
|
+
|
|
317
|
+
def policy_loss_compiled_constant_norm(
|
|
318
|
+
old_logprobs: mx.array,
|
|
319
|
+
new_logprobs: mx.array,
|
|
320
|
+
advantages: mx.array,
|
|
321
|
+
clip_ratio: float = None,
|
|
322
|
+
clip_ratio_low: float = 0.2,
|
|
323
|
+
clip_ratio_high: float = 0.28,
|
|
324
|
+
normalize_constant: float = 1024.0
|
|
325
|
+
) -> mx.array:
|
|
326
|
+
"""
|
|
327
|
+
Compiled version of policy_loss with fixed constant normalization (Dr. GRPO).
|
|
328
|
+
|
|
329
|
+
Uses sum/constant instead of mean to eliminate length bias.
|
|
330
|
+
|
|
331
|
+
Args:
|
|
332
|
+
old_logprobs: Log probabilities from rollout collection
|
|
333
|
+
new_logprobs: Log probabilities from current policy evaluation
|
|
334
|
+
advantages: Group-relative advantages
|
|
335
|
+
clip_ratio: Symmetric clipping ratio (for backward compatibility).
|
|
336
|
+
If provided, overrides clip_ratio_low and clip_ratio_high.
|
|
337
|
+
clip_ratio_low: Lower bound offset (default 0.2)
|
|
338
|
+
clip_ratio_high: Upper bound offset (default 0.28)
|
|
339
|
+
normalize_constant: Fixed constant divisor (default 1024)
|
|
340
|
+
|
|
341
|
+
References:
|
|
342
|
+
Dr. GRPO: Understanding R1-Zero-Like Training
|
|
343
|
+
https://arxiv.org/abs/2503.20783
|
|
344
|
+
"""
|
|
345
|
+
if normalize_constant <= 0:
|
|
346
|
+
raise ValueError(
|
|
347
|
+
f"normalize_constant must be positive, got {normalize_constant}"
|
|
348
|
+
)
|
|
349
|
+
clip_cfg = resolve_clip_config(clip_ratio, clip_ratio_low, clip_ratio_high)
|
|
350
|
+
return _policy_loss_compiled_constant(
|
|
351
|
+
old_logprobs, new_logprobs, advantages,
|
|
352
|
+
clip_cfg.low, clip_cfg.high, normalize_constant
|
|
353
|
+
)
|
|
354
|
+
|
|
355
|
+
|
|
181
356
|
def entropy_bonus(logprobs: mx.array, coefficient: float = 0.01) -> mx.array:
|
|
182
357
|
"""
|
|
183
358
|
Entropy bonus for exploration (optional GRPO component).
|
|
184
|
-
|
|
359
|
+
|
|
185
360
|
Args:
|
|
186
361
|
logprobs: Log probabilities from policy
|
|
187
362
|
coefficient: Entropy coefficient (typically small, like 0.01)
|
|
188
|
-
|
|
363
|
+
|
|
189
364
|
Returns:
|
|
190
365
|
Entropy bonus (added to loss for exploration)
|
|
191
366
|
"""
|
|
192
367
|
if coefficient <= 0:
|
|
193
368
|
return mx.array(0.0)
|
|
194
|
-
|
|
369
|
+
|
|
195
370
|
# Entropy = -sum(p * log(p))
|
|
196
371
|
# For log probabilities: entropy = -sum(exp(logp) * logp)
|
|
197
372
|
probs = mx.exp(logprobs)
|
|
198
373
|
entropy = -mx.sum(probs * logprobs, axis=-1)
|
|
199
|
-
|
|
374
|
+
|
|
200
375
|
# Return negative entropy (since we add to loss but want to maximize entropy)
|
|
201
376
|
return -coefficient * mx.mean(entropy)
|
|
202
377
|
|
|
203
378
|
|
|
379
|
+
# Note: compute_length_penalty, apply_length_shaping, and compute_length_shaping_stats
|
|
380
|
+
# are imported from .length_shaping module (see imports at top of file)
|
|
381
|
+
|
|
382
|
+
|
|
383
|
+
# DAPO-style dynamic batch filtering (Issue #9)
|
|
384
|
+
def _get_episode_reward(episode) -> float:
|
|
385
|
+
"""Extract total reward from episode (handles both Episode objects and dicts)."""
|
|
386
|
+
if hasattr(episode, 'rew'):
|
|
387
|
+
# Episode object
|
|
388
|
+
return float(mx.sum(mx.array(episode.rew)).item())
|
|
389
|
+
else:
|
|
390
|
+
# Serialized dictionary
|
|
391
|
+
rew = episode.get('rew', episode.get('reward', [0.0]))
|
|
392
|
+
if isinstance(rew, (int, float)):
|
|
393
|
+
return float(rew)
|
|
394
|
+
return float(mx.sum(mx.array(rew)).item())
|
|
395
|
+
|
|
396
|
+
|
|
397
|
+
def _get_prompt_key(episode) -> tuple:
|
|
398
|
+
"""
|
|
399
|
+
Generate a hashable key for an episode's prompt.
|
|
400
|
+
|
|
401
|
+
Handles both Episode objects and serialized dictionaries.
|
|
402
|
+
Uses the observation (prompt) tokens to identify the prompt.
|
|
403
|
+
"""
|
|
404
|
+
if hasattr(episode, 'obs'):
|
|
405
|
+
obs = episode.obs
|
|
406
|
+
else:
|
|
407
|
+
obs = episode.get('obs', [])
|
|
408
|
+
|
|
409
|
+
# Flatten nested structures to create consistent key
|
|
410
|
+
flattened = []
|
|
411
|
+
for item in obs:
|
|
412
|
+
if hasattr(item, 'tolist'): # MLX array
|
|
413
|
+
flattened.extend(item.tolist())
|
|
414
|
+
elif isinstance(item, list):
|
|
415
|
+
flattened.extend(item)
|
|
416
|
+
else:
|
|
417
|
+
flattened.append(item)
|
|
418
|
+
|
|
419
|
+
return tuple(flattened)
|
|
420
|
+
|
|
421
|
+
|
|
422
|
+
def _precompute_episode_rewards(episodes: List[Any]) -> List[float]:
|
|
423
|
+
"""
|
|
424
|
+
Pre-compute rewards for all episodes in a single pass.
|
|
425
|
+
|
|
426
|
+
Uses batched MLX evaluation to avoid per-episode .item() sync barriers.
|
|
427
|
+
All mx.sum() calls are built lazily, then evaluated in one mx.eval() call.
|
|
428
|
+
|
|
429
|
+
Args:
|
|
430
|
+
episodes: List of episodes
|
|
431
|
+
|
|
432
|
+
Returns:
|
|
433
|
+
List of rewards in the same order as episodes
|
|
434
|
+
"""
|
|
435
|
+
if not episodes:
|
|
436
|
+
return []
|
|
437
|
+
|
|
438
|
+
rewards: List[Optional[float]] = [None] * len(episodes)
|
|
439
|
+
pending: List[Tuple[int, mx.array]] = [] # (index, lazy_sum) pairs
|
|
440
|
+
|
|
441
|
+
for i, ep in enumerate(episodes):
|
|
442
|
+
if hasattr(ep, 'rew'):
|
|
443
|
+
rew = ep.rew
|
|
444
|
+
else:
|
|
445
|
+
rew = ep.get('rew', ep.get('reward', [0.0]))
|
|
446
|
+
|
|
447
|
+
if isinstance(rew, (int, float)):
|
|
448
|
+
rewards[i] = float(rew)
|
|
449
|
+
else:
|
|
450
|
+
pending.append((i, mx.sum(mx.array(rew))))
|
|
451
|
+
|
|
452
|
+
# Single sync barrier for all array rewards
|
|
453
|
+
if pending:
|
|
454
|
+
indices, lazy_sums = zip(*pending)
|
|
455
|
+
stacked = mx.stack(list(lazy_sums))
|
|
456
|
+
mx.eval(stacked)
|
|
457
|
+
values = stacked.tolist()
|
|
458
|
+
for idx, val in zip(indices, values):
|
|
459
|
+
rewards[idx] = float(val)
|
|
460
|
+
|
|
461
|
+
return rewards # type: ignore[return-value]
|
|
462
|
+
|
|
463
|
+
|
|
464
|
+
def _compute_group_variance_and_mean(
|
|
465
|
+
group_indices: List[int],
|
|
466
|
+
all_rewards: List[float]
|
|
467
|
+
) -> Tuple[float, float]:
|
|
468
|
+
"""
|
|
469
|
+
Compute variance and mean for a group of episodes using pre-computed rewards.
|
|
470
|
+
|
|
471
|
+
Args:
|
|
472
|
+
group_indices: Indices into all_rewards for this group
|
|
473
|
+
all_rewards: Pre-computed rewards for all episodes
|
|
474
|
+
|
|
475
|
+
Returns:
|
|
476
|
+
Tuple of (variance, mean)
|
|
477
|
+
"""
|
|
478
|
+
group_rewards = mx.array([all_rewards[i] for i in group_indices])
|
|
479
|
+
return mx.var(group_rewards).item(), mx.mean(group_rewards).item()
|
|
480
|
+
|
|
481
|
+
|
|
482
|
+
def filter_informative_prompts(
|
|
483
|
+
episodes: List[Any],
|
|
484
|
+
min_variance: float = 0.01,
|
|
485
|
+
keep_single_completion: bool = True
|
|
486
|
+
) -> Tuple[List[Any], Dict[str, Union[int, float]]]:
|
|
487
|
+
"""
|
|
488
|
+
Filter episodes to keep only informative prompts (DAPO dynamic sampling).
|
|
489
|
+
|
|
490
|
+
Removes prompts where all completions have same outcome:
|
|
491
|
+
- All correct (reward ~1.0): no learning signal (nothing to improve)
|
|
492
|
+
- All wrong (reward ~0.0): no positive signal (can't learn what works)
|
|
493
|
+
|
|
494
|
+
GRPO uses group-relative advantages. If all completions have the same
|
|
495
|
+
outcome, advantages are zero, producing no gradient and wasting compute.
|
|
496
|
+
|
|
497
|
+
Note on single-completion prompts:
|
|
498
|
+
The DAPO paper (Equation 11) defines informative prompts as having
|
|
499
|
+
mixed outcomes: `0 < |correct| < G`. This assumes G > 1 completions
|
|
500
|
+
per prompt. For single-completion prompts (G=1), variance is always 0
|
|
501
|
+
by definition, but this doesn't mean "all outcomes are the same" -
|
|
502
|
+
it means we have insufficient data to determine variance.
|
|
503
|
+
|
|
504
|
+
By default (keep_single_completion=True), single-completion prompts
|
|
505
|
+
are kept since they still provide valid gradient signal. Set to False
|
|
506
|
+
to filter them out (stricter DAPO interpretation).
|
|
507
|
+
|
|
508
|
+
Args:
|
|
509
|
+
episodes: List of episodes (Episode objects or serialized dicts)
|
|
510
|
+
min_variance: Minimum reward variance to keep a prompt group.
|
|
511
|
+
Groups with variance below this threshold are filtered out.
|
|
512
|
+
Default 0.01 filters prompts with essentially identical rewards.
|
|
513
|
+
Only applied to groups with 2+ completions.
|
|
514
|
+
keep_single_completion: Whether to keep prompts with only one completion.
|
|
515
|
+
Default True (keep them). Set False to require
|
|
516
|
+
multiple completions for variance calculation.
|
|
517
|
+
|
|
518
|
+
Returns:
|
|
519
|
+
Tuple of:
|
|
520
|
+
- filtered: List of episodes from informative prompts
|
|
521
|
+
- stats: Dictionary with filtering statistics:
|
|
522
|
+
- 'prompts_kept': Number of prompt groups kept
|
|
523
|
+
- 'prompts_dropped_all_correct': Prompts where all completions succeeded
|
|
524
|
+
- 'prompts_dropped_all_wrong': Prompts where all completions failed
|
|
525
|
+
- 'prompts_dropped_single': Prompts dropped due to single completion
|
|
526
|
+
- 'prompts_kept_single': Single-completion prompts that were kept
|
|
527
|
+
- 'episodes_kept': Total episodes kept
|
|
528
|
+
- 'episodes_dropped': Total episodes filtered out
|
|
529
|
+
- 'filter_rate': Fraction of prompts filtered
|
|
530
|
+
|
|
531
|
+
Example:
|
|
532
|
+
>>> filtered, stats = filter_informative_prompts(episodes, min_variance=0.01)
|
|
533
|
+
>>> print(f"Kept {stats['prompts_kept']} prompts, "
|
|
534
|
+
... f"dropped {stats['prompts_dropped_all_correct']} all-correct, "
|
|
535
|
+
... f"{stats['prompts_dropped_all_wrong']} all-wrong")
|
|
536
|
+
|
|
537
|
+
References:
|
|
538
|
+
DAPO: An Open-Source LLM Reinforcement Learning System at Scale
|
|
539
|
+
https://arxiv.org/abs/2503.14476 (Equation 11: 0 < |correct| < G)
|
|
540
|
+
|
|
541
|
+
GRPO++ Tricks
|
|
542
|
+
https://cameronrwolfe.substack.com/p/grpo-tricks
|
|
543
|
+
"""
|
|
544
|
+
if not episodes:
|
|
545
|
+
return [], {
|
|
546
|
+
'prompts_kept': 0,
|
|
547
|
+
'prompts_dropped_all_correct': 0,
|
|
548
|
+
'prompts_dropped_all_wrong': 0,
|
|
549
|
+
'prompts_dropped_single': 0,
|
|
550
|
+
'prompts_kept_single': 0,
|
|
551
|
+
'episodes_kept': 0,
|
|
552
|
+
'episodes_dropped': 0,
|
|
553
|
+
'filter_rate': 0.0,
|
|
554
|
+
}
|
|
555
|
+
|
|
556
|
+
# Pre-compute all rewards once (avoids repeated _get_episode_reward calls)
|
|
557
|
+
all_rewards = _precompute_episode_rewards(episodes)
|
|
558
|
+
|
|
559
|
+
# Group episodes by prompt, storing indices instead of episodes
|
|
560
|
+
prompt_groups: Dict[tuple, List[int]] = defaultdict(list)
|
|
561
|
+
for idx, ep in enumerate(episodes):
|
|
562
|
+
prompt_key = _get_prompt_key(ep)
|
|
563
|
+
prompt_groups[prompt_key].append(idx)
|
|
564
|
+
|
|
565
|
+
filtered = []
|
|
566
|
+
stats = {
|
|
567
|
+
'prompts_kept': 0,
|
|
568
|
+
'prompts_dropped_all_correct': 0,
|
|
569
|
+
'prompts_dropped_all_wrong': 0,
|
|
570
|
+
'prompts_dropped_single': 0,
|
|
571
|
+
'prompts_kept_single': 0,
|
|
572
|
+
'episodes_kept': 0,
|
|
573
|
+
'episodes_dropped': 0,
|
|
574
|
+
}
|
|
575
|
+
|
|
576
|
+
for prompt_key, group_indices in prompt_groups.items():
|
|
577
|
+
group_size = len(group_indices)
|
|
578
|
+
|
|
579
|
+
# Handle single-completion prompts separately
|
|
580
|
+
if group_size == 1:
|
|
581
|
+
if keep_single_completion:
|
|
582
|
+
# Keep single-completion prompts (variance undefined, not "zero")
|
|
583
|
+
filtered.append(episodes[group_indices[0]])
|
|
584
|
+
stats['prompts_kept'] += 1
|
|
585
|
+
stats['prompts_kept_single'] += 1
|
|
586
|
+
stats['episodes_kept'] += 1
|
|
587
|
+
else:
|
|
588
|
+
# Filter out single-completion prompts (strict DAPO interpretation)
|
|
589
|
+
stats['prompts_dropped_single'] += 1
|
|
590
|
+
stats['episodes_dropped'] += 1
|
|
591
|
+
continue
|
|
592
|
+
|
|
593
|
+
# For groups with 2+ completions, use variance criterion
|
|
594
|
+
variance, mean_reward = _compute_group_variance_and_mean(group_indices, all_rewards)
|
|
595
|
+
|
|
596
|
+
if variance > min_variance:
|
|
597
|
+
# Informative: mixed outcomes, keep all episodes from this prompt
|
|
598
|
+
for idx in group_indices:
|
|
599
|
+
filtered.append(episodes[idx])
|
|
600
|
+
stats['prompts_kept'] += 1
|
|
601
|
+
stats['episodes_kept'] += group_size
|
|
602
|
+
else:
|
|
603
|
+
# Uninformative: all completions have same outcome
|
|
604
|
+
stats['episodes_dropped'] += group_size
|
|
605
|
+
if mean_reward > 0.5:
|
|
606
|
+
stats['prompts_dropped_all_correct'] += 1
|
|
607
|
+
else:
|
|
608
|
+
stats['prompts_dropped_all_wrong'] += 1
|
|
609
|
+
|
|
610
|
+
# Compute filter rate
|
|
611
|
+
total_prompts = len(prompt_groups)
|
|
612
|
+
stats['filter_rate'] = 1.0 - (stats['prompts_kept'] / total_prompts) if total_prompts > 0 else 0.0
|
|
613
|
+
|
|
614
|
+
return filtered, stats
|
|
615
|
+
|
|
616
|
+
|
|
617
|
+
def compute_prompt_group_stats(episodes: List[Any]) -> Dict[str, Any]:
|
|
618
|
+
"""
|
|
619
|
+
Compute statistics about prompt groups for monitoring.
|
|
620
|
+
|
|
621
|
+
Useful for understanding the distribution of prompts and completions
|
|
622
|
+
before and after filtering.
|
|
623
|
+
|
|
624
|
+
Args:
|
|
625
|
+
episodes: List of episodes
|
|
626
|
+
|
|
627
|
+
Returns:
|
|
628
|
+
Dictionary with:
|
|
629
|
+
- 'num_prompts': Total unique prompts
|
|
630
|
+
- 'num_episodes': Total episodes
|
|
631
|
+
- 'completions_per_prompt': Average completions per prompt
|
|
632
|
+
- 'reward_variance_mean': Mean variance across prompt groups
|
|
633
|
+
- 'reward_variance_std': Std of variance across prompt groups
|
|
634
|
+
"""
|
|
635
|
+
if not episodes:
|
|
636
|
+
return {
|
|
637
|
+
'num_prompts': 0,
|
|
638
|
+
'num_episodes': 0,
|
|
639
|
+
'completions_per_prompt': 0.0,
|
|
640
|
+
'reward_variance_mean': 0.0,
|
|
641
|
+
'reward_variance_std': 0.0,
|
|
642
|
+
}
|
|
643
|
+
|
|
644
|
+
# Pre-compute all rewards once
|
|
645
|
+
all_rewards = _precompute_episode_rewards(episodes)
|
|
646
|
+
|
|
647
|
+
# Group by prompt, storing indices
|
|
648
|
+
prompt_groups: Dict[tuple, List[int]] = defaultdict(list)
|
|
649
|
+
for idx, ep in enumerate(episodes):
|
|
650
|
+
prompt_key = _get_prompt_key(ep)
|
|
651
|
+
prompt_groups[prompt_key].append(idx)
|
|
652
|
+
|
|
653
|
+
# Compute variance for each group using pre-computed rewards
|
|
654
|
+
variances = []
|
|
655
|
+
for group_indices in prompt_groups.values():
|
|
656
|
+
group_rewards = mx.array([all_rewards[i] for i in group_indices])
|
|
657
|
+
variances.append(mx.var(group_rewards).item())
|
|
658
|
+
|
|
659
|
+
variances_arr = mx.array(variances) if variances else mx.array([0.0])
|
|
660
|
+
|
|
661
|
+
return {
|
|
662
|
+
'num_prompts': len(prompt_groups),
|
|
663
|
+
'num_episodes': len(episodes),
|
|
664
|
+
'completions_per_prompt': len(episodes) / len(prompt_groups) if prompt_groups else 0.0,
|
|
665
|
+
'reward_variance_mean': float(mx.mean(variances_arr).item()),
|
|
666
|
+
'reward_variance_std': float(mx.std(variances_arr).item()),
|
|
667
|
+
}
|
|
668
|
+
|
|
669
|
+
|
|
204
670
|
# Convenience function for complete GRPO computation
|
|
205
671
|
def grpo_loss(
|
|
206
672
|
old_logprobs: mx.array,
|
|
207
673
|
new_logprobs: mx.array,
|
|
208
674
|
rewards: Union[List[float], mx.array],
|
|
209
|
-
clip_ratio: float =
|
|
210
|
-
|
|
675
|
+
clip_ratio: float = None,
|
|
676
|
+
clip_ratio_low: float = 0.2,
|
|
677
|
+
clip_ratio_high: float = 0.28,
|
|
678
|
+
entropy_coeff: float = 0.0,
|
|
679
|
+
normalize_constant: int = None
|
|
211
680
|
) -> mx.array:
|
|
212
681
|
"""
|
|
213
682
|
Complete GRPO loss computation in one function.
|
|
214
|
-
|
|
683
|
+
|
|
215
684
|
Combines advantage calculation and policy loss for convenience.
|
|
216
685
|
Can be compiled as a single unit for maximum efficiency.
|
|
217
|
-
|
|
686
|
+
Supports DAPO-style asymmetric clipping bounds and Dr. GRPO length-bias fix.
|
|
687
|
+
|
|
218
688
|
Args:
|
|
219
689
|
old_logprobs: Log probabilities from rollout
|
|
220
690
|
new_logprobs: Log probabilities from current policy
|
|
221
691
|
rewards: Episode rewards for group-relative advantages
|
|
222
|
-
clip_ratio:
|
|
692
|
+
clip_ratio: Symmetric clipping ratio (for backward compatibility).
|
|
693
|
+
If provided, overrides clip_ratio_low and clip_ratio_high.
|
|
694
|
+
clip_ratio_low: Lower bound offset (default 0.2)
|
|
695
|
+
clip_ratio_high: Upper bound offset (default 0.28)
|
|
223
696
|
entropy_coeff: Entropy bonus coefficient (0 disables)
|
|
224
|
-
|
|
697
|
+
normalize_constant: Fixed constant divisor for loss normalization.
|
|
698
|
+
If None (default), uses mean. If provided, uses
|
|
699
|
+
sum/constant to eliminate length bias.
|
|
700
|
+
|
|
225
701
|
Returns:
|
|
226
702
|
Total GRPO loss (policy + optional entropy)
|
|
703
|
+
|
|
704
|
+
References:
|
|
705
|
+
DAPO: An Open-Source LLM Reinforcement Learning System at Scale
|
|
706
|
+
https://arxiv.org/abs/2503.14476
|
|
707
|
+
|
|
708
|
+
Dr. GRPO: Understanding R1-Zero-Like Training
|
|
709
|
+
https://arxiv.org/abs/2503.20783
|
|
227
710
|
"""
|
|
228
711
|
# Compute group-relative advantages
|
|
229
712
|
advantages = compute_advantages(rewards)
|
|
230
|
-
|
|
713
|
+
|
|
231
714
|
# Expand advantages to match logprob sequence length if needed
|
|
232
715
|
if advantages.ndim == 1 and old_logprobs.ndim > 1:
|
|
233
716
|
# Each episode contributes its advantage to all tokens in that episode
|
|
234
717
|
# This requires knowing episode boundaries - simplified version assumes
|
|
235
718
|
# advantages and logprobs are already aligned
|
|
236
719
|
pass
|
|
237
|
-
|
|
238
|
-
# Compute policy loss
|
|
239
|
-
policy_loss_val = policy_loss(
|
|
240
|
-
|
|
720
|
+
|
|
721
|
+
# Compute policy loss with asymmetric clipping and optional length-bias fix
|
|
722
|
+
policy_loss_val = policy_loss(
|
|
723
|
+
old_logprobs, new_logprobs, advantages,
|
|
724
|
+
clip_ratio=clip_ratio,
|
|
725
|
+
clip_ratio_low=clip_ratio_low,
|
|
726
|
+
clip_ratio_high=clip_ratio_high,
|
|
727
|
+
normalize_constant=normalize_constant
|
|
728
|
+
)
|
|
729
|
+
|
|
241
730
|
# Add entropy bonus if specified
|
|
242
731
|
if entropy_coeff > 0:
|
|
243
732
|
entropy_bonus_val = entropy_bonus(new_logprobs, entropy_coeff)
|
|
244
733
|
return policy_loss_val + entropy_bonus_val
|
|
245
|
-
|
|
734
|
+
|
|
246
735
|
return policy_loss_val
|
|
247
736
|
|
|
248
737
|
|
|
@@ -251,392 +740,313 @@ def compute_metrics(
|
|
|
251
740
|
old_logprobs: mx.array,
|
|
252
741
|
new_logprobs: mx.array,
|
|
253
742
|
advantages: mx.array,
|
|
254
|
-
clip_ratio: float =
|
|
743
|
+
clip_ratio: float = None,
|
|
744
|
+
clip_ratio_low: float = 0.2,
|
|
745
|
+
clip_ratio_high: float = 0.28
|
|
255
746
|
) -> dict:
|
|
256
747
|
"""
|
|
257
748
|
Compute GRPO training metrics for monitoring.
|
|
258
|
-
|
|
749
|
+
|
|
750
|
+
Supports DAPO-style asymmetric clipping bounds and tracks clip fractions
|
|
751
|
+
for upper vs lower bounds separately.
|
|
752
|
+
|
|
259
753
|
Args:
|
|
260
754
|
old_logprobs: Log probabilities from rollout
|
|
261
|
-
new_logprobs: Log probabilities from current policy
|
|
755
|
+
new_logprobs: Log probabilities from current policy
|
|
262
756
|
advantages: Group-relative advantages
|
|
263
|
-
clip_ratio:
|
|
264
|
-
|
|
757
|
+
clip_ratio: Symmetric clipping ratio (for backward compatibility).
|
|
758
|
+
If provided, overrides clip_ratio_low and clip_ratio_high.
|
|
759
|
+
clip_ratio_low: Lower bound offset (default 0.2)
|
|
760
|
+
clip_ratio_high: Upper bound offset (default 0.28)
|
|
761
|
+
|
|
265
762
|
Returns:
|
|
266
|
-
Dictionary of metrics for logging/monitoring
|
|
763
|
+
Dictionary of metrics for logging/monitoring, including:
|
|
764
|
+
- clip_fraction_lower: Fraction of ratios clipped at lower bound
|
|
765
|
+
- clip_fraction_upper: Fraction of ratios clipped at upper bound
|
|
766
|
+
- clip_fraction: Total fraction of ratios clipped (either bound)
|
|
267
767
|
"""
|
|
768
|
+
# Resolve clipping configuration (handles backward compatibility)
|
|
769
|
+
clip_cfg = resolve_clip_config(clip_ratio, clip_ratio_low, clip_ratio_high)
|
|
770
|
+
|
|
268
771
|
# Importance ratio statistics
|
|
269
772
|
ratio = mx.exp(new_logprobs - old_logprobs)
|
|
270
|
-
|
|
271
|
-
#
|
|
272
|
-
clip_lower = 1 -
|
|
273
|
-
clip_upper = 1 +
|
|
274
|
-
|
|
773
|
+
|
|
774
|
+
# Asymmetric clipping bounds
|
|
775
|
+
clip_lower = 1 - clip_cfg.low
|
|
776
|
+
clip_upper = 1 + clip_cfg.high
|
|
777
|
+
|
|
778
|
+
# Track clip fractions separately for upper and lower bounds
|
|
779
|
+
clipped_lower = ratio < clip_lower
|
|
780
|
+
clipped_upper = ratio > clip_upper
|
|
781
|
+
clipped = clipped_lower | clipped_upper
|
|
782
|
+
|
|
783
|
+
clip_fraction_lower = mx.mean(clipped_lower.astype(mx.float32))
|
|
784
|
+
clip_fraction_upper = mx.mean(clipped_upper.astype(mx.float32))
|
|
275
785
|
clip_fraction = mx.mean(clipped.astype(mx.float32))
|
|
276
|
-
|
|
786
|
+
|
|
277
787
|
# KL divergence approximation
|
|
278
788
|
kl_div = mx.mean(old_logprobs - new_logprobs)
|
|
279
|
-
|
|
789
|
+
|
|
280
790
|
return {
|
|
281
791
|
'mean_advantage': mx.mean(advantages).item(),
|
|
282
792
|
'std_advantage': mx.std(advantages).item(),
|
|
283
793
|
'mean_ratio': mx.mean(ratio).item(),
|
|
284
794
|
'clip_fraction': clip_fraction.item(),
|
|
795
|
+
'clip_fraction_lower': clip_fraction_lower.item(),
|
|
796
|
+
'clip_fraction_upper': clip_fraction_upper.item(),
|
|
285
797
|
'kl_divergence': kl_div.item(),
|
|
286
798
|
'min_advantage': mx.min(advantages).item(),
|
|
287
|
-
'max_advantage': mx.max(advantages).item()
|
|
799
|
+
'max_advantage': mx.max(advantages).item(),
|
|
800
|
+
'clip_ratio_low': clip_cfg.low,
|
|
801
|
+
'clip_ratio_high': clip_cfg.high
|
|
288
802
|
}
|
|
289
803
|
|
|
290
804
|
|
|
291
|
-
#
|
|
292
|
-
def
|
|
805
|
+
# --- Episode Packing Helper ---
|
|
806
|
+
def _flatten_tokens(items: List[Any]) -> List:
|
|
807
|
+
"""Flatten nested token sequences into a flat list."""
|
|
808
|
+
flattened = []
|
|
809
|
+
for item in items:
|
|
810
|
+
if hasattr(item, 'tolist'): # MLX array
|
|
811
|
+
flattened.extend(item.tolist())
|
|
812
|
+
elif isinstance(item, list): # Python list
|
|
813
|
+
flattened.extend(item)
|
|
814
|
+
else: # Single token
|
|
815
|
+
flattened.append(item)
|
|
816
|
+
return flattened
|
|
817
|
+
|
|
818
|
+
|
|
819
|
+
def _pack_episodes(episodes: List[Any]) -> Dict[str, Any]:
|
|
293
820
|
"""
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
821
|
+
Pack episodes into batch data for GRPO training.
|
|
822
|
+
|
|
823
|
+
This is the shared helper for episode-to-batch conversion, used by all
|
|
824
|
+
data selectors (select_all_data, select_informative_data, select_recent_data).
|
|
825
|
+
|
|
299
826
|
Args:
|
|
300
|
-
|
|
301
|
-
|
|
827
|
+
episodes: List of episodes (Episode objects or serialized dicts)
|
|
828
|
+
|
|
302
829
|
Returns:
|
|
303
|
-
|
|
830
|
+
Dictionary with:
|
|
831
|
+
- 'obs': Flat concatenated full sequences (prompt + response)
|
|
832
|
+
- 'act': Flat concatenated response tokens
|
|
833
|
+
- 'logprob': Flat concatenated log probabilities
|
|
834
|
+
- 'rewards': Episode rewards as MLX array
|
|
835
|
+
- 'episode_lengths': List of episode lengths
|
|
304
836
|
"""
|
|
305
|
-
from textpolicy.buffer import Buffer
|
|
306
|
-
if not isinstance(buffer, Buffer):
|
|
307
|
-
raise TypeError(f"Expected Buffer, got {type(buffer)}")
|
|
308
|
-
|
|
309
|
-
# Use all available data - GRPO benefits from larger groups
|
|
310
|
-
episodes_data = buffer.sample() # This returns concatenated transitions
|
|
311
|
-
|
|
312
|
-
# We need to convert this back to episode structure for reward extraction
|
|
313
|
-
episodes = buffer.episodes # Access episodes directly from storage
|
|
314
|
-
|
|
315
837
|
if not episodes:
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
838
|
+
return {
|
|
839
|
+
'obs': mx.array([], dtype=mx.int64),
|
|
840
|
+
'act': mx.array([], dtype=mx.int64),
|
|
841
|
+
'logprob': mx.array([], dtype=mx.float32),
|
|
842
|
+
'rewards': mx.array([]),
|
|
843
|
+
'episode_lengths': [],
|
|
844
|
+
}
|
|
845
|
+
|
|
320
846
|
episode_lengths = []
|
|
321
|
-
|
|
322
|
-
# Collect all transitions
|
|
323
847
|
all_obs = []
|
|
324
848
|
all_acts = []
|
|
325
849
|
all_logprobs = []
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
850
|
+
pending_reward_sums: List[Tuple[int, mx.array]] = []
|
|
851
|
+
scalar_rewards: Dict[int, float] = {}
|
|
852
|
+
|
|
853
|
+
for i, episode in enumerate(episodes):
|
|
330
854
|
if hasattr(episode, 'rew'):
|
|
331
855
|
# Episode object with attributes
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
#
|
|
339
|
-
#
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
flattened_obs = []
|
|
343
|
-
for obs in episode.obs:
|
|
344
|
-
if hasattr(obs, 'tolist'): # MLX array
|
|
345
|
-
flattened_obs.extend(obs.tolist())
|
|
346
|
-
elif isinstance(obs, list): # Python list
|
|
347
|
-
flattened_obs.extend(obs)
|
|
348
|
-
else: # Single token
|
|
349
|
-
flattened_obs.append(obs)
|
|
350
|
-
|
|
351
|
-
# Extract and flatten action tokens (response)
|
|
352
|
-
flattened_acts = []
|
|
353
|
-
for act in episode.act:
|
|
354
|
-
if hasattr(act, 'tolist'): # MLX array
|
|
355
|
-
flattened_acts.extend(act.tolist())
|
|
356
|
-
elif isinstance(act, list): # Python list
|
|
357
|
-
flattened_acts.extend(act)
|
|
358
|
-
else: # Single token
|
|
359
|
-
flattened_acts.append(act)
|
|
360
|
-
|
|
856
|
+
pending_reward_sums.append((i, mx.sum(mx.array(episode.rew))))
|
|
857
|
+
|
|
858
|
+
# Flatten observation and action tokens
|
|
859
|
+
flattened_obs = _flatten_tokens(episode.obs)
|
|
860
|
+
flattened_acts = _flatten_tokens(episode.act)
|
|
861
|
+
|
|
862
|
+
# Use flattened token count for episode_lengths (used by _expand_advantages)
|
|
863
|
+
# This ensures alignment between expanded advantages and actual token sequences
|
|
864
|
+
episode_lengths.append(len(flattened_acts))
|
|
865
|
+
|
|
361
866
|
# Create full sequence: [prompt_tokens..., response_tokens...]
|
|
362
867
|
full_sequence = flattened_obs + flattened_acts
|
|
363
868
|
all_obs.append(full_sequence)
|
|
364
869
|
all_acts.append(flattened_acts)
|
|
365
|
-
all_logprobs.append(episode.logprob if episode.logprob else [])
|
|
870
|
+
all_logprobs.append(episode.logprob if episode.logprob is not None else [])
|
|
366
871
|
else:
|
|
367
872
|
# Serialized dictionary from multiprocessing
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
#
|
|
375
|
-
|
|
873
|
+
rew = episode['rew']
|
|
874
|
+
if isinstance(rew, (int, float)):
|
|
875
|
+
scalar_rewards[i] = float(rew)
|
|
876
|
+
else:
|
|
877
|
+
pending_reward_sums.append((i, mx.sum(mx.array(rew))))
|
|
878
|
+
|
|
879
|
+
# Flatten observation and action tokens
|
|
880
|
+
flattened_obs = _flatten_tokens(episode['obs'])
|
|
881
|
+
flattened_acts = _flatten_tokens(episode['act'])
|
|
882
|
+
|
|
883
|
+
# Use flattened token count for episode_lengths
|
|
884
|
+
episode_lengths.append(len(flattened_acts))
|
|
885
|
+
|
|
886
|
+
full_sequence = flattened_obs + flattened_acts
|
|
376
887
|
all_obs.append(full_sequence)
|
|
377
|
-
all_acts.append(
|
|
888
|
+
all_acts.append(flattened_acts)
|
|
378
889
|
all_logprobs.append(episode.get('logprob', []))
|
|
379
|
-
|
|
380
|
-
#
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
890
|
+
|
|
891
|
+
# Batch evaluate all pending reward sums (single sync barrier instead of N)
|
|
892
|
+
episode_rewards = [0.0] * len(episodes)
|
|
893
|
+
for idx, val in scalar_rewards.items():
|
|
894
|
+
episode_rewards[idx] = val
|
|
895
|
+
if pending_reward_sums:
|
|
896
|
+
indices, lazy_sums = zip(*pending_reward_sums)
|
|
897
|
+
stacked = mx.stack(list(lazy_sums))
|
|
898
|
+
mx.eval(stacked)
|
|
899
|
+
values = stacked.tolist()
|
|
900
|
+
for idx, val in zip(indices, values):
|
|
901
|
+
episode_rewards[idx] = float(val)
|
|
902
|
+
|
|
903
|
+
# Find maximum sequence lengths for padding
|
|
385
904
|
max_obs_len = max(len(obs) for obs in all_obs) if all_obs else 0
|
|
386
905
|
max_act_len = max(len(act) for act in all_acts) if all_acts else 0
|
|
387
906
|
max_logprob_len = max(len(logprob) for logprob in all_logprobs) if all_logprobs else 0
|
|
388
|
-
|
|
389
|
-
#
|
|
390
|
-
#
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
print(f"DEBUG: all_obs types: {[type(obs) for obs in all_obs[:3]]}") # Show first 3 for brevity
|
|
427
|
-
print(f"DEBUG: all_logprobs types: {[type(logprob) for logprob in all_logprobs[:3]]}")
|
|
428
|
-
raise
|
|
429
|
-
|
|
430
|
-
# GRPO data structure: both observations and actions as flat concatenated sequences
|
|
431
|
-
# This matches the expected format for GRPO logprob extraction function
|
|
432
|
-
batch_data = {
|
|
433
|
-
'obs': mx.concatenate(all_obs_mx) if all_obs_mx else mx.array([]), # Flat concatenated full sequences
|
|
434
|
-
'act': mx.concatenate(all_acts_mx) if all_acts_mx else mx.array([]), # Flat concatenated response tokens
|
|
435
|
-
'logprob': mx.concatenate([logprob.flatten() for logprob in all_logprobs_mx]) if all_logprobs_mx else mx.array([]), # Flat sequence for training
|
|
907
|
+
|
|
908
|
+
# Convert to MLX arrays with padding
|
|
909
|
+
# Always create an array for each episode to maintain alignment
|
|
910
|
+
all_obs_mx = [mx.array(obs, dtype=mx.int64) if obs else mx.array([], dtype=mx.int64) for obs in all_obs]
|
|
911
|
+
all_acts_mx = [mx.array(act, dtype=mx.int64) if act else mx.array([], dtype=mx.int64) for act in all_acts]
|
|
912
|
+
all_logprobs_mx = [mx.array(logprob, dtype=mx.float32) if logprob else mx.array([], dtype=mx.float32) for logprob in all_logprobs]
|
|
913
|
+
|
|
914
|
+
# Filter out empty arrays for padding/concatenation
|
|
915
|
+
non_empty_obs = [obs for obs in all_obs_mx if obs.size > 0]
|
|
916
|
+
non_empty_acts = [act for act in all_acts_mx if act.size > 0]
|
|
917
|
+
non_empty_logprobs = [logprob for logprob in all_logprobs_mx if logprob.size > 0]
|
|
918
|
+
|
|
919
|
+
# Pad using native MLX operations
|
|
920
|
+
if non_empty_obs:
|
|
921
|
+
padded_obs = [mx.pad(obs, (0, max_obs_len - obs.shape[0]), constant_values=0)
|
|
922
|
+
if obs.shape[0] < max_obs_len else obs[:max_obs_len]
|
|
923
|
+
for obs in non_empty_obs]
|
|
924
|
+
else:
|
|
925
|
+
padded_obs = []
|
|
926
|
+
|
|
927
|
+
if non_empty_acts:
|
|
928
|
+
padded_acts = [mx.pad(act, (0, max_act_len - act.shape[0]), constant_values=0)
|
|
929
|
+
if act.shape[0] < max_act_len else act[:max_act_len]
|
|
930
|
+
for act in non_empty_acts]
|
|
931
|
+
else:
|
|
932
|
+
padded_acts = []
|
|
933
|
+
|
|
934
|
+
if non_empty_logprobs:
|
|
935
|
+
padded_logprobs = [mx.pad(logprob, (0, max_logprob_len - logprob.shape[0]), constant_values=0.0)
|
|
936
|
+
if logprob.shape[0] < max_logprob_len else logprob[:max_logprob_len]
|
|
937
|
+
for logprob in non_empty_logprobs]
|
|
938
|
+
else:
|
|
939
|
+
padded_logprobs = []
|
|
940
|
+
|
|
941
|
+
return {
|
|
942
|
+
'obs': mx.concatenate(padded_obs) if padded_obs else mx.array([], dtype=mx.int64),
|
|
943
|
+
'act': mx.concatenate(padded_acts) if padded_acts else mx.array([], dtype=mx.int64),
|
|
944
|
+
'logprob': mx.concatenate([lp.flatten() for lp in padded_logprobs]) if padded_logprobs else mx.array([], dtype=mx.float32),
|
|
436
945
|
'rewards': mx.array(episode_rewards),
|
|
437
|
-
'episode_lengths': episode_lengths
|
|
946
|
+
'episode_lengths': episode_lengths,
|
|
438
947
|
}
|
|
439
|
-
|
|
948
|
+
|
|
949
|
+
|
|
950
|
+
# Algorithm-specific data selection strategies
|
|
951
|
+
def select_all_data(buffer):
|
|
952
|
+
"""
|
|
953
|
+
GRPO data selector: Use all available data.
|
|
954
|
+
|
|
955
|
+
GRPO is on-policy but can benefit from using all collected episodes
|
|
956
|
+
since group-relative advantages normalize across the entire group.
|
|
957
|
+
|
|
958
|
+
Args:
|
|
959
|
+
buffer: Buffer containing episodes
|
|
960
|
+
|
|
961
|
+
Returns:
|
|
962
|
+
All episode data prepared for training
|
|
963
|
+
"""
|
|
964
|
+
from textpolicy.buffer import Buffer
|
|
965
|
+
if not isinstance(buffer, Buffer):
|
|
966
|
+
raise TypeError(f"Expected Buffer, got {type(buffer)}")
|
|
967
|
+
|
|
968
|
+
episodes = buffer.episodes
|
|
969
|
+
if not episodes:
|
|
970
|
+
raise ValueError("Buffer is empty - no episodes to train on")
|
|
971
|
+
|
|
972
|
+
return _pack_episodes(episodes)
|
|
973
|
+
|
|
974
|
+
|
|
975
|
+
def select_informative_data(buffer, min_variance: float = 0.01):
|
|
976
|
+
"""
|
|
977
|
+
GRPO data selector with dynamic batch filtering (DAPO-style).
|
|
978
|
+
|
|
979
|
+
Filters out uninformative prompts where all completions have the same
|
|
980
|
+
outcome (all correct or all wrong), improving sample efficiency by
|
|
981
|
+
maintaining meaningful gradient signals.
|
|
982
|
+
|
|
983
|
+
This is the recommended selector for GRPO training when using multiple
|
|
984
|
+
completions per prompt, as it eliminates wasted compute on prompts
|
|
985
|
+
that provide no learning signal.
|
|
986
|
+
|
|
987
|
+
Args:
|
|
988
|
+
buffer: Buffer containing episodes (Episode objects or serialized dictionaries)
|
|
989
|
+
min_variance: Minimum reward variance to keep a prompt group.
|
|
990
|
+
Prompts with variance below this are filtered out.
|
|
991
|
+
|
|
992
|
+
Returns:
|
|
993
|
+
Filtered episode data prepared for training, plus filtering stats.
|
|
994
|
+
|
|
995
|
+
Example:
|
|
996
|
+
>>> batch_data = select_informative_data(buffer, min_variance=0.01)
|
|
997
|
+
>>> # batch_data includes 'filter_stats' with filtering information
|
|
998
|
+
|
|
999
|
+
References:
|
|
1000
|
+
DAPO: An Open-Source LLM Reinforcement Learning System at Scale
|
|
1001
|
+
https://arxiv.org/abs/2503.14476
|
|
1002
|
+
"""
|
|
1003
|
+
from textpolicy.buffer import Buffer
|
|
1004
|
+
if not isinstance(buffer, Buffer):
|
|
1005
|
+
raise TypeError(f"Expected Buffer, got {type(buffer)}")
|
|
1006
|
+
|
|
1007
|
+
episodes = buffer.episodes
|
|
1008
|
+
if not episodes:
|
|
1009
|
+
raise ValueError("Buffer is empty - no episodes to train on")
|
|
1010
|
+
|
|
1011
|
+
# Filter to keep only informative prompts
|
|
1012
|
+
filtered_episodes, filter_stats = filter_informative_prompts(episodes, min_variance)
|
|
1013
|
+
|
|
1014
|
+
if not filtered_episodes:
|
|
1015
|
+
raise ValueError(
|
|
1016
|
+
f"All prompts filtered out (min_variance={min_variance}). "
|
|
1017
|
+
f"Stats: {filter_stats}. Consider lowering min_variance or "
|
|
1018
|
+
"ensuring diversity in completions."
|
|
1019
|
+
)
|
|
1020
|
+
|
|
1021
|
+
# Pack filtered episodes using shared helper
|
|
1022
|
+
batch_data = _pack_episodes(filtered_episodes)
|
|
1023
|
+
batch_data['filter_stats'] = filter_stats
|
|
440
1024
|
return batch_data
|
|
441
1025
|
|
|
442
1026
|
|
|
443
1027
|
def select_recent_data(buffer, max_episodes: int = 100):
|
|
444
1028
|
"""
|
|
445
1029
|
GRPO data selector: Use only recent episodes.
|
|
446
|
-
|
|
1030
|
+
|
|
447
1031
|
Alternative selector for GRPO that limits to recent episodes
|
|
448
1032
|
for faster training on large buffers.
|
|
449
|
-
|
|
1033
|
+
|
|
450
1034
|
Args:
|
|
451
1035
|
buffer: Buffer containing episodes (Episode objects or serialized dictionaries)
|
|
452
1036
|
max_episodes: Maximum number of recent episodes to use
|
|
453
|
-
|
|
1037
|
+
|
|
454
1038
|
Returns:
|
|
455
1039
|
Recent episode data prepared for training
|
|
456
1040
|
"""
|
|
457
1041
|
from textpolicy.buffer import Buffer
|
|
458
1042
|
if not isinstance(buffer, Buffer):
|
|
459
1043
|
raise TypeError(f"Expected Buffer, got {type(buffer)}")
|
|
460
|
-
|
|
1044
|
+
|
|
461
1045
|
episodes = buffer.episodes
|
|
462
1046
|
if not episodes:
|
|
463
1047
|
raise ValueError("Buffer is empty - no episodes to train on")
|
|
464
|
-
|
|
1048
|
+
|
|
465
1049
|
# Select recent episodes
|
|
466
1050
|
recent_episodes = episodes[-max_episodes:] if len(episodes) > max_episodes else episodes
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
episode_rewards = []
|
|
470
|
-
episode_lengths = []
|
|
471
|
-
all_obs = []
|
|
472
|
-
all_acts = []
|
|
473
|
-
all_logprobs = []
|
|
474
|
-
|
|
475
|
-
for episode in recent_episodes:
|
|
476
|
-
# Handle both Episode objects and serialized dictionaries
|
|
477
|
-
if hasattr(episode, 'rew'):
|
|
478
|
-
# Episode object with attributes
|
|
479
|
-
episode_reward = mx.sum(mx.array(episode.rew)).item()
|
|
480
|
-
episode_rewards.append(episode_reward)
|
|
481
|
-
episode_lengths.append(len(episode.obs))
|
|
482
|
-
|
|
483
|
-
# For proper logprob extraction during training, we need the full context (prompt + response)
|
|
484
|
-
# This matches how the model was called during rollout generation
|
|
485
|
-
# Convert both obs and act to consistent Python list format before concatenation
|
|
486
|
-
obs_as_lists = []
|
|
487
|
-
for obs_item in episode.obs:
|
|
488
|
-
if hasattr(obs_item, 'tolist'): # MLX array
|
|
489
|
-
obs_as_lists.extend(obs_item.tolist())
|
|
490
|
-
elif isinstance(obs_item, list): # Already Python list
|
|
491
|
-
obs_as_lists.extend(obs_item)
|
|
492
|
-
else: # Single item
|
|
493
|
-
obs_as_lists.append(obs_item)
|
|
494
|
-
|
|
495
|
-
act_as_lists = []
|
|
496
|
-
for act_item in episode.act:
|
|
497
|
-
if hasattr(act_item, 'tolist'): # MLX array
|
|
498
|
-
act_as_lists.extend(act_item.tolist())
|
|
499
|
-
elif isinstance(act_item, list): # Already Python list
|
|
500
|
-
act_as_lists.extend(act_item)
|
|
501
|
-
else: # Single item
|
|
502
|
-
act_as_lists.append(act_item)
|
|
503
|
-
|
|
504
|
-
# Now concatenate the normalized lists
|
|
505
|
-
full_sequence = obs_as_lists + act_as_lists
|
|
506
|
-
all_obs.append(full_sequence)
|
|
507
|
-
|
|
508
|
-
# Extract actions as consistent Python lists
|
|
509
|
-
episode_actions = []
|
|
510
|
-
for act_item in episode.act:
|
|
511
|
-
if hasattr(act_item, 'tolist'): # MLX array
|
|
512
|
-
episode_actions.extend(act_item.tolist())
|
|
513
|
-
elif isinstance(act_item, list): # Already Python list
|
|
514
|
-
episode_actions.extend(act_item)
|
|
515
|
-
else: # Single item
|
|
516
|
-
episode_actions.append(act_item)
|
|
517
|
-
all_acts.append(episode_actions)
|
|
518
|
-
|
|
519
|
-
# Extract logprobs as consistent Python lists
|
|
520
|
-
episode_logprobs = []
|
|
521
|
-
if episode.logprob:
|
|
522
|
-
for logprob_item in episode.logprob:
|
|
523
|
-
if hasattr(logprob_item, 'tolist'): # MLX array
|
|
524
|
-
episode_logprobs.extend(logprob_item.tolist())
|
|
525
|
-
elif isinstance(logprob_item, list): # Already Python list
|
|
526
|
-
episode_logprobs.extend(logprob_item)
|
|
527
|
-
else: # Single item
|
|
528
|
-
episode_logprobs.append(logprob_item)
|
|
529
|
-
all_logprobs.append(episode_logprobs)
|
|
530
|
-
else:
|
|
531
|
-
# Serialized dictionary from multiprocessing
|
|
532
|
-
episode_reward = mx.sum(episode['rew']).item()
|
|
533
|
-
episode_rewards.append(episode_reward)
|
|
534
|
-
episode_lengths.append(len(episode['obs']))
|
|
535
|
-
|
|
536
|
-
# For proper logprob extraction during training, we need the full context (prompt + response)
|
|
537
|
-
# This matches how the model was called during rollout generation
|
|
538
|
-
# Convert both obs and act to consistent Python list format before concatenation
|
|
539
|
-
obs_as_lists = []
|
|
540
|
-
for obs_item in episode['obs']:
|
|
541
|
-
if hasattr(obs_item, 'tolist'): # MLX array
|
|
542
|
-
obs_as_lists.extend(obs_item.tolist())
|
|
543
|
-
elif isinstance(obs_item, list): # Already Python list
|
|
544
|
-
obs_as_lists.extend(obs_item)
|
|
545
|
-
else: # Single item
|
|
546
|
-
obs_as_lists.append(obs_item)
|
|
547
|
-
|
|
548
|
-
act_as_lists = []
|
|
549
|
-
for act_item in episode['act']:
|
|
550
|
-
if hasattr(act_item, 'tolist'): # MLX array
|
|
551
|
-
act_as_lists.extend(act_item.tolist())
|
|
552
|
-
elif isinstance(act_item, list): # Already Python list
|
|
553
|
-
act_as_lists.extend(act_item)
|
|
554
|
-
else: # Single item
|
|
555
|
-
act_as_lists.append(act_item)
|
|
556
|
-
|
|
557
|
-
# Now concatenate the normalized lists
|
|
558
|
-
full_sequence = obs_as_lists + act_as_lists
|
|
559
|
-
all_obs.append(full_sequence)
|
|
560
|
-
|
|
561
|
-
# Extract actions as consistent Python lists
|
|
562
|
-
episode_actions = []
|
|
563
|
-
for act_item in episode['act']:
|
|
564
|
-
if hasattr(act_item, 'tolist'): # MLX array
|
|
565
|
-
episode_actions.extend(act_item.tolist())
|
|
566
|
-
elif isinstance(act_item, list): # Already Python list
|
|
567
|
-
episode_actions.extend(act_item)
|
|
568
|
-
else: # Single item
|
|
569
|
-
episode_actions.append(act_item)
|
|
570
|
-
all_acts.append(episode_actions)
|
|
571
|
-
|
|
572
|
-
# Extract logprobs as consistent Python lists
|
|
573
|
-
episode_logprobs = []
|
|
574
|
-
if episode.get('logprob'):
|
|
575
|
-
for logprob_item in episode['logprob']:
|
|
576
|
-
if hasattr(logprob_item, 'tolist'): # MLX array
|
|
577
|
-
episode_logprobs.extend(logprob_item.tolist())
|
|
578
|
-
elif isinstance(logprob_item, list): # Already Python list
|
|
579
|
-
episode_logprobs.extend(logprob_item)
|
|
580
|
-
else: # Single item
|
|
581
|
-
episode_logprobs.append(logprob_item)
|
|
582
|
-
all_logprobs.append(episode_logprobs)
|
|
583
|
-
|
|
584
|
-
# Convert Python lists to MLX arrays before concatenation
|
|
585
|
-
# This is required because Episode objects store data as Python lists for memory efficiency
|
|
586
|
-
# For proper logprob extraction, we need uniform-length sequences, so we pad to the maximum length
|
|
587
|
-
|
|
588
|
-
# Find maximum sequence length for padding
|
|
589
|
-
max_obs_len = max(len(obs) for obs in all_obs) if all_obs else 0
|
|
590
|
-
max_act_len = max(len(act) for act in all_acts) if all_acts else 0
|
|
591
|
-
max_logprob_len = max(len(logprob) for logprob in all_logprobs) if all_logprobs else 0
|
|
592
|
-
|
|
593
|
-
# MLX-native padding and array operations for optimal Apple Silicon performance
|
|
594
|
-
# Convert all sequences to MLX arrays and pad directly in MLX space
|
|
595
|
-
try:
|
|
596
|
-
# Convert all sequences to MLX arrays first (staying in unified memory)
|
|
597
|
-
all_obs_mx = [mx.array(obs, dtype=mx.int64) for obs in all_obs if obs]
|
|
598
|
-
all_acts_mx = [mx.array(act, dtype=mx.int64) for act in all_acts if act]
|
|
599
|
-
all_logprobs_mx = [mx.array(logprob, dtype=mx.float32) for logprob in all_logprobs if logprob]
|
|
600
|
-
|
|
601
|
-
# Pad using native MLX operations (more efficient for Apple Silicon)
|
|
602
|
-
if all_obs_mx:
|
|
603
|
-
padded_obs_mx = [mx.pad(obs, (0, max_obs_len - obs.shape[0]), constant_values=0)
|
|
604
|
-
if obs.shape[0] < max_obs_len else obs[:max_obs_len]
|
|
605
|
-
for obs in all_obs_mx]
|
|
606
|
-
else:
|
|
607
|
-
padded_obs_mx = []
|
|
608
|
-
|
|
609
|
-
if all_acts_mx:
|
|
610
|
-
padded_acts_mx = [mx.pad(act, (0, max_act_len - act.shape[0]), constant_values=0)
|
|
611
|
-
if act.shape[0] < max_act_len else act[:max_act_len]
|
|
612
|
-
for act in all_acts_mx]
|
|
613
|
-
else:
|
|
614
|
-
padded_acts_mx = []
|
|
615
|
-
|
|
616
|
-
if all_logprobs_mx:
|
|
617
|
-
padded_logprobs_mx = [mx.pad(logprob, (0, max_logprob_len - logprob.shape[0]), constant_values=0.0)
|
|
618
|
-
if logprob.shape[0] < max_logprob_len else logprob[:max_logprob_len]
|
|
619
|
-
for logprob in all_logprobs_mx]
|
|
620
|
-
else:
|
|
621
|
-
padded_logprobs_mx = []
|
|
622
|
-
|
|
623
|
-
# Use padded MLX arrays directly (no intermediate conversion needed)
|
|
624
|
-
all_obs_mx = padded_obs_mx
|
|
625
|
-
all_acts_mx = padded_acts_mx
|
|
626
|
-
all_logprobs_mx = padded_logprobs_mx
|
|
627
|
-
|
|
628
|
-
except Exception as e:
|
|
629
|
-
print(f"ERROR in MLX array conversion: {e}")
|
|
630
|
-
print(f"DEBUG: all_obs types: {[type(obs) for obs in all_obs[:3]]}") # Show first 3 for brevity
|
|
631
|
-
print(f"DEBUG: all_logprobs types: {[type(logprob) for logprob in all_logprobs[:3]]}")
|
|
632
|
-
raise
|
|
633
|
-
|
|
634
|
-
batch_data = {
|
|
635
|
-
'obs': mx.concatenate(all_obs_mx) if all_obs_mx else mx.array([]), # Flat concatenated full sequences
|
|
636
|
-
'act': mx.concatenate(all_acts_mx) if all_acts_mx else mx.array([]), # Flat concatenated response tokens
|
|
637
|
-
'logprob': mx.concatenate([logprob.flatten() for logprob in all_logprobs_mx]) if all_logprobs_mx else mx.array([]), # Flat sequence for training
|
|
638
|
-
'rewards': mx.array(episode_rewards),
|
|
639
|
-
'episode_lengths': episode_lengths
|
|
640
|
-
}
|
|
641
|
-
|
|
642
|
-
return batch_data
|
|
1051
|
+
|
|
1052
|
+
return _pack_episodes(recent_episodes)
|