torchrl 0.11.0__cp314-cp314-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/benchmark_batched_envs.py +104 -0
- benchmarks/conftest.py +91 -0
- benchmarks/ecosystem/gym_env_throughput.py +321 -0
- benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
- benchmarks/requirements.txt +7 -0
- benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
- benchmarks/test_collectors_benchmark.py +240 -0
- benchmarks/test_compressed_storage_benchmark.py +145 -0
- benchmarks/test_envs_benchmark.py +133 -0
- benchmarks/test_llm.py +101 -0
- benchmarks/test_non_tensor_env_benchmark.py +70 -0
- benchmarks/test_objectives_benchmarks.py +1199 -0
- benchmarks/test_replaybuffer_benchmark.py +254 -0
- sota-check/README.md +35 -0
- sota-implementations/README.md +142 -0
- sota-implementations/a2c/README.md +39 -0
- sota-implementations/a2c/a2c_atari.py +291 -0
- sota-implementations/a2c/a2c_mujoco.py +273 -0
- sota-implementations/a2c/utils_atari.py +240 -0
- sota-implementations/a2c/utils_mujoco.py +160 -0
- sota-implementations/bandits/README.md +7 -0
- sota-implementations/bandits/dqn.py +126 -0
- sota-implementations/cql/cql_offline.py +198 -0
- sota-implementations/cql/cql_online.py +249 -0
- sota-implementations/cql/discrete_cql_offline.py +180 -0
- sota-implementations/cql/discrete_cql_online.py +227 -0
- sota-implementations/cql/utils.py +471 -0
- sota-implementations/crossq/crossq.py +271 -0
- sota-implementations/crossq/utils.py +320 -0
- sota-implementations/ddpg/ddpg.py +231 -0
- sota-implementations/ddpg/utils.py +325 -0
- sota-implementations/decision_transformer/dt.py +163 -0
- sota-implementations/decision_transformer/lamb.py +167 -0
- sota-implementations/decision_transformer/online_dt.py +178 -0
- sota-implementations/decision_transformer/utils.py +562 -0
- sota-implementations/discrete_sac/discrete_sac.py +243 -0
- sota-implementations/discrete_sac/utils.py +324 -0
- sota-implementations/dqn/README.md +30 -0
- sota-implementations/dqn/dqn_atari.py +272 -0
- sota-implementations/dqn/dqn_cartpole.py +236 -0
- sota-implementations/dqn/utils_atari.py +132 -0
- sota-implementations/dqn/utils_cartpole.py +90 -0
- sota-implementations/dreamer/README.md +129 -0
- sota-implementations/dreamer/dreamer.py +586 -0
- sota-implementations/dreamer/dreamer_utils.py +1107 -0
- sota-implementations/expert-iteration/README.md +352 -0
- sota-implementations/expert-iteration/ei_utils.py +770 -0
- sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
- sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
- sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
- sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
- sota-implementations/gail/gail.py +327 -0
- sota-implementations/gail/gail_utils.py +68 -0
- sota-implementations/gail/ppo_utils.py +157 -0
- sota-implementations/grpo/README.md +273 -0
- sota-implementations/grpo/grpo-async.py +437 -0
- sota-implementations/grpo/grpo-sync.py +435 -0
- sota-implementations/grpo/grpo_utils.py +843 -0
- sota-implementations/grpo/requirements_gsm8k.txt +11 -0
- sota-implementations/grpo/requirements_ifeval.txt +16 -0
- sota-implementations/impala/README.md +33 -0
- sota-implementations/impala/impala_multi_node_ray.py +292 -0
- sota-implementations/impala/impala_multi_node_submitit.py +284 -0
- sota-implementations/impala/impala_single_node.py +261 -0
- sota-implementations/impala/utils.py +184 -0
- sota-implementations/iql/discrete_iql.py +230 -0
- sota-implementations/iql/iql_offline.py +164 -0
- sota-implementations/iql/iql_online.py +225 -0
- sota-implementations/iql/utils.py +437 -0
- sota-implementations/multiagent/README.md +74 -0
- sota-implementations/multiagent/iql.py +237 -0
- sota-implementations/multiagent/maddpg_iddpg.py +266 -0
- sota-implementations/multiagent/mappo_ippo.py +267 -0
- sota-implementations/multiagent/qmix_vdn.py +271 -0
- sota-implementations/multiagent/sac.py +337 -0
- sota-implementations/multiagent/utils/__init__.py +4 -0
- sota-implementations/multiagent/utils/logging.py +151 -0
- sota-implementations/multiagent/utils/utils.py +43 -0
- sota-implementations/ppo/README.md +29 -0
- sota-implementations/ppo/ppo_atari.py +305 -0
- sota-implementations/ppo/ppo_mujoco.py +293 -0
- sota-implementations/ppo/utils_atari.py +238 -0
- sota-implementations/ppo/utils_mujoco.py +152 -0
- sota-implementations/ppo_trainer/train.py +21 -0
- sota-implementations/redq/README.md +7 -0
- sota-implementations/redq/redq.py +199 -0
- sota-implementations/redq/utils.py +1060 -0
- sota-implementations/sac/sac-async.py +266 -0
- sota-implementations/sac/sac.py +239 -0
- sota-implementations/sac/utils.py +381 -0
- sota-implementations/sac_trainer/train.py +16 -0
- sota-implementations/td3/td3.py +254 -0
- sota-implementations/td3/utils.py +319 -0
- sota-implementations/td3_bc/td3_bc.py +177 -0
- sota-implementations/td3_bc/utils.py +251 -0
- torchrl/__init__.py +144 -0
- torchrl/_extension.py +74 -0
- torchrl/_torchrl.cp314-win_amd64.pyd +0 -0
- torchrl/_utils.py +1431 -0
- torchrl/collectors/__init__.py +48 -0
- torchrl/collectors/_base.py +1058 -0
- torchrl/collectors/_constants.py +88 -0
- torchrl/collectors/_multi_async.py +324 -0
- torchrl/collectors/_multi_base.py +1805 -0
- torchrl/collectors/_multi_sync.py +464 -0
- torchrl/collectors/_runner.py +581 -0
- torchrl/collectors/_single.py +2009 -0
- torchrl/collectors/_single_async.py +259 -0
- torchrl/collectors/collectors.py +62 -0
- torchrl/collectors/distributed/__init__.py +32 -0
- torchrl/collectors/distributed/default_configs.py +133 -0
- torchrl/collectors/distributed/generic.py +1306 -0
- torchrl/collectors/distributed/ray.py +1092 -0
- torchrl/collectors/distributed/rpc.py +1006 -0
- torchrl/collectors/distributed/sync.py +731 -0
- torchrl/collectors/distributed/utils.py +160 -0
- torchrl/collectors/llm/__init__.py +10 -0
- torchrl/collectors/llm/base.py +494 -0
- torchrl/collectors/llm/ray_collector.py +275 -0
- torchrl/collectors/llm/utils.py +36 -0
- torchrl/collectors/llm/weight_update/__init__.py +10 -0
- torchrl/collectors/llm/weight_update/vllm.py +348 -0
- torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
- torchrl/collectors/utils.py +433 -0
- torchrl/collectors/weight_update.py +591 -0
- torchrl/csrc/numpy_utils.h +38 -0
- torchrl/csrc/pybind.cpp +27 -0
- torchrl/csrc/segment_tree.h +458 -0
- torchrl/csrc/torch_utils.h +34 -0
- torchrl/csrc/utils.cpp +48 -0
- torchrl/csrc/utils.h +31 -0
- torchrl/data/__init__.py +187 -0
- torchrl/data/datasets/__init__.py +58 -0
- torchrl/data/datasets/atari_dqn.py +878 -0
- torchrl/data/datasets/common.py +281 -0
- torchrl/data/datasets/d4rl.py +489 -0
- torchrl/data/datasets/d4rl_infos.py +187 -0
- torchrl/data/datasets/gen_dgrl.py +375 -0
- torchrl/data/datasets/minari_data.py +643 -0
- torchrl/data/datasets/openml.py +177 -0
- torchrl/data/datasets/openx.py +798 -0
- torchrl/data/datasets/roboset.py +363 -0
- torchrl/data/datasets/utils.py +11 -0
- torchrl/data/datasets/vd4rl.py +432 -0
- torchrl/data/llm/__init__.py +34 -0
- torchrl/data/llm/dataset.py +491 -0
- torchrl/data/llm/history.py +1378 -0
- torchrl/data/llm/prompt.py +198 -0
- torchrl/data/llm/reward.py +225 -0
- torchrl/data/llm/topk.py +186 -0
- torchrl/data/llm/utils.py +543 -0
- torchrl/data/map/__init__.py +21 -0
- torchrl/data/map/hash.py +185 -0
- torchrl/data/map/query.py +204 -0
- torchrl/data/map/tdstorage.py +363 -0
- torchrl/data/map/tree.py +1434 -0
- torchrl/data/map/utils.py +103 -0
- torchrl/data/postprocs/__init__.py +8 -0
- torchrl/data/postprocs/postprocs.py +391 -0
- torchrl/data/replay_buffers/__init__.py +99 -0
- torchrl/data/replay_buffers/checkpointers.py +622 -0
- torchrl/data/replay_buffers/ray_buffer.py +292 -0
- torchrl/data/replay_buffers/replay_buffers.py +2376 -0
- torchrl/data/replay_buffers/samplers.py +2578 -0
- torchrl/data/replay_buffers/scheduler.py +265 -0
- torchrl/data/replay_buffers/storages.py +2412 -0
- torchrl/data/replay_buffers/utils.py +1042 -0
- torchrl/data/replay_buffers/writers.py +781 -0
- torchrl/data/tensor_specs.py +7101 -0
- torchrl/data/utils.py +334 -0
- torchrl/envs/__init__.py +265 -0
- torchrl/envs/async_envs.py +1105 -0
- torchrl/envs/batched_envs.py +3093 -0
- torchrl/envs/common.py +4241 -0
- torchrl/envs/custom/__init__.py +11 -0
- torchrl/envs/custom/chess.py +617 -0
- torchrl/envs/custom/llm.py +214 -0
- torchrl/envs/custom/pendulum.py +401 -0
- torchrl/envs/custom/san_moves.txt +29274 -0
- torchrl/envs/custom/tictactoeenv.py +288 -0
- torchrl/envs/env_creator.py +263 -0
- torchrl/envs/gym_like.py +752 -0
- torchrl/envs/libs/__init__.py +68 -0
- torchrl/envs/libs/_gym_utils.py +326 -0
- torchrl/envs/libs/brax.py +846 -0
- torchrl/envs/libs/dm_control.py +544 -0
- torchrl/envs/libs/envpool.py +447 -0
- torchrl/envs/libs/gym.py +2239 -0
- torchrl/envs/libs/habitat.py +138 -0
- torchrl/envs/libs/isaac_lab.py +87 -0
- torchrl/envs/libs/isaacgym.py +203 -0
- torchrl/envs/libs/jax_utils.py +166 -0
- torchrl/envs/libs/jumanji.py +963 -0
- torchrl/envs/libs/meltingpot.py +599 -0
- torchrl/envs/libs/openml.py +153 -0
- torchrl/envs/libs/openspiel.py +652 -0
- torchrl/envs/libs/pettingzoo.py +1042 -0
- torchrl/envs/libs/procgen.py +351 -0
- torchrl/envs/libs/robohive.py +429 -0
- torchrl/envs/libs/smacv2.py +645 -0
- torchrl/envs/libs/unity_mlagents.py +891 -0
- torchrl/envs/libs/utils.py +147 -0
- torchrl/envs/libs/vmas.py +813 -0
- torchrl/envs/llm/__init__.py +63 -0
- torchrl/envs/llm/chat.py +730 -0
- torchrl/envs/llm/datasets/README.md +4 -0
- torchrl/envs/llm/datasets/__init__.py +17 -0
- torchrl/envs/llm/datasets/gsm8k.py +353 -0
- torchrl/envs/llm/datasets/ifeval.py +274 -0
- torchrl/envs/llm/envs.py +789 -0
- torchrl/envs/llm/libs/README.md +3 -0
- torchrl/envs/llm/libs/__init__.py +8 -0
- torchrl/envs/llm/libs/mlgym.py +869 -0
- torchrl/envs/llm/reward/__init__.py +10 -0
- torchrl/envs/llm/reward/gsm8k.py +324 -0
- torchrl/envs/llm/reward/ifeval/README.md +13 -0
- torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
- torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
- torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
- torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
- torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
- torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
- torchrl/envs/llm/transforms/__init__.py +55 -0
- torchrl/envs/llm/transforms/browser.py +292 -0
- torchrl/envs/llm/transforms/dataloading.py +859 -0
- torchrl/envs/llm/transforms/format.py +73 -0
- torchrl/envs/llm/transforms/kl.py +1544 -0
- torchrl/envs/llm/transforms/policy_version.py +189 -0
- torchrl/envs/llm/transforms/reason.py +323 -0
- torchrl/envs/llm/transforms/tokenizer.py +321 -0
- torchrl/envs/llm/transforms/tools.py +1955 -0
- torchrl/envs/model_based/__init__.py +9 -0
- torchrl/envs/model_based/common.py +180 -0
- torchrl/envs/model_based/dreamer.py +112 -0
- torchrl/envs/transforms/__init__.py +147 -0
- torchrl/envs/transforms/functional.py +48 -0
- torchrl/envs/transforms/gym_transforms.py +203 -0
- torchrl/envs/transforms/module.py +341 -0
- torchrl/envs/transforms/r3m.py +372 -0
- torchrl/envs/transforms/ray_service.py +663 -0
- torchrl/envs/transforms/rb_transforms.py +214 -0
- torchrl/envs/transforms/transforms.py +11835 -0
- torchrl/envs/transforms/utils.py +94 -0
- torchrl/envs/transforms/vc1.py +307 -0
- torchrl/envs/transforms/vecnorm.py +845 -0
- torchrl/envs/transforms/vip.py +407 -0
- torchrl/envs/utils.py +1718 -0
- torchrl/envs/vec_envs.py +11 -0
- torchrl/modules/__init__.py +206 -0
- torchrl/modules/distributions/__init__.py +73 -0
- torchrl/modules/distributions/continuous.py +830 -0
- torchrl/modules/distributions/discrete.py +908 -0
- torchrl/modules/distributions/truncated_normal.py +187 -0
- torchrl/modules/distributions/utils.py +233 -0
- torchrl/modules/llm/__init__.py +62 -0
- torchrl/modules/llm/backends/__init__.py +65 -0
- torchrl/modules/llm/backends/vllm/__init__.py +94 -0
- torchrl/modules/llm/backends/vllm/_models.py +46 -0
- torchrl/modules/llm/backends/vllm/base.py +72 -0
- torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
- torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
- torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
- torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
- torchrl/modules/llm/policies/__init__.py +28 -0
- torchrl/modules/llm/policies/common.py +1809 -0
- torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
- torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
- torchrl/modules/llm/utils.py +23 -0
- torchrl/modules/mcts/__init__.py +21 -0
- torchrl/modules/mcts/scores.py +579 -0
- torchrl/modules/models/__init__.py +86 -0
- torchrl/modules/models/batchrenorm.py +119 -0
- torchrl/modules/models/decision_transformer.py +179 -0
- torchrl/modules/models/exploration.py +731 -0
- torchrl/modules/models/llm.py +156 -0
- torchrl/modules/models/model_based.py +596 -0
- torchrl/modules/models/models.py +1712 -0
- torchrl/modules/models/multiagent.py +1067 -0
- torchrl/modules/models/recipes/impala.py +185 -0
- torchrl/modules/models/utils.py +162 -0
- torchrl/modules/planners/__init__.py +10 -0
- torchrl/modules/planners/cem.py +228 -0
- torchrl/modules/planners/common.py +73 -0
- torchrl/modules/planners/mppi.py +265 -0
- torchrl/modules/tensordict_module/__init__.py +89 -0
- torchrl/modules/tensordict_module/actors.py +2457 -0
- torchrl/modules/tensordict_module/common.py +529 -0
- torchrl/modules/tensordict_module/exploration.py +814 -0
- torchrl/modules/tensordict_module/probabilistic.py +321 -0
- torchrl/modules/tensordict_module/rnn.py +1639 -0
- torchrl/modules/tensordict_module/sequence.py +132 -0
- torchrl/modules/tensordict_module/world_models.py +34 -0
- torchrl/modules/utils/__init__.py +38 -0
- torchrl/modules/utils/mappings.py +9 -0
- torchrl/modules/utils/utils.py +89 -0
- torchrl/objectives/__init__.py +78 -0
- torchrl/objectives/a2c.py +659 -0
- torchrl/objectives/common.py +753 -0
- torchrl/objectives/cql.py +1346 -0
- torchrl/objectives/crossq.py +710 -0
- torchrl/objectives/ddpg.py +453 -0
- torchrl/objectives/decision_transformer.py +371 -0
- torchrl/objectives/deprecated.py +516 -0
- torchrl/objectives/dqn.py +683 -0
- torchrl/objectives/dreamer.py +488 -0
- torchrl/objectives/functional.py +48 -0
- torchrl/objectives/gail.py +258 -0
- torchrl/objectives/iql.py +996 -0
- torchrl/objectives/llm/__init__.py +30 -0
- torchrl/objectives/llm/grpo.py +846 -0
- torchrl/objectives/llm/sft.py +482 -0
- torchrl/objectives/multiagent/__init__.py +8 -0
- torchrl/objectives/multiagent/qmixer.py +396 -0
- torchrl/objectives/ppo.py +1669 -0
- torchrl/objectives/redq.py +683 -0
- torchrl/objectives/reinforce.py +530 -0
- torchrl/objectives/sac.py +1580 -0
- torchrl/objectives/td3.py +570 -0
- torchrl/objectives/td3_bc.py +625 -0
- torchrl/objectives/utils.py +782 -0
- torchrl/objectives/value/__init__.py +28 -0
- torchrl/objectives/value/advantages.py +1956 -0
- torchrl/objectives/value/functional.py +1459 -0
- torchrl/objectives/value/utils.py +360 -0
- torchrl/record/__init__.py +17 -0
- torchrl/record/loggers/__init__.py +23 -0
- torchrl/record/loggers/common.py +48 -0
- torchrl/record/loggers/csv.py +226 -0
- torchrl/record/loggers/mlflow.py +142 -0
- torchrl/record/loggers/tensorboard.py +139 -0
- torchrl/record/loggers/trackio.py +163 -0
- torchrl/record/loggers/utils.py +78 -0
- torchrl/record/loggers/wandb.py +214 -0
- torchrl/record/recorder.py +554 -0
- torchrl/services/__init__.py +79 -0
- torchrl/services/base.py +109 -0
- torchrl/services/ray_service.py +453 -0
- torchrl/testing/__init__.py +107 -0
- torchrl/testing/assertions.py +179 -0
- torchrl/testing/dist_utils.py +122 -0
- torchrl/testing/env_creators.py +227 -0
- torchrl/testing/env_helper.py +35 -0
- torchrl/testing/gym_helpers.py +156 -0
- torchrl/testing/llm_mocks.py +119 -0
- torchrl/testing/mocking_classes.py +2720 -0
- torchrl/testing/modules.py +295 -0
- torchrl/testing/mp_helpers.py +15 -0
- torchrl/testing/ray_helpers.py +293 -0
- torchrl/testing/utils.py +190 -0
- torchrl/trainers/__init__.py +42 -0
- torchrl/trainers/algorithms/__init__.py +11 -0
- torchrl/trainers/algorithms/configs/__init__.py +705 -0
- torchrl/trainers/algorithms/configs/collectors.py +216 -0
- torchrl/trainers/algorithms/configs/common.py +41 -0
- torchrl/trainers/algorithms/configs/data.py +308 -0
- torchrl/trainers/algorithms/configs/envs.py +104 -0
- torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
- torchrl/trainers/algorithms/configs/logging.py +80 -0
- torchrl/trainers/algorithms/configs/modules.py +570 -0
- torchrl/trainers/algorithms/configs/objectives.py +177 -0
- torchrl/trainers/algorithms/configs/trainers.py +340 -0
- torchrl/trainers/algorithms/configs/transforms.py +955 -0
- torchrl/trainers/algorithms/configs/utils.py +252 -0
- torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
- torchrl/trainers/algorithms/configs/weight_update.py +159 -0
- torchrl/trainers/algorithms/ppo.py +373 -0
- torchrl/trainers/algorithms/sac.py +308 -0
- torchrl/trainers/helpers/__init__.py +40 -0
- torchrl/trainers/helpers/collectors.py +416 -0
- torchrl/trainers/helpers/envs.py +573 -0
- torchrl/trainers/helpers/logger.py +33 -0
- torchrl/trainers/helpers/losses.py +132 -0
- torchrl/trainers/helpers/models.py +658 -0
- torchrl/trainers/helpers/replay_buffer.py +59 -0
- torchrl/trainers/helpers/trainers.py +301 -0
- torchrl/trainers/trainers.py +2052 -0
- torchrl/weight_update/__init__.py +33 -0
- torchrl/weight_update/_distributed.py +749 -0
- torchrl/weight_update/_mp.py +624 -0
- torchrl/weight_update/_noupdate.py +102 -0
- torchrl/weight_update/_ray.py +1032 -0
- torchrl/weight_update/_rpc.py +284 -0
- torchrl/weight_update/_shared.py +891 -0
- torchrl/weight_update/llm/__init__.py +32 -0
- torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
- torchrl/weight_update/llm/vllm_nccl.py +710 -0
- torchrl/weight_update/utils.py +73 -0
- torchrl/weight_update/weight_sync_schemes.py +1244 -0
- torchrl-0.11.0.dist-info/LICENSE +21 -0
- torchrl-0.11.0.dist-info/METADATA +1307 -0
- torchrl-0.11.0.dist-info/RECORD +394 -0
- torchrl-0.11.0.dist-info/WHEEL +5 -0
- torchrl-0.11.0.dist-info/entry_points.txt +2 -0
- torchrl-0.11.0.dist-info/top_level.txt +7 -0
|
@@ -0,0 +1,2578 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import json
|
|
8
|
+
import textwrap
|
|
9
|
+
import warnings
|
|
10
|
+
from abc import ABC, abstractmethod
|
|
11
|
+
from collections import OrderedDict
|
|
12
|
+
from copy import copy, deepcopy
|
|
13
|
+
from multiprocessing.context import get_spawning_popen
|
|
14
|
+
from pathlib import Path
|
|
15
|
+
from typing import Any
|
|
16
|
+
|
|
17
|
+
import numpy as np
|
|
18
|
+
import torch
|
|
19
|
+
from pyvers import implement_for
|
|
20
|
+
from tensordict import MemoryMappedTensor, TensorDict
|
|
21
|
+
from tensordict.utils import NestedKey
|
|
22
|
+
from torch.utils._pytree import tree_map
|
|
23
|
+
from torchrl._extension import EXTENSION_WARNING
|
|
24
|
+
from torchrl._utils import _replace_last, logger, rl_warnings
|
|
25
|
+
from torchrl.data.replay_buffers.storages import Storage, StorageEnsemble, TensorStorage
|
|
26
|
+
from torchrl.data.replay_buffers.utils import _auto_device, _is_int, unravel_index
|
|
27
|
+
|
|
28
|
+
try:
|
|
29
|
+
from torchrl._torchrl import (
|
|
30
|
+
MinSegmentTreeFp32,
|
|
31
|
+
MinSegmentTreeFp64,
|
|
32
|
+
SumSegmentTreeFp32,
|
|
33
|
+
SumSegmentTreeFp64,
|
|
34
|
+
)
|
|
35
|
+
except ImportError:
|
|
36
|
+
# Make default values
|
|
37
|
+
MinSegmentTreeFp32 = None
|
|
38
|
+
MinSegmentTreeFp64 = None
|
|
39
|
+
SumSegmentTreeFp32 = None
|
|
40
|
+
SumSegmentTreeFp64 = None
|
|
41
|
+
|
|
42
|
+
_EMPTY_STORAGE_ERROR = "Cannot sample from an empty storage."
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class Sampler(ABC):
|
|
46
|
+
"""A generic sampler base class for composable Replay Buffers."""
|
|
47
|
+
|
|
48
|
+
# Some samplers - mainly those without replacement -
|
|
49
|
+
# need to keep track of the number of remaining batches
|
|
50
|
+
_remaining_batches = int(torch.iinfo(torch.int64).max)
|
|
51
|
+
|
|
52
|
+
# The RNG is set by the replay buffer
|
|
53
|
+
_rng: torch.Generator | None = None
|
|
54
|
+
|
|
55
|
+
@abstractmethod
|
|
56
|
+
def sample(self, storage: Storage, batch_size: int) -> tuple[Any, dict]:
|
|
57
|
+
...
|
|
58
|
+
|
|
59
|
+
def add(self, index: int) -> None:
|
|
60
|
+
return
|
|
61
|
+
|
|
62
|
+
def extend(self, index: torch.Tensor) -> None:
|
|
63
|
+
return
|
|
64
|
+
|
|
65
|
+
def update_priority(
|
|
66
|
+
self,
|
|
67
|
+
index: int | torch.Tensor,
|
|
68
|
+
priority: float | torch.Tensor,
|
|
69
|
+
*,
|
|
70
|
+
storage: Storage | None = None,
|
|
71
|
+
) -> dict | None:
|
|
72
|
+
warnings.warn(
|
|
73
|
+
f"Calling update_priority() on a sampler {type(self).__name__} that is not prioritized. Make sure this is the indented behavior."
|
|
74
|
+
)
|
|
75
|
+
return
|
|
76
|
+
|
|
77
|
+
def mark_update(
|
|
78
|
+
self, index: int | torch.Tensor, *, storage: Storage | None = None
|
|
79
|
+
) -> None:
|
|
80
|
+
return
|
|
81
|
+
|
|
82
|
+
@property
|
|
83
|
+
def default_priority(self) -> float:
|
|
84
|
+
return 1.0
|
|
85
|
+
|
|
86
|
+
@abstractmethod
|
|
87
|
+
def state_dict(self) -> dict[str, Any]:
|
|
88
|
+
...
|
|
89
|
+
|
|
90
|
+
@abstractmethod
|
|
91
|
+
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
|
|
92
|
+
...
|
|
93
|
+
|
|
94
|
+
@property
|
|
95
|
+
def ran_out(self) -> bool:
|
|
96
|
+
# by default, samplers never run out
|
|
97
|
+
return False
|
|
98
|
+
|
|
99
|
+
@abstractmethod
|
|
100
|
+
def _empty(self):
|
|
101
|
+
...
|
|
102
|
+
|
|
103
|
+
@abstractmethod
|
|
104
|
+
def dumps(self, path):
|
|
105
|
+
...
|
|
106
|
+
|
|
107
|
+
@abstractmethod
|
|
108
|
+
def loads(self, path):
|
|
109
|
+
...
|
|
110
|
+
|
|
111
|
+
def __repr__(self):
|
|
112
|
+
return f"{self.__class__.__name__}()"
|
|
113
|
+
|
|
114
|
+
def __getstate__(self):
|
|
115
|
+
state = copy(self.__dict__)
|
|
116
|
+
state["_rng"] = None
|
|
117
|
+
return state
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
class RandomSampler(Sampler):
|
|
121
|
+
"""A uniformly random sampler for composable replay buffers.
|
|
122
|
+
|
|
123
|
+
Args:
|
|
124
|
+
batch_size (int, optional): if provided, the batch size to be used by
|
|
125
|
+
the replay buffer when calling :meth:`ReplayBuffer.sample`.
|
|
126
|
+
|
|
127
|
+
"""
|
|
128
|
+
|
|
129
|
+
def sample(self, storage: Storage, batch_size: int) -> tuple[torch.Tensor, dict]:
|
|
130
|
+
if len(storage) == 0:
|
|
131
|
+
raise RuntimeError(_EMPTY_STORAGE_ERROR)
|
|
132
|
+
index = storage._rand_given_ndim(batch_size)
|
|
133
|
+
return index, {}
|
|
134
|
+
|
|
135
|
+
def _empty(self):
|
|
136
|
+
pass
|
|
137
|
+
|
|
138
|
+
def dumps(self, path):
|
|
139
|
+
# no op
|
|
140
|
+
...
|
|
141
|
+
|
|
142
|
+
def loads(self, path):
|
|
143
|
+
# no op
|
|
144
|
+
...
|
|
145
|
+
|
|
146
|
+
def state_dict(self) -> dict[str, Any]:
|
|
147
|
+
return {}
|
|
148
|
+
|
|
149
|
+
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
|
|
150
|
+
return
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
class SamplerWithoutReplacement(Sampler):
|
|
154
|
+
"""A data-consuming sampler that ensures that the same sample is not present in consecutive batches.
|
|
155
|
+
|
|
156
|
+
Args:
|
|
157
|
+
drop_last (bool, optional): if ``True``, the last incomplete sample (if any) will be dropped.
|
|
158
|
+
If ``False``, this last sample will be kept and (unlike with torch dataloaders)
|
|
159
|
+
completed with other samples from a fresh indices permutation.
|
|
160
|
+
Defaults to ``False``.
|
|
161
|
+
shuffle (bool, optional): if ``False``, the items are not randomly
|
|
162
|
+
permuted. This enables to iterate over the replay buffer in the
|
|
163
|
+
order the data was collected. Defaults to ``True``.
|
|
164
|
+
|
|
165
|
+
*Caution*: If the size of the storage changes in between two calls, the samples will be re-shuffled
|
|
166
|
+
(as we can't generally keep track of which samples have been sampled before and which haven't).
|
|
167
|
+
|
|
168
|
+
Similarly, it is expected that the storage content remains the same in between two calls,
|
|
169
|
+
but this is not enforced.
|
|
170
|
+
|
|
171
|
+
When the sampler reaches the end of the list of available indices, a new sample order
|
|
172
|
+
will be generated and the resulting indices will be completed with this new draw, which
|
|
173
|
+
can lead to duplicated indices, unless the :obj:`drop_last` argument is set to ``True``.
|
|
174
|
+
|
|
175
|
+
"""
|
|
176
|
+
|
|
177
|
+
def __init__(self, drop_last: bool = False, shuffle: bool = True):
|
|
178
|
+
self._sample_list = None
|
|
179
|
+
self.len_storage = 0
|
|
180
|
+
self.drop_last = drop_last
|
|
181
|
+
self._ran_out = False
|
|
182
|
+
self.shuffle = shuffle
|
|
183
|
+
|
|
184
|
+
def dumps(self, path):
|
|
185
|
+
path = Path(path)
|
|
186
|
+
path.mkdir(exist_ok=True)
|
|
187
|
+
|
|
188
|
+
TensorDict(self.state_dict()).memmap(path)
|
|
189
|
+
|
|
190
|
+
def loads(self, path):
|
|
191
|
+
sd = TensorDict.load_memmap(path).to_dict()
|
|
192
|
+
self.load_state_dict(sd)
|
|
193
|
+
|
|
194
|
+
def _get_sample_list(self, storage: Storage, len_storage: int, batch_size: int):
|
|
195
|
+
if storage is None:
|
|
196
|
+
device = self._sample_list.device
|
|
197
|
+
else:
|
|
198
|
+
device = storage.device if hasattr(storage, "device") else None
|
|
199
|
+
|
|
200
|
+
if self.shuffle:
|
|
201
|
+
_sample_list = torch.randperm(
|
|
202
|
+
len_storage, device=device, generator=self._rng
|
|
203
|
+
)
|
|
204
|
+
else:
|
|
205
|
+
_sample_list = torch.arange(len_storage, device=device)
|
|
206
|
+
self._sample_list = _sample_list
|
|
207
|
+
if self.drop_last:
|
|
208
|
+
self._remaining_batches = self._sample_list.numel() // batch_size
|
|
209
|
+
else:
|
|
210
|
+
self._remaining_batches = -(self._sample_list.numel() // -batch_size)
|
|
211
|
+
|
|
212
|
+
def _single_sample(self, len_storage, batch_size):
|
|
213
|
+
index = self._sample_list[:batch_size]
|
|
214
|
+
self._sample_list = self._sample_list[batch_size:]
|
|
215
|
+
if self.drop_last:
|
|
216
|
+
self._remaining_batches = self._sample_list.numel() // batch_size
|
|
217
|
+
else:
|
|
218
|
+
self._remaining_batches = -(self._sample_list.numel() // -batch_size)
|
|
219
|
+
|
|
220
|
+
# check if we have enough elements for one more batch, assuming same batch size
|
|
221
|
+
# will be used each time sample is called
|
|
222
|
+
if self._sample_list.shape[0] == 0 or (
|
|
223
|
+
self.drop_last and len(self._sample_list) < batch_size
|
|
224
|
+
):
|
|
225
|
+
self.ran_out = True
|
|
226
|
+
self._get_sample_list(
|
|
227
|
+
storage=None, len_storage=len_storage, batch_size=batch_size
|
|
228
|
+
)
|
|
229
|
+
else:
|
|
230
|
+
self.ran_out = False
|
|
231
|
+
return index
|
|
232
|
+
|
|
233
|
+
def _storage_len(self, storage):
|
|
234
|
+
return len(storage)
|
|
235
|
+
|
|
236
|
+
def sample(
|
|
237
|
+
self, storage: Storage, batch_size: int
|
|
238
|
+
) -> tuple[Any, dict]: # noqa: F811
|
|
239
|
+
len_storage = self._storage_len(storage)
|
|
240
|
+
if len_storage == 0:
|
|
241
|
+
raise RuntimeError(_EMPTY_STORAGE_ERROR)
|
|
242
|
+
if not len_storage:
|
|
243
|
+
raise RuntimeError("An empty storage was passed")
|
|
244
|
+
if self.len_storage != len_storage or self._sample_list is None:
|
|
245
|
+
self._get_sample_list(storage, len_storage, batch_size=batch_size)
|
|
246
|
+
if len_storage < batch_size and self.drop_last:
|
|
247
|
+
raise ValueError(
|
|
248
|
+
f"The batch size ({batch_size}) is greater than the storage capacity ({len_storage}). "
|
|
249
|
+
"This makes it impossible to return a sample without repeating indices. "
|
|
250
|
+
"Consider changing the sampler class or turn the 'drop_last' argument to False."
|
|
251
|
+
)
|
|
252
|
+
self.len_storage = len_storage
|
|
253
|
+
index = self._single_sample(len_storage, batch_size)
|
|
254
|
+
if storage.ndim > 1:
|
|
255
|
+
index = unravel_index(index, storage.shape)
|
|
256
|
+
# we 'always' return the indices. The 'drop_last' just instructs the
|
|
257
|
+
# sampler to turn to `ran_out = True` whenever the next sample
|
|
258
|
+
# will be too short. This will be read by the replay buffer
|
|
259
|
+
# as a signal for an early break of the __iter__().
|
|
260
|
+
return index, {}
|
|
261
|
+
|
|
262
|
+
@property
|
|
263
|
+
def ran_out(self):
|
|
264
|
+
return self._ran_out
|
|
265
|
+
|
|
266
|
+
@ran_out.setter
|
|
267
|
+
def ran_out(self, value):
|
|
268
|
+
self._ran_out = value
|
|
269
|
+
|
|
270
|
+
def _empty(self):
|
|
271
|
+
self._sample_list = None
|
|
272
|
+
self.len_storage = 0
|
|
273
|
+
self._ran_out = False
|
|
274
|
+
|
|
275
|
+
def state_dict(self) -> dict[str, Any]:
|
|
276
|
+
return OrderedDict(
|
|
277
|
+
len_storage=self.len_storage,
|
|
278
|
+
_sample_list=self._sample_list,
|
|
279
|
+
drop_last=self.drop_last,
|
|
280
|
+
_ran_out=self._ran_out,
|
|
281
|
+
)
|
|
282
|
+
|
|
283
|
+
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
|
|
284
|
+
self.len_storage = state_dict["len_storage"]
|
|
285
|
+
self._sample_list = state_dict["_sample_list"]
|
|
286
|
+
self.drop_last = state_dict["drop_last"]
|
|
287
|
+
self._ran_out = state_dict["_ran_out"]
|
|
288
|
+
|
|
289
|
+
def __repr__(self):
|
|
290
|
+
if self._sample_list is not None:
|
|
291
|
+
perc = len(self._sample_list) / self.len_storage * 100
|
|
292
|
+
else:
|
|
293
|
+
perc = 0.0
|
|
294
|
+
return f"{self.__class__.__name__}({perc: 4.4f}% sampled)"
|
|
295
|
+
|
|
296
|
+
|
|
297
|
+
class PrioritizedSampler(Sampler):
|
|
298
|
+
r"""Prioritized sampler for replay buffer.
|
|
299
|
+
|
|
300
|
+
This sampler implements Prioritized Experience Replay (PER) as presented in
|
|
301
|
+
"Schaul, T.; Quan, J.; Antonoglou, I.; and Silver, D. 2015. Prioritized experience replay."
|
|
302
|
+
(https://arxiv.org/abs/1511.05952)
|
|
303
|
+
|
|
304
|
+
**Core Idea**: Instead of sampling experiences uniformly from the replay buffer,
|
|
305
|
+
PER samples experiences with probability proportional to their "importance" - typically
|
|
306
|
+
measured by the magnitude of their temporal-difference (TD) error. This prioritization
|
|
307
|
+
can lead to faster learning by focusing on experiences that are most informative.
|
|
308
|
+
|
|
309
|
+
**How it works**:
|
|
310
|
+
1. Each experience is assigned a priority based on its TD error: :math:`p_i = |\delta_i| + \epsilon`
|
|
311
|
+
2. Sampling probability is computed as: :math:`P(i) = \frac{p_i^\alpha}{\sum_j p_j^\alpha}`
|
|
312
|
+
3. Importance sampling weights correct for the bias: :math:`w_i = (N \cdot P(i))^{-\beta}`
|
|
313
|
+
|
|
314
|
+
Args:
|
|
315
|
+
max_capacity (int): maximum capacity of the buffer.
|
|
316
|
+
alpha (:obj:`float`): exponent :math:`\alpha` determines how much prioritization is used.
|
|
317
|
+
- :math:`\alpha = 0`: uniform sampling (no prioritization)
|
|
318
|
+
- :math:`\alpha = 1`: full prioritization based on TD error magnitude
|
|
319
|
+
- Typical values: 0.4-0.7 for balanced prioritization
|
|
320
|
+
- Higher :math:`\alpha` means more aggressive prioritization of high-error experiences
|
|
321
|
+
beta (:obj:`float`): importance sampling negative exponent :math:`\beta`.
|
|
322
|
+
- :math:`\beta` controls the correction for the bias introduced by prioritization
|
|
323
|
+
- :math:`\beta = 0`: no correction (biased towards high-priority samples)
|
|
324
|
+
- :math:`\beta = 1`: full correction (unbiased but potentially unstable)
|
|
325
|
+
- Typical values: start at 0.4-0.6 and anneal to 1.0 during training
|
|
326
|
+
- Lower :math:`\beta` early in training provides stability, higher :math:`\beta` later reduces bias
|
|
327
|
+
eps (:obj:`float`, optional): small constant added to priorities to ensure
|
|
328
|
+
no experience has zero priority. This prevents experiences from never
|
|
329
|
+
being sampled. Defaults to 1e-8.
|
|
330
|
+
reduction (str, optional): the reduction method for multidimensional
|
|
331
|
+
tensordicts (ie stored trajectory). Can be one of "max", "min",
|
|
332
|
+
"median" or "mean".
|
|
333
|
+
max_priority_within_buffer (bool, optional): if ``True``, the max-priority
|
|
334
|
+
is tracked within the buffer. When ``False``, the max-priority tracks
|
|
335
|
+
the maximum value since the instantiation of the sampler.
|
|
336
|
+
|
|
337
|
+
**Parameter Guidelines**:
|
|
338
|
+
|
|
339
|
+
- **:math:`\alpha` (alpha)**: Controls how much to prioritize high-error experiences.
|
|
340
|
+
0.4-0.7: Good balance between learning speed and stability.
|
|
341
|
+
1.0: Maximum prioritization (may be unstable).
|
|
342
|
+
0.0: Uniform sampling (no prioritization benefit).
|
|
343
|
+
|
|
344
|
+
- **:math:`\beta` (beta)**: Controls importance sampling correction.
|
|
345
|
+
Start at 0.4-0.6 for training stability.
|
|
346
|
+
Anneal to 1.0 over training to reduce bias.
|
|
347
|
+
Lower values = more stable but biased.
|
|
348
|
+
Higher values = less biased but potentially unstable.
|
|
349
|
+
|
|
350
|
+
- **:math:`\epsilon`**: Small constant to prevent zero priorities.
|
|
351
|
+
1e-8: Good default value.
|
|
352
|
+
Too small: may cause numerical issues.
|
|
353
|
+
Too large: reduces prioritization effect.
|
|
354
|
+
|
|
355
|
+
Examples:
|
|
356
|
+
>>> from torchrl.data.replay_buffers import ReplayBuffer, LazyTensorStorage, PrioritizedSampler
|
|
357
|
+
>>> from tensordict import TensorDict
|
|
358
|
+
>>> rb = ReplayBuffer(storage=LazyTensorStorage(10), sampler=PrioritizedSampler(max_capacity=10, alpha=1.0, beta=1.0))
|
|
359
|
+
>>> priority = torch.tensor([0, 1000])
|
|
360
|
+
>>> data_0 = TensorDict({"reward": 0, "obs": [0], "action": [0], "priority": priority[0]}, [])
|
|
361
|
+
>>> data_1 = TensorDict({"reward": 1, "obs": [1], "action": [2], "priority": priority[1]}, [])
|
|
362
|
+
>>> rb.add(data_0)
|
|
363
|
+
>>> rb.add(data_1)
|
|
364
|
+
>>> rb.update_priority(torch.tensor([0, 1]), priority=priority)
|
|
365
|
+
>>> sample, info = rb.sample(10, return_info=True)
|
|
366
|
+
>>> print(sample)
|
|
367
|
+
TensorDict(
|
|
368
|
+
fields={
|
|
369
|
+
action: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.int64, is_shared=False),
|
|
370
|
+
obs: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.int64, is_shared=False),
|
|
371
|
+
priority: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.int64, is_shared=False),
|
|
372
|
+
reward: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.int64, is_shared=False)},
|
|
373
|
+
batch_size=torch.Size([10]),
|
|
374
|
+
device=cpu,
|
|
375
|
+
is_shared=False)
|
|
376
|
+
>>> print(info)
|
|
377
|
+
{'priority_weight': array([1.e-11, 1.e-11, 1.e-11, 1.e-11, 1.e-11, 1.e-11, 1.e-11, 1.e-11,
|
|
378
|
+
1.e-11, 1.e-11], dtype=float32), 'index': array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1])}
|
|
379
|
+
|
|
380
|
+
.. note:: Using a :class:`~torchrl.data.replay_buffers.TensorDictReplayBuffer` can smoothen the
|
|
381
|
+
process of updating the priorities:
|
|
382
|
+
|
|
383
|
+
>>> from torchrl.data.replay_buffers import TensorDictReplayBuffer as TDRB, LazyTensorStorage, PrioritizedSampler
|
|
384
|
+
>>> from tensordict import TensorDict
|
|
385
|
+
>>> rb = TDRB(
|
|
386
|
+
... storage=LazyTensorStorage(10),
|
|
387
|
+
... sampler=PrioritizedSampler(max_capacity=10, alpha=1.0, beta=1.0),
|
|
388
|
+
... priority_key="priority", # This kwarg isn't present in regular RBs
|
|
389
|
+
... )
|
|
390
|
+
>>> priority = torch.tensor([0, 1000])
|
|
391
|
+
>>> data_0 = TensorDict({"reward": 0, "obs": [0], "action": [0], "priority": priority[0]}, [])
|
|
392
|
+
>>> data_1 = TensorDict({"reward": 1, "obs": [1], "action": [2], "priority": priority[1]}, [])
|
|
393
|
+
>>> data = torch.stack([data_0, data_1])
|
|
394
|
+
>>> rb.extend(data)
|
|
395
|
+
>>> rb.update_priority(data) # Reads the "priority" key as indicated in the constructor
|
|
396
|
+
>>> sample, info = rb.sample(10, return_info=True)
|
|
397
|
+
>>> print(sample['index']) # The index is packed with the tensordict
|
|
398
|
+
tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1])
|
|
399
|
+
|
|
400
|
+
"""
|
|
401
|
+
|
|
402
|
+
def __init__(
|
|
403
|
+
self,
|
|
404
|
+
max_capacity: int,
|
|
405
|
+
alpha: float,
|
|
406
|
+
beta: float,
|
|
407
|
+
eps: float = 1e-8,
|
|
408
|
+
dtype: torch.dtype = torch.float,
|
|
409
|
+
reduction: str = "max",
|
|
410
|
+
max_priority_within_buffer: bool = False,
|
|
411
|
+
) -> None:
|
|
412
|
+
if alpha < 0:
|
|
413
|
+
raise ValueError(
|
|
414
|
+
f"alpha must be greater or equal than 0, got alpha={alpha}"
|
|
415
|
+
)
|
|
416
|
+
if beta < 0:
|
|
417
|
+
raise ValueError(f"beta must be greater or equal to 0, got beta={beta}")
|
|
418
|
+
|
|
419
|
+
self._max_capacity = max_capacity
|
|
420
|
+
self._alpha = alpha
|
|
421
|
+
self._beta = beta
|
|
422
|
+
self._eps = eps
|
|
423
|
+
self.reduction = reduction
|
|
424
|
+
self.dtype = dtype
|
|
425
|
+
self._max_priority_within_buffer = max_priority_within_buffer
|
|
426
|
+
self._init()
|
|
427
|
+
if rl_warnings() and SumSegmentTreeFp32 is None:
|
|
428
|
+
logger.warning(EXTENSION_WARNING)
|
|
429
|
+
|
|
430
|
+
def __repr__(self):
|
|
431
|
+
return f"{self.__class__.__name__}(alpha={self._alpha}, beta={self._beta}, eps={self._eps}, reduction={self.reduction})"
|
|
432
|
+
|
|
433
|
+
@property
|
|
434
|
+
def max_size(self):
|
|
435
|
+
return self._max_capacity
|
|
436
|
+
|
|
437
|
+
@property
|
|
438
|
+
def alpha(self):
|
|
439
|
+
return self._alpha
|
|
440
|
+
|
|
441
|
+
@alpha.setter
|
|
442
|
+
def alpha(self, value):
|
|
443
|
+
self._alpha = value
|
|
444
|
+
|
|
445
|
+
@property
|
|
446
|
+
def beta(self):
|
|
447
|
+
return self._beta
|
|
448
|
+
|
|
449
|
+
@beta.setter
|
|
450
|
+
def beta(self, value):
|
|
451
|
+
self._beta = value
|
|
452
|
+
|
|
453
|
+
def __getstate__(self):
|
|
454
|
+
if get_spawning_popen() is not None:
|
|
455
|
+
raise RuntimeError(
|
|
456
|
+
f"Samplers of type {type(self)} cannot be shared between processes."
|
|
457
|
+
)
|
|
458
|
+
return super().__getstate__()
|
|
459
|
+
|
|
460
|
+
def _init(self) -> None:
|
|
461
|
+
if SumSegmentTreeFp32 is None:
|
|
462
|
+
raise RuntimeError(
|
|
463
|
+
"SumSegmentTreeFp32 is not available. See warning above."
|
|
464
|
+
)
|
|
465
|
+
if MinSegmentTreeFp32 is None:
|
|
466
|
+
raise RuntimeError(
|
|
467
|
+
"MinSegmentTreeFp32 is not available. See warning above."
|
|
468
|
+
)
|
|
469
|
+
if SumSegmentTreeFp64 is None:
|
|
470
|
+
raise RuntimeError(
|
|
471
|
+
"SumSegmentTreeFp64 is not available. See warning above."
|
|
472
|
+
)
|
|
473
|
+
if MinSegmentTreeFp64 is None:
|
|
474
|
+
raise RuntimeError(
|
|
475
|
+
"MinSegmentTreeFp64 is not available. See warning above."
|
|
476
|
+
)
|
|
477
|
+
if self.dtype in (torch.float, torch.FloatType, torch.float32):
|
|
478
|
+
self._sum_tree = SumSegmentTreeFp32(self._max_capacity)
|
|
479
|
+
self._min_tree = MinSegmentTreeFp32(self._max_capacity)
|
|
480
|
+
elif self.dtype in (torch.double, torch.DoubleTensor, torch.float64):
|
|
481
|
+
self._sum_tree = SumSegmentTreeFp64(self._max_capacity)
|
|
482
|
+
self._min_tree = MinSegmentTreeFp64(self._max_capacity)
|
|
483
|
+
else:
|
|
484
|
+
raise NotImplementedError(
|
|
485
|
+
f"dtype {self.dtype} not supported by PrioritizedSampler"
|
|
486
|
+
)
|
|
487
|
+
self._max_priority = None
|
|
488
|
+
|
|
489
|
+
def _empty(self) -> None:
|
|
490
|
+
self._init()
|
|
491
|
+
|
|
492
|
+
@property
|
|
493
|
+
def _max_priority(self) -> tuple[float | None, int | None]:
|
|
494
|
+
max_priority_index = self.__dict__.get("_max_priority")
|
|
495
|
+
if max_priority_index is None:
|
|
496
|
+
return (None, None)
|
|
497
|
+
return max_priority_index
|
|
498
|
+
|
|
499
|
+
@_max_priority.setter
|
|
500
|
+
def _max_priority(self, value: tuple[float | None, int | None]) -> None:
|
|
501
|
+
self.__dict__["_max_priority"] = value
|
|
502
|
+
|
|
503
|
+
def _maybe_erase_max_priority(
|
|
504
|
+
self, index: torch.Tensor | int | slice | tuple
|
|
505
|
+
) -> None:
|
|
506
|
+
if not self._max_priority_within_buffer:
|
|
507
|
+
return
|
|
508
|
+
max_priority_index = self._max_priority[1]
|
|
509
|
+
if max_priority_index is None:
|
|
510
|
+
return
|
|
511
|
+
|
|
512
|
+
def check_index(index=index, max_priority_index=max_priority_index):
|
|
513
|
+
if isinstance(index, torch.Tensor):
|
|
514
|
+
# index can be 1d or 2d
|
|
515
|
+
if index.ndim == 1:
|
|
516
|
+
is_overwritten = (index == max_priority_index).any()
|
|
517
|
+
else:
|
|
518
|
+
is_overwritten = (index == max_priority_index).all(-1).any()
|
|
519
|
+
elif isinstance(index, int):
|
|
520
|
+
is_overwritten = index == max_priority_index
|
|
521
|
+
elif isinstance(index, slice):
|
|
522
|
+
# This won't work if called recursively
|
|
523
|
+
is_overwritten = max_priority_index in range(
|
|
524
|
+
index.indices(self._max_capacity)
|
|
525
|
+
)
|
|
526
|
+
elif isinstance(index, tuple):
|
|
527
|
+
is_overwritten = isinstance(max_priority_index, tuple)
|
|
528
|
+
if is_overwritten:
|
|
529
|
+
for idx, mpi in zip(index, max_priority_index):
|
|
530
|
+
is_overwritten &= check_index(idx, mpi)
|
|
531
|
+
else:
|
|
532
|
+
raise TypeError(f"index of type {type(index)} is not recognized.")
|
|
533
|
+
return is_overwritten
|
|
534
|
+
|
|
535
|
+
is_overwritten = check_index()
|
|
536
|
+
if is_overwritten:
|
|
537
|
+
self._max_priority = None
|
|
538
|
+
|
|
539
|
+
@property
|
|
540
|
+
def default_priority(self) -> float:
|
|
541
|
+
mp = self._max_priority[0]
|
|
542
|
+
if mp is None:
|
|
543
|
+
mp = 1
|
|
544
|
+
return (mp + self._eps) ** self._alpha
|
|
545
|
+
|
|
546
|
+
def sample(self, storage: Storage, batch_size: int) -> torch.Tensor:
|
|
547
|
+
if len(storage) == 0:
|
|
548
|
+
raise RuntimeError(_EMPTY_STORAGE_ERROR)
|
|
549
|
+
p_sum = self._sum_tree.query(0, len(storage))
|
|
550
|
+
p_min = self._min_tree.query(0, len(storage))
|
|
551
|
+
|
|
552
|
+
if p_sum <= 0:
|
|
553
|
+
raise RuntimeError("non-positive p_sum")
|
|
554
|
+
if p_min <= 0:
|
|
555
|
+
raise RuntimeError("non-positive p_min")
|
|
556
|
+
# For some undefined reason, only np.random works here.
|
|
557
|
+
# All PT attempts fail, even when subsequently transformed into numpy
|
|
558
|
+
if self._rng is None:
|
|
559
|
+
mass = np.random.uniform(0.0, p_sum, size=batch_size)
|
|
560
|
+
else:
|
|
561
|
+
mass = torch.rand(batch_size, generator=self._rng) * p_sum
|
|
562
|
+
|
|
563
|
+
# mass = torch.zeros(batch_size, dtype=torch.double).uniform_(0.0, p_sum)
|
|
564
|
+
# mass = torch.rand(batch_size).mul_(p_sum)
|
|
565
|
+
index = self._sum_tree.scan_lower_bound(mass)
|
|
566
|
+
index = torch.as_tensor(index)
|
|
567
|
+
if not index.ndim:
|
|
568
|
+
index = index.unsqueeze(0)
|
|
569
|
+
index.clamp_max_(len(storage) - 1)
|
|
570
|
+
weight = torch.as_tensor(self._sum_tree[index])
|
|
571
|
+
# get indices where weight is 0
|
|
572
|
+
zero_weight = weight == 0
|
|
573
|
+
index = index
|
|
574
|
+
while zero_weight.any():
|
|
575
|
+
index = torch.where(zero_weight, index - 1, index)
|
|
576
|
+
if (index < 0).any():
|
|
577
|
+
raise RuntimeError("Failed to find a suitable index")
|
|
578
|
+
weight = torch.as_tensor(self._sum_tree[index])
|
|
579
|
+
zero_weight = weight == 0
|
|
580
|
+
|
|
581
|
+
# Importance sampling weight formula:
|
|
582
|
+
# w_i = (p_i / sum(p) * N) ^ (-beta)
|
|
583
|
+
# weight_i = w_i / max(w)
|
|
584
|
+
# weight_i = (p_i / sum(p) * N) ^ (-beta) /
|
|
585
|
+
# ((min(p) / sum(p) * N) ^ (-beta))
|
|
586
|
+
# weight_i = ((p_i / sum(p) * N) / (min(p) / sum(p) * N)) ^ (-beta)
|
|
587
|
+
# weight_i = (p_i / min(p)) ^ (-beta)
|
|
588
|
+
# weight = np.power(weight / (p_min + self._eps), -self._beta)
|
|
589
|
+
weight = torch.pow(weight / p_min, -self._beta)
|
|
590
|
+
if storage.ndim > 1:
|
|
591
|
+
index = unravel_index(index, storage.shape)
|
|
592
|
+
return index, {"priority_weight": weight}
|
|
593
|
+
|
|
594
|
+
def add(self, index: torch.Tensor | int) -> None:
|
|
595
|
+
super().add(index)
|
|
596
|
+
self._maybe_erase_max_priority(index)
|
|
597
|
+
|
|
598
|
+
def extend(self, index: torch.Tensor | tuple) -> None:
|
|
599
|
+
super().extend(index)
|
|
600
|
+
self._maybe_erase_max_priority(index)
|
|
601
|
+
|
|
602
|
+
@torch.no_grad()
|
|
603
|
+
def update_priority(
|
|
604
|
+
self,
|
|
605
|
+
index: int | torch.Tensor,
|
|
606
|
+
priority: float | torch.Tensor,
|
|
607
|
+
*,
|
|
608
|
+
storage: TensorStorage | None = None,
|
|
609
|
+
) -> None: # noqa: D417
|
|
610
|
+
"""Updates the priority of the data pointed by the index.
|
|
611
|
+
|
|
612
|
+
Args:
|
|
613
|
+
index (int or torch.Tensor): indexes of the priorities to be
|
|
614
|
+
updated.
|
|
615
|
+
priority (Number or torch.Tensor): new priorities of the
|
|
616
|
+
indexed elements.
|
|
617
|
+
|
|
618
|
+
Keyword Args:
|
|
619
|
+
storage (Storage, optional): a storage used to map the Nd index size to
|
|
620
|
+
the 1d size of the sum_tree and min_tree. Only required whenever
|
|
621
|
+
``index.ndim > 2``.
|
|
622
|
+
|
|
623
|
+
"""
|
|
624
|
+
priority = torch.as_tensor(priority, device=torch.device("cpu")).detach()
|
|
625
|
+
index = torch.as_tensor(index, dtype=torch.long, device=torch.device("cpu"))
|
|
626
|
+
# we need to reshape priority if it has more than one element or if it has
|
|
627
|
+
# a different shape than index
|
|
628
|
+
if priority.numel() > 1 and priority.shape != index.shape:
|
|
629
|
+
try:
|
|
630
|
+
priority = priority.reshape(index.shape[:1])
|
|
631
|
+
except Exception as err:
|
|
632
|
+
raise RuntimeError(
|
|
633
|
+
"priority should be a number or an iterable of the same "
|
|
634
|
+
f"length as index. Got priority of shape {priority.shape} and index "
|
|
635
|
+
f"{index.shape}."
|
|
636
|
+
) from err
|
|
637
|
+
elif priority.numel() <= 1:
|
|
638
|
+
priority = priority.squeeze()
|
|
639
|
+
|
|
640
|
+
# MaxValueWriter will set -1 for items in the data that we don't want
|
|
641
|
+
# to update. We therefore have to keep only the non-negative indices.
|
|
642
|
+
if _is_int(index):
|
|
643
|
+
if index == -1:
|
|
644
|
+
return
|
|
645
|
+
else:
|
|
646
|
+
if index.ndim > 1:
|
|
647
|
+
if storage is None:
|
|
648
|
+
raise RuntimeError(
|
|
649
|
+
"storage should be provided to Sampler.update_priority when the storage has more "
|
|
650
|
+
"than one dimension."
|
|
651
|
+
)
|
|
652
|
+
try:
|
|
653
|
+
shape = storage.shape
|
|
654
|
+
except AttributeError:
|
|
655
|
+
raise AttributeError(
|
|
656
|
+
"Could not retrieve the storage shape. If your storage is not a TensorStorage subclass "
|
|
657
|
+
"or its shape isn't accessible via the shape attribute, submit an issue on GitHub."
|
|
658
|
+
)
|
|
659
|
+
index = torch.as_tensor(np.ravel_multi_index(index.unbind(-1), shape))
|
|
660
|
+
valid_index = index >= 0
|
|
661
|
+
if not valid_index.any():
|
|
662
|
+
return
|
|
663
|
+
if not valid_index.all():
|
|
664
|
+
index = index[valid_index]
|
|
665
|
+
if priority.ndim:
|
|
666
|
+
priority = priority[valid_index]
|
|
667
|
+
|
|
668
|
+
max_p, max_p_idx = priority.max(dim=0)
|
|
669
|
+
cur_max_priority, cur_max_priority_index = self._max_priority
|
|
670
|
+
if cur_max_priority is None or max_p > cur_max_priority:
|
|
671
|
+
cur_max_priority, cur_max_priority_index = self._max_priority = (
|
|
672
|
+
max_p,
|
|
673
|
+
index[max_p_idx] if index.ndim else index,
|
|
674
|
+
)
|
|
675
|
+
priority = torch.pow(priority + self._eps, self._alpha)
|
|
676
|
+
self._sum_tree[index] = priority
|
|
677
|
+
self._min_tree[index] = priority
|
|
678
|
+
if (
|
|
679
|
+
self._max_priority_within_buffer
|
|
680
|
+
and cur_max_priority_index is not None
|
|
681
|
+
and (index == cur_max_priority_index).any()
|
|
682
|
+
):
|
|
683
|
+
maxval, maxidx = torch.tensor(
|
|
684
|
+
[self._sum_tree[i] for i in range(self._max_capacity)]
|
|
685
|
+
).max(0)
|
|
686
|
+
self._max_priority = (maxval, maxidx)
|
|
687
|
+
|
|
688
|
+
def mark_update(
|
|
689
|
+
self, index: int | torch.Tensor, *, storage: Storage | None = None
|
|
690
|
+
) -> None:
|
|
691
|
+
self.update_priority(index, self.default_priority, storage=storage)
|
|
692
|
+
|
|
693
|
+
def state_dict(self) -> dict[str, Any]:
|
|
694
|
+
return {
|
|
695
|
+
"_alpha": self._alpha,
|
|
696
|
+
"_beta": self._beta,
|
|
697
|
+
"_eps": self._eps,
|
|
698
|
+
"_max_priority": self._max_priority,
|
|
699
|
+
"_sum_tree": deepcopy(self._sum_tree),
|
|
700
|
+
"_min_tree": deepcopy(self._min_tree),
|
|
701
|
+
}
|
|
702
|
+
|
|
703
|
+
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
|
|
704
|
+
self._alpha = state_dict["_alpha"]
|
|
705
|
+
self._beta = state_dict["_beta"]
|
|
706
|
+
self._eps = state_dict["_eps"]
|
|
707
|
+
self._max_priority = state_dict["_max_priority"]
|
|
708
|
+
self._sum_tree = state_dict.pop("_sum_tree")
|
|
709
|
+
self._min_tree = state_dict.pop("_min_tree")
|
|
710
|
+
|
|
711
|
+
@implement_for("torch", None, "2.5.0")
|
|
712
|
+
def dumps(self, path):
|
|
713
|
+
raise NotImplementedError("This method is not implemented for Torch < 2.5.0")
|
|
714
|
+
|
|
715
|
+
@implement_for("torch", "2.5.0", None)
|
|
716
|
+
def dumps(self, path): # noqa: F811
|
|
717
|
+
path = Path(path).absolute()
|
|
718
|
+
path.mkdir(exist_ok=True)
|
|
719
|
+
try:
|
|
720
|
+
mm_st = MemoryMappedTensor.from_filename(
|
|
721
|
+
shape=(self._max_capacity,),
|
|
722
|
+
dtype=torch.float64,
|
|
723
|
+
filename=path / "sumtree.memmap",
|
|
724
|
+
)
|
|
725
|
+
mm_mt = MemoryMappedTensor.from_filename(
|
|
726
|
+
shape=(self._max_capacity,),
|
|
727
|
+
dtype=torch.float64,
|
|
728
|
+
filename=path / "mintree.memmap",
|
|
729
|
+
)
|
|
730
|
+
except FileNotFoundError:
|
|
731
|
+
mm_st = MemoryMappedTensor.empty(
|
|
732
|
+
(self._max_capacity,),
|
|
733
|
+
dtype=torch.float64,
|
|
734
|
+
filename=path / "sumtree.memmap",
|
|
735
|
+
)
|
|
736
|
+
mm_mt = MemoryMappedTensor.empty(
|
|
737
|
+
(self._max_capacity,),
|
|
738
|
+
dtype=torch.float64,
|
|
739
|
+
filename=path / "mintree.memmap",
|
|
740
|
+
)
|
|
741
|
+
mm_st.copy_(
|
|
742
|
+
torch.as_tensor([self._sum_tree[i] for i in range(self._max_capacity)])
|
|
743
|
+
)
|
|
744
|
+
mm_mt.copy_(
|
|
745
|
+
torch.as_tensor([self._min_tree[i] for i in range(self._max_capacity)])
|
|
746
|
+
)
|
|
747
|
+
with open(path / "sampler_metadata.json", "w") as file:
|
|
748
|
+
json.dump(
|
|
749
|
+
tree_map(
|
|
750
|
+
float,
|
|
751
|
+
{
|
|
752
|
+
"_alpha": self._alpha,
|
|
753
|
+
"_beta": self._beta,
|
|
754
|
+
"_eps": self._eps,
|
|
755
|
+
"_max_priority": self._max_priority,
|
|
756
|
+
"_max_capacity": self._max_capacity,
|
|
757
|
+
},
|
|
758
|
+
),
|
|
759
|
+
file,
|
|
760
|
+
)
|
|
761
|
+
|
|
762
|
+
@implement_for("torch", None, "2.5.0")
|
|
763
|
+
def loads(self, path):
|
|
764
|
+
raise NotImplementedError("This method is not implemented for Torch < 2.5.0")
|
|
765
|
+
|
|
766
|
+
@implement_for("torch", "2.5.0", None)
|
|
767
|
+
def loads(self, path): # noqa: F811
|
|
768
|
+
path = Path(path).absolute()
|
|
769
|
+
with open(path / "sampler_metadata.json") as file:
|
|
770
|
+
metadata = json.load(file)
|
|
771
|
+
self._alpha = metadata["_alpha"]
|
|
772
|
+
self._beta = metadata["_beta"]
|
|
773
|
+
self._eps = metadata["_eps"]
|
|
774
|
+
maxp = tree_map(
|
|
775
|
+
lambda dest, orig: dest.copy_(orig) if dest is not None else orig,
|
|
776
|
+
tuple(self._max_priority),
|
|
777
|
+
tuple(metadata["_max_priority"]),
|
|
778
|
+
)
|
|
779
|
+
if all(x is None for x in self._max_priority):
|
|
780
|
+
self._max_priority = maxp
|
|
781
|
+
_max_capacity = metadata["_max_capacity"]
|
|
782
|
+
if _max_capacity != self._max_capacity:
|
|
783
|
+
raise RuntimeError(
|
|
784
|
+
f"max capacity of loaded metadata ({_max_capacity}) differs from self._max_capacity ({self._max_capacity})."
|
|
785
|
+
)
|
|
786
|
+
mm_st = MemoryMappedTensor.from_filename(
|
|
787
|
+
shape=(self._max_capacity,),
|
|
788
|
+
dtype=torch.float64,
|
|
789
|
+
filename=path / "sumtree.memmap",
|
|
790
|
+
)
|
|
791
|
+
mm_mt = MemoryMappedTensor.from_filename(
|
|
792
|
+
shape=(self._max_capacity,),
|
|
793
|
+
dtype=torch.float64,
|
|
794
|
+
filename=path / "mintree.memmap",
|
|
795
|
+
)
|
|
796
|
+
for i, elt in enumerate(mm_st.tolist()):
|
|
797
|
+
self._sum_tree[i] = elt
|
|
798
|
+
for i, elt in enumerate(mm_mt.tolist()):
|
|
799
|
+
self._min_tree[i] = elt
|
|
800
|
+
|
|
801
|
+
|
|
802
|
+
class SliceSampler(Sampler):
|
|
803
|
+
"""Samples slices of data along the first dimension, given start and stop signals.
|
|
804
|
+
|
|
805
|
+
This class samples sub-trajectories with replacement. For a version without
|
|
806
|
+
replacement, see :class:`~torchrl.data.replay_buffers.samplers.SliceSamplerWithoutReplacement`.
|
|
807
|
+
|
|
808
|
+
.. note:: `SliceSampler` can be slow to retrieve the trajectory indices. To accelerate
|
|
809
|
+
its execution, prefer using `end_key` over `traj_key`, and consider the following
|
|
810
|
+
keyword arguments: :attr:`compile`, :attr:`cache_values` and :attr:`use_gpu`.
|
|
811
|
+
|
|
812
|
+
Keyword Args:
|
|
813
|
+
num_slices (int): the number of slices to be sampled. The batch-size
|
|
814
|
+
must be greater or equal to the ``num_slices`` argument. Exclusive
|
|
815
|
+
with ``slice_len``.
|
|
816
|
+
slice_len (int): the length of the slices to be sampled. The batch-size
|
|
817
|
+
must be greater or equal to the ``slice_len`` argument and divisible
|
|
818
|
+
by it. Exclusive with ``num_slices``.
|
|
819
|
+
end_key (NestedKey, optional): the key indicating the end of a
|
|
820
|
+
trajectory (or episode). Defaults to ``("next", "done")``.
|
|
821
|
+
traj_key (NestedKey, optional): the key indicating the trajectories.
|
|
822
|
+
Defaults to ``"episode"`` (commonly used across datasets in TorchRL).
|
|
823
|
+
ends (torch.Tensor, optional): a 1d boolean tensor containing the end of run signals.
|
|
824
|
+
To be used whenever the ``end_key`` or ``traj_key`` is expensive to get,
|
|
825
|
+
or when this signal is readily available. Must be used with ``cache_values=True``
|
|
826
|
+
and cannot be used in conjunction with ``end_key`` or ``traj_key``.
|
|
827
|
+
If provided, it is assumed that the storage is at capacity and that
|
|
828
|
+
if the last element of the ``ends`` tensor is ``False``,
|
|
829
|
+
the same trajectory spans across end and beginning.
|
|
830
|
+
trajectories (torch.Tensor, optional): a 1d integer tensor containing the run ids.
|
|
831
|
+
To be used whenever the ``end_key`` or ``traj_key`` is expensive to get,
|
|
832
|
+
or when this signal is readily available. Must be used with ``cache_values=True``
|
|
833
|
+
and cannot be used in conjunction with ``end_key`` or ``traj_key``.
|
|
834
|
+
If provided, it is assumed that the storage is at capacity and that
|
|
835
|
+
if the last element of the trajectory tensor is identical to the first,
|
|
836
|
+
the same trajectory spans across end and beginning.
|
|
837
|
+
cache_values (bool, optional): to be used with static datasets.
|
|
838
|
+
Will cache the start and end signal of the trajectory. This can be safely used even
|
|
839
|
+
if the trajectory indices change during calls to :class:`~torchrl.data.ReplayBuffer.extend`
|
|
840
|
+
as this operation will erase the cache.
|
|
841
|
+
|
|
842
|
+
.. warning:: ``cache_values=True`` will not work if the sampler is used with a
|
|
843
|
+
storage that is extended by another buffer. For instance:
|
|
844
|
+
|
|
845
|
+
>>> buffer0 = ReplayBuffer(storage=storage,
|
|
846
|
+
... sampler=SliceSampler(num_slices=8, cache_values=True),
|
|
847
|
+
... writer=ImmutableWriter())
|
|
848
|
+
>>> buffer1 = ReplayBuffer(storage=storage,
|
|
849
|
+
... sampler=other_sampler)
|
|
850
|
+
>>> # Wrong! Does not erase the buffer from the sampler of buffer0
|
|
851
|
+
>>> buffer1.extend(data)
|
|
852
|
+
|
|
853
|
+
.. warning:: ``cache_values=True`` will not work as expected if the buffer is
|
|
854
|
+
shared between processes and one process is responsible for writing
|
|
855
|
+
and one process for sampling, as erasing the cache can only be done locally.
|
|
856
|
+
|
|
857
|
+
truncated_key (NestedKey, optional): If not ``None``, this argument
|
|
858
|
+
indicates where a truncated signal should be written in the output
|
|
859
|
+
data. This is used to indicate to value estimators where the provided
|
|
860
|
+
trajectory breaks. Defaults to ``("next", "truncated")``.
|
|
861
|
+
This feature only works with :class:`~torchrl.data.replay_buffers.TensorDictReplayBuffer`
|
|
862
|
+
instances (otherwise the truncated key is returned in the info dictionary
|
|
863
|
+
returned by the :meth:`~torchrl.data.replay_buffers.ReplayBuffer.sample` method).
|
|
864
|
+
strict_length (bool, optional): if ``False``, trajectories of length
|
|
865
|
+
shorter than `slice_len` (or `batch_size // num_slices`) will be
|
|
866
|
+
allowed to appear in the batch. If ``True``, trajectories shorted
|
|
867
|
+
than required will be filtered out.
|
|
868
|
+
Be mindful that this can result in effective `batch_size` shorter
|
|
869
|
+
than the one asked for! Trajectories can be split using
|
|
870
|
+
:func:`~torchrl.collectors.split_trajectories`. Defaults to ``True``.
|
|
871
|
+
compile (bool or dict of kwargs, optional): if ``True``, the bottleneck of
|
|
872
|
+
the :meth:`~sample` method will be compiled with :func:`~torch.compile`.
|
|
873
|
+
Keyword arguments can also be passed to torch.compile with this arg.
|
|
874
|
+
Defaults to ``False``.
|
|
875
|
+
span (bool, int, Tuple[bool | int, bool | int], optional): if provided, the sampled
|
|
876
|
+
trajectory will span across the left and/or the right. This means that possibly
|
|
877
|
+
fewer elements will be provided than what was required. A boolean value means
|
|
878
|
+
that at least one element will be sampled per trajectory. An integer `i` means
|
|
879
|
+
that at least `slice_len - i` samples will be gathered for each sampled trajectory.
|
|
880
|
+
Using tuples allows a fine grained control over the span on the left (beginning
|
|
881
|
+
of the stored trajectory) and on the right (end of the stored trajectory).
|
|
882
|
+
use_gpu (bool or torch.device): if ``True`` (or is a device is passed), an accelerator
|
|
883
|
+
will be used to retrieve the indices of the trajectory starts. This can significantly
|
|
884
|
+
accelerate the sampling when the buffer content is large.
|
|
885
|
+
Defaults to ``False``.
|
|
886
|
+
|
|
887
|
+
.. note:: To recover the trajectory splits in the storage,
|
|
888
|
+
:class:`~torchrl.data.replay_buffers.samplers.SliceSampler` will first
|
|
889
|
+
attempt to find the ``traj_key`` entry in the storage. If it cannot be
|
|
890
|
+
found, the ``end_key`` will be used to reconstruct the episodes.
|
|
891
|
+
|
|
892
|
+
.. note:: When using `strict_length=False`, it is recommended to use
|
|
893
|
+
:func:`~torchrl.collectors.utils.split_trajectories` to split the sampled trajectories.
|
|
894
|
+
However, if two samples from the same episode are placed next to each other,
|
|
895
|
+
this may produce incorrect results. To avoid this issue, consider one of these solutions:
|
|
896
|
+
|
|
897
|
+
- using a :class:`~torchrl.data.TensorDictReplayBuffer` instance with the slice sampler
|
|
898
|
+
|
|
899
|
+
>>> import torch
|
|
900
|
+
>>> from tensordict import TensorDict
|
|
901
|
+
>>> from torchrl.collectors.utils import split_trajectories
|
|
902
|
+
>>> from torchrl.data import TensorDictReplayBuffer, ReplayBuffer, LazyTensorStorage, SliceSampler, SliceSamplerWithoutReplacement
|
|
903
|
+
>>>
|
|
904
|
+
>>> rb = TensorDictReplayBuffer(storage=LazyTensorStorage(max_size=1000),
|
|
905
|
+
... sampler=SliceSampler(
|
|
906
|
+
... slice_len=5, traj_key="episode",strict_length=False,
|
|
907
|
+
... ))
|
|
908
|
+
...
|
|
909
|
+
>>> ep_1 = TensorDict(
|
|
910
|
+
... {"obs": torch.arange(100),
|
|
911
|
+
... "episode": torch.zeros(100),},
|
|
912
|
+
... batch_size=[100]
|
|
913
|
+
... )
|
|
914
|
+
>>> ep_2 = TensorDict(
|
|
915
|
+
... {"obs": torch.arange(4),
|
|
916
|
+
... "episode": torch.ones(4),},
|
|
917
|
+
... batch_size=[4]
|
|
918
|
+
... )
|
|
919
|
+
>>> rb.extend(ep_1)
|
|
920
|
+
>>> rb.extend(ep_2)
|
|
921
|
+
>>>
|
|
922
|
+
>>> s = rb.sample(50)
|
|
923
|
+
>>> print(s)
|
|
924
|
+
TensorDict(
|
|
925
|
+
fields={
|
|
926
|
+
episode: Tensor(shape=torch.Size([46]), device=cpu, dtype=torch.float32, is_shared=False),
|
|
927
|
+
index: Tensor(shape=torch.Size([46, 1]), device=cpu, dtype=torch.int64, is_shared=False),
|
|
928
|
+
next: TensorDict(
|
|
929
|
+
fields={
|
|
930
|
+
done: Tensor(shape=torch.Size([46, 1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
931
|
+
terminated: Tensor(shape=torch.Size([46, 1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
932
|
+
truncated: Tensor(shape=torch.Size([46, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
|
|
933
|
+
batch_size=torch.Size([46]),
|
|
934
|
+
device=cpu,
|
|
935
|
+
is_shared=False),
|
|
936
|
+
obs: Tensor(shape=torch.Size([46]), device=cpu, dtype=torch.int64, is_shared=False)},
|
|
937
|
+
batch_size=torch.Size([46]),
|
|
938
|
+
device=cpu,
|
|
939
|
+
is_shared=False)
|
|
940
|
+
>>> t = split_trajectories(s, done_key="truncated")
|
|
941
|
+
>>> print(t["obs"])
|
|
942
|
+
tensor([[73, 74, 75, 76, 77],
|
|
943
|
+
[ 0, 1, 2, 3, 0],
|
|
944
|
+
[ 0, 1, 2, 3, 0],
|
|
945
|
+
[41, 42, 43, 44, 45],
|
|
946
|
+
[ 0, 1, 2, 3, 0],
|
|
947
|
+
[67, 68, 69, 70, 71],
|
|
948
|
+
[27, 28, 29, 30, 31],
|
|
949
|
+
[80, 81, 82, 83, 84],
|
|
950
|
+
[17, 18, 19, 20, 21],
|
|
951
|
+
[ 0, 1, 2, 3, 0]])
|
|
952
|
+
>>> print(t["episode"])
|
|
953
|
+
tensor([[0., 0., 0., 0., 0.],
|
|
954
|
+
[1., 1., 1., 1., 0.],
|
|
955
|
+
[1., 1., 1., 1., 0.],
|
|
956
|
+
[0., 0., 0., 0., 0.],
|
|
957
|
+
[1., 1., 1., 1., 0.],
|
|
958
|
+
[0., 0., 0., 0., 0.],
|
|
959
|
+
[0., 0., 0., 0., 0.],
|
|
960
|
+
[0., 0., 0., 0., 0.],
|
|
961
|
+
[0., 0., 0., 0., 0.],
|
|
962
|
+
[1., 1., 1., 1., 0.]])
|
|
963
|
+
|
|
964
|
+
- using a :class:`~torchrl.data.replay_buffers.samplers.SliceSamplerWithoutReplacement`
|
|
965
|
+
|
|
966
|
+
>>> import torch
|
|
967
|
+
>>> from tensordict import TensorDict
|
|
968
|
+
>>> from torchrl.collectors.utils import split_trajectories
|
|
969
|
+
>>> from torchrl.data import ReplayBuffer, LazyTensorStorage, SliceSampler, SliceSamplerWithoutReplacement
|
|
970
|
+
>>>
|
|
971
|
+
>>> rb = ReplayBuffer(storage=LazyTensorStorage(max_size=1000),
|
|
972
|
+
... sampler=SliceSamplerWithoutReplacement(
|
|
973
|
+
... slice_len=5, traj_key="episode",strict_length=False
|
|
974
|
+
... ))
|
|
975
|
+
...
|
|
976
|
+
>>> ep_1 = TensorDict(
|
|
977
|
+
... {"obs": torch.arange(100),
|
|
978
|
+
... "episode": torch.zeros(100),},
|
|
979
|
+
... batch_size=[100]
|
|
980
|
+
... )
|
|
981
|
+
>>> ep_2 = TensorDict(
|
|
982
|
+
... {"obs": torch.arange(4),
|
|
983
|
+
... "episode": torch.ones(4),},
|
|
984
|
+
... batch_size=[4]
|
|
985
|
+
... )
|
|
986
|
+
>>> rb.extend(ep_1)
|
|
987
|
+
>>> rb.extend(ep_2)
|
|
988
|
+
>>>
|
|
989
|
+
>>> s = rb.sample(50)
|
|
990
|
+
>>> t = split_trajectories(s, trajectory_key="episode")
|
|
991
|
+
>>> print(t["obs"])
|
|
992
|
+
tensor([[75, 76, 77, 78, 79],
|
|
993
|
+
[ 0, 1, 2, 3, 0]])
|
|
994
|
+
>>> print(t["episode"])
|
|
995
|
+
tensor([[0., 0., 0., 0., 0.],
|
|
996
|
+
[1., 1., 1., 1., 0.]])
|
|
997
|
+
|
|
998
|
+
Examples:
|
|
999
|
+
>>> import torch
|
|
1000
|
+
>>> from tensordict import TensorDict
|
|
1001
|
+
>>> from torchrl.data.replay_buffers import LazyMemmapStorage, TensorDictReplayBuffer
|
|
1002
|
+
>>> from torchrl.data.replay_buffers.samplers import SliceSampler
|
|
1003
|
+
>>> torch.manual_seed(0)
|
|
1004
|
+
>>> rb = TensorDictReplayBuffer(
|
|
1005
|
+
... storage=LazyMemmapStorage(1_000_000),
|
|
1006
|
+
... sampler=SliceSampler(cache_values=True, num_slices=10),
|
|
1007
|
+
... batch_size=320,
|
|
1008
|
+
... )
|
|
1009
|
+
>>> episode = torch.zeros(1000, dtype=torch.int)
|
|
1010
|
+
>>> episode[:300] = 1
|
|
1011
|
+
>>> episode[300:550] = 2
|
|
1012
|
+
>>> episode[550:700] = 3
|
|
1013
|
+
>>> episode[700:] = 4
|
|
1014
|
+
>>> data = TensorDict(
|
|
1015
|
+
... {
|
|
1016
|
+
... "episode": episode,
|
|
1017
|
+
... "obs": torch.randn((3, 4, 5)).expand(1000, 3, 4, 5),
|
|
1018
|
+
... "act": torch.randn((20,)).expand(1000, 20),
|
|
1019
|
+
... "other": torch.randn((20, 50)).expand(1000, 20, 50),
|
|
1020
|
+
... }, [1000]
|
|
1021
|
+
... )
|
|
1022
|
+
>>> rb.extend(data)
|
|
1023
|
+
>>> sample = rb.sample()
|
|
1024
|
+
>>> print("sample:", sample)
|
|
1025
|
+
>>> print("episodes", sample.get("episode").unique())
|
|
1026
|
+
episodes tensor([1, 2, 3, 4], dtype=torch.int32)
|
|
1027
|
+
|
|
1028
|
+
:class:`~torchrl.data.replay_buffers.SliceSampler` is default-compatible with
|
|
1029
|
+
most of TorchRL's datasets:
|
|
1030
|
+
|
|
1031
|
+
Examples:
|
|
1032
|
+
>>> import torch
|
|
1033
|
+
>>>
|
|
1034
|
+
>>> from torchrl.data.datasets import RobosetExperienceReplay
|
|
1035
|
+
>>> from torchrl.data import SliceSampler
|
|
1036
|
+
>>>
|
|
1037
|
+
>>> torch.manual_seed(0)
|
|
1038
|
+
>>> num_slices = 10
|
|
1039
|
+
>>> dataid = list(RobosetExperienceReplay.available_datasets)[0]
|
|
1040
|
+
>>> data = RobosetExperienceReplay(dataid, batch_size=320, sampler=SliceSampler(num_slices=num_slices))
|
|
1041
|
+
>>> for batch in data:
|
|
1042
|
+
... batch = batch.reshape(num_slices, -1)
|
|
1043
|
+
... break
|
|
1044
|
+
>>> print("check that each batch only has one episode:", batch["episode"].unique(dim=1))
|
|
1045
|
+
check that each batch only has one episode: tensor([[19],
|
|
1046
|
+
[14],
|
|
1047
|
+
[ 8],
|
|
1048
|
+
[10],
|
|
1049
|
+
[13],
|
|
1050
|
+
[ 4],
|
|
1051
|
+
[ 2],
|
|
1052
|
+
[ 3],
|
|
1053
|
+
[22],
|
|
1054
|
+
[ 8]])
|
|
1055
|
+
|
|
1056
|
+
"""
|
|
1057
|
+
|
|
1058
|
+
# We use this whenever we need to sample N times too many transitions to then select only a 1/N fraction of them
|
|
1059
|
+
_batch_size_multiplier: int | None = 1
|
|
1060
|
+
|
|
1061
|
+
def __init__(
|
|
1062
|
+
self,
|
|
1063
|
+
*,
|
|
1064
|
+
num_slices: int | None = None,
|
|
1065
|
+
slice_len: int | None = None,
|
|
1066
|
+
end_key: NestedKey | None = None,
|
|
1067
|
+
traj_key: NestedKey | None = None,
|
|
1068
|
+
ends: torch.Tensor | None = None,
|
|
1069
|
+
trajectories: torch.Tensor | None = None,
|
|
1070
|
+
cache_values: bool = False,
|
|
1071
|
+
truncated_key: NestedKey | None = ("next", "truncated"),
|
|
1072
|
+
strict_length: bool = True,
|
|
1073
|
+
compile: bool | dict = False,
|
|
1074
|
+
span: bool | int | tuple[bool | int, bool | int] = False,
|
|
1075
|
+
use_gpu: torch.device | bool = False,
|
|
1076
|
+
):
|
|
1077
|
+
self.num_slices = num_slices
|
|
1078
|
+
self.slice_len = slice_len
|
|
1079
|
+
self.end_key = end_key
|
|
1080
|
+
self.traj_key = traj_key
|
|
1081
|
+
self.truncated_key = truncated_key
|
|
1082
|
+
self.cache_values = cache_values
|
|
1083
|
+
self._fetch_traj = True
|
|
1084
|
+
self.strict_length = strict_length
|
|
1085
|
+
self._cache = {}
|
|
1086
|
+
self.use_gpu = bool(use_gpu)
|
|
1087
|
+
self._gpu_device = (
|
|
1088
|
+
None
|
|
1089
|
+
if not self.use_gpu
|
|
1090
|
+
else torch.device(use_gpu)
|
|
1091
|
+
if not isinstance(use_gpu, bool)
|
|
1092
|
+
else _auto_device()
|
|
1093
|
+
)
|
|
1094
|
+
|
|
1095
|
+
if isinstance(span, (bool, int)):
|
|
1096
|
+
span = (span, span)
|
|
1097
|
+
self.span = span
|
|
1098
|
+
|
|
1099
|
+
if trajectories is not None:
|
|
1100
|
+
if traj_key is not None or end_key:
|
|
1101
|
+
raise RuntimeError(
|
|
1102
|
+
"`trajectories` and `end_key` or `traj_key` are exclusive arguments."
|
|
1103
|
+
)
|
|
1104
|
+
if ends is not None:
|
|
1105
|
+
raise RuntimeError("trajectories and ends are exclusive arguments.")
|
|
1106
|
+
if not cache_values:
|
|
1107
|
+
raise RuntimeError(
|
|
1108
|
+
"To be used, trajectories requires `cache_values` to be set to `True`."
|
|
1109
|
+
)
|
|
1110
|
+
vals = self._find_start_stop_traj(
|
|
1111
|
+
trajectory=trajectories,
|
|
1112
|
+
at_capacity=True,
|
|
1113
|
+
)
|
|
1114
|
+
self._cache["stop-and-length"] = vals
|
|
1115
|
+
|
|
1116
|
+
elif ends is not None:
|
|
1117
|
+
if traj_key is not None or end_key:
|
|
1118
|
+
raise RuntimeError(
|
|
1119
|
+
"`ends` and `end_key` or `traj_key` are exclusive arguments."
|
|
1120
|
+
)
|
|
1121
|
+
if trajectories is not None:
|
|
1122
|
+
raise RuntimeError("trajectories and ends are exclusive arguments.")
|
|
1123
|
+
if not cache_values:
|
|
1124
|
+
raise RuntimeError(
|
|
1125
|
+
"To be used, ends requires `cache_values` to be set to `True`."
|
|
1126
|
+
)
|
|
1127
|
+
vals = self._find_start_stop_traj(end=ends, at_capacity=True)
|
|
1128
|
+
self._cache["stop-and-length"] = vals
|
|
1129
|
+
|
|
1130
|
+
else:
|
|
1131
|
+
if traj_key is not None:
|
|
1132
|
+
self._fetch_traj = True
|
|
1133
|
+
elif end_key is not None:
|
|
1134
|
+
self._fetch_traj = False
|
|
1135
|
+
if end_key is None:
|
|
1136
|
+
end_key = ("next", "done")
|
|
1137
|
+
if traj_key is None:
|
|
1138
|
+
traj_key = "episode"
|
|
1139
|
+
self.end_key = end_key
|
|
1140
|
+
self.traj_key = traj_key
|
|
1141
|
+
|
|
1142
|
+
if not ((num_slices is None) ^ (slice_len is None)):
|
|
1143
|
+
raise TypeError(
|
|
1144
|
+
"Either num_slices or slice_len must be not None, and not both. "
|
|
1145
|
+
f"Got num_slices={num_slices} and slice_len={slice_len}."
|
|
1146
|
+
)
|
|
1147
|
+
self.compile = bool(compile)
|
|
1148
|
+
if self.compile:
|
|
1149
|
+
if isinstance(compile, dict):
|
|
1150
|
+
kwargs = compile
|
|
1151
|
+
else:
|
|
1152
|
+
kwargs = {}
|
|
1153
|
+
self._get_index = torch.compile(self._get_index, **kwargs)
|
|
1154
|
+
|
|
1155
|
+
def __getstate__(self):
|
|
1156
|
+
if get_spawning_popen() is not None and self.cache_values:
|
|
1157
|
+
logger.warning(
|
|
1158
|
+
f"It seems you are sharing a {type(self).__name__} across processes with "
|
|
1159
|
+
f"cache_values=True. "
|
|
1160
|
+
f"While this isn't forbidden and could perfectly work if your dataset "
|
|
1161
|
+
f"is unaltered on both processes, remember that calling extend/add on "
|
|
1162
|
+
f"one process will NOT erase the cache on another process's sampler, "
|
|
1163
|
+
f"which will cause synchronization issues."
|
|
1164
|
+
)
|
|
1165
|
+
state = super().__getstate__()
|
|
1166
|
+
state["_cache"] = {}
|
|
1167
|
+
return state
|
|
1168
|
+
|
|
1169
|
+
def extend(self, index: torch.Tensor) -> None:
|
|
1170
|
+
super().extend(index)
|
|
1171
|
+
if self.cache_values:
|
|
1172
|
+
self._cache.clear()
|
|
1173
|
+
|
|
1174
|
+
def add(self, index: torch.Tensor) -> None:
|
|
1175
|
+
super().add(index)
|
|
1176
|
+
if self.cache_values:
|
|
1177
|
+
self._cache.clear()
|
|
1178
|
+
|
|
1179
|
+
def __repr__(self):
|
|
1180
|
+
return (
|
|
1181
|
+
f"{self.__class__.__name__}(num_slices={self.num_slices}, "
|
|
1182
|
+
f"slice_len={self.slice_len}, "
|
|
1183
|
+
f"end_key={self.end_key}, "
|
|
1184
|
+
f"traj_key={self.traj_key}, "
|
|
1185
|
+
f"truncated_key={self.truncated_key}, "
|
|
1186
|
+
f"strict_length={self.strict_length})"
|
|
1187
|
+
)
|
|
1188
|
+
|
|
1189
|
+
def _find_start_stop_traj(
|
|
1190
|
+
self, *, trajectory=None, end=None, at_capacity: bool, cursor=None
|
|
1191
|
+
):
|
|
1192
|
+
if trajectory is not None:
|
|
1193
|
+
# slower
|
|
1194
|
+
# _, stop_idx = torch.unique_consecutive(trajectory, return_counts=True)
|
|
1195
|
+
# stop_idx = stop_idx.cumsum(0) - 1
|
|
1196
|
+
|
|
1197
|
+
# even slower
|
|
1198
|
+
# t = trajectory.unsqueeze(0)
|
|
1199
|
+
# w = torch.tensor([1, -1], dtype=torch.int).view(1, 1, 2)
|
|
1200
|
+
# stop_idx = torch.conv1d(t, w).nonzero()
|
|
1201
|
+
|
|
1202
|
+
# faster
|
|
1203
|
+
end = trajectory[:-1] != trajectory[1:]
|
|
1204
|
+
if not at_capacity:
|
|
1205
|
+
end = torch.cat([end, torch.ones_like(end[:1])], 0)
|
|
1206
|
+
else:
|
|
1207
|
+
end = torch.cat([end, trajectory[-1:] != trajectory[:1]], 0)
|
|
1208
|
+
length = trajectory.shape[0]
|
|
1209
|
+
else:
|
|
1210
|
+
# We presume that not done at the end means that the traj spans across end and beginning of storage
|
|
1211
|
+
length = end.shape[0]
|
|
1212
|
+
if not at_capacity:
|
|
1213
|
+
end = end.clone()
|
|
1214
|
+
end[length - 1] = True
|
|
1215
|
+
ndim = end.ndim
|
|
1216
|
+
|
|
1217
|
+
if at_capacity:
|
|
1218
|
+
# we must have at least one end by traj to individuate trajectories
|
|
1219
|
+
# so if no end can be found we set it manually
|
|
1220
|
+
if cursor is not None:
|
|
1221
|
+
if isinstance(cursor, torch.Tensor):
|
|
1222
|
+
cursor = cursor[-1].item()
|
|
1223
|
+
elif isinstance(cursor, range):
|
|
1224
|
+
cursor = cursor[-1]
|
|
1225
|
+
if not _is_int(cursor):
|
|
1226
|
+
raise RuntimeError(
|
|
1227
|
+
"cursor should be an integer or a 1d tensor or a range."
|
|
1228
|
+
)
|
|
1229
|
+
end = torch.index_fill(
|
|
1230
|
+
end,
|
|
1231
|
+
index=torch.tensor(cursor, device=end.device, dtype=torch.long),
|
|
1232
|
+
dim=0,
|
|
1233
|
+
value=1,
|
|
1234
|
+
)
|
|
1235
|
+
if not end.any(0).all():
|
|
1236
|
+
mask = ~end.any(0, True)
|
|
1237
|
+
mask = torch.cat([torch.zeros_like(end[:-1]), mask])
|
|
1238
|
+
end = torch.masked_fill(mask, end, 1)
|
|
1239
|
+
if ndim == 0:
|
|
1240
|
+
raise RuntimeError(
|
|
1241
|
+
"Expected the end-of-trajectory signal to be at least 1-dimensional."
|
|
1242
|
+
)
|
|
1243
|
+
return self._end_to_start_stop(length=length, end=end)
|
|
1244
|
+
|
|
1245
|
+
def _end_to_start_stop(self, end, length):
|
|
1246
|
+
device = None
|
|
1247
|
+
if self.use_gpu:
|
|
1248
|
+
gpu_device = self._gpu_device
|
|
1249
|
+
if end.device != gpu_device:
|
|
1250
|
+
device = end.device
|
|
1251
|
+
end = end.to(self._gpu_device)
|
|
1252
|
+
# Using transpose ensures the start and stop are sorted the same way
|
|
1253
|
+
stop_idx = end.transpose(0, -1).nonzero()
|
|
1254
|
+
stop_idx[:, [0, -1]] = stop_idx[:, [-1, 0]].clone()
|
|
1255
|
+
# First build the start indices as the stop + 1, we'll shift it later
|
|
1256
|
+
start_idx = stop_idx.clone()
|
|
1257
|
+
start_idx[:, 0] += 1
|
|
1258
|
+
start_idx[:, 0] %= end.shape[0]
|
|
1259
|
+
# shift start: to do this, we check when the non-first dim indices are identical
|
|
1260
|
+
# and get a mask like [False, True, True, False, True, ...] where False means
|
|
1261
|
+
# that there's a switch from one dim to another (ie, a switch from one element of the batch
|
|
1262
|
+
# to another). We roll this one step along the time dimension and these two
|
|
1263
|
+
# masks provide us with the indices of the permutation matrix we need
|
|
1264
|
+
# to apply to start_idx.
|
|
1265
|
+
if start_idx.shape[0] > 1:
|
|
1266
|
+
start_idx_mask = (start_idx[1:, 1:] == start_idx[:-1, 1:]).all(-1)
|
|
1267
|
+
m1 = torch.cat([torch.zeros_like(start_idx_mask[:1]), start_idx_mask])
|
|
1268
|
+
m2 = torch.cat([start_idx_mask, torch.zeros_like(start_idx_mask[:1])])
|
|
1269
|
+
start_idx_replace = torch.empty_like(start_idx)
|
|
1270
|
+
start_idx_replace[m1] = start_idx[m2]
|
|
1271
|
+
start_idx_replace[~m1] = start_idx[~m2]
|
|
1272
|
+
start_idx = start_idx_replace
|
|
1273
|
+
else:
|
|
1274
|
+
# In this case we have only one start and stop has already been set
|
|
1275
|
+
pass
|
|
1276
|
+
lengths = stop_idx[:, 0] - start_idx[:, 0] + 1
|
|
1277
|
+
lengths[lengths <= 0] = lengths[lengths <= 0] + length
|
|
1278
|
+
if device is not None:
|
|
1279
|
+
return start_idx.to(device), stop_idx.to(device), lengths.to(device)
|
|
1280
|
+
return start_idx, stop_idx, lengths
|
|
1281
|
+
|
|
1282
|
+
def _start_to_end(self, st: torch.Tensor, length: int):
|
|
1283
|
+
|
|
1284
|
+
arange = torch.arange(length, device=st.device, dtype=st.dtype)
|
|
1285
|
+
ndims = st.shape[-1] - 1 if st.ndim else 0
|
|
1286
|
+
if ndims:
|
|
1287
|
+
arange = torch.stack([arange] + [torch.zeros_like(arange)] * ndims, -1)
|
|
1288
|
+
else:
|
|
1289
|
+
arange = arange.unsqueeze(-1)
|
|
1290
|
+
if st.shape != arange.shape:
|
|
1291
|
+
# we do this to make sure that we're not broadcasting the start
|
|
1292
|
+
# wrong as a tensor with shape [N] can't be expanded to [N, 1]
|
|
1293
|
+
# without getting an error
|
|
1294
|
+
st = st.expand_as(arange)
|
|
1295
|
+
return arange + st
|
|
1296
|
+
|
|
1297
|
+
def _tensor_slices_from_startend(self, seq_length, start, storage_length):
|
|
1298
|
+
# start is a 2d tensor resulting from nonzero()
|
|
1299
|
+
# seq_length is a 1d tensor indicating the desired length of each sequence
|
|
1300
|
+
|
|
1301
|
+
if isinstance(seq_length, int):
|
|
1302
|
+
arange = torch.arange(seq_length, device=start.device, dtype=start.dtype)
|
|
1303
|
+
ndims = start.shape[-1] - 1 if (start.ndim - 1) else 0
|
|
1304
|
+
if ndims:
|
|
1305
|
+
arange_reshaped = torch.empty(
|
|
1306
|
+
arange.shape + torch.Size([ndims + 1]),
|
|
1307
|
+
device=start.device,
|
|
1308
|
+
dtype=start.dtype,
|
|
1309
|
+
)
|
|
1310
|
+
arange_reshaped[..., 0] = arange
|
|
1311
|
+
arange_reshaped[..., 1:] = 0
|
|
1312
|
+
else:
|
|
1313
|
+
arange_reshaped = arange.unsqueeze(-1)
|
|
1314
|
+
arange_expanded = arange_reshaped.expand(
|
|
1315
|
+
torch.Size([start.shape[0]]) + arange_reshaped.shape
|
|
1316
|
+
)
|
|
1317
|
+
if start.shape != arange_expanded.shape:
|
|
1318
|
+
n_missing_dims = arange_expanded.dim() - start.dim()
|
|
1319
|
+
start_expanded = start[
|
|
1320
|
+
(slice(None),) + (None,) * n_missing_dims
|
|
1321
|
+
].expand_as(arange_expanded)
|
|
1322
|
+
result = (start_expanded + arange_expanded).flatten(0, 1)
|
|
1323
|
+
|
|
1324
|
+
else:
|
|
1325
|
+
# when padding is needed
|
|
1326
|
+
result = torch.cat(
|
|
1327
|
+
[
|
|
1328
|
+
self._start_to_end(_start, _seq_len)
|
|
1329
|
+
for _start, _seq_len in zip(start, seq_length)
|
|
1330
|
+
]
|
|
1331
|
+
)
|
|
1332
|
+
result[:, 0] = result[:, 0] % storage_length
|
|
1333
|
+
return result
|
|
1334
|
+
|
|
1335
|
+
def _get_stop_and_length(self, storage, fallback=True):
|
|
1336
|
+
if self.cache_values and "stop-and-length" in self._cache:
|
|
1337
|
+
return self._cache.get("stop-and-length")
|
|
1338
|
+
|
|
1339
|
+
if self._fetch_traj:
|
|
1340
|
+
# We first try with the traj_key
|
|
1341
|
+
try:
|
|
1342
|
+
if isinstance(storage, TensorStorage):
|
|
1343
|
+
trajectory = storage[:][self._used_traj_key]
|
|
1344
|
+
else:
|
|
1345
|
+
try:
|
|
1346
|
+
trajectory = storage[:][self.traj_key]
|
|
1347
|
+
except Exception:
|
|
1348
|
+
raise RuntimeError(
|
|
1349
|
+
"Could not get a tensordict out of the storage, which is required for SliceSampler to compute the trajectories."
|
|
1350
|
+
)
|
|
1351
|
+
vals = self._find_start_stop_traj(
|
|
1352
|
+
trajectory=trajectory,
|
|
1353
|
+
at_capacity=storage._is_full,
|
|
1354
|
+
cursor=getattr(storage, "_last_cursor", None),
|
|
1355
|
+
)
|
|
1356
|
+
if self.cache_values:
|
|
1357
|
+
self._cache["stop-and-length"] = vals
|
|
1358
|
+
return vals
|
|
1359
|
+
except KeyError:
|
|
1360
|
+
if fallback:
|
|
1361
|
+
self._fetch_traj = False
|
|
1362
|
+
return self._get_stop_and_length(storage, fallback=False)
|
|
1363
|
+
raise
|
|
1364
|
+
|
|
1365
|
+
else:
|
|
1366
|
+
try:
|
|
1367
|
+
try:
|
|
1368
|
+
done = storage[:][self.end_key]
|
|
1369
|
+
except Exception:
|
|
1370
|
+
raise RuntimeError(
|
|
1371
|
+
"Could not get a tensordict out of the storage, which is required for SliceSampler to compute the trajectories."
|
|
1372
|
+
)
|
|
1373
|
+
vals = self._find_start_stop_traj(
|
|
1374
|
+
end=done.squeeze()[: len(storage)],
|
|
1375
|
+
at_capacity=storage._is_full,
|
|
1376
|
+
cursor=getattr(storage, "_last_cursor", None),
|
|
1377
|
+
)
|
|
1378
|
+
if self.cache_values:
|
|
1379
|
+
self._cache["stop-and-length"] = vals
|
|
1380
|
+
return vals
|
|
1381
|
+
except KeyError:
|
|
1382
|
+
if fallback:
|
|
1383
|
+
self._fetch_traj = True
|
|
1384
|
+
return self._get_stop_and_length(storage, fallback=False)
|
|
1385
|
+
raise
|
|
1386
|
+
|
|
1387
|
+
def _adjusted_batch_size(self, batch_size):
|
|
1388
|
+
if self.num_slices is not None:
|
|
1389
|
+
if batch_size % self.num_slices != 0:
|
|
1390
|
+
raise RuntimeError(
|
|
1391
|
+
f"The batch-size must be divisible by the number of slices, got "
|
|
1392
|
+
f"batch_size={batch_size} and num_slices={self.num_slices}."
|
|
1393
|
+
)
|
|
1394
|
+
seq_length = batch_size // self.num_slices
|
|
1395
|
+
num_slices = self.num_slices
|
|
1396
|
+
else:
|
|
1397
|
+
if batch_size % self.slice_len != 0:
|
|
1398
|
+
raise RuntimeError(
|
|
1399
|
+
f"The batch-size must be divisible by the slice length, got "
|
|
1400
|
+
f"batch_size={batch_size} and slice_len={self.slice_len}."
|
|
1401
|
+
)
|
|
1402
|
+
seq_length = self.slice_len
|
|
1403
|
+
num_slices = batch_size // self.slice_len
|
|
1404
|
+
return seq_length, num_slices
|
|
1405
|
+
|
|
1406
|
+
def sample(self, storage: Storage, batch_size: int) -> tuple[torch.Tensor, dict]:
|
|
1407
|
+
if self._batch_size_multiplier is not None:
|
|
1408
|
+
batch_size = batch_size * self._batch_size_multiplier
|
|
1409
|
+
# pick up as many trajs as we need
|
|
1410
|
+
start_idx, stop_idx, lengths = self._get_stop_and_length(storage)
|
|
1411
|
+
# we have to make sure that the number of dims of the storage
|
|
1412
|
+
# is the same as the stop/start signals since we will
|
|
1413
|
+
# use these to sample the storage
|
|
1414
|
+
if start_idx.shape[1] != storage.ndim:
|
|
1415
|
+
raise RuntimeError(
|
|
1416
|
+
f"Expected the end-of-trajectory signal to be "
|
|
1417
|
+
f"{storage.ndim}-dimensional. Got a tensor with shape[1]={start_idx.shape[1]} "
|
|
1418
|
+
"instead."
|
|
1419
|
+
)
|
|
1420
|
+
seq_length, num_slices = self._adjusted_batch_size(batch_size)
|
|
1421
|
+
storage_length = storage.shape[0]
|
|
1422
|
+
return self._sample_slices(
|
|
1423
|
+
lengths,
|
|
1424
|
+
start_idx,
|
|
1425
|
+
stop_idx,
|
|
1426
|
+
seq_length,
|
|
1427
|
+
num_slices,
|
|
1428
|
+
storage_length=storage_length,
|
|
1429
|
+
storage=storage,
|
|
1430
|
+
)
|
|
1431
|
+
|
|
1432
|
+
def _sample_slices(
|
|
1433
|
+
self,
|
|
1434
|
+
lengths: torch.Tensor,
|
|
1435
|
+
start_idx: torch.Tensor,
|
|
1436
|
+
stop_idx: torch.Tensor,
|
|
1437
|
+
seq_length: int,
|
|
1438
|
+
num_slices: int,
|
|
1439
|
+
storage_length: int,
|
|
1440
|
+
traj_idx: torch.Tensor | None = None,
|
|
1441
|
+
*,
|
|
1442
|
+
storage,
|
|
1443
|
+
) -> tuple[tuple[torch.Tensor, ...], dict[str, Any]]:
|
|
1444
|
+
# start_idx and stop_idx are 2d tensors organized like a non-zero
|
|
1445
|
+
|
|
1446
|
+
def get_traj_idx(maxval):
|
|
1447
|
+
return torch.randint(
|
|
1448
|
+
maxval, (num_slices,), device=lengths.device, generator=self._rng
|
|
1449
|
+
)
|
|
1450
|
+
|
|
1451
|
+
if (lengths < seq_length).any():
|
|
1452
|
+
if self.strict_length:
|
|
1453
|
+
idx = lengths >= seq_length
|
|
1454
|
+
if not idx.any():
|
|
1455
|
+
raise RuntimeError(
|
|
1456
|
+
f"Did not find a single trajectory with sufficient length (length range: {lengths.min()} - {lengths.max()} / required={seq_length}))."
|
|
1457
|
+
)
|
|
1458
|
+
if (
|
|
1459
|
+
isinstance(seq_length, torch.Tensor)
|
|
1460
|
+
and seq_length.shape == lengths.shape
|
|
1461
|
+
):
|
|
1462
|
+
seq_length = seq_length[idx]
|
|
1463
|
+
lengths_idx = lengths[idx]
|
|
1464
|
+
start_idx = start_idx[idx]
|
|
1465
|
+
stop_idx = stop_idx[idx]
|
|
1466
|
+
|
|
1467
|
+
if traj_idx is None:
|
|
1468
|
+
traj_idx = get_traj_idx(lengths_idx.shape[0])
|
|
1469
|
+
else:
|
|
1470
|
+
# Here we must filter out the indices that correspond to trajectories
|
|
1471
|
+
# we don't want to keep. That could potentially lead to an empty sample.
|
|
1472
|
+
# The difficulty with this adjustment is that traj_idx points to a full
|
|
1473
|
+
# sequences of lengths, but we filter out part of it so we must
|
|
1474
|
+
# convert traj_idx to a boolean mask, index this mask with the
|
|
1475
|
+
# valid indices and then recover the nonzero.
|
|
1476
|
+
idx_mask = torch.zeros_like(idx)
|
|
1477
|
+
idx_mask[traj_idx] = True
|
|
1478
|
+
traj_idx = idx_mask[idx].nonzero().squeeze(-1)
|
|
1479
|
+
if not traj_idx.numel():
|
|
1480
|
+
raise RuntimeError(
|
|
1481
|
+
"None of the provided indices pointed to a trajectory of "
|
|
1482
|
+
"sufficient length. Consider using strict_length=False for the "
|
|
1483
|
+
"sampler instead."
|
|
1484
|
+
)
|
|
1485
|
+
num_slices = traj_idx.shape[0]
|
|
1486
|
+
|
|
1487
|
+
del idx
|
|
1488
|
+
lengths = lengths_idx
|
|
1489
|
+
else:
|
|
1490
|
+
if traj_idx is None:
|
|
1491
|
+
traj_idx = get_traj_idx(lengths.shape[0])
|
|
1492
|
+
else:
|
|
1493
|
+
num_slices = traj_idx.shape[0]
|
|
1494
|
+
|
|
1495
|
+
# make seq_length a tensor with values clamped by lengths
|
|
1496
|
+
seq_length = lengths[traj_idx].clamp_max(seq_length)
|
|
1497
|
+
else:
|
|
1498
|
+
if traj_idx is None:
|
|
1499
|
+
traj_idx = get_traj_idx(lengths.shape[0])
|
|
1500
|
+
else:
|
|
1501
|
+
num_slices = traj_idx.shape[0]
|
|
1502
|
+
return self._get_index(
|
|
1503
|
+
lengths=lengths,
|
|
1504
|
+
start_idx=start_idx,
|
|
1505
|
+
stop_idx=stop_idx,
|
|
1506
|
+
num_slices=num_slices,
|
|
1507
|
+
seq_length=seq_length,
|
|
1508
|
+
storage_length=storage_length,
|
|
1509
|
+
traj_idx=traj_idx,
|
|
1510
|
+
storage=storage,
|
|
1511
|
+
)
|
|
1512
|
+
|
|
1513
|
+
def _get_index(
|
|
1514
|
+
self,
|
|
1515
|
+
lengths: torch.Tensor,
|
|
1516
|
+
start_idx: torch.Tensor,
|
|
1517
|
+
stop_idx: torch.Tensor,
|
|
1518
|
+
seq_length: int,
|
|
1519
|
+
num_slices: int,
|
|
1520
|
+
storage_length: int,
|
|
1521
|
+
traj_idx: torch.Tensor | None = None,
|
|
1522
|
+
*,
|
|
1523
|
+
storage,
|
|
1524
|
+
) -> tuple[torch.Tensor, dict]:
|
|
1525
|
+
# end_point is the last possible index for start
|
|
1526
|
+
last_indexable_start = lengths[traj_idx] - seq_length + 1
|
|
1527
|
+
if not self.span[1]:
|
|
1528
|
+
end_point = last_indexable_start
|
|
1529
|
+
elif self.span[1] is True:
|
|
1530
|
+
end_point = lengths[traj_idx] + 1
|
|
1531
|
+
else:
|
|
1532
|
+
span_left = self.span[1]
|
|
1533
|
+
if span_left >= seq_length:
|
|
1534
|
+
raise ValueError(
|
|
1535
|
+
"The right and left span must be strictly lower than the sequence length"
|
|
1536
|
+
)
|
|
1537
|
+
end_point = lengths[traj_idx] - span_left
|
|
1538
|
+
|
|
1539
|
+
if not self.span[0]:
|
|
1540
|
+
start_point = 0
|
|
1541
|
+
elif self.span[0] is True:
|
|
1542
|
+
start_point = -seq_length + 1
|
|
1543
|
+
else:
|
|
1544
|
+
span_right = self.span[0]
|
|
1545
|
+
if span_right >= seq_length:
|
|
1546
|
+
raise ValueError(
|
|
1547
|
+
"The right and left span must be strictly lower than the sequence length"
|
|
1548
|
+
)
|
|
1549
|
+
start_point = -span_right
|
|
1550
|
+
|
|
1551
|
+
relative_starts = (
|
|
1552
|
+
torch.rand(num_slices, device=lengths.device, generator=self._rng)
|
|
1553
|
+
* (end_point - start_point)
|
|
1554
|
+
).floor().to(start_idx.dtype) + start_point
|
|
1555
|
+
|
|
1556
|
+
if self.span[0]:
|
|
1557
|
+
out_of_traj = relative_starts < 0
|
|
1558
|
+
if out_of_traj.any():
|
|
1559
|
+
# a negative start means sampling fewer elements
|
|
1560
|
+
seq_length = torch.where(
|
|
1561
|
+
~out_of_traj, seq_length, seq_length + relative_starts
|
|
1562
|
+
)
|
|
1563
|
+
relative_starts = torch.where(~out_of_traj, relative_starts, 0)
|
|
1564
|
+
if self.span[1]:
|
|
1565
|
+
out_of_traj = relative_starts + seq_length > lengths[traj_idx]
|
|
1566
|
+
if out_of_traj.any():
|
|
1567
|
+
# a negative start means sampling fewer elements
|
|
1568
|
+
seq_length = torch.minimum(
|
|
1569
|
+
seq_length, lengths[traj_idx] - relative_starts
|
|
1570
|
+
)
|
|
1571
|
+
|
|
1572
|
+
starts = torch.cat(
|
|
1573
|
+
[
|
|
1574
|
+
(start_idx[traj_idx, 0] + relative_starts).unsqueeze(1),
|
|
1575
|
+
start_idx[traj_idx, 1:],
|
|
1576
|
+
],
|
|
1577
|
+
1,
|
|
1578
|
+
)
|
|
1579
|
+
index = self._tensor_slices_from_startend(seq_length, starts, storage_length)
|
|
1580
|
+
if self.truncated_key is not None:
|
|
1581
|
+
truncated_key = self.truncated_key
|
|
1582
|
+
done_key = _replace_last(truncated_key, "done")
|
|
1583
|
+
terminated_key = _replace_last(truncated_key, "terminated")
|
|
1584
|
+
|
|
1585
|
+
truncated = torch.zeros(
|
|
1586
|
+
(index.shape[0], 1), dtype=torch.bool, device=index.device
|
|
1587
|
+
)
|
|
1588
|
+
if isinstance(seq_length, int):
|
|
1589
|
+
truncated.view(num_slices, -1)[:, -1] = 1
|
|
1590
|
+
else:
|
|
1591
|
+
truncated[seq_length.cumsum(0) - 1] = 1
|
|
1592
|
+
index = index.to(torch.long).unbind(-1)
|
|
1593
|
+
st_index = storage[index]
|
|
1594
|
+
done = st_index.get(done_key, default=None)
|
|
1595
|
+
if done is None:
|
|
1596
|
+
done = truncated.clone()
|
|
1597
|
+
else:
|
|
1598
|
+
done = done | truncated
|
|
1599
|
+
terminated = st_index.get(terminated_key, default=None)
|
|
1600
|
+
if terminated is None:
|
|
1601
|
+
terminated = torch.zeros_like(truncated)
|
|
1602
|
+
return index, {
|
|
1603
|
+
truncated_key: truncated,
|
|
1604
|
+
done_key: done,
|
|
1605
|
+
terminated_key: terminated,
|
|
1606
|
+
}
|
|
1607
|
+
index = index.to(torch.long).unbind(-1)
|
|
1608
|
+
return index, {}
|
|
1609
|
+
|
|
1610
|
+
@property
|
|
1611
|
+
def _used_traj_key(self):
|
|
1612
|
+
return self.__dict__.get("__used_traj_key", self.traj_key)
|
|
1613
|
+
|
|
1614
|
+
@_used_traj_key.setter
|
|
1615
|
+
def _used_traj_key(self, value):
|
|
1616
|
+
self.__dict__["__used_traj_key"] = value
|
|
1617
|
+
|
|
1618
|
+
@property
|
|
1619
|
+
def _used_end_key(self):
|
|
1620
|
+
return self.__dict__.get("__used_end_key", self.end_key)
|
|
1621
|
+
|
|
1622
|
+
@_used_end_key.setter
|
|
1623
|
+
def _used_end_key(self, value):
|
|
1624
|
+
self.__dict__["__used_end_key"] = value
|
|
1625
|
+
|
|
1626
|
+
def _empty(self):
|
|
1627
|
+
pass
|
|
1628
|
+
|
|
1629
|
+
def dumps(self, path):
|
|
1630
|
+
# no op - cache does not need to be saved
|
|
1631
|
+
...
|
|
1632
|
+
|
|
1633
|
+
def loads(self, path):
|
|
1634
|
+
# no op
|
|
1635
|
+
...
|
|
1636
|
+
|
|
1637
|
+
def state_dict(self) -> dict[str, Any]:
|
|
1638
|
+
return {}
|
|
1639
|
+
|
|
1640
|
+
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
|
|
1641
|
+
...
|
|
1642
|
+
|
|
1643
|
+
|
|
1644
|
+
class SliceSamplerWithoutReplacement(SliceSampler, SamplerWithoutReplacement):
|
|
1645
|
+
"""Samples slices of data along the first dimension, given start and stop signals, without replacement.
|
|
1646
|
+
|
|
1647
|
+
In this context, ``without replacement`` means that the same element (NOT trajectory) will not be sampled twice
|
|
1648
|
+
before the counter is automatically reset. Within a single sample, however, only one slice of a given trajectory
|
|
1649
|
+
will appear (see example below).
|
|
1650
|
+
|
|
1651
|
+
This class is to be used with static replay buffers or in between two
|
|
1652
|
+
replay buffer extensions. Extending the replay buffer will reset the
|
|
1653
|
+
the sampler, and continuous sampling without replacement is currently not
|
|
1654
|
+
allowed.
|
|
1655
|
+
|
|
1656
|
+
.. note:: `SliceSamplerWithoutReplacement` can be slow to retrieve the trajectory indices. To accelerate
|
|
1657
|
+
its execution, prefer using `end_key` over `traj_key`, and consider the following
|
|
1658
|
+
keyword arguments: :attr:`compile`, :attr:`cache_values` and :attr:`use_gpu`.
|
|
1659
|
+
|
|
1660
|
+
Keyword Args:
|
|
1661
|
+
drop_last (bool, optional): if ``True``, the last incomplete sample (if any) will be dropped.
|
|
1662
|
+
If ``False``, this last sample will be kept.
|
|
1663
|
+
Defaults to ``False``.
|
|
1664
|
+
num_slices (int): the number of slices to be sampled. The batch-size
|
|
1665
|
+
must be greater or equal to the ``num_slices`` argument. Exclusive
|
|
1666
|
+
with ``slice_len``.
|
|
1667
|
+
slice_len (int): the length of the slices to be sampled. The batch-size
|
|
1668
|
+
must be greater or equal to the ``slice_len`` argument and divisible
|
|
1669
|
+
by it. Exclusive with ``num_slices``.
|
|
1670
|
+
end_key (NestedKey, optional): the key indicating the end of a
|
|
1671
|
+
trajectory (or episode). Defaults to ``("next", "done")``.
|
|
1672
|
+
traj_key (NestedKey, optional): the key indicating the trajectories.
|
|
1673
|
+
Defaults to ``"episode"`` (commonly used across datasets in TorchRL).
|
|
1674
|
+
ends (torch.Tensor, optional): a 1d boolean tensor containing the end of run signals.
|
|
1675
|
+
To be used whenever the ``end_key`` or ``traj_key`` is expensive to get,
|
|
1676
|
+
or when this signal is readily available. Must be used with ``cache_values=True``
|
|
1677
|
+
and cannot be used in conjunction with ``end_key`` or ``traj_key``.
|
|
1678
|
+
trajectories (torch.Tensor, optional): a 1d integer tensor containing the run ids.
|
|
1679
|
+
To be used whenever the ``end_key`` or ``traj_key`` is expensive to get,
|
|
1680
|
+
or when this signal is readily available. Must be used with ``cache_values=True``
|
|
1681
|
+
and cannot be used in conjunction with ``end_key`` or ``traj_key``.
|
|
1682
|
+
truncated_key (NestedKey, optional): If not ``None``, this argument
|
|
1683
|
+
indicates where a truncated signal should be written in the output
|
|
1684
|
+
data. This is used to indicate to value estimators where the provided
|
|
1685
|
+
trajectory breaks. Defaults to ``("next", "truncated")``.
|
|
1686
|
+
This feature only works with :class:`~torchrl.data.replay_buffers.TensorDictReplayBuffer`
|
|
1687
|
+
instances (otherwise the truncated key is returned in the info dictionary
|
|
1688
|
+
returned by the :meth:`~torchrl.data.replay_buffers.ReplayBuffer.sample` method).
|
|
1689
|
+
strict_length (bool, optional): if ``False``, trajectories of length
|
|
1690
|
+
shorter than `slice_len` (or `batch_size // num_slices`) will be
|
|
1691
|
+
allowed to appear in the batch. If ``True``, trajectories shorted
|
|
1692
|
+
than required will be filtered out.
|
|
1693
|
+
Be mindful that this can result in effective `batch_size` shorter
|
|
1694
|
+
than the one asked for! Trajectories can be split using
|
|
1695
|
+
:func:`~torchrl.collectors.split_trajectories`. Defaults to ``True``.
|
|
1696
|
+
shuffle (bool, optional): if ``False``, the order of the trajectories
|
|
1697
|
+
is not shuffled. Defaults to ``True``.
|
|
1698
|
+
compile (bool or dict of kwargs, optional): if ``True``, the bottleneck of
|
|
1699
|
+
the :meth:`~sample` method will be compiled with :func:`~torch.compile`.
|
|
1700
|
+
Keyword arguments can also be passed to torch.compile with this arg.
|
|
1701
|
+
Defaults to ``False``.
|
|
1702
|
+
use_gpu (bool or torch.device): if ``True`` (or is a device is passed), an accelerator
|
|
1703
|
+
will be used to retrieve the indices of the trajectory starts. This can significantly
|
|
1704
|
+
accelerate the sampling when the buffer content is large.
|
|
1705
|
+
Defaults to ``False``.
|
|
1706
|
+
|
|
1707
|
+
.. note:: To recover the trajectory splits in the storage,
|
|
1708
|
+
:class:`~torchrl.data.replay_buffers.samplers.SliceSamplerWithoutReplacement` will first
|
|
1709
|
+
attempt to find the ``traj_key`` entry in the storage. If it cannot be
|
|
1710
|
+
found, the ``end_key`` will be used to reconstruct the episodes.
|
|
1711
|
+
|
|
1712
|
+
Examples:
|
|
1713
|
+
>>> import torch
|
|
1714
|
+
>>> from tensordict import TensorDict
|
|
1715
|
+
>>> from torchrl.data.replay_buffers import LazyMemmapStorage, TensorDictReplayBuffer
|
|
1716
|
+
>>> from torchrl.data.replay_buffers.samplers import SliceSamplerWithoutReplacement
|
|
1717
|
+
>>>
|
|
1718
|
+
>>> rb = TensorDictReplayBuffer(
|
|
1719
|
+
... storage=LazyMemmapStorage(1000),
|
|
1720
|
+
... # asking for 10 slices for a total of 320 elements, ie, 10 trajectories of 32 transitions each
|
|
1721
|
+
... sampler=SliceSamplerWithoutReplacement(num_slices=10),
|
|
1722
|
+
... batch_size=320,
|
|
1723
|
+
... )
|
|
1724
|
+
>>> episode = torch.zeros(1000, dtype=torch.int)
|
|
1725
|
+
>>> episode[:300] = 1
|
|
1726
|
+
>>> episode[300:550] = 2
|
|
1727
|
+
>>> episode[550:700] = 3
|
|
1728
|
+
>>> episode[700:] = 4
|
|
1729
|
+
>>> data = TensorDict(
|
|
1730
|
+
... {
|
|
1731
|
+
... "episode": episode,
|
|
1732
|
+
... "obs": torch.randn((3, 4, 5)).expand(1000, 3, 4, 5),
|
|
1733
|
+
... "act": torch.randn((20,)).expand(1000, 20),
|
|
1734
|
+
... "other": torch.randn((20, 50)).expand(1000, 20, 50),
|
|
1735
|
+
... }, [1000]
|
|
1736
|
+
... )
|
|
1737
|
+
>>> rb.extend(data)
|
|
1738
|
+
>>> sample = rb.sample()
|
|
1739
|
+
>>> # since we want trajectories of 32 transitions but there are only 4 episodes to
|
|
1740
|
+
>>> # sample from, we only get 4 x 32 = 128 transitions in this batch
|
|
1741
|
+
>>> print("sample:", sample)
|
|
1742
|
+
>>> print("trajectories in sample", sample.get("episode").unique())
|
|
1743
|
+
|
|
1744
|
+
:class:`~torchrl.data.replay_buffers.SliceSamplerWithoutReplacement` is default-compatible with
|
|
1745
|
+
most of TorchRL's datasets, and allows users to consume datasets in a dataloader-like fashion:
|
|
1746
|
+
|
|
1747
|
+
Examples:
|
|
1748
|
+
>>> import torch
|
|
1749
|
+
>>>
|
|
1750
|
+
>>> from torchrl.data.datasets import RobosetExperienceReplay
|
|
1751
|
+
>>> from torchrl.data import SliceSamplerWithoutReplacement
|
|
1752
|
+
>>>
|
|
1753
|
+
>>> torch.manual_seed(0)
|
|
1754
|
+
>>> num_slices = 10
|
|
1755
|
+
>>> dataid = list(RobosetExperienceReplay.available_datasets)[0]
|
|
1756
|
+
>>> data = RobosetExperienceReplay(dataid, batch_size=320,
|
|
1757
|
+
... sampler=SliceSamplerWithoutReplacement(num_slices=num_slices))
|
|
1758
|
+
>>> # the last sample is kept, since drop_last=False by default
|
|
1759
|
+
>>> for i, batch in enumerate(data):
|
|
1760
|
+
... print(batch.get("episode").unique())
|
|
1761
|
+
tensor([ 5, 6, 8, 11, 12, 14, 16, 17, 19, 24])
|
|
1762
|
+
tensor([ 1, 2, 7, 9, 10, 13, 15, 18, 21, 22])
|
|
1763
|
+
tensor([ 0, 3, 4, 20, 23])
|
|
1764
|
+
|
|
1765
|
+
When requesting a large total number of samples with few trajectories and small span, the batch will contain
|
|
1766
|
+
only at most one sample of each trajectory:
|
|
1767
|
+
|
|
1768
|
+
Examples:
|
|
1769
|
+
>>> import torch
|
|
1770
|
+
>>> from tensordict import TensorDict
|
|
1771
|
+
>>> from torchrl.collectors.utils import split_trajectories
|
|
1772
|
+
>>> from torchrl.data import ReplayBuffer, LazyTensorStorage, SliceSampler, SliceSamplerWithoutReplacement
|
|
1773
|
+
>>>
|
|
1774
|
+
>>> rb = ReplayBuffer(storage=LazyTensorStorage(max_size=1000),
|
|
1775
|
+
... sampler=SliceSamplerWithoutReplacement(
|
|
1776
|
+
... slice_len=5, traj_key="episode",strict_length=False
|
|
1777
|
+
... ))
|
|
1778
|
+
...
|
|
1779
|
+
>>> ep_1 = TensorDict(
|
|
1780
|
+
... {"obs": torch.arange(100),
|
|
1781
|
+
... "episode": torch.zeros(100),},
|
|
1782
|
+
... batch_size=[100]
|
|
1783
|
+
... )
|
|
1784
|
+
>>> ep_2 = TensorDict(
|
|
1785
|
+
... {"obs": torch.arange(51),
|
|
1786
|
+
... "episode": torch.ones(51),},
|
|
1787
|
+
... batch_size=[51]
|
|
1788
|
+
... )
|
|
1789
|
+
>>> rb.extend(ep_1)
|
|
1790
|
+
>>> rb.extend(ep_2)
|
|
1791
|
+
>>>
|
|
1792
|
+
>>> s = rb.sample(50)
|
|
1793
|
+
>>> t = split_trajectories(s, trajectory_key="episode")
|
|
1794
|
+
>>> print(t["obs"])
|
|
1795
|
+
tensor([[14, 15, 16, 17, 18],
|
|
1796
|
+
[ 3, 4, 5, 6, 7]])
|
|
1797
|
+
>>> print(t["episode"])
|
|
1798
|
+
tensor([[0., 0., 0., 0., 0.],
|
|
1799
|
+
[1., 1., 1., 1., 1.]])
|
|
1800
|
+
>>>
|
|
1801
|
+
>>> s = rb.sample(50)
|
|
1802
|
+
>>> t = split_trajectories(s, trajectory_key="episode")
|
|
1803
|
+
>>> print(t["obs"])
|
|
1804
|
+
tensor([[ 4, 5, 6, 7, 8],
|
|
1805
|
+
[26, 27, 28, 29, 30]])
|
|
1806
|
+
>>> print(t["episode"])
|
|
1807
|
+
tensor([[0., 0., 0., 0., 0.],
|
|
1808
|
+
[1., 1., 1., 1., 1.]])
|
|
1809
|
+
|
|
1810
|
+
"""
|
|
1811
|
+
|
|
1812
|
+
def __init__(
|
|
1813
|
+
self,
|
|
1814
|
+
*,
|
|
1815
|
+
num_slices: int | None = None,
|
|
1816
|
+
slice_len: int | None = None,
|
|
1817
|
+
drop_last: bool = False,
|
|
1818
|
+
end_key: NestedKey | None = None,
|
|
1819
|
+
traj_key: NestedKey | None = None,
|
|
1820
|
+
ends: torch.Tensor | None = None,
|
|
1821
|
+
trajectories: torch.Tensor | None = None,
|
|
1822
|
+
truncated_key: NestedKey | None = ("next", "truncated"),
|
|
1823
|
+
strict_length: bool = True,
|
|
1824
|
+
shuffle: bool = True,
|
|
1825
|
+
compile: bool | dict = False,
|
|
1826
|
+
use_gpu: bool | torch.device = False,
|
|
1827
|
+
):
|
|
1828
|
+
SliceSampler.__init__(
|
|
1829
|
+
self,
|
|
1830
|
+
num_slices=num_slices,
|
|
1831
|
+
slice_len=slice_len,
|
|
1832
|
+
end_key=end_key,
|
|
1833
|
+
traj_key=traj_key,
|
|
1834
|
+
cache_values=True,
|
|
1835
|
+
truncated_key=truncated_key,
|
|
1836
|
+
strict_length=strict_length,
|
|
1837
|
+
ends=ends,
|
|
1838
|
+
trajectories=trajectories,
|
|
1839
|
+
compile=compile,
|
|
1840
|
+
use_gpu=use_gpu,
|
|
1841
|
+
)
|
|
1842
|
+
SamplerWithoutReplacement.__init__(self, drop_last=drop_last, shuffle=shuffle)
|
|
1843
|
+
|
|
1844
|
+
def __repr__(self):
|
|
1845
|
+
if self._sample_list is not None:
|
|
1846
|
+
perc = len(self._sample_list) / self.len_storage * 100
|
|
1847
|
+
else:
|
|
1848
|
+
perc = 0
|
|
1849
|
+
return (
|
|
1850
|
+
f"{self.__class__.__name__}("
|
|
1851
|
+
f"num_slices={self.num_slices}, "
|
|
1852
|
+
f"slice_len={self.slice_len}, "
|
|
1853
|
+
f"end_key={self.end_key}, "
|
|
1854
|
+
f"traj_key={self.traj_key}, "
|
|
1855
|
+
f"truncated_key={self.truncated_key}, "
|
|
1856
|
+
f"strict_length={self.strict_length},"
|
|
1857
|
+
f"{perc}% sampled)"
|
|
1858
|
+
)
|
|
1859
|
+
|
|
1860
|
+
def _empty(self):
|
|
1861
|
+
self._cache = {}
|
|
1862
|
+
SamplerWithoutReplacement._empty(self)
|
|
1863
|
+
|
|
1864
|
+
def _storage_len(self, storage):
|
|
1865
|
+
return self._storage_len_buffer
|
|
1866
|
+
|
|
1867
|
+
def sample(
|
|
1868
|
+
self, storage: Storage, batch_size: int
|
|
1869
|
+
) -> tuple[tuple[torch.Tensor, ...], dict]:
|
|
1870
|
+
if self._batch_size_multiplier is not None:
|
|
1871
|
+
batch_size = batch_size * self._batch_size_multiplier
|
|
1872
|
+
start_idx, stop_idx, lengths = self._get_stop_and_length(storage)
|
|
1873
|
+
# we have to make sure that the number of dims of the storage
|
|
1874
|
+
# is the same as the stop/start signals since we will
|
|
1875
|
+
# use these to sample the storage
|
|
1876
|
+
if start_idx.shape[1] != storage.ndim:
|
|
1877
|
+
raise RuntimeError(
|
|
1878
|
+
f"Expected the end-of-trajectory signal to be "
|
|
1879
|
+
f"{storage.ndim}-dimensional. Got a {start_idx.shape[1]} tensor "
|
|
1880
|
+
"instead."
|
|
1881
|
+
)
|
|
1882
|
+
self._storage_len_buffer = len(start_idx)
|
|
1883
|
+
# first get indices of the trajectories we want to retrieve
|
|
1884
|
+
seq_length, num_slices = self._adjusted_batch_size(batch_size)
|
|
1885
|
+
indices, _ = SamplerWithoutReplacement.sample(self, storage, num_slices)
|
|
1886
|
+
storage_length = storage.shape[0]
|
|
1887
|
+
|
|
1888
|
+
# traj_idx will either be a single tensor or a tuple that can be reorganized
|
|
1889
|
+
# like a non-zero through stacking.
|
|
1890
|
+
def tuple_to_tensor(traj_idx, lengths=lengths):
|
|
1891
|
+
if isinstance(traj_idx, tuple):
|
|
1892
|
+
traj_idx = torch.arange(len(storage), device=lengths.device).view(
|
|
1893
|
+
storage.shape
|
|
1894
|
+
)[traj_idx]
|
|
1895
|
+
return traj_idx
|
|
1896
|
+
|
|
1897
|
+
idx, info = self._sample_slices(
|
|
1898
|
+
lengths,
|
|
1899
|
+
start_idx,
|
|
1900
|
+
stop_idx,
|
|
1901
|
+
seq_length,
|
|
1902
|
+
num_slices,
|
|
1903
|
+
storage_length,
|
|
1904
|
+
traj_idx=tuple_to_tensor(indices),
|
|
1905
|
+
storage=storage,
|
|
1906
|
+
)
|
|
1907
|
+
return idx, info
|
|
1908
|
+
|
|
1909
|
+
def state_dict(self) -> dict[str, Any]:
|
|
1910
|
+
return SamplerWithoutReplacement.state_dict(self)
|
|
1911
|
+
|
|
1912
|
+
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
|
|
1913
|
+
return SamplerWithoutReplacement.load_state_dict(self, state_dict)
|
|
1914
|
+
|
|
1915
|
+
|
|
1916
|
+
class PrioritizedSliceSampler(SliceSampler, PrioritizedSampler):
|
|
1917
|
+
r"""Samples slices of data along the first dimension, given start and stop signals, using prioritized sampling.
|
|
1918
|
+
|
|
1919
|
+
This class combines trajectory sampling with Prioritized Experience Replay (PER) as presented in
|
|
1920
|
+
"Schaul, T.; Quan, J.; Antonoglou, I.; and Silver, D. 2015. Prioritized experience replay."
|
|
1921
|
+
(https://arxiv.org/abs/1511.05952)
|
|
1922
|
+
|
|
1923
|
+
**Core Idea**: Instead of sampling trajectory slices uniformly, this sampler prioritizes
|
|
1924
|
+
trajectory start points based on the importance of the transitions at those positions.
|
|
1925
|
+
This allows focusing learning on the most informative parts of trajectories.
|
|
1926
|
+
|
|
1927
|
+
**How it works**:
|
|
1928
|
+
1. Each transition is assigned a priority based on its TD error: :math:`p_i = |\\delta_i| + \\epsilon`
|
|
1929
|
+
2. Trajectory start points are sampled with probability: :math:`P(i) = \frac{p_i^\alpha}{\\sum_j p_j^\alpha}`
|
|
1930
|
+
3. Importance sampling weights correct for bias: :math:`w_i = (N \\cdot P(i))^{-\beta}`
|
|
1931
|
+
4. Complete trajectory slices are extracted from the sampled start points
|
|
1932
|
+
|
|
1933
|
+
For more info see :class:`~torchrl.data.replay_buffers.samplers.SliceSampler` and :class:`~torchrl.data.replay_buffers.samplers.PrioritizedSampler`.
|
|
1934
|
+
|
|
1935
|
+
.. warning:: PrioritizedSliceSampler will look at the priorities of the individual transitions and sample the
|
|
1936
|
+
start points accordingly. This means that transitions with a low priority may as well appear in the
|
|
1937
|
+
samples if they follow another of higher priority, and transitions with a high priority but closer to the
|
|
1938
|
+
end of a trajectory may never be sampled if they cannot be used as start points.
|
|
1939
|
+
Currently, it is the user responsibility to aggregate priorities across items of a trajectory using
|
|
1940
|
+
:meth:`update_priority`.
|
|
1941
|
+
|
|
1942
|
+
Args:
|
|
1943
|
+
max_capacity (int): maximum capacity of the buffer.
|
|
1944
|
+
alpha (:obj:`float`): exponent :math:`\alpha` determines how much prioritization is used.
|
|
1945
|
+
- :math:`\alpha = 0`: uniform sampling of trajectory start points
|
|
1946
|
+
- :math:`\alpha = 1`: full prioritization based on TD error magnitude at start points
|
|
1947
|
+
- Typical values: 0.4-0.7 for balanced prioritization
|
|
1948
|
+
- Higher :math:`\alpha` means more aggressive prioritization of high-error trajectory regions
|
|
1949
|
+
beta (:obj:`float`): importance sampling negative exponent :math:`\beta`.
|
|
1950
|
+
- :math:`\beta` controls the correction for the bias introduced by prioritization
|
|
1951
|
+
- :math:`\beta = 0`: no correction (biased towards high-priority trajectory regions)
|
|
1952
|
+
- :math:`\beta = 1`: full correction (unbiased but potentially unstable)
|
|
1953
|
+
- Typical values: start at 0.4-0.6 and anneal to 1.0 during training
|
|
1954
|
+
- Lower :math:`\beta` early in training provides stability, higher :math:`\beta` later reduces bias
|
|
1955
|
+
eps (:obj:`float`, optional): small constant added to priorities to ensure
|
|
1956
|
+
no transition has zero priority. This prevents trajectory regions from never
|
|
1957
|
+
being sampled. Defaults to 1e-8.
|
|
1958
|
+
reduction (str, optional): the reduction method for multidimensional
|
|
1959
|
+
tensordicts (i.e., stored trajectory). Can be one of "max", "min",
|
|
1960
|
+
"median" or "mean".
|
|
1961
|
+
|
|
1962
|
+
**Parameter Guidelines**:
|
|
1963
|
+
|
|
1964
|
+
- **:math:`\alpha` (alpha)**: Controls how much to prioritize high-error trajectory regions.
|
|
1965
|
+
0.4-0.7: Good balance between learning speed and stability.
|
|
1966
|
+
1.0: Maximum prioritization (may be unstable).
|
|
1967
|
+
0.0: Uniform sampling (no prioritization benefit).
|
|
1968
|
+
|
|
1969
|
+
- **:math:`\beta` (beta)**: Controls importance sampling correction.
|
|
1970
|
+
Start at 0.4-0.6 for training stability.
|
|
1971
|
+
Anneal to 1.0 over training to reduce bias.
|
|
1972
|
+
Lower values = more stable but biased.
|
|
1973
|
+
Higher values = less biased but potentially unstable.
|
|
1974
|
+
|
|
1975
|
+
- **:math:`\\epsilon`**: Small constant to prevent zero priorities.
|
|
1976
|
+
1e-8: Good default value.
|
|
1977
|
+
Too small: may cause numerical issues.
|
|
1978
|
+
Too large: reduces prioritization effect.
|
|
1979
|
+
|
|
1980
|
+
Keyword Args:
|
|
1981
|
+
num_slices (int): the number of slices to be sampled. The batch-size
|
|
1982
|
+
must be greater or equal to the ``num_slices`` argument. Exclusive
|
|
1983
|
+
with ``slice_len``.
|
|
1984
|
+
slice_len (int): the length of the slices to be sampled. The batch-size
|
|
1985
|
+
must be greater or equal to the ``slice_len`` argument and divisible
|
|
1986
|
+
by it. Exclusive with ``num_slices``.
|
|
1987
|
+
end_key (NestedKey, optional): the key indicating the end of a
|
|
1988
|
+
trajectory (or episode). Defaults to ``("next", "done")``.
|
|
1989
|
+
traj_key (NestedKey, optional): the key indicating the trajectories.
|
|
1990
|
+
Defaults to ``"episode"`` (commonly used across datasets in TorchRL).
|
|
1991
|
+
ends (torch.Tensor, optional): a 1d boolean tensor containing the end of run signals.
|
|
1992
|
+
To be used whenever the ``end_key`` or ``traj_key`` is expensive to get,
|
|
1993
|
+
or when this signal is readily available. Must be used with ``cache_values=True``
|
|
1994
|
+
and cannot be used in conjunction with ``end_key`` or ``traj_key``.
|
|
1995
|
+
trajectories (torch.Tensor, optional): a 1d integer tensor containing the run ids.
|
|
1996
|
+
To be used whenever the ``end_key`` or ``traj_key`` is expensive to get,
|
|
1997
|
+
or when this signal is readily available. Must be used with ``cache_values=True``
|
|
1998
|
+
and cannot be used in conjunction with ``end_key`` or ``traj_key``.
|
|
1999
|
+
cache_values (bool, optional): to be used with static datasets.
|
|
2000
|
+
Will cache the start and end signal of the trajectory. This can be safely used even
|
|
2001
|
+
if the trajectory indices change during calls to :class:`~torchrl.data.ReplayBuffer.extend`
|
|
2002
|
+
as this operation will erase the cache.
|
|
2003
|
+
|
|
2004
|
+
.. warning:: ``cache_values=True`` will not work if the sampler is used with a
|
|
2005
|
+
storage that is extended by another buffer. For instance:
|
|
2006
|
+
|
|
2007
|
+
>>> buffer0 = ReplayBuffer(storage=storage,
|
|
2008
|
+
... sampler=SliceSampler(num_slices=8, cache_values=True),
|
|
2009
|
+
... writer=ImmutableWriter())
|
|
2010
|
+
>>> buffer1 = ReplayBuffer(storage=storage,
|
|
2011
|
+
... sampler=other_sampler)
|
|
2012
|
+
>>> # Wrong! Does not erase the buffer from the sampler of buffer0
|
|
2013
|
+
>>> buffer1.extend(data)
|
|
2014
|
+
|
|
2015
|
+
.. warning:: ``cache_values=True`` will not work as expected if the buffer is
|
|
2016
|
+
shared between processes and one process is responsible for writing
|
|
2017
|
+
and one process for sampling, as erasing the cache can only be done locally.
|
|
2018
|
+
|
|
2019
|
+
truncated_key (NestedKey, optional): If not ``None``, this argument
|
|
2020
|
+
indicates where a truncated signal should be written in the output
|
|
2021
|
+
data. This is used to indicate to value estimators where the provided
|
|
2022
|
+
trajectory breaks. Defaults to ``("next", "truncated")``.
|
|
2023
|
+
This feature only works with :class:`~torchrl.data.replay_buffers.TensorDictReplayBuffer`
|
|
2024
|
+
instances (otherwise the truncated key is returned in the info dictionary
|
|
2025
|
+
returned by the :meth:`~torchrl.data.replay_buffers.ReplayBuffer.sample` method).
|
|
2026
|
+
strict_length (bool, optional): if ``False``, trajectories of length
|
|
2027
|
+
shorter than `slice_len` (or `batch_size // num_slices`) will be
|
|
2028
|
+
allowed to appear in the batch. If ``True``, trajectories shorted
|
|
2029
|
+
than required will be filtered out.
|
|
2030
|
+
Be mindful that this can result in effective `batch_size` shorter
|
|
2031
|
+
than the one asked for! Trajectories can be split using
|
|
2032
|
+
:func:`~torchrl.collectors.split_trajectories`. Defaults to ``True``.
|
|
2033
|
+
compile (bool or dict of kwargs, optional): if ``True``, the bottleneck of
|
|
2034
|
+
the :meth:`~sample` method will be compiled with :func:`~torch.compile`.
|
|
2035
|
+
Keyword arguments can also be passed to torch.compile with this arg.
|
|
2036
|
+
Defaults to ``False``.
|
|
2037
|
+
span (bool, int, Tuple[bool | int, bool | int], optional): if provided, the sampled
|
|
2038
|
+
trajectory will span across the left and/or the right. This means that possibly
|
|
2039
|
+
fewer elements will be provided than what was required. A boolean value means
|
|
2040
|
+
that at least one element will be sampled per trajectory. An integer `i` means
|
|
2041
|
+
that at least `slice_len - i` samples will be gathered for each sampled trajectory.
|
|
2042
|
+
Using tuples allows a fine grained control over the span on the left (beginning
|
|
2043
|
+
of the stored trajectory) and on the right (end of the stored trajectory).
|
|
2044
|
+
max_priority_within_buffer (bool, optional): if ``True``, the max-priority
|
|
2045
|
+
is tracked within the buffer. When ``False``, the max-priority tracks
|
|
2046
|
+
the maximum value since the instantiation of the sampler.
|
|
2047
|
+
Defaults to ``False``.
|
|
2048
|
+
|
|
2049
|
+
Examples:
|
|
2050
|
+
>>> import torch
|
|
2051
|
+
>>> from torchrl.data.replay_buffers import TensorDictReplayBuffer, LazyMemmapStorage, PrioritizedSliceSampler
|
|
2052
|
+
>>> from tensordict import TensorDict
|
|
2053
|
+
>>> sampler = PrioritizedSliceSampler(max_capacity=9, num_slices=3, alpha=0.7, beta=0.9)
|
|
2054
|
+
>>> rb = TensorDictReplayBuffer(storage=LazyMemmapStorage(9), sampler=sampler, batch_size=6)
|
|
2055
|
+
>>> data = TensorDict(
|
|
2056
|
+
... {
|
|
2057
|
+
... "observation": torch.randn(9,16),
|
|
2058
|
+
... "action": torch.randn(9, 1),
|
|
2059
|
+
... "episode": torch.tensor([0,0,0,1,1,1,2,2,2], dtype=torch.long),
|
|
2060
|
+
... "steps": torch.tensor([0,1,2,0,1,2,0,1,2], dtype=torch.long),
|
|
2061
|
+
... ("next", "observation"): torch.randn(9,16),
|
|
2062
|
+
... ("next", "reward"): torch.randn(9,1),
|
|
2063
|
+
... ("next", "done"): torch.tensor([0,0,1,0,0,1,0,0,1], dtype=torch.bool).unsqueeze(1),
|
|
2064
|
+
... },
|
|
2065
|
+
... batch_size=[9],
|
|
2066
|
+
... )
|
|
2067
|
+
>>> rb.extend(data)
|
|
2068
|
+
>>> sample, info = rb.sample(return_info=True)
|
|
2069
|
+
>>> print("episode", sample["episode"].tolist())
|
|
2070
|
+
episode [2, 2, 2, 2, 1, 1]
|
|
2071
|
+
>>> print("steps", sample["steps"].tolist())
|
|
2072
|
+
steps [1, 2, 0, 1, 1, 2]
|
|
2073
|
+
>>> print("weight", info["priority_weight"].tolist())
|
|
2074
|
+
weight [1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
|
|
2075
|
+
>>> priority = torch.tensor([0,3,3,0,0,0,1,1,1])
|
|
2076
|
+
>>> rb.update_priority(torch.arange(0,9,1), priority=priority)
|
|
2077
|
+
>>> sample, info = rb.sample(return_info=True)
|
|
2078
|
+
>>> print("episode", sample["episode"].tolist())
|
|
2079
|
+
episode [2, 2, 2, 2, 2, 2]
|
|
2080
|
+
>>> print("steps", sample["steps"].tolist())
|
|
2081
|
+
steps [1, 2, 0, 1, 0, 1]
|
|
2082
|
+
>>> print("weight", info["priority_weight"].tolist())
|
|
2083
|
+
weight [9.120110917137936e-06, 9.120110917137936e-06, 9.120110917137936e-06, 9.120110917137936e-06, 9.120110917137936e-06, 9.120110917137936e-06]
|
|
2084
|
+
"""
|
|
2085
|
+
|
|
2086
|
+
def __init__(
|
|
2087
|
+
self,
|
|
2088
|
+
max_capacity: int,
|
|
2089
|
+
alpha: float,
|
|
2090
|
+
beta: float,
|
|
2091
|
+
eps: float = 1e-8,
|
|
2092
|
+
dtype: torch.dtype = torch.float,
|
|
2093
|
+
reduction: str = "max",
|
|
2094
|
+
*,
|
|
2095
|
+
num_slices: int | None = None,
|
|
2096
|
+
slice_len: int | None = None,
|
|
2097
|
+
end_key: NestedKey | None = None,
|
|
2098
|
+
traj_key: NestedKey | None = None,
|
|
2099
|
+
ends: torch.Tensor | None = None,
|
|
2100
|
+
trajectories: torch.Tensor | None = None,
|
|
2101
|
+
cache_values: bool = False,
|
|
2102
|
+
truncated_key: NestedKey | None = ("next", "truncated"),
|
|
2103
|
+
strict_length: bool = True,
|
|
2104
|
+
compile: bool | dict = False,
|
|
2105
|
+
span: bool | int | tuple[bool | int, bool | int] = False,
|
|
2106
|
+
max_priority_within_buffer: bool = False,
|
|
2107
|
+
):
|
|
2108
|
+
SliceSampler.__init__(
|
|
2109
|
+
self,
|
|
2110
|
+
num_slices=num_slices,
|
|
2111
|
+
slice_len=slice_len,
|
|
2112
|
+
end_key=end_key,
|
|
2113
|
+
traj_key=traj_key,
|
|
2114
|
+
cache_values=cache_values,
|
|
2115
|
+
truncated_key=truncated_key,
|
|
2116
|
+
strict_length=strict_length,
|
|
2117
|
+
ends=ends,
|
|
2118
|
+
trajectories=trajectories,
|
|
2119
|
+
compile=compile,
|
|
2120
|
+
span=span,
|
|
2121
|
+
)
|
|
2122
|
+
PrioritizedSampler.__init__(
|
|
2123
|
+
self,
|
|
2124
|
+
max_capacity=max_capacity,
|
|
2125
|
+
alpha=alpha,
|
|
2126
|
+
beta=beta,
|
|
2127
|
+
eps=eps,
|
|
2128
|
+
dtype=dtype,
|
|
2129
|
+
reduction=reduction,
|
|
2130
|
+
max_priority_within_buffer=max_priority_within_buffer,
|
|
2131
|
+
)
|
|
2132
|
+
if self.span[0]:
|
|
2133
|
+
# Span left is hard to achieve because we need to sample 'negative' starts, but to sample
|
|
2134
|
+
# the start we rely on PrioritizedSampler which has no idea it's looking at trajectories.
|
|
2135
|
+
#
|
|
2136
|
+
# Another way to go about this would be to stochastically decrease the seq_length to
|
|
2137
|
+
# accommodate this but that would require to over-sample the starts too.
|
|
2138
|
+
#
|
|
2139
|
+
warnings.warn(
|
|
2140
|
+
f"Left spanning is disabled for {type(self).__name__} and will be automatically turned off. "
|
|
2141
|
+
f"If this feature is required, please file an issue on torchrl GitHub repo."
|
|
2142
|
+
)
|
|
2143
|
+
self.span = (0, self.span[1])
|
|
2144
|
+
|
|
2145
|
+
def __repr__(self):
|
|
2146
|
+
return (
|
|
2147
|
+
f"{self.__class__.__name__}("
|
|
2148
|
+
f"num_slices={self.num_slices}, "
|
|
2149
|
+
f"slice_len={self.slice_len}, "
|
|
2150
|
+
f"end_key={self.end_key}, "
|
|
2151
|
+
f"traj_key={self.traj_key}, "
|
|
2152
|
+
f"truncated_key={self.truncated_key}, "
|
|
2153
|
+
f"strict_length={self.strict_length},"
|
|
2154
|
+
f"alpha={self._alpha}, "
|
|
2155
|
+
f"beta={self._beta}, "
|
|
2156
|
+
f"eps={self._eps}"
|
|
2157
|
+
)
|
|
2158
|
+
|
|
2159
|
+
def __getstate__(self):
|
|
2160
|
+
state = SliceSampler.__getstate__(self)
|
|
2161
|
+
state.update(PrioritizedSampler.__getstate__(self))
|
|
2162
|
+
return state
|
|
2163
|
+
|
|
2164
|
+
def mark_update(
|
|
2165
|
+
self, index: int | torch.Tensor, *, storage: Storage | None = None
|
|
2166
|
+
) -> None:
|
|
2167
|
+
return PrioritizedSampler.mark_update(self, index, storage=storage)
|
|
2168
|
+
|
|
2169
|
+
def _padded_indices(self, shapes, arange) -> torch.Tensor:
|
|
2170
|
+
# this complex mumbo jumbo creates a left padded tensor with valid indices on the right, e.g.
|
|
2171
|
+
# tensor([[ 0, 1, 2, 3, 4],
|
|
2172
|
+
# [-1, -1, 5, 6, 7],
|
|
2173
|
+
# [-1, 8, 9, 10, 11]])
|
|
2174
|
+
# where the -1 items on the left are padded values
|
|
2175
|
+
num_groups = shapes.shape[0]
|
|
2176
|
+
max_group_len = shapes.max()
|
|
2177
|
+
pad_lengths = max_group_len - shapes
|
|
2178
|
+
|
|
2179
|
+
# Get all the start and end indices within arange for each group
|
|
2180
|
+
group_ends = shapes.cumsum(0)
|
|
2181
|
+
group_starts = torch.empty_like(group_ends)
|
|
2182
|
+
group_starts[0] = 0
|
|
2183
|
+
group_starts[1:] = group_ends[:-1]
|
|
2184
|
+
pad = torch.empty(
|
|
2185
|
+
(num_groups, max_group_len), dtype=arange.dtype, device=arange.device
|
|
2186
|
+
)
|
|
2187
|
+
for pad_row, group_start, group_end, pad_len in zip(
|
|
2188
|
+
pad, group_starts, group_ends, pad_lengths
|
|
2189
|
+
):
|
|
2190
|
+
pad_row[:pad_len] = -1
|
|
2191
|
+
pad_row[pad_len:] = arange[group_start:group_end]
|
|
2192
|
+
|
|
2193
|
+
return pad
|
|
2194
|
+
|
|
2195
|
+
def _preceding_stop_idx(self, storage, lengths, seq_length, start_idx):
|
|
2196
|
+
preceding_stop_idx = self._cache.get("preceding_stop_idx")
|
|
2197
|
+
if preceding_stop_idx is not None:
|
|
2198
|
+
return preceding_stop_idx
|
|
2199
|
+
arange = torch.arange(storage.shape.numel())
|
|
2200
|
+
shapes = lengths.view(-1, 1).cpu()
|
|
2201
|
+
if not shapes.sum() - 1 == arange[-1]:
|
|
2202
|
+
raise RuntimeError("Wrong shapes / arange configuration")
|
|
2203
|
+
if not self.strict_length:
|
|
2204
|
+
# First, remove the starts from the arange
|
|
2205
|
+
# We do this because each traj can be sampled
|
|
2206
|
+
all_but_starts = torch.ones(arange.shape, dtype=torch.bool)
|
|
2207
|
+
starts = lengths.cumsum(0)
|
|
2208
|
+
starts = torch.cat([torch.zeros_like(starts[:1]), starts[:-1]])
|
|
2209
|
+
all_but_starts[starts] = False
|
|
2210
|
+
arange = arange[all_but_starts]
|
|
2211
|
+
shapes = shapes - 1
|
|
2212
|
+
pad = self._padded_indices(shapes, arange)
|
|
2213
|
+
_, span_right = self.span[0], self.span[1]
|
|
2214
|
+
if span_right and isinstance(span_right, bool):
|
|
2215
|
+
preceding_stop_idx = pad[:, -1:]
|
|
2216
|
+
else:
|
|
2217
|
+
# Mask the rightmost values of that padded tensor
|
|
2218
|
+
preceding_stop_idx = pad[:, -seq_length + 1 + span_right :]
|
|
2219
|
+
preceding_stop_idx = preceding_stop_idx[preceding_stop_idx >= 0]
|
|
2220
|
+
if storage._is_full:
|
|
2221
|
+
preceding_stop_idx = (
|
|
2222
|
+
preceding_stop_idx
|
|
2223
|
+
+ np.ravel_multi_index(
|
|
2224
|
+
tuple(start_idx[0].tolist()), storage._total_shape
|
|
2225
|
+
)
|
|
2226
|
+
) % storage._total_shape.numel()
|
|
2227
|
+
if self.cache_values:
|
|
2228
|
+
self._cache["preceding_stop_idx"] = preceding_stop_idx
|
|
2229
|
+
return preceding_stop_idx
|
|
2230
|
+
|
|
2231
|
+
def sample(self, storage: Storage, batch_size: int) -> tuple[torch.Tensor, dict]:
|
|
2232
|
+
# Sample `batch_size` indices representing the start of a slice.
|
|
2233
|
+
# The sampling is based on a weight vector.
|
|
2234
|
+
start_idx, stop_idx, lengths = self._get_stop_and_length(storage)
|
|
2235
|
+
seq_length, num_slices = self._adjusted_batch_size(batch_size)
|
|
2236
|
+
|
|
2237
|
+
preceding_stop_idx = self._preceding_stop_idx(
|
|
2238
|
+
storage, lengths, seq_length, start_idx
|
|
2239
|
+
)
|
|
2240
|
+
if storage.ndim > 1:
|
|
2241
|
+
# we need to convert indices of the permuted, flatten storage to indices in a flatten storage (not permuted)
|
|
2242
|
+
# This is because the lengths come as they would for a permuted storage
|
|
2243
|
+
preceding_stop_idx = unravel_index(
|
|
2244
|
+
preceding_stop_idx, (storage.shape[-1], *storage.shape[:-1])
|
|
2245
|
+
)
|
|
2246
|
+
preceding_stop_idx = (preceding_stop_idx[-1], *preceding_stop_idx[:-1])
|
|
2247
|
+
preceding_stop_idx = torch.as_tensor(
|
|
2248
|
+
np.ravel_multi_index(preceding_stop_idx, storage.shape)
|
|
2249
|
+
)
|
|
2250
|
+
|
|
2251
|
+
# force to not sample index at the end of a trajectory
|
|
2252
|
+
vals = torch.tensor(self._sum_tree[preceding_stop_idx.cpu().numpy()])
|
|
2253
|
+
self._sum_tree[preceding_stop_idx.cpu().numpy()] = 0.0
|
|
2254
|
+
# and no need to update self._min_tree
|
|
2255
|
+
|
|
2256
|
+
starts, info = PrioritizedSampler.sample(
|
|
2257
|
+
self, storage=storage, batch_size=batch_size // seq_length
|
|
2258
|
+
)
|
|
2259
|
+
self._sum_tree[preceding_stop_idx.cpu().numpy()] = vals
|
|
2260
|
+
# We must truncate the seq_length if (1) not strict length or (2) span[1]
|
|
2261
|
+
if self.span[1] or not self.strict_length:
|
|
2262
|
+
if not isinstance(starts, torch.Tensor):
|
|
2263
|
+
starts_tensor = torch.stack(list(starts), dim=-1).to(stop_idx.device)
|
|
2264
|
+
else:
|
|
2265
|
+
starts_tensor = starts.unsqueeze(1).to(stop_idx.device)
|
|
2266
|
+
# Find the stop that comes after the start index
|
|
2267
|
+
# say start_tensor has shape [N, X] and stop_idx has shape [M, X]
|
|
2268
|
+
# diff will have shape [M, N, X]
|
|
2269
|
+
stop_idx_corr = stop_idx.clone()
|
|
2270
|
+
stop_idx_corr[:, 0] = torch.where(
|
|
2271
|
+
stop_idx[:, 0] < start_idx[:, 0],
|
|
2272
|
+
stop_idx[:, 0] + storage._len_along_dim0,
|
|
2273
|
+
stop_idx[:, 0],
|
|
2274
|
+
)
|
|
2275
|
+
diff = stop_idx_corr.unsqueeze(1) - starts_tensor.unsqueeze(0)
|
|
2276
|
+
# filter out all items that don't belong to the same dim in the storage
|
|
2277
|
+
mask = (diff[:, :, 1:] != 0).any(-1)
|
|
2278
|
+
diff = diff[:, :, 0]
|
|
2279
|
+
diff[mask] = diff.max() + 1
|
|
2280
|
+
diff = diff.reshape(-1, starts_tensor.shape[0])
|
|
2281
|
+
# We remove all neg values from consideration
|
|
2282
|
+
diff[diff < 0] = diff.max() + 1
|
|
2283
|
+
# Take the arg min along dim 0 (thereby reducing dim M)
|
|
2284
|
+
idx = diff.argmin(dim=0)
|
|
2285
|
+
stops = stop_idx_corr[idx, 0]
|
|
2286
|
+
# TODO: here things may not work bc we could have spanning trajs,
|
|
2287
|
+
# though I cannot show that it breaks in the tests
|
|
2288
|
+
if starts_tensor.ndim > 1:
|
|
2289
|
+
starts_tensor = starts_tensor[:, 0]
|
|
2290
|
+
seq_length = (stops - starts_tensor + 1).clamp_max(seq_length)
|
|
2291
|
+
if (seq_length <= 0).any():
|
|
2292
|
+
raise RuntimeError(
|
|
2293
|
+
"failed to compute seq_length, please report this bug"
|
|
2294
|
+
)
|
|
2295
|
+
|
|
2296
|
+
if isinstance(starts, tuple):
|
|
2297
|
+
starts = torch.stack(starts, -1)
|
|
2298
|
+
# starts = torch.as_tensor(starts, device=lengths.device)
|
|
2299
|
+
info["priority_weight"] = torch.as_tensor(
|
|
2300
|
+
info["priority_weight"], device=lengths.device
|
|
2301
|
+
)
|
|
2302
|
+
|
|
2303
|
+
# extends starting indices of each slice with sequence_length to get indices of all steps
|
|
2304
|
+
index = self._tensor_slices_from_startend(
|
|
2305
|
+
seq_length, starts, storage_length=storage.shape[0]
|
|
2306
|
+
)
|
|
2307
|
+
|
|
2308
|
+
# repeat the weight of each slice to match the number of steps
|
|
2309
|
+
info["priority_weight"] = torch.repeat_interleave(
|
|
2310
|
+
info["priority_weight"], seq_length
|
|
2311
|
+
)
|
|
2312
|
+
|
|
2313
|
+
if self.truncated_key is not None:
|
|
2314
|
+
# following logics borrowed from SliceSampler
|
|
2315
|
+
truncated_key = self.truncated_key
|
|
2316
|
+
|
|
2317
|
+
done_key = _replace_last(truncated_key, "done")
|
|
2318
|
+
terminated_key = _replace_last(truncated_key, "terminated")
|
|
2319
|
+
|
|
2320
|
+
truncated = torch.zeros(
|
|
2321
|
+
(index.shape[0], 1), dtype=torch.bool, device=index.device
|
|
2322
|
+
)
|
|
2323
|
+
if isinstance(seq_length, int):
|
|
2324
|
+
truncated.view(num_slices, -1)[:, -1] = 1
|
|
2325
|
+
else:
|
|
2326
|
+
truncated[seq_length.cumsum(0) - 1] = 1
|
|
2327
|
+
index = index.to(torch.long).unbind(-1)
|
|
2328
|
+
st_index = storage[index]
|
|
2329
|
+
try:
|
|
2330
|
+
done = st_index[done_key] | truncated
|
|
2331
|
+
except KeyError:
|
|
2332
|
+
done = truncated.clone()
|
|
2333
|
+
try:
|
|
2334
|
+
terminated = st_index[terminated_key]
|
|
2335
|
+
except KeyError:
|
|
2336
|
+
terminated = torch.zeros_like(truncated)
|
|
2337
|
+
info.update(
|
|
2338
|
+
{
|
|
2339
|
+
truncated_key: truncated,
|
|
2340
|
+
done_key: done,
|
|
2341
|
+
terminated_key: terminated,
|
|
2342
|
+
}
|
|
2343
|
+
)
|
|
2344
|
+
return index, info
|
|
2345
|
+
return index.to(torch.long).unbind(-1), info
|
|
2346
|
+
|
|
2347
|
+
def _empty(self):
|
|
2348
|
+
# no op for SliceSampler
|
|
2349
|
+
PrioritizedSampler._empty(self)
|
|
2350
|
+
|
|
2351
|
+
def dumps(self, path):
|
|
2352
|
+
# no op for SliceSampler
|
|
2353
|
+
PrioritizedSampler.dumps(self, path)
|
|
2354
|
+
|
|
2355
|
+
def loads(self, path):
|
|
2356
|
+
# no op for SliceSampler
|
|
2357
|
+
return PrioritizedSampler.loads(self, path)
|
|
2358
|
+
|
|
2359
|
+
def state_dict(self):
|
|
2360
|
+
# no op for SliceSampler
|
|
2361
|
+
return PrioritizedSampler.state_dict(self)
|
|
2362
|
+
|
|
2363
|
+
def add(self, index: torch.Tensor) -> None:
|
|
2364
|
+
PrioritizedSampler.add(self, index)
|
|
2365
|
+
return SliceSampler.add(self, index)
|
|
2366
|
+
|
|
2367
|
+
def extend(self, index: torch.Tensor) -> None:
|
|
2368
|
+
PrioritizedSampler.extend(self, index)
|
|
2369
|
+
return SliceSampler.extend(self, index)
|
|
2370
|
+
|
|
2371
|
+
|
|
2372
|
+
class SamplerEnsemble(Sampler):
|
|
2373
|
+
"""An ensemble of samplers.
|
|
2374
|
+
|
|
2375
|
+
This class is designed to work with :class:`~torchrl.data.replay_buffers.replay_buffers.ReplayBufferEnsemble`.
|
|
2376
|
+
It contains the samplers as well as the sampling strategy hyperparameters.
|
|
2377
|
+
|
|
2378
|
+
Args:
|
|
2379
|
+
samplers (sequence of Sampler): the samplers to make the composite sampler.
|
|
2380
|
+
|
|
2381
|
+
Keyword Args:
|
|
2382
|
+
p (list or tensor of probabilities, optional): if provided, indicates the
|
|
2383
|
+
weights of each dataset during sampling.
|
|
2384
|
+
sample_from_all (bool, optional): if ``True``, each dataset will be sampled
|
|
2385
|
+
from. This is not compatible with the ``p`` argument. Defaults to ``False``.
|
|
2386
|
+
num_buffer_sampled (int, optional): the number of buffers to sample.
|
|
2387
|
+
if ``sample_from_all=True``, this has no effect, as it defaults to the
|
|
2388
|
+
number of buffers. If ``sample_from_all=False``, buffers will be
|
|
2389
|
+
sampled according to the probabilities ``p``.
|
|
2390
|
+
|
|
2391
|
+
.. warning::
|
|
2392
|
+
The indices provided in the info dictionary are placed in a :class:`~tensordict.TensorDict` with
|
|
2393
|
+
keys ``index`` and ``buffer_ids`` that allow the upper :class:`~torchrl.data.ReplayBufferEnsemble`
|
|
2394
|
+
and :class:`~torchrl.data.StorageEnsemble` objects to retrieve the data.
|
|
2395
|
+
This format is different from with other samplers which usually return indices
|
|
2396
|
+
as regular tensors.
|
|
2397
|
+
|
|
2398
|
+
"""
|
|
2399
|
+
|
|
2400
|
+
def __init__(
|
|
2401
|
+
self, *samplers, p=None, sample_from_all=False, num_buffer_sampled=None
|
|
2402
|
+
):
|
|
2403
|
+
self._rng_private = None
|
|
2404
|
+
self._samplers = samplers
|
|
2405
|
+
self.sample_from_all = sample_from_all
|
|
2406
|
+
if sample_from_all and p is not None:
|
|
2407
|
+
raise RuntimeError(
|
|
2408
|
+
"Cannot pass both `p` argument and `sample_from_all=True`."
|
|
2409
|
+
)
|
|
2410
|
+
self.p = p
|
|
2411
|
+
self.num_buffer_sampled = num_buffer_sampled
|
|
2412
|
+
|
|
2413
|
+
@property
|
|
2414
|
+
def _rng(self):
|
|
2415
|
+
return self._rng_private
|
|
2416
|
+
|
|
2417
|
+
@_rng.setter
|
|
2418
|
+
def _rng(self, value):
|
|
2419
|
+
self._rng_private = value
|
|
2420
|
+
for sampler in self._samplers:
|
|
2421
|
+
sampler._rng = value
|
|
2422
|
+
|
|
2423
|
+
@property
|
|
2424
|
+
def p(self):
|
|
2425
|
+
return self._p
|
|
2426
|
+
|
|
2427
|
+
@p.setter
|
|
2428
|
+
def p(self, value):
|
|
2429
|
+
if not isinstance(value, torch.Tensor) and value is not None:
|
|
2430
|
+
value = torch.tensor(value)
|
|
2431
|
+
if value is not None:
|
|
2432
|
+
value = value / value.sum().clamp_min(1e-6)
|
|
2433
|
+
self._p = value
|
|
2434
|
+
|
|
2435
|
+
@property
|
|
2436
|
+
def num_buffer_sampled(self):
|
|
2437
|
+
value = self.__dict__.get("_num_buffer_sampled", None)
|
|
2438
|
+
if value is None:
|
|
2439
|
+
value = self.__dict__["_num_buffer_sampled"] = len(self._samplers)
|
|
2440
|
+
return value
|
|
2441
|
+
|
|
2442
|
+
@num_buffer_sampled.setter
|
|
2443
|
+
def num_buffer_sampled(self, value):
|
|
2444
|
+
self.__dict__["_num_buffer_sampled"] = value
|
|
2445
|
+
|
|
2446
|
+
def sample(self, storage, batch_size):
|
|
2447
|
+
if batch_size % self.num_buffer_sampled > 0:
|
|
2448
|
+
raise ValueError
|
|
2449
|
+
if not isinstance(storage, StorageEnsemble):
|
|
2450
|
+
raise TypeError
|
|
2451
|
+
sub_batch_size = batch_size // self.num_buffer_sampled
|
|
2452
|
+
if self.sample_from_all:
|
|
2453
|
+
samples, infos = zip(
|
|
2454
|
+
*[
|
|
2455
|
+
sampler.sample(storage, sub_batch_size)
|
|
2456
|
+
for storage, sampler in zip(storage._storages, self._samplers)
|
|
2457
|
+
]
|
|
2458
|
+
)
|
|
2459
|
+
buffer_ids = torch.arange(len(samples))
|
|
2460
|
+
else:
|
|
2461
|
+
if self.p is None:
|
|
2462
|
+
buffer_ids = torch.randint(
|
|
2463
|
+
len(self._samplers),
|
|
2464
|
+
(self.num_buffer_sampled,),
|
|
2465
|
+
generator=self._rng,
|
|
2466
|
+
device=getattr(storage, "device", None),
|
|
2467
|
+
)
|
|
2468
|
+
else:
|
|
2469
|
+
buffer_ids = torch.multinomial(self.p, self.num_buffer_sampled, True)
|
|
2470
|
+
samples, infos = zip(
|
|
2471
|
+
*[
|
|
2472
|
+
self._samplers[i].sample(storage._storages[i], sub_batch_size)
|
|
2473
|
+
for i in buffer_ids.tolist()
|
|
2474
|
+
]
|
|
2475
|
+
)
|
|
2476
|
+
samples = [
|
|
2477
|
+
sample if isinstance(sample, torch.Tensor) else torch.stack(sample, -1)
|
|
2478
|
+
for sample in samples
|
|
2479
|
+
]
|
|
2480
|
+
if all(samples[0].shape == sample.shape for sample in samples[1:]):
|
|
2481
|
+
samples_stack = torch.stack(samples)
|
|
2482
|
+
else:
|
|
2483
|
+
samples_stack = torch.nested.nested_tensor(list(samples))
|
|
2484
|
+
|
|
2485
|
+
samples = TensorDict(
|
|
2486
|
+
{
|
|
2487
|
+
"index": samples_stack,
|
|
2488
|
+
"buffer_ids": buffer_ids,
|
|
2489
|
+
},
|
|
2490
|
+
batch_size=[self.num_buffer_sampled],
|
|
2491
|
+
)
|
|
2492
|
+
infos = torch.stack(
|
|
2493
|
+
[
|
|
2494
|
+
TensorDict.from_dict(info, batch_dims=samples.ndim - 1)
|
|
2495
|
+
if info
|
|
2496
|
+
else TensorDict()
|
|
2497
|
+
for info in infos
|
|
2498
|
+
]
|
|
2499
|
+
)
|
|
2500
|
+
return samples, infos
|
|
2501
|
+
|
|
2502
|
+
def dumps(self, path: Path):
|
|
2503
|
+
path = Path(path).absolute()
|
|
2504
|
+
for i, sampler in enumerate(self._samplers):
|
|
2505
|
+
sampler.dumps(path / str(i))
|
|
2506
|
+
|
|
2507
|
+
def loads(self, path: Path):
|
|
2508
|
+
path = Path(path).absolute()
|
|
2509
|
+
for i, sampler in enumerate(self._samplers):
|
|
2510
|
+
sampler.loads(path / str(i))
|
|
2511
|
+
|
|
2512
|
+
def state_dict(self) -> dict[str, Any]:
|
|
2513
|
+
state_dict = OrderedDict()
|
|
2514
|
+
for i, sampler in enumerate(self._samplers):
|
|
2515
|
+
state_dict[str(i)] = sampler.state_dict()
|
|
2516
|
+
return state_dict
|
|
2517
|
+
|
|
2518
|
+
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
|
|
2519
|
+
for i, sampler in enumerate(self._samplers):
|
|
2520
|
+
sampler.load_state_dict(state_dict[str(i)])
|
|
2521
|
+
|
|
2522
|
+
def _empty(self):
|
|
2523
|
+
raise NotImplementedError
|
|
2524
|
+
|
|
2525
|
+
_INDEX_ERROR = "Expected an index of type torch.Tensor, range, np.ndarray, int, slice or ellipsis, got {} instead."
|
|
2526
|
+
|
|
2527
|
+
def __getitem__(self, index):
|
|
2528
|
+
if isinstance(index, tuple):
|
|
2529
|
+
if index[0] is Ellipsis:
|
|
2530
|
+
index = (slice(None), index[1:])
|
|
2531
|
+
result = self[index[0]]
|
|
2532
|
+
if len(index) > 1:
|
|
2533
|
+
raise IndexError(
|
|
2534
|
+
f"Tuple of length greater than 1 are not accepted to index samplers of type {type(self)}."
|
|
2535
|
+
)
|
|
2536
|
+
return result
|
|
2537
|
+
if isinstance(index, slice) and index == slice(None):
|
|
2538
|
+
return self
|
|
2539
|
+
if isinstance(index, (list, range, np.ndarray)):
|
|
2540
|
+
index = torch.as_tensor(index)
|
|
2541
|
+
if isinstance(index, torch.Tensor):
|
|
2542
|
+
if index.ndim > 1:
|
|
2543
|
+
raise RuntimeError(
|
|
2544
|
+
f"Cannot index a {type(self)} with tensor indices that have more than one dimension."
|
|
2545
|
+
)
|
|
2546
|
+
if index.is_floating_point():
|
|
2547
|
+
raise TypeError(
|
|
2548
|
+
"A floating point index was received when an integer dtype was expected."
|
|
2549
|
+
)
|
|
2550
|
+
if isinstance(index, int) or (not isinstance(index, slice) and len(index) == 0):
|
|
2551
|
+
try:
|
|
2552
|
+
index = int(index)
|
|
2553
|
+
except Exception:
|
|
2554
|
+
raise IndexError(self._INDEX_ERROR.format(type(index)))
|
|
2555
|
+
try:
|
|
2556
|
+
return self._samplers[index]
|
|
2557
|
+
except IndexError:
|
|
2558
|
+
raise IndexError(self._INDEX_ERROR.format(type(index)))
|
|
2559
|
+
if isinstance(index, torch.Tensor):
|
|
2560
|
+
index = index.tolist()
|
|
2561
|
+
samplers = [self._samplers[i] for i in index]
|
|
2562
|
+
else:
|
|
2563
|
+
# slice
|
|
2564
|
+
samplers = self._samplers[index]
|
|
2565
|
+
p = self._p[index]
|
|
2566
|
+
return SamplerEnsemble(
|
|
2567
|
+
*samplers,
|
|
2568
|
+
p=p,
|
|
2569
|
+
sample_from_all=self.sample_from_all,
|
|
2570
|
+
num_buffer_sampled=self.num_buffer_sampled,
|
|
2571
|
+
)
|
|
2572
|
+
|
|
2573
|
+
def __len__(self):
|
|
2574
|
+
return len(self._samplers)
|
|
2575
|
+
|
|
2576
|
+
def __repr__(self):
|
|
2577
|
+
samplers = textwrap.indent(f"samplers={self._samplers}", " " * 4)
|
|
2578
|
+
return f"{self.__class__.__name__}(\n{samplers})"
|