torchrl 0.11.0__cp314-cp314-manylinux_2_28_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/benchmark_batched_envs.py +104 -0
- benchmarks/conftest.py +91 -0
- benchmarks/ecosystem/gym_env_throughput.py +321 -0
- benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
- benchmarks/requirements.txt +7 -0
- benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
- benchmarks/test_collectors_benchmark.py +240 -0
- benchmarks/test_compressed_storage_benchmark.py +145 -0
- benchmarks/test_envs_benchmark.py +133 -0
- benchmarks/test_llm.py +101 -0
- benchmarks/test_non_tensor_env_benchmark.py +70 -0
- benchmarks/test_objectives_benchmarks.py +1199 -0
- benchmarks/test_replaybuffer_benchmark.py +254 -0
- sota-check/README.md +35 -0
- sota-implementations/README.md +142 -0
- sota-implementations/a2c/README.md +39 -0
- sota-implementations/a2c/a2c_atari.py +291 -0
- sota-implementations/a2c/a2c_mujoco.py +273 -0
- sota-implementations/a2c/utils_atari.py +240 -0
- sota-implementations/a2c/utils_mujoco.py +160 -0
- sota-implementations/bandits/README.md +7 -0
- sota-implementations/bandits/dqn.py +126 -0
- sota-implementations/cql/cql_offline.py +198 -0
- sota-implementations/cql/cql_online.py +249 -0
- sota-implementations/cql/discrete_cql_offline.py +180 -0
- sota-implementations/cql/discrete_cql_online.py +227 -0
- sota-implementations/cql/utils.py +471 -0
- sota-implementations/crossq/crossq.py +271 -0
- sota-implementations/crossq/utils.py +320 -0
- sota-implementations/ddpg/ddpg.py +231 -0
- sota-implementations/ddpg/utils.py +325 -0
- sota-implementations/decision_transformer/dt.py +163 -0
- sota-implementations/decision_transformer/lamb.py +167 -0
- sota-implementations/decision_transformer/online_dt.py +178 -0
- sota-implementations/decision_transformer/utils.py +562 -0
- sota-implementations/discrete_sac/discrete_sac.py +243 -0
- sota-implementations/discrete_sac/utils.py +324 -0
- sota-implementations/dqn/README.md +30 -0
- sota-implementations/dqn/dqn_atari.py +272 -0
- sota-implementations/dqn/dqn_cartpole.py +236 -0
- sota-implementations/dqn/utils_atari.py +132 -0
- sota-implementations/dqn/utils_cartpole.py +90 -0
- sota-implementations/dreamer/README.md +129 -0
- sota-implementations/dreamer/dreamer.py +586 -0
- sota-implementations/dreamer/dreamer_utils.py +1107 -0
- sota-implementations/expert-iteration/README.md +352 -0
- sota-implementations/expert-iteration/ei_utils.py +770 -0
- sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
- sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
- sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
- sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
- sota-implementations/gail/gail.py +327 -0
- sota-implementations/gail/gail_utils.py +68 -0
- sota-implementations/gail/ppo_utils.py +157 -0
- sota-implementations/grpo/README.md +273 -0
- sota-implementations/grpo/grpo-async.py +437 -0
- sota-implementations/grpo/grpo-sync.py +435 -0
- sota-implementations/grpo/grpo_utils.py +843 -0
- sota-implementations/grpo/requirements_gsm8k.txt +11 -0
- sota-implementations/grpo/requirements_ifeval.txt +16 -0
- sota-implementations/impala/README.md +33 -0
- sota-implementations/impala/impala_multi_node_ray.py +292 -0
- sota-implementations/impala/impala_multi_node_submitit.py +284 -0
- sota-implementations/impala/impala_single_node.py +261 -0
- sota-implementations/impala/utils.py +184 -0
- sota-implementations/iql/discrete_iql.py +230 -0
- sota-implementations/iql/iql_offline.py +164 -0
- sota-implementations/iql/iql_online.py +225 -0
- sota-implementations/iql/utils.py +437 -0
- sota-implementations/multiagent/README.md +74 -0
- sota-implementations/multiagent/iql.py +237 -0
- sota-implementations/multiagent/maddpg_iddpg.py +266 -0
- sota-implementations/multiagent/mappo_ippo.py +267 -0
- sota-implementations/multiagent/qmix_vdn.py +271 -0
- sota-implementations/multiagent/sac.py +337 -0
- sota-implementations/multiagent/utils/__init__.py +4 -0
- sota-implementations/multiagent/utils/logging.py +151 -0
- sota-implementations/multiagent/utils/utils.py +43 -0
- sota-implementations/ppo/README.md +29 -0
- sota-implementations/ppo/ppo_atari.py +305 -0
- sota-implementations/ppo/ppo_mujoco.py +293 -0
- sota-implementations/ppo/utils_atari.py +238 -0
- sota-implementations/ppo/utils_mujoco.py +152 -0
- sota-implementations/ppo_trainer/train.py +21 -0
- sota-implementations/redq/README.md +7 -0
- sota-implementations/redq/redq.py +199 -0
- sota-implementations/redq/utils.py +1060 -0
- sota-implementations/sac/sac-async.py +266 -0
- sota-implementations/sac/sac.py +239 -0
- sota-implementations/sac/utils.py +381 -0
- sota-implementations/sac_trainer/train.py +16 -0
- sota-implementations/td3/td3.py +254 -0
- sota-implementations/td3/utils.py +319 -0
- sota-implementations/td3_bc/td3_bc.py +177 -0
- sota-implementations/td3_bc/utils.py +251 -0
- torchrl/__init__.py +144 -0
- torchrl/_extension.py +74 -0
- torchrl/_torchrl.cpython-314-aarch64-linux-gnu.so +0 -0
- torchrl/_utils.py +1431 -0
- torchrl/collectors/__init__.py +48 -0
- torchrl/collectors/_base.py +1058 -0
- torchrl/collectors/_constants.py +88 -0
- torchrl/collectors/_multi_async.py +324 -0
- torchrl/collectors/_multi_base.py +1805 -0
- torchrl/collectors/_multi_sync.py +464 -0
- torchrl/collectors/_runner.py +581 -0
- torchrl/collectors/_single.py +2009 -0
- torchrl/collectors/_single_async.py +259 -0
- torchrl/collectors/collectors.py +62 -0
- torchrl/collectors/distributed/__init__.py +32 -0
- torchrl/collectors/distributed/default_configs.py +133 -0
- torchrl/collectors/distributed/generic.py +1306 -0
- torchrl/collectors/distributed/ray.py +1092 -0
- torchrl/collectors/distributed/rpc.py +1006 -0
- torchrl/collectors/distributed/sync.py +731 -0
- torchrl/collectors/distributed/utils.py +160 -0
- torchrl/collectors/llm/__init__.py +10 -0
- torchrl/collectors/llm/base.py +494 -0
- torchrl/collectors/llm/ray_collector.py +275 -0
- torchrl/collectors/llm/utils.py +36 -0
- torchrl/collectors/llm/weight_update/__init__.py +10 -0
- torchrl/collectors/llm/weight_update/vllm.py +348 -0
- torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
- torchrl/collectors/utils.py +433 -0
- torchrl/collectors/weight_update.py +591 -0
- torchrl/csrc/numpy_utils.h +38 -0
- torchrl/csrc/pybind.cpp +27 -0
- torchrl/csrc/segment_tree.h +458 -0
- torchrl/csrc/torch_utils.h +34 -0
- torchrl/csrc/utils.cpp +48 -0
- torchrl/csrc/utils.h +31 -0
- torchrl/data/__init__.py +187 -0
- torchrl/data/datasets/__init__.py +58 -0
- torchrl/data/datasets/atari_dqn.py +878 -0
- torchrl/data/datasets/common.py +281 -0
- torchrl/data/datasets/d4rl.py +489 -0
- torchrl/data/datasets/d4rl_infos.py +187 -0
- torchrl/data/datasets/gen_dgrl.py +375 -0
- torchrl/data/datasets/minari_data.py +643 -0
- torchrl/data/datasets/openml.py +177 -0
- torchrl/data/datasets/openx.py +798 -0
- torchrl/data/datasets/roboset.py +363 -0
- torchrl/data/datasets/utils.py +11 -0
- torchrl/data/datasets/vd4rl.py +432 -0
- torchrl/data/llm/__init__.py +34 -0
- torchrl/data/llm/dataset.py +491 -0
- torchrl/data/llm/history.py +1378 -0
- torchrl/data/llm/prompt.py +198 -0
- torchrl/data/llm/reward.py +225 -0
- torchrl/data/llm/topk.py +186 -0
- torchrl/data/llm/utils.py +543 -0
- torchrl/data/map/__init__.py +21 -0
- torchrl/data/map/hash.py +185 -0
- torchrl/data/map/query.py +204 -0
- torchrl/data/map/tdstorage.py +363 -0
- torchrl/data/map/tree.py +1434 -0
- torchrl/data/map/utils.py +103 -0
- torchrl/data/postprocs/__init__.py +8 -0
- torchrl/data/postprocs/postprocs.py +391 -0
- torchrl/data/replay_buffers/__init__.py +99 -0
- torchrl/data/replay_buffers/checkpointers.py +622 -0
- torchrl/data/replay_buffers/ray_buffer.py +292 -0
- torchrl/data/replay_buffers/replay_buffers.py +2376 -0
- torchrl/data/replay_buffers/samplers.py +2578 -0
- torchrl/data/replay_buffers/scheduler.py +265 -0
- torchrl/data/replay_buffers/storages.py +2412 -0
- torchrl/data/replay_buffers/utils.py +1042 -0
- torchrl/data/replay_buffers/writers.py +781 -0
- torchrl/data/tensor_specs.py +7101 -0
- torchrl/data/utils.py +334 -0
- torchrl/envs/__init__.py +265 -0
- torchrl/envs/async_envs.py +1105 -0
- torchrl/envs/batched_envs.py +3093 -0
- torchrl/envs/common.py +4241 -0
- torchrl/envs/custom/__init__.py +11 -0
- torchrl/envs/custom/chess.py +617 -0
- torchrl/envs/custom/llm.py +214 -0
- torchrl/envs/custom/pendulum.py +401 -0
- torchrl/envs/custom/san_moves.txt +29274 -0
- torchrl/envs/custom/tictactoeenv.py +288 -0
- torchrl/envs/env_creator.py +263 -0
- torchrl/envs/gym_like.py +752 -0
- torchrl/envs/libs/__init__.py +68 -0
- torchrl/envs/libs/_gym_utils.py +326 -0
- torchrl/envs/libs/brax.py +846 -0
- torchrl/envs/libs/dm_control.py +544 -0
- torchrl/envs/libs/envpool.py +447 -0
- torchrl/envs/libs/gym.py +2239 -0
- torchrl/envs/libs/habitat.py +138 -0
- torchrl/envs/libs/isaac_lab.py +87 -0
- torchrl/envs/libs/isaacgym.py +203 -0
- torchrl/envs/libs/jax_utils.py +166 -0
- torchrl/envs/libs/jumanji.py +963 -0
- torchrl/envs/libs/meltingpot.py +599 -0
- torchrl/envs/libs/openml.py +153 -0
- torchrl/envs/libs/openspiel.py +652 -0
- torchrl/envs/libs/pettingzoo.py +1042 -0
- torchrl/envs/libs/procgen.py +351 -0
- torchrl/envs/libs/robohive.py +429 -0
- torchrl/envs/libs/smacv2.py +645 -0
- torchrl/envs/libs/unity_mlagents.py +891 -0
- torchrl/envs/libs/utils.py +147 -0
- torchrl/envs/libs/vmas.py +813 -0
- torchrl/envs/llm/__init__.py +63 -0
- torchrl/envs/llm/chat.py +730 -0
- torchrl/envs/llm/datasets/README.md +4 -0
- torchrl/envs/llm/datasets/__init__.py +17 -0
- torchrl/envs/llm/datasets/gsm8k.py +353 -0
- torchrl/envs/llm/datasets/ifeval.py +274 -0
- torchrl/envs/llm/envs.py +789 -0
- torchrl/envs/llm/libs/README.md +3 -0
- torchrl/envs/llm/libs/__init__.py +8 -0
- torchrl/envs/llm/libs/mlgym.py +869 -0
- torchrl/envs/llm/reward/__init__.py +10 -0
- torchrl/envs/llm/reward/gsm8k.py +324 -0
- torchrl/envs/llm/reward/ifeval/README.md +13 -0
- torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
- torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
- torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
- torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
- torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
- torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
- torchrl/envs/llm/transforms/__init__.py +55 -0
- torchrl/envs/llm/transforms/browser.py +292 -0
- torchrl/envs/llm/transforms/dataloading.py +859 -0
- torchrl/envs/llm/transforms/format.py +73 -0
- torchrl/envs/llm/transforms/kl.py +1544 -0
- torchrl/envs/llm/transforms/policy_version.py +189 -0
- torchrl/envs/llm/transforms/reason.py +323 -0
- torchrl/envs/llm/transforms/tokenizer.py +321 -0
- torchrl/envs/llm/transforms/tools.py +1955 -0
- torchrl/envs/model_based/__init__.py +9 -0
- torchrl/envs/model_based/common.py +180 -0
- torchrl/envs/model_based/dreamer.py +112 -0
- torchrl/envs/transforms/__init__.py +147 -0
- torchrl/envs/transforms/functional.py +48 -0
- torchrl/envs/transforms/gym_transforms.py +203 -0
- torchrl/envs/transforms/module.py +341 -0
- torchrl/envs/transforms/r3m.py +372 -0
- torchrl/envs/transforms/ray_service.py +663 -0
- torchrl/envs/transforms/rb_transforms.py +214 -0
- torchrl/envs/transforms/transforms.py +11835 -0
- torchrl/envs/transforms/utils.py +94 -0
- torchrl/envs/transforms/vc1.py +307 -0
- torchrl/envs/transforms/vecnorm.py +845 -0
- torchrl/envs/transforms/vip.py +407 -0
- torchrl/envs/utils.py +1718 -0
- torchrl/envs/vec_envs.py +11 -0
- torchrl/modules/__init__.py +206 -0
- torchrl/modules/distributions/__init__.py +73 -0
- torchrl/modules/distributions/continuous.py +830 -0
- torchrl/modules/distributions/discrete.py +908 -0
- torchrl/modules/distributions/truncated_normal.py +187 -0
- torchrl/modules/distributions/utils.py +233 -0
- torchrl/modules/llm/__init__.py +62 -0
- torchrl/modules/llm/backends/__init__.py +65 -0
- torchrl/modules/llm/backends/vllm/__init__.py +94 -0
- torchrl/modules/llm/backends/vllm/_models.py +46 -0
- torchrl/modules/llm/backends/vllm/base.py +72 -0
- torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
- torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
- torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
- torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
- torchrl/modules/llm/policies/__init__.py +28 -0
- torchrl/modules/llm/policies/common.py +1809 -0
- torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
- torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
- torchrl/modules/llm/utils.py +23 -0
- torchrl/modules/mcts/__init__.py +21 -0
- torchrl/modules/mcts/scores.py +579 -0
- torchrl/modules/models/__init__.py +86 -0
- torchrl/modules/models/batchrenorm.py +119 -0
- torchrl/modules/models/decision_transformer.py +179 -0
- torchrl/modules/models/exploration.py +731 -0
- torchrl/modules/models/llm.py +156 -0
- torchrl/modules/models/model_based.py +596 -0
- torchrl/modules/models/models.py +1712 -0
- torchrl/modules/models/multiagent.py +1067 -0
- torchrl/modules/models/recipes/impala.py +185 -0
- torchrl/modules/models/utils.py +162 -0
- torchrl/modules/planners/__init__.py +10 -0
- torchrl/modules/planners/cem.py +228 -0
- torchrl/modules/planners/common.py +73 -0
- torchrl/modules/planners/mppi.py +265 -0
- torchrl/modules/tensordict_module/__init__.py +89 -0
- torchrl/modules/tensordict_module/actors.py +2457 -0
- torchrl/modules/tensordict_module/common.py +529 -0
- torchrl/modules/tensordict_module/exploration.py +814 -0
- torchrl/modules/tensordict_module/probabilistic.py +321 -0
- torchrl/modules/tensordict_module/rnn.py +1639 -0
- torchrl/modules/tensordict_module/sequence.py +132 -0
- torchrl/modules/tensordict_module/world_models.py +34 -0
- torchrl/modules/utils/__init__.py +38 -0
- torchrl/modules/utils/mappings.py +9 -0
- torchrl/modules/utils/utils.py +89 -0
- torchrl/objectives/__init__.py +78 -0
- torchrl/objectives/a2c.py +659 -0
- torchrl/objectives/common.py +753 -0
- torchrl/objectives/cql.py +1346 -0
- torchrl/objectives/crossq.py +710 -0
- torchrl/objectives/ddpg.py +453 -0
- torchrl/objectives/decision_transformer.py +371 -0
- torchrl/objectives/deprecated.py +516 -0
- torchrl/objectives/dqn.py +683 -0
- torchrl/objectives/dreamer.py +488 -0
- torchrl/objectives/functional.py +48 -0
- torchrl/objectives/gail.py +258 -0
- torchrl/objectives/iql.py +996 -0
- torchrl/objectives/llm/__init__.py +30 -0
- torchrl/objectives/llm/grpo.py +846 -0
- torchrl/objectives/llm/sft.py +482 -0
- torchrl/objectives/multiagent/__init__.py +8 -0
- torchrl/objectives/multiagent/qmixer.py +396 -0
- torchrl/objectives/ppo.py +1669 -0
- torchrl/objectives/redq.py +683 -0
- torchrl/objectives/reinforce.py +530 -0
- torchrl/objectives/sac.py +1580 -0
- torchrl/objectives/td3.py +570 -0
- torchrl/objectives/td3_bc.py +625 -0
- torchrl/objectives/utils.py +782 -0
- torchrl/objectives/value/__init__.py +28 -0
- torchrl/objectives/value/advantages.py +1956 -0
- torchrl/objectives/value/functional.py +1459 -0
- torchrl/objectives/value/utils.py +360 -0
- torchrl/record/__init__.py +17 -0
- torchrl/record/loggers/__init__.py +23 -0
- torchrl/record/loggers/common.py +48 -0
- torchrl/record/loggers/csv.py +226 -0
- torchrl/record/loggers/mlflow.py +142 -0
- torchrl/record/loggers/tensorboard.py +139 -0
- torchrl/record/loggers/trackio.py +163 -0
- torchrl/record/loggers/utils.py +78 -0
- torchrl/record/loggers/wandb.py +214 -0
- torchrl/record/recorder.py +554 -0
- torchrl/services/__init__.py +79 -0
- torchrl/services/base.py +109 -0
- torchrl/services/ray_service.py +453 -0
- torchrl/testing/__init__.py +107 -0
- torchrl/testing/assertions.py +179 -0
- torchrl/testing/dist_utils.py +122 -0
- torchrl/testing/env_creators.py +227 -0
- torchrl/testing/env_helper.py +35 -0
- torchrl/testing/gym_helpers.py +156 -0
- torchrl/testing/llm_mocks.py +119 -0
- torchrl/testing/mocking_classes.py +2720 -0
- torchrl/testing/modules.py +295 -0
- torchrl/testing/mp_helpers.py +15 -0
- torchrl/testing/ray_helpers.py +293 -0
- torchrl/testing/utils.py +190 -0
- torchrl/trainers/__init__.py +42 -0
- torchrl/trainers/algorithms/__init__.py +11 -0
- torchrl/trainers/algorithms/configs/__init__.py +705 -0
- torchrl/trainers/algorithms/configs/collectors.py +216 -0
- torchrl/trainers/algorithms/configs/common.py +41 -0
- torchrl/trainers/algorithms/configs/data.py +308 -0
- torchrl/trainers/algorithms/configs/envs.py +104 -0
- torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
- torchrl/trainers/algorithms/configs/logging.py +80 -0
- torchrl/trainers/algorithms/configs/modules.py +570 -0
- torchrl/trainers/algorithms/configs/objectives.py +177 -0
- torchrl/trainers/algorithms/configs/trainers.py +340 -0
- torchrl/trainers/algorithms/configs/transforms.py +955 -0
- torchrl/trainers/algorithms/configs/utils.py +252 -0
- torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
- torchrl/trainers/algorithms/configs/weight_update.py +159 -0
- torchrl/trainers/algorithms/ppo.py +373 -0
- torchrl/trainers/algorithms/sac.py +308 -0
- torchrl/trainers/helpers/__init__.py +40 -0
- torchrl/trainers/helpers/collectors.py +416 -0
- torchrl/trainers/helpers/envs.py +573 -0
- torchrl/trainers/helpers/logger.py +33 -0
- torchrl/trainers/helpers/losses.py +132 -0
- torchrl/trainers/helpers/models.py +658 -0
- torchrl/trainers/helpers/replay_buffer.py +59 -0
- torchrl/trainers/helpers/trainers.py +301 -0
- torchrl/trainers/trainers.py +2052 -0
- torchrl/weight_update/__init__.py +33 -0
- torchrl/weight_update/_distributed.py +749 -0
- torchrl/weight_update/_mp.py +624 -0
- torchrl/weight_update/_noupdate.py +102 -0
- torchrl/weight_update/_ray.py +1032 -0
- torchrl/weight_update/_rpc.py +284 -0
- torchrl/weight_update/_shared.py +891 -0
- torchrl/weight_update/llm/__init__.py +32 -0
- torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
- torchrl/weight_update/llm/vllm_nccl.py +710 -0
- torchrl/weight_update/utils.py +73 -0
- torchrl/weight_update/weight_sync_schemes.py +1244 -0
- torchrl-0.11.0.dist-info/METADATA +1308 -0
- torchrl-0.11.0.dist-info/RECORD +394 -0
- torchrl-0.11.0.dist-info/WHEEL +5 -0
- torchrl-0.11.0.dist-info/entry_points.txt +2 -0
- torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
- torchrl-0.11.0.dist-info/top_level.txt +7 -0
|
@@ -0,0 +1,185 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
import torch.nn as nn
|
|
9
|
+
import torch.nn.functional as F
|
|
10
|
+
from tensordict import TensorDictBase
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
# TODO: code small architecture ref in Impala paper
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class _ResNetBlock(nn.Module):
|
|
17
|
+
def __init__(
|
|
18
|
+
self,
|
|
19
|
+
num_ch,
|
|
20
|
+
):
|
|
21
|
+
super().__init__()
|
|
22
|
+
resnet_block = []
|
|
23
|
+
resnet_block.append(nn.ReLU(inplace=True))
|
|
24
|
+
resnet_block.append(
|
|
25
|
+
nn.LazyConv2d(
|
|
26
|
+
out_channels=num_ch,
|
|
27
|
+
kernel_size=3,
|
|
28
|
+
stride=1,
|
|
29
|
+
padding=1,
|
|
30
|
+
)
|
|
31
|
+
)
|
|
32
|
+
resnet_block.append(nn.ReLU(inplace=True))
|
|
33
|
+
resnet_block.append(
|
|
34
|
+
nn.Conv2d(
|
|
35
|
+
in_channels=num_ch,
|
|
36
|
+
out_channels=num_ch,
|
|
37
|
+
kernel_size=3,
|
|
38
|
+
stride=1,
|
|
39
|
+
padding=1,
|
|
40
|
+
)
|
|
41
|
+
)
|
|
42
|
+
self.seq = nn.Sequential(*resnet_block)
|
|
43
|
+
|
|
44
|
+
def forward(self, x):
|
|
45
|
+
x += self.seq(x)
|
|
46
|
+
return x
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class _ConvNetBlock(nn.Module):
|
|
50
|
+
def __init__(self, num_ch):
|
|
51
|
+
super().__init__()
|
|
52
|
+
|
|
53
|
+
conv = nn.LazyConv2d(
|
|
54
|
+
out_channels=num_ch,
|
|
55
|
+
kernel_size=3,
|
|
56
|
+
stride=1,
|
|
57
|
+
padding=1,
|
|
58
|
+
)
|
|
59
|
+
mp = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
|
60
|
+
self.feats_conv = nn.Sequential(conv, mp)
|
|
61
|
+
self.resnet1 = _ResNetBlock(num_ch=num_ch)
|
|
62
|
+
self.resnet2 = _ResNetBlock(num_ch=num_ch)
|
|
63
|
+
|
|
64
|
+
def forward(self, x):
|
|
65
|
+
x = self.feats_conv(x)
|
|
66
|
+
x = self.resnet1(x)
|
|
67
|
+
x = self.resnet1(x)
|
|
68
|
+
return x
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
class ImpalaNet(nn.Module): # noqa: D101
|
|
72
|
+
def __init__(
|
|
73
|
+
self,
|
|
74
|
+
num_actions,
|
|
75
|
+
channels=(16, 32, 32),
|
|
76
|
+
out_features=256,
|
|
77
|
+
use_lstm=False,
|
|
78
|
+
batch_first=True,
|
|
79
|
+
clamp_reward=True,
|
|
80
|
+
one_hot=False,
|
|
81
|
+
):
|
|
82
|
+
super().__init__()
|
|
83
|
+
self.batch_first = batch_first
|
|
84
|
+
self.use_lstm = use_lstm
|
|
85
|
+
self.clamp_reward = clamp_reward
|
|
86
|
+
self.one_hot = one_hot
|
|
87
|
+
self.num_actions = num_actions
|
|
88
|
+
|
|
89
|
+
layers = [_ConvNetBlock(num_ch) for num_ch in channels]
|
|
90
|
+
layers += [nn.ReLU(inplace=True)]
|
|
91
|
+
self.convs = nn.Sequential(*layers)
|
|
92
|
+
self.fc = nn.Sequential(nn.LazyLinear(out_features), nn.ReLU(inplace=True))
|
|
93
|
+
|
|
94
|
+
# FC output size + last reward.
|
|
95
|
+
core_output_size = out_features + 1
|
|
96
|
+
|
|
97
|
+
if use_lstm:
|
|
98
|
+
self.core = nn.LSTM(
|
|
99
|
+
core_output_size,
|
|
100
|
+
out_features,
|
|
101
|
+
num_layers=1,
|
|
102
|
+
batch_first=batch_first,
|
|
103
|
+
)
|
|
104
|
+
core_output_size = out_features
|
|
105
|
+
|
|
106
|
+
self.policy = nn.Linear(core_output_size, self.num_actions)
|
|
107
|
+
self.baseline = nn.Linear(core_output_size, 1)
|
|
108
|
+
|
|
109
|
+
def forward(self, x, reward, done, core_state=None, mask=None): # noqa: D102
|
|
110
|
+
if self.batch_first:
|
|
111
|
+
B, T, *x_shape = x.shape
|
|
112
|
+
batch_shape = torch.Size([B, T])
|
|
113
|
+
else:
|
|
114
|
+
T, B, *x_shape = x.shape
|
|
115
|
+
batch_shape = torch.Size([T, B])
|
|
116
|
+
if mask is None:
|
|
117
|
+
x = x.view(-1, *x.shape[-3:])
|
|
118
|
+
else:
|
|
119
|
+
x = x[mask]
|
|
120
|
+
if x.ndimension() != 4:
|
|
121
|
+
raise RuntimeError(
|
|
122
|
+
f"masked input should have 4 dimensions but got {x.ndimension()} instead"
|
|
123
|
+
)
|
|
124
|
+
x = self.convs(x)
|
|
125
|
+
x = x.view(B * T, -1)
|
|
126
|
+
x = self.fc(x)
|
|
127
|
+
|
|
128
|
+
if mask is None:
|
|
129
|
+
if self.batch_first:
|
|
130
|
+
x = x.view(B, T, -1)
|
|
131
|
+
else:
|
|
132
|
+
x = x.view(T, B, -1)
|
|
133
|
+
else:
|
|
134
|
+
x = self._allocate_masked_x(x, mask)
|
|
135
|
+
|
|
136
|
+
if self.clamp_reward:
|
|
137
|
+
reward = torch.clamp(reward, -1, 1)
|
|
138
|
+
reward = reward.unsqueeze(-1)
|
|
139
|
+
|
|
140
|
+
core_input = torch.cat([x, reward], dim=-1)
|
|
141
|
+
|
|
142
|
+
if self.use_lstm:
|
|
143
|
+
core_output, _ = self.core(core_input, core_state)
|
|
144
|
+
else:
|
|
145
|
+
core_output = core_input
|
|
146
|
+
|
|
147
|
+
policy_logits = self.policy(core_output)
|
|
148
|
+
baseline = self.baseline(core_output)
|
|
149
|
+
|
|
150
|
+
softmax_vals = F.softmax(policy_logits, dim=-1)
|
|
151
|
+
action = torch.multinomial(
|
|
152
|
+
softmax_vals.view(-1, softmax_vals.shape[-1]), num_samples=1
|
|
153
|
+
).view(softmax_vals.shape[:-1])
|
|
154
|
+
if self.one_hot:
|
|
155
|
+
action = F.one_hot(action, policy_logits.shape[-1])
|
|
156
|
+
|
|
157
|
+
if policy_logits.shape[:2] != batch_shape:
|
|
158
|
+
raise RuntimeError("policy_logits and batch-shape mismatch")
|
|
159
|
+
if baseline.shape[:2] != batch_shape:
|
|
160
|
+
raise RuntimeError("baseline and batch-shape mismatch")
|
|
161
|
+
if action.shape[:2] != batch_shape:
|
|
162
|
+
raise RuntimeError("action and batch-shape mismatch")
|
|
163
|
+
|
|
164
|
+
return (action, policy_logits, baseline), core_state
|
|
165
|
+
|
|
166
|
+
def _allocate_masked_x(self, x, mask):
|
|
167
|
+
x_empty = torch.zeros(
|
|
168
|
+
*mask.shape[:2], x.shape[-1], device=x.device, dtype=x.dtype
|
|
169
|
+
)
|
|
170
|
+
x_empty[mask] = x
|
|
171
|
+
return x_empty
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
class ImpalaNetTensorDict(ImpalaNet): # noqa: D101
|
|
175
|
+
observation_key = "pixels"
|
|
176
|
+
|
|
177
|
+
def forward(self, tensordict: TensorDictBase): # noqa: D102
|
|
178
|
+
x = tensordict.get(self.observation_key)
|
|
179
|
+
done = tensordict.get("done").squeeze(-1)
|
|
180
|
+
reward = tensordict.get("reward").squeeze(-1)
|
|
181
|
+
mask = tensordict.get(("collector", "mask"))
|
|
182
|
+
core_state = (
|
|
183
|
+
tensordict.get("core_state") if "core_state" in tensordict.keys() else None
|
|
184
|
+
)
|
|
185
|
+
return super().forward(x, reward, done, core_state=core_state, mask=mask)
|
|
@@ -0,0 +1,162 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import inspect
|
|
8
|
+
import warnings
|
|
9
|
+
from collections.abc import Callable, Sequence
|
|
10
|
+
|
|
11
|
+
import torch
|
|
12
|
+
from torch import nn
|
|
13
|
+
from torchrl.data.utils import DEVICE_TYPING
|
|
14
|
+
from torchrl.modules.models.exploration import NoisyLazyLinear, NoisyLinear
|
|
15
|
+
|
|
16
|
+
LazyMapping = {
|
|
17
|
+
nn.Linear: nn.LazyLinear,
|
|
18
|
+
NoisyLinear: NoisyLazyLinear,
|
|
19
|
+
}
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class SqueezeLayer(nn.Module):
|
|
23
|
+
"""Squeezing layer.
|
|
24
|
+
|
|
25
|
+
Squeezes some given singleton dimensions of an input tensor.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
dims (iterable): dimensions to be squeezed
|
|
29
|
+
default: (-1,)
|
|
30
|
+
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
def __init__(self, dims: Sequence[int] = (-1,)):
|
|
34
|
+
super().__init__()
|
|
35
|
+
for dim in dims:
|
|
36
|
+
if dim >= 0:
|
|
37
|
+
raise RuntimeError("dims must all be < 0")
|
|
38
|
+
self.dims = dims
|
|
39
|
+
|
|
40
|
+
def forward(self, input: torch.Tensor) -> torch.Tensor: # noqa: D102
|
|
41
|
+
for dim in self.dims:
|
|
42
|
+
if input.shape[dim] != 1:
|
|
43
|
+
raise RuntimeError(
|
|
44
|
+
f"Tried to squeeze an input over dims {self.dims} with shape {input.shape}"
|
|
45
|
+
)
|
|
46
|
+
input = input.squeeze(dim)
|
|
47
|
+
return input
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class Squeeze2dLayer(SqueezeLayer):
|
|
51
|
+
"""Squeezing layer for convolutional neural networks.
|
|
52
|
+
|
|
53
|
+
Squeezes the last two singleton dimensions of an input tensor.
|
|
54
|
+
|
|
55
|
+
"""
|
|
56
|
+
|
|
57
|
+
def __init__(self):
|
|
58
|
+
super().__init__((-2, -1))
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class SquashDims(nn.Module):
|
|
62
|
+
"""A squashing layer.
|
|
63
|
+
|
|
64
|
+
Flattens the N last dimensions of an input tensor.
|
|
65
|
+
|
|
66
|
+
Args:
|
|
67
|
+
ndims_in (int): number of dimensions to be flattened.
|
|
68
|
+
default = 3
|
|
69
|
+
|
|
70
|
+
Examples:
|
|
71
|
+
>>> from torchrl.modules.models.utils import SquashDims
|
|
72
|
+
>>> import torch
|
|
73
|
+
>>> x = torch.randn(1, 2, 3, 4)
|
|
74
|
+
>>> print(SquashDims()(x).shape)
|
|
75
|
+
torch.Size([1, 24])
|
|
76
|
+
|
|
77
|
+
"""
|
|
78
|
+
|
|
79
|
+
def __init__(self, ndims_in: int = 3):
|
|
80
|
+
super().__init__()
|
|
81
|
+
self.ndims_in = ndims_in
|
|
82
|
+
|
|
83
|
+
def forward(self, value: torch.Tensor) -> torch.Tensor:
|
|
84
|
+
value = value.flatten(-self.ndims_in, -1)
|
|
85
|
+
return value
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def _find_depth(depth: int | None, *list_or_ints: Sequence):
|
|
89
|
+
"""Find depth based on a sequence of inputs and a depth indicator.
|
|
90
|
+
|
|
91
|
+
If the depth is None, it is inferred by the length of one (or more) matching
|
|
92
|
+
lists of integers.
|
|
93
|
+
Raises an exception if depth does not match the list lengths or if lists lengths
|
|
94
|
+
do not match.
|
|
95
|
+
|
|
96
|
+
Args:
|
|
97
|
+
depth (int, optional): depth of the network
|
|
98
|
+
*list_or_ints (lists of int or int): if depth is None, at least one of
|
|
99
|
+
these inputs must be a list of ints of the length of the desired
|
|
100
|
+
network.
|
|
101
|
+
"""
|
|
102
|
+
if depth is None:
|
|
103
|
+
for item in list_or_ints:
|
|
104
|
+
if isinstance(item, (list, tuple)):
|
|
105
|
+
depth = len(item)
|
|
106
|
+
if depth is None:
|
|
107
|
+
raise ValueError(
|
|
108
|
+
f"depth=None requires one of the input args (kernel_sizes, strides, "
|
|
109
|
+
f"num_cells) to be a a list or tuple. Got {tuple(type(item) for item in list_or_ints)}"
|
|
110
|
+
)
|
|
111
|
+
return depth
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def create_on_device(
|
|
115
|
+
module_class: type[nn.Module] | Callable,
|
|
116
|
+
device: DEVICE_TYPING | None,
|
|
117
|
+
*args,
|
|
118
|
+
**kwargs,
|
|
119
|
+
) -> nn.Module:
|
|
120
|
+
"""Create a new instance of :obj:`module_class` on :obj:`device`.
|
|
121
|
+
|
|
122
|
+
The new instance is created directly on the device if its constructor supports this.
|
|
123
|
+
|
|
124
|
+
Args:
|
|
125
|
+
module_class (Type[nn.Module]): the class of module to be created.
|
|
126
|
+
device (DEVICE_TYPING): device to create the module on.
|
|
127
|
+
*args: positional arguments to be passed to the module constructor.
|
|
128
|
+
**kwargs: keyword arguments to be passed to the module constructor.
|
|
129
|
+
|
|
130
|
+
"""
|
|
131
|
+
fullargspec = inspect.getfullargspec(module_class.__init__)
|
|
132
|
+
if "device" in fullargspec.args or "device" in fullargspec.kwonlyargs:
|
|
133
|
+
return module_class(*args, device=device, **kwargs)
|
|
134
|
+
else:
|
|
135
|
+
result = module_class(*args, **kwargs)
|
|
136
|
+
if hasattr(result, "to"):
|
|
137
|
+
result = result.to(device)
|
|
138
|
+
return result
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
def _reset_parameters_recursive(module, warn_if_no_op: bool = True) -> bool:
|
|
142
|
+
"""Recursively resets the parameters of a :class:`~torch.nn.Module` in-place.
|
|
143
|
+
|
|
144
|
+
Args:
|
|
145
|
+
module (torch.nn.Module): the module to reset.
|
|
146
|
+
warn_if_no_op (bool, optional): whether to raise a warning in case this is a no-op.
|
|
147
|
+
Defaults to ``True``.
|
|
148
|
+
|
|
149
|
+
Returns: whether any parameter has been reset.
|
|
150
|
+
|
|
151
|
+
"""
|
|
152
|
+
any_reset = False
|
|
153
|
+
for layer in module.children():
|
|
154
|
+
if hasattr(layer, "reset_parameters"):
|
|
155
|
+
layer.reset_parameters()
|
|
156
|
+
any_reset |= True
|
|
157
|
+
any_reset |= _reset_parameters_recursive(layer, warn_if_no_op=False)
|
|
158
|
+
if warn_if_no_op and not any_reset:
|
|
159
|
+
warnings.warn(
|
|
160
|
+
"_reset_parameters_recursive was called without the parameters argument and did not find any parameters to reset"
|
|
161
|
+
)
|
|
162
|
+
return any_reset
|
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
|
|
6
|
+
from .cem import CEMPlanner
|
|
7
|
+
from .common import MPCPlannerBase
|
|
8
|
+
from .mppi import MPPIPlanner
|
|
9
|
+
|
|
10
|
+
__all__ = ["CEMPlanner", "MPCPlannerBase", "MPPIPlanner"]
|
|
@@ -0,0 +1,228 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
from typing import TYPE_CHECKING
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
from tensordict import TensorDict, TensorDictBase
|
|
11
|
+
from torchrl.modules.planners.common import MPCPlannerBase
|
|
12
|
+
|
|
13
|
+
if TYPE_CHECKING:
|
|
14
|
+
from torchrl.envs.common import EnvBase
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class CEMPlanner(MPCPlannerBase):
|
|
18
|
+
"""CEMPlanner Module.
|
|
19
|
+
|
|
20
|
+
Reference: The cross-entropy method for optimization, Botev et al. 2013
|
|
21
|
+
|
|
22
|
+
This module will perform a CEM planning step when given a TensorDict
|
|
23
|
+
containing initial states.
|
|
24
|
+
The CEM planning step is performed by sampling actions from a Gaussian
|
|
25
|
+
distribution with zero mean and unit variance.
|
|
26
|
+
The sampled actions are then used to perform a rollout in the environment.
|
|
27
|
+
The cumulative rewards obtained with the rollout is then
|
|
28
|
+
ranked. We select the top-k episodes and use their actions to update the
|
|
29
|
+
mean and standard deviation of the actions distribution.
|
|
30
|
+
The CEM planning step is repeated for a specified number of steps.
|
|
31
|
+
|
|
32
|
+
A call to the module returns the actions that empirically maximised the
|
|
33
|
+
returns given a planning horizon
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
env (EnvBase): The environment to perform the planning step on (can be
|
|
37
|
+
`ModelBasedEnv` or :obj:`EnvBase`).
|
|
38
|
+
planning_horizon (int): The length of the simulated trajectories
|
|
39
|
+
optim_steps (int): The number of optimization steps used by the MPC
|
|
40
|
+
planner
|
|
41
|
+
num_candidates (int): The number of candidates to sample from the
|
|
42
|
+
Gaussian distributions.
|
|
43
|
+
top_k (int): The number of top candidates to use to
|
|
44
|
+
update the mean and standard deviation of the Gaussian distribution.
|
|
45
|
+
reward_key (str, optional): The key in the TensorDict to use to
|
|
46
|
+
retrieve the reward. Defaults to "reward".
|
|
47
|
+
action_key (str, optional): The key in the TensorDict to use to store
|
|
48
|
+
the action. Defaults to "action"
|
|
49
|
+
|
|
50
|
+
Examples:
|
|
51
|
+
>>> from tensordict import TensorDict
|
|
52
|
+
>>> from torchrl.data import Composite, Unbounded
|
|
53
|
+
>>> from torchrl.envs.model_based import ModelBasedEnvBase
|
|
54
|
+
>>> from torchrl.modules import SafeModule
|
|
55
|
+
>>> class MyMBEnv(ModelBasedEnvBase):
|
|
56
|
+
... def __init__(self, world_model, device="cpu", dtype=None, batch_size=None):
|
|
57
|
+
... super().__init__(world_model, device=device, dtype=dtype, batch_size=batch_size)
|
|
58
|
+
... self.state_spec = Composite(
|
|
59
|
+
... hidden_observation=Unbounded((4,))
|
|
60
|
+
... )
|
|
61
|
+
... self.observation_spec = Composite(
|
|
62
|
+
... hidden_observation=Unbounded((4,))
|
|
63
|
+
... )
|
|
64
|
+
... self.action_spec = Unbounded((1,))
|
|
65
|
+
... self.reward_spec = Unbounded((1,))
|
|
66
|
+
...
|
|
67
|
+
... def _reset(self, tensordict: TensorDict) -> TensorDict:
|
|
68
|
+
... tensordict = TensorDict(
|
|
69
|
+
... {},
|
|
70
|
+
... batch_size=self.batch_size,
|
|
71
|
+
... device=self.device,
|
|
72
|
+
... )
|
|
73
|
+
... tensordict = tensordict.update(
|
|
74
|
+
... self.full_state_spec.rand())
|
|
75
|
+
... tensordict = tensordict.update(
|
|
76
|
+
... self.full_action_spec.rand())
|
|
77
|
+
... tensordict = tensordict.update(
|
|
78
|
+
... self.full_observation_spec.rand())
|
|
79
|
+
... return tensordict
|
|
80
|
+
...
|
|
81
|
+
>>> from torchrl.modules import MLP, WorldModelWrapper
|
|
82
|
+
>>> import torch.nn as nn
|
|
83
|
+
>>> world_model = WorldModelWrapper(
|
|
84
|
+
... SafeModule(
|
|
85
|
+
... MLP(out_features=4, activation_class=nn.ReLU, activate_last_layer=True, depth=0),
|
|
86
|
+
... in_keys=["hidden_observation", "action"],
|
|
87
|
+
... out_keys=["hidden_observation"],
|
|
88
|
+
... ),
|
|
89
|
+
... SafeModule(
|
|
90
|
+
... nn.Linear(4, 1),
|
|
91
|
+
... in_keys=["hidden_observation"],
|
|
92
|
+
... out_keys=["reward"],
|
|
93
|
+
... ),
|
|
94
|
+
... )
|
|
95
|
+
>>> env = MyMBEnv(world_model)
|
|
96
|
+
>>> # Build a planner and use it as actor
|
|
97
|
+
>>> planner = CEMPlanner(env, 10, 11, 7, 3)
|
|
98
|
+
>>> env.rollout(5, planner)
|
|
99
|
+
TensorDict(
|
|
100
|
+
fields={
|
|
101
|
+
action: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.float32, is_shared=False),
|
|
102
|
+
done: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
103
|
+
hidden_observation: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.float32, is_shared=False),
|
|
104
|
+
next: TensorDict(
|
|
105
|
+
fields={
|
|
106
|
+
done: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
107
|
+
hidden_observation: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.float32, is_shared=False),
|
|
108
|
+
reward: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.float32, is_shared=False),
|
|
109
|
+
terminated: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
|
|
110
|
+
batch_size=torch.Size([5]),
|
|
111
|
+
device=cpu,
|
|
112
|
+
is_shared=False),
|
|
113
|
+
terminated: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
|
|
114
|
+
batch_size=torch.Size([5]),
|
|
115
|
+
device=cpu,
|
|
116
|
+
is_shared=False)
|
|
117
|
+
"""
|
|
118
|
+
|
|
119
|
+
def __init__(
|
|
120
|
+
self,
|
|
121
|
+
env: EnvBase,
|
|
122
|
+
planning_horizon: int,
|
|
123
|
+
optim_steps: int,
|
|
124
|
+
num_candidates: int,
|
|
125
|
+
top_k: int,
|
|
126
|
+
reward_key: str = ("next", "reward"),
|
|
127
|
+
action_key: str = "action",
|
|
128
|
+
):
|
|
129
|
+
super().__init__(env=env, action_key=action_key)
|
|
130
|
+
self.planning_horizon = planning_horizon
|
|
131
|
+
self.optim_steps = optim_steps
|
|
132
|
+
self.num_candidates = num_candidates
|
|
133
|
+
self.top_k = top_k
|
|
134
|
+
self.reward_key = reward_key
|
|
135
|
+
|
|
136
|
+
def planning(self, tensordict: TensorDictBase) -> torch.Tensor:
|
|
137
|
+
batch_size = tensordict.batch_size
|
|
138
|
+
action_shape = (
|
|
139
|
+
*batch_size,
|
|
140
|
+
self.num_candidates,
|
|
141
|
+
self.planning_horizon,
|
|
142
|
+
*self.action_spec.shape,
|
|
143
|
+
)
|
|
144
|
+
action_stats_shape = (
|
|
145
|
+
*batch_size,
|
|
146
|
+
1,
|
|
147
|
+
self.planning_horizon,
|
|
148
|
+
*self.action_spec.shape,
|
|
149
|
+
)
|
|
150
|
+
action_topk_shape = (
|
|
151
|
+
*batch_size,
|
|
152
|
+
self.top_k,
|
|
153
|
+
self.planning_horizon,
|
|
154
|
+
*self.action_spec.shape,
|
|
155
|
+
)
|
|
156
|
+
TIME_DIM = len(self.action_spec.shape) - 3
|
|
157
|
+
K_DIM = len(self.action_spec.shape) - 4
|
|
158
|
+
expanded_original_tensordict = (
|
|
159
|
+
tensordict.unsqueeze(-1)
|
|
160
|
+
.expand(*batch_size, self.num_candidates)
|
|
161
|
+
.to_tensordict()
|
|
162
|
+
)
|
|
163
|
+
_action_means = torch.zeros(
|
|
164
|
+
*action_stats_shape,
|
|
165
|
+
device=tensordict.device,
|
|
166
|
+
dtype=self.env.action_spec.dtype,
|
|
167
|
+
)
|
|
168
|
+
_action_stds = torch.ones_like(_action_means)
|
|
169
|
+
container = TensorDict(
|
|
170
|
+
{
|
|
171
|
+
"tensordict": expanded_original_tensordict,
|
|
172
|
+
"stats": TensorDict(
|
|
173
|
+
{
|
|
174
|
+
"_action_means": _action_means,
|
|
175
|
+
"_action_stds": _action_stds,
|
|
176
|
+
},
|
|
177
|
+
[*batch_size, 1, self.planning_horizon],
|
|
178
|
+
),
|
|
179
|
+
},
|
|
180
|
+
batch_size,
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
for _ in range(self.optim_steps):
|
|
184
|
+
actions_means = container.get(("stats", "_action_means"))
|
|
185
|
+
actions_stds = container.get(("stats", "_action_stds"))
|
|
186
|
+
actions = actions_means + actions_stds * torch.randn(
|
|
187
|
+
*action_shape,
|
|
188
|
+
device=actions_means.device,
|
|
189
|
+
dtype=actions_means.dtype,
|
|
190
|
+
)
|
|
191
|
+
actions = self.env.action_spec.project(actions)
|
|
192
|
+
optim_tensordict = container.get("tensordict").clone()
|
|
193
|
+
policy = _PrecomputedActionsSequentialSetter(actions)
|
|
194
|
+
optim_tensordict = self.env.rollout(
|
|
195
|
+
max_steps=self.planning_horizon,
|
|
196
|
+
policy=policy,
|
|
197
|
+
auto_reset=False,
|
|
198
|
+
tensordict=optim_tensordict,
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
sum_rewards = optim_tensordict.get(self.reward_key).sum(
|
|
202
|
+
dim=TIME_DIM, keepdim=True
|
|
203
|
+
)
|
|
204
|
+
_, top_k = sum_rewards.topk(self.top_k, dim=K_DIM)
|
|
205
|
+
top_k = top_k.expand(action_topk_shape)
|
|
206
|
+
best_actions = actions.gather(K_DIM, top_k)
|
|
207
|
+
container.set_(
|
|
208
|
+
("stats", "_action_means"), best_actions.mean(dim=K_DIM, keepdim=True)
|
|
209
|
+
)
|
|
210
|
+
container.set_(
|
|
211
|
+
("stats", "_action_stds"), best_actions.std(dim=K_DIM, keepdim=True)
|
|
212
|
+
)
|
|
213
|
+
action_means = container.get(("stats", "_action_means"))
|
|
214
|
+
return action_means[..., 0, 0, :]
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
class _PrecomputedActionsSequentialSetter:
|
|
218
|
+
def __init__(self, actions):
|
|
219
|
+
self.actions = actions
|
|
220
|
+
self.cmpt = 0
|
|
221
|
+
|
|
222
|
+
def __call__(self, tensordict):
|
|
223
|
+
# checks that the step count is lower or equal to the horizon
|
|
224
|
+
if self.cmpt >= self.actions.shape[-2]:
|
|
225
|
+
raise ValueError("Precomputed actions sequence is too short")
|
|
226
|
+
tensordict = tensordict.set("action", self.actions[..., self.cmpt, :])
|
|
227
|
+
self.cmpt += 1
|
|
228
|
+
return tensordict
|
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import abc
|
|
8
|
+
from typing import TYPE_CHECKING
|
|
9
|
+
|
|
10
|
+
import torch
|
|
11
|
+
from tensordict import TensorDictBase
|
|
12
|
+
|
|
13
|
+
from torchrl.modules import SafeModule
|
|
14
|
+
|
|
15
|
+
if TYPE_CHECKING:
|
|
16
|
+
from torchrl.envs.common import EnvBase
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class MPCPlannerBase(SafeModule, metaclass=abc.ABCMeta):
|
|
20
|
+
"""MPCPlannerBase abstract Module.
|
|
21
|
+
|
|
22
|
+
This class inherits from :obj:`SafeModule`. Provided a :obj:`TensorDict`, this module will perform a Model Predictive Control (MPC) planning step.
|
|
23
|
+
At the end of the planning step, the :obj:`MPCPlanner` will return a proposed action.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
env (EnvBase): The environment to perform the planning step on (Can be :obj:`ModelBasedEnvBase` or :obj:`EnvBase`).
|
|
27
|
+
action_key (str, optional): The key that will point to the computed action.
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
def __init__(
|
|
31
|
+
self,
|
|
32
|
+
env: EnvBase,
|
|
33
|
+
action_key: str = "action",
|
|
34
|
+
):
|
|
35
|
+
# Check if env is stateless
|
|
36
|
+
if env.batch_locked:
|
|
37
|
+
raise ValueError(
|
|
38
|
+
"Environment is batch_locked. MPCPlanners need an environment that accepts batched inputs with any batch size"
|
|
39
|
+
)
|
|
40
|
+
out_keys = [action_key]
|
|
41
|
+
in_keys = list(env.observation_spec.keys(True, True))
|
|
42
|
+
super().__init__(env, in_keys=in_keys, out_keys=out_keys)
|
|
43
|
+
self.env = env
|
|
44
|
+
self.action_spec = env.action_spec
|
|
45
|
+
self.to(env.device)
|
|
46
|
+
|
|
47
|
+
@abc.abstractmethod
|
|
48
|
+
def planning(self, td: TensorDictBase) -> torch.Tensor:
|
|
49
|
+
"""Performs the MPC planning step.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
td (TensorDict): The TensorDict to perform the planning step on.
|
|
53
|
+
"""
|
|
54
|
+
raise NotImplementedError()
|
|
55
|
+
|
|
56
|
+
def forward(
|
|
57
|
+
self,
|
|
58
|
+
tensordict: TensorDictBase,
|
|
59
|
+
tensordict_out: TensorDictBase | None = None,
|
|
60
|
+
**kwargs,
|
|
61
|
+
) -> TensorDictBase:
|
|
62
|
+
if "params" in kwargs or "vmap" in kwargs:
|
|
63
|
+
raise ValueError(
|
|
64
|
+
"MPCPlannerBase does not currently support functional programming."
|
|
65
|
+
)
|
|
66
|
+
action = self.planning(tensordict)
|
|
67
|
+
action = self.action_spec.project(action)
|
|
68
|
+
tensordict_out = self._write_to_tensordict(
|
|
69
|
+
tensordict,
|
|
70
|
+
(action,),
|
|
71
|
+
tensordict_out,
|
|
72
|
+
)
|
|
73
|
+
return tensordict_out
|