textpolicy 0.1.0__py3-none-any.whl → 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- textpolicy/__init__.py +2 -1
- textpolicy/environment/text_generation.py +12 -4
- {textpolicy-0.1.0.dist-info → textpolicy-0.1.2.dist-info}/METADATA +14 -4
- {textpolicy-0.1.0.dist-info → textpolicy-0.1.2.dist-info}/RECORD +8 -8
- {textpolicy-0.1.0.dist-info → textpolicy-0.1.2.dist-info}/WHEEL +1 -1
- {textpolicy-0.1.0.dist-info → textpolicy-0.1.2.dist-info}/entry_points.txt +0 -0
- {textpolicy-0.1.0.dist-info → textpolicy-0.1.2.dist-info}/licenses/LICENSE +0 -0
- {textpolicy-0.1.0.dist-info → textpolicy-0.1.2.dist-info}/top_level.txt +0 -0
textpolicy/__init__.py
CHANGED
|
@@ -30,7 +30,7 @@ from .validate import validate_installation
|
|
|
30
30
|
|
|
31
31
|
# Export core reward functions and the reward decorator
|
|
32
32
|
from .rewards.basic import length_reward, keyword_reward, perplexity_reward, accuracy_reward
|
|
33
|
-
from .rewards.registry import reward
|
|
33
|
+
from .rewards.registry import reward, verifier
|
|
34
34
|
|
|
35
35
|
# Build __all__ combining submodule __all__ lists and additional symbols
|
|
36
36
|
__all__ = (
|
|
@@ -48,5 +48,6 @@ __all__ = (
|
|
|
48
48
|
"perplexity_reward",
|
|
49
49
|
"accuracy_reward",
|
|
50
50
|
"reward",
|
|
51
|
+
"verifier",
|
|
51
52
|
]
|
|
52
53
|
)
|
|
@@ -630,24 +630,31 @@ class TextGenerationEnv(Environment):
|
|
|
630
630
|
reward_fn: Callable[[str, str, dict], float],
|
|
631
631
|
max_tokens: int = 25,
|
|
632
632
|
seed: int = 42,
|
|
633
|
-
tokenizer: Any = None
|
|
633
|
+
tokenizer: Any = None,
|
|
634
|
+
examples: Optional[List[dict]] = None
|
|
634
635
|
):
|
|
635
636
|
"""
|
|
636
637
|
Initialize simple text generation environment.
|
|
637
|
-
|
|
638
|
+
|
|
638
639
|
Args:
|
|
639
640
|
prompts: List of prompts to cycle through
|
|
640
641
|
reward_fn: Function that computes reward from (prompt, completion, example)
|
|
641
642
|
max_tokens: Maximum tokens to generate per response
|
|
642
643
|
seed: Random seed for reproducible behavior
|
|
643
644
|
tokenizer: Tokenizer for converting prompts to tokens (required for MLX compatibility)
|
|
645
|
+
examples: Optional list of example dicts to pass to reward function. If provided,
|
|
646
|
+
must have same length as prompts. examples[i] is passed when prompts[i] is used.
|
|
644
647
|
"""
|
|
645
648
|
super().__init__()
|
|
646
649
|
|
|
647
650
|
if tokenizer is None:
|
|
648
651
|
raise ValueError("tokenizer is required for TextGenerationEnv to work with MLX rollout system")
|
|
649
|
-
|
|
652
|
+
|
|
653
|
+
if examples is not None and len(examples) != len(prompts):
|
|
654
|
+
raise ValueError(f"examples length ({len(examples)}) must match prompts length ({len(prompts)})")
|
|
655
|
+
|
|
650
656
|
self.prompts = prompts
|
|
657
|
+
self.examples = examples if examples is not None else [{} for _ in prompts]
|
|
651
658
|
self.reward_fn = reward_fn
|
|
652
659
|
self.max_tokens = max_tokens
|
|
653
660
|
self.tokenizer = tokenizer
|
|
@@ -735,10 +742,11 @@ class TextGenerationEnv(Environment):
|
|
|
735
742
|
|
|
736
743
|
# Compute reward using provided reward function
|
|
737
744
|
# Pass tokenizer for EOS token detection and truncation detection
|
|
745
|
+
prompt_index = self.current_episode % len(self.prompts)
|
|
738
746
|
reward = self.reward_fn(
|
|
739
747
|
prompt=self.current_prompt,
|
|
740
748
|
completion=response_text,
|
|
741
|
-
example=
|
|
749
|
+
example=self.examples[prompt_index],
|
|
742
750
|
tokenizer=self.tokenizer, # Pass tokenizer for EOS detection
|
|
743
751
|
truncated=truncated # Pass truncation flag from environment
|
|
744
752
|
)
|
|
@@ -1,13 +1,23 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: textpolicy
|
|
3
|
-
Version: 0.1.
|
|
4
|
-
Summary: MLX
|
|
3
|
+
Version: 0.1.2
|
|
4
|
+
Summary: Reinforcement learning for text generation on MLX (Apple Silicon): GRPO/GSPO, environments, rollout, rewards, LoRA/QLoRA
|
|
5
|
+
Project-URL: Homepage, https://github.com/teilomillet/textpolicy
|
|
6
|
+
Project-URL: Repository, https://github.com/teilomillet/textpolicy
|
|
7
|
+
Project-URL: Documentation, https://github.com/teilomillet/textpolicy#readme
|
|
8
|
+
Project-URL: Changelog, https://github.com/teilomillet/textpolicy/blob/main/CHANGELOG.md
|
|
9
|
+
Keywords: reinforcement-learning,text-generation,mlx,apple-silicon,lora,qlora,grpo,gspo,rlhf
|
|
10
|
+
Classifier: Programming Language :: Python :: 3
|
|
11
|
+
Classifier: Operating System :: MacOS
|
|
12
|
+
Classifier: Intended Audience :: Developers
|
|
13
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
14
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
5
15
|
Requires-Python: >=3.12
|
|
6
16
|
Description-Content-Type: text/markdown
|
|
7
17
|
License-File: LICENSE
|
|
8
18
|
Requires-Dist: numpy>=2.3.2
|
|
9
|
-
Requires-Dist: mlx>=0.
|
|
10
|
-
Requires-Dist: mlx-lm>=0.
|
|
19
|
+
Requires-Dist: mlx>=0.22.0
|
|
20
|
+
Requires-Dist: mlx-lm>=0.22.0
|
|
11
21
|
Requires-Dist: gymnasium>=0.29.0
|
|
12
22
|
Requires-Dist: psutil>=7.0.0
|
|
13
23
|
Requires-Dist: wandb>=0.21.1
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
textpolicy/__init__.py,sha256=
|
|
1
|
+
textpolicy/__init__.py,sha256=vDAHJ826gKuTZUjcAftzz-RTX8KuOjH50Uj1RMhjTIQ,1606
|
|
2
2
|
textpolicy/__main__.py,sha256=IlGmjJaW-DJUC7yhxUhbwNOZA3GxkeQGkVbFdS3_wBI,136
|
|
3
3
|
textpolicy/cli.py,sha256=3CcJzrRlin1pgd6Mh312Xp3-EihHtTSvhakyYpdfacs,2107
|
|
4
4
|
textpolicy/validate.py,sha256=lxmegz83B_c-PS3cFHaaL3c9fgWrEaLsDLkpPFtSj8Y,3780
|
|
@@ -16,7 +16,7 @@ textpolicy/environment/environment.py,sha256=o8-RY6wj5xrzDBp77HoY2At3XlBwvreF3DK
|
|
|
16
16
|
textpolicy/environment/factory.py,sha256=pebQo1_M3sMF8Pdc9yvpdXzRXfIDllKJoAQAjQbif0E,3124
|
|
17
17
|
textpolicy/environment/gym.py,sha256=P8Bi8PlDtcWWa9uLuCjkhZnYRVs-mg6iSJVSBkG99f8,3186
|
|
18
18
|
textpolicy/environment/task_suites.py,sha256=ssPnw2Y3eGYaskWf8dUab4rNu_Bx5L284b3VdhgvSPM,1544
|
|
19
|
-
textpolicy/environment/text_generation.py,sha256=
|
|
19
|
+
textpolicy/environment/text_generation.py,sha256=Jql0pEfrPp9tqNsPOAdIP-UYoAUsfV969TMR2uPkUp4,31837
|
|
20
20
|
textpolicy/environment/vectorized.py,sha256=ZROtpmdbh1Oi8c0b0D_vmVzqI16Cp2WZTmkjkRbMoDg,9932
|
|
21
21
|
textpolicy/generation/__init__.py,sha256=J3dc0SPAZChJTsRn47tz8FfIp3XwNgZ-8_H9VBpQYvQ,1266
|
|
22
22
|
textpolicy/generation/lora.py,sha256=xSKRczJY20BrkkU1SSgBtDc30tZjdFE7FhEZPUEoiyg,13747
|
|
@@ -58,9 +58,9 @@ textpolicy/utils/logging/tensorboard.py,sha256=aY9YMReSJkWEhy6SdAAUlHSB4lzDecivB
|
|
|
58
58
|
textpolicy/utils/logging/wandb.py,sha256=U4pxuZNOz2l8XiymK8OFbCpiRTBOLNtnZakC_udttfQ,2206
|
|
59
59
|
textpolicy/validation/__init__.py,sha256=KcyppNi91w0bF51gZ0ykUIKEiF7z6TT37uuavMFScnA,328
|
|
60
60
|
textpolicy/validation/logprob_validation.py,sha256=G_CCy5NRDUTmo7WZIChhNVM3NtP1VmWAjdd5z6TIvos,11749
|
|
61
|
-
textpolicy-0.1.
|
|
62
|
-
textpolicy-0.1.
|
|
63
|
-
textpolicy-0.1.
|
|
64
|
-
textpolicy-0.1.
|
|
65
|
-
textpolicy-0.1.
|
|
66
|
-
textpolicy-0.1.
|
|
61
|
+
textpolicy-0.1.2.dist-info/licenses/LICENSE,sha256=AYDHSNRbiqZt4HHH1gaOoQ2hjYjK4bqw4Vd9UyKzx18,1065
|
|
62
|
+
textpolicy-0.1.2.dist-info/METADATA,sha256=HXAh6fGcTtNez86WFNlr6OnIQZNcswptUXPnBSmXQHM,3895
|
|
63
|
+
textpolicy-0.1.2.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
|
|
64
|
+
textpolicy-0.1.2.dist-info/entry_points.txt,sha256=d0Cj5boT6k_l_beVPWPt9LZMllsN4kbIUmsNsn1BANE,51
|
|
65
|
+
textpolicy-0.1.2.dist-info/top_level.txt,sha256=Ww6_QEF71dI-AYCaugiGeGcgMoFAixSOszSoRsyX-E0,11
|
|
66
|
+
textpolicy-0.1.2.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|