textpolicy 0.1.0__tar.gz → 0.1.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {textpolicy-0.1.0/textpolicy.egg-info → textpolicy-0.1.2}/PKG-INFO +14 -4
- textpolicy-0.1.2/pyproject.toml +56 -0
- textpolicy-0.1.2/tests/test_issue_fixes.py +218 -0
- textpolicy-0.1.2/tests/test_mlx_compatibility.py +315 -0
- textpolicy-0.1.2/tests/test_training_pipeline.py +240 -0
- {textpolicy-0.1.0 → textpolicy-0.1.2}/textpolicy/__init__.py +2 -1
- {textpolicy-0.1.0 → textpolicy-0.1.2}/textpolicy/environment/text_generation.py +12 -4
- {textpolicy-0.1.0 → textpolicy-0.1.2/textpolicy.egg-info}/PKG-INFO +14 -4
- {textpolicy-0.1.0 → textpolicy-0.1.2}/textpolicy.egg-info/SOURCES.txt +3 -0
- {textpolicy-0.1.0 → textpolicy-0.1.2}/textpolicy.egg-info/requires.txt +2 -2
- textpolicy-0.1.0/pyproject.toml +0 -30
- {textpolicy-0.1.0 → textpolicy-0.1.2}/LICENSE +0 -0
- {textpolicy-0.1.0 → textpolicy-0.1.2}/README.md +0 -0
- {textpolicy-0.1.0 → textpolicy-0.1.2}/setup.cfg +0 -0
- {textpolicy-0.1.0 → textpolicy-0.1.2}/tests/test_gspo_verification.py +0 -0
- {textpolicy-0.1.0 → textpolicy-0.1.2}/tests/test_integration_e2e_training.py +0 -0
- {textpolicy-0.1.0 → textpolicy-0.1.2}/tests/test_reward_signatures.py +0 -0
- {textpolicy-0.1.0 → textpolicy-0.1.2}/tests/test_rollout_rewards.py +0 -0
- {textpolicy-0.1.0 → textpolicy-0.1.2}/tests/test_runner_step_enforcement.py +0 -0
- {textpolicy-0.1.0 → textpolicy-0.1.2}/tests/test_validate_installation.py +0 -0
- {textpolicy-0.1.0 → textpolicy-0.1.2}/textpolicy/__main__.py +0 -0
- {textpolicy-0.1.0 → textpolicy-0.1.2}/textpolicy/algorithms/__init__.py +0 -0
- {textpolicy-0.1.0 → textpolicy-0.1.2}/textpolicy/algorithms/grpo.py +0 -0
- {textpolicy-0.1.0 → textpolicy-0.1.2}/textpolicy/algorithms/gspo.py +0 -0
- {textpolicy-0.1.0 → textpolicy-0.1.2}/textpolicy/buffer/__init__.py +0 -0
- {textpolicy-0.1.0 → textpolicy-0.1.2}/textpolicy/buffer/buffer.py +0 -0
- {textpolicy-0.1.0 → textpolicy-0.1.2}/textpolicy/buffer/episode.py +0 -0
- {textpolicy-0.1.0 → textpolicy-0.1.2}/textpolicy/buffer/sampling.py +0 -0
- {textpolicy-0.1.0 → textpolicy-0.1.2}/textpolicy/buffer/storage.py +0 -0
- {textpolicy-0.1.0 → textpolicy-0.1.2}/textpolicy/cli.py +0 -0
- {textpolicy-0.1.0 → textpolicy-0.1.2}/textpolicy/environment/__init__.py +0 -0
- {textpolicy-0.1.0 → textpolicy-0.1.2}/textpolicy/environment/base.py +0 -0
- {textpolicy-0.1.0 → textpolicy-0.1.2}/textpolicy/environment/environment.py +0 -0
- {textpolicy-0.1.0 → textpolicy-0.1.2}/textpolicy/environment/factory.py +0 -0
- {textpolicy-0.1.0 → textpolicy-0.1.2}/textpolicy/environment/gym.py +0 -0
- {textpolicy-0.1.0 → textpolicy-0.1.2}/textpolicy/environment/task_suites.py +0 -0
- {textpolicy-0.1.0 → textpolicy-0.1.2}/textpolicy/environment/vectorized.py +0 -0
- {textpolicy-0.1.0 → textpolicy-0.1.2}/textpolicy/generation/__init__.py +0 -0
- {textpolicy-0.1.0 → textpolicy-0.1.2}/textpolicy/generation/lora.py +0 -0
- {textpolicy-0.1.0 → textpolicy-0.1.2}/textpolicy/generation/mlx_generation.py +0 -0
- {textpolicy-0.1.0 → textpolicy-0.1.2}/textpolicy/generation/reload.py +0 -0
- {textpolicy-0.1.0 → textpolicy-0.1.2}/textpolicy/rewards/__init__.py +0 -0
- {textpolicy-0.1.0 → textpolicy-0.1.2}/textpolicy/rewards/adapters.py +0 -0
- {textpolicy-0.1.0 → textpolicy-0.1.2}/textpolicy/rewards/basic.py +0 -0
- {textpolicy-0.1.0 → textpolicy-0.1.2}/textpolicy/rewards/integrated_system.py +0 -0
- {textpolicy-0.1.0 → textpolicy-0.1.2}/textpolicy/rewards/mlx_batch_processor.py +0 -0
- {textpolicy-0.1.0 → textpolicy-0.1.2}/textpolicy/rewards/registry.py +0 -0
- {textpolicy-0.1.0 → textpolicy-0.1.2}/textpolicy/rewards/rollout_rewards.py +0 -0
- {textpolicy-0.1.0 → textpolicy-0.1.2}/textpolicy/rewards/verifiers.py +0 -0
- {textpolicy-0.1.0 → textpolicy-0.1.2}/textpolicy/rollout/__init__.py +0 -0
- {textpolicy-0.1.0 → textpolicy-0.1.2}/textpolicy/rollout/aggregator.py +0 -0
- {textpolicy-0.1.0 → textpolicy-0.1.2}/textpolicy/rollout/base.py +0 -0
- {textpolicy-0.1.0 → textpolicy-0.1.2}/textpolicy/rollout/rollout.py +0 -0
- {textpolicy-0.1.0 → textpolicy-0.1.2}/textpolicy/rollout/runner.py +0 -0
- {textpolicy-0.1.0 → textpolicy-0.1.2}/textpolicy/rollout/strategy.py +0 -0
- {textpolicy-0.1.0 → textpolicy-0.1.2}/textpolicy/rollout/worker.py +0 -0
- {textpolicy-0.1.0 → textpolicy-0.1.2}/textpolicy/training/__init__.py +0 -0
- {textpolicy-0.1.0 → textpolicy-0.1.2}/textpolicy/training/metrics.py +0 -0
- {textpolicy-0.1.0 → textpolicy-0.1.2}/textpolicy/training/rollout_manager.py +0 -0
- {textpolicy-0.1.0 → textpolicy-0.1.2}/textpolicy/training/trainer.py +0 -0
- {textpolicy-0.1.0 → textpolicy-0.1.2}/textpolicy/utils/__init__.py +0 -0
- {textpolicy-0.1.0 → textpolicy-0.1.2}/textpolicy/utils/benchmarking.py +0 -0
- {textpolicy-0.1.0 → textpolicy-0.1.2}/textpolicy/utils/data.py +0 -0
- {textpolicy-0.1.0 → textpolicy-0.1.2}/textpolicy/utils/debug.py +0 -0
- {textpolicy-0.1.0 → textpolicy-0.1.2}/textpolicy/utils/environment.py +0 -0
- {textpolicy-0.1.0 → textpolicy-0.1.2}/textpolicy/utils/logging/__init__.py +0 -0
- {textpolicy-0.1.0 → textpolicy-0.1.2}/textpolicy/utils/logging/base.py +0 -0
- {textpolicy-0.1.0 → textpolicy-0.1.2}/textpolicy/utils/logging/console.py +0 -0
- {textpolicy-0.1.0 → textpolicy-0.1.2}/textpolicy/utils/logging/factory.py +0 -0
- {textpolicy-0.1.0 → textpolicy-0.1.2}/textpolicy/utils/logging/multi.py +0 -0
- {textpolicy-0.1.0 → textpolicy-0.1.2}/textpolicy/utils/logging/tensorboard.py +0 -0
- {textpolicy-0.1.0 → textpolicy-0.1.2}/textpolicy/utils/logging/wandb.py +0 -0
- {textpolicy-0.1.0 → textpolicy-0.1.2}/textpolicy/utils/memory.py +0 -0
- {textpolicy-0.1.0 → textpolicy-0.1.2}/textpolicy/utils/performance.py +0 -0
- {textpolicy-0.1.0 → textpolicy-0.1.2}/textpolicy/utils/timing.py +0 -0
- {textpolicy-0.1.0 → textpolicy-0.1.2}/textpolicy/validate.py +0 -0
- {textpolicy-0.1.0 → textpolicy-0.1.2}/textpolicy/validation/__init__.py +0 -0
- {textpolicy-0.1.0 → textpolicy-0.1.2}/textpolicy/validation/logprob_validation.py +0 -0
- {textpolicy-0.1.0 → textpolicy-0.1.2}/textpolicy.egg-info/dependency_links.txt +0 -0
- {textpolicy-0.1.0 → textpolicy-0.1.2}/textpolicy.egg-info/entry_points.txt +0 -0
- {textpolicy-0.1.0 → textpolicy-0.1.2}/textpolicy.egg-info/top_level.txt +0 -0
|
@@ -1,13 +1,23 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: textpolicy
|
|
3
|
-
Version: 0.1.
|
|
4
|
-
Summary: MLX
|
|
3
|
+
Version: 0.1.2
|
|
4
|
+
Summary: Reinforcement learning for text generation on MLX (Apple Silicon): GRPO/GSPO, environments, rollout, rewards, LoRA/QLoRA
|
|
5
|
+
Project-URL: Homepage, https://github.com/teilomillet/textpolicy
|
|
6
|
+
Project-URL: Repository, https://github.com/teilomillet/textpolicy
|
|
7
|
+
Project-URL: Documentation, https://github.com/teilomillet/textpolicy#readme
|
|
8
|
+
Project-URL: Changelog, https://github.com/teilomillet/textpolicy/blob/main/CHANGELOG.md
|
|
9
|
+
Keywords: reinforcement-learning,text-generation,mlx,apple-silicon,lora,qlora,grpo,gspo,rlhf
|
|
10
|
+
Classifier: Programming Language :: Python :: 3
|
|
11
|
+
Classifier: Operating System :: MacOS
|
|
12
|
+
Classifier: Intended Audience :: Developers
|
|
13
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
14
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
5
15
|
Requires-Python: >=3.12
|
|
6
16
|
Description-Content-Type: text/markdown
|
|
7
17
|
License-File: LICENSE
|
|
8
18
|
Requires-Dist: numpy>=2.3.2
|
|
9
|
-
Requires-Dist: mlx>=0.
|
|
10
|
-
Requires-Dist: mlx-lm>=0.
|
|
19
|
+
Requires-Dist: mlx>=0.22.0
|
|
20
|
+
Requires-Dist: mlx-lm>=0.22.0
|
|
11
21
|
Requires-Dist: gymnasium>=0.29.0
|
|
12
22
|
Requires-Dist: psutil>=7.0.0
|
|
13
23
|
Requires-Dist: wandb>=0.21.1
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["setuptools>=61.0", "wheel"]
|
|
3
|
+
build-backend = "setuptools.build_meta"
|
|
4
|
+
|
|
5
|
+
[tool.setuptools.packages.find]
|
|
6
|
+
include = ["textpolicy*"]
|
|
7
|
+
|
|
8
|
+
[project]
|
|
9
|
+
name = "textpolicy"
|
|
10
|
+
version = "0.1.2"
|
|
11
|
+
description = "Reinforcement learning for text generation on MLX (Apple Silicon): GRPO/GSPO, environments, rollout, rewards, LoRA/QLoRA"
|
|
12
|
+
readme = "README.md"
|
|
13
|
+
requires-python = ">=3.12"
|
|
14
|
+
dependencies = [
|
|
15
|
+
"numpy>=2.3.2",
|
|
16
|
+
"mlx>=0.22.0", # Core MLX framework for Apple Silicon acceleration (tested up to 0.30.5)
|
|
17
|
+
"mlx-lm>=0.22.0", # MLX language models for inference (tested up to 0.30.6)
|
|
18
|
+
"gymnasium>=0.29.0",
|
|
19
|
+
"psutil>=7.0.0",
|
|
20
|
+
"wandb>=0.21.1",
|
|
21
|
+
"aiohttp>=3.12.15",
|
|
22
|
+
"pytest>=8.4.1",
|
|
23
|
+
]
|
|
24
|
+
|
|
25
|
+
keywords = [
|
|
26
|
+
"reinforcement-learning",
|
|
27
|
+
"text-generation",
|
|
28
|
+
"mlx",
|
|
29
|
+
"apple-silicon",
|
|
30
|
+
"lora",
|
|
31
|
+
"qlora",
|
|
32
|
+
"grpo",
|
|
33
|
+
"gspo",
|
|
34
|
+
"rlhf",
|
|
35
|
+
]
|
|
36
|
+
|
|
37
|
+
classifiers = [
|
|
38
|
+
"Programming Language :: Python :: 3",
|
|
39
|
+
"Operating System :: MacOS",
|
|
40
|
+
"Intended Audience :: Developers",
|
|
41
|
+
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
|
42
|
+
"License :: OSI Approved :: MIT License",
|
|
43
|
+
]
|
|
44
|
+
|
|
45
|
+
[project.urls]
|
|
46
|
+
Homepage = "https://github.com/teilomillet/textpolicy"
|
|
47
|
+
Repository = "https://github.com/teilomillet/textpolicy"
|
|
48
|
+
Documentation = "https://github.com/teilomillet/textpolicy#readme"
|
|
49
|
+
Changelog = "https://github.com/teilomillet/textpolicy/blob/main/CHANGELOG.md"
|
|
50
|
+
|
|
51
|
+
[project.scripts]
|
|
52
|
+
textpolicy = "textpolicy.cli:main"
|
|
53
|
+
|
|
54
|
+
[project.optional-dependencies]
|
|
55
|
+
external = ["aiohttp>=3.8.0", "pydantic>=2.0.0"]
|
|
56
|
+
dev = ["pytest>=7.0.0", "black>=22.0.0", "ruff>=0.1.0"]
|
|
@@ -0,0 +1,218 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Tests for GitHub Issues #2 and #3 fixes.
|
|
3
|
+
|
|
4
|
+
Issue #2: TextGenerationEnv cannot pass example metadata to reward functions
|
|
5
|
+
Issue #3: Export @verifier decorator at top level for API consistency
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import pytest
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class TestIssue3VerifierExport:
|
|
12
|
+
"""Test that @verifier decorator is exported at top level (Issue #3)."""
|
|
13
|
+
|
|
14
|
+
def test_verifier_accessible_at_top_level(self):
|
|
15
|
+
"""tp.verifier should be accessible without deep import."""
|
|
16
|
+
import textpolicy as tp
|
|
17
|
+
|
|
18
|
+
assert hasattr(tp, "verifier"), "verifier should be exported at top level"
|
|
19
|
+
assert callable(tp.verifier), "verifier should be callable"
|
|
20
|
+
|
|
21
|
+
def test_verifier_same_as_deep_import(self):
|
|
22
|
+
"""tp.verifier should be the same function as deep import."""
|
|
23
|
+
import textpolicy as tp
|
|
24
|
+
from textpolicy.rewards import verifier
|
|
25
|
+
|
|
26
|
+
assert tp.verifier is verifier, "tp.verifier should be same as textpolicy.rewards.verifier"
|
|
27
|
+
|
|
28
|
+
def test_verifier_decorator_works(self):
|
|
29
|
+
"""@tp.verifier should work as a decorator."""
|
|
30
|
+
import textpolicy as tp
|
|
31
|
+
|
|
32
|
+
@tp.verifier
|
|
33
|
+
def my_test_verifier(prompt, completion, example, **kwargs):
|
|
34
|
+
return len(completion) > 0
|
|
35
|
+
|
|
36
|
+
# Verify it's registered and callable
|
|
37
|
+
assert callable(my_test_verifier)
|
|
38
|
+
result = my_test_verifier("test prompt", "test completion", {})
|
|
39
|
+
assert isinstance(result, bool)
|
|
40
|
+
|
|
41
|
+
def test_reward_and_verifier_both_at_top_level(self):
|
|
42
|
+
"""Both @reward and @verifier should be at top level for API consistency."""
|
|
43
|
+
import textpolicy as tp
|
|
44
|
+
|
|
45
|
+
assert hasattr(tp, "reward"), "reward should be at top level"
|
|
46
|
+
assert hasattr(tp, "verifier"), "verifier should be at top level"
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
@pytest.mark.integration
|
|
50
|
+
class TestIssue2ExamplesParameter:
|
|
51
|
+
"""Test that TextGenerationEnv passes example metadata to reward functions (Issue #2)."""
|
|
52
|
+
|
|
53
|
+
@pytest.fixture
|
|
54
|
+
def dummy_tokenizer(self):
|
|
55
|
+
"""Provide a minimal tokenizer for tests."""
|
|
56
|
+
class DummyTokenizer:
|
|
57
|
+
def encode(self, text):
|
|
58
|
+
return [ord(c) % 256 for c in text]
|
|
59
|
+
|
|
60
|
+
def decode(self, ids):
|
|
61
|
+
return "".join(chr(int(i) % 256) for i in ids)
|
|
62
|
+
|
|
63
|
+
return DummyTokenizer()
|
|
64
|
+
|
|
65
|
+
def test_env_accepts_examples_parameter(self, dummy_tokenizer):
|
|
66
|
+
"""TextGenerationEnv should accept an examples parameter."""
|
|
67
|
+
from textpolicy.environment.text_generation import TextGenerationEnv
|
|
68
|
+
|
|
69
|
+
prompts = ["Hello", "World"]
|
|
70
|
+
examples = [{"key": "value1"}, {"key": "value2"}]
|
|
71
|
+
|
|
72
|
+
def reward_fn(prompt, completion, example, **kwargs):
|
|
73
|
+
return 1.0
|
|
74
|
+
|
|
75
|
+
# Should not raise
|
|
76
|
+
env = TextGenerationEnv(
|
|
77
|
+
prompts=prompts,
|
|
78
|
+
reward_fn=reward_fn,
|
|
79
|
+
tokenizer=dummy_tokenizer,
|
|
80
|
+
examples=examples,
|
|
81
|
+
)
|
|
82
|
+
assert env.examples == examples
|
|
83
|
+
|
|
84
|
+
def test_env_defaults_to_empty_dicts_when_no_examples(self, dummy_tokenizer):
|
|
85
|
+
"""When examples not provided, should default to empty dicts."""
|
|
86
|
+
from textpolicy.environment.text_generation import TextGenerationEnv
|
|
87
|
+
|
|
88
|
+
prompts = ["Hello", "World"]
|
|
89
|
+
|
|
90
|
+
def reward_fn(prompt, completion, example, **kwargs):
|
|
91
|
+
return 1.0
|
|
92
|
+
|
|
93
|
+
env = TextGenerationEnv(
|
|
94
|
+
prompts=prompts, reward_fn=reward_fn, tokenizer=dummy_tokenizer
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
assert env.examples == [{}, {}]
|
|
98
|
+
|
|
99
|
+
def test_env_validates_examples_length(self, dummy_tokenizer):
|
|
100
|
+
"""Should raise ValueError if examples length != prompts length."""
|
|
101
|
+
from textpolicy.environment.text_generation import TextGenerationEnv
|
|
102
|
+
|
|
103
|
+
prompts = ["Hello", "World"]
|
|
104
|
+
examples = [{"key": "value1"}] # Wrong length
|
|
105
|
+
|
|
106
|
+
def reward_fn(prompt, completion, example, **kwargs):
|
|
107
|
+
return 1.0
|
|
108
|
+
|
|
109
|
+
with pytest.raises(ValueError, match="examples length.*must match prompts length"):
|
|
110
|
+
TextGenerationEnv(
|
|
111
|
+
prompts=prompts,
|
|
112
|
+
reward_fn=reward_fn,
|
|
113
|
+
tokenizer=dummy_tokenizer,
|
|
114
|
+
examples=examples,
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
def test_example_passed_to_reward_function(self, dummy_tokenizer):
|
|
118
|
+
"""Reward function should receive the correct example for each prompt."""
|
|
119
|
+
from textpolicy.environment.text_generation import TextGenerationEnv
|
|
120
|
+
|
|
121
|
+
prompts = ["Question 1", "Question 2"]
|
|
122
|
+
examples = [
|
|
123
|
+
{"db_id": "database_1", "gold_sql": "SELECT 1"},
|
|
124
|
+
{"db_id": "database_2", "gold_sql": "SELECT 2"},
|
|
125
|
+
]
|
|
126
|
+
|
|
127
|
+
received_examples = []
|
|
128
|
+
|
|
129
|
+
def capture_reward(prompt, completion, example, **kwargs):
|
|
130
|
+
received_examples.append(example.copy())
|
|
131
|
+
return 1.0
|
|
132
|
+
|
|
133
|
+
env = TextGenerationEnv(
|
|
134
|
+
prompts=prompts,
|
|
135
|
+
reward_fn=capture_reward,
|
|
136
|
+
tokenizer=dummy_tokenizer,
|
|
137
|
+
examples=examples,
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
# Episode 0 should use examples[0]
|
|
141
|
+
env.reset()
|
|
142
|
+
env.step("some response")
|
|
143
|
+
assert received_examples[0] == {"db_id": "database_1", "gold_sql": "SELECT 1"}
|
|
144
|
+
|
|
145
|
+
# Episode 1 should use examples[1]
|
|
146
|
+
env.reset()
|
|
147
|
+
env.step("another response")
|
|
148
|
+
assert received_examples[1] == {"db_id": "database_2", "gold_sql": "SELECT 2"}
|
|
149
|
+
|
|
150
|
+
def test_examples_cycle_with_prompts(self, dummy_tokenizer):
|
|
151
|
+
"""Examples should cycle correctly when prompts cycle."""
|
|
152
|
+
from textpolicy.environment.text_generation import TextGenerationEnv
|
|
153
|
+
|
|
154
|
+
prompts = ["P1", "P2"]
|
|
155
|
+
examples = [{"idx": 0}, {"idx": 1}]
|
|
156
|
+
|
|
157
|
+
received_indices = []
|
|
158
|
+
|
|
159
|
+
def capture_reward(prompt, completion, example, **kwargs):
|
|
160
|
+
received_indices.append(example.get("idx"))
|
|
161
|
+
return 1.0
|
|
162
|
+
|
|
163
|
+
env = TextGenerationEnv(
|
|
164
|
+
prompts=prompts,
|
|
165
|
+
reward_fn=capture_reward,
|
|
166
|
+
tokenizer=dummy_tokenizer,
|
|
167
|
+
examples=examples,
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
# Run 4 episodes (should cycle through prompts twice)
|
|
171
|
+
for _ in range(4):
|
|
172
|
+
env.reset()
|
|
173
|
+
env.step("response")
|
|
174
|
+
|
|
175
|
+
# Should have received [0, 1, 0, 1]
|
|
176
|
+
assert received_indices == [0, 1, 0, 1]
|
|
177
|
+
|
|
178
|
+
def test_litmus_test_from_issue(self, dummy_tokenizer):
|
|
179
|
+
"""Run the exact litmus test from Issue #2."""
|
|
180
|
+
from textpolicy.environment.text_generation import TextGenerationEnv
|
|
181
|
+
import textpolicy as tp
|
|
182
|
+
|
|
183
|
+
examples = [
|
|
184
|
+
{"db_id": "concert_singer", "gold_sql": "SELECT COUNT(*) FROM singer"},
|
|
185
|
+
{"db_id": "pets_1", "gold_sql": "SELECT COUNT(*) FROM pets"},
|
|
186
|
+
]
|
|
187
|
+
prompts = [
|
|
188
|
+
"Schema: singer(id, name)\nQuestion: How many singers?",
|
|
189
|
+
"Schema: pets(id, name)\nQuestion: How many pets?",
|
|
190
|
+
]
|
|
191
|
+
|
|
192
|
+
captured_db_ids = []
|
|
193
|
+
|
|
194
|
+
@tp.reward
|
|
195
|
+
def check_example(prompt, completion, example, **kwargs):
|
|
196
|
+
db_id = example.get("db_id")
|
|
197
|
+
captured_db_ids.append(db_id)
|
|
198
|
+
return 1.0
|
|
199
|
+
|
|
200
|
+
env = TextGenerationEnv(prompts, check_example, examples=examples, tokenizer=dummy_tokenizer)
|
|
201
|
+
|
|
202
|
+
# First episode
|
|
203
|
+
env.reset()
|
|
204
|
+
env.step("some action")
|
|
205
|
+
|
|
206
|
+
# Should have captured 'concert_singer'
|
|
207
|
+
assert captured_db_ids[0] == "concert_singer", f"Expected 'concert_singer', got {captured_db_ids[0]}"
|
|
208
|
+
|
|
209
|
+
# Second episode
|
|
210
|
+
env.reset()
|
|
211
|
+
env.step("another action")
|
|
212
|
+
|
|
213
|
+
# Should have captured 'pets_1'
|
|
214
|
+
assert captured_db_ids[1] == "pets_1", f"Expected 'pets_1', got {captured_db_ids[1]}"
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
if __name__ == "__main__":
|
|
218
|
+
pytest.main([__file__, "-v"])
|
|
@@ -0,0 +1,315 @@
|
|
|
1
|
+
"""
|
|
2
|
+
MLX and mlx-lm Compatibility Tests
|
|
3
|
+
|
|
4
|
+
These tests verify that textpolicy works correctly with MLX 0.30.x and mlx-lm 0.30.x.
|
|
5
|
+
They cover the core APIs used throughout the codebase.
|
|
6
|
+
|
|
7
|
+
Run with: pytest tests/test_mlx_compatibility.py -v
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
import pytest
|
|
11
|
+
import mlx.core as mx
|
|
12
|
+
import mlx.nn as nn
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@pytest.mark.unit
|
|
16
|
+
class TestMLXCoreAPIs:
|
|
17
|
+
"""Test core MLX APIs used in textpolicy."""
|
|
18
|
+
|
|
19
|
+
def test_mx_compile_decorator(self):
|
|
20
|
+
"""Test @mx.compile decorator works (used in grpo.py, gspo.py, trainer.py)."""
|
|
21
|
+
@mx.compile
|
|
22
|
+
def compiled_fn(x, y):
|
|
23
|
+
return x + y
|
|
24
|
+
|
|
25
|
+
result = compiled_fn(mx.array([1.0]), mx.array([2.0]))
|
|
26
|
+
assert float(result[0]) == 3.0
|
|
27
|
+
|
|
28
|
+
def test_mx_compile_with_value_and_grad(self):
|
|
29
|
+
"""Test mx.compile with nn.value_and_grad (used in trainer.py:90)."""
|
|
30
|
+
class SimpleModel(nn.Module):
|
|
31
|
+
def __init__(self):
|
|
32
|
+
super().__init__()
|
|
33
|
+
self.linear = nn.Linear(4, 2)
|
|
34
|
+
def __call__(self, x):
|
|
35
|
+
return self.linear(x)
|
|
36
|
+
|
|
37
|
+
model = SimpleModel()
|
|
38
|
+
mx.eval(model.parameters())
|
|
39
|
+
|
|
40
|
+
def loss_fn(model, x, y):
|
|
41
|
+
pred = model(x)
|
|
42
|
+
return mx.mean((pred - y) ** 2)
|
|
43
|
+
|
|
44
|
+
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
|
|
45
|
+
x = mx.random.normal((2, 4))
|
|
46
|
+
y = mx.random.normal((2, 2))
|
|
47
|
+
loss, grads = loss_and_grad_fn(model, x, y)
|
|
48
|
+
|
|
49
|
+
assert not mx.isnan(loss)
|
|
50
|
+
assert isinstance(grads, dict)
|
|
51
|
+
|
|
52
|
+
def test_array_operations(self):
|
|
53
|
+
"""Test array operations used in GRPO/GSPO algorithms."""
|
|
54
|
+
# Ratio computation (grpo.py, gspo.py)
|
|
55
|
+
old_lp = mx.array([-1.0, -1.2, -0.8])
|
|
56
|
+
new_lp = mx.array([-1.1, -1.0, -0.9])
|
|
57
|
+
ratios = mx.exp(new_lp - old_lp)
|
|
58
|
+
|
|
59
|
+
assert ratios.shape == (3,)
|
|
60
|
+
assert all(not mx.isnan(r) for r in ratios)
|
|
61
|
+
|
|
62
|
+
# Clipping (PPO-style)
|
|
63
|
+
clipped = mx.clip(ratios, 0.8, 1.2)
|
|
64
|
+
# Use small epsilon for float comparison
|
|
65
|
+
assert all(0.8 - 1e-6 <= float(c) <= 1.2 + 1e-6 for c in clipped)
|
|
66
|
+
|
|
67
|
+
def test_array_slicing_with_python_ints(self):
|
|
68
|
+
"""Test array slicing with Python integers (used in gspo.py)."""
|
|
69
|
+
arr = mx.array([1.0, 2.0, 3.0, 4.0, 5.0])
|
|
70
|
+
|
|
71
|
+
# This pattern is used in compute_sequence_importance_weights
|
|
72
|
+
current_idx = 0
|
|
73
|
+
for seq_len in [2, 3]:
|
|
74
|
+
result = arr[current_idx:current_idx + seq_len]
|
|
75
|
+
assert result.shape[0] == seq_len
|
|
76
|
+
current_idx += seq_len
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
@pytest.mark.unit
|
|
80
|
+
class TestMLXLMAPIs:
|
|
81
|
+
"""Test mlx-lm APIs used in textpolicy."""
|
|
82
|
+
|
|
83
|
+
def test_mlx_lm_imports(self):
|
|
84
|
+
"""Test mlx_lm core imports (used in mlx_generation.py)."""
|
|
85
|
+
from mlx_lm import load, generate
|
|
86
|
+
assert callable(load)
|
|
87
|
+
assert callable(generate)
|
|
88
|
+
|
|
89
|
+
def test_sample_utils_imports(self):
|
|
90
|
+
"""Test sample_utils imports (used in mlx_generation.py:27)."""
|
|
91
|
+
from mlx_lm.sample_utils import make_sampler, make_logits_processors
|
|
92
|
+
assert callable(make_sampler)
|
|
93
|
+
assert callable(make_logits_processors)
|
|
94
|
+
|
|
95
|
+
def test_make_sampler_signature(self):
|
|
96
|
+
"""Test make_sampler accepts expected parameters."""
|
|
97
|
+
from mlx_lm.sample_utils import make_sampler
|
|
98
|
+
|
|
99
|
+
# These params are used in mlx_generation.py:77-81
|
|
100
|
+
sampler = make_sampler(
|
|
101
|
+
temp=0.7,
|
|
102
|
+
top_p=0.9,
|
|
103
|
+
min_p=0.0,
|
|
104
|
+
min_tokens_to_keep=2
|
|
105
|
+
)
|
|
106
|
+
assert sampler is not None
|
|
107
|
+
|
|
108
|
+
def test_quantize_model_signature(self):
|
|
109
|
+
"""Test quantize_model has expected parameters (used in lora.py:335)."""
|
|
110
|
+
from mlx_lm.utils import quantize_model
|
|
111
|
+
import inspect
|
|
112
|
+
|
|
113
|
+
sig = inspect.signature(quantize_model)
|
|
114
|
+
params = list(sig.parameters.keys())
|
|
115
|
+
|
|
116
|
+
# These params are used in lora.py:342-348
|
|
117
|
+
expected = ['model', 'config', 'q_group_size', 'q_bits', 'quant_predicate']
|
|
118
|
+
for p in expected:
|
|
119
|
+
assert p in params, f"Expected param '{p}' not found in quantize_model"
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
@pytest.mark.unit
|
|
123
|
+
@pytest.mark.algorithm
|
|
124
|
+
class TestGRPOWithMLX:
|
|
125
|
+
"""Test GRPO algorithms with current MLX version."""
|
|
126
|
+
|
|
127
|
+
def test_compute_advantages(self):
|
|
128
|
+
"""Test GRPO compute_advantages."""
|
|
129
|
+
from textpolicy.algorithms import grpo
|
|
130
|
+
|
|
131
|
+
rewards = mx.array([1.0, 0.5, -0.5, 0.8, 0.2])
|
|
132
|
+
advantages = grpo.compute_advantages(rewards)
|
|
133
|
+
|
|
134
|
+
assert advantages.shape == rewards.shape
|
|
135
|
+
# Group-relative: mean should be ~0
|
|
136
|
+
assert abs(float(mx.mean(advantages))) < 1e-5
|
|
137
|
+
|
|
138
|
+
def test_compute_advantages_compiled(self):
|
|
139
|
+
"""Test compiled version of compute_advantages."""
|
|
140
|
+
from textpolicy.algorithms import grpo
|
|
141
|
+
|
|
142
|
+
rewards = mx.array([1.0, 0.5, -0.5, 0.8])
|
|
143
|
+
advantages = grpo.compute_advantages_compiled(rewards)
|
|
144
|
+
|
|
145
|
+
assert advantages.shape == rewards.shape
|
|
146
|
+
|
|
147
|
+
def test_policy_loss(self):
|
|
148
|
+
"""Test GRPO policy_loss computation."""
|
|
149
|
+
from textpolicy.algorithms import grpo
|
|
150
|
+
|
|
151
|
+
old_lp = mx.array([-1.0, -1.2, -0.8, -1.1, -0.9])
|
|
152
|
+
new_lp = mx.array([-1.1, -1.0, -0.9, -1.0, -1.0])
|
|
153
|
+
advantages = mx.array([0.6, 0.1, -0.9, 0.4, -0.2])
|
|
154
|
+
|
|
155
|
+
loss = grpo.policy_loss(old_lp, new_lp, advantages, clip_ratio=0.2)
|
|
156
|
+
|
|
157
|
+
assert not mx.isnan(loss)
|
|
158
|
+
assert loss.shape == () # Scalar
|
|
159
|
+
|
|
160
|
+
def test_entropy_bonus(self):
|
|
161
|
+
"""Test entropy bonus computation."""
|
|
162
|
+
from textpolicy.algorithms import grpo
|
|
163
|
+
|
|
164
|
+
logprobs = mx.array([-1.0, -2.0, -0.5, -1.5])
|
|
165
|
+
entropy = grpo.entropy_bonus(logprobs, coefficient=0.01)
|
|
166
|
+
|
|
167
|
+
assert not mx.isnan(entropy)
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
@pytest.mark.unit
|
|
171
|
+
@pytest.mark.algorithm
|
|
172
|
+
class TestGSPOWithMLX:
|
|
173
|
+
"""Test GSPO algorithms with current MLX version."""
|
|
174
|
+
|
|
175
|
+
def test_compute_sequence_importance_weights(self):
|
|
176
|
+
"""Test sequence-level importance weights."""
|
|
177
|
+
from textpolicy.algorithms import gspo
|
|
178
|
+
|
|
179
|
+
old_lp = mx.array([-1.0, -1.2, -0.8, -1.1, -0.9])
|
|
180
|
+
new_lp = mx.array([-1.1, -1.0, -0.9, -1.0, -1.0])
|
|
181
|
+
seq_lens = [2, 3]
|
|
182
|
+
|
|
183
|
+
weights = gspo.compute_sequence_importance_weights(
|
|
184
|
+
old_lp, new_lp, seq_lens, clip_ratio=0.2
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
assert len(weights) == len(seq_lens)
|
|
188
|
+
assert all(not mx.isnan(w) for w in weights)
|
|
189
|
+
|
|
190
|
+
def test_compute_hybrid_importance_weights(self):
|
|
191
|
+
"""Test hybrid importance weights."""
|
|
192
|
+
from textpolicy.algorithms import gspo
|
|
193
|
+
|
|
194
|
+
old_lp = mx.array([-1.0, -1.2, -0.8, -1.1, -0.9])
|
|
195
|
+
new_lp = mx.array([-1.1, -1.0, -0.9, -1.0, -1.0])
|
|
196
|
+
seq_lens = [2, 3]
|
|
197
|
+
|
|
198
|
+
weights = gspo.compute_hybrid_importance_weights(
|
|
199
|
+
old_lp, new_lp, seq_lens, alpha=0.5, beta=0.5
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
assert len(weights) == sum(seq_lens) # Token-level weights
|
|
203
|
+
|
|
204
|
+
def test_gspo_policy_loss_sequence_variant(self):
|
|
205
|
+
"""Test GSPO policy loss with sequence variant."""
|
|
206
|
+
from textpolicy.algorithms import gspo
|
|
207
|
+
|
|
208
|
+
old_lp = mx.array([-1.0, -1.2, -0.8, -1.1, -0.9])
|
|
209
|
+
new_lp = mx.array([-1.1, -1.0, -0.9, -1.0, -1.0])
|
|
210
|
+
seq_lens = [2, 3]
|
|
211
|
+
advantages = mx.array([0.5, -0.3])
|
|
212
|
+
|
|
213
|
+
loss = gspo.gspo_policy_loss(
|
|
214
|
+
old_lp, new_lp, advantages, seq_lens,
|
|
215
|
+
variant='sequence', clip_ratio=0.2
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
assert not mx.isnan(loss)
|
|
219
|
+
assert loss.shape == ()
|
|
220
|
+
|
|
221
|
+
def test_gspo_policy_loss_hybrid_variant(self):
|
|
222
|
+
"""Test GSPO policy loss with hybrid variant."""
|
|
223
|
+
from textpolicy.algorithms import gspo
|
|
224
|
+
|
|
225
|
+
old_lp = mx.array([-1.0, -1.2, -0.8, -1.1, -0.9])
|
|
226
|
+
new_lp = mx.array([-1.1, -1.0, -0.9, -1.0, -1.0])
|
|
227
|
+
seq_lens = [2, 3]
|
|
228
|
+
advantages = mx.array([0.5, -0.3])
|
|
229
|
+
|
|
230
|
+
loss = gspo.gspo_policy_loss(
|
|
231
|
+
old_lp, new_lp, advantages, seq_lens,
|
|
232
|
+
variant='hybrid', clip_ratio=0.2
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
assert not mx.isnan(loss)
|
|
236
|
+
|
|
237
|
+
def test_gspo_policy_loss_token_variant(self):
|
|
238
|
+
"""Test GSPO policy loss with token variant (GRPO fallback)."""
|
|
239
|
+
from textpolicy.algorithms import gspo
|
|
240
|
+
|
|
241
|
+
old_lp = mx.array([-1.0, -1.2, -0.8, -1.1, -0.9])
|
|
242
|
+
new_lp = mx.array([-1.1, -1.0, -0.9, -1.0, -1.0])
|
|
243
|
+
seq_lens = [2, 3]
|
|
244
|
+
advantages = mx.array([0.5, -0.3])
|
|
245
|
+
|
|
246
|
+
loss = gspo.gspo_policy_loss(
|
|
247
|
+
old_lp, new_lp, advantages, seq_lens,
|
|
248
|
+
variant='token', clip_ratio=0.2
|
|
249
|
+
)
|
|
250
|
+
|
|
251
|
+
assert not mx.isnan(loss)
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
@pytest.mark.unit
|
|
255
|
+
class TestTrainerWithMLX:
|
|
256
|
+
"""Test Trainer compilation with MLX."""
|
|
257
|
+
|
|
258
|
+
def test_trainer_compiles_loss_function(self):
|
|
259
|
+
"""Test that Trainer correctly compiles the loss function."""
|
|
260
|
+
import mlx.optimizers as optim
|
|
261
|
+
from textpolicy.training import Trainer
|
|
262
|
+
from textpolicy.algorithms import grpo
|
|
263
|
+
|
|
264
|
+
class SimpleModel(nn.Module):
|
|
265
|
+
def __init__(self):
|
|
266
|
+
super().__init__()
|
|
267
|
+
self.linear = nn.Linear(10, 5)
|
|
268
|
+
def __call__(self, x):
|
|
269
|
+
return self.linear(x)
|
|
270
|
+
|
|
271
|
+
model = SimpleModel()
|
|
272
|
+
mx.eval(model.parameters())
|
|
273
|
+
|
|
274
|
+
def loss_fn(model, batch):
|
|
275
|
+
obs = batch.get('observations', mx.zeros((1, 10)))
|
|
276
|
+
return mx.mean(model(obs) ** 2)
|
|
277
|
+
|
|
278
|
+
trainer = Trainer(
|
|
279
|
+
model=model,
|
|
280
|
+
loss_fn=loss_fn,
|
|
281
|
+
optimizer=optim.Adam(learning_rate=1e-3),
|
|
282
|
+
advantage_fn=grpo.compute_advantages,
|
|
283
|
+
compile_training=True
|
|
284
|
+
)
|
|
285
|
+
|
|
286
|
+
# Verify compilation happened
|
|
287
|
+
assert trainer.loss_and_grad_fn is not None
|
|
288
|
+
|
|
289
|
+
|
|
290
|
+
@pytest.mark.unit
|
|
291
|
+
class TestLoRAWithMLX:
|
|
292
|
+
"""Test LoRA functions with MLX."""
|
|
293
|
+
|
|
294
|
+
def test_lora_functions_importable(self):
|
|
295
|
+
"""Test all LoRA functions can be imported."""
|
|
296
|
+
from textpolicy.generation.lora import (
|
|
297
|
+
apply_lora,
|
|
298
|
+
freeze_base,
|
|
299
|
+
extract_params,
|
|
300
|
+
merge_weights,
|
|
301
|
+
create_lora_setup,
|
|
302
|
+
create_qlora_setup,
|
|
303
|
+
apply_quantization_to_model,
|
|
304
|
+
compute_lora_memory_savings
|
|
305
|
+
)
|
|
306
|
+
|
|
307
|
+
# All should be callable
|
|
308
|
+
assert callable(apply_lora)
|
|
309
|
+
assert callable(freeze_base)
|
|
310
|
+
assert callable(extract_params)
|
|
311
|
+
assert callable(merge_weights)
|
|
312
|
+
assert callable(create_lora_setup)
|
|
313
|
+
assert callable(create_qlora_setup)
|
|
314
|
+
assert callable(apply_quantization_to_model)
|
|
315
|
+
assert callable(compute_lora_memory_savings)
|
|
@@ -0,0 +1,240 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Training Pipeline Integration Tests
|
|
3
|
+
|
|
4
|
+
These tests verify that the complete training pipeline works correctly,
|
|
5
|
+
including gradient computation and parameter updates.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import pytest
|
|
9
|
+
import mlx.core as mx
|
|
10
|
+
import mlx.nn as nn
|
|
11
|
+
import mlx.optimizers as optim
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@pytest.mark.integration
|
|
15
|
+
class TestTrainingPipeline:
|
|
16
|
+
"""Test the complete training pipeline with gradient updates."""
|
|
17
|
+
|
|
18
|
+
def test_grpo_training_step_produces_finite_loss(self):
|
|
19
|
+
"""Test that GRPO training produces finite loss values."""
|
|
20
|
+
from textpolicy.algorithms import grpo
|
|
21
|
+
|
|
22
|
+
# Create simple model
|
|
23
|
+
class TinyLM(nn.Module):
|
|
24
|
+
def __init__(self, vocab_size=100, hidden=32):
|
|
25
|
+
super().__init__()
|
|
26
|
+
self.embed = nn.Embedding(vocab_size, hidden)
|
|
27
|
+
self.proj = nn.Linear(hidden, vocab_size)
|
|
28
|
+
|
|
29
|
+
def __call__(self, x):
|
|
30
|
+
return self.proj(self.embed(x))
|
|
31
|
+
|
|
32
|
+
model = TinyLM()
|
|
33
|
+
mx.eval(model.parameters())
|
|
34
|
+
|
|
35
|
+
# Prepare batch data (mimics what data_selector returns)
|
|
36
|
+
old_logprobs = mx.array([-2.5, -3.1, -2.8, -2.9, -3.0])
|
|
37
|
+
observations = mx.array([10, 20, 30, 40, 50])
|
|
38
|
+
actions = mx.array([15, 25, 35, 45, 55])
|
|
39
|
+
rewards = mx.array([1.0, 0.5, -0.5, 0.8, 0.2])
|
|
40
|
+
|
|
41
|
+
# Compute new logprobs from model
|
|
42
|
+
logits = model(observations)
|
|
43
|
+
log_probs = nn.log_softmax(logits, axis=-1)
|
|
44
|
+
new_logprobs = mx.take_along_axis(log_probs, actions[:, None], axis=-1).squeeze(-1)
|
|
45
|
+
|
|
46
|
+
# Compute GRPO loss
|
|
47
|
+
advantages = grpo.compute_advantages(rewards)
|
|
48
|
+
loss = grpo.policy_loss(old_logprobs, new_logprobs, advantages, clip_ratio=0.2)
|
|
49
|
+
|
|
50
|
+
mx.eval(loss)
|
|
51
|
+
assert not mx.isnan(loss), "Loss should not be NaN"
|
|
52
|
+
assert not mx.isinf(loss), "Loss should not be infinite"
|
|
53
|
+
|
|
54
|
+
def test_grpo_gradients_flow(self):
|
|
55
|
+
"""Test that gradients flow through the GRPO loss."""
|
|
56
|
+
from textpolicy.algorithms import grpo
|
|
57
|
+
|
|
58
|
+
class TinyLM(nn.Module):
|
|
59
|
+
def __init__(self, vocab_size=100, hidden=32):
|
|
60
|
+
super().__init__()
|
|
61
|
+
self.embed = nn.Embedding(vocab_size, hidden)
|
|
62
|
+
self.proj = nn.Linear(hidden, vocab_size)
|
|
63
|
+
|
|
64
|
+
def __call__(self, x):
|
|
65
|
+
return self.proj(self.embed(x))
|
|
66
|
+
|
|
67
|
+
model = TinyLM()
|
|
68
|
+
mx.eval(model.parameters())
|
|
69
|
+
|
|
70
|
+
def loss_fn(model):
|
|
71
|
+
observations = mx.array([10, 20, 30, 40, 50])
|
|
72
|
+
actions = mx.array([15, 25, 35, 45, 55])
|
|
73
|
+
old_logprobs = mx.array([-2.5, -3.1, -2.8, -2.9, -3.0])
|
|
74
|
+
rewards = mx.array([1.0, 0.5, -0.5, 0.8, 0.2])
|
|
75
|
+
|
|
76
|
+
logits = model(observations)
|
|
77
|
+
log_probs = nn.log_softmax(logits, axis=-1)
|
|
78
|
+
new_logprobs = mx.take_along_axis(log_probs, actions[:, None], axis=-1).squeeze(-1)
|
|
79
|
+
|
|
80
|
+
advantages = grpo.compute_advantages(rewards)
|
|
81
|
+
return grpo.policy_loss(old_logprobs, new_logprobs, advantages, clip_ratio=0.2)
|
|
82
|
+
|
|
83
|
+
# Compute gradients
|
|
84
|
+
loss_and_grad = nn.value_and_grad(model, loss_fn)
|
|
85
|
+
loss, grads = loss_and_grad(model)
|
|
86
|
+
|
|
87
|
+
mx.eval(loss, grads)
|
|
88
|
+
|
|
89
|
+
assert not mx.isnan(loss), "Loss should not be NaN"
|
|
90
|
+
|
|
91
|
+
# Check that at least some gradients are non-zero
|
|
92
|
+
has_nonzero_grad = False
|
|
93
|
+
for name, grad in grads.items():
|
|
94
|
+
if isinstance(grad, dict):
|
|
95
|
+
for subname, subgrad in grad.items():
|
|
96
|
+
if mx.any(subgrad != 0):
|
|
97
|
+
has_nonzero_grad = True
|
|
98
|
+
break
|
|
99
|
+
elif mx.any(grad != 0):
|
|
100
|
+
has_nonzero_grad = True
|
|
101
|
+
break
|
|
102
|
+
|
|
103
|
+
assert has_nonzero_grad, "Should have at least some non-zero gradients"
|
|
104
|
+
|
|
105
|
+
def test_optimizer_updates_parameters(self):
|
|
106
|
+
"""Test that optimizer actually updates model parameters."""
|
|
107
|
+
from textpolicy.algorithms import grpo
|
|
108
|
+
|
|
109
|
+
class TinyLM(nn.Module):
|
|
110
|
+
def __init__(self, vocab_size=100, hidden=32):
|
|
111
|
+
super().__init__()
|
|
112
|
+
self.embed = nn.Embedding(vocab_size, hidden)
|
|
113
|
+
self.proj = nn.Linear(hidden, vocab_size)
|
|
114
|
+
|
|
115
|
+
def __call__(self, x):
|
|
116
|
+
return self.proj(self.embed(x))
|
|
117
|
+
|
|
118
|
+
model = TinyLM()
|
|
119
|
+
optimizer = optim.Adam(learning_rate=0.01)
|
|
120
|
+
mx.eval(model.parameters())
|
|
121
|
+
|
|
122
|
+
# Store initial parameter values
|
|
123
|
+
initial_weight = mx.array(model.proj.weight) # Create a copy
|
|
124
|
+
|
|
125
|
+
def loss_fn(model):
|
|
126
|
+
observations = mx.array([10, 20, 30])
|
|
127
|
+
actions = mx.array([15, 25, 35])
|
|
128
|
+
old_logprobs = mx.array([-2.5, -3.1, -2.8])
|
|
129
|
+
rewards = mx.array([1.0, 0.5, -0.5])
|
|
130
|
+
|
|
131
|
+
logits = model(observations)
|
|
132
|
+
log_probs = nn.log_softmax(logits, axis=-1)
|
|
133
|
+
new_logprobs = mx.take_along_axis(log_probs, actions[:, None], axis=-1).squeeze(-1)
|
|
134
|
+
|
|
135
|
+
advantages = grpo.compute_advantages(rewards)
|
|
136
|
+
return grpo.policy_loss(old_logprobs, new_logprobs, advantages)
|
|
137
|
+
|
|
138
|
+
loss_and_grad = nn.value_and_grad(model, loss_fn)
|
|
139
|
+
|
|
140
|
+
# Run a few training steps
|
|
141
|
+
for _ in range(3):
|
|
142
|
+
loss, grads = loss_and_grad(model)
|
|
143
|
+
optimizer.update(model, grads)
|
|
144
|
+
mx.eval(model.parameters())
|
|
145
|
+
|
|
146
|
+
# Check parameters changed
|
|
147
|
+
final_weight = model.proj.weight
|
|
148
|
+
params_changed = not mx.allclose(initial_weight, final_weight)
|
|
149
|
+
|
|
150
|
+
assert params_changed, "Parameters should change after optimization"
|
|
151
|
+
|
|
152
|
+
def test_gspo_training_step_produces_finite_loss(self):
|
|
153
|
+
"""Test that GSPO training produces finite loss values."""
|
|
154
|
+
from textpolicy.algorithms import gspo
|
|
155
|
+
|
|
156
|
+
class TinyLM(nn.Module):
|
|
157
|
+
def __init__(self, vocab_size=100, hidden=32):
|
|
158
|
+
super().__init__()
|
|
159
|
+
self.embed = nn.Embedding(vocab_size, hidden)
|
|
160
|
+
self.proj = nn.Linear(hidden, vocab_size)
|
|
161
|
+
|
|
162
|
+
def __call__(self, x):
|
|
163
|
+
return self.proj(self.embed(x))
|
|
164
|
+
|
|
165
|
+
model = TinyLM()
|
|
166
|
+
mx.eval(model.parameters())
|
|
167
|
+
|
|
168
|
+
# Prepare batch data
|
|
169
|
+
old_logprobs = mx.array([-2.5, -3.1, -2.8, -2.9, -3.0])
|
|
170
|
+
observations = mx.array([10, 20, 30, 40, 50])
|
|
171
|
+
actions = mx.array([15, 25, 35, 45, 55])
|
|
172
|
+
|
|
173
|
+
# Compute new logprobs from model
|
|
174
|
+
logits = model(observations)
|
|
175
|
+
log_probs = nn.log_softmax(logits, axis=-1)
|
|
176
|
+
new_logprobs = mx.take_along_axis(log_probs, actions[:, None], axis=-1).squeeze(-1)
|
|
177
|
+
|
|
178
|
+
# GSPO needs sequence-level advantages
|
|
179
|
+
sequence_lengths = [2, 3] # 2 sequences
|
|
180
|
+
advantages = mx.array([0.5, -0.3]) # Per-sequence
|
|
181
|
+
|
|
182
|
+
# Test all GSPO variants
|
|
183
|
+
for variant in ['sequence', 'hybrid', 'token']:
|
|
184
|
+
loss = gspo.gspo_policy_loss(
|
|
185
|
+
old_logprobs, new_logprobs, advantages, sequence_lengths,
|
|
186
|
+
variant=variant, clip_ratio=0.2
|
|
187
|
+
)
|
|
188
|
+
mx.eval(loss)
|
|
189
|
+
assert not mx.isnan(loss), f"GSPO {variant} loss should not be NaN"
|
|
190
|
+
assert not mx.isinf(loss), f"GSPO {variant} loss should not be infinite"
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
@pytest.mark.integration
|
|
194
|
+
class TestCompiledTraining:
|
|
195
|
+
"""Test that @mx.compile works correctly with training."""
|
|
196
|
+
|
|
197
|
+
def test_compiled_function_works(self):
|
|
198
|
+
"""Test that @mx.compile decorator works with loss functions."""
|
|
199
|
+
@mx.compile
|
|
200
|
+
def compiled_loss(x, y):
|
|
201
|
+
return mx.mean((x - y) ** 2)
|
|
202
|
+
|
|
203
|
+
x = mx.array([1.0, 2.0, 3.0])
|
|
204
|
+
y = mx.array([1.1, 2.1, 3.1])
|
|
205
|
+
|
|
206
|
+
loss = compiled_loss(x, y)
|
|
207
|
+
mx.eval(loss)
|
|
208
|
+
|
|
209
|
+
assert not mx.isnan(loss), "Compiled loss should not be NaN"
|
|
210
|
+
assert float(loss) > 0, "Loss should be positive"
|
|
211
|
+
|
|
212
|
+
def test_trainer_compiled_loss_function(self):
|
|
213
|
+
"""Test that Trainer's compiled loss function works correctly.
|
|
214
|
+
|
|
215
|
+
This mimics how Trainer.py:90 compiles the loss function.
|
|
216
|
+
"""
|
|
217
|
+
from textpolicy.training import Trainer
|
|
218
|
+
from textpolicy.algorithms import grpo
|
|
219
|
+
|
|
220
|
+
class TinyLM(nn.Module):
|
|
221
|
+
def __init__(self):
|
|
222
|
+
super().__init__()
|
|
223
|
+
self.linear = nn.Linear(10, 10)
|
|
224
|
+
def __call__(self, x):
|
|
225
|
+
return self.linear(x)
|
|
226
|
+
|
|
227
|
+
model = TinyLM()
|
|
228
|
+
mx.eval(model.parameters())
|
|
229
|
+
|
|
230
|
+
# Create trainer with compilation enabled
|
|
231
|
+
trainer = Trainer(
|
|
232
|
+
model=model,
|
|
233
|
+
loss_fn=grpo.policy_loss,
|
|
234
|
+
optimizer=optim.Adam(learning_rate=1e-3),
|
|
235
|
+
advantage_fn=grpo.compute_advantages,
|
|
236
|
+
compile_training=True
|
|
237
|
+
)
|
|
238
|
+
|
|
239
|
+
# Verify the compiled function was created
|
|
240
|
+
assert trainer.loss_and_grad_fn is not None, "Compiled loss function should exist"
|
|
@@ -30,7 +30,7 @@ from .validate import validate_installation
|
|
|
30
30
|
|
|
31
31
|
# Export core reward functions and the reward decorator
|
|
32
32
|
from .rewards.basic import length_reward, keyword_reward, perplexity_reward, accuracy_reward
|
|
33
|
-
from .rewards.registry import reward
|
|
33
|
+
from .rewards.registry import reward, verifier
|
|
34
34
|
|
|
35
35
|
# Build __all__ combining submodule __all__ lists and additional symbols
|
|
36
36
|
__all__ = (
|
|
@@ -48,5 +48,6 @@ __all__ = (
|
|
|
48
48
|
"perplexity_reward",
|
|
49
49
|
"accuracy_reward",
|
|
50
50
|
"reward",
|
|
51
|
+
"verifier",
|
|
51
52
|
]
|
|
52
53
|
)
|
|
@@ -630,24 +630,31 @@ class TextGenerationEnv(Environment):
|
|
|
630
630
|
reward_fn: Callable[[str, str, dict], float],
|
|
631
631
|
max_tokens: int = 25,
|
|
632
632
|
seed: int = 42,
|
|
633
|
-
tokenizer: Any = None
|
|
633
|
+
tokenizer: Any = None,
|
|
634
|
+
examples: Optional[List[dict]] = None
|
|
634
635
|
):
|
|
635
636
|
"""
|
|
636
637
|
Initialize simple text generation environment.
|
|
637
|
-
|
|
638
|
+
|
|
638
639
|
Args:
|
|
639
640
|
prompts: List of prompts to cycle through
|
|
640
641
|
reward_fn: Function that computes reward from (prompt, completion, example)
|
|
641
642
|
max_tokens: Maximum tokens to generate per response
|
|
642
643
|
seed: Random seed for reproducible behavior
|
|
643
644
|
tokenizer: Tokenizer for converting prompts to tokens (required for MLX compatibility)
|
|
645
|
+
examples: Optional list of example dicts to pass to reward function. If provided,
|
|
646
|
+
must have same length as prompts. examples[i] is passed when prompts[i] is used.
|
|
644
647
|
"""
|
|
645
648
|
super().__init__()
|
|
646
649
|
|
|
647
650
|
if tokenizer is None:
|
|
648
651
|
raise ValueError("tokenizer is required for TextGenerationEnv to work with MLX rollout system")
|
|
649
|
-
|
|
652
|
+
|
|
653
|
+
if examples is not None and len(examples) != len(prompts):
|
|
654
|
+
raise ValueError(f"examples length ({len(examples)}) must match prompts length ({len(prompts)})")
|
|
655
|
+
|
|
650
656
|
self.prompts = prompts
|
|
657
|
+
self.examples = examples if examples is not None else [{} for _ in prompts]
|
|
651
658
|
self.reward_fn = reward_fn
|
|
652
659
|
self.max_tokens = max_tokens
|
|
653
660
|
self.tokenizer = tokenizer
|
|
@@ -735,10 +742,11 @@ class TextGenerationEnv(Environment):
|
|
|
735
742
|
|
|
736
743
|
# Compute reward using provided reward function
|
|
737
744
|
# Pass tokenizer for EOS token detection and truncation detection
|
|
745
|
+
prompt_index = self.current_episode % len(self.prompts)
|
|
738
746
|
reward = self.reward_fn(
|
|
739
747
|
prompt=self.current_prompt,
|
|
740
748
|
completion=response_text,
|
|
741
|
-
example=
|
|
749
|
+
example=self.examples[prompt_index],
|
|
742
750
|
tokenizer=self.tokenizer, # Pass tokenizer for EOS detection
|
|
743
751
|
truncated=truncated # Pass truncation flag from environment
|
|
744
752
|
)
|
|
@@ -1,13 +1,23 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: textpolicy
|
|
3
|
-
Version: 0.1.
|
|
4
|
-
Summary: MLX
|
|
3
|
+
Version: 0.1.2
|
|
4
|
+
Summary: Reinforcement learning for text generation on MLX (Apple Silicon): GRPO/GSPO, environments, rollout, rewards, LoRA/QLoRA
|
|
5
|
+
Project-URL: Homepage, https://github.com/teilomillet/textpolicy
|
|
6
|
+
Project-URL: Repository, https://github.com/teilomillet/textpolicy
|
|
7
|
+
Project-URL: Documentation, https://github.com/teilomillet/textpolicy#readme
|
|
8
|
+
Project-URL: Changelog, https://github.com/teilomillet/textpolicy/blob/main/CHANGELOG.md
|
|
9
|
+
Keywords: reinforcement-learning,text-generation,mlx,apple-silicon,lora,qlora,grpo,gspo,rlhf
|
|
10
|
+
Classifier: Programming Language :: Python :: 3
|
|
11
|
+
Classifier: Operating System :: MacOS
|
|
12
|
+
Classifier: Intended Audience :: Developers
|
|
13
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
14
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
5
15
|
Requires-Python: >=3.12
|
|
6
16
|
Description-Content-Type: text/markdown
|
|
7
17
|
License-File: LICENSE
|
|
8
18
|
Requires-Dist: numpy>=2.3.2
|
|
9
|
-
Requires-Dist: mlx>=0.
|
|
10
|
-
Requires-Dist: mlx-lm>=0.
|
|
19
|
+
Requires-Dist: mlx>=0.22.0
|
|
20
|
+
Requires-Dist: mlx-lm>=0.22.0
|
|
11
21
|
Requires-Dist: gymnasium>=0.29.0
|
|
12
22
|
Requires-Dist: psutil>=7.0.0
|
|
13
23
|
Requires-Dist: wandb>=0.21.1
|
|
@@ -3,9 +3,12 @@ README.md
|
|
|
3
3
|
pyproject.toml
|
|
4
4
|
tests/test_gspo_verification.py
|
|
5
5
|
tests/test_integration_e2e_training.py
|
|
6
|
+
tests/test_issue_fixes.py
|
|
7
|
+
tests/test_mlx_compatibility.py
|
|
6
8
|
tests/test_reward_signatures.py
|
|
7
9
|
tests/test_rollout_rewards.py
|
|
8
10
|
tests/test_runner_step_enforcement.py
|
|
11
|
+
tests/test_training_pipeline.py
|
|
9
12
|
tests/test_validate_installation.py
|
|
10
13
|
textpolicy/__init__.py
|
|
11
14
|
textpolicy/__main__.py
|
textpolicy-0.1.0/pyproject.toml
DELETED
|
@@ -1,30 +0,0 @@
|
|
|
1
|
-
[build-system]
|
|
2
|
-
requires = ["setuptools>=61.0", "wheel"]
|
|
3
|
-
build-backend = "setuptools.build_meta"
|
|
4
|
-
|
|
5
|
-
[tool.setuptools.packages.find]
|
|
6
|
-
include = ["textpolicy*"]
|
|
7
|
-
|
|
8
|
-
[project]
|
|
9
|
-
name = "textpolicy"
|
|
10
|
-
version = "0.1.0"
|
|
11
|
-
description = "MLX-optimized reward and verification system for text generation RL"
|
|
12
|
-
readme = "README.md"
|
|
13
|
-
requires-python = ">=3.12"
|
|
14
|
-
dependencies = [
|
|
15
|
-
"numpy>=2.3.2",
|
|
16
|
-
"mlx>=0.21.0", # Core MLX framework for Apple Silicon acceleration
|
|
17
|
-
"mlx-lm>=0.21.0", # MLX language models for inference
|
|
18
|
-
"gymnasium>=0.29.0",
|
|
19
|
-
"psutil>=7.0.0",
|
|
20
|
-
"wandb>=0.21.1",
|
|
21
|
-
"aiohttp>=3.12.15",
|
|
22
|
-
"pytest>=8.4.1",
|
|
23
|
-
]
|
|
24
|
-
|
|
25
|
-
[project.scripts]
|
|
26
|
-
textpolicy = "textpolicy.cli:main"
|
|
27
|
-
|
|
28
|
-
[project.optional-dependencies]
|
|
29
|
-
external = ["aiohttp>=3.8.0", "pydantic>=2.0.0"]
|
|
30
|
-
dev = ["pytest>=7.0.0", "black>=22.0.0", "ruff>=0.1.0"]
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|