torchrl 0.11.0__cp314-cp314t-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/benchmark_batched_envs.py +104 -0
- benchmarks/conftest.py +91 -0
- benchmarks/ecosystem/gym_env_throughput.py +321 -0
- benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
- benchmarks/requirements.txt +7 -0
- benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
- benchmarks/test_collectors_benchmark.py +240 -0
- benchmarks/test_compressed_storage_benchmark.py +145 -0
- benchmarks/test_envs_benchmark.py +133 -0
- benchmarks/test_llm.py +101 -0
- benchmarks/test_non_tensor_env_benchmark.py +70 -0
- benchmarks/test_objectives_benchmarks.py +1199 -0
- benchmarks/test_replaybuffer_benchmark.py +254 -0
- sota-check/README.md +35 -0
- sota-implementations/README.md +142 -0
- sota-implementations/a2c/README.md +39 -0
- sota-implementations/a2c/a2c_atari.py +291 -0
- sota-implementations/a2c/a2c_mujoco.py +273 -0
- sota-implementations/a2c/utils_atari.py +240 -0
- sota-implementations/a2c/utils_mujoco.py +160 -0
- sota-implementations/bandits/README.md +7 -0
- sota-implementations/bandits/dqn.py +126 -0
- sota-implementations/cql/cql_offline.py +198 -0
- sota-implementations/cql/cql_online.py +249 -0
- sota-implementations/cql/discrete_cql_offline.py +180 -0
- sota-implementations/cql/discrete_cql_online.py +227 -0
- sota-implementations/cql/utils.py +471 -0
- sota-implementations/crossq/crossq.py +271 -0
- sota-implementations/crossq/utils.py +320 -0
- sota-implementations/ddpg/ddpg.py +231 -0
- sota-implementations/ddpg/utils.py +325 -0
- sota-implementations/decision_transformer/dt.py +163 -0
- sota-implementations/decision_transformer/lamb.py +167 -0
- sota-implementations/decision_transformer/online_dt.py +178 -0
- sota-implementations/decision_transformer/utils.py +562 -0
- sota-implementations/discrete_sac/discrete_sac.py +243 -0
- sota-implementations/discrete_sac/utils.py +324 -0
- sota-implementations/dqn/README.md +30 -0
- sota-implementations/dqn/dqn_atari.py +272 -0
- sota-implementations/dqn/dqn_cartpole.py +236 -0
- sota-implementations/dqn/utils_atari.py +132 -0
- sota-implementations/dqn/utils_cartpole.py +90 -0
- sota-implementations/dreamer/README.md +129 -0
- sota-implementations/dreamer/dreamer.py +586 -0
- sota-implementations/dreamer/dreamer_utils.py +1107 -0
- sota-implementations/expert-iteration/README.md +352 -0
- sota-implementations/expert-iteration/ei_utils.py +770 -0
- sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
- sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
- sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
- sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
- sota-implementations/gail/gail.py +327 -0
- sota-implementations/gail/gail_utils.py +68 -0
- sota-implementations/gail/ppo_utils.py +157 -0
- sota-implementations/grpo/README.md +273 -0
- sota-implementations/grpo/grpo-async.py +437 -0
- sota-implementations/grpo/grpo-sync.py +435 -0
- sota-implementations/grpo/grpo_utils.py +843 -0
- sota-implementations/grpo/requirements_gsm8k.txt +11 -0
- sota-implementations/grpo/requirements_ifeval.txt +16 -0
- sota-implementations/impala/README.md +33 -0
- sota-implementations/impala/impala_multi_node_ray.py +292 -0
- sota-implementations/impala/impala_multi_node_submitit.py +284 -0
- sota-implementations/impala/impala_single_node.py +261 -0
- sota-implementations/impala/utils.py +184 -0
- sota-implementations/iql/discrete_iql.py +230 -0
- sota-implementations/iql/iql_offline.py +164 -0
- sota-implementations/iql/iql_online.py +225 -0
- sota-implementations/iql/utils.py +437 -0
- sota-implementations/multiagent/README.md +74 -0
- sota-implementations/multiagent/iql.py +237 -0
- sota-implementations/multiagent/maddpg_iddpg.py +266 -0
- sota-implementations/multiagent/mappo_ippo.py +267 -0
- sota-implementations/multiagent/qmix_vdn.py +271 -0
- sota-implementations/multiagent/sac.py +337 -0
- sota-implementations/multiagent/utils/__init__.py +4 -0
- sota-implementations/multiagent/utils/logging.py +151 -0
- sota-implementations/multiagent/utils/utils.py +43 -0
- sota-implementations/ppo/README.md +29 -0
- sota-implementations/ppo/ppo_atari.py +305 -0
- sota-implementations/ppo/ppo_mujoco.py +293 -0
- sota-implementations/ppo/utils_atari.py +238 -0
- sota-implementations/ppo/utils_mujoco.py +152 -0
- sota-implementations/ppo_trainer/train.py +21 -0
- sota-implementations/redq/README.md +7 -0
- sota-implementations/redq/redq.py +199 -0
- sota-implementations/redq/utils.py +1060 -0
- sota-implementations/sac/sac-async.py +266 -0
- sota-implementations/sac/sac.py +239 -0
- sota-implementations/sac/utils.py +381 -0
- sota-implementations/sac_trainer/train.py +16 -0
- sota-implementations/td3/td3.py +254 -0
- sota-implementations/td3/utils.py +319 -0
- sota-implementations/td3_bc/td3_bc.py +177 -0
- sota-implementations/td3_bc/utils.py +251 -0
- torchrl/.dylibs/libc++.1.0.dylib +0 -0
- torchrl/__init__.py +144 -0
- torchrl/_extension.py +74 -0
- torchrl/_torchrl.cpython-314t-darwin.so +0 -0
- torchrl/_utils.py +1431 -0
- torchrl/collectors/__init__.py +48 -0
- torchrl/collectors/_base.py +1058 -0
- torchrl/collectors/_constants.py +88 -0
- torchrl/collectors/_multi_async.py +324 -0
- torchrl/collectors/_multi_base.py +1805 -0
- torchrl/collectors/_multi_sync.py +464 -0
- torchrl/collectors/_runner.py +581 -0
- torchrl/collectors/_single.py +2009 -0
- torchrl/collectors/_single_async.py +259 -0
- torchrl/collectors/collectors.py +62 -0
- torchrl/collectors/distributed/__init__.py +32 -0
- torchrl/collectors/distributed/default_configs.py +133 -0
- torchrl/collectors/distributed/generic.py +1306 -0
- torchrl/collectors/distributed/ray.py +1092 -0
- torchrl/collectors/distributed/rpc.py +1006 -0
- torchrl/collectors/distributed/sync.py +731 -0
- torchrl/collectors/distributed/utils.py +160 -0
- torchrl/collectors/llm/__init__.py +10 -0
- torchrl/collectors/llm/base.py +494 -0
- torchrl/collectors/llm/ray_collector.py +275 -0
- torchrl/collectors/llm/utils.py +36 -0
- torchrl/collectors/llm/weight_update/__init__.py +10 -0
- torchrl/collectors/llm/weight_update/vllm.py +348 -0
- torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
- torchrl/collectors/utils.py +433 -0
- torchrl/collectors/weight_update.py +591 -0
- torchrl/csrc/numpy_utils.h +38 -0
- torchrl/csrc/pybind.cpp +27 -0
- torchrl/csrc/segment_tree.h +458 -0
- torchrl/csrc/torch_utils.h +34 -0
- torchrl/csrc/utils.cpp +48 -0
- torchrl/csrc/utils.h +31 -0
- torchrl/data/__init__.py +187 -0
- torchrl/data/datasets/__init__.py +58 -0
- torchrl/data/datasets/atari_dqn.py +878 -0
- torchrl/data/datasets/common.py +281 -0
- torchrl/data/datasets/d4rl.py +489 -0
- torchrl/data/datasets/d4rl_infos.py +187 -0
- torchrl/data/datasets/gen_dgrl.py +375 -0
- torchrl/data/datasets/minari_data.py +643 -0
- torchrl/data/datasets/openml.py +177 -0
- torchrl/data/datasets/openx.py +798 -0
- torchrl/data/datasets/roboset.py +363 -0
- torchrl/data/datasets/utils.py +11 -0
- torchrl/data/datasets/vd4rl.py +432 -0
- torchrl/data/llm/__init__.py +34 -0
- torchrl/data/llm/dataset.py +491 -0
- torchrl/data/llm/history.py +1378 -0
- torchrl/data/llm/prompt.py +198 -0
- torchrl/data/llm/reward.py +225 -0
- torchrl/data/llm/topk.py +186 -0
- torchrl/data/llm/utils.py +543 -0
- torchrl/data/map/__init__.py +21 -0
- torchrl/data/map/hash.py +185 -0
- torchrl/data/map/query.py +204 -0
- torchrl/data/map/tdstorage.py +363 -0
- torchrl/data/map/tree.py +1434 -0
- torchrl/data/map/utils.py +103 -0
- torchrl/data/postprocs/__init__.py +8 -0
- torchrl/data/postprocs/postprocs.py +391 -0
- torchrl/data/replay_buffers/__init__.py +99 -0
- torchrl/data/replay_buffers/checkpointers.py +622 -0
- torchrl/data/replay_buffers/ray_buffer.py +292 -0
- torchrl/data/replay_buffers/replay_buffers.py +2376 -0
- torchrl/data/replay_buffers/samplers.py +2578 -0
- torchrl/data/replay_buffers/scheduler.py +265 -0
- torchrl/data/replay_buffers/storages.py +2412 -0
- torchrl/data/replay_buffers/utils.py +1042 -0
- torchrl/data/replay_buffers/writers.py +781 -0
- torchrl/data/tensor_specs.py +7101 -0
- torchrl/data/utils.py +334 -0
- torchrl/envs/__init__.py +265 -0
- torchrl/envs/async_envs.py +1105 -0
- torchrl/envs/batched_envs.py +3093 -0
- torchrl/envs/common.py +4241 -0
- torchrl/envs/custom/__init__.py +11 -0
- torchrl/envs/custom/chess.py +617 -0
- torchrl/envs/custom/llm.py +214 -0
- torchrl/envs/custom/pendulum.py +401 -0
- torchrl/envs/custom/san_moves.txt +29274 -0
- torchrl/envs/custom/tictactoeenv.py +288 -0
- torchrl/envs/env_creator.py +263 -0
- torchrl/envs/gym_like.py +752 -0
- torchrl/envs/libs/__init__.py +68 -0
- torchrl/envs/libs/_gym_utils.py +326 -0
- torchrl/envs/libs/brax.py +846 -0
- torchrl/envs/libs/dm_control.py +544 -0
- torchrl/envs/libs/envpool.py +447 -0
- torchrl/envs/libs/gym.py +2239 -0
- torchrl/envs/libs/habitat.py +138 -0
- torchrl/envs/libs/isaac_lab.py +87 -0
- torchrl/envs/libs/isaacgym.py +203 -0
- torchrl/envs/libs/jax_utils.py +166 -0
- torchrl/envs/libs/jumanji.py +963 -0
- torchrl/envs/libs/meltingpot.py +599 -0
- torchrl/envs/libs/openml.py +153 -0
- torchrl/envs/libs/openspiel.py +652 -0
- torchrl/envs/libs/pettingzoo.py +1042 -0
- torchrl/envs/libs/procgen.py +351 -0
- torchrl/envs/libs/robohive.py +429 -0
- torchrl/envs/libs/smacv2.py +645 -0
- torchrl/envs/libs/unity_mlagents.py +891 -0
- torchrl/envs/libs/utils.py +147 -0
- torchrl/envs/libs/vmas.py +813 -0
- torchrl/envs/llm/__init__.py +63 -0
- torchrl/envs/llm/chat.py +730 -0
- torchrl/envs/llm/datasets/README.md +4 -0
- torchrl/envs/llm/datasets/__init__.py +17 -0
- torchrl/envs/llm/datasets/gsm8k.py +353 -0
- torchrl/envs/llm/datasets/ifeval.py +274 -0
- torchrl/envs/llm/envs.py +789 -0
- torchrl/envs/llm/libs/README.md +3 -0
- torchrl/envs/llm/libs/__init__.py +8 -0
- torchrl/envs/llm/libs/mlgym.py +869 -0
- torchrl/envs/llm/reward/__init__.py +10 -0
- torchrl/envs/llm/reward/gsm8k.py +324 -0
- torchrl/envs/llm/reward/ifeval/README.md +13 -0
- torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
- torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
- torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
- torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
- torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
- torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
- torchrl/envs/llm/transforms/__init__.py +55 -0
- torchrl/envs/llm/transforms/browser.py +292 -0
- torchrl/envs/llm/transforms/dataloading.py +859 -0
- torchrl/envs/llm/transforms/format.py +73 -0
- torchrl/envs/llm/transforms/kl.py +1544 -0
- torchrl/envs/llm/transforms/policy_version.py +189 -0
- torchrl/envs/llm/transforms/reason.py +323 -0
- torchrl/envs/llm/transforms/tokenizer.py +321 -0
- torchrl/envs/llm/transforms/tools.py +1955 -0
- torchrl/envs/model_based/__init__.py +9 -0
- torchrl/envs/model_based/common.py +180 -0
- torchrl/envs/model_based/dreamer.py +112 -0
- torchrl/envs/transforms/__init__.py +147 -0
- torchrl/envs/transforms/functional.py +48 -0
- torchrl/envs/transforms/gym_transforms.py +203 -0
- torchrl/envs/transforms/module.py +341 -0
- torchrl/envs/transforms/r3m.py +372 -0
- torchrl/envs/transforms/ray_service.py +663 -0
- torchrl/envs/transforms/rb_transforms.py +214 -0
- torchrl/envs/transforms/transforms.py +11835 -0
- torchrl/envs/transforms/utils.py +94 -0
- torchrl/envs/transforms/vc1.py +307 -0
- torchrl/envs/transforms/vecnorm.py +845 -0
- torchrl/envs/transforms/vip.py +407 -0
- torchrl/envs/utils.py +1718 -0
- torchrl/envs/vec_envs.py +11 -0
- torchrl/modules/__init__.py +206 -0
- torchrl/modules/distributions/__init__.py +73 -0
- torchrl/modules/distributions/continuous.py +830 -0
- torchrl/modules/distributions/discrete.py +908 -0
- torchrl/modules/distributions/truncated_normal.py +187 -0
- torchrl/modules/distributions/utils.py +233 -0
- torchrl/modules/llm/__init__.py +62 -0
- torchrl/modules/llm/backends/__init__.py +65 -0
- torchrl/modules/llm/backends/vllm/__init__.py +94 -0
- torchrl/modules/llm/backends/vllm/_models.py +46 -0
- torchrl/modules/llm/backends/vllm/base.py +72 -0
- torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
- torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
- torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
- torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
- torchrl/modules/llm/policies/__init__.py +28 -0
- torchrl/modules/llm/policies/common.py +1809 -0
- torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
- torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
- torchrl/modules/llm/utils.py +23 -0
- torchrl/modules/mcts/__init__.py +21 -0
- torchrl/modules/mcts/scores.py +579 -0
- torchrl/modules/models/__init__.py +86 -0
- torchrl/modules/models/batchrenorm.py +119 -0
- torchrl/modules/models/decision_transformer.py +179 -0
- torchrl/modules/models/exploration.py +731 -0
- torchrl/modules/models/llm.py +156 -0
- torchrl/modules/models/model_based.py +596 -0
- torchrl/modules/models/models.py +1712 -0
- torchrl/modules/models/multiagent.py +1067 -0
- torchrl/modules/models/recipes/impala.py +185 -0
- torchrl/modules/models/utils.py +162 -0
- torchrl/modules/planners/__init__.py +10 -0
- torchrl/modules/planners/cem.py +228 -0
- torchrl/modules/planners/common.py +73 -0
- torchrl/modules/planners/mppi.py +265 -0
- torchrl/modules/tensordict_module/__init__.py +89 -0
- torchrl/modules/tensordict_module/actors.py +2457 -0
- torchrl/modules/tensordict_module/common.py +529 -0
- torchrl/modules/tensordict_module/exploration.py +814 -0
- torchrl/modules/tensordict_module/probabilistic.py +321 -0
- torchrl/modules/tensordict_module/rnn.py +1639 -0
- torchrl/modules/tensordict_module/sequence.py +132 -0
- torchrl/modules/tensordict_module/world_models.py +34 -0
- torchrl/modules/utils/__init__.py +38 -0
- torchrl/modules/utils/mappings.py +9 -0
- torchrl/modules/utils/utils.py +89 -0
- torchrl/objectives/__init__.py +78 -0
- torchrl/objectives/a2c.py +659 -0
- torchrl/objectives/common.py +753 -0
- torchrl/objectives/cql.py +1346 -0
- torchrl/objectives/crossq.py +710 -0
- torchrl/objectives/ddpg.py +453 -0
- torchrl/objectives/decision_transformer.py +371 -0
- torchrl/objectives/deprecated.py +516 -0
- torchrl/objectives/dqn.py +683 -0
- torchrl/objectives/dreamer.py +488 -0
- torchrl/objectives/functional.py +48 -0
- torchrl/objectives/gail.py +258 -0
- torchrl/objectives/iql.py +996 -0
- torchrl/objectives/llm/__init__.py +30 -0
- torchrl/objectives/llm/grpo.py +846 -0
- torchrl/objectives/llm/sft.py +482 -0
- torchrl/objectives/multiagent/__init__.py +8 -0
- torchrl/objectives/multiagent/qmixer.py +396 -0
- torchrl/objectives/ppo.py +1669 -0
- torchrl/objectives/redq.py +683 -0
- torchrl/objectives/reinforce.py +530 -0
- torchrl/objectives/sac.py +1580 -0
- torchrl/objectives/td3.py +570 -0
- torchrl/objectives/td3_bc.py +625 -0
- torchrl/objectives/utils.py +782 -0
- torchrl/objectives/value/__init__.py +28 -0
- torchrl/objectives/value/advantages.py +1956 -0
- torchrl/objectives/value/functional.py +1459 -0
- torchrl/objectives/value/utils.py +360 -0
- torchrl/record/__init__.py +17 -0
- torchrl/record/loggers/__init__.py +23 -0
- torchrl/record/loggers/common.py +48 -0
- torchrl/record/loggers/csv.py +226 -0
- torchrl/record/loggers/mlflow.py +142 -0
- torchrl/record/loggers/tensorboard.py +139 -0
- torchrl/record/loggers/trackio.py +163 -0
- torchrl/record/loggers/utils.py +78 -0
- torchrl/record/loggers/wandb.py +214 -0
- torchrl/record/recorder.py +554 -0
- torchrl/services/__init__.py +79 -0
- torchrl/services/base.py +109 -0
- torchrl/services/ray_service.py +453 -0
- torchrl/testing/__init__.py +107 -0
- torchrl/testing/assertions.py +179 -0
- torchrl/testing/dist_utils.py +122 -0
- torchrl/testing/env_creators.py +227 -0
- torchrl/testing/env_helper.py +35 -0
- torchrl/testing/gym_helpers.py +156 -0
- torchrl/testing/llm_mocks.py +119 -0
- torchrl/testing/mocking_classes.py +2720 -0
- torchrl/testing/modules.py +295 -0
- torchrl/testing/mp_helpers.py +15 -0
- torchrl/testing/ray_helpers.py +293 -0
- torchrl/testing/utils.py +190 -0
- torchrl/trainers/__init__.py +42 -0
- torchrl/trainers/algorithms/__init__.py +11 -0
- torchrl/trainers/algorithms/configs/__init__.py +705 -0
- torchrl/trainers/algorithms/configs/collectors.py +216 -0
- torchrl/trainers/algorithms/configs/common.py +41 -0
- torchrl/trainers/algorithms/configs/data.py +308 -0
- torchrl/trainers/algorithms/configs/envs.py +104 -0
- torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
- torchrl/trainers/algorithms/configs/logging.py +80 -0
- torchrl/trainers/algorithms/configs/modules.py +570 -0
- torchrl/trainers/algorithms/configs/objectives.py +177 -0
- torchrl/trainers/algorithms/configs/trainers.py +340 -0
- torchrl/trainers/algorithms/configs/transforms.py +955 -0
- torchrl/trainers/algorithms/configs/utils.py +252 -0
- torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
- torchrl/trainers/algorithms/configs/weight_update.py +159 -0
- torchrl/trainers/algorithms/ppo.py +373 -0
- torchrl/trainers/algorithms/sac.py +308 -0
- torchrl/trainers/helpers/__init__.py +40 -0
- torchrl/trainers/helpers/collectors.py +416 -0
- torchrl/trainers/helpers/envs.py +573 -0
- torchrl/trainers/helpers/logger.py +33 -0
- torchrl/trainers/helpers/losses.py +132 -0
- torchrl/trainers/helpers/models.py +658 -0
- torchrl/trainers/helpers/replay_buffer.py +59 -0
- torchrl/trainers/helpers/trainers.py +301 -0
- torchrl/trainers/trainers.py +2052 -0
- torchrl/weight_update/__init__.py +33 -0
- torchrl/weight_update/_distributed.py +749 -0
- torchrl/weight_update/_mp.py +624 -0
- torchrl/weight_update/_noupdate.py +102 -0
- torchrl/weight_update/_ray.py +1032 -0
- torchrl/weight_update/_rpc.py +284 -0
- torchrl/weight_update/_shared.py +891 -0
- torchrl/weight_update/llm/__init__.py +32 -0
- torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
- torchrl/weight_update/llm/vllm_nccl.py +710 -0
- torchrl/weight_update/utils.py +73 -0
- torchrl/weight_update/weight_sync_schemes.py +1244 -0
- torchrl-0.11.0.dist-info/METADATA +1308 -0
- torchrl-0.11.0.dist-info/RECORD +395 -0
- torchrl-0.11.0.dist-info/WHEEL +5 -0
- torchrl-0.11.0.dist-info/entry_points.txt +2 -0
- torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
- torchrl-0.11.0.dist-info/top_level.txt +7 -0
|
@@ -0,0 +1,2376 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import collections
|
|
8
|
+
import contextlib
|
|
9
|
+
import json
|
|
10
|
+
import multiprocessing
|
|
11
|
+
import textwrap
|
|
12
|
+
import threading
|
|
13
|
+
import warnings
|
|
14
|
+
from collections.abc import Callable, Sequence
|
|
15
|
+
from concurrent.futures import ThreadPoolExecutor
|
|
16
|
+
from pathlib import Path
|
|
17
|
+
from typing import Any
|
|
18
|
+
|
|
19
|
+
import numpy as np
|
|
20
|
+
import torch
|
|
21
|
+
|
|
22
|
+
try:
|
|
23
|
+
from torch.compiler import is_compiling
|
|
24
|
+
except ImportError:
|
|
25
|
+
from torch._dynamo import is_compiling
|
|
26
|
+
|
|
27
|
+
from functools import partial, wraps
|
|
28
|
+
from typing import TYPE_CHECKING, TypeVar
|
|
29
|
+
|
|
30
|
+
from tensordict import (
|
|
31
|
+
is_tensor_collection,
|
|
32
|
+
is_tensorclass,
|
|
33
|
+
LazyStackedTensorDict,
|
|
34
|
+
NestedKey,
|
|
35
|
+
TensorDict,
|
|
36
|
+
TensorDictBase,
|
|
37
|
+
unravel_key,
|
|
38
|
+
)
|
|
39
|
+
from tensordict.nn.utils import _set_dispatch_td_nn_modules
|
|
40
|
+
from tensordict.utils import expand_as_right, expand_right
|
|
41
|
+
from torch import Tensor
|
|
42
|
+
from torch.utils._pytree import tree_map
|
|
43
|
+
|
|
44
|
+
from torchrl._utils import accept_remote_rref_udf_invocation, rl_warnings
|
|
45
|
+
from torchrl.data.replay_buffers.samplers import (
|
|
46
|
+
PrioritizedSampler,
|
|
47
|
+
RandomSampler,
|
|
48
|
+
Sampler,
|
|
49
|
+
SamplerEnsemble,
|
|
50
|
+
)
|
|
51
|
+
from torchrl.data.replay_buffers.storages import (
|
|
52
|
+
_get_default_collate,
|
|
53
|
+
_stack_anything,
|
|
54
|
+
ListStorage,
|
|
55
|
+
Storage,
|
|
56
|
+
StorageEnsemble,
|
|
57
|
+
)
|
|
58
|
+
from torchrl.data.replay_buffers.utils import (
|
|
59
|
+
_is_int,
|
|
60
|
+
_reduce,
|
|
61
|
+
_to_numpy,
|
|
62
|
+
_to_torch,
|
|
63
|
+
INT_CLASSES,
|
|
64
|
+
pin_memory_output,
|
|
65
|
+
)
|
|
66
|
+
from torchrl.data.replay_buffers.writers import (
|
|
67
|
+
RoundRobinWriter,
|
|
68
|
+
TensorDictRoundRobinWriter,
|
|
69
|
+
Writer,
|
|
70
|
+
WriterEnsemble,
|
|
71
|
+
)
|
|
72
|
+
from torchrl.data.utils import DEVICE_TYPING
|
|
73
|
+
from torchrl.envs.transforms.transforms import _InvertTransform, Transform
|
|
74
|
+
|
|
75
|
+
T = TypeVar("T")
|
|
76
|
+
if TYPE_CHECKING:
|
|
77
|
+
from typing import Self
|
|
78
|
+
else:
|
|
79
|
+
Self = T
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def _maybe_delay_init(func):
|
|
83
|
+
@wraps(func)
|
|
84
|
+
def wrapper(self, *args, **kwargs):
|
|
85
|
+
if self._delayed_init and not self.initialized:
|
|
86
|
+
self._init()
|
|
87
|
+
return func(self, *args, **kwargs)
|
|
88
|
+
|
|
89
|
+
return wrapper
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
class ReplayBuffer:
|
|
93
|
+
"""A generic, composable replay buffer class.
|
|
94
|
+
|
|
95
|
+
Keyword Args:
|
|
96
|
+
storage (Storage, Callable[[], Storage], optional): the storage to be used.
|
|
97
|
+
If a callable is passed, it is used as constructor for the storage.
|
|
98
|
+
If none is provided a default :class:`~torchrl.data.replay_buffers.ListStorage` with
|
|
99
|
+
``max_size`` of ``1_000`` will be created.
|
|
100
|
+
sampler (Sampler, Callable[[], Sampler], optional): the sampler to be used.
|
|
101
|
+
If a callable is passed, it is used as constructor for the sampler.
|
|
102
|
+
If none is provided, a default :class:`~torchrl.data.replay_buffers.RandomSampler`
|
|
103
|
+
will be used.
|
|
104
|
+
writer (Writer, Callable[[], Writer], optional): the writer to be used.
|
|
105
|
+
If a callable is passed, it is used as constructor for the writer.
|
|
106
|
+
If none is provided a default :class:`~torchrl.data.replay_buffers.RoundRobinWriter`
|
|
107
|
+
will be used.
|
|
108
|
+
collate_fn (callable, optional): merges a list of samples to form a
|
|
109
|
+
mini-batch of Tensor(s)/outputs. Used when using batched
|
|
110
|
+
loading from a map-style dataset. The default value will be decided
|
|
111
|
+
based on the storage type.
|
|
112
|
+
pin_memory (bool): whether pin_memory() should be called on the rb
|
|
113
|
+
samples.
|
|
114
|
+
prefetch (int, optional): number of next batches to be prefetched
|
|
115
|
+
using multithreading. Defaults to None (no prefetching).
|
|
116
|
+
transform (Transform or Callable[[Any], Any], optional): Transform to be executed when
|
|
117
|
+
:meth:`sample` is called.
|
|
118
|
+
To chain transforms use the :class:`~torchrl.envs.Compose` class.
|
|
119
|
+
Transforms should be used with :class:`tensordict.TensorDict`
|
|
120
|
+
content. A generic callable can also be passed if the replay buffer
|
|
121
|
+
is used with PyTree structures (see example below).
|
|
122
|
+
Unlike storages, writers and samplers, transform constructors must
|
|
123
|
+
be passed as separate keyword argument :attr:`transform_factory`,
|
|
124
|
+
as it is impossible to distinguish a constructor from a transform.
|
|
125
|
+
transform_factory (Callable[[], Callable], optional): a factory for the
|
|
126
|
+
transform. Exclusive with :attr:`transform`.
|
|
127
|
+
batch_size (int, optional): the batch size to be used when sample() is
|
|
128
|
+
called.
|
|
129
|
+
|
|
130
|
+
.. note::
|
|
131
|
+
The batch-size can be specified at construction time via the
|
|
132
|
+
``batch_size`` argument, or at sampling time. The former should
|
|
133
|
+
be preferred whenever the batch-size is consistent across the
|
|
134
|
+
experiment. If the batch-size is likely to change, it can be
|
|
135
|
+
passed to the :meth:`sample` method. This option is
|
|
136
|
+
incompatible with prefetching (since this requires to know the
|
|
137
|
+
batch-size in advance) as well as with samplers that have a
|
|
138
|
+
``drop_last`` argument.
|
|
139
|
+
|
|
140
|
+
dim_extend (int, optional): indicates the dim to consider for
|
|
141
|
+
extension when calling :meth:`extend`. Defaults to ``storage.ndim-1``.
|
|
142
|
+
When using ``dim_extend > 0``, we recommend using the ``ndim``
|
|
143
|
+
argument in the storage instantiation if that argument is
|
|
144
|
+
available, to let storages know that the data is
|
|
145
|
+
multi-dimensional and keep consistent notions of storage-capacity
|
|
146
|
+
and batch-size during sampling.
|
|
147
|
+
|
|
148
|
+
.. note:: This argument has no effect on :meth:`add` and
|
|
149
|
+
therefore should be used with caution when both :meth:`add`
|
|
150
|
+
and :meth:`extend` are used in a codebase. For example:
|
|
151
|
+
|
|
152
|
+
>>> data = torch.zeros(3, 4)
|
|
153
|
+
>>> rb = ReplayBuffer(
|
|
154
|
+
... storage=LazyTensorStorage(10, ndim=2),
|
|
155
|
+
... dim_extend=1)
|
|
156
|
+
>>> # these two approaches are equivalent:
|
|
157
|
+
>>> for d in data.unbind(1):
|
|
158
|
+
... rb.add(d)
|
|
159
|
+
>>> rb.extend(data)
|
|
160
|
+
|
|
161
|
+
generator (torch.Generator, optional): a generator to use for sampling.
|
|
162
|
+
Using a dedicated generator for the replay buffer can allow a fine-grained control
|
|
163
|
+
over seeding, for instance keeping the global seed different but the RB seed identical
|
|
164
|
+
for distributed jobs.
|
|
165
|
+
Defaults to ``None`` (global default generator).
|
|
166
|
+
|
|
167
|
+
.. warning:: As of now, the generator has no effect on the transforms.
|
|
168
|
+
shared (bool, optional): whether the buffer will be shared using multiprocessing or not.
|
|
169
|
+
Defaults to ``False``.
|
|
170
|
+
compilable (bool, optional): whether the writer is compilable.
|
|
171
|
+
If ``True``, the writer cannot be shared between multiple processes.
|
|
172
|
+
Defaults to ``False``.
|
|
173
|
+
delayed_init (bool, optional): whether to initialize storage, writer, sampler and transform
|
|
174
|
+
the first time the buffer is used rather than during construction.
|
|
175
|
+
This is useful when the replay buffer needs to be pickled and sent to remote workers,
|
|
176
|
+
particularly when using transforms with modules that require gradients.
|
|
177
|
+
If not specified, defaults to ``True`` when ``transform_factory`` is provided,
|
|
178
|
+
and ``False`` otherwise.
|
|
179
|
+
|
|
180
|
+
Examples:
|
|
181
|
+
>>> import torch
|
|
182
|
+
>>>
|
|
183
|
+
>>> from torchrl.data import ReplayBuffer, ListStorage
|
|
184
|
+
>>>
|
|
185
|
+
>>> torch.manual_seed(0)
|
|
186
|
+
>>> rb = ReplayBuffer(
|
|
187
|
+
... storage=ListStorage(max_size=1000),
|
|
188
|
+
... batch_size=5,
|
|
189
|
+
... )
|
|
190
|
+
>>> # populate the replay buffer and get the item indices
|
|
191
|
+
>>> data = range(10)
|
|
192
|
+
>>> indices = rb.extend(data)
|
|
193
|
+
>>> # sample will return as many elements as specified in the constructor
|
|
194
|
+
>>> sample = rb.sample()
|
|
195
|
+
>>> print(sample)
|
|
196
|
+
tensor([4, 9, 3, 0, 3])
|
|
197
|
+
>>> # Passing the batch-size to the sample method overrides the one in the constructor
|
|
198
|
+
>>> sample = rb.sample(batch_size=3)
|
|
199
|
+
>>> print(sample)
|
|
200
|
+
tensor([9, 7, 3])
|
|
201
|
+
>>> # one cans sample using the ``sample`` method or iterate over the buffer
|
|
202
|
+
>>> for i, batch in enumerate(rb):
|
|
203
|
+
... print(i, batch)
|
|
204
|
+
... if i == 3:
|
|
205
|
+
... break
|
|
206
|
+
0 tensor([7, 3, 1, 6, 6])
|
|
207
|
+
1 tensor([9, 8, 6, 6, 8])
|
|
208
|
+
2 tensor([4, 3, 6, 9, 1])
|
|
209
|
+
3 tensor([4, 4, 1, 9, 9])
|
|
210
|
+
|
|
211
|
+
Replay buffers accept *any* kind of data. Not all storage types
|
|
212
|
+
will work, as some expect numerical data only, but the default
|
|
213
|
+
:class:`~torchrl.data.ListStorage` will:
|
|
214
|
+
|
|
215
|
+
Examples:
|
|
216
|
+
>>> torch.manual_seed(0)
|
|
217
|
+
>>> buffer = ReplayBuffer(storage=ListStorage(100), collate_fn=lambda x: x)
|
|
218
|
+
>>> indices = buffer.extend(["a", 1, None])
|
|
219
|
+
>>> buffer.sample(3)
|
|
220
|
+
[None, 'a', None]
|
|
221
|
+
|
|
222
|
+
The :class:`~torchrl.data.replay_buffers.TensorStorage`, :class:`~torchrl.data.replay_buffers.LazyMemmapStorage`
|
|
223
|
+
and :class:`~torchrl.data.replay_buffers.LazyTensorStorage` also work
|
|
224
|
+
with any PyTree structure (a PyTree is a nested structure of arbitrary depth made of dicts,
|
|
225
|
+
lists or tuples where the leaves are tensors) provided that it only contains
|
|
226
|
+
tensor data.
|
|
227
|
+
|
|
228
|
+
Examples:
|
|
229
|
+
>>> from torch.utils._pytree import tree_map
|
|
230
|
+
>>> def transform(x):
|
|
231
|
+
... # Zeros all the data in the pytree
|
|
232
|
+
... return tree_map(lambda y: y * 0, x)
|
|
233
|
+
>>> rb = ReplayBuffer(storage=LazyMemmapStorage(100), transform=transform)
|
|
234
|
+
>>> data = {
|
|
235
|
+
... "a": torch.randn(3),
|
|
236
|
+
... "b": {"c": (torch.zeros(2), [torch.ones(1)])},
|
|
237
|
+
... 30: -torch.ones(()),
|
|
238
|
+
... }
|
|
239
|
+
>>> rb.add(data)
|
|
240
|
+
>>> # The sample has a similar structure to the data (with a leading dimension of 10 for each tensor)
|
|
241
|
+
>>> s = rb.sample(10)
|
|
242
|
+
>>> # let's check that our transform did its job:
|
|
243
|
+
>>> def assert0(x):
|
|
244
|
+
>>> assert (x == 0).all()
|
|
245
|
+
>>> tree_map(assert0, s)
|
|
246
|
+
|
|
247
|
+
"""
|
|
248
|
+
|
|
249
|
+
def __init__(
|
|
250
|
+
self,
|
|
251
|
+
*,
|
|
252
|
+
storage: Storage | Callable[[], Storage] | None = None,
|
|
253
|
+
sampler: Sampler | Callable[[], Sampler] | None = None,
|
|
254
|
+
writer: Writer | Callable[[], Writer] | None = None,
|
|
255
|
+
collate_fn: Callable | None = None,
|
|
256
|
+
pin_memory: bool = False,
|
|
257
|
+
prefetch: int | None = None,
|
|
258
|
+
transform: Transform | Callable | None = None, # noqa-F821
|
|
259
|
+
transform_factory: Callable[[], Transform | Callable]
|
|
260
|
+
| None = None, # noqa-F821
|
|
261
|
+
batch_size: int | None = None,
|
|
262
|
+
dim_extend: int | None = None,
|
|
263
|
+
checkpointer: StorageCheckpointerBase # noqa: F821
|
|
264
|
+
| Callable[[], StorageCheckpointerBase] # noqa: F821
|
|
265
|
+
| None = None, # noqa: F821
|
|
266
|
+
generator: torch.Generator | None = None,
|
|
267
|
+
shared: bool = False,
|
|
268
|
+
compilable: bool | None = None,
|
|
269
|
+
delayed_init: bool | None = None,
|
|
270
|
+
) -> None:
|
|
271
|
+
self._delayed_init = delayed_init
|
|
272
|
+
self._initialized = False
|
|
273
|
+
|
|
274
|
+
# Store init parameters for potential delayed initialization
|
|
275
|
+
self._init_storage = storage
|
|
276
|
+
self._init_sampler = sampler
|
|
277
|
+
self._init_writer = writer
|
|
278
|
+
self._init_collate_fn = collate_fn
|
|
279
|
+
self._init_transform = transform
|
|
280
|
+
self._init_transform_factory = transform_factory
|
|
281
|
+
self._init_checkpointer = checkpointer
|
|
282
|
+
self._init_generator = generator
|
|
283
|
+
self._init_compilable = compilable
|
|
284
|
+
|
|
285
|
+
if transform is not None and transform_factory is not None:
|
|
286
|
+
raise TypeError(
|
|
287
|
+
f"transform and transform_factory are mutually exclusive. "
|
|
288
|
+
f"Got transform={transform} and transform_factory={transform_factory}."
|
|
289
|
+
)
|
|
290
|
+
|
|
291
|
+
# Auto-detect delayed_init when transform_factory is provided
|
|
292
|
+
if transform_factory is not None and delayed_init is None:
|
|
293
|
+
delayed_init = True
|
|
294
|
+
elif delayed_init is None:
|
|
295
|
+
delayed_init = False
|
|
296
|
+
|
|
297
|
+
# Update _delayed_init after auto-detection
|
|
298
|
+
self._delayed_init = delayed_init
|
|
299
|
+
|
|
300
|
+
self._pin_memory = pin_memory
|
|
301
|
+
self._prefetch = bool(prefetch)
|
|
302
|
+
self._prefetch_cap = prefetch or 0
|
|
303
|
+
self._prefetch_queue = collections.deque()
|
|
304
|
+
self._batch_size = batch_size
|
|
305
|
+
|
|
306
|
+
if batch_size is None and prefetch:
|
|
307
|
+
raise ValueError(
|
|
308
|
+
"Dynamic batch-size specification is incompatible "
|
|
309
|
+
"with multithreaded sampling. "
|
|
310
|
+
"When using prefetch, the batch-size must be specified in "
|
|
311
|
+
"advance. "
|
|
312
|
+
)
|
|
313
|
+
|
|
314
|
+
if dim_extend is not None and dim_extend < 0:
|
|
315
|
+
raise ValueError("dim_extend must be a positive value.")
|
|
316
|
+
self._dim_extend = dim_extend
|
|
317
|
+
|
|
318
|
+
if self._prefetch_cap:
|
|
319
|
+
self._prefetch_executor = ThreadPoolExecutor(max_workers=self._prefetch_cap)
|
|
320
|
+
|
|
321
|
+
if shared and prefetch:
|
|
322
|
+
raise ValueError("Cannot share prefetched replay buffers.")
|
|
323
|
+
self.shared = shared
|
|
324
|
+
self.share(self.shared)
|
|
325
|
+
|
|
326
|
+
self._replay_lock = threading.RLock()
|
|
327
|
+
self._futures_lock = threading.RLock()
|
|
328
|
+
|
|
329
|
+
# If not delayed, initialize immediately
|
|
330
|
+
if not self._delayed_init:
|
|
331
|
+
self._init()
|
|
332
|
+
|
|
333
|
+
def _init(self) -> None:
|
|
334
|
+
"""Initialize the replay buffer components.
|
|
335
|
+
|
|
336
|
+
This method is called either immediately during __init__ (if delayed_init=False)
|
|
337
|
+
or on first use of the buffer (if delayed_init=True).
|
|
338
|
+
"""
|
|
339
|
+
if self._initialized:
|
|
340
|
+
return
|
|
341
|
+
|
|
342
|
+
self._initialized = True
|
|
343
|
+
try:
|
|
344
|
+
# Initialize storage
|
|
345
|
+
self._storage = self._maybe_make_storage(
|
|
346
|
+
self._init_storage, compilable=self._init_compilable
|
|
347
|
+
)
|
|
348
|
+
self._storage.attach(self)
|
|
349
|
+
|
|
350
|
+
# Initialize sampler
|
|
351
|
+
self._sampler = self._maybe_make_sampler(self._init_sampler)
|
|
352
|
+
|
|
353
|
+
# Initialize writer
|
|
354
|
+
self._writer = self._maybe_make_writer(self._init_writer)
|
|
355
|
+
self._writer.register_storage(self._storage)
|
|
356
|
+
|
|
357
|
+
# Initialize collate function
|
|
358
|
+
self._get_collate_fn(self._init_collate_fn)
|
|
359
|
+
|
|
360
|
+
# Initialize transform
|
|
361
|
+
self._transform = self._maybe_make_transform(
|
|
362
|
+
self._init_transform, self._init_transform_factory
|
|
363
|
+
)
|
|
364
|
+
|
|
365
|
+
# Check batch_size compatibility with sampler
|
|
366
|
+
if (
|
|
367
|
+
self._batch_size is None
|
|
368
|
+
and hasattr(self._sampler, "drop_last")
|
|
369
|
+
and self._sampler.drop_last
|
|
370
|
+
):
|
|
371
|
+
raise ValueError(
|
|
372
|
+
"Samplers with drop_last=True must work with a predictable batch-size. "
|
|
373
|
+
"Please pass the batch-size to the ReplayBuffer constructor."
|
|
374
|
+
)
|
|
375
|
+
|
|
376
|
+
# Set dim_extend properly now that storage is initialized
|
|
377
|
+
if self._dim_extend is None:
|
|
378
|
+
if self._storage is not None:
|
|
379
|
+
ndim = self._storage.ndim
|
|
380
|
+
self._dim_extend = ndim - 1
|
|
381
|
+
else:
|
|
382
|
+
self._dim_extend = 1
|
|
383
|
+
|
|
384
|
+
# Set checkpointer and generator
|
|
385
|
+
self._storage.checkpointer = self._init_checkpointer
|
|
386
|
+
self.set_rng(generator=self._init_generator)
|
|
387
|
+
|
|
388
|
+
# Initialize prioritized sampler if needed
|
|
389
|
+
self._initialize_prioritized_sampler()
|
|
390
|
+
|
|
391
|
+
# Remove init parameters
|
|
392
|
+
self._init_storage = None
|
|
393
|
+
self._init_sampler = None
|
|
394
|
+
self._init_writer = None
|
|
395
|
+
self._init_collate_fn = None
|
|
396
|
+
self._init_transform = None
|
|
397
|
+
self._init_transform_factory = None
|
|
398
|
+
self._init_checkpointer = None
|
|
399
|
+
self._init_generator = None
|
|
400
|
+
self._init_compilable = None
|
|
401
|
+
except Exception as e:
|
|
402
|
+
self._initialized = False
|
|
403
|
+
raise e
|
|
404
|
+
|
|
405
|
+
@property
|
|
406
|
+
def initialized(self) -> bool:
|
|
407
|
+
"""Whether the replay buffer has been initialized."""
|
|
408
|
+
return self._initialized
|
|
409
|
+
|
|
410
|
+
def _initialize_prioritized_sampler(self) -> None:
|
|
411
|
+
"""Initialize priority trees for existing data when using PrioritizedSampler.
|
|
412
|
+
|
|
413
|
+
This method ensures that when a PrioritizedSampler is used with storage that
|
|
414
|
+
already contains data, the priority trees are properly populated with default
|
|
415
|
+
priorities for all existing entries.
|
|
416
|
+
"""
|
|
417
|
+
from .samplers import PrioritizedSampler
|
|
418
|
+
|
|
419
|
+
if isinstance(self._sampler, PrioritizedSampler) and len(self._storage) > 0:
|
|
420
|
+
# Set default priorities for all existing data
|
|
421
|
+
indices = torch.arange(len(self._storage), dtype=torch.long)
|
|
422
|
+
default_priorities = torch.full(
|
|
423
|
+
(len(self._storage),), self._sampler.default_priority, dtype=torch.float
|
|
424
|
+
)
|
|
425
|
+
self._sampler.update_priority(indices, default_priorities)
|
|
426
|
+
|
|
427
|
+
def _maybe_make_storage(
|
|
428
|
+
self, storage: Storage | Callable[[], Storage] | None, compilable
|
|
429
|
+
) -> Storage:
|
|
430
|
+
if storage is None:
|
|
431
|
+
return ListStorage(max_size=1_000, compilable=compilable)
|
|
432
|
+
elif isinstance(storage, Storage):
|
|
433
|
+
return storage
|
|
434
|
+
elif callable(storage):
|
|
435
|
+
storage = storage()
|
|
436
|
+
if not isinstance(storage, Storage):
|
|
437
|
+
raise TypeError(
|
|
438
|
+
"storage must be either a Storage or a callable returning a storage instance."
|
|
439
|
+
)
|
|
440
|
+
return storage
|
|
441
|
+
|
|
442
|
+
def _maybe_make_sampler(
|
|
443
|
+
self, sampler: Sampler | Callable[[], Sampler] | None
|
|
444
|
+
) -> Sampler:
|
|
445
|
+
if sampler is None:
|
|
446
|
+
return RandomSampler()
|
|
447
|
+
elif isinstance(sampler, Sampler):
|
|
448
|
+
return sampler
|
|
449
|
+
elif callable(sampler):
|
|
450
|
+
sampler = sampler()
|
|
451
|
+
if not isinstance(sampler, Sampler):
|
|
452
|
+
raise TypeError(
|
|
453
|
+
"sampler must be either a Sampler or a callable returning a sampler instance."
|
|
454
|
+
)
|
|
455
|
+
return sampler
|
|
456
|
+
|
|
457
|
+
def _maybe_make_writer(
|
|
458
|
+
self, writer: Writer | Callable[[], Writer] | None
|
|
459
|
+
) -> Writer:
|
|
460
|
+
if writer is None:
|
|
461
|
+
return RoundRobinWriter()
|
|
462
|
+
elif isinstance(writer, Writer):
|
|
463
|
+
return writer
|
|
464
|
+
elif callable(writer):
|
|
465
|
+
writer = writer()
|
|
466
|
+
if not isinstance(writer, Writer):
|
|
467
|
+
raise TypeError(
|
|
468
|
+
"writer must be either a Writer or a callable returning a writer instance."
|
|
469
|
+
)
|
|
470
|
+
return writer
|
|
471
|
+
|
|
472
|
+
def _maybe_make_transform(
|
|
473
|
+
self,
|
|
474
|
+
transform: Transform | Callable[[], Transform] | None,
|
|
475
|
+
transform_factory: Callable | None,
|
|
476
|
+
) -> Transform:
|
|
477
|
+
from torchrl.envs.transforms.transforms import (
|
|
478
|
+
_CallableTransform,
|
|
479
|
+
Compose,
|
|
480
|
+
Transform,
|
|
481
|
+
)
|
|
482
|
+
|
|
483
|
+
if transform_factory is not None:
|
|
484
|
+
if transform is not None:
|
|
485
|
+
raise TypeError(
|
|
486
|
+
"transform and transform_factory cannot be used simultaneously"
|
|
487
|
+
)
|
|
488
|
+
transform = transform_factory()
|
|
489
|
+
if transform is None:
|
|
490
|
+
transform = Compose()
|
|
491
|
+
elif not isinstance(transform, Compose):
|
|
492
|
+
if not isinstance(transform, Transform) and callable(transform):
|
|
493
|
+
transform = _CallableTransform(transform)
|
|
494
|
+
elif not isinstance(transform, Transform):
|
|
495
|
+
raise RuntimeError(
|
|
496
|
+
"transform must be either a Transform instance or a callable."
|
|
497
|
+
)
|
|
498
|
+
transform = Compose(transform)
|
|
499
|
+
transform.eval()
|
|
500
|
+
return transform
|
|
501
|
+
|
|
502
|
+
def share(self, shared: bool = True) -> Self:
|
|
503
|
+
self.shared = shared
|
|
504
|
+
if self.shared:
|
|
505
|
+
self._write_lock = multiprocessing.Lock()
|
|
506
|
+
else:
|
|
507
|
+
self._write_lock = contextlib.nullcontext()
|
|
508
|
+
return self
|
|
509
|
+
|
|
510
|
+
@_maybe_delay_init
|
|
511
|
+
def set_rng(self, generator) -> None:
|
|
512
|
+
self._rng = generator
|
|
513
|
+
self._storage._rng = generator
|
|
514
|
+
self._sampler._rng = generator
|
|
515
|
+
self._writer._rng = generator
|
|
516
|
+
|
|
517
|
+
@property
|
|
518
|
+
def dim_extend(self):
|
|
519
|
+
return self._dim_extend
|
|
520
|
+
|
|
521
|
+
@property
|
|
522
|
+
def batch_size(self):
|
|
523
|
+
"""The batch size of the replay buffer.
|
|
524
|
+
|
|
525
|
+
The batch size can be overriden by setting the `batch_size` parameter in the :meth:`sample` method.
|
|
526
|
+
|
|
527
|
+
It defines both the number of samples returned by :meth:`sample` and the number of samples that are
|
|
528
|
+
yielded by the :class:`ReplayBuffer` iterator.
|
|
529
|
+
"""
|
|
530
|
+
return self._batch_size
|
|
531
|
+
|
|
532
|
+
@dim_extend.setter
|
|
533
|
+
def dim_extend(self, value):
|
|
534
|
+
if (
|
|
535
|
+
hasattr(self, "_dim_extend")
|
|
536
|
+
and self._dim_extend is not None
|
|
537
|
+
and self._dim_extend != value
|
|
538
|
+
):
|
|
539
|
+
raise RuntimeError(
|
|
540
|
+
"dim_extend cannot be reset. Please create a new replay buffer."
|
|
541
|
+
)
|
|
542
|
+
|
|
543
|
+
if value is None:
|
|
544
|
+
if self._initialized and self._storage is not None:
|
|
545
|
+
ndim = self._storage.ndim
|
|
546
|
+
value = ndim - 1
|
|
547
|
+
else:
|
|
548
|
+
value = 1
|
|
549
|
+
|
|
550
|
+
self._dim_extend = value
|
|
551
|
+
|
|
552
|
+
def _transpose(self, data):
|
|
553
|
+
if is_tensor_collection(data):
|
|
554
|
+
return data.transpose(self.dim_extend, 0)
|
|
555
|
+
return tree_map(lambda x: x.transpose(self.dim_extend, 0), data)
|
|
556
|
+
|
|
557
|
+
def _get_collate_fn(self, collate_fn):
|
|
558
|
+
self._collate_fn = (
|
|
559
|
+
collate_fn
|
|
560
|
+
if collate_fn is not None
|
|
561
|
+
else _get_default_collate(
|
|
562
|
+
self._storage, _is_tensordict=isinstance(self, TensorDictReplayBuffer)
|
|
563
|
+
)
|
|
564
|
+
)
|
|
565
|
+
|
|
566
|
+
@_maybe_delay_init
|
|
567
|
+
def set_storage(self, storage: Storage, collate_fn: Callable | None = None):
|
|
568
|
+
"""Sets a new storage in the replay buffer and returns the previous storage.
|
|
569
|
+
|
|
570
|
+
Args:
|
|
571
|
+
storage (Storage): the new storage for the buffer.
|
|
572
|
+
collate_fn (callable, optional): if provided, the collate_fn is set to this
|
|
573
|
+
value. Otherwise it is reset to a default value.
|
|
574
|
+
|
|
575
|
+
"""
|
|
576
|
+
prev_storage = self._storage
|
|
577
|
+
self._storage = storage
|
|
578
|
+
self._get_collate_fn(collate_fn)
|
|
579
|
+
|
|
580
|
+
return prev_storage
|
|
581
|
+
|
|
582
|
+
@_maybe_delay_init
|
|
583
|
+
def set_writer(self, writer: Writer):
|
|
584
|
+
"""Sets a new writer in the replay buffer and returns the previous writer."""
|
|
585
|
+
prev_writer = self._writer
|
|
586
|
+
self._writer = writer
|
|
587
|
+
self._writer.register_storage(self._storage)
|
|
588
|
+
return prev_writer
|
|
589
|
+
|
|
590
|
+
@_maybe_delay_init
|
|
591
|
+
def set_sampler(self, sampler: Sampler):
|
|
592
|
+
"""Sets a new sampler in the replay buffer and returns the previous sampler."""
|
|
593
|
+
prev_sampler = self._sampler
|
|
594
|
+
self._sampler = sampler
|
|
595
|
+
return prev_sampler
|
|
596
|
+
|
|
597
|
+
@_maybe_delay_init
|
|
598
|
+
def __len__(self) -> int:
|
|
599
|
+
with self._replay_lock:
|
|
600
|
+
return len(self._storage)
|
|
601
|
+
|
|
602
|
+
def _getattr(self, attr):
|
|
603
|
+
# To access properties in remote settings, see RayReplayBuffer.write_count for instance
|
|
604
|
+
return getattr(self, attr)
|
|
605
|
+
|
|
606
|
+
def _setattr(self, attr, value):
|
|
607
|
+
# To set properties in remote settings
|
|
608
|
+
setattr(self, attr, value)
|
|
609
|
+
return None # explicit return for remote calls
|
|
610
|
+
|
|
611
|
+
@property
|
|
612
|
+
@_maybe_delay_init
|
|
613
|
+
def write_count(self) -> int:
|
|
614
|
+
"""The total number of items written so far in the buffer through add and extend."""
|
|
615
|
+
return self._writer._write_count
|
|
616
|
+
|
|
617
|
+
def __repr__(self) -> str:
|
|
618
|
+
from torchrl.envs.transforms import Compose
|
|
619
|
+
|
|
620
|
+
storage = textwrap.indent(f"storage={getattr(self, '_storage', None)}", " " * 4)
|
|
621
|
+
writer = textwrap.indent(f"writer={getattr(self, '_writer', None)}", " " * 4)
|
|
622
|
+
sampler = textwrap.indent(f"sampler={getattr(self, '_sampler', None)}", " " * 4)
|
|
623
|
+
if getattr(self, "_transform", None) is not None and not (
|
|
624
|
+
isinstance(self._transform, Compose)
|
|
625
|
+
and not len(getattr(self, "_transform", None))
|
|
626
|
+
):
|
|
627
|
+
transform = textwrap.indent(
|
|
628
|
+
f"transform={getattr(self, '_transform', None)}", " " * 4
|
|
629
|
+
)
|
|
630
|
+
transform = f"\n{self._transform}, "
|
|
631
|
+
else:
|
|
632
|
+
transform = ""
|
|
633
|
+
batch_size = textwrap.indent(
|
|
634
|
+
f"batch_size={getattr(self, '_batch_size', None)}", " " * 4
|
|
635
|
+
)
|
|
636
|
+
collate_fn = textwrap.indent(
|
|
637
|
+
f"collate_fn={getattr(self, '_collate_fn', None)}", " " * 4
|
|
638
|
+
)
|
|
639
|
+
return f"{self.__class__.__name__}(\n{storage}, \n{sampler}, \n{writer}, {transform}\n{batch_size}, \n{collate_fn})"
|
|
640
|
+
|
|
641
|
+
@_maybe_delay_init
|
|
642
|
+
@pin_memory_output
|
|
643
|
+
def __getitem__(self, index: int | torch.Tensor | NestedKey) -> Any:
|
|
644
|
+
if isinstance(index, str) or (isinstance(index, tuple) and unravel_key(index)):
|
|
645
|
+
return self[:][index]
|
|
646
|
+
if isinstance(index, tuple):
|
|
647
|
+
if len(index) == 1:
|
|
648
|
+
return self[index[0]]
|
|
649
|
+
else:
|
|
650
|
+
return self[:][index]
|
|
651
|
+
index = _to_numpy(index)
|
|
652
|
+
|
|
653
|
+
if self.dim_extend > 0:
|
|
654
|
+
index = (slice(None),) * self.dim_extend + (index,)
|
|
655
|
+
with self._replay_lock:
|
|
656
|
+
data = self._storage[index]
|
|
657
|
+
data = self._transpose(data)
|
|
658
|
+
else:
|
|
659
|
+
with self._replay_lock:
|
|
660
|
+
data = self._storage[index]
|
|
661
|
+
|
|
662
|
+
if not isinstance(index, INT_CLASSES):
|
|
663
|
+
data = self._collate_fn(data)
|
|
664
|
+
|
|
665
|
+
if self._transform is not None and len(self._transform):
|
|
666
|
+
with data.unlock_() if is_tensor_collection(
|
|
667
|
+
data
|
|
668
|
+
) else contextlib.nullcontext():
|
|
669
|
+
data = self._transform(data)
|
|
670
|
+
|
|
671
|
+
return data
|
|
672
|
+
|
|
673
|
+
@_maybe_delay_init
|
|
674
|
+
def __setitem__(self, index, value) -> None:
|
|
675
|
+
if isinstance(index, str) or (isinstance(index, tuple) and unravel_key(index)):
|
|
676
|
+
self[:][index] = value
|
|
677
|
+
return
|
|
678
|
+
if isinstance(index, tuple):
|
|
679
|
+
if len(index) == 1:
|
|
680
|
+
self[index[0]] = value
|
|
681
|
+
else:
|
|
682
|
+
self[:][index] = value
|
|
683
|
+
return
|
|
684
|
+
index = _to_numpy(index)
|
|
685
|
+
|
|
686
|
+
if self._transform is not None and len(self._transform):
|
|
687
|
+
value = self._transform.inv(value)
|
|
688
|
+
|
|
689
|
+
if self.dim_extend > 0:
|
|
690
|
+
index = (slice(None),) * self.dim_extend + (index,)
|
|
691
|
+
with self._replay_lock:
|
|
692
|
+
self._storage[index] = self._transpose(value)
|
|
693
|
+
else:
|
|
694
|
+
with self._replay_lock:
|
|
695
|
+
self._storage[index] = value
|
|
696
|
+
return
|
|
697
|
+
|
|
698
|
+
@_maybe_delay_init
|
|
699
|
+
def state_dict(self) -> dict[str, Any]:
|
|
700
|
+
return {
|
|
701
|
+
"_storage": self._storage.state_dict(),
|
|
702
|
+
"_sampler": self._sampler.state_dict(),
|
|
703
|
+
"_writer": self._writer.state_dict(),
|
|
704
|
+
"_transforms": self._transform.state_dict(),
|
|
705
|
+
"_batch_size": self._batch_size,
|
|
706
|
+
"_rng": (self._rng.get_state().clone(), str(self._rng.device))
|
|
707
|
+
if self._rng is not None
|
|
708
|
+
else None,
|
|
709
|
+
}
|
|
710
|
+
|
|
711
|
+
@_maybe_delay_init
|
|
712
|
+
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
|
|
713
|
+
self._storage.load_state_dict(state_dict["_storage"])
|
|
714
|
+
self._sampler.load_state_dict(state_dict["_sampler"])
|
|
715
|
+
self._writer.load_state_dict(state_dict["_writer"])
|
|
716
|
+
self._transform.load_state_dict(state_dict["_transforms"])
|
|
717
|
+
self._batch_size = state_dict["_batch_size"]
|
|
718
|
+
rng = state_dict.get("_rng")
|
|
719
|
+
if rng is not None:
|
|
720
|
+
state, device = rng
|
|
721
|
+
rng = torch.Generator(device=device)
|
|
722
|
+
rng.set_state(state)
|
|
723
|
+
self.set_rng(generator=rng)
|
|
724
|
+
|
|
725
|
+
@_maybe_delay_init
|
|
726
|
+
def dumps(self, path):
|
|
727
|
+
"""Saves the replay buffer on disk at the specified path.
|
|
728
|
+
|
|
729
|
+
Args:
|
|
730
|
+
path (Path or str): path where to save the replay buffer.
|
|
731
|
+
|
|
732
|
+
Examples:
|
|
733
|
+
>>> import tempfile
|
|
734
|
+
>>> import tqdm
|
|
735
|
+
>>> from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer
|
|
736
|
+
>>> from torchrl.data.replay_buffers.samplers import PrioritizedSampler, RandomSampler
|
|
737
|
+
>>> import torch
|
|
738
|
+
>>> from tensordict import TensorDict
|
|
739
|
+
>>> # Build and populate the replay buffer
|
|
740
|
+
>>> S = 1_000_000
|
|
741
|
+
>>> sampler = PrioritizedSampler(S, 1.1, 1.0)
|
|
742
|
+
>>> # sampler = RandomSampler()
|
|
743
|
+
>>> storage = LazyMemmapStorage(S)
|
|
744
|
+
>>> rb = TensorDictReplayBuffer(storage=storage, sampler=sampler)
|
|
745
|
+
>>>
|
|
746
|
+
>>> for _ in tqdm.tqdm(range(100)):
|
|
747
|
+
... td = TensorDict({"obs": torch.randn(100, 3, 4), "next": {"obs": torch.randn(100, 3, 4)}, "td_error": torch.rand(100)}, [100])
|
|
748
|
+
... rb.extend(td)
|
|
749
|
+
... sample = rb.sample(32)
|
|
750
|
+
... rb.update_tensordict_priority(sample)
|
|
751
|
+
>>> # save and load the buffer
|
|
752
|
+
>>> with tempfile.TemporaryDirectory() as tmpdir:
|
|
753
|
+
... rb.dumps(tmpdir)
|
|
754
|
+
...
|
|
755
|
+
... sampler = PrioritizedSampler(S, 1.1, 1.0)
|
|
756
|
+
... # sampler = RandomSampler()
|
|
757
|
+
... storage = LazyMemmapStorage(S)
|
|
758
|
+
... rb_load = TensorDictReplayBuffer(storage=storage, sampler=sampler)
|
|
759
|
+
... rb_load.loads(tmpdir)
|
|
760
|
+
... assert len(rb) == len(rb_load)
|
|
761
|
+
|
|
762
|
+
"""
|
|
763
|
+
path = Path(path).absolute()
|
|
764
|
+
path.mkdir(exist_ok=True)
|
|
765
|
+
self._storage.dumps(path / "storage")
|
|
766
|
+
self._sampler.dumps(path / "sampler")
|
|
767
|
+
self._writer.dumps(path / "writer")
|
|
768
|
+
if self._rng is not None:
|
|
769
|
+
rng_state = TensorDict(
|
|
770
|
+
rng_state=self._rng.get_state().clone(),
|
|
771
|
+
device=self._rng.device,
|
|
772
|
+
)
|
|
773
|
+
rng_state.memmap(path / "rng_state")
|
|
774
|
+
|
|
775
|
+
# fall back on state_dict for transforms
|
|
776
|
+
transform_sd = self._transform.state_dict()
|
|
777
|
+
if transform_sd:
|
|
778
|
+
torch.save(transform_sd, path / "transform.t")
|
|
779
|
+
with open(path / "buffer_metadata.json", "w") as file:
|
|
780
|
+
json.dump({"batch_size": self._batch_size}, file)
|
|
781
|
+
|
|
782
|
+
@_maybe_delay_init
|
|
783
|
+
def loads(self, path):
|
|
784
|
+
"""Loads a replay buffer state at the given path.
|
|
785
|
+
|
|
786
|
+
The buffer should have matching components and be saved using :meth:`dumps`.
|
|
787
|
+
|
|
788
|
+
Args:
|
|
789
|
+
path (Path or str): path where the replay buffer was saved.
|
|
790
|
+
|
|
791
|
+
See :meth:`dumps` for more info.
|
|
792
|
+
|
|
793
|
+
"""
|
|
794
|
+
path = Path(path).absolute()
|
|
795
|
+
self._storage.loads(path / "storage")
|
|
796
|
+
self._sampler.loads(path / "sampler")
|
|
797
|
+
self._writer.loads(path / "writer")
|
|
798
|
+
if (path / "rng_state").exists():
|
|
799
|
+
rng_state = TensorDict.load_memmap(path / "rng_state")
|
|
800
|
+
rng = torch.Generator(device=rng_state.device)
|
|
801
|
+
rng.set_state(rng_state["rng_state"])
|
|
802
|
+
self.set_rng(rng)
|
|
803
|
+
# fall back on state_dict for transforms
|
|
804
|
+
if (path / "transform.t").exists():
|
|
805
|
+
self._transform.load_state_dict(torch.load(path / "transform.t"))
|
|
806
|
+
with open(path / "buffer_metadata.json") as file:
|
|
807
|
+
metadata = json.load(file)
|
|
808
|
+
self._batch_size = metadata["batch_size"]
|
|
809
|
+
|
|
810
|
+
@_maybe_delay_init
|
|
811
|
+
def save(self, *args, **kwargs):
|
|
812
|
+
"""Alias for :meth:`dumps`."""
|
|
813
|
+
return self.dumps(*args, **kwargs)
|
|
814
|
+
|
|
815
|
+
@_maybe_delay_init
|
|
816
|
+
def dump(self, *args, **kwargs):
|
|
817
|
+
"""Alias for :meth:`dumps`."""
|
|
818
|
+
return self.dumps(*args, **kwargs)
|
|
819
|
+
|
|
820
|
+
@_maybe_delay_init
|
|
821
|
+
def load(self, *args, **kwargs):
|
|
822
|
+
"""Alias for :meth:`loads`."""
|
|
823
|
+
return self.loads(*args, **kwargs)
|
|
824
|
+
|
|
825
|
+
@_maybe_delay_init
|
|
826
|
+
def register_save_hook(self, hook: Callable[[Any], Any]):
|
|
827
|
+
"""Registers a save hook for the storage.
|
|
828
|
+
|
|
829
|
+
.. note:: Hooks are currently not serialized when saving a replay buffer: they must
|
|
830
|
+
be manually re-initialized every time the buffer is created.
|
|
831
|
+
|
|
832
|
+
"""
|
|
833
|
+
self._storage.register_save_hook(hook)
|
|
834
|
+
|
|
835
|
+
@_maybe_delay_init
|
|
836
|
+
def register_load_hook(self, hook: Callable[[Any], Any]):
|
|
837
|
+
"""Registers a load hook for the storage.
|
|
838
|
+
|
|
839
|
+
.. note:: Hooks are currently not serialized when saving a replay buffer: they must
|
|
840
|
+
be manually re-initialized every time the buffer is created.
|
|
841
|
+
|
|
842
|
+
"""
|
|
843
|
+
self._storage.register_load_hook(hook)
|
|
844
|
+
|
|
845
|
+
@_maybe_delay_init
|
|
846
|
+
def add(self, data: Any) -> int:
|
|
847
|
+
"""Add a single element to the replay buffer.
|
|
848
|
+
|
|
849
|
+
Args:
|
|
850
|
+
data (Any): data to be added to the replay buffer
|
|
851
|
+
|
|
852
|
+
Returns:
|
|
853
|
+
index where the data lives in the replay buffer.
|
|
854
|
+
"""
|
|
855
|
+
if self._transform is not None and len(self._transform):
|
|
856
|
+
with _set_dispatch_td_nn_modules(is_tensor_collection(data)):
|
|
857
|
+
make_none = False
|
|
858
|
+
# Transforms usually expect a time batch dimension when called within a RB, so we unsqueeze the data temporarily
|
|
859
|
+
is_tc = is_tensor_collection(data)
|
|
860
|
+
cm = data.unsqueeze(-1) if is_tc else contextlib.nullcontext(data)
|
|
861
|
+
new_data = None
|
|
862
|
+
with cm as data_unsq:
|
|
863
|
+
data_unsq_r = self._transform.inv(data_unsq)
|
|
864
|
+
if is_tc and data_unsq_r is not None:
|
|
865
|
+
# this is a no-op whenever the result matches the input
|
|
866
|
+
new_data = data_unsq_r.squeeze(-1)
|
|
867
|
+
else:
|
|
868
|
+
make_none = data_unsq_r is None
|
|
869
|
+
data = new_data if new_data is not None else data
|
|
870
|
+
if make_none:
|
|
871
|
+
data = None
|
|
872
|
+
if data is None:
|
|
873
|
+
return torch.zeros((0, self._storage.ndim), dtype=torch.long)
|
|
874
|
+
if rl_warnings() and is_tensor_collection(data) and data.ndim:
|
|
875
|
+
warnings.warn(
|
|
876
|
+
f"Using `add()` with a TensorDict that has batch_size={data.batch_size}. "
|
|
877
|
+
f"Use `extend()` to add multiple elements, or `add()` with a single element (batch_size=torch.Size([])). "
|
|
878
|
+
"You can silence this warning by setting the `RL_WARNINGS` environment variable to `'0'`."
|
|
879
|
+
)
|
|
880
|
+
|
|
881
|
+
return self._add(data)
|
|
882
|
+
|
|
883
|
+
def _add(self, data):
|
|
884
|
+
with self._replay_lock, self._write_lock:
|
|
885
|
+
index = self._writer.add(data)
|
|
886
|
+
self._sampler.add(index)
|
|
887
|
+
return index
|
|
888
|
+
|
|
889
|
+
def _extend(self, data: Sequence, *, update_priority: bool = True) -> torch.Tensor:
|
|
890
|
+
is_comp = is_compiling()
|
|
891
|
+
nc = contextlib.nullcontext()
|
|
892
|
+
with self._replay_lock if not is_comp else nc, self._write_lock if not is_comp else nc:
|
|
893
|
+
if self.dim_extend > 0:
|
|
894
|
+
data = self._transpose(data)
|
|
895
|
+
index = self._writer.extend(data)
|
|
896
|
+
self._sampler.extend(index)
|
|
897
|
+
return index
|
|
898
|
+
|
|
899
|
+
@_maybe_delay_init
|
|
900
|
+
def extend(
|
|
901
|
+
self, data: Sequence, *, update_priority: bool | None = None
|
|
902
|
+
) -> torch.Tensor:
|
|
903
|
+
"""Extends the replay buffer with one or more elements contained in an iterable.
|
|
904
|
+
|
|
905
|
+
If present, the inverse transforms will be called.`
|
|
906
|
+
|
|
907
|
+
Args:
|
|
908
|
+
data (iterable): collection of data to be added to the replay
|
|
909
|
+
buffer.
|
|
910
|
+
|
|
911
|
+
Keyword Args:
|
|
912
|
+
update_priority (bool, optional): Whether to update the priority of the data. Defaults to True.
|
|
913
|
+
Without effect in this class. See :meth:`~torchrl.data.TensorDictReplayBuffer.extend` for more details.
|
|
914
|
+
|
|
915
|
+
Returns:
|
|
916
|
+
Indices of the data added to the replay buffer.
|
|
917
|
+
|
|
918
|
+
.. warning:: :meth:`~torchrl.data.replay_buffers.ReplayBuffer.extend` can have an
|
|
919
|
+
ambiguous signature when dealing with lists of values, which should be interpreted
|
|
920
|
+
either as PyTree (in which case all elements in the list will be put in a slice
|
|
921
|
+
in the stored PyTree in the storage) or a list of values to add one at a time.
|
|
922
|
+
To solve this, TorchRL makes the clear-cut distinction between list and tuple:
|
|
923
|
+
a tuple will be viewed as a PyTree, a list (at the root level) will be interpreted
|
|
924
|
+
as a stack of values to add one at a time to the buffer.
|
|
925
|
+
For :class:`~torchrl.data.replay_buffers.ListStorage` instances, only
|
|
926
|
+
unbound elements can be provided (no PyTrees).
|
|
927
|
+
|
|
928
|
+
"""
|
|
929
|
+
if update_priority is not None:
|
|
930
|
+
raise NotImplementedError(
|
|
931
|
+
"update_priority is not supported in this class. See :meth:`~torchrl.data.TensorDictReplayBuffer.extend` for more details."
|
|
932
|
+
)
|
|
933
|
+
if self._transform is not None and len(self._transform):
|
|
934
|
+
with _set_dispatch_td_nn_modules(is_tensor_collection(data)):
|
|
935
|
+
data = self._transform.inv(data)
|
|
936
|
+
if data is None:
|
|
937
|
+
return torch.zeros((0, self._storage.ndim), dtype=torch.long)
|
|
938
|
+
return self._extend(data, update_priority=update_priority)
|
|
939
|
+
|
|
940
|
+
@_maybe_delay_init
|
|
941
|
+
def update_priority(
|
|
942
|
+
self,
|
|
943
|
+
index: int | torch.Tensor | tuple[torch.Tensor],
|
|
944
|
+
priority: int | torch.Tensor,
|
|
945
|
+
) -> None:
|
|
946
|
+
if isinstance(index, tuple):
|
|
947
|
+
index = torch.stack(index, -1)
|
|
948
|
+
priority = torch.as_tensor(priority)
|
|
949
|
+
if self.dim_extend > 0 and priority.ndim > 1:
|
|
950
|
+
priority = self._transpose(priority).flatten()
|
|
951
|
+
# priority = priority.flatten()
|
|
952
|
+
with self._replay_lock, self._write_lock:
|
|
953
|
+
self._sampler.update_priority(index, priority, storage=self.storage)
|
|
954
|
+
|
|
955
|
+
@pin_memory_output
|
|
956
|
+
def _sample(self, batch_size: int) -> tuple[Any, dict]:
|
|
957
|
+
is_comp = is_compiling()
|
|
958
|
+
nc = contextlib.nullcontext()
|
|
959
|
+
with self._replay_lock if not is_comp else nc, self._write_lock if not is_comp else nc:
|
|
960
|
+
index, info = self._sampler.sample(self._storage, batch_size)
|
|
961
|
+
info["index"] = index
|
|
962
|
+
data = self._storage.get(index)
|
|
963
|
+
if not isinstance(index, INT_CLASSES):
|
|
964
|
+
data = self._collate_fn(data)
|
|
965
|
+
if self._transform is not None and len(self._transform):
|
|
966
|
+
is_td = is_tensor_collection(data)
|
|
967
|
+
with data.unlock_() if is_td else contextlib.nullcontext(), _set_dispatch_td_nn_modules(
|
|
968
|
+
is_td
|
|
969
|
+
):
|
|
970
|
+
data = self._transform(data)
|
|
971
|
+
|
|
972
|
+
return data, info
|
|
973
|
+
|
|
974
|
+
@_maybe_delay_init
|
|
975
|
+
def empty(self, empty_write_count: bool = True):
|
|
976
|
+
"""Empties the replay buffer and reset cursor to 0.
|
|
977
|
+
|
|
978
|
+
Args:
|
|
979
|
+
empty_write_count (bool, optional): Whether to empty the write_count attribute. Defaults to `True`.
|
|
980
|
+
"""
|
|
981
|
+
self._writer._empty(empty_write_count=empty_write_count)
|
|
982
|
+
self._sampler._empty()
|
|
983
|
+
self._storage._empty()
|
|
984
|
+
|
|
985
|
+
@_maybe_delay_init
|
|
986
|
+
def sample(self, batch_size: int | None = None, return_info: bool = False) -> Any:
|
|
987
|
+
"""Samples a batch of data from the replay buffer.
|
|
988
|
+
|
|
989
|
+
Uses Sampler to sample indices, and retrieves them from Storage.
|
|
990
|
+
|
|
991
|
+
Args:
|
|
992
|
+
batch_size (int, optional): size of data to be collected. If none
|
|
993
|
+
is provided, this method will sample a batch-size as indicated
|
|
994
|
+
by the sampler.
|
|
995
|
+
return_info (bool): whether to return info. If True, the result
|
|
996
|
+
is a tuple (data, info). If False, the result is the data.
|
|
997
|
+
|
|
998
|
+
Returns:
|
|
999
|
+
A batch of data selected in the replay buffer.
|
|
1000
|
+
A tuple containing this batch and info if return_info flag is set to True.
|
|
1001
|
+
"""
|
|
1002
|
+
if (
|
|
1003
|
+
batch_size is not None
|
|
1004
|
+
and self._batch_size is not None
|
|
1005
|
+
and batch_size != self._batch_size
|
|
1006
|
+
):
|
|
1007
|
+
warnings.warn(
|
|
1008
|
+
f"Got conflicting batch_sizes in constructor ({self._batch_size}) "
|
|
1009
|
+
f"and `sample` ({batch_size}). Refer to the ReplayBuffer documentation "
|
|
1010
|
+
"for a proper usage of the batch-size arguments. "
|
|
1011
|
+
"The batch-size provided to the sample method "
|
|
1012
|
+
"will prevail."
|
|
1013
|
+
)
|
|
1014
|
+
elif batch_size is None and self._batch_size is not None:
|
|
1015
|
+
batch_size = self._batch_size
|
|
1016
|
+
elif batch_size is None:
|
|
1017
|
+
raise RuntimeError(
|
|
1018
|
+
"batch_size not specified. You can specify the batch_size when "
|
|
1019
|
+
"constructing the replay buffer, or pass it to the sample method. "
|
|
1020
|
+
"Refer to the ReplayBuffer documentation "
|
|
1021
|
+
"for a proper usage of the batch-size arguments."
|
|
1022
|
+
)
|
|
1023
|
+
if not self._prefetch:
|
|
1024
|
+
result = self._sample(batch_size)
|
|
1025
|
+
else:
|
|
1026
|
+
with self._futures_lock:
|
|
1027
|
+
while (
|
|
1028
|
+
len(self._prefetch_queue)
|
|
1029
|
+
< min(self._sampler._remaining_batches, self._prefetch_cap)
|
|
1030
|
+
and not self._sampler.ran_out
|
|
1031
|
+
) or not len(self._prefetch_queue):
|
|
1032
|
+
fut = self._prefetch_executor.submit(self._sample, batch_size)
|
|
1033
|
+
self._prefetch_queue.append(fut)
|
|
1034
|
+
result = self._prefetch_queue.popleft().result()
|
|
1035
|
+
|
|
1036
|
+
if return_info:
|
|
1037
|
+
out, info = result
|
|
1038
|
+
if getattr(self.storage, "device", None) is not None:
|
|
1039
|
+
device = self.storage.device
|
|
1040
|
+
info = tree_map(lambda x: x.to(device) if hasattr(x, "to") else x, info)
|
|
1041
|
+
return out, info
|
|
1042
|
+
return result[0]
|
|
1043
|
+
|
|
1044
|
+
@_maybe_delay_init
|
|
1045
|
+
def mark_update(self, index: int | torch.Tensor) -> None:
|
|
1046
|
+
self._sampler.mark_update(index, storage=self._storage)
|
|
1047
|
+
|
|
1048
|
+
@_maybe_delay_init
|
|
1049
|
+
def append_transform(
|
|
1050
|
+
self, transform: Transform, *, invert: bool = False # noqa-F821
|
|
1051
|
+
) -> ReplayBuffer: # noqa: D417
|
|
1052
|
+
"""Appends transform at the end.
|
|
1053
|
+
|
|
1054
|
+
Transforms are applied in order when `sample` is called.
|
|
1055
|
+
|
|
1056
|
+
Args:
|
|
1057
|
+
transform (Transform): The transform to be appended
|
|
1058
|
+
|
|
1059
|
+
Keyword Args:
|
|
1060
|
+
invert (bool, optional): if ``True``, the transform will be inverted (forward calls will be called
|
|
1061
|
+
during writing and inverse calls during reading). Defaults to ``False``.
|
|
1062
|
+
|
|
1063
|
+
Example:
|
|
1064
|
+
>>> rb = ReplayBuffer(storage=LazyMemmapStorage(10), batch_size=4)
|
|
1065
|
+
>>> data = TensorDict({"a": torch.zeros(10)}, [10])
|
|
1066
|
+
>>> def t(data):
|
|
1067
|
+
... data += 1
|
|
1068
|
+
... return data
|
|
1069
|
+
>>> rb.append_transform(t, invert=True)
|
|
1070
|
+
>>> rb.extend(data)
|
|
1071
|
+
>>> assert (data == 1).all()
|
|
1072
|
+
|
|
1073
|
+
"""
|
|
1074
|
+
from torchrl.envs.transforms.transforms import _CallableTransform, Transform
|
|
1075
|
+
|
|
1076
|
+
if not isinstance(transform, Transform) and callable(transform):
|
|
1077
|
+
transform = _CallableTransform(transform)
|
|
1078
|
+
if invert:
|
|
1079
|
+
transform = _InvertTransform(transform)
|
|
1080
|
+
transform.eval()
|
|
1081
|
+
self._transform.append(transform)
|
|
1082
|
+
return self
|
|
1083
|
+
|
|
1084
|
+
@_maybe_delay_init
|
|
1085
|
+
def insert_transform(
|
|
1086
|
+
self,
|
|
1087
|
+
index: int,
|
|
1088
|
+
transform: Transform, # noqa-F821
|
|
1089
|
+
*,
|
|
1090
|
+
invert: bool = False,
|
|
1091
|
+
) -> ReplayBuffer: # noqa: D417
|
|
1092
|
+
"""Inserts transform.
|
|
1093
|
+
|
|
1094
|
+
Transforms are executed in order when `sample` is called.
|
|
1095
|
+
|
|
1096
|
+
Args:
|
|
1097
|
+
index (int): Position to insert the transform.
|
|
1098
|
+
transform (Transform): The transform to be appended
|
|
1099
|
+
|
|
1100
|
+
Keyword Args:
|
|
1101
|
+
invert (bool, optional): if ``True``, the transform will be inverted (forward calls will be called
|
|
1102
|
+
during writing and inverse calls during reading). Defaults to ``False``.
|
|
1103
|
+
|
|
1104
|
+
"""
|
|
1105
|
+
transform.eval()
|
|
1106
|
+
if invert:
|
|
1107
|
+
transform = _InvertTransform(transform)
|
|
1108
|
+
self._transform.insert(index, transform)
|
|
1109
|
+
return self
|
|
1110
|
+
|
|
1111
|
+
_iterator = None
|
|
1112
|
+
|
|
1113
|
+
@_maybe_delay_init
|
|
1114
|
+
def next(self):
|
|
1115
|
+
"""Returns the next item in the replay buffer.
|
|
1116
|
+
|
|
1117
|
+
This method is used to iterate over the replay buffer in contexts where __iter__ is not available,
|
|
1118
|
+
such as :class:`~torchrl.data.replay_buffers.RayReplayBuffer`.
|
|
1119
|
+
"""
|
|
1120
|
+
try:
|
|
1121
|
+
if self._iterator is None:
|
|
1122
|
+
self._iterator = iter(self)
|
|
1123
|
+
out = next(self._iterator)
|
|
1124
|
+
# if any, we don't want the device ref to be passed in distributed settings
|
|
1125
|
+
if out is not None and (out.device != "cpu"):
|
|
1126
|
+
out = out.copy().clear_device_()
|
|
1127
|
+
return out
|
|
1128
|
+
except StopIteration:
|
|
1129
|
+
self._iterator = None
|
|
1130
|
+
return None
|
|
1131
|
+
|
|
1132
|
+
@_maybe_delay_init
|
|
1133
|
+
def __iter__(self):
|
|
1134
|
+
if self._sampler.ran_out:
|
|
1135
|
+
self._sampler.ran_out = False
|
|
1136
|
+
if self._batch_size is None:
|
|
1137
|
+
raise RuntimeError(
|
|
1138
|
+
"Cannot iterate over the replay buffer. "
|
|
1139
|
+
"Batch_size was not specified during construction of the replay buffer."
|
|
1140
|
+
)
|
|
1141
|
+
while not self._sampler.ran_out or (
|
|
1142
|
+
self._prefetch and len(self._prefetch_queue)
|
|
1143
|
+
):
|
|
1144
|
+
yield self.sample()
|
|
1145
|
+
|
|
1146
|
+
@_maybe_delay_init
|
|
1147
|
+
def __getstate__(self) -> dict[str, Any]:
|
|
1148
|
+
state = self.__dict__.copy()
|
|
1149
|
+
if getattr(self, "_rng", None) is not None:
|
|
1150
|
+
rng_state = TensorDict(
|
|
1151
|
+
rng_state=self._rng.get_state().clone(),
|
|
1152
|
+
device=self._rng.device,
|
|
1153
|
+
)
|
|
1154
|
+
state["_rng"] = rng_state
|
|
1155
|
+
_replay_lock = state.pop("_replay_lock", None)
|
|
1156
|
+
_futures_lock = state.pop("_futures_lock", None)
|
|
1157
|
+
if _replay_lock is not None:
|
|
1158
|
+
state["_replay_lock_placeholder"] = None
|
|
1159
|
+
if _futures_lock is not None:
|
|
1160
|
+
state["_futures_lock_placeholder"] = None
|
|
1161
|
+
# Remove non-picklable prefetch objects - they will be recreated on unpickle
|
|
1162
|
+
_prefetch_queue = state.pop("_prefetch_queue", None)
|
|
1163
|
+
_prefetch_executor = state.pop("_prefetch_executor", None)
|
|
1164
|
+
if _prefetch_queue is not None:
|
|
1165
|
+
state["_prefetch_queue_placeholder"] = None
|
|
1166
|
+
if _prefetch_executor is not None:
|
|
1167
|
+
state["_prefetch_executor_placeholder"] = None
|
|
1168
|
+
return state
|
|
1169
|
+
|
|
1170
|
+
def __setstate__(self, state: dict[str, Any]):
|
|
1171
|
+
rngstate = None
|
|
1172
|
+
if "_rng" in state:
|
|
1173
|
+
rngstate = state["_rng"]
|
|
1174
|
+
if rngstate is not None:
|
|
1175
|
+
rng = torch.Generator(device=rngstate.device)
|
|
1176
|
+
rng.set_state(rngstate["rng_state"])
|
|
1177
|
+
|
|
1178
|
+
if "_replay_lock_placeholder" in state:
|
|
1179
|
+
state.pop("_replay_lock_placeholder")
|
|
1180
|
+
_replay_lock = threading.RLock()
|
|
1181
|
+
state["_replay_lock"] = _replay_lock
|
|
1182
|
+
if "_futures_lock_placeholder" in state:
|
|
1183
|
+
state.pop("_futures_lock_placeholder")
|
|
1184
|
+
_futures_lock = threading.RLock()
|
|
1185
|
+
state["_futures_lock"] = _futures_lock
|
|
1186
|
+
# Recreate prefetch objects after unpickling if they were present
|
|
1187
|
+
if "_prefetch_queue_placeholder" in state:
|
|
1188
|
+
state.pop("_prefetch_queue_placeholder")
|
|
1189
|
+
state["_prefetch_queue"] = collections.deque()
|
|
1190
|
+
if "_prefetch_executor_placeholder" in state:
|
|
1191
|
+
state.pop("_prefetch_executor_placeholder")
|
|
1192
|
+
state["_prefetch_executor"] = ThreadPoolExecutor(
|
|
1193
|
+
max_workers=state["_prefetch_cap"]
|
|
1194
|
+
)
|
|
1195
|
+
self.__dict__.update(state)
|
|
1196
|
+
if rngstate is not None:
|
|
1197
|
+
self.set_rng(rng)
|
|
1198
|
+
|
|
1199
|
+
@property
|
|
1200
|
+
@_maybe_delay_init
|
|
1201
|
+
def sampler(self) -> Sampler:
|
|
1202
|
+
"""The sampler of the replay buffer.
|
|
1203
|
+
|
|
1204
|
+
The sampler must be an instance of :class:`~torchrl.data.replay_buffers.Sampler`.
|
|
1205
|
+
|
|
1206
|
+
"""
|
|
1207
|
+
return self._sampler
|
|
1208
|
+
|
|
1209
|
+
@property
|
|
1210
|
+
@_maybe_delay_init
|
|
1211
|
+
def writer(self) -> Writer:
|
|
1212
|
+
"""The writer of the replay buffer.
|
|
1213
|
+
|
|
1214
|
+
The writer must be an instance of :class:`~torchrl.data.replay_buffers.Writer`.
|
|
1215
|
+
|
|
1216
|
+
"""
|
|
1217
|
+
return self._writer
|
|
1218
|
+
|
|
1219
|
+
@property
|
|
1220
|
+
@_maybe_delay_init
|
|
1221
|
+
def storage(self) -> Storage:
|
|
1222
|
+
"""The storage of the replay buffer.
|
|
1223
|
+
|
|
1224
|
+
The storage must be an instance of :class:`~torchrl.data.replay_buffers.Storage`.
|
|
1225
|
+
|
|
1226
|
+
"""
|
|
1227
|
+
return self._storage
|
|
1228
|
+
|
|
1229
|
+
@property
|
|
1230
|
+
@_maybe_delay_init
|
|
1231
|
+
def transform(self) -> Transform:
|
|
1232
|
+
"""The transform of the replay buffer.
|
|
1233
|
+
|
|
1234
|
+
The transform must be an instance of :class:`~torchrl.envs.transforms.Transform`.
|
|
1235
|
+
"""
|
|
1236
|
+
return self._transform
|
|
1237
|
+
|
|
1238
|
+
|
|
1239
|
+
class PrioritizedReplayBuffer(ReplayBuffer):
|
|
1240
|
+
"""Prioritized replay buffer.
|
|
1241
|
+
|
|
1242
|
+
All arguments are keyword-only arguments.
|
|
1243
|
+
|
|
1244
|
+
Presented in "Schaul, T.; Quan, J.; Antonoglou, I.; and Silver, D. 2015.
|
|
1245
|
+
Prioritized experience replay." (https://arxiv.org/abs/1511.05952)
|
|
1246
|
+
|
|
1247
|
+
Args:
|
|
1248
|
+
alpha (:obj:`float`): exponent α determines how much prioritization is used,
|
|
1249
|
+
with α = 0 corresponding to the uniform case.
|
|
1250
|
+
beta (:obj:`float`): importance sampling negative exponent.
|
|
1251
|
+
eps (:obj:`float`): delta added to the priorities to ensure that the buffer
|
|
1252
|
+
does not contain null priorities.
|
|
1253
|
+
storage (Storage, optional): the storage to be used. If none is provided
|
|
1254
|
+
a default :class:`~torchrl.data.replay_buffers.ListStorage` with
|
|
1255
|
+
``max_size`` of ``1_000`` will be created.
|
|
1256
|
+
sampler (Sampler, optional): the sampler to be used. If none is provided,
|
|
1257
|
+
a default :class:`~torchrl.data.replay_buffers.PrioritizedSampler` with
|
|
1258
|
+
``alpha``, ``beta``, and ``eps`` will be created.
|
|
1259
|
+
collate_fn (callable, optional): merges a list of samples to form a
|
|
1260
|
+
mini-batch of Tensor(s)/outputs. Used when using batched
|
|
1261
|
+
loading from a map-style dataset. The default value will be decided
|
|
1262
|
+
based on the storage type.
|
|
1263
|
+
pin_memory (bool): whether pin_memory() should be called on the rb
|
|
1264
|
+
samples.
|
|
1265
|
+
prefetch (int, optional): number of next batches to be prefetched
|
|
1266
|
+
using multithreading. Defaults to None (no prefetching).
|
|
1267
|
+
transform (Transform, optional): Transform to be executed when
|
|
1268
|
+
sample() is called.
|
|
1269
|
+
To chain transforms use the :class:`~torchrl.envs.Compose` class.
|
|
1270
|
+
Transforms should be used with :class:`tensordict.TensorDict`
|
|
1271
|
+
content. If used with other structures, the transforms should be
|
|
1272
|
+
encoded with a ``"data"`` leading key that will be used to
|
|
1273
|
+
construct a tensordict from the non-tensordict content.
|
|
1274
|
+
batch_size (int, optional): the batch size to be used when sample() is
|
|
1275
|
+
called.
|
|
1276
|
+
|
|
1277
|
+
.. note:: The batch-size can be specified at construction time via the
|
|
1278
|
+
``batch_size`` argument, or at sampling time. The former should
|
|
1279
|
+
be preferred whenever the batch-size is consistent across the
|
|
1280
|
+
experiment. If the batch-size is likely to change, it can be
|
|
1281
|
+
passed to the :meth:`sample` method. This option is
|
|
1282
|
+
incompatible with prefetching (since this requires to know the
|
|
1283
|
+
batch-size in advance) as well as with samplers that have a
|
|
1284
|
+
``drop_last`` argument.
|
|
1285
|
+
|
|
1286
|
+
dim_extend (int, optional): indicates the dim to consider for
|
|
1287
|
+
extension when calling :meth:`extend`. Defaults to ``storage.ndim-1``.
|
|
1288
|
+
When using ``dim_extend > 0``, we recommend using the ``ndim``
|
|
1289
|
+
argument in the storage instantiation if that argument is
|
|
1290
|
+
available, to let storages know that the data is
|
|
1291
|
+
multi-dimensional and keep consistent notions of storage-capacity
|
|
1292
|
+
and batch-size during sampling.
|
|
1293
|
+
|
|
1294
|
+
.. note:: This argument has no effect on :meth:`add` and
|
|
1295
|
+
therefore should be used with caution when both :meth:`add`
|
|
1296
|
+
and :meth:`extend` are used in a codebase. For example:
|
|
1297
|
+
|
|
1298
|
+
>>> data = torch.zeros(3, 4)
|
|
1299
|
+
>>> rb = ReplayBuffer(
|
|
1300
|
+
... storage=LazyTensorStorage(10, ndim=2),
|
|
1301
|
+
... dim_extend=1)
|
|
1302
|
+
>>> # these two approaches are equivalent:
|
|
1303
|
+
>>> for d in data.unbind(1):
|
|
1304
|
+
... rb.add(d)
|
|
1305
|
+
>>> rb.extend(data)
|
|
1306
|
+
|
|
1307
|
+
delayed_init (bool, optional): whether to initialize storage, writer, sampler and transform
|
|
1308
|
+
the first time the buffer is used rather than during construction.
|
|
1309
|
+
This is useful when the replay buffer needs to be pickled and sent to remote workers,
|
|
1310
|
+
particularly when using transforms with modules that require gradients.
|
|
1311
|
+
If not specified, defaults to ``True`` when ``transform_factory`` is provided,
|
|
1312
|
+
and ``False`` otherwise.
|
|
1313
|
+
|
|
1314
|
+
.. note::
|
|
1315
|
+
Generic prioritized replay buffers (ie. non-tensordict backed) require
|
|
1316
|
+
calling :meth:`~.sample` with the ``return_info`` argument set to
|
|
1317
|
+
``True`` to have access to the indices, and hence update the priority.
|
|
1318
|
+
Using :class:`tensordict.TensorDict` and the related
|
|
1319
|
+
:class:`~torchrl.data.TensorDictPrioritizedReplayBuffer` simplifies this
|
|
1320
|
+
process.
|
|
1321
|
+
|
|
1322
|
+
Examples:
|
|
1323
|
+
>>> import torch
|
|
1324
|
+
>>>
|
|
1325
|
+
>>> from torchrl.data import ListStorage, PrioritizedReplayBuffer
|
|
1326
|
+
>>>
|
|
1327
|
+
>>> torch.manual_seed(0)
|
|
1328
|
+
>>>
|
|
1329
|
+
>>> rb = PrioritizedReplayBuffer(alpha=0.7, beta=0.9, storage=ListStorage(10))
|
|
1330
|
+
>>> data = range(10)
|
|
1331
|
+
>>> rb.extend(data)
|
|
1332
|
+
>>> sample = rb.sample(3)
|
|
1333
|
+
>>> print(sample)
|
|
1334
|
+
tensor([1, 0, 1])
|
|
1335
|
+
>>> # get the info to find what the indices are
|
|
1336
|
+
>>> sample, info = rb.sample(5, return_info=True)
|
|
1337
|
+
>>> print(sample, info)
|
|
1338
|
+
tensor([2, 7, 4, 3, 5]) {'priority_weight': array([1., 1., 1., 1., 1.], dtype=float32), 'index': array([2, 7, 4, 3, 5])}
|
|
1339
|
+
>>> # update priority
|
|
1340
|
+
>>> priority = torch.ones(5) * 5
|
|
1341
|
+
>>> rb.update_priority(info["index"], priority)
|
|
1342
|
+
>>> # and now a new sample, the weights should be updated
|
|
1343
|
+
>>> sample, info = rb.sample(5, return_info=True)
|
|
1344
|
+
>>> print(sample, info)
|
|
1345
|
+
tensor([2, 5, 2, 2, 5]) {'priority_weight': array([0.36278465, 0.36278465, 0.36278465, 0.36278465, 0.36278465],
|
|
1346
|
+
dtype=float32), 'index': array([2, 5, 2, 2, 5])}
|
|
1347
|
+
|
|
1348
|
+
"""
|
|
1349
|
+
|
|
1350
|
+
def __init__(
|
|
1351
|
+
self,
|
|
1352
|
+
*,
|
|
1353
|
+
alpha: float,
|
|
1354
|
+
beta: float,
|
|
1355
|
+
eps: float = 1e-8,
|
|
1356
|
+
dtype: torch.dtype = torch.float,
|
|
1357
|
+
storage: Storage | None = None,
|
|
1358
|
+
sampler: Sampler | None = None,
|
|
1359
|
+
collate_fn: Callable | None = None,
|
|
1360
|
+
pin_memory: bool = False,
|
|
1361
|
+
prefetch: int | None = None,
|
|
1362
|
+
transform: Transform | None = None, # noqa-F821
|
|
1363
|
+
batch_size: int | None = None,
|
|
1364
|
+
dim_extend: int | None = None,
|
|
1365
|
+
delayed_init: bool = False,
|
|
1366
|
+
) -> None:
|
|
1367
|
+
if storage is None:
|
|
1368
|
+
storage = ListStorage(max_size=1_000)
|
|
1369
|
+
if sampler is None:
|
|
1370
|
+
sampler = PrioritizedSampler(storage.max_size, alpha, beta, eps, dtype)
|
|
1371
|
+
super().__init__(
|
|
1372
|
+
storage=storage,
|
|
1373
|
+
sampler=sampler,
|
|
1374
|
+
collate_fn=collate_fn,
|
|
1375
|
+
pin_memory=pin_memory,
|
|
1376
|
+
prefetch=prefetch,
|
|
1377
|
+
transform=transform,
|
|
1378
|
+
batch_size=batch_size,
|
|
1379
|
+
dim_extend=dim_extend,
|
|
1380
|
+
delayed_init=delayed_init,
|
|
1381
|
+
)
|
|
1382
|
+
|
|
1383
|
+
|
|
1384
|
+
class TensorDictReplayBuffer(ReplayBuffer):
|
|
1385
|
+
"""TensorDict-specific wrapper around the :class:`~torchrl.data.ReplayBuffer` class.
|
|
1386
|
+
|
|
1387
|
+
Keyword Args:
|
|
1388
|
+
storage (Storage, Callable[[], Storage], optional): the storage to be used.
|
|
1389
|
+
If a callable is passed, it is used as constructor for the storage.
|
|
1390
|
+
If none is provided a default :class:`~torchrl.data.replay_buffers.ListStorage` with
|
|
1391
|
+
``max_size`` of ``1_000`` will be created.
|
|
1392
|
+
sampler (Sampler, Callable[[], Sampler], optional): the sampler to be used.
|
|
1393
|
+
If a callable is passed, it is used as constructor for the sampler.
|
|
1394
|
+
If none is provided, a default :class:`~torchrl.data.replay_buffers.RandomSampler`
|
|
1395
|
+
will be used.
|
|
1396
|
+
writer (Writer, Callable[[], Writer], optional): the writer to be used.
|
|
1397
|
+
If a callable is passed, it is used as constructor for the writer.
|
|
1398
|
+
If none is provided a default :class:`~torchrl.data.replay_buffers.TensorDictRoundRobinWriter`
|
|
1399
|
+
will be used.
|
|
1400
|
+
collate_fn (callable, optional): merges a list of samples to form a
|
|
1401
|
+
mini-batch of Tensor(s)/outputs. Used when using batched
|
|
1402
|
+
loading from a map-style dataset. The default value will be decided
|
|
1403
|
+
based on the storage type.
|
|
1404
|
+
pin_memory (bool): whether pin_memory() should be called on the rb
|
|
1405
|
+
samples.
|
|
1406
|
+
prefetch (int, optional): number of next batches to be prefetched
|
|
1407
|
+
using multithreading. Defaults to None (no prefetching).
|
|
1408
|
+
transform (Transform or Callable[[Any], Any], optional): Transform to be executed when
|
|
1409
|
+
:meth:`sample` is called.
|
|
1410
|
+
To chain transforms use the :class:`~torchrl.envs.Compose` class.
|
|
1411
|
+
Transforms should be used with :class:`tensordict.TensorDict`
|
|
1412
|
+
content. A generic callable can also be passed if the replay buffer
|
|
1413
|
+
is used with PyTree structures (see example below).
|
|
1414
|
+
Unlike storages, writers and samplers, transform constructors must
|
|
1415
|
+
be passed as separate keyword argument :attr:`transform_factory`,
|
|
1416
|
+
as it is impossible to distinguish a constructor from a transform.
|
|
1417
|
+
transform_factory (Callable[[], Callable], optional): a factory for the
|
|
1418
|
+
transform. Exclusive with :attr:`transform`.
|
|
1419
|
+
batch_size (int, optional): the batch size to be used when sample() is
|
|
1420
|
+
called.
|
|
1421
|
+
|
|
1422
|
+
.. note::
|
|
1423
|
+
The batch-size can be specified at construction time via the
|
|
1424
|
+
``batch_size`` argument, or at sampling time. The former should
|
|
1425
|
+
be preferred whenever the batch-size is consistent across the
|
|
1426
|
+
experiment. If the batch-size is likely to change, it can be
|
|
1427
|
+
passed to the :meth:`~.sample` method. This option is
|
|
1428
|
+
incompatible with prefetching (since this requires to know the
|
|
1429
|
+
batch-size in advance) as well as with samplers that have a
|
|
1430
|
+
``drop_last`` argument.
|
|
1431
|
+
|
|
1432
|
+
priority_key (str, optional): the key at which priority is assumed to
|
|
1433
|
+
be stored within TensorDicts added to this ReplayBuffer.
|
|
1434
|
+
This is to be used when the sampler is of type
|
|
1435
|
+
:class:`~torchrl.data.PrioritizedSampler`.
|
|
1436
|
+
Defaults to ``"td_error"``.
|
|
1437
|
+
dim_extend (int, optional): indicates the dim to consider for
|
|
1438
|
+
extension when calling :meth:`~.extend`. Defaults to ``storage.ndim-1``.
|
|
1439
|
+
When using ``dim_extend > 0``, we recommend using the ``ndim``
|
|
1440
|
+
argument in the storage instantiation if that argument is
|
|
1441
|
+
available, to let storages know that the data is
|
|
1442
|
+
multi-dimensional and keep consistent notions of storage-capacity
|
|
1443
|
+
and batch-size during sampling.
|
|
1444
|
+
|
|
1445
|
+
.. note:: This argument has no effect on :meth:`~.add` and
|
|
1446
|
+
therefore should be used with caution when both :meth:`~.add`
|
|
1447
|
+
and :meth:`~.extend` are used in a codebase. For example:
|
|
1448
|
+
|
|
1449
|
+
>>> data = torch.zeros(3, 4)
|
|
1450
|
+
>>> rb = ReplayBuffer(
|
|
1451
|
+
... storage=LazyTensorStorage(10, ndim=2),
|
|
1452
|
+
... dim_extend=1)
|
|
1453
|
+
>>> # these two approaches are equivalent:
|
|
1454
|
+
>>> for d in data.unbind(1):
|
|
1455
|
+
... rb.add(d)
|
|
1456
|
+
>>> rb.extend(data)
|
|
1457
|
+
|
|
1458
|
+
generator (torch.Generator, optional): a generator to use for sampling.
|
|
1459
|
+
Using a dedicated generator for the replay buffer can allow a fine-grained control
|
|
1460
|
+
over seeding, for instance keeping the global seed different but the RB seed identical
|
|
1461
|
+
for distributed jobs.
|
|
1462
|
+
Defaults to ``None`` (global default generator).
|
|
1463
|
+
|
|
1464
|
+
.. warning:: As of now, the generator has no effect on the transforms.
|
|
1465
|
+
shared (bool, optional): whether the buffer will be shared using multiprocessing or not.
|
|
1466
|
+
Defaults to ``False``.
|
|
1467
|
+
compilable (bool, optional): whether the writer is compilable.
|
|
1468
|
+
If ``True``, the writer cannot be shared between multiple processes.
|
|
1469
|
+
Defaults to ``False``.
|
|
1470
|
+
delayed_init (bool, optional): whether to initialize storage, writer, sampler and transform
|
|
1471
|
+
the first time the buffer is used rather than during construction.
|
|
1472
|
+
This is useful when the replay buffer needs to be pickled and sent to remote workers,
|
|
1473
|
+
particularly when using transforms with modules that require gradients.
|
|
1474
|
+
If not specified, defaults to ``True`` when ``transform_factory`` is provided,
|
|
1475
|
+
and ``False`` otherwise.
|
|
1476
|
+
|
|
1477
|
+
Examples:
|
|
1478
|
+
>>> import torch
|
|
1479
|
+
>>>
|
|
1480
|
+
>>> from torchrl.data import LazyTensorStorage, TensorDictReplayBuffer
|
|
1481
|
+
>>> from tensordict import TensorDict
|
|
1482
|
+
>>>
|
|
1483
|
+
>>> torch.manual_seed(0)
|
|
1484
|
+
>>>
|
|
1485
|
+
>>> rb = TensorDictReplayBuffer(storage=LazyTensorStorage(10), batch_size=5)
|
|
1486
|
+
>>> data = TensorDict({"a": torch.ones(10, 3), ("b", "c"): torch.zeros(10, 1, 1)}, [10])
|
|
1487
|
+
>>> rb.extend(data)
|
|
1488
|
+
>>> sample = rb.sample(3)
|
|
1489
|
+
>>> # samples keep track of the index
|
|
1490
|
+
>>> print(sample)
|
|
1491
|
+
TensorDict(
|
|
1492
|
+
fields={
|
|
1493
|
+
a: Tensor(shape=torch.Size([3, 3]), device=cpu, dtype=torch.float32, is_shared=False),
|
|
1494
|
+
b: TensorDict(
|
|
1495
|
+
fields={
|
|
1496
|
+
c: Tensor(shape=torch.Size([3, 1, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
|
|
1497
|
+
batch_size=torch.Size([3]),
|
|
1498
|
+
device=cpu,
|
|
1499
|
+
is_shared=False),
|
|
1500
|
+
index: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.int32, is_shared=False)},
|
|
1501
|
+
batch_size=torch.Size([3]),
|
|
1502
|
+
device=cpu,
|
|
1503
|
+
is_shared=False)
|
|
1504
|
+
>>> # we can iterate over the buffer
|
|
1505
|
+
>>> for i, data in enumerate(rb):
|
|
1506
|
+
... print(i, data)
|
|
1507
|
+
... if i == 2:
|
|
1508
|
+
... break
|
|
1509
|
+
0 TensorDict(
|
|
1510
|
+
fields={
|
|
1511
|
+
a: Tensor(shape=torch.Size([5, 3]), device=cpu, dtype=torch.float32, is_shared=False),
|
|
1512
|
+
b: TensorDict(
|
|
1513
|
+
fields={
|
|
1514
|
+
c: Tensor(shape=torch.Size([5, 1, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
|
|
1515
|
+
batch_size=torch.Size([5]),
|
|
1516
|
+
device=cpu,
|
|
1517
|
+
is_shared=False),
|
|
1518
|
+
index: Tensor(shape=torch.Size([5]), device=cpu, dtype=torch.int32, is_shared=False)},
|
|
1519
|
+
batch_size=torch.Size([5]),
|
|
1520
|
+
device=cpu,
|
|
1521
|
+
is_shared=False)
|
|
1522
|
+
1 TensorDict(
|
|
1523
|
+
fields={
|
|
1524
|
+
a: Tensor(shape=torch.Size([5, 3]), device=cpu, dtype=torch.float32, is_shared=False),
|
|
1525
|
+
b: TensorDict(
|
|
1526
|
+
fields={
|
|
1527
|
+
c: Tensor(shape=torch.Size([5, 1, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
|
|
1528
|
+
batch_size=torch.Size([5]),
|
|
1529
|
+
device=cpu,
|
|
1530
|
+
is_shared=False),
|
|
1531
|
+
index: Tensor(shape=torch.Size([5]), device=cpu, dtype=torch.int32, is_shared=False)},
|
|
1532
|
+
batch_size=torch.Size([5]),
|
|
1533
|
+
device=cpu,
|
|
1534
|
+
is_shared=False)
|
|
1535
|
+
|
|
1536
|
+
"""
|
|
1537
|
+
|
|
1538
|
+
def __init__(self, *, priority_key: str = "td_error", **kwargs) -> None:
|
|
1539
|
+
writer = kwargs.get("writer", None)
|
|
1540
|
+
if writer is None:
|
|
1541
|
+
kwargs["writer"] = partial(
|
|
1542
|
+
TensorDictRoundRobinWriter, compilable=kwargs.get("compilable")
|
|
1543
|
+
)
|
|
1544
|
+
super().__init__(**kwargs)
|
|
1545
|
+
self.priority_key = priority_key
|
|
1546
|
+
|
|
1547
|
+
def _get_priority_item(self, tensordict: TensorDictBase) -> float:
|
|
1548
|
+
priority = tensordict.get(self.priority_key, None)
|
|
1549
|
+
if self._storage.ndim > 1:
|
|
1550
|
+
# We have to flatten the priority otherwise we'll be aggregating
|
|
1551
|
+
# the priority across batches
|
|
1552
|
+
priority = priority.flatten(0, self._storage.ndim - 1)
|
|
1553
|
+
if priority is None:
|
|
1554
|
+
return self._sampler.default_priority
|
|
1555
|
+
try:
|
|
1556
|
+
if priority.numel() > 1:
|
|
1557
|
+
priority = _reduce(priority, self._sampler.reduction)
|
|
1558
|
+
else:
|
|
1559
|
+
priority = priority.item()
|
|
1560
|
+
except ValueError:
|
|
1561
|
+
raise ValueError(
|
|
1562
|
+
f"Found a priority key of size"
|
|
1563
|
+
f" {tensordict.get(self.priority_key).shape} but expected "
|
|
1564
|
+
f"scalar value"
|
|
1565
|
+
)
|
|
1566
|
+
|
|
1567
|
+
if self._storage.ndim > 1:
|
|
1568
|
+
priority = priority.unflatten(0, tensordict.shape[: self._storage.ndim])
|
|
1569
|
+
|
|
1570
|
+
return priority
|
|
1571
|
+
|
|
1572
|
+
def _get_priority_vector(self, tensordict: TensorDictBase) -> torch.Tensor:
|
|
1573
|
+
priority = tensordict.get(self.priority_key, None)
|
|
1574
|
+
if priority is None:
|
|
1575
|
+
return torch.tensor(
|
|
1576
|
+
self._sampler.default_priority,
|
|
1577
|
+
dtype=torch.float,
|
|
1578
|
+
device=tensordict.device,
|
|
1579
|
+
).expand(tensordict.shape[0])
|
|
1580
|
+
if self._storage.ndim > 1 and priority.ndim >= self._storage.ndim:
|
|
1581
|
+
# We have to flatten the priority otherwise we'll be aggregating
|
|
1582
|
+
# the priority across batches
|
|
1583
|
+
priority = priority.flatten(0, self._storage.ndim - 1)
|
|
1584
|
+
|
|
1585
|
+
priority = priority.reshape(priority.shape[0], -1)
|
|
1586
|
+
priority = _reduce(priority, self._sampler.reduction, dim=1)
|
|
1587
|
+
|
|
1588
|
+
if self._storage.ndim > 1:
|
|
1589
|
+
priority = priority.unflatten(0, tensordict.shape[: self._storage.ndim])
|
|
1590
|
+
|
|
1591
|
+
return priority
|
|
1592
|
+
|
|
1593
|
+
@_maybe_delay_init
|
|
1594
|
+
def add(self, data: TensorDictBase) -> int:
|
|
1595
|
+
if self._transform is not None:
|
|
1596
|
+
with _set_dispatch_td_nn_modules(is_tensor_collection(data)):
|
|
1597
|
+
data = self._transform.inv(data)
|
|
1598
|
+
if data is None:
|
|
1599
|
+
return torch.zeros((0, self._storage.ndim), dtype=torch.long)
|
|
1600
|
+
|
|
1601
|
+
index = super()._add(data)
|
|
1602
|
+
if index is not None:
|
|
1603
|
+
if is_tensor_collection(data):
|
|
1604
|
+
self._set_index_in_td(data, index)
|
|
1605
|
+
|
|
1606
|
+
self.update_tensordict_priority(data)
|
|
1607
|
+
return index
|
|
1608
|
+
|
|
1609
|
+
@_maybe_delay_init
|
|
1610
|
+
def extend(
|
|
1611
|
+
self, tensordicts: TensorDictBase, *, update_priority: bool | None = None
|
|
1612
|
+
) -> torch.Tensor:
|
|
1613
|
+
"""Extends the replay buffer with a batch of data.
|
|
1614
|
+
|
|
1615
|
+
Args:
|
|
1616
|
+
tensordicts (TensorDictBase): The data to extend the replay buffer with.
|
|
1617
|
+
|
|
1618
|
+
Keyword Args:
|
|
1619
|
+
update_priority (bool, optional): Whether to update the priority of the data. Defaults to True.
|
|
1620
|
+
|
|
1621
|
+
Returns:
|
|
1622
|
+
The indices of the data that were added to the replay buffer.
|
|
1623
|
+
"""
|
|
1624
|
+
if not isinstance(tensordicts, TensorDictBase):
|
|
1625
|
+
raise ValueError(
|
|
1626
|
+
f"{self.__class__.__name__} only accepts TensorDictBase subclasses. tensorclasses "
|
|
1627
|
+
f"and other types are not compatible with that class. "
|
|
1628
|
+
"Please use a regular `ReplayBuffer` instead."
|
|
1629
|
+
)
|
|
1630
|
+
if self._transform is not None:
|
|
1631
|
+
tensordicts = self._transform.inv(tensordicts)
|
|
1632
|
+
if tensordicts is None:
|
|
1633
|
+
return torch.zeros((0, self._storage.ndim), dtype=torch.long)
|
|
1634
|
+
|
|
1635
|
+
index = super()._extend(tensordicts)
|
|
1636
|
+
|
|
1637
|
+
# TODO: to be usable directly, the indices should be flipped but the issue
|
|
1638
|
+
# is that just doing this results in indices that are not sorted like the original data
|
|
1639
|
+
# so the actually indices will have to be used on the _storage directly (not on the buffer)
|
|
1640
|
+
self._set_index_in_td(tensordicts, index)
|
|
1641
|
+
if update_priority is None:
|
|
1642
|
+
update_priority = True
|
|
1643
|
+
if update_priority:
|
|
1644
|
+
try:
|
|
1645
|
+
vector = tensordicts.get(self.priority_key)
|
|
1646
|
+
if vector is not None:
|
|
1647
|
+
self.update_priority(index, vector)
|
|
1648
|
+
except Exception as e:
|
|
1649
|
+
raise RuntimeError(
|
|
1650
|
+
"Failed to update priority of extended data. You can try to set update_priority=False in the extend method and update the priority manually."
|
|
1651
|
+
) from e
|
|
1652
|
+
return index
|
|
1653
|
+
|
|
1654
|
+
def _set_index_in_td(self, tensordict, index):
|
|
1655
|
+
if index is None:
|
|
1656
|
+
return
|
|
1657
|
+
if _is_int(index):
|
|
1658
|
+
index = torch.as_tensor(index, device=tensordict.device)
|
|
1659
|
+
elif index.ndim == 2 and index.shape[:1] != tensordict.shape[:1]:
|
|
1660
|
+
for dim in range(2, tensordict.ndim + 1):
|
|
1661
|
+
if index.shape[:1].numel() == tensordict.shape[:dim].numel():
|
|
1662
|
+
# if index has 2 dims and is in a non-zero format
|
|
1663
|
+
index = index.unflatten(0, tensordict.shape[:dim])
|
|
1664
|
+
break
|
|
1665
|
+
else:
|
|
1666
|
+
raise RuntimeError(
|
|
1667
|
+
f"could not find how to reshape index with shape {index.shape} to fit in tensordict with shape {tensordict.shape}"
|
|
1668
|
+
)
|
|
1669
|
+
tensordict.set("index", index)
|
|
1670
|
+
return
|
|
1671
|
+
tensordict.set("index", expand_as_right(index, tensordict))
|
|
1672
|
+
|
|
1673
|
+
@_maybe_delay_init
|
|
1674
|
+
def update_tensordict_priority(self, data: TensorDictBase) -> None:
|
|
1675
|
+
if not isinstance(self._sampler, PrioritizedSampler):
|
|
1676
|
+
return
|
|
1677
|
+
if data.ndim:
|
|
1678
|
+
priority = self._get_priority_vector(data)
|
|
1679
|
+
else:
|
|
1680
|
+
priority = torch.as_tensor(self._get_priority_item(data))
|
|
1681
|
+
index = data.get("index")
|
|
1682
|
+
if self._storage.ndim > 1 and index.ndim == 2:
|
|
1683
|
+
index = index.unbind(-1)
|
|
1684
|
+
else:
|
|
1685
|
+
while index.shape != priority.shape:
|
|
1686
|
+
# reduce index
|
|
1687
|
+
index = index[..., 0]
|
|
1688
|
+
return self.update_priority(index, priority)
|
|
1689
|
+
|
|
1690
|
+
def sample(
|
|
1691
|
+
self,
|
|
1692
|
+
batch_size: int | None = None,
|
|
1693
|
+
return_info: bool = False,
|
|
1694
|
+
include_info: bool | None = None,
|
|
1695
|
+
) -> TensorDictBase:
|
|
1696
|
+
"""Samples a batch of data from the replay buffer.
|
|
1697
|
+
|
|
1698
|
+
Uses Sampler to sample indices, and retrieves them from Storage.
|
|
1699
|
+
|
|
1700
|
+
Args:
|
|
1701
|
+
batch_size (int, optional): size of data to be collected. If none
|
|
1702
|
+
is provided, this method will sample a batch-size as indicated
|
|
1703
|
+
by the sampler.
|
|
1704
|
+
return_info (bool): whether to return info. If True, the result
|
|
1705
|
+
is a tuple (data, info). If False, the result is the data.
|
|
1706
|
+
|
|
1707
|
+
Returns:
|
|
1708
|
+
A tensordict containing a batch of data selected in the replay buffer.
|
|
1709
|
+
A tuple containing this tensordict and info if return_info flag is set to True.
|
|
1710
|
+
"""
|
|
1711
|
+
if include_info is not None:
|
|
1712
|
+
warnings.warn(
|
|
1713
|
+
"include_info is going to be deprecated soon."
|
|
1714
|
+
"The default behavior has changed to `include_info=True` "
|
|
1715
|
+
"to avoid bugs linked to wrongly preassigned values in the "
|
|
1716
|
+
"output tensordict."
|
|
1717
|
+
)
|
|
1718
|
+
|
|
1719
|
+
data, info = super().sample(batch_size, return_info=True)
|
|
1720
|
+
is_tc = is_tensor_collection(data)
|
|
1721
|
+
if is_tc and not is_tensorclass(data) and include_info in (True, None):
|
|
1722
|
+
is_locked = data.is_locked
|
|
1723
|
+
if is_locked:
|
|
1724
|
+
data.unlock_()
|
|
1725
|
+
for key, val in info.items():
|
|
1726
|
+
if key == "index" and isinstance(val, tuple):
|
|
1727
|
+
val = torch.stack(val, -1)
|
|
1728
|
+
try:
|
|
1729
|
+
val = _to_torch(val, data.device)
|
|
1730
|
+
if val.ndim < data.ndim:
|
|
1731
|
+
val = expand_as_right(val, data)
|
|
1732
|
+
data.set(key, val)
|
|
1733
|
+
except RuntimeError:
|
|
1734
|
+
raise RuntimeError(
|
|
1735
|
+
"Failed to set the metadata (e.g., indices or weights) in the sampled tensordict within TensorDictReplayBuffer.sample. "
|
|
1736
|
+
"This is probably caused by a shape mismatch (one of the transforms has probably modified "
|
|
1737
|
+
"the shape of the output tensordict). "
|
|
1738
|
+
"You can always recover these items from the `sample` method from a regular ReplayBuffer "
|
|
1739
|
+
"instance with the 'return_info' flag set to True."
|
|
1740
|
+
)
|
|
1741
|
+
if is_locked:
|
|
1742
|
+
data.lock_()
|
|
1743
|
+
elif not is_tc and include_info in (True, None):
|
|
1744
|
+
raise RuntimeError("Cannot include info in non-tensordict data")
|
|
1745
|
+
if return_info:
|
|
1746
|
+
return data, info
|
|
1747
|
+
return data
|
|
1748
|
+
|
|
1749
|
+
@pin_memory_output
|
|
1750
|
+
def _sample(self, batch_size: int) -> tuple[Any, dict]:
|
|
1751
|
+
is_comp = is_compiling()
|
|
1752
|
+
nc = contextlib.nullcontext()
|
|
1753
|
+
with self._replay_lock if not is_comp else nc, self._write_lock if not is_comp else nc:
|
|
1754
|
+
index, info = self._sampler.sample(self._storage, batch_size)
|
|
1755
|
+
info["index"] = index
|
|
1756
|
+
data = self._storage.get(index)
|
|
1757
|
+
if not isinstance(index, INT_CLASSES):
|
|
1758
|
+
data = self._collate_fn(data)
|
|
1759
|
+
if self._transform is not None and len(self._transform):
|
|
1760
|
+
with data.unlock_(), _set_dispatch_td_nn_modules(True):
|
|
1761
|
+
data = self._transform(data)
|
|
1762
|
+
return data, info
|
|
1763
|
+
|
|
1764
|
+
|
|
1765
|
+
class TensorDictPrioritizedReplayBuffer(TensorDictReplayBuffer):
|
|
1766
|
+
"""TensorDict-specific wrapper around the :class:`~torchrl.data.PrioritizedReplayBuffer` class.
|
|
1767
|
+
|
|
1768
|
+
This class returns tensordicts with a new key ``"index"`` that represents
|
|
1769
|
+
the index of each element in the replay buffer. It also provides the
|
|
1770
|
+
:meth:`~.update_tensordict_priority` method that only requires for the
|
|
1771
|
+
tensordict to be passed to it with its new priority value.
|
|
1772
|
+
|
|
1773
|
+
Keyword Args:
|
|
1774
|
+
alpha (:obj:`float`): exponent α determines how much prioritization is used,
|
|
1775
|
+
with α = 0 corresponding to the uniform case.
|
|
1776
|
+
beta (:obj:`float`): importance sampling negative exponent.
|
|
1777
|
+
eps (:obj:`float`): delta added to the priorities to ensure that the buffer
|
|
1778
|
+
does not contain null priorities.
|
|
1779
|
+
storage (Storage, Callable[[], Storage], optional): the storage to be used.
|
|
1780
|
+
If a callable is passed, it is used as constructor for the storage.
|
|
1781
|
+
If none is provided a default :class:`~torchrl.data.replay_buffers.ListStorage` with
|
|
1782
|
+
``max_size`` of ``1_000`` will be created.
|
|
1783
|
+
collate_fn (callable, optional): merges a list of samples to form a
|
|
1784
|
+
mini-batch of Tensor(s)/outputs. Used when using batched
|
|
1785
|
+
loading from a map-style dataset. The default value will be decided
|
|
1786
|
+
based on the storage type.
|
|
1787
|
+
pin_memory (bool): whether pin_memory() should be called on the rb
|
|
1788
|
+
samples.
|
|
1789
|
+
prefetch (int, optional): number of next batches to be prefetched
|
|
1790
|
+
using multithreading. Defaults to None (no prefetching).
|
|
1791
|
+
transform (Transform or Callable[[Any], Any], optional): Transform to be executed when
|
|
1792
|
+
:meth:`sample` is called.
|
|
1793
|
+
To chain transforms use the :class:`~torchrl.envs.Compose` class.
|
|
1794
|
+
Transforms should be used with :class:`tensordict.TensorDict`
|
|
1795
|
+
content. A generic callable can also be passed if the replay buffer
|
|
1796
|
+
is used with PyTree structures (see example below).
|
|
1797
|
+
Unlike storages, writers and samplers, transform constructors must
|
|
1798
|
+
be passed as separate keyword argument :attr:`transform_factory`,
|
|
1799
|
+
as it is impossible to distinguish a constructor from a transform.
|
|
1800
|
+
transform_factory (Callable[[], Callable], optional): a factory for the
|
|
1801
|
+
transform. Exclusive with :attr:`transform`.
|
|
1802
|
+
batch_size (int, optional): the batch size to be used when sample() is
|
|
1803
|
+
called.
|
|
1804
|
+
|
|
1805
|
+
.. note::
|
|
1806
|
+
The batch-size can be specified at construction time via the
|
|
1807
|
+
``batch_size`` argument, or at sampling time. The former should
|
|
1808
|
+
be preferred whenever the batch-size is consistent across the
|
|
1809
|
+
experiment. If the batch-size is likely to change, it can be
|
|
1810
|
+
passed to the :meth:`~.sample` method. This option is
|
|
1811
|
+
incompatible with prefetching (since this requires to know the
|
|
1812
|
+
batch-size in advance) as well as with samplers that have a
|
|
1813
|
+
``drop_last`` argument.
|
|
1814
|
+
|
|
1815
|
+
priority_key (str, optional): the key at which priority is assumed to
|
|
1816
|
+
be stored within TensorDicts added to this ReplayBuffer.
|
|
1817
|
+
This is to be used when the sampler is of type
|
|
1818
|
+
:class:`~torchrl.data.PrioritizedSampler`.
|
|
1819
|
+
Defaults to ``"td_error"``.
|
|
1820
|
+
reduction (str, optional): the reduction method for multidimensional
|
|
1821
|
+
tensordicts (ie stored trajectories). Can be one of "max", "min",
|
|
1822
|
+
"median" or "mean".
|
|
1823
|
+
dim_extend (int, optional): indicates the dim to consider for
|
|
1824
|
+
extension when calling :meth:`~.extend`. Defaults to ``storage.ndim-1``.
|
|
1825
|
+
When using ``dim_extend > 0``, we recommend using the ``ndim``
|
|
1826
|
+
argument in the storage instantiation if that argument is
|
|
1827
|
+
available, to let storages know that the data is
|
|
1828
|
+
multi-dimensional and keep consistent notions of storage-capacity
|
|
1829
|
+
and batch-size during sampling.
|
|
1830
|
+
|
|
1831
|
+
.. note:: This argument has no effect on :meth:`~.add` and
|
|
1832
|
+
therefore should be used with caution when both :meth:`~.add`
|
|
1833
|
+
and :meth:`~.extend` are used in a codebase. For example:
|
|
1834
|
+
|
|
1835
|
+
>>> data = torch.zeros(3, 4)
|
|
1836
|
+
>>> rb = ReplayBuffer(
|
|
1837
|
+
... storage=LazyTensorStorage(10, ndim=2),
|
|
1838
|
+
... dim_extend=1)
|
|
1839
|
+
>>> # these two approaches are equivalent:
|
|
1840
|
+
>>> for d in data.unbind(1):
|
|
1841
|
+
... rb.add(d)
|
|
1842
|
+
>>> rb.extend(data)
|
|
1843
|
+
|
|
1844
|
+
generator (torch.Generator, optional): a generator to use for sampling.
|
|
1845
|
+
Using a dedicated generator for the replay buffer can allow a fine-grained control
|
|
1846
|
+
over seeding, for instance keeping the global seed different but the RB seed identical
|
|
1847
|
+
for distributed jobs.
|
|
1848
|
+
Defaults to ``None`` (global default generator).
|
|
1849
|
+
|
|
1850
|
+
.. warning:: As of now, the generator has no effect on the transforms.
|
|
1851
|
+
shared (bool, optional): whether the buffer will be shared using multiprocessing or not.
|
|
1852
|
+
Defaults to ``False``.
|
|
1853
|
+
compilable (bool, optional): whether the writer is compilable.
|
|
1854
|
+
If ``True``, the writer cannot be shared between multiple processes.
|
|
1855
|
+
Defaults to ``False``.
|
|
1856
|
+
delayed_init (bool, optional): whether to initialize storage, writer, sampler and transform
|
|
1857
|
+
the first time the buffer is used rather than during construction.
|
|
1858
|
+
This is useful when the replay buffer needs to be pickled and sent to remote workers,
|
|
1859
|
+
particularly when using transforms with modules that require gradients.
|
|
1860
|
+
If not specified, defaults to ``True`` when ``transform_factory`` is provided,
|
|
1861
|
+
and ``False`` otherwise.
|
|
1862
|
+
|
|
1863
|
+
Examples:
|
|
1864
|
+
>>> import torch
|
|
1865
|
+
>>>
|
|
1866
|
+
>>> from torchrl.data import LazyTensorStorage, TensorDictPrioritizedReplayBuffer
|
|
1867
|
+
>>> from tensordict import TensorDict
|
|
1868
|
+
>>>
|
|
1869
|
+
>>> torch.manual_seed(0)
|
|
1870
|
+
>>>
|
|
1871
|
+
>>> rb = TensorDictPrioritizedReplayBuffer(alpha=0.7, beta=1.1, storage=LazyTensorStorage(10), batch_size=5)
|
|
1872
|
+
>>> data = TensorDict({"a": torch.ones(10, 3), ("b", "c"): torch.zeros(10, 3, 1)}, [10])
|
|
1873
|
+
>>> rb.extend(data)
|
|
1874
|
+
>>> print("len of rb", len(rb))
|
|
1875
|
+
len of rb 10
|
|
1876
|
+
>>> sample = rb.sample(5)
|
|
1877
|
+
>>> print(sample)
|
|
1878
|
+
TensorDict(
|
|
1879
|
+
fields={
|
|
1880
|
+
priority_weight: Tensor(shape=torch.Size([5]), device=cpu, dtype=torch.float32, is_shared=False),
|
|
1881
|
+
a: Tensor(shape=torch.Size([5, 3]), device=cpu, dtype=torch.float32, is_shared=False),
|
|
1882
|
+
b: TensorDict(
|
|
1883
|
+
fields={
|
|
1884
|
+
c: Tensor(shape=torch.Size([5, 3, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
|
|
1885
|
+
batch_size=torch.Size([5]),
|
|
1886
|
+
device=cpu,
|
|
1887
|
+
is_shared=False),
|
|
1888
|
+
index: Tensor(shape=torch.Size([5]), device=cpu, dtype=torch.int64, is_shared=False)},
|
|
1889
|
+
batch_size=torch.Size([5]),
|
|
1890
|
+
device=cpu,
|
|
1891
|
+
is_shared=False)
|
|
1892
|
+
>>> print("index", sample["index"])
|
|
1893
|
+
index tensor([9, 5, 2, 2, 7])
|
|
1894
|
+
>>> # give a high priority to these samples...
|
|
1895
|
+
>>> sample.set("td_error", 100*torch.ones(sample.shape))
|
|
1896
|
+
>>> # and update priority
|
|
1897
|
+
>>> rb.update_tensordict_priority(sample)
|
|
1898
|
+
>>> # the new sample should have a high overlap with the previous one
|
|
1899
|
+
>>> sample = rb.sample(5)
|
|
1900
|
+
>>> print(sample)
|
|
1901
|
+
TensorDict(
|
|
1902
|
+
fields={
|
|
1903
|
+
priority_weight: Tensor(shape=torch.Size([5]), device=cpu, dtype=torch.float32, is_shared=False),
|
|
1904
|
+
a: Tensor(shape=torch.Size([5, 3]), device=cpu, dtype=torch.float32, is_shared=False),
|
|
1905
|
+
b: TensorDict(
|
|
1906
|
+
fields={
|
|
1907
|
+
c: Tensor(shape=torch.Size([5, 3, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
|
|
1908
|
+
batch_size=torch.Size([5]),
|
|
1909
|
+
device=cpu,
|
|
1910
|
+
is_shared=False),
|
|
1911
|
+
index: Tensor(shape=torch.Size([5]), device=cpu, dtype=torch.int64, is_shared=False)},
|
|
1912
|
+
batch_size=torch.Size([5]),
|
|
1913
|
+
device=cpu,
|
|
1914
|
+
is_shared=False)
|
|
1915
|
+
>>> print("index", sample["index"])
|
|
1916
|
+
index tensor([2, 5, 5, 9, 7])
|
|
1917
|
+
|
|
1918
|
+
"""
|
|
1919
|
+
|
|
1920
|
+
def __init__(
|
|
1921
|
+
self,
|
|
1922
|
+
*,
|
|
1923
|
+
alpha: float,
|
|
1924
|
+
beta: float,
|
|
1925
|
+
priority_key: str = "td_error",
|
|
1926
|
+
eps: float = 1e-8,
|
|
1927
|
+
storage: Storage | None = None,
|
|
1928
|
+
collate_fn: Callable | None = None,
|
|
1929
|
+
pin_memory: bool = False,
|
|
1930
|
+
prefetch: int | None = None,
|
|
1931
|
+
transform: Transform | None = None, # noqa-F821
|
|
1932
|
+
reduction: str = "max",
|
|
1933
|
+
batch_size: int | None = None,
|
|
1934
|
+
dim_extend: int | None = None,
|
|
1935
|
+
generator: torch.Generator | None = None,
|
|
1936
|
+
shared: bool = False,
|
|
1937
|
+
compilable: bool = False,
|
|
1938
|
+
) -> None:
|
|
1939
|
+
storage = self._maybe_make_storage(storage, compilable=compilable)
|
|
1940
|
+
sampler = PrioritizedSampler(
|
|
1941
|
+
storage.max_size, alpha, beta, eps, reduction=reduction
|
|
1942
|
+
)
|
|
1943
|
+
super().__init__(
|
|
1944
|
+
priority_key=priority_key,
|
|
1945
|
+
storage=storage,
|
|
1946
|
+
sampler=sampler,
|
|
1947
|
+
collate_fn=collate_fn,
|
|
1948
|
+
pin_memory=pin_memory,
|
|
1949
|
+
prefetch=prefetch,
|
|
1950
|
+
transform=transform,
|
|
1951
|
+
batch_size=batch_size,
|
|
1952
|
+
dim_extend=dim_extend,
|
|
1953
|
+
generator=generator,
|
|
1954
|
+
shared=shared,
|
|
1955
|
+
compilable=compilable,
|
|
1956
|
+
)
|
|
1957
|
+
|
|
1958
|
+
|
|
1959
|
+
@accept_remote_rref_udf_invocation
|
|
1960
|
+
class RemoteTensorDictReplayBuffer(TensorDictReplayBuffer):
|
|
1961
|
+
"""A remote invocation friendly ReplayBuffer class. Public methods can be invoked by remote agents using `torch.rpc` or called locally as normal."""
|
|
1962
|
+
|
|
1963
|
+
def __init__(self, *args, **kwargs):
|
|
1964
|
+
super().__init__(*args, **kwargs)
|
|
1965
|
+
|
|
1966
|
+
def sample(
|
|
1967
|
+
self,
|
|
1968
|
+
batch_size: int | None = None,
|
|
1969
|
+
include_info: bool | None = None,
|
|
1970
|
+
return_info: bool = False,
|
|
1971
|
+
) -> TensorDictBase:
|
|
1972
|
+
return super().sample(
|
|
1973
|
+
batch_size=batch_size, include_info=include_info, return_info=return_info
|
|
1974
|
+
)
|
|
1975
|
+
|
|
1976
|
+
def add(self, data: TensorDictBase) -> int:
|
|
1977
|
+
return super().add(data)
|
|
1978
|
+
|
|
1979
|
+
def extend(
|
|
1980
|
+
self, tensordicts: list | TensorDictBase, *, update_priority: bool | None = None
|
|
1981
|
+
) -> torch.Tensor:
|
|
1982
|
+
return super().extend(tensordicts, update_priority=update_priority)
|
|
1983
|
+
|
|
1984
|
+
def update_priority(
|
|
1985
|
+
self, index: int | torch.Tensor, priority: int | torch.Tensor
|
|
1986
|
+
) -> None:
|
|
1987
|
+
return super().update_priority(index, priority)
|
|
1988
|
+
|
|
1989
|
+
def update_tensordict_priority(self, data: TensorDictBase) -> None:
|
|
1990
|
+
return super().update_tensordict_priority(data)
|
|
1991
|
+
|
|
1992
|
+
|
|
1993
|
+
class InPlaceSampler:
|
|
1994
|
+
"""[Deprecated] A sampler to write tennsordicts in-place."""
|
|
1995
|
+
|
|
1996
|
+
def __init__(self, device: DEVICE_TYPING | None = None):
|
|
1997
|
+
raise RuntimeError(
|
|
1998
|
+
"This class has been removed without replacement. In-place sampling should be avoided."
|
|
1999
|
+
)
|
|
2000
|
+
|
|
2001
|
+
|
|
2002
|
+
def stack_tensors(list_of_tensor_iterators: list) -> tuple[torch.Tensor]:
|
|
2003
|
+
"""Zips a list of iterables containing tensor-like objects and stacks the resulting lists of tensors together.
|
|
2004
|
+
|
|
2005
|
+
Args:
|
|
2006
|
+
list_of_tensor_iterators (list): Sequence containing similar iterators,
|
|
2007
|
+
where each element of the nested iterator is a tensor whose
|
|
2008
|
+
shape match the tensor of other iterators that have the same index.
|
|
2009
|
+
|
|
2010
|
+
Returns:
|
|
2011
|
+
Tuple of stacked tensors.
|
|
2012
|
+
|
|
2013
|
+
Examples:
|
|
2014
|
+
>>> list_of_tensor_iterators = [[torch.ones(3), torch.zeros(1,2)]
|
|
2015
|
+
... for _ in range(4)]
|
|
2016
|
+
>>> stack_tensors(list_of_tensor_iterators)
|
|
2017
|
+
(tensor([[1., 1., 1.],
|
|
2018
|
+
[1., 1., 1.],
|
|
2019
|
+
[1., 1., 1.],
|
|
2020
|
+
[1., 1., 1.]]), tensor([[[0., 0.]],
|
|
2021
|
+
<BLANKLINE>
|
|
2022
|
+
[[0., 0.]],
|
|
2023
|
+
<BLANKLINE>
|
|
2024
|
+
[[0., 0.]],
|
|
2025
|
+
<BLANKLINE>
|
|
2026
|
+
[[0., 0.]]]))
|
|
2027
|
+
|
|
2028
|
+
"""
|
|
2029
|
+
return tuple(torch.stack(tensors, 0) for tensors in zip(*list_of_tensor_iterators))
|
|
2030
|
+
|
|
2031
|
+
|
|
2032
|
+
class ReplayBufferEnsemble(ReplayBuffer):
|
|
2033
|
+
"""An ensemble of replay buffers.
|
|
2034
|
+
|
|
2035
|
+
This class allows to read and sample from multiple replay buffers at once.
|
|
2036
|
+
It automatically composes ensemble of storages (:class:`~torchrl.data.replay_buffers.storages.StorageEnsemble`),
|
|
2037
|
+
writers (:class:`~torchrl.data.replay_buffers.writers.WriterEnsemble`) and
|
|
2038
|
+
samplers (:class:`~torchrl.data.replay_buffers.samplers.SamplerEnsemble`).
|
|
2039
|
+
|
|
2040
|
+
.. note::
|
|
2041
|
+
Writing directly to this class is forbidden, but it can be indexed to retrieve
|
|
2042
|
+
the nested nested-buffer and extending it.
|
|
2043
|
+
|
|
2044
|
+
There are two distinct ways of constructing a :class:`~torchrl.data.ReplayBufferEnsemble`:
|
|
2045
|
+
one can either pass a list of replay buffers, or directly pass the components
|
|
2046
|
+
(storage, writers and samplers) like it is done for other replay buffer subclasses.
|
|
2047
|
+
|
|
2048
|
+
Args:
|
|
2049
|
+
rbs (sequence of ReplayBuffer instances, optional): the replay buffers to ensemble.
|
|
2050
|
+
storages (StorageEnsemble, optional): the ensemble of storages, if the replay
|
|
2051
|
+
buffers are not passed.
|
|
2052
|
+
samplers (SamplerEnsemble, optional): the ensemble of samplers, if the replay
|
|
2053
|
+
buffers are not passed.
|
|
2054
|
+
writers (WriterEnsemble, optional): the ensemble of writers, if the replay
|
|
2055
|
+
buffers are not passed.
|
|
2056
|
+
transform (Transform, optional): if passed, this will be the transform
|
|
2057
|
+
of the ensemble of replay buffers. Individual transforms for each
|
|
2058
|
+
replay buffer is retrieved from its parent replay buffer, or directly
|
|
2059
|
+
written in the :class:`~torchrl.data.replay_buffers.storages.StorageEnsemble`
|
|
2060
|
+
object.
|
|
2061
|
+
batch_size (int, optional): the batch-size to use during sampling.
|
|
2062
|
+
collate_fn (callable, optional): the function to use to collate the
|
|
2063
|
+
data after each individual collate_fn has been called and the data
|
|
2064
|
+
is placed in a list (along with the buffer id).
|
|
2065
|
+
collate_fns (list of callables, optional): collate_fn of each nested
|
|
2066
|
+
replay buffer. Retrieved from the :class:`~ReplayBuffer` instances
|
|
2067
|
+
if not provided.
|
|
2068
|
+
p (list of float or Tensor, optional): a list of floating numbers
|
|
2069
|
+
indicating the relative weight of each replay buffer. Can also
|
|
2070
|
+
be passed to torchrl.data.replay_buffers.samplers.SamplerEnsemble`
|
|
2071
|
+
if the buffer is built explicitly.
|
|
2072
|
+
sample_from_all (bool, optional): if ``True``, each dataset will be sampled
|
|
2073
|
+
from. This is not compatible with the ``p`` argument. Defaults to ``False``.
|
|
2074
|
+
Can also be passed to torchrl.data.replay_buffers.samplers.SamplerEnsemble`
|
|
2075
|
+
if the buffer is built explicitly.
|
|
2076
|
+
num_buffer_sampled (int, optional): the number of buffers to sample.
|
|
2077
|
+
if ``sample_from_all=True``, this has no effect, as it defaults to the
|
|
2078
|
+
number of buffers. If ``sample_from_all=False``, buffers will be
|
|
2079
|
+
sampled according to the probabilities ``p``. Can also
|
|
2080
|
+
be passed to torchrl.data.replay_buffers.samplers.SamplerEnsemble`
|
|
2081
|
+
if the buffer is built explicitly.
|
|
2082
|
+
generator (torch.Generator, optional): a generator to use for sampling.
|
|
2083
|
+
Using a dedicated generator for the replay buffer can allow a fine-grained control
|
|
2084
|
+
over seeding, for instance keeping the global seed different but the RB seed identical
|
|
2085
|
+
for distributed jobs.
|
|
2086
|
+
Defaults to ``None`` (global default generator).
|
|
2087
|
+
|
|
2088
|
+
.. warning:: As of now, the generator has no effect on the transforms.
|
|
2089
|
+
|
|
2090
|
+
shared (bool, optional): whether the buffer will be shared using multiprocessing or not.
|
|
2091
|
+
Defaults to ``False``.
|
|
2092
|
+
delayed_init (bool, optional): whether to initialize storage, writer, sampler and transform
|
|
2093
|
+
the first time the buffer is used rather than during construction.
|
|
2094
|
+
This is useful when the replay buffer needs to be pickled and sent to remote workers,
|
|
2095
|
+
particularly when using transforms with modules that require gradients.
|
|
2096
|
+
If not specified, defaults to ``True`` when ``transform_factory`` is provided,
|
|
2097
|
+
and ``False`` otherwise.
|
|
2098
|
+
|
|
2099
|
+
Examples:
|
|
2100
|
+
>>> from torchrl.envs import Compose, ToTensorImage, Resize, RenameTransform
|
|
2101
|
+
>>> from torchrl.data import TensorDictReplayBuffer, ReplayBufferEnsemble, LazyMemmapStorage
|
|
2102
|
+
>>> from tensordict import TensorDict
|
|
2103
|
+
>>> import torch
|
|
2104
|
+
>>> rb0 = TensorDictReplayBuffer(
|
|
2105
|
+
... storage=LazyMemmapStorage(10),
|
|
2106
|
+
... transform=Compose(
|
|
2107
|
+
... ToTensorImage(in_keys=["pixels", ("next", "pixels")]),
|
|
2108
|
+
... Resize(32, in_keys=["pixels", ("next", "pixels")]),
|
|
2109
|
+
... RenameTransform([("some", "key")], ["renamed"]),
|
|
2110
|
+
... ),
|
|
2111
|
+
... )
|
|
2112
|
+
>>> rb1 = TensorDictReplayBuffer(
|
|
2113
|
+
... storage=LazyMemmapStorage(10),
|
|
2114
|
+
... transform=Compose(
|
|
2115
|
+
... ToTensorImage(in_keys=["pixels", ("next", "pixels")]),
|
|
2116
|
+
... Resize(32, in_keys=["pixels", ("next", "pixels")]),
|
|
2117
|
+
... RenameTransform(["another_key"], ["renamed"]),
|
|
2118
|
+
... ),
|
|
2119
|
+
... )
|
|
2120
|
+
>>> rb = ReplayBufferEnsemble(
|
|
2121
|
+
... rb0,
|
|
2122
|
+
... rb1,
|
|
2123
|
+
... p=[0.5, 0.5],
|
|
2124
|
+
... transform=Resize(33, in_keys=["pixels"], out_keys=["pixels33"]),
|
|
2125
|
+
... )
|
|
2126
|
+
>>> print(rb)
|
|
2127
|
+
ReplayBufferEnsemble(
|
|
2128
|
+
storages=StorageEnsemble(
|
|
2129
|
+
storages=(<torchrl.data.replay_buffers.storages.LazyMemmapStorage object at 0x13a2ef430>, <torchrl.data.replay_buffers.storages.LazyMemmapStorage object at 0x13a2f9310>),
|
|
2130
|
+
transforms=[Compose(
|
|
2131
|
+
ToTensorImage(keys=['pixels', ('next', 'pixels')]),
|
|
2132
|
+
Resize(w=32, h=32, interpolation=InterpolationMode.BILINEAR, keys=['pixels', ('next', 'pixels')]),
|
|
2133
|
+
RenameTransform(keys=[('some', 'key')])), Compose(
|
|
2134
|
+
ToTensorImage(keys=['pixels', ('next', 'pixels')]),
|
|
2135
|
+
Resize(w=32, h=32, interpolation=InterpolationMode.BILINEAR, keys=['pixels', ('next', 'pixels')]),
|
|
2136
|
+
RenameTransform(keys=['another_key']))]),
|
|
2137
|
+
samplers=SamplerEnsemble(
|
|
2138
|
+
samplers=(<torchrl.data.replay_buffers.samplers.RandomSampler object at 0x13a2f9220>, <torchrl.data.replay_buffers.samplers.RandomSampler object at 0x13a2f9f70>)),
|
|
2139
|
+
writers=WriterEnsemble(
|
|
2140
|
+
writers=(<torchrl.data.replay_buffers.writers.TensorDictRoundRobinWriter object at 0x13a2d9b50>, <torchrl.data.replay_buffers.writers.TensorDictRoundRobinWriter object at 0x13a2f95b0>)),
|
|
2141
|
+
batch_size=None,
|
|
2142
|
+
transform=Compose(
|
|
2143
|
+
Resize(w=33, h=33, interpolation=InterpolationMode.BILINEAR, keys=['pixels'])),
|
|
2144
|
+
collate_fn=<built-in method stack of type object at 0x128648260>)
|
|
2145
|
+
>>> data0 = TensorDict(
|
|
2146
|
+
... {
|
|
2147
|
+
... "pixels": torch.randint(255, (10, 244, 244, 3)),
|
|
2148
|
+
... ("next", "pixels"): torch.randint(255, (10, 244, 244, 3)),
|
|
2149
|
+
... ("some", "key"): torch.randn(10),
|
|
2150
|
+
... },
|
|
2151
|
+
... batch_size=[10],
|
|
2152
|
+
... )
|
|
2153
|
+
>>> data1 = TensorDict(
|
|
2154
|
+
... {
|
|
2155
|
+
... "pixels": torch.randint(255, (10, 64, 64, 3)),
|
|
2156
|
+
... ("next", "pixels"): torch.randint(255, (10, 64, 64, 3)),
|
|
2157
|
+
... "another_key": torch.randn(10),
|
|
2158
|
+
... },
|
|
2159
|
+
... batch_size=[10],
|
|
2160
|
+
... )
|
|
2161
|
+
>>> rb[0].extend(data0)
|
|
2162
|
+
>>> rb[1].extend(data1)
|
|
2163
|
+
>>> for _ in range(2):
|
|
2164
|
+
... sample = rb.sample(10)
|
|
2165
|
+
... assert sample["next", "pixels"].shape == torch.Size([2, 5, 3, 32, 32])
|
|
2166
|
+
... assert sample["pixels"].shape == torch.Size([2, 5, 3, 32, 32])
|
|
2167
|
+
... assert sample["pixels33"].shape == torch.Size([2, 5, 3, 33, 33])
|
|
2168
|
+
... assert sample["renamed"].shape == torch.Size([2, 5])
|
|
2169
|
+
|
|
2170
|
+
"""
|
|
2171
|
+
|
|
2172
|
+
_collate_fn_val = None
|
|
2173
|
+
|
|
2174
|
+
def __init__(
|
|
2175
|
+
self,
|
|
2176
|
+
*rbs,
|
|
2177
|
+
storages: StorageEnsemble | None = None,
|
|
2178
|
+
samplers: SamplerEnsemble | None = None,
|
|
2179
|
+
writers: WriterEnsemble | None = None,
|
|
2180
|
+
transform: Transform | None = None, # noqa: F821
|
|
2181
|
+
batch_size: int | None = None,
|
|
2182
|
+
collate_fn: Callable | None = None,
|
|
2183
|
+
collate_fns: list[Callable] | None = None,
|
|
2184
|
+
p: Tensor = None,
|
|
2185
|
+
sample_from_all: bool = False,
|
|
2186
|
+
num_buffer_sampled: int | None = None,
|
|
2187
|
+
generator: torch.Generator | None = None,
|
|
2188
|
+
shared: bool = False,
|
|
2189
|
+
**kwargs,
|
|
2190
|
+
):
|
|
2191
|
+
|
|
2192
|
+
if collate_fn is None:
|
|
2193
|
+
collate_fn = _stack_anything
|
|
2194
|
+
|
|
2195
|
+
if rbs:
|
|
2196
|
+
if storages is not None or samplers is not None or writers is not None:
|
|
2197
|
+
raise RuntimeError
|
|
2198
|
+
# Ensure all replay buffers are initialized before creating ensemble
|
|
2199
|
+
for rb in rbs:
|
|
2200
|
+
if (
|
|
2201
|
+
hasattr(rb, "_delayed_init")
|
|
2202
|
+
and rb._delayed_init
|
|
2203
|
+
and not rb.initialized
|
|
2204
|
+
):
|
|
2205
|
+
rb._init()
|
|
2206
|
+
storages = StorageEnsemble(
|
|
2207
|
+
*[rb._storage for rb in rbs], transforms=[rb._transform for rb in rbs]
|
|
2208
|
+
)
|
|
2209
|
+
samplers = SamplerEnsemble(
|
|
2210
|
+
*[rb._sampler for rb in rbs],
|
|
2211
|
+
p=p,
|
|
2212
|
+
sample_from_all=sample_from_all,
|
|
2213
|
+
num_buffer_sampled=num_buffer_sampled,
|
|
2214
|
+
)
|
|
2215
|
+
writers = WriterEnsemble(*[rb._writer for rb in rbs])
|
|
2216
|
+
if collate_fns is None:
|
|
2217
|
+
collate_fns = [rb._collate_fn for rb in rbs]
|
|
2218
|
+
else:
|
|
2219
|
+
rbs = None
|
|
2220
|
+
if collate_fns is None:
|
|
2221
|
+
collate_fns = [
|
|
2222
|
+
_get_default_collate(storage) for storage in storages._storages
|
|
2223
|
+
]
|
|
2224
|
+
self._rbs = rbs
|
|
2225
|
+
self._collate_fns = collate_fns
|
|
2226
|
+
super().__init__(
|
|
2227
|
+
storage=storages,
|
|
2228
|
+
sampler=samplers,
|
|
2229
|
+
writer=writers,
|
|
2230
|
+
transform=transform,
|
|
2231
|
+
batch_size=batch_size,
|
|
2232
|
+
collate_fn=collate_fn,
|
|
2233
|
+
generator=generator,
|
|
2234
|
+
shared=shared,
|
|
2235
|
+
**kwargs,
|
|
2236
|
+
)
|
|
2237
|
+
|
|
2238
|
+
def _sample(self, *args, **kwargs):
|
|
2239
|
+
sample, info = super()._sample(*args, **kwargs)
|
|
2240
|
+
if isinstance(sample, TensorDictBase):
|
|
2241
|
+
buffer_ids = info.get(("index", "buffer_ids"))
|
|
2242
|
+
info.set(
|
|
2243
|
+
("index", "buffer_ids"), expand_right(buffer_ids, sample.batch_size)
|
|
2244
|
+
)
|
|
2245
|
+
if isinstance(info, LazyStackedTensorDict):
|
|
2246
|
+
for _info, _sample in zip(
|
|
2247
|
+
info.unbind(info.stack_dim), sample.unbind(info.stack_dim)
|
|
2248
|
+
):
|
|
2249
|
+
_info.batch_size = _sample.batch_size
|
|
2250
|
+
info = torch.stack(info.tensordicts, info.stack_dim)
|
|
2251
|
+
else:
|
|
2252
|
+
info.batch_size = sample.batch_size
|
|
2253
|
+
sample.update(info)
|
|
2254
|
+
|
|
2255
|
+
return sample, info
|
|
2256
|
+
|
|
2257
|
+
@property
|
|
2258
|
+
def _collate_fn(self):
|
|
2259
|
+
def new_collate(samples):
|
|
2260
|
+
samples = [self._collate_fns[i](sample) for (i, sample) in samples]
|
|
2261
|
+
return self._collate_fn_val(samples)
|
|
2262
|
+
|
|
2263
|
+
return new_collate
|
|
2264
|
+
|
|
2265
|
+
@_collate_fn.setter
|
|
2266
|
+
def _collate_fn(self, value):
|
|
2267
|
+
self._collate_fn_val = value
|
|
2268
|
+
|
|
2269
|
+
_INDEX_ERROR = "Expected an index of type torch.Tensor, range, np.ndarray, int, slice or ellipsis, got {} instead."
|
|
2270
|
+
|
|
2271
|
+
def __getitem__(
|
|
2272
|
+
self, index: int | torch.Tensor | tuple | np.ndarray | list | slice | Ellipsis
|
|
2273
|
+
) -> Any:
|
|
2274
|
+
# accepts inputs:
|
|
2275
|
+
# (int | 1d tensor | 1d list | 1d array | slice | ellipsis | range, int | tensor | list | array | slice | ellipsis | range)
|
|
2276
|
+
# tensor
|
|
2277
|
+
if isinstance(index, tuple):
|
|
2278
|
+
if index[0] is Ellipsis:
|
|
2279
|
+
index = (slice(None), index[1:])
|
|
2280
|
+
rb = self[index[0]]
|
|
2281
|
+
if len(index) > 1:
|
|
2282
|
+
if rb is self:
|
|
2283
|
+
# then index[0] is an ellipsis/slice(None)
|
|
2284
|
+
sample = [
|
|
2285
|
+
(i, storage[index[1:]])
|
|
2286
|
+
for i, storage in enumerate(self._storage._storages)
|
|
2287
|
+
]
|
|
2288
|
+
return self._collate_fn(sample)
|
|
2289
|
+
if isinstance(rb, ReplayBufferEnsemble):
|
|
2290
|
+
new_index = (slice(None), *index[1:])
|
|
2291
|
+
return rb[new_index]
|
|
2292
|
+
return rb[index[1:]]
|
|
2293
|
+
return rb
|
|
2294
|
+
if isinstance(index, slice) and index == slice(None):
|
|
2295
|
+
return self
|
|
2296
|
+
if isinstance(index, (list, range, np.ndarray)):
|
|
2297
|
+
index = torch.as_tensor(index)
|
|
2298
|
+
if isinstance(index, torch.Tensor):
|
|
2299
|
+
if index.ndim > 1:
|
|
2300
|
+
raise RuntimeError(
|
|
2301
|
+
f"Cannot index a {type(self)} with tensor indices that have more than one dimension."
|
|
2302
|
+
)
|
|
2303
|
+
if index.is_floating_point():
|
|
2304
|
+
raise TypeError(
|
|
2305
|
+
"A floating point index was received when an integer dtype was expected."
|
|
2306
|
+
)
|
|
2307
|
+
if self._rbs is not None and (
|
|
2308
|
+
isinstance(index, int) or (not isinstance(index, slice) and len(index) == 0)
|
|
2309
|
+
):
|
|
2310
|
+
try:
|
|
2311
|
+
index = int(index)
|
|
2312
|
+
except Exception:
|
|
2313
|
+
raise IndexError(self._INDEX_ERROR.format(type(index)))
|
|
2314
|
+
try:
|
|
2315
|
+
return self._rbs[index]
|
|
2316
|
+
except IndexError:
|
|
2317
|
+
raise IndexError(self._INDEX_ERROR.format(type(index)))
|
|
2318
|
+
|
|
2319
|
+
if self._rbs is not None:
|
|
2320
|
+
if isinstance(index, torch.Tensor):
|
|
2321
|
+
index = index.tolist()
|
|
2322
|
+
rbs = [self._rbs[i] for i in index]
|
|
2323
|
+
_collate_fns = [self._collate_fns[i] for i in index]
|
|
2324
|
+
else:
|
|
2325
|
+
try:
|
|
2326
|
+
# slice
|
|
2327
|
+
rbs = self._rbs[index]
|
|
2328
|
+
_collate_fns = self._collate_fns[index]
|
|
2329
|
+
except IndexError:
|
|
2330
|
+
raise IndexError(self._INDEX_ERROR.format(type(index)))
|
|
2331
|
+
p = self._sampler._p[index] if self._sampler._p is not None else None
|
|
2332
|
+
return ReplayBufferEnsemble(
|
|
2333
|
+
*rbs,
|
|
2334
|
+
transform=self._transform,
|
|
2335
|
+
batch_size=self._batch_size,
|
|
2336
|
+
collate_fn=self._collate_fn_val,
|
|
2337
|
+
collate_fns=_collate_fns,
|
|
2338
|
+
sample_from_all=self._sampler.sample_from_all,
|
|
2339
|
+
num_buffer_sampled=self._sampler.num_buffer_sampled,
|
|
2340
|
+
p=p,
|
|
2341
|
+
)
|
|
2342
|
+
|
|
2343
|
+
try:
|
|
2344
|
+
samplers = self._sampler[index]
|
|
2345
|
+
writers = self._writer[index]
|
|
2346
|
+
storages = self._storage[index]
|
|
2347
|
+
if isinstance(index, torch.Tensor):
|
|
2348
|
+
_collate_fns = [self._collate_fns[i] for i in index.tolist()]
|
|
2349
|
+
else:
|
|
2350
|
+
_collate_fns = self._collate_fns[index]
|
|
2351
|
+
p = self._sampler._p[index] if self._sampler._p is not None else None
|
|
2352
|
+
|
|
2353
|
+
except IndexError:
|
|
2354
|
+
raise IndexError(self._INDEX_ERROR.format(type(index)))
|
|
2355
|
+
|
|
2356
|
+
return ReplayBufferEnsemble(
|
|
2357
|
+
samplers=samplers,
|
|
2358
|
+
writers=writers,
|
|
2359
|
+
storages=storages,
|
|
2360
|
+
transform=self._transform,
|
|
2361
|
+
batch_size=self._batch_size,
|
|
2362
|
+
collate_fn=self._collate_fn_val,
|
|
2363
|
+
collate_fns=_collate_fns,
|
|
2364
|
+
sample_from_all=self._sampler.sample_from_all,
|
|
2365
|
+
num_buffer_sampled=self._sampler.num_buffer_sampled,
|
|
2366
|
+
p=p,
|
|
2367
|
+
)
|
|
2368
|
+
|
|
2369
|
+
def __len__(self):
|
|
2370
|
+
return len(self._storage)
|
|
2371
|
+
|
|
2372
|
+
def __repr__(self):
|
|
2373
|
+
storages = textwrap.indent(f"storages={self._storage}", " " * 4)
|
|
2374
|
+
writers = textwrap.indent(f"writers={self._writer}", " " * 4)
|
|
2375
|
+
samplers = textwrap.indent(f"samplers={self._sampler}", " " * 4)
|
|
2376
|
+
return f"ReplayBufferEnsemble(\n{storages}, \n{samplers}, \n{writers}, \nbatch_size={self._batch_size}, \ntransform={self._transform}, \ncollate_fn={self._collate_fn_val})"
|