textpolicy 0.0.1__tar.gz → 0.1.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (84) hide show
  1. textpolicy-0.1.1/PKG-INFO +109 -0
  2. textpolicy-0.1.1/README.md +75 -0
  3. textpolicy-0.1.1/pyproject.toml +56 -0
  4. textpolicy-0.1.1/tests/test_gspo_verification.py +215 -0
  5. textpolicy-0.1.1/tests/test_integration_e2e_training.py +49 -0
  6. textpolicy-0.1.1/tests/test_issue_fixes.py +218 -0
  7. textpolicy-0.1.1/tests/test_reward_signatures.py +131 -0
  8. textpolicy-0.1.1/tests/test_rollout_rewards.py +228 -0
  9. textpolicy-0.1.1/tests/test_runner_step_enforcement.py +80 -0
  10. textpolicy-0.1.1/tests/test_validate_installation.py +12 -0
  11. textpolicy-0.1.1/textpolicy/__init__.py +53 -0
  12. textpolicy-0.1.1/textpolicy/__main__.py +8 -0
  13. textpolicy-0.1.1/textpolicy/algorithms/__init__.py +54 -0
  14. textpolicy-0.1.1/textpolicy/algorithms/grpo.py +642 -0
  15. textpolicy-0.1.1/textpolicy/algorithms/gspo.py +582 -0
  16. textpolicy-0.1.1/textpolicy/buffer/__init__.py +23 -0
  17. textpolicy-0.1.1/textpolicy/buffer/buffer.py +244 -0
  18. textpolicy-0.1.1/textpolicy/buffer/episode.py +383 -0
  19. textpolicy-0.1.1/textpolicy/buffer/sampling.py +438 -0
  20. textpolicy-0.1.1/textpolicy/buffer/storage.py +255 -0
  21. textpolicy-0.1.1/textpolicy/cli.py +67 -0
  22. textpolicy-0.1.1/textpolicy/environment/__init__.py +79 -0
  23. textpolicy-0.1.1/textpolicy/environment/base.py +110 -0
  24. textpolicy-0.1.1/textpolicy/environment/environment.py +46 -0
  25. textpolicy-0.1.1/textpolicy/environment/factory.py +103 -0
  26. textpolicy-0.1.1/textpolicy/environment/gym.py +106 -0
  27. textpolicy-0.1.1/textpolicy/environment/task_suites.py +51 -0
  28. textpolicy-0.1.1/textpolicy/environment/text_generation.py +797 -0
  29. textpolicy-0.1.1/textpolicy/environment/vectorized.py +253 -0
  30. textpolicy-0.1.1/textpolicy/generation/__init__.py +62 -0
  31. textpolicy-0.1.1/textpolicy/generation/lora.py +411 -0
  32. textpolicy-0.1.1/textpolicy/generation/mlx_generation.py +557 -0
  33. textpolicy-0.1.1/textpolicy/generation/reload.py +253 -0
  34. textpolicy-0.1.1/textpolicy/rewards/__init__.py +137 -0
  35. textpolicy-0.1.1/textpolicy/rewards/adapters.py +387 -0
  36. textpolicy-0.1.1/textpolicy/rewards/basic.py +214 -0
  37. textpolicy-0.1.1/textpolicy/rewards/integrated_system.py +338 -0
  38. textpolicy-0.1.1/textpolicy/rewards/mlx_batch_processor.py +447 -0
  39. textpolicy-0.1.1/textpolicy/rewards/registry.py +293 -0
  40. textpolicy-0.1.1/textpolicy/rewards/rollout_rewards.py +410 -0
  41. textpolicy-0.1.1/textpolicy/rewards/verifiers.py +369 -0
  42. textpolicy-0.1.1/textpolicy/rollout/__init__.py +44 -0
  43. textpolicy-0.1.1/textpolicy/rollout/aggregator.py +145 -0
  44. textpolicy-0.1.1/textpolicy/rollout/base.py +108 -0
  45. textpolicy-0.1.1/textpolicy/rollout/rollout.py +142 -0
  46. textpolicy-0.1.1/textpolicy/rollout/runner.py +280 -0
  47. textpolicy-0.1.1/textpolicy/rollout/strategy.py +208 -0
  48. textpolicy-0.1.1/textpolicy/rollout/worker.py +194 -0
  49. textpolicy-0.1.1/textpolicy/training/__init__.py +14 -0
  50. textpolicy-0.1.1/textpolicy/training/metrics.py +242 -0
  51. textpolicy-0.1.1/textpolicy/training/rollout_manager.py +78 -0
  52. textpolicy-0.1.1/textpolicy/training/trainer.py +684 -0
  53. textpolicy-0.1.1/textpolicy/utils/__init__.py +40 -0
  54. textpolicy-0.1.1/textpolicy/utils/benchmarking.py +489 -0
  55. textpolicy-0.1.1/textpolicy/utils/data.py +60 -0
  56. textpolicy-0.1.1/textpolicy/utils/debug.py +170 -0
  57. textpolicy-0.1.1/textpolicy/utils/environment.py +349 -0
  58. textpolicy-0.1.1/textpolicy/utils/logging/__init__.py +22 -0
  59. textpolicy-0.1.1/textpolicy/utils/logging/base.py +48 -0
  60. textpolicy-0.1.1/textpolicy/utils/logging/console.py +61 -0
  61. textpolicy-0.1.1/textpolicy/utils/logging/factory.py +133 -0
  62. textpolicy-0.1.1/textpolicy/utils/logging/multi.py +83 -0
  63. textpolicy-0.1.1/textpolicy/utils/logging/tensorboard.py +65 -0
  64. textpolicy-0.1.1/textpolicy/utils/logging/wandb.py +72 -0
  65. textpolicy-0.1.1/textpolicy/utils/memory.py +118 -0
  66. textpolicy-0.1.1/textpolicy/utils/performance.py +464 -0
  67. textpolicy-0.1.1/textpolicy/utils/timing.py +171 -0
  68. textpolicy-0.1.1/textpolicy/validate.py +101 -0
  69. textpolicy-0.1.1/textpolicy/validation/__init__.py +13 -0
  70. textpolicy-0.1.1/textpolicy/validation/logprob_validation.py +315 -0
  71. textpolicy-0.1.1/textpolicy.egg-info/PKG-INFO +109 -0
  72. textpolicy-0.1.1/textpolicy.egg-info/SOURCES.txt +76 -0
  73. textpolicy-0.1.1/textpolicy.egg-info/entry_points.txt +2 -0
  74. textpolicy-0.1.1/textpolicy.egg-info/requires.txt +17 -0
  75. textpolicy-0.0.1/PKG-INFO +0 -10
  76. textpolicy-0.0.1/README.md +0 -1
  77. textpolicy-0.0.1/pyproject.toml +0 -7
  78. textpolicy-0.0.1/textpolicy/__init__.py +0 -0
  79. textpolicy-0.0.1/textpolicy.egg-info/PKG-INFO +0 -10
  80. textpolicy-0.0.1/textpolicy.egg-info/SOURCES.txt +0 -8
  81. {textpolicy-0.0.1 → textpolicy-0.1.1}/LICENSE +0 -0
  82. {textpolicy-0.0.1 → textpolicy-0.1.1}/setup.cfg +0 -0
  83. {textpolicy-0.0.1 → textpolicy-0.1.1}/textpolicy.egg-info/dependency_links.txt +0 -0
  84. {textpolicy-0.0.1 → textpolicy-0.1.1}/textpolicy.egg-info/top_level.txt +0 -0
