textpolicy 0.1.1__tar.gz → 0.1.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (80) hide show
  1. {textpolicy-0.1.1/textpolicy.egg-info → textpolicy-0.1.2}/PKG-INFO +3 -3
  2. {textpolicy-0.1.1 → textpolicy-0.1.2}/pyproject.toml +4 -4
  3. textpolicy-0.1.2/tests/test_mlx_compatibility.py +315 -0
  4. textpolicy-0.1.2/tests/test_training_pipeline.py +240 -0
  5. {textpolicy-0.1.1 → textpolicy-0.1.2/textpolicy.egg-info}/PKG-INFO +3 -3
  6. {textpolicy-0.1.1 → textpolicy-0.1.2}/textpolicy.egg-info/SOURCES.txt +2 -0
  7. {textpolicy-0.1.1 → textpolicy-0.1.2}/textpolicy.egg-info/requires.txt +2 -2
  8. {textpolicy-0.1.1 → textpolicy-0.1.2}/LICENSE +0 -0
  9. {textpolicy-0.1.1 → textpolicy-0.1.2}/README.md +0 -0
  10. {textpolicy-0.1.1 → textpolicy-0.1.2}/setup.cfg +0 -0
  11. {textpolicy-0.1.1 → textpolicy-0.1.2}/tests/test_gspo_verification.py +0 -0
  12. {textpolicy-0.1.1 → textpolicy-0.1.2}/tests/test_integration_e2e_training.py +0 -0
  13. {textpolicy-0.1.1 → textpolicy-0.1.2}/tests/test_issue_fixes.py +0 -0
  14. {textpolicy-0.1.1 → textpolicy-0.1.2}/tests/test_reward_signatures.py +0 -0
  15. {textpolicy-0.1.1 → textpolicy-0.1.2}/tests/test_rollout_rewards.py +0 -0
  16. {textpolicy-0.1.1 → textpolicy-0.1.2}/tests/test_runner_step_enforcement.py +0 -0
  17. {textpolicy-0.1.1 → textpolicy-0.1.2}/tests/test_validate_installation.py +0 -0
  18. {textpolicy-0.1.1 → textpolicy-0.1.2}/textpolicy/__init__.py +0 -0
  19. {textpolicy-0.1.1 → textpolicy-0.1.2}/textpolicy/__main__.py +0 -0
  20. {textpolicy-0.1.1 → textpolicy-0.1.2}/textpolicy/algorithms/__init__.py +0 -0
  21. {textpolicy-0.1.1 → textpolicy-0.1.2}/textpolicy/algorithms/grpo.py +0 -0
  22. {textpolicy-0.1.1 → textpolicy-0.1.2}/textpolicy/algorithms/gspo.py +0 -0
  23. {textpolicy-0.1.1 → textpolicy-0.1.2}/textpolicy/buffer/__init__.py +0 -0
  24. {textpolicy-0.1.1 → textpolicy-0.1.2}/textpolicy/buffer/buffer.py +0 -0
  25. {textpolicy-0.1.1 → textpolicy-0.1.2}/textpolicy/buffer/episode.py +0 -0
  26. {textpolicy-0.1.1 → textpolicy-0.1.2}/textpolicy/buffer/sampling.py +0 -0
  27. {textpolicy-0.1.1 → textpolicy-0.1.2}/textpolicy/buffer/storage.py +0 -0
  28. {textpolicy-0.1.1 → textpolicy-0.1.2}/textpolicy/cli.py +0 -0
  29. {textpolicy-0.1.1 → textpolicy-0.1.2}/textpolicy/environment/__init__.py +0 -0
  30. {textpolicy-0.1.1 → textpolicy-0.1.2}/textpolicy/environment/base.py +0 -0
  31. {textpolicy-0.1.1 → textpolicy-0.1.2}/textpolicy/environment/environment.py +0 -0
  32. {textpolicy-0.1.1 → textpolicy-0.1.2}/textpolicy/environment/factory.py +0 -0
  33. {textpolicy-0.1.1 → textpolicy-0.1.2}/textpolicy/environment/gym.py +0 -0
  34. {textpolicy-0.1.1 → textpolicy-0.1.2}/textpolicy/environment/task_suites.py +0 -0
  35. {textpolicy-0.1.1 → textpolicy-0.1.2}/textpolicy/environment/text_generation.py +0 -0
  36. {textpolicy-0.1.1 → textpolicy-0.1.2}/textpolicy/environment/vectorized.py +0 -0
  37. {textpolicy-0.1.1 → textpolicy-0.1.2}/textpolicy/generation/__init__.py +0 -0
  38. {textpolicy-0.1.1 → textpolicy-0.1.2}/textpolicy/generation/lora.py +0 -0
  39. {textpolicy-0.1.1 → textpolicy-0.1.2}/textpolicy/generation/mlx_generation.py +0 -0
  40. {textpolicy-0.1.1 → textpolicy-0.1.2}/textpolicy/generation/reload.py +0 -0
  41. {textpolicy-0.1.1 → textpolicy-0.1.2}/textpolicy/rewards/__init__.py +0 -0
  42. {textpolicy-0.1.1 → textpolicy-0.1.2}/textpolicy/rewards/adapters.py +0 -0
  43. {textpolicy-0.1.1 → textpolicy-0.1.2}/textpolicy/rewards/basic.py +0 -0
  44. {textpolicy-0.1.1 → textpolicy-0.1.2}/textpolicy/rewards/integrated_system.py +0 -0
  45. {textpolicy-0.1.1 → textpolicy-0.1.2}/textpolicy/rewards/mlx_batch_processor.py +0 -0
  46. {textpolicy-0.1.1 → textpolicy-0.1.2}/textpolicy/rewards/registry.py +0 -0
  47. {textpolicy-0.1.1 → textpolicy-0.1.2}/textpolicy/rewards/rollout_rewards.py +0 -0
  48. {textpolicy-0.1.1 → textpolicy-0.1.2}/textpolicy/rewards/verifiers.py +0 -0
  49. {textpolicy-0.1.1 → textpolicy-0.1.2}/textpolicy/rollout/__init__.py +0 -0
  50. {textpolicy-0.1.1 → textpolicy-0.1.2}/textpolicy/rollout/aggregator.py +0 -0
  51. {textpolicy-0.1.1 → textpolicy-0.1.2}/textpolicy/rollout/base.py +0 -0
  52. {textpolicy-0.1.1 → textpolicy-0.1.2}/textpolicy/rollout/rollout.py +0 -0
  53. {textpolicy-0.1.1 → textpolicy-0.1.2}/textpolicy/rollout/runner.py +0 -0
  54. {textpolicy-0.1.1 → textpolicy-0.1.2}/textpolicy/rollout/strategy.py +0 -0
  55. {textpolicy-0.1.1 → textpolicy-0.1.2}/textpolicy/rollout/worker.py +0 -0
  56. {textpolicy-0.1.1 → textpolicy-0.1.2}/textpolicy/training/__init__.py +0 -0
  57. {textpolicy-0.1.1 → textpolicy-0.1.2}/textpolicy/training/metrics.py +0 -0
  58. {textpolicy-0.1.1 → textpolicy-0.1.2}/textpolicy/training/rollout_manager.py +0 -0
  59. {textpolicy-0.1.1 → textpolicy-0.1.2}/textpolicy/training/trainer.py +0 -0
  60. {textpolicy-0.1.1 → textpolicy-0.1.2}/textpolicy/utils/__init__.py +0 -0
  61. {textpolicy-0.1.1 → textpolicy-0.1.2}/textpolicy/utils/benchmarking.py +0 -0
  62. {textpolicy-0.1.1 → textpolicy-0.1.2}/textpolicy/utils/data.py +0 -0
  63. {textpolicy-0.1.1 → textpolicy-0.1.2}/textpolicy/utils/debug.py +0 -0
  64. {textpolicy-0.1.1 → textpolicy-0.1.2}/textpolicy/utils/environment.py +0 -0
  65. {textpolicy-0.1.1 → textpolicy-0.1.2}/textpolicy/utils/logging/__init__.py +0 -0
  66. {textpolicy-0.1.1 → textpolicy-0.1.2}/textpolicy/utils/logging/base.py +0 -0
  67. {textpolicy-0.1.1 → textpolicy-0.1.2}/textpolicy/utils/logging/console.py +0 -0
  68. {textpolicy-0.1.1 → textpolicy-0.1.2}/textpolicy/utils/logging/factory.py +0 -0
  69. {textpolicy-0.1.1 → textpolicy-0.1.2}/textpolicy/utils/logging/multi.py +0 -0
  70. {textpolicy-0.1.1 → textpolicy-0.1.2}/textpolicy/utils/logging/tensorboard.py +0 -0
  71. {textpolicy-0.1.1 → textpolicy-0.1.2}/textpolicy/utils/logging/wandb.py +0 -0
  72. {textpolicy-0.1.1 → textpolicy-0.1.2}/textpolicy/utils/memory.py +0 -0
  73. {textpolicy-0.1.1 → textpolicy-0.1.2}/textpolicy/utils/performance.py +0 -0
  74. {textpolicy-0.1.1 → textpolicy-0.1.2}/textpolicy/utils/timing.py +0 -0
  75. {textpolicy-0.1.1 → textpolicy-0.1.2}/textpolicy/validate.py +0 -0
  76. {textpolicy-0.1.1 → textpolicy-0.1.2}/textpolicy/validation/__init__.py +0 -0
  77. {textpolicy-0.1.1 → textpolicy-0.1.2}/textpolicy/validation/logprob_validation.py +0 -0
  78. {textpolicy-0.1.1 → textpolicy-0.1.2}/textpolicy.egg-info/dependency_links.txt +0 -0
  79. {textpolicy-0.1.1 → textpolicy-0.1.2}/textpolicy.egg-info/entry_points.txt +0 -0
  80. {textpolicy-0.1.1 → textpolicy-0.1.2}/textpolicy.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: textpolicy
