torchrl 0.11.0__cp314-cp314t-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/benchmark_batched_envs.py +104 -0
- benchmarks/conftest.py +91 -0
- benchmarks/ecosystem/gym_env_throughput.py +321 -0
- benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
- benchmarks/requirements.txt +7 -0
- benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
- benchmarks/test_collectors_benchmark.py +240 -0
- benchmarks/test_compressed_storage_benchmark.py +145 -0
- benchmarks/test_envs_benchmark.py +133 -0
- benchmarks/test_llm.py +101 -0
- benchmarks/test_non_tensor_env_benchmark.py +70 -0
- benchmarks/test_objectives_benchmarks.py +1199 -0
- benchmarks/test_replaybuffer_benchmark.py +254 -0
- sota-check/README.md +35 -0
- sota-implementations/README.md +142 -0
- sota-implementations/a2c/README.md +39 -0
- sota-implementations/a2c/a2c_atari.py +291 -0
- sota-implementations/a2c/a2c_mujoco.py +273 -0
- sota-implementations/a2c/utils_atari.py +240 -0
- sota-implementations/a2c/utils_mujoco.py +160 -0
- sota-implementations/bandits/README.md +7 -0
- sota-implementations/bandits/dqn.py +126 -0
- sota-implementations/cql/cql_offline.py +198 -0
- sota-implementations/cql/cql_online.py +249 -0
- sota-implementations/cql/discrete_cql_offline.py +180 -0
- sota-implementations/cql/discrete_cql_online.py +227 -0
- sota-implementations/cql/utils.py +471 -0
- sota-implementations/crossq/crossq.py +271 -0
- sota-implementations/crossq/utils.py +320 -0
- sota-implementations/ddpg/ddpg.py +231 -0
- sota-implementations/ddpg/utils.py +325 -0
- sota-implementations/decision_transformer/dt.py +163 -0
- sota-implementations/decision_transformer/lamb.py +167 -0
- sota-implementations/decision_transformer/online_dt.py +178 -0
- sota-implementations/decision_transformer/utils.py +562 -0
- sota-implementations/discrete_sac/discrete_sac.py +243 -0
- sota-implementations/discrete_sac/utils.py +324 -0
- sota-implementations/dqn/README.md +30 -0
- sota-implementations/dqn/dqn_atari.py +272 -0
- sota-implementations/dqn/dqn_cartpole.py +236 -0
- sota-implementations/dqn/utils_atari.py +132 -0
- sota-implementations/dqn/utils_cartpole.py +90 -0
- sota-implementations/dreamer/README.md +129 -0
- sota-implementations/dreamer/dreamer.py +586 -0
- sota-implementations/dreamer/dreamer_utils.py +1107 -0
- sota-implementations/expert-iteration/README.md +352 -0
- sota-implementations/expert-iteration/ei_utils.py +770 -0
- sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
- sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
- sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
- sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
- sota-implementations/gail/gail.py +327 -0
- sota-implementations/gail/gail_utils.py +68 -0
- sota-implementations/gail/ppo_utils.py +157 -0
- sota-implementations/grpo/README.md +273 -0
- sota-implementations/grpo/grpo-async.py +437 -0
- sota-implementations/grpo/grpo-sync.py +435 -0
- sota-implementations/grpo/grpo_utils.py +843 -0
- sota-implementations/grpo/requirements_gsm8k.txt +11 -0
- sota-implementations/grpo/requirements_ifeval.txt +16 -0
- sota-implementations/impala/README.md +33 -0
- sota-implementations/impala/impala_multi_node_ray.py +292 -0
- sota-implementations/impala/impala_multi_node_submitit.py +284 -0
- sota-implementations/impala/impala_single_node.py +261 -0
- sota-implementations/impala/utils.py +184 -0
- sota-implementations/iql/discrete_iql.py +230 -0
- sota-implementations/iql/iql_offline.py +164 -0
- sota-implementations/iql/iql_online.py +225 -0
- sota-implementations/iql/utils.py +437 -0
- sota-implementations/multiagent/README.md +74 -0
- sota-implementations/multiagent/iql.py +237 -0
- sota-implementations/multiagent/maddpg_iddpg.py +266 -0
- sota-implementations/multiagent/mappo_ippo.py +267 -0
- sota-implementations/multiagent/qmix_vdn.py +271 -0
- sota-implementations/multiagent/sac.py +337 -0
- sota-implementations/multiagent/utils/__init__.py +4 -0
- sota-implementations/multiagent/utils/logging.py +151 -0
- sota-implementations/multiagent/utils/utils.py +43 -0
- sota-implementations/ppo/README.md +29 -0
- sota-implementations/ppo/ppo_atari.py +305 -0
- sota-implementations/ppo/ppo_mujoco.py +293 -0
- sota-implementations/ppo/utils_atari.py +238 -0
- sota-implementations/ppo/utils_mujoco.py +152 -0
- sota-implementations/ppo_trainer/train.py +21 -0
- sota-implementations/redq/README.md +7 -0
- sota-implementations/redq/redq.py +199 -0
- sota-implementations/redq/utils.py +1060 -0
- sota-implementations/sac/sac-async.py +266 -0
- sota-implementations/sac/sac.py +239 -0
- sota-implementations/sac/utils.py +381 -0
- sota-implementations/sac_trainer/train.py +16 -0
- sota-implementations/td3/td3.py +254 -0
- sota-implementations/td3/utils.py +319 -0
- sota-implementations/td3_bc/td3_bc.py +177 -0
- sota-implementations/td3_bc/utils.py +251 -0
- torchrl/__init__.py +144 -0
- torchrl/_extension.py +74 -0
- torchrl/_torchrl.cp314t-win_amd64.pyd +0 -0
- torchrl/_utils.py +1431 -0
- torchrl/collectors/__init__.py +48 -0
- torchrl/collectors/_base.py +1058 -0
- torchrl/collectors/_constants.py +88 -0
- torchrl/collectors/_multi_async.py +324 -0
- torchrl/collectors/_multi_base.py +1805 -0
- torchrl/collectors/_multi_sync.py +464 -0
- torchrl/collectors/_runner.py +581 -0
- torchrl/collectors/_single.py +2009 -0
- torchrl/collectors/_single_async.py +259 -0
- torchrl/collectors/collectors.py +62 -0
- torchrl/collectors/distributed/__init__.py +32 -0
- torchrl/collectors/distributed/default_configs.py +133 -0
- torchrl/collectors/distributed/generic.py +1306 -0
- torchrl/collectors/distributed/ray.py +1092 -0
- torchrl/collectors/distributed/rpc.py +1006 -0
- torchrl/collectors/distributed/sync.py +731 -0
- torchrl/collectors/distributed/utils.py +160 -0
- torchrl/collectors/llm/__init__.py +10 -0
- torchrl/collectors/llm/base.py +494 -0
- torchrl/collectors/llm/ray_collector.py +275 -0
- torchrl/collectors/llm/utils.py +36 -0
- torchrl/collectors/llm/weight_update/__init__.py +10 -0
- torchrl/collectors/llm/weight_update/vllm.py +348 -0
- torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
- torchrl/collectors/utils.py +433 -0
- torchrl/collectors/weight_update.py +591 -0
- torchrl/csrc/numpy_utils.h +38 -0
- torchrl/csrc/pybind.cpp +27 -0
- torchrl/csrc/segment_tree.h +458 -0
- torchrl/csrc/torch_utils.h +34 -0
- torchrl/csrc/utils.cpp +48 -0
- torchrl/csrc/utils.h +31 -0
- torchrl/data/__init__.py +187 -0
- torchrl/data/datasets/__init__.py +58 -0
- torchrl/data/datasets/atari_dqn.py +878 -0
- torchrl/data/datasets/common.py +281 -0
- torchrl/data/datasets/d4rl.py +489 -0
- torchrl/data/datasets/d4rl_infos.py +187 -0
- torchrl/data/datasets/gen_dgrl.py +375 -0
- torchrl/data/datasets/minari_data.py +643 -0
- torchrl/data/datasets/openml.py +177 -0
- torchrl/data/datasets/openx.py +798 -0
- torchrl/data/datasets/roboset.py +363 -0
- torchrl/data/datasets/utils.py +11 -0
- torchrl/data/datasets/vd4rl.py +432 -0
- torchrl/data/llm/__init__.py +34 -0
- torchrl/data/llm/dataset.py +491 -0
- torchrl/data/llm/history.py +1378 -0
- torchrl/data/llm/prompt.py +198 -0
- torchrl/data/llm/reward.py +225 -0
- torchrl/data/llm/topk.py +186 -0
- torchrl/data/llm/utils.py +543 -0
- torchrl/data/map/__init__.py +21 -0
- torchrl/data/map/hash.py +185 -0
- torchrl/data/map/query.py +204 -0
- torchrl/data/map/tdstorage.py +363 -0
- torchrl/data/map/tree.py +1434 -0
- torchrl/data/map/utils.py +103 -0
- torchrl/data/postprocs/__init__.py +8 -0
- torchrl/data/postprocs/postprocs.py +391 -0
- torchrl/data/replay_buffers/__init__.py +99 -0
- torchrl/data/replay_buffers/checkpointers.py +622 -0
- torchrl/data/replay_buffers/ray_buffer.py +292 -0
- torchrl/data/replay_buffers/replay_buffers.py +2376 -0
- torchrl/data/replay_buffers/samplers.py +2578 -0
- torchrl/data/replay_buffers/scheduler.py +265 -0
- torchrl/data/replay_buffers/storages.py +2412 -0
- torchrl/data/replay_buffers/utils.py +1042 -0
- torchrl/data/replay_buffers/writers.py +781 -0
- torchrl/data/tensor_specs.py +7101 -0
- torchrl/data/utils.py +334 -0
- torchrl/envs/__init__.py +265 -0
- torchrl/envs/async_envs.py +1105 -0
- torchrl/envs/batched_envs.py +3093 -0
- torchrl/envs/common.py +4241 -0
- torchrl/envs/custom/__init__.py +11 -0
- torchrl/envs/custom/chess.py +617 -0
- torchrl/envs/custom/llm.py +214 -0
- torchrl/envs/custom/pendulum.py +401 -0
- torchrl/envs/custom/san_moves.txt +29274 -0
- torchrl/envs/custom/tictactoeenv.py +288 -0
- torchrl/envs/env_creator.py +263 -0
- torchrl/envs/gym_like.py +752 -0
- torchrl/envs/libs/__init__.py +68 -0
- torchrl/envs/libs/_gym_utils.py +326 -0
- torchrl/envs/libs/brax.py +846 -0
- torchrl/envs/libs/dm_control.py +544 -0
- torchrl/envs/libs/envpool.py +447 -0
- torchrl/envs/libs/gym.py +2239 -0
- torchrl/envs/libs/habitat.py +138 -0
- torchrl/envs/libs/isaac_lab.py +87 -0
- torchrl/envs/libs/isaacgym.py +203 -0
- torchrl/envs/libs/jax_utils.py +166 -0
- torchrl/envs/libs/jumanji.py +963 -0
- torchrl/envs/libs/meltingpot.py +599 -0
- torchrl/envs/libs/openml.py +153 -0
- torchrl/envs/libs/openspiel.py +652 -0
- torchrl/envs/libs/pettingzoo.py +1042 -0
- torchrl/envs/libs/procgen.py +351 -0
- torchrl/envs/libs/robohive.py +429 -0
- torchrl/envs/libs/smacv2.py +645 -0
- torchrl/envs/libs/unity_mlagents.py +891 -0
- torchrl/envs/libs/utils.py +147 -0
- torchrl/envs/libs/vmas.py +813 -0
- torchrl/envs/llm/__init__.py +63 -0
- torchrl/envs/llm/chat.py +730 -0
- torchrl/envs/llm/datasets/README.md +4 -0
- torchrl/envs/llm/datasets/__init__.py +17 -0
- torchrl/envs/llm/datasets/gsm8k.py +353 -0
- torchrl/envs/llm/datasets/ifeval.py +274 -0
- torchrl/envs/llm/envs.py +789 -0
- torchrl/envs/llm/libs/README.md +3 -0
- torchrl/envs/llm/libs/__init__.py +8 -0
- torchrl/envs/llm/libs/mlgym.py +869 -0
- torchrl/envs/llm/reward/__init__.py +10 -0
- torchrl/envs/llm/reward/gsm8k.py +324 -0
- torchrl/envs/llm/reward/ifeval/README.md +13 -0
- torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
- torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
- torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
- torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
- torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
- torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
- torchrl/envs/llm/transforms/__init__.py +55 -0
- torchrl/envs/llm/transforms/browser.py +292 -0
- torchrl/envs/llm/transforms/dataloading.py +859 -0
- torchrl/envs/llm/transforms/format.py +73 -0
- torchrl/envs/llm/transforms/kl.py +1544 -0
- torchrl/envs/llm/transforms/policy_version.py +189 -0
- torchrl/envs/llm/transforms/reason.py +323 -0
- torchrl/envs/llm/transforms/tokenizer.py +321 -0
- torchrl/envs/llm/transforms/tools.py +1955 -0
- torchrl/envs/model_based/__init__.py +9 -0
- torchrl/envs/model_based/common.py +180 -0
- torchrl/envs/model_based/dreamer.py +112 -0
- torchrl/envs/transforms/__init__.py +147 -0
- torchrl/envs/transforms/functional.py +48 -0
- torchrl/envs/transforms/gym_transforms.py +203 -0
- torchrl/envs/transforms/module.py +341 -0
- torchrl/envs/transforms/r3m.py +372 -0
- torchrl/envs/transforms/ray_service.py +663 -0
- torchrl/envs/transforms/rb_transforms.py +214 -0
- torchrl/envs/transforms/transforms.py +11835 -0
- torchrl/envs/transforms/utils.py +94 -0
- torchrl/envs/transforms/vc1.py +307 -0
- torchrl/envs/transforms/vecnorm.py +845 -0
- torchrl/envs/transforms/vip.py +407 -0
- torchrl/envs/utils.py +1718 -0
- torchrl/envs/vec_envs.py +11 -0
- torchrl/modules/__init__.py +206 -0
- torchrl/modules/distributions/__init__.py +73 -0
- torchrl/modules/distributions/continuous.py +830 -0
- torchrl/modules/distributions/discrete.py +908 -0
- torchrl/modules/distributions/truncated_normal.py +187 -0
- torchrl/modules/distributions/utils.py +233 -0
- torchrl/modules/llm/__init__.py +62 -0
- torchrl/modules/llm/backends/__init__.py +65 -0
- torchrl/modules/llm/backends/vllm/__init__.py +94 -0
- torchrl/modules/llm/backends/vllm/_models.py +46 -0
- torchrl/modules/llm/backends/vllm/base.py +72 -0
- torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
- torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
- torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
- torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
- torchrl/modules/llm/policies/__init__.py +28 -0
- torchrl/modules/llm/policies/common.py +1809 -0
- torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
- torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
- torchrl/modules/llm/utils.py +23 -0
- torchrl/modules/mcts/__init__.py +21 -0
- torchrl/modules/mcts/scores.py +579 -0
- torchrl/modules/models/__init__.py +86 -0
- torchrl/modules/models/batchrenorm.py +119 -0
- torchrl/modules/models/decision_transformer.py +179 -0
- torchrl/modules/models/exploration.py +731 -0
- torchrl/modules/models/llm.py +156 -0
- torchrl/modules/models/model_based.py +596 -0
- torchrl/modules/models/models.py +1712 -0
- torchrl/modules/models/multiagent.py +1067 -0
- torchrl/modules/models/recipes/impala.py +185 -0
- torchrl/modules/models/utils.py +162 -0
- torchrl/modules/planners/__init__.py +10 -0
- torchrl/modules/planners/cem.py +228 -0
- torchrl/modules/planners/common.py +73 -0
- torchrl/modules/planners/mppi.py +265 -0
- torchrl/modules/tensordict_module/__init__.py +89 -0
- torchrl/modules/tensordict_module/actors.py +2457 -0
- torchrl/modules/tensordict_module/common.py +529 -0
- torchrl/modules/tensordict_module/exploration.py +814 -0
- torchrl/modules/tensordict_module/probabilistic.py +321 -0
- torchrl/modules/tensordict_module/rnn.py +1639 -0
- torchrl/modules/tensordict_module/sequence.py +132 -0
- torchrl/modules/tensordict_module/world_models.py +34 -0
- torchrl/modules/utils/__init__.py +38 -0
- torchrl/modules/utils/mappings.py +9 -0
- torchrl/modules/utils/utils.py +89 -0
- torchrl/objectives/__init__.py +78 -0
- torchrl/objectives/a2c.py +659 -0
- torchrl/objectives/common.py +753 -0
- torchrl/objectives/cql.py +1346 -0
- torchrl/objectives/crossq.py +710 -0
- torchrl/objectives/ddpg.py +453 -0
- torchrl/objectives/decision_transformer.py +371 -0
- torchrl/objectives/deprecated.py +516 -0
- torchrl/objectives/dqn.py +683 -0
- torchrl/objectives/dreamer.py +488 -0
- torchrl/objectives/functional.py +48 -0
- torchrl/objectives/gail.py +258 -0
- torchrl/objectives/iql.py +996 -0
- torchrl/objectives/llm/__init__.py +30 -0
- torchrl/objectives/llm/grpo.py +846 -0
- torchrl/objectives/llm/sft.py +482 -0
- torchrl/objectives/multiagent/__init__.py +8 -0
- torchrl/objectives/multiagent/qmixer.py +396 -0
- torchrl/objectives/ppo.py +1669 -0
- torchrl/objectives/redq.py +683 -0
- torchrl/objectives/reinforce.py +530 -0
- torchrl/objectives/sac.py +1580 -0
- torchrl/objectives/td3.py +570 -0
- torchrl/objectives/td3_bc.py +625 -0
- torchrl/objectives/utils.py +782 -0
- torchrl/objectives/value/__init__.py +28 -0
- torchrl/objectives/value/advantages.py +1956 -0
- torchrl/objectives/value/functional.py +1459 -0
- torchrl/objectives/value/utils.py +360 -0
- torchrl/record/__init__.py +17 -0
- torchrl/record/loggers/__init__.py +23 -0
- torchrl/record/loggers/common.py +48 -0
- torchrl/record/loggers/csv.py +226 -0
- torchrl/record/loggers/mlflow.py +142 -0
- torchrl/record/loggers/tensorboard.py +139 -0
- torchrl/record/loggers/trackio.py +163 -0
- torchrl/record/loggers/utils.py +78 -0
- torchrl/record/loggers/wandb.py +214 -0
- torchrl/record/recorder.py +554 -0
- torchrl/services/__init__.py +79 -0
- torchrl/services/base.py +109 -0
- torchrl/services/ray_service.py +453 -0
- torchrl/testing/__init__.py +107 -0
- torchrl/testing/assertions.py +179 -0
- torchrl/testing/dist_utils.py +122 -0
- torchrl/testing/env_creators.py +227 -0
- torchrl/testing/env_helper.py +35 -0
- torchrl/testing/gym_helpers.py +156 -0
- torchrl/testing/llm_mocks.py +119 -0
- torchrl/testing/mocking_classes.py +2720 -0
- torchrl/testing/modules.py +295 -0
- torchrl/testing/mp_helpers.py +15 -0
- torchrl/testing/ray_helpers.py +293 -0
- torchrl/testing/utils.py +190 -0
- torchrl/trainers/__init__.py +42 -0
- torchrl/trainers/algorithms/__init__.py +11 -0
- torchrl/trainers/algorithms/configs/__init__.py +705 -0
- torchrl/trainers/algorithms/configs/collectors.py +216 -0
- torchrl/trainers/algorithms/configs/common.py +41 -0
- torchrl/trainers/algorithms/configs/data.py +308 -0
- torchrl/trainers/algorithms/configs/envs.py +104 -0
- torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
- torchrl/trainers/algorithms/configs/logging.py +80 -0
- torchrl/trainers/algorithms/configs/modules.py +570 -0
- torchrl/trainers/algorithms/configs/objectives.py +177 -0
- torchrl/trainers/algorithms/configs/trainers.py +340 -0
- torchrl/trainers/algorithms/configs/transforms.py +955 -0
- torchrl/trainers/algorithms/configs/utils.py +252 -0
- torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
- torchrl/trainers/algorithms/configs/weight_update.py +159 -0
- torchrl/trainers/algorithms/ppo.py +373 -0
- torchrl/trainers/algorithms/sac.py +308 -0
- torchrl/trainers/helpers/__init__.py +40 -0
- torchrl/trainers/helpers/collectors.py +416 -0
- torchrl/trainers/helpers/envs.py +573 -0
- torchrl/trainers/helpers/logger.py +33 -0
- torchrl/trainers/helpers/losses.py +132 -0
- torchrl/trainers/helpers/models.py +658 -0
- torchrl/trainers/helpers/replay_buffer.py +59 -0
- torchrl/trainers/helpers/trainers.py +301 -0
- torchrl/trainers/trainers.py +2052 -0
- torchrl/weight_update/__init__.py +33 -0
- torchrl/weight_update/_distributed.py +749 -0
- torchrl/weight_update/_mp.py +624 -0
- torchrl/weight_update/_noupdate.py +102 -0
- torchrl/weight_update/_ray.py +1032 -0
- torchrl/weight_update/_rpc.py +284 -0
- torchrl/weight_update/_shared.py +891 -0
- torchrl/weight_update/llm/__init__.py +32 -0
- torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
- torchrl/weight_update/llm/vllm_nccl.py +710 -0
- torchrl/weight_update/utils.py +73 -0
- torchrl/weight_update/weight_sync_schemes.py +1244 -0
- torchrl-0.11.0.dist-info/LICENSE +21 -0
- torchrl-0.11.0.dist-info/METADATA +1307 -0
- torchrl-0.11.0.dist-info/RECORD +394 -0
- torchrl-0.11.0.dist-info/WHEEL +5 -0
- torchrl-0.11.0.dist-info/entry_points.txt +2 -0
- torchrl-0.11.0.dist-info/top_level.txt +7 -0
|
@@ -0,0 +1,908 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
from collections.abc import Sequence
|
|
8
|
+
|
|
9
|
+
from enum import Enum
|
|
10
|
+
from functools import wraps
|
|
11
|
+
|
|
12
|
+
import torch
|
|
13
|
+
import torch.distributions as D
|
|
14
|
+
import torch.nn.functional as F
|
|
15
|
+
from tensordict.utils import expand_as_right
|
|
16
|
+
|
|
17
|
+
from torch.distributions.utils import lazy_property, logits_to_probs, probs_to_logits
|
|
18
|
+
|
|
19
|
+
__all__ = [
|
|
20
|
+
"OneHotCategorical",
|
|
21
|
+
"MaskedCategorical",
|
|
22
|
+
"Ordinal",
|
|
23
|
+
"OneHotOrdinal",
|
|
24
|
+
"LLMMaskedCategorical",
|
|
25
|
+
]
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def _treat_categorical_params(
|
|
29
|
+
params: torch.Tensor | None = None,
|
|
30
|
+
) -> torch.Tensor | None:
|
|
31
|
+
if params is None:
|
|
32
|
+
return None
|
|
33
|
+
if params.shape[-1] == 1:
|
|
34
|
+
params = params[..., 0]
|
|
35
|
+
return params
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def rand_one_hot(values: torch.Tensor, do_softmax: bool = True) -> torch.Tensor:
|
|
39
|
+
if do_softmax:
|
|
40
|
+
values = values.softmax(-1)
|
|
41
|
+
out = values.cumsum(-1) > torch.rand_like(values[..., :1])
|
|
42
|
+
out = (out.cumsum(-1) == 1).to(torch.long)
|
|
43
|
+
return out
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class _one_hot_wrapper:
|
|
47
|
+
def __init__(self, parent_dist):
|
|
48
|
+
self.parent_dist = parent_dist
|
|
49
|
+
|
|
50
|
+
def __call__(self, func):
|
|
51
|
+
@wraps(func)
|
|
52
|
+
def wrapped(_self, *args, **kwargs):
|
|
53
|
+
out = getattr(self.parent_dist, func.__name__)(_self, *args, **kwargs)
|
|
54
|
+
n = _self.num_samples
|
|
55
|
+
return torch.nn.functional.one_hot(out, n)
|
|
56
|
+
|
|
57
|
+
return wrapped
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class ReparamGradientStrategy(Enum):
|
|
61
|
+
PassThrough = 1
|
|
62
|
+
RelaxedOneHot = 2
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
class OneHotCategorical(D.Categorical):
|
|
66
|
+
"""One-hot categorical distribution.
|
|
67
|
+
|
|
68
|
+
This class behaves exactly as torch.distributions.Categorical except that it reads and produces one-hot encodings
|
|
69
|
+
of the discrete tensors.
|
|
70
|
+
|
|
71
|
+
Args:
|
|
72
|
+
logits (torch.Tensor): event log probabilities (unnormalized)
|
|
73
|
+
probs (torch.Tensor): event probabilities
|
|
74
|
+
grad_method (ReparamGradientStrategy, optional): strategy to gather
|
|
75
|
+
reparameterized samples.
|
|
76
|
+
``ReparamGradientStrategy.PassThrough`` will compute the sample gradients
|
|
77
|
+
by using the softmax valued log-probability as a proxy to the
|
|
78
|
+
sample gradients.
|
|
79
|
+
``ReparamGradientStrategy.RelaxedOneHot`` will use
|
|
80
|
+
:class:`torch.distributions.RelaxedOneHot` to sample from the distribution.
|
|
81
|
+
|
|
82
|
+
Examples:
|
|
83
|
+
>>> torch.manual_seed(0)
|
|
84
|
+
>>> logits = torch.randn(4)
|
|
85
|
+
>>> dist = OneHotCategorical(logits=logits)
|
|
86
|
+
>>> print(dist.rsample((3,)))
|
|
87
|
+
tensor([[1., 0., 0., 0.],
|
|
88
|
+
[0., 0., 0., 1.],
|
|
89
|
+
[1., 0., 0., 0.]])
|
|
90
|
+
|
|
91
|
+
"""
|
|
92
|
+
|
|
93
|
+
num_params: int = 1
|
|
94
|
+
|
|
95
|
+
# This is to make the compiler happy, see https://github.com/pytorch/pytorch/issues/140266
|
|
96
|
+
@lazy_property
|
|
97
|
+
def logits(self):
|
|
98
|
+
return probs_to_logits(self.probs)
|
|
99
|
+
|
|
100
|
+
@lazy_property
|
|
101
|
+
def probs(self):
|
|
102
|
+
return logits_to_probs(self.logits)
|
|
103
|
+
|
|
104
|
+
def __init__(
|
|
105
|
+
self,
|
|
106
|
+
logits: torch.Tensor | None = None,
|
|
107
|
+
probs: torch.Tensor | None = None,
|
|
108
|
+
grad_method: ReparamGradientStrategy = ReparamGradientStrategy.PassThrough,
|
|
109
|
+
**kwargs,
|
|
110
|
+
) -> None:
|
|
111
|
+
logits = _treat_categorical_params(logits)
|
|
112
|
+
probs = _treat_categorical_params(probs)
|
|
113
|
+
self.grad_method = grad_method
|
|
114
|
+
super().__init__(probs=probs, logits=logits, **kwargs)
|
|
115
|
+
# Get num_samples from logits or probs shape
|
|
116
|
+
if logits is not None:
|
|
117
|
+
self.num_samples = logits.shape[-1]
|
|
118
|
+
else:
|
|
119
|
+
self.num_samples = probs.shape[-1]
|
|
120
|
+
|
|
121
|
+
def log_prob(self, value: torch.Tensor) -> torch.Tensor:
|
|
122
|
+
return super().log_prob(value.argmax(dim=-1))
|
|
123
|
+
|
|
124
|
+
@property
|
|
125
|
+
def mode(self) -> torch.Tensor:
|
|
126
|
+
if hasattr(self, "logits"):
|
|
127
|
+
return (self.logits == self.logits.max(-1, True)[0]).to(torch.long)
|
|
128
|
+
else:
|
|
129
|
+
return (self.probs == self.probs.max(-1, True)[0]).to(torch.long)
|
|
130
|
+
|
|
131
|
+
@property
|
|
132
|
+
def deterministic_sample(self):
|
|
133
|
+
return self.mode
|
|
134
|
+
|
|
135
|
+
def entropy(self):
|
|
136
|
+
min_real = torch.finfo(self.logits.dtype).min
|
|
137
|
+
logits = torch.clamp(self.logits, min=min_real)
|
|
138
|
+
p_log_p = logits * self.probs
|
|
139
|
+
return -p_log_p.sum(-1)
|
|
140
|
+
|
|
141
|
+
@_one_hot_wrapper(D.Categorical)
|
|
142
|
+
def sample(self, sample_shape: torch.Size | Sequence | None = None) -> torch.Tensor:
|
|
143
|
+
...
|
|
144
|
+
|
|
145
|
+
def rsample(self, sample_shape: torch.Size | Sequence = None) -> torch.Tensor:
|
|
146
|
+
if sample_shape is None:
|
|
147
|
+
sample_shape = torch.Size([])
|
|
148
|
+
if hasattr(self, "logits") and self.logits is not None:
|
|
149
|
+
logits = self.logits
|
|
150
|
+
probs = None
|
|
151
|
+
else:
|
|
152
|
+
logits = None
|
|
153
|
+
probs = self.probs
|
|
154
|
+
if self.grad_method == ReparamGradientStrategy.RelaxedOneHot:
|
|
155
|
+
d = D.relaxed_categorical.RelaxedOneHotCategorical(
|
|
156
|
+
1.0, probs=probs, logits=logits
|
|
157
|
+
)
|
|
158
|
+
out = d.rsample(sample_shape)
|
|
159
|
+
out.data.copy_((out == out.max(-1)[0].unsqueeze(-1)).to(out.dtype))
|
|
160
|
+
return out
|
|
161
|
+
elif self.grad_method == ReparamGradientStrategy.PassThrough:
|
|
162
|
+
if logits is not None:
|
|
163
|
+
probs = self.probs
|
|
164
|
+
else:
|
|
165
|
+
probs = torch.softmax(self.logits, dim=-1)
|
|
166
|
+
out = self.sample(sample_shape)
|
|
167
|
+
out = out + probs - probs.detach()
|
|
168
|
+
return out
|
|
169
|
+
else:
|
|
170
|
+
raise ValueError(
|
|
171
|
+
f"Unknown reparameterization strategy {self.reparam_strategy}."
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
class MaskedCategorical(D.Categorical):
|
|
176
|
+
"""MaskedCategorical distribution.
|
|
177
|
+
|
|
178
|
+
Reference:
|
|
179
|
+
https://www.tensorflow.org/agents/api_docs/python/tf_agents/distributions/masked/MaskedCategorical
|
|
180
|
+
|
|
181
|
+
Args:
|
|
182
|
+
logits (torch.Tensor): event log probabilities (unnormalized)
|
|
183
|
+
probs (torch.Tensor): event probabilities. If provided, the probabilities
|
|
184
|
+
corresponding to masked items will be zeroed and the probability
|
|
185
|
+
re-normalized along its last dimension.
|
|
186
|
+
|
|
187
|
+
Keyword Args:
|
|
188
|
+
mask (torch.Tensor): A boolean mask of the same shape as ``logits``/``probs``
|
|
189
|
+
where ``False`` entries are the ones to be masked. Alternatively,
|
|
190
|
+
if ``sparse_mask`` is True, it represents the list of valid indices
|
|
191
|
+
in the distribution. Exclusive with ``indices``.
|
|
192
|
+
indices (torch.Tensor): A dense index tensor representing which actions
|
|
193
|
+
must be taken into account. Exclusive with ``mask``.
|
|
194
|
+
neg_inf (:obj:`float`, optional): The log-probability value allocated to
|
|
195
|
+
invalid (out-of-mask) indices. Defaults to -inf.
|
|
196
|
+
padding_value: The padding value in the mask tensor. When
|
|
197
|
+
sparse_mask == True, the padding_value will be ignored.
|
|
198
|
+
use_cross_entropy (bool, optional): For faster computation of the log-probability,
|
|
199
|
+
the cross_entropy loss functional can be used. Defaults to ``True``.
|
|
200
|
+
padding_side (str, optional): The side of the padding. Defaults to ``"left"``.
|
|
201
|
+
|
|
202
|
+
Examples:
|
|
203
|
+
>>> torch.manual_seed(0)
|
|
204
|
+
>>> logits = torch.randn(4) / 100 # almost equal probabilities
|
|
205
|
+
>>> mask = torch.tensor([True, False, True, True])
|
|
206
|
+
>>> dist = MaskedCategorical(logits=logits, mask=mask)
|
|
207
|
+
>>> sample = dist.sample((10,))
|
|
208
|
+
>>> print(sample) # no `1` in the sample
|
|
209
|
+
tensor([2, 3, 0, 2, 2, 0, 2, 0, 2, 2])
|
|
210
|
+
>>> print(dist.log_prob(sample))
|
|
211
|
+
tensor([-1.1203, -1.0928, -1.0831, -1.1203, -1.1203, -1.0831, -1.1203, -1.0831,
|
|
212
|
+
-1.1203, -1.1203])
|
|
213
|
+
>>> print(dist.log_prob(torch.ones_like(sample)))
|
|
214
|
+
tensor([-inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf])
|
|
215
|
+
>>> # with probabilities
|
|
216
|
+
>>> prob = torch.ones(10)
|
|
217
|
+
>>> prob = prob / prob.sum()
|
|
218
|
+
>>> mask = torch.tensor([False] + 9 * [True]) # first outcome is masked
|
|
219
|
+
>>> dist = MaskedCategorical(probs=prob, mask=mask)
|
|
220
|
+
>>> print(dist.log_prob(torch.arange(10)))
|
|
221
|
+
tensor([ -inf, -2.1972, -2.1972, -2.1972, -2.1972, -2.1972, -2.1972, -2.1972,
|
|
222
|
+
-2.1972, -2.1972])
|
|
223
|
+
"""
|
|
224
|
+
|
|
225
|
+
@lazy_property
|
|
226
|
+
def logits(self):
|
|
227
|
+
return probs_to_logits(self.probs)
|
|
228
|
+
|
|
229
|
+
@lazy_property
|
|
230
|
+
def probs(self):
|
|
231
|
+
return logits_to_probs(self.logits)
|
|
232
|
+
|
|
233
|
+
def __init__(
|
|
234
|
+
self,
|
|
235
|
+
logits: torch.Tensor | None = None,
|
|
236
|
+
probs: torch.Tensor | None = None,
|
|
237
|
+
*,
|
|
238
|
+
mask: torch.Tensor | None = None,
|
|
239
|
+
indices: torch.Tensor | None = None,
|
|
240
|
+
neg_inf: float = float("-inf"),
|
|
241
|
+
padding_value: int | None = None,
|
|
242
|
+
use_cross_entropy: bool = True,
|
|
243
|
+
padding_side: str = "left",
|
|
244
|
+
) -> None:
|
|
245
|
+
if not ((mask is None) ^ (indices is None)):
|
|
246
|
+
raise ValueError(
|
|
247
|
+
f"A ``mask`` or some ``indices`` must be provided for {type(self)}, but not both."
|
|
248
|
+
)
|
|
249
|
+
if mask is None:
|
|
250
|
+
mask = indices
|
|
251
|
+
sparse_mask = True
|
|
252
|
+
else:
|
|
253
|
+
sparse_mask = False
|
|
254
|
+
|
|
255
|
+
if probs is not None:
|
|
256
|
+
if logits is not None:
|
|
257
|
+
raise ValueError(
|
|
258
|
+
"Either `probs` or `logits` must be specified, but not both."
|
|
259
|
+
)
|
|
260
|
+
# unnormalized logits
|
|
261
|
+
probs = probs.clone()
|
|
262
|
+
if mask.dtype == torch.bool:
|
|
263
|
+
probs[~mask] = 0
|
|
264
|
+
else:
|
|
265
|
+
probs = torch.scatter(
|
|
266
|
+
torch.zeros_like(probs), -1, indices, probs.gather(-1, indices)
|
|
267
|
+
)
|
|
268
|
+
probs = probs / probs.sum(-1, keepdim=True)
|
|
269
|
+
logits = probs.log()
|
|
270
|
+
num_samples = logits.shape[-1]
|
|
271
|
+
self.use_cross_entropy = use_cross_entropy
|
|
272
|
+
logits = self._mask_logits(
|
|
273
|
+
logits,
|
|
274
|
+
mask,
|
|
275
|
+
neg_inf=neg_inf,
|
|
276
|
+
sparse_mask=sparse_mask,
|
|
277
|
+
padding_value=padding_value,
|
|
278
|
+
)
|
|
279
|
+
self.neg_inf = neg_inf
|
|
280
|
+
self._mask = mask
|
|
281
|
+
self._sparse_mask = sparse_mask
|
|
282
|
+
self._padding_value = padding_value
|
|
283
|
+
self._padding_side = padding_side
|
|
284
|
+
super().__init__(logits=logits)
|
|
285
|
+
self.num_samples = num_samples
|
|
286
|
+
|
|
287
|
+
@property
|
|
288
|
+
def padding_value(self):
|
|
289
|
+
"""Padding value of the distribution mask.
|
|
290
|
+
|
|
291
|
+
If the padding value is not set, it will be inferred from the logits.
|
|
292
|
+
"""
|
|
293
|
+
return self._padding_value if self._padding_value is not None else 0
|
|
294
|
+
|
|
295
|
+
@property
|
|
296
|
+
def padding_side(self):
|
|
297
|
+
return self._padding_side
|
|
298
|
+
|
|
299
|
+
@property
|
|
300
|
+
def mask(self):
|
|
301
|
+
if self._sparse_mask:
|
|
302
|
+
raise ValueError("MaskedCategorical.mask does not support sparse masks")
|
|
303
|
+
return self._mask
|
|
304
|
+
|
|
305
|
+
def entropy(self):
|
|
306
|
+
"""Compute the entropy of the distribution.
|
|
307
|
+
|
|
308
|
+
For masked distributions, we only consider the entropy over the valid (unmasked) outcomes.
|
|
309
|
+
Invalid outcomes have zero probability and don't contribute to entropy.
|
|
310
|
+
"""
|
|
311
|
+
min_real = torch.finfo(self.logits.dtype).min
|
|
312
|
+
|
|
313
|
+
# Clamp logits to avoid numerical issues
|
|
314
|
+
logits = self.logits
|
|
315
|
+
if self._mask.dtype is torch.bool:
|
|
316
|
+
mask = expand_as_right(self._mask, logits)
|
|
317
|
+
mask = (~mask) | (~logits.isfinite())
|
|
318
|
+
logits = torch.masked_fill(logits, mask, min_real)
|
|
319
|
+
else:
|
|
320
|
+
# logits are already masked
|
|
321
|
+
pass
|
|
322
|
+
logits = logits - logits.logsumexp(-1, keepdim=True)
|
|
323
|
+
|
|
324
|
+
# Get probabilities and mask them
|
|
325
|
+
probs = logits.exp()
|
|
326
|
+
|
|
327
|
+
# Compute entropy only for valid outcomes
|
|
328
|
+
p_log_p = logits * probs
|
|
329
|
+
return -p_log_p.sum(-1)
|
|
330
|
+
|
|
331
|
+
def sample(
|
|
332
|
+
self, sample_shape: torch.Size | Sequence[int] | None = None
|
|
333
|
+
) -> torch.Tensor:
|
|
334
|
+
if sample_shape is None:
|
|
335
|
+
sample_shape = torch.Size()
|
|
336
|
+
else:
|
|
337
|
+
sample_shape = torch.Size(sample_shape)
|
|
338
|
+
|
|
339
|
+
ret = super().sample(sample_shape)
|
|
340
|
+
if not self._sparse_mask:
|
|
341
|
+
return ret
|
|
342
|
+
|
|
343
|
+
size = ret.size()
|
|
344
|
+
outer_dim = sample_shape.numel()
|
|
345
|
+
inner_dim = self._mask.shape[:-1].numel()
|
|
346
|
+
idx_3d = self._mask.expand(outer_dim, inner_dim, -1)
|
|
347
|
+
ret = idx_3d.gather(dim=-1, index=ret.view(outer_dim, inner_dim, 1))
|
|
348
|
+
return ret.reshape(size)
|
|
349
|
+
|
|
350
|
+
def log_prob(self, value: torch.Tensor) -> torch.Tensor:
|
|
351
|
+
if not self._sparse_mask:
|
|
352
|
+
if self.use_cross_entropy:
|
|
353
|
+
logits = self.logits
|
|
354
|
+
if logits.ndim > 2:
|
|
355
|
+
# Bring channels in 2nd dim
|
|
356
|
+
logits = logits.permute(0, -1, *range(1, logits.ndim - 1))
|
|
357
|
+
original_value_shape = None
|
|
358
|
+
if logits.ndim == 1 and value.ndim >= 1:
|
|
359
|
+
if value.ndim >= 2:
|
|
360
|
+
original_value_shape = value.shape
|
|
361
|
+
value = value.flatten()
|
|
362
|
+
logits = logits.unsqueeze(0).expand(value.shape + logits.shape)
|
|
363
|
+
result = -torch.nn.functional.cross_entropy(logits, value, reduce=False)
|
|
364
|
+
if original_value_shape is not None:
|
|
365
|
+
result = result.unflatten(0, original_value_shape)
|
|
366
|
+
else:
|
|
367
|
+
result = super().log_prob(value)
|
|
368
|
+
result = torch.where(torch.isfinite(result), result, self.neg_inf)
|
|
369
|
+
return result
|
|
370
|
+
|
|
371
|
+
idx_3d = self._mask.view(1, -1, self._num_events)
|
|
372
|
+
val_3d = value.view(-1, idx_3d.size(1), 1)
|
|
373
|
+
mask = idx_3d == val_3d
|
|
374
|
+
idx = mask.int().argmax(dim=-1, keepdim=True)
|
|
375
|
+
idx = idx.view_as(value)
|
|
376
|
+
if self.use_cross_entropy:
|
|
377
|
+
logits = self.logits
|
|
378
|
+
if logits.ndim > 2:
|
|
379
|
+
# Bring channels in 2nd dim
|
|
380
|
+
logits = logits.transpose(-1, 1)
|
|
381
|
+
# possible shapes:
|
|
382
|
+
# Don't work with cross_entropy (missing batch dimension)
|
|
383
|
+
# logits.shape = (C,) and idx.shape = (B,)
|
|
384
|
+
# logits.shape = (C,) and idx.shape = (B0, B1, ...) => requires flattening of idx, only one batch dimension
|
|
385
|
+
# work with cross_entropy:
|
|
386
|
+
# logits.shape = (B, C) and idx.shape = (B,)
|
|
387
|
+
# logits.shape = (B, C, d1, d2, ...) and idx.shape = (B, d1, d2, ...)
|
|
388
|
+
original_idx_shape = None
|
|
389
|
+
if logits.ndim == 1 and idx.ndim >= 1:
|
|
390
|
+
if idx.ndim >= 2:
|
|
391
|
+
original_idx_shape = idx.shape
|
|
392
|
+
idx = idx.flatten()
|
|
393
|
+
logits = logits.unsqueeze(0).expand(idx.shape + logits.shape)
|
|
394
|
+
ret = -torch.nn.functional.cross_entropy(logits, idx, reduce=False)
|
|
395
|
+
if original_idx_shape is not None:
|
|
396
|
+
ret = ret.unflatten(0, original_idx_shape)
|
|
397
|
+
else:
|
|
398
|
+
ret = super().log_prob(idx)
|
|
399
|
+
# Fill masked values with neg_inf.
|
|
400
|
+
ret = ret.view_as(val_3d)
|
|
401
|
+
ret = ret.masked_fill(
|
|
402
|
+
torch.logical_not(mask.any(dim=-1, keepdim=True)), self.neg_inf
|
|
403
|
+
)
|
|
404
|
+
return ret.view_as(value)
|
|
405
|
+
|
|
406
|
+
@staticmethod
|
|
407
|
+
def _mask_logits(
|
|
408
|
+
logits: torch.Tensor,
|
|
409
|
+
mask: torch.Tensor | None = None,
|
|
410
|
+
neg_inf: float = float("-inf"),
|
|
411
|
+
sparse_mask: bool = False,
|
|
412
|
+
padding_value: int | None = None,
|
|
413
|
+
) -> torch.Tensor:
|
|
414
|
+
if mask is None:
|
|
415
|
+
return logits
|
|
416
|
+
|
|
417
|
+
if not sparse_mask:
|
|
418
|
+
return logits.masked_fill(~mask, neg_inf)
|
|
419
|
+
|
|
420
|
+
if padding_value is not None:
|
|
421
|
+
padding_mask = mask == padding_value
|
|
422
|
+
if padding_value != 0:
|
|
423
|
+
# Avoid invalid indices in mask.
|
|
424
|
+
mask = mask.masked_fill(padding_mask, 0)
|
|
425
|
+
logits = logits.gather(dim=-1, index=mask)
|
|
426
|
+
if padding_value is not None:
|
|
427
|
+
logits.masked_fill_(padding_mask, neg_inf)
|
|
428
|
+
return logits
|
|
429
|
+
|
|
430
|
+
@property
|
|
431
|
+
def deterministic_sample(self):
|
|
432
|
+
return self.mode
|
|
433
|
+
|
|
434
|
+
|
|
435
|
+
class MaskedOneHotCategorical(MaskedCategorical):
|
|
436
|
+
"""MaskedCategorical distribution.
|
|
437
|
+
|
|
438
|
+
Reference:
|
|
439
|
+
https://www.tensorflow.org/agents/api_docs/python/tf_agents/distributions/masked/MaskedCategorical
|
|
440
|
+
|
|
441
|
+
Args:
|
|
442
|
+
logits (torch.Tensor): event log probabilities (unnormalized)
|
|
443
|
+
probs (torch.Tensor): event probabilities. If provided, the probabilities
|
|
444
|
+
corresponding to masked items will be zeroed and the probability
|
|
445
|
+
re-normalized along its last dimension.
|
|
446
|
+
|
|
447
|
+
Keyword Args:
|
|
448
|
+
mask (torch.Tensor): A boolean mask of the same shape as ``logits``/``probs``
|
|
449
|
+
where ``False`` entries are the ones to be masked. Alternatively,
|
|
450
|
+
if ``sparse_mask`` is True, it represents the list of valid indices
|
|
451
|
+
in the distribution. Exclusive with ``indices``.
|
|
452
|
+
indices (torch.Tensor): A dense index tensor representing which actions
|
|
453
|
+
must be taken into account. Exclusive with ``mask``.
|
|
454
|
+
neg_inf (:obj:`float`, optional): The log-probability value allocated to
|
|
455
|
+
invalid (out-of-mask) indices. Defaults to -inf.
|
|
456
|
+
padding_value: The padding value in then mask tensor when
|
|
457
|
+
sparse_mask == True, the padding_value will be ignored.
|
|
458
|
+
grad_method (ReparamGradientStrategy, optional): strategy to gather
|
|
459
|
+
reparameterized samples.
|
|
460
|
+
``ReparamGradientStrategy.PassThrough`` will compute the sample gradients
|
|
461
|
+
by using the softmax valued log-probability as a proxy to the
|
|
462
|
+
samples gradients.
|
|
463
|
+
``ReparamGradientStrategy.RelaxedOneHot`` will use
|
|
464
|
+
:class:`torch.distributions.RelaxedOneHot` to sample from the distribution.
|
|
465
|
+
|
|
466
|
+
Examples:
|
|
467
|
+
>>> torch.manual_seed(0)
|
|
468
|
+
>>> logits = torch.randn(4) / 100 # almost equal probabilities
|
|
469
|
+
>>> mask = torch.tensor([True, False, True, True])
|
|
470
|
+
>>> dist = MaskedOneHotCategorical(logits=logits, mask=mask)
|
|
471
|
+
>>> sample = dist.sample((10,))
|
|
472
|
+
>>> print(sample) # no `1` in the sample
|
|
473
|
+
tensor([[0, 0, 1, 0],
|
|
474
|
+
[0, 0, 0, 1],
|
|
475
|
+
[1, 0, 0, 0],
|
|
476
|
+
[0, 0, 1, 0],
|
|
477
|
+
[0, 0, 1, 0],
|
|
478
|
+
[1, 0, 0, 0],
|
|
479
|
+
[0, 0, 1, 0],
|
|
480
|
+
[1, 0, 0, 0],
|
|
481
|
+
[0, 0, 1, 0],
|
|
482
|
+
[0, 0, 1, 0]])
|
|
483
|
+
>>> print(dist.log_prob(sample))
|
|
484
|
+
tensor([-1.1203, -1.0928, -1.0831, -1.1203, -1.1203, -1.0831, -1.1203, -1.0831,
|
|
485
|
+
-1.1203, -1.1203])
|
|
486
|
+
>>> sample_non_valid = torch.zeros_like(sample)
|
|
487
|
+
>>> sample_non_valid[..., 1] = 1
|
|
488
|
+
>>> print(dist.log_prob(sample_non_valid))
|
|
489
|
+
tensor([-inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf])
|
|
490
|
+
>>> # with probabilities
|
|
491
|
+
>>> prob = torch.ones(10)
|
|
492
|
+
>>> prob = prob / prob.sum()
|
|
493
|
+
>>> mask = torch.tensor([False] + 9 * [True]) # first outcome is masked
|
|
494
|
+
>>> dist = MaskedOneHotCategorical(probs=prob, mask=mask)
|
|
495
|
+
>>> s = torch.arange(10)
|
|
496
|
+
>>> s = torch.nn.functional.one_hot(s, 10)
|
|
497
|
+
>>> print(dist.log_prob(s))
|
|
498
|
+
tensor([ -inf, -2.1972, -2.1972, -2.1972, -2.1972, -2.1972, -2.1972, -2.1972,
|
|
499
|
+
-2.1972, -2.1972])
|
|
500
|
+
"""
|
|
501
|
+
|
|
502
|
+
@lazy_property
|
|
503
|
+
def logits(self):
|
|
504
|
+
return probs_to_logits(self.probs)
|
|
505
|
+
|
|
506
|
+
@lazy_property
|
|
507
|
+
def probs(self):
|
|
508
|
+
return logits_to_probs(self.logits)
|
|
509
|
+
|
|
510
|
+
def __init__(
|
|
511
|
+
self,
|
|
512
|
+
logits: torch.Tensor | None = None,
|
|
513
|
+
probs: torch.Tensor | None = None,
|
|
514
|
+
mask: torch.Tensor = None,
|
|
515
|
+
indices: torch.Tensor = None,
|
|
516
|
+
neg_inf: float = float("-inf"),
|
|
517
|
+
padding_value: int | None = None,
|
|
518
|
+
grad_method: ReparamGradientStrategy = ReparamGradientStrategy.PassThrough,
|
|
519
|
+
) -> None:
|
|
520
|
+
self.grad_method = grad_method
|
|
521
|
+
super().__init__(
|
|
522
|
+
logits=logits,
|
|
523
|
+
probs=probs,
|
|
524
|
+
mask=mask,
|
|
525
|
+
indices=indices,
|
|
526
|
+
neg_inf=neg_inf,
|
|
527
|
+
padding_value=padding_value,
|
|
528
|
+
)
|
|
529
|
+
|
|
530
|
+
@_one_hot_wrapper(MaskedCategorical)
|
|
531
|
+
def sample(
|
|
532
|
+
self, sample_shape: torch.Size | Sequence[int] | None = None
|
|
533
|
+
) -> torch.Tensor:
|
|
534
|
+
...
|
|
535
|
+
|
|
536
|
+
@property
|
|
537
|
+
def deterministic_sample(self):
|
|
538
|
+
return self.mode
|
|
539
|
+
|
|
540
|
+
@property
|
|
541
|
+
def mode(self) -> torch.Tensor:
|
|
542
|
+
if hasattr(self, "logits"):
|
|
543
|
+
return (self.logits == self.logits.max(-1, True)[0]).to(torch.long)
|
|
544
|
+
else:
|
|
545
|
+
return (self.probs == self.probs.max(-1, True)[0]).to(torch.long)
|
|
546
|
+
|
|
547
|
+
def log_prob(self, value: torch.Tensor) -> torch.Tensor:
|
|
548
|
+
return super().log_prob(value.argmax(dim=-1))
|
|
549
|
+
|
|
550
|
+
def rsample(self, sample_shape: torch.Size | Sequence = None) -> torch.Tensor:
|
|
551
|
+
if sample_shape is None:
|
|
552
|
+
sample_shape = torch.Size([])
|
|
553
|
+
if hasattr(self, "logits") and self.logits is not None:
|
|
554
|
+
logits = self.logits
|
|
555
|
+
probs = None
|
|
556
|
+
else:
|
|
557
|
+
logits = None
|
|
558
|
+
probs = self.probs
|
|
559
|
+
if self.grad_method == ReparamGradientStrategy.RelaxedOneHot:
|
|
560
|
+
if self._sparse_mask:
|
|
561
|
+
if probs is not None:
|
|
562
|
+
probs_extended = torch.full(
|
|
563
|
+
(*probs.shape[:-1], self.num_samples),
|
|
564
|
+
0,
|
|
565
|
+
device=probs.device,
|
|
566
|
+
dtype=probs.dtype,
|
|
567
|
+
)
|
|
568
|
+
probs_extended = torch.scatter(
|
|
569
|
+
probs_extended, -1, self._mask, probs
|
|
570
|
+
)
|
|
571
|
+
logits_extended = None
|
|
572
|
+
else:
|
|
573
|
+
probs_extended = torch.full(
|
|
574
|
+
(*logits.shape[:-1], self.num_samples),
|
|
575
|
+
self.neg_inf,
|
|
576
|
+
device=logits.device,
|
|
577
|
+
dtype=logits.dtype,
|
|
578
|
+
)
|
|
579
|
+
logits_extended = torch.scatter(
|
|
580
|
+
probs_extended, -1, self._mask, logits
|
|
581
|
+
)
|
|
582
|
+
probs_extended = None
|
|
583
|
+
else:
|
|
584
|
+
probs_extended = probs
|
|
585
|
+
logits_extended = logits
|
|
586
|
+
|
|
587
|
+
d = D.relaxed_categorical.RelaxedOneHotCategorical(
|
|
588
|
+
1.0, probs=probs_extended, logits=logits_extended
|
|
589
|
+
)
|
|
590
|
+
out = d.rsample(sample_shape)
|
|
591
|
+
out.data.copy_((out == out.max(-1)[0].unsqueeze(-1)).to(out.dtype))
|
|
592
|
+
return out
|
|
593
|
+
elif self.grad_method == ReparamGradientStrategy.PassThrough:
|
|
594
|
+
if logits is not None:
|
|
595
|
+
probs = self.probs
|
|
596
|
+
else:
|
|
597
|
+
probs = torch.softmax(self.logits, dim=-1)
|
|
598
|
+
if self._sparse_mask:
|
|
599
|
+
probs_extended = torch.full(
|
|
600
|
+
(*probs.shape[:-1], self.num_samples),
|
|
601
|
+
0,
|
|
602
|
+
device=probs.device,
|
|
603
|
+
dtype=probs.dtype,
|
|
604
|
+
)
|
|
605
|
+
probs_extended = torch.scatter(probs_extended, -1, self._mask, probs)
|
|
606
|
+
else:
|
|
607
|
+
probs_extended = probs
|
|
608
|
+
|
|
609
|
+
out = self.sample(sample_shape)
|
|
610
|
+
out = out + probs_extended - probs_extended.detach()
|
|
611
|
+
return out
|
|
612
|
+
else:
|
|
613
|
+
raise ValueError(
|
|
614
|
+
f"Unknown reparameterization strategy {self.reparam_strategy}."
|
|
615
|
+
)
|
|
616
|
+
|
|
617
|
+
|
|
618
|
+
class Ordinal(D.Categorical):
|
|
619
|
+
"""A discrete distribution for learning to sample from finite ordered sets.
|
|
620
|
+
|
|
621
|
+
It is defined in contrast with the `Categorical` distribution, which does
|
|
622
|
+
not impose any notion of proximity or ordering over its support's atoms.
|
|
623
|
+
The `Ordinal` distribution explicitly encodes those concepts, which is
|
|
624
|
+
useful for learning discrete sampling from continuous sets. See §5 of
|
|
625
|
+
`Tang & Agrawal, 2020 <https://arxiv.org/pdf/1901.10500.pdf>`_ for details.
|
|
626
|
+
|
|
627
|
+
.. note::
|
|
628
|
+
This class is mostly useful when you want to learn a distribution over
|
|
629
|
+
a finite set which is obtained by discretising a continuous set.
|
|
630
|
+
|
|
631
|
+
Args:
|
|
632
|
+
scores (torch.Tensor): a tensor of shape [..., N] where N is the size of the set which supports the distributions.
|
|
633
|
+
Typically, the output of a neural network parametrising the distribution.
|
|
634
|
+
|
|
635
|
+
Examples:
|
|
636
|
+
>>> num_atoms, num_samples = 5, 20
|
|
637
|
+
>>> mean = (num_atoms - 1) / 2 # Target mean for samples, centered around the middle atom
|
|
638
|
+
>>> torch.manual_seed(42)
|
|
639
|
+
>>> logits = torch.ones((num_atoms), requires_grad=True)
|
|
640
|
+
>>> optimizer = torch.optim.Adam([logits], lr=0.1)
|
|
641
|
+
>>>
|
|
642
|
+
>>> # Perform optimisation loop to minimise deviation from `mean`
|
|
643
|
+
>>> for _ in range(20):
|
|
644
|
+
>>> sampler = Ordinal(scores=logits)
|
|
645
|
+
>>> samples = sampler.sample((num_samples,))
|
|
646
|
+
>>> # Define loss to encourage samples around the mean by penalising deviation from mean
|
|
647
|
+
>>> loss = torch.mean((samples - mean) ** 2 * sampler.log_prob(samples))
|
|
648
|
+
>>> loss.backward()
|
|
649
|
+
>>> optimizer.step()
|
|
650
|
+
>>> optimizer.zero_grad()
|
|
651
|
+
>>>
|
|
652
|
+
>>> sampler.probs
|
|
653
|
+
tensor([0.0308, 0.1586, 0.4727, 0.2260, 0.1120], ...)
|
|
654
|
+
>>> # Print histogram to observe sample distribution frequency across 5 bins (0, 1, 2, 3, and 4)
|
|
655
|
+
>>> torch.histogram(sampler.sample((1000,)).reshape(-1).float(), bins=num_atoms)
|
|
656
|
+
torch.return_types.histogram(
|
|
657
|
+
hist=tensor([ 24., 158., 478., 228., 112.]),
|
|
658
|
+
bin_edges=tensor([0.0000, 0.8000, 1.6000, 2.4000, 3.2000, 4.0000]))
|
|
659
|
+
"""
|
|
660
|
+
|
|
661
|
+
def __init__(self, scores: torch.Tensor):
|
|
662
|
+
logits = _generate_ordinal_logits(scores)
|
|
663
|
+
super().__init__(logits=logits)
|
|
664
|
+
|
|
665
|
+
|
|
666
|
+
class OneHotOrdinal(OneHotCategorical):
|
|
667
|
+
"""The one-hot version of the :class:`~tensordict.nn.distributions.Ordinal` distribution.
|
|
668
|
+
|
|
669
|
+
Args:
|
|
670
|
+
scores (torch.Tensor): a tensor of shape [..., N] where N is the size of the set which supports the distributions.
|
|
671
|
+
Typically, the output of a neural network parametrising the distribution.
|
|
672
|
+
"""
|
|
673
|
+
|
|
674
|
+
def __init__(self, scores: torch.Tensor):
|
|
675
|
+
logits = _generate_ordinal_logits(scores)
|
|
676
|
+
super().__init__(logits=logits)
|
|
677
|
+
|
|
678
|
+
|
|
679
|
+
def _generate_ordinal_logits(scores: torch.Tensor) -> torch.Tensor:
|
|
680
|
+
"""Implements Eq. 4 of `Tang & Agrawal, 2020<https://arxiv.org/pdf/1901.10500.pdf>`__."""
|
|
681
|
+
# Assigns Bernoulli-like probabilities for each class in the set
|
|
682
|
+
log_probs = F.logsigmoid(scores)
|
|
683
|
+
complementary_log_probs = F.logsigmoid(-scores)
|
|
684
|
+
|
|
685
|
+
# Total log-probability for being "larger than k"
|
|
686
|
+
larger_than_log_probs = log_probs.cumsum(dim=-1)
|
|
687
|
+
|
|
688
|
+
# Total log-probability for being "smaller than k"
|
|
689
|
+
smaller_than_log_probs = (
|
|
690
|
+
complementary_log_probs.flip(dims=[-1]).cumsum(dim=-1).flip(dims=[-1])
|
|
691
|
+
- complementary_log_probs
|
|
692
|
+
)
|
|
693
|
+
|
|
694
|
+
return larger_than_log_probs + smaller_than_log_probs
|
|
695
|
+
|
|
696
|
+
|
|
697
|
+
class LLMMaskedCategorical(D.Distribution):
|
|
698
|
+
"""LLM-optimized masked categorical distribution.
|
|
699
|
+
|
|
700
|
+
This class provides a more memory-efficient approach for LLM training by:
|
|
701
|
+
1. Using ignore_index=-100 for log_prob computation (no masking overhead)
|
|
702
|
+
2. Using traditional masking for sampling operations
|
|
703
|
+
|
|
704
|
+
This is particularly beneficial for large vocabulary sizes where masking
|
|
705
|
+
all logits can be memory-intensive.
|
|
706
|
+
|
|
707
|
+
Args:
|
|
708
|
+
logits (torch.Tensor): Event log probabilities (unnormalized), shape [B, T, C].
|
|
709
|
+
- *B*: batch size (optional)
|
|
710
|
+
- T: sequence length
|
|
711
|
+
- C: vocabulary size (number of classes)
|
|
712
|
+
mask (torch.Tensor): Boolean mask indicating valid positions/tokens.
|
|
713
|
+
- If shape [*B, T]: position-level masking. True means the position is valid (all tokens allowed).
|
|
714
|
+
- If shape [*B, T, C]: token-level masking. True means the token is valid at that position.
|
|
715
|
+
|
|
716
|
+
.. warning:: Token-level masking is considerably more memory-intensive than position-level masking.
|
|
717
|
+
Only use this if you need to mask tokens.
|
|
718
|
+
|
|
719
|
+
ignore_index (int, optional): Index to ignore in log_prob computation. Defaults to -100.
|
|
720
|
+
|
|
721
|
+
Input shapes:
|
|
722
|
+
- logits: [*B, T, C] (required)
|
|
723
|
+
- mask: [*B, T] (position-level) or [*B, T, C] (token-level)
|
|
724
|
+
- tokens (for log_prob): [*B, T] (token indices, with ignore_index for masked positions)
|
|
725
|
+
|
|
726
|
+
Use cases:
|
|
727
|
+
1. **Position-level masking**
|
|
728
|
+
>>> logits = torch.randn(2, 10, 50000) # [B=2, T=10, C=50000]
|
|
729
|
+
>>> mask = torch.ones(2, 10, dtype=torch.bool) # [B, T]
|
|
730
|
+
>>> mask[0, :5] = False # mask first 5 positions of first sequence
|
|
731
|
+
>>> dist = LLMMaskedCategorical(logits=logits, mask=mask)
|
|
732
|
+
>>> tokens = torch.randint(0, 50000, (2, 10)) # [B, T]
|
|
733
|
+
>>> tokens[0, :5] = -100 # set masked positions to ignore_index
|
|
734
|
+
>>> log_probs = dist.log_prob(tokens)
|
|
735
|
+
>>> samples = dist.sample() # [B, T]
|
|
736
|
+
|
|
737
|
+
2. **Token-level masking**
|
|
738
|
+
>>> logits = torch.randn(2, 10, 50000)
|
|
739
|
+
>>> mask = torch.ones(2, 10, 50000, dtype=torch.bool) # [B, T, C]
|
|
740
|
+
>>> mask[0, :5, :1000] = False # mask first 1000 tokens for first 5 positions
|
|
741
|
+
>>> dist = LLMMaskedCategorical(logits=logits, mask=mask)
|
|
742
|
+
>>> tokens = torch.randint(0, 50000, (2, 10))
|
|
743
|
+
>>> # Optionally, set tokens at fully-masked positions to ignore_index
|
|
744
|
+
>>> log_probs = dist.log_prob(tokens)
|
|
745
|
+
>>> samples = dist.sample() # [B, T]
|
|
746
|
+
|
|
747
|
+
Notes:
|
|
748
|
+
- For log_prob, tokens must be of shape [B, T] and contain valid token indices (0 <= token < C), or ignore_index for masked/ignored positions.
|
|
749
|
+
- For token-level masking, if a token is masked at a given position, log_prob will return -inf for that entry.
|
|
750
|
+
- For position-level masking, if a position is masked (ignore_index), log_prob will return 0.0 for that entry (correct for cross-entropy loss).
|
|
751
|
+
- Sampling always respects the mask (masked tokens/positions are never sampled).
|
|
752
|
+
|
|
753
|
+
All documented use cases are covered by tests in test_distributions.py.
|
|
754
|
+
"""
|
|
755
|
+
|
|
756
|
+
def __init__(
|
|
757
|
+
self,
|
|
758
|
+
logits: torch.Tensor,
|
|
759
|
+
mask: torch.Tensor,
|
|
760
|
+
ignore_index: int = -100,
|
|
761
|
+
) -> None:
|
|
762
|
+
# Validate shapes
|
|
763
|
+
if logits.shape[:-1] != mask.shape and logits.shape != mask.shape:
|
|
764
|
+
raise ValueError(
|
|
765
|
+
f"Mask shape {mask.shape} must be either logits batch shape {logits.shape[:-1]} "
|
|
766
|
+
f"(for position-level masking) or logits shape {logits.shape} "
|
|
767
|
+
f"(for token-level masking)"
|
|
768
|
+
)
|
|
769
|
+
|
|
770
|
+
self._original_logits = logits
|
|
771
|
+
self._mask = mask
|
|
772
|
+
self.ignore_index = ignore_index
|
|
773
|
+
self._position_level_masking = mask.shape == logits.shape[:-1]
|
|
774
|
+
|
|
775
|
+
# Create masked logits for sampling (only when needed)
|
|
776
|
+
self._masked_logits = None
|
|
777
|
+
self._masked_dist = None
|
|
778
|
+
|
|
779
|
+
# Set up distribution properties
|
|
780
|
+
batch_shape = logits.shape[:-1]
|
|
781
|
+
event_shape = logits.shape[-1:]
|
|
782
|
+
super().__init__(batch_shape=batch_shape, event_shape=event_shape)
|
|
783
|
+
|
|
784
|
+
@property
|
|
785
|
+
def _sampling_logits(self):
|
|
786
|
+
"""Get masked logits for sampling operations."""
|
|
787
|
+
if self._masked_logits is None:
|
|
788
|
+
# Only create masked logits when needed for sampling
|
|
789
|
+
large_neg = torch.finfo(self._original_logits.dtype).min
|
|
790
|
+
|
|
791
|
+
if self._position_level_masking:
|
|
792
|
+
# Position-level masking: expand mask to match logits shape
|
|
793
|
+
mask_expanded = expand_as_right(self._mask, self._original_logits)
|
|
794
|
+
self._masked_logits = self._original_logits.masked_fill(
|
|
795
|
+
~mask_expanded, large_neg
|
|
796
|
+
)
|
|
797
|
+
else:
|
|
798
|
+
# Token-level masking: direct masking
|
|
799
|
+
self._masked_logits = self._original_logits.masked_fill(
|
|
800
|
+
~self._mask, large_neg
|
|
801
|
+
)
|
|
802
|
+
return self._masked_logits
|
|
803
|
+
|
|
804
|
+
@property
|
|
805
|
+
def _sampling_dist(self):
|
|
806
|
+
"""Get masked distribution for sampling operations."""
|
|
807
|
+
if self._masked_dist is None:
|
|
808
|
+
self._masked_dist = D.Categorical(logits=self._sampling_logits)
|
|
809
|
+
return self._masked_dist
|
|
810
|
+
|
|
811
|
+
def log_prob(self, value: torch.Tensor) -> torch.Tensor:
|
|
812
|
+
"""Compute log probabilities using ignore_index approach.
|
|
813
|
+
|
|
814
|
+
This is memory-efficient as it doesn't require masking the logits.
|
|
815
|
+
The value tensor should use ignore_index for masked positions.
|
|
816
|
+
"""
|
|
817
|
+
if not self._position_level_masking:
|
|
818
|
+
logits = self.masked_logits
|
|
819
|
+
else:
|
|
820
|
+
# Use cross_entropy with ignore_index for efficiency
|
|
821
|
+
|
|
822
|
+
# For position-level masking, keep the default behavior (0.0 for ignore_index)
|
|
823
|
+
# This is correct for cross-entropy loss computation
|
|
824
|
+
# For token-level masking, we need to check if specific tokens are masked
|
|
825
|
+
|
|
826
|
+
logits = self._original_logits
|
|
827
|
+
value = value.masked_fill(~self._mask, self.ignore_index)
|
|
828
|
+
if value.ndim > 1:
|
|
829
|
+
# Reshape for cross_entropy: (batch, seq_len, vocab) -> (batch*seq_len, vocab)
|
|
830
|
+
logits_flat = logits.reshape(-1, logits.size(-1))
|
|
831
|
+
value_flat = value.reshape(-1)
|
|
832
|
+
|
|
833
|
+
# Compute cross_entropy with ignore_index
|
|
834
|
+
log_probs_flat = -F.cross_entropy(
|
|
835
|
+
logits_flat, value_flat, reduce=False, ignore_index=self.ignore_index
|
|
836
|
+
)
|
|
837
|
+
|
|
838
|
+
# Reshape back
|
|
839
|
+
log_probs = log_probs_flat.reshape_as(value)
|
|
840
|
+
else:
|
|
841
|
+
log_probs = -F.cross_entropy(
|
|
842
|
+
logits,
|
|
843
|
+
value,
|
|
844
|
+
reduce=False,
|
|
845
|
+
ignore_index=self.ignore_index,
|
|
846
|
+
)
|
|
847
|
+
return log_probs
|
|
848
|
+
|
|
849
|
+
def sample(
|
|
850
|
+
self, sample_shape: torch.Size | Sequence[int] | None = None
|
|
851
|
+
) -> torch.Tensor:
|
|
852
|
+
"""Sample from the distribution using masked logits."""
|
|
853
|
+
if sample_shape is None:
|
|
854
|
+
sample_shape = torch.Size()
|
|
855
|
+
return self._sampling_dist.sample(sample_shape)
|
|
856
|
+
|
|
857
|
+
def rsample(
|
|
858
|
+
self, sample_shape: torch.Size | Sequence[int] | None = None
|
|
859
|
+
) -> torch.Tensor:
|
|
860
|
+
"""Reparameterized sampling using masked logits."""
|
|
861
|
+
# This would need to be implemented based on the specific reparameterization strategy
|
|
862
|
+
# For now, fall back to regular sampling
|
|
863
|
+
return self.sample(sample_shape)
|
|
864
|
+
|
|
865
|
+
@property
|
|
866
|
+
def mode(self) -> torch.Tensor:
|
|
867
|
+
"""Get the mode using masked logits."""
|
|
868
|
+
masked_logits = self._sampling_logits
|
|
869
|
+
return masked_logits.argmax(dim=-1)
|
|
870
|
+
|
|
871
|
+
def entropy(self) -> torch.Tensor:
|
|
872
|
+
"""Compute entropy using masked logits."""
|
|
873
|
+
return self._sampling_dist.entropy()
|
|
874
|
+
|
|
875
|
+
def clear_cache(self):
|
|
876
|
+
"""Clear cached masked tensors to free memory."""
|
|
877
|
+
self._masked_logits = None
|
|
878
|
+
self._masked_dist = None
|
|
879
|
+
|
|
880
|
+
@property
|
|
881
|
+
def mask(self) -> torch.Tensor:
|
|
882
|
+
"""Get the mask."""
|
|
883
|
+
return self._mask
|
|
884
|
+
|
|
885
|
+
@property
|
|
886
|
+
def logits(self) -> torch.Tensor:
|
|
887
|
+
"""Get the original logits."""
|
|
888
|
+
return self._original_logits
|
|
889
|
+
|
|
890
|
+
@property
|
|
891
|
+
def probs(self) -> torch.Tensor:
|
|
892
|
+
"""Get probabilities from original logits."""
|
|
893
|
+
return torch.softmax(self._original_logits, dim=-1)
|
|
894
|
+
|
|
895
|
+
@property
|
|
896
|
+
def masked_logits(self) -> torch.Tensor:
|
|
897
|
+
"""Get the masked logits for sampling operations."""
|
|
898
|
+
return self._sampling_logits
|
|
899
|
+
|
|
900
|
+
@property
|
|
901
|
+
def masked_dist(self) -> D.Categorical:
|
|
902
|
+
"""Get the masked distribution for sampling operations."""
|
|
903
|
+
return self._sampling_dist
|
|
904
|
+
|
|
905
|
+
@property
|
|
906
|
+
def position_level_masking(self) -> bool:
|
|
907
|
+
"""Whether the mask is position-level (True) or token-level (False)."""
|
|
908
|
+
return self._position_level_masking
|