torchrl 0.11.0__cp314-cp314-manylinux_2_28_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (394) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/__init__.py +144 -0
  97. torchrl/_extension.py +74 -0
  98. torchrl/_torchrl.cpython-314-aarch64-linux-gnu.so +0 -0
  99. torchrl/_utils.py +1431 -0
  100. torchrl/collectors/__init__.py +48 -0
  101. torchrl/collectors/_base.py +1058 -0
  102. torchrl/collectors/_constants.py +88 -0
  103. torchrl/collectors/_multi_async.py +324 -0
  104. torchrl/collectors/_multi_base.py +1805 -0
  105. torchrl/collectors/_multi_sync.py +464 -0
  106. torchrl/collectors/_runner.py +581 -0
  107. torchrl/collectors/_single.py +2009 -0
  108. torchrl/collectors/_single_async.py +259 -0
  109. torchrl/collectors/collectors.py +62 -0
  110. torchrl/collectors/distributed/__init__.py +32 -0
  111. torchrl/collectors/distributed/default_configs.py +133 -0
  112. torchrl/collectors/distributed/generic.py +1306 -0
  113. torchrl/collectors/distributed/ray.py +1092 -0
  114. torchrl/collectors/distributed/rpc.py +1006 -0
  115. torchrl/collectors/distributed/sync.py +731 -0
  116. torchrl/collectors/distributed/utils.py +160 -0
  117. torchrl/collectors/llm/__init__.py +10 -0
  118. torchrl/collectors/llm/base.py +494 -0
  119. torchrl/collectors/llm/ray_collector.py +275 -0
  120. torchrl/collectors/llm/utils.py +36 -0
  121. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  122. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  123. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  124. torchrl/collectors/utils.py +433 -0
  125. torchrl/collectors/weight_update.py +591 -0
  126. torchrl/csrc/numpy_utils.h +38 -0
  127. torchrl/csrc/pybind.cpp +27 -0
  128. torchrl/csrc/segment_tree.h +458 -0
  129. torchrl/csrc/torch_utils.h +34 -0
  130. torchrl/csrc/utils.cpp +48 -0
  131. torchrl/csrc/utils.h +31 -0
  132. torchrl/data/__init__.py +187 -0
  133. torchrl/data/datasets/__init__.py +58 -0
  134. torchrl/data/datasets/atari_dqn.py +878 -0
  135. torchrl/data/datasets/common.py +281 -0
  136. torchrl/data/datasets/d4rl.py +489 -0
  137. torchrl/data/datasets/d4rl_infos.py +187 -0
  138. torchrl/data/datasets/gen_dgrl.py +375 -0
  139. torchrl/data/datasets/minari_data.py +643 -0
  140. torchrl/data/datasets/openml.py +177 -0
  141. torchrl/data/datasets/openx.py +798 -0
  142. torchrl/data/datasets/roboset.py +363 -0
  143. torchrl/data/datasets/utils.py +11 -0
  144. torchrl/data/datasets/vd4rl.py +432 -0
  145. torchrl/data/llm/__init__.py +34 -0
  146. torchrl/data/llm/dataset.py +491 -0
  147. torchrl/data/llm/history.py +1378 -0
  148. torchrl/data/llm/prompt.py +198 -0
  149. torchrl/data/llm/reward.py +225 -0
  150. torchrl/data/llm/topk.py +186 -0
  151. torchrl/data/llm/utils.py +543 -0
  152. torchrl/data/map/__init__.py +21 -0
  153. torchrl/data/map/hash.py +185 -0
  154. torchrl/data/map/query.py +204 -0
  155. torchrl/data/map/tdstorage.py +363 -0
  156. torchrl/data/map/tree.py +1434 -0
  157. torchrl/data/map/utils.py +103 -0
  158. torchrl/data/postprocs/__init__.py +8 -0
  159. torchrl/data/postprocs/postprocs.py +391 -0
  160. torchrl/data/replay_buffers/__init__.py +99 -0
  161. torchrl/data/replay_buffers/checkpointers.py +622 -0
  162. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  163. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  164. torchrl/data/replay_buffers/samplers.py +2578 -0
  165. torchrl/data/replay_buffers/scheduler.py +265 -0
  166. torchrl/data/replay_buffers/storages.py +2412 -0
  167. torchrl/data/replay_buffers/utils.py +1042 -0
  168. torchrl/data/replay_buffers/writers.py +781 -0
  169. torchrl/data/tensor_specs.py +7101 -0
  170. torchrl/data/utils.py +334 -0
  171. torchrl/envs/__init__.py +265 -0
  172. torchrl/envs/async_envs.py +1105 -0
  173. torchrl/envs/batched_envs.py +3093 -0
  174. torchrl/envs/common.py +4241 -0
  175. torchrl/envs/custom/__init__.py +11 -0
  176. torchrl/envs/custom/chess.py +617 -0
  177. torchrl/envs/custom/llm.py +214 -0
  178. torchrl/envs/custom/pendulum.py +401 -0
  179. torchrl/envs/custom/san_moves.txt +29274 -0
  180. torchrl/envs/custom/tictactoeenv.py +288 -0
  181. torchrl/envs/env_creator.py +263 -0
  182. torchrl/envs/gym_like.py +752 -0
  183. torchrl/envs/libs/__init__.py +68 -0
  184. torchrl/envs/libs/_gym_utils.py +326 -0
  185. torchrl/envs/libs/brax.py +846 -0
  186. torchrl/envs/libs/dm_control.py +544 -0
  187. torchrl/envs/libs/envpool.py +447 -0
  188. torchrl/envs/libs/gym.py +2239 -0
  189. torchrl/envs/libs/habitat.py +138 -0
  190. torchrl/envs/libs/isaac_lab.py +87 -0
  191. torchrl/envs/libs/isaacgym.py +203 -0
  192. torchrl/envs/libs/jax_utils.py +166 -0
  193. torchrl/envs/libs/jumanji.py +963 -0
  194. torchrl/envs/libs/meltingpot.py +599 -0
  195. torchrl/envs/libs/openml.py +153 -0
  196. torchrl/envs/libs/openspiel.py +652 -0
  197. torchrl/envs/libs/pettingzoo.py +1042 -0
  198. torchrl/envs/libs/procgen.py +351 -0
  199. torchrl/envs/libs/robohive.py +429 -0
  200. torchrl/envs/libs/smacv2.py +645 -0
  201. torchrl/envs/libs/unity_mlagents.py +891 -0
  202. torchrl/envs/libs/utils.py +147 -0
  203. torchrl/envs/libs/vmas.py +813 -0
  204. torchrl/envs/llm/__init__.py +63 -0
  205. torchrl/envs/llm/chat.py +730 -0
  206. torchrl/envs/llm/datasets/README.md +4 -0
  207. torchrl/envs/llm/datasets/__init__.py +17 -0
  208. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  209. torchrl/envs/llm/datasets/ifeval.py +274 -0
  210. torchrl/envs/llm/envs.py +789 -0
  211. torchrl/envs/llm/libs/README.md +3 -0
  212. torchrl/envs/llm/libs/__init__.py +8 -0
  213. torchrl/envs/llm/libs/mlgym.py +869 -0
  214. torchrl/envs/llm/reward/__init__.py +10 -0
  215. torchrl/envs/llm/reward/gsm8k.py +324 -0
  216. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  217. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  218. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  222. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  223. torchrl/envs/llm/transforms/__init__.py +55 -0
  224. torchrl/envs/llm/transforms/browser.py +292 -0
  225. torchrl/envs/llm/transforms/dataloading.py +859 -0
  226. torchrl/envs/llm/transforms/format.py +73 -0
  227. torchrl/envs/llm/transforms/kl.py +1544 -0
  228. torchrl/envs/llm/transforms/policy_version.py +189 -0
  229. torchrl/envs/llm/transforms/reason.py +323 -0
  230. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  231. torchrl/envs/llm/transforms/tools.py +1955 -0
  232. torchrl/envs/model_based/__init__.py +9 -0
  233. torchrl/envs/model_based/common.py +180 -0
  234. torchrl/envs/model_based/dreamer.py +112 -0
  235. torchrl/envs/transforms/__init__.py +147 -0
  236. torchrl/envs/transforms/functional.py +48 -0
  237. torchrl/envs/transforms/gym_transforms.py +203 -0
  238. torchrl/envs/transforms/module.py +341 -0
  239. torchrl/envs/transforms/r3m.py +372 -0
  240. torchrl/envs/transforms/ray_service.py +663 -0
  241. torchrl/envs/transforms/rb_transforms.py +214 -0
  242. torchrl/envs/transforms/transforms.py +11835 -0
  243. torchrl/envs/transforms/utils.py +94 -0
  244. torchrl/envs/transforms/vc1.py +307 -0
  245. torchrl/envs/transforms/vecnorm.py +845 -0
  246. torchrl/envs/transforms/vip.py +407 -0
  247. torchrl/envs/utils.py +1718 -0
  248. torchrl/envs/vec_envs.py +11 -0
  249. torchrl/modules/__init__.py +206 -0
  250. torchrl/modules/distributions/__init__.py +73 -0
  251. torchrl/modules/distributions/continuous.py +830 -0
  252. torchrl/modules/distributions/discrete.py +908 -0
  253. torchrl/modules/distributions/truncated_normal.py +187 -0
  254. torchrl/modules/distributions/utils.py +233 -0
  255. torchrl/modules/llm/__init__.py +62 -0
  256. torchrl/modules/llm/backends/__init__.py +65 -0
  257. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  258. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  259. torchrl/modules/llm/backends/vllm/base.py +72 -0
  260. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  261. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  262. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  263. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  264. torchrl/modules/llm/policies/__init__.py +28 -0
  265. torchrl/modules/llm/policies/common.py +1809 -0
  266. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  267. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  268. torchrl/modules/llm/utils.py +23 -0
  269. torchrl/modules/mcts/__init__.py +21 -0
  270. torchrl/modules/mcts/scores.py +579 -0
  271. torchrl/modules/models/__init__.py +86 -0
  272. torchrl/modules/models/batchrenorm.py +119 -0
  273. torchrl/modules/models/decision_transformer.py +179 -0
  274. torchrl/modules/models/exploration.py +731 -0
  275. torchrl/modules/models/llm.py +156 -0
  276. torchrl/modules/models/model_based.py +596 -0
  277. torchrl/modules/models/models.py +1712 -0
  278. torchrl/modules/models/multiagent.py +1067 -0
  279. torchrl/modules/models/recipes/impala.py +185 -0
  280. torchrl/modules/models/utils.py +162 -0
  281. torchrl/modules/planners/__init__.py +10 -0
  282. torchrl/modules/planners/cem.py +228 -0
  283. torchrl/modules/planners/common.py +73 -0
  284. torchrl/modules/planners/mppi.py +265 -0
  285. torchrl/modules/tensordict_module/__init__.py +89 -0
  286. torchrl/modules/tensordict_module/actors.py +2457 -0
  287. torchrl/modules/tensordict_module/common.py +529 -0
  288. torchrl/modules/tensordict_module/exploration.py +814 -0
  289. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  290. torchrl/modules/tensordict_module/rnn.py +1639 -0
  291. torchrl/modules/tensordict_module/sequence.py +132 -0
  292. torchrl/modules/tensordict_module/world_models.py +34 -0
  293. torchrl/modules/utils/__init__.py +38 -0
  294. torchrl/modules/utils/mappings.py +9 -0
  295. torchrl/modules/utils/utils.py +89 -0
  296. torchrl/objectives/__init__.py +78 -0
  297. torchrl/objectives/a2c.py +659 -0
  298. torchrl/objectives/common.py +753 -0
  299. torchrl/objectives/cql.py +1346 -0
  300. torchrl/objectives/crossq.py +710 -0
  301. torchrl/objectives/ddpg.py +453 -0
  302. torchrl/objectives/decision_transformer.py +371 -0
  303. torchrl/objectives/deprecated.py +516 -0
  304. torchrl/objectives/dqn.py +683 -0
  305. torchrl/objectives/dreamer.py +488 -0
  306. torchrl/objectives/functional.py +48 -0
  307. torchrl/objectives/gail.py +258 -0
  308. torchrl/objectives/iql.py +996 -0
  309. torchrl/objectives/llm/__init__.py +30 -0
  310. torchrl/objectives/llm/grpo.py +846 -0
  311. torchrl/objectives/llm/sft.py +482 -0
  312. torchrl/objectives/multiagent/__init__.py +8 -0
  313. torchrl/objectives/multiagent/qmixer.py +396 -0
  314. torchrl/objectives/ppo.py +1669 -0
  315. torchrl/objectives/redq.py +683 -0
  316. torchrl/objectives/reinforce.py +530 -0
  317. torchrl/objectives/sac.py +1580 -0
  318. torchrl/objectives/td3.py +570 -0
  319. torchrl/objectives/td3_bc.py +625 -0
  320. torchrl/objectives/utils.py +782 -0
  321. torchrl/objectives/value/__init__.py +28 -0
  322. torchrl/objectives/value/advantages.py +1956 -0
  323. torchrl/objectives/value/functional.py +1459 -0
  324. torchrl/objectives/value/utils.py +360 -0
  325. torchrl/record/__init__.py +17 -0
  326. torchrl/record/loggers/__init__.py +23 -0
  327. torchrl/record/loggers/common.py +48 -0
  328. torchrl/record/loggers/csv.py +226 -0
  329. torchrl/record/loggers/mlflow.py +142 -0
  330. torchrl/record/loggers/tensorboard.py +139 -0
  331. torchrl/record/loggers/trackio.py +163 -0
  332. torchrl/record/loggers/utils.py +78 -0
  333. torchrl/record/loggers/wandb.py +214 -0
  334. torchrl/record/recorder.py +554 -0
  335. torchrl/services/__init__.py +79 -0
  336. torchrl/services/base.py +109 -0
  337. torchrl/services/ray_service.py +453 -0
  338. torchrl/testing/__init__.py +107 -0
  339. torchrl/testing/assertions.py +179 -0
  340. torchrl/testing/dist_utils.py +122 -0
  341. torchrl/testing/env_creators.py +227 -0
  342. torchrl/testing/env_helper.py +35 -0
  343. torchrl/testing/gym_helpers.py +156 -0
  344. torchrl/testing/llm_mocks.py +119 -0
  345. torchrl/testing/mocking_classes.py +2720 -0
  346. torchrl/testing/modules.py +295 -0
  347. torchrl/testing/mp_helpers.py +15 -0
  348. torchrl/testing/ray_helpers.py +293 -0
  349. torchrl/testing/utils.py +190 -0
  350. torchrl/trainers/__init__.py +42 -0
  351. torchrl/trainers/algorithms/__init__.py +11 -0
  352. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  353. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  354. torchrl/trainers/algorithms/configs/common.py +41 -0
  355. torchrl/trainers/algorithms/configs/data.py +308 -0
  356. torchrl/trainers/algorithms/configs/envs.py +104 -0
  357. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  358. torchrl/trainers/algorithms/configs/logging.py +80 -0
  359. torchrl/trainers/algorithms/configs/modules.py +570 -0
  360. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  361. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  362. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  363. torchrl/trainers/algorithms/configs/utils.py +252 -0
  364. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  365. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  366. torchrl/trainers/algorithms/ppo.py +373 -0
  367. torchrl/trainers/algorithms/sac.py +308 -0
  368. torchrl/trainers/helpers/__init__.py +40 -0
  369. torchrl/trainers/helpers/collectors.py +416 -0
  370. torchrl/trainers/helpers/envs.py +573 -0
  371. torchrl/trainers/helpers/logger.py +33 -0
  372. torchrl/trainers/helpers/losses.py +132 -0
  373. torchrl/trainers/helpers/models.py +658 -0
  374. torchrl/trainers/helpers/replay_buffer.py +59 -0
  375. torchrl/trainers/helpers/trainers.py +301 -0
  376. torchrl/trainers/trainers.py +2052 -0
  377. torchrl/weight_update/__init__.py +33 -0
  378. torchrl/weight_update/_distributed.py +749 -0
  379. torchrl/weight_update/_mp.py +624 -0
  380. torchrl/weight_update/_noupdate.py +102 -0
  381. torchrl/weight_update/_ray.py +1032 -0
  382. torchrl/weight_update/_rpc.py +284 -0
  383. torchrl/weight_update/_shared.py +891 -0
  384. torchrl/weight_update/llm/__init__.py +32 -0
  385. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  386. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  387. torchrl/weight_update/utils.py +73 -0
  388. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  389. torchrl-0.11.0.dist-info/METADATA +1308 -0
  390. torchrl-0.11.0.dist-info/RECORD +394 -0
  391. torchrl-0.11.0.dist-info/WHEEL +5 -0
  392. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  393. torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
  394. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,508 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from __future__ import annotations
