torchrl 0.11.0__cp314-cp314-manylinux_2_28_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (394) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/__init__.py +144 -0
  97. torchrl/_extension.py +74 -0
  98. torchrl/_torchrl.cpython-314-aarch64-linux-gnu.so +0 -0
  99. torchrl/_utils.py +1431 -0
  100. torchrl/collectors/__init__.py +48 -0
  101. torchrl/collectors/_base.py +1058 -0
  102. torchrl/collectors/_constants.py +88 -0
  103. torchrl/collectors/_multi_async.py +324 -0
  104. torchrl/collectors/_multi_base.py +1805 -0
  105. torchrl/collectors/_multi_sync.py +464 -0
  106. torchrl/collectors/_runner.py +581 -0
  107. torchrl/collectors/_single.py +2009 -0
  108. torchrl/collectors/_single_async.py +259 -0
  109. torchrl/collectors/collectors.py +62 -0
  110. torchrl/collectors/distributed/__init__.py +32 -0
  111. torchrl/collectors/distributed/default_configs.py +133 -0
  112. torchrl/collectors/distributed/generic.py +1306 -0
  113. torchrl/collectors/distributed/ray.py +1092 -0
  114. torchrl/collectors/distributed/rpc.py +1006 -0
  115. torchrl/collectors/distributed/sync.py +731 -0
  116. torchrl/collectors/distributed/utils.py +160 -0
  117. torchrl/collectors/llm/__init__.py +10 -0
  118. torchrl/collectors/llm/base.py +494 -0
  119. torchrl/collectors/llm/ray_collector.py +275 -0
  120. torchrl/collectors/llm/utils.py +36 -0
  121. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  122. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  123. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  124. torchrl/collectors/utils.py +433 -0
  125. torchrl/collectors/weight_update.py +591 -0
  126. torchrl/csrc/numpy_utils.h +38 -0
  127. torchrl/csrc/pybind.cpp +27 -0
  128. torchrl/csrc/segment_tree.h +458 -0
  129. torchrl/csrc/torch_utils.h +34 -0
  130. torchrl/csrc/utils.cpp +48 -0
  131. torchrl/csrc/utils.h +31 -0
  132. torchrl/data/__init__.py +187 -0
  133. torchrl/data/datasets/__init__.py +58 -0
  134. torchrl/data/datasets/atari_dqn.py +878 -0
  135. torchrl/data/datasets/common.py +281 -0
  136. torchrl/data/datasets/d4rl.py +489 -0
  137. torchrl/data/datasets/d4rl_infos.py +187 -0
  138. torchrl/data/datasets/gen_dgrl.py +375 -0
  139. torchrl/data/datasets/minari_data.py +643 -0
  140. torchrl/data/datasets/openml.py +177 -0
  141. torchrl/data/datasets/openx.py +798 -0
  142. torchrl/data/datasets/roboset.py +363 -0
  143. torchrl/data/datasets/utils.py +11 -0
  144. torchrl/data/datasets/vd4rl.py +432 -0
  145. torchrl/data/llm/__init__.py +34 -0
  146. torchrl/data/llm/dataset.py +491 -0
  147. torchrl/data/llm/history.py +1378 -0
  148. torchrl/data/llm/prompt.py +198 -0
  149. torchrl/data/llm/reward.py +225 -0
  150. torchrl/data/llm/topk.py +186 -0
  151. torchrl/data/llm/utils.py +543 -0
  152. torchrl/data/map/__init__.py +21 -0
  153. torchrl/data/map/hash.py +185 -0
  154. torchrl/data/map/query.py +204 -0
  155. torchrl/data/map/tdstorage.py +363 -0
  156. torchrl/data/map/tree.py +1434 -0
  157. torchrl/data/map/utils.py +103 -0
  158. torchrl/data/postprocs/__init__.py +8 -0
  159. torchrl/data/postprocs/postprocs.py +391 -0
  160. torchrl/data/replay_buffers/__init__.py +99 -0
  161. torchrl/data/replay_buffers/checkpointers.py +622 -0
  162. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  163. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  164. torchrl/data/replay_buffers/samplers.py +2578 -0
  165. torchrl/data/replay_buffers/scheduler.py +265 -0
  166. torchrl/data/replay_buffers/storages.py +2412 -0
  167. torchrl/data/replay_buffers/utils.py +1042 -0
  168. torchrl/data/replay_buffers/writers.py +781 -0
  169. torchrl/data/tensor_specs.py +7101 -0
  170. torchrl/data/utils.py +334 -0
  171. torchrl/envs/__init__.py +265 -0
  172. torchrl/envs/async_envs.py +1105 -0
  173. torchrl/envs/batched_envs.py +3093 -0
  174. torchrl/envs/common.py +4241 -0
  175. torchrl/envs/custom/__init__.py +11 -0
  176. torchrl/envs/custom/chess.py +617 -0
  177. torchrl/envs/custom/llm.py +214 -0
  178. torchrl/envs/custom/pendulum.py +401 -0
  179. torchrl/envs/custom/san_moves.txt +29274 -0
  180. torchrl/envs/custom/tictactoeenv.py +288 -0
  181. torchrl/envs/env_creator.py +263 -0
  182. torchrl/envs/gym_like.py +752 -0
  183. torchrl/envs/libs/__init__.py +68 -0
  184. torchrl/envs/libs/_gym_utils.py +326 -0
  185. torchrl/envs/libs/brax.py +846 -0
  186. torchrl/envs/libs/dm_control.py +544 -0
  187. torchrl/envs/libs/envpool.py +447 -0
  188. torchrl/envs/libs/gym.py +2239 -0
  189. torchrl/envs/libs/habitat.py +138 -0
  190. torchrl/envs/libs/isaac_lab.py +87 -0
  191. torchrl/envs/libs/isaacgym.py +203 -0
  192. torchrl/envs/libs/jax_utils.py +166 -0
  193. torchrl/envs/libs/jumanji.py +963 -0
  194. torchrl/envs/libs/meltingpot.py +599 -0
  195. torchrl/envs/libs/openml.py +153 -0
  196. torchrl/envs/libs/openspiel.py +652 -0
  197. torchrl/envs/libs/pettingzoo.py +1042 -0
  198. torchrl/envs/libs/procgen.py +351 -0
  199. torchrl/envs/libs/robohive.py +429 -0
  200. torchrl/envs/libs/smacv2.py +645 -0
  201. torchrl/envs/libs/unity_mlagents.py +891 -0
  202. torchrl/envs/libs/utils.py +147 -0
  203. torchrl/envs/libs/vmas.py +813 -0
  204. torchrl/envs/llm/__init__.py +63 -0
  205. torchrl/envs/llm/chat.py +730 -0
  206. torchrl/envs/llm/datasets/README.md +4 -0
  207. torchrl/envs/llm/datasets/__init__.py +17 -0
  208. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  209. torchrl/envs/llm/datasets/ifeval.py +274 -0
  210. torchrl/envs/llm/envs.py +789 -0
  211. torchrl/envs/llm/libs/README.md +3 -0
  212. torchrl/envs/llm/libs/__init__.py +8 -0
  213. torchrl/envs/llm/libs/mlgym.py +869 -0
  214. torchrl/envs/llm/reward/__init__.py +10 -0
  215. torchrl/envs/llm/reward/gsm8k.py +324 -0
  216. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  217. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  218. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  222. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  223. torchrl/envs/llm/transforms/__init__.py +55 -0
  224. torchrl/envs/llm/transforms/browser.py +292 -0
  225. torchrl/envs/llm/transforms/dataloading.py +859 -0
  226. torchrl/envs/llm/transforms/format.py +73 -0
  227. torchrl/envs/llm/transforms/kl.py +1544 -0
  228. torchrl/envs/llm/transforms/policy_version.py +189 -0
  229. torchrl/envs/llm/transforms/reason.py +323 -0
  230. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  231. torchrl/envs/llm/transforms/tools.py +1955 -0
  232. torchrl/envs/model_based/__init__.py +9 -0
  233. torchrl/envs/model_based/common.py +180 -0
  234. torchrl/envs/model_based/dreamer.py +112 -0
  235. torchrl/envs/transforms/__init__.py +147 -0
  236. torchrl/envs/transforms/functional.py +48 -0
  237. torchrl/envs/transforms/gym_transforms.py +203 -0
  238. torchrl/envs/transforms/module.py +341 -0
  239. torchrl/envs/transforms/r3m.py +372 -0
  240. torchrl/envs/transforms/ray_service.py +663 -0
  241. torchrl/envs/transforms/rb_transforms.py +214 -0
  242. torchrl/envs/transforms/transforms.py +11835 -0
  243. torchrl/envs/transforms/utils.py +94 -0
  244. torchrl/envs/transforms/vc1.py +307 -0
  245. torchrl/envs/transforms/vecnorm.py +845 -0
  246. torchrl/envs/transforms/vip.py +407 -0
  247. torchrl/envs/utils.py +1718 -0
  248. torchrl/envs/vec_envs.py +11 -0
  249. torchrl/modules/__init__.py +206 -0
  250. torchrl/modules/distributions/__init__.py +73 -0
  251. torchrl/modules/distributions/continuous.py +830 -0
  252. torchrl/modules/distributions/discrete.py +908 -0
  253. torchrl/modules/distributions/truncated_normal.py +187 -0
  254. torchrl/modules/distributions/utils.py +233 -0
  255. torchrl/modules/llm/__init__.py +62 -0
  256. torchrl/modules/llm/backends/__init__.py +65 -0
  257. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  258. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  259. torchrl/modules/llm/backends/vllm/base.py +72 -0
  260. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  261. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  262. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  263. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  264. torchrl/modules/llm/policies/__init__.py +28 -0
  265. torchrl/modules/llm/policies/common.py +1809 -0
  266. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  267. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  268. torchrl/modules/llm/utils.py +23 -0
  269. torchrl/modules/mcts/__init__.py +21 -0
  270. torchrl/modules/mcts/scores.py +579 -0
  271. torchrl/modules/models/__init__.py +86 -0
  272. torchrl/modules/models/batchrenorm.py +119 -0
  273. torchrl/modules/models/decision_transformer.py +179 -0
  274. torchrl/modules/models/exploration.py +731 -0
  275. torchrl/modules/models/llm.py +156 -0
  276. torchrl/modules/models/model_based.py +596 -0
  277. torchrl/modules/models/models.py +1712 -0
  278. torchrl/modules/models/multiagent.py +1067 -0
  279. torchrl/modules/models/recipes/impala.py +185 -0
  280. torchrl/modules/models/utils.py +162 -0
  281. torchrl/modules/planners/__init__.py +10 -0
  282. torchrl/modules/planners/cem.py +228 -0
  283. torchrl/modules/planners/common.py +73 -0
  284. torchrl/modules/planners/mppi.py +265 -0
  285. torchrl/modules/tensordict_module/__init__.py +89 -0
  286. torchrl/modules/tensordict_module/actors.py +2457 -0
  287. torchrl/modules/tensordict_module/common.py +529 -0
  288. torchrl/modules/tensordict_module/exploration.py +814 -0
  289. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  290. torchrl/modules/tensordict_module/rnn.py +1639 -0
  291. torchrl/modules/tensordict_module/sequence.py +132 -0
  292. torchrl/modules/tensordict_module/world_models.py +34 -0
  293. torchrl/modules/utils/__init__.py +38 -0
  294. torchrl/modules/utils/mappings.py +9 -0
  295. torchrl/modules/utils/utils.py +89 -0
  296. torchrl/objectives/__init__.py +78 -0
  297. torchrl/objectives/a2c.py +659 -0
  298. torchrl/objectives/common.py +753 -0
  299. torchrl/objectives/cql.py +1346 -0
  300. torchrl/objectives/crossq.py +710 -0
  301. torchrl/objectives/ddpg.py +453 -0
  302. torchrl/objectives/decision_transformer.py +371 -0
  303. torchrl/objectives/deprecated.py +516 -0
  304. torchrl/objectives/dqn.py +683 -0
  305. torchrl/objectives/dreamer.py +488 -0
  306. torchrl/objectives/functional.py +48 -0
  307. torchrl/objectives/gail.py +258 -0
  308. torchrl/objectives/iql.py +996 -0
  309. torchrl/objectives/llm/__init__.py +30 -0
  310. torchrl/objectives/llm/grpo.py +846 -0
  311. torchrl/objectives/llm/sft.py +482 -0
  312. torchrl/objectives/multiagent/__init__.py +8 -0
  313. torchrl/objectives/multiagent/qmixer.py +396 -0
  314. torchrl/objectives/ppo.py +1669 -0
  315. torchrl/objectives/redq.py +683 -0
  316. torchrl/objectives/reinforce.py +530 -0
  317. torchrl/objectives/sac.py +1580 -0
  318. torchrl/objectives/td3.py +570 -0
  319. torchrl/objectives/td3_bc.py +625 -0
  320. torchrl/objectives/utils.py +782 -0
  321. torchrl/objectives/value/__init__.py +28 -0
  322. torchrl/objectives/value/advantages.py +1956 -0
  323. torchrl/objectives/value/functional.py +1459 -0
  324. torchrl/objectives/value/utils.py +360 -0
  325. torchrl/record/__init__.py +17 -0
  326. torchrl/record/loggers/__init__.py +23 -0
  327. torchrl/record/loggers/common.py +48 -0
  328. torchrl/record/loggers/csv.py +226 -0
  329. torchrl/record/loggers/mlflow.py +142 -0
  330. torchrl/record/loggers/tensorboard.py +139 -0
  331. torchrl/record/loggers/trackio.py +163 -0
  332. torchrl/record/loggers/utils.py +78 -0
  333. torchrl/record/loggers/wandb.py +214 -0
  334. torchrl/record/recorder.py +554 -0
  335. torchrl/services/__init__.py +79 -0
  336. torchrl/services/base.py +109 -0
  337. torchrl/services/ray_service.py +453 -0
  338. torchrl/testing/__init__.py +107 -0
  339. torchrl/testing/assertions.py +179 -0
  340. torchrl/testing/dist_utils.py +122 -0
  341. torchrl/testing/env_creators.py +227 -0
  342. torchrl/testing/env_helper.py +35 -0
  343. torchrl/testing/gym_helpers.py +156 -0
  344. torchrl/testing/llm_mocks.py +119 -0
  345. torchrl/testing/mocking_classes.py +2720 -0
  346. torchrl/testing/modules.py +295 -0
  347. torchrl/testing/mp_helpers.py +15 -0
  348. torchrl/testing/ray_helpers.py +293 -0
  349. torchrl/testing/utils.py +190 -0
  350. torchrl/trainers/__init__.py +42 -0
  351. torchrl/trainers/algorithms/__init__.py +11 -0
  352. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  353. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  354. torchrl/trainers/algorithms/configs/common.py +41 -0
  355. torchrl/trainers/algorithms/configs/data.py +308 -0
  356. torchrl/trainers/algorithms/configs/envs.py +104 -0
  357. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  358. torchrl/trainers/algorithms/configs/logging.py +80 -0
  359. torchrl/trainers/algorithms/configs/modules.py +570 -0
  360. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  361. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  362. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  363. torchrl/trainers/algorithms/configs/utils.py +252 -0
  364. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  365. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  366. torchrl/trainers/algorithms/ppo.py +373 -0
  367. torchrl/trainers/algorithms/sac.py +308 -0
  368. torchrl/trainers/helpers/__init__.py +40 -0
  369. torchrl/trainers/helpers/collectors.py +416 -0
  370. torchrl/trainers/helpers/envs.py +573 -0
  371. torchrl/trainers/helpers/logger.py +33 -0
  372. torchrl/trainers/helpers/losses.py +132 -0
  373. torchrl/trainers/helpers/models.py +658 -0
  374. torchrl/trainers/helpers/replay_buffer.py +59 -0
  375. torchrl/trainers/helpers/trainers.py +301 -0
  376. torchrl/trainers/trainers.py +2052 -0
  377. torchrl/weight_update/__init__.py +33 -0
  378. torchrl/weight_update/_distributed.py +749 -0
  379. torchrl/weight_update/_mp.py +624 -0
  380. torchrl/weight_update/_noupdate.py +102 -0
  381. torchrl/weight_update/_ray.py +1032 -0
  382. torchrl/weight_update/_rpc.py +284 -0
  383. torchrl/weight_update/_shared.py +891 -0
  384. torchrl/weight_update/llm/__init__.py +32 -0
  385. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  386. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  387. torchrl/weight_update/utils.py +73 -0
  388. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  389. torchrl-0.11.0.dist-info/METADATA +1308 -0
  390. torchrl-0.11.0.dist-info/RECORD +394 -0
  391. torchrl-0.11.0.dist-info/WHEEL +5 -0
  392. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  393. torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
  394. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,512 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from __future__ import annotations
