torchrl 0.11.0__cp314-cp314-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/benchmark_batched_envs.py +104 -0
- benchmarks/conftest.py +91 -0
- benchmarks/ecosystem/gym_env_throughput.py +321 -0
- benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
- benchmarks/requirements.txt +7 -0
- benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
- benchmarks/test_collectors_benchmark.py +240 -0
- benchmarks/test_compressed_storage_benchmark.py +145 -0
- benchmarks/test_envs_benchmark.py +133 -0
- benchmarks/test_llm.py +101 -0
- benchmarks/test_non_tensor_env_benchmark.py +70 -0
- benchmarks/test_objectives_benchmarks.py +1199 -0
- benchmarks/test_replaybuffer_benchmark.py +254 -0
- sota-check/README.md +35 -0
- sota-implementations/README.md +142 -0
- sota-implementations/a2c/README.md +39 -0
- sota-implementations/a2c/a2c_atari.py +291 -0
- sota-implementations/a2c/a2c_mujoco.py +273 -0
- sota-implementations/a2c/utils_atari.py +240 -0
- sota-implementations/a2c/utils_mujoco.py +160 -0
- sota-implementations/bandits/README.md +7 -0
- sota-implementations/bandits/dqn.py +126 -0
- sota-implementations/cql/cql_offline.py +198 -0
- sota-implementations/cql/cql_online.py +249 -0
- sota-implementations/cql/discrete_cql_offline.py +180 -0
- sota-implementations/cql/discrete_cql_online.py +227 -0
- sota-implementations/cql/utils.py +471 -0
- sota-implementations/crossq/crossq.py +271 -0
- sota-implementations/crossq/utils.py +320 -0
- sota-implementations/ddpg/ddpg.py +231 -0
- sota-implementations/ddpg/utils.py +325 -0
- sota-implementations/decision_transformer/dt.py +163 -0
- sota-implementations/decision_transformer/lamb.py +167 -0
- sota-implementations/decision_transformer/online_dt.py +178 -0
- sota-implementations/decision_transformer/utils.py +562 -0
- sota-implementations/discrete_sac/discrete_sac.py +243 -0
- sota-implementations/discrete_sac/utils.py +324 -0
- sota-implementations/dqn/README.md +30 -0
- sota-implementations/dqn/dqn_atari.py +272 -0
- sota-implementations/dqn/dqn_cartpole.py +236 -0
- sota-implementations/dqn/utils_atari.py +132 -0
- sota-implementations/dqn/utils_cartpole.py +90 -0
- sota-implementations/dreamer/README.md +129 -0
- sota-implementations/dreamer/dreamer.py +586 -0
- sota-implementations/dreamer/dreamer_utils.py +1107 -0
- sota-implementations/expert-iteration/README.md +352 -0
- sota-implementations/expert-iteration/ei_utils.py +770 -0
- sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
- sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
- sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
- sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
- sota-implementations/gail/gail.py +327 -0
- sota-implementations/gail/gail_utils.py +68 -0
- sota-implementations/gail/ppo_utils.py +157 -0
- sota-implementations/grpo/README.md +273 -0
- sota-implementations/grpo/grpo-async.py +437 -0
- sota-implementations/grpo/grpo-sync.py +435 -0
- sota-implementations/grpo/grpo_utils.py +843 -0
- sota-implementations/grpo/requirements_gsm8k.txt +11 -0
- sota-implementations/grpo/requirements_ifeval.txt +16 -0
- sota-implementations/impala/README.md +33 -0
- sota-implementations/impala/impala_multi_node_ray.py +292 -0
- sota-implementations/impala/impala_multi_node_submitit.py +284 -0
- sota-implementations/impala/impala_single_node.py +261 -0
- sota-implementations/impala/utils.py +184 -0
- sota-implementations/iql/discrete_iql.py +230 -0
- sota-implementations/iql/iql_offline.py +164 -0
- sota-implementations/iql/iql_online.py +225 -0
- sota-implementations/iql/utils.py +437 -0
- sota-implementations/multiagent/README.md +74 -0
- sota-implementations/multiagent/iql.py +237 -0
- sota-implementations/multiagent/maddpg_iddpg.py +266 -0
- sota-implementations/multiagent/mappo_ippo.py +267 -0
- sota-implementations/multiagent/qmix_vdn.py +271 -0
- sota-implementations/multiagent/sac.py +337 -0
- sota-implementations/multiagent/utils/__init__.py +4 -0
- sota-implementations/multiagent/utils/logging.py +151 -0
- sota-implementations/multiagent/utils/utils.py +43 -0
- sota-implementations/ppo/README.md +29 -0
- sota-implementations/ppo/ppo_atari.py +305 -0
- sota-implementations/ppo/ppo_mujoco.py +293 -0
- sota-implementations/ppo/utils_atari.py +238 -0
- sota-implementations/ppo/utils_mujoco.py +152 -0
- sota-implementations/ppo_trainer/train.py +21 -0
- sota-implementations/redq/README.md +7 -0
- sota-implementations/redq/redq.py +199 -0
- sota-implementations/redq/utils.py +1060 -0
- sota-implementations/sac/sac-async.py +266 -0
- sota-implementations/sac/sac.py +239 -0
- sota-implementations/sac/utils.py +381 -0
- sota-implementations/sac_trainer/train.py +16 -0
- sota-implementations/td3/td3.py +254 -0
- sota-implementations/td3/utils.py +319 -0
- sota-implementations/td3_bc/td3_bc.py +177 -0
- sota-implementations/td3_bc/utils.py +251 -0
- torchrl/__init__.py +144 -0
- torchrl/_extension.py +74 -0
- torchrl/_torchrl.cp314-win_amd64.pyd +0 -0
- torchrl/_utils.py +1431 -0
- torchrl/collectors/__init__.py +48 -0
- torchrl/collectors/_base.py +1058 -0
- torchrl/collectors/_constants.py +88 -0
- torchrl/collectors/_multi_async.py +324 -0
- torchrl/collectors/_multi_base.py +1805 -0
- torchrl/collectors/_multi_sync.py +464 -0
- torchrl/collectors/_runner.py +581 -0
- torchrl/collectors/_single.py +2009 -0
- torchrl/collectors/_single_async.py +259 -0
- torchrl/collectors/collectors.py +62 -0
- torchrl/collectors/distributed/__init__.py +32 -0
- torchrl/collectors/distributed/default_configs.py +133 -0
- torchrl/collectors/distributed/generic.py +1306 -0
- torchrl/collectors/distributed/ray.py +1092 -0
- torchrl/collectors/distributed/rpc.py +1006 -0
- torchrl/collectors/distributed/sync.py +731 -0
- torchrl/collectors/distributed/utils.py +160 -0
- torchrl/collectors/llm/__init__.py +10 -0
- torchrl/collectors/llm/base.py +494 -0
- torchrl/collectors/llm/ray_collector.py +275 -0
- torchrl/collectors/llm/utils.py +36 -0
- torchrl/collectors/llm/weight_update/__init__.py +10 -0
- torchrl/collectors/llm/weight_update/vllm.py +348 -0
- torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
- torchrl/collectors/utils.py +433 -0
- torchrl/collectors/weight_update.py +591 -0
- torchrl/csrc/numpy_utils.h +38 -0
- torchrl/csrc/pybind.cpp +27 -0
- torchrl/csrc/segment_tree.h +458 -0
- torchrl/csrc/torch_utils.h +34 -0
- torchrl/csrc/utils.cpp +48 -0
- torchrl/csrc/utils.h +31 -0
- torchrl/data/__init__.py +187 -0
- torchrl/data/datasets/__init__.py +58 -0
- torchrl/data/datasets/atari_dqn.py +878 -0
- torchrl/data/datasets/common.py +281 -0
- torchrl/data/datasets/d4rl.py +489 -0
- torchrl/data/datasets/d4rl_infos.py +187 -0
- torchrl/data/datasets/gen_dgrl.py +375 -0
- torchrl/data/datasets/minari_data.py +643 -0
- torchrl/data/datasets/openml.py +177 -0
- torchrl/data/datasets/openx.py +798 -0
- torchrl/data/datasets/roboset.py +363 -0
- torchrl/data/datasets/utils.py +11 -0
- torchrl/data/datasets/vd4rl.py +432 -0
- torchrl/data/llm/__init__.py +34 -0
- torchrl/data/llm/dataset.py +491 -0
- torchrl/data/llm/history.py +1378 -0
- torchrl/data/llm/prompt.py +198 -0
- torchrl/data/llm/reward.py +225 -0
- torchrl/data/llm/topk.py +186 -0
- torchrl/data/llm/utils.py +543 -0
- torchrl/data/map/__init__.py +21 -0
- torchrl/data/map/hash.py +185 -0
- torchrl/data/map/query.py +204 -0
- torchrl/data/map/tdstorage.py +363 -0
- torchrl/data/map/tree.py +1434 -0
- torchrl/data/map/utils.py +103 -0
- torchrl/data/postprocs/__init__.py +8 -0
- torchrl/data/postprocs/postprocs.py +391 -0
- torchrl/data/replay_buffers/__init__.py +99 -0
- torchrl/data/replay_buffers/checkpointers.py +622 -0
- torchrl/data/replay_buffers/ray_buffer.py +292 -0
- torchrl/data/replay_buffers/replay_buffers.py +2376 -0
- torchrl/data/replay_buffers/samplers.py +2578 -0
- torchrl/data/replay_buffers/scheduler.py +265 -0
- torchrl/data/replay_buffers/storages.py +2412 -0
- torchrl/data/replay_buffers/utils.py +1042 -0
- torchrl/data/replay_buffers/writers.py +781 -0
- torchrl/data/tensor_specs.py +7101 -0
- torchrl/data/utils.py +334 -0
- torchrl/envs/__init__.py +265 -0
- torchrl/envs/async_envs.py +1105 -0
- torchrl/envs/batched_envs.py +3093 -0
- torchrl/envs/common.py +4241 -0
- torchrl/envs/custom/__init__.py +11 -0
- torchrl/envs/custom/chess.py +617 -0
- torchrl/envs/custom/llm.py +214 -0
- torchrl/envs/custom/pendulum.py +401 -0
- torchrl/envs/custom/san_moves.txt +29274 -0
- torchrl/envs/custom/tictactoeenv.py +288 -0
- torchrl/envs/env_creator.py +263 -0
- torchrl/envs/gym_like.py +752 -0
- torchrl/envs/libs/__init__.py +68 -0
- torchrl/envs/libs/_gym_utils.py +326 -0
- torchrl/envs/libs/brax.py +846 -0
- torchrl/envs/libs/dm_control.py +544 -0
- torchrl/envs/libs/envpool.py +447 -0
- torchrl/envs/libs/gym.py +2239 -0
- torchrl/envs/libs/habitat.py +138 -0
- torchrl/envs/libs/isaac_lab.py +87 -0
- torchrl/envs/libs/isaacgym.py +203 -0
- torchrl/envs/libs/jax_utils.py +166 -0
- torchrl/envs/libs/jumanji.py +963 -0
- torchrl/envs/libs/meltingpot.py +599 -0
- torchrl/envs/libs/openml.py +153 -0
- torchrl/envs/libs/openspiel.py +652 -0
- torchrl/envs/libs/pettingzoo.py +1042 -0
- torchrl/envs/libs/procgen.py +351 -0
- torchrl/envs/libs/robohive.py +429 -0
- torchrl/envs/libs/smacv2.py +645 -0
- torchrl/envs/libs/unity_mlagents.py +891 -0
- torchrl/envs/libs/utils.py +147 -0
- torchrl/envs/libs/vmas.py +813 -0
- torchrl/envs/llm/__init__.py +63 -0
- torchrl/envs/llm/chat.py +730 -0
- torchrl/envs/llm/datasets/README.md +4 -0
- torchrl/envs/llm/datasets/__init__.py +17 -0
- torchrl/envs/llm/datasets/gsm8k.py +353 -0
- torchrl/envs/llm/datasets/ifeval.py +274 -0
- torchrl/envs/llm/envs.py +789 -0
- torchrl/envs/llm/libs/README.md +3 -0
- torchrl/envs/llm/libs/__init__.py +8 -0
- torchrl/envs/llm/libs/mlgym.py +869 -0
- torchrl/envs/llm/reward/__init__.py +10 -0
- torchrl/envs/llm/reward/gsm8k.py +324 -0
- torchrl/envs/llm/reward/ifeval/README.md +13 -0
- torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
- torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
- torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
- torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
- torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
- torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
- torchrl/envs/llm/transforms/__init__.py +55 -0
- torchrl/envs/llm/transforms/browser.py +292 -0
- torchrl/envs/llm/transforms/dataloading.py +859 -0
- torchrl/envs/llm/transforms/format.py +73 -0
- torchrl/envs/llm/transforms/kl.py +1544 -0
- torchrl/envs/llm/transforms/policy_version.py +189 -0
- torchrl/envs/llm/transforms/reason.py +323 -0
- torchrl/envs/llm/transforms/tokenizer.py +321 -0
- torchrl/envs/llm/transforms/tools.py +1955 -0
- torchrl/envs/model_based/__init__.py +9 -0
- torchrl/envs/model_based/common.py +180 -0
- torchrl/envs/model_based/dreamer.py +112 -0
- torchrl/envs/transforms/__init__.py +147 -0
- torchrl/envs/transforms/functional.py +48 -0
- torchrl/envs/transforms/gym_transforms.py +203 -0
- torchrl/envs/transforms/module.py +341 -0
- torchrl/envs/transforms/r3m.py +372 -0
- torchrl/envs/transforms/ray_service.py +663 -0
- torchrl/envs/transforms/rb_transforms.py +214 -0
- torchrl/envs/transforms/transforms.py +11835 -0
- torchrl/envs/transforms/utils.py +94 -0
- torchrl/envs/transforms/vc1.py +307 -0
- torchrl/envs/transforms/vecnorm.py +845 -0
- torchrl/envs/transforms/vip.py +407 -0
- torchrl/envs/utils.py +1718 -0
- torchrl/envs/vec_envs.py +11 -0
- torchrl/modules/__init__.py +206 -0
- torchrl/modules/distributions/__init__.py +73 -0
- torchrl/modules/distributions/continuous.py +830 -0
- torchrl/modules/distributions/discrete.py +908 -0
- torchrl/modules/distributions/truncated_normal.py +187 -0
- torchrl/modules/distributions/utils.py +233 -0
- torchrl/modules/llm/__init__.py +62 -0
- torchrl/modules/llm/backends/__init__.py +65 -0
- torchrl/modules/llm/backends/vllm/__init__.py +94 -0
- torchrl/modules/llm/backends/vllm/_models.py +46 -0
- torchrl/modules/llm/backends/vllm/base.py +72 -0
- torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
- torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
- torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
- torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
- torchrl/modules/llm/policies/__init__.py +28 -0
- torchrl/modules/llm/policies/common.py +1809 -0
- torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
- torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
- torchrl/modules/llm/utils.py +23 -0
- torchrl/modules/mcts/__init__.py +21 -0
- torchrl/modules/mcts/scores.py +579 -0
- torchrl/modules/models/__init__.py +86 -0
- torchrl/modules/models/batchrenorm.py +119 -0
- torchrl/modules/models/decision_transformer.py +179 -0
- torchrl/modules/models/exploration.py +731 -0
- torchrl/modules/models/llm.py +156 -0
- torchrl/modules/models/model_based.py +596 -0
- torchrl/modules/models/models.py +1712 -0
- torchrl/modules/models/multiagent.py +1067 -0
- torchrl/modules/models/recipes/impala.py +185 -0
- torchrl/modules/models/utils.py +162 -0
- torchrl/modules/planners/__init__.py +10 -0
- torchrl/modules/planners/cem.py +228 -0
- torchrl/modules/planners/common.py +73 -0
- torchrl/modules/planners/mppi.py +265 -0
- torchrl/modules/tensordict_module/__init__.py +89 -0
- torchrl/modules/tensordict_module/actors.py +2457 -0
- torchrl/modules/tensordict_module/common.py +529 -0
- torchrl/modules/tensordict_module/exploration.py +814 -0
- torchrl/modules/tensordict_module/probabilistic.py +321 -0
- torchrl/modules/tensordict_module/rnn.py +1639 -0
- torchrl/modules/tensordict_module/sequence.py +132 -0
- torchrl/modules/tensordict_module/world_models.py +34 -0
- torchrl/modules/utils/__init__.py +38 -0
- torchrl/modules/utils/mappings.py +9 -0
- torchrl/modules/utils/utils.py +89 -0
- torchrl/objectives/__init__.py +78 -0
- torchrl/objectives/a2c.py +659 -0
- torchrl/objectives/common.py +753 -0
- torchrl/objectives/cql.py +1346 -0
- torchrl/objectives/crossq.py +710 -0
- torchrl/objectives/ddpg.py +453 -0
- torchrl/objectives/decision_transformer.py +371 -0
- torchrl/objectives/deprecated.py +516 -0
- torchrl/objectives/dqn.py +683 -0
- torchrl/objectives/dreamer.py +488 -0
- torchrl/objectives/functional.py +48 -0
- torchrl/objectives/gail.py +258 -0
- torchrl/objectives/iql.py +996 -0
- torchrl/objectives/llm/__init__.py +30 -0
- torchrl/objectives/llm/grpo.py +846 -0
- torchrl/objectives/llm/sft.py +482 -0
- torchrl/objectives/multiagent/__init__.py +8 -0
- torchrl/objectives/multiagent/qmixer.py +396 -0
- torchrl/objectives/ppo.py +1669 -0
- torchrl/objectives/redq.py +683 -0
- torchrl/objectives/reinforce.py +530 -0
- torchrl/objectives/sac.py +1580 -0
- torchrl/objectives/td3.py +570 -0
- torchrl/objectives/td3_bc.py +625 -0
- torchrl/objectives/utils.py +782 -0
- torchrl/objectives/value/__init__.py +28 -0
- torchrl/objectives/value/advantages.py +1956 -0
- torchrl/objectives/value/functional.py +1459 -0
- torchrl/objectives/value/utils.py +360 -0
- torchrl/record/__init__.py +17 -0
- torchrl/record/loggers/__init__.py +23 -0
- torchrl/record/loggers/common.py +48 -0
- torchrl/record/loggers/csv.py +226 -0
- torchrl/record/loggers/mlflow.py +142 -0
- torchrl/record/loggers/tensorboard.py +139 -0
- torchrl/record/loggers/trackio.py +163 -0
- torchrl/record/loggers/utils.py +78 -0
- torchrl/record/loggers/wandb.py +214 -0
- torchrl/record/recorder.py +554 -0
- torchrl/services/__init__.py +79 -0
- torchrl/services/base.py +109 -0
- torchrl/services/ray_service.py +453 -0
- torchrl/testing/__init__.py +107 -0
- torchrl/testing/assertions.py +179 -0
- torchrl/testing/dist_utils.py +122 -0
- torchrl/testing/env_creators.py +227 -0
- torchrl/testing/env_helper.py +35 -0
- torchrl/testing/gym_helpers.py +156 -0
- torchrl/testing/llm_mocks.py +119 -0
- torchrl/testing/mocking_classes.py +2720 -0
- torchrl/testing/modules.py +295 -0
- torchrl/testing/mp_helpers.py +15 -0
- torchrl/testing/ray_helpers.py +293 -0
- torchrl/testing/utils.py +190 -0
- torchrl/trainers/__init__.py +42 -0
- torchrl/trainers/algorithms/__init__.py +11 -0
- torchrl/trainers/algorithms/configs/__init__.py +705 -0
- torchrl/trainers/algorithms/configs/collectors.py +216 -0
- torchrl/trainers/algorithms/configs/common.py +41 -0
- torchrl/trainers/algorithms/configs/data.py +308 -0
- torchrl/trainers/algorithms/configs/envs.py +104 -0
- torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
- torchrl/trainers/algorithms/configs/logging.py +80 -0
- torchrl/trainers/algorithms/configs/modules.py +570 -0
- torchrl/trainers/algorithms/configs/objectives.py +177 -0
- torchrl/trainers/algorithms/configs/trainers.py +340 -0
- torchrl/trainers/algorithms/configs/transforms.py +955 -0
- torchrl/trainers/algorithms/configs/utils.py +252 -0
- torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
- torchrl/trainers/algorithms/configs/weight_update.py +159 -0
- torchrl/trainers/algorithms/ppo.py +373 -0
- torchrl/trainers/algorithms/sac.py +308 -0
- torchrl/trainers/helpers/__init__.py +40 -0
- torchrl/trainers/helpers/collectors.py +416 -0
- torchrl/trainers/helpers/envs.py +573 -0
- torchrl/trainers/helpers/logger.py +33 -0
- torchrl/trainers/helpers/losses.py +132 -0
- torchrl/trainers/helpers/models.py +658 -0
- torchrl/trainers/helpers/replay_buffer.py +59 -0
- torchrl/trainers/helpers/trainers.py +301 -0
- torchrl/trainers/trainers.py +2052 -0
- torchrl/weight_update/__init__.py +33 -0
- torchrl/weight_update/_distributed.py +749 -0
- torchrl/weight_update/_mp.py +624 -0
- torchrl/weight_update/_noupdate.py +102 -0
- torchrl/weight_update/_ray.py +1032 -0
- torchrl/weight_update/_rpc.py +284 -0
- torchrl/weight_update/_shared.py +891 -0
- torchrl/weight_update/llm/__init__.py +32 -0
- torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
- torchrl/weight_update/llm/vllm_nccl.py +710 -0
- torchrl/weight_update/utils.py +73 -0
- torchrl/weight_update/weight_sync_schemes.py +1244 -0
- torchrl-0.11.0.dist-info/LICENSE +21 -0
- torchrl-0.11.0.dist-info/METADATA +1307 -0
- torchrl-0.11.0.dist-info/RECORD +394 -0
- torchrl-0.11.0.dist-info/WHEEL +5 -0
- torchrl-0.11.0.dist-info/entry_points.txt +2 -0
- torchrl-0.11.0.dist-info/top_level.txt +7 -0
|
@@ -0,0 +1,830 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import weakref
|
|
8
|
+
from collections.abc import Callable, Sequence
|
|
9
|
+
from numbers import Number
|
|
10
|
+
|
|
11
|
+
import numpy as np
|
|
12
|
+
import torch
|
|
13
|
+
from packaging import version
|
|
14
|
+
from torch import distributions as D, nn
|
|
15
|
+
from torch.distributions import constraints
|
|
16
|
+
from torch.distributions.transforms import _InverseTransform
|
|
17
|
+
|
|
18
|
+
from torchrl._utils import safe_is_current_stream_capturing
|
|
19
|
+
from torchrl.modules.distributions.truncated_normal import (
|
|
20
|
+
TruncatedNormal as _TruncatedNormal,
|
|
21
|
+
)
|
|
22
|
+
from torchrl.modules.distributions.utils import (
|
|
23
|
+
_cast_device,
|
|
24
|
+
FasterTransformedDistribution,
|
|
25
|
+
safeatanh_noeps,
|
|
26
|
+
safetanh_noeps,
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
# speeds up distribution construction
|
|
30
|
+
D.Distribution.set_default_validate_args(False)
|
|
31
|
+
|
|
32
|
+
try:
|
|
33
|
+
from torch.compiler import assume_constant_result
|
|
34
|
+
except ImportError:
|
|
35
|
+
from torch._dynamo import assume_constant_result
|
|
36
|
+
|
|
37
|
+
try:
|
|
38
|
+
from torch.compiler import is_compiling
|
|
39
|
+
except ImportError:
|
|
40
|
+
from torch._dynamo import is_compiling
|
|
41
|
+
|
|
42
|
+
TORCH_VERSION = version.parse(torch.__version__).base_version
|
|
43
|
+
TORCH_VERSION_PRE_2_6 = version.parse(TORCH_VERSION) < version.parse("2.6.0")
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class IndependentNormal(D.Independent):
|
|
47
|
+
"""Implements a Normal distribution with location scaling.
|
|
48
|
+
|
|
49
|
+
Location scaling prevents the location to be "too far" from 0, which ultimately
|
|
50
|
+
leads to numerically unstable samples and poor gradient computation (e.g. gradient explosion).
|
|
51
|
+
In practice, the location is computed according to
|
|
52
|
+
|
|
53
|
+
.. math::
|
|
54
|
+
loc = tanh(loc / upscale) * upscale.
|
|
55
|
+
|
|
56
|
+
This behavior can be disabled by switching off the tanh_loc parameter (see below).
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
loc (torch.Tensor): normal distribution location parameter
|
|
61
|
+
scale (torch.Tensor, float, or callable): normal distribution sigma parameter (squared root of variance).
|
|
62
|
+
Can be a tensor, a float, or a callable that takes the ``loc`` tensor as input and returns the scale tensor.
|
|
63
|
+
Using a callable (e.g., ``torch.ones_like`` or ``functools.partial(torch.full_like, fill_value=0.1)``)
|
|
64
|
+
avoids explicit device transfers like ``torch.tensor(val, device=device)`` and prevents graph breaks
|
|
65
|
+
in :func:`torch.compile`.
|
|
66
|
+
upscale (torch.Tensor or number, optional): 'a' scaling factor in the formula:
|
|
67
|
+
|
|
68
|
+
.. math::
|
|
69
|
+
loc = tanh(loc / upscale) * upscale.
|
|
70
|
+
|
|
71
|
+
Default is 5.0
|
|
72
|
+
|
|
73
|
+
tanh_loc (bool, optional): if ``False``, the above formula is used for
|
|
74
|
+
the location scaling, otherwise the raw value
|
|
75
|
+
is kept. Default is ``False``;
|
|
76
|
+
|
|
77
|
+
Example:
|
|
78
|
+
>>> import torch
|
|
79
|
+
>>> from functools import partial
|
|
80
|
+
>>> from torchrl.modules.distributions import IndependentNormal
|
|
81
|
+
>>> loc = torch.zeros(3, 4)
|
|
82
|
+
>>> # Using a callable scale avoids device transfers and graph breaks in torch.compile
|
|
83
|
+
>>> dist = IndependentNormal(loc, scale=torch.ones_like)
|
|
84
|
+
>>> # For a custom scale value, use partial to create a callable
|
|
85
|
+
>>> dist = IndependentNormal(loc, scale=partial(torch.full_like, fill_value=0.1))
|
|
86
|
+
>>> sample = dist.sample()
|
|
87
|
+
>>> sample.shape
|
|
88
|
+
torch.Size([3, 4])
|
|
89
|
+
|
|
90
|
+
"""
|
|
91
|
+
|
|
92
|
+
num_params: int = 2
|
|
93
|
+
|
|
94
|
+
def __init__(
|
|
95
|
+
self,
|
|
96
|
+
loc: torch.Tensor,
|
|
97
|
+
scale: torch.Tensor | float | Callable[[torch.Tensor], torch.Tensor],
|
|
98
|
+
upscale: float = 5.0,
|
|
99
|
+
tanh_loc: bool = False,
|
|
100
|
+
event_dim: int = 1,
|
|
101
|
+
**kwargs,
|
|
102
|
+
):
|
|
103
|
+
self.tanh_loc = tanh_loc
|
|
104
|
+
self.upscale = upscale
|
|
105
|
+
self._event_dim = event_dim
|
|
106
|
+
self._kwargs = kwargs
|
|
107
|
+
# Support callable scale (e.g., torch.ones_like) for compile-friendliness
|
|
108
|
+
if callable(scale) and not isinstance(scale, torch.Tensor):
|
|
109
|
+
scale = scale(loc)
|
|
110
|
+
elif not isinstance(scale, torch.Tensor):
|
|
111
|
+
scale = torch.as_tensor(scale, device=loc.device, dtype=loc.dtype)
|
|
112
|
+
elif scale.device != loc.device:
|
|
113
|
+
scale = scale.to(loc.device, non_blocking=loc.device.type == "cuda")
|
|
114
|
+
super().__init__(D.Normal(loc, scale, **kwargs), event_dim)
|
|
115
|
+
|
|
116
|
+
def update(self, loc, scale):
|
|
117
|
+
if self.tanh_loc:
|
|
118
|
+
loc = self.upscale * (loc / self.upscale).tanh()
|
|
119
|
+
# Support callable scale (e.g., torch.ones_like) for compile-friendliness
|
|
120
|
+
if callable(scale) and not isinstance(scale, torch.Tensor):
|
|
121
|
+
scale = scale(loc)
|
|
122
|
+
elif not isinstance(scale, torch.Tensor):
|
|
123
|
+
scale = torch.as_tensor(scale, device=loc.device, dtype=loc.dtype)
|
|
124
|
+
elif scale.device != loc.device:
|
|
125
|
+
scale = scale.to(loc.device, non_blocking=loc.device.type == "cuda")
|
|
126
|
+
super().__init__(D.Normal(loc, scale, **self._kwargs), self._event_dim)
|
|
127
|
+
|
|
128
|
+
@property
|
|
129
|
+
def mode(self):
|
|
130
|
+
return self.base_dist.mean
|
|
131
|
+
|
|
132
|
+
@property
|
|
133
|
+
def deterministic_sample(self):
|
|
134
|
+
return self.mean
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
class SafeTanhTransform(D.TanhTransform):
|
|
138
|
+
"""TanhTransform subclass that ensured that the transformation is numerically invertible."""
|
|
139
|
+
|
|
140
|
+
def _call(self, x: torch.Tensor) -> torch.Tensor:
|
|
141
|
+
return safetanh_noeps(x)
|
|
142
|
+
|
|
143
|
+
def _inverse(self, y: torch.Tensor) -> torch.Tensor:
|
|
144
|
+
return safeatanh_noeps(y)
|
|
145
|
+
|
|
146
|
+
@property
|
|
147
|
+
def inv(self):
|
|
148
|
+
inv = None
|
|
149
|
+
if self._inv is not None:
|
|
150
|
+
inv = self._inv()
|
|
151
|
+
if inv is None:
|
|
152
|
+
inv = _InverseTransform(self)
|
|
153
|
+
if not is_compiling():
|
|
154
|
+
self._inv = weakref.ref(inv)
|
|
155
|
+
return inv
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
class NormalParamWrapper(nn.Module): # noqa: D101
|
|
159
|
+
def __init__(
|
|
160
|
+
self,
|
|
161
|
+
operator: nn.Module,
|
|
162
|
+
scale_mapping: str = "biased_softplus_1.0",
|
|
163
|
+
scale_lb: Number = 1e-4,
|
|
164
|
+
) -> None:
|
|
165
|
+
raise RuntimeError(
|
|
166
|
+
"NormalParamWrapper has been deprecated in favor of `tensordict.nn.NormalParamExtractor`. Use this class instead."
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
class TruncatedNormal(D.Independent):
|
|
171
|
+
"""Implements a Truncated Normal distribution with location scaling.
|
|
172
|
+
|
|
173
|
+
Location scaling prevents the location to be "too far" from 0, which ultimately
|
|
174
|
+
leads to numerically unstable samples and poor gradient computation (e.g. gradient explosion).
|
|
175
|
+
In practice, the location is computed according to
|
|
176
|
+
|
|
177
|
+
.. math::
|
|
178
|
+
loc = tanh(loc / upscale) * upscale.
|
|
179
|
+
|
|
180
|
+
This behavior can be disabled by switching off the tanh_loc parameter (see below).
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
Args:
|
|
184
|
+
loc (torch.Tensor): normal distribution location parameter
|
|
185
|
+
scale (torch.Tensor): normal distribution sigma parameter (squared root of variance)
|
|
186
|
+
upscale (torch.Tensor or number, optional): 'a' scaling factor in the formula:
|
|
187
|
+
|
|
188
|
+
.. math::
|
|
189
|
+
loc = tanh(loc / upscale) * upscale.
|
|
190
|
+
|
|
191
|
+
Default is 5.0
|
|
192
|
+
|
|
193
|
+
low (torch.Tensor or number, optional): minimum value of the distribution. Default = -1.0;
|
|
194
|
+
high (torch.Tensor or number, optional): maximum value of the distribution. Default = 1.0;
|
|
195
|
+
tanh_loc (bool, optional): if ``True``, the above formula is used for
|
|
196
|
+
the location scaling, otherwise the raw value is kept.
|
|
197
|
+
Default is ``False``;
|
|
198
|
+
"""
|
|
199
|
+
|
|
200
|
+
num_params: int = 2
|
|
201
|
+
|
|
202
|
+
base_dist: _TruncatedNormal
|
|
203
|
+
|
|
204
|
+
arg_constraints = {
|
|
205
|
+
"loc": constraints.real,
|
|
206
|
+
"scale": constraints.greater_than(1e-6),
|
|
207
|
+
}
|
|
208
|
+
|
|
209
|
+
def __init__(
|
|
210
|
+
self,
|
|
211
|
+
loc: torch.Tensor,
|
|
212
|
+
scale: torch.Tensor,
|
|
213
|
+
upscale: torch.Tensor | float = 5.0,
|
|
214
|
+
low: torch.Tensor | float = -1.0,
|
|
215
|
+
high: torch.Tensor | float = 1.0,
|
|
216
|
+
tanh_loc: bool = False,
|
|
217
|
+
):
|
|
218
|
+
|
|
219
|
+
err_msg = "TanhNormal high values must be strictly greater than low values"
|
|
220
|
+
if isinstance(high, torch.Tensor) or isinstance(low, torch.Tensor):
|
|
221
|
+
if not (high > low).all():
|
|
222
|
+
raise RuntimeError(err_msg)
|
|
223
|
+
elif isinstance(high, Number) and isinstance(low, Number):
|
|
224
|
+
if not high > low:
|
|
225
|
+
raise RuntimeError(err_msg)
|
|
226
|
+
else:
|
|
227
|
+
if not all(high > low):
|
|
228
|
+
raise RuntimeError(err_msg)
|
|
229
|
+
|
|
230
|
+
if isinstance(high, torch.Tensor):
|
|
231
|
+
self.non_trivial_max = (high != 1.0).any()
|
|
232
|
+
else:
|
|
233
|
+
self.non_trivial_max = high != 1.0
|
|
234
|
+
|
|
235
|
+
if isinstance(low, torch.Tensor):
|
|
236
|
+
self.non_trivial_min = (low != -1.0).any()
|
|
237
|
+
else:
|
|
238
|
+
self.non_trivial_min = low != -1.0
|
|
239
|
+
self.tanh_loc = tanh_loc
|
|
240
|
+
|
|
241
|
+
self.device = loc.device
|
|
242
|
+
self.upscale = torch.as_tensor(upscale, device=self.device)
|
|
243
|
+
|
|
244
|
+
high = torch.as_tensor(high, device=self.device)
|
|
245
|
+
low = torch.as_tensor(low, device=self.device)
|
|
246
|
+
self.low = low
|
|
247
|
+
self.high = high
|
|
248
|
+
self.update(loc, scale)
|
|
249
|
+
|
|
250
|
+
@property
|
|
251
|
+
def min(self):
|
|
252
|
+
self._warn_minmax()
|
|
253
|
+
return self.low
|
|
254
|
+
|
|
255
|
+
@property
|
|
256
|
+
def max(self):
|
|
257
|
+
self._warn_minmax()
|
|
258
|
+
return self.high
|
|
259
|
+
|
|
260
|
+
def update(self, loc: torch.Tensor, scale: torch.Tensor) -> None:
|
|
261
|
+
if self.tanh_loc:
|
|
262
|
+
loc = (loc / self.upscale).tanh() * self.upscale
|
|
263
|
+
self.loc = loc
|
|
264
|
+
self.scale = scale
|
|
265
|
+
|
|
266
|
+
base_dist = _TruncatedNormal(
|
|
267
|
+
loc,
|
|
268
|
+
scale,
|
|
269
|
+
a=self.low.expand_as(loc),
|
|
270
|
+
b=self.high.expand_as(scale),
|
|
271
|
+
device=self.device,
|
|
272
|
+
)
|
|
273
|
+
super().__init__(base_dist, 1, validate_args=False)
|
|
274
|
+
|
|
275
|
+
@property
|
|
276
|
+
def mode(self):
|
|
277
|
+
m = self.base_dist.loc
|
|
278
|
+
a = self.base_dist._non_std_a + self.base_dist._dtype_min_gt_0
|
|
279
|
+
b = self.base_dist._non_std_b - self.base_dist._dtype_min_gt_0
|
|
280
|
+
m = torch.min(torch.stack([m, b], -1), dim=-1)[0]
|
|
281
|
+
return torch.max(torch.stack([m, a], -1), dim=-1)[0]
|
|
282
|
+
|
|
283
|
+
@property
|
|
284
|
+
def deterministic_sample(self):
|
|
285
|
+
return self.mean
|
|
286
|
+
|
|
287
|
+
def log_prob(self, value, **kwargs):
|
|
288
|
+
above_or_below = (self.low > value) | (self.high < value)
|
|
289
|
+
a = self.base_dist._non_std_a + self.base_dist._dtype_min_gt_0
|
|
290
|
+
a = a.expand_as(value)
|
|
291
|
+
b = self.base_dist._non_std_b - self.base_dist._dtype_min_gt_0
|
|
292
|
+
b = b.expand_as(value)
|
|
293
|
+
value = torch.min(torch.stack([value, b], -1), dim=-1)[0]
|
|
294
|
+
value = torch.max(torch.stack([value, a], -1), dim=-1)[0]
|
|
295
|
+
lp = super().log_prob(value, **kwargs)
|
|
296
|
+
if above_or_below.any():
|
|
297
|
+
if self.event_shape:
|
|
298
|
+
above_or_below = above_or_below.flatten(-len(self.event_shape), -1).any(
|
|
299
|
+
-1
|
|
300
|
+
)
|
|
301
|
+
lp = torch.masked_fill(
|
|
302
|
+
lp,
|
|
303
|
+
above_or_below.expand_as(lp),
|
|
304
|
+
torch.tensor(-float("inf"), device=lp.device, dtype=lp.dtype),
|
|
305
|
+
)
|
|
306
|
+
return lp
|
|
307
|
+
|
|
308
|
+
|
|
309
|
+
class _PatchedComposeTransform(D.ComposeTransform):
|
|
310
|
+
@property
|
|
311
|
+
def inv(self):
|
|
312
|
+
inv = None
|
|
313
|
+
if self._inv is not None:
|
|
314
|
+
inv = self._inv()
|
|
315
|
+
if inv is None:
|
|
316
|
+
inv = _PatchedComposeTransform([p.inv for p in reversed(self.parts)])
|
|
317
|
+
if not is_compiling():
|
|
318
|
+
self._inv = weakref.ref(inv)
|
|
319
|
+
inv._inv = weakref.ref(self)
|
|
320
|
+
return inv
|
|
321
|
+
|
|
322
|
+
|
|
323
|
+
class _PatchedAffineTransform(D.AffineTransform):
|
|
324
|
+
@property
|
|
325
|
+
def inv(self):
|
|
326
|
+
inv = None
|
|
327
|
+
if self._inv is not None:
|
|
328
|
+
inv = self._inv()
|
|
329
|
+
if inv is None:
|
|
330
|
+
inv = _InverseTransform(self)
|
|
331
|
+
if not is_compiling():
|
|
332
|
+
self._inv = weakref.ref(inv)
|
|
333
|
+
return inv
|
|
334
|
+
|
|
335
|
+
|
|
336
|
+
class TanhNormal(FasterTransformedDistribution):
|
|
337
|
+
"""Implements a TanhNormal distribution with location scaling.
|
|
338
|
+
|
|
339
|
+
Location scaling prevents the location to be "too far" from 0 when a
|
|
340
|
+
``TanhTransform`` is applied, but ultimately
|
|
341
|
+
leads to numerically unstable samples and poor gradient computation
|
|
342
|
+
(e.g. gradient explosion).
|
|
343
|
+
In practice, with location scaling the location is computed according to
|
|
344
|
+
|
|
345
|
+
.. math::
|
|
346
|
+
loc = tanh(loc / upscale) * upscale.
|
|
347
|
+
|
|
348
|
+
|
|
349
|
+
Args:
|
|
350
|
+
loc (torch.Tensor): normal distribution location parameter
|
|
351
|
+
scale (torch.Tensor, float, or callable): normal distribution sigma parameter (squared root of variance).
|
|
352
|
+
Can be a tensor, a float, or a callable that takes the ``loc`` tensor as input and returns the scale tensor.
|
|
353
|
+
Using a callable (e.g., ``torch.ones_like`` or ``functools.partial(torch.full_like, fill_value=0.1)``)
|
|
354
|
+
avoids explicit device transfers like ``torch.tensor(val, device=device)`` and prevents graph breaks
|
|
355
|
+
in :func:`torch.compile`.
|
|
356
|
+
upscale (torch.Tensor or number): 'a' scaling factor in the formula:
|
|
357
|
+
|
|
358
|
+
.. math::
|
|
359
|
+
loc = tanh(loc / upscale) * upscale.
|
|
360
|
+
|
|
361
|
+
low (torch.Tensor or number, optional): minimum value of the distribution. Default is -1.0;
|
|
362
|
+
high (torch.Tensor or number, optional): maximum value of the distribution. Default is 1.0;
|
|
363
|
+
event_dims (int, optional): number of dimensions describing the action.
|
|
364
|
+
Default is 1. Setting ``event_dims`` to ``0`` will result in a log-probability that has the same shape
|
|
365
|
+
as the input, ``1`` will reduce (sum over) the last dimension, ``2`` the last two etc.
|
|
366
|
+
tanh_loc (bool, optional): if ``True``, the above formula is used for the location scaling, otherwise the raw
|
|
367
|
+
value is kept. Default is ``False``;
|
|
368
|
+
safe_tanh (bool, optional): if ``True``, the Tanh transform is done "safely", to avoid numerical overflows.
|
|
369
|
+
This will currently break with :func:`torch.compile`.
|
|
370
|
+
|
|
371
|
+
Example:
|
|
372
|
+
>>> import torch
|
|
373
|
+
>>> from functools import partial
|
|
374
|
+
>>> from torchrl.modules.distributions import TanhNormal
|
|
375
|
+
>>> loc = torch.zeros(3, 4)
|
|
376
|
+
>>> # Using a callable scale avoids device transfers and graph breaks in torch.compile
|
|
377
|
+
>>> dist = TanhNormal(loc, scale=torch.ones_like)
|
|
378
|
+
>>> # For a custom scale value, use partial to create a callable
|
|
379
|
+
>>> dist = TanhNormal(loc, scale=partial(torch.full_like, fill_value=0.1))
|
|
380
|
+
>>> sample = dist.sample()
|
|
381
|
+
>>> sample.shape
|
|
382
|
+
torch.Size([3, 4])
|
|
383
|
+
|
|
384
|
+
"""
|
|
385
|
+
|
|
386
|
+
arg_constraints = {
|
|
387
|
+
"loc": constraints.real,
|
|
388
|
+
"scale": constraints.greater_than(1e-6),
|
|
389
|
+
}
|
|
390
|
+
|
|
391
|
+
num_params = 2
|
|
392
|
+
|
|
393
|
+
def __init__(
|
|
394
|
+
self,
|
|
395
|
+
loc: torch.Tensor,
|
|
396
|
+
scale: torch.Tensor | float | Callable[[torch.Tensor], torch.Tensor],
|
|
397
|
+
upscale: torch.Tensor | Number = 5.0,
|
|
398
|
+
low: torch.Tensor | Number = -1.0,
|
|
399
|
+
high: torch.Tensor | Number = 1.0,
|
|
400
|
+
event_dims: int | None = None,
|
|
401
|
+
tanh_loc: bool = False,
|
|
402
|
+
safe_tanh: bool = True,
|
|
403
|
+
):
|
|
404
|
+
if not isinstance(loc, torch.Tensor):
|
|
405
|
+
loc = torch.as_tensor(loc, dtype=torch.get_default_dtype())
|
|
406
|
+
_non_blocking = loc.device.type == "cuda"
|
|
407
|
+
# Support callable scale (e.g., torch.ones_like) for compile-friendliness
|
|
408
|
+
if callable(scale) and not isinstance(scale, torch.Tensor):
|
|
409
|
+
scale = scale(loc)
|
|
410
|
+
elif not isinstance(scale, torch.Tensor):
|
|
411
|
+
scale = torch.as_tensor(scale, device=loc.device, dtype=loc.dtype)
|
|
412
|
+
elif scale.device != loc.device:
|
|
413
|
+
scale = scale.to(loc.device, non_blocking=_non_blocking)
|
|
414
|
+
if event_dims is None:
|
|
415
|
+
event_dims = min(1, loc.ndim)
|
|
416
|
+
|
|
417
|
+
err_msg = "TanhNormal high values must be strictly greater than low values"
|
|
418
|
+
if not is_compiling() and not safe_is_current_stream_capturing():
|
|
419
|
+
if isinstance(high, torch.Tensor) or isinstance(low, torch.Tensor):
|
|
420
|
+
if not (high > low).all():
|
|
421
|
+
raise RuntimeError(err_msg)
|
|
422
|
+
elif isinstance(high, Number) and isinstance(low, Number):
|
|
423
|
+
if not high > low:
|
|
424
|
+
raise RuntimeError(err_msg)
|
|
425
|
+
else:
|
|
426
|
+
if not all(high > low):
|
|
427
|
+
raise RuntimeError(err_msg)
|
|
428
|
+
|
|
429
|
+
if not isinstance(high, torch.Tensor):
|
|
430
|
+
high = torch.as_tensor(high, device=loc.device)
|
|
431
|
+
elif high.device != loc.device:
|
|
432
|
+
high = high.to(loc.device, non_blocking=_non_blocking)
|
|
433
|
+
if not isinstance(low, torch.Tensor):
|
|
434
|
+
low = torch.as_tensor(low, device=loc.device)
|
|
435
|
+
elif low.device != loc.device:
|
|
436
|
+
low = low.to(loc.device, non_blocking=_non_blocking)
|
|
437
|
+
if not is_compiling() and not safe_is_current_stream_capturing():
|
|
438
|
+
self.non_trivial_max = (high != 1.0).any()
|
|
439
|
+
self.non_trivial_min = (low != -1.0).any()
|
|
440
|
+
else:
|
|
441
|
+
self.non_trivial_max = self.non_trivial_min = True
|
|
442
|
+
|
|
443
|
+
self.tanh_loc = tanh_loc
|
|
444
|
+
self._event_dims = event_dims
|
|
445
|
+
|
|
446
|
+
self.device = loc.device
|
|
447
|
+
self.upscale = (
|
|
448
|
+
upscale
|
|
449
|
+
if not isinstance(upscale, torch.Tensor)
|
|
450
|
+
else upscale.to(self.device, non_blocking=_non_blocking)
|
|
451
|
+
)
|
|
452
|
+
|
|
453
|
+
low = low.to(loc.device, non_blocking=_non_blocking)
|
|
454
|
+
self.low = low
|
|
455
|
+
self.high = high
|
|
456
|
+
|
|
457
|
+
if safe_tanh:
|
|
458
|
+
if is_compiling() and TORCH_VERSION_PRE_2_6:
|
|
459
|
+
_err_compile_safetanh()
|
|
460
|
+
t = SafeTanhTransform()
|
|
461
|
+
else:
|
|
462
|
+
t = D.TanhTransform()
|
|
463
|
+
# t = D.TanhTransform()
|
|
464
|
+
if is_compiling() or (self.non_trivial_max or self.non_trivial_min):
|
|
465
|
+
t = _PatchedComposeTransform(
|
|
466
|
+
[
|
|
467
|
+
t,
|
|
468
|
+
_PatchedAffineTransform(
|
|
469
|
+
loc=(high + low) / 2, scale=(high - low) / 2
|
|
470
|
+
),
|
|
471
|
+
]
|
|
472
|
+
)
|
|
473
|
+
self._t = t
|
|
474
|
+
|
|
475
|
+
self.update(loc, scale)
|
|
476
|
+
|
|
477
|
+
@property
|
|
478
|
+
def min(self):
|
|
479
|
+
self._warn_minmax()
|
|
480
|
+
return self.low
|
|
481
|
+
|
|
482
|
+
@property
|
|
483
|
+
def max(self):
|
|
484
|
+
self._warn_minmax()
|
|
485
|
+
return self.high
|
|
486
|
+
|
|
487
|
+
def update(self, loc: torch.Tensor, scale: torch.Tensor) -> None:
|
|
488
|
+
if self.tanh_loc:
|
|
489
|
+
loc = (loc / self.upscale).tanh() * self.upscale
|
|
490
|
+
# loc must be rescaled if tanh_loc
|
|
491
|
+
if is_compiling() or (self.non_trivial_max or self.non_trivial_min):
|
|
492
|
+
loc = loc + (self.high - self.low) / 2 + self.low
|
|
493
|
+
# Support callable scale (e.g., torch.ones_like) for compile-friendliness
|
|
494
|
+
if callable(scale) and not isinstance(scale, torch.Tensor):
|
|
495
|
+
scale = scale(loc)
|
|
496
|
+
elif not isinstance(scale, torch.Tensor):
|
|
497
|
+
scale = torch.as_tensor(scale, device=loc.device, dtype=loc.dtype)
|
|
498
|
+
elif scale.device != loc.device:
|
|
499
|
+
scale = scale.to(loc.device, non_blocking=loc.device.type == "cuda")
|
|
500
|
+
self.loc = loc
|
|
501
|
+
self.scale = scale
|
|
502
|
+
|
|
503
|
+
if (
|
|
504
|
+
hasattr(self, "base_dist")
|
|
505
|
+
and (self.root_dist.loc.shape == self.loc.shape)
|
|
506
|
+
and (self.root_dist.scale.shape == self.scale.shape)
|
|
507
|
+
):
|
|
508
|
+
self.root_dist.loc = self.loc
|
|
509
|
+
self.root_dist.scale = self.scale
|
|
510
|
+
else:
|
|
511
|
+
if self._event_dims > 0:
|
|
512
|
+
base = D.Independent(D.Normal(self.loc, self.scale), self._event_dims)
|
|
513
|
+
super().__init__(base, self._t)
|
|
514
|
+
else:
|
|
515
|
+
base = D.Normal(self.loc, self.scale)
|
|
516
|
+
super().__init__(base, self._t)
|
|
517
|
+
|
|
518
|
+
@property
|
|
519
|
+
def support(self):
|
|
520
|
+
return D.constraints.real()
|
|
521
|
+
|
|
522
|
+
@property
|
|
523
|
+
def root_dist(self):
|
|
524
|
+
bd = self
|
|
525
|
+
while hasattr(bd, "base_dist"):
|
|
526
|
+
bd = bd.base_dist
|
|
527
|
+
return bd
|
|
528
|
+
|
|
529
|
+
@property
|
|
530
|
+
def mode(self):
|
|
531
|
+
raise RuntimeError(
|
|
532
|
+
f"The distribution {type(self).__name__} has not analytical mode. "
|
|
533
|
+
f"Use ExplorationMode.DETERMINISTIC to get a deterministic sample from it."
|
|
534
|
+
)
|
|
535
|
+
|
|
536
|
+
@property
|
|
537
|
+
def deterministic_sample(self):
|
|
538
|
+
m = self.root_dist.mean
|
|
539
|
+
for t in self.transforms:
|
|
540
|
+
m = t(m)
|
|
541
|
+
return m
|
|
542
|
+
|
|
543
|
+
@torch.enable_grad()
|
|
544
|
+
def get_mode(self):
|
|
545
|
+
"""Computes an estimation of the mode using the Adam optimizer."""
|
|
546
|
+
# Get starting point
|
|
547
|
+
m = self.sample((1000,)).mean(0)
|
|
548
|
+
m = torch.nn.Parameter(m.clamp(self.low, self.high).detach())
|
|
549
|
+
optim = torch.optim.Adam((m,), lr=1e-2)
|
|
550
|
+
self_copy = type(self)(
|
|
551
|
+
loc=self.loc.detach(),
|
|
552
|
+
scale=self.scale.detach(),
|
|
553
|
+
low=self.low.detach(),
|
|
554
|
+
high=self.high.detach(),
|
|
555
|
+
event_dims=self._event_dims,
|
|
556
|
+
upscale=self.upscale,
|
|
557
|
+
tanh_loc=False,
|
|
558
|
+
)
|
|
559
|
+
for _ in range(200):
|
|
560
|
+
lp = -self_copy.log_prob(m)
|
|
561
|
+
lp.mean().backward()
|
|
562
|
+
mc = m.clone().detach()
|
|
563
|
+
m.grad.clamp_max_(1)
|
|
564
|
+
optim.step()
|
|
565
|
+
optim.zero_grad()
|
|
566
|
+
m.data.clamp_(self_copy.low, self_copy.high)
|
|
567
|
+
nans = m.isnan()
|
|
568
|
+
if nans.any():
|
|
569
|
+
m.data = torch.where(nans, mc, m.data)
|
|
570
|
+
if (m - mc).norm() < 1e-3:
|
|
571
|
+
break
|
|
572
|
+
return m.detach()
|
|
573
|
+
|
|
574
|
+
@property
|
|
575
|
+
def mean(self):
|
|
576
|
+
raise NotImplementedError(
|
|
577
|
+
f"{type(self).__name__} does not have a closed form formula for the average. "
|
|
578
|
+
"An estimate of this value can be computed using dist.sample((N,)).mean(dim=0), "
|
|
579
|
+
"where N is a large number of samples."
|
|
580
|
+
)
|
|
581
|
+
|
|
582
|
+
|
|
583
|
+
def uniform_sample_tanhnormal(dist: TanhNormal, size=None) -> torch.Tensor:
|
|
584
|
+
"""Defines what uniform sampling looks like for a TanhNormal distribution.
|
|
585
|
+
|
|
586
|
+
Args:
|
|
587
|
+
dist (TanhNormal): distribution defining the space where the sampling should occur.
|
|
588
|
+
size (torch.Size): batch-size of the output tensor
|
|
589
|
+
|
|
590
|
+
Returns:
|
|
591
|
+
a tensor sampled uniformly in the boundaries defined by the input distribution.
|
|
592
|
+
|
|
593
|
+
"""
|
|
594
|
+
if size is None:
|
|
595
|
+
size = torch.Size([])
|
|
596
|
+
return torch.rand_like(dist.sample(size)) * (dist.max - dist.min) + dist.min
|
|
597
|
+
|
|
598
|
+
|
|
599
|
+
class Delta(D.Distribution):
|
|
600
|
+
"""Delta distribution.
|
|
601
|
+
|
|
602
|
+
Args:
|
|
603
|
+
param (torch.Tensor): parameter of the delta distribution;
|
|
604
|
+
atol (number, optional): absolute tolerance to consider that a tensor matches the distribution parameter;
|
|
605
|
+
Default is 1e-6
|
|
606
|
+
rtol (number, optional): relative tolerance to consider that a tensor matches the distribution parameter;
|
|
607
|
+
Default is 1e-6
|
|
608
|
+
batch_shape (torch.Size, optional): batch shape;
|
|
609
|
+
event_shape (torch.Size, optional): shape of the outcome.
|
|
610
|
+
|
|
611
|
+
"""
|
|
612
|
+
|
|
613
|
+
arg_constraints: dict = {}
|
|
614
|
+
|
|
615
|
+
def __init__(
|
|
616
|
+
self,
|
|
617
|
+
param: torch.Tensor,
|
|
618
|
+
atol: float = 1e-6,
|
|
619
|
+
rtol: float = 1e-6,
|
|
620
|
+
batch_shape: torch.Size | Sequence[int] = None,
|
|
621
|
+
event_shape: torch.Size | Sequence[int] = None,
|
|
622
|
+
):
|
|
623
|
+
if batch_shape is None:
|
|
624
|
+
batch_shape = torch.Size([])
|
|
625
|
+
if event_shape is None:
|
|
626
|
+
event_shape = torch.Size([])
|
|
627
|
+
self.update(param)
|
|
628
|
+
self.atol = atol
|
|
629
|
+
self.rtol = rtol
|
|
630
|
+
if not len(batch_shape) and not len(event_shape):
|
|
631
|
+
batch_shape = param.shape[:-1]
|
|
632
|
+
event_shape = param.shape[-1:]
|
|
633
|
+
super().__init__(batch_shape=batch_shape, event_shape=event_shape)
|
|
634
|
+
|
|
635
|
+
def expand(self, batch_shape: torch.Size, _instance=None):
|
|
636
|
+
if self.batch_shape != tuple(batch_shape):
|
|
637
|
+
return type(self)(
|
|
638
|
+
self.param.expand((*batch_shape, *self.event_shape)),
|
|
639
|
+
atol=self.atol,
|
|
640
|
+
rtol=self.rtol,
|
|
641
|
+
)
|
|
642
|
+
return self
|
|
643
|
+
|
|
644
|
+
def update(self, param):
|
|
645
|
+
self.param = param
|
|
646
|
+
|
|
647
|
+
def _is_equal(self, value: torch.Tensor) -> torch.Tensor:
|
|
648
|
+
param = self.param.expand_as(value)
|
|
649
|
+
is_equal = abs(value - param) < self.atol + self.rtol * abs(param)
|
|
650
|
+
for i in range(-1, -len(self.event_shape) - 1, -1):
|
|
651
|
+
is_equal = is_equal.all(i)
|
|
652
|
+
return is_equal
|
|
653
|
+
|
|
654
|
+
def log_prob(self, value: torch.Tensor) -> torch.Tensor:
|
|
655
|
+
is_equal = self._is_equal(value)
|
|
656
|
+
out = torch.zeros_like(is_equal, dtype=value.dtype)
|
|
657
|
+
out.masked_fill_(is_equal, np.inf)
|
|
658
|
+
out.masked_fill_(~is_equal, -np.inf)
|
|
659
|
+
return out
|
|
660
|
+
|
|
661
|
+
@torch.no_grad()
|
|
662
|
+
def sample(self, size=None) -> torch.Tensor:
|
|
663
|
+
if size is None:
|
|
664
|
+
size = torch.Size([])
|
|
665
|
+
return self.param.expand(*size, *self.param.shape)
|
|
666
|
+
|
|
667
|
+
def rsample(self, size=None) -> torch.Tensor:
|
|
668
|
+
if size is None:
|
|
669
|
+
size = torch.Size([])
|
|
670
|
+
return self.param.expand(*size, *self.param.shape)
|
|
671
|
+
|
|
672
|
+
@property
|
|
673
|
+
def mode(self) -> torch.Tensor:
|
|
674
|
+
return self.param
|
|
675
|
+
|
|
676
|
+
@property
|
|
677
|
+
def deterministic_sample(self):
|
|
678
|
+
return self.mean
|
|
679
|
+
|
|
680
|
+
@property
|
|
681
|
+
def mean(self) -> torch.Tensor:
|
|
682
|
+
return self.param
|
|
683
|
+
|
|
684
|
+
|
|
685
|
+
class TanhDelta(FasterTransformedDistribution):
|
|
686
|
+
"""Implements a Tanh transformed_in Delta distribution.
|
|
687
|
+
|
|
688
|
+
Args:
|
|
689
|
+
param (torch.Tensor): parameter of the delta distribution;
|
|
690
|
+
low (torch.Tensor or number, optional): minimum value of the distribution. Default is -1.0;
|
|
691
|
+
high (torch.Tensor or number, optional): maximum value of the distribution. Default is 1.0;
|
|
692
|
+
event_dims (int, optional): number of dimensions describing the action.
|
|
693
|
+
Default is 1;
|
|
694
|
+
atol (number, optional): absolute tolerance to consider that a tensor matches the distribution parameter;
|
|
695
|
+
Default is 1e-6
|
|
696
|
+
rtol (number, optional): relative tolerance to consider that a tensor matches the distribution parameter;
|
|
697
|
+
Default is 1e-6
|
|
698
|
+
batch_shape (torch.Size, optional): batch shape;
|
|
699
|
+
event_shape (torch.Size, optional): shape of the outcome;
|
|
700
|
+
|
|
701
|
+
"""
|
|
702
|
+
|
|
703
|
+
arg_constraints = {
|
|
704
|
+
"loc": constraints.real,
|
|
705
|
+
}
|
|
706
|
+
|
|
707
|
+
def __init__(
|
|
708
|
+
self,
|
|
709
|
+
param: torch.Tensor,
|
|
710
|
+
low: torch.Tensor | float = -1.0,
|
|
711
|
+
high: torch.Tensor | float = 1.0,
|
|
712
|
+
event_dims: int = 1,
|
|
713
|
+
atol: float = 1e-6,
|
|
714
|
+
rtol: float = 1e-6,
|
|
715
|
+
safe: bool = True,
|
|
716
|
+
):
|
|
717
|
+
minmax_msg = "high value has been found to be equal or less than low value"
|
|
718
|
+
if isinstance(high, torch.Tensor) or isinstance(low, torch.Tensor):
|
|
719
|
+
if is_compiling():
|
|
720
|
+
assert (high > low).all()
|
|
721
|
+
else:
|
|
722
|
+
if not (high > low).all():
|
|
723
|
+
raise ValueError(minmax_msg)
|
|
724
|
+
elif isinstance(high, Number) and isinstance(low, Number):
|
|
725
|
+
if is_compiling():
|
|
726
|
+
assert high > low
|
|
727
|
+
elif high <= low:
|
|
728
|
+
raise ValueError(minmax_msg)
|
|
729
|
+
else:
|
|
730
|
+
if not all(high > low):
|
|
731
|
+
raise ValueError(minmax_msg)
|
|
732
|
+
|
|
733
|
+
if safe:
|
|
734
|
+
if is_compiling():
|
|
735
|
+
_err_compile_safetanh()
|
|
736
|
+
t = SafeTanhTransform()
|
|
737
|
+
else:
|
|
738
|
+
t = torch.distributions.TanhTransform()
|
|
739
|
+
non_trivial_min = is_compiling or (
|
|
740
|
+
(isinstance(low, torch.Tensor) and (low != -1.0).any())
|
|
741
|
+
or (not isinstance(low, torch.Tensor) and low != -1.0)
|
|
742
|
+
)
|
|
743
|
+
non_trivial_max = is_compiling or (
|
|
744
|
+
(isinstance(high, torch.Tensor) and (high != 1.0).any())
|
|
745
|
+
or (not isinstance(high, torch.Tensor) and high != 1.0)
|
|
746
|
+
)
|
|
747
|
+
self.non_trivial = non_trivial_min or non_trivial_max
|
|
748
|
+
|
|
749
|
+
self.low = _cast_device(low, param.device)
|
|
750
|
+
self.high = _cast_device(high, param.device)
|
|
751
|
+
loc = self.update(param)
|
|
752
|
+
|
|
753
|
+
if self.non_trivial:
|
|
754
|
+
t = _PatchedComposeTransform(
|
|
755
|
+
[
|
|
756
|
+
t,
|
|
757
|
+
_PatchedAffineTransform(
|
|
758
|
+
loc=(self.high + self.low) / 2, scale=(self.high - self.low) / 2
|
|
759
|
+
),
|
|
760
|
+
]
|
|
761
|
+
)
|
|
762
|
+
event_shape = param.shape[-event_dims:]
|
|
763
|
+
batch_shape = param.shape[:-event_dims]
|
|
764
|
+
base = Delta(
|
|
765
|
+
loc,
|
|
766
|
+
atol=atol,
|
|
767
|
+
rtol=rtol,
|
|
768
|
+
batch_shape=batch_shape,
|
|
769
|
+
event_shape=event_shape,
|
|
770
|
+
)
|
|
771
|
+
|
|
772
|
+
super().__init__(base, t)
|
|
773
|
+
|
|
774
|
+
@property
|
|
775
|
+
def min(self):
|
|
776
|
+
self._warn_minmax()
|
|
777
|
+
return self.low
|
|
778
|
+
|
|
779
|
+
@property
|
|
780
|
+
def max(self):
|
|
781
|
+
self._warn_minmax()
|
|
782
|
+
return self.high
|
|
783
|
+
|
|
784
|
+
def update(self, net_output: torch.Tensor) -> torch.Tensor | None:
|
|
785
|
+
loc = net_output
|
|
786
|
+
if self.non_trivial:
|
|
787
|
+
device = loc.device
|
|
788
|
+
shift = _cast_device(self.high - self.low, device)
|
|
789
|
+
loc = loc + shift / 2 + _cast_device(self.low, device)
|
|
790
|
+
if hasattr(self, "base_dist"):
|
|
791
|
+
self.base_dist.update(loc)
|
|
792
|
+
else:
|
|
793
|
+
return loc
|
|
794
|
+
|
|
795
|
+
@property
|
|
796
|
+
def mode(self) -> torch.Tensor:
|
|
797
|
+
mode = self.base_dist.param
|
|
798
|
+
for t in self.transforms:
|
|
799
|
+
mode = t(mode)
|
|
800
|
+
return mode
|
|
801
|
+
|
|
802
|
+
@property
|
|
803
|
+
def deterministic_sample(self):
|
|
804
|
+
return self.mode
|
|
805
|
+
|
|
806
|
+
@property
|
|
807
|
+
def mean(self) -> torch.Tensor:
|
|
808
|
+
raise AttributeError("TanhDelta mean has not analytical form.")
|
|
809
|
+
|
|
810
|
+
|
|
811
|
+
def _uniform_sample_delta(dist: Delta, size=None) -> torch.Tensor:
|
|
812
|
+
if size is None:
|
|
813
|
+
size = torch.Size([])
|
|
814
|
+
return torch.randn_like(dist.sample(size))
|
|
815
|
+
|
|
816
|
+
|
|
817
|
+
uniform_sample_delta = _uniform_sample_delta
|
|
818
|
+
|
|
819
|
+
|
|
820
|
+
def _err_compile_safetanh():
|
|
821
|
+
raise RuntimeError(
|
|
822
|
+
"safe_tanh=True in TanhNormal is not compatible with torch.compile with torch pre 2.6.0. "
|
|
823
|
+
" To deactivate it, pass safe_tanh=False. "
|
|
824
|
+
"If you are using a ProbabilisticTensorDictModule, this can be done via "
|
|
825
|
+
"`distribution_kwargs={'safe_tanh': False}`. "
|
|
826
|
+
"See https://github.com/pytorch/pytorch/issues/133529 for more details."
|
|
827
|
+
)
|
|
828
|
+
|
|
829
|
+
|
|
830
|
+
_warn_compile_safetanh = assume_constant_result(_err_compile_safetanh)
|