3
- Version: 0.1.1
3
+ Version: 0.1.2
4
4
  Summary: Reinforcement learning for text generation on MLX (Apple Silicon): GRPO/GSPO, environments, rollout, rewards, LoRA/QLoRA
5
5
  Project-URL: Homepage, https://github.com/teilomillet/textpolicy
6
6
  Project-URL: Repository, https://github.com/teilomillet/textpolicy
@@ -16,8 +16,8 @@ Requires-Python: >=3.12
16
16
  Description-Content-Type: text/markdown
17
17
  License-File: LICENSE
18
18
  Requires-Dist: numpy>=2.3.2
19
- Requires-Dist: mlx>=0.21.0
20
- Requires-Dist: mlx-lm>=0.21.0
19
+ Requires-Dist: mlx>=0.22.0
20
+ Requires-Dist: mlx-lm>=0.22.0
21
21
  Requires-Dist: gymnasium>=0.29.0
22
22
  Requires-Dist: psutil>=7.0.0
23
23
  Requires-Dist: wandb>=0.21.1
@@ -7,15 +7,15 @@ include = ["textpolicy*"]
7
7
 
8
8
  [project]
9
9
  name = "textpolicy"
10
- version = "0.1.1"
10
+ version = "0.1.2"
11
11
  description = "Reinforcement learning for text generation on MLX (Apple Silicon): GRPO/GSPO, environments, rollout, rewards, LoRA/QLoRA"
