textpolicy 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- textpolicy/__init__.py +52 -0
- textpolicy/__main__.py +8 -0
- textpolicy/algorithms/__init__.py +54 -0
- textpolicy/algorithms/grpo.py +642 -0
- textpolicy/algorithms/gspo.py +582 -0
- textpolicy/buffer/__init__.py +23 -0
- textpolicy/buffer/buffer.py +244 -0
- textpolicy/buffer/episode.py +383 -0
- textpolicy/buffer/sampling.py +438 -0
- textpolicy/buffer/storage.py +255 -0
- textpolicy/cli.py +67 -0
- textpolicy/environment/__init__.py +79 -0
- textpolicy/environment/base.py +110 -0
- textpolicy/environment/environment.py +46 -0
- textpolicy/environment/factory.py +103 -0
- textpolicy/environment/gym.py +106 -0
- textpolicy/environment/task_suites.py +51 -0
- textpolicy/environment/text_generation.py +789 -0
- textpolicy/environment/vectorized.py +253 -0
- textpolicy/generation/__init__.py +62 -0
- textpolicy/generation/lora.py +411 -0
- textpolicy/generation/mlx_generation.py +557 -0
- textpolicy/generation/reload.py +253 -0
- textpolicy/rewards/__init__.py +137 -0
- textpolicy/rewards/adapters.py +387 -0
- textpolicy/rewards/basic.py +214 -0
- textpolicy/rewards/integrated_system.py +338 -0
- textpolicy/rewards/mlx_batch_processor.py +447 -0
- textpolicy/rewards/registry.py +293 -0
- textpolicy/rewards/rollout_rewards.py +410 -0
- textpolicy/rewards/verifiers.py +369 -0
- textpolicy/rollout/__init__.py +44 -0
- textpolicy/rollout/aggregator.py +145 -0
- textpolicy/rollout/base.py +108 -0
- textpolicy/rollout/rollout.py +142 -0
- textpolicy/rollout/runner.py +280 -0
- textpolicy/rollout/strategy.py +208 -0
- textpolicy/rollout/worker.py +194 -0
- textpolicy/training/__init__.py +14 -0
- textpolicy/training/metrics.py +242 -0
- textpolicy/training/rollout_manager.py +78 -0
- textpolicy/training/trainer.py +684 -0
- textpolicy/utils/__init__.py +40 -0
- textpolicy/utils/benchmarking.py +489 -0
- textpolicy/utils/data.py +60 -0
- textpolicy/utils/debug.py +170 -0
- textpolicy/utils/environment.py +349 -0
- textpolicy/utils/logging/__init__.py +22 -0
- textpolicy/utils/logging/base.py +48 -0
- textpolicy/utils/logging/console.py +61 -0
- textpolicy/utils/logging/factory.py +133 -0
- textpolicy/utils/logging/multi.py +83 -0
- textpolicy/utils/logging/tensorboard.py +65 -0
- textpolicy/utils/logging/wandb.py +72 -0
- textpolicy/utils/memory.py +118 -0
- textpolicy/utils/performance.py +464 -0
- textpolicy/utils/timing.py +171 -0
- textpolicy/validate.py +101 -0
- textpolicy/validation/__init__.py +13 -0
- textpolicy/validation/logprob_validation.py +315 -0
- textpolicy-0.1.0.dist-info/METADATA +99 -0
- textpolicy-0.1.0.dist-info/RECORD +66 -0
- textpolicy-0.1.0.dist-info/entry_points.txt +2 -0
- textpolicy-0.0.1.dist-info/METADATA +0 -10
- textpolicy-0.0.1.dist-info/RECORD +0 -6
- {textpolicy-0.0.1.dist-info → textpolicy-0.1.0.dist-info}/WHEEL +0 -0
- {textpolicy-0.0.1.dist-info → textpolicy-0.1.0.dist-info}/licenses/LICENSE +0 -0
- {textpolicy-0.0.1.dist-info → textpolicy-0.1.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,582 @@
|
|
|
1
|
+
# textpolicy/algorithms/gspo.py
|
|
2
|
+
"""
|
|
3
|
+
Group Sequence Policy Optimization (GSPO).
|
|
4
|
+
|
|
5
|
+
GSPO computes importance weights at the sequence level to align with
|
|
6
|
+
sequence-level rewards. Variants include sequence, token, and hybrid forms.
|
|
7
|
+
Reference: https://swift.readthedocs.io/en/latest/Instruction/GRPO/AdvancedResearch/GSPO.html
|
|
8
|
+
"""
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
import mlx.core as mx
|
|
11
|
+
from typing import List, Dict
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def compute_sequence_importance_weights(
|
|
15
|
+
old_logprobs: mx.array,
|
|
16
|
+
new_logprobs: mx.array,
|
|
17
|
+
sequence_lengths: List[int],
|
|
18
|
+
clip_ratio: float = 0.2
|
|
19
|
+
) -> mx.array:
|
|
20
|
+
"""
|
|
21
|
+
Compute sequence-level importance weights for GSPO.
|
|
22
|
+
|
|
23
|
+
GSPO formula: w^GSPO_{i} = [π_θ(y_i | x) / π_θ_old(y_i | x)]^(1/|y_i|)
|
|
24
|
+
|
|
25
|
+
This normalizes by sequence length to prevent bias toward shorter/longer sequences.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
old_logprobs: Log probabilities from rollout collection [batch_size, seq_len]
|
|
29
|
+
new_logprobs: Log probabilities from current policy [batch_size, seq_len]
|
|
30
|
+
sequence_lengths: Length of each sequence in the batch
|
|
31
|
+
|
|
32
|
+
Returns:
|
|
33
|
+
Sequence-level importance weights [batch_size]
|
|
34
|
+
|
|
35
|
+
Compared to token-level sampling, this reduces variance and matches
|
|
36
|
+
sequence-level reward assignment.
|
|
37
|
+
"""
|
|
38
|
+
batch_size = len(sequence_lengths)
|
|
39
|
+
sequence_weights = []
|
|
40
|
+
|
|
41
|
+
current_idx = 0
|
|
42
|
+
for seq_len in sequence_lengths:
|
|
43
|
+
# Extract logprobs for this sequence
|
|
44
|
+
seq_old_logprobs = old_logprobs[current_idx:current_idx + seq_len]
|
|
45
|
+
seq_new_logprobs = new_logprobs[current_idx:current_idx + seq_len] # type: ignore
|
|
46
|
+
|
|
47
|
+
# Compute sequence-level log probability: sum of token log probs
|
|
48
|
+
old_seq_logprob = mx.sum(seq_old_logprobs)
|
|
49
|
+
new_seq_logprob = mx.sum(seq_new_logprobs)
|
|
50
|
+
|
|
51
|
+
# Sequence-level importance ratio: π_new(y|x) / π_old(y|x)
|
|
52
|
+
log_ratio = new_seq_logprob - old_seq_logprob
|
|
53
|
+
|
|
54
|
+
# GSPO normalization: raise to power 1/|y_i| to prevent length bias
|
|
55
|
+
# This ensures sequences of different lengths contribute equally
|
|
56
|
+
normalized_log_ratio = log_ratio / seq_len
|
|
57
|
+
|
|
58
|
+
# Clip in log space to prevent numerical explosion
|
|
59
|
+
# This is the key missing piece that was causing billion-scale importance weights
|
|
60
|
+
clipped_log_ratio = mx.clip(
|
|
61
|
+
normalized_log_ratio,
|
|
62
|
+
mx.log(mx.array(1 - clip_ratio)), # log(0.8) ≈ -0.22
|
|
63
|
+
mx.log(mx.array(1 + clip_ratio)) # log(1.2) ≈ 0.18
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
# Now safely compute importance weight (will be in range [0.8, 1.2])
|
|
67
|
+
importance_weight = mx.exp(clipped_log_ratio)
|
|
68
|
+
|
|
69
|
+
# Final check: account for float32 precision
|
|
70
|
+
# exp(log(1.2)) in float32 may produce 1.2000000476837158; enforce exact bounds
|
|
71
|
+
importance_weight = mx.clip(importance_weight, 1.0 - clip_ratio, 1.0 + clip_ratio)
|
|
72
|
+
|
|
73
|
+
# Enforce exact bounds using scalar comparisons
|
|
74
|
+
# Convert to float for comparison to avoid MLX array comparison issues
|
|
75
|
+
weight_float = float(importance_weight)
|
|
76
|
+
if weight_float > 1.0 + clip_ratio:
|
|
77
|
+
importance_weight = mx.array(1.0 + clip_ratio)
|
|
78
|
+
elif weight_float < 1.0 - clip_ratio:
|
|
79
|
+
importance_weight = mx.array(1.0 - clip_ratio)
|
|
80
|
+
|
|
81
|
+
sequence_weights.append(importance_weight)
|
|
82
|
+
current_idx += seq_len
|
|
83
|
+
|
|
84
|
+
return mx.array(sequence_weights)
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def compute_hybrid_importance_weights(
|
|
88
|
+
old_logprobs: mx.array,
|
|
89
|
+
new_logprobs: mx.array,
|
|
90
|
+
sequence_lengths: List[int],
|
|
91
|
+
alpha: float = 0.5,
|
|
92
|
+
beta: float = 0.5
|
|
93
|
+
) -> mx.array:
|
|
94
|
+
"""
|
|
95
|
+
Compute hybrid importance weights using principled log-space combination.
|
|
96
|
+
|
|
97
|
+
Instead of multiplying exp(seq_ratio) * exp(token_ratio) which compounds variance,
|
|
98
|
+
uses additive combination: exp(α * seq_log_ratio + β * token_log_ratio)
|
|
99
|
+
|
|
100
|
+
This provides a more stable and theoretically sound approach to combining
|
|
101
|
+
sequence-level stability with token-level granularity.
|
|
102
|
+
|
|
103
|
+
Args:
|
|
104
|
+
old_logprobs: Log probabilities from rollout collection [total_tokens]
|
|
105
|
+
new_logprobs: Log probabilities from current policy [total_tokens]
|
|
106
|
+
sequence_lengths: Length of each sequence in the batch
|
|
107
|
+
alpha: Weight for sequence-level importance (default: 0.5)
|
|
108
|
+
beta: Weight for token-level importance (default: 0.5)
|
|
109
|
+
|
|
110
|
+
Returns:
|
|
111
|
+
Hybrid importance weights [total_tokens]
|
|
112
|
+
|
|
113
|
+
Advantages:
|
|
114
|
+
- Avoids explosive multiplication of exponentials
|
|
115
|
+
- Controlled variance through hyperparameter balance
|
|
116
|
+
- Principled combination in log-space
|
|
117
|
+
"""
|
|
118
|
+
# Compute sequence-level log ratios (without exponential)
|
|
119
|
+
batch_size = len(sequence_lengths)
|
|
120
|
+
seq_log_ratios = []
|
|
121
|
+
|
|
122
|
+
current_idx = 0
|
|
123
|
+
for seq_len in sequence_lengths:
|
|
124
|
+
# Extract logprobs for this sequence
|
|
125
|
+
seq_old_logprobs = old_logprobs[current_idx:current_idx + seq_len]
|
|
126
|
+
seq_new_logprobs = new_logprobs[current_idx:current_idx + seq_len] # type: ignore
|
|
127
|
+
|
|
128
|
+
# Compute sequence-level log probability: sum of token log probs
|
|
129
|
+
old_seq_logprob = mx.sum(seq_old_logprobs)
|
|
130
|
+
new_seq_logprob = mx.sum(seq_new_logprobs)
|
|
131
|
+
|
|
132
|
+
# Sequence-level log ratio with GSPO normalization (prevent length bias)
|
|
133
|
+
log_ratio = new_seq_logprob - old_seq_logprob
|
|
134
|
+
normalized_seq_log_ratio = log_ratio / seq_len
|
|
135
|
+
|
|
136
|
+
seq_log_ratios.append(normalized_seq_log_ratio)
|
|
137
|
+
current_idx += seq_len
|
|
138
|
+
|
|
139
|
+
# Expand sequence-level log ratios to token level
|
|
140
|
+
token_seq_log_ratios = []
|
|
141
|
+
for i, seq_len in enumerate(sequence_lengths):
|
|
142
|
+
# Use stop gradient to prevent certain gradient flows
|
|
143
|
+
seq_log_ratio_sg = mx.stop_gradient(seq_log_ratios[i])
|
|
144
|
+
token_seq_log_ratios.extend([seq_log_ratio_sg] * seq_len)
|
|
145
|
+
|
|
146
|
+
token_seq_log_ratios = mx.array(token_seq_log_ratios)
|
|
147
|
+
|
|
148
|
+
# Compute token-level log ratios (with stop gradient on old logprobs)
|
|
149
|
+
old_logprobs_sg = mx.stop_gradient(old_logprobs)
|
|
150
|
+
token_log_ratios = new_logprobs - old_logprobs_sg
|
|
151
|
+
|
|
152
|
+
# Combine in log-space: α * seq_log_ratio + β * token_log_ratio
|
|
153
|
+
combined_log_ratios = alpha * token_seq_log_ratios + beta * token_log_ratios
|
|
154
|
+
|
|
155
|
+
# Apply single exponential to get final importance weights
|
|
156
|
+
hybrid_weights = mx.exp(combined_log_ratios)
|
|
157
|
+
|
|
158
|
+
return hybrid_weights
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
def gspo_policy_loss(
|
|
162
|
+
old_logprobs: mx.array,
|
|
163
|
+
new_logprobs: mx.array,
|
|
164
|
+
advantages: mx.array,
|
|
165
|
+
sequence_lengths: List[int],
|
|
166
|
+
variant: str = "sequence",
|
|
167
|
+
clip_ratio: float = 0.2,
|
|
168
|
+
alpha: float = 0.5,
|
|
169
|
+
beta: float = 0.5
|
|
170
|
+
) -> mx.array:
|
|
171
|
+
"""
|
|
172
|
+
GSPO policy loss with sequence-level importance sampling.
|
|
173
|
+
|
|
174
|
+
Args:
|
|
175
|
+
old_logprobs: Log probabilities from rollout collection
|
|
176
|
+
new_logprobs: Log probabilities from current policy
|
|
177
|
+
advantages: Group-relative advantages (computed same as GRPO)
|
|
178
|
+
sequence_lengths: Length of each sequence in the batch
|
|
179
|
+
variant: "sequence" for pure GSPO, "hybrid" for GSPO-token, "token" for GRPO
|
|
180
|
+
clip_ratio: Clipping ratio for surrogate objective
|
|
181
|
+
alpha: Weight for sequence-level importance (used in hybrid variant)
|
|
182
|
+
beta: Weight for token-level importance (used in hybrid variant)
|
|
183
|
+
|
|
184
|
+
Returns:
|
|
185
|
+
Policy loss scalar (to be minimized)
|
|
186
|
+
|
|
187
|
+
Key innovation:
|
|
188
|
+
- Uses sequence-level importance weights instead of token-level
|
|
189
|
+
- Reduces gradient variance and improves training stability
|
|
190
|
+
- Better alignment with sequence-level reward signals
|
|
191
|
+
"""
|
|
192
|
+
if variant == "sequence":
|
|
193
|
+
# Pure GSPO: sequence-level importance sampling
|
|
194
|
+
importance_weights = compute_sequence_importance_weights(
|
|
195
|
+
old_logprobs, new_logprobs, sequence_lengths, clip_ratio
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
# Expand advantages to match sequence weights
|
|
199
|
+
if len(advantages) != len(sequence_lengths):
|
|
200
|
+
raise ValueError(f"Advantages length {len(advantages)} doesn't match sequences {len(sequence_lengths)}")
|
|
201
|
+
|
|
202
|
+
# Apply PPO clipping to sequence-level weights
|
|
203
|
+
clipped_weights = mx.clip(importance_weights, 1 - clip_ratio, 1 + clip_ratio)
|
|
204
|
+
|
|
205
|
+
# Compute surrogate loss at sequence level
|
|
206
|
+
surr1 = importance_weights * advantages
|
|
207
|
+
surr2 = clipped_weights * advantages
|
|
208
|
+
loss = -mx.mean(mx.minimum(surr1, surr2))
|
|
209
|
+
|
|
210
|
+
elif variant == "hybrid":
|
|
211
|
+
# GSPO-token: hybrid sequence and token-level
|
|
212
|
+
importance_weights = compute_hybrid_importance_weights(
|
|
213
|
+
old_logprobs, new_logprobs, sequence_lengths, alpha=alpha, beta=beta
|
|
214
|
+
)
|
|
215
|
+
|
|
216
|
+
# Expand advantages to token level
|
|
217
|
+
token_advantages = []
|
|
218
|
+
for i, seq_len in enumerate(sequence_lengths):
|
|
219
|
+
token_advantages.extend([advantages[i]] * seq_len)
|
|
220
|
+
token_advantages = mx.array(token_advantages)
|
|
221
|
+
|
|
222
|
+
# Apply PPO clipping to hybrid weights
|
|
223
|
+
clipped_weights = mx.clip(importance_weights, 1 - clip_ratio, 1 + clip_ratio)
|
|
224
|
+
|
|
225
|
+
# Compute surrogate loss at token level
|
|
226
|
+
surr1 = importance_weights * token_advantages
|
|
227
|
+
surr2 = clipped_weights * token_advantages
|
|
228
|
+
loss = -mx.mean(mx.minimum(surr1, surr2))
|
|
229
|
+
|
|
230
|
+
elif variant == "token":
|
|
231
|
+
# Standard GRPO: token-level importance sampling (for comparison)
|
|
232
|
+
ratio = mx.exp(new_logprobs - old_logprobs)
|
|
233
|
+
|
|
234
|
+
# Expand advantages to token level
|
|
235
|
+
token_advantages = []
|
|
236
|
+
for i, seq_len in enumerate(sequence_lengths):
|
|
237
|
+
token_advantages.extend([advantages[i]] * seq_len)
|
|
238
|
+
token_advantages = mx.array(token_advantages)
|
|
239
|
+
|
|
240
|
+
# Apply PPO clipping
|
|
241
|
+
clipped_ratio = mx.clip(ratio, 1 - clip_ratio, 1 + clip_ratio)
|
|
242
|
+
|
|
243
|
+
# Compute surrogate loss
|
|
244
|
+
surr1 = ratio * token_advantages
|
|
245
|
+
surr2 = clipped_ratio * token_advantages
|
|
246
|
+
loss = -mx.mean(mx.minimum(surr1, surr2))
|
|
247
|
+
|
|
248
|
+
else:
|
|
249
|
+
raise ValueError(f"Unknown GSPO variant: {variant}. Choose 'sequence', 'hybrid', or 'token'")
|
|
250
|
+
|
|
251
|
+
return loss
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
def create_gspo_policy_loss(variant: str = "sequence", clip_ratio: float = 0.2, alpha: float = 0.5, beta: float = 0.5):
|
|
255
|
+
"""
|
|
256
|
+
Factory function to create GSPO policy loss function with standard signature.
|
|
257
|
+
|
|
258
|
+
This follows the design guidelines for pure function composition with the universal Trainer.
|
|
259
|
+
|
|
260
|
+
Args:
|
|
261
|
+
variant: GSPO variant ("sequence", "hybrid", or "token")
|
|
262
|
+
clip_ratio: PPO clipping ratio for importance weights
|
|
263
|
+
|
|
264
|
+
Returns:
|
|
265
|
+
Policy loss function with standard signature (old_logprobs, new_logprobs, advantages)
|
|
266
|
+
|
|
267
|
+
Usage:
|
|
268
|
+
trainer = Trainer(
|
|
269
|
+
model=model,
|
|
270
|
+
advantage_fn=grpo.compute_advantages_dr_grpo,
|
|
271
|
+
loss_fn=gspo.create_gspo_policy_loss(variant="sequence"),
|
|
272
|
+
optimizer=optimizer
|
|
273
|
+
)
|
|
274
|
+
"""
|
|
275
|
+
def gspo_policy_loss_fn(old_logprobs: mx.array, new_logprobs: mx.array, advantages: mx.array) -> mx.array:
|
|
276
|
+
"""
|
|
277
|
+
GSPO policy loss with sequence-level importance sampling.
|
|
278
|
+
|
|
279
|
+
Standard signature for use with universal Trainer.
|
|
280
|
+
"""
|
|
281
|
+
# For GSPO, we need sequence lengths. This is a limitation that requires
|
|
282
|
+
# the batch_data to include sequence_lengths information.
|
|
283
|
+
# For now, we'll use a fallback approach for compatibility.
|
|
284
|
+
|
|
285
|
+
# Robust fallback: distribute tokens as evenly as possible across episodes
|
|
286
|
+
# This handles variable-length sequences by distributing remainder tokens
|
|
287
|
+
total_tokens = len(old_logprobs) if len(old_logprobs.shape) == 1 else old_logprobs.shape[0] # type: ignore
|
|
288
|
+
num_episodes = len(advantages)
|
|
289
|
+
|
|
290
|
+
if num_episodes > 0:
|
|
291
|
+
base_length = total_tokens // num_episodes
|
|
292
|
+
remainder = total_tokens % num_episodes
|
|
293
|
+
# Distribute remainder tokens to first 'remainder' episodes
|
|
294
|
+
sequence_lengths = [base_length + (1 if i < remainder else 0) for i in range(num_episodes)]
|
|
295
|
+
else:
|
|
296
|
+
sequence_lengths = [total_tokens] if total_tokens > 0 else [1]
|
|
297
|
+
|
|
298
|
+
return gspo_policy_loss(
|
|
299
|
+
old_logprobs=old_logprobs,
|
|
300
|
+
new_logprobs=new_logprobs,
|
|
301
|
+
advantages=advantages,
|
|
302
|
+
sequence_lengths=sequence_lengths,
|
|
303
|
+
variant=variant,
|
|
304
|
+
clip_ratio=clip_ratio,
|
|
305
|
+
alpha=alpha,
|
|
306
|
+
beta=beta
|
|
307
|
+
)
|
|
308
|
+
|
|
309
|
+
return gspo_policy_loss_fn
|
|
310
|
+
|
|
311
|
+
|
|
312
|
+
def create_gspo_metrics(variant: str = "sequence", clip_ratio: float = 0.2):
|
|
313
|
+
"""
|
|
314
|
+
Factory function to create GSPO metrics function with standard signature.
|
|
315
|
+
|
|
316
|
+
Args:
|
|
317
|
+
variant: GSPO variant being used
|
|
318
|
+
clip_ratio: Clipping ratio used in loss
|
|
319
|
+
|
|
320
|
+
Returns:
|
|
321
|
+
Metrics function with standard signature
|
|
322
|
+
|
|
323
|
+
Usage:
|
|
324
|
+
trainer = Trainer(
|
|
325
|
+
model=model,
|
|
326
|
+
advantage_fn=grpo.compute_advantages_dr_grpo,
|
|
327
|
+
loss_fn=gspo.create_gspo_policy_loss(variant="sequence"),
|
|
328
|
+
metrics_fn=gspo.create_gspo_metrics(variant="sequence"),
|
|
329
|
+
optimizer=optimizer
|
|
330
|
+
)
|
|
331
|
+
"""
|
|
332
|
+
def gspo_metrics_fn(old_logprobs: mx.array, new_logprobs: mx.array, advantages: mx.array) -> Dict[str, float]:
|
|
333
|
+
"""GSPO metrics with sequence-level importance weight tracking."""
|
|
334
|
+
# Robust fallback: distribute tokens as evenly as possible across episodes
|
|
335
|
+
# This matches the same robust approach used in the policy loss function
|
|
336
|
+
total_tokens = len(old_logprobs) if len(old_logprobs.shape) == 1 else old_logprobs.shape[0] # type: ignore
|
|
337
|
+
num_episodes = len(advantages)
|
|
338
|
+
|
|
339
|
+
if num_episodes > 0:
|
|
340
|
+
base_length = total_tokens // num_episodes
|
|
341
|
+
remainder = total_tokens % num_episodes
|
|
342
|
+
# Distribute remainder tokens to first 'remainder' episodes
|
|
343
|
+
sequence_lengths = [base_length + (1 if i < remainder else 0) for i in range(num_episodes)]
|
|
344
|
+
else:
|
|
345
|
+
sequence_lengths = [total_tokens] if total_tokens > 0 else [1]
|
|
346
|
+
|
|
347
|
+
return compute_gspo_metrics(
|
|
348
|
+
old_logprobs=old_logprobs,
|
|
349
|
+
new_logprobs=new_logprobs,
|
|
350
|
+
advantages=advantages,
|
|
351
|
+
sequence_lengths=sequence_lengths,
|
|
352
|
+
variant=variant,
|
|
353
|
+
clip_ratio=clip_ratio
|
|
354
|
+
)
|
|
355
|
+
|
|
356
|
+
return gspo_metrics_fn
|
|
357
|
+
|
|
358
|
+
|
|
359
|
+
# Convenience functions that match GRPO interface
|
|
360
|
+
def policy_loss_sequence(old_logprobs: mx.array, new_logprobs: mx.array, advantages: mx.array) -> mx.array:
|
|
361
|
+
"""GSPO sequence-level policy loss function (standard signature)."""
|
|
362
|
+
return create_gspo_policy_loss(variant="sequence")(old_logprobs, new_logprobs, advantages)
|
|
363
|
+
|
|
364
|
+
|
|
365
|
+
def policy_loss_hybrid(old_logprobs: mx.array, new_logprobs: mx.array, advantages: mx.array) -> mx.array:
|
|
366
|
+
"""GSPO hybrid policy loss function (standard signature)."""
|
|
367
|
+
return create_gspo_policy_loss(variant="hybrid")(old_logprobs, new_logprobs, advantages)
|
|
368
|
+
|
|
369
|
+
def create_policy_loss_hybrid(alpha: float = 0.5, beta: float = 0.5):
|
|
370
|
+
"""
|
|
371
|
+
Create a GSPO hybrid policy loss function with configurable hyperparameters.
|
|
372
|
+
|
|
373
|
+
Args:
|
|
374
|
+
alpha: Weight for sequence-level importance (0.0 = pure token-level, 1.0 = pure sequence-level)
|
|
375
|
+
beta: Weight for token-level importance (0.0 = ignore token-level, 1.0 = full token-level)
|
|
376
|
+
|
|
377
|
+
Returns:
|
|
378
|
+
Policy loss function with standard signature
|
|
379
|
+
|
|
380
|
+
Example:
|
|
381
|
+
# Balanced hybrid (default)
|
|
382
|
+
loss_fn = create_policy_loss_hybrid(alpha=0.5, beta=0.5)
|
|
383
|
+
|
|
384
|
+
# More sequence-focused
|
|
385
|
+
loss_fn = create_policy_loss_hybrid(alpha=0.7, beta=0.3)
|
|
386
|
+
|
|
387
|
+
# More token-focused
|
|
388
|
+
loss_fn = create_policy_loss_hybrid(alpha=0.3, beta=0.7)
|
|
389
|
+
"""
|
|
390
|
+
def hybrid_loss_fn(old_logprobs: mx.array, new_logprobs: mx.array, advantages: mx.array) -> mx.array:
|
|
391
|
+
# Use custom alpha/beta parameters for this specific loss function
|
|
392
|
+
return create_gspo_policy_loss(variant="hybrid", alpha=alpha, beta=beta)(
|
|
393
|
+
old_logprobs, new_logprobs, advantages
|
|
394
|
+
)
|
|
395
|
+
return hybrid_loss_fn
|
|
396
|
+
|
|
397
|
+
|
|
398
|
+
def policy_loss_token(old_logprobs: mx.array, new_logprobs: mx.array, advantages: mx.array) -> mx.array:
|
|
399
|
+
"""GSPO token-level policy loss function (standard signature) - equivalent to GRPO."""
|
|
400
|
+
return create_gspo_policy_loss(variant="token")(old_logprobs, new_logprobs, advantages)
|
|
401
|
+
|
|
402
|
+
|
|
403
|
+
def compute_metrics_sequence(old_logprobs: mx.array, new_logprobs: mx.array, advantages: mx.array) -> Dict[str, float]:
|
|
404
|
+
"""GSPO sequence-level metrics function (standard signature)."""
|
|
405
|
+
return create_gspo_metrics(variant="sequence")(old_logprobs, new_logprobs, advantages)
|
|
406
|
+
|
|
407
|
+
|
|
408
|
+
def compute_metrics_hybrid(old_logprobs: mx.array, new_logprobs: mx.array, advantages: mx.array) -> Dict[str, float]:
|
|
409
|
+
"""GSPO hybrid metrics function (standard signature)."""
|
|
410
|
+
return create_gspo_metrics(variant="hybrid")(old_logprobs, new_logprobs, advantages)
|
|
411
|
+
|
|
412
|
+
|
|
413
|
+
def compute_metrics_token(old_logprobs: mx.array, new_logprobs: mx.array, advantages: mx.array) -> Dict[str, float]:
|
|
414
|
+
"""GSPO token-level metrics function (standard signature)."""
|
|
415
|
+
return create_gspo_metrics(variant="token")(old_logprobs, new_logprobs, advantages)
|
|
416
|
+
|
|
417
|
+
|
|
418
|
+
def compute_gspo_metrics(
|
|
419
|
+
old_logprobs: mx.array,
|
|
420
|
+
new_logprobs: mx.array,
|
|
421
|
+
advantages: mx.array,
|
|
422
|
+
sequence_lengths: List[int],
|
|
423
|
+
variant: str = "sequence",
|
|
424
|
+
clip_ratio: float = 0.2
|
|
425
|
+
) -> dict:
|
|
426
|
+
"""
|
|
427
|
+
Compute GSPO training metrics for monitoring.
|
|
428
|
+
|
|
429
|
+
Args:
|
|
430
|
+
old_logprobs: Log probabilities from rollout
|
|
431
|
+
new_logprobs: Log probabilities from current policy
|
|
432
|
+
advantages: Group-relative advantages
|
|
433
|
+
sequence_lengths: Length of each sequence in the batch
|
|
434
|
+
variant: GSPO variant being used
|
|
435
|
+
clip_ratio: Clipping ratio used in loss
|
|
436
|
+
|
|
437
|
+
Returns:
|
|
438
|
+
Dictionary of metrics for logging/monitoring
|
|
439
|
+
|
|
440
|
+
Additional GSPO-specific metrics:
|
|
441
|
+
- Sequence-level importance weight statistics
|
|
442
|
+
- Gradient variance estimates
|
|
443
|
+
- Length bias indicators
|
|
444
|
+
"""
|
|
445
|
+
# Standard advantage metrics
|
|
446
|
+
metrics = {
|
|
447
|
+
'mean_advantage': mx.mean(advantages).item(),
|
|
448
|
+
'std_advantage': mx.std(advantages).item(),
|
|
449
|
+
'min_advantage': mx.min(advantages).item(),
|
|
450
|
+
'max_advantage': mx.max(advantages).item()
|
|
451
|
+
}
|
|
452
|
+
|
|
453
|
+
if variant == "sequence":
|
|
454
|
+
# Sequence-level importance weights
|
|
455
|
+
seq_weights = compute_sequence_importance_weights(
|
|
456
|
+
old_logprobs, new_logprobs, sequence_lengths, clip_ratio
|
|
457
|
+
)
|
|
458
|
+
|
|
459
|
+
# Sequence weight statistics
|
|
460
|
+
metrics.update({
|
|
461
|
+
'mean_seq_weight': mx.mean(seq_weights).item(),
|
|
462
|
+
'std_seq_weight': mx.std(seq_weights).item(),
|
|
463
|
+
'max_seq_weight': mx.max(seq_weights).item(),
|
|
464
|
+
'min_seq_weight': mx.min(seq_weights).item()
|
|
465
|
+
})
|
|
466
|
+
|
|
467
|
+
# Clipping statistics at sequence level
|
|
468
|
+
clipped = (seq_weights < (1 - clip_ratio)) | (seq_weights > (1 + clip_ratio))
|
|
469
|
+
metrics['seq_clip_fraction'] = mx.mean(clipped.astype(mx.float32)).item()
|
|
470
|
+
|
|
471
|
+
elif variant == "hybrid":
|
|
472
|
+
# Hybrid importance weights
|
|
473
|
+
hybrid_weights = compute_hybrid_importance_weights(
|
|
474
|
+
old_logprobs, new_logprobs, sequence_lengths
|
|
475
|
+
)
|
|
476
|
+
|
|
477
|
+
# Hybrid weight statistics
|
|
478
|
+
metrics.update({
|
|
479
|
+
'mean_hybrid_weight': mx.mean(hybrid_weights).item(),
|
|
480
|
+
'std_hybrid_weight': mx.std(hybrid_weights).item(),
|
|
481
|
+
'max_hybrid_weight': mx.max(hybrid_weights).item(),
|
|
482
|
+
'min_hybrid_weight': mx.min(hybrid_weights).item()
|
|
483
|
+
})
|
|
484
|
+
|
|
485
|
+
# Clipping statistics at token level
|
|
486
|
+
clipped = (hybrid_weights < (1 - clip_ratio)) | (hybrid_weights > (1 + clip_ratio))
|
|
487
|
+
metrics['hybrid_clip_fraction'] = mx.mean(clipped.astype(mx.float32)).item()
|
|
488
|
+
|
|
489
|
+
else: # token-level (standard GRPO)
|
|
490
|
+
# Token-level importance ratios
|
|
491
|
+
ratio = mx.exp(new_logprobs - old_logprobs)
|
|
492
|
+
|
|
493
|
+
metrics.update({
|
|
494
|
+
'mean_token_ratio': mx.mean(ratio).item(),
|
|
495
|
+
'std_token_ratio': mx.std(ratio).item(),
|
|
496
|
+
'max_token_ratio': mx.max(ratio).item(),
|
|
497
|
+
'min_token_ratio': mx.min(ratio).item()
|
|
498
|
+
})
|
|
499
|
+
|
|
500
|
+
# Clipping statistics at token level
|
|
501
|
+
clipped = (ratio < (1 - clip_ratio)) | (ratio > (1 + clip_ratio))
|
|
502
|
+
metrics['token_clip_fraction'] = mx.mean(clipped.astype(mx.float32)).item()
|
|
503
|
+
|
|
504
|
+
# Length bias analysis
|
|
505
|
+
if len(sequence_lengths) > 1:
|
|
506
|
+
length_array = mx.array(sequence_lengths, dtype=mx.float32)
|
|
507
|
+
metrics.update({
|
|
508
|
+
'mean_seq_length': mx.mean(length_array).item(),
|
|
509
|
+
'std_seq_length': mx.std(length_array).item(),
|
|
510
|
+
'min_seq_length': mx.min(length_array).item(),
|
|
511
|
+
'max_seq_length': mx.max(length_array).item()
|
|
512
|
+
})
|
|
513
|
+
|
|
514
|
+
# KL divergence approximation
|
|
515
|
+
kl_div = mx.mean(old_logprobs - new_logprobs)
|
|
516
|
+
metrics['kl_divergence'] = kl_div.item()
|
|
517
|
+
|
|
518
|
+
return metrics
|
|
519
|
+
|
|
520
|
+
|
|
521
|
+
# Algorithm-specific data selectors for GSPO
|
|
522
|
+
def select_gspo_data(buffer, variant: str = "sequence"):
|
|
523
|
+
"""
|
|
524
|
+
GSPO data selector: Use all available data with sequence-level organization.
|
|
525
|
+
|
|
526
|
+
GSPO requires sequence length information for proper importance weight computation.
|
|
527
|
+
This selector ensures sequence boundaries are preserved in the batch data.
|
|
528
|
+
|
|
529
|
+
Args:
|
|
530
|
+
buffer: Buffer containing episodes
|
|
531
|
+
variant: GSPO variant ("sequence", "hybrid", or "token")
|
|
532
|
+
|
|
533
|
+
Returns:
|
|
534
|
+
Batch data organized for GSPO training with sequence length metadata
|
|
535
|
+
"""
|
|
536
|
+
from .grpo import select_all_data
|
|
537
|
+
|
|
538
|
+
# Reuse GRPO's data selection but add sequence length tracking
|
|
539
|
+
batch_data = select_all_data(buffer)
|
|
540
|
+
|
|
541
|
+
# GSPO-specific enhancement: explicit sequence length tracking
|
|
542
|
+
# This ensures proper importance weight computation
|
|
543
|
+
if 'episode_lengths' in batch_data:
|
|
544
|
+
# Use episode lengths as sequence lengths for GSPO
|
|
545
|
+
batch_data['sequence_lengths'] = batch_data['episode_lengths']
|
|
546
|
+
else:
|
|
547
|
+
# Fallback: infer sequence lengths from batch structure
|
|
548
|
+
# This is less ideal but provides compatibility
|
|
549
|
+
total_tokens = len(batch_data['obs']) if 'obs' in batch_data else 0
|
|
550
|
+
num_episodes = len(batch_data['rewards']) if 'rewards' in batch_data else 1
|
|
551
|
+
avg_length = total_tokens // num_episodes if num_episodes > 0 else 0
|
|
552
|
+
batch_data['sequence_lengths'] = [avg_length] * num_episodes
|
|
553
|
+
|
|
554
|
+
return batch_data
|
|
555
|
+
|
|
556
|
+
|
|
557
|
+
# Compiled versions for maximum performance
|
|
558
|
+
@mx.compile
|
|
559
|
+
def compute_sequence_weights_compiled(
|
|
560
|
+
old_logprobs: mx.array,
|
|
561
|
+
new_logprobs: mx.array,
|
|
562
|
+
seq_len: int
|
|
563
|
+
) -> mx.array:
|
|
564
|
+
"""Compiled version of sequence weight computation for a single sequence."""
|
|
565
|
+
old_seq_logprob = mx.sum(old_logprobs)
|
|
566
|
+
new_seq_logprob = mx.sum(new_logprobs)
|
|
567
|
+
log_ratio = new_seq_logprob - old_seq_logprob
|
|
568
|
+
normalized_log_ratio = log_ratio / seq_len
|
|
569
|
+
return mx.exp(normalized_log_ratio)
|
|
570
|
+
|
|
571
|
+
|
|
572
|
+
@mx.compile
|
|
573
|
+
def gspo_loss_compiled(
|
|
574
|
+
importance_weights: mx.array,
|
|
575
|
+
advantages: mx.array,
|
|
576
|
+
clip_ratio: float = 0.2
|
|
577
|
+
) -> mx.array:
|
|
578
|
+
"""Compiled version of GSPO surrogate loss computation."""
|
|
579
|
+
clipped_weights = mx.clip(importance_weights, 1 - clip_ratio, 1 + clip_ratio)
|
|
580
|
+
surr1 = importance_weights * advantages
|
|
581
|
+
surr2 = clipped_weights * advantages
|
|
582
|
+
return -mx.mean(mx.minimum(surr1, surr2))
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
# textpolicy/buffer/__init__.py
|
|
2
|
+
"""
|
|
3
|
+
Modular buffer system for TextPolicy.
|
|
4
|
+
|
|
5
|
+
Main components:
|
|
6
|
+
- Episode: Single episode trajectory management
|
|
7
|
+
- Buffer: Multi-episode storage and sampling
|
|
8
|
+
- BufferStorage: Storage and capacity management
|
|
9
|
+
- BufferSampler: Data retrieval and sampling methods
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from .episode import Episode
|
|
13
|
+
from .buffer import Buffer
|
|
14
|
+
from .storage import BufferStorage
|
|
15
|
+
from .sampling import BufferSampler
|
|
16
|
+
|
|
17
|
+
# Backwards compatibility - maintain existing import structure
|
|
18
|
+
__all__ = [
|
|
19
|
+
'Episode',
|
|
20
|
+
'Buffer',
|
|
21
|
+
'BufferStorage',
|
|
22
|
+
'BufferSampler',
|
|
23
|
+
]
|