trinity-rft 0.2.0__tar.gz → 0.2.1.dev0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/PKG-INFO +38 -19
- {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/README.md +31 -16
- {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/pyproject.toml +8 -4
- {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/__init__.py +1 -1
- {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/algorithm/__init__.py +3 -0
- trinity_rft-0.2.1.dev0/trinity/algorithm/add_strategy/__init__.py +25 -0
- trinity_rft-0.2.1.dev0/trinity/algorithm/add_strategy/add_strategy.py +230 -0
- trinity_rft-0.2.1.dev0/trinity/algorithm/add_strategy/correct_bias_add_strategy.py +54 -0
- trinity_rft-0.2.1.dev0/trinity/algorithm/add_strategy/duplicate_add_strategy.py +72 -0
- trinity_rft-0.2.1.dev0/trinity/algorithm/add_strategy/step_wise_add_strategy.py +123 -0
- {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/algorithm/advantage_fn/grpo_advantage.py +2 -2
- {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/algorithm/advantage_fn/opmd_advantage.py +4 -2
- {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/algorithm/advantage_fn/rloo_advantage.py +1 -1
- {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/algorithm/algorithm.py +59 -11
- {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/algorithm/entropy_loss_fn/entropy_loss_fn.py +27 -0
- {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/algorithm/policy_loss_fn/__init__.py +10 -0
- trinity_rft-0.2.1.dev0/trinity/algorithm/policy_loss_fn/chord_policy_loss.py +257 -0
- trinity_rft-0.2.1.dev0/trinity/algorithm/policy_loss_fn/gspo_policy_loss.py +76 -0
- {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/algorithm/policy_loss_fn/mix_policy_loss.py +23 -26
- trinity_rft-0.2.1.dev0/trinity/algorithm/sample_strategy/mix_sample_strategy.py +98 -0
- trinity_rft-0.2.1.dev0/trinity/algorithm/sample_strategy/sample_strategy.py +92 -0
- trinity_rft-0.2.1.dev0/trinity/algorithm/sample_strategy/utils.py +42 -0
- trinity_rft-0.2.1.dev0/trinity/buffer/__init__.py +15 -0
- {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/buffer/buffer_reader.py +6 -0
- trinity_rft-0.2.0/trinity/buffer/priority_queue.py → trinity_rft-0.2.1.dev0/trinity/buffer/queue.py +96 -27
- {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/buffer/ray_wrapper.py +111 -4
- {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/buffer/reader/file_reader.py +69 -33
- {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/buffer/reader/queue_reader.py +16 -3
- {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/buffer/reader/sql_reader.py +8 -0
- {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/buffer/schema/sql_schema.py +3 -3
- {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/buffer/writer/queue_writer.py +2 -2
- {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/cli/launcher.py +16 -10
- {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/common/config.py +187 -73
- {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/common/constants.py +15 -0
- trinity_rft-0.2.1.dev0/trinity/common/experience.py +547 -0
- {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/common/models/__init__.py +9 -14
- trinity_rft-0.2.1.dev0/trinity/common/models/api/vllm_patch.py +374 -0
- trinity_rft-0.2.1.dev0/trinity/common/models/model.py +214 -0
- {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/common/models/utils.py +13 -8
- trinity_rft-0.2.0/trinity/common/models/vllm_async_model.py → trinity_rft-0.2.1.dev0/trinity/common/models/vllm_model.py +50 -52
- {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/common/models/vllm_worker.py +9 -14
- {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/common/rewards/__init__.py +2 -0
- trinity_rft-0.2.1.dev0/trinity/common/rewards/dapo_reward.py +67 -0
- {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/common/rewards/math_reward.py +2 -1
- {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/common/rewards/reward_fn.py +1 -1
- {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/common/verl_config.py +33 -15
- trinity_rft-0.2.1.dev0/trinity/common/workflows/__init__.py +32 -0
- {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/common/workflows/customized_math_workflows.py +17 -13
- trinity_rft-0.2.1.dev0/trinity/common/workflows/customized_toolcall_workflows.py +265 -0
- trinity_rft-0.2.1.dev0/trinity/common/workflows/envs/agentscope/agentscope_react_workflow.py +162 -0
- trinity_rft-0.2.1.dev0/trinity/common/workflows/envs/alfworld/RAFT_alfworld_workflow.py +225 -0
- trinity_rft-0.2.1.dev0/trinity/common/workflows/envs/alfworld/RAFT_reflect_alfworld_workflow.py +310 -0
- trinity_rft-0.2.1.dev0/trinity/common/workflows/envs/alfworld/RAFT_utils.py +196 -0
- {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/common/workflows/envs/alfworld/alfworld_workflow.py +119 -3
- {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/common/workflows/envs/sciworld/sciworld_workflow.py +1 -1
- {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/common/workflows/envs/webshop/webshop_workflow.py +1 -1
- trinity_rft-0.2.1.dev0/trinity/common/workflows/eval_workflow.py +90 -0
- {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/common/workflows/math_rm_workflow.py +5 -3
- trinity_rft-0.2.1.dev0/trinity/common/workflows/step_wise_workflow.py +147 -0
- {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/common/workflows/workflow.py +52 -19
- {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/explorer/explorer.py +143 -123
- {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/explorer/scheduler.py +64 -41
- {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/explorer/workflow_runner.py +46 -20
- trinity_rft-0.2.1.dev0/trinity/manager/__init__.py +7 -0
- {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/manager/config_manager.py +42 -14
- {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/manager/config_registry/buffer_config_manager.py +28 -2
- {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/manager/config_registry/explorer_config_manager.py +24 -1
- {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/manager/config_registry/model_config_manager.py +5 -5
- {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/manager/config_registry/trainer_config_manager.py +28 -0
- trinity_rft-0.2.1.dev0/trinity/manager/synchronizer.py +325 -0
- trinity_rft-0.2.1.dev0/trinity/trainer/trainer.py +215 -0
- {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/trainer/verl/dp_actor.py +55 -96
- trinity_rft-0.2.1.dev0/trinity/trainer/verl/fsdp_checkpoint_manager.py +363 -0
- {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/trainer/verl/fsdp_workers.py +105 -54
- trinity_rft-0.2.1.dev0/trinity/trainer/verl/utils.py +152 -0
- {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/trainer/verl_trainer.py +60 -120
- {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/utils/dlc_utils.py +3 -2
- {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/utils/eval_utils.py +2 -130
- trinity_rft-0.2.1.dev0/trinity/utils/math_eval_utils.py +558 -0
- {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/utils/monitor.py +99 -4
- trinity_rft-0.2.1.dev0/trinity/utils/plugin_loader.py +87 -0
- {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/utils/registry.py +1 -0
- {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity_rft.egg-info/PKG-INFO +38 -19
- {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity_rft.egg-info/SOURCES.txt +20 -3
- {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity_rft.egg-info/requires.txt +6 -2
- trinity_rft-0.2.0/trinity/algorithm/sample_strategy/mix_sample_strategy.py +0 -119
- trinity_rft-0.2.0/trinity/algorithm/sample_strategy/sample_strategy.py +0 -142
- trinity_rft-0.2.0/trinity/algorithm/sample_strategy/utils.py +0 -79
- trinity_rft-0.2.0/trinity/buffer/__init__.py +0 -7
- trinity_rft-0.2.0/trinity/buffer/queue.py +0 -112
- trinity_rft-0.2.0/trinity/common/experience.py +0 -319
- trinity_rft-0.2.0/trinity/common/models/model.py +0 -138
- trinity_rft-0.2.0/trinity/common/models/openai_api.py +0 -79
- trinity_rft-0.2.0/trinity/common/models/vllm_model.py +0 -293
- trinity_rft-0.2.0/trinity/common/workflows/__init__.py +0 -21
- trinity_rft-0.2.0/trinity/manager/__init__.py +0 -7
- trinity_rft-0.2.0/trinity/trainer/trainer.py +0 -121
- trinity_rft-0.2.0/trinity/utils/plugin_loader.py +0 -68
- {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/LICENSE +0 -0
- {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/setup.cfg +0 -0
- {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/setup.py +0 -0
- {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/algorithm/advantage_fn/__init__.py +0 -0
- {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/algorithm/advantage_fn/advantage_fn.py +0 -0
- {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/algorithm/advantage_fn/ppo_advantage.py +0 -0
- {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/algorithm/advantage_fn/reinforce_plus_plus_advantage.py +0 -0
- {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/algorithm/advantage_fn/remax_advantage.py +0 -0
- {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/algorithm/algorithm_manager.py +0 -0
- {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/algorithm/entropy_loss_fn/__init__.py +0 -0
- {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/algorithm/key_mapper.py +0 -0
- {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/algorithm/kl_fn/__init__.py +0 -0
- {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/algorithm/kl_fn/kl_fn.py +0 -0
- {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/algorithm/policy_loss_fn/dpo_loss.py +0 -0
- {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/algorithm/policy_loss_fn/opmd_policy_loss.py +0 -0
- {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/algorithm/policy_loss_fn/policy_loss_fn.py +0 -0
- {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/algorithm/policy_loss_fn/ppo_policy_loss.py +0 -0
- {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/algorithm/policy_loss_fn/sft_loss.py +0 -0
- {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/algorithm/sample_strategy/__init__.py +0 -0
- {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/algorithm/utils.py +0 -0
- {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/buffer/buffer.py +0 -0
- {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/buffer/buffer_writer.py +0 -0
- {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/buffer/reader/__init__.py +0 -0
- {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/buffer/schema/__init__.py +0 -0
- {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/buffer/utils.py +0 -0
- {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/buffer/writer/__init__.py +0 -0
- {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/buffer/writer/file_writer.py +0 -0
- {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/buffer/writer/sql_writer.py +0 -0
- {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/cli/client.py +0 -0
- {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/cli/server.py +0 -0
- {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/common/__init__.py +0 -0
- {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/common/rewards/accuracy_reward.py +0 -0
- {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/common/rewards/agents_reward.py +0 -0
- {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/common/rewards/countdown_reward.py +0 -0
- {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/common/rewards/format_reward.py +0 -0
- {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/common/rewards/human_reward.py +0 -0
- {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/common/rewards/tool_reward.py +0 -0
- {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/common/rewards/utils.py +0 -0
- {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/data/controllers/active_iterator.py +0 -0
- {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/data/controllers/default_ops.py +0 -0
- {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/data/controllers/task_parser.py +0 -0
- {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/data/core/comparator.py +0 -0
- {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/data/core/dataset.py +0 -0
- {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/data/core/formatter.py +0 -0
- {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/data/processors/base.py +0 -0
- {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/data/processors/cleaner.py +0 -0
- {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/data/processors/human_annotator.py +0 -0
- {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/data/processors/synthesizer.py +0 -0
- {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/data/server.py +0 -0
- {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/data/utils.py +0 -0
- {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/explorer/__init__.py +0 -0
- {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/manager/config_registry/__init__.py +0 -0
- {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/manager/config_registry/algorithm_config_manager.py +0 -0
- {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/manager/config_registry/config_registry.py +0 -0
- {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/manager/manager.py +0 -0
- {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/plugins/__init__.py +0 -0
- {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/trainer/__init__.py +0 -0
- {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/trainer/verl/__init__.py +0 -0
- {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/utils/__init__.py +0 -0
- {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/utils/distributed.py +0 -0
- {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/utils/log.py +0 -0
- {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/utils/timer.py +0 -0
- {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity_rft.egg-info/dependency_links.txt +0 -0
- {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity_rft.egg-info/entry_points.txt +0 -0
- {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity_rft.egg-info/top_level.txt +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: trinity-rft
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.1.dev0
|
|
4
4
|
Summary: Trinity-RFT: A Framework for Training Large Language Models with Reinforcement Fine-Tuning
|
|
5
5
|
Author-email: Trinity-RFT Team <trinity-rft@outlook.com>
|
|
6
6
|
Project-URL: Homepage, https://github.com/modelscope/Trinity-RFT
|
|
@@ -15,9 +15,9 @@ Classifier: Programming Language :: Python :: 3.12
|
|
|
15
15
|
Requires-Python: >=3.10
|
|
16
16
|
Description-Content-Type: text/markdown
|
|
17
17
|
License-File: LICENSE
|
|
18
|
-
Requires-Dist: verl==0.4.
|
|
18
|
+
Requires-Dist: verl==0.4.1
|
|
19
19
|
Requires-Dist: ray[default]>=2.45.0
|
|
20
|
-
Requires-Dist: vllm
|
|
20
|
+
Requires-Dist: vllm<=0.10.0,>=0.9.1
|
|
21
21
|
Requires-Dist: tensordict==0.6.2
|
|
22
22
|
Requires-Dist: wandb
|
|
23
23
|
Requires-Dist: omegaconf
|
|
@@ -35,6 +35,8 @@ Requires-Dist: tensorboard
|
|
|
35
35
|
Requires-Dist: openai
|
|
36
36
|
Requires-Dist: jsonlines
|
|
37
37
|
Requires-Dist: sortedcontainers
|
|
38
|
+
Requires-Dist: word2number
|
|
39
|
+
Requires-Dist: transformers<4.54.0
|
|
38
40
|
Provides-Extra: data
|
|
39
41
|
Requires-Dist: py-data-juicer; extra == "data"
|
|
40
42
|
Provides-Extra: agent
|
|
@@ -57,10 +59,13 @@ Requires-Dist: sphinx; extra == "doc"
|
|
|
57
59
|
Requires-Dist: sphinx-autobuild; extra == "doc"
|
|
58
60
|
Requires-Dist: sphinx_rtd_theme; extra == "doc"
|
|
59
61
|
Requires-Dist: myst-parser; extra == "doc"
|
|
62
|
+
Requires-Dist: sphinxcontrib-apidoc; extra == "doc"
|
|
63
|
+
Requires-Dist: sphinx-multiversion; extra == "doc"
|
|
60
64
|
Provides-Extra: flash-attn
|
|
61
65
|
Requires-Dist: flash-attn==2.8.0.post2; extra == "flash-attn"
|
|
62
66
|
Dynamic: license-file
|
|
63
67
|
|
|
68
|
+
[**中文主页**](https://github.com/modelscope/Trinity-RFT/blob/main/README_zh.md) | [**Tutorial**](https://modelscope.github.io/Trinity-RFT/) | [**FAQ**](./docs/sphinx_doc/source/tutorial/faq.md)
|
|
64
69
|
|
|
65
70
|
<div align="center">
|
|
66
71
|
<img src="https://img.alicdn.com/imgextra/i1/O1CN01lvLpfw25Pl4ohGZnU_!!6000000007519-2-tps-1628-490.png" alt="Trinity-RFT" style="height: 120px;">
|
|
@@ -84,6 +89,9 @@ Dynamic: license-file
|
|
|
84
89
|
|
|
85
90
|
## 🚀 News
|
|
86
91
|
|
|
92
|
+
* [2025-08] ✨ Trinity-RFT v0.2.1 is released with enhanced features for Agentic RL and Async RL.
|
|
93
|
+
* [2025-08] 🎵 We introduce [CHORD](https://github.com/modelscope/Trinity-RFT/tree/main/examples/mix_chord), a dynamic integration of SFT and RL for enhanced LLM fine-tuning ([paper](https://arxiv.org/pdf/2508.11408)).
|
|
94
|
+
* [2025-08] We now support training on general multi-step workflows! Please check out examples for [ALFWorld](./docs/sphinx_doc/source/tutorial/example_step_wise.md) and [ReAct](./docs/sphinx_doc/source/tutorial/example_react.md).
|
|
87
95
|
* [2025-07] Trinity-RFT v0.2.0 is released.
|
|
88
96
|
* [2025-07] We update the [technical report](https://arxiv.org/abs/2505.17826) (arXiv v2) with new features, examples, and experiments.
|
|
89
97
|
* [2025-06] Trinity-RFT v0.1.1 is released.
|
|
@@ -143,6 +151,15 @@ It is designed to support diverse application scenarios and serve as a unified p
|
|
|
143
151
|
<img src="https://img.alicdn.com/imgextra/i3/O1CN01E7NskS1FFoTI9jlaQ_!!6000000000458-2-tps-1458-682.png" alt="Trinity-RFT-modes">
|
|
144
152
|
</p>
|
|
145
153
|
|
|
154
|
+
</details>
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
<details>
|
|
158
|
+
<summary>Figure: Concatenated and general multi-step workflows</summary>
|
|
159
|
+
|
|
160
|
+
<p align="center">
|
|
161
|
+
<img src="https://img.alicdn.com/imgextra/i1/O1CN01z1i7kk1jlMEVa8ZHV_!!6000000004588-2-tps-1262-695.png" alt="Trinity-RFT-multi-step">
|
|
162
|
+
</p>
|
|
146
163
|
|
|
147
164
|
</details>
|
|
148
165
|
|
|
@@ -214,6 +231,12 @@ It is designed to support diverse application scenarios and serve as a unified p
|
|
|
214
231
|
### Step 1: installation
|
|
215
232
|
|
|
216
233
|
|
|
234
|
+
Requirements:
|
|
235
|
+
- Python version >= 3.10, <= 3.12
|
|
236
|
+
- CUDA version >= 12.4, <= 12.8
|
|
237
|
+
- At least 2 GPUs
|
|
238
|
+
|
|
239
|
+
|
|
217
240
|
Installation from source **(recommended)**:
|
|
218
241
|
|
|
219
242
|
```shell
|
|
@@ -243,13 +266,15 @@ pip install -e .[flash_attn]
|
|
|
243
266
|
# for zsh
|
|
244
267
|
pip install -e .\[flash_attn\]
|
|
245
268
|
# Try the following command if you encounter errors during flash-attn installation
|
|
246
|
-
# pip install flash-attn -v --no-build-isolation
|
|
269
|
+
# pip install flash-attn==2.8.0.post2 -v --no-build-isolation
|
|
247
270
|
```
|
|
248
271
|
|
|
249
272
|
Installation using pip:
|
|
250
273
|
|
|
251
274
|
```shell
|
|
252
275
|
pip install trinity-rft==0.2.0
|
|
276
|
+
# install flash-attn separately
|
|
277
|
+
pip install flash-attn==2.8.0.post2
|
|
253
278
|
```
|
|
254
279
|
|
|
255
280
|
Installation from docker:
|
|
@@ -268,13 +293,6 @@ docker build -f scripts/docker/Dockerfile -t trinity-rft:latest .
|
|
|
268
293
|
docker run -it --gpus all --shm-size="64g" --rm -v $PWD:/workspace -v <root_path_of_data_and_checkpoints>:/data trinity-rft:latest
|
|
269
294
|
```
|
|
270
295
|
|
|
271
|
-
|
|
272
|
-
**Requirements:**
|
|
273
|
-
Python version >= 3.10,
|
|
274
|
-
CUDA version >= 12.4,
|
|
275
|
-
and at least 2 GPUs.
|
|
276
|
-
|
|
277
|
-
|
|
278
296
|
### Step 2: prepare dataset and model
|
|
279
297
|
|
|
280
298
|
|
|
@@ -291,7 +309,7 @@ huggingface-cli download {model_name} --local-dir $MODEL_PATH/{model_name}
|
|
|
291
309
|
modelscope download {model_name} --local_dir $MODEL_PATH/{model_name}
|
|
292
310
|
```
|
|
293
311
|
|
|
294
|
-
For more details about model downloading, see [Huggingface](https://huggingface.co/docs/huggingface_hub/main/en/guides/cli) or
|
|
312
|
+
For more details about model downloading, see [Huggingface](https://huggingface.co/docs/huggingface_hub/main/en/guides/cli) or [ModelScope](https://modelscope.cn/docs/models/download).
|
|
295
313
|
|
|
296
314
|
|
|
297
315
|
|
|
@@ -318,13 +336,13 @@ Trinity-RFT provides a web interface for configuring your RFT process.
|
|
|
318
336
|
> This is an experimental feature, and we will continue to improve it.
|
|
319
337
|
|
|
320
338
|
|
|
321
|
-
To
|
|
339
|
+
To launch the web interface for minimal configurations, you can run
|
|
322
340
|
|
|
323
341
|
```bash
|
|
324
342
|
trinity studio --port 8080
|
|
325
343
|
```
|
|
326
344
|
|
|
327
|
-
Then you can configure your RFT process in the web page and generate a config file. You can save the config for later use or run it directly as described in the following section.
|
|
345
|
+
Then you can configure your RFT process in the web page and generate a config file. You can save the config file for later use or run it directly as described in the following section.
|
|
328
346
|
|
|
329
347
|
Advanced users can also edit the config file directly.
|
|
330
348
|
We provide example config files in [`examples`](examples/).
|
|
@@ -392,7 +410,12 @@ Tutorials for running different RFT modes:
|
|
|
392
410
|
|
|
393
411
|
Tutorials for adapting Trinity-RFT to a new multi-turn agentic scenario:
|
|
394
412
|
|
|
395
|
-
+ [Multi-turn tasks](./docs/sphinx_doc/source/tutorial/example_multi_turn.md)
|
|
413
|
+
+ [Concatenated Multi-turn tasks](./docs/sphinx_doc/source/tutorial/example_multi_turn.md)
|
|
414
|
+
|
|
415
|
+
Tutorials for adapting Trinity-RFT to a general multi-step agentic scenario:
|
|
416
|
+
|
|
417
|
+
+ [General Multi-Step tasks](./docs/sphinx_doc/source/tutorial/example_step_wise.md)
|
|
418
|
+
+ [ReAct agent tasks](./docs/sphinx_doc/source/tutorial/example_react.md)
|
|
396
419
|
|
|
397
420
|
|
|
398
421
|
Tutorials for data-related functionalities:
|
|
@@ -416,10 +439,6 @@ Guidelines for developers and researchers:
|
|
|
416
439
|
|
|
417
440
|
|
|
418
441
|
|
|
419
|
-
For some frequently asked questions, see [FAQ](./docs/sphinx_doc/source/tutorial/faq.md).
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
|
|
423
442
|
|
|
424
443
|
## Upcoming features
|
|
425
444
|
|
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
[**中文主页**](https://github.com/modelscope/Trinity-RFT/blob/main/README_zh.md) | [**Tutorial**](https://modelscope.github.io/Trinity-RFT/) | [**FAQ**](./docs/sphinx_doc/source/tutorial/faq.md)
|
|
1
2
|
|
|
2
3
|
<div align="center">
|
|
3
4
|
<img src="https://img.alicdn.com/imgextra/i1/O1CN01lvLpfw25Pl4ohGZnU_!!6000000007519-2-tps-1628-490.png" alt="Trinity-RFT" style="height: 120px;">
|
|
@@ -21,6 +22,9 @@
|
|
|
21
22
|
|
|
22
23
|
## 🚀 News
|
|
23
24
|
|
|
25
|
+
* [2025-08] ✨ Trinity-RFT v0.2.1 is released with enhanced features for Agentic RL and Async RL.
|
|
26
|
+
* [2025-08] 🎵 We introduce [CHORD](https://github.com/modelscope/Trinity-RFT/tree/main/examples/mix_chord), a dynamic integration of SFT and RL for enhanced LLM fine-tuning ([paper](https://arxiv.org/pdf/2508.11408)).
|
|
27
|
+
* [2025-08] We now support training on general multi-step workflows! Please check out examples for [ALFWorld](./docs/sphinx_doc/source/tutorial/example_step_wise.md) and [ReAct](./docs/sphinx_doc/source/tutorial/example_react.md).
|
|
24
28
|
* [2025-07] Trinity-RFT v0.2.0 is released.
|
|
25
29
|
* [2025-07] We update the [technical report](https://arxiv.org/abs/2505.17826) (arXiv v2) with new features, examples, and experiments.
|
|
26
30
|
* [2025-06] Trinity-RFT v0.1.1 is released.
|
|
@@ -80,6 +84,15 @@ It is designed to support diverse application scenarios and serve as a unified p
|
|
|
80
84
|
<img src="https://img.alicdn.com/imgextra/i3/O1CN01E7NskS1FFoTI9jlaQ_!!6000000000458-2-tps-1458-682.png" alt="Trinity-RFT-modes">
|
|
81
85
|
</p>
|
|
82
86
|
|
|
87
|
+
</details>
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
<details>
|
|
91
|
+
<summary>Figure: Concatenated and general multi-step workflows</summary>
|
|
92
|
+
|
|
93
|
+
<p align="center">
|
|
94
|
+
<img src="https://img.alicdn.com/imgextra/i1/O1CN01z1i7kk1jlMEVa8ZHV_!!6000000004588-2-tps-1262-695.png" alt="Trinity-RFT-multi-step">
|
|
95
|
+
</p>
|
|
83
96
|
|
|
84
97
|
</details>
|
|
85
98
|
|
|
@@ -151,6 +164,12 @@ It is designed to support diverse application scenarios and serve as a unified p
|
|
|
151
164
|
### Step 1: installation
|
|
152
165
|
|
|
153
166
|
|
|
167
|
+
Requirements:
|
|
168
|
+
- Python version >= 3.10, <= 3.12
|
|
169
|
+
- CUDA version >= 12.4, <= 12.8
|
|
170
|
+
- At least 2 GPUs
|
|
171
|
+
|
|
172
|
+
|
|
154
173
|
Installation from source **(recommended)**:
|
|
155
174
|
|
|
156
175
|
```shell
|
|
@@ -180,13 +199,15 @@ pip install -e .[flash_attn]
|
|
|
180
199
|
# for zsh
|
|
181
200
|
pip install -e .\[flash_attn\]
|
|
182
201
|
# Try the following command if you encounter errors during flash-attn installation
|
|
183
|
-
# pip install flash-attn -v --no-build-isolation
|
|
202
|
+
# pip install flash-attn==2.8.0.post2 -v --no-build-isolation
|
|
184
203
|
```
|
|
185
204
|
|
|
186
205
|
Installation using pip:
|
|
187
206
|
|
|
188
207
|
```shell
|
|
189
208
|
pip install trinity-rft==0.2.0
|
|
209
|
+
# install flash-attn separately
|
|
210
|
+
pip install flash-attn==2.8.0.post2
|
|
190
211
|
```
|
|
191
212
|
|
|
192
213
|
Installation from docker:
|
|
@@ -205,13 +226,6 @@ docker build -f scripts/docker/Dockerfile -t trinity-rft:latest .
|
|
|
205
226
|
docker run -it --gpus all --shm-size="64g" --rm -v $PWD:/workspace -v <root_path_of_data_and_checkpoints>:/data trinity-rft:latest
|
|
206
227
|
```
|
|
207
228
|
|
|
208
|
-
|
|
209
|
-
**Requirements:**
|
|
210
|
-
Python version >= 3.10,
|
|
211
|
-
CUDA version >= 12.4,
|
|
212
|
-
and at least 2 GPUs.
|
|
213
|
-
|
|
214
|
-
|
|
215
229
|
### Step 2: prepare dataset and model
|
|
216
230
|
|
|
217
231
|
|
|
@@ -228,7 +242,7 @@ huggingface-cli download {model_name} --local-dir $MODEL_PATH/{model_name}
|
|
|
228
242
|
modelscope download {model_name} --local_dir $MODEL_PATH/{model_name}
|
|
229
243
|
```
|
|
230
244
|
|
|
231
|
-
For more details about model downloading, see [Huggingface](https://huggingface.co/docs/huggingface_hub/main/en/guides/cli) or
|
|
245
|
+
For more details about model downloading, see [Huggingface](https://huggingface.co/docs/huggingface_hub/main/en/guides/cli) or [ModelScope](https://modelscope.cn/docs/models/download).
|
|
232
246
|
|
|
233
247
|
|
|
234
248
|
|
|
@@ -255,13 +269,13 @@ Trinity-RFT provides a web interface for configuring your RFT process.
|
|
|
255
269
|
> This is an experimental feature, and we will continue to improve it.
|
|
256
270
|
|
|
257
271
|
|
|
258
|
-
To
|
|
272
|
+
To launch the web interface for minimal configurations, you can run
|
|
259
273
|
|
|
260
274
|
```bash
|
|
261
275
|
trinity studio --port 8080
|
|
262
276
|
```
|
|
263
277
|
|
|
264
|
-
Then you can configure your RFT process in the web page and generate a config file. You can save the config for later use or run it directly as described in the following section.
|
|
278
|
+
Then you can configure your RFT process in the web page and generate a config file. You can save the config file for later use or run it directly as described in the following section.
|
|
265
279
|
|
|
266
280
|
Advanced users can also edit the config file directly.
|
|
267
281
|
We provide example config files in [`examples`](examples/).
|
|
@@ -329,7 +343,12 @@ Tutorials for running different RFT modes:
|
|
|
329
343
|
|
|
330
344
|
Tutorials for adapting Trinity-RFT to a new multi-turn agentic scenario:
|
|
331
345
|
|
|
332
|
-
+ [Multi-turn tasks](./docs/sphinx_doc/source/tutorial/example_multi_turn.md)
|
|
346
|
+
+ [Concatenated Multi-turn tasks](./docs/sphinx_doc/source/tutorial/example_multi_turn.md)
|
|
347
|
+
|
|
348
|
+
Tutorials for adapting Trinity-RFT to a general multi-step agentic scenario:
|
|
349
|
+
|
|
350
|
+
+ [General Multi-Step tasks](./docs/sphinx_doc/source/tutorial/example_step_wise.md)
|
|
351
|
+
+ [ReAct agent tasks](./docs/sphinx_doc/source/tutorial/example_react.md)
|
|
333
352
|
|
|
334
353
|
|
|
335
354
|
Tutorials for data-related functionalities:
|
|
@@ -353,10 +372,6 @@ Guidelines for developers and researchers:
|
|
|
353
372
|
|
|
354
373
|
|
|
355
374
|
|
|
356
|
-
For some frequently asked questions, see [FAQ](./docs/sphinx_doc/source/tutorial/faq.md).
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
375
|
|
|
361
376
|
## Upcoming features
|
|
362
377
|
|
|
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|
|
4
4
|
|
|
5
5
|
[project]
|
|
6
6
|
name = "trinity-rft"
|
|
7
|
-
version = "0.2.
|
|
7
|
+
version = "0.2.1.dev0"
|
|
8
8
|
authors = [
|
|
9
9
|
{name="Trinity-RFT Team", email="trinity-rft@outlook.com"},
|
|
10
10
|
]
|
|
@@ -21,9 +21,9 @@ classifiers = [
|
|
|
21
21
|
]
|
|
22
22
|
requires-python = ">=3.10"
|
|
23
23
|
dependencies = [
|
|
24
|
-
"verl==0.4.
|
|
24
|
+
"verl==0.4.1",
|
|
25
25
|
"ray[default]>=2.45.0",
|
|
26
|
-
"vllm
|
|
26
|
+
"vllm>=0.9.1,<=0.10.0",
|
|
27
27
|
"tensordict==0.6.2",
|
|
28
28
|
"wandb",
|
|
29
29
|
"omegaconf",
|
|
@@ -41,6 +41,8 @@ dependencies = [
|
|
|
41
41
|
"openai",
|
|
42
42
|
"jsonlines",
|
|
43
43
|
"sortedcontainers",
|
|
44
|
+
"word2number",
|
|
45
|
+
"transformers<4.54.0", # TODO: remove when https://github.com/vllm-project/vllm-ascend/issues/2046 is fixed
|
|
44
46
|
]
|
|
45
47
|
|
|
46
48
|
[project.scripts]
|
|
@@ -66,7 +68,7 @@ dev = [
|
|
|
66
68
|
"pytest>=8.0.0",
|
|
67
69
|
"pytest-json-ctrf",
|
|
68
70
|
"parameterized",
|
|
69
|
-
"matplotlib"
|
|
71
|
+
"matplotlib",
|
|
70
72
|
]
|
|
71
73
|
|
|
72
74
|
doc = [
|
|
@@ -74,6 +76,8 @@ doc = [
|
|
|
74
76
|
"sphinx-autobuild",
|
|
75
77
|
"sphinx_rtd_theme",
|
|
76
78
|
"myst-parser",
|
|
79
|
+
"sphinxcontrib-apidoc",
|
|
80
|
+
"sphinx-multiversion",
|
|
77
81
|
]
|
|
78
82
|
|
|
79
83
|
flash_attn = [
|
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
from trinity.algorithm.add_strategy import ADD_STRATEGY, AddStrategy
|
|
1
2
|
from trinity.algorithm.advantage_fn import ADVANTAGE_FN, AdvantageFn
|
|
2
3
|
from trinity.algorithm.algorithm import ALGORITHM_TYPE, AlgorithmType
|
|
3
4
|
from trinity.algorithm.entropy_loss_fn import ENTROPY_LOSS_FN, EntropyLossFn
|
|
@@ -18,4 +19,6 @@ __all__ = [
|
|
|
18
19
|
"ENTROPY_LOSS_FN",
|
|
19
20
|
"SampleStrategy",
|
|
20
21
|
"SAMPLE_STRATEGY",
|
|
22
|
+
"AddStrategy",
|
|
23
|
+
"ADD_STRATEGY",
|
|
21
24
|
]
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
from trinity.algorithm.add_strategy.add_strategy import (
|
|
2
|
+
ADD_STRATEGY,
|
|
3
|
+
AddStrategy,
|
|
4
|
+
GRPOAddStrategy,
|
|
5
|
+
OPMDAddStrategy,
|
|
6
|
+
RewardVarianceAddStrategy,
|
|
7
|
+
)
|
|
8
|
+
from trinity.algorithm.add_strategy.correct_bias_add_strategy import (
|
|
9
|
+
CorrectBiasAddStrategy,
|
|
10
|
+
)
|
|
11
|
+
from trinity.algorithm.add_strategy.duplicate_add_strategy import (
|
|
12
|
+
DuplicateInformativeAddStrategy,
|
|
13
|
+
)
|
|
14
|
+
from trinity.algorithm.add_strategy.step_wise_add_strategy import StepWiseGRPOStrategy
|
|
15
|
+
|
|
16
|
+
__all__ = [
|
|
17
|
+
"ADD_STRATEGY",
|
|
18
|
+
"AddStrategy",
|
|
19
|
+
"GRPOAddStrategy",
|
|
20
|
+
"OPMDAddStrategy",
|
|
21
|
+
"StepWiseGRPOStrategy",
|
|
22
|
+
"RewardVarianceAddStrategy",
|
|
23
|
+
"CorrectBiasAddStrategy",
|
|
24
|
+
"DuplicateInformativeAddStrategy",
|
|
25
|
+
]
|
|
@@ -0,0 +1,230 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
from abc import ABC, abstractmethod
|
|
3
|
+
from typing import Dict, List, Literal, Tuple
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import torch
|
|
7
|
+
|
|
8
|
+
from trinity.buffer import BufferWriter
|
|
9
|
+
from trinity.common.experience import Experience
|
|
10
|
+
from trinity.utils.monitor import gather_metrics
|
|
11
|
+
from trinity.utils.registry import Registry
|
|
12
|
+
from trinity.utils.timer import Timer
|
|
13
|
+
|
|
14
|
+
ADD_STRATEGY = Registry("add_strategy")
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class AddStrategy(ABC):
|
|
18
|
+
def __init__(self, writer: BufferWriter, **kwargs) -> None:
|
|
19
|
+
self.writer = writer
|
|
20
|
+
|
|
21
|
+
@abstractmethod
|
|
22
|
+
async def add(self, experiences: List[Experience], step: int) -> Tuple[int, Dict]:
|
|
23
|
+
"""Add experiences to the buffer.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
experiences (`Experience`): The experiences to be added.
|
|
27
|
+
step (`int`): The current step number.
|
|
28
|
+
|
|
29
|
+
Returns:
|
|
30
|
+
`int`: The number of experiences added to the buffer.
|
|
31
|
+
`Dict`: Metrics for logging.
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
@classmethod
|
|
35
|
+
@abstractmethod
|
|
36
|
+
def default_args(cls) -> dict:
|
|
37
|
+
"""Get the default arguments of the add strategy.
|
|
38
|
+
|
|
39
|
+
Returns:
|
|
40
|
+
`dict`: The default arguments.
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class GroupAdvantageStrategy(AddStrategy):
|
|
45
|
+
"""An example AddStrategy that calculates group advantages."""
|
|
46
|
+
|
|
47
|
+
@abstractmethod
|
|
48
|
+
def group_experiences(self, exps: List[Experience]) -> Dict[str, List[Experience]]:
|
|
49
|
+
"""Group experiences by a certain criterion.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
exps (List[Experience]): List of experiences to be grouped.
|
|
53
|
+
|
|
54
|
+
Returns:
|
|
55
|
+
Dict[str, List[Experience]]: A dictionary where keys are group identifiers and values are lists of experiences.
|
|
56
|
+
"""
|
|
57
|
+
|
|
58
|
+
@abstractmethod
|
|
59
|
+
def calculate_group_advantage(
|
|
60
|
+
self, group_id: str, exps: List[Experience]
|
|
61
|
+
) -> Tuple[List[Experience], Dict]:
|
|
62
|
+
"""Calculate advantages for a group of experiences.
|
|
63
|
+
|
|
64
|
+
Args:
|
|
65
|
+
group_id (str): The identifier for the group of experiences.
|
|
66
|
+
exps (List[Experience]): List of experiences in the group.
|
|
67
|
+
|
|
68
|
+
Returns:
|
|
69
|
+
Tuple[List[Experience], Dict]: A tuple containing the modified list of experiences and a dictionary of metrics.
|
|
70
|
+
"""
|
|
71
|
+
|
|
72
|
+
async def add(self, exps: List[Experience], step: int) -> Tuple[int, Dict]:
|
|
73
|
+
if len(exps) == 0:
|
|
74
|
+
return 0, {}
|
|
75
|
+
exp_groups = self.group_experiences(exps)
|
|
76
|
+
cnt = 0
|
|
77
|
+
metric_list = []
|
|
78
|
+
tasks = []
|
|
79
|
+
for group_id, group_exps in exp_groups.items():
|
|
80
|
+
group_exps, group_metrics = self.calculate_group_advantage(group_id, group_exps)
|
|
81
|
+
metric_list.append(group_metrics)
|
|
82
|
+
cnt += len(group_exps)
|
|
83
|
+
if len(group_exps) > 0:
|
|
84
|
+
tasks.append(self.writer.write_async(group_exps))
|
|
85
|
+
if tasks:
|
|
86
|
+
await asyncio.gather(*tasks)
|
|
87
|
+
try:
|
|
88
|
+
metrics = gather_metrics(metric_list, "group_advantages")
|
|
89
|
+
except ValueError:
|
|
90
|
+
metrics = {} # empty metric list causes ValueError, ignore it
|
|
91
|
+
return cnt, metrics
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
@ADD_STRATEGY.register_module("grpo")
|
|
95
|
+
class GRPOAddStrategy(GroupAdvantageStrategy):
|
|
96
|
+
"""An example AddStrategy that calculates GRPO advantages."""
|
|
97
|
+
|
|
98
|
+
def __init__(self, writer: BufferWriter, epsilon: float = 1e-6, **kwargs) -> None:
|
|
99
|
+
super().__init__(writer)
|
|
100
|
+
self.epsilon = epsilon
|
|
101
|
+
|
|
102
|
+
def group_experiences(self, exps):
|
|
103
|
+
return group_by(exps, id_type="task")
|
|
104
|
+
|
|
105
|
+
def calculate_group_advantage(
|
|
106
|
+
self, group_id: str, exps: List[Experience]
|
|
107
|
+
) -> Tuple[List[Experience], Dict]:
|
|
108
|
+
with torch.no_grad():
|
|
109
|
+
if len(exps) == 1:
|
|
110
|
+
group_reward_mean = torch.tensor(0.0)
|
|
111
|
+
group_reward_std = torch.tensor(1.0)
|
|
112
|
+
else:
|
|
113
|
+
rewards = torch.tensor([exp.reward for exp in exps], dtype=torch.float32)
|
|
114
|
+
group_reward_mean = torch.mean(rewards)
|
|
115
|
+
group_reward_std = torch.std(rewards)
|
|
116
|
+
for exp in exps:
|
|
117
|
+
score = (exp.reward - group_reward_mean) / (group_reward_std + self.epsilon)
|
|
118
|
+
exp.advantages = score * exp.action_mask
|
|
119
|
+
exp.returns = exp.advantages.clone()
|
|
120
|
+
|
|
121
|
+
metrics = {
|
|
122
|
+
"reward_mean": group_reward_mean.item(),
|
|
123
|
+
"reward_std": group_reward_std.item(),
|
|
124
|
+
}
|
|
125
|
+
|
|
126
|
+
return exps, metrics
|
|
127
|
+
|
|
128
|
+
@classmethod
|
|
129
|
+
def default_args(cls) -> dict:
|
|
130
|
+
return {"epsilon": 1e-6}
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
@ADD_STRATEGY.register_module("opmd")
|
|
134
|
+
class OPMDAddStrategy(GroupAdvantageStrategy):
|
|
135
|
+
"""An example AddStrategy that calculates OPMD advantages."""
|
|
136
|
+
|
|
137
|
+
def __init__(
|
|
138
|
+
self, writer: BufferWriter, opmd_baseline: str = "mean", tau: float = 1.0, **kwargs
|
|
139
|
+
) -> None:
|
|
140
|
+
super().__init__(writer)
|
|
141
|
+
assert opmd_baseline in [
|
|
142
|
+
"mean",
|
|
143
|
+
"logavgexp",
|
|
144
|
+
], f"opmd_baseline must be 'mean' or 'logavgexp', got {opmd_baseline}"
|
|
145
|
+
self.opmd_baseline = opmd_baseline
|
|
146
|
+
self.tau = tau
|
|
147
|
+
|
|
148
|
+
def group_experiences(self, exps):
|
|
149
|
+
return group_by(exps, id_type="task")
|
|
150
|
+
|
|
151
|
+
def calculate_group_advantage(
|
|
152
|
+
self, group_id: str, exps: List[Experience]
|
|
153
|
+
) -> Tuple[List[Experience], Dict]:
|
|
154
|
+
with torch.no_grad():
|
|
155
|
+
if len(exps) == 1:
|
|
156
|
+
group_baseline = torch.tensor(0.0)
|
|
157
|
+
else:
|
|
158
|
+
group_rewards = torch.tensor([exp.reward for exp in exps], dtype=torch.float32)
|
|
159
|
+
if self.opmd_baseline == "mean":
|
|
160
|
+
group_baseline = torch.mean(group_rewards)
|
|
161
|
+
else:
|
|
162
|
+
group_baseline = self.tau * (
|
|
163
|
+
torch.logsumexp(group_rewards / self.tau, dim=-1)
|
|
164
|
+
- torch.log(torch.tensor(len(exps)))
|
|
165
|
+
)
|
|
166
|
+
for exp in exps:
|
|
167
|
+
score = exp.reward - group_baseline
|
|
168
|
+
exp.advantages = score * exp.action_mask
|
|
169
|
+
exp.returns = exp.advantages.clone()
|
|
170
|
+
metrics = {
|
|
171
|
+
"group_baseline": group_baseline,
|
|
172
|
+
}
|
|
173
|
+
return exps, metrics
|
|
174
|
+
|
|
175
|
+
@classmethod
|
|
176
|
+
def default_args(cls) -> dict:
|
|
177
|
+
return {"opmd_baseline": "mean", "tau": 1.0}
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
@ADD_STRATEGY.register_module("reward_variance")
|
|
181
|
+
class RewardVarianceAddStrategy(AddStrategy):
|
|
182
|
+
"""An example AddStrategy that filters experiences based on a reward variance threshold."""
|
|
183
|
+
|
|
184
|
+
def __init__(self, writer: BufferWriter, variance_threshold: float = 0.0, **kwargs) -> None:
|
|
185
|
+
super().__init__(writer)
|
|
186
|
+
self.variance_threshold = variance_threshold
|
|
187
|
+
|
|
188
|
+
async def add(self, experiences: List[Experience], step: int) -> Tuple[int, Dict]:
|
|
189
|
+
cnt = 0
|
|
190
|
+
metrics = {}
|
|
191
|
+
tasks = []
|
|
192
|
+
with Timer(metrics, "add_strategy_time"):
|
|
193
|
+
grouped_experiences = group_by(experiences, id_type="task")
|
|
194
|
+
for _, group_exps in grouped_experiences.items():
|
|
195
|
+
if len(group_exps) < 2:
|
|
196
|
+
continue
|
|
197
|
+
rewards = [exp.reward for exp in group_exps]
|
|
198
|
+
variance = np.var(rewards)
|
|
199
|
+
if variance <= self.variance_threshold:
|
|
200
|
+
continue
|
|
201
|
+
cnt += len(group_exps)
|
|
202
|
+
tasks.append(self.writer.write_async(group_exps))
|
|
203
|
+
if tasks:
|
|
204
|
+
await asyncio.gather(*tasks)
|
|
205
|
+
return cnt, metrics
|
|
206
|
+
|
|
207
|
+
@classmethod
|
|
208
|
+
def default_args(cls) -> dict:
|
|
209
|
+
return {"variance_threshold": 0.0}
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
def group_by(
|
|
213
|
+
experiences: List[Experience], id_type: Literal["task", "run", "step"]
|
|
214
|
+
) -> Dict[str, List[Experience]]:
|
|
215
|
+
"""Group experiences by ID."""
|
|
216
|
+
if id_type == "task":
|
|
217
|
+
id_type = "tid"
|
|
218
|
+
elif id_type == "run":
|
|
219
|
+
id_type = "rid"
|
|
220
|
+
elif id_type == "step":
|
|
221
|
+
id_type = "sid"
|
|
222
|
+
else:
|
|
223
|
+
raise ValueError(f"Unknown id_type: {id_type}")
|
|
224
|
+
grouped = {}
|
|
225
|
+
for exp in experiences:
|
|
226
|
+
group_id = getattr(exp.eid, id_type)
|
|
227
|
+
if group_id not in grouped:
|
|
228
|
+
grouped[group_id] = []
|
|
229
|
+
grouped[group_id].append(exp)
|
|
230
|
+
return grouped
|
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
from typing import Dict, List, Tuple
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
from trinity.algorithm.add_strategy.add_strategy import ADD_STRATEGY, GRPOAddStrategy
|
|
7
|
+
from trinity.buffer import BufferWriter
|
|
8
|
+
from trinity.common.experience import Experience
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@ADD_STRATEGY.register_module("correct_bias")
|
|
12
|
+
class CorrectBiasAddStrategy(GRPOAddStrategy):
|
|
13
|
+
"""An Addstrategy with GroupAdvantage that corrects for rank bias (https://arxiv.org/pdf/2506.02355)"""
|
|
14
|
+
|
|
15
|
+
def __init__(
|
|
16
|
+
self, writer: BufferWriter, epsilon: float = 1e-6, rank_penalty: float = 0.25, **kwargs
|
|
17
|
+
) -> None:
|
|
18
|
+
super().__init__(writer, epsilon, **kwargs)
|
|
19
|
+
self.rank_penalty = rank_penalty
|
|
20
|
+
|
|
21
|
+
def calculate_group_advantage(
|
|
22
|
+
self, group_id: str, exps: List[Experience]
|
|
23
|
+
) -> Tuple[List[Experience], Dict]:
|
|
24
|
+
with torch.no_grad():
|
|
25
|
+
rewards = torch.tensor([exp.reward for exp in exps], dtype=torch.float32)
|
|
26
|
+
|
|
27
|
+
if len(exps) == 1:
|
|
28
|
+
group_reward_mean = torch.tensor(0.0)
|
|
29
|
+
group_reward_std = torch.tensor(1.0)
|
|
30
|
+
else:
|
|
31
|
+
# correct bias
|
|
32
|
+
old_log_probs = torch.tensor([torch.mean(exp.logprobs, axis=-1) for exp in exps])
|
|
33
|
+
group_ranks = torch.argsort(torch.argsort(old_log_probs))
|
|
34
|
+
group_ranks = group_ranks / len(group_ranks)
|
|
35
|
+
rewards = rewards * (1 - group_ranks * self.rank_penalty)
|
|
36
|
+
|
|
37
|
+
group_reward_mean = torch.mean(rewards)
|
|
38
|
+
group_reward_std = torch.std(rewards)
|
|
39
|
+
|
|
40
|
+
for i, exp in enumerate(exps):
|
|
41
|
+
score = (rewards[i] - group_reward_mean) / (group_reward_std + self.epsilon)
|
|
42
|
+
exp.advantages = score * exp.action_mask
|
|
43
|
+
exp.returns = exp.advantages.clone()
|
|
44
|
+
|
|
45
|
+
metrics = {
|
|
46
|
+
"reward_mean": group_reward_mean.item(),
|
|
47
|
+
"reward_std": group_reward_std.item(),
|
|
48
|
+
}
|
|
49
|
+
|
|
50
|
+
return exps, metrics
|
|
51
|
+
|
|
52
|
+
@classmethod
|
|
53
|
+
def default_args(cls) -> dict:
|
|
54
|
+
return {"epsilon": 1e-6, "rank_penalty": 0.25}
|