torchrl 0.11.0__cp314-cp314-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/benchmark_batched_envs.py +104 -0
- benchmarks/conftest.py +91 -0
- benchmarks/ecosystem/gym_env_throughput.py +321 -0
- benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
- benchmarks/requirements.txt +7 -0
- benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
- benchmarks/test_collectors_benchmark.py +240 -0
- benchmarks/test_compressed_storage_benchmark.py +145 -0
- benchmarks/test_envs_benchmark.py +133 -0
- benchmarks/test_llm.py +101 -0
- benchmarks/test_non_tensor_env_benchmark.py +70 -0
- benchmarks/test_objectives_benchmarks.py +1199 -0
- benchmarks/test_replaybuffer_benchmark.py +254 -0
- sota-check/README.md +35 -0
- sota-implementations/README.md +142 -0
- sota-implementations/a2c/README.md +39 -0
- sota-implementations/a2c/a2c_atari.py +291 -0
- sota-implementations/a2c/a2c_mujoco.py +273 -0
- sota-implementations/a2c/utils_atari.py +240 -0
- sota-implementations/a2c/utils_mujoco.py +160 -0
- sota-implementations/bandits/README.md +7 -0
- sota-implementations/bandits/dqn.py +126 -0
- sota-implementations/cql/cql_offline.py +198 -0
- sota-implementations/cql/cql_online.py +249 -0
- sota-implementations/cql/discrete_cql_offline.py +180 -0
- sota-implementations/cql/discrete_cql_online.py +227 -0
- sota-implementations/cql/utils.py +471 -0
- sota-implementations/crossq/crossq.py +271 -0
- sota-implementations/crossq/utils.py +320 -0
- sota-implementations/ddpg/ddpg.py +231 -0
- sota-implementations/ddpg/utils.py +325 -0
- sota-implementations/decision_transformer/dt.py +163 -0
- sota-implementations/decision_transformer/lamb.py +167 -0
- sota-implementations/decision_transformer/online_dt.py +178 -0
- sota-implementations/decision_transformer/utils.py +562 -0
- sota-implementations/discrete_sac/discrete_sac.py +243 -0
- sota-implementations/discrete_sac/utils.py +324 -0
- sota-implementations/dqn/README.md +30 -0
- sota-implementations/dqn/dqn_atari.py +272 -0
- sota-implementations/dqn/dqn_cartpole.py +236 -0
- sota-implementations/dqn/utils_atari.py +132 -0
- sota-implementations/dqn/utils_cartpole.py +90 -0
- sota-implementations/dreamer/README.md +129 -0
- sota-implementations/dreamer/dreamer.py +586 -0
- sota-implementations/dreamer/dreamer_utils.py +1107 -0
- sota-implementations/expert-iteration/README.md +352 -0
- sota-implementations/expert-iteration/ei_utils.py +770 -0
- sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
- sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
- sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
- sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
- sota-implementations/gail/gail.py +327 -0
- sota-implementations/gail/gail_utils.py +68 -0
- sota-implementations/gail/ppo_utils.py +157 -0
- sota-implementations/grpo/README.md +273 -0
- sota-implementations/grpo/grpo-async.py +437 -0
- sota-implementations/grpo/grpo-sync.py +435 -0
- sota-implementations/grpo/grpo_utils.py +843 -0
- sota-implementations/grpo/requirements_gsm8k.txt +11 -0
- sota-implementations/grpo/requirements_ifeval.txt +16 -0
- sota-implementations/impala/README.md +33 -0
- sota-implementations/impala/impala_multi_node_ray.py +292 -0
- sota-implementations/impala/impala_multi_node_submitit.py +284 -0
- sota-implementations/impala/impala_single_node.py +261 -0
- sota-implementations/impala/utils.py +184 -0
- sota-implementations/iql/discrete_iql.py +230 -0
- sota-implementations/iql/iql_offline.py +164 -0
- sota-implementations/iql/iql_online.py +225 -0
- sota-implementations/iql/utils.py +437 -0
- sota-implementations/multiagent/README.md +74 -0
- sota-implementations/multiagent/iql.py +237 -0
- sota-implementations/multiagent/maddpg_iddpg.py +266 -0
- sota-implementations/multiagent/mappo_ippo.py +267 -0
- sota-implementations/multiagent/qmix_vdn.py +271 -0
- sota-implementations/multiagent/sac.py +337 -0
- sota-implementations/multiagent/utils/__init__.py +4 -0
- sota-implementations/multiagent/utils/logging.py +151 -0
- sota-implementations/multiagent/utils/utils.py +43 -0
- sota-implementations/ppo/README.md +29 -0
- sota-implementations/ppo/ppo_atari.py +305 -0
- sota-implementations/ppo/ppo_mujoco.py +293 -0
- sota-implementations/ppo/utils_atari.py +238 -0
- sota-implementations/ppo/utils_mujoco.py +152 -0
- sota-implementations/ppo_trainer/train.py +21 -0
- sota-implementations/redq/README.md +7 -0
- sota-implementations/redq/redq.py +199 -0
- sota-implementations/redq/utils.py +1060 -0
- sota-implementations/sac/sac-async.py +266 -0
- sota-implementations/sac/sac.py +239 -0
- sota-implementations/sac/utils.py +381 -0
- sota-implementations/sac_trainer/train.py +16 -0
- sota-implementations/td3/td3.py +254 -0
- sota-implementations/td3/utils.py +319 -0
- sota-implementations/td3_bc/td3_bc.py +177 -0
- sota-implementations/td3_bc/utils.py +251 -0
- torchrl/.dylibs/libc++.1.0.dylib +0 -0
- torchrl/__init__.py +144 -0
- torchrl/_extension.py +74 -0
- torchrl/_torchrl.cpython-314-darwin.so +0 -0
- torchrl/_utils.py +1431 -0
- torchrl/collectors/__init__.py +48 -0
- torchrl/collectors/_base.py +1058 -0
- torchrl/collectors/_constants.py +88 -0
- torchrl/collectors/_multi_async.py +324 -0
- torchrl/collectors/_multi_base.py +1805 -0
- torchrl/collectors/_multi_sync.py +464 -0
- torchrl/collectors/_runner.py +581 -0
- torchrl/collectors/_single.py +2009 -0
- torchrl/collectors/_single_async.py +259 -0
- torchrl/collectors/collectors.py +62 -0
- torchrl/collectors/distributed/__init__.py +32 -0
- torchrl/collectors/distributed/default_configs.py +133 -0
- torchrl/collectors/distributed/generic.py +1306 -0
- torchrl/collectors/distributed/ray.py +1092 -0
- torchrl/collectors/distributed/rpc.py +1006 -0
- torchrl/collectors/distributed/sync.py +731 -0
- torchrl/collectors/distributed/utils.py +160 -0
- torchrl/collectors/llm/__init__.py +10 -0
- torchrl/collectors/llm/base.py +494 -0
- torchrl/collectors/llm/ray_collector.py +275 -0
- torchrl/collectors/llm/utils.py +36 -0
- torchrl/collectors/llm/weight_update/__init__.py +10 -0
- torchrl/collectors/llm/weight_update/vllm.py +348 -0
- torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
- torchrl/collectors/utils.py +433 -0
- torchrl/collectors/weight_update.py +591 -0
- torchrl/csrc/numpy_utils.h +38 -0
- torchrl/csrc/pybind.cpp +27 -0
- torchrl/csrc/segment_tree.h +458 -0
- torchrl/csrc/torch_utils.h +34 -0
- torchrl/csrc/utils.cpp +48 -0
- torchrl/csrc/utils.h +31 -0
- torchrl/data/__init__.py +187 -0
- torchrl/data/datasets/__init__.py +58 -0
- torchrl/data/datasets/atari_dqn.py +878 -0
- torchrl/data/datasets/common.py +281 -0
- torchrl/data/datasets/d4rl.py +489 -0
- torchrl/data/datasets/d4rl_infos.py +187 -0
- torchrl/data/datasets/gen_dgrl.py +375 -0
- torchrl/data/datasets/minari_data.py +643 -0
- torchrl/data/datasets/openml.py +177 -0
- torchrl/data/datasets/openx.py +798 -0
- torchrl/data/datasets/roboset.py +363 -0
- torchrl/data/datasets/utils.py +11 -0
- torchrl/data/datasets/vd4rl.py +432 -0
- torchrl/data/llm/__init__.py +34 -0
- torchrl/data/llm/dataset.py +491 -0
- torchrl/data/llm/history.py +1378 -0
- torchrl/data/llm/prompt.py +198 -0
- torchrl/data/llm/reward.py +225 -0
- torchrl/data/llm/topk.py +186 -0
- torchrl/data/llm/utils.py +543 -0
- torchrl/data/map/__init__.py +21 -0
- torchrl/data/map/hash.py +185 -0
- torchrl/data/map/query.py +204 -0
- torchrl/data/map/tdstorage.py +363 -0
- torchrl/data/map/tree.py +1434 -0
- torchrl/data/map/utils.py +103 -0
- torchrl/data/postprocs/__init__.py +8 -0
- torchrl/data/postprocs/postprocs.py +391 -0
- torchrl/data/replay_buffers/__init__.py +99 -0
- torchrl/data/replay_buffers/checkpointers.py +622 -0
- torchrl/data/replay_buffers/ray_buffer.py +292 -0
- torchrl/data/replay_buffers/replay_buffers.py +2376 -0
- torchrl/data/replay_buffers/samplers.py +2578 -0
- torchrl/data/replay_buffers/scheduler.py +265 -0
- torchrl/data/replay_buffers/storages.py +2412 -0
- torchrl/data/replay_buffers/utils.py +1042 -0
- torchrl/data/replay_buffers/writers.py +781 -0
- torchrl/data/tensor_specs.py +7101 -0
- torchrl/data/utils.py +334 -0
- torchrl/envs/__init__.py +265 -0
- torchrl/envs/async_envs.py +1105 -0
- torchrl/envs/batched_envs.py +3093 -0
- torchrl/envs/common.py +4241 -0
- torchrl/envs/custom/__init__.py +11 -0
- torchrl/envs/custom/chess.py +617 -0
- torchrl/envs/custom/llm.py +214 -0
- torchrl/envs/custom/pendulum.py +401 -0
- torchrl/envs/custom/san_moves.txt +29274 -0
- torchrl/envs/custom/tictactoeenv.py +288 -0
- torchrl/envs/env_creator.py +263 -0
- torchrl/envs/gym_like.py +752 -0
- torchrl/envs/libs/__init__.py +68 -0
- torchrl/envs/libs/_gym_utils.py +326 -0
- torchrl/envs/libs/brax.py +846 -0
- torchrl/envs/libs/dm_control.py +544 -0
- torchrl/envs/libs/envpool.py +447 -0
- torchrl/envs/libs/gym.py +2239 -0
- torchrl/envs/libs/habitat.py +138 -0
- torchrl/envs/libs/isaac_lab.py +87 -0
- torchrl/envs/libs/isaacgym.py +203 -0
- torchrl/envs/libs/jax_utils.py +166 -0
- torchrl/envs/libs/jumanji.py +963 -0
- torchrl/envs/libs/meltingpot.py +599 -0
- torchrl/envs/libs/openml.py +153 -0
- torchrl/envs/libs/openspiel.py +652 -0
- torchrl/envs/libs/pettingzoo.py +1042 -0
- torchrl/envs/libs/procgen.py +351 -0
- torchrl/envs/libs/robohive.py +429 -0
- torchrl/envs/libs/smacv2.py +645 -0
- torchrl/envs/libs/unity_mlagents.py +891 -0
- torchrl/envs/libs/utils.py +147 -0
- torchrl/envs/libs/vmas.py +813 -0
- torchrl/envs/llm/__init__.py +63 -0
- torchrl/envs/llm/chat.py +730 -0
- torchrl/envs/llm/datasets/README.md +4 -0
- torchrl/envs/llm/datasets/__init__.py +17 -0
- torchrl/envs/llm/datasets/gsm8k.py +353 -0
- torchrl/envs/llm/datasets/ifeval.py +274 -0
- torchrl/envs/llm/envs.py +789 -0
- torchrl/envs/llm/libs/README.md +3 -0
- torchrl/envs/llm/libs/__init__.py +8 -0
- torchrl/envs/llm/libs/mlgym.py +869 -0
- torchrl/envs/llm/reward/__init__.py +10 -0
- torchrl/envs/llm/reward/gsm8k.py +324 -0
- torchrl/envs/llm/reward/ifeval/README.md +13 -0
- torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
- torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
- torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
- torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
- torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
- torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
- torchrl/envs/llm/transforms/__init__.py +55 -0
- torchrl/envs/llm/transforms/browser.py +292 -0
- torchrl/envs/llm/transforms/dataloading.py +859 -0
- torchrl/envs/llm/transforms/format.py +73 -0
- torchrl/envs/llm/transforms/kl.py +1544 -0
- torchrl/envs/llm/transforms/policy_version.py +189 -0
- torchrl/envs/llm/transforms/reason.py +323 -0
- torchrl/envs/llm/transforms/tokenizer.py +321 -0
- torchrl/envs/llm/transforms/tools.py +1955 -0
- torchrl/envs/model_based/__init__.py +9 -0
- torchrl/envs/model_based/common.py +180 -0
- torchrl/envs/model_based/dreamer.py +112 -0
- torchrl/envs/transforms/__init__.py +147 -0
- torchrl/envs/transforms/functional.py +48 -0
- torchrl/envs/transforms/gym_transforms.py +203 -0
- torchrl/envs/transforms/module.py +341 -0
- torchrl/envs/transforms/r3m.py +372 -0
- torchrl/envs/transforms/ray_service.py +663 -0
- torchrl/envs/transforms/rb_transforms.py +214 -0
- torchrl/envs/transforms/transforms.py +11835 -0
- torchrl/envs/transforms/utils.py +94 -0
- torchrl/envs/transforms/vc1.py +307 -0
- torchrl/envs/transforms/vecnorm.py +845 -0
- torchrl/envs/transforms/vip.py +407 -0
- torchrl/envs/utils.py +1718 -0
- torchrl/envs/vec_envs.py +11 -0
- torchrl/modules/__init__.py +206 -0
- torchrl/modules/distributions/__init__.py +73 -0
- torchrl/modules/distributions/continuous.py +830 -0
- torchrl/modules/distributions/discrete.py +908 -0
- torchrl/modules/distributions/truncated_normal.py +187 -0
- torchrl/modules/distributions/utils.py +233 -0
- torchrl/modules/llm/__init__.py +62 -0
- torchrl/modules/llm/backends/__init__.py +65 -0
- torchrl/modules/llm/backends/vllm/__init__.py +94 -0
- torchrl/modules/llm/backends/vllm/_models.py +46 -0
- torchrl/modules/llm/backends/vllm/base.py +72 -0
- torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
- torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
- torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
- torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
- torchrl/modules/llm/policies/__init__.py +28 -0
- torchrl/modules/llm/policies/common.py +1809 -0
- torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
- torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
- torchrl/modules/llm/utils.py +23 -0
- torchrl/modules/mcts/__init__.py +21 -0
- torchrl/modules/mcts/scores.py +579 -0
- torchrl/modules/models/__init__.py +86 -0
- torchrl/modules/models/batchrenorm.py +119 -0
- torchrl/modules/models/decision_transformer.py +179 -0
- torchrl/modules/models/exploration.py +731 -0
- torchrl/modules/models/llm.py +156 -0
- torchrl/modules/models/model_based.py +596 -0
- torchrl/modules/models/models.py +1712 -0
- torchrl/modules/models/multiagent.py +1067 -0
- torchrl/modules/models/recipes/impala.py +185 -0
- torchrl/modules/models/utils.py +162 -0
- torchrl/modules/planners/__init__.py +10 -0
- torchrl/modules/planners/cem.py +228 -0
- torchrl/modules/planners/common.py +73 -0
- torchrl/modules/planners/mppi.py +265 -0
- torchrl/modules/tensordict_module/__init__.py +89 -0
- torchrl/modules/tensordict_module/actors.py +2457 -0
- torchrl/modules/tensordict_module/common.py +529 -0
- torchrl/modules/tensordict_module/exploration.py +814 -0
- torchrl/modules/tensordict_module/probabilistic.py +321 -0
- torchrl/modules/tensordict_module/rnn.py +1639 -0
- torchrl/modules/tensordict_module/sequence.py +132 -0
- torchrl/modules/tensordict_module/world_models.py +34 -0
- torchrl/modules/utils/__init__.py +38 -0
- torchrl/modules/utils/mappings.py +9 -0
- torchrl/modules/utils/utils.py +89 -0
- torchrl/objectives/__init__.py +78 -0
- torchrl/objectives/a2c.py +659 -0
- torchrl/objectives/common.py +753 -0
- torchrl/objectives/cql.py +1346 -0
- torchrl/objectives/crossq.py +710 -0
- torchrl/objectives/ddpg.py +453 -0
- torchrl/objectives/decision_transformer.py +371 -0
- torchrl/objectives/deprecated.py +516 -0
- torchrl/objectives/dqn.py +683 -0
- torchrl/objectives/dreamer.py +488 -0
- torchrl/objectives/functional.py +48 -0
- torchrl/objectives/gail.py +258 -0
- torchrl/objectives/iql.py +996 -0
- torchrl/objectives/llm/__init__.py +30 -0
- torchrl/objectives/llm/grpo.py +846 -0
- torchrl/objectives/llm/sft.py +482 -0
- torchrl/objectives/multiagent/__init__.py +8 -0
- torchrl/objectives/multiagent/qmixer.py +396 -0
- torchrl/objectives/ppo.py +1669 -0
- torchrl/objectives/redq.py +683 -0
- torchrl/objectives/reinforce.py +530 -0
- torchrl/objectives/sac.py +1580 -0
- torchrl/objectives/td3.py +570 -0
- torchrl/objectives/td3_bc.py +625 -0
- torchrl/objectives/utils.py +782 -0
- torchrl/objectives/value/__init__.py +28 -0
- torchrl/objectives/value/advantages.py +1956 -0
- torchrl/objectives/value/functional.py +1459 -0
- torchrl/objectives/value/utils.py +360 -0
- torchrl/record/__init__.py +17 -0
- torchrl/record/loggers/__init__.py +23 -0
- torchrl/record/loggers/common.py +48 -0
- torchrl/record/loggers/csv.py +226 -0
- torchrl/record/loggers/mlflow.py +142 -0
- torchrl/record/loggers/tensorboard.py +139 -0
- torchrl/record/loggers/trackio.py +163 -0
- torchrl/record/loggers/utils.py +78 -0
- torchrl/record/loggers/wandb.py +214 -0
- torchrl/record/recorder.py +554 -0
- torchrl/services/__init__.py +79 -0
- torchrl/services/base.py +109 -0
- torchrl/services/ray_service.py +453 -0
- torchrl/testing/__init__.py +107 -0
- torchrl/testing/assertions.py +179 -0
- torchrl/testing/dist_utils.py +122 -0
- torchrl/testing/env_creators.py +227 -0
- torchrl/testing/env_helper.py +35 -0
- torchrl/testing/gym_helpers.py +156 -0
- torchrl/testing/llm_mocks.py +119 -0
- torchrl/testing/mocking_classes.py +2720 -0
- torchrl/testing/modules.py +295 -0
- torchrl/testing/mp_helpers.py +15 -0
- torchrl/testing/ray_helpers.py +293 -0
- torchrl/testing/utils.py +190 -0
- torchrl/trainers/__init__.py +42 -0
- torchrl/trainers/algorithms/__init__.py +11 -0
- torchrl/trainers/algorithms/configs/__init__.py +705 -0
- torchrl/trainers/algorithms/configs/collectors.py +216 -0
- torchrl/trainers/algorithms/configs/common.py +41 -0
- torchrl/trainers/algorithms/configs/data.py +308 -0
- torchrl/trainers/algorithms/configs/envs.py +104 -0
- torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
- torchrl/trainers/algorithms/configs/logging.py +80 -0
- torchrl/trainers/algorithms/configs/modules.py +570 -0
- torchrl/trainers/algorithms/configs/objectives.py +177 -0
- torchrl/trainers/algorithms/configs/trainers.py +340 -0
- torchrl/trainers/algorithms/configs/transforms.py +955 -0
- torchrl/trainers/algorithms/configs/utils.py +252 -0
- torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
- torchrl/trainers/algorithms/configs/weight_update.py +159 -0
- torchrl/trainers/algorithms/ppo.py +373 -0
- torchrl/trainers/algorithms/sac.py +308 -0
- torchrl/trainers/helpers/__init__.py +40 -0
- torchrl/trainers/helpers/collectors.py +416 -0
- torchrl/trainers/helpers/envs.py +573 -0
- torchrl/trainers/helpers/logger.py +33 -0
- torchrl/trainers/helpers/losses.py +132 -0
- torchrl/trainers/helpers/models.py +658 -0
- torchrl/trainers/helpers/replay_buffer.py +59 -0
- torchrl/trainers/helpers/trainers.py +301 -0
- torchrl/trainers/trainers.py +2052 -0
- torchrl/weight_update/__init__.py +33 -0
- torchrl/weight_update/_distributed.py +749 -0
- torchrl/weight_update/_mp.py +624 -0
- torchrl/weight_update/_noupdate.py +102 -0
- torchrl/weight_update/_ray.py +1032 -0
- torchrl/weight_update/_rpc.py +284 -0
- torchrl/weight_update/_shared.py +891 -0
- torchrl/weight_update/llm/__init__.py +32 -0
- torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
- torchrl/weight_update/llm/vllm_nccl.py +710 -0
- torchrl/weight_update/utils.py +73 -0
- torchrl/weight_update/weight_sync_schemes.py +1244 -0
- torchrl-0.11.0.dist-info/METADATA +1308 -0
- torchrl-0.11.0.dist-info/RECORD +395 -0
- torchrl-0.11.0.dist-info/WHEEL +5 -0
- torchrl-0.11.0.dist-info/entry_points.txt +2 -0
- torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
- torchrl-0.11.0.dist-info/top_level.txt +7 -0
|
@@ -0,0 +1,361 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
from dataclasses import dataclass
|
|
9
|
+
from typing import Any
|
|
10
|
+
|
|
11
|
+
from omegaconf import MISSING
|
|
12
|
+
from torchrl.envs.libs.gym import set_gym_backend
|
|
13
|
+
from torchrl.envs.transforms.transforms import DoubleToFloat
|
|
14
|
+
from torchrl.trainers.algorithms.configs.common import ConfigBase
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@dataclass
|
|
18
|
+
class EnvLibsConfig(ConfigBase):
|
|
19
|
+
"""Base configuration class for environment libs."""
|
|
20
|
+
|
|
21
|
+
_partial_: bool = False
|
|
22
|
+
|
|
23
|
+
def __post_init__(self) -> None:
|
|
24
|
+
"""Post-initialization hook for environment libs configurations."""
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@dataclass
|
|
28
|
+
class GymEnvConfig(EnvLibsConfig):
|
|
29
|
+
"""Configuration for GymEnv environment."""
|
|
30
|
+
|
|
31
|
+
env_name: str = MISSING
|
|
32
|
+
categorical_action_encoding: bool = False
|
|
33
|
+
from_pixels: bool = False
|
|
34
|
+
pixels_only: bool = True
|
|
35
|
+
frame_skip: int = 1
|
|
36
|
+
device: str = "cpu"
|
|
37
|
+
batch_size: list[int] | None = None
|
|
38
|
+
allow_done_after_reset: bool = False
|
|
39
|
+
convert_actions_to_numpy: bool = True
|
|
40
|
+
missing_obs_value: Any = None
|
|
41
|
+
disable_env_checker: bool | None = None
|
|
42
|
+
render_mode: str | None = None
|
|
43
|
+
num_envs: int = 0
|
|
44
|
+
backend: str = "gymnasium"
|
|
45
|
+
_target_: str = "torchrl.trainers.algorithms.configs.envs_libs.make_gym_env"
|
|
46
|
+
|
|
47
|
+
def __post_init__(self) -> None:
|
|
48
|
+
"""Post-initialization hook for GymEnv configuration."""
|
|
49
|
+
super().__post_init__()
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def make_gym_env(
|
|
53
|
+
env_name: str,
|
|
54
|
+
backend: str = "gymnasium",
|
|
55
|
+
from_pixels: bool = False,
|
|
56
|
+
double_to_float: bool = False,
|
|
57
|
+
**kwargs,
|
|
58
|
+
):
|
|
59
|
+
"""Create a Gym/Gymnasium environment.
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
env_name: Name of the environment to create.
|
|
63
|
+
backend: Backend to use (gym or gymnasium).
|
|
64
|
+
from_pixels: Whether to use pixel observations.
|
|
65
|
+
double_to_float: Whether to convert double to float.
|
|
66
|
+
|
|
67
|
+
Returns:
|
|
68
|
+
The created environment instance.
|
|
69
|
+
"""
|
|
70
|
+
from torchrl.envs.libs.gym import GymEnv
|
|
71
|
+
|
|
72
|
+
if backend is not None:
|
|
73
|
+
with set_gym_backend(backend):
|
|
74
|
+
env = GymEnv(env_name, from_pixels=from_pixels, **kwargs)
|
|
75
|
+
else:
|
|
76
|
+
env = GymEnv(env_name, from_pixels=from_pixels, **kwargs)
|
|
77
|
+
|
|
78
|
+
if double_to_float:
|
|
79
|
+
env = env.append_transform(DoubleToFloat(in_keys=["observation"]))
|
|
80
|
+
|
|
81
|
+
return env
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
@dataclass
|
|
85
|
+
class MOGymEnvConfig(EnvLibsConfig):
|
|
86
|
+
"""Configuration for MOGymEnv environment."""
|
|
87
|
+
|
|
88
|
+
env_name: str = MISSING
|
|
89
|
+
categorical_action_encoding: bool = False
|
|
90
|
+
from_pixels: bool = False
|
|
91
|
+
pixels_only: bool = True
|
|
92
|
+
frame_skip: int | None = None
|
|
93
|
+
device: str = "cpu"
|
|
94
|
+
batch_size: list[int] | None = None
|
|
95
|
+
allow_done_after_reset: bool = False
|
|
96
|
+
convert_actions_to_numpy: bool = True
|
|
97
|
+
missing_obs_value: Any = None
|
|
98
|
+
backend: str | None = None
|
|
99
|
+
disable_env_checker: bool | None = None
|
|
100
|
+
render_mode: str | None = None
|
|
101
|
+
num_envs: int = 0
|
|
102
|
+
_target_: str = "torchrl.envs.libs.gym.MOGymEnv"
|
|
103
|
+
|
|
104
|
+
def __post_init__(self) -> None:
|
|
105
|
+
"""Post-initialization hook for MOGymEnv configuration."""
|
|
106
|
+
super().__post_init__()
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
@dataclass
|
|
110
|
+
class BraxEnvConfig(EnvLibsConfig):
|
|
111
|
+
"""Configuration for BraxEnv environment."""
|
|
112
|
+
|
|
113
|
+
env_name: str = MISSING
|
|
114
|
+
categorical_action_encoding: bool = False
|
|
115
|
+
cache_clear_frequency: int | None = None
|
|
116
|
+
from_pixels: bool = False
|
|
117
|
+
frame_skip: int | None = None
|
|
118
|
+
device: str = "cpu"
|
|
119
|
+
batch_size: list[int] | None = None
|
|
120
|
+
allow_done_after_reset: bool = False
|
|
121
|
+
requires_grad: bool = False
|
|
122
|
+
_target_: str = "torchrl.envs.libs.brax.BraxEnv"
|
|
123
|
+
|
|
124
|
+
def __post_init__(self) -> None:
|
|
125
|
+
"""Post-initialization hook for BraxEnv configuration."""
|
|
126
|
+
super().__post_init__()
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
@dataclass
|
|
130
|
+
class DMControlEnvConfig(EnvLibsConfig):
|
|
131
|
+
"""Configuration for DMControlEnv environment."""
|
|
132
|
+
|
|
133
|
+
env_name: str = MISSING
|
|
134
|
+
task_name: str = MISSING
|
|
135
|
+
from_pixels: bool = False
|
|
136
|
+
pixels_only: bool = True
|
|
137
|
+
frame_skip: int | None = None
|
|
138
|
+
device: str = "cpu"
|
|
139
|
+
batch_size: list[int] | None = None
|
|
140
|
+
allow_done_after_reset: bool = False
|
|
141
|
+
_target_: str = "torchrl.envs.libs.dm_control.DMControlEnv"
|
|
142
|
+
|
|
143
|
+
def __post_init__(self) -> None:
|
|
144
|
+
"""Post-initialization hook for DMControlEnv configuration."""
|
|
145
|
+
super().__post_init__()
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
@dataclass
|
|
149
|
+
class HabitatEnvConfig(EnvLibsConfig):
|
|
150
|
+
"""Configuration for HabitatEnv environment."""
|
|
151
|
+
|
|
152
|
+
env_name: str = MISSING
|
|
153
|
+
from_pixels: bool = False
|
|
154
|
+
pixels_only: bool = True
|
|
155
|
+
frame_skip: int | None = None
|
|
156
|
+
device: str = "cpu"
|
|
157
|
+
batch_size: list[int] | None = None
|
|
158
|
+
allow_done_after_reset: bool = False
|
|
159
|
+
_target_: str = "torchrl.envs.libs.habitat.HabitatEnv"
|
|
160
|
+
|
|
161
|
+
def __post_init__(self) -> None:
|
|
162
|
+
"""Post-initialization hook for HabitatEnv configuration."""
|
|
163
|
+
super().__post_init__()
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
@dataclass
|
|
167
|
+
class IsaacGymEnvConfig(EnvLibsConfig):
|
|
168
|
+
"""Configuration for IsaacGymEnv environment."""
|
|
169
|
+
|
|
170
|
+
env_name: str = MISSING
|
|
171
|
+
from_pixels: bool = False
|
|
172
|
+
pixels_only: bool = True
|
|
173
|
+
frame_skip: int | None = None
|
|
174
|
+
device: str = "cpu"
|
|
175
|
+
batch_size: list[int] | None = None
|
|
176
|
+
allow_done_after_reset: bool = False
|
|
177
|
+
_target_: str = "torchrl.envs.libs.isaacgym.IsaacGymEnv"
|
|
178
|
+
|
|
179
|
+
def __post_init__(self) -> None:
|
|
180
|
+
"""Post-initialization hook for IsaacGymEnv configuration."""
|
|
181
|
+
super().__post_init__()
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
@dataclass
|
|
185
|
+
class JumanjiEnvConfig(EnvLibsConfig):
|
|
186
|
+
"""Configuration for JumanjiEnv environment."""
|
|
187
|
+
|
|
188
|
+
env_name: str = MISSING
|
|
189
|
+
from_pixels: bool = False
|
|
190
|
+
pixels_only: bool = True
|
|
191
|
+
frame_skip: int | None = None
|
|
192
|
+
device: str = "cpu"
|
|
193
|
+
batch_size: list[int] | None = None
|
|
194
|
+
allow_done_after_reset: bool = False
|
|
195
|
+
_target_: str = "torchrl.envs.libs.jumanji.JumanjiEnv"
|
|
196
|
+
|
|
197
|
+
def __post_init__(self) -> None:
|
|
198
|
+
"""Post-initialization hook for JumanjiEnv configuration."""
|
|
199
|
+
super().__post_init__()
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
@dataclass
|
|
203
|
+
class MeltingpotEnvConfig(EnvLibsConfig):
|
|
204
|
+
"""Configuration for MeltingpotEnv environment."""
|
|
205
|
+
|
|
206
|
+
env_name: str = MISSING
|
|
207
|
+
from_pixels: bool = False
|
|
208
|
+
pixels_only: bool = True
|
|
209
|
+
frame_skip: int | None = None
|
|
210
|
+
device: str = "cpu"
|
|
211
|
+
batch_size: list[int] | None = None
|
|
212
|
+
allow_done_after_reset: bool = False
|
|
213
|
+
_target_: str = "torchrl.envs.libs.meltingpot.MeltingpotEnv"
|
|
214
|
+
|
|
215
|
+
def __post_init__(self) -> None:
|
|
216
|
+
"""Post-initialization hook for MeltingpotEnv configuration."""
|
|
217
|
+
super().__post_init__()
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
@dataclass
|
|
221
|
+
class OpenMLEnvConfig(EnvLibsConfig):
|
|
222
|
+
"""Configuration for OpenMLEnv environment."""
|
|
223
|
+
|
|
224
|
+
env_name: str = MISSING
|
|
225
|
+
from_pixels: bool = False
|
|
226
|
+
pixels_only: bool = True
|
|
227
|
+
frame_skip: int | None = None
|
|
228
|
+
device: str = "cpu"
|
|
229
|
+
batch_size: list[int] | None = None
|
|
230
|
+
allow_done_after_reset: bool = False
|
|
231
|
+
_target_: str = "torchrl.envs.libs.openml.OpenMLEnv"
|
|
232
|
+
|
|
233
|
+
def __post_init__(self) -> None:
|
|
234
|
+
"""Post-initialization hook for OpenMLEnv configuration."""
|
|
235
|
+
super().__post_init__()
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
@dataclass
|
|
239
|
+
class OpenSpielEnvConfig(EnvLibsConfig):
|
|
240
|
+
"""Configuration for OpenSpielEnv environment."""
|
|
241
|
+
|
|
242
|
+
env_name: str = MISSING
|
|
243
|
+
from_pixels: bool = False
|
|
244
|
+
pixels_only: bool = True
|
|
245
|
+
frame_skip: int | None = None
|
|
246
|
+
device: str = "cpu"
|
|
247
|
+
batch_size: list[int] | None = None
|
|
248
|
+
allow_done_after_reset: bool = False
|
|
249
|
+
_target_: str = "torchrl.envs.libs.openspiel.OpenSpielEnv"
|
|
250
|
+
|
|
251
|
+
def __post_init__(self) -> None:
|
|
252
|
+
"""Post-initialization hook for OpenSpielEnv configuration."""
|
|
253
|
+
super().__post_init__()
|
|
254
|
+
|
|
255
|
+
|
|
256
|
+
@dataclass
|
|
257
|
+
class PettingZooEnvConfig(EnvLibsConfig):
|
|
258
|
+
"""Configuration for PettingZooEnv environment."""
|
|
259
|
+
|
|
260
|
+
env_name: str = MISSING
|
|
261
|
+
from_pixels: bool = False
|
|
262
|
+
pixels_only: bool = True
|
|
263
|
+
frame_skip: int | None = None
|
|
264
|
+
device: str = "cpu"
|
|
265
|
+
batch_size: list[int] | None = None
|
|
266
|
+
allow_done_after_reset: bool = False
|
|
267
|
+
_target_: str = "torchrl.envs.libs.pettingzoo.PettingZooEnv"
|
|
268
|
+
|
|
269
|
+
def __post_init__(self) -> None:
|
|
270
|
+
"""Post-initialization hook for PettingZooEnv configuration."""
|
|
271
|
+
super().__post_init__()
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
@dataclass
|
|
275
|
+
class RoboHiveEnvConfig(EnvLibsConfig):
|
|
276
|
+
"""Configuration for RoboHiveEnv environment."""
|
|
277
|
+
|
|
278
|
+
env_name: str = MISSING
|
|
279
|
+
from_pixels: bool = False
|
|
280
|
+
pixels_only: bool = True
|
|
281
|
+
frame_skip: int | None = None
|
|
282
|
+
device: str = "cpu"
|
|
283
|
+
batch_size: list[int] | None = None
|
|
284
|
+
allow_done_after_reset: bool = False
|
|
285
|
+
_target_: str = "torchrl.envs.libs.robohive.RoboHiveEnv"
|
|
286
|
+
|
|
287
|
+
def __post_init__(self) -> None:
|
|
288
|
+
"""Post-initialization hook for RoboHiveEnv configuration."""
|
|
289
|
+
super().__post_init__()
|
|
290
|
+
|
|
291
|
+
|
|
292
|
+
@dataclass
|
|
293
|
+
class SMACv2EnvConfig(EnvLibsConfig):
|
|
294
|
+
"""Configuration for SMACv2Env environment."""
|
|
295
|
+
|
|
296
|
+
env_name: str = MISSING
|
|
297
|
+
from_pixels: bool = False
|
|
298
|
+
pixels_only: bool = True
|
|
299
|
+
frame_skip: int | None = None
|
|
300
|
+
device: str = "cpu"
|
|
301
|
+
batch_size: list[int] | None = None
|
|
302
|
+
allow_done_after_reset: bool = False
|
|
303
|
+
_target_: str = "torchrl.envs.libs.smacv2.SMACv2Env"
|
|
304
|
+
|
|
305
|
+
def __post_init__(self) -> None:
|
|
306
|
+
"""Post-initialization hook for SMACv2Env configuration."""
|
|
307
|
+
super().__post_init__()
|
|
308
|
+
|
|
309
|
+
|
|
310
|
+
@dataclass
|
|
311
|
+
class UnityMLAgentsEnvConfig(EnvLibsConfig):
|
|
312
|
+
"""Configuration for UnityMLAgentsEnv environment."""
|
|
313
|
+
|
|
314
|
+
env_name: str = MISSING
|
|
315
|
+
from_pixels: bool = False
|
|
316
|
+
pixels_only: bool = True
|
|
317
|
+
frame_skip: int | None = None
|
|
318
|
+
device: str = "cpu"
|
|
319
|
+
batch_size: list[int] | None = None
|
|
320
|
+
allow_done_after_reset: bool = False
|
|
321
|
+
_target_: str = "torchrl.envs.libs.unity_mlagents.UnityMLAgentsEnv"
|
|
322
|
+
|
|
323
|
+
def __post_init__(self) -> None:
|
|
324
|
+
"""Post-initialization hook for UnityMLAgentsEnv configuration."""
|
|
325
|
+
super().__post_init__()
|
|
326
|
+
|
|
327
|
+
|
|
328
|
+
@dataclass
|
|
329
|
+
class VmasEnvConfig(EnvLibsConfig):
|
|
330
|
+
"""Configuration for VmasEnv environment."""
|
|
331
|
+
|
|
332
|
+
env_name: str = MISSING
|
|
333
|
+
from_pixels: bool = False
|
|
334
|
+
pixels_only: bool = True
|
|
335
|
+
frame_skip: int | None = None
|
|
336
|
+
device: str = "cpu"
|
|
337
|
+
batch_size: list[int] | None = None
|
|
338
|
+
allow_done_after_reset: bool = False
|
|
339
|
+
_target_: str = "torchrl.envs.libs.vmas.VmasEnv"
|
|
340
|
+
|
|
341
|
+
def __post_init__(self) -> None:
|
|
342
|
+
"""Post-initialization hook for VmasEnv configuration."""
|
|
343
|
+
super().__post_init__()
|
|
344
|
+
|
|
345
|
+
|
|
346
|
+
@dataclass
|
|
347
|
+
class MultiThreadedEnvConfig(EnvLibsConfig):
|
|
348
|
+
"""Configuration for MultiThreadedEnv environment."""
|
|
349
|
+
|
|
350
|
+
env_name: str = MISSING
|
|
351
|
+
from_pixels: bool = False
|
|
352
|
+
pixels_only: bool = True
|
|
353
|
+
frame_skip: int | None = None
|
|
354
|
+
device: str = "cpu"
|
|
355
|
+
batch_size: list[int] | None = None
|
|
356
|
+
allow_done_after_reset: bool = False
|
|
357
|
+
_target_: str = "torchrl.envs.libs.envpool.MultiThreadedEnv"
|
|
358
|
+
|
|
359
|
+
def __post_init__(self) -> None:
|
|
360
|
+
"""Post-initialization hook for MultiThreadedEnv configuration."""
|
|
361
|
+
super().__post_init__()
|
|
@@ -0,0 +1,80 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
from dataclasses import dataclass
|
|
9
|
+
|
|
10
|
+
from torchrl.trainers.algorithms.configs.common import ConfigBase
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@dataclass
|
|
14
|
+
class LoggerConfig(ConfigBase):
|
|
15
|
+
"""A class to configure a logger.
|
|
16
|
+
|
|
17
|
+
Args:
|
|
18
|
+
logger: The logger to use.
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
def __post_init__(self) -> None:
|
|
22
|
+
pass
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@dataclass
|
|
26
|
+
class WandbLoggerConfig(LoggerConfig):
|
|
27
|
+
"""A class to configure a Wandb logger.
|
|
28
|
+
|
|
29
|
+
.. seealso::
|
|
30
|
+
:class:`~torchrl.record.loggers.wandb.WandbLogger`
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
exp_name: str
|
|
34
|
+
offline: bool = False
|
|
35
|
+
save_dir: str | None = None
|
|
36
|
+
id: str | None = None
|
|
37
|
+
project: str | None = None
|
|
38
|
+
video_fps: int = 32
|
|
39
|
+
log_dir: str | None = None
|
|
40
|
+
|
|
41
|
+
_target_: str = "torchrl.record.loggers.wandb.WandbLogger"
|
|
42
|
+
|
|
43
|
+
def __post_init__(self) -> None:
|
|
44
|
+
pass
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
@dataclass
|
|
48
|
+
class TensorboardLoggerConfig(LoggerConfig):
|
|
49
|
+
"""A class to configure a Tensorboard logger.
|
|
50
|
+
|
|
51
|
+
.. seealso::
|
|
52
|
+
:class:`~torchrl.record.loggers.tensorboard.TensorboardLogger`
|
|
53
|
+
"""
|
|
54
|
+
|
|
55
|
+
exp_name: str
|
|
56
|
+
log_dir: str = "tb_logs"
|
|
57
|
+
|
|
58
|
+
_target_: str = "torchrl.record.loggers.tensorboard.TensorboardLogger"
|
|
59
|
+
|
|
60
|
+
def __post_init__(self) -> None:
|
|
61
|
+
pass
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
@dataclass
|
|
65
|
+
class CSVLoggerConfig(LoggerConfig):
|
|
66
|
+
"""A class to configure a CSV logger.
|
|
67
|
+
|
|
68
|
+
.. seealso::
|
|
69
|
+
:class:`~torchrl.record.loggers.csv.CSVLogger`
|
|
70
|
+
"""
|
|
71
|
+
|
|
72
|
+
exp_name: str
|
|
73
|
+
log_dir: str | None = None
|
|
74
|
+
video_format: str = "pt"
|
|
75
|
+
video_fps: int = 30
|
|
76
|
+
|
|
77
|
+
_target_: str = "torchrl.record.loggers.csv.CSVLogger"
|
|
78
|
+
|
|
79
|
+
def __post_init__(self) -> None:
|
|
80
|
+
pass
|