torchrl 0.11.0__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (394) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/__init__.py +144 -0
  97. torchrl/_extension.py +74 -0
  98. torchrl/_torchrl.cp314-win_amd64.pyd +0 -0
  99. torchrl/_utils.py +1431 -0
  100. torchrl/collectors/__init__.py +48 -0
  101. torchrl/collectors/_base.py +1058 -0
  102. torchrl/collectors/_constants.py +88 -0
  103. torchrl/collectors/_multi_async.py +324 -0
  104. torchrl/collectors/_multi_base.py +1805 -0
  105. torchrl/collectors/_multi_sync.py +464 -0
  106. torchrl/collectors/_runner.py +581 -0
  107. torchrl/collectors/_single.py +2009 -0
  108. torchrl/collectors/_single_async.py +259 -0
  109. torchrl/collectors/collectors.py +62 -0
  110. torchrl/collectors/distributed/__init__.py +32 -0
  111. torchrl/collectors/distributed/default_configs.py +133 -0
  112. torchrl/collectors/distributed/generic.py +1306 -0
  113. torchrl/collectors/distributed/ray.py +1092 -0
  114. torchrl/collectors/distributed/rpc.py +1006 -0
  115. torchrl/collectors/distributed/sync.py +731 -0
  116. torchrl/collectors/distributed/utils.py +160 -0
  117. torchrl/collectors/llm/__init__.py +10 -0
  118. torchrl/collectors/llm/base.py +494 -0
  119. torchrl/collectors/llm/ray_collector.py +275 -0
  120. torchrl/collectors/llm/utils.py +36 -0
  121. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  122. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  123. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  124. torchrl/collectors/utils.py +433 -0
  125. torchrl/collectors/weight_update.py +591 -0
  126. torchrl/csrc/numpy_utils.h +38 -0
  127. torchrl/csrc/pybind.cpp +27 -0
  128. torchrl/csrc/segment_tree.h +458 -0
  129. torchrl/csrc/torch_utils.h +34 -0
  130. torchrl/csrc/utils.cpp +48 -0
  131. torchrl/csrc/utils.h +31 -0
  132. torchrl/data/__init__.py +187 -0
  133. torchrl/data/datasets/__init__.py +58 -0
  134. torchrl/data/datasets/atari_dqn.py +878 -0
  135. torchrl/data/datasets/common.py +281 -0
  136. torchrl/data/datasets/d4rl.py +489 -0
  137. torchrl/data/datasets/d4rl_infos.py +187 -0
  138. torchrl/data/datasets/gen_dgrl.py +375 -0
  139. torchrl/data/datasets/minari_data.py +643 -0
  140. torchrl/data/datasets/openml.py +177 -0
  141. torchrl/data/datasets/openx.py +798 -0
  142. torchrl/data/datasets/roboset.py +363 -0
  143. torchrl/data/datasets/utils.py +11 -0
  144. torchrl/data/datasets/vd4rl.py +432 -0
  145. torchrl/data/llm/__init__.py +34 -0
  146. torchrl/data/llm/dataset.py +491 -0
  147. torchrl/data/llm/history.py +1378 -0
  148. torchrl/data/llm/prompt.py +198 -0
  149. torchrl/data/llm/reward.py +225 -0
  150. torchrl/data/llm/topk.py +186 -0
  151. torchrl/data/llm/utils.py +543 -0
  152. torchrl/data/map/__init__.py +21 -0
  153. torchrl/data/map/hash.py +185 -0
  154. torchrl/data/map/query.py +204 -0
  155. torchrl/data/map/tdstorage.py +363 -0
  156. torchrl/data/map/tree.py +1434 -0
  157. torchrl/data/map/utils.py +103 -0
  158. torchrl/data/postprocs/__init__.py +8 -0
  159. torchrl/data/postprocs/postprocs.py +391 -0
  160. torchrl/data/replay_buffers/__init__.py +99 -0
  161. torchrl/data/replay_buffers/checkpointers.py +622 -0
  162. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  163. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  164. torchrl/data/replay_buffers/samplers.py +2578 -0
  165. torchrl/data/replay_buffers/scheduler.py +265 -0
  166. torchrl/data/replay_buffers/storages.py +2412 -0
  167. torchrl/data/replay_buffers/utils.py +1042 -0
  168. torchrl/data/replay_buffers/writers.py +781 -0
  169. torchrl/data/tensor_specs.py +7101 -0
  170. torchrl/data/utils.py +334 -0
  171. torchrl/envs/__init__.py +265 -0
  172. torchrl/envs/async_envs.py +1105 -0
  173. torchrl/envs/batched_envs.py +3093 -0
  174. torchrl/envs/common.py +4241 -0
  175. torchrl/envs/custom/__init__.py +11 -0
  176. torchrl/envs/custom/chess.py +617 -0
  177. torchrl/envs/custom/llm.py +214 -0
  178. torchrl/envs/custom/pendulum.py +401 -0
  179. torchrl/envs/custom/san_moves.txt +29274 -0
  180. torchrl/envs/custom/tictactoeenv.py +288 -0
  181. torchrl/envs/env_creator.py +263 -0
  182. torchrl/envs/gym_like.py +752 -0
  183. torchrl/envs/libs/__init__.py +68 -0
  184. torchrl/envs/libs/_gym_utils.py +326 -0
  185. torchrl/envs/libs/brax.py +846 -0
  186. torchrl/envs/libs/dm_control.py +544 -0
  187. torchrl/envs/libs/envpool.py +447 -0
  188. torchrl/envs/libs/gym.py +2239 -0
  189. torchrl/envs/libs/habitat.py +138 -0
  190. torchrl/envs/libs/isaac_lab.py +87 -0
  191. torchrl/envs/libs/isaacgym.py +203 -0
  192. torchrl/envs/libs/jax_utils.py +166 -0
  193. torchrl/envs/libs/jumanji.py +963 -0
  194. torchrl/envs/libs/meltingpot.py +599 -0
  195. torchrl/envs/libs/openml.py +153 -0
  196. torchrl/envs/libs/openspiel.py +652 -0
  197. torchrl/envs/libs/pettingzoo.py +1042 -0
  198. torchrl/envs/libs/procgen.py +351 -0
  199. torchrl/envs/libs/robohive.py +429 -0
  200. torchrl/envs/libs/smacv2.py +645 -0
  201. torchrl/envs/libs/unity_mlagents.py +891 -0
  202. torchrl/envs/libs/utils.py +147 -0
  203. torchrl/envs/libs/vmas.py +813 -0
  204. torchrl/envs/llm/__init__.py +63 -0
  205. torchrl/envs/llm/chat.py +730 -0
  206. torchrl/envs/llm/datasets/README.md +4 -0
  207. torchrl/envs/llm/datasets/__init__.py +17 -0
  208. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  209. torchrl/envs/llm/datasets/ifeval.py +274 -0
  210. torchrl/envs/llm/envs.py +789 -0
  211. torchrl/envs/llm/libs/README.md +3 -0
  212. torchrl/envs/llm/libs/__init__.py +8 -0
  213. torchrl/envs/llm/libs/mlgym.py +869 -0
  214. torchrl/envs/llm/reward/__init__.py +10 -0
  215. torchrl/envs/llm/reward/gsm8k.py +324 -0
  216. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  217. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  218. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  222. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  223. torchrl/envs/llm/transforms/__init__.py +55 -0
  224. torchrl/envs/llm/transforms/browser.py +292 -0
  225. torchrl/envs/llm/transforms/dataloading.py +859 -0
  226. torchrl/envs/llm/transforms/format.py +73 -0
  227. torchrl/envs/llm/transforms/kl.py +1544 -0
  228. torchrl/envs/llm/transforms/policy_version.py +189 -0
  229. torchrl/envs/llm/transforms/reason.py +323 -0
  230. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  231. torchrl/envs/llm/transforms/tools.py +1955 -0
  232. torchrl/envs/model_based/__init__.py +9 -0
  233. torchrl/envs/model_based/common.py +180 -0
  234. torchrl/envs/model_based/dreamer.py +112 -0
  235. torchrl/envs/transforms/__init__.py +147 -0
  236. torchrl/envs/transforms/functional.py +48 -0
  237. torchrl/envs/transforms/gym_transforms.py +203 -0
  238. torchrl/envs/transforms/module.py +341 -0
  239. torchrl/envs/transforms/r3m.py +372 -0
  240. torchrl/envs/transforms/ray_service.py +663 -0
  241. torchrl/envs/transforms/rb_transforms.py +214 -0
  242. torchrl/envs/transforms/transforms.py +11835 -0
  243. torchrl/envs/transforms/utils.py +94 -0
  244. torchrl/envs/transforms/vc1.py +307 -0
  245. torchrl/envs/transforms/vecnorm.py +845 -0
  246. torchrl/envs/transforms/vip.py +407 -0
  247. torchrl/envs/utils.py +1718 -0
  248. torchrl/envs/vec_envs.py +11 -0
  249. torchrl/modules/__init__.py +206 -0
  250. torchrl/modules/distributions/__init__.py +73 -0
  251. torchrl/modules/distributions/continuous.py +830 -0
  252. torchrl/modules/distributions/discrete.py +908 -0
  253. torchrl/modules/distributions/truncated_normal.py +187 -0
  254. torchrl/modules/distributions/utils.py +233 -0
  255. torchrl/modules/llm/__init__.py +62 -0
  256. torchrl/modules/llm/backends/__init__.py +65 -0
  257. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  258. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  259. torchrl/modules/llm/backends/vllm/base.py +72 -0
  260. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  261. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  262. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  263. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  264. torchrl/modules/llm/policies/__init__.py +28 -0
  265. torchrl/modules/llm/policies/common.py +1809 -0
  266. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  267. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  268. torchrl/modules/llm/utils.py +23 -0
  269. torchrl/modules/mcts/__init__.py +21 -0
  270. torchrl/modules/mcts/scores.py +579 -0
  271. torchrl/modules/models/__init__.py +86 -0
  272. torchrl/modules/models/batchrenorm.py +119 -0
  273. torchrl/modules/models/decision_transformer.py +179 -0
  274. torchrl/modules/models/exploration.py +731 -0
  275. torchrl/modules/models/llm.py +156 -0
  276. torchrl/modules/models/model_based.py +596 -0
  277. torchrl/modules/models/models.py +1712 -0
  278. torchrl/modules/models/multiagent.py +1067 -0
  279. torchrl/modules/models/recipes/impala.py +185 -0
  280. torchrl/modules/models/utils.py +162 -0
  281. torchrl/modules/planners/__init__.py +10 -0
  282. torchrl/modules/planners/cem.py +228 -0
  283. torchrl/modules/planners/common.py +73 -0
  284. torchrl/modules/planners/mppi.py +265 -0
  285. torchrl/modules/tensordict_module/__init__.py +89 -0
  286. torchrl/modules/tensordict_module/actors.py +2457 -0
  287. torchrl/modules/tensordict_module/common.py +529 -0
  288. torchrl/modules/tensordict_module/exploration.py +814 -0
  289. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  290. torchrl/modules/tensordict_module/rnn.py +1639 -0
  291. torchrl/modules/tensordict_module/sequence.py +132 -0
  292. torchrl/modules/tensordict_module/world_models.py +34 -0
  293. torchrl/modules/utils/__init__.py +38 -0
  294. torchrl/modules/utils/mappings.py +9 -0
  295. torchrl/modules/utils/utils.py +89 -0
  296. torchrl/objectives/__init__.py +78 -0
  297. torchrl/objectives/a2c.py +659 -0
  298. torchrl/objectives/common.py +753 -0
  299. torchrl/objectives/cql.py +1346 -0
  300. torchrl/objectives/crossq.py +710 -0
  301. torchrl/objectives/ddpg.py +453 -0
  302. torchrl/objectives/decision_transformer.py +371 -0
  303. torchrl/objectives/deprecated.py +516 -0
  304. torchrl/objectives/dqn.py +683 -0
  305. torchrl/objectives/dreamer.py +488 -0
  306. torchrl/objectives/functional.py +48 -0
  307. torchrl/objectives/gail.py +258 -0
  308. torchrl/objectives/iql.py +996 -0
  309. torchrl/objectives/llm/__init__.py +30 -0
  310. torchrl/objectives/llm/grpo.py +846 -0
  311. torchrl/objectives/llm/sft.py +482 -0
  312. torchrl/objectives/multiagent/__init__.py +8 -0
  313. torchrl/objectives/multiagent/qmixer.py +396 -0
  314. torchrl/objectives/ppo.py +1669 -0
  315. torchrl/objectives/redq.py +683 -0
  316. torchrl/objectives/reinforce.py +530 -0
  317. torchrl/objectives/sac.py +1580 -0
  318. torchrl/objectives/td3.py +570 -0
  319. torchrl/objectives/td3_bc.py +625 -0
  320. torchrl/objectives/utils.py +782 -0
  321. torchrl/objectives/value/__init__.py +28 -0
  322. torchrl/objectives/value/advantages.py +1956 -0
  323. torchrl/objectives/value/functional.py +1459 -0
  324. torchrl/objectives/value/utils.py +360 -0
  325. torchrl/record/__init__.py +17 -0
  326. torchrl/record/loggers/__init__.py +23 -0
  327. torchrl/record/loggers/common.py +48 -0
  328. torchrl/record/loggers/csv.py +226 -0
  329. torchrl/record/loggers/mlflow.py +142 -0
  330. torchrl/record/loggers/tensorboard.py +139 -0
  331. torchrl/record/loggers/trackio.py +163 -0
  332. torchrl/record/loggers/utils.py +78 -0
  333. torchrl/record/loggers/wandb.py +214 -0
  334. torchrl/record/recorder.py +554 -0
  335. torchrl/services/__init__.py +79 -0
  336. torchrl/services/base.py +109 -0
  337. torchrl/services/ray_service.py +453 -0
  338. torchrl/testing/__init__.py +107 -0
  339. torchrl/testing/assertions.py +179 -0
  340. torchrl/testing/dist_utils.py +122 -0
  341. torchrl/testing/env_creators.py +227 -0
  342. torchrl/testing/env_helper.py +35 -0
  343. torchrl/testing/gym_helpers.py +156 -0
  344. torchrl/testing/llm_mocks.py +119 -0
  345. torchrl/testing/mocking_classes.py +2720 -0
  346. torchrl/testing/modules.py +295 -0
  347. torchrl/testing/mp_helpers.py +15 -0
  348. torchrl/testing/ray_helpers.py +293 -0
  349. torchrl/testing/utils.py +190 -0
  350. torchrl/trainers/__init__.py +42 -0
  351. torchrl/trainers/algorithms/__init__.py +11 -0
  352. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  353. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  354. torchrl/trainers/algorithms/configs/common.py +41 -0
  355. torchrl/trainers/algorithms/configs/data.py +308 -0
  356. torchrl/trainers/algorithms/configs/envs.py +104 -0
  357. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  358. torchrl/trainers/algorithms/configs/logging.py +80 -0
  359. torchrl/trainers/algorithms/configs/modules.py +570 -0
  360. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  361. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  362. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  363. torchrl/trainers/algorithms/configs/utils.py +252 -0
  364. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  365. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  366. torchrl/trainers/algorithms/ppo.py +373 -0
  367. torchrl/trainers/algorithms/sac.py +308 -0
  368. torchrl/trainers/helpers/__init__.py +40 -0
  369. torchrl/trainers/helpers/collectors.py +416 -0
  370. torchrl/trainers/helpers/envs.py +573 -0
  371. torchrl/trainers/helpers/logger.py +33 -0
  372. torchrl/trainers/helpers/losses.py +132 -0
  373. torchrl/trainers/helpers/models.py +658 -0
  374. torchrl/trainers/helpers/replay_buffer.py +59 -0
  375. torchrl/trainers/helpers/trainers.py +301 -0
  376. torchrl/trainers/trainers.py +2052 -0
  377. torchrl/weight_update/__init__.py +33 -0
  378. torchrl/weight_update/_distributed.py +749 -0
  379. torchrl/weight_update/_mp.py +624 -0
  380. torchrl/weight_update/_noupdate.py +102 -0
  381. torchrl/weight_update/_ray.py +1032 -0
  382. torchrl/weight_update/_rpc.py +284 -0
  383. torchrl/weight_update/_shared.py +891 -0
  384. torchrl/weight_update/llm/__init__.py +32 -0
  385. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  386. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  387. torchrl/weight_update/utils.py +73 -0
  388. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  389. torchrl-0.11.0.dist-info/LICENSE +21 -0
  390. torchrl-0.11.0.dist-info/METADATA +1307 -0
  391. torchrl-0.11.0.dist-info/RECORD +394 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,2578 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ import json
