torchrl 0.11.0__cp314-cp314t-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/benchmark_batched_envs.py +104 -0
- benchmarks/conftest.py +91 -0
- benchmarks/ecosystem/gym_env_throughput.py +321 -0
- benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
- benchmarks/requirements.txt +7 -0
- benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
- benchmarks/test_collectors_benchmark.py +240 -0
- benchmarks/test_compressed_storage_benchmark.py +145 -0
- benchmarks/test_envs_benchmark.py +133 -0
- benchmarks/test_llm.py +101 -0
- benchmarks/test_non_tensor_env_benchmark.py +70 -0
- benchmarks/test_objectives_benchmarks.py +1199 -0
- benchmarks/test_replaybuffer_benchmark.py +254 -0
- sota-check/README.md +35 -0
- sota-implementations/README.md +142 -0
- sota-implementations/a2c/README.md +39 -0
- sota-implementations/a2c/a2c_atari.py +291 -0
- sota-implementations/a2c/a2c_mujoco.py +273 -0
- sota-implementations/a2c/utils_atari.py +240 -0
- sota-implementations/a2c/utils_mujoco.py +160 -0
- sota-implementations/bandits/README.md +7 -0
- sota-implementations/bandits/dqn.py +126 -0
- sota-implementations/cql/cql_offline.py +198 -0
- sota-implementations/cql/cql_online.py +249 -0
- sota-implementations/cql/discrete_cql_offline.py +180 -0
- sota-implementations/cql/discrete_cql_online.py +227 -0
- sota-implementations/cql/utils.py +471 -0
- sota-implementations/crossq/crossq.py +271 -0
- sota-implementations/crossq/utils.py +320 -0
- sota-implementations/ddpg/ddpg.py +231 -0
- sota-implementations/ddpg/utils.py +325 -0
- sota-implementations/decision_transformer/dt.py +163 -0
- sota-implementations/decision_transformer/lamb.py +167 -0
- sota-implementations/decision_transformer/online_dt.py +178 -0
- sota-implementations/decision_transformer/utils.py +562 -0
- sota-implementations/discrete_sac/discrete_sac.py +243 -0
- sota-implementations/discrete_sac/utils.py +324 -0
- sota-implementations/dqn/README.md +30 -0
- sota-implementations/dqn/dqn_atari.py +272 -0
- sota-implementations/dqn/dqn_cartpole.py +236 -0
- sota-implementations/dqn/utils_atari.py +132 -0
- sota-implementations/dqn/utils_cartpole.py +90 -0
- sota-implementations/dreamer/README.md +129 -0
- sota-implementations/dreamer/dreamer.py +586 -0
- sota-implementations/dreamer/dreamer_utils.py +1107 -0
- sota-implementations/expert-iteration/README.md +352 -0
- sota-implementations/expert-iteration/ei_utils.py +770 -0
- sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
- sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
- sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
- sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
- sota-implementations/gail/gail.py +327 -0
- sota-implementations/gail/gail_utils.py +68 -0
- sota-implementations/gail/ppo_utils.py +157 -0
- sota-implementations/grpo/README.md +273 -0
- sota-implementations/grpo/grpo-async.py +437 -0
- sota-implementations/grpo/grpo-sync.py +435 -0
- sota-implementations/grpo/grpo_utils.py +843 -0
- sota-implementations/grpo/requirements_gsm8k.txt +11 -0
- sota-implementations/grpo/requirements_ifeval.txt +16 -0
- sota-implementations/impala/README.md +33 -0
- sota-implementations/impala/impala_multi_node_ray.py +292 -0
- sota-implementations/impala/impala_multi_node_submitit.py +284 -0
- sota-implementations/impala/impala_single_node.py +261 -0
- sota-implementations/impala/utils.py +184 -0
- sota-implementations/iql/discrete_iql.py +230 -0
- sota-implementations/iql/iql_offline.py +164 -0
- sota-implementations/iql/iql_online.py +225 -0
- sota-implementations/iql/utils.py +437 -0
- sota-implementations/multiagent/README.md +74 -0
- sota-implementations/multiagent/iql.py +237 -0
- sota-implementations/multiagent/maddpg_iddpg.py +266 -0
- sota-implementations/multiagent/mappo_ippo.py +267 -0
- sota-implementations/multiagent/qmix_vdn.py +271 -0
- sota-implementations/multiagent/sac.py +337 -0
- sota-implementations/multiagent/utils/__init__.py +4 -0
- sota-implementations/multiagent/utils/logging.py +151 -0
- sota-implementations/multiagent/utils/utils.py +43 -0
- sota-implementations/ppo/README.md +29 -0
- sota-implementations/ppo/ppo_atari.py +305 -0
- sota-implementations/ppo/ppo_mujoco.py +293 -0
- sota-implementations/ppo/utils_atari.py +238 -0
- sota-implementations/ppo/utils_mujoco.py +152 -0
- sota-implementations/ppo_trainer/train.py +21 -0
- sota-implementations/redq/README.md +7 -0
- sota-implementations/redq/redq.py +199 -0
- sota-implementations/redq/utils.py +1060 -0
- sota-implementations/sac/sac-async.py +266 -0
- sota-implementations/sac/sac.py +239 -0
- sota-implementations/sac/utils.py +381 -0
- sota-implementations/sac_trainer/train.py +16 -0
- sota-implementations/td3/td3.py +254 -0
- sota-implementations/td3/utils.py +319 -0
- sota-implementations/td3_bc/td3_bc.py +177 -0
- sota-implementations/td3_bc/utils.py +251 -0
- torchrl/__init__.py +144 -0
- torchrl/_extension.py +74 -0
- torchrl/_torchrl.cp314t-win_amd64.pyd +0 -0
- torchrl/_utils.py +1431 -0
- torchrl/collectors/__init__.py +48 -0
- torchrl/collectors/_base.py +1058 -0
- torchrl/collectors/_constants.py +88 -0
- torchrl/collectors/_multi_async.py +324 -0
- torchrl/collectors/_multi_base.py +1805 -0
- torchrl/collectors/_multi_sync.py +464 -0
- torchrl/collectors/_runner.py +581 -0
- torchrl/collectors/_single.py +2009 -0
- torchrl/collectors/_single_async.py +259 -0
- torchrl/collectors/collectors.py +62 -0
- torchrl/collectors/distributed/__init__.py +32 -0
- torchrl/collectors/distributed/default_configs.py +133 -0
- torchrl/collectors/distributed/generic.py +1306 -0
- torchrl/collectors/distributed/ray.py +1092 -0
- torchrl/collectors/distributed/rpc.py +1006 -0
- torchrl/collectors/distributed/sync.py +731 -0
- torchrl/collectors/distributed/utils.py +160 -0
- torchrl/collectors/llm/__init__.py +10 -0
- torchrl/collectors/llm/base.py +494 -0
- torchrl/collectors/llm/ray_collector.py +275 -0
- torchrl/collectors/llm/utils.py +36 -0
- torchrl/collectors/llm/weight_update/__init__.py +10 -0
- torchrl/collectors/llm/weight_update/vllm.py +348 -0
- torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
- torchrl/collectors/utils.py +433 -0
- torchrl/collectors/weight_update.py +591 -0
- torchrl/csrc/numpy_utils.h +38 -0
- torchrl/csrc/pybind.cpp +27 -0
- torchrl/csrc/segment_tree.h +458 -0
- torchrl/csrc/torch_utils.h +34 -0
- torchrl/csrc/utils.cpp +48 -0
- torchrl/csrc/utils.h +31 -0
- torchrl/data/__init__.py +187 -0
- torchrl/data/datasets/__init__.py +58 -0
- torchrl/data/datasets/atari_dqn.py +878 -0
- torchrl/data/datasets/common.py +281 -0
- torchrl/data/datasets/d4rl.py +489 -0
- torchrl/data/datasets/d4rl_infos.py +187 -0
- torchrl/data/datasets/gen_dgrl.py +375 -0
- torchrl/data/datasets/minari_data.py +643 -0
- torchrl/data/datasets/openml.py +177 -0
- torchrl/data/datasets/openx.py +798 -0
- torchrl/data/datasets/roboset.py +363 -0
- torchrl/data/datasets/utils.py +11 -0
- torchrl/data/datasets/vd4rl.py +432 -0
- torchrl/data/llm/__init__.py +34 -0
- torchrl/data/llm/dataset.py +491 -0
- torchrl/data/llm/history.py +1378 -0
- torchrl/data/llm/prompt.py +198 -0
- torchrl/data/llm/reward.py +225 -0
- torchrl/data/llm/topk.py +186 -0
- torchrl/data/llm/utils.py +543 -0
- torchrl/data/map/__init__.py +21 -0
- torchrl/data/map/hash.py +185 -0
- torchrl/data/map/query.py +204 -0
- torchrl/data/map/tdstorage.py +363 -0
- torchrl/data/map/tree.py +1434 -0
- torchrl/data/map/utils.py +103 -0
- torchrl/data/postprocs/__init__.py +8 -0
- torchrl/data/postprocs/postprocs.py +391 -0
- torchrl/data/replay_buffers/__init__.py +99 -0
- torchrl/data/replay_buffers/checkpointers.py +622 -0
- torchrl/data/replay_buffers/ray_buffer.py +292 -0
- torchrl/data/replay_buffers/replay_buffers.py +2376 -0
- torchrl/data/replay_buffers/samplers.py +2578 -0
- torchrl/data/replay_buffers/scheduler.py +265 -0
- torchrl/data/replay_buffers/storages.py +2412 -0
- torchrl/data/replay_buffers/utils.py +1042 -0
- torchrl/data/replay_buffers/writers.py +781 -0
- torchrl/data/tensor_specs.py +7101 -0
- torchrl/data/utils.py +334 -0
- torchrl/envs/__init__.py +265 -0
- torchrl/envs/async_envs.py +1105 -0
- torchrl/envs/batched_envs.py +3093 -0
- torchrl/envs/common.py +4241 -0
- torchrl/envs/custom/__init__.py +11 -0
- torchrl/envs/custom/chess.py +617 -0
- torchrl/envs/custom/llm.py +214 -0
- torchrl/envs/custom/pendulum.py +401 -0
- torchrl/envs/custom/san_moves.txt +29274 -0
- torchrl/envs/custom/tictactoeenv.py +288 -0
- torchrl/envs/env_creator.py +263 -0
- torchrl/envs/gym_like.py +752 -0
- torchrl/envs/libs/__init__.py +68 -0
- torchrl/envs/libs/_gym_utils.py +326 -0
- torchrl/envs/libs/brax.py +846 -0
- torchrl/envs/libs/dm_control.py +544 -0
- torchrl/envs/libs/envpool.py +447 -0
- torchrl/envs/libs/gym.py +2239 -0
- torchrl/envs/libs/habitat.py +138 -0
- torchrl/envs/libs/isaac_lab.py +87 -0
- torchrl/envs/libs/isaacgym.py +203 -0
- torchrl/envs/libs/jax_utils.py +166 -0
- torchrl/envs/libs/jumanji.py +963 -0
- torchrl/envs/libs/meltingpot.py +599 -0
- torchrl/envs/libs/openml.py +153 -0
- torchrl/envs/libs/openspiel.py +652 -0
- torchrl/envs/libs/pettingzoo.py +1042 -0
- torchrl/envs/libs/procgen.py +351 -0
- torchrl/envs/libs/robohive.py +429 -0
- torchrl/envs/libs/smacv2.py +645 -0
- torchrl/envs/libs/unity_mlagents.py +891 -0
- torchrl/envs/libs/utils.py +147 -0
- torchrl/envs/libs/vmas.py +813 -0
- torchrl/envs/llm/__init__.py +63 -0
- torchrl/envs/llm/chat.py +730 -0
- torchrl/envs/llm/datasets/README.md +4 -0
- torchrl/envs/llm/datasets/__init__.py +17 -0
- torchrl/envs/llm/datasets/gsm8k.py +353 -0
- torchrl/envs/llm/datasets/ifeval.py +274 -0
- torchrl/envs/llm/envs.py +789 -0
- torchrl/envs/llm/libs/README.md +3 -0
- torchrl/envs/llm/libs/__init__.py +8 -0
- torchrl/envs/llm/libs/mlgym.py +869 -0
- torchrl/envs/llm/reward/__init__.py +10 -0
- torchrl/envs/llm/reward/gsm8k.py +324 -0
- torchrl/envs/llm/reward/ifeval/README.md +13 -0
- torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
- torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
- torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
- torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
- torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
- torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
- torchrl/envs/llm/transforms/__init__.py +55 -0
- torchrl/envs/llm/transforms/browser.py +292 -0
- torchrl/envs/llm/transforms/dataloading.py +859 -0
- torchrl/envs/llm/transforms/format.py +73 -0
- torchrl/envs/llm/transforms/kl.py +1544 -0
- torchrl/envs/llm/transforms/policy_version.py +189 -0
- torchrl/envs/llm/transforms/reason.py +323 -0
- torchrl/envs/llm/transforms/tokenizer.py +321 -0
- torchrl/envs/llm/transforms/tools.py +1955 -0
- torchrl/envs/model_based/__init__.py +9 -0
- torchrl/envs/model_based/common.py +180 -0
- torchrl/envs/model_based/dreamer.py +112 -0
- torchrl/envs/transforms/__init__.py +147 -0
- torchrl/envs/transforms/functional.py +48 -0
- torchrl/envs/transforms/gym_transforms.py +203 -0
- torchrl/envs/transforms/module.py +341 -0
- torchrl/envs/transforms/r3m.py +372 -0
- torchrl/envs/transforms/ray_service.py +663 -0
- torchrl/envs/transforms/rb_transforms.py +214 -0
- torchrl/envs/transforms/transforms.py +11835 -0
- torchrl/envs/transforms/utils.py +94 -0
- torchrl/envs/transforms/vc1.py +307 -0
- torchrl/envs/transforms/vecnorm.py +845 -0
- torchrl/envs/transforms/vip.py +407 -0
- torchrl/envs/utils.py +1718 -0
- torchrl/envs/vec_envs.py +11 -0
- torchrl/modules/__init__.py +206 -0
- torchrl/modules/distributions/__init__.py +73 -0
- torchrl/modules/distributions/continuous.py +830 -0
- torchrl/modules/distributions/discrete.py +908 -0
- torchrl/modules/distributions/truncated_normal.py +187 -0
- torchrl/modules/distributions/utils.py +233 -0
- torchrl/modules/llm/__init__.py +62 -0
- torchrl/modules/llm/backends/__init__.py +65 -0
- torchrl/modules/llm/backends/vllm/__init__.py +94 -0
- torchrl/modules/llm/backends/vllm/_models.py +46 -0
- torchrl/modules/llm/backends/vllm/base.py +72 -0
- torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
- torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
- torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
- torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
- torchrl/modules/llm/policies/__init__.py +28 -0
- torchrl/modules/llm/policies/common.py +1809 -0
- torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
- torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
- torchrl/modules/llm/utils.py +23 -0
- torchrl/modules/mcts/__init__.py +21 -0
- torchrl/modules/mcts/scores.py +579 -0
- torchrl/modules/models/__init__.py +86 -0
- torchrl/modules/models/batchrenorm.py +119 -0
- torchrl/modules/models/decision_transformer.py +179 -0
- torchrl/modules/models/exploration.py +731 -0
- torchrl/modules/models/llm.py +156 -0
- torchrl/modules/models/model_based.py +596 -0
- torchrl/modules/models/models.py +1712 -0
- torchrl/modules/models/multiagent.py +1067 -0
- torchrl/modules/models/recipes/impala.py +185 -0
- torchrl/modules/models/utils.py +162 -0
- torchrl/modules/planners/__init__.py +10 -0
- torchrl/modules/planners/cem.py +228 -0
- torchrl/modules/planners/common.py +73 -0
- torchrl/modules/planners/mppi.py +265 -0
- torchrl/modules/tensordict_module/__init__.py +89 -0
- torchrl/modules/tensordict_module/actors.py +2457 -0
- torchrl/modules/tensordict_module/common.py +529 -0
- torchrl/modules/tensordict_module/exploration.py +814 -0
- torchrl/modules/tensordict_module/probabilistic.py +321 -0
- torchrl/modules/tensordict_module/rnn.py +1639 -0
- torchrl/modules/tensordict_module/sequence.py +132 -0
- torchrl/modules/tensordict_module/world_models.py +34 -0
- torchrl/modules/utils/__init__.py +38 -0
- torchrl/modules/utils/mappings.py +9 -0
- torchrl/modules/utils/utils.py +89 -0
- torchrl/objectives/__init__.py +78 -0
- torchrl/objectives/a2c.py +659 -0
- torchrl/objectives/common.py +753 -0
- torchrl/objectives/cql.py +1346 -0
- torchrl/objectives/crossq.py +710 -0
- torchrl/objectives/ddpg.py +453 -0
- torchrl/objectives/decision_transformer.py +371 -0
- torchrl/objectives/deprecated.py +516 -0
- torchrl/objectives/dqn.py +683 -0
- torchrl/objectives/dreamer.py +488 -0
- torchrl/objectives/functional.py +48 -0
- torchrl/objectives/gail.py +258 -0
- torchrl/objectives/iql.py +996 -0
- torchrl/objectives/llm/__init__.py +30 -0
- torchrl/objectives/llm/grpo.py +846 -0
- torchrl/objectives/llm/sft.py +482 -0
- torchrl/objectives/multiagent/__init__.py +8 -0
- torchrl/objectives/multiagent/qmixer.py +396 -0
- torchrl/objectives/ppo.py +1669 -0
- torchrl/objectives/redq.py +683 -0
- torchrl/objectives/reinforce.py +530 -0
- torchrl/objectives/sac.py +1580 -0
- torchrl/objectives/td3.py +570 -0
- torchrl/objectives/td3_bc.py +625 -0
- torchrl/objectives/utils.py +782 -0
- torchrl/objectives/value/__init__.py +28 -0
- torchrl/objectives/value/advantages.py +1956 -0
- torchrl/objectives/value/functional.py +1459 -0
- torchrl/objectives/value/utils.py +360 -0
- torchrl/record/__init__.py +17 -0
- torchrl/record/loggers/__init__.py +23 -0
- torchrl/record/loggers/common.py +48 -0
- torchrl/record/loggers/csv.py +226 -0
- torchrl/record/loggers/mlflow.py +142 -0
- torchrl/record/loggers/tensorboard.py +139 -0
- torchrl/record/loggers/trackio.py +163 -0
- torchrl/record/loggers/utils.py +78 -0
- torchrl/record/loggers/wandb.py +214 -0
- torchrl/record/recorder.py +554 -0
- torchrl/services/__init__.py +79 -0
- torchrl/services/base.py +109 -0
- torchrl/services/ray_service.py +453 -0
- torchrl/testing/__init__.py +107 -0
- torchrl/testing/assertions.py +179 -0
- torchrl/testing/dist_utils.py +122 -0
- torchrl/testing/env_creators.py +227 -0
- torchrl/testing/env_helper.py +35 -0
- torchrl/testing/gym_helpers.py +156 -0
- torchrl/testing/llm_mocks.py +119 -0
- torchrl/testing/mocking_classes.py +2720 -0
- torchrl/testing/modules.py +295 -0
- torchrl/testing/mp_helpers.py +15 -0
- torchrl/testing/ray_helpers.py +293 -0
- torchrl/testing/utils.py +190 -0
- torchrl/trainers/__init__.py +42 -0
- torchrl/trainers/algorithms/__init__.py +11 -0
- torchrl/trainers/algorithms/configs/__init__.py +705 -0
- torchrl/trainers/algorithms/configs/collectors.py +216 -0
- torchrl/trainers/algorithms/configs/common.py +41 -0
- torchrl/trainers/algorithms/configs/data.py +308 -0
- torchrl/trainers/algorithms/configs/envs.py +104 -0
- torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
- torchrl/trainers/algorithms/configs/logging.py +80 -0
- torchrl/trainers/algorithms/configs/modules.py +570 -0
- torchrl/trainers/algorithms/configs/objectives.py +177 -0
- torchrl/trainers/algorithms/configs/trainers.py +340 -0
- torchrl/trainers/algorithms/configs/transforms.py +955 -0
- torchrl/trainers/algorithms/configs/utils.py +252 -0
- torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
- torchrl/trainers/algorithms/configs/weight_update.py +159 -0
- torchrl/trainers/algorithms/ppo.py +373 -0
- torchrl/trainers/algorithms/sac.py +308 -0
- torchrl/trainers/helpers/__init__.py +40 -0
- torchrl/trainers/helpers/collectors.py +416 -0
- torchrl/trainers/helpers/envs.py +573 -0
- torchrl/trainers/helpers/logger.py +33 -0
- torchrl/trainers/helpers/losses.py +132 -0
- torchrl/trainers/helpers/models.py +658 -0
- torchrl/trainers/helpers/replay_buffer.py +59 -0
- torchrl/trainers/helpers/trainers.py +301 -0
- torchrl/trainers/trainers.py +2052 -0
- torchrl/weight_update/__init__.py +33 -0
- torchrl/weight_update/_distributed.py +749 -0
- torchrl/weight_update/_mp.py +624 -0
- torchrl/weight_update/_noupdate.py +102 -0
- torchrl/weight_update/_ray.py +1032 -0
- torchrl/weight_update/_rpc.py +284 -0
- torchrl/weight_update/_shared.py +891 -0
- torchrl/weight_update/llm/__init__.py +32 -0
- torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
- torchrl/weight_update/llm/vllm_nccl.py +710 -0
- torchrl/weight_update/utils.py +73 -0
- torchrl/weight_update/weight_sync_schemes.py +1244 -0
- torchrl-0.11.0.dist-info/LICENSE +21 -0
- torchrl-0.11.0.dist-info/METADATA +1307 -0
- torchrl-0.11.0.dist-info/RECORD +394 -0
- torchrl-0.11.0.dist-info/WHEEL +5 -0
- torchrl-0.11.0.dist-info/entry_points.txt +2 -0
- torchrl-0.11.0.dist-info/top_level.txt +7 -0
|
@@ -0,0 +1,488 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
from dataclasses import dataclass
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
from tensordict import TensorDict
|
|
11
|
+
from tensordict.nn import TensorDictModule
|
|
12
|
+
from tensordict.utils import NestedKey
|
|
13
|
+
|
|
14
|
+
from torchrl._utils import _maybe_record_function_decorator, _maybe_timeit
|
|
15
|
+
from torchrl.envs.model_based.dreamer import DreamerEnv
|
|
16
|
+
from torchrl.envs.utils import ExplorationType, set_exploration_type, step_mdp
|
|
17
|
+
from torchrl.objectives.common import LossModule
|
|
18
|
+
from torchrl.objectives.utils import (
|
|
19
|
+
_GAMMA_LMBDA_DEPREC_ERROR,
|
|
20
|
+
default_value_kwargs,
|
|
21
|
+
distance_loss,
|
|
22
|
+
hold_out_net,
|
|
23
|
+
ValueEstimators,
|
|
24
|
+
) # distance_loss,
|
|
25
|
+
from torchrl.objectives.value import (
|
|
26
|
+
TD0Estimator,
|
|
27
|
+
TD1Estimator,
|
|
28
|
+
TDLambdaEstimator,
|
|
29
|
+
ValueEstimatorBase,
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class DreamerModelLoss(LossModule):
|
|
34
|
+
"""Dreamer Model Loss.
|
|
35
|
+
|
|
36
|
+
Computes the loss of the dreamer world model. The loss is composed of the
|
|
37
|
+
kl divergence between the prior and posterior of the RSSM,
|
|
38
|
+
the reconstruction loss over the reconstructed observation and the reward
|
|
39
|
+
loss over the predicted reward.
|
|
40
|
+
|
|
41
|
+
Reference: https://arxiv.org/abs/1912.01603.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
world_model (TensorDictModule): the world model.
|
|
45
|
+
lambda_kl (:obj:`float`, optional): the weight of the kl divergence loss. Default: 1.0.
|
|
46
|
+
lambda_reco (:obj:`float`, optional): the weight of the reconstruction loss. Default: 1.0.
|
|
47
|
+
lambda_reward (:obj:`float`, optional): the weight of the reward loss. Default: 1.0.
|
|
48
|
+
reco_loss (str, optional): the reconstruction loss. Default: "l2".
|
|
49
|
+
reward_loss (str, optional): the reward loss. Default: "l2".
|
|
50
|
+
free_nats (int, optional): the free nats. Default: 3.
|
|
51
|
+
delayed_clamp (bool, optional): if ``True``, the KL clamping occurs after
|
|
52
|
+
averaging. If False (default), the kl divergence is clamped to the
|
|
53
|
+
free nats value first and then averaged.
|
|
54
|
+
global_average (bool, optional): if ``True``, the losses will be averaged
|
|
55
|
+
over all dimensions. Otherwise, a sum will be performed over all
|
|
56
|
+
non-batch/time dimensions and an average over batch and time.
|
|
57
|
+
Default: False.
|
|
58
|
+
"""
|
|
59
|
+
|
|
60
|
+
@dataclass
|
|
61
|
+
class _AcceptedKeys:
|
|
62
|
+
"""Maintains default values for all configurable tensordict keys.
|
|
63
|
+
|
|
64
|
+
This class defines which tensordict keys can be set using '.set_keys(key_name=key_value)' and their
|
|
65
|
+
default values
|
|
66
|
+
|
|
67
|
+
Attributes:
|
|
68
|
+
reward (NestedKey): The reward is expected to be in the tensordict
|
|
69
|
+
key ("next", reward). Defaults to ``"reward"``.
|
|
70
|
+
true_reward (NestedKey): The `true_reward` will be stored in the
|
|
71
|
+
tensordict key ("next", true_reward). Defaults to ``"true_reward"``.
|
|
72
|
+
prior_mean (NestedKey): The prior mean is expected to be in the
|
|
73
|
+
tensordict key ("next", prior_mean). Defaults to ``"prior_mean"``.
|
|
74
|
+
prior_std (NestedKey): The prior mean is expected to be in the
|
|
75
|
+
tensordict key ("next", prior_mean). Defaults to ``"prior_mean"``.
|
|
76
|
+
posterior_mean (NestedKey): The posterior mean is expected to be in
|
|
77
|
+
the tensordict key ("next", prior_mean). Defaults to ``"posterior_mean"``.
|
|
78
|
+
posterior_std (NestedKey): The posterior std is expected to be in
|
|
79
|
+
the tensordict key ("next", prior_mean). Defaults to ``"posterior_std"``.
|
|
80
|
+
pixels (NestedKey): The pixels is expected to be in the tensordict key ("next", pixels).
|
|
81
|
+
Defaults to ``"pixels"``.
|
|
82
|
+
reco_pixels (NestedKey): The reconstruction pixels is expected to be
|
|
83
|
+
in the tensordict key ("next", reco_pixels). Defaults to ``"reco_pixels"``.
|
|
84
|
+
"""
|
|
85
|
+
|
|
86
|
+
reward: NestedKey = "reward"
|
|
87
|
+
true_reward: NestedKey = "true_reward"
|
|
88
|
+
prior_mean: NestedKey = "prior_mean"
|
|
89
|
+
prior_std: NestedKey = "prior_std"
|
|
90
|
+
posterior_mean: NestedKey = "posterior_mean"
|
|
91
|
+
posterior_std: NestedKey = "posterior_std"
|
|
92
|
+
pixels: NestedKey = "pixels"
|
|
93
|
+
reco_pixels: NestedKey = "reco_pixels"
|
|
94
|
+
|
|
95
|
+
tensor_keys: _AcceptedKeys
|
|
96
|
+
default_keys = _AcceptedKeys
|
|
97
|
+
|
|
98
|
+
decoder: TensorDictModule
|
|
99
|
+
reward_model: TensorDictModule
|
|
100
|
+
world_mdel: TensorDictModule
|
|
101
|
+
|
|
102
|
+
def __init__(
|
|
103
|
+
self,
|
|
104
|
+
world_model: TensorDictModule,
|
|
105
|
+
*,
|
|
106
|
+
lambda_kl: float = 1.0,
|
|
107
|
+
lambda_reco: float = 1.0,
|
|
108
|
+
lambda_reward: float = 1.0,
|
|
109
|
+
reco_loss: str | None = None,
|
|
110
|
+
reward_loss: str | None = None,
|
|
111
|
+
free_nats: int = 3,
|
|
112
|
+
delayed_clamp: bool = False,
|
|
113
|
+
global_average: bool = False,
|
|
114
|
+
):
|
|
115
|
+
super().__init__()
|
|
116
|
+
self.world_model = world_model
|
|
117
|
+
self.reco_loss = reco_loss if reco_loss is not None else "l2"
|
|
118
|
+
self.reward_loss = reward_loss if reward_loss is not None else "l2"
|
|
119
|
+
self.lambda_kl = lambda_kl
|
|
120
|
+
self.lambda_reco = lambda_reco
|
|
121
|
+
self.lambda_reward = lambda_reward
|
|
122
|
+
self.free_nats = free_nats
|
|
123
|
+
self.delayed_clamp = delayed_clamp
|
|
124
|
+
self.global_average = global_average
|
|
125
|
+
self.__dict__["decoder"] = self.world_model[0][-1]
|
|
126
|
+
self.__dict__["reward_model"] = self.world_model[1]
|
|
127
|
+
|
|
128
|
+
def _forward_value_estimator_keys(self, **kwargs) -> None:
|
|
129
|
+
pass
|
|
130
|
+
|
|
131
|
+
@_maybe_record_function_decorator("world_model_loss/forward")
|
|
132
|
+
def forward(self, tensordict: TensorDict) -> torch.Tensor:
|
|
133
|
+
tensordict = tensordict.copy()
|
|
134
|
+
tensordict.rename_key_(
|
|
135
|
+
("next", self.tensor_keys.reward),
|
|
136
|
+
("next", self.tensor_keys.true_reward),
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
tensordict = self.world_model(tensordict)
|
|
140
|
+
|
|
141
|
+
prior_mean = tensordict.get(("next", self.tensor_keys.prior_mean))
|
|
142
|
+
prior_std = tensordict.get(("next", self.tensor_keys.prior_std))
|
|
143
|
+
posterior_mean = tensordict.get(("next", self.tensor_keys.posterior_mean))
|
|
144
|
+
posterior_std = tensordict.get(("next", self.tensor_keys.posterior_std))
|
|
145
|
+
|
|
146
|
+
kl_loss = self.kl_loss(
|
|
147
|
+
prior_mean,
|
|
148
|
+
prior_std,
|
|
149
|
+
posterior_mean,
|
|
150
|
+
posterior_std,
|
|
151
|
+
).unsqueeze(-1)
|
|
152
|
+
|
|
153
|
+
# Ensure contiguous layout for torch.compile compatibility
|
|
154
|
+
# The gradient from distance_loss flows back through decoder convolutions
|
|
155
|
+
pixels = tensordict.get(("next", self.tensor_keys.pixels)).contiguous()
|
|
156
|
+
reco_pixels = tensordict.get(
|
|
157
|
+
("next", self.tensor_keys.reco_pixels)
|
|
158
|
+
).contiguous()
|
|
159
|
+
reco_loss = distance_loss(
|
|
160
|
+
pixels,
|
|
161
|
+
reco_pixels,
|
|
162
|
+
self.reco_loss,
|
|
163
|
+
)
|
|
164
|
+
if not self.global_average:
|
|
165
|
+
reco_loss = reco_loss.sum((-3, -2, -1))
|
|
166
|
+
reco_loss = reco_loss.mean().unsqueeze(-1)
|
|
167
|
+
|
|
168
|
+
true_reward = tensordict.get(("next", self.tensor_keys.true_reward))
|
|
169
|
+
pred_reward = tensordict.get(("next", self.tensor_keys.reward))
|
|
170
|
+
reward_loss = distance_loss(
|
|
171
|
+
true_reward,
|
|
172
|
+
pred_reward,
|
|
173
|
+
self.reward_loss,
|
|
174
|
+
)
|
|
175
|
+
if not self.global_average:
|
|
176
|
+
reward_loss = reward_loss.squeeze(-1)
|
|
177
|
+
reward_loss = reward_loss.mean().unsqueeze(-1)
|
|
178
|
+
|
|
179
|
+
td_out = TensorDict(
|
|
180
|
+
loss_model_kl=self.lambda_kl * kl_loss,
|
|
181
|
+
loss_model_reco=self.lambda_reco * reco_loss,
|
|
182
|
+
loss_model_reward=self.lambda_reward * reward_loss,
|
|
183
|
+
)
|
|
184
|
+
self._clear_weakrefs(tensordict, td_out)
|
|
185
|
+
|
|
186
|
+
return (td_out, tensordict.data)
|
|
187
|
+
|
|
188
|
+
@staticmethod
|
|
189
|
+
def normal_log_probability(x, mean, std):
|
|
190
|
+
return (
|
|
191
|
+
-0.5 * ((x.to(mean.dtype) - mean) / std).pow(2) - std.log()
|
|
192
|
+
) # - 0.5 * math.log(2 * math.pi)
|
|
193
|
+
|
|
194
|
+
def kl_loss(
|
|
195
|
+
self,
|
|
196
|
+
prior_mean: torch.Tensor,
|
|
197
|
+
prior_std: torch.Tensor,
|
|
198
|
+
posterior_mean: torch.Tensor,
|
|
199
|
+
posterior_std: torch.Tensor,
|
|
200
|
+
) -> torch.Tensor:
|
|
201
|
+
kl = (
|
|
202
|
+
torch.log(prior_std / posterior_std)
|
|
203
|
+
+ (posterior_std**2 + (prior_mean - posterior_mean) ** 2)
|
|
204
|
+
/ (2 * prior_std**2)
|
|
205
|
+
- 0.5
|
|
206
|
+
)
|
|
207
|
+
if not self.global_average:
|
|
208
|
+
kl = kl.sum(-1)
|
|
209
|
+
if self.delayed_clamp:
|
|
210
|
+
kl = kl.mean().clamp_min(self.free_nats)
|
|
211
|
+
else:
|
|
212
|
+
kl = kl.clamp_min(self.free_nats).mean()
|
|
213
|
+
return kl
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
class DreamerActorLoss(LossModule):
|
|
217
|
+
"""Dreamer Actor Loss.
|
|
218
|
+
|
|
219
|
+
Computes the loss of the dreamer actor. The actor loss is computed as the
|
|
220
|
+
negative average lambda return.
|
|
221
|
+
|
|
222
|
+
Reference: https://arxiv.org/abs/1912.01603.
|
|
223
|
+
|
|
224
|
+
Args:
|
|
225
|
+
actor_model (TensorDictModule): the actor model.
|
|
226
|
+
value_model (TensorDictModule): the value model.
|
|
227
|
+
model_based_env (DreamerEnv): the model based environment.
|
|
228
|
+
imagination_horizon (int, optional): The number of steps to unroll the
|
|
229
|
+
model. Defaults to ``15``.
|
|
230
|
+
discount_loss (bool, optional): if ``True``, the loss is discounted with a
|
|
231
|
+
gamma discount factor. Default to ``False``.
|
|
232
|
+
|
|
233
|
+
"""
|
|
234
|
+
|
|
235
|
+
@dataclass
|
|
236
|
+
class _AcceptedKeys:
|
|
237
|
+
"""Maintains default values for all configurable tensordict keys.
|
|
238
|
+
|
|
239
|
+
This class defines which tensordict keys can be set using '.set_keys(key_name=key_value)' and their
|
|
240
|
+
default values.
|
|
241
|
+
|
|
242
|
+
Attributes:
|
|
243
|
+
belief (NestedKey): The input tensordict key where the belief is expected.
|
|
244
|
+
Defaults to ``"belief"``.
|
|
245
|
+
reward (NestedKey): The reward is expected to be in the tensordict key ("next", reward).
|
|
246
|
+
Defaults to ``"reward"``.
|
|
247
|
+
value (NestedKey): The reward is expected to be in the tensordict key ("next", value).
|
|
248
|
+
Will be used for the underlying value estimator. Defaults to ``"state_value"``.
|
|
249
|
+
done (NestedKey): The input tensordict key where the flag if a
|
|
250
|
+
trajectory is done is expected ("next", done). Defaults to ``"done"``.
|
|
251
|
+
terminated (NestedKey): The input tensordict key where the flag if a
|
|
252
|
+
trajectory is terminated is expected ("next", terminated). Defaults to ``"terminated"``.
|
|
253
|
+
"""
|
|
254
|
+
|
|
255
|
+
belief: NestedKey = "belief"
|
|
256
|
+
reward: NestedKey = "reward"
|
|
257
|
+
value: NestedKey = "state_value"
|
|
258
|
+
done: NestedKey = "done"
|
|
259
|
+
terminated: NestedKey = "terminated"
|
|
260
|
+
|
|
261
|
+
tensor_keys: _AcceptedKeys
|
|
262
|
+
default_keys = _AcceptedKeys
|
|
263
|
+
default_value_estimator = ValueEstimators.TDLambda
|
|
264
|
+
|
|
265
|
+
value_model: TensorDictModule
|
|
266
|
+
actor_model: TensorDictModule
|
|
267
|
+
|
|
268
|
+
def __init__(
|
|
269
|
+
self,
|
|
270
|
+
actor_model: TensorDictModule,
|
|
271
|
+
value_model: TensorDictModule,
|
|
272
|
+
model_based_env: DreamerEnv,
|
|
273
|
+
*,
|
|
274
|
+
imagination_horizon: int = 15,
|
|
275
|
+
discount_loss: bool = True, # for consistency with paper
|
|
276
|
+
gamma: int | None = None,
|
|
277
|
+
lmbda: int | None = None,
|
|
278
|
+
):
|
|
279
|
+
super().__init__()
|
|
280
|
+
self.actor_model = actor_model
|
|
281
|
+
self.__dict__["value_model"] = value_model
|
|
282
|
+
self.model_based_env = model_based_env
|
|
283
|
+
self.imagination_horizon = imagination_horizon
|
|
284
|
+
self.discount_loss = discount_loss
|
|
285
|
+
if gamma is not None:
|
|
286
|
+
raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR)
|
|
287
|
+
if lmbda is not None:
|
|
288
|
+
raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR)
|
|
289
|
+
|
|
290
|
+
def _forward_value_estimator_keys(self, **kwargs) -> None:
|
|
291
|
+
if self._value_estimator is not None:
|
|
292
|
+
self._value_estimator.set_keys(
|
|
293
|
+
value=self._tensor_keys.value,
|
|
294
|
+
)
|
|
295
|
+
|
|
296
|
+
@_maybe_record_function_decorator("actor_loss/forward")
|
|
297
|
+
def forward(self, tensordict: TensorDict) -> tuple[TensorDict, TensorDict]:
|
|
298
|
+
tensordict = tensordict.select("state", self.tensor_keys.belief).data
|
|
299
|
+
|
|
300
|
+
with _maybe_timeit("actor_loss/time-rollout"), hold_out_net(
|
|
301
|
+
self.model_based_env
|
|
302
|
+
), set_exploration_type(ExplorationType.RANDOM):
|
|
303
|
+
tensordict = self.model_based_env.reset(tensordict.copy())
|
|
304
|
+
fake_data = self.model_based_env.rollout(
|
|
305
|
+
max_steps=self.imagination_horizon,
|
|
306
|
+
policy=self.actor_model,
|
|
307
|
+
auto_reset=False,
|
|
308
|
+
tensordict=tensordict,
|
|
309
|
+
)
|
|
310
|
+
next_tensordict = step_mdp(fake_data, keep_other=True)
|
|
311
|
+
with hold_out_net(self.value_model):
|
|
312
|
+
next_tensordict = self.value_model(next_tensordict)
|
|
313
|
+
|
|
314
|
+
reward = fake_data.get(("next", self.tensor_keys.reward))
|
|
315
|
+
next_value = next_tensordict.get(self.tensor_keys.value)
|
|
316
|
+
lambda_target = self.lambda_target(reward, next_value)
|
|
317
|
+
fake_data.set("lambda_target", lambda_target)
|
|
318
|
+
|
|
319
|
+
if self.discount_loss:
|
|
320
|
+
gamma = self.value_estimator.gamma.to(tensordict.device)
|
|
321
|
+
discount = gamma.expand(lambda_target.shape).clone()
|
|
322
|
+
discount[..., 0, :] = 1
|
|
323
|
+
discount = discount.cumprod(dim=-2)
|
|
324
|
+
actor_loss = -(lambda_target * discount).sum((-2, -1)).mean()
|
|
325
|
+
else:
|
|
326
|
+
actor_loss = -lambda_target.sum((-2, -1)).mean()
|
|
327
|
+
loss_tensordict = TensorDict({"loss_actor": actor_loss}, [])
|
|
328
|
+
self._clear_weakrefs(tensordict, loss_tensordict)
|
|
329
|
+
|
|
330
|
+
return loss_tensordict, fake_data.data
|
|
331
|
+
|
|
332
|
+
def lambda_target(self, reward: torch.Tensor, value: torch.Tensor) -> torch.Tensor:
|
|
333
|
+
done = torch.zeros(reward.shape, dtype=torch.bool, device=reward.device)
|
|
334
|
+
terminated = torch.zeros(reward.shape, dtype=torch.bool, device=reward.device)
|
|
335
|
+
input_tensordict = TensorDict(
|
|
336
|
+
{
|
|
337
|
+
("next", self.tensor_keys.reward): reward,
|
|
338
|
+
("next", self.tensor_keys.value): value,
|
|
339
|
+
("next", self.tensor_keys.done): done,
|
|
340
|
+
("next", self.tensor_keys.terminated): terminated,
|
|
341
|
+
},
|
|
342
|
+
[],
|
|
343
|
+
)
|
|
344
|
+
return self.value_estimator.value_estimate(input_tensordict)
|
|
345
|
+
|
|
346
|
+
def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams):
|
|
347
|
+
if value_type is None:
|
|
348
|
+
value_type = self.default_value_estimator
|
|
349
|
+
|
|
350
|
+
# Handle ValueEstimatorBase instance or class
|
|
351
|
+
if isinstance(value_type, ValueEstimatorBase) or (
|
|
352
|
+
isinstance(value_type, type) and issubclass(value_type, ValueEstimatorBase)
|
|
353
|
+
):
|
|
354
|
+
return LossModule.make_value_estimator(self, value_type, **hyperparams)
|
|
355
|
+
|
|
356
|
+
self.value_type = value_type
|
|
357
|
+
value_net = None
|
|
358
|
+
hp = dict(default_value_kwargs(value_type))
|
|
359
|
+
if hasattr(self, "gamma"):
|
|
360
|
+
hp["gamma"] = self.gamma
|
|
361
|
+
hp.update(hyperparams)
|
|
362
|
+
if value_type is ValueEstimators.TD1:
|
|
363
|
+
self._value_estimator = TD1Estimator(
|
|
364
|
+
**hp,
|
|
365
|
+
value_network=value_net,
|
|
366
|
+
)
|
|
367
|
+
elif value_type is ValueEstimators.TD0:
|
|
368
|
+
self._value_estimator = TD0Estimator(
|
|
369
|
+
**hp,
|
|
370
|
+
value_network=value_net,
|
|
371
|
+
)
|
|
372
|
+
elif value_type is ValueEstimators.GAE:
|
|
373
|
+
if hasattr(self, "lmbda"):
|
|
374
|
+
hp["lmbda"] = self.lmbda
|
|
375
|
+
raise NotImplementedError(
|
|
376
|
+
f"Value type {value_type} it not implemented for loss {type(self)}."
|
|
377
|
+
)
|
|
378
|
+
elif value_type is ValueEstimators.TDLambda:
|
|
379
|
+
if hasattr(self, "lmbda"):
|
|
380
|
+
hp["lmbda"] = self.lmbda
|
|
381
|
+
self._value_estimator = TDLambdaEstimator(
|
|
382
|
+
**hp,
|
|
383
|
+
value_network=value_net,
|
|
384
|
+
vectorized=True, # TODO: vectorized version seems not to be similar to the non vectorised
|
|
385
|
+
)
|
|
386
|
+
else:
|
|
387
|
+
raise NotImplementedError(f"Unknown value type {value_type}")
|
|
388
|
+
|
|
389
|
+
tensor_keys = {
|
|
390
|
+
"value": self.tensor_keys.value,
|
|
391
|
+
"value_target": "value_target",
|
|
392
|
+
}
|
|
393
|
+
self._value_estimator.set_keys(**tensor_keys)
|
|
394
|
+
|
|
395
|
+
|
|
396
|
+
class DreamerValueLoss(LossModule):
|
|
397
|
+
"""Dreamer Value Loss.
|
|
398
|
+
|
|
399
|
+
Computes the loss of the dreamer value model. The value loss is computed
|
|
400
|
+
between the predicted value and the lambda target.
|
|
401
|
+
|
|
402
|
+
Reference: https://arxiv.org/abs/1912.01603.
|
|
403
|
+
|
|
404
|
+
Args:
|
|
405
|
+
value_model (TensorDictModule): the value model.
|
|
406
|
+
value_loss (str, optional): the loss to use for the value loss.
|
|
407
|
+
Default: ``"l2"``.
|
|
408
|
+
discount_loss (bool, optional): if ``True``, the loss is discounted with a
|
|
409
|
+
gamma discount factor. Default: False.
|
|
410
|
+
gamma (:obj:`float`, optional): the gamma discount factor. Default: ``0.99``.
|
|
411
|
+
|
|
412
|
+
"""
|
|
413
|
+
|
|
414
|
+
@dataclass
|
|
415
|
+
class _AcceptedKeys:
|
|
416
|
+
"""Maintains default values for all configurable tensordict keys.
|
|
417
|
+
|
|
418
|
+
This class defines which tensordict keys can be set using '.set_keys(key_name=key_value)' and their
|
|
419
|
+
default values
|
|
420
|
+
|
|
421
|
+
Attributes:
|
|
422
|
+
value (NestedKey): The input tensordict key where the state value is expected.
|
|
423
|
+
Defaults to ``"state_value"``.
|
|
424
|
+
"""
|
|
425
|
+
|
|
426
|
+
value: NestedKey = "state_value"
|
|
427
|
+
|
|
428
|
+
tensor_keys: _AcceptedKeys
|
|
429
|
+
default_keys = _AcceptedKeys
|
|
430
|
+
|
|
431
|
+
value_model: TensorDictModule
|
|
432
|
+
|
|
433
|
+
def __init__(
|
|
434
|
+
self,
|
|
435
|
+
value_model: TensorDictModule,
|
|
436
|
+
value_loss: str | None = None,
|
|
437
|
+
discount_loss: bool = True, # for consistency with paper
|
|
438
|
+
gamma: int = 0.99,
|
|
439
|
+
):
|
|
440
|
+
super().__init__()
|
|
441
|
+
self.value_model = value_model
|
|
442
|
+
self.value_loss = value_loss if value_loss is not None else "l2"
|
|
443
|
+
self.gamma = gamma
|
|
444
|
+
self.discount_loss = discount_loss
|
|
445
|
+
|
|
446
|
+
def _forward_value_estimator_keys(self, **kwargs) -> None:
|
|
447
|
+
pass
|
|
448
|
+
|
|
449
|
+
@_maybe_record_function_decorator("value_loss/forward")
|
|
450
|
+
def forward(self, fake_data) -> torch.Tensor:
|
|
451
|
+
lambda_target = fake_data.get("lambda_target")
|
|
452
|
+
|
|
453
|
+
tensordict_select = fake_data.select(*self.value_model.in_keys, strict=False)
|
|
454
|
+
self.value_model(tensordict_select)
|
|
455
|
+
|
|
456
|
+
if self.discount_loss:
|
|
457
|
+
discount = self.gamma * torch.ones_like(
|
|
458
|
+
lambda_target, device=lambda_target.device
|
|
459
|
+
)
|
|
460
|
+
discount[..., 0, :] = 1
|
|
461
|
+
discount = discount.cumprod(dim=-2)
|
|
462
|
+
value_loss = (
|
|
463
|
+
(
|
|
464
|
+
discount
|
|
465
|
+
* distance_loss(
|
|
466
|
+
tensordict_select.get(self.tensor_keys.value),
|
|
467
|
+
lambda_target,
|
|
468
|
+
self.value_loss,
|
|
469
|
+
)
|
|
470
|
+
)
|
|
471
|
+
.sum((-1, -2))
|
|
472
|
+
.mean()
|
|
473
|
+
)
|
|
474
|
+
else:
|
|
475
|
+
value_loss = (
|
|
476
|
+
distance_loss(
|
|
477
|
+
tensordict_select.get(self.tensor_keys.value),
|
|
478
|
+
lambda_target,
|
|
479
|
+
self.value_loss,
|
|
480
|
+
)
|
|
481
|
+
.sum((-1, -2))
|
|
482
|
+
.mean()
|
|
483
|
+
)
|
|
484
|
+
|
|
485
|
+
loss_tensordict = TensorDict({"loss_value": value_loss})
|
|
486
|
+
self._clear_weakrefs(fake_data, loss_tensordict)
|
|
487
|
+
|
|
488
|
+
return loss_tensordict, fake_data
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def cross_entropy_loss(
|
|
11
|
+
log_policy: torch.Tensor, action: torch.Tensor, inplace: bool = False
|
|
12
|
+
) -> torch.Tensor:
|
|
13
|
+
"""Returns the cross entropy loss defined as the log-softmax value indexed by the action index.
|
|
14
|
+
|
|
15
|
+
Supports discrete (integer) actions or one-hot encodings.
|
|
16
|
+
|
|
17
|
+
Args:
|
|
18
|
+
log_policy: Tensor of the log_softmax values of the policy.
|
|
19
|
+
action: Integer or one-hot representation of the actions undertaken. Must have a shape log_policy.shape[:-1]
|
|
20
|
+
(integer representation) or log_policy.shape (one-hot).
|
|
21
|
+
inplace: fills log_policy in-place with 0.0 at non-selected actions before summing along the last dimensions.
|
|
22
|
+
This is usually faster but it will change the value of log-policy in place, which may lead to unwanted
|
|
23
|
+
behaviors.
|
|
24
|
+
|
|
25
|
+
"""
|
|
26
|
+
if action.shape == log_policy.shape:
|
|
27
|
+
if action.dtype not in (torch.bool, torch.long, torch.uint8):
|
|
28
|
+
raise TypeError(
|
|
29
|
+
f"Cross-entropy loss with {action.dtype} dtype is not permitted"
|
|
30
|
+
)
|
|
31
|
+
if not ((action == 1).sum(-1) == 1).all():
|
|
32
|
+
raise RuntimeError(
|
|
33
|
+
"Expected the action tensor to be a one hot encoding of the actions taken, "
|
|
34
|
+
"but got more/less than one non-null boolean index on the last dimension"
|
|
35
|
+
)
|
|
36
|
+
if inplace:
|
|
37
|
+
cross_entropy = log_policy.masked_fill_(action, 0.0).sum(-1)
|
|
38
|
+
else:
|
|
39
|
+
cross_entropy = (log_policy * action).sum(-1)
|
|
40
|
+
elif action.shape == log_policy.shape[:-1]:
|
|
41
|
+
cross_entropy = torch.gather(log_policy, dim=-1, index=action[..., None])
|
|
42
|
+
cross_entropy.squeeze_(-1)
|
|
43
|
+
else:
|
|
44
|
+
raise RuntimeError(
|
|
45
|
+
f"unexpected action shape in cross_entropy_loss with log_policy.shape={log_policy.shape} and"
|
|
46
|
+
f"action.shape={action.shape}"
|
|
47
|
+
)
|
|
48
|
+
return cross_entropy
|