torchrl 0.11.0__cp314-cp314t-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (395) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/.dylibs/libc++.1.0.dylib +0 -0
  97. torchrl/__init__.py +144 -0
  98. torchrl/_extension.py +74 -0
  99. torchrl/_torchrl.cpython-314t-darwin.so +0 -0
  100. torchrl/_utils.py +1431 -0
  101. torchrl/collectors/__init__.py +48 -0
  102. torchrl/collectors/_base.py +1058 -0
  103. torchrl/collectors/_constants.py +88 -0
  104. torchrl/collectors/_multi_async.py +324 -0
  105. torchrl/collectors/_multi_base.py +1805 -0
  106. torchrl/collectors/_multi_sync.py +464 -0
  107. torchrl/collectors/_runner.py +581 -0
  108. torchrl/collectors/_single.py +2009 -0
  109. torchrl/collectors/_single_async.py +259 -0
  110. torchrl/collectors/collectors.py +62 -0
  111. torchrl/collectors/distributed/__init__.py +32 -0
  112. torchrl/collectors/distributed/default_configs.py +133 -0
  113. torchrl/collectors/distributed/generic.py +1306 -0
  114. torchrl/collectors/distributed/ray.py +1092 -0
  115. torchrl/collectors/distributed/rpc.py +1006 -0
  116. torchrl/collectors/distributed/sync.py +731 -0
  117. torchrl/collectors/distributed/utils.py +160 -0
  118. torchrl/collectors/llm/__init__.py +10 -0
  119. torchrl/collectors/llm/base.py +494 -0
  120. torchrl/collectors/llm/ray_collector.py +275 -0
  121. torchrl/collectors/llm/utils.py +36 -0
  122. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  123. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  124. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  125. torchrl/collectors/utils.py +433 -0
  126. torchrl/collectors/weight_update.py +591 -0
  127. torchrl/csrc/numpy_utils.h +38 -0
  128. torchrl/csrc/pybind.cpp +27 -0
  129. torchrl/csrc/segment_tree.h +458 -0
  130. torchrl/csrc/torch_utils.h +34 -0
  131. torchrl/csrc/utils.cpp +48 -0
  132. torchrl/csrc/utils.h +31 -0
  133. torchrl/data/__init__.py +187 -0
  134. torchrl/data/datasets/__init__.py +58 -0
  135. torchrl/data/datasets/atari_dqn.py +878 -0
  136. torchrl/data/datasets/common.py +281 -0
  137. torchrl/data/datasets/d4rl.py +489 -0
  138. torchrl/data/datasets/d4rl_infos.py +187 -0
  139. torchrl/data/datasets/gen_dgrl.py +375 -0
  140. torchrl/data/datasets/minari_data.py +643 -0
  141. torchrl/data/datasets/openml.py +177 -0
  142. torchrl/data/datasets/openx.py +798 -0
  143. torchrl/data/datasets/roboset.py +363 -0
  144. torchrl/data/datasets/utils.py +11 -0
  145. torchrl/data/datasets/vd4rl.py +432 -0
  146. torchrl/data/llm/__init__.py +34 -0
  147. torchrl/data/llm/dataset.py +491 -0
  148. torchrl/data/llm/history.py +1378 -0
  149. torchrl/data/llm/prompt.py +198 -0
  150. torchrl/data/llm/reward.py +225 -0
  151. torchrl/data/llm/topk.py +186 -0
  152. torchrl/data/llm/utils.py +543 -0
  153. torchrl/data/map/__init__.py +21 -0
  154. torchrl/data/map/hash.py +185 -0
  155. torchrl/data/map/query.py +204 -0
  156. torchrl/data/map/tdstorage.py +363 -0
  157. torchrl/data/map/tree.py +1434 -0
  158. torchrl/data/map/utils.py +103 -0
  159. torchrl/data/postprocs/__init__.py +8 -0
  160. torchrl/data/postprocs/postprocs.py +391 -0
  161. torchrl/data/replay_buffers/__init__.py +99 -0
  162. torchrl/data/replay_buffers/checkpointers.py +622 -0
  163. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  164. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  165. torchrl/data/replay_buffers/samplers.py +2578 -0
  166. torchrl/data/replay_buffers/scheduler.py +265 -0
  167. torchrl/data/replay_buffers/storages.py +2412 -0
  168. torchrl/data/replay_buffers/utils.py +1042 -0
  169. torchrl/data/replay_buffers/writers.py +781 -0
  170. torchrl/data/tensor_specs.py +7101 -0
  171. torchrl/data/utils.py +334 -0
  172. torchrl/envs/__init__.py +265 -0
  173. torchrl/envs/async_envs.py +1105 -0
  174. torchrl/envs/batched_envs.py +3093 -0
  175. torchrl/envs/common.py +4241 -0
  176. torchrl/envs/custom/__init__.py +11 -0
  177. torchrl/envs/custom/chess.py +617 -0
  178. torchrl/envs/custom/llm.py +214 -0
  179. torchrl/envs/custom/pendulum.py +401 -0
  180. torchrl/envs/custom/san_moves.txt +29274 -0
  181. torchrl/envs/custom/tictactoeenv.py +288 -0
  182. torchrl/envs/env_creator.py +263 -0
  183. torchrl/envs/gym_like.py +752 -0
  184. torchrl/envs/libs/__init__.py +68 -0
  185. torchrl/envs/libs/_gym_utils.py +326 -0
  186. torchrl/envs/libs/brax.py +846 -0
  187. torchrl/envs/libs/dm_control.py +544 -0
  188. torchrl/envs/libs/envpool.py +447 -0
  189. torchrl/envs/libs/gym.py +2239 -0
  190. torchrl/envs/libs/habitat.py +138 -0
  191. torchrl/envs/libs/isaac_lab.py +87 -0
  192. torchrl/envs/libs/isaacgym.py +203 -0
  193. torchrl/envs/libs/jax_utils.py +166 -0
  194. torchrl/envs/libs/jumanji.py +963 -0
  195. torchrl/envs/libs/meltingpot.py +599 -0
  196. torchrl/envs/libs/openml.py +153 -0
  197. torchrl/envs/libs/openspiel.py +652 -0
  198. torchrl/envs/libs/pettingzoo.py +1042 -0
  199. torchrl/envs/libs/procgen.py +351 -0
  200. torchrl/envs/libs/robohive.py +429 -0
  201. torchrl/envs/libs/smacv2.py +645 -0
  202. torchrl/envs/libs/unity_mlagents.py +891 -0
  203. torchrl/envs/libs/utils.py +147 -0
  204. torchrl/envs/libs/vmas.py +813 -0
  205. torchrl/envs/llm/__init__.py +63 -0
  206. torchrl/envs/llm/chat.py +730 -0
  207. torchrl/envs/llm/datasets/README.md +4 -0
  208. torchrl/envs/llm/datasets/__init__.py +17 -0
  209. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  210. torchrl/envs/llm/datasets/ifeval.py +274 -0
  211. torchrl/envs/llm/envs.py +789 -0
  212. torchrl/envs/llm/libs/README.md +3 -0
  213. torchrl/envs/llm/libs/__init__.py +8 -0
  214. torchrl/envs/llm/libs/mlgym.py +869 -0
  215. torchrl/envs/llm/reward/__init__.py +10 -0
  216. torchrl/envs/llm/reward/gsm8k.py +324 -0
  217. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  218. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  222. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  223. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  224. torchrl/envs/llm/transforms/__init__.py +55 -0
  225. torchrl/envs/llm/transforms/browser.py +292 -0
  226. torchrl/envs/llm/transforms/dataloading.py +859 -0
  227. torchrl/envs/llm/transforms/format.py +73 -0
  228. torchrl/envs/llm/transforms/kl.py +1544 -0
  229. torchrl/envs/llm/transforms/policy_version.py +189 -0
  230. torchrl/envs/llm/transforms/reason.py +323 -0
  231. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  232. torchrl/envs/llm/transforms/tools.py +1955 -0
  233. torchrl/envs/model_based/__init__.py +9 -0
  234. torchrl/envs/model_based/common.py +180 -0
  235. torchrl/envs/model_based/dreamer.py +112 -0
  236. torchrl/envs/transforms/__init__.py +147 -0
  237. torchrl/envs/transforms/functional.py +48 -0
  238. torchrl/envs/transforms/gym_transforms.py +203 -0
  239. torchrl/envs/transforms/module.py +341 -0
  240. torchrl/envs/transforms/r3m.py +372 -0
  241. torchrl/envs/transforms/ray_service.py +663 -0
  242. torchrl/envs/transforms/rb_transforms.py +214 -0
  243. torchrl/envs/transforms/transforms.py +11835 -0
  244. torchrl/envs/transforms/utils.py +94 -0
  245. torchrl/envs/transforms/vc1.py +307 -0
  246. torchrl/envs/transforms/vecnorm.py +845 -0
  247. torchrl/envs/transforms/vip.py +407 -0
  248. torchrl/envs/utils.py +1718 -0
  249. torchrl/envs/vec_envs.py +11 -0
  250. torchrl/modules/__init__.py +206 -0
  251. torchrl/modules/distributions/__init__.py +73 -0
  252. torchrl/modules/distributions/continuous.py +830 -0
  253. torchrl/modules/distributions/discrete.py +908 -0
  254. torchrl/modules/distributions/truncated_normal.py +187 -0
  255. torchrl/modules/distributions/utils.py +233 -0
  256. torchrl/modules/llm/__init__.py +62 -0
  257. torchrl/modules/llm/backends/__init__.py +65 -0
  258. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  259. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  260. torchrl/modules/llm/backends/vllm/base.py +72 -0
  261. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  262. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  263. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  264. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  265. torchrl/modules/llm/policies/__init__.py +28 -0
  266. torchrl/modules/llm/policies/common.py +1809 -0
  267. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  268. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  269. torchrl/modules/llm/utils.py +23 -0
  270. torchrl/modules/mcts/__init__.py +21 -0
  271. torchrl/modules/mcts/scores.py +579 -0
  272. torchrl/modules/models/__init__.py +86 -0
  273. torchrl/modules/models/batchrenorm.py +119 -0
  274. torchrl/modules/models/decision_transformer.py +179 -0
  275. torchrl/modules/models/exploration.py +731 -0
  276. torchrl/modules/models/llm.py +156 -0
  277. torchrl/modules/models/model_based.py +596 -0
  278. torchrl/modules/models/models.py +1712 -0
  279. torchrl/modules/models/multiagent.py +1067 -0
  280. torchrl/modules/models/recipes/impala.py +185 -0
  281. torchrl/modules/models/utils.py +162 -0
  282. torchrl/modules/planners/__init__.py +10 -0
  283. torchrl/modules/planners/cem.py +228 -0
  284. torchrl/modules/planners/common.py +73 -0
  285. torchrl/modules/planners/mppi.py +265 -0
  286. torchrl/modules/tensordict_module/__init__.py +89 -0
  287. torchrl/modules/tensordict_module/actors.py +2457 -0
  288. torchrl/modules/tensordict_module/common.py +529 -0
  289. torchrl/modules/tensordict_module/exploration.py +814 -0
  290. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  291. torchrl/modules/tensordict_module/rnn.py +1639 -0
  292. torchrl/modules/tensordict_module/sequence.py +132 -0
  293. torchrl/modules/tensordict_module/world_models.py +34 -0
  294. torchrl/modules/utils/__init__.py +38 -0
  295. torchrl/modules/utils/mappings.py +9 -0
  296. torchrl/modules/utils/utils.py +89 -0
  297. torchrl/objectives/__init__.py +78 -0
  298. torchrl/objectives/a2c.py +659 -0
  299. torchrl/objectives/common.py +753 -0
  300. torchrl/objectives/cql.py +1346 -0
  301. torchrl/objectives/crossq.py +710 -0
  302. torchrl/objectives/ddpg.py +453 -0
  303. torchrl/objectives/decision_transformer.py +371 -0
  304. torchrl/objectives/deprecated.py +516 -0
  305. torchrl/objectives/dqn.py +683 -0
  306. torchrl/objectives/dreamer.py +488 -0
  307. torchrl/objectives/functional.py +48 -0
  308. torchrl/objectives/gail.py +258 -0
  309. torchrl/objectives/iql.py +996 -0
  310. torchrl/objectives/llm/__init__.py +30 -0
  311. torchrl/objectives/llm/grpo.py +846 -0
  312. torchrl/objectives/llm/sft.py +482 -0
  313. torchrl/objectives/multiagent/__init__.py +8 -0
  314. torchrl/objectives/multiagent/qmixer.py +396 -0
  315. torchrl/objectives/ppo.py +1669 -0
  316. torchrl/objectives/redq.py +683 -0
  317. torchrl/objectives/reinforce.py +530 -0
  318. torchrl/objectives/sac.py +1580 -0
  319. torchrl/objectives/td3.py +570 -0
  320. torchrl/objectives/td3_bc.py +625 -0
  321. torchrl/objectives/utils.py +782 -0
  322. torchrl/objectives/value/__init__.py +28 -0
  323. torchrl/objectives/value/advantages.py +1956 -0
  324. torchrl/objectives/value/functional.py +1459 -0
  325. torchrl/objectives/value/utils.py +360 -0
  326. torchrl/record/__init__.py +17 -0
  327. torchrl/record/loggers/__init__.py +23 -0
  328. torchrl/record/loggers/common.py +48 -0
  329. torchrl/record/loggers/csv.py +226 -0
  330. torchrl/record/loggers/mlflow.py +142 -0
  331. torchrl/record/loggers/tensorboard.py +139 -0
  332. torchrl/record/loggers/trackio.py +163 -0
  333. torchrl/record/loggers/utils.py +78 -0
  334. torchrl/record/loggers/wandb.py +214 -0
  335. torchrl/record/recorder.py +554 -0
  336. torchrl/services/__init__.py +79 -0
  337. torchrl/services/base.py +109 -0
  338. torchrl/services/ray_service.py +453 -0
  339. torchrl/testing/__init__.py +107 -0
  340. torchrl/testing/assertions.py +179 -0
  341. torchrl/testing/dist_utils.py +122 -0
  342. torchrl/testing/env_creators.py +227 -0
  343. torchrl/testing/env_helper.py +35 -0
  344. torchrl/testing/gym_helpers.py +156 -0
  345. torchrl/testing/llm_mocks.py +119 -0
  346. torchrl/testing/mocking_classes.py +2720 -0
  347. torchrl/testing/modules.py +295 -0
  348. torchrl/testing/mp_helpers.py +15 -0
  349. torchrl/testing/ray_helpers.py +293 -0
  350. torchrl/testing/utils.py +190 -0
  351. torchrl/trainers/__init__.py +42 -0
  352. torchrl/trainers/algorithms/__init__.py +11 -0
  353. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  354. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  355. torchrl/trainers/algorithms/configs/common.py +41 -0
  356. torchrl/trainers/algorithms/configs/data.py +308 -0
  357. torchrl/trainers/algorithms/configs/envs.py +104 -0
  358. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  359. torchrl/trainers/algorithms/configs/logging.py +80 -0
  360. torchrl/trainers/algorithms/configs/modules.py +570 -0
  361. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  362. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  363. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  364. torchrl/trainers/algorithms/configs/utils.py +252 -0
  365. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  366. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  367. torchrl/trainers/algorithms/ppo.py +373 -0
  368. torchrl/trainers/algorithms/sac.py +308 -0
  369. torchrl/trainers/helpers/__init__.py +40 -0
  370. torchrl/trainers/helpers/collectors.py +416 -0
  371. torchrl/trainers/helpers/envs.py +573 -0
  372. torchrl/trainers/helpers/logger.py +33 -0
  373. torchrl/trainers/helpers/losses.py +132 -0
  374. torchrl/trainers/helpers/models.py +658 -0
  375. torchrl/trainers/helpers/replay_buffer.py +59 -0
  376. torchrl/trainers/helpers/trainers.py +301 -0
  377. torchrl/trainers/trainers.py +2052 -0
  378. torchrl/weight_update/__init__.py +33 -0
  379. torchrl/weight_update/_distributed.py +749 -0
  380. torchrl/weight_update/_mp.py +624 -0
  381. torchrl/weight_update/_noupdate.py +102 -0
  382. torchrl/weight_update/_ray.py +1032 -0
  383. torchrl/weight_update/_rpc.py +284 -0
  384. torchrl/weight_update/_shared.py +891 -0
  385. torchrl/weight_update/llm/__init__.py +32 -0
  386. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  387. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  388. torchrl/weight_update/utils.py +73 -0
  389. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  390. torchrl-0.11.0.dist-info/METADATA +1308 -0
  391. torchrl-0.11.0.dist-info/RECORD +395 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
  395. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,1042 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ # import tree
