torchrl 0.11.0__cp314-cp314t-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/benchmark_batched_envs.py +104 -0
- benchmarks/conftest.py +91 -0
- benchmarks/ecosystem/gym_env_throughput.py +321 -0
- benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
- benchmarks/requirements.txt +7 -0
- benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
- benchmarks/test_collectors_benchmark.py +240 -0
- benchmarks/test_compressed_storage_benchmark.py +145 -0
- benchmarks/test_envs_benchmark.py +133 -0
- benchmarks/test_llm.py +101 -0
- benchmarks/test_non_tensor_env_benchmark.py +70 -0
- benchmarks/test_objectives_benchmarks.py +1199 -0
- benchmarks/test_replaybuffer_benchmark.py +254 -0
- sota-check/README.md +35 -0
- sota-implementations/README.md +142 -0
- sota-implementations/a2c/README.md +39 -0
- sota-implementations/a2c/a2c_atari.py +291 -0
- sota-implementations/a2c/a2c_mujoco.py +273 -0
- sota-implementations/a2c/utils_atari.py +240 -0
- sota-implementations/a2c/utils_mujoco.py +160 -0
- sota-implementations/bandits/README.md +7 -0
- sota-implementations/bandits/dqn.py +126 -0
- sota-implementations/cql/cql_offline.py +198 -0
- sota-implementations/cql/cql_online.py +249 -0
- sota-implementations/cql/discrete_cql_offline.py +180 -0
- sota-implementations/cql/discrete_cql_online.py +227 -0
- sota-implementations/cql/utils.py +471 -0
- sota-implementations/crossq/crossq.py +271 -0
- sota-implementations/crossq/utils.py +320 -0
- sota-implementations/ddpg/ddpg.py +231 -0
- sota-implementations/ddpg/utils.py +325 -0
- sota-implementations/decision_transformer/dt.py +163 -0
- sota-implementations/decision_transformer/lamb.py +167 -0
- sota-implementations/decision_transformer/online_dt.py +178 -0
- sota-implementations/decision_transformer/utils.py +562 -0
- sota-implementations/discrete_sac/discrete_sac.py +243 -0
- sota-implementations/discrete_sac/utils.py +324 -0
- sota-implementations/dqn/README.md +30 -0
- sota-implementations/dqn/dqn_atari.py +272 -0
- sota-implementations/dqn/dqn_cartpole.py +236 -0
- sota-implementations/dqn/utils_atari.py +132 -0
- sota-implementations/dqn/utils_cartpole.py +90 -0
- sota-implementations/dreamer/README.md +129 -0
- sota-implementations/dreamer/dreamer.py +586 -0
- sota-implementations/dreamer/dreamer_utils.py +1107 -0
- sota-implementations/expert-iteration/README.md +352 -0
- sota-implementations/expert-iteration/ei_utils.py +770 -0
- sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
- sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
- sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
- sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
- sota-implementations/gail/gail.py +327 -0
- sota-implementations/gail/gail_utils.py +68 -0
- sota-implementations/gail/ppo_utils.py +157 -0
- sota-implementations/grpo/README.md +273 -0
- sota-implementations/grpo/grpo-async.py +437 -0
- sota-implementations/grpo/grpo-sync.py +435 -0
- sota-implementations/grpo/grpo_utils.py +843 -0
- sota-implementations/grpo/requirements_gsm8k.txt +11 -0
- sota-implementations/grpo/requirements_ifeval.txt +16 -0
- sota-implementations/impala/README.md +33 -0
- sota-implementations/impala/impala_multi_node_ray.py +292 -0
- sota-implementations/impala/impala_multi_node_submitit.py +284 -0
- sota-implementations/impala/impala_single_node.py +261 -0
- sota-implementations/impala/utils.py +184 -0
- sota-implementations/iql/discrete_iql.py +230 -0
- sota-implementations/iql/iql_offline.py +164 -0
- sota-implementations/iql/iql_online.py +225 -0
- sota-implementations/iql/utils.py +437 -0
- sota-implementations/multiagent/README.md +74 -0
- sota-implementations/multiagent/iql.py +237 -0
- sota-implementations/multiagent/maddpg_iddpg.py +266 -0
- sota-implementations/multiagent/mappo_ippo.py +267 -0
- sota-implementations/multiagent/qmix_vdn.py +271 -0
- sota-implementations/multiagent/sac.py +337 -0
- sota-implementations/multiagent/utils/__init__.py +4 -0
- sota-implementations/multiagent/utils/logging.py +151 -0
- sota-implementations/multiagent/utils/utils.py +43 -0
- sota-implementations/ppo/README.md +29 -0
- sota-implementations/ppo/ppo_atari.py +305 -0
- sota-implementations/ppo/ppo_mujoco.py +293 -0
- sota-implementations/ppo/utils_atari.py +238 -0
- sota-implementations/ppo/utils_mujoco.py +152 -0
- sota-implementations/ppo_trainer/train.py +21 -0
- sota-implementations/redq/README.md +7 -0
- sota-implementations/redq/redq.py +199 -0
- sota-implementations/redq/utils.py +1060 -0
- sota-implementations/sac/sac-async.py +266 -0
- sota-implementations/sac/sac.py +239 -0
- sota-implementations/sac/utils.py +381 -0
- sota-implementations/sac_trainer/train.py +16 -0
- sota-implementations/td3/td3.py +254 -0
- sota-implementations/td3/utils.py +319 -0
- sota-implementations/td3_bc/td3_bc.py +177 -0
- sota-implementations/td3_bc/utils.py +251 -0
- torchrl/.dylibs/libc++.1.0.dylib +0 -0
- torchrl/__init__.py +144 -0
- torchrl/_extension.py +74 -0
- torchrl/_torchrl.cpython-314t-darwin.so +0 -0
- torchrl/_utils.py +1431 -0
- torchrl/collectors/__init__.py +48 -0
- torchrl/collectors/_base.py +1058 -0
- torchrl/collectors/_constants.py +88 -0
- torchrl/collectors/_multi_async.py +324 -0
- torchrl/collectors/_multi_base.py +1805 -0
- torchrl/collectors/_multi_sync.py +464 -0
- torchrl/collectors/_runner.py +581 -0
- torchrl/collectors/_single.py +2009 -0
- torchrl/collectors/_single_async.py +259 -0
- torchrl/collectors/collectors.py +62 -0
- torchrl/collectors/distributed/__init__.py +32 -0
- torchrl/collectors/distributed/default_configs.py +133 -0
- torchrl/collectors/distributed/generic.py +1306 -0
- torchrl/collectors/distributed/ray.py +1092 -0
- torchrl/collectors/distributed/rpc.py +1006 -0
- torchrl/collectors/distributed/sync.py +731 -0
- torchrl/collectors/distributed/utils.py +160 -0
- torchrl/collectors/llm/__init__.py +10 -0
- torchrl/collectors/llm/base.py +494 -0
- torchrl/collectors/llm/ray_collector.py +275 -0
- torchrl/collectors/llm/utils.py +36 -0
- torchrl/collectors/llm/weight_update/__init__.py +10 -0
- torchrl/collectors/llm/weight_update/vllm.py +348 -0
- torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
- torchrl/collectors/utils.py +433 -0
- torchrl/collectors/weight_update.py +591 -0
- torchrl/csrc/numpy_utils.h +38 -0
- torchrl/csrc/pybind.cpp +27 -0
- torchrl/csrc/segment_tree.h +458 -0
- torchrl/csrc/torch_utils.h +34 -0
- torchrl/csrc/utils.cpp +48 -0
- torchrl/csrc/utils.h +31 -0
- torchrl/data/__init__.py +187 -0
- torchrl/data/datasets/__init__.py +58 -0
- torchrl/data/datasets/atari_dqn.py +878 -0
- torchrl/data/datasets/common.py +281 -0
- torchrl/data/datasets/d4rl.py +489 -0
- torchrl/data/datasets/d4rl_infos.py +187 -0
- torchrl/data/datasets/gen_dgrl.py +375 -0
- torchrl/data/datasets/minari_data.py +643 -0
- torchrl/data/datasets/openml.py +177 -0
- torchrl/data/datasets/openx.py +798 -0
- torchrl/data/datasets/roboset.py +363 -0
- torchrl/data/datasets/utils.py +11 -0
- torchrl/data/datasets/vd4rl.py +432 -0
- torchrl/data/llm/__init__.py +34 -0
- torchrl/data/llm/dataset.py +491 -0
- torchrl/data/llm/history.py +1378 -0
- torchrl/data/llm/prompt.py +198 -0
- torchrl/data/llm/reward.py +225 -0
- torchrl/data/llm/topk.py +186 -0
- torchrl/data/llm/utils.py +543 -0
- torchrl/data/map/__init__.py +21 -0
- torchrl/data/map/hash.py +185 -0
- torchrl/data/map/query.py +204 -0
- torchrl/data/map/tdstorage.py +363 -0
- torchrl/data/map/tree.py +1434 -0
- torchrl/data/map/utils.py +103 -0
- torchrl/data/postprocs/__init__.py +8 -0
- torchrl/data/postprocs/postprocs.py +391 -0
- torchrl/data/replay_buffers/__init__.py +99 -0
- torchrl/data/replay_buffers/checkpointers.py +622 -0
- torchrl/data/replay_buffers/ray_buffer.py +292 -0
- torchrl/data/replay_buffers/replay_buffers.py +2376 -0
- torchrl/data/replay_buffers/samplers.py +2578 -0
- torchrl/data/replay_buffers/scheduler.py +265 -0
- torchrl/data/replay_buffers/storages.py +2412 -0
- torchrl/data/replay_buffers/utils.py +1042 -0
- torchrl/data/replay_buffers/writers.py +781 -0
- torchrl/data/tensor_specs.py +7101 -0
- torchrl/data/utils.py +334 -0
- torchrl/envs/__init__.py +265 -0
- torchrl/envs/async_envs.py +1105 -0
- torchrl/envs/batched_envs.py +3093 -0
- torchrl/envs/common.py +4241 -0
- torchrl/envs/custom/__init__.py +11 -0
- torchrl/envs/custom/chess.py +617 -0
- torchrl/envs/custom/llm.py +214 -0
- torchrl/envs/custom/pendulum.py +401 -0
- torchrl/envs/custom/san_moves.txt +29274 -0
- torchrl/envs/custom/tictactoeenv.py +288 -0
- torchrl/envs/env_creator.py +263 -0
- torchrl/envs/gym_like.py +752 -0
- torchrl/envs/libs/__init__.py +68 -0
- torchrl/envs/libs/_gym_utils.py +326 -0
- torchrl/envs/libs/brax.py +846 -0
- torchrl/envs/libs/dm_control.py +544 -0
- torchrl/envs/libs/envpool.py +447 -0
- torchrl/envs/libs/gym.py +2239 -0
- torchrl/envs/libs/habitat.py +138 -0
- torchrl/envs/libs/isaac_lab.py +87 -0
- torchrl/envs/libs/isaacgym.py +203 -0
- torchrl/envs/libs/jax_utils.py +166 -0
- torchrl/envs/libs/jumanji.py +963 -0
- torchrl/envs/libs/meltingpot.py +599 -0
- torchrl/envs/libs/openml.py +153 -0
- torchrl/envs/libs/openspiel.py +652 -0
- torchrl/envs/libs/pettingzoo.py +1042 -0
- torchrl/envs/libs/procgen.py +351 -0
- torchrl/envs/libs/robohive.py +429 -0
- torchrl/envs/libs/smacv2.py +645 -0
- torchrl/envs/libs/unity_mlagents.py +891 -0
- torchrl/envs/libs/utils.py +147 -0
- torchrl/envs/libs/vmas.py +813 -0
- torchrl/envs/llm/__init__.py +63 -0
- torchrl/envs/llm/chat.py +730 -0
- torchrl/envs/llm/datasets/README.md +4 -0
- torchrl/envs/llm/datasets/__init__.py +17 -0
- torchrl/envs/llm/datasets/gsm8k.py +353 -0
- torchrl/envs/llm/datasets/ifeval.py +274 -0
- torchrl/envs/llm/envs.py +789 -0
- torchrl/envs/llm/libs/README.md +3 -0
- torchrl/envs/llm/libs/__init__.py +8 -0
- torchrl/envs/llm/libs/mlgym.py +869 -0
- torchrl/envs/llm/reward/__init__.py +10 -0
- torchrl/envs/llm/reward/gsm8k.py +324 -0
- torchrl/envs/llm/reward/ifeval/README.md +13 -0
- torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
- torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
- torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
- torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
- torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
- torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
- torchrl/envs/llm/transforms/__init__.py +55 -0
- torchrl/envs/llm/transforms/browser.py +292 -0
- torchrl/envs/llm/transforms/dataloading.py +859 -0
- torchrl/envs/llm/transforms/format.py +73 -0
- torchrl/envs/llm/transforms/kl.py +1544 -0
- torchrl/envs/llm/transforms/policy_version.py +189 -0
- torchrl/envs/llm/transforms/reason.py +323 -0
- torchrl/envs/llm/transforms/tokenizer.py +321 -0
- torchrl/envs/llm/transforms/tools.py +1955 -0
- torchrl/envs/model_based/__init__.py +9 -0
- torchrl/envs/model_based/common.py +180 -0
- torchrl/envs/model_based/dreamer.py +112 -0
- torchrl/envs/transforms/__init__.py +147 -0
- torchrl/envs/transforms/functional.py +48 -0
- torchrl/envs/transforms/gym_transforms.py +203 -0
- torchrl/envs/transforms/module.py +341 -0
- torchrl/envs/transforms/r3m.py +372 -0
- torchrl/envs/transforms/ray_service.py +663 -0
- torchrl/envs/transforms/rb_transforms.py +214 -0
- torchrl/envs/transforms/transforms.py +11835 -0
- torchrl/envs/transforms/utils.py +94 -0
- torchrl/envs/transforms/vc1.py +307 -0
- torchrl/envs/transforms/vecnorm.py +845 -0
- torchrl/envs/transforms/vip.py +407 -0
- torchrl/envs/utils.py +1718 -0
- torchrl/envs/vec_envs.py +11 -0
- torchrl/modules/__init__.py +206 -0
- torchrl/modules/distributions/__init__.py +73 -0
- torchrl/modules/distributions/continuous.py +830 -0
- torchrl/modules/distributions/discrete.py +908 -0
- torchrl/modules/distributions/truncated_normal.py +187 -0
- torchrl/modules/distributions/utils.py +233 -0
- torchrl/modules/llm/__init__.py +62 -0
- torchrl/modules/llm/backends/__init__.py +65 -0
- torchrl/modules/llm/backends/vllm/__init__.py +94 -0
- torchrl/modules/llm/backends/vllm/_models.py +46 -0
- torchrl/modules/llm/backends/vllm/base.py +72 -0
- torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
- torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
- torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
- torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
- torchrl/modules/llm/policies/__init__.py +28 -0
- torchrl/modules/llm/policies/common.py +1809 -0
- torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
- torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
- torchrl/modules/llm/utils.py +23 -0
- torchrl/modules/mcts/__init__.py +21 -0
- torchrl/modules/mcts/scores.py +579 -0
- torchrl/modules/models/__init__.py +86 -0
- torchrl/modules/models/batchrenorm.py +119 -0
- torchrl/modules/models/decision_transformer.py +179 -0
- torchrl/modules/models/exploration.py +731 -0
- torchrl/modules/models/llm.py +156 -0
- torchrl/modules/models/model_based.py +596 -0
- torchrl/modules/models/models.py +1712 -0
- torchrl/modules/models/multiagent.py +1067 -0
- torchrl/modules/models/recipes/impala.py +185 -0
- torchrl/modules/models/utils.py +162 -0
- torchrl/modules/planners/__init__.py +10 -0
- torchrl/modules/planners/cem.py +228 -0
- torchrl/modules/planners/common.py +73 -0
- torchrl/modules/planners/mppi.py +265 -0
- torchrl/modules/tensordict_module/__init__.py +89 -0
- torchrl/modules/tensordict_module/actors.py +2457 -0
- torchrl/modules/tensordict_module/common.py +529 -0
- torchrl/modules/tensordict_module/exploration.py +814 -0
- torchrl/modules/tensordict_module/probabilistic.py +321 -0
- torchrl/modules/tensordict_module/rnn.py +1639 -0
- torchrl/modules/tensordict_module/sequence.py +132 -0
- torchrl/modules/tensordict_module/world_models.py +34 -0
- torchrl/modules/utils/__init__.py +38 -0
- torchrl/modules/utils/mappings.py +9 -0
- torchrl/modules/utils/utils.py +89 -0
- torchrl/objectives/__init__.py +78 -0
- torchrl/objectives/a2c.py +659 -0
- torchrl/objectives/common.py +753 -0
- torchrl/objectives/cql.py +1346 -0
- torchrl/objectives/crossq.py +710 -0
- torchrl/objectives/ddpg.py +453 -0
- torchrl/objectives/decision_transformer.py +371 -0
- torchrl/objectives/deprecated.py +516 -0
- torchrl/objectives/dqn.py +683 -0
- torchrl/objectives/dreamer.py +488 -0
- torchrl/objectives/functional.py +48 -0
- torchrl/objectives/gail.py +258 -0
- torchrl/objectives/iql.py +996 -0
- torchrl/objectives/llm/__init__.py +30 -0
- torchrl/objectives/llm/grpo.py +846 -0
- torchrl/objectives/llm/sft.py +482 -0
- torchrl/objectives/multiagent/__init__.py +8 -0
- torchrl/objectives/multiagent/qmixer.py +396 -0
- torchrl/objectives/ppo.py +1669 -0
- torchrl/objectives/redq.py +683 -0
- torchrl/objectives/reinforce.py +530 -0
- torchrl/objectives/sac.py +1580 -0
- torchrl/objectives/td3.py +570 -0
- torchrl/objectives/td3_bc.py +625 -0
- torchrl/objectives/utils.py +782 -0
- torchrl/objectives/value/__init__.py +28 -0
- torchrl/objectives/value/advantages.py +1956 -0
- torchrl/objectives/value/functional.py +1459 -0
- torchrl/objectives/value/utils.py +360 -0
- torchrl/record/__init__.py +17 -0
- torchrl/record/loggers/__init__.py +23 -0
- torchrl/record/loggers/common.py +48 -0
- torchrl/record/loggers/csv.py +226 -0
- torchrl/record/loggers/mlflow.py +142 -0
- torchrl/record/loggers/tensorboard.py +139 -0
- torchrl/record/loggers/trackio.py +163 -0
- torchrl/record/loggers/utils.py +78 -0
- torchrl/record/loggers/wandb.py +214 -0
- torchrl/record/recorder.py +554 -0
- torchrl/services/__init__.py +79 -0
- torchrl/services/base.py +109 -0
- torchrl/services/ray_service.py +453 -0
- torchrl/testing/__init__.py +107 -0
- torchrl/testing/assertions.py +179 -0
- torchrl/testing/dist_utils.py +122 -0
- torchrl/testing/env_creators.py +227 -0
- torchrl/testing/env_helper.py +35 -0
- torchrl/testing/gym_helpers.py +156 -0
- torchrl/testing/llm_mocks.py +119 -0
- torchrl/testing/mocking_classes.py +2720 -0
- torchrl/testing/modules.py +295 -0
- torchrl/testing/mp_helpers.py +15 -0
- torchrl/testing/ray_helpers.py +293 -0
- torchrl/testing/utils.py +190 -0
- torchrl/trainers/__init__.py +42 -0
- torchrl/trainers/algorithms/__init__.py +11 -0
- torchrl/trainers/algorithms/configs/__init__.py +705 -0
- torchrl/trainers/algorithms/configs/collectors.py +216 -0
- torchrl/trainers/algorithms/configs/common.py +41 -0
- torchrl/trainers/algorithms/configs/data.py +308 -0
- torchrl/trainers/algorithms/configs/envs.py +104 -0
- torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
- torchrl/trainers/algorithms/configs/logging.py +80 -0
- torchrl/trainers/algorithms/configs/modules.py +570 -0
- torchrl/trainers/algorithms/configs/objectives.py +177 -0
- torchrl/trainers/algorithms/configs/trainers.py +340 -0
- torchrl/trainers/algorithms/configs/transforms.py +955 -0
- torchrl/trainers/algorithms/configs/utils.py +252 -0
- torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
- torchrl/trainers/algorithms/configs/weight_update.py +159 -0
- torchrl/trainers/algorithms/ppo.py +373 -0
- torchrl/trainers/algorithms/sac.py +308 -0
- torchrl/trainers/helpers/__init__.py +40 -0
- torchrl/trainers/helpers/collectors.py +416 -0
- torchrl/trainers/helpers/envs.py +573 -0
- torchrl/trainers/helpers/logger.py +33 -0
- torchrl/trainers/helpers/losses.py +132 -0
- torchrl/trainers/helpers/models.py +658 -0
- torchrl/trainers/helpers/replay_buffer.py +59 -0
- torchrl/trainers/helpers/trainers.py +301 -0
- torchrl/trainers/trainers.py +2052 -0
- torchrl/weight_update/__init__.py +33 -0
- torchrl/weight_update/_distributed.py +749 -0
- torchrl/weight_update/_mp.py +624 -0
- torchrl/weight_update/_noupdate.py +102 -0
- torchrl/weight_update/_ray.py +1032 -0
- torchrl/weight_update/_rpc.py +284 -0
- torchrl/weight_update/_shared.py +891 -0
- torchrl/weight_update/llm/__init__.py +32 -0
- torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
- torchrl/weight_update/llm/vllm_nccl.py +710 -0
- torchrl/weight_update/utils.py +73 -0
- torchrl/weight_update/weight_sync_schemes.py +1244 -0
- torchrl-0.11.0.dist-info/METADATA +1308 -0
- torchrl-0.11.0.dist-info/RECORD +395 -0
- torchrl-0.11.0.dist-info/WHEEL +5 -0
- torchrl-0.11.0.dist-info/entry_points.txt +2 -0
- torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
- torchrl-0.11.0.dist-info/top_level.txt +7 -0
|
@@ -0,0 +1,963 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import importlib.util
|
|
8
|
+
|
|
9
|
+
import numpy as np
|
|
10
|
+
import torch
|
|
11
|
+
from packaging import version
|
|
12
|
+
from tensordict import TensorDict, TensorDictBase
|
|
13
|
+
from torchrl.envs.common import _EnvPostInit
|
|
14
|
+
from torchrl.envs.utils import _classproperty
|
|
15
|
+
|
|
16
|
+
_has_jumanji = importlib.util.find_spec("jumanji") is not None
|
|
17
|
+
|
|
18
|
+
from torchrl.data.tensor_specs import (
|
|
19
|
+
Bounded,
|
|
20
|
+
Categorical,
|
|
21
|
+
Composite,
|
|
22
|
+
DEVICE_TYPING,
|
|
23
|
+
MultiCategorical,
|
|
24
|
+
MultiOneHot,
|
|
25
|
+
OneHot,
|
|
26
|
+
TensorSpec,
|
|
27
|
+
Unbounded,
|
|
28
|
+
)
|
|
29
|
+
from torchrl.data.utils import numpy_to_torch_dtype_dict
|
|
30
|
+
from torchrl.envs.gym_like import GymLikeEnv
|
|
31
|
+
|
|
32
|
+
from torchrl.envs.libs.jax_utils import (
|
|
33
|
+
_extract_spec,
|
|
34
|
+
_ndarray_to_tensor,
|
|
35
|
+
_object_to_tensordict,
|
|
36
|
+
_tensordict_to_object,
|
|
37
|
+
_tree_flatten,
|
|
38
|
+
_tree_reshape,
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def _get_envs():
|
|
43
|
+
if not _has_jumanji:
|
|
44
|
+
raise ImportError("Jumanji is not installed in your virtual environment.")
|
|
45
|
+
import jumanji
|
|
46
|
+
|
|
47
|
+
return jumanji.registered_environments()
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def _jumanji_to_torchrl_spec_transform(
|
|
51
|
+
spec,
|
|
52
|
+
dtype: torch.dtype | None = None,
|
|
53
|
+
device: DEVICE_TYPING = None,
|
|
54
|
+
categorical_action_encoding: bool = True,
|
|
55
|
+
) -> TensorSpec:
|
|
56
|
+
import jumanji
|
|
57
|
+
|
|
58
|
+
if isinstance(spec, jumanji.specs.DiscreteArray):
|
|
59
|
+
action_space_cls = Categorical if categorical_action_encoding else OneHot
|
|
60
|
+
if dtype is None:
|
|
61
|
+
dtype = numpy_to_torch_dtype_dict[spec.dtype]
|
|
62
|
+
return action_space_cls(spec.num_values, dtype=dtype, device=device)
|
|
63
|
+
if isinstance(spec, jumanji.specs.MultiDiscreteArray):
|
|
64
|
+
action_space_cls = (
|
|
65
|
+
MultiCategorical if categorical_action_encoding else MultiOneHot
|
|
66
|
+
)
|
|
67
|
+
if dtype is None:
|
|
68
|
+
dtype = numpy_to_torch_dtype_dict[spec.dtype]
|
|
69
|
+
return action_space_cls(
|
|
70
|
+
torch.as_tensor(np.asarray(spec.num_values)), dtype=dtype, device=device
|
|
71
|
+
)
|
|
72
|
+
elif isinstance(spec, jumanji.specs.BoundedArray):
|
|
73
|
+
shape = spec.shape
|
|
74
|
+
if dtype is None:
|
|
75
|
+
dtype = numpy_to_torch_dtype_dict[spec.dtype]
|
|
76
|
+
return Bounded(
|
|
77
|
+
shape=shape,
|
|
78
|
+
low=np.asarray(spec.minimum),
|
|
79
|
+
high=np.asarray(spec.maximum),
|
|
80
|
+
dtype=dtype,
|
|
81
|
+
device=device,
|
|
82
|
+
)
|
|
83
|
+
elif isinstance(spec, jumanji.specs.Array):
|
|
84
|
+
shape = spec.shape
|
|
85
|
+
if dtype is None:
|
|
86
|
+
dtype = numpy_to_torch_dtype_dict[spec.dtype]
|
|
87
|
+
if dtype in (torch.float, torch.double, torch.half):
|
|
88
|
+
return Unbounded(shape=shape, dtype=dtype, device=device)
|
|
89
|
+
else:
|
|
90
|
+
return Unbounded(shape=shape, dtype=dtype, device=device)
|
|
91
|
+
elif isinstance(spec, jumanji.specs.Spec) and hasattr(spec, "__dict__"):
|
|
92
|
+
new_spec = {}
|
|
93
|
+
for key, value in spec.__dict__.items():
|
|
94
|
+
if isinstance(value, jumanji.specs.Spec):
|
|
95
|
+
if key.endswith("_obs"):
|
|
96
|
+
key = key[:-4]
|
|
97
|
+
if key.endswith("_spec"):
|
|
98
|
+
key = key[:-5]
|
|
99
|
+
new_spec[key] = _jumanji_to_torchrl_spec_transform(
|
|
100
|
+
value, dtype, device, categorical_action_encoding
|
|
101
|
+
)
|
|
102
|
+
return Composite(**new_spec)
|
|
103
|
+
else:
|
|
104
|
+
raise TypeError(f"Unsupported spec type {type(spec)}")
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
class _JumanjiMakeRender(_EnvPostInit):
|
|
108
|
+
def __call__(self, *args, **kwargs):
|
|
109
|
+
instance = super().__call__(*args, **kwargs)
|
|
110
|
+
if instance.from_pixels:
|
|
111
|
+
return instance.make_render()
|
|
112
|
+
return instance
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
class JumanjiWrapper(GymLikeEnv, metaclass=_JumanjiMakeRender):
|
|
116
|
+
"""Jumanji's environment wrapper.
|
|
117
|
+
|
|
118
|
+
Jumanji offers a vectorized simulation framework based on Jax.
|
|
119
|
+
TorchRL's wrapper incurs some overhead for the jax-to-torch conversion,
|
|
120
|
+
but computational graphs can still be built on top of the simulated trajectories,
|
|
121
|
+
allowing for backpropagation through the rollout.
|
|
122
|
+
|
|
123
|
+
GitHub: https://github.com/instadeepai/jumanji
|
|
124
|
+
|
|
125
|
+
Doc: https://instadeepai.github.io/jumanji/
|
|
126
|
+
|
|
127
|
+
Paper: https://arxiv.org/abs/2306.09884
|
|
128
|
+
|
|
129
|
+
.. note:: For better performance, turn `jit` on when instantiating this class.
|
|
130
|
+
The `jit` attribute can also be flipped during code execution:
|
|
131
|
+
|
|
132
|
+
>>> env.jit = True # Used jit
|
|
133
|
+
>>> env.jit = False # eager
|
|
134
|
+
|
|
135
|
+
Args:
|
|
136
|
+
env (jumanji.env.Environment): the env to wrap.
|
|
137
|
+
categorical_action_encoding (bool, optional): if ``True``, categorical
|
|
138
|
+
specs will be converted to the TorchRL equivalent (:class:`torchrl.data.Categorical`),
|
|
139
|
+
otherwise a one-hot encoding will be used (:class:`torchrl.data.OneHot`).
|
|
140
|
+
Defaults to ``False``.
|
|
141
|
+
|
|
142
|
+
Keyword Args:
|
|
143
|
+
batch_size (torch.Size, optional): the batch size of the environment.
|
|
144
|
+
With ``jumanji``, this indicates the number of vectorized environments.
|
|
145
|
+
If the batch-size is empty, the environment is not batch-locked and an arbitrary number
|
|
146
|
+
of environments can be executed simultaneously.
|
|
147
|
+
Defaults to ``torch.Size([])``.
|
|
148
|
+
|
|
149
|
+
>>> import jumanji
|
|
150
|
+
>>> from torchrl.envs import JumanjiWrapper
|
|
151
|
+
>>> base_env = jumanji.make("Snake-v1")
|
|
152
|
+
>>> env = JumanjiWrapper(base_env)
|
|
153
|
+
>>> # Set the batch-size of the TensorDict instead of the env allows to control the number
|
|
154
|
+
>>> # of envs being run simultaneously
|
|
155
|
+
>>> tdreset = env.reset(TensorDict(batch_size=[32]))
|
|
156
|
+
>>> # Execute a rollout until all envs are done or max steps is reached, whichever comes first
|
|
157
|
+
>>> rollout = env.rollout(100, break_when_all_done=True, auto_reset=False, tensordict=tdreset)
|
|
158
|
+
|
|
159
|
+
from_pixels (bool, optional): Whether the environment should render its output.
|
|
160
|
+
This will drastically impact the environment throughput. Only the first environment
|
|
161
|
+
will be rendered. See :meth:`~torchrl.envs.JumanjiWrapper.render` for more information.
|
|
162
|
+
Defaults to `False`.
|
|
163
|
+
frame_skip (int, optional): if provided, indicates for how many steps the
|
|
164
|
+
same action is to be repeated. The observation returned will be the
|
|
165
|
+
last observation of the sequence, whereas the reward will be the sum
|
|
166
|
+
of rewards across steps.
|
|
167
|
+
device (torch.device, optional): if provided, the device on which the data
|
|
168
|
+
is to be cast. Defaults to ``torch.device("cpu")``.
|
|
169
|
+
allow_done_after_reset (bool, optional): if ``True``, it is tolerated
|
|
170
|
+
for envs to be ``done`` just after :meth:`reset` is called.
|
|
171
|
+
Defaults to ``False``.
|
|
172
|
+
jit (bool, optional): whether the step and reset method should be wrapped in `jit`.
|
|
173
|
+
Defaults to ``False``.
|
|
174
|
+
|
|
175
|
+
Attributes:
|
|
176
|
+
available_envs: environments available to build
|
|
177
|
+
|
|
178
|
+
Examples:
|
|
179
|
+
>>> import jumanji
|
|
180
|
+
>>> from torchrl.envs import JumanjiWrapper
|
|
181
|
+
>>> base_env = jumanji.make("Snake-v1")
|
|
182
|
+
>>> env = JumanjiWrapper(base_env)
|
|
183
|
+
>>> env.set_seed(0)
|
|
184
|
+
>>> td = env.reset()
|
|
185
|
+
>>> td["action"] = env.action_spec.rand()
|
|
186
|
+
>>> td = env.step(td)
|
|
187
|
+
>>> print(td)
|
|
188
|
+
TensorDict(
|
|
189
|
+
fields={
|
|
190
|
+
action: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int32, is_shared=False),
|
|
191
|
+
action_mask: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
192
|
+
done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
193
|
+
grid: Tensor(shape=torch.Size([12, 12, 5]), device=cpu, dtype=torch.float32, is_shared=False),
|
|
194
|
+
next: TensorDict(
|
|
195
|
+
fields={
|
|
196
|
+
action_mask: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
197
|
+
done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
198
|
+
grid: Tensor(shape=torch.Size([12, 12, 5]), device=cpu, dtype=torch.float32, is_shared=False),
|
|
199
|
+
reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
|
|
200
|
+
state: TensorDict(
|
|
201
|
+
fields={
|
|
202
|
+
action_mask: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
203
|
+
body: Tensor(shape=torch.Size([12, 12]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
204
|
+
body_state: Tensor(shape=torch.Size([12, 12]), device=cpu, dtype=torch.int32, is_shared=False),
|
|
205
|
+
fruit_position: TensorDict(
|
|
206
|
+
fields={
|
|
207
|
+
col: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int32, is_shared=False),
|
|
208
|
+
row: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int32, is_shared=False)},
|
|
209
|
+
batch_size=torch.Size([]),
|
|
210
|
+
device=cpu,
|
|
211
|
+
is_shared=False),
|
|
212
|
+
head_position: TensorDict(
|
|
213
|
+
fields={
|
|
214
|
+
col: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int32, is_shared=False),
|
|
215
|
+
row: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int32, is_shared=False)},
|
|
216
|
+
batch_size=torch.Size([]),
|
|
217
|
+
device=cpu,
|
|
218
|
+
is_shared=False),
|
|
219
|
+
key: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.int32, is_shared=False),
|
|
220
|
+
length: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int32, is_shared=False),
|
|
221
|
+
step_count: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int32, is_shared=False),
|
|
222
|
+
tail: Tensor(shape=torch.Size([12, 12]), device=cpu, dtype=torch.bool, is_shared=False)},
|
|
223
|
+
batch_size=torch.Size([]),
|
|
224
|
+
device=cpu,
|
|
225
|
+
is_shared=False),
|
|
226
|
+
step_count: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int32, is_shared=False),
|
|
227
|
+
terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},
|
|
228
|
+
batch_size=torch.Size([]),
|
|
229
|
+
device=cpu,
|
|
230
|
+
is_shared=False),
|
|
231
|
+
state: TensorDict(
|
|
232
|
+
fields={
|
|
233
|
+
action_mask: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
234
|
+
body: Tensor(shape=torch.Size([12, 12]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
235
|
+
body_state: Tensor(shape=torch.Size([12, 12]), device=cpu, dtype=torch.int32, is_shared=False),
|
|
236
|
+
fruit_position: TensorDict(
|
|
237
|
+
fields={
|
|
238
|
+
col: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int32, is_shared=False),
|
|
239
|
+
row: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int32, is_shared=False)},
|
|
240
|
+
batch_size=torch.Size([]),
|
|
241
|
+
device=cpu,
|
|
242
|
+
is_shared=False),
|
|
243
|
+
head_position: TensorDict(
|
|
244
|
+
fields={
|
|
245
|
+
col: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int32, is_shared=False),
|
|
246
|
+
row: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int32, is_shared=False)},
|
|
247
|
+
batch_size=torch.Size([]),
|
|
248
|
+
device=cpu,
|
|
249
|
+
is_shared=False),
|
|
250
|
+
key: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.int32, is_shared=False),
|
|
251
|
+
length: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int32, is_shared=False),
|
|
252
|
+
step_count: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int32, is_shared=False),
|
|
253
|
+
tail: Tensor(shape=torch.Size([12, 12]), device=cpu, dtype=torch.bool, is_shared=False)},
|
|
254
|
+
batch_size=torch.Size([]),
|
|
255
|
+
device=cpu,
|
|
256
|
+
is_shared=False),
|
|
257
|
+
step_count: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int32, is_shared=False),
|
|
258
|
+
terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},
|
|
259
|
+
batch_size=torch.Size([]),
|
|
260
|
+
device=cpu,
|
|
261
|
+
is_shared=False)
|
|
262
|
+
>>> print(env.available_envs)
|
|
263
|
+
['Game2048-v1',
|
|
264
|
+
'Maze-v0',
|
|
265
|
+
'Cleaner-v0',
|
|
266
|
+
'CVRP-v1',
|
|
267
|
+
'MultiCVRP-v0',
|
|
268
|
+
'Minesweeper-v0',
|
|
269
|
+
'RubiksCube-v0',
|
|
270
|
+
'Knapsack-v1',
|
|
271
|
+
'Sudoku-v0',
|
|
272
|
+
'Snake-v1',
|
|
273
|
+
'TSP-v1',
|
|
274
|
+
'Connector-v2',
|
|
275
|
+
'MMST-v0',
|
|
276
|
+
'GraphColoring-v0',
|
|
277
|
+
'RubiksCube-partly-scrambled-v0',
|
|
278
|
+
'RobotWarehouse-v0',
|
|
279
|
+
'Tetris-v0',
|
|
280
|
+
'BinPack-v2',
|
|
281
|
+
'Sudoku-very-easy-v0',
|
|
282
|
+
'JobShop-v0']
|
|
283
|
+
|
|
284
|
+
To take advante of Jumanji, one usually executes multiple environments at the
|
|
285
|
+
same time.
|
|
286
|
+
|
|
287
|
+
>>> import jumanji
|
|
288
|
+
>>> from torchrl.envs import JumanjiWrapper
|
|
289
|
+
>>> base_env = jumanji.make("Snake-v1")
|
|
290
|
+
>>> env = JumanjiWrapper(base_env, batch_size=[10])
|
|
291
|
+
>>> env.set_seed(0)
|
|
292
|
+
>>> td = env.reset()
|
|
293
|
+
>>> td["action"] = env.action_spec.rand()
|
|
294
|
+
>>> td = env.step(td)
|
|
295
|
+
|
|
296
|
+
In the following example, we iteratively test different batch sizes
|
|
297
|
+
and report the execution time for a short rollout:
|
|
298
|
+
|
|
299
|
+
Examples:
|
|
300
|
+
>>> from torch.utils.benchmark import Timer
|
|
301
|
+
>>> for batch_size in [4, 16, 128]:
|
|
302
|
+
... timer = Timer(
|
|
303
|
+
... '''
|
|
304
|
+
... env.rollout(100)
|
|
305
|
+
... ''',
|
|
306
|
+
... setup=f'''
|
|
307
|
+
... from torchrl.envs import JumanjiWrapper
|
|
308
|
+
... import jumanji
|
|
309
|
+
... env = JumanjiWrapper(jumanji.make('Snake-v1'), batch_size=[{batch_size}])
|
|
310
|
+
... env.set_seed(0)
|
|
311
|
+
... env.rollout(2)
|
|
312
|
+
... ''')
|
|
313
|
+
... print(batch_size, timer.timeit(number=10))
|
|
314
|
+
4
|
|
315
|
+
env.rollout(100)
|
|
316
|
+
setup: [...]
|
|
317
|
+
Median: 122.40 ms
|
|
318
|
+
2 measurements, 1 runs per measurement, 1 thread
|
|
319
|
+
|
|
320
|
+
16
|
|
321
|
+
env.rollout(100)
|
|
322
|
+
setup: [...]
|
|
323
|
+
Median: 134.39 ms
|
|
324
|
+
2 measurements, 1 runs per measurement, 1 thread
|
|
325
|
+
|
|
326
|
+
128
|
|
327
|
+
env.rollout(100)
|
|
328
|
+
setup: [...]
|
|
329
|
+
Median: 172.31 ms
|
|
330
|
+
2 measurements, 1 runs per measurement, 1 thread
|
|
331
|
+
|
|
332
|
+
"""
|
|
333
|
+
|
|
334
|
+
git_url = "https://github.com/instadeepai/jumanji"
|
|
335
|
+
libname = "jumanji"
|
|
336
|
+
|
|
337
|
+
@_classproperty
|
|
338
|
+
def available_envs(cls):
|
|
339
|
+
if not _has_jumanji:
|
|
340
|
+
return []
|
|
341
|
+
return sorted(_get_envs())
|
|
342
|
+
|
|
343
|
+
@property
|
|
344
|
+
def lib(self):
|
|
345
|
+
import jumanji
|
|
346
|
+
|
|
347
|
+
if version.parse(jumanji.__version__) < version.parse("1.0.0"):
|
|
348
|
+
raise ImportError("jumanji version must be >= 1.0.0")
|
|
349
|
+
return jumanji
|
|
350
|
+
|
|
351
|
+
def __init__(
|
|
352
|
+
self,
|
|
353
|
+
env: jumanji.env.Environment = None, # noqa: F821
|
|
354
|
+
categorical_action_encoding=True,
|
|
355
|
+
jit: bool = True,
|
|
356
|
+
**kwargs,
|
|
357
|
+
):
|
|
358
|
+
if not _has_jumanji:
|
|
359
|
+
raise ImportError(
|
|
360
|
+
"jumanji is not installed or importing it failed. Consider checking your installation."
|
|
361
|
+
)
|
|
362
|
+
self.categorical_action_encoding = categorical_action_encoding
|
|
363
|
+
if env is not None:
|
|
364
|
+
kwargs["env"] = env
|
|
365
|
+
batch_locked = kwargs.pop("batch_locked", kwargs.get("batch_size") is not None)
|
|
366
|
+
super().__init__(**kwargs)
|
|
367
|
+
self._batch_locked = batch_locked
|
|
368
|
+
self.jit = jit
|
|
369
|
+
|
|
370
|
+
@property
|
|
371
|
+
def jit(self):
|
|
372
|
+
return self._jit
|
|
373
|
+
|
|
374
|
+
@jit.setter
|
|
375
|
+
def jit(self, value):
|
|
376
|
+
self._jit = value
|
|
377
|
+
if value:
|
|
378
|
+
import jax
|
|
379
|
+
|
|
380
|
+
self._env_reset = jax.jit(self._env.reset)
|
|
381
|
+
self._env_step = jax.jit(self._env.step)
|
|
382
|
+
else:
|
|
383
|
+
self._env_reset = self._env.reset
|
|
384
|
+
self._env_step = self._env.step
|
|
385
|
+
|
|
386
|
+
def _build_env(
|
|
387
|
+
self,
|
|
388
|
+
env,
|
|
389
|
+
_seed: int | None = None,
|
|
390
|
+
from_pixels: bool = False,
|
|
391
|
+
render_kwargs: dict | None = None,
|
|
392
|
+
pixels_only: bool = False,
|
|
393
|
+
camera_id: int | str = 0,
|
|
394
|
+
**kwargs,
|
|
395
|
+
):
|
|
396
|
+
self.from_pixels = from_pixels
|
|
397
|
+
self.pixels_only = pixels_only
|
|
398
|
+
|
|
399
|
+
return env
|
|
400
|
+
|
|
401
|
+
def make_render(self):
|
|
402
|
+
"""Returns a transformed environment that can be rendered.
|
|
403
|
+
|
|
404
|
+
Examples:
|
|
405
|
+
>>> from torchrl.envs import JumanjiEnv
|
|
406
|
+
>>> from torchrl.record import CSVLogger, VideoRecorder
|
|
407
|
+
>>>
|
|
408
|
+
>>> envname = JumanjiEnv.available_envs[-1]
|
|
409
|
+
>>> logger = CSVLogger("jumanji", video_format="mp4", video_fps=2)
|
|
410
|
+
>>> env = JumanjiEnv(envname, from_pixels=True)
|
|
411
|
+
>>>
|
|
412
|
+
>>> env = env.append_transform(
|
|
413
|
+
... VideoRecorder(logger=logger, in_keys=["pixels"], tag=envname)
|
|
414
|
+
... )
|
|
415
|
+
>>> env.set_seed(0)
|
|
416
|
+
>>> r = env.rollout(100)
|
|
417
|
+
>>> env.transform.dump()
|
|
418
|
+
|
|
419
|
+
"""
|
|
420
|
+
from torchrl.record import PixelRenderTransform
|
|
421
|
+
|
|
422
|
+
return self.append_transform(
|
|
423
|
+
PixelRenderTransform(
|
|
424
|
+
out_keys=["pixels"],
|
|
425
|
+
pass_tensordict=True,
|
|
426
|
+
as_non_tensor=bool(self.batch_size),
|
|
427
|
+
as_numpy=bool(self.batch_size),
|
|
428
|
+
)
|
|
429
|
+
)
|
|
430
|
+
|
|
431
|
+
def _make_state_example(self, env):
|
|
432
|
+
import jax
|
|
433
|
+
from jax import numpy as jnp
|
|
434
|
+
|
|
435
|
+
key = jax.random.PRNGKey(0)
|
|
436
|
+
keys = jax.random.split(key, self.batch_size.numel())
|
|
437
|
+
state, _ = jax.vmap(env.reset)(jnp.stack(keys))
|
|
438
|
+
state = _tree_reshape(state, self.batch_size)
|
|
439
|
+
return state
|
|
440
|
+
|
|
441
|
+
def _make_state_spec(self, env) -> TensorSpec:
|
|
442
|
+
import jax
|
|
443
|
+
|
|
444
|
+
key = jax.random.PRNGKey(0)
|
|
445
|
+
state, _ = env.reset(key)
|
|
446
|
+
state_dict = _object_to_tensordict(state, self.device, batch_size=())
|
|
447
|
+
state_spec = _extract_spec(state_dict)
|
|
448
|
+
return state_spec
|
|
449
|
+
|
|
450
|
+
def _make_action_spec(self, env) -> TensorSpec:
|
|
451
|
+
action_spec = _jumanji_to_torchrl_spec_transform(
|
|
452
|
+
env.action_spec,
|
|
453
|
+
device=self.device,
|
|
454
|
+
categorical_action_encoding=self.categorical_action_encoding,
|
|
455
|
+
)
|
|
456
|
+
action_spec = action_spec.expand(*self.batch_size, *action_spec.shape)
|
|
457
|
+
return action_spec
|
|
458
|
+
|
|
459
|
+
def _make_observation_spec(self, env) -> TensorSpec:
|
|
460
|
+
jumanji = self.lib
|
|
461
|
+
|
|
462
|
+
spec = env.observation_spec
|
|
463
|
+
new_spec = _jumanji_to_torchrl_spec_transform(spec, device=self.device)
|
|
464
|
+
if isinstance(spec, jumanji.specs.Array):
|
|
465
|
+
return Composite(observation=new_spec).expand(self.batch_size)
|
|
466
|
+
elif isinstance(spec, jumanji.specs.Spec):
|
|
467
|
+
return Composite(**{k: v for k, v in new_spec.items()}).expand(
|
|
468
|
+
self.batch_size
|
|
469
|
+
)
|
|
470
|
+
else:
|
|
471
|
+
raise TypeError(f"Unsupported spec type {type(spec)}")
|
|
472
|
+
|
|
473
|
+
def _make_reward_spec(self, env) -> TensorSpec:
|
|
474
|
+
reward_spec = _jumanji_to_torchrl_spec_transform(
|
|
475
|
+
env.reward_spec, device=self.device
|
|
476
|
+
)
|
|
477
|
+
if not len(reward_spec.shape):
|
|
478
|
+
reward_spec.shape = torch.Size([1])
|
|
479
|
+
return reward_spec.expand([*self.batch_size, *reward_spec.shape])
|
|
480
|
+
|
|
481
|
+
def _make_specs(self, env: jumanji.env.Environment) -> None: # noqa: F821
|
|
482
|
+
|
|
483
|
+
# extract spec from jumanji definition
|
|
484
|
+
self.action_spec = self._make_action_spec(env)
|
|
485
|
+
self.observation_spec = self._make_observation_spec(env)
|
|
486
|
+
self.reward_spec = self._make_reward_spec(env)
|
|
487
|
+
|
|
488
|
+
# extract state spec from instance
|
|
489
|
+
state_spec = self._make_state_spec(env).expand(self.batch_size)
|
|
490
|
+
self.state_spec["state"] = state_spec
|
|
491
|
+
self.observation_spec["state"] = state_spec.clone()
|
|
492
|
+
|
|
493
|
+
# build state example for data conversion
|
|
494
|
+
self._state_example = self._make_state_example(env)
|
|
495
|
+
|
|
496
|
+
def _check_kwargs(self, kwargs: dict):
|
|
497
|
+
jumanji = self.lib
|
|
498
|
+
if "env" not in kwargs:
|
|
499
|
+
raise TypeError("Could not find environment key 'env' in kwargs.")
|
|
500
|
+
env = kwargs["env"]
|
|
501
|
+
if not isinstance(env, (jumanji.env.Environment,)):
|
|
502
|
+
raise TypeError("env is not of type 'jumanji.env.Environment'.")
|
|
503
|
+
|
|
504
|
+
def _init_env(self):
|
|
505
|
+
pass
|
|
506
|
+
|
|
507
|
+
@property
|
|
508
|
+
def key(self):
|
|
509
|
+
key = getattr(self, "_key", None)
|
|
510
|
+
if key is None:
|
|
511
|
+
raise RuntimeError(
|
|
512
|
+
"the env.key attribute wasn't found. Make sure to call `env.set_seed(seed)` before any interaction."
|
|
513
|
+
)
|
|
514
|
+
return key
|
|
515
|
+
|
|
516
|
+
@key.setter
|
|
517
|
+
def key(self, value):
|
|
518
|
+
self._key = value
|
|
519
|
+
|
|
520
|
+
def _set_seed(self, seed: int | None) -> None:
|
|
521
|
+
import jax
|
|
522
|
+
|
|
523
|
+
if seed is None:
|
|
524
|
+
raise Exception("Jumanji requires an integer seed.")
|
|
525
|
+
self.key = jax.random.PRNGKey(seed)
|
|
526
|
+
|
|
527
|
+
def read_state(self, state, batch_size=None):
|
|
528
|
+
state_dict = _object_to_tensordict(
|
|
529
|
+
state, self.device, self.batch_size if batch_size is None else batch_size
|
|
530
|
+
)
|
|
531
|
+
return self.state_spec["state"].encode(state_dict)
|
|
532
|
+
|
|
533
|
+
def read_obs(self, obs, batch_size=None):
|
|
534
|
+
from jax import numpy as jnp
|
|
535
|
+
|
|
536
|
+
if isinstance(obs, (list, jnp.ndarray, np.ndarray)):
|
|
537
|
+
obs_dict = _ndarray_to_tensor(obs).to(self.device)
|
|
538
|
+
else:
|
|
539
|
+
obs_dict = _object_to_tensordict(
|
|
540
|
+
obs, self.device, self.batch_size if batch_size is None else batch_size
|
|
541
|
+
)
|
|
542
|
+
return super().read_obs(obs_dict)
|
|
543
|
+
|
|
544
|
+
def render(
|
|
545
|
+
self,
|
|
546
|
+
tensordict,
|
|
547
|
+
matplotlib_backend: str | None = None,
|
|
548
|
+
as_numpy: bool = False,
|
|
549
|
+
**kwargs,
|
|
550
|
+
):
|
|
551
|
+
"""Renders the environment output given an input tensordict.
|
|
552
|
+
|
|
553
|
+
This method is intended to be called by the :class:`~torchrl.record.PixelRenderTransform`
|
|
554
|
+
created whenever `from_pixels=True` is selected.
|
|
555
|
+
To create an appropriate rendering transform, use a similar call as bellow:
|
|
556
|
+
|
|
557
|
+
>>> from torchrl.record import PixelRenderTransform
|
|
558
|
+
>>> matplotlib_backend = None # Change this value if a specific matplotlib backend has to be used.
|
|
559
|
+
>>> env = env.append_transform(
|
|
560
|
+
... PixelRenderTransform(out_keys=["pixels"], pass_tensordict=True, matplotlib_backend=matplotlib_backend)
|
|
561
|
+
... )
|
|
562
|
+
|
|
563
|
+
This pipeline will write a `"pixels"` entry in your output tensordict.
|
|
564
|
+
|
|
565
|
+
Args:
|
|
566
|
+
tensordict (TensorDictBase): a tensordict containing a state to represent
|
|
567
|
+
matplotlib_backend (str, optional): the matplotlib backend
|
|
568
|
+
as_numpy (bool, optional): if ``False``, the np.ndarray will be converted to a torch.Tensor.
|
|
569
|
+
Defaults to ``False``.
|
|
570
|
+
|
|
571
|
+
"""
|
|
572
|
+
import io
|
|
573
|
+
|
|
574
|
+
import jax
|
|
575
|
+
import jax.numpy as jnp
|
|
576
|
+
import jumanji
|
|
577
|
+
|
|
578
|
+
try:
|
|
579
|
+
import matplotlib
|
|
580
|
+
import matplotlib.pyplot as plt
|
|
581
|
+
import PIL
|
|
582
|
+
import torchvision.transforms.v2.functional
|
|
583
|
+
except ImportError as err:
|
|
584
|
+
raise ImportError(
|
|
585
|
+
"Rendering with Jumanji requires torchvision, matplotlib and PIL to be installed."
|
|
586
|
+
) from err
|
|
587
|
+
|
|
588
|
+
if matplotlib_backend is not None:
|
|
589
|
+
matplotlib.use(matplotlib_backend)
|
|
590
|
+
|
|
591
|
+
# Get only one env
|
|
592
|
+
_state_example = self._state_example
|
|
593
|
+
while tensordict.ndim:
|
|
594
|
+
tensordict = tensordict[0]
|
|
595
|
+
_state_example = jax.tree_util.tree_map(
|
|
596
|
+
lambda x: jnp.take(x, 0, axis=0), _state_example
|
|
597
|
+
)
|
|
598
|
+
# Patch jumanji is_notebook
|
|
599
|
+
is_notebook = jumanji.environments.is_notebook
|
|
600
|
+
try:
|
|
601
|
+
jumanji.environments.is_notebook = lambda: False
|
|
602
|
+
|
|
603
|
+
isinteractive = plt.isinteractive()
|
|
604
|
+
plt.ion()
|
|
605
|
+
buf = io.BytesIO()
|
|
606
|
+
state = _tensordict_to_object(
|
|
607
|
+
tensordict.get("state"),
|
|
608
|
+
_state_example,
|
|
609
|
+
batch_size=tensordict.batch_size if not self.batch_locked else None,
|
|
610
|
+
)
|
|
611
|
+
self._env.render(state, **kwargs)
|
|
612
|
+
plt.savefig(buf, format="png")
|
|
613
|
+
buf.seek(0)
|
|
614
|
+
# Load the image into a PIL object.
|
|
615
|
+
img = PIL.Image.open(buf)
|
|
616
|
+
img_array = torchvision.transforms.v2.functional.pil_to_tensor(img)
|
|
617
|
+
if not isinteractive:
|
|
618
|
+
plt.ioff()
|
|
619
|
+
plt.close()
|
|
620
|
+
if not as_numpy:
|
|
621
|
+
return img_array[:3]
|
|
622
|
+
return img_array[:3].numpy().copy()
|
|
623
|
+
finally:
|
|
624
|
+
jumanji.environments.is_notebook = is_notebook
|
|
625
|
+
|
|
626
|
+
def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
|
|
627
|
+
import jax
|
|
628
|
+
|
|
629
|
+
if self.batch_locked:
|
|
630
|
+
batch_size = self.batch_size
|
|
631
|
+
else:
|
|
632
|
+
batch_size = tensordict.batch_size
|
|
633
|
+
|
|
634
|
+
# prepare inputs
|
|
635
|
+
state = _tensordict_to_object(
|
|
636
|
+
tensordict.get("state"),
|
|
637
|
+
self._state_example,
|
|
638
|
+
batch_size=tensordict.batch_size if not self.batch_locked else None,
|
|
639
|
+
)
|
|
640
|
+
action = self.read_action(tensordict.get("action"))
|
|
641
|
+
|
|
642
|
+
# flatten batch size into vector
|
|
643
|
+
state = _tree_flatten(state, batch_size)
|
|
644
|
+
action = _tree_flatten(action, batch_size)
|
|
645
|
+
|
|
646
|
+
# jax vectorizing map on env.step
|
|
647
|
+
state, timestep = jax.vmap(self._env_step)(state, action)
|
|
648
|
+
|
|
649
|
+
# reshape batch size from vector
|
|
650
|
+
state = _tree_reshape(state, batch_size)
|
|
651
|
+
timestep = _tree_reshape(timestep, batch_size)
|
|
652
|
+
|
|
653
|
+
# collect outputs
|
|
654
|
+
state_dict = self.read_state(state, batch_size=batch_size)
|
|
655
|
+
obs_dict = self.read_obs(timestep.observation, batch_size=batch_size)
|
|
656
|
+
reward = self.read_reward(np.asarray(timestep.reward))
|
|
657
|
+
done = timestep.step_type == self.lib.types.StepType.LAST
|
|
658
|
+
done = _ndarray_to_tensor(done).view(torch.bool).to(self.device)
|
|
659
|
+
|
|
660
|
+
# build results
|
|
661
|
+
tensordict_out = TensorDict(
|
|
662
|
+
source=obs_dict,
|
|
663
|
+
batch_size=tensordict.batch_size,
|
|
664
|
+
device=self.device,
|
|
665
|
+
)
|
|
666
|
+
tensordict_out.set("reward", reward)
|
|
667
|
+
tensordict_out.set("done", done)
|
|
668
|
+
tensordict_out.set("terminated", done)
|
|
669
|
+
# tensordict_out.set("terminated", done)
|
|
670
|
+
tensordict_out["state"] = state_dict
|
|
671
|
+
|
|
672
|
+
return tensordict_out
|
|
673
|
+
|
|
674
|
+
def _reset(
|
|
675
|
+
self, tensordict: TensorDictBase | None = None, **kwargs
|
|
676
|
+
) -> TensorDictBase:
|
|
677
|
+
import jax
|
|
678
|
+
from jax import numpy as jnp
|
|
679
|
+
|
|
680
|
+
if self.batch_locked or tensordict is None:
|
|
681
|
+
numel = self.numel()
|
|
682
|
+
batch_size = self.batch_size
|
|
683
|
+
elif tensordict is not None:
|
|
684
|
+
numel = tensordict.numel()
|
|
685
|
+
batch_size = tensordict.batch_size
|
|
686
|
+
|
|
687
|
+
# generate random keys
|
|
688
|
+
self.key, *keys = jax.random.split(self.key, numel + 1)
|
|
689
|
+
|
|
690
|
+
# jax vectorizing map on env.reset
|
|
691
|
+
state, timestep = jax.vmap(self._env_reset)(jnp.stack(keys))
|
|
692
|
+
|
|
693
|
+
# reshape batch size from vector
|
|
694
|
+
state = _tree_reshape(state, batch_size)
|
|
695
|
+
timestep = _tree_reshape(timestep, batch_size)
|
|
696
|
+
|
|
697
|
+
# collect outputs
|
|
698
|
+
state_dict = self.read_state(state, batch_size=batch_size)
|
|
699
|
+
obs_dict = self.read_obs(timestep.observation, batch_size=batch_size)
|
|
700
|
+
if not self.batch_locked:
|
|
701
|
+
done_td = self.full_done_spec.zero(batch_size)
|
|
702
|
+
else:
|
|
703
|
+
done_td = self.full_done_spec.zero()
|
|
704
|
+
|
|
705
|
+
# build results
|
|
706
|
+
tensordict_out = TensorDict(
|
|
707
|
+
source=obs_dict,
|
|
708
|
+
batch_size=batch_size,
|
|
709
|
+
device=self.device,
|
|
710
|
+
)
|
|
711
|
+
tensordict_out.update(done_td)
|
|
712
|
+
tensordict_out["state"] = state_dict
|
|
713
|
+
|
|
714
|
+
return tensordict_out
|
|
715
|
+
|
|
716
|
+
def read_reward(self, reward):
|
|
717
|
+
"""Reads the reward and maps it to the reward space.
|
|
718
|
+
|
|
719
|
+
Args:
|
|
720
|
+
reward (torch.Tensor or TensorDict): reward to be mapped.
|
|
721
|
+
|
|
722
|
+
"""
|
|
723
|
+
if isinstance(reward, int) and reward == 0:
|
|
724
|
+
return self.reward_spec.zero()
|
|
725
|
+
if self.batch_locked:
|
|
726
|
+
reward = self.reward_spec.encode(reward, ignore_device=True)
|
|
727
|
+
else:
|
|
728
|
+
reward = torch.as_tensor(reward)
|
|
729
|
+
if not reward.ndim or (reward.shape[-1] != self.reward_spec.shape[-1]):
|
|
730
|
+
reward = reward.unsqueeze(-1)
|
|
731
|
+
|
|
732
|
+
if reward is None:
|
|
733
|
+
reward = torch.tensor(np.nan).expand(self.reward_spec.shape)
|
|
734
|
+
|
|
735
|
+
return reward
|
|
736
|
+
|
|
737
|
+
def _output_transform(self, step_outputs_tuple: tuple) -> tuple:
|
|
738
|
+
...
|
|
739
|
+
|
|
740
|
+
def _reset_output_transform(self, reset_outputs_tuple: tuple) -> tuple:
|
|
741
|
+
...
|
|
742
|
+
|
|
743
|
+
|
|
744
|
+
class JumanjiEnv(JumanjiWrapper):
|
|
745
|
+
"""Jumanji environment wrapper built with the environment name.
|
|
746
|
+
|
|
747
|
+
Jumanji offers a vectorized simulation framework based on Jax.
|
|
748
|
+
TorchRL's wrapper incurs some overhead for the jax-to-torch conversion,
|
|
749
|
+
but computational graphs can still be built on top of the simulated trajectories,
|
|
750
|
+
allowing for backpropagation through the rollout.
|
|
751
|
+
|
|
752
|
+
GitHub: https://github.com/instadeepai/jumanji
|
|
753
|
+
|
|
754
|
+
Doc: https://instadeepai.github.io/jumanji/
|
|
755
|
+
|
|
756
|
+
Paper: https://arxiv.org/abs/2306.09884
|
|
757
|
+
|
|
758
|
+
Args:
|
|
759
|
+
env_name (str): the name of the environment to wrap. Must be part of :attr:`~.available_envs`.
|
|
760
|
+
categorical_action_encoding (bool, optional): if ``True``, categorical
|
|
761
|
+
specs will be converted to the TorchRL equivalent (:class:`torchrl.data.Categorical`),
|
|
762
|
+
otherwise a one-hot encoding will be used (:class:`torchrl.data.OneHot`).
|
|
763
|
+
Defaults to ``False``.
|
|
764
|
+
|
|
765
|
+
Keyword Args:
|
|
766
|
+
from_pixels (bool, optional): Not yet supported.
|
|
767
|
+
frame_skip (int, optional): if provided, indicates for how many steps the
|
|
768
|
+
same action is to be repeated. The observation returned will be the
|
|
769
|
+
last observation of the sequence, whereas the reward will be the sum
|
|
770
|
+
of rewards across steps.
|
|
771
|
+
device (torch.device, optional): if provided, the device on which the data
|
|
772
|
+
is to be cast. Defaults to ``torch.device("cpu")``.
|
|
773
|
+
batch_size (torch.Size, optional): the batch size of the environment.
|
|
774
|
+
With ``jumanji``, this indicates the number of vectorized environments.
|
|
775
|
+
Defaults to ``torch.Size([])``.
|
|
776
|
+
allow_done_after_reset (bool, optional): if ``True``, it is tolerated
|
|
777
|
+
for envs to be ``done`` just after :meth:`reset` is called.
|
|
778
|
+
Defaults to ``False``.
|
|
779
|
+
|
|
780
|
+
Attributes:
|
|
781
|
+
available_envs: environments available to build
|
|
782
|
+
|
|
783
|
+
Examples:
|
|
784
|
+
>>> from torchrl.envs import JumanjiEnv
|
|
785
|
+
>>> env = JumanjiEnv("Snake-v1")
|
|
786
|
+
>>> env.set_seed(0)
|
|
787
|
+
>>> td = env.reset()
|
|
788
|
+
>>> td["action"] = env.action_spec.rand()
|
|
789
|
+
>>> td = env.step(td)
|
|
790
|
+
>>> print(td)
|
|
791
|
+
TensorDict(
|
|
792
|
+
fields={
|
|
793
|
+
action: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int32, is_shared=False),
|
|
794
|
+
action_mask: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
795
|
+
done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
796
|
+
grid: Tensor(shape=torch.Size([12, 12, 5]), device=cpu, dtype=torch.float32, is_shared=False),
|
|
797
|
+
next: TensorDict(
|
|
798
|
+
fields={
|
|
799
|
+
action_mask: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
800
|
+
done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
801
|
+
grid: Tensor(shape=torch.Size([12, 12, 5]), device=cpu, dtype=torch.float32, is_shared=False),
|
|
802
|
+
reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
|
|
803
|
+
state: TensorDict(
|
|
804
|
+
fields={
|
|
805
|
+
action_mask: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
806
|
+
body: Tensor(shape=torch.Size([12, 12]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
807
|
+
body_state: Tensor(shape=torch.Size([12, 12]), device=cpu, dtype=torch.int32, is_shared=False),
|
|
808
|
+
fruit_position: TensorDict(
|
|
809
|
+
fields={
|
|
810
|
+
col: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int32, is_shared=False),
|
|
811
|
+
row: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int32, is_shared=False)},
|
|
812
|
+
batch_size=torch.Size([]),
|
|
813
|
+
device=cpu,
|
|
814
|
+
is_shared=False),
|
|
815
|
+
head_position: TensorDict(
|
|
816
|
+
fields={
|
|
817
|
+
col: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int32, is_shared=False),
|
|
818
|
+
row: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int32, is_shared=False)},
|
|
819
|
+
batch_size=torch.Size([]),
|
|
820
|
+
device=cpu,
|
|
821
|
+
is_shared=False),
|
|
822
|
+
key: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.int32, is_shared=False),
|
|
823
|
+
length: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int32, is_shared=False),
|
|
824
|
+
step_count: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int32, is_shared=False),
|
|
825
|
+
tail: Tensor(shape=torch.Size([12, 12]), device=cpu, dtype=torch.bool, is_shared=False)},
|
|
826
|
+
batch_size=torch.Size([]),
|
|
827
|
+
device=cpu,
|
|
828
|
+
is_shared=False),
|
|
829
|
+
step_count: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int32, is_shared=False),
|
|
830
|
+
terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},
|
|
831
|
+
batch_size=torch.Size([]),
|
|
832
|
+
device=cpu,
|
|
833
|
+
is_shared=False),
|
|
834
|
+
state: TensorDict(
|
|
835
|
+
fields={
|
|
836
|
+
action_mask: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
837
|
+
body: Tensor(shape=torch.Size([12, 12]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
838
|
+
body_state: Tensor(shape=torch.Size([12, 12]), device=cpu, dtype=torch.int32, is_shared=False),
|
|
839
|
+
fruit_position: TensorDict(
|
|
840
|
+
fields={
|
|
841
|
+
col: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int32, is_shared=False),
|
|
842
|
+
row: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int32, is_shared=False)},
|
|
843
|
+
batch_size=torch.Size([]),
|
|
844
|
+
device=cpu,
|
|
845
|
+
is_shared=False),
|
|
846
|
+
head_position: TensorDict(
|
|
847
|
+
fields={
|
|
848
|
+
col: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int32, is_shared=False),
|
|
849
|
+
row: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int32, is_shared=False)},
|
|
850
|
+
batch_size=torch.Size([]),
|
|
851
|
+
device=cpu,
|
|
852
|
+
is_shared=False),
|
|
853
|
+
key: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.int32, is_shared=False),
|
|
854
|
+
length: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int32, is_shared=False),
|
|
855
|
+
step_count: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int32, is_shared=False),
|
|
856
|
+
tail: Tensor(shape=torch.Size([12, 12]), device=cpu, dtype=torch.bool, is_shared=False)},
|
|
857
|
+
batch_size=torch.Size([]),
|
|
858
|
+
device=cpu,
|
|
859
|
+
is_shared=False),
|
|
860
|
+
step_count: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int32, is_shared=False),
|
|
861
|
+
terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},
|
|
862
|
+
batch_size=torch.Size([]),
|
|
863
|
+
device=cpu,
|
|
864
|
+
is_shared=False)
|
|
865
|
+
>>> print(env.available_envs)
|
|
866
|
+
['Game2048-v1',
|
|
867
|
+
'Maze-v0',
|
|
868
|
+
'Cleaner-v0',
|
|
869
|
+
'CVRP-v1',
|
|
870
|
+
'MultiCVRP-v0',
|
|
871
|
+
'Minesweeper-v0',
|
|
872
|
+
'RubiksCube-v0',
|
|
873
|
+
'Knapsack-v1',
|
|
874
|
+
'Sudoku-v0',
|
|
875
|
+
'Snake-v1',
|
|
876
|
+
'TSP-v1',
|
|
877
|
+
'Connector-v2',
|
|
878
|
+
'MMST-v0',
|
|
879
|
+
'GraphColoring-v0',
|
|
880
|
+
'RubiksCube-partly-scrambled-v0',
|
|
881
|
+
'RobotWarehouse-v0',
|
|
882
|
+
'Tetris-v0',
|
|
883
|
+
'BinPack-v2',
|
|
884
|
+
'Sudoku-very-easy-v0',
|
|
885
|
+
'JobShop-v0']
|
|
886
|
+
|
|
887
|
+
To take advante of Jumanji, one usually executes multiple environments at the
|
|
888
|
+
same time.
|
|
889
|
+
|
|
890
|
+
>>> from torchrl.envs import JumanjiEnv
|
|
891
|
+
>>> env = JumanjiEnv("Snake-v1", batch_size=[10])
|
|
892
|
+
>>> env.set_seed(0)
|
|
893
|
+
>>> td = env.reset()
|
|
894
|
+
>>> td["action"] = env.action_spec.rand()
|
|
895
|
+
>>> td = env.step(td)
|
|
896
|
+
|
|
897
|
+
In the following example, we iteratively test different batch sizes
|
|
898
|
+
and report the execution time for a short rollout:
|
|
899
|
+
|
|
900
|
+
Examples:
|
|
901
|
+
>>> from torch.utils.benchmark import Timer
|
|
902
|
+
>>> for batch_size in [4, 16, 128]:
|
|
903
|
+
... timer = Timer(
|
|
904
|
+
... '''
|
|
905
|
+
... env.rollout(100)
|
|
906
|
+
... ''',
|
|
907
|
+
... setup=f'''
|
|
908
|
+
... from torchrl.envs import JumanjiEnv
|
|
909
|
+
... env = JumanjiEnv('Snake-v1', batch_size=[{batch_size}])
|
|
910
|
+
... env.set_seed(0)
|
|
911
|
+
... env.rollout(2)
|
|
912
|
+
... ''')
|
|
913
|
+
... print(batch_size, timer.timeit(number=10))
|
|
914
|
+
4 <torch.utils.benchmark.utils.common.Measurement object at 0x1fca91910>
|
|
915
|
+
env.rollout(100)
|
|
916
|
+
setup: [...]
|
|
917
|
+
Median: 122.40 ms
|
|
918
|
+
2 measurements, 1 runs per measurement, 1 thread
|
|
919
|
+
16 <torch.utils.benchmark.utils.common.Measurement object at 0x1ff9baee0>
|
|
920
|
+
env.rollout(100)
|
|
921
|
+
setup: [...]
|
|
922
|
+
Median: 134.39 ms
|
|
923
|
+
2 measurements, 1 runs per measurement, 1 thread
|
|
924
|
+
128 <torch.utils.benchmark.utils.common.Measurement object at 0x1ff9ba7c0>
|
|
925
|
+
env.rollout(100)
|
|
926
|
+
setup: [...]
|
|
927
|
+
Median: 172.31 ms
|
|
928
|
+
2 measurements, 1 runs per measurement, 1 thread
|
|
929
|
+
"""
|
|
930
|
+
|
|
931
|
+
def __init__(self, env_name, **kwargs):
|
|
932
|
+
kwargs["env_name"] = env_name
|
|
933
|
+
super().__init__(**kwargs)
|
|
934
|
+
|
|
935
|
+
def _build_env(
|
|
936
|
+
self,
|
|
937
|
+
env_name: str,
|
|
938
|
+
**kwargs,
|
|
939
|
+
) -> jumanji.env.Environment: # noqa: F821
|
|
940
|
+
if not _has_jumanji:
|
|
941
|
+
raise ImportError(
|
|
942
|
+
f"jumanji not found, unable to create {env_name}. "
|
|
943
|
+
f"Consider installing jumanji. More info:"
|
|
944
|
+
f" {self.git_url}."
|
|
945
|
+
)
|
|
946
|
+
from_pixels = kwargs.pop("from_pixels", False)
|
|
947
|
+
pixels_only = kwargs.pop("pixels_only", True)
|
|
948
|
+
if kwargs:
|
|
949
|
+
raise ValueError(f"Extra kwargs are not supported by {type(self)}.")
|
|
950
|
+
self.wrapper_frame_skip = 1
|
|
951
|
+
env = self.lib.make(env_name, **kwargs)
|
|
952
|
+
return super()._build_env(env, pixels_only=pixels_only, from_pixels=from_pixels)
|
|
953
|
+
|
|
954
|
+
@property
|
|
955
|
+
def env_name(self):
|
|
956
|
+
return self._constructor_kwargs["env_name"]
|
|
957
|
+
|
|
958
|
+
def _check_kwargs(self, kwargs: dict):
|
|
959
|
+
if "env_name" not in kwargs:
|
|
960
|
+
raise TypeError("Expected 'env_name' to be part of kwargs")
|
|
961
|
+
|
|
962
|
+
def __repr__(self) -> str:
|
|
963
|
+
return f"{self.__class__.__name__}(env={self.env_name}, batch_size={self.batch_size}, device={self.device})"
|