7
+
8
+ import gc
9
+ import time
10
+ from functools import partial
11
+ from pathlib import Path
12
+
13
+ import hydra
14
+
15
+ from torchrl import merge_ray_runtime_env, torchrl_logger
16
+ from torchrl.data.llm.history import History
17
+ from torchrl.record.loggers.wandb import WandbLogger
18
+ from torchrl.weight_update.llm import get_model_metadata
19
+
20
+ try:
21
+ import ray
22
+ except ImportError:
23
+ raise ImportError(
24
+ "Ray is required for async training. Please install ray with `pip install ray`."
25
+ )
26
+ import torch
27
+ import tqdm
28
+
29
+ from ei_utils import (
30
+ compute_device_allocation,
31
+ create_cosine_scheduler_with_warmup,
32
+ get_inference_model,
33
+ get_train_model,
34
+ log_training_metrics,
35
+ make_env,
36
+ make_weight_sync_scheme,
37
+ RemoteDataLogger,
38
+ )
39
+ from omegaconf import DictConfig
40
+ from ray.util.queue import Queue
41
+
42
+ try:
43
+ from tensordict import set_list_to_stack
44
+ except ImportError:
45
+ raise ImportError(
46
+ "TensorDict is required. Please install it with `pip install tensordict`."
47
+ )
48
+ from torch.amp.autocast_mode import autocast
49
+ from torch.amp.grad_scaler import GradScaler
50
+ from torchrl._utils import timeit
51
+ from torchrl.collectors.llm import RayLLMCollector
52
+ from torchrl.data import (
53
+ LazyStackStorage,
54
+ PrioritizedSampler,
55
+ ReplayBuffer,
56
+ TensorDictReplayBuffer,
57
+ )
58
+ from torchrl.data.llm.topk import TopKRewardSelector
59
+ from torchrl.data.replay_buffers.ray_buffer import RayReplayBuffer
60
+ from torchrl.objectives.llm.sft import SFTLoss
61
+
62
+
63
+ def setup_environment() -> None:
64
+ """Setup required environment variables and configurations."""
65
+
66
+ if not torch.cuda.is_available():
67
+ raise RuntimeError("CUDA is required for training")
68
+
69
+ # Set default dtype to float32 for mixed precision training
70
+ torch.set_default_dtype(torch.float32)
71
+ torch.set_default_device("cuda:0")
72
+ set_list_to_stack(True).set()
73
+
74
+ # Ensure CUDA is using the correct dtype
75
+ if torch.cuda.is_available():
76
+ torch.cuda.set_device("cuda:0")
77
+
78
+
79
+ def train(
80
+ replay_buffer: ReplayBuffer,
81
+ cfg: DictConfig,
82
+ collector: RayLLMCollector,
83
+ devices: list[int] | None = None,
84
+ ):
85
+ """Main training loop for EI async.
86
+
87
+ This function implements asynchronous training where data collection and optimization
88
+ happen concurrently. The total number of steps is determined by the number of epochs,
89
+ samples per epoch, and batches collected.
90
+
91
+ Args:
92
+ replay_buffer: The replay buffer to store experiences
93
+ cfg: The configuration object containing training parameters
94
+ collector: The collector object.
95
+ devices: The devices to use for the training model.
96
+ """
97
+ # Setup training model and tokenizer
98
+ policy_training, train_tokenizer = get_train_model(
99
+ cfg, devices=devices, chat_template_name="qwen"
100
+ )
101
+ train_device = devices[0] # Use first device for batch processing
102
+
103
+ # Setup loss function
104
+ loss_fn = SFTLoss(
105
+ actor_network=policy_training,
106
+ kl_to_ref_coeff=cfg.train.kl_to_ref_coeff,
107
+ tokenizer=train_tokenizer,
108
+ tokenizer_kwargs={"chat_template_name": "qwen"},
109
+ device=torch.device(f"cuda:{train_device}")
110
+ if train_device is not None
111
+ else None,
112
+ loss_function=cfg.train.loss_function,
113
+ beta=cfg.train.minor_sft_beta,
114
+ )
115
+ if cfg.model.compile:
116
+ loss_fn = torch.compile(loss_fn)
117
+
118
+ # Get vLLM engine from the inference policy
119
+ # Note: In expert iteration, the inference policy is typically created in get_inference_model
120
+ # We need to get the vLLM engine from the collector's policy or create it
121
+ # For now, we'll use the approach similar to GRPO with explicit scheme creation
122
+
123
+ # Create weight sync scheme
124
+ weight_sync_scheme = make_weight_sync_scheme(
125
+ master_address="localhost", # Since we're running locally
126
+ master_port=None, # Will auto-assign an open port
127
+ vllm_tp_size=cfg.inference_model.num_devices
128
+ if cfg.inference_model.num_devices is not None
129
+ else len(cfg.inference_model.get("devices", [1])),
130
+ )
131
+
132
+ # Set up weight sender
133
+ torchrl_logger.info("Setting up weight synchronization scheme...")
134
+ sender = weight_sync_scheme.create_sender()
135
+ sender.register_model(policy_training)
136
+
137
+ # Get vLLM engine reference from collector's policy
138
+ # The collector has the policy which wraps the vLLM engine
139
+ vllm_engine = collector.policy.model if hasattr(collector, "policy") else None
140
+ if vllm_engine is None:
141
+ raise RuntimeError("Could not get vLLM engine from collector policy")
142
+
143
+ # Initialize collective group
144
+ torchrl_logger.info("Initializing collective group...")
145
+ metadata = get_model_metadata(policy_training)
146
+ sender.init_all_workers_group(metadata, vllm_engine=vllm_engine)
147
+
148
+ # First weight update
149
+ with timeit("update_policy_weights"):
150
+ sender.update_weights()
151
+ timeit.print(prefix="First update_policy_weights_ time")
152
+ timeit.reset()
153
+
154
+ # Make optimizer
155
+ torchrl_logger.info("Starting optimizer.")
156
+ optimizer = torch.optim.Adam(
157
+ policy_training.parameters(),
158
+ lr=cfg.optimizer.lr,
159
+ weight_decay=cfg.optimizer.weight_decay,
160
+ fused=False,
161
+ )
162
+ scaler = GradScaler(enabled=cfg.train.mixed_precision)
163
+
164
+ # Calculate total optimization steps for scheduler
165
+ # The training loop structure: for each collector iteration, we do cfg.train.epochs epochs
166
+ # Each epoch processes the entire replay buffer, and optimization happens every gradient_accumulation_steps
167
+ # We need to estimate the total number of optimization steps
168
+ # For now, we'll use a conservative estimate based on the total dialog turns
169
+ # This can be refined based on the actual training dynamics
170
+ total_optim_steps = (
171
+ cfg.train.total_dialog_turns
172
+ * cfg.train.epochs
173
+ // cfg.train.gradient_accumulation_steps
174
+ )
175
+
176
+ # Create scheduler if enabled
177
+ scheduler = None
178
+ if cfg.optimizer.scheduler.enabled:
179
+ warmup_steps = cfg.optimizer.scheduler.warmup_steps
180
+ num_cycles = cfg.optimizer.scheduler.num_cycles
181
+ torchrl_logger.info(
182
+ f"Creating {cfg.optimizer.scheduler.type} scheduler with {warmup_steps} warmup steps out of {total_optim_steps} total steps"
183
+ )
184
+
185
+ scheduler = create_cosine_scheduler_with_warmup(
186
+ optimizer,
187
+ num_warmup_steps=warmup_steps,
188
+ num_training_steps=total_optim_steps,
189
+ num_cycles=num_cycles,
190
+ )
191
+
192
+ # Make checkpoint dir
193
+ checkpoint_dir = Path(cfg.logging.checkpoint_dir)
194
+ checkpoint_dir.mkdir(parents=True, exist_ok=True)
195
+
196
+ # Make wandb logger
197
+ torchrl_logger.info("Starting wandb logger.")
198
+ experiment_name = cfg.logging.experiment_name
199
+ if experiment_name is not None:
200
+ experiment_name = [experiment_name]
201
+ else:
202
+ experiment_name = []
203
+
204
+ experiment_name.append(cfg.env.dataset)
205
+ experiment_name.append(cfg.model.name)
206
+
207
+ # Create local wandb logger for training metrics
208
+ wandb_config = {
209
+ "project": "ei-async",
210
+ "exp_name": "-".join(["ei-async"] + experiment_name),
211
+ }
212
+ wandb_logger = WandbLogger(**wandb_config)
213
+
214
+ # Pass the logging actor reference to the collector
215
+ log_queue = Queue(maxsize=1000)
216
+ collector.set_postproc(RemoteDataLogger(log_queue=log_queue))
217
+
218
+ # Start collector
219
+ collector.start()
220
+
221
+ # Wait for initial data
222
+ while not replay_buffer.write_count:
223
+ time.sleep(1)
224
+
225
+ # Training loop
226
+ total_steps = (
227
+ -(cfg.train.total_dialog_turns // -cfg.train.optim_batch_size)
228
+ * cfg.train.epochs
229
+ )
230
+ torchrl_logger.info(f"Total steps: {total_steps}")
231
+
232
+ pbar = tqdm.tqdm(total=total_steps)
233
+ grad_norm = 0.0 # Initialize grad_norm
234
+ data_read_count = 0
235
+ optim_step = 0
236
+ start_time = time.time()
237
+
238
+ for step in range(total_steps):
239
+ pbar.update(1)
240
+ pbar.set_description(f"Step {step}, writes: {replay_buffer.write_count}")
241
+
242
+ with timeit("sampling"):
243
+ # Sample batch and move to device
244
+ batch = replay_buffer.sample(cfg.train.optim_batch_size).to(train_device)
245
+
246
+ max_policy_age = (
247
+ batch.view(-1)[0]["next", "policy_version"] - collector.policy_version
248
+ ).max()
249
+ if (
250
+ cfg.train.max_policy_age is not None
251
+ and max_policy_age > cfg.train.max_policy_age
252
+ ):
253
+ # Skip this batch, as it's too old
254
+ torchrl_logger.info(f"Skipping batch with policy age {max_policy_age}")
255
+ continue
256
+
257
+ # For logging purposes, we get the last element of the history
258
+ # and convert it to a string
259
+ history: History = batch.view(-1)[0]["next", "history", "prompt"]
260
+ history_str: list[str] | str = history.apply_chat_template(
261
+ tokenizer=train_tokenizer
262
+ )
263
+ while not isinstance(history_str, str):
264
+ history_str = "\n".join(history_str)
265
+
266
+ data_read_count += batch.numel()
267
+
268
+ with timeit("forward_pass"):
269
+ # Forward pass with mixed precision
270
+ with autocast("cuda", enabled=cfg.train.mixed_precision):
271
+ loss = loss_fn(batch)
272
+ if loss.loss_kl_to_ref is not None:
273
+ loss_val = loss.loss_sft + loss.loss_kl_to_ref
274
+ else:
275
+ loss_val = loss.loss_sft
276
+ loss_val = loss_val / cfg.train.gradient_accumulation_steps
277
+ with timeit("backward_pass"):
278
+ # Backward pass
279
+ if cfg.train.mixed_precision and cfg.train_model.torch_dtype == "float16":
280
+ scaler = GradScaler(enabled=True)
281
+ scaler.scale(loss_val).backward()
282
+ else:
283
+ loss_val.backward()
284
+
285
+ # Optimization step
286
+ if ((step + 1) % cfg.train.gradient_accumulation_steps) == 0:
287
+ with timeit("optim_step"):
288
+ if (
289
+ cfg.train.mixed_precision
290
+ and cfg.train_model.torch_dtype == "float16"
291
+ ):
292
+ scaler.unscale_(optimizer)
293
+
294
+ grad_norm = torch.nn.utils.clip_grad_norm_(
295
+ policy_training.parameters(),
296
+ cfg.optimizer.clip_grad_norm,
297
+ )
298
+
299
+ if (
300
+ cfg.train.mixed_precision
301
+ and cfg.train_model.torch_dtype == "float16"
302
+ ):
303
+ scaler.step(optimizer)
304
+ scaler.update()
305
+ else:
306
+ optimizer.step()
307
+ optimizer.zero_grad(set_to_none=True)
308
+
309
+ # Step the scheduler
310
+ if scheduler is not None:
311
+ scheduler.step()
312
+
313
+ # Increment optimization step counter
314
+ optim_step += 1
315
+
316
+ # Update metrics
317
+ if (step % cfg.train.logging_frequency) == 0:
318
+ log_training_metrics(
319
+ wandb_logger=wandb_logger,
320
+ replay_buffer=replay_buffer,
321
+ batch=batch,
322
+ loss=loss,
323
+ grad_norm=grad_norm,
324
+ global_step=step,
325
+ data_read_count=data_read_count,
326
+ collector=collector,
327
+ start_time=start_time,
328
+ gradient_accumulation_steps=cfg.train.gradient_accumulation_steps,
329
+ history_str=history_str,
330
+ )
331
+ # Log additional metrics
332
+ wandb_logger.log_scalar(
333
+ "learning_rate", float(optimizer.param_groups[0]["lr"]), step=step
334
+ )
335
+ wandb_logger.log_scalar("optim_step", optim_step, step=step)
336
+ while not log_queue.empty():
337
+ logs = log_queue.get()
338
+ for k, v in logs.items():
339
+ wandb_logger.log_scalar(k, v)
340
+
341
+ # Update policy weights
342
+ if step % cfg.train.weight_update_frequency == 0:
343
+ with timeit("update_policy_weights"):
344
+ torchrl_logger.info("Updating policy weights...")
345
+ sender.update_weights()
346
+ # TODO: do we need this? Does it interfere with other processes?
347
+ # torch.cuda.empty_cache()
348
+ gc.collect()
349
+
350
+ # Checkpointing disabled to prevent disk space issues
351
+ # if (step + 1) % cfg.train.checkpoint_frequency == 0:
352
+ # with timeit("save_checkpoint"):
353
+ # torchrl_logger.info(
354
+ # f"Saving checkpoint {(step+1) // cfg.train.checkpoint_frequency}..."
355
+ # )
356
+ # checkpoint = {
357
+ # "step": step,
358
+ # "model_state_dict": policy_training.model.state_dict(),
359
+ # "optimizer_state_dict": optimizer.state_dict(),
360
+ # "scaler_state_dict": scaler.state_dict(),
361
+ # "config": dict(cfg),
362
+ # }
363
+ # torch.save(checkpoint, checkpoint_dir / f"checkpoint_{step:04d}.pt")
364
+
365
+ if step % cfg.train.weight_update_frequency == 0:
366
+ timeit.print(prefix="timeit")
367
+ for key, val in timeit.todict().items():
368
+ wandb_logger.log_scalar(f"timeit/{key}", val)
369
+ timeit.reset()
370
+
371
+ # Clear memory
372
+ del loss_val
373
+ # TODO: do we need this? Does it interfere with other processes?
374
+ # torch.cuda.empty_cache()
375
+ gc.collect()
376
+
377
+ pbar.close()
378
+ collector.shutdown()
379
+
380
+
381
+ @hydra.main(version_base=None, config_path="config", config_name="ei_gsm8k")
382
+ def main(cfg):
383
+ # Force async mode
384
+ if cfg.train.sync:
385
+ raise ValueError(
386
+ "expert-iteration-async.py must run in async mode (`python expert-iteration-async.py mode=async`). Please use expert-iteration-sync.py for sync mode (`python expert-iteration-sync.py mode=sync`)."
387
+ )
388
+
389
+ # Compute device allocation
390
+ device_config = compute_device_allocation(cfg)
391
+
392
+ if not ray.is_initialized():
393
+ # Convert OmegaConf to regular dict and filter out unsupported parameters
394
+ ray_init_config = {
395
+ k: dict(v) if isinstance(v, DictConfig) else v
396
+ for k, v in dict(cfg.ray.init_config).items()
397
+ if not k.startswith("_")
398
+ }
399
+
400
+ # Add computed GPU configuration and merge with default runtime_env
401
+ ray_init_config["num_gpus"] = device_config["ray_num_gpus"]
402
+ ray_init_config = merge_ray_runtime_env(ray_init_config)
403
+ torchrl_logger.info(f"Ray init config: {ray_init_config=}")
404
+ ray.init(**ray_init_config)
405
+
406
+ # Check if num_devices is set
407
+ if cfg.inference_model.num_devices is None:
408
+ raise ValueError(
409
+ "Inference model num_devices must be set via inference_model.num_devices"
410
+ )
411
+ if cfg.ref_model.num_devices is None:
412
+ raise ValueError("Ref model num_devices must be set via ref_model.num_devices")
413
+ if cfg.train_model.num_devices is None:
414
+ raise ValueError(
415
+ "Train model num_devices must be set via train_model.num_devices"
416
+ )
417
+
418
+ # Convert OmegaConf to regular dict for Ray configs
419
+ replay_buffer_config = dict(cfg.ray.replay_buffer_config)
420
+ collector_config = dict(cfg.ray.collector_config)
421
+ train_handler_config = dict(cfg.ray.train_handler_config)
422
+
423
+ inference_policy = get_inference_model(
424
+ cfg, devices=device_config["inference_model_devices"]
425
+ )
426
+ torchrl_logger.info(f"Inference policy: {inference_policy}")
427
+
428
+ torchrl_logger.info(f"Starting replay buffer with {replay_buffer_config=}")
429
+ rb_size = cfg.train.buffer_size
430
+ if rb_size is None:
431
+ # Hardcoded for now
432
+ rb_size = 256
433
+ if cfg.train.prioritized_sampling:
434
+ rb_cls = TensorDictReplayBuffer
435
+ rb_sampler_cls = partial(
436
+ PrioritizedSampler,
437
+ max_capacity=rb_size,
438
+ alpha=cfg.train.prioritized_sampling_alpha,
439
+ beta=cfg.train.prioritized_sampling_beta,
440
+ eps=cfg.train.prioritized_sampling_epsilon,
441
+ )
442
+ kwargs = {"priority_key": ("next", "reward")}
443
+ else:
444
+ rb_cls = ReplayBuffer
445
+ rb_sampler_cls = None
446
+ kwargs = {}
447
+ rb = RayReplayBuffer(
448
+ storage=partial(
449
+ LazyStackStorage,
450
+ rb_size,
451
+ device="cpu",
452
+ ),
453
+ transform_factory=partial(
454
+ TopKRewardSelector,
455
+ total_dialog_turns=cfg.env.repeats,
456
+ topk_size=cfg.train.topk_size,
457
+ ),
458
+ batch_size=cfg.train.optim_batch_size,
459
+ remote_config=replay_buffer_config,
460
+ replay_buffer_cls=rb_cls,
461
+ sampler=rb_sampler_cls,
462
+ **kwargs,
463
+ )
464
+ torchrl_logger.info(f"Replay buffer: {rb}")
465
+
466
+ # Create remote collector using RayLLMCollector
467
+ collector_config["num_gpus"] = (
468
+ # The ref model will be instantiated within the collector, so we only need to allocate the number of devices for the inference model
469
+ cfg.ref_model.num_devices
470
+ )
471
+ torchrl_logger.info(f"Starting collector with {collector_config=}")
472
+
473
+ dialog_turns_per_batch = cfg.train.dialog_turns_per_batch
474
+ if dialog_turns_per_batch is None:
475
+ # Hardcoded for now
476
+ dialog_turns_per_batch = cfg.env.repeats
477
+
478
+ collector = RayLLMCollector(
479
+ env=partial(make_env, cfg, devices=device_config["ref_model_devices"]),
480
+ policy=inference_policy,
481
+ dialog_turns_per_batch=dialog_turns_per_batch,
482
+ total_dialog_turns=cfg.train.total_dialog_turns,
483
+ replay_buffer=rb,
484
+ ray_init_config=None, # Ray is already initialized
485
+ weight_updater=None, # We'll create this after getting the remote LLM
486
+ track_policy_version=True,
487
+ remote_config=collector_config,
488
+ verbose=True,
489
+ )
490
+ # Ensure collector is initialized by calling a method that will block until ready
491
+ ray.get(collector._collector.is_initialized.remote())
492
+ torchrl_logger.info(f"Collector: {collector}")
493
+
494
+ train_handler_config = {
495
+ "num_cpus": train_handler_config.get("num_cpus", 1),
496
+ "num_gpus": cfg.train_model.num_devices,
497
+ }
498
+ torchrl_logger.info(f"Starting training handler with {train_handler_config=}")
499
+ train_handler = ray.remote(
500
+ **train_handler_config,
501
+ )(train)
502
+
503
+ # launch training
504
+ ray.get(
505
+ train_handler.remote(rb, cfg, collector, device_config["train_model_devices"])
506
+ )
507
+
508
+
509
+ if __name__ == "__main__":
510
+ # Setup environment
511
+ setup_environment()
512
+ main()