torchrl 0.11.0__cp314-cp314t-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/benchmark_batched_envs.py +104 -0
- benchmarks/conftest.py +91 -0
- benchmarks/ecosystem/gym_env_throughput.py +321 -0
- benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
- benchmarks/requirements.txt +7 -0
- benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
- benchmarks/test_collectors_benchmark.py +240 -0
- benchmarks/test_compressed_storage_benchmark.py +145 -0
- benchmarks/test_envs_benchmark.py +133 -0
- benchmarks/test_llm.py +101 -0
- benchmarks/test_non_tensor_env_benchmark.py +70 -0
- benchmarks/test_objectives_benchmarks.py +1199 -0
- benchmarks/test_replaybuffer_benchmark.py +254 -0
- sota-check/README.md +35 -0
- sota-implementations/README.md +142 -0
- sota-implementations/a2c/README.md +39 -0
- sota-implementations/a2c/a2c_atari.py +291 -0
- sota-implementations/a2c/a2c_mujoco.py +273 -0
- sota-implementations/a2c/utils_atari.py +240 -0
- sota-implementations/a2c/utils_mujoco.py +160 -0
- sota-implementations/bandits/README.md +7 -0
- sota-implementations/bandits/dqn.py +126 -0
- sota-implementations/cql/cql_offline.py +198 -0
- sota-implementations/cql/cql_online.py +249 -0
- sota-implementations/cql/discrete_cql_offline.py +180 -0
- sota-implementations/cql/discrete_cql_online.py +227 -0
- sota-implementations/cql/utils.py +471 -0
- sota-implementations/crossq/crossq.py +271 -0
- sota-implementations/crossq/utils.py +320 -0
- sota-implementations/ddpg/ddpg.py +231 -0
- sota-implementations/ddpg/utils.py +325 -0
- sota-implementations/decision_transformer/dt.py +163 -0
- sota-implementations/decision_transformer/lamb.py +167 -0
- sota-implementations/decision_transformer/online_dt.py +178 -0
- sota-implementations/decision_transformer/utils.py +562 -0
- sota-implementations/discrete_sac/discrete_sac.py +243 -0
- sota-implementations/discrete_sac/utils.py +324 -0
- sota-implementations/dqn/README.md +30 -0
- sota-implementations/dqn/dqn_atari.py +272 -0
- sota-implementations/dqn/dqn_cartpole.py +236 -0
- sota-implementations/dqn/utils_atari.py +132 -0
- sota-implementations/dqn/utils_cartpole.py +90 -0
- sota-implementations/dreamer/README.md +129 -0
- sota-implementations/dreamer/dreamer.py +586 -0
- sota-implementations/dreamer/dreamer_utils.py +1107 -0
- sota-implementations/expert-iteration/README.md +352 -0
- sota-implementations/expert-iteration/ei_utils.py +770 -0
- sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
- sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
- sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
- sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
- sota-implementations/gail/gail.py +327 -0
- sota-implementations/gail/gail_utils.py +68 -0
- sota-implementations/gail/ppo_utils.py +157 -0
- sota-implementations/grpo/README.md +273 -0
- sota-implementations/grpo/grpo-async.py +437 -0
- sota-implementations/grpo/grpo-sync.py +435 -0
- sota-implementations/grpo/grpo_utils.py +843 -0
- sota-implementations/grpo/requirements_gsm8k.txt +11 -0
- sota-implementations/grpo/requirements_ifeval.txt +16 -0
- sota-implementations/impala/README.md +33 -0
- sota-implementations/impala/impala_multi_node_ray.py +292 -0
- sota-implementations/impala/impala_multi_node_submitit.py +284 -0
- sota-implementations/impala/impala_single_node.py +261 -0
- sota-implementations/impala/utils.py +184 -0
- sota-implementations/iql/discrete_iql.py +230 -0
- sota-implementations/iql/iql_offline.py +164 -0
- sota-implementations/iql/iql_online.py +225 -0
- sota-implementations/iql/utils.py +437 -0
- sota-implementations/multiagent/README.md +74 -0
- sota-implementations/multiagent/iql.py +237 -0
- sota-implementations/multiagent/maddpg_iddpg.py +266 -0
- sota-implementations/multiagent/mappo_ippo.py +267 -0
- sota-implementations/multiagent/qmix_vdn.py +271 -0
- sota-implementations/multiagent/sac.py +337 -0
- sota-implementations/multiagent/utils/__init__.py +4 -0
- sota-implementations/multiagent/utils/logging.py +151 -0
- sota-implementations/multiagent/utils/utils.py +43 -0
- sota-implementations/ppo/README.md +29 -0
- sota-implementations/ppo/ppo_atari.py +305 -0
- sota-implementations/ppo/ppo_mujoco.py +293 -0
- sota-implementations/ppo/utils_atari.py +238 -0
- sota-implementations/ppo/utils_mujoco.py +152 -0
- sota-implementations/ppo_trainer/train.py +21 -0
- sota-implementations/redq/README.md +7 -0
- sota-implementations/redq/redq.py +199 -0
- sota-implementations/redq/utils.py +1060 -0
- sota-implementations/sac/sac-async.py +266 -0
- sota-implementations/sac/sac.py +239 -0
- sota-implementations/sac/utils.py +381 -0
- sota-implementations/sac_trainer/train.py +16 -0
- sota-implementations/td3/td3.py +254 -0
- sota-implementations/td3/utils.py +319 -0
- sota-implementations/td3_bc/td3_bc.py +177 -0
- sota-implementations/td3_bc/utils.py +251 -0
- torchrl/.dylibs/libc++.1.0.dylib +0 -0
- torchrl/__init__.py +144 -0
- torchrl/_extension.py +74 -0
- torchrl/_torchrl.cpython-314t-darwin.so +0 -0
- torchrl/_utils.py +1431 -0
- torchrl/collectors/__init__.py +48 -0
- torchrl/collectors/_base.py +1058 -0
- torchrl/collectors/_constants.py +88 -0
- torchrl/collectors/_multi_async.py +324 -0
- torchrl/collectors/_multi_base.py +1805 -0
- torchrl/collectors/_multi_sync.py +464 -0
- torchrl/collectors/_runner.py +581 -0
- torchrl/collectors/_single.py +2009 -0
- torchrl/collectors/_single_async.py +259 -0
- torchrl/collectors/collectors.py +62 -0
- torchrl/collectors/distributed/__init__.py +32 -0
- torchrl/collectors/distributed/default_configs.py +133 -0
- torchrl/collectors/distributed/generic.py +1306 -0
- torchrl/collectors/distributed/ray.py +1092 -0
- torchrl/collectors/distributed/rpc.py +1006 -0
- torchrl/collectors/distributed/sync.py +731 -0
- torchrl/collectors/distributed/utils.py +160 -0
- torchrl/collectors/llm/__init__.py +10 -0
- torchrl/collectors/llm/base.py +494 -0
- torchrl/collectors/llm/ray_collector.py +275 -0
- torchrl/collectors/llm/utils.py +36 -0
- torchrl/collectors/llm/weight_update/__init__.py +10 -0
- torchrl/collectors/llm/weight_update/vllm.py +348 -0
- torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
- torchrl/collectors/utils.py +433 -0
- torchrl/collectors/weight_update.py +591 -0
- torchrl/csrc/numpy_utils.h +38 -0
- torchrl/csrc/pybind.cpp +27 -0
- torchrl/csrc/segment_tree.h +458 -0
- torchrl/csrc/torch_utils.h +34 -0
- torchrl/csrc/utils.cpp +48 -0
- torchrl/csrc/utils.h +31 -0
- torchrl/data/__init__.py +187 -0
- torchrl/data/datasets/__init__.py +58 -0
- torchrl/data/datasets/atari_dqn.py +878 -0
- torchrl/data/datasets/common.py +281 -0
- torchrl/data/datasets/d4rl.py +489 -0
- torchrl/data/datasets/d4rl_infos.py +187 -0
- torchrl/data/datasets/gen_dgrl.py +375 -0
- torchrl/data/datasets/minari_data.py +643 -0
- torchrl/data/datasets/openml.py +177 -0
- torchrl/data/datasets/openx.py +798 -0
- torchrl/data/datasets/roboset.py +363 -0
- torchrl/data/datasets/utils.py +11 -0
- torchrl/data/datasets/vd4rl.py +432 -0
- torchrl/data/llm/__init__.py +34 -0
- torchrl/data/llm/dataset.py +491 -0
- torchrl/data/llm/history.py +1378 -0
- torchrl/data/llm/prompt.py +198 -0
- torchrl/data/llm/reward.py +225 -0
- torchrl/data/llm/topk.py +186 -0
- torchrl/data/llm/utils.py +543 -0
- torchrl/data/map/__init__.py +21 -0
- torchrl/data/map/hash.py +185 -0
- torchrl/data/map/query.py +204 -0
- torchrl/data/map/tdstorage.py +363 -0
- torchrl/data/map/tree.py +1434 -0
- torchrl/data/map/utils.py +103 -0
- torchrl/data/postprocs/__init__.py +8 -0
- torchrl/data/postprocs/postprocs.py +391 -0
- torchrl/data/replay_buffers/__init__.py +99 -0
- torchrl/data/replay_buffers/checkpointers.py +622 -0
- torchrl/data/replay_buffers/ray_buffer.py +292 -0
- torchrl/data/replay_buffers/replay_buffers.py +2376 -0
- torchrl/data/replay_buffers/samplers.py +2578 -0
- torchrl/data/replay_buffers/scheduler.py +265 -0
- torchrl/data/replay_buffers/storages.py +2412 -0
- torchrl/data/replay_buffers/utils.py +1042 -0
- torchrl/data/replay_buffers/writers.py +781 -0
- torchrl/data/tensor_specs.py +7101 -0
- torchrl/data/utils.py +334 -0
- torchrl/envs/__init__.py +265 -0
- torchrl/envs/async_envs.py +1105 -0
- torchrl/envs/batched_envs.py +3093 -0
- torchrl/envs/common.py +4241 -0
- torchrl/envs/custom/__init__.py +11 -0
- torchrl/envs/custom/chess.py +617 -0
- torchrl/envs/custom/llm.py +214 -0
- torchrl/envs/custom/pendulum.py +401 -0
- torchrl/envs/custom/san_moves.txt +29274 -0
- torchrl/envs/custom/tictactoeenv.py +288 -0
- torchrl/envs/env_creator.py +263 -0
- torchrl/envs/gym_like.py +752 -0
- torchrl/envs/libs/__init__.py +68 -0
- torchrl/envs/libs/_gym_utils.py +326 -0
- torchrl/envs/libs/brax.py +846 -0
- torchrl/envs/libs/dm_control.py +544 -0
- torchrl/envs/libs/envpool.py +447 -0
- torchrl/envs/libs/gym.py +2239 -0
- torchrl/envs/libs/habitat.py +138 -0
- torchrl/envs/libs/isaac_lab.py +87 -0
- torchrl/envs/libs/isaacgym.py +203 -0
- torchrl/envs/libs/jax_utils.py +166 -0
- torchrl/envs/libs/jumanji.py +963 -0
- torchrl/envs/libs/meltingpot.py +599 -0
- torchrl/envs/libs/openml.py +153 -0
- torchrl/envs/libs/openspiel.py +652 -0
- torchrl/envs/libs/pettingzoo.py +1042 -0
- torchrl/envs/libs/procgen.py +351 -0
- torchrl/envs/libs/robohive.py +429 -0
- torchrl/envs/libs/smacv2.py +645 -0
- torchrl/envs/libs/unity_mlagents.py +891 -0
- torchrl/envs/libs/utils.py +147 -0
- torchrl/envs/libs/vmas.py +813 -0
- torchrl/envs/llm/__init__.py +63 -0
- torchrl/envs/llm/chat.py +730 -0
- torchrl/envs/llm/datasets/README.md +4 -0
- torchrl/envs/llm/datasets/__init__.py +17 -0
- torchrl/envs/llm/datasets/gsm8k.py +353 -0
- torchrl/envs/llm/datasets/ifeval.py +274 -0
- torchrl/envs/llm/envs.py +789 -0
- torchrl/envs/llm/libs/README.md +3 -0
- torchrl/envs/llm/libs/__init__.py +8 -0
- torchrl/envs/llm/libs/mlgym.py +869 -0
- torchrl/envs/llm/reward/__init__.py +10 -0
- torchrl/envs/llm/reward/gsm8k.py +324 -0
- torchrl/envs/llm/reward/ifeval/README.md +13 -0
- torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
- torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
- torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
- torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
- torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
- torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
- torchrl/envs/llm/transforms/__init__.py +55 -0
- torchrl/envs/llm/transforms/browser.py +292 -0
- torchrl/envs/llm/transforms/dataloading.py +859 -0
- torchrl/envs/llm/transforms/format.py +73 -0
- torchrl/envs/llm/transforms/kl.py +1544 -0
- torchrl/envs/llm/transforms/policy_version.py +189 -0
- torchrl/envs/llm/transforms/reason.py +323 -0
- torchrl/envs/llm/transforms/tokenizer.py +321 -0
- torchrl/envs/llm/transforms/tools.py +1955 -0
- torchrl/envs/model_based/__init__.py +9 -0
- torchrl/envs/model_based/common.py +180 -0
- torchrl/envs/model_based/dreamer.py +112 -0
- torchrl/envs/transforms/__init__.py +147 -0
- torchrl/envs/transforms/functional.py +48 -0
- torchrl/envs/transforms/gym_transforms.py +203 -0
- torchrl/envs/transforms/module.py +341 -0
- torchrl/envs/transforms/r3m.py +372 -0
- torchrl/envs/transforms/ray_service.py +663 -0
- torchrl/envs/transforms/rb_transforms.py +214 -0
- torchrl/envs/transforms/transforms.py +11835 -0
- torchrl/envs/transforms/utils.py +94 -0
- torchrl/envs/transforms/vc1.py +307 -0
- torchrl/envs/transforms/vecnorm.py +845 -0
- torchrl/envs/transforms/vip.py +407 -0
- torchrl/envs/utils.py +1718 -0
- torchrl/envs/vec_envs.py +11 -0
- torchrl/modules/__init__.py +206 -0
- torchrl/modules/distributions/__init__.py +73 -0
- torchrl/modules/distributions/continuous.py +830 -0
- torchrl/modules/distributions/discrete.py +908 -0
- torchrl/modules/distributions/truncated_normal.py +187 -0
- torchrl/modules/distributions/utils.py +233 -0
- torchrl/modules/llm/__init__.py +62 -0
- torchrl/modules/llm/backends/__init__.py +65 -0
- torchrl/modules/llm/backends/vllm/__init__.py +94 -0
- torchrl/modules/llm/backends/vllm/_models.py +46 -0
- torchrl/modules/llm/backends/vllm/base.py +72 -0
- torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
- torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
- torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
- torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
- torchrl/modules/llm/policies/__init__.py +28 -0
- torchrl/modules/llm/policies/common.py +1809 -0
- torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
- torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
- torchrl/modules/llm/utils.py +23 -0
- torchrl/modules/mcts/__init__.py +21 -0
- torchrl/modules/mcts/scores.py +579 -0
- torchrl/modules/models/__init__.py +86 -0
- torchrl/modules/models/batchrenorm.py +119 -0
- torchrl/modules/models/decision_transformer.py +179 -0
- torchrl/modules/models/exploration.py +731 -0
- torchrl/modules/models/llm.py +156 -0
- torchrl/modules/models/model_based.py +596 -0
- torchrl/modules/models/models.py +1712 -0
- torchrl/modules/models/multiagent.py +1067 -0
- torchrl/modules/models/recipes/impala.py +185 -0
- torchrl/modules/models/utils.py +162 -0
- torchrl/modules/planners/__init__.py +10 -0
- torchrl/modules/planners/cem.py +228 -0
- torchrl/modules/planners/common.py +73 -0
- torchrl/modules/planners/mppi.py +265 -0
- torchrl/modules/tensordict_module/__init__.py +89 -0
- torchrl/modules/tensordict_module/actors.py +2457 -0
- torchrl/modules/tensordict_module/common.py +529 -0
- torchrl/modules/tensordict_module/exploration.py +814 -0
- torchrl/modules/tensordict_module/probabilistic.py +321 -0
- torchrl/modules/tensordict_module/rnn.py +1639 -0
- torchrl/modules/tensordict_module/sequence.py +132 -0
- torchrl/modules/tensordict_module/world_models.py +34 -0
- torchrl/modules/utils/__init__.py +38 -0
- torchrl/modules/utils/mappings.py +9 -0
- torchrl/modules/utils/utils.py +89 -0
- torchrl/objectives/__init__.py +78 -0
- torchrl/objectives/a2c.py +659 -0
- torchrl/objectives/common.py +753 -0
- torchrl/objectives/cql.py +1346 -0
- torchrl/objectives/crossq.py +710 -0
- torchrl/objectives/ddpg.py +453 -0
- torchrl/objectives/decision_transformer.py +371 -0
- torchrl/objectives/deprecated.py +516 -0
- torchrl/objectives/dqn.py +683 -0
- torchrl/objectives/dreamer.py +488 -0
- torchrl/objectives/functional.py +48 -0
- torchrl/objectives/gail.py +258 -0
- torchrl/objectives/iql.py +996 -0
- torchrl/objectives/llm/__init__.py +30 -0
- torchrl/objectives/llm/grpo.py +846 -0
- torchrl/objectives/llm/sft.py +482 -0
- torchrl/objectives/multiagent/__init__.py +8 -0
- torchrl/objectives/multiagent/qmixer.py +396 -0
- torchrl/objectives/ppo.py +1669 -0
- torchrl/objectives/redq.py +683 -0
- torchrl/objectives/reinforce.py +530 -0
- torchrl/objectives/sac.py +1580 -0
- torchrl/objectives/td3.py +570 -0
- torchrl/objectives/td3_bc.py +625 -0
- torchrl/objectives/utils.py +782 -0
- torchrl/objectives/value/__init__.py +28 -0
- torchrl/objectives/value/advantages.py +1956 -0
- torchrl/objectives/value/functional.py +1459 -0
- torchrl/objectives/value/utils.py +360 -0
- torchrl/record/__init__.py +17 -0
- torchrl/record/loggers/__init__.py +23 -0
- torchrl/record/loggers/common.py +48 -0
- torchrl/record/loggers/csv.py +226 -0
- torchrl/record/loggers/mlflow.py +142 -0
- torchrl/record/loggers/tensorboard.py +139 -0
- torchrl/record/loggers/trackio.py +163 -0
- torchrl/record/loggers/utils.py +78 -0
- torchrl/record/loggers/wandb.py +214 -0
- torchrl/record/recorder.py +554 -0
- torchrl/services/__init__.py +79 -0
- torchrl/services/base.py +109 -0
- torchrl/services/ray_service.py +453 -0
- torchrl/testing/__init__.py +107 -0
- torchrl/testing/assertions.py +179 -0
- torchrl/testing/dist_utils.py +122 -0
- torchrl/testing/env_creators.py +227 -0
- torchrl/testing/env_helper.py +35 -0
- torchrl/testing/gym_helpers.py +156 -0
- torchrl/testing/llm_mocks.py +119 -0
- torchrl/testing/mocking_classes.py +2720 -0
- torchrl/testing/modules.py +295 -0
- torchrl/testing/mp_helpers.py +15 -0
- torchrl/testing/ray_helpers.py +293 -0
- torchrl/testing/utils.py +190 -0
- torchrl/trainers/__init__.py +42 -0
- torchrl/trainers/algorithms/__init__.py +11 -0
- torchrl/trainers/algorithms/configs/__init__.py +705 -0
- torchrl/trainers/algorithms/configs/collectors.py +216 -0
- torchrl/trainers/algorithms/configs/common.py +41 -0
- torchrl/trainers/algorithms/configs/data.py +308 -0
- torchrl/trainers/algorithms/configs/envs.py +104 -0
- torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
- torchrl/trainers/algorithms/configs/logging.py +80 -0
- torchrl/trainers/algorithms/configs/modules.py +570 -0
- torchrl/trainers/algorithms/configs/objectives.py +177 -0
- torchrl/trainers/algorithms/configs/trainers.py +340 -0
- torchrl/trainers/algorithms/configs/transforms.py +955 -0
- torchrl/trainers/algorithms/configs/utils.py +252 -0
- torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
- torchrl/trainers/algorithms/configs/weight_update.py +159 -0
- torchrl/trainers/algorithms/ppo.py +373 -0
- torchrl/trainers/algorithms/sac.py +308 -0
- torchrl/trainers/helpers/__init__.py +40 -0
- torchrl/trainers/helpers/collectors.py +416 -0
- torchrl/trainers/helpers/envs.py +573 -0
- torchrl/trainers/helpers/logger.py +33 -0
- torchrl/trainers/helpers/losses.py +132 -0
- torchrl/trainers/helpers/models.py +658 -0
- torchrl/trainers/helpers/replay_buffer.py +59 -0
- torchrl/trainers/helpers/trainers.py +301 -0
- torchrl/trainers/trainers.py +2052 -0
- torchrl/weight_update/__init__.py +33 -0
- torchrl/weight_update/_distributed.py +749 -0
- torchrl/weight_update/_mp.py +624 -0
- torchrl/weight_update/_noupdate.py +102 -0
- torchrl/weight_update/_ray.py +1032 -0
- torchrl/weight_update/_rpc.py +284 -0
- torchrl/weight_update/_shared.py +891 -0
- torchrl/weight_update/llm/__init__.py +32 -0
- torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
- torchrl/weight_update/llm/vllm_nccl.py +710 -0
- torchrl/weight_update/utils.py +73 -0
- torchrl/weight_update/weight_sync_schemes.py +1244 -0
- torchrl-0.11.0.dist-info/METADATA +1308 -0
- torchrl-0.11.0.dist-info/RECORD +395 -0
- torchrl-0.11.0.dist-info/WHEEL +5 -0
- torchrl-0.11.0.dist-info/entry_points.txt +2 -0
- torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
- torchrl-0.11.0.dist-info/top_level.txt +7 -0
|
@@ -0,0 +1,1346 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import math
|
|
8
|
+
import warnings
|
|
9
|
+
from copy import deepcopy
|
|
10
|
+
from dataclasses import dataclass
|
|
11
|
+
|
|
12
|
+
import numpy as np
|
|
13
|
+
import torch
|
|
14
|
+
import torch.nn as nn
|
|
15
|
+
from tensordict import TensorDict, TensorDictBase, TensorDictParams
|
|
16
|
+
from tensordict.nn import dispatch, ProbabilisticTensorDictSequential, TensorDictModule
|
|
17
|
+
from tensordict.utils import NestedKey, unravel_key
|
|
18
|
+
from torch import Tensor
|
|
19
|
+
|
|
20
|
+
from torchrl.data.tensor_specs import Composite
|
|
21
|
+
from torchrl.data.utils import _find_action_space
|
|
22
|
+
from torchrl.envs.utils import ExplorationType, set_exploration_type
|
|
23
|
+
from torchrl.modules.tensordict_module.actors import QValueActor
|
|
24
|
+
from torchrl.modules.tensordict_module.common import ensure_tensordict_compatible
|
|
25
|
+
from torchrl.objectives.common import LossModule
|
|
26
|
+
from torchrl.objectives.utils import (
|
|
27
|
+
_cache_values,
|
|
28
|
+
_GAMMA_LMBDA_DEPREC_ERROR,
|
|
29
|
+
_reduce,
|
|
30
|
+
_vmap_func,
|
|
31
|
+
default_value_kwargs,
|
|
32
|
+
distance_loss,
|
|
33
|
+
ValueEstimators,
|
|
34
|
+
)
|
|
35
|
+
from torchrl.objectives.value import (
|
|
36
|
+
TD0Estimator,
|
|
37
|
+
TD1Estimator,
|
|
38
|
+
TDLambdaEstimator,
|
|
39
|
+
ValueEstimatorBase,
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class CQLLoss(LossModule):
|
|
44
|
+
"""TorchRL implementation of the continuous CQL loss.
|
|
45
|
+
|
|
46
|
+
Presented in "Conservative Q-Learning for Offline Reinforcement Learning" https://arxiv.org/abs/2006.04779
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
actor_network (ProbabilisticTensorDictSequential): stochastic actor
|
|
50
|
+
qvalue_network (TensorDictModule or list of TensorDictModule): Q(s, a) parametric model.
|
|
51
|
+
This module typically outputs a ``"state_action_value"`` entry.
|
|
52
|
+
If a single instance of `qvalue_network` is provided, it will be duplicated ``N``
|
|
53
|
+
times (where ``N=2`` for this loss). If a list of modules is passed, their
|
|
54
|
+
parameters will be stacked unless they share the same identity (in which case
|
|
55
|
+
the original parameter will be expanded).
|
|
56
|
+
|
|
57
|
+
.. warning:: When a list of parameters if passed, it will **not** be compared against the policy parameters
|
|
58
|
+
and all the parameters will be considered as untied.
|
|
59
|
+
|
|
60
|
+
Keyword args:
|
|
61
|
+
loss_function (str, optional): loss function to be used with
|
|
62
|
+
the value function loss. Default is `"smooth_l1"`.
|
|
63
|
+
alpha_init (:obj:`float`, optional): initial entropy multiplier.
|
|
64
|
+
Default is 1.0.
|
|
65
|
+
min_alpha (:obj:`float`, optional): min value of alpha.
|
|
66
|
+
Default is None (no minimum value).
|
|
67
|
+
max_alpha (:obj:`float`, optional): max value of alpha.
|
|
68
|
+
Default is None (no maximum value).
|
|
69
|
+
action_spec (TensorSpec, optional): the action tensor spec. If not provided
|
|
70
|
+
and the target entropy is ``"auto"``, it will be retrieved from
|
|
71
|
+
the actor.
|
|
72
|
+
fixed_alpha (bool, optional): if ``True``, alpha will be fixed to its
|
|
73
|
+
initial value. Otherwise, alpha will be optimized to
|
|
74
|
+
match the 'target_entropy' value.
|
|
75
|
+
Default is ``False``.
|
|
76
|
+
target_entropy (:obj:`float` or str, optional): Target entropy for the
|
|
77
|
+
stochastic policy. Default is "auto", where target entropy is
|
|
78
|
+
computed as :obj:`-prod(n_actions)`.
|
|
79
|
+
delay_actor (bool, optional): Whether to separate the target actor
|
|
80
|
+
networks from the actor networks used for data collection.
|
|
81
|
+
Default is ``False``.
|
|
82
|
+
delay_qvalue (bool, optional): Whether to separate the target Q value
|
|
83
|
+
networks from the Q value networks used for data collection.
|
|
84
|
+
Default is ``True``.
|
|
85
|
+
gamma (:obj:`float`, optional): Discount factor. Default is ``None``.
|
|
86
|
+
temperature (:obj:`float`, optional): CQL temperature. Default is 1.0.
|
|
87
|
+
min_q_weight (:obj:`float`, optional): Minimum Q weight. Default is 1.0.
|
|
88
|
+
max_q_backup (bool, optional): Whether to use the max-min Q backup.
|
|
89
|
+
Default is ``False``.
|
|
90
|
+
deterministic_backup (bool, optional): Whether to use the deterministic. Default is ``True``.
|
|
91
|
+
num_random (int, optional): Number of random actions to sample for the CQL loss.
|
|
92
|
+
Default is 10.
|
|
93
|
+
with_lagrange (bool, optional): Whether to use the Lagrange multiplier.
|
|
94
|
+
Default is ``False``.
|
|
95
|
+
lagrange_thresh (:obj:`float`, optional): Lagrange threshold. Default is 0.0.
|
|
96
|
+
reduction (str, optional): Specifies the reduction to apply to the output:
|
|
97
|
+
``"none"`` | ``"mean"`` | ``"sum"``. ``"none"``: no reduction will be applied,
|
|
98
|
+
``"mean"``: the sum of the output will be divided by the number of
|
|
99
|
+
elements in the output, ``"sum"``: the output will be summed. Default: ``"mean"``.
|
|
100
|
+
deactivate_vmap (bool, optional): whether to deactivate vmap calls and replace them with a plain for loop.
|
|
101
|
+
Defaults to ``False``.
|
|
102
|
+
|
|
103
|
+
Examples:
|
|
104
|
+
>>> import torch
|
|
105
|
+
>>> from torch import nn
|
|
106
|
+
>>> from torchrl.data import Bounded
|
|
107
|
+
>>> from torchrl.modules.distributions import NormalParamExtractor, TanhNormal
|
|
108
|
+
>>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator
|
|
109
|
+
>>> from torchrl.modules.tensordict_module.common import SafeModule
|
|
110
|
+
>>> from torchrl.objectives.cql import CQLLoss
|
|
111
|
+
>>> from tensordict import TensorDict
|
|
112
|
+
>>> n_act, n_obs = 4, 3
|
|
113
|
+
>>> spec = Bounded(-torch.ones(n_act), torch.ones(n_act), (n_act,))
|
|
114
|
+
>>> net = nn.Sequential(nn.Linear(n_obs, 2 * n_act), NormalParamExtractor())
|
|
115
|
+
>>> module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"])
|
|
116
|
+
>>> actor = ProbabilisticActor(
|
|
117
|
+
... module=module,
|
|
118
|
+
... in_keys=["loc", "scale"],
|
|
119
|
+
... spec=spec,
|
|
120
|
+
... distribution_class=TanhNormal)
|
|
121
|
+
>>> class ValueClass(nn.Module):
|
|
122
|
+
... def __init__(self):
|
|
123
|
+
... super().__init__()
|
|
124
|
+
... self.linear = nn.Linear(n_obs + n_act, 1)
|
|
125
|
+
... def forward(self, obs, act):
|
|
126
|
+
... return self.linear(torch.cat([obs, act], -1))
|
|
127
|
+
>>> module = ValueClass()
|
|
128
|
+
>>> qvalue = ValueOperator(
|
|
129
|
+
... module=module,
|
|
130
|
+
... in_keys=['observation', 'action'])
|
|
131
|
+
>>> loss = CQLLoss(actor, qvalue)
|
|
132
|
+
>>> batch = [2, ]
|
|
133
|
+
>>> action = spec.rand(batch)
|
|
134
|
+
>>> data = TensorDict({
|
|
135
|
+
... "observation": torch.randn(*batch, n_obs),
|
|
136
|
+
... "action": action,
|
|
137
|
+
... ("next", "done"): torch.zeros(*batch, 1, dtype=torch.bool),
|
|
138
|
+
... ("next", "terminated"): torch.zeros(*batch, 1, dtype=torch.bool),
|
|
139
|
+
... ("next", "reward"): torch.randn(*batch, 1),
|
|
140
|
+
... ("next", "observation"): torch.randn(*batch, n_obs),
|
|
141
|
+
... }, batch)
|
|
142
|
+
>>> loss(data)
|
|
143
|
+
TensorDict(
|
|
144
|
+
fields={
|
|
145
|
+
alpha: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
|
|
146
|
+
entropy: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
|
|
147
|
+
loss_actor: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
|
|
148
|
+
loss_actor_bc: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
|
|
149
|
+
loss_alpha: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
|
|
150
|
+
loss_cql: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
|
|
151
|
+
loss_qvalue: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
|
|
152
|
+
batch_size=torch.Size([]),
|
|
153
|
+
device=None,
|
|
154
|
+
is_shared=False)
|
|
155
|
+
|
|
156
|
+
This class is compatible with non-tensordict based modules too and can be
|
|
157
|
+
used without recurring to any tensordict-related primitive. In this case,
|
|
158
|
+
the expected keyword arguments are:
|
|
159
|
+
``["action", "next_reward", "next_done", "next_terminated"]`` + in_keys of the actor, value, and qvalue network.
|
|
160
|
+
The return value is a tuple of tensors in the following order:
|
|
161
|
+
``["loss_actor", "loss_qvalue", "loss_alpha", "loss_alpha_prime", "alpha", "entropy"]``.
|
|
162
|
+
|
|
163
|
+
Examples:
|
|
164
|
+
>>> import torch
|
|
165
|
+
>>> from torch import nn
|
|
166
|
+
>>> from torchrl.data import Bounded
|
|
167
|
+
>>> from torchrl.modules.distributions import NormalParamExtractor, TanhNormal
|
|
168
|
+
>>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator
|
|
169
|
+
>>> from torchrl.modules.tensordict_module.common import SafeModule
|
|
170
|
+
>>> from torchrl.objectives.cql import CQLLoss
|
|
171
|
+
>>> _ = torch.manual_seed(42)
|
|
172
|
+
>>> n_act, n_obs = 4, 3
|
|
173
|
+
>>> spec = Bounded(-torch.ones(n_act), torch.ones(n_act), (n_act,))
|
|
174
|
+
>>> net = nn.Sequential(nn.Linear(n_obs, 2 * n_act), NormalParamExtractor())
|
|
175
|
+
>>> module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"])
|
|
176
|
+
>>> actor = ProbabilisticActor(
|
|
177
|
+
... module=module,
|
|
178
|
+
... in_keys=["loc", "scale"],
|
|
179
|
+
... spec=spec,
|
|
180
|
+
... distribution_class=TanhNormal)
|
|
181
|
+
>>> class ValueClass(nn.Module):
|
|
182
|
+
... def __init__(self):
|
|
183
|
+
... super().__init__()
|
|
184
|
+
... self.linear = nn.Linear(n_obs + n_act, 1)
|
|
185
|
+
... def forward(self, obs, act):
|
|
186
|
+
... return self.linear(torch.cat([obs, act], -1))
|
|
187
|
+
>>> module = ValueClass()
|
|
188
|
+
>>> qvalue = ValueOperator(
|
|
189
|
+
... module=module,
|
|
190
|
+
... in_keys=['observation', 'action'])
|
|
191
|
+
>>> loss = CQLLoss(actor, qvalue)
|
|
192
|
+
>>> batch = [2, ]
|
|
193
|
+
>>> action = spec.rand(batch)
|
|
194
|
+
>>> loss_actor, loss_actor_bc, loss_qvalue, loss_cql, *_ = loss(
|
|
195
|
+
... observation=torch.randn(*batch, n_obs),
|
|
196
|
+
... action=action,
|
|
197
|
+
... next_done=torch.zeros(*batch, 1, dtype=torch.bool),
|
|
198
|
+
... next_terminated=torch.zeros(*batch, 1, dtype=torch.bool),
|
|
199
|
+
... next_observation=torch.zeros(*batch, n_obs),
|
|
200
|
+
... next_reward=torch.randn(*batch, 1))
|
|
201
|
+
>>> loss_actor.backward()
|
|
202
|
+
|
|
203
|
+
The output keys can also be filtered using the :meth:`CQLLoss.select_out_keys`
|
|
204
|
+
method.
|
|
205
|
+
|
|
206
|
+
Examples:
|
|
207
|
+
>>> _ = loss.select_out_keys('loss_actor', 'loss_qvalue')
|
|
208
|
+
>>> loss_actor, loss_qvalue = loss(
|
|
209
|
+
... observation=torch.randn(*batch, n_obs),
|
|
210
|
+
... action=action,
|
|
211
|
+
... next_done=torch.zeros(*batch, 1, dtype=torch.bool),
|
|
212
|
+
... next_terminated=torch.zeros(*batch, 1, dtype=torch.bool),
|
|
213
|
+
... next_observation=torch.zeros(*batch, n_obs),
|
|
214
|
+
... next_reward=torch.randn(*batch, 1))
|
|
215
|
+
>>> loss_actor.backward()
|
|
216
|
+
"""
|
|
217
|
+
|
|
218
|
+
@dataclass
|
|
219
|
+
class _AcceptedKeys:
|
|
220
|
+
"""Maintains default values for all configurable tensordict keys.
|
|
221
|
+
|
|
222
|
+
This class defines which tensordict keys can be set using '.set_keys(key_name=key_value)' and their
|
|
223
|
+
default values.
|
|
224
|
+
|
|
225
|
+
Attributes:
|
|
226
|
+
action (NestedKey): The input tensordict key where the action is expected.
|
|
227
|
+
Defaults to ``"advantage"``.
|
|
228
|
+
value (NestedKey): The input tensordict key where the state value is expected.
|
|
229
|
+
Will be used for the underlying value estimator. Defaults to ``"state_value"``.
|
|
230
|
+
state_action_value (NestedKey): The input tensordict key where the
|
|
231
|
+
state action value is expected. Defaults to ``"state_action_value"``.
|
|
232
|
+
log_prob (NestedKey): The input tensordict key where the log probability is expected.
|
|
233
|
+
Defaults to ``"_log_prob"``.
|
|
234
|
+
pred_q1 (NestedKey): The input tensordict key where the predicted Q1 values are expected.
|
|
235
|
+
Defaults to ``"pred_q1"``.
|
|
236
|
+
pred_q2 (NestedKey): The input tensordict key where the predicted Q2 values are expected.
|
|
237
|
+
Defaults to ``"pred_q2"``.
|
|
238
|
+
priority (NestedKey): The input tensordict key where the target priority is written to.
|
|
239
|
+
Defaults to ``"td_error"``.
|
|
240
|
+
cql_q1_loss (NestedKey): The input tensordict key where the CQL Q1 loss is expected.
|
|
241
|
+
Defaults to ``"cql_q1_loss"``.
|
|
242
|
+
cql_q2_loss (NestedKey): The input tensordict key where the CQL Q2 loss is expected.
|
|
243
|
+
Defaults to ``"cql_q2_loss"``.
|
|
244
|
+
reward (NestedKey): The input tensordict key where the reward is expected.
|
|
245
|
+
Defaults to ``"reward"``.
|
|
246
|
+
done (NestedKey): The input tensordict key where the done flag is expected.
|
|
247
|
+
Defaults to ``"done"``.
|
|
248
|
+
terminated (NestedKey): The input tensordict key where the terminated flag is expected.
|
|
249
|
+
Defaults to ``"terminated"``.
|
|
250
|
+
"""
|
|
251
|
+
|
|
252
|
+
action: NestedKey = "action"
|
|
253
|
+
value: NestedKey = "state_value"
|
|
254
|
+
state_action_value: NestedKey = "state_action_value"
|
|
255
|
+
log_prob: NestedKey = "_log_prob"
|
|
256
|
+
pred_q1: NestedKey = "pred_q1"
|
|
257
|
+
pred_q2: NestedKey = "pred_q2"
|
|
258
|
+
priority: NestedKey = "td_error"
|
|
259
|
+
cql_q1_loss: NestedKey = "cql_q1_loss"
|
|
260
|
+
cql_q2_loss: NestedKey = "cql_q2_loss"
|
|
261
|
+
priority: NestedKey = "td_error"
|
|
262
|
+
reward: NestedKey = "reward"
|
|
263
|
+
done: NestedKey = "done"
|
|
264
|
+
terminated: NestedKey = "terminated"
|
|
265
|
+
|
|
266
|
+
tensor_keys: _AcceptedKeys
|
|
267
|
+
default_keys = _AcceptedKeys
|
|
268
|
+
default_value_estimator = ValueEstimators.TD0
|
|
269
|
+
|
|
270
|
+
actor_network: TensorDictModule
|
|
271
|
+
qvalue_network: TensorDictModule
|
|
272
|
+
actor_network_params: TensorDictParams
|
|
273
|
+
qvalue_network_params: TensorDictParams
|
|
274
|
+
target_actor_network_params: TensorDictParams
|
|
275
|
+
target_qvalue_network_params: TensorDictParams
|
|
276
|
+
|
|
277
|
+
def __init__(
|
|
278
|
+
self,
|
|
279
|
+
actor_network: ProbabilisticTensorDictSequential,
|
|
280
|
+
qvalue_network: TensorDictModule | list[TensorDictModule],
|
|
281
|
+
*,
|
|
282
|
+
loss_function: str = "smooth_l1",
|
|
283
|
+
alpha_init: float = 1.0,
|
|
284
|
+
min_alpha: float | None = None,
|
|
285
|
+
max_alpha: float | None = None,
|
|
286
|
+
action_spec=None,
|
|
287
|
+
fixed_alpha: bool = False,
|
|
288
|
+
target_entropy: str | float = "auto",
|
|
289
|
+
delay_actor: bool = False,
|
|
290
|
+
delay_qvalue: bool = True,
|
|
291
|
+
gamma: float | None = None,
|
|
292
|
+
temperature: float = 1.0,
|
|
293
|
+
min_q_weight: float = 1.0,
|
|
294
|
+
max_q_backup: bool = False,
|
|
295
|
+
deterministic_backup: bool = True,
|
|
296
|
+
num_random: int = 10,
|
|
297
|
+
with_lagrange: bool = False,
|
|
298
|
+
lagrange_thresh: float = 0.0,
|
|
299
|
+
reduction: str | None = None,
|
|
300
|
+
deactivate_vmap: bool = False,
|
|
301
|
+
) -> None:
|
|
302
|
+
self._out_keys = None
|
|
303
|
+
if reduction is None:
|
|
304
|
+
reduction = "mean"
|
|
305
|
+
super().__init__()
|
|
306
|
+
|
|
307
|
+
# Actor
|
|
308
|
+
self.delay_actor = delay_actor
|
|
309
|
+
self.convert_to_functional(
|
|
310
|
+
actor_network,
|
|
311
|
+
"actor_network",
|
|
312
|
+
create_target_params=self.delay_actor,
|
|
313
|
+
)
|
|
314
|
+
self.deactivate_vmap = deactivate_vmap
|
|
315
|
+
|
|
316
|
+
# Q value
|
|
317
|
+
self.delay_qvalue = delay_qvalue
|
|
318
|
+
self.num_qvalue_nets = 2
|
|
319
|
+
|
|
320
|
+
self.convert_to_functional(
|
|
321
|
+
qvalue_network,
|
|
322
|
+
"qvalue_network",
|
|
323
|
+
self.num_qvalue_nets,
|
|
324
|
+
create_target_params=self.delay_qvalue,
|
|
325
|
+
compare_against=list(actor_network.parameters()),
|
|
326
|
+
)
|
|
327
|
+
|
|
328
|
+
self.loss_function = loss_function
|
|
329
|
+
try:
|
|
330
|
+
device = next(self.parameters()).device
|
|
331
|
+
except AttributeError:
|
|
332
|
+
device = getattr(torch, "get_default_device", lambda: torch.device("cpu"))()
|
|
333
|
+
self.register_buffer("alpha_init", torch.tensor(alpha_init, device=device))
|
|
334
|
+
if bool(min_alpha) ^ bool(max_alpha):
|
|
335
|
+
min_alpha = min_alpha if min_alpha else 0.0
|
|
336
|
+
if max_alpha == 0:
|
|
337
|
+
raise ValueError("max_alpha must be either None or greater than 0.")
|
|
338
|
+
max_alpha = max_alpha if max_alpha else 1e9
|
|
339
|
+
if min_alpha:
|
|
340
|
+
self.register_buffer(
|
|
341
|
+
"min_log_alpha", torch.tensor(min_alpha, device=device).log()
|
|
342
|
+
)
|
|
343
|
+
else:
|
|
344
|
+
self.min_log_alpha = None
|
|
345
|
+
if max_alpha:
|
|
346
|
+
self.register_buffer(
|
|
347
|
+
"max_log_alpha", torch.tensor(max_alpha, device=device).log()
|
|
348
|
+
)
|
|
349
|
+
else:
|
|
350
|
+
self.max_log_alpha = None
|
|
351
|
+
self.fixed_alpha = fixed_alpha
|
|
352
|
+
if fixed_alpha:
|
|
353
|
+
self.register_buffer(
|
|
354
|
+
"log_alpha", torch.tensor(math.log(alpha_init), device=device)
|
|
355
|
+
)
|
|
356
|
+
else:
|
|
357
|
+
self.register_parameter(
|
|
358
|
+
"log_alpha",
|
|
359
|
+
torch.nn.Parameter(torch.tensor(math.log(alpha_init), device=device)),
|
|
360
|
+
)
|
|
361
|
+
|
|
362
|
+
self._target_entropy = target_entropy
|
|
363
|
+
self._action_spec = action_spec
|
|
364
|
+
self.target_entropy_buffer = None
|
|
365
|
+
|
|
366
|
+
if gamma is not None:
|
|
367
|
+
raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR)
|
|
368
|
+
|
|
369
|
+
self.temperature = temperature
|
|
370
|
+
self.min_q_weight = min_q_weight
|
|
371
|
+
self.max_q_backup = max_q_backup
|
|
372
|
+
self.deterministic_backup = deterministic_backup
|
|
373
|
+
self.num_random = num_random
|
|
374
|
+
self.with_lagrange = with_lagrange
|
|
375
|
+
|
|
376
|
+
if self.with_lagrange:
|
|
377
|
+
self.target_action_gap = lagrange_thresh
|
|
378
|
+
self.register_parameter(
|
|
379
|
+
"log_alpha_prime",
|
|
380
|
+
torch.nn.Parameter(torch.tensor(math.log(1.0), device=device)),
|
|
381
|
+
)
|
|
382
|
+
self._make_vmap()
|
|
383
|
+
self.reduction = reduction
|
|
384
|
+
_ = self.target_entropy
|
|
385
|
+
|
|
386
|
+
def _make_vmap(self):
|
|
387
|
+
self._vmap_qvalue_networkN0 = _vmap_func(
|
|
388
|
+
self.qvalue_network,
|
|
389
|
+
(None, 0),
|
|
390
|
+
randomness=self.vmap_randomness,
|
|
391
|
+
pseudo_vmap=self.deactivate_vmap,
|
|
392
|
+
)
|
|
393
|
+
self._vmap_qvalue_network00 = _vmap_func(
|
|
394
|
+
self.qvalue_network,
|
|
395
|
+
randomness=self.vmap_randomness,
|
|
396
|
+
pseudo_vmap=self.deactivate_vmap,
|
|
397
|
+
)
|
|
398
|
+
|
|
399
|
+
@property
|
|
400
|
+
def target_entropy(self):
|
|
401
|
+
target_entropy = self.target_entropy_buffer
|
|
402
|
+
if target_entropy is None:
|
|
403
|
+
delattr(self, "target_entropy_buffer")
|
|
404
|
+
target_entropy = self._target_entropy
|
|
405
|
+
action_spec = self._action_spec
|
|
406
|
+
actor_network = self.actor_network
|
|
407
|
+
device = next(self.parameters()).device
|
|
408
|
+
if target_entropy == "auto":
|
|
409
|
+
action_spec = (
|
|
410
|
+
action_spec
|
|
411
|
+
if action_spec is not None
|
|
412
|
+
else getattr(actor_network, "spec", None)
|
|
413
|
+
)
|
|
414
|
+
if action_spec is None:
|
|
415
|
+
raise RuntimeError(
|
|
416
|
+
"Cannot infer the dimensionality of the action. Consider providing "
|
|
417
|
+
"the target entropy explicitly or provide the spec of the "
|
|
418
|
+
"action tensor in the actor network."
|
|
419
|
+
)
|
|
420
|
+
if not isinstance(action_spec, Composite):
|
|
421
|
+
action_spec = Composite({self.tensor_keys.action: action_spec})
|
|
422
|
+
if (
|
|
423
|
+
isinstance(self.tensor_keys.action, tuple)
|
|
424
|
+
and len(self.tensor_keys.action) > 1
|
|
425
|
+
):
|
|
426
|
+
action_container_shape = action_spec[
|
|
427
|
+
self.tensor_keys.action[:-1]
|
|
428
|
+
].shape
|
|
429
|
+
else:
|
|
430
|
+
action_container_shape = action_spec.shape
|
|
431
|
+
target_entropy = -float(
|
|
432
|
+
action_spec[self.tensor_keys.action]
|
|
433
|
+
.shape[len(action_container_shape) :]
|
|
434
|
+
.numel()
|
|
435
|
+
)
|
|
436
|
+
self.register_buffer(
|
|
437
|
+
"target_entropy_buffer", torch.tensor(target_entropy, device=device)
|
|
438
|
+
)
|
|
439
|
+
return self.target_entropy_buffer
|
|
440
|
+
return target_entropy
|
|
441
|
+
|
|
442
|
+
def _forward_value_estimator_keys(self, **kwargs) -> None:
|
|
443
|
+
if self._value_estimator is not None:
|
|
444
|
+
self._value_estimator.set_keys(
|
|
445
|
+
value=self._tensor_keys.value,
|
|
446
|
+
reward=self.tensor_keys.reward,
|
|
447
|
+
done=self.tensor_keys.done,
|
|
448
|
+
terminated=self.tensor_keys.terminated,
|
|
449
|
+
)
|
|
450
|
+
|
|
451
|
+
def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams):
|
|
452
|
+
if value_type is None:
|
|
453
|
+
value_type = self.default_value_estimator
|
|
454
|
+
|
|
455
|
+
# Handle ValueEstimatorBase instance or class
|
|
456
|
+
if isinstance(value_type, ValueEstimatorBase) or (
|
|
457
|
+
isinstance(value_type, type) and issubclass(value_type, ValueEstimatorBase)
|
|
458
|
+
):
|
|
459
|
+
return LossModule.make_value_estimator(self, value_type, **hyperparams)
|
|
460
|
+
|
|
461
|
+
self.value_type = value_type
|
|
462
|
+
|
|
463
|
+
# we will take care of computing the next value inside this module
|
|
464
|
+
value_net = None
|
|
465
|
+
|
|
466
|
+
hp = dict(default_value_kwargs(value_type))
|
|
467
|
+
hp.update(hyperparams)
|
|
468
|
+
if value_type is ValueEstimators.TD1:
|
|
469
|
+
self._value_estimator = TD1Estimator(
|
|
470
|
+
**hp,
|
|
471
|
+
value_network=value_net,
|
|
472
|
+
)
|
|
473
|
+
elif value_type is ValueEstimators.TD0:
|
|
474
|
+
self._value_estimator = TD0Estimator(
|
|
475
|
+
**hp,
|
|
476
|
+
value_network=value_net,
|
|
477
|
+
)
|
|
478
|
+
elif value_type is ValueEstimators.GAE:
|
|
479
|
+
raise NotImplementedError(
|
|
480
|
+
f"Value type {value_type} it not implemented for loss {type(self)}."
|
|
481
|
+
)
|
|
482
|
+
elif value_type is ValueEstimators.TDLambda:
|
|
483
|
+
self._value_estimator = TDLambdaEstimator(
|
|
484
|
+
**hp,
|
|
485
|
+
value_network=value_net,
|
|
486
|
+
)
|
|
487
|
+
else:
|
|
488
|
+
raise NotImplementedError(f"Unknown value type {value_type}")
|
|
489
|
+
|
|
490
|
+
tensor_keys = {
|
|
491
|
+
"value_target": "value_target",
|
|
492
|
+
"value": self.tensor_keys.value,
|
|
493
|
+
"reward": self.tensor_keys.reward,
|
|
494
|
+
"done": self.tensor_keys.done,
|
|
495
|
+
"terminated": self.tensor_keys.terminated,
|
|
496
|
+
}
|
|
497
|
+
self._value_estimator.set_keys(**tensor_keys)
|
|
498
|
+
|
|
499
|
+
@property
|
|
500
|
+
def in_keys(self):
|
|
501
|
+
keys = [
|
|
502
|
+
self.tensor_keys.action,
|
|
503
|
+
("next", self.tensor_keys.reward),
|
|
504
|
+
("next", self.tensor_keys.done),
|
|
505
|
+
("next", self.tensor_keys.terminated),
|
|
506
|
+
*self.actor_network.in_keys,
|
|
507
|
+
*[("next", key) for key in self.actor_network.in_keys],
|
|
508
|
+
*self.qvalue_network.in_keys,
|
|
509
|
+
]
|
|
510
|
+
|
|
511
|
+
return list(set(keys))
|
|
512
|
+
|
|
513
|
+
@property
|
|
514
|
+
def out_keys(self):
|
|
515
|
+
if self._out_keys is None:
|
|
516
|
+
keys = [
|
|
517
|
+
"loss_actor",
|
|
518
|
+
"loss_actor_bc",
|
|
519
|
+
"loss_qvalue",
|
|
520
|
+
"loss_cql",
|
|
521
|
+
"loss_alpha",
|
|
522
|
+
"alpha",
|
|
523
|
+
"entropy",
|
|
524
|
+
]
|
|
525
|
+
if self.with_lagrange:
|
|
526
|
+
keys.append("loss_alpha_prime")
|
|
527
|
+
self._out_keys = keys
|
|
528
|
+
return self._out_keys
|
|
529
|
+
|
|
530
|
+
@out_keys.setter
|
|
531
|
+
def out_keys(self, values):
|
|
532
|
+
self._out_keys = values
|
|
533
|
+
|
|
534
|
+
@dispatch
|
|
535
|
+
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
|
|
536
|
+
q_loss, metadata = self.q_loss(tensordict)
|
|
537
|
+
cql_loss, cql_metadata = self.cql_loss(tensordict)
|
|
538
|
+
if self.with_lagrange:
|
|
539
|
+
alpha_prime_loss, alpha_prime_metadata = self.alpha_prime_loss(tensordict)
|
|
540
|
+
metadata.update(alpha_prime_metadata)
|
|
541
|
+
loss_actor_bc, bc_metadata = self.actor_bc_loss(tensordict)
|
|
542
|
+
loss_actor, actor_metadata = self.actor_loss(tensordict)
|
|
543
|
+
loss_alpha, alpha_metadata = self.alpha_loss(actor_metadata)
|
|
544
|
+
metadata.update(bc_metadata)
|
|
545
|
+
metadata.update(cql_metadata)
|
|
546
|
+
metadata.update(actor_metadata)
|
|
547
|
+
metadata.update(alpha_metadata)
|
|
548
|
+
tensordict.set(
|
|
549
|
+
self.tensor_keys.priority, metadata.pop("td_error").detach().max(0).values
|
|
550
|
+
)
|
|
551
|
+
out = {
|
|
552
|
+
"loss_actor": loss_actor,
|
|
553
|
+
"loss_actor_bc": loss_actor_bc,
|
|
554
|
+
"loss_qvalue": q_loss,
|
|
555
|
+
"loss_cql": cql_loss,
|
|
556
|
+
"loss_alpha": loss_alpha,
|
|
557
|
+
"alpha": self._alpha,
|
|
558
|
+
"entropy": -actor_metadata.get(self.tensor_keys.log_prob).mean().detach(),
|
|
559
|
+
}
|
|
560
|
+
if self.with_lagrange:
|
|
561
|
+
out["loss_alpha_prime"] = alpha_prime_loss.mean()
|
|
562
|
+
td_loss = TensorDict(out)
|
|
563
|
+
self._clear_weakrefs(
|
|
564
|
+
tensordict,
|
|
565
|
+
td_loss,
|
|
566
|
+
"actor_network_params",
|
|
567
|
+
"qvalue_network_params",
|
|
568
|
+
"target_actor_network_params",
|
|
569
|
+
"target_qvalue_network_params",
|
|
570
|
+
)
|
|
571
|
+
return td_loss
|
|
572
|
+
|
|
573
|
+
@property
|
|
574
|
+
@_cache_values
|
|
575
|
+
def _cached_detach_qvalue_params(self):
|
|
576
|
+
return self.qvalue_network_params.detach()
|
|
577
|
+
|
|
578
|
+
def actor_bc_loss(self, tensordict: TensorDictBase) -> Tensor:
|
|
579
|
+
with set_exploration_type(
|
|
580
|
+
ExplorationType.RANDOM
|
|
581
|
+
), self.actor_network_params.to_module(self.actor_network):
|
|
582
|
+
dist = self.actor_network.get_dist(
|
|
583
|
+
tensordict,
|
|
584
|
+
)
|
|
585
|
+
a_reparm = dist.rsample()
|
|
586
|
+
log_prob = dist.log_prob(a_reparm)
|
|
587
|
+
bc_log_prob = dist.log_prob(tensordict.get(self.tensor_keys.action))
|
|
588
|
+
|
|
589
|
+
bc_actor_loss = self._alpha * log_prob - bc_log_prob
|
|
590
|
+
bc_actor_loss = _reduce(bc_actor_loss, reduction=self.reduction)
|
|
591
|
+
metadata = {"bc_log_prob": bc_log_prob.mean().detach()}
|
|
592
|
+
self._clear_weakrefs(
|
|
593
|
+
tensordict,
|
|
594
|
+
"actor_network_params",
|
|
595
|
+
"qvalue_network_params",
|
|
596
|
+
"target_actor_network_params",
|
|
597
|
+
"target_qvalue_network_params",
|
|
598
|
+
)
|
|
599
|
+
return bc_actor_loss, metadata
|
|
600
|
+
|
|
601
|
+
def actor_loss(self, tensordict: TensorDictBase) -> tuple[Tensor, dict]:
|
|
602
|
+
with set_exploration_type(
|
|
603
|
+
ExplorationType.RANDOM
|
|
604
|
+
), self.actor_network_params.to_module(self.actor_network):
|
|
605
|
+
dist = self.actor_network.get_dist(
|
|
606
|
+
tensordict,
|
|
607
|
+
)
|
|
608
|
+
a_reparm = dist.rsample()
|
|
609
|
+
log_prob = dist.log_prob(a_reparm)
|
|
610
|
+
|
|
611
|
+
td_q = tensordict.select(*self.qvalue_network.in_keys, strict=False)
|
|
612
|
+
if td_q is tensordict:
|
|
613
|
+
raise RuntimeError
|
|
614
|
+
td_q.set(self.tensor_keys.action, a_reparm)
|
|
615
|
+
td_q = self._vmap_qvalue_networkN0(
|
|
616
|
+
td_q,
|
|
617
|
+
self._cached_detach_qvalue_params,
|
|
618
|
+
)
|
|
619
|
+
min_q_logprob = (
|
|
620
|
+
td_q.get(self.tensor_keys.state_action_value).min(0)[0].squeeze(-1)
|
|
621
|
+
)
|
|
622
|
+
|
|
623
|
+
if log_prob.shape != min_q_logprob.shape:
|
|
624
|
+
raise RuntimeError(
|
|
625
|
+
f"Losses shape mismatch: {log_prob.shape} and {min_q_logprob.shape}"
|
|
626
|
+
)
|
|
627
|
+
|
|
628
|
+
metadata = {}
|
|
629
|
+
metadata[self.tensor_keys.log_prob] = log_prob.detach()
|
|
630
|
+
actor_loss = self._alpha * log_prob - min_q_logprob
|
|
631
|
+
actor_loss = _reduce(actor_loss, reduction=self.reduction)
|
|
632
|
+
self._clear_weakrefs(
|
|
633
|
+
tensordict,
|
|
634
|
+
"actor_network_params",
|
|
635
|
+
"qvalue_network_params",
|
|
636
|
+
"target_actor_network_params",
|
|
637
|
+
"target_qvalue_network_params",
|
|
638
|
+
)
|
|
639
|
+
return actor_loss, metadata
|
|
640
|
+
|
|
641
|
+
def _get_policy_actions(self, data, actor_params, num_actions=10):
|
|
642
|
+
batch_size = data.batch_size
|
|
643
|
+
batch_size = list(batch_size[:-1]) + [batch_size[-1] * num_actions]
|
|
644
|
+
in_keys = [unravel_key(key) for key in self.actor_network.in_keys]
|
|
645
|
+
|
|
646
|
+
def filter_and_repeat(name, x):
|
|
647
|
+
if name in in_keys:
|
|
648
|
+
return x.repeat_interleave(num_actions, dim=data.ndim - 1)
|
|
649
|
+
|
|
650
|
+
tensordict = data.named_apply(
|
|
651
|
+
filter_and_repeat, batch_size=batch_size, filter_empty=True
|
|
652
|
+
)
|
|
653
|
+
with set_exploration_type(ExplorationType.RANDOM), actor_params.data.to_module(
|
|
654
|
+
self.actor_network
|
|
655
|
+
):
|
|
656
|
+
dist = self.actor_network.get_dist(tensordict)
|
|
657
|
+
action = dist.rsample()
|
|
658
|
+
tensordict.set(self.tensor_keys.action, action)
|
|
659
|
+
sample_log_prob = dist.log_prob(action)
|
|
660
|
+
|
|
661
|
+
return (
|
|
662
|
+
tensordict.select(
|
|
663
|
+
*self.actor_network.in_keys, self.tensor_keys.action, strict=False
|
|
664
|
+
),
|
|
665
|
+
sample_log_prob,
|
|
666
|
+
)
|
|
667
|
+
|
|
668
|
+
def _get_value_v(self, tensordict, _alpha, actor_params, qval_params):
|
|
669
|
+
tensordict = tensordict.clone(False)
|
|
670
|
+
# get actions and log-probs
|
|
671
|
+
# TODO: wait for compile to handle this properly
|
|
672
|
+
actor_data = actor_params.data.to_module(self.actor_network)
|
|
673
|
+
with set_exploration_type(ExplorationType.RANDOM):
|
|
674
|
+
next_tensordict = tensordict.get("next").clone(False)
|
|
675
|
+
next_dist = self.actor_network.get_dist(next_tensordict)
|
|
676
|
+
next_action = next_dist.rsample()
|
|
677
|
+
next_tensordict.set(self.tensor_keys.action, next_action)
|
|
678
|
+
next_sample_log_prob = next_dist.log_prob(next_action)
|
|
679
|
+
actor_data.to_module(self.actor_network, return_swap=False)
|
|
680
|
+
|
|
681
|
+
# get q-values
|
|
682
|
+
if not self.max_q_backup:
|
|
683
|
+
next_tensordict_expand = self._vmap_qvalue_networkN0(
|
|
684
|
+
next_tensordict, qval_params.data
|
|
685
|
+
)
|
|
686
|
+
next_state_value = next_tensordict_expand.get(
|
|
687
|
+
self.tensor_keys.state_action_value
|
|
688
|
+
).min(0)[0]
|
|
689
|
+
if (
|
|
690
|
+
next_state_value.shape[-len(next_sample_log_prob.shape) :]
|
|
691
|
+
!= next_sample_log_prob.shape
|
|
692
|
+
):
|
|
693
|
+
next_sample_log_prob = next_sample_log_prob.unsqueeze(-1)
|
|
694
|
+
if not self.deterministic_backup:
|
|
695
|
+
next_state_value = next_state_value - _alpha * next_sample_log_prob
|
|
696
|
+
|
|
697
|
+
if self.max_q_backup:
|
|
698
|
+
next_tensordict, _ = self._get_policy_actions(
|
|
699
|
+
tensordict.get("next").copy(),
|
|
700
|
+
actor_params,
|
|
701
|
+
num_actions=self.num_random,
|
|
702
|
+
)
|
|
703
|
+
next_tensordict_expand = self._vmap_qvalue_networkN0(
|
|
704
|
+
next_tensordict, qval_params.data
|
|
705
|
+
)
|
|
706
|
+
|
|
707
|
+
state_action_value = next_tensordict_expand.get(
|
|
708
|
+
self.tensor_keys.state_action_value
|
|
709
|
+
)
|
|
710
|
+
# take max over actions
|
|
711
|
+
state_action_value = state_action_value.reshape(
|
|
712
|
+
torch.Size(
|
|
713
|
+
[self.num_qvalue_nets, *tensordict.shape, self.num_random, -1]
|
|
714
|
+
)
|
|
715
|
+
).max(-2)[0]
|
|
716
|
+
# take min over qvalue nets
|
|
717
|
+
next_state_value = state_action_value.min(0)[0]
|
|
718
|
+
|
|
719
|
+
tensordict.set(
|
|
720
|
+
("next", self.value_estimator.tensor_keys.value), next_state_value
|
|
721
|
+
)
|
|
722
|
+
target_value = self.value_estimator.value_estimate(tensordict).squeeze(-1)
|
|
723
|
+
return target_value
|
|
724
|
+
|
|
725
|
+
def q_loss(self, tensordict: TensorDictBase) -> tuple[Tensor, dict]:
|
|
726
|
+
# we pass the alpha value to the tensordict. Since it's a scalar, we must erase the batch-size first.
|
|
727
|
+
target_value = self._get_value_v(
|
|
728
|
+
tensordict.copy(),
|
|
729
|
+
self._alpha,
|
|
730
|
+
self.actor_network_params,
|
|
731
|
+
self.target_qvalue_network_params,
|
|
732
|
+
)
|
|
733
|
+
|
|
734
|
+
tensordict_pred_q = tensordict.select(
|
|
735
|
+
*self.qvalue_network.in_keys, strict=False
|
|
736
|
+
)
|
|
737
|
+
q_pred = self._vmap_qvalue_networkN0(
|
|
738
|
+
tensordict_pred_q, self.qvalue_network_params
|
|
739
|
+
).get(self.tensor_keys.state_action_value)
|
|
740
|
+
|
|
741
|
+
# write pred values in tensordict for cql loss
|
|
742
|
+
tensordict.set(self.tensor_keys.pred_q1, q_pred[0])
|
|
743
|
+
tensordict.set(self.tensor_keys.pred_q2, q_pred[1])
|
|
744
|
+
|
|
745
|
+
q_pred = q_pred.squeeze(-1)
|
|
746
|
+
loss_qval = distance_loss(
|
|
747
|
+
q_pred,
|
|
748
|
+
target_value.expand_as(q_pred),
|
|
749
|
+
loss_function=self.loss_function,
|
|
750
|
+
).sum(0)
|
|
751
|
+
loss_qval = _reduce(loss_qval, reduction=self.reduction)
|
|
752
|
+
td_error = (q_pred - target_value).pow(2)
|
|
753
|
+
metadata = {"td_error": td_error.detach()}
|
|
754
|
+
self._clear_weakrefs(
|
|
755
|
+
tensordict,
|
|
756
|
+
"actor_network_params",
|
|
757
|
+
"qvalue_network_params",
|
|
758
|
+
"target_actor_network_params",
|
|
759
|
+
"target_qvalue_network_params",
|
|
760
|
+
)
|
|
761
|
+
return loss_qval, metadata
|
|
762
|
+
|
|
763
|
+
def cql_loss(self, tensordict: TensorDictBase) -> tuple[Tensor, dict]:
|
|
764
|
+
pred_q1 = tensordict.get(self.tensor_keys.pred_q1)
|
|
765
|
+
pred_q2 = tensordict.get(self.tensor_keys.pred_q2)
|
|
766
|
+
|
|
767
|
+
if pred_q1 is None:
|
|
768
|
+
raise KeyError(
|
|
769
|
+
f"Couldn't find the pred_q1 with key {self.tensor_keys.pred_q1} in the input tensordict. "
|
|
770
|
+
"This could be caused by calling cql_loss method before q_loss method."
|
|
771
|
+
)
|
|
772
|
+
if pred_q2 is None:
|
|
773
|
+
raise KeyError(
|
|
774
|
+
f"Couldn't find the pred_q2 with key {self.tensor_keys.pred_q2} in the input tensordict. "
|
|
775
|
+
"This could be caused by calling cql_loss method before q_loss method."
|
|
776
|
+
)
|
|
777
|
+
|
|
778
|
+
random_actions_tensor = pred_q1.new_empty(
|
|
779
|
+
(
|
|
780
|
+
*tensordict.shape[:-1],
|
|
781
|
+
tensordict.shape[-1] * self.num_random,
|
|
782
|
+
tensordict[self.tensor_keys.action].shape[-1],
|
|
783
|
+
)
|
|
784
|
+
).uniform_(-1, 1)
|
|
785
|
+
curr_actions_td, curr_log_pis = self._get_policy_actions(
|
|
786
|
+
tensordict.copy(),
|
|
787
|
+
self.actor_network_params,
|
|
788
|
+
num_actions=self.num_random,
|
|
789
|
+
)
|
|
790
|
+
new_curr_actions_td, new_log_pis = self._get_policy_actions(
|
|
791
|
+
tensordict.get("next").copy(),
|
|
792
|
+
self.actor_network_params,
|
|
793
|
+
num_actions=self.num_random,
|
|
794
|
+
)
|
|
795
|
+
|
|
796
|
+
# process all in one forward pass
|
|
797
|
+
# stack qvalue params
|
|
798
|
+
qvalue_params = torch.cat(
|
|
799
|
+
[
|
|
800
|
+
self.qvalue_network_params,
|
|
801
|
+
self.qvalue_network_params,
|
|
802
|
+
self.qvalue_network_params,
|
|
803
|
+
],
|
|
804
|
+
0,
|
|
805
|
+
)
|
|
806
|
+
# select and stack input params
|
|
807
|
+
# q value random action
|
|
808
|
+
tensordict_q_random = tensordict.select(
|
|
809
|
+
*self.actor_network.in_keys, strict=False
|
|
810
|
+
)
|
|
811
|
+
|
|
812
|
+
batch_size = tensordict_q_random.batch_size
|
|
813
|
+
batch_size = list(batch_size[:-1]) + [batch_size[-1] * self.num_random]
|
|
814
|
+
in_keys = [unravel_key(key) for key in self.actor_network.in_keys]
|
|
815
|
+
|
|
816
|
+
def filter_and_repeat(name, x):
|
|
817
|
+
if name in in_keys:
|
|
818
|
+
return x.repeat_interleave(
|
|
819
|
+
self.num_random, dim=tensordict_q_random.ndim - 1
|
|
820
|
+
)
|
|
821
|
+
|
|
822
|
+
tensordict_q_random = tensordict_q_random.named_apply(
|
|
823
|
+
filter_and_repeat,
|
|
824
|
+
batch_size=batch_size,
|
|
825
|
+
filter_empty=True,
|
|
826
|
+
)
|
|
827
|
+
tensordict_q_random.set(self.tensor_keys.action, random_actions_tensor)
|
|
828
|
+
cql_tensordict = torch.cat(
|
|
829
|
+
[
|
|
830
|
+
tensordict_q_random.expand(
|
|
831
|
+
self.num_qvalue_nets, *curr_actions_td.batch_size
|
|
832
|
+
),
|
|
833
|
+
curr_actions_td.expand(
|
|
834
|
+
self.num_qvalue_nets, *curr_actions_td.batch_size
|
|
835
|
+
),
|
|
836
|
+
new_curr_actions_td.expand(
|
|
837
|
+
self.num_qvalue_nets, *curr_actions_td.batch_size
|
|
838
|
+
),
|
|
839
|
+
],
|
|
840
|
+
0,
|
|
841
|
+
)
|
|
842
|
+
cql_tensordict = cql_tensordict.contiguous()
|
|
843
|
+
|
|
844
|
+
cql_tensordict_expand = self._vmap_qvalue_network00(
|
|
845
|
+
cql_tensordict, qvalue_params
|
|
846
|
+
)
|
|
847
|
+
# get q values
|
|
848
|
+
state_action_value = cql_tensordict_expand.get(
|
|
849
|
+
self.tensor_keys.state_action_value
|
|
850
|
+
)
|
|
851
|
+
# split q values
|
|
852
|
+
(q_random, q_curr, q_new,) = state_action_value.split(
|
|
853
|
+
[
|
|
854
|
+
self.num_qvalue_nets,
|
|
855
|
+
self.num_qvalue_nets,
|
|
856
|
+
self.num_qvalue_nets,
|
|
857
|
+
],
|
|
858
|
+
dim=0,
|
|
859
|
+
)
|
|
860
|
+
|
|
861
|
+
# importance sammpled version
|
|
862
|
+
random_density = np.log(
|
|
863
|
+
0.5 ** curr_actions_td[self.tensor_keys.action].shape[-1]
|
|
864
|
+
)
|
|
865
|
+
cat_q1 = torch.cat(
|
|
866
|
+
[
|
|
867
|
+
q_random[0] - random_density,
|
|
868
|
+
q_new[0] - new_log_pis.detach().unsqueeze(-1),
|
|
869
|
+
q_curr[0] - curr_log_pis.detach().unsqueeze(-1),
|
|
870
|
+
],
|
|
871
|
+
-1,
|
|
872
|
+
)
|
|
873
|
+
cat_q2 = torch.cat(
|
|
874
|
+
[
|
|
875
|
+
q_random[1] - random_density,
|
|
876
|
+
q_new[1] - new_log_pis.detach().unsqueeze(-1),
|
|
877
|
+
q_curr[1] - curr_log_pis.detach().unsqueeze(-1),
|
|
878
|
+
],
|
|
879
|
+
-1,
|
|
880
|
+
)
|
|
881
|
+
|
|
882
|
+
min_qf1_loss = (
|
|
883
|
+
torch.logsumexp(cat_q1 / self.temperature, dim=-1)
|
|
884
|
+
* self.min_q_weight
|
|
885
|
+
* self.temperature
|
|
886
|
+
)
|
|
887
|
+
min_qf2_loss = (
|
|
888
|
+
torch.logsumexp(cat_q2 / self.temperature, dim=-1)
|
|
889
|
+
* self.min_q_weight
|
|
890
|
+
* self.temperature
|
|
891
|
+
)
|
|
892
|
+
|
|
893
|
+
# Subtract the log likelihood of data
|
|
894
|
+
cql_q1_loss = min_qf1_loss.flatten() - pred_q1 * self.min_q_weight
|
|
895
|
+
cql_q2_loss = min_qf2_loss.flatten() - pred_q2 * self.min_q_weight
|
|
896
|
+
|
|
897
|
+
# write cql losses in tensordict for alpha prime loss
|
|
898
|
+
tensordict.set(self.tensor_keys.cql_q1_loss, cql_q1_loss)
|
|
899
|
+
tensordict.set(self.tensor_keys.cql_q2_loss, cql_q2_loss)
|
|
900
|
+
|
|
901
|
+
cql_q_loss = (cql_q1_loss + cql_q2_loss).mean(-1)
|
|
902
|
+
cql_q_loss = _reduce(cql_q_loss, reduction=self.reduction)
|
|
903
|
+
|
|
904
|
+
self._clear_weakrefs(
|
|
905
|
+
tensordict,
|
|
906
|
+
"actor_network_params",
|
|
907
|
+
"qvalue_network_params",
|
|
908
|
+
"target_actor_network_params",
|
|
909
|
+
"target_qvalue_network_params",
|
|
910
|
+
)
|
|
911
|
+
return cql_q_loss, {}
|
|
912
|
+
|
|
913
|
+
def alpha_prime_loss(self, tensordict: TensorDictBase) -> Tensor:
|
|
914
|
+
cql_q1_loss = tensordict.get(self.tensor_keys.cql_q1_loss)
|
|
915
|
+
cql_q2_loss = tensordict.get(self.tensor_keys.cql_q2_loss)
|
|
916
|
+
|
|
917
|
+
if cql_q1_loss is None:
|
|
918
|
+
raise KeyError(
|
|
919
|
+
f"Couldn't find the cql_q1_loss with key {self.tensor_keys.cql_q1_loss} in the input tensordict. "
|
|
920
|
+
"This could be caused by calling alpha_prime_loss method before cql_loss method."
|
|
921
|
+
)
|
|
922
|
+
if cql_q2_loss is None:
|
|
923
|
+
raise KeyError(
|
|
924
|
+
f"Couldn't find the cql_q2_loss with key {self.tensor_keys.cql_q2_loss} in the input tensordict. "
|
|
925
|
+
"This could be caused by calling alpha_prime_loss method before cql_loss method."
|
|
926
|
+
)
|
|
927
|
+
|
|
928
|
+
alpha_prime = torch.clamp_max(self.log_alpha_prime.exp(), max=1000000.0)
|
|
929
|
+
min_qf1_loss = alpha_prime * (cql_q1_loss.mean() - self.target_action_gap)
|
|
930
|
+
min_qf2_loss = alpha_prime * (cql_q2_loss.mean() - self.target_action_gap)
|
|
931
|
+
|
|
932
|
+
alpha_prime_loss = (-min_qf1_loss - min_qf2_loss) * 0.5
|
|
933
|
+
alpha_prime_loss = _reduce(alpha_prime_loss, reduction=self.reduction)
|
|
934
|
+
self._clear_weakrefs(
|
|
935
|
+
tensordict,
|
|
936
|
+
"actor_network_params",
|
|
937
|
+
"qvalue_network_params",
|
|
938
|
+
"target_actor_network_params",
|
|
939
|
+
"target_qvalue_network_params",
|
|
940
|
+
)
|
|
941
|
+
return alpha_prime_loss, {}
|
|
942
|
+
|
|
943
|
+
def alpha_loss(self, tensordict: TensorDictBase) -> Tensor:
|
|
944
|
+
log_pi = tensordict.get(self.tensor_keys.log_prob)
|
|
945
|
+
if self.target_entropy is not None:
|
|
946
|
+
# we can compute this loss even if log_alpha is not a parameter
|
|
947
|
+
alpha_loss = -self.log_alpha * (log_pi.detach() + self.target_entropy)
|
|
948
|
+
else:
|
|
949
|
+
# placeholder
|
|
950
|
+
alpha_loss = torch.zeros_like(log_pi)
|
|
951
|
+
alpha_loss = _reduce(alpha_loss, reduction=self.reduction)
|
|
952
|
+
self._clear_weakrefs(
|
|
953
|
+
tensordict,
|
|
954
|
+
"actor_network_params",
|
|
955
|
+
"qvalue_network_params",
|
|
956
|
+
"target_actor_network_params",
|
|
957
|
+
"target_qvalue_network_params",
|
|
958
|
+
)
|
|
959
|
+
return alpha_loss, {}
|
|
960
|
+
|
|
961
|
+
@property
|
|
962
|
+
def _alpha(self):
|
|
963
|
+
if self.min_log_alpha is not None or self.max_log_alpha is not None:
|
|
964
|
+
self.log_alpha.data.clamp_(self.min_log_alpha, self.max_log_alpha)
|
|
965
|
+
alpha = self.log_alpha.data.exp()
|
|
966
|
+
return alpha
|
|
967
|
+
|
|
968
|
+
|
|
969
|
+
class DiscreteCQLLoss(LossModule):
|
|
970
|
+
"""TorchRL implementation of the discrete CQL loss.
|
|
971
|
+
|
|
972
|
+
This class implements the discrete conservative Q-learning (CQL) loss function, as presented in the paper
|
|
973
|
+
"Conservative Q-Learning for Offline Reinforcement Learning" (https://arxiv.org/abs/2006.04779).
|
|
974
|
+
|
|
975
|
+
Args:
|
|
976
|
+
value_network (Union[QValueActor, nn.Module]): The Q-value network used to estimate state-action values.
|
|
977
|
+
Keyword Args:
|
|
978
|
+
loss_function (Optional[str]): The distance function used to calculate the distance between the predicted
|
|
979
|
+
Q-values and the target Q-values. Defaults to ``l2``.
|
|
980
|
+
delay_value (bool): Whether to separate the target Q value
|
|
981
|
+
networks from the Q value networks used for data collection.
|
|
982
|
+
Default is ``True``.
|
|
983
|
+
gamma (:obj:`float`, optional): Discount factor. Default is ``None``.
|
|
984
|
+
action_space: The action space of the environment. If None, it is inferred from the value network.
|
|
985
|
+
Defaults to None.
|
|
986
|
+
reduction (str, optional): Specifies the reduction to apply to the output:
|
|
987
|
+
``"none"`` | ``"mean"`` | ``"sum"``. ``"none"``: no reduction will be applied,
|
|
988
|
+
``"mean"``: the sum of the output will be divided by the number of
|
|
989
|
+
elements in the output, ``"sum"``: the output will be summed. Default: ``"mean"``.
|
|
990
|
+
|
|
991
|
+
Examples:
|
|
992
|
+
>>> from torchrl.modules import MLP, QValueActor
|
|
993
|
+
>>> from torchrl.data import OneHot
|
|
994
|
+
>>> from torchrl.objectives import DiscreteCQLLoss
|
|
995
|
+
>>> n_obs, n_act = 4, 3
|
|
996
|
+
>>> value_net = MLP(in_features=n_obs, out_features=n_act)
|
|
997
|
+
>>> spec = OneHot(n_act)
|
|
998
|
+
>>> actor = QValueActor(value_net, in_keys=["observation"], action_space=spec)
|
|
999
|
+
>>> loss = DiscreteCQLLoss(actor, action_space=spec)
|
|
1000
|
+
>>> batch = [10,]
|
|
1001
|
+
>>> data = TensorDict({
|
|
1002
|
+
... "observation": torch.randn(*batch, n_obs),
|
|
1003
|
+
... "action": spec.rand(batch),
|
|
1004
|
+
... ("next", "observation"): torch.randn(*batch, n_obs),
|
|
1005
|
+
... ("next", "done"): torch.zeros(*batch, 1, dtype=torch.bool),
|
|
1006
|
+
... ("next", "terminated"): torch.zeros(*batch, 1, dtype=torch.bool),
|
|
1007
|
+
... ("next", "reward"): torch.randn(*batch, 1)
|
|
1008
|
+
... }, batch)
|
|
1009
|
+
>>> loss(data)
|
|
1010
|
+
TensorDict(
|
|
1011
|
+
fields={
|
|
1012
|
+
loss_cql: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
|
|
1013
|
+
loss_qvalue: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
|
|
1014
|
+
pred_value: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
|
|
1015
|
+
target_value: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
|
|
1016
|
+
td_error: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False)},
|
|
1017
|
+
batch_size=torch.Size([]),
|
|
1018
|
+
device=None,
|
|
1019
|
+
is_shared=False)
|
|
1020
|
+
|
|
1021
|
+
This class is compatible with non-tensordict based modules too and can be
|
|
1022
|
+
used without recurring to any tensordict-related primitive. In this case,
|
|
1023
|
+
the expected keyword arguments are:
|
|
1024
|
+
``["observation", "next_observation", "action", "next_reward", "next_done", "next_terminated"]``,
|
|
1025
|
+
and a single loss value is returned.
|
|
1026
|
+
|
|
1027
|
+
Examples:
|
|
1028
|
+
>>> from torchrl.objectives import DiscreteCQLLoss
|
|
1029
|
+
>>> from torchrl.data import OneHot
|
|
1030
|
+
>>> from torch import nn
|
|
1031
|
+
>>> import torch
|
|
1032
|
+
>>> n_obs = 3
|
|
1033
|
+
>>> n_action = 4
|
|
1034
|
+
>>> action_spec = OneHot(n_action)
|
|
1035
|
+
>>> value_network = nn.Linear(n_obs, n_action) # a simple value model
|
|
1036
|
+
>>> dcql_loss = DiscreteCQLLoss(value_network, action_space=action_spec)
|
|
1037
|
+
>>> # define data
|
|
1038
|
+
>>> observation = torch.randn(n_obs)
|
|
1039
|
+
>>> next_observation = torch.randn(n_obs)
|
|
1040
|
+
>>> action = action_spec.rand()
|
|
1041
|
+
>>> next_reward = torch.randn(1)
|
|
1042
|
+
>>> next_done = torch.zeros(1, dtype=torch.bool)
|
|
1043
|
+
>>> next_terminated = torch.zeros(1, dtype=torch.bool)
|
|
1044
|
+
>>> loss_val = dcql_loss(
|
|
1045
|
+
... observation=observation,
|
|
1046
|
+
... next_observation=next_observation,
|
|
1047
|
+
... next_reward=next_reward,
|
|
1048
|
+
... next_done=next_done,
|
|
1049
|
+
... next_terminated=next_terminated,
|
|
1050
|
+
... action=action)
|
|
1051
|
+
"""
|
|
1052
|
+
|
|
1053
|
+
@dataclass
|
|
1054
|
+
class _AcceptedKeys:
|
|
1055
|
+
"""Maintains default values for all configurable tensordict keys.
|
|
1056
|
+
|
|
1057
|
+
This class defines which tensordict keys can be set using '.set_keys(key_name=key_value)' and their
|
|
1058
|
+
default values.
|
|
1059
|
+
|
|
1060
|
+
Attributes:
|
|
1061
|
+
value_target (NestedKey): The input tensordict key where the target state value is expected.
|
|
1062
|
+
Will be used for the underlying value estimator Defaults to ``"value_target"``.
|
|
1063
|
+
value (NestedKey): The input tensordict key where the chosen action value is expected.
|
|
1064
|
+
Will be used for the underlying value estimator. Defaults to ``"chosen_action_value"``.
|
|
1065
|
+
action_value (NestedKey): The input tensordict key where the action value is expected.
|
|
1066
|
+
Defaults to ``"action_value"``.
|
|
1067
|
+
action (NestedKey): The input tensordict key where the action is expected.
|
|
1068
|
+
Defaults to ``"action"``.
|
|
1069
|
+
priority (NestedKey): The input tensordict key where the target priority is written to.
|
|
1070
|
+
Defaults to ``"td_error"``.
|
|
1071
|
+
reward (NestedKey): The input tensordict key where the reward is expected.
|
|
1072
|
+
Will be used for the underlying value estimator. Defaults to ``"reward"``.
|
|
1073
|
+
done (NestedKey): The key in the input TensorDict that indicates
|
|
1074
|
+
whether a trajectory is done. Will be used for the underlying value estimator.
|
|
1075
|
+
Defaults to ``"done"``.
|
|
1076
|
+
terminated (NestedKey): The key in the input TensorDict that indicates
|
|
1077
|
+
whether a trajectory is terminated. Will be used for the underlying value estimator.
|
|
1078
|
+
Defaults to ``"terminated"``.
|
|
1079
|
+
pred_val (NestedKey): The key where the predicted value will be written
|
|
1080
|
+
in the input tensordict. This value is subsequently used by cql_loss.
|
|
1081
|
+
Defaults to ``"pred_val"``.
|
|
1082
|
+
|
|
1083
|
+
"""
|
|
1084
|
+
|
|
1085
|
+
value_target: NestedKey = "value_target"
|
|
1086
|
+
value: NestedKey = "chosen_action_value"
|
|
1087
|
+
action_value: NestedKey = "action_value"
|
|
1088
|
+
action: NestedKey = "action"
|
|
1089
|
+
priority: NestedKey = "td_error"
|
|
1090
|
+
reward: NestedKey = "reward"
|
|
1091
|
+
done: NestedKey = "done"
|
|
1092
|
+
terminated: NestedKey = "terminated"
|
|
1093
|
+
pred_val: NestedKey = "pred_val"
|
|
1094
|
+
|
|
1095
|
+
tensor_keys: _AcceptedKeys
|
|
1096
|
+
default_keys = _AcceptedKeys
|
|
1097
|
+
default_value_estimator = ValueEstimators.TD0
|
|
1098
|
+
out_keys = [
|
|
1099
|
+
"loss_qvalue",
|
|
1100
|
+
"loss_cql",
|
|
1101
|
+
]
|
|
1102
|
+
|
|
1103
|
+
value_network: TensorDictModule
|
|
1104
|
+
value_network_params: TensorDictParams
|
|
1105
|
+
target_value_network_params: TensorDictParams
|
|
1106
|
+
|
|
1107
|
+
def __init__(
|
|
1108
|
+
self,
|
|
1109
|
+
value_network: QValueActor | nn.Module,
|
|
1110
|
+
*,
|
|
1111
|
+
loss_function: str | None = "l2",
|
|
1112
|
+
delay_value: bool = True,
|
|
1113
|
+
gamma: float | None = None,
|
|
1114
|
+
action_space=None,
|
|
1115
|
+
reduction: str | None = None,
|
|
1116
|
+
) -> None:
|
|
1117
|
+
self._in_keys = None
|
|
1118
|
+
if reduction is None:
|
|
1119
|
+
reduction = "mean"
|
|
1120
|
+
super().__init__()
|
|
1121
|
+
self.delay_value = delay_value
|
|
1122
|
+
value_network = ensure_tensordict_compatible(
|
|
1123
|
+
module=value_network,
|
|
1124
|
+
wrapper_type=QValueActor,
|
|
1125
|
+
action_space=action_space,
|
|
1126
|
+
)
|
|
1127
|
+
|
|
1128
|
+
self.convert_to_functional(
|
|
1129
|
+
value_network,
|
|
1130
|
+
"value_network",
|
|
1131
|
+
create_target_params=self.delay_value,
|
|
1132
|
+
)
|
|
1133
|
+
|
|
1134
|
+
self.value_network_in_keys = value_network.in_keys
|
|
1135
|
+
|
|
1136
|
+
self.loss_function = loss_function
|
|
1137
|
+
if action_space is None:
|
|
1138
|
+
# infer from value net
|
|
1139
|
+
if hasattr(value_network, "action_space"):
|
|
1140
|
+
action_space = value_network.spec
|
|
1141
|
+
else:
|
|
1142
|
+
# let's try with action_space then
|
|
1143
|
+
try:
|
|
1144
|
+
action_space = value_network.action_space
|
|
1145
|
+
except AttributeError:
|
|
1146
|
+
raise ValueError(self.ACTION_SPEC_ERROR)
|
|
1147
|
+
if action_space is None:
|
|
1148
|
+
warnings.warn(
|
|
1149
|
+
"action_space was not specified. DiscreteCQLLoss will default to 'one-hot'. "
|
|
1150
|
+
"This behavior will be deprecated soon and a space will have to be passed. "
|
|
1151
|
+
"Check the DiscreteCQLLoss documentation to see how to pass the action space."
|
|
1152
|
+
)
|
|
1153
|
+
action_space = "one-hot"
|
|
1154
|
+
self.action_space = _find_action_space(action_space)
|
|
1155
|
+
self.reduction = reduction
|
|
1156
|
+
|
|
1157
|
+
if gamma is not None:
|
|
1158
|
+
raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR)
|
|
1159
|
+
|
|
1160
|
+
def _forward_value_estimator_keys(self, **kwargs) -> None:
|
|
1161
|
+
if self._value_estimator is not None:
|
|
1162
|
+
self._value_estimator.set_keys(
|
|
1163
|
+
value_target=self.tensor_keys.value_target,
|
|
1164
|
+
value=self._tensor_keys.value,
|
|
1165
|
+
reward=self._tensor_keys.reward,
|
|
1166
|
+
done=self._tensor_keys.done,
|
|
1167
|
+
terminated=self._tensor_keys.terminated,
|
|
1168
|
+
)
|
|
1169
|
+
self._set_in_keys()
|
|
1170
|
+
|
|
1171
|
+
def _set_in_keys(self):
|
|
1172
|
+
in_keys = {
|
|
1173
|
+
self.tensor_keys.action,
|
|
1174
|
+
unravel_key(("next", self.tensor_keys.reward)),
|
|
1175
|
+
unravel_key(("next", self.tensor_keys.done)),
|
|
1176
|
+
unravel_key(("next", self.tensor_keys.terminated)),
|
|
1177
|
+
*self.value_network.in_keys,
|
|
1178
|
+
*[unravel_key(("next", key)) for key in self.value_network.in_keys],
|
|
1179
|
+
}
|
|
1180
|
+
self._in_keys = sorted(in_keys, key=str)
|
|
1181
|
+
|
|
1182
|
+
def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams):
|
|
1183
|
+
if value_type is None:
|
|
1184
|
+
value_type = self.default_value_estimator
|
|
1185
|
+
|
|
1186
|
+
# Handle ValueEstimatorBase instance or class
|
|
1187
|
+
if isinstance(value_type, ValueEstimatorBase) or (
|
|
1188
|
+
isinstance(value_type, type) and issubclass(value_type, ValueEstimatorBase)
|
|
1189
|
+
):
|
|
1190
|
+
return LossModule.make_value_estimator(self, value_type, **hyperparams)
|
|
1191
|
+
|
|
1192
|
+
self.value_type = value_type
|
|
1193
|
+
|
|
1194
|
+
# we will take care of computing the next value inside this module
|
|
1195
|
+
value_net = deepcopy(self.value_network)
|
|
1196
|
+
self.value_network_params.to_module(value_net, return_swap=False)
|
|
1197
|
+
|
|
1198
|
+
hp = dict(default_value_kwargs(value_type))
|
|
1199
|
+
hp.update(hyperparams)
|
|
1200
|
+
if value_type is ValueEstimators.TD1:
|
|
1201
|
+
self._value_estimator = TD1Estimator(
|
|
1202
|
+
**hp,
|
|
1203
|
+
value_network=value_net,
|
|
1204
|
+
)
|
|
1205
|
+
elif value_type is ValueEstimators.TD0:
|
|
1206
|
+
self._value_estimator = TD0Estimator(
|
|
1207
|
+
**hp,
|
|
1208
|
+
value_network=value_net,
|
|
1209
|
+
)
|
|
1210
|
+
elif value_type is ValueEstimators.GAE:
|
|
1211
|
+
raise NotImplementedError(
|
|
1212
|
+
f"Value type {value_type} it not implemented for loss {type(self)}."
|
|
1213
|
+
)
|
|
1214
|
+
elif value_type is ValueEstimators.TDLambda:
|
|
1215
|
+
self._value_estimator = TDLambdaEstimator(
|
|
1216
|
+
**hp,
|
|
1217
|
+
value_network=value_net,
|
|
1218
|
+
)
|
|
1219
|
+
else:
|
|
1220
|
+
raise NotImplementedError(f"Unknown value type {value_type}")
|
|
1221
|
+
|
|
1222
|
+
tensor_keys = {
|
|
1223
|
+
"value_target": "value_target",
|
|
1224
|
+
"value": self.tensor_keys.value,
|
|
1225
|
+
"reward": self.tensor_keys.reward,
|
|
1226
|
+
"done": self.tensor_keys.done,
|
|
1227
|
+
"terminated": self.tensor_keys.terminated,
|
|
1228
|
+
}
|
|
1229
|
+
self._value_estimator.set_keys(**tensor_keys)
|
|
1230
|
+
|
|
1231
|
+
@property
|
|
1232
|
+
def in_keys(self):
|
|
1233
|
+
if self._in_keys is None:
|
|
1234
|
+
self._set_in_keys()
|
|
1235
|
+
return self._in_keys
|
|
1236
|
+
|
|
1237
|
+
@in_keys.setter
|
|
1238
|
+
def in_keys(self, values):
|
|
1239
|
+
self._in_keys = values
|
|
1240
|
+
|
|
1241
|
+
@dispatch
|
|
1242
|
+
def value_loss(
|
|
1243
|
+
self,
|
|
1244
|
+
tensordict: TensorDictBase,
|
|
1245
|
+
) -> tuple[torch.Tensor, dict]:
|
|
1246
|
+
td_copy = tensordict.clone(False)
|
|
1247
|
+
with self.value_network_params.to_module(self.value_network):
|
|
1248
|
+
self.value_network(td_copy)
|
|
1249
|
+
|
|
1250
|
+
action = tensordict.get(self.tensor_keys.action)
|
|
1251
|
+
pred_val = td_copy.get(self.tensor_keys.action_value)
|
|
1252
|
+
|
|
1253
|
+
if self.action_space == "categorical":
|
|
1254
|
+
if action.shape != pred_val.shape:
|
|
1255
|
+
# unsqueeze the action if it lacks on trailing singleton dim
|
|
1256
|
+
action = action.unsqueeze(-1)
|
|
1257
|
+
pred_val_index = torch.gather(pred_val, -1, index=action).squeeze(-1)
|
|
1258
|
+
else:
|
|
1259
|
+
action = action.to(torch.float)
|
|
1260
|
+
pred_val_index = (pred_val * action).sum(-1)
|
|
1261
|
+
|
|
1262
|
+
# calculate target value
|
|
1263
|
+
target_value = self.value_estimator.value_estimate(
|
|
1264
|
+
td_copy, params=self._cached_detached_target_value_params
|
|
1265
|
+
).squeeze(-1)
|
|
1266
|
+
|
|
1267
|
+
td_error = (pred_val_index - target_value).pow(2)
|
|
1268
|
+
td_error = td_error.unsqueeze(-1)
|
|
1269
|
+
|
|
1270
|
+
tensordict.set(
|
|
1271
|
+
self.tensor_keys.priority,
|
|
1272
|
+
td_error,
|
|
1273
|
+
inplace=True,
|
|
1274
|
+
)
|
|
1275
|
+
tensordict.set(
|
|
1276
|
+
self.tensor_keys.pred_val,
|
|
1277
|
+
pred_val,
|
|
1278
|
+
inplace=True,
|
|
1279
|
+
)
|
|
1280
|
+
loss = 0.5 * distance_loss(pred_val_index, target_value, self.loss_function)
|
|
1281
|
+
loss = _reduce(loss, reduction=self.reduction)
|
|
1282
|
+
|
|
1283
|
+
metadata = {
|
|
1284
|
+
"td_error": td_error.mean(0).detach(),
|
|
1285
|
+
"pred_value": pred_val.mean().detach(),
|
|
1286
|
+
"target_value": target_value.mean().detach(),
|
|
1287
|
+
}
|
|
1288
|
+
|
|
1289
|
+
return loss, metadata
|
|
1290
|
+
|
|
1291
|
+
@dispatch
|
|
1292
|
+
def forward(self, tensordict: TensorDictBase) -> TensorDict:
|
|
1293
|
+
"""Computes the (DQN) CQL loss given a tensordict sampled from the replay buffer.
|
|
1294
|
+
|
|
1295
|
+
This function will also write a "td_error" key that can be used by prioritized replay buffers to assign
|
|
1296
|
+
a priority to items in the tensordict.
|
|
1297
|
+
|
|
1298
|
+
Args:
|
|
1299
|
+
tensordict (TensorDictBase): a tensordict with keys ["action"] and the in_keys of
|
|
1300
|
+
the value network (observations, "done", "terminated", "reward" in a "next" tensordict).
|
|
1301
|
+
|
|
1302
|
+
Returns:
|
|
1303
|
+
a tensor containing the CQL loss.
|
|
1304
|
+
|
|
1305
|
+
"""
|
|
1306
|
+
loss_qval, metadata = self.value_loss(tensordict)
|
|
1307
|
+
loss_cql, _ = self.cql_loss(tensordict)
|
|
1308
|
+
source = {
|
|
1309
|
+
"loss_qvalue": loss_qval,
|
|
1310
|
+
"loss_cql": loss_cql,
|
|
1311
|
+
}
|
|
1312
|
+
source.update(metadata)
|
|
1313
|
+
td_out = TensorDict(
|
|
1314
|
+
source=source,
|
|
1315
|
+
batch_size=[],
|
|
1316
|
+
)
|
|
1317
|
+
|
|
1318
|
+
return td_out
|
|
1319
|
+
|
|
1320
|
+
@property
|
|
1321
|
+
@_cache_values
|
|
1322
|
+
def _cached_detached_target_value_params(self):
|
|
1323
|
+
return self.target_value_network_params.detach()
|
|
1324
|
+
|
|
1325
|
+
def cql_loss(self, tensordict):
|
|
1326
|
+
qvalues = tensordict.get(self.tensor_keys.pred_val, default=None)
|
|
1327
|
+
if qvalues is None:
|
|
1328
|
+
raise KeyError(
|
|
1329
|
+
"Couldn't find the predicted qvalue with key {self.tensor_keys.pred_val} in the input tensordict. "
|
|
1330
|
+
"This could be caused by calling cql_loss method before value_loss."
|
|
1331
|
+
)
|
|
1332
|
+
|
|
1333
|
+
current_action = tensordict.get(self.tensor_keys.action)
|
|
1334
|
+
|
|
1335
|
+
logsumexp = torch.logsumexp(qvalues, dim=-1, keepdim=True)
|
|
1336
|
+
if self.action_space == "categorical":
|
|
1337
|
+
if current_action.shape != qvalues.shape:
|
|
1338
|
+
# unsqueeze the action if it lacks on trailing singleton dim
|
|
1339
|
+
current_action = current_action.unsqueeze(-1)
|
|
1340
|
+
q_a = qvalues.gather(-1, current_action)
|
|
1341
|
+
else:
|
|
1342
|
+
q_a = (qvalues * current_action).sum(dim=-1, keepdim=True)
|
|
1343
|
+
|
|
1344
|
+
loss_cql = (logsumexp - q_a).squeeze(-1)
|
|
1345
|
+
loss_cql = _reduce(loss_cql, reduction=self.reduction)
|
|
1346
|
+
return loss_cql, {}
|