torchrl 0.11.0__cp314-cp314-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/benchmark_batched_envs.py +104 -0
- benchmarks/conftest.py +91 -0
- benchmarks/ecosystem/gym_env_throughput.py +321 -0
- benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
- benchmarks/requirements.txt +7 -0
- benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
- benchmarks/test_collectors_benchmark.py +240 -0
- benchmarks/test_compressed_storage_benchmark.py +145 -0
- benchmarks/test_envs_benchmark.py +133 -0
- benchmarks/test_llm.py +101 -0
- benchmarks/test_non_tensor_env_benchmark.py +70 -0
- benchmarks/test_objectives_benchmarks.py +1199 -0
- benchmarks/test_replaybuffer_benchmark.py +254 -0
- sota-check/README.md +35 -0
- sota-implementations/README.md +142 -0
- sota-implementations/a2c/README.md +39 -0
- sota-implementations/a2c/a2c_atari.py +291 -0
- sota-implementations/a2c/a2c_mujoco.py +273 -0
- sota-implementations/a2c/utils_atari.py +240 -0
- sota-implementations/a2c/utils_mujoco.py +160 -0
- sota-implementations/bandits/README.md +7 -0
- sota-implementations/bandits/dqn.py +126 -0
- sota-implementations/cql/cql_offline.py +198 -0
- sota-implementations/cql/cql_online.py +249 -0
- sota-implementations/cql/discrete_cql_offline.py +180 -0
- sota-implementations/cql/discrete_cql_online.py +227 -0
- sota-implementations/cql/utils.py +471 -0
- sota-implementations/crossq/crossq.py +271 -0
- sota-implementations/crossq/utils.py +320 -0
- sota-implementations/ddpg/ddpg.py +231 -0
- sota-implementations/ddpg/utils.py +325 -0
- sota-implementations/decision_transformer/dt.py +163 -0
- sota-implementations/decision_transformer/lamb.py +167 -0
- sota-implementations/decision_transformer/online_dt.py +178 -0
- sota-implementations/decision_transformer/utils.py +562 -0
- sota-implementations/discrete_sac/discrete_sac.py +243 -0
- sota-implementations/discrete_sac/utils.py +324 -0
- sota-implementations/dqn/README.md +30 -0
- sota-implementations/dqn/dqn_atari.py +272 -0
- sota-implementations/dqn/dqn_cartpole.py +236 -0
- sota-implementations/dqn/utils_atari.py +132 -0
- sota-implementations/dqn/utils_cartpole.py +90 -0
- sota-implementations/dreamer/README.md +129 -0
- sota-implementations/dreamer/dreamer.py +586 -0
- sota-implementations/dreamer/dreamer_utils.py +1107 -0
- sota-implementations/expert-iteration/README.md +352 -0
- sota-implementations/expert-iteration/ei_utils.py +770 -0
- sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
- sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
- sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
- sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
- sota-implementations/gail/gail.py +327 -0
- sota-implementations/gail/gail_utils.py +68 -0
- sota-implementations/gail/ppo_utils.py +157 -0
- sota-implementations/grpo/README.md +273 -0
- sota-implementations/grpo/grpo-async.py +437 -0
- sota-implementations/grpo/grpo-sync.py +435 -0
- sota-implementations/grpo/grpo_utils.py +843 -0
- sota-implementations/grpo/requirements_gsm8k.txt +11 -0
- sota-implementations/grpo/requirements_ifeval.txt +16 -0
- sota-implementations/impala/README.md +33 -0
- sota-implementations/impala/impala_multi_node_ray.py +292 -0
- sota-implementations/impala/impala_multi_node_submitit.py +284 -0
- sota-implementations/impala/impala_single_node.py +261 -0
- sota-implementations/impala/utils.py +184 -0
- sota-implementations/iql/discrete_iql.py +230 -0
- sota-implementations/iql/iql_offline.py +164 -0
- sota-implementations/iql/iql_online.py +225 -0
- sota-implementations/iql/utils.py +437 -0
- sota-implementations/multiagent/README.md +74 -0
- sota-implementations/multiagent/iql.py +237 -0
- sota-implementations/multiagent/maddpg_iddpg.py +266 -0
- sota-implementations/multiagent/mappo_ippo.py +267 -0
- sota-implementations/multiagent/qmix_vdn.py +271 -0
- sota-implementations/multiagent/sac.py +337 -0
- sota-implementations/multiagent/utils/__init__.py +4 -0
- sota-implementations/multiagent/utils/logging.py +151 -0
- sota-implementations/multiagent/utils/utils.py +43 -0
- sota-implementations/ppo/README.md +29 -0
- sota-implementations/ppo/ppo_atari.py +305 -0
- sota-implementations/ppo/ppo_mujoco.py +293 -0
- sota-implementations/ppo/utils_atari.py +238 -0
- sota-implementations/ppo/utils_mujoco.py +152 -0
- sota-implementations/ppo_trainer/train.py +21 -0
- sota-implementations/redq/README.md +7 -0
- sota-implementations/redq/redq.py +199 -0
- sota-implementations/redq/utils.py +1060 -0
- sota-implementations/sac/sac-async.py +266 -0
- sota-implementations/sac/sac.py +239 -0
- sota-implementations/sac/utils.py +381 -0
- sota-implementations/sac_trainer/train.py +16 -0
- sota-implementations/td3/td3.py +254 -0
- sota-implementations/td3/utils.py +319 -0
- sota-implementations/td3_bc/td3_bc.py +177 -0
- sota-implementations/td3_bc/utils.py +251 -0
- torchrl/__init__.py +144 -0
- torchrl/_extension.py +74 -0
- torchrl/_torchrl.cp314-win_amd64.pyd +0 -0
- torchrl/_utils.py +1431 -0
- torchrl/collectors/__init__.py +48 -0
- torchrl/collectors/_base.py +1058 -0
- torchrl/collectors/_constants.py +88 -0
- torchrl/collectors/_multi_async.py +324 -0
- torchrl/collectors/_multi_base.py +1805 -0
- torchrl/collectors/_multi_sync.py +464 -0
- torchrl/collectors/_runner.py +581 -0
- torchrl/collectors/_single.py +2009 -0
- torchrl/collectors/_single_async.py +259 -0
- torchrl/collectors/collectors.py +62 -0
- torchrl/collectors/distributed/__init__.py +32 -0
- torchrl/collectors/distributed/default_configs.py +133 -0
- torchrl/collectors/distributed/generic.py +1306 -0
- torchrl/collectors/distributed/ray.py +1092 -0
- torchrl/collectors/distributed/rpc.py +1006 -0
- torchrl/collectors/distributed/sync.py +731 -0
- torchrl/collectors/distributed/utils.py +160 -0
- torchrl/collectors/llm/__init__.py +10 -0
- torchrl/collectors/llm/base.py +494 -0
- torchrl/collectors/llm/ray_collector.py +275 -0
- torchrl/collectors/llm/utils.py +36 -0
- torchrl/collectors/llm/weight_update/__init__.py +10 -0
- torchrl/collectors/llm/weight_update/vllm.py +348 -0
- torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
- torchrl/collectors/utils.py +433 -0
- torchrl/collectors/weight_update.py +591 -0
- torchrl/csrc/numpy_utils.h +38 -0
- torchrl/csrc/pybind.cpp +27 -0
- torchrl/csrc/segment_tree.h +458 -0
- torchrl/csrc/torch_utils.h +34 -0
- torchrl/csrc/utils.cpp +48 -0
- torchrl/csrc/utils.h +31 -0
- torchrl/data/__init__.py +187 -0
- torchrl/data/datasets/__init__.py +58 -0
- torchrl/data/datasets/atari_dqn.py +878 -0
- torchrl/data/datasets/common.py +281 -0
- torchrl/data/datasets/d4rl.py +489 -0
- torchrl/data/datasets/d4rl_infos.py +187 -0
- torchrl/data/datasets/gen_dgrl.py +375 -0
- torchrl/data/datasets/minari_data.py +643 -0
- torchrl/data/datasets/openml.py +177 -0
- torchrl/data/datasets/openx.py +798 -0
- torchrl/data/datasets/roboset.py +363 -0
- torchrl/data/datasets/utils.py +11 -0
- torchrl/data/datasets/vd4rl.py +432 -0
- torchrl/data/llm/__init__.py +34 -0
- torchrl/data/llm/dataset.py +491 -0
- torchrl/data/llm/history.py +1378 -0
- torchrl/data/llm/prompt.py +198 -0
- torchrl/data/llm/reward.py +225 -0
- torchrl/data/llm/topk.py +186 -0
- torchrl/data/llm/utils.py +543 -0
- torchrl/data/map/__init__.py +21 -0
- torchrl/data/map/hash.py +185 -0
- torchrl/data/map/query.py +204 -0
- torchrl/data/map/tdstorage.py +363 -0
- torchrl/data/map/tree.py +1434 -0
- torchrl/data/map/utils.py +103 -0
- torchrl/data/postprocs/__init__.py +8 -0
- torchrl/data/postprocs/postprocs.py +391 -0
- torchrl/data/replay_buffers/__init__.py +99 -0
- torchrl/data/replay_buffers/checkpointers.py +622 -0
- torchrl/data/replay_buffers/ray_buffer.py +292 -0
- torchrl/data/replay_buffers/replay_buffers.py +2376 -0
- torchrl/data/replay_buffers/samplers.py +2578 -0
- torchrl/data/replay_buffers/scheduler.py +265 -0
- torchrl/data/replay_buffers/storages.py +2412 -0
- torchrl/data/replay_buffers/utils.py +1042 -0
- torchrl/data/replay_buffers/writers.py +781 -0
- torchrl/data/tensor_specs.py +7101 -0
- torchrl/data/utils.py +334 -0
- torchrl/envs/__init__.py +265 -0
- torchrl/envs/async_envs.py +1105 -0
- torchrl/envs/batched_envs.py +3093 -0
- torchrl/envs/common.py +4241 -0
- torchrl/envs/custom/__init__.py +11 -0
- torchrl/envs/custom/chess.py +617 -0
- torchrl/envs/custom/llm.py +214 -0
- torchrl/envs/custom/pendulum.py +401 -0
- torchrl/envs/custom/san_moves.txt +29274 -0
- torchrl/envs/custom/tictactoeenv.py +288 -0
- torchrl/envs/env_creator.py +263 -0
- torchrl/envs/gym_like.py +752 -0
- torchrl/envs/libs/__init__.py +68 -0
- torchrl/envs/libs/_gym_utils.py +326 -0
- torchrl/envs/libs/brax.py +846 -0
- torchrl/envs/libs/dm_control.py +544 -0
- torchrl/envs/libs/envpool.py +447 -0
- torchrl/envs/libs/gym.py +2239 -0
- torchrl/envs/libs/habitat.py +138 -0
- torchrl/envs/libs/isaac_lab.py +87 -0
- torchrl/envs/libs/isaacgym.py +203 -0
- torchrl/envs/libs/jax_utils.py +166 -0
- torchrl/envs/libs/jumanji.py +963 -0
- torchrl/envs/libs/meltingpot.py +599 -0
- torchrl/envs/libs/openml.py +153 -0
- torchrl/envs/libs/openspiel.py +652 -0
- torchrl/envs/libs/pettingzoo.py +1042 -0
- torchrl/envs/libs/procgen.py +351 -0
- torchrl/envs/libs/robohive.py +429 -0
- torchrl/envs/libs/smacv2.py +645 -0
- torchrl/envs/libs/unity_mlagents.py +891 -0
- torchrl/envs/libs/utils.py +147 -0
- torchrl/envs/libs/vmas.py +813 -0
- torchrl/envs/llm/__init__.py +63 -0
- torchrl/envs/llm/chat.py +730 -0
- torchrl/envs/llm/datasets/README.md +4 -0
- torchrl/envs/llm/datasets/__init__.py +17 -0
- torchrl/envs/llm/datasets/gsm8k.py +353 -0
- torchrl/envs/llm/datasets/ifeval.py +274 -0
- torchrl/envs/llm/envs.py +789 -0
- torchrl/envs/llm/libs/README.md +3 -0
- torchrl/envs/llm/libs/__init__.py +8 -0
- torchrl/envs/llm/libs/mlgym.py +869 -0
- torchrl/envs/llm/reward/__init__.py +10 -0
- torchrl/envs/llm/reward/gsm8k.py +324 -0
- torchrl/envs/llm/reward/ifeval/README.md +13 -0
- torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
- torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
- torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
- torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
- torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
- torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
- torchrl/envs/llm/transforms/__init__.py +55 -0
- torchrl/envs/llm/transforms/browser.py +292 -0
- torchrl/envs/llm/transforms/dataloading.py +859 -0
- torchrl/envs/llm/transforms/format.py +73 -0
- torchrl/envs/llm/transforms/kl.py +1544 -0
- torchrl/envs/llm/transforms/policy_version.py +189 -0
- torchrl/envs/llm/transforms/reason.py +323 -0
- torchrl/envs/llm/transforms/tokenizer.py +321 -0
- torchrl/envs/llm/transforms/tools.py +1955 -0
- torchrl/envs/model_based/__init__.py +9 -0
- torchrl/envs/model_based/common.py +180 -0
- torchrl/envs/model_based/dreamer.py +112 -0
- torchrl/envs/transforms/__init__.py +147 -0
- torchrl/envs/transforms/functional.py +48 -0
- torchrl/envs/transforms/gym_transforms.py +203 -0
- torchrl/envs/transforms/module.py +341 -0
- torchrl/envs/transforms/r3m.py +372 -0
- torchrl/envs/transforms/ray_service.py +663 -0
- torchrl/envs/transforms/rb_transforms.py +214 -0
- torchrl/envs/transforms/transforms.py +11835 -0
- torchrl/envs/transforms/utils.py +94 -0
- torchrl/envs/transforms/vc1.py +307 -0
- torchrl/envs/transforms/vecnorm.py +845 -0
- torchrl/envs/transforms/vip.py +407 -0
- torchrl/envs/utils.py +1718 -0
- torchrl/envs/vec_envs.py +11 -0
- torchrl/modules/__init__.py +206 -0
- torchrl/modules/distributions/__init__.py +73 -0
- torchrl/modules/distributions/continuous.py +830 -0
- torchrl/modules/distributions/discrete.py +908 -0
- torchrl/modules/distributions/truncated_normal.py +187 -0
- torchrl/modules/distributions/utils.py +233 -0
- torchrl/modules/llm/__init__.py +62 -0
- torchrl/modules/llm/backends/__init__.py +65 -0
- torchrl/modules/llm/backends/vllm/__init__.py +94 -0
- torchrl/modules/llm/backends/vllm/_models.py +46 -0
- torchrl/modules/llm/backends/vllm/base.py +72 -0
- torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
- torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
- torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
- torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
- torchrl/modules/llm/policies/__init__.py +28 -0
- torchrl/modules/llm/policies/common.py +1809 -0
- torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
- torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
- torchrl/modules/llm/utils.py +23 -0
- torchrl/modules/mcts/__init__.py +21 -0
- torchrl/modules/mcts/scores.py +579 -0
- torchrl/modules/models/__init__.py +86 -0
- torchrl/modules/models/batchrenorm.py +119 -0
- torchrl/modules/models/decision_transformer.py +179 -0
- torchrl/modules/models/exploration.py +731 -0
- torchrl/modules/models/llm.py +156 -0
- torchrl/modules/models/model_based.py +596 -0
- torchrl/modules/models/models.py +1712 -0
- torchrl/modules/models/multiagent.py +1067 -0
- torchrl/modules/models/recipes/impala.py +185 -0
- torchrl/modules/models/utils.py +162 -0
- torchrl/modules/planners/__init__.py +10 -0
- torchrl/modules/planners/cem.py +228 -0
- torchrl/modules/planners/common.py +73 -0
- torchrl/modules/planners/mppi.py +265 -0
- torchrl/modules/tensordict_module/__init__.py +89 -0
- torchrl/modules/tensordict_module/actors.py +2457 -0
- torchrl/modules/tensordict_module/common.py +529 -0
- torchrl/modules/tensordict_module/exploration.py +814 -0
- torchrl/modules/tensordict_module/probabilistic.py +321 -0
- torchrl/modules/tensordict_module/rnn.py +1639 -0
- torchrl/modules/tensordict_module/sequence.py +132 -0
- torchrl/modules/tensordict_module/world_models.py +34 -0
- torchrl/modules/utils/__init__.py +38 -0
- torchrl/modules/utils/mappings.py +9 -0
- torchrl/modules/utils/utils.py +89 -0
- torchrl/objectives/__init__.py +78 -0
- torchrl/objectives/a2c.py +659 -0
- torchrl/objectives/common.py +753 -0
- torchrl/objectives/cql.py +1346 -0
- torchrl/objectives/crossq.py +710 -0
- torchrl/objectives/ddpg.py +453 -0
- torchrl/objectives/decision_transformer.py +371 -0
- torchrl/objectives/deprecated.py +516 -0
- torchrl/objectives/dqn.py +683 -0
- torchrl/objectives/dreamer.py +488 -0
- torchrl/objectives/functional.py +48 -0
- torchrl/objectives/gail.py +258 -0
- torchrl/objectives/iql.py +996 -0
- torchrl/objectives/llm/__init__.py +30 -0
- torchrl/objectives/llm/grpo.py +846 -0
- torchrl/objectives/llm/sft.py +482 -0
- torchrl/objectives/multiagent/__init__.py +8 -0
- torchrl/objectives/multiagent/qmixer.py +396 -0
- torchrl/objectives/ppo.py +1669 -0
- torchrl/objectives/redq.py +683 -0
- torchrl/objectives/reinforce.py +530 -0
- torchrl/objectives/sac.py +1580 -0
- torchrl/objectives/td3.py +570 -0
- torchrl/objectives/td3_bc.py +625 -0
- torchrl/objectives/utils.py +782 -0
- torchrl/objectives/value/__init__.py +28 -0
- torchrl/objectives/value/advantages.py +1956 -0
- torchrl/objectives/value/functional.py +1459 -0
- torchrl/objectives/value/utils.py +360 -0
- torchrl/record/__init__.py +17 -0
- torchrl/record/loggers/__init__.py +23 -0
- torchrl/record/loggers/common.py +48 -0
- torchrl/record/loggers/csv.py +226 -0
- torchrl/record/loggers/mlflow.py +142 -0
- torchrl/record/loggers/tensorboard.py +139 -0
- torchrl/record/loggers/trackio.py +163 -0
- torchrl/record/loggers/utils.py +78 -0
- torchrl/record/loggers/wandb.py +214 -0
- torchrl/record/recorder.py +554 -0
- torchrl/services/__init__.py +79 -0
- torchrl/services/base.py +109 -0
- torchrl/services/ray_service.py +453 -0
- torchrl/testing/__init__.py +107 -0
- torchrl/testing/assertions.py +179 -0
- torchrl/testing/dist_utils.py +122 -0
- torchrl/testing/env_creators.py +227 -0
- torchrl/testing/env_helper.py +35 -0
- torchrl/testing/gym_helpers.py +156 -0
- torchrl/testing/llm_mocks.py +119 -0
- torchrl/testing/mocking_classes.py +2720 -0
- torchrl/testing/modules.py +295 -0
- torchrl/testing/mp_helpers.py +15 -0
- torchrl/testing/ray_helpers.py +293 -0
- torchrl/testing/utils.py +190 -0
- torchrl/trainers/__init__.py +42 -0
- torchrl/trainers/algorithms/__init__.py +11 -0
- torchrl/trainers/algorithms/configs/__init__.py +705 -0
- torchrl/trainers/algorithms/configs/collectors.py +216 -0
- torchrl/trainers/algorithms/configs/common.py +41 -0
- torchrl/trainers/algorithms/configs/data.py +308 -0
- torchrl/trainers/algorithms/configs/envs.py +104 -0
- torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
- torchrl/trainers/algorithms/configs/logging.py +80 -0
- torchrl/trainers/algorithms/configs/modules.py +570 -0
- torchrl/trainers/algorithms/configs/objectives.py +177 -0
- torchrl/trainers/algorithms/configs/trainers.py +340 -0
- torchrl/trainers/algorithms/configs/transforms.py +955 -0
- torchrl/trainers/algorithms/configs/utils.py +252 -0
- torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
- torchrl/trainers/algorithms/configs/weight_update.py +159 -0
- torchrl/trainers/algorithms/ppo.py +373 -0
- torchrl/trainers/algorithms/sac.py +308 -0
- torchrl/trainers/helpers/__init__.py +40 -0
- torchrl/trainers/helpers/collectors.py +416 -0
- torchrl/trainers/helpers/envs.py +573 -0
- torchrl/trainers/helpers/logger.py +33 -0
- torchrl/trainers/helpers/losses.py +132 -0
- torchrl/trainers/helpers/models.py +658 -0
- torchrl/trainers/helpers/replay_buffer.py +59 -0
- torchrl/trainers/helpers/trainers.py +301 -0
- torchrl/trainers/trainers.py +2052 -0
- torchrl/weight_update/__init__.py +33 -0
- torchrl/weight_update/_distributed.py +749 -0
- torchrl/weight_update/_mp.py +624 -0
- torchrl/weight_update/_noupdate.py +102 -0
- torchrl/weight_update/_ray.py +1032 -0
- torchrl/weight_update/_rpc.py +284 -0
- torchrl/weight_update/_shared.py +891 -0
- torchrl/weight_update/llm/__init__.py +32 -0
- torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
- torchrl/weight_update/llm/vllm_nccl.py +710 -0
- torchrl/weight_update/utils.py +73 -0
- torchrl/weight_update/weight_sync_schemes.py +1244 -0
- torchrl-0.11.0.dist-info/LICENSE +21 -0
- torchrl-0.11.0.dist-info/METADATA +1307 -0
- torchrl-0.11.0.dist-info/RECORD +394 -0
- torchrl-0.11.0.dist-info/WHEEL +5 -0
- torchrl-0.11.0.dist-info/entry_points.txt +2 -0
- torchrl-0.11.0.dist-info/top_level.txt +7 -0
torchrl/data/utils.py
ADDED
|
@@ -0,0 +1,334 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import functools
|
|
8
|
+
import typing
|
|
9
|
+
from collections.abc import Callable
|
|
10
|
+
from typing import Any, Union
|
|
11
|
+
|
|
12
|
+
import cloudpickle
|
|
13
|
+
import numpy as np
|
|
14
|
+
import torch
|
|
15
|
+
from torch import Tensor
|
|
16
|
+
from torchrl.data.tensor_specs import (
|
|
17
|
+
Binary,
|
|
18
|
+
Categorical,
|
|
19
|
+
Composite,
|
|
20
|
+
MultiCategorical,
|
|
21
|
+
MultiOneHot,
|
|
22
|
+
OneHot,
|
|
23
|
+
Stacked,
|
|
24
|
+
StackedComposite,
|
|
25
|
+
TensorSpec,
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
numpy_to_torch_dtype_dict = {
|
|
29
|
+
np.dtype("bool"): torch.bool,
|
|
30
|
+
np.dtype("uint8"): torch.uint8,
|
|
31
|
+
np.dtype("int8"): torch.int8,
|
|
32
|
+
np.dtype("int16"): torch.int16,
|
|
33
|
+
np.dtype("int32"): torch.int32,
|
|
34
|
+
np.dtype("int64"): torch.int64,
|
|
35
|
+
np.dtype("float16"): torch.float16,
|
|
36
|
+
np.dtype("float32"): torch.float32,
|
|
37
|
+
np.dtype("float64"): torch.float64,
|
|
38
|
+
np.dtype("complex64"): torch.complex64,
|
|
39
|
+
np.dtype("complex128"): torch.complex128,
|
|
40
|
+
}
|
|
41
|
+
torch_to_numpy_dtype_dict = {
|
|
42
|
+
value: key for key, value in numpy_to_torch_dtype_dict.items()
|
|
43
|
+
}
|
|
44
|
+
DEVICE_TYPING = Union[torch.device, str, int]
|
|
45
|
+
if hasattr(typing, "get_args"):
|
|
46
|
+
DEVICE_TYPING_ARGS = typing.get_args(DEVICE_TYPING)
|
|
47
|
+
else:
|
|
48
|
+
DEVICE_TYPING_ARGS = (torch.device, str, int)
|
|
49
|
+
|
|
50
|
+
INDEX_TYPING = Union[None, int, slice, str, Tensor, list[Any], tuple[Any, ...]]
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
ACTION_SPACE_MAP = {
|
|
54
|
+
OneHot: "one_hot",
|
|
55
|
+
MultiOneHot: "mult_one_hot",
|
|
56
|
+
Binary: "binary",
|
|
57
|
+
Categorical: "categorical",
|
|
58
|
+
"one_hot": "one_hot",
|
|
59
|
+
"one-hot": "one_hot",
|
|
60
|
+
"mult_one_hot": "mult_one_hot",
|
|
61
|
+
"mult-one-hot": "mult_one_hot",
|
|
62
|
+
"multi_one_hot": "mult_one_hot",
|
|
63
|
+
"multi-one-hot": "mult_one_hot",
|
|
64
|
+
"binary": "binary",
|
|
65
|
+
"categorical": "categorical",
|
|
66
|
+
MultiCategorical: "multi_categorical",
|
|
67
|
+
"multi_categorical": "multi_categorical",
|
|
68
|
+
"multi-categorical": "multi_categorical",
|
|
69
|
+
"multi_discrete": "multi_categorical",
|
|
70
|
+
"multi-discrete": "multi_categorical",
|
|
71
|
+
}
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def consolidate_spec(
|
|
75
|
+
spec: Composite,
|
|
76
|
+
recurse_through_entries: bool = True,
|
|
77
|
+
recurse_through_stack: bool = True,
|
|
78
|
+
):
|
|
79
|
+
"""Given a TensorSpec, removes exclusive keys by adding 0 shaped specs.
|
|
80
|
+
|
|
81
|
+
Args:
|
|
82
|
+
spec (Composite): the spec to be consolidated.
|
|
83
|
+
recurse_through_entries (bool): if True, call the function recursively on all entries of the spec.
|
|
84
|
+
Default is True.
|
|
85
|
+
recurse_through_stack (bool): if True, if the provided spec is lazy, the function recursively
|
|
86
|
+
on all specs in its list. Default is True.
|
|
87
|
+
|
|
88
|
+
"""
|
|
89
|
+
spec = spec.clone()
|
|
90
|
+
|
|
91
|
+
if not isinstance(spec, (Composite, StackedComposite)):
|
|
92
|
+
return spec
|
|
93
|
+
|
|
94
|
+
if isinstance(spec, StackedComposite):
|
|
95
|
+
keys = set(spec.keys()) # shared keys
|
|
96
|
+
exclusive_keys_per_spec = [
|
|
97
|
+
set() for _ in range(len(spec._specs))
|
|
98
|
+
] # list of exclusive keys per td
|
|
99
|
+
exclusive_keys_examples = (
|
|
100
|
+
{}
|
|
101
|
+
) # map of all exclusive keys to a list of their values
|
|
102
|
+
for spec_index in range(len(spec._specs)): # gather all exclusive keys
|
|
103
|
+
sub_spec = spec._specs[spec_index]
|
|
104
|
+
if recurse_through_stack:
|
|
105
|
+
sub_spec = consolidate_spec(
|
|
106
|
+
sub_spec, recurse_through_entries, recurse_through_stack
|
|
107
|
+
)
|
|
108
|
+
spec._specs[spec_index] = sub_spec
|
|
109
|
+
for sub_spec_key in sub_spec.keys():
|
|
110
|
+
if sub_spec_key not in keys: # exclusive key
|
|
111
|
+
exclusive_keys_per_spec[spec_index].add(sub_spec_key)
|
|
112
|
+
value = sub_spec[sub_spec_key]
|
|
113
|
+
if sub_spec_key in exclusive_keys_examples:
|
|
114
|
+
exclusive_keys_examples[sub_spec_key].append(value)
|
|
115
|
+
else:
|
|
116
|
+
exclusive_keys_examples.update({sub_spec_key: [value]})
|
|
117
|
+
|
|
118
|
+
for sub_spec, exclusive_keys in zip(
|
|
119
|
+
spec._specs, exclusive_keys_per_spec
|
|
120
|
+
): # add missing exclusive entries
|
|
121
|
+
for exclusive_key in set(exclusive_keys_examples.keys()).difference(
|
|
122
|
+
exclusive_keys
|
|
123
|
+
):
|
|
124
|
+
exclusive_keys_example_list = exclusive_keys_examples[exclusive_key]
|
|
125
|
+
sub_spec.set(
|
|
126
|
+
exclusive_key,
|
|
127
|
+
_empty_like_spec(exclusive_keys_example_list, sub_spec.shape),
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
if recurse_through_entries:
|
|
131
|
+
for key, value in spec.items():
|
|
132
|
+
if isinstance(value, (Composite, StackedComposite)):
|
|
133
|
+
spec.set(
|
|
134
|
+
key,
|
|
135
|
+
consolidate_spec(
|
|
136
|
+
value, recurse_through_entries, recurse_through_stack
|
|
137
|
+
),
|
|
138
|
+
)
|
|
139
|
+
return spec
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
def _empty_like_spec(specs: list[TensorSpec], shape):
|
|
143
|
+
for spec in specs[1:]:
|
|
144
|
+
if spec.__class__ != specs[0].__class__:
|
|
145
|
+
raise ValueError(
|
|
146
|
+
"Found same key in lazy specs corresponding to entries with different classes"
|
|
147
|
+
)
|
|
148
|
+
spec = specs[0]
|
|
149
|
+
if isinstance(spec, (Composite, StackedComposite)):
|
|
150
|
+
# the exclusive key has values which are Composite specs ->
|
|
151
|
+
# we create an empty composite spec with same batch size
|
|
152
|
+
return spec.empty()
|
|
153
|
+
elif isinstance(spec, Stacked):
|
|
154
|
+
# the exclusive key has values which are Stacked specs ->
|
|
155
|
+
# we create a Stacked spec with the same shape (aka same -1s) as the first in the list.
|
|
156
|
+
# this will not add any new -1s when they are stacked
|
|
157
|
+
shape = list(shape[: spec.stack_dim]) + list(shape[spec.stack_dim + 1 :])
|
|
158
|
+
return Stacked(
|
|
159
|
+
*[_empty_like_spec(spec._specs, shape) for _ in spec._specs],
|
|
160
|
+
dim=spec.stack_dim,
|
|
161
|
+
)
|
|
162
|
+
else:
|
|
163
|
+
# the exclusive key has values which are TensorSpecs ->
|
|
164
|
+
# if the shapes of the values are all the same, we create a TensorSpec with leading shape `shape` and following dims 0 (having the same ndims as the values)
|
|
165
|
+
# if the shapes of the values differ, we create a TensorSpec with 0 size in the differing dims
|
|
166
|
+
spec_shape = list(spec.shape)
|
|
167
|
+
|
|
168
|
+
for dim_index in range(len(spec_shape)):
|
|
169
|
+
hetero_dim = False
|
|
170
|
+
for sub_spec in specs:
|
|
171
|
+
if sub_spec.shape[dim_index] != spec.shape[dim_index]:
|
|
172
|
+
hetero_dim = True
|
|
173
|
+
break
|
|
174
|
+
if hetero_dim:
|
|
175
|
+
spec_shape[dim_index] = 0
|
|
176
|
+
|
|
177
|
+
if 0 not in spec_shape: # the values have all same shape
|
|
178
|
+
spec_shape = [
|
|
179
|
+
dim if i < len(shape) else 0 for i, dim in enumerate(spec_shape)
|
|
180
|
+
]
|
|
181
|
+
|
|
182
|
+
spec = spec[(0,) * len(spec.shape)]
|
|
183
|
+
spec = spec.expand(spec_shape)
|
|
184
|
+
|
|
185
|
+
return spec
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
def check_no_exclusive_keys(spec: TensorSpec, recurse: bool = True):
|
|
189
|
+
"""Given a TensorSpec, returns true if there are no exclusive keys.
|
|
190
|
+
|
|
191
|
+
Args:
|
|
192
|
+
spec (TensorSpec): the spec to check
|
|
193
|
+
recurse (bool): if True, check recursively in nested specs. Default is True.
|
|
194
|
+
"""
|
|
195
|
+
if isinstance(spec, StackedComposite):
|
|
196
|
+
keys = set(spec.keys())
|
|
197
|
+
for inner_td in spec._specs:
|
|
198
|
+
if recurse and not check_no_exclusive_keys(inner_td):
|
|
199
|
+
return False
|
|
200
|
+
if set(inner_td.keys()) != keys:
|
|
201
|
+
return False
|
|
202
|
+
elif isinstance(spec, Composite) and recurse:
|
|
203
|
+
for value in spec.values():
|
|
204
|
+
if not check_no_exclusive_keys(value):
|
|
205
|
+
return False
|
|
206
|
+
else:
|
|
207
|
+
return True
|
|
208
|
+
return True
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
def contains_lazy_spec(spec: TensorSpec) -> bool:
|
|
212
|
+
"""Returns true if a spec contains lazy stacked specs.
|
|
213
|
+
|
|
214
|
+
Args:
|
|
215
|
+
spec (TensorSpec): the spec to check
|
|
216
|
+
|
|
217
|
+
"""
|
|
218
|
+
if isinstance(spec, (Stacked, StackedComposite)):
|
|
219
|
+
return True
|
|
220
|
+
elif isinstance(spec, Composite):
|
|
221
|
+
for inner_spec in spec.values():
|
|
222
|
+
if contains_lazy_spec(inner_spec):
|
|
223
|
+
return True
|
|
224
|
+
return False
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
class _CloudpickleWrapperMeta(type):
|
|
228
|
+
def __call__(cls, obj):
|
|
229
|
+
if isinstance(obj, cls):
|
|
230
|
+
return obj
|
|
231
|
+
else:
|
|
232
|
+
return super().__call__(obj)
|
|
233
|
+
|
|
234
|
+
|
|
235
|
+
class CloudpickleWrapper(metaclass=_CloudpickleWrapperMeta):
|
|
236
|
+
"""A wrapper for functions that allow for serialization in multiprocessed settings."""
|
|
237
|
+
|
|
238
|
+
def __init__(self, fn: Callable, **kwargs):
|
|
239
|
+
if fn.__class__.__name__ == "EnvCreator":
|
|
240
|
+
raise RuntimeError(
|
|
241
|
+
"CloudpickleWrapper usage with EnvCreator class is "
|
|
242
|
+
"prohibited as it breaks the transmission of shared tensors."
|
|
243
|
+
)
|
|
244
|
+
self.fn = fn
|
|
245
|
+
self.kwargs = kwargs
|
|
246
|
+
|
|
247
|
+
functools.update_wrapper(self, getattr(fn, "forward", fn))
|
|
248
|
+
|
|
249
|
+
def __getstate__(self):
|
|
250
|
+
return cloudpickle.dumps((self.fn, self.kwargs))
|
|
251
|
+
|
|
252
|
+
def __setstate__(self, ob: bytes):
|
|
253
|
+
self.fn, self.kwargs = cloudpickle.loads(ob)
|
|
254
|
+
functools.update_wrapper(self, getattr(self.fn, "forward", self.fn))
|
|
255
|
+
|
|
256
|
+
def __call__(self, *args, **kwargs) -> Any:
|
|
257
|
+
kwargs.update(self.kwargs)
|
|
258
|
+
return self.fn(*args, **kwargs)
|
|
259
|
+
|
|
260
|
+
|
|
261
|
+
def _process_action_space_spec(action_space, spec):
|
|
262
|
+
original_spec = spec
|
|
263
|
+
composite_spec = False
|
|
264
|
+
if isinstance(spec, Composite):
|
|
265
|
+
# this will break whenever our action is more complex than a single tensor
|
|
266
|
+
try:
|
|
267
|
+
if "action" in spec.keys():
|
|
268
|
+
_key = "action"
|
|
269
|
+
else:
|
|
270
|
+
# the first key is the action
|
|
271
|
+
for _key in spec.keys(True, True):
|
|
272
|
+
if isinstance(_key, tuple) and _key[-1] == "action":
|
|
273
|
+
break
|
|
274
|
+
else:
|
|
275
|
+
raise KeyError
|
|
276
|
+
spec = spec[_key]
|
|
277
|
+
composite_spec = True
|
|
278
|
+
except KeyError:
|
|
279
|
+
raise KeyError(
|
|
280
|
+
"action could not be found in the spec. Make sure "
|
|
281
|
+
"you pass a spec that is either a native action spec or a composite action spec "
|
|
282
|
+
"with a leaf 'action' entry. Otherwise, simply remove the spec and use the action_space only."
|
|
283
|
+
)
|
|
284
|
+
if action_space is not None:
|
|
285
|
+
if isinstance(action_space, Composite):
|
|
286
|
+
raise ValueError("action_space cannot be of type Composite.")
|
|
287
|
+
if (
|
|
288
|
+
spec is not None
|
|
289
|
+
and isinstance(action_space, TensorSpec)
|
|
290
|
+
and action_space is not spec
|
|
291
|
+
):
|
|
292
|
+
raise ValueError(
|
|
293
|
+
"Passing an action_space as a TensorSpec and a spec isn't allowed, unless they match."
|
|
294
|
+
)
|
|
295
|
+
if isinstance(action_space, TensorSpec):
|
|
296
|
+
spec = action_space
|
|
297
|
+
action_space = _find_action_space(action_space)
|
|
298
|
+
# check that the spec and action_space match
|
|
299
|
+
if spec is not None and _find_action_space(spec) != action_space:
|
|
300
|
+
raise ValueError(
|
|
301
|
+
f"The action spec and the action space do not match: got action_space={action_space} and spec={spec}."
|
|
302
|
+
)
|
|
303
|
+
elif spec is not None:
|
|
304
|
+
action_space = _find_action_space(spec)
|
|
305
|
+
else:
|
|
306
|
+
raise ValueError(
|
|
307
|
+
"Neither action_space nor spec was defined. The action space cannot be inferred."
|
|
308
|
+
)
|
|
309
|
+
if composite_spec:
|
|
310
|
+
spec = original_spec
|
|
311
|
+
return action_space, spec
|
|
312
|
+
|
|
313
|
+
|
|
314
|
+
def _find_action_space(action_space) -> str:
|
|
315
|
+
if isinstance(action_space, TensorSpec):
|
|
316
|
+
if isinstance(action_space, Composite):
|
|
317
|
+
if "action" in action_space.keys():
|
|
318
|
+
_key = "action"
|
|
319
|
+
else:
|
|
320
|
+
# the first key is the action
|
|
321
|
+
for _key in action_space.keys(True, True):
|
|
322
|
+
if isinstance(_key, tuple) and _key[-1] == "action":
|
|
323
|
+
break
|
|
324
|
+
else:
|
|
325
|
+
raise KeyError
|
|
326
|
+
action_space = action_space[_key]
|
|
327
|
+
action_space = type(action_space)
|
|
328
|
+
try:
|
|
329
|
+
action_space = ACTION_SPACE_MAP[action_space]
|
|
330
|
+
except KeyError:
|
|
331
|
+
raise ValueError(
|
|
332
|
+
f"action_space was not specified/not compatible and could not be retrieved from the value network. Got action_space={action_space}."
|
|
333
|
+
)
|
|
334
|
+
return action_space
|
torchrl/envs/__init__.py
ADDED
|
@@ -0,0 +1,265 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
|
|
6
|
+
from .async_envs import AsyncEnvPool, ProcessorAsyncEnvPool, ThreadingAsyncEnvPool
|
|
7
|
+
from .batched_envs import ParallelEnv, SerialEnv
|
|
8
|
+
from .common import EnvBase, EnvMetaData, make_tensordict
|
|
9
|
+
from .custom import ChessEnv, LLMHashingEnv, PendulumEnv, TicTacToeEnv
|
|
10
|
+
from .env_creator import env_creator, EnvCreator, get_env_metadata
|
|
11
|
+
from .gym_like import default_info_dict_reader, GymLikeEnv
|
|
12
|
+
from .libs import (
|
|
13
|
+
BraxEnv,
|
|
14
|
+
BraxWrapper,
|
|
15
|
+
DMControlEnv,
|
|
16
|
+
DMControlWrapper,
|
|
17
|
+
gym_backend,
|
|
18
|
+
GymEnv,
|
|
19
|
+
GymWrapper,
|
|
20
|
+
HabitatEnv,
|
|
21
|
+
IsaacGymEnv,
|
|
22
|
+
IsaacGymWrapper,
|
|
23
|
+
IsaacLabWrapper,
|
|
24
|
+
JumanjiEnv,
|
|
25
|
+
JumanjiWrapper,
|
|
26
|
+
MeltingpotEnv,
|
|
27
|
+
MeltingpotWrapper,
|
|
28
|
+
MOGymEnv,
|
|
29
|
+
MOGymWrapper,
|
|
30
|
+
MultiThreadedEnv,
|
|
31
|
+
MultiThreadedEnvWrapper,
|
|
32
|
+
OpenMLEnv,
|
|
33
|
+
OpenSpielEnv,
|
|
34
|
+
OpenSpielWrapper,
|
|
35
|
+
PettingZooEnv,
|
|
36
|
+
PettingZooWrapper,
|
|
37
|
+
ProcgenEnv,
|
|
38
|
+
ProcgenWrapper,
|
|
39
|
+
register_gym_spec_conversion,
|
|
40
|
+
RoboHiveEnv,
|
|
41
|
+
set_gym_backend,
|
|
42
|
+
SMACv2Env,
|
|
43
|
+
SMACv2Wrapper,
|
|
44
|
+
UnityMLAgentsEnv,
|
|
45
|
+
UnityMLAgentsWrapper,
|
|
46
|
+
VmasEnv,
|
|
47
|
+
VmasWrapper,
|
|
48
|
+
)
|
|
49
|
+
from .model_based import DreamerDecoder, DreamerEnv, ModelBasedEnvBase
|
|
50
|
+
from .transforms import (
|
|
51
|
+
ActionDiscretizer,
|
|
52
|
+
ActionMask,
|
|
53
|
+
AutoResetEnv,
|
|
54
|
+
AutoResetTransform,
|
|
55
|
+
BatchSizeTransform,
|
|
56
|
+
BinarizeReward,
|
|
57
|
+
BurnInTransform,
|
|
58
|
+
CatFrames,
|
|
59
|
+
CatTensors,
|
|
60
|
+
CenterCrop,
|
|
61
|
+
ClipTransform,
|
|
62
|
+
Compose,
|
|
63
|
+
ConditionalPolicySwitch,
|
|
64
|
+
ConditionalSkip,
|
|
65
|
+
Crop,
|
|
66
|
+
DeviceCastTransform,
|
|
67
|
+
DiscreteActionProjection,
|
|
68
|
+
DoubleToFloat,
|
|
69
|
+
DTypeCastTransform,
|
|
70
|
+
EndOfLifeTransform,
|
|
71
|
+
ExcludeTransform,
|
|
72
|
+
FiniteTensorDictCheck,
|
|
73
|
+
FlattenObservation,
|
|
74
|
+
FrameSkipTransform,
|
|
75
|
+
GrayScale,
|
|
76
|
+
gSDENoise,
|
|
77
|
+
Hash,
|
|
78
|
+
InitTracker,
|
|
79
|
+
LineariseRewards,
|
|
80
|
+
MultiAction,
|
|
81
|
+
MultiStepTransform,
|
|
82
|
+
NoopResetEnv,
|
|
83
|
+
ObservationNorm,
|
|
84
|
+
ObservationTransform,
|
|
85
|
+
PermuteTransform,
|
|
86
|
+
PinMemoryTransform,
|
|
87
|
+
R3MTransform,
|
|
88
|
+
RandomCropTensorDict,
|
|
89
|
+
RemoveEmptySpecs,
|
|
90
|
+
RenameTransform,
|
|
91
|
+
Resize,
|
|
92
|
+
Reward2GoTransform,
|
|
93
|
+
RewardClipping,
|
|
94
|
+
RewardScaling,
|
|
95
|
+
RewardSum,
|
|
96
|
+
SelectTransform,
|
|
97
|
+
SignTransform,
|
|
98
|
+
SqueezeTransform,
|
|
99
|
+
Stack,
|
|
100
|
+
StepCounter,
|
|
101
|
+
TargetReturn,
|
|
102
|
+
TensorDictPrimer,
|
|
103
|
+
TimeMaxPool,
|
|
104
|
+
Timer,
|
|
105
|
+
Tokenizer,
|
|
106
|
+
ToTensorImage,
|
|
107
|
+
TrajCounter,
|
|
108
|
+
Transform,
|
|
109
|
+
TransformedEnv,
|
|
110
|
+
UnaryTransform,
|
|
111
|
+
UnsqueezeTransform,
|
|
112
|
+
VC1Transform,
|
|
113
|
+
VecGymEnvTransform,
|
|
114
|
+
VecNorm,
|
|
115
|
+
VecNormV2,
|
|
116
|
+
VIPRewardTransform,
|
|
117
|
+
VIPTransform,
|
|
118
|
+
)
|
|
119
|
+
from .utils import (
|
|
120
|
+
check_env_specs,
|
|
121
|
+
check_marl_grouping,
|
|
122
|
+
exploration_type,
|
|
123
|
+
ExplorationType,
|
|
124
|
+
get_available_libraries,
|
|
125
|
+
make_composite_from_td,
|
|
126
|
+
MarlGroupMapType,
|
|
127
|
+
set_exploration_type,
|
|
128
|
+
step_mdp,
|
|
129
|
+
terminated_or_truncated,
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
__all__ = [
|
|
133
|
+
"ActionDiscretizer",
|
|
134
|
+
"ActionMask",
|
|
135
|
+
"VecNormV2",
|
|
136
|
+
"IsaacLabWrapper",
|
|
137
|
+
"AutoResetEnv",
|
|
138
|
+
"AutoResetTransform",
|
|
139
|
+
"AsyncEnvPool",
|
|
140
|
+
"ProcessorAsyncEnvPool",
|
|
141
|
+
"ConditionalPolicySwitch",
|
|
142
|
+
"ThreadingAsyncEnvPool",
|
|
143
|
+
"BatchSizeTransform",
|
|
144
|
+
"BinarizeReward",
|
|
145
|
+
"BraxEnv",
|
|
146
|
+
"BraxWrapper",
|
|
147
|
+
"BurnInTransform",
|
|
148
|
+
"CatFrames",
|
|
149
|
+
"CatTensors",
|
|
150
|
+
"CenterCrop",
|
|
151
|
+
"ChessEnv",
|
|
152
|
+
"ClipTransform",
|
|
153
|
+
"Compose",
|
|
154
|
+
"ConditionalSkip",
|
|
155
|
+
"Crop",
|
|
156
|
+
"DMControlEnv",
|
|
157
|
+
"DMControlWrapper",
|
|
158
|
+
"DTypeCastTransform",
|
|
159
|
+
"DeviceCastTransform",
|
|
160
|
+
"DiscreteActionProjection",
|
|
161
|
+
"DoubleToFloat",
|
|
162
|
+
"DreamerDecoder",
|
|
163
|
+
"DreamerEnv",
|
|
164
|
+
"EndOfLifeTransform",
|
|
165
|
+
"EnvBase",
|
|
166
|
+
"EnvCreator",
|
|
167
|
+
"EnvMetaData",
|
|
168
|
+
"ExcludeTransform",
|
|
169
|
+
"ExplorationType",
|
|
170
|
+
"FiniteTensorDictCheck",
|
|
171
|
+
"FlattenObservation",
|
|
172
|
+
"FrameSkipTransform",
|
|
173
|
+
"GrayScale",
|
|
174
|
+
"GymEnv",
|
|
175
|
+
"GymLikeEnv",
|
|
176
|
+
"GymWrapper",
|
|
177
|
+
"HabitatEnv",
|
|
178
|
+
"Hash",
|
|
179
|
+
"InitTracker",
|
|
180
|
+
"IsaacGymEnv",
|
|
181
|
+
"IsaacGymWrapper",
|
|
182
|
+
"JumanjiEnv",
|
|
183
|
+
"JumanjiWrapper",
|
|
184
|
+
"LLMHashingEnv",
|
|
185
|
+
"LineariseRewards",
|
|
186
|
+
"MOGymEnv",
|
|
187
|
+
"MOGymWrapper",
|
|
188
|
+
"MarlGroupMapType",
|
|
189
|
+
"MeltingpotEnv",
|
|
190
|
+
"MeltingpotWrapper",
|
|
191
|
+
"ModelBasedEnvBase",
|
|
192
|
+
"MultiAction",
|
|
193
|
+
"MultiStepTransform",
|
|
194
|
+
"MultiThreadedEnv",
|
|
195
|
+
"MultiThreadedEnvWrapper",
|
|
196
|
+
"NoopResetEnv",
|
|
197
|
+
"ObservationNorm",
|
|
198
|
+
"ObservationTransform",
|
|
199
|
+
"OpenMLEnv",
|
|
200
|
+
"OpenSpielEnv",
|
|
201
|
+
"OpenSpielWrapper",
|
|
202
|
+
"ParallelEnv",
|
|
203
|
+
"PendulumEnv",
|
|
204
|
+
"PermuteTransform",
|
|
205
|
+
"PettingZooEnv",
|
|
206
|
+
"PettingZooWrapper",
|
|
207
|
+
"PinMemoryTransform",
|
|
208
|
+
"ProcgenEnv",
|
|
209
|
+
"ProcgenWrapper",
|
|
210
|
+
"R3MTransform",
|
|
211
|
+
"RandomCropTensorDict",
|
|
212
|
+
"RemoveEmptySpecs",
|
|
213
|
+
"RenameTransform",
|
|
214
|
+
"Resize",
|
|
215
|
+
"Reward2GoTransform",
|
|
216
|
+
"RewardClipping",
|
|
217
|
+
"RewardScaling",
|
|
218
|
+
"RewardSum",
|
|
219
|
+
"RoboHiveEnv",
|
|
220
|
+
"SMACv2Env",
|
|
221
|
+
"SMACv2Wrapper",
|
|
222
|
+
"SelectTransform",
|
|
223
|
+
"SerialEnv",
|
|
224
|
+
"SignTransform",
|
|
225
|
+
"SqueezeTransform",
|
|
226
|
+
"Stack",
|
|
227
|
+
"StepCounter",
|
|
228
|
+
"TargetReturn",
|
|
229
|
+
"TensorDictPrimer",
|
|
230
|
+
"TicTacToeEnv",
|
|
231
|
+
"TimeMaxPool",
|
|
232
|
+
"Timer",
|
|
233
|
+
"ToTensorImage",
|
|
234
|
+
"Tokenizer",
|
|
235
|
+
"TrajCounter",
|
|
236
|
+
"Transform",
|
|
237
|
+
"TransformedEnv",
|
|
238
|
+
"UnaryTransform",
|
|
239
|
+
"UnityMLAgentsEnv",
|
|
240
|
+
"UnityMLAgentsWrapper",
|
|
241
|
+
"UnsqueezeTransform",
|
|
242
|
+
"VC1Transform",
|
|
243
|
+
"VIPRewardTransform",
|
|
244
|
+
"VIPTransform",
|
|
245
|
+
"VecGymEnvTransform",
|
|
246
|
+
"VecNorm",
|
|
247
|
+
"VmasEnv",
|
|
248
|
+
"VmasWrapper",
|
|
249
|
+
"check_env_specs",
|
|
250
|
+
"check_marl_grouping",
|
|
251
|
+
"default_info_dict_reader",
|
|
252
|
+
"env_creator",
|
|
253
|
+
"exploration_type",
|
|
254
|
+
"gSDENoise",
|
|
255
|
+
"get_available_libraries",
|
|
256
|
+
"get_env_metadata",
|
|
257
|
+
"gym_backend",
|
|
258
|
+
"make_composite_from_td",
|
|
259
|
+
"make_tensordict",
|
|
260
|
+
"register_gym_spec_conversion",
|
|
261
|
+
"set_exploration_type",
|
|
262
|
+
"set_gym_backend",
|
|
263
|
+
"step_mdp",
|
|
264
|
+
"terminated_or_truncated",
|
|
265
|
+
]
|