torchrl 0.11.0__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (394) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/__init__.py +144 -0
  97. torchrl/_extension.py +74 -0
  98. torchrl/_torchrl.cp314-win_amd64.pyd +0 -0
  99. torchrl/_utils.py +1431 -0
  100. torchrl/collectors/__init__.py +48 -0
  101. torchrl/collectors/_base.py +1058 -0
  102. torchrl/collectors/_constants.py +88 -0
  103. torchrl/collectors/_multi_async.py +324 -0
  104. torchrl/collectors/_multi_base.py +1805 -0
  105. torchrl/collectors/_multi_sync.py +464 -0
  106. torchrl/collectors/_runner.py +581 -0
  107. torchrl/collectors/_single.py +2009 -0
  108. torchrl/collectors/_single_async.py +259 -0
  109. torchrl/collectors/collectors.py +62 -0
  110. torchrl/collectors/distributed/__init__.py +32 -0
  111. torchrl/collectors/distributed/default_configs.py +133 -0
  112. torchrl/collectors/distributed/generic.py +1306 -0
  113. torchrl/collectors/distributed/ray.py +1092 -0
  114. torchrl/collectors/distributed/rpc.py +1006 -0
  115. torchrl/collectors/distributed/sync.py +731 -0
  116. torchrl/collectors/distributed/utils.py +160 -0
  117. torchrl/collectors/llm/__init__.py +10 -0
  118. torchrl/collectors/llm/base.py +494 -0
  119. torchrl/collectors/llm/ray_collector.py +275 -0
  120. torchrl/collectors/llm/utils.py +36 -0
  121. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  122. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  123. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  124. torchrl/collectors/utils.py +433 -0
  125. torchrl/collectors/weight_update.py +591 -0
  126. torchrl/csrc/numpy_utils.h +38 -0
  127. torchrl/csrc/pybind.cpp +27 -0
  128. torchrl/csrc/segment_tree.h +458 -0
  129. torchrl/csrc/torch_utils.h +34 -0
  130. torchrl/csrc/utils.cpp +48 -0
  131. torchrl/csrc/utils.h +31 -0
  132. torchrl/data/__init__.py +187 -0
  133. torchrl/data/datasets/__init__.py +58 -0
  134. torchrl/data/datasets/atari_dqn.py +878 -0
  135. torchrl/data/datasets/common.py +281 -0
  136. torchrl/data/datasets/d4rl.py +489 -0
  137. torchrl/data/datasets/d4rl_infos.py +187 -0
  138. torchrl/data/datasets/gen_dgrl.py +375 -0
  139. torchrl/data/datasets/minari_data.py +643 -0
  140. torchrl/data/datasets/openml.py +177 -0
  141. torchrl/data/datasets/openx.py +798 -0
  142. torchrl/data/datasets/roboset.py +363 -0
  143. torchrl/data/datasets/utils.py +11 -0
  144. torchrl/data/datasets/vd4rl.py +432 -0
  145. torchrl/data/llm/__init__.py +34 -0
  146. torchrl/data/llm/dataset.py +491 -0
  147. torchrl/data/llm/history.py +1378 -0
  148. torchrl/data/llm/prompt.py +198 -0
  149. torchrl/data/llm/reward.py +225 -0
  150. torchrl/data/llm/topk.py +186 -0
  151. torchrl/data/llm/utils.py +543 -0
  152. torchrl/data/map/__init__.py +21 -0
  153. torchrl/data/map/hash.py +185 -0
  154. torchrl/data/map/query.py +204 -0
  155. torchrl/data/map/tdstorage.py +363 -0
  156. torchrl/data/map/tree.py +1434 -0
  157. torchrl/data/map/utils.py +103 -0
  158. torchrl/data/postprocs/__init__.py +8 -0
  159. torchrl/data/postprocs/postprocs.py +391 -0
  160. torchrl/data/replay_buffers/__init__.py +99 -0
  161. torchrl/data/replay_buffers/checkpointers.py +622 -0
  162. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  163. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  164. torchrl/data/replay_buffers/samplers.py +2578 -0
  165. torchrl/data/replay_buffers/scheduler.py +265 -0
  166. torchrl/data/replay_buffers/storages.py +2412 -0
  167. torchrl/data/replay_buffers/utils.py +1042 -0
  168. torchrl/data/replay_buffers/writers.py +781 -0
  169. torchrl/data/tensor_specs.py +7101 -0
  170. torchrl/data/utils.py +334 -0
  171. torchrl/envs/__init__.py +265 -0
  172. torchrl/envs/async_envs.py +1105 -0
  173. torchrl/envs/batched_envs.py +3093 -0
  174. torchrl/envs/common.py +4241 -0
  175. torchrl/envs/custom/__init__.py +11 -0
  176. torchrl/envs/custom/chess.py +617 -0
  177. torchrl/envs/custom/llm.py +214 -0
  178. torchrl/envs/custom/pendulum.py +401 -0
  179. torchrl/envs/custom/san_moves.txt +29274 -0
  180. torchrl/envs/custom/tictactoeenv.py +288 -0
  181. torchrl/envs/env_creator.py +263 -0
  182. torchrl/envs/gym_like.py +752 -0
  183. torchrl/envs/libs/__init__.py +68 -0
  184. torchrl/envs/libs/_gym_utils.py +326 -0
  185. torchrl/envs/libs/brax.py +846 -0
  186. torchrl/envs/libs/dm_control.py +544 -0
  187. torchrl/envs/libs/envpool.py +447 -0
  188. torchrl/envs/libs/gym.py +2239 -0
  189. torchrl/envs/libs/habitat.py +138 -0
  190. torchrl/envs/libs/isaac_lab.py +87 -0
  191. torchrl/envs/libs/isaacgym.py +203 -0
  192. torchrl/envs/libs/jax_utils.py +166 -0
  193. torchrl/envs/libs/jumanji.py +963 -0
  194. torchrl/envs/libs/meltingpot.py +599 -0
  195. torchrl/envs/libs/openml.py +153 -0
  196. torchrl/envs/libs/openspiel.py +652 -0
  197. torchrl/envs/libs/pettingzoo.py +1042 -0
  198. torchrl/envs/libs/procgen.py +351 -0
  199. torchrl/envs/libs/robohive.py +429 -0
  200. torchrl/envs/libs/smacv2.py +645 -0
  201. torchrl/envs/libs/unity_mlagents.py +891 -0
  202. torchrl/envs/libs/utils.py +147 -0
  203. torchrl/envs/libs/vmas.py +813 -0
  204. torchrl/envs/llm/__init__.py +63 -0
  205. torchrl/envs/llm/chat.py +730 -0
  206. torchrl/envs/llm/datasets/README.md +4 -0
  207. torchrl/envs/llm/datasets/__init__.py +17 -0
  208. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  209. torchrl/envs/llm/datasets/ifeval.py +274 -0
  210. torchrl/envs/llm/envs.py +789 -0
  211. torchrl/envs/llm/libs/README.md +3 -0
  212. torchrl/envs/llm/libs/__init__.py +8 -0
  213. torchrl/envs/llm/libs/mlgym.py +869 -0
  214. torchrl/envs/llm/reward/__init__.py +10 -0
  215. torchrl/envs/llm/reward/gsm8k.py +324 -0
  216. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  217. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  218. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  222. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  223. torchrl/envs/llm/transforms/__init__.py +55 -0
  224. torchrl/envs/llm/transforms/browser.py +292 -0
  225. torchrl/envs/llm/transforms/dataloading.py +859 -0
  226. torchrl/envs/llm/transforms/format.py +73 -0
  227. torchrl/envs/llm/transforms/kl.py +1544 -0
  228. torchrl/envs/llm/transforms/policy_version.py +189 -0
  229. torchrl/envs/llm/transforms/reason.py +323 -0
  230. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  231. torchrl/envs/llm/transforms/tools.py +1955 -0
  232. torchrl/envs/model_based/__init__.py +9 -0
  233. torchrl/envs/model_based/common.py +180 -0
  234. torchrl/envs/model_based/dreamer.py +112 -0
  235. torchrl/envs/transforms/__init__.py +147 -0
  236. torchrl/envs/transforms/functional.py +48 -0
  237. torchrl/envs/transforms/gym_transforms.py +203 -0
  238. torchrl/envs/transforms/module.py +341 -0
  239. torchrl/envs/transforms/r3m.py +372 -0
  240. torchrl/envs/transforms/ray_service.py +663 -0
  241. torchrl/envs/transforms/rb_transforms.py +214 -0
  242. torchrl/envs/transforms/transforms.py +11835 -0
  243. torchrl/envs/transforms/utils.py +94 -0
  244. torchrl/envs/transforms/vc1.py +307 -0
  245. torchrl/envs/transforms/vecnorm.py +845 -0
  246. torchrl/envs/transforms/vip.py +407 -0
  247. torchrl/envs/utils.py +1718 -0
  248. torchrl/envs/vec_envs.py +11 -0
  249. torchrl/modules/__init__.py +206 -0
  250. torchrl/modules/distributions/__init__.py +73 -0
  251. torchrl/modules/distributions/continuous.py +830 -0
  252. torchrl/modules/distributions/discrete.py +908 -0
  253. torchrl/modules/distributions/truncated_normal.py +187 -0
  254. torchrl/modules/distributions/utils.py +233 -0
  255. torchrl/modules/llm/__init__.py +62 -0
  256. torchrl/modules/llm/backends/__init__.py +65 -0
  257. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  258. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  259. torchrl/modules/llm/backends/vllm/base.py +72 -0
  260. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  261. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  262. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  263. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  264. torchrl/modules/llm/policies/__init__.py +28 -0
  265. torchrl/modules/llm/policies/common.py +1809 -0
  266. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  267. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  268. torchrl/modules/llm/utils.py +23 -0
  269. torchrl/modules/mcts/__init__.py +21 -0
  270. torchrl/modules/mcts/scores.py +579 -0
  271. torchrl/modules/models/__init__.py +86 -0
  272. torchrl/modules/models/batchrenorm.py +119 -0
  273. torchrl/modules/models/decision_transformer.py +179 -0
  274. torchrl/modules/models/exploration.py +731 -0
  275. torchrl/modules/models/llm.py +156 -0
  276. torchrl/modules/models/model_based.py +596 -0
  277. torchrl/modules/models/models.py +1712 -0
  278. torchrl/modules/models/multiagent.py +1067 -0
  279. torchrl/modules/models/recipes/impala.py +185 -0
  280. torchrl/modules/models/utils.py +162 -0
  281. torchrl/modules/planners/__init__.py +10 -0
  282. torchrl/modules/planners/cem.py +228 -0
  283. torchrl/modules/planners/common.py +73 -0
  284. torchrl/modules/planners/mppi.py +265 -0
  285. torchrl/modules/tensordict_module/__init__.py +89 -0
  286. torchrl/modules/tensordict_module/actors.py +2457 -0
  287. torchrl/modules/tensordict_module/common.py +529 -0
  288. torchrl/modules/tensordict_module/exploration.py +814 -0
  289. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  290. torchrl/modules/tensordict_module/rnn.py +1639 -0
  291. torchrl/modules/tensordict_module/sequence.py +132 -0
  292. torchrl/modules/tensordict_module/world_models.py +34 -0
  293. torchrl/modules/utils/__init__.py +38 -0
  294. torchrl/modules/utils/mappings.py +9 -0
  295. torchrl/modules/utils/utils.py +89 -0
  296. torchrl/objectives/__init__.py +78 -0
  297. torchrl/objectives/a2c.py +659 -0
  298. torchrl/objectives/common.py +753 -0
  299. torchrl/objectives/cql.py +1346 -0
  300. torchrl/objectives/crossq.py +710 -0
  301. torchrl/objectives/ddpg.py +453 -0
  302. torchrl/objectives/decision_transformer.py +371 -0
  303. torchrl/objectives/deprecated.py +516 -0
  304. torchrl/objectives/dqn.py +683 -0
  305. torchrl/objectives/dreamer.py +488 -0
  306. torchrl/objectives/functional.py +48 -0
  307. torchrl/objectives/gail.py +258 -0
  308. torchrl/objectives/iql.py +996 -0
  309. torchrl/objectives/llm/__init__.py +30 -0
  310. torchrl/objectives/llm/grpo.py +846 -0
  311. torchrl/objectives/llm/sft.py +482 -0
  312. torchrl/objectives/multiagent/__init__.py +8 -0
  313. torchrl/objectives/multiagent/qmixer.py +396 -0
  314. torchrl/objectives/ppo.py +1669 -0
  315. torchrl/objectives/redq.py +683 -0
  316. torchrl/objectives/reinforce.py +530 -0
  317. torchrl/objectives/sac.py +1580 -0
  318. torchrl/objectives/td3.py +570 -0
  319. torchrl/objectives/td3_bc.py +625 -0
  320. torchrl/objectives/utils.py +782 -0
  321. torchrl/objectives/value/__init__.py +28 -0
  322. torchrl/objectives/value/advantages.py +1956 -0
  323. torchrl/objectives/value/functional.py +1459 -0
  324. torchrl/objectives/value/utils.py +360 -0
  325. torchrl/record/__init__.py +17 -0
  326. torchrl/record/loggers/__init__.py +23 -0
  327. torchrl/record/loggers/common.py +48 -0
  328. torchrl/record/loggers/csv.py +226 -0
  329. torchrl/record/loggers/mlflow.py +142 -0
  330. torchrl/record/loggers/tensorboard.py +139 -0
  331. torchrl/record/loggers/trackio.py +163 -0
  332. torchrl/record/loggers/utils.py +78 -0
  333. torchrl/record/loggers/wandb.py +214 -0
  334. torchrl/record/recorder.py +554 -0
  335. torchrl/services/__init__.py +79 -0
  336. torchrl/services/base.py +109 -0
  337. torchrl/services/ray_service.py +453 -0
  338. torchrl/testing/__init__.py +107 -0
  339. torchrl/testing/assertions.py +179 -0
  340. torchrl/testing/dist_utils.py +122 -0
  341. torchrl/testing/env_creators.py +227 -0
  342. torchrl/testing/env_helper.py +35 -0
  343. torchrl/testing/gym_helpers.py +156 -0
  344. torchrl/testing/llm_mocks.py +119 -0
  345. torchrl/testing/mocking_classes.py +2720 -0
  346. torchrl/testing/modules.py +295 -0
  347. torchrl/testing/mp_helpers.py +15 -0
  348. torchrl/testing/ray_helpers.py +293 -0
  349. torchrl/testing/utils.py +190 -0
  350. torchrl/trainers/__init__.py +42 -0
  351. torchrl/trainers/algorithms/__init__.py +11 -0
  352. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  353. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  354. torchrl/trainers/algorithms/configs/common.py +41 -0
  355. torchrl/trainers/algorithms/configs/data.py +308 -0
  356. torchrl/trainers/algorithms/configs/envs.py +104 -0
  357. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  358. torchrl/trainers/algorithms/configs/logging.py +80 -0
  359. torchrl/trainers/algorithms/configs/modules.py +570 -0
  360. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  361. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  362. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  363. torchrl/trainers/algorithms/configs/utils.py +252 -0
  364. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  365. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  366. torchrl/trainers/algorithms/ppo.py +373 -0
  367. torchrl/trainers/algorithms/sac.py +308 -0
  368. torchrl/trainers/helpers/__init__.py +40 -0
  369. torchrl/trainers/helpers/collectors.py +416 -0
  370. torchrl/trainers/helpers/envs.py +573 -0
  371. torchrl/trainers/helpers/logger.py +33 -0
  372. torchrl/trainers/helpers/losses.py +132 -0
  373. torchrl/trainers/helpers/models.py +658 -0
  374. torchrl/trainers/helpers/replay_buffer.py +59 -0
  375. torchrl/trainers/helpers/trainers.py +301 -0
  376. torchrl/trainers/trainers.py +2052 -0
  377. torchrl/weight_update/__init__.py +33 -0
  378. torchrl/weight_update/_distributed.py +749 -0
  379. torchrl/weight_update/_mp.py +624 -0
  380. torchrl/weight_update/_noupdate.py +102 -0
  381. torchrl/weight_update/_ray.py +1032 -0
  382. torchrl/weight_update/_rpc.py +284 -0
  383. torchrl/weight_update/_shared.py +891 -0
  384. torchrl/weight_update/llm/__init__.py +32 -0
  385. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  386. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  387. torchrl/weight_update/utils.py +73 -0
  388. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  389. torchrl-0.11.0.dist-info/LICENSE +21 -0
  390. torchrl-0.11.0.dist-info/METADATA +1307 -0
  391. torchrl-0.11.0.dist-info/RECORD +394 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,2376 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ import collections
