torchrl 0.11.0__cp314-cp314t-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (394) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/__init__.py +144 -0
  97. torchrl/_extension.py +74 -0
  98. torchrl/_torchrl.cp314t-win_amd64.pyd +0 -0
  99. torchrl/_utils.py +1431 -0
  100. torchrl/collectors/__init__.py +48 -0
  101. torchrl/collectors/_base.py +1058 -0
  102. torchrl/collectors/_constants.py +88 -0
  103. torchrl/collectors/_multi_async.py +324 -0
  104. torchrl/collectors/_multi_base.py +1805 -0
  105. torchrl/collectors/_multi_sync.py +464 -0
  106. torchrl/collectors/_runner.py +581 -0
  107. torchrl/collectors/_single.py +2009 -0
  108. torchrl/collectors/_single_async.py +259 -0
  109. torchrl/collectors/collectors.py +62 -0
  110. torchrl/collectors/distributed/__init__.py +32 -0
  111. torchrl/collectors/distributed/default_configs.py +133 -0
  112. torchrl/collectors/distributed/generic.py +1306 -0
  113. torchrl/collectors/distributed/ray.py +1092 -0
  114. torchrl/collectors/distributed/rpc.py +1006 -0
  115. torchrl/collectors/distributed/sync.py +731 -0
  116. torchrl/collectors/distributed/utils.py +160 -0
  117. torchrl/collectors/llm/__init__.py +10 -0
  118. torchrl/collectors/llm/base.py +494 -0
  119. torchrl/collectors/llm/ray_collector.py +275 -0
  120. torchrl/collectors/llm/utils.py +36 -0
  121. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  122. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  123. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  124. torchrl/collectors/utils.py +433 -0
  125. torchrl/collectors/weight_update.py +591 -0
  126. torchrl/csrc/numpy_utils.h +38 -0
  127. torchrl/csrc/pybind.cpp +27 -0
  128. torchrl/csrc/segment_tree.h +458 -0
  129. torchrl/csrc/torch_utils.h +34 -0
  130. torchrl/csrc/utils.cpp +48 -0
  131. torchrl/csrc/utils.h +31 -0
  132. torchrl/data/__init__.py +187 -0
  133. torchrl/data/datasets/__init__.py +58 -0
  134. torchrl/data/datasets/atari_dqn.py +878 -0
  135. torchrl/data/datasets/common.py +281 -0
  136. torchrl/data/datasets/d4rl.py +489 -0
  137. torchrl/data/datasets/d4rl_infos.py +187 -0
  138. torchrl/data/datasets/gen_dgrl.py +375 -0
  139. torchrl/data/datasets/minari_data.py +643 -0
  140. torchrl/data/datasets/openml.py +177 -0
  141. torchrl/data/datasets/openx.py +798 -0
  142. torchrl/data/datasets/roboset.py +363 -0
  143. torchrl/data/datasets/utils.py +11 -0
  144. torchrl/data/datasets/vd4rl.py +432 -0
  145. torchrl/data/llm/__init__.py +34 -0
  146. torchrl/data/llm/dataset.py +491 -0
  147. torchrl/data/llm/history.py +1378 -0
  148. torchrl/data/llm/prompt.py +198 -0
  149. torchrl/data/llm/reward.py +225 -0
  150. torchrl/data/llm/topk.py +186 -0
  151. torchrl/data/llm/utils.py +543 -0
  152. torchrl/data/map/__init__.py +21 -0
  153. torchrl/data/map/hash.py +185 -0
  154. torchrl/data/map/query.py +204 -0
  155. torchrl/data/map/tdstorage.py +363 -0
  156. torchrl/data/map/tree.py +1434 -0
  157. torchrl/data/map/utils.py +103 -0
  158. torchrl/data/postprocs/__init__.py +8 -0
  159. torchrl/data/postprocs/postprocs.py +391 -0
  160. torchrl/data/replay_buffers/__init__.py +99 -0
  161. torchrl/data/replay_buffers/checkpointers.py +622 -0
  162. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  163. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  164. torchrl/data/replay_buffers/samplers.py +2578 -0
  165. torchrl/data/replay_buffers/scheduler.py +265 -0
  166. torchrl/data/replay_buffers/storages.py +2412 -0
  167. torchrl/data/replay_buffers/utils.py +1042 -0
  168. torchrl/data/replay_buffers/writers.py +781 -0
  169. torchrl/data/tensor_specs.py +7101 -0
  170. torchrl/data/utils.py +334 -0
  171. torchrl/envs/__init__.py +265 -0
  172. torchrl/envs/async_envs.py +1105 -0
  173. torchrl/envs/batched_envs.py +3093 -0
  174. torchrl/envs/common.py +4241 -0
  175. torchrl/envs/custom/__init__.py +11 -0
  176. torchrl/envs/custom/chess.py +617 -0
  177. torchrl/envs/custom/llm.py +214 -0
  178. torchrl/envs/custom/pendulum.py +401 -0
  179. torchrl/envs/custom/san_moves.txt +29274 -0
  180. torchrl/envs/custom/tictactoeenv.py +288 -0
  181. torchrl/envs/env_creator.py +263 -0
  182. torchrl/envs/gym_like.py +752 -0
  183. torchrl/envs/libs/__init__.py +68 -0
  184. torchrl/envs/libs/_gym_utils.py +326 -0
  185. torchrl/envs/libs/brax.py +846 -0
  186. torchrl/envs/libs/dm_control.py +544 -0
  187. torchrl/envs/libs/envpool.py +447 -0
  188. torchrl/envs/libs/gym.py +2239 -0
  189. torchrl/envs/libs/habitat.py +138 -0
  190. torchrl/envs/libs/isaac_lab.py +87 -0
  191. torchrl/envs/libs/isaacgym.py +203 -0
  192. torchrl/envs/libs/jax_utils.py +166 -0
  193. torchrl/envs/libs/jumanji.py +963 -0
  194. torchrl/envs/libs/meltingpot.py +599 -0
  195. torchrl/envs/libs/openml.py +153 -0
  196. torchrl/envs/libs/openspiel.py +652 -0
  197. torchrl/envs/libs/pettingzoo.py +1042 -0
  198. torchrl/envs/libs/procgen.py +351 -0
  199. torchrl/envs/libs/robohive.py +429 -0
  200. torchrl/envs/libs/smacv2.py +645 -0
  201. torchrl/envs/libs/unity_mlagents.py +891 -0
  202. torchrl/envs/libs/utils.py +147 -0
  203. torchrl/envs/libs/vmas.py +813 -0
  204. torchrl/envs/llm/__init__.py +63 -0
  205. torchrl/envs/llm/chat.py +730 -0
  206. torchrl/envs/llm/datasets/README.md +4 -0
  207. torchrl/envs/llm/datasets/__init__.py +17 -0
  208. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  209. torchrl/envs/llm/datasets/ifeval.py +274 -0
  210. torchrl/envs/llm/envs.py +789 -0
  211. torchrl/envs/llm/libs/README.md +3 -0
  212. torchrl/envs/llm/libs/__init__.py +8 -0
  213. torchrl/envs/llm/libs/mlgym.py +869 -0
  214. torchrl/envs/llm/reward/__init__.py +10 -0
  215. torchrl/envs/llm/reward/gsm8k.py +324 -0
  216. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  217. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  218. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  222. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  223. torchrl/envs/llm/transforms/__init__.py +55 -0
  224. torchrl/envs/llm/transforms/browser.py +292 -0
  225. torchrl/envs/llm/transforms/dataloading.py +859 -0
  226. torchrl/envs/llm/transforms/format.py +73 -0
  227. torchrl/envs/llm/transforms/kl.py +1544 -0
  228. torchrl/envs/llm/transforms/policy_version.py +189 -0
  229. torchrl/envs/llm/transforms/reason.py +323 -0
  230. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  231. torchrl/envs/llm/transforms/tools.py +1955 -0
  232. torchrl/envs/model_based/__init__.py +9 -0
  233. torchrl/envs/model_based/common.py +180 -0
  234. torchrl/envs/model_based/dreamer.py +112 -0
  235. torchrl/envs/transforms/__init__.py +147 -0
  236. torchrl/envs/transforms/functional.py +48 -0
  237. torchrl/envs/transforms/gym_transforms.py +203 -0
  238. torchrl/envs/transforms/module.py +341 -0
  239. torchrl/envs/transforms/r3m.py +372 -0
  240. torchrl/envs/transforms/ray_service.py +663 -0
  241. torchrl/envs/transforms/rb_transforms.py +214 -0
  242. torchrl/envs/transforms/transforms.py +11835 -0
  243. torchrl/envs/transforms/utils.py +94 -0
  244. torchrl/envs/transforms/vc1.py +307 -0
  245. torchrl/envs/transforms/vecnorm.py +845 -0
  246. torchrl/envs/transforms/vip.py +407 -0
  247. torchrl/envs/utils.py +1718 -0
  248. torchrl/envs/vec_envs.py +11 -0
  249. torchrl/modules/__init__.py +206 -0
  250. torchrl/modules/distributions/__init__.py +73 -0
  251. torchrl/modules/distributions/continuous.py +830 -0
  252. torchrl/modules/distributions/discrete.py +908 -0
  253. torchrl/modules/distributions/truncated_normal.py +187 -0
  254. torchrl/modules/distributions/utils.py +233 -0
  255. torchrl/modules/llm/__init__.py +62 -0
  256. torchrl/modules/llm/backends/__init__.py +65 -0
  257. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  258. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  259. torchrl/modules/llm/backends/vllm/base.py +72 -0
  260. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  261. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  262. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  263. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  264. torchrl/modules/llm/policies/__init__.py +28 -0
  265. torchrl/modules/llm/policies/common.py +1809 -0
  266. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  267. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  268. torchrl/modules/llm/utils.py +23 -0
  269. torchrl/modules/mcts/__init__.py +21 -0
  270. torchrl/modules/mcts/scores.py +579 -0
  271. torchrl/modules/models/__init__.py +86 -0
  272. torchrl/modules/models/batchrenorm.py +119 -0
  273. torchrl/modules/models/decision_transformer.py +179 -0
  274. torchrl/modules/models/exploration.py +731 -0
  275. torchrl/modules/models/llm.py +156 -0
  276. torchrl/modules/models/model_based.py +596 -0
  277. torchrl/modules/models/models.py +1712 -0
  278. torchrl/modules/models/multiagent.py +1067 -0
  279. torchrl/modules/models/recipes/impala.py +185 -0
  280. torchrl/modules/models/utils.py +162 -0
  281. torchrl/modules/planners/__init__.py +10 -0
  282. torchrl/modules/planners/cem.py +228 -0
  283. torchrl/modules/planners/common.py +73 -0
  284. torchrl/modules/planners/mppi.py +265 -0
  285. torchrl/modules/tensordict_module/__init__.py +89 -0
  286. torchrl/modules/tensordict_module/actors.py +2457 -0
  287. torchrl/modules/tensordict_module/common.py +529 -0
  288. torchrl/modules/tensordict_module/exploration.py +814 -0
  289. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  290. torchrl/modules/tensordict_module/rnn.py +1639 -0
  291. torchrl/modules/tensordict_module/sequence.py +132 -0
  292. torchrl/modules/tensordict_module/world_models.py +34 -0
  293. torchrl/modules/utils/__init__.py +38 -0
  294. torchrl/modules/utils/mappings.py +9 -0
  295. torchrl/modules/utils/utils.py +89 -0
  296. torchrl/objectives/__init__.py +78 -0
  297. torchrl/objectives/a2c.py +659 -0
  298. torchrl/objectives/common.py +753 -0
  299. torchrl/objectives/cql.py +1346 -0
  300. torchrl/objectives/crossq.py +710 -0
  301. torchrl/objectives/ddpg.py +453 -0
  302. torchrl/objectives/decision_transformer.py +371 -0
  303. torchrl/objectives/deprecated.py +516 -0
  304. torchrl/objectives/dqn.py +683 -0
  305. torchrl/objectives/dreamer.py +488 -0
  306. torchrl/objectives/functional.py +48 -0
  307. torchrl/objectives/gail.py +258 -0
  308. torchrl/objectives/iql.py +996 -0
  309. torchrl/objectives/llm/__init__.py +30 -0
  310. torchrl/objectives/llm/grpo.py +846 -0
  311. torchrl/objectives/llm/sft.py +482 -0
  312. torchrl/objectives/multiagent/__init__.py +8 -0
  313. torchrl/objectives/multiagent/qmixer.py +396 -0
  314. torchrl/objectives/ppo.py +1669 -0
  315. torchrl/objectives/redq.py +683 -0
  316. torchrl/objectives/reinforce.py +530 -0
  317. torchrl/objectives/sac.py +1580 -0
  318. torchrl/objectives/td3.py +570 -0
  319. torchrl/objectives/td3_bc.py +625 -0
  320. torchrl/objectives/utils.py +782 -0
  321. torchrl/objectives/value/__init__.py +28 -0
  322. torchrl/objectives/value/advantages.py +1956 -0
  323. torchrl/objectives/value/functional.py +1459 -0
  324. torchrl/objectives/value/utils.py +360 -0
  325. torchrl/record/__init__.py +17 -0
  326. torchrl/record/loggers/__init__.py +23 -0
  327. torchrl/record/loggers/common.py +48 -0
  328. torchrl/record/loggers/csv.py +226 -0
  329. torchrl/record/loggers/mlflow.py +142 -0
  330. torchrl/record/loggers/tensorboard.py +139 -0
  331. torchrl/record/loggers/trackio.py +163 -0
  332. torchrl/record/loggers/utils.py +78 -0
  333. torchrl/record/loggers/wandb.py +214 -0
  334. torchrl/record/recorder.py +554 -0
  335. torchrl/services/__init__.py +79 -0
  336. torchrl/services/base.py +109 -0
  337. torchrl/services/ray_service.py +453 -0
  338. torchrl/testing/__init__.py +107 -0
  339. torchrl/testing/assertions.py +179 -0
  340. torchrl/testing/dist_utils.py +122 -0
  341. torchrl/testing/env_creators.py +227 -0
  342. torchrl/testing/env_helper.py +35 -0
  343. torchrl/testing/gym_helpers.py +156 -0
  344. torchrl/testing/llm_mocks.py +119 -0
  345. torchrl/testing/mocking_classes.py +2720 -0
  346. torchrl/testing/modules.py +295 -0
  347. torchrl/testing/mp_helpers.py +15 -0
  348. torchrl/testing/ray_helpers.py +293 -0
  349. torchrl/testing/utils.py +190 -0
  350. torchrl/trainers/__init__.py +42 -0
  351. torchrl/trainers/algorithms/__init__.py +11 -0
  352. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  353. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  354. torchrl/trainers/algorithms/configs/common.py +41 -0
  355. torchrl/trainers/algorithms/configs/data.py +308 -0
  356. torchrl/trainers/algorithms/configs/envs.py +104 -0
  357. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  358. torchrl/trainers/algorithms/configs/logging.py +80 -0
  359. torchrl/trainers/algorithms/configs/modules.py +570 -0
  360. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  361. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  362. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  363. torchrl/trainers/algorithms/configs/utils.py +252 -0
  364. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  365. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  366. torchrl/trainers/algorithms/ppo.py +373 -0
  367. torchrl/trainers/algorithms/sac.py +308 -0
  368. torchrl/trainers/helpers/__init__.py +40 -0
  369. torchrl/trainers/helpers/collectors.py +416 -0
  370. torchrl/trainers/helpers/envs.py +573 -0
  371. torchrl/trainers/helpers/logger.py +33 -0
  372. torchrl/trainers/helpers/losses.py +132 -0
  373. torchrl/trainers/helpers/models.py +658 -0
  374. torchrl/trainers/helpers/replay_buffer.py +59 -0
  375. torchrl/trainers/helpers/trainers.py +301 -0
  376. torchrl/trainers/trainers.py +2052 -0
  377. torchrl/weight_update/__init__.py +33 -0
  378. torchrl/weight_update/_distributed.py +749 -0
  379. torchrl/weight_update/_mp.py +624 -0
  380. torchrl/weight_update/_noupdate.py +102 -0
  381. torchrl/weight_update/_ray.py +1032 -0
  382. torchrl/weight_update/_rpc.py +284 -0
  383. torchrl/weight_update/_shared.py +891 -0
  384. torchrl/weight_update/llm/__init__.py +32 -0
  385. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  386. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  387. torchrl/weight_update/utils.py +73 -0
  388. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  389. torchrl-0.11.0.dist-info/LICENSE +21 -0
  390. torchrl-0.11.0.dist-info/METADATA +1307 -0
  391. torchrl-0.11.0.dist-info/RECORD +394 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,435 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from __future__ import annotations
