torchrl 0.11.0__cp314-cp314-manylinux_2_28_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (394) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/__init__.py +144 -0
  97. torchrl/_extension.py +74 -0
  98. torchrl/_torchrl.cpython-314-aarch64-linux-gnu.so +0 -0
  99. torchrl/_utils.py +1431 -0
  100. torchrl/collectors/__init__.py +48 -0
  101. torchrl/collectors/_base.py +1058 -0
  102. torchrl/collectors/_constants.py +88 -0
  103. torchrl/collectors/_multi_async.py +324 -0
  104. torchrl/collectors/_multi_base.py +1805 -0
  105. torchrl/collectors/_multi_sync.py +464 -0
  106. torchrl/collectors/_runner.py +581 -0
  107. torchrl/collectors/_single.py +2009 -0
  108. torchrl/collectors/_single_async.py +259 -0
  109. torchrl/collectors/collectors.py +62 -0
  110. torchrl/collectors/distributed/__init__.py +32 -0
  111. torchrl/collectors/distributed/default_configs.py +133 -0
  112. torchrl/collectors/distributed/generic.py +1306 -0
  113. torchrl/collectors/distributed/ray.py +1092 -0
  114. torchrl/collectors/distributed/rpc.py +1006 -0
  115. torchrl/collectors/distributed/sync.py +731 -0
  116. torchrl/collectors/distributed/utils.py +160 -0
  117. torchrl/collectors/llm/__init__.py +10 -0
  118. torchrl/collectors/llm/base.py +494 -0
  119. torchrl/collectors/llm/ray_collector.py +275 -0
  120. torchrl/collectors/llm/utils.py +36 -0
  121. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  122. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  123. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  124. torchrl/collectors/utils.py +433 -0
  125. torchrl/collectors/weight_update.py +591 -0
  126. torchrl/csrc/numpy_utils.h +38 -0
  127. torchrl/csrc/pybind.cpp +27 -0
  128. torchrl/csrc/segment_tree.h +458 -0
  129. torchrl/csrc/torch_utils.h +34 -0
  130. torchrl/csrc/utils.cpp +48 -0
  131. torchrl/csrc/utils.h +31 -0
  132. torchrl/data/__init__.py +187 -0
  133. torchrl/data/datasets/__init__.py +58 -0
  134. torchrl/data/datasets/atari_dqn.py +878 -0
  135. torchrl/data/datasets/common.py +281 -0
  136. torchrl/data/datasets/d4rl.py +489 -0
  137. torchrl/data/datasets/d4rl_infos.py +187 -0
  138. torchrl/data/datasets/gen_dgrl.py +375 -0
  139. torchrl/data/datasets/minari_data.py +643 -0
  140. torchrl/data/datasets/openml.py +177 -0
  141. torchrl/data/datasets/openx.py +798 -0
  142. torchrl/data/datasets/roboset.py +363 -0
  143. torchrl/data/datasets/utils.py +11 -0
  144. torchrl/data/datasets/vd4rl.py +432 -0
  145. torchrl/data/llm/__init__.py +34 -0
  146. torchrl/data/llm/dataset.py +491 -0
  147. torchrl/data/llm/history.py +1378 -0
  148. torchrl/data/llm/prompt.py +198 -0
  149. torchrl/data/llm/reward.py +225 -0
  150. torchrl/data/llm/topk.py +186 -0
  151. torchrl/data/llm/utils.py +543 -0
  152. torchrl/data/map/__init__.py +21 -0
  153. torchrl/data/map/hash.py +185 -0
  154. torchrl/data/map/query.py +204 -0
  155. torchrl/data/map/tdstorage.py +363 -0
  156. torchrl/data/map/tree.py +1434 -0
  157. torchrl/data/map/utils.py +103 -0
  158. torchrl/data/postprocs/__init__.py +8 -0
  159. torchrl/data/postprocs/postprocs.py +391 -0
  160. torchrl/data/replay_buffers/__init__.py +99 -0
  161. torchrl/data/replay_buffers/checkpointers.py +622 -0
  162. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  163. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  164. torchrl/data/replay_buffers/samplers.py +2578 -0
  165. torchrl/data/replay_buffers/scheduler.py +265 -0
  166. torchrl/data/replay_buffers/storages.py +2412 -0
  167. torchrl/data/replay_buffers/utils.py +1042 -0
  168. torchrl/data/replay_buffers/writers.py +781 -0
  169. torchrl/data/tensor_specs.py +7101 -0
  170. torchrl/data/utils.py +334 -0
  171. torchrl/envs/__init__.py +265 -0
  172. torchrl/envs/async_envs.py +1105 -0
  173. torchrl/envs/batched_envs.py +3093 -0
  174. torchrl/envs/common.py +4241 -0
  175. torchrl/envs/custom/__init__.py +11 -0
  176. torchrl/envs/custom/chess.py +617 -0
  177. torchrl/envs/custom/llm.py +214 -0
  178. torchrl/envs/custom/pendulum.py +401 -0
  179. torchrl/envs/custom/san_moves.txt +29274 -0
  180. torchrl/envs/custom/tictactoeenv.py +288 -0
  181. torchrl/envs/env_creator.py +263 -0
  182. torchrl/envs/gym_like.py +752 -0
  183. torchrl/envs/libs/__init__.py +68 -0
  184. torchrl/envs/libs/_gym_utils.py +326 -0
  185. torchrl/envs/libs/brax.py +846 -0
  186. torchrl/envs/libs/dm_control.py +544 -0
  187. torchrl/envs/libs/envpool.py +447 -0
  188. torchrl/envs/libs/gym.py +2239 -0
  189. torchrl/envs/libs/habitat.py +138 -0
  190. torchrl/envs/libs/isaac_lab.py +87 -0
  191. torchrl/envs/libs/isaacgym.py +203 -0
  192. torchrl/envs/libs/jax_utils.py +166 -0
  193. torchrl/envs/libs/jumanji.py +963 -0
  194. torchrl/envs/libs/meltingpot.py +599 -0
  195. torchrl/envs/libs/openml.py +153 -0
  196. torchrl/envs/libs/openspiel.py +652 -0
  197. torchrl/envs/libs/pettingzoo.py +1042 -0
  198. torchrl/envs/libs/procgen.py +351 -0
  199. torchrl/envs/libs/robohive.py +429 -0
  200. torchrl/envs/libs/smacv2.py +645 -0
  201. torchrl/envs/libs/unity_mlagents.py +891 -0
  202. torchrl/envs/libs/utils.py +147 -0
  203. torchrl/envs/libs/vmas.py +813 -0
  204. torchrl/envs/llm/__init__.py +63 -0
  205. torchrl/envs/llm/chat.py +730 -0
  206. torchrl/envs/llm/datasets/README.md +4 -0
  207. torchrl/envs/llm/datasets/__init__.py +17 -0
  208. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  209. torchrl/envs/llm/datasets/ifeval.py +274 -0
  210. torchrl/envs/llm/envs.py +789 -0
  211. torchrl/envs/llm/libs/README.md +3 -0
  212. torchrl/envs/llm/libs/__init__.py +8 -0
  213. torchrl/envs/llm/libs/mlgym.py +869 -0
  214. torchrl/envs/llm/reward/__init__.py +10 -0
  215. torchrl/envs/llm/reward/gsm8k.py +324 -0
  216. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  217. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  218. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  222. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  223. torchrl/envs/llm/transforms/__init__.py +55 -0
  224. torchrl/envs/llm/transforms/browser.py +292 -0
  225. torchrl/envs/llm/transforms/dataloading.py +859 -0
  226. torchrl/envs/llm/transforms/format.py +73 -0
  227. torchrl/envs/llm/transforms/kl.py +1544 -0
  228. torchrl/envs/llm/transforms/policy_version.py +189 -0
  229. torchrl/envs/llm/transforms/reason.py +323 -0
  230. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  231. torchrl/envs/llm/transforms/tools.py +1955 -0
  232. torchrl/envs/model_based/__init__.py +9 -0
  233. torchrl/envs/model_based/common.py +180 -0
  234. torchrl/envs/model_based/dreamer.py +112 -0
  235. torchrl/envs/transforms/__init__.py +147 -0
  236. torchrl/envs/transforms/functional.py +48 -0
  237. torchrl/envs/transforms/gym_transforms.py +203 -0
  238. torchrl/envs/transforms/module.py +341 -0
  239. torchrl/envs/transforms/r3m.py +372 -0
  240. torchrl/envs/transforms/ray_service.py +663 -0
  241. torchrl/envs/transforms/rb_transforms.py +214 -0
  242. torchrl/envs/transforms/transforms.py +11835 -0
  243. torchrl/envs/transforms/utils.py +94 -0
  244. torchrl/envs/transforms/vc1.py +307 -0
  245. torchrl/envs/transforms/vecnorm.py +845 -0
  246. torchrl/envs/transforms/vip.py +407 -0
  247. torchrl/envs/utils.py +1718 -0
  248. torchrl/envs/vec_envs.py +11 -0
  249. torchrl/modules/__init__.py +206 -0
  250. torchrl/modules/distributions/__init__.py +73 -0
  251. torchrl/modules/distributions/continuous.py +830 -0
  252. torchrl/modules/distributions/discrete.py +908 -0
  253. torchrl/modules/distributions/truncated_normal.py +187 -0
  254. torchrl/modules/distributions/utils.py +233 -0
  255. torchrl/modules/llm/__init__.py +62 -0
  256. torchrl/modules/llm/backends/__init__.py +65 -0
  257. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  258. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  259. torchrl/modules/llm/backends/vllm/base.py +72 -0
  260. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  261. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  262. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  263. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  264. torchrl/modules/llm/policies/__init__.py +28 -0
  265. torchrl/modules/llm/policies/common.py +1809 -0
  266. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  267. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  268. torchrl/modules/llm/utils.py +23 -0
  269. torchrl/modules/mcts/__init__.py +21 -0
  270. torchrl/modules/mcts/scores.py +579 -0
  271. torchrl/modules/models/__init__.py +86 -0
  272. torchrl/modules/models/batchrenorm.py +119 -0
  273. torchrl/modules/models/decision_transformer.py +179 -0
  274. torchrl/modules/models/exploration.py +731 -0
  275. torchrl/modules/models/llm.py +156 -0
  276. torchrl/modules/models/model_based.py +596 -0
  277. torchrl/modules/models/models.py +1712 -0
  278. torchrl/modules/models/multiagent.py +1067 -0
  279. torchrl/modules/models/recipes/impala.py +185 -0
  280. torchrl/modules/models/utils.py +162 -0
  281. torchrl/modules/planners/__init__.py +10 -0
  282. torchrl/modules/planners/cem.py +228 -0
  283. torchrl/modules/planners/common.py +73 -0
  284. torchrl/modules/planners/mppi.py +265 -0
  285. torchrl/modules/tensordict_module/__init__.py +89 -0
  286. torchrl/modules/tensordict_module/actors.py +2457 -0
  287. torchrl/modules/tensordict_module/common.py +529 -0
  288. torchrl/modules/tensordict_module/exploration.py +814 -0
  289. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  290. torchrl/modules/tensordict_module/rnn.py +1639 -0
  291. torchrl/modules/tensordict_module/sequence.py +132 -0
  292. torchrl/modules/tensordict_module/world_models.py +34 -0
  293. torchrl/modules/utils/__init__.py +38 -0
  294. torchrl/modules/utils/mappings.py +9 -0
  295. torchrl/modules/utils/utils.py +89 -0
  296. torchrl/objectives/__init__.py +78 -0
  297. torchrl/objectives/a2c.py +659 -0
  298. torchrl/objectives/common.py +753 -0
  299. torchrl/objectives/cql.py +1346 -0
  300. torchrl/objectives/crossq.py +710 -0
  301. torchrl/objectives/ddpg.py +453 -0
  302. torchrl/objectives/decision_transformer.py +371 -0
  303. torchrl/objectives/deprecated.py +516 -0
  304. torchrl/objectives/dqn.py +683 -0
  305. torchrl/objectives/dreamer.py +488 -0
  306. torchrl/objectives/functional.py +48 -0
  307. torchrl/objectives/gail.py +258 -0
  308. torchrl/objectives/iql.py +996 -0
  309. torchrl/objectives/llm/__init__.py +30 -0
  310. torchrl/objectives/llm/grpo.py +846 -0
  311. torchrl/objectives/llm/sft.py +482 -0
  312. torchrl/objectives/multiagent/__init__.py +8 -0
  313. torchrl/objectives/multiagent/qmixer.py +396 -0
  314. torchrl/objectives/ppo.py +1669 -0
  315. torchrl/objectives/redq.py +683 -0
  316. torchrl/objectives/reinforce.py +530 -0
  317. torchrl/objectives/sac.py +1580 -0
  318. torchrl/objectives/td3.py +570 -0
  319. torchrl/objectives/td3_bc.py +625 -0
  320. torchrl/objectives/utils.py +782 -0
  321. torchrl/objectives/value/__init__.py +28 -0
  322. torchrl/objectives/value/advantages.py +1956 -0
  323. torchrl/objectives/value/functional.py +1459 -0
  324. torchrl/objectives/value/utils.py +360 -0
  325. torchrl/record/__init__.py +17 -0
  326. torchrl/record/loggers/__init__.py +23 -0
  327. torchrl/record/loggers/common.py +48 -0
  328. torchrl/record/loggers/csv.py +226 -0
  329. torchrl/record/loggers/mlflow.py +142 -0
  330. torchrl/record/loggers/tensorboard.py +139 -0
  331. torchrl/record/loggers/trackio.py +163 -0
  332. torchrl/record/loggers/utils.py +78 -0
  333. torchrl/record/loggers/wandb.py +214 -0
  334. torchrl/record/recorder.py +554 -0
  335. torchrl/services/__init__.py +79 -0
  336. torchrl/services/base.py +109 -0
  337. torchrl/services/ray_service.py +453 -0
  338. torchrl/testing/__init__.py +107 -0
  339. torchrl/testing/assertions.py +179 -0
  340. torchrl/testing/dist_utils.py +122 -0
  341. torchrl/testing/env_creators.py +227 -0
  342. torchrl/testing/env_helper.py +35 -0
  343. torchrl/testing/gym_helpers.py +156 -0
  344. torchrl/testing/llm_mocks.py +119 -0
  345. torchrl/testing/mocking_classes.py +2720 -0
  346. torchrl/testing/modules.py +295 -0
  347. torchrl/testing/mp_helpers.py +15 -0
  348. torchrl/testing/ray_helpers.py +293 -0
  349. torchrl/testing/utils.py +190 -0
  350. torchrl/trainers/__init__.py +42 -0
  351. torchrl/trainers/algorithms/__init__.py +11 -0
  352. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  353. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  354. torchrl/trainers/algorithms/configs/common.py +41 -0
  355. torchrl/trainers/algorithms/configs/data.py +308 -0
  356. torchrl/trainers/algorithms/configs/envs.py +104 -0
  357. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  358. torchrl/trainers/algorithms/configs/logging.py +80 -0
  359. torchrl/trainers/algorithms/configs/modules.py +570 -0
  360. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  361. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  362. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  363. torchrl/trainers/algorithms/configs/utils.py +252 -0
  364. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  365. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  366. torchrl/trainers/algorithms/ppo.py +373 -0
  367. torchrl/trainers/algorithms/sac.py +308 -0
  368. torchrl/trainers/helpers/__init__.py +40 -0
  369. torchrl/trainers/helpers/collectors.py +416 -0
  370. torchrl/trainers/helpers/envs.py +573 -0
  371. torchrl/trainers/helpers/logger.py +33 -0
  372. torchrl/trainers/helpers/losses.py +132 -0
  373. torchrl/trainers/helpers/models.py +658 -0
  374. torchrl/trainers/helpers/replay_buffer.py +59 -0
  375. torchrl/trainers/helpers/trainers.py +301 -0
  376. torchrl/trainers/trainers.py +2052 -0
  377. torchrl/weight_update/__init__.py +33 -0
  378. torchrl/weight_update/_distributed.py +749 -0
  379. torchrl/weight_update/_mp.py +624 -0
  380. torchrl/weight_update/_noupdate.py +102 -0
  381. torchrl/weight_update/_ray.py +1032 -0
  382. torchrl/weight_update/_rpc.py +284 -0
  383. torchrl/weight_update/_shared.py +891 -0
  384. torchrl/weight_update/llm/__init__.py +32 -0
  385. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  386. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  387. torchrl/weight_update/utils.py +73 -0
  388. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  389. torchrl-0.11.0.dist-info/METADATA +1308 -0
  390. torchrl-0.11.0.dist-info/RECORD +394 -0
  391. torchrl-0.11.0.dist-info/WHEEL +5 -0
  392. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  393. torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
  394. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,622 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ import abc
