torchrl 0.11.0__cp314-cp314t-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/benchmark_batched_envs.py +104 -0
- benchmarks/conftest.py +91 -0
- benchmarks/ecosystem/gym_env_throughput.py +321 -0
- benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
- benchmarks/requirements.txt +7 -0
- benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
- benchmarks/test_collectors_benchmark.py +240 -0
- benchmarks/test_compressed_storage_benchmark.py +145 -0
- benchmarks/test_envs_benchmark.py +133 -0
- benchmarks/test_llm.py +101 -0
- benchmarks/test_non_tensor_env_benchmark.py +70 -0
- benchmarks/test_objectives_benchmarks.py +1199 -0
- benchmarks/test_replaybuffer_benchmark.py +254 -0
- sota-check/README.md +35 -0
- sota-implementations/README.md +142 -0
- sota-implementations/a2c/README.md +39 -0
- sota-implementations/a2c/a2c_atari.py +291 -0
- sota-implementations/a2c/a2c_mujoco.py +273 -0
- sota-implementations/a2c/utils_atari.py +240 -0
- sota-implementations/a2c/utils_mujoco.py +160 -0
- sota-implementations/bandits/README.md +7 -0
- sota-implementations/bandits/dqn.py +126 -0
- sota-implementations/cql/cql_offline.py +198 -0
- sota-implementations/cql/cql_online.py +249 -0
- sota-implementations/cql/discrete_cql_offline.py +180 -0
- sota-implementations/cql/discrete_cql_online.py +227 -0
- sota-implementations/cql/utils.py +471 -0
- sota-implementations/crossq/crossq.py +271 -0
- sota-implementations/crossq/utils.py +320 -0
- sota-implementations/ddpg/ddpg.py +231 -0
- sota-implementations/ddpg/utils.py +325 -0
- sota-implementations/decision_transformer/dt.py +163 -0
- sota-implementations/decision_transformer/lamb.py +167 -0
- sota-implementations/decision_transformer/online_dt.py +178 -0
- sota-implementations/decision_transformer/utils.py +562 -0
- sota-implementations/discrete_sac/discrete_sac.py +243 -0
- sota-implementations/discrete_sac/utils.py +324 -0
- sota-implementations/dqn/README.md +30 -0
- sota-implementations/dqn/dqn_atari.py +272 -0
- sota-implementations/dqn/dqn_cartpole.py +236 -0
- sota-implementations/dqn/utils_atari.py +132 -0
- sota-implementations/dqn/utils_cartpole.py +90 -0
- sota-implementations/dreamer/README.md +129 -0
- sota-implementations/dreamer/dreamer.py +586 -0
- sota-implementations/dreamer/dreamer_utils.py +1107 -0
- sota-implementations/expert-iteration/README.md +352 -0
- sota-implementations/expert-iteration/ei_utils.py +770 -0
- sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
- sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
- sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
- sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
- sota-implementations/gail/gail.py +327 -0
- sota-implementations/gail/gail_utils.py +68 -0
- sota-implementations/gail/ppo_utils.py +157 -0
- sota-implementations/grpo/README.md +273 -0
- sota-implementations/grpo/grpo-async.py +437 -0
- sota-implementations/grpo/grpo-sync.py +435 -0
- sota-implementations/grpo/grpo_utils.py +843 -0
- sota-implementations/grpo/requirements_gsm8k.txt +11 -0
- sota-implementations/grpo/requirements_ifeval.txt +16 -0
- sota-implementations/impala/README.md +33 -0
- sota-implementations/impala/impala_multi_node_ray.py +292 -0
- sota-implementations/impala/impala_multi_node_submitit.py +284 -0
- sota-implementations/impala/impala_single_node.py +261 -0
- sota-implementations/impala/utils.py +184 -0
- sota-implementations/iql/discrete_iql.py +230 -0
- sota-implementations/iql/iql_offline.py +164 -0
- sota-implementations/iql/iql_online.py +225 -0
- sota-implementations/iql/utils.py +437 -0
- sota-implementations/multiagent/README.md +74 -0
- sota-implementations/multiagent/iql.py +237 -0
- sota-implementations/multiagent/maddpg_iddpg.py +266 -0
- sota-implementations/multiagent/mappo_ippo.py +267 -0
- sota-implementations/multiagent/qmix_vdn.py +271 -0
- sota-implementations/multiagent/sac.py +337 -0
- sota-implementations/multiagent/utils/__init__.py +4 -0
- sota-implementations/multiagent/utils/logging.py +151 -0
- sota-implementations/multiagent/utils/utils.py +43 -0
- sota-implementations/ppo/README.md +29 -0
- sota-implementations/ppo/ppo_atari.py +305 -0
- sota-implementations/ppo/ppo_mujoco.py +293 -0
- sota-implementations/ppo/utils_atari.py +238 -0
- sota-implementations/ppo/utils_mujoco.py +152 -0
- sota-implementations/ppo_trainer/train.py +21 -0
- sota-implementations/redq/README.md +7 -0
- sota-implementations/redq/redq.py +199 -0
- sota-implementations/redq/utils.py +1060 -0
- sota-implementations/sac/sac-async.py +266 -0
- sota-implementations/sac/sac.py +239 -0
- sota-implementations/sac/utils.py +381 -0
- sota-implementations/sac_trainer/train.py +16 -0
- sota-implementations/td3/td3.py +254 -0
- sota-implementations/td3/utils.py +319 -0
- sota-implementations/td3_bc/td3_bc.py +177 -0
- sota-implementations/td3_bc/utils.py +251 -0
- torchrl/.dylibs/libc++.1.0.dylib +0 -0
- torchrl/__init__.py +144 -0
- torchrl/_extension.py +74 -0
- torchrl/_torchrl.cpython-314t-darwin.so +0 -0
- torchrl/_utils.py +1431 -0
- torchrl/collectors/__init__.py +48 -0
- torchrl/collectors/_base.py +1058 -0
- torchrl/collectors/_constants.py +88 -0
- torchrl/collectors/_multi_async.py +324 -0
- torchrl/collectors/_multi_base.py +1805 -0
- torchrl/collectors/_multi_sync.py +464 -0
- torchrl/collectors/_runner.py +581 -0
- torchrl/collectors/_single.py +2009 -0
- torchrl/collectors/_single_async.py +259 -0
- torchrl/collectors/collectors.py +62 -0
- torchrl/collectors/distributed/__init__.py +32 -0
- torchrl/collectors/distributed/default_configs.py +133 -0
- torchrl/collectors/distributed/generic.py +1306 -0
- torchrl/collectors/distributed/ray.py +1092 -0
- torchrl/collectors/distributed/rpc.py +1006 -0
- torchrl/collectors/distributed/sync.py +731 -0
- torchrl/collectors/distributed/utils.py +160 -0
- torchrl/collectors/llm/__init__.py +10 -0
- torchrl/collectors/llm/base.py +494 -0
- torchrl/collectors/llm/ray_collector.py +275 -0
- torchrl/collectors/llm/utils.py +36 -0
- torchrl/collectors/llm/weight_update/__init__.py +10 -0
- torchrl/collectors/llm/weight_update/vllm.py +348 -0
- torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
- torchrl/collectors/utils.py +433 -0
- torchrl/collectors/weight_update.py +591 -0
- torchrl/csrc/numpy_utils.h +38 -0
- torchrl/csrc/pybind.cpp +27 -0
- torchrl/csrc/segment_tree.h +458 -0
- torchrl/csrc/torch_utils.h +34 -0
- torchrl/csrc/utils.cpp +48 -0
- torchrl/csrc/utils.h +31 -0
- torchrl/data/__init__.py +187 -0
- torchrl/data/datasets/__init__.py +58 -0
- torchrl/data/datasets/atari_dqn.py +878 -0
- torchrl/data/datasets/common.py +281 -0
- torchrl/data/datasets/d4rl.py +489 -0
- torchrl/data/datasets/d4rl_infos.py +187 -0
- torchrl/data/datasets/gen_dgrl.py +375 -0
- torchrl/data/datasets/minari_data.py +643 -0
- torchrl/data/datasets/openml.py +177 -0
- torchrl/data/datasets/openx.py +798 -0
- torchrl/data/datasets/roboset.py +363 -0
- torchrl/data/datasets/utils.py +11 -0
- torchrl/data/datasets/vd4rl.py +432 -0
- torchrl/data/llm/__init__.py +34 -0
- torchrl/data/llm/dataset.py +491 -0
- torchrl/data/llm/history.py +1378 -0
- torchrl/data/llm/prompt.py +198 -0
- torchrl/data/llm/reward.py +225 -0
- torchrl/data/llm/topk.py +186 -0
- torchrl/data/llm/utils.py +543 -0
- torchrl/data/map/__init__.py +21 -0
- torchrl/data/map/hash.py +185 -0
- torchrl/data/map/query.py +204 -0
- torchrl/data/map/tdstorage.py +363 -0
- torchrl/data/map/tree.py +1434 -0
- torchrl/data/map/utils.py +103 -0
- torchrl/data/postprocs/__init__.py +8 -0
- torchrl/data/postprocs/postprocs.py +391 -0
- torchrl/data/replay_buffers/__init__.py +99 -0
- torchrl/data/replay_buffers/checkpointers.py +622 -0
- torchrl/data/replay_buffers/ray_buffer.py +292 -0
- torchrl/data/replay_buffers/replay_buffers.py +2376 -0
- torchrl/data/replay_buffers/samplers.py +2578 -0
- torchrl/data/replay_buffers/scheduler.py +265 -0
- torchrl/data/replay_buffers/storages.py +2412 -0
- torchrl/data/replay_buffers/utils.py +1042 -0
- torchrl/data/replay_buffers/writers.py +781 -0
- torchrl/data/tensor_specs.py +7101 -0
- torchrl/data/utils.py +334 -0
- torchrl/envs/__init__.py +265 -0
- torchrl/envs/async_envs.py +1105 -0
- torchrl/envs/batched_envs.py +3093 -0
- torchrl/envs/common.py +4241 -0
- torchrl/envs/custom/__init__.py +11 -0
- torchrl/envs/custom/chess.py +617 -0
- torchrl/envs/custom/llm.py +214 -0
- torchrl/envs/custom/pendulum.py +401 -0
- torchrl/envs/custom/san_moves.txt +29274 -0
- torchrl/envs/custom/tictactoeenv.py +288 -0
- torchrl/envs/env_creator.py +263 -0
- torchrl/envs/gym_like.py +752 -0
- torchrl/envs/libs/__init__.py +68 -0
- torchrl/envs/libs/_gym_utils.py +326 -0
- torchrl/envs/libs/brax.py +846 -0
- torchrl/envs/libs/dm_control.py +544 -0
- torchrl/envs/libs/envpool.py +447 -0
- torchrl/envs/libs/gym.py +2239 -0
- torchrl/envs/libs/habitat.py +138 -0
- torchrl/envs/libs/isaac_lab.py +87 -0
- torchrl/envs/libs/isaacgym.py +203 -0
- torchrl/envs/libs/jax_utils.py +166 -0
- torchrl/envs/libs/jumanji.py +963 -0
- torchrl/envs/libs/meltingpot.py +599 -0
- torchrl/envs/libs/openml.py +153 -0
- torchrl/envs/libs/openspiel.py +652 -0
- torchrl/envs/libs/pettingzoo.py +1042 -0
- torchrl/envs/libs/procgen.py +351 -0
- torchrl/envs/libs/robohive.py +429 -0
- torchrl/envs/libs/smacv2.py +645 -0
- torchrl/envs/libs/unity_mlagents.py +891 -0
- torchrl/envs/libs/utils.py +147 -0
- torchrl/envs/libs/vmas.py +813 -0
- torchrl/envs/llm/__init__.py +63 -0
- torchrl/envs/llm/chat.py +730 -0
- torchrl/envs/llm/datasets/README.md +4 -0
- torchrl/envs/llm/datasets/__init__.py +17 -0
- torchrl/envs/llm/datasets/gsm8k.py +353 -0
- torchrl/envs/llm/datasets/ifeval.py +274 -0
- torchrl/envs/llm/envs.py +789 -0
- torchrl/envs/llm/libs/README.md +3 -0
- torchrl/envs/llm/libs/__init__.py +8 -0
- torchrl/envs/llm/libs/mlgym.py +869 -0
- torchrl/envs/llm/reward/__init__.py +10 -0
- torchrl/envs/llm/reward/gsm8k.py +324 -0
- torchrl/envs/llm/reward/ifeval/README.md +13 -0
- torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
- torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
- torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
- torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
- torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
- torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
- torchrl/envs/llm/transforms/__init__.py +55 -0
- torchrl/envs/llm/transforms/browser.py +292 -0
- torchrl/envs/llm/transforms/dataloading.py +859 -0
- torchrl/envs/llm/transforms/format.py +73 -0
- torchrl/envs/llm/transforms/kl.py +1544 -0
- torchrl/envs/llm/transforms/policy_version.py +189 -0
- torchrl/envs/llm/transforms/reason.py +323 -0
- torchrl/envs/llm/transforms/tokenizer.py +321 -0
- torchrl/envs/llm/transforms/tools.py +1955 -0
- torchrl/envs/model_based/__init__.py +9 -0
- torchrl/envs/model_based/common.py +180 -0
- torchrl/envs/model_based/dreamer.py +112 -0
- torchrl/envs/transforms/__init__.py +147 -0
- torchrl/envs/transforms/functional.py +48 -0
- torchrl/envs/transforms/gym_transforms.py +203 -0
- torchrl/envs/transforms/module.py +341 -0
- torchrl/envs/transforms/r3m.py +372 -0
- torchrl/envs/transforms/ray_service.py +663 -0
- torchrl/envs/transforms/rb_transforms.py +214 -0
- torchrl/envs/transforms/transforms.py +11835 -0
- torchrl/envs/transforms/utils.py +94 -0
- torchrl/envs/transforms/vc1.py +307 -0
- torchrl/envs/transforms/vecnorm.py +845 -0
- torchrl/envs/transforms/vip.py +407 -0
- torchrl/envs/utils.py +1718 -0
- torchrl/envs/vec_envs.py +11 -0
- torchrl/modules/__init__.py +206 -0
- torchrl/modules/distributions/__init__.py +73 -0
- torchrl/modules/distributions/continuous.py +830 -0
- torchrl/modules/distributions/discrete.py +908 -0
- torchrl/modules/distributions/truncated_normal.py +187 -0
- torchrl/modules/distributions/utils.py +233 -0
- torchrl/modules/llm/__init__.py +62 -0
- torchrl/modules/llm/backends/__init__.py +65 -0
- torchrl/modules/llm/backends/vllm/__init__.py +94 -0
- torchrl/modules/llm/backends/vllm/_models.py +46 -0
- torchrl/modules/llm/backends/vllm/base.py +72 -0
- torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
- torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
- torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
- torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
- torchrl/modules/llm/policies/__init__.py +28 -0
- torchrl/modules/llm/policies/common.py +1809 -0
- torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
- torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
- torchrl/modules/llm/utils.py +23 -0
- torchrl/modules/mcts/__init__.py +21 -0
- torchrl/modules/mcts/scores.py +579 -0
- torchrl/modules/models/__init__.py +86 -0
- torchrl/modules/models/batchrenorm.py +119 -0
- torchrl/modules/models/decision_transformer.py +179 -0
- torchrl/modules/models/exploration.py +731 -0
- torchrl/modules/models/llm.py +156 -0
- torchrl/modules/models/model_based.py +596 -0
- torchrl/modules/models/models.py +1712 -0
- torchrl/modules/models/multiagent.py +1067 -0
- torchrl/modules/models/recipes/impala.py +185 -0
- torchrl/modules/models/utils.py +162 -0
- torchrl/modules/planners/__init__.py +10 -0
- torchrl/modules/planners/cem.py +228 -0
- torchrl/modules/planners/common.py +73 -0
- torchrl/modules/planners/mppi.py +265 -0
- torchrl/modules/tensordict_module/__init__.py +89 -0
- torchrl/modules/tensordict_module/actors.py +2457 -0
- torchrl/modules/tensordict_module/common.py +529 -0
- torchrl/modules/tensordict_module/exploration.py +814 -0
- torchrl/modules/tensordict_module/probabilistic.py +321 -0
- torchrl/modules/tensordict_module/rnn.py +1639 -0
- torchrl/modules/tensordict_module/sequence.py +132 -0
- torchrl/modules/tensordict_module/world_models.py +34 -0
- torchrl/modules/utils/__init__.py +38 -0
- torchrl/modules/utils/mappings.py +9 -0
- torchrl/modules/utils/utils.py +89 -0
- torchrl/objectives/__init__.py +78 -0
- torchrl/objectives/a2c.py +659 -0
- torchrl/objectives/common.py +753 -0
- torchrl/objectives/cql.py +1346 -0
- torchrl/objectives/crossq.py +710 -0
- torchrl/objectives/ddpg.py +453 -0
- torchrl/objectives/decision_transformer.py +371 -0
- torchrl/objectives/deprecated.py +516 -0
- torchrl/objectives/dqn.py +683 -0
- torchrl/objectives/dreamer.py +488 -0
- torchrl/objectives/functional.py +48 -0
- torchrl/objectives/gail.py +258 -0
- torchrl/objectives/iql.py +996 -0
- torchrl/objectives/llm/__init__.py +30 -0
- torchrl/objectives/llm/grpo.py +846 -0
- torchrl/objectives/llm/sft.py +482 -0
- torchrl/objectives/multiagent/__init__.py +8 -0
- torchrl/objectives/multiagent/qmixer.py +396 -0
- torchrl/objectives/ppo.py +1669 -0
- torchrl/objectives/redq.py +683 -0
- torchrl/objectives/reinforce.py +530 -0
- torchrl/objectives/sac.py +1580 -0
- torchrl/objectives/td3.py +570 -0
- torchrl/objectives/td3_bc.py +625 -0
- torchrl/objectives/utils.py +782 -0
- torchrl/objectives/value/__init__.py +28 -0
- torchrl/objectives/value/advantages.py +1956 -0
- torchrl/objectives/value/functional.py +1459 -0
- torchrl/objectives/value/utils.py +360 -0
- torchrl/record/__init__.py +17 -0
- torchrl/record/loggers/__init__.py +23 -0
- torchrl/record/loggers/common.py +48 -0
- torchrl/record/loggers/csv.py +226 -0
- torchrl/record/loggers/mlflow.py +142 -0
- torchrl/record/loggers/tensorboard.py +139 -0
- torchrl/record/loggers/trackio.py +163 -0
- torchrl/record/loggers/utils.py +78 -0
- torchrl/record/loggers/wandb.py +214 -0
- torchrl/record/recorder.py +554 -0
- torchrl/services/__init__.py +79 -0
- torchrl/services/base.py +109 -0
- torchrl/services/ray_service.py +453 -0
- torchrl/testing/__init__.py +107 -0
- torchrl/testing/assertions.py +179 -0
- torchrl/testing/dist_utils.py +122 -0
- torchrl/testing/env_creators.py +227 -0
- torchrl/testing/env_helper.py +35 -0
- torchrl/testing/gym_helpers.py +156 -0
- torchrl/testing/llm_mocks.py +119 -0
- torchrl/testing/mocking_classes.py +2720 -0
- torchrl/testing/modules.py +295 -0
- torchrl/testing/mp_helpers.py +15 -0
- torchrl/testing/ray_helpers.py +293 -0
- torchrl/testing/utils.py +190 -0
- torchrl/trainers/__init__.py +42 -0
- torchrl/trainers/algorithms/__init__.py +11 -0
- torchrl/trainers/algorithms/configs/__init__.py +705 -0
- torchrl/trainers/algorithms/configs/collectors.py +216 -0
- torchrl/trainers/algorithms/configs/common.py +41 -0
- torchrl/trainers/algorithms/configs/data.py +308 -0
- torchrl/trainers/algorithms/configs/envs.py +104 -0
- torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
- torchrl/trainers/algorithms/configs/logging.py +80 -0
- torchrl/trainers/algorithms/configs/modules.py +570 -0
- torchrl/trainers/algorithms/configs/objectives.py +177 -0
- torchrl/trainers/algorithms/configs/trainers.py +340 -0
- torchrl/trainers/algorithms/configs/transforms.py +955 -0
- torchrl/trainers/algorithms/configs/utils.py +252 -0
- torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
- torchrl/trainers/algorithms/configs/weight_update.py +159 -0
- torchrl/trainers/algorithms/ppo.py +373 -0
- torchrl/trainers/algorithms/sac.py +308 -0
- torchrl/trainers/helpers/__init__.py +40 -0
- torchrl/trainers/helpers/collectors.py +416 -0
- torchrl/trainers/helpers/envs.py +573 -0
- torchrl/trainers/helpers/logger.py +33 -0
- torchrl/trainers/helpers/losses.py +132 -0
- torchrl/trainers/helpers/models.py +658 -0
- torchrl/trainers/helpers/replay_buffer.py +59 -0
- torchrl/trainers/helpers/trainers.py +301 -0
- torchrl/trainers/trainers.py +2052 -0
- torchrl/weight_update/__init__.py +33 -0
- torchrl/weight_update/_distributed.py +749 -0
- torchrl/weight_update/_mp.py +624 -0
- torchrl/weight_update/_noupdate.py +102 -0
- torchrl/weight_update/_ray.py +1032 -0
- torchrl/weight_update/_rpc.py +284 -0
- torchrl/weight_update/_shared.py +891 -0
- torchrl/weight_update/llm/__init__.py +32 -0
- torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
- torchrl/weight_update/llm/vllm_nccl.py +710 -0
- torchrl/weight_update/utils.py +73 -0
- torchrl/weight_update/weight_sync_schemes.py +1244 -0
- torchrl-0.11.0.dist-info/METADATA +1308 -0
- torchrl-0.11.0.dist-info/RECORD +395 -0
- torchrl-0.11.0.dist-info/WHEEL +5 -0
- torchrl-0.11.0.dist-info/entry_points.txt +2 -0
- torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
- torchrl-0.11.0.dist-info/top_level.txt +7 -0
|
@@ -0,0 +1,288 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
from tensordict import TensorDict, TensorDictBase
|
|
9
|
+
from torchrl.data.tensor_specs import Categorical, Composite, Unbounded
|
|
10
|
+
from torchrl.envs.common import EnvBase
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class TicTacToeEnv(EnvBase):
|
|
14
|
+
"""A Tic-Tac-Toe implementation.
|
|
15
|
+
|
|
16
|
+
Keyword Args:
|
|
17
|
+
single_player (bool, optional): whether one or two players have to be
|
|
18
|
+
accounted for. ``single_player=True`` means that ``"player1"`` is
|
|
19
|
+
playing randomly. If ``False`` (default), at each turn,
|
|
20
|
+
one of the two players has to play.
|
|
21
|
+
device (torch.device, optional): the device where to put the tensors.
|
|
22
|
+
Defaults to ``None`` (default device).
|
|
23
|
+
|
|
24
|
+
The environment is stateless. To run it across multiple batches, call
|
|
25
|
+
|
|
26
|
+
>>> env.reset(TensorDict(batch_size=desired_batch_size))
|
|
27
|
+
|
|
28
|
+
If the ``"mask"`` entry is present, ``rand_action`` takes it into account to
|
|
29
|
+
generate the next action. Any policy executed on this env should take this
|
|
30
|
+
mask into account, as well as the turn of the player (stored in the ``"turn"``
|
|
31
|
+
output entry).
|
|
32
|
+
|
|
33
|
+
Specs:
|
|
34
|
+
>>> print(env.specs)
|
|
35
|
+
Composite(
|
|
36
|
+
output_spec: Composite(
|
|
37
|
+
full_observation_spec: Composite(
|
|
38
|
+
board: Categorical(
|
|
39
|
+
shape=torch.Size([3, 3]),
|
|
40
|
+
space=DiscreteBox(n=2),
|
|
41
|
+
dtype=torch.int32,
|
|
42
|
+
domain=discrete),
|
|
43
|
+
turn: Categorical(
|
|
44
|
+
shape=torch.Size([1]),
|
|
45
|
+
space=DiscreteBox(n=2),
|
|
46
|
+
dtype=torch.int32,
|
|
47
|
+
domain=discrete),
|
|
48
|
+
mask: Categorical(
|
|
49
|
+
shape=torch.Size([9]),
|
|
50
|
+
space=DiscreteBox(n=2),
|
|
51
|
+
dtype=torch.bool,
|
|
52
|
+
domain=discrete),
|
|
53
|
+
shape=torch.Size([])),
|
|
54
|
+
full_reward_spec: Composite(
|
|
55
|
+
player0: Composite(
|
|
56
|
+
reward: UnboundedContinuous(
|
|
57
|
+
shape=torch.Size([1]),
|
|
58
|
+
space=ContinuousBox(
|
|
59
|
+
low=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True),
|
|
60
|
+
high=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True)),
|
|
61
|
+
dtype=torch.float32,
|
|
62
|
+
domain=continuous),
|
|
63
|
+
shape=torch.Size([])),
|
|
64
|
+
player1: Composite(
|
|
65
|
+
reward: UnboundedContinuous(
|
|
66
|
+
shape=torch.Size([1]),
|
|
67
|
+
space=ContinuousBox(
|
|
68
|
+
low=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True),
|
|
69
|
+
high=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True)),
|
|
70
|
+
dtype=torch.float32,
|
|
71
|
+
domain=continuous),
|
|
72
|
+
shape=torch.Size([])),
|
|
73
|
+
shape=torch.Size([])),
|
|
74
|
+
full_done_spec: Composite(
|
|
75
|
+
done: Categorical(
|
|
76
|
+
shape=torch.Size([1]),
|
|
77
|
+
space=DiscreteBox(n=2),
|
|
78
|
+
dtype=torch.bool,
|
|
79
|
+
domain=discrete),
|
|
80
|
+
terminated: Categorical(
|
|
81
|
+
shape=torch.Size([1]),
|
|
82
|
+
space=DiscreteBox(n=2),
|
|
83
|
+
dtype=torch.bool,
|
|
84
|
+
domain=discrete),
|
|
85
|
+
truncated: Categorical(
|
|
86
|
+
shape=torch.Size([1]),
|
|
87
|
+
space=DiscreteBox(n=2),
|
|
88
|
+
dtype=torch.bool,
|
|
89
|
+
domain=discrete),
|
|
90
|
+
shape=torch.Size([])),
|
|
91
|
+
shape=torch.Size([])),
|
|
92
|
+
input_spec: Composite(
|
|
93
|
+
full_state_spec: Composite(
|
|
94
|
+
board: Categorical(
|
|
95
|
+
shape=torch.Size([3, 3]),
|
|
96
|
+
space=DiscreteBox(n=2),
|
|
97
|
+
dtype=torch.int32,
|
|
98
|
+
domain=discrete),
|
|
99
|
+
turn: Categorical(
|
|
100
|
+
shape=torch.Size([1]),
|
|
101
|
+
space=DiscreteBox(n=2),
|
|
102
|
+
dtype=torch.int32,
|
|
103
|
+
domain=discrete),
|
|
104
|
+
mask: Categorical(
|
|
105
|
+
shape=torch.Size([9]),
|
|
106
|
+
space=DiscreteBox(n=2),
|
|
107
|
+
dtype=torch.bool,
|
|
108
|
+
domain=discrete), shape=torch.Size([])),
|
|
109
|
+
full_action_spec: Composite(
|
|
110
|
+
action: Categorical(
|
|
111
|
+
shape=torch.Size([1]),
|
|
112
|
+
space=DiscreteBox(n=9),
|
|
113
|
+
dtype=torch.int64,
|
|
114
|
+
domain=discrete),
|
|
115
|
+
shape=torch.Size([])),
|
|
116
|
+
shape=torch.Size([])),
|
|
117
|
+
shape=torch.Size([]))
|
|
118
|
+
|
|
119
|
+
To run a dummy rollout, execute the following command:
|
|
120
|
+
|
|
121
|
+
Examples:
|
|
122
|
+
>>> env = TicTacToeEnv()
|
|
123
|
+
>>> env.rollout(10)
|
|
124
|
+
TensorDict(
|
|
125
|
+
fields={
|
|
126
|
+
action: Tensor(shape=torch.Size([9, 1]), device=cpu, dtype=torch.int64, is_shared=False),
|
|
127
|
+
board: Tensor(shape=torch.Size([9, 3, 3]), device=cpu, dtype=torch.int32, is_shared=False),
|
|
128
|
+
done: Tensor(shape=torch.Size([9, 1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
129
|
+
mask: Tensor(shape=torch.Size([9, 9]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
130
|
+
next: TensorDict(
|
|
131
|
+
fields={
|
|
132
|
+
board: Tensor(shape=torch.Size([9, 3, 3]), device=cpu, dtype=torch.int32, is_shared=False),
|
|
133
|
+
done: Tensor(shape=torch.Size([9, 1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
134
|
+
mask: Tensor(shape=torch.Size([9, 9]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
135
|
+
player0: TensorDict(
|
|
136
|
+
fields={
|
|
137
|
+
reward: Tensor(shape=torch.Size([9, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
|
|
138
|
+
batch_size=torch.Size([9]),
|
|
139
|
+
device=None,
|
|
140
|
+
is_shared=False),
|
|
141
|
+
player1: TensorDict(
|
|
142
|
+
fields={
|
|
143
|
+
reward: Tensor(shape=torch.Size([9, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
|
|
144
|
+
batch_size=torch.Size([9]),
|
|
145
|
+
device=None,
|
|
146
|
+
is_shared=False),
|
|
147
|
+
terminated: Tensor(shape=torch.Size([9, 1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
148
|
+
truncated: Tensor(shape=torch.Size([9, 1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
149
|
+
turn: Tensor(shape=torch.Size([9, 1]), device=cpu, dtype=torch.int32, is_shared=False)},
|
|
150
|
+
batch_size=torch.Size([9]),
|
|
151
|
+
device=None,
|
|
152
|
+
is_shared=False),
|
|
153
|
+
terminated: Tensor(shape=torch.Size([9, 1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
154
|
+
truncated: Tensor(shape=torch.Size([9, 1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
155
|
+
turn: Tensor(shape=torch.Size([9, 1]), device=cpu, dtype=torch.int32, is_shared=False)},
|
|
156
|
+
batch_size=torch.Size([9]),
|
|
157
|
+
device=None,
|
|
158
|
+
is_shared=False)
|
|
159
|
+
|
|
160
|
+
"""
|
|
161
|
+
|
|
162
|
+
# batch_locked is set to False since various batch sizes can be provided to the env
|
|
163
|
+
batch_locked: bool = False
|
|
164
|
+
|
|
165
|
+
def __init__(self, *, single_player: bool = False, device=None):
|
|
166
|
+
super().__init__(device=device)
|
|
167
|
+
self.single_player = single_player
|
|
168
|
+
self.action_spec: Unbounded = Categorical(
|
|
169
|
+
n=9,
|
|
170
|
+
shape=(),
|
|
171
|
+
device=device,
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
self.full_observation_spec: Composite = Composite(
|
|
175
|
+
board=Unbounded(shape=(3, 3), dtype=torch.int, device=device),
|
|
176
|
+
turn=Categorical(
|
|
177
|
+
2,
|
|
178
|
+
shape=(1,),
|
|
179
|
+
dtype=torch.int,
|
|
180
|
+
device=device,
|
|
181
|
+
),
|
|
182
|
+
mask=Categorical(
|
|
183
|
+
2,
|
|
184
|
+
shape=(9,),
|
|
185
|
+
dtype=torch.bool,
|
|
186
|
+
device=device,
|
|
187
|
+
),
|
|
188
|
+
device=device,
|
|
189
|
+
)
|
|
190
|
+
self.state_spec: Composite = self.observation_spec.clone()
|
|
191
|
+
|
|
192
|
+
self.reward_spec: Unbounded = Composite(
|
|
193
|
+
{
|
|
194
|
+
("player0", "reward"): Unbounded(shape=(1,), device=device),
|
|
195
|
+
("player1", "reward"): Unbounded(shape=(1,), device=device),
|
|
196
|
+
},
|
|
197
|
+
device=device,
|
|
198
|
+
)
|
|
199
|
+
|
|
200
|
+
self.full_done_spec: Categorical = Composite(
|
|
201
|
+
done=Categorical(2, shape=(1,), dtype=torch.bool, device=device),
|
|
202
|
+
device=device,
|
|
203
|
+
)
|
|
204
|
+
self.full_done_spec["terminated"] = self.full_done_spec["done"].clone()
|
|
205
|
+
self.full_done_spec["truncated"] = self.full_done_spec["done"].clone()
|
|
206
|
+
|
|
207
|
+
def _reset(self, reset_td: TensorDict) -> TensorDict:
|
|
208
|
+
shape = reset_td.shape if reset_td is not None else ()
|
|
209
|
+
state = self.state_spec.zero(shape)
|
|
210
|
+
state["board"] -= 1
|
|
211
|
+
state["mask"].fill_(True)
|
|
212
|
+
return state.update(self.full_done_spec.zero(shape))
|
|
213
|
+
|
|
214
|
+
def _step(self, state: TensorDict) -> TensorDict:
|
|
215
|
+
board = state["board"].clone()
|
|
216
|
+
turn = state["turn"].clone()
|
|
217
|
+
action = state["action"]
|
|
218
|
+
board.flatten(-2, -1).scatter_(index=action.unsqueeze(-1), dim=-1, value=1)
|
|
219
|
+
wins = self.win(board, action)
|
|
220
|
+
|
|
221
|
+
mask = board.flatten(-2, -1) == -1
|
|
222
|
+
done = wins | ~mask.any(-1, keepdim=True)
|
|
223
|
+
terminated = done.clone()
|
|
224
|
+
|
|
225
|
+
reward_0 = wins & (turn == 0)
|
|
226
|
+
reward_1 = wins & (turn == 1)
|
|
227
|
+
|
|
228
|
+
state = TensorDict(
|
|
229
|
+
{
|
|
230
|
+
"done": done,
|
|
231
|
+
"terminated": terminated,
|
|
232
|
+
("player0", "reward"): reward_0.float(),
|
|
233
|
+
("player1", "reward"): reward_1.float(),
|
|
234
|
+
"board": torch.where(board == -1, board, 1 - board),
|
|
235
|
+
"turn": 1 - turn,
|
|
236
|
+
"mask": mask,
|
|
237
|
+
},
|
|
238
|
+
batch_size=state.batch_size,
|
|
239
|
+
)
|
|
240
|
+
if self.single_player:
|
|
241
|
+
select = (~done & (turn == 0)).squeeze(-1)
|
|
242
|
+
if select.all():
|
|
243
|
+
state_select = state
|
|
244
|
+
elif select.any():
|
|
245
|
+
state_select = state[select]
|
|
246
|
+
else:
|
|
247
|
+
return state
|
|
248
|
+
state_select = self._step(self.rand_action(state_select))
|
|
249
|
+
if select.all():
|
|
250
|
+
return state_select
|
|
251
|
+
return torch.where(done, state, state_select)
|
|
252
|
+
return state
|
|
253
|
+
|
|
254
|
+
def _set_seed(self, seed: int | None) -> None:
|
|
255
|
+
...
|
|
256
|
+
|
|
257
|
+
@staticmethod
|
|
258
|
+
def win(board: torch.Tensor, action: torch.Tensor):
|
|
259
|
+
row = action // 3 # type: ignore
|
|
260
|
+
col = action % 3 # type: ignore
|
|
261
|
+
if board[..., row, :].sum() == 3:
|
|
262
|
+
return True
|
|
263
|
+
if board[..., col].sum() == 3:
|
|
264
|
+
return True
|
|
265
|
+
if board.diagonal(0, -2, -1).sum() == 3:
|
|
266
|
+
return True
|
|
267
|
+
if board.flip(-1).diagonal(0, -2, -1).sum() == 3:
|
|
268
|
+
return True
|
|
269
|
+
return False
|
|
270
|
+
|
|
271
|
+
@staticmethod
|
|
272
|
+
def full(board: torch.Tensor) -> bool:
|
|
273
|
+
return torch.sym_int(board.abs().sum()) == 9
|
|
274
|
+
|
|
275
|
+
@staticmethod
|
|
276
|
+
def get_action_mask():
|
|
277
|
+
pass
|
|
278
|
+
|
|
279
|
+
def rand_action(self, tensordict: TensorDictBase | None = None):
|
|
280
|
+
mask = tensordict.get("mask")
|
|
281
|
+
action_spec = self.action_spec
|
|
282
|
+
if tensordict.ndim:
|
|
283
|
+
action_spec = action_spec.expand(tensordict.shape)
|
|
284
|
+
else:
|
|
285
|
+
action_spec = action_spec.clone()
|
|
286
|
+
action_spec.update_mask(mask)
|
|
287
|
+
tensordict.set(self.action_key, action_spec.rand())
|
|
288
|
+
return tensordict
|
|
@@ -0,0 +1,263 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
from collections import OrderedDict
|
|
9
|
+
from collections.abc import Callable
|
|
10
|
+
from multiprocessing.sharedctypes import Synchronized
|
|
11
|
+
from multiprocessing.synchronize import Lock, RLock
|
|
12
|
+
|
|
13
|
+
import torch
|
|
14
|
+
from tensordict import TensorDictBase
|
|
15
|
+
from torchrl._utils import logger as torchrl_logger
|
|
16
|
+
from torchrl.data.utils import CloudpickleWrapper
|
|
17
|
+
from torchrl.envs.common import EnvBase, EnvMetaData
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class EnvCreator:
|
|
21
|
+
"""Environment creator class.
|
|
22
|
+
|
|
23
|
+
EnvCreator is a generic environment creator class that can substitute
|
|
24
|
+
lambda functions when creating environments in multiprocessing contexts.
|
|
25
|
+
If the environment created on a subprocess must share information with the
|
|
26
|
+
main process (e.g. for the VecNorm transform), EnvCreator will pass the
|
|
27
|
+
pointers to the tensordicts in shared memory to each process such that
|
|
28
|
+
all of them are synchronised.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
create_env_fn (callable): a callable that returns an EnvBase
|
|
32
|
+
instance.
|
|
33
|
+
create_env_kwargs (dict, optional): the kwargs of the env creator.
|
|
34
|
+
share_memory (bool, optional): if False, the resulting tensordict
|
|
35
|
+
from the environment won't be placed in shared memory.
|
|
36
|
+
**kwargs: additional keyword arguments to be passed to the environment
|
|
37
|
+
during construction.
|
|
38
|
+
|
|
39
|
+
Examples:
|
|
40
|
+
>>> # We create the same environment on 2 processes using VecNorm
|
|
41
|
+
>>> # and check that the discounted count of observations matches on
|
|
42
|
+
>>> # both workers, even if one has not executed any step
|
|
43
|
+
>>> import time
|
|
44
|
+
>>> from torchrl.envs.libs.gym import GymEnv
|
|
45
|
+
>>> from torchrl.envs.transforms import VecNorm, TransformedEnv
|
|
46
|
+
>>> from torchrl.envs import EnvCreator
|
|
47
|
+
>>> from torch import multiprocessing as mp
|
|
48
|
+
>>> env_fn = lambda: TransformedEnv(GymEnv("Pendulum-v1"), VecNorm())
|
|
49
|
+
>>> env_creator = EnvCreator(env_fn)
|
|
50
|
+
>>>
|
|
51
|
+
>>> def test_env1(env_creator):
|
|
52
|
+
... env = env_creator()
|
|
53
|
+
... tensordict = env.reset()
|
|
54
|
+
... for _ in range(10):
|
|
55
|
+
... env.rand_step(tensordict)
|
|
56
|
+
... if tensordict.get(("next", "done")):
|
|
57
|
+
... tensordict = env.reset(tensordict)
|
|
58
|
+
... print("env 1: ", env.transform._td.get(("next", "observation_count")))
|
|
59
|
+
>>>
|
|
60
|
+
>>> def test_env2(env_creator):
|
|
61
|
+
... env = env_creator()
|
|
62
|
+
... time.sleep(5)
|
|
63
|
+
... print("env 2: ", env.transform._td.get(("next", "observation_count")))
|
|
64
|
+
>>>
|
|
65
|
+
>>> if __name__ == "__main__":
|
|
66
|
+
... ps = []
|
|
67
|
+
... p1 = mp.Process(target=test_env1, args=(env_creator,))
|
|
68
|
+
... p1.start()
|
|
69
|
+
... ps.append(p1)
|
|
70
|
+
... p2 = mp.Process(target=test_env2, args=(env_creator,))
|
|
71
|
+
... p2.start()
|
|
72
|
+
... ps.append(p1)
|
|
73
|
+
... for p in ps:
|
|
74
|
+
... p.join()
|
|
75
|
+
env 1: tensor([11.9934])
|
|
76
|
+
env 2: tensor([11.9934])
|
|
77
|
+
"""
|
|
78
|
+
|
|
79
|
+
def __init__(
|
|
80
|
+
self,
|
|
81
|
+
create_env_fn: Callable[..., EnvBase],
|
|
82
|
+
create_env_kwargs: dict | None = None,
|
|
83
|
+
share_memory: bool = True,
|
|
84
|
+
**kwargs,
|
|
85
|
+
) -> None:
|
|
86
|
+
if not isinstance(create_env_fn, (EnvCreator, CloudpickleWrapper)):
|
|
87
|
+
self.create_env_fn = CloudpickleWrapper(create_env_fn)
|
|
88
|
+
else:
|
|
89
|
+
self.create_env_fn = create_env_fn
|
|
90
|
+
|
|
91
|
+
self.create_env_kwargs = kwargs
|
|
92
|
+
if isinstance(create_env_kwargs, dict):
|
|
93
|
+
self.create_env_kwargs.update(create_env_kwargs)
|
|
94
|
+
self.initialized = False
|
|
95
|
+
self._meta_data = None
|
|
96
|
+
self._share_memory = share_memory
|
|
97
|
+
self.init_()
|
|
98
|
+
|
|
99
|
+
def make_variant(self, **kwargs) -> EnvCreator:
|
|
100
|
+
"""Creates a variant of the EnvCreator, pointing to the same underlying metadata but with different keyword arguments during construction.
|
|
101
|
+
|
|
102
|
+
This can be useful with transforms that share a state, like :class:`~torchrl.envs.TrajCounter`.
|
|
103
|
+
|
|
104
|
+
Examples:
|
|
105
|
+
>>> from torchrl.envs import GymEnv
|
|
106
|
+
>>> env_creator_pendulum = EnvCreator(GymEnv, env_name="Pendulum-v1")
|
|
107
|
+
>>> env_creator_cartpole = env_creator_pendulum.make_variant(env_name="CartPole-v1")
|
|
108
|
+
|
|
109
|
+
"""
|
|
110
|
+
# Copy self
|
|
111
|
+
out = type(self).__new__(type(self))
|
|
112
|
+
out.__dict__.update(self.__dict__)
|
|
113
|
+
out.create_env_kwargs.update(kwargs)
|
|
114
|
+
return out
|
|
115
|
+
|
|
116
|
+
def share_memory(self, state_dict: OrderedDict) -> None:
|
|
117
|
+
for key, item in list(state_dict.items()):
|
|
118
|
+
if isinstance(item, (TensorDictBase,)):
|
|
119
|
+
if not item.is_shared():
|
|
120
|
+
item.share_memory_()
|
|
121
|
+
else:
|
|
122
|
+
torchrl_logger.info(
|
|
123
|
+
f"{self.env_type}: {item} is already shared"
|
|
124
|
+
) # , deleting key'val)
|
|
125
|
+
del state_dict[key]
|
|
126
|
+
elif isinstance(item, OrderedDict):
|
|
127
|
+
self.share_memory(item)
|
|
128
|
+
elif isinstance(item, torch.Tensor):
|
|
129
|
+
del state_dict[key]
|
|
130
|
+
|
|
131
|
+
@property
|
|
132
|
+
def meta_data(self) -> EnvMetaData:
|
|
133
|
+
if self._meta_data is None:
|
|
134
|
+
raise RuntimeError(
|
|
135
|
+
"meta_data is None in EnvCreator. " "Make sure init_() has been called."
|
|
136
|
+
)
|
|
137
|
+
return self._meta_data
|
|
138
|
+
|
|
139
|
+
@meta_data.setter
|
|
140
|
+
def meta_data(self, value: EnvMetaData):
|
|
141
|
+
self._meta_data = value
|
|
142
|
+
|
|
143
|
+
@staticmethod
|
|
144
|
+
def _is_mp_value(val):
|
|
145
|
+
if isinstance(val, (Synchronized,)) and hasattr(val, "_obj"):
|
|
146
|
+
return True
|
|
147
|
+
# Also check for lock types which need to be shared across processes
|
|
148
|
+
if isinstance(val, (Lock, RLock)):
|
|
149
|
+
return True
|
|
150
|
+
return False
|
|
151
|
+
|
|
152
|
+
@classmethod
|
|
153
|
+
def _find_mp_values(cls, env_or_transform, values, prefix=()):
|
|
154
|
+
from torchrl.envs.transforms.transforms import Compose, TransformedEnv
|
|
155
|
+
|
|
156
|
+
if isinstance(env_or_transform, EnvBase) and isinstance(
|
|
157
|
+
env_or_transform, TransformedEnv
|
|
158
|
+
):
|
|
159
|
+
cls._find_mp_values(
|
|
160
|
+
env_or_transform.transform,
|
|
161
|
+
values=values,
|
|
162
|
+
prefix=prefix + ("transform",),
|
|
163
|
+
)
|
|
164
|
+
cls._find_mp_values(
|
|
165
|
+
env_or_transform.base_env, values=values, prefix=prefix + ("base_env",)
|
|
166
|
+
)
|
|
167
|
+
elif isinstance(env_or_transform, Compose):
|
|
168
|
+
for i, t in enumerate(env_or_transform.transforms):
|
|
169
|
+
cls._find_mp_values(t, values=values, prefix=prefix + (i,))
|
|
170
|
+
for k, v in env_or_transform.__dict__.items():
|
|
171
|
+
if cls._is_mp_value(v):
|
|
172
|
+
values.append((prefix + (k,), v))
|
|
173
|
+
return values
|
|
174
|
+
|
|
175
|
+
def init_(self) -> EnvCreator:
|
|
176
|
+
shadow_env = self.create_env_fn(**self.create_env_kwargs)
|
|
177
|
+
tensordict = shadow_env.reset()
|
|
178
|
+
shadow_env.rand_step(tensordict)
|
|
179
|
+
self.env_type = type(shadow_env)
|
|
180
|
+
self._transform_state_dict = shadow_env.state_dict()
|
|
181
|
+
# Extract any mp.Value object from the env
|
|
182
|
+
self._mp_values = self._find_mp_values(shadow_env, values=[])
|
|
183
|
+
|
|
184
|
+
if self._share_memory:
|
|
185
|
+
self.share_memory(self._transform_state_dict)
|
|
186
|
+
self.initialized = True
|
|
187
|
+
self.meta_data = EnvMetaData.metadata_from_env(shadow_env)
|
|
188
|
+
shadow_env.close()
|
|
189
|
+
del shadow_env
|
|
190
|
+
return self
|
|
191
|
+
|
|
192
|
+
@classmethod
|
|
193
|
+
def _set_mp_value(cls, env, key, value):
|
|
194
|
+
if len(key) > 1:
|
|
195
|
+
if isinstance(key[0], int):
|
|
196
|
+
return cls._set_mp_value(env[key[0]], key[1:], value)
|
|
197
|
+
else:
|
|
198
|
+
return cls._set_mp_value(getattr(env, key[0]), key[1:], value)
|
|
199
|
+
else:
|
|
200
|
+
setattr(env, key[0], value)
|
|
201
|
+
|
|
202
|
+
def __call__(self, **kwargs) -> EnvBase:
|
|
203
|
+
if not self.initialized:
|
|
204
|
+
raise RuntimeError("EnvCreator must be initialized before being called.")
|
|
205
|
+
kwargs.update(self.create_env_kwargs) # create_env_kwargs precedes
|
|
206
|
+
env = self.create_env_fn(**kwargs)
|
|
207
|
+
if self._mp_values:
|
|
208
|
+
for k, v in self._mp_values:
|
|
209
|
+
self._set_mp_value(env, k, v)
|
|
210
|
+
env.load_state_dict(self._transform_state_dict, strict=False)
|
|
211
|
+
return env
|
|
212
|
+
|
|
213
|
+
def state_dict(self) -> OrderedDict:
|
|
214
|
+
if self._transform_state_dict is None:
|
|
215
|
+
return OrderedDict()
|
|
216
|
+
return self._transform_state_dict
|
|
217
|
+
|
|
218
|
+
def load_state_dict(self, state_dict: OrderedDict) -> None:
|
|
219
|
+
if self._transform_state_dict is not None:
|
|
220
|
+
for key, item in state_dict.items():
|
|
221
|
+
item_to_update = self._transform_state_dict[key]
|
|
222
|
+
item_to_update.copy_(item)
|
|
223
|
+
|
|
224
|
+
def __repr__(self) -> str:
|
|
225
|
+
substr = ", ".join(
|
|
226
|
+
[f"{key}: {type(item)}" for key, item in self.create_env_kwargs]
|
|
227
|
+
)
|
|
228
|
+
return f"EnvCreator({self.create_env_fn}({substr}))"
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
def env_creator(fun: Callable) -> EnvCreator:
|
|
232
|
+
"""Helper function to call `EnvCreator`."""
|
|
233
|
+
return EnvCreator(fun)
|
|
234
|
+
|
|
235
|
+
|
|
236
|
+
def get_env_metadata(env_or_creator: EnvBase | Callable, kwargs: dict | None = None):
|
|
237
|
+
"""Retrieves a EnvMetaData object from an env."""
|
|
238
|
+
if isinstance(env_or_creator, (EnvBase,)):
|
|
239
|
+
return EnvMetaData.metadata_from_env(env_or_creator)
|
|
240
|
+
elif not isinstance(env_or_creator, EnvBase) and not isinstance(
|
|
241
|
+
env_or_creator, EnvCreator
|
|
242
|
+
):
|
|
243
|
+
# then env is a creator
|
|
244
|
+
if kwargs is None:
|
|
245
|
+
kwargs = {}
|
|
246
|
+
env = env_or_creator(**kwargs)
|
|
247
|
+
return EnvMetaData.metadata_from_env(env)
|
|
248
|
+
elif isinstance(env_or_creator, EnvCreator):
|
|
249
|
+
if not (
|
|
250
|
+
kwargs == env_or_creator.create_env_kwargs
|
|
251
|
+
or kwargs is None
|
|
252
|
+
or len(kwargs) == 0
|
|
253
|
+
):
|
|
254
|
+
raise RuntimeError(
|
|
255
|
+
"kwargs mismatch between EnvCreator and the kwargs provided to get_env_metadata:"
|
|
256
|
+
f"got EnvCreator.create_env_kwargs={env_or_creator.create_env_kwargs} and "
|
|
257
|
+
f"kwargs = {kwargs}"
|
|
258
|
+
)
|
|
259
|
+
return env_or_creator.meta_data.clone()
|
|
260
|
+
else:
|
|
261
|
+
raise NotImplementedError(
|
|
262
|
+
f"env of type {type(env_or_creator)} is not supported by get_env_metadata."
|
|
263
|
+
)
|