6
+ from __future__ import annotations
7
+
8
+ import contextlib
9
+ import itertools
10
+ import math
11
+ import operator
12
+ import os
13
+ import typing
14
+ from collections.abc import Callable
15
+ from pathlib import Path
16
+ from typing import Any, Union
17
+
18
+ import numpy as np
19
+ import torch
20
+ from tensordict import (
21
+ lazy_stack,
22
+ MemoryMappedTensor,
23
+ NonTensorData,
24
+ TensorDict,
25
+ TensorDictBase,
26
+ unravel_key,
27
+ )
28
+ from torch import Tensor
29
+ from torch.nn import functional as F
30
+ from torch.utils._pytree import LeafSpec, tree_flatten, tree_unflatten
31
+ from torchrl._utils import implement_for, logger as torchrl_logger
32
+
33
+ SINGLE_TENSOR_BUFFER_NAME = os.environ.get(
34
+ "SINGLE_TENSOR_BUFFER_NAME", "_-single-tensor-_"
35
+ )
36
+
37
+
38
+ INT_CLASSES_TYPING = Union[int, np.integer]
39
+ if hasattr(typing, "get_args"):
40
+ INT_CLASSES = typing.get_args(INT_CLASSES_TYPING)
41
+ else:
42
+ # python 3.7
43
+ INT_CLASSES = (int, np.integer)
44
+
45
+
46
+ def _to_numpy(data: Tensor) -> np.ndarray:
47
+ return data.detach().cpu().numpy() if isinstance(data, torch.Tensor) else data
48
+
49
+
50
+ def _to_torch(
51
+ data: Tensor, device, pin_memory: bool = False, non_blocking: bool = False
52
+ ) -> torch.Tensor:
53
+ if isinstance(data, np.generic):
54
+ return torch.as_tensor(data, device=device)
55
+ elif isinstance(data, np.ndarray):
56
+ data = torch.from_numpy(data)
57
+ elif not isinstance(data, Tensor):
58
+ data = torch.as_tensor(data, device=device)
59
+
60
+ if pin_memory:
61
+ data = data.pin_memory()
62
+ if device is not None:
63
+ data = data.to(device, non_blocking=non_blocking)
64
+
65
+ return data
66
+
67
+
68
+ def pin_memory_output(fun) -> Callable:
69
+ """Calls pin_memory on outputs of decorated function if they have such method."""
70
+
71
+ def decorated_fun(self, *args, **kwargs):
72
+ output = fun(self, *args, **kwargs)
73
+ if self._pin_memory:
74
+ _tuple_out = True
75
+ if not isinstance(output, tuple):
76
+ _tuple_out = False
77
+ output = (output,)
78
+ output = tuple(_pin_memory(_output) for _output in output)
79
+ if _tuple_out:
80
+ return output
81
+ return output[0]
82
+ return output
83
+
84
+ return decorated_fun
85
+
86
+
87
+ def _pin_memory(output: Any) -> Any:
88
+ if hasattr(output, "pin_memory") and output.device == torch.device("cpu"):
89
+ return output.pin_memory()
90
+ else:
91
+ return output
92
+
93
+
94
+ def _reduce(
95
+ tensor: torch.Tensor, reduction: str, dim: int | None = None
96
+ ) -> float | torch.Tensor:
97
+ """Reduces a tensor given the reduction method."""
98
+ if reduction == "max":
99
+ result = tensor.max(dim=dim)
100
+ elif reduction == "min":
101
+ result = tensor.min(dim=dim)
102
+ elif reduction == "mean":
103
+ result = tensor.mean(dim=dim)
104
+ elif reduction == "median":
105
+ result = tensor.median(dim=dim)
106
+ elif reduction == "sum":
107
+ result = tensor.sum(dim=dim)
108
+ else:
109
+ raise NotImplementedError(f"Unknown reduction method {reduction}")
110
+ if isinstance(result, tuple):
111
+ result = result[0]
112
+ return result.item() if dim is None else result
113
+
114
+
115
+ def _is_int(index):
116
+ if isinstance(index, INT_CLASSES):
117
+ return True
118
+ if isinstance(index, (np.ndarray, torch.Tensor)):
119
+ return index.ndim == 0
120
+ return False
121
+
122
+
123
+ class TED2Flat:
124
+ """A storage saving hook to serialize TED data in a compact format.
125
+
126
+ Args:
127
+ done_key (NestedKey, optional): the key where the done states should be read.
128
+ Defaults to ``("next", "done")``.
129
+ shift_key (NestedKey, optional): the key where the shift will be written.
130
+ Defaults to "shift".
131
+ is_full_key (NestedKey, optional): the key where the is_full attribute will be written.
132
+ Defaults to "is_full".
133
+ done_keys (Tuple[NestedKey], optional): a tuple of nested keys indicating the done entries.
134
+ Defaults to ("done", "truncated", "terminated")
135
+ reward_keys (Tuple[NestedKey], optional): a tuple of nested keys indicating the reward entries.
136
+ Defaults to ("reward",)
137
+
138
+
139
+ Examples:
140
+ >>> import tempfile
141
+ >>>
142
+ >>> from tensordict import TensorDict
143
+ >>>
144
+ >>> from torchrl.collectors import Collector
145
+ >>> from torchrl.data import ReplayBuffer, TED2Flat, LazyMemmapStorage
146
+ >>> from torchrl.envs import GymEnv
147
+ >>> import torch
148
+ >>>
149
+ >>> env = GymEnv("CartPole-v1")
150
+ >>> env.set_seed(0)
151
+ >>> torch.manual_seed(0)
152
+ >>> collector = Collector(env, policy=env.rand_step, total_frames=200, frames_per_batch=200)
153
+ >>> rb = ReplayBuffer(storage=LazyMemmapStorage(200))
154
+ >>> rb.register_save_hook(TED2Flat())
155
+ >>> with tempfile.TemporaryDirectory() as tmpdir:
156
+ ... for i, data in enumerate(collector):
157
+ ... rb.extend(data)
158
+ ... rb.dumps(tmpdir)
159
+ ... # load the data to represent it
160
+ ... td = TensorDict.load(tmpdir + "/storage/")
161
+ ... print(td)
162
+ TensorDict(
163
+ fields={
164
+ action: MemoryMappedTensor(shape=torch.Size([200, 2]), device=cpu, dtype=torch.int64, is_shared=True),
165
+ collector: TensorDict(
166
+ fields={
167
+ traj_ids: MemoryMappedTensor(shape=torch.Size([200]), device=cpu, dtype=torch.int64, is_shared=True)},
168
+ batch_size=torch.Size([]),
169
+ device=cpu,
170
+ is_shared=False),
171
+ done: MemoryMappedTensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=True),
172
+ observation: MemoryMappedTensor(shape=torch.Size([220, 4]), device=cpu, dtype=torch.float32, is_shared=True),
173
+ reward: MemoryMappedTensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=True),
174
+ terminated: MemoryMappedTensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=True),
175
+ truncated: MemoryMappedTensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=True)},
176
+ batch_size=torch.Size([]),
177
+ device=cpu,
178
+ is_shared=False)
179
+
180
+ """
181
+
182
+ _shift: int | None = None
183
+ _is_full: bool | None = None
184
+
185
+ def __init__(
186
+ self,
187
+ done_key=("next", "done"),
188
+ shift_key="shift",
189
+ is_full_key="is_full",
190
+ done_keys=("done", "truncated", "terminated"),
191
+ reward_keys=("reward",),
192
+ ):
193
+ self.done_key = done_key
194
+ self.shift_key = shift_key
195
+ self.is_full_key = is_full_key
196
+ self.done_keys = {unravel_key(key) for key in done_keys}
197
+ self.reward_keys = {unravel_key(key) for key in reward_keys}
198
+
199
+ @property
200
+ def shift(self):
201
+ return self._shift
202
+
203
+ @shift.setter
204
+ def shift(self, value: int):
205
+ self._shift = value
206
+
207
+ @property
208
+ def is_full(self):
209
+ return self._is_full
210
+
211
+ @is_full.setter
212
+ def is_full(self, value: int):
213
+ self._is_full = value
214
+
215
+ def __call__(self, data: TensorDictBase, path: Path = None):
216
+ # Get the done state
217
+ shift = self.shift
218
+ is_full = self.is_full
219
+
220
+ # Create an output storage
221
+ output = TensorDict()
222
+ output.set_non_tensor(self.is_full_key, is_full)
223
+ output.set_non_tensor(self.shift_key, shift)
224
+ output.set_non_tensor("_storage_shape", tuple(data.shape))
225
+ output.memmap_(path)
226
+
227
+ # Preallocate the output
228
+ done = data.get(self.done_key).squeeze(-1).clone()
229
+ if not is_full:
230
+ # shift is the cursor place
231
+ done[shift - 1] = True
232
+ else:
233
+ done = done.roll(-shift, dims=0)
234
+ done[-1] = True
235
+ ntraj = done.sum()
236
+
237
+ # Get the keys that require extra storage
238
+ keys_to_expand = set(data.get("next").keys(True, True)) - (
239
+ self.done_keys.union(self.reward_keys)
240
+ )
241
+
242
+ total_keys = data.exclude("next").keys(True, True)
243
+ total_keys = set(total_keys).union(set(data.get("next").keys(True, True)))
244
+
245
+ len_with_offset = data.numel() + ntraj # + done[0].numel()
246
+ for key in total_keys:
247
+ if key in (self.done_keys.union(self.reward_keys)):
248
+ entry = data.get(("next", key))
249
+ else:
250
+ entry = data.get(key)
251
+
252
+ if key in keys_to_expand:
253
+ shape = torch.Size([len_with_offset, *entry.shape[data.ndim :]])
254
+ dtype = entry.dtype
255
+ output.make_memmap(key, shape=shape, dtype=dtype)
256
+ else:
257
+ shape = torch.Size([data.numel(), *entry.shape[data.ndim :]])
258
+ output.make_memmap(key, shape=shape, dtype=entry.dtype)
259
+
260
+ if data.ndim == 1:
261
+ return self._call(
262
+ data=data,
263
+ output=output,
264
+ is_full=is_full,
265
+ shift=shift,
266
+ done=done,
267
+ total_keys=total_keys,
268
+ keys_to_expand=keys_to_expand,
269
+ )
270
+
271
+ with data.flatten(1, -1) if data.ndim > 2 else contextlib.nullcontext(
272
+ data
273
+ ) as data_flat:
274
+ if data.ndim > 2:
275
+ done = done.flatten(1, -1)
276
+ traj_per_dim = done.sum(0)
277
+ nsteps = data_flat.shape[0]
278
+
279
+ start = 0
280
+ start_with_offset = start
281
+ stop_with_offset = 0
282
+ stop = 0
283
+ for data_slice, done_slice, traj_for_dim in zip(
284
+ data_flat.unbind(1), done.unbind(1), traj_per_dim
285
+ ):
286
+ stop_with_offset = stop_with_offset + nsteps + traj_for_dim
287
+ cur_slice_offset = slice(start_with_offset, stop_with_offset)
288
+ start_with_offset = stop_with_offset
289
+
290
+ stop = stop + data.shape[0]
291
+ cur_slice = slice(start, stop)
292
+ start = stop
293
+
294
+ def _index(
295
+ key,
296
+ val,
297
+ keys_to_expand=keys_to_expand,
298
+ cur_slice=cur_slice,
299
+ cur_slice_offset=cur_slice_offset,
300
+ ):
301
+ if key in keys_to_expand:
302
+ return val[cur_slice_offset]
303
+ return val[cur_slice]
304
+
305
+ out_slice = output.named_apply(_index, nested_keys=True)
306
+ self._call(
307
+ data=data_slice,
308
+ output=out_slice,
309
+ is_full=is_full,
310
+ shift=shift,
311
+ done=done_slice,
312
+ total_keys=total_keys,
313
+ keys_to_expand=keys_to_expand,
314
+ )
315
+ return output
316
+
317
+ def _call(self, *, data, output, is_full, shift, done, total_keys, keys_to_expand):
318
+ # capture for each item in data where the observation should be written
319
+ idx = torch.arange(data.shape[0])
320
+ idx_done = (idx + done.cumsum(0))[done]
321
+ idx += torch.nn.functional.pad(done, [1, 0])[:-1].cumsum(0)
322
+
323
+ for key in total_keys:
324
+ if key in (self.done_keys.union(self.reward_keys)):
325
+ entry = data.get(("next", key))
326
+ else:
327
+ entry = data.get(key)
328
+
329
+ if key in keys_to_expand:
330
+ mmap = output.get(key)
331
+ shifted_next = data.get(("next", key))
332
+ if is_full:
333
+ _roll_inplace(entry, shift=-shift, out=mmap, index_dest=idx)
334
+ _roll_inplace(
335
+ shifted_next,
336
+ shift=-shift,
337
+ out=mmap,
338
+ index_dest=idx_done,
339
+ index_source=done,
340
+ )
341
+ else:
342
+ mmap[idx] = entry
343
+ mmap[idx_done] = shifted_next[done]
344
+ elif is_full:
345
+ mmap = output.get(key)
346
+ _roll_inplace(entry, shift=-shift, out=mmap)
347
+ else:
348
+ mmap = output.get(key)
349
+ mmap.copy_(entry)
350
+ return output
351
+
352
+
353
+ class Flat2TED:
354
+ """A storage loading hook to deserialize flattened TED data to TED format.
355
+
356
+ Args:
357
+ done_key (NestedKey, optional): the key where the done states should be read.
358
+ Defaults to ``("next", "done")``.
359
+ shift_key (NestedKey, optional): the key where the shift will be written.
360
+ Defaults to "shift".
361
+ is_full_key (NestedKey, optional): the key where the is_full attribute will be written.
362
+ Defaults to "is_full".
363
+ done_keys (Tuple[NestedKey], optional): a tuple of nested keys indicating the done entries.
364
+ Defaults to ("done", "truncated", "terminated")
365
+ reward_keys (Tuple[NestedKey], optional): a tuple of nested keys indicating the reward entries.
366
+ Defaults to ("reward",)
367
+
368
+ Examples:
369
+ >>> import tempfile
370
+ >>>
371
+ >>> from tensordict import TensorDict
372
+ >>>
373
+ >>> from torchrl.collectors import Collector
374
+ >>> from torchrl.data import ReplayBuffer, TED2Flat, LazyMemmapStorage, Flat2TED
375
+ >>> from torchrl.envs import GymEnv
376
+ >>> import torch
377
+ >>>
378
+ >>> env = GymEnv("CartPole-v1")
379
+ >>> env.set_seed(0)
380
+ >>> torch.manual_seed(0)
381
+ >>> collector = Collector(env, policy=env.rand_step, total_frames=200, frames_per_batch=200)
382
+ >>> rb = ReplayBuffer(storage=LazyMemmapStorage(200))
383
+ >>> rb.register_save_hook(TED2Flat())
384
+ >>> with tempfile.TemporaryDirectory() as tmpdir:
385
+ ... for i, data in enumerate(collector):
386
+ ... rb.extend(data)
387
+ ... rb.dumps(tmpdir)
388
+ ... # load the data to represent it
389
+ ... td = TensorDict.load(tmpdir + "/storage/")
390
+ ...
391
+ ... rb_load = ReplayBuffer(storage=LazyMemmapStorage(200))
392
+ ... rb_load.register_load_hook(Flat2TED())
393
+ ... rb_load.load(tmpdir)
394
+ ... print("storage after loading", rb_load[:])
395
+ ... assert (rb[:] == rb_load[:]).all()
396
+ storage after loading TensorDict(
397
+ fields={
398
+ action: MemoryMappedTensor(shape=torch.Size([200, 2]), device=cpu, dtype=torch.int64, is_shared=False),
399
+ collector: TensorDict(
400
+ fields={
401
+ traj_ids: MemoryMappedTensor(shape=torch.Size([200]), device=cpu, dtype=torch.int64, is_shared=False)},
402
+ batch_size=torch.Size([200]),
403
+ device=cpu,
404
+ is_shared=False),
405
+ done: MemoryMappedTensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False),
406
+ next: TensorDict(
407
+ fields={
408
+ done: MemoryMappedTensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False),
409
+ observation: MemoryMappedTensor(shape=torch.Size([200, 4]), device=cpu, dtype=torch.float32, is_shared=False),
410
+ reward: MemoryMappedTensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False),
411
+ terminated: MemoryMappedTensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False),
412
+ truncated: MemoryMappedTensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
413
+ batch_size=torch.Size([200]),
414
+ device=cpu,
415
+ is_shared=False),
416
+ observation: MemoryMappedTensor(shape=torch.Size([200, 4]), device=cpu, dtype=torch.float32, is_shared=False),
417
+ terminated: MemoryMappedTensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False),
418
+ truncated: MemoryMappedTensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
419
+ batch_size=torch.Size([200]),
420
+ device=cpu,
421
+ is_shared=False)
422
+
423
+
424
+ """
425
+
426
+ def __init__(
427
+ self,
428
+ done_key="done",
429
+ shift_key="shift",
430
+ is_full_key="is_full",
431
+ done_keys=("done", "truncated", "terminated"),
432
+ reward_keys=("reward",),
433
+ ):
434
+ self.done_key = done_key
435
+ self.shift_key = shift_key
436
+ self.is_full_key = is_full_key
437
+ self.done_keys = {unravel_key(key) for key in done_keys}
438
+ self.reward_keys = {unravel_key(key) for key in reward_keys}
439
+
440
+ def __call__(self, data: TensorDictBase, out: TensorDictBase = None):
441
+ _storage_shape = data.get_non_tensor("_storage_shape", default=None)
442
+ if isinstance(_storage_shape, int):
443
+ _storage_shape = torch.Size([_storage_shape])
444
+ shift = data.get_non_tensor(self.shift_key, default=None)
445
+ is_full = data.get_non_tensor(self.is_full_key, default=None)
446
+ done = (
447
+ data.get("done")
448
+ .reshape((*_storage_shape[1:], -1))
449
+ .contiguous()
450
+ .permute(-1, *range(0, len(_storage_shape) - 1))
451
+ .clone()
452
+ )
453
+ if not is_full:
454
+ # shift is the cursor place
455
+ done[shift - 1] = True
456
+ else:
457
+ # done = done.roll(-shift, dims=0)
458
+ done[-1] = True
459
+
460
+ if _storage_shape is not None and len(_storage_shape) > 1:
461
+ # iterate over data and allocate
462
+ if out is None:
463
+ # out = TensorDict(batch_size=_storage_shape)
464
+ # for i in range(out.ndim):
465
+ # if i >= 2:
466
+ # # FLattening the lazy stack will make the data unavailable - we need to find a way to make this
467
+ # # possible.
468
+ # raise RuntimeError(
469
+ # "Checkpointing an uninitialized buffer with more than 2 dimensions is currently not supported. "
470
+ # "Please file an issue on GitHub to ask for this feature!"
471
+ # )
472
+ # out = LazyStackedTensorDict(*out.unbind(i), stack_dim=i)
473
+ out = TensorDict(batch_size=_storage_shape)
474
+ for i in range(1, out.ndim):
475
+ if i >= 2:
476
+ # FLattening the lazy stack will make the data unavailable - we need to find a way to make this
477
+ # possible.
478
+ raise RuntimeError(
479
+ "Checkpointing an uninitialized buffer with more than 2 dimensions is currently not supported. "
480
+ "Please file an issue on GitHub to ask for this feature!"
481
+ )
482
+ out_list = [
483
+ out._get_sub_tensordict((slice(None),) * i + (j,))
484
+ for j in range(out.shape[i])
485
+ ]
486
+ out = lazy_stack(out_list, i)
487
+
488
+ # Create a function that reads slices of the input data
489
+ with out.flatten(1, -1) if out.ndim > 2 else contextlib.nullcontext(
490
+ out
491
+ ) as out_flat:
492
+ nsteps = done.shape[0]
493
+ n_elt_batch = done.shape[1:].numel()
494
+ traj_per_dim = done.sum(0)
495
+
496
+ start = 0
497
+ start_with_offset = start
498
+ stop_with_offset = 0
499
+ stop = 0
500
+
501
+ for out_unbound, traj_for_dim in zip(out_flat.unbind(-1), traj_per_dim):
502
+ stop_with_offset = stop_with_offset + nsteps + traj_for_dim
503
+ cur_slice_offset = slice(start_with_offset, stop_with_offset)
504
+ start_with_offset = stop_with_offset
505
+
506
+ stop = stop + nsteps
507
+ cur_slice = slice(start, stop)
508
+ start = stop
509
+
510
+ def _index(
511
+ key,
512
+ val,
513
+ cur_slice=cur_slice,
514
+ nsteps=nsteps,
515
+ n_elt_batch=n_elt_batch,
516
+ cur_slice_offset=cur_slice_offset,
517
+ ):
518
+ if val.shape[0] != (nsteps * n_elt_batch):
519
+ return val[cur_slice_offset]
520
+ return val[cur_slice]
521
+
522
+ data_slice = data.named_apply(
523
+ _index, nested_keys=True, batch_size=[]
524
+ )
525
+ self._call(
526
+ data=data_slice,
527
+ out=out_unbound,
528
+ is_full=is_full,
529
+ shift=shift,
530
+ _storage_shape=_storage_shape,
531
+ )
532
+ return out
533
+ return self._call(
534
+ data=data,
535
+ out=out,
536
+ is_full=is_full,
537
+ shift=shift,
538
+ _storage_shape=_storage_shape,
539
+ )
540
+
541
+ def _call(self, *, data, out, _storage_shape, shift, is_full):
542
+ done = data.get(self.done_key)
543
+ done = done.clone()
544
+
545
+ nsteps = done.shape[0]
546
+
547
+ # capture for each item in data where the observation should be written
548
+ idx = torch.arange(done.shape[0])
549
+ padded_done = F.pad(done.squeeze(-1), [1, 0])
550
+ root_idx = idx + padded_done[:-1].cumsum(0)
551
+ next_idx = root_idx + 1
552
+
553
+ if out is None:
554
+ out = TensorDict(batch_size=[nsteps])
555
+
556
+ def maybe_roll(entry, out=None):
557
+ if is_full and shift is not None:
558
+ if out is not None:
559
+ _roll_inplace(entry, shift=shift, out=out)
560
+ return
561
+ else:
562
+ return entry.roll(shift, dims=0)
563
+ if out is not None:
564
+ out.copy_(entry)
565
+ return
566
+ return entry
567
+
568
+ root_idx = maybe_roll(root_idx)
569
+ next_idx = maybe_roll(next_idx)
570
+ if not is_full:
571
+ next_idx = next_idx[:-1]
572
+
573
+ for key, entry in data.items(True, True):
574
+ if entry.shape[0] == nsteps:
575
+ if key in (self.done_keys.union(self.reward_keys)):
576
+ if key != "reward" and key not in out.keys(True, True):
577
+ # Create a done state at the root full of 0s
578
+ out.set(key, torch.zeros_like(entry), inplace=True)
579
+ entry = maybe_roll(entry, out=out.get(("next", key), None))
580
+ if entry is not None:
581
+ out.set(("next", key), entry, inplace=True)
582
+ else:
583
+ # action and similar
584
+ entry = maybe_roll(entry, out=out.get(key, default=None))
585
+ if entry is not None:
586
+ # then out is not locked
587
+ out.set(key, entry, inplace=True)
588
+ else:
589
+ dest_next = out.get(("next", key), None)
590
+ if dest_next is not None:
591
+ if not is_full:
592
+ dest_next = dest_next[:-1]
593
+ dest_next.copy_(entry[next_idx])
594
+ else:
595
+ if not is_full:
596
+ val = entry[next_idx]
597
+ val = torch.cat([val, torch.zeros_like(val[:1])])
598
+ out.set(("next", key), val, inplace=True)
599
+ else:
600
+ out.set(("next", key), entry[next_idx], inplace=True)
601
+
602
+ dest = out.get(key, None)
603
+ if dest is not None:
604
+ dest.copy_(entry[root_idx])
605
+ else:
606
+ out.set(key, entry[root_idx], inplace=True)
607
+ return out
608
+
609
+
610
+ class TED2Nested(TED2Flat):
611
+ """Converts a TED-formatted dataset into a tensordict populated with nested tensors where each row is a trajectory."""
612
+
613
+ _shift: int | None = None
614
+ _is_full: bool | None = None
615
+
616
+ def __init__(self, *args, **kwargs):
617
+ if not hasattr(torch, "_nested_compute_contiguous_strides_offsets"):
618
+ raise ValueError(
619
+ f"Unsupported torch version {torch.__version__}. "
620
+ f"torch>=2.4 is required for {type(self).__name__} to be used."
621
+ )
622
+ return super().__init__(*args, **kwargs)
623
+
624
+ def __call__(self, data: TensorDictBase, path: Path = None):
625
+ data = super().__call__(data, path=path)
626
+
627
+ shift = self.shift
628
+ is_full = self.is_full
629
+ storage_shape = data.get_non_tensor("_storage_shape", (-1,))
630
+ # place time at the end
631
+ storage_shape = (*storage_shape[1:], storage_shape[0])
632
+
633
+ done = data.get("done")
634
+ done = done.squeeze(-1).clone()
635
+ if not is_full:
636
+ done.view(storage_shape)[..., shift - 1] = True
637
+ # else:
638
+ done.view(storage_shape)[..., -1] = True
639
+
640
+ ntraj = done.sum()
641
+
642
+ nz = done.nonzero(as_tuple=True)[0]
643
+ traj_lengths = torch.cat([nz[:1] + 1, nz.diff()])
644
+ # if not is_full:
645
+ # traj_lengths = torch.cat(
646
+ # [traj_lengths, (done.shape[0] - traj_lengths.sum()).unsqueeze(0)]
647
+ # )
648
+
649
+ keys_to_expand, keys_to_keep = zip(
650
+ *[
651
+ (key, None) if val.shape[0] != done.shape[0] else (None, key)
652
+ for key, val in data.items(True, True)
653
+ ]
654
+ )
655
+ keys_to_expand = [key for key in keys_to_expand if key is not None]
656
+ keys_to_keep = [key for key in keys_to_keep if key is not None]
657
+
658
+ out = TensorDict(batch_size=[ntraj])
659
+ out.update(dict(data.non_tensor_items()))
660
+
661
+ out.memmap_(path)
662
+
663
+ traj_lengths = traj_lengths.unsqueeze(-1)
664
+ if not is_full:
665
+ # Increment by one only the trajectories that are not terminal
666
+ traj_lengths_expand = traj_lengths + (
667
+ traj_lengths.cumsum(0) % storage_shape[-1] != 0
668
+ )
669
+ else:
670
+ traj_lengths_expand = traj_lengths + 1
671
+ for key in keys_to_expand:
672
+ val = data.get(key)
673
+ shape = torch.cat(
674
+ [
675
+ traj_lengths_expand,
676
+ torch.tensor(val.shape[1:], dtype=torch.long).repeat(
677
+ traj_lengths.numel(), 1
678
+ ),
679
+ ],
680
+ -1,
681
+ )
682
+ # This works because the storage location is the same as the previous one - no copy is done
683
+ # but a new shape is written
684
+ out.make_memmap_from_storage(
685
+ key, val.untyped_storage(), dtype=val.dtype, shape=shape
686
+ )
687
+ for key in keys_to_keep:
688
+ val = data.get(key)
689
+ shape = torch.cat(
690
+ [
691
+ traj_lengths,
692
+ torch.tensor(val.shape[1:], dtype=torch.long).repeat(
693
+ traj_lengths.numel(), 1
694
+ ),
695
+ ],
696
+ -1,
697
+ )
698
+ out.make_memmap_from_storage(
699
+ key, val.untyped_storage(), dtype=val.dtype, shape=shape
700
+ )
701
+ return out
702
+
703
+
704
+ class Nested2TED(Flat2TED):
705
+ """Converts a nested tensordict where each row is a trajectory into the TED format."""
706
+
707
+ def __call__(self, data, out: TensorDictBase = None):
708
+ # Get a flat representation of data
709
+ def flatten_het_dim(tensor):
710
+ shape = [tensor.size(i) for i in range(2, tensor.ndim)]
711
+ tensor = torch.tensor(tensor.untyped_storage(), dtype=tensor.dtype).view(
712
+ -1, *shape
713
+ )
714
+ return tensor
715
+
716
+ data = data.apply(flatten_het_dim, batch_size=[])
717
+ data.auto_batch_size_()
718
+ return super().__call__(data, out=out)
719
+
720
+
721
+ class H5Split(TED2Flat):
722
+ """Splits a dataset prepared with TED2Nested into a TensorDict where each trajectory is stored as views on their parent nested tensors."""
723
+
724
+ _shift: int | None = None
725
+ _is_full: bool | None = None
726
+
727
+ def __call__(self, data):
728
+ nzeros = int(math.ceil(math.log10(data.shape[0])))
729
+
730
+ result = TensorDict(
731
+ {
732
+ f"traj_{str(i).zfill(nzeros)}": _data
733
+ for i, _data in enumerate(data.filter_non_tensor_data().unbind(0))
734
+ }
735
+ ).update(dict(data.non_tensor_items()))
736
+
737
+ return result
738
+
739
+
740
+ class H5Combine:
741
+ """Combines trajectories in a persistent tensordict into a single standing tensordict stored in filesystem."""
742
+
743
+ def __call__(self, data, out=None):
744
+ # TODO: this load the entire H5 in memory, which can be problematic
745
+ # Ideally we would want to load it on a memmap tensordict
746
+ # We currently ignore out in this call but we should leverage that
747
+ values = [val for key, val in data.items() if key.startswith("traj")]
748
+ metadata_keys = [key for key in data.keys() if not key.startswith("traj")]
749
+ result = TensorDict({key: NonTensorData(data[key]) for key in metadata_keys})
750
+
751
+ # Create a memmap in file system (no files associated)
752
+ result.memmap_()
753
+
754
+ # Create each entry
755
+ def initialize(key, *x):
756
+ result.make_memmap(
757
+ key,
758
+ shape=torch.stack([torch.tensor(_x.shape) for _x in x]),
759
+ dtype=x[0].dtype,
760
+ )
761
+ return
762
+
763
+ values[0].named_apply(
764
+ initialize,
765
+ *values[1:],
766
+ nested_keys=True,
767
+ batch_size=[],
768
+ filter_empty=True,
769
+ )
770
+
771
+ # Populate the entries
772
+ def populate(key, *x):
773
+ dest = result.get(key)
774
+ for i, _x in enumerate(x):
775
+ dest[i].copy_(_x)
776
+
777
+ values[0].named_apply(
778
+ populate,
779
+ *values[1:],
780
+ nested_keys=True,
781
+ batch_size=[],
782
+ filter_empty=True,
783
+ )
784
+ return result
785
+
786
+
787
+ @implement_for("torch", "2.3", None)
788
+ def _path2str(path, default_name=None):
789
+ # Uses the Keys defined in pytree to build a path
790
+ from torch.utils._pytree import MappingKey, SequenceKey
791
+
792
+ if default_name is None:
793
+ default_name = SINGLE_TENSOR_BUFFER_NAME
794
+ if not path:
795
+ return default_name
796
+ if isinstance(path, tuple):
797
+ return "/".join([_path2str(_sub, default_name=default_name) for _sub in path])
798
+ if isinstance(path, MappingKey):
799
+ if not isinstance(path.key, (int, str, bytes)):
800
+ raise ValueError("Values must be of type int, str or bytes in PyTree maps.")
801
+ result = str(path.key)
802
+ if result == default_name:
803
+ raise RuntimeError(
804
+ "A tensor had the same identifier as the default name used when the buffer contains "
805
+ f"a single tensor (name={default_name}). This behavior is not allowed. Please rename your "
806
+ f"tensor in the map/dict or set a new default name with the environment variable SINGLE_TENSOR_BUFFER_NAME."
807
+ )
808
+ return result
809
+ if isinstance(path, SequenceKey):
810
+ return str(path.idx)
811
+
812
+
813
+ @implement_for("torch", None, "2.3")
814
+ def _path2str(path, default_name=None): # noqa: F811
815
+ raise RuntimeError
816
+
817
+
818
+ def _save_pytree_common(tensor_path, path, tensor, metadata):
819
+ if "." in tensor_path:
820
+ tensor_path.replace(".", "_<dot>_")
821
+ total_tensor_path = path / (tensor_path + ".memmap")
822
+ if os.path.exists(total_tensor_path):
823
+ MemoryMappedTensor.from_filename(
824
+ shape=tensor.shape,
825
+ filename=total_tensor_path,
826
+ dtype=tensor.dtype,
827
+ ).copy_(tensor)
828
+ else:
829
+ os.makedirs(total_tensor_path.parent, exist_ok=True)
830
+ MemoryMappedTensor.from_tensor(
831
+ tensor,
832
+ filename=total_tensor_path,
833
+ copy_existing=True,
834
+ copy_data=True,
835
+ )
836
+ key = tensor_path.replace("/", ".")
837
+ if key in metadata:
838
+ raise KeyError(
839
+ "At least two values have conflicting representations in "
840
+ f"the data structure to be serialized: {key}."
841
+ )
842
+ metadata[key] = {
843
+ "dtype": str(tensor.dtype),
844
+ "shape": list(tensor.shape),
845
+ }
846
+
847
+
848
+ @implement_for("torch", "2.3", None)
849
+ def _save_pytree(_storage, metadata, path):
850
+ from torch.utils._pytree import tree_map_with_path
851
+
852
+ def save_tensor(
853
+ tensor_path: tuple, tensor: torch.Tensor, metadata=metadata, path=path
854
+ ):
855
+ tensor_path = _path2str(tensor_path)
856
+ _save_pytree_common(tensor_path, path, tensor, metadata)
857
+
858
+ tree_map_with_path(save_tensor, _storage)
859
+
860
+
861
+ @implement_for("torch", None, "2.3")
862
+ def _save_pytree(_storage, metadata, path): # noqa: F811
863
+
864
+ flat_storage, storage_specs = tree_flatten(_storage)
865
+ storage_paths = _get_paths(storage_specs)
866
+
867
+ def save_tensor(
868
+ tensor_path: str, tensor: torch.Tensor, metadata=metadata, path=path
869
+ ):
870
+ _save_pytree_common(tensor_path, path, tensor, metadata)
871
+
872
+ for tensor, tensor_path in zip(flat_storage, storage_paths):
873
+ save_tensor(tensor_path, tensor)
874
+
875
+
876
+ def _get_paths(spec, cumulpath=""):
877
+ # alternative way to build a path without the keys
878
+ if isinstance(spec, LeafSpec):
879
+ yield cumulpath if cumulpath else SINGLE_TENSOR_BUFFER_NAME
880
+
881
+ contexts = spec.context
882
+ children_specs = spec.children_specs
883
+ if contexts is None:
884
+ contexts = range(len(children_specs))
885
+
886
+ for context, spec in zip(contexts, children_specs):
887
+ cpath = "/".join((cumulpath, str(context))) if cumulpath else str(context)
888
+ yield from _get_paths(spec, cpath)
889
+
890
+
891
+ def _init_pytree_common(tensor_path, scratch_dir, max_size_fn, tensor):
892
+ if "." in tensor_path:
893
+ tensor_path.replace(".", "_<dot>_")
894
+ if scratch_dir is not None:
895
+ total_tensor_path = Path(scratch_dir) / (tensor_path + ".memmap")
896
+ if os.path.exists(total_tensor_path):
897
+ raise RuntimeError(
898
+ f"The storage of tensor {total_tensor_path} already exists. "
899
+ f"To load an existing replay buffer, use storage.loads. "
900
+ f"Choose a different path to store your buffer or delete the existing files."
901
+ )
902
+ os.makedirs(total_tensor_path.parent, exist_ok=True)
903
+ else:
904
+ total_tensor_path = None
905
+ out = MemoryMappedTensor.empty(
906
+ shape=max_size_fn(tensor.shape),
907
+ filename=total_tensor_path,
908
+ dtype=tensor.dtype,
909
+ )
910
+ try:
911
+ filesize = os.path.getsize(tensor.filename) / 1024 / 1024
912
+ torchrl_logger.debug(
913
+ f"The storage was created in {out.filename} and occupies {filesize} Mb of storage."
914
+ )
915
+ except (RuntimeError, AttributeError):
916
+ pass
917
+ return out
918
+
919
+
920
+ @implement_for("torch", "2.3", None)
921
+ def _init_pytree(scratch_dir, max_size_fn, data):
922
+ from torch.utils._pytree import tree_map_with_path
923
+
924
+ # If not a tensorclass/tensordict, it must be a tensor(-like) or a PyTree
925
+ # if Tensor, we just create a MemoryMappedTensor of the desired shape, device and dtype
926
+ def save_tensor(tensor_path: tuple, tensor: torch.Tensor):
927
+ tensor_path = _path2str(tensor_path)
928
+ return _init_pytree_common(tensor_path, scratch_dir, max_size_fn, tensor)
929
+
930
+ out = tree_map_with_path(save_tensor, data)
931
+ return out
932
+
933
+
934
+ @implement_for("torch", None, "2.3")
935
+ def _init_pytree(scratch_dir, max_size, data): # noqa: F811
936
+
937
+ flat_data, data_specs = tree_flatten(data)
938
+ data_paths = _get_paths(data_specs)
939
+ data_paths = list(data_paths)
940
+
941
+ # If not a tensorclass/tensordict, it must be a tensor(-like) or a PyTree
942
+ # if Tensor, we just create a MemoryMappedTensor of the desired shape, device and dtype
943
+ def save_tensor(tensor_path: str, tensor: torch.Tensor):
944
+ return _init_pytree_common(tensor_path, scratch_dir, max_size, tensor)
945
+
946
+ out = []
947
+ for tensor, tensor_path in zip(flat_data, data_paths):
948
+ out.append(save_tensor(tensor_path, tensor))
949
+
950
+ return tree_unflatten(out, data_specs)
951
+
952
+
953
+ def _roll_inplace(tensor, shift, out, index_dest=None, index_source=None):
954
+ # slice 0
955
+ source0 = tensor[:-shift]
956
+ if index_source is not None:
957
+ source0 = source0[index_source[shift:]]
958
+
959
+ slice0_shift = source0.shape[0]
960
+ if index_dest is not None:
961
+ out[index_dest[-slice0_shift:]] = source0
962
+ else:
963
+ slice0 = out[-slice0_shift:]
964
+ slice0.copy_(source0)
965
+
966
+ # slice 1
967
+ source1 = tensor[-shift:]
968
+ if index_source is not None:
969
+ source1 = source1[index_source[:shift]]
970
+ if index_dest is not None:
971
+ out[index_dest[:-slice0_shift]] = source1
972
+ else:
973
+ slice1 = out[:-slice0_shift]
974
+ slice1.copy_(source1)
975
+ return out
976
+
977
+
978
+ # Copy-paste of unravel-index for PT 2.0
979
+ def _unravel_index(
980
+ indices: Tensor, shape: int | typing.Sequence[int] | torch.Size
981
+ ) -> tuple[Tensor, ...]:
982
+ res_tensor = _unravel_index_impl(indices, shape)
983
+ return res_tensor.unbind(-1)
984
+
985
+
986
+ def _unravel_index_impl(indices: Tensor, shape: int | typing.Sequence[int]) -> Tensor:
987
+ if isinstance(shape, (int, torch.SymInt)):
988
+ shape = torch.Size([shape])
989
+ else:
990
+ shape = torch.Size(shape)
991
+
992
+ coefs = list(
993
+ reversed(
994
+ list(
995
+ itertools.accumulate(
996
+ reversed(shape[1:] + torch.Size([1])), func=operator.mul
997
+ )
998
+ )
999
+ )
1000
+ )
1001
+ return indices.unsqueeze(-1).floor_divide(
1002
+ torch.tensor(coefs, device=indices.device, dtype=torch.int64)
1003
+ ) % torch.tensor(shape, device=indices.device, dtype=torch.int64)
1004
+
1005
+
1006
+ @implement_for("torch", None, "2.2")
1007
+ def unravel_index(indices, shape):
1008
+ """A version-compatible wrapper around torch.unravel_index."""
1009
+ return _unravel_index(indices, shape)
1010
+
1011
+
1012
+ @implement_for("torch", "2.2")
1013
+ def unravel_index(indices, shape): # noqa: F811
1014
+ """A version-compatible wrapper around torch.unravel_index."""
1015
+ return torch.unravel_index(indices, shape)
1016
+
1017
+
1018
+ @implement_for("torch", None, "2.3")
1019
+ def tree_iter(pytree):
1020
+ """A version-compatible wrapper around tree_iter."""
1021
+ flat_tree, _ = torch.utils._pytree.tree_flatten(pytree)
1022
+ yield from flat_tree
1023
+
1024
+
1025
+ @implement_for("torch", "2.3", "2.4")
1026
+ def tree_iter(pytree): # noqa: F811
1027
+ """A version-compatible wrapper around tree_iter."""
1028
+ yield from torch.utils._pytree.tree_leaves(pytree)
1029
+
1030
+
1031
+ @implement_for("torch", "2.4")
1032
+ def tree_iter(pytree): # noqa: F811
1033
+ """A version-compatible wrapper around tree_iter."""
1034
+ yield from torch.utils._pytree.tree_iter(pytree)
1035
+
1036
+
1037
+ def _auto_device() -> torch.device:
1038
+ if torch.cuda.is_available():
1039
+ return torch.device("cuda:0")
1040
+ elif torch.mps.is_available():
1041
+ return torch.device("mps:0")
1042
+ return torch.device("cpu")