torchrl 0.11.0__cp314-cp314t-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/benchmark_batched_envs.py +104 -0
- benchmarks/conftest.py +91 -0
- benchmarks/ecosystem/gym_env_throughput.py +321 -0
- benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
- benchmarks/requirements.txt +7 -0
- benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
- benchmarks/test_collectors_benchmark.py +240 -0
- benchmarks/test_compressed_storage_benchmark.py +145 -0
- benchmarks/test_envs_benchmark.py +133 -0
- benchmarks/test_llm.py +101 -0
- benchmarks/test_non_tensor_env_benchmark.py +70 -0
- benchmarks/test_objectives_benchmarks.py +1199 -0
- benchmarks/test_replaybuffer_benchmark.py +254 -0
- sota-check/README.md +35 -0
- sota-implementations/README.md +142 -0
- sota-implementations/a2c/README.md +39 -0
- sota-implementations/a2c/a2c_atari.py +291 -0
- sota-implementations/a2c/a2c_mujoco.py +273 -0
- sota-implementations/a2c/utils_atari.py +240 -0
- sota-implementations/a2c/utils_mujoco.py +160 -0
- sota-implementations/bandits/README.md +7 -0
- sota-implementations/bandits/dqn.py +126 -0
- sota-implementations/cql/cql_offline.py +198 -0
- sota-implementations/cql/cql_online.py +249 -0
- sota-implementations/cql/discrete_cql_offline.py +180 -0
- sota-implementations/cql/discrete_cql_online.py +227 -0
- sota-implementations/cql/utils.py +471 -0
- sota-implementations/crossq/crossq.py +271 -0
- sota-implementations/crossq/utils.py +320 -0
- sota-implementations/ddpg/ddpg.py +231 -0
- sota-implementations/ddpg/utils.py +325 -0
- sota-implementations/decision_transformer/dt.py +163 -0
- sota-implementations/decision_transformer/lamb.py +167 -0
- sota-implementations/decision_transformer/online_dt.py +178 -0
- sota-implementations/decision_transformer/utils.py +562 -0
- sota-implementations/discrete_sac/discrete_sac.py +243 -0
- sota-implementations/discrete_sac/utils.py +324 -0
- sota-implementations/dqn/README.md +30 -0
- sota-implementations/dqn/dqn_atari.py +272 -0
- sota-implementations/dqn/dqn_cartpole.py +236 -0
- sota-implementations/dqn/utils_atari.py +132 -0
- sota-implementations/dqn/utils_cartpole.py +90 -0
- sota-implementations/dreamer/README.md +129 -0
- sota-implementations/dreamer/dreamer.py +586 -0
- sota-implementations/dreamer/dreamer_utils.py +1107 -0
- sota-implementations/expert-iteration/README.md +352 -0
- sota-implementations/expert-iteration/ei_utils.py +770 -0
- sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
- sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
- sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
- sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
- sota-implementations/gail/gail.py +327 -0
- sota-implementations/gail/gail_utils.py +68 -0
- sota-implementations/gail/ppo_utils.py +157 -0
- sota-implementations/grpo/README.md +273 -0
- sota-implementations/grpo/grpo-async.py +437 -0
- sota-implementations/grpo/grpo-sync.py +435 -0
- sota-implementations/grpo/grpo_utils.py +843 -0
- sota-implementations/grpo/requirements_gsm8k.txt +11 -0
- sota-implementations/grpo/requirements_ifeval.txt +16 -0
- sota-implementations/impala/README.md +33 -0
- sota-implementations/impala/impala_multi_node_ray.py +292 -0
- sota-implementations/impala/impala_multi_node_submitit.py +284 -0
- sota-implementations/impala/impala_single_node.py +261 -0
- sota-implementations/impala/utils.py +184 -0
- sota-implementations/iql/discrete_iql.py +230 -0
- sota-implementations/iql/iql_offline.py +164 -0
- sota-implementations/iql/iql_online.py +225 -0
- sota-implementations/iql/utils.py +437 -0
- sota-implementations/multiagent/README.md +74 -0
- sota-implementations/multiagent/iql.py +237 -0
- sota-implementations/multiagent/maddpg_iddpg.py +266 -0
- sota-implementations/multiagent/mappo_ippo.py +267 -0
- sota-implementations/multiagent/qmix_vdn.py +271 -0
- sota-implementations/multiagent/sac.py +337 -0
- sota-implementations/multiagent/utils/__init__.py +4 -0
- sota-implementations/multiagent/utils/logging.py +151 -0
- sota-implementations/multiagent/utils/utils.py +43 -0
- sota-implementations/ppo/README.md +29 -0
- sota-implementations/ppo/ppo_atari.py +305 -0
- sota-implementations/ppo/ppo_mujoco.py +293 -0
- sota-implementations/ppo/utils_atari.py +238 -0
- sota-implementations/ppo/utils_mujoco.py +152 -0
- sota-implementations/ppo_trainer/train.py +21 -0
- sota-implementations/redq/README.md +7 -0
- sota-implementations/redq/redq.py +199 -0
- sota-implementations/redq/utils.py +1060 -0
- sota-implementations/sac/sac-async.py +266 -0
- sota-implementations/sac/sac.py +239 -0
- sota-implementations/sac/utils.py +381 -0
- sota-implementations/sac_trainer/train.py +16 -0
- sota-implementations/td3/td3.py +254 -0
- sota-implementations/td3/utils.py +319 -0
- sota-implementations/td3_bc/td3_bc.py +177 -0
- sota-implementations/td3_bc/utils.py +251 -0
- torchrl/__init__.py +144 -0
- torchrl/_extension.py +74 -0
- torchrl/_torchrl.cp314t-win_amd64.pyd +0 -0
- torchrl/_utils.py +1431 -0
- torchrl/collectors/__init__.py +48 -0
- torchrl/collectors/_base.py +1058 -0
- torchrl/collectors/_constants.py +88 -0
- torchrl/collectors/_multi_async.py +324 -0
- torchrl/collectors/_multi_base.py +1805 -0
- torchrl/collectors/_multi_sync.py +464 -0
- torchrl/collectors/_runner.py +581 -0
- torchrl/collectors/_single.py +2009 -0
- torchrl/collectors/_single_async.py +259 -0
- torchrl/collectors/collectors.py +62 -0
- torchrl/collectors/distributed/__init__.py +32 -0
- torchrl/collectors/distributed/default_configs.py +133 -0
- torchrl/collectors/distributed/generic.py +1306 -0
- torchrl/collectors/distributed/ray.py +1092 -0
- torchrl/collectors/distributed/rpc.py +1006 -0
- torchrl/collectors/distributed/sync.py +731 -0
- torchrl/collectors/distributed/utils.py +160 -0
- torchrl/collectors/llm/__init__.py +10 -0
- torchrl/collectors/llm/base.py +494 -0
- torchrl/collectors/llm/ray_collector.py +275 -0
- torchrl/collectors/llm/utils.py +36 -0
- torchrl/collectors/llm/weight_update/__init__.py +10 -0
- torchrl/collectors/llm/weight_update/vllm.py +348 -0
- torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
- torchrl/collectors/utils.py +433 -0
- torchrl/collectors/weight_update.py +591 -0
- torchrl/csrc/numpy_utils.h +38 -0
- torchrl/csrc/pybind.cpp +27 -0
- torchrl/csrc/segment_tree.h +458 -0
- torchrl/csrc/torch_utils.h +34 -0
- torchrl/csrc/utils.cpp +48 -0
- torchrl/csrc/utils.h +31 -0
- torchrl/data/__init__.py +187 -0
- torchrl/data/datasets/__init__.py +58 -0
- torchrl/data/datasets/atari_dqn.py +878 -0
- torchrl/data/datasets/common.py +281 -0
- torchrl/data/datasets/d4rl.py +489 -0
- torchrl/data/datasets/d4rl_infos.py +187 -0
- torchrl/data/datasets/gen_dgrl.py +375 -0
- torchrl/data/datasets/minari_data.py +643 -0
- torchrl/data/datasets/openml.py +177 -0
- torchrl/data/datasets/openx.py +798 -0
- torchrl/data/datasets/roboset.py +363 -0
- torchrl/data/datasets/utils.py +11 -0
- torchrl/data/datasets/vd4rl.py +432 -0
- torchrl/data/llm/__init__.py +34 -0
- torchrl/data/llm/dataset.py +491 -0
- torchrl/data/llm/history.py +1378 -0
- torchrl/data/llm/prompt.py +198 -0
- torchrl/data/llm/reward.py +225 -0
- torchrl/data/llm/topk.py +186 -0
- torchrl/data/llm/utils.py +543 -0
- torchrl/data/map/__init__.py +21 -0
- torchrl/data/map/hash.py +185 -0
- torchrl/data/map/query.py +204 -0
- torchrl/data/map/tdstorage.py +363 -0
- torchrl/data/map/tree.py +1434 -0
- torchrl/data/map/utils.py +103 -0
- torchrl/data/postprocs/__init__.py +8 -0
- torchrl/data/postprocs/postprocs.py +391 -0
- torchrl/data/replay_buffers/__init__.py +99 -0
- torchrl/data/replay_buffers/checkpointers.py +622 -0
- torchrl/data/replay_buffers/ray_buffer.py +292 -0
- torchrl/data/replay_buffers/replay_buffers.py +2376 -0
- torchrl/data/replay_buffers/samplers.py +2578 -0
- torchrl/data/replay_buffers/scheduler.py +265 -0
- torchrl/data/replay_buffers/storages.py +2412 -0
- torchrl/data/replay_buffers/utils.py +1042 -0
- torchrl/data/replay_buffers/writers.py +781 -0
- torchrl/data/tensor_specs.py +7101 -0
- torchrl/data/utils.py +334 -0
- torchrl/envs/__init__.py +265 -0
- torchrl/envs/async_envs.py +1105 -0
- torchrl/envs/batched_envs.py +3093 -0
- torchrl/envs/common.py +4241 -0
- torchrl/envs/custom/__init__.py +11 -0
- torchrl/envs/custom/chess.py +617 -0
- torchrl/envs/custom/llm.py +214 -0
- torchrl/envs/custom/pendulum.py +401 -0
- torchrl/envs/custom/san_moves.txt +29274 -0
- torchrl/envs/custom/tictactoeenv.py +288 -0
- torchrl/envs/env_creator.py +263 -0
- torchrl/envs/gym_like.py +752 -0
- torchrl/envs/libs/__init__.py +68 -0
- torchrl/envs/libs/_gym_utils.py +326 -0
- torchrl/envs/libs/brax.py +846 -0
- torchrl/envs/libs/dm_control.py +544 -0
- torchrl/envs/libs/envpool.py +447 -0
- torchrl/envs/libs/gym.py +2239 -0
- torchrl/envs/libs/habitat.py +138 -0
- torchrl/envs/libs/isaac_lab.py +87 -0
- torchrl/envs/libs/isaacgym.py +203 -0
- torchrl/envs/libs/jax_utils.py +166 -0
- torchrl/envs/libs/jumanji.py +963 -0
- torchrl/envs/libs/meltingpot.py +599 -0
- torchrl/envs/libs/openml.py +153 -0
- torchrl/envs/libs/openspiel.py +652 -0
- torchrl/envs/libs/pettingzoo.py +1042 -0
- torchrl/envs/libs/procgen.py +351 -0
- torchrl/envs/libs/robohive.py +429 -0
- torchrl/envs/libs/smacv2.py +645 -0
- torchrl/envs/libs/unity_mlagents.py +891 -0
- torchrl/envs/libs/utils.py +147 -0
- torchrl/envs/libs/vmas.py +813 -0
- torchrl/envs/llm/__init__.py +63 -0
- torchrl/envs/llm/chat.py +730 -0
- torchrl/envs/llm/datasets/README.md +4 -0
- torchrl/envs/llm/datasets/__init__.py +17 -0
- torchrl/envs/llm/datasets/gsm8k.py +353 -0
- torchrl/envs/llm/datasets/ifeval.py +274 -0
- torchrl/envs/llm/envs.py +789 -0
- torchrl/envs/llm/libs/README.md +3 -0
- torchrl/envs/llm/libs/__init__.py +8 -0
- torchrl/envs/llm/libs/mlgym.py +869 -0
- torchrl/envs/llm/reward/__init__.py +10 -0
- torchrl/envs/llm/reward/gsm8k.py +324 -0
- torchrl/envs/llm/reward/ifeval/README.md +13 -0
- torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
- torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
- torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
- torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
- torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
- torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
- torchrl/envs/llm/transforms/__init__.py +55 -0
- torchrl/envs/llm/transforms/browser.py +292 -0
- torchrl/envs/llm/transforms/dataloading.py +859 -0
- torchrl/envs/llm/transforms/format.py +73 -0
- torchrl/envs/llm/transforms/kl.py +1544 -0
- torchrl/envs/llm/transforms/policy_version.py +189 -0
- torchrl/envs/llm/transforms/reason.py +323 -0
- torchrl/envs/llm/transforms/tokenizer.py +321 -0
- torchrl/envs/llm/transforms/tools.py +1955 -0
- torchrl/envs/model_based/__init__.py +9 -0
- torchrl/envs/model_based/common.py +180 -0
- torchrl/envs/model_based/dreamer.py +112 -0
- torchrl/envs/transforms/__init__.py +147 -0
- torchrl/envs/transforms/functional.py +48 -0
- torchrl/envs/transforms/gym_transforms.py +203 -0
- torchrl/envs/transforms/module.py +341 -0
- torchrl/envs/transforms/r3m.py +372 -0
- torchrl/envs/transforms/ray_service.py +663 -0
- torchrl/envs/transforms/rb_transforms.py +214 -0
- torchrl/envs/transforms/transforms.py +11835 -0
- torchrl/envs/transforms/utils.py +94 -0
- torchrl/envs/transforms/vc1.py +307 -0
- torchrl/envs/transforms/vecnorm.py +845 -0
- torchrl/envs/transforms/vip.py +407 -0
- torchrl/envs/utils.py +1718 -0
- torchrl/envs/vec_envs.py +11 -0
- torchrl/modules/__init__.py +206 -0
- torchrl/modules/distributions/__init__.py +73 -0
- torchrl/modules/distributions/continuous.py +830 -0
- torchrl/modules/distributions/discrete.py +908 -0
- torchrl/modules/distributions/truncated_normal.py +187 -0
- torchrl/modules/distributions/utils.py +233 -0
- torchrl/modules/llm/__init__.py +62 -0
- torchrl/modules/llm/backends/__init__.py +65 -0
- torchrl/modules/llm/backends/vllm/__init__.py +94 -0
- torchrl/modules/llm/backends/vllm/_models.py +46 -0
- torchrl/modules/llm/backends/vllm/base.py +72 -0
- torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
- torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
- torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
- torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
- torchrl/modules/llm/policies/__init__.py +28 -0
- torchrl/modules/llm/policies/common.py +1809 -0
- torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
- torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
- torchrl/modules/llm/utils.py +23 -0
- torchrl/modules/mcts/__init__.py +21 -0
- torchrl/modules/mcts/scores.py +579 -0
- torchrl/modules/models/__init__.py +86 -0
- torchrl/modules/models/batchrenorm.py +119 -0
- torchrl/modules/models/decision_transformer.py +179 -0
- torchrl/modules/models/exploration.py +731 -0
- torchrl/modules/models/llm.py +156 -0
- torchrl/modules/models/model_based.py +596 -0
- torchrl/modules/models/models.py +1712 -0
- torchrl/modules/models/multiagent.py +1067 -0
- torchrl/modules/models/recipes/impala.py +185 -0
- torchrl/modules/models/utils.py +162 -0
- torchrl/modules/planners/__init__.py +10 -0
- torchrl/modules/planners/cem.py +228 -0
- torchrl/modules/planners/common.py +73 -0
- torchrl/modules/planners/mppi.py +265 -0
- torchrl/modules/tensordict_module/__init__.py +89 -0
- torchrl/modules/tensordict_module/actors.py +2457 -0
- torchrl/modules/tensordict_module/common.py +529 -0
- torchrl/modules/tensordict_module/exploration.py +814 -0
- torchrl/modules/tensordict_module/probabilistic.py +321 -0
- torchrl/modules/tensordict_module/rnn.py +1639 -0
- torchrl/modules/tensordict_module/sequence.py +132 -0
- torchrl/modules/tensordict_module/world_models.py +34 -0
- torchrl/modules/utils/__init__.py +38 -0
- torchrl/modules/utils/mappings.py +9 -0
- torchrl/modules/utils/utils.py +89 -0
- torchrl/objectives/__init__.py +78 -0
- torchrl/objectives/a2c.py +659 -0
- torchrl/objectives/common.py +753 -0
- torchrl/objectives/cql.py +1346 -0
- torchrl/objectives/crossq.py +710 -0
- torchrl/objectives/ddpg.py +453 -0
- torchrl/objectives/decision_transformer.py +371 -0
- torchrl/objectives/deprecated.py +516 -0
- torchrl/objectives/dqn.py +683 -0
- torchrl/objectives/dreamer.py +488 -0
- torchrl/objectives/functional.py +48 -0
- torchrl/objectives/gail.py +258 -0
- torchrl/objectives/iql.py +996 -0
- torchrl/objectives/llm/__init__.py +30 -0
- torchrl/objectives/llm/grpo.py +846 -0
- torchrl/objectives/llm/sft.py +482 -0
- torchrl/objectives/multiagent/__init__.py +8 -0
- torchrl/objectives/multiagent/qmixer.py +396 -0
- torchrl/objectives/ppo.py +1669 -0
- torchrl/objectives/redq.py +683 -0
- torchrl/objectives/reinforce.py +530 -0
- torchrl/objectives/sac.py +1580 -0
- torchrl/objectives/td3.py +570 -0
- torchrl/objectives/td3_bc.py +625 -0
- torchrl/objectives/utils.py +782 -0
- torchrl/objectives/value/__init__.py +28 -0
- torchrl/objectives/value/advantages.py +1956 -0
- torchrl/objectives/value/functional.py +1459 -0
- torchrl/objectives/value/utils.py +360 -0
- torchrl/record/__init__.py +17 -0
- torchrl/record/loggers/__init__.py +23 -0
- torchrl/record/loggers/common.py +48 -0
- torchrl/record/loggers/csv.py +226 -0
- torchrl/record/loggers/mlflow.py +142 -0
- torchrl/record/loggers/tensorboard.py +139 -0
- torchrl/record/loggers/trackio.py +163 -0
- torchrl/record/loggers/utils.py +78 -0
- torchrl/record/loggers/wandb.py +214 -0
- torchrl/record/recorder.py +554 -0
- torchrl/services/__init__.py +79 -0
- torchrl/services/base.py +109 -0
- torchrl/services/ray_service.py +453 -0
- torchrl/testing/__init__.py +107 -0
- torchrl/testing/assertions.py +179 -0
- torchrl/testing/dist_utils.py +122 -0
- torchrl/testing/env_creators.py +227 -0
- torchrl/testing/env_helper.py +35 -0
- torchrl/testing/gym_helpers.py +156 -0
- torchrl/testing/llm_mocks.py +119 -0
- torchrl/testing/mocking_classes.py +2720 -0
- torchrl/testing/modules.py +295 -0
- torchrl/testing/mp_helpers.py +15 -0
- torchrl/testing/ray_helpers.py +293 -0
- torchrl/testing/utils.py +190 -0
- torchrl/trainers/__init__.py +42 -0
- torchrl/trainers/algorithms/__init__.py +11 -0
- torchrl/trainers/algorithms/configs/__init__.py +705 -0
- torchrl/trainers/algorithms/configs/collectors.py +216 -0
- torchrl/trainers/algorithms/configs/common.py +41 -0
- torchrl/trainers/algorithms/configs/data.py +308 -0
- torchrl/trainers/algorithms/configs/envs.py +104 -0
- torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
- torchrl/trainers/algorithms/configs/logging.py +80 -0
- torchrl/trainers/algorithms/configs/modules.py +570 -0
- torchrl/trainers/algorithms/configs/objectives.py +177 -0
- torchrl/trainers/algorithms/configs/trainers.py +340 -0
- torchrl/trainers/algorithms/configs/transforms.py +955 -0
- torchrl/trainers/algorithms/configs/utils.py +252 -0
- torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
- torchrl/trainers/algorithms/configs/weight_update.py +159 -0
- torchrl/trainers/algorithms/ppo.py +373 -0
- torchrl/trainers/algorithms/sac.py +308 -0
- torchrl/trainers/helpers/__init__.py +40 -0
- torchrl/trainers/helpers/collectors.py +416 -0
- torchrl/trainers/helpers/envs.py +573 -0
- torchrl/trainers/helpers/logger.py +33 -0
- torchrl/trainers/helpers/losses.py +132 -0
- torchrl/trainers/helpers/models.py +658 -0
- torchrl/trainers/helpers/replay_buffer.py +59 -0
- torchrl/trainers/helpers/trainers.py +301 -0
- torchrl/trainers/trainers.py +2052 -0
- torchrl/weight_update/__init__.py +33 -0
- torchrl/weight_update/_distributed.py +749 -0
- torchrl/weight_update/_mp.py +624 -0
- torchrl/weight_update/_noupdate.py +102 -0
- torchrl/weight_update/_ray.py +1032 -0
- torchrl/weight_update/_rpc.py +284 -0
- torchrl/weight_update/_shared.py +891 -0
- torchrl/weight_update/llm/__init__.py +32 -0
- torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
- torchrl/weight_update/llm/vllm_nccl.py +710 -0
- torchrl/weight_update/utils.py +73 -0
- torchrl/weight_update/weight_sync_schemes.py +1244 -0
- torchrl-0.11.0.dist-info/LICENSE +21 -0
- torchrl-0.11.0.dist-info/METADATA +1307 -0
- torchrl-0.11.0.dist-info/RECORD +394 -0
- torchrl-0.11.0.dist-info/WHEEL +5 -0
- torchrl-0.11.0.dist-info/entry_points.txt +2 -0
- torchrl-0.11.0.dist-info/top_level.txt +7 -0
|
@@ -0,0 +1,1459 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import math
|
|
8
|
+
import warnings
|
|
9
|
+
from functools import wraps
|
|
10
|
+
|
|
11
|
+
import torch
|
|
12
|
+
|
|
13
|
+
try:
|
|
14
|
+
from torch.compiler import is_dynamo_compiling
|
|
15
|
+
except ImportError:
|
|
16
|
+
from torch._dynamo import is_compiling as is_dynamo_compiling
|
|
17
|
+
|
|
18
|
+
__all__ = [
|
|
19
|
+
"generalized_advantage_estimate",
|
|
20
|
+
"vec_generalized_advantage_estimate",
|
|
21
|
+
"td0_advantage_estimate",
|
|
22
|
+
"td0_return_estimate",
|
|
23
|
+
"td1_return_estimate",
|
|
24
|
+
"vec_td1_return_estimate",
|
|
25
|
+
"td1_advantage_estimate",
|
|
26
|
+
"vec_td1_advantage_estimate",
|
|
27
|
+
"td_lambda_return_estimate",
|
|
28
|
+
"vec_td_lambda_return_estimate",
|
|
29
|
+
"td_lambda_advantage_estimate",
|
|
30
|
+
"vec_td_lambda_advantage_estimate",
|
|
31
|
+
"vtrace_advantage_estimate",
|
|
32
|
+
]
|
|
33
|
+
|
|
34
|
+
from torchrl.objectives.value.utils import (
|
|
35
|
+
_custom_conv1d,
|
|
36
|
+
_get_num_per_traj,
|
|
37
|
+
_inv_pad_sequence,
|
|
38
|
+
_make_gammas_tensor,
|
|
39
|
+
_split_and_pad_sequence,
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
SHAPE_ERR = (
|
|
43
|
+
"All input tensors (value, reward and done states) must share a unique shape."
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def _transpose_time(fun):
|
|
48
|
+
"""Checks the time_dim argument of the function to allow for any dim.
|
|
49
|
+
|
|
50
|
+
If not -2, makes a transpose of all the multi-dim input tensors to bring
|
|
51
|
+
time at -2, and does the opposite transform for the outputs.
|
|
52
|
+
"""
|
|
53
|
+
ERROR = (
|
|
54
|
+
"The tensor shape and the time dimension are not compatible: "
|
|
55
|
+
"got {} and time_dim={}."
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
@wraps(fun)
|
|
59
|
+
def transposed_fun(*args, **kwargs):
|
|
60
|
+
time_dim = kwargs.pop("time_dim", -2)
|
|
61
|
+
|
|
62
|
+
def transpose_tensor(tensor):
|
|
63
|
+
if not isinstance(tensor, torch.Tensor) or tensor.numel() <= 1:
|
|
64
|
+
return tensor, False
|
|
65
|
+
if time_dim >= 0:
|
|
66
|
+
timedim = time_dim - tensor.ndim
|
|
67
|
+
else:
|
|
68
|
+
timedim = time_dim
|
|
69
|
+
if timedim < -tensor.ndim or timedim >= 0:
|
|
70
|
+
raise RuntimeError(ERROR.format(tensor.shape, timedim))
|
|
71
|
+
if tensor.ndim >= 2:
|
|
72
|
+
single_dim = False
|
|
73
|
+
tensor = tensor.transpose(timedim, -2)
|
|
74
|
+
elif tensor.ndim == 1 and timedim == -1:
|
|
75
|
+
single_dim = True
|
|
76
|
+
tensor = tensor.unsqueeze(-1)
|
|
77
|
+
else:
|
|
78
|
+
raise RuntimeError(ERROR.format(tensor.shape, timedim))
|
|
79
|
+
return tensor, single_dim
|
|
80
|
+
|
|
81
|
+
if time_dim != -2:
|
|
82
|
+
single_dim = False
|
|
83
|
+
if args:
|
|
84
|
+
args, single_dim = zip(*(transpose_tensor(arg) for arg in args))
|
|
85
|
+
single_dim = any(single_dim)
|
|
86
|
+
for k, item in list(kwargs.items()):
|
|
87
|
+
item, sd = transpose_tensor(item)
|
|
88
|
+
single_dim = single_dim or sd
|
|
89
|
+
kwargs[k] = item
|
|
90
|
+
# We don't pass time_dim because it isn't supposed to be used thereafter
|
|
91
|
+
out = fun(*args, **kwargs)
|
|
92
|
+
if isinstance(out, torch.Tensor):
|
|
93
|
+
out = transpose_tensor(out)[0]
|
|
94
|
+
if single_dim:
|
|
95
|
+
out = out.squeeze(-2)
|
|
96
|
+
return out
|
|
97
|
+
if single_dim:
|
|
98
|
+
return tuple(transpose_tensor(_out)[0].squeeze(-2) for _out in out)
|
|
99
|
+
return tuple(transpose_tensor(_out)[0] for _out in out)
|
|
100
|
+
# We don't pass time_dim because it isn't supposed to be used thereafter
|
|
101
|
+
out = fun(*args, **kwargs)
|
|
102
|
+
if isinstance(out, tuple):
|
|
103
|
+
for _out in out:
|
|
104
|
+
if _out.ndim < 2:
|
|
105
|
+
raise RuntimeError(ERROR.format(_out.shape, time_dim))
|
|
106
|
+
else:
|
|
107
|
+
if out.ndim < 2:
|
|
108
|
+
raise RuntimeError(ERROR.format(out.shape, time_dim))
|
|
109
|
+
return out
|
|
110
|
+
|
|
111
|
+
return transposed_fun
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
########################################################################
|
|
115
|
+
# GAE
|
|
116
|
+
# ---
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
@_transpose_time
|
|
120
|
+
def generalized_advantage_estimate(
|
|
121
|
+
gamma: float,
|
|
122
|
+
lmbda: float,
|
|
123
|
+
state_value: torch.Tensor,
|
|
124
|
+
next_state_value: torch.Tensor,
|
|
125
|
+
reward: torch.Tensor,
|
|
126
|
+
done: torch.Tensor,
|
|
127
|
+
terminated: torch.Tensor | None = None,
|
|
128
|
+
*,
|
|
129
|
+
time_dim: int = -2,
|
|
130
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
131
|
+
"""Generalized advantage estimate of a trajectory.
|
|
132
|
+
|
|
133
|
+
Refer to "HIGH-DIMENSIONAL CONTINUOUS CONTROL USING GENERALIZED ADVANTAGE ESTIMATION"
|
|
134
|
+
https://arxiv.org/pdf/1506.02438.pdf for more context.
|
|
135
|
+
|
|
136
|
+
Args:
|
|
137
|
+
gamma (scalar): exponential mean discount.
|
|
138
|
+
lmbda (scalar): trajectory discount.
|
|
139
|
+
state_value (Tensor): value function result with old_state input.
|
|
140
|
+
next_state_value (Tensor): value function result with new_state input.
|
|
141
|
+
reward (Tensor): reward of taking actions in the environment.
|
|
142
|
+
done (Tensor): boolean flag for end of trajectory.
|
|
143
|
+
terminated (Tensor): boolean flag for the end of episode. Defaults to ``done``
|
|
144
|
+
if not provided.
|
|
145
|
+
time_dim (int): dimension where the time is unrolled. Defaults to -2.
|
|
146
|
+
|
|
147
|
+
All tensors (values, reward and done) must have shape
|
|
148
|
+
``[*Batch x TimeSteps x *F]``, with ``*F`` feature dimensions.
|
|
149
|
+
|
|
150
|
+
"""
|
|
151
|
+
if terminated is None:
|
|
152
|
+
terminated = done.clone()
|
|
153
|
+
if not (
|
|
154
|
+
next_state_value.shape
|
|
155
|
+
== state_value.shape
|
|
156
|
+
== reward.shape
|
|
157
|
+
== done.shape
|
|
158
|
+
== terminated.shape
|
|
159
|
+
):
|
|
160
|
+
raise RuntimeError(SHAPE_ERR)
|
|
161
|
+
dtype = next_state_value.dtype
|
|
162
|
+
device = state_value.device
|
|
163
|
+
not_done = (~done).int()
|
|
164
|
+
not_terminated = (~terminated).int()
|
|
165
|
+
*batch_size, time_steps, lastdim = not_done.shape
|
|
166
|
+
advantage = torch.empty(
|
|
167
|
+
*batch_size, time_steps, lastdim, device=device, dtype=dtype
|
|
168
|
+
)
|
|
169
|
+
prev_advantage = 0
|
|
170
|
+
g_not_terminated = gamma * not_terminated
|
|
171
|
+
delta = reward + (g_not_terminated * next_state_value) - state_value
|
|
172
|
+
discount = lmbda * gamma * not_done
|
|
173
|
+
for t in reversed(range(time_steps)):
|
|
174
|
+
prev_advantage = advantage[..., t, :] = delta[..., t, :] + (
|
|
175
|
+
prev_advantage * discount[..., t, :]
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
value_target = advantage + state_value
|
|
179
|
+
|
|
180
|
+
return advantage, value_target
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
def _geom_series_like(t, r, thr):
|
|
184
|
+
"""Creates a geometric series of the form [1, gammalmbda, gammalmbda**2] with the shape of `t`.
|
|
185
|
+
|
|
186
|
+
Drops all elements which are smaller than `thr` (unless in compile mode).
|
|
187
|
+
"""
|
|
188
|
+
if is_dynamo_compiling():
|
|
189
|
+
if isinstance(r, torch.Tensor):
|
|
190
|
+
rs = r.expand_as(t)
|
|
191
|
+
else:
|
|
192
|
+
rs = torch.full_like(t, r)
|
|
193
|
+
else:
|
|
194
|
+
if isinstance(r, torch.Tensor):
|
|
195
|
+
r = r.item()
|
|
196
|
+
|
|
197
|
+
if r == 0.0:
|
|
198
|
+
return torch.zeros_like(t)
|
|
199
|
+
elif r >= 1.0:
|
|
200
|
+
lim = t.numel()
|
|
201
|
+
else:
|
|
202
|
+
lim = int(math.log(thr) / math.log(r))
|
|
203
|
+
|
|
204
|
+
rs = torch.full_like(t[:lim], r)
|
|
205
|
+
rs[0] = 1.0
|
|
206
|
+
rs = rs.cumprod(0)
|
|
207
|
+
rs = rs.unsqueeze(-1)
|
|
208
|
+
return rs
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
def _fast_vec_gae(
|
|
212
|
+
reward: torch.Tensor,
|
|
213
|
+
state_value: torch.Tensor,
|
|
214
|
+
next_state_value: torch.Tensor,
|
|
215
|
+
done: torch.Tensor,
|
|
216
|
+
terminated: torch.Tensor,
|
|
217
|
+
gamma: float,
|
|
218
|
+
lmbda: float,
|
|
219
|
+
thr: float = 1e-7,
|
|
220
|
+
):
|
|
221
|
+
"""Fast vectorized Generalized Advantage Estimate when gamma and lmbda are scalars.
|
|
222
|
+
|
|
223
|
+
In contrast to `vec_generalized_advantage_estimate` this function does not need
|
|
224
|
+
to allocate a big tensor of the form [B, T, T].
|
|
225
|
+
|
|
226
|
+
Args:
|
|
227
|
+
reward (torch.Tensor): a [*B, T, F] tensor containing rewards
|
|
228
|
+
state_value (torch.Tensor): a [*B, T, F] tensor containing state values (value function)
|
|
229
|
+
next_state_value (torch.Tensor): a [*B, T, F] tensor containing next state values (value function)
|
|
230
|
+
done (torch.Tensor): a [B, T] boolean tensor containing the done states.
|
|
231
|
+
terminated (torch.Tensor): a [B, T] boolean tensor containing the terminated states.
|
|
232
|
+
gamma (scalar): the gamma decay (trajectory discount)
|
|
233
|
+
lmbda (scalar): the lambda decay (exponential mean discount)
|
|
234
|
+
thr (:obj:`float`): threshold for the filter. Below this limit, components will ignored.
|
|
235
|
+
Defaults to 1e-7.
|
|
236
|
+
|
|
237
|
+
All tensors (values, reward and done) must have shape
|
|
238
|
+
``[*Batch x TimeSteps x F]``, with ``F`` feature dimensions.
|
|
239
|
+
|
|
240
|
+
"""
|
|
241
|
+
# _get_num_per_traj and _split_and_pad_sequence need
|
|
242
|
+
# time dimension at last position
|
|
243
|
+
done = done.transpose(-2, -1)
|
|
244
|
+
terminated = terminated.transpose(-2, -1)
|
|
245
|
+
reward = reward.transpose(-2, -1)
|
|
246
|
+
state_value = state_value.transpose(-2, -1)
|
|
247
|
+
next_state_value = next_state_value.transpose(-2, -1)
|
|
248
|
+
|
|
249
|
+
gammalmbda = gamma * lmbda
|
|
250
|
+
not_terminated = (~terminated).int()
|
|
251
|
+
td0 = reward + not_terminated * gamma * next_state_value - state_value
|
|
252
|
+
|
|
253
|
+
num_per_traj = _get_num_per_traj(done)
|
|
254
|
+
td0_flat, mask = _split_and_pad_sequence(td0, num_per_traj, return_mask=True)
|
|
255
|
+
|
|
256
|
+
gammalmbdas = _geom_series_like(td0_flat[0], gammalmbda, thr=thr)
|
|
257
|
+
|
|
258
|
+
advantage = _custom_conv1d(td0_flat.unsqueeze(1), gammalmbdas)
|
|
259
|
+
advantage = advantage.squeeze(1)
|
|
260
|
+
advantage = advantage[mask].view_as(reward)
|
|
261
|
+
|
|
262
|
+
value_target = advantage + state_value
|
|
263
|
+
|
|
264
|
+
advantage = advantage.transpose(-1, -2)
|
|
265
|
+
value_target = value_target.transpose(-1, -2)
|
|
266
|
+
|
|
267
|
+
return advantage, value_target
|
|
268
|
+
|
|
269
|
+
|
|
270
|
+
@_transpose_time
|
|
271
|
+
def vec_generalized_advantage_estimate(
|
|
272
|
+
gamma: float | torch.Tensor,
|
|
273
|
+
lmbda: float | torch.Tensor,
|
|
274
|
+
state_value: torch.Tensor,
|
|
275
|
+
next_state_value: torch.Tensor,
|
|
276
|
+
reward: torch.Tensor,
|
|
277
|
+
done: torch.Tensor,
|
|
278
|
+
terminated: torch.Tensor | None = None,
|
|
279
|
+
*,
|
|
280
|
+
time_dim: int = -2,
|
|
281
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
282
|
+
"""Vectorized Generalized advantage estimate of a trajectory.
|
|
283
|
+
|
|
284
|
+
Refer to "HIGH-DIMENSIONAL CONTINUOUS CONTROL USING GENERALIZED ADVANTAGE ESTIMATION"
|
|
285
|
+
https://arxiv.org/pdf/1506.02438.pdf for more context.
|
|
286
|
+
|
|
287
|
+
Args:
|
|
288
|
+
gamma (scalar): exponential mean discount.
|
|
289
|
+
lmbda (scalar): trajectory discount.
|
|
290
|
+
state_value (Tensor): value function result with old_state input.
|
|
291
|
+
next_state_value (Tensor): value function result with new_state input.
|
|
292
|
+
reward (Tensor): reward of taking actions in the environment.
|
|
293
|
+
done (Tensor): boolean flag for end of trajectory.
|
|
294
|
+
terminated (Tensor): boolean flag for the end of episode. Defaults to ``done``
|
|
295
|
+
if not provided.
|
|
296
|
+
time_dim (int): dimension where the time is unrolled. Defaults to -2.
|
|
297
|
+
|
|
298
|
+
All tensors (values, reward and done) must have shape
|
|
299
|
+
``[*Batch x TimeSteps x *F]``, with ``*F`` feature dimensions.
|
|
300
|
+
|
|
301
|
+
"""
|
|
302
|
+
if terminated is None:
|
|
303
|
+
terminated = done.clone()
|
|
304
|
+
if not (
|
|
305
|
+
next_state_value.shape
|
|
306
|
+
== state_value.shape
|
|
307
|
+
== reward.shape
|
|
308
|
+
== done.shape
|
|
309
|
+
== terminated.shape
|
|
310
|
+
):
|
|
311
|
+
raise RuntimeError(SHAPE_ERR)
|
|
312
|
+
dtype = state_value.dtype
|
|
313
|
+
*batch_size, time_steps, lastdim = terminated.shape
|
|
314
|
+
|
|
315
|
+
value = gamma * lmbda
|
|
316
|
+
|
|
317
|
+
if isinstance(value, torch.Tensor) and value.numel() > 1:
|
|
318
|
+
# create tensor while ensuring that gradients are passed
|
|
319
|
+
not_done = (~done).to(dtype)
|
|
320
|
+
gammalmbdas = not_done * value
|
|
321
|
+
else:
|
|
322
|
+
# when gamma and lmbda are scalars, use fast_vec_gae implementation
|
|
323
|
+
return _fast_vec_gae(
|
|
324
|
+
reward=reward,
|
|
325
|
+
state_value=state_value,
|
|
326
|
+
next_state_value=next_state_value,
|
|
327
|
+
done=done,
|
|
328
|
+
terminated=terminated,
|
|
329
|
+
gamma=gamma,
|
|
330
|
+
lmbda=lmbda,
|
|
331
|
+
)
|
|
332
|
+
|
|
333
|
+
gammalmbdas = _make_gammas_tensor(gammalmbdas, time_steps, True)
|
|
334
|
+
gammalmbdas = gammalmbdas.cumprod(-2)
|
|
335
|
+
|
|
336
|
+
# Skip data-dependent truncation optimization during compile (causes guards)
|
|
337
|
+
if not is_dynamo_compiling():
|
|
338
|
+
first_below_thr = gammalmbdas < 1e-7
|
|
339
|
+
# if we have multiple gammas, we only want to truncate if _all_ of
|
|
340
|
+
# the geometric sequences fall below the threshold
|
|
341
|
+
first_below_thr = first_below_thr.flatten(0, 1).all(0).all(-1)
|
|
342
|
+
if first_below_thr.any():
|
|
343
|
+
first_below_thr = torch.where(first_below_thr)[0][0].item()
|
|
344
|
+
gammalmbdas = gammalmbdas[..., :first_below_thr, :]
|
|
345
|
+
|
|
346
|
+
not_terminated = (~terminated).to(dtype)
|
|
347
|
+
td0 = reward + not_terminated * gamma * next_state_value - state_value
|
|
348
|
+
|
|
349
|
+
if len(batch_size) > 1:
|
|
350
|
+
td0 = td0.flatten(0, len(batch_size) - 1)
|
|
351
|
+
elif not len(batch_size):
|
|
352
|
+
td0 = td0.unsqueeze(0)
|
|
353
|
+
|
|
354
|
+
td0_r = td0.transpose(-2, -1)
|
|
355
|
+
shapes = td0_r.shape[:2]
|
|
356
|
+
if lastdim != 1:
|
|
357
|
+
# then we flatten again the first dims and reset a singleton in between
|
|
358
|
+
td0_r = td0_r.flatten(0, 1).unsqueeze(1)
|
|
359
|
+
advantage = _custom_conv1d(td0_r, gammalmbdas)
|
|
360
|
+
if lastdim != 1:
|
|
361
|
+
advantage = advantage.squeeze(1).unflatten(0, shapes)
|
|
362
|
+
|
|
363
|
+
if len(batch_size) > 1:
|
|
364
|
+
advantage = advantage.unflatten(0, batch_size)
|
|
365
|
+
elif not len(batch_size):
|
|
366
|
+
advantage = advantage.squeeze(0)
|
|
367
|
+
|
|
368
|
+
advantage = advantage.transpose(-2, -1)
|
|
369
|
+
value_target = advantage + state_value
|
|
370
|
+
return advantage, value_target
|
|
371
|
+
|
|
372
|
+
|
|
373
|
+
########################################################################
|
|
374
|
+
# TD(0)
|
|
375
|
+
# -----
|
|
376
|
+
|
|
377
|
+
|
|
378
|
+
def td0_advantage_estimate(
|
|
379
|
+
gamma: float,
|
|
380
|
+
state_value: torch.Tensor,
|
|
381
|
+
next_state_value: torch.Tensor,
|
|
382
|
+
reward: torch.Tensor,
|
|
383
|
+
done: torch.Tensor,
|
|
384
|
+
terminated: torch.Tensor | None = None,
|
|
385
|
+
) -> torch.Tensor:
|
|
386
|
+
"""TD(0) advantage estimate of a trajectory.
|
|
387
|
+
|
|
388
|
+
Also known as bootstrapped Temporal Difference or one-step return.
|
|
389
|
+
|
|
390
|
+
Args:
|
|
391
|
+
gamma (scalar): exponential mean discount.
|
|
392
|
+
state_value (Tensor): value function result with old_state input.
|
|
393
|
+
next_state_value (Tensor): value function result with new_state input.
|
|
394
|
+
reward (Tensor): reward of taking actions in the environment.
|
|
395
|
+
done (Tensor): boolean flag for end of trajectory.
|
|
396
|
+
terminated (Tensor): boolean flag for the end of episode. Defaults to ``done``
|
|
397
|
+
if not provided.
|
|
398
|
+
|
|
399
|
+
All tensors (values, reward and done) must have shape
|
|
400
|
+
``[*Batch x TimeSteps x *F]``, with ``*F`` feature dimensions.
|
|
401
|
+
|
|
402
|
+
"""
|
|
403
|
+
if terminated is None:
|
|
404
|
+
terminated = done.clone()
|
|
405
|
+
if not (
|
|
406
|
+
next_state_value.shape
|
|
407
|
+
== state_value.shape
|
|
408
|
+
== reward.shape
|
|
409
|
+
== done.shape
|
|
410
|
+
== terminated.shape
|
|
411
|
+
):
|
|
412
|
+
raise RuntimeError(SHAPE_ERR)
|
|
413
|
+
returns = td0_return_estimate(gamma, next_state_value, reward, terminated)
|
|
414
|
+
advantage = returns - state_value
|
|
415
|
+
return advantage
|
|
416
|
+
|
|
417
|
+
|
|
418
|
+
def td0_return_estimate(
|
|
419
|
+
gamma: float,
|
|
420
|
+
next_state_value: torch.Tensor,
|
|
421
|
+
reward: torch.Tensor,
|
|
422
|
+
terminated: torch.Tensor | None = None,
|
|
423
|
+
*,
|
|
424
|
+
done: torch.Tensor | None = None,
|
|
425
|
+
) -> torch.Tensor:
|
|
426
|
+
# noqa: D417
|
|
427
|
+
"""TD(0) discounted return estimate of a trajectory.
|
|
428
|
+
|
|
429
|
+
Also known as bootstrapped Temporal Difference or one-step return.
|
|
430
|
+
|
|
431
|
+
Args:
|
|
432
|
+
gamma (scalar): exponential mean discount.
|
|
433
|
+
next_state_value (Tensor): value function result with new_state input.
|
|
434
|
+
must be a [Batch x TimeSteps x 1] or [Batch x TimeSteps] tensor
|
|
435
|
+
reward (Tensor): reward of taking actions in the environment.
|
|
436
|
+
must be a [Batch x TimeSteps x 1] or [Batch x TimeSteps] tensor
|
|
437
|
+
terminated (Tensor): boolean flag for the end of episode. Defaults to ``done``
|
|
438
|
+
if not provided.
|
|
439
|
+
|
|
440
|
+
Keyword Args:
|
|
441
|
+
done (Tensor): Deprecated. Use ``terminated`` instead.
|
|
442
|
+
|
|
443
|
+
All tensors (values, reward and done) must have shape
|
|
444
|
+
``[*Batch x TimeSteps x *F]``, with ``*F`` feature dimensions.
|
|
445
|
+
|
|
446
|
+
"""
|
|
447
|
+
if done is not None and terminated is None:
|
|
448
|
+
terminated = done.clone()
|
|
449
|
+
warnings.warn(
|
|
450
|
+
"done for td0_return_estimate is deprecated. Pass ``terminated`` instead."
|
|
451
|
+
)
|
|
452
|
+
if not (next_state_value.shape == reward.shape == terminated.shape):
|
|
453
|
+
raise RuntimeError(SHAPE_ERR)
|
|
454
|
+
not_terminated = (~terminated).int()
|
|
455
|
+
returns = reward + gamma * not_terminated * next_state_value
|
|
456
|
+
return returns
|
|
457
|
+
|
|
458
|
+
|
|
459
|
+
########################################################################
|
|
460
|
+
# TD(1)
|
|
461
|
+
# ----------
|
|
462
|
+
|
|
463
|
+
|
|
464
|
+
@_transpose_time
|
|
465
|
+
def td1_return_estimate(
|
|
466
|
+
gamma: float,
|
|
467
|
+
next_state_value: torch.Tensor,
|
|
468
|
+
reward: torch.Tensor,
|
|
469
|
+
done: torch.Tensor,
|
|
470
|
+
terminated: torch.Tensor | None = None,
|
|
471
|
+
rolling_gamma: bool | None = None,
|
|
472
|
+
*,
|
|
473
|
+
time_dim: int = -2,
|
|
474
|
+
) -> torch.Tensor:
|
|
475
|
+
r"""TD(1) return estimate.
|
|
476
|
+
|
|
477
|
+
Args:
|
|
478
|
+
gamma (scalar): exponential mean discount.
|
|
479
|
+
next_state_value (Tensor): value function result with new_state input.
|
|
480
|
+
reward (Tensor): reward of taking actions in the environment.
|
|
481
|
+
done (Tensor): boolean flag for end of trajectory.
|
|
482
|
+
terminated (Tensor): boolean flag for the end of episode. Defaults to ``done``
|
|
483
|
+
if not provided.
|
|
484
|
+
rolling_gamma (bool, optional): if ``True``, it is assumed that each gamma
|
|
485
|
+
of a gamma tensor is tied to a single event:
|
|
486
|
+
|
|
487
|
+
>>> gamma = [g1, g2, g3, g4]
|
|
488
|
+
>>> value = [v1, v2, v3, v4]
|
|
489
|
+
>>> return = [
|
|
490
|
+
... v1 + g1 v2 + g1 g2 v3 + g1 g2 g3 v4,
|
|
491
|
+
... v2 + g2 v3 + g2 g3 v4,
|
|
492
|
+
... v3 + g3 v4,
|
|
493
|
+
... v4,
|
|
494
|
+
... ]
|
|
495
|
+
|
|
496
|
+
if ``False``, it is assumed that each gamma is tied to the upcoming
|
|
497
|
+
trajectory:
|
|
498
|
+
|
|
499
|
+
>>> gamma = [g1, g2, g3, g4]
|
|
500
|
+
>>> value = [v1, v2, v3, v4]
|
|
501
|
+
>>> return = [
|
|
502
|
+
... v1 + g1 v2 + g1**2 v3 + g**3 v4,
|
|
503
|
+
... v2 + g2 v3 + g2**2 v4,
|
|
504
|
+
... v3 + g3 v4,
|
|
505
|
+
... v4,
|
|
506
|
+
... ]
|
|
507
|
+
|
|
508
|
+
Default is ``True``.
|
|
509
|
+
time_dim (int): dimension where the time is unrolled. Defaults to -2.
|
|
510
|
+
|
|
511
|
+
All tensors (values, reward and done) must have shape
|
|
512
|
+
``[*Batch x TimeSteps x *F]``, with ``*F`` feature dimensions.
|
|
513
|
+
|
|
514
|
+
"""
|
|
515
|
+
if terminated is None:
|
|
516
|
+
terminated = done.clone()
|
|
517
|
+
if not (next_state_value.shape == reward.shape == done.shape == terminated.shape):
|
|
518
|
+
raise RuntimeError(SHAPE_ERR)
|
|
519
|
+
not_done = (~done).int()
|
|
520
|
+
not_terminated = (~terminated).int()
|
|
521
|
+
|
|
522
|
+
returns = torch.empty_like(next_state_value)
|
|
523
|
+
|
|
524
|
+
T = returns.shape[-2]
|
|
525
|
+
|
|
526
|
+
single_gamma = False
|
|
527
|
+
if not (isinstance(gamma, torch.Tensor) and gamma.shape == not_done.shape):
|
|
528
|
+
single_gamma = True
|
|
529
|
+
if isinstance(gamma, torch.Tensor):
|
|
530
|
+
# Use expand instead of full_like to avoid .item() call which creates
|
|
531
|
+
# unbacked symbols during torch.compile tracing.
|
|
532
|
+
if gamma.device != next_state_value.device:
|
|
533
|
+
gamma = gamma.to(next_state_value.device)
|
|
534
|
+
gamma = gamma.expand(next_state_value.shape)
|
|
535
|
+
else:
|
|
536
|
+
gamma = torch.full_like(next_state_value, gamma)
|
|
537
|
+
|
|
538
|
+
if rolling_gamma is None:
|
|
539
|
+
rolling_gamma = True
|
|
540
|
+
elif not rolling_gamma and single_gamma:
|
|
541
|
+
raise RuntimeError(
|
|
542
|
+
"rolling_gamma=False is expected only with time-sensitive gamma values"
|
|
543
|
+
)
|
|
544
|
+
|
|
545
|
+
done_but_not_terminated = (done & ~terminated).int()
|
|
546
|
+
if rolling_gamma:
|
|
547
|
+
gamma = gamma * not_terminated
|
|
548
|
+
g = next_state_value[..., -1, :]
|
|
549
|
+
for i in reversed(range(T)):
|
|
550
|
+
# if not done (and hence not terminated), get the bootstrapped value
|
|
551
|
+
# if done but not terminated, get nex_val
|
|
552
|
+
# if terminated, take nothing (gamma = 0)
|
|
553
|
+
dnt = done_but_not_terminated[..., i, :]
|
|
554
|
+
g = returns[..., i, :] = reward[..., i, :] + gamma[..., i, :] * (
|
|
555
|
+
(1 - dnt) * g + dnt * next_state_value[..., i, :]
|
|
556
|
+
)
|
|
557
|
+
else:
|
|
558
|
+
for k in range(T):
|
|
559
|
+
g = 0
|
|
560
|
+
_gamma = gamma[..., k, :]
|
|
561
|
+
nd = not_terminated
|
|
562
|
+
_gamma = _gamma.unsqueeze(-2) * nd
|
|
563
|
+
for i in reversed(range(k, T)):
|
|
564
|
+
dnt = done_but_not_terminated[..., i, :]
|
|
565
|
+
g = reward[..., i, :] + _gamma[..., i, :] * (
|
|
566
|
+
(1 - dnt) * g + dnt * next_state_value[..., i, :]
|
|
567
|
+
)
|
|
568
|
+
returns[..., k, :] = g
|
|
569
|
+
return returns
|
|
570
|
+
|
|
571
|
+
|
|
572
|
+
def td1_advantage_estimate(
|
|
573
|
+
gamma: float,
|
|
574
|
+
state_value: torch.Tensor,
|
|
575
|
+
next_state_value: torch.Tensor,
|
|
576
|
+
reward: torch.Tensor,
|
|
577
|
+
done: torch.Tensor,
|
|
578
|
+
terminated: torch.Tensor | None = None,
|
|
579
|
+
rolling_gamma: bool | None = None,
|
|
580
|
+
time_dim: int = -2,
|
|
581
|
+
) -> torch.Tensor:
|
|
582
|
+
"""TD(1) advantage estimate.
|
|
583
|
+
|
|
584
|
+
Args:
|
|
585
|
+
gamma (scalar): exponential mean discount.
|
|
586
|
+
state_value (Tensor): value function result with old_state input.
|
|
587
|
+
next_state_value (Tensor): value function result with new_state input.
|
|
588
|
+
reward (Tensor): reward of taking actions in the environment.
|
|
589
|
+
done (Tensor): boolean flag for end of trajectory.
|
|
590
|
+
terminated (Tensor): boolean flag for the end of episode. Defaults to ``done``
|
|
591
|
+
if not provided.
|
|
592
|
+
rolling_gamma (bool, optional): if ``True``, it is assumed that each gamma
|
|
593
|
+
of a gamma tensor is tied to a single event:
|
|
594
|
+
|
|
595
|
+
>>> gamma = [g1, g2, g3, g4]
|
|
596
|
+
>>> value = [v1, v2, v3, v4]
|
|
597
|
+
>>> return = [
|
|
598
|
+
... v1 + g1 v2 + g1 g2 v3 + g1 g2 g3 v4,
|
|
599
|
+
... v2 + g2 v3 + g2 g3 v4,
|
|
600
|
+
... v3 + g3 v4,
|
|
601
|
+
... v4,
|
|
602
|
+
... ]
|
|
603
|
+
|
|
604
|
+
if ``False``, it is assumed that each gamma is tied to the upcoming
|
|
605
|
+
trajectory:
|
|
606
|
+
|
|
607
|
+
>>> gamma = [g1, g2, g3, g4]
|
|
608
|
+
>>> value = [v1, v2, v3, v4]
|
|
609
|
+
>>> return = [
|
|
610
|
+
... v1 + g1 v2 + g1**2 v3 + g**3 v4,
|
|
611
|
+
... v2 + g2 v3 + g2**2 v4,
|
|
612
|
+
... v3 + g3 v4,
|
|
613
|
+
... v4,
|
|
614
|
+
... ]
|
|
615
|
+
|
|
616
|
+
Default is ``True``.
|
|
617
|
+
time_dim (int): dimension where the time is unrolled. Defaults to -2.
|
|
618
|
+
|
|
619
|
+
All tensors (values, reward and done) must have shape
|
|
620
|
+
``[*Batch x TimeSteps x *F]``, with ``*F`` feature dimensions.
|
|
621
|
+
|
|
622
|
+
"""
|
|
623
|
+
if terminated is None:
|
|
624
|
+
terminated = done.clone()
|
|
625
|
+
if not (
|
|
626
|
+
next_state_value.shape
|
|
627
|
+
== state_value.shape
|
|
628
|
+
== reward.shape
|
|
629
|
+
== done.shape
|
|
630
|
+
== terminated.shape
|
|
631
|
+
):
|
|
632
|
+
raise RuntimeError(SHAPE_ERR)
|
|
633
|
+
if not state_value.shape == next_state_value.shape:
|
|
634
|
+
raise RuntimeError("shape of state_value and next_state_value must match")
|
|
635
|
+
returns = td1_return_estimate(
|
|
636
|
+
gamma,
|
|
637
|
+
next_state_value,
|
|
638
|
+
reward,
|
|
639
|
+
done,
|
|
640
|
+
terminated=terminated,
|
|
641
|
+
rolling_gamma=rolling_gamma,
|
|
642
|
+
time_dim=time_dim,
|
|
643
|
+
)
|
|
644
|
+
advantage = returns - state_value
|
|
645
|
+
return advantage
|
|
646
|
+
|
|
647
|
+
|
|
648
|
+
@_transpose_time
|
|
649
|
+
def vec_td1_return_estimate(
|
|
650
|
+
gamma,
|
|
651
|
+
next_state_value,
|
|
652
|
+
reward,
|
|
653
|
+
done: torch.Tensor,
|
|
654
|
+
terminated: torch.Tensor | None = None,
|
|
655
|
+
rolling_gamma: bool | None = None,
|
|
656
|
+
time_dim: int = -2,
|
|
657
|
+
):
|
|
658
|
+
"""Vectorized TD(1) return estimate.
|
|
659
|
+
|
|
660
|
+
Args:
|
|
661
|
+
gamma (scalar, Tensor): exponential mean discount. If tensor-valued,
|
|
662
|
+
next_state_value (Tensor): value function result with new_state input.
|
|
663
|
+
reward (Tensor): reward of taking actions in the environment.
|
|
664
|
+
done (Tensor): boolean flag for end of trajectory.
|
|
665
|
+
terminated (Tensor): boolean flag for the end of episode. Defaults to ``done``
|
|
666
|
+
if not provided.
|
|
667
|
+
rolling_gamma (bool, optional): if ``True``, it is assumed that each gamma
|
|
668
|
+
of the gamma tensor is tied to a single event:
|
|
669
|
+
|
|
670
|
+
>>> gamma = [g1, g2, g3, g4]
|
|
671
|
+
>>> value = [v1, v2, v3, v4]
|
|
672
|
+
>>> return = [
|
|
673
|
+
... v1 + g1 v2 + g1 g2 v3 + g1 g2 g3 v4,
|
|
674
|
+
... v2 + g2 v3 + g2 g3 v4,
|
|
675
|
+
... v3 + g3 v4,
|
|
676
|
+
... v4,
|
|
677
|
+
... ]
|
|
678
|
+
|
|
679
|
+
if ``False``, it is assumed that each gamma is tied to the upcoming
|
|
680
|
+
trajectory:
|
|
681
|
+
|
|
682
|
+
>>> gamma = [g1, g2, g3, g4]
|
|
683
|
+
>>> value = [v1, v2, v3, v4]
|
|
684
|
+
>>> return = [
|
|
685
|
+
... v1 + g1 v2 + g1**2 v3 + g**3 v4,
|
|
686
|
+
... v2 + g2 v3 + g2**2 v4,
|
|
687
|
+
... v3 + g3 v4,
|
|
688
|
+
... v4,
|
|
689
|
+
... ]
|
|
690
|
+
|
|
691
|
+
Default is ``True``.
|
|
692
|
+
time_dim (int): dimension where the time is unrolled. Defaults to ``-2``.
|
|
693
|
+
|
|
694
|
+
All tensors (values, reward and done) must have shape
|
|
695
|
+
``[*Batch x TimeSteps x *F]``, with ``*F`` feature dimensions.
|
|
696
|
+
|
|
697
|
+
"""
|
|
698
|
+
return vec_td_lambda_return_estimate(
|
|
699
|
+
gamma=gamma,
|
|
700
|
+
next_state_value=next_state_value,
|
|
701
|
+
reward=reward,
|
|
702
|
+
done=done,
|
|
703
|
+
terminated=terminated,
|
|
704
|
+
rolling_gamma=rolling_gamma,
|
|
705
|
+
lmbda=1,
|
|
706
|
+
time_dim=time_dim,
|
|
707
|
+
)
|
|
708
|
+
|
|
709
|
+
|
|
710
|
+
def vec_td1_advantage_estimate(
|
|
711
|
+
gamma,
|
|
712
|
+
state_value,
|
|
713
|
+
next_state_value,
|
|
714
|
+
reward,
|
|
715
|
+
done: torch.Tensor,
|
|
716
|
+
terminated: torch.Tensor | None = None,
|
|
717
|
+
rolling_gamma: bool | None = None,
|
|
718
|
+
time_dim: int = -2,
|
|
719
|
+
):
|
|
720
|
+
"""Vectorized TD(1) advantage estimate.
|
|
721
|
+
|
|
722
|
+
Args:
|
|
723
|
+
gamma (scalar, Tensor): exponential mean discount. If tensor-valued,
|
|
724
|
+
state_value (Tensor): value function result with old_state input.
|
|
725
|
+
next_state_value (Tensor): value function result with new_state input.
|
|
726
|
+
reward (Tensor): reward of taking actions in the environment.
|
|
727
|
+
done (Tensor): boolean flag for end of trajectory.
|
|
728
|
+
terminated (Tensor): boolean flag for the end of episode. Defaults to ``done``
|
|
729
|
+
if not provided.
|
|
730
|
+
rolling_gamma (bool, optional): if ``True``, it is assumed that each gamma
|
|
731
|
+
of a gamma tensor is tied to a single event:
|
|
732
|
+
|
|
733
|
+
>>> gamma = [g1, g2, g3, g4]
|
|
734
|
+
>>> value = [v1, v2, v3, v4]
|
|
735
|
+
>>> return = [
|
|
736
|
+
... v1 + g1 v2 + g1 g2 v3 + g1 g2 g3 v4,
|
|
737
|
+
... v2 + g2 v3 + g2 g3 v4,
|
|
738
|
+
... v3 + g3 v4,
|
|
739
|
+
... v4,
|
|
740
|
+
... ]
|
|
741
|
+
|
|
742
|
+
if ``False``, it is assumed that each gamma is tied to the upcoming
|
|
743
|
+
trajectory:
|
|
744
|
+
|
|
745
|
+
>>> gamma = [g1, g2, g3, g4]
|
|
746
|
+
>>> value = [v1, v2, v3, v4]
|
|
747
|
+
>>> return = [
|
|
748
|
+
... v1 + g1 v2 + g1**2 v3 + g**3 v4,
|
|
749
|
+
... v2 + g2 v3 + g2**2 v4,
|
|
750
|
+
... v3 + g3 v4,
|
|
751
|
+
... v4,
|
|
752
|
+
... ]
|
|
753
|
+
|
|
754
|
+
Default is ``True``.
|
|
755
|
+
time_dim (int): dimension where the time is unrolled. Defaults to -2.
|
|
756
|
+
|
|
757
|
+
All tensors (values, reward and done) must have shape
|
|
758
|
+
``[*Batch x TimeSteps x *F]``, with ``*F`` feature dimensions.
|
|
759
|
+
|
|
760
|
+
"""
|
|
761
|
+
if terminated is None:
|
|
762
|
+
terminated = done.clone()
|
|
763
|
+
if not (
|
|
764
|
+
next_state_value.shape
|
|
765
|
+
== state_value.shape
|
|
766
|
+
== reward.shape
|
|
767
|
+
== done.shape
|
|
768
|
+
== terminated.shape
|
|
769
|
+
):
|
|
770
|
+
raise RuntimeError(SHAPE_ERR)
|
|
771
|
+
return (
|
|
772
|
+
vec_td1_return_estimate(
|
|
773
|
+
gamma,
|
|
774
|
+
next_state_value,
|
|
775
|
+
reward,
|
|
776
|
+
done=done,
|
|
777
|
+
terminated=terminated,
|
|
778
|
+
rolling_gamma=rolling_gamma,
|
|
779
|
+
time_dim=time_dim,
|
|
780
|
+
)
|
|
781
|
+
- state_value
|
|
782
|
+
)
|
|
783
|
+
|
|
784
|
+
|
|
785
|
+
########################################################################
|
|
786
|
+
# TD(lambda)
|
|
787
|
+
# ----------
|
|
788
|
+
|
|
789
|
+
|
|
790
|
+
@_transpose_time
|
|
791
|
+
def td_lambda_return_estimate(
|
|
792
|
+
gamma: float,
|
|
793
|
+
lmbda: float,
|
|
794
|
+
next_state_value: torch.Tensor,
|
|
795
|
+
reward: torch.Tensor,
|
|
796
|
+
done: torch.Tensor,
|
|
797
|
+
terminated: torch.Tensor | None = None,
|
|
798
|
+
rolling_gamma: bool | None = None,
|
|
799
|
+
*,
|
|
800
|
+
time_dim: int = -2,
|
|
801
|
+
) -> torch.Tensor:
|
|
802
|
+
r"""TD(:math:`\lambda`) return estimate.
|
|
803
|
+
|
|
804
|
+
Args:
|
|
805
|
+
gamma (scalar): exponential mean discount.
|
|
806
|
+
lmbda (scalar): trajectory discount.
|
|
807
|
+
next_state_value (Tensor): value function result with new_state input.
|
|
808
|
+
reward (Tensor): reward of taking actions in the environment.
|
|
809
|
+
done (Tensor): boolean flag for end of trajectory.
|
|
810
|
+
terminated (Tensor): boolean flag for the end of episode. Defaults to ``done``
|
|
811
|
+
if not provided.
|
|
812
|
+
rolling_gamma (bool, optional): if ``True``, it is assumed that each gamma
|
|
813
|
+
of a gamma tensor is tied to a single event:
|
|
814
|
+
|
|
815
|
+
>>> gamma = [g1, g2, g3, g4]
|
|
816
|
+
>>> value = [v1, v2, v3, v4]
|
|
817
|
+
>>> return = [
|
|
818
|
+
... v1 + g1 v2 + g1 g2 v3 + g1 g2 g3 v4,
|
|
819
|
+
... v2 + g2 v3 + g2 g3 v4,
|
|
820
|
+
... v3 + g3 v4,
|
|
821
|
+
... v4,
|
|
822
|
+
... ]
|
|
823
|
+
|
|
824
|
+
if ``False``, it is assumed that each gamma is tied to the upcoming
|
|
825
|
+
trajectory:
|
|
826
|
+
|
|
827
|
+
>>> gamma = [g1, g2, g3, g4]
|
|
828
|
+
>>> value = [v1, v2, v3, v4]
|
|
829
|
+
>>> return = [
|
|
830
|
+
... v1 + g1 v2 + g1**2 v3 + g**3 v4,
|
|
831
|
+
... v2 + g2 v3 + g2**2 v4,
|
|
832
|
+
... v3 + g3 v4,
|
|
833
|
+
... v4,
|
|
834
|
+
... ]
|
|
835
|
+
|
|
836
|
+
Default is ``True``.
|
|
837
|
+
time_dim (int): dimension where the time is unrolled. Defaults to -2.
|
|
838
|
+
|
|
839
|
+
All tensors (values, reward and done) must have shape
|
|
840
|
+
``[*Batch x TimeSteps x *F]``, with ``*F`` feature dimensions.
|
|
841
|
+
|
|
842
|
+
"""
|
|
843
|
+
if terminated is None:
|
|
844
|
+
terminated = done.clone()
|
|
845
|
+
if not (next_state_value.shape == reward.shape == done.shape == terminated.shape):
|
|
846
|
+
raise RuntimeError(SHAPE_ERR)
|
|
847
|
+
|
|
848
|
+
not_terminated = (~terminated).int()
|
|
849
|
+
|
|
850
|
+
returns = torch.empty_like(next_state_value)
|
|
851
|
+
next_state_value = next_state_value * not_terminated
|
|
852
|
+
|
|
853
|
+
*batch, T, lastdim = returns.shape
|
|
854
|
+
|
|
855
|
+
# if gamma is not a tensor of the same shape as other inputs, we use rolling_gamma = True
|
|
856
|
+
single_gamma = False
|
|
857
|
+
if not (isinstance(gamma, torch.Tensor) and gamma.shape == done.shape):
|
|
858
|
+
single_gamma = True
|
|
859
|
+
if isinstance(gamma, torch.Tensor):
|
|
860
|
+
# Use expand instead of full_like to avoid .item() call which creates
|
|
861
|
+
# unbacked symbols during torch.compile tracing.
|
|
862
|
+
if gamma.device != next_state_value.device:
|
|
863
|
+
gamma = gamma.to(next_state_value.device)
|
|
864
|
+
gamma = gamma.expand(next_state_value.shape)
|
|
865
|
+
else:
|
|
866
|
+
gamma = torch.full_like(next_state_value, gamma)
|
|
867
|
+
|
|
868
|
+
single_lambda = False
|
|
869
|
+
if not (isinstance(lmbda, torch.Tensor) and lmbda.shape == done.shape):
|
|
870
|
+
single_lambda = True
|
|
871
|
+
if isinstance(lmbda, torch.Tensor):
|
|
872
|
+
# Use expand instead of full_like to avoid .item() call which creates
|
|
873
|
+
# unbacked symbols during torch.compile tracing.
|
|
874
|
+
if lmbda.device != next_state_value.device:
|
|
875
|
+
lmbda = lmbda.to(next_state_value.device)
|
|
876
|
+
lmbda = lmbda.expand(next_state_value.shape)
|
|
877
|
+
else:
|
|
878
|
+
lmbda = torch.full_like(next_state_value, lmbda)
|
|
879
|
+
|
|
880
|
+
if rolling_gamma is None:
|
|
881
|
+
rolling_gamma = True
|
|
882
|
+
elif not rolling_gamma and single_gamma and single_lambda:
|
|
883
|
+
raise RuntimeError(
|
|
884
|
+
"rolling_gamma=False is expected only with time-sensitive gamma or lambda values"
|
|
885
|
+
)
|
|
886
|
+
if rolling_gamma:
|
|
887
|
+
g = next_state_value[..., -1, :]
|
|
888
|
+
for i in reversed(range(T)):
|
|
889
|
+
dn = done[..., i, :].int()
|
|
890
|
+
nv = next_state_value[..., i, :]
|
|
891
|
+
lmd = lmbda[..., i, :]
|
|
892
|
+
# if done, the bootstrapped gain is the next value, otherwise it's the
|
|
893
|
+
# value we computed during the previous iter
|
|
894
|
+
g = g * (1 - dn) + nv * dn
|
|
895
|
+
g = returns[..., i, :] = reward[..., i, :] + gamma[..., i, :] * (
|
|
896
|
+
(1 - lmd) * nv + lmd * g
|
|
897
|
+
)
|
|
898
|
+
else:
|
|
899
|
+
for k in range(T):
|
|
900
|
+
g = next_state_value[..., -1, :]
|
|
901
|
+
_gamma = gamma[..., k, :]
|
|
902
|
+
_lambda = lmbda[..., k, :]
|
|
903
|
+
for i in reversed(range(k, T)):
|
|
904
|
+
dn = done[..., i, :].int()
|
|
905
|
+
nv = next_state_value[..., i, :]
|
|
906
|
+
g = g * (1 - dn) + nv * dn
|
|
907
|
+
g = reward[..., i, :] + _gamma * ((1 - _lambda) * nv + _lambda * g)
|
|
908
|
+
returns[..., k, :] = g
|
|
909
|
+
|
|
910
|
+
return returns
|
|
911
|
+
|
|
912
|
+
|
|
913
|
+
def td_lambda_advantage_estimate(
|
|
914
|
+
gamma: float,
|
|
915
|
+
lmbda: float,
|
|
916
|
+
state_value: torch.Tensor,
|
|
917
|
+
next_state_value: torch.Tensor,
|
|
918
|
+
reward: torch.Tensor,
|
|
919
|
+
done: torch.Tensor,
|
|
920
|
+
terminated: torch.Tensor | None = None,
|
|
921
|
+
rolling_gamma: bool | None = None,
|
|
922
|
+
# not a kwarg because used directly
|
|
923
|
+
time_dim: int = -2,
|
|
924
|
+
) -> torch.Tensor:
|
|
925
|
+
r"""TD(:math:`\lambda`) advantage estimate.
|
|
926
|
+
|
|
927
|
+
Args:
|
|
928
|
+
gamma (scalar): exponential mean discount.
|
|
929
|
+
lmbda (scalar): trajectory discount.
|
|
930
|
+
state_value (Tensor): value function result with old_state input.
|
|
931
|
+
next_state_value (Tensor): value function result with new_state input.
|
|
932
|
+
reward (Tensor): reward of taking actions in the environment.
|
|
933
|
+
done (Tensor): boolean flag for end of trajectory.
|
|
934
|
+
terminated (Tensor): boolean flag for the end of episode. Defaults to ``done``
|
|
935
|
+
if not provided.
|
|
936
|
+
rolling_gamma (bool, optional): if ``True``, it is assumed that each gamma
|
|
937
|
+
of a gamma tensor is tied to a single event:
|
|
938
|
+
|
|
939
|
+
>>> gamma = [g1, g2, g3, g4]
|
|
940
|
+
>>> value = [v1, v2, v3, v4]
|
|
941
|
+
>>> return = [
|
|
942
|
+
... v1 + g1 v2 + g1 g2 v3 + g1 g2 g3 v4,
|
|
943
|
+
... v2 + g2 v3 + g2 g3 v4,
|
|
944
|
+
... v3 + g3 v4,
|
|
945
|
+
... v4,
|
|
946
|
+
... ]
|
|
947
|
+
|
|
948
|
+
if ``False``, it is assumed that each gamma is tied to the upcoming
|
|
949
|
+
trajectory:
|
|
950
|
+
|
|
951
|
+
>>> gamma = [g1, g2, g3, g4]
|
|
952
|
+
>>> value = [v1, v2, v3, v4]
|
|
953
|
+
>>> return = [
|
|
954
|
+
... v1 + g1 v2 + g1**2 v3 + g**3 v4,
|
|
955
|
+
... v2 + g2 v3 + g2**2 v4,
|
|
956
|
+
... v3 + g3 v4,
|
|
957
|
+
... v4,
|
|
958
|
+
... ]
|
|
959
|
+
|
|
960
|
+
Default is ``True``.
|
|
961
|
+
time_dim (int): dimension where the time is unrolled. Defaults to -2.
|
|
962
|
+
|
|
963
|
+
All tensors (values, reward and done) must have shape
|
|
964
|
+
``[*Batch x TimeSteps x *F]``, with ``*F`` feature dimensions.
|
|
965
|
+
|
|
966
|
+
"""
|
|
967
|
+
if terminated is None:
|
|
968
|
+
terminated = done.clone()
|
|
969
|
+
if not (
|
|
970
|
+
next_state_value.shape
|
|
971
|
+
== state_value.shape
|
|
972
|
+
== reward.shape
|
|
973
|
+
== done.shape
|
|
974
|
+
== terminated.shape
|
|
975
|
+
):
|
|
976
|
+
raise RuntimeError(SHAPE_ERR)
|
|
977
|
+
if not state_value.shape == next_state_value.shape:
|
|
978
|
+
raise RuntimeError("shape of state_value and next_state_value must match")
|
|
979
|
+
returns = td_lambda_return_estimate(
|
|
980
|
+
gamma,
|
|
981
|
+
lmbda,
|
|
982
|
+
next_state_value,
|
|
983
|
+
reward,
|
|
984
|
+
done,
|
|
985
|
+
terminated=terminated,
|
|
986
|
+
rolling_gamma=rolling_gamma,
|
|
987
|
+
time_dim=time_dim,
|
|
988
|
+
)
|
|
989
|
+
advantage = returns - state_value
|
|
990
|
+
return advantage
|
|
991
|
+
|
|
992
|
+
|
|
993
|
+
def _fast_td_lambda_return_estimate(
|
|
994
|
+
gamma: torch.Tensor | float,
|
|
995
|
+
lmbda: float,
|
|
996
|
+
next_state_value: torch.Tensor,
|
|
997
|
+
reward: torch.Tensor,
|
|
998
|
+
done: torch.Tensor,
|
|
999
|
+
terminated: torch.Tensor,
|
|
1000
|
+
thr: float = 1e-7,
|
|
1001
|
+
):
|
|
1002
|
+
"""Fast vectorized TD lambda return estimate.
|
|
1003
|
+
|
|
1004
|
+
In contrast to the generalized `vec_td_lambda_return_estimate` this function does not need
|
|
1005
|
+
to allocate a big tensor of the form [B, T, T], but it only works with gamma/lmbda being scalars.
|
|
1006
|
+
|
|
1007
|
+
Args:
|
|
1008
|
+
gamma (scalar): the gamma decay, can be a tensor with a single element (trajectory discount)
|
|
1009
|
+
lmbda (scalar): the lambda decay (exponential mean discount)
|
|
1010
|
+
next_state_value (torch.Tensor): a [*B, T, F] tensor containing next state values (value function)
|
|
1011
|
+
reward (torch.Tensor): a [*B, T, F] tensor containing rewards
|
|
1012
|
+
done (Tensor): boolean flag for end of trajectory.
|
|
1013
|
+
terminated (Tensor): boolean flag for end of episode.
|
|
1014
|
+
thr (:obj:`float`): threshold for the filter. Below this limit, components will ignored.
|
|
1015
|
+
Defaults to 1e-7.
|
|
1016
|
+
|
|
1017
|
+
All tensors (values, reward and done) must have shape
|
|
1018
|
+
``[*Batch x TimeSteps x F]``, with ``F`` feature dimensions.
|
|
1019
|
+
|
|
1020
|
+
"""
|
|
1021
|
+
device = reward.device
|
|
1022
|
+
done = done.transpose(-2, -1)
|
|
1023
|
+
terminated = terminated.transpose(-2, -1)
|
|
1024
|
+
reward = reward.transpose(-2, -1)
|
|
1025
|
+
next_state_value = next_state_value.transpose(-2, -1)
|
|
1026
|
+
|
|
1027
|
+
# the only valid next states are those where the trajectory does not terminate
|
|
1028
|
+
next_state_value = (~terminated).int() * next_state_value
|
|
1029
|
+
|
|
1030
|
+
# Use torch.full to create directly on device (avoids DeviceCopy in cudagraph)
|
|
1031
|
+
# Handle both scalar and single-element tensor gamma
|
|
1032
|
+
if isinstance(gamma, torch.Tensor):
|
|
1033
|
+
gamma_tensor = gamma.to(device).view(1)
|
|
1034
|
+
else:
|
|
1035
|
+
gamma_tensor = torch.full((1,), gamma, device=device)
|
|
1036
|
+
gammalmbda = gamma_tensor * lmbda
|
|
1037
|
+
|
|
1038
|
+
num_per_traj = _get_num_per_traj(done)
|
|
1039
|
+
|
|
1040
|
+
done = done.clone()
|
|
1041
|
+
done[..., -1] = 1
|
|
1042
|
+
not_done = (~done).int()
|
|
1043
|
+
|
|
1044
|
+
t = reward + next_state_value * gamma_tensor * (1 - not_done * lmbda)
|
|
1045
|
+
|
|
1046
|
+
t_flat, mask = _split_and_pad_sequence(t, num_per_traj, return_mask=True)
|
|
1047
|
+
|
|
1048
|
+
gammalmbdas = _geom_series_like(t_flat[0], gammalmbda, thr=thr)
|
|
1049
|
+
|
|
1050
|
+
ret_flat = _custom_conv1d(t_flat.unsqueeze(1), gammalmbdas)
|
|
1051
|
+
ret = ret_flat.squeeze(1)[mask]
|
|
1052
|
+
|
|
1053
|
+
return ret.view_as(reward).transpose(-1, -2)
|
|
1054
|
+
|
|
1055
|
+
|
|
1056
|
+
@_transpose_time
|
|
1057
|
+
def vec_td_lambda_return_estimate(
|
|
1058
|
+
gamma,
|
|
1059
|
+
lmbda,
|
|
1060
|
+
next_state_value,
|
|
1061
|
+
reward,
|
|
1062
|
+
done,
|
|
1063
|
+
terminated: torch.Tensor | None = None,
|
|
1064
|
+
rolling_gamma: bool | None = None,
|
|
1065
|
+
*,
|
|
1066
|
+
time_dim: int = -2,
|
|
1067
|
+
):
|
|
1068
|
+
r"""Vectorized TD(:math:`\lambda`) return estimate.
|
|
1069
|
+
|
|
1070
|
+
Args:
|
|
1071
|
+
gamma (scalar, Tensor): exponential mean discount. If tensor-valued,
|
|
1072
|
+
must be a [Batch x TimeSteps x 1] tensor.
|
|
1073
|
+
lmbda (scalar): trajectory discount.
|
|
1074
|
+
next_state_value (Tensor): value function result with new_state input.
|
|
1075
|
+
must be a [Batch x TimeSteps x 1] tensor
|
|
1076
|
+
reward (Tensor): reward of taking actions in the environment.
|
|
1077
|
+
must be a [Batch x TimeSteps x 1] or [Batch x TimeSteps] tensor
|
|
1078
|
+
done (Tensor): boolean flag for end of trajectory.
|
|
1079
|
+
terminated (Tensor): boolean flag for the end of episode. Defaults to ``done``
|
|
1080
|
+
if not provided.
|
|
1081
|
+
rolling_gamma (bool, optional): if ``True``, it is assumed that each gamma
|
|
1082
|
+
of a gamma tensor is tied to a single event:
|
|
1083
|
+
|
|
1084
|
+
>>> gamma = [g1, g2, g3, g4]
|
|
1085
|
+
>>> value = [v1, v2, v3, v4]
|
|
1086
|
+
>>> return = [
|
|
1087
|
+
... v1 + g1 v2 + g1 g2 v3 + g1 g2 g3 v4,
|
|
1088
|
+
... v2 + g2 v3 + g2 g3 v4,
|
|
1089
|
+
... v3 + g3 v4,
|
|
1090
|
+
... v4,
|
|
1091
|
+
... ]
|
|
1092
|
+
|
|
1093
|
+
if ``False``, it is assumed that each gamma is tied to the upcoming
|
|
1094
|
+
trajectory:
|
|
1095
|
+
|
|
1096
|
+
>>> gamma = [g1, g2, g3, g4]
|
|
1097
|
+
>>> value = [v1, v2, v3, v4]
|
|
1098
|
+
>>> return = [
|
|
1099
|
+
... v1 + g1 v2 + g1**2 v3 + g**3 v4,
|
|
1100
|
+
... v2 + g2 v3 + g2**2 v4,
|
|
1101
|
+
... v3 + g3 v4,
|
|
1102
|
+
... v4,
|
|
1103
|
+
... ]
|
|
1104
|
+
|
|
1105
|
+
Default is ``True``.
|
|
1106
|
+
time_dim (int): dimension where the time is unrolled. Defaults to -2.
|
|
1107
|
+
|
|
1108
|
+
All tensors (values, reward and done) must have shape
|
|
1109
|
+
``[*Batch x TimeSteps x *F]``, with ``*F`` feature dimensions.
|
|
1110
|
+
|
|
1111
|
+
"""
|
|
1112
|
+
if terminated is None:
|
|
1113
|
+
terminated = done.clone()
|
|
1114
|
+
if not (next_state_value.shape == reward.shape == done.shape == terminated.shape):
|
|
1115
|
+
raise RuntimeError(SHAPE_ERR)
|
|
1116
|
+
|
|
1117
|
+
gamma_thr = 1e-7
|
|
1118
|
+
shape = next_state_value.shape
|
|
1119
|
+
|
|
1120
|
+
*batch, T, lastdim = shape
|
|
1121
|
+
|
|
1122
|
+
def _is_scalar(tensor):
|
|
1123
|
+
return not isinstance(tensor, torch.Tensor) or tensor.numel() == 1
|
|
1124
|
+
|
|
1125
|
+
# There are two use-cases: if gamma/lmbda are scalars we can use the
|
|
1126
|
+
# fast implementation, if not we must construct a gamma tensor.
|
|
1127
|
+
if _is_scalar(gamma) and _is_scalar(lmbda):
|
|
1128
|
+
return _fast_td_lambda_return_estimate(
|
|
1129
|
+
gamma=gamma,
|
|
1130
|
+
lmbda=lmbda,
|
|
1131
|
+
next_state_value=next_state_value,
|
|
1132
|
+
reward=reward,
|
|
1133
|
+
done=done,
|
|
1134
|
+
terminated=terminated,
|
|
1135
|
+
thr=gamma_thr,
|
|
1136
|
+
)
|
|
1137
|
+
|
|
1138
|
+
next_state_value = next_state_value.transpose(-2, -1).unsqueeze(-2)
|
|
1139
|
+
if len(batch):
|
|
1140
|
+
next_state_value = next_state_value.flatten(0, len(batch))
|
|
1141
|
+
|
|
1142
|
+
reward = reward.transpose(-2, -1).unsqueeze(-2)
|
|
1143
|
+
if len(batch):
|
|
1144
|
+
reward = reward.flatten(0, len(batch))
|
|
1145
|
+
|
|
1146
|
+
"""Vectorized version of td_lambda_advantage_estimate"""
|
|
1147
|
+
device = reward.device
|
|
1148
|
+
not_done = (~done).int()
|
|
1149
|
+
not_terminated = (~terminated).int().transpose(-2, -1).unsqueeze(-2)
|
|
1150
|
+
if len(batch):
|
|
1151
|
+
not_terminated = not_terminated.flatten(0, len(batch))
|
|
1152
|
+
next_state_value = next_state_value * not_terminated
|
|
1153
|
+
|
|
1154
|
+
if rolling_gamma is None:
|
|
1155
|
+
rolling_gamma = True
|
|
1156
|
+
if not rolling_gamma and not is_dynamo_compiling():
|
|
1157
|
+
# Skip this validation during compile to avoid CUDA syncs
|
|
1158
|
+
terminated_follows_terminated = terminated[..., 1:, :][
|
|
1159
|
+
terminated[..., :-1, :]
|
|
1160
|
+
].all()
|
|
1161
|
+
if not terminated_follows_terminated:
|
|
1162
|
+
raise NotImplementedError(
|
|
1163
|
+
"When using rolling_gamma=False and vectorized TD(lambda) with time-dependent gamma, "
|
|
1164
|
+
"make sure that conseducitve trajectories are separated as different batch "
|
|
1165
|
+
"items. Propagating a gamma value across trajectories is not permitted with "
|
|
1166
|
+
"this method. Check that you need to use rolling_gamma=False, and if so "
|
|
1167
|
+
"consider using the non-vectorized version of the return computation or splitting "
|
|
1168
|
+
"your trajectories."
|
|
1169
|
+
)
|
|
1170
|
+
|
|
1171
|
+
if rolling_gamma:
|
|
1172
|
+
# Make the coefficient table
|
|
1173
|
+
gammas = _make_gammas_tensor(gamma * not_done, T, rolling_gamma)
|
|
1174
|
+
gammas_cp = torch.cumprod(gammas, -2)
|
|
1175
|
+
lambdas = torch.ones(T + 1, 1, device=device)
|
|
1176
|
+
lambdas[1:] = lmbda
|
|
1177
|
+
lambdas_cp = torch.cumprod(lambdas, -2)
|
|
1178
|
+
lambdas = lambdas[1:]
|
|
1179
|
+
dec = gammas_cp * lambdas_cp
|
|
1180
|
+
|
|
1181
|
+
gammas = _make_gammas_tensor(gamma, T, rolling_gamma)
|
|
1182
|
+
gammas = gammas[..., 1:, :]
|
|
1183
|
+
if gammas.ndimension() == 4 and gammas.shape[1] > 1:
|
|
1184
|
+
gammas = gammas[:, :1]
|
|
1185
|
+
if lambdas.ndimension() == 4 and lambdas.shape[1] > 1:
|
|
1186
|
+
lambdas = lambdas[:, :1]
|
|
1187
|
+
|
|
1188
|
+
not_done = not_done.transpose(-2, -1).unsqueeze(-2)
|
|
1189
|
+
if len(batch):
|
|
1190
|
+
not_done = not_done.flatten(0, len(batch))
|
|
1191
|
+
# lambdas = lambdas * not_done
|
|
1192
|
+
|
|
1193
|
+
v3 = (gammas * lambdas).squeeze(-1) * next_state_value * not_done
|
|
1194
|
+
v3[..., :-1] = 0
|
|
1195
|
+
out = _custom_conv1d(
|
|
1196
|
+
reward
|
|
1197
|
+
+ gammas.squeeze(-1)
|
|
1198
|
+
* next_state_value
|
|
1199
|
+
* (1 - lambdas.squeeze(-1) * not_done)
|
|
1200
|
+
+ v3,
|
|
1201
|
+
dec,
|
|
1202
|
+
)
|
|
1203
|
+
|
|
1204
|
+
return out.view(*batch, lastdim, T).transpose(-2, -1)
|
|
1205
|
+
else:
|
|
1206
|
+
raise NotImplementedError(
|
|
1207
|
+
"The vectorized version of TD(lambda) with rolling_gamma=False is currently not available. "
|
|
1208
|
+
"To use this feature, use the non-vectorized version of TD(lambda). You can expect "
|
|
1209
|
+
"good speed improvements by decorating the function with torch.compile!"
|
|
1210
|
+
)
|
|
1211
|
+
|
|
1212
|
+
|
|
1213
|
+
def vec_td_lambda_advantage_estimate(
|
|
1214
|
+
gamma,
|
|
1215
|
+
lmbda,
|
|
1216
|
+
state_value,
|
|
1217
|
+
next_state_value,
|
|
1218
|
+
reward,
|
|
1219
|
+
done,
|
|
1220
|
+
terminated: torch.Tensor | None = None,
|
|
1221
|
+
rolling_gamma: bool | None = None,
|
|
1222
|
+
# not a kwarg because used directly
|
|
1223
|
+
time_dim: int = -2,
|
|
1224
|
+
):
|
|
1225
|
+
r"""Vectorized TD(:math:`\lambda`) advantage estimate.
|
|
1226
|
+
|
|
1227
|
+
Args:
|
|
1228
|
+
gamma (scalar, Tensor): exponential mean discount. If tensor-valued,
|
|
1229
|
+
lmbda (scalar): trajectory discount.
|
|
1230
|
+
state_value (Tensor): value function result with old_state input.
|
|
1231
|
+
next_state_value (Tensor): value function result with new_state input.
|
|
1232
|
+
reward (Tensor): reward of taking actions in the environment.
|
|
1233
|
+
done (Tensor): boolean flag for end of trajectory.
|
|
1234
|
+
terminated (Tensor): boolean flag for the end of episode. Defaults to ``done``
|
|
1235
|
+
if not provided.
|
|
1236
|
+
rolling_gamma (bool, optional): if ``True``, it is assumed that each gamma
|
|
1237
|
+
of a gamma tensor is tied to a single event:
|
|
1238
|
+
|
|
1239
|
+
>>> gamma = [g1, g2, g3, g4]
|
|
1240
|
+
>>> value = [v1, v2, v3, v4]
|
|
1241
|
+
>>> return = [
|
|
1242
|
+
... v1 + g1 v2 + g1 g2 v3 + g1 g2 g3 v4,
|
|
1243
|
+
... v2 + g2 v3 + g2 g3 v4,
|
|
1244
|
+
... v3 + g3 v4,
|
|
1245
|
+
... v4,
|
|
1246
|
+
... ]
|
|
1247
|
+
|
|
1248
|
+
if ``False``, it is assumed that each gamma is tied to the upcoming
|
|
1249
|
+
trajectory:
|
|
1250
|
+
|
|
1251
|
+
>>> gamma = [g1, g2, g3, g4]
|
|
1252
|
+
>>> value = [v1, v2, v3, v4]
|
|
1253
|
+
>>> return = [
|
|
1254
|
+
... v1 + g1 v2 + g1**2 v3 + g**3 v4,
|
|
1255
|
+
... v2 + g2 v3 + g2**2 v4,
|
|
1256
|
+
... v3 + g3 v4,
|
|
1257
|
+
... v4,
|
|
1258
|
+
... ]
|
|
1259
|
+
|
|
1260
|
+
Default is ``True``.
|
|
1261
|
+
time_dim (int): dimension where the time is unrolled. Defaults to -2.
|
|
1262
|
+
|
|
1263
|
+
All tensors (values, reward and done) must have shape
|
|
1264
|
+
``[*Batch x TimeSteps x *F]``, with ``*F`` feature dimensions.
|
|
1265
|
+
|
|
1266
|
+
"""
|
|
1267
|
+
if terminated is None:
|
|
1268
|
+
terminated = done.clone()
|
|
1269
|
+
if not (
|
|
1270
|
+
next_state_value.shape
|
|
1271
|
+
== state_value.shape
|
|
1272
|
+
== reward.shape
|
|
1273
|
+
== done.shape
|
|
1274
|
+
== terminated.shape
|
|
1275
|
+
):
|
|
1276
|
+
raise RuntimeError(SHAPE_ERR)
|
|
1277
|
+
return (
|
|
1278
|
+
vec_td_lambda_return_estimate(
|
|
1279
|
+
gamma,
|
|
1280
|
+
lmbda,
|
|
1281
|
+
next_state_value,
|
|
1282
|
+
reward,
|
|
1283
|
+
done=done,
|
|
1284
|
+
terminated=terminated,
|
|
1285
|
+
rolling_gamma=rolling_gamma,
|
|
1286
|
+
time_dim=time_dim,
|
|
1287
|
+
)
|
|
1288
|
+
- state_value
|
|
1289
|
+
)
|
|
1290
|
+
|
|
1291
|
+
|
|
1292
|
+
########################################################################
|
|
1293
|
+
# V-Trace
|
|
1294
|
+
# -----
|
|
1295
|
+
|
|
1296
|
+
|
|
1297
|
+
@_transpose_time
|
|
1298
|
+
def vtrace_advantage_estimate(
|
|
1299
|
+
gamma: float,
|
|
1300
|
+
log_pi: torch.Tensor,
|
|
1301
|
+
log_mu: torch.Tensor,
|
|
1302
|
+
state_value: torch.Tensor,
|
|
1303
|
+
next_state_value: torch.Tensor,
|
|
1304
|
+
reward: torch.Tensor,
|
|
1305
|
+
done: torch.Tensor,
|
|
1306
|
+
terminated: torch.Tensor | None = None,
|
|
1307
|
+
rho_thresh: float | torch.Tensor = 1.0,
|
|
1308
|
+
c_thresh: float | torch.Tensor = 1.0,
|
|
1309
|
+
# not a kwarg because used directly
|
|
1310
|
+
time_dim: int = -2,
|
|
1311
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
1312
|
+
"""Computes V-Trace off-policy actor critic targets.
|
|
1313
|
+
|
|
1314
|
+
Refer to "IMPALA: Scalable Distributed Deep-RL with Importance Weighted Actor-Learner Architectures"
|
|
1315
|
+
https://arxiv.org/abs/1802.01561 for more context.
|
|
1316
|
+
|
|
1317
|
+
Args:
|
|
1318
|
+
gamma (scalar): exponential mean discount.
|
|
1319
|
+
log_pi (Tensor): collection actor log probability of taking actions in the environment.
|
|
1320
|
+
log_mu (Tensor): current actor log probability of taking actions in the environment.
|
|
1321
|
+
state_value (Tensor): value function result with state input.
|
|
1322
|
+
next_state_value (Tensor): value function result with next_state input.
|
|
1323
|
+
reward (Tensor): reward of taking actions in the environment.
|
|
1324
|
+
done (Tensor): boolean flag for end of episode.
|
|
1325
|
+
terminated (torch.Tensor): a [B, T] boolean tensor containing the terminated states.
|
|
1326
|
+
rho_thresh (Union[float, Tensor]): rho clipping parameter for importance weights.
|
|
1327
|
+
c_thresh (Union[float, Tensor]): c clipping parameter for importance weights.
|
|
1328
|
+
time_dim (int): dimension where the time is unrolled. Defaults to -2.
|
|
1329
|
+
|
|
1330
|
+
All tensors (values, reward and done) must have shape
|
|
1331
|
+
``[*Batch x TimeSteps x *F]``, with ``*F`` feature dimensions.
|
|
1332
|
+
"""
|
|
1333
|
+
if not (next_state_value.shape == state_value.shape == reward.shape == done.shape):
|
|
1334
|
+
raise RuntimeError(SHAPE_ERR)
|
|
1335
|
+
|
|
1336
|
+
device = state_value.device
|
|
1337
|
+
|
|
1338
|
+
if not isinstance(rho_thresh, torch.Tensor):
|
|
1339
|
+
rho_thresh = torch.tensor(rho_thresh, device=device)
|
|
1340
|
+
if not isinstance(c_thresh, torch.Tensor):
|
|
1341
|
+
c_thresh = torch.tensor(c_thresh, device=device)
|
|
1342
|
+
|
|
1343
|
+
c_thresh = c_thresh.to(device)
|
|
1344
|
+
rho_thresh = rho_thresh.to(device)
|
|
1345
|
+
|
|
1346
|
+
not_done = (~done).int()
|
|
1347
|
+
not_terminated = not_done if terminated is None else (~terminated).int()
|
|
1348
|
+
*batch_size, time_steps, lastdim = not_done.shape
|
|
1349
|
+
done_discounts = gamma * not_done
|
|
1350
|
+
terminated_discounts = gamma * not_terminated
|
|
1351
|
+
|
|
1352
|
+
rho = (log_pi - log_mu).exp()
|
|
1353
|
+
clipped_rho = rho.clamp_max(rho_thresh)
|
|
1354
|
+
deltas = clipped_rho * (
|
|
1355
|
+
reward + terminated_discounts * next_state_value - state_value
|
|
1356
|
+
)
|
|
1357
|
+
clipped_c = rho.clamp_max(c_thresh)
|
|
1358
|
+
|
|
1359
|
+
vs_minus_v_xs = [torch.zeros_like(next_state_value[..., -1, :])]
|
|
1360
|
+
for i in reversed(range(time_steps)):
|
|
1361
|
+
discount_t, c_t, delta_t = (
|
|
1362
|
+
done_discounts[..., i, :],
|
|
1363
|
+
clipped_c[..., i, :],
|
|
1364
|
+
deltas[..., i, :],
|
|
1365
|
+
)
|
|
1366
|
+
vs_minus_v_xs.append(delta_t + discount_t * c_t * vs_minus_v_xs[-1])
|
|
1367
|
+
vs_minus_v_xs = torch.stack(vs_minus_v_xs[1:], dim=time_dim)
|
|
1368
|
+
vs_minus_v_xs = torch.flip(vs_minus_v_xs, dims=[time_dim])
|
|
1369
|
+
vs = vs_minus_v_xs + state_value
|
|
1370
|
+
vs_t_plus_1 = torch.cat(
|
|
1371
|
+
[vs[..., 1:, :], next_state_value[..., -1:, :]], dim=time_dim
|
|
1372
|
+
)
|
|
1373
|
+
advantages = clipped_rho * (
|
|
1374
|
+
reward + terminated_discounts * vs_t_plus_1 - state_value
|
|
1375
|
+
)
|
|
1376
|
+
|
|
1377
|
+
return advantages, vs
|
|
1378
|
+
|
|
1379
|
+
|
|
1380
|
+
########################################################################
|
|
1381
|
+
# Reward to go
|
|
1382
|
+
# ------------
|
|
1383
|
+
|
|
1384
|
+
|
|
1385
|
+
@_transpose_time
|
|
1386
|
+
def reward2go(
|
|
1387
|
+
reward,
|
|
1388
|
+
done,
|
|
1389
|
+
gamma,
|
|
1390
|
+
*,
|
|
1391
|
+
time_dim: int = -2,
|
|
1392
|
+
):
|
|
1393
|
+
"""Compute the discounted cumulative sum of rewards given multiple trajectories and the episode ends.
|
|
1394
|
+
|
|
1395
|
+
Args:
|
|
1396
|
+
reward (torch.Tensor): A tensor containing the rewards
|
|
1397
|
+
received at each time step over multiple trajectories.
|
|
1398
|
+
done (Tensor): boolean flag for end of episode. Differs from
|
|
1399
|
+
truncated, where the episode did not end but was interrupted.
|
|
1400
|
+
gamma (:obj:`float`, optional): The discount factor to use for computing the
|
|
1401
|
+
discounted cumulative sum of rewards. Defaults to 1.0.
|
|
1402
|
+
time_dim (int): dimension where the time is unrolled. Defaults to -2.
|
|
1403
|
+
|
|
1404
|
+
Returns:
|
|
1405
|
+
torch.Tensor: A tensor of shape [B, T] containing the discounted cumulative
|
|
1406
|
+
sum of rewards (reward-to-go) at each time step.
|
|
1407
|
+
|
|
1408
|
+
Examples:
|
|
1409
|
+
>>> reward = torch.ones(1, 10)
|
|
1410
|
+
>>> done = torch.zeros(1, 10, dtype=torch.bool)
|
|
1411
|
+
>>> done[:, [3, 7]] = True
|
|
1412
|
+
>>> reward2go(reward, done, 0.99, time_dim=-1)
|
|
1413
|
+
tensor([[3.9404],
|
|
1414
|
+
[2.9701],
|
|
1415
|
+
[1.9900],
|
|
1416
|
+
[1.0000],
|
|
1417
|
+
[3.9404],
|
|
1418
|
+
[2.9701],
|
|
1419
|
+
[1.9900],
|
|
1420
|
+
[1.0000],
|
|
1421
|
+
[1.9900],
|
|
1422
|
+
[1.0000]])
|
|
1423
|
+
|
|
1424
|
+
"""
|
|
1425
|
+
shape = reward.shape
|
|
1426
|
+
if shape != done.shape:
|
|
1427
|
+
raise ValueError(
|
|
1428
|
+
f"reward and done must share the same shape, got {reward.shape} and {done.shape}"
|
|
1429
|
+
)
|
|
1430
|
+
# flatten if needed
|
|
1431
|
+
if reward.ndim > 2:
|
|
1432
|
+
# we know time dim is at -2, let's put it at -3
|
|
1433
|
+
rflip = reward.transpose(-2, -3)
|
|
1434
|
+
rflip_shape = rflip.shape[-2:]
|
|
1435
|
+
r2go = reward2go(
|
|
1436
|
+
rflip.flatten(-2, -1), done.transpose(-2, -3).flatten(-2, -1), gamma=gamma
|
|
1437
|
+
).unflatten(-1, rflip_shape)
|
|
1438
|
+
return r2go.transpose(-2, -3)
|
|
1439
|
+
|
|
1440
|
+
# place time at dim -1
|
|
1441
|
+
reward = reward.transpose(-2, -1)
|
|
1442
|
+
done = done.transpose(-2, -1)
|
|
1443
|
+
|
|
1444
|
+
num_per_traj = _get_num_per_traj(done)
|
|
1445
|
+
td0_flat = _split_and_pad_sequence(reward, num_per_traj)
|
|
1446
|
+
gammas = _geom_series_like(td0_flat[0], gamma, thr=1e-7)
|
|
1447
|
+
cumsum = _custom_conv1d(td0_flat.unsqueeze(1), gammas)
|
|
1448
|
+
cumsum = cumsum.squeeze(1)
|
|
1449
|
+
cumsum = _inv_pad_sequence(cumsum, num_per_traj)
|
|
1450
|
+
cumsum = cumsum.reshape_as(reward)
|
|
1451
|
+
cumsum = cumsum.transpose(-2, -1)
|
|
1452
|
+
if cumsum.shape != shape:
|
|
1453
|
+
try:
|
|
1454
|
+
cumsum = cumsum.reshape(shape)
|
|
1455
|
+
except RuntimeError:
|
|
1456
|
+
raise RuntimeError(
|
|
1457
|
+
f"Wrong shape for output reward2go: {cumsum.shape} when {shape} was expected."
|
|
1458
|
+
)
|
|
1459
|
+
return cumsum
|