trinity-rft 0.2.0__tar.gz → 0.2.1.dev0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/PKG-INFO +38 -19
  2. {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/README.md +31 -16
  3. {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/pyproject.toml +8 -4
  4. {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/__init__.py +1 -1
  5. {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/algorithm/__init__.py +3 -0
  6. trinity_rft-0.2.1.dev0/trinity/algorithm/add_strategy/__init__.py +25 -0
  7. trinity_rft-0.2.1.dev0/trinity/algorithm/add_strategy/add_strategy.py +230 -0
  8. trinity_rft-0.2.1.dev0/trinity/algorithm/add_strategy/correct_bias_add_strategy.py +54 -0
  9. trinity_rft-0.2.1.dev0/trinity/algorithm/add_strategy/duplicate_add_strategy.py +72 -0
  10. trinity_rft-0.2.1.dev0/trinity/algorithm/add_strategy/step_wise_add_strategy.py +123 -0
  11. {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/algorithm/advantage_fn/grpo_advantage.py +2 -2
  12. {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/algorithm/advantage_fn/opmd_advantage.py +4 -2
  13. {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/algorithm/advantage_fn/rloo_advantage.py +1 -1
  14. {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/algorithm/algorithm.py +59 -11
  15. {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/algorithm/entropy_loss_fn/entropy_loss_fn.py +27 -0
  16. {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/algorithm/policy_loss_fn/__init__.py +10 -0
  17. trinity_rft-0.2.1.dev0/trinity/algorithm/policy_loss_fn/chord_policy_loss.py +257 -0
  18. trinity_rft-0.2.1.dev0/trinity/algorithm/policy_loss_fn/gspo_policy_loss.py +76 -0
  19. {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/algorithm/policy_loss_fn/mix_policy_loss.py +23 -26
  20. trinity_rft-0.2.1.dev0/trinity/algorithm/sample_strategy/mix_sample_strategy.py +98 -0
  21. trinity_rft-0.2.1.dev0/trinity/algorithm/sample_strategy/sample_strategy.py +92 -0
  22. trinity_rft-0.2.1.dev0/trinity/algorithm/sample_strategy/utils.py +42 -0
  23. trinity_rft-0.2.1.dev0/trinity/buffer/__init__.py +15 -0
  24. {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/buffer/buffer_reader.py +6 -0
  25. trinity_rft-0.2.0/trinity/buffer/priority_queue.py → trinity_rft-0.2.1.dev0/trinity/buffer/queue.py +96 -27
  26. {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/buffer/ray_wrapper.py +111 -4
  27. {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/buffer/reader/file_reader.py +69 -33
  28. {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/buffer/reader/queue_reader.py +16 -3
  29. {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/buffer/reader/sql_reader.py +8 -0
  30. {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/buffer/schema/sql_schema.py +3 -3
  31. {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/buffer/writer/queue_writer.py +2 -2
  32. {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/cli/launcher.py +16 -10
  33. {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/common/config.py +187 -73
  34. {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/common/constants.py +15 -0
  35. trinity_rft-0.2.1.dev0/trinity/common/experience.py +547 -0
  36. {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/common/models/__init__.py +9 -14
  37. trinity_rft-0.2.1.dev0/trinity/common/models/api/vllm_patch.py +374 -0
  38. trinity_rft-0.2.1.dev0/trinity/common/models/model.py +214 -0
  39. {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/common/models/utils.py +13 -8
  40. trinity_rft-0.2.0/trinity/common/models/vllm_async_model.py → trinity_rft-0.2.1.dev0/trinity/common/models/vllm_model.py +50 -52
  41. {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/common/models/vllm_worker.py +9 -14
  42. {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/common/rewards/__init__.py +2 -0
  43. trinity_rft-0.2.1.dev0/trinity/common/rewards/dapo_reward.py +67 -0
  44. {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/common/rewards/math_reward.py +2 -1
  45. {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/common/rewards/reward_fn.py +1 -1
  46. {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/common/verl_config.py +33 -15
  47. trinity_rft-0.2.1.dev0/trinity/common/workflows/__init__.py +32 -0
  48. {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/common/workflows/customized_math_workflows.py +17 -13
  49. trinity_rft-0.2.1.dev0/trinity/common/workflows/customized_toolcall_workflows.py +265 -0
  50. trinity_rft-0.2.1.dev0/trinity/common/workflows/envs/agentscope/agentscope_react_workflow.py +162 -0
  51. trinity_rft-0.2.1.dev0/trinity/common/workflows/envs/alfworld/RAFT_alfworld_workflow.py +225 -0
  52. trinity_rft-0.2.1.dev0/trinity/common/workflows/envs/alfworld/RAFT_reflect_alfworld_workflow.py +310 -0
  53. trinity_rft-0.2.1.dev0/trinity/common/workflows/envs/alfworld/RAFT_utils.py +196 -0
  54. {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/common/workflows/envs/alfworld/alfworld_workflow.py +119 -3
  55. {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/common/workflows/envs/sciworld/sciworld_workflow.py +1 -1
  56. {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/common/workflows/envs/webshop/webshop_workflow.py +1 -1
  57. trinity_rft-0.2.1.dev0/trinity/common/workflows/eval_workflow.py +90 -0
  58. {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/common/workflows/math_rm_workflow.py +5 -3
  59. trinity_rft-0.2.1.dev0/trinity/common/workflows/step_wise_workflow.py +147 -0
  60. {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/common/workflows/workflow.py +52 -19
  61. {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/explorer/explorer.py +143 -123
  62. {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/explorer/scheduler.py +64 -41
  63. {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/explorer/workflow_runner.py +46 -20
  64. trinity_rft-0.2.1.dev0/trinity/manager/__init__.py +7 -0
  65. {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/manager/config_manager.py +42 -14
  66. {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/manager/config_registry/buffer_config_manager.py +28 -2
  67. {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/manager/config_registry/explorer_config_manager.py +24 -1
  68. {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/manager/config_registry/model_config_manager.py +5 -5
  69. {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/manager/config_registry/trainer_config_manager.py +28 -0
  70. trinity_rft-0.2.1.dev0/trinity/manager/synchronizer.py +325 -0
  71. trinity_rft-0.2.1.dev0/trinity/trainer/trainer.py +215 -0
  72. {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/trainer/verl/dp_actor.py +55 -96
  73. trinity_rft-0.2.1.dev0/trinity/trainer/verl/fsdp_checkpoint_manager.py +363 -0
  74. {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/trainer/verl/fsdp_workers.py +105 -54
  75. trinity_rft-0.2.1.dev0/trinity/trainer/verl/utils.py +152 -0
  76. {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/trainer/verl_trainer.py +60 -120
  77. {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/utils/dlc_utils.py +3 -2
  78. {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/utils/eval_utils.py +2 -130
  79. trinity_rft-0.2.1.dev0/trinity/utils/math_eval_utils.py +558 -0
  80. {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/utils/monitor.py +99 -4
  81. trinity_rft-0.2.1.dev0/trinity/utils/plugin_loader.py +87 -0
  82. {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/utils/registry.py +1 -0
  83. {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity_rft.egg-info/PKG-INFO +38 -19
  84. {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity_rft.egg-info/SOURCES.txt +20 -3
  85. {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity_rft.egg-info/requires.txt +6 -2
  86. trinity_rft-0.2.0/trinity/algorithm/sample_strategy/mix_sample_strategy.py +0 -119
  87. trinity_rft-0.2.0/trinity/algorithm/sample_strategy/sample_strategy.py +0 -142
  88. trinity_rft-0.2.0/trinity/algorithm/sample_strategy/utils.py +0 -79
  89. trinity_rft-0.2.0/trinity/buffer/__init__.py +0 -7
  90. trinity_rft-0.2.0/trinity/buffer/queue.py +0 -112
  91. trinity_rft-0.2.0/trinity/common/experience.py +0 -319
  92. trinity_rft-0.2.0/trinity/common/models/model.py +0 -138
  93. trinity_rft-0.2.0/trinity/common/models/openai_api.py +0 -79
  94. trinity_rft-0.2.0/trinity/common/models/vllm_model.py +0 -293
  95. trinity_rft-0.2.0/trinity/common/workflows/__init__.py +0 -21
  96. trinity_rft-0.2.0/trinity/manager/__init__.py +0 -7
  97. trinity_rft-0.2.0/trinity/trainer/trainer.py +0 -121
  98. trinity_rft-0.2.0/trinity/utils/plugin_loader.py +0 -68
  99. {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/LICENSE +0 -0
  100. {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/setup.cfg +0 -0
  101. {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/setup.py +0 -0
  102. {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/algorithm/advantage_fn/__init__.py +0 -0
  103. {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/algorithm/advantage_fn/advantage_fn.py +0 -0
  104. {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/algorithm/advantage_fn/ppo_advantage.py +0 -0
  105. {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/algorithm/advantage_fn/reinforce_plus_plus_advantage.py +0 -0
  106. {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/algorithm/advantage_fn/remax_advantage.py +0 -0
  107. {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/algorithm/algorithm_manager.py +0 -0
  108. {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/algorithm/entropy_loss_fn/__init__.py +0 -0
  109. {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/algorithm/key_mapper.py +0 -0
  110. {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/algorithm/kl_fn/__init__.py +0 -0
  111. {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/algorithm/kl_fn/kl_fn.py +0 -0
  112. {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/algorithm/policy_loss_fn/dpo_loss.py +0 -0
  113. {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/algorithm/policy_loss_fn/opmd_policy_loss.py +0 -0
  114. {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/algorithm/policy_loss_fn/policy_loss_fn.py +0 -0
  115. {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/algorithm/policy_loss_fn/ppo_policy_loss.py +0 -0
  116. {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/algorithm/policy_loss_fn/sft_loss.py +0 -0
  117. {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/algorithm/sample_strategy/__init__.py +0 -0
  118. {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/algorithm/utils.py +0 -0
  119. {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/buffer/buffer.py +0 -0
  120. {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/buffer/buffer_writer.py +0 -0
  121. {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/buffer/reader/__init__.py +0 -0
  122. {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/buffer/schema/__init__.py +0 -0
  123. {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/buffer/utils.py +0 -0
  124. {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/buffer/writer/__init__.py +0 -0
  125. {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/buffer/writer/file_writer.py +0 -0
  126. {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/buffer/writer/sql_writer.py +0 -0
  127. {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/cli/client.py +0 -0
  128. {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/cli/server.py +0 -0
  129. {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/common/__init__.py +0 -0
  130. {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/common/rewards/accuracy_reward.py +0 -0
  131. {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/common/rewards/agents_reward.py +0 -0
  132. {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/common/rewards/countdown_reward.py +0 -0
  133. {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/common/rewards/format_reward.py +0 -0
  134. {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/common/rewards/human_reward.py +0 -0
  135. {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/common/rewards/tool_reward.py +0 -0
  136. {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/common/rewards/utils.py +0 -0
  137. {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/data/controllers/active_iterator.py +0 -0
  138. {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/data/controllers/default_ops.py +0 -0
  139. {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/data/controllers/task_parser.py +0 -0
  140. {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/data/core/comparator.py +0 -0
  141. {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/data/core/dataset.py +0 -0
  142. {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/data/core/formatter.py +0 -0
  143. {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/data/processors/base.py +0 -0
  144. {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/data/processors/cleaner.py +0 -0
  145. {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/data/processors/human_annotator.py +0 -0
  146. {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/data/processors/synthesizer.py +0 -0
  147. {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/data/server.py +0 -0
  148. {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/data/utils.py +0 -0
  149. {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/explorer/__init__.py +0 -0
  150. {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/manager/config_registry/__init__.py +0 -0
  151. {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/manager/config_registry/algorithm_config_manager.py +0 -0
  152. {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/manager/config_registry/config_registry.py +0 -0
  153. {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/manager/manager.py +0 -0
  154. {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/plugins/__init__.py +0 -0
  155. {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/trainer/__init__.py +0 -0
  156. {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/trainer/verl/__init__.py +0 -0
  157. {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/utils/__init__.py +0 -0
  158. {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/utils/distributed.py +0 -0
  159. {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/utils/log.py +0 -0
  160. {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity/utils/timer.py +0 -0
  161. {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity_rft.egg-info/dependency_links.txt +0 -0
  162. {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity_rft.egg-info/entry_points.txt +0 -0
  163. {trinity_rft-0.2.0 → trinity_rft-0.2.1.dev0}/trinity_rft.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: trinity-rft
3
- Version: 0.2.0
3
+ Version: 0.2.1.dev0
4
4
  Summary: Trinity-RFT: A Framework for Training Large Language Models with Reinforcement Fine-Tuning
5
5
  Author-email: Trinity-RFT Team <trinity-rft@outlook.com>
6
6
  Project-URL: Homepage, https://github.com/modelscope/Trinity-RFT
@@ -15,9 +15,9 @@ Classifier: Programming Language :: Python :: 3.12
15
15
  Requires-Python: >=3.10
16
16
  Description-Content-Type: text/markdown
17
17
  License-File: LICENSE
18
- Requires-Dist: verl==0.4.0
18
+ Requires-Dist: verl==0.4.1
19
19
  Requires-Dist: ray[default]>=2.45.0
20
- Requires-Dist: vllm==0.9.1
20
+ Requires-Dist: vllm<=0.10.0,>=0.9.1
21
21
  Requires-Dist: tensordict==0.6.2
22
22
  Requires-Dist: wandb
23
23
  Requires-Dist: omegaconf
@@ -35,6 +35,8 @@ Requires-Dist: tensorboard
35
35
  Requires-Dist: openai
36
36
  Requires-Dist: jsonlines
37
37
  Requires-Dist: sortedcontainers
38
+ Requires-Dist: word2number
39
+ Requires-Dist: transformers<4.54.0
38
40
  Provides-Extra: data
39
41
  Requires-Dist: py-data-juicer; extra == "data"
40
42
  Provides-Extra: agent
@@ -57,10 +59,13 @@ Requires-Dist: sphinx; extra == "doc"
57
59
  Requires-Dist: sphinx-autobuild; extra == "doc"
58
60
  Requires-Dist: sphinx_rtd_theme; extra == "doc"
59
61
  Requires-Dist: myst-parser; extra == "doc"
62
+ Requires-Dist: sphinxcontrib-apidoc; extra == "doc"
63
+ Requires-Dist: sphinx-multiversion; extra == "doc"
60
64
  Provides-Extra: flash-attn
61
65
  Requires-Dist: flash-attn==2.8.0.post2; extra == "flash-attn"
62
66
  Dynamic: license-file
63
67
 
68
+ [**中文主页**](https://github.com/modelscope/Trinity-RFT/blob/main/README_zh.md) | [**Tutorial**](https://modelscope.github.io/Trinity-RFT/) | [**FAQ**](./docs/sphinx_doc/source/tutorial/faq.md)
64
69
 
65
70
  <div align="center">
66
71
  <img src="https://img.alicdn.com/imgextra/i1/O1CN01lvLpfw25Pl4ohGZnU_!!6000000007519-2-tps-1628-490.png" alt="Trinity-RFT" style="height: 120px;">
@@ -84,6 +89,9 @@ Dynamic: license-file
84
89
 
85
90
  ## 🚀 News
86
91
 
92
+ * [2025-08] ✨ Trinity-RFT v0.2.1 is released with enhanced features for Agentic RL and Async RL.
93
+ * [2025-08] 🎵 We introduce [CHORD](https://github.com/modelscope/Trinity-RFT/tree/main/examples/mix_chord), a dynamic integration of SFT and RL for enhanced LLM fine-tuning ([paper](https://arxiv.org/pdf/2508.11408)).
94
+ * [2025-08] We now support training on general multi-step workflows! Please check out examples for [ALFWorld](./docs/sphinx_doc/source/tutorial/example_step_wise.md) and [ReAct](./docs/sphinx_doc/source/tutorial/example_react.md).
87
95
  * [2025-07] Trinity-RFT v0.2.0 is released.
88
96
  * [2025-07] We update the [technical report](https://arxiv.org/abs/2505.17826) (arXiv v2) with new features, examples, and experiments.
89
97
  * [2025-06] Trinity-RFT v0.1.1 is released.
@@ -143,6 +151,15 @@ It is designed to support diverse application scenarios and serve as a unified p
143
151
  <img src="https://img.alicdn.com/imgextra/i3/O1CN01E7NskS1FFoTI9jlaQ_!!6000000000458-2-tps-1458-682.png" alt="Trinity-RFT-modes">
144
152
  </p>
145
153
 
154
+ </details>
155
+
156
+
157
+ <details>
158
+ <summary>Figure: Concatenated and general multi-step workflows</summary>
159
+
160
+ <p align="center">
161
+ <img src="https://img.alicdn.com/imgextra/i1/O1CN01z1i7kk1jlMEVa8ZHV_!!6000000004588-2-tps-1262-695.png" alt="Trinity-RFT-multi-step">
162
+ </p>
146
163
 
147
164
  </details>
148
165
 
@@ -214,6 +231,12 @@ It is designed to support diverse application scenarios and serve as a unified p
214
231
  ### Step 1: installation
215
232
 
216
233
 
234
+ Requirements:
235
+ - Python version >= 3.10, <= 3.12
236
+ - CUDA version >= 12.4, <= 12.8
237
+ - At least 2 GPUs
238
+
239
+
217
240
  Installation from source **(recommended)**:
218
241
 
219
242
  ```shell
@@ -243,13 +266,15 @@ pip install -e .[flash_attn]
243
266
  # for zsh
244
267
  pip install -e .\[flash_attn\]
245
268
  # Try the following command if you encounter errors during flash-attn installation
246
- # pip install flash-attn -v --no-build-isolation
269
+ # pip install flash-attn==2.8.0.post2 -v --no-build-isolation
247
270
  ```
248
271
 
249
272
  Installation using pip:
250
273
 
251
274
  ```shell
252
275
  pip install trinity-rft==0.2.0
276
+ # install flash-attn separately
277
+ pip install flash-attn==2.8.0.post2
253
278
  ```
254
279
 
255
280
  Installation from docker:
@@ -268,13 +293,6 @@ docker build -f scripts/docker/Dockerfile -t trinity-rft:latest .
268
293
  docker run -it --gpus all --shm-size="64g" --rm -v $PWD:/workspace -v <root_path_of_data_and_checkpoints>:/data trinity-rft:latest
269
294
  ```
270
295
 
271
-
272
- **Requirements:**
273
- Python version >= 3.10,
274
- CUDA version >= 12.4,
275
- and at least 2 GPUs.
276
-
277
-
278
296
  ### Step 2: prepare dataset and model
279
297
 
280
298
 
@@ -291,7 +309,7 @@ huggingface-cli download {model_name} --local-dir $MODEL_PATH/{model_name}
291
309
  modelscope download {model_name} --local_dir $MODEL_PATH/{model_name}
292
310
  ```
293
311
 
294
- For more details about model downloading, see [Huggingface](https://huggingface.co/docs/huggingface_hub/main/en/guides/cli) or [ModelScope](https://modelscope.cn/docs/models/download).
312
+ For more details about model downloading, see [Huggingface](https://huggingface.co/docs/huggingface_hub/main/en/guides/cli) or [ModelScope](https://modelscope.cn/docs/models/download).
295
313
 
296
314
 
297
315
 
@@ -318,13 +336,13 @@ Trinity-RFT provides a web interface for configuring your RFT process.
318
336
  > This is an experimental feature, and we will continue to improve it.
319
337
 
320
338
 
321
- To enable minimal features (mainly for trainer), you can run
339
+ To launch the web interface for minimal configurations, you can run
322
340
 
323
341
  ```bash
324
342
  trinity studio --port 8080
325
343
  ```
326
344
 
327
- Then you can configure your RFT process in the web page and generate a config file. You can save the config for later use or run it directly as described in the following section.
345
+ Then you can configure your RFT process in the web page and generate a config file. You can save the config file for later use or run it directly as described in the following section.
328
346
 
329
347
  Advanced users can also edit the config file directly.
330
348
  We provide example config files in [`examples`](examples/).
@@ -392,7 +410,12 @@ Tutorials for running different RFT modes:
392
410
 
393
411
  Tutorials for adapting Trinity-RFT to a new multi-turn agentic scenario:
394
412
 
395
- + [Multi-turn tasks](./docs/sphinx_doc/source/tutorial/example_multi_turn.md)
413
+ + [Concatenated Multi-turn tasks](./docs/sphinx_doc/source/tutorial/example_multi_turn.md)
414
+
415
+ Tutorials for adapting Trinity-RFT to a general multi-step agentic scenario:
416
+
417
+ + [General Multi-Step tasks](./docs/sphinx_doc/source/tutorial/example_step_wise.md)
418
+ + [ReAct agent tasks](./docs/sphinx_doc/source/tutorial/example_react.md)
396
419
 
397
420
 
398
421
  Tutorials for data-related functionalities:
@@ -416,10 +439,6 @@ Guidelines for developers and researchers:
416
439
 
417
440
 
418
441
 
419
- For some frequently asked questions, see [FAQ](./docs/sphinx_doc/source/tutorial/faq.md).
420
-
421
-
422
-
423
442
 
424
443
  ## Upcoming features
425
444
 
@@ -1,3 +1,4 @@
1
+ [**中文主页**](https://github.com/modelscope/Trinity-RFT/blob/main/README_zh.md) | [**Tutorial**](https://modelscope.github.io/Trinity-RFT/) | [**FAQ**](./docs/sphinx_doc/source/tutorial/faq.md)
1
2
 
2
3
  <div align="center">
3
4
  <img src="https://img.alicdn.com/imgextra/i1/O1CN01lvLpfw25Pl4ohGZnU_!!6000000007519-2-tps-1628-490.png" alt="Trinity-RFT" style="height: 120px;">
@@ -21,6 +22,9 @@
21
22
 
22
23
  ## 🚀 News
23
24
 
25
+ * [2025-08] ✨ Trinity-RFT v0.2.1 is released with enhanced features for Agentic RL and Async RL.
26
+ * [2025-08] 🎵 We introduce [CHORD](https://github.com/modelscope/Trinity-RFT/tree/main/examples/mix_chord), a dynamic integration of SFT and RL for enhanced LLM fine-tuning ([paper](https://arxiv.org/pdf/2508.11408)).
27
+ * [2025-08] We now support training on general multi-step workflows! Please check out examples for [ALFWorld](./docs/sphinx_doc/source/tutorial/example_step_wise.md) and [ReAct](./docs/sphinx_doc/source/tutorial/example_react.md).
24
28
  * [2025-07] Trinity-RFT v0.2.0 is released.
25
29
  * [2025-07] We update the [technical report](https://arxiv.org/abs/2505.17826) (arXiv v2) with new features, examples, and experiments.
26
30
  * [2025-06] Trinity-RFT v0.1.1 is released.
@@ -80,6 +84,15 @@ It is designed to support diverse application scenarios and serve as a unified p
80
84
  <img src="https://img.alicdn.com/imgextra/i3/O1CN01E7NskS1FFoTI9jlaQ_!!6000000000458-2-tps-1458-682.png" alt="Trinity-RFT-modes">
81
85
  </p>
82
86
 
87
+ </details>
88
+
89
+
90
+ <details>
91
+ <summary>Figure: Concatenated and general multi-step workflows</summary>
92
+
93
+ <p align="center">
94
+ <img src="https://img.alicdn.com/imgextra/i1/O1CN01z1i7kk1jlMEVa8ZHV_!!6000000004588-2-tps-1262-695.png" alt="Trinity-RFT-multi-step">
95
+ </p>
83
96
 
84
97
  </details>
85
98
 
@@ -151,6 +164,12 @@ It is designed to support diverse application scenarios and serve as a unified p
151
164
  ### Step 1: installation
152
165
 
153
166
 
167
+ Requirements:
168
+ - Python version >= 3.10, <= 3.12
169
+ - CUDA version >= 12.4, <= 12.8
170
+ - At least 2 GPUs
171
+
172
+
154
173
  Installation from source **(recommended)**:
155
174
 
156
175
  ```shell
@@ -180,13 +199,15 @@ pip install -e .[flash_attn]
180
199
  # for zsh
181
200
  pip install -e .\[flash_attn\]
182
201
  # Try the following command if you encounter errors during flash-attn installation
183
- # pip install flash-attn -v --no-build-isolation
202
+ # pip install flash-attn==2.8.0.post2 -v --no-build-isolation
184
203
  ```
185
204
 
186
205
  Installation using pip:
187
206
 
188
207
  ```shell
189
208
  pip install trinity-rft==0.2.0
209
+ # install flash-attn separately
210
+ pip install flash-attn==2.8.0.post2
190
211
  ```
191
212
 
192
213
  Installation from docker:
@@ -205,13 +226,6 @@ docker build -f scripts/docker/Dockerfile -t trinity-rft:latest .
205
226
  docker run -it --gpus all --shm-size="64g" --rm -v $PWD:/workspace -v <root_path_of_data_and_checkpoints>:/data trinity-rft:latest
206
227
  ```
207
228
 
208
-
209
- **Requirements:**
210
- Python version >= 3.10,
211
- CUDA version >= 12.4,
212
- and at least 2 GPUs.
213
-
214
-
215
229
  ### Step 2: prepare dataset and model
216
230
 
217
231
 
@@ -228,7 +242,7 @@ huggingface-cli download {model_name} --local-dir $MODEL_PATH/{model_name}
228
242
  modelscope download {model_name} --local_dir $MODEL_PATH/{model_name}
229
243
  ```
230
244
 
231
- For more details about model downloading, see [Huggingface](https://huggingface.co/docs/huggingface_hub/main/en/guides/cli) or [ModelScope](https://modelscope.cn/docs/models/download).
245
+ For more details about model downloading, see [Huggingface](https://huggingface.co/docs/huggingface_hub/main/en/guides/cli) or [ModelScope](https://modelscope.cn/docs/models/download).
232
246
 
233
247
 
234
248
 
@@ -255,13 +269,13 @@ Trinity-RFT provides a web interface for configuring your RFT process.
255
269
  > This is an experimental feature, and we will continue to improve it.
256
270
 
257
271
 
258
- To enable minimal features (mainly for trainer), you can run
272
+ To launch the web interface for minimal configurations, you can run
259
273
 
260
274
  ```bash
261
275
  trinity studio --port 8080
262
276
  ```
263
277
 
264
- Then you can configure your RFT process in the web page and generate a config file. You can save the config for later use or run it directly as described in the following section.
278
+ Then you can configure your RFT process in the web page and generate a config file. You can save the config file for later use or run it directly as described in the following section.
265
279
 
266
280
  Advanced users can also edit the config file directly.
267
281
  We provide example config files in [`examples`](examples/).
@@ -329,7 +343,12 @@ Tutorials for running different RFT modes:
329
343
 
330
344
  Tutorials for adapting Trinity-RFT to a new multi-turn agentic scenario:
331
345
 
332
- + [Multi-turn tasks](./docs/sphinx_doc/source/tutorial/example_multi_turn.md)
346
+ + [Concatenated Multi-turn tasks](./docs/sphinx_doc/source/tutorial/example_multi_turn.md)
347
+
348
+ Tutorials for adapting Trinity-RFT to a general multi-step agentic scenario:
349
+
350
+ + [General Multi-Step tasks](./docs/sphinx_doc/source/tutorial/example_step_wise.md)
351
+ + [ReAct agent tasks](./docs/sphinx_doc/source/tutorial/example_react.md)
333
352
 
334
353
 
335
354
  Tutorials for data-related functionalities:
@@ -353,10 +372,6 @@ Guidelines for developers and researchers:
353
372
 
354
373
 
355
374
 
356
- For some frequently asked questions, see [FAQ](./docs/sphinx_doc/source/tutorial/faq.md).
357
-
358
-
359
-
360
375
 
361
376
  ## Upcoming features
362
377
 
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "trinity-rft"
7
- version = "0.2.0"
7
+ version = "0.2.1.dev0"
8
8
  authors = [
9
9
  {name="Trinity-RFT Team", email="trinity-rft@outlook.com"},
10
10
  ]
@@ -21,9 +21,9 @@ classifiers = [
21
21
  ]
22
22
  requires-python = ">=3.10"
23
23
  dependencies = [
24
- "verl==0.4.0",
24
+ "verl==0.4.1",
25
25
  "ray[default]>=2.45.0",
26
- "vllm==0.9.1",
26
+ "vllm>=0.9.1,<=0.10.0",
27
27
  "tensordict==0.6.2",
28
28
  "wandb",
29
29
  "omegaconf",
@@ -41,6 +41,8 @@ dependencies = [
41
41
  "openai",
42
42
  "jsonlines",
43
43
  "sortedcontainers",
44
+ "word2number",
45
+ "transformers<4.54.0", # TODO: remove when https://github.com/vllm-project/vllm-ascend/issues/2046 is fixed
44
46
  ]
45
47
 
46
48
  [project.scripts]
@@ -66,7 +68,7 @@ dev = [
66
68
  "pytest>=8.0.0",
67
69
  "pytest-json-ctrf",
68
70
  "parameterized",
69
- "matplotlib"
71
+ "matplotlib",
70
72
  ]
71
73
 
72
74
  doc = [
@@ -74,6 +76,8 @@ doc = [
74
76
  "sphinx-autobuild",
75
77
  "sphinx_rtd_theme",
76
78
  "myst-parser",
79
+ "sphinxcontrib-apidoc",
80
+ "sphinx-multiversion",
77
81
  ]
78
82
 
79
83
  flash_attn = [
@@ -1,4 +1,4 @@
1
1
  # -*- coding: utf-8 -*-
2
2
  """Trinity-RFT (Reinforcement Fine-Tuning)"""
3
3
 
4
- __version__ = "0.2.0"
4
+ __version__ = "0.2.1"
@@ -1,3 +1,4 @@
1
+ from trinity.algorithm.add_strategy import ADD_STRATEGY, AddStrategy
1
2
  from trinity.algorithm.advantage_fn import ADVANTAGE_FN, AdvantageFn
2
3
  from trinity.algorithm.algorithm import ALGORITHM_TYPE, AlgorithmType
3
4
  from trinity.algorithm.entropy_loss_fn import ENTROPY_LOSS_FN, EntropyLossFn
@@ -18,4 +19,6 @@ __all__ = [
18
19
  "ENTROPY_LOSS_FN",
19
20
  "SampleStrategy",
20
21
  "SAMPLE_STRATEGY",
22
+ "AddStrategy",
23
+ "ADD_STRATEGY",
21
24
  ]
@@ -0,0 +1,25 @@
1
+ from trinity.algorithm.add_strategy.add_strategy import (
2
+ ADD_STRATEGY,
3
+ AddStrategy,
4
+ GRPOAddStrategy,
5
+ OPMDAddStrategy,
6
+ RewardVarianceAddStrategy,
7
+ )
8
+ from trinity.algorithm.add_strategy.correct_bias_add_strategy import (
9
+ CorrectBiasAddStrategy,
10
+ )
11
+ from trinity.algorithm.add_strategy.duplicate_add_strategy import (
12
+ DuplicateInformativeAddStrategy,
13
+ )
14
+ from trinity.algorithm.add_strategy.step_wise_add_strategy import StepWiseGRPOStrategy
15
+
16
+ __all__ = [
17
+ "ADD_STRATEGY",
18
+ "AddStrategy",
19
+ "GRPOAddStrategy",
20
+ "OPMDAddStrategy",
21
+ "StepWiseGRPOStrategy",
22
+ "RewardVarianceAddStrategy",
23
+ "CorrectBiasAddStrategy",
24
+ "DuplicateInformativeAddStrategy",
25
+ ]
@@ -0,0 +1,230 @@
1
+ import asyncio
2
+ from abc import ABC, abstractmethod
3
+ from typing import Dict, List, Literal, Tuple
4
+
5
+ import numpy as np
6
+ import torch
7
+
8
+ from trinity.buffer import BufferWriter
9
+ from trinity.common.experience import Experience
10
+ from trinity.utils.monitor import gather_metrics
11
+ from trinity.utils.registry import Registry
12
+ from trinity.utils.timer import Timer
13
+
14
+ ADD_STRATEGY = Registry("add_strategy")
15
+
16
+
17
+ class AddStrategy(ABC):
18
+ def __init__(self, writer: BufferWriter, **kwargs) -> None:
19
+ self.writer = writer
20
+
21
+ @abstractmethod
22
+ async def add(self, experiences: List[Experience], step: int) -> Tuple[int, Dict]:
23
+ """Add experiences to the buffer.
24
+
25
+ Args:
26
+ experiences (`Experience`): The experiences to be added.
27
+ step (`int`): The current step number.
28
+
29
+ Returns:
30
+ `int`: The number of experiences added to the buffer.
31
+ `Dict`: Metrics for logging.
32
+ """
33
+
34
+ @classmethod
35
+ @abstractmethod
36
+ def default_args(cls) -> dict:
37
+ """Get the default arguments of the add strategy.
38
+
39
+ Returns:
40
+ `dict`: The default arguments.
41
+ """
42
+
43
+
44
+ class GroupAdvantageStrategy(AddStrategy):
45
+ """An example AddStrategy that calculates group advantages."""
46
+
47
+ @abstractmethod
48
+ def group_experiences(self, exps: List[Experience]) -> Dict[str, List[Experience]]:
49
+ """Group experiences by a certain criterion.
50
+
51
+ Args:
52
+ exps (List[Experience]): List of experiences to be grouped.
53
+
54
+ Returns:
55
+ Dict[str, List[Experience]]: A dictionary where keys are group identifiers and values are lists of experiences.
56
+ """
57
+
58
+ @abstractmethod
59
+ def calculate_group_advantage(
60
+ self, group_id: str, exps: List[Experience]
61
+ ) -> Tuple[List[Experience], Dict]:
62
+ """Calculate advantages for a group of experiences.
63
+
64
+ Args:
65
+ group_id (str): The identifier for the group of experiences.
66
+ exps (List[Experience]): List of experiences in the group.
67
+
68
+ Returns:
69
+ Tuple[List[Experience], Dict]: A tuple containing the modified list of experiences and a dictionary of metrics.
70
+ """
71
+
72
+ async def add(self, exps: List[Experience], step: int) -> Tuple[int, Dict]:
73
+ if len(exps) == 0:
74
+ return 0, {}
75
+ exp_groups = self.group_experiences(exps)
76
+ cnt = 0
77
+ metric_list = []
78
+ tasks = []
79
+ for group_id, group_exps in exp_groups.items():
80
+ group_exps, group_metrics = self.calculate_group_advantage(group_id, group_exps)
81
+ metric_list.append(group_metrics)
82
+ cnt += len(group_exps)
83
+ if len(group_exps) > 0:
84
+ tasks.append(self.writer.write_async(group_exps))
85
+ if tasks:
86
+ await asyncio.gather(*tasks)
87
+ try:
88
+ metrics = gather_metrics(metric_list, "group_advantages")
89
+ except ValueError:
90
+ metrics = {} # empty metric list causes ValueError, ignore it
91
+ return cnt, metrics
92
+
93
+
94
+ @ADD_STRATEGY.register_module("grpo")
95
+ class GRPOAddStrategy(GroupAdvantageStrategy):
96
+ """An example AddStrategy that calculates GRPO advantages."""
97
+
98
+ def __init__(self, writer: BufferWriter, epsilon: float = 1e-6, **kwargs) -> None:
99
+ super().__init__(writer)
100
+ self.epsilon = epsilon
101
+
102
+ def group_experiences(self, exps):
103
+ return group_by(exps, id_type="task")
104
+
105
+ def calculate_group_advantage(
106
+ self, group_id: str, exps: List[Experience]
107
+ ) -> Tuple[List[Experience], Dict]:
108
+ with torch.no_grad():
109
+ if len(exps) == 1:
110
+ group_reward_mean = torch.tensor(0.0)
111
+ group_reward_std = torch.tensor(1.0)
112
+ else:
113
+ rewards = torch.tensor([exp.reward for exp in exps], dtype=torch.float32)
114
+ group_reward_mean = torch.mean(rewards)
115
+ group_reward_std = torch.std(rewards)
116
+ for exp in exps:
117
+ score = (exp.reward - group_reward_mean) / (group_reward_std + self.epsilon)
118
+ exp.advantages = score * exp.action_mask
119
+ exp.returns = exp.advantages.clone()
120
+
121
+ metrics = {
122
+ "reward_mean": group_reward_mean.item(),
123
+ "reward_std": group_reward_std.item(),
124
+ }
125
+
126
+ return exps, metrics
127
+
128
+ @classmethod
129
+ def default_args(cls) -> dict:
130
+ return {"epsilon": 1e-6}
131
+
132
+
133
+ @ADD_STRATEGY.register_module("opmd")
134
+ class OPMDAddStrategy(GroupAdvantageStrategy):
135
+ """An example AddStrategy that calculates OPMD advantages."""
136
+
137
+ def __init__(
138
+ self, writer: BufferWriter, opmd_baseline: str = "mean", tau: float = 1.0, **kwargs
139
+ ) -> None:
140
+ super().__init__(writer)
141
+ assert opmd_baseline in [
142
+ "mean",
143
+ "logavgexp",
144
+ ], f"opmd_baseline must be 'mean' or 'logavgexp', got {opmd_baseline}"
145
+ self.opmd_baseline = opmd_baseline
146
+ self.tau = tau
147
+
148
+ def group_experiences(self, exps):
149
+ return group_by(exps, id_type="task")
150
+
151
+ def calculate_group_advantage(
152
+ self, group_id: str, exps: List[Experience]
153
+ ) -> Tuple[List[Experience], Dict]:
154
+ with torch.no_grad():
155
+ if len(exps) == 1:
156
+ group_baseline = torch.tensor(0.0)
157
+ else:
158
+ group_rewards = torch.tensor([exp.reward for exp in exps], dtype=torch.float32)
159
+ if self.opmd_baseline == "mean":
160
+ group_baseline = torch.mean(group_rewards)
161
+ else:
162
+ group_baseline = self.tau * (
163
+ torch.logsumexp(group_rewards / self.tau, dim=-1)
164
+ - torch.log(torch.tensor(len(exps)))
165
+ )
166
+ for exp in exps:
167
+ score = exp.reward - group_baseline
168
+ exp.advantages = score * exp.action_mask
169
+ exp.returns = exp.advantages.clone()
170
+ metrics = {
171
+ "group_baseline": group_baseline,
172
+ }
173
+ return exps, metrics
174
+
175
+ @classmethod
176
+ def default_args(cls) -> dict:
177
+ return {"opmd_baseline": "mean", "tau": 1.0}
178
+
179
+
180
+ @ADD_STRATEGY.register_module("reward_variance")
181
+ class RewardVarianceAddStrategy(AddStrategy):
182
+ """An example AddStrategy that filters experiences based on a reward variance threshold."""
183
+
184
+ def __init__(self, writer: BufferWriter, variance_threshold: float = 0.0, **kwargs) -> None:
185
+ super().__init__(writer)
186
+ self.variance_threshold = variance_threshold
187
+
188
+ async def add(self, experiences: List[Experience], step: int) -> Tuple[int, Dict]:
189
+ cnt = 0
190
+ metrics = {}
191
+ tasks = []
192
+ with Timer(metrics, "add_strategy_time"):
193
+ grouped_experiences = group_by(experiences, id_type="task")
194
+ for _, group_exps in grouped_experiences.items():
195
+ if len(group_exps) < 2:
196
+ continue
197
+ rewards = [exp.reward for exp in group_exps]
198
+ variance = np.var(rewards)
199
+ if variance <= self.variance_threshold:
200
+ continue
201
+ cnt += len(group_exps)
202
+ tasks.append(self.writer.write_async(group_exps))
203
+ if tasks:
204
+ await asyncio.gather(*tasks)
205
+ return cnt, metrics
206
+
207
+ @classmethod
208
+ def default_args(cls) -> dict:
209
+ return {"variance_threshold": 0.0}
210
+
211
+
212
+ def group_by(
213
+ experiences: List[Experience], id_type: Literal["task", "run", "step"]
214
+ ) -> Dict[str, List[Experience]]:
215
+ """Group experiences by ID."""
216
+ if id_type == "task":
217
+ id_type = "tid"
218
+ elif id_type == "run":
219
+ id_type = "rid"
220
+ elif id_type == "step":
221
+ id_type = "sid"
222
+ else:
223
+ raise ValueError(f"Unknown id_type: {id_type}")
224
+ grouped = {}
225
+ for exp in experiences:
226
+ group_id = getattr(exp.eid, id_type)
227
+ if group_id not in grouped:
228
+ grouped[group_id] = []
229
+ grouped[group_id].append(exp)
230
+ return grouped
@@ -0,0 +1,54 @@
1
+ # -*- coding: utf-8 -*-
2
+ from typing import Dict, List, Tuple
3
+
4
+ import torch
5
+
6
+ from trinity.algorithm.add_strategy.add_strategy import ADD_STRATEGY, GRPOAddStrategy
7
+ from trinity.buffer import BufferWriter
8
+ from trinity.common.experience import Experience
9
+
10
+
11
+ @ADD_STRATEGY.register_module("correct_bias")
12
+ class CorrectBiasAddStrategy(GRPOAddStrategy):
13
+ """An Addstrategy with GroupAdvantage that corrects for rank bias (https://arxiv.org/pdf/2506.02355)"""
14
+
15
+ def __init__(
16
+ self, writer: BufferWriter, epsilon: float = 1e-6, rank_penalty: float = 0.25, **kwargs
17
+ ) -> None:
18
+ super().__init__(writer, epsilon, **kwargs)
19
+ self.rank_penalty = rank_penalty
20
+
21
+ def calculate_group_advantage(
22
+ self, group_id: str, exps: List[Experience]
23
+ ) -> Tuple[List[Experience], Dict]:
24
+ with torch.no_grad():
25
+ rewards = torch.tensor([exp.reward for exp in exps], dtype=torch.float32)
26
+
27
+ if len(exps) == 1:
28
+ group_reward_mean = torch.tensor(0.0)
29
+ group_reward_std = torch.tensor(1.0)
30
+ else:
31
+ # correct bias
32
+ old_log_probs = torch.tensor([torch.mean(exp.logprobs, axis=-1) for exp in exps])
33
+ group_ranks = torch.argsort(torch.argsort(old_log_probs))
34
+ group_ranks = group_ranks / len(group_ranks)
35
+ rewards = rewards * (1 - group_ranks * self.rank_penalty)
36
+
37
+ group_reward_mean = torch.mean(rewards)
38
+ group_reward_std = torch.std(rewards)
39
+
40
+ for i, exp in enumerate(exps):
41
+ score = (rewards[i] - group_reward_mean) / (group_reward_std + self.epsilon)
42
+ exp.advantages = score * exp.action_mask
43
+ exp.returns = exp.advantages.clone()
44
+
45
+ metrics = {
46
+ "reward_mean": group_reward_mean.item(),
47
+ "reward_std": group_reward_std.item(),
48
+ }
49
+
50
+ return exps, metrics
51
+
52
+ @classmethod
53
+ def default_args(cls) -> dict:
54
+ return {"epsilon": 1e-6, "rank_penalty": 0.25}