torchrl 0.11.0__cp314-cp314t-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/benchmark_batched_envs.py +104 -0
- benchmarks/conftest.py +91 -0
- benchmarks/ecosystem/gym_env_throughput.py +321 -0
- benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
- benchmarks/requirements.txt +7 -0
- benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
- benchmarks/test_collectors_benchmark.py +240 -0
- benchmarks/test_compressed_storage_benchmark.py +145 -0
- benchmarks/test_envs_benchmark.py +133 -0
- benchmarks/test_llm.py +101 -0
- benchmarks/test_non_tensor_env_benchmark.py +70 -0
- benchmarks/test_objectives_benchmarks.py +1199 -0
- benchmarks/test_replaybuffer_benchmark.py +254 -0
- sota-check/README.md +35 -0
- sota-implementations/README.md +142 -0
- sota-implementations/a2c/README.md +39 -0
- sota-implementations/a2c/a2c_atari.py +291 -0
- sota-implementations/a2c/a2c_mujoco.py +273 -0
- sota-implementations/a2c/utils_atari.py +240 -0
- sota-implementations/a2c/utils_mujoco.py +160 -0
- sota-implementations/bandits/README.md +7 -0
- sota-implementations/bandits/dqn.py +126 -0
- sota-implementations/cql/cql_offline.py +198 -0
- sota-implementations/cql/cql_online.py +249 -0
- sota-implementations/cql/discrete_cql_offline.py +180 -0
- sota-implementations/cql/discrete_cql_online.py +227 -0
- sota-implementations/cql/utils.py +471 -0
- sota-implementations/crossq/crossq.py +271 -0
- sota-implementations/crossq/utils.py +320 -0
- sota-implementations/ddpg/ddpg.py +231 -0
- sota-implementations/ddpg/utils.py +325 -0
- sota-implementations/decision_transformer/dt.py +163 -0
- sota-implementations/decision_transformer/lamb.py +167 -0
- sota-implementations/decision_transformer/online_dt.py +178 -0
- sota-implementations/decision_transformer/utils.py +562 -0
- sota-implementations/discrete_sac/discrete_sac.py +243 -0
- sota-implementations/discrete_sac/utils.py +324 -0
- sota-implementations/dqn/README.md +30 -0
- sota-implementations/dqn/dqn_atari.py +272 -0
- sota-implementations/dqn/dqn_cartpole.py +236 -0
- sota-implementations/dqn/utils_atari.py +132 -0
- sota-implementations/dqn/utils_cartpole.py +90 -0
- sota-implementations/dreamer/README.md +129 -0
- sota-implementations/dreamer/dreamer.py +586 -0
- sota-implementations/dreamer/dreamer_utils.py +1107 -0
- sota-implementations/expert-iteration/README.md +352 -0
- sota-implementations/expert-iteration/ei_utils.py +770 -0
- sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
- sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
- sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
- sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
- sota-implementations/gail/gail.py +327 -0
- sota-implementations/gail/gail_utils.py +68 -0
- sota-implementations/gail/ppo_utils.py +157 -0
- sota-implementations/grpo/README.md +273 -0
- sota-implementations/grpo/grpo-async.py +437 -0
- sota-implementations/grpo/grpo-sync.py +435 -0
- sota-implementations/grpo/grpo_utils.py +843 -0
- sota-implementations/grpo/requirements_gsm8k.txt +11 -0
- sota-implementations/grpo/requirements_ifeval.txt +16 -0
- sota-implementations/impala/README.md +33 -0
- sota-implementations/impala/impala_multi_node_ray.py +292 -0
- sota-implementations/impala/impala_multi_node_submitit.py +284 -0
- sota-implementations/impala/impala_single_node.py +261 -0
- sota-implementations/impala/utils.py +184 -0
- sota-implementations/iql/discrete_iql.py +230 -0
- sota-implementations/iql/iql_offline.py +164 -0
- sota-implementations/iql/iql_online.py +225 -0
- sota-implementations/iql/utils.py +437 -0
- sota-implementations/multiagent/README.md +74 -0
- sota-implementations/multiagent/iql.py +237 -0
- sota-implementations/multiagent/maddpg_iddpg.py +266 -0
- sota-implementations/multiagent/mappo_ippo.py +267 -0
- sota-implementations/multiagent/qmix_vdn.py +271 -0
- sota-implementations/multiagent/sac.py +337 -0
- sota-implementations/multiagent/utils/__init__.py +4 -0
- sota-implementations/multiagent/utils/logging.py +151 -0
- sota-implementations/multiagent/utils/utils.py +43 -0
- sota-implementations/ppo/README.md +29 -0
- sota-implementations/ppo/ppo_atari.py +305 -0
- sota-implementations/ppo/ppo_mujoco.py +293 -0
- sota-implementations/ppo/utils_atari.py +238 -0
- sota-implementations/ppo/utils_mujoco.py +152 -0
- sota-implementations/ppo_trainer/train.py +21 -0
- sota-implementations/redq/README.md +7 -0
- sota-implementations/redq/redq.py +199 -0
- sota-implementations/redq/utils.py +1060 -0
- sota-implementations/sac/sac-async.py +266 -0
- sota-implementations/sac/sac.py +239 -0
- sota-implementations/sac/utils.py +381 -0
- sota-implementations/sac_trainer/train.py +16 -0
- sota-implementations/td3/td3.py +254 -0
- sota-implementations/td3/utils.py +319 -0
- sota-implementations/td3_bc/td3_bc.py +177 -0
- sota-implementations/td3_bc/utils.py +251 -0
- torchrl/.dylibs/libc++.1.0.dylib +0 -0
- torchrl/__init__.py +144 -0
- torchrl/_extension.py +74 -0
- torchrl/_torchrl.cpython-314t-darwin.so +0 -0
- torchrl/_utils.py +1431 -0
- torchrl/collectors/__init__.py +48 -0
- torchrl/collectors/_base.py +1058 -0
- torchrl/collectors/_constants.py +88 -0
- torchrl/collectors/_multi_async.py +324 -0
- torchrl/collectors/_multi_base.py +1805 -0
- torchrl/collectors/_multi_sync.py +464 -0
- torchrl/collectors/_runner.py +581 -0
- torchrl/collectors/_single.py +2009 -0
- torchrl/collectors/_single_async.py +259 -0
- torchrl/collectors/collectors.py +62 -0
- torchrl/collectors/distributed/__init__.py +32 -0
- torchrl/collectors/distributed/default_configs.py +133 -0
- torchrl/collectors/distributed/generic.py +1306 -0
- torchrl/collectors/distributed/ray.py +1092 -0
- torchrl/collectors/distributed/rpc.py +1006 -0
- torchrl/collectors/distributed/sync.py +731 -0
- torchrl/collectors/distributed/utils.py +160 -0
- torchrl/collectors/llm/__init__.py +10 -0
- torchrl/collectors/llm/base.py +494 -0
- torchrl/collectors/llm/ray_collector.py +275 -0
- torchrl/collectors/llm/utils.py +36 -0
- torchrl/collectors/llm/weight_update/__init__.py +10 -0
- torchrl/collectors/llm/weight_update/vllm.py +348 -0
- torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
- torchrl/collectors/utils.py +433 -0
- torchrl/collectors/weight_update.py +591 -0
- torchrl/csrc/numpy_utils.h +38 -0
- torchrl/csrc/pybind.cpp +27 -0
- torchrl/csrc/segment_tree.h +458 -0
- torchrl/csrc/torch_utils.h +34 -0
- torchrl/csrc/utils.cpp +48 -0
- torchrl/csrc/utils.h +31 -0
- torchrl/data/__init__.py +187 -0
- torchrl/data/datasets/__init__.py +58 -0
- torchrl/data/datasets/atari_dqn.py +878 -0
- torchrl/data/datasets/common.py +281 -0
- torchrl/data/datasets/d4rl.py +489 -0
- torchrl/data/datasets/d4rl_infos.py +187 -0
- torchrl/data/datasets/gen_dgrl.py +375 -0
- torchrl/data/datasets/minari_data.py +643 -0
- torchrl/data/datasets/openml.py +177 -0
- torchrl/data/datasets/openx.py +798 -0
- torchrl/data/datasets/roboset.py +363 -0
- torchrl/data/datasets/utils.py +11 -0
- torchrl/data/datasets/vd4rl.py +432 -0
- torchrl/data/llm/__init__.py +34 -0
- torchrl/data/llm/dataset.py +491 -0
- torchrl/data/llm/history.py +1378 -0
- torchrl/data/llm/prompt.py +198 -0
- torchrl/data/llm/reward.py +225 -0
- torchrl/data/llm/topk.py +186 -0
- torchrl/data/llm/utils.py +543 -0
- torchrl/data/map/__init__.py +21 -0
- torchrl/data/map/hash.py +185 -0
- torchrl/data/map/query.py +204 -0
- torchrl/data/map/tdstorage.py +363 -0
- torchrl/data/map/tree.py +1434 -0
- torchrl/data/map/utils.py +103 -0
- torchrl/data/postprocs/__init__.py +8 -0
- torchrl/data/postprocs/postprocs.py +391 -0
- torchrl/data/replay_buffers/__init__.py +99 -0
- torchrl/data/replay_buffers/checkpointers.py +622 -0
- torchrl/data/replay_buffers/ray_buffer.py +292 -0
- torchrl/data/replay_buffers/replay_buffers.py +2376 -0
- torchrl/data/replay_buffers/samplers.py +2578 -0
- torchrl/data/replay_buffers/scheduler.py +265 -0
- torchrl/data/replay_buffers/storages.py +2412 -0
- torchrl/data/replay_buffers/utils.py +1042 -0
- torchrl/data/replay_buffers/writers.py +781 -0
- torchrl/data/tensor_specs.py +7101 -0
- torchrl/data/utils.py +334 -0
- torchrl/envs/__init__.py +265 -0
- torchrl/envs/async_envs.py +1105 -0
- torchrl/envs/batched_envs.py +3093 -0
- torchrl/envs/common.py +4241 -0
- torchrl/envs/custom/__init__.py +11 -0
- torchrl/envs/custom/chess.py +617 -0
- torchrl/envs/custom/llm.py +214 -0
- torchrl/envs/custom/pendulum.py +401 -0
- torchrl/envs/custom/san_moves.txt +29274 -0
- torchrl/envs/custom/tictactoeenv.py +288 -0
- torchrl/envs/env_creator.py +263 -0
- torchrl/envs/gym_like.py +752 -0
- torchrl/envs/libs/__init__.py +68 -0
- torchrl/envs/libs/_gym_utils.py +326 -0
- torchrl/envs/libs/brax.py +846 -0
- torchrl/envs/libs/dm_control.py +544 -0
- torchrl/envs/libs/envpool.py +447 -0
- torchrl/envs/libs/gym.py +2239 -0
- torchrl/envs/libs/habitat.py +138 -0
- torchrl/envs/libs/isaac_lab.py +87 -0
- torchrl/envs/libs/isaacgym.py +203 -0
- torchrl/envs/libs/jax_utils.py +166 -0
- torchrl/envs/libs/jumanji.py +963 -0
- torchrl/envs/libs/meltingpot.py +599 -0
- torchrl/envs/libs/openml.py +153 -0
- torchrl/envs/libs/openspiel.py +652 -0
- torchrl/envs/libs/pettingzoo.py +1042 -0
- torchrl/envs/libs/procgen.py +351 -0
- torchrl/envs/libs/robohive.py +429 -0
- torchrl/envs/libs/smacv2.py +645 -0
- torchrl/envs/libs/unity_mlagents.py +891 -0
- torchrl/envs/libs/utils.py +147 -0
- torchrl/envs/libs/vmas.py +813 -0
- torchrl/envs/llm/__init__.py +63 -0
- torchrl/envs/llm/chat.py +730 -0
- torchrl/envs/llm/datasets/README.md +4 -0
- torchrl/envs/llm/datasets/__init__.py +17 -0
- torchrl/envs/llm/datasets/gsm8k.py +353 -0
- torchrl/envs/llm/datasets/ifeval.py +274 -0
- torchrl/envs/llm/envs.py +789 -0
- torchrl/envs/llm/libs/README.md +3 -0
- torchrl/envs/llm/libs/__init__.py +8 -0
- torchrl/envs/llm/libs/mlgym.py +869 -0
- torchrl/envs/llm/reward/__init__.py +10 -0
- torchrl/envs/llm/reward/gsm8k.py +324 -0
- torchrl/envs/llm/reward/ifeval/README.md +13 -0
- torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
- torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
- torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
- torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
- torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
- torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
- torchrl/envs/llm/transforms/__init__.py +55 -0
- torchrl/envs/llm/transforms/browser.py +292 -0
- torchrl/envs/llm/transforms/dataloading.py +859 -0
- torchrl/envs/llm/transforms/format.py +73 -0
- torchrl/envs/llm/transforms/kl.py +1544 -0
- torchrl/envs/llm/transforms/policy_version.py +189 -0
- torchrl/envs/llm/transforms/reason.py +323 -0
- torchrl/envs/llm/transforms/tokenizer.py +321 -0
- torchrl/envs/llm/transforms/tools.py +1955 -0
- torchrl/envs/model_based/__init__.py +9 -0
- torchrl/envs/model_based/common.py +180 -0
- torchrl/envs/model_based/dreamer.py +112 -0
- torchrl/envs/transforms/__init__.py +147 -0
- torchrl/envs/transforms/functional.py +48 -0
- torchrl/envs/transforms/gym_transforms.py +203 -0
- torchrl/envs/transforms/module.py +341 -0
- torchrl/envs/transforms/r3m.py +372 -0
- torchrl/envs/transforms/ray_service.py +663 -0
- torchrl/envs/transforms/rb_transforms.py +214 -0
- torchrl/envs/transforms/transforms.py +11835 -0
- torchrl/envs/transforms/utils.py +94 -0
- torchrl/envs/transforms/vc1.py +307 -0
- torchrl/envs/transforms/vecnorm.py +845 -0
- torchrl/envs/transforms/vip.py +407 -0
- torchrl/envs/utils.py +1718 -0
- torchrl/envs/vec_envs.py +11 -0
- torchrl/modules/__init__.py +206 -0
- torchrl/modules/distributions/__init__.py +73 -0
- torchrl/modules/distributions/continuous.py +830 -0
- torchrl/modules/distributions/discrete.py +908 -0
- torchrl/modules/distributions/truncated_normal.py +187 -0
- torchrl/modules/distributions/utils.py +233 -0
- torchrl/modules/llm/__init__.py +62 -0
- torchrl/modules/llm/backends/__init__.py +65 -0
- torchrl/modules/llm/backends/vllm/__init__.py +94 -0
- torchrl/modules/llm/backends/vllm/_models.py +46 -0
- torchrl/modules/llm/backends/vllm/base.py +72 -0
- torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
- torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
- torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
- torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
- torchrl/modules/llm/policies/__init__.py +28 -0
- torchrl/modules/llm/policies/common.py +1809 -0
- torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
- torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
- torchrl/modules/llm/utils.py +23 -0
- torchrl/modules/mcts/__init__.py +21 -0
- torchrl/modules/mcts/scores.py +579 -0
- torchrl/modules/models/__init__.py +86 -0
- torchrl/modules/models/batchrenorm.py +119 -0
- torchrl/modules/models/decision_transformer.py +179 -0
- torchrl/modules/models/exploration.py +731 -0
- torchrl/modules/models/llm.py +156 -0
- torchrl/modules/models/model_based.py +596 -0
- torchrl/modules/models/models.py +1712 -0
- torchrl/modules/models/multiagent.py +1067 -0
- torchrl/modules/models/recipes/impala.py +185 -0
- torchrl/modules/models/utils.py +162 -0
- torchrl/modules/planners/__init__.py +10 -0
- torchrl/modules/planners/cem.py +228 -0
- torchrl/modules/planners/common.py +73 -0
- torchrl/modules/planners/mppi.py +265 -0
- torchrl/modules/tensordict_module/__init__.py +89 -0
- torchrl/modules/tensordict_module/actors.py +2457 -0
- torchrl/modules/tensordict_module/common.py +529 -0
- torchrl/modules/tensordict_module/exploration.py +814 -0
- torchrl/modules/tensordict_module/probabilistic.py +321 -0
- torchrl/modules/tensordict_module/rnn.py +1639 -0
- torchrl/modules/tensordict_module/sequence.py +132 -0
- torchrl/modules/tensordict_module/world_models.py +34 -0
- torchrl/modules/utils/__init__.py +38 -0
- torchrl/modules/utils/mappings.py +9 -0
- torchrl/modules/utils/utils.py +89 -0
- torchrl/objectives/__init__.py +78 -0
- torchrl/objectives/a2c.py +659 -0
- torchrl/objectives/common.py +753 -0
- torchrl/objectives/cql.py +1346 -0
- torchrl/objectives/crossq.py +710 -0
- torchrl/objectives/ddpg.py +453 -0
- torchrl/objectives/decision_transformer.py +371 -0
- torchrl/objectives/deprecated.py +516 -0
- torchrl/objectives/dqn.py +683 -0
- torchrl/objectives/dreamer.py +488 -0
- torchrl/objectives/functional.py +48 -0
- torchrl/objectives/gail.py +258 -0
- torchrl/objectives/iql.py +996 -0
- torchrl/objectives/llm/__init__.py +30 -0
- torchrl/objectives/llm/grpo.py +846 -0
- torchrl/objectives/llm/sft.py +482 -0
- torchrl/objectives/multiagent/__init__.py +8 -0
- torchrl/objectives/multiagent/qmixer.py +396 -0
- torchrl/objectives/ppo.py +1669 -0
- torchrl/objectives/redq.py +683 -0
- torchrl/objectives/reinforce.py +530 -0
- torchrl/objectives/sac.py +1580 -0
- torchrl/objectives/td3.py +570 -0
- torchrl/objectives/td3_bc.py +625 -0
- torchrl/objectives/utils.py +782 -0
- torchrl/objectives/value/__init__.py +28 -0
- torchrl/objectives/value/advantages.py +1956 -0
- torchrl/objectives/value/functional.py +1459 -0
- torchrl/objectives/value/utils.py +360 -0
- torchrl/record/__init__.py +17 -0
- torchrl/record/loggers/__init__.py +23 -0
- torchrl/record/loggers/common.py +48 -0
- torchrl/record/loggers/csv.py +226 -0
- torchrl/record/loggers/mlflow.py +142 -0
- torchrl/record/loggers/tensorboard.py +139 -0
- torchrl/record/loggers/trackio.py +163 -0
- torchrl/record/loggers/utils.py +78 -0
- torchrl/record/loggers/wandb.py +214 -0
- torchrl/record/recorder.py +554 -0
- torchrl/services/__init__.py +79 -0
- torchrl/services/base.py +109 -0
- torchrl/services/ray_service.py +453 -0
- torchrl/testing/__init__.py +107 -0
- torchrl/testing/assertions.py +179 -0
- torchrl/testing/dist_utils.py +122 -0
- torchrl/testing/env_creators.py +227 -0
- torchrl/testing/env_helper.py +35 -0
- torchrl/testing/gym_helpers.py +156 -0
- torchrl/testing/llm_mocks.py +119 -0
- torchrl/testing/mocking_classes.py +2720 -0
- torchrl/testing/modules.py +295 -0
- torchrl/testing/mp_helpers.py +15 -0
- torchrl/testing/ray_helpers.py +293 -0
- torchrl/testing/utils.py +190 -0
- torchrl/trainers/__init__.py +42 -0
- torchrl/trainers/algorithms/__init__.py +11 -0
- torchrl/trainers/algorithms/configs/__init__.py +705 -0
- torchrl/trainers/algorithms/configs/collectors.py +216 -0
- torchrl/trainers/algorithms/configs/common.py +41 -0
- torchrl/trainers/algorithms/configs/data.py +308 -0
- torchrl/trainers/algorithms/configs/envs.py +104 -0
- torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
- torchrl/trainers/algorithms/configs/logging.py +80 -0
- torchrl/trainers/algorithms/configs/modules.py +570 -0
- torchrl/trainers/algorithms/configs/objectives.py +177 -0
- torchrl/trainers/algorithms/configs/trainers.py +340 -0
- torchrl/trainers/algorithms/configs/transforms.py +955 -0
- torchrl/trainers/algorithms/configs/utils.py +252 -0
- torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
- torchrl/trainers/algorithms/configs/weight_update.py +159 -0
- torchrl/trainers/algorithms/ppo.py +373 -0
- torchrl/trainers/algorithms/sac.py +308 -0
- torchrl/trainers/helpers/__init__.py +40 -0
- torchrl/trainers/helpers/collectors.py +416 -0
- torchrl/trainers/helpers/envs.py +573 -0
- torchrl/trainers/helpers/logger.py +33 -0
- torchrl/trainers/helpers/losses.py +132 -0
- torchrl/trainers/helpers/models.py +658 -0
- torchrl/trainers/helpers/replay_buffer.py +59 -0
- torchrl/trainers/helpers/trainers.py +301 -0
- torchrl/trainers/trainers.py +2052 -0
- torchrl/weight_update/__init__.py +33 -0
- torchrl/weight_update/_distributed.py +749 -0
- torchrl/weight_update/_mp.py +624 -0
- torchrl/weight_update/_noupdate.py +102 -0
- torchrl/weight_update/_ray.py +1032 -0
- torchrl/weight_update/_rpc.py +284 -0
- torchrl/weight_update/_shared.py +891 -0
- torchrl/weight_update/llm/__init__.py +32 -0
- torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
- torchrl/weight_update/llm/vllm_nccl.py +710 -0
- torchrl/weight_update/utils.py +73 -0
- torchrl/weight_update/weight_sync_schemes.py +1244 -0
- torchrl-0.11.0.dist-info/METADATA +1308 -0
- torchrl-0.11.0.dist-info/RECORD +395 -0
- torchrl-0.11.0.dist-info/WHEEL +5 -0
- torchrl-0.11.0.dist-info/entry_points.txt +2 -0
- torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
- torchrl-0.11.0.dist-info/top_level.txt +7 -0
|
@@ -0,0 +1,996 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import warnings
|
|
8
|
+
from dataclasses import dataclass
|
|
9
|
+
|
|
10
|
+
import torch
|
|
11
|
+
from tensordict import TensorDict, TensorDictBase, TensorDictParams
|
|
12
|
+
from tensordict.nn import dispatch, ProbabilisticTensorDictSequential, TensorDictModule
|
|
13
|
+
from tensordict.utils import NestedKey
|
|
14
|
+
from torch import Tensor
|
|
15
|
+
|
|
16
|
+
from torchrl.data.tensor_specs import TensorSpec
|
|
17
|
+
from torchrl.data.utils import _find_action_space
|
|
18
|
+
from torchrl.objectives.common import LossModule
|
|
19
|
+
from torchrl.objectives.utils import (
|
|
20
|
+
_GAMMA_LMBDA_DEPREC_ERROR,
|
|
21
|
+
_pseudo_vmap,
|
|
22
|
+
_reduce,
|
|
23
|
+
_vmap_func,
|
|
24
|
+
default_value_kwargs,
|
|
25
|
+
distance_loss,
|
|
26
|
+
ValueEstimators,
|
|
27
|
+
)
|
|
28
|
+
from torchrl.objectives.value import (
|
|
29
|
+
TD0Estimator,
|
|
30
|
+
TD1Estimator,
|
|
31
|
+
TDLambdaEstimator,
|
|
32
|
+
ValueEstimatorBase,
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class IQLLoss(LossModule):
|
|
37
|
+
r"""TorchRL implementation of the IQL loss.
|
|
38
|
+
|
|
39
|
+
Presented in "Offline Reinforcement Learning with Implicit Q-Learning" https://arxiv.org/abs/2110.06169
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
actor_network (ProbabilisticTensorDictSequential): stochastic actor
|
|
43
|
+
qvalue_network (TensorDictModule): Q(s, a) parametric model
|
|
44
|
+
If a single instance of `qvalue_network` is provided, it will be duplicated ``num_qvalue_nets``
|
|
45
|
+
times. If a list of modules is passed, their
|
|
46
|
+
parameters will be stacked unless they share the same identity (in which case
|
|
47
|
+
the original parameter will be expanded).
|
|
48
|
+
|
|
49
|
+
.. warning:: When a list of parameters if passed, it will **not** be compared against the policy parameters
|
|
50
|
+
and all the parameters will be considered as untied.
|
|
51
|
+
|
|
52
|
+
value_network (TensorDictModule, optional): V(s) parametric model.
|
|
53
|
+
|
|
54
|
+
Keyword Args:
|
|
55
|
+
num_qvalue_nets (integer, optional): number of Q-Value networks used.
|
|
56
|
+
Defaults to ``2``.
|
|
57
|
+
loss_function (str, optional): loss function to be used with
|
|
58
|
+
the value function loss. Default is `"smooth_l1"`.
|
|
59
|
+
temperature (:obj:`float`, optional): Inverse temperature (beta).
|
|
60
|
+
For smaller hyperparameter values, the objective behaves similarly to
|
|
61
|
+
behavioral cloning, while for larger values, it attempts to recover the
|
|
62
|
+
maximum of the Q-function.
|
|
63
|
+
expectile (:obj:`float`, optional): expectile :math:`\tau`. A larger value of :math:`\tau` is crucial
|
|
64
|
+
for antmaze tasks that require dynamical programming ("stichting").
|
|
65
|
+
priority_key (str, optional): [Deprecated, use .set_keys(priority_key=priority_key) instead]
|
|
66
|
+
tensordict key where to write the priority (for prioritized replay
|
|
67
|
+
buffer usage). Default is `"td_error"`.
|
|
68
|
+
separate_losses (bool, optional): if ``True``, shared parameters between
|
|
69
|
+
policy and critic will only be trained on the policy loss.
|
|
70
|
+
Defaults to ``False``, i.e., gradients are propagated to shared
|
|
71
|
+
parameters for both policy and critic losses.
|
|
72
|
+
reduction (str, optional): Specifies the reduction to apply to the output:
|
|
73
|
+
``"none"`` | ``"mean"`` | ``"sum"``. ``"none"``: no reduction will be applied,
|
|
74
|
+
``"mean"``: the sum of the output will be divided by the number of
|
|
75
|
+
elements in the output, ``"sum"``: the output will be summed. Default: ``"mean"``.
|
|
76
|
+
deactivate_vmap (bool, optional): whether to deactivate vmap calls and replace them with a plain for loop.
|
|
77
|
+
Defaults to ``False``.
|
|
78
|
+
|
|
79
|
+
Examples:
|
|
80
|
+
>>> import torch
|
|
81
|
+
>>> from torch import nn
|
|
82
|
+
>>> from torchrl.data import Bounded
|
|
83
|
+
>>> from torchrl.modules.distributions import NormalParamExtractor, TanhNormal
|
|
84
|
+
>>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator
|
|
85
|
+
>>> from torchrl.modules.tensordict_module.common import SafeModule
|
|
86
|
+
>>> from torchrl.objectives.iql import IQLLoss
|
|
87
|
+
>>> from tensordict import TensorDict
|
|
88
|
+
>>> n_act, n_obs = 4, 3
|
|
89
|
+
>>> spec = Bounded(-torch.ones(n_act), torch.ones(n_act), (n_act,))
|
|
90
|
+
>>> net = nn.Sequential(nn.Linear(n_obs, 2 * n_act), NormalParamExtractor())
|
|
91
|
+
>>> module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"])
|
|
92
|
+
>>> actor = ProbabilisticActor(
|
|
93
|
+
... module=module,
|
|
94
|
+
... in_keys=["loc", "scale"],
|
|
95
|
+
... spec=spec,
|
|
96
|
+
... distribution_class=TanhNormal)
|
|
97
|
+
>>> class QValueClass(nn.Module):
|
|
98
|
+
... def __init__(self):
|
|
99
|
+
... super().__init__()
|
|
100
|
+
... self.linear = nn.Linear(n_obs + n_act, 1)
|
|
101
|
+
... def forward(self, obs, act):
|
|
102
|
+
... return self.linear(torch.cat([obs, act], -1))
|
|
103
|
+
>>> qvalue = SafeModule(
|
|
104
|
+
... QValueClass(),
|
|
105
|
+
... in_keys=["observation", "action"],
|
|
106
|
+
... out_keys=["state_action_value"],
|
|
107
|
+
... )
|
|
108
|
+
>>> value = SafeModule(
|
|
109
|
+
... nn.Linear(n_obs, 1),
|
|
110
|
+
... in_keys=["observation"],
|
|
111
|
+
... out_keys=["state_value"],
|
|
112
|
+
... )
|
|
113
|
+
>>> loss = IQLLoss(actor, qvalue, value)
|
|
114
|
+
>>> batch = [2, ]
|
|
115
|
+
>>> action = spec.rand(batch)
|
|
116
|
+
>>> data = TensorDict({
|
|
117
|
+
... "observation": torch.randn(*batch, n_obs),
|
|
118
|
+
... "action": action,
|
|
119
|
+
... ("next", "done"): torch.zeros(*batch, 1, dtype=torch.bool),
|
|
120
|
+
... ("next", "terminated"): torch.zeros(*batch, 1, dtype=torch.bool),
|
|
121
|
+
... ("next", "reward"): torch.randn(*batch, 1),
|
|
122
|
+
... ("next", "observation"): torch.randn(*batch, n_obs),
|
|
123
|
+
... }, batch)
|
|
124
|
+
>>> loss(data)
|
|
125
|
+
TensorDict(
|
|
126
|
+
fields={
|
|
127
|
+
entropy: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
|
|
128
|
+
loss_actor: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
|
|
129
|
+
loss_qvalue: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
|
|
130
|
+
loss_value: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
|
|
131
|
+
batch_size=torch.Size([]),
|
|
132
|
+
device=None,
|
|
133
|
+
is_shared=False)
|
|
134
|
+
|
|
135
|
+
This class is compatible with non-tensordict based modules too and can be
|
|
136
|
+
used without recurring to any tensordict-related primitive. In this case,
|
|
137
|
+
the expected keyword arguments are:
|
|
138
|
+
``["action", "next_reward", "next_done", "next_terminated"]`` + in_keys of the actor, value, and qvalue network
|
|
139
|
+
The return value is a tuple of tensors in the following order:
|
|
140
|
+
``["loss_actor", "loss_qvalue", "loss_value", "entropy"]``.
|
|
141
|
+
|
|
142
|
+
Examples:
|
|
143
|
+
>>> import torch
|
|
144
|
+
>>> from torch import nn
|
|
145
|
+
>>> from torchrl.data import Bounded
|
|
146
|
+
>>> from torchrl.modules.distributions import NormalParamExtractor, TanhNormal
|
|
147
|
+
>>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator
|
|
148
|
+
>>> from torchrl.modules.tensordict_module.common import SafeModule
|
|
149
|
+
>>> from torchrl.objectives.iql import IQLLoss
|
|
150
|
+
>>> _ = torch.manual_seed(42)
|
|
151
|
+
>>> n_act, n_obs = 4, 3
|
|
152
|
+
>>> spec = Bounded(-torch.ones(n_act), torch.ones(n_act), (n_act,))
|
|
153
|
+
>>> net = nn.Sequential(nn.Linear(n_obs, 2 * n_act), NormalParamExtractor())
|
|
154
|
+
>>> module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"])
|
|
155
|
+
>>> actor = ProbabilisticActor(
|
|
156
|
+
... module=module,
|
|
157
|
+
... in_keys=["loc", "scale"],
|
|
158
|
+
... spec=spec,
|
|
159
|
+
... distribution_class=TanhNormal)
|
|
160
|
+
>>> class QValueClass(nn.Module):
|
|
161
|
+
... def __init__(self):
|
|
162
|
+
... super().__init__()
|
|
163
|
+
... self.linear = nn.Linear(n_obs + n_act, 1)
|
|
164
|
+
... def forward(self, obs, act):
|
|
165
|
+
... return self.linear(torch.cat([obs, act], -1))
|
|
166
|
+
>>> qvalue = SafeModule(
|
|
167
|
+
... QValueClass(),
|
|
168
|
+
... in_keys=["observation", "action"],
|
|
169
|
+
... out_keys=["state_action_value"],
|
|
170
|
+
... )
|
|
171
|
+
>>> value = SafeModule(
|
|
172
|
+
... nn.Linear(n_obs, 1),
|
|
173
|
+
... in_keys=["observation"],
|
|
174
|
+
... out_keys=["state_value"],
|
|
175
|
+
... )
|
|
176
|
+
>>> loss = IQLLoss(actor, qvalue, value)
|
|
177
|
+
>>> batch = [2, ]
|
|
178
|
+
>>> action = spec.rand(batch)
|
|
179
|
+
>>> loss_actor, loss_qvalue, loss_value, entropy = loss(
|
|
180
|
+
... observation=torch.randn(*batch, n_obs),
|
|
181
|
+
... action=action,
|
|
182
|
+
... next_done=torch.zeros(*batch, 1, dtype=torch.bool),
|
|
183
|
+
... next_terminated=torch.zeros(*batch, 1, dtype=torch.bool),
|
|
184
|
+
... next_observation=torch.zeros(*batch, n_obs),
|
|
185
|
+
... next_reward=torch.randn(*batch, 1))
|
|
186
|
+
>>> loss_actor.backward()
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
The output keys can also be filtered using the :meth:`IQLLoss.select_out_keys`
|
|
190
|
+
method.
|
|
191
|
+
|
|
192
|
+
Examples:
|
|
193
|
+
>>> _ = loss.select_out_keys('loss_actor', 'loss_qvalue')
|
|
194
|
+
>>> loss_actor, loss_qvalue = loss(
|
|
195
|
+
... observation=torch.randn(*batch, n_obs),
|
|
196
|
+
... action=action,
|
|
197
|
+
... next_done=torch.zeros(*batch, 1, dtype=torch.bool),
|
|
198
|
+
... next_terminated=torch.zeros(*batch, 1, dtype=torch.bool),
|
|
199
|
+
... next_observation=torch.zeros(*batch, n_obs),
|
|
200
|
+
... next_reward=torch.randn(*batch, 1))
|
|
201
|
+
>>> loss_actor.backward()
|
|
202
|
+
"""
|
|
203
|
+
|
|
204
|
+
@dataclass
|
|
205
|
+
class _AcceptedKeys:
|
|
206
|
+
"""Maintains default values for all configurable tensordict keys.
|
|
207
|
+
|
|
208
|
+
This class defines which tensordict keys can be set using '.set_keys(key_name=key_value)' and their
|
|
209
|
+
default values
|
|
210
|
+
|
|
211
|
+
Attributes:
|
|
212
|
+
value (NestedKey): The input tensordict key where the state value is expected.
|
|
213
|
+
Will be used for the underlying value estimator. Defaults to ``"state_value"``.
|
|
214
|
+
action (NestedKey): The input tensordict key where the action is expected.
|
|
215
|
+
Defaults to ``"action"``.
|
|
216
|
+
log_prob (NestedKey): The input tensordict key where the log probability is expected.
|
|
217
|
+
Defaults to ``"_log_prob"``.
|
|
218
|
+
priority (NestedKey): The input tensordict key where the target priority is written to.
|
|
219
|
+
Defaults to ``"td_error"``.
|
|
220
|
+
state_action_value (NestedKey): The input tensordict key where the
|
|
221
|
+
state action value is expected. Will be used for the underlying
|
|
222
|
+
value estimator as value key. Defaults to ``"state_action_value"``.
|
|
223
|
+
reward (NestedKey): The input tensordict key where the reward is expected.
|
|
224
|
+
Will be used for the underlying value estimator. Defaults to ``"reward"``.
|
|
225
|
+
done (NestedKey): The key in the input TensorDict that indicates
|
|
226
|
+
whether a trajectory is done. Will be used for the underlying value estimator.
|
|
227
|
+
Defaults to ``"done"``.
|
|
228
|
+
terminated (NestedKey): The key in the input TensorDict that indicates
|
|
229
|
+
whether a trajectory is terminated. Will be used for the underlying value estimator.
|
|
230
|
+
Defaults to ``"terminated"``.
|
|
231
|
+
"""
|
|
232
|
+
|
|
233
|
+
value: NestedKey = "state_value"
|
|
234
|
+
action: NestedKey = "action"
|
|
235
|
+
log_prob: NestedKey = "_log_prob"
|
|
236
|
+
priority: NestedKey = "td_error"
|
|
237
|
+
state_action_value: NestedKey = "state_action_value"
|
|
238
|
+
reward: NestedKey = "reward"
|
|
239
|
+
done: NestedKey = "done"
|
|
240
|
+
terminated: NestedKey = "terminated"
|
|
241
|
+
|
|
242
|
+
tensor_keys: _AcceptedKeys
|
|
243
|
+
default_keys = _AcceptedKeys
|
|
244
|
+
default_value_estimator = ValueEstimators.TD0
|
|
245
|
+
out_keys = [
|
|
246
|
+
"loss_actor",
|
|
247
|
+
"loss_qvalue",
|
|
248
|
+
"loss_value",
|
|
249
|
+
"entropy",
|
|
250
|
+
]
|
|
251
|
+
|
|
252
|
+
actor_network: TensorDictModule
|
|
253
|
+
actor_network_params: TensorDictParams
|
|
254
|
+
target_actor_network_params: TensorDictParams
|
|
255
|
+
qvalue_network: TensorDictModule
|
|
256
|
+
qvalue_network_params: TensorDictParams
|
|
257
|
+
target_qvalue_network_params: TensorDictParams
|
|
258
|
+
value_network: TensorDictModule | None
|
|
259
|
+
value_network_params: TensorDictParams | None
|
|
260
|
+
target_value_network_params: TensorDictParams | None
|
|
261
|
+
|
|
262
|
+
def __init__(
|
|
263
|
+
self,
|
|
264
|
+
actor_network: ProbabilisticTensorDictSequential,
|
|
265
|
+
qvalue_network: TensorDictModule | list[TensorDictModule],
|
|
266
|
+
value_network: TensorDictModule | None,
|
|
267
|
+
*,
|
|
268
|
+
num_qvalue_nets: int = 2,
|
|
269
|
+
loss_function: str = "smooth_l1",
|
|
270
|
+
temperature: float = 1.0,
|
|
271
|
+
expectile: float = 0.5,
|
|
272
|
+
gamma: float | None = None,
|
|
273
|
+
priority_key: str | None = None,
|
|
274
|
+
separate_losses: bool = False,
|
|
275
|
+
reduction: str | None = None,
|
|
276
|
+
deactivate_vmap: bool = False,
|
|
277
|
+
) -> None:
|
|
278
|
+
self._in_keys = None
|
|
279
|
+
self._out_keys = None
|
|
280
|
+
if reduction is None:
|
|
281
|
+
reduction = "mean"
|
|
282
|
+
super().__init__()
|
|
283
|
+
self._set_deprecated_ctor_keys(priority=priority_key)
|
|
284
|
+
|
|
285
|
+
self.deactivate_vmap = deactivate_vmap
|
|
286
|
+
|
|
287
|
+
# IQL parameter
|
|
288
|
+
self.temperature = temperature
|
|
289
|
+
self.expectile = expectile
|
|
290
|
+
|
|
291
|
+
# Actor Network
|
|
292
|
+
self.convert_to_functional(
|
|
293
|
+
actor_network,
|
|
294
|
+
"actor_network",
|
|
295
|
+
create_target_params=False,
|
|
296
|
+
)
|
|
297
|
+
if separate_losses:
|
|
298
|
+
# we want to make sure there are no duplicates in the params: the
|
|
299
|
+
# params of critic must be refs to actor if they're shared
|
|
300
|
+
policy_params = list(actor_network.parameters())
|
|
301
|
+
else:
|
|
302
|
+
policy_params = None
|
|
303
|
+
# Value Function Network
|
|
304
|
+
self.convert_to_functional(
|
|
305
|
+
value_network,
|
|
306
|
+
"value_network",
|
|
307
|
+
create_target_params=False,
|
|
308
|
+
compare_against=policy_params,
|
|
309
|
+
)
|
|
310
|
+
|
|
311
|
+
# Q Function Network
|
|
312
|
+
self.delay_qvalue = True
|
|
313
|
+
self.num_qvalue_nets = num_qvalue_nets
|
|
314
|
+
if separate_losses and policy_params is not None:
|
|
315
|
+
qvalue_policy_params = list(actor_network.parameters()) + list(
|
|
316
|
+
value_network.parameters()
|
|
317
|
+
)
|
|
318
|
+
else:
|
|
319
|
+
qvalue_policy_params = None
|
|
320
|
+
self.convert_to_functional(
|
|
321
|
+
qvalue_network,
|
|
322
|
+
"qvalue_network",
|
|
323
|
+
num_qvalue_nets,
|
|
324
|
+
create_target_params=True,
|
|
325
|
+
compare_against=qvalue_policy_params,
|
|
326
|
+
)
|
|
327
|
+
|
|
328
|
+
self.loss_function = loss_function
|
|
329
|
+
if gamma is not None:
|
|
330
|
+
raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR)
|
|
331
|
+
self._make_vmap()
|
|
332
|
+
self.reduction = reduction
|
|
333
|
+
|
|
334
|
+
def _make_vmap(self):
|
|
335
|
+
self._vmap_qvalue_networkN0 = _vmap_func(
|
|
336
|
+
self.qvalue_network,
|
|
337
|
+
(None, 0),
|
|
338
|
+
randomness=self.vmap_randomness,
|
|
339
|
+
pseudo_vmap=self.deactivate_vmap,
|
|
340
|
+
)
|
|
341
|
+
|
|
342
|
+
@property
|
|
343
|
+
def device(self) -> torch.device:
|
|
344
|
+
raise RuntimeError(
|
|
345
|
+
"The device attributes of the losses is deprecated since v0.3.",
|
|
346
|
+
)
|
|
347
|
+
|
|
348
|
+
def _set_in_keys(self):
|
|
349
|
+
keys = [
|
|
350
|
+
self.tensor_keys.action,
|
|
351
|
+
("next", self.tensor_keys.reward),
|
|
352
|
+
("next", self.tensor_keys.done),
|
|
353
|
+
("next", self.tensor_keys.terminated),
|
|
354
|
+
*self.actor_network.in_keys,
|
|
355
|
+
*[("next", key) for key in self.actor_network.in_keys],
|
|
356
|
+
*self.qvalue_network.in_keys,
|
|
357
|
+
*self.value_network.in_keys,
|
|
358
|
+
]
|
|
359
|
+
self._in_keys = list(set(keys))
|
|
360
|
+
|
|
361
|
+
@property
|
|
362
|
+
def in_keys(self):
|
|
363
|
+
if self._in_keys is None:
|
|
364
|
+
self._set_in_keys()
|
|
365
|
+
return self._in_keys
|
|
366
|
+
|
|
367
|
+
@in_keys.setter
|
|
368
|
+
def in_keys(self, values):
|
|
369
|
+
self._in_keys = values
|
|
370
|
+
|
|
371
|
+
@staticmethod
|
|
372
|
+
def loss_value_diff(diff, expectile=0.8):
|
|
373
|
+
"""Loss function for iql expectile value difference."""
|
|
374
|
+
weight = torch.where(diff > 0, expectile, (1 - expectile))
|
|
375
|
+
return weight * (diff**2)
|
|
376
|
+
|
|
377
|
+
def _forward_value_estimator_keys(self, **kwargs) -> None:
|
|
378
|
+
if self._value_estimator is not None:
|
|
379
|
+
self._value_estimator.set_keys(
|
|
380
|
+
value=self._tensor_keys.value,
|
|
381
|
+
reward=self.tensor_keys.reward,
|
|
382
|
+
done=self.tensor_keys.done,
|
|
383
|
+
terminated=self.tensor_keys.terminated,
|
|
384
|
+
)
|
|
385
|
+
self._set_in_keys()
|
|
386
|
+
|
|
387
|
+
@dispatch
|
|
388
|
+
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
|
|
389
|
+
loss_actor, metadata = self.actor_loss(tensordict)
|
|
390
|
+
loss_qvalue, metadata_qvalue = self.qvalue_loss(tensordict)
|
|
391
|
+
loss_value, metadata_value = self.value_loss(tensordict)
|
|
392
|
+
metadata.update(metadata_qvalue)
|
|
393
|
+
metadata.update(metadata_value)
|
|
394
|
+
|
|
395
|
+
if (loss_actor.shape != loss_qvalue.shape) or (
|
|
396
|
+
loss_value is not None and loss_actor.shape != loss_value.shape
|
|
397
|
+
):
|
|
398
|
+
raise RuntimeError(
|
|
399
|
+
f"Losses shape mismatch: {loss_actor.shape}, {loss_qvalue.shape} and {loss_value.shape}"
|
|
400
|
+
)
|
|
401
|
+
tensordict.set(
|
|
402
|
+
self.tensor_keys.priority, metadata.pop("td_error").detach().max(0).values
|
|
403
|
+
)
|
|
404
|
+
entropy = -tensordict.get(self.tensor_keys.log_prob).detach()
|
|
405
|
+
out = {
|
|
406
|
+
"loss_actor": loss_actor,
|
|
407
|
+
"loss_qvalue": loss_qvalue,
|
|
408
|
+
"loss_value": loss_value,
|
|
409
|
+
"entropy": entropy.mean(),
|
|
410
|
+
}
|
|
411
|
+
td_out = TensorDict(out)
|
|
412
|
+
|
|
413
|
+
self._clear_weakrefs(
|
|
414
|
+
tensordict,
|
|
415
|
+
td_out,
|
|
416
|
+
"actor_network_params",
|
|
417
|
+
"qvalue_network_params",
|
|
418
|
+
"value_network_params",
|
|
419
|
+
"target_actor_network_params",
|
|
420
|
+
"target_qvalue_network_params",
|
|
421
|
+
"target_value_network_params",
|
|
422
|
+
)
|
|
423
|
+
return td_out
|
|
424
|
+
|
|
425
|
+
def actor_loss(self, tensordict: TensorDictBase) -> tuple[Tensor, dict]:
|
|
426
|
+
# KL loss
|
|
427
|
+
with self.actor_network_params.to_module(self.actor_network):
|
|
428
|
+
dist = self.actor_network.get_dist(tensordict)
|
|
429
|
+
|
|
430
|
+
log_prob = dist.log_prob(tensordict[self.tensor_keys.action])
|
|
431
|
+
|
|
432
|
+
# Min Q value
|
|
433
|
+
td_q = tensordict.select(*self.qvalue_network.in_keys, strict=False)
|
|
434
|
+
td_q = self._vmap_qvalue_networkN0(td_q, self.target_qvalue_network_params)
|
|
435
|
+
min_q = td_q.get(self.tensor_keys.state_action_value).min(0)[0].squeeze(-1)
|
|
436
|
+
|
|
437
|
+
if log_prob.shape != min_q.shape:
|
|
438
|
+
raise RuntimeError(
|
|
439
|
+
f"Losses shape mismatch: {log_prob.shape} and {min_q.shape}"
|
|
440
|
+
)
|
|
441
|
+
# state value
|
|
442
|
+
with torch.no_grad():
|
|
443
|
+
td_copy = tensordict.select(
|
|
444
|
+
*self.value_network.in_keys, strict=False
|
|
445
|
+
).detach()
|
|
446
|
+
with self.value_network_params.to_module(self.value_network):
|
|
447
|
+
self.value_network(td_copy)
|
|
448
|
+
value = td_copy.get(self.tensor_keys.value).squeeze(
|
|
449
|
+
-1
|
|
450
|
+
) # assert has no gradient
|
|
451
|
+
|
|
452
|
+
exp_a = torch.exp((min_q - value) * self.temperature)
|
|
453
|
+
exp_a = exp_a.clamp_max(100)
|
|
454
|
+
|
|
455
|
+
# write log_prob in tensordict for alpha loss
|
|
456
|
+
tensordict.set(self.tensor_keys.log_prob, log_prob.detach())
|
|
457
|
+
loss_actor = -(exp_a * log_prob)
|
|
458
|
+
loss_actor = _reduce(loss_actor, reduction=self.reduction)
|
|
459
|
+
self._clear_weakrefs(
|
|
460
|
+
tensordict,
|
|
461
|
+
"actor_network_params",
|
|
462
|
+
"qvalue_network_params",
|
|
463
|
+
"value_network_params",
|
|
464
|
+
"target_actor_network_params",
|
|
465
|
+
"target_qvalue_network_params",
|
|
466
|
+
"target_value_network_params",
|
|
467
|
+
)
|
|
468
|
+
return loss_actor, {}
|
|
469
|
+
|
|
470
|
+
def value_loss(self, tensordict: TensorDictBase) -> tuple[Tensor, dict]:
|
|
471
|
+
# Min Q value
|
|
472
|
+
td_q = tensordict.select(*self.qvalue_network.in_keys, strict=False)
|
|
473
|
+
td_q = self._vmap_qvalue_networkN0(td_q, self.target_qvalue_network_params)
|
|
474
|
+
min_q = td_q.get(self.tensor_keys.state_action_value).min(0)[0].squeeze(-1)
|
|
475
|
+
# state value
|
|
476
|
+
td_copy = tensordict.select(*self.value_network.in_keys, strict=False)
|
|
477
|
+
with self.value_network_params.to_module(self.value_network):
|
|
478
|
+
self.value_network(td_copy)
|
|
479
|
+
value = td_copy.get(self.tensor_keys.value).squeeze(-1)
|
|
480
|
+
value_loss = self.loss_value_diff(min_q - value, self.expectile)
|
|
481
|
+
value_loss = _reduce(value_loss, reduction=self.reduction)
|
|
482
|
+
self._clear_weakrefs(
|
|
483
|
+
tensordict,
|
|
484
|
+
"actor_network_params",
|
|
485
|
+
"qvalue_network_params",
|
|
486
|
+
"value_network_params",
|
|
487
|
+
"target_actor_network_params",
|
|
488
|
+
"target_qvalue_network_params",
|
|
489
|
+
"target_value_network_params",
|
|
490
|
+
)
|
|
491
|
+
return value_loss, {}
|
|
492
|
+
|
|
493
|
+
def qvalue_loss(self, tensordict: TensorDictBase) -> tuple[Tensor, dict]:
|
|
494
|
+
obs_keys = self.actor_network.in_keys
|
|
495
|
+
tensordict = tensordict.select(
|
|
496
|
+
"next", *obs_keys, self.tensor_keys.action, strict=False
|
|
497
|
+
)
|
|
498
|
+
|
|
499
|
+
target_value = self.value_estimator.value_estimate(
|
|
500
|
+
tensordict, target_params=self.target_value_network_params
|
|
501
|
+
).squeeze(-1)
|
|
502
|
+
tensordict_expand = self._vmap_qvalue_networkN0(
|
|
503
|
+
tensordict.select(*self.qvalue_network.in_keys, strict=False),
|
|
504
|
+
self.qvalue_network_params,
|
|
505
|
+
)
|
|
506
|
+
pred_val = tensordict_expand.get(self.tensor_keys.state_action_value).squeeze(
|
|
507
|
+
-1
|
|
508
|
+
)
|
|
509
|
+
td_error = (pred_val - target_value).pow(2)
|
|
510
|
+
loss_qval = distance_loss(
|
|
511
|
+
pred_val,
|
|
512
|
+
target_value.expand_as(pred_val),
|
|
513
|
+
loss_function=self.loss_function,
|
|
514
|
+
).sum(0)
|
|
515
|
+
loss_qval = _reduce(loss_qval, reduction=self.reduction)
|
|
516
|
+
metadata = {"td_error": td_error.detach()}
|
|
517
|
+
self._clear_weakrefs(
|
|
518
|
+
tensordict,
|
|
519
|
+
"actor_network_params",
|
|
520
|
+
"qvalue_network_params",
|
|
521
|
+
"value_network_params",
|
|
522
|
+
"target_actor_network_params",
|
|
523
|
+
"target_qvalue_network_params",
|
|
524
|
+
"target_value_network_params",
|
|
525
|
+
)
|
|
526
|
+
return loss_qval, metadata
|
|
527
|
+
|
|
528
|
+
def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams):
|
|
529
|
+
if value_type is None:
|
|
530
|
+
value_type = self.default_value_estimator
|
|
531
|
+
|
|
532
|
+
# Handle ValueEstimatorBase instance or class
|
|
533
|
+
if isinstance(value_type, ValueEstimatorBase) or (
|
|
534
|
+
isinstance(value_type, type) and issubclass(value_type, ValueEstimatorBase)
|
|
535
|
+
):
|
|
536
|
+
return LossModule.make_value_estimator(self, value_type, **hyperparams)
|
|
537
|
+
|
|
538
|
+
self.value_type = value_type
|
|
539
|
+
value_net = self.value_network
|
|
540
|
+
|
|
541
|
+
hp = dict(default_value_kwargs(value_type))
|
|
542
|
+
if hasattr(self, "gamma"):
|
|
543
|
+
hp["gamma"] = self.gamma
|
|
544
|
+
hp.update(hyperparams)
|
|
545
|
+
if value_type is ValueEstimators.TD1:
|
|
546
|
+
self._value_estimator = TD1Estimator(
|
|
547
|
+
**hp,
|
|
548
|
+
value_network=value_net,
|
|
549
|
+
)
|
|
550
|
+
elif value_type is ValueEstimators.TD0:
|
|
551
|
+
self._value_estimator = TD0Estimator(
|
|
552
|
+
**hp,
|
|
553
|
+
value_network=value_net,
|
|
554
|
+
)
|
|
555
|
+
elif value_type is ValueEstimators.GAE:
|
|
556
|
+
raise NotImplementedError(
|
|
557
|
+
f"Value type {value_type} it not implemented for loss {type(self)}."
|
|
558
|
+
)
|
|
559
|
+
elif value_type is ValueEstimators.TDLambda:
|
|
560
|
+
self._value_estimator = TDLambdaEstimator(
|
|
561
|
+
**hp,
|
|
562
|
+
value_network=value_net,
|
|
563
|
+
)
|
|
564
|
+
else:
|
|
565
|
+
raise NotImplementedError(f"Unknown value type {value_type}")
|
|
566
|
+
|
|
567
|
+
tensor_keys = {
|
|
568
|
+
"value_target": "value_target",
|
|
569
|
+
"value": self.tensor_keys.value,
|
|
570
|
+
"reward": self.tensor_keys.reward,
|
|
571
|
+
"done": self.tensor_keys.done,
|
|
572
|
+
"terminated": self.tensor_keys.terminated,
|
|
573
|
+
}
|
|
574
|
+
self._value_estimator.set_keys(**tensor_keys)
|
|
575
|
+
|
|
576
|
+
|
|
577
|
+
class DiscreteIQLLoss(IQLLoss):
|
|
578
|
+
r"""TorchRL implementation of the discrete IQL loss.
|
|
579
|
+
|
|
580
|
+
Presented in "Offline Reinforcement Learning with Implicit Q-Learning" https://arxiv.org/abs/2110.06169
|
|
581
|
+
|
|
582
|
+
Args:
|
|
583
|
+
actor_network (ProbabilisticTensorDictSequential): stochastic actor
|
|
584
|
+
qvalue_network (TensorDictModule): Q(s, a) parametric model.
|
|
585
|
+
value_network (TensorDictModule, optional): V(s) parametric model.
|
|
586
|
+
|
|
587
|
+
Keyword Args:
|
|
588
|
+
action_space (str or TensorSpec): Action space. Must be one of
|
|
589
|
+
``"one-hot"``, ``"mult_one_hot"``, ``"binary"`` or ``"categorical"``,
|
|
590
|
+
or an instance of the corresponding specs (:class:`torchrl.data.OneHot`,
|
|
591
|
+
:class:`torchrl.data.MultiOneHot`,
|
|
592
|
+
:class:`torchrl.data.Binary` or :class:`torchrl.data.Categorical`).
|
|
593
|
+
num_qvalue_nets (integer, optional): number of Q-Value networks used.
|
|
594
|
+
Defaults to ``2``.
|
|
595
|
+
loss_function (str, optional): loss function to be used with
|
|
596
|
+
the value function loss. Default is `"smooth_l1"`.
|
|
597
|
+
temperature (:obj:`float`, optional): Inverse temperature (beta).
|
|
598
|
+
For smaller hyperparameter values, the objective behaves similarly to
|
|
599
|
+
behavioral cloning, while for larger values, it attempts to recover the
|
|
600
|
+
maximum of the Q-function.
|
|
601
|
+
expectile (:obj:`float`, optional): expectile :math:`\tau`. A larger value of :math:`\tau` is crucial
|
|
602
|
+
for antmaze tasks that require dynamical programming ("stichting").
|
|
603
|
+
priority_key (str, optional): [Deprecated, use .set_keys(priority_key=priority_key) instead]
|
|
604
|
+
tensordict key where to write the priority (for prioritized replay
|
|
605
|
+
buffer usage). Default is `"td_error"`.
|
|
606
|
+
separate_losses (bool, optional): if ``True``, shared parameters between
|
|
607
|
+
policy and critic will only be trained on the policy loss.
|
|
608
|
+
Defaults to ``False``, i.e., gradients are propagated to shared
|
|
609
|
+
parameters for both policy and critic losses.
|
|
610
|
+
reduction (str, optional): Specifies the reduction to apply to the output:
|
|
611
|
+
``"none"`` | ``"mean"`` | ``"sum"``. ``"none"``: no reduction will be applied,
|
|
612
|
+
``"mean"``: the sum of the output will be divided by the number of
|
|
613
|
+
elements in the output, ``"sum"``: the output will be summed. Default: ``"mean"``.
|
|
614
|
+
|
|
615
|
+
Examples:
|
|
616
|
+
>>> import torch
|
|
617
|
+
>>> from torch import nn
|
|
618
|
+
>>> from torchrl.data.tensor_specs import OneHot
|
|
619
|
+
>>> from torchrl.modules.distributions.discrete import OneHotCategorical
|
|
620
|
+
>>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor
|
|
621
|
+
>>> from torchrl.modules.tensordict_module.common import SafeModule
|
|
622
|
+
>>> from torchrl.objectives.iql import DiscreteIQLLoss
|
|
623
|
+
>>> from tensordict import TensorDict
|
|
624
|
+
>>> n_act, n_obs = 4, 3
|
|
625
|
+
>>> spec = OneHot(n_act)
|
|
626
|
+
>>> module = SafeModule(nn.Linear(n_obs, n_act), in_keys=["observation"], out_keys=["logits"])
|
|
627
|
+
>>> actor = ProbabilisticActor(
|
|
628
|
+
... module=module,
|
|
629
|
+
... in_keys=["logits"],
|
|
630
|
+
... out_keys=["action"],
|
|
631
|
+
... spec=spec,
|
|
632
|
+
... distribution_class=OneHotCategorical)
|
|
633
|
+
>>> qvalue = SafeModule(
|
|
634
|
+
... nn.Linear(n_obs, n_act),
|
|
635
|
+
... in_keys=["observation"],
|
|
636
|
+
... out_keys=["state_action_value"],
|
|
637
|
+
... )
|
|
638
|
+
>>> value = SafeModule(
|
|
639
|
+
... nn.Linear(n_obs, 1),
|
|
640
|
+
... in_keys=["observation"],
|
|
641
|
+
... out_keys=["state_value"],
|
|
642
|
+
... )
|
|
643
|
+
>>> loss = DiscreteIQLLoss(actor, qvalue, value)
|
|
644
|
+
>>> batch = [2, ]
|
|
645
|
+
>>> action = spec.rand(batch).long()
|
|
646
|
+
>>> data = TensorDict({
|
|
647
|
+
... "observation": torch.randn(*batch, n_obs),
|
|
648
|
+
... "action": action,
|
|
649
|
+
... ("next", "done"): torch.zeros(*batch, 1, dtype=torch.bool),
|
|
650
|
+
... ("next", "terminated"): torch.zeros(*batch, 1, dtype=torch.bool),
|
|
651
|
+
... ("next", "reward"): torch.randn(*batch, 1),
|
|
652
|
+
... ("next", "observation"): torch.randn(*batch, n_obs),
|
|
653
|
+
... }, batch)
|
|
654
|
+
>>> loss(data)
|
|
655
|
+
TensorDict(
|
|
656
|
+
fields={
|
|
657
|
+
entropy: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
|
|
658
|
+
loss_actor: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
|
|
659
|
+
loss_qvalue: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
|
|
660
|
+
loss_value: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
|
|
661
|
+
batch_size=torch.Size([]),
|
|
662
|
+
device=None,
|
|
663
|
+
is_shared=False)
|
|
664
|
+
|
|
665
|
+
This class is compatible with non-tensordict based modules too and can be
|
|
666
|
+
used without recurring to any tensordict-related primitive. In this case,
|
|
667
|
+
the expected keyword arguments are:
|
|
668
|
+
``["action", "next_reward", "next_done", "next_terminated"]`` + in_keys of the actor, value, and qvalue network
|
|
669
|
+
The return value is a tuple of tensors in the following order:
|
|
670
|
+
``["loss_actor", "loss_qvalue", "loss_value", "entropy"]``.
|
|
671
|
+
|
|
672
|
+
Examples:
|
|
673
|
+
>>> import torch
|
|
674
|
+
>>> import torch
|
|
675
|
+
>>> from torch import nn
|
|
676
|
+
>>> from torchrl.data.tensor_specs import OneHot
|
|
677
|
+
>>> from torchrl.modules.distributions.discrete import OneHotCategorical
|
|
678
|
+
>>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor
|
|
679
|
+
>>> from torchrl.modules.tensordict_module.common import SafeModule
|
|
680
|
+
>>> from torchrl.objectives.iql import DiscreteIQLLoss
|
|
681
|
+
>>> _ = torch.manual_seed(42)
|
|
682
|
+
>>> n_act, n_obs = 4, 3
|
|
683
|
+
>>> spec = OneHot(n_act)
|
|
684
|
+
>>> module = SafeModule(nn.Linear(n_obs, n_act), in_keys=["observation"], out_keys=["logits"])
|
|
685
|
+
>>> actor = ProbabilisticActor(
|
|
686
|
+
... module=module,
|
|
687
|
+
... in_keys=["logits"],
|
|
688
|
+
... out_keys=["action"],
|
|
689
|
+
... spec=spec,
|
|
690
|
+
... distribution_class=OneHotCategorical)
|
|
691
|
+
>>> qvalue = SafeModule(
|
|
692
|
+
... nn.Linear(n_obs, n_act),
|
|
693
|
+
... in_keys=["observation"],
|
|
694
|
+
... out_keys=["state_action_value"],
|
|
695
|
+
... )
|
|
696
|
+
>>> value = SafeModule(
|
|
697
|
+
... nn.Linear(n_obs, 1),
|
|
698
|
+
... in_keys=["observation"],
|
|
699
|
+
... out_keys=["state_value"],
|
|
700
|
+
... )
|
|
701
|
+
>>> loss = DiscreteIQLLoss(actor, qvalue, value)
|
|
702
|
+
>>> batch = [2, ]
|
|
703
|
+
>>> action = spec.rand(batch).long()
|
|
704
|
+
>>> loss_actor, loss_qvalue, loss_value, entropy = loss(
|
|
705
|
+
... observation=torch.randn(*batch, n_obs),
|
|
706
|
+
... action=action,
|
|
707
|
+
... next_done=torch.zeros(*batch, 1, dtype=torch.bool),
|
|
708
|
+
... next_terminated=torch.zeros(*batch, 1, dtype=torch.bool),
|
|
709
|
+
... next_observation=torch.zeros(*batch, n_obs),
|
|
710
|
+
... next_reward=torch.randn(*batch, 1))
|
|
711
|
+
>>> loss_actor.backward()
|
|
712
|
+
|
|
713
|
+
|
|
714
|
+
The output keys can also be filtered using the :meth:`DiscreteIQLLoss.select_out_keys`
|
|
715
|
+
method.
|
|
716
|
+
|
|
717
|
+
Examples:
|
|
718
|
+
>>> _ = loss.select_out_keys('loss_actor', 'loss_qvalue', 'loss_value')
|
|
719
|
+
>>> loss_actor, loss_qvalue, loss_value = loss(
|
|
720
|
+
... observation=torch.randn(*batch, n_obs),
|
|
721
|
+
... action=action,
|
|
722
|
+
... next_done=torch.zeros(*batch, 1, dtype=torch.bool),
|
|
723
|
+
... next_terminated=torch.zeros(*batch, 1, dtype=torch.bool),
|
|
724
|
+
... next_observation=torch.zeros(*batch, n_obs),
|
|
725
|
+
... next_reward=torch.randn(*batch, 1))
|
|
726
|
+
>>> loss_actor.backward()
|
|
727
|
+
"""
|
|
728
|
+
|
|
729
|
+
@dataclass
|
|
730
|
+
class _AcceptedKeys:
|
|
731
|
+
"""Maintains default values for all configurable tensordict keys.
|
|
732
|
+
|
|
733
|
+
This class defines which tensordict keys can be set using '.set_keys(key_name=key_value)' and their
|
|
734
|
+
default values
|
|
735
|
+
|
|
736
|
+
Attributes:
|
|
737
|
+
value (NestedKey): The input tensordict key where the state value is expected.
|
|
738
|
+
Will be used for the underlying value estimator. Defaults to ``"state_value"``.
|
|
739
|
+
action (NestedKey): The input tensordict key where the action is expected.
|
|
740
|
+
Defaults to ``"action"``.
|
|
741
|
+
log_prob (NestedKey): The input tensordict key where the log probability is expected.
|
|
742
|
+
Defaults to ``"_log_prob"``.
|
|
743
|
+
priority (NestedKey): The input tensordict key where the target priority is written to.
|
|
744
|
+
Defaults to ``"td_error"``.
|
|
745
|
+
state_action_value (NestedKey): The input tensordict key where the
|
|
746
|
+
state action value is expected. Will be used for the underlying
|
|
747
|
+
value estimator as value key. Defaults to ``"state_action_value"``.
|
|
748
|
+
reward (NestedKey): The input tensordict key where the reward is expected.
|
|
749
|
+
Will be used for the underlying value estimator. Defaults to ``"reward"``.
|
|
750
|
+
done (NestedKey): The key in the input TensorDict that indicates
|
|
751
|
+
whether a trajectory is done. Will be used for the underlying value estimator.
|
|
752
|
+
Defaults to ``"done"``.
|
|
753
|
+
terminated (NestedKey): The key in the input TensorDict that indicates
|
|
754
|
+
whether a trajectory is terminated. Will be used for the underlying value estimator.
|
|
755
|
+
Defaults to ``"terminated"``.
|
|
756
|
+
"""
|
|
757
|
+
|
|
758
|
+
value: NestedKey = "state_value"
|
|
759
|
+
action: NestedKey = "action"
|
|
760
|
+
log_prob: NestedKey = "_log_prob"
|
|
761
|
+
priority: NestedKey = "td_error"
|
|
762
|
+
state_action_value: NestedKey = "state_action_value"
|
|
763
|
+
reward: NestedKey = "reward"
|
|
764
|
+
done: NestedKey = "done"
|
|
765
|
+
terminated: NestedKey = "terminated"
|
|
766
|
+
|
|
767
|
+
tensor_keys: _AcceptedKeys
|
|
768
|
+
default_keys = _AcceptedKeys
|
|
769
|
+
default_value_estimator = ValueEstimators.TD0
|
|
770
|
+
out_keys = [
|
|
771
|
+
"loss_actor",
|
|
772
|
+
"loss_qvalue",
|
|
773
|
+
"loss_value",
|
|
774
|
+
"entropy",
|
|
775
|
+
]
|
|
776
|
+
|
|
777
|
+
actor_network: TensorDictModule
|
|
778
|
+
actor_network_params: TensorDictParams
|
|
779
|
+
target_actor_network_params: TensorDictParams
|
|
780
|
+
qvalue_network: TensorDictModule
|
|
781
|
+
qvalue_network_params: TensorDictParams
|
|
782
|
+
target_qvalue_network_params: TensorDictParams
|
|
783
|
+
value_network: TensorDictModule | None
|
|
784
|
+
value_network_params: TensorDictParams | None
|
|
785
|
+
target_value_network_params: TensorDictParams | None
|
|
786
|
+
|
|
787
|
+
def __init__(
|
|
788
|
+
self,
|
|
789
|
+
actor_network: ProbabilisticTensorDictSequential,
|
|
790
|
+
qvalue_network: TensorDictModule,
|
|
791
|
+
value_network: TensorDictModule | None,
|
|
792
|
+
*,
|
|
793
|
+
action_space: str | TensorSpec = None,
|
|
794
|
+
num_qvalue_nets: int = 2,
|
|
795
|
+
loss_function: str = "smooth_l1",
|
|
796
|
+
temperature: float = 1.0,
|
|
797
|
+
expectile: float = 0.5,
|
|
798
|
+
gamma: float | None = None,
|
|
799
|
+
priority_key: str | None = None,
|
|
800
|
+
separate_losses: bool = False,
|
|
801
|
+
reduction: str | None = None,
|
|
802
|
+
) -> None:
|
|
803
|
+
self._in_keys = None
|
|
804
|
+
self._out_keys = None
|
|
805
|
+
if reduction is None:
|
|
806
|
+
reduction = "mean"
|
|
807
|
+
if expectile >= 1.0:
|
|
808
|
+
raise ValueError(f"Expectile should be lower than 1.0 but is {expectile}")
|
|
809
|
+
super().__init__(
|
|
810
|
+
actor_network=actor_network,
|
|
811
|
+
qvalue_network=qvalue_network,
|
|
812
|
+
value_network=value_network,
|
|
813
|
+
num_qvalue_nets=num_qvalue_nets,
|
|
814
|
+
loss_function=loss_function,
|
|
815
|
+
temperature=temperature,
|
|
816
|
+
expectile=expectile,
|
|
817
|
+
gamma=gamma,
|
|
818
|
+
priority_key=priority_key,
|
|
819
|
+
separate_losses=separate_losses,
|
|
820
|
+
)
|
|
821
|
+
if action_space is None:
|
|
822
|
+
warnings.warn(
|
|
823
|
+
"action_space was not specified. DiscreteIQLLoss will default to 'one-hot'. "
|
|
824
|
+
"This behavior will be deprecated soon and a space will have to be passed. "
|
|
825
|
+
"Check the DiscreteIQLLoss documentation to see how to pass the action space. "
|
|
826
|
+
)
|
|
827
|
+
action_space = "one-hot"
|
|
828
|
+
self.action_space = _find_action_space(action_space)
|
|
829
|
+
self.reduction = reduction
|
|
830
|
+
|
|
831
|
+
def actor_loss(self, tensordict: TensorDictBase) -> tuple[Tensor, dict]:
|
|
832
|
+
# KL loss
|
|
833
|
+
with self.actor_network_params.to_module(self.actor_network):
|
|
834
|
+
dist = self.actor_network.get_dist(tensordict)
|
|
835
|
+
|
|
836
|
+
log_prob = dist.log_prob(tensordict[self.tensor_keys.action])
|
|
837
|
+
|
|
838
|
+
# Min Q value
|
|
839
|
+
td_q = tensordict.select(*self.qvalue_network.in_keys, strict=False)
|
|
840
|
+
td_q = self._vmap_qvalue_networkN0(td_q, self.target_qvalue_network_params)
|
|
841
|
+
state_action_value = td_q.get(self.tensor_keys.state_action_value)
|
|
842
|
+
action = tensordict.get(self.tensor_keys.action)
|
|
843
|
+
if self.action_space == "categorical":
|
|
844
|
+
if action.ndim < (state_action_value.ndim - (td_q.ndim - tensordict.ndim)):
|
|
845
|
+
# unsqueeze the action if it lacks on trailing singleton dim
|
|
846
|
+
action = action.unsqueeze(-1)
|
|
847
|
+
if self.deactivate_vmap:
|
|
848
|
+
vmap = _pseudo_vmap
|
|
849
|
+
else:
|
|
850
|
+
vmap = torch.vmap
|
|
851
|
+
chosen_state_action_value = vmap(
|
|
852
|
+
lambda state_action_value, action: torch.gather(
|
|
853
|
+
state_action_value, -1, index=action
|
|
854
|
+
).squeeze(-1),
|
|
855
|
+
(0, None),
|
|
856
|
+
)(state_action_value, action)
|
|
857
|
+
elif self.action_space == "one_hot":
|
|
858
|
+
action = action.to(torch.float)
|
|
859
|
+
chosen_state_action_value = (state_action_value * action).sum(-1)
|
|
860
|
+
else:
|
|
861
|
+
raise RuntimeError(f"Unknown action space {self.action_space}.")
|
|
862
|
+
min_Q, _ = torch.min(chosen_state_action_value, dim=0)
|
|
863
|
+
if log_prob.shape != min_Q.shape:
|
|
864
|
+
raise RuntimeError(
|
|
865
|
+
f"Losses shape mismatch: {log_prob.shape} and {min_Q.shape}"
|
|
866
|
+
)
|
|
867
|
+
with torch.no_grad():
|
|
868
|
+
# state value
|
|
869
|
+
td_copy = tensordict.select(
|
|
870
|
+
*self.value_network.in_keys, strict=False
|
|
871
|
+
).detach()
|
|
872
|
+
with self.value_network_params.to_module(self.value_network):
|
|
873
|
+
self.value_network(td_copy)
|
|
874
|
+
value = td_copy.get(self.tensor_keys.value).squeeze(
|
|
875
|
+
-1
|
|
876
|
+
) # assert has no gradient
|
|
877
|
+
|
|
878
|
+
exp_a = torch.exp((min_Q - value) * self.temperature)
|
|
879
|
+
exp_a = exp_a.clamp_max(100)
|
|
880
|
+
|
|
881
|
+
# write log_prob in tensordict for alpha loss
|
|
882
|
+
tensordict.set(self.tensor_keys.log_prob, log_prob.detach())
|
|
883
|
+
loss_actor = -(exp_a * log_prob)
|
|
884
|
+
loss_actor = _reduce(loss_actor, reduction=self.reduction)
|
|
885
|
+
self._clear_weakrefs(
|
|
886
|
+
tensordict,
|
|
887
|
+
"actor_network_params",
|
|
888
|
+
"qvalue_network_params",
|
|
889
|
+
"value_network_params",
|
|
890
|
+
"target_actor_network_params",
|
|
891
|
+
"target_qvalue_network_params",
|
|
892
|
+
"target_value_network_params",
|
|
893
|
+
)
|
|
894
|
+
return loss_actor, {}
|
|
895
|
+
|
|
896
|
+
def value_loss(self, tensordict: TensorDictBase) -> tuple[Tensor, dict]:
|
|
897
|
+
# Min Q value
|
|
898
|
+
with torch.no_grad():
|
|
899
|
+
# Min Q value
|
|
900
|
+
td_q = tensordict.select(*self.qvalue_network.in_keys, strict=False)
|
|
901
|
+
td_q = self._vmap_qvalue_networkN0(td_q, self.target_qvalue_network_params)
|
|
902
|
+
state_action_value = td_q.get(self.tensor_keys.state_action_value)
|
|
903
|
+
action = tensordict.get(self.tensor_keys.action)
|
|
904
|
+
if self.action_space == "categorical":
|
|
905
|
+
if action.ndim < (
|
|
906
|
+
state_action_value.ndim - (td_q.ndim - tensordict.ndim)
|
|
907
|
+
):
|
|
908
|
+
# unsqueeze the action if it lacks on trailing singleton dim
|
|
909
|
+
action = action.unsqueeze(-1)
|
|
910
|
+
if self.deactivate_vmap:
|
|
911
|
+
vmap = _pseudo_vmap
|
|
912
|
+
else:
|
|
913
|
+
vmap = torch.vmap
|
|
914
|
+
chosen_state_action_value = vmap(
|
|
915
|
+
lambda state_action_value, action: torch.gather(
|
|
916
|
+
state_action_value, -1, index=action
|
|
917
|
+
).squeeze(-1),
|
|
918
|
+
(0, None),
|
|
919
|
+
)(state_action_value, action)
|
|
920
|
+
elif self.action_space == "one_hot":
|
|
921
|
+
action = action.to(torch.float)
|
|
922
|
+
chosen_state_action_value = (state_action_value * action).sum(-1)
|
|
923
|
+
else:
|
|
924
|
+
raise RuntimeError(f"Unknown action space {self.action_space}.")
|
|
925
|
+
min_Q, _ = torch.min(chosen_state_action_value, dim=0)
|
|
926
|
+
# state value
|
|
927
|
+
td_copy = tensordict.select(*self.value_network.in_keys, strict=False)
|
|
928
|
+
with self.value_network_params.to_module(self.value_network):
|
|
929
|
+
self.value_network(td_copy)
|
|
930
|
+
value = td_copy.get(self.tensor_keys.value).squeeze(-1)
|
|
931
|
+
value_loss = self.loss_value_diff(min_Q - value, self.expectile)
|
|
932
|
+
value_loss = _reduce(value_loss, reduction=self.reduction)
|
|
933
|
+
self._clear_weakrefs(
|
|
934
|
+
tensordict,
|
|
935
|
+
"actor_network_params",
|
|
936
|
+
"qvalue_network_params",
|
|
937
|
+
"value_network_params",
|
|
938
|
+
"target_actor_network_params",
|
|
939
|
+
"target_qvalue_network_params",
|
|
940
|
+
"target_value_network_params",
|
|
941
|
+
)
|
|
942
|
+
return value_loss, {}
|
|
943
|
+
|
|
944
|
+
def qvalue_loss(self, tensordict: TensorDictBase) -> tuple[Tensor, dict]:
|
|
945
|
+
obs_keys = self.actor_network.in_keys
|
|
946
|
+
next_td = tensordict.select(
|
|
947
|
+
"next", *obs_keys, self.tensor_keys.action, strict=False
|
|
948
|
+
)
|
|
949
|
+
with torch.no_grad():
|
|
950
|
+
target_value = self.value_estimator.value_estimate(
|
|
951
|
+
next_td, target_params=self.target_value_network_params
|
|
952
|
+
).squeeze(-1)
|
|
953
|
+
|
|
954
|
+
# predict current Q value
|
|
955
|
+
td_q = tensordict.select(*self.qvalue_network.in_keys, strict=False)
|
|
956
|
+
td_q = self._vmap_qvalue_networkN0(td_q, self.qvalue_network_params)
|
|
957
|
+
state_action_value = td_q.get(self.tensor_keys.state_action_value)
|
|
958
|
+
action = tensordict.get(self.tensor_keys.action)
|
|
959
|
+
if self.action_space == "categorical":
|
|
960
|
+
if action.ndim < (state_action_value.ndim - (td_q.ndim - tensordict.ndim)):
|
|
961
|
+
# unsqueeze the action if it lacks on trailing singleton dim
|
|
962
|
+
action = action.unsqueeze(-1)
|
|
963
|
+
if self.deactivate_vmap:
|
|
964
|
+
vmap = _pseudo_vmap
|
|
965
|
+
else:
|
|
966
|
+
vmap = torch.vmap
|
|
967
|
+
pred_val = vmap(
|
|
968
|
+
lambda state_action_value, action: torch.gather(
|
|
969
|
+
state_action_value, -1, index=action
|
|
970
|
+
).squeeze(-1),
|
|
971
|
+
(0, None),
|
|
972
|
+
)(state_action_value, action)
|
|
973
|
+
elif self.action_space == "one_hot":
|
|
974
|
+
action = action.to(torch.float)
|
|
975
|
+
pred_val = (state_action_value * action).sum(-1)
|
|
976
|
+
else:
|
|
977
|
+
raise RuntimeError(f"Unknown action space {self.action_space}.")
|
|
978
|
+
|
|
979
|
+
td_error = (pred_val - target_value.expand_as(pred_val)).pow(2)
|
|
980
|
+
loss_qval = distance_loss(
|
|
981
|
+
pred_val,
|
|
982
|
+
target_value.expand_as(pred_val),
|
|
983
|
+
loss_function=self.loss_function,
|
|
984
|
+
).sum(0)
|
|
985
|
+
loss_qval = _reduce(loss_qval, reduction=self.reduction)
|
|
986
|
+
metadata = {"td_error": td_error.detach()}
|
|
987
|
+
self._clear_weakrefs(
|
|
988
|
+
tensordict,
|
|
989
|
+
"actor_network_params",
|
|
990
|
+
"qvalue_network_params",
|
|
991
|
+
"value_network_params",
|
|
992
|
+
"target_actor_network_params",
|
|
993
|
+
"target_qvalue_network_params",
|
|
994
|
+
"target_value_network_params",
|
|
995
|
+
)
|
|
996
|
+
return loss_qval, metadata
|