textpolicy 0.1.0__py3-none-any.whl → 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
textpolicy/__init__.py CHANGED
@@ -30,7 +30,7 @@ from .validate import validate_installation
30
30
 
31
31
  # Export core reward functions and the reward decorator
32
32
  from .rewards.basic import length_reward, keyword_reward, perplexity_reward, accuracy_reward
33
- from .rewards.registry import reward
33
+ from .rewards.registry import reward, verifier
34
34
 
35
35
  # Build __all__ combining submodule __all__ lists and additional symbols
36
36
  __all__ = (
@@ -48,5 +48,6 @@ __all__ = (
48
48
  "perplexity_reward",
49
49
  "accuracy_reward",
50
50
  "reward",
51
+ "verifier",
51
52
  ]
52
53
  )
@@ -630,24 +630,31 @@ class TextGenerationEnv(Environment):
630
630
  reward_fn: Callable[[str, str, dict], float],
631
631
  max_tokens: int = 25,
632
632
  seed: int = 42,
633
- tokenizer: Any = None
633
+ tokenizer: Any = None,
634
+ examples: Optional[List[dict]] = None
634
635
  ):
635
636
  """
636
637
  Initialize simple text generation environment.
637
-
638
+
638
639
  Args:
639
640
  prompts: List of prompts to cycle through
640
641
  reward_fn: Function that computes reward from (prompt, completion, example)
641
642
  max_tokens: Maximum tokens to generate per response
642
643
  seed: Random seed for reproducible behavior
643
644
  tokenizer: Tokenizer for converting prompts to tokens (required for MLX compatibility)
645
+ examples: Optional list of example dicts to pass to reward function. If provided,
646
+ must have same length as prompts. examples[i] is passed when prompts[i] is used.
644
647
  """
645
648
  super().__init__()
646
649
 
647
650
  if tokenizer is None:
648
651
  raise ValueError("tokenizer is required for TextGenerationEnv to work with MLX rollout system")
649
-
652
+
653
+ if examples is not None and len(examples) != len(prompts):
654
+ raise ValueError(f"examples length ({len(examples)}) must match prompts length ({len(prompts)})")
655
+
650
656
  self.prompts = prompts
657
+ self.examples = examples if examples is not None else [{} for _ in prompts]
651
658
  self.reward_fn = reward_fn
652
659
  self.max_tokens = max_tokens
653
660
  self.tokenizer = tokenizer
@@ -735,10 +742,11 @@ class TextGenerationEnv(Environment):
735
742
 
736
743
  # Compute reward using provided reward function
737
744
  # Pass tokenizer for EOS token detection and truncation detection
745
+ prompt_index = self.current_episode % len(self.prompts)
738
746
  reward = self.reward_fn(
739
747
  prompt=self.current_prompt,
740
748
  completion=response_text,
741
- example={},
749
+ example=self.examples[prompt_index],
742
750
  tokenizer=self.tokenizer, # Pass tokenizer for EOS detection
743
751
  truncated=truncated # Pass truncation flag from environment
744
752
  )
@@ -1,13 +1,23 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: textpolicy
3
- Version: 0.1.0
4
- Summary: MLX-optimized reward and verification system for text generation RL
3
+ Version: 0.1.2
4
+ Summary: Reinforcement learning for text generation on MLX (Apple Silicon): GRPO/GSPO, environments, rollout, rewards, LoRA/QLoRA
5
+ Project-URL: Homepage, https://github.com/teilomillet/textpolicy
6
+ Project-URL: Repository, https://github.com/teilomillet/textpolicy
7
+ Project-URL: Documentation, https://github.com/teilomillet/textpolicy#readme
8
+ Project-URL: Changelog, https://github.com/teilomillet/textpolicy/blob/main/CHANGELOG.md
9
+ Keywords: reinforcement-learning,text-generation,mlx,apple-silicon,lora,qlora,grpo,gspo,rlhf
10
+ Classifier: Programming Language :: Python :: 3
11
+ Classifier: Operating System :: MacOS
12
+ Classifier: Intended Audience :: Developers
13
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
14
+ Classifier: License :: OSI Approved :: MIT License
5
15
  Requires-Python: >=3.12
6
16
  Description-Content-Type: text/markdown
7
17
  License-File: LICENSE
8
18
  Requires-Dist: numpy>=2.3.2
9
- Requires-Dist: mlx>=0.21.0
10
- Requires-Dist: mlx-lm>=0.21.0
19
+ Requires-Dist: mlx>=0.22.0
20
+ Requires-Dist: mlx-lm>=0.22.0
11
21
  Requires-Dist: gymnasium>=0.29.0
12
22
  Requires-Dist: psutil>=7.0.0
13
23
  Requires-Dist: wandb>=0.21.1