@@ -0,0 +1,109 @@
1
+ Metadata-Version: 2.4
2
+ Name: textpolicy
3
+ Version: 0.1.1
4
+ Summary: Reinforcement learning for text generation on MLX (Apple Silicon): GRPO/GSPO, environments, rollout, rewards, LoRA/QLoRA
5
+ Project-URL: Homepage, https://github.com/teilomillet/textpolicy
6
+ Project-URL: Repository, https://github.com/teilomillet/textpolicy
7
+ Project-URL: Documentation, https://github.com/teilomillet/textpolicy#readme
8
+ Project-URL: Changelog, https://github.com/teilomillet/textpolicy/blob/main/CHANGELOG.md
9
+ Keywords: reinforcement-learning,text-generation,mlx,apple-silicon,lora,qlora,grpo,gspo,rlhf
10
+ Classifier: Programming Language :: Python :: 3
11
+ Classifier: Operating System :: MacOS
12
+ Classifier: Intended Audience :: Developers
13
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
14
+ Classifier: License :: OSI Approved :: MIT License
15
+ Requires-Python: >=3.12
16
+ Description-Content-Type: text/markdown
17
+ License-File: LICENSE
18
+ Requires-Dist: numpy>=2.3.2
19
+ Requires-Dist: mlx>=0.21.0
20
+ Requires-Dist: mlx-lm>=0.21.0
21
+ Requires-Dist: gymnasium>=0.29.0
22
+ Requires-Dist: psutil>=7.0.0
23
+ Requires-Dist: wandb>=0.21.1
24
+ Requires-Dist: aiohttp>=3.12.15
25
+ Requires-Dist: pytest>=8.4.1
26
+ Provides-Extra: external
27
+ Requires-Dist: aiohttp>=3.8.0; extra == "external"
28
+ Requires-Dist: pydantic>=2.0.0; extra == "external"
29
+ Provides-Extra: dev
30
+ Requires-Dist: pytest>=7.0.0; extra == "dev"
31
+ Requires-Dist: black>=22.0.0; extra == "dev"
32
+ Requires-Dist: ruff>=0.1.0; extra == "dev"
33
+ Dynamic: license-file
34
+
35
+ # TextPolicy
36
+
37
+ Reinforcement learning toolkit for text generation on MLX (Apple Silicon).
38
+ TextPolicy provides algorithms (GRPO/GSPO), text-generation environments, a rollout runner,
39
+ reward functions with a decorator registry, and LoRA/QLoRA utilities.
40
+
41
+ ## Install (uv)
42
+
43
+ ```bash
44
+ uv add textpolicy
45
+ ```
46
+
47
+ Optional model integration:
48
+
49
+ ```bash
50
+ uv add mlx mlx-lm
51
+ ```
52
+
53
+ ## Quickstart
54
+
55
+ Working example using a real model and tokenizer (mlx-lm required):
56
+
57
+ ```python
58
+ import mlx.core as mx
59
+ import textpolicy as tp
60
+ from textpolicy import load_model, create_policy
61
+ from textpolicy.environment.text_generation import TextGenerationEnv
62
+ from textpolicy.rollout import RolloutRunner, create_strategy
63
+
64
+ # 1) Load model and tokenizer (mlx-lm)
65
+ model, tokenizer = load_model("Qwen/Qwen3-0.6B")
66
+
67
+ # 2) Create a policy (controls generation)
68
+ generation_params = {"max_tokens": 25, "temperature": 0.7}
69
+ policy_fn = create_policy(model, tokenizer, generation_params)
70
+
71
+ # 3) Define a reward function (env uses this to score responses)
72
+ @tp.reward
73
+ def length_reward(prompt: str, completion: str, example: dict, **kwargs) -> float:
74
+ return float(len(completion.split()))
75
+
76
+ # 4) Create an environment (requires a tokenizer)
77
+ env = TextGenerationEnv(["What is AI?"], length_reward, tokenizer=tokenizer)
78
+
79
+ # 5) Collect one rollout step
80
+ strategy = create_strategy('grpo')
81
+ runner = RolloutRunner(env, policy=policy_fn, strategy=strategy, max_steps=1)
82
+ buffer = runner.collect()
83
+ print(len(buffer.episodes))
84
+ ```
85
+
86
+ Docs:
87
+ - Quickstart: `docs/QUICKSTART_UV.md`
88
+ - LoRA/QLoRA: `docs/10_lora_qlora.md`
89
+ - Full index: `docs/index.md`
90
+
91
+ FAQ:
92
+ - Do I need a model?
93
+ - Yes for generation with `create_policy`.
94
+ Use `load_model()` (mlx‑lm) to get `(model, tokenizer)`.
95
+ For reward‑only code (no generation), a model is not required.
96
+ - Do I need a tokenizer?
97
+ - Yes.
98
+ Both `TextGenerationEnv` and `TextGenerationEnvironment` require a tokenizer.
99
+ `load_model()` returns one for mlx‑lm models.
100
+ - How do I control generation?
101
+ - Pass `generation_params` to `create_policy` (for example, `max_tokens`, `temperature`, `top_p`, `repetition_penalty`).
102
+ - What does `step()` return?
103
+ - A dict with `observation`, `reward`, `terminated`, `truncated`, `info`. The runner enforces this.
104
+
105
+ Examples:
106
+ - 01–06: reward functions, batch processing, minimal training
107
+ - 08: GRPO training with rollout + buffer
108
+ - 09–10: length reduction (GRPO/GSPO)
109
+ - 11: LoRA/QLoRA configuration
@@ -0,0 +1,75 @@
1
+ # TextPolicy
2
+
3
+ Reinforcement learning toolkit for text generation on MLX (Apple Silicon).
4
+ TextPolicy provides algorithms (GRPO/GSPO), text-generation environments, a rollout runner,
5
+ reward functions with a decorator registry, and LoRA/QLoRA utilities.
6
+
7
+ ## Install (uv)
8
+
9
+ ```bash
10
+ uv add textpolicy
11
+ ```
12
+
13
+ Optional model integration:
14
+
15
+ ```bash
16
+ uv add mlx mlx-lm
17
+ ```
18
+
19
+ ## Quickstart
20
+
21
+ Working example using a real model and tokenizer (mlx-lm required):
22
+
23
+ ```python
24
+ import mlx.core as mx
25
+ import textpolicy as tp
26
+ from textpolicy import load_model, create_policy
27
+ from textpolicy.environment.text_generation import TextGenerationEnv
28
+ from textpolicy.rollout import RolloutRunner, create_strategy
29
+
30
+ # 1) Load model and tokenizer (mlx-lm)
31
+ model, tokenizer = load_model("Qwen/Qwen3-0.6B")
32
+
33
+ # 2) Create a policy (controls generation)
34
+ generation_params = {"max_tokens": 25, "temperature": 0.7}
35
+ policy_fn = create_policy(model, tokenizer, generation_params)
36
+
37
+ # 3) Define a reward function (env uses this to score responses)
38
+ @tp.reward
39
+ def length_reward(prompt: str, completion: str, example: dict, **kwargs) -> float:
40
+ return float(len(completion.split()))
41
+
42
+ # 4) Create an environment (requires a tokenizer)
43
+ env = TextGenerationEnv(["What is AI?"], length_reward, tokenizer=tokenizer)
44
+
45
+ # 5) Collect one rollout step
46
+ strategy = create_strategy('grpo')
47
+ runner = RolloutRunner(env, policy=policy_fn, strategy=strategy, max_steps=1)
48
+ buffer = runner.collect()
49
+ print(len(buffer.episodes))
50
+ ```
51
+
52
+ Docs:
53
+ - Quickstart: `docs/QUICKSTART_UV.md`
54
+ - LoRA/QLoRA: `docs/10_lora_qlora.md`
55
+ - Full index: `docs/index.md`
56
+
57
+ FAQ:
58
+ - Do I need a model?
59
+ - Yes for generation with `create_policy`.
60
+ Use `load_model()` (mlx‑lm) to get `(model, tokenizer)`.
61
+ For reward‑only code (no generation), a model is not required.
62
+ - Do I need a tokenizer?
63
+ - Yes.
64
+ Both `TextGenerationEnv` and `TextGenerationEnvironment` require a tokenizer.
65
+ `load_model()` returns one for mlx‑lm models.
66
+ - How do I control generation?
67
+ - Pass `generation_params` to `create_policy` (for example, `max_tokens`, `temperature`, `top_p`, `repetition_penalty`).
68
+ - What does `step()` return?
69
+ - A dict with `observation`, `reward`, `terminated`, `truncated`, `info`. The runner enforces this.
70
+
71
+ Examples:
72
+ - 01–06: reward functions, batch processing, minimal training
73
+ - 08: GRPO training with rollout + buffer
74
+ - 09–10: length reduction (GRPO/GSPO)
75
+ - 11: LoRA/QLoRA configuration
@@ -0,0 +1,56 @@
1
+ [build-system]
2
+ requires = ["setuptools>=61.0", "wheel"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [tool.setuptools.packages.find]
6
+ include = ["textpolicy*"]
7
+
8
+ [project]
9
+ name = "textpolicy"
10
+ version = "0.1.1"
11
+ description = "Reinforcement learning for text generation on MLX (Apple Silicon): GRPO/GSPO, environments, rollout, rewards, LoRA/QLoRA"
12
+ readme = "README.md"
13
+ requires-python = ">=3.12"
14
+ dependencies = [
15
+ "numpy>=2.3.2",
16
+ "mlx>=0.21.0", # Core MLX framework for Apple Silicon acceleration
17
+ "mlx-lm>=0.21.0", # MLX language models for inference
18
+ "gymnasium>=0.29.0",
19
+ "psutil>=7.0.0",
20
+ "wandb>=0.21.1",
21
+ "aiohttp>=3.12.15",
22
+ "pytest>=8.4.1",
23
+ ]
24
+
25
+ keywords = [
26
+ "reinforcement-learning",
27
+ "text-generation",
28
+ "mlx",
29
+ "apple-silicon",
30
+ "lora",
31
+ "qlora",
32
+ "grpo",
33
+ "gspo",
34
+ "rlhf",
35
+ ]
36
+
37
+ classifiers = [
38
+ "Programming Language :: Python :: 3",
39
+ "Operating System :: MacOS",
40
+ "Intended Audience :: Developers",
41
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
42
+ "License :: OSI Approved :: MIT License",
43
+ ]
44
+
45
+ [project.urls]
46
+ Homepage = "https://github.com/teilomillet/textpolicy"
47
+ Repository = "https://github.com/teilomillet/textpolicy"
48
+ Documentation = "https://github.com/teilomillet/textpolicy#readme"
49
+ Changelog = "https://github.com/teilomillet/textpolicy/blob/main/CHANGELOG.md"
50
+
51
+ [project.scripts]
52
+ textpolicy = "textpolicy.cli:main"
53
+
54
+ [project.optional-dependencies]
55
+ external = ["aiohttp>=3.8.0", "pydantic>=2.0.0"]
56
+ dev = ["pytest>=7.0.0", "black>=22.0.0", "ruff>=0.1.0"]
@@ -0,0 +1,215 @@
1
+ """
2
+ GSPO Verification Tests - Comprehensive Testing of GSPO Implementation
3
+
4
+ This test module verifies that GSPO is working correctly by testing:
5
+ 1. Basic functionality of GSPO components
6
+ 2. Comparison with GRPO behavior
7
+ 3. Sequence-level vs token-level importance sampling
8
+ 4. Mathematical correctness of importance weights
9
+ 5. Training dynamics and convergence
10
+ """
11
+
12
+ import pytest
13
+ import mlx.core as mx
14
+ import mlx.optimizers as optim
15
+ import numpy as np
16
+ from textpolicy.algorithms import grpo, gspo
17
+ from textpolicy.generation.mlx_generation import load_model, create_policy
18
+ from textpolicy.rollout import RolloutCoordinator
19
+ from textpolicy.buffer import Buffer
20
+ from textpolicy.training import Trainer
21
+
22
+
23
+ @pytest.mark.unit
24
+ @pytest.mark.algorithm
25
+ class TestGSPOBasicFunctionality:
26
+ """Test basic GSPO functions work correctly."""
27
+
28
+ def test_sequence_importance_weights(self):
29
+ """Test sequence-level importance weights computation."""
30
+ # Create test data
31
+ old_logprobs = mx.array([-1.0, -1.2, -0.8, -1.1, -0.9]) # 5 tokens
32
+ new_logprobs = mx.array([-1.1, -1.0, -0.9, -1.0, -1.0]) # 5 tokens
33
+ sequence_lengths = [2, 3] # Two sequences: 2 tokens + 3 tokens
34
+
35
+ # Test sequence-level importance weights
36
+ seq_weights = gspo.compute_sequence_importance_weights(
37
+ old_logprobs, new_logprobs, sequence_lengths, clip_ratio=0.2
38
+ )
39
+
40
+ assert len(seq_weights) == len(sequence_lengths), \
41
+ f"Expected {len(sequence_lengths)} weights, got {len(seq_weights)}"
42
+ assert all(not mx.isnan(w) and not mx.isinf(w) for w in seq_weights), \
43
+ "All weights should be finite"
44
+
45
+ def test_gspo_policy_loss(self):
46
+ """Test GSPO policy loss computation."""
47
+ old_logprobs = mx.array([-1.0, -1.2, -0.8, -1.1, -0.9])
48
+ new_logprobs = mx.array([-1.1, -1.0, -0.9, -1.0, -1.0])
49
+ sequence_lengths = [2, 3]
50
+ advantages = mx.array([0.5, -0.3]) # Advantages for each sequence
51
+
52
+ # Test GSPO policy loss
53
+ loss = gspo.gspo_policy_loss(
54
+ old_logprobs, new_logprobs, advantages, sequence_lengths, variant="sequence"
55
+ )
56
+
57
+ assert not mx.isnan(loss) and not mx.isinf(loss), "Loss should be finite"
58
+ assert isinstance(loss, mx.array), "Loss should be an MLX array"
59
+
60
+ def test_hybrid_importance_weights(self):
61
+ """Test hybrid importance weights computation."""
62
+ old_logprobs = mx.array([-1.0, -1.2, -0.8, -1.1, -0.9])
63
+ new_logprobs = mx.array([-1.1, -1.0, -0.9, -1.0, -1.0])
64
+ sequence_lengths = [2, 3]
65
+
66
+ # Test hybrid variant
67
+ hybrid_weights = gspo.compute_hybrid_importance_weights(
68
+ old_logprobs, new_logprobs, sequence_lengths
69
+ )
70
+
71
+ assert len(hybrid_weights) == len(old_logprobs), \
72
+ f"Expected {len(old_logprobs)} hybrid weights, got {len(hybrid_weights)}"
73
+ assert all(not mx.isnan(w) and not mx.isinf(w) for w in hybrid_weights), \
74
+ "All hybrid weights should be finite"
75
+
76
+
77
+ @pytest.mark.unit
78
+ @pytest.mark.algorithm
79
+ class TestGSPOvsGRPO:
80
+ """Test that GSPO produces different importance weights than GRPO."""
81
+
82
+ def test_importance_weight_differences(self):
83
+ """Test that GSPO produces different importance weights than GRPO."""
84
+ # Create test data with clear differences between old and new policies
85
+ old_logprobs = mx.array([-2.0, -2.0, -1.0, -1.0]) # 4 tokens
86
+ new_logprobs = mx.array([-1.0, -1.0, -2.0, -2.0]) # Policy changed significantly
87
+ sequence_lengths = [2, 2] # Two sequences of equal length
88
+
89
+ # Compute GSPO sequence-level importance weights
90
+ gspo_weights = gspo.compute_sequence_importance_weights(
91
+ old_logprobs, new_logprobs, sequence_lengths, clip_ratio=0.2
92
+ )
93
+
94
+ # Compute GRPO token-level importance ratios for comparison
95
+ grpo_ratios = mx.exp(new_logprobs - old_logprobs)
96
+ grpo_ratios_clipped = mx.clip(grpo_ratios, 0.8, 1.2)
97
+
98
+ # GSPO should produce sequence-level weights (2 values)
99
+ # GRPO produces token-level ratios (4 values)
100
+ assert len(gspo_weights) == len(sequence_lengths), \
101
+ f"GSPO should produce {len(sequence_lengths)} sequence weights"
102
+ assert len(grpo_ratios) == len(old_logprobs), \
103
+ f"GRPO should produce {len(old_logprobs)} token ratios"
104
+
105
+ # The approaches should be fundamentally different
106
+ # GSPO normalizes by sequence length, GRPO doesn't
107
+ assert len(gspo_weights) != len(grpo_ratios), \
108
+ "GSPO and GRPO should produce different numbers of weights"
109
+
110
+
111
+ @pytest.mark.unit
112
+ @pytest.mark.algorithm
113
+ class TestGSPOClipping:
114
+ """Test GSPO clipping behavior."""
115
+
116
+ def test_clipping_bounds_respected(self):
117
+ """Test that importance weights respect clipping bounds."""
118
+ # Test extreme case to verify clipping
119
+ old_logprobs = mx.array([-10.0, -1.0]) # Extreme difference
120
+ new_logprobs = mx.array([-1.0, -1.0])
121
+ sequence_lengths = [2]
122
+ clip_ratio = 0.2
123
+
124
+ # Compute sequence weights
125
+ weights = gspo.compute_sequence_importance_weights(
126
+ old_logprobs, new_logprobs, sequence_lengths, clip_ratio=clip_ratio
127
+ )
128
+
129
+ # Weights should be clipped between (1-clip_ratio) and (1+clip_ratio)
130
+ lower_bound = 1.0 - clip_ratio
131
+ upper_bound = 1.0 + clip_ratio
132
+
133
+ # Use tolerance for floating-point comparisons due to MLX float32 precision
134
+ # MLX uses float32 by default, which has precision ~1.19e-7
135
+ tolerance = 1e-6 # Conservative tolerance for float32 precision issues
136
+
137
+ for weight in weights:
138
+ weight_val = float(weight) # Convert MLX scalar to Python float
139
+ assert lower_bound - tolerance <= weight_val <= upper_bound + tolerance, \
140
+ f"Weight {weight_val} outside clipping bounds [{lower_bound}, {upper_bound}] with tolerance {tolerance}"
141
+
142
+ def test_length_normalization_effect(self):
143
+ """Test that GSPO properly normalizes by sequence length."""
144
+ # Identical sequences of different lengths should have similar weights
145
+ old_logprobs_short = mx.array([-1.0, -1.0]) # 2 tokens
146
+ new_logprobs_short = mx.array([-0.5, -0.5]) # Better by 0.5 per token
147
+
148
+ old_logprobs_long = mx.array([-1.0, -1.0, -1.0, -1.0]) # 4 tokens
149
+ new_logprobs_long = mx.array([-0.5, -0.5, -0.5, -0.5]) # Better by 0.5 per token
150
+
151
+ weight_short = gspo.compute_sequence_importance_weights(
152
+ old_logprobs_short, new_logprobs_short, [2], clip_ratio=1.0 # No clipping
153
+ )
154
+ weight_long = gspo.compute_sequence_importance_weights(
155
+ old_logprobs_long, new_logprobs_long, [4], clip_ratio=1.0 # No clipping
156
+ )
157
+
158
+ # Both should be similar due to length normalization
159
+ # Short: exp((sum(-0.5) - sum(-1.0)) / 2) = exp((−1.0 − (−2.0)) / 2) = exp(0.5)
160
+ # Long: exp((sum(-0.5) - sum(-1.0)) / 4) = exp((−2.0 − (−4.0)) / 4) = exp(0.5)
161
+ short_val = float(weight_short[0])
162
+ long_val = float(weight_long[0])
163
+
164
+ # They should be approximately equal due to length normalization
165
+ assert abs(short_val - long_val) < 0.01, \
166
+ f"Length normalization failed: short={short_val}, long={long_val}"
167
+
168
+
169
+ @pytest.mark.integration
170
+ @pytest.mark.algorithm
171
+ @pytest.mark.slow
172
+ class TestGSPOTraining:
173
+ """Integration tests for GSPO training."""
174
+
175
+ def test_gspo_training_step(self):
176
+ """Test a complete GSPO training step."""
177
+ # This is a minimal integration test
178
+ # Create minimal test data
179
+ old_logprobs = mx.array([-1.0, -1.0, -1.0, -1.0])
180
+ new_logprobs = mx.array([-0.8, -0.8, -1.2, -1.2])
181
+ advantages = mx.array([0.5, -0.3])
182
+ sequence_lengths = [2, 2]
183
+
184
+ # Test that we can compute a complete loss
185
+ loss = gspo.gspo_policy_loss(
186
+ old_logprobs, new_logprobs, advantages, sequence_lengths, variant="sequence"
187
+ )
188
+
189
+ assert not mx.isnan(loss) and not mx.isinf(loss), "Training loss should be finite"
190
+ assert float(loss) != 0.0, "Loss should be non-zero for non-trivial inputs"
191
+
192
+ def test_gspo_metrics_computation(self):
193
+ """Test GSPO metrics computation."""
194
+ old_logprobs = mx.array([-1.0, -1.0, -1.0, -1.0])
195
+ new_logprobs = mx.array([-0.8, -0.8, -1.2, -1.2])
196
+ advantages = mx.array([0.5, -0.3])
197
+
198
+ # Test metrics computation
199
+ metrics_fn = gspo.create_gspo_metrics(variant="sequence")
200
+ metrics = metrics_fn(old_logprobs, new_logprobs, advantages)
201
+
202
+ assert isinstance(metrics, dict), "Metrics should be a dictionary"
203
+ assert len(metrics) > 0, "Metrics should not be empty"
204
+
205
+ # Check for expected metric keys
206
+ expected_keys = ['mean_advantage', 'std_advantage']
207
+ for key in expected_keys:
208
+ assert key in metrics, f"Missing expected metric: {key}"
209
+ assert isinstance(metrics[key], (int, float)), \
210
+ f"Metric {key} should be numeric, got {type(metrics[key])}"
211
+
212
+
213
+ if __name__ == "__main__":
214
+ # Allow running this file directly for debugging
215
+ pytest.main([__file__, "-v"])
@@ -0,0 +1,49 @@
1
+ import pytest
2
+
3
+
4
+ @pytest.mark.integration
5
+ def test_e2e_minimal_rollout_grpo():
6
+ """
7
+ Minimal end-to-end rollout + buffer collection using TextGenerationEnv
8
+ with a dummy tokenizer and a trivial policy. This validates that
9
+ the environment returns dict-shaped step results and the runner
10
+ normalization path works as expected.
11
+
12
+ Kept intentionally lightweight for CI (no external model downloads).
13
+ """
14
+ try:
15
+ import mlx.core as mx # type: ignore
16
+ except Exception:
17
+ pytest.skip("MLX not available")
18
+
19
+ from textpolicy.environment.text_generation import TextGenerationEnv
20
+ from textpolicy.rollout.runner import RolloutRunner
21
+ from textpolicy.rollout.strategy import create_strategy
22
+
23
+ class DummyTokenizer:
24
+ def encode(self, text):
25
+ return [ord(c) % 256 for c in text]
26
+
27
+ def decode(self, ids):
28
+ return "".join(chr(int(i) % 256) for i in ids)
29
+
30
+ def reward_fn(prompt, completion, example, **kwargs) -> float:
31
+ # Simple length reward in words
32
+ return float(len(completion.split()))
33
+
34
+ # Create simple environment
35
+ env = TextGenerationEnv(["Hello"], reward_fn, tokenizer=DummyTokenizer())
36
+
37
+ # Policy returns tokens that decode to 'a b c'
38
+ def simple_policy(obs_mx, deterministic=False):
39
+ return mx.array([97, 32, 98, 32, 99], dtype=mx.int32), {}
40
+
41
+ strategy = create_strategy('grpo')
42
+ runner = RolloutRunner(env, policy=simple_policy, strategy=strategy, max_steps=2)
43
+
44
+ buffer = runner.collect()
45
+ assert len(buffer.episodes) >= 1
46
+ ep = buffer.episodes[0]
47
+ # Episode stores rewards in `rew`
48
+ assert len(ep.rew) >= 1
49
+ assert all(r > 0 for r in ep.rew)