7
+
8
+ import gc
9
+ import os
10
+ from functools import partial
11
+ from pathlib import Path
12
+
13
+ import hydra
14
+
15
+ from torchrl import merge_ray_runtime_env, torchrl_logger
16
+ from torchrl.data.llm.history import History
17
+ from torchrl.record.loggers.wandb import WandbLogger
18
+ from torchrl.weight_update.llm import get_model_metadata
19
+
20
+ try:
21
+ import ray
22
+ except ImportError:
23
+ raise ImportError(
24
+ "Ray is required for sync training. Please install ray with `pip install ray`."
25
+ )
26
+ import time
27
+
28
+ import torch
29
+ import tqdm
30
+
31
+ from grpo_utils import (
32
+ add_kl_transforms_to_replay_buffer,
33
+ check_grpo_dependencies,
34
+ compute_device_allocation,
35
+ get_inference_model,
36
+ get_train_model,
37
+ log_training_metrics,
38
+ make_env,
39
+ make_weight_sync_scheme,
40
+ )
41
+ from omegaconf import DictConfig
42
+
43
+ try:
44
+ from tensordict import set_list_to_stack
45
+ except ImportError:
46
+ raise ImportError(
47
+ "TensorDict is required. Please install it with `pip install tensordict`."
48
+ )
49
+ from torch.amp.autocast_mode import autocast
50
+ from torch.amp.grad_scaler import GradScaler
51
+ from torchrl._utils import timeit
52
+ from torchrl.collectors.llm import RayLLMCollector
53
+ from torchrl.data import LazyStackStorage, ReplayBuffer, SamplerWithoutReplacement
54
+ from torchrl.data.replay_buffers.ray_buffer import RayReplayBuffer
55
+ from torchrl.objectives.llm.grpo import GRPOLoss, MCAdvantage
56
+
57
+
58
+ def setup_environment() -> None:
59
+ """Setup required environment variables and configurations."""
60
+ if not torch.cuda.is_available():
61
+ raise RuntimeError("CUDA is required for training")
62
+
63
+ # Set default dtype to float32 for mixed precision training
64
+ torch.set_default_dtype(torch.float32)
65
+ torch.set_default_device("cuda:0")
66
+ set_list_to_stack(True).set()
67
+
68
+ # Ensure CUDA is using the correct dtype
69
+ if torch.cuda.is_available():
70
+ torch.cuda.set_device("cuda:0")
71
+
72
+
73
+ def train(
74
+ replay_buffer: ReplayBuffer,
75
+ cfg: DictConfig,
76
+ collector: RayLLMCollector,
77
+ inference_policy,
78
+ devices: list[int] | None = None,
79
+ ):
80
+ """Main training loop for GRPO sync.
81
+
82
+ This function implements synchronous training where data collection and optimization
83
+ happen in separate, consecutive steps. The total number of steps is determined by the number of epochs,
84
+ samples per epoch, and batches collected.
85
+
86
+ Args:
87
+ replay_buffer: The replay buffer to store experiences. The sampler will typically be a `SamplerWithoutReplacement`.
88
+ cfg: The configuration object containing training parameters
89
+ collector: The collector object.
90
+ devices: The devices to use for the training model.
91
+ """
92
+ # Setup training model and tokenizer
93
+ policy_training, train_tokenizer = get_train_model(cfg, devices=devices)
94
+ train_device = torch.device(f"cuda:{devices[0]}" if devices else "cuda:0")
95
+
96
+ # Setup loss function
97
+ loss_fn = GRPOLoss(
98
+ actor_network=policy_training,
99
+ kl_to_ref_coeff=cfg.train.kl_to_ref_coeff
100
+ if (cfg.train.kl_coef_in_loss and cfg.train.use_kl_to_ref)
101
+ else 0.0,
102
+ kl_to_inference_coeff=cfg.train.kl_to_inference_coeff,
103
+ entropy_coeff=cfg.train.entropy_coeff,
104
+ masking_strategy="rlhf" if cfg.env.reasoning else "sft",
105
+ device=train_device,
106
+ )
107
+ if cfg.env.reasoning:
108
+ # TODO: this is clunky, we should find a way to do this more naturally
109
+ loss_fn.set_keys(sample_log_prob=("next", "log_probs", "full"))
110
+ if cfg.model.compile:
111
+ loss_fn = torch.compile(loss_fn)
112
+
113
+ vllm_engine = inference_policy.model
114
+
115
+ # Create weight sync scheme
116
+ weight_sync_scheme = make_weight_sync_scheme(vllm_engine=vllm_engine)
117
+
118
+ # Set up weight sender
119
+ torchrl_logger.info("Setting up weight synchronization scheme...")
120
+ sender = weight_sync_scheme.create_sender()
121
+ sender.register_model(policy_training)
122
+
123
+ # Initialize collective group
124
+ torchrl_logger.info("Initializing collective group...")
125
+ metadata = get_model_metadata(policy_training)
126
+ sender.init_all_workers_group(metadata, vllm_engine=vllm_engine)
127
+
128
+ # First weight update
129
+ with timeit("update_policy_weights"):
130
+ sender.update_weights()
131
+ timeit.print(prefix="First update_policy_weights_ time")
132
+ timeit.reset()
133
+
134
+ # Make optimizer
135
+ torchrl_logger.info("Starting optimizer.")
136
+ optimizer = torch.optim.Adam(
137
+ policy_training.parameters(),
138
+ lr=cfg.optimizer.lr,
139
+ weight_decay=cfg.optimizer.weight_decay,
140
+ fused=False,
141
+ )
142
+ scaler = GradScaler(enabled=cfg.train.mixed_precision)
143
+
144
+ # Make checkpoint dir
145
+ checkpoint_dir = Path(cfg.logging.checkpoint_dir)
146
+ checkpoint_dir.mkdir(parents=True, exist_ok=True)
147
+
148
+ # Make wandb logger
149
+ torchrl_logger.info("Starting wandb logger.")
150
+ experiment_name = cfg.logging.experiment_name
151
+ if experiment_name is not None:
152
+ experiment_name = [experiment_name]
153
+ else:
154
+ experiment_name = []
155
+
156
+ experiment_name.append(cfg.env.dataset)
157
+ experiment_name.append(cfg.model.name)
158
+ wandb_logger = WandbLogger(
159
+ project="grpo-sync", exp_name="-".join(["grpo-sync"] + experiment_name)
160
+ )
161
+
162
+ # Training loop
163
+ torchrl_logger.info("Starting training loop.")
164
+ pbar = tqdm.tqdm(collector)
165
+ grad_norm = 0.0 # Initialize grad_norm
166
+ data_read_count = 0
167
+
168
+ global_step = 0
169
+ start_time = time.time()
170
+ for data in pbar:
171
+ # Wait for the replay buffer to be filled - when reasoning, we collect trajectories
172
+ # so the buffer may not be filled straight away
173
+ if not len(replay_buffer):
174
+ torchrl_logger.info("Waiting for replay buffer to be filled")
175
+ continue
176
+ else:
177
+ torchrl_logger.info(f"Replay buffer filled: {len(replay_buffer)}")
178
+
179
+ pbar.update(1)
180
+
181
+ # data is None as the collector directly writes to the replay buffer
182
+ if data is not None:
183
+ raise ValueError("Data is not None")
184
+
185
+ for _ in range(cfg.train.epochs):
186
+ # Iterate over the replay buffer
187
+ for batch in replay_buffer:
188
+ batch = batch.to(train_device)
189
+ global_step += 1
190
+ pbar.set_description(
191
+ f"Gradient step {global_step}, writes: {replay_buffer.write_count}, batch size: {batch.shape}"
192
+ )
193
+ history: History = batch.view(-1)[0]["next", "history"].prompt
194
+ history_str: list[str] | str = history.apply_chat_template(
195
+ tokenizer=train_tokenizer
196
+ )
197
+ while not isinstance(history_str, str):
198
+ history_str = "\n".join(history_str)
199
+
200
+ data_read_count += batch.numel()
201
+
202
+ with timeit("forward_pass"):
203
+ with autocast("cuda", enabled=cfg.train.mixed_precision):
204
+ loss = loss_fn(batch)
205
+ loss_val = (
206
+ loss.mean(reduce=True)
207
+ / cfg.train.gradient_accumulation_steps
208
+ )
209
+
210
+ with timeit("backward_pass"):
211
+ if (
212
+ cfg.train.mixed_precision
213
+ and cfg.train_model.torch_dtype == "float16"
214
+ ):
215
+ scaler = GradScaler(enabled=True)
216
+ scaler.scale(loss_val).backward()
217
+ else:
218
+ loss_val.backward()
219
+
220
+ if ((global_step + 1) % cfg.train.gradient_accumulation_steps) == 0:
221
+ with timeit("optim_step"):
222
+ if (
223
+ cfg.train.mixed_precision
224
+ and cfg.train_model.torch_dtype == "float16"
225
+ ):
226
+ scaler.unscale_(optimizer)
227
+
228
+ grad_norm = torch.nn.utils.clip_grad_norm_(
229
+ policy_training.parameters(),
230
+ cfg.optimizer.clip_grad_norm,
231
+ )
232
+
233
+ if (
234
+ cfg.train.mixed_precision
235
+ and cfg.train_model.torch_dtype == "float16"
236
+ ):
237
+ scaler.step(optimizer)
238
+ scaler.update()
239
+ else:
240
+ optimizer.step()
241
+ optimizer.zero_grad(set_to_none=True)
242
+
243
+ del loss_val
244
+ # TODO: do we need this? Does it interfere with other processes?
245
+ # torch.cuda.empty_cache()
246
+ gc.collect()
247
+
248
+ if (global_step % cfg.train.logging_frequency) == 0:
249
+ log_training_metrics(
250
+ wandb_logger=wandb_logger,
251
+ replay_buffer=replay_buffer,
252
+ batch=batch,
253
+ loss=loss,
254
+ grad_norm=grad_norm,
255
+ global_step=global_step,
256
+ data_read_count=data_read_count,
257
+ collector=collector,
258
+ start_time=start_time,
259
+ gradient_accumulation_steps=cfg.train.gradient_accumulation_steps,
260
+ history_str=history_str,
261
+ use_kl_to_ref=cfg.train.use_kl_to_ref,
262
+ )
263
+
264
+ # Checkpointing disabled to prevent disk space issues
265
+ # if (global_step + 1) % cfg.train.checkpoint_frequency == 0:
266
+ # with timeit("save_checkpoint"):
267
+ # torchrl_logger.info(
268
+ # f"Saving checkpoint {(global_step+1) // cfg.train.checkpoint_frequency}..."
269
+ # )
270
+ # checkpoint = {
271
+ # "step": global_step,
272
+ # "model_state_dict": policy_training.model.state_dict(),
273
+ # "optimizer_state_dict": optimizer.state_dict(),
274
+ # "scaler_state_dict": scaler.state_dict(),
275
+ # "config": dict(cfg),
276
+ # }
277
+ # torch.save(checkpoint, checkpoint_dir / f"checkpoint_{global_step:04d}.pt")
278
+
279
+ with timeit("update_policy_weights"):
280
+ torchrl_logger.info("Updating policy weights...")
281
+ sender.update_weights()
282
+ # TODO: do we need this? Does it interfere with other processes?
283
+ # torch.cuda.empty_cache()
284
+ gc.collect()
285
+
286
+ timeit.print(prefix="timeit")
287
+ for key, val in timeit.todict().items():
288
+ wandb_logger.log_scalar(f"timeit/{key}", val)
289
+ timeit.reset()
290
+
291
+ if cfg.train.empty_replay_buffer:
292
+ replay_buffer.empty(empty_write_count=False)
293
+
294
+ pbar.close()
295
+ collector.shutdown()
296
+
297
+
298
+ @hydra.main(version_base=None, config_path="config", config_name="grpo_gsm8k")
299
+ def main(cfg):
300
+ # Check for required GRPO dependencies
301
+ check_grpo_dependencies()
302
+
303
+ # Force sync mode
304
+ if not cfg.train.sync:
305
+ raise ValueError(
306
+ "grpo-sync.py must run in sync mode (`python grpo-sync.py mode=sync`). Please use grpo-async.py for async mode (`python grpo-async.py mode=async`)."
307
+ )
308
+ if cfg.train.weight_update_frequency is not None:
309
+ raise ValueError("weight_update_frequency must be left empty in sync mode.")
310
+
311
+ # Compute device allocation
312
+ device_config = compute_device_allocation(cfg)
313
+
314
+ if not ray.is_initialized():
315
+ # Convert OmegaConf to regular dict and filter out unsupported parameters
316
+ ray_init_config = {
317
+ k: dict(v) if isinstance(v, DictConfig) else v
318
+ for k, v in dict(cfg.ray.init_config).items()
319
+ if not k.startswith("_")
320
+ }
321
+
322
+ # Add computed GPU configuration and merge with default runtime_env
323
+ ray_init_config["num_gpus"] = device_config["ray_num_gpus"]
324
+ ray_init_config = merge_ray_runtime_env(ray_init_config)
325
+ torchrl_logger.info(f"Ray init config: {ray_init_config=}")
326
+ ray_managed_externally = os.environ.get("RAY_CLUSTER_MANAGED_EXTERNALLY")
327
+ if ray_managed_externally:
328
+ ray.init(address="auto")
329
+ else:
330
+ ray.init(**ray_init_config)
331
+
332
+ # Check if num_devices is set
333
+ if cfg.inference_model.num_devices is None:
334
+ raise ValueError(
335
+ "Inference model num_devices must be set via inference_model.num_devices"
336
+ )
337
+ if cfg.train.use_kl_to_ref and cfg.ref_model.num_devices is None:
338
+ raise ValueError(
339
+ "Ref model num_devices must be set via ref_model.num_devices when use_kl_to_ref is True"
340
+ )
341
+ if cfg.train_model.num_devices is None:
342
+ raise ValueError(
343
+ "Train model num_devices must be set via train_model.num_devices"
344
+ )
345
+
346
+ # Convert OmegaConf to regular dict for Ray configs
347
+ replay_buffer_config = dict(cfg.ray.replay_buffer_config)
348
+ collector_config = dict(cfg.ray.collector_config)
349
+ train_handler_config = dict(cfg.ray.train_handler_config)
350
+
351
+ inference_policy = get_inference_model(
352
+ cfg, devices=device_config["inference_model_devices"]
353
+ )
354
+ torchrl_logger.info(f"Inference policy: {inference_policy}")
355
+
356
+ torchrl_logger.info(f"Starting replay buffer with {replay_buffer_config=}")
357
+ if cfg.train.buffer_size is not None and (
358
+ cfg.train.buffer_size != cfg.train.dialog_turns_per_batch
359
+ ):
360
+ raise ValueError(
361
+ "buffer_size must be equal to dialog_turns_per_batch in sync settings."
362
+ )
363
+
364
+ if cfg.train.optim_batch_size % cfg.train.gradient_accumulation_steps != 0:
365
+ raise ValueError(
366
+ "optim_batch_size must be divisible by gradient_accumulation_steps"
367
+ )
368
+
369
+ rb = RayReplayBuffer(
370
+ storage=partial(
371
+ LazyStackStorage,
372
+ # Since we cache the values in the queue until we have "repeats" samples,
373
+ # the buffer can be bigger than what the dialog_turns_per_batch (at most repeats * num_envs)
374
+ cfg.train.buffer_size
375
+ if cfg.train.buffer_size
376
+ else cfg.env.repeats * cfg.env.num_envs,
377
+ ),
378
+ sampler=SamplerWithoutReplacement,
379
+ transform_factory=partial(MCAdvantage, grpo_size=cfg.env.repeats),
380
+ batch_size=max(
381
+ 1, cfg.train.optim_batch_size // cfg.train.gradient_accumulation_steps
382
+ ),
383
+ remote_config=replay_buffer_config,
384
+ )
385
+
386
+ add_kl_transforms_to_replay_buffer(rb, cfg)
387
+
388
+ torchrl_logger.info(f"Replay buffer: {rb}")
389
+
390
+ collector_config["num_gpus"] = 0
391
+ collector_config["num_cpus"] = cfg.ray.collector_config.get("num_cpus", 1)
392
+ torchrl_logger.info(f"Starting collector with {collector_config=}")
393
+
394
+ collector = RayLLMCollector(
395
+ env=partial(make_env, cfg),
396
+ policy=inference_policy,
397
+ dialog_turns_per_batch=cfg.train.dialog_turns_per_batch,
398
+ total_dialog_turns=cfg.train.total_dialog_turns,
399
+ replay_buffer=rb,
400
+ ray_init_config=None,
401
+ weight_updater=None,
402
+ track_policy_version=True,
403
+ remote_config=collector_config,
404
+ sync_iter=cfg.train.sync_iter,
405
+ verbose=False,
406
+ yield_only_last_steps=cfg.env.reasoning,
407
+ )
408
+ ray.get(collector._collector.is_initialized.remote())
409
+ torchrl_logger.info(f"Collector: {collector}")
410
+
411
+ train_handler_config = {
412
+ "num_cpus": train_handler_config.get("num_cpus", 1),
413
+ "num_gpus": cfg.train_model.num_devices,
414
+ }
415
+ torchrl_logger.info(f"Starting training handler with {train_handler_config=}")
416
+ train_handler = ray.remote(
417
+ **train_handler_config,
418
+ )(train)
419
+
420
+ # launch training
421
+ ray.get(
422
+ train_handler.remote(
423
+ rb,
424
+ cfg,
425
+ collector,
426
+ inference_policy,
427
+ devices=device_config["train_model_devices"],
428
+ )
429
+ )
430
+
431
+
432
+ if __name__ == "__main__":
433
+ # Setup environment
434
+ setup_environment()
435
+ main()