@@ -1,4 +1,4 @@
1
- textpolicy/__init__.py,sha256=u4u0fIHfAvXFN2ATHCsG0Tx4xGfOcfuOITBTmKbGhrw,1576
1
+ textpolicy/__init__.py,sha256=vDAHJ826gKuTZUjcAftzz-RTX8KuOjH50Uj1RMhjTIQ,1606
2
2
  textpolicy/__main__.py,sha256=IlGmjJaW-DJUC7yhxUhbwNOZA3GxkeQGkVbFdS3_wBI,136
3
3
  textpolicy/cli.py,sha256=3CcJzrRlin1pgd6Mh312Xp3-EihHtTSvhakyYpdfacs,2107
4
4
  textpolicy/validate.py,sha256=lxmegz83B_c-PS3cFHaaL3c9fgWrEaLsDLkpPFtSj8Y,3780
@@ -16,7 +16,7 @@ textpolicy/environment/environment.py,sha256=o8-RY6wj5xrzDBp77HoY2At3XlBwvreF3DK
16
16
  textpolicy/environment/factory.py,sha256=pebQo1_M3sMF8Pdc9yvpdXzRXfIDllKJoAQAjQbif0E,3124
17
17
  textpolicy/environment/gym.py,sha256=P8Bi8PlDtcWWa9uLuCjkhZnYRVs-mg6iSJVSBkG99f8,3186
18
18
  textpolicy/environment/task_suites.py,sha256=ssPnw2Y3eGYaskWf8dUab4rNu_Bx5L284b3VdhgvSPM,1544
19
- textpolicy/environment/text_generation.py,sha256=BXSJS_05Q89cPFfdXcUKxOXSZm3HBR3KMi55BnOdoLY,31258
19
+ textpolicy/environment/text_generation.py,sha256=Jql0pEfrPp9tqNsPOAdIP-UYoAUsfV969TMR2uPkUp4,31837
20
20
  textpolicy/environment/vectorized.py,sha256=ZROtpmdbh1Oi8c0b0D_vmVzqI16Cp2WZTmkjkRbMoDg,9932
21
21
  textpolicy/generation/__init__.py,sha256=J3dc0SPAZChJTsRn47tz8FfIp3XwNgZ-8_H9VBpQYvQ,1266
22
22
  textpolicy/generation/lora.py,sha256=xSKRczJY20BrkkU1SSgBtDc30tZjdFE7FhEZPUEoiyg,13747
@@ -58,9 +58,9 @@ textpolicy/utils/logging/tensorboard.py,sha256=aY9YMReSJkWEhy6SdAAUlHSB4lzDecivB
58
58
  textpolicy/utils/logging/wandb.py,sha256=U4pxuZNOz2l8XiymK8OFbCpiRTBOLNtnZakC_udttfQ,2206
59
59
  textpolicy/validation/__init__.py,sha256=KcyppNi91w0bF51gZ0ykUIKEiF7z6TT37uuavMFScnA,328
60
60
  textpolicy/validation/logprob_validation.py,sha256=G_CCy5NRDUTmo7WZIChhNVM3NtP1VmWAjdd5z6TIvos,11749
61
- textpolicy-0.1.0.dist-info/licenses/LICENSE,sha256=AYDHSNRbiqZt4HHH1gaOoQ2hjYjK4bqw4Vd9UyKzx18,1065
62
- textpolicy-0.1.0.dist-info/METADATA,sha256=XdyIh8e2IIRymRf31vu1MuVM2aaut2qsZ5PcsjHrl9Y,3199
63
- textpolicy-0.1.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
64
- textpolicy-0.1.0.dist-info/entry_points.txt,sha256=d0Cj5boT6k_l_beVPWPt9LZMllsN4kbIUmsNsn1BANE,51
65
- textpolicy-0.1.0.dist-info/top_level.txt,sha256=Ww6_QEF71dI-AYCaugiGeGcgMoFAixSOszSoRsyX-E0,11
66
- textpolicy-0.1.0.dist-info/RECORD,,
61
+ textpolicy-0.1.2.dist-info/licenses/LICENSE,sha256=AYDHSNRbiqZt4HHH1gaOoQ2hjYjK4bqw4Vd9UyKzx18,1065
62
+ textpolicy-0.1.2.dist-info/METADATA,sha256=HXAh6fGcTtNez86WFNlr6OnIQZNcswptUXPnBSmXQHM,3895
63
+ textpolicy-0.1.2.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
64
+ textpolicy-0.1.2.dist-info/entry_points.txt,sha256=d0Cj5boT6k_l_beVPWPt9LZMllsN4kbIUmsNsn1BANE,51
65
+ textpolicy-0.1.2.dist-info/top_level.txt,sha256=Ww6_QEF71dI-AYCaugiGeGcgMoFAixSOszSoRsyX-E0,11
66
+ textpolicy-0.1.2.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (80.9.0)
2
+ Generator: setuptools (80.10.2)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5