torchrl 0.11.0__cp314-cp314-manylinux_2_28_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/benchmark_batched_envs.py +104 -0
- benchmarks/conftest.py +91 -0
- benchmarks/ecosystem/gym_env_throughput.py +321 -0
- benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
- benchmarks/requirements.txt +7 -0
- benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
- benchmarks/test_collectors_benchmark.py +240 -0
- benchmarks/test_compressed_storage_benchmark.py +145 -0
- benchmarks/test_envs_benchmark.py +133 -0
- benchmarks/test_llm.py +101 -0
- benchmarks/test_non_tensor_env_benchmark.py +70 -0
- benchmarks/test_objectives_benchmarks.py +1199 -0
- benchmarks/test_replaybuffer_benchmark.py +254 -0
- sota-check/README.md +35 -0
- sota-implementations/README.md +142 -0
- sota-implementations/a2c/README.md +39 -0
- sota-implementations/a2c/a2c_atari.py +291 -0
- sota-implementations/a2c/a2c_mujoco.py +273 -0
- sota-implementations/a2c/utils_atari.py +240 -0
- sota-implementations/a2c/utils_mujoco.py +160 -0
- sota-implementations/bandits/README.md +7 -0
- sota-implementations/bandits/dqn.py +126 -0
- sota-implementations/cql/cql_offline.py +198 -0
- sota-implementations/cql/cql_online.py +249 -0
- sota-implementations/cql/discrete_cql_offline.py +180 -0
- sota-implementations/cql/discrete_cql_online.py +227 -0
- sota-implementations/cql/utils.py +471 -0
- sota-implementations/crossq/crossq.py +271 -0
- sota-implementations/crossq/utils.py +320 -0
- sota-implementations/ddpg/ddpg.py +231 -0
- sota-implementations/ddpg/utils.py +325 -0
- sota-implementations/decision_transformer/dt.py +163 -0
- sota-implementations/decision_transformer/lamb.py +167 -0
- sota-implementations/decision_transformer/online_dt.py +178 -0
- sota-implementations/decision_transformer/utils.py +562 -0
- sota-implementations/discrete_sac/discrete_sac.py +243 -0
- sota-implementations/discrete_sac/utils.py +324 -0
- sota-implementations/dqn/README.md +30 -0
- sota-implementations/dqn/dqn_atari.py +272 -0
- sota-implementations/dqn/dqn_cartpole.py +236 -0
- sota-implementations/dqn/utils_atari.py +132 -0
- sota-implementations/dqn/utils_cartpole.py +90 -0
- sota-implementations/dreamer/README.md +129 -0
- sota-implementations/dreamer/dreamer.py +586 -0
- sota-implementations/dreamer/dreamer_utils.py +1107 -0
- sota-implementations/expert-iteration/README.md +352 -0
- sota-implementations/expert-iteration/ei_utils.py +770 -0
- sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
- sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
- sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
- sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
- sota-implementations/gail/gail.py +327 -0
- sota-implementations/gail/gail_utils.py +68 -0
- sota-implementations/gail/ppo_utils.py +157 -0
- sota-implementations/grpo/README.md +273 -0
- sota-implementations/grpo/grpo-async.py +437 -0
- sota-implementations/grpo/grpo-sync.py +435 -0
- sota-implementations/grpo/grpo_utils.py +843 -0
- sota-implementations/grpo/requirements_gsm8k.txt +11 -0
- sota-implementations/grpo/requirements_ifeval.txt +16 -0
- sota-implementations/impala/README.md +33 -0
- sota-implementations/impala/impala_multi_node_ray.py +292 -0
- sota-implementations/impala/impala_multi_node_submitit.py +284 -0
- sota-implementations/impala/impala_single_node.py +261 -0
- sota-implementations/impala/utils.py +184 -0
- sota-implementations/iql/discrete_iql.py +230 -0
- sota-implementations/iql/iql_offline.py +164 -0
- sota-implementations/iql/iql_online.py +225 -0
- sota-implementations/iql/utils.py +437 -0
- sota-implementations/multiagent/README.md +74 -0
- sota-implementations/multiagent/iql.py +237 -0
- sota-implementations/multiagent/maddpg_iddpg.py +266 -0
- sota-implementations/multiagent/mappo_ippo.py +267 -0
- sota-implementations/multiagent/qmix_vdn.py +271 -0
- sota-implementations/multiagent/sac.py +337 -0
- sota-implementations/multiagent/utils/__init__.py +4 -0
- sota-implementations/multiagent/utils/logging.py +151 -0
- sota-implementations/multiagent/utils/utils.py +43 -0
- sota-implementations/ppo/README.md +29 -0
- sota-implementations/ppo/ppo_atari.py +305 -0
- sota-implementations/ppo/ppo_mujoco.py +293 -0
- sota-implementations/ppo/utils_atari.py +238 -0
- sota-implementations/ppo/utils_mujoco.py +152 -0
- sota-implementations/ppo_trainer/train.py +21 -0
- sota-implementations/redq/README.md +7 -0
- sota-implementations/redq/redq.py +199 -0
- sota-implementations/redq/utils.py +1060 -0
- sota-implementations/sac/sac-async.py +266 -0
- sota-implementations/sac/sac.py +239 -0
- sota-implementations/sac/utils.py +381 -0
- sota-implementations/sac_trainer/train.py +16 -0
- sota-implementations/td3/td3.py +254 -0
- sota-implementations/td3/utils.py +319 -0
- sota-implementations/td3_bc/td3_bc.py +177 -0
- sota-implementations/td3_bc/utils.py +251 -0
- torchrl/__init__.py +144 -0
- torchrl/_extension.py +74 -0
- torchrl/_torchrl.cpython-314-aarch64-linux-gnu.so +0 -0
- torchrl/_utils.py +1431 -0
- torchrl/collectors/__init__.py +48 -0
- torchrl/collectors/_base.py +1058 -0
- torchrl/collectors/_constants.py +88 -0
- torchrl/collectors/_multi_async.py +324 -0
- torchrl/collectors/_multi_base.py +1805 -0
- torchrl/collectors/_multi_sync.py +464 -0
- torchrl/collectors/_runner.py +581 -0
- torchrl/collectors/_single.py +2009 -0
- torchrl/collectors/_single_async.py +259 -0
- torchrl/collectors/collectors.py +62 -0
- torchrl/collectors/distributed/__init__.py +32 -0
- torchrl/collectors/distributed/default_configs.py +133 -0
- torchrl/collectors/distributed/generic.py +1306 -0
- torchrl/collectors/distributed/ray.py +1092 -0
- torchrl/collectors/distributed/rpc.py +1006 -0
- torchrl/collectors/distributed/sync.py +731 -0
- torchrl/collectors/distributed/utils.py +160 -0
- torchrl/collectors/llm/__init__.py +10 -0
- torchrl/collectors/llm/base.py +494 -0
- torchrl/collectors/llm/ray_collector.py +275 -0
- torchrl/collectors/llm/utils.py +36 -0
- torchrl/collectors/llm/weight_update/__init__.py +10 -0
- torchrl/collectors/llm/weight_update/vllm.py +348 -0
- torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
- torchrl/collectors/utils.py +433 -0
- torchrl/collectors/weight_update.py +591 -0
- torchrl/csrc/numpy_utils.h +38 -0
- torchrl/csrc/pybind.cpp +27 -0
- torchrl/csrc/segment_tree.h +458 -0
- torchrl/csrc/torch_utils.h +34 -0
- torchrl/csrc/utils.cpp +48 -0
- torchrl/csrc/utils.h +31 -0
- torchrl/data/__init__.py +187 -0
- torchrl/data/datasets/__init__.py +58 -0
- torchrl/data/datasets/atari_dqn.py +878 -0
- torchrl/data/datasets/common.py +281 -0
- torchrl/data/datasets/d4rl.py +489 -0
- torchrl/data/datasets/d4rl_infos.py +187 -0
- torchrl/data/datasets/gen_dgrl.py +375 -0
- torchrl/data/datasets/minari_data.py +643 -0
- torchrl/data/datasets/openml.py +177 -0
- torchrl/data/datasets/openx.py +798 -0
- torchrl/data/datasets/roboset.py +363 -0
- torchrl/data/datasets/utils.py +11 -0
- torchrl/data/datasets/vd4rl.py +432 -0
- torchrl/data/llm/__init__.py +34 -0
- torchrl/data/llm/dataset.py +491 -0
- torchrl/data/llm/history.py +1378 -0
- torchrl/data/llm/prompt.py +198 -0
- torchrl/data/llm/reward.py +225 -0
- torchrl/data/llm/topk.py +186 -0
- torchrl/data/llm/utils.py +543 -0
- torchrl/data/map/__init__.py +21 -0
- torchrl/data/map/hash.py +185 -0
- torchrl/data/map/query.py +204 -0
- torchrl/data/map/tdstorage.py +363 -0
- torchrl/data/map/tree.py +1434 -0
- torchrl/data/map/utils.py +103 -0
- torchrl/data/postprocs/__init__.py +8 -0
- torchrl/data/postprocs/postprocs.py +391 -0
- torchrl/data/replay_buffers/__init__.py +99 -0
- torchrl/data/replay_buffers/checkpointers.py +622 -0
- torchrl/data/replay_buffers/ray_buffer.py +292 -0
- torchrl/data/replay_buffers/replay_buffers.py +2376 -0
- torchrl/data/replay_buffers/samplers.py +2578 -0
- torchrl/data/replay_buffers/scheduler.py +265 -0
- torchrl/data/replay_buffers/storages.py +2412 -0
- torchrl/data/replay_buffers/utils.py +1042 -0
- torchrl/data/replay_buffers/writers.py +781 -0
- torchrl/data/tensor_specs.py +7101 -0
- torchrl/data/utils.py +334 -0
- torchrl/envs/__init__.py +265 -0
- torchrl/envs/async_envs.py +1105 -0
- torchrl/envs/batched_envs.py +3093 -0
- torchrl/envs/common.py +4241 -0
- torchrl/envs/custom/__init__.py +11 -0
- torchrl/envs/custom/chess.py +617 -0
- torchrl/envs/custom/llm.py +214 -0
- torchrl/envs/custom/pendulum.py +401 -0
- torchrl/envs/custom/san_moves.txt +29274 -0
- torchrl/envs/custom/tictactoeenv.py +288 -0
- torchrl/envs/env_creator.py +263 -0
- torchrl/envs/gym_like.py +752 -0
- torchrl/envs/libs/__init__.py +68 -0
- torchrl/envs/libs/_gym_utils.py +326 -0
- torchrl/envs/libs/brax.py +846 -0
- torchrl/envs/libs/dm_control.py +544 -0
- torchrl/envs/libs/envpool.py +447 -0
- torchrl/envs/libs/gym.py +2239 -0
- torchrl/envs/libs/habitat.py +138 -0
- torchrl/envs/libs/isaac_lab.py +87 -0
- torchrl/envs/libs/isaacgym.py +203 -0
- torchrl/envs/libs/jax_utils.py +166 -0
- torchrl/envs/libs/jumanji.py +963 -0
- torchrl/envs/libs/meltingpot.py +599 -0
- torchrl/envs/libs/openml.py +153 -0
- torchrl/envs/libs/openspiel.py +652 -0
- torchrl/envs/libs/pettingzoo.py +1042 -0
- torchrl/envs/libs/procgen.py +351 -0
- torchrl/envs/libs/robohive.py +429 -0
- torchrl/envs/libs/smacv2.py +645 -0
- torchrl/envs/libs/unity_mlagents.py +891 -0
- torchrl/envs/libs/utils.py +147 -0
- torchrl/envs/libs/vmas.py +813 -0
- torchrl/envs/llm/__init__.py +63 -0
- torchrl/envs/llm/chat.py +730 -0
- torchrl/envs/llm/datasets/README.md +4 -0
- torchrl/envs/llm/datasets/__init__.py +17 -0
- torchrl/envs/llm/datasets/gsm8k.py +353 -0
- torchrl/envs/llm/datasets/ifeval.py +274 -0
- torchrl/envs/llm/envs.py +789 -0
- torchrl/envs/llm/libs/README.md +3 -0
- torchrl/envs/llm/libs/__init__.py +8 -0
- torchrl/envs/llm/libs/mlgym.py +869 -0
- torchrl/envs/llm/reward/__init__.py +10 -0
- torchrl/envs/llm/reward/gsm8k.py +324 -0
- torchrl/envs/llm/reward/ifeval/README.md +13 -0
- torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
- torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
- torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
- torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
- torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
- torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
- torchrl/envs/llm/transforms/__init__.py +55 -0
- torchrl/envs/llm/transforms/browser.py +292 -0
- torchrl/envs/llm/transforms/dataloading.py +859 -0
- torchrl/envs/llm/transforms/format.py +73 -0
- torchrl/envs/llm/transforms/kl.py +1544 -0
- torchrl/envs/llm/transforms/policy_version.py +189 -0
- torchrl/envs/llm/transforms/reason.py +323 -0
- torchrl/envs/llm/transforms/tokenizer.py +321 -0
- torchrl/envs/llm/transforms/tools.py +1955 -0
- torchrl/envs/model_based/__init__.py +9 -0
- torchrl/envs/model_based/common.py +180 -0
- torchrl/envs/model_based/dreamer.py +112 -0
- torchrl/envs/transforms/__init__.py +147 -0
- torchrl/envs/transforms/functional.py +48 -0
- torchrl/envs/transforms/gym_transforms.py +203 -0
- torchrl/envs/transforms/module.py +341 -0
- torchrl/envs/transforms/r3m.py +372 -0
- torchrl/envs/transforms/ray_service.py +663 -0
- torchrl/envs/transforms/rb_transforms.py +214 -0
- torchrl/envs/transforms/transforms.py +11835 -0
- torchrl/envs/transforms/utils.py +94 -0
- torchrl/envs/transforms/vc1.py +307 -0
- torchrl/envs/transforms/vecnorm.py +845 -0
- torchrl/envs/transforms/vip.py +407 -0
- torchrl/envs/utils.py +1718 -0
- torchrl/envs/vec_envs.py +11 -0
- torchrl/modules/__init__.py +206 -0
- torchrl/modules/distributions/__init__.py +73 -0
- torchrl/modules/distributions/continuous.py +830 -0
- torchrl/modules/distributions/discrete.py +908 -0
- torchrl/modules/distributions/truncated_normal.py +187 -0
- torchrl/modules/distributions/utils.py +233 -0
- torchrl/modules/llm/__init__.py +62 -0
- torchrl/modules/llm/backends/__init__.py +65 -0
- torchrl/modules/llm/backends/vllm/__init__.py +94 -0
- torchrl/modules/llm/backends/vllm/_models.py +46 -0
- torchrl/modules/llm/backends/vllm/base.py +72 -0
- torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
- torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
- torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
- torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
- torchrl/modules/llm/policies/__init__.py +28 -0
- torchrl/modules/llm/policies/common.py +1809 -0
- torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
- torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
- torchrl/modules/llm/utils.py +23 -0
- torchrl/modules/mcts/__init__.py +21 -0
- torchrl/modules/mcts/scores.py +579 -0
- torchrl/modules/models/__init__.py +86 -0
- torchrl/modules/models/batchrenorm.py +119 -0
- torchrl/modules/models/decision_transformer.py +179 -0
- torchrl/modules/models/exploration.py +731 -0
- torchrl/modules/models/llm.py +156 -0
- torchrl/modules/models/model_based.py +596 -0
- torchrl/modules/models/models.py +1712 -0
- torchrl/modules/models/multiagent.py +1067 -0
- torchrl/modules/models/recipes/impala.py +185 -0
- torchrl/modules/models/utils.py +162 -0
- torchrl/modules/planners/__init__.py +10 -0
- torchrl/modules/planners/cem.py +228 -0
- torchrl/modules/planners/common.py +73 -0
- torchrl/modules/planners/mppi.py +265 -0
- torchrl/modules/tensordict_module/__init__.py +89 -0
- torchrl/modules/tensordict_module/actors.py +2457 -0
- torchrl/modules/tensordict_module/common.py +529 -0
- torchrl/modules/tensordict_module/exploration.py +814 -0
- torchrl/modules/tensordict_module/probabilistic.py +321 -0
- torchrl/modules/tensordict_module/rnn.py +1639 -0
- torchrl/modules/tensordict_module/sequence.py +132 -0
- torchrl/modules/tensordict_module/world_models.py +34 -0
- torchrl/modules/utils/__init__.py +38 -0
- torchrl/modules/utils/mappings.py +9 -0
- torchrl/modules/utils/utils.py +89 -0
- torchrl/objectives/__init__.py +78 -0
- torchrl/objectives/a2c.py +659 -0
- torchrl/objectives/common.py +753 -0
- torchrl/objectives/cql.py +1346 -0
- torchrl/objectives/crossq.py +710 -0
- torchrl/objectives/ddpg.py +453 -0
- torchrl/objectives/decision_transformer.py +371 -0
- torchrl/objectives/deprecated.py +516 -0
- torchrl/objectives/dqn.py +683 -0
- torchrl/objectives/dreamer.py +488 -0
- torchrl/objectives/functional.py +48 -0
- torchrl/objectives/gail.py +258 -0
- torchrl/objectives/iql.py +996 -0
- torchrl/objectives/llm/__init__.py +30 -0
- torchrl/objectives/llm/grpo.py +846 -0
- torchrl/objectives/llm/sft.py +482 -0
- torchrl/objectives/multiagent/__init__.py +8 -0
- torchrl/objectives/multiagent/qmixer.py +396 -0
- torchrl/objectives/ppo.py +1669 -0
- torchrl/objectives/redq.py +683 -0
- torchrl/objectives/reinforce.py +530 -0
- torchrl/objectives/sac.py +1580 -0
- torchrl/objectives/td3.py +570 -0
- torchrl/objectives/td3_bc.py +625 -0
- torchrl/objectives/utils.py +782 -0
- torchrl/objectives/value/__init__.py +28 -0
- torchrl/objectives/value/advantages.py +1956 -0
- torchrl/objectives/value/functional.py +1459 -0
- torchrl/objectives/value/utils.py +360 -0
- torchrl/record/__init__.py +17 -0
- torchrl/record/loggers/__init__.py +23 -0
- torchrl/record/loggers/common.py +48 -0
- torchrl/record/loggers/csv.py +226 -0
- torchrl/record/loggers/mlflow.py +142 -0
- torchrl/record/loggers/tensorboard.py +139 -0
- torchrl/record/loggers/trackio.py +163 -0
- torchrl/record/loggers/utils.py +78 -0
- torchrl/record/loggers/wandb.py +214 -0
- torchrl/record/recorder.py +554 -0
- torchrl/services/__init__.py +79 -0
- torchrl/services/base.py +109 -0
- torchrl/services/ray_service.py +453 -0
- torchrl/testing/__init__.py +107 -0
- torchrl/testing/assertions.py +179 -0
- torchrl/testing/dist_utils.py +122 -0
- torchrl/testing/env_creators.py +227 -0
- torchrl/testing/env_helper.py +35 -0
- torchrl/testing/gym_helpers.py +156 -0
- torchrl/testing/llm_mocks.py +119 -0
- torchrl/testing/mocking_classes.py +2720 -0
- torchrl/testing/modules.py +295 -0
- torchrl/testing/mp_helpers.py +15 -0
- torchrl/testing/ray_helpers.py +293 -0
- torchrl/testing/utils.py +190 -0
- torchrl/trainers/__init__.py +42 -0
- torchrl/trainers/algorithms/__init__.py +11 -0
- torchrl/trainers/algorithms/configs/__init__.py +705 -0
- torchrl/trainers/algorithms/configs/collectors.py +216 -0
- torchrl/trainers/algorithms/configs/common.py +41 -0
- torchrl/trainers/algorithms/configs/data.py +308 -0
- torchrl/trainers/algorithms/configs/envs.py +104 -0
- torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
- torchrl/trainers/algorithms/configs/logging.py +80 -0
- torchrl/trainers/algorithms/configs/modules.py +570 -0
- torchrl/trainers/algorithms/configs/objectives.py +177 -0
- torchrl/trainers/algorithms/configs/trainers.py +340 -0
- torchrl/trainers/algorithms/configs/transforms.py +955 -0
- torchrl/trainers/algorithms/configs/utils.py +252 -0
- torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
- torchrl/trainers/algorithms/configs/weight_update.py +159 -0
- torchrl/trainers/algorithms/ppo.py +373 -0
- torchrl/trainers/algorithms/sac.py +308 -0
- torchrl/trainers/helpers/__init__.py +40 -0
- torchrl/trainers/helpers/collectors.py +416 -0
- torchrl/trainers/helpers/envs.py +573 -0
- torchrl/trainers/helpers/logger.py +33 -0
- torchrl/trainers/helpers/losses.py +132 -0
- torchrl/trainers/helpers/models.py +658 -0
- torchrl/trainers/helpers/replay_buffer.py +59 -0
- torchrl/trainers/helpers/trainers.py +301 -0
- torchrl/trainers/trainers.py +2052 -0
- torchrl/weight_update/__init__.py +33 -0
- torchrl/weight_update/_distributed.py +749 -0
- torchrl/weight_update/_mp.py +624 -0
- torchrl/weight_update/_noupdate.py +102 -0
- torchrl/weight_update/_ray.py +1032 -0
- torchrl/weight_update/_rpc.py +284 -0
- torchrl/weight_update/_shared.py +891 -0
- torchrl/weight_update/llm/__init__.py +32 -0
- torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
- torchrl/weight_update/llm/vllm_nccl.py +710 -0
- torchrl/weight_update/utils.py +73 -0
- torchrl/weight_update/weight_sync_schemes.py +1244 -0
- torchrl-0.11.0.dist-info/METADATA +1308 -0
- torchrl-0.11.0.dist-info/RECORD +394 -0
- torchrl-0.11.0.dist-info/WHEEL +5 -0
- torchrl-0.11.0.dist-info/entry_points.txt +2 -0
- torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
- torchrl-0.11.0.dist-info/top_level.txt +7 -0
|
@@ -0,0 +1,846 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import importlib.util
|
|
8
|
+
import warnings
|
|
9
|
+
|
|
10
|
+
import torch
|
|
11
|
+
from packaging import version
|
|
12
|
+
from tensordict import TensorDict, TensorDictBase
|
|
13
|
+
|
|
14
|
+
from torchrl.data.tensor_specs import Bounded, Composite, Unbounded
|
|
15
|
+
from torchrl.envs.batched_envs import ParallelEnv
|
|
16
|
+
from torchrl.envs.common import _EnvPostInit, _EnvWrapper
|
|
17
|
+
from torchrl.envs.libs.jax_utils import (
|
|
18
|
+
_extract_spec,
|
|
19
|
+
_ndarray_to_tensor,
|
|
20
|
+
_object_to_tensordict,
|
|
21
|
+
_tensor_to_ndarray,
|
|
22
|
+
_tensordict_to_object,
|
|
23
|
+
_tree_flatten,
|
|
24
|
+
_tree_reshape,
|
|
25
|
+
)
|
|
26
|
+
from torchrl.envs.utils import _classproperty
|
|
27
|
+
|
|
28
|
+
_has_brax = importlib.util.find_spec("brax") is not None
|
|
29
|
+
|
|
30
|
+
_DEFAULT_CACHE_CLEAR_FREQUENCY = 20
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def _get_envs():
|
|
34
|
+
if not _has_brax:
|
|
35
|
+
raise ImportError("BRAX is not installed in your virtual environment.")
|
|
36
|
+
|
|
37
|
+
import brax.envs
|
|
38
|
+
|
|
39
|
+
return list(brax.envs._envs.keys())
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class _BraxMeta(_EnvPostInit):
|
|
43
|
+
"""Metaclass for BraxEnv that returns a lazy ParallelEnv when num_workers > 1."""
|
|
44
|
+
|
|
45
|
+
def __call__(cls, *args, num_workers: int | None = None, **kwargs):
|
|
46
|
+
# Extract num_workers from explicit kwarg or kwargs dict
|
|
47
|
+
if num_workers is None:
|
|
48
|
+
num_workers = kwargs.pop("num_workers", 1)
|
|
49
|
+
else:
|
|
50
|
+
kwargs.pop("num_workers", None)
|
|
51
|
+
|
|
52
|
+
num_workers = int(num_workers)
|
|
53
|
+
if cls.__name__ == "BraxEnv" and num_workers > 1:
|
|
54
|
+
# Extract env_name from args or kwargs
|
|
55
|
+
env_name = args[0] if len(args) >= 1 else kwargs.get("env_name")
|
|
56
|
+
|
|
57
|
+
# Remove env_name from kwargs if present (it will be passed positionally)
|
|
58
|
+
env_kwargs = {k: v for k, v in kwargs.items() if k != "env_name"}
|
|
59
|
+
|
|
60
|
+
# Create factory function that builds single BraxEnv instances
|
|
61
|
+
def make_env(_env_name=env_name, _kwargs=env_kwargs):
|
|
62
|
+
return cls(_env_name, num_workers=1, **_kwargs)
|
|
63
|
+
|
|
64
|
+
# Return lazy ParallelEnv (workers not started yet)
|
|
65
|
+
return ParallelEnv(num_workers, make_env)
|
|
66
|
+
|
|
67
|
+
return super().__call__(*args, **kwargs)
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
class BraxWrapper(_EnvWrapper):
|
|
71
|
+
"""Google Brax environment wrapper.
|
|
72
|
+
|
|
73
|
+
Brax offers a vectorized and differentiable simulation framework based on Jax.
|
|
74
|
+
TorchRL's wrapper incurs some overhead for the jax-to-torch conversion,
|
|
75
|
+
but computational graphs can still be built on top of the simulated trajectories,
|
|
76
|
+
allowing for backpropagation through the rollout.
|
|
77
|
+
|
|
78
|
+
GitHub: https://github.com/google/brax
|
|
79
|
+
|
|
80
|
+
Paper: https://arxiv.org/abs/2106.13281
|
|
81
|
+
|
|
82
|
+
Args:
|
|
83
|
+
env (brax.envs.base.PipelineEnv): the environment to wrap.
|
|
84
|
+
categorical_action_encoding (bool, optional): if ``True``, categorical
|
|
85
|
+
specs will be converted to the TorchRL equivalent (:class:`torchrl.data.Categorical`),
|
|
86
|
+
otherwise a one-hot encoding will be used (:class:`torchrl.data.OneHot`).
|
|
87
|
+
Defaults to ``False``.
|
|
88
|
+
cache_clear_frequency (int, optional): automatically clear JAX's internal
|
|
89
|
+
cache every N steps to prevent memory leaks when using ``requires_grad=True``.
|
|
90
|
+
Defaults to `False` (deactivates automatic cache clearing).
|
|
91
|
+
|
|
92
|
+
Keyword Args:
|
|
93
|
+
from_pixels (bool, optional): Not yet supported.
|
|
94
|
+
frame_skip (int, optional): if provided, indicates for how many steps the
|
|
95
|
+
same action is to be repeated. The observation returned will be the
|
|
96
|
+
last observation of the sequence, whereas the reward will be the sum
|
|
97
|
+
of rewards across steps.
|
|
98
|
+
device (torch.device, optional): if provided, the device on which the data
|
|
99
|
+
is to be cast. Defaults to ``torch.device("cpu")``.
|
|
100
|
+
batch_size (torch.Size, optional): the batch size of the environment.
|
|
101
|
+
In ``brax``, this controls the number of environments simulated in
|
|
102
|
+
parallel via JAX's ``vmap`` on a single device (GPU/TPU). Brax leverages
|
|
103
|
+
MuJoCo XLA (MJX) for hardware-accelerated batched simulation, enabling
|
|
104
|
+
thousands of environments to run in parallel within a single process.
|
|
105
|
+
Defaults to ``torch.Size([])``.
|
|
106
|
+
allow_done_after_reset (bool, optional): if ``True``, it is tolerated
|
|
107
|
+
for envs to be ``done`` just after :meth:`reset` is called.
|
|
108
|
+
Defaults to ``False``.
|
|
109
|
+
|
|
110
|
+
Attributes:
|
|
111
|
+
available_envs: environments available to build
|
|
112
|
+
|
|
113
|
+
Examples:
|
|
114
|
+
>>> import brax.envs
|
|
115
|
+
>>> from torchrl.envs import BraxWrapper
|
|
116
|
+
>>> import torch
|
|
117
|
+
>>> device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
118
|
+
>>> base_env = brax.envs.get_environment("ant")
|
|
119
|
+
>>> env = BraxWrapper(base_env, device=device)
|
|
120
|
+
>>> env.set_seed(0)
|
|
121
|
+
>>> td = env.reset()
|
|
122
|
+
>>> td["action"] = env.action_spec.rand()
|
|
123
|
+
>>> td = env.step(td)
|
|
124
|
+
>>> print(td)
|
|
125
|
+
TensorDict(
|
|
126
|
+
fields={
|
|
127
|
+
action: Tensor(torch.Size([8]), dtype=torch.float32),
|
|
128
|
+
done: Tensor(torch.Size([1]), dtype=torch.bool),
|
|
129
|
+
next: TensorDict(
|
|
130
|
+
fields={
|
|
131
|
+
observation: Tensor(torch.Size([87]), dtype=torch.float32)},
|
|
132
|
+
batch_size=torch.Size([]),
|
|
133
|
+
device=cpu,
|
|
134
|
+
is_shared=False),
|
|
135
|
+
observation: Tensor(torch.Size([87]), dtype=torch.float32),
|
|
136
|
+
reward: Tensor(torch.Size([1]), dtype=torch.float32),
|
|
137
|
+
state: TensorDict(...)},
|
|
138
|
+
batch_size=torch.Size([]),
|
|
139
|
+
device=cpu,
|
|
140
|
+
is_shared=False)
|
|
141
|
+
>>> print(env.available_envs)
|
|
142
|
+
['acrobot', 'ant', 'fast', 'fetch', ...]
|
|
143
|
+
|
|
144
|
+
To take advante of Brax, one usually executes multiple environments at the
|
|
145
|
+
same time. In the following example, we iteratively test different batch sizes
|
|
146
|
+
and report the execution time for a short rollout:
|
|
147
|
+
|
|
148
|
+
Examples:
|
|
149
|
+
>>> import torch
|
|
150
|
+
>>> from torch.utils.benchmark import Timer
|
|
151
|
+
>>> device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
152
|
+
>>> for batch_size in [4, 16, 128]:
|
|
153
|
+
... timer = Timer('''
|
|
154
|
+
... env.rollout(100)
|
|
155
|
+
... ''',
|
|
156
|
+
... setup=f'''
|
|
157
|
+
... import brax.envs
|
|
158
|
+
... from torchrl.envs import BraxWrapper
|
|
159
|
+
... env = BraxWrapper(brax.envs.get_environment("ant"), batch_size=[{batch_size}], device="{device}")
|
|
160
|
+
... env.set_seed(0)
|
|
161
|
+
... env.rollout(2)
|
|
162
|
+
... ''')
|
|
163
|
+
... print(batch_size, timer.timeit(10))
|
|
164
|
+
4
|
|
165
|
+
env.rollout(100)
|
|
166
|
+
setup: [...]
|
|
167
|
+
310.00 ms
|
|
168
|
+
1 measurement, 10 runs , 1 thread
|
|
169
|
+
|
|
170
|
+
16
|
|
171
|
+
env.rollout(100)
|
|
172
|
+
setup: [...]
|
|
173
|
+
268.46 ms
|
|
174
|
+
1 measurement, 10 runs , 1 thread
|
|
175
|
+
|
|
176
|
+
128
|
|
177
|
+
env.rollout(100)
|
|
178
|
+
setup: [...]
|
|
179
|
+
433.80 ms
|
|
180
|
+
1 measurement, 10 runs , 1 thread
|
|
181
|
+
|
|
182
|
+
One can backpropagate through the rollout and optimize the policy directly:
|
|
183
|
+
|
|
184
|
+
>>> import brax.envs
|
|
185
|
+
>>> from torchrl.envs import BraxWrapper
|
|
186
|
+
>>> from tensordict.nn import TensorDictModule
|
|
187
|
+
>>> from torch import nn
|
|
188
|
+
>>> import torch
|
|
189
|
+
>>>
|
|
190
|
+
>>> env = BraxWrapper(brax.envs.get_environment("ant"), batch_size=[10], requires_grad=True, cache_clear_frequency=100)
|
|
191
|
+
>>> env.set_seed(0)
|
|
192
|
+
>>> torch.manual_seed(0)
|
|
193
|
+
>>> policy = TensorDictModule(nn.Linear(27, 8), in_keys=["observation"], out_keys=["action"])
|
|
194
|
+
>>>
|
|
195
|
+
>>> td = env.rollout(10, policy)
|
|
196
|
+
>>>
|
|
197
|
+
>>> td["next", "reward"].mean().backward(retain_graph=True)
|
|
198
|
+
>>> print(policy.module.weight.grad.norm())
|
|
199
|
+
tensor(213.8605)
|
|
200
|
+
|
|
201
|
+
"""
|
|
202
|
+
|
|
203
|
+
git_url = "https://github.com/google/brax"
|
|
204
|
+
|
|
205
|
+
@_classproperty
|
|
206
|
+
def available_envs(cls):
|
|
207
|
+
if not _has_brax:
|
|
208
|
+
return []
|
|
209
|
+
return list(_get_envs())
|
|
210
|
+
|
|
211
|
+
libname = "brax"
|
|
212
|
+
|
|
213
|
+
_lib = None
|
|
214
|
+
_jax = None
|
|
215
|
+
|
|
216
|
+
@_classproperty
|
|
217
|
+
def lib(cls):
|
|
218
|
+
if cls._lib is not None:
|
|
219
|
+
return cls._lib
|
|
220
|
+
|
|
221
|
+
import brax
|
|
222
|
+
import brax.envs
|
|
223
|
+
|
|
224
|
+
cls._lib = brax
|
|
225
|
+
return brax
|
|
226
|
+
|
|
227
|
+
@_classproperty
|
|
228
|
+
def jax(cls):
|
|
229
|
+
if cls._jax is not None:
|
|
230
|
+
return cls._jax
|
|
231
|
+
|
|
232
|
+
import jax
|
|
233
|
+
|
|
234
|
+
cls._jax = jax
|
|
235
|
+
return jax
|
|
236
|
+
|
|
237
|
+
def __init__(
|
|
238
|
+
self,
|
|
239
|
+
env=None,
|
|
240
|
+
categorical_action_encoding=False,
|
|
241
|
+
cache_clear_frequency: int | None = None,
|
|
242
|
+
**kwargs,
|
|
243
|
+
):
|
|
244
|
+
if env is not None:
|
|
245
|
+
kwargs["env"] = env
|
|
246
|
+
self._seed_calls_reset = None
|
|
247
|
+
self._categorical_action_encoding = categorical_action_encoding
|
|
248
|
+
# If user passes None or False, deactivate automatic cache clearing
|
|
249
|
+
if cache_clear_frequency in (False,):
|
|
250
|
+
self._cache_clear_frequency = False
|
|
251
|
+
elif cache_clear_frequency in (None, True):
|
|
252
|
+
self._cache_clear_frequency = _DEFAULT_CACHE_CLEAR_FREQUENCY
|
|
253
|
+
else:
|
|
254
|
+
self._cache_clear_frequency = cache_clear_frequency
|
|
255
|
+
self._step_count = 0
|
|
256
|
+
super().__init__(**kwargs)
|
|
257
|
+
if not self.device:
|
|
258
|
+
warnings.warn(
|
|
259
|
+
f"No device is set for env {self}. "
|
|
260
|
+
f"Setting a device in Brax wrapped environments is strongly recommended."
|
|
261
|
+
)
|
|
262
|
+
|
|
263
|
+
def _check_kwargs(self, kwargs: dict):
|
|
264
|
+
brax = self.lib
|
|
265
|
+
if version.parse(brax.__version__) < version.parse("0.10.4"):
|
|
266
|
+
raise ImportError("Brax v0.10.4 or greater is required.")
|
|
267
|
+
|
|
268
|
+
if "env" not in kwargs:
|
|
269
|
+
raise TypeError("Could not find environment key 'env' in kwargs.")
|
|
270
|
+
env = kwargs["env"]
|
|
271
|
+
if not isinstance(env, brax.envs.Env):
|
|
272
|
+
raise TypeError("env is not of type 'brax.envs.Env'.")
|
|
273
|
+
|
|
274
|
+
def _build_env(
|
|
275
|
+
self,
|
|
276
|
+
env,
|
|
277
|
+
_seed: int | None = None,
|
|
278
|
+
from_pixels: bool = False,
|
|
279
|
+
render_kwargs: dict | None = None,
|
|
280
|
+
pixels_only: bool = False,
|
|
281
|
+
requires_grad: bool = False,
|
|
282
|
+
camera_id: int | str = 0,
|
|
283
|
+
**kwargs,
|
|
284
|
+
):
|
|
285
|
+
self.from_pixels = from_pixels
|
|
286
|
+
self.pixels_only = pixels_only
|
|
287
|
+
self.requires_grad = requires_grad
|
|
288
|
+
|
|
289
|
+
if from_pixels:
|
|
290
|
+
raise NotImplementedError(
|
|
291
|
+
"from_pixels=True is not yest supported within BraxWrapper"
|
|
292
|
+
)
|
|
293
|
+
return env
|
|
294
|
+
|
|
295
|
+
def _make_state_spec(self, env: brax.envs.env.Env): # noqa: F821
|
|
296
|
+
jax = self.jax
|
|
297
|
+
|
|
298
|
+
key = jax.random.PRNGKey(0)
|
|
299
|
+
state = env.reset(key)
|
|
300
|
+
state_dict = _object_to_tensordict(state, self.device, batch_size=())
|
|
301
|
+
state_spec = _extract_spec(state_dict).expand(self.batch_size)
|
|
302
|
+
return state_spec
|
|
303
|
+
|
|
304
|
+
def _make_specs(self, env: brax.envs.env.Env) -> None: # noqa: F821
|
|
305
|
+
self.action_spec = Bounded(
|
|
306
|
+
low=-1,
|
|
307
|
+
high=1,
|
|
308
|
+
shape=(
|
|
309
|
+
*self.batch_size,
|
|
310
|
+
env.action_size,
|
|
311
|
+
),
|
|
312
|
+
device=self.device,
|
|
313
|
+
)
|
|
314
|
+
self.reward_spec = Unbounded(
|
|
315
|
+
shape=[
|
|
316
|
+
*self.batch_size,
|
|
317
|
+
1,
|
|
318
|
+
],
|
|
319
|
+
device=self.device,
|
|
320
|
+
)
|
|
321
|
+
self.observation_spec = Composite(
|
|
322
|
+
observation=Unbounded(
|
|
323
|
+
shape=(
|
|
324
|
+
*self.batch_size,
|
|
325
|
+
env.observation_size,
|
|
326
|
+
),
|
|
327
|
+
device=self.device,
|
|
328
|
+
),
|
|
329
|
+
shape=self.batch_size,
|
|
330
|
+
)
|
|
331
|
+
# extract state spec from instance
|
|
332
|
+
state_spec = self._make_state_spec(env)
|
|
333
|
+
self.state_spec["state"] = state_spec
|
|
334
|
+
self.observation_spec["state"] = state_spec.clone()
|
|
335
|
+
|
|
336
|
+
def _make_state_example(self):
|
|
337
|
+
jax = self.jax
|
|
338
|
+
|
|
339
|
+
key = jax.random.PRNGKey(0)
|
|
340
|
+
keys = jax.random.split(key, self.batch_size.numel())
|
|
341
|
+
state = self._vmap_jit_env_reset(jax.numpy.stack(keys))
|
|
342
|
+
state = _tree_reshape(state, self.batch_size)
|
|
343
|
+
return state
|
|
344
|
+
|
|
345
|
+
def _init_env(self) -> int | None:
|
|
346
|
+
jax = self.jax
|
|
347
|
+
self._key = None
|
|
348
|
+
self._vmap_jit_env_reset = jax.vmap(jax.jit(self._env.reset))
|
|
349
|
+
self._vmap_jit_env_step = jax.vmap(jax.jit(self._env.step))
|
|
350
|
+
self._state_example = self._make_state_example()
|
|
351
|
+
|
|
352
|
+
def _set_seed(self, seed: int | None) -> None:
|
|
353
|
+
jax = self.jax
|
|
354
|
+
if seed is None:
|
|
355
|
+
raise Exception("Brax requires an integer seed.")
|
|
356
|
+
self._key = jax.random.PRNGKey(seed)
|
|
357
|
+
|
|
358
|
+
def _reset(self, tensordict: TensorDictBase = None, **kwargs) -> TensorDictBase:
|
|
359
|
+
jax = self.jax
|
|
360
|
+
|
|
361
|
+
# ensure a valid JAX PRNG key exists
|
|
362
|
+
if getattr(self, "_key", None) is None:
|
|
363
|
+
seed = getattr(self, "_seed", None)
|
|
364
|
+
if seed is None:
|
|
365
|
+
seed = 0
|
|
366
|
+
|
|
367
|
+
self._key = jax.random.PRNGKey(int(seed))
|
|
368
|
+
|
|
369
|
+
# generate random keys
|
|
370
|
+
self._key, *keys = jax.random.split(self._key, 1 + self.numel())
|
|
371
|
+
|
|
372
|
+
# call env reset with jit and vmap
|
|
373
|
+
state = self._vmap_jit_env_reset(jax.numpy.stack(keys))
|
|
374
|
+
|
|
375
|
+
# reshape batch size
|
|
376
|
+
state = _tree_reshape(state, self.batch_size)
|
|
377
|
+
state = _object_to_tensordict(state, self.device, self.batch_size)
|
|
378
|
+
|
|
379
|
+
# build result
|
|
380
|
+
state["reward"] = state.get("reward").view(*self.reward_spec.shape)
|
|
381
|
+
state["done"] = state.get("done").view(*self.reward_spec.shape)
|
|
382
|
+
done = state["done"].bool()
|
|
383
|
+
tensordict_out = TensorDict._new_unsafe(
|
|
384
|
+
source={
|
|
385
|
+
"observation": state.get("obs"),
|
|
386
|
+
# "reward": reward,
|
|
387
|
+
"done": done,
|
|
388
|
+
"terminated": done.clone(),
|
|
389
|
+
"state": state,
|
|
390
|
+
},
|
|
391
|
+
batch_size=self.batch_size,
|
|
392
|
+
device=self.device,
|
|
393
|
+
)
|
|
394
|
+
return tensordict_out
|
|
395
|
+
|
|
396
|
+
def _step_without_grad(self, tensordict: TensorDictBase):
|
|
397
|
+
|
|
398
|
+
# convert tensors to ndarrays
|
|
399
|
+
state = _tensordict_to_object(tensordict.get("state"), self._state_example)
|
|
400
|
+
action = _tensor_to_ndarray(tensordict.get("action"))
|
|
401
|
+
|
|
402
|
+
# flatten batch size
|
|
403
|
+
state = _tree_flatten(state, self.batch_size)
|
|
404
|
+
action = _tree_flatten(action, self.batch_size)
|
|
405
|
+
|
|
406
|
+
# call env step with jit and vmap
|
|
407
|
+
next_state = self._vmap_jit_env_step(state, action)
|
|
408
|
+
|
|
409
|
+
# reshape batch size and convert ndarrays to tensors
|
|
410
|
+
next_state = _tree_reshape(next_state, self.batch_size)
|
|
411
|
+
next_state = _object_to_tensordict(next_state, self.device, self.batch_size)
|
|
412
|
+
|
|
413
|
+
# build result
|
|
414
|
+
next_state.set("reward", next_state.get("reward").view(self.reward_spec.shape))
|
|
415
|
+
next_state.set("done", next_state.get("done").view(self.reward_spec.shape))
|
|
416
|
+
done = next_state["done"].bool()
|
|
417
|
+
reward = next_state["reward"]
|
|
418
|
+
tensordict_out = TensorDict._new_unsafe(
|
|
419
|
+
source={
|
|
420
|
+
"observation": next_state.get("obs"),
|
|
421
|
+
"reward": reward,
|
|
422
|
+
"done": done,
|
|
423
|
+
"terminated": done.clone(),
|
|
424
|
+
"state": next_state,
|
|
425
|
+
},
|
|
426
|
+
batch_size=self.batch_size,
|
|
427
|
+
device=self.device,
|
|
428
|
+
)
|
|
429
|
+
return tensordict_out
|
|
430
|
+
|
|
431
|
+
def _step_with_grad(self, tensordict: TensorDictBase):
|
|
432
|
+
|
|
433
|
+
# convert tensors to ndarrays
|
|
434
|
+
action = tensordict.get("action")
|
|
435
|
+
state = tensordict.get("state")
|
|
436
|
+
qp_keys, qp_values = zip(*state.get("pipeline_state").items())
|
|
437
|
+
|
|
438
|
+
# call env step with autograd function
|
|
439
|
+
next_state_nograd, next_obs, next_reward, *next_qp_values = _BraxEnvStep.apply(
|
|
440
|
+
self, state, action, *qp_values
|
|
441
|
+
)
|
|
442
|
+
|
|
443
|
+
# extract done values: we assume a shape identical to reward
|
|
444
|
+
next_done = next_state_nograd.get("done").view(*self.reward_spec.shape)
|
|
445
|
+
next_reward = next_reward.view(*self.reward_spec.shape)
|
|
446
|
+
|
|
447
|
+
# merge with tensors with grad function
|
|
448
|
+
next_state = next_state_nograd
|
|
449
|
+
next_state["obs"] = next_obs
|
|
450
|
+
next_state.set("reward", next_reward)
|
|
451
|
+
next_state.set("done", next_done)
|
|
452
|
+
next_done = next_done.bool()
|
|
453
|
+
next_state.get("pipeline_state").update(dict(zip(qp_keys, next_qp_values)))
|
|
454
|
+
|
|
455
|
+
# build result
|
|
456
|
+
tensordict_out = TensorDict._new_unsafe(
|
|
457
|
+
source={
|
|
458
|
+
"observation": next_obs,
|
|
459
|
+
"reward": next_reward,
|
|
460
|
+
"done": next_done,
|
|
461
|
+
"terminated": next_done,
|
|
462
|
+
"state": next_state,
|
|
463
|
+
},
|
|
464
|
+
batch_size=self.batch_size,
|
|
465
|
+
device=self.device,
|
|
466
|
+
)
|
|
467
|
+
return tensordict_out
|
|
468
|
+
|
|
469
|
+
def _step(
|
|
470
|
+
self,
|
|
471
|
+
tensordict: TensorDictBase,
|
|
472
|
+
) -> TensorDictBase:
|
|
473
|
+
|
|
474
|
+
if self.requires_grad:
|
|
475
|
+
out = self._step_with_grad(tensordict)
|
|
476
|
+
else:
|
|
477
|
+
out = self._step_without_grad(tensordict)
|
|
478
|
+
|
|
479
|
+
self._step_count += 1
|
|
480
|
+
if (
|
|
481
|
+
self._cache_clear_frequency
|
|
482
|
+
and (self._step_count % self._cache_clear_frequency) == 0
|
|
483
|
+
):
|
|
484
|
+
self.clear_cache()
|
|
485
|
+
|
|
486
|
+
return out
|
|
487
|
+
|
|
488
|
+
def clear_cache(self):
|
|
489
|
+
"""Clear JAX's internal cache to prevent memory leaks.
|
|
490
|
+
|
|
491
|
+
This method should be called periodically when using requires_grad=True
|
|
492
|
+
to prevent memory accumulation from JAX's internal computation graph.
|
|
493
|
+
"""
|
|
494
|
+
if hasattr(self, "jax"):
|
|
495
|
+
try:
|
|
496
|
+
# Clear JAX's compilation cache
|
|
497
|
+
if hasattr(self.jax.jit, "clear_caches"):
|
|
498
|
+
self.jax.jit.clear_caches()
|
|
499
|
+
# Alternative: clear JAX's internal cache
|
|
500
|
+
if hasattr(self.jax, "clear_caches"):
|
|
501
|
+
self.jax.clear_caches()
|
|
502
|
+
# Clear JAX's XLA compilation cache if available
|
|
503
|
+
try:
|
|
504
|
+
import jaxlib
|
|
505
|
+
|
|
506
|
+
if hasattr(jaxlib, "xla_extension"):
|
|
507
|
+
jaxlib.xla_extension.clear_caches()
|
|
508
|
+
except Exception:
|
|
509
|
+
pass
|
|
510
|
+
except Exception:
|
|
511
|
+
pass
|
|
512
|
+
|
|
513
|
+
|
|
514
|
+
class BraxEnv(BraxWrapper, metaclass=_BraxMeta):
|
|
515
|
+
"""Google Brax environment wrapper built with the environment name.
|
|
516
|
+
|
|
517
|
+
Brax offers a vectorized and differentiable simulation framework based on Jax.
|
|
518
|
+
TorchRL's wrapper incurs some overhead for the jax-to-torch conversion,
|
|
519
|
+
but computational graphs can still be built on top of the simulated trajectories,
|
|
520
|
+
allowing for backpropagation through the rollout.
|
|
521
|
+
|
|
522
|
+
GitHub: https://github.com/google/brax
|
|
523
|
+
|
|
524
|
+
Paper: https://arxiv.org/abs/2106.13281
|
|
525
|
+
|
|
526
|
+
Args:
|
|
527
|
+
env_name (str): the environment name of the env to wrap. Must be part of
|
|
528
|
+
:attr:`~.available_envs`.
|
|
529
|
+
categorical_action_encoding (bool, optional): if ``True``, categorical
|
|
530
|
+
specs will be converted to the TorchRL equivalent (:class:`torchrl.data.Categorical`),
|
|
531
|
+
otherwise a one-hot encoding will be used (:class:`torchrl.data.OneHot`).
|
|
532
|
+
Defaults to ``False``.
|
|
533
|
+
cache_clear_frequency (int, optional): automatically clear JAX's internal
|
|
534
|
+
cache every N steps to prevent memory leaks when using ``requires_grad=True``.
|
|
535
|
+
Defaults to `False` (deactivates automatic cache clearing).
|
|
536
|
+
|
|
537
|
+
Keyword Args:
|
|
538
|
+
from_pixels (bool, optional): Not yet supported.
|
|
539
|
+
frame_skip (int, optional): if provided, indicates for how many steps the
|
|
540
|
+
same action is to be repeated. The observation returned will be the
|
|
541
|
+
last observation of the sequence, whereas the reward will be the sum
|
|
542
|
+
of rewards across steps.
|
|
543
|
+
device (torch.device, optional): if provided, the device on which the data
|
|
544
|
+
is to be cast. Defaults to ``torch.device("cpu")``.
|
|
545
|
+
batch_size (torch.Size, optional): the batch size of the environment.
|
|
546
|
+
In ``brax``, this controls the number of environments simulated in
|
|
547
|
+
parallel via JAX's ``vmap`` on a single device (GPU/TPU). Brax leverages
|
|
548
|
+
MuJoCo XLA (MJX) for hardware-accelerated batched simulation, enabling
|
|
549
|
+
thousands of environments to run in parallel within a single process.
|
|
550
|
+
Defaults to ``torch.Size([])``.
|
|
551
|
+
allow_done_after_reset (bool, optional): if ``True``, it is tolerated
|
|
552
|
+
for envs to be ``done`` just after :meth:`reset` is called.
|
|
553
|
+
Defaults to ``False``.
|
|
554
|
+
num_workers (int, optional): if greater than 1, a lazy :class:`~torchrl.envs.ParallelEnv`
|
|
555
|
+
will be returned instead, with each worker instantiating its own
|
|
556
|
+
:class:`~torchrl.envs.BraxEnv` instance. Defaults to ``None``.
|
|
557
|
+
|
|
558
|
+
.. note::
|
|
559
|
+
There are two orthogonal ways to scale environment throughput:
|
|
560
|
+
|
|
561
|
+
- **batch_size**: Uses Brax's native JAX-based vectorization (``vmap``) to run
|
|
562
|
+
multiple environments in parallel on a single GPU/TPU. This is highly efficient
|
|
563
|
+
for moderate batch sizes where the MJX solver has not yet saturated.
|
|
564
|
+
- **num_workers**: Uses TorchRL's :class:`~torchrl.envs.ParallelEnv` to spawn
|
|
565
|
+
multiple Python processes, each running its own ``BraxEnv``.
|
|
566
|
+
|
|
567
|
+
These can be combined: ``BraxEnv("ant", batch_size=[128], num_workers=4)`` creates
|
|
568
|
+
4 worker processes, each running 128 vectorized environments, for a total of 512
|
|
569
|
+
parallel environments. This hybrid approach can be beneficial when the MJX solver
|
|
570
|
+
saturates on a single device, or when distributing across multiple GPUs/CPUs.
|
|
571
|
+
|
|
572
|
+
Attributes:
|
|
573
|
+
available_envs: environments available to build
|
|
574
|
+
|
|
575
|
+
Examples:
|
|
576
|
+
>>> from torchrl.envs import BraxEnv
|
|
577
|
+
>>> import torch
|
|
578
|
+
>>> device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
579
|
+
>>> env = BraxEnv("ant", device=device)
|
|
580
|
+
>>> env.set_seed(0)
|
|
581
|
+
>>> td = env.reset()
|
|
582
|
+
>>> td["action"] = env.action_spec.rand()
|
|
583
|
+
>>> td = env.step(td)
|
|
584
|
+
>>> print(td)
|
|
585
|
+
TensorDict(
|
|
586
|
+
fields={
|
|
587
|
+
action: Tensor(torch.Size([8]), dtype=torch.float32),
|
|
588
|
+
done: Tensor(torch.Size([1]), dtype=torch.bool),
|
|
589
|
+
next: TensorDict(
|
|
590
|
+
fields={
|
|
591
|
+
observation: Tensor(torch.Size([87]), dtype=torch.float32)},
|
|
592
|
+
batch_size=torch.Size([]),
|
|
593
|
+
device=cpu,
|
|
594
|
+
is_shared=False),
|
|
595
|
+
observation: Tensor(torch.Size([87]), dtype=torch.float32),
|
|
596
|
+
reward: Tensor(torch.Size([1]), dtype=torch.float32),
|
|
597
|
+
state: TensorDict(...)},
|
|
598
|
+
batch_size=torch.Size([]),
|
|
599
|
+
device=cpu,
|
|
600
|
+
is_shared=False)
|
|
601
|
+
>>> print(env.available_envs)
|
|
602
|
+
['acrobot', 'ant', 'fast', 'fetch', ...]
|
|
603
|
+
|
|
604
|
+
# Example: create a parallel environment with 4 workers. This returns a lazy
|
|
605
|
+
# ParallelEnv; each worker will instantiate a BraxEnv with num_workers=1.
|
|
606
|
+
>>> from torchrl.envs import BraxEnv
|
|
607
|
+
>>> par_env = BraxEnv("ant", batch_size=[8], num_workers=4, device="cpu")
|
|
608
|
+
>>> # par_env is a ParallelEnv; start interacting as usual
|
|
609
|
+
>>> par_env.set_seed(0)
|
|
610
|
+
>>> td = par_env.reset()
|
|
611
|
+
>>> print(td.shape)
|
|
612
|
+
torch.Size([4, 8])
|
|
613
|
+
>>> td["action"] = par_env.action_spec.rand()
|
|
614
|
+
>>> td = par_env.step(td)
|
|
615
|
+
|
|
616
|
+
To take advante of Brax, one usually executes multiple environments at the
|
|
617
|
+
same time. In the following example, we iteratively test different batch sizes
|
|
618
|
+
and report the execution time for a short rollout:
|
|
619
|
+
|
|
620
|
+
Examples:
|
|
621
|
+
>>> import torch
|
|
622
|
+
>>> from torch.utils.benchmark import Timer
|
|
623
|
+
>>> device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
624
|
+
>>> for batch_size in [4, 16, 128]:
|
|
625
|
+
... timer = Timer('''
|
|
626
|
+
... env.rollout(100)
|
|
627
|
+
... ''',
|
|
628
|
+
... setup=f'''
|
|
629
|
+
... from torchrl.envs import BraxEnv
|
|
630
|
+
... env = BraxEnv("ant", batch_size=[{batch_size}], device="{device}")
|
|
631
|
+
... env.set_seed(0)
|
|
632
|
+
... env.rollout(2)
|
|
633
|
+
... ''')
|
|
634
|
+
... print(batch_size, timer.timeit(10))
|
|
635
|
+
4
|
|
636
|
+
env.rollout(100)
|
|
637
|
+
setup: [...]
|
|
638
|
+
310.00 ms
|
|
639
|
+
1 measurement, 10 runs , 1 thread
|
|
640
|
+
|
|
641
|
+
16
|
|
642
|
+
env.rollout(100)
|
|
643
|
+
setup: [...]
|
|
644
|
+
268.46 ms
|
|
645
|
+
1 measurement, 10 runs , 1 thread
|
|
646
|
+
|
|
647
|
+
128
|
|
648
|
+
env.rollout(100)
|
|
649
|
+
setup: [...]
|
|
650
|
+
433.80 ms
|
|
651
|
+
1 measurement, 10 runs , 1 thread
|
|
652
|
+
|
|
653
|
+
One can backpropagate through the rollout and optimize the policy directly:
|
|
654
|
+
|
|
655
|
+
>>> from torchrl.envs import BraxEnv
|
|
656
|
+
>>> from tensordict.nn import TensorDictModule
|
|
657
|
+
>>> from torch import nn
|
|
658
|
+
>>> import torch
|
|
659
|
+
>>>
|
|
660
|
+
>>> env = BraxEnv("ant", batch_size=[10], requires_grad=True, cache_clear_frequency=100)
|
|
661
|
+
>>> env.set_seed(0)
|
|
662
|
+
>>> torch.manual_seed(0)
|
|
663
|
+
>>> policy = TensorDictModule(nn.Linear(27, 8), in_keys=["observation"], out_keys=["action"])
|
|
664
|
+
>>>
|
|
665
|
+
>>> td = env.rollout(10, policy)
|
|
666
|
+
>>>
|
|
667
|
+
>>> td["next", "reward"].mean().backward(retain_graph=True)
|
|
668
|
+
>>> print(policy.module.weight.grad.norm())
|
|
669
|
+
tensor(213.8605)
|
|
670
|
+
|
|
671
|
+
"""
|
|
672
|
+
|
|
673
|
+
def __init__(self, env_name, **kwargs):
|
|
674
|
+
kwargs["env_name"] = env_name
|
|
675
|
+
super().__init__(**kwargs)
|
|
676
|
+
|
|
677
|
+
def _build_env(
|
|
678
|
+
self,
|
|
679
|
+
env_name: str,
|
|
680
|
+
**kwargs,
|
|
681
|
+
) -> brax.envs.env.Env: # noqa: F821
|
|
682
|
+
if not _has_brax:
|
|
683
|
+
raise ImportError(
|
|
684
|
+
f"brax not found, unable to create {env_name}. "
|
|
685
|
+
f"Consider downloading and installing brax from"
|
|
686
|
+
f" {self.git_url}"
|
|
687
|
+
)
|
|
688
|
+
from_pixels = kwargs.pop("from_pixels", False)
|
|
689
|
+
pixels_only = kwargs.pop("pixels_only", True)
|
|
690
|
+
requires_grad = kwargs.pop("requires_grad", False)
|
|
691
|
+
cache_clear_frequency = kwargs.pop("cache_clear_frequency", False)
|
|
692
|
+
if kwargs:
|
|
693
|
+
raise ValueError("kwargs not supported.")
|
|
694
|
+
self.wrapper_frame_skip = 1
|
|
695
|
+
env = self.lib.envs.get_environment(env_name, **kwargs)
|
|
696
|
+
return super()._build_env(
|
|
697
|
+
env,
|
|
698
|
+
pixels_only=pixels_only,
|
|
699
|
+
from_pixels=from_pixels,
|
|
700
|
+
requires_grad=requires_grad,
|
|
701
|
+
cache_clear_frequency=cache_clear_frequency,
|
|
702
|
+
)
|
|
703
|
+
|
|
704
|
+
@property
|
|
705
|
+
def env_name(self):
|
|
706
|
+
return self._constructor_kwargs["env_name"]
|
|
707
|
+
|
|
708
|
+
def _check_kwargs(self, kwargs: dict):
|
|
709
|
+
if "env_name" not in kwargs:
|
|
710
|
+
raise TypeError("Expected 'env_name' to be part of kwargs")
|
|
711
|
+
|
|
712
|
+
def __repr__(self) -> str:
|
|
713
|
+
return f"{self.__class__.__name__}(env={self.env_name}, batch_size={self.batch_size}, device={self.device})"
|
|
714
|
+
|
|
715
|
+
|
|
716
|
+
class _BraxEnvStep(torch.autograd.Function):
|
|
717
|
+
@staticmethod
|
|
718
|
+
def forward(ctx, env: BraxWrapper, state_td, action_tensor, *qp_values):
|
|
719
|
+
import jax
|
|
720
|
+
|
|
721
|
+
# convert tensors to ndarrays
|
|
722
|
+
state_obj = _tensordict_to_object(state_td, env._state_example)
|
|
723
|
+
action_nd = _tensor_to_ndarray(action_tensor)
|
|
724
|
+
|
|
725
|
+
# flatten batch size
|
|
726
|
+
state = _tree_flatten(state_obj, env.batch_size)
|
|
727
|
+
action = _tree_flatten(action_nd, env.batch_size)
|
|
728
|
+
|
|
729
|
+
# call vjp with jit and vmap
|
|
730
|
+
next_state, vjp_fn = jax.vjp(env._vmap_jit_env_step, state, action)
|
|
731
|
+
|
|
732
|
+
# reshape batch size
|
|
733
|
+
next_state_reshape = _tree_reshape(next_state, env.batch_size)
|
|
734
|
+
|
|
735
|
+
# convert ndarrays to tensors
|
|
736
|
+
next_state_tensor = _object_to_tensordict(
|
|
737
|
+
next_state_reshape, device=env.device, batch_size=env.batch_size
|
|
738
|
+
)
|
|
739
|
+
|
|
740
|
+
# save context
|
|
741
|
+
ctx.vjp_fn = vjp_fn
|
|
742
|
+
ctx.next_state = next_state_tensor
|
|
743
|
+
ctx.env = env
|
|
744
|
+
# Mark that backward hasn't been called yet
|
|
745
|
+
ctx._backward_called = False
|
|
746
|
+
|
|
747
|
+
return (
|
|
748
|
+
next_state_tensor, # no gradient
|
|
749
|
+
next_state_tensor["obs"],
|
|
750
|
+
next_state_tensor["reward"],
|
|
751
|
+
*next_state_tensor["pipeline_state"].values(),
|
|
752
|
+
)
|
|
753
|
+
|
|
754
|
+
@staticmethod
|
|
755
|
+
def backward(ctx, _, grad_next_obs, grad_next_reward, *grad_next_qp_values):
|
|
756
|
+
# Prevent multiple backward calls on the same context
|
|
757
|
+
if hasattr(ctx, "_backward_called") and ctx._backward_called:
|
|
758
|
+
return (None, None, *([None] * len(grad_next_qp_values)))
|
|
759
|
+
|
|
760
|
+
ctx._backward_called = True
|
|
761
|
+
|
|
762
|
+
pipeline_state = dict(
|
|
763
|
+
zip(ctx.next_state.get("pipeline_state").keys(), grad_next_qp_values)
|
|
764
|
+
)
|
|
765
|
+
none_keys = []
|
|
766
|
+
|
|
767
|
+
def _make_none(key, val):
|
|
768
|
+
if val is not None:
|
|
769
|
+
return val
|
|
770
|
+
none_keys.append(key)
|
|
771
|
+
return torch.zeros_like(ctx.next_state.get(("pipeline_state", key)))
|
|
772
|
+
|
|
773
|
+
pipeline_state = {
|
|
774
|
+
key: _make_none(key, val) for key, val in pipeline_state.items()
|
|
775
|
+
}
|
|
776
|
+
metrics = ctx.next_state.get("metrics", None)
|
|
777
|
+
if metrics is None:
|
|
778
|
+
metrics = {}
|
|
779
|
+
info = ctx.next_state.get("info", None)
|
|
780
|
+
if info is None:
|
|
781
|
+
info = {}
|
|
782
|
+
grad_next_state_td = TensorDict(
|
|
783
|
+
source={
|
|
784
|
+
"pipeline_state": pipeline_state,
|
|
785
|
+
"obs": grad_next_obs,
|
|
786
|
+
"reward": grad_next_reward,
|
|
787
|
+
"done": torch.zeros_like(ctx.next_state.get("done")),
|
|
788
|
+
"metrics": {k: torch.zeros_like(v) for k, v in metrics.items()},
|
|
789
|
+
"info": {k: torch.zeros_like(v) for k, v in info.items()},
|
|
790
|
+
},
|
|
791
|
+
device=ctx.env.device,
|
|
792
|
+
batch_size=ctx.env.batch_size,
|
|
793
|
+
)
|
|
794
|
+
# convert tensors to ndarrays
|
|
795
|
+
grad_next_state_obj = _tensordict_to_object(
|
|
796
|
+
grad_next_state_td, ctx.env._state_example
|
|
797
|
+
)
|
|
798
|
+
|
|
799
|
+
# flatten batch size
|
|
800
|
+
grad_next_state_flat = _tree_flatten(grad_next_state_obj, ctx.env.batch_size)
|
|
801
|
+
|
|
802
|
+
# call vjp to get gradients
|
|
803
|
+
grad_state, grad_action = ctx.vjp_fn(grad_next_state_flat)
|
|
804
|
+
# assert grad_action.device == ctx.env.device
|
|
805
|
+
|
|
806
|
+
# reshape batch size
|
|
807
|
+
grad_state = _tree_reshape(grad_state, ctx.env.batch_size)
|
|
808
|
+
grad_action = _tree_reshape(grad_action, ctx.env.batch_size)
|
|
809
|
+
# assert grad_action.device == ctx.env.device
|
|
810
|
+
|
|
811
|
+
# convert ndarrays to tensors
|
|
812
|
+
grad_state_qp = _object_to_tensordict(
|
|
813
|
+
grad_state.pipeline_state,
|
|
814
|
+
device=ctx.env.device,
|
|
815
|
+
batch_size=ctx.env.batch_size,
|
|
816
|
+
)
|
|
817
|
+
grad_action = _ndarray_to_tensor(grad_action).to(ctx.env.device)
|
|
818
|
+
grad_state_qp = {
|
|
819
|
+
key: val if key not in none_keys else None
|
|
820
|
+
for key, val in grad_state_qp.items()
|
|
821
|
+
}
|
|
822
|
+
grads = (grad_action, *grad_state_qp.values())
|
|
823
|
+
|
|
824
|
+
# Clean up context to prevent memory leaks
|
|
825
|
+
try:
|
|
826
|
+
# Clear JAX VJP function reference
|
|
827
|
+
del ctx.vjp_fn
|
|
828
|
+
except AttributeError:
|
|
829
|
+
pass
|
|
830
|
+
try:
|
|
831
|
+
# Clear stored tensors
|
|
832
|
+
del ctx.next_state
|
|
833
|
+
except AttributeError:
|
|
834
|
+
pass
|
|
835
|
+
try:
|
|
836
|
+
# Clear environment reference
|
|
837
|
+
del ctx.env
|
|
838
|
+
except AttributeError:
|
|
839
|
+
pass
|
|
840
|
+
try:
|
|
841
|
+
# Clear the backward flag
|
|
842
|
+
del ctx._backward_called
|
|
843
|
+
except AttributeError:
|
|
844
|
+
pass
|
|
845
|
+
|
|
846
|
+
return (None, None, *grads)
|