torchrl 0.11.0__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (394) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/__init__.py +144 -0
  97. torchrl/_extension.py +74 -0
  98. torchrl/_torchrl.cp314-win_amd64.pyd +0 -0
  99. torchrl/_utils.py +1431 -0
  100. torchrl/collectors/__init__.py +48 -0
  101. torchrl/collectors/_base.py +1058 -0
  102. torchrl/collectors/_constants.py +88 -0
  103. torchrl/collectors/_multi_async.py +324 -0
  104. torchrl/collectors/_multi_base.py +1805 -0
  105. torchrl/collectors/_multi_sync.py +464 -0
  106. torchrl/collectors/_runner.py +581 -0
  107. torchrl/collectors/_single.py +2009 -0
  108. torchrl/collectors/_single_async.py +259 -0
  109. torchrl/collectors/collectors.py +62 -0
  110. torchrl/collectors/distributed/__init__.py +32 -0
  111. torchrl/collectors/distributed/default_configs.py +133 -0
  112. torchrl/collectors/distributed/generic.py +1306 -0
  113. torchrl/collectors/distributed/ray.py +1092 -0
  114. torchrl/collectors/distributed/rpc.py +1006 -0
  115. torchrl/collectors/distributed/sync.py +731 -0
  116. torchrl/collectors/distributed/utils.py +160 -0
  117. torchrl/collectors/llm/__init__.py +10 -0
  118. torchrl/collectors/llm/base.py +494 -0
  119. torchrl/collectors/llm/ray_collector.py +275 -0
  120. torchrl/collectors/llm/utils.py +36 -0
  121. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  122. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  123. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  124. torchrl/collectors/utils.py +433 -0
  125. torchrl/collectors/weight_update.py +591 -0
  126. torchrl/csrc/numpy_utils.h +38 -0
  127. torchrl/csrc/pybind.cpp +27 -0
  128. torchrl/csrc/segment_tree.h +458 -0
  129. torchrl/csrc/torch_utils.h +34 -0
  130. torchrl/csrc/utils.cpp +48 -0
  131. torchrl/csrc/utils.h +31 -0
  132. torchrl/data/__init__.py +187 -0
  133. torchrl/data/datasets/__init__.py +58 -0
  134. torchrl/data/datasets/atari_dqn.py +878 -0
  135. torchrl/data/datasets/common.py +281 -0
  136. torchrl/data/datasets/d4rl.py +489 -0
  137. torchrl/data/datasets/d4rl_infos.py +187 -0
  138. torchrl/data/datasets/gen_dgrl.py +375 -0
  139. torchrl/data/datasets/minari_data.py +643 -0
  140. torchrl/data/datasets/openml.py +177 -0
  141. torchrl/data/datasets/openx.py +798 -0
  142. torchrl/data/datasets/roboset.py +363 -0
  143. torchrl/data/datasets/utils.py +11 -0
  144. torchrl/data/datasets/vd4rl.py +432 -0
  145. torchrl/data/llm/__init__.py +34 -0
  146. torchrl/data/llm/dataset.py +491 -0
  147. torchrl/data/llm/history.py +1378 -0
  148. torchrl/data/llm/prompt.py +198 -0
  149. torchrl/data/llm/reward.py +225 -0
  150. torchrl/data/llm/topk.py +186 -0
  151. torchrl/data/llm/utils.py +543 -0
  152. torchrl/data/map/__init__.py +21 -0
  153. torchrl/data/map/hash.py +185 -0
  154. torchrl/data/map/query.py +204 -0
  155. torchrl/data/map/tdstorage.py +363 -0
  156. torchrl/data/map/tree.py +1434 -0
  157. torchrl/data/map/utils.py +103 -0
  158. torchrl/data/postprocs/__init__.py +8 -0
  159. torchrl/data/postprocs/postprocs.py +391 -0
  160. torchrl/data/replay_buffers/__init__.py +99 -0
  161. torchrl/data/replay_buffers/checkpointers.py +622 -0
  162. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  163. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  164. torchrl/data/replay_buffers/samplers.py +2578 -0
  165. torchrl/data/replay_buffers/scheduler.py +265 -0
  166. torchrl/data/replay_buffers/storages.py +2412 -0
  167. torchrl/data/replay_buffers/utils.py +1042 -0
  168. torchrl/data/replay_buffers/writers.py +781 -0
  169. torchrl/data/tensor_specs.py +7101 -0
  170. torchrl/data/utils.py +334 -0
  171. torchrl/envs/__init__.py +265 -0
  172. torchrl/envs/async_envs.py +1105 -0
  173. torchrl/envs/batched_envs.py +3093 -0
  174. torchrl/envs/common.py +4241 -0
  175. torchrl/envs/custom/__init__.py +11 -0
  176. torchrl/envs/custom/chess.py +617 -0
  177. torchrl/envs/custom/llm.py +214 -0
  178. torchrl/envs/custom/pendulum.py +401 -0
  179. torchrl/envs/custom/san_moves.txt +29274 -0
  180. torchrl/envs/custom/tictactoeenv.py +288 -0
  181. torchrl/envs/env_creator.py +263 -0
  182. torchrl/envs/gym_like.py +752 -0
  183. torchrl/envs/libs/__init__.py +68 -0
  184. torchrl/envs/libs/_gym_utils.py +326 -0
  185. torchrl/envs/libs/brax.py +846 -0
  186. torchrl/envs/libs/dm_control.py +544 -0
  187. torchrl/envs/libs/envpool.py +447 -0
  188. torchrl/envs/libs/gym.py +2239 -0
  189. torchrl/envs/libs/habitat.py +138 -0
  190. torchrl/envs/libs/isaac_lab.py +87 -0
  191. torchrl/envs/libs/isaacgym.py +203 -0
  192. torchrl/envs/libs/jax_utils.py +166 -0
  193. torchrl/envs/libs/jumanji.py +963 -0
  194. torchrl/envs/libs/meltingpot.py +599 -0
  195. torchrl/envs/libs/openml.py +153 -0
  196. torchrl/envs/libs/openspiel.py +652 -0
  197. torchrl/envs/libs/pettingzoo.py +1042 -0
  198. torchrl/envs/libs/procgen.py +351 -0
  199. torchrl/envs/libs/robohive.py +429 -0
  200. torchrl/envs/libs/smacv2.py +645 -0
  201. torchrl/envs/libs/unity_mlagents.py +891 -0
  202. torchrl/envs/libs/utils.py +147 -0
  203. torchrl/envs/libs/vmas.py +813 -0
  204. torchrl/envs/llm/__init__.py +63 -0
  205. torchrl/envs/llm/chat.py +730 -0
  206. torchrl/envs/llm/datasets/README.md +4 -0
  207. torchrl/envs/llm/datasets/__init__.py +17 -0
  208. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  209. torchrl/envs/llm/datasets/ifeval.py +274 -0
  210. torchrl/envs/llm/envs.py +789 -0
  211. torchrl/envs/llm/libs/README.md +3 -0
  212. torchrl/envs/llm/libs/__init__.py +8 -0
  213. torchrl/envs/llm/libs/mlgym.py +869 -0
  214. torchrl/envs/llm/reward/__init__.py +10 -0
  215. torchrl/envs/llm/reward/gsm8k.py +324 -0
  216. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  217. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  218. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  222. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  223. torchrl/envs/llm/transforms/__init__.py +55 -0
  224. torchrl/envs/llm/transforms/browser.py +292 -0
  225. torchrl/envs/llm/transforms/dataloading.py +859 -0
  226. torchrl/envs/llm/transforms/format.py +73 -0
  227. torchrl/envs/llm/transforms/kl.py +1544 -0
  228. torchrl/envs/llm/transforms/policy_version.py +189 -0
  229. torchrl/envs/llm/transforms/reason.py +323 -0
  230. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  231. torchrl/envs/llm/transforms/tools.py +1955 -0
  232. torchrl/envs/model_based/__init__.py +9 -0
  233. torchrl/envs/model_based/common.py +180 -0
  234. torchrl/envs/model_based/dreamer.py +112 -0
  235. torchrl/envs/transforms/__init__.py +147 -0
  236. torchrl/envs/transforms/functional.py +48 -0
  237. torchrl/envs/transforms/gym_transforms.py +203 -0
  238. torchrl/envs/transforms/module.py +341 -0
  239. torchrl/envs/transforms/r3m.py +372 -0
  240. torchrl/envs/transforms/ray_service.py +663 -0
  241. torchrl/envs/transforms/rb_transforms.py +214 -0
  242. torchrl/envs/transforms/transforms.py +11835 -0
  243. torchrl/envs/transforms/utils.py +94 -0
  244. torchrl/envs/transforms/vc1.py +307 -0
  245. torchrl/envs/transforms/vecnorm.py +845 -0
  246. torchrl/envs/transforms/vip.py +407 -0
  247. torchrl/envs/utils.py +1718 -0
  248. torchrl/envs/vec_envs.py +11 -0
  249. torchrl/modules/__init__.py +206 -0
  250. torchrl/modules/distributions/__init__.py +73 -0
  251. torchrl/modules/distributions/continuous.py +830 -0
  252. torchrl/modules/distributions/discrete.py +908 -0
  253. torchrl/modules/distributions/truncated_normal.py +187 -0
  254. torchrl/modules/distributions/utils.py +233 -0
  255. torchrl/modules/llm/__init__.py +62 -0
  256. torchrl/modules/llm/backends/__init__.py +65 -0
  257. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  258. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  259. torchrl/modules/llm/backends/vllm/base.py +72 -0
  260. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  261. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  262. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  263. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  264. torchrl/modules/llm/policies/__init__.py +28 -0
  265. torchrl/modules/llm/policies/common.py +1809 -0
  266. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  267. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  268. torchrl/modules/llm/utils.py +23 -0
  269. torchrl/modules/mcts/__init__.py +21 -0
  270. torchrl/modules/mcts/scores.py +579 -0
  271. torchrl/modules/models/__init__.py +86 -0
  272. torchrl/modules/models/batchrenorm.py +119 -0
  273. torchrl/modules/models/decision_transformer.py +179 -0
  274. torchrl/modules/models/exploration.py +731 -0
  275. torchrl/modules/models/llm.py +156 -0
  276. torchrl/modules/models/model_based.py +596 -0
  277. torchrl/modules/models/models.py +1712 -0
  278. torchrl/modules/models/multiagent.py +1067 -0
  279. torchrl/modules/models/recipes/impala.py +185 -0
  280. torchrl/modules/models/utils.py +162 -0
  281. torchrl/modules/planners/__init__.py +10 -0
  282. torchrl/modules/planners/cem.py +228 -0
  283. torchrl/modules/planners/common.py +73 -0
  284. torchrl/modules/planners/mppi.py +265 -0
  285. torchrl/modules/tensordict_module/__init__.py +89 -0
  286. torchrl/modules/tensordict_module/actors.py +2457 -0
  287. torchrl/modules/tensordict_module/common.py +529 -0
  288. torchrl/modules/tensordict_module/exploration.py +814 -0
  289. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  290. torchrl/modules/tensordict_module/rnn.py +1639 -0
  291. torchrl/modules/tensordict_module/sequence.py +132 -0
  292. torchrl/modules/tensordict_module/world_models.py +34 -0
  293. torchrl/modules/utils/__init__.py +38 -0
  294. torchrl/modules/utils/mappings.py +9 -0
  295. torchrl/modules/utils/utils.py +89 -0
  296. torchrl/objectives/__init__.py +78 -0
  297. torchrl/objectives/a2c.py +659 -0
  298. torchrl/objectives/common.py +753 -0
  299. torchrl/objectives/cql.py +1346 -0
  300. torchrl/objectives/crossq.py +710 -0
  301. torchrl/objectives/ddpg.py +453 -0
  302. torchrl/objectives/decision_transformer.py +371 -0
  303. torchrl/objectives/deprecated.py +516 -0
  304. torchrl/objectives/dqn.py +683 -0
  305. torchrl/objectives/dreamer.py +488 -0
  306. torchrl/objectives/functional.py +48 -0
  307. torchrl/objectives/gail.py +258 -0
  308. torchrl/objectives/iql.py +996 -0
  309. torchrl/objectives/llm/__init__.py +30 -0
  310. torchrl/objectives/llm/grpo.py +846 -0
  311. torchrl/objectives/llm/sft.py +482 -0
  312. torchrl/objectives/multiagent/__init__.py +8 -0
  313. torchrl/objectives/multiagent/qmixer.py +396 -0
  314. torchrl/objectives/ppo.py +1669 -0
  315. torchrl/objectives/redq.py +683 -0
  316. torchrl/objectives/reinforce.py +530 -0
  317. torchrl/objectives/sac.py +1580 -0
  318. torchrl/objectives/td3.py +570 -0
  319. torchrl/objectives/td3_bc.py +625 -0
  320. torchrl/objectives/utils.py +782 -0
  321. torchrl/objectives/value/__init__.py +28 -0
  322. torchrl/objectives/value/advantages.py +1956 -0
  323. torchrl/objectives/value/functional.py +1459 -0
  324. torchrl/objectives/value/utils.py +360 -0
  325. torchrl/record/__init__.py +17 -0
  326. torchrl/record/loggers/__init__.py +23 -0
  327. torchrl/record/loggers/common.py +48 -0
  328. torchrl/record/loggers/csv.py +226 -0
  329. torchrl/record/loggers/mlflow.py +142 -0
  330. torchrl/record/loggers/tensorboard.py +139 -0
  331. torchrl/record/loggers/trackio.py +163 -0
  332. torchrl/record/loggers/utils.py +78 -0
  333. torchrl/record/loggers/wandb.py +214 -0
  334. torchrl/record/recorder.py +554 -0
  335. torchrl/services/__init__.py +79 -0
  336. torchrl/services/base.py +109 -0
  337. torchrl/services/ray_service.py +453 -0
  338. torchrl/testing/__init__.py +107 -0
  339. torchrl/testing/assertions.py +179 -0
  340. torchrl/testing/dist_utils.py +122 -0
  341. torchrl/testing/env_creators.py +227 -0
  342. torchrl/testing/env_helper.py +35 -0
  343. torchrl/testing/gym_helpers.py +156 -0
  344. torchrl/testing/llm_mocks.py +119 -0
  345. torchrl/testing/mocking_classes.py +2720 -0
  346. torchrl/testing/modules.py +295 -0
  347. torchrl/testing/mp_helpers.py +15 -0
  348. torchrl/testing/ray_helpers.py +293 -0
  349. torchrl/testing/utils.py +190 -0
  350. torchrl/trainers/__init__.py +42 -0
  351. torchrl/trainers/algorithms/__init__.py +11 -0
  352. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  353. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  354. torchrl/trainers/algorithms/configs/common.py +41 -0
  355. torchrl/trainers/algorithms/configs/data.py +308 -0
  356. torchrl/trainers/algorithms/configs/envs.py +104 -0
  357. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  358. torchrl/trainers/algorithms/configs/logging.py +80 -0
  359. torchrl/trainers/algorithms/configs/modules.py +570 -0
  360. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  361. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  362. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  363. torchrl/trainers/algorithms/configs/utils.py +252 -0
  364. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  365. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  366. torchrl/trainers/algorithms/ppo.py +373 -0
  367. torchrl/trainers/algorithms/sac.py +308 -0
  368. torchrl/trainers/helpers/__init__.py +40 -0
  369. torchrl/trainers/helpers/collectors.py +416 -0
  370. torchrl/trainers/helpers/envs.py +573 -0
  371. torchrl/trainers/helpers/logger.py +33 -0
  372. torchrl/trainers/helpers/losses.py +132 -0
  373. torchrl/trainers/helpers/models.py +658 -0
  374. torchrl/trainers/helpers/replay_buffer.py +59 -0
  375. torchrl/trainers/helpers/trainers.py +301 -0
  376. torchrl/trainers/trainers.py +2052 -0
  377. torchrl/weight_update/__init__.py +33 -0
  378. torchrl/weight_update/_distributed.py +749 -0
  379. torchrl/weight_update/_mp.py +624 -0
  380. torchrl/weight_update/_noupdate.py +102 -0
  381. torchrl/weight_update/_ray.py +1032 -0
  382. torchrl/weight_update/_rpc.py +284 -0
  383. torchrl/weight_update/_shared.py +891 -0
  384. torchrl/weight_update/llm/__init__.py +32 -0
  385. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  386. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  387. torchrl/weight_update/utils.py +73 -0
  388. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  389. torchrl-0.11.0.dist-info/LICENSE +21 -0
  390. torchrl-0.11.0.dist-info/METADATA +1307 -0
  391. torchrl-0.11.0.dist-info/RECORD +394 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,643 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ import importlib.util
