torchrl 0.11.0__cp314-cp314-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/benchmark_batched_envs.py +104 -0
- benchmarks/conftest.py +91 -0
- benchmarks/ecosystem/gym_env_throughput.py +321 -0
- benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
- benchmarks/requirements.txt +7 -0
- benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
- benchmarks/test_collectors_benchmark.py +240 -0
- benchmarks/test_compressed_storage_benchmark.py +145 -0
- benchmarks/test_envs_benchmark.py +133 -0
- benchmarks/test_llm.py +101 -0
- benchmarks/test_non_tensor_env_benchmark.py +70 -0
- benchmarks/test_objectives_benchmarks.py +1199 -0
- benchmarks/test_replaybuffer_benchmark.py +254 -0
- sota-check/README.md +35 -0
- sota-implementations/README.md +142 -0
- sota-implementations/a2c/README.md +39 -0
- sota-implementations/a2c/a2c_atari.py +291 -0
- sota-implementations/a2c/a2c_mujoco.py +273 -0
- sota-implementations/a2c/utils_atari.py +240 -0
- sota-implementations/a2c/utils_mujoco.py +160 -0
- sota-implementations/bandits/README.md +7 -0
- sota-implementations/bandits/dqn.py +126 -0
- sota-implementations/cql/cql_offline.py +198 -0
- sota-implementations/cql/cql_online.py +249 -0
- sota-implementations/cql/discrete_cql_offline.py +180 -0
- sota-implementations/cql/discrete_cql_online.py +227 -0
- sota-implementations/cql/utils.py +471 -0
- sota-implementations/crossq/crossq.py +271 -0
- sota-implementations/crossq/utils.py +320 -0
- sota-implementations/ddpg/ddpg.py +231 -0
- sota-implementations/ddpg/utils.py +325 -0
- sota-implementations/decision_transformer/dt.py +163 -0
- sota-implementations/decision_transformer/lamb.py +167 -0
- sota-implementations/decision_transformer/online_dt.py +178 -0
- sota-implementations/decision_transformer/utils.py +562 -0
- sota-implementations/discrete_sac/discrete_sac.py +243 -0
- sota-implementations/discrete_sac/utils.py +324 -0
- sota-implementations/dqn/README.md +30 -0
- sota-implementations/dqn/dqn_atari.py +272 -0
- sota-implementations/dqn/dqn_cartpole.py +236 -0
- sota-implementations/dqn/utils_atari.py +132 -0
- sota-implementations/dqn/utils_cartpole.py +90 -0
- sota-implementations/dreamer/README.md +129 -0
- sota-implementations/dreamer/dreamer.py +586 -0
- sota-implementations/dreamer/dreamer_utils.py +1107 -0
- sota-implementations/expert-iteration/README.md +352 -0
- sota-implementations/expert-iteration/ei_utils.py +770 -0
- sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
- sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
- sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
- sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
- sota-implementations/gail/gail.py +327 -0
- sota-implementations/gail/gail_utils.py +68 -0
- sota-implementations/gail/ppo_utils.py +157 -0
- sota-implementations/grpo/README.md +273 -0
- sota-implementations/grpo/grpo-async.py +437 -0
- sota-implementations/grpo/grpo-sync.py +435 -0
- sota-implementations/grpo/grpo_utils.py +843 -0
- sota-implementations/grpo/requirements_gsm8k.txt +11 -0
- sota-implementations/grpo/requirements_ifeval.txt +16 -0
- sota-implementations/impala/README.md +33 -0
- sota-implementations/impala/impala_multi_node_ray.py +292 -0
- sota-implementations/impala/impala_multi_node_submitit.py +284 -0
- sota-implementations/impala/impala_single_node.py +261 -0
- sota-implementations/impala/utils.py +184 -0
- sota-implementations/iql/discrete_iql.py +230 -0
- sota-implementations/iql/iql_offline.py +164 -0
- sota-implementations/iql/iql_online.py +225 -0
- sota-implementations/iql/utils.py +437 -0
- sota-implementations/multiagent/README.md +74 -0
- sota-implementations/multiagent/iql.py +237 -0
- sota-implementations/multiagent/maddpg_iddpg.py +266 -0
- sota-implementations/multiagent/mappo_ippo.py +267 -0
- sota-implementations/multiagent/qmix_vdn.py +271 -0
- sota-implementations/multiagent/sac.py +337 -0
- sota-implementations/multiagent/utils/__init__.py +4 -0
- sota-implementations/multiagent/utils/logging.py +151 -0
- sota-implementations/multiagent/utils/utils.py +43 -0
- sota-implementations/ppo/README.md +29 -0
- sota-implementations/ppo/ppo_atari.py +305 -0
- sota-implementations/ppo/ppo_mujoco.py +293 -0
- sota-implementations/ppo/utils_atari.py +238 -0
- sota-implementations/ppo/utils_mujoco.py +152 -0
- sota-implementations/ppo_trainer/train.py +21 -0
- sota-implementations/redq/README.md +7 -0
- sota-implementations/redq/redq.py +199 -0
- sota-implementations/redq/utils.py +1060 -0
- sota-implementations/sac/sac-async.py +266 -0
- sota-implementations/sac/sac.py +239 -0
- sota-implementations/sac/utils.py +381 -0
- sota-implementations/sac_trainer/train.py +16 -0
- sota-implementations/td3/td3.py +254 -0
- sota-implementations/td3/utils.py +319 -0
- sota-implementations/td3_bc/td3_bc.py +177 -0
- sota-implementations/td3_bc/utils.py +251 -0
- torchrl/.dylibs/libc++.1.0.dylib +0 -0
- torchrl/__init__.py +144 -0
- torchrl/_extension.py +74 -0
- torchrl/_torchrl.cpython-314-darwin.so +0 -0
- torchrl/_utils.py +1431 -0
- torchrl/collectors/__init__.py +48 -0
- torchrl/collectors/_base.py +1058 -0
- torchrl/collectors/_constants.py +88 -0
- torchrl/collectors/_multi_async.py +324 -0
- torchrl/collectors/_multi_base.py +1805 -0
- torchrl/collectors/_multi_sync.py +464 -0
- torchrl/collectors/_runner.py +581 -0
- torchrl/collectors/_single.py +2009 -0
- torchrl/collectors/_single_async.py +259 -0
- torchrl/collectors/collectors.py +62 -0
- torchrl/collectors/distributed/__init__.py +32 -0
- torchrl/collectors/distributed/default_configs.py +133 -0
- torchrl/collectors/distributed/generic.py +1306 -0
- torchrl/collectors/distributed/ray.py +1092 -0
- torchrl/collectors/distributed/rpc.py +1006 -0
- torchrl/collectors/distributed/sync.py +731 -0
- torchrl/collectors/distributed/utils.py +160 -0
- torchrl/collectors/llm/__init__.py +10 -0
- torchrl/collectors/llm/base.py +494 -0
- torchrl/collectors/llm/ray_collector.py +275 -0
- torchrl/collectors/llm/utils.py +36 -0
- torchrl/collectors/llm/weight_update/__init__.py +10 -0
- torchrl/collectors/llm/weight_update/vllm.py +348 -0
- torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
- torchrl/collectors/utils.py +433 -0
- torchrl/collectors/weight_update.py +591 -0
- torchrl/csrc/numpy_utils.h +38 -0
- torchrl/csrc/pybind.cpp +27 -0
- torchrl/csrc/segment_tree.h +458 -0
- torchrl/csrc/torch_utils.h +34 -0
- torchrl/csrc/utils.cpp +48 -0
- torchrl/csrc/utils.h +31 -0
- torchrl/data/__init__.py +187 -0
- torchrl/data/datasets/__init__.py +58 -0
- torchrl/data/datasets/atari_dqn.py +878 -0
- torchrl/data/datasets/common.py +281 -0
- torchrl/data/datasets/d4rl.py +489 -0
- torchrl/data/datasets/d4rl_infos.py +187 -0
- torchrl/data/datasets/gen_dgrl.py +375 -0
- torchrl/data/datasets/minari_data.py +643 -0
- torchrl/data/datasets/openml.py +177 -0
- torchrl/data/datasets/openx.py +798 -0
- torchrl/data/datasets/roboset.py +363 -0
- torchrl/data/datasets/utils.py +11 -0
- torchrl/data/datasets/vd4rl.py +432 -0
- torchrl/data/llm/__init__.py +34 -0
- torchrl/data/llm/dataset.py +491 -0
- torchrl/data/llm/history.py +1378 -0
- torchrl/data/llm/prompt.py +198 -0
- torchrl/data/llm/reward.py +225 -0
- torchrl/data/llm/topk.py +186 -0
- torchrl/data/llm/utils.py +543 -0
- torchrl/data/map/__init__.py +21 -0
- torchrl/data/map/hash.py +185 -0
- torchrl/data/map/query.py +204 -0
- torchrl/data/map/tdstorage.py +363 -0
- torchrl/data/map/tree.py +1434 -0
- torchrl/data/map/utils.py +103 -0
- torchrl/data/postprocs/__init__.py +8 -0
- torchrl/data/postprocs/postprocs.py +391 -0
- torchrl/data/replay_buffers/__init__.py +99 -0
- torchrl/data/replay_buffers/checkpointers.py +622 -0
- torchrl/data/replay_buffers/ray_buffer.py +292 -0
- torchrl/data/replay_buffers/replay_buffers.py +2376 -0
- torchrl/data/replay_buffers/samplers.py +2578 -0
- torchrl/data/replay_buffers/scheduler.py +265 -0
- torchrl/data/replay_buffers/storages.py +2412 -0
- torchrl/data/replay_buffers/utils.py +1042 -0
- torchrl/data/replay_buffers/writers.py +781 -0
- torchrl/data/tensor_specs.py +7101 -0
- torchrl/data/utils.py +334 -0
- torchrl/envs/__init__.py +265 -0
- torchrl/envs/async_envs.py +1105 -0
- torchrl/envs/batched_envs.py +3093 -0
- torchrl/envs/common.py +4241 -0
- torchrl/envs/custom/__init__.py +11 -0
- torchrl/envs/custom/chess.py +617 -0
- torchrl/envs/custom/llm.py +214 -0
- torchrl/envs/custom/pendulum.py +401 -0
- torchrl/envs/custom/san_moves.txt +29274 -0
- torchrl/envs/custom/tictactoeenv.py +288 -0
- torchrl/envs/env_creator.py +263 -0
- torchrl/envs/gym_like.py +752 -0
- torchrl/envs/libs/__init__.py +68 -0
- torchrl/envs/libs/_gym_utils.py +326 -0
- torchrl/envs/libs/brax.py +846 -0
- torchrl/envs/libs/dm_control.py +544 -0
- torchrl/envs/libs/envpool.py +447 -0
- torchrl/envs/libs/gym.py +2239 -0
- torchrl/envs/libs/habitat.py +138 -0
- torchrl/envs/libs/isaac_lab.py +87 -0
- torchrl/envs/libs/isaacgym.py +203 -0
- torchrl/envs/libs/jax_utils.py +166 -0
- torchrl/envs/libs/jumanji.py +963 -0
- torchrl/envs/libs/meltingpot.py +599 -0
- torchrl/envs/libs/openml.py +153 -0
- torchrl/envs/libs/openspiel.py +652 -0
- torchrl/envs/libs/pettingzoo.py +1042 -0
- torchrl/envs/libs/procgen.py +351 -0
- torchrl/envs/libs/robohive.py +429 -0
- torchrl/envs/libs/smacv2.py +645 -0
- torchrl/envs/libs/unity_mlagents.py +891 -0
- torchrl/envs/libs/utils.py +147 -0
- torchrl/envs/libs/vmas.py +813 -0
- torchrl/envs/llm/__init__.py +63 -0
- torchrl/envs/llm/chat.py +730 -0
- torchrl/envs/llm/datasets/README.md +4 -0
- torchrl/envs/llm/datasets/__init__.py +17 -0
- torchrl/envs/llm/datasets/gsm8k.py +353 -0
- torchrl/envs/llm/datasets/ifeval.py +274 -0
- torchrl/envs/llm/envs.py +789 -0
- torchrl/envs/llm/libs/README.md +3 -0
- torchrl/envs/llm/libs/__init__.py +8 -0
- torchrl/envs/llm/libs/mlgym.py +869 -0
- torchrl/envs/llm/reward/__init__.py +10 -0
- torchrl/envs/llm/reward/gsm8k.py +324 -0
- torchrl/envs/llm/reward/ifeval/README.md +13 -0
- torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
- torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
- torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
- torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
- torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
- torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
- torchrl/envs/llm/transforms/__init__.py +55 -0
- torchrl/envs/llm/transforms/browser.py +292 -0
- torchrl/envs/llm/transforms/dataloading.py +859 -0
- torchrl/envs/llm/transforms/format.py +73 -0
- torchrl/envs/llm/transforms/kl.py +1544 -0
- torchrl/envs/llm/transforms/policy_version.py +189 -0
- torchrl/envs/llm/transforms/reason.py +323 -0
- torchrl/envs/llm/transforms/tokenizer.py +321 -0
- torchrl/envs/llm/transforms/tools.py +1955 -0
- torchrl/envs/model_based/__init__.py +9 -0
- torchrl/envs/model_based/common.py +180 -0
- torchrl/envs/model_based/dreamer.py +112 -0
- torchrl/envs/transforms/__init__.py +147 -0
- torchrl/envs/transforms/functional.py +48 -0
- torchrl/envs/transforms/gym_transforms.py +203 -0
- torchrl/envs/transforms/module.py +341 -0
- torchrl/envs/transforms/r3m.py +372 -0
- torchrl/envs/transforms/ray_service.py +663 -0
- torchrl/envs/transforms/rb_transforms.py +214 -0
- torchrl/envs/transforms/transforms.py +11835 -0
- torchrl/envs/transforms/utils.py +94 -0
- torchrl/envs/transforms/vc1.py +307 -0
- torchrl/envs/transforms/vecnorm.py +845 -0
- torchrl/envs/transforms/vip.py +407 -0
- torchrl/envs/utils.py +1718 -0
- torchrl/envs/vec_envs.py +11 -0
- torchrl/modules/__init__.py +206 -0
- torchrl/modules/distributions/__init__.py +73 -0
- torchrl/modules/distributions/continuous.py +830 -0
- torchrl/modules/distributions/discrete.py +908 -0
- torchrl/modules/distributions/truncated_normal.py +187 -0
- torchrl/modules/distributions/utils.py +233 -0
- torchrl/modules/llm/__init__.py +62 -0
- torchrl/modules/llm/backends/__init__.py +65 -0
- torchrl/modules/llm/backends/vllm/__init__.py +94 -0
- torchrl/modules/llm/backends/vllm/_models.py +46 -0
- torchrl/modules/llm/backends/vllm/base.py +72 -0
- torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
- torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
- torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
- torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
- torchrl/modules/llm/policies/__init__.py +28 -0
- torchrl/modules/llm/policies/common.py +1809 -0
- torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
- torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
- torchrl/modules/llm/utils.py +23 -0
- torchrl/modules/mcts/__init__.py +21 -0
- torchrl/modules/mcts/scores.py +579 -0
- torchrl/modules/models/__init__.py +86 -0
- torchrl/modules/models/batchrenorm.py +119 -0
- torchrl/modules/models/decision_transformer.py +179 -0
- torchrl/modules/models/exploration.py +731 -0
- torchrl/modules/models/llm.py +156 -0
- torchrl/modules/models/model_based.py +596 -0
- torchrl/modules/models/models.py +1712 -0
- torchrl/modules/models/multiagent.py +1067 -0
- torchrl/modules/models/recipes/impala.py +185 -0
- torchrl/modules/models/utils.py +162 -0
- torchrl/modules/planners/__init__.py +10 -0
- torchrl/modules/planners/cem.py +228 -0
- torchrl/modules/planners/common.py +73 -0
- torchrl/modules/planners/mppi.py +265 -0
- torchrl/modules/tensordict_module/__init__.py +89 -0
- torchrl/modules/tensordict_module/actors.py +2457 -0
- torchrl/modules/tensordict_module/common.py +529 -0
- torchrl/modules/tensordict_module/exploration.py +814 -0
- torchrl/modules/tensordict_module/probabilistic.py +321 -0
- torchrl/modules/tensordict_module/rnn.py +1639 -0
- torchrl/modules/tensordict_module/sequence.py +132 -0
- torchrl/modules/tensordict_module/world_models.py +34 -0
- torchrl/modules/utils/__init__.py +38 -0
- torchrl/modules/utils/mappings.py +9 -0
- torchrl/modules/utils/utils.py +89 -0
- torchrl/objectives/__init__.py +78 -0
- torchrl/objectives/a2c.py +659 -0
- torchrl/objectives/common.py +753 -0
- torchrl/objectives/cql.py +1346 -0
- torchrl/objectives/crossq.py +710 -0
- torchrl/objectives/ddpg.py +453 -0
- torchrl/objectives/decision_transformer.py +371 -0
- torchrl/objectives/deprecated.py +516 -0
- torchrl/objectives/dqn.py +683 -0
- torchrl/objectives/dreamer.py +488 -0
- torchrl/objectives/functional.py +48 -0
- torchrl/objectives/gail.py +258 -0
- torchrl/objectives/iql.py +996 -0
- torchrl/objectives/llm/__init__.py +30 -0
- torchrl/objectives/llm/grpo.py +846 -0
- torchrl/objectives/llm/sft.py +482 -0
- torchrl/objectives/multiagent/__init__.py +8 -0
- torchrl/objectives/multiagent/qmixer.py +396 -0
- torchrl/objectives/ppo.py +1669 -0
- torchrl/objectives/redq.py +683 -0
- torchrl/objectives/reinforce.py +530 -0
- torchrl/objectives/sac.py +1580 -0
- torchrl/objectives/td3.py +570 -0
- torchrl/objectives/td3_bc.py +625 -0
- torchrl/objectives/utils.py +782 -0
- torchrl/objectives/value/__init__.py +28 -0
- torchrl/objectives/value/advantages.py +1956 -0
- torchrl/objectives/value/functional.py +1459 -0
- torchrl/objectives/value/utils.py +360 -0
- torchrl/record/__init__.py +17 -0
- torchrl/record/loggers/__init__.py +23 -0
- torchrl/record/loggers/common.py +48 -0
- torchrl/record/loggers/csv.py +226 -0
- torchrl/record/loggers/mlflow.py +142 -0
- torchrl/record/loggers/tensorboard.py +139 -0
- torchrl/record/loggers/trackio.py +163 -0
- torchrl/record/loggers/utils.py +78 -0
- torchrl/record/loggers/wandb.py +214 -0
- torchrl/record/recorder.py +554 -0
- torchrl/services/__init__.py +79 -0
- torchrl/services/base.py +109 -0
- torchrl/services/ray_service.py +453 -0
- torchrl/testing/__init__.py +107 -0
- torchrl/testing/assertions.py +179 -0
- torchrl/testing/dist_utils.py +122 -0
- torchrl/testing/env_creators.py +227 -0
- torchrl/testing/env_helper.py +35 -0
- torchrl/testing/gym_helpers.py +156 -0
- torchrl/testing/llm_mocks.py +119 -0
- torchrl/testing/mocking_classes.py +2720 -0
- torchrl/testing/modules.py +295 -0
- torchrl/testing/mp_helpers.py +15 -0
- torchrl/testing/ray_helpers.py +293 -0
- torchrl/testing/utils.py +190 -0
- torchrl/trainers/__init__.py +42 -0
- torchrl/trainers/algorithms/__init__.py +11 -0
- torchrl/trainers/algorithms/configs/__init__.py +705 -0
- torchrl/trainers/algorithms/configs/collectors.py +216 -0
- torchrl/trainers/algorithms/configs/common.py +41 -0
- torchrl/trainers/algorithms/configs/data.py +308 -0
- torchrl/trainers/algorithms/configs/envs.py +104 -0
- torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
- torchrl/trainers/algorithms/configs/logging.py +80 -0
- torchrl/trainers/algorithms/configs/modules.py +570 -0
- torchrl/trainers/algorithms/configs/objectives.py +177 -0
- torchrl/trainers/algorithms/configs/trainers.py +340 -0
- torchrl/trainers/algorithms/configs/transforms.py +955 -0
- torchrl/trainers/algorithms/configs/utils.py +252 -0
- torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
- torchrl/trainers/algorithms/configs/weight_update.py +159 -0
- torchrl/trainers/algorithms/ppo.py +373 -0
- torchrl/trainers/algorithms/sac.py +308 -0
- torchrl/trainers/helpers/__init__.py +40 -0
- torchrl/trainers/helpers/collectors.py +416 -0
- torchrl/trainers/helpers/envs.py +573 -0
- torchrl/trainers/helpers/logger.py +33 -0
- torchrl/trainers/helpers/losses.py +132 -0
- torchrl/trainers/helpers/models.py +658 -0
- torchrl/trainers/helpers/replay_buffer.py +59 -0
- torchrl/trainers/helpers/trainers.py +301 -0
- torchrl/trainers/trainers.py +2052 -0
- torchrl/weight_update/__init__.py +33 -0
- torchrl/weight_update/_distributed.py +749 -0
- torchrl/weight_update/_mp.py +624 -0
- torchrl/weight_update/_noupdate.py +102 -0
- torchrl/weight_update/_ray.py +1032 -0
- torchrl/weight_update/_rpc.py +284 -0
- torchrl/weight_update/_shared.py +891 -0
- torchrl/weight_update/llm/__init__.py +32 -0
- torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
- torchrl/weight_update/llm/vllm_nccl.py +710 -0
- torchrl/weight_update/utils.py +73 -0
- torchrl/weight_update/weight_sync_schemes.py +1244 -0
- torchrl-0.11.0.dist-info/METADATA +1308 -0
- torchrl-0.11.0.dist-info/RECORD +395 -0
- torchrl-0.11.0.dist-info/WHEEL +5 -0
- torchrl-0.11.0.dist-info/entry_points.txt +2 -0
- torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
- torchrl-0.11.0.dist-info/top_level.txt +7 -0
|
@@ -0,0 +1,363 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import abc
|
|
8
|
+
import functools
|
|
9
|
+
from abc import abstractmethod
|
|
10
|
+
from collections.abc import Callable
|
|
11
|
+
from typing import Any, Generic, TypeVar
|
|
12
|
+
|
|
13
|
+
import torch
|
|
14
|
+
from tensordict import is_tensor_collection, NestedKey, TensorDictBase
|
|
15
|
+
from tensordict.nn.common import TensorDictModuleBase
|
|
16
|
+
|
|
17
|
+
from torchrl.data.map.hash import RandomProjectionHash, SipHash
|
|
18
|
+
from torchrl.data.map.query import QueryModule
|
|
19
|
+
from torchrl.data.replay_buffers.storages import (
|
|
20
|
+
_get_default_collate,
|
|
21
|
+
LazyTensorStorage,
|
|
22
|
+
TensorStorage,
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
K = TypeVar("K")
|
|
26
|
+
V = TypeVar("V")
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class TensorMap(abc.ABC, Generic[K, V]):
|
|
30
|
+
"""An Abstraction for implementing different storage.
|
|
31
|
+
|
|
32
|
+
This class is for internal use, please use derived classes instead.
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
@abstractmethod
|
|
36
|
+
def clear(self) -> None:
|
|
37
|
+
raise NotImplementedError
|
|
38
|
+
|
|
39
|
+
@abstractmethod
|
|
40
|
+
def __getitem__(self, item: K) -> V:
|
|
41
|
+
raise NotImplementedError
|
|
42
|
+
|
|
43
|
+
@abstractmethod
|
|
44
|
+
def __setitem__(self, key: K, value: V) -> None:
|
|
45
|
+
raise NotImplementedError
|
|
46
|
+
|
|
47
|
+
@abstractmethod
|
|
48
|
+
def __len__(self) -> int:
|
|
49
|
+
raise NotImplementedError
|
|
50
|
+
|
|
51
|
+
@abstractmethod
|
|
52
|
+
def contains(self, item: K) -> torch.Tensor:
|
|
53
|
+
raise NotImplementedError
|
|
54
|
+
|
|
55
|
+
def __contains__(self, item):
|
|
56
|
+
return self.contains(item)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class TensorDictMap(
|
|
60
|
+
TensorDictModuleBase, TensorMap[TensorDictModuleBase, TensorDictModuleBase]
|
|
61
|
+
):
|
|
62
|
+
"""A Map-Storage for TensorDict.
|
|
63
|
+
|
|
64
|
+
This module resembles a storage. It takes a tensordict as its input and
|
|
65
|
+
returns another tensordict as output similar to TensorDictModuleBase. However,
|
|
66
|
+
it provides additional functionality like python map:
|
|
67
|
+
|
|
68
|
+
Keyword Args:
|
|
69
|
+
query_module (TensorDictModuleBase): a query module, typically an instance of
|
|
70
|
+
:class:`~tensordict.nn.QueryModule`, used to map a set of tensordict
|
|
71
|
+
entries to a hash key.
|
|
72
|
+
storage (Dict[NestedKey, TensorMap[torch.Tensor, torch.Tensor]]):
|
|
73
|
+
a dictionary representing the map from an index key to a tensor storage.
|
|
74
|
+
collate_fn (callable, optional): a function to use to collate samples from the
|
|
75
|
+
storage. Defaults to a custom value for each known storage type (stack for
|
|
76
|
+
:class:`~torchrl.data.ListStorage`, identity for :class:`~torchrl.data.TensorStorage`
|
|
77
|
+
subtypes and others).
|
|
78
|
+
|
|
79
|
+
Examples:
|
|
80
|
+
>>> import torch
|
|
81
|
+
>>> from tensordict import TensorDict
|
|
82
|
+
>>> from typing import cast
|
|
83
|
+
>>> from torchrl.data import LazyTensorStorage
|
|
84
|
+
>>> query_module = QueryModule(
|
|
85
|
+
... in_keys=["key1", "key2"],
|
|
86
|
+
... index_key="index",
|
|
87
|
+
... )
|
|
88
|
+
>>> embedding_storage = LazyTensorStorage(1000)
|
|
89
|
+
>>> tensor_dict_storage = TensorDictMap(
|
|
90
|
+
... query_module=query_module,
|
|
91
|
+
... storage={"out": embedding_storage},
|
|
92
|
+
... )
|
|
93
|
+
>>> index = TensorDict(
|
|
94
|
+
... {
|
|
95
|
+
... "key1": torch.Tensor([[-1], [1], [3], [-3]]),
|
|
96
|
+
... "key2": torch.Tensor([[0], [2], [4], [-4]]),
|
|
97
|
+
... },
|
|
98
|
+
... batch_size=(4,),
|
|
99
|
+
... )
|
|
100
|
+
>>> value = TensorDict(
|
|
101
|
+
... {"out": torch.Tensor([[10], [20], [30], [40]])}, batch_size=(4,)
|
|
102
|
+
... )
|
|
103
|
+
>>> tensor_dict_storage[index] = value
|
|
104
|
+
>>> tensor_dict_storage[index]
|
|
105
|
+
TensorDict(
|
|
106
|
+
fields={
|
|
107
|
+
out: Tensor(shape=torch.Size([4, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
|
|
108
|
+
batch_size=torch.Size([4]),
|
|
109
|
+
device=None,
|
|
110
|
+
is_shared=False)
|
|
111
|
+
>>> assert torch.sum(tensor_dict_storage.contains(index)).item() == 4
|
|
112
|
+
>>> new_index = index.clone(True)
|
|
113
|
+
>>> new_index["key3"] = torch.Tensor([[4], [5], [6], [7]])
|
|
114
|
+
>>> retrieve_value = tensor_dict_storage[new_index]
|
|
115
|
+
>>> assert cast(torch.Tensor, retrieve_value["index"] == value["index"]).all()
|
|
116
|
+
"""
|
|
117
|
+
|
|
118
|
+
def __init__(
|
|
119
|
+
self,
|
|
120
|
+
*,
|
|
121
|
+
query_module: QueryModule,
|
|
122
|
+
storage: dict[NestedKey, TensorMap[torch.Tensor, torch.Tensor]],
|
|
123
|
+
collate_fn: Callable[[Any], Any] | None = None,
|
|
124
|
+
out_keys: list[NestedKey] | None = None,
|
|
125
|
+
write_fn: Callable[[Any, Any], Any] | None = None,
|
|
126
|
+
):
|
|
127
|
+
super().__init__()
|
|
128
|
+
|
|
129
|
+
self.in_keys = query_module.in_keys
|
|
130
|
+
if out_keys is not None:
|
|
131
|
+
self.out_keys = out_keys
|
|
132
|
+
|
|
133
|
+
self.query_module = query_module
|
|
134
|
+
self.index_key = query_module.index_key
|
|
135
|
+
self.storage = storage
|
|
136
|
+
self.batch_added = False
|
|
137
|
+
if collate_fn is None:
|
|
138
|
+
collate_fn = _get_default_collate(self.storage)
|
|
139
|
+
self.collate_fn = collate_fn
|
|
140
|
+
self.write_fn = write_fn
|
|
141
|
+
|
|
142
|
+
@property
|
|
143
|
+
def max_size(self):
|
|
144
|
+
return self.storage.max_size
|
|
145
|
+
|
|
146
|
+
@property
|
|
147
|
+
def out_keys(self) -> list[NestedKey]:
|
|
148
|
+
out_keys = self.__dict__.get("_out_keys_and_lazy")
|
|
149
|
+
if out_keys is not None:
|
|
150
|
+
return out_keys[0]
|
|
151
|
+
storage = self.storage
|
|
152
|
+
if isinstance(storage, TensorStorage) and is_tensor_collection(
|
|
153
|
+
storage._storage
|
|
154
|
+
):
|
|
155
|
+
out_keys = list(storage._storage.keys(True, True))
|
|
156
|
+
self._out_keys_and_lazy = (out_keys, True)
|
|
157
|
+
return self.out_keys
|
|
158
|
+
raise AttributeError(
|
|
159
|
+
f"No out-keys found in the storage of type {type(storage)}"
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
@out_keys.setter
|
|
163
|
+
def out_keys(self, value):
|
|
164
|
+
self._out_keys_and_lazy = (value, False)
|
|
165
|
+
|
|
166
|
+
def _has_lazy_out_keys(self):
|
|
167
|
+
_out_keys_and_lazy = self.__dict__.get("_out_keys_and_lazy")
|
|
168
|
+
if _out_keys_and_lazy is None:
|
|
169
|
+
return True
|
|
170
|
+
return self._out_keys_and_lazy[1]
|
|
171
|
+
|
|
172
|
+
@classmethod
|
|
173
|
+
def from_tensordict_pair(
|
|
174
|
+
cls,
|
|
175
|
+
source,
|
|
176
|
+
dest,
|
|
177
|
+
in_keys: list[NestedKey],
|
|
178
|
+
out_keys: list[NestedKey] | None = None,
|
|
179
|
+
max_size: int = 1000,
|
|
180
|
+
storage_constructor: type | None = None,
|
|
181
|
+
hash_module: Callable | None = None,
|
|
182
|
+
collate_fn: Callable[[Any], Any] | None = None,
|
|
183
|
+
write_fn: Callable[[Any, Any], Any] | None = None,
|
|
184
|
+
consolidated: bool | None = None,
|
|
185
|
+
) -> TensorDictMap:
|
|
186
|
+
"""Creates a new TensorDictStorage from a pair of tensordicts (source and dest) using pre-defined rules of thumb.
|
|
187
|
+
|
|
188
|
+
Args:
|
|
189
|
+
source (TensorDict): An example of source tensordict, used as index in the storage.
|
|
190
|
+
dest (TensorDict): An example of dest tensordict, used as data in the storage.
|
|
191
|
+
in_keys (List[NestedKey]): a list of keys to use in the map.
|
|
192
|
+
out_keys (List[NestedKey]): a list of keys to return in the output tensordict.
|
|
193
|
+
All keys absent from out_keys, even if present in ``dest``, will not be stored
|
|
194
|
+
in the storage. Defaults to ``None`` (all keys are registered).
|
|
195
|
+
max_size (int, optional): the maximum number of elements in the storage. Ignored if the
|
|
196
|
+
``storage_constructor`` is passed. Defaults to ``1000``.
|
|
197
|
+
storage_constructor (Type, optional): a type of tensor storage.
|
|
198
|
+
Defaults to :class:`~tensordict.nn.storage.LazyDynamicStorage`.
|
|
199
|
+
Other options include :class:`~tensordict.nn.storage.FixedStorage`.
|
|
200
|
+
hash_module (Callable, optional): a hash function to use in the :class:`~torchrl.data.map.QueryModule`.
|
|
201
|
+
Defaults to :class:`SipHash` for low-dimensional inputs, and :class:`~torchrl.data.map.RandomProjectionHash`
|
|
202
|
+
for larger inputs.
|
|
203
|
+
collate_fn (callable, optional): a function to use to collate samples from the
|
|
204
|
+
storage. Defaults to a custom value for each known storage type (stack for
|
|
205
|
+
:class:`~torchrl.data.ListStorage`, identity for :class:`~torchrl.data.TensorStorage`
|
|
206
|
+
subtypes and others).
|
|
207
|
+
consolidated (bool, optional): whether to consolidate the storage in a single storage tensor.
|
|
208
|
+
Defaults to ``False``.
|
|
209
|
+
|
|
210
|
+
Examples:
|
|
211
|
+
>>> # The following example requires torchrl and gymnasium to be installed
|
|
212
|
+
>>> from torchrl.envs import GymEnv
|
|
213
|
+
>>> torch.manual_seed(0)
|
|
214
|
+
>>> env = GymEnv("CartPole-v1")
|
|
215
|
+
>>> env.set_seed(0)
|
|
216
|
+
>>> rollout = env.rollout(100)
|
|
217
|
+
>>> source, dest = rollout.exclude("next"), rollout.get("next")
|
|
218
|
+
>>> storage = TensorDictMap.from_tensordict_pair(
|
|
219
|
+
... source, dest,
|
|
220
|
+
... in_keys=["observation", "action"],
|
|
221
|
+
... )
|
|
222
|
+
>>> # maps the (obs, action) tuple to a corresponding next state
|
|
223
|
+
>>> storage[source] = dest
|
|
224
|
+
>>> print(source["_index"])
|
|
225
|
+
tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13])
|
|
226
|
+
>>> storage[source]
|
|
227
|
+
TensorDict(
|
|
228
|
+
fields={
|
|
229
|
+
done: Tensor(shape=torch.Size([14, 1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
230
|
+
observation: Tensor(shape=torch.Size([14, 4]), device=cpu, dtype=torch.float32, is_shared=False),
|
|
231
|
+
reward: Tensor(shape=torch.Size([14, 1]), device=cpu, dtype=torch.float32, is_shared=False),
|
|
232
|
+
terminated: Tensor(shape=torch.Size([14, 1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
233
|
+
truncated: Tensor(shape=torch.Size([14, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
|
|
234
|
+
batch_size=torch.Size([14]),
|
|
235
|
+
device=None,
|
|
236
|
+
is_shared=False)
|
|
237
|
+
|
|
238
|
+
"""
|
|
239
|
+
# Build query module
|
|
240
|
+
if hash_module is None:
|
|
241
|
+
# Count the features, if they're greater than RandomProjectionHash._N_COMPONENTS_DEFAULT
|
|
242
|
+
# use that module to project them to that dimensionality.
|
|
243
|
+
n_feat = 0
|
|
244
|
+
hash_module = []
|
|
245
|
+
for in_key in in_keys:
|
|
246
|
+
entry = source[in_key]
|
|
247
|
+
if entry.ndim == source.ndim:
|
|
248
|
+
# this is a good example of why td/tc are useful - carrying metadata
|
|
249
|
+
# allows us to know if there's a feature dim or not
|
|
250
|
+
n_feat = 0
|
|
251
|
+
else:
|
|
252
|
+
n_feat = entry.shape[-1]
|
|
253
|
+
if n_feat > RandomProjectionHash._N_COMPONENTS_DEFAULT:
|
|
254
|
+
_hash_module = RandomProjectionHash()
|
|
255
|
+
else:
|
|
256
|
+
_hash_module = SipHash()
|
|
257
|
+
hash_module.append(_hash_module)
|
|
258
|
+
query_module = QueryModule(in_keys, hash_module=hash_module)
|
|
259
|
+
|
|
260
|
+
# Build key_to_storage
|
|
261
|
+
if storage_constructor is None:
|
|
262
|
+
storage_constructor = functools.partial(
|
|
263
|
+
LazyTensorStorage, max_size, consolidated=bool(consolidated)
|
|
264
|
+
)
|
|
265
|
+
elif consolidated is not None:
|
|
266
|
+
storage_constructor = functools.partial(
|
|
267
|
+
storage_constructor, consolidated=consolidated
|
|
268
|
+
)
|
|
269
|
+
storage = storage_constructor()
|
|
270
|
+
result = cls(
|
|
271
|
+
query_module=query_module,
|
|
272
|
+
storage=storage,
|
|
273
|
+
collate_fn=collate_fn,
|
|
274
|
+
out_keys=out_keys,
|
|
275
|
+
write_fn=write_fn,
|
|
276
|
+
)
|
|
277
|
+
return result
|
|
278
|
+
|
|
279
|
+
def clear(self) -> None:
|
|
280
|
+
for mem in self.storage.values():
|
|
281
|
+
mem.clear()
|
|
282
|
+
|
|
283
|
+
def _to_index(
|
|
284
|
+
self, item: TensorDictBase, extend: bool, clone: bool | None = None
|
|
285
|
+
) -> torch.Tensor:
|
|
286
|
+
item = self.query_module(item, extend=extend, clone=clone)
|
|
287
|
+
return item[self.index_key]
|
|
288
|
+
|
|
289
|
+
def _maybe_add_batch(
|
|
290
|
+
self, item: TensorDictBase, value: TensorDictBase | None
|
|
291
|
+
) -> TensorDictBase:
|
|
292
|
+
self.batch_added = False
|
|
293
|
+
if len(item.batch_size) == 0:
|
|
294
|
+
self.batch_added = True
|
|
295
|
+
|
|
296
|
+
item = item.unsqueeze(dim=0)
|
|
297
|
+
if value is not None:
|
|
298
|
+
value = value.unsqueeze(dim=0)
|
|
299
|
+
|
|
300
|
+
return item, value
|
|
301
|
+
|
|
302
|
+
def _maybe_remove_batch(self, item: TensorDictBase) -> TensorDictBase:
|
|
303
|
+
if self.batch_added:
|
|
304
|
+
item = item.squeeze(dim=0)
|
|
305
|
+
return item
|
|
306
|
+
|
|
307
|
+
def __getitem__(self, item: TensorDictBase) -> TensorDictBase:
|
|
308
|
+
item = item.copy()
|
|
309
|
+
item, _ = self._maybe_add_batch(item, None)
|
|
310
|
+
|
|
311
|
+
index = self._to_index(item, extend=False, clone=False)
|
|
312
|
+
|
|
313
|
+
res = self.storage[index]
|
|
314
|
+
res = self.collate_fn(res)
|
|
315
|
+
res = self._maybe_remove_batch(res)
|
|
316
|
+
return res
|
|
317
|
+
|
|
318
|
+
def __setitem__(self, item: TensorDictBase, value: TensorDictBase):
|
|
319
|
+
if not self._has_lazy_out_keys():
|
|
320
|
+
# TODO: make this work with pytrees and avoid calling select if keys match
|
|
321
|
+
value = value.select(*self.out_keys, strict=False)
|
|
322
|
+
item, value = self._maybe_add_batch(item, value)
|
|
323
|
+
index = self._to_index(item, extend=True)
|
|
324
|
+
if index.unique().numel() < index.numel():
|
|
325
|
+
# If multiple values point to the same place in the storage, we cannot process them by batch
|
|
326
|
+
# There could be a better way to deal with this, using unique ids.
|
|
327
|
+
vals = []
|
|
328
|
+
for it, val in zip(item.split(1), value.split(1)):
|
|
329
|
+
self[it] = val
|
|
330
|
+
vals.append(val)
|
|
331
|
+
# __setitem__ may affect the content of the input data
|
|
332
|
+
value.update(TensorDictBase.lazy_stack(vals))
|
|
333
|
+
return
|
|
334
|
+
if self.write_fn is not None:
|
|
335
|
+
# We use this block in the following context: the value written in the storage is already present,
|
|
336
|
+
# but it needs to be updated.
|
|
337
|
+
# We first check if the value is already there using `contains`. If so, we pass the new value and the
|
|
338
|
+
# previous one to write_fn. The values that are not present are passed alone.
|
|
339
|
+
if len(self):
|
|
340
|
+
modifiable = self.contains(item)
|
|
341
|
+
if modifiable.any():
|
|
342
|
+
to_modify = (value[modifiable], self[item[modifiable]])
|
|
343
|
+
v1 = self.write_fn(*to_modify)
|
|
344
|
+
result = value.empty()
|
|
345
|
+
result[modifiable] = v1
|
|
346
|
+
result[~modifiable] = self.write_fn(value[~modifiable])
|
|
347
|
+
value = result
|
|
348
|
+
else:
|
|
349
|
+
value = self.write_fn(value)
|
|
350
|
+
else:
|
|
351
|
+
value = self.write_fn(value)
|
|
352
|
+
self.storage.set(index, value)
|
|
353
|
+
|
|
354
|
+
def __len__(self):
|
|
355
|
+
return len(self.storage)
|
|
356
|
+
|
|
357
|
+
def contains(self, item: TensorDictBase) -> torch.Tensor:
|
|
358
|
+
item, _ = self._maybe_add_batch(item, None)
|
|
359
|
+
index = self._to_index(item, extend=False, clone=True)
|
|
360
|
+
|
|
361
|
+
res = self.storage.contains(index)
|
|
362
|
+
res = self._maybe_remove_batch(res)
|
|
363
|
+
return res
|