textpolicy 0.0.1__py3-none-any.whl → 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. textpolicy/__init__.py +53 -0
  2. textpolicy/__main__.py +8 -0
  3. textpolicy/algorithms/__init__.py +54 -0
  4. textpolicy/algorithms/grpo.py +642 -0
  5. textpolicy/algorithms/gspo.py +582 -0
  6. textpolicy/buffer/__init__.py +23 -0
  7. textpolicy/buffer/buffer.py +244 -0
  8. textpolicy/buffer/episode.py +383 -0
  9. textpolicy/buffer/sampling.py +438 -0
  10. textpolicy/buffer/storage.py +255 -0
  11. textpolicy/cli.py +67 -0
  12. textpolicy/environment/__init__.py +79 -0
  13. textpolicy/environment/base.py +110 -0
  14. textpolicy/environment/environment.py +46 -0
  15. textpolicy/environment/factory.py +103 -0
  16. textpolicy/environment/gym.py +106 -0
  17. textpolicy/environment/task_suites.py +51 -0
  18. textpolicy/environment/text_generation.py +797 -0
  19. textpolicy/environment/vectorized.py +253 -0
  20. textpolicy/generation/__init__.py +62 -0
  21. textpolicy/generation/lora.py +411 -0
  22. textpolicy/generation/mlx_generation.py +557 -0
  23. textpolicy/generation/reload.py +253 -0
  24. textpolicy/rewards/__init__.py +137 -0
  25. textpolicy/rewards/adapters.py +387 -0
  26. textpolicy/rewards/basic.py +214 -0
  27. textpolicy/rewards/integrated_system.py +338 -0
  28. textpolicy/rewards/mlx_batch_processor.py +447 -0
  29. textpolicy/rewards/registry.py +293 -0
  30. textpolicy/rewards/rollout_rewards.py +410 -0
  31. textpolicy/rewards/verifiers.py +369 -0
  32. textpolicy/rollout/__init__.py +44 -0
  33. textpolicy/rollout/aggregator.py +145 -0
  34. textpolicy/rollout/base.py +108 -0
  35. textpolicy/rollout/rollout.py +142 -0
  36. textpolicy/rollout/runner.py +280 -0
  37. textpolicy/rollout/strategy.py +208 -0
  38. textpolicy/rollout/worker.py +194 -0
  39. textpolicy/training/__init__.py +14 -0
  40. textpolicy/training/metrics.py +242 -0
  41. textpolicy/training/rollout_manager.py +78 -0
  42. textpolicy/training/trainer.py +684 -0
  43. textpolicy/utils/__init__.py +40 -0
  44. textpolicy/utils/benchmarking.py +489 -0
  45. textpolicy/utils/data.py +60 -0
  46. textpolicy/utils/debug.py +170 -0
  47. textpolicy/utils/environment.py +349 -0
  48. textpolicy/utils/logging/__init__.py +22 -0
  49. textpolicy/utils/logging/base.py +48 -0
  50. textpolicy/utils/logging/console.py +61 -0
  51. textpolicy/utils/logging/factory.py +133 -0
  52. textpolicy/utils/logging/multi.py +83 -0
  53. textpolicy/utils/logging/tensorboard.py +65 -0
  54. textpolicy/utils/logging/wandb.py +72 -0
  55. textpolicy/utils/memory.py +118 -0
  56. textpolicy/utils/performance.py +464 -0
  57. textpolicy/utils/timing.py +171 -0
  58. textpolicy/validate.py +101 -0
  59. textpolicy/validation/__init__.py +13 -0
  60. textpolicy/validation/logprob_validation.py +315 -0
  61. textpolicy-0.1.1.dist-info/METADATA +109 -0
  62. textpolicy-0.1.1.dist-info/RECORD +66 -0
  63. {textpolicy-0.0.1.dist-info → textpolicy-0.1.1.dist-info}/WHEEL +1 -1
  64. textpolicy-0.1.1.dist-info/entry_points.txt +2 -0
  65. textpolicy-0.0.1.dist-info/METADATA +0 -10
  66. textpolicy-0.0.1.dist-info/RECORD +0 -6
  67. {textpolicy-0.0.1.dist-info → textpolicy-0.1.1.dist-info}/licenses/LICENSE +0 -0
  68. {textpolicy-0.0.1.dist-info → textpolicy-0.1.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,642 @@
1
+ # textpolicy/algorithms/grpo.py
2
+ """
3
+ Group Relative Policy Optimization (GRPO) - Pure Functions for MLX.
4
+
5
+ GRPO eliminates value function training by using group-relative advantages:
6
+ A(τ) = R(τ) - mean(R(group))
7
+
8
+ These pure functions are designed for:
9
+ - MLX compilation with @mx.compile
10
+ - Apple Silicon unified memory
11
+ - Low abstraction cost
12
+ - Composability
13
+ """
14
+
15
+ from __future__ import annotations
16
+
17
+ try:
18
+ import mlx.core as mx # type: ignore
19
+ except ImportError:
20
+ mx = None # MLX is optional; compilation-decorated functions will error if MLX is missing
21
+
22
+ # Provide a no-op compile decorator when MLX is not available
23
+ if mx is None:
24
+ class _DummyMx:
25
+ def compile(self, fn):
26
+ return fn
27
+
28
+ mx = _DummyMx()
29
+ from typing import List, Union
30
+
31
+
32
+ def compute_advantages(rewards: Union[List[float], mx.array]) -> mx.array:
33
+ """
34
+ Compute group-relative advantages for GRPO.
35
+
36
+ Core GRPO innovation: Use group mean as baseline instead of value function.
37
+ This eliminates 50% of neural network training while providing stable gradients.
38
+
39
+ Formula: A(τ) = R(τ) - mean(R(group))
40
+
41
+ Args:
42
+ rewards: Episode rewards, either Python list or MLX array
43
+
44
+ Returns:
45
+ Group-relative advantages as MLX array
46
+
47
+ Notes:
48
+ - Single vectorized operation (no Python loops)
49
+ - Minimal memory allocation
50
+ - Suitable for @mx.compile decoration
51
+ - Handles variable batch sizes
52
+ """
53
+ if isinstance(rewards, list):
54
+ if not rewards:
55
+ return mx.array([])
56
+ rewards_tensor = mx.array(rewards, dtype=mx.float32)
57
+ elif isinstance(rewards, mx.array):
58
+ rewards_tensor = rewards.astype(mx.float32)
59
+ else:
60
+ raise TypeError(f"Expected list or mx.array, got {type(rewards)}")
61
+
62
+ # Group-relative advantages: rewards relative to group mean
63
+ # Broadcasting handles the subtraction efficiently
64
+ group_mean = mx.mean(rewards_tensor)
65
+ advantages = rewards_tensor - group_mean
66
+
67
+ return advantages
68
+
69
+
70
+ def compute_advantages_dr_grpo(rewards: Union[List[float], mx.array]) -> mx.array:
71
+ """
72
+ Compute advantages using Dr. GRPO (GRPO Done Right) - bias-corrected version.
73
+
74
+ Based on https://arxiv.org/html/2503.20783, this version fixes two key biases:
75
+ 1. Response-level length bias: Removes 1/|o_i| normalization
76
+ 2. Question-level difficulty bias: Removes std normalization
77
+
78
+ Dr. GRPO formula: A(τ) = R(τ) - mean(R(group))
79
+ (Same as basic GRPO but ensures no hidden normalizations)
80
+
81
+ Args:
82
+ rewards: Episode rewards, either Python list or MLX array
83
+
84
+ Returns:
85
+ Unbiased group-relative advantages as MLX array
86
+
87
+ Key improvements over standard GRPO:
88
+ - No response length normalization (prevents length bias)
89
+ - No standard deviation normalization (prevents difficulty bias)
90
+ - Recovers original unbiased policy gradient objective
91
+ """
92
+ if isinstance(rewards, list):
93
+ if not rewards:
94
+ return mx.array([])
95
+ rewards_tensor = mx.array(rewards, dtype=mx.float32)
96
+ elif isinstance(rewards, mx.array):
97
+ rewards_tensor = rewards.astype(mx.float32)
98
+ else:
99
+ raise TypeError(f"Expected list or mx.array, got {type(rewards)}")
100
+
101
+ # Dr. GRPO: Pure group-relative advantages without any normalization bias
102
+ # Key insight: Keep advantages raw to avoid length/difficulty biases
103
+ group_mean = mx.mean(rewards_tensor)
104
+ advantages = rewards_tensor - group_mean
105
+
106
+ # Do not apply extra normalizations that introduce bias:
107
+ # - NO division by response length |o_i| (creates length bias)
108
+ # - NO division by std(rewards) (creates difficulty bias)
109
+ # - Keep raw advantage signal for unbiased learning
110
+
111
+ return advantages
112
+
113
+
114
+ def policy_loss(
115
+ old_logprobs: mx.array,
116
+ new_logprobs: mx.array,
117
+ advantages: mx.array,
118
+ clip_ratio: float = 0.2
119
+ ) -> mx.array:
120
+ """
121
+ GRPO policy loss with PPO-style clipping.
122
+
123
+ Uses clipped surrogate objective but with group-relative advantages
124
+ instead of GAE advantages.
125
+
126
+ Args:
127
+ old_logprobs: Log probabilities from rollout collection
128
+ new_logprobs: Log probabilities from current policy evaluation
129
+ advantages: Group-relative advantages from compute_advantages()
130
+ clip_ratio: Clipping ratio for surrogate objective
131
+
132
+ Returns:
133
+ Policy loss scalar (to be minimized)
134
+
135
+ Notes:
136
+ - Fully vectorized (no Python loops over batch)
137
+ - Uses in-place operations where possible
138
+ - Suitable for MLX graph optimization
139
+ - Single forward pass through computation
140
+ """
141
+ # Importance ratio: π_new / π_old
142
+ # MLX optimizes exp() for Apple Silicon
143
+ ratio = mx.exp(new_logprobs - old_logprobs)
144
+
145
+ # PPO clipped surrogate objective
146
+ # L = min(ratio * A, clip(ratio, 1-ε, 1+ε) * A)
147
+ clipped_ratio = mx.clip(ratio, 1 - clip_ratio, 1 + clip_ratio)
148
+
149
+ # Element-wise minimum and mean reduction
150
+ # Negative because we minimize (original maximizes)
151
+ surr1 = ratio * advantages
152
+ surr2 = clipped_ratio * advantages
153
+ loss = -mx.mean(mx.minimum(surr1, surr2))
154
+
155
+ return loss
156
+
157
+
158
+ # Optional: Compiled versions for maximum performance
159
+ @mx.compile
160
+ def compute_advantages_compiled(rewards: mx.array) -> mx.array:
161
+ """Compiled version of compute_advantages for maximum performance."""
162
+ group_mean = mx.mean(rewards)
163
+ return rewards - group_mean
164
+
165
+
166
+ @mx.compile
167
+ def policy_loss_compiled(
168
+ old_logprobs: mx.array,
169
+ new_logprobs: mx.array,
170
+ advantages: mx.array,
171
+ clip_ratio: float = 0.2
172
+ ) -> mx.array:
173
+ """Compiled version of policy_loss for maximum performance."""
174
+ ratio = mx.exp(new_logprobs - old_logprobs)
175
+ clipped_ratio = mx.clip(ratio, 1 - clip_ratio, 1 + clip_ratio)
176
+ surr1 = ratio * advantages
177
+ surr2 = clipped_ratio * advantages
178
+ return -mx.mean(mx.minimum(surr1, surr2))
179
+
180
+
181
+ def entropy_bonus(logprobs: mx.array, coefficient: float = 0.01) -> mx.array:
182
+ """
183
+ Entropy bonus for exploration (optional GRPO component).
184
+
185
+ Args:
186
+ logprobs: Log probabilities from policy
187
+ coefficient: Entropy coefficient (typically small, like 0.01)
188
+
189
+ Returns:
190
+ Entropy bonus (added to loss for exploration)
191
+ """
192
+ if coefficient <= 0:
193
+ return mx.array(0.0)
194
+
195
+ # Entropy = -sum(p * log(p))
196
+ # For log probabilities: entropy = -sum(exp(logp) * logp)
197
+ probs = mx.exp(logprobs)
198
+ entropy = -mx.sum(probs * logprobs, axis=-1)
199
+
200
+ # Return negative entropy (since we add to loss but want to maximize entropy)
201
+ return -coefficient * mx.mean(entropy)
202
+
203
+
204
+ # Convenience function for complete GRPO computation
205
+ def grpo_loss(
206
+ old_logprobs: mx.array,
207
+ new_logprobs: mx.array,
208
+ rewards: Union[List[float], mx.array],
209
+ clip_ratio: float = 0.2,
210
+ entropy_coeff: float = 0.0
211
+ ) -> mx.array:
212
+ """
213
+ Complete GRPO loss computation in one function.
214
+
215
+ Combines advantage calculation and policy loss for convenience.
216
+ Can be compiled as a single unit for maximum efficiency.
217
+
218
+ Args:
219
+ old_logprobs: Log probabilities from rollout
220
+ new_logprobs: Log probabilities from current policy
221
+ rewards: Episode rewards for group-relative advantages
222
+ clip_ratio: PPO clipping ratio
223
+ entropy_coeff: Entropy bonus coefficient (0 disables)
224
+
225
+ Returns:
226
+ Total GRPO loss (policy + optional entropy)
227
+ """
228
+ # Compute group-relative advantages
229
+ advantages = compute_advantages(rewards)
230
+
231
+ # Expand advantages to match logprob sequence length if needed
232
+ if advantages.ndim == 1 and old_logprobs.ndim > 1:
233
+ # Each episode contributes its advantage to all tokens in that episode
234
+ # This requires knowing episode boundaries - simplified version assumes
235
+ # advantages and logprobs are already aligned
236
+ pass
237
+
238
+ # Compute policy loss
239
+ policy_loss_val = policy_loss(old_logprobs, new_logprobs, advantages, clip_ratio)
240
+
241
+ # Add entropy bonus if specified
242
+ if entropy_coeff > 0:
243
+ entropy_bonus_val = entropy_bonus(new_logprobs, entropy_coeff)
244
+ return policy_loss_val + entropy_bonus_val
245
+
246
+ return policy_loss_val
247
+
248
+
249
+ # Performance monitoring utilities
250
+ def compute_metrics(
251
+ old_logprobs: mx.array,
252
+ new_logprobs: mx.array,
253
+ advantages: mx.array,
254
+ clip_ratio: float = 0.2
255
+ ) -> dict:
256
+ """
257
+ Compute GRPO training metrics for monitoring.
258
+
259
+ Args:
260
+ old_logprobs: Log probabilities from rollout
261
+ new_logprobs: Log probabilities from current policy
262
+ advantages: Group-relative advantages
263
+ clip_ratio: Clipping ratio used in loss
264
+
265
+ Returns:
266
+ Dictionary of metrics for logging/monitoring
267
+ """
268
+ # Importance ratio statistics
269
+ ratio = mx.exp(new_logprobs - old_logprobs)
270
+
271
+ # Clipping statistics
272
+ clip_lower = 1 - clip_ratio
273
+ clip_upper = 1 + clip_ratio
274
+ clipped = (ratio < clip_lower) | (ratio > clip_upper)
275
+ clip_fraction = mx.mean(clipped.astype(mx.float32))
276
+
277
+ # KL divergence approximation
278
+ kl_div = mx.mean(old_logprobs - new_logprobs)
279
+
280
+ return {
281
+ 'mean_advantage': mx.mean(advantages).item(),
282
+ 'std_advantage': mx.std(advantages).item(),
283
+ 'mean_ratio': mx.mean(ratio).item(),
284
+ 'clip_fraction': clip_fraction.item(),
285
+ 'kl_divergence': kl_div.item(),
286
+ 'min_advantage': mx.min(advantages).item(),
287
+ 'max_advantage': mx.max(advantages).item()
288
+ }
289
+
290
+
291
+ # Algorithm-specific data selection strategies
292
+ def select_all_data(buffer):
293
+ """
294
+ GRPO data selector: Use all available data.
295
+
296
+ GRPO is on-policy but can benefit from using all collected episodes
297
+ since group-relative advantages normalize across the entire group.
298
+
299
+ Args:
300
+ buffer: Buffer containing episodes
301
+
302
+ Returns:
303
+ All episode data prepared for training
304
+ """
305
+ from textpolicy.buffer import Buffer
306
+ if not isinstance(buffer, Buffer):
307
+ raise TypeError(f"Expected Buffer, got {type(buffer)}")
308
+
309
+ # Use all available data - GRPO benefits from larger groups
310
+ episodes_data = buffer.sample() # This returns concatenated transitions
311
+
312
+ # We need to convert this back to episode structure for reward extraction
313
+ episodes = buffer.episodes # Access episodes directly from storage
314
+
315
+ if not episodes:
316
+ raise ValueError("Buffer is empty - no episodes to train on")
317
+
318
+ # Extract episode rewards for advantage computation
319
+ episode_rewards = []
320
+ episode_lengths = []
321
+
322
+ # Collect all transitions
323
+ all_obs = []
324
+ all_acts = []
325
+ all_logprobs = []
326
+
327
+ for episode in episodes:
328
+ # Episode reward (sum of all rewards in episode)
329
+ # Handle both Episode objects and serialized dictionaries
330
+ if hasattr(episode, 'rew'):
331
+ # Episode object with attributes
332
+ episode_reward = mx.sum(mx.array(episode.rew)).item()
333
+ episode_rewards.append(episode_reward)
334
+ episode_lengths.append(len(episode.obs))
335
+
336
+ # Collect transitions
337
+ # For proper logprob extraction during training, we need the full context (prompt + response)
338
+ # This matches how the model was called during rollout generation
339
+ # Flatten nested token sequences to create uniform token arrays
340
+
341
+ # Extract and flatten observation tokens (prompt)
342
+ flattened_obs = []
343
+ for obs in episode.obs:
344
+ if hasattr(obs, 'tolist'): # MLX array
345
+ flattened_obs.extend(obs.tolist())
346
+ elif isinstance(obs, list): # Python list
347
+ flattened_obs.extend(obs)
348
+ else: # Single token
349
+ flattened_obs.append(obs)
350
+
351
+ # Extract and flatten action tokens (response)
352
+ flattened_acts = []
353
+ for act in episode.act:
354
+ if hasattr(act, 'tolist'): # MLX array
355
+ flattened_acts.extend(act.tolist())
356
+ elif isinstance(act, list): # Python list
357
+ flattened_acts.extend(act)
358
+ else: # Single token
359
+ flattened_acts.append(act)
360
+
361
+ # Create full sequence: [prompt_tokens..., response_tokens...]
362
+ full_sequence = flattened_obs + flattened_acts
363
+ all_obs.append(full_sequence)
364
+ all_acts.append(flattened_acts)
365
+ all_logprobs.append(episode.logprob if episode.logprob else [])
366
+ else:
367
+ # Serialized dictionary from multiprocessing
368
+ episode_reward = mx.sum(episode['rew']).item()
369
+ episode_rewards.append(episode_reward)
370
+ episode_lengths.append(len(episode['obs']))
371
+
372
+ # Collect transitions
373
+ # For proper logprob extraction during training, we need the full context (prompt + response)
374
+ # This matches how the model was called during rollout generation
375
+ full_sequence = episode['obs'] + episode['act'] # Concatenate prompt + response
376
+ all_obs.append(full_sequence)
377
+ all_acts.append(episode['act'])
378
+ all_logprobs.append(episode.get('logprob', []))
379
+
380
+ # Convert Python lists to MLX arrays before concatenation
381
+ # This is required because Episode objects store data as Python lists for memory efficiency
382
+ # For proper logprob extraction, we need uniform-length sequences, so we pad to the maximum length
383
+
384
+ # Find maximum sequence length for padding
385
+ max_obs_len = max(len(obs) for obs in all_obs) if all_obs else 0
386
+ max_act_len = max(len(act) for act in all_acts) if all_acts else 0
387
+ max_logprob_len = max(len(logprob) for logprob in all_logprobs) if all_logprobs else 0
388
+
389
+ # MLX-native padding and array operations for optimal Apple Silicon performance
390
+ # Convert all sequences to MLX arrays and pad directly in MLX space
391
+ try:
392
+ # Convert all sequences to MLX arrays first (staying in unified memory)
393
+ all_obs_mx = [mx.array(obs, dtype=mx.int64) for obs in all_obs if obs]
394
+ all_acts_mx = [mx.array(act, dtype=mx.int64) for act in all_acts if act]
395
+ all_logprobs_mx = [mx.array(logprob, dtype=mx.float32) for logprob in all_logprobs if logprob]
396
+
397
+ # Pad using native MLX operations (more efficient for Apple Silicon)
398
+ if all_obs_mx:
399
+ padded_obs_mx = [mx.pad(obs, (0, max_obs_len - obs.shape[0]), constant_values=0)
400
+ if obs.shape[0] < max_obs_len else obs[:max_obs_len]
401
+ for obs in all_obs_mx]
402
+ else:
403
+ padded_obs_mx = []
404
+
405
+ if all_acts_mx:
406
+ padded_acts_mx = [mx.pad(act, (0, max_act_len - act.shape[0]), constant_values=0)
407
+ if act.shape[0] < max_act_len else act[:max_act_len]
408
+ for act in all_acts_mx]
409
+ else:
410
+ padded_acts_mx = []
411
+
412
+ if all_logprobs_mx:
413
+ padded_logprobs_mx = [mx.pad(logprob, (0, max_logprob_len - logprob.shape[0]), constant_values=0.0)
414
+ if logprob.shape[0] < max_logprob_len else logprob[:max_logprob_len]
415
+ for logprob in all_logprobs_mx]
416
+ else:
417
+ padded_logprobs_mx = []
418
+
419
+ # Use padded MLX arrays directly (no intermediate conversion needed)
420
+ all_obs_mx = padded_obs_mx
421
+ all_acts_mx = padded_acts_mx
422
+ all_logprobs_mx = padded_logprobs_mx
423
+
424
+ except Exception as e:
425
+ print(f"ERROR in MLX array conversion: {e}")
426
+ print(f"DEBUG: all_obs types: {[type(obs) for obs in all_obs[:3]]}") # Show first 3 for brevity
427
+ print(f"DEBUG: all_logprobs types: {[type(logprob) for logprob in all_logprobs[:3]]}")
428
+ raise
429
+
430
+ # GRPO data structure: both observations and actions as flat concatenated sequences
431
+ # This matches the expected format for GRPO logprob extraction function
432
+ batch_data = {
433
+ 'obs': mx.concatenate(all_obs_mx) if all_obs_mx else mx.array([]), # Flat concatenated full sequences
434
+ 'act': mx.concatenate(all_acts_mx) if all_acts_mx else mx.array([]), # Flat concatenated response tokens
435
+ 'logprob': mx.concatenate([logprob.flatten() for logprob in all_logprobs_mx]) if all_logprobs_mx else mx.array([]), # Flat sequence for training
436
+ 'rewards': mx.array(episode_rewards),
437
+ 'episode_lengths': episode_lengths
438
+ }
439
+
440
+ return batch_data
441
+
442
+
443
+ def select_recent_data(buffer, max_episodes: int = 100):
444
+ """
445
+ GRPO data selector: Use only recent episodes.
446
+
447
+ Alternative selector for GRPO that limits to recent episodes
448
+ for faster training on large buffers.
449
+
450
+ Args:
451
+ buffer: Buffer containing episodes (Episode objects or serialized dictionaries)
452
+ max_episodes: Maximum number of recent episodes to use
453
+
454
+ Returns:
455
+ Recent episode data prepared for training
456
+ """
457
+ from textpolicy.buffer import Buffer
458
+ if not isinstance(buffer, Buffer):
459
+ raise TypeError(f"Expected Buffer, got {type(buffer)}")
460
+
461
+ episodes = buffer.episodes
462
+ if not episodes:
463
+ raise ValueError("Buffer is empty - no episodes to train on")
464
+
465
+ # Select recent episodes
466
+ recent_episodes = episodes[-max_episodes:] if len(episodes) > max_episodes else episodes
467
+
468
+ # Process recent episodes
469
+ episode_rewards = []
470
+ episode_lengths = []
471
+ all_obs = []
472
+ all_acts = []
473
+ all_logprobs = []
474
+
475
+ for episode in recent_episodes:
476
+ # Handle both Episode objects and serialized dictionaries
477
+ if hasattr(episode, 'rew'):
478
+ # Episode object with attributes
479
+ episode_reward = mx.sum(mx.array(episode.rew)).item()
480
+ episode_rewards.append(episode_reward)
481
+ episode_lengths.append(len(episode.obs))
482
+
483
+ # For proper logprob extraction during training, we need the full context (prompt + response)
484
+ # This matches how the model was called during rollout generation
485
+ # Convert both obs and act to consistent Python list format before concatenation
486
+ obs_as_lists = []
487
+ for obs_item in episode.obs:
488
+ if hasattr(obs_item, 'tolist'): # MLX array
489
+ obs_as_lists.extend(obs_item.tolist())
490
+ elif isinstance(obs_item, list): # Already Python list
491
+ obs_as_lists.extend(obs_item)
492
+ else: # Single item
493
+ obs_as_lists.append(obs_item)
494
+
495
+ act_as_lists = []
496
+ for act_item in episode.act:
497
+ if hasattr(act_item, 'tolist'): # MLX array
498
+ act_as_lists.extend(act_item.tolist())
499
+ elif isinstance(act_item, list): # Already Python list
500
+ act_as_lists.extend(act_item)
501
+ else: # Single item
502
+ act_as_lists.append(act_item)
503
+
504
+ # Now concatenate the normalized lists
505
+ full_sequence = obs_as_lists + act_as_lists
506
+ all_obs.append(full_sequence)
507
+
508
+ # Extract actions as consistent Python lists
509
+ episode_actions = []
510
+ for act_item in episode.act:
511
+ if hasattr(act_item, 'tolist'): # MLX array
512
+ episode_actions.extend(act_item.tolist())
513
+ elif isinstance(act_item, list): # Already Python list
514
+ episode_actions.extend(act_item)
515
+ else: # Single item
516
+ episode_actions.append(act_item)
517
+ all_acts.append(episode_actions)
518
+
519
+ # Extract logprobs as consistent Python lists
520
+ episode_logprobs = []
521
+ if episode.logprob:
522
+ for logprob_item in episode.logprob:
523
+ if hasattr(logprob_item, 'tolist'): # MLX array
524
+ episode_logprobs.extend(logprob_item.tolist())
525
+ elif isinstance(logprob_item, list): # Already Python list
526
+ episode_logprobs.extend(logprob_item)
527
+ else: # Single item
528
+ episode_logprobs.append(logprob_item)
529
+ all_logprobs.append(episode_logprobs)
530
+ else:
531
+ # Serialized dictionary from multiprocessing
532
+ episode_reward = mx.sum(episode['rew']).item()
533
+ episode_rewards.append(episode_reward)
534
+ episode_lengths.append(len(episode['obs']))
535
+
536
+ # For proper logprob extraction during training, we need the full context (prompt + response)
537
+ # This matches how the model was called during rollout generation
538
+ # Convert both obs and act to consistent Python list format before concatenation
539
+ obs_as_lists = []
540
+ for obs_item in episode['obs']:
541
+ if hasattr(obs_item, 'tolist'): # MLX array
542
+ obs_as_lists.extend(obs_item.tolist())
543
+ elif isinstance(obs_item, list): # Already Python list
544
+ obs_as_lists.extend(obs_item)
545
+ else: # Single item
546
+ obs_as_lists.append(obs_item)
547
+
548
+ act_as_lists = []
549
+ for act_item in episode['act']:
550
+ if hasattr(act_item, 'tolist'): # MLX array
551
+ act_as_lists.extend(act_item.tolist())
552
+ elif isinstance(act_item, list): # Already Python list
553
+ act_as_lists.extend(act_item)
554
+ else: # Single item
555
+ act_as_lists.append(act_item)
556
+
557
+ # Now concatenate the normalized lists
558
+ full_sequence = obs_as_lists + act_as_lists
559
+ all_obs.append(full_sequence)
560
+
561
+ # Extract actions as consistent Python lists
562
+ episode_actions = []
563
+ for act_item in episode['act']:
564
+ if hasattr(act_item, 'tolist'): # MLX array
565
+ episode_actions.extend(act_item.tolist())
566
+ elif isinstance(act_item, list): # Already Python list
567
+ episode_actions.extend(act_item)
568
+ else: # Single item
569
+ episode_actions.append(act_item)
570
+ all_acts.append(episode_actions)
571
+
572
+ # Extract logprobs as consistent Python lists
573
+ episode_logprobs = []
574
+ if episode.get('logprob'):
575
+ for logprob_item in episode['logprob']:
576
+ if hasattr(logprob_item, 'tolist'): # MLX array
577
+ episode_logprobs.extend(logprob_item.tolist())
578
+ elif isinstance(logprob_item, list): # Already Python list
579
+ episode_logprobs.extend(logprob_item)
580
+ else: # Single item
581
+ episode_logprobs.append(logprob_item)
582
+ all_logprobs.append(episode_logprobs)
583
+
584
+ # Convert Python lists to MLX arrays before concatenation
585
+ # This is required because Episode objects store data as Python lists for memory efficiency
586
+ # For proper logprob extraction, we need uniform-length sequences, so we pad to the maximum length
587
+
588
+ # Find maximum sequence length for padding
589
+ max_obs_len = max(len(obs) for obs in all_obs) if all_obs else 0
590
+ max_act_len = max(len(act) for act in all_acts) if all_acts else 0
591
+ max_logprob_len = max(len(logprob) for logprob in all_logprobs) if all_logprobs else 0
592
+
593
+ # MLX-native padding and array operations for optimal Apple Silicon performance
594
+ # Convert all sequences to MLX arrays and pad directly in MLX space
595
+ try:
596
+ # Convert all sequences to MLX arrays first (staying in unified memory)
597
+ all_obs_mx = [mx.array(obs, dtype=mx.int64) for obs in all_obs if obs]
598
+ all_acts_mx = [mx.array(act, dtype=mx.int64) for act in all_acts if act]
599
+ all_logprobs_mx = [mx.array(logprob, dtype=mx.float32) for logprob in all_logprobs if logprob]
600
+
601
+ # Pad using native MLX operations (more efficient for Apple Silicon)
602
+ if all_obs_mx:
603
+ padded_obs_mx = [mx.pad(obs, (0, max_obs_len - obs.shape[0]), constant_values=0)
604
+ if obs.shape[0] < max_obs_len else obs[:max_obs_len]
605
+ for obs in all_obs_mx]
606
+ else:
607
+ padded_obs_mx = []
608
+
609
+ if all_acts_mx:
610
+ padded_acts_mx = [mx.pad(act, (0, max_act_len - act.shape[0]), constant_values=0)
611
+ if act.shape[0] < max_act_len else act[:max_act_len]
612
+ for act in all_acts_mx]
613
+ else:
614
+ padded_acts_mx = []
615
+
616
+ if all_logprobs_mx:
617
+ padded_logprobs_mx = [mx.pad(logprob, (0, max_logprob_len - logprob.shape[0]), constant_values=0.0)
618
+ if logprob.shape[0] < max_logprob_len else logprob[:max_logprob_len]
619
+ for logprob in all_logprobs_mx]
620
+ else:
621
+ padded_logprobs_mx = []
622
+
623
+ # Use padded MLX arrays directly (no intermediate conversion needed)
624
+ all_obs_mx = padded_obs_mx
625
+ all_acts_mx = padded_acts_mx
626
+ all_logprobs_mx = padded_logprobs_mx
627
+
628
+ except Exception as e:
629
+ print(f"ERROR in MLX array conversion: {e}")
630
+ print(f"DEBUG: all_obs types: {[type(obs) for obs in all_obs[:3]]}") # Show first 3 for brevity
631
+ print(f"DEBUG: all_logprobs types: {[type(logprob) for logprob in all_logprobs[:3]]}")
632
+ raise
633
+
634
+ batch_data = {
635
+ 'obs': mx.concatenate(all_obs_mx) if all_obs_mx else mx.array([]), # Flat concatenated full sequences
636
+ 'act': mx.concatenate(all_acts_mx) if all_acts_mx else mx.array([]), # Flat concatenated response tokens
637
+ 'logprob': mx.concatenate([logprob.flatten() for logprob in all_logprobs_mx]) if all_logprobs_mx else mx.array([]), # Flat sequence for training
638
+ 'rewards': mx.array(episode_rewards),
639
+ 'episode_lengths': episode_lengths
640
+ }
641
+
642
+ return batch_data