torchrl 0.11.0__cp314-cp314-manylinux_2_28_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (394) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/__init__.py +144 -0
  97. torchrl/_extension.py +74 -0
  98. torchrl/_torchrl.cpython-314-aarch64-linux-gnu.so +0 -0
  99. torchrl/_utils.py +1431 -0
  100. torchrl/collectors/__init__.py +48 -0
  101. torchrl/collectors/_base.py +1058 -0
  102. torchrl/collectors/_constants.py +88 -0
  103. torchrl/collectors/_multi_async.py +324 -0
  104. torchrl/collectors/_multi_base.py +1805 -0
  105. torchrl/collectors/_multi_sync.py +464 -0
  106. torchrl/collectors/_runner.py +581 -0
  107. torchrl/collectors/_single.py +2009 -0
  108. torchrl/collectors/_single_async.py +259 -0
  109. torchrl/collectors/collectors.py +62 -0
  110. torchrl/collectors/distributed/__init__.py +32 -0
  111. torchrl/collectors/distributed/default_configs.py +133 -0
  112. torchrl/collectors/distributed/generic.py +1306 -0
  113. torchrl/collectors/distributed/ray.py +1092 -0
  114. torchrl/collectors/distributed/rpc.py +1006 -0
  115. torchrl/collectors/distributed/sync.py +731 -0
  116. torchrl/collectors/distributed/utils.py +160 -0
  117. torchrl/collectors/llm/__init__.py +10 -0
  118. torchrl/collectors/llm/base.py +494 -0
  119. torchrl/collectors/llm/ray_collector.py +275 -0
  120. torchrl/collectors/llm/utils.py +36 -0
  121. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  122. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  123. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  124. torchrl/collectors/utils.py +433 -0
  125. torchrl/collectors/weight_update.py +591 -0
  126. torchrl/csrc/numpy_utils.h +38 -0
  127. torchrl/csrc/pybind.cpp +27 -0
  128. torchrl/csrc/segment_tree.h +458 -0
  129. torchrl/csrc/torch_utils.h +34 -0
  130. torchrl/csrc/utils.cpp +48 -0
  131. torchrl/csrc/utils.h +31 -0
  132. torchrl/data/__init__.py +187 -0
  133. torchrl/data/datasets/__init__.py +58 -0
  134. torchrl/data/datasets/atari_dqn.py +878 -0
  135. torchrl/data/datasets/common.py +281 -0
  136. torchrl/data/datasets/d4rl.py +489 -0
  137. torchrl/data/datasets/d4rl_infos.py +187 -0
  138. torchrl/data/datasets/gen_dgrl.py +375 -0
  139. torchrl/data/datasets/minari_data.py +643 -0
  140. torchrl/data/datasets/openml.py +177 -0
  141. torchrl/data/datasets/openx.py +798 -0
  142. torchrl/data/datasets/roboset.py +363 -0
  143. torchrl/data/datasets/utils.py +11 -0
  144. torchrl/data/datasets/vd4rl.py +432 -0
  145. torchrl/data/llm/__init__.py +34 -0
  146. torchrl/data/llm/dataset.py +491 -0
  147. torchrl/data/llm/history.py +1378 -0
  148. torchrl/data/llm/prompt.py +198 -0
  149. torchrl/data/llm/reward.py +225 -0
  150. torchrl/data/llm/topk.py +186 -0
  151. torchrl/data/llm/utils.py +543 -0
  152. torchrl/data/map/__init__.py +21 -0
  153. torchrl/data/map/hash.py +185 -0
  154. torchrl/data/map/query.py +204 -0
  155. torchrl/data/map/tdstorage.py +363 -0
  156. torchrl/data/map/tree.py +1434 -0
  157. torchrl/data/map/utils.py +103 -0
  158. torchrl/data/postprocs/__init__.py +8 -0
  159. torchrl/data/postprocs/postprocs.py +391 -0
  160. torchrl/data/replay_buffers/__init__.py +99 -0
  161. torchrl/data/replay_buffers/checkpointers.py +622 -0
  162. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  163. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  164. torchrl/data/replay_buffers/samplers.py +2578 -0
  165. torchrl/data/replay_buffers/scheduler.py +265 -0
  166. torchrl/data/replay_buffers/storages.py +2412 -0
  167. torchrl/data/replay_buffers/utils.py +1042 -0
  168. torchrl/data/replay_buffers/writers.py +781 -0
  169. torchrl/data/tensor_specs.py +7101 -0
  170. torchrl/data/utils.py +334 -0
  171. torchrl/envs/__init__.py +265 -0
  172. torchrl/envs/async_envs.py +1105 -0
  173. torchrl/envs/batched_envs.py +3093 -0
  174. torchrl/envs/common.py +4241 -0
  175. torchrl/envs/custom/__init__.py +11 -0
  176. torchrl/envs/custom/chess.py +617 -0
  177. torchrl/envs/custom/llm.py +214 -0
  178. torchrl/envs/custom/pendulum.py +401 -0
  179. torchrl/envs/custom/san_moves.txt +29274 -0
  180. torchrl/envs/custom/tictactoeenv.py +288 -0
  181. torchrl/envs/env_creator.py +263 -0
  182. torchrl/envs/gym_like.py +752 -0
  183. torchrl/envs/libs/__init__.py +68 -0
  184. torchrl/envs/libs/_gym_utils.py +326 -0
  185. torchrl/envs/libs/brax.py +846 -0
  186. torchrl/envs/libs/dm_control.py +544 -0
  187. torchrl/envs/libs/envpool.py +447 -0
  188. torchrl/envs/libs/gym.py +2239 -0
  189. torchrl/envs/libs/habitat.py +138 -0
  190. torchrl/envs/libs/isaac_lab.py +87 -0
  191. torchrl/envs/libs/isaacgym.py +203 -0
  192. torchrl/envs/libs/jax_utils.py +166 -0
  193. torchrl/envs/libs/jumanji.py +963 -0
  194. torchrl/envs/libs/meltingpot.py +599 -0
  195. torchrl/envs/libs/openml.py +153 -0
  196. torchrl/envs/libs/openspiel.py +652 -0
  197. torchrl/envs/libs/pettingzoo.py +1042 -0
  198. torchrl/envs/libs/procgen.py +351 -0
  199. torchrl/envs/libs/robohive.py +429 -0
  200. torchrl/envs/libs/smacv2.py +645 -0
  201. torchrl/envs/libs/unity_mlagents.py +891 -0
  202. torchrl/envs/libs/utils.py +147 -0
  203. torchrl/envs/libs/vmas.py +813 -0
  204. torchrl/envs/llm/__init__.py +63 -0
  205. torchrl/envs/llm/chat.py +730 -0
  206. torchrl/envs/llm/datasets/README.md +4 -0
  207. torchrl/envs/llm/datasets/__init__.py +17 -0
  208. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  209. torchrl/envs/llm/datasets/ifeval.py +274 -0
  210. torchrl/envs/llm/envs.py +789 -0
  211. torchrl/envs/llm/libs/README.md +3 -0
  212. torchrl/envs/llm/libs/__init__.py +8 -0
  213. torchrl/envs/llm/libs/mlgym.py +869 -0
  214. torchrl/envs/llm/reward/__init__.py +10 -0
  215. torchrl/envs/llm/reward/gsm8k.py +324 -0
  216. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  217. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  218. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  222. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  223. torchrl/envs/llm/transforms/__init__.py +55 -0
  224. torchrl/envs/llm/transforms/browser.py +292 -0
  225. torchrl/envs/llm/transforms/dataloading.py +859 -0
  226. torchrl/envs/llm/transforms/format.py +73 -0
  227. torchrl/envs/llm/transforms/kl.py +1544 -0
  228. torchrl/envs/llm/transforms/policy_version.py +189 -0
  229. torchrl/envs/llm/transforms/reason.py +323 -0
  230. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  231. torchrl/envs/llm/transforms/tools.py +1955 -0
  232. torchrl/envs/model_based/__init__.py +9 -0
  233. torchrl/envs/model_based/common.py +180 -0
  234. torchrl/envs/model_based/dreamer.py +112 -0
  235. torchrl/envs/transforms/__init__.py +147 -0
  236. torchrl/envs/transforms/functional.py +48 -0
  237. torchrl/envs/transforms/gym_transforms.py +203 -0
  238. torchrl/envs/transforms/module.py +341 -0
  239. torchrl/envs/transforms/r3m.py +372 -0
  240. torchrl/envs/transforms/ray_service.py +663 -0
  241. torchrl/envs/transforms/rb_transforms.py +214 -0
  242. torchrl/envs/transforms/transforms.py +11835 -0
  243. torchrl/envs/transforms/utils.py +94 -0
  244. torchrl/envs/transforms/vc1.py +307 -0
  245. torchrl/envs/transforms/vecnorm.py +845 -0
  246. torchrl/envs/transforms/vip.py +407 -0
  247. torchrl/envs/utils.py +1718 -0
  248. torchrl/envs/vec_envs.py +11 -0
  249. torchrl/modules/__init__.py +206 -0
  250. torchrl/modules/distributions/__init__.py +73 -0
  251. torchrl/modules/distributions/continuous.py +830 -0
  252. torchrl/modules/distributions/discrete.py +908 -0
  253. torchrl/modules/distributions/truncated_normal.py +187 -0
  254. torchrl/modules/distributions/utils.py +233 -0
  255. torchrl/modules/llm/__init__.py +62 -0
  256. torchrl/modules/llm/backends/__init__.py +65 -0
  257. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  258. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  259. torchrl/modules/llm/backends/vllm/base.py +72 -0
  260. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  261. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  262. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  263. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  264. torchrl/modules/llm/policies/__init__.py +28 -0
  265. torchrl/modules/llm/policies/common.py +1809 -0
  266. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  267. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  268. torchrl/modules/llm/utils.py +23 -0
  269. torchrl/modules/mcts/__init__.py +21 -0
  270. torchrl/modules/mcts/scores.py +579 -0
  271. torchrl/modules/models/__init__.py +86 -0
  272. torchrl/modules/models/batchrenorm.py +119 -0
  273. torchrl/modules/models/decision_transformer.py +179 -0
  274. torchrl/modules/models/exploration.py +731 -0
  275. torchrl/modules/models/llm.py +156 -0
  276. torchrl/modules/models/model_based.py +596 -0
  277. torchrl/modules/models/models.py +1712 -0
  278. torchrl/modules/models/multiagent.py +1067 -0
  279. torchrl/modules/models/recipes/impala.py +185 -0
  280. torchrl/modules/models/utils.py +162 -0
  281. torchrl/modules/planners/__init__.py +10 -0
  282. torchrl/modules/planners/cem.py +228 -0
  283. torchrl/modules/planners/common.py +73 -0
  284. torchrl/modules/planners/mppi.py +265 -0
  285. torchrl/modules/tensordict_module/__init__.py +89 -0
  286. torchrl/modules/tensordict_module/actors.py +2457 -0
  287. torchrl/modules/tensordict_module/common.py +529 -0
  288. torchrl/modules/tensordict_module/exploration.py +814 -0
  289. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  290. torchrl/modules/tensordict_module/rnn.py +1639 -0
  291. torchrl/modules/tensordict_module/sequence.py +132 -0
  292. torchrl/modules/tensordict_module/world_models.py +34 -0
  293. torchrl/modules/utils/__init__.py +38 -0
  294. torchrl/modules/utils/mappings.py +9 -0
  295. torchrl/modules/utils/utils.py +89 -0
  296. torchrl/objectives/__init__.py +78 -0
  297. torchrl/objectives/a2c.py +659 -0
  298. torchrl/objectives/common.py +753 -0
  299. torchrl/objectives/cql.py +1346 -0
  300. torchrl/objectives/crossq.py +710 -0
  301. torchrl/objectives/ddpg.py +453 -0
  302. torchrl/objectives/decision_transformer.py +371 -0
  303. torchrl/objectives/deprecated.py +516 -0
  304. torchrl/objectives/dqn.py +683 -0
  305. torchrl/objectives/dreamer.py +488 -0
  306. torchrl/objectives/functional.py +48 -0
  307. torchrl/objectives/gail.py +258 -0
  308. torchrl/objectives/iql.py +996 -0
  309. torchrl/objectives/llm/__init__.py +30 -0
  310. torchrl/objectives/llm/grpo.py +846 -0
  311. torchrl/objectives/llm/sft.py +482 -0
  312. torchrl/objectives/multiagent/__init__.py +8 -0
  313. torchrl/objectives/multiagent/qmixer.py +396 -0
  314. torchrl/objectives/ppo.py +1669 -0
  315. torchrl/objectives/redq.py +683 -0
  316. torchrl/objectives/reinforce.py +530 -0
  317. torchrl/objectives/sac.py +1580 -0
  318. torchrl/objectives/td3.py +570 -0
  319. torchrl/objectives/td3_bc.py +625 -0
  320. torchrl/objectives/utils.py +782 -0
  321. torchrl/objectives/value/__init__.py +28 -0
  322. torchrl/objectives/value/advantages.py +1956 -0
  323. torchrl/objectives/value/functional.py +1459 -0
  324. torchrl/objectives/value/utils.py +360 -0
  325. torchrl/record/__init__.py +17 -0
  326. torchrl/record/loggers/__init__.py +23 -0
  327. torchrl/record/loggers/common.py +48 -0
  328. torchrl/record/loggers/csv.py +226 -0
  329. torchrl/record/loggers/mlflow.py +142 -0
  330. torchrl/record/loggers/tensorboard.py +139 -0
  331. torchrl/record/loggers/trackio.py +163 -0
  332. torchrl/record/loggers/utils.py +78 -0
  333. torchrl/record/loggers/wandb.py +214 -0
  334. torchrl/record/recorder.py +554 -0
  335. torchrl/services/__init__.py +79 -0
  336. torchrl/services/base.py +109 -0
  337. torchrl/services/ray_service.py +453 -0
  338. torchrl/testing/__init__.py +107 -0
  339. torchrl/testing/assertions.py +179 -0
  340. torchrl/testing/dist_utils.py +122 -0
  341. torchrl/testing/env_creators.py +227 -0
  342. torchrl/testing/env_helper.py +35 -0
  343. torchrl/testing/gym_helpers.py +156 -0
  344. torchrl/testing/llm_mocks.py +119 -0
  345. torchrl/testing/mocking_classes.py +2720 -0
  346. torchrl/testing/modules.py +295 -0
  347. torchrl/testing/mp_helpers.py +15 -0
  348. torchrl/testing/ray_helpers.py +293 -0
  349. torchrl/testing/utils.py +190 -0
  350. torchrl/trainers/__init__.py +42 -0
  351. torchrl/trainers/algorithms/__init__.py +11 -0
  352. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  353. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  354. torchrl/trainers/algorithms/configs/common.py +41 -0
  355. torchrl/trainers/algorithms/configs/data.py +308 -0
  356. torchrl/trainers/algorithms/configs/envs.py +104 -0
  357. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  358. torchrl/trainers/algorithms/configs/logging.py +80 -0
  359. torchrl/trainers/algorithms/configs/modules.py +570 -0
  360. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  361. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  362. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  363. torchrl/trainers/algorithms/configs/utils.py +252 -0
  364. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  365. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  366. torchrl/trainers/algorithms/ppo.py +373 -0
  367. torchrl/trainers/algorithms/sac.py +308 -0
  368. torchrl/trainers/helpers/__init__.py +40 -0
  369. torchrl/trainers/helpers/collectors.py +416 -0
  370. torchrl/trainers/helpers/envs.py +573 -0
  371. torchrl/trainers/helpers/logger.py +33 -0
  372. torchrl/trainers/helpers/losses.py +132 -0
  373. torchrl/trainers/helpers/models.py +658 -0
  374. torchrl/trainers/helpers/replay_buffer.py +59 -0
  375. torchrl/trainers/helpers/trainers.py +301 -0
  376. torchrl/trainers/trainers.py +2052 -0
  377. torchrl/weight_update/__init__.py +33 -0
  378. torchrl/weight_update/_distributed.py +749 -0
  379. torchrl/weight_update/_mp.py +624 -0
  380. torchrl/weight_update/_noupdate.py +102 -0
  381. torchrl/weight_update/_ray.py +1032 -0
  382. torchrl/weight_update/_rpc.py +284 -0
  383. torchrl/weight_update/_shared.py +891 -0
  384. torchrl/weight_update/llm/__init__.py +32 -0
  385. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  386. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  387. torchrl/weight_update/utils.py +73 -0
  388. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  389. torchrl-0.11.0.dist-info/METADATA +1308 -0
  390. torchrl-0.11.0.dist-info/RECORD +394 -0
  391. torchrl-0.11.0.dist-info/WHEEL +5 -0
  392. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  393. torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
  394. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,273 @@
