torchrl 0.11.0__cp314-cp314-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/benchmark_batched_envs.py +104 -0
- benchmarks/conftest.py +91 -0
- benchmarks/ecosystem/gym_env_throughput.py +321 -0
- benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
- benchmarks/requirements.txt +7 -0
- benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
- benchmarks/test_collectors_benchmark.py +240 -0
- benchmarks/test_compressed_storage_benchmark.py +145 -0
- benchmarks/test_envs_benchmark.py +133 -0
- benchmarks/test_llm.py +101 -0
- benchmarks/test_non_tensor_env_benchmark.py +70 -0
- benchmarks/test_objectives_benchmarks.py +1199 -0
- benchmarks/test_replaybuffer_benchmark.py +254 -0
- sota-check/README.md +35 -0
- sota-implementations/README.md +142 -0
- sota-implementations/a2c/README.md +39 -0
- sota-implementations/a2c/a2c_atari.py +291 -0
- sota-implementations/a2c/a2c_mujoco.py +273 -0
- sota-implementations/a2c/utils_atari.py +240 -0
- sota-implementations/a2c/utils_mujoco.py +160 -0
- sota-implementations/bandits/README.md +7 -0
- sota-implementations/bandits/dqn.py +126 -0
- sota-implementations/cql/cql_offline.py +198 -0
- sota-implementations/cql/cql_online.py +249 -0
- sota-implementations/cql/discrete_cql_offline.py +180 -0
- sota-implementations/cql/discrete_cql_online.py +227 -0
- sota-implementations/cql/utils.py +471 -0
- sota-implementations/crossq/crossq.py +271 -0
- sota-implementations/crossq/utils.py +320 -0
- sota-implementations/ddpg/ddpg.py +231 -0
- sota-implementations/ddpg/utils.py +325 -0
- sota-implementations/decision_transformer/dt.py +163 -0
- sota-implementations/decision_transformer/lamb.py +167 -0
- sota-implementations/decision_transformer/online_dt.py +178 -0
- sota-implementations/decision_transformer/utils.py +562 -0
- sota-implementations/discrete_sac/discrete_sac.py +243 -0
- sota-implementations/discrete_sac/utils.py +324 -0
- sota-implementations/dqn/README.md +30 -0
- sota-implementations/dqn/dqn_atari.py +272 -0
- sota-implementations/dqn/dqn_cartpole.py +236 -0
- sota-implementations/dqn/utils_atari.py +132 -0
- sota-implementations/dqn/utils_cartpole.py +90 -0
- sota-implementations/dreamer/README.md +129 -0
- sota-implementations/dreamer/dreamer.py +586 -0
- sota-implementations/dreamer/dreamer_utils.py +1107 -0
- sota-implementations/expert-iteration/README.md +352 -0
- sota-implementations/expert-iteration/ei_utils.py +770 -0
- sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
- sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
- sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
- sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
- sota-implementations/gail/gail.py +327 -0
- sota-implementations/gail/gail_utils.py +68 -0
- sota-implementations/gail/ppo_utils.py +157 -0
- sota-implementations/grpo/README.md +273 -0
- sota-implementations/grpo/grpo-async.py +437 -0
- sota-implementations/grpo/grpo-sync.py +435 -0
- sota-implementations/grpo/grpo_utils.py +843 -0
- sota-implementations/grpo/requirements_gsm8k.txt +11 -0
- sota-implementations/grpo/requirements_ifeval.txt +16 -0
- sota-implementations/impala/README.md +33 -0
- sota-implementations/impala/impala_multi_node_ray.py +292 -0
- sota-implementations/impala/impala_multi_node_submitit.py +284 -0
- sota-implementations/impala/impala_single_node.py +261 -0
- sota-implementations/impala/utils.py +184 -0
- sota-implementations/iql/discrete_iql.py +230 -0
- sota-implementations/iql/iql_offline.py +164 -0
- sota-implementations/iql/iql_online.py +225 -0
- sota-implementations/iql/utils.py +437 -0
- sota-implementations/multiagent/README.md +74 -0
- sota-implementations/multiagent/iql.py +237 -0
- sota-implementations/multiagent/maddpg_iddpg.py +266 -0
- sota-implementations/multiagent/mappo_ippo.py +267 -0
- sota-implementations/multiagent/qmix_vdn.py +271 -0
- sota-implementations/multiagent/sac.py +337 -0
- sota-implementations/multiagent/utils/__init__.py +4 -0
- sota-implementations/multiagent/utils/logging.py +151 -0
- sota-implementations/multiagent/utils/utils.py +43 -0
- sota-implementations/ppo/README.md +29 -0
- sota-implementations/ppo/ppo_atari.py +305 -0
- sota-implementations/ppo/ppo_mujoco.py +293 -0
- sota-implementations/ppo/utils_atari.py +238 -0
- sota-implementations/ppo/utils_mujoco.py +152 -0
- sota-implementations/ppo_trainer/train.py +21 -0
- sota-implementations/redq/README.md +7 -0
- sota-implementations/redq/redq.py +199 -0
- sota-implementations/redq/utils.py +1060 -0
- sota-implementations/sac/sac-async.py +266 -0
- sota-implementations/sac/sac.py +239 -0
- sota-implementations/sac/utils.py +381 -0
- sota-implementations/sac_trainer/train.py +16 -0
- sota-implementations/td3/td3.py +254 -0
- sota-implementations/td3/utils.py +319 -0
- sota-implementations/td3_bc/td3_bc.py +177 -0
- sota-implementations/td3_bc/utils.py +251 -0
- torchrl/__init__.py +144 -0
- torchrl/_extension.py +74 -0
- torchrl/_torchrl.cp314-win_amd64.pyd +0 -0
- torchrl/_utils.py +1431 -0
- torchrl/collectors/__init__.py +48 -0
- torchrl/collectors/_base.py +1058 -0
- torchrl/collectors/_constants.py +88 -0
- torchrl/collectors/_multi_async.py +324 -0
- torchrl/collectors/_multi_base.py +1805 -0
- torchrl/collectors/_multi_sync.py +464 -0
- torchrl/collectors/_runner.py +581 -0
- torchrl/collectors/_single.py +2009 -0
- torchrl/collectors/_single_async.py +259 -0
- torchrl/collectors/collectors.py +62 -0
- torchrl/collectors/distributed/__init__.py +32 -0
- torchrl/collectors/distributed/default_configs.py +133 -0
- torchrl/collectors/distributed/generic.py +1306 -0
- torchrl/collectors/distributed/ray.py +1092 -0
- torchrl/collectors/distributed/rpc.py +1006 -0
- torchrl/collectors/distributed/sync.py +731 -0
- torchrl/collectors/distributed/utils.py +160 -0
- torchrl/collectors/llm/__init__.py +10 -0
- torchrl/collectors/llm/base.py +494 -0
- torchrl/collectors/llm/ray_collector.py +275 -0
- torchrl/collectors/llm/utils.py +36 -0
- torchrl/collectors/llm/weight_update/__init__.py +10 -0
- torchrl/collectors/llm/weight_update/vllm.py +348 -0
- torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
- torchrl/collectors/utils.py +433 -0
- torchrl/collectors/weight_update.py +591 -0
- torchrl/csrc/numpy_utils.h +38 -0
- torchrl/csrc/pybind.cpp +27 -0
- torchrl/csrc/segment_tree.h +458 -0
- torchrl/csrc/torch_utils.h +34 -0
- torchrl/csrc/utils.cpp +48 -0
- torchrl/csrc/utils.h +31 -0
- torchrl/data/__init__.py +187 -0
- torchrl/data/datasets/__init__.py +58 -0
- torchrl/data/datasets/atari_dqn.py +878 -0
- torchrl/data/datasets/common.py +281 -0
- torchrl/data/datasets/d4rl.py +489 -0
- torchrl/data/datasets/d4rl_infos.py +187 -0
- torchrl/data/datasets/gen_dgrl.py +375 -0
- torchrl/data/datasets/minari_data.py +643 -0
- torchrl/data/datasets/openml.py +177 -0
- torchrl/data/datasets/openx.py +798 -0
- torchrl/data/datasets/roboset.py +363 -0
- torchrl/data/datasets/utils.py +11 -0
- torchrl/data/datasets/vd4rl.py +432 -0
- torchrl/data/llm/__init__.py +34 -0
- torchrl/data/llm/dataset.py +491 -0
- torchrl/data/llm/history.py +1378 -0
- torchrl/data/llm/prompt.py +198 -0
- torchrl/data/llm/reward.py +225 -0
- torchrl/data/llm/topk.py +186 -0
- torchrl/data/llm/utils.py +543 -0
- torchrl/data/map/__init__.py +21 -0
- torchrl/data/map/hash.py +185 -0
- torchrl/data/map/query.py +204 -0
- torchrl/data/map/tdstorage.py +363 -0
- torchrl/data/map/tree.py +1434 -0
- torchrl/data/map/utils.py +103 -0
- torchrl/data/postprocs/__init__.py +8 -0
- torchrl/data/postprocs/postprocs.py +391 -0
- torchrl/data/replay_buffers/__init__.py +99 -0
- torchrl/data/replay_buffers/checkpointers.py +622 -0
- torchrl/data/replay_buffers/ray_buffer.py +292 -0
- torchrl/data/replay_buffers/replay_buffers.py +2376 -0
- torchrl/data/replay_buffers/samplers.py +2578 -0
- torchrl/data/replay_buffers/scheduler.py +265 -0
- torchrl/data/replay_buffers/storages.py +2412 -0
- torchrl/data/replay_buffers/utils.py +1042 -0
- torchrl/data/replay_buffers/writers.py +781 -0
- torchrl/data/tensor_specs.py +7101 -0
- torchrl/data/utils.py +334 -0
- torchrl/envs/__init__.py +265 -0
- torchrl/envs/async_envs.py +1105 -0
- torchrl/envs/batched_envs.py +3093 -0
- torchrl/envs/common.py +4241 -0
- torchrl/envs/custom/__init__.py +11 -0
- torchrl/envs/custom/chess.py +617 -0
- torchrl/envs/custom/llm.py +214 -0
- torchrl/envs/custom/pendulum.py +401 -0
- torchrl/envs/custom/san_moves.txt +29274 -0
- torchrl/envs/custom/tictactoeenv.py +288 -0
- torchrl/envs/env_creator.py +263 -0
- torchrl/envs/gym_like.py +752 -0
- torchrl/envs/libs/__init__.py +68 -0
- torchrl/envs/libs/_gym_utils.py +326 -0
- torchrl/envs/libs/brax.py +846 -0
- torchrl/envs/libs/dm_control.py +544 -0
- torchrl/envs/libs/envpool.py +447 -0
- torchrl/envs/libs/gym.py +2239 -0
- torchrl/envs/libs/habitat.py +138 -0
- torchrl/envs/libs/isaac_lab.py +87 -0
- torchrl/envs/libs/isaacgym.py +203 -0
- torchrl/envs/libs/jax_utils.py +166 -0
- torchrl/envs/libs/jumanji.py +963 -0
- torchrl/envs/libs/meltingpot.py +599 -0
- torchrl/envs/libs/openml.py +153 -0
- torchrl/envs/libs/openspiel.py +652 -0
- torchrl/envs/libs/pettingzoo.py +1042 -0
- torchrl/envs/libs/procgen.py +351 -0
- torchrl/envs/libs/robohive.py +429 -0
- torchrl/envs/libs/smacv2.py +645 -0
- torchrl/envs/libs/unity_mlagents.py +891 -0
- torchrl/envs/libs/utils.py +147 -0
- torchrl/envs/libs/vmas.py +813 -0
- torchrl/envs/llm/__init__.py +63 -0
- torchrl/envs/llm/chat.py +730 -0
- torchrl/envs/llm/datasets/README.md +4 -0
- torchrl/envs/llm/datasets/__init__.py +17 -0
- torchrl/envs/llm/datasets/gsm8k.py +353 -0
- torchrl/envs/llm/datasets/ifeval.py +274 -0
- torchrl/envs/llm/envs.py +789 -0
- torchrl/envs/llm/libs/README.md +3 -0
- torchrl/envs/llm/libs/__init__.py +8 -0
- torchrl/envs/llm/libs/mlgym.py +869 -0
- torchrl/envs/llm/reward/__init__.py +10 -0
- torchrl/envs/llm/reward/gsm8k.py +324 -0
- torchrl/envs/llm/reward/ifeval/README.md +13 -0
- torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
- torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
- torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
- torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
- torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
- torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
- torchrl/envs/llm/transforms/__init__.py +55 -0
- torchrl/envs/llm/transforms/browser.py +292 -0
- torchrl/envs/llm/transforms/dataloading.py +859 -0
- torchrl/envs/llm/transforms/format.py +73 -0
- torchrl/envs/llm/transforms/kl.py +1544 -0
- torchrl/envs/llm/transforms/policy_version.py +189 -0
- torchrl/envs/llm/transforms/reason.py +323 -0
- torchrl/envs/llm/transforms/tokenizer.py +321 -0
- torchrl/envs/llm/transforms/tools.py +1955 -0
- torchrl/envs/model_based/__init__.py +9 -0
- torchrl/envs/model_based/common.py +180 -0
- torchrl/envs/model_based/dreamer.py +112 -0
- torchrl/envs/transforms/__init__.py +147 -0
- torchrl/envs/transforms/functional.py +48 -0
- torchrl/envs/transforms/gym_transforms.py +203 -0
- torchrl/envs/transforms/module.py +341 -0
- torchrl/envs/transforms/r3m.py +372 -0
- torchrl/envs/transforms/ray_service.py +663 -0
- torchrl/envs/transforms/rb_transforms.py +214 -0
- torchrl/envs/transforms/transforms.py +11835 -0
- torchrl/envs/transforms/utils.py +94 -0
- torchrl/envs/transforms/vc1.py +307 -0
- torchrl/envs/transforms/vecnorm.py +845 -0
- torchrl/envs/transforms/vip.py +407 -0
- torchrl/envs/utils.py +1718 -0
- torchrl/envs/vec_envs.py +11 -0
- torchrl/modules/__init__.py +206 -0
- torchrl/modules/distributions/__init__.py +73 -0
- torchrl/modules/distributions/continuous.py +830 -0
- torchrl/modules/distributions/discrete.py +908 -0
- torchrl/modules/distributions/truncated_normal.py +187 -0
- torchrl/modules/distributions/utils.py +233 -0
- torchrl/modules/llm/__init__.py +62 -0
- torchrl/modules/llm/backends/__init__.py +65 -0
- torchrl/modules/llm/backends/vllm/__init__.py +94 -0
- torchrl/modules/llm/backends/vllm/_models.py +46 -0
- torchrl/modules/llm/backends/vllm/base.py +72 -0
- torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
- torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
- torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
- torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
- torchrl/modules/llm/policies/__init__.py +28 -0
- torchrl/modules/llm/policies/common.py +1809 -0
- torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
- torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
- torchrl/modules/llm/utils.py +23 -0
- torchrl/modules/mcts/__init__.py +21 -0
- torchrl/modules/mcts/scores.py +579 -0
- torchrl/modules/models/__init__.py +86 -0
- torchrl/modules/models/batchrenorm.py +119 -0
- torchrl/modules/models/decision_transformer.py +179 -0
- torchrl/modules/models/exploration.py +731 -0
- torchrl/modules/models/llm.py +156 -0
- torchrl/modules/models/model_based.py +596 -0
- torchrl/modules/models/models.py +1712 -0
- torchrl/modules/models/multiagent.py +1067 -0
- torchrl/modules/models/recipes/impala.py +185 -0
- torchrl/modules/models/utils.py +162 -0
- torchrl/modules/planners/__init__.py +10 -0
- torchrl/modules/planners/cem.py +228 -0
- torchrl/modules/planners/common.py +73 -0
- torchrl/modules/planners/mppi.py +265 -0
- torchrl/modules/tensordict_module/__init__.py +89 -0
- torchrl/modules/tensordict_module/actors.py +2457 -0
- torchrl/modules/tensordict_module/common.py +529 -0
- torchrl/modules/tensordict_module/exploration.py +814 -0
- torchrl/modules/tensordict_module/probabilistic.py +321 -0
- torchrl/modules/tensordict_module/rnn.py +1639 -0
- torchrl/modules/tensordict_module/sequence.py +132 -0
- torchrl/modules/tensordict_module/world_models.py +34 -0
- torchrl/modules/utils/__init__.py +38 -0
- torchrl/modules/utils/mappings.py +9 -0
- torchrl/modules/utils/utils.py +89 -0
- torchrl/objectives/__init__.py +78 -0
- torchrl/objectives/a2c.py +659 -0
- torchrl/objectives/common.py +753 -0
- torchrl/objectives/cql.py +1346 -0
- torchrl/objectives/crossq.py +710 -0
- torchrl/objectives/ddpg.py +453 -0
- torchrl/objectives/decision_transformer.py +371 -0
- torchrl/objectives/deprecated.py +516 -0
- torchrl/objectives/dqn.py +683 -0
- torchrl/objectives/dreamer.py +488 -0
- torchrl/objectives/functional.py +48 -0
- torchrl/objectives/gail.py +258 -0
- torchrl/objectives/iql.py +996 -0
- torchrl/objectives/llm/__init__.py +30 -0
- torchrl/objectives/llm/grpo.py +846 -0
- torchrl/objectives/llm/sft.py +482 -0
- torchrl/objectives/multiagent/__init__.py +8 -0
- torchrl/objectives/multiagent/qmixer.py +396 -0
- torchrl/objectives/ppo.py +1669 -0
- torchrl/objectives/redq.py +683 -0
- torchrl/objectives/reinforce.py +530 -0
- torchrl/objectives/sac.py +1580 -0
- torchrl/objectives/td3.py +570 -0
- torchrl/objectives/td3_bc.py +625 -0
- torchrl/objectives/utils.py +782 -0
- torchrl/objectives/value/__init__.py +28 -0
- torchrl/objectives/value/advantages.py +1956 -0
- torchrl/objectives/value/functional.py +1459 -0
- torchrl/objectives/value/utils.py +360 -0
- torchrl/record/__init__.py +17 -0
- torchrl/record/loggers/__init__.py +23 -0
- torchrl/record/loggers/common.py +48 -0
- torchrl/record/loggers/csv.py +226 -0
- torchrl/record/loggers/mlflow.py +142 -0
- torchrl/record/loggers/tensorboard.py +139 -0
- torchrl/record/loggers/trackio.py +163 -0
- torchrl/record/loggers/utils.py +78 -0
- torchrl/record/loggers/wandb.py +214 -0
- torchrl/record/recorder.py +554 -0
- torchrl/services/__init__.py +79 -0
- torchrl/services/base.py +109 -0
- torchrl/services/ray_service.py +453 -0
- torchrl/testing/__init__.py +107 -0
- torchrl/testing/assertions.py +179 -0
- torchrl/testing/dist_utils.py +122 -0
- torchrl/testing/env_creators.py +227 -0
- torchrl/testing/env_helper.py +35 -0
- torchrl/testing/gym_helpers.py +156 -0
- torchrl/testing/llm_mocks.py +119 -0
- torchrl/testing/mocking_classes.py +2720 -0
- torchrl/testing/modules.py +295 -0
- torchrl/testing/mp_helpers.py +15 -0
- torchrl/testing/ray_helpers.py +293 -0
- torchrl/testing/utils.py +190 -0
- torchrl/trainers/__init__.py +42 -0
- torchrl/trainers/algorithms/__init__.py +11 -0
- torchrl/trainers/algorithms/configs/__init__.py +705 -0
- torchrl/trainers/algorithms/configs/collectors.py +216 -0
- torchrl/trainers/algorithms/configs/common.py +41 -0
- torchrl/trainers/algorithms/configs/data.py +308 -0
- torchrl/trainers/algorithms/configs/envs.py +104 -0
- torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
- torchrl/trainers/algorithms/configs/logging.py +80 -0
- torchrl/trainers/algorithms/configs/modules.py +570 -0
- torchrl/trainers/algorithms/configs/objectives.py +177 -0
- torchrl/trainers/algorithms/configs/trainers.py +340 -0
- torchrl/trainers/algorithms/configs/transforms.py +955 -0
- torchrl/trainers/algorithms/configs/utils.py +252 -0
- torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
- torchrl/trainers/algorithms/configs/weight_update.py +159 -0
- torchrl/trainers/algorithms/ppo.py +373 -0
- torchrl/trainers/algorithms/sac.py +308 -0
- torchrl/trainers/helpers/__init__.py +40 -0
- torchrl/trainers/helpers/collectors.py +416 -0
- torchrl/trainers/helpers/envs.py +573 -0
- torchrl/trainers/helpers/logger.py +33 -0
- torchrl/trainers/helpers/losses.py +132 -0
- torchrl/trainers/helpers/models.py +658 -0
- torchrl/trainers/helpers/replay_buffer.py +59 -0
- torchrl/trainers/helpers/trainers.py +301 -0
- torchrl/trainers/trainers.py +2052 -0
- torchrl/weight_update/__init__.py +33 -0
- torchrl/weight_update/_distributed.py +749 -0
- torchrl/weight_update/_mp.py +624 -0
- torchrl/weight_update/_noupdate.py +102 -0
- torchrl/weight_update/_ray.py +1032 -0
- torchrl/weight_update/_rpc.py +284 -0
- torchrl/weight_update/_shared.py +891 -0
- torchrl/weight_update/llm/__init__.py +32 -0
- torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
- torchrl/weight_update/llm/vllm_nccl.py +710 -0
- torchrl/weight_update/utils.py +73 -0
- torchrl/weight_update/weight_sync_schemes.py +1244 -0
- torchrl-0.11.0.dist-info/LICENSE +21 -0
- torchrl-0.11.0.dist-info/METADATA +1307 -0
- torchrl-0.11.0.dist-info/RECORD +394 -0
- torchrl-0.11.0.dist-info/WHEEL +5 -0
- torchrl-0.11.0.dist-info/entry_points.txt +2 -0
- torchrl-0.11.0.dist-info/top_level.txt +7 -0
|
@@ -0,0 +1,1042 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
# import tree
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
import contextlib
|
|
9
|
+
import itertools
|
|
10
|
+
import math
|
|
11
|
+
import operator
|
|
12
|
+
import os
|
|
13
|
+
import typing
|
|
14
|
+
from collections.abc import Callable
|
|
15
|
+
from pathlib import Path
|
|
16
|
+
from typing import Any, Union
|
|
17
|
+
|
|
18
|
+
import numpy as np
|
|
19
|
+
import torch
|
|
20
|
+
from tensordict import (
|
|
21
|
+
lazy_stack,
|
|
22
|
+
MemoryMappedTensor,
|
|
23
|
+
NonTensorData,
|
|
24
|
+
TensorDict,
|
|
25
|
+
TensorDictBase,
|
|
26
|
+
unravel_key,
|
|
27
|
+
)
|
|
28
|
+
from torch import Tensor
|
|
29
|
+
from torch.nn import functional as F
|
|
30
|
+
from torch.utils._pytree import LeafSpec, tree_flatten, tree_unflatten
|
|
31
|
+
from torchrl._utils import implement_for, logger as torchrl_logger
|
|
32
|
+
|
|
33
|
+
SINGLE_TENSOR_BUFFER_NAME = os.environ.get(
|
|
34
|
+
"SINGLE_TENSOR_BUFFER_NAME", "_-single-tensor-_"
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
INT_CLASSES_TYPING = Union[int, np.integer]
|
|
39
|
+
if hasattr(typing, "get_args"):
|
|
40
|
+
INT_CLASSES = typing.get_args(INT_CLASSES_TYPING)
|
|
41
|
+
else:
|
|
42
|
+
# python 3.7
|
|
43
|
+
INT_CLASSES = (int, np.integer)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def _to_numpy(data: Tensor) -> np.ndarray:
|
|
47
|
+
return data.detach().cpu().numpy() if isinstance(data, torch.Tensor) else data
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def _to_torch(
|
|
51
|
+
data: Tensor, device, pin_memory: bool = False, non_blocking: bool = False
|
|
52
|
+
) -> torch.Tensor:
|
|
53
|
+
if isinstance(data, np.generic):
|
|
54
|
+
return torch.as_tensor(data, device=device)
|
|
55
|
+
elif isinstance(data, np.ndarray):
|
|
56
|
+
data = torch.from_numpy(data)
|
|
57
|
+
elif not isinstance(data, Tensor):
|
|
58
|
+
data = torch.as_tensor(data, device=device)
|
|
59
|
+
|
|
60
|
+
if pin_memory:
|
|
61
|
+
data = data.pin_memory()
|
|
62
|
+
if device is not None:
|
|
63
|
+
data = data.to(device, non_blocking=non_blocking)
|
|
64
|
+
|
|
65
|
+
return data
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def pin_memory_output(fun) -> Callable:
|
|
69
|
+
"""Calls pin_memory on outputs of decorated function if they have such method."""
|
|
70
|
+
|
|
71
|
+
def decorated_fun(self, *args, **kwargs):
|
|
72
|
+
output = fun(self, *args, **kwargs)
|
|
73
|
+
if self._pin_memory:
|
|
74
|
+
_tuple_out = True
|
|
75
|
+
if not isinstance(output, tuple):
|
|
76
|
+
_tuple_out = False
|
|
77
|
+
output = (output,)
|
|
78
|
+
output = tuple(_pin_memory(_output) for _output in output)
|
|
79
|
+
if _tuple_out:
|
|
80
|
+
return output
|
|
81
|
+
return output[0]
|
|
82
|
+
return output
|
|
83
|
+
|
|
84
|
+
return decorated_fun
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def _pin_memory(output: Any) -> Any:
|
|
88
|
+
if hasattr(output, "pin_memory") and output.device == torch.device("cpu"):
|
|
89
|
+
return output.pin_memory()
|
|
90
|
+
else:
|
|
91
|
+
return output
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def _reduce(
|
|
95
|
+
tensor: torch.Tensor, reduction: str, dim: int | None = None
|
|
96
|
+
) -> float | torch.Tensor:
|
|
97
|
+
"""Reduces a tensor given the reduction method."""
|
|
98
|
+
if reduction == "max":
|
|
99
|
+
result = tensor.max(dim=dim)
|
|
100
|
+
elif reduction == "min":
|
|
101
|
+
result = tensor.min(dim=dim)
|
|
102
|
+
elif reduction == "mean":
|
|
103
|
+
result = tensor.mean(dim=dim)
|
|
104
|
+
elif reduction == "median":
|
|
105
|
+
result = tensor.median(dim=dim)
|
|
106
|
+
elif reduction == "sum":
|
|
107
|
+
result = tensor.sum(dim=dim)
|
|
108
|
+
else:
|
|
109
|
+
raise NotImplementedError(f"Unknown reduction method {reduction}")
|
|
110
|
+
if isinstance(result, tuple):
|
|
111
|
+
result = result[0]
|
|
112
|
+
return result.item() if dim is None else result
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def _is_int(index):
|
|
116
|
+
if isinstance(index, INT_CLASSES):
|
|
117
|
+
return True
|
|
118
|
+
if isinstance(index, (np.ndarray, torch.Tensor)):
|
|
119
|
+
return index.ndim == 0
|
|
120
|
+
return False
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
class TED2Flat:
|
|
124
|
+
"""A storage saving hook to serialize TED data in a compact format.
|
|
125
|
+
|
|
126
|
+
Args:
|
|
127
|
+
done_key (NestedKey, optional): the key where the done states should be read.
|
|
128
|
+
Defaults to ``("next", "done")``.
|
|
129
|
+
shift_key (NestedKey, optional): the key where the shift will be written.
|
|
130
|
+
Defaults to "shift".
|
|
131
|
+
is_full_key (NestedKey, optional): the key where the is_full attribute will be written.
|
|
132
|
+
Defaults to "is_full".
|
|
133
|
+
done_keys (Tuple[NestedKey], optional): a tuple of nested keys indicating the done entries.
|
|
134
|
+
Defaults to ("done", "truncated", "terminated")
|
|
135
|
+
reward_keys (Tuple[NestedKey], optional): a tuple of nested keys indicating the reward entries.
|
|
136
|
+
Defaults to ("reward",)
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
Examples:
|
|
140
|
+
>>> import tempfile
|
|
141
|
+
>>>
|
|
142
|
+
>>> from tensordict import TensorDict
|
|
143
|
+
>>>
|
|
144
|
+
>>> from torchrl.collectors import Collector
|
|
145
|
+
>>> from torchrl.data import ReplayBuffer, TED2Flat, LazyMemmapStorage
|
|
146
|
+
>>> from torchrl.envs import GymEnv
|
|
147
|
+
>>> import torch
|
|
148
|
+
>>>
|
|
149
|
+
>>> env = GymEnv("CartPole-v1")
|
|
150
|
+
>>> env.set_seed(0)
|
|
151
|
+
>>> torch.manual_seed(0)
|
|
152
|
+
>>> collector = Collector(env, policy=env.rand_step, total_frames=200, frames_per_batch=200)
|
|
153
|
+
>>> rb = ReplayBuffer(storage=LazyMemmapStorage(200))
|
|
154
|
+
>>> rb.register_save_hook(TED2Flat())
|
|
155
|
+
>>> with tempfile.TemporaryDirectory() as tmpdir:
|
|
156
|
+
... for i, data in enumerate(collector):
|
|
157
|
+
... rb.extend(data)
|
|
158
|
+
... rb.dumps(tmpdir)
|
|
159
|
+
... # load the data to represent it
|
|
160
|
+
... td = TensorDict.load(tmpdir + "/storage/")
|
|
161
|
+
... print(td)
|
|
162
|
+
TensorDict(
|
|
163
|
+
fields={
|
|
164
|
+
action: MemoryMappedTensor(shape=torch.Size([200, 2]), device=cpu, dtype=torch.int64, is_shared=True),
|
|
165
|
+
collector: TensorDict(
|
|
166
|
+
fields={
|
|
167
|
+
traj_ids: MemoryMappedTensor(shape=torch.Size([200]), device=cpu, dtype=torch.int64, is_shared=True)},
|
|
168
|
+
batch_size=torch.Size([]),
|
|
169
|
+
device=cpu,
|
|
170
|
+
is_shared=False),
|
|
171
|
+
done: MemoryMappedTensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=True),
|
|
172
|
+
observation: MemoryMappedTensor(shape=torch.Size([220, 4]), device=cpu, dtype=torch.float32, is_shared=True),
|
|
173
|
+
reward: MemoryMappedTensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=True),
|
|
174
|
+
terminated: MemoryMappedTensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=True),
|
|
175
|
+
truncated: MemoryMappedTensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=True)},
|
|
176
|
+
batch_size=torch.Size([]),
|
|
177
|
+
device=cpu,
|
|
178
|
+
is_shared=False)
|
|
179
|
+
|
|
180
|
+
"""
|
|
181
|
+
|
|
182
|
+
_shift: int | None = None
|
|
183
|
+
_is_full: bool | None = None
|
|
184
|
+
|
|
185
|
+
def __init__(
|
|
186
|
+
self,
|
|
187
|
+
done_key=("next", "done"),
|
|
188
|
+
shift_key="shift",
|
|
189
|
+
is_full_key="is_full",
|
|
190
|
+
done_keys=("done", "truncated", "terminated"),
|
|
191
|
+
reward_keys=("reward",),
|
|
192
|
+
):
|
|
193
|
+
self.done_key = done_key
|
|
194
|
+
self.shift_key = shift_key
|
|
195
|
+
self.is_full_key = is_full_key
|
|
196
|
+
self.done_keys = {unravel_key(key) for key in done_keys}
|
|
197
|
+
self.reward_keys = {unravel_key(key) for key in reward_keys}
|
|
198
|
+
|
|
199
|
+
@property
|
|
200
|
+
def shift(self):
|
|
201
|
+
return self._shift
|
|
202
|
+
|
|
203
|
+
@shift.setter
|
|
204
|
+
def shift(self, value: int):
|
|
205
|
+
self._shift = value
|
|
206
|
+
|
|
207
|
+
@property
|
|
208
|
+
def is_full(self):
|
|
209
|
+
return self._is_full
|
|
210
|
+
|
|
211
|
+
@is_full.setter
|
|
212
|
+
def is_full(self, value: int):
|
|
213
|
+
self._is_full = value
|
|
214
|
+
|
|
215
|
+
def __call__(self, data: TensorDictBase, path: Path = None):
|
|
216
|
+
# Get the done state
|
|
217
|
+
shift = self.shift
|
|
218
|
+
is_full = self.is_full
|
|
219
|
+
|
|
220
|
+
# Create an output storage
|
|
221
|
+
output = TensorDict()
|
|
222
|
+
output.set_non_tensor(self.is_full_key, is_full)
|
|
223
|
+
output.set_non_tensor(self.shift_key, shift)
|
|
224
|
+
output.set_non_tensor("_storage_shape", tuple(data.shape))
|
|
225
|
+
output.memmap_(path)
|
|
226
|
+
|
|
227
|
+
# Preallocate the output
|
|
228
|
+
done = data.get(self.done_key).squeeze(-1).clone()
|
|
229
|
+
if not is_full:
|
|
230
|
+
# shift is the cursor place
|
|
231
|
+
done[shift - 1] = True
|
|
232
|
+
else:
|
|
233
|
+
done = done.roll(-shift, dims=0)
|
|
234
|
+
done[-1] = True
|
|
235
|
+
ntraj = done.sum()
|
|
236
|
+
|
|
237
|
+
# Get the keys that require extra storage
|
|
238
|
+
keys_to_expand = set(data.get("next").keys(True, True)) - (
|
|
239
|
+
self.done_keys.union(self.reward_keys)
|
|
240
|
+
)
|
|
241
|
+
|
|
242
|
+
total_keys = data.exclude("next").keys(True, True)
|
|
243
|
+
total_keys = set(total_keys).union(set(data.get("next").keys(True, True)))
|
|
244
|
+
|
|
245
|
+
len_with_offset = data.numel() + ntraj # + done[0].numel()
|
|
246
|
+
for key in total_keys:
|
|
247
|
+
if key in (self.done_keys.union(self.reward_keys)):
|
|
248
|
+
entry = data.get(("next", key))
|
|
249
|
+
else:
|
|
250
|
+
entry = data.get(key)
|
|
251
|
+
|
|
252
|
+
if key in keys_to_expand:
|
|
253
|
+
shape = torch.Size([len_with_offset, *entry.shape[data.ndim :]])
|
|
254
|
+
dtype = entry.dtype
|
|
255
|
+
output.make_memmap(key, shape=shape, dtype=dtype)
|
|
256
|
+
else:
|
|
257
|
+
shape = torch.Size([data.numel(), *entry.shape[data.ndim :]])
|
|
258
|
+
output.make_memmap(key, shape=shape, dtype=entry.dtype)
|
|
259
|
+
|
|
260
|
+
if data.ndim == 1:
|
|
261
|
+
return self._call(
|
|
262
|
+
data=data,
|
|
263
|
+
output=output,
|
|
264
|
+
is_full=is_full,
|
|
265
|
+
shift=shift,
|
|
266
|
+
done=done,
|
|
267
|
+
total_keys=total_keys,
|
|
268
|
+
keys_to_expand=keys_to_expand,
|
|
269
|
+
)
|
|
270
|
+
|
|
271
|
+
with data.flatten(1, -1) if data.ndim > 2 else contextlib.nullcontext(
|
|
272
|
+
data
|
|
273
|
+
) as data_flat:
|
|
274
|
+
if data.ndim > 2:
|
|
275
|
+
done = done.flatten(1, -1)
|
|
276
|
+
traj_per_dim = done.sum(0)
|
|
277
|
+
nsteps = data_flat.shape[0]
|
|
278
|
+
|
|
279
|
+
start = 0
|
|
280
|
+
start_with_offset = start
|
|
281
|
+
stop_with_offset = 0
|
|
282
|
+
stop = 0
|
|
283
|
+
for data_slice, done_slice, traj_for_dim in zip(
|
|
284
|
+
data_flat.unbind(1), done.unbind(1), traj_per_dim
|
|
285
|
+
):
|
|
286
|
+
stop_with_offset = stop_with_offset + nsteps + traj_for_dim
|
|
287
|
+
cur_slice_offset = slice(start_with_offset, stop_with_offset)
|
|
288
|
+
start_with_offset = stop_with_offset
|
|
289
|
+
|
|
290
|
+
stop = stop + data.shape[0]
|
|
291
|
+
cur_slice = slice(start, stop)
|
|
292
|
+
start = stop
|
|
293
|
+
|
|
294
|
+
def _index(
|
|
295
|
+
key,
|
|
296
|
+
val,
|
|
297
|
+
keys_to_expand=keys_to_expand,
|
|
298
|
+
cur_slice=cur_slice,
|
|
299
|
+
cur_slice_offset=cur_slice_offset,
|
|
300
|
+
):
|
|
301
|
+
if key in keys_to_expand:
|
|
302
|
+
return val[cur_slice_offset]
|
|
303
|
+
return val[cur_slice]
|
|
304
|
+
|
|
305
|
+
out_slice = output.named_apply(_index, nested_keys=True)
|
|
306
|
+
self._call(
|
|
307
|
+
data=data_slice,
|
|
308
|
+
output=out_slice,
|
|
309
|
+
is_full=is_full,
|
|
310
|
+
shift=shift,
|
|
311
|
+
done=done_slice,
|
|
312
|
+
total_keys=total_keys,
|
|
313
|
+
keys_to_expand=keys_to_expand,
|
|
314
|
+
)
|
|
315
|
+
return output
|
|
316
|
+
|
|
317
|
+
def _call(self, *, data, output, is_full, shift, done, total_keys, keys_to_expand):
|
|
318
|
+
# capture for each item in data where the observation should be written
|
|
319
|
+
idx = torch.arange(data.shape[0])
|
|
320
|
+
idx_done = (idx + done.cumsum(0))[done]
|
|
321
|
+
idx += torch.nn.functional.pad(done, [1, 0])[:-1].cumsum(0)
|
|
322
|
+
|
|
323
|
+
for key in total_keys:
|
|
324
|
+
if key in (self.done_keys.union(self.reward_keys)):
|
|
325
|
+
entry = data.get(("next", key))
|
|
326
|
+
else:
|
|
327
|
+
entry = data.get(key)
|
|
328
|
+
|
|
329
|
+
if key in keys_to_expand:
|
|
330
|
+
mmap = output.get(key)
|
|
331
|
+
shifted_next = data.get(("next", key))
|
|
332
|
+
if is_full:
|
|
333
|
+
_roll_inplace(entry, shift=-shift, out=mmap, index_dest=idx)
|
|
334
|
+
_roll_inplace(
|
|
335
|
+
shifted_next,
|
|
336
|
+
shift=-shift,
|
|
337
|
+
out=mmap,
|
|
338
|
+
index_dest=idx_done,
|
|
339
|
+
index_source=done,
|
|
340
|
+
)
|
|
341
|
+
else:
|
|
342
|
+
mmap[idx] = entry
|
|
343
|
+
mmap[idx_done] = shifted_next[done]
|
|
344
|
+
elif is_full:
|
|
345
|
+
mmap = output.get(key)
|
|
346
|
+
_roll_inplace(entry, shift=-shift, out=mmap)
|
|
347
|
+
else:
|
|
348
|
+
mmap = output.get(key)
|
|
349
|
+
mmap.copy_(entry)
|
|
350
|
+
return output
|
|
351
|
+
|
|
352
|
+
|
|
353
|
+
class Flat2TED:
|
|
354
|
+
"""A storage loading hook to deserialize flattened TED data to TED format.
|
|
355
|
+
|
|
356
|
+
Args:
|
|
357
|
+
done_key (NestedKey, optional): the key where the done states should be read.
|
|
358
|
+
Defaults to ``("next", "done")``.
|
|
359
|
+
shift_key (NestedKey, optional): the key where the shift will be written.
|
|
360
|
+
Defaults to "shift".
|
|
361
|
+
is_full_key (NestedKey, optional): the key where the is_full attribute will be written.
|
|
362
|
+
Defaults to "is_full".
|
|
363
|
+
done_keys (Tuple[NestedKey], optional): a tuple of nested keys indicating the done entries.
|
|
364
|
+
Defaults to ("done", "truncated", "terminated")
|
|
365
|
+
reward_keys (Tuple[NestedKey], optional): a tuple of nested keys indicating the reward entries.
|
|
366
|
+
Defaults to ("reward",)
|
|
367
|
+
|
|
368
|
+
Examples:
|
|
369
|
+
>>> import tempfile
|
|
370
|
+
>>>
|
|
371
|
+
>>> from tensordict import TensorDict
|
|
372
|
+
>>>
|
|
373
|
+
>>> from torchrl.collectors import Collector
|
|
374
|
+
>>> from torchrl.data import ReplayBuffer, TED2Flat, LazyMemmapStorage, Flat2TED
|
|
375
|
+
>>> from torchrl.envs import GymEnv
|
|
376
|
+
>>> import torch
|
|
377
|
+
>>>
|
|
378
|
+
>>> env = GymEnv("CartPole-v1")
|
|
379
|
+
>>> env.set_seed(0)
|
|
380
|
+
>>> torch.manual_seed(0)
|
|
381
|
+
>>> collector = Collector(env, policy=env.rand_step, total_frames=200, frames_per_batch=200)
|
|
382
|
+
>>> rb = ReplayBuffer(storage=LazyMemmapStorage(200))
|
|
383
|
+
>>> rb.register_save_hook(TED2Flat())
|
|
384
|
+
>>> with tempfile.TemporaryDirectory() as tmpdir:
|
|
385
|
+
... for i, data in enumerate(collector):
|
|
386
|
+
... rb.extend(data)
|
|
387
|
+
... rb.dumps(tmpdir)
|
|
388
|
+
... # load the data to represent it
|
|
389
|
+
... td = TensorDict.load(tmpdir + "/storage/")
|
|
390
|
+
...
|
|
391
|
+
... rb_load = ReplayBuffer(storage=LazyMemmapStorage(200))
|
|
392
|
+
... rb_load.register_load_hook(Flat2TED())
|
|
393
|
+
... rb_load.load(tmpdir)
|
|
394
|
+
... print("storage after loading", rb_load[:])
|
|
395
|
+
... assert (rb[:] == rb_load[:]).all()
|
|
396
|
+
storage after loading TensorDict(
|
|
397
|
+
fields={
|
|
398
|
+
action: MemoryMappedTensor(shape=torch.Size([200, 2]), device=cpu, dtype=torch.int64, is_shared=False),
|
|
399
|
+
collector: TensorDict(
|
|
400
|
+
fields={
|
|
401
|
+
traj_ids: MemoryMappedTensor(shape=torch.Size([200]), device=cpu, dtype=torch.int64, is_shared=False)},
|
|
402
|
+
batch_size=torch.Size([200]),
|
|
403
|
+
device=cpu,
|
|
404
|
+
is_shared=False),
|
|
405
|
+
done: MemoryMappedTensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
406
|
+
next: TensorDict(
|
|
407
|
+
fields={
|
|
408
|
+
done: MemoryMappedTensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
409
|
+
observation: MemoryMappedTensor(shape=torch.Size([200, 4]), device=cpu, dtype=torch.float32, is_shared=False),
|
|
410
|
+
reward: MemoryMappedTensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False),
|
|
411
|
+
terminated: MemoryMappedTensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
412
|
+
truncated: MemoryMappedTensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
|
|
413
|
+
batch_size=torch.Size([200]),
|
|
414
|
+
device=cpu,
|
|
415
|
+
is_shared=False),
|
|
416
|
+
observation: MemoryMappedTensor(shape=torch.Size([200, 4]), device=cpu, dtype=torch.float32, is_shared=False),
|
|
417
|
+
terminated: MemoryMappedTensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
418
|
+
truncated: MemoryMappedTensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
|
|
419
|
+
batch_size=torch.Size([200]),
|
|
420
|
+
device=cpu,
|
|
421
|
+
is_shared=False)
|
|
422
|
+
|
|
423
|
+
|
|
424
|
+
"""
|
|
425
|
+
|
|
426
|
+
def __init__(
|
|
427
|
+
self,
|
|
428
|
+
done_key="done",
|
|
429
|
+
shift_key="shift",
|
|
430
|
+
is_full_key="is_full",
|
|
431
|
+
done_keys=("done", "truncated", "terminated"),
|
|
432
|
+
reward_keys=("reward",),
|
|
433
|
+
):
|
|
434
|
+
self.done_key = done_key
|
|
435
|
+
self.shift_key = shift_key
|
|
436
|
+
self.is_full_key = is_full_key
|
|
437
|
+
self.done_keys = {unravel_key(key) for key in done_keys}
|
|
438
|
+
self.reward_keys = {unravel_key(key) for key in reward_keys}
|
|
439
|
+
|
|
440
|
+
def __call__(self, data: TensorDictBase, out: TensorDictBase = None):
|
|
441
|
+
_storage_shape = data.get_non_tensor("_storage_shape", default=None)
|
|
442
|
+
if isinstance(_storage_shape, int):
|
|
443
|
+
_storage_shape = torch.Size([_storage_shape])
|
|
444
|
+
shift = data.get_non_tensor(self.shift_key, default=None)
|
|
445
|
+
is_full = data.get_non_tensor(self.is_full_key, default=None)
|
|
446
|
+
done = (
|
|
447
|
+
data.get("done")
|
|
448
|
+
.reshape((*_storage_shape[1:], -1))
|
|
449
|
+
.contiguous()
|
|
450
|
+
.permute(-1, *range(0, len(_storage_shape) - 1))
|
|
451
|
+
.clone()
|
|
452
|
+
)
|
|
453
|
+
if not is_full:
|
|
454
|
+
# shift is the cursor place
|
|
455
|
+
done[shift - 1] = True
|
|
456
|
+
else:
|
|
457
|
+
# done = done.roll(-shift, dims=0)
|
|
458
|
+
done[-1] = True
|
|
459
|
+
|
|
460
|
+
if _storage_shape is not None and len(_storage_shape) > 1:
|
|
461
|
+
# iterate over data and allocate
|
|
462
|
+
if out is None:
|
|
463
|
+
# out = TensorDict(batch_size=_storage_shape)
|
|
464
|
+
# for i in range(out.ndim):
|
|
465
|
+
# if i >= 2:
|
|
466
|
+
# # FLattening the lazy stack will make the data unavailable - we need to find a way to make this
|
|
467
|
+
# # possible.
|
|
468
|
+
# raise RuntimeError(
|
|
469
|
+
# "Checkpointing an uninitialized buffer with more than 2 dimensions is currently not supported. "
|
|
470
|
+
# "Please file an issue on GitHub to ask for this feature!"
|
|
471
|
+
# )
|
|
472
|
+
# out = LazyStackedTensorDict(*out.unbind(i), stack_dim=i)
|
|
473
|
+
out = TensorDict(batch_size=_storage_shape)
|
|
474
|
+
for i in range(1, out.ndim):
|
|
475
|
+
if i >= 2:
|
|
476
|
+
# FLattening the lazy stack will make the data unavailable - we need to find a way to make this
|
|
477
|
+
# possible.
|
|
478
|
+
raise RuntimeError(
|
|
479
|
+
"Checkpointing an uninitialized buffer with more than 2 dimensions is currently not supported. "
|
|
480
|
+
"Please file an issue on GitHub to ask for this feature!"
|
|
481
|
+
)
|
|
482
|
+
out_list = [
|
|
483
|
+
out._get_sub_tensordict((slice(None),) * i + (j,))
|
|
484
|
+
for j in range(out.shape[i])
|
|
485
|
+
]
|
|
486
|
+
out = lazy_stack(out_list, i)
|
|
487
|
+
|
|
488
|
+
# Create a function that reads slices of the input data
|
|
489
|
+
with out.flatten(1, -1) if out.ndim > 2 else contextlib.nullcontext(
|
|
490
|
+
out
|
|
491
|
+
) as out_flat:
|
|
492
|
+
nsteps = done.shape[0]
|
|
493
|
+
n_elt_batch = done.shape[1:].numel()
|
|
494
|
+
traj_per_dim = done.sum(0)
|
|
495
|
+
|
|
496
|
+
start = 0
|
|
497
|
+
start_with_offset = start
|
|
498
|
+
stop_with_offset = 0
|
|
499
|
+
stop = 0
|
|
500
|
+
|
|
501
|
+
for out_unbound, traj_for_dim in zip(out_flat.unbind(-1), traj_per_dim):
|
|
502
|
+
stop_with_offset = stop_with_offset + nsteps + traj_for_dim
|
|
503
|
+
cur_slice_offset = slice(start_with_offset, stop_with_offset)
|
|
504
|
+
start_with_offset = stop_with_offset
|
|
505
|
+
|
|
506
|
+
stop = stop + nsteps
|
|
507
|
+
cur_slice = slice(start, stop)
|
|
508
|
+
start = stop
|
|
509
|
+
|
|
510
|
+
def _index(
|
|
511
|
+
key,
|
|
512
|
+
val,
|
|
513
|
+
cur_slice=cur_slice,
|
|
514
|
+
nsteps=nsteps,
|
|
515
|
+
n_elt_batch=n_elt_batch,
|
|
516
|
+
cur_slice_offset=cur_slice_offset,
|
|
517
|
+
):
|
|
518
|
+
if val.shape[0] != (nsteps * n_elt_batch):
|
|
519
|
+
return val[cur_slice_offset]
|
|
520
|
+
return val[cur_slice]
|
|
521
|
+
|
|
522
|
+
data_slice = data.named_apply(
|
|
523
|
+
_index, nested_keys=True, batch_size=[]
|
|
524
|
+
)
|
|
525
|
+
self._call(
|
|
526
|
+
data=data_slice,
|
|
527
|
+
out=out_unbound,
|
|
528
|
+
is_full=is_full,
|
|
529
|
+
shift=shift,
|
|
530
|
+
_storage_shape=_storage_shape,
|
|
531
|
+
)
|
|
532
|
+
return out
|
|
533
|
+
return self._call(
|
|
534
|
+
data=data,
|
|
535
|
+
out=out,
|
|
536
|
+
is_full=is_full,
|
|
537
|
+
shift=shift,
|
|
538
|
+
_storage_shape=_storage_shape,
|
|
539
|
+
)
|
|
540
|
+
|
|
541
|
+
def _call(self, *, data, out, _storage_shape, shift, is_full):
|
|
542
|
+
done = data.get(self.done_key)
|
|
543
|
+
done = done.clone()
|
|
544
|
+
|
|
545
|
+
nsteps = done.shape[0]
|
|
546
|
+
|
|
547
|
+
# capture for each item in data where the observation should be written
|
|
548
|
+
idx = torch.arange(done.shape[0])
|
|
549
|
+
padded_done = F.pad(done.squeeze(-1), [1, 0])
|
|
550
|
+
root_idx = idx + padded_done[:-1].cumsum(0)
|
|
551
|
+
next_idx = root_idx + 1
|
|
552
|
+
|
|
553
|
+
if out is None:
|
|
554
|
+
out = TensorDict(batch_size=[nsteps])
|
|
555
|
+
|
|
556
|
+
def maybe_roll(entry, out=None):
|
|
557
|
+
if is_full and shift is not None:
|
|
558
|
+
if out is not None:
|
|
559
|
+
_roll_inplace(entry, shift=shift, out=out)
|
|
560
|
+
return
|
|
561
|
+
else:
|
|
562
|
+
return entry.roll(shift, dims=0)
|
|
563
|
+
if out is not None:
|
|
564
|
+
out.copy_(entry)
|
|
565
|
+
return
|
|
566
|
+
return entry
|
|
567
|
+
|
|
568
|
+
root_idx = maybe_roll(root_idx)
|
|
569
|
+
next_idx = maybe_roll(next_idx)
|
|
570
|
+
if not is_full:
|
|
571
|
+
next_idx = next_idx[:-1]
|
|
572
|
+
|
|
573
|
+
for key, entry in data.items(True, True):
|
|
574
|
+
if entry.shape[0] == nsteps:
|
|
575
|
+
if key in (self.done_keys.union(self.reward_keys)):
|
|
576
|
+
if key != "reward" and key not in out.keys(True, True):
|
|
577
|
+
# Create a done state at the root full of 0s
|
|
578
|
+
out.set(key, torch.zeros_like(entry), inplace=True)
|
|
579
|
+
entry = maybe_roll(entry, out=out.get(("next", key), None))
|
|
580
|
+
if entry is not None:
|
|
581
|
+
out.set(("next", key), entry, inplace=True)
|
|
582
|
+
else:
|
|
583
|
+
# action and similar
|
|
584
|
+
entry = maybe_roll(entry, out=out.get(key, default=None))
|
|
585
|
+
if entry is not None:
|
|
586
|
+
# then out is not locked
|
|
587
|
+
out.set(key, entry, inplace=True)
|
|
588
|
+
else:
|
|
589
|
+
dest_next = out.get(("next", key), None)
|
|
590
|
+
if dest_next is not None:
|
|
591
|
+
if not is_full:
|
|
592
|
+
dest_next = dest_next[:-1]
|
|
593
|
+
dest_next.copy_(entry[next_idx])
|
|
594
|
+
else:
|
|
595
|
+
if not is_full:
|
|
596
|
+
val = entry[next_idx]
|
|
597
|
+
val = torch.cat([val, torch.zeros_like(val[:1])])
|
|
598
|
+
out.set(("next", key), val, inplace=True)
|
|
599
|
+
else:
|
|
600
|
+
out.set(("next", key), entry[next_idx], inplace=True)
|
|
601
|
+
|
|
602
|
+
dest = out.get(key, None)
|
|
603
|
+
if dest is not None:
|
|
604
|
+
dest.copy_(entry[root_idx])
|
|
605
|
+
else:
|
|
606
|
+
out.set(key, entry[root_idx], inplace=True)
|
|
607
|
+
return out
|
|
608
|
+
|
|
609
|
+
|
|
610
|
+
class TED2Nested(TED2Flat):
|
|
611
|
+
"""Converts a TED-formatted dataset into a tensordict populated with nested tensors where each row is a trajectory."""
|
|
612
|
+
|
|
613
|
+
_shift: int | None = None
|
|
614
|
+
_is_full: bool | None = None
|
|
615
|
+
|
|
616
|
+
def __init__(self, *args, **kwargs):
|
|
617
|
+
if not hasattr(torch, "_nested_compute_contiguous_strides_offsets"):
|
|
618
|
+
raise ValueError(
|
|
619
|
+
f"Unsupported torch version {torch.__version__}. "
|
|
620
|
+
f"torch>=2.4 is required for {type(self).__name__} to be used."
|
|
621
|
+
)
|
|
622
|
+
return super().__init__(*args, **kwargs)
|
|
623
|
+
|
|
624
|
+
def __call__(self, data: TensorDictBase, path: Path = None):
|
|
625
|
+
data = super().__call__(data, path=path)
|
|
626
|
+
|
|
627
|
+
shift = self.shift
|
|
628
|
+
is_full = self.is_full
|
|
629
|
+
storage_shape = data.get_non_tensor("_storage_shape", (-1,))
|
|
630
|
+
# place time at the end
|
|
631
|
+
storage_shape = (*storage_shape[1:], storage_shape[0])
|
|
632
|
+
|
|
633
|
+
done = data.get("done")
|
|
634
|
+
done = done.squeeze(-1).clone()
|
|
635
|
+
if not is_full:
|
|
636
|
+
done.view(storage_shape)[..., shift - 1] = True
|
|
637
|
+
# else:
|
|
638
|
+
done.view(storage_shape)[..., -1] = True
|
|
639
|
+
|
|
640
|
+
ntraj = done.sum()
|
|
641
|
+
|
|
642
|
+
nz = done.nonzero(as_tuple=True)[0]
|
|
643
|
+
traj_lengths = torch.cat([nz[:1] + 1, nz.diff()])
|
|
644
|
+
# if not is_full:
|
|
645
|
+
# traj_lengths = torch.cat(
|
|
646
|
+
# [traj_lengths, (done.shape[0] - traj_lengths.sum()).unsqueeze(0)]
|
|
647
|
+
# )
|
|
648
|
+
|
|
649
|
+
keys_to_expand, keys_to_keep = zip(
|
|
650
|
+
*[
|
|
651
|
+
(key, None) if val.shape[0] != done.shape[0] else (None, key)
|
|
652
|
+
for key, val in data.items(True, True)
|
|
653
|
+
]
|
|
654
|
+
)
|
|
655
|
+
keys_to_expand = [key for key in keys_to_expand if key is not None]
|
|
656
|
+
keys_to_keep = [key for key in keys_to_keep if key is not None]
|
|
657
|
+
|
|
658
|
+
out = TensorDict(batch_size=[ntraj])
|
|
659
|
+
out.update(dict(data.non_tensor_items()))
|
|
660
|
+
|
|
661
|
+
out.memmap_(path)
|
|
662
|
+
|
|
663
|
+
traj_lengths = traj_lengths.unsqueeze(-1)
|
|
664
|
+
if not is_full:
|
|
665
|
+
# Increment by one only the trajectories that are not terminal
|
|
666
|
+
traj_lengths_expand = traj_lengths + (
|
|
667
|
+
traj_lengths.cumsum(0) % storage_shape[-1] != 0
|
|
668
|
+
)
|
|
669
|
+
else:
|
|
670
|
+
traj_lengths_expand = traj_lengths + 1
|
|
671
|
+
for key in keys_to_expand:
|
|
672
|
+
val = data.get(key)
|
|
673
|
+
shape = torch.cat(
|
|
674
|
+
[
|
|
675
|
+
traj_lengths_expand,
|
|
676
|
+
torch.tensor(val.shape[1:], dtype=torch.long).repeat(
|
|
677
|
+
traj_lengths.numel(), 1
|
|
678
|
+
),
|
|
679
|
+
],
|
|
680
|
+
-1,
|
|
681
|
+
)
|
|
682
|
+
# This works because the storage location is the same as the previous one - no copy is done
|
|
683
|
+
# but a new shape is written
|
|
684
|
+
out.make_memmap_from_storage(
|
|
685
|
+
key, val.untyped_storage(), dtype=val.dtype, shape=shape
|
|
686
|
+
)
|
|
687
|
+
for key in keys_to_keep:
|
|
688
|
+
val = data.get(key)
|
|
689
|
+
shape = torch.cat(
|
|
690
|
+
[
|
|
691
|
+
traj_lengths,
|
|
692
|
+
torch.tensor(val.shape[1:], dtype=torch.long).repeat(
|
|
693
|
+
traj_lengths.numel(), 1
|
|
694
|
+
),
|
|
695
|
+
],
|
|
696
|
+
-1,
|
|
697
|
+
)
|
|
698
|
+
out.make_memmap_from_storage(
|
|
699
|
+
key, val.untyped_storage(), dtype=val.dtype, shape=shape
|
|
700
|
+
)
|
|
701
|
+
return out
|
|
702
|
+
|
|
703
|
+
|
|
704
|
+
class Nested2TED(Flat2TED):
|
|
705
|
+
"""Converts a nested tensordict where each row is a trajectory into the TED format."""
|
|
706
|
+
|
|
707
|
+
def __call__(self, data, out: TensorDictBase = None):
|
|
708
|
+
# Get a flat representation of data
|
|
709
|
+
def flatten_het_dim(tensor):
|
|
710
|
+
shape = [tensor.size(i) for i in range(2, tensor.ndim)]
|
|
711
|
+
tensor = torch.tensor(tensor.untyped_storage(), dtype=tensor.dtype).view(
|
|
712
|
+
-1, *shape
|
|
713
|
+
)
|
|
714
|
+
return tensor
|
|
715
|
+
|
|
716
|
+
data = data.apply(flatten_het_dim, batch_size=[])
|
|
717
|
+
data.auto_batch_size_()
|
|
718
|
+
return super().__call__(data, out=out)
|
|
719
|
+
|
|
720
|
+
|
|
721
|
+
class H5Split(TED2Flat):
|
|
722
|
+
"""Splits a dataset prepared with TED2Nested into a TensorDict where each trajectory is stored as views on their parent nested tensors."""
|
|
723
|
+
|
|
724
|
+
_shift: int | None = None
|
|
725
|
+
_is_full: bool | None = None
|
|
726
|
+
|
|
727
|
+
def __call__(self, data):
|
|
728
|
+
nzeros = int(math.ceil(math.log10(data.shape[0])))
|
|
729
|
+
|
|
730
|
+
result = TensorDict(
|
|
731
|
+
{
|
|
732
|
+
f"traj_{str(i).zfill(nzeros)}": _data
|
|
733
|
+
for i, _data in enumerate(data.filter_non_tensor_data().unbind(0))
|
|
734
|
+
}
|
|
735
|
+
).update(dict(data.non_tensor_items()))
|
|
736
|
+
|
|
737
|
+
return result
|
|
738
|
+
|
|
739
|
+
|
|
740
|
+
class H5Combine:
|
|
741
|
+
"""Combines trajectories in a persistent tensordict into a single standing tensordict stored in filesystem."""
|
|
742
|
+
|
|
743
|
+
def __call__(self, data, out=None):
|
|
744
|
+
# TODO: this load the entire H5 in memory, which can be problematic
|
|
745
|
+
# Ideally we would want to load it on a memmap tensordict
|
|
746
|
+
# We currently ignore out in this call but we should leverage that
|
|
747
|
+
values = [val for key, val in data.items() if key.startswith("traj")]
|
|
748
|
+
metadata_keys = [key for key in data.keys() if not key.startswith("traj")]
|
|
749
|
+
result = TensorDict({key: NonTensorData(data[key]) for key in metadata_keys})
|
|
750
|
+
|
|
751
|
+
# Create a memmap in file system (no files associated)
|
|
752
|
+
result.memmap_()
|
|
753
|
+
|
|
754
|
+
# Create each entry
|
|
755
|
+
def initialize(key, *x):
|
|
756
|
+
result.make_memmap(
|
|
757
|
+
key,
|
|
758
|
+
shape=torch.stack([torch.tensor(_x.shape) for _x in x]),
|
|
759
|
+
dtype=x[0].dtype,
|
|
760
|
+
)
|
|
761
|
+
return
|
|
762
|
+
|
|
763
|
+
values[0].named_apply(
|
|
764
|
+
initialize,
|
|
765
|
+
*values[1:],
|
|
766
|
+
nested_keys=True,
|
|
767
|
+
batch_size=[],
|
|
768
|
+
filter_empty=True,
|
|
769
|
+
)
|
|
770
|
+
|
|
771
|
+
# Populate the entries
|
|
772
|
+
def populate(key, *x):
|
|
773
|
+
dest = result.get(key)
|
|
774
|
+
for i, _x in enumerate(x):
|
|
775
|
+
dest[i].copy_(_x)
|
|
776
|
+
|
|
777
|
+
values[0].named_apply(
|
|
778
|
+
populate,
|
|
779
|
+
*values[1:],
|
|
780
|
+
nested_keys=True,
|
|
781
|
+
batch_size=[],
|
|
782
|
+
filter_empty=True,
|
|
783
|
+
)
|
|
784
|
+
return result
|
|
785
|
+
|
|
786
|
+
|
|
787
|
+
@implement_for("torch", "2.3", None)
|
|
788
|
+
def _path2str(path, default_name=None):
|
|
789
|
+
# Uses the Keys defined in pytree to build a path
|
|
790
|
+
from torch.utils._pytree import MappingKey, SequenceKey
|
|
791
|
+
|
|
792
|
+
if default_name is None:
|
|
793
|
+
default_name = SINGLE_TENSOR_BUFFER_NAME
|
|
794
|
+
if not path:
|
|
795
|
+
return default_name
|
|
796
|
+
if isinstance(path, tuple):
|
|
797
|
+
return "/".join([_path2str(_sub, default_name=default_name) for _sub in path])
|
|
798
|
+
if isinstance(path, MappingKey):
|
|
799
|
+
if not isinstance(path.key, (int, str, bytes)):
|
|
800
|
+
raise ValueError("Values must be of type int, str or bytes in PyTree maps.")
|
|
801
|
+
result = str(path.key)
|
|
802
|
+
if result == default_name:
|
|
803
|
+
raise RuntimeError(
|
|
804
|
+
"A tensor had the same identifier as the default name used when the buffer contains "
|
|
805
|
+
f"a single tensor (name={default_name}). This behavior is not allowed. Please rename your "
|
|
806
|
+
f"tensor in the map/dict or set a new default name with the environment variable SINGLE_TENSOR_BUFFER_NAME."
|
|
807
|
+
)
|
|
808
|
+
return result
|
|
809
|
+
if isinstance(path, SequenceKey):
|
|
810
|
+
return str(path.idx)
|
|
811
|
+
|
|
812
|
+
|
|
813
|
+
@implement_for("torch", None, "2.3")
|
|
814
|
+
def _path2str(path, default_name=None): # noqa: F811
|
|
815
|
+
raise RuntimeError
|
|
816
|
+
|
|
817
|
+
|
|
818
|
+
def _save_pytree_common(tensor_path, path, tensor, metadata):
|
|
819
|
+
if "." in tensor_path:
|
|
820
|
+
tensor_path.replace(".", "_<dot>_")
|
|
821
|
+
total_tensor_path = path / (tensor_path + ".memmap")
|
|
822
|
+
if os.path.exists(total_tensor_path):
|
|
823
|
+
MemoryMappedTensor.from_filename(
|
|
824
|
+
shape=tensor.shape,
|
|
825
|
+
filename=total_tensor_path,
|
|
826
|
+
dtype=tensor.dtype,
|
|
827
|
+
).copy_(tensor)
|
|
828
|
+
else:
|
|
829
|
+
os.makedirs(total_tensor_path.parent, exist_ok=True)
|
|
830
|
+
MemoryMappedTensor.from_tensor(
|
|
831
|
+
tensor,
|
|
832
|
+
filename=total_tensor_path,
|
|
833
|
+
copy_existing=True,
|
|
834
|
+
copy_data=True,
|
|
835
|
+
)
|
|
836
|
+
key = tensor_path.replace("/", ".")
|
|
837
|
+
if key in metadata:
|
|
838
|
+
raise KeyError(
|
|
839
|
+
"At least two values have conflicting representations in "
|
|
840
|
+
f"the data structure to be serialized: {key}."
|
|
841
|
+
)
|
|
842
|
+
metadata[key] = {
|
|
843
|
+
"dtype": str(tensor.dtype),
|
|
844
|
+
"shape": list(tensor.shape),
|
|
845
|
+
}
|
|
846
|
+
|
|
847
|
+
|
|
848
|
+
@implement_for("torch", "2.3", None)
|
|
849
|
+
def _save_pytree(_storage, metadata, path):
|
|
850
|
+
from torch.utils._pytree import tree_map_with_path
|
|
851
|
+
|
|
852
|
+
def save_tensor(
|
|
853
|
+
tensor_path: tuple, tensor: torch.Tensor, metadata=metadata, path=path
|
|
854
|
+
):
|
|
855
|
+
tensor_path = _path2str(tensor_path)
|
|
856
|
+
_save_pytree_common(tensor_path, path, tensor, metadata)
|
|
857
|
+
|
|
858
|
+
tree_map_with_path(save_tensor, _storage)
|
|
859
|
+
|
|
860
|
+
|
|
861
|
+
@implement_for("torch", None, "2.3")
|
|
862
|
+
def _save_pytree(_storage, metadata, path): # noqa: F811
|
|
863
|
+
|
|
864
|
+
flat_storage, storage_specs = tree_flatten(_storage)
|
|
865
|
+
storage_paths = _get_paths(storage_specs)
|
|
866
|
+
|
|
867
|
+
def save_tensor(
|
|
868
|
+
tensor_path: str, tensor: torch.Tensor, metadata=metadata, path=path
|
|
869
|
+
):
|
|
870
|
+
_save_pytree_common(tensor_path, path, tensor, metadata)
|
|
871
|
+
|
|
872
|
+
for tensor, tensor_path in zip(flat_storage, storage_paths):
|
|
873
|
+
save_tensor(tensor_path, tensor)
|
|
874
|
+
|
|
875
|
+
|
|
876
|
+
def _get_paths(spec, cumulpath=""):
|
|
877
|
+
# alternative way to build a path without the keys
|
|
878
|
+
if isinstance(spec, LeafSpec):
|
|
879
|
+
yield cumulpath if cumulpath else SINGLE_TENSOR_BUFFER_NAME
|
|
880
|
+
|
|
881
|
+
contexts = spec.context
|
|
882
|
+
children_specs = spec.children_specs
|
|
883
|
+
if contexts is None:
|
|
884
|
+
contexts = range(len(children_specs))
|
|
885
|
+
|
|
886
|
+
for context, spec in zip(contexts, children_specs):
|
|
887
|
+
cpath = "/".join((cumulpath, str(context))) if cumulpath else str(context)
|
|
888
|
+
yield from _get_paths(spec, cpath)
|
|
889
|
+
|
|
890
|
+
|
|
891
|
+
def _init_pytree_common(tensor_path, scratch_dir, max_size_fn, tensor):
|
|
892
|
+
if "." in tensor_path:
|
|
893
|
+
tensor_path.replace(".", "_<dot>_")
|
|
894
|
+
if scratch_dir is not None:
|
|
895
|
+
total_tensor_path = Path(scratch_dir) / (tensor_path + ".memmap")
|
|
896
|
+
if os.path.exists(total_tensor_path):
|
|
897
|
+
raise RuntimeError(
|
|
898
|
+
f"The storage of tensor {total_tensor_path} already exists. "
|
|
899
|
+
f"To load an existing replay buffer, use storage.loads. "
|
|
900
|
+
f"Choose a different path to store your buffer or delete the existing files."
|
|
901
|
+
)
|
|
902
|
+
os.makedirs(total_tensor_path.parent, exist_ok=True)
|
|
903
|
+
else:
|
|
904
|
+
total_tensor_path = None
|
|
905
|
+
out = MemoryMappedTensor.empty(
|
|
906
|
+
shape=max_size_fn(tensor.shape),
|
|
907
|
+
filename=total_tensor_path,
|
|
908
|
+
dtype=tensor.dtype,
|
|
909
|
+
)
|
|
910
|
+
try:
|
|
911
|
+
filesize = os.path.getsize(tensor.filename) / 1024 / 1024
|
|
912
|
+
torchrl_logger.debug(
|
|
913
|
+
f"The storage was created in {out.filename} and occupies {filesize} Mb of storage."
|
|
914
|
+
)
|
|
915
|
+
except (RuntimeError, AttributeError):
|
|
916
|
+
pass
|
|
917
|
+
return out
|
|
918
|
+
|
|
919
|
+
|
|
920
|
+
@implement_for("torch", "2.3", None)
|
|
921
|
+
def _init_pytree(scratch_dir, max_size_fn, data):
|
|
922
|
+
from torch.utils._pytree import tree_map_with_path
|
|
923
|
+
|
|
924
|
+
# If not a tensorclass/tensordict, it must be a tensor(-like) or a PyTree
|
|
925
|
+
# if Tensor, we just create a MemoryMappedTensor of the desired shape, device and dtype
|
|
926
|
+
def save_tensor(tensor_path: tuple, tensor: torch.Tensor):
|
|
927
|
+
tensor_path = _path2str(tensor_path)
|
|
928
|
+
return _init_pytree_common(tensor_path, scratch_dir, max_size_fn, tensor)
|
|
929
|
+
|
|
930
|
+
out = tree_map_with_path(save_tensor, data)
|
|
931
|
+
return out
|
|
932
|
+
|
|
933
|
+
|
|
934
|
+
@implement_for("torch", None, "2.3")
|
|
935
|
+
def _init_pytree(scratch_dir, max_size, data): # noqa: F811
|
|
936
|
+
|
|
937
|
+
flat_data, data_specs = tree_flatten(data)
|
|
938
|
+
data_paths = _get_paths(data_specs)
|
|
939
|
+
data_paths = list(data_paths)
|
|
940
|
+
|
|
941
|
+
# If not a tensorclass/tensordict, it must be a tensor(-like) or a PyTree
|
|
942
|
+
# if Tensor, we just create a MemoryMappedTensor of the desired shape, device and dtype
|
|
943
|
+
def save_tensor(tensor_path: str, tensor: torch.Tensor):
|
|
944
|
+
return _init_pytree_common(tensor_path, scratch_dir, max_size, tensor)
|
|
945
|
+
|
|
946
|
+
out = []
|
|
947
|
+
for tensor, tensor_path in zip(flat_data, data_paths):
|
|
948
|
+
out.append(save_tensor(tensor_path, tensor))
|
|
949
|
+
|
|
950
|
+
return tree_unflatten(out, data_specs)
|
|
951
|
+
|
|
952
|
+
|
|
953
|
+
def _roll_inplace(tensor, shift, out, index_dest=None, index_source=None):
|
|
954
|
+
# slice 0
|
|
955
|
+
source0 = tensor[:-shift]
|
|
956
|
+
if index_source is not None:
|
|
957
|
+
source0 = source0[index_source[shift:]]
|
|
958
|
+
|
|
959
|
+
slice0_shift = source0.shape[0]
|
|
960
|
+
if index_dest is not None:
|
|
961
|
+
out[index_dest[-slice0_shift:]] = source0
|
|
962
|
+
else:
|
|
963
|
+
slice0 = out[-slice0_shift:]
|
|
964
|
+
slice0.copy_(source0)
|
|
965
|
+
|
|
966
|
+
# slice 1
|
|
967
|
+
source1 = tensor[-shift:]
|
|
968
|
+
if index_source is not None:
|
|
969
|
+
source1 = source1[index_source[:shift]]
|
|
970
|
+
if index_dest is not None:
|
|
971
|
+
out[index_dest[:-slice0_shift]] = source1
|
|
972
|
+
else:
|
|
973
|
+
slice1 = out[:-slice0_shift]
|
|
974
|
+
slice1.copy_(source1)
|
|
975
|
+
return out
|
|
976
|
+
|
|
977
|
+
|
|
978
|
+
# Copy-paste of unravel-index for PT 2.0
|
|
979
|
+
def _unravel_index(
|
|
980
|
+
indices: Tensor, shape: int | typing.Sequence[int] | torch.Size
|
|
981
|
+
) -> tuple[Tensor, ...]:
|
|
982
|
+
res_tensor = _unravel_index_impl(indices, shape)
|
|
983
|
+
return res_tensor.unbind(-1)
|
|
984
|
+
|
|
985
|
+
|
|
986
|
+
def _unravel_index_impl(indices: Tensor, shape: int | typing.Sequence[int]) -> Tensor:
|
|
987
|
+
if isinstance(shape, (int, torch.SymInt)):
|
|
988
|
+
shape = torch.Size([shape])
|
|
989
|
+
else:
|
|
990
|
+
shape = torch.Size(shape)
|
|
991
|
+
|
|
992
|
+
coefs = list(
|
|
993
|
+
reversed(
|
|
994
|
+
list(
|
|
995
|
+
itertools.accumulate(
|
|
996
|
+
reversed(shape[1:] + torch.Size([1])), func=operator.mul
|
|
997
|
+
)
|
|
998
|
+
)
|
|
999
|
+
)
|
|
1000
|
+
)
|
|
1001
|
+
return indices.unsqueeze(-1).floor_divide(
|
|
1002
|
+
torch.tensor(coefs, device=indices.device, dtype=torch.int64)
|
|
1003
|
+
) % torch.tensor(shape, device=indices.device, dtype=torch.int64)
|
|
1004
|
+
|
|
1005
|
+
|
|
1006
|
+
@implement_for("torch", None, "2.2")
|
|
1007
|
+
def unravel_index(indices, shape):
|
|
1008
|
+
"""A version-compatible wrapper around torch.unravel_index."""
|
|
1009
|
+
return _unravel_index(indices, shape)
|
|
1010
|
+
|
|
1011
|
+
|
|
1012
|
+
@implement_for("torch", "2.2")
|
|
1013
|
+
def unravel_index(indices, shape): # noqa: F811
|
|
1014
|
+
"""A version-compatible wrapper around torch.unravel_index."""
|
|
1015
|
+
return torch.unravel_index(indices, shape)
|
|
1016
|
+
|
|
1017
|
+
|
|
1018
|
+
@implement_for("torch", None, "2.3")
|
|
1019
|
+
def tree_iter(pytree):
|
|
1020
|
+
"""A version-compatible wrapper around tree_iter."""
|
|
1021
|
+
flat_tree, _ = torch.utils._pytree.tree_flatten(pytree)
|
|
1022
|
+
yield from flat_tree
|
|
1023
|
+
|
|
1024
|
+
|
|
1025
|
+
@implement_for("torch", "2.3", "2.4")
|
|
1026
|
+
def tree_iter(pytree): # noqa: F811
|
|
1027
|
+
"""A version-compatible wrapper around tree_iter."""
|
|
1028
|
+
yield from torch.utils._pytree.tree_leaves(pytree)
|
|
1029
|
+
|
|
1030
|
+
|
|
1031
|
+
@implement_for("torch", "2.4")
|
|
1032
|
+
def tree_iter(pytree): # noqa: F811
|
|
1033
|
+
"""A version-compatible wrapper around tree_iter."""
|
|
1034
|
+
yield from torch.utils._pytree.tree_iter(pytree)
|
|
1035
|
+
|
|
1036
|
+
|
|
1037
|
+
def _auto_device() -> torch.device:
|
|
1038
|
+
if torch.cuda.is_available():
|
|
1039
|
+
return torch.device("cuda:0")
|
|
1040
|
+
elif torch.mps.is_available():
|
|
1041
|
+
return torch.device("mps:0")
|
|
1042
|
+
return torch.device("cpu")
|