torchrl 0.11.0__cp314-cp314t-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/benchmark_batched_envs.py +104 -0
- benchmarks/conftest.py +91 -0
- benchmarks/ecosystem/gym_env_throughput.py +321 -0
- benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
- benchmarks/requirements.txt +7 -0
- benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
- benchmarks/test_collectors_benchmark.py +240 -0
- benchmarks/test_compressed_storage_benchmark.py +145 -0
- benchmarks/test_envs_benchmark.py +133 -0
- benchmarks/test_llm.py +101 -0
- benchmarks/test_non_tensor_env_benchmark.py +70 -0
- benchmarks/test_objectives_benchmarks.py +1199 -0
- benchmarks/test_replaybuffer_benchmark.py +254 -0
- sota-check/README.md +35 -0
- sota-implementations/README.md +142 -0
- sota-implementations/a2c/README.md +39 -0
- sota-implementations/a2c/a2c_atari.py +291 -0
- sota-implementations/a2c/a2c_mujoco.py +273 -0
- sota-implementations/a2c/utils_atari.py +240 -0
- sota-implementations/a2c/utils_mujoco.py +160 -0
- sota-implementations/bandits/README.md +7 -0
- sota-implementations/bandits/dqn.py +126 -0
- sota-implementations/cql/cql_offline.py +198 -0
- sota-implementations/cql/cql_online.py +249 -0
- sota-implementations/cql/discrete_cql_offline.py +180 -0
- sota-implementations/cql/discrete_cql_online.py +227 -0
- sota-implementations/cql/utils.py +471 -0
- sota-implementations/crossq/crossq.py +271 -0
- sota-implementations/crossq/utils.py +320 -0
- sota-implementations/ddpg/ddpg.py +231 -0
- sota-implementations/ddpg/utils.py +325 -0
- sota-implementations/decision_transformer/dt.py +163 -0
- sota-implementations/decision_transformer/lamb.py +167 -0
- sota-implementations/decision_transformer/online_dt.py +178 -0
- sota-implementations/decision_transformer/utils.py +562 -0
- sota-implementations/discrete_sac/discrete_sac.py +243 -0
- sota-implementations/discrete_sac/utils.py +324 -0
- sota-implementations/dqn/README.md +30 -0
- sota-implementations/dqn/dqn_atari.py +272 -0
- sota-implementations/dqn/dqn_cartpole.py +236 -0
- sota-implementations/dqn/utils_atari.py +132 -0
- sota-implementations/dqn/utils_cartpole.py +90 -0
- sota-implementations/dreamer/README.md +129 -0
- sota-implementations/dreamer/dreamer.py +586 -0
- sota-implementations/dreamer/dreamer_utils.py +1107 -0
- sota-implementations/expert-iteration/README.md +352 -0
- sota-implementations/expert-iteration/ei_utils.py +770 -0
- sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
- sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
- sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
- sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
- sota-implementations/gail/gail.py +327 -0
- sota-implementations/gail/gail_utils.py +68 -0
- sota-implementations/gail/ppo_utils.py +157 -0
- sota-implementations/grpo/README.md +273 -0
- sota-implementations/grpo/grpo-async.py +437 -0
- sota-implementations/grpo/grpo-sync.py +435 -0
- sota-implementations/grpo/grpo_utils.py +843 -0
- sota-implementations/grpo/requirements_gsm8k.txt +11 -0
- sota-implementations/grpo/requirements_ifeval.txt +16 -0
- sota-implementations/impala/README.md +33 -0
- sota-implementations/impala/impala_multi_node_ray.py +292 -0
- sota-implementations/impala/impala_multi_node_submitit.py +284 -0
- sota-implementations/impala/impala_single_node.py +261 -0
- sota-implementations/impala/utils.py +184 -0
- sota-implementations/iql/discrete_iql.py +230 -0
- sota-implementations/iql/iql_offline.py +164 -0
- sota-implementations/iql/iql_online.py +225 -0
- sota-implementations/iql/utils.py +437 -0
- sota-implementations/multiagent/README.md +74 -0
- sota-implementations/multiagent/iql.py +237 -0
- sota-implementations/multiagent/maddpg_iddpg.py +266 -0
- sota-implementations/multiagent/mappo_ippo.py +267 -0
- sota-implementations/multiagent/qmix_vdn.py +271 -0
- sota-implementations/multiagent/sac.py +337 -0
- sota-implementations/multiagent/utils/__init__.py +4 -0
- sota-implementations/multiagent/utils/logging.py +151 -0
- sota-implementations/multiagent/utils/utils.py +43 -0
- sota-implementations/ppo/README.md +29 -0
- sota-implementations/ppo/ppo_atari.py +305 -0
- sota-implementations/ppo/ppo_mujoco.py +293 -0
- sota-implementations/ppo/utils_atari.py +238 -0
- sota-implementations/ppo/utils_mujoco.py +152 -0
- sota-implementations/ppo_trainer/train.py +21 -0
- sota-implementations/redq/README.md +7 -0
- sota-implementations/redq/redq.py +199 -0
- sota-implementations/redq/utils.py +1060 -0
- sota-implementations/sac/sac-async.py +266 -0
- sota-implementations/sac/sac.py +239 -0
- sota-implementations/sac/utils.py +381 -0
- sota-implementations/sac_trainer/train.py +16 -0
- sota-implementations/td3/td3.py +254 -0
- sota-implementations/td3/utils.py +319 -0
- sota-implementations/td3_bc/td3_bc.py +177 -0
- sota-implementations/td3_bc/utils.py +251 -0
- torchrl/.dylibs/libc++.1.0.dylib +0 -0
- torchrl/__init__.py +144 -0
- torchrl/_extension.py +74 -0
- torchrl/_torchrl.cpython-314t-darwin.so +0 -0
- torchrl/_utils.py +1431 -0
- torchrl/collectors/__init__.py +48 -0
- torchrl/collectors/_base.py +1058 -0
- torchrl/collectors/_constants.py +88 -0
- torchrl/collectors/_multi_async.py +324 -0
- torchrl/collectors/_multi_base.py +1805 -0
- torchrl/collectors/_multi_sync.py +464 -0
- torchrl/collectors/_runner.py +581 -0
- torchrl/collectors/_single.py +2009 -0
- torchrl/collectors/_single_async.py +259 -0
- torchrl/collectors/collectors.py +62 -0
- torchrl/collectors/distributed/__init__.py +32 -0
- torchrl/collectors/distributed/default_configs.py +133 -0
- torchrl/collectors/distributed/generic.py +1306 -0
- torchrl/collectors/distributed/ray.py +1092 -0
- torchrl/collectors/distributed/rpc.py +1006 -0
- torchrl/collectors/distributed/sync.py +731 -0
- torchrl/collectors/distributed/utils.py +160 -0
- torchrl/collectors/llm/__init__.py +10 -0
- torchrl/collectors/llm/base.py +494 -0
- torchrl/collectors/llm/ray_collector.py +275 -0
- torchrl/collectors/llm/utils.py +36 -0
- torchrl/collectors/llm/weight_update/__init__.py +10 -0
- torchrl/collectors/llm/weight_update/vllm.py +348 -0
- torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
- torchrl/collectors/utils.py +433 -0
- torchrl/collectors/weight_update.py +591 -0
- torchrl/csrc/numpy_utils.h +38 -0
- torchrl/csrc/pybind.cpp +27 -0
- torchrl/csrc/segment_tree.h +458 -0
- torchrl/csrc/torch_utils.h +34 -0
- torchrl/csrc/utils.cpp +48 -0
- torchrl/csrc/utils.h +31 -0
- torchrl/data/__init__.py +187 -0
- torchrl/data/datasets/__init__.py +58 -0
- torchrl/data/datasets/atari_dqn.py +878 -0
- torchrl/data/datasets/common.py +281 -0
- torchrl/data/datasets/d4rl.py +489 -0
- torchrl/data/datasets/d4rl_infos.py +187 -0
- torchrl/data/datasets/gen_dgrl.py +375 -0
- torchrl/data/datasets/minari_data.py +643 -0
- torchrl/data/datasets/openml.py +177 -0
- torchrl/data/datasets/openx.py +798 -0
- torchrl/data/datasets/roboset.py +363 -0
- torchrl/data/datasets/utils.py +11 -0
- torchrl/data/datasets/vd4rl.py +432 -0
- torchrl/data/llm/__init__.py +34 -0
- torchrl/data/llm/dataset.py +491 -0
- torchrl/data/llm/history.py +1378 -0
- torchrl/data/llm/prompt.py +198 -0
- torchrl/data/llm/reward.py +225 -0
- torchrl/data/llm/topk.py +186 -0
- torchrl/data/llm/utils.py +543 -0
- torchrl/data/map/__init__.py +21 -0
- torchrl/data/map/hash.py +185 -0
- torchrl/data/map/query.py +204 -0
- torchrl/data/map/tdstorage.py +363 -0
- torchrl/data/map/tree.py +1434 -0
- torchrl/data/map/utils.py +103 -0
- torchrl/data/postprocs/__init__.py +8 -0
- torchrl/data/postprocs/postprocs.py +391 -0
- torchrl/data/replay_buffers/__init__.py +99 -0
- torchrl/data/replay_buffers/checkpointers.py +622 -0
- torchrl/data/replay_buffers/ray_buffer.py +292 -0
- torchrl/data/replay_buffers/replay_buffers.py +2376 -0
- torchrl/data/replay_buffers/samplers.py +2578 -0
- torchrl/data/replay_buffers/scheduler.py +265 -0
- torchrl/data/replay_buffers/storages.py +2412 -0
- torchrl/data/replay_buffers/utils.py +1042 -0
- torchrl/data/replay_buffers/writers.py +781 -0
- torchrl/data/tensor_specs.py +7101 -0
- torchrl/data/utils.py +334 -0
- torchrl/envs/__init__.py +265 -0
- torchrl/envs/async_envs.py +1105 -0
- torchrl/envs/batched_envs.py +3093 -0
- torchrl/envs/common.py +4241 -0
- torchrl/envs/custom/__init__.py +11 -0
- torchrl/envs/custom/chess.py +617 -0
- torchrl/envs/custom/llm.py +214 -0
- torchrl/envs/custom/pendulum.py +401 -0
- torchrl/envs/custom/san_moves.txt +29274 -0
- torchrl/envs/custom/tictactoeenv.py +288 -0
- torchrl/envs/env_creator.py +263 -0
- torchrl/envs/gym_like.py +752 -0
- torchrl/envs/libs/__init__.py +68 -0
- torchrl/envs/libs/_gym_utils.py +326 -0
- torchrl/envs/libs/brax.py +846 -0
- torchrl/envs/libs/dm_control.py +544 -0
- torchrl/envs/libs/envpool.py +447 -0
- torchrl/envs/libs/gym.py +2239 -0
- torchrl/envs/libs/habitat.py +138 -0
- torchrl/envs/libs/isaac_lab.py +87 -0
- torchrl/envs/libs/isaacgym.py +203 -0
- torchrl/envs/libs/jax_utils.py +166 -0
- torchrl/envs/libs/jumanji.py +963 -0
- torchrl/envs/libs/meltingpot.py +599 -0
- torchrl/envs/libs/openml.py +153 -0
- torchrl/envs/libs/openspiel.py +652 -0
- torchrl/envs/libs/pettingzoo.py +1042 -0
- torchrl/envs/libs/procgen.py +351 -0
- torchrl/envs/libs/robohive.py +429 -0
- torchrl/envs/libs/smacv2.py +645 -0
- torchrl/envs/libs/unity_mlagents.py +891 -0
- torchrl/envs/libs/utils.py +147 -0
- torchrl/envs/libs/vmas.py +813 -0
- torchrl/envs/llm/__init__.py +63 -0
- torchrl/envs/llm/chat.py +730 -0
- torchrl/envs/llm/datasets/README.md +4 -0
- torchrl/envs/llm/datasets/__init__.py +17 -0
- torchrl/envs/llm/datasets/gsm8k.py +353 -0
- torchrl/envs/llm/datasets/ifeval.py +274 -0
- torchrl/envs/llm/envs.py +789 -0
- torchrl/envs/llm/libs/README.md +3 -0
- torchrl/envs/llm/libs/__init__.py +8 -0
- torchrl/envs/llm/libs/mlgym.py +869 -0
- torchrl/envs/llm/reward/__init__.py +10 -0
- torchrl/envs/llm/reward/gsm8k.py +324 -0
- torchrl/envs/llm/reward/ifeval/README.md +13 -0
- torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
- torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
- torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
- torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
- torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
- torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
- torchrl/envs/llm/transforms/__init__.py +55 -0
- torchrl/envs/llm/transforms/browser.py +292 -0
- torchrl/envs/llm/transforms/dataloading.py +859 -0
- torchrl/envs/llm/transforms/format.py +73 -0
- torchrl/envs/llm/transforms/kl.py +1544 -0
- torchrl/envs/llm/transforms/policy_version.py +189 -0
- torchrl/envs/llm/transforms/reason.py +323 -0
- torchrl/envs/llm/transforms/tokenizer.py +321 -0
- torchrl/envs/llm/transforms/tools.py +1955 -0
- torchrl/envs/model_based/__init__.py +9 -0
- torchrl/envs/model_based/common.py +180 -0
- torchrl/envs/model_based/dreamer.py +112 -0
- torchrl/envs/transforms/__init__.py +147 -0
- torchrl/envs/transforms/functional.py +48 -0
- torchrl/envs/transforms/gym_transforms.py +203 -0
- torchrl/envs/transforms/module.py +341 -0
- torchrl/envs/transforms/r3m.py +372 -0
- torchrl/envs/transforms/ray_service.py +663 -0
- torchrl/envs/transforms/rb_transforms.py +214 -0
- torchrl/envs/transforms/transforms.py +11835 -0
- torchrl/envs/transforms/utils.py +94 -0
- torchrl/envs/transforms/vc1.py +307 -0
- torchrl/envs/transforms/vecnorm.py +845 -0
- torchrl/envs/transforms/vip.py +407 -0
- torchrl/envs/utils.py +1718 -0
- torchrl/envs/vec_envs.py +11 -0
- torchrl/modules/__init__.py +206 -0
- torchrl/modules/distributions/__init__.py +73 -0
- torchrl/modules/distributions/continuous.py +830 -0
- torchrl/modules/distributions/discrete.py +908 -0
- torchrl/modules/distributions/truncated_normal.py +187 -0
- torchrl/modules/distributions/utils.py +233 -0
- torchrl/modules/llm/__init__.py +62 -0
- torchrl/modules/llm/backends/__init__.py +65 -0
- torchrl/modules/llm/backends/vllm/__init__.py +94 -0
- torchrl/modules/llm/backends/vllm/_models.py +46 -0
- torchrl/modules/llm/backends/vllm/base.py +72 -0
- torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
- torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
- torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
- torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
- torchrl/modules/llm/policies/__init__.py +28 -0
- torchrl/modules/llm/policies/common.py +1809 -0
- torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
- torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
- torchrl/modules/llm/utils.py +23 -0
- torchrl/modules/mcts/__init__.py +21 -0
- torchrl/modules/mcts/scores.py +579 -0
- torchrl/modules/models/__init__.py +86 -0
- torchrl/modules/models/batchrenorm.py +119 -0
- torchrl/modules/models/decision_transformer.py +179 -0
- torchrl/modules/models/exploration.py +731 -0
- torchrl/modules/models/llm.py +156 -0
- torchrl/modules/models/model_based.py +596 -0
- torchrl/modules/models/models.py +1712 -0
- torchrl/modules/models/multiagent.py +1067 -0
- torchrl/modules/models/recipes/impala.py +185 -0
- torchrl/modules/models/utils.py +162 -0
- torchrl/modules/planners/__init__.py +10 -0
- torchrl/modules/planners/cem.py +228 -0
- torchrl/modules/planners/common.py +73 -0
- torchrl/modules/planners/mppi.py +265 -0
- torchrl/modules/tensordict_module/__init__.py +89 -0
- torchrl/modules/tensordict_module/actors.py +2457 -0
- torchrl/modules/tensordict_module/common.py +529 -0
- torchrl/modules/tensordict_module/exploration.py +814 -0
- torchrl/modules/tensordict_module/probabilistic.py +321 -0
- torchrl/modules/tensordict_module/rnn.py +1639 -0
- torchrl/modules/tensordict_module/sequence.py +132 -0
- torchrl/modules/tensordict_module/world_models.py +34 -0
- torchrl/modules/utils/__init__.py +38 -0
- torchrl/modules/utils/mappings.py +9 -0
- torchrl/modules/utils/utils.py +89 -0
- torchrl/objectives/__init__.py +78 -0
- torchrl/objectives/a2c.py +659 -0
- torchrl/objectives/common.py +753 -0
- torchrl/objectives/cql.py +1346 -0
- torchrl/objectives/crossq.py +710 -0
- torchrl/objectives/ddpg.py +453 -0
- torchrl/objectives/decision_transformer.py +371 -0
- torchrl/objectives/deprecated.py +516 -0
- torchrl/objectives/dqn.py +683 -0
- torchrl/objectives/dreamer.py +488 -0
- torchrl/objectives/functional.py +48 -0
- torchrl/objectives/gail.py +258 -0
- torchrl/objectives/iql.py +996 -0
- torchrl/objectives/llm/__init__.py +30 -0
- torchrl/objectives/llm/grpo.py +846 -0
- torchrl/objectives/llm/sft.py +482 -0
- torchrl/objectives/multiagent/__init__.py +8 -0
- torchrl/objectives/multiagent/qmixer.py +396 -0
- torchrl/objectives/ppo.py +1669 -0
- torchrl/objectives/redq.py +683 -0
- torchrl/objectives/reinforce.py +530 -0
- torchrl/objectives/sac.py +1580 -0
- torchrl/objectives/td3.py +570 -0
- torchrl/objectives/td3_bc.py +625 -0
- torchrl/objectives/utils.py +782 -0
- torchrl/objectives/value/__init__.py +28 -0
- torchrl/objectives/value/advantages.py +1956 -0
- torchrl/objectives/value/functional.py +1459 -0
- torchrl/objectives/value/utils.py +360 -0
- torchrl/record/__init__.py +17 -0
- torchrl/record/loggers/__init__.py +23 -0
- torchrl/record/loggers/common.py +48 -0
- torchrl/record/loggers/csv.py +226 -0
- torchrl/record/loggers/mlflow.py +142 -0
- torchrl/record/loggers/tensorboard.py +139 -0
- torchrl/record/loggers/trackio.py +163 -0
- torchrl/record/loggers/utils.py +78 -0
- torchrl/record/loggers/wandb.py +214 -0
- torchrl/record/recorder.py +554 -0
- torchrl/services/__init__.py +79 -0
- torchrl/services/base.py +109 -0
- torchrl/services/ray_service.py +453 -0
- torchrl/testing/__init__.py +107 -0
- torchrl/testing/assertions.py +179 -0
- torchrl/testing/dist_utils.py +122 -0
- torchrl/testing/env_creators.py +227 -0
- torchrl/testing/env_helper.py +35 -0
- torchrl/testing/gym_helpers.py +156 -0
- torchrl/testing/llm_mocks.py +119 -0
- torchrl/testing/mocking_classes.py +2720 -0
- torchrl/testing/modules.py +295 -0
- torchrl/testing/mp_helpers.py +15 -0
- torchrl/testing/ray_helpers.py +293 -0
- torchrl/testing/utils.py +190 -0
- torchrl/trainers/__init__.py +42 -0
- torchrl/trainers/algorithms/__init__.py +11 -0
- torchrl/trainers/algorithms/configs/__init__.py +705 -0
- torchrl/trainers/algorithms/configs/collectors.py +216 -0
- torchrl/trainers/algorithms/configs/common.py +41 -0
- torchrl/trainers/algorithms/configs/data.py +308 -0
- torchrl/trainers/algorithms/configs/envs.py +104 -0
- torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
- torchrl/trainers/algorithms/configs/logging.py +80 -0
- torchrl/trainers/algorithms/configs/modules.py +570 -0
- torchrl/trainers/algorithms/configs/objectives.py +177 -0
- torchrl/trainers/algorithms/configs/trainers.py +340 -0
- torchrl/trainers/algorithms/configs/transforms.py +955 -0
- torchrl/trainers/algorithms/configs/utils.py +252 -0
- torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
- torchrl/trainers/algorithms/configs/weight_update.py +159 -0
- torchrl/trainers/algorithms/ppo.py +373 -0
- torchrl/trainers/algorithms/sac.py +308 -0
- torchrl/trainers/helpers/__init__.py +40 -0
- torchrl/trainers/helpers/collectors.py +416 -0
- torchrl/trainers/helpers/envs.py +573 -0
- torchrl/trainers/helpers/logger.py +33 -0
- torchrl/trainers/helpers/losses.py +132 -0
- torchrl/trainers/helpers/models.py +658 -0
- torchrl/trainers/helpers/replay_buffer.py +59 -0
- torchrl/trainers/helpers/trainers.py +301 -0
- torchrl/trainers/trainers.py +2052 -0
- torchrl/weight_update/__init__.py +33 -0
- torchrl/weight_update/_distributed.py +749 -0
- torchrl/weight_update/_mp.py +624 -0
- torchrl/weight_update/_noupdate.py +102 -0
- torchrl/weight_update/_ray.py +1032 -0
- torchrl/weight_update/_rpc.py +284 -0
- torchrl/weight_update/_shared.py +891 -0
- torchrl/weight_update/llm/__init__.py +32 -0
- torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
- torchrl/weight_update/llm/vllm_nccl.py +710 -0
- torchrl/weight_update/utils.py +73 -0
- torchrl/weight_update/weight_sync_schemes.py +1244 -0
- torchrl-0.11.0.dist-info/METADATA +1308 -0
- torchrl-0.11.0.dist-info/RECORD +395 -0
- torchrl-0.11.0.dist-info/WHEEL +5 -0
- torchrl-0.11.0.dist-info/entry_points.txt +2 -0
- torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
- torchrl-0.11.0.dist-info/top_level.txt +7 -0
|
@@ -0,0 +1,214 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
from tensordict import NestedKey, TensorDictBase
|
|
9
|
+
from torchrl.data.postprocs.postprocs import _multi_step_func
|
|
10
|
+
from torchrl.envs.transforms.transforms import Transform
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class MultiStepTransform(Transform):
|
|
14
|
+
"""A MultiStep transformation for ReplayBuffers.
|
|
15
|
+
|
|
16
|
+
This transform keeps the previous ``n_steps`` observations in a local buffer.
|
|
17
|
+
The inverse transform (called during :meth:`~torchrl.data.ReplayBuffer.extend`)
|
|
18
|
+
outputs the transformed previous ``n_steps`` with the ``T-n_steps`` current
|
|
19
|
+
frames.
|
|
20
|
+
|
|
21
|
+
All entries in the ``"next"`` tensordict that are not part of the ``done_keys``
|
|
22
|
+
or ``reward_keys`` will be mapped to their respective ``t + n_steps - 1``
|
|
23
|
+
correspondent.
|
|
24
|
+
|
|
25
|
+
This transform is a more hyperparameter resistant version of
|
|
26
|
+
:class:`~torchrl.data.postprocs.postprocs.MultiStep`:
|
|
27
|
+
the replay buffer transform will make the multi-step transform insensitive
|
|
28
|
+
to the collectors hyperparameters, whereas the post-process
|
|
29
|
+
version will output results that are sensitive to these
|
|
30
|
+
(because collectors have no memory of previous output).
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
n_steps (int): Number of steps in multi-step. The number of steps can be
|
|
34
|
+
dynamically changed by changing the ``n_steps`` attribute of this
|
|
35
|
+
transform.
|
|
36
|
+
gamma (:obj:`float`): Discount factor.
|
|
37
|
+
|
|
38
|
+
Keyword Args:
|
|
39
|
+
reward_keys (list of NestedKey, optional): the reward keys in the input tensordict.
|
|
40
|
+
The reward entries indicated by these keys will be accumulated and discounted
|
|
41
|
+
across ``n_steps`` steps in the future. A corresponding ``<reward_key>_orig``
|
|
42
|
+
entry will be written in the ``"next"`` entry of the output tensordict
|
|
43
|
+
to keep track of the original value of the reward.
|
|
44
|
+
Defaults to ``["reward"]``.
|
|
45
|
+
done_key (NestedKey, optional): the done key in the input tensordict, used to indicate
|
|
46
|
+
an end of trajectory.
|
|
47
|
+
Defaults to ``"done"``.
|
|
48
|
+
done_keys (list of NestedKey, optional): the list of end keys in the input tensordict.
|
|
49
|
+
All the entries indicated by these keys will be left untouched by the transform.
|
|
50
|
+
Defaults to ``["done", "truncated", "terminated"]``.
|
|
51
|
+
mask_key (NestedKey, optional): the mask key in the input tensordict.
|
|
52
|
+
The mask represents the valid frames in the input tensordict and
|
|
53
|
+
should have a shape that allows the input tensordict to be masked
|
|
54
|
+
with.
|
|
55
|
+
Defaults to ``"mask"``.
|
|
56
|
+
|
|
57
|
+
Examples:
|
|
58
|
+
>>> from torchrl.envs import GymEnv, TransformedEnv, StepCounter, MultiStepTransform, SerialEnv
|
|
59
|
+
>>> from torchrl.data import ReplayBuffer, LazyTensorStorage
|
|
60
|
+
>>> rb = ReplayBuffer(
|
|
61
|
+
... storage=LazyTensorStorage(100, ndim=2),
|
|
62
|
+
... transform=MultiStepTransform(n_steps=3, gamma=0.95)
|
|
63
|
+
... )
|
|
64
|
+
>>> base_env = SerialEnv(2, lambda: GymEnv("CartPole"))
|
|
65
|
+
>>> env = TransformedEnv(base_env, StepCounter())
|
|
66
|
+
>>> _ = env.set_seed(0)
|
|
67
|
+
>>> _ = torch.manual_seed(0)
|
|
68
|
+
>>> tdreset = env.reset()
|
|
69
|
+
>>> for _ in range(100):
|
|
70
|
+
... rollout = env.rollout(max_steps=50, break_when_any_done=False,
|
|
71
|
+
... tensordict=tdreset, auto_reset=False)
|
|
72
|
+
... indices = rb.extend(rollout)
|
|
73
|
+
... tdreset = rollout[..., -1]["next"]
|
|
74
|
+
>>> print("step_count", rb[:]["step_count"][:, :5])
|
|
75
|
+
step_count tensor([[[ 9],
|
|
76
|
+
[10],
|
|
77
|
+
[11],
|
|
78
|
+
[12],
|
|
79
|
+
[13]],
|
|
80
|
+
<BLANKLINE>
|
|
81
|
+
[[12],
|
|
82
|
+
[13],
|
|
83
|
+
[14],
|
|
84
|
+
[15],
|
|
85
|
+
[16]]])
|
|
86
|
+
>>> # The next step_count is 3 steps in the future
|
|
87
|
+
>>> print("next step_count", rb[:]["next", "step_count"][:, :5])
|
|
88
|
+
next step_count tensor([[[13],
|
|
89
|
+
[14],
|
|
90
|
+
[15],
|
|
91
|
+
[16],
|
|
92
|
+
[17]],
|
|
93
|
+
<BLANKLINE>
|
|
94
|
+
[[16],
|
|
95
|
+
[17],
|
|
96
|
+
[18],
|
|
97
|
+
[19],
|
|
98
|
+
[20]]])
|
|
99
|
+
|
|
100
|
+
"""
|
|
101
|
+
|
|
102
|
+
ENV_ERR = (
|
|
103
|
+
"The MultiStepTransform is only an inverse transform and can "
|
|
104
|
+
"be applied exclusively to replay buffers."
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
def __init__(
|
|
108
|
+
self,
|
|
109
|
+
n_steps,
|
|
110
|
+
gamma,
|
|
111
|
+
*,
|
|
112
|
+
reward_keys: list[NestedKey] | None = None,
|
|
113
|
+
done_key: NestedKey | None = None,
|
|
114
|
+
done_keys: list[NestedKey] | None = None,
|
|
115
|
+
mask_key: NestedKey | None = None,
|
|
116
|
+
):
|
|
117
|
+
super().__init__()
|
|
118
|
+
self.n_steps = n_steps
|
|
119
|
+
self.reward_keys = reward_keys
|
|
120
|
+
self.done_key = done_key
|
|
121
|
+
self.done_keys = done_keys
|
|
122
|
+
self.mask_key = mask_key
|
|
123
|
+
self.gamma = gamma
|
|
124
|
+
self._buffer = None
|
|
125
|
+
self._validated = False
|
|
126
|
+
|
|
127
|
+
@property
|
|
128
|
+
def n_steps(self):
|
|
129
|
+
"""The look ahead window of the transform.
|
|
130
|
+
|
|
131
|
+
This value can be dynamically edited during training.
|
|
132
|
+
"""
|
|
133
|
+
return self._n_steps
|
|
134
|
+
|
|
135
|
+
@n_steps.setter
|
|
136
|
+
def n_steps(self, value):
|
|
137
|
+
if not isinstance(value, int) or not (value >= 1):
|
|
138
|
+
raise ValueError(
|
|
139
|
+
"The value of n_steps must be a strictly positive integer."
|
|
140
|
+
)
|
|
141
|
+
self._n_steps = value
|
|
142
|
+
|
|
143
|
+
@property
|
|
144
|
+
def done_key(self):
|
|
145
|
+
return self._done_key
|
|
146
|
+
|
|
147
|
+
@done_key.setter
|
|
148
|
+
def done_key(self, value):
|
|
149
|
+
if value is None:
|
|
150
|
+
value = "done"
|
|
151
|
+
self._done_key = value
|
|
152
|
+
|
|
153
|
+
@property
|
|
154
|
+
def done_keys(self):
|
|
155
|
+
return self._done_keys
|
|
156
|
+
|
|
157
|
+
@done_keys.setter
|
|
158
|
+
def done_keys(self, value):
|
|
159
|
+
if value is None:
|
|
160
|
+
value = ["done", "terminated", "truncated"]
|
|
161
|
+
self._done_keys = value
|
|
162
|
+
|
|
163
|
+
@property
|
|
164
|
+
def reward_keys(self):
|
|
165
|
+
return self._reward_keys
|
|
166
|
+
|
|
167
|
+
@reward_keys.setter
|
|
168
|
+
def reward_keys(self, value):
|
|
169
|
+
if value is None:
|
|
170
|
+
value = [
|
|
171
|
+
"reward",
|
|
172
|
+
]
|
|
173
|
+
self._reward_keys = value
|
|
174
|
+
|
|
175
|
+
@property
|
|
176
|
+
def mask_key(self):
|
|
177
|
+
return self._mask_key
|
|
178
|
+
|
|
179
|
+
@mask_key.setter
|
|
180
|
+
def mask_key(self, value):
|
|
181
|
+
if value is None:
|
|
182
|
+
value = "mask"
|
|
183
|
+
self._mask_key = value
|
|
184
|
+
|
|
185
|
+
def _validate(self):
|
|
186
|
+
if self.parent is not None:
|
|
187
|
+
raise ValueError(self.ENV_ERR)
|
|
188
|
+
self._validated = True
|
|
189
|
+
|
|
190
|
+
def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase:
|
|
191
|
+
if not self._validated:
|
|
192
|
+
self._validate()
|
|
193
|
+
|
|
194
|
+
total_cat = self._append_tensordict(tensordict)
|
|
195
|
+
if total_cat.shape[-1] > self.n_steps:
|
|
196
|
+
out = _multi_step_func(
|
|
197
|
+
total_cat,
|
|
198
|
+
done_key=self.done_key,
|
|
199
|
+
done_keys=self.done_keys,
|
|
200
|
+
reward_keys=self.reward_keys,
|
|
201
|
+
mask_key=self.mask_key,
|
|
202
|
+
n_steps=self.n_steps,
|
|
203
|
+
gamma=self.gamma,
|
|
204
|
+
)
|
|
205
|
+
return out[..., : -self.n_steps]
|
|
206
|
+
|
|
207
|
+
def _append_tensordict(self, data):
|
|
208
|
+
if self._buffer is None:
|
|
209
|
+
total_cat = data
|
|
210
|
+
self._buffer = data[..., -self.n_steps :].copy()
|
|
211
|
+
else:
|
|
212
|
+
total_cat = torch.cat([self._buffer, data], -1)
|
|
213
|
+
self._buffer = total_cat[..., -self.n_steps :].copy()
|
|
214
|
+
return total_cat
|