torchrl 0.11.0__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (394) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/__init__.py +144 -0
  97. torchrl/_extension.py +74 -0
  98. torchrl/_torchrl.cp314-win_amd64.pyd +0 -0
  99. torchrl/_utils.py +1431 -0
  100. torchrl/collectors/__init__.py +48 -0
  101. torchrl/collectors/_base.py +1058 -0
  102. torchrl/collectors/_constants.py +88 -0
  103. torchrl/collectors/_multi_async.py +324 -0
  104. torchrl/collectors/_multi_base.py +1805 -0
  105. torchrl/collectors/_multi_sync.py +464 -0
  106. torchrl/collectors/_runner.py +581 -0
  107. torchrl/collectors/_single.py +2009 -0
  108. torchrl/collectors/_single_async.py +259 -0
  109. torchrl/collectors/collectors.py +62 -0
  110. torchrl/collectors/distributed/__init__.py +32 -0
  111. torchrl/collectors/distributed/default_configs.py +133 -0
  112. torchrl/collectors/distributed/generic.py +1306 -0
  113. torchrl/collectors/distributed/ray.py +1092 -0
  114. torchrl/collectors/distributed/rpc.py +1006 -0
  115. torchrl/collectors/distributed/sync.py +731 -0
  116. torchrl/collectors/distributed/utils.py +160 -0
  117. torchrl/collectors/llm/__init__.py +10 -0
  118. torchrl/collectors/llm/base.py +494 -0
  119. torchrl/collectors/llm/ray_collector.py +275 -0
  120. torchrl/collectors/llm/utils.py +36 -0
  121. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  122. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  123. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  124. torchrl/collectors/utils.py +433 -0
  125. torchrl/collectors/weight_update.py +591 -0
  126. torchrl/csrc/numpy_utils.h +38 -0
  127. torchrl/csrc/pybind.cpp +27 -0
  128. torchrl/csrc/segment_tree.h +458 -0
  129. torchrl/csrc/torch_utils.h +34 -0
  130. torchrl/csrc/utils.cpp +48 -0
  131. torchrl/csrc/utils.h +31 -0
  132. torchrl/data/__init__.py +187 -0
  133. torchrl/data/datasets/__init__.py +58 -0
  134. torchrl/data/datasets/atari_dqn.py +878 -0
  135. torchrl/data/datasets/common.py +281 -0
  136. torchrl/data/datasets/d4rl.py +489 -0
  137. torchrl/data/datasets/d4rl_infos.py +187 -0
  138. torchrl/data/datasets/gen_dgrl.py +375 -0
  139. torchrl/data/datasets/minari_data.py +643 -0
  140. torchrl/data/datasets/openml.py +177 -0
  141. torchrl/data/datasets/openx.py +798 -0
  142. torchrl/data/datasets/roboset.py +363 -0
  143. torchrl/data/datasets/utils.py +11 -0
  144. torchrl/data/datasets/vd4rl.py +432 -0
  145. torchrl/data/llm/__init__.py +34 -0
  146. torchrl/data/llm/dataset.py +491 -0
  147. torchrl/data/llm/history.py +1378 -0
  148. torchrl/data/llm/prompt.py +198 -0
  149. torchrl/data/llm/reward.py +225 -0
  150. torchrl/data/llm/topk.py +186 -0
  151. torchrl/data/llm/utils.py +543 -0
  152. torchrl/data/map/__init__.py +21 -0
  153. torchrl/data/map/hash.py +185 -0
  154. torchrl/data/map/query.py +204 -0
  155. torchrl/data/map/tdstorage.py +363 -0
  156. torchrl/data/map/tree.py +1434 -0
  157. torchrl/data/map/utils.py +103 -0
  158. torchrl/data/postprocs/__init__.py +8 -0
  159. torchrl/data/postprocs/postprocs.py +391 -0
  160. torchrl/data/replay_buffers/__init__.py +99 -0
  161. torchrl/data/replay_buffers/checkpointers.py +622 -0
  162. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  163. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  164. torchrl/data/replay_buffers/samplers.py +2578 -0
  165. torchrl/data/replay_buffers/scheduler.py +265 -0
  166. torchrl/data/replay_buffers/storages.py +2412 -0
  167. torchrl/data/replay_buffers/utils.py +1042 -0
  168. torchrl/data/replay_buffers/writers.py +781 -0
  169. torchrl/data/tensor_specs.py +7101 -0
  170. torchrl/data/utils.py +334 -0
  171. torchrl/envs/__init__.py +265 -0
  172. torchrl/envs/async_envs.py +1105 -0
  173. torchrl/envs/batched_envs.py +3093 -0
  174. torchrl/envs/common.py +4241 -0
  175. torchrl/envs/custom/__init__.py +11 -0
  176. torchrl/envs/custom/chess.py +617 -0
  177. torchrl/envs/custom/llm.py +214 -0
  178. torchrl/envs/custom/pendulum.py +401 -0
  179. torchrl/envs/custom/san_moves.txt +29274 -0
  180. torchrl/envs/custom/tictactoeenv.py +288 -0
  181. torchrl/envs/env_creator.py +263 -0
  182. torchrl/envs/gym_like.py +752 -0
  183. torchrl/envs/libs/__init__.py +68 -0
  184. torchrl/envs/libs/_gym_utils.py +326 -0
  185. torchrl/envs/libs/brax.py +846 -0
  186. torchrl/envs/libs/dm_control.py +544 -0
  187. torchrl/envs/libs/envpool.py +447 -0
  188. torchrl/envs/libs/gym.py +2239 -0
  189. torchrl/envs/libs/habitat.py +138 -0
  190. torchrl/envs/libs/isaac_lab.py +87 -0
  191. torchrl/envs/libs/isaacgym.py +203 -0
  192. torchrl/envs/libs/jax_utils.py +166 -0
  193. torchrl/envs/libs/jumanji.py +963 -0
  194. torchrl/envs/libs/meltingpot.py +599 -0
  195. torchrl/envs/libs/openml.py +153 -0
  196. torchrl/envs/libs/openspiel.py +652 -0
  197. torchrl/envs/libs/pettingzoo.py +1042 -0
  198. torchrl/envs/libs/procgen.py +351 -0
  199. torchrl/envs/libs/robohive.py +429 -0
  200. torchrl/envs/libs/smacv2.py +645 -0
  201. torchrl/envs/libs/unity_mlagents.py +891 -0
  202. torchrl/envs/libs/utils.py +147 -0
  203. torchrl/envs/libs/vmas.py +813 -0
  204. torchrl/envs/llm/__init__.py +63 -0
  205. torchrl/envs/llm/chat.py +730 -0
  206. torchrl/envs/llm/datasets/README.md +4 -0
  207. torchrl/envs/llm/datasets/__init__.py +17 -0
  208. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  209. torchrl/envs/llm/datasets/ifeval.py +274 -0
  210. torchrl/envs/llm/envs.py +789 -0
  211. torchrl/envs/llm/libs/README.md +3 -0
  212. torchrl/envs/llm/libs/__init__.py +8 -0
  213. torchrl/envs/llm/libs/mlgym.py +869 -0
  214. torchrl/envs/llm/reward/__init__.py +10 -0
  215. torchrl/envs/llm/reward/gsm8k.py +324 -0
  216. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  217. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  218. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  222. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  223. torchrl/envs/llm/transforms/__init__.py +55 -0
  224. torchrl/envs/llm/transforms/browser.py +292 -0
  225. torchrl/envs/llm/transforms/dataloading.py +859 -0
  226. torchrl/envs/llm/transforms/format.py +73 -0
  227. torchrl/envs/llm/transforms/kl.py +1544 -0
  228. torchrl/envs/llm/transforms/policy_version.py +189 -0
  229. torchrl/envs/llm/transforms/reason.py +323 -0
  230. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  231. torchrl/envs/llm/transforms/tools.py +1955 -0
  232. torchrl/envs/model_based/__init__.py +9 -0
  233. torchrl/envs/model_based/common.py +180 -0
  234. torchrl/envs/model_based/dreamer.py +112 -0
  235. torchrl/envs/transforms/__init__.py +147 -0
  236. torchrl/envs/transforms/functional.py +48 -0
  237. torchrl/envs/transforms/gym_transforms.py +203 -0
  238. torchrl/envs/transforms/module.py +341 -0
  239. torchrl/envs/transforms/r3m.py +372 -0
  240. torchrl/envs/transforms/ray_service.py +663 -0
  241. torchrl/envs/transforms/rb_transforms.py +214 -0
  242. torchrl/envs/transforms/transforms.py +11835 -0
  243. torchrl/envs/transforms/utils.py +94 -0
  244. torchrl/envs/transforms/vc1.py +307 -0
  245. torchrl/envs/transforms/vecnorm.py +845 -0
  246. torchrl/envs/transforms/vip.py +407 -0
  247. torchrl/envs/utils.py +1718 -0
  248. torchrl/envs/vec_envs.py +11 -0
  249. torchrl/modules/__init__.py +206 -0
  250. torchrl/modules/distributions/__init__.py +73 -0
  251. torchrl/modules/distributions/continuous.py +830 -0
  252. torchrl/modules/distributions/discrete.py +908 -0
  253. torchrl/modules/distributions/truncated_normal.py +187 -0
  254. torchrl/modules/distributions/utils.py +233 -0
  255. torchrl/modules/llm/__init__.py +62 -0
  256. torchrl/modules/llm/backends/__init__.py +65 -0
  257. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  258. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  259. torchrl/modules/llm/backends/vllm/base.py +72 -0
  260. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  261. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  262. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  263. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  264. torchrl/modules/llm/policies/__init__.py +28 -0
  265. torchrl/modules/llm/policies/common.py +1809 -0
  266. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  267. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  268. torchrl/modules/llm/utils.py +23 -0
  269. torchrl/modules/mcts/__init__.py +21 -0
  270. torchrl/modules/mcts/scores.py +579 -0
  271. torchrl/modules/models/__init__.py +86 -0
  272. torchrl/modules/models/batchrenorm.py +119 -0
  273. torchrl/modules/models/decision_transformer.py +179 -0
  274. torchrl/modules/models/exploration.py +731 -0
  275. torchrl/modules/models/llm.py +156 -0
  276. torchrl/modules/models/model_based.py +596 -0
  277. torchrl/modules/models/models.py +1712 -0
  278. torchrl/modules/models/multiagent.py +1067 -0
  279. torchrl/modules/models/recipes/impala.py +185 -0
  280. torchrl/modules/models/utils.py +162 -0
  281. torchrl/modules/planners/__init__.py +10 -0
  282. torchrl/modules/planners/cem.py +228 -0
  283. torchrl/modules/planners/common.py +73 -0
  284. torchrl/modules/planners/mppi.py +265 -0
  285. torchrl/modules/tensordict_module/__init__.py +89 -0
  286. torchrl/modules/tensordict_module/actors.py +2457 -0
  287. torchrl/modules/tensordict_module/common.py +529 -0
  288. torchrl/modules/tensordict_module/exploration.py +814 -0
  289. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  290. torchrl/modules/tensordict_module/rnn.py +1639 -0
  291. torchrl/modules/tensordict_module/sequence.py +132 -0
  292. torchrl/modules/tensordict_module/world_models.py +34 -0
  293. torchrl/modules/utils/__init__.py +38 -0
  294. torchrl/modules/utils/mappings.py +9 -0
  295. torchrl/modules/utils/utils.py +89 -0
  296. torchrl/objectives/__init__.py +78 -0
  297. torchrl/objectives/a2c.py +659 -0
  298. torchrl/objectives/common.py +753 -0
  299. torchrl/objectives/cql.py +1346 -0
  300. torchrl/objectives/crossq.py +710 -0
  301. torchrl/objectives/ddpg.py +453 -0
  302. torchrl/objectives/decision_transformer.py +371 -0
  303. torchrl/objectives/deprecated.py +516 -0
  304. torchrl/objectives/dqn.py +683 -0
  305. torchrl/objectives/dreamer.py +488 -0
  306. torchrl/objectives/functional.py +48 -0
  307. torchrl/objectives/gail.py +258 -0
  308. torchrl/objectives/iql.py +996 -0
  309. torchrl/objectives/llm/__init__.py +30 -0
  310. torchrl/objectives/llm/grpo.py +846 -0
  311. torchrl/objectives/llm/sft.py +482 -0
  312. torchrl/objectives/multiagent/__init__.py +8 -0
  313. torchrl/objectives/multiagent/qmixer.py +396 -0
  314. torchrl/objectives/ppo.py +1669 -0
  315. torchrl/objectives/redq.py +683 -0
  316. torchrl/objectives/reinforce.py +530 -0
  317. torchrl/objectives/sac.py +1580 -0
  318. torchrl/objectives/td3.py +570 -0
  319. torchrl/objectives/td3_bc.py +625 -0
  320. torchrl/objectives/utils.py +782 -0
  321. torchrl/objectives/value/__init__.py +28 -0
  322. torchrl/objectives/value/advantages.py +1956 -0
  323. torchrl/objectives/value/functional.py +1459 -0
  324. torchrl/objectives/value/utils.py +360 -0
  325. torchrl/record/__init__.py +17 -0
  326. torchrl/record/loggers/__init__.py +23 -0
  327. torchrl/record/loggers/common.py +48 -0
  328. torchrl/record/loggers/csv.py +226 -0
  329. torchrl/record/loggers/mlflow.py +142 -0
  330. torchrl/record/loggers/tensorboard.py +139 -0
  331. torchrl/record/loggers/trackio.py +163 -0
  332. torchrl/record/loggers/utils.py +78 -0
  333. torchrl/record/loggers/wandb.py +214 -0
  334. torchrl/record/recorder.py +554 -0
  335. torchrl/services/__init__.py +79 -0
  336. torchrl/services/base.py +109 -0
  337. torchrl/services/ray_service.py +453 -0
  338. torchrl/testing/__init__.py +107 -0
  339. torchrl/testing/assertions.py +179 -0
  340. torchrl/testing/dist_utils.py +122 -0
  341. torchrl/testing/env_creators.py +227 -0
  342. torchrl/testing/env_helper.py +35 -0
  343. torchrl/testing/gym_helpers.py +156 -0
  344. torchrl/testing/llm_mocks.py +119 -0
  345. torchrl/testing/mocking_classes.py +2720 -0
  346. torchrl/testing/modules.py +295 -0
  347. torchrl/testing/mp_helpers.py +15 -0
  348. torchrl/testing/ray_helpers.py +293 -0
  349. torchrl/testing/utils.py +190 -0
  350. torchrl/trainers/__init__.py +42 -0
  351. torchrl/trainers/algorithms/__init__.py +11 -0
  352. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  353. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  354. torchrl/trainers/algorithms/configs/common.py +41 -0
  355. torchrl/trainers/algorithms/configs/data.py +308 -0
  356. torchrl/trainers/algorithms/configs/envs.py +104 -0
  357. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  358. torchrl/trainers/algorithms/configs/logging.py +80 -0
  359. torchrl/trainers/algorithms/configs/modules.py +570 -0
  360. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  361. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  362. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  363. torchrl/trainers/algorithms/configs/utils.py +252 -0
  364. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  365. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  366. torchrl/trainers/algorithms/ppo.py +373 -0
  367. torchrl/trainers/algorithms/sac.py +308 -0
  368. torchrl/trainers/helpers/__init__.py +40 -0
  369. torchrl/trainers/helpers/collectors.py +416 -0
  370. torchrl/trainers/helpers/envs.py +573 -0
  371. torchrl/trainers/helpers/logger.py +33 -0
  372. torchrl/trainers/helpers/losses.py +132 -0
  373. torchrl/trainers/helpers/models.py +658 -0
  374. torchrl/trainers/helpers/replay_buffer.py +59 -0
  375. torchrl/trainers/helpers/trainers.py +301 -0
  376. torchrl/trainers/trainers.py +2052 -0
  377. torchrl/weight_update/__init__.py +33 -0
  378. torchrl/weight_update/_distributed.py +749 -0
  379. torchrl/weight_update/_mp.py +624 -0
  380. torchrl/weight_update/_noupdate.py +102 -0
  381. torchrl/weight_update/_ray.py +1032 -0
  382. torchrl/weight_update/_rpc.py +284 -0
  383. torchrl/weight_update/_shared.py +891 -0
  384. torchrl/weight_update/llm/__init__.py +32 -0
  385. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  386. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  387. torchrl/weight_update/utils.py +73 -0
  388. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  389. torchrl-0.11.0.dist-info/LICENSE +21 -0
  390. torchrl-0.11.0.dist-info/METADATA +1307 -0
  391. torchrl-0.11.0.dist-info/RECORD +394 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,437 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from __future__ import annotations
