textpolicy 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- textpolicy/__init__.py +52 -0
- textpolicy/__main__.py +8 -0
- textpolicy/algorithms/__init__.py +54 -0
- textpolicy/algorithms/grpo.py +642 -0
- textpolicy/algorithms/gspo.py +582 -0
- textpolicy/buffer/__init__.py +23 -0
- textpolicy/buffer/buffer.py +244 -0
- textpolicy/buffer/episode.py +383 -0
- textpolicy/buffer/sampling.py +438 -0
- textpolicy/buffer/storage.py +255 -0
- textpolicy/cli.py +67 -0
- textpolicy/environment/__init__.py +79 -0
- textpolicy/environment/base.py +110 -0
- textpolicy/environment/environment.py +46 -0
- textpolicy/environment/factory.py +103 -0
- textpolicy/environment/gym.py +106 -0
- textpolicy/environment/task_suites.py +51 -0
- textpolicy/environment/text_generation.py +789 -0
- textpolicy/environment/vectorized.py +253 -0
- textpolicy/generation/__init__.py +62 -0
- textpolicy/generation/lora.py +411 -0
- textpolicy/generation/mlx_generation.py +557 -0
- textpolicy/generation/reload.py +253 -0
- textpolicy/rewards/__init__.py +137 -0
- textpolicy/rewards/adapters.py +387 -0
- textpolicy/rewards/basic.py +214 -0
- textpolicy/rewards/integrated_system.py +338 -0
- textpolicy/rewards/mlx_batch_processor.py +447 -0
- textpolicy/rewards/registry.py +293 -0
- textpolicy/rewards/rollout_rewards.py +410 -0
- textpolicy/rewards/verifiers.py +369 -0
- textpolicy/rollout/__init__.py +44 -0
- textpolicy/rollout/aggregator.py +145 -0
- textpolicy/rollout/base.py +108 -0
- textpolicy/rollout/rollout.py +142 -0
- textpolicy/rollout/runner.py +280 -0
- textpolicy/rollout/strategy.py +208 -0
- textpolicy/rollout/worker.py +194 -0
- textpolicy/training/__init__.py +14 -0
- textpolicy/training/metrics.py +242 -0
- textpolicy/training/rollout_manager.py +78 -0
- textpolicy/training/trainer.py +684 -0
- textpolicy/utils/__init__.py +40 -0
- textpolicy/utils/benchmarking.py +489 -0
- textpolicy/utils/data.py +60 -0
- textpolicy/utils/debug.py +170 -0
- textpolicy/utils/environment.py +349 -0
- textpolicy/utils/logging/__init__.py +22 -0
- textpolicy/utils/logging/base.py +48 -0
- textpolicy/utils/logging/console.py +61 -0
- textpolicy/utils/logging/factory.py +133 -0
- textpolicy/utils/logging/multi.py +83 -0
- textpolicy/utils/logging/tensorboard.py +65 -0
- textpolicy/utils/logging/wandb.py +72 -0
- textpolicy/utils/memory.py +118 -0
- textpolicy/utils/performance.py +464 -0
- textpolicy/utils/timing.py +171 -0
- textpolicy/validate.py +101 -0
- textpolicy/validation/__init__.py +13 -0
- textpolicy/validation/logprob_validation.py +315 -0
- textpolicy-0.1.0.dist-info/METADATA +99 -0
- textpolicy-0.1.0.dist-info/RECORD +66 -0
- textpolicy-0.1.0.dist-info/entry_points.txt +2 -0
- textpolicy-0.0.1.dist-info/METADATA +0 -10
- textpolicy-0.0.1.dist-info/RECORD +0 -6
- {textpolicy-0.0.1.dist-info → textpolicy-0.1.0.dist-info}/WHEEL +0 -0
- {textpolicy-0.0.1.dist-info → textpolicy-0.1.0.dist-info}/licenses/LICENSE +0 -0
- {textpolicy-0.0.1.dist-info → textpolicy-0.1.0.dist-info}/top_level.txt +0 -0
textpolicy/validate.py
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Lightweight programmatic installation validation for TextPolicy.
|
|
3
|
+
|
|
4
|
+
This module provides a fast health check without network or model downloads.
|
|
5
|
+
It verifies critical imports, environment step contracts, and basic rollout
|
|
6
|
+
plumbing using a minimal text generation environment.
|
|
7
|
+
|
|
8
|
+
Usage:
|
|
9
|
+
from textpolicy.validate import validate_installation
|
|
10
|
+
report = validate_installation()
|
|
11
|
+
assert report["status"] == "ok"
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
from typing import Any, Dict, List
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def validate_installation(verbose: bool = True) -> Dict[str, Any]:
|
|
18
|
+
"""
|
|
19
|
+
Run a series of quick validation checks.
|
|
20
|
+
|
|
21
|
+
- Import checks: mlx, gymnasium, mlx_lm (optional)
|
|
22
|
+
- Environment contract: TextGenerationEnv reset/step returns
|
|
23
|
+
- Rollout shape sanity: basic reward extraction path
|
|
24
|
+
|
|
25
|
+
Returns:
|
|
26
|
+
A dictionary with keys:
|
|
27
|
+
status: "ok" or "fail"
|
|
28
|
+
checks: mapping of check name to details
|
|
29
|
+
errors: list of error messages
|
|
30
|
+
"""
|
|
31
|
+
checks: Dict[str, Any] = {}
|
|
32
|
+
errors: List[str] = []
|
|
33
|
+
|
|
34
|
+
# 1) Import checks: MLX (required for policies), gymnasium (adapters), mlx_lm (optional)
|
|
35
|
+
try:
|
|
36
|
+
import mlx.core as mx # type: ignore
|
|
37
|
+
checks["mlx"] = {"available": True, "version": getattr(mx, "__version__", "unknown")}
|
|
38
|
+
except Exception as e: # pragma: no cover - environment dependent
|
|
39
|
+
checks["mlx"] = {"available": False, "error": str(e)}
|
|
40
|
+
errors.append("MLX not available: install 'mlx' for full functionality")
|
|
41
|
+
|
|
42
|
+
try:
|
|
43
|
+
import gymnasium as gym # type: ignore
|
|
44
|
+
checks["gymnasium"] = {"available": True, "version": getattr(gym, "__version__", "unknown")}
|
|
45
|
+
except Exception as e: # pragma: no cover
|
|
46
|
+
checks["gymnasium"] = {"available": False, "error": str(e)}
|
|
47
|
+
|
|
48
|
+
try:
|
|
49
|
+
import mlx_lm # type: ignore
|
|
50
|
+
checks["mlx_lm"] = {"available": True, "version": getattr(mlx_lm, "__version__", "unknown")}
|
|
51
|
+
except Exception: # pragma: no cover
|
|
52
|
+
checks["mlx_lm"] = {"available": False}
|
|
53
|
+
|
|
54
|
+
# 2) Environment contract + reward path using a dummy tokenizer
|
|
55
|
+
try:
|
|
56
|
+
from textpolicy.environment.text_generation import TextGenerationEnv
|
|
57
|
+
|
|
58
|
+
class _DummyTokenizer:
|
|
59
|
+
def encode(self, text):
|
|
60
|
+
return [ord(c) % 256 for c in text]
|
|
61
|
+
|
|
62
|
+
def decode(self, ids):
|
|
63
|
+
return "".join(chr(int(i) % 256) for i in ids)
|
|
64
|
+
|
|
65
|
+
def _reward(prompt: str, completion: str, example: dict, **kwargs) -> float:
|
|
66
|
+
return float(len(completion.split()))
|
|
67
|
+
|
|
68
|
+
env = TextGenerationEnv(["Hello"], _reward, tokenizer=_DummyTokenizer())
|
|
69
|
+
obs, info = env.reset()
|
|
70
|
+
step_result = env.step("a b c")
|
|
71
|
+
ok = (
|
|
72
|
+
isinstance(step_result, dict)
|
|
73
|
+
and {"observation", "reward", "terminated", "truncated", "info"}.issubset(step_result.keys())
|
|
74
|
+
and step_result["reward"] > 0
|
|
75
|
+
)
|
|
76
|
+
checks["environment_contract"] = {"ok": ok}
|
|
77
|
+
if not ok:
|
|
78
|
+
errors.append("Environment.step did not return the required dict shape or reward")
|
|
79
|
+
except Exception as e:
|
|
80
|
+
checks["environment_contract"] = {"ok": False, "error": str(e)}
|
|
81
|
+
errors.append(f"Environment contract failed: {e}")
|
|
82
|
+
|
|
83
|
+
status = "ok" if not errors else "fail"
|
|
84
|
+
|
|
85
|
+
report = {"status": status, "checks": checks, "errors": errors}
|
|
86
|
+
if verbose:
|
|
87
|
+
_print_report(report)
|
|
88
|
+
return report
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def _print_report(report: Dict[str, Any]) -> None:
|
|
92
|
+
"""Pretty-print validation results with high signal-to-noise."""
|
|
93
|
+
status = report.get("status", "fail")
|
|
94
|
+
print(f"TextPolicy validation: {status}")
|
|
95
|
+
for name, detail in report.get("checks", {}).items():
|
|
96
|
+
print(f"- {name}: {detail}")
|
|
97
|
+
if report.get("errors"):
|
|
98
|
+
print("Errors:")
|
|
99
|
+
for msg in report["errors"]:
|
|
100
|
+
print(f" - {msg}")
|
|
101
|
+
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# textpolicy/validation/__init__.py
|
|
2
|
+
"""
|
|
3
|
+
Validation utilities for TextPolicy RL training system.
|
|
4
|
+
|
|
5
|
+
Validation functions to ensure training components work correctly.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from .logprob_validation import validate_logprob_implementation, LogprobValidator
|
|
9
|
+
|
|
10
|
+
__all__ = [
|
|
11
|
+
"validate_logprob_implementation",
|
|
12
|
+
"LogprobValidator"
|
|
13
|
+
]
|
|
@@ -0,0 +1,315 @@
|
|
|
1
|
+
# textpolicy/validation/logprob_validation.py
|
|
2
|
+
"""
|
|
3
|
+
Critical Logprob Validation for MLX RL Training.
|
|
4
|
+
|
|
5
|
+
This module provides rigorous testing of logprob extraction to ensure
|
|
6
|
+
policy gradient algorithms receive correct probability values.
|
|
7
|
+
|
|
8
|
+
Incorrect logprob computation is the #1 cause of RL training failure.
|
|
9
|
+
This validation catches errors before they break training.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
import mlx.core as mx
|
|
13
|
+
import mlx.nn as nn
|
|
14
|
+
from typing import Tuple, Dict, Any, List
|
|
15
|
+
import numpy as np
|
|
16
|
+
|
|
17
|
+
from ..generation.mlx_generation import compute_logprobs
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class LogprobValidator:
|
|
21
|
+
"""
|
|
22
|
+
Comprehensive validation of logprob extraction correctness.
|
|
23
|
+
|
|
24
|
+
This class implements multiple tests to ensure logprob functions
|
|
25
|
+
compute the correct values needed for policy gradient training.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
@staticmethod
|
|
29
|
+
def create_test_model(vocab_size: int = 100, hidden_size: int = 64) -> nn.Module:
|
|
30
|
+
"""Create a simple test model for logprob validation."""
|
|
31
|
+
class SimpleTestModel(nn.Module):
|
|
32
|
+
def __init__(self, vocab_size: int, hidden_size: int):
|
|
33
|
+
super().__init__()
|
|
34
|
+
self.embedding = nn.Embedding(vocab_size, hidden_size)
|
|
35
|
+
self.linear = nn.Linear(hidden_size, vocab_size)
|
|
36
|
+
self.vocab_size = vocab_size
|
|
37
|
+
|
|
38
|
+
def __call__(self, x):
|
|
39
|
+
# Simple model for testing
|
|
40
|
+
if x.ndim == 1:
|
|
41
|
+
x = x[None] # Add batch dimension
|
|
42
|
+
|
|
43
|
+
# Get embeddings and average them (simple pooling)
|
|
44
|
+
h = self.embedding(x) # [batch, seq, hidden]
|
|
45
|
+
h = mx.mean(h, axis=1) # [batch, hidden]
|
|
46
|
+
|
|
47
|
+
# Expand back to sequence length for autoregressive simulation
|
|
48
|
+
seq_len = x.shape[1] # type: ignore
|
|
49
|
+
h_expanded = mx.broadcast_to(h[:, None, :], (h.shape[0], seq_len, h.shape[1])) # type: ignore
|
|
50
|
+
|
|
51
|
+
# Generate logits for each position
|
|
52
|
+
logits = self.linear(h_expanded) # [batch, seq, vocab]
|
|
53
|
+
|
|
54
|
+
return logits
|
|
55
|
+
|
|
56
|
+
return SimpleTestModel(vocab_size, hidden_size)
|
|
57
|
+
|
|
58
|
+
@staticmethod
|
|
59
|
+
def test_autoregressive_indexing():
|
|
60
|
+
"""
|
|
61
|
+
Test that logprob extraction uses correct autoregressive indexing.
|
|
62
|
+
|
|
63
|
+
This test verifies autoregressive indexing. Incorrect indexing breaks training.
|
|
64
|
+
"""
|
|
65
|
+
print("Testing autoregressive indexing...")
|
|
66
|
+
|
|
67
|
+
model = LogprobValidator.create_test_model()
|
|
68
|
+
|
|
69
|
+
# Create test data
|
|
70
|
+
prompt_tokens = mx.array([1, 2, 3]) # 3 tokens
|
|
71
|
+
response_tokens = mx.array([4, 5]) # 2 tokens
|
|
72
|
+
|
|
73
|
+
# Test the logprob computation
|
|
74
|
+
try:
|
|
75
|
+
logprobs = compute_logprobs(model, prompt_tokens, response_tokens)
|
|
76
|
+
|
|
77
|
+
# Validate output shape
|
|
78
|
+
assert logprobs.shape == (2,), f"Expected shape (2,), got {logprobs.shape}"
|
|
79
|
+
|
|
80
|
+
# Validate logprobs are reasonable (negative values)
|
|
81
|
+
assert mx.all(logprobs <= 0), "Logprobs should be <= 0 (log of probabilities)"
|
|
82
|
+
|
|
83
|
+
print("Autoregressive indexing test passed")
|
|
84
|
+
return True
|
|
85
|
+
|
|
86
|
+
except Exception as e:
|
|
87
|
+
print(f"Autoregressive indexing test failed: {e}")
|
|
88
|
+
return False
|
|
89
|
+
|
|
90
|
+
@staticmethod
|
|
91
|
+
def test_teacher_forcing_consistency():
|
|
92
|
+
"""
|
|
93
|
+
Test that logprobs match teacher-forcing computation.
|
|
94
|
+
|
|
95
|
+
When we provide the full sequence [prompt + response] to the model,
|
|
96
|
+
the logprobs should match what the model actually computed.
|
|
97
|
+
"""
|
|
98
|
+
print("Testing teacher-forcing consistency...")
|
|
99
|
+
|
|
100
|
+
model = LogprobValidator.create_test_model()
|
|
101
|
+
|
|
102
|
+
prompt_tokens = mx.array([10, 20])
|
|
103
|
+
response_tokens = mx.array([30, 40, 50])
|
|
104
|
+
|
|
105
|
+
try:
|
|
106
|
+
# Method 1: Our logprob function
|
|
107
|
+
computed_logprobs = compute_logprobs(model, prompt_tokens, response_tokens)
|
|
108
|
+
|
|
109
|
+
# Method 2: Manual teacher-forcing computation
|
|
110
|
+
full_sequence = mx.concatenate([prompt_tokens, response_tokens])
|
|
111
|
+
logits = model(full_sequence[None]) # [1, seq_len, vocab_size]
|
|
112
|
+
|
|
113
|
+
# Extract the same logits our function should use
|
|
114
|
+
prompt_len = len(prompt_tokens)
|
|
115
|
+
manual_logits = logits[0, prompt_len-1:prompt_len-1+len(response_tokens), :]
|
|
116
|
+
manual_log_probs = manual_logits - mx.logsumexp(manual_logits, axis=-1, keepdims=True)
|
|
117
|
+
manual_selected = manual_log_probs[mx.arange(len(response_tokens)), response_tokens]
|
|
118
|
+
|
|
119
|
+
# Compare results
|
|
120
|
+
diff = mx.abs(computed_logprobs - manual_selected)
|
|
121
|
+
max_diff = float(mx.max(diff))
|
|
122
|
+
|
|
123
|
+
assert max_diff < 1e-6, f"Logprob mismatch: max difference {max_diff}"
|
|
124
|
+
|
|
125
|
+
print(f"Teacher-forcing consistency test passed (max diff: {max_diff:.2e})")
|
|
126
|
+
return True
|
|
127
|
+
|
|
128
|
+
except Exception as e:
|
|
129
|
+
print(f"Teacher-forcing consistency test failed: {e}")
|
|
130
|
+
return False
|
|
131
|
+
|
|
132
|
+
@staticmethod
|
|
133
|
+
def test_batch_processing():
|
|
134
|
+
"""Test logprob extraction works correctly with batched inputs."""
|
|
135
|
+
print("Testing batch processing...")
|
|
136
|
+
|
|
137
|
+
model = LogprobValidator.create_test_model()
|
|
138
|
+
|
|
139
|
+
try:
|
|
140
|
+
# Test single sequence
|
|
141
|
+
prompt1 = mx.array([1, 2])
|
|
142
|
+
response1 = mx.array([3, 4])
|
|
143
|
+
logprobs1 = compute_logprobs(model, prompt1, response1)
|
|
144
|
+
|
|
145
|
+
# Test another sequence
|
|
146
|
+
prompt2 = mx.array([5, 6])
|
|
147
|
+
response2 = mx.array([7, 8])
|
|
148
|
+
logprobs2 = compute_logprobs(model, prompt2, response2)
|
|
149
|
+
|
|
150
|
+
# Results should be consistent shapes
|
|
151
|
+
assert logprobs1.shape == logprobs2.shape, "Batch processing shape inconsistency"
|
|
152
|
+
|
|
153
|
+
print("Batch processing test passed")
|
|
154
|
+
return True
|
|
155
|
+
|
|
156
|
+
except Exception as e:
|
|
157
|
+
print(f"Batch processing test failed: {e}")
|
|
158
|
+
return False
|
|
159
|
+
|
|
160
|
+
@staticmethod
|
|
161
|
+
def test_edge_cases():
|
|
162
|
+
"""Test edge cases that could break training."""
|
|
163
|
+
print("Testing edge cases...")
|
|
164
|
+
|
|
165
|
+
model = LogprobValidator.create_test_model()
|
|
166
|
+
|
|
167
|
+
try:
|
|
168
|
+
# Test empty response
|
|
169
|
+
prompt = mx.array([1, 2, 3])
|
|
170
|
+
empty_response = mx.array([])
|
|
171
|
+
logprobs_empty = compute_logprobs(model, prompt, empty_response)
|
|
172
|
+
assert len(logprobs_empty) == 0, "Empty response should return empty logprobs"
|
|
173
|
+
|
|
174
|
+
# Test single token response
|
|
175
|
+
single_response = mx.array([4])
|
|
176
|
+
logprobs_single = compute_logprobs(model, prompt, single_response)
|
|
177
|
+
assert len(logprobs_single) == 1, "Single token response should return single logprob"
|
|
178
|
+
|
|
179
|
+
# Test a short prompt
|
|
180
|
+
short_prompt = mx.array([1])
|
|
181
|
+
response = mx.array([2, 3])
|
|
182
|
+
logprobs_short = compute_logprobs(model, short_prompt, response)
|
|
183
|
+
assert len(logprobs_short) == 2, "Should handle short prompts correctly"
|
|
184
|
+
|
|
185
|
+
print("Edge cases test passed")
|
|
186
|
+
return True
|
|
187
|
+
|
|
188
|
+
except Exception as e:
|
|
189
|
+
print(f"Edge cases test failed: {e}")
|
|
190
|
+
return False
|
|
191
|
+
|
|
192
|
+
@staticmethod
|
|
193
|
+
def test_numerical_stability():
|
|
194
|
+
"""Test numerical stability of logprob computation."""
|
|
195
|
+
print("Testing numerical stability...")
|
|
196
|
+
|
|
197
|
+
model = LogprobValidator.create_test_model()
|
|
198
|
+
|
|
199
|
+
try:
|
|
200
|
+
prompt = mx.array([1, 2, 3])
|
|
201
|
+
response = mx.array([4, 5, 6])
|
|
202
|
+
|
|
203
|
+
logprobs = compute_logprobs(model, prompt, response)
|
|
204
|
+
|
|
205
|
+
# Check for NaN or Inf values
|
|
206
|
+
assert not mx.any(mx.isnan(logprobs)), "NaN values detected in logprobs"
|
|
207
|
+
assert not mx.any(mx.isinf(logprobs)), "Inf values detected in logprobs"
|
|
208
|
+
|
|
209
|
+
# Check logprobs are reasonable (log-probabilities are non-positive)
|
|
210
|
+
assert mx.all(logprobs <= 0), "Logprobs should be negative (log probabilities)"
|
|
211
|
+
assert mx.all(logprobs >= -50), "Logprobs should not be too negative"
|
|
212
|
+
|
|
213
|
+
print("Numerical stability test passed")
|
|
214
|
+
return True
|
|
215
|
+
|
|
216
|
+
except Exception as e:
|
|
217
|
+
print(f"Numerical stability test failed: {e}")
|
|
218
|
+
return False
|
|
219
|
+
|
|
220
|
+
@staticmethod
|
|
221
|
+
def test_gradient_computation():
|
|
222
|
+
"""Test that gradients can flow through logprob computation when needed."""
|
|
223
|
+
print("Testing gradient computation...")
|
|
224
|
+
|
|
225
|
+
model = LogprobValidator.create_test_model()
|
|
226
|
+
|
|
227
|
+
try:
|
|
228
|
+
prompt = mx.array([1, 2])
|
|
229
|
+
response = mx.array([3, 4])
|
|
230
|
+
|
|
231
|
+
# Test that we can compute gradients through the model
|
|
232
|
+
def loss_fn(model_params):
|
|
233
|
+
# Forward pass
|
|
234
|
+
full_seq = mx.concatenate([prompt, response])
|
|
235
|
+
logits = model(full_seq[None])
|
|
236
|
+
|
|
237
|
+
# Simple loss (not using our logprob function, just testing model)
|
|
238
|
+
return mx.mean(logits)
|
|
239
|
+
|
|
240
|
+
# Test gradient computation
|
|
241
|
+
grads = mx.grad(loss_fn)(model.parameters())
|
|
242
|
+
|
|
243
|
+
# Should get gradients for all parameters
|
|
244
|
+
assert len(grads) > 0, "No gradients computed"
|
|
245
|
+
|
|
246
|
+
print("Gradient computation test passed")
|
|
247
|
+
return True
|
|
248
|
+
|
|
249
|
+
except Exception as e:
|
|
250
|
+
print(f"Gradient computation test failed: {e}")
|
|
251
|
+
return False
|
|
252
|
+
|
|
253
|
+
@staticmethod
|
|
254
|
+
def run_full_validation() -> bool:
|
|
255
|
+
"""
|
|
256
|
+
Run complete logprob validation suite.
|
|
257
|
+
|
|
258
|
+
Returns:
|
|
259
|
+
True if all tests pass, False otherwise
|
|
260
|
+
"""
|
|
261
|
+
print("=" * 60)
|
|
262
|
+
print("LOGPROB VALIDATION")
|
|
263
|
+
print("=" * 60)
|
|
264
|
+
print("Testing logprob extraction correctness for RL training...")
|
|
265
|
+
print()
|
|
266
|
+
|
|
267
|
+
tests = [
|
|
268
|
+
LogprobValidator.test_autoregressive_indexing,
|
|
269
|
+
LogprobValidator.test_teacher_forcing_consistency,
|
|
270
|
+
LogprobValidator.test_batch_processing,
|
|
271
|
+
LogprobValidator.test_edge_cases,
|
|
272
|
+
LogprobValidator.test_numerical_stability,
|
|
273
|
+
LogprobValidator.test_gradient_computation
|
|
274
|
+
]
|
|
275
|
+
|
|
276
|
+
passed = 0
|
|
277
|
+
total = len(tests)
|
|
278
|
+
|
|
279
|
+
for test in tests:
|
|
280
|
+
try:
|
|
281
|
+
if test():
|
|
282
|
+
passed += 1
|
|
283
|
+
print()
|
|
284
|
+
except Exception as e:
|
|
285
|
+
print(f"Test {test.__name__} crashed: {e}")
|
|
286
|
+
print()
|
|
287
|
+
|
|
288
|
+
print("=" * 60)
|
|
289
|
+
print(f"VALIDATION RESULTS: {passed}/{total} tests passed")
|
|
290
|
+
|
|
291
|
+
if passed == total:
|
|
292
|
+
print("ALL LOGPROB TESTS PASSED!")
|
|
293
|
+
print("Logprob extraction is correct for RL training")
|
|
294
|
+
return True
|
|
295
|
+
else:
|
|
296
|
+
print("LOGPROB VALIDATION FAILED!")
|
|
297
|
+
print("RL training may not work correctly with the current implementation")
|
|
298
|
+
return False
|
|
299
|
+
|
|
300
|
+
|
|
301
|
+
def validate_logprob_implementation() -> bool:
|
|
302
|
+
"""
|
|
303
|
+
Main entry point for logprob validation.
|
|
304
|
+
|
|
305
|
+
This function should be called before starting any RL training
|
|
306
|
+
to ensure logprob extraction is implemented correctly.
|
|
307
|
+
|
|
308
|
+
Returns:
|
|
309
|
+
True if validation passes, False otherwise
|
|
310
|
+
"""
|
|
311
|
+
return LogprobValidator.run_full_validation()
|
|
312
|
+
|
|
313
|
+
|
|
314
|
+
if __name__ == "__main__":
|
|
315
|
+
validate_logprob_implementation()
|
|
@@ -0,0 +1,99 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: textpolicy
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: MLX-optimized reward and verification system for text generation RL
|
|
5
|
+
Requires-Python: >=3.12
|
|
6
|
+
Description-Content-Type: text/markdown
|
|
7
|
+
License-File: LICENSE
|
|
8
|
+
Requires-Dist: numpy>=2.3.2
|
|
9
|
+
Requires-Dist: mlx>=0.21.0
|
|
10
|
+
Requires-Dist: mlx-lm>=0.21.0
|
|
11
|
+
Requires-Dist: gymnasium>=0.29.0
|
|
12
|
+
Requires-Dist: psutil>=7.0.0
|
|
13
|
+
Requires-Dist: wandb>=0.21.1
|
|
14
|
+
Requires-Dist: aiohttp>=3.12.15
|
|
15
|
+
Requires-Dist: pytest>=8.4.1
|
|
16
|
+
Provides-Extra: external
|
|
17
|
+
Requires-Dist: aiohttp>=3.8.0; extra == "external"
|
|
18
|
+
Requires-Dist: pydantic>=2.0.0; extra == "external"
|
|
19
|
+
Provides-Extra: dev
|
|
20
|
+
Requires-Dist: pytest>=7.0.0; extra == "dev"
|
|
21
|
+
Requires-Dist: black>=22.0.0; extra == "dev"
|
|
22
|
+
Requires-Dist: ruff>=0.1.0; extra == "dev"
|
|
23
|
+
Dynamic: license-file
|
|
24
|
+
|
|
25
|
+
# TextPolicy
|
|
26
|
+
|
|
27
|
+
Reinforcement learning toolkit for text generation on MLX (Apple Silicon).
|
|
28
|
+
TextPolicy provides algorithms (GRPO/GSPO), text-generation environments, a rollout runner,
|
|
29
|
+
reward functions with a decorator registry, and LoRA/QLoRA utilities.
|
|
30
|
+
|
|
31
|
+
## Install (uv)
|
|
32
|
+
|
|
33
|
+
```bash
|
|
34
|
+
uv add textpolicy
|
|
35
|
+
```
|
|
36
|
+
|
|
37
|
+
Optional model integration:
|
|
38
|
+
|
|
39
|
+
```bash
|
|
40
|
+
uv add mlx mlx-lm
|
|
41
|
+
```
|
|
42
|
+
|
|
43
|
+
## Quickstart
|
|
44
|
+
|
|
45
|
+
Working example using a real model and tokenizer (mlx-lm required):
|
|
46
|
+
|
|
47
|
+
```python
|
|
48
|
+
import mlx.core as mx
|
|
49
|
+
import textpolicy as tp
|
|
50
|
+
from textpolicy import load_model, create_policy
|
|
51
|
+
from textpolicy.environment.text_generation import TextGenerationEnv
|
|
52
|
+
from textpolicy.rollout import RolloutRunner, create_strategy
|
|
53
|
+
|
|
54
|
+
# 1) Load model and tokenizer (mlx-lm)
|
|
55
|
+
model, tokenizer = load_model("Qwen/Qwen3-0.6B")
|
|
56
|
+
|
|
57
|
+
# 2) Create a policy (controls generation)
|
|
58
|
+
generation_params = {"max_tokens": 25, "temperature": 0.7}
|
|
59
|
+
policy_fn = create_policy(model, tokenizer, generation_params)
|
|
60
|
+
|
|
61
|
+
# 3) Define a reward function (env uses this to score responses)
|
|
62
|
+
@tp.reward
|
|
63
|
+
def length_reward(prompt: str, completion: str, example: dict, **kwargs) -> float:
|
|
64
|
+
return float(len(completion.split()))
|
|
65
|
+
|
|
66
|
+
# 4) Create an environment (requires a tokenizer)
|
|
67
|
+
env = TextGenerationEnv(["What is AI?"], length_reward, tokenizer=tokenizer)
|
|
68
|
+
|
|
69
|
+
# 5) Collect one rollout step
|
|
70
|
+
strategy = create_strategy('grpo')
|
|
71
|
+
runner = RolloutRunner(env, policy=policy_fn, strategy=strategy, max_steps=1)
|
|
72
|
+
buffer = runner.collect()
|
|
73
|
+
print(len(buffer.episodes))
|
|
74
|
+
```
|
|
75
|
+
|
|
76
|
+
Docs:
|
|
77
|
+
- Quickstart: `docs/QUICKSTART_UV.md`
|
|
78
|
+
- LoRA/QLoRA: `docs/10_lora_qlora.md`
|
|
79
|
+
- Full index: `docs/index.md`
|
|
80
|
+
|
|
81
|
+
FAQ:
|
|
82
|
+
- Do I need a model?
|
|
83
|
+
- Yes for generation with `create_policy`.
|
|
84
|
+
Use `load_model()` (mlx‑lm) to get `(model, tokenizer)`.
|
|
85
|
+
For reward‑only code (no generation), a model is not required.
|
|
86
|
+
- Do I need a tokenizer?
|
|
87
|
+
- Yes.
|
|
88
|
+
Both `TextGenerationEnv` and `TextGenerationEnvironment` require a tokenizer.
|
|
89
|
+
`load_model()` returns one for mlx‑lm models.
|
|
90
|
+
- How do I control generation?
|
|
91
|
+
- Pass `generation_params` to `create_policy` (for example, `max_tokens`, `temperature`, `top_p`, `repetition_penalty`).
|
|
92
|
+
- What does `step()` return?
|
|
93
|
+
- A dict with `observation`, `reward`, `terminated`, `truncated`, `info`. The runner enforces this.
|
|
94
|
+
|
|
95
|
+
Examples:
|
|
96
|
+
- 01–06: reward functions, batch processing, minimal training
|
|
97
|
+
- 08: GRPO training with rollout + buffer
|
|
98
|
+
- 09–10: length reduction (GRPO/GSPO)
|
|
99
|
+
- 11: LoRA/QLoRA configuration
|
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
textpolicy/__init__.py,sha256=u4u0fIHfAvXFN2ATHCsG0Tx4xGfOcfuOITBTmKbGhrw,1576
|
|
2
|
+
textpolicy/__main__.py,sha256=IlGmjJaW-DJUC7yhxUhbwNOZA3GxkeQGkVbFdS3_wBI,136
|
|
3
|
+
textpolicy/cli.py,sha256=3CcJzrRlin1pgd6Mh312Xp3-EihHtTSvhakyYpdfacs,2107
|
|
4
|
+
textpolicy/validate.py,sha256=lxmegz83B_c-PS3cFHaaL3c9fgWrEaLsDLkpPFtSj8Y,3780
|
|
5
|
+
textpolicy/algorithms/__init__.py,sha256=muJSuiJkaGg-zSaGIYkaB7UbLh6UJYMdI60SGqTgNWM,1257
|
|
6
|
+
textpolicy/algorithms/grpo.py,sha256=1j_C70Bgwrnr_BCAl_qvAsH3Mg9yMOW-D4vhPUxUpFQ,26261
|
|
7
|
+
textpolicy/algorithms/gspo.py,sha256=OWvJolldTSTEOsCIwio3ER0hTWkYsJ1e0BBJElgJ2mc,23485
|
|
8
|
+
textpolicy/buffer/__init__.py,sha256=bnSkX9Oe1ajau-yqC2PYNF4a4ELVP05zjlkDmIerXlw,569
|
|
9
|
+
textpolicy/buffer/buffer.py,sha256=mDie8ZiWgsjNJ4LiKyfpQNLzN1K0UICxI8XaqQacUMM,7917
|
|
10
|
+
textpolicy/buffer/episode.py,sha256=iNyVqeMLzOMauz1Z3fs9JUyL7g7IEC9t8GN1eypThy4,15875
|
|
11
|
+
textpolicy/buffer/sampling.py,sha256=aE92R69TQe8c5V0WL_EaKHLXSSxbI2wDlgvtR80fJDc,17598
|
|
12
|
+
textpolicy/buffer/storage.py,sha256=dYdROkL4KHx0w3XAqJFvRs6gRbYs3dAsiHqdMnwBMAI,9248
|
|
13
|
+
textpolicy/environment/__init__.py,sha256=TzhkXiAq10iC9PGGJ-gn2Bm2ttmJHpQ8d3Lxf_wUnNg,1685
|
|
14
|
+
textpolicy/environment/base.py,sha256=7Bh5tlFz20RC-iPQhnpv-zigHf8a6cEwBC3ISV-HS1Y,2829
|
|
15
|
+
textpolicy/environment/environment.py,sha256=o8-RY6wj5xrzDBp77HoY2At3XlBwvreF3DKpDFpC7js,1240
|
|
16
|
+
textpolicy/environment/factory.py,sha256=pebQo1_M3sMF8Pdc9yvpdXzRXfIDllKJoAQAjQbif0E,3124
|
|
17
|
+
textpolicy/environment/gym.py,sha256=P8Bi8PlDtcWWa9uLuCjkhZnYRVs-mg6iSJVSBkG99f8,3186
|
|
18
|
+
textpolicy/environment/task_suites.py,sha256=ssPnw2Y3eGYaskWf8dUab4rNu_Bx5L284b3VdhgvSPM,1544
|
|
19
|
+
textpolicy/environment/text_generation.py,sha256=BXSJS_05Q89cPFfdXcUKxOXSZm3HBR3KMi55BnOdoLY,31258
|
|
20
|
+
textpolicy/environment/vectorized.py,sha256=ZROtpmdbh1Oi8c0b0D_vmVzqI16Cp2WZTmkjkRbMoDg,9932
|
|
21
|
+
textpolicy/generation/__init__.py,sha256=J3dc0SPAZChJTsRn47tz8FfIp3XwNgZ-8_H9VBpQYvQ,1266
|
|
22
|
+
textpolicy/generation/lora.py,sha256=xSKRczJY20BrkkU1SSgBtDc30tZjdFE7FhEZPUEoiyg,13747
|
|
23
|
+
textpolicy/generation/mlx_generation.py,sha256=r__oXHiAtAQ4xq4ODUwS7FrXL40Hu9cwoS5sZOhsAfs,20468
|
|
24
|
+
textpolicy/generation/reload.py,sha256=-eJE3LXmN-kDatUQjM0--VZp0jjqWgBslYcmNcQZ_A8,7998
|
|
25
|
+
textpolicy/rewards/__init__.py,sha256=mg_wL7oedL_5KLsnaJuPVc_ZHZqZKXRHg9ws2gSifMk,4769
|
|
26
|
+
textpolicy/rewards/adapters.py,sha256=Ffwi9eF_mx6DdCoRRmzl7TdhqNJycpz1TovJXa0XxXk,12843
|
|
27
|
+
textpolicy/rewards/basic.py,sha256=xlMMfCmLm3TrDJsxpJ-h9vlc-m27tTrvnZ-JGUOD89A,6921
|
|
28
|
+
textpolicy/rewards/integrated_system.py,sha256=eGK69J2cAfJD4GoL_ANivA8ZWpXHLtu6A0YoMwGTAzI,11243
|
|
29
|
+
textpolicy/rewards/mlx_batch_processor.py,sha256=97jFIHwqq75q7-LADVIBCbNqZJTU8jpbs8xcyUrJNfQ,16850
|
|
30
|
+
textpolicy/rewards/registry.py,sha256=azuz4HpbacBUss2-FS0Wji_FOUs7NLtwgpsEynqq7ds,11437
|
|
31
|
+
textpolicy/rewards/rollout_rewards.py,sha256=7bSkbBJwsb9MLOQ-YMutocpIgGI_AZdgFjRCh2xv0iY,13805
|
|
32
|
+
textpolicy/rewards/verifiers.py,sha256=Xnb9FC5Lc0CE34Wwoc99AwKRt0Ut0gDliImr_YDChiY,11596
|
|
33
|
+
textpolicy/rollout/__init__.py,sha256=PKjY1NmsARTPmUwzNLSp2tFU4NvgJ4NUP1VPy2g3nxI,1229
|
|
34
|
+
textpolicy/rollout/aggregator.py,sha256=U2ywEPWwBHGgpEPHxmu2ywSibSa_pkD_GuHN3JGcwck,4467
|
|
35
|
+
textpolicy/rollout/base.py,sha256=CuyzsHM_yn3eRKldLCcEDfmnqFnHoq-rJ7k0f-nYHw8,2919
|
|
36
|
+
textpolicy/rollout/rollout.py,sha256=h3gs_U-NfoIKpBVf1NFeZGInvSki8RDATsq1__ne8Qo,4499
|
|
37
|
+
textpolicy/rollout/runner.py,sha256=9bB0B1GlEGNtr8bhEYQbpY1WBzJQK0MoFrsbZTQ-Lzw,10993
|
|
38
|
+
textpolicy/rollout/strategy.py,sha256=Q97wxgq-FCienL15P1l-pXYEWiUZrh861UmtStj4x3E,7577
|
|
39
|
+
textpolicy/rollout/worker.py,sha256=aXOKRtkivKwDks8g8VtaWUv-wQMPR72idZxPuNtwmSE,6939
|
|
40
|
+
textpolicy/training/__init__.py,sha256=TmcW2BqmwO4DaDDr4n2g1QOtHeVPxgw6xZdeYTmzjD8,282
|
|
41
|
+
textpolicy/training/metrics.py,sha256=fmY1ZBdyEgYrfH18H3fOZ-dieMtjVNzjxjdxd7yo7OU,7582
|
|
42
|
+
textpolicy/training/rollout_manager.py,sha256=ETD7WTbbaQ8uUzrHPBCDX-PawmEJfSK6Kd5N-dvIZRY,2328
|
|
43
|
+
textpolicy/training/trainer.py,sha256=kG7tduOKHPFVVewyspgm360enowTpNpwaLhZWuIc9vo,29268
|
|
44
|
+
textpolicy/utils/__init__.py,sha256=v0ji-jnegGRydzmAOccKY4XC0nkBbBZqdHXzk-i6ers,1220
|
|
45
|
+
textpolicy/utils/benchmarking.py,sha256=YDN24vU8SL_EsrANQWF1qbmXtfhF4Woj8yjez-h-Io0,18682
|
|
46
|
+
textpolicy/utils/data.py,sha256=KJoPzYWYVAJawvDX1BHzwBZEpCXLSBC168rjud7MSB0,1413
|
|
47
|
+
textpolicy/utils/debug.py,sha256=ir_5DF88_yZbU43w-o_o05EivgPv9AgNVRovL-adNIE,6139
|
|
48
|
+
textpolicy/utils/environment.py,sha256=LyYQgpZVfEDyPlD7774_AHR9crOC6NNGjd6J37ltLGM,13319
|
|
49
|
+
textpolicy/utils/memory.py,sha256=H8mfUY52iU5VRPhOLPdanWvBgEGLtLCoUE0xpJIMcfM,3391
|
|
50
|
+
textpolicy/utils/performance.py,sha256=YzLMT_bAPs9TnVwmrPvWOEwi4UAy5Bmbr5zTpJacmGA,17823
|
|
51
|
+
textpolicy/utils/timing.py,sha256=MUCBUC22zQ3nMVxhif_XhHlyZUxYGhheFMs7bf6ZKg8,4869
|
|
52
|
+
textpolicy/utils/logging/__init__.py,sha256=YeJ18H7suPzYPgXlfSsE90GaAVF_5lOcbGCDCy8tvuA,512
|
|
53
|
+
textpolicy/utils/logging/base.py,sha256=3BBg318dEPNIHYtoJnFiJcktvS5KMcIUNXBKKfn8x_g,1304
|
|
54
|
+
textpolicy/utils/logging/console.py,sha256=WIcOO-tT2FhrqKhWO89qSSNmLA06up_KenrN6TGoyo8,1680
|
|
55
|
+
textpolicy/utils/logging/factory.py,sha256=vAkMShn7bnVRDuZNOKaVXNmF6XPUNeqaFFPF4dZw70E,3920
|
|
56
|
+
textpolicy/utils/logging/multi.py,sha256=kIxuoXiZ4nf_p8JlnzYxtqA0r82LJaqHnO5mEHlJdgM,2501
|
|
57
|
+
textpolicy/utils/logging/tensorboard.py,sha256=aY9YMReSJkWEhy6SdAAUlHSB4lzDecivBC8K7CZPcO4,1949
|
|
58
|
+
textpolicy/utils/logging/wandb.py,sha256=U4pxuZNOz2l8XiymK8OFbCpiRTBOLNtnZakC_udttfQ,2206
|
|
59
|
+
textpolicy/validation/__init__.py,sha256=KcyppNi91w0bF51gZ0ykUIKEiF7z6TT37uuavMFScnA,328
|
|
60
|
+
textpolicy/validation/logprob_validation.py,sha256=G_CCy5NRDUTmo7WZIChhNVM3NtP1VmWAjdd5z6TIvos,11749
|
|
61
|
+
textpolicy-0.1.0.dist-info/licenses/LICENSE,sha256=AYDHSNRbiqZt4HHH1gaOoQ2hjYjK4bqw4Vd9UyKzx18,1065
|
|
62
|
+
textpolicy-0.1.0.dist-info/METADATA,sha256=XdyIh8e2IIRymRf31vu1MuVM2aaut2qsZ5PcsjHrl9Y,3199
|
|
63
|
+
textpolicy-0.1.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
64
|
+
textpolicy-0.1.0.dist-info/entry_points.txt,sha256=d0Cj5boT6k_l_beVPWPt9LZMllsN4kbIUmsNsn1BANE,51
|
|
65
|
+
textpolicy-0.1.0.dist-info/top_level.txt,sha256=Ww6_QEF71dI-AYCaugiGeGcgMoFAixSOszSoRsyX-E0,11
|
|
66
|
+
textpolicy-0.1.0.dist-info/RECORD,,
|
|
@@ -1,6 +0,0 @@
|
|
|
1
|
-
textpolicy/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
2
|
-
textpolicy-0.0.1.dist-info/licenses/LICENSE,sha256=AYDHSNRbiqZt4HHH1gaOoQ2hjYjK4bqw4Vd9UyKzx18,1065
|
|
3
|
-
textpolicy-0.0.1.dist-info/METADATA,sha256=Hm0fs04Q8V79bM4-GI0QMPDPEWN97oe5ULRex1orYDQ,211
|
|
4
|
-
textpolicy-0.0.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
5
|
-
textpolicy-0.0.1.dist-info/top_level.txt,sha256=Ww6_QEF71dI-AYCaugiGeGcgMoFAixSOszSoRsyX-E0,11
|
|
6
|
-
textpolicy-0.0.1.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|