textpolicy 0.0.1__tar.gz → 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. textpolicy-0.1.0/PKG-INFO +99 -0
  2. textpolicy-0.1.0/README.md +75 -0
  3. textpolicy-0.1.0/pyproject.toml +30 -0
  4. textpolicy-0.1.0/tests/test_gspo_verification.py +215 -0
  5. textpolicy-0.1.0/tests/test_integration_e2e_training.py +49 -0
  6. textpolicy-0.1.0/tests/test_reward_signatures.py +131 -0
  7. textpolicy-0.1.0/tests/test_rollout_rewards.py +228 -0
  8. textpolicy-0.1.0/tests/test_runner_step_enforcement.py +80 -0
  9. textpolicy-0.1.0/tests/test_validate_installation.py +12 -0
  10. textpolicy-0.1.0/textpolicy/__init__.py +52 -0
  11. textpolicy-0.1.0/textpolicy/__main__.py +8 -0
  12. textpolicy-0.1.0/textpolicy/algorithms/__init__.py +54 -0
  13. textpolicy-0.1.0/textpolicy/algorithms/grpo.py +642 -0
  14. textpolicy-0.1.0/textpolicy/algorithms/gspo.py +582 -0
  15. textpolicy-0.1.0/textpolicy/buffer/__init__.py +23 -0
  16. textpolicy-0.1.0/textpolicy/buffer/buffer.py +244 -0
  17. textpolicy-0.1.0/textpolicy/buffer/episode.py +383 -0
  18. textpolicy-0.1.0/textpolicy/buffer/sampling.py +438 -0
  19. textpolicy-0.1.0/textpolicy/buffer/storage.py +255 -0
  20. textpolicy-0.1.0/textpolicy/cli.py +67 -0
  21. textpolicy-0.1.0/textpolicy/environment/__init__.py +79 -0
  22. textpolicy-0.1.0/textpolicy/environment/base.py +110 -0
  23. textpolicy-0.1.0/textpolicy/environment/environment.py +46 -0
  24. textpolicy-0.1.0/textpolicy/environment/factory.py +103 -0
  25. textpolicy-0.1.0/textpolicy/environment/gym.py +106 -0
  26. textpolicy-0.1.0/textpolicy/environment/task_suites.py +51 -0
  27. textpolicy-0.1.0/textpolicy/environment/text_generation.py +789 -0
  28. textpolicy-0.1.0/textpolicy/environment/vectorized.py +253 -0
  29. textpolicy-0.1.0/textpolicy/generation/__init__.py +62 -0
  30. textpolicy-0.1.0/textpolicy/generation/lora.py +411 -0
  31. textpolicy-0.1.0/textpolicy/generation/mlx_generation.py +557 -0
  32. textpolicy-0.1.0/textpolicy/generation/reload.py +253 -0
  33. textpolicy-0.1.0/textpolicy/rewards/__init__.py +137 -0
  34. textpolicy-0.1.0/textpolicy/rewards/adapters.py +387 -0
  35. textpolicy-0.1.0/textpolicy/rewards/basic.py +214 -0
  36. textpolicy-0.1.0/textpolicy/rewards/integrated_system.py +338 -0
  37. textpolicy-0.1.0/textpolicy/rewards/mlx_batch_processor.py +447 -0
  38. textpolicy-0.1.0/textpolicy/rewards/registry.py +293 -0
  39. textpolicy-0.1.0/textpolicy/rewards/rollout_rewards.py +410 -0
  40. textpolicy-0.1.0/textpolicy/rewards/verifiers.py +369 -0
  41. textpolicy-0.1.0/textpolicy/rollout/__init__.py +44 -0
  42. textpolicy-0.1.0/textpolicy/rollout/aggregator.py +145 -0
  43. textpolicy-0.1.0/textpolicy/rollout/base.py +108 -0
  44. textpolicy-0.1.0/textpolicy/rollout/rollout.py +142 -0
  45. textpolicy-0.1.0/textpolicy/rollout/runner.py +280 -0
  46. textpolicy-0.1.0/textpolicy/rollout/strategy.py +208 -0
  47. textpolicy-0.1.0/textpolicy/rollout/worker.py +194 -0
  48. textpolicy-0.1.0/textpolicy/training/__init__.py +14 -0
  49. textpolicy-0.1.0/textpolicy/training/metrics.py +242 -0
  50. textpolicy-0.1.0/textpolicy/training/rollout_manager.py +78 -0
  51. textpolicy-0.1.0/textpolicy/training/trainer.py +684 -0
  52. textpolicy-0.1.0/textpolicy/utils/__init__.py +40 -0
  53. textpolicy-0.1.0/textpolicy/utils/benchmarking.py +489 -0
  54. textpolicy-0.1.0/textpolicy/utils/data.py +60 -0
  55. textpolicy-0.1.0/textpolicy/utils/debug.py +170 -0
  56. textpolicy-0.1.0/textpolicy/utils/environment.py +349 -0
  57. textpolicy-0.1.0/textpolicy/utils/logging/__init__.py +22 -0
  58. textpolicy-0.1.0/textpolicy/utils/logging/base.py +48 -0
  59. textpolicy-0.1.0/textpolicy/utils/logging/console.py +61 -0
  60. textpolicy-0.1.0/textpolicy/utils/logging/factory.py +133 -0
  61. textpolicy-0.1.0/textpolicy/utils/logging/multi.py +83 -0
  62. textpolicy-0.1.0/textpolicy/utils/logging/tensorboard.py +65 -0
  63. textpolicy-0.1.0/textpolicy/utils/logging/wandb.py +72 -0
  64. textpolicy-0.1.0/textpolicy/utils/memory.py +118 -0
  65. textpolicy-0.1.0/textpolicy/utils/performance.py +464 -0
  66. textpolicy-0.1.0/textpolicy/utils/timing.py +171 -0
  67. textpolicy-0.1.0/textpolicy/validate.py +101 -0
  68. textpolicy-0.1.0/textpolicy/validation/__init__.py +13 -0
  69. textpolicy-0.1.0/textpolicy/validation/logprob_validation.py +315 -0
  70. textpolicy-0.1.0/textpolicy.egg-info/PKG-INFO +99 -0
  71. textpolicy-0.1.0/textpolicy.egg-info/SOURCES.txt +75 -0
  72. textpolicy-0.1.0/textpolicy.egg-info/entry_points.txt +2 -0
  73. textpolicy-0.1.0/textpolicy.egg-info/requires.txt +17 -0
  74. textpolicy-0.0.1/PKG-INFO +0 -10
  75. textpolicy-0.0.1/README.md +0 -1
  76. textpolicy-0.0.1/pyproject.toml +0 -7
  77. textpolicy-0.0.1/textpolicy/__init__.py +0 -0
  78. textpolicy-0.0.1/textpolicy.egg-info/PKG-INFO +0 -10
  79. textpolicy-0.0.1/textpolicy.egg-info/SOURCES.txt +0 -8
  80. {textpolicy-0.0.1 → textpolicy-0.1.0}/LICENSE +0 -0
  81. {textpolicy-0.0.1 → textpolicy-0.1.0}/setup.cfg +0 -0
  82. {textpolicy-0.0.1 → textpolicy-0.1.0}/textpolicy.egg-info/dependency_links.txt +0 -0
  83. {textpolicy-0.0.1 → textpolicy-0.1.0}/textpolicy.egg-info/top_level.txt +0 -0