12
12
  readme = "README.md"
13
13
  requires-python = ">=3.12"
14
14
  dependencies = [
15
15
  "numpy>=2.3.2",
16
- "mlx>=0.21.0", # Core MLX framework for Apple Silicon acceleration
17
- "mlx-lm>=0.21.0", # MLX language models for inference
18
- "gymnasium>=0.29.0",
16
+ "mlx>=0.22.0", # Core MLX framework for Apple Silicon acceleration (tested up to 0.30.5)
17
+ "mlx-lm>=0.22.0", # MLX language models for inference (tested up to 0.30.6)
18
+ "gymnasium>=0.29.0",
19
19
  "psutil>=7.0.0",
20
20
  "wandb>=0.21.1",
21
21
  "aiohttp>=3.12.15",
@@ -0,0 +1,315 @@
1
+ """
2
+ MLX and mlx-lm Compatibility Tests
3
+
4
+ These tests verify that textpolicy works correctly with MLX 0.30.x and mlx-lm 0.30.x.
5
+ They cover the core APIs used throughout the codebase.
6
+
7
+ Run with: pytest tests/test_mlx_compatibility.py -v
8
+ """
9
+
10
+ import pytest
11
+ import mlx.core as mx
12
+ import mlx.nn as nn
13
+
14
+
15
+ @pytest.mark.unit
16
+ class TestMLXCoreAPIs:
17
+ """Test core MLX APIs used in textpolicy."""
18
+
19
+ def test_mx_compile_decorator(self):
20
+ """Test @mx.compile decorator works (used in grpo.py, gspo.py, trainer.py)."""
21
+ @mx.compile
22
+ def compiled_fn(x, y):
23
+ return x + y
24
+
25
+ result = compiled_fn(mx.array([1.0]), mx.array([2.0]))
26
+ assert float(result[0]) == 3.0
27
+
28
+ def test_mx_compile_with_value_and_grad(self):
29
+ """Test mx.compile with nn.value_and_grad (used in trainer.py:90)."""
30
+ class SimpleModel(nn.Module):
31
+ def __init__(self):
32
+ super().__init__()
33
+ self.linear = nn.Linear(4, 2)
34
+ def __call__(self, x):
35
+ return self.linear(x)
36
+
37
+ model = SimpleModel()
38
+ mx.eval(model.parameters())
39
+
40
+ def loss_fn(model, x, y):
41
+ pred = model(x)
42
+ return mx.mean((pred - y) ** 2)
43
+
44
+ loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
45
+ x = mx.random.normal((2, 4))
46
+ y = mx.random.normal((2, 2))
47
+ loss, grads = loss_and_grad_fn(model, x, y)
48
+
49
+ assert not mx.isnan(loss)
50
+ assert isinstance(grads, dict)
51
+
52
+ def test_array_operations(self):
53
+ """Test array operations used in GRPO/GSPO algorithms."""
54
+ # Ratio computation (grpo.py, gspo.py)
55
+ old_lp = mx.array([-1.0, -1.2, -0.8])
56
+ new_lp = mx.array([-1.1, -1.0, -0.9])
57
+ ratios = mx.exp(new_lp - old_lp)
58
+
59
+ assert ratios.shape == (3,)
60
+ assert all(not mx.isnan(r) for r in ratios)
61
+
62
+ # Clipping (PPO-style)
63
+ clipped = mx.clip(ratios, 0.8, 1.2)
64
+ # Use small epsilon for float comparison
65
+ assert all(0.8 - 1e-6 <= float(c) <= 1.2 + 1e-6 for c in clipped)
66
+
67
+ def test_array_slicing_with_python_ints(self):
68
+ """Test array slicing with Python integers (used in gspo.py)."""
69
+ arr = mx.array([1.0, 2.0, 3.0, 4.0, 5.0])
70
+
71
+ # This pattern is used in compute_sequence_importance_weights
72
+ current_idx = 0
73
+ for seq_len in [2, 3]:
74
+ result = arr[current_idx:current_idx + seq_len]
75
+ assert result.shape[0] == seq_len
76
+ current_idx += seq_len
77
+
78
+
79
+ @pytest.mark.unit
80
+ class TestMLXLMAPIs:
81
+ """Test mlx-lm APIs used in textpolicy."""
82
+
83
+ def test_mlx_lm_imports(self):
84
+ """Test mlx_lm core imports (used in mlx_generation.py)."""
85
+ from mlx_lm import load, generate
86
+ assert callable(load)
87
+ assert callable(generate)
88
+
89
+ def test_sample_utils_imports(self):
90
+ """Test sample_utils imports (used in mlx_generation.py:27)."""
91
+ from mlx_lm.sample_utils import make_sampler, make_logits_processors
92
+ assert callable(make_sampler)
93
+ assert callable(make_logits_processors)
94
+
95
+ def test_make_sampler_signature(self):
96
+ """Test make_sampler accepts expected parameters."""
97
+ from mlx_lm.sample_utils import make_sampler
98
+
99
+ # These params are used in mlx_generation.py:77-81
100
+ sampler = make_sampler(
101
+ temp=0.7,
102
+ top_p=0.9,
103
+ min_p=0.0,
104
+ min_tokens_to_keep=2
105
+ )
106
+ assert sampler is not None
107
+
108
+ def test_quantize_model_signature(self):
109
+ """Test quantize_model has expected parameters (used in lora.py:335)."""
110
+ from mlx_lm.utils import quantize_model
111
+ import inspect
112
+
113
+ sig = inspect.signature(quantize_model)
114
+ params = list(sig.parameters.keys())
115
+
116
+ # These params are used in lora.py:342-348
117
+ expected = ['model', 'config', 'q_group_size', 'q_bits', 'quant_predicate']
118
+ for p in expected:
119
+ assert p in params, f"Expected param '{p}' not found in quantize_model"
120
+
121
+
122
+ @pytest.mark.unit
123
+ @pytest.mark.algorithm
124
+ class TestGRPOWithMLX:
125
+ """Test GRPO algorithms with current MLX version."""
126
+
127
+ def test_compute_advantages(self):
128
+ """Test GRPO compute_advantages."""
129
+ from textpolicy.algorithms import grpo
130
+
131
+ rewards = mx.array([1.0, 0.5, -0.5, 0.8, 0.2])
132
+ advantages = grpo.compute_advantages(rewards)
133
+
134
+ assert advantages.shape == rewards.shape
135
+ # Group-relative: mean should be ~0
136
+ assert abs(float(mx.mean(advantages))) < 1e-5
137
+
138
+ def test_compute_advantages_compiled(self):
139
+ """Test compiled version of compute_advantages."""
140
+ from textpolicy.algorithms import grpo
141
+
142
+ rewards = mx.array([1.0, 0.5, -0.5, 0.8])
143
+ advantages = grpo.compute_advantages_compiled(rewards)
144
+
145
+ assert advantages.shape == rewards.shape
146
+
147
+ def test_policy_loss(self):
148
+ """Test GRPO policy_loss computation."""
149
+ from textpolicy.algorithms import grpo
150
+
151
+ old_lp = mx.array([-1.0, -1.2, -0.8, -1.1, -0.9])
152
+ new_lp = mx.array([-1.1, -1.0, -0.9, -1.0, -1.0])
153
+ advantages = mx.array([0.6, 0.1, -0.9, 0.4, -0.2])
154
+
155
+ loss = grpo.policy_loss(old_lp, new_lp, advantages, clip_ratio=0.2)
156
+
157
+ assert not mx.isnan(loss)
158
+ assert loss.shape == () # Scalar
159
+
160
+ def test_entropy_bonus(self):
161
+ """Test entropy bonus computation."""
162
+ from textpolicy.algorithms import grpo
163
+
164
+ logprobs = mx.array([-1.0, -2.0, -0.5, -1.5])
165
+ entropy = grpo.entropy_bonus(logprobs, coefficient=0.01)
166
+
167
+ assert not mx.isnan(entropy)
168
+
169
+
170
+ @pytest.mark.unit
171
+ @pytest.mark.algorithm
172
+ class TestGSPOWithMLX:
173
+ """Test GSPO algorithms with current MLX version."""
174
+
175
+ def test_compute_sequence_importance_weights(self):
176
+ """Test sequence-level importance weights."""
177
+ from textpolicy.algorithms import gspo
178
+
179
+ old_lp = mx.array([-1.0, -1.2, -0.8, -1.1, -0.9])
180
+ new_lp = mx.array([-1.1, -1.0, -0.9, -1.0, -1.0])
181
+ seq_lens = [2, 3]
182
+
183
+ weights = gspo.compute_sequence_importance_weights(
184
+ old_lp, new_lp, seq_lens, clip_ratio=0.2
185
+ )
186
+
187
+ assert len(weights) == len(seq_lens)
188
+ assert all(not mx.isnan(w) for w in weights)
189
+
190
+ def test_compute_hybrid_importance_weights(self):
191
+ """Test hybrid importance weights."""
192
+ from textpolicy.algorithms import gspo
193
+
194
+ old_lp = mx.array([-1.0, -1.2, -0.8, -1.1, -0.9])
195
+ new_lp = mx.array([-1.1, -1.0, -0.9, -1.0, -1.0])
196
+ seq_lens = [2, 3]
197
+
198
+ weights = gspo.compute_hybrid_importance_weights(
199
+ old_lp, new_lp, seq_lens, alpha=0.5, beta=0.5
200
+ )
201
+
202
+ assert len(weights) == sum(seq_lens) # Token-level weights
203
+
204
+ def test_gspo_policy_loss_sequence_variant(self):
205
+ """Test GSPO policy loss with sequence variant."""
206
+ from textpolicy.algorithms import gspo
207
+
208
+ old_lp = mx.array([-1.0, -1.2, -0.8, -1.1, -0.9])
209
+ new_lp = mx.array([-1.1, -1.0, -0.9, -1.0, -1.0])
210
+ seq_lens = [2, 3]
211
+ advantages = mx.array([0.5, -0.3])
212
+
213
+ loss = gspo.gspo_policy_loss(
214
+ old_lp, new_lp, advantages, seq_lens,
215
+ variant='sequence', clip_ratio=0.2
216
+ )
217
+
218
+ assert not mx.isnan(loss)
219
+ assert loss.shape == ()
220
+
221
+ def test_gspo_policy_loss_hybrid_variant(self):
222
+ """Test GSPO policy loss with hybrid variant."""
223
+ from textpolicy.algorithms import gspo
224
+
225
+ old_lp = mx.array([-1.0, -1.2, -0.8, -1.1, -0.9])
226
+ new_lp = mx.array([-1.1, -1.0, -0.9, -1.0, -1.0])
227
+ seq_lens = [2, 3]
228
+ advantages = mx.array([0.5, -0.3])
229
+
230
+ loss = gspo.gspo_policy_loss(
231
+ old_lp, new_lp, advantages, seq_lens,
232
+ variant='hybrid', clip_ratio=0.2
233
+ )
234
+
235
+ assert not mx.isnan(loss)
236
+
237
+ def test_gspo_policy_loss_token_variant(self):
238
+ """Test GSPO policy loss with token variant (GRPO fallback)."""
239
+ from textpolicy.algorithms import gspo
240
+
241
+ old_lp = mx.array([-1.0, -1.2, -0.8, -1.1, -0.9])
242
+ new_lp = mx.array([-1.1, -1.0, -0.9, -1.0, -1.0])
243
+ seq_lens = [2, 3]
244
+ advantages = mx.array([0.5, -0.3])
245
+
246
+ loss = gspo.gspo_policy_loss(
247
+ old_lp, new_lp, advantages, seq_lens,
248
+ variant='token', clip_ratio=0.2
249
+ )
250
+
251
+ assert not mx.isnan(loss)
252
+
253
+
254
+ @pytest.mark.unit
255
+ class TestTrainerWithMLX:
256
+ """Test Trainer compilation with MLX."""
257
+
258
+ def test_trainer_compiles_loss_function(self):
259
+ """Test that Trainer correctly compiles the loss function."""
260
+ import mlx.optimizers as optim
261
+ from textpolicy.training import Trainer
262
+ from textpolicy.algorithms import grpo
263
+
264
+ class SimpleModel(nn.Module):
265
+ def __init__(self):
266
+ super().__init__()
267
+ self.linear = nn.Linear(10, 5)
268
+ def __call__(self, x):
269
+ return self.linear(x)
270
+
271
+ model = SimpleModel()
272
+ mx.eval(model.parameters())
273
+
274
+ def loss_fn(model, batch):
275
+ obs = batch.get('observations', mx.zeros((1, 10)))
276
+ return mx.mean(model(obs) ** 2)
277
+
278
+ trainer = Trainer(
279
+ model=model,
280
+ loss_fn=loss_fn,
281
+ optimizer=optim.Adam(learning_rate=1e-3),
282
+ advantage_fn=grpo.compute_advantages,
283
+ compile_training=True
284
+ )
285
+
286
+ # Verify compilation happened
287
+ assert trainer.loss_and_grad_fn is not None
288
+
289
+
290
+ @pytest.mark.unit
291
+ class TestLoRAWithMLX:
292
+ """Test LoRA functions with MLX."""
293
+
294
+ def test_lora_functions_importable(self):
295
+ """Test all LoRA functions can be imported."""
296
+ from textpolicy.generation.lora import (
297
+ apply_lora,
298
+ freeze_base,
299
+ extract_params,
300
+ merge_weights,
301
+ create_lora_setup,
302
+ create_qlora_setup,
303
+ apply_quantization_to_model,
304
+ compute_lora_memory_savings
305
+ )
306
+
307
+ # All should be callable
308
+ assert callable(apply_lora)
309
+ assert callable(freeze_base)
310
+ assert callable(extract_params)
311
+ assert callable(merge_weights)
312
+ assert callable(create_lora_setup)
313
+ assert callable(create_qlora_setup)
314
+ assert callable(apply_quantization_to_model)
315
+ assert callable(compute_lora_memory_savings)
@@ -0,0 +1,240 @@
1
+ """
2
+ Training Pipeline Integration Tests
3
+
4
+ These tests verify that the complete training pipeline works correctly,
5
+ including gradient computation and parameter updates.
6
+ """
7
+
8
+ import pytest
9
+ import mlx.core as mx
10
+ import mlx.nn as nn
11
+ import mlx.optimizers as optim
12
+
13
+
14
+ @pytest.mark.integration
15
+ class TestTrainingPipeline:
16
+ """Test the complete training pipeline with gradient updates."""
17
+
18
+ def test_grpo_training_step_produces_finite_loss(self):
19
+ """Test that GRPO training produces finite loss values."""
20
+ from textpolicy.algorithms import grpo
21
+
22
+ # Create simple model
23
+ class TinyLM(nn.Module):
24
+ def __init__(self, vocab_size=100, hidden=32):
25
+ super().__init__()
26
+ self.embed = nn.Embedding(vocab_size, hidden)
27
+ self.proj = nn.Linear(hidden, vocab_size)
28
+
29
+ def __call__(self, x):
30
+ return self.proj(self.embed(x))
31
+
32
+ model = TinyLM()
33
+ mx.eval(model.parameters())
34
+
35
+ # Prepare batch data (mimics what data_selector returns)
36
+ old_logprobs = mx.array([-2.5, -3.1, -2.8, -2.9, -3.0])
37
+ observations = mx.array([10, 20, 30, 40, 50])
38
+ actions = mx.array([15, 25, 35, 45, 55])
39
+ rewards = mx.array([1.0, 0.5, -0.5, 0.8, 0.2])
40
+
41
+ # Compute new logprobs from model
42
+ logits = model(observations)
43
+ log_probs = nn.log_softmax(logits, axis=-1)
44
+ new_logprobs = mx.take_along_axis(log_probs, actions[:, None], axis=-1).squeeze(-1)
45
+
46
+ # Compute GRPO loss
47
+ advantages = grpo.compute_advantages(rewards)
48
+ loss = grpo.policy_loss(old_logprobs, new_logprobs, advantages, clip_ratio=0.2)
49
+
50
+ mx.eval(loss)
51
+ assert not mx.isnan(loss), "Loss should not be NaN"
52
+ assert not mx.isinf(loss), "Loss should not be infinite"
53
+
54
+ def test_grpo_gradients_flow(self):
55
+ """Test that gradients flow through the GRPO loss."""
56
+ from textpolicy.algorithms import grpo
57
+
58
+ class TinyLM(nn.Module):
59
+ def __init__(self, vocab_size=100, hidden=32):
60
+ super().__init__()
61
+ self.embed = nn.Embedding(vocab_size, hidden)
62
+ self.proj = nn.Linear(hidden, vocab_size)
63
+
64
+ def __call__(self, x):
65
+ return self.proj(self.embed(x))
66
+
67
+ model = TinyLM()
68
+ mx.eval(model.parameters())
69
+
70
+ def loss_fn(model):
71
+ observations = mx.array([10, 20, 30, 40, 50])
72
+ actions = mx.array([15, 25, 35, 45, 55])
73
+ old_logprobs = mx.array([-2.5, -3.1, -2.8, -2.9, -3.0])
74
+ rewards = mx.array([1.0, 0.5, -0.5, 0.8, 0.2])
75
+
76
+ logits = model(observations)
77
+ log_probs = nn.log_softmax(logits, axis=-1)
78
+ new_logprobs = mx.take_along_axis(log_probs, actions[:, None], axis=-1).squeeze(-1)
79
+
80
+ advantages = grpo.compute_advantages(rewards)
81
+ return grpo.policy_loss(old_logprobs, new_logprobs, advantages, clip_ratio=0.2)
82
+
83
+ # Compute gradients
84
+ loss_and_grad = nn.value_and_grad(model, loss_fn)
85
+ loss, grads = loss_and_grad(model)
86
+
87
+ mx.eval(loss, grads)
88
+
89
+ assert not mx.isnan(loss), "Loss should not be NaN"
90
+
91
+ # Check that at least some gradients are non-zero
92
+ has_nonzero_grad = False
93
+ for name, grad in grads.items():
94
+ if isinstance(grad, dict):
95
+ for subname, subgrad in grad.items():
96
+ if mx.any(subgrad != 0):
97
+ has_nonzero_grad = True
98
+ break
99
+ elif mx.any(grad != 0):
100
+ has_nonzero_grad = True
101
+ break
102
+
103
+ assert has_nonzero_grad, "Should have at least some non-zero gradients"
104
+
105
+ def test_optimizer_updates_parameters(self):
106
+ """Test that optimizer actually updates model parameters."""
107
+ from textpolicy.algorithms import grpo
108
+
109
+ class TinyLM(nn.Module):
110
+ def __init__(self, vocab_size=100, hidden=32):
111
+ super().__init__()
112
+ self.embed = nn.Embedding(vocab_size, hidden)
113
+ self.proj = nn.Linear(hidden, vocab_size)
114
+
115
+ def __call__(self, x):
116
+ return self.proj(self.embed(x))
117
+
118
+ model = TinyLM()
119
+ optimizer = optim.Adam(learning_rate=0.01)
120
+ mx.eval(model.parameters())
121
+
122
+ # Store initial parameter values
123
+ initial_weight = mx.array(model.proj.weight) # Create a copy
124
+
125
+ def loss_fn(model):
126
+ observations = mx.array([10, 20, 30])
127
+ actions = mx.array([15, 25, 35])
128
+ old_logprobs = mx.array([-2.5, -3.1, -2.8])
129
+ rewards = mx.array([1.0, 0.5, -0.5])
130
+
131
+ logits = model(observations)
132
+ log_probs = nn.log_softmax(logits, axis=-1)
133
+ new_logprobs = mx.take_along_axis(log_probs, actions[:, None], axis=-1).squeeze(-1)
134
+
135
+ advantages = grpo.compute_advantages(rewards)
136
+ return grpo.policy_loss(old_logprobs, new_logprobs, advantages)
137
+
138
+ loss_and_grad = nn.value_and_grad(model, loss_fn)
139
+
140
+ # Run a few training steps
141
+ for _ in range(3):
142
+ loss, grads = loss_and_grad(model)
143
+ optimizer.update(model, grads)
144
+ mx.eval(model.parameters())
145
+
146
+ # Check parameters changed
147
+ final_weight = model.proj.weight
148
+ params_changed = not mx.allclose(initial_weight, final_weight)
149
+
150
+ assert params_changed, "Parameters should change after optimization"
151
+
152
+ def test_gspo_training_step_produces_finite_loss(self):
153
+ """Test that GSPO training produces finite loss values."""
154
+ from textpolicy.algorithms import gspo
155
+
156
+ class TinyLM(nn.Module):
157
+ def __init__(self, vocab_size=100, hidden=32):
158
+ super().__init__()
159
+ self.embed = nn.Embedding(vocab_size, hidden)
160
+ self.proj = nn.Linear(hidden, vocab_size)
161
+
162
+ def __call__(self, x):
163
+ return self.proj(self.embed(x))
164
+
165
+ model = TinyLM()
166
+ mx.eval(model.parameters())
167
+
168
+ # Prepare batch data
169
+ old_logprobs = mx.array([-2.5, -3.1, -2.8, -2.9, -3.0])
170
+ observations = mx.array([10, 20, 30, 40, 50])
171
+ actions = mx.array([15, 25, 35, 45, 55])
172
+
173
+ # Compute new logprobs from model
174
+ logits = model(observations)
175
+ log_probs = nn.log_softmax(logits, axis=-1)
176
+ new_logprobs = mx.take_along_axis(log_probs, actions[:, None], axis=-1).squeeze(-1)
177
+
178
+ # GSPO needs sequence-level advantages
179
+ sequence_lengths = [2, 3] # 2 sequences
180
+ advantages = mx.array([0.5, -0.3]) # Per-sequence
181
+
182
+ # Test all GSPO variants
183
+ for variant in ['sequence', 'hybrid', 'token']:
184
+ loss = gspo.gspo_policy_loss(
185
+ old_logprobs, new_logprobs, advantages, sequence_lengths,
186
+ variant=variant, clip_ratio=0.2
187
+ )
188
+ mx.eval(loss)
189
+ assert not mx.isnan(loss), f"GSPO {variant} loss should not be NaN"
190
+ assert not mx.isinf(loss), f"GSPO {variant} loss should not be infinite"
191
+
192
+
193
+ @pytest.mark.integration
194
+ class TestCompiledTraining:
195
+ """Test that @mx.compile works correctly with training."""
196
+
197
+ def test_compiled_function_works(self):
198
+ """Test that @mx.compile decorator works with loss functions."""
199
+ @mx.compile
200
+ def compiled_loss(x, y):
201
+ return mx.mean((x - y) ** 2)
202
+
203
+ x = mx.array([1.0, 2.0, 3.0])
204
+ y = mx.array([1.1, 2.1, 3.1])
205
+
206
+ loss = compiled_loss(x, y)
207
+ mx.eval(loss)
208
+
209
+ assert not mx.isnan(loss), "Compiled loss should not be NaN"
210
+ assert float(loss) > 0, "Loss should be positive"
211
+
212
+ def test_trainer_compiled_loss_function(self):
213
+ """Test that Trainer's compiled loss function works correctly.
214
+
215
+ This mimics how Trainer.py:90 compiles the loss function.
216
+ """
217
+ from textpolicy.training import Trainer
218
+ from textpolicy.algorithms import grpo
219
+
220
+ class TinyLM(nn.Module):
221
+ def __init__(self):
222
+ super().__init__()
223
+ self.linear = nn.Linear(10, 10)
224
+ def __call__(self, x):
225
+ return self.linear(x)
226
+
227
+ model = TinyLM()
228
+ mx.eval(model.parameters())
229
+
230
+ # Create trainer with compilation enabled
231
+ trainer = Trainer(
232
+ model=model,
233
+ loss_fn=grpo.policy_loss,
234
+ optimizer=optim.Adam(learning_rate=1e-3),
235
+ advantage_fn=grpo.compute_advantages,
236
+ compile_training=True
237
+ )
238
+
239
+ # Verify the compiled function was created
240
+ assert trainer.loss_and_grad_fn is not None, "Compiled loss function should exist"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: textpolicy
3
- Version: 0.1.1
3
+ Version: 0.1.2
4
4
  Summary: Reinforcement learning for text generation on MLX (Apple Silicon): GRPO/GSPO, environments, rollout, rewards, LoRA/QLoRA
5
5
  Project-URL: Homepage, https://github.com/teilomillet/textpolicy
6
6
  Project-URL: Repository, https://github.com/teilomillet/textpolicy
@@ -16,8 +16,8 @@ Requires-Python: >=3.12
16
16
  Description-Content-Type: text/markdown
17
17
  License-File: LICENSE
18
18
  Requires-Dist: numpy>=2.3.2
19
- Requires-Dist: mlx>=0.21.0
20
- Requires-Dist: mlx-lm>=0.21.0
19
+ Requires-Dist: mlx>=0.22.0
20
+ Requires-Dist: mlx-lm>=0.22.0
21
21
  Requires-Dist: gymnasium>=0.29.0
22
22
  Requires-Dist: psutil>=7.0.0
23
23
  Requires-Dist: wandb>=0.21.1
@@ -4,9 +4,11 @@ pyproject.toml
4
4
  tests/test_gspo_verification.py
5
5
  tests/test_integration_e2e_training.py
6
6
  tests/test_issue_fixes.py
7
+ tests/test_mlx_compatibility.py
7
8
  tests/test_reward_signatures.py
8
9
  tests/test_rollout_rewards.py
9
10
  tests/test_runner_step_enforcement.py
11
+ tests/test_training_pipeline.py
10
12
  tests/test_validate_installation.py
11
13
  textpolicy/__init__.py
12
14
  textpolicy/__main__.py
@@ -1,6 +1,6 @@
1
1
  numpy>=2.3.2
2
- mlx>=0.21.0
3
- mlx-lm>=0.21.0
2
+ mlx>=0.22.0
3
+ mlx-lm>=0.22.0
4
4
  gymnasium>=0.29.0
5
5
  psutil>=7.0.0
6
6
  wandb>=0.21.1
File without changes
File without changes
File without changes
File without changes