8
+ import json
9
+ import os.path
10
+ import shutil
11
+ import tempfile
12
+ from collections import defaultdict
13
+ from collections.abc import Callable
14
+ from contextlib import nullcontext
15
+ from dataclasses import asdict
16
+ from pathlib import Path
17
+
18
+ import torch
19
+ from tensordict import (
20
+ is_non_tensor,
21
+ is_tensor_collection,
22
+ NonTensorData,
23
+ NonTensorStack,
24
+ PersistentTensorDict,
25
+ set_list_to_stack,
26
+ TensorDict,
27
+ TensorDictBase,
28
+ )
29
+
30
+ from torchrl._utils import KeyDependentDefaultDict, logger as torchrl_logger
31
+ from torchrl.data.datasets.common import BaseDatasetExperienceReplay
32
+ from torchrl.data.datasets.utils import _get_root_dir
33
+ from torchrl.data.replay_buffers.samplers import Sampler
34
+ from torchrl.data.replay_buffers.storages import TensorStorage
35
+ from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer
36
+ from torchrl.data.tensor_specs import Bounded, Categorical, Composite, Unbounded
37
+ from torchrl.envs.utils import _classproperty
38
+
39
+ _has_tqdm = importlib.util.find_spec("tqdm", None) is not None
40
+ _has_minari = importlib.util.find_spec("minari", None) is not None
41
+
42
+ _NAME_MATCH = KeyDependentDefaultDict(lambda key: key)
43
+ _NAME_MATCH["observations"] = "observation"
44
+ _NAME_MATCH["rewards"] = "reward"
45
+ _NAME_MATCH["truncations"] = "truncated"
46
+ _NAME_MATCH["terminations"] = "terminated"
47
+ _NAME_MATCH["actions"] = "action"
48
+ _NAME_MATCH["infos"] = "info"
49
+
50
+
51
+ _DTYPE_DIR = {
52
+ "float16": torch.float16,
53
+ "float32": torch.float32,
54
+ "float64": torch.float64,
55
+ "int64": torch.int64,
56
+ "int32": torch.int32,
57
+ "uint8": torch.uint8,
58
+ }
59
+
60
+
61
+ class MinariExperienceReplay(BaseDatasetExperienceReplay):
62
+ """Minari Experience replay dataset.
63
+
64
+ Learn more about Minari on their website: https://minari.farama.org/
65
+
66
+ The data format follows the :ref:`TED convention <TED-format>`.
67
+
68
+ Args:
69
+ dataset_id (str): The dataset to be downloaded. Must be part of MinariExperienceReplay.available_datasets
70
+ batch_size (int): Batch-size used during sampling. Can be overridden by `data.sample(batch_size)` if
71
+ necessary.
72
+
73
+ Keyword Args:
74
+ root (Path or str, optional): The Minari dataset root directory.
75
+ The actual dataset memory-mapped files will be saved under
76
+ `<root>/<dataset_id>`. If none is provided, it defaults to
77
+ `~/.cache/torchrl/atari`.minari`.
78
+ download (bool or str, optional): Whether the dataset should be downloaded if
79
+ not found. Defaults to ``True``. Download can also be passed as ``"force"``,
80
+ in which case the downloaded data will be overwritten.
81
+ sampler (Sampler, optional): the sampler to be used. If none is provided
82
+ a default RandomSampler() will be used.
83
+ writer (Writer, optional): the writer to be used. If none is provided
84
+ a default :class:`~torchrl.data.replay_buffers.writers.ImmutableDatasetWriter` will be used.
85
+ collate_fn (callable, optional): merges a list of samples to form a
86
+ mini-batch of Tensor(s)/outputs. Used when using batched
87
+ loading from a map-style dataset.
88
+ pin_memory (bool): whether pin_memory() should be called on the rb
89
+ samples.
90
+ prefetch (int, optional): number of next batches to be prefetched
91
+ using multithreading.
92
+ transform (Transform, optional): Transform to be executed when sample() is called.
93
+ To chain transforms use the :class:`~torchrl.envs.transforms.transforms.Compose` class.
94
+ split_trajs (bool, optional): if ``True``, the trajectories will be split
95
+ along the first dimension and padded to have a matching shape.
96
+ To split the trajectories, the ``"done"`` signal will be used, which
97
+ is recovered via ``done = truncated | terminated``. In other words,
98
+ it is assumed that any ``truncated`` or ``terminated`` signal is
99
+ equivalent to the end of a trajectory.
100
+ Defaults to ``False``.
101
+ load_from_local_minari (bool, optional): if ``True``, the dataset will be loaded directly
102
+ from the local Minari cache (typically located at ``~/.minari/datasets``),
103
+ bypassing any remote download. This is useful when working with custom
104
+ Minari datasets previously generated and stored locally, or when network
105
+ access should be avoided. If the dataset is not found in the expected
106
+ cache directory, a ``FileNotFoundError`` will be raised.
107
+ Defaults to ``False``.
108
+
109
+
110
+ Attributes:
111
+ available_datasets: a list of accepted entries to be downloaded.
112
+
113
+ .. note::
114
+ Text data is currenrtly discarded from the wrapped dataset, as there is not
115
+ PyTorch native way of representing text data.
116
+ If this feature is required, please post an issue on TorchRL's GitHub
117
+ repository.
118
+
119
+ Examples:
120
+ >>> from torchrl.data.datasets.minari_data import MinariExperienceReplay
121
+ >>> data = MinariExperienceReplay("door-human-v1", batch_size=32, download="force")
122
+ >>> for sample in data:
123
+ ... torchrl_logger.info(sample)
124
+ ... break
125
+ TensorDict(
126
+ fields={
127
+ action: Tensor(shape=torch.Size([32, 28]), device=cpu, dtype=torch.float32, is_shared=False),
128
+ index: Tensor(shape=torch.Size([32]), device=cpu, dtype=torch.int64, is_shared=False),
129
+ info: TensorDict(
130
+ fields={
131
+ success: Tensor(shape=torch.Size([32]), device=cpu, dtype=torch.bool, is_shared=False)},
132
+ batch_size=torch.Size([32]),
133
+ device=cpu,
134
+ is_shared=False),
135
+ next: TensorDict(
136
+ fields={
137
+ done: Tensor(shape=torch.Size([32, 1]), device=cpu, dtype=torch.bool, is_shared=False),
138
+ info: TensorDict(
139
+ fields={
140
+ success: Tensor(shape=torch.Size([32]), device=cpu, dtype=torch.bool, is_shared=False)},
141
+ batch_size=torch.Size([32]),
142
+ device=cpu,
143
+ is_shared=False),
144
+ observation: Tensor(shape=torch.Size([32, 39]), device=cpu, dtype=torch.float64, is_shared=False),
145
+ reward: Tensor(shape=torch.Size([32, 1]), device=cpu, dtype=torch.float64, is_shared=False),
146
+ state: TensorDict(
147
+ fields={
148
+ door_body_pos: Tensor(shape=torch.Size([32, 3]), device=cpu, dtype=torch.float64, is_shared=False),
149
+ qpos: Tensor(shape=torch.Size([32, 30]), device=cpu, dtype=torch.float64, is_shared=False),
150
+ qvel: Tensor(shape=torch.Size([32, 30]), device=cpu, dtype=torch.float64, is_shared=False)},
151
+ batch_size=torch.Size([32]),
152
+ device=cpu,
153
+ is_shared=False),
154
+ terminated: Tensor(shape=torch.Size([32, 1]), device=cpu, dtype=torch.bool, is_shared=False),
155
+ truncated: Tensor(shape=torch.Size([32, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
156
+ batch_size=torch.Size([32]),
157
+ device=cpu,
158
+ is_shared=False),
159
+ observation: Tensor(shape=torch.Size([32, 39]), device=cpu, dtype=torch.float64, is_shared=False),
160
+ state: TensorDict(
161
+ fields={
162
+ door_body_pos: Tensor(shape=torch.Size([32, 3]), device=cpu, dtype=torch.float64, is_shared=False),
163
+ qpos: Tensor(shape=torch.Size([32, 30]), device=cpu, dtype=torch.float64, is_shared=False),
164
+ qvel: Tensor(shape=torch.Size([32, 30]), device=cpu, dtype=torch.float64, is_shared=False)},
165
+ batch_size=torch.Size([32]),
166
+ device=cpu,
167
+ is_shared=False)},
168
+ batch_size=torch.Size([32]),
169
+ device=cpu,
170
+ is_shared=False)
171
+
172
+ """
173
+
174
+ def __init__(
175
+ self,
176
+ dataset_id,
177
+ batch_size: int,
178
+ *,
179
+ root: str | Path | None = None,
180
+ download: bool = True,
181
+ sampler: Sampler | None = None,
182
+ writer: Writer | None = None,
183
+ collate_fn: Callable | None = None,
184
+ pin_memory: bool = False,
185
+ prefetch: int | None = None,
186
+ transform: torchrl.envs.Transform | None = None, # noqa-F821
187
+ split_trajs: bool = False,
188
+ load_from_local_minari: bool = False,
189
+ ):
190
+ self.dataset_id = dataset_id
191
+ if root is None:
192
+ root = _get_root_dir("minari")
193
+ os.makedirs(root, exist_ok=True)
194
+ self.root = root
195
+ self.split_trajs = split_trajs
196
+ self.download = download
197
+ self.load_from_local_minari = load_from_local_minari
198
+
199
+ if (
200
+ self.download == "force"
201
+ or (self.download and not self._is_downloaded())
202
+ or self.load_from_local_minari
203
+ ):
204
+ if self.download == "force":
205
+ try:
206
+ if os.path.exists(self.data_path_root):
207
+ shutil.rmtree(self.data_path_root)
208
+ if self.data_path != self.data_path_root:
209
+ shutil.rmtree(self.data_path)
210
+ except FileNotFoundError:
211
+ pass
212
+ storage = self._download_and_preproc()
213
+ elif self.split_trajs and not os.path.exists(self.data_path):
214
+ storage = self._make_split()
215
+ else:
216
+ storage = self._load()
217
+ storage = TensorStorage(storage)
218
+
219
+ if writer is None:
220
+ writer = ImmutableDatasetWriter()
221
+
222
+ super().__init__(
223
+ storage=storage,
224
+ sampler=sampler,
225
+ writer=writer,
226
+ collate_fn=collate_fn,
227
+ pin_memory=pin_memory,
228
+ prefetch=prefetch,
229
+ batch_size=batch_size,
230
+ transform=transform,
231
+ )
232
+
233
+ @_classproperty
234
+ def available_datasets(self):
235
+ if not _has_minari:
236
+ raise ImportError("minari library not found.")
237
+ import minari
238
+
239
+ return minari.list_remote_datasets().keys()
240
+
241
+ def _is_downloaded(self):
242
+ return os.path.exists(self.data_path_root)
243
+
244
+ @property
245
+ def data_path(self) -> Path:
246
+ if self.split_trajs:
247
+ return Path(self.root) / (self.dataset_id + "_split")
248
+ return self.data_path_root
249
+
250
+ @property
251
+ def data_path_root(self) -> Path:
252
+ return Path(self.root) / self.dataset_id
253
+
254
+ @property
255
+ def metadata_path(self) -> Path:
256
+ return Path(self.root) / self.dataset_id / "env_metadata.json"
257
+
258
+ def _download_and_preproc(self):
259
+ if not _has_minari:
260
+ raise ImportError("minari library not found.")
261
+ import minari
262
+
263
+ if _has_tqdm:
264
+ from tqdm import tqdm
265
+
266
+ prev_minari_datasets_path_save = prev_minari_datasets_path = os.environ.get(
267
+ "MINARI_DATASETS_PATH"
268
+ )
269
+ try:
270
+ if prev_minari_datasets_path is None:
271
+ prev_minari_datasets_path = os.path.expanduser("~/.minari/datasets")
272
+ with tempfile.TemporaryDirectory() as tmpdir:
273
+
274
+ total_steps = 0
275
+ td_data = TensorDict()
276
+
277
+ if self.load_from_local_minari:
278
+ # Load minari dataset from user's local Minari cache
279
+
280
+ parent_dir = (
281
+ Path(prev_minari_datasets_path) / self.dataset_id / "data"
282
+ )
283
+ h5_path = parent_dir / "main_data.hdf5"
284
+
285
+ if not h5_path.exists():
286
+ raise FileNotFoundError(
287
+ f"{h5_path} does not exist in local Minari cache!"
288
+ )
289
+
290
+ torchrl_logger.info(
291
+ f"loading dataset from local Minari cache at {h5_path}"
292
+ )
293
+ h5_data = PersistentTensorDict.from_h5(h5_path)
294
+ h5_data = h5_data.to_tensordict()
295
+
296
+ else:
297
+ # temporarily change the minari cache path
298
+ prev_minari_datasets_path_save2 = os.environ.get(
299
+ "MINARI_DATASETS_PATH"
300
+ )
301
+ os.environ["MINARI_DATASETS_PATH"] = tmpdir
302
+ try:
303
+ minari.download_dataset(dataset_id=self.dataset_id)
304
+ finally:
305
+ if prev_minari_datasets_path_save2 is not None:
306
+ os.environ[
307
+ "MINARI_DATASETS_PATH"
308
+ ] = prev_minari_datasets_path_save2
309
+
310
+ parent_dir = Path(tmpdir) / self.dataset_id / "data"
311
+
312
+ torchrl_logger.info(
313
+ "first read through data to create data structure..."
314
+ )
315
+ h5_data = PersistentTensorDict.from_h5(
316
+ parent_dir / "main_data.hdf5"
317
+ )
318
+ h5_data = h5_data.to_tensordict()
319
+
320
+ # populate the tensordict
321
+ episode_dict = {}
322
+ dataset_has_nontensor = False
323
+ for i, (episode_key, episode) in enumerate(h5_data.items()):
324
+ episode_num = int(episode_key[len("episode_") :])
325
+ episode_len = episode["actions"].shape[0]
326
+ episode_dict[episode_num] = (episode_key, episode_len)
327
+ # Get the total number of steps for the dataset
328
+ total_steps += episode_len
329
+ if i == 0:
330
+ td_data.set("episode", 0)
331
+ seen = set()
332
+ for key, val in episode.items():
333
+ match = _NAME_MATCH[key]
334
+ if match in seen:
335
+ continue
336
+ seen.add(match)
337
+ if key in ("observations", "state", "infos"):
338
+ val = episode[key]
339
+ if is_tensor_collection(val) and any(
340
+ isinstance(
341
+ val.get(k), (NonTensorData, NonTensorStack)
342
+ )
343
+ for k in val.keys()
344
+ ):
345
+ non_tensor_probe = val.clone()
346
+ _extract_nontensor_fields(
347
+ non_tensor_probe, recursive=True
348
+ )
349
+ dataset_has_nontensor = True
350
+ if (
351
+ not val.shape
352
+ ): # no need for this, we don't need the proper length: or steps != val.shape[0] - 1:
353
+ if val.is_empty():
354
+ continue
355
+ if is_non_tensor(val):
356
+ continue
357
+ val = _patch_info(val)
358
+ td_data.set(("next", match), torch.zeros_like(val[0]))
359
+ td_data.set(match, torch.zeros_like(val[0]))
360
+ elif key not in ("terminations", "truncations", "rewards"):
361
+ td_data.set(match, torch.zeros_like(val[0]))
362
+ else:
363
+ td_data.set(
364
+ ("next", match),
365
+ torch.zeros_like(val[0].unsqueeze(-1)),
366
+ )
367
+
368
+ # give it the proper size
369
+ td_data["next", "done"] = (
370
+ td_data["next", "truncated"] | td_data["next", "terminated"]
371
+ )
372
+ if "terminated" in td_data.keys():
373
+ td_data["done"] = td_data["truncated"] | td_data["terminated"]
374
+ td_data = td_data.expand(total_steps).contiguous()
375
+ # save to designated location
376
+ torchrl_logger.info(
377
+ f"creating tensordict data in {self.data_path_root}: "
378
+ )
379
+ if dataset_has_nontensor:
380
+ _preallocate_nontensor_fields(
381
+ td_data, episode, total_steps, name_map=_NAME_MATCH
382
+ )
383
+ torchrl_logger.info(f"tensordict structure: {td_data}")
384
+
385
+ torchrl_logger.info(
386
+ f"Reading data from {max(*episode_dict) + 1} episodes"
387
+ )
388
+ index = 0
389
+ with tqdm(total=total_steps) if _has_tqdm else nullcontext() as pbar:
390
+ # iterate over episodes and populate the tensordict
391
+ for episode_num in sorted(episode_dict):
392
+ episode_key, steps = episode_dict[episode_num]
393
+ episode = _patch_nontensor_data_to_stack(
394
+ h5_data.get(episode_key)
395
+ )
396
+ idx = slice(index, (index + steps))
397
+ data_view = td_data[idx]
398
+ data_view.fill_("episode", episode_num)
399
+ for key, val in episode.items():
400
+ match = _NAME_MATCH[key]
401
+ if key in (
402
+ "observations",
403
+ "state",
404
+ "infos",
405
+ ):
406
+ if not val.shape or steps != val.shape[0] - 1:
407
+ if val.is_empty():
408
+ continue
409
+ if is_non_tensor(val):
410
+ continue
411
+ val = _patch_info(val)
412
+ if steps != val.shape[0] - 1:
413
+ raise RuntimeError(
414
+ f"Mismatching number of steps for key {key}: was {steps} but got {val.shape[0] - 1}."
415
+ )
416
+ val_next = val[1:].clone()
417
+ val_copy = val[:-1].clone()
418
+
419
+ data_view["next", match].copy_(val_next)
420
+ data_view[match].copy_(val_copy)
421
+
422
+ if is_tensor_collection(val_next):
423
+ non_tensors_next = _extract_nontensor_fields(
424
+ val_next
425
+ )
426
+ non_tensors_now = _extract_nontensor_fields(
427
+ val_copy
428
+ )
429
+ data_view["next", match].update_(non_tensors_next)
430
+ data_view[match].update_(non_tensors_now)
431
+
432
+ elif key not in ("terminations", "truncations", "rewards"):
433
+ if steps is None:
434
+ steps = val.shape[0]
435
+ else:
436
+ if steps != val.shape[0]:
437
+ raise RuntimeError(
438
+ f"Mismatching number of steps for key {key}: was {steps} but got {val.shape[0]}."
439
+ )
440
+ data_view[match].copy_(val)
441
+ else:
442
+ if steps is None:
443
+ steps = val.shape[0]
444
+ else:
445
+ if steps != val.shape[0]:
446
+ raise RuntimeError(
447
+ f"Mismatching number of steps for key {key}: was {steps} but got {val.shape[0]}."
448
+ )
449
+ data_view[("next", match)].copy_(val.unsqueeze(-1))
450
+ data_view["next", "done"].copy_(
451
+ data_view["next", "terminated"]
452
+ | data_view["next", "truncated"]
453
+ )
454
+ if "done" in data_view.keys():
455
+ data_view["done"].copy_(
456
+ data_view["terminated"] | data_view["truncated"]
457
+ )
458
+ if pbar is not None:
459
+ pbar.update(steps)
460
+ pbar.set_description(
461
+ f"index={index} - episode num {episode_num}"
462
+ )
463
+ index += steps
464
+
465
+ td_data = td_data.memmap_like(self.data_path_root)
466
+ # Add a "done" entry
467
+ if self.split_trajs:
468
+ with td_data.unlock_():
469
+ from torchrl.collectors.utils import split_trajectories
470
+
471
+ td_data = split_trajectories(td_data).memmap_(self.data_path)
472
+ with open(self.metadata_path, "w") as metadata_file:
473
+ dataset = minari.load_dataset(self.dataset_id)
474
+ self.metadata = asdict(dataset.spec)
475
+ self.metadata["observation_space"] = _spec_to_dict(
476
+ self.metadata["observation_space"]
477
+ )
478
+ self.metadata["action_space"] = _spec_to_dict(
479
+ self.metadata["action_space"]
480
+ )
481
+ json.dump(self.metadata, metadata_file)
482
+ self._load_and_proc_metadata()
483
+ return td_data
484
+ finally:
485
+ if prev_minari_datasets_path_save is not None:
486
+ os.environ["MINARI_DATASETS_PATH"] = prev_minari_datasets_path_save
487
+
488
+ def _make_split(self):
489
+ from torchrl.collectors.utils import split_trajectories
490
+
491
+ self._load_and_proc_metadata()
492
+ td_data = TensorDict.load_memmap(self.data_path_root)
493
+ td_data = split_trajectories(td_data).memmap_(self.data_path)
494
+ return td_data
495
+
496
+ def _load(self):
497
+ self._load_and_proc_metadata()
498
+ return TensorDict.load_memmap(self.data_path)
499
+
500
+ def _load_and_proc_metadata(self):
501
+ with open(self.metadata_path) as file:
502
+ self.metadata = json.load(file)
503
+ self.metadata["observation_space"] = _proc_spec(
504
+ self.metadata["observation_space"]
505
+ )
506
+ self.metadata["action_space"] = _proc_spec(self.metadata["action_space"])
507
+
508
+
509
+ def _proc_spec(spec):
510
+ if spec is None:
511
+ return
512
+ if spec["type"] == "Dict":
513
+ return Composite(
514
+ {key: _proc_spec(subspec) for key, subspec in spec["subspaces"].items()}
515
+ )
516
+ elif spec["type"] == "Box":
517
+ if all(item == -float("inf") for item in spec["low"]) and all(
518
+ item == float("inf") for item in spec["high"]
519
+ ):
520
+ return Unbounded(spec["shape"], dtype=_DTYPE_DIR[spec["dtype"]])
521
+ return Bounded(
522
+ shape=spec["shape"],
523
+ low=torch.as_tensor(spec["low"]),
524
+ high=torch.as_tensor(spec["high"]),
525
+ dtype=_DTYPE_DIR[spec["dtype"]],
526
+ )
527
+ elif spec["type"] == "Discrete":
528
+ return Categorical(
529
+ spec["n"], shape=spec["shape"], dtype=_DTYPE_DIR[spec["dtype"]]
530
+ )
531
+ else:
532
+ raise NotImplementedError(f"{type(spec)}")
533
+
534
+
535
+ def _spec_to_dict(spec):
536
+ from torchrl.envs.libs.gym import gym_backend
537
+
538
+ if isinstance(spec, gym_backend("spaces").Dict):
539
+ return {
540
+ "type": "Dict",
541
+ "subspaces": {key: _spec_to_dict(val) for key, val in spec.items()},
542
+ }
543
+ if isinstance(spec, gym_backend("spaces").Box):
544
+ return {
545
+ "type": "Box",
546
+ "low": spec.low.tolist(),
547
+ "high": spec.high.tolist(),
548
+ "dtype": str(spec.dtype),
549
+ "shape": tuple(spec.shape),
550
+ }
551
+ if isinstance(spec, gym_backend("spaces").Discrete):
552
+ return {
553
+ "type": "Discrete",
554
+ "dtype": str(spec.dtype),
555
+ "n": int(spec.n),
556
+ "shape": tuple(spec.shape),
557
+ }
558
+ if isinstance(spec, gym_backend("spaces").Text):
559
+ return
560
+ raise NotImplementedError(f"{type(spec)}, {str(spec)}")
561
+
562
+
563
+ def _patch_info(info_td):
564
+ # Some info dicts have tensors with one less element than others
565
+ # We explicitly assume that the missing item is in the first position because
566
+ # it wasn't given at reset time.
567
+ # An alternative explanation could be that the last element is missing because
568
+ # deemed useless for training...
569
+ unique_shapes = defaultdict(list)
570
+ for subkey, subval in info_td.items():
571
+ unique_shapes[subval.shape[0]].append(subkey)
572
+ if len(unique_shapes) == 1:
573
+ unique_shapes[subval.shape[0] + 1] = []
574
+ if len(unique_shapes) != 2:
575
+ raise RuntimeError(
576
+ f"Unique shapes in a sub-tensordict can only be of length 2, got shapes {unique_shapes}."
577
+ )
578
+ val_td = info_td.to_tensordict()
579
+ min_shape = min(*unique_shapes) # can only be found at root
580
+ max_shape = min_shape + 1
581
+ val_td_sel = val_td.select(*unique_shapes[min_shape])
582
+ val_td_sel = val_td_sel.apply(
583
+ lambda x: torch.cat([torch.zeros_like(x[:1]), x], 0), batch_size=[min_shape + 1]
584
+ )
585
+ source = val_td.select(*unique_shapes[max_shape])
586
+ # make sure source has no batch size
587
+ source.batch_size = ()
588
+ if not source.is_empty():
589
+ val_td_sel.update(source, update_batch_size=True)
590
+ return val_td_sel
591
+
592
+
593
+ def _patch_nontensor_data_to_stack(tensordict: TensorDictBase):
594
+ """Recursively replaces all NonTensorData fields in the TensorDict with NonTensorStack."""
595
+ for key, val in tensordict.items():
596
+ if isinstance(val, TensorDictBase):
597
+ _patch_nontensor_data_to_stack(val) # in-place recursive
598
+ elif isinstance(val, NonTensorData):
599
+ data_list = list(val.data)
600
+ with set_list_to_stack(True):
601
+ tensordict[key] = data_list
602
+ return tensordict
603
+
604
+
605
+ def _extract_nontensor_fields(
606
+ tensordict: TensorDictBase, recursive: bool = False
607
+ ) -> TensorDict:
608
+ """Deletes the NonTensor fields from tensordict and returns the deleted tensordict."""
609
+ extracted = {}
610
+ for key in list(tensordict.keys()):
611
+ val = tensordict.get(key)
612
+ if is_non_tensor(val):
613
+ extracted[key] = val
614
+ del tensordict[key]
615
+ elif recursive and is_tensor_collection(val):
616
+ nested = _extract_nontensor_fields(val, recursive=True)
617
+ if len(nested) > 0:
618
+ extracted[key] = nested
619
+ return TensorDict(extracted, batch_size=tensordict.batch_size)
620
+
621
+
622
+ def _preallocate_nontensor_fields(
623
+ td_data: TensorDictBase, example: TensorDictBase, total_steps: int, name_map: dict
624
+ ):
625
+ """Preallocates NonTensorStack fields in td_data based on an example TensorDict, applying key remapping."""
626
+ with set_list_to_stack(True):
627
+
628
+ def _recurse(src_td: TensorDictBase, dst_td: TensorDictBase, prefix=()):
629
+ for key, val in src_td.items():
630
+ mapped_key = name_map.get(key, key)
631
+ full_dst_key = prefix + (mapped_key,)
632
+
633
+ if is_non_tensor(val):
634
+ dummy_stack = NonTensorStack(
635
+ *[total_steps for _ in range(total_steps)]
636
+ )
637
+ dst_td.set(full_dst_key, dummy_stack)
638
+ dst_td.set(("next",) + full_dst_key, dummy_stack)
639
+
640
+ elif is_tensor_collection(val):
641
+ _recurse(val, dst_td, full_dst_key)
642
+
643
+ _recurse(example, td_data)