@@ -0,0 +1,99 @@
1
+ Metadata-Version: 2.4
2
+ Name: textpolicy
3
+ Version: 0.1.0
4
+ Summary: MLX-optimized reward and verification system for text generation RL
5
+ Requires-Python: >=3.12
6
+ Description-Content-Type: text/markdown
7
+ License-File: LICENSE
8
+ Requires-Dist: numpy>=2.3.2
9
+ Requires-Dist: mlx>=0.21.0
10
+ Requires-Dist: mlx-lm>=0.21.0
11
+ Requires-Dist: gymnasium>=0.29.0
12
+ Requires-Dist: psutil>=7.0.0
13
+ Requires-Dist: wandb>=0.21.1
14
+ Requires-Dist: aiohttp>=3.12.15
15
+ Requires-Dist: pytest>=8.4.1
16
+ Provides-Extra: external
17
+ Requires-Dist: aiohttp>=3.8.0; extra == "external"
18
+ Requires-Dist: pydantic>=2.0.0; extra == "external"
19
+ Provides-Extra: dev
20
+ Requires-Dist: pytest>=7.0.0; extra == "dev"
21
+ Requires-Dist: black>=22.0.0; extra == "dev"
22
+ Requires-Dist: ruff>=0.1.0; extra == "dev"
23
+ Dynamic: license-file
24
+
25
+ # TextPolicy
26
+
27
+ Reinforcement learning toolkit for text generation on MLX (Apple Silicon).
28
+ TextPolicy provides algorithms (GRPO/GSPO), text-generation environments, a rollout runner,
29
+ reward functions with a decorator registry, and LoRA/QLoRA utilities.
30
+
31
+ ## Install (uv)
32
+
33
+ ```bash
34
+ uv add textpolicy
35
+ ```
36
+
37
+ Optional model integration:
38
+
39
+ ```bash
40
+ uv add mlx mlx-lm
41
+ ```
42
+
43
+ ## Quickstart
44
+
45
+ Working example using a real model and tokenizer (mlx-lm required):
46
+
47
+ ```python
48
+ import mlx.core as mx
49
+ import textpolicy as tp
50
+ from textpolicy import load_model, create_policy
51
+ from textpolicy.environment.text_generation import TextGenerationEnv
52
+ from textpolicy.rollout import RolloutRunner, create_strategy
53
+
54
+ # 1) Load model and tokenizer (mlx-lm)
55
+ model, tokenizer = load_model("Qwen/Qwen3-0.6B")
56
+
57
+ # 2) Create a policy (controls generation)
58
+ generation_params = {"max_tokens": 25, "temperature": 0.7}
59
+ policy_fn = create_policy(model, tokenizer, generation_params)
60
+
61
+ # 3) Define a reward function (env uses this to score responses)
62
+ @tp.reward
63
+ def length_reward(prompt: str, completion: str, example: dict, **kwargs) -> float:
64
+ return float(len(completion.split()))
65
+
66
+ # 4) Create an environment (requires a tokenizer)
67
+ env = TextGenerationEnv(["What is AI?"], length_reward, tokenizer=tokenizer)
68
+
69
+ # 5) Collect one rollout step
70
+ strategy = create_strategy('grpo')
71
+ runner = RolloutRunner(env, policy=policy_fn, strategy=strategy, max_steps=1)
72
+ buffer = runner.collect()
73
+ print(len(buffer.episodes))
74
+ ```
75
+
76
+ Docs:
77
+ - Quickstart: `docs/QUICKSTART_UV.md`
78
+ - LoRA/QLoRA: `docs/10_lora_qlora.md`
79
+ - Full index: `docs/index.md`
80
+
81
+ FAQ:
82
+ - Do I need a model?
83
+ - Yes for generation with `create_policy`.
84
+ Use `load_model()` (mlx‑lm) to get `(model, tokenizer)`.
85
+ For reward‑only code (no generation), a model is not required.
86
+ - Do I need a tokenizer?
87
+ - Yes.
88
+ Both `TextGenerationEnv` and `TextGenerationEnvironment` require a tokenizer.
89
+ `load_model()` returns one for mlx‑lm models.
90
+ - How do I control generation?
91
+ - Pass `generation_params` to `create_policy` (for example, `max_tokens`, `temperature`, `top_p`, `repetition_penalty`).
92
+ - What does `step()` return?
93
+ - A dict with `observation`, `reward`, `terminated`, `truncated`, `info`. The runner enforces this.
94
+
95
+ Examples:
96
+ - 01–06: reward functions, batch processing, minimal training
97
+ - 08: GRPO training with rollout + buffer
98
+ - 09–10: length reduction (GRPO/GSPO)
99
+ - 11: LoRA/QLoRA configuration
@@ -0,0 +1,75 @@
1
+ # TextPolicy
2
+
3
+ Reinforcement learning toolkit for text generation on MLX (Apple Silicon).
4
+ TextPolicy provides algorithms (GRPO/GSPO), text-generation environments, a rollout runner,
5
+ reward functions with a decorator registry, and LoRA/QLoRA utilities.
6
+
7
+ ## Install (uv)
8
+
9
+ ```bash
10
+ uv add textpolicy
11
+ ```
12
+
13
+ Optional model integration:
14
+
15
+ ```bash
16
+ uv add mlx mlx-lm
17
+ ```
18
+
19
+ ## Quickstart
20
+
21
+ Working example using a real model and tokenizer (mlx-lm required):
22
+
23
+ ```python
24
+ import mlx.core as mx
25
+ import textpolicy as tp
26
+ from textpolicy import load_model, create_policy
27
+ from textpolicy.environment.text_generation import TextGenerationEnv
28
+ from textpolicy.rollout import RolloutRunner, create_strategy
29
+
30
+ # 1) Load model and tokenizer (mlx-lm)
31
+ model, tokenizer = load_model("Qwen/Qwen3-0.6B")
32
+
33
+ # 2) Create a policy (controls generation)
34
+ generation_params = {"max_tokens": 25, "temperature": 0.7}
35
+ policy_fn = create_policy(model, tokenizer, generation_params)
36
+
37
+ # 3) Define a reward function (env uses this to score responses)
38
+ @tp.reward
39
+ def length_reward(prompt: str, completion: str, example: dict, **kwargs) -> float:
40
+ return float(len(completion.split()))
41
+
42
+ # 4) Create an environment (requires a tokenizer)
43
+ env = TextGenerationEnv(["What is AI?"], length_reward, tokenizer=tokenizer)
44
+
45
+ # 5) Collect one rollout step
46
+ strategy = create_strategy('grpo')
47
+ runner = RolloutRunner(env, policy=policy_fn, strategy=strategy, max_steps=1)
48
+ buffer = runner.collect()
49
+ print(len(buffer.episodes))
50
+ ```
51
+
52
+ Docs:
53
+ - Quickstart: `docs/QUICKSTART_UV.md`
54
+ - LoRA/QLoRA: `docs/10_lora_qlora.md`
55
+ - Full index: `docs/index.md`
56
+
57
+ FAQ:
58
+ - Do I need a model?
59
+ - Yes for generation with `create_policy`.
60
+ Use `load_model()` (mlx‑lm) to get `(model, tokenizer)`.
61
+ For reward‑only code (no generation), a model is not required.
62
+ - Do I need a tokenizer?
63
+ - Yes.
64
+ Both `TextGenerationEnv` and `TextGenerationEnvironment` require a tokenizer.
65
+ `load_model()` returns one for mlx‑lm models.
66
+ - How do I control generation?
67
+ - Pass `generation_params` to `create_policy` (for example, `max_tokens`, `temperature`, `top_p`, `repetition_penalty`).
68
+ - What does `step()` return?
69
+ - A dict with `observation`, `reward`, `terminated`, `truncated`, `info`. The runner enforces this.
70
+
71
+ Examples:
72
+ - 01–06: reward functions, batch processing, minimal training
73
+ - 08: GRPO training with rollout + buffer
74
+ - 09–10: length reduction (GRPO/GSPO)
75
+ - 11: LoRA/QLoRA configuration
@@ -0,0 +1,30 @@
1
+ [build-system]
2
+ requires = ["setuptools>=61.0", "wheel"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [tool.setuptools.packages.find]
6
+ include = ["textpolicy*"]
7
+
8
+ [project]
9
+ name = "textpolicy"
10
+ version = "0.1.0"
11
+ description = "MLX-optimized reward and verification system for text generation RL"
12
+ readme = "README.md"
13
+ requires-python = ">=3.12"
14
+ dependencies = [
15
+ "numpy>=2.3.2",
16
+ "mlx>=0.21.0", # Core MLX framework for Apple Silicon acceleration
17
+ "mlx-lm>=0.21.0", # MLX language models for inference
18
+ "gymnasium>=0.29.0",
19
+ "psutil>=7.0.0",
20
+ "wandb>=0.21.1",
21
+ "aiohttp>=3.12.15",
22
+ "pytest>=8.4.1",
23
+ ]
24
+
25
+ [project.scripts]
26
+ textpolicy = "textpolicy.cli:main"
27
+
28
+ [project.optional-dependencies]
29
+ external = ["aiohttp>=3.8.0", "pydantic>=2.0.0"]
30
+ dev = ["pytest>=7.0.0", "black>=22.0.0", "ruff>=0.1.0"]
@@ -0,0 +1,215 @@
1
+ """
2
+ GSPO Verification Tests - Comprehensive Testing of GSPO Implementation
3
+
4
+ This test module verifies that GSPO is working correctly by testing:
5
+ 1. Basic functionality of GSPO components
6
+ 2. Comparison with GRPO behavior
7
+ 3. Sequence-level vs token-level importance sampling
8
+ 4. Mathematical correctness of importance weights
9
+ 5. Training dynamics and convergence
10
+ """
11
+
12
+ import pytest
13
+ import mlx.core as mx
14
+ import mlx.optimizers as optim
15
+ import numpy as np
16
+ from textpolicy.algorithms import grpo, gspo
17
+ from textpolicy.generation.mlx_generation import load_model, create_policy
18
+ from textpolicy.rollout import RolloutCoordinator
19
+ from textpolicy.buffer import Buffer
20
+ from textpolicy.training import Trainer
21
+
22
+
23
+ @pytest.mark.unit
24
+ @pytest.mark.algorithm
25
+ class TestGSPOBasicFunctionality:
26
+ """Test basic GSPO functions work correctly."""
27
+
28
+ def test_sequence_importance_weights(self):
29
+ """Test sequence-level importance weights computation."""
30
+ # Create test data
31
+ old_logprobs = mx.array([-1.0, -1.2, -0.8, -1.1, -0.9]) # 5 tokens
32
+ new_logprobs = mx.array([-1.1, -1.0, -0.9, -1.0, -1.0]) # 5 tokens
33
+ sequence_lengths = [2, 3] # Two sequences: 2 tokens + 3 tokens
34
+
35
+ # Test sequence-level importance weights
36
+ seq_weights = gspo.compute_sequence_importance_weights(
37
+ old_logprobs, new_logprobs, sequence_lengths, clip_ratio=0.2
38
+ )
39
+
40
+ assert len(seq_weights) == len(sequence_lengths), \
41
+ f"Expected {len(sequence_lengths)} weights, got {len(seq_weights)}"
42
+ assert all(not mx.isnan(w) and not mx.isinf(w) for w in seq_weights), \
43
+ "All weights should be finite"
44
+
45
+ def test_gspo_policy_loss(self):
46
+ """Test GSPO policy loss computation."""
47
+ old_logprobs = mx.array([-1.0, -1.2, -0.8, -1.1, -0.9])
48
+ new_logprobs = mx.array([-1.1, -1.0, -0.9, -1.0, -1.0])
49
+ sequence_lengths = [2, 3]
50
+ advantages = mx.array([0.5, -0.3]) # Advantages for each sequence
51
+
52
+ # Test GSPO policy loss
53
+ loss = gspo.gspo_policy_loss(
54
+ old_logprobs, new_logprobs, advantages, sequence_lengths, variant="sequence"
55
+ )
56
+
57
+ assert not mx.isnan(loss) and not mx.isinf(loss), "Loss should be finite"
58
+ assert isinstance(loss, mx.array), "Loss should be an MLX array"
59
+
60
+ def test_hybrid_importance_weights(self):
61
+ """Test hybrid importance weights computation."""
62
+ old_logprobs = mx.array([-1.0, -1.2, -0.8, -1.1, -0.9])
63
+ new_logprobs = mx.array([-1.1, -1.0, -0.9, -1.0, -1.0])
64
+ sequence_lengths = [2, 3]
65
+
66
+ # Test hybrid variant
67
+ hybrid_weights = gspo.compute_hybrid_importance_weights(
68
+ old_logprobs, new_logprobs, sequence_lengths
69
+ )
70
+
71
+ assert len(hybrid_weights) == len(old_logprobs), \
72
+ f"Expected {len(old_logprobs)} hybrid weights, got {len(hybrid_weights)}"
73
+ assert all(not mx.isnan(w) and not mx.isinf(w) for w in hybrid_weights), \
74
+ "All hybrid weights should be finite"
75
+
76
+
77
+ @pytest.mark.unit
78
+ @pytest.mark.algorithm
79
+ class TestGSPOvsGRPO:
80
+ """Test that GSPO produces different importance weights than GRPO."""
81
+
82
+ def test_importance_weight_differences(self):
83
+ """Test that GSPO produces different importance weights than GRPO."""
84
+ # Create test data with clear differences between old and new policies
85
+ old_logprobs = mx.array([-2.0, -2.0, -1.0, -1.0]) # 4 tokens
86
+ new_logprobs = mx.array([-1.0, -1.0, -2.0, -2.0]) # Policy changed significantly
87
+ sequence_lengths = [2, 2] # Two sequences of equal length
88
+
89
+ # Compute GSPO sequence-level importance weights
90
+ gspo_weights = gspo.compute_sequence_importance_weights(
91
+ old_logprobs, new_logprobs, sequence_lengths, clip_ratio=0.2
92
+ )
93
+
94
+ # Compute GRPO token-level importance ratios for comparison
95
+ grpo_ratios = mx.exp(new_logprobs - old_logprobs)
96
+ grpo_ratios_clipped = mx.clip(grpo_ratios, 0.8, 1.2)
97
+
98
+ # GSPO should produce sequence-level weights (2 values)
99
+ # GRPO produces token-level ratios (4 values)
100
+ assert len(gspo_weights) == len(sequence_lengths), \
101
+ f"GSPO should produce {len(sequence_lengths)} sequence weights"
102
+ assert len(grpo_ratios) == len(old_logprobs), \
103
+ f"GRPO should produce {len(old_logprobs)} token ratios"
104
+
105
+ # The approaches should be fundamentally different
106
+ # GSPO normalizes by sequence length, GRPO doesn't
107
+ assert len(gspo_weights) != len(grpo_ratios), \
108
+ "GSPO and GRPO should produce different numbers of weights"
109
+
110
+
111
+ @pytest.mark.unit
112
+ @pytest.mark.algorithm
113
+ class TestGSPOClipping:
114
+ """Test GSPO clipping behavior."""
115
+
116
+ def test_clipping_bounds_respected(self):
117
+ """Test that importance weights respect clipping bounds."""
118
+ # Test extreme case to verify clipping
119
+ old_logprobs = mx.array([-10.0, -1.0]) # Extreme difference
120
+ new_logprobs = mx.array([-1.0, -1.0])
121
+ sequence_lengths = [2]
122
+ clip_ratio = 0.2
123
+
124
+ # Compute sequence weights
125
+ weights = gspo.compute_sequence_importance_weights(
126
+ old_logprobs, new_logprobs, sequence_lengths, clip_ratio=clip_ratio
127
+ )
128
+
129
+ # Weights should be clipped between (1-clip_ratio) and (1+clip_ratio)
130
+ lower_bound = 1.0 - clip_ratio
131
+ upper_bound = 1.0 + clip_ratio
132
+
133
+ # Use tolerance for floating-point comparisons due to MLX float32 precision
134
+ # MLX uses float32 by default, which has precision ~1.19e-7
135
+ tolerance = 1e-6 # Conservative tolerance for float32 precision issues
136
+
137
+ for weight in weights:
138
+ weight_val = float(weight) # Convert MLX scalar to Python float
139
+ assert lower_bound - tolerance <= weight_val <= upper_bound + tolerance, \
140
+ f"Weight {weight_val} outside clipping bounds [{lower_bound}, {upper_bound}] with tolerance {tolerance}"
141
+
142
+ def test_length_normalization_effect(self):
143
+ """Test that GSPO properly normalizes by sequence length."""
144
+ # Identical sequences of different lengths should have similar weights
145
+ old_logprobs_short = mx.array([-1.0, -1.0]) # 2 tokens
146
+ new_logprobs_short = mx.array([-0.5, -0.5]) # Better by 0.5 per token
147
+
148
+ old_logprobs_long = mx.array([-1.0, -1.0, -1.0, -1.0]) # 4 tokens
149
+ new_logprobs_long = mx.array([-0.5, -0.5, -0.5, -0.5]) # Better by 0.5 per token
150
+
151
+ weight_short = gspo.compute_sequence_importance_weights(
152
+ old_logprobs_short, new_logprobs_short, [2], clip_ratio=1.0 # No clipping
153
+ )
154
+ weight_long = gspo.compute_sequence_importance_weights(
155
+ old_logprobs_long, new_logprobs_long, [4], clip_ratio=1.0 # No clipping
156
+ )
157
+
158
+ # Both should be similar due to length normalization
159
+ # Short: exp((sum(-0.5) - sum(-1.0)) / 2) = exp((−1.0 − (−2.0)) / 2) = exp(0.5)
160
+ # Long: exp((sum(-0.5) - sum(-1.0)) / 4) = exp((−2.0 − (−4.0)) / 4) = exp(0.5)
161
+ short_val = float(weight_short[0])
162
+ long_val = float(weight_long[0])
163
+
164
+ # They should be approximately equal due to length normalization
165
+ assert abs(short_val - long_val) < 0.01, \
166
+ f"Length normalization failed: short={short_val}, long={long_val}"
167
+
168
+
169
+ @pytest.mark.integration
170
+ @pytest.mark.algorithm
171
+ @pytest.mark.slow
172
+ class TestGSPOTraining:
173
+ """Integration tests for GSPO training."""
174
+
175
+ def test_gspo_training_step(self):
176
+ """Test a complete GSPO training step."""
177
+ # This is a minimal integration test
178
+ # Create minimal test data
179
+ old_logprobs = mx.array([-1.0, -1.0, -1.0, -1.0])
180
+ new_logprobs = mx.array([-0.8, -0.8, -1.2, -1.2])
181
+ advantages = mx.array([0.5, -0.3])
182
+ sequence_lengths = [2, 2]
183
+
184
+ # Test that we can compute a complete loss
185
+ loss = gspo.gspo_policy_loss(
186
+ old_logprobs, new_logprobs, advantages, sequence_lengths, variant="sequence"
187
+ )
188
+
189
+ assert not mx.isnan(loss) and not mx.isinf(loss), "Training loss should be finite"
190
+ assert float(loss) != 0.0, "Loss should be non-zero for non-trivial inputs"
191
+
192
+ def test_gspo_metrics_computation(self):
193
+ """Test GSPO metrics computation."""
194
+ old_logprobs = mx.array([-1.0, -1.0, -1.0, -1.0])
195
+ new_logprobs = mx.array([-0.8, -0.8, -1.2, -1.2])
196
+ advantages = mx.array([0.5, -0.3])
197
+
198
+ # Test metrics computation
199
+ metrics_fn = gspo.create_gspo_metrics(variant="sequence")
200
+ metrics = metrics_fn(old_logprobs, new_logprobs, advantages)
201
+
202
+ assert isinstance(metrics, dict), "Metrics should be a dictionary"
203
+ assert len(metrics) > 0, "Metrics should not be empty"
204
+
205
+ # Check for expected metric keys
206
+ expected_keys = ['mean_advantage', 'std_advantage']
207
+ for key in expected_keys:
208
+ assert key in metrics, f"Missing expected metric: {key}"
209
+ assert isinstance(metrics[key], (int, float)), \
210
+ f"Metric {key} should be numeric, got {type(metrics[key])}"
211
+
212
+
213
+ if __name__ == "__main__":
214
+ # Allow running this file directly for debugging
215
+ pytest.main([__file__, "-v"])
@@ -0,0 +1,49 @@
1
+ import pytest
2
+
3
+
4
+ @pytest.mark.integration
5
+ def test_e2e_minimal_rollout_grpo():
6
+ """
7
+ Minimal end-to-end rollout + buffer collection using TextGenerationEnv
8
+ with a dummy tokenizer and a trivial policy. This validates that
9
+ the environment returns dict-shaped step results and the runner
10
+ normalization path works as expected.
11
+
12
+ Kept intentionally lightweight for CI (no external model downloads).
13
+ """
14
+ try:
15
+ import mlx.core as mx # type: ignore
16
+ except Exception:
17
+ pytest.skip("MLX not available")
18
+
19
+ from textpolicy.environment.text_generation import TextGenerationEnv
20
+ from textpolicy.rollout.runner import RolloutRunner
21
+ from textpolicy.rollout.strategy import create_strategy
22
+
23
+ class DummyTokenizer:
24
+ def encode(self, text):
25
+ return [ord(c) % 256 for c in text]
26
+
27
+ def decode(self, ids):
28
+ return "".join(chr(int(i) % 256) for i in ids)
29
+
30
+ def reward_fn(prompt, completion, example, **kwargs) -> float:
31
+ # Simple length reward in words
32
+ return float(len(completion.split()))
33
+
34
+ # Create simple environment
35
+ env = TextGenerationEnv(["Hello"], reward_fn, tokenizer=DummyTokenizer())
36
+
37
+ # Policy returns tokens that decode to 'a b c'
38
+ def simple_policy(obs_mx, deterministic=False):
39
+ return mx.array([97, 32, 98, 32, 99], dtype=mx.int32), {}
40
+
41
+ strategy = create_strategy('grpo')
42
+ runner = RolloutRunner(env, policy=simple_policy, strategy=strategy, max_steps=2)
43
+
44
+ buffer = runner.collect()
45
+ assert len(buffer.episodes) >= 1
46
+ ep = buffer.episodes[0]
47
+ # Episode stores rewards in `rew`
48
+ assert len(ep.rew) >= 1
49
+ assert all(r > 0 for r in ep.rew)
@@ -0,0 +1,131 @@
1
+ """
2
+ Reward Function Signature Tests
3
+
4
+ Test reward function signatures and compatibility to ensure proper integration.
5
+ """
6
+
7
+ import pytest
8
+ from textpolicy.rewards import length_reward, keyword_reward, perplexity_reward, accuracy_reward
9
+
10
+
11
+ @pytest.mark.unit
12
+ @pytest.mark.reward
13
+ class TestRewardFunctionSignatures:
14
+ """Test reward function signatures for compatibility."""
15
+
16
+ def test_reward_functions_import(self):
17
+ """Test that all reward functions can be imported successfully."""
18
+ # Test that imports work
19
+ assert callable(length_reward), "length_reward should be callable"
20
+ assert callable(keyword_reward), "keyword_reward should be callable"
21
+ assert callable(perplexity_reward), "perplexity_reward should be callable"
22
+ assert callable(accuracy_reward), "accuracy_reward should be callable"
23
+
24
+ def test_length_reward_signature(self):
25
+ """Test length_reward function signature."""
26
+ test_prompt = "What is AI?"
27
+ test_completion = "AI is artificial intelligence technology that enables machines to simulate human thinking."
28
+ test_example = {"target_length": 15}
29
+
30
+ # Test basic call
31
+ try:
32
+ reward = length_reward(test_prompt, test_completion, test_example)
33
+ assert isinstance(reward, (int, float)), "length_reward should return numeric value"
34
+ except Exception as e:
35
+ pytest.fail(f"length_reward failed with signature (prompt, completion, example): {e}")
36
+
37
+ def test_keyword_reward_signature(self):
38
+ """Test keyword_reward function signature."""
39
+ test_prompt = "What is AI?"
40
+ test_completion = "AI is artificial intelligence technology that enables machines to simulate human thinking."
41
+ test_example = {"keywords": ["AI", "intelligence"]}
42
+
43
+ try:
44
+ reward = keyword_reward(test_prompt, test_completion, test_example)
45
+ assert isinstance(reward, (int, float)), "keyword_reward should return numeric value"
46
+ except Exception as e:
47
+ pytest.fail(f"keyword_reward failed with signature (prompt, completion, example): {e}")
48
+
49
+ def test_perplexity_reward_signature(self):
50
+ """Test perplexity_reward function signature."""
51
+ test_prompt = "What is AI?"
52
+ test_completion = "AI is artificial intelligence technology."
53
+ test_example = {"max_perplexity": 10.0}
54
+
55
+ try:
56
+ reward = perplexity_reward(test_prompt, test_completion, test_example)
57
+ assert isinstance(reward, (int, float)), "perplexity_reward should return numeric value"
58
+ except Exception as e:
59
+ pytest.fail(f"perplexity_reward failed with signature (prompt, completion, example): {e}")
60
+
61
+ def test_accuracy_reward_signature(self):
62
+ """Test accuracy_reward function signature."""
63
+ test_prompt = "What is 2+2?"
64
+ test_completion = "4"
65
+ test_example = {"correct_answer": "4"}
66
+
67
+ try:
68
+ reward = accuracy_reward(test_prompt, test_completion, test_example)
69
+ assert isinstance(reward, (int, float)), "accuracy_reward should return numeric value"
70
+ except Exception as e:
71
+ pytest.fail(f"accuracy_reward failed with signature (prompt, completion, example): {e}")
72
+
73
+ @pytest.mark.parametrize("reward_func,example_data", [
74
+ (length_reward, {"target_length": 15}),
75
+ (keyword_reward, {"keywords": ["test", "example"]}),
76
+ (perplexity_reward, {"max_perplexity": 10.0}),
77
+ (accuracy_reward, {"correct_answer": "test answer"}),
78
+ ])
79
+ def test_reward_function_consistency(self, reward_func, example_data):
80
+ """Test that all reward functions follow consistent signature patterns."""
81
+ test_prompt = "Test prompt"
82
+ test_completion = "Test completion response"
83
+
84
+ # All reward functions should accept (prompt, completion, example) signature
85
+ try:
86
+ result = reward_func(test_prompt, test_completion, example_data)
87
+ assert isinstance(result, (int, float)), \
88
+ f"{reward_func.__name__} should return numeric value"
89
+ assert -1.0 <= result <= 1.0, \
90
+ f"{reward_func.__name__} should return value in [-1, 1] range, got {result}"
91
+ except Exception as e:
92
+ pytest.fail(f"{reward_func.__name__} failed with standard signature: {e}")
93
+
94
+
95
+ @pytest.mark.integration
96
+ @pytest.mark.reward
97
+ class TestRewardIntegration:
98
+ """Test reward function integration with the system."""
99
+
100
+ def test_reward_functions_with_realistic_data(self):
101
+ """Test reward functions with realistic data."""
102
+ prompt = "Explain what machine learning is in simple terms."
103
+ completion = "Machine learning is a type of artificial intelligence that allows computers to learn and improve from data without being explicitly programmed for every task."
104
+
105
+ # Test length reward
106
+ length_example = {"target_length": 20}
107
+ length_result = length_reward(prompt, completion, length_example)
108
+ assert isinstance(length_result, (int, float))
109
+
110
+ # Test keyword reward
111
+ keyword_example = {"keywords": ["machine", "learning", "artificial", "intelligence"]}
112
+ keyword_result = keyword_reward(prompt, completion, keyword_example)
113
+ assert isinstance(keyword_result, (int, float))
114
+
115
+ # Test perplexity reward (if model available)
116
+ perplexity_example = {"max_perplexity": 15.0}
117
+ try:
118
+ perplexity_result = perplexity_reward(prompt, completion, perplexity_example)
119
+ assert isinstance(perplexity_result, (int, float))
120
+ except Exception:
121
+ # Perplexity might fail if model not available, which is acceptable
122
+ pytest.skip("Perplexity reward requires model - skipping")
123
+
124
+ # Test accuracy reward
125
+ accuracy_example = {"correct_answer": "machine learning"}
126
+ accuracy_result = accuracy_reward(prompt, completion, accuracy_example)
127
+ assert isinstance(accuracy_result, (int, float))
128
+
129
+
130
+ if __name__ == "__main__":
131
+ pytest.main([__file__, "-v"])