textpolicy 0.0.1__py3-none-any.whl → 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. textpolicy/__init__.py +53 -0
  2. textpolicy/__main__.py +8 -0
  3. textpolicy/algorithms/__init__.py +54 -0
  4. textpolicy/algorithms/grpo.py +642 -0
  5. textpolicy/algorithms/gspo.py +582 -0
  6. textpolicy/buffer/__init__.py +23 -0
  7. textpolicy/buffer/buffer.py +244 -0
  8. textpolicy/buffer/episode.py +383 -0
  9. textpolicy/buffer/sampling.py +438 -0
  10. textpolicy/buffer/storage.py +255 -0
  11. textpolicy/cli.py +67 -0
  12. textpolicy/environment/__init__.py +79 -0
  13. textpolicy/environment/base.py +110 -0
  14. textpolicy/environment/environment.py +46 -0
  15. textpolicy/environment/factory.py +103 -0
  16. textpolicy/environment/gym.py +106 -0
  17. textpolicy/environment/task_suites.py +51 -0
  18. textpolicy/environment/text_generation.py +797 -0
  19. textpolicy/environment/vectorized.py +253 -0
  20. textpolicy/generation/__init__.py +62 -0
  21. textpolicy/generation/lora.py +411 -0
  22. textpolicy/generation/mlx_generation.py +557 -0
  23. textpolicy/generation/reload.py +253 -0
  24. textpolicy/rewards/__init__.py +137 -0
  25. textpolicy/rewards/adapters.py +387 -0
  26. textpolicy/rewards/basic.py +214 -0
  27. textpolicy/rewards/integrated_system.py +338 -0
  28. textpolicy/rewards/mlx_batch_processor.py +447 -0
  29. textpolicy/rewards/registry.py +293 -0
  30. textpolicy/rewards/rollout_rewards.py +410 -0
  31. textpolicy/rewards/verifiers.py +369 -0
  32. textpolicy/rollout/__init__.py +44 -0
  33. textpolicy/rollout/aggregator.py +145 -0
  34. textpolicy/rollout/base.py +108 -0
  35. textpolicy/rollout/rollout.py +142 -0
  36. textpolicy/rollout/runner.py +280 -0
  37. textpolicy/rollout/strategy.py +208 -0
  38. textpolicy/rollout/worker.py +194 -0
  39. textpolicy/training/__init__.py +14 -0
  40. textpolicy/training/metrics.py +242 -0
  41. textpolicy/training/rollout_manager.py +78 -0
  42. textpolicy/training/trainer.py +684 -0
  43. textpolicy/utils/__init__.py +40 -0
  44. textpolicy/utils/benchmarking.py +489 -0
  45. textpolicy/utils/data.py +60 -0
  46. textpolicy/utils/debug.py +170 -0
  47. textpolicy/utils/environment.py +349 -0
  48. textpolicy/utils/logging/__init__.py +22 -0
  49. textpolicy/utils/logging/base.py +48 -0
  50. textpolicy/utils/logging/console.py +61 -0
  51. textpolicy/utils/logging/factory.py +133 -0
  52. textpolicy/utils/logging/multi.py +83 -0
  53. textpolicy/utils/logging/tensorboard.py +65 -0
  54. textpolicy/utils/logging/wandb.py +72 -0
  55. textpolicy/utils/memory.py +118 -0
  56. textpolicy/utils/performance.py +464 -0
  57. textpolicy/utils/timing.py +171 -0
  58. textpolicy/validate.py +101 -0
  59. textpolicy/validation/__init__.py +13 -0
  60. textpolicy/validation/logprob_validation.py +315 -0
  61. textpolicy-0.1.1.dist-info/METADATA +109 -0
  62. textpolicy-0.1.1.dist-info/RECORD +66 -0
  63. {textpolicy-0.0.1.dist-info → textpolicy-0.1.1.dist-info}/WHEEL +1 -1
  64. textpolicy-0.1.1.dist-info/entry_points.txt +2 -0
  65. textpolicy-0.0.1.dist-info/METADATA +0 -10
  66. textpolicy-0.0.1.dist-info/RECORD +0 -6
  67. {textpolicy-0.0.1.dist-info → textpolicy-0.1.1.dist-info}/licenses/LICENSE +0 -0
  68. {textpolicy-0.0.1.dist-info → textpolicy-0.1.1.dist-info}/top_level.txt +0 -0
