textpolicy 0.1.1__py3-none-any.whl → 0.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- textpolicy/__init__.py +3 -0
- textpolicy/algorithms/__init__.py +29 -4
- textpolicy/algorithms/grpo.py +771 -361
- textpolicy/algorithms/length_shaping.py +151 -0
- textpolicy/analysis/__init__.py +23 -0
- textpolicy/analysis/emergence_logger.py +248 -0
- textpolicy/analysis/planning_patterns.py +105 -0
- textpolicy/analysis/serialization.py +65 -0
- textpolicy/generation/mlx_generation.py +36 -21
- textpolicy/tasks/__init__.py +7 -0
- textpolicy/tasks/countdown/__init__.py +21 -0
- textpolicy/tasks/countdown/dataset.py +163 -0
- textpolicy/tasks/countdown/evaluator.py +197 -0
- textpolicy/tasks/countdown/prompt.py +89 -0
- textpolicy/tasks/countdown/reward.py +56 -0
- textpolicy/training/trainer.py +41 -21
- {textpolicy-0.1.1.dist-info → textpolicy-0.1.3.dist-info}/METADATA +3 -3
- {textpolicy-0.1.1.dist-info → textpolicy-0.1.3.dist-info}/RECORD +22 -11
- {textpolicy-0.1.1.dist-info → textpolicy-0.1.3.dist-info}/WHEEL +0 -0
- {textpolicy-0.1.1.dist-info → textpolicy-0.1.3.dist-info}/entry_points.txt +0 -0
- {textpolicy-0.1.1.dist-info → textpolicy-0.1.3.dist-info}/licenses/LICENSE +0 -0
- {textpolicy-0.1.1.dist-info → textpolicy-0.1.3.dist-info}/top_level.txt +0 -0
textpolicy/__init__.py
CHANGED
|
@@ -8,6 +8,9 @@ training, generation, environment, and rewards.
|
|
|
8
8
|
# Submodule imports for building the public API
|
|
9
9
|
from . import algorithms, generation, training
|
|
10
10
|
|
|
11
|
+
# Import tasks to trigger auto-registration of task reward functions
|
|
12
|
+
from . import tasks # noqa: F401
|
|
13
|
+
|
|
11
14
|
# Export RL algorithms as defined in textpolicy.algorithms.__all__
|
|
12
15
|
from .algorithms import * # noqa: F403,F401
|
|
13
16
|
|
|
@@ -7,6 +7,7 @@ GSPO: sequence-level importance sampling (sequence, token, and hybrid variants).
|
|
|
7
7
|
"""
|
|
8
8
|
|
|
9
9
|
from .grpo import (
|
|
10
|
+
# Core GRPO functions
|
|
10
11
|
compute_advantages,
|
|
11
12
|
compute_advantages_dr_grpo,
|
|
12
13
|
policy_loss,
|
|
@@ -14,7 +15,19 @@ from .grpo import (
|
|
|
14
15
|
compute_metrics,
|
|
15
16
|
entropy_bonus,
|
|
16
17
|
select_all_data,
|
|
17
|
-
select_recent_data
|
|
18
|
+
select_recent_data,
|
|
19
|
+
# Compiled versions
|
|
20
|
+
compute_advantages_compiled,
|
|
21
|
+
policy_loss_compiled,
|
|
22
|
+
policy_loss_compiled_constant_norm,
|
|
23
|
+
# Length shaping (DAPO-style soft overlong penalties)
|
|
24
|
+
compute_length_penalty,
|
|
25
|
+
apply_length_shaping,
|
|
26
|
+
compute_length_shaping_stats,
|
|
27
|
+
# Dynamic batch filtering
|
|
28
|
+
filter_informative_prompts,
|
|
29
|
+
compute_prompt_group_stats,
|
|
30
|
+
select_informative_data,
|
|
18
31
|
)
|
|
19
32
|
|
|
20
33
|
from .gspo import (
|
|
@@ -31,15 +44,27 @@ from .gspo import (
|
|
|
31
44
|
)
|
|
32
45
|
|
|
33
46
|
__all__ = [
|
|
34
|
-
# GRPO functions
|
|
47
|
+
# GRPO core functions
|
|
35
48
|
"compute_advantages",
|
|
36
49
|
"compute_advantages_dr_grpo",
|
|
37
50
|
"policy_loss",
|
|
38
|
-
"grpo_loss",
|
|
51
|
+
"grpo_loss",
|
|
39
52
|
"compute_metrics",
|
|
40
53
|
"entropy_bonus",
|
|
41
54
|
"select_all_data",
|
|
42
55
|
"select_recent_data",
|
|
56
|
+
# GRPO compiled versions
|
|
57
|
+
"compute_advantages_compiled",
|
|
58
|
+
"policy_loss_compiled",
|
|
59
|
+
"policy_loss_compiled_constant_norm",
|
|
60
|
+
# GRPO length shaping (DAPO-style)
|
|
61
|
+
"compute_length_penalty",
|
|
62
|
+
"apply_length_shaping",
|
|
63
|
+
"compute_length_shaping_stats",
|
|
64
|
+
# GRPO dynamic batch filtering
|
|
65
|
+
"filter_informative_prompts",
|
|
66
|
+
"compute_prompt_group_stats",
|
|
67
|
+
"select_informative_data",
|
|
43
68
|
# GSPO functions
|
|
44
69
|
"create_gspo_policy_loss",
|
|
45
70
|
"create_gspo_metrics",
|
|
@@ -50,5 +75,5 @@ __all__ = [
|
|
|
50
75
|
"compute_metrics_sequence",
|
|
51
76
|
"compute_metrics_hybrid",
|
|
52
77
|
"compute_metrics_token",
|
|
53
|
-
"select_gspo_data"
|
|
78
|
+
"select_gspo_data",
|
|
54
79
|
]
|