torchrl 0.11.0__cp314-cp314-manylinux_2_28_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/benchmark_batched_envs.py +104 -0
- benchmarks/conftest.py +91 -0
- benchmarks/ecosystem/gym_env_throughput.py +321 -0
- benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
- benchmarks/requirements.txt +7 -0
- benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
- benchmarks/test_collectors_benchmark.py +240 -0
- benchmarks/test_compressed_storage_benchmark.py +145 -0
- benchmarks/test_envs_benchmark.py +133 -0
- benchmarks/test_llm.py +101 -0
- benchmarks/test_non_tensor_env_benchmark.py +70 -0
- benchmarks/test_objectives_benchmarks.py +1199 -0
- benchmarks/test_replaybuffer_benchmark.py +254 -0
- sota-check/README.md +35 -0
- sota-implementations/README.md +142 -0
- sota-implementations/a2c/README.md +39 -0
- sota-implementations/a2c/a2c_atari.py +291 -0
- sota-implementations/a2c/a2c_mujoco.py +273 -0
- sota-implementations/a2c/utils_atari.py +240 -0
- sota-implementations/a2c/utils_mujoco.py +160 -0
- sota-implementations/bandits/README.md +7 -0
- sota-implementations/bandits/dqn.py +126 -0
- sota-implementations/cql/cql_offline.py +198 -0
- sota-implementations/cql/cql_online.py +249 -0
- sota-implementations/cql/discrete_cql_offline.py +180 -0
- sota-implementations/cql/discrete_cql_online.py +227 -0
- sota-implementations/cql/utils.py +471 -0
- sota-implementations/crossq/crossq.py +271 -0
- sota-implementations/crossq/utils.py +320 -0
- sota-implementations/ddpg/ddpg.py +231 -0
- sota-implementations/ddpg/utils.py +325 -0
- sota-implementations/decision_transformer/dt.py +163 -0
- sota-implementations/decision_transformer/lamb.py +167 -0
- sota-implementations/decision_transformer/online_dt.py +178 -0
- sota-implementations/decision_transformer/utils.py +562 -0
- sota-implementations/discrete_sac/discrete_sac.py +243 -0
- sota-implementations/discrete_sac/utils.py +324 -0
- sota-implementations/dqn/README.md +30 -0
- sota-implementations/dqn/dqn_atari.py +272 -0
- sota-implementations/dqn/dqn_cartpole.py +236 -0
- sota-implementations/dqn/utils_atari.py +132 -0
- sota-implementations/dqn/utils_cartpole.py +90 -0
- sota-implementations/dreamer/README.md +129 -0
- sota-implementations/dreamer/dreamer.py +586 -0
- sota-implementations/dreamer/dreamer_utils.py +1107 -0
- sota-implementations/expert-iteration/README.md +352 -0
- sota-implementations/expert-iteration/ei_utils.py +770 -0
- sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
- sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
- sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
- sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
- sota-implementations/gail/gail.py +327 -0
- sota-implementations/gail/gail_utils.py +68 -0
- sota-implementations/gail/ppo_utils.py +157 -0
- sota-implementations/grpo/README.md +273 -0
- sota-implementations/grpo/grpo-async.py +437 -0
- sota-implementations/grpo/grpo-sync.py +435 -0
- sota-implementations/grpo/grpo_utils.py +843 -0
- sota-implementations/grpo/requirements_gsm8k.txt +11 -0
- sota-implementations/grpo/requirements_ifeval.txt +16 -0
- sota-implementations/impala/README.md +33 -0
- sota-implementations/impala/impala_multi_node_ray.py +292 -0
- sota-implementations/impala/impala_multi_node_submitit.py +284 -0
- sota-implementations/impala/impala_single_node.py +261 -0
- sota-implementations/impala/utils.py +184 -0
- sota-implementations/iql/discrete_iql.py +230 -0
- sota-implementations/iql/iql_offline.py +164 -0
- sota-implementations/iql/iql_online.py +225 -0
- sota-implementations/iql/utils.py +437 -0
- sota-implementations/multiagent/README.md +74 -0
- sota-implementations/multiagent/iql.py +237 -0
- sota-implementations/multiagent/maddpg_iddpg.py +266 -0
- sota-implementations/multiagent/mappo_ippo.py +267 -0
- sota-implementations/multiagent/qmix_vdn.py +271 -0
- sota-implementations/multiagent/sac.py +337 -0
- sota-implementations/multiagent/utils/__init__.py +4 -0
- sota-implementations/multiagent/utils/logging.py +151 -0
- sota-implementations/multiagent/utils/utils.py +43 -0
- sota-implementations/ppo/README.md +29 -0
- sota-implementations/ppo/ppo_atari.py +305 -0
- sota-implementations/ppo/ppo_mujoco.py +293 -0
- sota-implementations/ppo/utils_atari.py +238 -0
- sota-implementations/ppo/utils_mujoco.py +152 -0
- sota-implementations/ppo_trainer/train.py +21 -0
- sota-implementations/redq/README.md +7 -0
- sota-implementations/redq/redq.py +199 -0
- sota-implementations/redq/utils.py +1060 -0
- sota-implementations/sac/sac-async.py +266 -0
- sota-implementations/sac/sac.py +239 -0
- sota-implementations/sac/utils.py +381 -0
- sota-implementations/sac_trainer/train.py +16 -0
- sota-implementations/td3/td3.py +254 -0
- sota-implementations/td3/utils.py +319 -0
- sota-implementations/td3_bc/td3_bc.py +177 -0
- sota-implementations/td3_bc/utils.py +251 -0
- torchrl/__init__.py +144 -0
- torchrl/_extension.py +74 -0
- torchrl/_torchrl.cpython-314-aarch64-linux-gnu.so +0 -0
- torchrl/_utils.py +1431 -0
- torchrl/collectors/__init__.py +48 -0
- torchrl/collectors/_base.py +1058 -0
- torchrl/collectors/_constants.py +88 -0
- torchrl/collectors/_multi_async.py +324 -0
- torchrl/collectors/_multi_base.py +1805 -0
- torchrl/collectors/_multi_sync.py +464 -0
- torchrl/collectors/_runner.py +581 -0
- torchrl/collectors/_single.py +2009 -0
- torchrl/collectors/_single_async.py +259 -0
- torchrl/collectors/collectors.py +62 -0
- torchrl/collectors/distributed/__init__.py +32 -0
- torchrl/collectors/distributed/default_configs.py +133 -0
- torchrl/collectors/distributed/generic.py +1306 -0
- torchrl/collectors/distributed/ray.py +1092 -0
- torchrl/collectors/distributed/rpc.py +1006 -0
- torchrl/collectors/distributed/sync.py +731 -0
- torchrl/collectors/distributed/utils.py +160 -0
- torchrl/collectors/llm/__init__.py +10 -0
- torchrl/collectors/llm/base.py +494 -0
- torchrl/collectors/llm/ray_collector.py +275 -0
- torchrl/collectors/llm/utils.py +36 -0
- torchrl/collectors/llm/weight_update/__init__.py +10 -0
- torchrl/collectors/llm/weight_update/vllm.py +348 -0
- torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
- torchrl/collectors/utils.py +433 -0
- torchrl/collectors/weight_update.py +591 -0
- torchrl/csrc/numpy_utils.h +38 -0
- torchrl/csrc/pybind.cpp +27 -0
- torchrl/csrc/segment_tree.h +458 -0
- torchrl/csrc/torch_utils.h +34 -0
- torchrl/csrc/utils.cpp +48 -0
- torchrl/csrc/utils.h +31 -0
- torchrl/data/__init__.py +187 -0
- torchrl/data/datasets/__init__.py +58 -0
- torchrl/data/datasets/atari_dqn.py +878 -0
- torchrl/data/datasets/common.py +281 -0
- torchrl/data/datasets/d4rl.py +489 -0
- torchrl/data/datasets/d4rl_infos.py +187 -0
- torchrl/data/datasets/gen_dgrl.py +375 -0
- torchrl/data/datasets/minari_data.py +643 -0
- torchrl/data/datasets/openml.py +177 -0
- torchrl/data/datasets/openx.py +798 -0
- torchrl/data/datasets/roboset.py +363 -0
- torchrl/data/datasets/utils.py +11 -0
- torchrl/data/datasets/vd4rl.py +432 -0
- torchrl/data/llm/__init__.py +34 -0
- torchrl/data/llm/dataset.py +491 -0
- torchrl/data/llm/history.py +1378 -0
- torchrl/data/llm/prompt.py +198 -0
- torchrl/data/llm/reward.py +225 -0
- torchrl/data/llm/topk.py +186 -0
- torchrl/data/llm/utils.py +543 -0
- torchrl/data/map/__init__.py +21 -0
- torchrl/data/map/hash.py +185 -0
- torchrl/data/map/query.py +204 -0
- torchrl/data/map/tdstorage.py +363 -0
- torchrl/data/map/tree.py +1434 -0
- torchrl/data/map/utils.py +103 -0
- torchrl/data/postprocs/__init__.py +8 -0
- torchrl/data/postprocs/postprocs.py +391 -0
- torchrl/data/replay_buffers/__init__.py +99 -0
- torchrl/data/replay_buffers/checkpointers.py +622 -0
- torchrl/data/replay_buffers/ray_buffer.py +292 -0
- torchrl/data/replay_buffers/replay_buffers.py +2376 -0
- torchrl/data/replay_buffers/samplers.py +2578 -0
- torchrl/data/replay_buffers/scheduler.py +265 -0
- torchrl/data/replay_buffers/storages.py +2412 -0
- torchrl/data/replay_buffers/utils.py +1042 -0
- torchrl/data/replay_buffers/writers.py +781 -0
- torchrl/data/tensor_specs.py +7101 -0
- torchrl/data/utils.py +334 -0
- torchrl/envs/__init__.py +265 -0
- torchrl/envs/async_envs.py +1105 -0
- torchrl/envs/batched_envs.py +3093 -0
- torchrl/envs/common.py +4241 -0
- torchrl/envs/custom/__init__.py +11 -0
- torchrl/envs/custom/chess.py +617 -0
- torchrl/envs/custom/llm.py +214 -0
- torchrl/envs/custom/pendulum.py +401 -0
- torchrl/envs/custom/san_moves.txt +29274 -0
- torchrl/envs/custom/tictactoeenv.py +288 -0
- torchrl/envs/env_creator.py +263 -0
- torchrl/envs/gym_like.py +752 -0
- torchrl/envs/libs/__init__.py +68 -0
- torchrl/envs/libs/_gym_utils.py +326 -0
- torchrl/envs/libs/brax.py +846 -0
- torchrl/envs/libs/dm_control.py +544 -0
- torchrl/envs/libs/envpool.py +447 -0
- torchrl/envs/libs/gym.py +2239 -0
- torchrl/envs/libs/habitat.py +138 -0
- torchrl/envs/libs/isaac_lab.py +87 -0
- torchrl/envs/libs/isaacgym.py +203 -0
- torchrl/envs/libs/jax_utils.py +166 -0
- torchrl/envs/libs/jumanji.py +963 -0
- torchrl/envs/libs/meltingpot.py +599 -0
- torchrl/envs/libs/openml.py +153 -0
- torchrl/envs/libs/openspiel.py +652 -0
- torchrl/envs/libs/pettingzoo.py +1042 -0
- torchrl/envs/libs/procgen.py +351 -0
- torchrl/envs/libs/robohive.py +429 -0
- torchrl/envs/libs/smacv2.py +645 -0
- torchrl/envs/libs/unity_mlagents.py +891 -0
- torchrl/envs/libs/utils.py +147 -0
- torchrl/envs/libs/vmas.py +813 -0
- torchrl/envs/llm/__init__.py +63 -0
- torchrl/envs/llm/chat.py +730 -0
- torchrl/envs/llm/datasets/README.md +4 -0
- torchrl/envs/llm/datasets/__init__.py +17 -0
- torchrl/envs/llm/datasets/gsm8k.py +353 -0
- torchrl/envs/llm/datasets/ifeval.py +274 -0
- torchrl/envs/llm/envs.py +789 -0
- torchrl/envs/llm/libs/README.md +3 -0
- torchrl/envs/llm/libs/__init__.py +8 -0
- torchrl/envs/llm/libs/mlgym.py +869 -0
- torchrl/envs/llm/reward/__init__.py +10 -0
- torchrl/envs/llm/reward/gsm8k.py +324 -0
- torchrl/envs/llm/reward/ifeval/README.md +13 -0
- torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
- torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
- torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
- torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
- torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
- torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
- torchrl/envs/llm/transforms/__init__.py +55 -0
- torchrl/envs/llm/transforms/browser.py +292 -0
- torchrl/envs/llm/transforms/dataloading.py +859 -0
- torchrl/envs/llm/transforms/format.py +73 -0
- torchrl/envs/llm/transforms/kl.py +1544 -0
- torchrl/envs/llm/transforms/policy_version.py +189 -0
- torchrl/envs/llm/transforms/reason.py +323 -0
- torchrl/envs/llm/transforms/tokenizer.py +321 -0
- torchrl/envs/llm/transforms/tools.py +1955 -0
- torchrl/envs/model_based/__init__.py +9 -0
- torchrl/envs/model_based/common.py +180 -0
- torchrl/envs/model_based/dreamer.py +112 -0
- torchrl/envs/transforms/__init__.py +147 -0
- torchrl/envs/transforms/functional.py +48 -0
- torchrl/envs/transforms/gym_transforms.py +203 -0
- torchrl/envs/transforms/module.py +341 -0
- torchrl/envs/transforms/r3m.py +372 -0
- torchrl/envs/transforms/ray_service.py +663 -0
- torchrl/envs/transforms/rb_transforms.py +214 -0
- torchrl/envs/transforms/transforms.py +11835 -0
- torchrl/envs/transforms/utils.py +94 -0
- torchrl/envs/transforms/vc1.py +307 -0
- torchrl/envs/transforms/vecnorm.py +845 -0
- torchrl/envs/transforms/vip.py +407 -0
- torchrl/envs/utils.py +1718 -0
- torchrl/envs/vec_envs.py +11 -0
- torchrl/modules/__init__.py +206 -0
- torchrl/modules/distributions/__init__.py +73 -0
- torchrl/modules/distributions/continuous.py +830 -0
- torchrl/modules/distributions/discrete.py +908 -0
- torchrl/modules/distributions/truncated_normal.py +187 -0
- torchrl/modules/distributions/utils.py +233 -0
- torchrl/modules/llm/__init__.py +62 -0
- torchrl/modules/llm/backends/__init__.py +65 -0
- torchrl/modules/llm/backends/vllm/__init__.py +94 -0
- torchrl/modules/llm/backends/vllm/_models.py +46 -0
- torchrl/modules/llm/backends/vllm/base.py +72 -0
- torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
- torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
- torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
- torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
- torchrl/modules/llm/policies/__init__.py +28 -0
- torchrl/modules/llm/policies/common.py +1809 -0
- torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
- torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
- torchrl/modules/llm/utils.py +23 -0
- torchrl/modules/mcts/__init__.py +21 -0
- torchrl/modules/mcts/scores.py +579 -0
- torchrl/modules/models/__init__.py +86 -0
- torchrl/modules/models/batchrenorm.py +119 -0
- torchrl/modules/models/decision_transformer.py +179 -0
- torchrl/modules/models/exploration.py +731 -0
- torchrl/modules/models/llm.py +156 -0
- torchrl/modules/models/model_based.py +596 -0
- torchrl/modules/models/models.py +1712 -0
- torchrl/modules/models/multiagent.py +1067 -0
- torchrl/modules/models/recipes/impala.py +185 -0
- torchrl/modules/models/utils.py +162 -0
- torchrl/modules/planners/__init__.py +10 -0
- torchrl/modules/planners/cem.py +228 -0
- torchrl/modules/planners/common.py +73 -0
- torchrl/modules/planners/mppi.py +265 -0
- torchrl/modules/tensordict_module/__init__.py +89 -0
- torchrl/modules/tensordict_module/actors.py +2457 -0
- torchrl/modules/tensordict_module/common.py +529 -0
- torchrl/modules/tensordict_module/exploration.py +814 -0
- torchrl/modules/tensordict_module/probabilistic.py +321 -0
- torchrl/modules/tensordict_module/rnn.py +1639 -0
- torchrl/modules/tensordict_module/sequence.py +132 -0
- torchrl/modules/tensordict_module/world_models.py +34 -0
- torchrl/modules/utils/__init__.py +38 -0
- torchrl/modules/utils/mappings.py +9 -0
- torchrl/modules/utils/utils.py +89 -0
- torchrl/objectives/__init__.py +78 -0
- torchrl/objectives/a2c.py +659 -0
- torchrl/objectives/common.py +753 -0
- torchrl/objectives/cql.py +1346 -0
- torchrl/objectives/crossq.py +710 -0
- torchrl/objectives/ddpg.py +453 -0
- torchrl/objectives/decision_transformer.py +371 -0
- torchrl/objectives/deprecated.py +516 -0
- torchrl/objectives/dqn.py +683 -0
- torchrl/objectives/dreamer.py +488 -0
- torchrl/objectives/functional.py +48 -0
- torchrl/objectives/gail.py +258 -0
- torchrl/objectives/iql.py +996 -0
- torchrl/objectives/llm/__init__.py +30 -0
- torchrl/objectives/llm/grpo.py +846 -0
- torchrl/objectives/llm/sft.py +482 -0
- torchrl/objectives/multiagent/__init__.py +8 -0
- torchrl/objectives/multiagent/qmixer.py +396 -0
- torchrl/objectives/ppo.py +1669 -0
- torchrl/objectives/redq.py +683 -0
- torchrl/objectives/reinforce.py +530 -0
- torchrl/objectives/sac.py +1580 -0
- torchrl/objectives/td3.py +570 -0
- torchrl/objectives/td3_bc.py +625 -0
- torchrl/objectives/utils.py +782 -0
- torchrl/objectives/value/__init__.py +28 -0
- torchrl/objectives/value/advantages.py +1956 -0
- torchrl/objectives/value/functional.py +1459 -0
- torchrl/objectives/value/utils.py +360 -0
- torchrl/record/__init__.py +17 -0
- torchrl/record/loggers/__init__.py +23 -0
- torchrl/record/loggers/common.py +48 -0
- torchrl/record/loggers/csv.py +226 -0
- torchrl/record/loggers/mlflow.py +142 -0
- torchrl/record/loggers/tensorboard.py +139 -0
- torchrl/record/loggers/trackio.py +163 -0
- torchrl/record/loggers/utils.py +78 -0
- torchrl/record/loggers/wandb.py +214 -0
- torchrl/record/recorder.py +554 -0
- torchrl/services/__init__.py +79 -0
- torchrl/services/base.py +109 -0
- torchrl/services/ray_service.py +453 -0
- torchrl/testing/__init__.py +107 -0
- torchrl/testing/assertions.py +179 -0
- torchrl/testing/dist_utils.py +122 -0
- torchrl/testing/env_creators.py +227 -0
- torchrl/testing/env_helper.py +35 -0
- torchrl/testing/gym_helpers.py +156 -0
- torchrl/testing/llm_mocks.py +119 -0
- torchrl/testing/mocking_classes.py +2720 -0
- torchrl/testing/modules.py +295 -0
- torchrl/testing/mp_helpers.py +15 -0
- torchrl/testing/ray_helpers.py +293 -0
- torchrl/testing/utils.py +190 -0
- torchrl/trainers/__init__.py +42 -0
- torchrl/trainers/algorithms/__init__.py +11 -0
- torchrl/trainers/algorithms/configs/__init__.py +705 -0
- torchrl/trainers/algorithms/configs/collectors.py +216 -0
- torchrl/trainers/algorithms/configs/common.py +41 -0
- torchrl/trainers/algorithms/configs/data.py +308 -0
- torchrl/trainers/algorithms/configs/envs.py +104 -0
- torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
- torchrl/trainers/algorithms/configs/logging.py +80 -0
- torchrl/trainers/algorithms/configs/modules.py +570 -0
- torchrl/trainers/algorithms/configs/objectives.py +177 -0
- torchrl/trainers/algorithms/configs/trainers.py +340 -0
- torchrl/trainers/algorithms/configs/transforms.py +955 -0
- torchrl/trainers/algorithms/configs/utils.py +252 -0
- torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
- torchrl/trainers/algorithms/configs/weight_update.py +159 -0
- torchrl/trainers/algorithms/ppo.py +373 -0
- torchrl/trainers/algorithms/sac.py +308 -0
- torchrl/trainers/helpers/__init__.py +40 -0
- torchrl/trainers/helpers/collectors.py +416 -0
- torchrl/trainers/helpers/envs.py +573 -0
- torchrl/trainers/helpers/logger.py +33 -0
- torchrl/trainers/helpers/losses.py +132 -0
- torchrl/trainers/helpers/models.py +658 -0
- torchrl/trainers/helpers/replay_buffer.py +59 -0
- torchrl/trainers/helpers/trainers.py +301 -0
- torchrl/trainers/trainers.py +2052 -0
- torchrl/weight_update/__init__.py +33 -0
- torchrl/weight_update/_distributed.py +749 -0
- torchrl/weight_update/_mp.py +624 -0
- torchrl/weight_update/_noupdate.py +102 -0
- torchrl/weight_update/_ray.py +1032 -0
- torchrl/weight_update/_rpc.py +284 -0
- torchrl/weight_update/_shared.py +891 -0
- torchrl/weight_update/llm/__init__.py +32 -0
- torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
- torchrl/weight_update/llm/vllm_nccl.py +710 -0
- torchrl/weight_update/utils.py +73 -0
- torchrl/weight_update/weight_sync_schemes.py +1244 -0
- torchrl-0.11.0.dist-info/METADATA +1308 -0
- torchrl-0.11.0.dist-info/RECORD +394 -0
- torchrl-0.11.0.dist-info/WHEEL +5 -0
- torchrl-0.11.0.dist-info/entry_points.txt +2 -0
- torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
- torchrl-0.11.0.dist-info/top_level.txt +7 -0
|
@@ -0,0 +1,781 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import heapq
|
|
8
|
+
import json
|
|
9
|
+
import textwrap
|
|
10
|
+
from abc import ABC, abstractmethod
|
|
11
|
+
from collections.abc import Sequence
|
|
12
|
+
from copy import copy
|
|
13
|
+
from multiprocessing.context import get_spawning_popen
|
|
14
|
+
from pathlib import Path
|
|
15
|
+
from typing import Any
|
|
16
|
+
|
|
17
|
+
import numpy as np
|
|
18
|
+
import torch
|
|
19
|
+
from tensordict import is_tensor_collection, MemoryMappedTensor, TensorDictBase
|
|
20
|
+
from tensordict.utils import expand_as_right, is_tensorclass
|
|
21
|
+
from torch import multiprocessing as mp
|
|
22
|
+
from torchrl._utils import _STRDTYPE2DTYPE
|
|
23
|
+
|
|
24
|
+
try:
|
|
25
|
+
from torch.compiler import disable as compile_disable
|
|
26
|
+
except ImportError:
|
|
27
|
+
from torch._dynamo import disable as compile_disable
|
|
28
|
+
|
|
29
|
+
try:
|
|
30
|
+
from torch.utils._pytree import tree_leaves
|
|
31
|
+
except ImportError:
|
|
32
|
+
from torch.utils._pytree import tree_flatten
|
|
33
|
+
|
|
34
|
+
def tree_leaves(data): # noqa: D103
|
|
35
|
+
tree_flat, _ = tree_flatten(data)
|
|
36
|
+
return tree_flat
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
from torchrl.data.replay_buffers.storages import Storage
|
|
40
|
+
from torchrl.data.replay_buffers.utils import _is_int, _reduce
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class Writer(ABC):
|
|
44
|
+
"""A ReplayBuffer base Writer class."""
|
|
45
|
+
|
|
46
|
+
_storage: Storage
|
|
47
|
+
_rng: torch.Generator | None = None
|
|
48
|
+
|
|
49
|
+
def __init__(self, compilable: bool = False) -> None:
|
|
50
|
+
self._storage = None
|
|
51
|
+
self._compilable = compilable
|
|
52
|
+
|
|
53
|
+
def register_storage(self, storage: Storage) -> None:
|
|
54
|
+
self._storage = storage
|
|
55
|
+
|
|
56
|
+
@abstractmethod
|
|
57
|
+
def add(self, data: Any) -> int:
|
|
58
|
+
"""Inserts one piece of data at an appropriate index, and returns that index."""
|
|
59
|
+
...
|
|
60
|
+
|
|
61
|
+
@abstractmethod
|
|
62
|
+
def extend(self, data: Sequence) -> torch.Tensor:
|
|
63
|
+
"""Inserts a series of data points at appropriate indices, and returns a tensor containing the indices."""
|
|
64
|
+
...
|
|
65
|
+
|
|
66
|
+
@abstractmethod
|
|
67
|
+
def _empty(self, empty_write_count: bool = True) -> None:
|
|
68
|
+
...
|
|
69
|
+
|
|
70
|
+
@abstractmethod
|
|
71
|
+
def dumps(self, path):
|
|
72
|
+
...
|
|
73
|
+
|
|
74
|
+
@abstractmethod
|
|
75
|
+
def loads(self, path):
|
|
76
|
+
...
|
|
77
|
+
|
|
78
|
+
@abstractmethod
|
|
79
|
+
def state_dict(self) -> dict[str, Any]:
|
|
80
|
+
...
|
|
81
|
+
|
|
82
|
+
@abstractmethod
|
|
83
|
+
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
|
|
84
|
+
...
|
|
85
|
+
|
|
86
|
+
def _replicate_index(self, index):
|
|
87
|
+
# replicates the index in a non-zero format to have as many indices as
|
|
88
|
+
# elements truly written when the storage is multidim
|
|
89
|
+
if self._storage.ndim == 1:
|
|
90
|
+
return index
|
|
91
|
+
device = (
|
|
92
|
+
index.device if isinstance(index, torch.Tensor) else torch.device("cpu")
|
|
93
|
+
)
|
|
94
|
+
mesh = torch.stack(
|
|
95
|
+
torch.meshgrid(
|
|
96
|
+
*(torch.arange(dim, device=device) for dim in self._storage.shape[1:])
|
|
97
|
+
),
|
|
98
|
+
-1,
|
|
99
|
+
).flatten(0, -2)
|
|
100
|
+
if _is_int(index):
|
|
101
|
+
index0 = torch.as_tensor(int(index)).expand(mesh.shape[0], 1)
|
|
102
|
+
return torch.cat([index0, mesh], 1)
|
|
103
|
+
return torch.cat(
|
|
104
|
+
[
|
|
105
|
+
index.repeat_interleave(mesh.shape[0]).unsqueeze(1),
|
|
106
|
+
mesh.repeat(index.numel(), 1),
|
|
107
|
+
],
|
|
108
|
+
1,
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
def __repr__(self):
|
|
112
|
+
return f"{self.__class__.__name__}()"
|
|
113
|
+
|
|
114
|
+
def __getstate__(self):
|
|
115
|
+
state = copy(self.__dict__)
|
|
116
|
+
state["_rng"] = None
|
|
117
|
+
return state
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
class ImmutableDatasetWriter(Writer):
|
|
121
|
+
"""A blocking writer for immutable datasets."""
|
|
122
|
+
|
|
123
|
+
WRITING_ERR = "This dataset doesn't allow writing."
|
|
124
|
+
|
|
125
|
+
def add(self, data: Any) -> int:
|
|
126
|
+
raise RuntimeError(self.WRITING_ERR)
|
|
127
|
+
|
|
128
|
+
def extend(self, data: Sequence) -> torch.Tensor:
|
|
129
|
+
raise RuntimeError(self.WRITING_ERR)
|
|
130
|
+
|
|
131
|
+
def _empty(self, empty_write_count: bool = True) -> None:
|
|
132
|
+
raise RuntimeError(self.WRITING_ERR)
|
|
133
|
+
|
|
134
|
+
def dumps(self, path):
|
|
135
|
+
...
|
|
136
|
+
|
|
137
|
+
def loads(self, path):
|
|
138
|
+
...
|
|
139
|
+
|
|
140
|
+
def state_dict(self) -> dict[str, Any]:
|
|
141
|
+
return {}
|
|
142
|
+
|
|
143
|
+
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
|
|
144
|
+
return
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
class RoundRobinWriter(Writer):
|
|
148
|
+
"""A RoundRobin Writer class for composable replay buffers.
|
|
149
|
+
|
|
150
|
+
Args:
|
|
151
|
+
compilable (bool, optional): whether the writer is compilable.
|
|
152
|
+
If ``True``, the writer cannot be shared between multiple processes.
|
|
153
|
+
Defaults to ``False``.
|
|
154
|
+
|
|
155
|
+
"""
|
|
156
|
+
|
|
157
|
+
def __init__(self, compilable: bool = False) -> None:
|
|
158
|
+
super().__init__(compilable=compilable)
|
|
159
|
+
self._cursor = 0
|
|
160
|
+
self._write_count # noqa
|
|
161
|
+
|
|
162
|
+
def dumps(self, path):
|
|
163
|
+
path = Path(path).absolute()
|
|
164
|
+
path.mkdir(exist_ok=True)
|
|
165
|
+
with open(path / "metadata.json", "w") as file:
|
|
166
|
+
json.dump({"cursor": self._cursor}, file)
|
|
167
|
+
|
|
168
|
+
def loads(self, path):
|
|
169
|
+
path = Path(path).absolute()
|
|
170
|
+
with open(path / "metadata.json") as file:
|
|
171
|
+
metadata = json.load(file)
|
|
172
|
+
self._cursor = metadata["cursor"]
|
|
173
|
+
|
|
174
|
+
def add(self, data: Any) -> int | torch.Tensor:
|
|
175
|
+
index = self._cursor
|
|
176
|
+
_cursor = self._cursor
|
|
177
|
+
# we need to update the cursor first to avoid race conditions between workers
|
|
178
|
+
self._cursor = (self._cursor + 1) % self._storage._max_size_along_dim0(
|
|
179
|
+
single_data=data
|
|
180
|
+
)
|
|
181
|
+
self._write_count += 1
|
|
182
|
+
# Replicate index requires the shape of the storage to be known
|
|
183
|
+
# Other than that, a "flat" (1d) index is ok to write the data
|
|
184
|
+
self._storage.set(_cursor, data)
|
|
185
|
+
index = self._replicate_index(index)
|
|
186
|
+
self._mark_update_entities(index)
|
|
187
|
+
return index
|
|
188
|
+
|
|
189
|
+
def extend(self, data: Sequence) -> torch.Tensor:
|
|
190
|
+
cur_size = self._cursor
|
|
191
|
+
if is_tensor_collection(data) or isinstance(data, torch.Tensor):
|
|
192
|
+
batch_size = len(data)
|
|
193
|
+
elif isinstance(data, list):
|
|
194
|
+
batch_size = len(data)
|
|
195
|
+
else:
|
|
196
|
+
batch_size = len(tree_leaves(data)[0])
|
|
197
|
+
if batch_size == 0:
|
|
198
|
+
raise RuntimeError(f"Expected at least one element in extend. Got {data=}")
|
|
199
|
+
device = data.device if hasattr(data, "device") else None
|
|
200
|
+
max_size_along0 = self._storage._max_size_along_dim0(batched_data=data)
|
|
201
|
+
index = (
|
|
202
|
+
torch.arange(
|
|
203
|
+
cur_size, batch_size + cur_size, dtype=torch.long, device=device
|
|
204
|
+
)
|
|
205
|
+
% max_size_along0
|
|
206
|
+
)
|
|
207
|
+
# we need to update the cursor first to avoid race conditions between workers
|
|
208
|
+
self._cursor = (batch_size + cur_size) % max_size_along0
|
|
209
|
+
self._write_count += batch_size
|
|
210
|
+
# Replicate index requires the shape of the storage to be known
|
|
211
|
+
# Other than that, a "flat" (1d) index is ok to write the data
|
|
212
|
+
self._storage.set(index, data)
|
|
213
|
+
index = self._replicate_index(index)
|
|
214
|
+
self._mark_update_entities(index)
|
|
215
|
+
return index
|
|
216
|
+
|
|
217
|
+
def state_dict(self) -> dict[str, Any]:
|
|
218
|
+
return {"_cursor": self._cursor}
|
|
219
|
+
|
|
220
|
+
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
|
|
221
|
+
self._cursor = state_dict["_cursor"]
|
|
222
|
+
|
|
223
|
+
def _empty(self, empty_write_count: bool = True) -> None:
|
|
224
|
+
self._cursor = 0
|
|
225
|
+
if empty_write_count:
|
|
226
|
+
self._write_count = 0
|
|
227
|
+
|
|
228
|
+
# TODO: Workaround for PyTorch nightly regression where compiler can't handle
|
|
229
|
+
# method calls on objects returned from _attached_entities_iter()
|
|
230
|
+
@compile_disable()
|
|
231
|
+
def _mark_update_entities(self, index: torch.Tensor) -> None:
|
|
232
|
+
"""Mark entities as updated with the given index."""
|
|
233
|
+
for ent in self._storage._attached_entities_iter():
|
|
234
|
+
ent.mark_update(index)
|
|
235
|
+
|
|
236
|
+
@property
|
|
237
|
+
def _cursor(self):
|
|
238
|
+
_cursor_value = self.__dict__.get("_cursor_value", None)
|
|
239
|
+
if not self._compilable:
|
|
240
|
+
if _cursor_value is None:
|
|
241
|
+
_cursor_value = self._cursor_value = mp.Value("i", 0)
|
|
242
|
+
return _cursor_value.value
|
|
243
|
+
else:
|
|
244
|
+
if _cursor_value is None:
|
|
245
|
+
_cursor_value = self._cursor_value = 0
|
|
246
|
+
return _cursor_value
|
|
247
|
+
|
|
248
|
+
@_cursor.setter
|
|
249
|
+
def _cursor(self, value):
|
|
250
|
+
if not self._compilable:
|
|
251
|
+
_cursor_value = self.__dict__.get("_cursor_value", None)
|
|
252
|
+
if _cursor_value is None:
|
|
253
|
+
_cursor_value = self._cursor_value = mp.Value("i", 0)
|
|
254
|
+
_cursor_value.value = value
|
|
255
|
+
else:
|
|
256
|
+
self._cursor_value = value
|
|
257
|
+
|
|
258
|
+
@property
|
|
259
|
+
def _write_count(self):
|
|
260
|
+
_write_count = self.__dict__.get("_write_count_value", None)
|
|
261
|
+
if not self._compilable:
|
|
262
|
+
if _write_count is None:
|
|
263
|
+
_write_count = self._write_count_value = mp.Value("i", 0)
|
|
264
|
+
return _write_count.value
|
|
265
|
+
else:
|
|
266
|
+
if _write_count is None:
|
|
267
|
+
_write_count = self._write_count_value = 0
|
|
268
|
+
return _write_count
|
|
269
|
+
|
|
270
|
+
@_write_count.setter
|
|
271
|
+
def _write_count(self, value):
|
|
272
|
+
if not self._compilable:
|
|
273
|
+
_write_count = self.__dict__.get("_write_count_value", None)
|
|
274
|
+
if _write_count is None:
|
|
275
|
+
_write_count = self._write_count_value = mp.Value("i", 0)
|
|
276
|
+
_write_count.value = value
|
|
277
|
+
else:
|
|
278
|
+
self._write_count_value = value
|
|
279
|
+
|
|
280
|
+
def __getstate__(self):
|
|
281
|
+
state = super().__getstate__()
|
|
282
|
+
if get_spawning_popen() is None:
|
|
283
|
+
cursor = self._cursor
|
|
284
|
+
write_count = self._write_count
|
|
285
|
+
del state["_cursor_value"]
|
|
286
|
+
del state["_write_count_value"]
|
|
287
|
+
state["cursor__context"] = cursor
|
|
288
|
+
state["write_count__context"] = write_count
|
|
289
|
+
return state
|
|
290
|
+
|
|
291
|
+
def __setstate__(self, state):
|
|
292
|
+
cursor = state.pop("cursor__context", None)
|
|
293
|
+
write_count = state.pop("write_count__context", None)
|
|
294
|
+
if cursor is not None:
|
|
295
|
+
if not state["_compilable"]:
|
|
296
|
+
_cursor_value = mp.Value("i", cursor)
|
|
297
|
+
else:
|
|
298
|
+
_cursor_value = cursor
|
|
299
|
+
state["_cursor_value"] = _cursor_value
|
|
300
|
+
if write_count is not None:
|
|
301
|
+
if not state["_compilable"]:
|
|
302
|
+
_write_count_value = mp.Value("i", write_count)
|
|
303
|
+
else:
|
|
304
|
+
_write_count_value = write_count
|
|
305
|
+
state["_write_count_value"] = _write_count_value
|
|
306
|
+
self.__dict__.update(state)
|
|
307
|
+
|
|
308
|
+
def __repr__(self):
|
|
309
|
+
return f"{self.__class__.__name__}(cursor={int(self._cursor)}, full_storage={self._storage._is_full})"
|
|
310
|
+
|
|
311
|
+
|
|
312
|
+
class TensorDictRoundRobinWriter(RoundRobinWriter):
|
|
313
|
+
"""A RoundRobin Writer class for composable, tensordict-based replay buffers."""
|
|
314
|
+
|
|
315
|
+
def add(self, data: Any) -> int | torch.Tensor:
|
|
316
|
+
index = self._cursor
|
|
317
|
+
# we need to update the cursor first to avoid race conditions between workers
|
|
318
|
+
max_size_along_dim0 = self._storage._max_size_along_dim0(single_data=data)
|
|
319
|
+
self._cursor = (index + 1) % max_size_along_dim0
|
|
320
|
+
self._write_count += 1
|
|
321
|
+
if not is_tensorclass(data):
|
|
322
|
+
data.set(
|
|
323
|
+
"index",
|
|
324
|
+
expand_as_right(
|
|
325
|
+
torch.as_tensor(index, device=data.device, dtype=torch.long), data
|
|
326
|
+
),
|
|
327
|
+
)
|
|
328
|
+
self._storage.set(index, data)
|
|
329
|
+
index = self._replicate_index(index)
|
|
330
|
+
self._mark_update_entities(index)
|
|
331
|
+
return index
|
|
332
|
+
|
|
333
|
+
def extend(self, data: Sequence) -> torch.Tensor:
|
|
334
|
+
cur_size = self._cursor
|
|
335
|
+
batch_size = len(data)
|
|
336
|
+
device = data.device if hasattr(data, "device") else None
|
|
337
|
+
max_size_along_dim0 = self._storage._max_size_along_dim0(batched_data=data)
|
|
338
|
+
index = (
|
|
339
|
+
torch.arange(
|
|
340
|
+
cur_size, batch_size + cur_size, dtype=torch.long, device=device
|
|
341
|
+
)
|
|
342
|
+
% max_size_along_dim0
|
|
343
|
+
)
|
|
344
|
+
# we need to update the cursor first to avoid race conditions between workers
|
|
345
|
+
self._cursor = (batch_size + cur_size) % max_size_along_dim0
|
|
346
|
+
self._write_count += batch_size
|
|
347
|
+
# storage must convert the data to the appropriate format if needed
|
|
348
|
+
if not is_tensorclass(data):
|
|
349
|
+
data.set(
|
|
350
|
+
"index",
|
|
351
|
+
expand_as_right(
|
|
352
|
+
torch.as_tensor(index, device=data.device, dtype=torch.long), data
|
|
353
|
+
),
|
|
354
|
+
)
|
|
355
|
+
# Replicate index requires the shape of the storage to be known
|
|
356
|
+
# Other than that, a "flat" (1d) index is ok to write the data
|
|
357
|
+
self._storage.set(index, data)
|
|
358
|
+
index = self._replicate_index(index)
|
|
359
|
+
self._mark_update_entities(index)
|
|
360
|
+
return index
|
|
361
|
+
|
|
362
|
+
|
|
363
|
+
class TensorDictMaxValueWriter(Writer):
|
|
364
|
+
"""A Writer class for composable replay buffers that keeps the top elements based on some ranking key.
|
|
365
|
+
|
|
366
|
+
Args:
|
|
367
|
+
rank_key (str or tuple of str): the key to rank the elements by. Defaults to ``("next", "reward")``.
|
|
368
|
+
reduction (str): the reduction method to use if the rank key has more than one element.
|
|
369
|
+
Can be ``"max"``, ``"min"``, ``"mean"``, ``"median"`` or ``"sum"``.
|
|
370
|
+
|
|
371
|
+
Examples:
|
|
372
|
+
>>> import torch
|
|
373
|
+
>>> from tensordict import TensorDict
|
|
374
|
+
>>> from torchrl.data import LazyTensorStorage, TensorDictReplayBuffer, TensorDictMaxValueWriter
|
|
375
|
+
>>> from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
|
|
376
|
+
>>> rb = TensorDictReplayBuffer(
|
|
377
|
+
... storage=LazyTensorStorage(1),
|
|
378
|
+
... sampler=SamplerWithoutReplacement(),
|
|
379
|
+
... batch_size=1,
|
|
380
|
+
... writer=TensorDictMaxValueWriter(rank_key="key"),
|
|
381
|
+
... )
|
|
382
|
+
>>> td = TensorDict({
|
|
383
|
+
... "key": torch.tensor(range(10)),
|
|
384
|
+
... "obs": torch.tensor(range(10))
|
|
385
|
+
... }, batch_size=10)
|
|
386
|
+
>>> rb.extend(td)
|
|
387
|
+
>>> print(rb.sample().get("obs").item())
|
|
388
|
+
9
|
|
389
|
+
>>> td = TensorDict({
|
|
390
|
+
... "key": torch.tensor(range(10, 20)),
|
|
391
|
+
... "obs": torch.tensor(range(10, 20))
|
|
392
|
+
... }, batch_size=10)
|
|
393
|
+
>>> rb.extend(td)
|
|
394
|
+
>>> print(rb.sample().get("obs").item())
|
|
395
|
+
19
|
|
396
|
+
>>> td = TensorDict({
|
|
397
|
+
... "key": torch.tensor(range(10)),
|
|
398
|
+
... "obs": torch.tensor(range(10))
|
|
399
|
+
... }, batch_size=10)
|
|
400
|
+
>>> rb.extend(td)
|
|
401
|
+
>>> print(rb.sample().get("obs").item())
|
|
402
|
+
19
|
|
403
|
+
|
|
404
|
+
.. note::
|
|
405
|
+
This class isn't compatible with storages with more than one dimension.
|
|
406
|
+
This doesn't mean that storing trajectories is prohibited, but that
|
|
407
|
+
the trajectories stored must be stored on a per-trajectory basis.
|
|
408
|
+
Here are some examples of valid and invalid usages of the class.
|
|
409
|
+
First, a flat buffer where we store individual transitions:
|
|
410
|
+
|
|
411
|
+
>>> from torchrl.data import TensorStorage
|
|
412
|
+
>>> # Simplest use case: data comes in 1d and is stored as such
|
|
413
|
+
>>> data = TensorDict({
|
|
414
|
+
... "obs": torch.zeros(10, 3),
|
|
415
|
+
... "reward": torch.zeros(10, 1),
|
|
416
|
+
... }, batch_size=[10])
|
|
417
|
+
>>> rb = TensorDictReplayBuffer(
|
|
418
|
+
... storage=LazyTensorStorage(max_size=100),
|
|
419
|
+
... writer=TensorDictMaxValueWriter(rank_key="reward")
|
|
420
|
+
... )
|
|
421
|
+
>>> # We initialize the buffer: a total of 100 *transitions* can be stored
|
|
422
|
+
>>> rb.extend(data)
|
|
423
|
+
>>> # Samples 5 *transitions* at random
|
|
424
|
+
>>> sample = rb.sample(5)
|
|
425
|
+
>>> assert sample.shape == (5,)
|
|
426
|
+
|
|
427
|
+
Second, a buffer where we store trajectories. The max signal is aggregated
|
|
428
|
+
in each batch (e.g. the reward of each rollout is summed):
|
|
429
|
+
|
|
430
|
+
>>> # One can also store batches of data, each batch being a sub-trajectory
|
|
431
|
+
>>> env = ParallelEnv(2, lambda: GymEnv("Pendulum-v1"))
|
|
432
|
+
>>> # Get a batch of [2, 10] -- format is [Batch, Time]
|
|
433
|
+
>>> rollout = env.rollout(max_steps=10)
|
|
434
|
+
>>> rb = TensorDictReplayBuffer(
|
|
435
|
+
... storage=LazyTensorStorage(max_size=100),
|
|
436
|
+
... writer=TensorDictMaxValueWriter(rank_key="reward")
|
|
437
|
+
... )
|
|
438
|
+
>>> # We initialize the buffer: a total of 100 *trajectories* (!) can be stored
|
|
439
|
+
>>> rb.extend(rollout)
|
|
440
|
+
>>> # Sample 5 trajectories at random
|
|
441
|
+
>>> sample = rb.sample(5)
|
|
442
|
+
>>> assert sample.shape == (5, 10)
|
|
443
|
+
|
|
444
|
+
If data come in batch but a flat buffer is needed, we can simply flatten
|
|
445
|
+
the data before extending the buffer:
|
|
446
|
+
|
|
447
|
+
>>> rb = TensorDictReplayBuffer(
|
|
448
|
+
... storage=LazyTensorStorage(max_size=100),
|
|
449
|
+
... writer=TensorDictMaxValueWriter(rank_key="reward")
|
|
450
|
+
... )
|
|
451
|
+
>>> # We initialize the buffer: a total of 100 *transitions* can be stored
|
|
452
|
+
>>> rb.extend(rollout.reshape(-1))
|
|
453
|
+
>>> # Sample 5 trajectories at random
|
|
454
|
+
>>> sample = rb.sample(5)
|
|
455
|
+
>>> assert sample.shape == (5,)
|
|
456
|
+
|
|
457
|
+
It is not possible to create a buffer that is extended along the time
|
|
458
|
+
dimension, which is usually the recommended way of using buffers with
|
|
459
|
+
batches of trajectories. Since trajectories are overlapping, it's hard
|
|
460
|
+
if not impossible to aggregate the reward values and compare them.
|
|
461
|
+
This constructor isn't valid (notice the ndim argument):
|
|
462
|
+
|
|
463
|
+
>>> rb = TensorDictReplayBuffer(
|
|
464
|
+
... storage=LazyTensorStorage(max_size=100, ndim=2), # Breaks!
|
|
465
|
+
... writer=TensorDictMaxValueWriter(rank_key="reward")
|
|
466
|
+
... )
|
|
467
|
+
|
|
468
|
+
"""
|
|
469
|
+
|
|
470
|
+
def __init__(self, rank_key=None, reduction: str = "sum", **kwargs) -> None:
|
|
471
|
+
super().__init__(**kwargs)
|
|
472
|
+
self._cursor = 0
|
|
473
|
+
self._current_top_values = []
|
|
474
|
+
self._rank_key = rank_key
|
|
475
|
+
self._reduction = reduction
|
|
476
|
+
if self._rank_key is None:
|
|
477
|
+
self._rank_key = ("next", "reward")
|
|
478
|
+
|
|
479
|
+
def register_storage(self, storage: Storage) -> None:
|
|
480
|
+
if storage.ndim > 1:
|
|
481
|
+
raise ValueError(
|
|
482
|
+
"TensorDictMaxValueWriter is not compatible with storages with more than one dimension. "
|
|
483
|
+
"See the docstring constructor note about storing trajectories with TensorDictMaxValueWriter."
|
|
484
|
+
)
|
|
485
|
+
return super().register_storage(storage)
|
|
486
|
+
|
|
487
|
+
def get_insert_index(self, data: Any) -> int:
|
|
488
|
+
"""Returns the index where the data should be inserted, or ``None`` if it should not be inserted."""
|
|
489
|
+
if not is_tensor_collection(data):
|
|
490
|
+
raise RuntimeError(
|
|
491
|
+
f"{type(self)} expects data to be a tensor collection (tensordict or tensorclass). Found a {type(data)} instead."
|
|
492
|
+
)
|
|
493
|
+
if data.batch_dims > 1:
|
|
494
|
+
raise RuntimeError(
|
|
495
|
+
"Expected input tensordict to have no more than 1 dimension, got"
|
|
496
|
+
f"tensordict.batch_size = {data.batch_size}"
|
|
497
|
+
)
|
|
498
|
+
|
|
499
|
+
ret = None
|
|
500
|
+
rank_data = data.get(self._rank_key)
|
|
501
|
+
|
|
502
|
+
# If time dimension, sum along it.
|
|
503
|
+
if rank_data.numel() > 1:
|
|
504
|
+
rank_data = _reduce(rank_data.reshape(-1), self._reduction, dim=0)
|
|
505
|
+
else:
|
|
506
|
+
rank_data = rank_data.item()
|
|
507
|
+
|
|
508
|
+
if rank_data is None:
|
|
509
|
+
raise KeyError(f"Rank key {self._rank_key} not found in data.")
|
|
510
|
+
|
|
511
|
+
# If the buffer is not full, add the data
|
|
512
|
+
if len(self._current_top_values) < self._storage.max_size:
|
|
513
|
+
ret = self._cursor
|
|
514
|
+
self._cursor = (self._cursor + 1) % self._storage.max_size
|
|
515
|
+
|
|
516
|
+
# Add new reward to the heap
|
|
517
|
+
heapq.heappush(self._current_top_values, (rank_data, ret))
|
|
518
|
+
|
|
519
|
+
# If the buffer is full, check if the new data is better than the worst data in the buffer
|
|
520
|
+
elif rank_data > self._current_top_values[0][0]:
|
|
521
|
+
|
|
522
|
+
# retrieve position of the smallest value
|
|
523
|
+
min_sample = heapq.heappop(self._current_top_values)
|
|
524
|
+
ret = min_sample[1]
|
|
525
|
+
|
|
526
|
+
# Add new reward to the heap
|
|
527
|
+
heapq.heappush(self._current_top_values, (rank_data, ret))
|
|
528
|
+
|
|
529
|
+
return ret
|
|
530
|
+
|
|
531
|
+
@property
|
|
532
|
+
def _write_count(self):
|
|
533
|
+
_write_count = self.__dict__.get("_write_count_value", None)
|
|
534
|
+
if _write_count is None:
|
|
535
|
+
_write_count = self._write_count_value = mp.Value("i", 0)
|
|
536
|
+
return _write_count.value
|
|
537
|
+
|
|
538
|
+
@_write_count.setter
|
|
539
|
+
def _write_count(self, value):
|
|
540
|
+
_write_count = self.__dict__.get("_write_count_value", None)
|
|
541
|
+
if _write_count is None:
|
|
542
|
+
_write_count = self._write_count_value = mp.Value("i", 0)
|
|
543
|
+
_write_count.value = value
|
|
544
|
+
|
|
545
|
+
def add(self, data: Any) -> int | torch.Tensor:
|
|
546
|
+
"""Inserts a single element of data at an appropriate index, and returns that index.
|
|
547
|
+
|
|
548
|
+
The ``rank_key`` in the data passed to this module should be structured as [].
|
|
549
|
+
If it has more dimensions, it will be reduced to a single value using the ``reduction`` method.
|
|
550
|
+
"""
|
|
551
|
+
index = self.get_insert_index(data)
|
|
552
|
+
if index is not None:
|
|
553
|
+
data.set("index", index)
|
|
554
|
+
self._write_count += 1
|
|
555
|
+
# Replicate index requires the shape of the storage to be known
|
|
556
|
+
# Other than that, a "flat" (1d) index is ok to write the data
|
|
557
|
+
self._storage.set(index, data)
|
|
558
|
+
index = self._replicate_index(index)
|
|
559
|
+
for ent in self._storage._attached_entities_iter():
|
|
560
|
+
ent.mark_update(index)
|
|
561
|
+
return index
|
|
562
|
+
|
|
563
|
+
def extend(self, data: TensorDictBase) -> None:
|
|
564
|
+
"""Inserts a series of data points at appropriate indices.
|
|
565
|
+
|
|
566
|
+
The ``rank_key`` in the data passed to this module should be structured as [B].
|
|
567
|
+
If it has more dimensions, it will be reduced to a single value using the ``reduction`` method.
|
|
568
|
+
"""
|
|
569
|
+
# a map of [idx_in_storage, idx_in_data]
|
|
570
|
+
data_to_replace = {}
|
|
571
|
+
for data_idx, sample in enumerate(data):
|
|
572
|
+
storage_idx = self.get_insert_index(sample)
|
|
573
|
+
if storage_idx is not None:
|
|
574
|
+
self._write_count += 1
|
|
575
|
+
data_to_replace[storage_idx] = data_idx
|
|
576
|
+
|
|
577
|
+
# -1 will be interpreted as invalid by prioritized buffers
|
|
578
|
+
# Replace the data in the storage all at once
|
|
579
|
+
if len(data_to_replace) > 0:
|
|
580
|
+
storage_idx, data_idx = zip(*data_to_replace.items())
|
|
581
|
+
index = data.get("index", None)
|
|
582
|
+
dtype = index.dtype if index is not None else torch.long
|
|
583
|
+
device = index.device if index is not None else data.device
|
|
584
|
+
out_index = torch.full(data.shape, -1, dtype=torch.long, device=device)
|
|
585
|
+
data_idx = torch.as_tensor(data_idx, dtype=dtype, device=device)
|
|
586
|
+
storage_idx = torch.as_tensor(storage_idx, dtype=dtype, device=device)
|
|
587
|
+
out_index[data_idx] = storage_idx
|
|
588
|
+
self._storage.set(storage_idx, data[data_idx])
|
|
589
|
+
else:
|
|
590
|
+
device = getattr(self._storage, "device", None)
|
|
591
|
+
out_index = torch.full(data.shape, -1, dtype=torch.long, device=device)
|
|
592
|
+
index = self._replicate_index(out_index)
|
|
593
|
+
self._mark_update_entities(index)
|
|
594
|
+
return index
|
|
595
|
+
|
|
596
|
+
# TODO: Workaround for PyTorch nightly regression where compiler can't handle
|
|
597
|
+
# method calls on objects returned from _attached_entities_iter()
|
|
598
|
+
@compile_disable()
|
|
599
|
+
def _mark_update_entities(self, index: torch.Tensor) -> None:
|
|
600
|
+
"""Mark entities as updated with the given index."""
|
|
601
|
+
for ent in self._storage._attached_entities_iter():
|
|
602
|
+
ent.mark_update(index)
|
|
603
|
+
|
|
604
|
+
def _empty(self, empty_write_count: bool = True) -> None:
|
|
605
|
+
self._cursor = 0
|
|
606
|
+
self._current_top_values = []
|
|
607
|
+
if empty_write_count:
|
|
608
|
+
self._write_count = 0
|
|
609
|
+
|
|
610
|
+
def __getstate__(self):
|
|
611
|
+
if get_spawning_popen() is not None:
|
|
612
|
+
raise RuntimeError(
|
|
613
|
+
f"Writers of type {type(self)} cannot be shared between processes. "
|
|
614
|
+
f"Please submit an issue at https://github.com/pytorch/rl if this feature is needed."
|
|
615
|
+
)
|
|
616
|
+
state = super().__getstate__()
|
|
617
|
+
# Handle the mp.Value object for pickling
|
|
618
|
+
if "_write_count_value" in state:
|
|
619
|
+
write_count = self._write_count
|
|
620
|
+
del state["_write_count_value"]
|
|
621
|
+
state["write_count__context"] = write_count
|
|
622
|
+
return state
|
|
623
|
+
|
|
624
|
+
def __setstate__(self, state):
|
|
625
|
+
write_count = state.pop("write_count__context", None)
|
|
626
|
+
if write_count is not None:
|
|
627
|
+
state["_write_count_value"] = mp.Value("i", write_count)
|
|
628
|
+
self.__dict__.update(state)
|
|
629
|
+
|
|
630
|
+
def dumps(self, path):
|
|
631
|
+
path = Path(path).absolute()
|
|
632
|
+
path.mkdir(exist_ok=True)
|
|
633
|
+
t = torch.as_tensor(self._current_top_values)
|
|
634
|
+
try:
|
|
635
|
+
MemoryMappedTensor.from_filename(
|
|
636
|
+
filename=path / "current_top_values.memmap",
|
|
637
|
+
shape=t.shape,
|
|
638
|
+
dtype=t.dtype,
|
|
639
|
+
).copy_(t)
|
|
640
|
+
except FileNotFoundError:
|
|
641
|
+
MemoryMappedTensor.from_tensor(
|
|
642
|
+
t, filename=path / "current_top_values.memmap"
|
|
643
|
+
)
|
|
644
|
+
with open(path / "metadata.json", "w") as file:
|
|
645
|
+
json.dump(
|
|
646
|
+
{
|
|
647
|
+
"cursor": self._cursor,
|
|
648
|
+
"rank_key": self._rank_key,
|
|
649
|
+
"dtype": str(t.dtype),
|
|
650
|
+
"shape": list(t.shape),
|
|
651
|
+
},
|
|
652
|
+
file,
|
|
653
|
+
)
|
|
654
|
+
|
|
655
|
+
def loads(self, path):
|
|
656
|
+
path = Path(path).absolute()
|
|
657
|
+
with open(path / "metadata.json") as file:
|
|
658
|
+
metadata = json.load(file)
|
|
659
|
+
self._cursor = metadata["cursor"]
|
|
660
|
+
self._rank_key = metadata["rank_key"]
|
|
661
|
+
shape = torch.Size(metadata["shape"])
|
|
662
|
+
dtype = metadata["dtype"]
|
|
663
|
+
self._current_top_values = MemoryMappedTensor.from_filename(
|
|
664
|
+
filename=path / "current_top_values.memmap",
|
|
665
|
+
dtype=_STRDTYPE2DTYPE[dtype],
|
|
666
|
+
shape=shape,
|
|
667
|
+
).tolist()
|
|
668
|
+
|
|
669
|
+
def state_dict(self) -> dict[str, Any]:
|
|
670
|
+
raise NotImplementedError
|
|
671
|
+
|
|
672
|
+
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
|
|
673
|
+
raise NotImplementedError
|
|
674
|
+
|
|
675
|
+
def __repr__(self):
|
|
676
|
+
return f"{self.__class__.__name__}(cursor={int(self._cursor)}, full_storage={self._storage._is_full}, rank_key={self._rank_key}, reduction={self._reduction})"
|
|
677
|
+
|
|
678
|
+
|
|
679
|
+
class WriterEnsemble(Writer):
|
|
680
|
+
"""An ensemble of writers.
|
|
681
|
+
|
|
682
|
+
This class is designed to work with :class:`~torchrl.data.replay_buffers.replay_buffers.ReplayBufferEnsemble`.
|
|
683
|
+
It contains the writers but blocks writing with any of them.
|
|
684
|
+
|
|
685
|
+
Args:
|
|
686
|
+
writers (sequence of Writer): the writers to make the composite writer.
|
|
687
|
+
|
|
688
|
+
.. warning::
|
|
689
|
+
This class does not support writing.
|
|
690
|
+
To extend one of the replay buffers, simply index the parent
|
|
691
|
+
:class:`~torchrl.data.ReplayBufferEnsemble` object.
|
|
692
|
+
|
|
693
|
+
"""
|
|
694
|
+
|
|
695
|
+
def __init__(self, *writers):
|
|
696
|
+
self._rng_private = None
|
|
697
|
+
self._writers = writers
|
|
698
|
+
|
|
699
|
+
@property
|
|
700
|
+
def _rng(self):
|
|
701
|
+
return self._rng_private
|
|
702
|
+
|
|
703
|
+
@_rng.setter
|
|
704
|
+
def _rng(self, value):
|
|
705
|
+
self._rng_private = value
|
|
706
|
+
for writer in self._writers:
|
|
707
|
+
writer._rng = value
|
|
708
|
+
|
|
709
|
+
def _empty(self, empty_write_count: bool = True) -> None:
|
|
710
|
+
raise NotImplementedError
|
|
711
|
+
|
|
712
|
+
def dumps(self, path: Path):
|
|
713
|
+
path = Path(path).absolute()
|
|
714
|
+
for i, writer in enumerate(self._writers):
|
|
715
|
+
writer.dumps(path / str(i))
|
|
716
|
+
|
|
717
|
+
def loads(self, path: Path):
|
|
718
|
+
path = Path(path).absolute()
|
|
719
|
+
for i, writer in enumerate(self._writers):
|
|
720
|
+
writer.loads(path / str(i))
|
|
721
|
+
|
|
722
|
+
def add(self):
|
|
723
|
+
raise NotImplementedError
|
|
724
|
+
|
|
725
|
+
def extend(self):
|
|
726
|
+
raise NotImplementedError
|
|
727
|
+
|
|
728
|
+
_INDEX_ERROR = "Expected an index of type torch.Tensor, range, np.ndarray, int, slice or ellipsis, got {} instead."
|
|
729
|
+
|
|
730
|
+
def __getitem__(self, index):
|
|
731
|
+
if isinstance(index, tuple):
|
|
732
|
+
if index[0] is Ellipsis:
|
|
733
|
+
index = (slice(None), index[1:])
|
|
734
|
+
result = self[index[0]]
|
|
735
|
+
if len(index) > 1:
|
|
736
|
+
raise IndexError(
|
|
737
|
+
f"Tuple of length greater than 1 are not accepted to index writers of type {type(self)}."
|
|
738
|
+
)
|
|
739
|
+
return result
|
|
740
|
+
if isinstance(index, slice) and index == slice(None):
|
|
741
|
+
return self
|
|
742
|
+
if isinstance(index, (list, range, np.ndarray)):
|
|
743
|
+
index = torch.as_tensor(index)
|
|
744
|
+
if isinstance(index, torch.Tensor):
|
|
745
|
+
if index.ndim > 1:
|
|
746
|
+
raise RuntimeError(
|
|
747
|
+
f"Cannot index a {type(self)} with tensor indices that have more than one dimension."
|
|
748
|
+
)
|
|
749
|
+
if index.is_floating_point():
|
|
750
|
+
raise TypeError(
|
|
751
|
+
"A floating point index was received when an integer dtype was expected."
|
|
752
|
+
)
|
|
753
|
+
if isinstance(index, int) or (not isinstance(index, slice) and len(index) == 0):
|
|
754
|
+
try:
|
|
755
|
+
index = int(index)
|
|
756
|
+
except Exception:
|
|
757
|
+
raise IndexError(self._INDEX_ERROR.format(type(index)))
|
|
758
|
+
try:
|
|
759
|
+
return self._writers[index]
|
|
760
|
+
except IndexError:
|
|
761
|
+
raise IndexError(self._INDEX_ERROR.format(type(index)))
|
|
762
|
+
if isinstance(index, torch.Tensor):
|
|
763
|
+
index = index.tolist()
|
|
764
|
+
writers = [self._writers[i] for i in index]
|
|
765
|
+
else:
|
|
766
|
+
# slice
|
|
767
|
+
writers = self._writers[index]
|
|
768
|
+
return WriterEnsemble(*writers)
|
|
769
|
+
|
|
770
|
+
def __len__(self):
|
|
771
|
+
return len(self._writers)
|
|
772
|
+
|
|
773
|
+
def __repr__(self):
|
|
774
|
+
writers = textwrap.indent(f"writers={self._writers}", " " * 4)
|
|
775
|
+
return f"WriterEnsemble(\n{writers})"
|
|
776
|
+
|
|
777
|
+
def state_dict(self) -> dict[str, Any]:
|
|
778
|
+
raise NotImplementedError
|
|
779
|
+
|
|
780
|
+
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
|
|
781
|
+
raise NotImplementedError
|