7
+
8
+ import gc
9
+ import math
10
+ from functools import partial
11
+ from pathlib import Path
12
+
13
+ import hydra
14
+
15
+ from torchrl import merge_ray_runtime_env, torchrl_logger
16
+ from torchrl.data.llm.history import History
17
+ from torchrl.record.loggers.wandb import WandbLogger
18
+ from torchrl.weight_update.llm import get_model_metadata
19
+
20
+ try:
21
+ import ray
22
+ except ImportError:
23
+ raise ImportError(
24
+ "Ray is required for sync training. Please install ray with `pip install ray`."
25
+ )
26
+ import time
27
+
28
+ import torch
29
+ import tqdm
30
+
31
+ from ei_utils import (
32
+ compute_device_allocation,
33
+ create_cosine_scheduler_with_warmup,
34
+ get_inference_model,
35
+ get_train_model,
36
+ log_training_metrics,
37
+ make_env,
38
+ make_weight_sync_scheme,
39
+ RemoteDataLogger,
40
+ )
41
+ from omegaconf import DictConfig
42
+ from ray.util.queue import Queue
43
+
44
+ try:
45
+ from tensordict import set_list_to_stack
46
+ except ImportError:
47
+ raise ImportError(
48
+ "TensorDict is required. Please install it with `pip install tensordict`."
49
+ )
50
+ from torch.amp.autocast_mode import autocast
51
+ from torch.amp.grad_scaler import GradScaler
52
+ from torchrl._utils import timeit
53
+ from torchrl.collectors.llm import RayLLMCollector
54
+ from torchrl.data import LazyStackStorage, ReplayBuffer, SamplerWithoutReplacement
55
+ from torchrl.data.llm.topk import TopKRewardSelector
56
+ from torchrl.data.replay_buffers.ray_buffer import RayReplayBuffer
57
+ from torchrl.objectives.llm.sft import SFTLoss
58
+
59
+ DEFAULT_DIALOG_TURNS_PER_BATCH = 256
60
+
61
+
62
+ def setup_environment() -> None:
63
+ """Setup required environment variables and configurations."""
64
+
65
+ if not torch.cuda.is_available():
66
+ raise RuntimeError("CUDA is required for training")
67
+
68
+ # Set default dtype to float32 for mixed precision training
69
+ torch.set_default_dtype(torch.float32)
70
+ torch.set_default_device("cuda:0")
71
+ set_list_to_stack(True).set()
72
+
73
+ # Ensure CUDA is using the correct dtype
74
+ if torch.cuda.is_available():
75
+ torch.cuda.set_device("cuda:0")
76
+
77
+
78
+ def train(
79
+ replay_buffer: ReplayBuffer,
80
+ cfg: DictConfig,
81
+ collector: RayLLMCollector,
82
+ devices: list[int] | None = None,
83
+ ):
84
+ """Main training loop for EI sync.
85
+
86
+ This function implements synchronous training where data collection and optimization
87
+ happen in separate, consecutive steps. The total number of steps is determined by the number of epochs,
88
+ samples per epoch, and batches collected.
89
+
90
+ Args:
91
+ replay_buffer: The replay buffer to store experiences. The sampler will typically be a `SamplerWithoutReplacement`.
92
+ cfg: The configuration object containing training parameters
93
+ collector: The collector object.
94
+ devices: The devices to use for the training model.
95
+ """
96
+ # Setup training model and tokenizer
97
+ policy_training, train_tokenizer = get_train_model(
98
+ cfg, devices=devices, chat_template_name="qwen"
99
+ )
100
+ train_device = devices[0] # Use first device for batch processing
101
+
102
+ # Setup loss function
103
+ loss_fn = SFTLoss(
104
+ actor_network=policy_training,
105
+ kl_to_ref_coeff=cfg.train.kl_to_ref_coeff,
106
+ tokenizer=train_tokenizer,
107
+ tokenizer_kwargs={"chat_template_name": "qwen"},
108
+ device=torch.device(f"cuda:{train_device}")
109
+ if train_device is not None
110
+ else None,
111
+ loss_function=cfg.train.loss_function,
112
+ beta=cfg.train.minor_sft_beta,
113
+ )
114
+ if cfg.model.compile:
115
+ loss_fn = torch.compile(loss_fn)
116
+
117
+ # Create weight sync scheme
118
+ weight_sync_scheme = make_weight_sync_scheme(
119
+ master_address="localhost", # Since we're running locally
120
+ master_port=None, # Will auto-assign an open port
121
+ vllm_tp_size=cfg.inference_model.num_devices
122
+ if cfg.inference_model.num_devices is not None
123
+ else len(cfg.inference_model.get("devices", [1])),
124
+ )
125
+
126
+ # Set up weight sender
127
+ torchrl_logger.info("Setting up weight synchronization scheme...")
128
+ sender = weight_sync_scheme.create_sender()
129
+ sender.register_model(policy_training)
130
+
131
+ # Get vLLM engine reference from collector's policy
132
+ vllm_engine = collector.policy.model if hasattr(collector, "policy") else None
133
+ if vllm_engine is None:
134
+ raise RuntimeError("Could not get vLLM engine from collector policy")
135
+
136
+ # Initialize collective group
137
+ torchrl_logger.info("Initializing collective group...")
138
+ metadata = get_model_metadata(policy_training)
139
+ sender.init_all_workers_group(metadata, vllm_engine=vllm_engine)
140
+
141
+ # First weight update
142
+ with timeit("update_policy_weights"):
143
+ sender.update_weights()
144
+ timeit.print(prefix="First update_policy_weights_ time")
145
+ timeit.reset()
146
+
147
+ # Make optimizer
148
+ torchrl_logger.info("Starting optimizer.")
149
+ optimizer = torch.optim.Adam(
150
+ policy_training.parameters(),
151
+ lr=cfg.optimizer.lr,
152
+ weight_decay=cfg.optimizer.weight_decay,
153
+ fused=False,
154
+ )
155
+ scaler = GradScaler(enabled=cfg.train.mixed_precision)
156
+
157
+ # Calculate total optimization steps for scheduler
158
+ # The training loop structure: for each collector iteration, we do cfg.train.epochs epochs
159
+ # Each epoch processes the entire replay buffer, and optimization happens every gradient_accumulation_steps
160
+ # We need to estimate the total number of optimization steps
161
+ # For now, we'll use a conservative estimate based on the total dialog turns
162
+ # This can be refined based on the actual training dynamics
163
+ total_optim_steps = (
164
+ cfg.train.total_dialog_turns
165
+ * cfg.train.epochs
166
+ // cfg.train.gradient_accumulation_steps
167
+ )
168
+
169
+ # Create scheduler if enabled
170
+ scheduler = None
171
+ if cfg.optimizer.scheduler.enabled:
172
+ warmup_steps = cfg.optimizer.scheduler.warmup_steps
173
+ num_cycles = cfg.optimizer.scheduler.num_cycles
174
+ torchrl_logger.info(
175
+ f"Creating {cfg.optimizer.scheduler.type} scheduler with {warmup_steps} warmup steps out of {total_optim_steps} total steps"
176
+ )
177
+
178
+ scheduler = create_cosine_scheduler_with_warmup(
179
+ optimizer,
180
+ num_warmup_steps=warmup_steps,
181
+ num_training_steps=total_optim_steps,
182
+ num_cycles=num_cycles,
183
+ )
184
+
185
+ # Make checkpoint dir
186
+ checkpoint_dir = Path(cfg.logging.checkpoint_dir)
187
+ checkpoint_dir.mkdir(parents=True, exist_ok=True)
188
+
189
+ # Make wandb logger
190
+ torchrl_logger.info("Starting wandb logger.")
191
+ experiment_name = cfg.logging.experiment_name
192
+ if experiment_name is not None:
193
+ experiment_name = [experiment_name]
194
+ else:
195
+ experiment_name = []
196
+
197
+ experiment_name.append(cfg.env.dataset)
198
+ experiment_name.append(cfg.model.name)
199
+
200
+ # Create local wandb logger for training metrics
201
+ wandb_config = {
202
+ "project": "ei-sync",
203
+ "exp_name": "-".join(["ei-sync"] + experiment_name),
204
+ }
205
+ wandb_logger = WandbLogger(**wandb_config)
206
+
207
+ # Pass the logging actor reference to the collector
208
+ log_queue = Queue(maxsize=1000)
209
+ collector.set_postproc(RemoteDataLogger(log_queue=log_queue))
210
+
211
+ # Training loop
212
+ torchrl_logger.info("Starting training loop.")
213
+ pbar = tqdm.tqdm(total=cfg.train.total_dialog_turns)
214
+ grad_norm = 0.0 # Initialize grad_norm
215
+ data_read_count = 0
216
+
217
+ global_step = 0
218
+ optim_step = 0 # Track optimization steps separately for scheduler
219
+ start_time = time.time()
220
+ write_count = replay_buffer.write_count
221
+ for data in collector:
222
+ new_write_count = replay_buffer.write_count
223
+ if new_write_count == write_count:
224
+ torchrl_logger.warning("No new writes to replay buffer")
225
+ continue
226
+ pbar.update(new_write_count - write_count)
227
+ write_count = new_write_count
228
+
229
+ # data is None as the collector directly writes to the replay buffer
230
+ if data is not None:
231
+ raise ValueError("Data is not None")
232
+
233
+ for _ in range(cfg.train.epochs):
234
+ # Iterate over the replay buffer
235
+ for batch in replay_buffer:
236
+ batch = batch.to(train_device)
237
+ global_step += 1
238
+ pbar.set_description(
239
+ f"Gradient step {global_step}, writes: {replay_buffer.write_count}, batch size: {batch.shape}"
240
+ )
241
+ # For logging purposes, we get the last element of the history
242
+ # and convert it to a string
243
+ history: History = batch.view(-1)[0]["next", "history", "prompt"]
244
+ history_str: list[str] | str = history.apply_chat_template(
245
+ tokenizer=train_tokenizer
246
+ )
247
+ while not isinstance(history_str, str):
248
+ history_str = "\n".join(history_str)
249
+
250
+ data_read_count += batch.numel()
251
+
252
+ with timeit("forward_pass"):
253
+ # Forward pass with mixed precision
254
+ with autocast("cuda", enabled=cfg.train.mixed_precision):
255
+ loss = loss_fn(batch)
256
+ if loss.loss_kl_to_ref is not None:
257
+ loss_val = loss.loss_sft + loss.loss_kl_to_ref
258
+ else:
259
+ loss_val = loss.loss_sft
260
+ loss_val = loss_val / cfg.train.gradient_accumulation_steps
261
+
262
+ with timeit("backward_pass"):
263
+ # Backward pass
264
+ if (
265
+ cfg.train.mixed_precision
266
+ and cfg.train_model.torch_dtype == "float16"
267
+ ):
268
+ scaler = GradScaler(enabled=True)
269
+ scaler.scale(loss_val).backward()
270
+ else:
271
+ loss_val.backward()
272
+
273
+ # Optimization step
274
+ if ((global_step + 1) % cfg.train.gradient_accumulation_steps) == 0:
275
+ with timeit("optim_step"):
276
+ if (
277
+ cfg.train.mixed_precision
278
+ and cfg.train_model.torch_dtype == "float16"
279
+ ):
280
+ scaler.unscale_(optimizer)
281
+
282
+ grad_norm = torch.nn.utils.clip_grad_norm_(
283
+ policy_training.parameters(),
284
+ cfg.optimizer.clip_grad_norm,
285
+ )
286
+
287
+ if (
288
+ cfg.train.mixed_precision
289
+ and cfg.train_model.torch_dtype == "float16"
290
+ ):
291
+ scaler.step(optimizer)
292
+ scaler.update()
293
+ else:
294
+ optimizer.step()
295
+ optimizer.zero_grad(set_to_none=True)
296
+
297
+ # Step the scheduler
298
+ if scheduler is not None:
299
+ scheduler.step()
300
+
301
+ # Increment optimization step counter
302
+ optim_step += 1
303
+
304
+ # Clear memory
305
+ del loss_val
306
+ torch.cuda.empty_cache()
307
+ gc.collect()
308
+
309
+ # Update metrics
310
+ if (global_step % cfg.train.logging_frequency) == 0:
311
+ log_training_metrics(
312
+ wandb_logger=wandb_logger,
313
+ replay_buffer=replay_buffer,
314
+ batch=batch,
315
+ loss=loss,
316
+ grad_norm=grad_norm,
317
+ global_step=global_step,
318
+ data_read_count=data_read_count,
319
+ collector=collector,
320
+ start_time=start_time,
321
+ gradient_accumulation_steps=cfg.train.gradient_accumulation_steps,
322
+ history_str=history_str,
323
+ )
324
+ # Log additional metrics
325
+ wandb_logger.log_scalar(
326
+ "learning_rate",
327
+ float(optimizer.param_groups[0]["lr"]),
328
+ step=global_step,
329
+ )
330
+ wandb_logger.log_scalar("optim_step", optim_step, step=global_step)
331
+ while not log_queue.empty():
332
+ logs = log_queue.get()
333
+ for k, v in logs.items():
334
+ wandb_logger.log_scalar(k, v, step=global_step)
335
+
336
+ # Update policy weights
337
+ if (
338
+ cfg.train.weight_update_frequency is not None
339
+ and (global_step + 1) % cfg.train.weight_update_frequency == 0
340
+ ):
341
+ with timeit("update_policy_weights"):
342
+ torchrl_logger.info("Updating policy weights...")
343
+ sender.update_weights()
344
+ torch.cuda.empty_cache()
345
+ gc.collect()
346
+ # Checkpointing disabled to prevent disk space issues
347
+ # if (global_step + 1) % cfg.train.checkpoint_frequency == 0:
348
+ # with timeit("save_checkpoint"):
349
+ # torchrl_logger.info(
350
+ # f"Saving checkpoint {(global_step+1) // cfg.train.checkpoint_frequency}..."
351
+ # )
352
+ # checkpoint = {
353
+ # "step": global_step,
354
+ # "model_state_dict": policy_training.model.state_dict(),
355
+ # "optimizer_state_dict": optimizer.state_dict(),
356
+ # "scaler_state_dict": scaler.state_dict(),
357
+ # "config": dict(cfg),
358
+ # }
359
+ # torch.save(checkpoint, checkpoint_dir / f"checkpoint_{global_step:04d}.pt")
360
+
361
+ # Update policy weights
362
+ if cfg.train.weight_update_frequency is None:
363
+ # If weight_update_frequency is not set, we update the weights after each batch
364
+ with timeit("update_policy_weights"):
365
+ torchrl_logger.info("Updating policy weights...")
366
+ sender.update_weights()
367
+ torch.cuda.empty_cache()
368
+ gc.collect()
369
+
370
+ timeit.print(prefix="timeit")
371
+ for key, val in timeit.todict().items():
372
+ wandb_logger.log_scalar(f"timeit/{key}", val)
373
+ timeit.reset()
374
+
375
+ if cfg.train.empty_replay_buffer:
376
+ replay_buffer.empty(empty_write_count=False)
377
+
378
+ pbar.close()
379
+ collector.shutdown()
380
+
381
+
382
+ @hydra.main(version_base=None, config_path="config", config_name="ei_gsm8k")
383
+ def main(cfg):
384
+ # Force sync mode
385
+ if not cfg.train.sync:
386
+ raise ValueError(
387
+ "expert-iteration-sync.py must run in sync mode (`python expert-iteration-sync.py mode=sync`). Please use expert-iteration-async.py for async mode (`python expert-iteration-async.py mode=async`)."
388
+ )
389
+
390
+ # Compute device allocation
391
+ device_config = compute_device_allocation(cfg)
392
+
393
+ if not ray.is_initialized():
394
+ # Convert OmegaConf to regular dict and filter out unsupported parameters
395
+ ray_init_config = {
396
+ k: dict(v) if isinstance(v, DictConfig) else v
397
+ for k, v in dict(cfg.ray.init_config).items()
398
+ if not k.startswith("_")
399
+ }
400
+
401
+ # Add computed GPU configuration and merge with default runtime_env
402
+ ray_init_config["num_gpus"] = device_config["ray_num_gpus"]
403
+ ray_init_config = merge_ray_runtime_env(ray_init_config)
404
+ torchrl_logger.info(f"Ray init config: {ray_init_config=}")
405
+ ray.init(**ray_init_config)
406
+
407
+ # Check if num_devices is set
408
+ if cfg.inference_model.num_devices is None:
409
+ raise ValueError(
410
+ "Inference model num_devices must be set via inference_model.num_devices"
411
+ )
412
+ if cfg.ref_model.num_devices is None:
413
+ raise ValueError("Ref model num_devices must be set via ref_model.num_devices")
414
+ if cfg.train_model.num_devices is None:
415
+ raise ValueError(
416
+ "Train model num_devices must be set via train_model.num_devices"
417
+ )
418
+
419
+ # Convert OmegaConf to regular dict for Ray configs
420
+ replay_buffer_config = dict(cfg.ray.replay_buffer_config)
421
+ collector_config = dict(cfg.ray.collector_config)
422
+ train_handler_config = dict(cfg.ray.train_handler_config)
423
+
424
+ inference_policy = get_inference_model(
425
+ cfg, devices=device_config["inference_model_devices"]
426
+ )
427
+ torchrl_logger.info(f"Inference policy: {inference_policy}")
428
+
429
+ torchrl_logger.info(f"Starting replay buffer with {replay_buffer_config=}")
430
+ rb_size = cfg.train.buffer_size
431
+ if rb_size is None:
432
+ if cfg.train.empty_replay_buffer:
433
+ # we can just set a big number, the buffer will be emptied anyway
434
+ rb_size = 1000000
435
+ else:
436
+ dialog_turns_per_batch = cfg.train.dialog_turns_per_batch
437
+ if dialog_turns_per_batch is None:
438
+ dialog_turns_per_batch = DEFAULT_DIALOG_TURNS_PER_BATCH
439
+ rb_size = int(
440
+ math.ceil(
441
+ dialog_turns_per_batch * cfg.train.topk_size / cfg.env.repeats
442
+ )
443
+ )
444
+ rb = RayReplayBuffer(
445
+ storage=partial(
446
+ LazyStackStorage,
447
+ rb_size,
448
+ device="cpu",
449
+ ),
450
+ sampler=SamplerWithoutReplacement,
451
+ transform_factory=partial(
452
+ TopKRewardSelector,
453
+ total_dialog_turns=cfg.env.repeats,
454
+ topk_size=cfg.train.topk_size,
455
+ ),
456
+ batch_size=cfg.train.optim_batch_size,
457
+ remote_config=replay_buffer_config,
458
+ )
459
+ torchrl_logger.info(f"Replay buffer: {rb}")
460
+
461
+ # Create remote collector using RayLLMCollector
462
+ collector_config["num_gpus"] = (
463
+ # The ref model will be instantiated within the collector, so we only need to allocate the number of devices for the inference model
464
+ cfg.ref_model.num_devices
465
+ )
466
+ torchrl_logger.info(f"Starting collector with {collector_config=}")
467
+
468
+ dialog_turns_per_batch = cfg.train.dialog_turns_per_batch
469
+ if dialog_turns_per_batch is None:
470
+ # Hardcoded for now
471
+ dialog_turns_per_batch = DEFAULT_DIALOG_TURNS_PER_BATCH
472
+
473
+ collector = RayLLMCollector(
474
+ env=partial(make_env, cfg, devices=device_config["ref_model_devices"]),
475
+ policy=inference_policy,
476
+ dialog_turns_per_batch=dialog_turns_per_batch,
477
+ total_dialog_turns=cfg.train.total_dialog_turns,
478
+ replay_buffer=rb,
479
+ ray_init_config=None, # Ray is already initialized
480
+ weight_updater=None, # We'll create this after getting the remote LLM
481
+ track_policy_version=True,
482
+ remote_config=collector_config,
483
+ sync_iter=cfg.train.sync_iter,
484
+ verbose=True,
485
+ )
486
+ # Ensure collector is initialized by calling a method that will block until ready
487
+ ray.get(collector._collector.is_initialized.remote())
488
+ torchrl_logger.info(f"Collector: {collector}")
489
+
490
+ train_handler_config = {
491
+ "num_cpus": train_handler_config.get("num_cpus", 1),
492
+ "num_gpus": cfg.train_model.num_devices,
493
+ }
494
+ torchrl_logger.info(f"Starting training handler with {train_handler_config=}")
495
+ train_handler = ray.remote(
496
+ **train_handler_config,
497
+ )(train)
498
+
499
+ # launch training
500
+ ray.get(
501
+ train_handler.remote(rb, cfg, collector, device_config["train_model_devices"])
502
+ )
503
+
504
+
505
+ if __name__ == "__main__":
506
+ # Setup environment
507
+ setup_environment()
508
+ main()
@@ -0,0 +1,13 @@
1
+ torch==2.7.0
2
+ transformers==4.52.4
3
+ peft==0.15.2
4
+ bitsandbytes==0.46.0
5
+ datasets==3.6.0
6
+ wandb==0.19.11
7
+ hydra-core==1.3.2
8
+ ray==2.52.1
9
+ tqdm==4.67.1
10
+ tensordict==0.9.0
11
+ vllm==0.9.0.1
12
+ accelerate==1.7.0
13
+ xformers==0.0.30
@@ -0,0 +1,16 @@
1
+ torch==2.7.0
2
+ transformers==4.52.4
3
+ peft==0.15.2
4
+ bitsandbytes==0.46.0
5
+ datasets==3.6.0
6
+ wandb==0.19.11
7
+ hydra-core==1.3.2
8
+ ray==2.52.1
9
+ tqdm==4.67.1
10
+ tensordict==0.9.0
11
+ vllm==0.9.0.1
12
+ accelerate==1.7.0
13
+ xformers==0.0.30
14
+ nltk==3.9.1
15
+ langdetect==1.0.9
16
+ immutabledict==4.2.1