1
+ # GRPO: Generalized Reward-Conditioned Policy Optimization
2
+
3
+ This is an implementation of GRPO for language models, built on top of TorchRL.
4
+
5
+ ## Overview
6
+
7
+ GRPO is a method for training language models using reinforcement learning, with the following key features:
8
+ - Multi-GPU support with efficient device management
9
+ - Mixed precision training
10
+ - Gradient accumulation
11
+ - Automatic checkpointing
12
+ - Comprehensive logging with Weights & Biases
13
+ - Hydra configuration system
14
+ - Asynchronous training support with Ray
15
+
16
+ ## Installation
17
+
18
+ Install dependencies:
19
+ ```bash
20
+ # GSM8K deps
21
+ pip install -r sota-implementations/grpo/requirements_gsm8k.txt
22
+ # IFEval deps
23
+ pip install -r sota-implementations/grpo/requirements_ifeval.txt
24
+ ```
25
+
26
+ ## Hardware Requirements
27
+
28
+ - At least 3 CUDA-capable GPUs:
29
+ - Training device(s)
30
+ - vLLM inference device
31
+ - Reference model device
32
+
33
+ ### Device Management
34
+
35
+ The number of devices for each model component is specified using `num_devices`:
36
+
37
+ ```bash
38
+ train_model.num_devices=2 ref_model.num_devices=2 inference_model.num_devices=2
39
+ ```
40
+
41
+ This approach:
42
+ - Automatically handles device allocation
43
+ - Works correctly in both sync and async modes
44
+ - Prevents device conflicts between model components
45
+ - Is more portable across different machine configurations
46
+
47
+ ## Configuration
48
+
49
+ The training configuration is managed through Hydra. There are two main configuration files:
50
+ - `config/grpo_gsm8k.yaml`: Default configuration for GSM8K tasks (default)
51
+ - `config/grpo_ifeval.yaml`: Configuration optimized for IFEval tasks
52
+
53
+ ## Usage
54
+
55
+ ### Basic Training
56
+
57
+ There are two training modes available:
58
+
59
+ #### Synchronous Mode (Default)
60
+ ```bash
61
+ python sota-implementations/grpo/grpo-sync.py mode=sync train_model.num_devices=2 ref_model.num_devices=2 inference_model.num_devices=2
62
+ ```
63
+
64
+ #### Asynchronous Mode (Recommended)
65
+ ```bash
66
+ python sota-implementations/grpo/grpo-async.py mode=async train_model.num_devices=2 ref_model.num_devices=2 inference_model.num_devices=2
67
+ ```
68
+
69
+ The key difference between sync and async modes is how data collection and optimization are handled:
70
+
71
+ **Synchronous Mode (grpo-sync.py)**:
72
+ ```python
73
+ # Three nested loops:
74
+ for data in collector: # Data collection loop
75
+ for epoch in range(epochs): # Epoch loop
76
+ for batch in replay_buffer: # Buffer consumption loop
77
+ # Optimize on batch
78
+ loss = loss_fn(batch)
79
+ loss.backward()
80
+ optimizer.step()
81
+ # Weight update
82
+ weight_updater.push_weights(policy_training)
83
+ ```
84
+
85
+ **Asynchronous Mode (grpo-async.py)**:
86
+ ```python
87
+ # Start data collection in background
88
+ collector.start()
89
+
90
+ # Single optimization loop
91
+ for step in range(total_steps):
92
+ # Sample and optimize
93
+ batch = replay_buffer.sample()
94
+ loss = loss_fn(batch)
95
+ loss.backward()
96
+ optimizer.step()
97
+ # Update weights once in a while
98
+ if cond():
99
+ weight_updater.push_weights(policy_training)
100
+
101
+ ```
102
+
103
+ Key differences:
104
+ 1. **Data Collection**:
105
+ - Sync: Data collection and optimization happen sequentially.
106
+
107
+ *Note*: The `train.sync_iter=False` argument can be used to collect data whilst optimizing. In this context, the
108
+ maximum policy age will be 1. If `train.sync_iter=True` (default), the maximum policy age is `0`.
109
+
110
+ - Async: Data collection runs in background while optimization happens
111
+
112
+ 2. **Buffer Size**:
113
+ - Sync: Buffer size must equal the batch size returned by collector (`buffer_size = dialog_turns_per_batch`)
114
+ - Async: Buffer can be larger than the batch size, allowing for more diverse sampling
115
+
116
+ 3. **Data Processing**:
117
+ - Sync: Processes the same data multiple times (epochs)
118
+ - Async: Each piece of data is processed a non-deterministic number of times.
119
+
120
+ 4. **Weight updates**:
121
+ - Sync: Weights are updated befor every collection of data.
122
+ - Async: Weights are updated at a given interval (in gradient steps). This will require a synchronization between the training
123
+ and inference processes, and frequent updates will cause both workers to often wait for each other.
124
+
125
+ The async mode offers better performance by:
126
+ - Running data collection and optimization concurrently
127
+ - More efficient GPU utilization
128
+ - Reduced memory overhead
129
+ - Better throughput
130
+ - More flexible buffer management
131
+
132
+ ### Running GRPO on More Than One Node with SLURM
133
+
134
+ GRPO can be run across more than one node using SLURM, enabling distributed training for moderately scaled workloads.
135
+
136
+ Two scripts are provided for launching multi-node runs:
137
+
138
+ - `grpo-sync-multi-node.sbatch`: SLURM job script that launches sync GRPO across multiple nodes using Ray.
139
+ - `grpo-async-multi-node.sbatch`: SLURM job script that launches async GRPO across multiple nodes using Ray.
140
+
141
+ Example Usage:
142
+
143
+ ```bash
144
+ sbatch sota-implementations/grpo/grpo-sync-multi-node.sbatch
145
+
146
+ ### KL Divergences in PPO: Reference vs Inference
147
+
148
+ KL divergence is a key regularization term in policy optimization algorithms like PPO and in LLM post-training. It measures how much the updated policy diverges from a baseline or reference policy, helping to prevent the new policy from drifting too far and ensuring stable learning.
149
+
150
+ There are two main types of KL divergences commonly used:
151
+
152
+ #### 1. KL to Reference Policy (KL[ref || policy])
153
+ - **Definition:** Measures how much the new (learned) policy diverges from a fixed reference policy (often the original, pre-trained model).
154
+ - **Implementation:** In GRPO, this is computed as `(ref_log_prob - cur_log_prob).expm1() - (ref_log_prob - cur_log_prob)`, which is a numerically stable way to compute KL for log probabilities.
155
+ - **Usage:**
156
+ - **LLM Post-Training:** This is the canonical choice in LLM post-training (e.g., RLHF, DPO, GRPO). The reference is usually the original language model before any RL fine-tuning. Penalizing KL[ref || policy] ensures the fine-tuned model stays close to the original, preserving language quality and preventing over-optimization.
157
+ - **Effect:** Encourages the new policy to not deviate too much from the reference, maintaining fluency and generalization.
158
+
159
+ #### 2. KL to Inference Policy (KL[policy || inference])
160
+ - **Definition:** Measures how much the current policy diverges from the policy used to generate the data (the inference policy, sometimes called the behavior policy).
161
+ - **Implementation:** In GRPO, this is approximated as `prev_log_prob - cur_log_prob`, where `prev_log_prob` is from the inference policy that generated the data.
162
+ - **Usage:**
163
+ - **Canonical PPO:** In standard PPO (especially in RL for control), this is the canonical KL: KL[policy || inference]. The inference policy is the one that generated the trajectories in the replay buffer. Penalizing this KL ensures that the updated policy does not move too far from the data distribution, stabilizing importance sampling and learning.
164
+ - **Effect:** Prevents the policy from making large, unstable updates relative to the data it was trained on.
165
+
166
+ #### Summary Table
167
+ | Setting | Canonical KL Term | Purpose |
168
+ |--------------------|--------------------------|---------------------------------------------|
169
+ | PPO (RL control) | KL[policy || inference] | Stabilize updates, match data distribution |
170
+ | LLM Post-Training | KL[ref || policy] | Stay close to pre-trained model |
171
+
172
+ In GRPO, both types of KL can be used and controlled via configuration. Typically, for LLM post-training, the KL to reference is the most important for preserving model quality, while the KL to inference is more about stabilizing the optimization process.
173
+
174
+ The KL contributions to the loss can be controlled via the `train.kl_to_ref_coeff` and `train.kl_to_inference_coeff`, respectively.
175
+
176
+ Additionally, the KL to ref loss contribution can be either added to the reward during the grading of the LLM response, or added directly to the loss given by the `train.kl_coef_in_loss` config option.
177
+
178
+ In the original GRPO paper, the KL to reference (KL[ref || policy]) is added **directly to the loss function**, not to the reward. This means that the KL penalty acts as a regularizer during optimization, discouraging the policy from drifting too far from the reference model at every update step. This is in contrast to some RLHF-style approaches, where the KL penalty is added to the reward signal during data collection (i.e., the environment's reward is modified).
179
+
180
+ **Why does this matter?**
181
+ - **KL in the loss (as in GRPO):** The optimization explicitly balances the policy objective and the KL penalty at each gradient step, making the trade-off more direct and stable. This is the canonical approach in GRPO and is controlled by setting `train.kl_coef_in_loss=True` in the config.
182
+ - **KL in the reward:** The KL penalty is treated as part of the environment's reward, so the policy is trained to maximize this modified reward. This can sometimes make the effect of the KL less direct, as it is mixed with the task reward during data collection.
183
+
184
+ In summary, GRPO's approach of adding the KL to reference directly to the loss provides more explicit and stable regularization, and is the recommended setting for most LLM post-training scenarios.
185
+
186
+ ### Run with IFEval Config
187
+
188
+ ```bash
189
+ python grpo-sync.py mode=sync --config-name grpo_ifeval
190
+ ```
191
+
192
+ ### Override Config Values
193
+
194
+ ```bash
195
+ # Change dataset
196
+ python grpo-sync.py mode=sync env.dataset=ifeval
197
+
198
+ # Modify training parameters
199
+ python grpo-sync.py mode=sync optimizer.lr=2e-5 optimizer.weight_decay=0.01
200
+
201
+ # Change model
202
+ python grpo-sync.py mode=sync model.name=meta-llama/Llama-2-7b-hf
203
+ ```
204
+
205
+ ### Hyperparameter Sweeps
206
+
207
+ ```bash
208
+ # Learning rate sweep
209
+ python grpo-sync.py mode=sync --multirun optimizer.lr=1e-4,1e-5,1e-6
210
+
211
+ # Multiple parameters
212
+ python grpo-sync.py mode=sync --multirun \
213
+ optimizer.lr=1e-4,1e-5 \
214
+ policy.kl_coef=0.01,0.1
215
+ ```
216
+
217
+ Don't forget to set the number of value of `train.total_dialog_turns` to a reasonable value!
218
+
219
+ ## Monitoring
220
+
221
+ Training progress is logged to Weights & Biases with the following metrics:
222
+ - Reward
223
+ - Advantage
224
+ - KL penalty
225
+ - Sequence length
226
+ - ESS (Effective Sample Size)
227
+ - Loss metrics (objective, clip fraction, etc.)
228
+ - Gradient norm
229
+ - Throughput metrics (in async mode)
230
+
231
+ ## Checkpointing
232
+
233
+ Checkpoints are saved every `train.checkpoint_frequency` steps and contain:
234
+ - Model state
235
+ - Optimizer state
236
+ - Gradient scaler state (for mixed precision)
237
+ - Full configuration
238
+
239
+ ## Debugging Out-of-memory issues
240
+
241
+ - vLLM: Reduce `inference_model.gpu_memory_utilization=FRACTION` or number of environments run
242
+ in parallel (`env.num_envs=N`).
243
+ - KL scoring: If the KL scoring is achieved on the batch of data,
244
+ reduce the number of environments (`env.num_envs=N`) run in parallel.
245
+ - Training: Reduce batch size (`train.optim_batch_size`)
246
+
247
+ ## Directory Structure
248
+
249
+ ```
250
+ sota-implementations/grpo/
251
+ ├── config/
252
+ │ └── grpo_gsm8k.yaml # Main configuration file
253
+ │ └── grpo_ifeval.yaml # config file for IFEval task
254
+ ├── grpo-sync.py # Synchronous training script
255
+ ├── grpo-async.py # Asynchronous training script
256
+ ├── grpo_utils.py # Utility functions
257
+ └── README.md # This file
258
+ ```
259
+
260
+ ## Output Structure
261
+
262
+ Each run creates a timestamped directory under `outputs/`:
263
+ ```
264
+ outputs/
265
+ └── YYYY-MM-DD/
266
+ └── HH-MM-SS/
267
+ ├── checkpoints/
268
+ │ └── checkpoint_*.pt
269
+ └── .hydra/
270
+ └── config.yaml
271
+ ```
272
+
273
+ For hyperparameter sweeps, outputs are stored under `multirun/`.