textpolicy/validate.py ADDED
@@ -0,0 +1,101 @@
1
+ """
2
+ Lightweight programmatic installation validation for TextPolicy.
3
+
4
+ This module provides a fast health check without network or model downloads.
5
+ It verifies critical imports, environment step contracts, and basic rollout
6
+ plumbing using a minimal text generation environment.
7
+
8
+ Usage:
9
+ from textpolicy.validate import validate_installation
10
+ report = validate_installation()
11
+ assert report["status"] == "ok"
12
+ """
13
+
14
+ from typing import Any, Dict, List
15
+
16
+
17
+ def validate_installation(verbose: bool = True) -> Dict[str, Any]:
18
+ """
19
+ Run a series of quick validation checks.
20
+
21
+ - Import checks: mlx, gymnasium, mlx_lm (optional)
22
+ - Environment contract: TextGenerationEnv reset/step returns
23
+ - Rollout shape sanity: basic reward extraction path
24
+
25
+ Returns:
26
+ A dictionary with keys:
27
+ status: "ok" or "fail"
28
+ checks: mapping of check name to details
29
+ errors: list of error messages
30
+ """
31
+ checks: Dict[str, Any] = {}
32
+ errors: List[str] = []
33
+
34
+ # 1) Import checks: MLX (required for policies), gymnasium (adapters), mlx_lm (optional)
35
+ try:
36
+ import mlx.core as mx # type: ignore
37
+ checks["mlx"] = {"available": True, "version": getattr(mx, "__version__", "unknown")}
38
+ except Exception as e: # pragma: no cover - environment dependent
39
+ checks["mlx"] = {"available": False, "error": str(e)}
40
+ errors.append("MLX not available: install 'mlx' for full functionality")
41
+
42
+ try:
43
+ import gymnasium as gym # type: ignore
44
+ checks["gymnasium"] = {"available": True, "version": getattr(gym, "__version__", "unknown")}
45
+ except Exception as e: # pragma: no cover
46
+ checks["gymnasium"] = {"available": False, "error": str(e)}
47
+
48
+ try:
49
+ import mlx_lm # type: ignore
50
+ checks["mlx_lm"] = {"available": True, "version": getattr(mlx_lm, "__version__", "unknown")}
51
+ except Exception: # pragma: no cover
52
+ checks["mlx_lm"] = {"available": False}
53
+
54
+ # 2) Environment contract + reward path using a dummy tokenizer
55
+ try:
56
+ from textpolicy.environment.text_generation import TextGenerationEnv
57
+
58
+ class _DummyTokenizer:
59
+ def encode(self, text):
60
+ return [ord(c) % 256 for c in text]
61
+
62
+ def decode(self, ids):
63
+ return "".join(chr(int(i) % 256) for i in ids)
64
+
65
+ def _reward(prompt: str, completion: str, example: dict, **kwargs) -> float:
66
+ return float(len(completion.split()))
67
+
68
+ env = TextGenerationEnv(["Hello"], _reward, tokenizer=_DummyTokenizer())
69
+ obs, info = env.reset()
70
+ step_result = env.step("a b c")
71
+ ok = (
72
+ isinstance(step_result, dict)
73
+ and {"observation", "reward", "terminated", "truncated", "info"}.issubset(step_result.keys())
74
+ and step_result["reward"] > 0
75
+ )
76
+ checks["environment_contract"] = {"ok": ok}
77
+ if not ok:
78
+ errors.append("Environment.step did not return the required dict shape or reward")
79
+ except Exception as e:
80
+ checks["environment_contract"] = {"ok": False, "error": str(e)}
81
+ errors.append(f"Environment contract failed: {e}")
82
+
83
+ status = "ok" if not errors else "fail"
84
+
85
+ report = {"status": status, "checks": checks, "errors": errors}
86
+ if verbose:
87
+ _print_report(report)
88
+ return report
89
+
90
+
91
+ def _print_report(report: Dict[str, Any]) -> None:
92
+ """Pretty-print validation results with high signal-to-noise."""
93
+ status = report.get("status", "fail")
94
+ print(f"TextPolicy validation: {status}")
95
+ for name, detail in report.get("checks", {}).items():
96
+ print(f"- {name}: {detail}")
97
+ if report.get("errors"):
98
+ print("Errors:")
99
+ for msg in report["errors"]:
100
+ print(f" - {msg}")
101
+
@@ -0,0 +1,13 @@
1
+ # textpolicy/validation/__init__.py
2
+ """
3
+ Validation utilities for TextPolicy RL training system.
4
+
5
+ Validation functions to ensure training components work correctly.
6
+ """
7
+
8
+ from .logprob_validation import validate_logprob_implementation, LogprobValidator
9
+
10
+ __all__ = [
11
+ "validate_logprob_implementation",
12
+ "LogprobValidator"
13
+ ]
@@ -0,0 +1,315 @@
1
+ # textpolicy/validation/logprob_validation.py
2
+ """
3
+ Critical Logprob Validation for MLX RL Training.
4
+
5
+ This module provides rigorous testing of logprob extraction to ensure
6
+ policy gradient algorithms receive correct probability values.
7
+
8
+ Incorrect logprob computation is the #1 cause of RL training failure.
9
+ This validation catches errors before they break training.
10
+ """
11
+
12
+ import mlx.core as mx
13
+ import mlx.nn as nn
14
+ from typing import Tuple, Dict, Any, List
15
+ import numpy as np
16
+
17
+ from ..generation.mlx_generation import compute_logprobs
18
+
19
+
20
+ class LogprobValidator:
21
+ """
22
+ Comprehensive validation of logprob extraction correctness.
23
+
24
+ This class implements multiple tests to ensure logprob functions
25
+ compute the correct values needed for policy gradient training.
26
+ """
27
+
28
+ @staticmethod
29
+ def create_test_model(vocab_size: int = 100, hidden_size: int = 64) -> nn.Module:
30
+ """Create a simple test model for logprob validation."""
31
+ class SimpleTestModel(nn.Module):
32
+ def __init__(self, vocab_size: int, hidden_size: int):
33
+ super().__init__()
34
+ self.embedding = nn.Embedding(vocab_size, hidden_size)
35
+ self.linear = nn.Linear(hidden_size, vocab_size)
36
+ self.vocab_size = vocab_size
37
+
38
+ def __call__(self, x):
39
+ # Simple model for testing
40
+ if x.ndim == 1:
41
+ x = x[None] # Add batch dimension
42
+
43
+ # Get embeddings and average them (simple pooling)
44
+ h = self.embedding(x) # [batch, seq, hidden]
45
+ h = mx.mean(h, axis=1) # [batch, hidden]
46
+
47
+ # Expand back to sequence length for autoregressive simulation
48
+ seq_len = x.shape[1] # type: ignore
49
+ h_expanded = mx.broadcast_to(h[:, None, :], (h.shape[0], seq_len, h.shape[1])) # type: ignore
50
+
51
+ # Generate logits for each position
52
+ logits = self.linear(h_expanded) # [batch, seq, vocab]
53
+
54
+ return logits
55
+
56
+ return SimpleTestModel(vocab_size, hidden_size)
57
+
58
+ @staticmethod
59
+ def test_autoregressive_indexing():
60
+ """
61
+ Test that logprob extraction uses correct autoregressive indexing.
62
+
63
+ This test verifies autoregressive indexing. Incorrect indexing breaks training.
64
+ """
65
+ print("Testing autoregressive indexing...")
66
+
67
+ model = LogprobValidator.create_test_model()
68
+
69
+ # Create test data
70
+ prompt_tokens = mx.array([1, 2, 3]) # 3 tokens
71
+ response_tokens = mx.array([4, 5]) # 2 tokens
72
+
73
+ # Test the logprob computation
74
+ try:
75
+ logprobs = compute_logprobs(model, prompt_tokens, response_tokens)
76
+
77
+ # Validate output shape
78
+ assert logprobs.shape == (2,), f"Expected shape (2,), got {logprobs.shape}"
79
+
80
+ # Validate logprobs are reasonable (negative values)
81
+ assert mx.all(logprobs <= 0), "Logprobs should be <= 0 (log of probabilities)"
82
+
83
+ print("Autoregressive indexing test passed")
84
+ return True
85
+
86
+ except Exception as e:
87
+ print(f"Autoregressive indexing test failed: {e}")
88
+ return False
89
+
90
+ @staticmethod
91
+ def test_teacher_forcing_consistency():
92
+ """
93
+ Test that logprobs match teacher-forcing computation.
94
+
95
+ When we provide the full sequence [prompt + response] to the model,
96
+ the logprobs should match what the model actually computed.
97
+ """
98
+ print("Testing teacher-forcing consistency...")
99
+
100
+ model = LogprobValidator.create_test_model()
101
+
102
+ prompt_tokens = mx.array([10, 20])
103
+ response_tokens = mx.array([30, 40, 50])
104
+
105
+ try:
106
+ # Method 1: Our logprob function
107
+ computed_logprobs = compute_logprobs(model, prompt_tokens, response_tokens)
108
+
109
+ # Method 2: Manual teacher-forcing computation
110
+ full_sequence = mx.concatenate([prompt_tokens, response_tokens])
111
+ logits = model(full_sequence[None]) # [1, seq_len, vocab_size]
112
+
113
+ # Extract the same logits our function should use
114
+ prompt_len = len(prompt_tokens)
115
+ manual_logits = logits[0, prompt_len-1:prompt_len-1+len(response_tokens), :]
116
+ manual_log_probs = manual_logits - mx.logsumexp(manual_logits, axis=-1, keepdims=True)
117
+ manual_selected = manual_log_probs[mx.arange(len(response_tokens)), response_tokens]
118
+
119
+ # Compare results
120
+ diff = mx.abs(computed_logprobs - manual_selected)
121
+ max_diff = float(mx.max(diff))
122
+
123
+ assert max_diff < 1e-6, f"Logprob mismatch: max difference {max_diff}"
124
+
125
+ print(f"Teacher-forcing consistency test passed (max diff: {max_diff:.2e})")
126
+ return True
127
+
128
+ except Exception as e:
129
+ print(f"Teacher-forcing consistency test failed: {e}")
130
+ return False
131
+
132
+ @staticmethod
133
+ def test_batch_processing():
134
+ """Test logprob extraction works correctly with batched inputs."""
135
+ print("Testing batch processing...")
136
+
137
+ model = LogprobValidator.create_test_model()
138
+
139
+ try:
140
+ # Test single sequence
141
+ prompt1 = mx.array([1, 2])
142
+ response1 = mx.array([3, 4])
143
+ logprobs1 = compute_logprobs(model, prompt1, response1)
144
+
145
+ # Test another sequence
146
+ prompt2 = mx.array([5, 6])
147
+ response2 = mx.array([7, 8])
148
+ logprobs2 = compute_logprobs(model, prompt2, response2)
149
+
150
+ # Results should be consistent shapes
151
+ assert logprobs1.shape == logprobs2.shape, "Batch processing shape inconsistency"
152
+
153
+ print("Batch processing test passed")
154
+ return True
155
+
156
+ except Exception as e:
157
+ print(f"Batch processing test failed: {e}")
158
+ return False
159
+
160
+ @staticmethod
161
+ def test_edge_cases():
162
+ """Test edge cases that could break training."""
163
+ print("Testing edge cases...")
164
+
165
+ model = LogprobValidator.create_test_model()
166
+
167
+ try:
168
+ # Test empty response
169
+ prompt = mx.array([1, 2, 3])
170
+ empty_response = mx.array([])
171
+ logprobs_empty = compute_logprobs(model, prompt, empty_response)
172
+ assert len(logprobs_empty) == 0, "Empty response should return empty logprobs"
173
+
174
+ # Test single token response
175
+ single_response = mx.array([4])
176
+ logprobs_single = compute_logprobs(model, prompt, single_response)
177
+ assert len(logprobs_single) == 1, "Single token response should return single logprob"
178
+
179
+ # Test a short prompt
180
+ short_prompt = mx.array([1])
181
+ response = mx.array([2, 3])
182
+ logprobs_short = compute_logprobs(model, short_prompt, response)
183
+ assert len(logprobs_short) == 2, "Should handle short prompts correctly"
184
+
185
+ print("Edge cases test passed")
186
+ return True
187
+
188
+ except Exception as e:
189
+ print(f"Edge cases test failed: {e}")
190
+ return False
191
+
192
+ @staticmethod
193
+ def test_numerical_stability():
194
+ """Test numerical stability of logprob computation."""
195
+ print("Testing numerical stability...")
196
+
197
+ model = LogprobValidator.create_test_model()
198
+
199
+ try:
200
+ prompt = mx.array([1, 2, 3])
201
+ response = mx.array([4, 5, 6])
202
+
203
+ logprobs = compute_logprobs(model, prompt, response)
204
+
205
+ # Check for NaN or Inf values
206
+ assert not mx.any(mx.isnan(logprobs)), "NaN values detected in logprobs"
207
+ assert not mx.any(mx.isinf(logprobs)), "Inf values detected in logprobs"
208
+
209
+ # Check logprobs are reasonable (log-probabilities are non-positive)
210
+ assert mx.all(logprobs <= 0), "Logprobs should be negative (log probabilities)"
211
+ assert mx.all(logprobs >= -50), "Logprobs should not be too negative"
212
+
213
+ print("Numerical stability test passed")
214
+ return True
215
+
216
+ except Exception as e:
217
+ print(f"Numerical stability test failed: {e}")
218
+ return False
219
+
220
+ @staticmethod
221
+ def test_gradient_computation():
222
+ """Test that gradients can flow through logprob computation when needed."""
223
+ print("Testing gradient computation...")
224
+
225
+ model = LogprobValidator.create_test_model()
226
+
227
+ try:
228
+ prompt = mx.array([1, 2])
229
+ response = mx.array([3, 4])
230
+
231
+ # Test that we can compute gradients through the model
232
+ def loss_fn(model_params):
233
+ # Forward pass
234
+ full_seq = mx.concatenate([prompt, response])
235
+ logits = model(full_seq[None])
236
+
237
+ # Simple loss (not using our logprob function, just testing model)
238
+ return mx.mean(logits)
239
+
240
+ # Test gradient computation
241
+ grads = mx.grad(loss_fn)(model.parameters())
242
+
243
+ # Should get gradients for all parameters
244
+ assert len(grads) > 0, "No gradients computed"
245
+
246
+ print("Gradient computation test passed")
247
+ return True
248
+
249
+ except Exception as e:
250
+ print(f"Gradient computation test failed: {e}")
251
+ return False
252
+
253
+ @staticmethod
254
+ def run_full_validation() -> bool:
255
+ """
256
+ Run complete logprob validation suite.
257
+
258
+ Returns:
259
+ True if all tests pass, False otherwise
260
+ """
261
+ print("=" * 60)
262
+ print("LOGPROB VALIDATION")
263
+ print("=" * 60)
264
+ print("Testing logprob extraction correctness for RL training...")
265
+ print()
266
+
267
+ tests = [
268
+ LogprobValidator.test_autoregressive_indexing,
269
+ LogprobValidator.test_teacher_forcing_consistency,
270
+ LogprobValidator.test_batch_processing,
271
+ LogprobValidator.test_edge_cases,
272
+ LogprobValidator.test_numerical_stability,
273
+ LogprobValidator.test_gradient_computation
274
+ ]
275
+
276
+ passed = 0
277
+ total = len(tests)
278
+
279
+ for test in tests:
280
+ try:
281
+ if test():
282
+ passed += 1
283
+ print()
284
+ except Exception as e:
285
+ print(f"Test {test.__name__} crashed: {e}")
286
+ print()
287
+
288
+ print("=" * 60)
289
+ print(f"VALIDATION RESULTS: {passed}/{total} tests passed")
290
+
291
+ if passed == total:
292
+ print("ALL LOGPROB TESTS PASSED!")
293
+ print("Logprob extraction is correct for RL training")
294
+ return True
295
+ else:
296
+ print("LOGPROB VALIDATION FAILED!")
297
+ print("RL training may not work correctly with the current implementation")
298
+ return False
299
+
300
+
301
+ def validate_logprob_implementation() -> bool:
302
+ """
303
+ Main entry point for logprob validation.
304
+
305
+ This function should be called before starting any RL training
306
+ to ensure logprob extraction is implemented correctly.
307
+
308
+ Returns:
309
+ True if validation passes, False otherwise
310
+ """
311
+ return LogprobValidator.run_full_validation()
312
+
313
+
314
+ if __name__ == "__main__":
315
+ validate_logprob_implementation()
@@ -0,0 +1,109 @@
1
+ Metadata-Version: 2.4
2
+ Name: textpolicy
3
+ Version: 0.1.1
4
+ Summary: Reinforcement learning for text generation on MLX (Apple Silicon): GRPO/GSPO, environments, rollout, rewards, LoRA/QLoRA
5
+ Project-URL: Homepage, https://github.com/teilomillet/textpolicy
6
+ Project-URL: Repository, https://github.com/teilomillet/textpolicy
7
+ Project-URL: Documentation, https://github.com/teilomillet/textpolicy#readme
8
+ Project-URL: Changelog, https://github.com/teilomillet/textpolicy/blob/main/CHANGELOG.md
9
+ Keywords: reinforcement-learning,text-generation,mlx,apple-silicon,lora,qlora,grpo,gspo,rlhf
10
+ Classifier: Programming Language :: Python :: 3
11
+ Classifier: Operating System :: MacOS
12
+ Classifier: Intended Audience :: Developers
13
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
14
+ Classifier: License :: OSI Approved :: MIT License
15
+ Requires-Python: >=3.12
16
+ Description-Content-Type: text/markdown
17
+ License-File: LICENSE
18
+ Requires-Dist: numpy>=2.3.2
19
+ Requires-Dist: mlx>=0.21.0
20
+ Requires-Dist: mlx-lm>=0.21.0
21
+ Requires-Dist: gymnasium>=0.29.0
22
+ Requires-Dist: psutil>=7.0.0
23
+ Requires-Dist: wandb>=0.21.1
24
+ Requires-Dist: aiohttp>=3.12.15
25
+ Requires-Dist: pytest>=8.4.1
26
+ Provides-Extra: external
27
+ Requires-Dist: aiohttp>=3.8.0; extra == "external"
28
+ Requires-Dist: pydantic>=2.0.0; extra == "external"
29
+ Provides-Extra: dev
30
+ Requires-Dist: pytest>=7.0.0; extra == "dev"
31
+ Requires-Dist: black>=22.0.0; extra == "dev"
32
+ Requires-Dist: ruff>=0.1.0; extra == "dev"
33
+ Dynamic: license-file
34
+
35
+ # TextPolicy
36
+
37
+ Reinforcement learning toolkit for text generation on MLX (Apple Silicon).
38
+ TextPolicy provides algorithms (GRPO/GSPO), text-generation environments, a rollout runner,
39
+ reward functions with a decorator registry, and LoRA/QLoRA utilities.
40
+
41
+ ## Install (uv)
42
+
43
+ ```bash
44
+ uv add textpolicy
45
+ ```
46
+
47
+ Optional model integration:
48
+
49
+ ```bash
50
+ uv add mlx mlx-lm
51
+ ```
52
+
53
+ ## Quickstart
54
+
55
+ Working example using a real model and tokenizer (mlx-lm required):
56
+
57
+ ```python
58
+ import mlx.core as mx
59
+ import textpolicy as tp
60
+ from textpolicy import load_model, create_policy
61
+ from textpolicy.environment.text_generation import TextGenerationEnv
62
+ from textpolicy.rollout import RolloutRunner, create_strategy
63
+
64
+ # 1) Load model and tokenizer (mlx-lm)
65
+ model, tokenizer = load_model("Qwen/Qwen3-0.6B")
66
+
67
+ # 2) Create a policy (controls generation)
68
+ generation_params = {"max_tokens": 25, "temperature": 0.7}
69
+ policy_fn = create_policy(model, tokenizer, generation_params)
70
+
71
+ # 3) Define a reward function (env uses this to score responses)
72
+ @tp.reward
73
+ def length_reward(prompt: str, completion: str, example: dict, **kwargs) -> float:
74
+ return float(len(completion.split()))
75
+
76
+ # 4) Create an environment (requires a tokenizer)
77
+ env = TextGenerationEnv(["What is AI?"], length_reward, tokenizer=tokenizer)
78
+
79
+ # 5) Collect one rollout step
80
+ strategy = create_strategy('grpo')
81
+ runner = RolloutRunner(env, policy=policy_fn, strategy=strategy, max_steps=1)
82
+ buffer = runner.collect()
83
+ print(len(buffer.episodes))
84
+ ```
85
+
86
+ Docs:
87
+ - Quickstart: `docs/QUICKSTART_UV.md`
88
+ - LoRA/QLoRA: `docs/10_lora_qlora.md`
89
+ - Full index: `docs/index.md`
90
+
91
+ FAQ:
92
+ - Do I need a model?
93
+ - Yes for generation with `create_policy`.
94
+ Use `load_model()` (mlx‑lm) to get `(model, tokenizer)`.
95
+ For reward‑only code (no generation), a model is not required.
96
+ - Do I need a tokenizer?
97
+ - Yes.
98
+ Both `TextGenerationEnv` and `TextGenerationEnvironment` require a tokenizer.
99
+ `load_model()` returns one for mlx‑lm models.
100
+ - How do I control generation?
101
+ - Pass `generation_params` to `create_policy` (for example, `max_tokens`, `temperature`, `top_p`, `repetition_penalty`).
102
+ - What does `step()` return?
103
+ - A dict with `observation`, `reward`, `terminated`, `truncated`, `info`. The runner enforces this.
104
+
105
+ Examples:
106
+ - 01–06: reward functions, batch processing, minimal training
107
+ - 08: GRPO training with rollout + buffer
108
+ - 09–10: length reduction (GRPO/GSPO)
109
+ - 11: LoRA/QLoRA configuration
@@ -0,0 +1,66 @@
1
+ textpolicy/__init__.py,sha256=vDAHJ826gKuTZUjcAftzz-RTX8KuOjH50Uj1RMhjTIQ,1606
2
+ textpolicy/__main__.py,sha256=IlGmjJaW-DJUC7yhxUhbwNOZA3GxkeQGkVbFdS3_wBI,136
3
+ textpolicy/cli.py,sha256=3CcJzrRlin1pgd6Mh312Xp3-EihHtTSvhakyYpdfacs,2107
4
+ textpolicy/validate.py,sha256=lxmegz83B_c-PS3cFHaaL3c9fgWrEaLsDLkpPFtSj8Y,3780
5
+ textpolicy/algorithms/__init__.py,sha256=muJSuiJkaGg-zSaGIYkaB7UbLh6UJYMdI60SGqTgNWM,1257
6
+ textpolicy/algorithms/grpo.py,sha256=1j_C70Bgwrnr_BCAl_qvAsH3Mg9yMOW-D4vhPUxUpFQ,26261
7
+ textpolicy/algorithms/gspo.py,sha256=OWvJolldTSTEOsCIwio3ER0hTWkYsJ1e0BBJElgJ2mc,23485
8
+ textpolicy/buffer/__init__.py,sha256=bnSkX9Oe1ajau-yqC2PYNF4a4ELVP05zjlkDmIerXlw,569
9
+ textpolicy/buffer/buffer.py,sha256=mDie8ZiWgsjNJ4LiKyfpQNLzN1K0UICxI8XaqQacUMM,7917
10
+ textpolicy/buffer/episode.py,sha256=iNyVqeMLzOMauz1Z3fs9JUyL7g7IEC9t8GN1eypThy4,15875
11
+ textpolicy/buffer/sampling.py,sha256=aE92R69TQe8c5V0WL_EaKHLXSSxbI2wDlgvtR80fJDc,17598
12
+ textpolicy/buffer/storage.py,sha256=dYdROkL4KHx0w3XAqJFvRs6gRbYs3dAsiHqdMnwBMAI,9248
13
+ textpolicy/environment/__init__.py,sha256=TzhkXiAq10iC9PGGJ-gn2Bm2ttmJHpQ8d3Lxf_wUnNg,1685
14
+ textpolicy/environment/base.py,sha256=7Bh5tlFz20RC-iPQhnpv-zigHf8a6cEwBC3ISV-HS1Y,2829
15
+ textpolicy/environment/environment.py,sha256=o8-RY6wj5xrzDBp77HoY2At3XlBwvreF3DKpDFpC7js,1240
16
+ textpolicy/environment/factory.py,sha256=pebQo1_M3sMF8Pdc9yvpdXzRXfIDllKJoAQAjQbif0E,3124
17
+ textpolicy/environment/gym.py,sha256=P8Bi8PlDtcWWa9uLuCjkhZnYRVs-mg6iSJVSBkG99f8,3186
18
+ textpolicy/environment/task_suites.py,sha256=ssPnw2Y3eGYaskWf8dUab4rNu_Bx5L284b3VdhgvSPM,1544
19
+ textpolicy/environment/text_generation.py,sha256=Jql0pEfrPp9tqNsPOAdIP-UYoAUsfV969TMR2uPkUp4,31837
20
+ textpolicy/environment/vectorized.py,sha256=ZROtpmdbh1Oi8c0b0D_vmVzqI16Cp2WZTmkjkRbMoDg,9932
21
+ textpolicy/generation/__init__.py,sha256=J3dc0SPAZChJTsRn47tz8FfIp3XwNgZ-8_H9VBpQYvQ,1266
22
+ textpolicy/generation/lora.py,sha256=xSKRczJY20BrkkU1SSgBtDc30tZjdFE7FhEZPUEoiyg,13747
23
+ textpolicy/generation/mlx_generation.py,sha256=r__oXHiAtAQ4xq4ODUwS7FrXL40Hu9cwoS5sZOhsAfs,20468
24
+ textpolicy/generation/reload.py,sha256=-eJE3LXmN-kDatUQjM0--VZp0jjqWgBslYcmNcQZ_A8,7998
25
+ textpolicy/rewards/__init__.py,sha256=mg_wL7oedL_5KLsnaJuPVc_ZHZqZKXRHg9ws2gSifMk,4769
26
+ textpolicy/rewards/adapters.py,sha256=Ffwi9eF_mx6DdCoRRmzl7TdhqNJycpz1TovJXa0XxXk,12843
27
+ textpolicy/rewards/basic.py,sha256=xlMMfCmLm3TrDJsxpJ-h9vlc-m27tTrvnZ-JGUOD89A,6921
28
+ textpolicy/rewards/integrated_system.py,sha256=eGK69J2cAfJD4GoL_ANivA8ZWpXHLtu6A0YoMwGTAzI,11243
29
+ textpolicy/rewards/mlx_batch_processor.py,sha256=97jFIHwqq75q7-LADVIBCbNqZJTU8jpbs8xcyUrJNfQ,16850
30
+ textpolicy/rewards/registry.py,sha256=azuz4HpbacBUss2-FS0Wji_FOUs7NLtwgpsEynqq7ds,11437
31
+ textpolicy/rewards/rollout_rewards.py,sha256=7bSkbBJwsb9MLOQ-YMutocpIgGI_AZdgFjRCh2xv0iY,13805
32
+ textpolicy/rewards/verifiers.py,sha256=Xnb9FC5Lc0CE34Wwoc99AwKRt0Ut0gDliImr_YDChiY,11596
33
+ textpolicy/rollout/__init__.py,sha256=PKjY1NmsARTPmUwzNLSp2tFU4NvgJ4NUP1VPy2g3nxI,1229
34
+ textpolicy/rollout/aggregator.py,sha256=U2ywEPWwBHGgpEPHxmu2ywSibSa_pkD_GuHN3JGcwck,4467
35
+ textpolicy/rollout/base.py,sha256=CuyzsHM_yn3eRKldLCcEDfmnqFnHoq-rJ7k0f-nYHw8,2919
36
+ textpolicy/rollout/rollout.py,sha256=h3gs_U-NfoIKpBVf1NFeZGInvSki8RDATsq1__ne8Qo,4499
37
+ textpolicy/rollout/runner.py,sha256=9bB0B1GlEGNtr8bhEYQbpY1WBzJQK0MoFrsbZTQ-Lzw,10993
38
+ textpolicy/rollout/strategy.py,sha256=Q97wxgq-FCienL15P1l-pXYEWiUZrh861UmtStj4x3E,7577
39
+ textpolicy/rollout/worker.py,sha256=aXOKRtkivKwDks8g8VtaWUv-wQMPR72idZxPuNtwmSE,6939
40
+ textpolicy/training/__init__.py,sha256=TmcW2BqmwO4DaDDr4n2g1QOtHeVPxgw6xZdeYTmzjD8,282
41
+ textpolicy/training/metrics.py,sha256=fmY1ZBdyEgYrfH18H3fOZ-dieMtjVNzjxjdxd7yo7OU,7582
42
+ textpolicy/training/rollout_manager.py,sha256=ETD7WTbbaQ8uUzrHPBCDX-PawmEJfSK6Kd5N-dvIZRY,2328
43
+ textpolicy/training/trainer.py,sha256=kG7tduOKHPFVVewyspgm360enowTpNpwaLhZWuIc9vo,29268
44
+ textpolicy/utils/__init__.py,sha256=v0ji-jnegGRydzmAOccKY4XC0nkBbBZqdHXzk-i6ers,1220
45
+ textpolicy/utils/benchmarking.py,sha256=YDN24vU8SL_EsrANQWF1qbmXtfhF4Woj8yjez-h-Io0,18682
46
+ textpolicy/utils/data.py,sha256=KJoPzYWYVAJawvDX1BHzwBZEpCXLSBC168rjud7MSB0,1413
47
+ textpolicy/utils/debug.py,sha256=ir_5DF88_yZbU43w-o_o05EivgPv9AgNVRovL-adNIE,6139
48
+ textpolicy/utils/environment.py,sha256=LyYQgpZVfEDyPlD7774_AHR9crOC6NNGjd6J37ltLGM,13319
49
+ textpolicy/utils/memory.py,sha256=H8mfUY52iU5VRPhOLPdanWvBgEGLtLCoUE0xpJIMcfM,3391
50
+ textpolicy/utils/performance.py,sha256=YzLMT_bAPs9TnVwmrPvWOEwi4UAy5Bmbr5zTpJacmGA,17823
51
+ textpolicy/utils/timing.py,sha256=MUCBUC22zQ3nMVxhif_XhHlyZUxYGhheFMs7bf6ZKg8,4869
52
+ textpolicy/utils/logging/__init__.py,sha256=YeJ18H7suPzYPgXlfSsE90GaAVF_5lOcbGCDCy8tvuA,512
53
+ textpolicy/utils/logging/base.py,sha256=3BBg318dEPNIHYtoJnFiJcktvS5KMcIUNXBKKfn8x_g,1304
54
+ textpolicy/utils/logging/console.py,sha256=WIcOO-tT2FhrqKhWO89qSSNmLA06up_KenrN6TGoyo8,1680
55
+ textpolicy/utils/logging/factory.py,sha256=vAkMShn7bnVRDuZNOKaVXNmF6XPUNeqaFFPF4dZw70E,3920
56
+ textpolicy/utils/logging/multi.py,sha256=kIxuoXiZ4nf_p8JlnzYxtqA0r82LJaqHnO5mEHlJdgM,2501
57
+ textpolicy/utils/logging/tensorboard.py,sha256=aY9YMReSJkWEhy6SdAAUlHSB4lzDecivBC8K7CZPcO4,1949
58
+ textpolicy/utils/logging/wandb.py,sha256=U4pxuZNOz2l8XiymK8OFbCpiRTBOLNtnZakC_udttfQ,2206
59
+ textpolicy/validation/__init__.py,sha256=KcyppNi91w0bF51gZ0ykUIKEiF7z6TT37uuavMFScnA,328
60
+ textpolicy/validation/logprob_validation.py,sha256=G_CCy5NRDUTmo7WZIChhNVM3NtP1VmWAjdd5z6TIvos,11749
61
+ textpolicy-0.1.1.dist-info/licenses/LICENSE,sha256=AYDHSNRbiqZt4HHH1gaOoQ2hjYjK4bqw4Vd9UyKzx18,1065
62
+ textpolicy-0.1.1.dist-info/METADATA,sha256=CrrIoETuh6xExhyqrhWq-8KcHSNVeuyzo9oZ8uxLOIU,3895
63
+ textpolicy-0.1.1.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
64
+ textpolicy-0.1.1.dist-info/entry_points.txt,sha256=d0Cj5boT6k_l_beVPWPt9LZMllsN4kbIUmsNsn1BANE,51
65
+ textpolicy-0.1.1.dist-info/top_level.txt,sha256=Ww6_QEF71dI-AYCaugiGeGcgMoFAixSOszSoRsyX-E0,11
66
+ textpolicy-0.1.1.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (80.9.0)
2
+ Generator: setuptools (80.10.2)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -0,0 +1,2 @@
1
+ [console_scripts]
2
+ textpolicy = textpolicy.cli:main
@@ -1,10 +0,0 @@
1
- Metadata-Version: 2.4
2
- Name: textpolicy
3
- Version: 0.0.1
4
- Summary: Add your description here
5
- Requires-Python: >=3.12
6
- Description-Content-Type: text/markdown
7
- License-File: LICENSE
8
- Dynamic: license-file
9
-
10
- # textpolicy
@@ -1,6 +0,0 @@
1
- textpolicy/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
- textpolicy-0.0.1.dist-info/licenses/LICENSE,sha256=AYDHSNRbiqZt4HHH1gaOoQ2hjYjK4bqw4Vd9UyKzx18,1065
3
- textpolicy-0.0.1.dist-info/METADATA,sha256=Hm0fs04Q8V79bM4-GI0QMPDPEWN97oe5ULRex1orYDQ,211
4
- textpolicy-0.0.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
5
- textpolicy-0.0.1.dist-info/top_level.txt,sha256=Ww6_QEF71dI-AYCaugiGeGcgMoFAixSOszSoRsyX-E0,11
6
- textpolicy-0.0.1.dist-info/RECORD,,