textpolicy 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. textpolicy/__init__.py +52 -0
  2. textpolicy/__main__.py +8 -0
  3. textpolicy/algorithms/__init__.py +54 -0
  4. textpolicy/algorithms/grpo.py +642 -0
  5. textpolicy/algorithms/gspo.py +582 -0
  6. textpolicy/buffer/__init__.py +23 -0
  7. textpolicy/buffer/buffer.py +244 -0
  8. textpolicy/buffer/episode.py +383 -0
  9. textpolicy/buffer/sampling.py +438 -0
  10. textpolicy/buffer/storage.py +255 -0
  11. textpolicy/cli.py +67 -0
  12. textpolicy/environment/__init__.py +79 -0
  13. textpolicy/environment/base.py +110 -0
  14. textpolicy/environment/environment.py +46 -0
  15. textpolicy/environment/factory.py +103 -0
  16. textpolicy/environment/gym.py +106 -0
  17. textpolicy/environment/task_suites.py +51 -0
  18. textpolicy/environment/text_generation.py +789 -0
  19. textpolicy/environment/vectorized.py +253 -0
  20. textpolicy/generation/__init__.py +62 -0
  21. textpolicy/generation/lora.py +411 -0
  22. textpolicy/generation/mlx_generation.py +557 -0
  23. textpolicy/generation/reload.py +253 -0
  24. textpolicy/rewards/__init__.py +137 -0
  25. textpolicy/rewards/adapters.py +387 -0
  26. textpolicy/rewards/basic.py +214 -0
  27. textpolicy/rewards/integrated_system.py +338 -0
  28. textpolicy/rewards/mlx_batch_processor.py +447 -0
  29. textpolicy/rewards/registry.py +293 -0
  30. textpolicy/rewards/rollout_rewards.py +410 -0
  31. textpolicy/rewards/verifiers.py +369 -0
  32. textpolicy/rollout/__init__.py +44 -0
  33. textpolicy/rollout/aggregator.py +145 -0
  34. textpolicy/rollout/base.py +108 -0
  35. textpolicy/rollout/rollout.py +142 -0
  36. textpolicy/rollout/runner.py +280 -0
  37. textpolicy/rollout/strategy.py +208 -0
  38. textpolicy/rollout/worker.py +194 -0
  39. textpolicy/training/__init__.py +14 -0
  40. textpolicy/training/metrics.py +242 -0
  41. textpolicy/training/rollout_manager.py +78 -0
  42. textpolicy/training/trainer.py +684 -0
  43. textpolicy/utils/__init__.py +40 -0
  44. textpolicy/utils/benchmarking.py +489 -0
  45. textpolicy/utils/data.py +60 -0
  46. textpolicy/utils/debug.py +170 -0
  47. textpolicy/utils/environment.py +349 -0
  48. textpolicy/utils/logging/__init__.py +22 -0
  49. textpolicy/utils/logging/base.py +48 -0
  50. textpolicy/utils/logging/console.py +61 -0
  51. textpolicy/utils/logging/factory.py +133 -0
  52. textpolicy/utils/logging/multi.py +83 -0
  53. textpolicy/utils/logging/tensorboard.py +65 -0
  54. textpolicy/utils/logging/wandb.py +72 -0
  55. textpolicy/utils/memory.py +118 -0
  56. textpolicy/utils/performance.py +464 -0
  57. textpolicy/utils/timing.py +171 -0
  58. textpolicy/validate.py +101 -0
  59. textpolicy/validation/__init__.py +13 -0
  60. textpolicy/validation/logprob_validation.py +315 -0
  61. textpolicy-0.1.0.dist-info/METADATA +99 -0
  62. textpolicy-0.1.0.dist-info/RECORD +66 -0
  63. textpolicy-0.1.0.dist-info/entry_points.txt +2 -0
  64. textpolicy-0.0.1.dist-info/METADATA +0 -10
  65. textpolicy-0.0.1.dist-info/RECORD +0 -6
  66. {textpolicy-0.0.1.dist-info → textpolicy-0.1.0.dist-info}/WHEEL +0 -0
  67. {textpolicy-0.0.1.dist-info → textpolicy-0.1.0.dist-info}/licenses/LICENSE +0 -0
  68. {textpolicy-0.0.1.dist-info → textpolicy-0.1.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,582 @@
1
+ # textpolicy/algorithms/gspo.py
2
+ """
3
+ Group Sequence Policy Optimization (GSPO).
4
+
5
+ GSPO computes importance weights at the sequence level to align with
6
+ sequence-level rewards. Variants include sequence, token, and hybrid forms.
7
+ Reference: https://swift.readthedocs.io/en/latest/Instruction/GRPO/AdvancedResearch/GSPO.html
8
+ """
9
+ from __future__ import annotations
10
+ import mlx.core as mx
11
+ from typing import List, Dict
12
+
13
+
14
+ def compute_sequence_importance_weights(
15
+ old_logprobs: mx.array,
16
+ new_logprobs: mx.array,
17
+ sequence_lengths: List[int],
18
+ clip_ratio: float = 0.2
19
+ ) -> mx.array:
20
+ """
21
+ Compute sequence-level importance weights for GSPO.
22
+
23
+ GSPO formula: w^GSPO_{i} = [π_θ(y_i | x) / π_θ_old(y_i | x)]^(1/|y_i|)
24
+
25
+ This normalizes by sequence length to prevent bias toward shorter/longer sequences.
26
+
27
+ Args:
28
+ old_logprobs: Log probabilities from rollout collection [batch_size, seq_len]
29
+ new_logprobs: Log probabilities from current policy [batch_size, seq_len]
30
+ sequence_lengths: Length of each sequence in the batch
31
+
32
+ Returns:
33
+ Sequence-level importance weights [batch_size]
34
+
35
+ Compared to token-level sampling, this reduces variance and matches
36
+ sequence-level reward assignment.
37
+ """
38
+ batch_size = len(sequence_lengths)
39
+ sequence_weights = []
40
+
41
+ current_idx = 0
42
+ for seq_len in sequence_lengths:
43
+ # Extract logprobs for this sequence
44
+ seq_old_logprobs = old_logprobs[current_idx:current_idx + seq_len]
45
+ seq_new_logprobs = new_logprobs[current_idx:current_idx + seq_len] # type: ignore
46
+
47
+ # Compute sequence-level log probability: sum of token log probs
48
+ old_seq_logprob = mx.sum(seq_old_logprobs)
49
+ new_seq_logprob = mx.sum(seq_new_logprobs)
50
+
51
+ # Sequence-level importance ratio: π_new(y|x) / π_old(y|x)
52
+ log_ratio = new_seq_logprob - old_seq_logprob
53
+
54
+ # GSPO normalization: raise to power 1/|y_i| to prevent length bias
55
+ # This ensures sequences of different lengths contribute equally
56
+ normalized_log_ratio = log_ratio / seq_len
57
+
58
+ # Clip in log space to prevent numerical explosion
59
+ # This is the key missing piece that was causing billion-scale importance weights
60
+ clipped_log_ratio = mx.clip(
61
+ normalized_log_ratio,
62
+ mx.log(mx.array(1 - clip_ratio)), # log(0.8) ≈ -0.22
63
+ mx.log(mx.array(1 + clip_ratio)) # log(1.2) ≈ 0.18
64
+ )
65
+
66
+ # Now safely compute importance weight (will be in range [0.8, 1.2])
67
+ importance_weight = mx.exp(clipped_log_ratio)
68
+
69
+ # Final check: account for float32 precision
70
+ # exp(log(1.2)) in float32 may produce 1.2000000476837158; enforce exact bounds
71
+ importance_weight = mx.clip(importance_weight, 1.0 - clip_ratio, 1.0 + clip_ratio)
72
+
73
+ # Enforce exact bounds using scalar comparisons
74
+ # Convert to float for comparison to avoid MLX array comparison issues
75
+ weight_float = float(importance_weight)
76
+ if weight_float > 1.0 + clip_ratio:
77
+ importance_weight = mx.array(1.0 + clip_ratio)
78
+ elif weight_float < 1.0 - clip_ratio:
79
+ importance_weight = mx.array(1.0 - clip_ratio)
80
+
81
+ sequence_weights.append(importance_weight)
82
+ current_idx += seq_len
83
+
84
+ return mx.array(sequence_weights)
85
+
86
+
87
+ def compute_hybrid_importance_weights(
88
+ old_logprobs: mx.array,
89
+ new_logprobs: mx.array,
90
+ sequence_lengths: List[int],
91
+ alpha: float = 0.5,
92
+ beta: float = 0.5
93
+ ) -> mx.array:
94
+ """
95
+ Compute hybrid importance weights using principled log-space combination.
96
+
97
+ Instead of multiplying exp(seq_ratio) * exp(token_ratio) which compounds variance,
98
+ uses additive combination: exp(α * seq_log_ratio + β * token_log_ratio)
99
+
100
+ This provides a more stable and theoretically sound approach to combining
101
+ sequence-level stability with token-level granularity.
102
+
103
+ Args:
104
+ old_logprobs: Log probabilities from rollout collection [total_tokens]
105
+ new_logprobs: Log probabilities from current policy [total_tokens]
106
+ sequence_lengths: Length of each sequence in the batch
107
+ alpha: Weight for sequence-level importance (default: 0.5)
108
+ beta: Weight for token-level importance (default: 0.5)
109
+
110
+ Returns:
111
+ Hybrid importance weights [total_tokens]
112
+
113
+ Advantages:
114
+ - Avoids explosive multiplication of exponentials
115
+ - Controlled variance through hyperparameter balance
116
+ - Principled combination in log-space
117
+ """
118
+ # Compute sequence-level log ratios (without exponential)
119
+ batch_size = len(sequence_lengths)
120
+ seq_log_ratios = []
121
+
122
+ current_idx = 0
123
+ for seq_len in sequence_lengths:
124
+ # Extract logprobs for this sequence
125
+ seq_old_logprobs = old_logprobs[current_idx:current_idx + seq_len]
126
+ seq_new_logprobs = new_logprobs[current_idx:current_idx + seq_len] # type: ignore
127
+
128
+ # Compute sequence-level log probability: sum of token log probs
129
+ old_seq_logprob = mx.sum(seq_old_logprobs)
130
+ new_seq_logprob = mx.sum(seq_new_logprobs)
131
+
132
+ # Sequence-level log ratio with GSPO normalization (prevent length bias)
133
+ log_ratio = new_seq_logprob - old_seq_logprob
134
+ normalized_seq_log_ratio = log_ratio / seq_len
135
+
136
+ seq_log_ratios.append(normalized_seq_log_ratio)
137
+ current_idx += seq_len
138
+
139
+ # Expand sequence-level log ratios to token level
140
+ token_seq_log_ratios = []
141
+ for i, seq_len in enumerate(sequence_lengths):
142
+ # Use stop gradient to prevent certain gradient flows
143
+ seq_log_ratio_sg = mx.stop_gradient(seq_log_ratios[i])
144
+ token_seq_log_ratios.extend([seq_log_ratio_sg] * seq_len)
145
+
146
+ token_seq_log_ratios = mx.array(token_seq_log_ratios)
147
+
148
+ # Compute token-level log ratios (with stop gradient on old logprobs)
149
+ old_logprobs_sg = mx.stop_gradient(old_logprobs)
150
+ token_log_ratios = new_logprobs - old_logprobs_sg
151
+
152
+ # Combine in log-space: α * seq_log_ratio + β * token_log_ratio
153
+ combined_log_ratios = alpha * token_seq_log_ratios + beta * token_log_ratios
154
+
155
+ # Apply single exponential to get final importance weights
156
+ hybrid_weights = mx.exp(combined_log_ratios)
157
+
158
+ return hybrid_weights
159
+
160
+
161
+ def gspo_policy_loss(
162
+ old_logprobs: mx.array,
163
+ new_logprobs: mx.array,
164
+ advantages: mx.array,
165
+ sequence_lengths: List[int],
166
+ variant: str = "sequence",
167
+ clip_ratio: float = 0.2,
168
+ alpha: float = 0.5,
169
+ beta: float = 0.5
170
+ ) -> mx.array:
171
+ """
172
+ GSPO policy loss with sequence-level importance sampling.
173
+
174
+ Args:
175
+ old_logprobs: Log probabilities from rollout collection
176
+ new_logprobs: Log probabilities from current policy
177
+ advantages: Group-relative advantages (computed same as GRPO)
178
+ sequence_lengths: Length of each sequence in the batch
179
+ variant: "sequence" for pure GSPO, "hybrid" for GSPO-token, "token" for GRPO
180
+ clip_ratio: Clipping ratio for surrogate objective
181
+ alpha: Weight for sequence-level importance (used in hybrid variant)
182
+ beta: Weight for token-level importance (used in hybrid variant)
183
+
184
+ Returns:
185
+ Policy loss scalar (to be minimized)
186
+
187
+ Key innovation:
188
+ - Uses sequence-level importance weights instead of token-level
189
+ - Reduces gradient variance and improves training stability
190
+ - Better alignment with sequence-level reward signals
191
+ """
192
+ if variant == "sequence":
193
+ # Pure GSPO: sequence-level importance sampling
194
+ importance_weights = compute_sequence_importance_weights(
195
+ old_logprobs, new_logprobs, sequence_lengths, clip_ratio
196
+ )
197
+
198
+ # Expand advantages to match sequence weights
199
+ if len(advantages) != len(sequence_lengths):
200
+ raise ValueError(f"Advantages length {len(advantages)} doesn't match sequences {len(sequence_lengths)}")
201
+
202
+ # Apply PPO clipping to sequence-level weights
203
+ clipped_weights = mx.clip(importance_weights, 1 - clip_ratio, 1 + clip_ratio)
204
+
205
+ # Compute surrogate loss at sequence level
206
+ surr1 = importance_weights * advantages
207
+ surr2 = clipped_weights * advantages
208
+ loss = -mx.mean(mx.minimum(surr1, surr2))
209
+
210
+ elif variant == "hybrid":
211
+ # GSPO-token: hybrid sequence and token-level
212
+ importance_weights = compute_hybrid_importance_weights(
213
+ old_logprobs, new_logprobs, sequence_lengths, alpha=alpha, beta=beta
214
+ )
215
+
216
+ # Expand advantages to token level
217
+ token_advantages = []
218
+ for i, seq_len in enumerate(sequence_lengths):
219
+ token_advantages.extend([advantages[i]] * seq_len)
220
+ token_advantages = mx.array(token_advantages)
221
+
222
+ # Apply PPO clipping to hybrid weights
223
+ clipped_weights = mx.clip(importance_weights, 1 - clip_ratio, 1 + clip_ratio)
224
+
225
+ # Compute surrogate loss at token level
226
+ surr1 = importance_weights * token_advantages
227
+ surr2 = clipped_weights * token_advantages
228
+ loss = -mx.mean(mx.minimum(surr1, surr2))
229
+
230
+ elif variant == "token":
231
+ # Standard GRPO: token-level importance sampling (for comparison)
232
+ ratio = mx.exp(new_logprobs - old_logprobs)
233
+
234
+ # Expand advantages to token level
235
+ token_advantages = []
236
+ for i, seq_len in enumerate(sequence_lengths):
237
+ token_advantages.extend([advantages[i]] * seq_len)
238
+ token_advantages = mx.array(token_advantages)
239
+
240
+ # Apply PPO clipping
241
+ clipped_ratio = mx.clip(ratio, 1 - clip_ratio, 1 + clip_ratio)
242
+
243
+ # Compute surrogate loss
244
+ surr1 = ratio * token_advantages
245
+ surr2 = clipped_ratio * token_advantages
246
+ loss = -mx.mean(mx.minimum(surr1, surr2))
247
+
248
+ else:
249
+ raise ValueError(f"Unknown GSPO variant: {variant}. Choose 'sequence', 'hybrid', or 'token'")
250
+
251
+ return loss
252
+
253
+
254
+ def create_gspo_policy_loss(variant: str = "sequence", clip_ratio: float = 0.2, alpha: float = 0.5, beta: float = 0.5):
255
+ """
256
+ Factory function to create GSPO policy loss function with standard signature.
257
+
258
+ This follows the design guidelines for pure function composition with the universal Trainer.
259
+
260
+ Args:
261
+ variant: GSPO variant ("sequence", "hybrid", or "token")
262
+ clip_ratio: PPO clipping ratio for importance weights
263
+
264
+ Returns:
265
+ Policy loss function with standard signature (old_logprobs, new_logprobs, advantages)
266
+
267
+ Usage:
268
+ trainer = Trainer(
269
+ model=model,
270
+ advantage_fn=grpo.compute_advantages_dr_grpo,
271
+ loss_fn=gspo.create_gspo_policy_loss(variant="sequence"),
272
+ optimizer=optimizer
273
+ )
274
+ """
275
+ def gspo_policy_loss_fn(old_logprobs: mx.array, new_logprobs: mx.array, advantages: mx.array) -> mx.array:
276
+ """
277
+ GSPO policy loss with sequence-level importance sampling.
278
+
279
+ Standard signature for use with universal Trainer.
280
+ """
281
+ # For GSPO, we need sequence lengths. This is a limitation that requires
282
+ # the batch_data to include sequence_lengths information.
283
+ # For now, we'll use a fallback approach for compatibility.
284
+
285
+ # Robust fallback: distribute tokens as evenly as possible across episodes
286
+ # This handles variable-length sequences by distributing remainder tokens
287
+ total_tokens = len(old_logprobs) if len(old_logprobs.shape) == 1 else old_logprobs.shape[0] # type: ignore
288
+ num_episodes = len(advantages)
289
+
290
+ if num_episodes > 0:
291
+ base_length = total_tokens // num_episodes
292
+ remainder = total_tokens % num_episodes
293
+ # Distribute remainder tokens to first 'remainder' episodes
294
+ sequence_lengths = [base_length + (1 if i < remainder else 0) for i in range(num_episodes)]
295
+ else:
296
+ sequence_lengths = [total_tokens] if total_tokens > 0 else [1]
297
+
298
+ return gspo_policy_loss(
299
+ old_logprobs=old_logprobs,
300
+ new_logprobs=new_logprobs,
301
+ advantages=advantages,
302
+ sequence_lengths=sequence_lengths,
303
+ variant=variant,
304
+ clip_ratio=clip_ratio,
305
+ alpha=alpha,
306
+ beta=beta
307
+ )
308
+
309
+ return gspo_policy_loss_fn
310
+
311
+
312
+ def create_gspo_metrics(variant: str = "sequence", clip_ratio: float = 0.2):
313
+ """
314
+ Factory function to create GSPO metrics function with standard signature.
315
+
316
+ Args:
317
+ variant: GSPO variant being used
318
+ clip_ratio: Clipping ratio used in loss
319
+
320
+ Returns:
321
+ Metrics function with standard signature
322
+
323
+ Usage:
324
+ trainer = Trainer(
325
+ model=model,
326
+ advantage_fn=grpo.compute_advantages_dr_grpo,
327
+ loss_fn=gspo.create_gspo_policy_loss(variant="sequence"),
328
+ metrics_fn=gspo.create_gspo_metrics(variant="sequence"),
329
+ optimizer=optimizer
330
+ )
331
+ """
332
+ def gspo_metrics_fn(old_logprobs: mx.array, new_logprobs: mx.array, advantages: mx.array) -> Dict[str, float]:
333
+ """GSPO metrics with sequence-level importance weight tracking."""
334
+ # Robust fallback: distribute tokens as evenly as possible across episodes
335
+ # This matches the same robust approach used in the policy loss function
336
+ total_tokens = len(old_logprobs) if len(old_logprobs.shape) == 1 else old_logprobs.shape[0] # type: ignore
337
+ num_episodes = len(advantages)
338
+
339
+ if num_episodes > 0:
340
+ base_length = total_tokens // num_episodes
341
+ remainder = total_tokens % num_episodes
342
+ # Distribute remainder tokens to first 'remainder' episodes
343
+ sequence_lengths = [base_length + (1 if i < remainder else 0) for i in range(num_episodes)]
344
+ else:
345
+ sequence_lengths = [total_tokens] if total_tokens > 0 else [1]
346
+
347
+ return compute_gspo_metrics(
348
+ old_logprobs=old_logprobs,
349
+ new_logprobs=new_logprobs,
350
+ advantages=advantages,
351
+ sequence_lengths=sequence_lengths,
352
+ variant=variant,
353
+ clip_ratio=clip_ratio
354
+ )
355
+
356
+ return gspo_metrics_fn
357
+
358
+
359
+ # Convenience functions that match GRPO interface
360
+ def policy_loss_sequence(old_logprobs: mx.array, new_logprobs: mx.array, advantages: mx.array) -> mx.array:
361
+ """GSPO sequence-level policy loss function (standard signature)."""
362
+ return create_gspo_policy_loss(variant="sequence")(old_logprobs, new_logprobs, advantages)
363
+
364
+
365
+ def policy_loss_hybrid(old_logprobs: mx.array, new_logprobs: mx.array, advantages: mx.array) -> mx.array:
366
+ """GSPO hybrid policy loss function (standard signature)."""
367
+ return create_gspo_policy_loss(variant="hybrid")(old_logprobs, new_logprobs, advantages)
368
+
369
+ def create_policy_loss_hybrid(alpha: float = 0.5, beta: float = 0.5):
370
+ """
371
+ Create a GSPO hybrid policy loss function with configurable hyperparameters.
372
+
373
+ Args:
374
+ alpha: Weight for sequence-level importance (0.0 = pure token-level, 1.0 = pure sequence-level)
375
+ beta: Weight for token-level importance (0.0 = ignore token-level, 1.0 = full token-level)
376
+
377
+ Returns:
378
+ Policy loss function with standard signature
379
+
380
+ Example:
381
+ # Balanced hybrid (default)
382
+ loss_fn = create_policy_loss_hybrid(alpha=0.5, beta=0.5)
383
+
384
+ # More sequence-focused
385
+ loss_fn = create_policy_loss_hybrid(alpha=0.7, beta=0.3)
386
+
387
+ # More token-focused
388
+ loss_fn = create_policy_loss_hybrid(alpha=0.3, beta=0.7)
389
+ """
390
+ def hybrid_loss_fn(old_logprobs: mx.array, new_logprobs: mx.array, advantages: mx.array) -> mx.array:
391
+ # Use custom alpha/beta parameters for this specific loss function
392
+ return create_gspo_policy_loss(variant="hybrid", alpha=alpha, beta=beta)(
393
+ old_logprobs, new_logprobs, advantages
394
+ )
395
+ return hybrid_loss_fn
396
+
397
+
398
+ def policy_loss_token(old_logprobs: mx.array, new_logprobs: mx.array, advantages: mx.array) -> mx.array:
399
+ """GSPO token-level policy loss function (standard signature) - equivalent to GRPO."""
400
+ return create_gspo_policy_loss(variant="token")(old_logprobs, new_logprobs, advantages)
401
+
402
+
403
+ def compute_metrics_sequence(old_logprobs: mx.array, new_logprobs: mx.array, advantages: mx.array) -> Dict[str, float]:
404
+ """GSPO sequence-level metrics function (standard signature)."""
405
+ return create_gspo_metrics(variant="sequence")(old_logprobs, new_logprobs, advantages)
406
+
407
+
408
+ def compute_metrics_hybrid(old_logprobs: mx.array, new_logprobs: mx.array, advantages: mx.array) -> Dict[str, float]:
409
+ """GSPO hybrid metrics function (standard signature)."""
410
+ return create_gspo_metrics(variant="hybrid")(old_logprobs, new_logprobs, advantages)
411
+
412
+
413
+ def compute_metrics_token(old_logprobs: mx.array, new_logprobs: mx.array, advantages: mx.array) -> Dict[str, float]:
414
+ """GSPO token-level metrics function (standard signature)."""
415
+ return create_gspo_metrics(variant="token")(old_logprobs, new_logprobs, advantages)
416
+
417
+
418
+ def compute_gspo_metrics(
419
+ old_logprobs: mx.array,
420
+ new_logprobs: mx.array,
421
+ advantages: mx.array,
422
+ sequence_lengths: List[int],
423
+ variant: str = "sequence",
424
+ clip_ratio: float = 0.2
425
+ ) -> dict:
426
+ """
427
+ Compute GSPO training metrics for monitoring.
428
+
429
+ Args:
430
+ old_logprobs: Log probabilities from rollout
431
+ new_logprobs: Log probabilities from current policy
432
+ advantages: Group-relative advantages
433
+ sequence_lengths: Length of each sequence in the batch
434
+ variant: GSPO variant being used
435
+ clip_ratio: Clipping ratio used in loss
436
+
437
+ Returns:
438
+ Dictionary of metrics for logging/monitoring
439
+
440
+ Additional GSPO-specific metrics:
441
+ - Sequence-level importance weight statistics
442
+ - Gradient variance estimates
443
+ - Length bias indicators
444
+ """
445
+ # Standard advantage metrics
446
+ metrics = {
447
+ 'mean_advantage': mx.mean(advantages).item(),
448
+ 'std_advantage': mx.std(advantages).item(),
449
+ 'min_advantage': mx.min(advantages).item(),
450
+ 'max_advantage': mx.max(advantages).item()
451
+ }
452
+
453
+ if variant == "sequence":
454
+ # Sequence-level importance weights
455
+ seq_weights = compute_sequence_importance_weights(
456
+ old_logprobs, new_logprobs, sequence_lengths, clip_ratio
457
+ )
458
+
459
+ # Sequence weight statistics
460
+ metrics.update({
461
+ 'mean_seq_weight': mx.mean(seq_weights).item(),
462
+ 'std_seq_weight': mx.std(seq_weights).item(),
463
+ 'max_seq_weight': mx.max(seq_weights).item(),
464
+ 'min_seq_weight': mx.min(seq_weights).item()
465
+ })
466
+
467
+ # Clipping statistics at sequence level
468
+ clipped = (seq_weights < (1 - clip_ratio)) | (seq_weights > (1 + clip_ratio))
469
+ metrics['seq_clip_fraction'] = mx.mean(clipped.astype(mx.float32)).item()
470
+
471
+ elif variant == "hybrid":
472
+ # Hybrid importance weights
473
+ hybrid_weights = compute_hybrid_importance_weights(
474
+ old_logprobs, new_logprobs, sequence_lengths
475
+ )
476
+
477
+ # Hybrid weight statistics
478
+ metrics.update({
479
+ 'mean_hybrid_weight': mx.mean(hybrid_weights).item(),
480
+ 'std_hybrid_weight': mx.std(hybrid_weights).item(),
481
+ 'max_hybrid_weight': mx.max(hybrid_weights).item(),
482
+ 'min_hybrid_weight': mx.min(hybrid_weights).item()
483
+ })
484
+
485
+ # Clipping statistics at token level
486
+ clipped = (hybrid_weights < (1 - clip_ratio)) | (hybrid_weights > (1 + clip_ratio))
487
+ metrics['hybrid_clip_fraction'] = mx.mean(clipped.astype(mx.float32)).item()
488
+
489
+ else: # token-level (standard GRPO)
490
+ # Token-level importance ratios
491
+ ratio = mx.exp(new_logprobs - old_logprobs)
492
+
493
+ metrics.update({
494
+ 'mean_token_ratio': mx.mean(ratio).item(),
495
+ 'std_token_ratio': mx.std(ratio).item(),
496
+ 'max_token_ratio': mx.max(ratio).item(),
497
+ 'min_token_ratio': mx.min(ratio).item()
498
+ })
499
+
500
+ # Clipping statistics at token level
501
+ clipped = (ratio < (1 - clip_ratio)) | (ratio > (1 + clip_ratio))
502
+ metrics['token_clip_fraction'] = mx.mean(clipped.astype(mx.float32)).item()
503
+
504
+ # Length bias analysis
505
+ if len(sequence_lengths) > 1:
506
+ length_array = mx.array(sequence_lengths, dtype=mx.float32)
507
+ metrics.update({
508
+ 'mean_seq_length': mx.mean(length_array).item(),
509
+ 'std_seq_length': mx.std(length_array).item(),
510
+ 'min_seq_length': mx.min(length_array).item(),
511
+ 'max_seq_length': mx.max(length_array).item()
512
+ })
513
+
514
+ # KL divergence approximation
515
+ kl_div = mx.mean(old_logprobs - new_logprobs)
516
+ metrics['kl_divergence'] = kl_div.item()
517
+
518
+ return metrics
519
+
520
+
521
+ # Algorithm-specific data selectors for GSPO
522
+ def select_gspo_data(buffer, variant: str = "sequence"):
523
+ """
524
+ GSPO data selector: Use all available data with sequence-level organization.
525
+
526
+ GSPO requires sequence length information for proper importance weight computation.
527
+ This selector ensures sequence boundaries are preserved in the batch data.
528
+
529
+ Args:
530
+ buffer: Buffer containing episodes
531
+ variant: GSPO variant ("sequence", "hybrid", or "token")
532
+
533
+ Returns:
534
+ Batch data organized for GSPO training with sequence length metadata
535
+ """
536
+ from .grpo import select_all_data
537
+
538
+ # Reuse GRPO's data selection but add sequence length tracking
539
+ batch_data = select_all_data(buffer)
540
+
541
+ # GSPO-specific enhancement: explicit sequence length tracking
542
+ # This ensures proper importance weight computation
543
+ if 'episode_lengths' in batch_data:
544
+ # Use episode lengths as sequence lengths for GSPO
545
+ batch_data['sequence_lengths'] = batch_data['episode_lengths']
546
+ else:
547
+ # Fallback: infer sequence lengths from batch structure
548
+ # This is less ideal but provides compatibility
549
+ total_tokens = len(batch_data['obs']) if 'obs' in batch_data else 0
550
+ num_episodes = len(batch_data['rewards']) if 'rewards' in batch_data else 1
551
+ avg_length = total_tokens // num_episodes if num_episodes > 0 else 0
552
+ batch_data['sequence_lengths'] = [avg_length] * num_episodes
553
+
554
+ return batch_data
555
+
556
+
557
+ # Compiled versions for maximum performance
558
+ @mx.compile
559
+ def compute_sequence_weights_compiled(
560
+ old_logprobs: mx.array,
561
+ new_logprobs: mx.array,
562
+ seq_len: int
563
+ ) -> mx.array:
564
+ """Compiled version of sequence weight computation for a single sequence."""
565
+ old_seq_logprob = mx.sum(old_logprobs)
566
+ new_seq_logprob = mx.sum(new_logprobs)
567
+ log_ratio = new_seq_logprob - old_seq_logprob
568
+ normalized_log_ratio = log_ratio / seq_len
569
+ return mx.exp(normalized_log_ratio)
570
+
571
+
572
+ @mx.compile
573
+ def gspo_loss_compiled(
574
+ importance_weights: mx.array,
575
+ advantages: mx.array,
576
+ clip_ratio: float = 0.2
577
+ ) -> mx.array:
578
+ """Compiled version of GSPO surrogate loss computation."""
579
+ clipped_weights = mx.clip(importance_weights, 1 - clip_ratio, 1 + clip_ratio)
580
+ surr1 = importance_weights * advantages
581
+ surr2 = clipped_weights * advantages
582
+ return -mx.mean(mx.minimum(surr1, surr2))
@@ -0,0 +1,23 @@
1
+ # textpolicy/buffer/__init__.py
2
+ """
3
+ Modular buffer system for TextPolicy.
4
+
5
+ Main components:
6
+ - Episode: Single episode trajectory management
7
+ - Buffer: Multi-episode storage and sampling
8
+ - BufferStorage: Storage and capacity management
9
+ - BufferSampler: Data retrieval and sampling methods
10
+ """
11
+
12
+ from .episode import Episode
13
+ from .buffer import Buffer
14
+ from .storage import BufferStorage
15
+ from .sampling import BufferSampler
16
+
17
+ # Backwards compatibility - maintain existing import structure
18
+ __all__ = [
19
+ 'Episode',
20
+ 'Buffer',
21
+ 'BufferStorage',
22
+ 'BufferSampler',
23
+ ]