torchrl 0.11.0__cp314-cp314-manylinux_2_28_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/benchmark_batched_envs.py +104 -0
- benchmarks/conftest.py +91 -0
- benchmarks/ecosystem/gym_env_throughput.py +321 -0
- benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
- benchmarks/requirements.txt +7 -0
- benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
- benchmarks/test_collectors_benchmark.py +240 -0
- benchmarks/test_compressed_storage_benchmark.py +145 -0
- benchmarks/test_envs_benchmark.py +133 -0
- benchmarks/test_llm.py +101 -0
- benchmarks/test_non_tensor_env_benchmark.py +70 -0
- benchmarks/test_objectives_benchmarks.py +1199 -0
- benchmarks/test_replaybuffer_benchmark.py +254 -0
- sota-check/README.md +35 -0
- sota-implementations/README.md +142 -0
- sota-implementations/a2c/README.md +39 -0
- sota-implementations/a2c/a2c_atari.py +291 -0
- sota-implementations/a2c/a2c_mujoco.py +273 -0
- sota-implementations/a2c/utils_atari.py +240 -0
- sota-implementations/a2c/utils_mujoco.py +160 -0
- sota-implementations/bandits/README.md +7 -0
- sota-implementations/bandits/dqn.py +126 -0
- sota-implementations/cql/cql_offline.py +198 -0
- sota-implementations/cql/cql_online.py +249 -0
- sota-implementations/cql/discrete_cql_offline.py +180 -0
- sota-implementations/cql/discrete_cql_online.py +227 -0
- sota-implementations/cql/utils.py +471 -0
- sota-implementations/crossq/crossq.py +271 -0
- sota-implementations/crossq/utils.py +320 -0
- sota-implementations/ddpg/ddpg.py +231 -0
- sota-implementations/ddpg/utils.py +325 -0
- sota-implementations/decision_transformer/dt.py +163 -0
- sota-implementations/decision_transformer/lamb.py +167 -0
- sota-implementations/decision_transformer/online_dt.py +178 -0
- sota-implementations/decision_transformer/utils.py +562 -0
- sota-implementations/discrete_sac/discrete_sac.py +243 -0
- sota-implementations/discrete_sac/utils.py +324 -0
- sota-implementations/dqn/README.md +30 -0
- sota-implementations/dqn/dqn_atari.py +272 -0
- sota-implementations/dqn/dqn_cartpole.py +236 -0
- sota-implementations/dqn/utils_atari.py +132 -0
- sota-implementations/dqn/utils_cartpole.py +90 -0
- sota-implementations/dreamer/README.md +129 -0
- sota-implementations/dreamer/dreamer.py +586 -0
- sota-implementations/dreamer/dreamer_utils.py +1107 -0
- sota-implementations/expert-iteration/README.md +352 -0
- sota-implementations/expert-iteration/ei_utils.py +770 -0
- sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
- sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
- sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
- sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
- sota-implementations/gail/gail.py +327 -0
- sota-implementations/gail/gail_utils.py +68 -0
- sota-implementations/gail/ppo_utils.py +157 -0
- sota-implementations/grpo/README.md +273 -0
- sota-implementations/grpo/grpo-async.py +437 -0
- sota-implementations/grpo/grpo-sync.py +435 -0
- sota-implementations/grpo/grpo_utils.py +843 -0
- sota-implementations/grpo/requirements_gsm8k.txt +11 -0
- sota-implementations/grpo/requirements_ifeval.txt +16 -0
- sota-implementations/impala/README.md +33 -0
- sota-implementations/impala/impala_multi_node_ray.py +292 -0
- sota-implementations/impala/impala_multi_node_submitit.py +284 -0
- sota-implementations/impala/impala_single_node.py +261 -0
- sota-implementations/impala/utils.py +184 -0
- sota-implementations/iql/discrete_iql.py +230 -0
- sota-implementations/iql/iql_offline.py +164 -0
- sota-implementations/iql/iql_online.py +225 -0
- sota-implementations/iql/utils.py +437 -0
- sota-implementations/multiagent/README.md +74 -0
- sota-implementations/multiagent/iql.py +237 -0
- sota-implementations/multiagent/maddpg_iddpg.py +266 -0
- sota-implementations/multiagent/mappo_ippo.py +267 -0
- sota-implementations/multiagent/qmix_vdn.py +271 -0
- sota-implementations/multiagent/sac.py +337 -0
- sota-implementations/multiagent/utils/__init__.py +4 -0
- sota-implementations/multiagent/utils/logging.py +151 -0
- sota-implementations/multiagent/utils/utils.py +43 -0
- sota-implementations/ppo/README.md +29 -0
- sota-implementations/ppo/ppo_atari.py +305 -0
- sota-implementations/ppo/ppo_mujoco.py +293 -0
- sota-implementations/ppo/utils_atari.py +238 -0
- sota-implementations/ppo/utils_mujoco.py +152 -0
- sota-implementations/ppo_trainer/train.py +21 -0
- sota-implementations/redq/README.md +7 -0
- sota-implementations/redq/redq.py +199 -0
- sota-implementations/redq/utils.py +1060 -0
- sota-implementations/sac/sac-async.py +266 -0
- sota-implementations/sac/sac.py +239 -0
- sota-implementations/sac/utils.py +381 -0
- sota-implementations/sac_trainer/train.py +16 -0
- sota-implementations/td3/td3.py +254 -0
- sota-implementations/td3/utils.py +319 -0
- sota-implementations/td3_bc/td3_bc.py +177 -0
- sota-implementations/td3_bc/utils.py +251 -0
- torchrl/__init__.py +144 -0
- torchrl/_extension.py +74 -0
- torchrl/_torchrl.cpython-314-aarch64-linux-gnu.so +0 -0
- torchrl/_utils.py +1431 -0
- torchrl/collectors/__init__.py +48 -0
- torchrl/collectors/_base.py +1058 -0
- torchrl/collectors/_constants.py +88 -0
- torchrl/collectors/_multi_async.py +324 -0
- torchrl/collectors/_multi_base.py +1805 -0
- torchrl/collectors/_multi_sync.py +464 -0
- torchrl/collectors/_runner.py +581 -0
- torchrl/collectors/_single.py +2009 -0
- torchrl/collectors/_single_async.py +259 -0
- torchrl/collectors/collectors.py +62 -0
- torchrl/collectors/distributed/__init__.py +32 -0
- torchrl/collectors/distributed/default_configs.py +133 -0
- torchrl/collectors/distributed/generic.py +1306 -0
- torchrl/collectors/distributed/ray.py +1092 -0
- torchrl/collectors/distributed/rpc.py +1006 -0
- torchrl/collectors/distributed/sync.py +731 -0
- torchrl/collectors/distributed/utils.py +160 -0
- torchrl/collectors/llm/__init__.py +10 -0
- torchrl/collectors/llm/base.py +494 -0
- torchrl/collectors/llm/ray_collector.py +275 -0
- torchrl/collectors/llm/utils.py +36 -0
- torchrl/collectors/llm/weight_update/__init__.py +10 -0
- torchrl/collectors/llm/weight_update/vllm.py +348 -0
- torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
- torchrl/collectors/utils.py +433 -0
- torchrl/collectors/weight_update.py +591 -0
- torchrl/csrc/numpy_utils.h +38 -0
- torchrl/csrc/pybind.cpp +27 -0
- torchrl/csrc/segment_tree.h +458 -0
- torchrl/csrc/torch_utils.h +34 -0
- torchrl/csrc/utils.cpp +48 -0
- torchrl/csrc/utils.h +31 -0
- torchrl/data/__init__.py +187 -0
- torchrl/data/datasets/__init__.py +58 -0
- torchrl/data/datasets/atari_dqn.py +878 -0
- torchrl/data/datasets/common.py +281 -0
- torchrl/data/datasets/d4rl.py +489 -0
- torchrl/data/datasets/d4rl_infos.py +187 -0
- torchrl/data/datasets/gen_dgrl.py +375 -0
- torchrl/data/datasets/minari_data.py +643 -0
- torchrl/data/datasets/openml.py +177 -0
- torchrl/data/datasets/openx.py +798 -0
- torchrl/data/datasets/roboset.py +363 -0
- torchrl/data/datasets/utils.py +11 -0
- torchrl/data/datasets/vd4rl.py +432 -0
- torchrl/data/llm/__init__.py +34 -0
- torchrl/data/llm/dataset.py +491 -0
- torchrl/data/llm/history.py +1378 -0
- torchrl/data/llm/prompt.py +198 -0
- torchrl/data/llm/reward.py +225 -0
- torchrl/data/llm/topk.py +186 -0
- torchrl/data/llm/utils.py +543 -0
- torchrl/data/map/__init__.py +21 -0
- torchrl/data/map/hash.py +185 -0
- torchrl/data/map/query.py +204 -0
- torchrl/data/map/tdstorage.py +363 -0
- torchrl/data/map/tree.py +1434 -0
- torchrl/data/map/utils.py +103 -0
- torchrl/data/postprocs/__init__.py +8 -0
- torchrl/data/postprocs/postprocs.py +391 -0
- torchrl/data/replay_buffers/__init__.py +99 -0
- torchrl/data/replay_buffers/checkpointers.py +622 -0
- torchrl/data/replay_buffers/ray_buffer.py +292 -0
- torchrl/data/replay_buffers/replay_buffers.py +2376 -0
- torchrl/data/replay_buffers/samplers.py +2578 -0
- torchrl/data/replay_buffers/scheduler.py +265 -0
- torchrl/data/replay_buffers/storages.py +2412 -0
- torchrl/data/replay_buffers/utils.py +1042 -0
- torchrl/data/replay_buffers/writers.py +781 -0
- torchrl/data/tensor_specs.py +7101 -0
- torchrl/data/utils.py +334 -0
- torchrl/envs/__init__.py +265 -0
- torchrl/envs/async_envs.py +1105 -0
- torchrl/envs/batched_envs.py +3093 -0
- torchrl/envs/common.py +4241 -0
- torchrl/envs/custom/__init__.py +11 -0
- torchrl/envs/custom/chess.py +617 -0
- torchrl/envs/custom/llm.py +214 -0
- torchrl/envs/custom/pendulum.py +401 -0
- torchrl/envs/custom/san_moves.txt +29274 -0
- torchrl/envs/custom/tictactoeenv.py +288 -0
- torchrl/envs/env_creator.py +263 -0
- torchrl/envs/gym_like.py +752 -0
- torchrl/envs/libs/__init__.py +68 -0
- torchrl/envs/libs/_gym_utils.py +326 -0
- torchrl/envs/libs/brax.py +846 -0
- torchrl/envs/libs/dm_control.py +544 -0
- torchrl/envs/libs/envpool.py +447 -0
- torchrl/envs/libs/gym.py +2239 -0
- torchrl/envs/libs/habitat.py +138 -0
- torchrl/envs/libs/isaac_lab.py +87 -0
- torchrl/envs/libs/isaacgym.py +203 -0
- torchrl/envs/libs/jax_utils.py +166 -0
- torchrl/envs/libs/jumanji.py +963 -0
- torchrl/envs/libs/meltingpot.py +599 -0
- torchrl/envs/libs/openml.py +153 -0
- torchrl/envs/libs/openspiel.py +652 -0
- torchrl/envs/libs/pettingzoo.py +1042 -0
- torchrl/envs/libs/procgen.py +351 -0
- torchrl/envs/libs/robohive.py +429 -0
- torchrl/envs/libs/smacv2.py +645 -0
- torchrl/envs/libs/unity_mlagents.py +891 -0
- torchrl/envs/libs/utils.py +147 -0
- torchrl/envs/libs/vmas.py +813 -0
- torchrl/envs/llm/__init__.py +63 -0
- torchrl/envs/llm/chat.py +730 -0
- torchrl/envs/llm/datasets/README.md +4 -0
- torchrl/envs/llm/datasets/__init__.py +17 -0
- torchrl/envs/llm/datasets/gsm8k.py +353 -0
- torchrl/envs/llm/datasets/ifeval.py +274 -0
- torchrl/envs/llm/envs.py +789 -0
- torchrl/envs/llm/libs/README.md +3 -0
- torchrl/envs/llm/libs/__init__.py +8 -0
- torchrl/envs/llm/libs/mlgym.py +869 -0
- torchrl/envs/llm/reward/__init__.py +10 -0
- torchrl/envs/llm/reward/gsm8k.py +324 -0
- torchrl/envs/llm/reward/ifeval/README.md +13 -0
- torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
- torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
- torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
- torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
- torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
- torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
- torchrl/envs/llm/transforms/__init__.py +55 -0
- torchrl/envs/llm/transforms/browser.py +292 -0
- torchrl/envs/llm/transforms/dataloading.py +859 -0
- torchrl/envs/llm/transforms/format.py +73 -0
- torchrl/envs/llm/transforms/kl.py +1544 -0
- torchrl/envs/llm/transforms/policy_version.py +189 -0
- torchrl/envs/llm/transforms/reason.py +323 -0
- torchrl/envs/llm/transforms/tokenizer.py +321 -0
- torchrl/envs/llm/transforms/tools.py +1955 -0
- torchrl/envs/model_based/__init__.py +9 -0
- torchrl/envs/model_based/common.py +180 -0
- torchrl/envs/model_based/dreamer.py +112 -0
- torchrl/envs/transforms/__init__.py +147 -0
- torchrl/envs/transforms/functional.py +48 -0
- torchrl/envs/transforms/gym_transforms.py +203 -0
- torchrl/envs/transforms/module.py +341 -0
- torchrl/envs/transforms/r3m.py +372 -0
- torchrl/envs/transforms/ray_service.py +663 -0
- torchrl/envs/transforms/rb_transforms.py +214 -0
- torchrl/envs/transforms/transforms.py +11835 -0
- torchrl/envs/transforms/utils.py +94 -0
- torchrl/envs/transforms/vc1.py +307 -0
- torchrl/envs/transforms/vecnorm.py +845 -0
- torchrl/envs/transforms/vip.py +407 -0
- torchrl/envs/utils.py +1718 -0
- torchrl/envs/vec_envs.py +11 -0
- torchrl/modules/__init__.py +206 -0
- torchrl/modules/distributions/__init__.py +73 -0
- torchrl/modules/distributions/continuous.py +830 -0
- torchrl/modules/distributions/discrete.py +908 -0
- torchrl/modules/distributions/truncated_normal.py +187 -0
- torchrl/modules/distributions/utils.py +233 -0
- torchrl/modules/llm/__init__.py +62 -0
- torchrl/modules/llm/backends/__init__.py +65 -0
- torchrl/modules/llm/backends/vllm/__init__.py +94 -0
- torchrl/modules/llm/backends/vllm/_models.py +46 -0
- torchrl/modules/llm/backends/vllm/base.py +72 -0
- torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
- torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
- torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
- torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
- torchrl/modules/llm/policies/__init__.py +28 -0
- torchrl/modules/llm/policies/common.py +1809 -0
- torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
- torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
- torchrl/modules/llm/utils.py +23 -0
- torchrl/modules/mcts/__init__.py +21 -0
- torchrl/modules/mcts/scores.py +579 -0
- torchrl/modules/models/__init__.py +86 -0
- torchrl/modules/models/batchrenorm.py +119 -0
- torchrl/modules/models/decision_transformer.py +179 -0
- torchrl/modules/models/exploration.py +731 -0
- torchrl/modules/models/llm.py +156 -0
- torchrl/modules/models/model_based.py +596 -0
- torchrl/modules/models/models.py +1712 -0
- torchrl/modules/models/multiagent.py +1067 -0
- torchrl/modules/models/recipes/impala.py +185 -0
- torchrl/modules/models/utils.py +162 -0
- torchrl/modules/planners/__init__.py +10 -0
- torchrl/modules/planners/cem.py +228 -0
- torchrl/modules/planners/common.py +73 -0
- torchrl/modules/planners/mppi.py +265 -0
- torchrl/modules/tensordict_module/__init__.py +89 -0
- torchrl/modules/tensordict_module/actors.py +2457 -0
- torchrl/modules/tensordict_module/common.py +529 -0
- torchrl/modules/tensordict_module/exploration.py +814 -0
- torchrl/modules/tensordict_module/probabilistic.py +321 -0
- torchrl/modules/tensordict_module/rnn.py +1639 -0
- torchrl/modules/tensordict_module/sequence.py +132 -0
- torchrl/modules/tensordict_module/world_models.py +34 -0
- torchrl/modules/utils/__init__.py +38 -0
- torchrl/modules/utils/mappings.py +9 -0
- torchrl/modules/utils/utils.py +89 -0
- torchrl/objectives/__init__.py +78 -0
- torchrl/objectives/a2c.py +659 -0
- torchrl/objectives/common.py +753 -0
- torchrl/objectives/cql.py +1346 -0
- torchrl/objectives/crossq.py +710 -0
- torchrl/objectives/ddpg.py +453 -0
- torchrl/objectives/decision_transformer.py +371 -0
- torchrl/objectives/deprecated.py +516 -0
- torchrl/objectives/dqn.py +683 -0
- torchrl/objectives/dreamer.py +488 -0
- torchrl/objectives/functional.py +48 -0
- torchrl/objectives/gail.py +258 -0
- torchrl/objectives/iql.py +996 -0
- torchrl/objectives/llm/__init__.py +30 -0
- torchrl/objectives/llm/grpo.py +846 -0
- torchrl/objectives/llm/sft.py +482 -0
- torchrl/objectives/multiagent/__init__.py +8 -0
- torchrl/objectives/multiagent/qmixer.py +396 -0
- torchrl/objectives/ppo.py +1669 -0
- torchrl/objectives/redq.py +683 -0
- torchrl/objectives/reinforce.py +530 -0
- torchrl/objectives/sac.py +1580 -0
- torchrl/objectives/td3.py +570 -0
- torchrl/objectives/td3_bc.py +625 -0
- torchrl/objectives/utils.py +782 -0
- torchrl/objectives/value/__init__.py +28 -0
- torchrl/objectives/value/advantages.py +1956 -0
- torchrl/objectives/value/functional.py +1459 -0
- torchrl/objectives/value/utils.py +360 -0
- torchrl/record/__init__.py +17 -0
- torchrl/record/loggers/__init__.py +23 -0
- torchrl/record/loggers/common.py +48 -0
- torchrl/record/loggers/csv.py +226 -0
- torchrl/record/loggers/mlflow.py +142 -0
- torchrl/record/loggers/tensorboard.py +139 -0
- torchrl/record/loggers/trackio.py +163 -0
- torchrl/record/loggers/utils.py +78 -0
- torchrl/record/loggers/wandb.py +214 -0
- torchrl/record/recorder.py +554 -0
- torchrl/services/__init__.py +79 -0
- torchrl/services/base.py +109 -0
- torchrl/services/ray_service.py +453 -0
- torchrl/testing/__init__.py +107 -0
- torchrl/testing/assertions.py +179 -0
- torchrl/testing/dist_utils.py +122 -0
- torchrl/testing/env_creators.py +227 -0
- torchrl/testing/env_helper.py +35 -0
- torchrl/testing/gym_helpers.py +156 -0
- torchrl/testing/llm_mocks.py +119 -0
- torchrl/testing/mocking_classes.py +2720 -0
- torchrl/testing/modules.py +295 -0
- torchrl/testing/mp_helpers.py +15 -0
- torchrl/testing/ray_helpers.py +293 -0
- torchrl/testing/utils.py +190 -0
- torchrl/trainers/__init__.py +42 -0
- torchrl/trainers/algorithms/__init__.py +11 -0
- torchrl/trainers/algorithms/configs/__init__.py +705 -0
- torchrl/trainers/algorithms/configs/collectors.py +216 -0
- torchrl/trainers/algorithms/configs/common.py +41 -0
- torchrl/trainers/algorithms/configs/data.py +308 -0
- torchrl/trainers/algorithms/configs/envs.py +104 -0
- torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
- torchrl/trainers/algorithms/configs/logging.py +80 -0
- torchrl/trainers/algorithms/configs/modules.py +570 -0
- torchrl/trainers/algorithms/configs/objectives.py +177 -0
- torchrl/trainers/algorithms/configs/trainers.py +340 -0
- torchrl/trainers/algorithms/configs/transforms.py +955 -0
- torchrl/trainers/algorithms/configs/utils.py +252 -0
- torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
- torchrl/trainers/algorithms/configs/weight_update.py +159 -0
- torchrl/trainers/algorithms/ppo.py +373 -0
- torchrl/trainers/algorithms/sac.py +308 -0
- torchrl/trainers/helpers/__init__.py +40 -0
- torchrl/trainers/helpers/collectors.py +416 -0
- torchrl/trainers/helpers/envs.py +573 -0
- torchrl/trainers/helpers/logger.py +33 -0
- torchrl/trainers/helpers/losses.py +132 -0
- torchrl/trainers/helpers/models.py +658 -0
- torchrl/trainers/helpers/replay_buffer.py +59 -0
- torchrl/trainers/helpers/trainers.py +301 -0
- torchrl/trainers/trainers.py +2052 -0
- torchrl/weight_update/__init__.py +33 -0
- torchrl/weight_update/_distributed.py +749 -0
- torchrl/weight_update/_mp.py +624 -0
- torchrl/weight_update/_noupdate.py +102 -0
- torchrl/weight_update/_ray.py +1032 -0
- torchrl/weight_update/_rpc.py +284 -0
- torchrl/weight_update/_shared.py +891 -0
- torchrl/weight_update/llm/__init__.py +32 -0
- torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
- torchrl/weight_update/llm/vllm_nccl.py +710 -0
- torchrl/weight_update/utils.py +73 -0
- torchrl/weight_update/weight_sync_schemes.py +1244 -0
- torchrl-0.11.0.dist-info/METADATA +1308 -0
- torchrl-0.11.0.dist-info/RECORD +394 -0
- torchrl-0.11.0.dist-info/WHEEL +5 -0
- torchrl-0.11.0.dist-info/entry_points.txt +2 -0
- torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
- torchrl-0.11.0.dist-info/top_level.txt +7 -0
|
@@ -0,0 +1,573 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
|
|
6
|
+
# This makes omegaconf unhappy with typing.Any
|
|
7
|
+
# Therefore we need Optional and Union
|
|
8
|
+
# from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import importlib.util
|
|
11
|
+
from collections.abc import Callable, Sequence
|
|
12
|
+
from copy import copy
|
|
13
|
+
from dataclasses import dataclass, field as dataclass_field
|
|
14
|
+
from typing import Any, Optional, Union
|
|
15
|
+
|
|
16
|
+
import torch
|
|
17
|
+
from torchrl._utils import logger as torchrl_logger, VERBOSE
|
|
18
|
+
from torchrl.envs import ParallelEnv
|
|
19
|
+
from torchrl.envs.common import EnvBase
|
|
20
|
+
from torchrl.envs.env_creator import env_creator, EnvCreator
|
|
21
|
+
from torchrl.envs.libs.dm_control import DMControlEnv
|
|
22
|
+
from torchrl.envs.libs.gym import GymEnv
|
|
23
|
+
from torchrl.envs.transforms import (
|
|
24
|
+
CatFrames,
|
|
25
|
+
CatTensors,
|
|
26
|
+
CenterCrop,
|
|
27
|
+
Compose,
|
|
28
|
+
DoubleToFloat,
|
|
29
|
+
GrayScale,
|
|
30
|
+
NoopResetEnv,
|
|
31
|
+
ObservationNorm,
|
|
32
|
+
Resize,
|
|
33
|
+
RewardScaling,
|
|
34
|
+
ToTensorImage,
|
|
35
|
+
TransformedEnv,
|
|
36
|
+
VecNorm,
|
|
37
|
+
)
|
|
38
|
+
from torchrl.envs.transforms.transforms import (
|
|
39
|
+
FlattenObservation,
|
|
40
|
+
gSDENoise,
|
|
41
|
+
InitTracker,
|
|
42
|
+
StepCounter,
|
|
43
|
+
)
|
|
44
|
+
from torchrl.record.loggers import Logger
|
|
45
|
+
from torchrl.record.recorder import VideoRecorder
|
|
46
|
+
|
|
47
|
+
LIBS = {
|
|
48
|
+
"gym": GymEnv,
|
|
49
|
+
"dm_control": DMControlEnv,
|
|
50
|
+
}
|
|
51
|
+
|
|
52
|
+
_has_omegaconf = importlib.util.find_spec("omegaconf") is not None
|
|
53
|
+
if _has_omegaconf:
|
|
54
|
+
from omegaconf import DictConfig
|
|
55
|
+
else:
|
|
56
|
+
|
|
57
|
+
class DictConfig: # noqa
|
|
58
|
+
...
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def correct_for_frame_skip(cfg: DictConfig) -> DictConfig: # noqa: F821
|
|
62
|
+
"""Correct the arguments for the input frame_skip, by dividing all the arguments that reflect a count of frames by the frame_skip.
|
|
63
|
+
|
|
64
|
+
This is aimed at avoiding unknowingly over-sampling from the environment, i.e. targeting a total number of frames
|
|
65
|
+
of 1M but actually collecting frame_skip * 1M frames.
|
|
66
|
+
|
|
67
|
+
Args:
|
|
68
|
+
cfg (DictConfig): DictConfig containing some frame-counting argument, including:
|
|
69
|
+
"max_frames_per_traj", "total_frames", "frames_per_batch", "record_frames", "annealing_frames",
|
|
70
|
+
"init_random_frames", "init_env_steps"
|
|
71
|
+
|
|
72
|
+
Returns:
|
|
73
|
+
the input DictConfig, modified in-place.
|
|
74
|
+
|
|
75
|
+
"""
|
|
76
|
+
# Adapt all frame counts wrt frame_skip
|
|
77
|
+
if cfg.frame_skip != 1:
|
|
78
|
+
fields = [
|
|
79
|
+
"max_frames_per_traj",
|
|
80
|
+
"total_frames",
|
|
81
|
+
"frames_per_batch",
|
|
82
|
+
"record_frames",
|
|
83
|
+
"annealing_frames",
|
|
84
|
+
"init_random_frames",
|
|
85
|
+
"init_env_steps",
|
|
86
|
+
"noops",
|
|
87
|
+
]
|
|
88
|
+
for field in fields:
|
|
89
|
+
if hasattr(cfg, field):
|
|
90
|
+
setattr(cfg, field, getattr(cfg, field) // cfg.frame_skip)
|
|
91
|
+
return cfg
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def make_env_transforms(
|
|
95
|
+
env,
|
|
96
|
+
cfg,
|
|
97
|
+
video_tag,
|
|
98
|
+
logger,
|
|
99
|
+
env_name,
|
|
100
|
+
stats,
|
|
101
|
+
norm_obs_only,
|
|
102
|
+
env_library,
|
|
103
|
+
action_dim_gsde,
|
|
104
|
+
state_dim_gsde,
|
|
105
|
+
batch_dims=0,
|
|
106
|
+
obs_norm_state_dict=None,
|
|
107
|
+
):
|
|
108
|
+
"""Creates the typical transforms for and env."""
|
|
109
|
+
env = TransformedEnv(env)
|
|
110
|
+
|
|
111
|
+
from_pixels = cfg.from_pixels
|
|
112
|
+
vecnorm = cfg.vecnorm
|
|
113
|
+
norm_rewards = vecnorm and cfg.norm_rewards
|
|
114
|
+
_norm_obs_only = norm_obs_only or not norm_rewards
|
|
115
|
+
reward_scaling = cfg.reward_scaling
|
|
116
|
+
reward_loc = cfg.reward_loc
|
|
117
|
+
|
|
118
|
+
if len(video_tag):
|
|
119
|
+
center_crop = cfg.center_crop
|
|
120
|
+
if center_crop:
|
|
121
|
+
center_crop = center_crop[0]
|
|
122
|
+
env.append_transform(
|
|
123
|
+
VideoRecorder(
|
|
124
|
+
logger=logger,
|
|
125
|
+
tag=f"{video_tag}_{env_name}_video",
|
|
126
|
+
center_crop=center_crop,
|
|
127
|
+
),
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
if from_pixels:
|
|
131
|
+
if not cfg.catframes:
|
|
132
|
+
raise RuntimeError(
|
|
133
|
+
"this env builder currently only accepts positive catframes values "
|
|
134
|
+
"when pixels are being used."
|
|
135
|
+
)
|
|
136
|
+
env.append_transform(ToTensorImage())
|
|
137
|
+
if cfg.center_crop:
|
|
138
|
+
env.append_transform(CenterCrop(*cfg.center_crop))
|
|
139
|
+
env.append_transform(Resize(cfg.image_size, cfg.image_size))
|
|
140
|
+
if cfg.grayscale:
|
|
141
|
+
env.append_transform(GrayScale())
|
|
142
|
+
env.append_transform(FlattenObservation(0, -3, allow_positive_dim=True))
|
|
143
|
+
env.append_transform(CatFrames(N=cfg.catframes, in_keys=["pixels"], dim=-3))
|
|
144
|
+
if stats is None and obs_norm_state_dict is None:
|
|
145
|
+
obs_stats = {}
|
|
146
|
+
elif stats is None:
|
|
147
|
+
obs_stats = copy(obs_norm_state_dict)
|
|
148
|
+
else:
|
|
149
|
+
obs_stats = copy(stats)
|
|
150
|
+
obs_stats["standard_normal"] = True
|
|
151
|
+
obs_norm = ObservationNorm(**obs_stats, in_keys=["pixels"])
|
|
152
|
+
env.append_transform(obs_norm)
|
|
153
|
+
if norm_rewards:
|
|
154
|
+
reward_scaling = 1.0
|
|
155
|
+
reward_loc = 0.0
|
|
156
|
+
if norm_obs_only:
|
|
157
|
+
reward_scaling = 1.0
|
|
158
|
+
reward_loc = 0.0
|
|
159
|
+
if reward_scaling is not None:
|
|
160
|
+
env.append_transform(RewardScaling(reward_loc, reward_scaling))
|
|
161
|
+
|
|
162
|
+
if not from_pixels:
|
|
163
|
+
selected_keys = [
|
|
164
|
+
key
|
|
165
|
+
for key in env.observation_spec.keys(True, True)
|
|
166
|
+
if ("pixels" not in key) and (key not in env.state_spec.keys(True, True))
|
|
167
|
+
]
|
|
168
|
+
|
|
169
|
+
# even if there is a single tensor, it'll be renamed in "observation_vector"
|
|
170
|
+
out_key = "observation_vector"
|
|
171
|
+
env.append_transform(CatTensors(in_keys=selected_keys, out_key=out_key))
|
|
172
|
+
|
|
173
|
+
if not vecnorm:
|
|
174
|
+
if stats is None and obs_norm_state_dict is None:
|
|
175
|
+
_stats = {}
|
|
176
|
+
elif stats is None:
|
|
177
|
+
_stats = copy(obs_norm_state_dict)
|
|
178
|
+
else:
|
|
179
|
+
_stats = copy(stats)
|
|
180
|
+
_stats.update({"standard_normal": True})
|
|
181
|
+
obs_norm = ObservationNorm(
|
|
182
|
+
**_stats,
|
|
183
|
+
in_keys=[out_key],
|
|
184
|
+
)
|
|
185
|
+
env.append_transform(obs_norm)
|
|
186
|
+
else:
|
|
187
|
+
env.append_transform(
|
|
188
|
+
VecNorm(
|
|
189
|
+
in_keys=[out_key, "reward"] if not _norm_obs_only else [out_key],
|
|
190
|
+
decay=0.9999,
|
|
191
|
+
)
|
|
192
|
+
)
|
|
193
|
+
|
|
194
|
+
env.append_transform(DoubleToFloat())
|
|
195
|
+
|
|
196
|
+
if hasattr(cfg, "catframes") and cfg.catframes:
|
|
197
|
+
env.append_transform(CatFrames(N=cfg.catframes, in_keys=[out_key], dim=-1))
|
|
198
|
+
|
|
199
|
+
else:
|
|
200
|
+
env.append_transform(DoubleToFloat())
|
|
201
|
+
|
|
202
|
+
if hasattr(cfg, "gSDE") and cfg.gSDE:
|
|
203
|
+
env.append_transform(
|
|
204
|
+
gSDENoise(action_dim=action_dim_gsde, state_dim=state_dim_gsde)
|
|
205
|
+
)
|
|
206
|
+
|
|
207
|
+
env.append_transform(StepCounter())
|
|
208
|
+
env.append_transform(InitTracker())
|
|
209
|
+
|
|
210
|
+
return env
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
def get_norm_state_dict(env):
|
|
214
|
+
"""Gets the normalization loc and scale from the env state_dict."""
|
|
215
|
+
sd = env.state_dict()
|
|
216
|
+
sd = {
|
|
217
|
+
key: val
|
|
218
|
+
for key, val in sd.items()
|
|
219
|
+
if key.endswith("loc") or key.endswith("scale")
|
|
220
|
+
}
|
|
221
|
+
return sd
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
def transformed_env_constructor(
|
|
225
|
+
cfg: DictConfig, # noqa: F821
|
|
226
|
+
video_tag: str = "",
|
|
227
|
+
logger: Optional[Logger] = None, # noqa
|
|
228
|
+
stats: Optional[dict] = None,
|
|
229
|
+
norm_obs_only: bool = False,
|
|
230
|
+
use_env_creator: bool = False,
|
|
231
|
+
custom_env_maker: Optional[Callable] = None,
|
|
232
|
+
custom_env: Optional[EnvBase] = None,
|
|
233
|
+
return_transformed_envs: bool = True,
|
|
234
|
+
action_dim_gsde: Optional[int] = None,
|
|
235
|
+
state_dim_gsde: Optional[int] = None,
|
|
236
|
+
batch_dims: Optional[int] = 0,
|
|
237
|
+
obs_norm_state_dict: Optional[dict] = None,
|
|
238
|
+
) -> Union[Callable, EnvCreator]:
|
|
239
|
+
"""Returns an environment creator from an argparse.Namespace built with the appropriate parser constructor.
|
|
240
|
+
|
|
241
|
+
Args:
|
|
242
|
+
cfg (DictConfig): a DictConfig containing the arguments of the script.
|
|
243
|
+
video_tag (str, optional): video tag to be passed to the Logger object
|
|
244
|
+
logger (Logger, optional): logger associated with the script
|
|
245
|
+
stats (dict, optional): a dictionary containing the :obj:`loc` and :obj:`scale` for the `ObservationNorm` transform
|
|
246
|
+
norm_obs_only (bool, optional): If `True` and `VecNorm` is used, the reward won't be normalized online.
|
|
247
|
+
Default is `False`.
|
|
248
|
+
use_env_creator (bool, optional): whether the `EnvCreator` class should be used. By using `EnvCreator`,
|
|
249
|
+
one can make sure that running statistics will be put in shared memory and accessible for all workers
|
|
250
|
+
when using a `VecNorm` transform. Default is `True`.
|
|
251
|
+
custom_env_maker (callable, optional): if your env maker is not part
|
|
252
|
+
of torchrl env wrappers, a custom callable
|
|
253
|
+
can be passed instead. In this case it will override the
|
|
254
|
+
constructor retrieved from `args`.
|
|
255
|
+
custom_env (EnvBase, optional): if an existing environment needs to be
|
|
256
|
+
transformed_in, it can be passed directly to this helper. `custom_env_maker`
|
|
257
|
+
and `custom_env` are exclusive features.
|
|
258
|
+
return_transformed_envs (bool, optional): if ``True``, a transformed_in environment
|
|
259
|
+
is returned.
|
|
260
|
+
action_dim_gsde (int, Optional): if gSDE is used, this can present the action dim to initialize the noise.
|
|
261
|
+
Make sure this is indicated in environment executed in parallel.
|
|
262
|
+
state_dim_gsde: if gSDE is used, this can present the state dim to initialize the noise.
|
|
263
|
+
Make sure this is indicated in environment executed in parallel.
|
|
264
|
+
batch_dims (int, optional): number of dimensions of a batch of data. If a single env is
|
|
265
|
+
used, it should be 0 (default). If multiple envs are being transformed in parallel,
|
|
266
|
+
it should be set to 1 (or the number of dims of the batch).
|
|
267
|
+
obs_norm_state_dict (dict, optional): the state_dict of the ObservationNorm transform to be loaded into the
|
|
268
|
+
environment
|
|
269
|
+
"""
|
|
270
|
+
|
|
271
|
+
def make_transformed_env(**kwargs) -> TransformedEnv:
|
|
272
|
+
env_name = cfg.env_name
|
|
273
|
+
env_task = cfg.env_task
|
|
274
|
+
env_library = LIBS[cfg.env_library]
|
|
275
|
+
frame_skip = cfg.frame_skip
|
|
276
|
+
from_pixels = cfg.from_pixels
|
|
277
|
+
categorical_action_encoding = cfg.categorical_action_encoding
|
|
278
|
+
|
|
279
|
+
if custom_env is None and custom_env_maker is None:
|
|
280
|
+
if isinstance(cfg.collector_device, str):
|
|
281
|
+
device = cfg.collector_device
|
|
282
|
+
elif isinstance(cfg.collector_device, Sequence):
|
|
283
|
+
device = cfg.collector_device[0]
|
|
284
|
+
else:
|
|
285
|
+
raise ValueError(
|
|
286
|
+
"collector_device must be either a string or a sequence of strings"
|
|
287
|
+
)
|
|
288
|
+
env_kwargs = {
|
|
289
|
+
"env_name": env_name,
|
|
290
|
+
"device": device,
|
|
291
|
+
"frame_skip": frame_skip,
|
|
292
|
+
"from_pixels": from_pixels or len(video_tag),
|
|
293
|
+
"pixels_only": from_pixels,
|
|
294
|
+
}
|
|
295
|
+
if env_library is GymEnv:
|
|
296
|
+
env_kwargs.update(
|
|
297
|
+
{"categorical_action_encoding": categorical_action_encoding}
|
|
298
|
+
)
|
|
299
|
+
elif categorical_action_encoding:
|
|
300
|
+
raise NotImplementedError(
|
|
301
|
+
"categorical_action_encoding=True is currently only compatible with GymEnvs."
|
|
302
|
+
)
|
|
303
|
+
if env_library is DMControlEnv:
|
|
304
|
+
env_kwargs.update({"task_name": env_task})
|
|
305
|
+
env_kwargs.update(kwargs)
|
|
306
|
+
env = env_library(**env_kwargs)
|
|
307
|
+
elif custom_env is None and custom_env_maker is not None:
|
|
308
|
+
env = custom_env_maker(**kwargs)
|
|
309
|
+
elif custom_env_maker is None and custom_env is not None:
|
|
310
|
+
env = custom_env
|
|
311
|
+
else:
|
|
312
|
+
raise RuntimeError("cannot provide both custom_env and custom_env_maker")
|
|
313
|
+
|
|
314
|
+
if cfg.noops and custom_env is None:
|
|
315
|
+
# this is a bit hacky: if custom_env is not None, it is probably a ParallelEnv
|
|
316
|
+
# that already has its NoopResetEnv set for the contained envs.
|
|
317
|
+
# There is a risk however that we're just skipping the NoopsReset instantiation
|
|
318
|
+
env = TransformedEnv(env, NoopResetEnv(cfg.noops))
|
|
319
|
+
if not return_transformed_envs:
|
|
320
|
+
return env
|
|
321
|
+
|
|
322
|
+
return make_env_transforms(
|
|
323
|
+
env,
|
|
324
|
+
cfg,
|
|
325
|
+
video_tag,
|
|
326
|
+
logger,
|
|
327
|
+
env_name,
|
|
328
|
+
stats,
|
|
329
|
+
norm_obs_only,
|
|
330
|
+
env_library,
|
|
331
|
+
action_dim_gsde,
|
|
332
|
+
state_dim_gsde,
|
|
333
|
+
batch_dims=batch_dims,
|
|
334
|
+
obs_norm_state_dict=obs_norm_state_dict,
|
|
335
|
+
)
|
|
336
|
+
|
|
337
|
+
if use_env_creator:
|
|
338
|
+
return env_creator(make_transformed_env)
|
|
339
|
+
return make_transformed_env
|
|
340
|
+
|
|
341
|
+
|
|
342
|
+
def parallel_env_constructor(
|
|
343
|
+
cfg: DictConfig, **kwargs # noqa: F821
|
|
344
|
+
) -> Union[ParallelEnv, EnvCreator]:
|
|
345
|
+
"""Returns a parallel environment from an argparse.Namespace built with the appropriate parser constructor.
|
|
346
|
+
|
|
347
|
+
Args:
|
|
348
|
+
cfg (DictConfig): config containing user-defined arguments
|
|
349
|
+
kwargs: keyword arguments for the `transformed_env_constructor` method.
|
|
350
|
+
"""
|
|
351
|
+
batch_transform = cfg.batch_transform
|
|
352
|
+
if not batch_transform:
|
|
353
|
+
raise NotImplementedError(
|
|
354
|
+
"batch_transform must be set to True for the recorder to be synced "
|
|
355
|
+
"with the collection envs."
|
|
356
|
+
)
|
|
357
|
+
if cfg.env_per_collector == 1:
|
|
358
|
+
kwargs.update({"cfg": cfg, "use_env_creator": True})
|
|
359
|
+
make_transformed_env = transformed_env_constructor(**kwargs)
|
|
360
|
+
return make_transformed_env
|
|
361
|
+
kwargs.update({"cfg": cfg, "use_env_creator": True})
|
|
362
|
+
make_transformed_env = transformed_env_constructor(
|
|
363
|
+
return_transformed_envs=not batch_transform, **kwargs
|
|
364
|
+
)
|
|
365
|
+
parallel_env = ParallelEnv(
|
|
366
|
+
num_workers=cfg.env_per_collector,
|
|
367
|
+
create_env_fn=make_transformed_env,
|
|
368
|
+
create_env_kwargs=None,
|
|
369
|
+
pin_memory=cfg.pin_memory,
|
|
370
|
+
)
|
|
371
|
+
if batch_transform:
|
|
372
|
+
kwargs.update(
|
|
373
|
+
{
|
|
374
|
+
"cfg": cfg,
|
|
375
|
+
"use_env_creator": False,
|
|
376
|
+
"custom_env": parallel_env,
|
|
377
|
+
"batch_dims": 1,
|
|
378
|
+
}
|
|
379
|
+
)
|
|
380
|
+
env = transformed_env_constructor(**kwargs)()
|
|
381
|
+
return env
|
|
382
|
+
return parallel_env
|
|
383
|
+
|
|
384
|
+
|
|
385
|
+
@torch.no_grad()
|
|
386
|
+
def get_stats_random_rollout(
|
|
387
|
+
cfg: DictConfig, # noqa: F821
|
|
388
|
+
proof_environment: EnvBase = None,
|
|
389
|
+
key: Optional[str] = None,
|
|
390
|
+
):
|
|
391
|
+
"""Gathers stas (loc and scale) from an environment using random rollouts.
|
|
392
|
+
|
|
393
|
+
Args:
|
|
394
|
+
cfg (DictConfig): a config object with `init_env_steps` field, indicating
|
|
395
|
+
the total number of frames to be collected to compute the stats.
|
|
396
|
+
proof_environment (EnvBase instance, optional): if provided, this env will
|
|
397
|
+
be used ot execute the rollouts. If not, it will be created using
|
|
398
|
+
the cfg object.
|
|
399
|
+
key (str, optional): if provided, the stats of this key will be gathered.
|
|
400
|
+
If not, it is expected that only one key exists in `env.observation_spec`.
|
|
401
|
+
|
|
402
|
+
"""
|
|
403
|
+
proof_env_is_none = proof_environment is None
|
|
404
|
+
if proof_env_is_none:
|
|
405
|
+
proof_environment = transformed_env_constructor(
|
|
406
|
+
cfg=cfg, use_env_creator=False, stats={"loc": 0.0, "scale": 1.0}
|
|
407
|
+
)()
|
|
408
|
+
|
|
409
|
+
if VERBOSE:
|
|
410
|
+
torchrl_logger.info("computing state stats")
|
|
411
|
+
if not hasattr(cfg, "init_env_steps"):
|
|
412
|
+
raise AttributeError("init_env_steps missing from arguments.")
|
|
413
|
+
|
|
414
|
+
n = 0
|
|
415
|
+
val_stats = []
|
|
416
|
+
while n < cfg.init_env_steps:
|
|
417
|
+
_td_stats = proof_environment.rollout(max_steps=cfg.init_env_steps)
|
|
418
|
+
n += _td_stats.numel()
|
|
419
|
+
val = _td_stats.get(key).cpu()
|
|
420
|
+
val_stats.append(val)
|
|
421
|
+
del _td_stats, val
|
|
422
|
+
val_stats = torch.cat(val_stats, 0)
|
|
423
|
+
|
|
424
|
+
if key is None:
|
|
425
|
+
keys = list(proof_environment.observation_spec.keys(True, True))
|
|
426
|
+
key = keys.pop()
|
|
427
|
+
if len(keys):
|
|
428
|
+
raise RuntimeError(
|
|
429
|
+
f"More than one key exists in the observation_specs: {[key] + keys} were found, "
|
|
430
|
+
"thus get_stats_random_rollout cannot infer which to compute the stats of."
|
|
431
|
+
)
|
|
432
|
+
|
|
433
|
+
if key == "pixels":
|
|
434
|
+
m = val_stats.mean()
|
|
435
|
+
s = val_stats.std()
|
|
436
|
+
else:
|
|
437
|
+
m = val_stats.mean(dim=0)
|
|
438
|
+
s = val_stats.std(dim=0)
|
|
439
|
+
m[s == 0] = 0.0
|
|
440
|
+
s[s == 0] = 1.0
|
|
441
|
+
|
|
442
|
+
if VERBOSE:
|
|
443
|
+
torchrl_logger.info(
|
|
444
|
+
f"stats computed for {val_stats.numel()} steps. Got: \n"
|
|
445
|
+
f"loc = {m}, \n"
|
|
446
|
+
f"scale = {s}"
|
|
447
|
+
)
|
|
448
|
+
if not torch.isfinite(m).all():
|
|
449
|
+
raise RuntimeError("non-finite values found in mean")
|
|
450
|
+
if not torch.isfinite(s).all():
|
|
451
|
+
raise RuntimeError("non-finite values found in sd")
|
|
452
|
+
stats = {"loc": m, "scale": s}
|
|
453
|
+
if proof_env_is_none:
|
|
454
|
+
proof_environment.close()
|
|
455
|
+
if (
|
|
456
|
+
proof_environment.device != torch.device("cpu")
|
|
457
|
+
and torch.cuda.device_count() > 0
|
|
458
|
+
):
|
|
459
|
+
torch.cuda.empty_cache()
|
|
460
|
+
del proof_environment
|
|
461
|
+
return stats
|
|
462
|
+
|
|
463
|
+
|
|
464
|
+
def initialize_observation_norm_transforms(
|
|
465
|
+
proof_environment: EnvBase,
|
|
466
|
+
num_iter: int = 1000,
|
|
467
|
+
key: Optional[Union[str, tuple[str, ...]]] = None,
|
|
468
|
+
):
|
|
469
|
+
"""Calls :obj:`ObservationNorm.init_stats` on all uninitialized :obj:`ObservationNorm` instances of a :obj:`TransformedEnv`.
|
|
470
|
+
|
|
471
|
+
If an :obj:`ObservationNorm` already has non-null :obj:`loc` or :obj:`scale`, a call to :obj:`initialize_observation_norm_transforms` will be a no-op.
|
|
472
|
+
Similarly, if the transformed environment does not contain any :obj:`ObservationNorm`, a call to this function will have no effect.
|
|
473
|
+
If no key is provided but the observations of the :obj:`EnvBase` contains more than one key, an exception will
|
|
474
|
+
be raised.
|
|
475
|
+
|
|
476
|
+
Args:
|
|
477
|
+
proof_environment (EnvBase instance, optional): if provided, this env will
|
|
478
|
+
be used to execute the rollouts. If not, it will be created using
|
|
479
|
+
the cfg object.
|
|
480
|
+
num_iter (int): Number of iterations used for initializing the :obj:`ObservationNorms`
|
|
481
|
+
key (str, optional): if provided, the stats of this key will be gathered.
|
|
482
|
+
If not, it is expected that only one key exists in `env.observation_spec`.
|
|
483
|
+
|
|
484
|
+
"""
|
|
485
|
+
if not isinstance(proof_environment.transform, Compose) and not isinstance(
|
|
486
|
+
proof_environment.transform, ObservationNorm
|
|
487
|
+
):
|
|
488
|
+
return
|
|
489
|
+
|
|
490
|
+
if key is None:
|
|
491
|
+
keys = list(proof_environment.base_env.observation_spec.keys(True, True))
|
|
492
|
+
key = keys.pop()
|
|
493
|
+
if len(keys):
|
|
494
|
+
raise RuntimeError(
|
|
495
|
+
f"More than one key exists in the observation_specs: {[key] + keys} were found, "
|
|
496
|
+
"thus initialize_observation_norm_transforms cannot infer which to compute the stats of."
|
|
497
|
+
)
|
|
498
|
+
|
|
499
|
+
if isinstance(proof_environment.transform, Compose):
|
|
500
|
+
for transform in proof_environment.transform:
|
|
501
|
+
if isinstance(transform, ObservationNorm) and not transform.initialized:
|
|
502
|
+
transform.init_stats(num_iter=num_iter, key=key)
|
|
503
|
+
elif not proof_environment.transform.initialized:
|
|
504
|
+
proof_environment.transform.init_stats(num_iter=num_iter, key=key)
|
|
505
|
+
|
|
506
|
+
|
|
507
|
+
def retrieve_observation_norms_state_dict(proof_environment: TransformedEnv):
|
|
508
|
+
"""Traverses the transforms of the environment and retrieves the :obj:`ObservationNorm` state dicts.
|
|
509
|
+
|
|
510
|
+
Returns a list of tuple (idx, state_dict) for each :obj:`ObservationNorm` transform in proof_environment
|
|
511
|
+
If the environment transforms do not contain any :obj:`ObservationNorm`, returns an empty list
|
|
512
|
+
|
|
513
|
+
Args:
|
|
514
|
+
proof_environment (EnvBase instance, optional): the :obj:``TransformedEnv` to retrieve the :obj:`ObservationNorm`
|
|
515
|
+
state dict from
|
|
516
|
+
"""
|
|
517
|
+
obs_norm_state_dicts = []
|
|
518
|
+
|
|
519
|
+
if isinstance(proof_environment.transform, Compose):
|
|
520
|
+
for idx, transform in enumerate(proof_environment.transform):
|
|
521
|
+
if isinstance(transform, ObservationNorm):
|
|
522
|
+
obs_norm_state_dicts.append((idx, transform.state_dict()))
|
|
523
|
+
|
|
524
|
+
if isinstance(proof_environment.transform, ObservationNorm):
|
|
525
|
+
obs_norm_state_dicts.append((0, proof_environment.transform.state_dict()))
|
|
526
|
+
|
|
527
|
+
return obs_norm_state_dicts
|
|
528
|
+
|
|
529
|
+
|
|
530
|
+
@dataclass
|
|
531
|
+
class EnvConfig:
|
|
532
|
+
"""Environment config struct."""
|
|
533
|
+
|
|
534
|
+
env_library: str = "gym"
|
|
535
|
+
# env_library used for the simulated environment. Default=gym
|
|
536
|
+
env_name: str = "Humanoid-v2"
|
|
537
|
+
# name of the environment to be created. Default=Humanoid-v2
|
|
538
|
+
env_task: str = ""
|
|
539
|
+
# task (if any) for the environment. Default=run
|
|
540
|
+
from_pixels: bool = False
|
|
541
|
+
# whether the environment output should be state vector(s) (default) or the pixels.
|
|
542
|
+
frame_skip: int = 1
|
|
543
|
+
# frame_skip for the environment. Note that this value does NOT impact the buffer size,
|
|
544
|
+
# maximum steps per trajectory, frames per batch or any other factor in the algorithm,
|
|
545
|
+
# e.g. if the total number of frames that has to be computed is 50e6 and the frame skip is 4
|
|
546
|
+
# the actual number of frames retrieved will be 200e6. Default=1.
|
|
547
|
+
reward_scaling: Any = None # noqa
|
|
548
|
+
# scale of the reward.
|
|
549
|
+
reward_loc: float = 0.0
|
|
550
|
+
# location of the reward.
|
|
551
|
+
init_env_steps: int = 1000
|
|
552
|
+
# number of random steps to compute normalizing constants
|
|
553
|
+
vecnorm: bool = False
|
|
554
|
+
# Normalizes the environment observation and reward outputs with the running statistics obtained across processes.
|
|
555
|
+
norm_rewards: bool = False
|
|
556
|
+
# If True, rewards will be normalized on the fly. This may interfere with SAC update rule and should be used cautiously.
|
|
557
|
+
norm_stats: bool = True
|
|
558
|
+
# Deactivates the normalization based on random collection of data.
|
|
559
|
+
noops: int = 0
|
|
560
|
+
# number of random steps to do after reset. Default is 0
|
|
561
|
+
catframes: int = 0
|
|
562
|
+
# Number of frames to concatenate through time. Default is 0 (do not use CatFrames).
|
|
563
|
+
center_crop: Any = dataclass_field(default_factory=lambda: [])
|
|
564
|
+
# center crop size.
|
|
565
|
+
grayscale: bool = True
|
|
566
|
+
# Disables grayscale transform.
|
|
567
|
+
max_frames_per_traj: int = 1000
|
|
568
|
+
# Number of steps before a reset of the environment is called (if it has not been flagged as done before).
|
|
569
|
+
batch_transform: bool = False
|
|
570
|
+
# if ``True``, the transforms will be applied to the parallel env, and not to each individual env.\
|
|
571
|
+
image_size: int = 84
|
|
572
|
+
# if True and environment has discrete action space, then it is encoded as categorical values rather than one-hot.
|
|
573
|
+
categorical_action_encoding: bool = False
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
from dataclasses import dataclass, field
|
|
8
|
+
from typing import Any
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@dataclass
|
|
12
|
+
class LoggerConfig:
|
|
13
|
+
"""Logger config data-class."""
|
|
14
|
+
|
|
15
|
+
logger: str = "csv"
|
|
16
|
+
# recorder type to be used. One of 'tensorboard', 'wandb' or 'csv'
|
|
17
|
+
record_video: bool = False
|
|
18
|
+
# whether a video of the task should be rendered during logging.
|
|
19
|
+
no_video: bool = True
|
|
20
|
+
# whether a video of the task should be rendered during logging.
|
|
21
|
+
exp_name: str = ""
|
|
22
|
+
# experiment name. Used for logging directory.
|
|
23
|
+
# A date and uuid will be joined to account for multiple experiments with the same name.
|
|
24
|
+
record_interval: int = 1000
|
|
25
|
+
# number of batch collections in between two collections of validation rollouts. Default=1000.
|
|
26
|
+
record_frames: int = 1000
|
|
27
|
+
# number of steps in validation rollouts. " "Default=1000.
|
|
28
|
+
recorder_log_keys: Any = field(default_factory=lambda: None)
|
|
29
|
+
# Keys to log in the recorder
|
|
30
|
+
offline_logging: bool = True
|
|
31
|
+
# If True, Wandb will do the logging offline
|
|
32
|
+
project_name: str = ""
|
|
33
|
+
# The name of the project for WandB
|