8
+ import json
9
+ import tempfile
10
+ import warnings
11
+ from pathlib import Path
12
+
13
+ import numpy as np
14
+ import torch
15
+ from tensordict import (
16
+ is_tensor_collection,
17
+ lazy_stack,
18
+ NonTensorData,
19
+ PersistentTensorDict,
20
+ TensorDict,
21
+ )
22
+ from tensordict.memmap import MemoryMappedTensor
23
+ from tensordict.utils import _zip_strict
24
+ from torch.utils._pytree import tree_map
25
+ from torchrl._utils import _STRDTYPE2DTYPE
26
+
27
+ from torchrl.data.replay_buffers.utils import (
28
+ _save_pytree,
29
+ Flat2TED,
30
+ H5Combine,
31
+ H5Split,
32
+ Nested2TED,
33
+ TED2Flat,
34
+ TED2Nested,
35
+ )
36
+
37
+
38
+ class StorageCheckpointerBase:
39
+ """Public base class for storage checkpointers.
40
+
41
+ Each storage checkpointer must implement a `save` and `load` method that take as input a storage and a
42
+ path.
43
+
44
+ """
45
+
46
+ def __init__(self):
47
+ self._save_hooks = []
48
+ self._load_hooks = []
49
+
50
+ def register_save_hook(self, hook):
51
+ """Registers a save hook for this checkpointer."""
52
+ self._save_hooks.append(hook)
53
+
54
+ def register_load_hook(self, hook):
55
+ """Registers a load hook for this checkpointer."""
56
+ self._load_hooks.append(hook)
57
+
58
+ def _get_shift_from_last_cursor(self, last_cursor):
59
+ """Computes shift from the last cursor position."""
60
+ if isinstance(last_cursor, slice):
61
+ return last_cursor.stop + 1
62
+ if isinstance(last_cursor, int):
63
+ return last_cursor + 1
64
+ if isinstance(last_cursor, range):
65
+ return last_cursor[-1] + 1
66
+ if isinstance(last_cursor, torch.Tensor):
67
+ return last_cursor.reshape(-1)[-1].item() + 1
68
+ if isinstance(last_cursor, np.ndarray):
69
+ return last_cursor.reshape(-1)[-1].item() + 1
70
+ raise ValueError(f"Unrecognised last_cursor type {type(last_cursor)}.")
71
+
72
+ def _set_hooks_shift_is_full(self, storage):
73
+ """Sets shift and is_full attributes on save hooks that have them."""
74
+ is_full = storage._is_full
75
+ last_cursor = storage._last_cursor
76
+ for hook in self._save_hooks:
77
+ if hasattr(hook, "is_full"):
78
+ hook.is_full = is_full
79
+ if last_cursor is None:
80
+ warnings.warn(
81
+ "last_cursor is None. The replay buffer "
82
+ "may not be saved properly in this setting. To solve this issue, make "
83
+ "sure the storage updates the _last_cursor value during calls to `set`."
84
+ )
85
+ shift = 0
86
+ else:
87
+ shift = self._get_shift_from_last_cursor(last_cursor)
88
+ for hook in self._save_hooks:
89
+ if hasattr(hook, "shift"):
90
+ hook.shift = shift
91
+
92
+ @abc.abstractmethod
93
+ def dumps(self, storage, path):
94
+ ...
95
+
96
+ @abc.abstractmethod
97
+ def loads(self, storage, path):
98
+ ...
99
+
100
+
101
+ class ListStorageCheckpointer(StorageCheckpointerBase):
102
+ """A storage checkpointer for ListStoage.
103
+
104
+ Currently not implemented.
105
+
106
+ """
107
+
108
+ @staticmethod
109
+ def dumps(storage, path):
110
+ raise NotImplementedError(
111
+ "ListStorage doesn't support serialization via `dumps` - `loads` API."
112
+ )
113
+
114
+ @staticmethod
115
+ def loads(storage, path):
116
+ raise NotImplementedError(
117
+ "ListStorage doesn't support serialization via `dumps` - `loads` API."
118
+ )
119
+
120
+
121
+ class CompressedListStorageCheckpointer(StorageCheckpointerBase):
122
+ """A storage checkpointer for CompressedListStorage.
123
+
124
+ This checkpointer saves compressed data and metadata using memory-mapped storage
125
+ for efficient disk I/O and memory usage.
126
+
127
+ """
128
+
129
+ def dumps(self, storage, path):
130
+ """Save compressed storage to disk using memory-mapped storage.
131
+
132
+ Args:
133
+ storage: The CompressedListStorage instance to save
134
+ path: Directory path where to save the storage
135
+ """
136
+ path = Path(path)
137
+ path.mkdir(exist_ok=True)
138
+
139
+ if not hasattr(storage, "_storage") or len(storage._storage) == 0:
140
+ raise RuntimeError(
141
+ "Cannot save an empty or non-initialized CompressedListStorage."
142
+ )
143
+
144
+ # Get state dict from storage
145
+ state_dict = storage.state_dict()
146
+ compressed_data = state_dict["_storage"]
147
+ metadata = state_dict["_metadata"]
148
+
149
+ # Create a temporary directory for processing
150
+ with tempfile.TemporaryDirectory() as tmp_dir:
151
+ tmp_path = Path(tmp_dir)
152
+
153
+ # Process compressed data for memmap storage
154
+ processed_data = []
155
+ for item in compressed_data:
156
+ if item is None:
157
+ processed_data.append(None)
158
+ continue
159
+
160
+ if isinstance(item, torch.Tensor):
161
+ # For tensor data, create a TensorDict with the tensor
162
+ processed_item = TensorDict({"data": item}, batch_size=[])
163
+ elif isinstance(item, dict):
164
+ # For dict data (tensordict fields), convert to TensorDict
165
+ processed_item = TensorDict(item, batch_size=[])
166
+ else:
167
+ # For other types, wrap in TensorDict
168
+ processed_item = TensorDict({"data": item}, batch_size=[])
169
+
170
+ processed_data.append(processed_item)
171
+
172
+ # Stack all non-None items into a single TensorDict for memmap
173
+ non_none_data = [item for item in processed_data if item is not None]
174
+ if non_none_data:
175
+ # Use lazy_stack to handle heterogeneous structures
176
+ stacked_data = lazy_stack(non_none_data)
177
+
178
+ # Save to memmap
179
+ stacked_data.memmap_(tmp_path / "compressed_data")
180
+
181
+ # Create index mapping for None values
182
+ data_indices = []
183
+ current_idx = 0
184
+ for item in processed_data:
185
+ if item is None:
186
+ data_indices.append(None)
187
+ else:
188
+ data_indices.append(current_idx)
189
+ current_idx += 1
190
+ else:
191
+ # No data to save
192
+ data_indices = []
193
+
194
+ # Process metadata for JSON serialization
195
+ def is_leaf(item):
196
+ return isinstance(
197
+ item,
198
+ (
199
+ torch.Size,
200
+ torch.dtype,
201
+ torch.device,
202
+ str,
203
+ int,
204
+ float,
205
+ bool,
206
+ torch.Tensor,
207
+ NonTensorData,
208
+ ),
209
+ )
210
+
211
+ def map_to_json_serializable(item):
212
+ if isinstance(item, torch.Size):
213
+ return {"__type__": "torch.Size", "value": list(item)}
214
+ elif isinstance(item, torch.dtype):
215
+ return {"__type__": "torch.dtype", "value": str(item)}
216
+ elif isinstance(item, torch.device):
217
+ return {"__type__": "torch.device", "value": str(item)}
218
+ elif isinstance(item, torch.Tensor):
219
+ return {"__type__": "torch.Tensor", "value": item.tolist()}
220
+ elif isinstance(item, NonTensorData):
221
+ return {"__type__": "NonTensorData", "value": item.data}
222
+ return item
223
+
224
+ serializable_metadata = tree_map(
225
+ map_to_json_serializable, metadata, is_leaf=is_leaf
226
+ )
227
+
228
+ # Save metadata and indices
229
+ metadata_file = tmp_path / "metadata.json"
230
+ with open(metadata_file, "w") as f:
231
+ json.dump(serializable_metadata, f, indent=2)
232
+
233
+ indices_file = tmp_path / "data_indices.json"
234
+ with open(indices_file, "w") as f:
235
+ json.dump(data_indices, f, indent=2)
236
+
237
+ # Copy all files from temp directory to final destination
238
+ import shutil
239
+
240
+ for item in tmp_path.iterdir():
241
+ if item.is_file():
242
+ shutil.copy2(item, path / item.name)
243
+ elif item.is_dir():
244
+ shutil.copytree(item, path / item.name, dirs_exist_ok=True)
245
+
246
+ def loads(self, storage, path):
247
+ """Load compressed storage from disk.
248
+
249
+ Args:
250
+ storage: The CompressedListStorage instance to load into
251
+ path: Directory path where the storage was saved
252
+ """
253
+ path = Path(path)
254
+
255
+ # Load metadata
256
+ metadata_file = path / "metadata.json"
257
+ if not metadata_file.exists():
258
+ raise FileNotFoundError(f"Metadata file not found at {metadata_file}")
259
+
260
+ with open(metadata_file) as f:
261
+ serializable_metadata = json.load(f)
262
+
263
+ # Load data indices
264
+ indices_file = path / "data_indices.json"
265
+ if not indices_file.exists():
266
+ raise FileNotFoundError(f"Data indices file not found at {indices_file}")
267
+
268
+ with open(indices_file) as f:
269
+ data_indices = json.load(f)
270
+
271
+ # Convert serializable metadata back to original format
272
+ def is_leaf(item):
273
+ return isinstance(item, dict) and "__type__" in item
274
+
275
+ def map_from_json_serializable(item):
276
+ if isinstance(item, dict) and "__type__" in item:
277
+ if item["__type__"] == "torch.Size":
278
+ return torch.Size(item["value"])
279
+ elif item["__type__"] == "torch.dtype":
280
+ # Handle torch.dtype conversion
281
+ dtype_str = item["value"]
282
+ if hasattr(torch, dtype_str.replace("torch.", "")):
283
+ return getattr(torch, dtype_str.replace("torch.", ""))
284
+ else:
285
+ # Handle cases like 'torch.float32' -> torch.float32
286
+ return eval(dtype_str)
287
+ elif item["__type__"] == "torch.device":
288
+ return torch.device(item["value"])
289
+ elif item["__type__"] == "torch.Tensor":
290
+ return torch.tensor(item["value"])
291
+ elif item["__type__"] == "NonTensorData":
292
+ return NonTensorData(item["value"])
293
+ return item
294
+
295
+ metadata = tree_map(
296
+ map_from_json_serializable, serializable_metadata, is_leaf=is_leaf
297
+ )
298
+
299
+ # Load compressed data from memmap
300
+ compressed_data = []
301
+ memmap_path = path / "compressed_data"
302
+
303
+ if memmap_path.exists():
304
+ # Load the memmapped data
305
+ stacked_data = TensorDict.load_memmap(memmap_path)
306
+ compressed_data = stacked_data.tolist()
307
+ if len(compressed_data) != len(data_indices):
308
+ raise ValueError(
309
+ f"Length of compressed data ({len(compressed_data)}) does not match length of data indices ({len(data_indices)})"
310
+ )
311
+ for i, (data, mtdt) in enumerate(_zip_strict(compressed_data, metadata)):
312
+ if mtdt["type"] == "tensor":
313
+ compressed_data[i] = data["data"]
314
+ else:
315
+ compressed_data[i] = data
316
+
317
+ else:
318
+ # No data to load
319
+ compressed_data = [None] * len(data_indices)
320
+
321
+ # Load into storage
322
+ storage._storage = compressed_data
323
+ storage._metadata = metadata
324
+
325
+
326
+ class TensorStorageCheckpointer(StorageCheckpointerBase):
327
+ """A storage checkpointer for TensorStorages.
328
+
329
+ This class supports TensorDict-based storages as well as pytrees.
330
+
331
+ This class will call save and load hooks if provided. These hooks should take as input the
332
+ data being transformed as well as the path where the data should be saved.
333
+
334
+ """
335
+
336
+ def dumps(self, storage, path):
337
+ path = Path(path)
338
+ path.mkdir(exist_ok=True)
339
+
340
+ if not storage.initialized:
341
+ raise RuntimeError("Cannot save a non-initialized storage.")
342
+ metadata = {}
343
+ _storage = storage._storage
344
+
345
+ self._set_hooks_shift_is_full(storage)
346
+
347
+ for hook in self._save_hooks:
348
+ _storage = hook(_storage, path=path)
349
+ if is_tensor_collection(_storage):
350
+ if (
351
+ _storage.is_memmap()
352
+ and _storage.saved_path
353
+ and Path(_storage.saved_path).absolute() == Path(path).absolute()
354
+ ):
355
+ _storage.memmap_refresh_()
356
+ else:
357
+ # try to load the path and overwrite.
358
+ _storage.memmap(
359
+ path,
360
+ copy_existing=True, # num_threads=torch.get_num_threads()
361
+ )
362
+ is_pytree = False
363
+ else:
364
+ _save_pytree(_storage, metadata, path)
365
+ is_pytree = True
366
+
367
+ with open(path / "storage_metadata.json", "w") as file:
368
+ json.dump(
369
+ {
370
+ "metadata": metadata,
371
+ "is_pytree": is_pytree,
372
+ "len": storage._len,
373
+ },
374
+ file,
375
+ )
376
+
377
+ def loads(self, storage, path):
378
+ with open(path / "storage_metadata.json") as file:
379
+ metadata = json.load(file)
380
+ is_pytree = metadata["is_pytree"]
381
+ _len = metadata["len"]
382
+ if is_pytree:
383
+ if self._load_hooks:
384
+ raise RuntimeError(
385
+ "Loading hooks are not compatible with PyTree storages."
386
+ )
387
+ path = Path(path)
388
+ for local_path, md in metadata["metadata"].items():
389
+ # load tensor
390
+ local_path_dot = local_path.replace(".", "/")
391
+ total_tensor_path = path / (local_path_dot + ".memmap")
392
+ shape = torch.Size(md["shape"])
393
+ dtype = _STRDTYPE2DTYPE[md["dtype"]]
394
+ tensor = MemoryMappedTensor.from_filename(
395
+ filename=total_tensor_path, shape=shape, dtype=dtype
396
+ )
397
+ # split path
398
+ local_path = local_path.split(".")
399
+ # replace potential dots
400
+ local_path = [_path.replace("_<dot>_", ".") for _path in local_path]
401
+ if storage.initialized:
402
+ # copy in-place
403
+ _storage_tensor = storage._storage
404
+ # in this case there is a single tensor, so we skip
405
+ if local_path != ["_-single-tensor-_"]:
406
+ for _path in local_path:
407
+ if _path.isdigit():
408
+ _path_attempt = int(_path)
409
+ try:
410
+ _storage_tensor = _storage_tensor[_path_attempt]
411
+ continue
412
+ except IndexError:
413
+ pass
414
+ _storage_tensor = _storage_tensor[_path]
415
+ _storage_tensor.copy_(tensor)
416
+ else:
417
+ raise RuntimeError(
418
+ "Cannot fill a non-initialized pytree-based TensorStorage."
419
+ )
420
+ else:
421
+ _storage = TensorDict.load_memmap(path)
422
+ if storage.initialized:
423
+ dest = storage._storage
424
+ else:
425
+ # TODO: This could load the RAM a lot, maybe try to catch this within the hook and use memmap instead
426
+ dest = None
427
+ for hook in self._load_hooks:
428
+ _storage = hook(_storage, out=dest)
429
+ if not storage.initialized:
430
+ from torchrl.data.replay_buffers.storages import LazyMemmapStorage
431
+
432
+ if (
433
+ isinstance(storage, LazyMemmapStorage)
434
+ and storage.scratch_dir
435
+ and Path(storage.scratch_dir).absolute() == Path(path).absolute()
436
+ ):
437
+ storage._storage = TensorDict.load_memmap(path)
438
+ storage.initialized = True
439
+ else:
440
+ # this should not be reached if is_pytree=True
441
+ storage._init(_storage[0])
442
+ storage._storage.update_(_storage)
443
+ elif (
444
+ storage._storage.is_memmap()
445
+ and storage._storage.saved_path
446
+ and Path(storage._storage.saved_path).absolute()
447
+ == Path(path).absolute()
448
+ ):
449
+ # If the storage is already where it should be, we don't need to load anything.
450
+ storage._storage.memmap_refresh_()
451
+
452
+ else:
453
+ storage._storage.copy_(_storage)
454
+ storage._len = _len
455
+
456
+
457
+ class FlatStorageCheckpointer(TensorStorageCheckpointer):
458
+ """Saves the storage in a compact form, saving space on the TED format.
459
+
460
+ This class explicitly assumes and does NOT check that:
461
+
462
+ - done states (including terminated and truncated) at the root are always False;
463
+ - observations in the "next" tensordict are shifted by one step in the future (this
464
+ is not the case when a multi-step transform is used for instance) unless `done` is `True`
465
+ in which case the observation in `("next", key)` at time `t` and the one in `key` at time
466
+ `t+1` should not match.
467
+
468
+ .. seealso: The full list of arguments can be found in :class:`~torchrl.data.TED2Flat`.
469
+
470
+ """
471
+
472
+ def __init__(self, done_keys=None, reward_keys=None):
473
+ super().__init__()
474
+ kwargs = {}
475
+ if done_keys is not None:
476
+ kwargs["done_keys"] = done_keys
477
+ if reward_keys is not None:
478
+ kwargs["reward_keys"] = reward_keys
479
+ self._save_hooks = [TED2Flat(**kwargs)]
480
+ self._load_hooks = [Flat2TED(**kwargs)]
481
+
482
+
483
+ class NestedStorageCheckpointer(FlatStorageCheckpointer):
484
+ """Saves the storage in a compact form, saving space on the TED format and using memory-mapped nested tensors.
485
+
486
+ This class explicitly assumes and does NOT check that:
487
+
488
+ - done states (including terminated and truncated) at the root are always False;
489
+ - observations in the "next" tensordict are shifted by one step in the future (this
490
+ is not the case when a multi-step transform is used for instance).
491
+
492
+ .. seealso: The full list of arguments can be found in :class:`~torchrl.data.TED2Flat`.
493
+
494
+ """
495
+
496
+ def __init__(self, done_keys=None, reward_keys=None):
497
+ super().__init__()
498
+ kwargs = {}
499
+ if done_keys is not None:
500
+ kwargs["done_keys"] = done_keys
501
+ if reward_keys is not None:
502
+ kwargs["reward_keys"] = reward_keys
503
+ self._save_hooks = [TED2Nested(**kwargs)]
504
+ self._load_hooks = [Nested2TED(**kwargs)]
505
+
506
+
507
+ class H5StorageCheckpointer(NestedStorageCheckpointer):
508
+ """Saves the storage in a compact form, saving space on the TED format and using H5 format to save the data.
509
+
510
+ This class explicitly assumes and does NOT check that:
511
+
512
+ - done states (including terminated and truncated) at the root are always False;
513
+ - observations in the "next" tensordict are shifted by one step in the future (this
514
+ is not the case when a multi-step transform is used for instance).
515
+
516
+ Keyword Args:
517
+ checkpoint_file: the filename where to save the checkpointed data.
518
+ This will be ignored iff the path passed to dumps / loads ends with the ``.h5``
519
+ suffix. Defaults to ``"checkpoint.h5"``.
520
+ h5_kwargs (Dict[str, Any] or Tuple[Tuple[str, Any], ...]): kwargs to be
521
+ passed to :meth:`h5py.File.create_dataset`.
522
+
523
+ .. note:: To prevent out-of-memory issues, the data of the H5 file will be temporarily written
524
+ on memory-mapped tensors stored in shared file system. The physical memory usage may increase
525
+ during loading as a consequence.
526
+
527
+ .. seealso: The full list of arguments can be found in :class:`~torchrl.data.TED2Flat`. Note that this class only
528
+ supports keyword arguments.
529
+
530
+ """
531
+
532
+ def __init__(
533
+ self,
534
+ *,
535
+ checkpoint_file: str = "checkpoint.h5",
536
+ done_keys=None,
537
+ reward_keys=None,
538
+ h5_kwargs=None,
539
+ **kwargs,
540
+ ):
541
+ StorageCheckpointerBase.__init__(self)
542
+ ted2_kwargs = kwargs
543
+ if done_keys is not None:
544
+ ted2_kwargs["done_keys"] = done_keys
545
+ if reward_keys is not None:
546
+ ted2_kwargs["reward_keys"] = reward_keys
547
+ self._save_hooks = [TED2Nested(**ted2_kwargs), H5Split()]
548
+ self._load_hooks = [H5Combine(), Nested2TED(**ted2_kwargs)]
549
+ self.kwargs = {} if h5_kwargs is None else dict(h5_kwargs)
550
+ self.checkpoint_file = checkpoint_file
551
+
552
+ def dumps(self, storage, path):
553
+ path = self._get_path(path)
554
+
555
+ self._set_hooks_shift_is_full(storage)
556
+
557
+ if not storage.initialized:
558
+ raise RuntimeError("Cannot save a non-initialized storage.")
559
+ _storage = storage._storage
560
+ length = storage._len
561
+ for hook in self._save_hooks:
562
+ # we don't pass a path here since we're not reusing the tensordict
563
+ _storage = hook(_storage)
564
+ if is_tensor_collection(_storage):
565
+ # try to load the path and overwrite.
566
+ data = PersistentTensorDict.from_dict(_storage, path, **self.kwargs)
567
+ data["_len"] = NonTensorData(data=length)
568
+ else:
569
+ raise ValueError("Only tensor collections are supported.")
570
+
571
+ def loads(self, storage, path):
572
+ path = self._get_path(path)
573
+ data = PersistentTensorDict.from_h5(path)
574
+ if storage.initialized:
575
+ dest = storage._storage
576
+ else:
577
+ # TODO: This could load the RAM a lot, maybe try to catch this within the hook and use memmap instead
578
+ dest = None
579
+ _len = data["_len"]
580
+ for hook in self._load_hooks:
581
+ data = hook(data, out=dest)
582
+ if not storage.initialized:
583
+ # this should not be reached if is_pytree=True
584
+ storage._init(data[0])
585
+ storage._storage.update_(data)
586
+ else:
587
+ storage._storage.copy_(data)
588
+ storage._len = _len
589
+
590
+ def _get_path(self, path):
591
+ path = Path(path)
592
+ if path.suffix == ".h5":
593
+ return str(path.absolute())
594
+ try:
595
+ path.mkdir(exist_ok=True)
596
+ except Exception:
597
+ raise RuntimeError(f"Failed to create the checkpoint directory {path}.")
598
+ path = path / self.checkpoint_file
599
+ return str(path.absolute())
600
+
601
+
602
+ class StorageEnsembleCheckpointer(StorageCheckpointerBase):
603
+ """Checkpointer for ensemble storages."""
604
+
605
+ @staticmethod
606
+ def dumps(storage, path: Path):
607
+ path = Path(path).absolute()
608
+ storages = storage._storages
609
+ for i, storage in enumerate(storages):
610
+ storage.dumps(path / str(i))
611
+ if storage._transforms is not None:
612
+ for i, transform in enumerate(storage._transforms):
613
+ torch.save(transform.state_dict(), path / f"{i}_transform.pt")
614
+
615
+ @staticmethod
616
+ def loads(storage, path: Path):
617
+ path = Path(path).absolute()
618
+ for i, _storage in enumerate(storage._storages):
619
+ _storage.loads(path / str(i))
620
+ if storage._transforms is not None:
621
+ for i, transform in enumerate(storage._transforms):
622
+ transform.load_state_dict(torch.load(path / f"{i}_transform.pt"))