torchrl 0.11.0__cp314-cp314-manylinux_2_28_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/benchmark_batched_envs.py +104 -0
- benchmarks/conftest.py +91 -0
- benchmarks/ecosystem/gym_env_throughput.py +321 -0
- benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
- benchmarks/requirements.txt +7 -0
- benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
- benchmarks/test_collectors_benchmark.py +240 -0
- benchmarks/test_compressed_storage_benchmark.py +145 -0
- benchmarks/test_envs_benchmark.py +133 -0
- benchmarks/test_llm.py +101 -0
- benchmarks/test_non_tensor_env_benchmark.py +70 -0
- benchmarks/test_objectives_benchmarks.py +1199 -0
- benchmarks/test_replaybuffer_benchmark.py +254 -0
- sota-check/README.md +35 -0
- sota-implementations/README.md +142 -0
- sota-implementations/a2c/README.md +39 -0
- sota-implementations/a2c/a2c_atari.py +291 -0
- sota-implementations/a2c/a2c_mujoco.py +273 -0
- sota-implementations/a2c/utils_atari.py +240 -0
- sota-implementations/a2c/utils_mujoco.py +160 -0
- sota-implementations/bandits/README.md +7 -0
- sota-implementations/bandits/dqn.py +126 -0
- sota-implementations/cql/cql_offline.py +198 -0
- sota-implementations/cql/cql_online.py +249 -0
- sota-implementations/cql/discrete_cql_offline.py +180 -0
- sota-implementations/cql/discrete_cql_online.py +227 -0
- sota-implementations/cql/utils.py +471 -0
- sota-implementations/crossq/crossq.py +271 -0
- sota-implementations/crossq/utils.py +320 -0
- sota-implementations/ddpg/ddpg.py +231 -0
- sota-implementations/ddpg/utils.py +325 -0
- sota-implementations/decision_transformer/dt.py +163 -0
- sota-implementations/decision_transformer/lamb.py +167 -0
- sota-implementations/decision_transformer/online_dt.py +178 -0
- sota-implementations/decision_transformer/utils.py +562 -0
- sota-implementations/discrete_sac/discrete_sac.py +243 -0
- sota-implementations/discrete_sac/utils.py +324 -0
- sota-implementations/dqn/README.md +30 -0
- sota-implementations/dqn/dqn_atari.py +272 -0
- sota-implementations/dqn/dqn_cartpole.py +236 -0
- sota-implementations/dqn/utils_atari.py +132 -0
- sota-implementations/dqn/utils_cartpole.py +90 -0
- sota-implementations/dreamer/README.md +129 -0
- sota-implementations/dreamer/dreamer.py +586 -0
- sota-implementations/dreamer/dreamer_utils.py +1107 -0
- sota-implementations/expert-iteration/README.md +352 -0
- sota-implementations/expert-iteration/ei_utils.py +770 -0
- sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
- sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
- sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
- sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
- sota-implementations/gail/gail.py +327 -0
- sota-implementations/gail/gail_utils.py +68 -0
- sota-implementations/gail/ppo_utils.py +157 -0
- sota-implementations/grpo/README.md +273 -0
- sota-implementations/grpo/grpo-async.py +437 -0
- sota-implementations/grpo/grpo-sync.py +435 -0
- sota-implementations/grpo/grpo_utils.py +843 -0
- sota-implementations/grpo/requirements_gsm8k.txt +11 -0
- sota-implementations/grpo/requirements_ifeval.txt +16 -0
- sota-implementations/impala/README.md +33 -0
- sota-implementations/impala/impala_multi_node_ray.py +292 -0
- sota-implementations/impala/impala_multi_node_submitit.py +284 -0
- sota-implementations/impala/impala_single_node.py +261 -0
- sota-implementations/impala/utils.py +184 -0
- sota-implementations/iql/discrete_iql.py +230 -0
- sota-implementations/iql/iql_offline.py +164 -0
- sota-implementations/iql/iql_online.py +225 -0
- sota-implementations/iql/utils.py +437 -0
- sota-implementations/multiagent/README.md +74 -0
- sota-implementations/multiagent/iql.py +237 -0
- sota-implementations/multiagent/maddpg_iddpg.py +266 -0
- sota-implementations/multiagent/mappo_ippo.py +267 -0
- sota-implementations/multiagent/qmix_vdn.py +271 -0
- sota-implementations/multiagent/sac.py +337 -0
- sota-implementations/multiagent/utils/__init__.py +4 -0
- sota-implementations/multiagent/utils/logging.py +151 -0
- sota-implementations/multiagent/utils/utils.py +43 -0
- sota-implementations/ppo/README.md +29 -0
- sota-implementations/ppo/ppo_atari.py +305 -0
- sota-implementations/ppo/ppo_mujoco.py +293 -0
- sota-implementations/ppo/utils_atari.py +238 -0
- sota-implementations/ppo/utils_mujoco.py +152 -0
- sota-implementations/ppo_trainer/train.py +21 -0
- sota-implementations/redq/README.md +7 -0
- sota-implementations/redq/redq.py +199 -0
- sota-implementations/redq/utils.py +1060 -0
- sota-implementations/sac/sac-async.py +266 -0
- sota-implementations/sac/sac.py +239 -0
- sota-implementations/sac/utils.py +381 -0
- sota-implementations/sac_trainer/train.py +16 -0
- sota-implementations/td3/td3.py +254 -0
- sota-implementations/td3/utils.py +319 -0
- sota-implementations/td3_bc/td3_bc.py +177 -0
- sota-implementations/td3_bc/utils.py +251 -0
- torchrl/__init__.py +144 -0
- torchrl/_extension.py +74 -0
- torchrl/_torchrl.cpython-314-aarch64-linux-gnu.so +0 -0
- torchrl/_utils.py +1431 -0
- torchrl/collectors/__init__.py +48 -0
- torchrl/collectors/_base.py +1058 -0
- torchrl/collectors/_constants.py +88 -0
- torchrl/collectors/_multi_async.py +324 -0
- torchrl/collectors/_multi_base.py +1805 -0
- torchrl/collectors/_multi_sync.py +464 -0
- torchrl/collectors/_runner.py +581 -0
- torchrl/collectors/_single.py +2009 -0
- torchrl/collectors/_single_async.py +259 -0
- torchrl/collectors/collectors.py +62 -0
- torchrl/collectors/distributed/__init__.py +32 -0
- torchrl/collectors/distributed/default_configs.py +133 -0
- torchrl/collectors/distributed/generic.py +1306 -0
- torchrl/collectors/distributed/ray.py +1092 -0
- torchrl/collectors/distributed/rpc.py +1006 -0
- torchrl/collectors/distributed/sync.py +731 -0
- torchrl/collectors/distributed/utils.py +160 -0
- torchrl/collectors/llm/__init__.py +10 -0
- torchrl/collectors/llm/base.py +494 -0
- torchrl/collectors/llm/ray_collector.py +275 -0
- torchrl/collectors/llm/utils.py +36 -0
- torchrl/collectors/llm/weight_update/__init__.py +10 -0
- torchrl/collectors/llm/weight_update/vllm.py +348 -0
- torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
- torchrl/collectors/utils.py +433 -0
- torchrl/collectors/weight_update.py +591 -0
- torchrl/csrc/numpy_utils.h +38 -0
- torchrl/csrc/pybind.cpp +27 -0
- torchrl/csrc/segment_tree.h +458 -0
- torchrl/csrc/torch_utils.h +34 -0
- torchrl/csrc/utils.cpp +48 -0
- torchrl/csrc/utils.h +31 -0
- torchrl/data/__init__.py +187 -0
- torchrl/data/datasets/__init__.py +58 -0
- torchrl/data/datasets/atari_dqn.py +878 -0
- torchrl/data/datasets/common.py +281 -0
- torchrl/data/datasets/d4rl.py +489 -0
- torchrl/data/datasets/d4rl_infos.py +187 -0
- torchrl/data/datasets/gen_dgrl.py +375 -0
- torchrl/data/datasets/minari_data.py +643 -0
- torchrl/data/datasets/openml.py +177 -0
- torchrl/data/datasets/openx.py +798 -0
- torchrl/data/datasets/roboset.py +363 -0
- torchrl/data/datasets/utils.py +11 -0
- torchrl/data/datasets/vd4rl.py +432 -0
- torchrl/data/llm/__init__.py +34 -0
- torchrl/data/llm/dataset.py +491 -0
- torchrl/data/llm/history.py +1378 -0
- torchrl/data/llm/prompt.py +198 -0
- torchrl/data/llm/reward.py +225 -0
- torchrl/data/llm/topk.py +186 -0
- torchrl/data/llm/utils.py +543 -0
- torchrl/data/map/__init__.py +21 -0
- torchrl/data/map/hash.py +185 -0
- torchrl/data/map/query.py +204 -0
- torchrl/data/map/tdstorage.py +363 -0
- torchrl/data/map/tree.py +1434 -0
- torchrl/data/map/utils.py +103 -0
- torchrl/data/postprocs/__init__.py +8 -0
- torchrl/data/postprocs/postprocs.py +391 -0
- torchrl/data/replay_buffers/__init__.py +99 -0
- torchrl/data/replay_buffers/checkpointers.py +622 -0
- torchrl/data/replay_buffers/ray_buffer.py +292 -0
- torchrl/data/replay_buffers/replay_buffers.py +2376 -0
- torchrl/data/replay_buffers/samplers.py +2578 -0
- torchrl/data/replay_buffers/scheduler.py +265 -0
- torchrl/data/replay_buffers/storages.py +2412 -0
- torchrl/data/replay_buffers/utils.py +1042 -0
- torchrl/data/replay_buffers/writers.py +781 -0
- torchrl/data/tensor_specs.py +7101 -0
- torchrl/data/utils.py +334 -0
- torchrl/envs/__init__.py +265 -0
- torchrl/envs/async_envs.py +1105 -0
- torchrl/envs/batched_envs.py +3093 -0
- torchrl/envs/common.py +4241 -0
- torchrl/envs/custom/__init__.py +11 -0
- torchrl/envs/custom/chess.py +617 -0
- torchrl/envs/custom/llm.py +214 -0
- torchrl/envs/custom/pendulum.py +401 -0
- torchrl/envs/custom/san_moves.txt +29274 -0
- torchrl/envs/custom/tictactoeenv.py +288 -0
- torchrl/envs/env_creator.py +263 -0
- torchrl/envs/gym_like.py +752 -0
- torchrl/envs/libs/__init__.py +68 -0
- torchrl/envs/libs/_gym_utils.py +326 -0
- torchrl/envs/libs/brax.py +846 -0
- torchrl/envs/libs/dm_control.py +544 -0
- torchrl/envs/libs/envpool.py +447 -0
- torchrl/envs/libs/gym.py +2239 -0
- torchrl/envs/libs/habitat.py +138 -0
- torchrl/envs/libs/isaac_lab.py +87 -0
- torchrl/envs/libs/isaacgym.py +203 -0
- torchrl/envs/libs/jax_utils.py +166 -0
- torchrl/envs/libs/jumanji.py +963 -0
- torchrl/envs/libs/meltingpot.py +599 -0
- torchrl/envs/libs/openml.py +153 -0
- torchrl/envs/libs/openspiel.py +652 -0
- torchrl/envs/libs/pettingzoo.py +1042 -0
- torchrl/envs/libs/procgen.py +351 -0
- torchrl/envs/libs/robohive.py +429 -0
- torchrl/envs/libs/smacv2.py +645 -0
- torchrl/envs/libs/unity_mlagents.py +891 -0
- torchrl/envs/libs/utils.py +147 -0
- torchrl/envs/libs/vmas.py +813 -0
- torchrl/envs/llm/__init__.py +63 -0
- torchrl/envs/llm/chat.py +730 -0
- torchrl/envs/llm/datasets/README.md +4 -0
- torchrl/envs/llm/datasets/__init__.py +17 -0
- torchrl/envs/llm/datasets/gsm8k.py +353 -0
- torchrl/envs/llm/datasets/ifeval.py +274 -0
- torchrl/envs/llm/envs.py +789 -0
- torchrl/envs/llm/libs/README.md +3 -0
- torchrl/envs/llm/libs/__init__.py +8 -0
- torchrl/envs/llm/libs/mlgym.py +869 -0
- torchrl/envs/llm/reward/__init__.py +10 -0
- torchrl/envs/llm/reward/gsm8k.py +324 -0
- torchrl/envs/llm/reward/ifeval/README.md +13 -0
- torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
- torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
- torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
- torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
- torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
- torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
- torchrl/envs/llm/transforms/__init__.py +55 -0
- torchrl/envs/llm/transforms/browser.py +292 -0
- torchrl/envs/llm/transforms/dataloading.py +859 -0
- torchrl/envs/llm/transforms/format.py +73 -0
- torchrl/envs/llm/transforms/kl.py +1544 -0
- torchrl/envs/llm/transforms/policy_version.py +189 -0
- torchrl/envs/llm/transforms/reason.py +323 -0
- torchrl/envs/llm/transforms/tokenizer.py +321 -0
- torchrl/envs/llm/transforms/tools.py +1955 -0
- torchrl/envs/model_based/__init__.py +9 -0
- torchrl/envs/model_based/common.py +180 -0
- torchrl/envs/model_based/dreamer.py +112 -0
- torchrl/envs/transforms/__init__.py +147 -0
- torchrl/envs/transforms/functional.py +48 -0
- torchrl/envs/transforms/gym_transforms.py +203 -0
- torchrl/envs/transforms/module.py +341 -0
- torchrl/envs/transforms/r3m.py +372 -0
- torchrl/envs/transforms/ray_service.py +663 -0
- torchrl/envs/transforms/rb_transforms.py +214 -0
- torchrl/envs/transforms/transforms.py +11835 -0
- torchrl/envs/transforms/utils.py +94 -0
- torchrl/envs/transforms/vc1.py +307 -0
- torchrl/envs/transforms/vecnorm.py +845 -0
- torchrl/envs/transforms/vip.py +407 -0
- torchrl/envs/utils.py +1718 -0
- torchrl/envs/vec_envs.py +11 -0
- torchrl/modules/__init__.py +206 -0
- torchrl/modules/distributions/__init__.py +73 -0
- torchrl/modules/distributions/continuous.py +830 -0
- torchrl/modules/distributions/discrete.py +908 -0
- torchrl/modules/distributions/truncated_normal.py +187 -0
- torchrl/modules/distributions/utils.py +233 -0
- torchrl/modules/llm/__init__.py +62 -0
- torchrl/modules/llm/backends/__init__.py +65 -0
- torchrl/modules/llm/backends/vllm/__init__.py +94 -0
- torchrl/modules/llm/backends/vllm/_models.py +46 -0
- torchrl/modules/llm/backends/vllm/base.py +72 -0
- torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
- torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
- torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
- torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
- torchrl/modules/llm/policies/__init__.py +28 -0
- torchrl/modules/llm/policies/common.py +1809 -0
- torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
- torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
- torchrl/modules/llm/utils.py +23 -0
- torchrl/modules/mcts/__init__.py +21 -0
- torchrl/modules/mcts/scores.py +579 -0
- torchrl/modules/models/__init__.py +86 -0
- torchrl/modules/models/batchrenorm.py +119 -0
- torchrl/modules/models/decision_transformer.py +179 -0
- torchrl/modules/models/exploration.py +731 -0
- torchrl/modules/models/llm.py +156 -0
- torchrl/modules/models/model_based.py +596 -0
- torchrl/modules/models/models.py +1712 -0
- torchrl/modules/models/multiagent.py +1067 -0
- torchrl/modules/models/recipes/impala.py +185 -0
- torchrl/modules/models/utils.py +162 -0
- torchrl/modules/planners/__init__.py +10 -0
- torchrl/modules/planners/cem.py +228 -0
- torchrl/modules/planners/common.py +73 -0
- torchrl/modules/planners/mppi.py +265 -0
- torchrl/modules/tensordict_module/__init__.py +89 -0
- torchrl/modules/tensordict_module/actors.py +2457 -0
- torchrl/modules/tensordict_module/common.py +529 -0
- torchrl/modules/tensordict_module/exploration.py +814 -0
- torchrl/modules/tensordict_module/probabilistic.py +321 -0
- torchrl/modules/tensordict_module/rnn.py +1639 -0
- torchrl/modules/tensordict_module/sequence.py +132 -0
- torchrl/modules/tensordict_module/world_models.py +34 -0
- torchrl/modules/utils/__init__.py +38 -0
- torchrl/modules/utils/mappings.py +9 -0
- torchrl/modules/utils/utils.py +89 -0
- torchrl/objectives/__init__.py +78 -0
- torchrl/objectives/a2c.py +659 -0
- torchrl/objectives/common.py +753 -0
- torchrl/objectives/cql.py +1346 -0
- torchrl/objectives/crossq.py +710 -0
- torchrl/objectives/ddpg.py +453 -0
- torchrl/objectives/decision_transformer.py +371 -0
- torchrl/objectives/deprecated.py +516 -0
- torchrl/objectives/dqn.py +683 -0
- torchrl/objectives/dreamer.py +488 -0
- torchrl/objectives/functional.py +48 -0
- torchrl/objectives/gail.py +258 -0
- torchrl/objectives/iql.py +996 -0
- torchrl/objectives/llm/__init__.py +30 -0
- torchrl/objectives/llm/grpo.py +846 -0
- torchrl/objectives/llm/sft.py +482 -0
- torchrl/objectives/multiagent/__init__.py +8 -0
- torchrl/objectives/multiagent/qmixer.py +396 -0
- torchrl/objectives/ppo.py +1669 -0
- torchrl/objectives/redq.py +683 -0
- torchrl/objectives/reinforce.py +530 -0
- torchrl/objectives/sac.py +1580 -0
- torchrl/objectives/td3.py +570 -0
- torchrl/objectives/td3_bc.py +625 -0
- torchrl/objectives/utils.py +782 -0
- torchrl/objectives/value/__init__.py +28 -0
- torchrl/objectives/value/advantages.py +1956 -0
- torchrl/objectives/value/functional.py +1459 -0
- torchrl/objectives/value/utils.py +360 -0
- torchrl/record/__init__.py +17 -0
- torchrl/record/loggers/__init__.py +23 -0
- torchrl/record/loggers/common.py +48 -0
- torchrl/record/loggers/csv.py +226 -0
- torchrl/record/loggers/mlflow.py +142 -0
- torchrl/record/loggers/tensorboard.py +139 -0
- torchrl/record/loggers/trackio.py +163 -0
- torchrl/record/loggers/utils.py +78 -0
- torchrl/record/loggers/wandb.py +214 -0
- torchrl/record/recorder.py +554 -0
- torchrl/services/__init__.py +79 -0
- torchrl/services/base.py +109 -0
- torchrl/services/ray_service.py +453 -0
- torchrl/testing/__init__.py +107 -0
- torchrl/testing/assertions.py +179 -0
- torchrl/testing/dist_utils.py +122 -0
- torchrl/testing/env_creators.py +227 -0
- torchrl/testing/env_helper.py +35 -0
- torchrl/testing/gym_helpers.py +156 -0
- torchrl/testing/llm_mocks.py +119 -0
- torchrl/testing/mocking_classes.py +2720 -0
- torchrl/testing/modules.py +295 -0
- torchrl/testing/mp_helpers.py +15 -0
- torchrl/testing/ray_helpers.py +293 -0
- torchrl/testing/utils.py +190 -0
- torchrl/trainers/__init__.py +42 -0
- torchrl/trainers/algorithms/__init__.py +11 -0
- torchrl/trainers/algorithms/configs/__init__.py +705 -0
- torchrl/trainers/algorithms/configs/collectors.py +216 -0
- torchrl/trainers/algorithms/configs/common.py +41 -0
- torchrl/trainers/algorithms/configs/data.py +308 -0
- torchrl/trainers/algorithms/configs/envs.py +104 -0
- torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
- torchrl/trainers/algorithms/configs/logging.py +80 -0
- torchrl/trainers/algorithms/configs/modules.py +570 -0
- torchrl/trainers/algorithms/configs/objectives.py +177 -0
- torchrl/trainers/algorithms/configs/trainers.py +340 -0
- torchrl/trainers/algorithms/configs/transforms.py +955 -0
- torchrl/trainers/algorithms/configs/utils.py +252 -0
- torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
- torchrl/trainers/algorithms/configs/weight_update.py +159 -0
- torchrl/trainers/algorithms/ppo.py +373 -0
- torchrl/trainers/algorithms/sac.py +308 -0
- torchrl/trainers/helpers/__init__.py +40 -0
- torchrl/trainers/helpers/collectors.py +416 -0
- torchrl/trainers/helpers/envs.py +573 -0
- torchrl/trainers/helpers/logger.py +33 -0
- torchrl/trainers/helpers/losses.py +132 -0
- torchrl/trainers/helpers/models.py +658 -0
- torchrl/trainers/helpers/replay_buffer.py +59 -0
- torchrl/trainers/helpers/trainers.py +301 -0
- torchrl/trainers/trainers.py +2052 -0
- torchrl/weight_update/__init__.py +33 -0
- torchrl/weight_update/_distributed.py +749 -0
- torchrl/weight_update/_mp.py +624 -0
- torchrl/weight_update/_noupdate.py +102 -0
- torchrl/weight_update/_ray.py +1032 -0
- torchrl/weight_update/_rpc.py +284 -0
- torchrl/weight_update/_shared.py +891 -0
- torchrl/weight_update/llm/__init__.py +32 -0
- torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
- torchrl/weight_update/llm/vllm_nccl.py +710 -0
- torchrl/weight_update/utils.py +73 -0
- torchrl/weight_update/weight_sync_schemes.py +1244 -0
- torchrl-0.11.0.dist-info/METADATA +1308 -0
- torchrl-0.11.0.dist-info/RECORD +394 -0
- torchrl-0.11.0.dist-info/WHEEL +5 -0
- torchrl-0.11.0.dist-info/entry_points.txt +2 -0
- torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
- torchrl-0.11.0.dist-info/top_level.txt +7 -0
|
@@ -0,0 +1,2412 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import abc
|
|
8
|
+
import atexit
|
|
9
|
+
import logging
|
|
10
|
+
import multiprocessing as mp
|
|
11
|
+
import os
|
|
12
|
+
import shutil
|
|
13
|
+
import signal
|
|
14
|
+
import sys
|
|
15
|
+
import tempfile
|
|
16
|
+
import textwrap
|
|
17
|
+
import warnings
|
|
18
|
+
import weakref
|
|
19
|
+
from collections import OrderedDict
|
|
20
|
+
from collections.abc import Callable, Mapping, Sequence
|
|
21
|
+
from copy import copy
|
|
22
|
+
from multiprocessing.context import get_spawning_popen
|
|
23
|
+
from typing import Any
|
|
24
|
+
|
|
25
|
+
import numpy as np
|
|
26
|
+
import tensordict
|
|
27
|
+
import torch
|
|
28
|
+
from tensordict import (
|
|
29
|
+
is_tensor_collection,
|
|
30
|
+
lazy_stack,
|
|
31
|
+
LazyStackedTensorDict,
|
|
32
|
+
TensorDict,
|
|
33
|
+
TensorDictBase,
|
|
34
|
+
)
|
|
35
|
+
from tensordict.base import _NESTED_TENSORS_AS_LISTS
|
|
36
|
+
from tensordict.memmap import MemoryMappedTensor
|
|
37
|
+
from tensordict.utils import _zip_strict
|
|
38
|
+
from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten
|
|
39
|
+
|
|
40
|
+
from torchrl._utils import _make_ordinal_device, implement_for, logger as torchrl_logger
|
|
41
|
+
from torchrl.data.replay_buffers.checkpointers import (
|
|
42
|
+
CompressedListStorageCheckpointer,
|
|
43
|
+
ListStorageCheckpointer,
|
|
44
|
+
StorageCheckpointerBase,
|
|
45
|
+
StorageEnsembleCheckpointer,
|
|
46
|
+
TensorStorageCheckpointer,
|
|
47
|
+
)
|
|
48
|
+
from torchrl.data.replay_buffers.utils import (
|
|
49
|
+
_init_pytree,
|
|
50
|
+
_is_int,
|
|
51
|
+
INT_CLASSES,
|
|
52
|
+
tree_iter,
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
try:
|
|
56
|
+
from torch.compiler import disable as compile_disable, is_compiling
|
|
57
|
+
except ImportError:
|
|
58
|
+
from torch._dynamo import disable as compile_disable, is_compiling
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
# =============================================================================
|
|
62
|
+
# Memmap Storage Cleanup Infrastructure
|
|
63
|
+
# =============================================================================
|
|
64
|
+
# This module-level infrastructure ensures that memmap files created by
|
|
65
|
+
# LazyMemmapStorage are cleaned up even when scripts are interrupted with
|
|
66
|
+
# Ctrl+C (SIGINT) or killed with SIGTERM.
|
|
67
|
+
|
|
68
|
+
# Registry of storages to clean up (weak references to avoid preventing GC)
|
|
69
|
+
_MEMMAP_STORAGE_REGISTRY: weakref.WeakSet = weakref.WeakSet()
|
|
70
|
+
|
|
71
|
+
# Track if cleanup has already run (to avoid double cleanup)
|
|
72
|
+
_CLEANUP_DONE = False
|
|
73
|
+
|
|
74
|
+
# Store original signal handlers to restore after cleanup
|
|
75
|
+
_ORIGINAL_SIGINT_HANDLER = None
|
|
76
|
+
_ORIGINAL_SIGTERM_HANDLER = None
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def _cleanup_all_memmap_storages():
|
|
80
|
+
"""Clean up all registered memmap storages.
|
|
81
|
+
|
|
82
|
+
This function is called on exit (via atexit) and on signal interrupts.
|
|
83
|
+
It removes all temporary memmap directories that were created with
|
|
84
|
+
auto_cleanup=True.
|
|
85
|
+
"""
|
|
86
|
+
global _CLEANUP_DONE
|
|
87
|
+
if _CLEANUP_DONE:
|
|
88
|
+
return
|
|
89
|
+
_CLEANUP_DONE = True
|
|
90
|
+
|
|
91
|
+
for storage in list(_MEMMAP_STORAGE_REGISTRY):
|
|
92
|
+
try:
|
|
93
|
+
storage.cleanup()
|
|
94
|
+
except Exception:
|
|
95
|
+
# Ignore errors during cleanup - the storage might already be gone
|
|
96
|
+
pass
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def _signal_cleanup_handler(signum, frame):
|
|
100
|
+
"""Signal handler that cleans up memmap storages before exiting.
|
|
101
|
+
|
|
102
|
+
This handler is robust to cleanup failures - it will always re-raise the
|
|
103
|
+
signal to ensure proper process termination.
|
|
104
|
+
"""
|
|
105
|
+
# Always ensure we re-raise the signal, even if cleanup fails
|
|
106
|
+
try:
|
|
107
|
+
_cleanup_all_memmap_storages()
|
|
108
|
+
except Exception:
|
|
109
|
+
# Ignore any cleanup errors - we must re-raise the signal
|
|
110
|
+
pass
|
|
111
|
+
|
|
112
|
+
# Re-raise the signal with the original handler (or default behavior)
|
|
113
|
+
if signum == signal.SIGINT:
|
|
114
|
+
original = _ORIGINAL_SIGINT_HANDLER
|
|
115
|
+
elif signum == signal.SIGTERM:
|
|
116
|
+
original = _ORIGINAL_SIGTERM_HANDLER
|
|
117
|
+
else:
|
|
118
|
+
original = signal.SIG_DFL
|
|
119
|
+
|
|
120
|
+
# Restore original handler and re-raise
|
|
121
|
+
signal.signal(signum, original if original else signal.SIG_DFL)
|
|
122
|
+
os.kill(os.getpid(), signum)
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
def _register_cleanup_handlers():
|
|
126
|
+
"""Register atexit and signal handlers for memmap cleanup.
|
|
127
|
+
|
|
128
|
+
This is called once when the first storage with auto_cleanup=True is created.
|
|
129
|
+
"""
|
|
130
|
+
global _ORIGINAL_SIGINT_HANDLER, _ORIGINAL_SIGTERM_HANDLER
|
|
131
|
+
|
|
132
|
+
# Register atexit handler (for normal exits)
|
|
133
|
+
atexit.register(_cleanup_all_memmap_storages)
|
|
134
|
+
|
|
135
|
+
# Register signal handlers (for Ctrl+C and kill)
|
|
136
|
+
# Only register if we're in the main thread (signals can only be handled in main thread)
|
|
137
|
+
try:
|
|
138
|
+
import threading
|
|
139
|
+
|
|
140
|
+
if threading.current_thread() is threading.main_thread():
|
|
141
|
+
_ORIGINAL_SIGINT_HANDLER = signal.signal(
|
|
142
|
+
signal.SIGINT, _signal_cleanup_handler
|
|
143
|
+
)
|
|
144
|
+
_ORIGINAL_SIGTERM_HANDLER = signal.signal(
|
|
145
|
+
signal.SIGTERM, _signal_cleanup_handler
|
|
146
|
+
)
|
|
147
|
+
except (ValueError, RuntimeError):
|
|
148
|
+
# Signal handling not available (e.g., not main thread)
|
|
149
|
+
pass
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
# Flag to track if handlers have been registered
|
|
153
|
+
_CLEANUP_HANDLERS_REGISTERED = False
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
def _ensure_cleanup_handlers():
|
|
157
|
+
"""Ensure cleanup handlers are registered (called once per process)."""
|
|
158
|
+
global _CLEANUP_HANDLERS_REGISTERED
|
|
159
|
+
if not _CLEANUP_HANDLERS_REGISTERED:
|
|
160
|
+
_register_cleanup_handlers()
|
|
161
|
+
_CLEANUP_HANDLERS_REGISTERED = True
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
class Storage:
|
|
165
|
+
"""A Storage is the container of a replay buffer.
|
|
166
|
+
|
|
167
|
+
Every storage must have a set, get and __len__ methods implemented.
|
|
168
|
+
Get and set should support integers as well as list of integers.
|
|
169
|
+
|
|
170
|
+
The storage does not need to have a definite size, but if it does one should
|
|
171
|
+
make sure that it is compatible with the buffer size.
|
|
172
|
+
|
|
173
|
+
"""
|
|
174
|
+
|
|
175
|
+
ndim = 1
|
|
176
|
+
max_size: int
|
|
177
|
+
_default_checkpointer: StorageCheckpointerBase = StorageCheckpointerBase
|
|
178
|
+
_rng: torch.Generator | None = None
|
|
179
|
+
|
|
180
|
+
def __init__(
|
|
181
|
+
self,
|
|
182
|
+
max_size: int,
|
|
183
|
+
checkpointer: StorageCheckpointerBase | None = None,
|
|
184
|
+
compilable: bool = False,
|
|
185
|
+
) -> None:
|
|
186
|
+
self.max_size = int(max_size)
|
|
187
|
+
self.checkpointer = checkpointer
|
|
188
|
+
self._compilable = compilable
|
|
189
|
+
self._attached_entities_list = []
|
|
190
|
+
|
|
191
|
+
@property
|
|
192
|
+
def checkpointer(self):
|
|
193
|
+
return self._checkpointer
|
|
194
|
+
|
|
195
|
+
def register_save_hook(self, hook):
|
|
196
|
+
"""Register a save hook for this storage.
|
|
197
|
+
|
|
198
|
+
The hook is forwarded to the checkpointer.
|
|
199
|
+
"""
|
|
200
|
+
self._checkpointer.register_save_hook(hook)
|
|
201
|
+
|
|
202
|
+
def register_load_hook(self, hook):
|
|
203
|
+
"""Register a load hook for this storage.
|
|
204
|
+
|
|
205
|
+
The hook is forwarded to the checkpointer.
|
|
206
|
+
"""
|
|
207
|
+
self._checkpointer.register_load_hook(hook)
|
|
208
|
+
|
|
209
|
+
@checkpointer.setter
|
|
210
|
+
def checkpointer(self, value: StorageCheckpointerBase | None) -> None:
|
|
211
|
+
if value is None:
|
|
212
|
+
value = self._default_checkpointer()
|
|
213
|
+
self._checkpointer = value
|
|
214
|
+
|
|
215
|
+
@property
|
|
216
|
+
def _is_full(self):
|
|
217
|
+
return len(self) == self.max_size
|
|
218
|
+
|
|
219
|
+
@property
|
|
220
|
+
def _attached_entities(self) -> list:
|
|
221
|
+
# RBs that use a given instance of Storage should add
|
|
222
|
+
# themselves to this set.
|
|
223
|
+
_attached_entities_list = getattr(self, "_attached_entities_list", None)
|
|
224
|
+
if _attached_entities_list is None:
|
|
225
|
+
self._attached_entities_list = _attached_entities_list = []
|
|
226
|
+
return _attached_entities_list
|
|
227
|
+
|
|
228
|
+
# TODO: Check this
|
|
229
|
+
@torch._dynamo.assume_constant_result
|
|
230
|
+
def _attached_entities_iter(self):
|
|
231
|
+
return self._attached_entities
|
|
232
|
+
|
|
233
|
+
@abc.abstractmethod
|
|
234
|
+
def set(self, cursor: int, data: Any, *, set_cursor: bool = True):
|
|
235
|
+
...
|
|
236
|
+
|
|
237
|
+
@abc.abstractmethod
|
|
238
|
+
def get(self, index: int) -> Any:
|
|
239
|
+
...
|
|
240
|
+
|
|
241
|
+
def dumps(self, path):
|
|
242
|
+
self.checkpointer.dumps(self, path)
|
|
243
|
+
|
|
244
|
+
def loads(self, path):
|
|
245
|
+
self.checkpointer.loads(self, path)
|
|
246
|
+
|
|
247
|
+
def attach(self, buffer: Any) -> None:
|
|
248
|
+
"""This function attaches a sampler to this storage.
|
|
249
|
+
|
|
250
|
+
Buffers that read from this storage must be included as an attached
|
|
251
|
+
entity by calling this method. This guarantees that when data
|
|
252
|
+
in the storage changes, components are made aware of changes even if the storage
|
|
253
|
+
is shared with other buffers (eg. Priority Samplers).
|
|
254
|
+
|
|
255
|
+
Args:
|
|
256
|
+
buffer: the object that reads from this storage.
|
|
257
|
+
"""
|
|
258
|
+
if buffer not in self._attached_entities:
|
|
259
|
+
self._attached_entities.append(buffer)
|
|
260
|
+
|
|
261
|
+
def __getitem__(self, item):
|
|
262
|
+
return self.get(item)
|
|
263
|
+
|
|
264
|
+
def __setitem__(self, index, value):
|
|
265
|
+
"""Sets values in the storage without updating the cursor or length."""
|
|
266
|
+
return self.set(index, value, set_cursor=False)
|
|
267
|
+
|
|
268
|
+
def __iter__(self):
|
|
269
|
+
for i in range(len(self)):
|
|
270
|
+
yield self[i]
|
|
271
|
+
|
|
272
|
+
@abc.abstractmethod
|
|
273
|
+
def __len__(self):
|
|
274
|
+
...
|
|
275
|
+
|
|
276
|
+
@abc.abstractmethod
|
|
277
|
+
def state_dict(self) -> dict[str, Any]:
|
|
278
|
+
...
|
|
279
|
+
|
|
280
|
+
@abc.abstractmethod
|
|
281
|
+
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
|
|
282
|
+
...
|
|
283
|
+
|
|
284
|
+
@abc.abstractmethod
|
|
285
|
+
def _empty(self):
|
|
286
|
+
...
|
|
287
|
+
|
|
288
|
+
# TODO: Without this disable, compiler recompiles due to changing len(self) guards.
|
|
289
|
+
@compile_disable()
|
|
290
|
+
def _rand_given_ndim(self, batch_size):
|
|
291
|
+
# a method to return random indices given the storage ndim
|
|
292
|
+
if self.ndim == 1:
|
|
293
|
+
return torch.randint(
|
|
294
|
+
0,
|
|
295
|
+
len(self),
|
|
296
|
+
(batch_size,),
|
|
297
|
+
generator=self._rng,
|
|
298
|
+
device=getattr(self, "device", None),
|
|
299
|
+
)
|
|
300
|
+
raise RuntimeError(
|
|
301
|
+
f"Random number generation is not implemented for storage of type {type(self)} with ndim {self.ndim}. "
|
|
302
|
+
f"Please report this exception as well as the use case (incl. buffer construction) on github."
|
|
303
|
+
)
|
|
304
|
+
|
|
305
|
+
@property
|
|
306
|
+
def shape(self):
|
|
307
|
+
if self.ndim == 1:
|
|
308
|
+
return torch.Size([self.max_size])
|
|
309
|
+
raise RuntimeError(
|
|
310
|
+
f"storage.shape is not supported for storages of type {type(self)} when ndim > 1."
|
|
311
|
+
f"Please report this exception as well as the use case (incl. buffer construction) on github."
|
|
312
|
+
)
|
|
313
|
+
|
|
314
|
+
def _max_size_along_dim0(self, *, single_data=None, batched_data=None):
|
|
315
|
+
if self.ndim == 1:
|
|
316
|
+
return self.max_size
|
|
317
|
+
raise RuntimeError(
|
|
318
|
+
f"storage._max_size_along_dim0 is not supported for storages of type {type(self)} when ndim > 1."
|
|
319
|
+
f"Please report this exception as well as the use case (incl. buffer construction) on github."
|
|
320
|
+
)
|
|
321
|
+
|
|
322
|
+
def flatten(self):
|
|
323
|
+
if self.ndim == 1:
|
|
324
|
+
return self
|
|
325
|
+
raise RuntimeError(
|
|
326
|
+
f"storage.flatten is not supported for storages of type {type(self)} when ndim > 1."
|
|
327
|
+
f"Please report this exception as well as the use case (incl. buffer construction) on github."
|
|
328
|
+
)
|
|
329
|
+
|
|
330
|
+
def save(self, *args, **kwargs):
|
|
331
|
+
"""Alias for :meth:`dumps`."""
|
|
332
|
+
return self.dumps(*args, **kwargs)
|
|
333
|
+
|
|
334
|
+
def dump(self, *args, **kwargs):
|
|
335
|
+
"""Alias for :meth:`dumps`."""
|
|
336
|
+
return self.dumps(*args, **kwargs)
|
|
337
|
+
|
|
338
|
+
def load(self, *args, **kwargs):
|
|
339
|
+
"""Alias for :meth:`loads`."""
|
|
340
|
+
return self.loads(*args, **kwargs)
|
|
341
|
+
|
|
342
|
+
def __getstate__(self):
|
|
343
|
+
state = copy(self.__dict__)
|
|
344
|
+
state["_rng"] = None
|
|
345
|
+
return state
|
|
346
|
+
|
|
347
|
+
def __contains__(self, item):
|
|
348
|
+
return self.contains(item)
|
|
349
|
+
|
|
350
|
+
@abc.abstractmethod
|
|
351
|
+
def contains(self, item):
|
|
352
|
+
...
|
|
353
|
+
|
|
354
|
+
|
|
355
|
+
class ListStorage(Storage):
|
|
356
|
+
"""A storage stored in a list.
|
|
357
|
+
|
|
358
|
+
This class cannot be extended with PyTrees, the data provided during calls to
|
|
359
|
+
:meth:`~torchrl.data.replay_buffers.ReplayBuffer.extend` should be iterables
|
|
360
|
+
(like lists, tuples, tensors or tensordicts with non-empty batch-size).
|
|
361
|
+
|
|
362
|
+
Args:
|
|
363
|
+
max_size (int, optional): the maximum number of elements stored in the storage.
|
|
364
|
+
If not provided, an unlimited storage is created.
|
|
365
|
+
|
|
366
|
+
Keyword Args:
|
|
367
|
+
compilable (bool, optional): if ``True``, the storage will be made compatible with :func:`~torch.compile` at
|
|
368
|
+
the cost of being executable in multiprocessed settings.
|
|
369
|
+
device (str, optional): the device to use for the storage. Defaults to `None` (inputs are not moved to the device).
|
|
370
|
+
|
|
371
|
+
"""
|
|
372
|
+
|
|
373
|
+
_default_checkpointer = ListStorageCheckpointer
|
|
374
|
+
|
|
375
|
+
def __init__(
|
|
376
|
+
self,
|
|
377
|
+
max_size: int | None = None,
|
|
378
|
+
*,
|
|
379
|
+
compilable: bool = False,
|
|
380
|
+
device: torch.device | str | int | None = None,
|
|
381
|
+
):
|
|
382
|
+
if max_size is None:
|
|
383
|
+
max_size = torch.iinfo(torch.int64).max
|
|
384
|
+
super().__init__(max_size, compilable=compilable)
|
|
385
|
+
self._storage = []
|
|
386
|
+
self.device = device
|
|
387
|
+
|
|
388
|
+
def _to_device(self, data: Any) -> Any:
|
|
389
|
+
"""Utility method to move data to the device."""
|
|
390
|
+
if self.device is not None:
|
|
391
|
+
if hasattr(data, "to"):
|
|
392
|
+
data = data.to(self.device)
|
|
393
|
+
else:
|
|
394
|
+
data = tree_map(
|
|
395
|
+
lambda x: x.to(self.device) if hasattr(x, "to") else x, data
|
|
396
|
+
)
|
|
397
|
+
return data
|
|
398
|
+
|
|
399
|
+
def set(
|
|
400
|
+
self,
|
|
401
|
+
cursor: int | Sequence[int] | slice,
|
|
402
|
+
data: Any,
|
|
403
|
+
*,
|
|
404
|
+
set_cursor: bool = True,
|
|
405
|
+
):
|
|
406
|
+
if not isinstance(cursor, INT_CLASSES):
|
|
407
|
+
if (isinstance(cursor, torch.Tensor) and cursor.ndim == 0) or (
|
|
408
|
+
isinstance(cursor, np.ndarray) and cursor.ndim == 0
|
|
409
|
+
):
|
|
410
|
+
self.set(int(cursor), data, set_cursor=set_cursor)
|
|
411
|
+
return
|
|
412
|
+
if isinstance(cursor, slice):
|
|
413
|
+
data = self._to_device(data)
|
|
414
|
+
self._set_slice(cursor, data)
|
|
415
|
+
return
|
|
416
|
+
if isinstance(
|
|
417
|
+
data,
|
|
418
|
+
(
|
|
419
|
+
list,
|
|
420
|
+
tuple,
|
|
421
|
+
torch.Tensor,
|
|
422
|
+
TensorDictBase,
|
|
423
|
+
*tensordict.base._ACCEPTED_CLASSES,
|
|
424
|
+
range,
|
|
425
|
+
set,
|
|
426
|
+
np.ndarray,
|
|
427
|
+
),
|
|
428
|
+
):
|
|
429
|
+
for _cursor, _data in _zip_strict(cursor, data):
|
|
430
|
+
self.set(_cursor, _data, set_cursor=set_cursor)
|
|
431
|
+
else:
|
|
432
|
+
raise TypeError(
|
|
433
|
+
f"Cannot extend a {type(self)} with data of type {type(data)}. "
|
|
434
|
+
f"Provide a list, tuple, set, range, np.ndarray, tensor or tensordict subclass instead."
|
|
435
|
+
)
|
|
436
|
+
return
|
|
437
|
+
else:
|
|
438
|
+
if cursor > len(self._storage):
|
|
439
|
+
raise RuntimeError(
|
|
440
|
+
"Cannot append data located more than one item away from "
|
|
441
|
+
f"the storage size: the storage size is {len(self._storage)} "
|
|
442
|
+
f"and the index of the item to be set is {cursor}."
|
|
443
|
+
)
|
|
444
|
+
if cursor >= self.max_size:
|
|
445
|
+
raise RuntimeError(
|
|
446
|
+
f"Cannot append data to the list storage: "
|
|
447
|
+
f"maximum capacity is {self.max_size} "
|
|
448
|
+
f"and the index of the item to be set is {cursor}."
|
|
449
|
+
)
|
|
450
|
+
data = self._to_device(data)
|
|
451
|
+
self._set_item(cursor, data)
|
|
452
|
+
|
|
453
|
+
def _set_item(self, cursor: int, data: Any) -> None:
|
|
454
|
+
"""Set a single item in the storage."""
|
|
455
|
+
if cursor == len(self._storage):
|
|
456
|
+
self._storage.append(data)
|
|
457
|
+
else:
|
|
458
|
+
self._storage[cursor] = data
|
|
459
|
+
|
|
460
|
+
def _set_slice(self, cursor: slice, data: Any) -> None:
|
|
461
|
+
"""Set a slice in the storage."""
|
|
462
|
+
self._storage[cursor] = data
|
|
463
|
+
|
|
464
|
+
def get(self, index: int | Sequence[int] | slice) -> Any:
|
|
465
|
+
if isinstance(index, INT_CLASSES):
|
|
466
|
+
return self._get_item(index)
|
|
467
|
+
elif isinstance(index, slice):
|
|
468
|
+
return self._get_slice(index)
|
|
469
|
+
elif isinstance(index, tuple):
|
|
470
|
+
if len(index) > 1:
|
|
471
|
+
raise RuntimeError(
|
|
472
|
+
f"{type(self).__name__} can only be indexed with one-length tuples."
|
|
473
|
+
)
|
|
474
|
+
return self.get(index[0])
|
|
475
|
+
else:
|
|
476
|
+
if isinstance(index, torch.Tensor) and index.device.type != "cpu":
|
|
477
|
+
index = index.cpu().tolist()
|
|
478
|
+
return self._get_list(index)
|
|
479
|
+
|
|
480
|
+
def _get_item(self, index: int) -> Any:
|
|
481
|
+
"""Get a single item from the storage."""
|
|
482
|
+
return self._storage[index]
|
|
483
|
+
|
|
484
|
+
def _get_slice(self, index: slice) -> Any:
|
|
485
|
+
"""Get a slice from the storage."""
|
|
486
|
+
return self._storage[index]
|
|
487
|
+
|
|
488
|
+
def _get_list(self, index: list) -> list:
|
|
489
|
+
"""Get a list of items from the storage."""
|
|
490
|
+
return [self._storage[i] for i in index]
|
|
491
|
+
|
|
492
|
+
def __len__(self):
|
|
493
|
+
"""Get the length of the storage."""
|
|
494
|
+
return len(self._storage)
|
|
495
|
+
|
|
496
|
+
def state_dict(self) -> dict[str, Any]:
|
|
497
|
+
return {
|
|
498
|
+
"_storage": [
|
|
499
|
+
elt if not hasattr(elt, "state_dict") else elt.state_dict()
|
|
500
|
+
for elt in self._storage
|
|
501
|
+
]
|
|
502
|
+
}
|
|
503
|
+
|
|
504
|
+
def load_state_dict(self, state_dict):
|
|
505
|
+
_storage = state_dict["_storage"]
|
|
506
|
+
self._storage = []
|
|
507
|
+
for elt in _storage:
|
|
508
|
+
if isinstance(elt, torch.Tensor):
|
|
509
|
+
self._storage.append(elt)
|
|
510
|
+
elif isinstance(elt, (dict, OrderedDict)):
|
|
511
|
+
self._storage.append(TensorDict().load_state_dict(elt, strict=False))
|
|
512
|
+
else:
|
|
513
|
+
raise TypeError(
|
|
514
|
+
f"Objects of type {type(elt)} are not supported by ListStorage.load_state_dict"
|
|
515
|
+
)
|
|
516
|
+
|
|
517
|
+
def _empty(self):
|
|
518
|
+
self._storage = []
|
|
519
|
+
|
|
520
|
+
def __getstate__(self):
|
|
521
|
+
if get_spawning_popen() is not None:
|
|
522
|
+
raise RuntimeError(
|
|
523
|
+
f"Cannot share a storage of type {type(self)} between processes."
|
|
524
|
+
)
|
|
525
|
+
state = super().__getstate__()
|
|
526
|
+
return state
|
|
527
|
+
|
|
528
|
+
def __repr__(self):
|
|
529
|
+
storage = getattr(self, "_storage", [None])
|
|
530
|
+
if not storage:
|
|
531
|
+
return f"{self.__class__.__name__}()"
|
|
532
|
+
return f"{self.__class__.__name__}(items=[{storage[0]}, ...])"
|
|
533
|
+
|
|
534
|
+
def contains(self, item):
|
|
535
|
+
if isinstance(item, int):
|
|
536
|
+
if item < 0:
|
|
537
|
+
item += len(self._storage)
|
|
538
|
+
return self._contains_int(item)
|
|
539
|
+
if isinstance(item, torch.Tensor):
|
|
540
|
+
return torch.tensor(
|
|
541
|
+
[self.contains(elt) for elt in item.tolist()],
|
|
542
|
+
dtype=torch.bool,
|
|
543
|
+
device=item.device,
|
|
544
|
+
).reshape_as(item)
|
|
545
|
+
raise NotImplementedError(f"type {type(item)} is not supported yet.")
|
|
546
|
+
|
|
547
|
+
def _contains_int(self, item: int) -> bool:
|
|
548
|
+
"""Check if an integer index is contained in the storage."""
|
|
549
|
+
return 0 <= item < len(self._storage)
|
|
550
|
+
|
|
551
|
+
|
|
552
|
+
class LazyStackStorage(ListStorage):
|
|
553
|
+
"""A ListStorage that returns LazyStackTensorDict instances.
|
|
554
|
+
|
|
555
|
+
This storage allows for heterougeneous structures to be indexed as a single `TensorDict` representation.
|
|
556
|
+
It uses :class:`~tensordict.LazyStackedTensorDict` which operates on non-contiguous lists of tensordicts,
|
|
557
|
+
lazily stacking items when queried.
|
|
558
|
+
This means that this storage is going to be fast to sample but data access may be slow (as it requires a stack).
|
|
559
|
+
Tensors of heterogeneous shapes can also be stored within the storage and stacked together.
|
|
560
|
+
Because the storage is represented as a list, the number of tensors to store in memory will grow linearly with
|
|
561
|
+
the size of the buffer.
|
|
562
|
+
|
|
563
|
+
If possible, nested tensors can also be created via :meth:`~tensordict.LazyStackedTensorDict.densify`
|
|
564
|
+
(see :mod:`~torch.nested`).
|
|
565
|
+
|
|
566
|
+
Args:
|
|
567
|
+
max_size (int, optional): the maximum number of elements stored in the storage.
|
|
568
|
+
If not provided, an unlimited storage is created.
|
|
569
|
+
|
|
570
|
+
Keyword Args:
|
|
571
|
+
compilable (bool, optional): if ``True``, the storage will be made compatible with :func:`~torch.compile` at
|
|
572
|
+
the cost of being executable in multiprocessed settings.
|
|
573
|
+
stack_dim (int, optional): the stack dimension in terms of TensorDict batch sizes. Defaults to `0`.
|
|
574
|
+
device (str, optional): the device to use for the storage. Defaults to `None` (inputs are not moved to the device).
|
|
575
|
+
|
|
576
|
+
Examples:
|
|
577
|
+
>>> import torch
|
|
578
|
+
>>> from torchrl.data import ReplayBuffer, LazyStackStorage
|
|
579
|
+
>>> from tensordict import TensorDict
|
|
580
|
+
>>> _ = torch.manual_seed(0)
|
|
581
|
+
>>> rb = ReplayBuffer(storage=LazyStackStorage(max_size=1000, stack_dim=-1))
|
|
582
|
+
>>> data0 = TensorDict(a=torch.randn((10,)), b=torch.rand(4), c="a string!")
|
|
583
|
+
>>> data1 = TensorDict(a=torch.randn((11,)), b=torch.rand(4), c="another string!")
|
|
584
|
+
>>> _ = rb.add(data0)
|
|
585
|
+
>>> _ = rb.add(data1)
|
|
586
|
+
>>> rb.sample(10)
|
|
587
|
+
LazyStackedTensorDict(
|
|
588
|
+
fields={
|
|
589
|
+
a: Tensor(shape=torch.Size([10, -1]), device=cpu, dtype=torch.float32, is_shared=False),
|
|
590
|
+
b: Tensor(shape=torch.Size([10, 4]), device=cpu, dtype=torch.float32, is_shared=False),
|
|
591
|
+
c: NonTensorStack(
|
|
592
|
+
['another string!', 'another string!', 'another st...,
|
|
593
|
+
batch_size=torch.Size([10]),
|
|
594
|
+
device=None)},
|
|
595
|
+
exclusive_fields={
|
|
596
|
+
},
|
|
597
|
+
batch_size=torch.Size([10]),
|
|
598
|
+
device=None,
|
|
599
|
+
is_shared=False,
|
|
600
|
+
stack_dim=0)
|
|
601
|
+
"""
|
|
602
|
+
|
|
603
|
+
def __init__(
|
|
604
|
+
self,
|
|
605
|
+
max_size: int | None = None,
|
|
606
|
+
*,
|
|
607
|
+
compilable: bool = False,
|
|
608
|
+
stack_dim: int = 0,
|
|
609
|
+
device: torch.device | str | int | None = None,
|
|
610
|
+
):
|
|
611
|
+
super().__init__(max_size=max_size, compilable=compilable, device=device)
|
|
612
|
+
self.stack_dim = stack_dim
|
|
613
|
+
|
|
614
|
+
def get(self, index: int | Sequence[int] | slice) -> Any:
|
|
615
|
+
out = super().get(index=index)
|
|
616
|
+
if isinstance(out, list):
|
|
617
|
+
stack_dim = self.stack_dim
|
|
618
|
+
if stack_dim < 0:
|
|
619
|
+
stack_dim = out[0].ndim + 1 + stack_dim
|
|
620
|
+
out = lazy_stack(list(out), stack_dim)
|
|
621
|
+
return out
|
|
622
|
+
return out
|
|
623
|
+
|
|
624
|
+
|
|
625
|
+
class TensorStorage(Storage):
|
|
626
|
+
"""A storage for tensors and tensordicts.
|
|
627
|
+
|
|
628
|
+
Args:
|
|
629
|
+
storage (tensor or TensorDict): the data buffer to be used.
|
|
630
|
+
max_size (int): size of the storage, i.e. maximum number of elements stored
|
|
631
|
+
in the buffer.
|
|
632
|
+
|
|
633
|
+
Keyword Args:
|
|
634
|
+
device (torch.device, optional): device where the sampled tensors will be
|
|
635
|
+
stored and sent. Default is :obj:`torch.device("cpu")`.
|
|
636
|
+
If "auto" is passed, the device is automatically gathered from the
|
|
637
|
+
first batch of data passed. This is not enabled by default to avoid
|
|
638
|
+
data placed on GPU by mistake, causing OOM issues.
|
|
639
|
+
ndim (int, optional): the number of dimensions to be accounted for when
|
|
640
|
+
measuring the storage size. For instance, a storage of shape ``[3, 4]``
|
|
641
|
+
has capacity ``3`` if ``ndim=1`` and ``12`` if ``ndim=2``.
|
|
642
|
+
Defaults to ``1``.
|
|
643
|
+
compilable (bool, optional): whether the storage is compilable.
|
|
644
|
+
If ``True``, the writer cannot be shared between multiple processes.
|
|
645
|
+
Defaults to ``False``.
|
|
646
|
+
|
|
647
|
+
Examples:
|
|
648
|
+
>>> data = TensorDict({
|
|
649
|
+
... "some data": torch.randn(10, 11),
|
|
650
|
+
... ("some", "nested", "data"): torch.randn(10, 11, 12),
|
|
651
|
+
... }, batch_size=[10, 11])
|
|
652
|
+
>>> storage = TensorStorage(data)
|
|
653
|
+
>>> len(storage) # only the first dimension is considered as indexable
|
|
654
|
+
10
|
|
655
|
+
>>> storage.get(0)
|
|
656
|
+
TensorDict(
|
|
657
|
+
fields={
|
|
658
|
+
some data: Tensor(shape=torch.Size([11]), device=cpu, dtype=torch.float32, is_shared=False),
|
|
659
|
+
some: TensorDict(
|
|
660
|
+
fields={
|
|
661
|
+
nested: TensorDict(
|
|
662
|
+
fields={
|
|
663
|
+
data: Tensor(shape=torch.Size([11, 12]), device=cpu, dtype=torch.float32, is_shared=False)},
|
|
664
|
+
batch_size=torch.Size([11]),
|
|
665
|
+
device=None,
|
|
666
|
+
is_shared=False)},
|
|
667
|
+
batch_size=torch.Size([11]),
|
|
668
|
+
device=None,
|
|
669
|
+
is_shared=False)},
|
|
670
|
+
batch_size=torch.Size([11]),
|
|
671
|
+
device=None,
|
|
672
|
+
is_shared=False)
|
|
673
|
+
>>> storage.set(0, storage.get(0).zero_()) # zeros the data along index ``0``
|
|
674
|
+
|
|
675
|
+
This class also supports tensorclass data.
|
|
676
|
+
|
|
677
|
+
Examples:
|
|
678
|
+
>>> from tensordict import tensorclass
|
|
679
|
+
>>> @tensorclass
|
|
680
|
+
... class MyClass:
|
|
681
|
+
... foo: torch.Tensor
|
|
682
|
+
... bar: torch.Tensor
|
|
683
|
+
>>> data = MyClass(foo=torch.randn(10, 11), bar=torch.randn(10, 11, 12), batch_size=[10, 11])
|
|
684
|
+
>>> storage = TensorStorage(data)
|
|
685
|
+
>>> storage.get(0)
|
|
686
|
+
MyClass(
|
|
687
|
+
bar=Tensor(shape=torch.Size([11, 12]), device=cpu, dtype=torch.float32, is_shared=False),
|
|
688
|
+
foo=Tensor(shape=torch.Size([11]), device=cpu, dtype=torch.float32, is_shared=False),
|
|
689
|
+
batch_size=torch.Size([11]),
|
|
690
|
+
device=None,
|
|
691
|
+
is_shared=False)
|
|
692
|
+
|
|
693
|
+
"""
|
|
694
|
+
|
|
695
|
+
_storage = None
|
|
696
|
+
_default_checkpointer = TensorStorageCheckpointer
|
|
697
|
+
|
|
698
|
+
def __init__(
|
|
699
|
+
self,
|
|
700
|
+
storage,
|
|
701
|
+
max_size=None,
|
|
702
|
+
*,
|
|
703
|
+
device: torch.device | str = "cpu",
|
|
704
|
+
ndim: int = 1,
|
|
705
|
+
compilable: bool = False,
|
|
706
|
+
):
|
|
707
|
+
if not ((storage is None) ^ (max_size is None)):
|
|
708
|
+
if storage is None:
|
|
709
|
+
raise ValueError("Expected storage to be non-null.")
|
|
710
|
+
if max_size != storage.shape[0]:
|
|
711
|
+
raise ValueError(
|
|
712
|
+
"The max-size and the storage shape mismatch: got "
|
|
713
|
+
f"max_size={max_size} for a storage of shape {storage.shape}."
|
|
714
|
+
)
|
|
715
|
+
elif storage is not None:
|
|
716
|
+
if is_tensor_collection(storage):
|
|
717
|
+
max_size = storage.shape[0]
|
|
718
|
+
else:
|
|
719
|
+
max_size = tree_flatten(storage)[0][0].shape[0]
|
|
720
|
+
self.ndim = ndim
|
|
721
|
+
super().__init__(max_size, compilable=compilable)
|
|
722
|
+
self.initialized = storage is not None
|
|
723
|
+
if self.initialized:
|
|
724
|
+
self._len = max_size
|
|
725
|
+
else:
|
|
726
|
+
self._len = 0
|
|
727
|
+
self.device = (
|
|
728
|
+
_make_ordinal_device(torch.device(device))
|
|
729
|
+
if device != "auto"
|
|
730
|
+
else storage.device
|
|
731
|
+
if storage is not None
|
|
732
|
+
else "auto"
|
|
733
|
+
)
|
|
734
|
+
self._storage = storage
|
|
735
|
+
self._last_cursor = None
|
|
736
|
+
self.__dict__["_storage_keys"] = None
|
|
737
|
+
|
|
738
|
+
@property
|
|
739
|
+
def _storage_keys(self) -> list | None:
|
|
740
|
+
"""Cached list of storage keys for filtering incoming data.
|
|
741
|
+
|
|
742
|
+
Returns None if storage is not locked, not a tensor collection, or not initialized.
|
|
743
|
+
Only locked storage (shared memory) needs key filtering to prevent adding
|
|
744
|
+
keys that won't propagate in multiprocessing pipelines.
|
|
745
|
+
"""
|
|
746
|
+
keys = self.__dict__.get("_storage_keys")
|
|
747
|
+
if keys is None and self.initialized and is_tensor_collection(self._storage):
|
|
748
|
+
# Only cache keys if storage is locked - unlocked storage can accept new keys
|
|
749
|
+
if self._storage.is_locked:
|
|
750
|
+
keys = list(
|
|
751
|
+
self._storage.keys(
|
|
752
|
+
include_nested=True,
|
|
753
|
+
leaves_only=True,
|
|
754
|
+
is_leaf=_NESTED_TENSORS_AS_LISTS,
|
|
755
|
+
)
|
|
756
|
+
)
|
|
757
|
+
self.__dict__["_storage_keys"] = keys
|
|
758
|
+
return keys
|
|
759
|
+
|
|
760
|
+
@_storage_keys.setter
|
|
761
|
+
def _storage_keys(self, value):
|
|
762
|
+
self.__dict__["_storage_keys"] = value
|
|
763
|
+
|
|
764
|
+
@property
|
|
765
|
+
def _len(self):
|
|
766
|
+
_len_value = self.__dict__.get("_len_value", None)
|
|
767
|
+
if not self._compilable:
|
|
768
|
+
if _len_value is None:
|
|
769
|
+
_len_value = self._len_value = mp.Value("i", 0)
|
|
770
|
+
return _len_value.value
|
|
771
|
+
else:
|
|
772
|
+
if _len_value is None:
|
|
773
|
+
_len_value = self._len_value = 0
|
|
774
|
+
return _len_value
|
|
775
|
+
|
|
776
|
+
@_len.setter
|
|
777
|
+
def _len(self, value):
|
|
778
|
+
if not is_compiling() and not self._compilable:
|
|
779
|
+
_len_value = self.__dict__.get("_len_value", None)
|
|
780
|
+
if _len_value is None:
|
|
781
|
+
_len_value = self._len_value = mp.Value("i", 0)
|
|
782
|
+
_len_value.value = value
|
|
783
|
+
else:
|
|
784
|
+
self._len_value = value
|
|
785
|
+
|
|
786
|
+
@property
|
|
787
|
+
def _total_shape(self):
|
|
788
|
+
# Total shape, irrespective of how full the storage is
|
|
789
|
+
_total_shape = self.__dict__.get("_total_shape_value", None)
|
|
790
|
+
if _total_shape is None and self.initialized:
|
|
791
|
+
if is_tensor_collection(self._storage):
|
|
792
|
+
_total_shape = self._storage.shape[: self.ndim]
|
|
793
|
+
else:
|
|
794
|
+
leaf = next(tree_iter(self._storage))
|
|
795
|
+
_total_shape = leaf.shape[: self.ndim]
|
|
796
|
+
self.__dict__["_total_shape_value"] = _total_shape
|
|
797
|
+
self._len = torch.Size([self._len_along_dim0, *_total_shape[1:]]).numel()
|
|
798
|
+
return _total_shape
|
|
799
|
+
|
|
800
|
+
@property
|
|
801
|
+
def _is_full(self):
|
|
802
|
+
# whether the storage is full
|
|
803
|
+
return len(self) == self.max_size
|
|
804
|
+
|
|
805
|
+
@property
|
|
806
|
+
def _len_along_dim0(self):
|
|
807
|
+
# returns the length of the buffer along dim0
|
|
808
|
+
len_along_dim = len(self)
|
|
809
|
+
if self.ndim > 1:
|
|
810
|
+
_total_shape = self._total_shape
|
|
811
|
+
if _total_shape is not None:
|
|
812
|
+
len_along_dim = -(len_along_dim // -_total_shape[1:].numel())
|
|
813
|
+
else:
|
|
814
|
+
return None
|
|
815
|
+
return len_along_dim
|
|
816
|
+
|
|
817
|
+
def _max_size_along_dim0(self, *, single_data=None, batched_data=None):
|
|
818
|
+
# returns the max_size of the buffer along dim0
|
|
819
|
+
max_size = self.max_size
|
|
820
|
+
if self.ndim > 1:
|
|
821
|
+
shape = self.shape
|
|
822
|
+
if shape is None:
|
|
823
|
+
if single_data is not None:
|
|
824
|
+
data = single_data
|
|
825
|
+
elif batched_data is not None:
|
|
826
|
+
data = batched_data
|
|
827
|
+
else:
|
|
828
|
+
raise ValueError("single_data or batched_data must be passed.")
|
|
829
|
+
if is_tensor_collection(data):
|
|
830
|
+
datashape = data.shape[: self.ndim]
|
|
831
|
+
else:
|
|
832
|
+
for leaf in tree_iter(data):
|
|
833
|
+
datashape = leaf.shape[: self.ndim]
|
|
834
|
+
break
|
|
835
|
+
if batched_data is not None:
|
|
836
|
+
datashape = datashape[1:]
|
|
837
|
+
max_size = -(max_size // -datashape.numel())
|
|
838
|
+
else:
|
|
839
|
+
max_size = -(max_size // -self._total_shape[1:].numel())
|
|
840
|
+
return max_size
|
|
841
|
+
|
|
842
|
+
@property
|
|
843
|
+
def shape(self):
|
|
844
|
+
# Shape, truncated where needed to accommodate for the length of the storage
|
|
845
|
+
if self._is_full:
|
|
846
|
+
return self._total_shape
|
|
847
|
+
_total_shape = self._total_shape
|
|
848
|
+
if _total_shape is not None:
|
|
849
|
+
return torch.Size([self._len_along_dim0] + list(_total_shape[1:]))
|
|
850
|
+
|
|
851
|
+
# TODO: Without this disable, compiler recompiles for back-to-back calls.
|
|
852
|
+
# Figuring out a way to avoid this disable would give better performance.
|
|
853
|
+
@compile_disable()
|
|
854
|
+
def _rand_given_ndim(self, batch_size):
|
|
855
|
+
return self._rand_given_ndim_impl(batch_size)
|
|
856
|
+
|
|
857
|
+
# At the moment, this is separated into its own function so that we can test
|
|
858
|
+
# it without the `disable` and detect if future updates to the
|
|
859
|
+
# compiler fix the recompile issue.
|
|
860
|
+
def _rand_given_ndim_impl(self, batch_size):
|
|
861
|
+
if self.ndim == 1:
|
|
862
|
+
return super()._rand_given_ndim(batch_size)
|
|
863
|
+
shape = self.shape
|
|
864
|
+
return tuple(
|
|
865
|
+
torch.randint(_dim, (batch_size,), generator=self._rng, device=self.device)
|
|
866
|
+
for _dim in shape
|
|
867
|
+
)
|
|
868
|
+
|
|
869
|
+
def flatten(self):
|
|
870
|
+
if self.ndim == 1:
|
|
871
|
+
return self
|
|
872
|
+
if not self.initialized:
|
|
873
|
+
raise RuntimeError("Cannot flatten a non-initialized storage.")
|
|
874
|
+
if is_tensor_collection(self._storage):
|
|
875
|
+
if self._is_full:
|
|
876
|
+
return TensorStorage(self._storage.flatten(0, self.ndim - 1))
|
|
877
|
+
return TensorStorage(
|
|
878
|
+
self._storage[: self._len_along_dim0].flatten(0, self.ndim - 1)
|
|
879
|
+
)
|
|
880
|
+
if self._is_full:
|
|
881
|
+
return TensorStorage(
|
|
882
|
+
tree_map(lambda x: x.flatten(0, self.ndim - 1), self._storage)
|
|
883
|
+
)
|
|
884
|
+
return TensorStorage(
|
|
885
|
+
tree_map(
|
|
886
|
+
lambda x: x[: self._len_along_dim0].flatten(0, self.ndim - 1),
|
|
887
|
+
self._storage,
|
|
888
|
+
)
|
|
889
|
+
)
|
|
890
|
+
|
|
891
|
+
def __getstate__(self):
|
|
892
|
+
state = super().__getstate__()
|
|
893
|
+
if get_spawning_popen() is None:
|
|
894
|
+
length = self._len
|
|
895
|
+
del state["_len_value"]
|
|
896
|
+
state["len__context"] = length
|
|
897
|
+
elif not self.initialized:
|
|
898
|
+
if not self.shared_init:
|
|
899
|
+
# check that the storage is initialized
|
|
900
|
+
raise RuntimeError(
|
|
901
|
+
f"Cowardly refusing to share a storage of type {type(self)} between processes if "
|
|
902
|
+
f"it has not been initialized yet. You can either:\n"
|
|
903
|
+
f"- Populate the buffer with some data in the main process before passing it to the other processes (or create the buffer explicitly with a TensorStorage).\n"
|
|
904
|
+
f"- set shared_init=True when creating the storage such that it can be initialized by the remote processes."
|
|
905
|
+
)
|
|
906
|
+
return state
|
|
907
|
+
else:
|
|
908
|
+
# check that the content is shared, otherwise tell the user we can't help
|
|
909
|
+
storage = self._storage
|
|
910
|
+
STORAGE_ERR = "The storage must be place in shared memory or memmapped before being shared between processes."
|
|
911
|
+
|
|
912
|
+
# If the content is on cpu, it will be placed in shared memory.
|
|
913
|
+
# If it's on cuda it's already shared.
|
|
914
|
+
# If it's memmaped no worry in this case either.
|
|
915
|
+
# Only if the device is not "cpu" or "cuda" we may have a problem.
|
|
916
|
+
def assert_is_sharable(tensor):
|
|
917
|
+
if tensor.device is None or tensor.device.type in (
|
|
918
|
+
"cuda",
|
|
919
|
+
"cpu",
|
|
920
|
+
"meta",
|
|
921
|
+
):
|
|
922
|
+
return
|
|
923
|
+
raise RuntimeError(STORAGE_ERR)
|
|
924
|
+
|
|
925
|
+
if is_tensor_collection(storage):
|
|
926
|
+
storage.apply(assert_is_sharable, filter_empty=True)
|
|
927
|
+
else:
|
|
928
|
+
tree_map(storage, assert_is_sharable)
|
|
929
|
+
|
|
930
|
+
return state
|
|
931
|
+
|
|
932
|
+
def __setstate__(self, state):
|
|
933
|
+
len = state.pop("len__context", None)
|
|
934
|
+
if len is not None:
|
|
935
|
+
if not state["_compilable"]:
|
|
936
|
+
_len_value = mp.Value("i", len)
|
|
937
|
+
state["_len_value"] = _len_value
|
|
938
|
+
else:
|
|
939
|
+
state["_len_value"] = len
|
|
940
|
+
self.__dict__.update(state)
|
|
941
|
+
|
|
942
|
+
def state_dict(self) -> dict[str, Any]:
|
|
943
|
+
_storage = self._storage
|
|
944
|
+
if isinstance(_storage, torch.Tensor):
|
|
945
|
+
pass
|
|
946
|
+
elif is_tensor_collection(_storage):
|
|
947
|
+
_storage = _storage.state_dict()
|
|
948
|
+
elif _storage is None:
|
|
949
|
+
_storage = {}
|
|
950
|
+
else:
|
|
951
|
+
raise TypeError(
|
|
952
|
+
f"Objects of type {type(_storage)} are not supported by {type(self)}.state_dict"
|
|
953
|
+
)
|
|
954
|
+
return {
|
|
955
|
+
"_storage": _storage,
|
|
956
|
+
"initialized": self.initialized,
|
|
957
|
+
"_len": self._len,
|
|
958
|
+
}
|
|
959
|
+
|
|
960
|
+
def load_state_dict(self, state_dict):
|
|
961
|
+
_storage = copy(state_dict["_storage"])
|
|
962
|
+
if isinstance(_storage, torch.Tensor):
|
|
963
|
+
if isinstance(self._storage, torch.Tensor):
|
|
964
|
+
self._storage.copy_(_storage)
|
|
965
|
+
elif self._storage is None:
|
|
966
|
+
self._storage = _storage
|
|
967
|
+
else:
|
|
968
|
+
raise RuntimeError(
|
|
969
|
+
f"Cannot copy a storage of type {type(_storage)} onto another of type {type(self._storage)}"
|
|
970
|
+
)
|
|
971
|
+
elif isinstance(_storage, (dict, OrderedDict)):
|
|
972
|
+
if is_tensor_collection(self._storage):
|
|
973
|
+
self._storage.load_state_dict(_storage, strict=False)
|
|
974
|
+
elif self._storage is None:
|
|
975
|
+
self._storage = TensorDict().load_state_dict(_storage, strict=False)
|
|
976
|
+
else:
|
|
977
|
+
raise RuntimeError(
|
|
978
|
+
f"Cannot copy a storage of type {type(_storage)} onto another of type {type(self._storage)}. If your storage is pytree-based, use the dumps/load API instead."
|
|
979
|
+
)
|
|
980
|
+
else:
|
|
981
|
+
raise TypeError(
|
|
982
|
+
f"Objects of type {type(_storage)} are not supported by ListStorage.load_state_dict"
|
|
983
|
+
)
|
|
984
|
+
self.initialized = state_dict["initialized"]
|
|
985
|
+
self._len = state_dict["_len"]
|
|
986
|
+
|
|
987
|
+
@implement_for("torch", "2.3", compilable=True)
|
|
988
|
+
def _set_tree_map(self, cursor, data, storage):
|
|
989
|
+
def set_tensor(datum, store):
|
|
990
|
+
store[cursor] = datum
|
|
991
|
+
|
|
992
|
+
# this won't be available until v2.3
|
|
993
|
+
tree_map(set_tensor, data, storage)
|
|
994
|
+
|
|
995
|
+
@implement_for("torch", "2.0", "2.3", compilable=True)
|
|
996
|
+
def _set_tree_map(self, cursor, data, storage): # noqa: 534
|
|
997
|
+
# flatten data and cursor
|
|
998
|
+
data_flat = tree_flatten(data)[0]
|
|
999
|
+
storage_flat = tree_flatten(storage)[0]
|
|
1000
|
+
for datum, store in zip(data_flat, storage_flat):
|
|
1001
|
+
store[cursor] = datum
|
|
1002
|
+
|
|
1003
|
+
def _get_new_len(self, data, cursor):
|
|
1004
|
+
int_cursor = _is_int(cursor)
|
|
1005
|
+
ndim = self.ndim - int_cursor
|
|
1006
|
+
if is_tensor_collection(data) or isinstance(data, torch.Tensor):
|
|
1007
|
+
numel = data.shape[:ndim].numel()
|
|
1008
|
+
else:
|
|
1009
|
+
leaf = next(tree_iter(data))
|
|
1010
|
+
numel = leaf.shape[:ndim].numel()
|
|
1011
|
+
self._len = min(self._len + numel, self.max_size)
|
|
1012
|
+
|
|
1013
|
+
@implement_for("torch", "2.0", None, compilable=True)
|
|
1014
|
+
def set(
|
|
1015
|
+
self,
|
|
1016
|
+
cursor: int | Sequence[int] | slice,
|
|
1017
|
+
data: TensorDictBase | torch.Tensor,
|
|
1018
|
+
*,
|
|
1019
|
+
set_cursor: bool = True,
|
|
1020
|
+
):
|
|
1021
|
+
if set_cursor:
|
|
1022
|
+
self._last_cursor = cursor
|
|
1023
|
+
|
|
1024
|
+
if isinstance(data, list):
|
|
1025
|
+
# flip list
|
|
1026
|
+
try:
|
|
1027
|
+
data = _flip_list(data)
|
|
1028
|
+
except Exception:
|
|
1029
|
+
raise RuntimeError(
|
|
1030
|
+
"Stacking the elements of the list resulted in "
|
|
1031
|
+
"an error. "
|
|
1032
|
+
f"Storages of type {type(self)} expect all elements of the list "
|
|
1033
|
+
f"to have the same tree structure. If the list is compact (each "
|
|
1034
|
+
f"leaf is itself a batch with the appropriate number of elements) "
|
|
1035
|
+
f"consider using a tuple instead, as lists are used within `extend` "
|
|
1036
|
+
f"for per-item addition."
|
|
1037
|
+
)
|
|
1038
|
+
|
|
1039
|
+
if set_cursor:
|
|
1040
|
+
self._get_new_len(data, cursor)
|
|
1041
|
+
|
|
1042
|
+
if not self.initialized:
|
|
1043
|
+
if not isinstance(cursor, INT_CLASSES):
|
|
1044
|
+
if is_tensor_collection(data):
|
|
1045
|
+
self._init(data[0])
|
|
1046
|
+
else:
|
|
1047
|
+
self._init(tree_map(lambda x: x[0], data))
|
|
1048
|
+
else:
|
|
1049
|
+
self._init(data)
|
|
1050
|
+
|
|
1051
|
+
if is_tensor_collection(data):
|
|
1052
|
+
# Filter data to only include keys present in storage.
|
|
1053
|
+
# _storage_keys is only set when storage is locked (shared memory),
|
|
1054
|
+
# so this handles cases where policy outputs extra keys that can't
|
|
1055
|
+
# be added to locked shared memory.
|
|
1056
|
+
storage_keys = self._storage_keys
|
|
1057
|
+
if storage_keys is not None:
|
|
1058
|
+
data = data.select(*storage_keys, strict=False)
|
|
1059
|
+
try:
|
|
1060
|
+
self._storage[cursor] = data
|
|
1061
|
+
except RuntimeError as e:
|
|
1062
|
+
if "locked" in str(e).lower():
|
|
1063
|
+
# Provide informative error about key differences
|
|
1064
|
+
self._raise_informative_lock_error(data, e)
|
|
1065
|
+
raise
|
|
1066
|
+
else:
|
|
1067
|
+
self._set_tree_map(cursor, data, self._storage)
|
|
1068
|
+
|
|
1069
|
+
@implement_for("torch", None, "2.0", compilable=True)
|
|
1070
|
+
def set( # noqa: F811
|
|
1071
|
+
self,
|
|
1072
|
+
cursor: int | Sequence[int] | slice,
|
|
1073
|
+
data: TensorDictBase | torch.Tensor,
|
|
1074
|
+
*,
|
|
1075
|
+
set_cursor: bool = True,
|
|
1076
|
+
):
|
|
1077
|
+
if set_cursor:
|
|
1078
|
+
self._last_cursor = cursor
|
|
1079
|
+
|
|
1080
|
+
if isinstance(data, list):
|
|
1081
|
+
# flip list
|
|
1082
|
+
try:
|
|
1083
|
+
data = _flip_list(data)
|
|
1084
|
+
except Exception:
|
|
1085
|
+
raise RuntimeError(
|
|
1086
|
+
"Stacking the elements of the list resulted in "
|
|
1087
|
+
"an error. "
|
|
1088
|
+
f"Storages of type {type(self)} expect all elements of the list "
|
|
1089
|
+
f"to have the same tree structure. If the list is compact (each "
|
|
1090
|
+
f"leaf is itself a batch with the appropriate number of elements) "
|
|
1091
|
+
f"consider using a tuple instead, as lists are used within `extend` "
|
|
1092
|
+
f"for per-item addition."
|
|
1093
|
+
)
|
|
1094
|
+
if set_cursor:
|
|
1095
|
+
self._get_new_len(data, cursor)
|
|
1096
|
+
|
|
1097
|
+
if not is_tensor_collection(data) and not isinstance(data, torch.Tensor):
|
|
1098
|
+
raise NotImplementedError(
|
|
1099
|
+
"storage extension with pytrees is only available with torch >= 2.0. If you need this "
|
|
1100
|
+
"feature, please open an issue on TorchRL's github repository."
|
|
1101
|
+
)
|
|
1102
|
+
if not self.initialized:
|
|
1103
|
+
if not isinstance(cursor, INT_CLASSES):
|
|
1104
|
+
self._init(data[0])
|
|
1105
|
+
else:
|
|
1106
|
+
self._init(data)
|
|
1107
|
+
|
|
1108
|
+
if not isinstance(cursor, (*INT_CLASSES, slice)):
|
|
1109
|
+
if not isinstance(cursor, torch.Tensor):
|
|
1110
|
+
cursor = torch.tensor(cursor, dtype=torch.long)
|
|
1111
|
+
elif cursor.dtype != torch.long:
|
|
1112
|
+
cursor = cursor.to(dtype=torch.long)
|
|
1113
|
+
if len(cursor) > self._len_along_dim0:
|
|
1114
|
+
warnings.warn(
|
|
1115
|
+
"A cursor of length superior to the storage capacity was provided. "
|
|
1116
|
+
"To accommodate for this, the cursor will be truncated to its last "
|
|
1117
|
+
"element such that its length matched the length of the storage. "
|
|
1118
|
+
"This may **not** be the optimal behavior for your application! "
|
|
1119
|
+
"Make sure that the storage capacity is big enough to support the "
|
|
1120
|
+
"batch size provided."
|
|
1121
|
+
)
|
|
1122
|
+
# Filter data to only include keys present in storage.
|
|
1123
|
+
# _storage_keys is only set when storage is locked (shared memory),
|
|
1124
|
+
# so this handles cases where policy outputs extra keys that can't
|
|
1125
|
+
# be added to locked shared memory.
|
|
1126
|
+
if is_tensor_collection(data):
|
|
1127
|
+
storage_keys = self._storage_keys
|
|
1128
|
+
if storage_keys is not None:
|
|
1129
|
+
data = data.select(*storage_keys, strict=False)
|
|
1130
|
+
try:
|
|
1131
|
+
self._storage[cursor] = data
|
|
1132
|
+
except RuntimeError as e:
|
|
1133
|
+
if "locked" in str(e).lower():
|
|
1134
|
+
# Provide informative error about key differences
|
|
1135
|
+
self._raise_informative_lock_error(data, e)
|
|
1136
|
+
raise
|
|
1137
|
+
|
|
1138
|
+
def _wait_for_init(self):
|
|
1139
|
+
pass
|
|
1140
|
+
|
|
1141
|
+
def _raise_informative_lock_error(
|
|
1142
|
+
self, data: TensorDictBase | torch.Tensor, original_error: RuntimeError
|
|
1143
|
+
) -> None:
|
|
1144
|
+
"""Raise an informative error when storage is locked and data has different keys.
|
|
1145
|
+
|
|
1146
|
+
This method is called when an assignment to the storage fails due to a lock error.
|
|
1147
|
+
It provides detailed information about which keys are new in the data vs what the
|
|
1148
|
+
storage expects.
|
|
1149
|
+
"""
|
|
1150
|
+
if not is_tensor_collection(data) or not is_tensor_collection(self._storage):
|
|
1151
|
+
# Can only provide detailed info for tensor collections
|
|
1152
|
+
raise original_error
|
|
1153
|
+
|
|
1154
|
+
# Get all keys from both storage and data
|
|
1155
|
+
storage_keys = set(
|
|
1156
|
+
self._storage.keys(
|
|
1157
|
+
include_nested=True, leaves_only=True, is_leaf=_NESTED_TENSORS_AS_LISTS
|
|
1158
|
+
)
|
|
1159
|
+
)
|
|
1160
|
+
data_keys = set(
|
|
1161
|
+
data.keys(
|
|
1162
|
+
include_nested=True, leaves_only=True, is_leaf=_NESTED_TENSORS_AS_LISTS
|
|
1163
|
+
)
|
|
1164
|
+
)
|
|
1165
|
+
|
|
1166
|
+
new_keys = data_keys - storage_keys
|
|
1167
|
+
missing_keys = storage_keys - data_keys
|
|
1168
|
+
|
|
1169
|
+
error_parts = [
|
|
1170
|
+
"Cannot write to locked storage due to key mismatch.",
|
|
1171
|
+
f"\nOriginal error: {original_error}",
|
|
1172
|
+
]
|
|
1173
|
+
|
|
1174
|
+
if new_keys:
|
|
1175
|
+
error_parts.append(
|
|
1176
|
+
f"\n\nNew keys in data (not in storage): {sorted(str(k) for k in new_keys)}"
|
|
1177
|
+
)
|
|
1178
|
+
if missing_keys:
|
|
1179
|
+
error_parts.append(
|
|
1180
|
+
f"\n\nMissing keys in data (present in storage): {sorted(str(k) for k in missing_keys)}"
|
|
1181
|
+
)
|
|
1182
|
+
|
|
1183
|
+
if new_keys or missing_keys:
|
|
1184
|
+
error_parts.append(
|
|
1185
|
+
"\n\nThis typically happens when:"
|
|
1186
|
+
"\n 1. The policy is called on some steps but not others (e.g., during init_random_frames)"
|
|
1187
|
+
"\n 2. A transform conditionally adds keys based on data content"
|
|
1188
|
+
"\n 3. Different collectors/workers produce data with different keys"
|
|
1189
|
+
"\n\nTo fix this, ensure all data written to the buffer has consistent keys."
|
|
1190
|
+
)
|
|
1191
|
+
else:
|
|
1192
|
+
error_parts.append(
|
|
1193
|
+
"\n\nNo key differences detected. The lock error may be due to shape or dtype mismatches."
|
|
1194
|
+
)
|
|
1195
|
+
|
|
1196
|
+
raise RuntimeError("".join(error_parts)) from original_error
|
|
1197
|
+
|
|
1198
|
+
def get(self, index: int | Sequence[int] | slice) -> Any:
|
|
1199
|
+
_storage = self._storage
|
|
1200
|
+
is_tc = is_tensor_collection(_storage)
|
|
1201
|
+
if not self.initialized:
|
|
1202
|
+
if getattr(self, "shared_init", False):
|
|
1203
|
+
self._wait_for_init()
|
|
1204
|
+
raise RuntimeError("Cannot get elements out of a non-initialized storage.")
|
|
1205
|
+
if not self._is_full:
|
|
1206
|
+
if is_tc:
|
|
1207
|
+
storage = self._storage[: self._len_along_dim0]
|
|
1208
|
+
else:
|
|
1209
|
+
storage = tree_map(lambda x: x[: self._len_along_dim0], self._storage)
|
|
1210
|
+
else:
|
|
1211
|
+
storage = self._storage
|
|
1212
|
+
if not self.initialized:
|
|
1213
|
+
raise RuntimeError(
|
|
1214
|
+
"Cannot get an item from an uninitialized LazyMemmapStorage"
|
|
1215
|
+
)
|
|
1216
|
+
if is_tc:
|
|
1217
|
+
return storage[index]
|
|
1218
|
+
else:
|
|
1219
|
+
return tree_map(lambda x: x[index], storage)
|
|
1220
|
+
|
|
1221
|
+
# TODO: Without this disable, compiler recompiles due to changing _len_value guards.
|
|
1222
|
+
@compile_disable()
|
|
1223
|
+
def __len__(self):
|
|
1224
|
+
return self._len
|
|
1225
|
+
|
|
1226
|
+
def _empty(self):
|
|
1227
|
+
# assuming that the data structure is the same, we don't need to to
|
|
1228
|
+
# anything if the cursor is reset to 0
|
|
1229
|
+
self._len = 0
|
|
1230
|
+
|
|
1231
|
+
def _init(self):
|
|
1232
|
+
raise NotImplementedError(
|
|
1233
|
+
f"{type(self)} must be initialized during construction."
|
|
1234
|
+
)
|
|
1235
|
+
|
|
1236
|
+
def __repr__(self):
|
|
1237
|
+
if not self.initialized:
|
|
1238
|
+
storage_str = textwrap.indent("data=<empty>", 4 * " ")
|
|
1239
|
+
elif is_tensor_collection(self._storage):
|
|
1240
|
+
storage_str = textwrap.indent(f"data={self[:]}", 4 * " ")
|
|
1241
|
+
else:
|
|
1242
|
+
|
|
1243
|
+
def repr_item(x):
|
|
1244
|
+
if isinstance(x, torch.Tensor):
|
|
1245
|
+
return f"{x.__class__.__name__}(shape={x.shape}, dtype={x.dtype}, device={x.device})"
|
|
1246
|
+
return x.__class__.__name__
|
|
1247
|
+
|
|
1248
|
+
storage_str = textwrap.indent(
|
|
1249
|
+
f"data={tree_map(repr_item, self[:])}", 4 * " "
|
|
1250
|
+
)
|
|
1251
|
+
shape_str = textwrap.indent(f"shape={self.shape}", 4 * " ")
|
|
1252
|
+
len_str = textwrap.indent(f"len={len(self)}", 4 * " ")
|
|
1253
|
+
maxsize_str = textwrap.indent(f"max_size={self.max_size}", 4 * " ")
|
|
1254
|
+
return f"{self.__class__.__name__}(\n{storage_str}, \n{shape_str}, \n{len_str}, \n{maxsize_str})"
|
|
1255
|
+
|
|
1256
|
+
def contains(self, item):
|
|
1257
|
+
if isinstance(item, int):
|
|
1258
|
+
if item < 0:
|
|
1259
|
+
item += self._len_along_dim0
|
|
1260
|
+
|
|
1261
|
+
return 0 <= item < self._len_along_dim0
|
|
1262
|
+
if isinstance(item, torch.Tensor):
|
|
1263
|
+
|
|
1264
|
+
def _is_valid_index(idx):
|
|
1265
|
+
try:
|
|
1266
|
+
torch.zeros(self.shape, device="meta")[idx]
|
|
1267
|
+
return True
|
|
1268
|
+
except IndexError:
|
|
1269
|
+
return False
|
|
1270
|
+
|
|
1271
|
+
if item.ndim:
|
|
1272
|
+
return torch.tensor(
|
|
1273
|
+
[_is_valid_index(idx) for idx in item],
|
|
1274
|
+
dtype=torch.bool,
|
|
1275
|
+
device=item.device,
|
|
1276
|
+
)
|
|
1277
|
+
return torch.tensor(_is_valid_index(item), device=item.device)
|
|
1278
|
+
raise NotImplementedError(f"type {type(item)} is not supported yet.")
|
|
1279
|
+
|
|
1280
|
+
|
|
1281
|
+
class LazyTensorStorage(TensorStorage):
|
|
1282
|
+
"""A pre-allocated tensor storage for tensors and tensordicts.
|
|
1283
|
+
|
|
1284
|
+
Args:
|
|
1285
|
+
max_size (int): size of the storage, i.e. maximum number of elements stored
|
|
1286
|
+
in the buffer.
|
|
1287
|
+
|
|
1288
|
+
Keyword Args:
|
|
1289
|
+
device (torch.device, optional): device where the sampled tensors will be
|
|
1290
|
+
stored and sent. Default is :obj:`torch.device("cpu")`.
|
|
1291
|
+
If "auto" is passed, the device is automatically gathered from the
|
|
1292
|
+
first batch of data passed. This is not enabled by default to avoid
|
|
1293
|
+
data placed on GPU by mistake, causing OOM issues.
|
|
1294
|
+
ndim (int, optional): the number of dimensions to be accounted for when
|
|
1295
|
+
measuring the storage size. For instance, a storage of shape ``[3, 4]``
|
|
1296
|
+
has capacity ``3`` if ``ndim=1`` and ``12`` if ``ndim=2``.
|
|
1297
|
+
Defaults to ``1``.
|
|
1298
|
+
compilable (bool, optional): whether the storage is compilable.
|
|
1299
|
+
If ``True``, the writer cannot be shared between multiple processes.
|
|
1300
|
+
Defaults to ``False``.
|
|
1301
|
+
consolidated (bool, optional): if ``True``, the storage will be consolidated after
|
|
1302
|
+
its first expansion. Defaults to ``False``.
|
|
1303
|
+
shared_init (bool, optional): if ``True``, enables multiprocess coordination
|
|
1304
|
+
during storage initialization. First process initializes with memmap,
|
|
1305
|
+
others wait and load from the shared memmap. Defaults to ``False``.
|
|
1306
|
+
cleanup_memmap (bool, optional): if ``True`` and ``shared_init=True``,
|
|
1307
|
+
the temporary memmap will be deleted after initialization and the
|
|
1308
|
+
storage will operate in RAM. Defaults to ``True``.
|
|
1309
|
+
|
|
1310
|
+
Examples:
|
|
1311
|
+
>>> data = TensorDict({
|
|
1312
|
+
... "some data": torch.randn(10, 11),
|
|
1313
|
+
... ("some", "nested", "data"): torch.randn(10, 11, 12),
|
|
1314
|
+
... }, batch_size=[10, 11])
|
|
1315
|
+
>>> storage = LazyTensorStorage(100)
|
|
1316
|
+
>>> storage.set(range(10), data)
|
|
1317
|
+
>>> len(storage) # only the first dimension is considered as indexable
|
|
1318
|
+
10
|
|
1319
|
+
>>> storage.get(0)
|
|
1320
|
+
TensorDict(
|
|
1321
|
+
fields={
|
|
1322
|
+
some data: Tensor(shape=torch.Size([11]), device=cpu, dtype=torch.float32, is_shared=False),
|
|
1323
|
+
some: TensorDict(
|
|
1324
|
+
fields={
|
|
1325
|
+
nested: TensorDict(
|
|
1326
|
+
fields={
|
|
1327
|
+
data: Tensor(shape=torch.Size([11, 12]), device=cpu, dtype=torch.float32, is_shared=False)},
|
|
1328
|
+
batch_size=torch.Size([11]),
|
|
1329
|
+
device=cpu,
|
|
1330
|
+
is_shared=False)},
|
|
1331
|
+
batch_size=torch.Size([11]),
|
|
1332
|
+
device=cpu,
|
|
1333
|
+
is_shared=False)},
|
|
1334
|
+
batch_size=torch.Size([11]),
|
|
1335
|
+
device=cpu,
|
|
1336
|
+
is_shared=False)
|
|
1337
|
+
>>> storage.set(0, storage.get(0).zero_()) # zeros the data along index ``0``
|
|
1338
|
+
|
|
1339
|
+
This class also supports tensorclass data.
|
|
1340
|
+
|
|
1341
|
+
Examples:
|
|
1342
|
+
>>> from tensordict import tensorclass
|
|
1343
|
+
>>> @tensorclass
|
|
1344
|
+
... class MyClass:
|
|
1345
|
+
... foo: torch.Tensor
|
|
1346
|
+
... bar: torch.Tensor
|
|
1347
|
+
>>> data = MyClass(foo=torch.randn(10, 11), bar=torch.randn(10, 11, 12), batch_size=[10, 11])
|
|
1348
|
+
>>> storage = LazyTensorStorage(10)
|
|
1349
|
+
>>> storage.set(range(10), data)
|
|
1350
|
+
>>> storage.get(0)
|
|
1351
|
+
MyClass(
|
|
1352
|
+
bar=Tensor(shape=torch.Size([11, 12]), device=cpu, dtype=torch.float32, is_shared=False),
|
|
1353
|
+
foo=Tensor(shape=torch.Size([11]), device=cpu, dtype=torch.float32, is_shared=False),
|
|
1354
|
+
batch_size=torch.Size([11]),
|
|
1355
|
+
device=cpu,
|
|
1356
|
+
is_shared=False)
|
|
1357
|
+
|
|
1358
|
+
"""
|
|
1359
|
+
|
|
1360
|
+
_default_checkpointer = TensorStorageCheckpointer
|
|
1361
|
+
|
|
1362
|
+
def __init__(
|
|
1363
|
+
self,
|
|
1364
|
+
max_size: int,
|
|
1365
|
+
*,
|
|
1366
|
+
device: torch.device | str = "cpu",
|
|
1367
|
+
ndim: int = 1,
|
|
1368
|
+
compilable: bool = False,
|
|
1369
|
+
consolidated: bool = False,
|
|
1370
|
+
shared_init: bool = False,
|
|
1371
|
+
cleanup_memmap: bool = True,
|
|
1372
|
+
):
|
|
1373
|
+
super().__init__(
|
|
1374
|
+
storage=None,
|
|
1375
|
+
max_size=max_size,
|
|
1376
|
+
device=device,
|
|
1377
|
+
ndim=ndim,
|
|
1378
|
+
compilable=compilable,
|
|
1379
|
+
)
|
|
1380
|
+
self.consolidated = consolidated
|
|
1381
|
+
self.shared_init = shared_init
|
|
1382
|
+
self.cleanup_memmap = cleanup_memmap
|
|
1383
|
+
|
|
1384
|
+
# Initialize multiprocess coordination objects if shared_init is enabled
|
|
1385
|
+
if self.shared_init:
|
|
1386
|
+
if self._compilable:
|
|
1387
|
+
raise RuntimeError(
|
|
1388
|
+
"Cannot share a compilable storage between processes."
|
|
1389
|
+
)
|
|
1390
|
+
self._init_lock = mp.Lock()
|
|
1391
|
+
self._init_event = mp.Event()
|
|
1392
|
+
self._make_init_directory()
|
|
1393
|
+
|
|
1394
|
+
def _make_init_directory(self):
|
|
1395
|
+
if getattr(self, "scratch_dir", None) is not None:
|
|
1396
|
+
self._init_directory = self.scratch_dir
|
|
1397
|
+
return
|
|
1398
|
+
# Create a shared directory
|
|
1399
|
+
self.scratch_dir = self._init_directory = tempfile.mkdtemp(
|
|
1400
|
+
prefix="torchrl_storage_init_"
|
|
1401
|
+
)
|
|
1402
|
+
return
|
|
1403
|
+
|
|
1404
|
+
def _init(
|
|
1405
|
+
self,
|
|
1406
|
+
data: TensorDictBase | torch.Tensor | PyTree, # noqa: F821
|
|
1407
|
+
) -> None:
|
|
1408
|
+
if not self.shared_init:
|
|
1409
|
+
return self._init_standard(data)
|
|
1410
|
+
|
|
1411
|
+
# Try to become coordinator
|
|
1412
|
+
is_coordinator = not self._init_event.is_set()
|
|
1413
|
+
is_coordinator = is_coordinator and self._init_lock.acquire(block=False)
|
|
1414
|
+
|
|
1415
|
+
if is_coordinator:
|
|
1416
|
+
try:
|
|
1417
|
+
# We are the coordinator
|
|
1418
|
+
self._init_coordinator(data)
|
|
1419
|
+
finally:
|
|
1420
|
+
# Signal other processes that initialization is complete
|
|
1421
|
+
self._init_event.set()
|
|
1422
|
+
self._init_lock.release()
|
|
1423
|
+
else:
|
|
1424
|
+
# Failed to acquire lock, wait for coordinator
|
|
1425
|
+
self._wait_for_init()
|
|
1426
|
+
|
|
1427
|
+
self.initialized = True
|
|
1428
|
+
|
|
1429
|
+
def _init_standard(
|
|
1430
|
+
self,
|
|
1431
|
+
data: TensorDictBase | torch.Tensor | PyTree, # noqa: F821
|
|
1432
|
+
) -> None:
|
|
1433
|
+
"""Standard initialization without multiprocess coordination."""
|
|
1434
|
+
if not self._compilable:
|
|
1435
|
+
# TODO: Investigate why this seems to have a performance impact with
|
|
1436
|
+
# the compiler
|
|
1437
|
+
torchrl_logger.debug("Creating a TensorStorage...")
|
|
1438
|
+
if self.device == "auto":
|
|
1439
|
+
self.device = data.device
|
|
1440
|
+
|
|
1441
|
+
def max_size_along_dim0(data_shape):
|
|
1442
|
+
if self.ndim > 1:
|
|
1443
|
+
result = (
|
|
1444
|
+
-(self.max_size // -data_shape[: self.ndim - 1].numel()),
|
|
1445
|
+
*data_shape,
|
|
1446
|
+
)
|
|
1447
|
+
self.max_size = torch.Size(result).numel()
|
|
1448
|
+
return result
|
|
1449
|
+
return (self.max_size, *data_shape)
|
|
1450
|
+
|
|
1451
|
+
if is_tensor_collection(data):
|
|
1452
|
+
out = data.to(self.device)
|
|
1453
|
+
out: TensorDictBase = torch.empty_like(
|
|
1454
|
+
out.expand(max_size_along_dim0(data.shape))
|
|
1455
|
+
)
|
|
1456
|
+
if self.consolidated:
|
|
1457
|
+
out = out.consolidate()
|
|
1458
|
+
else:
|
|
1459
|
+
# if Tensor, we just create a MemoryMappedTensor of the desired shape, device and dtype
|
|
1460
|
+
out = tree_map(
|
|
1461
|
+
lambda data: torch.empty(
|
|
1462
|
+
max_size_along_dim0(data.shape),
|
|
1463
|
+
device=self.device,
|
|
1464
|
+
dtype=data.dtype,
|
|
1465
|
+
),
|
|
1466
|
+
data,
|
|
1467
|
+
)
|
|
1468
|
+
if self.consolidated:
|
|
1469
|
+
raise ValueError("Cannot consolidate non-tensordict storages.")
|
|
1470
|
+
|
|
1471
|
+
self._storage = out
|
|
1472
|
+
self.initialized = True
|
|
1473
|
+
if hasattr(self._storage, "shape"):
|
|
1474
|
+
torchrl_logger.info(
|
|
1475
|
+
f"Initialized LazyTensorStorage with {self._storage.shape} shape"
|
|
1476
|
+
)
|
|
1477
|
+
|
|
1478
|
+
def _init_coordinator(
|
|
1479
|
+
self,
|
|
1480
|
+
data: TensorDictBase | torch.Tensor | PyTree, # noqa: F821
|
|
1481
|
+
) -> None:
|
|
1482
|
+
"""Initialize storage as the coordinating process using temporary memmap."""
|
|
1483
|
+
# Use LazyMemmapStorage which does everything we want
|
|
1484
|
+
temp_memmap_storage = LazyMemmapStorage(
|
|
1485
|
+
max_size=self.max_size,
|
|
1486
|
+
scratch_dir=self._init_directory,
|
|
1487
|
+
ndim=self.ndim,
|
|
1488
|
+
existsok=False,
|
|
1489
|
+
shared_init=False, # Don't recurse
|
|
1490
|
+
)
|
|
1491
|
+
temp_memmap_storage._init_standard(data)
|
|
1492
|
+
self._storage = temp_memmap_storage._storage
|
|
1493
|
+
return
|
|
1494
|
+
|
|
1495
|
+
def _wait_for_init(self) -> None:
|
|
1496
|
+
# wait till coordinator has initialized
|
|
1497
|
+
self._init_event.wait()
|
|
1498
|
+
storage = TensorDict.load_memmap(self._init_directory)
|
|
1499
|
+
self._storage = storage
|
|
1500
|
+
self.initialized = True
|
|
1501
|
+
return
|
|
1502
|
+
|
|
1503
|
+
# Read blocks
|
|
1504
|
+
def get(self, indices: slice) -> TensorDictBase | torch.Tensor | Any:
|
|
1505
|
+
if not self.initialized and self.shared_init:
|
|
1506
|
+
# Trigger initialization with dummy data
|
|
1507
|
+
self._wait_for_init()
|
|
1508
|
+
idx = super().get(indices)
|
|
1509
|
+
return idx
|
|
1510
|
+
|
|
1511
|
+
|
|
1512
|
+
class LazyMemmapStorage(LazyTensorStorage):
|
|
1513
|
+
"""A memory-mapped storage for tensors and tensordicts.
|
|
1514
|
+
|
|
1515
|
+
Args:
|
|
1516
|
+
max_size (int): size of the storage, i.e. maximum number of elements stored
|
|
1517
|
+
in the buffer.
|
|
1518
|
+
|
|
1519
|
+
Keyword Args:
|
|
1520
|
+
scratch_dir (str or path): directory where memmap-tensors will be written.
|
|
1521
|
+
If ``shared_init=True`` and no ``scratch_dir`` is provided, a shared
|
|
1522
|
+
temporary directory will be created automatically.
|
|
1523
|
+
device (torch.device, optional): device where the sampled tensors will be
|
|
1524
|
+
stored and sent. Default is :obj:`torch.device("cpu")`.
|
|
1525
|
+
If ``None`` is provided, the device is automatically gathered from the
|
|
1526
|
+
first batch of data passed. This is not enabled by default to avoid
|
|
1527
|
+
data placed on GPU by mistake, causing OOM issues.
|
|
1528
|
+
ndim (int, optional): the number of dimensions to be accounted for when
|
|
1529
|
+
measuring the storage size. For instance, a storage of shape ``[3, 4]``
|
|
1530
|
+
has capacity ``3`` if ``ndim=1`` and ``12`` if ``ndim=2``.
|
|
1531
|
+
Defaults to ``1``.
|
|
1532
|
+
existsok (bool, optional): whether an error should be raised if any of the
|
|
1533
|
+
tensors already exists on disk. Defaults to ``True``. If ``False``, the
|
|
1534
|
+
tensor will be opened as is, not overewritten.
|
|
1535
|
+
shared_init (bool, optional): if ``True``, enables multiprocess coordination
|
|
1536
|
+
during storage initialization. First process initializes the memmap,
|
|
1537
|
+
others wait and load from the shared directory. Defaults to ``False``.
|
|
1538
|
+
auto_cleanup (bool, optional): if ``True``, automatically registers this
|
|
1539
|
+
storage for cleanup when the process exits (normally or via Ctrl+C/SIGTERM).
|
|
1540
|
+
This removes the memmap files from disk when no longer needed.
|
|
1541
|
+
Defaults to ``True`` when ``scratch_dir`` is ``None`` (using temp directory),
|
|
1542
|
+
and ``False`` when a custom ``scratch_dir`` is provided (preserving user data).
|
|
1543
|
+
|
|
1544
|
+
.. note:: When checkpointing a ``LazyMemmapStorage``, one can provide a path identical to where the storage is
|
|
1545
|
+
already stored to avoid executing long copies of data that is already stored on disk.
|
|
1546
|
+
This will only work if the default :class:`~torchrl.data.TensorStorageCheckpointer` checkpointer is used.
|
|
1547
|
+
|
|
1548
|
+
Example::
|
|
1549
|
+
|
|
1550
|
+
>>> from tensordict import TensorDict
|
|
1551
|
+
>>> from torchrl.data import TensorStorage, LazyMemmapStorage, ReplayBuffer
|
|
1552
|
+
>>> import tempfile
|
|
1553
|
+
>>> from pathlib import Path
|
|
1554
|
+
>>> import time
|
|
1555
|
+
>>> td = TensorDict(a=0, b=1).expand(1000).clone()
|
|
1556
|
+
>>> # We pass a path that is <main_ckpt_dir>/storage to LazyMemmapStorage
|
|
1557
|
+
>>> rb_memmap = ReplayBuffer(storage=LazyMemmapStorage(10_000_000, scratch_dir="dump/storage"))
|
|
1558
|
+
>>> rb_memmap.extend(td);
|
|
1559
|
+
>>> # Checkpointing in `dump` is a zero-copy, as the data is already in `dump/storage`
|
|
1560
|
+
>>> rb_memmap.dumps(Path("./dump"))
|
|
1561
|
+
|
|
1562
|
+
|
|
1563
|
+
Examples:
|
|
1564
|
+
>>> data = TensorDict({
|
|
1565
|
+
... "some data": torch.randn(10, 11),
|
|
1566
|
+
... ("some", "nested", "data"): torch.randn(10, 11, 12),
|
|
1567
|
+
... }, batch_size=[10, 11])
|
|
1568
|
+
>>> storage = LazyMemmapStorage(100)
|
|
1569
|
+
>>> storage.set(range(10), data)
|
|
1570
|
+
>>> len(storage) # only the first dimension is considered as indexable
|
|
1571
|
+
10
|
|
1572
|
+
>>> storage.get(0)
|
|
1573
|
+
TensorDict(
|
|
1574
|
+
fields={
|
|
1575
|
+
some data: MemoryMappedTensor(shape=torch.Size([11]), device=cpu, dtype=torch.float32, is_shared=False),
|
|
1576
|
+
some: TensorDict(
|
|
1577
|
+
fields={
|
|
1578
|
+
nested: TensorDict(
|
|
1579
|
+
fields={
|
|
1580
|
+
data: MemoryMappedTensor(shape=torch.Size([11, 12]), device=cpu, dtype=torch.float32, is_shared=False)},
|
|
1581
|
+
batch_size=torch.Size([11]),
|
|
1582
|
+
device=cpu,
|
|
1583
|
+
is_shared=False)},
|
|
1584
|
+
batch_size=torch.Size([11]),
|
|
1585
|
+
device=cpu,
|
|
1586
|
+
is_shared=False)},
|
|
1587
|
+
batch_size=torch.Size([11]),
|
|
1588
|
+
device=cpu,
|
|
1589
|
+
is_shared=False)
|
|
1590
|
+
|
|
1591
|
+
This class also supports tensorclass data.
|
|
1592
|
+
|
|
1593
|
+
Examples:
|
|
1594
|
+
>>> from tensordict import tensorclass
|
|
1595
|
+
>>> @tensorclass
|
|
1596
|
+
... class MyClass:
|
|
1597
|
+
... foo: torch.Tensor
|
|
1598
|
+
... bar: torch.Tensor
|
|
1599
|
+
>>> data = MyClass(foo=torch.randn(10, 11), bar=torch.randn(10, 11, 12), batch_size=[10, 11])
|
|
1600
|
+
>>> storage = LazyMemmapStorage(10)
|
|
1601
|
+
>>> storage.set(range(10), data)
|
|
1602
|
+
>>> storage.get(0)
|
|
1603
|
+
MyClass(
|
|
1604
|
+
bar=MemoryMappedTensor(shape=torch.Size([11, 12]), device=cpu, dtype=torch.float32, is_shared=False),
|
|
1605
|
+
foo=MemoryMappedTensor(shape=torch.Size([11]), device=cpu, dtype=torch.float32, is_shared=False),
|
|
1606
|
+
batch_size=torch.Size([11]),
|
|
1607
|
+
device=cpu,
|
|
1608
|
+
is_shared=False)
|
|
1609
|
+
|
|
1610
|
+
"""
|
|
1611
|
+
|
|
1612
|
+
_default_checkpointer = TensorStorageCheckpointer
|
|
1613
|
+
|
|
1614
|
+
def __init__(
|
|
1615
|
+
self,
|
|
1616
|
+
max_size: int,
|
|
1617
|
+
*,
|
|
1618
|
+
scratch_dir=None,
|
|
1619
|
+
device: torch.device | str = "cpu",
|
|
1620
|
+
ndim: int = 1,
|
|
1621
|
+
existsok: bool = False,
|
|
1622
|
+
compilable: bool = False,
|
|
1623
|
+
shared_init: bool = False,
|
|
1624
|
+
auto_cleanup: bool | None = None,
|
|
1625
|
+
):
|
|
1626
|
+
self.initialized = False
|
|
1627
|
+
self.scratch_dir = None
|
|
1628
|
+
self._scratch_dir_is_temp = scratch_dir is None
|
|
1629
|
+
self.existsok = existsok
|
|
1630
|
+
if scratch_dir is not None:
|
|
1631
|
+
self.scratch_dir = str(scratch_dir)
|
|
1632
|
+
if self.scratch_dir[-1] != "/":
|
|
1633
|
+
self.scratch_dir += "/"
|
|
1634
|
+
super().__init__(
|
|
1635
|
+
max_size,
|
|
1636
|
+
ndim=ndim,
|
|
1637
|
+
compilable=compilable,
|
|
1638
|
+
shared_init=shared_init,
|
|
1639
|
+
cleanup_memmap=False,
|
|
1640
|
+
)
|
|
1641
|
+
self.device = (
|
|
1642
|
+
_make_ordinal_device(torch.device(device))
|
|
1643
|
+
if device != "auto"
|
|
1644
|
+
else torch.device("cpu")
|
|
1645
|
+
)
|
|
1646
|
+
if self.device.type != "cpu":
|
|
1647
|
+
raise ValueError(
|
|
1648
|
+
"Memory map device other than CPU isn't supported. To cast your data to the desired device, "
|
|
1649
|
+
"use `buffer.append_transform(lambda x: x.to(device))` or a similar transform."
|
|
1650
|
+
)
|
|
1651
|
+
self._len = 0
|
|
1652
|
+
|
|
1653
|
+
# Auto cleanup: default to True for temp dirs, False for user-specified dirs
|
|
1654
|
+
if auto_cleanup is None:
|
|
1655
|
+
auto_cleanup = self._scratch_dir_is_temp
|
|
1656
|
+
self._auto_cleanup = auto_cleanup
|
|
1657
|
+
self._cleaned_up = False
|
|
1658
|
+
|
|
1659
|
+
if self._auto_cleanup:
|
|
1660
|
+
_ensure_cleanup_handlers()
|
|
1661
|
+
_MEMMAP_STORAGE_REGISTRY.add(self)
|
|
1662
|
+
|
|
1663
|
+
def state_dict(self) -> dict[str, Any]:
|
|
1664
|
+
_storage = self._storage
|
|
1665
|
+
if isinstance(_storage, torch.Tensor):
|
|
1666
|
+
_storage = _mem_map_tensor_as_tensor(_storage)
|
|
1667
|
+
elif isinstance(_storage, TensorDictBase):
|
|
1668
|
+
_storage = _storage.apply(_mem_map_tensor_as_tensor).state_dict()
|
|
1669
|
+
elif _storage is None:
|
|
1670
|
+
_storage = {}
|
|
1671
|
+
else:
|
|
1672
|
+
raise TypeError(
|
|
1673
|
+
f"Objects of type {type(_storage)} are not supported by LazyTensorStorage.state_dict. If you are trying to serialize a PyTree, the storage.dumps/loads is preferred."
|
|
1674
|
+
)
|
|
1675
|
+
return {
|
|
1676
|
+
"_storage": _storage,
|
|
1677
|
+
"initialized": self.initialized,
|
|
1678
|
+
"_len": self._len,
|
|
1679
|
+
}
|
|
1680
|
+
|
|
1681
|
+
def load_state_dict(self, state_dict):
|
|
1682
|
+
_storage = copy(state_dict["_storage"])
|
|
1683
|
+
if isinstance(_storage, torch.Tensor):
|
|
1684
|
+
if isinstance(self._storage, torch.Tensor):
|
|
1685
|
+
_mem_map_tensor_as_tensor(self._storage).copy_(_storage)
|
|
1686
|
+
elif self._storage is None:
|
|
1687
|
+
self._storage = _make_memmap(
|
|
1688
|
+
_storage,
|
|
1689
|
+
path=self.scratch_dir + "/tensor.memmap"
|
|
1690
|
+
if self.scratch_dir is not None
|
|
1691
|
+
else None,
|
|
1692
|
+
)
|
|
1693
|
+
else:
|
|
1694
|
+
raise RuntimeError(
|
|
1695
|
+
f"Cannot copy a storage of type {type(_storage)} onto another of type {type(self._storage)}"
|
|
1696
|
+
)
|
|
1697
|
+
elif isinstance(_storage, (dict, OrderedDict)):
|
|
1698
|
+
if is_tensor_collection(self._storage):
|
|
1699
|
+
self._storage.load_state_dict(_storage, strict=False)
|
|
1700
|
+
self._storage.memmap_()
|
|
1701
|
+
elif self._storage is None:
|
|
1702
|
+
warnings.warn(
|
|
1703
|
+
"Loading the storage on an uninitialized TensorDict."
|
|
1704
|
+
"It is preferable to load a storage onto a"
|
|
1705
|
+
"pre-allocated one whenever possible."
|
|
1706
|
+
)
|
|
1707
|
+
self._storage = TensorDict().load_state_dict(_storage, strict=False)
|
|
1708
|
+
self._storage.memmap_()
|
|
1709
|
+
else:
|
|
1710
|
+
raise RuntimeError(
|
|
1711
|
+
f"Cannot copy a storage of type {type(_storage)} onto another of type {type(self._storage)}"
|
|
1712
|
+
)
|
|
1713
|
+
else:
|
|
1714
|
+
raise TypeError(
|
|
1715
|
+
f"Objects of type {type(_storage)} are not supported by ListStorage.load_state_dict"
|
|
1716
|
+
)
|
|
1717
|
+
self.initialized = state_dict["initialized"]
|
|
1718
|
+
self._len = state_dict["_len"]
|
|
1719
|
+
|
|
1720
|
+
def _init(
|
|
1721
|
+
self,
|
|
1722
|
+
data: TensorDictBase | torch.Tensor | PyTree, # noqa: F821
|
|
1723
|
+
) -> None:
|
|
1724
|
+
if not self.shared_init:
|
|
1725
|
+
return self._init_standard(data)
|
|
1726
|
+
is_coordinator = not self._init_event.is_set()
|
|
1727
|
+
is_coordinator = is_coordinator and self._init_lock.acquire(block=False)
|
|
1728
|
+
|
|
1729
|
+
if is_coordinator:
|
|
1730
|
+
# coordinator init
|
|
1731
|
+
try:
|
|
1732
|
+
return self._init_coordinator(data)
|
|
1733
|
+
finally:
|
|
1734
|
+
self._init_event.set()
|
|
1735
|
+
self._init_lock.release()
|
|
1736
|
+
else:
|
|
1737
|
+
# Standard initialization
|
|
1738
|
+
self._wait_for_init()
|
|
1739
|
+
self.initialized = True
|
|
1740
|
+
|
|
1741
|
+
def _init_coordinator(self, data: TensorDictBase | torch.Tensor | Any) -> None:
|
|
1742
|
+
return self._init_standard(data)
|
|
1743
|
+
|
|
1744
|
+
def _init_standard(self, data: TensorDictBase | torch.Tensor) -> None:
|
|
1745
|
+
torchrl_logger.debug("Creating a MemmapStorage...")
|
|
1746
|
+
if self.device == "auto":
|
|
1747
|
+
self.device = data.device
|
|
1748
|
+
if self.device.type != "cpu":
|
|
1749
|
+
raise RuntimeError("Support for Memmap device other than CPU is deprecated")
|
|
1750
|
+
|
|
1751
|
+
def max_size_along_dim0(data_shape):
|
|
1752
|
+
if self.ndim > 1:
|
|
1753
|
+
result = (
|
|
1754
|
+
-(self.max_size // -data_shape[: self.ndim - 1].numel()),
|
|
1755
|
+
*data_shape,
|
|
1756
|
+
)
|
|
1757
|
+
self.max_size = torch.Size(result).numel()
|
|
1758
|
+
return result
|
|
1759
|
+
return (self.max_size, *data_shape)
|
|
1760
|
+
|
|
1761
|
+
if is_tensor_collection(data):
|
|
1762
|
+
out = data.clone().to(self.device)
|
|
1763
|
+
out = out.expand(max_size_along_dim0(data.shape))
|
|
1764
|
+
out = out.memmap_like(prefix=self.scratch_dir, existsok=self.existsok)
|
|
1765
|
+
if torchrl_logger.isEnabledFor(logging.DEBUG):
|
|
1766
|
+
for key, tensor in sorted(
|
|
1767
|
+
out.items(
|
|
1768
|
+
include_nested=True,
|
|
1769
|
+
leaves_only=True,
|
|
1770
|
+
is_leaf=_NESTED_TENSORS_AS_LISTS,
|
|
1771
|
+
),
|
|
1772
|
+
key=str,
|
|
1773
|
+
):
|
|
1774
|
+
try:
|
|
1775
|
+
filesize = os.path.getsize(tensor.filename) / 1024 / 1024
|
|
1776
|
+
torchrl_logger.debug(
|
|
1777
|
+
f"\t{key}: {tensor.filename}, {filesize} Mb of storage (size: {tensor.shape})."
|
|
1778
|
+
)
|
|
1779
|
+
except (AttributeError, RuntimeError):
|
|
1780
|
+
pass
|
|
1781
|
+
else:
|
|
1782
|
+
out = _init_pytree(self.scratch_dir, max_size_along_dim0, data)
|
|
1783
|
+
self._storage = out
|
|
1784
|
+
if hasattr(self._storage, "shape"):
|
|
1785
|
+
torchrl_logger.info(
|
|
1786
|
+
f"Initialized LazyMemmapStorage with {self._storage.shape} shape"
|
|
1787
|
+
)
|
|
1788
|
+
self.initialized = True
|
|
1789
|
+
|
|
1790
|
+
def get(self, index: int | Sequence[int] | slice) -> Any:
|
|
1791
|
+
if not self.initialized and self.shared_init:
|
|
1792
|
+
# Trigger initialization with dummy data
|
|
1793
|
+
self._wait_for_init()
|
|
1794
|
+
result = super().get(index)
|
|
1795
|
+
return result
|
|
1796
|
+
|
|
1797
|
+
def cleanup(self) -> bool:
|
|
1798
|
+
"""Clean up memmap files from disk.
|
|
1799
|
+
|
|
1800
|
+
This method removes the memmap directory and all its contents from disk.
|
|
1801
|
+
It is automatically called on process exit if ``auto_cleanup=True``.
|
|
1802
|
+
|
|
1803
|
+
Returns:
|
|
1804
|
+
bool: ``True`` if cleanup was performed, ``False`` if already cleaned up
|
|
1805
|
+
or no cleanup needed.
|
|
1806
|
+
|
|
1807
|
+
Note:
|
|
1808
|
+
After cleanup, the storage is no longer usable. Any attempt to access
|
|
1809
|
+
the storage will result in undefined behavior.
|
|
1810
|
+
|
|
1811
|
+
Example:
|
|
1812
|
+
>>> storage = LazyMemmapStorage(1000, auto_cleanup=True)
|
|
1813
|
+
>>> # ... use storage ...
|
|
1814
|
+
>>> storage.cleanup() # Manually clean up when done
|
|
1815
|
+
"""
|
|
1816
|
+
if getattr(self, "_cleaned_up", False):
|
|
1817
|
+
return False
|
|
1818
|
+
|
|
1819
|
+
self._cleaned_up = True
|
|
1820
|
+
|
|
1821
|
+
# Get the directory to clean up
|
|
1822
|
+
scratch_dir = getattr(self, "scratch_dir", None)
|
|
1823
|
+
if scratch_dir is None:
|
|
1824
|
+
# No scratch dir - check if storage has memmap tensors with temp paths
|
|
1825
|
+
storage = getattr(self, "_storage", None)
|
|
1826
|
+
if storage is not None and is_tensor_collection(storage):
|
|
1827
|
+
# Get all memmap file paths and find their common directory
|
|
1828
|
+
paths = set()
|
|
1829
|
+
try:
|
|
1830
|
+
for tensor in storage.values(include_nested=True, leaves_only=True):
|
|
1831
|
+
if hasattr(tensor, "filename") and tensor.filename:
|
|
1832
|
+
paths.add(os.path.dirname(tensor.filename))
|
|
1833
|
+
except Exception:
|
|
1834
|
+
# Storage might be in an invalid state during cleanup
|
|
1835
|
+
pass
|
|
1836
|
+
for path in paths:
|
|
1837
|
+
if (
|
|
1838
|
+
path
|
|
1839
|
+
and os.path.isdir(path)
|
|
1840
|
+
and path.startswith(tempfile.gettempdir())
|
|
1841
|
+
):
|
|
1842
|
+
try:
|
|
1843
|
+
shutil.rmtree(path)
|
|
1844
|
+
torchrl_logger.debug(f"Cleaned up memmap directory: {path}")
|
|
1845
|
+
except Exception:
|
|
1846
|
+
# Ignore errors - file might be in use or already deleted
|
|
1847
|
+
pass
|
|
1848
|
+
return bool(paths)
|
|
1849
|
+
return False
|
|
1850
|
+
|
|
1851
|
+
# Clean up the scratch directory
|
|
1852
|
+
scratch_dir = scratch_dir.rstrip("/")
|
|
1853
|
+
if os.path.isdir(scratch_dir):
|
|
1854
|
+
try:
|
|
1855
|
+
shutil.rmtree(scratch_dir)
|
|
1856
|
+
torchrl_logger.debug(f"Cleaned up memmap directory: {scratch_dir}")
|
|
1857
|
+
return True
|
|
1858
|
+
except Exception as e:
|
|
1859
|
+
torchrl_logger.warning(f"Failed to clean up memmap directory: {e}")
|
|
1860
|
+
return False
|
|
1861
|
+
return False
|
|
1862
|
+
|
|
1863
|
+
def __del__(self):
|
|
1864
|
+
"""Ensure cleanup on garbage collection if auto_cleanup is enabled."""
|
|
1865
|
+
if getattr(self, "_auto_cleanup", False) and not getattr(
|
|
1866
|
+
self, "_cleaned_up", True
|
|
1867
|
+
):
|
|
1868
|
+
self.cleanup()
|
|
1869
|
+
|
|
1870
|
+
|
|
1871
|
+
class CompressedListStorage(ListStorage):
|
|
1872
|
+
"""A storage that compresses and decompresses data.
|
|
1873
|
+
|
|
1874
|
+
This storage compresses data when storing and decompresses when retrieving.
|
|
1875
|
+
It's particularly useful for storing raw sensory observations like images
|
|
1876
|
+
that can be compressed significantly to save memory.
|
|
1877
|
+
|
|
1878
|
+
Args:
|
|
1879
|
+
max_size (int): size of the storage, i.e. maximum number of elements stored
|
|
1880
|
+
in the buffer.
|
|
1881
|
+
compression_fn (callable, optional): function to compress data. Should take
|
|
1882
|
+
a tensor and return a compressed byte tensor. Defaults to zstd compression.
|
|
1883
|
+
decompression_fn (callable, optional): function to decompress data. Should take
|
|
1884
|
+
a compressed byte tensor and return the original tensor. Defaults to zstd decompression.
|
|
1885
|
+
compression_level (int, optional): compression level (1-22 for zstd) when using the default compression function.
|
|
1886
|
+
Defaults to 3.
|
|
1887
|
+
device (torch.device, optional): device where the sampled tensors will be
|
|
1888
|
+
stored and sent. Default is :obj:`torch.device("cpu")`.
|
|
1889
|
+
compilable (bool, optional): whether the storage is compilable.
|
|
1890
|
+
If ``True``, the writer cannot be shared between multiple processes.
|
|
1891
|
+
Defaults to ``False``.
|
|
1892
|
+
|
|
1893
|
+
Examples:
|
|
1894
|
+
>>> import torch
|
|
1895
|
+
>>> from torchrl.data import CompressedListStorage, ReplayBuffer
|
|
1896
|
+
>>> from tensordict import TensorDict
|
|
1897
|
+
>>>
|
|
1898
|
+
>>> # Create a compressed storage for image data
|
|
1899
|
+
>>> storage = CompressedListStorage(max_size=1000, compression_level=3)
|
|
1900
|
+
>>> rb = ReplayBuffer(storage=storage, batch_size=5)
|
|
1901
|
+
>>>
|
|
1902
|
+
>>> # Add some image data
|
|
1903
|
+
>>> images = torch.randn(10, 3, 84, 84) # Atari-like frames
|
|
1904
|
+
>>> data = TensorDict({"obs": images}, batch_size=[10])
|
|
1905
|
+
>>> rb.extend(data)
|
|
1906
|
+
>>>
|
|
1907
|
+
>>> # Sample and verify data is decompressed correctly
|
|
1908
|
+
>>> sample = rb.sample(3)
|
|
1909
|
+
>>> print(sample["obs"].shape) # torch.Size([3, 3, 84, 84])
|
|
1910
|
+
|
|
1911
|
+
"""
|
|
1912
|
+
|
|
1913
|
+
_default_checkpointer = CompressedListStorageCheckpointer
|
|
1914
|
+
|
|
1915
|
+
def __init__(
|
|
1916
|
+
self,
|
|
1917
|
+
max_size: int,
|
|
1918
|
+
*,
|
|
1919
|
+
compression_fn: Callable | None = None,
|
|
1920
|
+
decompression_fn: Callable | None = None,
|
|
1921
|
+
compression_level: int = 3,
|
|
1922
|
+
device: torch.device = "cpu",
|
|
1923
|
+
compilable: bool = False,
|
|
1924
|
+
):
|
|
1925
|
+
super().__init__(max_size, compilable=compilable, device=device)
|
|
1926
|
+
self.compression_level = compression_level
|
|
1927
|
+
|
|
1928
|
+
# Set up compression functions
|
|
1929
|
+
if compression_fn is None:
|
|
1930
|
+
self.compression_fn = self._default_compression_fn
|
|
1931
|
+
else:
|
|
1932
|
+
self.compression_fn = compression_fn
|
|
1933
|
+
|
|
1934
|
+
if decompression_fn is None:
|
|
1935
|
+
self.decompression_fn = self._default_decompression_fn
|
|
1936
|
+
else:
|
|
1937
|
+
self.decompression_fn = decompression_fn
|
|
1938
|
+
|
|
1939
|
+
# Store compressed data and metadata
|
|
1940
|
+
self._storage = []
|
|
1941
|
+
self._metadata = [] # Store shape, dtype, device info for each item
|
|
1942
|
+
|
|
1943
|
+
def _default_compression_fn(self, tensor: torch.Tensor) -> torch.Tensor:
|
|
1944
|
+
"""Default compression using zstd."""
|
|
1945
|
+
if sys.version_info >= (3, 14):
|
|
1946
|
+
from compression import zstd
|
|
1947
|
+
|
|
1948
|
+
compressor_fn = zstd.compress
|
|
1949
|
+
|
|
1950
|
+
else:
|
|
1951
|
+
import zlib
|
|
1952
|
+
|
|
1953
|
+
compressor_fn = zlib.compress
|
|
1954
|
+
|
|
1955
|
+
# Convert tensor to bytes
|
|
1956
|
+
tensor_bytes = self.to_bytestream(tensor)
|
|
1957
|
+
|
|
1958
|
+
# Compress with zstd
|
|
1959
|
+
compressed_bytes = compressor_fn(tensor_bytes, level=self.compression_level)
|
|
1960
|
+
|
|
1961
|
+
# Convert to tensor
|
|
1962
|
+
return torch.frombuffer(bytearray(compressed_bytes), dtype=torch.uint8)
|
|
1963
|
+
|
|
1964
|
+
def _default_decompression_fn(
|
|
1965
|
+
self, compressed_tensor: torch.Tensor, metadata: dict
|
|
1966
|
+
) -> torch.Tensor:
|
|
1967
|
+
"""Default decompression using zstd."""
|
|
1968
|
+
if sys.version_info >= (3, 14):
|
|
1969
|
+
from compression import zstd
|
|
1970
|
+
|
|
1971
|
+
decompressor_fn = zstd.decompress
|
|
1972
|
+
|
|
1973
|
+
else:
|
|
1974
|
+
import zlib
|
|
1975
|
+
|
|
1976
|
+
decompressor_fn = zlib.decompress
|
|
1977
|
+
|
|
1978
|
+
# Convert tensor to bytes
|
|
1979
|
+
compressed_bytes = self.to_bytestream(compressed_tensor.cpu())
|
|
1980
|
+
|
|
1981
|
+
# Decompress with zstd
|
|
1982
|
+
decompressed_bytes = decompressor_fn(compressed_bytes)
|
|
1983
|
+
|
|
1984
|
+
# Convert back to tensor
|
|
1985
|
+
tensor = torch.frombuffer(
|
|
1986
|
+
bytearray(decompressed_bytes), dtype=metadata["dtype"]
|
|
1987
|
+
)
|
|
1988
|
+
tensor = tensor.reshape(metadata["shape"])
|
|
1989
|
+
tensor = tensor.to(metadata["device"])
|
|
1990
|
+
|
|
1991
|
+
return tensor
|
|
1992
|
+
|
|
1993
|
+
def _compress_item(self, item: Any) -> tuple[torch.Tensor, dict]:
|
|
1994
|
+
"""Compress a single item and return compressed data with metadata."""
|
|
1995
|
+
if isinstance(item, torch.Tensor):
|
|
1996
|
+
metadata = {
|
|
1997
|
+
"type": "tensor",
|
|
1998
|
+
"shape": item.shape,
|
|
1999
|
+
"dtype": item.dtype,
|
|
2000
|
+
"device": item.device,
|
|
2001
|
+
}
|
|
2002
|
+
compressed = self.compression_fn(item)
|
|
2003
|
+
elif is_tensor_collection(item):
|
|
2004
|
+
# For TensorDict, compress each tensor field
|
|
2005
|
+
compressed_fields = {}
|
|
2006
|
+
metadata = {"type": "tensordict", "fields": {}}
|
|
2007
|
+
|
|
2008
|
+
for key, value in item.items():
|
|
2009
|
+
if isinstance(value, torch.Tensor):
|
|
2010
|
+
compressed_fields[key] = self.compression_fn(value)
|
|
2011
|
+
metadata["fields"][key] = {
|
|
2012
|
+
"type": "tensor",
|
|
2013
|
+
"shape": value.shape,
|
|
2014
|
+
"dtype": value.dtype,
|
|
2015
|
+
"device": value.device,
|
|
2016
|
+
}
|
|
2017
|
+
else:
|
|
2018
|
+
# For non-tensor data, store as-is
|
|
2019
|
+
compressed_fields[key] = value
|
|
2020
|
+
metadata["fields"][key] = {"type": "non_tensor", "value": value}
|
|
2021
|
+
|
|
2022
|
+
compressed = compressed_fields
|
|
2023
|
+
else:
|
|
2024
|
+
# For other types, store as-is
|
|
2025
|
+
compressed = item
|
|
2026
|
+
metadata = {"type": "other", "value": item}
|
|
2027
|
+
|
|
2028
|
+
return compressed, metadata
|
|
2029
|
+
|
|
2030
|
+
def _decompress_item(self, compressed_data: Any, metadata: dict) -> Any:
|
|
2031
|
+
"""Decompress a single item using its metadata."""
|
|
2032
|
+
if metadata["type"] == "tensor":
|
|
2033
|
+
return self.decompression_fn(compressed_data, metadata)
|
|
2034
|
+
elif metadata["type"] == "tensordict":
|
|
2035
|
+
# Reconstruct TensorDict
|
|
2036
|
+
result = TensorDict({}, batch_size=metadata.get("batch_size", []))
|
|
2037
|
+
|
|
2038
|
+
for key, field_metadata in metadata["fields"].items():
|
|
2039
|
+
if field_metadata["type"] == "non_tensor":
|
|
2040
|
+
result[key] = field_metadata["value"]
|
|
2041
|
+
else:
|
|
2042
|
+
# Decompress tensor field
|
|
2043
|
+
result[key] = self.decompression_fn(
|
|
2044
|
+
compressed_data[key], field_metadata
|
|
2045
|
+
)
|
|
2046
|
+
|
|
2047
|
+
return result
|
|
2048
|
+
else:
|
|
2049
|
+
# Return as-is for other types
|
|
2050
|
+
return metadata["value"]
|
|
2051
|
+
|
|
2052
|
+
def _set_item(self, cursor: int, data: Any) -> None:
|
|
2053
|
+
"""Set a single item in the compressed storage."""
|
|
2054
|
+
# Ensure we have enough space
|
|
2055
|
+
while len(self._storage) <= cursor:
|
|
2056
|
+
self._storage.append(None)
|
|
2057
|
+
self._metadata.append(None)
|
|
2058
|
+
|
|
2059
|
+
# Compress and store
|
|
2060
|
+
compressed_data, metadata = self._compress_item(data)
|
|
2061
|
+
self._storage[cursor] = compressed_data
|
|
2062
|
+
self._metadata[cursor] = metadata
|
|
2063
|
+
|
|
2064
|
+
def _set_slice(self, cursor: slice, data: Any) -> None:
|
|
2065
|
+
"""Set a slice in the compressed storage."""
|
|
2066
|
+
# Handle slice assignment
|
|
2067
|
+
if not hasattr(data, "__iter__"):
|
|
2068
|
+
data = [data]
|
|
2069
|
+
start, stop, step = cursor.indices(len(self._storage))
|
|
2070
|
+
indices = list(range(start, stop, step))
|
|
2071
|
+
|
|
2072
|
+
for i, value in zip(indices, data):
|
|
2073
|
+
self._set_item(i, value)
|
|
2074
|
+
|
|
2075
|
+
def _get_item(self, index: int) -> Any:
|
|
2076
|
+
"""Get a single item from the compressed storage."""
|
|
2077
|
+
if index >= len(self._storage) or self._storage[index] is None:
|
|
2078
|
+
raise IndexError(f"Index {index} out of bounds or not set")
|
|
2079
|
+
|
|
2080
|
+
compressed_data = self._storage[index]
|
|
2081
|
+
metadata = self._metadata[index]
|
|
2082
|
+
return self._decompress_item(compressed_data, metadata)
|
|
2083
|
+
|
|
2084
|
+
def _get_slice(self, index: slice) -> list:
|
|
2085
|
+
"""Get a slice from the compressed storage."""
|
|
2086
|
+
start, stop, step = index.indices(len(self._storage))
|
|
2087
|
+
results = []
|
|
2088
|
+
for i in range(start, stop, step):
|
|
2089
|
+
if i < len(self._storage) and self._storage[i] is not None:
|
|
2090
|
+
results.append(self._get_item(i))
|
|
2091
|
+
return results
|
|
2092
|
+
|
|
2093
|
+
def _get_list(self, index: list) -> list:
|
|
2094
|
+
"""Get a list of items from the compressed storage."""
|
|
2095
|
+
if isinstance(index, torch.Tensor) and index.device.type != "cpu":
|
|
2096
|
+
index = index.cpu().tolist()
|
|
2097
|
+
|
|
2098
|
+
results = []
|
|
2099
|
+
for i in index:
|
|
2100
|
+
if i >= len(self._storage) or self._storage[i] is None:
|
|
2101
|
+
raise IndexError(f"Index {i} out of bounds or not set")
|
|
2102
|
+
results.append(self._get_item(i))
|
|
2103
|
+
return results
|
|
2104
|
+
|
|
2105
|
+
def __len__(self) -> int:
|
|
2106
|
+
"""Get the length of the compressed storage."""
|
|
2107
|
+
return len([item for item in self._storage if item is not None])
|
|
2108
|
+
|
|
2109
|
+
def _contains_int(self, item: int) -> bool:
|
|
2110
|
+
"""Check if an integer index is contained in the compressed storage."""
|
|
2111
|
+
return 0 <= item < len(self._storage) and self._storage[item] is not None
|
|
2112
|
+
|
|
2113
|
+
def _empty(self):
|
|
2114
|
+
"""Empty the storage."""
|
|
2115
|
+
self._storage = []
|
|
2116
|
+
self._metadata = []
|
|
2117
|
+
|
|
2118
|
+
def state_dict(self) -> dict[str, Any]:
|
|
2119
|
+
"""Save the storage state."""
|
|
2120
|
+
return {
|
|
2121
|
+
"_storage": self._storage,
|
|
2122
|
+
"_metadata": self._metadata,
|
|
2123
|
+
}
|
|
2124
|
+
|
|
2125
|
+
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
|
|
2126
|
+
"""Load the storage state."""
|
|
2127
|
+
self._storage = state_dict["_storage"]
|
|
2128
|
+
self._metadata = state_dict["_metadata"]
|
|
2129
|
+
|
|
2130
|
+
def to_bytestream(self, data_to_bytestream: torch.Tensor | np.array | Any) -> bytes:
|
|
2131
|
+
"""Convert data to a byte stream."""
|
|
2132
|
+
if isinstance(data_to_bytestream, torch.Tensor):
|
|
2133
|
+
byte_stream = data_to_bytestream.cpu().numpy().tobytes()
|
|
2134
|
+
|
|
2135
|
+
elif isinstance(data_to_bytestream, np.array):
|
|
2136
|
+
byte_stream = bytes(data_to_bytestream.tobytes())
|
|
2137
|
+
|
|
2138
|
+
else:
|
|
2139
|
+
import io
|
|
2140
|
+
import pickle
|
|
2141
|
+
|
|
2142
|
+
buffer = io.BytesIO()
|
|
2143
|
+
pickle.dump(data_to_bytestream, buffer)
|
|
2144
|
+
buffer.seek(0)
|
|
2145
|
+
byte_stream = bytes(buffer.read())
|
|
2146
|
+
|
|
2147
|
+
return byte_stream
|
|
2148
|
+
|
|
2149
|
+
def bytes(self):
|
|
2150
|
+
"""Return the number of bytes in the storage."""
|
|
2151
|
+
|
|
2152
|
+
def compressed_size_from_list(data: Any) -> int:
|
|
2153
|
+
if data is None:
|
|
2154
|
+
return 0
|
|
2155
|
+
elif isinstance(data, (bytes,)):
|
|
2156
|
+
return len(data)
|
|
2157
|
+
elif isinstance(data, (np.ndarray,)):
|
|
2158
|
+
return data.nbytes
|
|
2159
|
+
elif isinstance(data, (torch.Tensor)):
|
|
2160
|
+
return compressed_size_from_list(data.cpu().numpy())
|
|
2161
|
+
elif isinstance(data, (tuple, list, Sequence)):
|
|
2162
|
+
return sum(compressed_size_from_list(item) for item in data)
|
|
2163
|
+
elif isinstance(data, Mapping) or is_tensor_collection(data):
|
|
2164
|
+
return sum(compressed_size_from_list(value) for value in data.values())
|
|
2165
|
+
else:
|
|
2166
|
+
return 0
|
|
2167
|
+
|
|
2168
|
+
compressed_size_estimate = compressed_size_from_list(self._storage)
|
|
2169
|
+
if compressed_size_estimate == 0:
|
|
2170
|
+
if len(self._storage) > 0:
|
|
2171
|
+
raise RuntimeError(
|
|
2172
|
+
"Compressed storage is not empty but the compressed size is 0. This is a bug."
|
|
2173
|
+
)
|
|
2174
|
+
warnings.warn("Compressed storage is empty, returning 0 bytes.")
|
|
2175
|
+
|
|
2176
|
+
return compressed_size_estimate
|
|
2177
|
+
|
|
2178
|
+
|
|
2179
|
+
class StorageEnsemble(Storage):
|
|
2180
|
+
"""An ensemble of storages.
|
|
2181
|
+
|
|
2182
|
+
This class is designed to work with :class:`~torchrl.data.replay_buffers.replay_buffers.ReplayBufferEnsemble`.
|
|
2183
|
+
|
|
2184
|
+
Args:
|
|
2185
|
+
storages (sequence of Storage): the storages to make the composite storage.
|
|
2186
|
+
|
|
2187
|
+
Keyword Args:
|
|
2188
|
+
transforms (list of :class:`~torchrl.envs.Transform`, optional): a list of
|
|
2189
|
+
transforms of the same length as storages.
|
|
2190
|
+
|
|
2191
|
+
.. warning::
|
|
2192
|
+
This class signatures for :meth:`get` does not match other storages, as
|
|
2193
|
+
it will return a tuple ``(buffer_id, samples)`` rather than just the samples.
|
|
2194
|
+
|
|
2195
|
+
.. warning::
|
|
2196
|
+
This class does not support writing (similarly to :class:`~torchrl.data.replay_buffers.writers.WriterEnsemble`).
|
|
2197
|
+
To extend one of the replay buffers, simply index the parent
|
|
2198
|
+
:class:`~torchrl.data.ReplayBufferEnsemble` object.
|
|
2199
|
+
|
|
2200
|
+
"""
|
|
2201
|
+
|
|
2202
|
+
_default_checkpointer = StorageEnsembleCheckpointer
|
|
2203
|
+
|
|
2204
|
+
def __init__(
|
|
2205
|
+
self,
|
|
2206
|
+
*storages: Storage,
|
|
2207
|
+
transforms: list[Transform] = None, # noqa: F821
|
|
2208
|
+
):
|
|
2209
|
+
self._rng_private = None
|
|
2210
|
+
self._storages = storages
|
|
2211
|
+
self._transforms = transforms
|
|
2212
|
+
if transforms is not None and len(transforms) != len(storages):
|
|
2213
|
+
raise TypeError(
|
|
2214
|
+
"transforms must have the same length as the storages provided."
|
|
2215
|
+
)
|
|
2216
|
+
|
|
2217
|
+
@property
|
|
2218
|
+
def _rng(self):
|
|
2219
|
+
return self._rng_private
|
|
2220
|
+
|
|
2221
|
+
@_rng.setter
|
|
2222
|
+
def _rng(self, value):
|
|
2223
|
+
self._rng_private = value
|
|
2224
|
+
for storage in self._storages:
|
|
2225
|
+
storage._rng = value
|
|
2226
|
+
|
|
2227
|
+
def extend(self, value):
|
|
2228
|
+
raise RuntimeError
|
|
2229
|
+
|
|
2230
|
+
def add(self, value):
|
|
2231
|
+
raise RuntimeError
|
|
2232
|
+
|
|
2233
|
+
def get(self, item):
|
|
2234
|
+
# we return the buffer id too to be able to track the appropriate collate_fn
|
|
2235
|
+
buffer_ids = item.get("buffer_ids")
|
|
2236
|
+
index = item.get("index")
|
|
2237
|
+
results = []
|
|
2238
|
+
for buffer_id, sample in zip(buffer_ids, index):
|
|
2239
|
+
buffer_id = self._convert_id(buffer_id)
|
|
2240
|
+
results.append((buffer_id, self._get_storage(buffer_id).get(sample)))
|
|
2241
|
+
if self._transforms is not None:
|
|
2242
|
+
results = [
|
|
2243
|
+
(buffer_id, self._transforms[buffer_id](result))
|
|
2244
|
+
if self._transforms[buffer_id] is not None
|
|
2245
|
+
else (buffer_id, result)
|
|
2246
|
+
for buffer_id, result in results
|
|
2247
|
+
]
|
|
2248
|
+
return results
|
|
2249
|
+
|
|
2250
|
+
def _convert_id(self, sub):
|
|
2251
|
+
if isinstance(sub, torch.Tensor):
|
|
2252
|
+
sub = sub.item()
|
|
2253
|
+
return sub
|
|
2254
|
+
|
|
2255
|
+
def _get_storage(self, sub):
|
|
2256
|
+
return self._storages[sub]
|
|
2257
|
+
|
|
2258
|
+
def state_dict(self) -> dict[str, Any]:
|
|
2259
|
+
raise NotImplementedError
|
|
2260
|
+
|
|
2261
|
+
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
|
|
2262
|
+
raise NotImplementedError
|
|
2263
|
+
|
|
2264
|
+
_INDEX_ERROR = "Expected an index of type torch.Tensor, range, np.ndarray, int, slice or ellipsis, got {} instead."
|
|
2265
|
+
|
|
2266
|
+
def __getitem__(self, index):
|
|
2267
|
+
if isinstance(index, tuple):
|
|
2268
|
+
if index[0] is Ellipsis:
|
|
2269
|
+
index = (slice(None), index[1:])
|
|
2270
|
+
result = self[index[0]]
|
|
2271
|
+
if len(index) > 1:
|
|
2272
|
+
if result is self:
|
|
2273
|
+
# then index[0] is an ellipsis/slice(None)
|
|
2274
|
+
sample = [storage[index[1:]] for storage in self._storages]
|
|
2275
|
+
return sample
|
|
2276
|
+
if isinstance(result, StorageEnsemble):
|
|
2277
|
+
new_index = (slice(None), *index[1:])
|
|
2278
|
+
return result[new_index]
|
|
2279
|
+
return result[index[1:]]
|
|
2280
|
+
return result
|
|
2281
|
+
if isinstance(index, slice) and index == slice(None):
|
|
2282
|
+
return self
|
|
2283
|
+
if isinstance(index, (list, range, np.ndarray)):
|
|
2284
|
+
index = torch.as_tensor(index)
|
|
2285
|
+
if isinstance(index, torch.Tensor):
|
|
2286
|
+
if index.ndim > 1:
|
|
2287
|
+
raise RuntimeError(
|
|
2288
|
+
f"Cannot index a {type(self)} with tensor indices that have more than one dimension."
|
|
2289
|
+
)
|
|
2290
|
+
if index.is_floating_point():
|
|
2291
|
+
raise TypeError(
|
|
2292
|
+
"A floating point index was received when an integer dtype was expected."
|
|
2293
|
+
)
|
|
2294
|
+
if isinstance(index, int) or (not isinstance(index, slice) and len(index) == 0):
|
|
2295
|
+
try:
|
|
2296
|
+
index = int(index)
|
|
2297
|
+
except Exception:
|
|
2298
|
+
raise IndexError(self._INDEX_ERROR.format(type(index)))
|
|
2299
|
+
try:
|
|
2300
|
+
return self._storages[index]
|
|
2301
|
+
except IndexError:
|
|
2302
|
+
raise IndexError(self._INDEX_ERROR.format(type(index)))
|
|
2303
|
+
if isinstance(index, torch.Tensor):
|
|
2304
|
+
index = index.tolist()
|
|
2305
|
+
storages = [self._storages[i] for i in index]
|
|
2306
|
+
transforms = (
|
|
2307
|
+
[self._transforms[i] for i in index]
|
|
2308
|
+
if self._transforms is not None
|
|
2309
|
+
else [None] * len(index)
|
|
2310
|
+
)
|
|
2311
|
+
else:
|
|
2312
|
+
# slice
|
|
2313
|
+
storages = self._storages[index]
|
|
2314
|
+
transforms = (
|
|
2315
|
+
self._transforms[index]
|
|
2316
|
+
if self._transforms is not None
|
|
2317
|
+
else [None] * len(storages)
|
|
2318
|
+
)
|
|
2319
|
+
|
|
2320
|
+
return StorageEnsemble(*storages, transforms=transforms)
|
|
2321
|
+
|
|
2322
|
+
def __len__(self):
|
|
2323
|
+
return len(self._storages)
|
|
2324
|
+
|
|
2325
|
+
def __repr__(self):
|
|
2326
|
+
storages = textwrap.indent(f"storages={self._storages}", " " * 4)
|
|
2327
|
+
transforms = textwrap.indent(f"transforms={self._transforms}", " " * 4)
|
|
2328
|
+
return f"StorageEnsemble(\n{storages}, \n{transforms})"
|
|
2329
|
+
|
|
2330
|
+
|
|
2331
|
+
# Utils
|
|
2332
|
+
def _mem_map_tensor_as_tensor(mem_map_tensor) -> torch.Tensor:
|
|
2333
|
+
if isinstance(mem_map_tensor, torch.Tensor):
|
|
2334
|
+
# This will account for MemoryMappedTensors
|
|
2335
|
+
return mem_map_tensor
|
|
2336
|
+
|
|
2337
|
+
|
|
2338
|
+
def _collate_list_tensordict(x):
|
|
2339
|
+
out = torch.stack(x, 0)
|
|
2340
|
+
return out
|
|
2341
|
+
|
|
2342
|
+
|
|
2343
|
+
@implement_for("torch", "2.4")
|
|
2344
|
+
def _stack_anything(data):
|
|
2345
|
+
if is_tensor_collection(data[0]):
|
|
2346
|
+
return LazyStackedTensorDict.maybe_dense_stack(data)
|
|
2347
|
+
return tree_map(
|
|
2348
|
+
lambda *x: torch.stack(x),
|
|
2349
|
+
*data,
|
|
2350
|
+
is_leaf=lambda x: isinstance(x, torch.Tensor) or is_tensor_collection(x),
|
|
2351
|
+
)
|
|
2352
|
+
|
|
2353
|
+
|
|
2354
|
+
@implement_for("torch", None, "2.4")
|
|
2355
|
+
def _stack_anything(data): # noqa: F811
|
|
2356
|
+
from tensordict import _pytree
|
|
2357
|
+
|
|
2358
|
+
if not _pytree.PYTREE_REGISTERED_TDS:
|
|
2359
|
+
raise RuntimeError(
|
|
2360
|
+
"TensorDict is not registered within PyTree. "
|
|
2361
|
+
"If you see this error, it means tensordicts instances cannot be natively stacked using tree_map. "
|
|
2362
|
+
"To solve this issue, (a) upgrade pytorch to a version > 2.4, or (b) make sure TensorDict is registered in PyTree. "
|
|
2363
|
+
"If this error persists, open an issue on https://github.com/pytorch/rl/issues"
|
|
2364
|
+
)
|
|
2365
|
+
if is_tensor_collection(data[0]):
|
|
2366
|
+
return LazyStackedTensorDict.maybe_dense_stack(data)
|
|
2367
|
+
flat_trees = []
|
|
2368
|
+
spec = None
|
|
2369
|
+
for d in data:
|
|
2370
|
+
flat_tree, spec = tree_flatten(d)
|
|
2371
|
+
flat_trees.append(flat_tree)
|
|
2372
|
+
|
|
2373
|
+
leaves = []
|
|
2374
|
+
for leaf in zip(*flat_trees):
|
|
2375
|
+
leaf = torch.stack(leaf)
|
|
2376
|
+
leaves.append(leaf)
|
|
2377
|
+
|
|
2378
|
+
return tree_unflatten(leaves, spec)
|
|
2379
|
+
|
|
2380
|
+
|
|
2381
|
+
def _collate_id(x):
|
|
2382
|
+
return x
|
|
2383
|
+
|
|
2384
|
+
|
|
2385
|
+
def _get_default_collate(storage, _is_tensordict=False):
|
|
2386
|
+
if isinstance(storage, (LazyStackStorage, TensorStorage)):
|
|
2387
|
+
return _collate_id
|
|
2388
|
+
elif isinstance(storage, CompressedListStorage):
|
|
2389
|
+
return lazy_stack
|
|
2390
|
+
elif isinstance(storage, (ListStorage, StorageEnsemble)):
|
|
2391
|
+
return _stack_anything
|
|
2392
|
+
else:
|
|
2393
|
+
raise NotImplementedError(
|
|
2394
|
+
f"Could not find a default collate_fn for storage {type(storage)}."
|
|
2395
|
+
)
|
|
2396
|
+
|
|
2397
|
+
|
|
2398
|
+
def _make_memmap(tensor, path):
|
|
2399
|
+
return MemoryMappedTensor.from_tensor(tensor, filename=path)
|
|
2400
|
+
|
|
2401
|
+
|
|
2402
|
+
def _make_empty_memmap(shape, dtype, path):
|
|
2403
|
+
return MemoryMappedTensor.empty(shape=shape, dtype=dtype, filename=path)
|
|
2404
|
+
|
|
2405
|
+
|
|
2406
|
+
def _flip_list(data):
|
|
2407
|
+
if all(is_tensor_collection(_data) for _data in data):
|
|
2408
|
+
return torch.stack(data)
|
|
2409
|
+
flat_data, flat_specs = zip(*[tree_flatten(item) for item in data])
|
|
2410
|
+
flat_data = zip(*flat_data)
|
|
2411
|
+
stacks = [torch.stack(item) for item in flat_data]
|
|
2412
|
+
return tree_unflatten(stacks, flat_specs[0])
|