textpolicy 0.0.1__tar.gz → 0.1.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- textpolicy-0.1.1/PKG-INFO +109 -0
- textpolicy-0.1.1/README.md +75 -0
- textpolicy-0.1.1/pyproject.toml +56 -0
- textpolicy-0.1.1/tests/test_gspo_verification.py +215 -0
- textpolicy-0.1.1/tests/test_integration_e2e_training.py +49 -0
- textpolicy-0.1.1/tests/test_issue_fixes.py +218 -0
- textpolicy-0.1.1/tests/test_reward_signatures.py +131 -0
- textpolicy-0.1.1/tests/test_rollout_rewards.py +228 -0
- textpolicy-0.1.1/tests/test_runner_step_enforcement.py +80 -0
- textpolicy-0.1.1/tests/test_validate_installation.py +12 -0
- textpolicy-0.1.1/textpolicy/__init__.py +53 -0
- textpolicy-0.1.1/textpolicy/__main__.py +8 -0
- textpolicy-0.1.1/textpolicy/algorithms/__init__.py +54 -0
- textpolicy-0.1.1/textpolicy/algorithms/grpo.py +642 -0
- textpolicy-0.1.1/textpolicy/algorithms/gspo.py +582 -0
- textpolicy-0.1.1/textpolicy/buffer/__init__.py +23 -0
- textpolicy-0.1.1/textpolicy/buffer/buffer.py +244 -0
- textpolicy-0.1.1/textpolicy/buffer/episode.py +383 -0
- textpolicy-0.1.1/textpolicy/buffer/sampling.py +438 -0
- textpolicy-0.1.1/textpolicy/buffer/storage.py +255 -0
- textpolicy-0.1.1/textpolicy/cli.py +67 -0
- textpolicy-0.1.1/textpolicy/environment/__init__.py +79 -0
- textpolicy-0.1.1/textpolicy/environment/base.py +110 -0
- textpolicy-0.1.1/textpolicy/environment/environment.py +46 -0
- textpolicy-0.1.1/textpolicy/environment/factory.py +103 -0
- textpolicy-0.1.1/textpolicy/environment/gym.py +106 -0
- textpolicy-0.1.1/textpolicy/environment/task_suites.py +51 -0
- textpolicy-0.1.1/textpolicy/environment/text_generation.py +797 -0
- textpolicy-0.1.1/textpolicy/environment/vectorized.py +253 -0
- textpolicy-0.1.1/textpolicy/generation/__init__.py +62 -0
- textpolicy-0.1.1/textpolicy/generation/lora.py +411 -0
- textpolicy-0.1.1/textpolicy/generation/mlx_generation.py +557 -0
- textpolicy-0.1.1/textpolicy/generation/reload.py +253 -0
- textpolicy-0.1.1/textpolicy/rewards/__init__.py +137 -0
- textpolicy-0.1.1/textpolicy/rewards/adapters.py +387 -0
- textpolicy-0.1.1/textpolicy/rewards/basic.py +214 -0
- textpolicy-0.1.1/textpolicy/rewards/integrated_system.py +338 -0
- textpolicy-0.1.1/textpolicy/rewards/mlx_batch_processor.py +447 -0
- textpolicy-0.1.1/textpolicy/rewards/registry.py +293 -0
- textpolicy-0.1.1/textpolicy/rewards/rollout_rewards.py +410 -0
- textpolicy-0.1.1/textpolicy/rewards/verifiers.py +369 -0
- textpolicy-0.1.1/textpolicy/rollout/__init__.py +44 -0
- textpolicy-0.1.1/textpolicy/rollout/aggregator.py +145 -0
- textpolicy-0.1.1/textpolicy/rollout/base.py +108 -0
- textpolicy-0.1.1/textpolicy/rollout/rollout.py +142 -0
- textpolicy-0.1.1/textpolicy/rollout/runner.py +280 -0
- textpolicy-0.1.1/textpolicy/rollout/strategy.py +208 -0
- textpolicy-0.1.1/textpolicy/rollout/worker.py +194 -0
- textpolicy-0.1.1/textpolicy/training/__init__.py +14 -0
- textpolicy-0.1.1/textpolicy/training/metrics.py +242 -0
- textpolicy-0.1.1/textpolicy/training/rollout_manager.py +78 -0
- textpolicy-0.1.1/textpolicy/training/trainer.py +684 -0
- textpolicy-0.1.1/textpolicy/utils/__init__.py +40 -0
- textpolicy-0.1.1/textpolicy/utils/benchmarking.py +489 -0
- textpolicy-0.1.1/textpolicy/utils/data.py +60 -0
- textpolicy-0.1.1/textpolicy/utils/debug.py +170 -0
- textpolicy-0.1.1/textpolicy/utils/environment.py +349 -0
- textpolicy-0.1.1/textpolicy/utils/logging/__init__.py +22 -0
- textpolicy-0.1.1/textpolicy/utils/logging/base.py +48 -0
- textpolicy-0.1.1/textpolicy/utils/logging/console.py +61 -0
- textpolicy-0.1.1/textpolicy/utils/logging/factory.py +133 -0
- textpolicy-0.1.1/textpolicy/utils/logging/multi.py +83 -0
- textpolicy-0.1.1/textpolicy/utils/logging/tensorboard.py +65 -0
- textpolicy-0.1.1/textpolicy/utils/logging/wandb.py +72 -0
- textpolicy-0.1.1/textpolicy/utils/memory.py +118 -0
- textpolicy-0.1.1/textpolicy/utils/performance.py +464 -0
- textpolicy-0.1.1/textpolicy/utils/timing.py +171 -0
- textpolicy-0.1.1/textpolicy/validate.py +101 -0
- textpolicy-0.1.1/textpolicy/validation/__init__.py +13 -0
- textpolicy-0.1.1/textpolicy/validation/logprob_validation.py +315 -0
- textpolicy-0.1.1/textpolicy.egg-info/PKG-INFO +109 -0
- textpolicy-0.1.1/textpolicy.egg-info/SOURCES.txt +76 -0
- textpolicy-0.1.1/textpolicy.egg-info/entry_points.txt +2 -0
- textpolicy-0.1.1/textpolicy.egg-info/requires.txt +17 -0
- textpolicy-0.0.1/PKG-INFO +0 -10
- textpolicy-0.0.1/README.md +0 -1
- textpolicy-0.0.1/pyproject.toml +0 -7
- textpolicy-0.0.1/textpolicy/__init__.py +0 -0
- textpolicy-0.0.1/textpolicy.egg-info/PKG-INFO +0 -10
- textpolicy-0.0.1/textpolicy.egg-info/SOURCES.txt +0 -8
- {textpolicy-0.0.1 → textpolicy-0.1.1}/LICENSE +0 -0
- {textpolicy-0.0.1 → textpolicy-0.1.1}/setup.cfg +0 -0
- {textpolicy-0.0.1 → textpolicy-0.1.1}/textpolicy.egg-info/dependency_links.txt +0 -0
- {textpolicy-0.0.1 → textpolicy-0.1.1}/textpolicy.egg-info/top_level.txt +0 -0
|
@@ -0,0 +1,109 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: textpolicy
|
|
3
|
+
Version: 0.1.1
|
|
4
|
+
Summary: Reinforcement learning for text generation on MLX (Apple Silicon): GRPO/GSPO, environments, rollout, rewards, LoRA/QLoRA
|
|
5
|
+
Project-URL: Homepage, https://github.com/teilomillet/textpolicy
|
|
6
|
+
Project-URL: Repository, https://github.com/teilomillet/textpolicy
|
|
7
|
+
Project-URL: Documentation, https://github.com/teilomillet/textpolicy#readme
|
|
8
|
+
Project-URL: Changelog, https://github.com/teilomillet/textpolicy/blob/main/CHANGELOG.md
|
|
9
|
+
Keywords: reinforcement-learning,text-generation,mlx,apple-silicon,lora,qlora,grpo,gspo,rlhf
|
|
10
|
+
Classifier: Programming Language :: Python :: 3
|
|
11
|
+
Classifier: Operating System :: MacOS
|
|
12
|
+
Classifier: Intended Audience :: Developers
|
|
13
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
14
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
15
|
+
Requires-Python: >=3.12
|
|
16
|
+
Description-Content-Type: text/markdown
|
|
17
|
+
License-File: LICENSE
|
|
18
|
+
Requires-Dist: numpy>=2.3.2
|
|
19
|
+
Requires-Dist: mlx>=0.21.0
|
|
20
|
+
Requires-Dist: mlx-lm>=0.21.0
|
|
21
|
+
Requires-Dist: gymnasium>=0.29.0
|
|
22
|
+
Requires-Dist: psutil>=7.0.0
|
|
23
|
+
Requires-Dist: wandb>=0.21.1
|
|
24
|
+
Requires-Dist: aiohttp>=3.12.15
|
|
25
|
+
Requires-Dist: pytest>=8.4.1
|
|
26
|
+
Provides-Extra: external
|
|
27
|
+
Requires-Dist: aiohttp>=3.8.0; extra == "external"
|
|
28
|
+
Requires-Dist: pydantic>=2.0.0; extra == "external"
|
|
29
|
+
Provides-Extra: dev
|
|
30
|
+
Requires-Dist: pytest>=7.0.0; extra == "dev"
|
|
31
|
+
Requires-Dist: black>=22.0.0; extra == "dev"
|
|
32
|
+
Requires-Dist: ruff>=0.1.0; extra == "dev"
|
|
33
|
+
Dynamic: license-file
|
|
34
|
+
|
|
35
|
+
# TextPolicy
|
|
36
|
+
|
|
37
|
+
Reinforcement learning toolkit for text generation on MLX (Apple Silicon).
|
|
38
|
+
TextPolicy provides algorithms (GRPO/GSPO), text-generation environments, a rollout runner,
|
|
39
|
+
reward functions with a decorator registry, and LoRA/QLoRA utilities.
|
|
40
|
+
|
|
41
|
+
## Install (uv)
|
|
42
|
+
|
|
43
|
+
```bash
|
|
44
|
+
uv add textpolicy
|
|
45
|
+
```
|
|
46
|
+
|
|
47
|
+
Optional model integration:
|
|
48
|
+
|
|
49
|
+
```bash
|
|
50
|
+
uv add mlx mlx-lm
|
|
51
|
+
```
|
|
52
|
+
|
|
53
|
+
## Quickstart
|
|
54
|
+
|
|
55
|
+
Working example using a real model and tokenizer (mlx-lm required):
|
|
56
|
+
|
|
57
|
+
```python
|
|
58
|
+
import mlx.core as mx
|
|
59
|
+
import textpolicy as tp
|
|
60
|
+
from textpolicy import load_model, create_policy
|
|
61
|
+
from textpolicy.environment.text_generation import TextGenerationEnv
|
|
62
|
+
from textpolicy.rollout import RolloutRunner, create_strategy
|
|
63
|
+
|
|
64
|
+
# 1) Load model and tokenizer (mlx-lm)
|
|
65
|
+
model, tokenizer = load_model("Qwen/Qwen3-0.6B")
|
|
66
|
+
|
|
67
|
+
# 2) Create a policy (controls generation)
|
|
68
|
+
generation_params = {"max_tokens": 25, "temperature": 0.7}
|
|
69
|
+
policy_fn = create_policy(model, tokenizer, generation_params)
|
|
70
|
+
|
|
71
|
+
# 3) Define a reward function (env uses this to score responses)
|
|
72
|
+
@tp.reward
|
|
73
|
+
def length_reward(prompt: str, completion: str, example: dict, **kwargs) -> float:
|
|
74
|
+
return float(len(completion.split()))
|
|
75
|
+
|
|
76
|
+
# 4) Create an environment (requires a tokenizer)
|
|
77
|
+
env = TextGenerationEnv(["What is AI?"], length_reward, tokenizer=tokenizer)
|
|
78
|
+
|
|
79
|
+
# 5) Collect one rollout step
|
|
80
|
+
strategy = create_strategy('grpo')
|
|
81
|
+
runner = RolloutRunner(env, policy=policy_fn, strategy=strategy, max_steps=1)
|
|
82
|
+
buffer = runner.collect()
|
|
83
|
+
print(len(buffer.episodes))
|
|
84
|
+
```
|
|
85
|
+
|
|
86
|
+
Docs:
|
|
87
|
+
- Quickstart: `docs/QUICKSTART_UV.md`
|
|
88
|
+
- LoRA/QLoRA: `docs/10_lora_qlora.md`
|
|
89
|
+
- Full index: `docs/index.md`
|
|
90
|
+
|
|
91
|
+
FAQ:
|
|
92
|
+
- Do I need a model?
|
|
93
|
+
- Yes for generation with `create_policy`.
|
|
94
|
+
Use `load_model()` (mlx‑lm) to get `(model, tokenizer)`.
|
|
95
|
+
For reward‑only code (no generation), a model is not required.
|
|
96
|
+
- Do I need a tokenizer?
|
|
97
|
+
- Yes.
|
|
98
|
+
Both `TextGenerationEnv` and `TextGenerationEnvironment` require a tokenizer.
|
|
99
|
+
`load_model()` returns one for mlx‑lm models.
|
|
100
|
+
- How do I control generation?
|
|
101
|
+
- Pass `generation_params` to `create_policy` (for example, `max_tokens`, `temperature`, `top_p`, `repetition_penalty`).
|
|
102
|
+
- What does `step()` return?
|
|
103
|
+
- A dict with `observation`, `reward`, `terminated`, `truncated`, `info`. The runner enforces this.
|
|
104
|
+
|
|
105
|
+
Examples:
|
|
106
|
+
- 01–06: reward functions, batch processing, minimal training
|
|
107
|
+
- 08: GRPO training with rollout + buffer
|
|
108
|
+
- 09–10: length reduction (GRPO/GSPO)
|
|
109
|
+
- 11: LoRA/QLoRA configuration
|
|
@@ -0,0 +1,75 @@
|
|
|
1
|
+
# TextPolicy
|
|
2
|
+
|
|
3
|
+
Reinforcement learning toolkit for text generation on MLX (Apple Silicon).
|
|
4
|
+
TextPolicy provides algorithms (GRPO/GSPO), text-generation environments, a rollout runner,
|
|
5
|
+
reward functions with a decorator registry, and LoRA/QLoRA utilities.
|
|
6
|
+
|
|
7
|
+
## Install (uv)
|
|
8
|
+
|
|
9
|
+
```bash
|
|
10
|
+
uv add textpolicy
|
|
11
|
+
```
|
|
12
|
+
|
|
13
|
+
Optional model integration:
|
|
14
|
+
|
|
15
|
+
```bash
|
|
16
|
+
uv add mlx mlx-lm
|
|
17
|
+
```
|
|
18
|
+
|
|
19
|
+
## Quickstart
|
|
20
|
+
|
|
21
|
+
Working example using a real model and tokenizer (mlx-lm required):
|
|
22
|
+
|
|
23
|
+
```python
|
|
24
|
+
import mlx.core as mx
|
|
25
|
+
import textpolicy as tp
|
|
26
|
+
from textpolicy import load_model, create_policy
|
|
27
|
+
from textpolicy.environment.text_generation import TextGenerationEnv
|
|
28
|
+
from textpolicy.rollout import RolloutRunner, create_strategy
|
|
29
|
+
|
|
30
|
+
# 1) Load model and tokenizer (mlx-lm)
|
|
31
|
+
model, tokenizer = load_model("Qwen/Qwen3-0.6B")
|
|
32
|
+
|
|
33
|
+
# 2) Create a policy (controls generation)
|
|
34
|
+
generation_params = {"max_tokens": 25, "temperature": 0.7}
|
|
35
|
+
policy_fn = create_policy(model, tokenizer, generation_params)
|
|
36
|
+
|
|
37
|
+
# 3) Define a reward function (env uses this to score responses)
|
|
38
|
+
@tp.reward
|
|
39
|
+
def length_reward(prompt: str, completion: str, example: dict, **kwargs) -> float:
|
|
40
|
+
return float(len(completion.split()))
|
|
41
|
+
|
|
42
|
+
# 4) Create an environment (requires a tokenizer)
|
|
43
|
+
env = TextGenerationEnv(["What is AI?"], length_reward, tokenizer=tokenizer)
|
|
44
|
+
|
|
45
|
+
# 5) Collect one rollout step
|
|
46
|
+
strategy = create_strategy('grpo')
|
|
47
|
+
runner = RolloutRunner(env, policy=policy_fn, strategy=strategy, max_steps=1)
|
|
48
|
+
buffer = runner.collect()
|
|
49
|
+
print(len(buffer.episodes))
|
|
50
|
+
```
|
|
51
|
+
|
|
52
|
+
Docs:
|
|
53
|
+
- Quickstart: `docs/QUICKSTART_UV.md`
|
|
54
|
+
- LoRA/QLoRA: `docs/10_lora_qlora.md`
|
|
55
|
+
- Full index: `docs/index.md`
|
|
56
|
+
|
|
57
|
+
FAQ:
|
|
58
|
+
- Do I need a model?
|
|
59
|
+
- Yes for generation with `create_policy`.
|
|
60
|
+
Use `load_model()` (mlx‑lm) to get `(model, tokenizer)`.
|
|
61
|
+
For reward‑only code (no generation), a model is not required.
|
|
62
|
+
- Do I need a tokenizer?
|
|
63
|
+
- Yes.
|
|
64
|
+
Both `TextGenerationEnv` and `TextGenerationEnvironment` require a tokenizer.
|
|
65
|
+
`load_model()` returns one for mlx‑lm models.
|
|
66
|
+
- How do I control generation?
|
|
67
|
+
- Pass `generation_params` to `create_policy` (for example, `max_tokens`, `temperature`, `top_p`, `repetition_penalty`).
|
|
68
|
+
- What does `step()` return?
|
|
69
|
+
- A dict with `observation`, `reward`, `terminated`, `truncated`, `info`. The runner enforces this.
|
|
70
|
+
|
|
71
|
+
Examples:
|
|
72
|
+
- 01–06: reward functions, batch processing, minimal training
|
|
73
|
+
- 08: GRPO training with rollout + buffer
|
|
74
|
+
- 09–10: length reduction (GRPO/GSPO)
|
|
75
|
+
- 11: LoRA/QLoRA configuration
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["setuptools>=61.0", "wheel"]
|
|
3
|
+
build-backend = "setuptools.build_meta"
|
|
4
|
+
|
|
5
|
+
[tool.setuptools.packages.find]
|
|
6
|
+
include = ["textpolicy*"]
|
|
7
|
+
|
|
8
|
+
[project]
|
|
9
|
+
name = "textpolicy"
|
|
10
|
+
version = "0.1.1"
|
|
11
|
+
description = "Reinforcement learning for text generation on MLX (Apple Silicon): GRPO/GSPO, environments, rollout, rewards, LoRA/QLoRA"
|
|
12
|
+
readme = "README.md"
|
|
13
|
+
requires-python = ">=3.12"
|
|
14
|
+
dependencies = [
|
|
15
|
+
"numpy>=2.3.2",
|
|
16
|
+
"mlx>=0.21.0", # Core MLX framework for Apple Silicon acceleration
|
|
17
|
+
"mlx-lm>=0.21.0", # MLX language models for inference
|
|
18
|
+
"gymnasium>=0.29.0",
|
|
19
|
+
"psutil>=7.0.0",
|
|
20
|
+
"wandb>=0.21.1",
|
|
21
|
+
"aiohttp>=3.12.15",
|
|
22
|
+
"pytest>=8.4.1",
|
|
23
|
+
]
|
|
24
|
+
|
|
25
|
+
keywords = [
|
|
26
|
+
"reinforcement-learning",
|
|
27
|
+
"text-generation",
|
|
28
|
+
"mlx",
|
|
29
|
+
"apple-silicon",
|
|
30
|
+
"lora",
|
|
31
|
+
"qlora",
|
|
32
|
+
"grpo",
|
|
33
|
+
"gspo",
|
|
34
|
+
"rlhf",
|
|
35
|
+
]
|
|
36
|
+
|
|
37
|
+
classifiers = [
|
|
38
|
+
"Programming Language :: Python :: 3",
|
|
39
|
+
"Operating System :: MacOS",
|
|
40
|
+
"Intended Audience :: Developers",
|
|
41
|
+
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
|
42
|
+
"License :: OSI Approved :: MIT License",
|
|
43
|
+
]
|
|
44
|
+
|
|
45
|
+
[project.urls]
|
|
46
|
+
Homepage = "https://github.com/teilomillet/textpolicy"
|
|
47
|
+
Repository = "https://github.com/teilomillet/textpolicy"
|
|
48
|
+
Documentation = "https://github.com/teilomillet/textpolicy#readme"
|
|
49
|
+
Changelog = "https://github.com/teilomillet/textpolicy/blob/main/CHANGELOG.md"
|
|
50
|
+
|
|
51
|
+
[project.scripts]
|
|
52
|
+
textpolicy = "textpolicy.cli:main"
|
|
53
|
+
|
|
54
|
+
[project.optional-dependencies]
|
|
55
|
+
external = ["aiohttp>=3.8.0", "pydantic>=2.0.0"]
|
|
56
|
+
dev = ["pytest>=7.0.0", "black>=22.0.0", "ruff>=0.1.0"]
|
|
@@ -0,0 +1,215 @@
|
|
|
1
|
+
"""
|
|
2
|
+
GSPO Verification Tests - Comprehensive Testing of GSPO Implementation
|
|
3
|
+
|
|
4
|
+
This test module verifies that GSPO is working correctly by testing:
|
|
5
|
+
1. Basic functionality of GSPO components
|
|
6
|
+
2. Comparison with GRPO behavior
|
|
7
|
+
3. Sequence-level vs token-level importance sampling
|
|
8
|
+
4. Mathematical correctness of importance weights
|
|
9
|
+
5. Training dynamics and convergence
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
import pytest
|
|
13
|
+
import mlx.core as mx
|
|
14
|
+
import mlx.optimizers as optim
|
|
15
|
+
import numpy as np
|
|
16
|
+
from textpolicy.algorithms import grpo, gspo
|
|
17
|
+
from textpolicy.generation.mlx_generation import load_model, create_policy
|
|
18
|
+
from textpolicy.rollout import RolloutCoordinator
|
|
19
|
+
from textpolicy.buffer import Buffer
|
|
20
|
+
from textpolicy.training import Trainer
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@pytest.mark.unit
|
|
24
|
+
@pytest.mark.algorithm
|
|
25
|
+
class TestGSPOBasicFunctionality:
|
|
26
|
+
"""Test basic GSPO functions work correctly."""
|
|
27
|
+
|
|
28
|
+
def test_sequence_importance_weights(self):
|
|
29
|
+
"""Test sequence-level importance weights computation."""
|
|
30
|
+
# Create test data
|
|
31
|
+
old_logprobs = mx.array([-1.0, -1.2, -0.8, -1.1, -0.9]) # 5 tokens
|
|
32
|
+
new_logprobs = mx.array([-1.1, -1.0, -0.9, -1.0, -1.0]) # 5 tokens
|
|
33
|
+
sequence_lengths = [2, 3] # Two sequences: 2 tokens + 3 tokens
|
|
34
|
+
|
|
35
|
+
# Test sequence-level importance weights
|
|
36
|
+
seq_weights = gspo.compute_sequence_importance_weights(
|
|
37
|
+
old_logprobs, new_logprobs, sequence_lengths, clip_ratio=0.2
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
assert len(seq_weights) == len(sequence_lengths), \
|
|
41
|
+
f"Expected {len(sequence_lengths)} weights, got {len(seq_weights)}"
|
|
42
|
+
assert all(not mx.isnan(w) and not mx.isinf(w) for w in seq_weights), \
|
|
43
|
+
"All weights should be finite"
|
|
44
|
+
|
|
45
|
+
def test_gspo_policy_loss(self):
|
|
46
|
+
"""Test GSPO policy loss computation."""
|
|
47
|
+
old_logprobs = mx.array([-1.0, -1.2, -0.8, -1.1, -0.9])
|
|
48
|
+
new_logprobs = mx.array([-1.1, -1.0, -0.9, -1.0, -1.0])
|
|
49
|
+
sequence_lengths = [2, 3]
|
|
50
|
+
advantages = mx.array([0.5, -0.3]) # Advantages for each sequence
|
|
51
|
+
|
|
52
|
+
# Test GSPO policy loss
|
|
53
|
+
loss = gspo.gspo_policy_loss(
|
|
54
|
+
old_logprobs, new_logprobs, advantages, sequence_lengths, variant="sequence"
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
assert not mx.isnan(loss) and not mx.isinf(loss), "Loss should be finite"
|
|
58
|
+
assert isinstance(loss, mx.array), "Loss should be an MLX array"
|
|
59
|
+
|
|
60
|
+
def test_hybrid_importance_weights(self):
|
|
61
|
+
"""Test hybrid importance weights computation."""
|
|
62
|
+
old_logprobs = mx.array([-1.0, -1.2, -0.8, -1.1, -0.9])
|
|
63
|
+
new_logprobs = mx.array([-1.1, -1.0, -0.9, -1.0, -1.0])
|
|
64
|
+
sequence_lengths = [2, 3]
|
|
65
|
+
|
|
66
|
+
# Test hybrid variant
|
|
67
|
+
hybrid_weights = gspo.compute_hybrid_importance_weights(
|
|
68
|
+
old_logprobs, new_logprobs, sequence_lengths
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
assert len(hybrid_weights) == len(old_logprobs), \
|
|
72
|
+
f"Expected {len(old_logprobs)} hybrid weights, got {len(hybrid_weights)}"
|
|
73
|
+
assert all(not mx.isnan(w) and not mx.isinf(w) for w in hybrid_weights), \
|
|
74
|
+
"All hybrid weights should be finite"
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
@pytest.mark.unit
|
|
78
|
+
@pytest.mark.algorithm
|
|
79
|
+
class TestGSPOvsGRPO:
|
|
80
|
+
"""Test that GSPO produces different importance weights than GRPO."""
|
|
81
|
+
|
|
82
|
+
def test_importance_weight_differences(self):
|
|
83
|
+
"""Test that GSPO produces different importance weights than GRPO."""
|
|
84
|
+
# Create test data with clear differences between old and new policies
|
|
85
|
+
old_logprobs = mx.array([-2.0, -2.0, -1.0, -1.0]) # 4 tokens
|
|
86
|
+
new_logprobs = mx.array([-1.0, -1.0, -2.0, -2.0]) # Policy changed significantly
|
|
87
|
+
sequence_lengths = [2, 2] # Two sequences of equal length
|
|
88
|
+
|
|
89
|
+
# Compute GSPO sequence-level importance weights
|
|
90
|
+
gspo_weights = gspo.compute_sequence_importance_weights(
|
|
91
|
+
old_logprobs, new_logprobs, sequence_lengths, clip_ratio=0.2
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
# Compute GRPO token-level importance ratios for comparison
|
|
95
|
+
grpo_ratios = mx.exp(new_logprobs - old_logprobs)
|
|
96
|
+
grpo_ratios_clipped = mx.clip(grpo_ratios, 0.8, 1.2)
|
|
97
|
+
|
|
98
|
+
# GSPO should produce sequence-level weights (2 values)
|
|
99
|
+
# GRPO produces token-level ratios (4 values)
|
|
100
|
+
assert len(gspo_weights) == len(sequence_lengths), \
|
|
101
|
+
f"GSPO should produce {len(sequence_lengths)} sequence weights"
|
|
102
|
+
assert len(grpo_ratios) == len(old_logprobs), \
|
|
103
|
+
f"GRPO should produce {len(old_logprobs)} token ratios"
|
|
104
|
+
|
|
105
|
+
# The approaches should be fundamentally different
|
|
106
|
+
# GSPO normalizes by sequence length, GRPO doesn't
|
|
107
|
+
assert len(gspo_weights) != len(grpo_ratios), \
|
|
108
|
+
"GSPO and GRPO should produce different numbers of weights"
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
@pytest.mark.unit
|
|
112
|
+
@pytest.mark.algorithm
|
|
113
|
+
class TestGSPOClipping:
|
|
114
|
+
"""Test GSPO clipping behavior."""
|
|
115
|
+
|
|
116
|
+
def test_clipping_bounds_respected(self):
|
|
117
|
+
"""Test that importance weights respect clipping bounds."""
|
|
118
|
+
# Test extreme case to verify clipping
|
|
119
|
+
old_logprobs = mx.array([-10.0, -1.0]) # Extreme difference
|
|
120
|
+
new_logprobs = mx.array([-1.0, -1.0])
|
|
121
|
+
sequence_lengths = [2]
|
|
122
|
+
clip_ratio = 0.2
|
|
123
|
+
|
|
124
|
+
# Compute sequence weights
|
|
125
|
+
weights = gspo.compute_sequence_importance_weights(
|
|
126
|
+
old_logprobs, new_logprobs, sequence_lengths, clip_ratio=clip_ratio
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
# Weights should be clipped between (1-clip_ratio) and (1+clip_ratio)
|
|
130
|
+
lower_bound = 1.0 - clip_ratio
|
|
131
|
+
upper_bound = 1.0 + clip_ratio
|
|
132
|
+
|
|
133
|
+
# Use tolerance for floating-point comparisons due to MLX float32 precision
|
|
134
|
+
# MLX uses float32 by default, which has precision ~1.19e-7
|
|
135
|
+
tolerance = 1e-6 # Conservative tolerance for float32 precision issues
|
|
136
|
+
|
|
137
|
+
for weight in weights:
|
|
138
|
+
weight_val = float(weight) # Convert MLX scalar to Python float
|
|
139
|
+
assert lower_bound - tolerance <= weight_val <= upper_bound + tolerance, \
|
|
140
|
+
f"Weight {weight_val} outside clipping bounds [{lower_bound}, {upper_bound}] with tolerance {tolerance}"
|
|
141
|
+
|
|
142
|
+
def test_length_normalization_effect(self):
|
|
143
|
+
"""Test that GSPO properly normalizes by sequence length."""
|
|
144
|
+
# Identical sequences of different lengths should have similar weights
|
|
145
|
+
old_logprobs_short = mx.array([-1.0, -1.0]) # 2 tokens
|
|
146
|
+
new_logprobs_short = mx.array([-0.5, -0.5]) # Better by 0.5 per token
|
|
147
|
+
|
|
148
|
+
old_logprobs_long = mx.array([-1.0, -1.0, -1.0, -1.0]) # 4 tokens
|
|
149
|
+
new_logprobs_long = mx.array([-0.5, -0.5, -0.5, -0.5]) # Better by 0.5 per token
|
|
150
|
+
|
|
151
|
+
weight_short = gspo.compute_sequence_importance_weights(
|
|
152
|
+
old_logprobs_short, new_logprobs_short, [2], clip_ratio=1.0 # No clipping
|
|
153
|
+
)
|
|
154
|
+
weight_long = gspo.compute_sequence_importance_weights(
|
|
155
|
+
old_logprobs_long, new_logprobs_long, [4], clip_ratio=1.0 # No clipping
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
# Both should be similar due to length normalization
|
|
159
|
+
# Short: exp((sum(-0.5) - sum(-1.0)) / 2) = exp((−1.0 − (−2.0)) / 2) = exp(0.5)
|
|
160
|
+
# Long: exp((sum(-0.5) - sum(-1.0)) / 4) = exp((−2.0 − (−4.0)) / 4) = exp(0.5)
|
|
161
|
+
short_val = float(weight_short[0])
|
|
162
|
+
long_val = float(weight_long[0])
|
|
163
|
+
|
|
164
|
+
# They should be approximately equal due to length normalization
|
|
165
|
+
assert abs(short_val - long_val) < 0.01, \
|
|
166
|
+
f"Length normalization failed: short={short_val}, long={long_val}"
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
@pytest.mark.integration
|
|
170
|
+
@pytest.mark.algorithm
|
|
171
|
+
@pytest.mark.slow
|
|
172
|
+
class TestGSPOTraining:
|
|
173
|
+
"""Integration tests for GSPO training."""
|
|
174
|
+
|
|
175
|
+
def test_gspo_training_step(self):
|
|
176
|
+
"""Test a complete GSPO training step."""
|
|
177
|
+
# This is a minimal integration test
|
|
178
|
+
# Create minimal test data
|
|
179
|
+
old_logprobs = mx.array([-1.0, -1.0, -1.0, -1.0])
|
|
180
|
+
new_logprobs = mx.array([-0.8, -0.8, -1.2, -1.2])
|
|
181
|
+
advantages = mx.array([0.5, -0.3])
|
|
182
|
+
sequence_lengths = [2, 2]
|
|
183
|
+
|
|
184
|
+
# Test that we can compute a complete loss
|
|
185
|
+
loss = gspo.gspo_policy_loss(
|
|
186
|
+
old_logprobs, new_logprobs, advantages, sequence_lengths, variant="sequence"
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
assert not mx.isnan(loss) and not mx.isinf(loss), "Training loss should be finite"
|
|
190
|
+
assert float(loss) != 0.0, "Loss should be non-zero for non-trivial inputs"
|
|
191
|
+
|
|
192
|
+
def test_gspo_metrics_computation(self):
|
|
193
|
+
"""Test GSPO metrics computation."""
|
|
194
|
+
old_logprobs = mx.array([-1.0, -1.0, -1.0, -1.0])
|
|
195
|
+
new_logprobs = mx.array([-0.8, -0.8, -1.2, -1.2])
|
|
196
|
+
advantages = mx.array([0.5, -0.3])
|
|
197
|
+
|
|
198
|
+
# Test metrics computation
|
|
199
|
+
metrics_fn = gspo.create_gspo_metrics(variant="sequence")
|
|
200
|
+
metrics = metrics_fn(old_logprobs, new_logprobs, advantages)
|
|
201
|
+
|
|
202
|
+
assert isinstance(metrics, dict), "Metrics should be a dictionary"
|
|
203
|
+
assert len(metrics) > 0, "Metrics should not be empty"
|
|
204
|
+
|
|
205
|
+
# Check for expected metric keys
|
|
206
|
+
expected_keys = ['mean_advantage', 'std_advantage']
|
|
207
|
+
for key in expected_keys:
|
|
208
|
+
assert key in metrics, f"Missing expected metric: {key}"
|
|
209
|
+
assert isinstance(metrics[key], (int, float)), \
|
|
210
|
+
f"Metric {key} should be numeric, got {type(metrics[key])}"
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
if __name__ == "__main__":
|
|
214
|
+
# Allow running this file directly for debugging
|
|
215
|
+
pytest.main([__file__, "-v"])
|
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
import pytest
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
@pytest.mark.integration
|
|
5
|
+
def test_e2e_minimal_rollout_grpo():
|
|
6
|
+
"""
|
|
7
|
+
Minimal end-to-end rollout + buffer collection using TextGenerationEnv
|
|
8
|
+
with a dummy tokenizer and a trivial policy. This validates that
|
|
9
|
+
the environment returns dict-shaped step results and the runner
|
|
10
|
+
normalization path works as expected.
|
|
11
|
+
|
|
12
|
+
Kept intentionally lightweight for CI (no external model downloads).
|
|
13
|
+
"""
|
|
14
|
+
try:
|
|
15
|
+
import mlx.core as mx # type: ignore
|
|
16
|
+
except Exception:
|
|
17
|
+
pytest.skip("MLX not available")
|
|
18
|
+
|
|
19
|
+
from textpolicy.environment.text_generation import TextGenerationEnv
|
|
20
|
+
from textpolicy.rollout.runner import RolloutRunner
|
|
21
|
+
from textpolicy.rollout.strategy import create_strategy
|
|
22
|
+
|
|
23
|
+
class DummyTokenizer:
|
|
24
|
+
def encode(self, text):
|
|
25
|
+
return [ord(c) % 256 for c in text]
|
|
26
|
+
|
|
27
|
+
def decode(self, ids):
|
|
28
|
+
return "".join(chr(int(i) % 256) for i in ids)
|
|
29
|
+
|
|
30
|
+
def reward_fn(prompt, completion, example, **kwargs) -> float:
|
|
31
|
+
# Simple length reward in words
|
|
32
|
+
return float(len(completion.split()))
|
|
33
|
+
|
|
34
|
+
# Create simple environment
|
|
35
|
+
env = TextGenerationEnv(["Hello"], reward_fn, tokenizer=DummyTokenizer())
|
|
36
|
+
|
|
37
|
+
# Policy returns tokens that decode to 'a b c'
|
|
38
|
+
def simple_policy(obs_mx, deterministic=False):
|
|
39
|
+
return mx.array([97, 32, 98, 32, 99], dtype=mx.int32), {}
|
|
40
|
+
|
|
41
|
+
strategy = create_strategy('grpo')
|
|
42
|
+
runner = RolloutRunner(env, policy=simple_policy, strategy=strategy, max_steps=2)
|
|
43
|
+
|
|
44
|
+
buffer = runner.collect()
|
|
45
|
+
assert len(buffer.episodes) >= 1
|
|
46
|
+
ep = buffer.episodes[0]
|
|
47
|
+
# Episode stores rewards in `rew`
|
|
48
|
+
assert len(ep.rew) >= 1
|
|
49
|
+
assert all(r > 0 for r in ep.rew)
|