7
+
8
+ import gc
9
+ import os
10
+ import time
11
+ from functools import partial
12
+ from pathlib import Path
13
+
14
+ import hydra
15
+
16
+ from torchrl import merge_ray_runtime_env, torchrl_logger
17
+ from torchrl.data.llm.history import History
18
+ from torchrl.record.loggers.wandb import WandbLogger
19
+ from torchrl.weight_update.llm import get_model_metadata
20
+
21
+ try:
22
+ import ray
23
+ except ImportError:
24
+ raise ImportError(
25
+ "Ray is required for async training. Please install ray with `pip install ray`."
26
+ )
27
+ import torch
28
+ import tqdm
29
+
30
+ from grpo_utils import (
31
+ add_kl_transforms_to_replay_buffer,
32
+ check_grpo_dependencies,
33
+ compute_device_allocation,
34
+ get_inference_model,
35
+ get_train_model,
36
+ log_training_metrics,
37
+ make_env,
38
+ make_weight_sync_scheme,
39
+ )
40
+ from omegaconf import DictConfig
41
+
42
+ try:
43
+ from tensordict import set_list_to_stack
44
+ except ImportError:
45
+ raise ImportError(
46
+ "TensorDict is required. Please install it with `pip install tensordict`."
47
+ )
48
+ from torch.amp.autocast_mode import autocast
49
+ from torch.amp.grad_scaler import GradScaler
50
+ from torchrl._utils import timeit
51
+ from torchrl.collectors.llm import RayLLMCollector
52
+ from torchrl.data import LazyStackStorage, ReplayBuffer
53
+ from torchrl.data.replay_buffers.ray_buffer import RayReplayBuffer
54
+ from torchrl.objectives.llm.grpo import GRPOLoss, MCAdvantage
55
+
56
+
57
+ def setup_environment() -> None:
58
+ """Setup required environment variables and configurations."""
59
+ if not torch.cuda.is_available():
60
+ raise RuntimeError("CUDA is required for training")
61
+
62
+ # Set default dtype to float32 for mixed precision training
63
+ torch.set_default_dtype(torch.float32)
64
+ torch.set_default_device("cuda:0")
65
+ set_list_to_stack(True).set()
66
+
67
+ # Ensure CUDA is using the correct dtype
68
+ if torch.cuda.is_available():
69
+ torch.cuda.set_device("cuda:0")
70
+
71
+
72
+ def train(
73
+ replay_buffer: ReplayBuffer,
74
+ cfg: DictConfig,
75
+ collectors: list[RayLLMCollector],
76
+ inference_policy,
77
+ devices: list[int] | None = None,
78
+ ):
79
+ """Main training loop for GRPO async.
80
+
81
+ This function implements asynchronous training where data collection and optimization
82
+ happen concurrently. The total number of steps is determined by the number of epochs,
83
+ samples per epoch, and batches collected.
84
+
85
+ Args:
86
+ replay_buffer: The replay buffer to store experiences
87
+ cfg: The configuration object containing training parameters
88
+ collectors: The collectors objects.
89
+ devices: The devices to use for the training model.
90
+ """
91
+ # Setup training model and tokenizer
92
+ policy_training, train_tokenizer = get_train_model(cfg, devices=devices)
93
+ train_device = torch.device(f"cuda:{devices[0]}" if devices else "cuda:0")
94
+
95
+ # Setup loss function
96
+ loss_fn = GRPOLoss(
97
+ actor_network=policy_training,
98
+ kl_to_ref_coeff=cfg.train.kl_to_ref_coeff
99
+ if (cfg.train.kl_coef_in_loss and cfg.train.use_kl_to_ref)
100
+ else 0.0,
101
+ kl_to_inference_coeff=cfg.train.kl_to_inference_coeff,
102
+ entropy_coeff=cfg.train.entropy_coeff,
103
+ masking_strategy="rlhf" if cfg.env.reasoning else "sft",
104
+ device=train_device,
105
+ )
106
+ if cfg.env.reasoning:
107
+ # TODO: this is clunky, we should find a way to do this more naturally
108
+ loss_fn.set_keys(sample_log_prob=("next", "log_probs", "full"))
109
+ if cfg.model.compile:
110
+ loss_fn = torch.compile(loss_fn)
111
+
112
+ vllm_engine = inference_policy.model
113
+
114
+ # Create weight sync scheme for the collectors
115
+ weight_sync_scheme = make_weight_sync_scheme(vllm_engine=vllm_engine)
116
+
117
+ # Set up weight sync scheme for collectors
118
+ # Note: We need to get the sender after the collectors are created
119
+ # For now, we'll update the collectors to use the scheme
120
+ torchrl_logger.info("Setting up weight synchronization scheme...")
121
+
122
+ # We'll need to manually set up the sender since collectors were already created
123
+ # without the scheme. In production, collectors should be created with weight_sync_schemes parameter.
124
+ sender = weight_sync_scheme.create_sender()
125
+ sender.register_model(policy_training)
126
+
127
+ # Initialize collective group
128
+ torchrl_logger.info("Initializing collective group...")
129
+ metadata = get_model_metadata(policy_training)
130
+ sender.init_all_workers_group(metadata, vllm_engine=vllm_engine)
131
+
132
+ # First weight update
133
+ with timeit("update_policy_weights"):
134
+ sender.update_weights()
135
+ torchrl_logger.info("Completed first update_policy_weights. Starting collectors...")
136
+ timeit.print(prefix="First update_policy_weights_ time")
137
+ timeit.reset()
138
+
139
+ for i, collector in enumerate(collectors):
140
+ torchrl_logger.info(f"Starting collector {i}...")
141
+ collector.start()
142
+
143
+ while not replay_buffer.write_count:
144
+ torchrl_logger.info("Waiting for replay buffer...")
145
+ time.sleep(1)
146
+
147
+ # Make optimizer
148
+ optimizer = torch.optim.Adam(
149
+ policy_training.parameters(),
150
+ lr=cfg.optimizer.lr,
151
+ weight_decay=cfg.optimizer.weight_decay,
152
+ fused=False,
153
+ )
154
+ scaler = GradScaler(enabled=cfg.train.mixed_precision)
155
+
156
+ # Make checkpoint dir
157
+ checkpoint_dir = Path(cfg.logging.checkpoint_dir)
158
+ checkpoint_dir.mkdir(parents=True, exist_ok=True)
159
+
160
+ # Make wandb logger
161
+ experiment_name = cfg.logging.experiment_name
162
+ if experiment_name is not None:
163
+ experiment_name = [experiment_name]
164
+ else:
165
+ experiment_name = []
166
+
167
+ experiment_name.append(cfg.env.dataset)
168
+ experiment_name.append(cfg.model.name)
169
+ wandb_logger = WandbLogger(
170
+ project="grpo-async", exp_name="-".join(["grpo-async"] + experiment_name)
171
+ )
172
+
173
+ # Training loop
174
+ total_steps = (
175
+ -(cfg.train.total_dialog_turns // -cfg.train.optim_batch_size)
176
+ * cfg.train.epochs
177
+ )
178
+ torchrl_logger.info(f"Total steps: {total_steps}")
179
+
180
+ pbar = tqdm.tqdm(total=total_steps)
181
+ grad_norm = 0.0 # Initialize grad_norm
182
+ data_read_count = 0
183
+ start_time = time.time()
184
+
185
+ for step in range(total_steps):
186
+ if not any(collector.is_running() for collector in collectors):
187
+ torchrl_logger.info("Collectors stopped, stopping training")
188
+ break
189
+ pbar.update(1)
190
+ pbar.set_description(f"Step {step}, writes: {replay_buffer.write_count}")
191
+
192
+ with timeit("sampling"):
193
+ # Sample the correct batch size for gradient accumulation
194
+ # The replay buffer is configured with batch_size = optim_batch_size // gradient_accumulation_steps
195
+ # So we should sample that amount per step, not the full optim_batch_size
196
+ batch_size_per_step = (
197
+ cfg.train.optim_batch_size // cfg.train.gradient_accumulation_steps
198
+ )
199
+ batch = replay_buffer.sample(batch_size_per_step).to(train_device)
200
+ history: History = batch.view(-1)[0]["history", "full"]
201
+ history_str: list[str] | str = history.apply_chat_template(
202
+ tokenizer=train_tokenizer
203
+ )
204
+ while not isinstance(history_str, str):
205
+ history_str = "\n".join(history_str)
206
+
207
+ data_read_count += batch.numel()
208
+
209
+ with timeit("forward_pass"):
210
+ with autocast("cuda", enabled=cfg.train.mixed_precision):
211
+ loss = loss_fn(batch)
212
+ loss_val = (
213
+ loss.mean(reduce=True) / cfg.train.gradient_accumulation_steps
214
+ )
215
+
216
+ with timeit("backward_pass"):
217
+ if cfg.train.mixed_precision and cfg.train_model.torch_dtype == "float16":
218
+ scaler = GradScaler(enabled=True)
219
+ scaler.scale(loss_val).backward()
220
+ else:
221
+ loss_val.backward()
222
+
223
+ if (step + 1) % cfg.train.gradient_accumulation_steps == 0:
224
+ with timeit("optim_step"):
225
+ if (
226
+ cfg.train.mixed_precision
227
+ and cfg.train_model.torch_dtype == "float16"
228
+ ):
229
+ scaler.unscale_(optimizer)
230
+
231
+ grad_norm = torch.nn.utils.clip_grad_norm_(
232
+ policy_training.parameters(),
233
+ cfg.optimizer.clip_grad_norm,
234
+ )
235
+
236
+ if (
237
+ cfg.train.mixed_precision
238
+ and cfg.train_model.torch_dtype == "float16"
239
+ ):
240
+ scaler.step(optimizer)
241
+ scaler.update()
242
+ else:
243
+ optimizer.step()
244
+ optimizer.zero_grad(set_to_none=True)
245
+
246
+ if (step % cfg.train.logging_frequency) == 0:
247
+ log_training_metrics(
248
+ wandb_logger=wandb_logger,
249
+ replay_buffer=replay_buffer,
250
+ batch=batch,
251
+ loss=loss,
252
+ grad_norm=grad_norm,
253
+ global_step=step,
254
+ data_read_count=data_read_count,
255
+ collector=collectors[0],
256
+ start_time=start_time,
257
+ gradient_accumulation_steps=cfg.train.gradient_accumulation_steps,
258
+ history_str=history_str,
259
+ use_kl_to_ref=cfg.train.use_kl_to_ref,
260
+ )
261
+
262
+ if step % cfg.train.weight_update_frequency == 0:
263
+ with timeit("update_policy_weights"):
264
+ torchrl_logger.info("Updating policy weights...")
265
+ sender.update_weights()
266
+ # TODO: do we need this? Does it interfere with other processes?
267
+ # torch.cuda.empty_cache()
268
+ gc.collect()
269
+
270
+ # Checkpointing disabled to prevent disk space issues
271
+ # if (step + 1) % cfg.train.checkpoint_frequency == 0:
272
+ # with timeit("save_checkpoint"):
273
+ # torchrl_logger.info(
274
+ # f"Saving checkpoint {(step+1) // cfg.train.checkpoint_frequency}..."
275
+ # )
276
+ # checkpoint = {
277
+ # "step": step,
278
+ # "model_state_dict": policy_training.model.state_dict(),
279
+ # "optimizer_state_dict": optimizer.state_dict(),
280
+ # "scaler_state_dict": scaler.state_dict(),
281
+ # "config": dict(cfg),
282
+ # }
283
+ # torch.save(checkpoint, checkpoint_dir / f"checkpoint_{step:04d}.pt")
284
+
285
+ if step % cfg.train.weight_update_frequency == 0:
286
+ timeit.print(prefix="timeit")
287
+ for key, val in timeit.todict().items():
288
+ wandb_logger.log_scalar(f"timeit/{key}", val)
289
+ timeit.reset()
290
+
291
+ del loss_val
292
+ # TODO: do we need this? Does it interfere with other processes?
293
+ # torch.cuda.empty_cache()
294
+ gc.collect()
295
+
296
+ pbar.close()
297
+ collector.shutdown()
298
+
299
+
300
+ @hydra.main(version_base=None, config_path="config", config_name="grpo_gsm8k")
301
+ def main(cfg):
302
+ # Check for required GRPO dependencies
303
+ check_grpo_dependencies()
304
+
305
+ # Force async mode
306
+ if cfg.train.sync:
307
+ raise ValueError(
308
+ "grpo-async.py must run in async mode (`python grpo-async.py mode=async`). Please use grpo-sync.py for sync mode (`python grpo-sync.py mode=sync`)."
309
+ )
310
+
311
+ # Compute device allocation
312
+ device_config = compute_device_allocation(cfg)
313
+
314
+ if not ray.is_initialized():
315
+ # Convert OmegaConf to regular dict and filter out unsupported parameters
316
+ ray_init_config = {
317
+ k: dict(v) if isinstance(v, DictConfig) else v
318
+ for k, v in dict(cfg.ray.init_config).items()
319
+ if not k.startswith("_")
320
+ }
321
+
322
+ # Add computed GPU configuration and merge with default runtime_env
323
+ ray_init_config["num_gpus"] = device_config["ray_num_gpus"]
324
+ ray_init_config = merge_ray_runtime_env(ray_init_config)
325
+ torchrl_logger.info(f"Ray init config: {ray_init_config=}")
326
+ ray_managed_externally = os.environ.get("RAY_CLUSTER_MANAGED_EXTERNALLY")
327
+ if ray_managed_externally:
328
+ ray.init(address="auto")
329
+ else:
330
+ ray.init(**ray_init_config)
331
+
332
+ # Check if num_devices is set
333
+ if cfg.inference_model.num_devices is None:
334
+ raise ValueError(
335
+ "Inference model num_devices must be set via inference_model.num_devices"
336
+ )
337
+ if cfg.train.use_kl_to_ref and cfg.ref_model.num_devices is None:
338
+ raise ValueError(
339
+ "Ref model num_devices must be set via ref_model.num_devices when use_kl_to_ref is True"
340
+ )
341
+ if cfg.train_model.num_devices is None:
342
+ raise ValueError(
343
+ "Train model num_devices must be set via train_model.num_devices"
344
+ )
345
+
346
+ # Convert OmegaConf to regular dict for Ray configs
347
+ replay_buffer_config = dict(cfg.ray.replay_buffer_config)
348
+ collector_config = dict(cfg.ray.collector_config)
349
+ train_handler_config = dict(cfg.ray.train_handler_config)
350
+
351
+ inference_policy = get_inference_model(
352
+ cfg,
353
+ devices=device_config["inference_model_devices"],
354
+ )
355
+ torchrl_logger.info(f"Inference policy: {inference_policy}")
356
+
357
+ torchrl_logger.info(f"Starting replay buffer with {replay_buffer_config=}")
358
+ if cfg.train.optim_batch_size % cfg.train.gradient_accumulation_steps != 0:
359
+ raise ValueError(
360
+ "optim_batch_size must be divisible by gradient_accumulation_steps"
361
+ )
362
+ rb = RayReplayBuffer(
363
+ storage=partial(
364
+ LazyStackStorage,
365
+ cfg.train.buffer_size
366
+ if cfg.train.buffer_size
367
+ else cfg.env.repeats * cfg.env.num_envs,
368
+ ),
369
+ transform_factory=partial(MCAdvantage, grpo_size=cfg.env.repeats),
370
+ batch_size=max(
371
+ 1, cfg.train.optim_batch_size // cfg.train.gradient_accumulation_steps
372
+ ),
373
+ remote_config=replay_buffer_config,
374
+ )
375
+
376
+ add_kl_transforms_to_replay_buffer(rb, cfg)
377
+
378
+ torchrl_logger.info(f"Replay buffer: {rb}")
379
+
380
+ collector_config["num_gpus"] = 0
381
+ collector_config["num_cpus"] = 2
382
+ torchrl_logger.info(f"Starting collector with {collector_config=}")
383
+
384
+ if cfg.train.sync_iter is not None:
385
+ raise ValueError("sync_iter is not supported in async mode.")
386
+ collectors = []
387
+ for i in tqdm.trange(cfg.env.num_envs, desc="Starting collectors"):
388
+ collector = RayLLMCollector(
389
+ env=partial(make_env, cfg, single_env=True),
390
+ policy=inference_policy,
391
+ dialog_turns_per_batch=cfg.train.dialog_turns_per_batch,
392
+ total_dialog_turns=cfg.train.total_dialog_turns,
393
+ replay_buffer=rb,
394
+ ray_init_config=None,
395
+ weight_updater=None,
396
+ track_policy_version=True,
397
+ remote_config=collector_config,
398
+ yield_only_last_steps=cfg.env.reasoning,
399
+ verbose=False,
400
+ )
401
+ collectors.append(collector)
402
+ if i == 0:
403
+ # wait for the first collector to initialize
404
+ ray.get(collector._collector.is_initialized.remote())
405
+ inits = []
406
+ for collector in tqdm.tqdm(
407
+ collectors[1:], desc="Checking collector initialization"
408
+ ):
409
+ inits.append(collector._collector.is_initialized.remote())
410
+ ray.get(inits)
411
+ torchrl_logger.info("All collectors initialized")
412
+
413
+ train_handler_config = {
414
+ "num_cpus": train_handler_config.get("num_cpus", 1),
415
+ "num_gpus": cfg.train_model.num_devices,
416
+ }
417
+ torchrl_logger.info(f"Starting training handler with {train_handler_config=}")
418
+ train_handler = ray.remote(
419
+ **train_handler_config,
420
+ )(train)
421
+
422
+ # launch training
423
+ ray.get(
424
+ train_handler.remote(
425
+ rb,
426
+ cfg,
427
+ collectors,
428
+ inference_policy,
429
+ devices=device_config["train_model_devices"],
430
+ )
431
+ )
432
+
433
+
434
+ if __name__ == "__main__":
435
+ # Setup environment
436
+ setup_environment()
437
+ main()