textpolicy 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- textpolicy/__init__.py +52 -0
- textpolicy/__main__.py +8 -0
- textpolicy/algorithms/__init__.py +54 -0
- textpolicy/algorithms/grpo.py +642 -0
- textpolicy/algorithms/gspo.py +582 -0
- textpolicy/buffer/__init__.py +23 -0
- textpolicy/buffer/buffer.py +244 -0
- textpolicy/buffer/episode.py +383 -0
- textpolicy/buffer/sampling.py +438 -0
- textpolicy/buffer/storage.py +255 -0
- textpolicy/cli.py +67 -0
- textpolicy/environment/__init__.py +79 -0
- textpolicy/environment/base.py +110 -0
- textpolicy/environment/environment.py +46 -0
- textpolicy/environment/factory.py +103 -0
- textpolicy/environment/gym.py +106 -0
- textpolicy/environment/task_suites.py +51 -0
- textpolicy/environment/text_generation.py +789 -0
- textpolicy/environment/vectorized.py +253 -0
- textpolicy/generation/__init__.py +62 -0
- textpolicy/generation/lora.py +411 -0
- textpolicy/generation/mlx_generation.py +557 -0
- textpolicy/generation/reload.py +253 -0
- textpolicy/rewards/__init__.py +137 -0
- textpolicy/rewards/adapters.py +387 -0
- textpolicy/rewards/basic.py +214 -0
- textpolicy/rewards/integrated_system.py +338 -0
- textpolicy/rewards/mlx_batch_processor.py +447 -0
- textpolicy/rewards/registry.py +293 -0
- textpolicy/rewards/rollout_rewards.py +410 -0
- textpolicy/rewards/verifiers.py +369 -0
- textpolicy/rollout/__init__.py +44 -0
- textpolicy/rollout/aggregator.py +145 -0
- textpolicy/rollout/base.py +108 -0
- textpolicy/rollout/rollout.py +142 -0
- textpolicy/rollout/runner.py +280 -0
- textpolicy/rollout/strategy.py +208 -0
- textpolicy/rollout/worker.py +194 -0
- textpolicy/training/__init__.py +14 -0
- textpolicy/training/metrics.py +242 -0
- textpolicy/training/rollout_manager.py +78 -0
- textpolicy/training/trainer.py +684 -0
- textpolicy/utils/__init__.py +40 -0
- textpolicy/utils/benchmarking.py +489 -0
- textpolicy/utils/data.py +60 -0
- textpolicy/utils/debug.py +170 -0
- textpolicy/utils/environment.py +349 -0
- textpolicy/utils/logging/__init__.py +22 -0
- textpolicy/utils/logging/base.py +48 -0
- textpolicy/utils/logging/console.py +61 -0
- textpolicy/utils/logging/factory.py +133 -0
- textpolicy/utils/logging/multi.py +83 -0
- textpolicy/utils/logging/tensorboard.py +65 -0
- textpolicy/utils/logging/wandb.py +72 -0
- textpolicy/utils/memory.py +118 -0
- textpolicy/utils/performance.py +464 -0
- textpolicy/utils/timing.py +171 -0
- textpolicy/validate.py +101 -0
- textpolicy/validation/__init__.py +13 -0
- textpolicy/validation/logprob_validation.py +315 -0
- textpolicy-0.1.0.dist-info/METADATA +99 -0
- textpolicy-0.1.0.dist-info/RECORD +66 -0
- textpolicy-0.1.0.dist-info/entry_points.txt +2 -0
- textpolicy-0.0.1.dist-info/METADATA +0 -10
- textpolicy-0.0.1.dist-info/RECORD +0 -6
- {textpolicy-0.0.1.dist-info → textpolicy-0.1.0.dist-info}/WHEEL +0 -0
- {textpolicy-0.0.1.dist-info → textpolicy-0.1.0.dist-info}/licenses/LICENSE +0 -0
- {textpolicy-0.0.1.dist-info → textpolicy-0.1.0.dist-info}/top_level.txt +0 -0
textpolicy/__init__.py
CHANGED
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
"""
|
|
2
|
+
TextPolicy: RL library for text generation with MLX.
|
|
3
|
+
|
|
4
|
+
This module exposes the public API entry points for algorithms,
|
|
5
|
+
training, generation, environment, and rewards.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
# Submodule imports for building the public API
|
|
9
|
+
from . import algorithms, generation, training
|
|
10
|
+
|
|
11
|
+
# Export RL algorithms as defined in textpolicy.algorithms.__all__
|
|
12
|
+
from .algorithms import * # noqa: F403,F401
|
|
13
|
+
|
|
14
|
+
# Export text generation utilities (load_model, generate_tokens, etc.)
|
|
15
|
+
from .generation import * # noqa: F403,F401
|
|
16
|
+
|
|
17
|
+
# Export training components (Trainer, RolloutManager, TrainingMetrics)
|
|
18
|
+
from .training import * # noqa: F403,F401
|
|
19
|
+
|
|
20
|
+
# Export environment components and factory functions
|
|
21
|
+
from .environment import (
|
|
22
|
+
TextGenerationEnvironment,
|
|
23
|
+
TextGenerationEnv,
|
|
24
|
+
create_text_generation_test_env,
|
|
25
|
+
validate_learning_progress,
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
# Export installation validation utilities
|
|
29
|
+
from .validate import validate_installation
|
|
30
|
+
|
|
31
|
+
# Export core reward functions and the reward decorator
|
|
32
|
+
from .rewards.basic import length_reward, keyword_reward, perplexity_reward, accuracy_reward
|
|
33
|
+
from .rewards.registry import reward
|
|
34
|
+
|
|
35
|
+
# Build __all__ combining submodule __all__ lists and additional symbols
|
|
36
|
+
__all__ = (
|
|
37
|
+
algorithms.__all__
|
|
38
|
+
+ generation.__all__
|
|
39
|
+
+ training.__all__
|
|
40
|
+
+ [
|
|
41
|
+
"TextGenerationEnvironment",
|
|
42
|
+
"TextGenerationEnv",
|
|
43
|
+
"create_text_generation_test_env",
|
|
44
|
+
"validate_learning_progress",
|
|
45
|
+
"validate_installation",
|
|
46
|
+
"length_reward",
|
|
47
|
+
"keyword_reward",
|
|
48
|
+
"perplexity_reward",
|
|
49
|
+
"accuracy_reward",
|
|
50
|
+
"reward",
|
|
51
|
+
]
|
|
52
|
+
)
|
textpolicy/__main__.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
# textpolicy/algorithms/__init__.py
|
|
2
|
+
"""
|
|
3
|
+
Reinforcement learning algorithms for MLX and Apple Silicon.
|
|
4
|
+
|
|
5
|
+
GRPO: group-relative advantages with PPO-style clipping.
|
|
6
|
+
GSPO: sequence-level importance sampling (sequence, token, and hybrid variants).
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from .grpo import (
|
|
10
|
+
compute_advantages,
|
|
11
|
+
compute_advantages_dr_grpo,
|
|
12
|
+
policy_loss,
|
|
13
|
+
grpo_loss,
|
|
14
|
+
compute_metrics,
|
|
15
|
+
entropy_bonus,
|
|
16
|
+
select_all_data,
|
|
17
|
+
select_recent_data
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
from .gspo import (
|
|
21
|
+
create_gspo_policy_loss,
|
|
22
|
+
create_gspo_metrics,
|
|
23
|
+
policy_loss_sequence,
|
|
24
|
+
policy_loss_hybrid,
|
|
25
|
+
create_policy_loss_hybrid,
|
|
26
|
+
policy_loss_token,
|
|
27
|
+
compute_metrics_sequence,
|
|
28
|
+
compute_metrics_hybrid,
|
|
29
|
+
compute_metrics_token,
|
|
30
|
+
select_gspo_data
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
__all__ = [
|
|
34
|
+
# GRPO functions
|
|
35
|
+
"compute_advantages",
|
|
36
|
+
"compute_advantages_dr_grpo",
|
|
37
|
+
"policy_loss",
|
|
38
|
+
"grpo_loss",
|
|
39
|
+
"compute_metrics",
|
|
40
|
+
"entropy_bonus",
|
|
41
|
+
"select_all_data",
|
|
42
|
+
"select_recent_data",
|
|
43
|
+
# GSPO functions
|
|
44
|
+
"create_gspo_policy_loss",
|
|
45
|
+
"create_gspo_metrics",
|
|
46
|
+
"policy_loss_sequence",
|
|
47
|
+
"policy_loss_hybrid",
|
|
48
|
+
"create_policy_loss_hybrid",
|
|
49
|
+
"policy_loss_token",
|
|
50
|
+
"compute_metrics_sequence",
|
|
51
|
+
"compute_metrics_hybrid",
|
|
52
|
+
"compute_metrics_token",
|
|
53
|
+
"select_gspo_data"
|
|
54
|
+
]
|