torchrl 0.11.0__cp314-cp314t-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/benchmark_batched_envs.py +104 -0
- benchmarks/conftest.py +91 -0
- benchmarks/ecosystem/gym_env_throughput.py +321 -0
- benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
- benchmarks/requirements.txt +7 -0
- benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
- benchmarks/test_collectors_benchmark.py +240 -0
- benchmarks/test_compressed_storage_benchmark.py +145 -0
- benchmarks/test_envs_benchmark.py +133 -0
- benchmarks/test_llm.py +101 -0
- benchmarks/test_non_tensor_env_benchmark.py +70 -0
- benchmarks/test_objectives_benchmarks.py +1199 -0
- benchmarks/test_replaybuffer_benchmark.py +254 -0
- sota-check/README.md +35 -0
- sota-implementations/README.md +142 -0
- sota-implementations/a2c/README.md +39 -0
- sota-implementations/a2c/a2c_atari.py +291 -0
- sota-implementations/a2c/a2c_mujoco.py +273 -0
- sota-implementations/a2c/utils_atari.py +240 -0
- sota-implementations/a2c/utils_mujoco.py +160 -0
- sota-implementations/bandits/README.md +7 -0
- sota-implementations/bandits/dqn.py +126 -0
- sota-implementations/cql/cql_offline.py +198 -0
- sota-implementations/cql/cql_online.py +249 -0
- sota-implementations/cql/discrete_cql_offline.py +180 -0
- sota-implementations/cql/discrete_cql_online.py +227 -0
- sota-implementations/cql/utils.py +471 -0
- sota-implementations/crossq/crossq.py +271 -0
- sota-implementations/crossq/utils.py +320 -0
- sota-implementations/ddpg/ddpg.py +231 -0
- sota-implementations/ddpg/utils.py +325 -0
- sota-implementations/decision_transformer/dt.py +163 -0
- sota-implementations/decision_transformer/lamb.py +167 -0
- sota-implementations/decision_transformer/online_dt.py +178 -0
- sota-implementations/decision_transformer/utils.py +562 -0
- sota-implementations/discrete_sac/discrete_sac.py +243 -0
- sota-implementations/discrete_sac/utils.py +324 -0
- sota-implementations/dqn/README.md +30 -0
- sota-implementations/dqn/dqn_atari.py +272 -0
- sota-implementations/dqn/dqn_cartpole.py +236 -0
- sota-implementations/dqn/utils_atari.py +132 -0
- sota-implementations/dqn/utils_cartpole.py +90 -0
- sota-implementations/dreamer/README.md +129 -0
- sota-implementations/dreamer/dreamer.py +586 -0
- sota-implementations/dreamer/dreamer_utils.py +1107 -0
- sota-implementations/expert-iteration/README.md +352 -0
- sota-implementations/expert-iteration/ei_utils.py +770 -0
- sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
- sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
- sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
- sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
- sota-implementations/gail/gail.py +327 -0
- sota-implementations/gail/gail_utils.py +68 -0
- sota-implementations/gail/ppo_utils.py +157 -0
- sota-implementations/grpo/README.md +273 -0
- sota-implementations/grpo/grpo-async.py +437 -0
- sota-implementations/grpo/grpo-sync.py +435 -0
- sota-implementations/grpo/grpo_utils.py +843 -0
- sota-implementations/grpo/requirements_gsm8k.txt +11 -0
- sota-implementations/grpo/requirements_ifeval.txt +16 -0
- sota-implementations/impala/README.md +33 -0
- sota-implementations/impala/impala_multi_node_ray.py +292 -0
- sota-implementations/impala/impala_multi_node_submitit.py +284 -0
- sota-implementations/impala/impala_single_node.py +261 -0
- sota-implementations/impala/utils.py +184 -0
- sota-implementations/iql/discrete_iql.py +230 -0
- sota-implementations/iql/iql_offline.py +164 -0
- sota-implementations/iql/iql_online.py +225 -0
- sota-implementations/iql/utils.py +437 -0
- sota-implementations/multiagent/README.md +74 -0
- sota-implementations/multiagent/iql.py +237 -0
- sota-implementations/multiagent/maddpg_iddpg.py +266 -0
- sota-implementations/multiagent/mappo_ippo.py +267 -0
- sota-implementations/multiagent/qmix_vdn.py +271 -0
- sota-implementations/multiagent/sac.py +337 -0
- sota-implementations/multiagent/utils/__init__.py +4 -0
- sota-implementations/multiagent/utils/logging.py +151 -0
- sota-implementations/multiagent/utils/utils.py +43 -0
- sota-implementations/ppo/README.md +29 -0
- sota-implementations/ppo/ppo_atari.py +305 -0
- sota-implementations/ppo/ppo_mujoco.py +293 -0
- sota-implementations/ppo/utils_atari.py +238 -0
- sota-implementations/ppo/utils_mujoco.py +152 -0
- sota-implementations/ppo_trainer/train.py +21 -0
- sota-implementations/redq/README.md +7 -0
- sota-implementations/redq/redq.py +199 -0
- sota-implementations/redq/utils.py +1060 -0
- sota-implementations/sac/sac-async.py +266 -0
- sota-implementations/sac/sac.py +239 -0
- sota-implementations/sac/utils.py +381 -0
- sota-implementations/sac_trainer/train.py +16 -0
- sota-implementations/td3/td3.py +254 -0
- sota-implementations/td3/utils.py +319 -0
- sota-implementations/td3_bc/td3_bc.py +177 -0
- sota-implementations/td3_bc/utils.py +251 -0
- torchrl/.dylibs/libc++.1.0.dylib +0 -0
- torchrl/__init__.py +144 -0
- torchrl/_extension.py +74 -0
- torchrl/_torchrl.cpython-314t-darwin.so +0 -0
- torchrl/_utils.py +1431 -0
- torchrl/collectors/__init__.py +48 -0
- torchrl/collectors/_base.py +1058 -0
- torchrl/collectors/_constants.py +88 -0
- torchrl/collectors/_multi_async.py +324 -0
- torchrl/collectors/_multi_base.py +1805 -0
- torchrl/collectors/_multi_sync.py +464 -0
- torchrl/collectors/_runner.py +581 -0
- torchrl/collectors/_single.py +2009 -0
- torchrl/collectors/_single_async.py +259 -0
- torchrl/collectors/collectors.py +62 -0
- torchrl/collectors/distributed/__init__.py +32 -0
- torchrl/collectors/distributed/default_configs.py +133 -0
- torchrl/collectors/distributed/generic.py +1306 -0
- torchrl/collectors/distributed/ray.py +1092 -0
- torchrl/collectors/distributed/rpc.py +1006 -0
- torchrl/collectors/distributed/sync.py +731 -0
- torchrl/collectors/distributed/utils.py +160 -0
- torchrl/collectors/llm/__init__.py +10 -0
- torchrl/collectors/llm/base.py +494 -0
- torchrl/collectors/llm/ray_collector.py +275 -0
- torchrl/collectors/llm/utils.py +36 -0
- torchrl/collectors/llm/weight_update/__init__.py +10 -0
- torchrl/collectors/llm/weight_update/vllm.py +348 -0
- torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
- torchrl/collectors/utils.py +433 -0
- torchrl/collectors/weight_update.py +591 -0
- torchrl/csrc/numpy_utils.h +38 -0
- torchrl/csrc/pybind.cpp +27 -0
- torchrl/csrc/segment_tree.h +458 -0
- torchrl/csrc/torch_utils.h +34 -0
- torchrl/csrc/utils.cpp +48 -0
- torchrl/csrc/utils.h +31 -0
- torchrl/data/__init__.py +187 -0
- torchrl/data/datasets/__init__.py +58 -0
- torchrl/data/datasets/atari_dqn.py +878 -0
- torchrl/data/datasets/common.py +281 -0
- torchrl/data/datasets/d4rl.py +489 -0
- torchrl/data/datasets/d4rl_infos.py +187 -0
- torchrl/data/datasets/gen_dgrl.py +375 -0
- torchrl/data/datasets/minari_data.py +643 -0
- torchrl/data/datasets/openml.py +177 -0
- torchrl/data/datasets/openx.py +798 -0
- torchrl/data/datasets/roboset.py +363 -0
- torchrl/data/datasets/utils.py +11 -0
- torchrl/data/datasets/vd4rl.py +432 -0
- torchrl/data/llm/__init__.py +34 -0
- torchrl/data/llm/dataset.py +491 -0
- torchrl/data/llm/history.py +1378 -0
- torchrl/data/llm/prompt.py +198 -0
- torchrl/data/llm/reward.py +225 -0
- torchrl/data/llm/topk.py +186 -0
- torchrl/data/llm/utils.py +543 -0
- torchrl/data/map/__init__.py +21 -0
- torchrl/data/map/hash.py +185 -0
- torchrl/data/map/query.py +204 -0
- torchrl/data/map/tdstorage.py +363 -0
- torchrl/data/map/tree.py +1434 -0
- torchrl/data/map/utils.py +103 -0
- torchrl/data/postprocs/__init__.py +8 -0
- torchrl/data/postprocs/postprocs.py +391 -0
- torchrl/data/replay_buffers/__init__.py +99 -0
- torchrl/data/replay_buffers/checkpointers.py +622 -0
- torchrl/data/replay_buffers/ray_buffer.py +292 -0
- torchrl/data/replay_buffers/replay_buffers.py +2376 -0
- torchrl/data/replay_buffers/samplers.py +2578 -0
- torchrl/data/replay_buffers/scheduler.py +265 -0
- torchrl/data/replay_buffers/storages.py +2412 -0
- torchrl/data/replay_buffers/utils.py +1042 -0
- torchrl/data/replay_buffers/writers.py +781 -0
- torchrl/data/tensor_specs.py +7101 -0
- torchrl/data/utils.py +334 -0
- torchrl/envs/__init__.py +265 -0
- torchrl/envs/async_envs.py +1105 -0
- torchrl/envs/batched_envs.py +3093 -0
- torchrl/envs/common.py +4241 -0
- torchrl/envs/custom/__init__.py +11 -0
- torchrl/envs/custom/chess.py +617 -0
- torchrl/envs/custom/llm.py +214 -0
- torchrl/envs/custom/pendulum.py +401 -0
- torchrl/envs/custom/san_moves.txt +29274 -0
- torchrl/envs/custom/tictactoeenv.py +288 -0
- torchrl/envs/env_creator.py +263 -0
- torchrl/envs/gym_like.py +752 -0
- torchrl/envs/libs/__init__.py +68 -0
- torchrl/envs/libs/_gym_utils.py +326 -0
- torchrl/envs/libs/brax.py +846 -0
- torchrl/envs/libs/dm_control.py +544 -0
- torchrl/envs/libs/envpool.py +447 -0
- torchrl/envs/libs/gym.py +2239 -0
- torchrl/envs/libs/habitat.py +138 -0
- torchrl/envs/libs/isaac_lab.py +87 -0
- torchrl/envs/libs/isaacgym.py +203 -0
- torchrl/envs/libs/jax_utils.py +166 -0
- torchrl/envs/libs/jumanji.py +963 -0
- torchrl/envs/libs/meltingpot.py +599 -0
- torchrl/envs/libs/openml.py +153 -0
- torchrl/envs/libs/openspiel.py +652 -0
- torchrl/envs/libs/pettingzoo.py +1042 -0
- torchrl/envs/libs/procgen.py +351 -0
- torchrl/envs/libs/robohive.py +429 -0
- torchrl/envs/libs/smacv2.py +645 -0
- torchrl/envs/libs/unity_mlagents.py +891 -0
- torchrl/envs/libs/utils.py +147 -0
- torchrl/envs/libs/vmas.py +813 -0
- torchrl/envs/llm/__init__.py +63 -0
- torchrl/envs/llm/chat.py +730 -0
- torchrl/envs/llm/datasets/README.md +4 -0
- torchrl/envs/llm/datasets/__init__.py +17 -0
- torchrl/envs/llm/datasets/gsm8k.py +353 -0
- torchrl/envs/llm/datasets/ifeval.py +274 -0
- torchrl/envs/llm/envs.py +789 -0
- torchrl/envs/llm/libs/README.md +3 -0
- torchrl/envs/llm/libs/__init__.py +8 -0
- torchrl/envs/llm/libs/mlgym.py +869 -0
- torchrl/envs/llm/reward/__init__.py +10 -0
- torchrl/envs/llm/reward/gsm8k.py +324 -0
- torchrl/envs/llm/reward/ifeval/README.md +13 -0
- torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
- torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
- torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
- torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
- torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
- torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
- torchrl/envs/llm/transforms/__init__.py +55 -0
- torchrl/envs/llm/transforms/browser.py +292 -0
- torchrl/envs/llm/transforms/dataloading.py +859 -0
- torchrl/envs/llm/transforms/format.py +73 -0
- torchrl/envs/llm/transforms/kl.py +1544 -0
- torchrl/envs/llm/transforms/policy_version.py +189 -0
- torchrl/envs/llm/transforms/reason.py +323 -0
- torchrl/envs/llm/transforms/tokenizer.py +321 -0
- torchrl/envs/llm/transforms/tools.py +1955 -0
- torchrl/envs/model_based/__init__.py +9 -0
- torchrl/envs/model_based/common.py +180 -0
- torchrl/envs/model_based/dreamer.py +112 -0
- torchrl/envs/transforms/__init__.py +147 -0
- torchrl/envs/transforms/functional.py +48 -0
- torchrl/envs/transforms/gym_transforms.py +203 -0
- torchrl/envs/transforms/module.py +341 -0
- torchrl/envs/transforms/r3m.py +372 -0
- torchrl/envs/transforms/ray_service.py +663 -0
- torchrl/envs/transforms/rb_transforms.py +214 -0
- torchrl/envs/transforms/transforms.py +11835 -0
- torchrl/envs/transforms/utils.py +94 -0
- torchrl/envs/transforms/vc1.py +307 -0
- torchrl/envs/transforms/vecnorm.py +845 -0
- torchrl/envs/transforms/vip.py +407 -0
- torchrl/envs/utils.py +1718 -0
- torchrl/envs/vec_envs.py +11 -0
- torchrl/modules/__init__.py +206 -0
- torchrl/modules/distributions/__init__.py +73 -0
- torchrl/modules/distributions/continuous.py +830 -0
- torchrl/modules/distributions/discrete.py +908 -0
- torchrl/modules/distributions/truncated_normal.py +187 -0
- torchrl/modules/distributions/utils.py +233 -0
- torchrl/modules/llm/__init__.py +62 -0
- torchrl/modules/llm/backends/__init__.py +65 -0
- torchrl/modules/llm/backends/vllm/__init__.py +94 -0
- torchrl/modules/llm/backends/vllm/_models.py +46 -0
- torchrl/modules/llm/backends/vllm/base.py +72 -0
- torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
- torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
- torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
- torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
- torchrl/modules/llm/policies/__init__.py +28 -0
- torchrl/modules/llm/policies/common.py +1809 -0
- torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
- torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
- torchrl/modules/llm/utils.py +23 -0
- torchrl/modules/mcts/__init__.py +21 -0
- torchrl/modules/mcts/scores.py +579 -0
- torchrl/modules/models/__init__.py +86 -0
- torchrl/modules/models/batchrenorm.py +119 -0
- torchrl/modules/models/decision_transformer.py +179 -0
- torchrl/modules/models/exploration.py +731 -0
- torchrl/modules/models/llm.py +156 -0
- torchrl/modules/models/model_based.py +596 -0
- torchrl/modules/models/models.py +1712 -0
- torchrl/modules/models/multiagent.py +1067 -0
- torchrl/modules/models/recipes/impala.py +185 -0
- torchrl/modules/models/utils.py +162 -0
- torchrl/modules/planners/__init__.py +10 -0
- torchrl/modules/planners/cem.py +228 -0
- torchrl/modules/planners/common.py +73 -0
- torchrl/modules/planners/mppi.py +265 -0
- torchrl/modules/tensordict_module/__init__.py +89 -0
- torchrl/modules/tensordict_module/actors.py +2457 -0
- torchrl/modules/tensordict_module/common.py +529 -0
- torchrl/modules/tensordict_module/exploration.py +814 -0
- torchrl/modules/tensordict_module/probabilistic.py +321 -0
- torchrl/modules/tensordict_module/rnn.py +1639 -0
- torchrl/modules/tensordict_module/sequence.py +132 -0
- torchrl/modules/tensordict_module/world_models.py +34 -0
- torchrl/modules/utils/__init__.py +38 -0
- torchrl/modules/utils/mappings.py +9 -0
- torchrl/modules/utils/utils.py +89 -0
- torchrl/objectives/__init__.py +78 -0
- torchrl/objectives/a2c.py +659 -0
- torchrl/objectives/common.py +753 -0
- torchrl/objectives/cql.py +1346 -0
- torchrl/objectives/crossq.py +710 -0
- torchrl/objectives/ddpg.py +453 -0
- torchrl/objectives/decision_transformer.py +371 -0
- torchrl/objectives/deprecated.py +516 -0
- torchrl/objectives/dqn.py +683 -0
- torchrl/objectives/dreamer.py +488 -0
- torchrl/objectives/functional.py +48 -0
- torchrl/objectives/gail.py +258 -0
- torchrl/objectives/iql.py +996 -0
- torchrl/objectives/llm/__init__.py +30 -0
- torchrl/objectives/llm/grpo.py +846 -0
- torchrl/objectives/llm/sft.py +482 -0
- torchrl/objectives/multiagent/__init__.py +8 -0
- torchrl/objectives/multiagent/qmixer.py +396 -0
- torchrl/objectives/ppo.py +1669 -0
- torchrl/objectives/redq.py +683 -0
- torchrl/objectives/reinforce.py +530 -0
- torchrl/objectives/sac.py +1580 -0
- torchrl/objectives/td3.py +570 -0
- torchrl/objectives/td3_bc.py +625 -0
- torchrl/objectives/utils.py +782 -0
- torchrl/objectives/value/__init__.py +28 -0
- torchrl/objectives/value/advantages.py +1956 -0
- torchrl/objectives/value/functional.py +1459 -0
- torchrl/objectives/value/utils.py +360 -0
- torchrl/record/__init__.py +17 -0
- torchrl/record/loggers/__init__.py +23 -0
- torchrl/record/loggers/common.py +48 -0
- torchrl/record/loggers/csv.py +226 -0
- torchrl/record/loggers/mlflow.py +142 -0
- torchrl/record/loggers/tensorboard.py +139 -0
- torchrl/record/loggers/trackio.py +163 -0
- torchrl/record/loggers/utils.py +78 -0
- torchrl/record/loggers/wandb.py +214 -0
- torchrl/record/recorder.py +554 -0
- torchrl/services/__init__.py +79 -0
- torchrl/services/base.py +109 -0
- torchrl/services/ray_service.py +453 -0
- torchrl/testing/__init__.py +107 -0
- torchrl/testing/assertions.py +179 -0
- torchrl/testing/dist_utils.py +122 -0
- torchrl/testing/env_creators.py +227 -0
- torchrl/testing/env_helper.py +35 -0
- torchrl/testing/gym_helpers.py +156 -0
- torchrl/testing/llm_mocks.py +119 -0
- torchrl/testing/mocking_classes.py +2720 -0
- torchrl/testing/modules.py +295 -0
- torchrl/testing/mp_helpers.py +15 -0
- torchrl/testing/ray_helpers.py +293 -0
- torchrl/testing/utils.py +190 -0
- torchrl/trainers/__init__.py +42 -0
- torchrl/trainers/algorithms/__init__.py +11 -0
- torchrl/trainers/algorithms/configs/__init__.py +705 -0
- torchrl/trainers/algorithms/configs/collectors.py +216 -0
- torchrl/trainers/algorithms/configs/common.py +41 -0
- torchrl/trainers/algorithms/configs/data.py +308 -0
- torchrl/trainers/algorithms/configs/envs.py +104 -0
- torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
- torchrl/trainers/algorithms/configs/logging.py +80 -0
- torchrl/trainers/algorithms/configs/modules.py +570 -0
- torchrl/trainers/algorithms/configs/objectives.py +177 -0
- torchrl/trainers/algorithms/configs/trainers.py +340 -0
- torchrl/trainers/algorithms/configs/transforms.py +955 -0
- torchrl/trainers/algorithms/configs/utils.py +252 -0
- torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
- torchrl/trainers/algorithms/configs/weight_update.py +159 -0
- torchrl/trainers/algorithms/ppo.py +373 -0
- torchrl/trainers/algorithms/sac.py +308 -0
- torchrl/trainers/helpers/__init__.py +40 -0
- torchrl/trainers/helpers/collectors.py +416 -0
- torchrl/trainers/helpers/envs.py +573 -0
- torchrl/trainers/helpers/logger.py +33 -0
- torchrl/trainers/helpers/losses.py +132 -0
- torchrl/trainers/helpers/models.py +658 -0
- torchrl/trainers/helpers/replay_buffer.py +59 -0
- torchrl/trainers/helpers/trainers.py +301 -0
- torchrl/trainers/trainers.py +2052 -0
- torchrl/weight_update/__init__.py +33 -0
- torchrl/weight_update/_distributed.py +749 -0
- torchrl/weight_update/_mp.py +624 -0
- torchrl/weight_update/_noupdate.py +102 -0
- torchrl/weight_update/_ray.py +1032 -0
- torchrl/weight_update/_rpc.py +284 -0
- torchrl/weight_update/_shared.py +891 -0
- torchrl/weight_update/llm/__init__.py +32 -0
- torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
- torchrl/weight_update/llm/vllm_nccl.py +710 -0
- torchrl/weight_update/utils.py +73 -0
- torchrl/weight_update/weight_sync_schemes.py +1244 -0
- torchrl-0.11.0.dist-info/METADATA +1308 -0
- torchrl-0.11.0.dist-info/RECORD +395 -0
- torchrl-0.11.0.dist-info/WHEEL +5 -0
- torchrl-0.11.0.dist-info/entry_points.txt +2 -0
- torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
- torchrl-0.11.0.dist-info/top_level.txt +7 -0
|
@@ -0,0 +1,955 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
from dataclasses import dataclass
|
|
9
|
+
from typing import Any
|
|
10
|
+
|
|
11
|
+
from torchrl.trainers.algorithms.configs.common import ConfigBase
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@dataclass
|
|
15
|
+
class TransformConfig(ConfigBase):
|
|
16
|
+
"""Base configuration class for transforms."""
|
|
17
|
+
|
|
18
|
+
def __post_init__(self) -> None:
|
|
19
|
+
"""Post-initialization hook for transform configurations."""
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@dataclass
|
|
23
|
+
class NoopResetEnvConfig(TransformConfig):
|
|
24
|
+
"""Configuration for NoopResetEnv transform."""
|
|
25
|
+
|
|
26
|
+
noops: int = 30
|
|
27
|
+
random: bool = True
|
|
28
|
+
_target_: str = "torchrl.envs.transforms.transforms.NoopResetEnv"
|
|
29
|
+
|
|
30
|
+
def __post_init__(self) -> None:
|
|
31
|
+
"""Post-initialization hook for NoopResetEnv configuration."""
|
|
32
|
+
super().__post_init__()
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@dataclass
|
|
36
|
+
class StepCounterConfig(TransformConfig):
|
|
37
|
+
"""Configuration for StepCounter transform."""
|
|
38
|
+
|
|
39
|
+
max_steps: int | None = None
|
|
40
|
+
truncated_key: str | None = "truncated"
|
|
41
|
+
step_count_key: str | None = "step_count"
|
|
42
|
+
update_done: bool = True
|
|
43
|
+
_target_: str = "torchrl.envs.transforms.transforms.StepCounter"
|
|
44
|
+
|
|
45
|
+
def __post_init__(self) -> None:
|
|
46
|
+
"""Post-initialization hook for StepCounter configuration."""
|
|
47
|
+
super().__post_init__()
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
@dataclass
|
|
51
|
+
class ComposeConfig(TransformConfig):
|
|
52
|
+
"""Configuration for Compose transform."""
|
|
53
|
+
|
|
54
|
+
transforms: list[Any] | None = None
|
|
55
|
+
_target_: str = "torchrl.envs.transforms.transforms.Compose"
|
|
56
|
+
|
|
57
|
+
def __post_init__(self) -> None:
|
|
58
|
+
"""Post-initialization hook for Compose configuration."""
|
|
59
|
+
super().__post_init__()
|
|
60
|
+
if self.transforms is None:
|
|
61
|
+
self.transforms = []
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
@dataclass
|
|
65
|
+
class DoubleToFloatConfig(TransformConfig):
|
|
66
|
+
"""Configuration for DoubleToFloat transform."""
|
|
67
|
+
|
|
68
|
+
in_keys: list[str] | None = None
|
|
69
|
+
out_keys: list[str] | None = None
|
|
70
|
+
in_keys_inv: list[str] | None = None
|
|
71
|
+
out_keys_inv: list[str] | None = None
|
|
72
|
+
_target_: str = "torchrl.envs.transforms.transforms.DoubleToFloat"
|
|
73
|
+
|
|
74
|
+
def __post_init__(self) -> None:
|
|
75
|
+
"""Post-initialization hook for DoubleToFloat configuration."""
|
|
76
|
+
super().__post_init__()
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
@dataclass
|
|
80
|
+
class ToTensorImageConfig(TransformConfig):
|
|
81
|
+
"""Configuration for ToTensorImage transform."""
|
|
82
|
+
|
|
83
|
+
from_int: bool | None = None
|
|
84
|
+
unsqueeze: bool = False
|
|
85
|
+
dtype: str | None = None
|
|
86
|
+
in_keys: list[str] | None = None
|
|
87
|
+
out_keys: list[str] | None = None
|
|
88
|
+
shape_tolerant: bool = False
|
|
89
|
+
_target_: str = "torchrl.envs.transforms.transforms.ToTensorImage"
|
|
90
|
+
|
|
91
|
+
def __post_init__(self) -> None:
|
|
92
|
+
"""Post-initialization hook for ToTensorImage configuration."""
|
|
93
|
+
super().__post_init__()
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
@dataclass
|
|
97
|
+
class ClipTransformConfig(TransformConfig):
|
|
98
|
+
"""Configuration for ClipTransform."""
|
|
99
|
+
|
|
100
|
+
in_keys: list[str] | None = None
|
|
101
|
+
out_keys: list[str] | None = None
|
|
102
|
+
in_keys_inv: list[str] | None = None
|
|
103
|
+
out_keys_inv: list[str] | None = None
|
|
104
|
+
low: float | None = None
|
|
105
|
+
high: float | None = None
|
|
106
|
+
_target_: str = "torchrl.envs.transforms.transforms.ClipTransform"
|
|
107
|
+
|
|
108
|
+
def __post_init__(self) -> None:
|
|
109
|
+
"""Post-initialization hook for ClipTransform configuration."""
|
|
110
|
+
super().__post_init__()
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
@dataclass
|
|
114
|
+
class ResizeConfig(TransformConfig):
|
|
115
|
+
"""Configuration for Resize transform."""
|
|
116
|
+
|
|
117
|
+
w: int = 84
|
|
118
|
+
h: int = 84
|
|
119
|
+
interpolation: str = "bilinear"
|
|
120
|
+
in_keys: list[str] | None = None
|
|
121
|
+
out_keys: list[str] | None = None
|
|
122
|
+
_target_: str = "torchrl.envs.transforms.transforms.Resize"
|
|
123
|
+
|
|
124
|
+
def __post_init__(self) -> None:
|
|
125
|
+
"""Post-initialization hook for Resize configuration."""
|
|
126
|
+
super().__post_init__()
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
@dataclass
|
|
130
|
+
class CenterCropConfig(TransformConfig):
|
|
131
|
+
"""Configuration for CenterCrop transform."""
|
|
132
|
+
|
|
133
|
+
height: int = 84
|
|
134
|
+
width: int = 84
|
|
135
|
+
in_keys: list[str] | None = None
|
|
136
|
+
out_keys: list[str] | None = None
|
|
137
|
+
_target_: str = "torchrl.envs.transforms.transforms.CenterCrop"
|
|
138
|
+
|
|
139
|
+
def __post_init__(self) -> None:
|
|
140
|
+
"""Post-initialization hook for CenterCrop configuration."""
|
|
141
|
+
super().__post_init__()
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
@dataclass
|
|
145
|
+
class FlattenObservationConfig(TransformConfig):
|
|
146
|
+
"""Configuration for FlattenObservation transform."""
|
|
147
|
+
|
|
148
|
+
in_keys: list[str] | None = None
|
|
149
|
+
out_keys: list[str] | None = None
|
|
150
|
+
_target_: str = "torchrl.envs.transforms.transforms.FlattenObservation"
|
|
151
|
+
|
|
152
|
+
def __post_init__(self) -> None:
|
|
153
|
+
"""Post-initialization hook for FlattenObservation configuration."""
|
|
154
|
+
super().__post_init__()
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
@dataclass
|
|
158
|
+
class GrayScaleConfig(TransformConfig):
|
|
159
|
+
"""Configuration for GrayScale transform."""
|
|
160
|
+
|
|
161
|
+
in_keys: list[str] | None = None
|
|
162
|
+
out_keys: list[str] | None = None
|
|
163
|
+
_target_: str = "torchrl.envs.transforms.transforms.GrayScale"
|
|
164
|
+
|
|
165
|
+
def __post_init__(self) -> None:
|
|
166
|
+
"""Post-initialization hook for GrayScale configuration."""
|
|
167
|
+
super().__post_init__()
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
@dataclass
|
|
171
|
+
class ObservationNormConfig(TransformConfig):
|
|
172
|
+
"""Configuration for ObservationNorm transform."""
|
|
173
|
+
|
|
174
|
+
loc: float = 0.0
|
|
175
|
+
scale: float = 1.0
|
|
176
|
+
in_keys: list[str] | None = None
|
|
177
|
+
out_keys: list[str] | None = None
|
|
178
|
+
standard_normal: bool = False
|
|
179
|
+
eps: float = 1e-8
|
|
180
|
+
_target_: str = "torchrl.envs.transforms.transforms.ObservationNorm"
|
|
181
|
+
|
|
182
|
+
def __post_init__(self) -> None:
|
|
183
|
+
"""Post-initialization hook for ObservationNorm configuration."""
|
|
184
|
+
super().__post_init__()
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
@dataclass
|
|
188
|
+
class CatFramesConfig(TransformConfig):
|
|
189
|
+
"""Configuration for CatFrames transform."""
|
|
190
|
+
|
|
191
|
+
N: int = 4
|
|
192
|
+
dim: int = -3
|
|
193
|
+
in_keys: list[str] | None = None
|
|
194
|
+
out_keys: list[str] | None = None
|
|
195
|
+
_target_: str = "torchrl.envs.transforms.transforms.CatFrames"
|
|
196
|
+
|
|
197
|
+
def __post_init__(self) -> None:
|
|
198
|
+
"""Post-initialization hook for CatFrames configuration."""
|
|
199
|
+
super().__post_init__()
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
@dataclass
|
|
203
|
+
class RewardClippingConfig(TransformConfig):
|
|
204
|
+
"""Configuration for RewardClipping transform."""
|
|
205
|
+
|
|
206
|
+
clamp_min: float | None = None
|
|
207
|
+
clamp_max: float | None = None
|
|
208
|
+
in_keys: list[str] | None = None
|
|
209
|
+
out_keys: list[str] | None = None
|
|
210
|
+
_target_: str = "torchrl.envs.transforms.transforms.RewardClipping"
|
|
211
|
+
|
|
212
|
+
def __post_init__(self) -> None:
|
|
213
|
+
"""Post-initialization hook for RewardClipping configuration."""
|
|
214
|
+
super().__post_init__()
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
@dataclass
|
|
218
|
+
class RewardScalingConfig(TransformConfig):
|
|
219
|
+
"""Configuration for RewardScaling transform."""
|
|
220
|
+
|
|
221
|
+
loc: float = 0.0
|
|
222
|
+
scale: float = 1.0
|
|
223
|
+
in_keys: list[str] | None = None
|
|
224
|
+
out_keys: list[str] | None = None
|
|
225
|
+
standard_normal: bool = False
|
|
226
|
+
eps: float = 1e-8
|
|
227
|
+
_target_: str = "torchrl.envs.transforms.transforms.RewardScaling"
|
|
228
|
+
|
|
229
|
+
def __post_init__(self) -> None:
|
|
230
|
+
"""Post-initialization hook for RewardScaling configuration."""
|
|
231
|
+
super().__post_init__()
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
@dataclass
|
|
235
|
+
class VecNormConfig(TransformConfig):
|
|
236
|
+
"""Configuration for VecNorm transform."""
|
|
237
|
+
|
|
238
|
+
in_keys: list[str] | None = None
|
|
239
|
+
out_keys: list[str] | None = None
|
|
240
|
+
decay: float = 0.99
|
|
241
|
+
eps: float = 1e-8
|
|
242
|
+
_target_: str = "torchrl.envs.transforms.transforms.VecNorm"
|
|
243
|
+
|
|
244
|
+
def __post_init__(self) -> None:
|
|
245
|
+
"""Post-initialization hook for VecNorm configuration."""
|
|
246
|
+
super().__post_init__()
|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
@dataclass
|
|
250
|
+
class FrameSkipTransformConfig(TransformConfig):
|
|
251
|
+
"""Configuration for FrameSkipTransform."""
|
|
252
|
+
|
|
253
|
+
frame_skip: int = 4
|
|
254
|
+
in_keys: list[str] | None = None
|
|
255
|
+
out_keys: list[str] | None = None
|
|
256
|
+
_target_: str = "torchrl.envs.transforms.transforms.FrameSkipTransform"
|
|
257
|
+
|
|
258
|
+
def __post_init__(self) -> None:
|
|
259
|
+
"""Post-initialization hook for FrameSkipTransform configuration."""
|
|
260
|
+
super().__post_init__()
|
|
261
|
+
|
|
262
|
+
|
|
263
|
+
@dataclass
|
|
264
|
+
class EndOfLifeTransformConfig(TransformConfig):
|
|
265
|
+
"""Configuration for EndOfLifeTransform."""
|
|
266
|
+
|
|
267
|
+
eol_key: str = "end-of-life"
|
|
268
|
+
lives_key: str = "lives"
|
|
269
|
+
done_key: str = "done"
|
|
270
|
+
eol_attribute: str = "unwrapped.ale.lives"
|
|
271
|
+
_target_: str = "torchrl.envs.transforms.gym_transforms.EndOfLifeTransform"
|
|
272
|
+
|
|
273
|
+
def __post_init__(self) -> None:
|
|
274
|
+
"""Post-initialization hook for EndOfLifeTransform configuration."""
|
|
275
|
+
super().__post_init__()
|
|
276
|
+
|
|
277
|
+
|
|
278
|
+
@dataclass
|
|
279
|
+
class MultiStepTransformConfig(TransformConfig):
|
|
280
|
+
"""Configuration for MultiStepTransform."""
|
|
281
|
+
|
|
282
|
+
n_steps: int = 3
|
|
283
|
+
gamma: float = 0.99
|
|
284
|
+
in_keys: list[str] | None = None
|
|
285
|
+
out_keys: list[str] | None = None
|
|
286
|
+
_target_: str = "torchrl.envs.transforms.rb_transforms.MultiStepTransform"
|
|
287
|
+
|
|
288
|
+
def __post_init__(self) -> None:
|
|
289
|
+
"""Post-initialization hook for MultiStepTransform configuration."""
|
|
290
|
+
super().__post_init__()
|
|
291
|
+
|
|
292
|
+
|
|
293
|
+
@dataclass
|
|
294
|
+
class TargetReturnConfig(TransformConfig):
|
|
295
|
+
"""Configuration for TargetReturn transform."""
|
|
296
|
+
|
|
297
|
+
target_return: float = 10.0
|
|
298
|
+
mode: str = "reduce"
|
|
299
|
+
in_keys: list[str] | None = None
|
|
300
|
+
out_keys: list[str] | None = None
|
|
301
|
+
reset_key: str | None = None
|
|
302
|
+
_target_: str = "torchrl.envs.transforms.transforms.TargetReturn"
|
|
303
|
+
|
|
304
|
+
def __post_init__(self) -> None:
|
|
305
|
+
"""Post-initialization hook for TargetReturn configuration."""
|
|
306
|
+
super().__post_init__()
|
|
307
|
+
|
|
308
|
+
|
|
309
|
+
@dataclass
|
|
310
|
+
class BinarizeRewardConfig(TransformConfig):
|
|
311
|
+
"""Configuration for BinarizeReward transform."""
|
|
312
|
+
|
|
313
|
+
in_keys: list[str] | None = None
|
|
314
|
+
out_keys: list[str] | None = None
|
|
315
|
+
_target_: str = "torchrl.envs.transforms.transforms.BinarizeReward"
|
|
316
|
+
|
|
317
|
+
def __post_init__(self) -> None:
|
|
318
|
+
"""Post-initialization hook for BinarizeReward configuration."""
|
|
319
|
+
super().__post_init__()
|
|
320
|
+
|
|
321
|
+
|
|
322
|
+
@dataclass
|
|
323
|
+
class ActionDiscretizerConfig(TransformConfig):
|
|
324
|
+
"""Configuration for ActionDiscretizer transform."""
|
|
325
|
+
|
|
326
|
+
num_intervals: int = 10
|
|
327
|
+
action_key: str = "action"
|
|
328
|
+
out_action_key: str | None = None
|
|
329
|
+
sampling: str | None = None
|
|
330
|
+
categorical: bool = True
|
|
331
|
+
_target_: str = "torchrl.envs.transforms.transforms.ActionDiscretizer"
|
|
332
|
+
|
|
333
|
+
def __post_init__(self) -> None:
|
|
334
|
+
"""Post-initialization hook for ActionDiscretizer configuration."""
|
|
335
|
+
super().__post_init__()
|
|
336
|
+
|
|
337
|
+
|
|
338
|
+
@dataclass
|
|
339
|
+
class AutoResetTransformConfig(TransformConfig):
|
|
340
|
+
"""Configuration for AutoResetTransform."""
|
|
341
|
+
|
|
342
|
+
replace: bool | None = None
|
|
343
|
+
fill_float: str = "nan"
|
|
344
|
+
fill_int: int = -1
|
|
345
|
+
fill_bool: bool = False
|
|
346
|
+
_target_: str = "torchrl.envs.transforms.transforms.AutoResetTransform"
|
|
347
|
+
|
|
348
|
+
def __post_init__(self) -> None:
|
|
349
|
+
"""Post-initialization hook for AutoResetTransform configuration."""
|
|
350
|
+
super().__post_init__()
|
|
351
|
+
|
|
352
|
+
|
|
353
|
+
@dataclass
|
|
354
|
+
class BatchSizeTransformConfig(TransformConfig):
|
|
355
|
+
"""Configuration for BatchSizeTransform."""
|
|
356
|
+
|
|
357
|
+
batch_size: list[int] | None = None
|
|
358
|
+
reshape_fn: Any = None
|
|
359
|
+
reset_func: Any = None
|
|
360
|
+
env_kwarg: bool = False
|
|
361
|
+
_target_: str = "torchrl.envs.transforms.transforms.BatchSizeTransform"
|
|
362
|
+
|
|
363
|
+
def __post_init__(self) -> None:
|
|
364
|
+
"""Post-initialization hook for BatchSizeTransform configuration."""
|
|
365
|
+
super().__post_init__()
|
|
366
|
+
|
|
367
|
+
|
|
368
|
+
@dataclass
|
|
369
|
+
class DeviceCastTransformConfig(TransformConfig):
|
|
370
|
+
"""Configuration for DeviceCastTransform."""
|
|
371
|
+
|
|
372
|
+
device: str = "cpu"
|
|
373
|
+
in_keys: list[str] | None = None
|
|
374
|
+
out_keys: list[str] | None = None
|
|
375
|
+
in_keys_inv: list[str] | None = None
|
|
376
|
+
out_keys_inv: list[str] | None = None
|
|
377
|
+
_target_: str = "torchrl.envs.transforms.transforms.DeviceCastTransform"
|
|
378
|
+
|
|
379
|
+
def __post_init__(self) -> None:
|
|
380
|
+
"""Post-initialization hook for DeviceCastTransform configuration."""
|
|
381
|
+
super().__post_init__()
|
|
382
|
+
|
|
383
|
+
|
|
384
|
+
@dataclass
|
|
385
|
+
class DTypeCastTransformConfig(TransformConfig):
|
|
386
|
+
"""Configuration for DTypeCastTransform."""
|
|
387
|
+
|
|
388
|
+
dtype: str = "torch.float32"
|
|
389
|
+
in_keys: list[str] | None = None
|
|
390
|
+
out_keys: list[str] | None = None
|
|
391
|
+
in_keys_inv: list[str] | None = None
|
|
392
|
+
out_keys_inv: list[str] | None = None
|
|
393
|
+
_target_: str = "torchrl.envs.transforms.transforms.DTypeCastTransform"
|
|
394
|
+
|
|
395
|
+
def __post_init__(self) -> None:
|
|
396
|
+
"""Post-initialization hook for DTypeCastTransform configuration."""
|
|
397
|
+
super().__post_init__()
|
|
398
|
+
|
|
399
|
+
|
|
400
|
+
@dataclass
|
|
401
|
+
class UnsqueezeTransformConfig(TransformConfig):
|
|
402
|
+
"""Configuration for UnsqueezeTransform."""
|
|
403
|
+
|
|
404
|
+
dim: int = 0
|
|
405
|
+
in_keys: list[str] | None = None
|
|
406
|
+
out_keys: list[str] | None = None
|
|
407
|
+
_target_: str = "torchrl.envs.transforms.transforms.UnsqueezeTransform"
|
|
408
|
+
|
|
409
|
+
def __post_init__(self) -> None:
|
|
410
|
+
"""Post-initialization hook for UnsqueezeTransform configuration."""
|
|
411
|
+
super().__post_init__()
|
|
412
|
+
|
|
413
|
+
|
|
414
|
+
@dataclass
|
|
415
|
+
class SqueezeTransformConfig(TransformConfig):
|
|
416
|
+
"""Configuration for SqueezeTransform."""
|
|
417
|
+
|
|
418
|
+
dim: int = 0
|
|
419
|
+
in_keys: list[str] | None = None
|
|
420
|
+
out_keys: list[str] | None = None
|
|
421
|
+
_target_: str = "torchrl.envs.transforms.transforms.SqueezeTransform"
|
|
422
|
+
|
|
423
|
+
def __post_init__(self) -> None:
|
|
424
|
+
"""Post-initialization hook for SqueezeTransform configuration."""
|
|
425
|
+
super().__post_init__()
|
|
426
|
+
|
|
427
|
+
|
|
428
|
+
@dataclass
|
|
429
|
+
class PermuteTransformConfig(TransformConfig):
|
|
430
|
+
"""Configuration for PermuteTransform."""
|
|
431
|
+
|
|
432
|
+
dims: list[int] | None = None
|
|
433
|
+
in_keys: list[str] | None = None
|
|
434
|
+
out_keys: list[str] | None = None
|
|
435
|
+
_target_: str = "torchrl.envs.transforms.transforms.PermuteTransform"
|
|
436
|
+
|
|
437
|
+
def __post_init__(self) -> None:
|
|
438
|
+
"""Post-initialization hook for PermuteTransform configuration."""
|
|
439
|
+
super().__post_init__()
|
|
440
|
+
if self.dims is None:
|
|
441
|
+
self.dims = [0, 2, 1]
|
|
442
|
+
|
|
443
|
+
|
|
444
|
+
@dataclass
|
|
445
|
+
class CatTensorsConfig(TransformConfig):
|
|
446
|
+
"""Configuration for CatTensors transform."""
|
|
447
|
+
|
|
448
|
+
dim: int = -1
|
|
449
|
+
in_keys: list[str] | None = None
|
|
450
|
+
out_keys: list[str] | None = None
|
|
451
|
+
_target_: str = "torchrl.envs.transforms.transforms.CatTensors"
|
|
452
|
+
|
|
453
|
+
def __post_init__(self) -> None:
|
|
454
|
+
"""Post-initialization hook for CatTensors configuration."""
|
|
455
|
+
super().__post_init__()
|
|
456
|
+
|
|
457
|
+
|
|
458
|
+
@dataclass
|
|
459
|
+
class StackConfig(TransformConfig):
|
|
460
|
+
"""Configuration for Stack transform."""
|
|
461
|
+
|
|
462
|
+
dim: int = 0
|
|
463
|
+
in_keys: list[str] | None = None
|
|
464
|
+
out_keys: list[str] | None = None
|
|
465
|
+
_target_: str = "torchrl.envs.transforms.transforms.Stack"
|
|
466
|
+
|
|
467
|
+
def __post_init__(self) -> None:
|
|
468
|
+
"""Post-initialization hook for Stack configuration."""
|
|
469
|
+
super().__post_init__()
|
|
470
|
+
|
|
471
|
+
|
|
472
|
+
@dataclass
|
|
473
|
+
class DiscreteActionProjectionConfig(TransformConfig):
|
|
474
|
+
"""Configuration for DiscreteActionProjection transform."""
|
|
475
|
+
|
|
476
|
+
num_actions: int = 4
|
|
477
|
+
in_keys: list[str] | None = None
|
|
478
|
+
out_keys: list[str] | None = None
|
|
479
|
+
_target_: str = "torchrl.envs.transforms.transforms.DiscreteActionProjection"
|
|
480
|
+
|
|
481
|
+
def __post_init__(self) -> None:
|
|
482
|
+
"""Post-initialization hook for DiscreteActionProjection configuration."""
|
|
483
|
+
super().__post_init__()
|
|
484
|
+
|
|
485
|
+
|
|
486
|
+
@dataclass
|
|
487
|
+
class TensorDictPrimerConfig(TransformConfig):
|
|
488
|
+
"""Configuration for TensorDictPrimer transform."""
|
|
489
|
+
|
|
490
|
+
primer_spec: Any = None
|
|
491
|
+
in_keys: list[str] | None = None
|
|
492
|
+
out_keys: list[str] | None = None
|
|
493
|
+
_target_: str = "torchrl.envs.transforms.transforms.TensorDictPrimer"
|
|
494
|
+
|
|
495
|
+
def __post_init__(self) -> None:
|
|
496
|
+
"""Post-initialization hook for TensorDictPrimer configuration."""
|
|
497
|
+
super().__post_init__()
|
|
498
|
+
|
|
499
|
+
|
|
500
|
+
@dataclass
|
|
501
|
+
class PinMemoryTransformConfig(TransformConfig):
|
|
502
|
+
"""Configuration for PinMemoryTransform."""
|
|
503
|
+
|
|
504
|
+
in_keys: list[str] | None = None
|
|
505
|
+
out_keys: list[str] | None = None
|
|
506
|
+
_target_: str = "torchrl.envs.transforms.transforms.PinMemoryTransform"
|
|
507
|
+
|
|
508
|
+
def __post_init__(self) -> None:
|
|
509
|
+
"""Post-initialization hook for PinMemoryTransform configuration."""
|
|
510
|
+
super().__post_init__()
|
|
511
|
+
|
|
512
|
+
|
|
513
|
+
@dataclass
|
|
514
|
+
class RewardSumConfig(TransformConfig):
|
|
515
|
+
"""Configuration for RewardSum transform."""
|
|
516
|
+
|
|
517
|
+
in_keys: list[str] | None = None
|
|
518
|
+
out_keys: list[str] | None = None
|
|
519
|
+
_target_: str = "torchrl.envs.transforms.transforms.RewardSum"
|
|
520
|
+
|
|
521
|
+
def __post_init__(self) -> None:
|
|
522
|
+
"""Post-initialization hook for RewardSum configuration."""
|
|
523
|
+
super().__post_init__()
|
|
524
|
+
|
|
525
|
+
|
|
526
|
+
@dataclass
|
|
527
|
+
class ExcludeTransformConfig(TransformConfig):
|
|
528
|
+
"""Configuration for ExcludeTransform."""
|
|
529
|
+
|
|
530
|
+
exclude_keys: list[str] | None = None
|
|
531
|
+
_target_: str = "torchrl.envs.transforms.transforms.ExcludeTransform"
|
|
532
|
+
|
|
533
|
+
def __post_init__(self) -> None:
|
|
534
|
+
"""Post-initialization hook for ExcludeTransform configuration."""
|
|
535
|
+
super().__post_init__()
|
|
536
|
+
if self.exclude_keys is None:
|
|
537
|
+
self.exclude_keys = []
|
|
538
|
+
|
|
539
|
+
|
|
540
|
+
@dataclass
|
|
541
|
+
class SelectTransformConfig(TransformConfig):
|
|
542
|
+
"""Configuration for SelectTransform."""
|
|
543
|
+
|
|
544
|
+
include_keys: list[str] | None = None
|
|
545
|
+
_target_: str = "torchrl.envs.transforms.transforms.SelectTransform"
|
|
546
|
+
|
|
547
|
+
def __post_init__(self) -> None:
|
|
548
|
+
"""Post-initialization hook for SelectTransform configuration."""
|
|
549
|
+
super().__post_init__()
|
|
550
|
+
if self.include_keys is None:
|
|
551
|
+
self.include_keys = []
|
|
552
|
+
|
|
553
|
+
|
|
554
|
+
@dataclass
|
|
555
|
+
class TimeMaxPoolConfig(TransformConfig):
|
|
556
|
+
"""Configuration for TimeMaxPool transform."""
|
|
557
|
+
|
|
558
|
+
dim: int = -1
|
|
559
|
+
in_keys: list[str] | None = None
|
|
560
|
+
out_keys: list[str] | None = None
|
|
561
|
+
_target_: str = "torchrl.envs.transforms.transforms.TimeMaxPool"
|
|
562
|
+
|
|
563
|
+
def __post_init__(self) -> None:
|
|
564
|
+
"""Post-initialization hook for TimeMaxPool configuration."""
|
|
565
|
+
super().__post_init__()
|
|
566
|
+
|
|
567
|
+
|
|
568
|
+
@dataclass
|
|
569
|
+
class RandomCropTensorDictConfig(TransformConfig):
|
|
570
|
+
"""Configuration for RandomCropTensorDict transform."""
|
|
571
|
+
|
|
572
|
+
crop_size: list[int] | None = None
|
|
573
|
+
in_keys: list[str] | None = None
|
|
574
|
+
out_keys: list[str] | None = None
|
|
575
|
+
_target_: str = "torchrl.envs.transforms.transforms.RandomCropTensorDict"
|
|
576
|
+
|
|
577
|
+
def __post_init__(self) -> None:
|
|
578
|
+
"""Post-initialization hook for RandomCropTensorDict configuration."""
|
|
579
|
+
super().__post_init__()
|
|
580
|
+
if self.crop_size is None:
|
|
581
|
+
self.crop_size = [84, 84]
|
|
582
|
+
|
|
583
|
+
|
|
584
|
+
@dataclass
|
|
585
|
+
class InitTrackerConfig(TransformConfig):
|
|
586
|
+
"""Configuration for InitTracker transform."""
|
|
587
|
+
|
|
588
|
+
init_key: str | None = None
|
|
589
|
+
_target_: str = "torchrl.envs.transforms.transforms.InitTracker"
|
|
590
|
+
|
|
591
|
+
def __post_init__(self) -> None:
|
|
592
|
+
"""Post-initialization hook for InitTracker configuration."""
|
|
593
|
+
super().__post_init__()
|
|
594
|
+
|
|
595
|
+
|
|
596
|
+
@dataclass
|
|
597
|
+
class RenameTransformConfig(TransformConfig):
|
|
598
|
+
"""Configuration for RenameTransform."""
|
|
599
|
+
|
|
600
|
+
key_mapping: dict[str, str] | None = None
|
|
601
|
+
_target_: str = "torchrl.envs.transforms.transforms.RenameTransform"
|
|
602
|
+
|
|
603
|
+
def __post_init__(self) -> None:
|
|
604
|
+
"""Post-initialization hook for RenameTransform configuration."""
|
|
605
|
+
super().__post_init__()
|
|
606
|
+
if self.key_mapping is None:
|
|
607
|
+
self.key_mapping = {}
|
|
608
|
+
|
|
609
|
+
|
|
610
|
+
@dataclass
|
|
611
|
+
class Reward2GoTransformConfig(TransformConfig):
|
|
612
|
+
"""Configuration for Reward2GoTransform."""
|
|
613
|
+
|
|
614
|
+
gamma: float = 0.99
|
|
615
|
+
in_keys: list[str] | None = None
|
|
616
|
+
out_keys: list[str] | None = None
|
|
617
|
+
_target_: str = "torchrl.envs.transforms.transforms.Reward2GoTransform"
|
|
618
|
+
|
|
619
|
+
def __post_init__(self) -> None:
|
|
620
|
+
"""Post-initialization hook for Reward2GoTransform configuration."""
|
|
621
|
+
super().__post_init__()
|
|
622
|
+
|
|
623
|
+
|
|
624
|
+
@dataclass
|
|
625
|
+
class ActionMaskConfig(TransformConfig):
|
|
626
|
+
"""Configuration for ActionMask transform."""
|
|
627
|
+
|
|
628
|
+
mask_key: str = "action_mask"
|
|
629
|
+
in_keys: list[str] | None = None
|
|
630
|
+
out_keys: list[str] | None = None
|
|
631
|
+
_target_: str = "torchrl.envs.transforms.transforms.ActionMask"
|
|
632
|
+
|
|
633
|
+
def __post_init__(self) -> None:
|
|
634
|
+
"""Post-initialization hook for ActionMask configuration."""
|
|
635
|
+
super().__post_init__()
|
|
636
|
+
|
|
637
|
+
|
|
638
|
+
@dataclass
|
|
639
|
+
class VecGymEnvTransformConfig(TransformConfig):
|
|
640
|
+
"""Configuration for VecGymEnvTransform."""
|
|
641
|
+
|
|
642
|
+
in_keys: list[str] | None = None
|
|
643
|
+
out_keys: list[str] | None = None
|
|
644
|
+
_target_: str = "torchrl.envs.transforms.transforms.VecGymEnvTransform"
|
|
645
|
+
|
|
646
|
+
def __post_init__(self) -> None:
|
|
647
|
+
"""Post-initialization hook for VecGymEnvTransform configuration."""
|
|
648
|
+
super().__post_init__()
|
|
649
|
+
|
|
650
|
+
|
|
651
|
+
@dataclass
|
|
652
|
+
class BurnInTransformConfig(TransformConfig):
|
|
653
|
+
"""Configuration for BurnInTransform."""
|
|
654
|
+
|
|
655
|
+
burn_in: int = 10
|
|
656
|
+
in_keys: list[str] | None = None
|
|
657
|
+
out_keys: list[str] | None = None
|
|
658
|
+
_target_: str = "torchrl.envs.transforms.transforms.BurnInTransform"
|
|
659
|
+
|
|
660
|
+
def __post_init__(self) -> None:
|
|
661
|
+
"""Post-initialization hook for BurnInTransform configuration."""
|
|
662
|
+
super().__post_init__()
|
|
663
|
+
|
|
664
|
+
|
|
665
|
+
@dataclass
|
|
666
|
+
class SignTransformConfig(TransformConfig):
|
|
667
|
+
"""Configuration for SignTransform."""
|
|
668
|
+
|
|
669
|
+
in_keys: list[str] | None = None
|
|
670
|
+
out_keys: list[str] | None = None
|
|
671
|
+
_target_: str = "torchrl.envs.transforms.transforms.SignTransform"
|
|
672
|
+
|
|
673
|
+
def __post_init__(self) -> None:
|
|
674
|
+
"""Post-initialization hook for SignTransform configuration."""
|
|
675
|
+
super().__post_init__()
|
|
676
|
+
|
|
677
|
+
|
|
678
|
+
@dataclass
|
|
679
|
+
class RemoveEmptySpecsConfig(TransformConfig):
|
|
680
|
+
"""Configuration for RemoveEmptySpecs transform."""
|
|
681
|
+
|
|
682
|
+
_target_: str = "torchrl.envs.transforms.transforms.RemoveEmptySpecs"
|
|
683
|
+
|
|
684
|
+
def __post_init__(self) -> None:
|
|
685
|
+
"""Post-initialization hook for RemoveEmptySpecs configuration."""
|
|
686
|
+
super().__post_init__()
|
|
687
|
+
|
|
688
|
+
|
|
689
|
+
@dataclass
|
|
690
|
+
class TrajCounterConfig(TransformConfig):
|
|
691
|
+
"""Configuration for TrajCounter transform."""
|
|
692
|
+
|
|
693
|
+
out_key: str = "traj_count"
|
|
694
|
+
repeats: int | None = None
|
|
695
|
+
_target_: str = "torchrl.envs.transforms.transforms.TrajCounter"
|
|
696
|
+
|
|
697
|
+
def __post_init__(self) -> None:
|
|
698
|
+
"""Post-initialization hook for TrajCounter configuration."""
|
|
699
|
+
super().__post_init__()
|
|
700
|
+
|
|
701
|
+
|
|
702
|
+
@dataclass
|
|
703
|
+
class LineariseRewardsConfig(TransformConfig):
|
|
704
|
+
"""Configuration for LineariseRewards transform."""
|
|
705
|
+
|
|
706
|
+
in_keys: list[str] | None = None
|
|
707
|
+
out_keys: list[str] | None = None
|
|
708
|
+
weights: list[float] | None = None
|
|
709
|
+
_target_: str = "torchrl.envs.transforms.transforms.LineariseRewards"
|
|
710
|
+
|
|
711
|
+
def __post_init__(self) -> None:
|
|
712
|
+
"""Post-initialization hook for LineariseRewards configuration."""
|
|
713
|
+
super().__post_init__()
|
|
714
|
+
if self.in_keys is None:
|
|
715
|
+
self.in_keys = []
|
|
716
|
+
|
|
717
|
+
|
|
718
|
+
@dataclass
|
|
719
|
+
class ConditionalSkipConfig(TransformConfig):
|
|
720
|
+
"""Configuration for ConditionalSkip transform."""
|
|
721
|
+
|
|
722
|
+
cond: Any = None
|
|
723
|
+
_target_: str = "torchrl.envs.transforms.transforms.ConditionalSkip"
|
|
724
|
+
|
|
725
|
+
def __post_init__(self) -> None:
|
|
726
|
+
"""Post-initialization hook for ConditionalSkip configuration."""
|
|
727
|
+
super().__post_init__()
|
|
728
|
+
|
|
729
|
+
|
|
730
|
+
@dataclass
|
|
731
|
+
class MultiActionConfig(TransformConfig):
|
|
732
|
+
"""Configuration for MultiAction transform."""
|
|
733
|
+
|
|
734
|
+
dim: int = 1
|
|
735
|
+
stack_rewards: bool = True
|
|
736
|
+
stack_observations: bool = False
|
|
737
|
+
_target_: str = "torchrl.envs.transforms.transforms.MultiAction"
|
|
738
|
+
|
|
739
|
+
def __post_init__(self) -> None:
|
|
740
|
+
"""Post-initialization hook for MultiAction configuration."""
|
|
741
|
+
super().__post_init__()
|
|
742
|
+
|
|
743
|
+
|
|
744
|
+
@dataclass
|
|
745
|
+
class TimerConfig(TransformConfig):
|
|
746
|
+
"""Configuration for Timer transform."""
|
|
747
|
+
|
|
748
|
+
out_keys: list[str] | None = None
|
|
749
|
+
time_key: str = "time"
|
|
750
|
+
_target_: str = "torchrl.envs.transforms.transforms.Timer"
|
|
751
|
+
|
|
752
|
+
def __post_init__(self) -> None:
|
|
753
|
+
"""Post-initialization hook for Timer configuration."""
|
|
754
|
+
super().__post_init__()
|
|
755
|
+
|
|
756
|
+
|
|
757
|
+
@dataclass
|
|
758
|
+
class ConditionalPolicySwitchConfig(TransformConfig):
|
|
759
|
+
"""Configuration for ConditionalPolicySwitch transform."""
|
|
760
|
+
|
|
761
|
+
policy: Any = None
|
|
762
|
+
condition: Any = None
|
|
763
|
+
_target_: str = "torchrl.envs.transforms.transforms.ConditionalPolicySwitch"
|
|
764
|
+
|
|
765
|
+
def __post_init__(self) -> None:
|
|
766
|
+
"""Post-initialization hook for ConditionalPolicySwitch configuration."""
|
|
767
|
+
super().__post_init__()
|
|
768
|
+
|
|
769
|
+
|
|
770
|
+
@dataclass
|
|
771
|
+
class KLRewardTransformConfig(TransformConfig):
|
|
772
|
+
"""Configuration for KLRewardTransform."""
|
|
773
|
+
|
|
774
|
+
in_keys: list[str] | None = None
|
|
775
|
+
out_keys: list[str] | None = None
|
|
776
|
+
_target_: str = "torchrl.envs.transforms.llm.KLRewardTransform"
|
|
777
|
+
|
|
778
|
+
def __post_init__(self) -> None:
|
|
779
|
+
"""Post-initialization hook for KLRewardTransform configuration."""
|
|
780
|
+
super().__post_init__()
|
|
781
|
+
|
|
782
|
+
|
|
783
|
+
@dataclass
|
|
784
|
+
class R3MTransformConfig(TransformConfig):
|
|
785
|
+
"""Configuration for R3MTransform."""
|
|
786
|
+
|
|
787
|
+
in_keys: list[str] | None = None
|
|
788
|
+
out_keys: list[str] | None = None
|
|
789
|
+
model_name: str = "resnet18"
|
|
790
|
+
device: str = "cpu"
|
|
791
|
+
_target_: str = "torchrl.envs.transforms.r3m.R3MTransform"
|
|
792
|
+
|
|
793
|
+
def __post_init__(self) -> None:
|
|
794
|
+
"""Post-initialization hook for R3MTransform configuration."""
|
|
795
|
+
super().__post_init__()
|
|
796
|
+
|
|
797
|
+
|
|
798
|
+
@dataclass
|
|
799
|
+
class VC1TransformConfig(TransformConfig):
|
|
800
|
+
"""Configuration for VC1Transform."""
|
|
801
|
+
|
|
802
|
+
in_keys: list[str] | None = None
|
|
803
|
+
out_keys: list[str] | None = None
|
|
804
|
+
device: str = "cpu"
|
|
805
|
+
_target_: str = "torchrl.envs.transforms.vc1.VC1Transform"
|
|
806
|
+
|
|
807
|
+
def __post_init__(self) -> None:
|
|
808
|
+
"""Post-initialization hook for VC1Transform configuration."""
|
|
809
|
+
super().__post_init__()
|
|
810
|
+
|
|
811
|
+
|
|
812
|
+
@dataclass
|
|
813
|
+
class VIPTransformConfig(TransformConfig):
|
|
814
|
+
"""Configuration for VIPTransform."""
|
|
815
|
+
|
|
816
|
+
in_keys: list[str] | None = None
|
|
817
|
+
out_keys: list[str] | None = None
|
|
818
|
+
device: str = "cpu"
|
|
819
|
+
_target_: str = "torchrl.envs.transforms.vip.VIPTransform"
|
|
820
|
+
|
|
821
|
+
def __post_init__(self) -> None:
|
|
822
|
+
"""Post-initialization hook for VIPTransform configuration."""
|
|
823
|
+
super().__post_init__()
|
|
824
|
+
|
|
825
|
+
|
|
826
|
+
@dataclass
|
|
827
|
+
class VIPRewardTransformConfig(TransformConfig):
|
|
828
|
+
"""Configuration for VIPRewardTransform."""
|
|
829
|
+
|
|
830
|
+
in_keys: list[str] | None = None
|
|
831
|
+
out_keys: list[str] | None = None
|
|
832
|
+
device: str = "cpu"
|
|
833
|
+
_target_: str = "torchrl.envs.transforms.vip.VIPRewardTransform"
|
|
834
|
+
|
|
835
|
+
def __post_init__(self) -> None:
|
|
836
|
+
"""Post-initialization hook for VIPRewardTransform configuration."""
|
|
837
|
+
super().__post_init__()
|
|
838
|
+
|
|
839
|
+
|
|
840
|
+
@dataclass
|
|
841
|
+
class VecNormV2Config(TransformConfig):
|
|
842
|
+
"""Configuration for VecNormV2 transform."""
|
|
843
|
+
|
|
844
|
+
in_keys: list[str] | None = None
|
|
845
|
+
out_keys: list[str] | None = None
|
|
846
|
+
decay: float = 0.99
|
|
847
|
+
eps: float = 1e-8
|
|
848
|
+
_target_: str = "torchrl.envs.transforms.vecnorm.VecNormV2"
|
|
849
|
+
|
|
850
|
+
def __post_init__(self) -> None:
|
|
851
|
+
"""Post-initialization hook for VecNormV2 configuration."""
|
|
852
|
+
super().__post_init__()
|
|
853
|
+
|
|
854
|
+
|
|
855
|
+
@dataclass
|
|
856
|
+
class FiniteTensorDictCheckConfig(TransformConfig):
|
|
857
|
+
"""Configuration for FiniteTensorDictCheck transform."""
|
|
858
|
+
|
|
859
|
+
in_keys: list[str] | None = None
|
|
860
|
+
out_keys: list[str] | None = None
|
|
861
|
+
_target_: str = "torchrl.envs.transforms.transforms.FiniteTensorDictCheck"
|
|
862
|
+
|
|
863
|
+
def __post_init__(self) -> None:
|
|
864
|
+
"""Post-initialization hook for FiniteTensorDictCheck configuration."""
|
|
865
|
+
super().__post_init__()
|
|
866
|
+
|
|
867
|
+
|
|
868
|
+
@dataclass
|
|
869
|
+
class UnaryTransformConfig(TransformConfig):
|
|
870
|
+
"""Configuration for UnaryTransform."""
|
|
871
|
+
|
|
872
|
+
fn: Any = None
|
|
873
|
+
in_keys: list[str] | None = None
|
|
874
|
+
out_keys: list[str] | None = None
|
|
875
|
+
_target_: str = "torchrl.envs.transforms.transforms.UnaryTransform"
|
|
876
|
+
|
|
877
|
+
def __post_init__(self) -> None:
|
|
878
|
+
"""Post-initialization hook for UnaryTransform configuration."""
|
|
879
|
+
super().__post_init__()
|
|
880
|
+
|
|
881
|
+
|
|
882
|
+
@dataclass
|
|
883
|
+
class HashConfig(TransformConfig):
|
|
884
|
+
"""Configuration for Hash transform."""
|
|
885
|
+
|
|
886
|
+
in_keys: list[str] | None = None
|
|
887
|
+
out_keys: list[str] | None = None
|
|
888
|
+
_target_: str = "torchrl.envs.transforms.transforms.Hash"
|
|
889
|
+
|
|
890
|
+
def __post_init__(self) -> None:
|
|
891
|
+
"""Post-initialization hook for Hash configuration."""
|
|
892
|
+
super().__post_init__()
|
|
893
|
+
|
|
894
|
+
|
|
895
|
+
@dataclass
|
|
896
|
+
class TokenizerConfig(TransformConfig):
|
|
897
|
+
"""Configuration for Tokenizer transform."""
|
|
898
|
+
|
|
899
|
+
vocab_size: int = 1000
|
|
900
|
+
in_keys: list[str] | None = None
|
|
901
|
+
out_keys: list[str] | None = None
|
|
902
|
+
_target_: str = "torchrl.envs.transforms.transforms.Tokenizer"
|
|
903
|
+
|
|
904
|
+
def __post_init__(self) -> None:
|
|
905
|
+
"""Post-initialization hook for Tokenizer configuration."""
|
|
906
|
+
super().__post_init__()
|
|
907
|
+
|
|
908
|
+
|
|
909
|
+
@dataclass
|
|
910
|
+
class CropConfig(TransformConfig):
|
|
911
|
+
"""Configuration for Crop transform."""
|
|
912
|
+
|
|
913
|
+
top: int = 0
|
|
914
|
+
left: int = 0
|
|
915
|
+
height: int = 84
|
|
916
|
+
width: int = 84
|
|
917
|
+
in_keys: list[str] | None = None
|
|
918
|
+
out_keys: list[str] | None = None
|
|
919
|
+
_target_: str = "torchrl.envs.transforms.transforms.Crop"
|
|
920
|
+
|
|
921
|
+
def __post_init__(self) -> None:
|
|
922
|
+
"""Post-initialization hook for Crop configuration."""
|
|
923
|
+
super().__post_init__()
|
|
924
|
+
|
|
925
|
+
|
|
926
|
+
@dataclass
|
|
927
|
+
class FlattenTensorDictConfig(TransformConfig):
|
|
928
|
+
"""Configuration for flattening TensorDict during inverse pass.
|
|
929
|
+
|
|
930
|
+
This transform reshapes the tensordict to have a flat batch dimension
|
|
931
|
+
during the inverse pass, which is useful for replay buffers that need
|
|
932
|
+
to store data with a flat batch structure.
|
|
933
|
+
"""
|
|
934
|
+
|
|
935
|
+
_target_: str = "torchrl.envs.transforms.transforms.FlattenTensorDict"
|
|
936
|
+
|
|
937
|
+
def __post_init__(self) -> None:
|
|
938
|
+
"""Post-initialization hook for FlattenTensorDict configuration."""
|
|
939
|
+
super().__post_init__()
|
|
940
|
+
|
|
941
|
+
|
|
942
|
+
@dataclass
|
|
943
|
+
class ModuleTransformConfig(TransformConfig):
|
|
944
|
+
"""Configuration for ModuleTransform."""
|
|
945
|
+
|
|
946
|
+
module: Any = None
|
|
947
|
+
device: Any = None
|
|
948
|
+
no_grad: bool = False
|
|
949
|
+
inverse: bool = False
|
|
950
|
+
_target_: str = "torchrl.envs.transforms.module.ModuleTransform"
|
|
951
|
+
_partial_: bool = False
|
|
952
|
+
|
|
953
|
+
def __post_init__(self) -> None:
|
|
954
|
+
"""Post-initialization hook for ModuleTransform configuration."""
|
|
955
|
+
super().__post_init__()
|