textpolicy 0.0.1__py3-none-any.whl → 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- textpolicy/__init__.py +53 -0
- textpolicy/__main__.py +8 -0
- textpolicy/algorithms/__init__.py +54 -0
- textpolicy/algorithms/grpo.py +642 -0
- textpolicy/algorithms/gspo.py +582 -0
- textpolicy/buffer/__init__.py +23 -0
- textpolicy/buffer/buffer.py +244 -0
- textpolicy/buffer/episode.py +383 -0
- textpolicy/buffer/sampling.py +438 -0
- textpolicy/buffer/storage.py +255 -0
- textpolicy/cli.py +67 -0
- textpolicy/environment/__init__.py +79 -0
- textpolicy/environment/base.py +110 -0
- textpolicy/environment/environment.py +46 -0
- textpolicy/environment/factory.py +103 -0
- textpolicy/environment/gym.py +106 -0
- textpolicy/environment/task_suites.py +51 -0
- textpolicy/environment/text_generation.py +797 -0
- textpolicy/environment/vectorized.py +253 -0
- textpolicy/generation/__init__.py +62 -0
- textpolicy/generation/lora.py +411 -0
- textpolicy/generation/mlx_generation.py +557 -0
- textpolicy/generation/reload.py +253 -0
- textpolicy/rewards/__init__.py +137 -0
- textpolicy/rewards/adapters.py +387 -0
- textpolicy/rewards/basic.py +214 -0
- textpolicy/rewards/integrated_system.py +338 -0
- textpolicy/rewards/mlx_batch_processor.py +447 -0
- textpolicy/rewards/registry.py +293 -0
- textpolicy/rewards/rollout_rewards.py +410 -0
- textpolicy/rewards/verifiers.py +369 -0
- textpolicy/rollout/__init__.py +44 -0
- textpolicy/rollout/aggregator.py +145 -0
- textpolicy/rollout/base.py +108 -0
- textpolicy/rollout/rollout.py +142 -0
- textpolicy/rollout/runner.py +280 -0
- textpolicy/rollout/strategy.py +208 -0
- textpolicy/rollout/worker.py +194 -0
- textpolicy/training/__init__.py +14 -0
- textpolicy/training/metrics.py +242 -0
- textpolicy/training/rollout_manager.py +78 -0
- textpolicy/training/trainer.py +684 -0
- textpolicy/utils/__init__.py +40 -0
- textpolicy/utils/benchmarking.py +489 -0
- textpolicy/utils/data.py +60 -0
- textpolicy/utils/debug.py +170 -0
- textpolicy/utils/environment.py +349 -0
- textpolicy/utils/logging/__init__.py +22 -0
- textpolicy/utils/logging/base.py +48 -0
- textpolicy/utils/logging/console.py +61 -0
- textpolicy/utils/logging/factory.py +133 -0
- textpolicy/utils/logging/multi.py +83 -0
- textpolicy/utils/logging/tensorboard.py +65 -0
- textpolicy/utils/logging/wandb.py +72 -0
- textpolicy/utils/memory.py +118 -0
- textpolicy/utils/performance.py +464 -0
- textpolicy/utils/timing.py +171 -0
- textpolicy/validate.py +101 -0
- textpolicy/validation/__init__.py +13 -0
- textpolicy/validation/logprob_validation.py +315 -0
- textpolicy-0.1.1.dist-info/METADATA +109 -0
- textpolicy-0.1.1.dist-info/RECORD +66 -0
- {textpolicy-0.0.1.dist-info → textpolicy-0.1.1.dist-info}/WHEEL +1 -1
- textpolicy-0.1.1.dist-info/entry_points.txt +2 -0
- textpolicy-0.0.1.dist-info/METADATA +0 -10
- textpolicy-0.0.1.dist-info/RECORD +0 -6
- {textpolicy-0.0.1.dist-info → textpolicy-0.1.1.dist-info}/licenses/LICENSE +0 -0
- {textpolicy-0.0.1.dist-info → textpolicy-0.1.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,642 @@
|
|
|
1
|
+
# textpolicy/algorithms/grpo.py
|
|
2
|
+
"""
|
|
3
|
+
Group Relative Policy Optimization (GRPO) - Pure Functions for MLX.
|
|
4
|
+
|
|
5
|
+
GRPO eliminates value function training by using group-relative advantages:
|
|
6
|
+
A(τ) = R(τ) - mean(R(group))
|
|
7
|
+
|
|
8
|
+
These pure functions are designed for:
|
|
9
|
+
- MLX compilation with @mx.compile
|
|
10
|
+
- Apple Silicon unified memory
|
|
11
|
+
- Low abstraction cost
|
|
12
|
+
- Composability
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
from __future__ import annotations
|
|
16
|
+
|
|
17
|
+
try:
|
|
18
|
+
import mlx.core as mx # type: ignore
|
|
19
|
+
except ImportError:
|
|
20
|
+
mx = None # MLX is optional; compilation-decorated functions will error if MLX is missing
|
|
21
|
+
|
|
22
|
+
# Provide a no-op compile decorator when MLX is not available
|
|
23
|
+
if mx is None:
|
|
24
|
+
class _DummyMx:
|
|
25
|
+
def compile(self, fn):
|
|
26
|
+
return fn
|
|
27
|
+
|
|
28
|
+
mx = _DummyMx()
|
|
29
|
+
from typing import List, Union
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def compute_advantages(rewards: Union[List[float], mx.array]) -> mx.array:
|
|
33
|
+
"""
|
|
34
|
+
Compute group-relative advantages for GRPO.
|
|
35
|
+
|
|
36
|
+
Core GRPO innovation: Use group mean as baseline instead of value function.
|
|
37
|
+
This eliminates 50% of neural network training while providing stable gradients.
|
|
38
|
+
|
|
39
|
+
Formula: A(τ) = R(τ) - mean(R(group))
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
rewards: Episode rewards, either Python list or MLX array
|
|
43
|
+
|
|
44
|
+
Returns:
|
|
45
|
+
Group-relative advantages as MLX array
|
|
46
|
+
|
|
47
|
+
Notes:
|
|
48
|
+
- Single vectorized operation (no Python loops)
|
|
49
|
+
- Minimal memory allocation
|
|
50
|
+
- Suitable for @mx.compile decoration
|
|
51
|
+
- Handles variable batch sizes
|
|
52
|
+
"""
|
|
53
|
+
if isinstance(rewards, list):
|
|
54
|
+
if not rewards:
|
|
55
|
+
return mx.array([])
|
|
56
|
+
rewards_tensor = mx.array(rewards, dtype=mx.float32)
|
|
57
|
+
elif isinstance(rewards, mx.array):
|
|
58
|
+
rewards_tensor = rewards.astype(mx.float32)
|
|
59
|
+
else:
|
|
60
|
+
raise TypeError(f"Expected list or mx.array, got {type(rewards)}")
|
|
61
|
+
|
|
62
|
+
# Group-relative advantages: rewards relative to group mean
|
|
63
|
+
# Broadcasting handles the subtraction efficiently
|
|
64
|
+
group_mean = mx.mean(rewards_tensor)
|
|
65
|
+
advantages = rewards_tensor - group_mean
|
|
66
|
+
|
|
67
|
+
return advantages
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def compute_advantages_dr_grpo(rewards: Union[List[float], mx.array]) -> mx.array:
|
|
71
|
+
"""
|
|
72
|
+
Compute advantages using Dr. GRPO (GRPO Done Right) - bias-corrected version.
|
|
73
|
+
|
|
74
|
+
Based on https://arxiv.org/html/2503.20783, this version fixes two key biases:
|
|
75
|
+
1. Response-level length bias: Removes 1/|o_i| normalization
|
|
76
|
+
2. Question-level difficulty bias: Removes std normalization
|
|
77
|
+
|
|
78
|
+
Dr. GRPO formula: A(τ) = R(τ) - mean(R(group))
|
|
79
|
+
(Same as basic GRPO but ensures no hidden normalizations)
|
|
80
|
+
|
|
81
|
+
Args:
|
|
82
|
+
rewards: Episode rewards, either Python list or MLX array
|
|
83
|
+
|
|
84
|
+
Returns:
|
|
85
|
+
Unbiased group-relative advantages as MLX array
|
|
86
|
+
|
|
87
|
+
Key improvements over standard GRPO:
|
|
88
|
+
- No response length normalization (prevents length bias)
|
|
89
|
+
- No standard deviation normalization (prevents difficulty bias)
|
|
90
|
+
- Recovers original unbiased policy gradient objective
|
|
91
|
+
"""
|
|
92
|
+
if isinstance(rewards, list):
|
|
93
|
+
if not rewards:
|
|
94
|
+
return mx.array([])
|
|
95
|
+
rewards_tensor = mx.array(rewards, dtype=mx.float32)
|
|
96
|
+
elif isinstance(rewards, mx.array):
|
|
97
|
+
rewards_tensor = rewards.astype(mx.float32)
|
|
98
|
+
else:
|
|
99
|
+
raise TypeError(f"Expected list or mx.array, got {type(rewards)}")
|
|
100
|
+
|
|
101
|
+
# Dr. GRPO: Pure group-relative advantages without any normalization bias
|
|
102
|
+
# Key insight: Keep advantages raw to avoid length/difficulty biases
|
|
103
|
+
group_mean = mx.mean(rewards_tensor)
|
|
104
|
+
advantages = rewards_tensor - group_mean
|
|
105
|
+
|
|
106
|
+
# Do not apply extra normalizations that introduce bias:
|
|
107
|
+
# - NO division by response length |o_i| (creates length bias)
|
|
108
|
+
# - NO division by std(rewards) (creates difficulty bias)
|
|
109
|
+
# - Keep raw advantage signal for unbiased learning
|
|
110
|
+
|
|
111
|
+
return advantages
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def policy_loss(
|
|
115
|
+
old_logprobs: mx.array,
|
|
116
|
+
new_logprobs: mx.array,
|
|
117
|
+
advantages: mx.array,
|
|
118
|
+
clip_ratio: float = 0.2
|
|
119
|
+
) -> mx.array:
|
|
120
|
+
"""
|
|
121
|
+
GRPO policy loss with PPO-style clipping.
|
|
122
|
+
|
|
123
|
+
Uses clipped surrogate objective but with group-relative advantages
|
|
124
|
+
instead of GAE advantages.
|
|
125
|
+
|
|
126
|
+
Args:
|
|
127
|
+
old_logprobs: Log probabilities from rollout collection
|
|
128
|
+
new_logprobs: Log probabilities from current policy evaluation
|
|
129
|
+
advantages: Group-relative advantages from compute_advantages()
|
|
130
|
+
clip_ratio: Clipping ratio for surrogate objective
|
|
131
|
+
|
|
132
|
+
Returns:
|
|
133
|
+
Policy loss scalar (to be minimized)
|
|
134
|
+
|
|
135
|
+
Notes:
|
|
136
|
+
- Fully vectorized (no Python loops over batch)
|
|
137
|
+
- Uses in-place operations where possible
|
|
138
|
+
- Suitable for MLX graph optimization
|
|
139
|
+
- Single forward pass through computation
|
|
140
|
+
"""
|
|
141
|
+
# Importance ratio: π_new / π_old
|
|
142
|
+
# MLX optimizes exp() for Apple Silicon
|
|
143
|
+
ratio = mx.exp(new_logprobs - old_logprobs)
|
|
144
|
+
|
|
145
|
+
# PPO clipped surrogate objective
|
|
146
|
+
# L = min(ratio * A, clip(ratio, 1-ε, 1+ε) * A)
|
|
147
|
+
clipped_ratio = mx.clip(ratio, 1 - clip_ratio, 1 + clip_ratio)
|
|
148
|
+
|
|
149
|
+
# Element-wise minimum and mean reduction
|
|
150
|
+
# Negative because we minimize (original maximizes)
|
|
151
|
+
surr1 = ratio * advantages
|
|
152
|
+
surr2 = clipped_ratio * advantages
|
|
153
|
+
loss = -mx.mean(mx.minimum(surr1, surr2))
|
|
154
|
+
|
|
155
|
+
return loss
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
# Optional: Compiled versions for maximum performance
|
|
159
|
+
@mx.compile
|
|
160
|
+
def compute_advantages_compiled(rewards: mx.array) -> mx.array:
|
|
161
|
+
"""Compiled version of compute_advantages for maximum performance."""
|
|
162
|
+
group_mean = mx.mean(rewards)
|
|
163
|
+
return rewards - group_mean
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
@mx.compile
|
|
167
|
+
def policy_loss_compiled(
|
|
168
|
+
old_logprobs: mx.array,
|
|
169
|
+
new_logprobs: mx.array,
|
|
170
|
+
advantages: mx.array,
|
|
171
|
+
clip_ratio: float = 0.2
|
|
172
|
+
) -> mx.array:
|
|
173
|
+
"""Compiled version of policy_loss for maximum performance."""
|
|
174
|
+
ratio = mx.exp(new_logprobs - old_logprobs)
|
|
175
|
+
clipped_ratio = mx.clip(ratio, 1 - clip_ratio, 1 + clip_ratio)
|
|
176
|
+
surr1 = ratio * advantages
|
|
177
|
+
surr2 = clipped_ratio * advantages
|
|
178
|
+
return -mx.mean(mx.minimum(surr1, surr2))
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
def entropy_bonus(logprobs: mx.array, coefficient: float = 0.01) -> mx.array:
|
|
182
|
+
"""
|
|
183
|
+
Entropy bonus for exploration (optional GRPO component).
|
|
184
|
+
|
|
185
|
+
Args:
|
|
186
|
+
logprobs: Log probabilities from policy
|
|
187
|
+
coefficient: Entropy coefficient (typically small, like 0.01)
|
|
188
|
+
|
|
189
|
+
Returns:
|
|
190
|
+
Entropy bonus (added to loss for exploration)
|
|
191
|
+
"""
|
|
192
|
+
if coefficient <= 0:
|
|
193
|
+
return mx.array(0.0)
|
|
194
|
+
|
|
195
|
+
# Entropy = -sum(p * log(p))
|
|
196
|
+
# For log probabilities: entropy = -sum(exp(logp) * logp)
|
|
197
|
+
probs = mx.exp(logprobs)
|
|
198
|
+
entropy = -mx.sum(probs * logprobs, axis=-1)
|
|
199
|
+
|
|
200
|
+
# Return negative entropy (since we add to loss but want to maximize entropy)
|
|
201
|
+
return -coefficient * mx.mean(entropy)
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
# Convenience function for complete GRPO computation
|
|
205
|
+
def grpo_loss(
|
|
206
|
+
old_logprobs: mx.array,
|
|
207
|
+
new_logprobs: mx.array,
|
|
208
|
+
rewards: Union[List[float], mx.array],
|
|
209
|
+
clip_ratio: float = 0.2,
|
|
210
|
+
entropy_coeff: float = 0.0
|
|
211
|
+
) -> mx.array:
|
|
212
|
+
"""
|
|
213
|
+
Complete GRPO loss computation in one function.
|
|
214
|
+
|
|
215
|
+
Combines advantage calculation and policy loss for convenience.
|
|
216
|
+
Can be compiled as a single unit for maximum efficiency.
|
|
217
|
+
|
|
218
|
+
Args:
|
|
219
|
+
old_logprobs: Log probabilities from rollout
|
|
220
|
+
new_logprobs: Log probabilities from current policy
|
|
221
|
+
rewards: Episode rewards for group-relative advantages
|
|
222
|
+
clip_ratio: PPO clipping ratio
|
|
223
|
+
entropy_coeff: Entropy bonus coefficient (0 disables)
|
|
224
|
+
|
|
225
|
+
Returns:
|
|
226
|
+
Total GRPO loss (policy + optional entropy)
|
|
227
|
+
"""
|
|
228
|
+
# Compute group-relative advantages
|
|
229
|
+
advantages = compute_advantages(rewards)
|
|
230
|
+
|
|
231
|
+
# Expand advantages to match logprob sequence length if needed
|
|
232
|
+
if advantages.ndim == 1 and old_logprobs.ndim > 1:
|
|
233
|
+
# Each episode contributes its advantage to all tokens in that episode
|
|
234
|
+
# This requires knowing episode boundaries - simplified version assumes
|
|
235
|
+
# advantages and logprobs are already aligned
|
|
236
|
+
pass
|
|
237
|
+
|
|
238
|
+
# Compute policy loss
|
|
239
|
+
policy_loss_val = policy_loss(old_logprobs, new_logprobs, advantages, clip_ratio)
|
|
240
|
+
|
|
241
|
+
# Add entropy bonus if specified
|
|
242
|
+
if entropy_coeff > 0:
|
|
243
|
+
entropy_bonus_val = entropy_bonus(new_logprobs, entropy_coeff)
|
|
244
|
+
return policy_loss_val + entropy_bonus_val
|
|
245
|
+
|
|
246
|
+
return policy_loss_val
|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
# Performance monitoring utilities
|
|
250
|
+
def compute_metrics(
|
|
251
|
+
old_logprobs: mx.array,
|
|
252
|
+
new_logprobs: mx.array,
|
|
253
|
+
advantages: mx.array,
|
|
254
|
+
clip_ratio: float = 0.2
|
|
255
|
+
) -> dict:
|
|
256
|
+
"""
|
|
257
|
+
Compute GRPO training metrics for monitoring.
|
|
258
|
+
|
|
259
|
+
Args:
|
|
260
|
+
old_logprobs: Log probabilities from rollout
|
|
261
|
+
new_logprobs: Log probabilities from current policy
|
|
262
|
+
advantages: Group-relative advantages
|
|
263
|
+
clip_ratio: Clipping ratio used in loss
|
|
264
|
+
|
|
265
|
+
Returns:
|
|
266
|
+
Dictionary of metrics for logging/monitoring
|
|
267
|
+
"""
|
|
268
|
+
# Importance ratio statistics
|
|
269
|
+
ratio = mx.exp(new_logprobs - old_logprobs)
|
|
270
|
+
|
|
271
|
+
# Clipping statistics
|
|
272
|
+
clip_lower = 1 - clip_ratio
|
|
273
|
+
clip_upper = 1 + clip_ratio
|
|
274
|
+
clipped = (ratio < clip_lower) | (ratio > clip_upper)
|
|
275
|
+
clip_fraction = mx.mean(clipped.astype(mx.float32))
|
|
276
|
+
|
|
277
|
+
# KL divergence approximation
|
|
278
|
+
kl_div = mx.mean(old_logprobs - new_logprobs)
|
|
279
|
+
|
|
280
|
+
return {
|
|
281
|
+
'mean_advantage': mx.mean(advantages).item(),
|
|
282
|
+
'std_advantage': mx.std(advantages).item(),
|
|
283
|
+
'mean_ratio': mx.mean(ratio).item(),
|
|
284
|
+
'clip_fraction': clip_fraction.item(),
|
|
285
|
+
'kl_divergence': kl_div.item(),
|
|
286
|
+
'min_advantage': mx.min(advantages).item(),
|
|
287
|
+
'max_advantage': mx.max(advantages).item()
|
|
288
|
+
}
|
|
289
|
+
|
|
290
|
+
|
|
291
|
+
# Algorithm-specific data selection strategies
|
|
292
|
+
def select_all_data(buffer):
|
|
293
|
+
"""
|
|
294
|
+
GRPO data selector: Use all available data.
|
|
295
|
+
|
|
296
|
+
GRPO is on-policy but can benefit from using all collected episodes
|
|
297
|
+
since group-relative advantages normalize across the entire group.
|
|
298
|
+
|
|
299
|
+
Args:
|
|
300
|
+
buffer: Buffer containing episodes
|
|
301
|
+
|
|
302
|
+
Returns:
|
|
303
|
+
All episode data prepared for training
|
|
304
|
+
"""
|
|
305
|
+
from textpolicy.buffer import Buffer
|
|
306
|
+
if not isinstance(buffer, Buffer):
|
|
307
|
+
raise TypeError(f"Expected Buffer, got {type(buffer)}")
|
|
308
|
+
|
|
309
|
+
# Use all available data - GRPO benefits from larger groups
|
|
310
|
+
episodes_data = buffer.sample() # This returns concatenated transitions
|
|
311
|
+
|
|
312
|
+
# We need to convert this back to episode structure for reward extraction
|
|
313
|
+
episodes = buffer.episodes # Access episodes directly from storage
|
|
314
|
+
|
|
315
|
+
if not episodes:
|
|
316
|
+
raise ValueError("Buffer is empty - no episodes to train on")
|
|
317
|
+
|
|
318
|
+
# Extract episode rewards for advantage computation
|
|
319
|
+
episode_rewards = []
|
|
320
|
+
episode_lengths = []
|
|
321
|
+
|
|
322
|
+
# Collect all transitions
|
|
323
|
+
all_obs = []
|
|
324
|
+
all_acts = []
|
|
325
|
+
all_logprobs = []
|
|
326
|
+
|
|
327
|
+
for episode in episodes:
|
|
328
|
+
# Episode reward (sum of all rewards in episode)
|
|
329
|
+
# Handle both Episode objects and serialized dictionaries
|
|
330
|
+
if hasattr(episode, 'rew'):
|
|
331
|
+
# Episode object with attributes
|
|
332
|
+
episode_reward = mx.sum(mx.array(episode.rew)).item()
|
|
333
|
+
episode_rewards.append(episode_reward)
|
|
334
|
+
episode_lengths.append(len(episode.obs))
|
|
335
|
+
|
|
336
|
+
# Collect transitions
|
|
337
|
+
# For proper logprob extraction during training, we need the full context (prompt + response)
|
|
338
|
+
# This matches how the model was called during rollout generation
|
|
339
|
+
# Flatten nested token sequences to create uniform token arrays
|
|
340
|
+
|
|
341
|
+
# Extract and flatten observation tokens (prompt)
|
|
342
|
+
flattened_obs = []
|
|
343
|
+
for obs in episode.obs:
|
|
344
|
+
if hasattr(obs, 'tolist'): # MLX array
|
|
345
|
+
flattened_obs.extend(obs.tolist())
|
|
346
|
+
elif isinstance(obs, list): # Python list
|
|
347
|
+
flattened_obs.extend(obs)
|
|
348
|
+
else: # Single token
|
|
349
|
+
flattened_obs.append(obs)
|
|
350
|
+
|
|
351
|
+
# Extract and flatten action tokens (response)
|
|
352
|
+
flattened_acts = []
|
|
353
|
+
for act in episode.act:
|
|
354
|
+
if hasattr(act, 'tolist'): # MLX array
|
|
355
|
+
flattened_acts.extend(act.tolist())
|
|
356
|
+
elif isinstance(act, list): # Python list
|
|
357
|
+
flattened_acts.extend(act)
|
|
358
|
+
else: # Single token
|
|
359
|
+
flattened_acts.append(act)
|
|
360
|
+
|
|
361
|
+
# Create full sequence: [prompt_tokens..., response_tokens...]
|
|
362
|
+
full_sequence = flattened_obs + flattened_acts
|
|
363
|
+
all_obs.append(full_sequence)
|
|
364
|
+
all_acts.append(flattened_acts)
|
|
365
|
+
all_logprobs.append(episode.logprob if episode.logprob else [])
|
|
366
|
+
else:
|
|
367
|
+
# Serialized dictionary from multiprocessing
|
|
368
|
+
episode_reward = mx.sum(episode['rew']).item()
|
|
369
|
+
episode_rewards.append(episode_reward)
|
|
370
|
+
episode_lengths.append(len(episode['obs']))
|
|
371
|
+
|
|
372
|
+
# Collect transitions
|
|
373
|
+
# For proper logprob extraction during training, we need the full context (prompt + response)
|
|
374
|
+
# This matches how the model was called during rollout generation
|
|
375
|
+
full_sequence = episode['obs'] + episode['act'] # Concatenate prompt + response
|
|
376
|
+
all_obs.append(full_sequence)
|
|
377
|
+
all_acts.append(episode['act'])
|
|
378
|
+
all_logprobs.append(episode.get('logprob', []))
|
|
379
|
+
|
|
380
|
+
# Convert Python lists to MLX arrays before concatenation
|
|
381
|
+
# This is required because Episode objects store data as Python lists for memory efficiency
|
|
382
|
+
# For proper logprob extraction, we need uniform-length sequences, so we pad to the maximum length
|
|
383
|
+
|
|
384
|
+
# Find maximum sequence length for padding
|
|
385
|
+
max_obs_len = max(len(obs) for obs in all_obs) if all_obs else 0
|
|
386
|
+
max_act_len = max(len(act) for act in all_acts) if all_acts else 0
|
|
387
|
+
max_logprob_len = max(len(logprob) for logprob in all_logprobs) if all_logprobs else 0
|
|
388
|
+
|
|
389
|
+
# MLX-native padding and array operations for optimal Apple Silicon performance
|
|
390
|
+
# Convert all sequences to MLX arrays and pad directly in MLX space
|
|
391
|
+
try:
|
|
392
|
+
# Convert all sequences to MLX arrays first (staying in unified memory)
|
|
393
|
+
all_obs_mx = [mx.array(obs, dtype=mx.int64) for obs in all_obs if obs]
|
|
394
|
+
all_acts_mx = [mx.array(act, dtype=mx.int64) for act in all_acts if act]
|
|
395
|
+
all_logprobs_mx = [mx.array(logprob, dtype=mx.float32) for logprob in all_logprobs if logprob]
|
|
396
|
+
|
|
397
|
+
# Pad using native MLX operations (more efficient for Apple Silicon)
|
|
398
|
+
if all_obs_mx:
|
|
399
|
+
padded_obs_mx = [mx.pad(obs, (0, max_obs_len - obs.shape[0]), constant_values=0)
|
|
400
|
+
if obs.shape[0] < max_obs_len else obs[:max_obs_len]
|
|
401
|
+
for obs in all_obs_mx]
|
|
402
|
+
else:
|
|
403
|
+
padded_obs_mx = []
|
|
404
|
+
|
|
405
|
+
if all_acts_mx:
|
|
406
|
+
padded_acts_mx = [mx.pad(act, (0, max_act_len - act.shape[0]), constant_values=0)
|
|
407
|
+
if act.shape[0] < max_act_len else act[:max_act_len]
|
|
408
|
+
for act in all_acts_mx]
|
|
409
|
+
else:
|
|
410
|
+
padded_acts_mx = []
|
|
411
|
+
|
|
412
|
+
if all_logprobs_mx:
|
|
413
|
+
padded_logprobs_mx = [mx.pad(logprob, (0, max_logprob_len - logprob.shape[0]), constant_values=0.0)
|
|
414
|
+
if logprob.shape[0] < max_logprob_len else logprob[:max_logprob_len]
|
|
415
|
+
for logprob in all_logprobs_mx]
|
|
416
|
+
else:
|
|
417
|
+
padded_logprobs_mx = []
|
|
418
|
+
|
|
419
|
+
# Use padded MLX arrays directly (no intermediate conversion needed)
|
|
420
|
+
all_obs_mx = padded_obs_mx
|
|
421
|
+
all_acts_mx = padded_acts_mx
|
|
422
|
+
all_logprobs_mx = padded_logprobs_mx
|
|
423
|
+
|
|
424
|
+
except Exception as e:
|
|
425
|
+
print(f"ERROR in MLX array conversion: {e}")
|
|
426
|
+
print(f"DEBUG: all_obs types: {[type(obs) for obs in all_obs[:3]]}") # Show first 3 for brevity
|
|
427
|
+
print(f"DEBUG: all_logprobs types: {[type(logprob) for logprob in all_logprobs[:3]]}")
|
|
428
|
+
raise
|
|
429
|
+
|
|
430
|
+
# GRPO data structure: both observations and actions as flat concatenated sequences
|
|
431
|
+
# This matches the expected format for GRPO logprob extraction function
|
|
432
|
+
batch_data = {
|
|
433
|
+
'obs': mx.concatenate(all_obs_mx) if all_obs_mx else mx.array([]), # Flat concatenated full sequences
|
|
434
|
+
'act': mx.concatenate(all_acts_mx) if all_acts_mx else mx.array([]), # Flat concatenated response tokens
|
|
435
|
+
'logprob': mx.concatenate([logprob.flatten() for logprob in all_logprobs_mx]) if all_logprobs_mx else mx.array([]), # Flat sequence for training
|
|
436
|
+
'rewards': mx.array(episode_rewards),
|
|
437
|
+
'episode_lengths': episode_lengths
|
|
438
|
+
}
|
|
439
|
+
|
|
440
|
+
return batch_data
|
|
441
|
+
|
|
442
|
+
|
|
443
|
+
def select_recent_data(buffer, max_episodes: int = 100):
|
|
444
|
+
"""
|
|
445
|
+
GRPO data selector: Use only recent episodes.
|
|
446
|
+
|
|
447
|
+
Alternative selector for GRPO that limits to recent episodes
|
|
448
|
+
for faster training on large buffers.
|
|
449
|
+
|
|
450
|
+
Args:
|
|
451
|
+
buffer: Buffer containing episodes (Episode objects or serialized dictionaries)
|
|
452
|
+
max_episodes: Maximum number of recent episodes to use
|
|
453
|
+
|
|
454
|
+
Returns:
|
|
455
|
+
Recent episode data prepared for training
|
|
456
|
+
"""
|
|
457
|
+
from textpolicy.buffer import Buffer
|
|
458
|
+
if not isinstance(buffer, Buffer):
|
|
459
|
+
raise TypeError(f"Expected Buffer, got {type(buffer)}")
|
|
460
|
+
|
|
461
|
+
episodes = buffer.episodes
|
|
462
|
+
if not episodes:
|
|
463
|
+
raise ValueError("Buffer is empty - no episodes to train on")
|
|
464
|
+
|
|
465
|
+
# Select recent episodes
|
|
466
|
+
recent_episodes = episodes[-max_episodes:] if len(episodes) > max_episodes else episodes
|
|
467
|
+
|
|
468
|
+
# Process recent episodes
|
|
469
|
+
episode_rewards = []
|
|
470
|
+
episode_lengths = []
|
|
471
|
+
all_obs = []
|
|
472
|
+
all_acts = []
|
|
473
|
+
all_logprobs = []
|
|
474
|
+
|
|
475
|
+
for episode in recent_episodes:
|
|
476
|
+
# Handle both Episode objects and serialized dictionaries
|
|
477
|
+
if hasattr(episode, 'rew'):
|
|
478
|
+
# Episode object with attributes
|
|
479
|
+
episode_reward = mx.sum(mx.array(episode.rew)).item()
|
|
480
|
+
episode_rewards.append(episode_reward)
|
|
481
|
+
episode_lengths.append(len(episode.obs))
|
|
482
|
+
|
|
483
|
+
# For proper logprob extraction during training, we need the full context (prompt + response)
|
|
484
|
+
# This matches how the model was called during rollout generation
|
|
485
|
+
# Convert both obs and act to consistent Python list format before concatenation
|
|
486
|
+
obs_as_lists = []
|
|
487
|
+
for obs_item in episode.obs:
|
|
488
|
+
if hasattr(obs_item, 'tolist'): # MLX array
|
|
489
|
+
obs_as_lists.extend(obs_item.tolist())
|
|
490
|
+
elif isinstance(obs_item, list): # Already Python list
|
|
491
|
+
obs_as_lists.extend(obs_item)
|
|
492
|
+
else: # Single item
|
|
493
|
+
obs_as_lists.append(obs_item)
|
|
494
|
+
|
|
495
|
+
act_as_lists = []
|
|
496
|
+
for act_item in episode.act:
|
|
497
|
+
if hasattr(act_item, 'tolist'): # MLX array
|
|
498
|
+
act_as_lists.extend(act_item.tolist())
|
|
499
|
+
elif isinstance(act_item, list): # Already Python list
|
|
500
|
+
act_as_lists.extend(act_item)
|
|
501
|
+
else: # Single item
|
|
502
|
+
act_as_lists.append(act_item)
|
|
503
|
+
|
|
504
|
+
# Now concatenate the normalized lists
|
|
505
|
+
full_sequence = obs_as_lists + act_as_lists
|
|
506
|
+
all_obs.append(full_sequence)
|
|
507
|
+
|
|
508
|
+
# Extract actions as consistent Python lists
|
|
509
|
+
episode_actions = []
|
|
510
|
+
for act_item in episode.act:
|
|
511
|
+
if hasattr(act_item, 'tolist'): # MLX array
|
|
512
|
+
episode_actions.extend(act_item.tolist())
|
|
513
|
+
elif isinstance(act_item, list): # Already Python list
|
|
514
|
+
episode_actions.extend(act_item)
|
|
515
|
+
else: # Single item
|
|
516
|
+
episode_actions.append(act_item)
|
|
517
|
+
all_acts.append(episode_actions)
|
|
518
|
+
|
|
519
|
+
# Extract logprobs as consistent Python lists
|
|
520
|
+
episode_logprobs = []
|
|
521
|
+
if episode.logprob:
|
|
522
|
+
for logprob_item in episode.logprob:
|
|
523
|
+
if hasattr(logprob_item, 'tolist'): # MLX array
|
|
524
|
+
episode_logprobs.extend(logprob_item.tolist())
|
|
525
|
+
elif isinstance(logprob_item, list): # Already Python list
|
|
526
|
+
episode_logprobs.extend(logprob_item)
|
|
527
|
+
else: # Single item
|
|
528
|
+
episode_logprobs.append(logprob_item)
|
|
529
|
+
all_logprobs.append(episode_logprobs)
|
|
530
|
+
else:
|
|
531
|
+
# Serialized dictionary from multiprocessing
|
|
532
|
+
episode_reward = mx.sum(episode['rew']).item()
|
|
533
|
+
episode_rewards.append(episode_reward)
|
|
534
|
+
episode_lengths.append(len(episode['obs']))
|
|
535
|
+
|
|
536
|
+
# For proper logprob extraction during training, we need the full context (prompt + response)
|
|
537
|
+
# This matches how the model was called during rollout generation
|
|
538
|
+
# Convert both obs and act to consistent Python list format before concatenation
|
|
539
|
+
obs_as_lists = []
|
|
540
|
+
for obs_item in episode['obs']:
|
|
541
|
+
if hasattr(obs_item, 'tolist'): # MLX array
|
|
542
|
+
obs_as_lists.extend(obs_item.tolist())
|
|
543
|
+
elif isinstance(obs_item, list): # Already Python list
|
|
544
|
+
obs_as_lists.extend(obs_item)
|
|
545
|
+
else: # Single item
|
|
546
|
+
obs_as_lists.append(obs_item)
|
|
547
|
+
|
|
548
|
+
act_as_lists = []
|
|
549
|
+
for act_item in episode['act']:
|
|
550
|
+
if hasattr(act_item, 'tolist'): # MLX array
|
|
551
|
+
act_as_lists.extend(act_item.tolist())
|
|
552
|
+
elif isinstance(act_item, list): # Already Python list
|
|
553
|
+
act_as_lists.extend(act_item)
|
|
554
|
+
else: # Single item
|
|
555
|
+
act_as_lists.append(act_item)
|
|
556
|
+
|
|
557
|
+
# Now concatenate the normalized lists
|
|
558
|
+
full_sequence = obs_as_lists + act_as_lists
|
|
559
|
+
all_obs.append(full_sequence)
|
|
560
|
+
|
|
561
|
+
# Extract actions as consistent Python lists
|
|
562
|
+
episode_actions = []
|
|
563
|
+
for act_item in episode['act']:
|
|
564
|
+
if hasattr(act_item, 'tolist'): # MLX array
|
|
565
|
+
episode_actions.extend(act_item.tolist())
|
|
566
|
+
elif isinstance(act_item, list): # Already Python list
|
|
567
|
+
episode_actions.extend(act_item)
|
|
568
|
+
else: # Single item
|
|
569
|
+
episode_actions.append(act_item)
|
|
570
|
+
all_acts.append(episode_actions)
|
|
571
|
+
|
|
572
|
+
# Extract logprobs as consistent Python lists
|
|
573
|
+
episode_logprobs = []
|
|
574
|
+
if episode.get('logprob'):
|
|
575
|
+
for logprob_item in episode['logprob']:
|
|
576
|
+
if hasattr(logprob_item, 'tolist'): # MLX array
|
|
577
|
+
episode_logprobs.extend(logprob_item.tolist())
|
|
578
|
+
elif isinstance(logprob_item, list): # Already Python list
|
|
579
|
+
episode_logprobs.extend(logprob_item)
|
|
580
|
+
else: # Single item
|
|
581
|
+
episode_logprobs.append(logprob_item)
|
|
582
|
+
all_logprobs.append(episode_logprobs)
|
|
583
|
+
|
|
584
|
+
# Convert Python lists to MLX arrays before concatenation
|
|
585
|
+
# This is required because Episode objects store data as Python lists for memory efficiency
|
|
586
|
+
# For proper logprob extraction, we need uniform-length sequences, so we pad to the maximum length
|
|
587
|
+
|
|
588
|
+
# Find maximum sequence length for padding
|
|
589
|
+
max_obs_len = max(len(obs) for obs in all_obs) if all_obs else 0
|
|
590
|
+
max_act_len = max(len(act) for act in all_acts) if all_acts else 0
|
|
591
|
+
max_logprob_len = max(len(logprob) for logprob in all_logprobs) if all_logprobs else 0
|
|
592
|
+
|
|
593
|
+
# MLX-native padding and array operations for optimal Apple Silicon performance
|
|
594
|
+
# Convert all sequences to MLX arrays and pad directly in MLX space
|
|
595
|
+
try:
|
|
596
|
+
# Convert all sequences to MLX arrays first (staying in unified memory)
|
|
597
|
+
all_obs_mx = [mx.array(obs, dtype=mx.int64) for obs in all_obs if obs]
|
|
598
|
+
all_acts_mx = [mx.array(act, dtype=mx.int64) for act in all_acts if act]
|
|
599
|
+
all_logprobs_mx = [mx.array(logprob, dtype=mx.float32) for logprob in all_logprobs if logprob]
|
|
600
|
+
|
|
601
|
+
# Pad using native MLX operations (more efficient for Apple Silicon)
|
|
602
|
+
if all_obs_mx:
|
|
603
|
+
padded_obs_mx = [mx.pad(obs, (0, max_obs_len - obs.shape[0]), constant_values=0)
|
|
604
|
+
if obs.shape[0] < max_obs_len else obs[:max_obs_len]
|
|
605
|
+
for obs in all_obs_mx]
|
|
606
|
+
else:
|
|
607
|
+
padded_obs_mx = []
|
|
608
|
+
|
|
609
|
+
if all_acts_mx:
|
|
610
|
+
padded_acts_mx = [mx.pad(act, (0, max_act_len - act.shape[0]), constant_values=0)
|
|
611
|
+
if act.shape[0] < max_act_len else act[:max_act_len]
|
|
612
|
+
for act in all_acts_mx]
|
|
613
|
+
else:
|
|
614
|
+
padded_acts_mx = []
|
|
615
|
+
|
|
616
|
+
if all_logprobs_mx:
|
|
617
|
+
padded_logprobs_mx = [mx.pad(logprob, (0, max_logprob_len - logprob.shape[0]), constant_values=0.0)
|
|
618
|
+
if logprob.shape[0] < max_logprob_len else logprob[:max_logprob_len]
|
|
619
|
+
for logprob in all_logprobs_mx]
|
|
620
|
+
else:
|
|
621
|
+
padded_logprobs_mx = []
|
|
622
|
+
|
|
623
|
+
# Use padded MLX arrays directly (no intermediate conversion needed)
|
|
624
|
+
all_obs_mx = padded_obs_mx
|
|
625
|
+
all_acts_mx = padded_acts_mx
|
|
626
|
+
all_logprobs_mx = padded_logprobs_mx
|
|
627
|
+
|
|
628
|
+
except Exception as e:
|
|
629
|
+
print(f"ERROR in MLX array conversion: {e}")
|
|
630
|
+
print(f"DEBUG: all_obs types: {[type(obs) for obs in all_obs[:3]]}") # Show first 3 for brevity
|
|
631
|
+
print(f"DEBUG: all_logprobs types: {[type(logprob) for logprob in all_logprobs[:3]]}")
|
|
632
|
+
raise
|
|
633
|
+
|
|
634
|
+
batch_data = {
|
|
635
|
+
'obs': mx.concatenate(all_obs_mx) if all_obs_mx else mx.array([]), # Flat concatenated full sequences
|
|
636
|
+
'act': mx.concatenate(all_acts_mx) if all_acts_mx else mx.array([]), # Flat concatenated response tokens
|
|
637
|
+
'logprob': mx.concatenate([logprob.flatten() for logprob in all_logprobs_mx]) if all_logprobs_mx else mx.array([]), # Flat sequence for training
|
|
638
|
+
'rewards': mx.array(episode_rewards),
|
|
639
|
+
'episode_lengths': episode_lengths
|
|
640
|
+
}
|
|
641
|
+
|
|
642
|
+
return batch_data
|