8
+ import contextlib
9
+ import json
10
+ import multiprocessing
11
+ import textwrap
12
+ import threading
13
+ import warnings
14
+ from collections.abc import Callable, Sequence
15
+ from concurrent.futures import ThreadPoolExecutor
16
+ from pathlib import Path
17
+ from typing import Any
18
+
19
+ import numpy as np
20
+ import torch
21
+
22
+ try:
23
+ from torch.compiler import is_compiling
24
+ except ImportError:
25
+ from torch._dynamo import is_compiling
26
+
27
+ from functools import partial, wraps
28
+ from typing import TYPE_CHECKING, TypeVar
29
+
30
+ from tensordict import (
31
+ is_tensor_collection,
32
+ is_tensorclass,
33
+ LazyStackedTensorDict,
34
+ NestedKey,
35
+ TensorDict,
36
+ TensorDictBase,
37
+ unravel_key,
38
+ )
39
+ from tensordict.nn.utils import _set_dispatch_td_nn_modules
40
+ from tensordict.utils import expand_as_right, expand_right
41
+ from torch import Tensor
42
+ from torch.utils._pytree import tree_map
43
+
44
+ from torchrl._utils import accept_remote_rref_udf_invocation, rl_warnings
45
+ from torchrl.data.replay_buffers.samplers import (
46
+ PrioritizedSampler,
47
+ RandomSampler,
48
+ Sampler,
49
+ SamplerEnsemble,
50
+ )
51
+ from torchrl.data.replay_buffers.storages import (
52
+ _get_default_collate,
53
+ _stack_anything,
54
+ ListStorage,
55
+ Storage,
56
+ StorageEnsemble,
57
+ )
58
+ from torchrl.data.replay_buffers.utils import (
59
+ _is_int,
60
+ _reduce,
61
+ _to_numpy,
62
+ _to_torch,
63
+ INT_CLASSES,
64
+ pin_memory_output,
65
+ )
66
+ from torchrl.data.replay_buffers.writers import (
67
+ RoundRobinWriter,
68
+ TensorDictRoundRobinWriter,
69
+ Writer,
70
+ WriterEnsemble,
71
+ )
72
+ from torchrl.data.utils import DEVICE_TYPING
73
+ from torchrl.envs.transforms.transforms import _InvertTransform, Transform
74
+
75
+ T = TypeVar("T")
76
+ if TYPE_CHECKING:
77
+ from typing import Self
78
+ else:
79
+ Self = T
80
+
81
+
82
+ def _maybe_delay_init(func):
83
+ @wraps(func)
84
+ def wrapper(self, *args, **kwargs):
85
+ if self._delayed_init and not self.initialized:
86
+ self._init()
87
+ return func(self, *args, **kwargs)
88
+
89
+ return wrapper
90
+
91
+
92
+ class ReplayBuffer:
93
+ """A generic, composable replay buffer class.
94
+
95
+ Keyword Args:
96
+ storage (Storage, Callable[[], Storage], optional): the storage to be used.
97
+ If a callable is passed, it is used as constructor for the storage.
98
+ If none is provided a default :class:`~torchrl.data.replay_buffers.ListStorage` with
99
+ ``max_size`` of ``1_000`` will be created.
100
+ sampler (Sampler, Callable[[], Sampler], optional): the sampler to be used.
101
+ If a callable is passed, it is used as constructor for the sampler.
102
+ If none is provided, a default :class:`~torchrl.data.replay_buffers.RandomSampler`
103
+ will be used.
104
+ writer (Writer, Callable[[], Writer], optional): the writer to be used.
105
+ If a callable is passed, it is used as constructor for the writer.
106
+ If none is provided a default :class:`~torchrl.data.replay_buffers.RoundRobinWriter`
107
+ will be used.
108
+ collate_fn (callable, optional): merges a list of samples to form a
109
+ mini-batch of Tensor(s)/outputs. Used when using batched
110
+ loading from a map-style dataset. The default value will be decided
111
+ based on the storage type.
112
+ pin_memory (bool): whether pin_memory() should be called on the rb
113
+ samples.
114
+ prefetch (int, optional): number of next batches to be prefetched
115
+ using multithreading. Defaults to None (no prefetching).
116
+ transform (Transform or Callable[[Any], Any], optional): Transform to be executed when
117
+ :meth:`sample` is called.
118
+ To chain transforms use the :class:`~torchrl.envs.Compose` class.
119
+ Transforms should be used with :class:`tensordict.TensorDict`
120
+ content. A generic callable can also be passed if the replay buffer
121
+ is used with PyTree structures (see example below).
122
+ Unlike storages, writers and samplers, transform constructors must
123
+ be passed as separate keyword argument :attr:`transform_factory`,
124
+ as it is impossible to distinguish a constructor from a transform.
125
+ transform_factory (Callable[[], Callable], optional): a factory for the
126
+ transform. Exclusive with :attr:`transform`.
127
+ batch_size (int, optional): the batch size to be used when sample() is
128
+ called.
129
+
130
+ .. note::
131
+ The batch-size can be specified at construction time via the
132
+ ``batch_size`` argument, or at sampling time. The former should
133
+ be preferred whenever the batch-size is consistent across the
134
+ experiment. If the batch-size is likely to change, it can be
135
+ passed to the :meth:`sample` method. This option is
136
+ incompatible with prefetching (since this requires to know the
137
+ batch-size in advance) as well as with samplers that have a
138
+ ``drop_last`` argument.
139
+
140
+ dim_extend (int, optional): indicates the dim to consider for
141
+ extension when calling :meth:`extend`. Defaults to ``storage.ndim-1``.
142
+ When using ``dim_extend > 0``, we recommend using the ``ndim``
143
+ argument in the storage instantiation if that argument is
144
+ available, to let storages know that the data is
145
+ multi-dimensional and keep consistent notions of storage-capacity
146
+ and batch-size during sampling.
147
+
148
+ .. note:: This argument has no effect on :meth:`add` and
149
+ therefore should be used with caution when both :meth:`add`
150
+ and :meth:`extend` are used in a codebase. For example:
151
+
152
+ >>> data = torch.zeros(3, 4)
153
+ >>> rb = ReplayBuffer(
154
+ ... storage=LazyTensorStorage(10, ndim=2),
155
+ ... dim_extend=1)
156
+ >>> # these two approaches are equivalent:
157
+ >>> for d in data.unbind(1):
158
+ ... rb.add(d)
159
+ >>> rb.extend(data)
160
+
161
+ generator (torch.Generator, optional): a generator to use for sampling.
162
+ Using a dedicated generator for the replay buffer can allow a fine-grained control
163
+ over seeding, for instance keeping the global seed different but the RB seed identical
164
+ for distributed jobs.
165
+ Defaults to ``None`` (global default generator).
166
+
167
+ .. warning:: As of now, the generator has no effect on the transforms.
168
+ shared (bool, optional): whether the buffer will be shared using multiprocessing or not.
169
+ Defaults to ``False``.
170
+ compilable (bool, optional): whether the writer is compilable.
171
+ If ``True``, the writer cannot be shared between multiple processes.
172
+ Defaults to ``False``.
173
+ delayed_init (bool, optional): whether to initialize storage, writer, sampler and transform
174
+ the first time the buffer is used rather than during construction.
175
+ This is useful when the replay buffer needs to be pickled and sent to remote workers,
176
+ particularly when using transforms with modules that require gradients.
177
+ If not specified, defaults to ``True`` when ``transform_factory`` is provided,
178
+ and ``False`` otherwise.
179
+
180
+ Examples:
181
+ >>> import torch
182
+ >>>
183
+ >>> from torchrl.data import ReplayBuffer, ListStorage
184
+ >>>
185
+ >>> torch.manual_seed(0)
186
+ >>> rb = ReplayBuffer(
187
+ ... storage=ListStorage(max_size=1000),
188
+ ... batch_size=5,
189
+ ... )
190
+ >>> # populate the replay buffer and get the item indices
191
+ >>> data = range(10)
192
+ >>> indices = rb.extend(data)
193
+ >>> # sample will return as many elements as specified in the constructor
194
+ >>> sample = rb.sample()
195
+ >>> print(sample)
196
+ tensor([4, 9, 3, 0, 3])
197
+ >>> # Passing the batch-size to the sample method overrides the one in the constructor
198
+ >>> sample = rb.sample(batch_size=3)
199
+ >>> print(sample)
200
+ tensor([9, 7, 3])
201
+ >>> # one cans sample using the ``sample`` method or iterate over the buffer
202
+ >>> for i, batch in enumerate(rb):
203
+ ... print(i, batch)
204
+ ... if i == 3:
205
+ ... break
206
+ 0 tensor([7, 3, 1, 6, 6])
207
+ 1 tensor([9, 8, 6, 6, 8])
208
+ 2 tensor([4, 3, 6, 9, 1])
209
+ 3 tensor([4, 4, 1, 9, 9])
210
+
211
+ Replay buffers accept *any* kind of data. Not all storage types
212
+ will work, as some expect numerical data only, but the default
213
+ :class:`~torchrl.data.ListStorage` will:
214
+
215
+ Examples:
216
+ >>> torch.manual_seed(0)
217
+ >>> buffer = ReplayBuffer(storage=ListStorage(100), collate_fn=lambda x: x)
218
+ >>> indices = buffer.extend(["a", 1, None])
219
+ >>> buffer.sample(3)
220
+ [None, 'a', None]
221
+
222
+ The :class:`~torchrl.data.replay_buffers.TensorStorage`, :class:`~torchrl.data.replay_buffers.LazyMemmapStorage`
223
+ and :class:`~torchrl.data.replay_buffers.LazyTensorStorage` also work
224
+ with any PyTree structure (a PyTree is a nested structure of arbitrary depth made of dicts,
225
+ lists or tuples where the leaves are tensors) provided that it only contains
226
+ tensor data.
227
+
228
+ Examples:
229
+ >>> from torch.utils._pytree import tree_map
230
+ >>> def transform(x):
231
+ ... # Zeros all the data in the pytree
232
+ ... return tree_map(lambda y: y * 0, x)
233
+ >>> rb = ReplayBuffer(storage=LazyMemmapStorage(100), transform=transform)
234
+ >>> data = {
235
+ ... "a": torch.randn(3),
236
+ ... "b": {"c": (torch.zeros(2), [torch.ones(1)])},
237
+ ... 30: -torch.ones(()),
238
+ ... }
239
+ >>> rb.add(data)
240
+ >>> # The sample has a similar structure to the data (with a leading dimension of 10 for each tensor)
241
+ >>> s = rb.sample(10)
242
+ >>> # let's check that our transform did its job:
243
+ >>> def assert0(x):
244
+ >>> assert (x == 0).all()
245
+ >>> tree_map(assert0, s)
246
+
247
+ """
248
+
249
+ def __init__(
250
+ self,
251
+ *,
252
+ storage: Storage | Callable[[], Storage] | None = None,
253
+ sampler: Sampler | Callable[[], Sampler] | None = None,
254
+ writer: Writer | Callable[[], Writer] | None = None,
255
+ collate_fn: Callable | None = None,
256
+ pin_memory: bool = False,
257
+ prefetch: int | None = None,
258
+ transform: Transform | Callable | None = None, # noqa-F821
259
+ transform_factory: Callable[[], Transform | Callable]
260
+ | None = None, # noqa-F821
261
+ batch_size: int | None = None,
262
+ dim_extend: int | None = None,
263
+ checkpointer: StorageCheckpointerBase # noqa: F821
264
+ | Callable[[], StorageCheckpointerBase] # noqa: F821
265
+ | None = None, # noqa: F821
266
+ generator: torch.Generator | None = None,
267
+ shared: bool = False,
268
+ compilable: bool | None = None,
269
+ delayed_init: bool | None = None,
270
+ ) -> None:
271
+ self._delayed_init = delayed_init
272
+ self._initialized = False
273
+
274
+ # Store init parameters for potential delayed initialization
275
+ self._init_storage = storage
276
+ self._init_sampler = sampler
277
+ self._init_writer = writer
278
+ self._init_collate_fn = collate_fn
279
+ self._init_transform = transform
280
+ self._init_transform_factory = transform_factory
281
+ self._init_checkpointer = checkpointer
282
+ self._init_generator = generator
283
+ self._init_compilable = compilable
284
+
285
+ if transform is not None and transform_factory is not None:
286
+ raise TypeError(
287
+ f"transform and transform_factory are mutually exclusive. "
288
+ f"Got transform={transform} and transform_factory={transform_factory}."
289
+ )
290
+
291
+ # Auto-detect delayed_init when transform_factory is provided
292
+ if transform_factory is not None and delayed_init is None:
293
+ delayed_init = True
294
+ elif delayed_init is None:
295
+ delayed_init = False
296
+
297
+ # Update _delayed_init after auto-detection
298
+ self._delayed_init = delayed_init
299
+
300
+ self._pin_memory = pin_memory
301
+ self._prefetch = bool(prefetch)
302
+ self._prefetch_cap = prefetch or 0
303
+ self._prefetch_queue = collections.deque()
304
+ self._batch_size = batch_size
305
+
306
+ if batch_size is None and prefetch:
307
+ raise ValueError(
308
+ "Dynamic batch-size specification is incompatible "
309
+ "with multithreaded sampling. "
310
+ "When using prefetch, the batch-size must be specified in "
311
+ "advance. "
312
+ )
313
+
314
+ if dim_extend is not None and dim_extend < 0:
315
+ raise ValueError("dim_extend must be a positive value.")
316
+ self._dim_extend = dim_extend
317
+
318
+ if self._prefetch_cap:
319
+ self._prefetch_executor = ThreadPoolExecutor(max_workers=self._prefetch_cap)
320
+
321
+ if shared and prefetch:
322
+ raise ValueError("Cannot share prefetched replay buffers.")
323
+ self.shared = shared
324
+ self.share(self.shared)
325
+
326
+ self._replay_lock = threading.RLock()
327
+ self._futures_lock = threading.RLock()
328
+
329
+ # If not delayed, initialize immediately
330
+ if not self._delayed_init:
331
+ self._init()
332
+
333
+ def _init(self) -> None:
334
+ """Initialize the replay buffer components.
335
+
336
+ This method is called either immediately during __init__ (if delayed_init=False)
337
+ or on first use of the buffer (if delayed_init=True).
338
+ """
339
+ if self._initialized:
340
+ return
341
+
342
+ self._initialized = True
343
+ try:
344
+ # Initialize storage
345
+ self._storage = self._maybe_make_storage(
346
+ self._init_storage, compilable=self._init_compilable
347
+ )
348
+ self._storage.attach(self)
349
+
350
+ # Initialize sampler
351
+ self._sampler = self._maybe_make_sampler(self._init_sampler)
352
+
353
+ # Initialize writer
354
+ self._writer = self._maybe_make_writer(self._init_writer)
355
+ self._writer.register_storage(self._storage)
356
+
357
+ # Initialize collate function
358
+ self._get_collate_fn(self._init_collate_fn)
359
+
360
+ # Initialize transform
361
+ self._transform = self._maybe_make_transform(
362
+ self._init_transform, self._init_transform_factory
363
+ )
364
+
365
+ # Check batch_size compatibility with sampler
366
+ if (
367
+ self._batch_size is None
368
+ and hasattr(self._sampler, "drop_last")
369
+ and self._sampler.drop_last
370
+ ):
371
+ raise ValueError(
372
+ "Samplers with drop_last=True must work with a predictable batch-size. "
373
+ "Please pass the batch-size to the ReplayBuffer constructor."
374
+ )
375
+
376
+ # Set dim_extend properly now that storage is initialized
377
+ if self._dim_extend is None:
378
+ if self._storage is not None:
379
+ ndim = self._storage.ndim
380
+ self._dim_extend = ndim - 1
381
+ else:
382
+ self._dim_extend = 1
383
+
384
+ # Set checkpointer and generator
385
+ self._storage.checkpointer = self._init_checkpointer
386
+ self.set_rng(generator=self._init_generator)
387
+
388
+ # Initialize prioritized sampler if needed
389
+ self._initialize_prioritized_sampler()
390
+
391
+ # Remove init parameters
392
+ self._init_storage = None
393
+ self._init_sampler = None
394
+ self._init_writer = None
395
+ self._init_collate_fn = None
396
+ self._init_transform = None
397
+ self._init_transform_factory = None
398
+ self._init_checkpointer = None
399
+ self._init_generator = None
400
+ self._init_compilable = None
401
+ except Exception as e:
402
+ self._initialized = False
403
+ raise e
404
+
405
+ @property
406
+ def initialized(self) -> bool:
407
+ """Whether the replay buffer has been initialized."""
408
+ return self._initialized
409
+
410
+ def _initialize_prioritized_sampler(self) -> None:
411
+ """Initialize priority trees for existing data when using PrioritizedSampler.
412
+
413
+ This method ensures that when a PrioritizedSampler is used with storage that
414
+ already contains data, the priority trees are properly populated with default
415
+ priorities for all existing entries.
416
+ """
417
+ from .samplers import PrioritizedSampler
418
+
419
+ if isinstance(self._sampler, PrioritizedSampler) and len(self._storage) > 0:
420
+ # Set default priorities for all existing data
421
+ indices = torch.arange(len(self._storage), dtype=torch.long)
422
+ default_priorities = torch.full(
423
+ (len(self._storage),), self._sampler.default_priority, dtype=torch.float
424
+ )
425
+ self._sampler.update_priority(indices, default_priorities)
426
+
427
+ def _maybe_make_storage(
428
+ self, storage: Storage | Callable[[], Storage] | None, compilable
429
+ ) -> Storage:
430
+ if storage is None:
431
+ return ListStorage(max_size=1_000, compilable=compilable)
432
+ elif isinstance(storage, Storage):
433
+ return storage
434
+ elif callable(storage):
435
+ storage = storage()
436
+ if not isinstance(storage, Storage):
437
+ raise TypeError(
438
+ "storage must be either a Storage or a callable returning a storage instance."
439
+ )
440
+ return storage
441
+
442
+ def _maybe_make_sampler(
443
+ self, sampler: Sampler | Callable[[], Sampler] | None
444
+ ) -> Sampler:
445
+ if sampler is None:
446
+ return RandomSampler()
447
+ elif isinstance(sampler, Sampler):
448
+ return sampler
449
+ elif callable(sampler):
450
+ sampler = sampler()
451
+ if not isinstance(sampler, Sampler):
452
+ raise TypeError(
453
+ "sampler must be either a Sampler or a callable returning a sampler instance."
454
+ )
455
+ return sampler
456
+
457
+ def _maybe_make_writer(
458
+ self, writer: Writer | Callable[[], Writer] | None
459
+ ) -> Writer:
460
+ if writer is None:
461
+ return RoundRobinWriter()
462
+ elif isinstance(writer, Writer):
463
+ return writer
464
+ elif callable(writer):
465
+ writer = writer()
466
+ if not isinstance(writer, Writer):
467
+ raise TypeError(
468
+ "writer must be either a Writer or a callable returning a writer instance."
469
+ )
470
+ return writer
471
+
472
+ def _maybe_make_transform(
473
+ self,
474
+ transform: Transform | Callable[[], Transform] | None,
475
+ transform_factory: Callable | None,
476
+ ) -> Transform:
477
+ from torchrl.envs.transforms.transforms import (
478
+ _CallableTransform,
479
+ Compose,
480
+ Transform,
481
+ )
482
+
483
+ if transform_factory is not None:
484
+ if transform is not None:
485
+ raise TypeError(
486
+ "transform and transform_factory cannot be used simultaneously"
487
+ )
488
+ transform = transform_factory()
489
+ if transform is None:
490
+ transform = Compose()
491
+ elif not isinstance(transform, Compose):
492
+ if not isinstance(transform, Transform) and callable(transform):
493
+ transform = _CallableTransform(transform)
494
+ elif not isinstance(transform, Transform):
495
+ raise RuntimeError(
496
+ "transform must be either a Transform instance or a callable."
497
+ )
498
+ transform = Compose(transform)
499
+ transform.eval()
500
+ return transform
501
+
502
+ def share(self, shared: bool = True) -> Self:
503
+ self.shared = shared
504
+ if self.shared:
505
+ self._write_lock = multiprocessing.Lock()
506
+ else:
507
+ self._write_lock = contextlib.nullcontext()
508
+ return self
509
+
510
+ @_maybe_delay_init
511
+ def set_rng(self, generator) -> None:
512
+ self._rng = generator
513
+ self._storage._rng = generator
514
+ self._sampler._rng = generator
515
+ self._writer._rng = generator
516
+
517
+ @property
518
+ def dim_extend(self):
519
+ return self._dim_extend
520
+
521
+ @property
522
+ def batch_size(self):
523
+ """The batch size of the replay buffer.
524
+
525
+ The batch size can be overriden by setting the `batch_size` parameter in the :meth:`sample` method.
526
+
527
+ It defines both the number of samples returned by :meth:`sample` and the number of samples that are
528
+ yielded by the :class:`ReplayBuffer` iterator.
529
+ """
530
+ return self._batch_size
531
+
532
+ @dim_extend.setter
533
+ def dim_extend(self, value):
534
+ if (
535
+ hasattr(self, "_dim_extend")
536
+ and self._dim_extend is not None
537
+ and self._dim_extend != value
538
+ ):
539
+ raise RuntimeError(
540
+ "dim_extend cannot be reset. Please create a new replay buffer."
541
+ )
542
+
543
+ if value is None:
544
+ if self._initialized and self._storage is not None:
545
+ ndim = self._storage.ndim
546
+ value = ndim - 1
547
+ else:
548
+ value = 1
549
+
550
+ self._dim_extend = value
551
+
552
+ def _transpose(self, data):
553
+ if is_tensor_collection(data):
554
+ return data.transpose(self.dim_extend, 0)
555
+ return tree_map(lambda x: x.transpose(self.dim_extend, 0), data)
556
+
557
+ def _get_collate_fn(self, collate_fn):
558
+ self._collate_fn = (
559
+ collate_fn
560
+ if collate_fn is not None
561
+ else _get_default_collate(
562
+ self._storage, _is_tensordict=isinstance(self, TensorDictReplayBuffer)
563
+ )
564
+ )
565
+
566
+ @_maybe_delay_init
567
+ def set_storage(self, storage: Storage, collate_fn: Callable | None = None):
568
+ """Sets a new storage in the replay buffer and returns the previous storage.
569
+
570
+ Args:
571
+ storage (Storage): the new storage for the buffer.
572
+ collate_fn (callable, optional): if provided, the collate_fn is set to this
573
+ value. Otherwise it is reset to a default value.
574
+
575
+ """
576
+ prev_storage = self._storage
577
+ self._storage = storage
578
+ self._get_collate_fn(collate_fn)
579
+
580
+ return prev_storage
581
+
582
+ @_maybe_delay_init
583
+ def set_writer(self, writer: Writer):
584
+ """Sets a new writer in the replay buffer and returns the previous writer."""
585
+ prev_writer = self._writer
586
+ self._writer = writer
587
+ self._writer.register_storage(self._storage)
588
+ return prev_writer
589
+
590
+ @_maybe_delay_init
591
+ def set_sampler(self, sampler: Sampler):
592
+ """Sets a new sampler in the replay buffer and returns the previous sampler."""
593
+ prev_sampler = self._sampler
594
+ self._sampler = sampler
595
+ return prev_sampler
596
+
597
+ @_maybe_delay_init
598
+ def __len__(self) -> int:
599
+ with self._replay_lock:
600
+ return len(self._storage)
601
+
602
+ def _getattr(self, attr):
603
+ # To access properties in remote settings, see RayReplayBuffer.write_count for instance
604
+ return getattr(self, attr)
605
+
606
+ def _setattr(self, attr, value):
607
+ # To set properties in remote settings
608
+ setattr(self, attr, value)
609
+ return None # explicit return for remote calls
610
+
611
+ @property
612
+ @_maybe_delay_init
613
+ def write_count(self) -> int:
614
+ """The total number of items written so far in the buffer through add and extend."""
615
+ return self._writer._write_count
616
+
617
+ def __repr__(self) -> str:
618
+ from torchrl.envs.transforms import Compose
619
+
620
+ storage = textwrap.indent(f"storage={getattr(self, '_storage', None)}", " " * 4)
621
+ writer = textwrap.indent(f"writer={getattr(self, '_writer', None)}", " " * 4)
622
+ sampler = textwrap.indent(f"sampler={getattr(self, '_sampler', None)}", " " * 4)
623
+ if getattr(self, "_transform", None) is not None and not (
624
+ isinstance(self._transform, Compose)
625
+ and not len(getattr(self, "_transform", None))
626
+ ):
627
+ transform = textwrap.indent(
628
+ f"transform={getattr(self, '_transform', None)}", " " * 4
629
+ )
630
+ transform = f"\n{self._transform}, "
631
+ else:
632
+ transform = ""
633
+ batch_size = textwrap.indent(
634
+ f"batch_size={getattr(self, '_batch_size', None)}", " " * 4
635
+ )
636
+ collate_fn = textwrap.indent(
637
+ f"collate_fn={getattr(self, '_collate_fn', None)}", " " * 4
638
+ )
639
+ return f"{self.__class__.__name__}(\n{storage}, \n{sampler}, \n{writer}, {transform}\n{batch_size}, \n{collate_fn})"
640
+
641
+ @_maybe_delay_init
642
+ @pin_memory_output
643
+ def __getitem__(self, index: int | torch.Tensor | NestedKey) -> Any:
644
+ if isinstance(index, str) or (isinstance(index, tuple) and unravel_key(index)):
645
+ return self[:][index]
646
+ if isinstance(index, tuple):
647
+ if len(index) == 1:
648
+ return self[index[0]]
649
+ else:
650
+ return self[:][index]
651
+ index = _to_numpy(index)
652
+
653
+ if self.dim_extend > 0:
654
+ index = (slice(None),) * self.dim_extend + (index,)
655
+ with self._replay_lock:
656
+ data = self._storage[index]
657
+ data = self._transpose(data)
658
+ else:
659
+ with self._replay_lock:
660
+ data = self._storage[index]
661
+
662
+ if not isinstance(index, INT_CLASSES):
663
+ data = self._collate_fn(data)
664
+
665
+ if self._transform is not None and len(self._transform):
666
+ with data.unlock_() if is_tensor_collection(
667
+ data
668
+ ) else contextlib.nullcontext():
669
+ data = self._transform(data)
670
+
671
+ return data
672
+
673
+ @_maybe_delay_init
674
+ def __setitem__(self, index, value) -> None:
675
+ if isinstance(index, str) or (isinstance(index, tuple) and unravel_key(index)):
676
+ self[:][index] = value
677
+ return
678
+ if isinstance(index, tuple):
679
+ if len(index) == 1:
680
+ self[index[0]] = value
681
+ else:
682
+ self[:][index] = value
683
+ return
684
+ index = _to_numpy(index)
685
+
686
+ if self._transform is not None and len(self._transform):
687
+ value = self._transform.inv(value)
688
+
689
+ if self.dim_extend > 0:
690
+ index = (slice(None),) * self.dim_extend + (index,)
691
+ with self._replay_lock:
692
+ self._storage[index] = self._transpose(value)
693
+ else:
694
+ with self._replay_lock:
695
+ self._storage[index] = value
696
+ return
697
+
698
+ @_maybe_delay_init
699
+ def state_dict(self) -> dict[str, Any]:
700
+ return {
701
+ "_storage": self._storage.state_dict(),
702
+ "_sampler": self._sampler.state_dict(),
703
+ "_writer": self._writer.state_dict(),
704
+ "_transforms": self._transform.state_dict(),
705
+ "_batch_size": self._batch_size,
706
+ "_rng": (self._rng.get_state().clone(), str(self._rng.device))
707
+ if self._rng is not None
708
+ else None,
709
+ }
710
+
711
+ @_maybe_delay_init
712
+ def load_state_dict(self, state_dict: dict[str, Any]) -> None:
713
+ self._storage.load_state_dict(state_dict["_storage"])
714
+ self._sampler.load_state_dict(state_dict["_sampler"])
715
+ self._writer.load_state_dict(state_dict["_writer"])
716
+ self._transform.load_state_dict(state_dict["_transforms"])
717
+ self._batch_size = state_dict["_batch_size"]
718
+ rng = state_dict.get("_rng")
719
+ if rng is not None:
720
+ state, device = rng
721
+ rng = torch.Generator(device=device)
722
+ rng.set_state(state)
723
+ self.set_rng(generator=rng)
724
+
725
+ @_maybe_delay_init
726
+ def dumps(self, path):
727
+ """Saves the replay buffer on disk at the specified path.
728
+
729
+ Args:
730
+ path (Path or str): path where to save the replay buffer.
731
+
732
+ Examples:
733
+ >>> import tempfile
734
+ >>> import tqdm
735
+ >>> from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer
736
+ >>> from torchrl.data.replay_buffers.samplers import PrioritizedSampler, RandomSampler
737
+ >>> import torch
738
+ >>> from tensordict import TensorDict
739
+ >>> # Build and populate the replay buffer
740
+ >>> S = 1_000_000
741
+ >>> sampler = PrioritizedSampler(S, 1.1, 1.0)
742
+ >>> # sampler = RandomSampler()
743
+ >>> storage = LazyMemmapStorage(S)
744
+ >>> rb = TensorDictReplayBuffer(storage=storage, sampler=sampler)
745
+ >>>
746
+ >>> for _ in tqdm.tqdm(range(100)):
747
+ ... td = TensorDict({"obs": torch.randn(100, 3, 4), "next": {"obs": torch.randn(100, 3, 4)}, "td_error": torch.rand(100)}, [100])
748
+ ... rb.extend(td)
749
+ ... sample = rb.sample(32)
750
+ ... rb.update_tensordict_priority(sample)
751
+ >>> # save and load the buffer
752
+ >>> with tempfile.TemporaryDirectory() as tmpdir:
753
+ ... rb.dumps(tmpdir)
754
+ ...
755
+ ... sampler = PrioritizedSampler(S, 1.1, 1.0)
756
+ ... # sampler = RandomSampler()
757
+ ... storage = LazyMemmapStorage(S)
758
+ ... rb_load = TensorDictReplayBuffer(storage=storage, sampler=sampler)
759
+ ... rb_load.loads(tmpdir)
760
+ ... assert len(rb) == len(rb_load)
761
+
762
+ """
763
+ path = Path(path).absolute()
764
+ path.mkdir(exist_ok=True)
765
+ self._storage.dumps(path / "storage")
766
+ self._sampler.dumps(path / "sampler")
767
+ self._writer.dumps(path / "writer")
768
+ if self._rng is not None:
769
+ rng_state = TensorDict(
770
+ rng_state=self._rng.get_state().clone(),
771
+ device=self._rng.device,
772
+ )
773
+ rng_state.memmap(path / "rng_state")
774
+
775
+ # fall back on state_dict for transforms
776
+ transform_sd = self._transform.state_dict()
777
+ if transform_sd:
778
+ torch.save(transform_sd, path / "transform.t")
779
+ with open(path / "buffer_metadata.json", "w") as file:
780
+ json.dump({"batch_size": self._batch_size}, file)
781
+
782
+ @_maybe_delay_init
783
+ def loads(self, path):
784
+ """Loads a replay buffer state at the given path.
785
+
786
+ The buffer should have matching components and be saved using :meth:`dumps`.
787
+
788
+ Args:
789
+ path (Path or str): path where the replay buffer was saved.
790
+
791
+ See :meth:`dumps` for more info.
792
+
793
+ """
794
+ path = Path(path).absolute()
795
+ self._storage.loads(path / "storage")
796
+ self._sampler.loads(path / "sampler")
797
+ self._writer.loads(path / "writer")
798
+ if (path / "rng_state").exists():
799
+ rng_state = TensorDict.load_memmap(path / "rng_state")
800
+ rng = torch.Generator(device=rng_state.device)
801
+ rng.set_state(rng_state["rng_state"])
802
+ self.set_rng(rng)
803
+ # fall back on state_dict for transforms
804
+ if (path / "transform.t").exists():
805
+ self._transform.load_state_dict(torch.load(path / "transform.t"))
806
+ with open(path / "buffer_metadata.json") as file:
807
+ metadata = json.load(file)
808
+ self._batch_size = metadata["batch_size"]
809
+
810
+ @_maybe_delay_init
811
+ def save(self, *args, **kwargs):
812
+ """Alias for :meth:`dumps`."""
813
+ return self.dumps(*args, **kwargs)
814
+
815
+ @_maybe_delay_init
816
+ def dump(self, *args, **kwargs):
817
+ """Alias for :meth:`dumps`."""
818
+ return self.dumps(*args, **kwargs)
819
+
820
+ @_maybe_delay_init
821
+ def load(self, *args, **kwargs):
822
+ """Alias for :meth:`loads`."""
823
+ return self.loads(*args, **kwargs)
824
+
825
+ @_maybe_delay_init
826
+ def register_save_hook(self, hook: Callable[[Any], Any]):
827
+ """Registers a save hook for the storage.
828
+
829
+ .. note:: Hooks are currently not serialized when saving a replay buffer: they must
830
+ be manually re-initialized every time the buffer is created.
831
+
832
+ """
833
+ self._storage.register_save_hook(hook)
834
+
835
+ @_maybe_delay_init
836
+ def register_load_hook(self, hook: Callable[[Any], Any]):
837
+ """Registers a load hook for the storage.
838
+
839
+ .. note:: Hooks are currently not serialized when saving a replay buffer: they must
840
+ be manually re-initialized every time the buffer is created.
841
+
842
+ """
843
+ self._storage.register_load_hook(hook)
844
+
845
+ @_maybe_delay_init
846
+ def add(self, data: Any) -> int:
847
+ """Add a single element to the replay buffer.
848
+
849
+ Args:
850
+ data (Any): data to be added to the replay buffer
851
+
852
+ Returns:
853
+ index where the data lives in the replay buffer.
854
+ """
855
+ if self._transform is not None and len(self._transform):
856
+ with _set_dispatch_td_nn_modules(is_tensor_collection(data)):
857
+ make_none = False
858
+ # Transforms usually expect a time batch dimension when called within a RB, so we unsqueeze the data temporarily
859
+ is_tc = is_tensor_collection(data)
860
+ cm = data.unsqueeze(-1) if is_tc else contextlib.nullcontext(data)
861
+ new_data = None
862
+ with cm as data_unsq:
863
+ data_unsq_r = self._transform.inv(data_unsq)
864
+ if is_tc and data_unsq_r is not None:
865
+ # this is a no-op whenever the result matches the input
866
+ new_data = data_unsq_r.squeeze(-1)
867
+ else:
868
+ make_none = data_unsq_r is None
869
+ data = new_data if new_data is not None else data
870
+ if make_none:
871
+ data = None
872
+ if data is None:
873
+ return torch.zeros((0, self._storage.ndim), dtype=torch.long)
874
+ if rl_warnings() and is_tensor_collection(data) and data.ndim:
875
+ warnings.warn(
876
+ f"Using `add()` with a TensorDict that has batch_size={data.batch_size}. "
877
+ f"Use `extend()` to add multiple elements, or `add()` with a single element (batch_size=torch.Size([])). "
878
+ "You can silence this warning by setting the `RL_WARNINGS` environment variable to `'0'`."
879
+ )
880
+
881
+ return self._add(data)
882
+
883
+ def _add(self, data):
884
+ with self._replay_lock, self._write_lock:
885
+ index = self._writer.add(data)
886
+ self._sampler.add(index)
887
+ return index
888
+
889
+ def _extend(self, data: Sequence, *, update_priority: bool = True) -> torch.Tensor:
890
+ is_comp = is_compiling()
891
+ nc = contextlib.nullcontext()
892
+ with self._replay_lock if not is_comp else nc, self._write_lock if not is_comp else nc:
893
+ if self.dim_extend > 0:
894
+ data = self._transpose(data)
895
+ index = self._writer.extend(data)
896
+ self._sampler.extend(index)
897
+ return index
898
+
899
+ @_maybe_delay_init
900
+ def extend(
901
+ self, data: Sequence, *, update_priority: bool | None = None
902
+ ) -> torch.Tensor:
903
+ """Extends the replay buffer with one or more elements contained in an iterable.
904
+
905
+ If present, the inverse transforms will be called.`
906
+
907
+ Args:
908
+ data (iterable): collection of data to be added to the replay
909
+ buffer.
910
+
911
+ Keyword Args:
912
+ update_priority (bool, optional): Whether to update the priority of the data. Defaults to True.
913
+ Without effect in this class. See :meth:`~torchrl.data.TensorDictReplayBuffer.extend` for more details.
914
+
915
+ Returns:
916
+ Indices of the data added to the replay buffer.
917
+
918
+ .. warning:: :meth:`~torchrl.data.replay_buffers.ReplayBuffer.extend` can have an
919
+ ambiguous signature when dealing with lists of values, which should be interpreted
920
+ either as PyTree (in which case all elements in the list will be put in a slice
921
+ in the stored PyTree in the storage) or a list of values to add one at a time.
922
+ To solve this, TorchRL makes the clear-cut distinction between list and tuple:
923
+ a tuple will be viewed as a PyTree, a list (at the root level) will be interpreted
924
+ as a stack of values to add one at a time to the buffer.
925
+ For :class:`~torchrl.data.replay_buffers.ListStorage` instances, only
926
+ unbound elements can be provided (no PyTrees).
927
+
928
+ """
929
+ if update_priority is not None:
930
+ raise NotImplementedError(
931
+ "update_priority is not supported in this class. See :meth:`~torchrl.data.TensorDictReplayBuffer.extend` for more details."
932
+ )
933
+ if self._transform is not None and len(self._transform):
934
+ with _set_dispatch_td_nn_modules(is_tensor_collection(data)):
935
+ data = self._transform.inv(data)
936
+ if data is None:
937
+ return torch.zeros((0, self._storage.ndim), dtype=torch.long)
938
+ return self._extend(data, update_priority=update_priority)
939
+
940
+ @_maybe_delay_init
941
+ def update_priority(
942
+ self,
943
+ index: int | torch.Tensor | tuple[torch.Tensor],
944
+ priority: int | torch.Tensor,
945
+ ) -> None:
946
+ if isinstance(index, tuple):
947
+ index = torch.stack(index, -1)
948
+ priority = torch.as_tensor(priority)
949
+ if self.dim_extend > 0 and priority.ndim > 1:
950
+ priority = self._transpose(priority).flatten()
951
+ # priority = priority.flatten()
952
+ with self._replay_lock, self._write_lock:
953
+ self._sampler.update_priority(index, priority, storage=self.storage)
954
+
955
+ @pin_memory_output
956
+ def _sample(self, batch_size: int) -> tuple[Any, dict]:
957
+ is_comp = is_compiling()
958
+ nc = contextlib.nullcontext()
959
+ with self._replay_lock if not is_comp else nc, self._write_lock if not is_comp else nc:
960
+ index, info = self._sampler.sample(self._storage, batch_size)
961
+ info["index"] = index
962
+ data = self._storage.get(index)
963
+ if not isinstance(index, INT_CLASSES):
964
+ data = self._collate_fn(data)
965
+ if self._transform is not None and len(self._transform):
966
+ is_td = is_tensor_collection(data)
967
+ with data.unlock_() if is_td else contextlib.nullcontext(), _set_dispatch_td_nn_modules(
968
+ is_td
969
+ ):
970
+ data = self._transform(data)
971
+
972
+ return data, info
973
+
974
+ @_maybe_delay_init
975
+ def empty(self, empty_write_count: bool = True):
976
+ """Empties the replay buffer and reset cursor to 0.
977
+
978
+ Args:
979
+ empty_write_count (bool, optional): Whether to empty the write_count attribute. Defaults to `True`.
980
+ """
981
+ self._writer._empty(empty_write_count=empty_write_count)
982
+ self._sampler._empty()
983
+ self._storage._empty()
984
+
985
+ @_maybe_delay_init
986
+ def sample(self, batch_size: int | None = None, return_info: bool = False) -> Any:
987
+ """Samples a batch of data from the replay buffer.
988
+
989
+ Uses Sampler to sample indices, and retrieves them from Storage.
990
+
991
+ Args:
992
+ batch_size (int, optional): size of data to be collected. If none
993
+ is provided, this method will sample a batch-size as indicated
994
+ by the sampler.
995
+ return_info (bool): whether to return info. If True, the result
996
+ is a tuple (data, info). If False, the result is the data.
997
+
998
+ Returns:
999
+ A batch of data selected in the replay buffer.
1000
+ A tuple containing this batch and info if return_info flag is set to True.
1001
+ """
1002
+ if (
1003
+ batch_size is not None
1004
+ and self._batch_size is not None
1005
+ and batch_size != self._batch_size
1006
+ ):
1007
+ warnings.warn(
1008
+ f"Got conflicting batch_sizes in constructor ({self._batch_size}) "
1009
+ f"and `sample` ({batch_size}). Refer to the ReplayBuffer documentation "
1010
+ "for a proper usage of the batch-size arguments. "
1011
+ "The batch-size provided to the sample method "
1012
+ "will prevail."
1013
+ )
1014
+ elif batch_size is None and self._batch_size is not None:
1015
+ batch_size = self._batch_size
1016
+ elif batch_size is None:
1017
+ raise RuntimeError(
1018
+ "batch_size not specified. You can specify the batch_size when "
1019
+ "constructing the replay buffer, or pass it to the sample method. "
1020
+ "Refer to the ReplayBuffer documentation "
1021
+ "for a proper usage of the batch-size arguments."
1022
+ )
1023
+ if not self._prefetch:
1024
+ result = self._sample(batch_size)
1025
+ else:
1026
+ with self._futures_lock:
1027
+ while (
1028
+ len(self._prefetch_queue)
1029
+ < min(self._sampler._remaining_batches, self._prefetch_cap)
1030
+ and not self._sampler.ran_out
1031
+ ) or not len(self._prefetch_queue):
1032
+ fut = self._prefetch_executor.submit(self._sample, batch_size)
1033
+ self._prefetch_queue.append(fut)
1034
+ result = self._prefetch_queue.popleft().result()
1035
+
1036
+ if return_info:
1037
+ out, info = result
1038
+ if getattr(self.storage, "device", None) is not None:
1039
+ device = self.storage.device
1040
+ info = tree_map(lambda x: x.to(device) if hasattr(x, "to") else x, info)
1041
+ return out, info
1042
+ return result[0]
1043
+
1044
+ @_maybe_delay_init
1045
+ def mark_update(self, index: int | torch.Tensor) -> None:
1046
+ self._sampler.mark_update(index, storage=self._storage)
1047
+
1048
+ @_maybe_delay_init
1049
+ def append_transform(
1050
+ self, transform: Transform, *, invert: bool = False # noqa-F821
1051
+ ) -> ReplayBuffer: # noqa: D417
1052
+ """Appends transform at the end.
1053
+
1054
+ Transforms are applied in order when `sample` is called.
1055
+
1056
+ Args:
1057
+ transform (Transform): The transform to be appended
1058
+
1059
+ Keyword Args:
1060
+ invert (bool, optional): if ``True``, the transform will be inverted (forward calls will be called
1061
+ during writing and inverse calls during reading). Defaults to ``False``.
1062
+
1063
+ Example:
1064
+ >>> rb = ReplayBuffer(storage=LazyMemmapStorage(10), batch_size=4)
1065
+ >>> data = TensorDict({"a": torch.zeros(10)}, [10])
1066
+ >>> def t(data):
1067
+ ... data += 1
1068
+ ... return data
1069
+ >>> rb.append_transform(t, invert=True)
1070
+ >>> rb.extend(data)
1071
+ >>> assert (data == 1).all()
1072
+
1073
+ """
1074
+ from torchrl.envs.transforms.transforms import _CallableTransform, Transform
1075
+
1076
+ if not isinstance(transform, Transform) and callable(transform):
1077
+ transform = _CallableTransform(transform)
1078
+ if invert:
1079
+ transform = _InvertTransform(transform)
1080
+ transform.eval()
1081
+ self._transform.append(transform)
1082
+ return self
1083
+
1084
+ @_maybe_delay_init
1085
+ def insert_transform(
1086
+ self,
1087
+ index: int,
1088
+ transform: Transform, # noqa-F821
1089
+ *,
1090
+ invert: bool = False,
1091
+ ) -> ReplayBuffer: # noqa: D417
1092
+ """Inserts transform.
1093
+
1094
+ Transforms are executed in order when `sample` is called.
1095
+
1096
+ Args:
1097
+ index (int): Position to insert the transform.
1098
+ transform (Transform): The transform to be appended
1099
+
1100
+ Keyword Args:
1101
+ invert (bool, optional): if ``True``, the transform will be inverted (forward calls will be called
1102
+ during writing and inverse calls during reading). Defaults to ``False``.
1103
+
1104
+ """
1105
+ transform.eval()
1106
+ if invert:
1107
+ transform = _InvertTransform(transform)
1108
+ self._transform.insert(index, transform)
1109
+ return self
1110
+
1111
+ _iterator = None
1112
+
1113
+ @_maybe_delay_init
1114
+ def next(self):
1115
+ """Returns the next item in the replay buffer.
1116
+
1117
+ This method is used to iterate over the replay buffer in contexts where __iter__ is not available,
1118
+ such as :class:`~torchrl.data.replay_buffers.RayReplayBuffer`.
1119
+ """
1120
+ try:
1121
+ if self._iterator is None:
1122
+ self._iterator = iter(self)
1123
+ out = next(self._iterator)
1124
+ # if any, we don't want the device ref to be passed in distributed settings
1125
+ if out is not None and (out.device != "cpu"):
1126
+ out = out.copy().clear_device_()
1127
+ return out
1128
+ except StopIteration:
1129
+ self._iterator = None
1130
+ return None
1131
+
1132
+ @_maybe_delay_init
1133
+ def __iter__(self):
1134
+ if self._sampler.ran_out:
1135
+ self._sampler.ran_out = False
1136
+ if self._batch_size is None:
1137
+ raise RuntimeError(
1138
+ "Cannot iterate over the replay buffer. "
1139
+ "Batch_size was not specified during construction of the replay buffer."
1140
+ )
1141
+ while not self._sampler.ran_out or (
1142
+ self._prefetch and len(self._prefetch_queue)
1143
+ ):
1144
+ yield self.sample()
1145
+
1146
+ @_maybe_delay_init
1147
+ def __getstate__(self) -> dict[str, Any]:
1148
+ state = self.__dict__.copy()
1149
+ if getattr(self, "_rng", None) is not None:
1150
+ rng_state = TensorDict(
1151
+ rng_state=self._rng.get_state().clone(),
1152
+ device=self._rng.device,
1153
+ )
1154
+ state["_rng"] = rng_state
1155
+ _replay_lock = state.pop("_replay_lock", None)
1156
+ _futures_lock = state.pop("_futures_lock", None)
1157
+ if _replay_lock is not None:
1158
+ state["_replay_lock_placeholder"] = None
1159
+ if _futures_lock is not None:
1160
+ state["_futures_lock_placeholder"] = None
1161
+ # Remove non-picklable prefetch objects - they will be recreated on unpickle
1162
+ _prefetch_queue = state.pop("_prefetch_queue", None)
1163
+ _prefetch_executor = state.pop("_prefetch_executor", None)
1164
+ if _prefetch_queue is not None:
1165
+ state["_prefetch_queue_placeholder"] = None
1166
+ if _prefetch_executor is not None:
1167
+ state["_prefetch_executor_placeholder"] = None
1168
+ return state
1169
+
1170
+ def __setstate__(self, state: dict[str, Any]):
1171
+ rngstate = None
1172
+ if "_rng" in state:
1173
+ rngstate = state["_rng"]
1174
+ if rngstate is not None:
1175
+ rng = torch.Generator(device=rngstate.device)
1176
+ rng.set_state(rngstate["rng_state"])
1177
+
1178
+ if "_replay_lock_placeholder" in state:
1179
+ state.pop("_replay_lock_placeholder")
1180
+ _replay_lock = threading.RLock()
1181
+ state["_replay_lock"] = _replay_lock
1182
+ if "_futures_lock_placeholder" in state:
1183
+ state.pop("_futures_lock_placeholder")
1184
+ _futures_lock = threading.RLock()
1185
+ state["_futures_lock"] = _futures_lock
1186
+ # Recreate prefetch objects after unpickling if they were present
1187
+ if "_prefetch_queue_placeholder" in state:
1188
+ state.pop("_prefetch_queue_placeholder")
1189
+ state["_prefetch_queue"] = collections.deque()
1190
+ if "_prefetch_executor_placeholder" in state:
1191
+ state.pop("_prefetch_executor_placeholder")
1192
+ state["_prefetch_executor"] = ThreadPoolExecutor(
1193
+ max_workers=state["_prefetch_cap"]
1194
+ )
1195
+ self.__dict__.update(state)
1196
+ if rngstate is not None:
1197
+ self.set_rng(rng)
1198
+
1199
+ @property
1200
+ @_maybe_delay_init
1201
+ def sampler(self) -> Sampler:
1202
+ """The sampler of the replay buffer.
1203
+
1204
+ The sampler must be an instance of :class:`~torchrl.data.replay_buffers.Sampler`.
1205
+
1206
+ """
1207
+ return self._sampler
1208
+
1209
+ @property
1210
+ @_maybe_delay_init
1211
+ def writer(self) -> Writer:
1212
+ """The writer of the replay buffer.
1213
+
1214
+ The writer must be an instance of :class:`~torchrl.data.replay_buffers.Writer`.
1215
+
1216
+ """
1217
+ return self._writer
1218
+
1219
+ @property
1220
+ @_maybe_delay_init
1221
+ def storage(self) -> Storage:
1222
+ """The storage of the replay buffer.
1223
+
1224
+ The storage must be an instance of :class:`~torchrl.data.replay_buffers.Storage`.
1225
+
1226
+ """
1227
+ return self._storage
1228
+
1229
+ @property
1230
+ @_maybe_delay_init
1231
+ def transform(self) -> Transform:
1232
+ """The transform of the replay buffer.
1233
+
1234
+ The transform must be an instance of :class:`~torchrl.envs.transforms.Transform`.
1235
+ """
1236
+ return self._transform
1237
+
1238
+
1239
+ class PrioritizedReplayBuffer(ReplayBuffer):
1240
+ """Prioritized replay buffer.
1241
+
1242
+ All arguments are keyword-only arguments.
1243
+
1244
+ Presented in "Schaul, T.; Quan, J.; Antonoglou, I.; and Silver, D. 2015.
1245
+ Prioritized experience replay." (https://arxiv.org/abs/1511.05952)
1246
+
1247
+ Args:
1248
+ alpha (:obj:`float`): exponent α determines how much prioritization is used,
1249
+ with α = 0 corresponding to the uniform case.
1250
+ beta (:obj:`float`): importance sampling negative exponent.
1251
+ eps (:obj:`float`): delta added to the priorities to ensure that the buffer
1252
+ does not contain null priorities.
1253
+ storage (Storage, optional): the storage to be used. If none is provided
1254
+ a default :class:`~torchrl.data.replay_buffers.ListStorage` with
1255
+ ``max_size`` of ``1_000`` will be created.
1256
+ sampler (Sampler, optional): the sampler to be used. If none is provided,
1257
+ a default :class:`~torchrl.data.replay_buffers.PrioritizedSampler` with
1258
+ ``alpha``, ``beta``, and ``eps`` will be created.
1259
+ collate_fn (callable, optional): merges a list of samples to form a
1260
+ mini-batch of Tensor(s)/outputs. Used when using batched
1261
+ loading from a map-style dataset. The default value will be decided
1262
+ based on the storage type.
1263
+ pin_memory (bool): whether pin_memory() should be called on the rb
1264
+ samples.
1265
+ prefetch (int, optional): number of next batches to be prefetched
1266
+ using multithreading. Defaults to None (no prefetching).
1267
+ transform (Transform, optional): Transform to be executed when
1268
+ sample() is called.
1269
+ To chain transforms use the :class:`~torchrl.envs.Compose` class.
1270
+ Transforms should be used with :class:`tensordict.TensorDict`
1271
+ content. If used with other structures, the transforms should be
1272
+ encoded with a ``"data"`` leading key that will be used to
1273
+ construct a tensordict from the non-tensordict content.
1274
+ batch_size (int, optional): the batch size to be used when sample() is
1275
+ called.
1276
+
1277
+ .. note:: The batch-size can be specified at construction time via the
1278
+ ``batch_size`` argument, or at sampling time. The former should
1279
+ be preferred whenever the batch-size is consistent across the
1280
+ experiment. If the batch-size is likely to change, it can be
1281
+ passed to the :meth:`sample` method. This option is
1282
+ incompatible with prefetching (since this requires to know the
1283
+ batch-size in advance) as well as with samplers that have a
1284
+ ``drop_last`` argument.
1285
+
1286
+ dim_extend (int, optional): indicates the dim to consider for
1287
+ extension when calling :meth:`extend`. Defaults to ``storage.ndim-1``.
1288
+ When using ``dim_extend > 0``, we recommend using the ``ndim``
1289
+ argument in the storage instantiation if that argument is
1290
+ available, to let storages know that the data is
1291
+ multi-dimensional and keep consistent notions of storage-capacity
1292
+ and batch-size during sampling.
1293
+
1294
+ .. note:: This argument has no effect on :meth:`add` and
1295
+ therefore should be used with caution when both :meth:`add`
1296
+ and :meth:`extend` are used in a codebase. For example:
1297
+
1298
+ >>> data = torch.zeros(3, 4)
1299
+ >>> rb = ReplayBuffer(
1300
+ ... storage=LazyTensorStorage(10, ndim=2),
1301
+ ... dim_extend=1)
1302
+ >>> # these two approaches are equivalent:
1303
+ >>> for d in data.unbind(1):
1304
+ ... rb.add(d)
1305
+ >>> rb.extend(data)
1306
+
1307
+ delayed_init (bool, optional): whether to initialize storage, writer, sampler and transform
1308
+ the first time the buffer is used rather than during construction.
1309
+ This is useful when the replay buffer needs to be pickled and sent to remote workers,
1310
+ particularly when using transforms with modules that require gradients.
1311
+ If not specified, defaults to ``True`` when ``transform_factory`` is provided,
1312
+ and ``False`` otherwise.
1313
+
1314
+ .. note::
1315
+ Generic prioritized replay buffers (ie. non-tensordict backed) require
1316
+ calling :meth:`~.sample` with the ``return_info`` argument set to
1317
+ ``True`` to have access to the indices, and hence update the priority.
1318
+ Using :class:`tensordict.TensorDict` and the related
1319
+ :class:`~torchrl.data.TensorDictPrioritizedReplayBuffer` simplifies this
1320
+ process.
1321
+
1322
+ Examples:
1323
+ >>> import torch
1324
+ >>>
1325
+ >>> from torchrl.data import ListStorage, PrioritizedReplayBuffer
1326
+ >>>
1327
+ >>> torch.manual_seed(0)
1328
+ >>>
1329
+ >>> rb = PrioritizedReplayBuffer(alpha=0.7, beta=0.9, storage=ListStorage(10))
1330
+ >>> data = range(10)
1331
+ >>> rb.extend(data)
1332
+ >>> sample = rb.sample(3)
1333
+ >>> print(sample)
1334
+ tensor([1, 0, 1])
1335
+ >>> # get the info to find what the indices are
1336
+ >>> sample, info = rb.sample(5, return_info=True)
1337
+ >>> print(sample, info)
1338
+ tensor([2, 7, 4, 3, 5]) {'priority_weight': array([1., 1., 1., 1., 1.], dtype=float32), 'index': array([2, 7, 4, 3, 5])}
1339
+ >>> # update priority
1340
+ >>> priority = torch.ones(5) * 5
1341
+ >>> rb.update_priority(info["index"], priority)
1342
+ >>> # and now a new sample, the weights should be updated
1343
+ >>> sample, info = rb.sample(5, return_info=True)
1344
+ >>> print(sample, info)
1345
+ tensor([2, 5, 2, 2, 5]) {'priority_weight': array([0.36278465, 0.36278465, 0.36278465, 0.36278465, 0.36278465],
1346
+ dtype=float32), 'index': array([2, 5, 2, 2, 5])}
1347
+
1348
+ """
1349
+
1350
+ def __init__(
1351
+ self,
1352
+ *,
1353
+ alpha: float,
1354
+ beta: float,
1355
+ eps: float = 1e-8,
1356
+ dtype: torch.dtype = torch.float,
1357
+ storage: Storage | None = None,
1358
+ sampler: Sampler | None = None,
1359
+ collate_fn: Callable | None = None,
1360
+ pin_memory: bool = False,
1361
+ prefetch: int | None = None,
1362
+ transform: Transform | None = None, # noqa-F821
1363
+ batch_size: int | None = None,
1364
+ dim_extend: int | None = None,
1365
+ delayed_init: bool = False,
1366
+ ) -> None:
1367
+ if storage is None:
1368
+ storage = ListStorage(max_size=1_000)
1369
+ if sampler is None:
1370
+ sampler = PrioritizedSampler(storage.max_size, alpha, beta, eps, dtype)
1371
+ super().__init__(
1372
+ storage=storage,
1373
+ sampler=sampler,
1374
+ collate_fn=collate_fn,
1375
+ pin_memory=pin_memory,
1376
+ prefetch=prefetch,
1377
+ transform=transform,
1378
+ batch_size=batch_size,
1379
+ dim_extend=dim_extend,
1380
+ delayed_init=delayed_init,
1381
+ )
1382
+
1383
+
1384
+ class TensorDictReplayBuffer(ReplayBuffer):
1385
+ """TensorDict-specific wrapper around the :class:`~torchrl.data.ReplayBuffer` class.
1386
+
1387
+ Keyword Args:
1388
+ storage (Storage, Callable[[], Storage], optional): the storage to be used.
1389
+ If a callable is passed, it is used as constructor for the storage.
1390
+ If none is provided a default :class:`~torchrl.data.replay_buffers.ListStorage` with
1391
+ ``max_size`` of ``1_000`` will be created.
1392
+ sampler (Sampler, Callable[[], Sampler], optional): the sampler to be used.
1393
+ If a callable is passed, it is used as constructor for the sampler.
1394
+ If none is provided, a default :class:`~torchrl.data.replay_buffers.RandomSampler`
1395
+ will be used.
1396
+ writer (Writer, Callable[[], Writer], optional): the writer to be used.
1397
+ If a callable is passed, it is used as constructor for the writer.
1398
+ If none is provided a default :class:`~torchrl.data.replay_buffers.TensorDictRoundRobinWriter`
1399
+ will be used.
1400
+ collate_fn (callable, optional): merges a list of samples to form a
1401
+ mini-batch of Tensor(s)/outputs. Used when using batched
1402
+ loading from a map-style dataset. The default value will be decided
1403
+ based on the storage type.
1404
+ pin_memory (bool): whether pin_memory() should be called on the rb
1405
+ samples.
1406
+ prefetch (int, optional): number of next batches to be prefetched
1407
+ using multithreading. Defaults to None (no prefetching).
1408
+ transform (Transform or Callable[[Any], Any], optional): Transform to be executed when
1409
+ :meth:`sample` is called.
1410
+ To chain transforms use the :class:`~torchrl.envs.Compose` class.
1411
+ Transforms should be used with :class:`tensordict.TensorDict`
1412
+ content. A generic callable can also be passed if the replay buffer
1413
+ is used with PyTree structures (see example below).
1414
+ Unlike storages, writers and samplers, transform constructors must
1415
+ be passed as separate keyword argument :attr:`transform_factory`,
1416
+ as it is impossible to distinguish a constructor from a transform.
1417
+ transform_factory (Callable[[], Callable], optional): a factory for the
1418
+ transform. Exclusive with :attr:`transform`.
1419
+ batch_size (int, optional): the batch size to be used when sample() is
1420
+ called.
1421
+
1422
+ .. note::
1423
+ The batch-size can be specified at construction time via the
1424
+ ``batch_size`` argument, or at sampling time. The former should
1425
+ be preferred whenever the batch-size is consistent across the
1426
+ experiment. If the batch-size is likely to change, it can be
1427
+ passed to the :meth:`~.sample` method. This option is
1428
+ incompatible with prefetching (since this requires to know the
1429
+ batch-size in advance) as well as with samplers that have a
1430
+ ``drop_last`` argument.
1431
+
1432
+ priority_key (str, optional): the key at which priority is assumed to
1433
+ be stored within TensorDicts added to this ReplayBuffer.
1434
+ This is to be used when the sampler is of type
1435
+ :class:`~torchrl.data.PrioritizedSampler`.
1436
+ Defaults to ``"td_error"``.
1437
+ dim_extend (int, optional): indicates the dim to consider for
1438
+ extension when calling :meth:`~.extend`. Defaults to ``storage.ndim-1``.
1439
+ When using ``dim_extend > 0``, we recommend using the ``ndim``
1440
+ argument in the storage instantiation if that argument is
1441
+ available, to let storages know that the data is
1442
+ multi-dimensional and keep consistent notions of storage-capacity
1443
+ and batch-size during sampling.
1444
+
1445
+ .. note:: This argument has no effect on :meth:`~.add` and
1446
+ therefore should be used with caution when both :meth:`~.add`
1447
+ and :meth:`~.extend` are used in a codebase. For example:
1448
+
1449
+ >>> data = torch.zeros(3, 4)
1450
+ >>> rb = ReplayBuffer(
1451
+ ... storage=LazyTensorStorage(10, ndim=2),
1452
+ ... dim_extend=1)
1453
+ >>> # these two approaches are equivalent:
1454
+ >>> for d in data.unbind(1):
1455
+ ... rb.add(d)
1456
+ >>> rb.extend(data)
1457
+
1458
+ generator (torch.Generator, optional): a generator to use for sampling.
1459
+ Using a dedicated generator for the replay buffer can allow a fine-grained control
1460
+ over seeding, for instance keeping the global seed different but the RB seed identical
1461
+ for distributed jobs.
1462
+ Defaults to ``None`` (global default generator).
1463
+
1464
+ .. warning:: As of now, the generator has no effect on the transforms.
1465
+ shared (bool, optional): whether the buffer will be shared using multiprocessing or not.
1466
+ Defaults to ``False``.
1467
+ compilable (bool, optional): whether the writer is compilable.
1468
+ If ``True``, the writer cannot be shared between multiple processes.
1469
+ Defaults to ``False``.
1470
+ delayed_init (bool, optional): whether to initialize storage, writer, sampler and transform
1471
+ the first time the buffer is used rather than during construction.
1472
+ This is useful when the replay buffer needs to be pickled and sent to remote workers,
1473
+ particularly when using transforms with modules that require gradients.
1474
+ If not specified, defaults to ``True`` when ``transform_factory`` is provided,
1475
+ and ``False`` otherwise.
1476
+
1477
+ Examples:
1478
+ >>> import torch
1479
+ >>>
1480
+ >>> from torchrl.data import LazyTensorStorage, TensorDictReplayBuffer
1481
+ >>> from tensordict import TensorDict
1482
+ >>>
1483
+ >>> torch.manual_seed(0)
1484
+ >>>
1485
+ >>> rb = TensorDictReplayBuffer(storage=LazyTensorStorage(10), batch_size=5)
1486
+ >>> data = TensorDict({"a": torch.ones(10, 3), ("b", "c"): torch.zeros(10, 1, 1)}, [10])
1487
+ >>> rb.extend(data)
1488
+ >>> sample = rb.sample(3)
1489
+ >>> # samples keep track of the index
1490
+ >>> print(sample)
1491
+ TensorDict(
1492
+ fields={
1493
+ a: Tensor(shape=torch.Size([3, 3]), device=cpu, dtype=torch.float32, is_shared=False),
1494
+ b: TensorDict(
1495
+ fields={
1496
+ c: Tensor(shape=torch.Size([3, 1, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
1497
+ batch_size=torch.Size([3]),
1498
+ device=cpu,
1499
+ is_shared=False),
1500
+ index: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.int32, is_shared=False)},
1501
+ batch_size=torch.Size([3]),
1502
+ device=cpu,
1503
+ is_shared=False)
1504
+ >>> # we can iterate over the buffer
1505
+ >>> for i, data in enumerate(rb):
1506
+ ... print(i, data)
1507
+ ... if i == 2:
1508
+ ... break
1509
+ 0 TensorDict(
1510
+ fields={
1511
+ a: Tensor(shape=torch.Size([5, 3]), device=cpu, dtype=torch.float32, is_shared=False),
1512
+ b: TensorDict(
1513
+ fields={
1514
+ c: Tensor(shape=torch.Size([5, 1, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
1515
+ batch_size=torch.Size([5]),
1516
+ device=cpu,
1517
+ is_shared=False),
1518
+ index: Tensor(shape=torch.Size([5]), device=cpu, dtype=torch.int32, is_shared=False)},
1519
+ batch_size=torch.Size([5]),
1520
+ device=cpu,
1521
+ is_shared=False)
1522
+ 1 TensorDict(
1523
+ fields={
1524
+ a: Tensor(shape=torch.Size([5, 3]), device=cpu, dtype=torch.float32, is_shared=False),
1525
+ b: TensorDict(
1526
+ fields={
1527
+ c: Tensor(shape=torch.Size([5, 1, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
1528
+ batch_size=torch.Size([5]),
1529
+ device=cpu,
1530
+ is_shared=False),
1531
+ index: Tensor(shape=torch.Size([5]), device=cpu, dtype=torch.int32, is_shared=False)},
1532
+ batch_size=torch.Size([5]),
1533
+ device=cpu,
1534
+ is_shared=False)
1535
+
1536
+ """
1537
+
1538
+ def __init__(self, *, priority_key: str = "td_error", **kwargs) -> None:
1539
+ writer = kwargs.get("writer", None)
1540
+ if writer is None:
1541
+ kwargs["writer"] = partial(
1542
+ TensorDictRoundRobinWriter, compilable=kwargs.get("compilable")
1543
+ )
1544
+ super().__init__(**kwargs)
1545
+ self.priority_key = priority_key
1546
+
1547
+ def _get_priority_item(self, tensordict: TensorDictBase) -> float:
1548
+ priority = tensordict.get(self.priority_key, None)
1549
+ if self._storage.ndim > 1:
1550
+ # We have to flatten the priority otherwise we'll be aggregating
1551
+ # the priority across batches
1552
+ priority = priority.flatten(0, self._storage.ndim - 1)
1553
+ if priority is None:
1554
+ return self._sampler.default_priority
1555
+ try:
1556
+ if priority.numel() > 1:
1557
+ priority = _reduce(priority, self._sampler.reduction)
1558
+ else:
1559
+ priority = priority.item()
1560
+ except ValueError:
1561
+ raise ValueError(
1562
+ f"Found a priority key of size"
1563
+ f" {tensordict.get(self.priority_key).shape} but expected "
1564
+ f"scalar value"
1565
+ )
1566
+
1567
+ if self._storage.ndim > 1:
1568
+ priority = priority.unflatten(0, tensordict.shape[: self._storage.ndim])
1569
+
1570
+ return priority
1571
+
1572
+ def _get_priority_vector(self, tensordict: TensorDictBase) -> torch.Tensor:
1573
+ priority = tensordict.get(self.priority_key, None)
1574
+ if priority is None:
1575
+ return torch.tensor(
1576
+ self._sampler.default_priority,
1577
+ dtype=torch.float,
1578
+ device=tensordict.device,
1579
+ ).expand(tensordict.shape[0])
1580
+ if self._storage.ndim > 1 and priority.ndim >= self._storage.ndim:
1581
+ # We have to flatten the priority otherwise we'll be aggregating
1582
+ # the priority across batches
1583
+ priority = priority.flatten(0, self._storage.ndim - 1)
1584
+
1585
+ priority = priority.reshape(priority.shape[0], -1)
1586
+ priority = _reduce(priority, self._sampler.reduction, dim=1)
1587
+
1588
+ if self._storage.ndim > 1:
1589
+ priority = priority.unflatten(0, tensordict.shape[: self._storage.ndim])
1590
+
1591
+ return priority
1592
+
1593
+ @_maybe_delay_init
1594
+ def add(self, data: TensorDictBase) -> int:
1595
+ if self._transform is not None:
1596
+ with _set_dispatch_td_nn_modules(is_tensor_collection(data)):
1597
+ data = self._transform.inv(data)
1598
+ if data is None:
1599
+ return torch.zeros((0, self._storage.ndim), dtype=torch.long)
1600
+
1601
+ index = super()._add(data)
1602
+ if index is not None:
1603
+ if is_tensor_collection(data):
1604
+ self._set_index_in_td(data, index)
1605
+
1606
+ self.update_tensordict_priority(data)
1607
+ return index
1608
+
1609
+ @_maybe_delay_init
1610
+ def extend(
1611
+ self, tensordicts: TensorDictBase, *, update_priority: bool | None = None
1612
+ ) -> torch.Tensor:
1613
+ """Extends the replay buffer with a batch of data.
1614
+
1615
+ Args:
1616
+ tensordicts (TensorDictBase): The data to extend the replay buffer with.
1617
+
1618
+ Keyword Args:
1619
+ update_priority (bool, optional): Whether to update the priority of the data. Defaults to True.
1620
+
1621
+ Returns:
1622
+ The indices of the data that were added to the replay buffer.
1623
+ """
1624
+ if not isinstance(tensordicts, TensorDictBase):
1625
+ raise ValueError(
1626
+ f"{self.__class__.__name__} only accepts TensorDictBase subclasses. tensorclasses "
1627
+ f"and other types are not compatible with that class. "
1628
+ "Please use a regular `ReplayBuffer` instead."
1629
+ )
1630
+ if self._transform is not None:
1631
+ tensordicts = self._transform.inv(tensordicts)
1632
+ if tensordicts is None:
1633
+ return torch.zeros((0, self._storage.ndim), dtype=torch.long)
1634
+
1635
+ index = super()._extend(tensordicts)
1636
+
1637
+ # TODO: to be usable directly, the indices should be flipped but the issue
1638
+ # is that just doing this results in indices that are not sorted like the original data
1639
+ # so the actually indices will have to be used on the _storage directly (not on the buffer)
1640
+ self._set_index_in_td(tensordicts, index)
1641
+ if update_priority is None:
1642
+ update_priority = True
1643
+ if update_priority:
1644
+ try:
1645
+ vector = tensordicts.get(self.priority_key)
1646
+ if vector is not None:
1647
+ self.update_priority(index, vector)
1648
+ except Exception as e:
1649
+ raise RuntimeError(
1650
+ "Failed to update priority of extended data. You can try to set update_priority=False in the extend method and update the priority manually."
1651
+ ) from e
1652
+ return index
1653
+
1654
+ def _set_index_in_td(self, tensordict, index):
1655
+ if index is None:
1656
+ return
1657
+ if _is_int(index):
1658
+ index = torch.as_tensor(index, device=tensordict.device)
1659
+ elif index.ndim == 2 and index.shape[:1] != tensordict.shape[:1]:
1660
+ for dim in range(2, tensordict.ndim + 1):
1661
+ if index.shape[:1].numel() == tensordict.shape[:dim].numel():
1662
+ # if index has 2 dims and is in a non-zero format
1663
+ index = index.unflatten(0, tensordict.shape[:dim])
1664
+ break
1665
+ else:
1666
+ raise RuntimeError(
1667
+ f"could not find how to reshape index with shape {index.shape} to fit in tensordict with shape {tensordict.shape}"
1668
+ )
1669
+ tensordict.set("index", index)
1670
+ return
1671
+ tensordict.set("index", expand_as_right(index, tensordict))
1672
+
1673
+ @_maybe_delay_init
1674
+ def update_tensordict_priority(self, data: TensorDictBase) -> None:
1675
+ if not isinstance(self._sampler, PrioritizedSampler):
1676
+ return
1677
+ if data.ndim:
1678
+ priority = self._get_priority_vector(data)
1679
+ else:
1680
+ priority = torch.as_tensor(self._get_priority_item(data))
1681
+ index = data.get("index")
1682
+ if self._storage.ndim > 1 and index.ndim == 2:
1683
+ index = index.unbind(-1)
1684
+ else:
1685
+ while index.shape != priority.shape:
1686
+ # reduce index
1687
+ index = index[..., 0]
1688
+ return self.update_priority(index, priority)
1689
+
1690
+ def sample(
1691
+ self,
1692
+ batch_size: int | None = None,
1693
+ return_info: bool = False,
1694
+ include_info: bool | None = None,
1695
+ ) -> TensorDictBase:
1696
+ """Samples a batch of data from the replay buffer.
1697
+
1698
+ Uses Sampler to sample indices, and retrieves them from Storage.
1699
+
1700
+ Args:
1701
+ batch_size (int, optional): size of data to be collected. If none
1702
+ is provided, this method will sample a batch-size as indicated
1703
+ by the sampler.
1704
+ return_info (bool): whether to return info. If True, the result
1705
+ is a tuple (data, info). If False, the result is the data.
1706
+
1707
+ Returns:
1708
+ A tensordict containing a batch of data selected in the replay buffer.
1709
+ A tuple containing this tensordict and info if return_info flag is set to True.
1710
+ """
1711
+ if include_info is not None:
1712
+ warnings.warn(
1713
+ "include_info is going to be deprecated soon."
1714
+ "The default behavior has changed to `include_info=True` "
1715
+ "to avoid bugs linked to wrongly preassigned values in the "
1716
+ "output tensordict."
1717
+ )
1718
+
1719
+ data, info = super().sample(batch_size, return_info=True)
1720
+ is_tc = is_tensor_collection(data)
1721
+ if is_tc and not is_tensorclass(data) and include_info in (True, None):
1722
+ is_locked = data.is_locked
1723
+ if is_locked:
1724
+ data.unlock_()
1725
+ for key, val in info.items():
1726
+ if key == "index" and isinstance(val, tuple):
1727
+ val = torch.stack(val, -1)
1728
+ try:
1729
+ val = _to_torch(val, data.device)
1730
+ if val.ndim < data.ndim:
1731
+ val = expand_as_right(val, data)
1732
+ data.set(key, val)
1733
+ except RuntimeError:
1734
+ raise RuntimeError(
1735
+ "Failed to set the metadata (e.g., indices or weights) in the sampled tensordict within TensorDictReplayBuffer.sample. "
1736
+ "This is probably caused by a shape mismatch (one of the transforms has probably modified "
1737
+ "the shape of the output tensordict). "
1738
+ "You can always recover these items from the `sample` method from a regular ReplayBuffer "
1739
+ "instance with the 'return_info' flag set to True."
1740
+ )
1741
+ if is_locked:
1742
+ data.lock_()
1743
+ elif not is_tc and include_info in (True, None):
1744
+ raise RuntimeError("Cannot include info in non-tensordict data")
1745
+ if return_info:
1746
+ return data, info
1747
+ return data
1748
+
1749
+ @pin_memory_output
1750
+ def _sample(self, batch_size: int) -> tuple[Any, dict]:
1751
+ is_comp = is_compiling()
1752
+ nc = contextlib.nullcontext()
1753
+ with self._replay_lock if not is_comp else nc, self._write_lock if not is_comp else nc:
1754
+ index, info = self._sampler.sample(self._storage, batch_size)
1755
+ info["index"] = index
1756
+ data = self._storage.get(index)
1757
+ if not isinstance(index, INT_CLASSES):
1758
+ data = self._collate_fn(data)
1759
+ if self._transform is not None and len(self._transform):
1760
+ with data.unlock_(), _set_dispatch_td_nn_modules(True):
1761
+ data = self._transform(data)
1762
+ return data, info
1763
+
1764
+
1765
+ class TensorDictPrioritizedReplayBuffer(TensorDictReplayBuffer):
1766
+ """TensorDict-specific wrapper around the :class:`~torchrl.data.PrioritizedReplayBuffer` class.
1767
+
1768
+ This class returns tensordicts with a new key ``"index"`` that represents
1769
+ the index of each element in the replay buffer. It also provides the
1770
+ :meth:`~.update_tensordict_priority` method that only requires for the
1771
+ tensordict to be passed to it with its new priority value.
1772
+
1773
+ Keyword Args:
1774
+ alpha (:obj:`float`): exponent α determines how much prioritization is used,
1775
+ with α = 0 corresponding to the uniform case.
1776
+ beta (:obj:`float`): importance sampling negative exponent.
1777
+ eps (:obj:`float`): delta added to the priorities to ensure that the buffer
1778
+ does not contain null priorities.
1779
+ storage (Storage, Callable[[], Storage], optional): the storage to be used.
1780
+ If a callable is passed, it is used as constructor for the storage.
1781
+ If none is provided a default :class:`~torchrl.data.replay_buffers.ListStorage` with
1782
+ ``max_size`` of ``1_000`` will be created.
1783
+ collate_fn (callable, optional): merges a list of samples to form a
1784
+ mini-batch of Tensor(s)/outputs. Used when using batched
1785
+ loading from a map-style dataset. The default value will be decided
1786
+ based on the storage type.
1787
+ pin_memory (bool): whether pin_memory() should be called on the rb
1788
+ samples.
1789
+ prefetch (int, optional): number of next batches to be prefetched
1790
+ using multithreading. Defaults to None (no prefetching).
1791
+ transform (Transform or Callable[[Any], Any], optional): Transform to be executed when
1792
+ :meth:`sample` is called.
1793
+ To chain transforms use the :class:`~torchrl.envs.Compose` class.
1794
+ Transforms should be used with :class:`tensordict.TensorDict`
1795
+ content. A generic callable can also be passed if the replay buffer
1796
+ is used with PyTree structures (see example below).
1797
+ Unlike storages, writers and samplers, transform constructors must
1798
+ be passed as separate keyword argument :attr:`transform_factory`,
1799
+ as it is impossible to distinguish a constructor from a transform.
1800
+ transform_factory (Callable[[], Callable], optional): a factory for the
1801
+ transform. Exclusive with :attr:`transform`.
1802
+ batch_size (int, optional): the batch size to be used when sample() is
1803
+ called.
1804
+
1805
+ .. note::
1806
+ The batch-size can be specified at construction time via the
1807
+ ``batch_size`` argument, or at sampling time. The former should
1808
+ be preferred whenever the batch-size is consistent across the
1809
+ experiment. If the batch-size is likely to change, it can be
1810
+ passed to the :meth:`~.sample` method. This option is
1811
+ incompatible with prefetching (since this requires to know the
1812
+ batch-size in advance) as well as with samplers that have a
1813
+ ``drop_last`` argument.
1814
+
1815
+ priority_key (str, optional): the key at which priority is assumed to
1816
+ be stored within TensorDicts added to this ReplayBuffer.
1817
+ This is to be used when the sampler is of type
1818
+ :class:`~torchrl.data.PrioritizedSampler`.
1819
+ Defaults to ``"td_error"``.
1820
+ reduction (str, optional): the reduction method for multidimensional
1821
+ tensordicts (ie stored trajectories). Can be one of "max", "min",
1822
+ "median" or "mean".
1823
+ dim_extend (int, optional): indicates the dim to consider for
1824
+ extension when calling :meth:`~.extend`. Defaults to ``storage.ndim-1``.
1825
+ When using ``dim_extend > 0``, we recommend using the ``ndim``
1826
+ argument in the storage instantiation if that argument is
1827
+ available, to let storages know that the data is
1828
+ multi-dimensional and keep consistent notions of storage-capacity
1829
+ and batch-size during sampling.
1830
+
1831
+ .. note:: This argument has no effect on :meth:`~.add` and
1832
+ therefore should be used with caution when both :meth:`~.add`
1833
+ and :meth:`~.extend` are used in a codebase. For example:
1834
+
1835
+ >>> data = torch.zeros(3, 4)
1836
+ >>> rb = ReplayBuffer(
1837
+ ... storage=LazyTensorStorage(10, ndim=2),
1838
+ ... dim_extend=1)
1839
+ >>> # these two approaches are equivalent:
1840
+ >>> for d in data.unbind(1):
1841
+ ... rb.add(d)
1842
+ >>> rb.extend(data)
1843
+
1844
+ generator (torch.Generator, optional): a generator to use for sampling.
1845
+ Using a dedicated generator for the replay buffer can allow a fine-grained control
1846
+ over seeding, for instance keeping the global seed different but the RB seed identical
1847
+ for distributed jobs.
1848
+ Defaults to ``None`` (global default generator).
1849
+
1850
+ .. warning:: As of now, the generator has no effect on the transforms.
1851
+ shared (bool, optional): whether the buffer will be shared using multiprocessing or not.
1852
+ Defaults to ``False``.
1853
+ compilable (bool, optional): whether the writer is compilable.
1854
+ If ``True``, the writer cannot be shared between multiple processes.
1855
+ Defaults to ``False``.
1856
+ delayed_init (bool, optional): whether to initialize storage, writer, sampler and transform
1857
+ the first time the buffer is used rather than during construction.
1858
+ This is useful when the replay buffer needs to be pickled and sent to remote workers,
1859
+ particularly when using transforms with modules that require gradients.
1860
+ If not specified, defaults to ``True`` when ``transform_factory`` is provided,
1861
+ and ``False`` otherwise.
1862
+
1863
+ Examples:
1864
+ >>> import torch
1865
+ >>>
1866
+ >>> from torchrl.data import LazyTensorStorage, TensorDictPrioritizedReplayBuffer
1867
+ >>> from tensordict import TensorDict
1868
+ >>>
1869
+ >>> torch.manual_seed(0)
1870
+ >>>
1871
+ >>> rb = TensorDictPrioritizedReplayBuffer(alpha=0.7, beta=1.1, storage=LazyTensorStorage(10), batch_size=5)
1872
+ >>> data = TensorDict({"a": torch.ones(10, 3), ("b", "c"): torch.zeros(10, 3, 1)}, [10])
1873
+ >>> rb.extend(data)
1874
+ >>> print("len of rb", len(rb))
1875
+ len of rb 10
1876
+ >>> sample = rb.sample(5)
1877
+ >>> print(sample)
1878
+ TensorDict(
1879
+ fields={
1880
+ priority_weight: Tensor(shape=torch.Size([5]), device=cpu, dtype=torch.float32, is_shared=False),
1881
+ a: Tensor(shape=torch.Size([5, 3]), device=cpu, dtype=torch.float32, is_shared=False),
1882
+ b: TensorDict(
1883
+ fields={
1884
+ c: Tensor(shape=torch.Size([5, 3, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
1885
+ batch_size=torch.Size([5]),
1886
+ device=cpu,
1887
+ is_shared=False),
1888
+ index: Tensor(shape=torch.Size([5]), device=cpu, dtype=torch.int64, is_shared=False)},
1889
+ batch_size=torch.Size([5]),
1890
+ device=cpu,
1891
+ is_shared=False)
1892
+ >>> print("index", sample["index"])
1893
+ index tensor([9, 5, 2, 2, 7])
1894
+ >>> # give a high priority to these samples...
1895
+ >>> sample.set("td_error", 100*torch.ones(sample.shape))
1896
+ >>> # and update priority
1897
+ >>> rb.update_tensordict_priority(sample)
1898
+ >>> # the new sample should have a high overlap with the previous one
1899
+ >>> sample = rb.sample(5)
1900
+ >>> print(sample)
1901
+ TensorDict(
1902
+ fields={
1903
+ priority_weight: Tensor(shape=torch.Size([5]), device=cpu, dtype=torch.float32, is_shared=False),
1904
+ a: Tensor(shape=torch.Size([5, 3]), device=cpu, dtype=torch.float32, is_shared=False),
1905
+ b: TensorDict(
1906
+ fields={
1907
+ c: Tensor(shape=torch.Size([5, 3, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
1908
+ batch_size=torch.Size([5]),
1909
+ device=cpu,
1910
+ is_shared=False),
1911
+ index: Tensor(shape=torch.Size([5]), device=cpu, dtype=torch.int64, is_shared=False)},
1912
+ batch_size=torch.Size([5]),
1913
+ device=cpu,
1914
+ is_shared=False)
1915
+ >>> print("index", sample["index"])
1916
+ index tensor([2, 5, 5, 9, 7])
1917
+
1918
+ """
1919
+
1920
+ def __init__(
1921
+ self,
1922
+ *,
1923
+ alpha: float,
1924
+ beta: float,
1925
+ priority_key: str = "td_error",
1926
+ eps: float = 1e-8,
1927
+ storage: Storage | None = None,
1928
+ collate_fn: Callable | None = None,
1929
+ pin_memory: bool = False,
1930
+ prefetch: int | None = None,
1931
+ transform: Transform | None = None, # noqa-F821
1932
+ reduction: str = "max",
1933
+ batch_size: int | None = None,
1934
+ dim_extend: int | None = None,
1935
+ generator: torch.Generator | None = None,
1936
+ shared: bool = False,
1937
+ compilable: bool = False,
1938
+ ) -> None:
1939
+ storage = self._maybe_make_storage(storage, compilable=compilable)
1940
+ sampler = PrioritizedSampler(
1941
+ storage.max_size, alpha, beta, eps, reduction=reduction
1942
+ )
1943
+ super().__init__(
1944
+ priority_key=priority_key,
1945
+ storage=storage,
1946
+ sampler=sampler,
1947
+ collate_fn=collate_fn,
1948
+ pin_memory=pin_memory,
1949
+ prefetch=prefetch,
1950
+ transform=transform,
1951
+ batch_size=batch_size,
1952
+ dim_extend=dim_extend,
1953
+ generator=generator,
1954
+ shared=shared,
1955
+ compilable=compilable,
1956
+ )
1957
+
1958
+
1959
+ @accept_remote_rref_udf_invocation
1960
+ class RemoteTensorDictReplayBuffer(TensorDictReplayBuffer):
1961
+ """A remote invocation friendly ReplayBuffer class. Public methods can be invoked by remote agents using `torch.rpc` or called locally as normal."""
1962
+
1963
+ def __init__(self, *args, **kwargs):
1964
+ super().__init__(*args, **kwargs)
1965
+
1966
+ def sample(
1967
+ self,
1968
+ batch_size: int | None = None,
1969
+ include_info: bool | None = None,
1970
+ return_info: bool = False,
1971
+ ) -> TensorDictBase:
1972
+ return super().sample(
1973
+ batch_size=batch_size, include_info=include_info, return_info=return_info
1974
+ )
1975
+
1976
+ def add(self, data: TensorDictBase) -> int:
1977
+ return super().add(data)
1978
+
1979
+ def extend(
1980
+ self, tensordicts: list | TensorDictBase, *, update_priority: bool | None = None
1981
+ ) -> torch.Tensor:
1982
+ return super().extend(tensordicts, update_priority=update_priority)
1983
+
1984
+ def update_priority(
1985
+ self, index: int | torch.Tensor, priority: int | torch.Tensor
1986
+ ) -> None:
1987
+ return super().update_priority(index, priority)
1988
+
1989
+ def update_tensordict_priority(self, data: TensorDictBase) -> None:
1990
+ return super().update_tensordict_priority(data)
1991
+
1992
+
1993
+ class InPlaceSampler:
1994
+ """[Deprecated] A sampler to write tennsordicts in-place."""
1995
+
1996
+ def __init__(self, device: DEVICE_TYPING | None = None):
1997
+ raise RuntimeError(
1998
+ "This class has been removed without replacement. In-place sampling should be avoided."
1999
+ )
2000
+
2001
+
2002
+ def stack_tensors(list_of_tensor_iterators: list) -> tuple[torch.Tensor]:
2003
+ """Zips a list of iterables containing tensor-like objects and stacks the resulting lists of tensors together.
2004
+
2005
+ Args:
2006
+ list_of_tensor_iterators (list): Sequence containing similar iterators,
2007
+ where each element of the nested iterator is a tensor whose
2008
+ shape match the tensor of other iterators that have the same index.
2009
+
2010
+ Returns:
2011
+ Tuple of stacked tensors.
2012
+
2013
+ Examples:
2014
+ >>> list_of_tensor_iterators = [[torch.ones(3), torch.zeros(1,2)]
2015
+ ... for _ in range(4)]
2016
+ >>> stack_tensors(list_of_tensor_iterators)
2017
+ (tensor([[1., 1., 1.],
2018
+ [1., 1., 1.],
2019
+ [1., 1., 1.],
2020
+ [1., 1., 1.]]), tensor([[[0., 0.]],
2021
+ <BLANKLINE>
2022
+ [[0., 0.]],
2023
+ <BLANKLINE>
2024
+ [[0., 0.]],
2025
+ <BLANKLINE>
2026
+ [[0., 0.]]]))
2027
+
2028
+ """
2029
+ return tuple(torch.stack(tensors, 0) for tensors in zip(*list_of_tensor_iterators))
2030
+
2031
+
2032
+ class ReplayBufferEnsemble(ReplayBuffer):
2033
+ """An ensemble of replay buffers.
2034
+
2035
+ This class allows to read and sample from multiple replay buffers at once.
2036
+ It automatically composes ensemble of storages (:class:`~torchrl.data.replay_buffers.storages.StorageEnsemble`),
2037
+ writers (:class:`~torchrl.data.replay_buffers.writers.WriterEnsemble`) and
2038
+ samplers (:class:`~torchrl.data.replay_buffers.samplers.SamplerEnsemble`).
2039
+
2040
+ .. note::
2041
+ Writing directly to this class is forbidden, but it can be indexed to retrieve
2042
+ the nested nested-buffer and extending it.
2043
+
2044
+ There are two distinct ways of constructing a :class:`~torchrl.data.ReplayBufferEnsemble`:
2045
+ one can either pass a list of replay buffers, or directly pass the components
2046
+ (storage, writers and samplers) like it is done for other replay buffer subclasses.
2047
+
2048
+ Args:
2049
+ rbs (sequence of ReplayBuffer instances, optional): the replay buffers to ensemble.
2050
+ storages (StorageEnsemble, optional): the ensemble of storages, if the replay
2051
+ buffers are not passed.
2052
+ samplers (SamplerEnsemble, optional): the ensemble of samplers, if the replay
2053
+ buffers are not passed.
2054
+ writers (WriterEnsemble, optional): the ensemble of writers, if the replay
2055
+ buffers are not passed.
2056
+ transform (Transform, optional): if passed, this will be the transform
2057
+ of the ensemble of replay buffers. Individual transforms for each
2058
+ replay buffer is retrieved from its parent replay buffer, or directly
2059
+ written in the :class:`~torchrl.data.replay_buffers.storages.StorageEnsemble`
2060
+ object.
2061
+ batch_size (int, optional): the batch-size to use during sampling.
2062
+ collate_fn (callable, optional): the function to use to collate the
2063
+ data after each individual collate_fn has been called and the data
2064
+ is placed in a list (along with the buffer id).
2065
+ collate_fns (list of callables, optional): collate_fn of each nested
2066
+ replay buffer. Retrieved from the :class:`~ReplayBuffer` instances
2067
+ if not provided.
2068
+ p (list of float or Tensor, optional): a list of floating numbers
2069
+ indicating the relative weight of each replay buffer. Can also
2070
+ be passed to torchrl.data.replay_buffers.samplers.SamplerEnsemble`
2071
+ if the buffer is built explicitly.
2072
+ sample_from_all (bool, optional): if ``True``, each dataset will be sampled
2073
+ from. This is not compatible with the ``p`` argument. Defaults to ``False``.
2074
+ Can also be passed to torchrl.data.replay_buffers.samplers.SamplerEnsemble`
2075
+ if the buffer is built explicitly.
2076
+ num_buffer_sampled (int, optional): the number of buffers to sample.
2077
+ if ``sample_from_all=True``, this has no effect, as it defaults to the
2078
+ number of buffers. If ``sample_from_all=False``, buffers will be
2079
+ sampled according to the probabilities ``p``. Can also
2080
+ be passed to torchrl.data.replay_buffers.samplers.SamplerEnsemble`
2081
+ if the buffer is built explicitly.
2082
+ generator (torch.Generator, optional): a generator to use for sampling.
2083
+ Using a dedicated generator for the replay buffer can allow a fine-grained control
2084
+ over seeding, for instance keeping the global seed different but the RB seed identical
2085
+ for distributed jobs.
2086
+ Defaults to ``None`` (global default generator).
2087
+
2088
+ .. warning:: As of now, the generator has no effect on the transforms.
2089
+
2090
+ shared (bool, optional): whether the buffer will be shared using multiprocessing or not.
2091
+ Defaults to ``False``.
2092
+ delayed_init (bool, optional): whether to initialize storage, writer, sampler and transform
2093
+ the first time the buffer is used rather than during construction.
2094
+ This is useful when the replay buffer needs to be pickled and sent to remote workers,
2095
+ particularly when using transforms with modules that require gradients.
2096
+ If not specified, defaults to ``True`` when ``transform_factory`` is provided,
2097
+ and ``False`` otherwise.
2098
+
2099
+ Examples:
2100
+ >>> from torchrl.envs import Compose, ToTensorImage, Resize, RenameTransform
2101
+ >>> from torchrl.data import TensorDictReplayBuffer, ReplayBufferEnsemble, LazyMemmapStorage
2102
+ >>> from tensordict import TensorDict
2103
+ >>> import torch
2104
+ >>> rb0 = TensorDictReplayBuffer(
2105
+ ... storage=LazyMemmapStorage(10),
2106
+ ... transform=Compose(
2107
+ ... ToTensorImage(in_keys=["pixels", ("next", "pixels")]),
2108
+ ... Resize(32, in_keys=["pixels", ("next", "pixels")]),
2109
+ ... RenameTransform([("some", "key")], ["renamed"]),
2110
+ ... ),
2111
+ ... )
2112
+ >>> rb1 = TensorDictReplayBuffer(
2113
+ ... storage=LazyMemmapStorage(10),
2114
+ ... transform=Compose(
2115
+ ... ToTensorImage(in_keys=["pixels", ("next", "pixels")]),
2116
+ ... Resize(32, in_keys=["pixels", ("next", "pixels")]),
2117
+ ... RenameTransform(["another_key"], ["renamed"]),
2118
+ ... ),
2119
+ ... )
2120
+ >>> rb = ReplayBufferEnsemble(
2121
+ ... rb0,
2122
+ ... rb1,
2123
+ ... p=[0.5, 0.5],
2124
+ ... transform=Resize(33, in_keys=["pixels"], out_keys=["pixels33"]),
2125
+ ... )
2126
+ >>> print(rb)
2127
+ ReplayBufferEnsemble(
2128
+ storages=StorageEnsemble(
2129
+ storages=(<torchrl.data.replay_buffers.storages.LazyMemmapStorage object at 0x13a2ef430>, <torchrl.data.replay_buffers.storages.LazyMemmapStorage object at 0x13a2f9310>),
2130
+ transforms=[Compose(
2131
+ ToTensorImage(keys=['pixels', ('next', 'pixels')]),
2132
+ Resize(w=32, h=32, interpolation=InterpolationMode.BILINEAR, keys=['pixels', ('next', 'pixels')]),
2133
+ RenameTransform(keys=[('some', 'key')])), Compose(
2134
+ ToTensorImage(keys=['pixels', ('next', 'pixels')]),
2135
+ Resize(w=32, h=32, interpolation=InterpolationMode.BILINEAR, keys=['pixels', ('next', 'pixels')]),
2136
+ RenameTransform(keys=['another_key']))]),
2137
+ samplers=SamplerEnsemble(
2138
+ samplers=(<torchrl.data.replay_buffers.samplers.RandomSampler object at 0x13a2f9220>, <torchrl.data.replay_buffers.samplers.RandomSampler object at 0x13a2f9f70>)),
2139
+ writers=WriterEnsemble(
2140
+ writers=(<torchrl.data.replay_buffers.writers.TensorDictRoundRobinWriter object at 0x13a2d9b50>, <torchrl.data.replay_buffers.writers.TensorDictRoundRobinWriter object at 0x13a2f95b0>)),
2141
+ batch_size=None,
2142
+ transform=Compose(
2143
+ Resize(w=33, h=33, interpolation=InterpolationMode.BILINEAR, keys=['pixels'])),
2144
+ collate_fn=<built-in method stack of type object at 0x128648260>)
2145
+ >>> data0 = TensorDict(
2146
+ ... {
2147
+ ... "pixels": torch.randint(255, (10, 244, 244, 3)),
2148
+ ... ("next", "pixels"): torch.randint(255, (10, 244, 244, 3)),
2149
+ ... ("some", "key"): torch.randn(10),
2150
+ ... },
2151
+ ... batch_size=[10],
2152
+ ... )
2153
+ >>> data1 = TensorDict(
2154
+ ... {
2155
+ ... "pixels": torch.randint(255, (10, 64, 64, 3)),
2156
+ ... ("next", "pixels"): torch.randint(255, (10, 64, 64, 3)),
2157
+ ... "another_key": torch.randn(10),
2158
+ ... },
2159
+ ... batch_size=[10],
2160
+ ... )
2161
+ >>> rb[0].extend(data0)
2162
+ >>> rb[1].extend(data1)
2163
+ >>> for _ in range(2):
2164
+ ... sample = rb.sample(10)
2165
+ ... assert sample["next", "pixels"].shape == torch.Size([2, 5, 3, 32, 32])
2166
+ ... assert sample["pixels"].shape == torch.Size([2, 5, 3, 32, 32])
2167
+ ... assert sample["pixels33"].shape == torch.Size([2, 5, 3, 33, 33])
2168
+ ... assert sample["renamed"].shape == torch.Size([2, 5])
2169
+
2170
+ """
2171
+
2172
+ _collate_fn_val = None
2173
+
2174
+ def __init__(
2175
+ self,
2176
+ *rbs,
2177
+ storages: StorageEnsemble | None = None,
2178
+ samplers: SamplerEnsemble | None = None,
2179
+ writers: WriterEnsemble | None = None,
2180
+ transform: Transform | None = None, # noqa: F821
2181
+ batch_size: int | None = None,
2182
+ collate_fn: Callable | None = None,
2183
+ collate_fns: list[Callable] | None = None,
2184
+ p: Tensor = None,
2185
+ sample_from_all: bool = False,
2186
+ num_buffer_sampled: int | None = None,
2187
+ generator: torch.Generator | None = None,
2188
+ shared: bool = False,
2189
+ **kwargs,
2190
+ ):
2191
+
2192
+ if collate_fn is None:
2193
+ collate_fn = _stack_anything
2194
+
2195
+ if rbs:
2196
+ if storages is not None or samplers is not None or writers is not None:
2197
+ raise RuntimeError
2198
+ # Ensure all replay buffers are initialized before creating ensemble
2199
+ for rb in rbs:
2200
+ if (
2201
+ hasattr(rb, "_delayed_init")
2202
+ and rb._delayed_init
2203
+ and not rb.initialized
2204
+ ):
2205
+ rb._init()
2206
+ storages = StorageEnsemble(
2207
+ *[rb._storage for rb in rbs], transforms=[rb._transform for rb in rbs]
2208
+ )
2209
+ samplers = SamplerEnsemble(
2210
+ *[rb._sampler for rb in rbs],
2211
+ p=p,
2212
+ sample_from_all=sample_from_all,
2213
+ num_buffer_sampled=num_buffer_sampled,
2214
+ )
2215
+ writers = WriterEnsemble(*[rb._writer for rb in rbs])
2216
+ if collate_fns is None:
2217
+ collate_fns = [rb._collate_fn for rb in rbs]
2218
+ else:
2219
+ rbs = None
2220
+ if collate_fns is None:
2221
+ collate_fns = [
2222
+ _get_default_collate(storage) for storage in storages._storages
2223
+ ]
2224
+ self._rbs = rbs
2225
+ self._collate_fns = collate_fns
2226
+ super().__init__(
2227
+ storage=storages,
2228
+ sampler=samplers,
2229
+ writer=writers,
2230
+ transform=transform,
2231
+ batch_size=batch_size,
2232
+ collate_fn=collate_fn,
2233
+ generator=generator,
2234
+ shared=shared,
2235
+ **kwargs,
2236
+ )
2237
+
2238
+ def _sample(self, *args, **kwargs):
2239
+ sample, info = super()._sample(*args, **kwargs)
2240
+ if isinstance(sample, TensorDictBase):
2241
+ buffer_ids = info.get(("index", "buffer_ids"))
2242
+ info.set(
2243
+ ("index", "buffer_ids"), expand_right(buffer_ids, sample.batch_size)
2244
+ )
2245
+ if isinstance(info, LazyStackedTensorDict):
2246
+ for _info, _sample in zip(
2247
+ info.unbind(info.stack_dim), sample.unbind(info.stack_dim)
2248
+ ):
2249
+ _info.batch_size = _sample.batch_size
2250
+ info = torch.stack(info.tensordicts, info.stack_dim)
2251
+ else:
2252
+ info.batch_size = sample.batch_size
2253
+ sample.update(info)
2254
+
2255
+ return sample, info
2256
+
2257
+ @property
2258
+ def _collate_fn(self):
2259
+ def new_collate(samples):
2260
+ samples = [self._collate_fns[i](sample) for (i, sample) in samples]
2261
+ return self._collate_fn_val(samples)
2262
+
2263
+ return new_collate
2264
+
2265
+ @_collate_fn.setter
2266
+ def _collate_fn(self, value):
2267
+ self._collate_fn_val = value
2268
+
2269
+ _INDEX_ERROR = "Expected an index of type torch.Tensor, range, np.ndarray, int, slice or ellipsis, got {} instead."
2270
+
2271
+ def __getitem__(
2272
+ self, index: int | torch.Tensor | tuple | np.ndarray | list | slice | Ellipsis
2273
+ ) -> Any:
2274
+ # accepts inputs:
2275
+ # (int | 1d tensor | 1d list | 1d array | slice | ellipsis | range, int | tensor | list | array | slice | ellipsis | range)
2276
+ # tensor
2277
+ if isinstance(index, tuple):
2278
+ if index[0] is Ellipsis:
2279
+ index = (slice(None), index[1:])
2280
+ rb = self[index[0]]
2281
+ if len(index) > 1:
2282
+ if rb is self:
2283
+ # then index[0] is an ellipsis/slice(None)
2284
+ sample = [
2285
+ (i, storage[index[1:]])
2286
+ for i, storage in enumerate(self._storage._storages)
2287
+ ]
2288
+ return self._collate_fn(sample)
2289
+ if isinstance(rb, ReplayBufferEnsemble):
2290
+ new_index = (slice(None), *index[1:])
2291
+ return rb[new_index]
2292
+ return rb[index[1:]]
2293
+ return rb
2294
+ if isinstance(index, slice) and index == slice(None):
2295
+ return self
2296
+ if isinstance(index, (list, range, np.ndarray)):
2297
+ index = torch.as_tensor(index)
2298
+ if isinstance(index, torch.Tensor):
2299
+ if index.ndim > 1:
2300
+ raise RuntimeError(
2301
+ f"Cannot index a {type(self)} with tensor indices that have more than one dimension."
2302
+ )
2303
+ if index.is_floating_point():
2304
+ raise TypeError(
2305
+ "A floating point index was received when an integer dtype was expected."
2306
+ )
2307
+ if self._rbs is not None and (
2308
+ isinstance(index, int) or (not isinstance(index, slice) and len(index) == 0)
2309
+ ):
2310
+ try:
2311
+ index = int(index)
2312
+ except Exception:
2313
+ raise IndexError(self._INDEX_ERROR.format(type(index)))
2314
+ try:
2315
+ return self._rbs[index]
2316
+ except IndexError:
2317
+ raise IndexError(self._INDEX_ERROR.format(type(index)))
2318
+
2319
+ if self._rbs is not None:
2320
+ if isinstance(index, torch.Tensor):
2321
+ index = index.tolist()
2322
+ rbs = [self._rbs[i] for i in index]
2323
+ _collate_fns = [self._collate_fns[i] for i in index]
2324
+ else:
2325
+ try:
2326
+ # slice
2327
+ rbs = self._rbs[index]
2328
+ _collate_fns = self._collate_fns[index]
2329
+ except IndexError:
2330
+ raise IndexError(self._INDEX_ERROR.format(type(index)))
2331
+ p = self._sampler._p[index] if self._sampler._p is not None else None
2332
+ return ReplayBufferEnsemble(
2333
+ *rbs,
2334
+ transform=self._transform,
2335
+ batch_size=self._batch_size,
2336
+ collate_fn=self._collate_fn_val,
2337
+ collate_fns=_collate_fns,
2338
+ sample_from_all=self._sampler.sample_from_all,
2339
+ num_buffer_sampled=self._sampler.num_buffer_sampled,
2340
+ p=p,
2341
+ )
2342
+
2343
+ try:
2344
+ samplers = self._sampler[index]
2345
+ writers = self._writer[index]
2346
+ storages = self._storage[index]
2347
+ if isinstance(index, torch.Tensor):
2348
+ _collate_fns = [self._collate_fns[i] for i in index.tolist()]
2349
+ else:
2350
+ _collate_fns = self._collate_fns[index]
2351
+ p = self._sampler._p[index] if self._sampler._p is not None else None
2352
+
2353
+ except IndexError:
2354
+ raise IndexError(self._INDEX_ERROR.format(type(index)))
2355
+
2356
+ return ReplayBufferEnsemble(
2357
+ samplers=samplers,
2358
+ writers=writers,
2359
+ storages=storages,
2360
+ transform=self._transform,
2361
+ batch_size=self._batch_size,
2362
+ collate_fn=self._collate_fn_val,
2363
+ collate_fns=_collate_fns,
2364
+ sample_from_all=self._sampler.sample_from_all,
2365
+ num_buffer_sampled=self._sampler.num_buffer_sampled,
2366
+ p=p,
2367
+ )
2368
+
2369
+ def __len__(self):
2370
+ return len(self._storage)
2371
+
2372
+ def __repr__(self):
2373
+ storages = textwrap.indent(f"storages={self._storage}", " " * 4)
2374
+ writers = textwrap.indent(f"writers={self._writer}", " " * 4)
2375
+ samplers = textwrap.indent(f"samplers={self._sampler}", " " * 4)
2376
+ return f"ReplayBufferEnsemble(\n{storages}, \n{samplers}, \n{writers}, \nbatch_size={self._batch_size}, \ntransform={self._transform}, \ncollate_fn={self._collate_fn_val})"