8
+ import textwrap
9
+ import warnings
10
+ from abc import ABC, abstractmethod
11
+ from collections import OrderedDict
12
+ from copy import copy, deepcopy
13
+ from multiprocessing.context import get_spawning_popen
14
+ from pathlib import Path
15
+ from typing import Any
16
+
17
+ import numpy as np
18
+ import torch
19
+ from pyvers import implement_for
20
+ from tensordict import MemoryMappedTensor, TensorDict
21
+ from tensordict.utils import NestedKey
22
+ from torch.utils._pytree import tree_map
23
+ from torchrl._extension import EXTENSION_WARNING
24
+ from torchrl._utils import _replace_last, logger, rl_warnings
25
+ from torchrl.data.replay_buffers.storages import Storage, StorageEnsemble, TensorStorage
26
+ from torchrl.data.replay_buffers.utils import _auto_device, _is_int, unravel_index
27
+
28
+ try:
29
+ from torchrl._torchrl import (
30
+ MinSegmentTreeFp32,
31
+ MinSegmentTreeFp64,
32
+ SumSegmentTreeFp32,
33
+ SumSegmentTreeFp64,
34
+ )
35
+ except ImportError:
36
+ # Make default values
37
+ MinSegmentTreeFp32 = None
38
+ MinSegmentTreeFp64 = None
39
+ SumSegmentTreeFp32 = None
40
+ SumSegmentTreeFp64 = None
41
+
42
+ _EMPTY_STORAGE_ERROR = "Cannot sample from an empty storage."
43
+
44
+
45
+ class Sampler(ABC):
46
+ """A generic sampler base class for composable Replay Buffers."""
47
+
48
+ # Some samplers - mainly those without replacement -
49
+ # need to keep track of the number of remaining batches
50
+ _remaining_batches = int(torch.iinfo(torch.int64).max)
51
+
52
+ # The RNG is set by the replay buffer
53
+ _rng: torch.Generator | None = None
54
+
55
+ @abstractmethod
56
+ def sample(self, storage: Storage, batch_size: int) -> tuple[Any, dict]:
57
+ ...
58
+
59
+ def add(self, index: int) -> None:
60
+ return
61
+
62
+ def extend(self, index: torch.Tensor) -> None:
63
+ return
64
+
65
+ def update_priority(
66
+ self,
67
+ index: int | torch.Tensor,
68
+ priority: float | torch.Tensor,
69
+ *,
70
+ storage: Storage | None = None,
71
+ ) -> dict | None:
72
+ warnings.warn(
73
+ f"Calling update_priority() on a sampler {type(self).__name__} that is not prioritized. Make sure this is the indented behavior."
74
+ )
75
+ return
76
+
77
+ def mark_update(
78
+ self, index: int | torch.Tensor, *, storage: Storage | None = None
79
+ ) -> None:
80
+ return
81
+
82
+ @property
83
+ def default_priority(self) -> float:
84
+ return 1.0
85
+
86
+ @abstractmethod
87
+ def state_dict(self) -> dict[str, Any]:
88
+ ...
89
+
90
+ @abstractmethod
91
+ def load_state_dict(self, state_dict: dict[str, Any]) -> None:
92
+ ...
93
+
94
+ @property
95
+ def ran_out(self) -> bool:
96
+ # by default, samplers never run out
97
+ return False
98
+
99
+ @abstractmethod
100
+ def _empty(self):
101
+ ...
102
+
103
+ @abstractmethod
104
+ def dumps(self, path):
105
+ ...
106
+
107
+ @abstractmethod
108
+ def loads(self, path):
109
+ ...
110
+
111
+ def __repr__(self):
112
+ return f"{self.__class__.__name__}()"
113
+
114
+ def __getstate__(self):
115
+ state = copy(self.__dict__)
116
+ state["_rng"] = None
117
+ return state
118
+
119
+
120
+ class RandomSampler(Sampler):
121
+ """A uniformly random sampler for composable replay buffers.
122
+
123
+ Args:
124
+ batch_size (int, optional): if provided, the batch size to be used by
125
+ the replay buffer when calling :meth:`ReplayBuffer.sample`.
126
+
127
+ """
128
+
129
+ def sample(self, storage: Storage, batch_size: int) -> tuple[torch.Tensor, dict]:
130
+ if len(storage) == 0:
131
+ raise RuntimeError(_EMPTY_STORAGE_ERROR)
132
+ index = storage._rand_given_ndim(batch_size)
133
+ return index, {}
134
+
135
+ def _empty(self):
136
+ pass
137
+
138
+ def dumps(self, path):
139
+ # no op
140
+ ...
141
+
142
+ def loads(self, path):
143
+ # no op
144
+ ...
145
+
146
+ def state_dict(self) -> dict[str, Any]:
147
+ return {}
148
+
149
+ def load_state_dict(self, state_dict: dict[str, Any]) -> None:
150
+ return
151
+
152
+
153
+ class SamplerWithoutReplacement(Sampler):
154
+ """A data-consuming sampler that ensures that the same sample is not present in consecutive batches.
155
+
156
+ Args:
157
+ drop_last (bool, optional): if ``True``, the last incomplete sample (if any) will be dropped.
158
+ If ``False``, this last sample will be kept and (unlike with torch dataloaders)
159
+ completed with other samples from a fresh indices permutation.
160
+ Defaults to ``False``.
161
+ shuffle (bool, optional): if ``False``, the items are not randomly
162
+ permuted. This enables to iterate over the replay buffer in the
163
+ order the data was collected. Defaults to ``True``.
164
+
165
+ *Caution*: If the size of the storage changes in between two calls, the samples will be re-shuffled
166
+ (as we can't generally keep track of which samples have been sampled before and which haven't).
167
+
168
+ Similarly, it is expected that the storage content remains the same in between two calls,
169
+ but this is not enforced.
170
+
171
+ When the sampler reaches the end of the list of available indices, a new sample order
172
+ will be generated and the resulting indices will be completed with this new draw, which
173
+ can lead to duplicated indices, unless the :obj:`drop_last` argument is set to ``True``.
174
+
175
+ """
176
+
177
+ def __init__(self, drop_last: bool = False, shuffle: bool = True):
178
+ self._sample_list = None
179
+ self.len_storage = 0
180
+ self.drop_last = drop_last
181
+ self._ran_out = False
182
+ self.shuffle = shuffle
183
+
184
+ def dumps(self, path):
185
+ path = Path(path)
186
+ path.mkdir(exist_ok=True)
187
+
188
+ TensorDict(self.state_dict()).memmap(path)
189
+
190
+ def loads(self, path):
191
+ sd = TensorDict.load_memmap(path).to_dict()
192
+ self.load_state_dict(sd)
193
+
194
+ def _get_sample_list(self, storage: Storage, len_storage: int, batch_size: int):
195
+ if storage is None:
196
+ device = self._sample_list.device
197
+ else:
198
+ device = storage.device if hasattr(storage, "device") else None
199
+
200
+ if self.shuffle:
201
+ _sample_list = torch.randperm(
202
+ len_storage, device=device, generator=self._rng
203
+ )
204
+ else:
205
+ _sample_list = torch.arange(len_storage, device=device)
206
+ self._sample_list = _sample_list
207
+ if self.drop_last:
208
+ self._remaining_batches = self._sample_list.numel() // batch_size
209
+ else:
210
+ self._remaining_batches = -(self._sample_list.numel() // -batch_size)
211
+
212
+ def _single_sample(self, len_storage, batch_size):
213
+ index = self._sample_list[:batch_size]
214
+ self._sample_list = self._sample_list[batch_size:]
215
+ if self.drop_last:
216
+ self._remaining_batches = self._sample_list.numel() // batch_size
217
+ else:
218
+ self._remaining_batches = -(self._sample_list.numel() // -batch_size)
219
+
220
+ # check if we have enough elements for one more batch, assuming same batch size
221
+ # will be used each time sample is called
222
+ if self._sample_list.shape[0] == 0 or (
223
+ self.drop_last and len(self._sample_list) < batch_size
224
+ ):
225
+ self.ran_out = True
226
+ self._get_sample_list(
227
+ storage=None, len_storage=len_storage, batch_size=batch_size
228
+ )
229
+ else:
230
+ self.ran_out = False
231
+ return index
232
+
233
+ def _storage_len(self, storage):
234
+ return len(storage)
235
+
236
+ def sample(
237
+ self, storage: Storage, batch_size: int
238
+ ) -> tuple[Any, dict]: # noqa: F811
239
+ len_storage = self._storage_len(storage)
240
+ if len_storage == 0:
241
+ raise RuntimeError(_EMPTY_STORAGE_ERROR)
242
+ if not len_storage:
243
+ raise RuntimeError("An empty storage was passed")
244
+ if self.len_storage != len_storage or self._sample_list is None:
245
+ self._get_sample_list(storage, len_storage, batch_size=batch_size)
246
+ if len_storage < batch_size and self.drop_last:
247
+ raise ValueError(
248
+ f"The batch size ({batch_size}) is greater than the storage capacity ({len_storage}). "
249
+ "This makes it impossible to return a sample without repeating indices. "
250
+ "Consider changing the sampler class or turn the 'drop_last' argument to False."
251
+ )
252
+ self.len_storage = len_storage
253
+ index = self._single_sample(len_storage, batch_size)
254
+ if storage.ndim > 1:
255
+ index = unravel_index(index, storage.shape)
256
+ # we 'always' return the indices. The 'drop_last' just instructs the
257
+ # sampler to turn to `ran_out = True` whenever the next sample
258
+ # will be too short. This will be read by the replay buffer
259
+ # as a signal for an early break of the __iter__().
260
+ return index, {}
261
+
262
+ @property
263
+ def ran_out(self):
264
+ return self._ran_out
265
+
266
+ @ran_out.setter
267
+ def ran_out(self, value):
268
+ self._ran_out = value
269
+
270
+ def _empty(self):
271
+ self._sample_list = None
272
+ self.len_storage = 0
273
+ self._ran_out = False
274
+
275
+ def state_dict(self) -> dict[str, Any]:
276
+ return OrderedDict(
277
+ len_storage=self.len_storage,
278
+ _sample_list=self._sample_list,
279
+ drop_last=self.drop_last,
280
+ _ran_out=self._ran_out,
281
+ )
282
+
283
+ def load_state_dict(self, state_dict: dict[str, Any]) -> None:
284
+ self.len_storage = state_dict["len_storage"]
285
+ self._sample_list = state_dict["_sample_list"]
286
+ self.drop_last = state_dict["drop_last"]
287
+ self._ran_out = state_dict["_ran_out"]
288
+
289
+ def __repr__(self):
290
+ if self._sample_list is not None:
291
+ perc = len(self._sample_list) / self.len_storage * 100
292
+ else:
293
+ perc = 0.0
294
+ return f"{self.__class__.__name__}({perc: 4.4f}% sampled)"
295
+
296
+
297
+ class PrioritizedSampler(Sampler):
298
+ r"""Prioritized sampler for replay buffer.
299
+
300
+ This sampler implements Prioritized Experience Replay (PER) as presented in
301
+ "Schaul, T.; Quan, J.; Antonoglou, I.; and Silver, D. 2015. Prioritized experience replay."
302
+ (https://arxiv.org/abs/1511.05952)
303
+
304
+ **Core Idea**: Instead of sampling experiences uniformly from the replay buffer,
305
+ PER samples experiences with probability proportional to their "importance" - typically
306
+ measured by the magnitude of their temporal-difference (TD) error. This prioritization
307
+ can lead to faster learning by focusing on experiences that are most informative.
308
+
309
+ **How it works**:
310
+ 1. Each experience is assigned a priority based on its TD error: :math:`p_i = |\delta_i| + \epsilon`
311
+ 2. Sampling probability is computed as: :math:`P(i) = \frac{p_i^\alpha}{\sum_j p_j^\alpha}`
312
+ 3. Importance sampling weights correct for the bias: :math:`w_i = (N \cdot P(i))^{-\beta}`
313
+
314
+ Args:
315
+ max_capacity (int): maximum capacity of the buffer.
316
+ alpha (:obj:`float`): exponent :math:`\alpha` determines how much prioritization is used.
317
+ - :math:`\alpha = 0`: uniform sampling (no prioritization)
318
+ - :math:`\alpha = 1`: full prioritization based on TD error magnitude
319
+ - Typical values: 0.4-0.7 for balanced prioritization
320
+ - Higher :math:`\alpha` means more aggressive prioritization of high-error experiences
321
+ beta (:obj:`float`): importance sampling negative exponent :math:`\beta`.
322
+ - :math:`\beta` controls the correction for the bias introduced by prioritization
323
+ - :math:`\beta = 0`: no correction (biased towards high-priority samples)
324
+ - :math:`\beta = 1`: full correction (unbiased but potentially unstable)
325
+ - Typical values: start at 0.4-0.6 and anneal to 1.0 during training
326
+ - Lower :math:`\beta` early in training provides stability, higher :math:`\beta` later reduces bias
327
+ eps (:obj:`float`, optional): small constant added to priorities to ensure
328
+ no experience has zero priority. This prevents experiences from never
329
+ being sampled. Defaults to 1e-8.
330
+ reduction (str, optional): the reduction method for multidimensional
331
+ tensordicts (ie stored trajectory). Can be one of "max", "min",
332
+ "median" or "mean".
333
+ max_priority_within_buffer (bool, optional): if ``True``, the max-priority
334
+ is tracked within the buffer. When ``False``, the max-priority tracks
335
+ the maximum value since the instantiation of the sampler.
336
+
337
+ **Parameter Guidelines**:
338
+
339
+ - **:math:`\alpha` (alpha)**: Controls how much to prioritize high-error experiences.
340
+ 0.4-0.7: Good balance between learning speed and stability.
341
+ 1.0: Maximum prioritization (may be unstable).
342
+ 0.0: Uniform sampling (no prioritization benefit).
343
+
344
+ - **:math:`\beta` (beta)**: Controls importance sampling correction.
345
+ Start at 0.4-0.6 for training stability.
346
+ Anneal to 1.0 over training to reduce bias.
347
+ Lower values = more stable but biased.
348
+ Higher values = less biased but potentially unstable.
349
+
350
+ - **:math:`\epsilon`**: Small constant to prevent zero priorities.
351
+ 1e-8: Good default value.
352
+ Too small: may cause numerical issues.
353
+ Too large: reduces prioritization effect.
354
+
355
+ Examples:
356
+ >>> from torchrl.data.replay_buffers import ReplayBuffer, LazyTensorStorage, PrioritizedSampler
357
+ >>> from tensordict import TensorDict
358
+ >>> rb = ReplayBuffer(storage=LazyTensorStorage(10), sampler=PrioritizedSampler(max_capacity=10, alpha=1.0, beta=1.0))
359
+ >>> priority = torch.tensor([0, 1000])
360
+ >>> data_0 = TensorDict({"reward": 0, "obs": [0], "action": [0], "priority": priority[0]}, [])
361
+ >>> data_1 = TensorDict({"reward": 1, "obs": [1], "action": [2], "priority": priority[1]}, [])
362
+ >>> rb.add(data_0)
363
+ >>> rb.add(data_1)
364
+ >>> rb.update_priority(torch.tensor([0, 1]), priority=priority)
365
+ >>> sample, info = rb.sample(10, return_info=True)
366
+ >>> print(sample)
367
+ TensorDict(
368
+ fields={
369
+ action: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.int64, is_shared=False),
370
+ obs: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.int64, is_shared=False),
371
+ priority: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.int64, is_shared=False),
372
+ reward: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.int64, is_shared=False)},
373
+ batch_size=torch.Size([10]),
374
+ device=cpu,
375
+ is_shared=False)
376
+ >>> print(info)
377
+ {'priority_weight': array([1.e-11, 1.e-11, 1.e-11, 1.e-11, 1.e-11, 1.e-11, 1.e-11, 1.e-11,
378
+ 1.e-11, 1.e-11], dtype=float32), 'index': array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1])}
379
+
380
+ .. note:: Using a :class:`~torchrl.data.replay_buffers.TensorDictReplayBuffer` can smoothen the
381
+ process of updating the priorities:
382
+
383
+ >>> from torchrl.data.replay_buffers import TensorDictReplayBuffer as TDRB, LazyTensorStorage, PrioritizedSampler
384
+ >>> from tensordict import TensorDict
385
+ >>> rb = TDRB(
386
+ ... storage=LazyTensorStorage(10),
387
+ ... sampler=PrioritizedSampler(max_capacity=10, alpha=1.0, beta=1.0),
388
+ ... priority_key="priority", # This kwarg isn't present in regular RBs
389
+ ... )
390
+ >>> priority = torch.tensor([0, 1000])
391
+ >>> data_0 = TensorDict({"reward": 0, "obs": [0], "action": [0], "priority": priority[0]}, [])
392
+ >>> data_1 = TensorDict({"reward": 1, "obs": [1], "action": [2], "priority": priority[1]}, [])
393
+ >>> data = torch.stack([data_0, data_1])
394
+ >>> rb.extend(data)
395
+ >>> rb.update_priority(data) # Reads the "priority" key as indicated in the constructor
396
+ >>> sample, info = rb.sample(10, return_info=True)
397
+ >>> print(sample['index']) # The index is packed with the tensordict
398
+ tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1])
399
+
400
+ """
401
+
402
+ def __init__(
403
+ self,
404
+ max_capacity: int,
405
+ alpha: float,
406
+ beta: float,
407
+ eps: float = 1e-8,
408
+ dtype: torch.dtype = torch.float,
409
+ reduction: str = "max",
410
+ max_priority_within_buffer: bool = False,
411
+ ) -> None:
412
+ if alpha < 0:
413
+ raise ValueError(
414
+ f"alpha must be greater or equal than 0, got alpha={alpha}"
415
+ )
416
+ if beta < 0:
417
+ raise ValueError(f"beta must be greater or equal to 0, got beta={beta}")
418
+
419
+ self._max_capacity = max_capacity
420
+ self._alpha = alpha
421
+ self._beta = beta
422
+ self._eps = eps
423
+ self.reduction = reduction
424
+ self.dtype = dtype
425
+ self._max_priority_within_buffer = max_priority_within_buffer
426
+ self._init()
427
+ if rl_warnings() and SumSegmentTreeFp32 is None:
428
+ logger.warning(EXTENSION_WARNING)
429
+
430
+ def __repr__(self):
431
+ return f"{self.__class__.__name__}(alpha={self._alpha}, beta={self._beta}, eps={self._eps}, reduction={self.reduction})"
432
+
433
+ @property
434
+ def max_size(self):
435
+ return self._max_capacity
436
+
437
+ @property
438
+ def alpha(self):
439
+ return self._alpha
440
+
441
+ @alpha.setter
442
+ def alpha(self, value):
443
+ self._alpha = value
444
+
445
+ @property
446
+ def beta(self):
447
+ return self._beta
448
+
449
+ @beta.setter
450
+ def beta(self, value):
451
+ self._beta = value
452
+
453
+ def __getstate__(self):
454
+ if get_spawning_popen() is not None:
455
+ raise RuntimeError(
456
+ f"Samplers of type {type(self)} cannot be shared between processes."
457
+ )
458
+ return super().__getstate__()
459
+
460
+ def _init(self) -> None:
461
+ if SumSegmentTreeFp32 is None:
462
+ raise RuntimeError(
463
+ "SumSegmentTreeFp32 is not available. See warning above."
464
+ )
465
+ if MinSegmentTreeFp32 is None:
466
+ raise RuntimeError(
467
+ "MinSegmentTreeFp32 is not available. See warning above."
468
+ )
469
+ if SumSegmentTreeFp64 is None:
470
+ raise RuntimeError(
471
+ "SumSegmentTreeFp64 is not available. See warning above."
472
+ )
473
+ if MinSegmentTreeFp64 is None:
474
+ raise RuntimeError(
475
+ "MinSegmentTreeFp64 is not available. See warning above."
476
+ )
477
+ if self.dtype in (torch.float, torch.FloatType, torch.float32):
478
+ self._sum_tree = SumSegmentTreeFp32(self._max_capacity)
479
+ self._min_tree = MinSegmentTreeFp32(self._max_capacity)
480
+ elif self.dtype in (torch.double, torch.DoubleTensor, torch.float64):
481
+ self._sum_tree = SumSegmentTreeFp64(self._max_capacity)
482
+ self._min_tree = MinSegmentTreeFp64(self._max_capacity)
483
+ else:
484
+ raise NotImplementedError(
485
+ f"dtype {self.dtype} not supported by PrioritizedSampler"
486
+ )
487
+ self._max_priority = None
488
+
489
+ def _empty(self) -> None:
490
+ self._init()
491
+
492
+ @property
493
+ def _max_priority(self) -> tuple[float | None, int | None]:
494
+ max_priority_index = self.__dict__.get("_max_priority")
495
+ if max_priority_index is None:
496
+ return (None, None)
497
+ return max_priority_index
498
+
499
+ @_max_priority.setter
500
+ def _max_priority(self, value: tuple[float | None, int | None]) -> None:
501
+ self.__dict__["_max_priority"] = value
502
+
503
+ def _maybe_erase_max_priority(
504
+ self, index: torch.Tensor | int | slice | tuple
505
+ ) -> None:
506
+ if not self._max_priority_within_buffer:
507
+ return
508
+ max_priority_index = self._max_priority[1]
509
+ if max_priority_index is None:
510
+ return
511
+
512
+ def check_index(index=index, max_priority_index=max_priority_index):
513
+ if isinstance(index, torch.Tensor):
514
+ # index can be 1d or 2d
515
+ if index.ndim == 1:
516
+ is_overwritten = (index == max_priority_index).any()
517
+ else:
518
+ is_overwritten = (index == max_priority_index).all(-1).any()
519
+ elif isinstance(index, int):
520
+ is_overwritten = index == max_priority_index
521
+ elif isinstance(index, slice):
522
+ # This won't work if called recursively
523
+ is_overwritten = max_priority_index in range(
524
+ index.indices(self._max_capacity)
525
+ )
526
+ elif isinstance(index, tuple):
527
+ is_overwritten = isinstance(max_priority_index, tuple)
528
+ if is_overwritten:
529
+ for idx, mpi in zip(index, max_priority_index):
530
+ is_overwritten &= check_index(idx, mpi)
531
+ else:
532
+ raise TypeError(f"index of type {type(index)} is not recognized.")
533
+ return is_overwritten
534
+
535
+ is_overwritten = check_index()
536
+ if is_overwritten:
537
+ self._max_priority = None
538
+
539
+ @property
540
+ def default_priority(self) -> float:
541
+ mp = self._max_priority[0]
542
+ if mp is None:
543
+ mp = 1
544
+ return (mp + self._eps) ** self._alpha
545
+
546
+ def sample(self, storage: Storage, batch_size: int) -> torch.Tensor:
547
+ if len(storage) == 0:
548
+ raise RuntimeError(_EMPTY_STORAGE_ERROR)
549
+ p_sum = self._sum_tree.query(0, len(storage))
550
+ p_min = self._min_tree.query(0, len(storage))
551
+
552
+ if p_sum <= 0:
553
+ raise RuntimeError("non-positive p_sum")
554
+ if p_min <= 0:
555
+ raise RuntimeError("non-positive p_min")
556
+ # For some undefined reason, only np.random works here.
557
+ # All PT attempts fail, even when subsequently transformed into numpy
558
+ if self._rng is None:
559
+ mass = np.random.uniform(0.0, p_sum, size=batch_size)
560
+ else:
561
+ mass = torch.rand(batch_size, generator=self._rng) * p_sum
562
+
563
+ # mass = torch.zeros(batch_size, dtype=torch.double).uniform_(0.0, p_sum)
564
+ # mass = torch.rand(batch_size).mul_(p_sum)
565
+ index = self._sum_tree.scan_lower_bound(mass)
566
+ index = torch.as_tensor(index)
567
+ if not index.ndim:
568
+ index = index.unsqueeze(0)
569
+ index.clamp_max_(len(storage) - 1)
570
+ weight = torch.as_tensor(self._sum_tree[index])
571
+ # get indices where weight is 0
572
+ zero_weight = weight == 0
573
+ index = index
574
+ while zero_weight.any():
575
+ index = torch.where(zero_weight, index - 1, index)
576
+ if (index < 0).any():
577
+ raise RuntimeError("Failed to find a suitable index")
578
+ weight = torch.as_tensor(self._sum_tree[index])
579
+ zero_weight = weight == 0
580
+
581
+ # Importance sampling weight formula:
582
+ # w_i = (p_i / sum(p) * N) ^ (-beta)
583
+ # weight_i = w_i / max(w)
584
+ # weight_i = (p_i / sum(p) * N) ^ (-beta) /
585
+ # ((min(p) / sum(p) * N) ^ (-beta))
586
+ # weight_i = ((p_i / sum(p) * N) / (min(p) / sum(p) * N)) ^ (-beta)
587
+ # weight_i = (p_i / min(p)) ^ (-beta)
588
+ # weight = np.power(weight / (p_min + self._eps), -self._beta)
589
+ weight = torch.pow(weight / p_min, -self._beta)
590
+ if storage.ndim > 1:
591
+ index = unravel_index(index, storage.shape)
592
+ return index, {"priority_weight": weight}
593
+
594
+ def add(self, index: torch.Tensor | int) -> None:
595
+ super().add(index)
596
+ self._maybe_erase_max_priority(index)
597
+
598
+ def extend(self, index: torch.Tensor | tuple) -> None:
599
+ super().extend(index)
600
+ self._maybe_erase_max_priority(index)
601
+
602
+ @torch.no_grad()
603
+ def update_priority(
604
+ self,
605
+ index: int | torch.Tensor,
606
+ priority: float | torch.Tensor,
607
+ *,
608
+ storage: TensorStorage | None = None,
609
+ ) -> None: # noqa: D417
610
+ """Updates the priority of the data pointed by the index.
611
+
612
+ Args:
613
+ index (int or torch.Tensor): indexes of the priorities to be
614
+ updated.
615
+ priority (Number or torch.Tensor): new priorities of the
616
+ indexed elements.
617
+
618
+ Keyword Args:
619
+ storage (Storage, optional): a storage used to map the Nd index size to
620
+ the 1d size of the sum_tree and min_tree. Only required whenever
621
+ ``index.ndim > 2``.
622
+
623
+ """
624
+ priority = torch.as_tensor(priority, device=torch.device("cpu")).detach()
625
+ index = torch.as_tensor(index, dtype=torch.long, device=torch.device("cpu"))
626
+ # we need to reshape priority if it has more than one element or if it has
627
+ # a different shape than index
628
+ if priority.numel() > 1 and priority.shape != index.shape:
629
+ try:
630
+ priority = priority.reshape(index.shape[:1])
631
+ except Exception as err:
632
+ raise RuntimeError(
633
+ "priority should be a number or an iterable of the same "
634
+ f"length as index. Got priority of shape {priority.shape} and index "
635
+ f"{index.shape}."
636
+ ) from err
637
+ elif priority.numel() <= 1:
638
+ priority = priority.squeeze()
639
+
640
+ # MaxValueWriter will set -1 for items in the data that we don't want
641
+ # to update. We therefore have to keep only the non-negative indices.
642
+ if _is_int(index):
643
+ if index == -1:
644
+ return
645
+ else:
646
+ if index.ndim > 1:
647
+ if storage is None:
648
+ raise RuntimeError(
649
+ "storage should be provided to Sampler.update_priority when the storage has more "
650
+ "than one dimension."
651
+ )
652
+ try:
653
+ shape = storage.shape
654
+ except AttributeError:
655
+ raise AttributeError(
656
+ "Could not retrieve the storage shape. If your storage is not a TensorStorage subclass "
657
+ "or its shape isn't accessible via the shape attribute, submit an issue on GitHub."
658
+ )
659
+ index = torch.as_tensor(np.ravel_multi_index(index.unbind(-1), shape))
660
+ valid_index = index >= 0
661
+ if not valid_index.any():
662
+ return
663
+ if not valid_index.all():
664
+ index = index[valid_index]
665
+ if priority.ndim:
666
+ priority = priority[valid_index]
667
+
668
+ max_p, max_p_idx = priority.max(dim=0)
669
+ cur_max_priority, cur_max_priority_index = self._max_priority
670
+ if cur_max_priority is None or max_p > cur_max_priority:
671
+ cur_max_priority, cur_max_priority_index = self._max_priority = (
672
+ max_p,
673
+ index[max_p_idx] if index.ndim else index,
674
+ )
675
+ priority = torch.pow(priority + self._eps, self._alpha)
676
+ self._sum_tree[index] = priority
677
+ self._min_tree[index] = priority
678
+ if (
679
+ self._max_priority_within_buffer
680
+ and cur_max_priority_index is not None
681
+ and (index == cur_max_priority_index).any()
682
+ ):
683
+ maxval, maxidx = torch.tensor(
684
+ [self._sum_tree[i] for i in range(self._max_capacity)]
685
+ ).max(0)
686
+ self._max_priority = (maxval, maxidx)
687
+
688
+ def mark_update(
689
+ self, index: int | torch.Tensor, *, storage: Storage | None = None
690
+ ) -> None:
691
+ self.update_priority(index, self.default_priority, storage=storage)
692
+
693
+ def state_dict(self) -> dict[str, Any]:
694
+ return {
695
+ "_alpha": self._alpha,
696
+ "_beta": self._beta,
697
+ "_eps": self._eps,
698
+ "_max_priority": self._max_priority,
699
+ "_sum_tree": deepcopy(self._sum_tree),
700
+ "_min_tree": deepcopy(self._min_tree),
701
+ }
702
+
703
+ def load_state_dict(self, state_dict: dict[str, Any]) -> None:
704
+ self._alpha = state_dict["_alpha"]
705
+ self._beta = state_dict["_beta"]
706
+ self._eps = state_dict["_eps"]
707
+ self._max_priority = state_dict["_max_priority"]
708
+ self._sum_tree = state_dict.pop("_sum_tree")
709
+ self._min_tree = state_dict.pop("_min_tree")
710
+
711
+ @implement_for("torch", None, "2.5.0")
712
+ def dumps(self, path):
713
+ raise NotImplementedError("This method is not implemented for Torch < 2.5.0")
714
+
715
+ @implement_for("torch", "2.5.0", None)
716
+ def dumps(self, path): # noqa: F811
717
+ path = Path(path).absolute()
718
+ path.mkdir(exist_ok=True)
719
+ try:
720
+ mm_st = MemoryMappedTensor.from_filename(
721
+ shape=(self._max_capacity,),
722
+ dtype=torch.float64,
723
+ filename=path / "sumtree.memmap",
724
+ )
725
+ mm_mt = MemoryMappedTensor.from_filename(
726
+ shape=(self._max_capacity,),
727
+ dtype=torch.float64,
728
+ filename=path / "mintree.memmap",
729
+ )
730
+ except FileNotFoundError:
731
+ mm_st = MemoryMappedTensor.empty(
732
+ (self._max_capacity,),
733
+ dtype=torch.float64,
734
+ filename=path / "sumtree.memmap",
735
+ )
736
+ mm_mt = MemoryMappedTensor.empty(
737
+ (self._max_capacity,),
738
+ dtype=torch.float64,
739
+ filename=path / "mintree.memmap",
740
+ )
741
+ mm_st.copy_(
742
+ torch.as_tensor([self._sum_tree[i] for i in range(self._max_capacity)])
743
+ )
744
+ mm_mt.copy_(
745
+ torch.as_tensor([self._min_tree[i] for i in range(self._max_capacity)])
746
+ )
747
+ with open(path / "sampler_metadata.json", "w") as file:
748
+ json.dump(
749
+ tree_map(
750
+ float,
751
+ {
752
+ "_alpha": self._alpha,
753
+ "_beta": self._beta,
754
+ "_eps": self._eps,
755
+ "_max_priority": self._max_priority,
756
+ "_max_capacity": self._max_capacity,
757
+ },
758
+ ),
759
+ file,
760
+ )
761
+
762
+ @implement_for("torch", None, "2.5.0")
763
+ def loads(self, path):
764
+ raise NotImplementedError("This method is not implemented for Torch < 2.5.0")
765
+
766
+ @implement_for("torch", "2.5.0", None)
767
+ def loads(self, path): # noqa: F811
768
+ path = Path(path).absolute()
769
+ with open(path / "sampler_metadata.json") as file:
770
+ metadata = json.load(file)
771
+ self._alpha = metadata["_alpha"]
772
+ self._beta = metadata["_beta"]
773
+ self._eps = metadata["_eps"]
774
+ maxp = tree_map(
775
+ lambda dest, orig: dest.copy_(orig) if dest is not None else orig,
776
+ tuple(self._max_priority),
777
+ tuple(metadata["_max_priority"]),
778
+ )
779
+ if all(x is None for x in self._max_priority):
780
+ self._max_priority = maxp
781
+ _max_capacity = metadata["_max_capacity"]
782
+ if _max_capacity != self._max_capacity:
783
+ raise RuntimeError(
784
+ f"max capacity of loaded metadata ({_max_capacity}) differs from self._max_capacity ({self._max_capacity})."
785
+ )
786
+ mm_st = MemoryMappedTensor.from_filename(
787
+ shape=(self._max_capacity,),
788
+ dtype=torch.float64,
789
+ filename=path / "sumtree.memmap",
790
+ )
791
+ mm_mt = MemoryMappedTensor.from_filename(
792
+ shape=(self._max_capacity,),
793
+ dtype=torch.float64,
794
+ filename=path / "mintree.memmap",
795
+ )
796
+ for i, elt in enumerate(mm_st.tolist()):
797
+ self._sum_tree[i] = elt
798
+ for i, elt in enumerate(mm_mt.tolist()):
799
+ self._min_tree[i] = elt
800
+
801
+
802
+ class SliceSampler(Sampler):
803
+ """Samples slices of data along the first dimension, given start and stop signals.
804
+
805
+ This class samples sub-trajectories with replacement. For a version without
806
+ replacement, see :class:`~torchrl.data.replay_buffers.samplers.SliceSamplerWithoutReplacement`.
807
+
808
+ .. note:: `SliceSampler` can be slow to retrieve the trajectory indices. To accelerate
809
+ its execution, prefer using `end_key` over `traj_key`, and consider the following
810
+ keyword arguments: :attr:`compile`, :attr:`cache_values` and :attr:`use_gpu`.
811
+
812
+ Keyword Args:
813
+ num_slices (int): the number of slices to be sampled. The batch-size
814
+ must be greater or equal to the ``num_slices`` argument. Exclusive
815
+ with ``slice_len``.
816
+ slice_len (int): the length of the slices to be sampled. The batch-size
817
+ must be greater or equal to the ``slice_len`` argument and divisible
818
+ by it. Exclusive with ``num_slices``.
819
+ end_key (NestedKey, optional): the key indicating the end of a
820
+ trajectory (or episode). Defaults to ``("next", "done")``.
821
+ traj_key (NestedKey, optional): the key indicating the trajectories.
822
+ Defaults to ``"episode"`` (commonly used across datasets in TorchRL).
823
+ ends (torch.Tensor, optional): a 1d boolean tensor containing the end of run signals.
824
+ To be used whenever the ``end_key`` or ``traj_key`` is expensive to get,
825
+ or when this signal is readily available. Must be used with ``cache_values=True``
826
+ and cannot be used in conjunction with ``end_key`` or ``traj_key``.
827
+ If provided, it is assumed that the storage is at capacity and that
828
+ if the last element of the ``ends`` tensor is ``False``,
829
+ the same trajectory spans across end and beginning.
830
+ trajectories (torch.Tensor, optional): a 1d integer tensor containing the run ids.
831
+ To be used whenever the ``end_key`` or ``traj_key`` is expensive to get,
832
+ or when this signal is readily available. Must be used with ``cache_values=True``
833
+ and cannot be used in conjunction with ``end_key`` or ``traj_key``.
834
+ If provided, it is assumed that the storage is at capacity and that
835
+ if the last element of the trajectory tensor is identical to the first,
836
+ the same trajectory spans across end and beginning.
837
+ cache_values (bool, optional): to be used with static datasets.
838
+ Will cache the start and end signal of the trajectory. This can be safely used even
839
+ if the trajectory indices change during calls to :class:`~torchrl.data.ReplayBuffer.extend`
840
+ as this operation will erase the cache.
841
+
842
+ .. warning:: ``cache_values=True`` will not work if the sampler is used with a
843
+ storage that is extended by another buffer. For instance:
844
+
845
+ >>> buffer0 = ReplayBuffer(storage=storage,
846
+ ... sampler=SliceSampler(num_slices=8, cache_values=True),
847
+ ... writer=ImmutableWriter())
848
+ >>> buffer1 = ReplayBuffer(storage=storage,
849
+ ... sampler=other_sampler)
850
+ >>> # Wrong! Does not erase the buffer from the sampler of buffer0
851
+ >>> buffer1.extend(data)
852
+
853
+ .. warning:: ``cache_values=True`` will not work as expected if the buffer is
854
+ shared between processes and one process is responsible for writing
855
+ and one process for sampling, as erasing the cache can only be done locally.
856
+
857
+ truncated_key (NestedKey, optional): If not ``None``, this argument
858
+ indicates where a truncated signal should be written in the output
859
+ data. This is used to indicate to value estimators where the provided
860
+ trajectory breaks. Defaults to ``("next", "truncated")``.
861
+ This feature only works with :class:`~torchrl.data.replay_buffers.TensorDictReplayBuffer`
862
+ instances (otherwise the truncated key is returned in the info dictionary
863
+ returned by the :meth:`~torchrl.data.replay_buffers.ReplayBuffer.sample` method).
864
+ strict_length (bool, optional): if ``False``, trajectories of length
865
+ shorter than `slice_len` (or `batch_size // num_slices`) will be
866
+ allowed to appear in the batch. If ``True``, trajectories shorted
867
+ than required will be filtered out.
868
+ Be mindful that this can result in effective `batch_size` shorter
869
+ than the one asked for! Trajectories can be split using
870
+ :func:`~torchrl.collectors.split_trajectories`. Defaults to ``True``.
871
+ compile (bool or dict of kwargs, optional): if ``True``, the bottleneck of
872
+ the :meth:`~sample` method will be compiled with :func:`~torch.compile`.
873
+ Keyword arguments can also be passed to torch.compile with this arg.
874
+ Defaults to ``False``.
875
+ span (bool, int, Tuple[bool | int, bool | int], optional): if provided, the sampled
876
+ trajectory will span across the left and/or the right. This means that possibly
877
+ fewer elements will be provided than what was required. A boolean value means
878
+ that at least one element will be sampled per trajectory. An integer `i` means
879
+ that at least `slice_len - i` samples will be gathered for each sampled trajectory.
880
+ Using tuples allows a fine grained control over the span on the left (beginning
881
+ of the stored trajectory) and on the right (end of the stored trajectory).
882
+ use_gpu (bool or torch.device): if ``True`` (or is a device is passed), an accelerator
883
+ will be used to retrieve the indices of the trajectory starts. This can significantly
884
+ accelerate the sampling when the buffer content is large.
885
+ Defaults to ``False``.
886
+
887
+ .. note:: To recover the trajectory splits in the storage,
888
+ :class:`~torchrl.data.replay_buffers.samplers.SliceSampler` will first
889
+ attempt to find the ``traj_key`` entry in the storage. If it cannot be
890
+ found, the ``end_key`` will be used to reconstruct the episodes.
891
+
892
+ .. note:: When using `strict_length=False`, it is recommended to use
893
+ :func:`~torchrl.collectors.utils.split_trajectories` to split the sampled trajectories.
894
+ However, if two samples from the same episode are placed next to each other,
895
+ this may produce incorrect results. To avoid this issue, consider one of these solutions:
896
+
897
+ - using a :class:`~torchrl.data.TensorDictReplayBuffer` instance with the slice sampler
898
+
899
+ >>> import torch
900
+ >>> from tensordict import TensorDict
901
+ >>> from torchrl.collectors.utils import split_trajectories
902
+ >>> from torchrl.data import TensorDictReplayBuffer, ReplayBuffer, LazyTensorStorage, SliceSampler, SliceSamplerWithoutReplacement
903
+ >>>
904
+ >>> rb = TensorDictReplayBuffer(storage=LazyTensorStorage(max_size=1000),
905
+ ... sampler=SliceSampler(
906
+ ... slice_len=5, traj_key="episode",strict_length=False,
907
+ ... ))
908
+ ...
909
+ >>> ep_1 = TensorDict(
910
+ ... {"obs": torch.arange(100),
911
+ ... "episode": torch.zeros(100),},
912
+ ... batch_size=[100]
913
+ ... )
914
+ >>> ep_2 = TensorDict(
915
+ ... {"obs": torch.arange(4),
916
+ ... "episode": torch.ones(4),},
917
+ ... batch_size=[4]
918
+ ... )
919
+ >>> rb.extend(ep_1)
920
+ >>> rb.extend(ep_2)
921
+ >>>
922
+ >>> s = rb.sample(50)
923
+ >>> print(s)
924
+ TensorDict(
925
+ fields={
926
+ episode: Tensor(shape=torch.Size([46]), device=cpu, dtype=torch.float32, is_shared=False),
927
+ index: Tensor(shape=torch.Size([46, 1]), device=cpu, dtype=torch.int64, is_shared=False),
928
+ next: TensorDict(
929
+ fields={
930
+ done: Tensor(shape=torch.Size([46, 1]), device=cpu, dtype=torch.bool, is_shared=False),
931
+ terminated: Tensor(shape=torch.Size([46, 1]), device=cpu, dtype=torch.bool, is_shared=False),
932
+ truncated: Tensor(shape=torch.Size([46, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
933
+ batch_size=torch.Size([46]),
934
+ device=cpu,
935
+ is_shared=False),
936
+ obs: Tensor(shape=torch.Size([46]), device=cpu, dtype=torch.int64, is_shared=False)},
937
+ batch_size=torch.Size([46]),
938
+ device=cpu,
939
+ is_shared=False)
940
+ >>> t = split_trajectories(s, done_key="truncated")
941
+ >>> print(t["obs"])
942
+ tensor([[73, 74, 75, 76, 77],
943
+ [ 0, 1, 2, 3, 0],
944
+ [ 0, 1, 2, 3, 0],
945
+ [41, 42, 43, 44, 45],
946
+ [ 0, 1, 2, 3, 0],
947
+ [67, 68, 69, 70, 71],
948
+ [27, 28, 29, 30, 31],
949
+ [80, 81, 82, 83, 84],
950
+ [17, 18, 19, 20, 21],
951
+ [ 0, 1, 2, 3, 0]])
952
+ >>> print(t["episode"])
953
+ tensor([[0., 0., 0., 0., 0.],
954
+ [1., 1., 1., 1., 0.],
955
+ [1., 1., 1., 1., 0.],
956
+ [0., 0., 0., 0., 0.],
957
+ [1., 1., 1., 1., 0.],
958
+ [0., 0., 0., 0., 0.],
959
+ [0., 0., 0., 0., 0.],
960
+ [0., 0., 0., 0., 0.],
961
+ [0., 0., 0., 0., 0.],
962
+ [1., 1., 1., 1., 0.]])
963
+
964
+ - using a :class:`~torchrl.data.replay_buffers.samplers.SliceSamplerWithoutReplacement`
965
+
966
+ >>> import torch
967
+ >>> from tensordict import TensorDict
968
+ >>> from torchrl.collectors.utils import split_trajectories
969
+ >>> from torchrl.data import ReplayBuffer, LazyTensorStorage, SliceSampler, SliceSamplerWithoutReplacement
970
+ >>>
971
+ >>> rb = ReplayBuffer(storage=LazyTensorStorage(max_size=1000),
972
+ ... sampler=SliceSamplerWithoutReplacement(
973
+ ... slice_len=5, traj_key="episode",strict_length=False
974
+ ... ))
975
+ ...
976
+ >>> ep_1 = TensorDict(
977
+ ... {"obs": torch.arange(100),
978
+ ... "episode": torch.zeros(100),},
979
+ ... batch_size=[100]
980
+ ... )
981
+ >>> ep_2 = TensorDict(
982
+ ... {"obs": torch.arange(4),
983
+ ... "episode": torch.ones(4),},
984
+ ... batch_size=[4]
985
+ ... )
986
+ >>> rb.extend(ep_1)
987
+ >>> rb.extend(ep_2)
988
+ >>>
989
+ >>> s = rb.sample(50)
990
+ >>> t = split_trajectories(s, trajectory_key="episode")
991
+ >>> print(t["obs"])
992
+ tensor([[75, 76, 77, 78, 79],
993
+ [ 0, 1, 2, 3, 0]])
994
+ >>> print(t["episode"])
995
+ tensor([[0., 0., 0., 0., 0.],
996
+ [1., 1., 1., 1., 0.]])
997
+
998
+ Examples:
999
+ >>> import torch
1000
+ >>> from tensordict import TensorDict
1001
+ >>> from torchrl.data.replay_buffers import LazyMemmapStorage, TensorDictReplayBuffer
1002
+ >>> from torchrl.data.replay_buffers.samplers import SliceSampler
1003
+ >>> torch.manual_seed(0)
1004
+ >>> rb = TensorDictReplayBuffer(
1005
+ ... storage=LazyMemmapStorage(1_000_000),
1006
+ ... sampler=SliceSampler(cache_values=True, num_slices=10),
1007
+ ... batch_size=320,
1008
+ ... )
1009
+ >>> episode = torch.zeros(1000, dtype=torch.int)
1010
+ >>> episode[:300] = 1
1011
+ >>> episode[300:550] = 2
1012
+ >>> episode[550:700] = 3
1013
+ >>> episode[700:] = 4
1014
+ >>> data = TensorDict(
1015
+ ... {
1016
+ ... "episode": episode,
1017
+ ... "obs": torch.randn((3, 4, 5)).expand(1000, 3, 4, 5),
1018
+ ... "act": torch.randn((20,)).expand(1000, 20),
1019
+ ... "other": torch.randn((20, 50)).expand(1000, 20, 50),
1020
+ ... }, [1000]
1021
+ ... )
1022
+ >>> rb.extend(data)
1023
+ >>> sample = rb.sample()
1024
+ >>> print("sample:", sample)
1025
+ >>> print("episodes", sample.get("episode").unique())
1026
+ episodes tensor([1, 2, 3, 4], dtype=torch.int32)
1027
+
1028
+ :class:`~torchrl.data.replay_buffers.SliceSampler` is default-compatible with
1029
+ most of TorchRL's datasets:
1030
+
1031
+ Examples:
1032
+ >>> import torch
1033
+ >>>
1034
+ >>> from torchrl.data.datasets import RobosetExperienceReplay
1035
+ >>> from torchrl.data import SliceSampler
1036
+ >>>
1037
+ >>> torch.manual_seed(0)
1038
+ >>> num_slices = 10
1039
+ >>> dataid = list(RobosetExperienceReplay.available_datasets)[0]
1040
+ >>> data = RobosetExperienceReplay(dataid, batch_size=320, sampler=SliceSampler(num_slices=num_slices))
1041
+ >>> for batch in data:
1042
+ ... batch = batch.reshape(num_slices, -1)
1043
+ ... break
1044
+ >>> print("check that each batch only has one episode:", batch["episode"].unique(dim=1))
1045
+ check that each batch only has one episode: tensor([[19],
1046
+ [14],
1047
+ [ 8],
1048
+ [10],
1049
+ [13],
1050
+ [ 4],
1051
+ [ 2],
1052
+ [ 3],
1053
+ [22],
1054
+ [ 8]])
1055
+
1056
+ """
1057
+
1058
+ # We use this whenever we need to sample N times too many transitions to then select only a 1/N fraction of them
1059
+ _batch_size_multiplier: int | None = 1
1060
+
1061
+ def __init__(
1062
+ self,
1063
+ *,
1064
+ num_slices: int | None = None,
1065
+ slice_len: int | None = None,
1066
+ end_key: NestedKey | None = None,
1067
+ traj_key: NestedKey | None = None,
1068
+ ends: torch.Tensor | None = None,
1069
+ trajectories: torch.Tensor | None = None,
1070
+ cache_values: bool = False,
1071
+ truncated_key: NestedKey | None = ("next", "truncated"),
1072
+ strict_length: bool = True,
1073
+ compile: bool | dict = False,
1074
+ span: bool | int | tuple[bool | int, bool | int] = False,
1075
+ use_gpu: torch.device | bool = False,
1076
+ ):
1077
+ self.num_slices = num_slices
1078
+ self.slice_len = slice_len
1079
+ self.end_key = end_key
1080
+ self.traj_key = traj_key
1081
+ self.truncated_key = truncated_key
1082
+ self.cache_values = cache_values
1083
+ self._fetch_traj = True
1084
+ self.strict_length = strict_length
1085
+ self._cache = {}
1086
+ self.use_gpu = bool(use_gpu)
1087
+ self._gpu_device = (
1088
+ None
1089
+ if not self.use_gpu
1090
+ else torch.device(use_gpu)
1091
+ if not isinstance(use_gpu, bool)
1092
+ else _auto_device()
1093
+ )
1094
+
1095
+ if isinstance(span, (bool, int)):
1096
+ span = (span, span)
1097
+ self.span = span
1098
+
1099
+ if trajectories is not None:
1100
+ if traj_key is not None or end_key:
1101
+ raise RuntimeError(
1102
+ "`trajectories` and `end_key` or `traj_key` are exclusive arguments."
1103
+ )
1104
+ if ends is not None:
1105
+ raise RuntimeError("trajectories and ends are exclusive arguments.")
1106
+ if not cache_values:
1107
+ raise RuntimeError(
1108
+ "To be used, trajectories requires `cache_values` to be set to `True`."
1109
+ )
1110
+ vals = self._find_start_stop_traj(
1111
+ trajectory=trajectories,
1112
+ at_capacity=True,
1113
+ )
1114
+ self._cache["stop-and-length"] = vals
1115
+
1116
+ elif ends is not None:
1117
+ if traj_key is not None or end_key:
1118
+ raise RuntimeError(
1119
+ "`ends` and `end_key` or `traj_key` are exclusive arguments."
1120
+ )
1121
+ if trajectories is not None:
1122
+ raise RuntimeError("trajectories and ends are exclusive arguments.")
1123
+ if not cache_values:
1124
+ raise RuntimeError(
1125
+ "To be used, ends requires `cache_values` to be set to `True`."
1126
+ )
1127
+ vals = self._find_start_stop_traj(end=ends, at_capacity=True)
1128
+ self._cache["stop-and-length"] = vals
1129
+
1130
+ else:
1131
+ if traj_key is not None:
1132
+ self._fetch_traj = True
1133
+ elif end_key is not None:
1134
+ self._fetch_traj = False
1135
+ if end_key is None:
1136
+ end_key = ("next", "done")
1137
+ if traj_key is None:
1138
+ traj_key = "episode"
1139
+ self.end_key = end_key
1140
+ self.traj_key = traj_key
1141
+
1142
+ if not ((num_slices is None) ^ (slice_len is None)):
1143
+ raise TypeError(
1144
+ "Either num_slices or slice_len must be not None, and not both. "
1145
+ f"Got num_slices={num_slices} and slice_len={slice_len}."
1146
+ )
1147
+ self.compile = bool(compile)
1148
+ if self.compile:
1149
+ if isinstance(compile, dict):
1150
+ kwargs = compile
1151
+ else:
1152
+ kwargs = {}
1153
+ self._get_index = torch.compile(self._get_index, **kwargs)
1154
+
1155
+ def __getstate__(self):
1156
+ if get_spawning_popen() is not None and self.cache_values:
1157
+ logger.warning(
1158
+ f"It seems you are sharing a {type(self).__name__} across processes with "
1159
+ f"cache_values=True. "
1160
+ f"While this isn't forbidden and could perfectly work if your dataset "
1161
+ f"is unaltered on both processes, remember that calling extend/add on "
1162
+ f"one process will NOT erase the cache on another process's sampler, "
1163
+ f"which will cause synchronization issues."
1164
+ )
1165
+ state = super().__getstate__()
1166
+ state["_cache"] = {}
1167
+ return state
1168
+
1169
+ def extend(self, index: torch.Tensor) -> None:
1170
+ super().extend(index)
1171
+ if self.cache_values:
1172
+ self._cache.clear()
1173
+
1174
+ def add(self, index: torch.Tensor) -> None:
1175
+ super().add(index)
1176
+ if self.cache_values:
1177
+ self._cache.clear()
1178
+
1179
+ def __repr__(self):
1180
+ return (
1181
+ f"{self.__class__.__name__}(num_slices={self.num_slices}, "
1182
+ f"slice_len={self.slice_len}, "
1183
+ f"end_key={self.end_key}, "
1184
+ f"traj_key={self.traj_key}, "
1185
+ f"truncated_key={self.truncated_key}, "
1186
+ f"strict_length={self.strict_length})"
1187
+ )
1188
+
1189
+ def _find_start_stop_traj(
1190
+ self, *, trajectory=None, end=None, at_capacity: bool, cursor=None
1191
+ ):
1192
+ if trajectory is not None:
1193
+ # slower
1194
+ # _, stop_idx = torch.unique_consecutive(trajectory, return_counts=True)
1195
+ # stop_idx = stop_idx.cumsum(0) - 1
1196
+
1197
+ # even slower
1198
+ # t = trajectory.unsqueeze(0)
1199
+ # w = torch.tensor([1, -1], dtype=torch.int).view(1, 1, 2)
1200
+ # stop_idx = torch.conv1d(t, w).nonzero()
1201
+
1202
+ # faster
1203
+ end = trajectory[:-1] != trajectory[1:]
1204
+ if not at_capacity:
1205
+ end = torch.cat([end, torch.ones_like(end[:1])], 0)
1206
+ else:
1207
+ end = torch.cat([end, trajectory[-1:] != trajectory[:1]], 0)
1208
+ length = trajectory.shape[0]
1209
+ else:
1210
+ # We presume that not done at the end means that the traj spans across end and beginning of storage
1211
+ length = end.shape[0]
1212
+ if not at_capacity:
1213
+ end = end.clone()
1214
+ end[length - 1] = True
1215
+ ndim = end.ndim
1216
+
1217
+ if at_capacity:
1218
+ # we must have at least one end by traj to individuate trajectories
1219
+ # so if no end can be found we set it manually
1220
+ if cursor is not None:
1221
+ if isinstance(cursor, torch.Tensor):
1222
+ cursor = cursor[-1].item()
1223
+ elif isinstance(cursor, range):
1224
+ cursor = cursor[-1]
1225
+ if not _is_int(cursor):
1226
+ raise RuntimeError(
1227
+ "cursor should be an integer or a 1d tensor or a range."
1228
+ )
1229
+ end = torch.index_fill(
1230
+ end,
1231
+ index=torch.tensor(cursor, device=end.device, dtype=torch.long),
1232
+ dim=0,
1233
+ value=1,
1234
+ )
1235
+ if not end.any(0).all():
1236
+ mask = ~end.any(0, True)
1237
+ mask = torch.cat([torch.zeros_like(end[:-1]), mask])
1238
+ end = torch.masked_fill(mask, end, 1)
1239
+ if ndim == 0:
1240
+ raise RuntimeError(
1241
+ "Expected the end-of-trajectory signal to be at least 1-dimensional."
1242
+ )
1243
+ return self._end_to_start_stop(length=length, end=end)
1244
+
1245
+ def _end_to_start_stop(self, end, length):
1246
+ device = None
1247
+ if self.use_gpu:
1248
+ gpu_device = self._gpu_device
1249
+ if end.device != gpu_device:
1250
+ device = end.device
1251
+ end = end.to(self._gpu_device)
1252
+ # Using transpose ensures the start and stop are sorted the same way
1253
+ stop_idx = end.transpose(0, -1).nonzero()
1254
+ stop_idx[:, [0, -1]] = stop_idx[:, [-1, 0]].clone()
1255
+ # First build the start indices as the stop + 1, we'll shift it later
1256
+ start_idx = stop_idx.clone()
1257
+ start_idx[:, 0] += 1
1258
+ start_idx[:, 0] %= end.shape[0]
1259
+ # shift start: to do this, we check when the non-first dim indices are identical
1260
+ # and get a mask like [False, True, True, False, True, ...] where False means
1261
+ # that there's a switch from one dim to another (ie, a switch from one element of the batch
1262
+ # to another). We roll this one step along the time dimension and these two
1263
+ # masks provide us with the indices of the permutation matrix we need
1264
+ # to apply to start_idx.
1265
+ if start_idx.shape[0] > 1:
1266
+ start_idx_mask = (start_idx[1:, 1:] == start_idx[:-1, 1:]).all(-1)
1267
+ m1 = torch.cat([torch.zeros_like(start_idx_mask[:1]), start_idx_mask])
1268
+ m2 = torch.cat([start_idx_mask, torch.zeros_like(start_idx_mask[:1])])
1269
+ start_idx_replace = torch.empty_like(start_idx)
1270
+ start_idx_replace[m1] = start_idx[m2]
1271
+ start_idx_replace[~m1] = start_idx[~m2]
1272
+ start_idx = start_idx_replace
1273
+ else:
1274
+ # In this case we have only one start and stop has already been set
1275
+ pass
1276
+ lengths = stop_idx[:, 0] - start_idx[:, 0] + 1
1277
+ lengths[lengths <= 0] = lengths[lengths <= 0] + length
1278
+ if device is not None:
1279
+ return start_idx.to(device), stop_idx.to(device), lengths.to(device)
1280
+ return start_idx, stop_idx, lengths
1281
+
1282
+ def _start_to_end(self, st: torch.Tensor, length: int):
1283
+
1284
+ arange = torch.arange(length, device=st.device, dtype=st.dtype)
1285
+ ndims = st.shape[-1] - 1 if st.ndim else 0
1286
+ if ndims:
1287
+ arange = torch.stack([arange] + [torch.zeros_like(arange)] * ndims, -1)
1288
+ else:
1289
+ arange = arange.unsqueeze(-1)
1290
+ if st.shape != arange.shape:
1291
+ # we do this to make sure that we're not broadcasting the start
1292
+ # wrong as a tensor with shape [N] can't be expanded to [N, 1]
1293
+ # without getting an error
1294
+ st = st.expand_as(arange)
1295
+ return arange + st
1296
+
1297
+ def _tensor_slices_from_startend(self, seq_length, start, storage_length):
1298
+ # start is a 2d tensor resulting from nonzero()
1299
+ # seq_length is a 1d tensor indicating the desired length of each sequence
1300
+
1301
+ if isinstance(seq_length, int):
1302
+ arange = torch.arange(seq_length, device=start.device, dtype=start.dtype)
1303
+ ndims = start.shape[-1] - 1 if (start.ndim - 1) else 0
1304
+ if ndims:
1305
+ arange_reshaped = torch.empty(
1306
+ arange.shape + torch.Size([ndims + 1]),
1307
+ device=start.device,
1308
+ dtype=start.dtype,
1309
+ )
1310
+ arange_reshaped[..., 0] = arange
1311
+ arange_reshaped[..., 1:] = 0
1312
+ else:
1313
+ arange_reshaped = arange.unsqueeze(-1)
1314
+ arange_expanded = arange_reshaped.expand(
1315
+ torch.Size([start.shape[0]]) + arange_reshaped.shape
1316
+ )
1317
+ if start.shape != arange_expanded.shape:
1318
+ n_missing_dims = arange_expanded.dim() - start.dim()
1319
+ start_expanded = start[
1320
+ (slice(None),) + (None,) * n_missing_dims
1321
+ ].expand_as(arange_expanded)
1322
+ result = (start_expanded + arange_expanded).flatten(0, 1)
1323
+
1324
+ else:
1325
+ # when padding is needed
1326
+ result = torch.cat(
1327
+ [
1328
+ self._start_to_end(_start, _seq_len)
1329
+ for _start, _seq_len in zip(start, seq_length)
1330
+ ]
1331
+ )
1332
+ result[:, 0] = result[:, 0] % storage_length
1333
+ return result
1334
+
1335
+ def _get_stop_and_length(self, storage, fallback=True):
1336
+ if self.cache_values and "stop-and-length" in self._cache:
1337
+ return self._cache.get("stop-and-length")
1338
+
1339
+ if self._fetch_traj:
1340
+ # We first try with the traj_key
1341
+ try:
1342
+ if isinstance(storage, TensorStorage):
1343
+ trajectory = storage[:][self._used_traj_key]
1344
+ else:
1345
+ try:
1346
+ trajectory = storage[:][self.traj_key]
1347
+ except Exception:
1348
+ raise RuntimeError(
1349
+ "Could not get a tensordict out of the storage, which is required for SliceSampler to compute the trajectories."
1350
+ )
1351
+ vals = self._find_start_stop_traj(
1352
+ trajectory=trajectory,
1353
+ at_capacity=storage._is_full,
1354
+ cursor=getattr(storage, "_last_cursor", None),
1355
+ )
1356
+ if self.cache_values:
1357
+ self._cache["stop-and-length"] = vals
1358
+ return vals
1359
+ except KeyError:
1360
+ if fallback:
1361
+ self._fetch_traj = False
1362
+ return self._get_stop_and_length(storage, fallback=False)
1363
+ raise
1364
+
1365
+ else:
1366
+ try:
1367
+ try:
1368
+ done = storage[:][self.end_key]
1369
+ except Exception:
1370
+ raise RuntimeError(
1371
+ "Could not get a tensordict out of the storage, which is required for SliceSampler to compute the trajectories."
1372
+ )
1373
+ vals = self._find_start_stop_traj(
1374
+ end=done.squeeze()[: len(storage)],
1375
+ at_capacity=storage._is_full,
1376
+ cursor=getattr(storage, "_last_cursor", None),
1377
+ )
1378
+ if self.cache_values:
1379
+ self._cache["stop-and-length"] = vals
1380
+ return vals
1381
+ except KeyError:
1382
+ if fallback:
1383
+ self._fetch_traj = True
1384
+ return self._get_stop_and_length(storage, fallback=False)
1385
+ raise
1386
+
1387
+ def _adjusted_batch_size(self, batch_size):
1388
+ if self.num_slices is not None:
1389
+ if batch_size % self.num_slices != 0:
1390
+ raise RuntimeError(
1391
+ f"The batch-size must be divisible by the number of slices, got "
1392
+ f"batch_size={batch_size} and num_slices={self.num_slices}."
1393
+ )
1394
+ seq_length = batch_size // self.num_slices
1395
+ num_slices = self.num_slices
1396
+ else:
1397
+ if batch_size % self.slice_len != 0:
1398
+ raise RuntimeError(
1399
+ f"The batch-size must be divisible by the slice length, got "
1400
+ f"batch_size={batch_size} and slice_len={self.slice_len}."
1401
+ )
1402
+ seq_length = self.slice_len
1403
+ num_slices = batch_size // self.slice_len
1404
+ return seq_length, num_slices
1405
+
1406
+ def sample(self, storage: Storage, batch_size: int) -> tuple[torch.Tensor, dict]:
1407
+ if self._batch_size_multiplier is not None:
1408
+ batch_size = batch_size * self._batch_size_multiplier
1409
+ # pick up as many trajs as we need
1410
+ start_idx, stop_idx, lengths = self._get_stop_and_length(storage)
1411
+ # we have to make sure that the number of dims of the storage
1412
+ # is the same as the stop/start signals since we will
1413
+ # use these to sample the storage
1414
+ if start_idx.shape[1] != storage.ndim:
1415
+ raise RuntimeError(
1416
+ f"Expected the end-of-trajectory signal to be "
1417
+ f"{storage.ndim}-dimensional. Got a tensor with shape[1]={start_idx.shape[1]} "
1418
+ "instead."
1419
+ )
1420
+ seq_length, num_slices = self._adjusted_batch_size(batch_size)
1421
+ storage_length = storage.shape[0]
1422
+ return self._sample_slices(
1423
+ lengths,
1424
+ start_idx,
1425
+ stop_idx,
1426
+ seq_length,
1427
+ num_slices,
1428
+ storage_length=storage_length,
1429
+ storage=storage,
1430
+ )
1431
+
1432
+ def _sample_slices(
1433
+ self,
1434
+ lengths: torch.Tensor,
1435
+ start_idx: torch.Tensor,
1436
+ stop_idx: torch.Tensor,
1437
+ seq_length: int,
1438
+ num_slices: int,
1439
+ storage_length: int,
1440
+ traj_idx: torch.Tensor | None = None,
1441
+ *,
1442
+ storage,
1443
+ ) -> tuple[tuple[torch.Tensor, ...], dict[str, Any]]:
1444
+ # start_idx and stop_idx are 2d tensors organized like a non-zero
1445
+
1446
+ def get_traj_idx(maxval):
1447
+ return torch.randint(
1448
+ maxval, (num_slices,), device=lengths.device, generator=self._rng
1449
+ )
1450
+
1451
+ if (lengths < seq_length).any():
1452
+ if self.strict_length:
1453
+ idx = lengths >= seq_length
1454
+ if not idx.any():
1455
+ raise RuntimeError(
1456
+ f"Did not find a single trajectory with sufficient length (length range: {lengths.min()} - {lengths.max()} / required={seq_length}))."
1457
+ )
1458
+ if (
1459
+ isinstance(seq_length, torch.Tensor)
1460
+ and seq_length.shape == lengths.shape
1461
+ ):
1462
+ seq_length = seq_length[idx]
1463
+ lengths_idx = lengths[idx]
1464
+ start_idx = start_idx[idx]
1465
+ stop_idx = stop_idx[idx]
1466
+
1467
+ if traj_idx is None:
1468
+ traj_idx = get_traj_idx(lengths_idx.shape[0])
1469
+ else:
1470
+ # Here we must filter out the indices that correspond to trajectories
1471
+ # we don't want to keep. That could potentially lead to an empty sample.
1472
+ # The difficulty with this adjustment is that traj_idx points to a full
1473
+ # sequences of lengths, but we filter out part of it so we must
1474
+ # convert traj_idx to a boolean mask, index this mask with the
1475
+ # valid indices and then recover the nonzero.
1476
+ idx_mask = torch.zeros_like(idx)
1477
+ idx_mask[traj_idx] = True
1478
+ traj_idx = idx_mask[idx].nonzero().squeeze(-1)
1479
+ if not traj_idx.numel():
1480
+ raise RuntimeError(
1481
+ "None of the provided indices pointed to a trajectory of "
1482
+ "sufficient length. Consider using strict_length=False for the "
1483
+ "sampler instead."
1484
+ )
1485
+ num_slices = traj_idx.shape[0]
1486
+
1487
+ del idx
1488
+ lengths = lengths_idx
1489
+ else:
1490
+ if traj_idx is None:
1491
+ traj_idx = get_traj_idx(lengths.shape[0])
1492
+ else:
1493
+ num_slices = traj_idx.shape[0]
1494
+
1495
+ # make seq_length a tensor with values clamped by lengths
1496
+ seq_length = lengths[traj_idx].clamp_max(seq_length)
1497
+ else:
1498
+ if traj_idx is None:
1499
+ traj_idx = get_traj_idx(lengths.shape[0])
1500
+ else:
1501
+ num_slices = traj_idx.shape[0]
1502
+ return self._get_index(
1503
+ lengths=lengths,
1504
+ start_idx=start_idx,
1505
+ stop_idx=stop_idx,
1506
+ num_slices=num_slices,
1507
+ seq_length=seq_length,
1508
+ storage_length=storage_length,
1509
+ traj_idx=traj_idx,
1510
+ storage=storage,
1511
+ )
1512
+
1513
+ def _get_index(
1514
+ self,
1515
+ lengths: torch.Tensor,
1516
+ start_idx: torch.Tensor,
1517
+ stop_idx: torch.Tensor,
1518
+ seq_length: int,
1519
+ num_slices: int,
1520
+ storage_length: int,
1521
+ traj_idx: torch.Tensor | None = None,
1522
+ *,
1523
+ storage,
1524
+ ) -> tuple[torch.Tensor, dict]:
1525
+ # end_point is the last possible index for start
1526
+ last_indexable_start = lengths[traj_idx] - seq_length + 1
1527
+ if not self.span[1]:
1528
+ end_point = last_indexable_start
1529
+ elif self.span[1] is True:
1530
+ end_point = lengths[traj_idx] + 1
1531
+ else:
1532
+ span_left = self.span[1]
1533
+ if span_left >= seq_length:
1534
+ raise ValueError(
1535
+ "The right and left span must be strictly lower than the sequence length"
1536
+ )
1537
+ end_point = lengths[traj_idx] - span_left
1538
+
1539
+ if not self.span[0]:
1540
+ start_point = 0
1541
+ elif self.span[0] is True:
1542
+ start_point = -seq_length + 1
1543
+ else:
1544
+ span_right = self.span[0]
1545
+ if span_right >= seq_length:
1546
+ raise ValueError(
1547
+ "The right and left span must be strictly lower than the sequence length"
1548
+ )
1549
+ start_point = -span_right
1550
+
1551
+ relative_starts = (
1552
+ torch.rand(num_slices, device=lengths.device, generator=self._rng)
1553
+ * (end_point - start_point)
1554
+ ).floor().to(start_idx.dtype) + start_point
1555
+
1556
+ if self.span[0]:
1557
+ out_of_traj = relative_starts < 0
1558
+ if out_of_traj.any():
1559
+ # a negative start means sampling fewer elements
1560
+ seq_length = torch.where(
1561
+ ~out_of_traj, seq_length, seq_length + relative_starts
1562
+ )
1563
+ relative_starts = torch.where(~out_of_traj, relative_starts, 0)
1564
+ if self.span[1]:
1565
+ out_of_traj = relative_starts + seq_length > lengths[traj_idx]
1566
+ if out_of_traj.any():
1567
+ # a negative start means sampling fewer elements
1568
+ seq_length = torch.minimum(
1569
+ seq_length, lengths[traj_idx] - relative_starts
1570
+ )
1571
+
1572
+ starts = torch.cat(
1573
+ [
1574
+ (start_idx[traj_idx, 0] + relative_starts).unsqueeze(1),
1575
+ start_idx[traj_idx, 1:],
1576
+ ],
1577
+ 1,
1578
+ )
1579
+ index = self._tensor_slices_from_startend(seq_length, starts, storage_length)
1580
+ if self.truncated_key is not None:
1581
+ truncated_key = self.truncated_key
1582
+ done_key = _replace_last(truncated_key, "done")
1583
+ terminated_key = _replace_last(truncated_key, "terminated")
1584
+
1585
+ truncated = torch.zeros(
1586
+ (index.shape[0], 1), dtype=torch.bool, device=index.device
1587
+ )
1588
+ if isinstance(seq_length, int):
1589
+ truncated.view(num_slices, -1)[:, -1] = 1
1590
+ else:
1591
+ truncated[seq_length.cumsum(0) - 1] = 1
1592
+ index = index.to(torch.long).unbind(-1)
1593
+ st_index = storage[index]
1594
+ done = st_index.get(done_key, default=None)
1595
+ if done is None:
1596
+ done = truncated.clone()
1597
+ else:
1598
+ done = done | truncated
1599
+ terminated = st_index.get(terminated_key, default=None)
1600
+ if terminated is None:
1601
+ terminated = torch.zeros_like(truncated)
1602
+ return index, {
1603
+ truncated_key: truncated,
1604
+ done_key: done,
1605
+ terminated_key: terminated,
1606
+ }
1607
+ index = index.to(torch.long).unbind(-1)
1608
+ return index, {}
1609
+
1610
+ @property
1611
+ def _used_traj_key(self):
1612
+ return self.__dict__.get("__used_traj_key", self.traj_key)
1613
+
1614
+ @_used_traj_key.setter
1615
+ def _used_traj_key(self, value):
1616
+ self.__dict__["__used_traj_key"] = value
1617
+
1618
+ @property
1619
+ def _used_end_key(self):
1620
+ return self.__dict__.get("__used_end_key", self.end_key)
1621
+
1622
+ @_used_end_key.setter
1623
+ def _used_end_key(self, value):
1624
+ self.__dict__["__used_end_key"] = value
1625
+
1626
+ def _empty(self):
1627
+ pass
1628
+
1629
+ def dumps(self, path):
1630
+ # no op - cache does not need to be saved
1631
+ ...
1632
+
1633
+ def loads(self, path):
1634
+ # no op
1635
+ ...
1636
+
1637
+ def state_dict(self) -> dict[str, Any]:
1638
+ return {}
1639
+
1640
+ def load_state_dict(self, state_dict: dict[str, Any]) -> None:
1641
+ ...
1642
+
1643
+
1644
+ class SliceSamplerWithoutReplacement(SliceSampler, SamplerWithoutReplacement):
1645
+ """Samples slices of data along the first dimension, given start and stop signals, without replacement.
1646
+
1647
+ In this context, ``without replacement`` means that the same element (NOT trajectory) will not be sampled twice
1648
+ before the counter is automatically reset. Within a single sample, however, only one slice of a given trajectory
1649
+ will appear (see example below).
1650
+
1651
+ This class is to be used with static replay buffers or in between two
1652
+ replay buffer extensions. Extending the replay buffer will reset the
1653
+ the sampler, and continuous sampling without replacement is currently not
1654
+ allowed.
1655
+
1656
+ .. note:: `SliceSamplerWithoutReplacement` can be slow to retrieve the trajectory indices. To accelerate
1657
+ its execution, prefer using `end_key` over `traj_key`, and consider the following
1658
+ keyword arguments: :attr:`compile`, :attr:`cache_values` and :attr:`use_gpu`.
1659
+
1660
+ Keyword Args:
1661
+ drop_last (bool, optional): if ``True``, the last incomplete sample (if any) will be dropped.
1662
+ If ``False``, this last sample will be kept.
1663
+ Defaults to ``False``.
1664
+ num_slices (int): the number of slices to be sampled. The batch-size
1665
+ must be greater or equal to the ``num_slices`` argument. Exclusive
1666
+ with ``slice_len``.
1667
+ slice_len (int): the length of the slices to be sampled. The batch-size
1668
+ must be greater or equal to the ``slice_len`` argument and divisible
1669
+ by it. Exclusive with ``num_slices``.
1670
+ end_key (NestedKey, optional): the key indicating the end of a
1671
+ trajectory (or episode). Defaults to ``("next", "done")``.
1672
+ traj_key (NestedKey, optional): the key indicating the trajectories.
1673
+ Defaults to ``"episode"`` (commonly used across datasets in TorchRL).
1674
+ ends (torch.Tensor, optional): a 1d boolean tensor containing the end of run signals.
1675
+ To be used whenever the ``end_key`` or ``traj_key`` is expensive to get,
1676
+ or when this signal is readily available. Must be used with ``cache_values=True``
1677
+ and cannot be used in conjunction with ``end_key`` or ``traj_key``.
1678
+ trajectories (torch.Tensor, optional): a 1d integer tensor containing the run ids.
1679
+ To be used whenever the ``end_key`` or ``traj_key`` is expensive to get,
1680
+ or when this signal is readily available. Must be used with ``cache_values=True``
1681
+ and cannot be used in conjunction with ``end_key`` or ``traj_key``.
1682
+ truncated_key (NestedKey, optional): If not ``None``, this argument
1683
+ indicates where a truncated signal should be written in the output
1684
+ data. This is used to indicate to value estimators where the provided
1685
+ trajectory breaks. Defaults to ``("next", "truncated")``.
1686
+ This feature only works with :class:`~torchrl.data.replay_buffers.TensorDictReplayBuffer`
1687
+ instances (otherwise the truncated key is returned in the info dictionary
1688
+ returned by the :meth:`~torchrl.data.replay_buffers.ReplayBuffer.sample` method).
1689
+ strict_length (bool, optional): if ``False``, trajectories of length
1690
+ shorter than `slice_len` (or `batch_size // num_slices`) will be
1691
+ allowed to appear in the batch. If ``True``, trajectories shorted
1692
+ than required will be filtered out.
1693
+ Be mindful that this can result in effective `batch_size` shorter
1694
+ than the one asked for! Trajectories can be split using
1695
+ :func:`~torchrl.collectors.split_trajectories`. Defaults to ``True``.
1696
+ shuffle (bool, optional): if ``False``, the order of the trajectories
1697
+ is not shuffled. Defaults to ``True``.
1698
+ compile (bool or dict of kwargs, optional): if ``True``, the bottleneck of
1699
+ the :meth:`~sample` method will be compiled with :func:`~torch.compile`.
1700
+ Keyword arguments can also be passed to torch.compile with this arg.
1701
+ Defaults to ``False``.
1702
+ use_gpu (bool or torch.device): if ``True`` (or is a device is passed), an accelerator
1703
+ will be used to retrieve the indices of the trajectory starts. This can significantly
1704
+ accelerate the sampling when the buffer content is large.
1705
+ Defaults to ``False``.
1706
+
1707
+ .. note:: To recover the trajectory splits in the storage,
1708
+ :class:`~torchrl.data.replay_buffers.samplers.SliceSamplerWithoutReplacement` will first
1709
+ attempt to find the ``traj_key`` entry in the storage. If it cannot be
1710
+ found, the ``end_key`` will be used to reconstruct the episodes.
1711
+
1712
+ Examples:
1713
+ >>> import torch
1714
+ >>> from tensordict import TensorDict
1715
+ >>> from torchrl.data.replay_buffers import LazyMemmapStorage, TensorDictReplayBuffer
1716
+ >>> from torchrl.data.replay_buffers.samplers import SliceSamplerWithoutReplacement
1717
+ >>>
1718
+ >>> rb = TensorDictReplayBuffer(
1719
+ ... storage=LazyMemmapStorage(1000),
1720
+ ... # asking for 10 slices for a total of 320 elements, ie, 10 trajectories of 32 transitions each
1721
+ ... sampler=SliceSamplerWithoutReplacement(num_slices=10),
1722
+ ... batch_size=320,
1723
+ ... )
1724
+ >>> episode = torch.zeros(1000, dtype=torch.int)
1725
+ >>> episode[:300] = 1
1726
+ >>> episode[300:550] = 2
1727
+ >>> episode[550:700] = 3
1728
+ >>> episode[700:] = 4
1729
+ >>> data = TensorDict(
1730
+ ... {
1731
+ ... "episode": episode,
1732
+ ... "obs": torch.randn((3, 4, 5)).expand(1000, 3, 4, 5),
1733
+ ... "act": torch.randn((20,)).expand(1000, 20),
1734
+ ... "other": torch.randn((20, 50)).expand(1000, 20, 50),
1735
+ ... }, [1000]
1736
+ ... )
1737
+ >>> rb.extend(data)
1738
+ >>> sample = rb.sample()
1739
+ >>> # since we want trajectories of 32 transitions but there are only 4 episodes to
1740
+ >>> # sample from, we only get 4 x 32 = 128 transitions in this batch
1741
+ >>> print("sample:", sample)
1742
+ >>> print("trajectories in sample", sample.get("episode").unique())
1743
+
1744
+ :class:`~torchrl.data.replay_buffers.SliceSamplerWithoutReplacement` is default-compatible with
1745
+ most of TorchRL's datasets, and allows users to consume datasets in a dataloader-like fashion:
1746
+
1747
+ Examples:
1748
+ >>> import torch
1749
+ >>>
1750
+ >>> from torchrl.data.datasets import RobosetExperienceReplay
1751
+ >>> from torchrl.data import SliceSamplerWithoutReplacement
1752
+ >>>
1753
+ >>> torch.manual_seed(0)
1754
+ >>> num_slices = 10
1755
+ >>> dataid = list(RobosetExperienceReplay.available_datasets)[0]
1756
+ >>> data = RobosetExperienceReplay(dataid, batch_size=320,
1757
+ ... sampler=SliceSamplerWithoutReplacement(num_slices=num_slices))
1758
+ >>> # the last sample is kept, since drop_last=False by default
1759
+ >>> for i, batch in enumerate(data):
1760
+ ... print(batch.get("episode").unique())
1761
+ tensor([ 5, 6, 8, 11, 12, 14, 16, 17, 19, 24])
1762
+ tensor([ 1, 2, 7, 9, 10, 13, 15, 18, 21, 22])
1763
+ tensor([ 0, 3, 4, 20, 23])
1764
+
1765
+ When requesting a large total number of samples with few trajectories and small span, the batch will contain
1766
+ only at most one sample of each trajectory:
1767
+
1768
+ Examples:
1769
+ >>> import torch
1770
+ >>> from tensordict import TensorDict
1771
+ >>> from torchrl.collectors.utils import split_trajectories
1772
+ >>> from torchrl.data import ReplayBuffer, LazyTensorStorage, SliceSampler, SliceSamplerWithoutReplacement
1773
+ >>>
1774
+ >>> rb = ReplayBuffer(storage=LazyTensorStorage(max_size=1000),
1775
+ ... sampler=SliceSamplerWithoutReplacement(
1776
+ ... slice_len=5, traj_key="episode",strict_length=False
1777
+ ... ))
1778
+ ...
1779
+ >>> ep_1 = TensorDict(
1780
+ ... {"obs": torch.arange(100),
1781
+ ... "episode": torch.zeros(100),},
1782
+ ... batch_size=[100]
1783
+ ... )
1784
+ >>> ep_2 = TensorDict(
1785
+ ... {"obs": torch.arange(51),
1786
+ ... "episode": torch.ones(51),},
1787
+ ... batch_size=[51]
1788
+ ... )
1789
+ >>> rb.extend(ep_1)
1790
+ >>> rb.extend(ep_2)
1791
+ >>>
1792
+ >>> s = rb.sample(50)
1793
+ >>> t = split_trajectories(s, trajectory_key="episode")
1794
+ >>> print(t["obs"])
1795
+ tensor([[14, 15, 16, 17, 18],
1796
+ [ 3, 4, 5, 6, 7]])
1797
+ >>> print(t["episode"])
1798
+ tensor([[0., 0., 0., 0., 0.],
1799
+ [1., 1., 1., 1., 1.]])
1800
+ >>>
1801
+ >>> s = rb.sample(50)
1802
+ >>> t = split_trajectories(s, trajectory_key="episode")
1803
+ >>> print(t["obs"])
1804
+ tensor([[ 4, 5, 6, 7, 8],
1805
+ [26, 27, 28, 29, 30]])
1806
+ >>> print(t["episode"])
1807
+ tensor([[0., 0., 0., 0., 0.],
1808
+ [1., 1., 1., 1., 1.]])
1809
+
1810
+ """
1811
+
1812
+ def __init__(
1813
+ self,
1814
+ *,
1815
+ num_slices: int | None = None,
1816
+ slice_len: int | None = None,
1817
+ drop_last: bool = False,
1818
+ end_key: NestedKey | None = None,
1819
+ traj_key: NestedKey | None = None,
1820
+ ends: torch.Tensor | None = None,
1821
+ trajectories: torch.Tensor | None = None,
1822
+ truncated_key: NestedKey | None = ("next", "truncated"),
1823
+ strict_length: bool = True,
1824
+ shuffle: bool = True,
1825
+ compile: bool | dict = False,
1826
+ use_gpu: bool | torch.device = False,
1827
+ ):
1828
+ SliceSampler.__init__(
1829
+ self,
1830
+ num_slices=num_slices,
1831
+ slice_len=slice_len,
1832
+ end_key=end_key,
1833
+ traj_key=traj_key,
1834
+ cache_values=True,
1835
+ truncated_key=truncated_key,
1836
+ strict_length=strict_length,
1837
+ ends=ends,
1838
+ trajectories=trajectories,
1839
+ compile=compile,
1840
+ use_gpu=use_gpu,
1841
+ )
1842
+ SamplerWithoutReplacement.__init__(self, drop_last=drop_last, shuffle=shuffle)
1843
+
1844
+ def __repr__(self):
1845
+ if self._sample_list is not None:
1846
+ perc = len(self._sample_list) / self.len_storage * 100
1847
+ else:
1848
+ perc = 0
1849
+ return (
1850
+ f"{self.__class__.__name__}("
1851
+ f"num_slices={self.num_slices}, "
1852
+ f"slice_len={self.slice_len}, "
1853
+ f"end_key={self.end_key}, "
1854
+ f"traj_key={self.traj_key}, "
1855
+ f"truncated_key={self.truncated_key}, "
1856
+ f"strict_length={self.strict_length},"
1857
+ f"{perc}% sampled)"
1858
+ )
1859
+
1860
+ def _empty(self):
1861
+ self._cache = {}
1862
+ SamplerWithoutReplacement._empty(self)
1863
+
1864
+ def _storage_len(self, storage):
1865
+ return self._storage_len_buffer
1866
+
1867
+ def sample(
1868
+ self, storage: Storage, batch_size: int
1869
+ ) -> tuple[tuple[torch.Tensor, ...], dict]:
1870
+ if self._batch_size_multiplier is not None:
1871
+ batch_size = batch_size * self._batch_size_multiplier
1872
+ start_idx, stop_idx, lengths = self._get_stop_and_length(storage)
1873
+ # we have to make sure that the number of dims of the storage
1874
+ # is the same as the stop/start signals since we will
1875
+ # use these to sample the storage
1876
+ if start_idx.shape[1] != storage.ndim:
1877
+ raise RuntimeError(
1878
+ f"Expected the end-of-trajectory signal to be "
1879
+ f"{storage.ndim}-dimensional. Got a {start_idx.shape[1]} tensor "
1880
+ "instead."
1881
+ )
1882
+ self._storage_len_buffer = len(start_idx)
1883
+ # first get indices of the trajectories we want to retrieve
1884
+ seq_length, num_slices = self._adjusted_batch_size(batch_size)
1885
+ indices, _ = SamplerWithoutReplacement.sample(self, storage, num_slices)
1886
+ storage_length = storage.shape[0]
1887
+
1888
+ # traj_idx will either be a single tensor or a tuple that can be reorganized
1889
+ # like a non-zero through stacking.
1890
+ def tuple_to_tensor(traj_idx, lengths=lengths):
1891
+ if isinstance(traj_idx, tuple):
1892
+ traj_idx = torch.arange(len(storage), device=lengths.device).view(
1893
+ storage.shape
1894
+ )[traj_idx]
1895
+ return traj_idx
1896
+
1897
+ idx, info = self._sample_slices(
1898
+ lengths,
1899
+ start_idx,
1900
+ stop_idx,
1901
+ seq_length,
1902
+ num_slices,
1903
+ storage_length,
1904
+ traj_idx=tuple_to_tensor(indices),
1905
+ storage=storage,
1906
+ )
1907
+ return idx, info
1908
+
1909
+ def state_dict(self) -> dict[str, Any]:
1910
+ return SamplerWithoutReplacement.state_dict(self)
1911
+
1912
+ def load_state_dict(self, state_dict: dict[str, Any]) -> None:
1913
+ return SamplerWithoutReplacement.load_state_dict(self, state_dict)
1914
+
1915
+
1916
+ class PrioritizedSliceSampler(SliceSampler, PrioritizedSampler):
1917
+ r"""Samples slices of data along the first dimension, given start and stop signals, using prioritized sampling.
1918
+
1919
+ This class combines trajectory sampling with Prioritized Experience Replay (PER) as presented in
1920
+ "Schaul, T.; Quan, J.; Antonoglou, I.; and Silver, D. 2015. Prioritized experience replay."
1921
+ (https://arxiv.org/abs/1511.05952)
1922
+
1923
+ **Core Idea**: Instead of sampling trajectory slices uniformly, this sampler prioritizes
1924
+ trajectory start points based on the importance of the transitions at those positions.
1925
+ This allows focusing learning on the most informative parts of trajectories.
1926
+
1927
+ **How it works**:
1928
+ 1. Each transition is assigned a priority based on its TD error: :math:`p_i = |\\delta_i| + \\epsilon`
1929
+ 2. Trajectory start points are sampled with probability: :math:`P(i) = \frac{p_i^\alpha}{\\sum_j p_j^\alpha}`
1930
+ 3. Importance sampling weights correct for bias: :math:`w_i = (N \\cdot P(i))^{-\beta}`
1931
+ 4. Complete trajectory slices are extracted from the sampled start points
1932
+
1933
+ For more info see :class:`~torchrl.data.replay_buffers.samplers.SliceSampler` and :class:`~torchrl.data.replay_buffers.samplers.PrioritizedSampler`.
1934
+
1935
+ .. warning:: PrioritizedSliceSampler will look at the priorities of the individual transitions and sample the
1936
+ start points accordingly. This means that transitions with a low priority may as well appear in the
1937
+ samples if they follow another of higher priority, and transitions with a high priority but closer to the
1938
+ end of a trajectory may never be sampled if they cannot be used as start points.
1939
+ Currently, it is the user responsibility to aggregate priorities across items of a trajectory using
1940
+ :meth:`update_priority`.
1941
+
1942
+ Args:
1943
+ max_capacity (int): maximum capacity of the buffer.
1944
+ alpha (:obj:`float`): exponent :math:`\alpha` determines how much prioritization is used.
1945
+ - :math:`\alpha = 0`: uniform sampling of trajectory start points
1946
+ - :math:`\alpha = 1`: full prioritization based on TD error magnitude at start points
1947
+ - Typical values: 0.4-0.7 for balanced prioritization
1948
+ - Higher :math:`\alpha` means more aggressive prioritization of high-error trajectory regions
1949
+ beta (:obj:`float`): importance sampling negative exponent :math:`\beta`.
1950
+ - :math:`\beta` controls the correction for the bias introduced by prioritization
1951
+ - :math:`\beta = 0`: no correction (biased towards high-priority trajectory regions)
1952
+ - :math:`\beta = 1`: full correction (unbiased but potentially unstable)
1953
+ - Typical values: start at 0.4-0.6 and anneal to 1.0 during training
1954
+ - Lower :math:`\beta` early in training provides stability, higher :math:`\beta` later reduces bias
1955
+ eps (:obj:`float`, optional): small constant added to priorities to ensure
1956
+ no transition has zero priority. This prevents trajectory regions from never
1957
+ being sampled. Defaults to 1e-8.
1958
+ reduction (str, optional): the reduction method for multidimensional
1959
+ tensordicts (i.e., stored trajectory). Can be one of "max", "min",
1960
+ "median" or "mean".
1961
+
1962
+ **Parameter Guidelines**:
1963
+
1964
+ - **:math:`\alpha` (alpha)**: Controls how much to prioritize high-error trajectory regions.
1965
+ 0.4-0.7: Good balance between learning speed and stability.
1966
+ 1.0: Maximum prioritization (may be unstable).
1967
+ 0.0: Uniform sampling (no prioritization benefit).
1968
+
1969
+ - **:math:`\beta` (beta)**: Controls importance sampling correction.
1970
+ Start at 0.4-0.6 for training stability.
1971
+ Anneal to 1.0 over training to reduce bias.
1972
+ Lower values = more stable but biased.
1973
+ Higher values = less biased but potentially unstable.
1974
+
1975
+ - **:math:`\\epsilon`**: Small constant to prevent zero priorities.
1976
+ 1e-8: Good default value.
1977
+ Too small: may cause numerical issues.
1978
+ Too large: reduces prioritization effect.
1979
+
1980
+ Keyword Args:
1981
+ num_slices (int): the number of slices to be sampled. The batch-size
1982
+ must be greater or equal to the ``num_slices`` argument. Exclusive
1983
+ with ``slice_len``.
1984
+ slice_len (int): the length of the slices to be sampled. The batch-size
1985
+ must be greater or equal to the ``slice_len`` argument and divisible
1986
+ by it. Exclusive with ``num_slices``.
1987
+ end_key (NestedKey, optional): the key indicating the end of a
1988
+ trajectory (or episode). Defaults to ``("next", "done")``.
1989
+ traj_key (NestedKey, optional): the key indicating the trajectories.
1990
+ Defaults to ``"episode"`` (commonly used across datasets in TorchRL).
1991
+ ends (torch.Tensor, optional): a 1d boolean tensor containing the end of run signals.
1992
+ To be used whenever the ``end_key`` or ``traj_key`` is expensive to get,
1993
+ or when this signal is readily available. Must be used with ``cache_values=True``
1994
+ and cannot be used in conjunction with ``end_key`` or ``traj_key``.
1995
+ trajectories (torch.Tensor, optional): a 1d integer tensor containing the run ids.
1996
+ To be used whenever the ``end_key`` or ``traj_key`` is expensive to get,
1997
+ or when this signal is readily available. Must be used with ``cache_values=True``
1998
+ and cannot be used in conjunction with ``end_key`` or ``traj_key``.
1999
+ cache_values (bool, optional): to be used with static datasets.
2000
+ Will cache the start and end signal of the trajectory. This can be safely used even
2001
+ if the trajectory indices change during calls to :class:`~torchrl.data.ReplayBuffer.extend`
2002
+ as this operation will erase the cache.
2003
+
2004
+ .. warning:: ``cache_values=True`` will not work if the sampler is used with a
2005
+ storage that is extended by another buffer. For instance:
2006
+
2007
+ >>> buffer0 = ReplayBuffer(storage=storage,
2008
+ ... sampler=SliceSampler(num_slices=8, cache_values=True),
2009
+ ... writer=ImmutableWriter())
2010
+ >>> buffer1 = ReplayBuffer(storage=storage,
2011
+ ... sampler=other_sampler)
2012
+ >>> # Wrong! Does not erase the buffer from the sampler of buffer0
2013
+ >>> buffer1.extend(data)
2014
+
2015
+ .. warning:: ``cache_values=True`` will not work as expected if the buffer is
2016
+ shared between processes and one process is responsible for writing
2017
+ and one process for sampling, as erasing the cache can only be done locally.
2018
+
2019
+ truncated_key (NestedKey, optional): If not ``None``, this argument
2020
+ indicates where a truncated signal should be written in the output
2021
+ data. This is used to indicate to value estimators where the provided
2022
+ trajectory breaks. Defaults to ``("next", "truncated")``.
2023
+ This feature only works with :class:`~torchrl.data.replay_buffers.TensorDictReplayBuffer`
2024
+ instances (otherwise the truncated key is returned in the info dictionary
2025
+ returned by the :meth:`~torchrl.data.replay_buffers.ReplayBuffer.sample` method).
2026
+ strict_length (bool, optional): if ``False``, trajectories of length
2027
+ shorter than `slice_len` (or `batch_size // num_slices`) will be
2028
+ allowed to appear in the batch. If ``True``, trajectories shorted
2029
+ than required will be filtered out.
2030
+ Be mindful that this can result in effective `batch_size` shorter
2031
+ than the one asked for! Trajectories can be split using
2032
+ :func:`~torchrl.collectors.split_trajectories`. Defaults to ``True``.
2033
+ compile (bool or dict of kwargs, optional): if ``True``, the bottleneck of
2034
+ the :meth:`~sample` method will be compiled with :func:`~torch.compile`.
2035
+ Keyword arguments can also be passed to torch.compile with this arg.
2036
+ Defaults to ``False``.
2037
+ span (bool, int, Tuple[bool | int, bool | int], optional): if provided, the sampled
2038
+ trajectory will span across the left and/or the right. This means that possibly
2039
+ fewer elements will be provided than what was required. A boolean value means
2040
+ that at least one element will be sampled per trajectory. An integer `i` means
2041
+ that at least `slice_len - i` samples will be gathered for each sampled trajectory.
2042
+ Using tuples allows a fine grained control over the span on the left (beginning
2043
+ of the stored trajectory) and on the right (end of the stored trajectory).
2044
+ max_priority_within_buffer (bool, optional): if ``True``, the max-priority
2045
+ is tracked within the buffer. When ``False``, the max-priority tracks
2046
+ the maximum value since the instantiation of the sampler.
2047
+ Defaults to ``False``.
2048
+
2049
+ Examples:
2050
+ >>> import torch
2051
+ >>> from torchrl.data.replay_buffers import TensorDictReplayBuffer, LazyMemmapStorage, PrioritizedSliceSampler
2052
+ >>> from tensordict import TensorDict
2053
+ >>> sampler = PrioritizedSliceSampler(max_capacity=9, num_slices=3, alpha=0.7, beta=0.9)
2054
+ >>> rb = TensorDictReplayBuffer(storage=LazyMemmapStorage(9), sampler=sampler, batch_size=6)
2055
+ >>> data = TensorDict(
2056
+ ... {
2057
+ ... "observation": torch.randn(9,16),
2058
+ ... "action": torch.randn(9, 1),
2059
+ ... "episode": torch.tensor([0,0,0,1,1,1,2,2,2], dtype=torch.long),
2060
+ ... "steps": torch.tensor([0,1,2,0,1,2,0,1,2], dtype=torch.long),
2061
+ ... ("next", "observation"): torch.randn(9,16),
2062
+ ... ("next", "reward"): torch.randn(9,1),
2063
+ ... ("next", "done"): torch.tensor([0,0,1,0,0,1,0,0,1], dtype=torch.bool).unsqueeze(1),
2064
+ ... },
2065
+ ... batch_size=[9],
2066
+ ... )
2067
+ >>> rb.extend(data)
2068
+ >>> sample, info = rb.sample(return_info=True)
2069
+ >>> print("episode", sample["episode"].tolist())
2070
+ episode [2, 2, 2, 2, 1, 1]
2071
+ >>> print("steps", sample["steps"].tolist())
2072
+ steps [1, 2, 0, 1, 1, 2]
2073
+ >>> print("weight", info["priority_weight"].tolist())
2074
+ weight [1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
2075
+ >>> priority = torch.tensor([0,3,3,0,0,0,1,1,1])
2076
+ >>> rb.update_priority(torch.arange(0,9,1), priority=priority)
2077
+ >>> sample, info = rb.sample(return_info=True)
2078
+ >>> print("episode", sample["episode"].tolist())
2079
+ episode [2, 2, 2, 2, 2, 2]
2080
+ >>> print("steps", sample["steps"].tolist())
2081
+ steps [1, 2, 0, 1, 0, 1]
2082
+ >>> print("weight", info["priority_weight"].tolist())
2083
+ weight [9.120110917137936e-06, 9.120110917137936e-06, 9.120110917137936e-06, 9.120110917137936e-06, 9.120110917137936e-06, 9.120110917137936e-06]
2084
+ """
2085
+
2086
+ def __init__(
2087
+ self,
2088
+ max_capacity: int,
2089
+ alpha: float,
2090
+ beta: float,
2091
+ eps: float = 1e-8,
2092
+ dtype: torch.dtype = torch.float,
2093
+ reduction: str = "max",
2094
+ *,
2095
+ num_slices: int | None = None,
2096
+ slice_len: int | None = None,
2097
+ end_key: NestedKey | None = None,
2098
+ traj_key: NestedKey | None = None,
2099
+ ends: torch.Tensor | None = None,
2100
+ trajectories: torch.Tensor | None = None,
2101
+ cache_values: bool = False,
2102
+ truncated_key: NestedKey | None = ("next", "truncated"),
2103
+ strict_length: bool = True,
2104
+ compile: bool | dict = False,
2105
+ span: bool | int | tuple[bool | int, bool | int] = False,
2106
+ max_priority_within_buffer: bool = False,
2107
+ ):
2108
+ SliceSampler.__init__(
2109
+ self,
2110
+ num_slices=num_slices,
2111
+ slice_len=slice_len,
2112
+ end_key=end_key,
2113
+ traj_key=traj_key,
2114
+ cache_values=cache_values,
2115
+ truncated_key=truncated_key,
2116
+ strict_length=strict_length,
2117
+ ends=ends,
2118
+ trajectories=trajectories,
2119
+ compile=compile,
2120
+ span=span,
2121
+ )
2122
+ PrioritizedSampler.__init__(
2123
+ self,
2124
+ max_capacity=max_capacity,
2125
+ alpha=alpha,
2126
+ beta=beta,
2127
+ eps=eps,
2128
+ dtype=dtype,
2129
+ reduction=reduction,
2130
+ max_priority_within_buffer=max_priority_within_buffer,
2131
+ )
2132
+ if self.span[0]:
2133
+ # Span left is hard to achieve because we need to sample 'negative' starts, but to sample
2134
+ # the start we rely on PrioritizedSampler which has no idea it's looking at trajectories.
2135
+ #
2136
+ # Another way to go about this would be to stochastically decrease the seq_length to
2137
+ # accommodate this but that would require to over-sample the starts too.
2138
+ #
2139
+ warnings.warn(
2140
+ f"Left spanning is disabled for {type(self).__name__} and will be automatically turned off. "
2141
+ f"If this feature is required, please file an issue on torchrl GitHub repo."
2142
+ )
2143
+ self.span = (0, self.span[1])
2144
+
2145
+ def __repr__(self):
2146
+ return (
2147
+ f"{self.__class__.__name__}("
2148
+ f"num_slices={self.num_slices}, "
2149
+ f"slice_len={self.slice_len}, "
2150
+ f"end_key={self.end_key}, "
2151
+ f"traj_key={self.traj_key}, "
2152
+ f"truncated_key={self.truncated_key}, "
2153
+ f"strict_length={self.strict_length},"
2154
+ f"alpha={self._alpha}, "
2155
+ f"beta={self._beta}, "
2156
+ f"eps={self._eps}"
2157
+ )
2158
+
2159
+ def __getstate__(self):
2160
+ state = SliceSampler.__getstate__(self)
2161
+ state.update(PrioritizedSampler.__getstate__(self))
2162
+ return state
2163
+
2164
+ def mark_update(
2165
+ self, index: int | torch.Tensor, *, storage: Storage | None = None
2166
+ ) -> None:
2167
+ return PrioritizedSampler.mark_update(self, index, storage=storage)
2168
+
2169
+ def _padded_indices(self, shapes, arange) -> torch.Tensor:
2170
+ # this complex mumbo jumbo creates a left padded tensor with valid indices on the right, e.g.
2171
+ # tensor([[ 0, 1, 2, 3, 4],
2172
+ # [-1, -1, 5, 6, 7],
2173
+ # [-1, 8, 9, 10, 11]])
2174
+ # where the -1 items on the left are padded values
2175
+ num_groups = shapes.shape[0]
2176
+ max_group_len = shapes.max()
2177
+ pad_lengths = max_group_len - shapes
2178
+
2179
+ # Get all the start and end indices within arange for each group
2180
+ group_ends = shapes.cumsum(0)
2181
+ group_starts = torch.empty_like(group_ends)
2182
+ group_starts[0] = 0
2183
+ group_starts[1:] = group_ends[:-1]
2184
+ pad = torch.empty(
2185
+ (num_groups, max_group_len), dtype=arange.dtype, device=arange.device
2186
+ )
2187
+ for pad_row, group_start, group_end, pad_len in zip(
2188
+ pad, group_starts, group_ends, pad_lengths
2189
+ ):
2190
+ pad_row[:pad_len] = -1
2191
+ pad_row[pad_len:] = arange[group_start:group_end]
2192
+
2193
+ return pad
2194
+
2195
+ def _preceding_stop_idx(self, storage, lengths, seq_length, start_idx):
2196
+ preceding_stop_idx = self._cache.get("preceding_stop_idx")
2197
+ if preceding_stop_idx is not None:
2198
+ return preceding_stop_idx
2199
+ arange = torch.arange(storage.shape.numel())
2200
+ shapes = lengths.view(-1, 1).cpu()
2201
+ if not shapes.sum() - 1 == arange[-1]:
2202
+ raise RuntimeError("Wrong shapes / arange configuration")
2203
+ if not self.strict_length:
2204
+ # First, remove the starts from the arange
2205
+ # We do this because each traj can be sampled
2206
+ all_but_starts = torch.ones(arange.shape, dtype=torch.bool)
2207
+ starts = lengths.cumsum(0)
2208
+ starts = torch.cat([torch.zeros_like(starts[:1]), starts[:-1]])
2209
+ all_but_starts[starts] = False
2210
+ arange = arange[all_but_starts]
2211
+ shapes = shapes - 1
2212
+ pad = self._padded_indices(shapes, arange)
2213
+ _, span_right = self.span[0], self.span[1]
2214
+ if span_right and isinstance(span_right, bool):
2215
+ preceding_stop_idx = pad[:, -1:]
2216
+ else:
2217
+ # Mask the rightmost values of that padded tensor
2218
+ preceding_stop_idx = pad[:, -seq_length + 1 + span_right :]
2219
+ preceding_stop_idx = preceding_stop_idx[preceding_stop_idx >= 0]
2220
+ if storage._is_full:
2221
+ preceding_stop_idx = (
2222
+ preceding_stop_idx
2223
+ + np.ravel_multi_index(
2224
+ tuple(start_idx[0].tolist()), storage._total_shape
2225
+ )
2226
+ ) % storage._total_shape.numel()
2227
+ if self.cache_values:
2228
+ self._cache["preceding_stop_idx"] = preceding_stop_idx
2229
+ return preceding_stop_idx
2230
+
2231
+ def sample(self, storage: Storage, batch_size: int) -> tuple[torch.Tensor, dict]:
2232
+ # Sample `batch_size` indices representing the start of a slice.
2233
+ # The sampling is based on a weight vector.
2234
+ start_idx, stop_idx, lengths = self._get_stop_and_length(storage)
2235
+ seq_length, num_slices = self._adjusted_batch_size(batch_size)
2236
+
2237
+ preceding_stop_idx = self._preceding_stop_idx(
2238
+ storage, lengths, seq_length, start_idx
2239
+ )
2240
+ if storage.ndim > 1:
2241
+ # we need to convert indices of the permuted, flatten storage to indices in a flatten storage (not permuted)
2242
+ # This is because the lengths come as they would for a permuted storage
2243
+ preceding_stop_idx = unravel_index(
2244
+ preceding_stop_idx, (storage.shape[-1], *storage.shape[:-1])
2245
+ )
2246
+ preceding_stop_idx = (preceding_stop_idx[-1], *preceding_stop_idx[:-1])
2247
+ preceding_stop_idx = torch.as_tensor(
2248
+ np.ravel_multi_index(preceding_stop_idx, storage.shape)
2249
+ )
2250
+
2251
+ # force to not sample index at the end of a trajectory
2252
+ vals = torch.tensor(self._sum_tree[preceding_stop_idx.cpu().numpy()])
2253
+ self._sum_tree[preceding_stop_idx.cpu().numpy()] = 0.0
2254
+ # and no need to update self._min_tree
2255
+
2256
+ starts, info = PrioritizedSampler.sample(
2257
+ self, storage=storage, batch_size=batch_size // seq_length
2258
+ )
2259
+ self._sum_tree[preceding_stop_idx.cpu().numpy()] = vals
2260
+ # We must truncate the seq_length if (1) not strict length or (2) span[1]
2261
+ if self.span[1] or not self.strict_length:
2262
+ if not isinstance(starts, torch.Tensor):
2263
+ starts_tensor = torch.stack(list(starts), dim=-1).to(stop_idx.device)
2264
+ else:
2265
+ starts_tensor = starts.unsqueeze(1).to(stop_idx.device)
2266
+ # Find the stop that comes after the start index
2267
+ # say start_tensor has shape [N, X] and stop_idx has shape [M, X]
2268
+ # diff will have shape [M, N, X]
2269
+ stop_idx_corr = stop_idx.clone()
2270
+ stop_idx_corr[:, 0] = torch.where(
2271
+ stop_idx[:, 0] < start_idx[:, 0],
2272
+ stop_idx[:, 0] + storage._len_along_dim0,
2273
+ stop_idx[:, 0],
2274
+ )
2275
+ diff = stop_idx_corr.unsqueeze(1) - starts_tensor.unsqueeze(0)
2276
+ # filter out all items that don't belong to the same dim in the storage
2277
+ mask = (diff[:, :, 1:] != 0).any(-1)
2278
+ diff = diff[:, :, 0]
2279
+ diff[mask] = diff.max() + 1
2280
+ diff = diff.reshape(-1, starts_tensor.shape[0])
2281
+ # We remove all neg values from consideration
2282
+ diff[diff < 0] = diff.max() + 1
2283
+ # Take the arg min along dim 0 (thereby reducing dim M)
2284
+ idx = diff.argmin(dim=0)
2285
+ stops = stop_idx_corr[idx, 0]
2286
+ # TODO: here things may not work bc we could have spanning trajs,
2287
+ # though I cannot show that it breaks in the tests
2288
+ if starts_tensor.ndim > 1:
2289
+ starts_tensor = starts_tensor[:, 0]
2290
+ seq_length = (stops - starts_tensor + 1).clamp_max(seq_length)
2291
+ if (seq_length <= 0).any():
2292
+ raise RuntimeError(
2293
+ "failed to compute seq_length, please report this bug"
2294
+ )
2295
+
2296
+ if isinstance(starts, tuple):
2297
+ starts = torch.stack(starts, -1)
2298
+ # starts = torch.as_tensor(starts, device=lengths.device)
2299
+ info["priority_weight"] = torch.as_tensor(
2300
+ info["priority_weight"], device=lengths.device
2301
+ )
2302
+
2303
+ # extends starting indices of each slice with sequence_length to get indices of all steps
2304
+ index = self._tensor_slices_from_startend(
2305
+ seq_length, starts, storage_length=storage.shape[0]
2306
+ )
2307
+
2308
+ # repeat the weight of each slice to match the number of steps
2309
+ info["priority_weight"] = torch.repeat_interleave(
2310
+ info["priority_weight"], seq_length
2311
+ )
2312
+
2313
+ if self.truncated_key is not None:
2314
+ # following logics borrowed from SliceSampler
2315
+ truncated_key = self.truncated_key
2316
+
2317
+ done_key = _replace_last(truncated_key, "done")
2318
+ terminated_key = _replace_last(truncated_key, "terminated")
2319
+
2320
+ truncated = torch.zeros(
2321
+ (index.shape[0], 1), dtype=torch.bool, device=index.device
2322
+ )
2323
+ if isinstance(seq_length, int):
2324
+ truncated.view(num_slices, -1)[:, -1] = 1
2325
+ else:
2326
+ truncated[seq_length.cumsum(0) - 1] = 1
2327
+ index = index.to(torch.long).unbind(-1)
2328
+ st_index = storage[index]
2329
+ try:
2330
+ done = st_index[done_key] | truncated
2331
+ except KeyError:
2332
+ done = truncated.clone()
2333
+ try:
2334
+ terminated = st_index[terminated_key]
2335
+ except KeyError:
2336
+ terminated = torch.zeros_like(truncated)
2337
+ info.update(
2338
+ {
2339
+ truncated_key: truncated,
2340
+ done_key: done,
2341
+ terminated_key: terminated,
2342
+ }
2343
+ )
2344
+ return index, info
2345
+ return index.to(torch.long).unbind(-1), info
2346
+
2347
+ def _empty(self):
2348
+ # no op for SliceSampler
2349
+ PrioritizedSampler._empty(self)
2350
+
2351
+ def dumps(self, path):
2352
+ # no op for SliceSampler
2353
+ PrioritizedSampler.dumps(self, path)
2354
+
2355
+ def loads(self, path):
2356
+ # no op for SliceSampler
2357
+ return PrioritizedSampler.loads(self, path)
2358
+
2359
+ def state_dict(self):
2360
+ # no op for SliceSampler
2361
+ return PrioritizedSampler.state_dict(self)
2362
+
2363
+ def add(self, index: torch.Tensor) -> None:
2364
+ PrioritizedSampler.add(self, index)
2365
+ return SliceSampler.add(self, index)
2366
+
2367
+ def extend(self, index: torch.Tensor) -> None:
2368
+ PrioritizedSampler.extend(self, index)
2369
+ return SliceSampler.extend(self, index)
2370
+
2371
+
2372
+ class SamplerEnsemble(Sampler):
2373
+ """An ensemble of samplers.
2374
+
2375
+ This class is designed to work with :class:`~torchrl.data.replay_buffers.replay_buffers.ReplayBufferEnsemble`.
2376
+ It contains the samplers as well as the sampling strategy hyperparameters.
2377
+
2378
+ Args:
2379
+ samplers (sequence of Sampler): the samplers to make the composite sampler.
2380
+
2381
+ Keyword Args:
2382
+ p (list or tensor of probabilities, optional): if provided, indicates the
2383
+ weights of each dataset during sampling.
2384
+ sample_from_all (bool, optional): if ``True``, each dataset will be sampled
2385
+ from. This is not compatible with the ``p`` argument. Defaults to ``False``.
2386
+ num_buffer_sampled (int, optional): the number of buffers to sample.
2387
+ if ``sample_from_all=True``, this has no effect, as it defaults to the
2388
+ number of buffers. If ``sample_from_all=False``, buffers will be
2389
+ sampled according to the probabilities ``p``.
2390
+
2391
+ .. warning::
2392
+ The indices provided in the info dictionary are placed in a :class:`~tensordict.TensorDict` with
2393
+ keys ``index`` and ``buffer_ids`` that allow the upper :class:`~torchrl.data.ReplayBufferEnsemble`
2394
+ and :class:`~torchrl.data.StorageEnsemble` objects to retrieve the data.
2395
+ This format is different from with other samplers which usually return indices
2396
+ as regular tensors.
2397
+
2398
+ """
2399
+
2400
+ def __init__(
2401
+ self, *samplers, p=None, sample_from_all=False, num_buffer_sampled=None
2402
+ ):
2403
+ self._rng_private = None
2404
+ self._samplers = samplers
2405
+ self.sample_from_all = sample_from_all
2406
+ if sample_from_all and p is not None:
2407
+ raise RuntimeError(
2408
+ "Cannot pass both `p` argument and `sample_from_all=True`."
2409
+ )
2410
+ self.p = p
2411
+ self.num_buffer_sampled = num_buffer_sampled
2412
+
2413
+ @property
2414
+ def _rng(self):
2415
+ return self._rng_private
2416
+
2417
+ @_rng.setter
2418
+ def _rng(self, value):
2419
+ self._rng_private = value
2420
+ for sampler in self._samplers:
2421
+ sampler._rng = value
2422
+
2423
+ @property
2424
+ def p(self):
2425
+ return self._p
2426
+
2427
+ @p.setter
2428
+ def p(self, value):
2429
+ if not isinstance(value, torch.Tensor) and value is not None:
2430
+ value = torch.tensor(value)
2431
+ if value is not None:
2432
+ value = value / value.sum().clamp_min(1e-6)
2433
+ self._p = value
2434
+
2435
+ @property
2436
+ def num_buffer_sampled(self):
2437
+ value = self.__dict__.get("_num_buffer_sampled", None)
2438
+ if value is None:
2439
+ value = self.__dict__["_num_buffer_sampled"] = len(self._samplers)
2440
+ return value
2441
+
2442
+ @num_buffer_sampled.setter
2443
+ def num_buffer_sampled(self, value):
2444
+ self.__dict__["_num_buffer_sampled"] = value
2445
+
2446
+ def sample(self, storage, batch_size):
2447
+ if batch_size % self.num_buffer_sampled > 0:
2448
+ raise ValueError
2449
+ if not isinstance(storage, StorageEnsemble):
2450
+ raise TypeError
2451
+ sub_batch_size = batch_size // self.num_buffer_sampled
2452
+ if self.sample_from_all:
2453
+ samples, infos = zip(
2454
+ *[
2455
+ sampler.sample(storage, sub_batch_size)
2456
+ for storage, sampler in zip(storage._storages, self._samplers)
2457
+ ]
2458
+ )
2459
+ buffer_ids = torch.arange(len(samples))
2460
+ else:
2461
+ if self.p is None:
2462
+ buffer_ids = torch.randint(
2463
+ len(self._samplers),
2464
+ (self.num_buffer_sampled,),
2465
+ generator=self._rng,
2466
+ device=getattr(storage, "device", None),
2467
+ )
2468
+ else:
2469
+ buffer_ids = torch.multinomial(self.p, self.num_buffer_sampled, True)
2470
+ samples, infos = zip(
2471
+ *[
2472
+ self._samplers[i].sample(storage._storages[i], sub_batch_size)
2473
+ for i in buffer_ids.tolist()
2474
+ ]
2475
+ )
2476
+ samples = [
2477
+ sample if isinstance(sample, torch.Tensor) else torch.stack(sample, -1)
2478
+ for sample in samples
2479
+ ]
2480
+ if all(samples[0].shape == sample.shape for sample in samples[1:]):
2481
+ samples_stack = torch.stack(samples)
2482
+ else:
2483
+ samples_stack = torch.nested.nested_tensor(list(samples))
2484
+
2485
+ samples = TensorDict(
2486
+ {
2487
+ "index": samples_stack,
2488
+ "buffer_ids": buffer_ids,
2489
+ },
2490
+ batch_size=[self.num_buffer_sampled],
2491
+ )
2492
+ infos = torch.stack(
2493
+ [
2494
+ TensorDict.from_dict(info, batch_dims=samples.ndim - 1)
2495
+ if info
2496
+ else TensorDict()
2497
+ for info in infos
2498
+ ]
2499
+ )
2500
+ return samples, infos
2501
+
2502
+ def dumps(self, path: Path):
2503
+ path = Path(path).absolute()
2504
+ for i, sampler in enumerate(self._samplers):
2505
+ sampler.dumps(path / str(i))
2506
+
2507
+ def loads(self, path: Path):
2508
+ path = Path(path).absolute()
2509
+ for i, sampler in enumerate(self._samplers):
2510
+ sampler.loads(path / str(i))
2511
+
2512
+ def state_dict(self) -> dict[str, Any]:
2513
+ state_dict = OrderedDict()
2514
+ for i, sampler in enumerate(self._samplers):
2515
+ state_dict[str(i)] = sampler.state_dict()
2516
+ return state_dict
2517
+
2518
+ def load_state_dict(self, state_dict: dict[str, Any]) -> None:
2519
+ for i, sampler in enumerate(self._samplers):
2520
+ sampler.load_state_dict(state_dict[str(i)])
2521
+
2522
+ def _empty(self):
2523
+ raise NotImplementedError
2524
+
2525
+ _INDEX_ERROR = "Expected an index of type torch.Tensor, range, np.ndarray, int, slice or ellipsis, got {} instead."
2526
+
2527
+ def __getitem__(self, index):
2528
+ if isinstance(index, tuple):
2529
+ if index[0] is Ellipsis:
2530
+ index = (slice(None), index[1:])
2531
+ result = self[index[0]]
2532
+ if len(index) > 1:
2533
+ raise IndexError(
2534
+ f"Tuple of length greater than 1 are not accepted to index samplers of type {type(self)}."
2535
+ )
2536
+ return result
2537
+ if isinstance(index, slice) and index == slice(None):
2538
+ return self
2539
+ if isinstance(index, (list, range, np.ndarray)):
2540
+ index = torch.as_tensor(index)
2541
+ if isinstance(index, torch.Tensor):
2542
+ if index.ndim > 1:
2543
+ raise RuntimeError(
2544
+ f"Cannot index a {type(self)} with tensor indices that have more than one dimension."
2545
+ )
2546
+ if index.is_floating_point():
2547
+ raise TypeError(
2548
+ "A floating point index was received when an integer dtype was expected."
2549
+ )
2550
+ if isinstance(index, int) or (not isinstance(index, slice) and len(index) == 0):
2551
+ try:
2552
+ index = int(index)
2553
+ except Exception:
2554
+ raise IndexError(self._INDEX_ERROR.format(type(index)))
2555
+ try:
2556
+ return self._samplers[index]
2557
+ except IndexError:
2558
+ raise IndexError(self._INDEX_ERROR.format(type(index)))
2559
+ if isinstance(index, torch.Tensor):
2560
+ index = index.tolist()
2561
+ samplers = [self._samplers[i] for i in index]
2562
+ else:
2563
+ # slice
2564
+ samplers = self._samplers[index]
2565
+ p = self._p[index]
2566
+ return SamplerEnsemble(
2567
+ *samplers,
2568
+ p=p,
2569
+ sample_from_all=self.sample_from_all,
2570
+ num_buffer_sampled=self.num_buffer_sampled,
2571
+ )
2572
+
2573
+ def __len__(self):
2574
+ return len(self._samplers)
2575
+
2576
+ def __repr__(self):
2577
+ samplers = textwrap.indent(f"samplers={self._samplers}", " " * 4)
2578
+ return f"{self.__class__.__name__}(\n{samplers})"