torchrl 0.11.0__cp314-cp314-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/benchmark_batched_envs.py +104 -0
- benchmarks/conftest.py +91 -0
- benchmarks/ecosystem/gym_env_throughput.py +321 -0
- benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
- benchmarks/requirements.txt +7 -0
- benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
- benchmarks/test_collectors_benchmark.py +240 -0
- benchmarks/test_compressed_storage_benchmark.py +145 -0
- benchmarks/test_envs_benchmark.py +133 -0
- benchmarks/test_llm.py +101 -0
- benchmarks/test_non_tensor_env_benchmark.py +70 -0
- benchmarks/test_objectives_benchmarks.py +1199 -0
- benchmarks/test_replaybuffer_benchmark.py +254 -0
- sota-check/README.md +35 -0
- sota-implementations/README.md +142 -0
- sota-implementations/a2c/README.md +39 -0
- sota-implementations/a2c/a2c_atari.py +291 -0
- sota-implementations/a2c/a2c_mujoco.py +273 -0
- sota-implementations/a2c/utils_atari.py +240 -0
- sota-implementations/a2c/utils_mujoco.py +160 -0
- sota-implementations/bandits/README.md +7 -0
- sota-implementations/bandits/dqn.py +126 -0
- sota-implementations/cql/cql_offline.py +198 -0
- sota-implementations/cql/cql_online.py +249 -0
- sota-implementations/cql/discrete_cql_offline.py +180 -0
- sota-implementations/cql/discrete_cql_online.py +227 -0
- sota-implementations/cql/utils.py +471 -0
- sota-implementations/crossq/crossq.py +271 -0
- sota-implementations/crossq/utils.py +320 -0
- sota-implementations/ddpg/ddpg.py +231 -0
- sota-implementations/ddpg/utils.py +325 -0
- sota-implementations/decision_transformer/dt.py +163 -0
- sota-implementations/decision_transformer/lamb.py +167 -0
- sota-implementations/decision_transformer/online_dt.py +178 -0
- sota-implementations/decision_transformer/utils.py +562 -0
- sota-implementations/discrete_sac/discrete_sac.py +243 -0
- sota-implementations/discrete_sac/utils.py +324 -0
- sota-implementations/dqn/README.md +30 -0
- sota-implementations/dqn/dqn_atari.py +272 -0
- sota-implementations/dqn/dqn_cartpole.py +236 -0
- sota-implementations/dqn/utils_atari.py +132 -0
- sota-implementations/dqn/utils_cartpole.py +90 -0
- sota-implementations/dreamer/README.md +129 -0
- sota-implementations/dreamer/dreamer.py +586 -0
- sota-implementations/dreamer/dreamer_utils.py +1107 -0
- sota-implementations/expert-iteration/README.md +352 -0
- sota-implementations/expert-iteration/ei_utils.py +770 -0
- sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
- sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
- sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
- sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
- sota-implementations/gail/gail.py +327 -0
- sota-implementations/gail/gail_utils.py +68 -0
- sota-implementations/gail/ppo_utils.py +157 -0
- sota-implementations/grpo/README.md +273 -0
- sota-implementations/grpo/grpo-async.py +437 -0
- sota-implementations/grpo/grpo-sync.py +435 -0
- sota-implementations/grpo/grpo_utils.py +843 -0
- sota-implementations/grpo/requirements_gsm8k.txt +11 -0
- sota-implementations/grpo/requirements_ifeval.txt +16 -0
- sota-implementations/impala/README.md +33 -0
- sota-implementations/impala/impala_multi_node_ray.py +292 -0
- sota-implementations/impala/impala_multi_node_submitit.py +284 -0
- sota-implementations/impala/impala_single_node.py +261 -0
- sota-implementations/impala/utils.py +184 -0
- sota-implementations/iql/discrete_iql.py +230 -0
- sota-implementations/iql/iql_offline.py +164 -0
- sota-implementations/iql/iql_online.py +225 -0
- sota-implementations/iql/utils.py +437 -0
- sota-implementations/multiagent/README.md +74 -0
- sota-implementations/multiagent/iql.py +237 -0
- sota-implementations/multiagent/maddpg_iddpg.py +266 -0
- sota-implementations/multiagent/mappo_ippo.py +267 -0
- sota-implementations/multiagent/qmix_vdn.py +271 -0
- sota-implementations/multiagent/sac.py +337 -0
- sota-implementations/multiagent/utils/__init__.py +4 -0
- sota-implementations/multiagent/utils/logging.py +151 -0
- sota-implementations/multiagent/utils/utils.py +43 -0
- sota-implementations/ppo/README.md +29 -0
- sota-implementations/ppo/ppo_atari.py +305 -0
- sota-implementations/ppo/ppo_mujoco.py +293 -0
- sota-implementations/ppo/utils_atari.py +238 -0
- sota-implementations/ppo/utils_mujoco.py +152 -0
- sota-implementations/ppo_trainer/train.py +21 -0
- sota-implementations/redq/README.md +7 -0
- sota-implementations/redq/redq.py +199 -0
- sota-implementations/redq/utils.py +1060 -0
- sota-implementations/sac/sac-async.py +266 -0
- sota-implementations/sac/sac.py +239 -0
- sota-implementations/sac/utils.py +381 -0
- sota-implementations/sac_trainer/train.py +16 -0
- sota-implementations/td3/td3.py +254 -0
- sota-implementations/td3/utils.py +319 -0
- sota-implementations/td3_bc/td3_bc.py +177 -0
- sota-implementations/td3_bc/utils.py +251 -0
- torchrl/.dylibs/libc++.1.0.dylib +0 -0
- torchrl/__init__.py +144 -0
- torchrl/_extension.py +74 -0
- torchrl/_torchrl.cpython-314-darwin.so +0 -0
- torchrl/_utils.py +1431 -0
- torchrl/collectors/__init__.py +48 -0
- torchrl/collectors/_base.py +1058 -0
- torchrl/collectors/_constants.py +88 -0
- torchrl/collectors/_multi_async.py +324 -0
- torchrl/collectors/_multi_base.py +1805 -0
- torchrl/collectors/_multi_sync.py +464 -0
- torchrl/collectors/_runner.py +581 -0
- torchrl/collectors/_single.py +2009 -0
- torchrl/collectors/_single_async.py +259 -0
- torchrl/collectors/collectors.py +62 -0
- torchrl/collectors/distributed/__init__.py +32 -0
- torchrl/collectors/distributed/default_configs.py +133 -0
- torchrl/collectors/distributed/generic.py +1306 -0
- torchrl/collectors/distributed/ray.py +1092 -0
- torchrl/collectors/distributed/rpc.py +1006 -0
- torchrl/collectors/distributed/sync.py +731 -0
- torchrl/collectors/distributed/utils.py +160 -0
- torchrl/collectors/llm/__init__.py +10 -0
- torchrl/collectors/llm/base.py +494 -0
- torchrl/collectors/llm/ray_collector.py +275 -0
- torchrl/collectors/llm/utils.py +36 -0
- torchrl/collectors/llm/weight_update/__init__.py +10 -0
- torchrl/collectors/llm/weight_update/vllm.py +348 -0
- torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
- torchrl/collectors/utils.py +433 -0
- torchrl/collectors/weight_update.py +591 -0
- torchrl/csrc/numpy_utils.h +38 -0
- torchrl/csrc/pybind.cpp +27 -0
- torchrl/csrc/segment_tree.h +458 -0
- torchrl/csrc/torch_utils.h +34 -0
- torchrl/csrc/utils.cpp +48 -0
- torchrl/csrc/utils.h +31 -0
- torchrl/data/__init__.py +187 -0
- torchrl/data/datasets/__init__.py +58 -0
- torchrl/data/datasets/atari_dqn.py +878 -0
- torchrl/data/datasets/common.py +281 -0
- torchrl/data/datasets/d4rl.py +489 -0
- torchrl/data/datasets/d4rl_infos.py +187 -0
- torchrl/data/datasets/gen_dgrl.py +375 -0
- torchrl/data/datasets/minari_data.py +643 -0
- torchrl/data/datasets/openml.py +177 -0
- torchrl/data/datasets/openx.py +798 -0
- torchrl/data/datasets/roboset.py +363 -0
- torchrl/data/datasets/utils.py +11 -0
- torchrl/data/datasets/vd4rl.py +432 -0
- torchrl/data/llm/__init__.py +34 -0
- torchrl/data/llm/dataset.py +491 -0
- torchrl/data/llm/history.py +1378 -0
- torchrl/data/llm/prompt.py +198 -0
- torchrl/data/llm/reward.py +225 -0
- torchrl/data/llm/topk.py +186 -0
- torchrl/data/llm/utils.py +543 -0
- torchrl/data/map/__init__.py +21 -0
- torchrl/data/map/hash.py +185 -0
- torchrl/data/map/query.py +204 -0
- torchrl/data/map/tdstorage.py +363 -0
- torchrl/data/map/tree.py +1434 -0
- torchrl/data/map/utils.py +103 -0
- torchrl/data/postprocs/__init__.py +8 -0
- torchrl/data/postprocs/postprocs.py +391 -0
- torchrl/data/replay_buffers/__init__.py +99 -0
- torchrl/data/replay_buffers/checkpointers.py +622 -0
- torchrl/data/replay_buffers/ray_buffer.py +292 -0
- torchrl/data/replay_buffers/replay_buffers.py +2376 -0
- torchrl/data/replay_buffers/samplers.py +2578 -0
- torchrl/data/replay_buffers/scheduler.py +265 -0
- torchrl/data/replay_buffers/storages.py +2412 -0
- torchrl/data/replay_buffers/utils.py +1042 -0
- torchrl/data/replay_buffers/writers.py +781 -0
- torchrl/data/tensor_specs.py +7101 -0
- torchrl/data/utils.py +334 -0
- torchrl/envs/__init__.py +265 -0
- torchrl/envs/async_envs.py +1105 -0
- torchrl/envs/batched_envs.py +3093 -0
- torchrl/envs/common.py +4241 -0
- torchrl/envs/custom/__init__.py +11 -0
- torchrl/envs/custom/chess.py +617 -0
- torchrl/envs/custom/llm.py +214 -0
- torchrl/envs/custom/pendulum.py +401 -0
- torchrl/envs/custom/san_moves.txt +29274 -0
- torchrl/envs/custom/tictactoeenv.py +288 -0
- torchrl/envs/env_creator.py +263 -0
- torchrl/envs/gym_like.py +752 -0
- torchrl/envs/libs/__init__.py +68 -0
- torchrl/envs/libs/_gym_utils.py +326 -0
- torchrl/envs/libs/brax.py +846 -0
- torchrl/envs/libs/dm_control.py +544 -0
- torchrl/envs/libs/envpool.py +447 -0
- torchrl/envs/libs/gym.py +2239 -0
- torchrl/envs/libs/habitat.py +138 -0
- torchrl/envs/libs/isaac_lab.py +87 -0
- torchrl/envs/libs/isaacgym.py +203 -0
- torchrl/envs/libs/jax_utils.py +166 -0
- torchrl/envs/libs/jumanji.py +963 -0
- torchrl/envs/libs/meltingpot.py +599 -0
- torchrl/envs/libs/openml.py +153 -0
- torchrl/envs/libs/openspiel.py +652 -0
- torchrl/envs/libs/pettingzoo.py +1042 -0
- torchrl/envs/libs/procgen.py +351 -0
- torchrl/envs/libs/robohive.py +429 -0
- torchrl/envs/libs/smacv2.py +645 -0
- torchrl/envs/libs/unity_mlagents.py +891 -0
- torchrl/envs/libs/utils.py +147 -0
- torchrl/envs/libs/vmas.py +813 -0
- torchrl/envs/llm/__init__.py +63 -0
- torchrl/envs/llm/chat.py +730 -0
- torchrl/envs/llm/datasets/README.md +4 -0
- torchrl/envs/llm/datasets/__init__.py +17 -0
- torchrl/envs/llm/datasets/gsm8k.py +353 -0
- torchrl/envs/llm/datasets/ifeval.py +274 -0
- torchrl/envs/llm/envs.py +789 -0
- torchrl/envs/llm/libs/README.md +3 -0
- torchrl/envs/llm/libs/__init__.py +8 -0
- torchrl/envs/llm/libs/mlgym.py +869 -0
- torchrl/envs/llm/reward/__init__.py +10 -0
- torchrl/envs/llm/reward/gsm8k.py +324 -0
- torchrl/envs/llm/reward/ifeval/README.md +13 -0
- torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
- torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
- torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
- torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
- torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
- torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
- torchrl/envs/llm/transforms/__init__.py +55 -0
- torchrl/envs/llm/transforms/browser.py +292 -0
- torchrl/envs/llm/transforms/dataloading.py +859 -0
- torchrl/envs/llm/transforms/format.py +73 -0
- torchrl/envs/llm/transforms/kl.py +1544 -0
- torchrl/envs/llm/transforms/policy_version.py +189 -0
- torchrl/envs/llm/transforms/reason.py +323 -0
- torchrl/envs/llm/transforms/tokenizer.py +321 -0
- torchrl/envs/llm/transforms/tools.py +1955 -0
- torchrl/envs/model_based/__init__.py +9 -0
- torchrl/envs/model_based/common.py +180 -0
- torchrl/envs/model_based/dreamer.py +112 -0
- torchrl/envs/transforms/__init__.py +147 -0
- torchrl/envs/transforms/functional.py +48 -0
- torchrl/envs/transforms/gym_transforms.py +203 -0
- torchrl/envs/transforms/module.py +341 -0
- torchrl/envs/transforms/r3m.py +372 -0
- torchrl/envs/transforms/ray_service.py +663 -0
- torchrl/envs/transforms/rb_transforms.py +214 -0
- torchrl/envs/transforms/transforms.py +11835 -0
- torchrl/envs/transforms/utils.py +94 -0
- torchrl/envs/transforms/vc1.py +307 -0
- torchrl/envs/transforms/vecnorm.py +845 -0
- torchrl/envs/transforms/vip.py +407 -0
- torchrl/envs/utils.py +1718 -0
- torchrl/envs/vec_envs.py +11 -0
- torchrl/modules/__init__.py +206 -0
- torchrl/modules/distributions/__init__.py +73 -0
- torchrl/modules/distributions/continuous.py +830 -0
- torchrl/modules/distributions/discrete.py +908 -0
- torchrl/modules/distributions/truncated_normal.py +187 -0
- torchrl/modules/distributions/utils.py +233 -0
- torchrl/modules/llm/__init__.py +62 -0
- torchrl/modules/llm/backends/__init__.py +65 -0
- torchrl/modules/llm/backends/vllm/__init__.py +94 -0
- torchrl/modules/llm/backends/vllm/_models.py +46 -0
- torchrl/modules/llm/backends/vllm/base.py +72 -0
- torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
- torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
- torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
- torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
- torchrl/modules/llm/policies/__init__.py +28 -0
- torchrl/modules/llm/policies/common.py +1809 -0
- torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
- torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
- torchrl/modules/llm/utils.py +23 -0
- torchrl/modules/mcts/__init__.py +21 -0
- torchrl/modules/mcts/scores.py +579 -0
- torchrl/modules/models/__init__.py +86 -0
- torchrl/modules/models/batchrenorm.py +119 -0
- torchrl/modules/models/decision_transformer.py +179 -0
- torchrl/modules/models/exploration.py +731 -0
- torchrl/modules/models/llm.py +156 -0
- torchrl/modules/models/model_based.py +596 -0
- torchrl/modules/models/models.py +1712 -0
- torchrl/modules/models/multiagent.py +1067 -0
- torchrl/modules/models/recipes/impala.py +185 -0
- torchrl/modules/models/utils.py +162 -0
- torchrl/modules/planners/__init__.py +10 -0
- torchrl/modules/planners/cem.py +228 -0
- torchrl/modules/planners/common.py +73 -0
- torchrl/modules/planners/mppi.py +265 -0
- torchrl/modules/tensordict_module/__init__.py +89 -0
- torchrl/modules/tensordict_module/actors.py +2457 -0
- torchrl/modules/tensordict_module/common.py +529 -0
- torchrl/modules/tensordict_module/exploration.py +814 -0
- torchrl/modules/tensordict_module/probabilistic.py +321 -0
- torchrl/modules/tensordict_module/rnn.py +1639 -0
- torchrl/modules/tensordict_module/sequence.py +132 -0
- torchrl/modules/tensordict_module/world_models.py +34 -0
- torchrl/modules/utils/__init__.py +38 -0
- torchrl/modules/utils/mappings.py +9 -0
- torchrl/modules/utils/utils.py +89 -0
- torchrl/objectives/__init__.py +78 -0
- torchrl/objectives/a2c.py +659 -0
- torchrl/objectives/common.py +753 -0
- torchrl/objectives/cql.py +1346 -0
- torchrl/objectives/crossq.py +710 -0
- torchrl/objectives/ddpg.py +453 -0
- torchrl/objectives/decision_transformer.py +371 -0
- torchrl/objectives/deprecated.py +516 -0
- torchrl/objectives/dqn.py +683 -0
- torchrl/objectives/dreamer.py +488 -0
- torchrl/objectives/functional.py +48 -0
- torchrl/objectives/gail.py +258 -0
- torchrl/objectives/iql.py +996 -0
- torchrl/objectives/llm/__init__.py +30 -0
- torchrl/objectives/llm/grpo.py +846 -0
- torchrl/objectives/llm/sft.py +482 -0
- torchrl/objectives/multiagent/__init__.py +8 -0
- torchrl/objectives/multiagent/qmixer.py +396 -0
- torchrl/objectives/ppo.py +1669 -0
- torchrl/objectives/redq.py +683 -0
- torchrl/objectives/reinforce.py +530 -0
- torchrl/objectives/sac.py +1580 -0
- torchrl/objectives/td3.py +570 -0
- torchrl/objectives/td3_bc.py +625 -0
- torchrl/objectives/utils.py +782 -0
- torchrl/objectives/value/__init__.py +28 -0
- torchrl/objectives/value/advantages.py +1956 -0
- torchrl/objectives/value/functional.py +1459 -0
- torchrl/objectives/value/utils.py +360 -0
- torchrl/record/__init__.py +17 -0
- torchrl/record/loggers/__init__.py +23 -0
- torchrl/record/loggers/common.py +48 -0
- torchrl/record/loggers/csv.py +226 -0
- torchrl/record/loggers/mlflow.py +142 -0
- torchrl/record/loggers/tensorboard.py +139 -0
- torchrl/record/loggers/trackio.py +163 -0
- torchrl/record/loggers/utils.py +78 -0
- torchrl/record/loggers/wandb.py +214 -0
- torchrl/record/recorder.py +554 -0
- torchrl/services/__init__.py +79 -0
- torchrl/services/base.py +109 -0
- torchrl/services/ray_service.py +453 -0
- torchrl/testing/__init__.py +107 -0
- torchrl/testing/assertions.py +179 -0
- torchrl/testing/dist_utils.py +122 -0
- torchrl/testing/env_creators.py +227 -0
- torchrl/testing/env_helper.py +35 -0
- torchrl/testing/gym_helpers.py +156 -0
- torchrl/testing/llm_mocks.py +119 -0
- torchrl/testing/mocking_classes.py +2720 -0
- torchrl/testing/modules.py +295 -0
- torchrl/testing/mp_helpers.py +15 -0
- torchrl/testing/ray_helpers.py +293 -0
- torchrl/testing/utils.py +190 -0
- torchrl/trainers/__init__.py +42 -0
- torchrl/trainers/algorithms/__init__.py +11 -0
- torchrl/trainers/algorithms/configs/__init__.py +705 -0
- torchrl/trainers/algorithms/configs/collectors.py +216 -0
- torchrl/trainers/algorithms/configs/common.py +41 -0
- torchrl/trainers/algorithms/configs/data.py +308 -0
- torchrl/trainers/algorithms/configs/envs.py +104 -0
- torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
- torchrl/trainers/algorithms/configs/logging.py +80 -0
- torchrl/trainers/algorithms/configs/modules.py +570 -0
- torchrl/trainers/algorithms/configs/objectives.py +177 -0
- torchrl/trainers/algorithms/configs/trainers.py +340 -0
- torchrl/trainers/algorithms/configs/transforms.py +955 -0
- torchrl/trainers/algorithms/configs/utils.py +252 -0
- torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
- torchrl/trainers/algorithms/configs/weight_update.py +159 -0
- torchrl/trainers/algorithms/ppo.py +373 -0
- torchrl/trainers/algorithms/sac.py +308 -0
- torchrl/trainers/helpers/__init__.py +40 -0
- torchrl/trainers/helpers/collectors.py +416 -0
- torchrl/trainers/helpers/envs.py +573 -0
- torchrl/trainers/helpers/logger.py +33 -0
- torchrl/trainers/helpers/losses.py +132 -0
- torchrl/trainers/helpers/models.py +658 -0
- torchrl/trainers/helpers/replay_buffer.py +59 -0
- torchrl/trainers/helpers/trainers.py +301 -0
- torchrl/trainers/trainers.py +2052 -0
- torchrl/weight_update/__init__.py +33 -0
- torchrl/weight_update/_distributed.py +749 -0
- torchrl/weight_update/_mp.py +624 -0
- torchrl/weight_update/_noupdate.py +102 -0
- torchrl/weight_update/_ray.py +1032 -0
- torchrl/weight_update/_rpc.py +284 -0
- torchrl/weight_update/_shared.py +891 -0
- torchrl/weight_update/llm/__init__.py +32 -0
- torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
- torchrl/weight_update/llm/vllm_nccl.py +710 -0
- torchrl/weight_update/utils.py +73 -0
- torchrl/weight_update/weight_sync_schemes.py +1244 -0
- torchrl-0.11.0.dist-info/METADATA +1308 -0
- torchrl-0.11.0.dist-info/RECORD +395 -0
- torchrl-0.11.0.dist-info/WHEEL +5 -0
- torchrl-0.11.0.dist-info/entry_points.txt +2 -0
- torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
- torchrl-0.11.0.dist-info/top_level.txt +7 -0
|
@@ -0,0 +1,2756 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import contextlib
|
|
8
|
+
import threading
|
|
9
|
+
from contextlib import nullcontext
|
|
10
|
+
from copy import copy
|
|
11
|
+
from typing import Any, Literal
|
|
12
|
+
|
|
13
|
+
import torch
|
|
14
|
+
from tensordict import (
|
|
15
|
+
lazy_stack,
|
|
16
|
+
LazyStackedTensorDict,
|
|
17
|
+
MetaData,
|
|
18
|
+
NonTensorStack,
|
|
19
|
+
set_list_to_stack,
|
|
20
|
+
TensorDict,
|
|
21
|
+
TensorDictBase,
|
|
22
|
+
)
|
|
23
|
+
from tensordict.utils import _zip_strict, NestedKey
|
|
24
|
+
from torch import distributions as D
|
|
25
|
+
from torch.nn.utils.rnn import pad_sequence
|
|
26
|
+
from torchrl import logger as torchrl_logger
|
|
27
|
+
from torchrl.modules.llm.policies.common import (
|
|
28
|
+
_batching,
|
|
29
|
+
_extract_responses_from_full_histories,
|
|
30
|
+
ChatHistory,
|
|
31
|
+
LLMWrapperBase,
|
|
32
|
+
LogProbs,
|
|
33
|
+
Masks,
|
|
34
|
+
Text,
|
|
35
|
+
Tokens,
|
|
36
|
+
)
|
|
37
|
+
from torchrl.modules.utils.utils import _unpad_tensors
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class TransformersWrapper(LLMWrapperBase):
|
|
41
|
+
"""A wrapper class for Hugging Face Transformers models, providing a consistent interface for text generation and log probability computation.
|
|
42
|
+
|
|
43
|
+
Packing vs Padding:
|
|
44
|
+
- Packing (`pad_model_input=False`):
|
|
45
|
+
* More memory efficient for variable-length sequences.
|
|
46
|
+
* Not all models support packed input (requires custom attention masks and position ids).
|
|
47
|
+
* May be less compatible with some HuggingFace models or custom architectures.
|
|
48
|
+
- Padding (`pad_model_input=True`):
|
|
49
|
+
* Universally supported by all models.
|
|
50
|
+
* Wastes memory for short sequences in a batch.
|
|
51
|
+
* Simpler, but less efficient for highly variable-length data.
|
|
52
|
+
- If unsure, use padding for maximum compatibility. Use packing for large batches of variable-length data and when your model supports it.
|
|
53
|
+
|
|
54
|
+
Additional error handling is provided for empty and overlong sequences.
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
model (transformers.AutoModelForCausalLM | str): The Hugging Face Transformers model to wrap.
|
|
58
|
+
If a string, it will be passed to `transformers.AutoModelForCausalLM.from_pretrained` (and `AutoTokenizer.from_pretrained`
|
|
59
|
+
if `tokenizer` is not provided).
|
|
60
|
+
|
|
61
|
+
Keyword Args:
|
|
62
|
+
tokenizer (transformers.tokenization_utils.PreTrainedTokenizer | str | None, optional): The tokenizer to use for
|
|
63
|
+
encoding and decoding text. If `None`, the tokenizer associated with the model will be used.
|
|
64
|
+
If a string, it will be passed to `transformers.AutoTokenizer.from_pretrained`. Defaults to `None`.
|
|
65
|
+
input_mode (str, optional): The input modality to use. Must be one of `"history"`, `"text"`, or `"tokens"`.
|
|
66
|
+
Defaults to `"history"`.
|
|
67
|
+
input_key (str | None, optional): The key for the input data. If `None`, defaults to
|
|
68
|
+
- `("history", "prompt")` for `"history"` when `generate=True`, `("history", "full")` for `"history"` when `generate=False`
|
|
69
|
+
- `("text", "prompt")` for `"text"` when `generate=True`, `("text", "full")` for `"text"` when `generate=False`
|
|
70
|
+
- `("tokens", "prompt")` for `"tokens"` when `generate=True`, `("tokens", "full")` for `"tokens"` when `generate=False`
|
|
71
|
+
attention_mask_key (str, optional): The key for attention masks (used in `"tokens"` mode). Defaults to `"attention_mask"`.
|
|
72
|
+
|
|
73
|
+
.. warning:: This argument is under development and may change in the future.
|
|
74
|
+
|
|
75
|
+
generate (bool, optional): Whether to enable text generation. If `True`, the model will generate text based on the input.
|
|
76
|
+
If `False`, only log probabilities will be computed. Defaults to `True`.
|
|
77
|
+
return_log_probs (bool, optional): Whether to return log probabilities. Defaults to `False`.
|
|
78
|
+
generate_kwargs (dict | None, optional): Additional arguments to pass to the model's generate method. Defaults to `None`.
|
|
79
|
+
|
|
80
|
+
**Standardized Parameters (cross-backend compatible):**
|
|
81
|
+
|
|
82
|
+
* **max_new_tokens** (int): Maximum number of new tokens to generate
|
|
83
|
+
* **num_return_sequences** (int): Number of sequences to return
|
|
84
|
+
* **temperature** (float): Sampling temperature (0.0 = deterministic, higher = more random)
|
|
85
|
+
* **top_p** (float): Nucleus sampling parameter (0.0-1.0)
|
|
86
|
+
* **top_k** (int): Top-k sampling parameter
|
|
87
|
+
* **repetition_penalty** (float): Penalty for repeating tokens
|
|
88
|
+
* **do_sample** (bool): Whether to use sampling vs greedy decoding
|
|
89
|
+
* **num_beams** (int): Number of beams for beam search
|
|
90
|
+
* **length_penalty** (float): Penalty for sequence length
|
|
91
|
+
* **early_stopping** (bool): Whether to stop early in beam search
|
|
92
|
+
* **stop_sequences** (list): Sequences that stop generation (requires custom stopping criteria)
|
|
93
|
+
* **skip_special_tokens** (bool): Whether to skip special tokens in output
|
|
94
|
+
* **logprobs** (bool): Whether to return log probabilities (maps to output_scores)
|
|
95
|
+
|
|
96
|
+
.. warning:: Usage of this parameter is discouraged as it may conflict with the `generate` parameter
|
|
97
|
+
of the class.
|
|
98
|
+
|
|
99
|
+
**Transformers-Specific Parameters:**
|
|
100
|
+
|
|
101
|
+
* **pad_token_id** (int): Token ID for padding
|
|
102
|
+
* **eos_token_id** (int): Token ID for end of sequence
|
|
103
|
+
* **bad_words_ids** (list): List of token IDs to avoid
|
|
104
|
+
* **force_words_ids** (list): List of token IDs to force
|
|
105
|
+
* **no_repeat_ngram_size** (int): Size of n-grams to avoid repeating
|
|
106
|
+
* **encoder_repetition_penalty** (float): Repetition penalty for encoder-decoder models
|
|
107
|
+
* **num_beam_groups** (int): Number of beam groups for diverse beam search
|
|
108
|
+
* **diversity_penalty** (float): Penalty for beam diversity
|
|
109
|
+
* **output_scores** (bool): Whether to output scores
|
|
110
|
+
* **return_dict_in_generate** (bool): Whether to return dict in generate
|
|
111
|
+
|
|
112
|
+
**Legacy Parameter Support:**
|
|
113
|
+
|
|
114
|
+
* **max_tokens** (int): Automatically converted to max_new_tokens
|
|
115
|
+
* **n** (int): Automatically converted to num_return_sequences
|
|
116
|
+
|
|
117
|
+
**Parameter Conflict Resolution:**
|
|
118
|
+
|
|
119
|
+
When both legacy (Transformers-specific) and standardized parameter names are provided,
|
|
120
|
+
a :exc:`ValueError` is raised to prevent confusion. For example:
|
|
121
|
+
|
|
122
|
+
* If both ``max_tokens`` and ``max_new_tokens`` are passed, an error is raised
|
|
123
|
+
* If both ``n`` and ``num_return_sequences`` are passed, an error is raised
|
|
124
|
+
|
|
125
|
+
This ensures clear parameter usage and prevents unexpected behavior.
|
|
126
|
+
|
|
127
|
+
tokenizer_kwargs (dict | None, optional): Additional arguments to pass to the tokenizer. Defaults to `None`.
|
|
128
|
+
pad_output (bool, optional): Whether to pad the output sequences to a uniform length. This does not impact the underlying padding
|
|
129
|
+
during call to the model. To use padding or packing during the model `forward` call, see `pad_model_input`.
|
|
130
|
+
Defaults to `False`.
|
|
131
|
+
pad_model_input (bool, optional): Whether to pad the model input sequences to a uniform length.
|
|
132
|
+
If `False`, packing will be used instead. Packing is generally more memory efficient than padding,
|
|
133
|
+
but this feature may not work with all models.
|
|
134
|
+
`pad_model_input` can only be used when `generate=False`.
|
|
135
|
+
This does not impact the padding of the model output - one may ask for padded output though `pad_output=True` while the model
|
|
136
|
+
is called with `pad_model_input=False`.
|
|
137
|
+
Defaults to `True`.
|
|
138
|
+
inplace (Literal[True, False, "empty"] | None, optional): Determines how the module should handle in-place operations. Defaults to `True`.
|
|
139
|
+
device (torch.device | None, optional): The device to use for computation. Defaults to `None`.
|
|
140
|
+
layout (torch.layout | None, optional): The layout to use for the output tensors when `pad_output=False`. Defaults to `torch.strided`.
|
|
141
|
+
num_samples (int | None, optional): The number of samples to generate. Defaults to `None` (one sample, and no batch-dimension for it).
|
|
142
|
+
Can also be set via the `generate_kwargs["num_return_sequences"] = value` argument. Requires the "do_sample" argument to be set to `True` in `generate_kwargs`.
|
|
143
|
+
chat_template_name (Literal["chatml_format", "qwen"] | None, optional): The name of the chat template to use when applying the chat
|
|
144
|
+
template to the history. Defaults to `None`. For `input_mode="history"` only.
|
|
145
|
+
chat_template (str | None, optional): The chat template to use when applying the chat template to the history.
|
|
146
|
+
Defaults to `None`. For `input_mode="history"` only.
|
|
147
|
+
log_probs_key (NestedKey | None, optional): The key for the log probabilities :class:`~torchrl.modules.llm.policies.LogProbs` object. Defaults to `"log_probs"`.
|
|
148
|
+
text_key (NestedKey | None, optional): The key for the action :class:`~torchrl.modules.llm.policies.Text` object. Defaults to `"text"`.
|
|
149
|
+
tokens_key (NestedKey | None, optional): The key for the action :class:`~torchrl.modules.llm.policies.Tokens` object. Defaults to `"tokens"`.
|
|
150
|
+
masks_key (NestedKey | None, optional): The key for the action :class:`~torchrl.modules.llm.policies.Masks` object. Defaults to `"masks"`.
|
|
151
|
+
history_key (NestedKey | None, optional): The key for the action :class:`~torchrl.modules.llm.policies.ChatHistory` object. Defaults to `"history"`.
|
|
152
|
+
batching (bool | None, optional): Whether to enable batching. See `Batching`_ below for more details.
|
|
153
|
+
min_batch_size (int | None, optional): The minimum batch size to use for batching. See `Batching`_ below for more details.
|
|
154
|
+
max_batch_size (int | None, optional): The maximum batch size to use for batching. See `Batching`_ below for more details.
|
|
155
|
+
batching_timeout (float, optional): The timeout for batching. See `Batching`_ below for more details.
|
|
156
|
+
|
|
157
|
+
.. _Batching:
|
|
158
|
+
|
|
159
|
+
**Batching**
|
|
160
|
+
|
|
161
|
+
Batching is a feature that allows the module to process multiple inputs in a single call.
|
|
162
|
+
It is designed to work in a multi-threaded environment.
|
|
163
|
+
To enable batching, it suffices to set `batching=True` which will set `min_batch_size` to 1 if not provided.
|
|
164
|
+
If you want to set a different value for `min_batch_size` or `max_batch_size` for a fine-grained control,
|
|
165
|
+
you can to set `batching=True` and then set `min_batch_size` or `max_batch_size` to a value greater or equal to 1.
|
|
166
|
+
The way batching works is as follows:
|
|
167
|
+
- If `min_batch_size` is not provided but `max_batch_size` is, `min_batch_size` is set to 1.
|
|
168
|
+
- If `max_batch_size` is not provided but `min_batch_size` is, `max_batch_size` is set to the number of inputs in the queue.
|
|
169
|
+
- When the model is called, a check is performed to see if the number of inputs in the queue is greater or equal to `min_batch_size`.
|
|
170
|
+
If it is, the batch is processed immediately, while waiting for the previous batch to be processed if the model is busy.
|
|
171
|
+
Otherwise, the input is added to the queue and the function waits for the batch to be completed.
|
|
172
|
+
While waiting for the batch to be completed, a timeout is set to `batching_timeout` seconds such that if the batch is not
|
|
173
|
+
completed after `batching_timeout` seconds, the remaining items to process are processed as is and the function returns after
|
|
174
|
+
at most `batching_timeout` seconds (plus the time to finish processing the previous and current batch).
|
|
175
|
+
|
|
176
|
+
Input Keys:
|
|
177
|
+
The input key depends on both `input_mode` and `generate`:
|
|
178
|
+
|
|
179
|
+
- If `input_mode="history"` and `generate=True`: `input_key` (defaults to `("history", "prompt")`)
|
|
180
|
+
- If `input_mode="history"` and `generate=False`: `input_key` (defaults to `("history", "full")`)
|
|
181
|
+
- If `input_mode="text"` and `generate=True`: `input_key` (defaults to `("text", "prompt")`)
|
|
182
|
+
- If `input_mode="text"` and `generate=False`: `input_key` (defaults to `("text", "full")`)
|
|
183
|
+
- If `input_mode="tokens"` and `generate=True`: `input_key` (defaults to `("tokens", "prompt")`)
|
|
184
|
+
- If `input_mode="tokens"` and `generate=False`: `input_key` (defaults to `("tokens", "full")`)
|
|
185
|
+
|
|
186
|
+
Output Keys:
|
|
187
|
+
The output keys are automatically determined based on the input_mode:
|
|
188
|
+
- **Tokens**: Always returned (`tokens_key`, defaults to `"tokens"`)
|
|
189
|
+
- **Text**: Returned for `"text"` and `"history"` modes (`text_key`, defaults to `"text"`)
|
|
190
|
+
- **History**: Returned only for `"history"` mode (`history_key`, defaults to `"history"`)
|
|
191
|
+
- **Masks**: Always returned (`masks_key`, defaults to `"masks"`)
|
|
192
|
+
- **Log Probs**: Returned when `return_log_probs=True` (`log_probs_key`, defaults to `"log_probs"`)
|
|
193
|
+
|
|
194
|
+
Example output structure for `input_mode="history"`::
|
|
195
|
+
|
|
196
|
+
TensorDict(
|
|
197
|
+
text=Text(prompt=..., response=..., full=...),
|
|
198
|
+
masks=Masks(all_attention_mask=..., all_assistant_mask=...),
|
|
199
|
+
tokens=Tokens(prompt=..., response=..., full=...),
|
|
200
|
+
log_probs=LogProbs(prompt=..., response=..., full=...),
|
|
201
|
+
history=ChatHistory(prompt=..., response=..., full=...)
|
|
202
|
+
)
|
|
203
|
+
|
|
204
|
+
Example:
|
|
205
|
+
>>> from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
206
|
+
>>> from torchrl.data.llm import History
|
|
207
|
+
>>> from torchrl.modules.llm.policies import ChatHistory
|
|
208
|
+
>>>
|
|
209
|
+
>>> model = AutoModelForCausalLM.from_pretrained("gpt2")
|
|
210
|
+
>>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
|
211
|
+
>>>
|
|
212
|
+
>>> # History input (recommended for RL environments)
|
|
213
|
+
>>> wrapper = TransformersWrapper(
|
|
214
|
+
... model,
|
|
215
|
+
... tokenizer=tokenizer,
|
|
216
|
+
... input_mode="history",
|
|
217
|
+
... generate=True,
|
|
218
|
+
... return_log_probs=True,
|
|
219
|
+
... generate_kwargs={
|
|
220
|
+
... "max_new_tokens": 50, # Standardized parameter
|
|
221
|
+
... "temperature": 0.7,
|
|
222
|
+
... "top_p": 0.9,
|
|
223
|
+
... "do_sample": True,
|
|
224
|
+
... }
|
|
225
|
+
... )
|
|
226
|
+
>>>
|
|
227
|
+
>>> history = History.from_chats([[
|
|
228
|
+
... {"role": "user", "content": "Hello"},
|
|
229
|
+
... {"role": "assistant", "content": "Hi there!"}
|
|
230
|
+
... ]])
|
|
231
|
+
>>> chat_history = ChatHistory(prompt=history)
|
|
232
|
+
>>> result = wrapper(TensorDict(history=chat_history, batch_size=(1,)))
|
|
233
|
+
>>> print(result["text"].response) # Generated text
|
|
234
|
+
>>> print(result["log_probs"].response) # Log probabilities
|
|
235
|
+
>>> print(result["history"].response) # History with response
|
|
236
|
+
|
|
237
|
+
Attributes:
|
|
238
|
+
collector: The collector associated with the module, if it exists.
|
|
239
|
+
|
|
240
|
+
.. seealso::
|
|
241
|
+
- :class:`~torchrl.modules.llm.policies.LLMWrapperBase`
|
|
242
|
+
- :class:`~torchrl.modules.llm.policies.vLLMWrapper`
|
|
243
|
+
"""
|
|
244
|
+
|
|
245
|
+
def __init__(
|
|
246
|
+
self,
|
|
247
|
+
model,
|
|
248
|
+
*,
|
|
249
|
+
tokenizer=None,
|
|
250
|
+
input_mode: str = "history",
|
|
251
|
+
input_key: str | None = None,
|
|
252
|
+
attention_mask_key: str = "attention_mask",
|
|
253
|
+
generate: bool = True,
|
|
254
|
+
generate_kwargs: dict | None = None,
|
|
255
|
+
tokenizer_kwargs: dict | None = None,
|
|
256
|
+
pad_output: bool = False,
|
|
257
|
+
pad_model_input: bool | None = None,
|
|
258
|
+
inplace: Literal[True, False, "empty"] | None = None,
|
|
259
|
+
device: torch.device | None = None,
|
|
260
|
+
layout: torch.layout | None = None,
|
|
261
|
+
num_samples: int | None = None,
|
|
262
|
+
chat_template_name: Literal["chatml_format", "qwen"] | None = None,
|
|
263
|
+
chat_template: str | None = None,
|
|
264
|
+
return_log_probs: bool | None = None,
|
|
265
|
+
history_key: NestedKey | None = "history",
|
|
266
|
+
text_key: NestedKey | None = "text",
|
|
267
|
+
tokens_key: NestedKey | None = "tokens",
|
|
268
|
+
masks_key: NestedKey | None = "masks",
|
|
269
|
+
log_probs_key: NestedKey | None = "log_probs",
|
|
270
|
+
batching: bool | None = None,
|
|
271
|
+
min_batch_size: int | None = None,
|
|
272
|
+
max_batch_size: int | None = None,
|
|
273
|
+
batching_timeout: float = 10.0,
|
|
274
|
+
):
|
|
275
|
+
super().__init__()
|
|
276
|
+
|
|
277
|
+
if batching and min_batch_size is None:
|
|
278
|
+
min_batch_size = 1
|
|
279
|
+
elif (min_batch_size is not None or max_batch_size is not None) and (
|
|
280
|
+
batching is False
|
|
281
|
+
):
|
|
282
|
+
raise ValueError(
|
|
283
|
+
"min_batch_size and max_batch_size must be None if batching is False."
|
|
284
|
+
)
|
|
285
|
+
|
|
286
|
+
# Validate that min_batch_size <= max_batch_size when both are specified
|
|
287
|
+
if min_batch_size is not None and max_batch_size is not None:
|
|
288
|
+
if min_batch_size > max_batch_size:
|
|
289
|
+
raise ValueError(
|
|
290
|
+
f"min_batch_size ({min_batch_size}) must be <= max_batch_size ({max_batch_size})"
|
|
291
|
+
)
|
|
292
|
+
|
|
293
|
+
self._min_batch_size = min_batch_size
|
|
294
|
+
self._max_batch_size = max_batch_size
|
|
295
|
+
self._batching_timeout = batching_timeout
|
|
296
|
+
self._batch_queue = []
|
|
297
|
+
self._futures = []
|
|
298
|
+
if self.batching:
|
|
299
|
+
self._batching_lock = threading.Lock()
|
|
300
|
+
else:
|
|
301
|
+
self._batching_lock = None
|
|
302
|
+
|
|
303
|
+
if isinstance(model, str):
|
|
304
|
+
if tokenizer is None:
|
|
305
|
+
from transformers import AutoTokenizer
|
|
306
|
+
|
|
307
|
+
tokenizer = AutoTokenizer.from_pretrained(model)
|
|
308
|
+
|
|
309
|
+
from transformers import AutoModelForCausalLM
|
|
310
|
+
|
|
311
|
+
model = AutoModelForCausalLM.from_pretrained(model)
|
|
312
|
+
|
|
313
|
+
if isinstance(tokenizer, str):
|
|
314
|
+
from transformers import AutoTokenizer
|
|
315
|
+
|
|
316
|
+
tokenizer = AutoTokenizer.from_pretrained(tokenizer)
|
|
317
|
+
|
|
318
|
+
# Validate input_mode
|
|
319
|
+
if input_mode not in ["history", "text", "tokens"]:
|
|
320
|
+
raise ValueError(
|
|
321
|
+
f"input_mode must be one of 'history', 'text', 'tokens'. Got '{input_mode}'"
|
|
322
|
+
)
|
|
323
|
+
|
|
324
|
+
self.model = model
|
|
325
|
+
self.input_mode = input_mode
|
|
326
|
+
self.attention_mask_key = attention_mask_key
|
|
327
|
+
self.generate = generate
|
|
328
|
+
if pad_model_input is not None and generate:
|
|
329
|
+
raise ValueError("pad_model_input is not supported when generate=True.")
|
|
330
|
+
pad_model_input = pad_model_input if pad_model_input is not None else True
|
|
331
|
+
self.pad_model_input = pad_model_input
|
|
332
|
+
|
|
333
|
+
# Auto-determine what to return based on input mode
|
|
334
|
+
self.return_history = input_mode in ("history",)
|
|
335
|
+
self.return_text = input_mode in ("text", "history")
|
|
336
|
+
self.return_tokens = input_mode in ("tokens", "history", "text")
|
|
337
|
+
self.return_masks = True
|
|
338
|
+
if return_log_probs is False and not generate:
|
|
339
|
+
raise ValueError("return_log_probs must be True when generate=False.")
|
|
340
|
+
return_log_probs = (
|
|
341
|
+
True
|
|
342
|
+
if (return_log_probs is None and generate) or (not generate)
|
|
343
|
+
else bool(return_log_probs)
|
|
344
|
+
)
|
|
345
|
+
self.return_log_probs = return_log_probs
|
|
346
|
+
|
|
347
|
+
self.history_key = history_key
|
|
348
|
+
self.text_key = text_key
|
|
349
|
+
self.tokens_key = tokens_key
|
|
350
|
+
self.masks_key = masks_key
|
|
351
|
+
self.log_probs_key = log_probs_key
|
|
352
|
+
if not isinstance(pad_output, bool):
|
|
353
|
+
raise ValueError("pad_output must be a boolean")
|
|
354
|
+
self.pad_output = pad_output
|
|
355
|
+
self._device = device
|
|
356
|
+
if not pad_output and layout is None:
|
|
357
|
+
layout = torch.strided
|
|
358
|
+
self.layout = layout
|
|
359
|
+
padding_value = None
|
|
360
|
+
|
|
361
|
+
# Auto-determine input_key if not provided
|
|
362
|
+
|
|
363
|
+
# Set input keys based on mode and generate parameter
|
|
364
|
+
if input_mode == "history":
|
|
365
|
+
if generate:
|
|
366
|
+
self.in_keys = [
|
|
367
|
+
("history", "prompt") if input_key is None else input_key
|
|
368
|
+
]
|
|
369
|
+
else:
|
|
370
|
+
self.in_keys = [("history", "full") if input_key is None else input_key]
|
|
371
|
+
elif input_mode == "text":
|
|
372
|
+
if generate:
|
|
373
|
+
self.in_keys = [("text", "prompt") if input_key is None else input_key]
|
|
374
|
+
else:
|
|
375
|
+
self.in_keys = [("text", "full") if input_key is None else input_key]
|
|
376
|
+
elif input_mode == "tokens":
|
|
377
|
+
if generate:
|
|
378
|
+
self.in_keys = [
|
|
379
|
+
("tokens", "prompt") if input_key is None else input_key
|
|
380
|
+
]
|
|
381
|
+
else:
|
|
382
|
+
self.in_keys = [("tokens", "full") if input_key is None else input_key]
|
|
383
|
+
self.input_key = self.in_keys[0]
|
|
384
|
+
|
|
385
|
+
# Set output keys based on auto-determined return flags
|
|
386
|
+
self.out_keys = []
|
|
387
|
+
if self.return_text:
|
|
388
|
+
self.out_keys.append(self.text_key)
|
|
389
|
+
if self.return_masks:
|
|
390
|
+
self.out_keys.append(self.masks_key)
|
|
391
|
+
if self.return_tokens:
|
|
392
|
+
self.out_keys.append(self.tokens_key)
|
|
393
|
+
if self.return_log_probs:
|
|
394
|
+
self.out_keys.append(self.log_probs_key)
|
|
395
|
+
if self.return_history:
|
|
396
|
+
self.out_keys.append(self.history_key)
|
|
397
|
+
|
|
398
|
+
# Tokenizer setup
|
|
399
|
+
if not tokenizer_kwargs:
|
|
400
|
+
tokenizer_kwargs = {}
|
|
401
|
+
else:
|
|
402
|
+
tokenizer_kwargs = dict(tokenizer_kwargs)
|
|
403
|
+
if not tokenizer_kwargs.setdefault("return_attention_mask", True):
|
|
404
|
+
raise RuntimeError("return_attention_mask must be True")
|
|
405
|
+
|
|
406
|
+
# We always pad, so we always return tensors
|
|
407
|
+
return_tensors = "pt"
|
|
408
|
+
tokenizer_kwargs.setdefault("padding", True)
|
|
409
|
+
if return_tensors:
|
|
410
|
+
if (
|
|
411
|
+
tokenizer_kwargs.setdefault("return_tensors", return_tensors)
|
|
412
|
+
!= return_tensors
|
|
413
|
+
):
|
|
414
|
+
raise RuntimeError
|
|
415
|
+
|
|
416
|
+
# We always pad atm
|
|
417
|
+
if tokenizer_kwargs.setdefault("padding_side", "left") != "left":
|
|
418
|
+
raise RuntimeError
|
|
419
|
+
|
|
420
|
+
self.tokenizer_kwargs = tokenizer_kwargs
|
|
421
|
+
|
|
422
|
+
# Get tokenizer if needed
|
|
423
|
+
if (
|
|
424
|
+
pad_output or (input_mode in ["text", "history"] and not generate)
|
|
425
|
+
) and tokenizer is None:
|
|
426
|
+
tokenizer = model.get_tokenizer()
|
|
427
|
+
self.tokenizer = tokenizer
|
|
428
|
+
|
|
429
|
+
if self.tokenizer is not None and (
|
|
430
|
+
not hasattr(self.tokenizer, "pad_token") or self.tokenizer.pad_token is None
|
|
431
|
+
):
|
|
432
|
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
|
433
|
+
if self.tokenizer is not None:
|
|
434
|
+
padding_value = self.tokenizer(self.tokenizer.pad_token)["input_ids"][0]
|
|
435
|
+
self.padding_value = padding_value
|
|
436
|
+
|
|
437
|
+
# Generate kwargs setup
|
|
438
|
+
if generate_kwargs is None:
|
|
439
|
+
generate_kwargs = {}
|
|
440
|
+
else:
|
|
441
|
+
generate_kwargs = dict(generate_kwargs)
|
|
442
|
+
|
|
443
|
+
# Standardize common parameters
|
|
444
|
+
generate_kwargs = self._standardize_generate_kwargs(generate_kwargs)
|
|
445
|
+
|
|
446
|
+
# Extract wrapper-specific parameters
|
|
447
|
+
transformers_specific_kwargs = self._get_wrapper_specific_kwargs(
|
|
448
|
+
generate_kwargs, "transformers"
|
|
449
|
+
)
|
|
450
|
+
|
|
451
|
+
# Convert common parameters to Transformers format
|
|
452
|
+
transformers_kwargs = {}
|
|
453
|
+
for key, value in generate_kwargs.items():
|
|
454
|
+
if key in self.COMMON_GENERATION_PARAMS:
|
|
455
|
+
# Convert common names to Transformers names
|
|
456
|
+
if key == "stop_sequences":
|
|
457
|
+
# Transformers uses stopping_criteria for stop sequences
|
|
458
|
+
# This requires custom stopping criteria implementation
|
|
459
|
+
# For now, we'll warn and skip this parameter
|
|
460
|
+
import warnings
|
|
461
|
+
|
|
462
|
+
warnings.warn(
|
|
463
|
+
"stop_sequences parameter is not yet fully supported in TransformersWrapper. "
|
|
464
|
+
"Use eos_token_id or implement custom stopping criteria for full support.",
|
|
465
|
+
UserWarning,
|
|
466
|
+
stacklevel=2,
|
|
467
|
+
)
|
|
468
|
+
continue
|
|
469
|
+
elif key == "logprobs":
|
|
470
|
+
transformers_kwargs["output_scores"] = value
|
|
471
|
+
else:
|
|
472
|
+
# Direct mapping for other common parameters
|
|
473
|
+
transformers_kwargs[key] = value
|
|
474
|
+
|
|
475
|
+
# Add Transformers-specific parameters
|
|
476
|
+
transformers_kwargs.update(transformers_specific_kwargs)
|
|
477
|
+
|
|
478
|
+
self.num_samples = num_samples
|
|
479
|
+
if (
|
|
480
|
+
transformers_kwargs.get("num_return_sequences", 1) > 1
|
|
481
|
+
or num_samples is not None
|
|
482
|
+
):
|
|
483
|
+
if inplace in (True, "empty"):
|
|
484
|
+
raise ValueError(
|
|
485
|
+
"inplace must be False (or None) when generating more than one sample."
|
|
486
|
+
)
|
|
487
|
+
if inplace is None:
|
|
488
|
+
inplace = False
|
|
489
|
+
if (
|
|
490
|
+
transformers_kwargs.get("num_return_sequences", 1) > 1
|
|
491
|
+
and num_samples is not None
|
|
492
|
+
and transformers_kwargs.get("num_return_sequences", 1) != num_samples
|
|
493
|
+
):
|
|
494
|
+
raise ValueError("num_samples differs from generate_kwargs['n'].")
|
|
495
|
+
elif num_samples is None:
|
|
496
|
+
self.num_samples = transformers_kwargs.get("num_return_sequences", 1)
|
|
497
|
+
transformers_kwargs["num_return_sequences"] = self.num_samples
|
|
498
|
+
elif inplace is None:
|
|
499
|
+
inplace = True
|
|
500
|
+
|
|
501
|
+
self.inplace = inplace
|
|
502
|
+
|
|
503
|
+
if not generate:
|
|
504
|
+
# We want only the log-probs, we generate a single token (that we then discard)
|
|
505
|
+
# and retrieve the prompt log-probs
|
|
506
|
+
transformers_kwargs["max_new_tokens"] = 1
|
|
507
|
+
|
|
508
|
+
transformers_kwargs.setdefault("tokenizer", self.tokenizer)
|
|
509
|
+
transformers_kwargs.setdefault("output_logits", self.return_log_probs)
|
|
510
|
+
transformers_kwargs.setdefault("return_dict_in_generate", True)
|
|
511
|
+
|
|
512
|
+
self.generate_kwargs = transformers_kwargs
|
|
513
|
+
|
|
514
|
+
# Additional transformers-specific settings
|
|
515
|
+
self.chat_template_name = chat_template_name
|
|
516
|
+
self.chat_template = chat_template
|
|
517
|
+
|
|
518
|
+
# Flag to track when we're in a get_dist call
|
|
519
|
+
self._in_get_dist_call = False
|
|
520
|
+
|
|
521
|
+
def get_new_version(self, **kwargs):
|
|
522
|
+
"""Returns a new version of the module with altered parameters.
|
|
523
|
+
|
|
524
|
+
For instance, the generate parameter can be altered to enable text generation or log-probabilities computation.
|
|
525
|
+
This is especially useful when one wants to avoid re-initializing the module with a new set of parameters, when the
|
|
526
|
+
same parameters could be used to gather log-probs.
|
|
527
|
+
|
|
528
|
+
Positional arguments are not supported.
|
|
529
|
+
|
|
530
|
+
See the class constructor for more details about the parameters.
|
|
531
|
+
"""
|
|
532
|
+
# Build the constructor arguments by using current values for missing parameters
|
|
533
|
+
constructor_kwargs = {}
|
|
534
|
+
|
|
535
|
+
# Model is always required
|
|
536
|
+
constructor_kwargs["model"] = kwargs.get("model", self.model)
|
|
537
|
+
|
|
538
|
+
# Check for each parameter and use current value if not provided
|
|
539
|
+
if "tokenizer" in kwargs:
|
|
540
|
+
constructor_kwargs["tokenizer"] = kwargs["tokenizer"]
|
|
541
|
+
elif hasattr(self, "tokenizer"):
|
|
542
|
+
constructor_kwargs["tokenizer"] = self.tokenizer
|
|
543
|
+
|
|
544
|
+
if "input_mode" in kwargs:
|
|
545
|
+
constructor_kwargs["input_mode"] = kwargs["input_mode"]
|
|
546
|
+
elif hasattr(self, "input_mode"):
|
|
547
|
+
constructor_kwargs["input_mode"] = self.input_mode
|
|
548
|
+
|
|
549
|
+
if "input_key" in kwargs:
|
|
550
|
+
constructor_kwargs["input_key"] = kwargs["input_key"]
|
|
551
|
+
elif hasattr(self, "input_key"):
|
|
552
|
+
constructor_kwargs["input_key"] = self.input_key
|
|
553
|
+
|
|
554
|
+
if "attention_mask_key" in kwargs:
|
|
555
|
+
constructor_kwargs["attention_mask_key"] = kwargs["attention_mask_key"]
|
|
556
|
+
elif hasattr(self, "attention_mask_key"):
|
|
557
|
+
constructor_kwargs["attention_mask_key"] = self.attention_mask_key
|
|
558
|
+
|
|
559
|
+
if "generate" in kwargs:
|
|
560
|
+
constructor_kwargs["generate"] = kwargs["generate"]
|
|
561
|
+
elif hasattr(self, "generate"):
|
|
562
|
+
constructor_kwargs["generate"] = self.generate
|
|
563
|
+
|
|
564
|
+
if "generate_kwargs" in kwargs:
|
|
565
|
+
constructor_kwargs["generate_kwargs"] = kwargs["generate_kwargs"]
|
|
566
|
+
elif hasattr(self, "generate_kwargs"):
|
|
567
|
+
constructor_kwargs["generate_kwargs"] = self.generate_kwargs
|
|
568
|
+
|
|
569
|
+
if "pad_output" in kwargs:
|
|
570
|
+
constructor_kwargs["pad_output"] = kwargs["pad_output"]
|
|
571
|
+
elif hasattr(self, "pad_output"):
|
|
572
|
+
constructor_kwargs["pad_output"] = self.pad_output
|
|
573
|
+
|
|
574
|
+
if "tokenizer_kwargs" in kwargs:
|
|
575
|
+
constructor_kwargs["tokenizer_kwargs"] = kwargs["tokenizer_kwargs"]
|
|
576
|
+
elif hasattr(self, "tokenizer_kwargs"):
|
|
577
|
+
constructor_kwargs["tokenizer_kwargs"] = self.tokenizer_kwargs
|
|
578
|
+
if (
|
|
579
|
+
"pad_output" in kwargs
|
|
580
|
+
and kwargs.get("pad_output")
|
|
581
|
+
!= constructor_kwargs["tokenizer_kwargs"]["padding"]
|
|
582
|
+
):
|
|
583
|
+
constructor_kwargs["tokenizer_kwargs"]["padding"] = kwargs.get(
|
|
584
|
+
"pad_output"
|
|
585
|
+
)
|
|
586
|
+
|
|
587
|
+
if "inplace" in kwargs:
|
|
588
|
+
constructor_kwargs["inplace"] = kwargs["inplace"]
|
|
589
|
+
elif hasattr(self, "inplace"):
|
|
590
|
+
constructor_kwargs["inplace"] = self.inplace
|
|
591
|
+
|
|
592
|
+
if "device" in kwargs:
|
|
593
|
+
constructor_kwargs["device"] = kwargs["device"]
|
|
594
|
+
elif hasattr(self, "_device"):
|
|
595
|
+
constructor_kwargs["device"] = self._device
|
|
596
|
+
|
|
597
|
+
if "layout" in kwargs:
|
|
598
|
+
constructor_kwargs["layout"] = kwargs["layout"]
|
|
599
|
+
elif hasattr(self, "layout"):
|
|
600
|
+
constructor_kwargs["layout"] = self.layout
|
|
601
|
+
|
|
602
|
+
if "num_samples" in kwargs:
|
|
603
|
+
constructor_kwargs["num_samples"] = kwargs["num_samples"]
|
|
604
|
+
elif hasattr(self, "num_samples"):
|
|
605
|
+
constructor_kwargs["num_samples"] = self.num_samples
|
|
606
|
+
|
|
607
|
+
if "chat_template_name" in kwargs:
|
|
608
|
+
constructor_kwargs["chat_template_name"] = kwargs["chat_template_name"]
|
|
609
|
+
elif hasattr(self, "chat_template_name"):
|
|
610
|
+
constructor_kwargs["chat_template_name"] = self.chat_template_name
|
|
611
|
+
|
|
612
|
+
if "chat_template" in kwargs:
|
|
613
|
+
constructor_kwargs["chat_template"] = kwargs["chat_template"]
|
|
614
|
+
elif hasattr(self, "chat_template"):
|
|
615
|
+
constructor_kwargs["chat_template"] = self.chat_template
|
|
616
|
+
|
|
617
|
+
if "text_key" in kwargs:
|
|
618
|
+
constructor_kwargs["text_key"] = kwargs["text_key"]
|
|
619
|
+
elif hasattr(self, "text_key"):
|
|
620
|
+
constructor_kwargs["text_key"] = self.text_key
|
|
621
|
+
|
|
622
|
+
if "tokens_key" in kwargs:
|
|
623
|
+
constructor_kwargs["tokens_key"] = kwargs["tokens_key"]
|
|
624
|
+
elif hasattr(self, "tokens_key"):
|
|
625
|
+
constructor_kwargs["tokens_key"] = self.tokens_key
|
|
626
|
+
|
|
627
|
+
if "masks_key" in kwargs:
|
|
628
|
+
constructor_kwargs["masks_key"] = kwargs["masks_key"]
|
|
629
|
+
elif hasattr(self, "masks_key"):
|
|
630
|
+
constructor_kwargs["masks_key"] = self.masks_key
|
|
631
|
+
|
|
632
|
+
if "log_probs_key" in kwargs:
|
|
633
|
+
constructor_kwargs["log_probs_key"] = kwargs["log_probs_key"]
|
|
634
|
+
elif hasattr(self, "log_probs_key"):
|
|
635
|
+
constructor_kwargs["log_probs_key"] = self.log_probs_key
|
|
636
|
+
|
|
637
|
+
# Create and return new instance
|
|
638
|
+
return type(self)(**constructor_kwargs)
|
|
639
|
+
|
|
640
|
+
@set_list_to_stack(True)
|
|
641
|
+
@_batching
|
|
642
|
+
def forward(
|
|
643
|
+
self,
|
|
644
|
+
tensordict: TensorDictBase,
|
|
645
|
+
*,
|
|
646
|
+
tensordict_out: TensorDictBase | None = None,
|
|
647
|
+
logits_only: bool = False,
|
|
648
|
+
**kwargs,
|
|
649
|
+
) -> TensorDictBase:
|
|
650
|
+
tensordict_orig = tensordict
|
|
651
|
+
if not tensordict.ndim:
|
|
652
|
+
if tensordict_out is not None:
|
|
653
|
+
raise ValueError(
|
|
654
|
+
"tensordict_out must not be provided when tensordict.ndim == 0. If this is needed, "
|
|
655
|
+
"please submit an issue on github."
|
|
656
|
+
)
|
|
657
|
+
# unsqueeze - squeeze the input
|
|
658
|
+
return self.forward(lazy_stack([tensordict]), logits_only=logits_only)[0]
|
|
659
|
+
elif tensordict.ndim > 1:
|
|
660
|
+
if tensordict_out is not None:
|
|
661
|
+
raise ValueError(
|
|
662
|
+
"tensordict_out must not be provided when tensordict.ndim > 1. If this is needed, "
|
|
663
|
+
"please submit an issue on github."
|
|
664
|
+
)
|
|
665
|
+
return self.forward(tensordict.reshape(-1), logits_only=logits_only).view(
|
|
666
|
+
tensordict.shape
|
|
667
|
+
)
|
|
668
|
+
|
|
669
|
+
if not isinstance(tensordict, LazyStackedTensorDict):
|
|
670
|
+
tensordict = tensordict.to_lazystack(0)
|
|
671
|
+
|
|
672
|
+
_source_device = None
|
|
673
|
+
if self._device:
|
|
674
|
+
_source_device = tensordict.device
|
|
675
|
+
if tensordict.device:
|
|
676
|
+
tensordict = tensordict.copy().clear_device_()
|
|
677
|
+
|
|
678
|
+
if kwargs:
|
|
679
|
+
from transformers import GenerationConfig
|
|
680
|
+
|
|
681
|
+
cfg = GenerationConfig(**kwargs)
|
|
682
|
+
else:
|
|
683
|
+
cfg = None
|
|
684
|
+
|
|
685
|
+
if self.num_samples is not None:
|
|
686
|
+
out = (
|
|
687
|
+
TensorDict(
|
|
688
|
+
device=tensordict.device,
|
|
689
|
+
batch_size=(
|
|
690
|
+
tensordict.batch_size[0],
|
|
691
|
+
self.num_samples,
|
|
692
|
+
*tensordict.batch_size[1:],
|
|
693
|
+
),
|
|
694
|
+
)
|
|
695
|
+
.to_lazystack(1)
|
|
696
|
+
.to_lazystack(0)
|
|
697
|
+
)
|
|
698
|
+
else:
|
|
699
|
+
out = TensorDict(
|
|
700
|
+
device=tensordict.device, batch_size=tensordict.batch_size
|
|
701
|
+
).to_lazystack(0)
|
|
702
|
+
|
|
703
|
+
if self.input_mode == "history":
|
|
704
|
+
if self.generate:
|
|
705
|
+
out = self._from_transformers_generate_history(tensordict, cfg, out)
|
|
706
|
+
else:
|
|
707
|
+
out = self._from_transformers_logprobs_history(
|
|
708
|
+
tensordict, cfg, out, logits_only=logits_only
|
|
709
|
+
)
|
|
710
|
+
elif self.input_mode == "text":
|
|
711
|
+
if self.generate:
|
|
712
|
+
out = self._from_transformers_generate_text(tensordict, cfg, out)
|
|
713
|
+
else:
|
|
714
|
+
out = self._from_transformers_logprobs_text(
|
|
715
|
+
tensordict, cfg, out, logits_only=logits_only
|
|
716
|
+
)
|
|
717
|
+
elif self.input_mode == "tokens":
|
|
718
|
+
if self.generate:
|
|
719
|
+
out = self._from_transformers_generate_tokens(tensordict, cfg, out)
|
|
720
|
+
else:
|
|
721
|
+
out = self._from_transformers_logprobs_tokens(
|
|
722
|
+
tensordict, cfg, out, logits_only=logits_only
|
|
723
|
+
)
|
|
724
|
+
|
|
725
|
+
if _source_device:
|
|
726
|
+
out = out.to(_source_device)
|
|
727
|
+
|
|
728
|
+
if tensordict_out is None:
|
|
729
|
+
if self.inplace is True:
|
|
730
|
+
# The output is the input
|
|
731
|
+
tensordict_out = tensordict_orig
|
|
732
|
+
elif self.inplace is False:
|
|
733
|
+
# The output is the new structure
|
|
734
|
+
tensordict_out = out
|
|
735
|
+
elif self.inplace == "empty":
|
|
736
|
+
# The output is empty
|
|
737
|
+
tensordict_out = tensordict.empty()
|
|
738
|
+
|
|
739
|
+
if tensordict_out is not None and tensordict_out is not out:
|
|
740
|
+
result = tensordict_out.exclude(*self.out_keys, inplace=True)
|
|
741
|
+
result.update(out, keys_to_update=self.out_keys)
|
|
742
|
+
elif tensordict_out is out:
|
|
743
|
+
result = out.select(*self.out_keys)
|
|
744
|
+
elif self.inplace:
|
|
745
|
+
result = out
|
|
746
|
+
keys = list(set(self.out_keys + list(tensordict.keys(True, True))))
|
|
747
|
+
result = tensordict.exclude(*self.out_keys, inplace=True).update(
|
|
748
|
+
result, keys_to_update=keys
|
|
749
|
+
)
|
|
750
|
+
else:
|
|
751
|
+
result = out
|
|
752
|
+
return result
|
|
753
|
+
|
|
754
|
+
def _from_transformers_generate_history(self, td, cfg, out) -> TensorDictBase:
|
|
755
|
+
"""Generate text from history input."""
|
|
756
|
+
from torchrl.data.llm import History
|
|
757
|
+
|
|
758
|
+
# Validate input
|
|
759
|
+
if self.input_key not in td:
|
|
760
|
+
raise ValueError(
|
|
761
|
+
f"Expected '{self.input_key}' key for history input mode, "
|
|
762
|
+
f"but found keys: {list(td.keys())}"
|
|
763
|
+
)
|
|
764
|
+
|
|
765
|
+
history = td.get(self.input_key)
|
|
766
|
+
if not isinstance(history, History):
|
|
767
|
+
raise TypeError(
|
|
768
|
+
f"Expected History object for '{self.input_key}', got {type(history)}"
|
|
769
|
+
)
|
|
770
|
+
|
|
771
|
+
# Apply chat template
|
|
772
|
+
tokenizer_kwargs = {}
|
|
773
|
+
if self.chat_template_name is not None:
|
|
774
|
+
tokenizer_kwargs.setdefault("chat_template_name", self.chat_template_name)
|
|
775
|
+
if self.chat_template is not None:
|
|
776
|
+
tokenizer_kwargs.setdefault("chat_template", self.chat_template)
|
|
777
|
+
tokenizer_kwargs.setdefault("add_generation_prompt", True)
|
|
778
|
+
text_prompt = history.apply_chat_template(
|
|
779
|
+
tokenizer=self.tokenizer, **tokenizer_kwargs
|
|
780
|
+
)
|
|
781
|
+
if not isinstance(text_prompt, list):
|
|
782
|
+
raise ValueError(
|
|
783
|
+
f"Expected list of text for history input, got {type(text_prompt)}"
|
|
784
|
+
)
|
|
785
|
+
tokenizer_kwargs.setdefault("return_assistant_tokens_mask", False)
|
|
786
|
+
tokenizer_kwargs.setdefault("tokenize", True)
|
|
787
|
+
tokenizer_kwargs.setdefault("padding", False)
|
|
788
|
+
tokenizer_kwargs.setdefault("return_dict", True)
|
|
789
|
+
response_struct = history.apply_chat_template(
|
|
790
|
+
tokenizer=self.tokenizer, **tokenizer_kwargs
|
|
791
|
+
)
|
|
792
|
+
|
|
793
|
+
if self._device is not None:
|
|
794
|
+
response_struct = response_struct.to(self._device)
|
|
795
|
+
|
|
796
|
+
tokens_prompt_padded = response_struct.get(
|
|
797
|
+
"input_ids",
|
|
798
|
+
as_padded_tensor=True,
|
|
799
|
+
padding_value=self.padding_value,
|
|
800
|
+
padding_side="left",
|
|
801
|
+
)
|
|
802
|
+
attention_mask_prompt_padded = response_struct.get(
|
|
803
|
+
"attention_mask",
|
|
804
|
+
as_padded_tensor=True,
|
|
805
|
+
padding_value=0,
|
|
806
|
+
padding_side="left",
|
|
807
|
+
)
|
|
808
|
+
|
|
809
|
+
if attention_mask_prompt_padded is None:
|
|
810
|
+
attention_mask_prompt_padded = (
|
|
811
|
+
tokens_prompt_padded != self.tokenizer.pad_token_id
|
|
812
|
+
)
|
|
813
|
+
|
|
814
|
+
result = self._generate_from_tokens(
|
|
815
|
+
tokens_prompt_padded, attention_mask_prompt_padded, cfg, out
|
|
816
|
+
)
|
|
817
|
+
|
|
818
|
+
# Generate using text path
|
|
819
|
+
if self.pad_output:
|
|
820
|
+
result[(self.tokens_key, "prompt")] = (
|
|
821
|
+
tokens_prompt_padded
|
|
822
|
+
if not self.num_samples
|
|
823
|
+
else tokens_prompt_padded.unsqueeze(1).repeat(1, self.num_samples, 1)
|
|
824
|
+
)
|
|
825
|
+
else:
|
|
826
|
+
tokens_prompt_unpadded = response_struct.get(
|
|
827
|
+
"input_ids",
|
|
828
|
+
as_nested_tensor=True,
|
|
829
|
+
)
|
|
830
|
+
if not self.num_samples:
|
|
831
|
+
result[(self.tokens_key, "prompt")] = tokens_prompt_unpadded
|
|
832
|
+
else:
|
|
833
|
+
for r in result.unbind(1):
|
|
834
|
+
r[(self.tokens_key, "prompt")] = tokens_prompt_unpadded
|
|
835
|
+
|
|
836
|
+
text_result = Text._from_tensordict(result.empty())
|
|
837
|
+
result.set(self.text_key, text_result)
|
|
838
|
+
if not self.num_samples:
|
|
839
|
+
text_result.prompt = text_prompt
|
|
840
|
+
else:
|
|
841
|
+
for r in result.unbind(1):
|
|
842
|
+
r[self.text_key, "prompt"] = text_prompt
|
|
843
|
+
with result.view(-1) as result_flat:
|
|
844
|
+
if self.pad_output:
|
|
845
|
+
tokens_full_padded = result_flat.get(
|
|
846
|
+
(self.tokens_key, "full"),
|
|
847
|
+
as_padded_tensor=True,
|
|
848
|
+
padding_side="right",
|
|
849
|
+
padding_value=self.padding_value,
|
|
850
|
+
)
|
|
851
|
+
if tokens_full_padded is None:
|
|
852
|
+
raise ValueError("tokens_full_padded is None")
|
|
853
|
+
text_full = self.tokenizer.batch_decode(
|
|
854
|
+
tokens_full_padded, skip_special_tokens=False
|
|
855
|
+
)
|
|
856
|
+
else:
|
|
857
|
+
tokens_full_unpadded = result_flat.get(
|
|
858
|
+
(self.tokens_key, "full"), as_list=True
|
|
859
|
+
)
|
|
860
|
+
if tokens_full_unpadded is None:
|
|
861
|
+
raise ValueError("tokens_full_unpadded is None")
|
|
862
|
+
text_full = self.tokenizer.batch_decode(
|
|
863
|
+
tokens_full_unpadded, skip_special_tokens=False
|
|
864
|
+
)
|
|
865
|
+
text_prompt = result_flat[self.text_key, "prompt"]
|
|
866
|
+
text_response = [
|
|
867
|
+
txt[len(prompt) :]
|
|
868
|
+
for txt, prompt in _zip_strict(text_full, text_prompt)
|
|
869
|
+
]
|
|
870
|
+
result_flat.set((self.text_key, "full"), text_full)
|
|
871
|
+
result_flat.set((self.text_key, "response"), text_response)
|
|
872
|
+
# Now parse the full text back to a history object, and use the extra history objects
|
|
873
|
+
# as response
|
|
874
|
+
history_chat = ChatHistory._from_tensordict(result.empty())
|
|
875
|
+
if self.num_samples is None:
|
|
876
|
+
history_chat.prompt = history
|
|
877
|
+
else:
|
|
878
|
+
for h in history_chat.unbind(1):
|
|
879
|
+
h.prompt = history
|
|
880
|
+
with history_chat.view(-1) as history_chat_flat:
|
|
881
|
+
prompt_histories = history_chat_flat.prompt
|
|
882
|
+
# Extract response histories from full text
|
|
883
|
+
h_responses = _extract_responses_from_full_histories(
|
|
884
|
+
text_full, prompt_histories, self.chat_template_name, self.tokenizer
|
|
885
|
+
)
|
|
886
|
+
history_chat_flat.response = h_responses
|
|
887
|
+
# Combine prompt and response to create full history
|
|
888
|
+
history_chat_flat.full = history_chat_flat.prompt.extend(
|
|
889
|
+
h_responses, inplace=False, dim=-1
|
|
890
|
+
)
|
|
891
|
+
result.set(self.history_key, history_chat)
|
|
892
|
+
return result
|
|
893
|
+
|
|
894
|
+
def _from_transformers_logprobs_history(self, td, cfg, out, logits_only=False):
|
|
895
|
+
"""Compute log-probs from history input."""
|
|
896
|
+
from torchrl.data.llm import History
|
|
897
|
+
|
|
898
|
+
# Validate input
|
|
899
|
+
if self.input_key not in td:
|
|
900
|
+
raise ValueError(
|
|
901
|
+
f"Expected '{self.input_key}' key for history input mode, "
|
|
902
|
+
f"but found keys: {list(td.keys())}"
|
|
903
|
+
)
|
|
904
|
+
|
|
905
|
+
history = td.get(self.input_key)
|
|
906
|
+
if not isinstance(history, History):
|
|
907
|
+
raise TypeError(
|
|
908
|
+
f"Expected History object for '{self.input_key}', got {type(history)}"
|
|
909
|
+
)
|
|
910
|
+
|
|
911
|
+
# Apply chat template
|
|
912
|
+
tokenizer_kwargs = {}
|
|
913
|
+
if self.chat_template_name is not None:
|
|
914
|
+
tokenizer_kwargs.setdefault("chat_template_name", self.chat_template_name)
|
|
915
|
+
if self.chat_template is not None:
|
|
916
|
+
tokenizer_kwargs.setdefault("chat_template", self.chat_template)
|
|
917
|
+
tokenizer_kwargs.setdefault("add_generation_prompt", False)
|
|
918
|
+
text_full = history.apply_chat_template(
|
|
919
|
+
tokenizer=self.tokenizer, **tokenizer_kwargs
|
|
920
|
+
)
|
|
921
|
+
|
|
922
|
+
tokenizer_kwargs.setdefault("return_assistant_tokens_mask", True)
|
|
923
|
+
tokenizer_kwargs.setdefault("tokenize", True)
|
|
924
|
+
tokenizer_kwargs.setdefault("padding", False)
|
|
925
|
+
tokenizer_kwargs.setdefault("return_dict", True)
|
|
926
|
+
|
|
927
|
+
with torch.device(self._device) if self._device is not None else nullcontext():
|
|
928
|
+
response_tokens = history.apply_chat_template(
|
|
929
|
+
tokenizer=self.tokenizer, **tokenizer_kwargs
|
|
930
|
+
)
|
|
931
|
+
if not isinstance(response_tokens, TensorDictBase):
|
|
932
|
+
raise ValueError(
|
|
933
|
+
f"Expected TensorDictBase for history input, got {type(response_tokens)}"
|
|
934
|
+
)
|
|
935
|
+
result = self._logprobs_from_history_tokens(
|
|
936
|
+
response_tokens, cfg, out, logits_only=logits_only
|
|
937
|
+
)
|
|
938
|
+
text_result = Text._from_tensordict(result.empty())
|
|
939
|
+
result.set(self.text_key, text_result)
|
|
940
|
+
result[self.text_key, "full"] = text_full
|
|
941
|
+
result.set(self.history_key, ChatHistory(full=history))
|
|
942
|
+
return result
|
|
943
|
+
|
|
944
|
+
def _cat_text(self, text, response_text):
|
|
945
|
+
"""Concatenate text and response text."""
|
|
946
|
+
if isinstance(text, list):
|
|
947
|
+
return [self._cat_text(t, t_) for t, t_ in _zip_strict(text, response_text)]
|
|
948
|
+
else:
|
|
949
|
+
return text + response_text
|
|
950
|
+
|
|
951
|
+
def _generate_from_text(self, text, cfg, out) -> TensorDictBase:
|
|
952
|
+
"""Generate text from text input."""
|
|
953
|
+
pad_val = self.tokenizer.pad_token_id
|
|
954
|
+
|
|
955
|
+
# Convert text to list format
|
|
956
|
+
if isinstance(text, str):
|
|
957
|
+
text = [text]
|
|
958
|
+
elif not isinstance(text, list):
|
|
959
|
+
text = text.tolist()
|
|
960
|
+
|
|
961
|
+
tokenizer_kwargs = dict(self.tokenizer_kwargs)
|
|
962
|
+
tokenizer_kwargs.setdefault("padding", True)
|
|
963
|
+
|
|
964
|
+
with torch.device(
|
|
965
|
+
self._device
|
|
966
|
+
) if self._device is not None else contextlib.nullcontext():
|
|
967
|
+
tokens_in = self.tokenizer(text, **tokenizer_kwargs)
|
|
968
|
+
if self._device is not None:
|
|
969
|
+
tokens_in = tokens_in.to(self._device)
|
|
970
|
+
# We are going to map this tokens_in to a tensordict to facilitate the padding in case we need it
|
|
971
|
+
tokens_in = dict(tokens_in)
|
|
972
|
+
for k, v in dict(tokens_in).items():
|
|
973
|
+
if isinstance(v, list):
|
|
974
|
+
if isinstance(v[0], torch.Tensor):
|
|
975
|
+
v = torch.nested.nested_tensor(v)
|
|
976
|
+
else:
|
|
977
|
+
v = torch.nested.nested_tensor([torch.tensor(t) for t in v])
|
|
978
|
+
tokens_in[k] = v
|
|
979
|
+
tokens_in = (
|
|
980
|
+
TensorDict(batch_size=tokens_in["input_ids"].size(0))
|
|
981
|
+
.to_lazystack(0)
|
|
982
|
+
.update(tokens_in)
|
|
983
|
+
)
|
|
984
|
+
tokens_prompt_padded = tokens_in.get(
|
|
985
|
+
"input_ids",
|
|
986
|
+
as_padded_tensor=True,
|
|
987
|
+
padding_side="left",
|
|
988
|
+
padding_value=pad_val,
|
|
989
|
+
)
|
|
990
|
+
attention_mask_prompt_padded = tokens_in.get(
|
|
991
|
+
"attention_mask",
|
|
992
|
+
as_padded_tensor=True,
|
|
993
|
+
padding_side="left",
|
|
994
|
+
padding_value=0,
|
|
995
|
+
)
|
|
996
|
+
|
|
997
|
+
if cfg is not None:
|
|
998
|
+
kwargs = copy(self.generate_kwargs)
|
|
999
|
+
kwargs["generation_config"] = cfg
|
|
1000
|
+
else:
|
|
1001
|
+
kwargs = self.generate_kwargs
|
|
1002
|
+
|
|
1003
|
+
tokens_out = self.model.generate(
|
|
1004
|
+
input_ids=tokens_prompt_padded,
|
|
1005
|
+
attention_mask=attention_mask_prompt_padded,
|
|
1006
|
+
**kwargs,
|
|
1007
|
+
)
|
|
1008
|
+
tokens_full_padded = tokens_out["sequences"]
|
|
1009
|
+
tokens_response_padded = tokens_full_padded[
|
|
1010
|
+
..., tokens_prompt_padded.shape[-1] :
|
|
1011
|
+
]
|
|
1012
|
+
|
|
1013
|
+
attention_mask_response_padded = tokens_response_padded != pad_val
|
|
1014
|
+
if self.num_samples:
|
|
1015
|
+
attention_mask_full_padded = torch.cat(
|
|
1016
|
+
[
|
|
1017
|
+
attention_mask_prompt_padded.repeat_interleave(
|
|
1018
|
+
self.num_samples, dim=0
|
|
1019
|
+
),
|
|
1020
|
+
attention_mask_response_padded,
|
|
1021
|
+
],
|
|
1022
|
+
dim=-1,
|
|
1023
|
+
)
|
|
1024
|
+
else:
|
|
1025
|
+
attention_mask_full_padded = torch.cat(
|
|
1026
|
+
[attention_mask_prompt_padded, attention_mask_response_padded], dim=-1
|
|
1027
|
+
)
|
|
1028
|
+
tokens_response_unpadded = _unpad_tensors(
|
|
1029
|
+
tokens_response_padded, attention_mask_response_padded, as_nested=False
|
|
1030
|
+
)
|
|
1031
|
+
|
|
1032
|
+
if self.return_log_probs:
|
|
1033
|
+
# These are only for the new tokens, not for the prompt - to get that, we'd need to run the forward pass again
|
|
1034
|
+
logits = torch.stack(list(tokens_out["logits"]), 1)
|
|
1035
|
+
log_probs, logits = self._log_probs_generate(
|
|
1036
|
+
tokens_response_padded, logits, pad_val=-100, pad=False
|
|
1037
|
+
)
|
|
1038
|
+
|
|
1039
|
+
response_text = self.tokenizer.batch_decode(
|
|
1040
|
+
tokens_response_unpadded, skip_special_tokens=False
|
|
1041
|
+
)
|
|
1042
|
+
|
|
1043
|
+
# Build output TensorClass objects
|
|
1044
|
+
if self.num_samples is not None:
|
|
1045
|
+
text = [txt for txt in text for _ in range(self.num_samples)]
|
|
1046
|
+
text_obj = Text._from_tensordict(out.empty())
|
|
1047
|
+
with text_obj.view(-1) as text_obj_flat:
|
|
1048
|
+
text_obj_flat.prompt = text
|
|
1049
|
+
text_obj_flat.response = response_text
|
|
1050
|
+
text_obj_flat.full = self._cat_text(text, response_text)
|
|
1051
|
+
out.set(self.text_key, text_obj)
|
|
1052
|
+
|
|
1053
|
+
tokens_obj = Tokens._from_tensordict(out.empty())
|
|
1054
|
+
if self.pad_output:
|
|
1055
|
+
prompt = tokens_prompt_padded
|
|
1056
|
+
else:
|
|
1057
|
+
prompt = _unpad_tensors(
|
|
1058
|
+
tokens_prompt_padded, attention_mask_prompt_padded, as_nested=False
|
|
1059
|
+
)
|
|
1060
|
+
if tokens_obj.ndim == 2:
|
|
1061
|
+
for i in range(self.num_samples):
|
|
1062
|
+
tokens_obj[:, i].prompt = prompt
|
|
1063
|
+
else:
|
|
1064
|
+
tokens_obj.prompt = prompt
|
|
1065
|
+
with tokens_obj.view(-1) as tokens_obj_flat:
|
|
1066
|
+
if not self.pad_output:
|
|
1067
|
+
tokens_obj_flat.response = tokens_response_unpadded
|
|
1068
|
+
tokens_full_unpadded = _unpad_tensors(
|
|
1069
|
+
tokens_full_padded, attention_mask_full_padded, as_nested=False
|
|
1070
|
+
)
|
|
1071
|
+
tokens_obj_flat.full = tokens_full_unpadded
|
|
1072
|
+
else:
|
|
1073
|
+
tokens_obj_flat.response = tokens_response_padded
|
|
1074
|
+
tokens_obj_flat.full = tokens_full_padded
|
|
1075
|
+
tokens_obj.padded = MetaData(self.pad_output)
|
|
1076
|
+
out.set(self.tokens_key, tokens_obj)
|
|
1077
|
+
|
|
1078
|
+
masks_obj = Masks._from_tensordict(out.empty())
|
|
1079
|
+
if out.ndim == 2:
|
|
1080
|
+
attention_mask_full_padded = attention_mask_full_padded.unflatten(
|
|
1081
|
+
0, (-1, self.num_samples)
|
|
1082
|
+
)
|
|
1083
|
+
if self.pad_output:
|
|
1084
|
+
masks_obj.all_attention_mask = attention_mask_full_padded.bool()
|
|
1085
|
+
else:
|
|
1086
|
+
if out.ndim == 2:
|
|
1087
|
+
with tokens_obj.view(-1) as tokens_obj_flat, masks_obj.view(
|
|
1088
|
+
-1
|
|
1089
|
+
) as masks_obj_flat:
|
|
1090
|
+
attention_mask_full_unpadded = attention_mask_full_padded.flatten(
|
|
1091
|
+
0, 1
|
|
1092
|
+
)
|
|
1093
|
+
attention_mask_full_unpadded = _unpad_tensors(
|
|
1094
|
+
attention_mask_full_unpadded.bool(),
|
|
1095
|
+
attention_mask_full_padded.flatten(0, 1),
|
|
1096
|
+
as_nested=False,
|
|
1097
|
+
)
|
|
1098
|
+
masks_obj_flat.all_attention_mask = attention_mask_full_unpadded
|
|
1099
|
+
else:
|
|
1100
|
+
attention_mask_full_unpadded = _unpad_tensors(
|
|
1101
|
+
attention_mask_full_padded.bool(),
|
|
1102
|
+
attention_mask_full_padded,
|
|
1103
|
+
as_nested=False,
|
|
1104
|
+
)
|
|
1105
|
+
masks_obj.all_attention_mask = attention_mask_full_unpadded
|
|
1106
|
+
masks_obj.all_assistant_mask = None
|
|
1107
|
+
masks_obj.padded = MetaData(self.pad_output)
|
|
1108
|
+
out.set(self.masks_key, masks_obj)
|
|
1109
|
+
|
|
1110
|
+
if self.return_log_probs:
|
|
1111
|
+
log_probs_obj = LogProbs._from_tensordict(out.empty())
|
|
1112
|
+
with log_probs_obj.view(-1) as log_probs_obj_flat:
|
|
1113
|
+
# Unfortunate but we only have the log-probs for the new tokens, not for the prompt - to get that, we'd need to run the forward pass again
|
|
1114
|
+
if self.pad_output:
|
|
1115
|
+
log_probs_obj_flat.prompt = None
|
|
1116
|
+
log_probs_obj_flat.response = log_probs
|
|
1117
|
+
log_probs_obj_flat.full = None
|
|
1118
|
+
else:
|
|
1119
|
+
log_probs_unpadded = _unpad_tensors(
|
|
1120
|
+
log_probs, attention_mask_response_padded, as_nested=False
|
|
1121
|
+
)
|
|
1122
|
+
log_probs_obj_flat.prompt = None
|
|
1123
|
+
log_probs_obj_flat.response = log_probs_unpadded
|
|
1124
|
+
log_probs_obj_flat.full = None
|
|
1125
|
+
log_probs_obj.padded = MetaData(self.pad_output)
|
|
1126
|
+
out.set(self.log_probs_key, log_probs_obj)
|
|
1127
|
+
|
|
1128
|
+
# Add logits to output if we're in a get_dist call
|
|
1129
|
+
if self._in_get_dist_call:
|
|
1130
|
+
if self.pad_output:
|
|
1131
|
+
out.set("logits", logits)
|
|
1132
|
+
else:
|
|
1133
|
+
logits_full_unpadded = _unpad_tensors(
|
|
1134
|
+
logits, attention_mask_full_padded, as_nested=False
|
|
1135
|
+
)
|
|
1136
|
+
out.set("logits", logits_full_unpadded)
|
|
1137
|
+
|
|
1138
|
+
return out
|
|
1139
|
+
|
|
1140
|
+
def _cat_tensors(
|
|
1141
|
+
self,
|
|
1142
|
+
tokens: torch.Tensor | list[torch.Tensor],
|
|
1143
|
+
response_tokens: torch.Tensor | list[torch.Tensor],
|
|
1144
|
+
cast: torch.dtype | None = None,
|
|
1145
|
+
):
|
|
1146
|
+
"""Concatenate tokens and response tokens."""
|
|
1147
|
+
if isinstance(tokens, list) or isinstance(response_tokens, list):
|
|
1148
|
+
return [
|
|
1149
|
+
self._cat_tensors(t, t_, cast=cast)
|
|
1150
|
+
for t, t_ in _zip_strict(tokens, response_tokens)
|
|
1151
|
+
]
|
|
1152
|
+
else:
|
|
1153
|
+
result = torch.cat([tokens, response_tokens], dim=-1)
|
|
1154
|
+
if cast is not None:
|
|
1155
|
+
result = result.to(cast)
|
|
1156
|
+
return result
|
|
1157
|
+
|
|
1158
|
+
def _logprobs_from_history_tokens(
|
|
1159
|
+
self, response_tokens, cfg, out, logits_only=False
|
|
1160
|
+
):
|
|
1161
|
+
"""Compute log-probs from history tokens."""
|
|
1162
|
+
pad_val = self.tokenizer.pad_token_id
|
|
1163
|
+
|
|
1164
|
+
if cfg is not None:
|
|
1165
|
+
kwargs = copy(self.generate_kwargs)
|
|
1166
|
+
kwargs["generation_config"] = cfg
|
|
1167
|
+
else:
|
|
1168
|
+
kwargs = self.generate_kwargs
|
|
1169
|
+
|
|
1170
|
+
# non-packed forward pass
|
|
1171
|
+
if self.pad_model_input:
|
|
1172
|
+
# unfortunately HF wants us to use padded tensors
|
|
1173
|
+
tokens_full_padded = response_tokens.get(
|
|
1174
|
+
"input_ids",
|
|
1175
|
+
as_padded_tensor=True,
|
|
1176
|
+
padding_side="left",
|
|
1177
|
+
padding_value=pad_val,
|
|
1178
|
+
)
|
|
1179
|
+
if not isinstance(tokens_full_padded, torch.Tensor):
|
|
1180
|
+
raise ValueError(
|
|
1181
|
+
f"Expected Tensor for tokens_full_padded, got {type(tokens_full_padded)}"
|
|
1182
|
+
)
|
|
1183
|
+
attention_mask_full_padded = response_tokens.get(
|
|
1184
|
+
"attention_mask",
|
|
1185
|
+
as_padded_tensor=True,
|
|
1186
|
+
padding_side="left",
|
|
1187
|
+
padding_value=0,
|
|
1188
|
+
)
|
|
1189
|
+
if not isinstance(attention_mask_full_padded, torch.Tensor):
|
|
1190
|
+
raise ValueError(
|
|
1191
|
+
f"Expected Tensor for attention_mask_full_padded, got {type(attention_mask_full_padded)}"
|
|
1192
|
+
)
|
|
1193
|
+
|
|
1194
|
+
(
|
|
1195
|
+
log_probs_full_padded,
|
|
1196
|
+
logits_full_padded,
|
|
1197
|
+
) = self._model_forward_with_padded_sequences(
|
|
1198
|
+
tokens_full_padded,
|
|
1199
|
+
attention_mask_full_padded,
|
|
1200
|
+
pad_val=pad_val,
|
|
1201
|
+
logits_only=logits_only,
|
|
1202
|
+
**kwargs,
|
|
1203
|
+
)
|
|
1204
|
+
else:
|
|
1205
|
+
# unfortunately HF wants us to use padded tensors
|
|
1206
|
+
tokens_full_unpadded = response_tokens.get(
|
|
1207
|
+
"input_ids",
|
|
1208
|
+
as_nested_tensor=True,
|
|
1209
|
+
layout=torch.jagged,
|
|
1210
|
+
)
|
|
1211
|
+
attention_mask_full_unpadded = response_tokens.get(
|
|
1212
|
+
"attention_mask",
|
|
1213
|
+
as_nested_tensor=True,
|
|
1214
|
+
layout=torch.jagged,
|
|
1215
|
+
)
|
|
1216
|
+
(
|
|
1217
|
+
log_probs_full_unpadded,
|
|
1218
|
+
logits_full_unpadded,
|
|
1219
|
+
) = self._model_forward_with_packed_sequences(
|
|
1220
|
+
# TODO: no padding if we don't need to
|
|
1221
|
+
tokens_full_unpadded,
|
|
1222
|
+
attention_mask_full_unpadded,
|
|
1223
|
+
pad=False,
|
|
1224
|
+
logits_only=logits_only,
|
|
1225
|
+
**kwargs,
|
|
1226
|
+
)
|
|
1227
|
+
tokens_full_padded = pad_sequence(
|
|
1228
|
+
tokens_full_unpadded.unbind(0),
|
|
1229
|
+
batch_first=True,
|
|
1230
|
+
padding_value=pad_val,
|
|
1231
|
+
padding_side="left",
|
|
1232
|
+
)
|
|
1233
|
+
attention_mask_full_padded = pad_sequence(
|
|
1234
|
+
attention_mask_full_unpadded.unbind(0),
|
|
1235
|
+
batch_first=True,
|
|
1236
|
+
padding_value=0,
|
|
1237
|
+
padding_side="left",
|
|
1238
|
+
)
|
|
1239
|
+
if log_probs_full_unpadded is not None:
|
|
1240
|
+
log_probs_full_padded = pad_sequence(
|
|
1241
|
+
log_probs_full_unpadded.unbind(0),
|
|
1242
|
+
batch_first=True,
|
|
1243
|
+
padding_value=0.0,
|
|
1244
|
+
padding_side="left",
|
|
1245
|
+
)
|
|
1246
|
+
else:
|
|
1247
|
+
log_probs_full_padded = None
|
|
1248
|
+
logits_full_padded = pad_sequence(
|
|
1249
|
+
logits_full_unpadded.unbind(0),
|
|
1250
|
+
batch_first=True,
|
|
1251
|
+
padding_value=0.0,
|
|
1252
|
+
padding_side="left",
|
|
1253
|
+
)
|
|
1254
|
+
# Build output TensorClass objects
|
|
1255
|
+
text_obj = Text._from_tensordict(
|
|
1256
|
+
TensorDict(batch_size=out.batch_size).to_lazystack(0)
|
|
1257
|
+
)
|
|
1258
|
+
text_obj.prompt = None
|
|
1259
|
+
text_obj.response = None
|
|
1260
|
+
text_obj.full = None
|
|
1261
|
+
out.set(self.text_key, text_obj)
|
|
1262
|
+
|
|
1263
|
+
all_assistant_mask_padded = response_tokens.get(
|
|
1264
|
+
"assistant_masks",
|
|
1265
|
+
as_padded_tensor=True,
|
|
1266
|
+
padding_side="left",
|
|
1267
|
+
padding_value=0,
|
|
1268
|
+
)
|
|
1269
|
+
if all_assistant_mask_padded is not None:
|
|
1270
|
+
all_assistant_mask_padded = all_assistant_mask_padded.bool()
|
|
1271
|
+
masks_obj = Masks._from_tensordict(
|
|
1272
|
+
TensorDict(batch_size=out.batch_size).to_lazystack(0)
|
|
1273
|
+
)
|
|
1274
|
+
if self.pad_output:
|
|
1275
|
+
masks_obj.all_attention_mask = attention_mask_full_padded.bool()
|
|
1276
|
+
if all_assistant_mask_padded is not None:
|
|
1277
|
+
masks_obj.all_assistant_mask = all_assistant_mask_padded
|
|
1278
|
+
else:
|
|
1279
|
+
masks_obj.all_attention_mask = _unpad_tensors(
|
|
1280
|
+
attention_mask_full_padded.bool(),
|
|
1281
|
+
attention_mask_full_padded,
|
|
1282
|
+
as_nested=False,
|
|
1283
|
+
)
|
|
1284
|
+
if all_assistant_mask_padded is not None:
|
|
1285
|
+
masks_obj.all_assistant_mask = _unpad_tensors(
|
|
1286
|
+
all_assistant_mask_padded,
|
|
1287
|
+
attention_mask_full_padded,
|
|
1288
|
+
as_nested=False,
|
|
1289
|
+
)
|
|
1290
|
+
masks_obj.padded = MetaData(self.pad_output)
|
|
1291
|
+
out.set(self.masks_key, masks_obj)
|
|
1292
|
+
|
|
1293
|
+
tokens_obj = Tokens._from_tensordict(
|
|
1294
|
+
TensorDict(batch_size=out.batch_size).to_lazystack(0)
|
|
1295
|
+
)
|
|
1296
|
+
if self.pad_output:
|
|
1297
|
+
tokens_obj.full = tokens_full_padded
|
|
1298
|
+
else:
|
|
1299
|
+
input_ids_full_unpadded = _unpad_tensors(
|
|
1300
|
+
tokens_full_padded, attention_mask_full_padded, as_nested=False
|
|
1301
|
+
)
|
|
1302
|
+
tokens_obj.full = input_ids_full_unpadded
|
|
1303
|
+
tokens_obj.response = None
|
|
1304
|
+
tokens_obj.padded = MetaData(self.pad_output)
|
|
1305
|
+
out.set(self.tokens_key, tokens_obj)
|
|
1306
|
+
|
|
1307
|
+
if not logits_only:
|
|
1308
|
+
log_probs_obj = LogProbs._from_tensordict(
|
|
1309
|
+
TensorDict(batch_size=out.batch_size).to_lazystack(0)
|
|
1310
|
+
)
|
|
1311
|
+
if self.pad_output:
|
|
1312
|
+
log_probs_obj.full = log_probs_full_padded
|
|
1313
|
+
else:
|
|
1314
|
+
log_probs_full_unpadded = _unpad_tensors(
|
|
1315
|
+
log_probs_full_padded, attention_mask_full_padded, as_nested=False
|
|
1316
|
+
)
|
|
1317
|
+
log_probs_obj.full = log_probs_full_unpadded
|
|
1318
|
+
log_probs_obj.response = None
|
|
1319
|
+
log_probs_obj.padded = MetaData(self.pad_output)
|
|
1320
|
+
out.set(self.log_probs_key, log_probs_obj)
|
|
1321
|
+
|
|
1322
|
+
# Add logits to output if we're in a get_dist call
|
|
1323
|
+
if self._in_get_dist_call:
|
|
1324
|
+
if self.pad_output:
|
|
1325
|
+
out.set("logits", logits_full_padded)
|
|
1326
|
+
else:
|
|
1327
|
+
logits_full_unpadded = _unpad_tensors(
|
|
1328
|
+
logits_full_padded, attention_mask_full_padded, as_nested=False
|
|
1329
|
+
)
|
|
1330
|
+
out.set("logits", logits_full_unpadded)
|
|
1331
|
+
|
|
1332
|
+
return out
|
|
1333
|
+
|
|
1334
|
+
def _from_transformers_generate_text(self, td, cfg, out) -> TensorDictBase:
|
|
1335
|
+
"""Generate text from text input."""
|
|
1336
|
+
# Validate input
|
|
1337
|
+
if self.input_key not in td:
|
|
1338
|
+
raise ValueError(
|
|
1339
|
+
f"Expected '{self.input_key}' key for text input mode, "
|
|
1340
|
+
f"but found keys: {list(td.keys())}"
|
|
1341
|
+
)
|
|
1342
|
+
|
|
1343
|
+
text = td.get(self.input_key)
|
|
1344
|
+
if text is None:
|
|
1345
|
+
raise ValueError(f"Expected '{self.input_key}' key for text input mode")
|
|
1346
|
+
if isinstance(text, NonTensorStack):
|
|
1347
|
+
text = text.tolist()
|
|
1348
|
+
if not isinstance(text, list):
|
|
1349
|
+
raise ValueError(f"Expected list of text for text input, got {type(text)}")
|
|
1350
|
+
return self._generate_from_text(text, cfg, out)
|
|
1351
|
+
|
|
1352
|
+
def _from_transformers_logprobs_text(self, td, cfg, out, logits_only=False):
|
|
1353
|
+
"""Compute log-probs from text input."""
|
|
1354
|
+
# Validate input
|
|
1355
|
+
if self.input_key not in td:
|
|
1356
|
+
raise ValueError(
|
|
1357
|
+
f"Expected '{self.input_key}' key for text input mode, "
|
|
1358
|
+
f"but found keys: {list(td.keys())}"
|
|
1359
|
+
)
|
|
1360
|
+
|
|
1361
|
+
text = td.get(self.input_key)
|
|
1362
|
+
if isinstance(text, NonTensorStack):
|
|
1363
|
+
text = text.tolist()
|
|
1364
|
+
if text is None:
|
|
1365
|
+
raise ValueError(f"Expected '{self.input_key}' key for text input mode")
|
|
1366
|
+
if not isinstance(text, list):
|
|
1367
|
+
raise ValueError(f"Expected list of text for text input, got {type(text)}")
|
|
1368
|
+
# Tokenize the text
|
|
1369
|
+
if self.tokenizer is None:
|
|
1370
|
+
raise ValueError(
|
|
1371
|
+
"Tokenizer is required for log-probs computation with text input"
|
|
1372
|
+
)
|
|
1373
|
+
|
|
1374
|
+
# Convert text to list format
|
|
1375
|
+
if isinstance(text, str):
|
|
1376
|
+
text = [text]
|
|
1377
|
+
elif not isinstance(text, list):
|
|
1378
|
+
text = text.tolist()
|
|
1379
|
+
|
|
1380
|
+
# Tokenize the text
|
|
1381
|
+
tokenizer_kwargs = dict(self.tokenizer_kwargs)
|
|
1382
|
+
with torch.device(
|
|
1383
|
+
self._device
|
|
1384
|
+
) if self._device is not None else contextlib.nullcontext():
|
|
1385
|
+
tokens_in = self.tokenizer(text, **tokenizer_kwargs)
|
|
1386
|
+
|
|
1387
|
+
if cfg is not None:
|
|
1388
|
+
kwargs = copy(self.generate_kwargs)
|
|
1389
|
+
kwargs["generation_config"] = cfg
|
|
1390
|
+
else:
|
|
1391
|
+
kwargs = self.generate_kwargs
|
|
1392
|
+
|
|
1393
|
+
# We are going to map this tokens_in to a tensordict to facilitate the padding in case we need it
|
|
1394
|
+
tokens_in = (
|
|
1395
|
+
TensorDict(batch_size=len(tokens_in["input_ids"]))
|
|
1396
|
+
.to_lazystack(0)
|
|
1397
|
+
.update(dict(tokens_in))
|
|
1398
|
+
)
|
|
1399
|
+
pad_val = self.padding_value
|
|
1400
|
+
|
|
1401
|
+
if self.pad_model_input:
|
|
1402
|
+
tokens_full_padded = tokens_in.get(
|
|
1403
|
+
"input_ids",
|
|
1404
|
+
as_padded_tensor=True,
|
|
1405
|
+
padding_side="left",
|
|
1406
|
+
padding_value=pad_val,
|
|
1407
|
+
)
|
|
1408
|
+
attention_mask_full_padded = tokens_in.get(
|
|
1409
|
+
"attention_mask",
|
|
1410
|
+
as_padded_tensor=True,
|
|
1411
|
+
padding_side="left",
|
|
1412
|
+
padding_value=0,
|
|
1413
|
+
)
|
|
1414
|
+
|
|
1415
|
+
(
|
|
1416
|
+
log_probs_full_padded,
|
|
1417
|
+
logits_full_padded,
|
|
1418
|
+
) = self._model_forward_with_padded_sequences(
|
|
1419
|
+
tokens_full_padded,
|
|
1420
|
+
attention_mask_full_padded,
|
|
1421
|
+
pad_val=pad_val,
|
|
1422
|
+
logits_only=logits_only,
|
|
1423
|
+
**kwargs,
|
|
1424
|
+
)
|
|
1425
|
+
else:
|
|
1426
|
+
# packed forward pass
|
|
1427
|
+
tokens_full_unpadded = tokens_in.get(
|
|
1428
|
+
"input_ids",
|
|
1429
|
+
as_nested_tensor=True,
|
|
1430
|
+
layout=torch.jagged,
|
|
1431
|
+
)
|
|
1432
|
+
attention_mask_full_unpadded = tokens_in.get(
|
|
1433
|
+
"attention_mask",
|
|
1434
|
+
as_nested_tensor=True,
|
|
1435
|
+
layout=torch.jagged,
|
|
1436
|
+
)
|
|
1437
|
+
(
|
|
1438
|
+
log_probs_full_unpadded,
|
|
1439
|
+
logits_full_unpadded,
|
|
1440
|
+
) = self._model_forward_with_packed_sequences(
|
|
1441
|
+
tokens_full_unpadded, attention_mask_full_unpadded, pad=False, **kwargs
|
|
1442
|
+
)
|
|
1443
|
+
tokens_full_padded = pad_sequence(
|
|
1444
|
+
tokens_full_unpadded.unbind(0),
|
|
1445
|
+
batch_first=True,
|
|
1446
|
+
padding_value=pad_val,
|
|
1447
|
+
padding_side="left",
|
|
1448
|
+
)
|
|
1449
|
+
attention_mask_full_padded = pad_sequence(
|
|
1450
|
+
attention_mask_full_unpadded.unbind(0),
|
|
1451
|
+
batch_first=True,
|
|
1452
|
+
padding_value=0,
|
|
1453
|
+
padding_side="left",
|
|
1454
|
+
)
|
|
1455
|
+
log_probs_full_padded = pad_sequence(
|
|
1456
|
+
log_probs_full_unpadded.unbind(0),
|
|
1457
|
+
batch_first=True,
|
|
1458
|
+
padding_value=0.0,
|
|
1459
|
+
padding_side="left",
|
|
1460
|
+
)
|
|
1461
|
+
logits_full_padded = pad_sequence(
|
|
1462
|
+
logits_full_unpadded.unbind(0),
|
|
1463
|
+
batch_first=True,
|
|
1464
|
+
padding_value=0.0,
|
|
1465
|
+
padding_side="left",
|
|
1466
|
+
)
|
|
1467
|
+
|
|
1468
|
+
# Build output TensorClass objects
|
|
1469
|
+
text_obj = Text._from_tensordict(
|
|
1470
|
+
TensorDict(batch_size=out.batch_size).to_lazystack(0)
|
|
1471
|
+
)
|
|
1472
|
+
text_obj.prompt = None
|
|
1473
|
+
text_obj.response = None
|
|
1474
|
+
text_obj.full = text
|
|
1475
|
+
out.set(self.text_key, text_obj)
|
|
1476
|
+
|
|
1477
|
+
tokens_obj = Tokens._from_tensordict(
|
|
1478
|
+
TensorDict(batch_size=out.batch_size).to_lazystack(0)
|
|
1479
|
+
)
|
|
1480
|
+
if self.pad_output:
|
|
1481
|
+
tokens_obj.full = tokens_full_padded
|
|
1482
|
+
else:
|
|
1483
|
+
input_ids_full_unpadded = _unpad_tensors(
|
|
1484
|
+
tokens_full_padded, attention_mask_full_padded, as_nested=False
|
|
1485
|
+
)
|
|
1486
|
+
tokens_obj.full = input_ids_full_unpadded
|
|
1487
|
+
tokens_obj.response = None
|
|
1488
|
+
tokens_obj.padded = MetaData(self.pad_output)
|
|
1489
|
+
out.set(self.tokens_key, tokens_obj)
|
|
1490
|
+
|
|
1491
|
+
masks_obj = Masks._from_tensordict(
|
|
1492
|
+
TensorDict(batch_size=out.batch_size).to_lazystack(0)
|
|
1493
|
+
)
|
|
1494
|
+
if self.pad_output:
|
|
1495
|
+
masks_obj.all_attention_mask = attention_mask_full_padded.bool()
|
|
1496
|
+
masks_obj.all_assistant_mask = td.get(("masks", "all_assistant_mask"))
|
|
1497
|
+
else:
|
|
1498
|
+
attention_mask_full_unpadded = _unpad_tensors(
|
|
1499
|
+
attention_mask_full_padded.bool(),
|
|
1500
|
+
attention_mask_full_padded,
|
|
1501
|
+
as_nested=False,
|
|
1502
|
+
)
|
|
1503
|
+
masks_obj.all_attention_mask = attention_mask_full_unpadded
|
|
1504
|
+
masks_obj.all_assistant_mask = td.get(
|
|
1505
|
+
("masks", "all_assistant_mask"), as_list=True
|
|
1506
|
+
)
|
|
1507
|
+
masks_obj.padded = MetaData(self.pad_output)
|
|
1508
|
+
out.set(self.masks_key, masks_obj)
|
|
1509
|
+
|
|
1510
|
+
if not logits_only:
|
|
1511
|
+
log_probs_obj = LogProbs._from_tensordict(
|
|
1512
|
+
TensorDict(batch_size=out.batch_size).to_lazystack(0)
|
|
1513
|
+
)
|
|
1514
|
+
if self.pad_output:
|
|
1515
|
+
log_probs_obj.full = log_probs_full_padded
|
|
1516
|
+
else:
|
|
1517
|
+
log_probs_full_unpadded = _unpad_tensors(
|
|
1518
|
+
log_probs_full_padded, attention_mask_full_padded, as_nested=False
|
|
1519
|
+
)
|
|
1520
|
+
log_probs_obj.full = log_probs_full_unpadded
|
|
1521
|
+
log_probs_obj.response = None
|
|
1522
|
+
log_probs_obj.padded = MetaData(self.pad_output)
|
|
1523
|
+
out.set(self.log_probs_key, log_probs_obj)
|
|
1524
|
+
|
|
1525
|
+
# Add logits to output if we're in a get_dist call
|
|
1526
|
+
if self._in_get_dist_call:
|
|
1527
|
+
if self.pad_output:
|
|
1528
|
+
out.set("logits", logits_full_padded)
|
|
1529
|
+
else:
|
|
1530
|
+
logits_full_unpadded = _unpad_tensors(
|
|
1531
|
+
logits_full_padded, attention_mask_full_padded, as_nested=False
|
|
1532
|
+
)
|
|
1533
|
+
out.set("logits", logits_full_unpadded)
|
|
1534
|
+
|
|
1535
|
+
return out
|
|
1536
|
+
|
|
1537
|
+
def _from_transformers_generate_tokens(
|
|
1538
|
+
self, td: TensorDictBase, cfg: dict | None, out: TensorDictBase
|
|
1539
|
+
) -> TensorDictBase:
|
|
1540
|
+
"""Generate text from tokens input."""
|
|
1541
|
+
# Validate input
|
|
1542
|
+
if self.input_key not in td:
|
|
1543
|
+
raise ValueError(
|
|
1544
|
+
f"Expected '{self.input_key}' key for tokens input mode, "
|
|
1545
|
+
f"but found keys: {list(td.keys())}"
|
|
1546
|
+
)
|
|
1547
|
+
|
|
1548
|
+
pad_val = self.tokenizer.pad_token_id
|
|
1549
|
+
|
|
1550
|
+
input_ids_prompt_padded = td.get(
|
|
1551
|
+
self.input_key,
|
|
1552
|
+
as_padded_tensor=True,
|
|
1553
|
+
padding_side="left",
|
|
1554
|
+
padding_value=pad_val,
|
|
1555
|
+
)
|
|
1556
|
+
attention_mask_prompt_padded = td.get(
|
|
1557
|
+
("masks", "all_attention_mask"),
|
|
1558
|
+
as_padded_tensor=True,
|
|
1559
|
+
padding_side="left",
|
|
1560
|
+
padding_value=False,
|
|
1561
|
+
)
|
|
1562
|
+
if attention_mask_prompt_padded is None:
|
|
1563
|
+
attention_mask_prompt_padded = td.get(
|
|
1564
|
+
self.attention_mask_key,
|
|
1565
|
+
as_padded_tensor=True,
|
|
1566
|
+
padding_side="left",
|
|
1567
|
+
padding_value=False,
|
|
1568
|
+
)
|
|
1569
|
+
if attention_mask_prompt_padded is None:
|
|
1570
|
+
attention_mask_prompt_padded = input_ids_prompt_padded != pad_val
|
|
1571
|
+
return self._generate_from_tokens(
|
|
1572
|
+
input_ids_prompt_padded, attention_mask_prompt_padded, cfg, out
|
|
1573
|
+
)
|
|
1574
|
+
|
|
1575
|
+
def _generate_from_tokens(
|
|
1576
|
+
self,
|
|
1577
|
+
tokens_prompt_padded: torch.Tensor,
|
|
1578
|
+
attention_mask_prompt_padded: torch.Tensor,
|
|
1579
|
+
cfg: dict | None,
|
|
1580
|
+
out: TensorDictBase,
|
|
1581
|
+
) -> TensorDictBase:
|
|
1582
|
+
if cfg is not None:
|
|
1583
|
+
kwargs = copy(self.generate_kwargs)
|
|
1584
|
+
kwargs["generation_config"] = cfg
|
|
1585
|
+
else:
|
|
1586
|
+
kwargs = self.generate_kwargs
|
|
1587
|
+
|
|
1588
|
+
tokens_out_struct = self.model.generate(
|
|
1589
|
+
input_ids=tokens_prompt_padded,
|
|
1590
|
+
attention_mask=attention_mask_prompt_padded,
|
|
1591
|
+
**kwargs,
|
|
1592
|
+
)
|
|
1593
|
+
tokens_full_padded = tokens_out_struct["sequences"]
|
|
1594
|
+
tokens_response_padded = tokens_full_padded[:, tokens_prompt_padded.shape[-1] :]
|
|
1595
|
+
pad_val = getattr(self.tokenizer, "pad_token_id", None)
|
|
1596
|
+
if pad_val is None:
|
|
1597
|
+
pad_val = self.padding_value
|
|
1598
|
+
attention_mask_reponse_padded = tokens_response_padded != pad_val
|
|
1599
|
+
attention_mask_full_padded = tokens_full_padded != pad_val
|
|
1600
|
+
tokens_response_unpadded = _unpad_tensors(
|
|
1601
|
+
tokens_response_padded, attention_mask_reponse_padded, as_nested=False
|
|
1602
|
+
)
|
|
1603
|
+
|
|
1604
|
+
if self.return_log_probs:
|
|
1605
|
+
# These are only for the new tokens, not for the prompt - to get that, we'd need to run the forward pass again
|
|
1606
|
+
logits_response_padded = tokens_out_struct["logits"]
|
|
1607
|
+
logits_response_padded = torch.stack(list(logits_response_padded), 1)
|
|
1608
|
+
(
|
|
1609
|
+
log_probs_response_padded,
|
|
1610
|
+
logits_response_padded,
|
|
1611
|
+
) = self._log_probs_generate(
|
|
1612
|
+
tokens_response_padded,
|
|
1613
|
+
logits_response_padded,
|
|
1614
|
+
pad_val=pad_val,
|
|
1615
|
+
pad=False,
|
|
1616
|
+
)
|
|
1617
|
+
|
|
1618
|
+
response_text = self.tokenizer.batch_decode(
|
|
1619
|
+
tokens_response_unpadded, skip_special_tokens=False
|
|
1620
|
+
)
|
|
1621
|
+
|
|
1622
|
+
# Build output TensorClass objects
|
|
1623
|
+
text_obj = Text._from_tensordict(out.empty())
|
|
1624
|
+
text_obj.prompt = None # We don't have text in tokens mode
|
|
1625
|
+
with text_obj.view(-1) as text_obj_flat:
|
|
1626
|
+
text_obj_flat.response = response_text
|
|
1627
|
+
text_obj.full = None # we don't have text in tokens mode so no all_text either
|
|
1628
|
+
out.set(self.text_key, text_obj)
|
|
1629
|
+
|
|
1630
|
+
tokens_obj = Tokens._from_tensordict(out.empty())
|
|
1631
|
+
if not self.pad_output:
|
|
1632
|
+
input_ids_prompt_unpadded = _unpad_tensors(
|
|
1633
|
+
tokens_prompt_padded,
|
|
1634
|
+
attention_mask_prompt_padded,
|
|
1635
|
+
as_nested=False,
|
|
1636
|
+
)
|
|
1637
|
+
if self.num_samples is not None:
|
|
1638
|
+
# replicate tokens
|
|
1639
|
+
for i in range(self.num_samples):
|
|
1640
|
+
tokens_obj[:, i].prompt = (
|
|
1641
|
+
input_ids_prompt_unpadded
|
|
1642
|
+
if not self.pad_output
|
|
1643
|
+
else tokens_prompt_padded
|
|
1644
|
+
)
|
|
1645
|
+
else:
|
|
1646
|
+
tokens_obj.prompt = (
|
|
1647
|
+
input_ids_prompt_unpadded
|
|
1648
|
+
if not self.pad_output
|
|
1649
|
+
else tokens_prompt_padded
|
|
1650
|
+
)
|
|
1651
|
+
with tokens_obj.view(-1) as tokens_obj_flat:
|
|
1652
|
+
if self.pad_output:
|
|
1653
|
+
tokens_obj_flat.response = tokens_response_padded
|
|
1654
|
+
tokens_obj_flat.full = tokens_full_padded
|
|
1655
|
+
else:
|
|
1656
|
+
tokens_obj_flat.response = tokens_response_unpadded
|
|
1657
|
+
tokens_full_unpadded = _unpad_tensors(
|
|
1658
|
+
tokens_full_padded, attention_mask_full_padded, as_nested=False
|
|
1659
|
+
)
|
|
1660
|
+
tokens_obj_flat.full = tokens_full_unpadded
|
|
1661
|
+
tokens_obj.padded = MetaData(self.pad_output)
|
|
1662
|
+
out.set(self.tokens_key, tokens_obj)
|
|
1663
|
+
|
|
1664
|
+
masks_obj = Masks._from_tensordict(out.empty())
|
|
1665
|
+
if out.ndim == 2:
|
|
1666
|
+
attention_mask_full_padded = attention_mask_full_padded.unflatten(
|
|
1667
|
+
0, (-1, self.num_samples)
|
|
1668
|
+
)
|
|
1669
|
+
if self.pad_output:
|
|
1670
|
+
# Get "real" attention masks
|
|
1671
|
+
masks_obj.all_attention_mask = attention_mask_full_padded.bool()
|
|
1672
|
+
else:
|
|
1673
|
+
# Get "real" attention masks
|
|
1674
|
+
# We can use select to avoid batch-size problems
|
|
1675
|
+
_td = torch.ones_like(
|
|
1676
|
+
out.select(("tokens", "full"))
|
|
1677
|
+
.copy()
|
|
1678
|
+
.rename_key_(("tokens", "full"), "all_attention_mask")
|
|
1679
|
+
).bool()
|
|
1680
|
+
del _td["tokens"]
|
|
1681
|
+
masks_obj.update(_td)
|
|
1682
|
+
masks_obj.all_assistant_mask = None
|
|
1683
|
+
masks_obj.padded = MetaData(self.pad_output)
|
|
1684
|
+
out.set(self.masks_key, masks_obj)
|
|
1685
|
+
|
|
1686
|
+
if self.return_log_probs:
|
|
1687
|
+
log_probs_obj = LogProbs._from_tensordict(out.empty())
|
|
1688
|
+
if self.num_samples is None:
|
|
1689
|
+
if self.pad_output:
|
|
1690
|
+
log_probs_obj.response = log_probs_response_padded
|
|
1691
|
+
else:
|
|
1692
|
+
log_probs_response_unpadded = _unpad_tensors(
|
|
1693
|
+
log_probs_response_padded,
|
|
1694
|
+
attention_mask_reponse_padded,
|
|
1695
|
+
as_nested=False,
|
|
1696
|
+
)
|
|
1697
|
+
log_probs_obj.response = log_probs_response_unpadded
|
|
1698
|
+
else:
|
|
1699
|
+
with log_probs_obj.view(-1) as log_probs_obj_flat:
|
|
1700
|
+
if self.pad_output:
|
|
1701
|
+
log_probs_obj_flat.response = log_probs_response_padded
|
|
1702
|
+
else:
|
|
1703
|
+
log_probs_response_unpadded = _unpad_tensors(
|
|
1704
|
+
log_probs_response_padded,
|
|
1705
|
+
attention_mask_reponse_padded,
|
|
1706
|
+
as_nested=False,
|
|
1707
|
+
)
|
|
1708
|
+
log_probs_obj_flat.response = log_probs_response_unpadded
|
|
1709
|
+
log_probs_obj.padded = MetaData(self.pad_output)
|
|
1710
|
+
out.set(self.log_probs_key, log_probs_obj)
|
|
1711
|
+
|
|
1712
|
+
return out
|
|
1713
|
+
|
|
1714
|
+
def _from_transformers_logprobs_tokens(
|
|
1715
|
+
self,
|
|
1716
|
+
td: TensorDictBase,
|
|
1717
|
+
cfg: dict | None,
|
|
1718
|
+
out: TensorDictBase,
|
|
1719
|
+
logits_only=False,
|
|
1720
|
+
) -> TensorDictBase:
|
|
1721
|
+
"""Compute log-probs from tokens input."""
|
|
1722
|
+
# Validate input
|
|
1723
|
+
if self.input_key not in td:
|
|
1724
|
+
raise ValueError(
|
|
1725
|
+
f"Expected '{self.input_key}' key for tokens input mode, "
|
|
1726
|
+
f"but found keys: {list(td.keys(isinstance(self.input_key, tuple)))}"
|
|
1727
|
+
)
|
|
1728
|
+
|
|
1729
|
+
pad_val = self.tokenizer.pad_token_id
|
|
1730
|
+
|
|
1731
|
+
if cfg is not None:
|
|
1732
|
+
kwargs = copy(self.generate_kwargs)
|
|
1733
|
+
kwargs["generation_config"] = cfg
|
|
1734
|
+
else:
|
|
1735
|
+
kwargs = self.generate_kwargs
|
|
1736
|
+
|
|
1737
|
+
if self.pad_model_input:
|
|
1738
|
+
tokens_full_padded = td.get(
|
|
1739
|
+
self.input_key,
|
|
1740
|
+
as_padded_tensor=True,
|
|
1741
|
+
padding_side="left",
|
|
1742
|
+
padding_value=pad_val,
|
|
1743
|
+
)
|
|
1744
|
+
# Attention mask: try first the regular entry, then the key provided in the constructor, finally fallback on eager attention mask
|
|
1745
|
+
attention_mask_full_padded = td.get(
|
|
1746
|
+
("masks", "all_attention_mask"),
|
|
1747
|
+
as_padded_tensor=True,
|
|
1748
|
+
padding_side="left",
|
|
1749
|
+
padding_value=False,
|
|
1750
|
+
)
|
|
1751
|
+
if attention_mask_full_padded is None:
|
|
1752
|
+
attention_mask_full_padded = td.get(
|
|
1753
|
+
self.attention_mask_key,
|
|
1754
|
+
as_padded_tensor=True,
|
|
1755
|
+
padding_side="left",
|
|
1756
|
+
padding_value=False,
|
|
1757
|
+
)
|
|
1758
|
+
if attention_mask_full_padded is None:
|
|
1759
|
+
attention_mask_full_padded = tokens_full_padded != pad_val
|
|
1760
|
+
|
|
1761
|
+
(
|
|
1762
|
+
log_probs_full_padded,
|
|
1763
|
+
logits_full_padded,
|
|
1764
|
+
) = self._model_forward_with_padded_sequences(
|
|
1765
|
+
tokens_full_padded,
|
|
1766
|
+
attention_mask_full_padded,
|
|
1767
|
+
pad_val=pad_val,
|
|
1768
|
+
logits_only=logits_only,
|
|
1769
|
+
**kwargs,
|
|
1770
|
+
)
|
|
1771
|
+
else:
|
|
1772
|
+
# packed forward pass
|
|
1773
|
+
# unfortunately HF wants us to use padded tensors
|
|
1774
|
+
tokens_full_unpadded = td.get(
|
|
1775
|
+
self.input_key,
|
|
1776
|
+
as_nested_tensor=True,
|
|
1777
|
+
layout=torch.jagged,
|
|
1778
|
+
)
|
|
1779
|
+
if tokens_full_unpadded is None:
|
|
1780
|
+
raise ValueError(
|
|
1781
|
+
f"Expected '{self.input_key}' key for tokens input mode, but found keys: {list(td.keys())}"
|
|
1782
|
+
)
|
|
1783
|
+
# Attention mask: try first the regular entry, then the key provided in the constructor, finally fallback on eager attention mask
|
|
1784
|
+
attention_mask_full_unpadded = td.get(
|
|
1785
|
+
("masks", "all_attention_mask"),
|
|
1786
|
+
as_nested_tensor=True,
|
|
1787
|
+
layout=torch.jagged,
|
|
1788
|
+
)
|
|
1789
|
+
if attention_mask_full_unpadded is None:
|
|
1790
|
+
attention_mask_full_unpadded = td.get(
|
|
1791
|
+
self.attention_mask_key,
|
|
1792
|
+
as_nested_tensor=True,
|
|
1793
|
+
layout=torch.jagged,
|
|
1794
|
+
)
|
|
1795
|
+
if attention_mask_full_unpadded is None:
|
|
1796
|
+
# does this even work?
|
|
1797
|
+
attention_mask_full_unpadded = tokens_full_unpadded != pad_val
|
|
1798
|
+
|
|
1799
|
+
(
|
|
1800
|
+
log_probs_full_unpadded,
|
|
1801
|
+
logits_full_unpadded,
|
|
1802
|
+
) = self._model_forward_with_packed_sequences(
|
|
1803
|
+
# TODO: no padding if we don't need to
|
|
1804
|
+
tokens_full_unpadded,
|
|
1805
|
+
attention_mask_full_unpadded,
|
|
1806
|
+
pad=False,
|
|
1807
|
+
logits_only=logits_only,
|
|
1808
|
+
**kwargs,
|
|
1809
|
+
)
|
|
1810
|
+
tokens_full_padded = pad_sequence(
|
|
1811
|
+
tokens_full_unpadded.unbind(0),
|
|
1812
|
+
batch_first=True,
|
|
1813
|
+
padding_value=pad_val,
|
|
1814
|
+
padding_side="left",
|
|
1815
|
+
)
|
|
1816
|
+
attention_mask_full_padded = pad_sequence(
|
|
1817
|
+
attention_mask_full_unpadded.unbind(0),
|
|
1818
|
+
batch_first=True,
|
|
1819
|
+
padding_value=0,
|
|
1820
|
+
padding_side="left",
|
|
1821
|
+
)
|
|
1822
|
+
if log_probs_full_unpadded is not None:
|
|
1823
|
+
log_probs_full_padded = pad_sequence(
|
|
1824
|
+
log_probs_full_unpadded.unbind(0),
|
|
1825
|
+
batch_first=True,
|
|
1826
|
+
padding_value=0.0,
|
|
1827
|
+
padding_side="left",
|
|
1828
|
+
)
|
|
1829
|
+
else:
|
|
1830
|
+
log_probs_full_padded = None
|
|
1831
|
+
logits_full_padded = pad_sequence(
|
|
1832
|
+
logits_full_unpadded.unbind(0),
|
|
1833
|
+
batch_first=True,
|
|
1834
|
+
padding_value=0.0,
|
|
1835
|
+
padding_side="left",
|
|
1836
|
+
)
|
|
1837
|
+
|
|
1838
|
+
# Build output TensorClass objects
|
|
1839
|
+
text_obj = Text._from_tensordict(
|
|
1840
|
+
TensorDict(batch_size=out.batch_size).to_lazystack(0)
|
|
1841
|
+
)
|
|
1842
|
+
text_obj.prompt = None
|
|
1843
|
+
text_obj.response = None
|
|
1844
|
+
text_obj.full = None
|
|
1845
|
+
out.set(self.text_key, text_obj)
|
|
1846
|
+
|
|
1847
|
+
tokens_obj = Tokens._from_tensordict(
|
|
1848
|
+
TensorDict(batch_size=out.batch_size).to_lazystack(0)
|
|
1849
|
+
)
|
|
1850
|
+
if not self.pad_output:
|
|
1851
|
+
input_ids_full_unpadded = _unpad_tensors(
|
|
1852
|
+
tokens_full_padded, attention_mask_full_padded, as_nested=False
|
|
1853
|
+
)
|
|
1854
|
+
tokens_obj.full = input_ids_full_unpadded
|
|
1855
|
+
else:
|
|
1856
|
+
tokens_obj.full = tokens_full_padded
|
|
1857
|
+
tokens_obj.response = None
|
|
1858
|
+
tokens_obj.padded = MetaData(self.pad_output)
|
|
1859
|
+
out.set(self.tokens_key, tokens_obj)
|
|
1860
|
+
|
|
1861
|
+
masks_obj = Masks._from_tensordict(
|
|
1862
|
+
TensorDict(batch_size=out.batch_size).to_lazystack(0)
|
|
1863
|
+
)
|
|
1864
|
+
if self.pad_output:
|
|
1865
|
+
masks_obj.all_attention_mask = attention_mask_full_padded.bool()
|
|
1866
|
+
masks_obj.all_assistant_mask = td.get(("masks", "all_assistant_mask"))
|
|
1867
|
+
else:
|
|
1868
|
+
masks_obj.all_attention_mask = _unpad_tensors(
|
|
1869
|
+
attention_mask_full_padded.bool(),
|
|
1870
|
+
attention_mask_full_padded,
|
|
1871
|
+
as_nested=False,
|
|
1872
|
+
)
|
|
1873
|
+
masks_obj.all_assistant_mask = td.get(
|
|
1874
|
+
("masks", "all_assistant_mask"), as_list=True
|
|
1875
|
+
)
|
|
1876
|
+
|
|
1877
|
+
masks_obj.padded = MetaData(self.pad_output)
|
|
1878
|
+
out.set(self.masks_key, masks_obj)
|
|
1879
|
+
|
|
1880
|
+
if not logits_only:
|
|
1881
|
+
log_probs_obj = LogProbs._from_tensordict(
|
|
1882
|
+
TensorDict(batch_size=out.batch_size).to_lazystack(0)
|
|
1883
|
+
)
|
|
1884
|
+
if self.pad_output:
|
|
1885
|
+
log_probs_obj.full = log_probs_full_padded
|
|
1886
|
+
else:
|
|
1887
|
+
log_probs_full_unpadded = _unpad_tensors(
|
|
1888
|
+
log_probs_full_padded, attention_mask_full_padded, as_nested=False
|
|
1889
|
+
)
|
|
1890
|
+
log_probs_obj.full = log_probs_full_unpadded
|
|
1891
|
+
log_probs_obj.response = None
|
|
1892
|
+
log_probs_obj.padded = MetaData(self.pad_output)
|
|
1893
|
+
out.set(self.log_probs_key, log_probs_obj)
|
|
1894
|
+
|
|
1895
|
+
# Add logits to output if we're in a get_dist call
|
|
1896
|
+
if self._in_get_dist_call:
|
|
1897
|
+
if self.pad_output:
|
|
1898
|
+
out.set("logits", logits_full_padded)
|
|
1899
|
+
else:
|
|
1900
|
+
logits_full_unpadded = _unpad_tensors(
|
|
1901
|
+
logits_full_padded, attention_mask_full_padded, as_nested=False
|
|
1902
|
+
)
|
|
1903
|
+
out.set("logits", logits_full_unpadded)
|
|
1904
|
+
return out
|
|
1905
|
+
|
|
1906
|
+
@classmethod
|
|
1907
|
+
def _log_probs_generate(cls, tokens, logits, pad_val=-100, pad: bool = True):
|
|
1908
|
+
if pad:
|
|
1909
|
+
tokens = pad_sequence(
|
|
1910
|
+
tokens,
|
|
1911
|
+
padding_value=pad_val,
|
|
1912
|
+
batch_first=True,
|
|
1913
|
+
padding_side="left",
|
|
1914
|
+
)
|
|
1915
|
+
logits = pad_sequence(
|
|
1916
|
+
logits,
|
|
1917
|
+
padding_value=0.0,
|
|
1918
|
+
batch_first=True,
|
|
1919
|
+
padding_side="left",
|
|
1920
|
+
)
|
|
1921
|
+
|
|
1922
|
+
# logits = logits.log_softmax(dim=-1)
|
|
1923
|
+
# log_probs = logits.gather(-1, tokens.unsqueeze(-1)).squeeze(-1)
|
|
1924
|
+
td = TensorDict(logits=logits, tokens=tokens).auto_batch_size_()
|
|
1925
|
+
with td.flatten() as tdflat:
|
|
1926
|
+
tdflat["log_probs"] = -torch.nn.functional.cross_entropy(
|
|
1927
|
+
tdflat["logits"], tdflat["tokens"], reduce=False, ignore_index=pad_val
|
|
1928
|
+
)
|
|
1929
|
+
td["log_probs"][:, 0] = 0
|
|
1930
|
+
log_probs = td["log_probs"]
|
|
1931
|
+
return log_probs, logits
|
|
1932
|
+
|
|
1933
|
+
def _compute_log_probs_from_model_output(
|
|
1934
|
+
self, model_output, input_ids, attention_mask, pad_val, logits_only=False
|
|
1935
|
+
):
|
|
1936
|
+
"""Compute log-probs from model output without modifying original tensors.
|
|
1937
|
+
|
|
1938
|
+
Args:
|
|
1939
|
+
model_output: Output from the model containing logits
|
|
1940
|
+
input_ids: Original input token ids
|
|
1941
|
+
attention_mask: Original attention mask
|
|
1942
|
+
pad_val: Padding token value to ignore in loss computation
|
|
1943
|
+
logits_only: Whether to return only the logits.
|
|
1944
|
+
|
|
1945
|
+
Returns:
|
|
1946
|
+
tuple: (log_probs, shifted_logits) where log_probs are the computed log probabilities
|
|
1947
|
+
and shifted_logits are the logits shifted to align with tokens
|
|
1948
|
+
"""
|
|
1949
|
+
logits = model_output["logits"]
|
|
1950
|
+
|
|
1951
|
+
# Create shifted versions for log-prob computation without modifying originals
|
|
1952
|
+
shifted_logits = logits[:, :-1, :]
|
|
1953
|
+
# shifted_logits = shifted_logits - shifted_logits.logsumexp(dim=-1, keepdim=True)
|
|
1954
|
+
shifted_logits = torch.cat(
|
|
1955
|
+
[torch.zeros_like(shifted_logits[:, :1]), shifted_logits], 1
|
|
1956
|
+
)
|
|
1957
|
+
|
|
1958
|
+
shifted_input_ids = input_ids[:, 1:]
|
|
1959
|
+
shifted_input_ids = torch.cat(
|
|
1960
|
+
[torch.zeros_like(shifted_input_ids[:, :1]), shifted_input_ids], 1
|
|
1961
|
+
)
|
|
1962
|
+
|
|
1963
|
+
# Check that the shape is correct
|
|
1964
|
+
if shifted_logits.shape[-2] != shifted_input_ids.shape[-1]:
|
|
1965
|
+
raise ValueError(
|
|
1966
|
+
f"The logits shape {shifted_logits.shape} does not match the input ids shape {shifted_input_ids.shape}"
|
|
1967
|
+
)
|
|
1968
|
+
if logits_only:
|
|
1969
|
+
return None, shifted_logits
|
|
1970
|
+
|
|
1971
|
+
# Compute log-probs
|
|
1972
|
+
td = TensorDict(
|
|
1973
|
+
logits=shifted_logits, tokens=shifted_input_ids
|
|
1974
|
+
).auto_batch_size_()
|
|
1975
|
+
with td.flatten() as tdflat:
|
|
1976
|
+
tdflat["log_probs"] = -torch.nn.functional.cross_entropy(
|
|
1977
|
+
tdflat["logits"],
|
|
1978
|
+
tdflat["tokens"],
|
|
1979
|
+
reduce=False,
|
|
1980
|
+
ignore_index=pad_val,
|
|
1981
|
+
)
|
|
1982
|
+
# For consistency with vllm, we set the log-probs of the first token to 0
|
|
1983
|
+
# However, the first element may not be the first - we want the first of the attention mask,
|
|
1984
|
+
# i.e, the first element that is true on the left
|
|
1985
|
+
attention_mask = attention_mask.bool()
|
|
1986
|
+
attention_mask_first_left = ~attention_mask[:, :-1] & attention_mask[:, 1:]
|
|
1987
|
+
attention_mask_first_left = torch.cat(
|
|
1988
|
+
[
|
|
1989
|
+
torch.zeros_like(attention_mask_first_left[..., :1]),
|
|
1990
|
+
attention_mask_first_left,
|
|
1991
|
+
],
|
|
1992
|
+
-1,
|
|
1993
|
+
)
|
|
1994
|
+
attention_mask_first_left[~(attention_mask_first_left.any(-1)), 0] = True
|
|
1995
|
+
assert attention_mask_first_left.any(-1).all()
|
|
1996
|
+
attention_mask_first_left = attention_mask_first_left | ~attention_mask
|
|
1997
|
+
td["log_probs"][attention_mask_first_left] = 0
|
|
1998
|
+
|
|
1999
|
+
return td["log_probs"], shifted_logits
|
|
2000
|
+
|
|
2001
|
+
def get_dist(
|
|
2002
|
+
self,
|
|
2003
|
+
tensordict: TensorDictBase,
|
|
2004
|
+
tensordict_out: TensorDictBase | None = None,
|
|
2005
|
+
logits_key: NestedKey = "logits",
|
|
2006
|
+
mask_key: NestedKey | None = None,
|
|
2007
|
+
as_padded_tensor: bool | None = None,
|
|
2008
|
+
as_nested_tensor: bool | None = None,
|
|
2009
|
+
padding_value: float | None = None,
|
|
2010
|
+
padding_side: str = "right",
|
|
2011
|
+
layout: torch.layout | None = None,
|
|
2012
|
+
**kwargs,
|
|
2013
|
+
) -> D.Distribution:
|
|
2014
|
+
"""Get distribution from logits/log-probs with optional masking.
|
|
2015
|
+
|
|
2016
|
+
This method enables logits computation for distribution creation.
|
|
2017
|
+
"""
|
|
2018
|
+
self._in_get_dist_call = True
|
|
2019
|
+
self.out_keys += ["logits"]
|
|
2020
|
+
try:
|
|
2021
|
+
return super().get_dist(
|
|
2022
|
+
tensordict,
|
|
2023
|
+
tensordict_out,
|
|
2024
|
+
logits_key,
|
|
2025
|
+
mask_key,
|
|
2026
|
+
as_padded_tensor,
|
|
2027
|
+
as_nested_tensor,
|
|
2028
|
+
padding_value,
|
|
2029
|
+
padding_side,
|
|
2030
|
+
layout,
|
|
2031
|
+
**kwargs,
|
|
2032
|
+
)
|
|
2033
|
+
finally:
|
|
2034
|
+
self._in_get_dist_call = False
|
|
2035
|
+
self.out_keys.remove("logits")
|
|
2036
|
+
|
|
2037
|
+
def _get_dist_with_prompt_mask(
|
|
2038
|
+
self,
|
|
2039
|
+
tensordict: TensorDictBase,
|
|
2040
|
+
tokens_key: NestedKey = ("tokens", "prompt"),
|
|
2041
|
+
logits_key: NestedKey = "logits",
|
|
2042
|
+
assistant_mask_key: NestedKey = ("masks", "all_assistant_mask"),
|
|
2043
|
+
attention_mask_key: NestedKey = ("masks", "all_attention_mask"),
|
|
2044
|
+
**kwargs,
|
|
2045
|
+
) -> D.Distribution:
|
|
2046
|
+
"""Get distribution masked to only include response tokens (exclude prompt).
|
|
2047
|
+
|
|
2048
|
+
This method enables logits computation for distribution creation.
|
|
2049
|
+
|
|
2050
|
+
This is a provisional method that will be replaced by the `get_dist` method once we have a better masking strategy.
|
|
2051
|
+
"""
|
|
2052
|
+
self._in_get_dist_call = True
|
|
2053
|
+
self.out_keys += ["logits"]
|
|
2054
|
+
try:
|
|
2055
|
+
return super()._get_dist_with_prompt_mask(
|
|
2056
|
+
tensordict,
|
|
2057
|
+
tokens_key,
|
|
2058
|
+
logits_key,
|
|
2059
|
+
assistant_mask_key,
|
|
2060
|
+
attention_mask_key,
|
|
2061
|
+
**kwargs,
|
|
2062
|
+
)
|
|
2063
|
+
finally:
|
|
2064
|
+
self._in_get_dist_call = False
|
|
2065
|
+
self.out_keys.remove("logits")
|
|
2066
|
+
|
|
2067
|
+
def _get_dist_with_assistant_mask(
|
|
2068
|
+
self,
|
|
2069
|
+
tensordict: TensorDictBase,
|
|
2070
|
+
assistant_mask_key: NestedKey = ("masks", "all_assistant_mask"),
|
|
2071
|
+
logits_key: NestedKey = "logits",
|
|
2072
|
+
**kwargs,
|
|
2073
|
+
) -> D.Distribution:
|
|
2074
|
+
"""Get distribution masked to only include assistant tokens.
|
|
2075
|
+
|
|
2076
|
+
This method enables logits computation for distribution creation.
|
|
2077
|
+
|
|
2078
|
+
This is a provisional method that will be replaced by the `get_dist` method once we have a better masking strategy.
|
|
2079
|
+
"""
|
|
2080
|
+
self._in_get_dist_call = True
|
|
2081
|
+
self.out_keys += ["logits"]
|
|
2082
|
+
try:
|
|
2083
|
+
return super()._get_dist_with_assistant_mask(
|
|
2084
|
+
tensordict, assistant_mask_key, logits_key, **kwargs
|
|
2085
|
+
)
|
|
2086
|
+
finally:
|
|
2087
|
+
self._in_get_dist_call = False
|
|
2088
|
+
self.out_keys.remove("logits")
|
|
2089
|
+
|
|
2090
|
+
def _get_dist_with_attention_mask(
|
|
2091
|
+
self,
|
|
2092
|
+
tensordict: TensorDictBase,
|
|
2093
|
+
attention_mask_key: NestedKey = ("masks", "all_attention_mask"),
|
|
2094
|
+
logits_key: NestedKey = "logits",
|
|
2095
|
+
**kwargs,
|
|
2096
|
+
) -> D.Distribution:
|
|
2097
|
+
"""Get distribution masked using attention mask.
|
|
2098
|
+
|
|
2099
|
+
This method enables logits computation for distribution creation.
|
|
2100
|
+
|
|
2101
|
+
This is a provisional method that will be replaced by the `get_dist` method once we have a better masking strategy.
|
|
2102
|
+
"""
|
|
2103
|
+
self._in_get_dist_call = True
|
|
2104
|
+
self.out_keys += ["logits"]
|
|
2105
|
+
try:
|
|
2106
|
+
return super()._get_dist_with_attention_mask(
|
|
2107
|
+
tensordict, attention_mask_key, logits_key, **kwargs
|
|
2108
|
+
)
|
|
2109
|
+
finally:
|
|
2110
|
+
self._in_get_dist_call = False
|
|
2111
|
+
self.out_keys.remove("logits")
|
|
2112
|
+
|
|
2113
|
+
def _get_dist_with_custom_mask(
|
|
2114
|
+
self,
|
|
2115
|
+
tensordict: TensorDictBase,
|
|
2116
|
+
mask: torch.Tensor,
|
|
2117
|
+
logits_key: NestedKey = "logits",
|
|
2118
|
+
**kwargs,
|
|
2119
|
+
) -> D.Distribution:
|
|
2120
|
+
"""Get distribution with custom mask.
|
|
2121
|
+
|
|
2122
|
+
This method enables logits computation for distribution creation.
|
|
2123
|
+
"""
|
|
2124
|
+
self._in_get_dist_call = True
|
|
2125
|
+
self.out_keys += ["logits"]
|
|
2126
|
+
try:
|
|
2127
|
+
return super()._get_dist_with_custom_mask(
|
|
2128
|
+
tensordict, mask, logits_key, **kwargs
|
|
2129
|
+
)
|
|
2130
|
+
finally:
|
|
2131
|
+
self._in_get_dist_call = False
|
|
2132
|
+
self.out_keys.remove("logits")
|
|
2133
|
+
|
|
2134
|
+
# Convenience methods for common LLM training scenarios
|
|
2135
|
+
def _get_sft_dist(self, tensordict: TensorDictBase, **kwargs) -> D.Distribution:
|
|
2136
|
+
"""Get distribution suitable for SFT loss (response tokens only).
|
|
2137
|
+
|
|
2138
|
+
This method enables logits computation for distribution creation.
|
|
2139
|
+
|
|
2140
|
+
This is a provisional method that will be replaced by the `get_dist` method once we have a better masking strategy.
|
|
2141
|
+
"""
|
|
2142
|
+
self._in_get_dist_call = True
|
|
2143
|
+
self.out_keys += ["logits"]
|
|
2144
|
+
try:
|
|
2145
|
+
return super()._get_sft_dist(tensordict, **kwargs)
|
|
2146
|
+
finally:
|
|
2147
|
+
self._in_get_dist_call = False
|
|
2148
|
+
self.out_keys.remove("logits")
|
|
2149
|
+
|
|
2150
|
+
def _get_rlhf_dist(self, tensordict: TensorDictBase, **kwargs) -> D.Distribution:
|
|
2151
|
+
"""Get distribution suitable for RLHF loss (assistant tokens only).
|
|
2152
|
+
|
|
2153
|
+
This method enables logits computation for distribution creation.
|
|
2154
|
+
|
|
2155
|
+
This is a provisional method that will be replaced by the `get_dist` method once we have a better masking strategy.
|
|
2156
|
+
"""
|
|
2157
|
+
self._in_get_dist_call = True
|
|
2158
|
+
self.out_keys += ["logits"]
|
|
2159
|
+
try:
|
|
2160
|
+
return super()._get_rlhf_dist(tensordict, **kwargs)
|
|
2161
|
+
finally:
|
|
2162
|
+
self._in_get_dist_call = False
|
|
2163
|
+
self.out_keys.remove("logits")
|
|
2164
|
+
|
|
2165
|
+
def _get_generic_dist(self, tensordict: TensorDictBase, **kwargs) -> D.Distribution:
|
|
2166
|
+
"""Get distribution suitable for generic losses (all tokens).
|
|
2167
|
+
|
|
2168
|
+
This method enables logits computation for distribution creation.
|
|
2169
|
+
|
|
2170
|
+
This is a provisional method that will be replaced by the `get_dist` method once we have a better masking strategy.
|
|
2171
|
+
"""
|
|
2172
|
+
self._in_get_dist_call = True
|
|
2173
|
+
self.out_keys += ["logits"]
|
|
2174
|
+
try:
|
|
2175
|
+
return super()._get_generic_dist(tensordict, **kwargs)
|
|
2176
|
+
finally:
|
|
2177
|
+
self._in_get_dist_call = False
|
|
2178
|
+
self.out_keys.remove("logits")
|
|
2179
|
+
|
|
2180
|
+
def _pack_sequences(
|
|
2181
|
+
self,
|
|
2182
|
+
input_ids: torch.nested.NestedTensor,
|
|
2183
|
+
attention_mask: torch.nested.NestedTensor,
|
|
2184
|
+
) -> tuple[torch.Tensor, torch.Tensor, dict[str, Any]]:
|
|
2185
|
+
"""Pack sequences into a single tensor."""
|
|
2186
|
+
packed_input_ids = input_ids.values()
|
|
2187
|
+
lengths = input_ids.lengths()
|
|
2188
|
+
if lengths is None:
|
|
2189
|
+
offsets = input_ids.offsets()
|
|
2190
|
+
lengths = offsets.diff()
|
|
2191
|
+
offsets = offsets[1:]
|
|
2192
|
+
else:
|
|
2193
|
+
offsets = lengths.cumsum(0)
|
|
2194
|
+
# Create block-diagonal attention mask to prevent cross-sequence attention
|
|
2195
|
+
attention_mask = self._create_block_diagonal_attention_mask(lengths)
|
|
2196
|
+
# Create position IDs that restart for each sequence
|
|
2197
|
+
position_ids = self._create_packed_position_ids(
|
|
2198
|
+
lengths, total_length=packed_input_ids.numel()
|
|
2199
|
+
)
|
|
2200
|
+
|
|
2201
|
+
packing_metadata = {
|
|
2202
|
+
"sequence_lengths": lengths,
|
|
2203
|
+
"cumulative_lengths": offsets,
|
|
2204
|
+
"attention_mask": attention_mask,
|
|
2205
|
+
"position_ids": position_ids,
|
|
2206
|
+
}
|
|
2207
|
+
|
|
2208
|
+
return (
|
|
2209
|
+
packed_input_ids.unsqueeze(0),
|
|
2210
|
+
attention_mask.unsqueeze(0),
|
|
2211
|
+
packing_metadata,
|
|
2212
|
+
)
|
|
2213
|
+
|
|
2214
|
+
def _model_forward_with_padded_sequences(
|
|
2215
|
+
self,
|
|
2216
|
+
tokens_full_padded: torch.Tensor,
|
|
2217
|
+
attention_mask_full_padded: torch.Tensor,
|
|
2218
|
+
*,
|
|
2219
|
+
pad_val: float | int | torch.Tensor | None = None,
|
|
2220
|
+
logits_only: bool = False,
|
|
2221
|
+
**kwargs,
|
|
2222
|
+
):
|
|
2223
|
+
"""Forward pass with padded sequences."""
|
|
2224
|
+
# Error handling for empty sequences
|
|
2225
|
+
if tokens_full_padded.numel() == 0:
|
|
2226
|
+
raise ValueError(
|
|
2227
|
+
"Input contains empty sequences. Packing/padding requires at least one token per sequence."
|
|
2228
|
+
)
|
|
2229
|
+
# Error handling for overlong sequences
|
|
2230
|
+
config = getattr(self.model, "config", None)
|
|
2231
|
+
max_len = getattr(config, "max_position_embeddings", None)
|
|
2232
|
+
if max_len is not None and tokens_full_padded.shape[-1] > max_len:
|
|
2233
|
+
raise ValueError(
|
|
2234
|
+
f"Input sequence length ({tokens_full_padded.shape[-1]}) exceeds model's max_position_embeddings ({max_len}). Consider truncating or splitting your input."
|
|
2235
|
+
)
|
|
2236
|
+
tokens_out_struct = self.model(
|
|
2237
|
+
tokens_full_padded, attention_mask_full_padded, **kwargs
|
|
2238
|
+
)
|
|
2239
|
+
(
|
|
2240
|
+
log_probs_full_padded,
|
|
2241
|
+
logits_full_padded,
|
|
2242
|
+
) = self._compute_log_probs_from_model_output(
|
|
2243
|
+
tokens_out_struct,
|
|
2244
|
+
tokens_full_padded,
|
|
2245
|
+
attention_mask_full_padded,
|
|
2246
|
+
pad_val,
|
|
2247
|
+
logits_only=logits_only,
|
|
2248
|
+
)
|
|
2249
|
+
return log_probs_full_padded, logits_full_padded
|
|
2250
|
+
|
|
2251
|
+
def _model_forward_with_packed_sequences(
|
|
2252
|
+
self,
|
|
2253
|
+
flat_input_ids: torch.Tensor,
|
|
2254
|
+
block_diag_attention_mask: torch.Tensor,
|
|
2255
|
+
*,
|
|
2256
|
+
pad: bool = True,
|
|
2257
|
+
logits_only: bool = False,
|
|
2258
|
+
**kwargs,
|
|
2259
|
+
):
|
|
2260
|
+
"""Pack sequences into a single tensor and forward them through the model.
|
|
2261
|
+
|
|
2262
|
+
Args:
|
|
2263
|
+
flat_input_ids (NestedTensor): NestedTensor of shape (batch_size, -1)
|
|
2264
|
+
block_diag_attention_mask (NestedTensor): NestedTensor of shape (batch_size, -1)
|
|
2265
|
+
|
|
2266
|
+
Returns:
|
|
2267
|
+
pad (bool): Whether to pad the output tensors.
|
|
2268
|
+
logits_only (bool): Whether to return only logits.
|
|
2269
|
+
kwargs (dict): Additional keyword arguments to pass to the model.
|
|
2270
|
+
|
|
2271
|
+
"""
|
|
2272
|
+
# Error handling for empty sequences
|
|
2273
|
+
if flat_input_ids.numel() == 0:
|
|
2274
|
+
raise ValueError(
|
|
2275
|
+
"Input contains empty sequences. Packing requires at least one token per sequence."
|
|
2276
|
+
)
|
|
2277
|
+
# Error handling for overlong sequences
|
|
2278
|
+
# Note: Skipping this check for nested tensors due to symbolic representation issues
|
|
2279
|
+
# The model will handle sequence length limits internally
|
|
2280
|
+
max_len = getattr(self.model.config, "max_position_embeddings", None)
|
|
2281
|
+
if max_len is not None and not hasattr(flat_input_ids, "size"):
|
|
2282
|
+
# Only check for regular tensors, not nested tensors
|
|
2283
|
+
actual_size = flat_input_ids.shape[-1]
|
|
2284
|
+
if actual_size > max_len:
|
|
2285
|
+
raise ValueError(
|
|
2286
|
+
f"Input sequence length ({actual_size}) exceeds model's max_position_embeddings ({max_len}). Consider truncating or splitting your input."
|
|
2287
|
+
)
|
|
2288
|
+
(
|
|
2289
|
+
flat_input_ids,
|
|
2290
|
+
block_diag_attention_mask,
|
|
2291
|
+
packing_metadata,
|
|
2292
|
+
) = self._pack_sequences(flat_input_ids, block_diag_attention_mask)
|
|
2293
|
+
|
|
2294
|
+
outputs = self.model(
|
|
2295
|
+
input_ids=flat_input_ids,
|
|
2296
|
+
attention_mask=block_diag_attention_mask.unsqueeze(0),
|
|
2297
|
+
position_ids=packing_metadata["position_ids"],
|
|
2298
|
+
use_cache=False, # Disable KV cache for packing
|
|
2299
|
+
**kwargs,
|
|
2300
|
+
)
|
|
2301
|
+
log_probs, logits = self._unpack_outputs(
|
|
2302
|
+
outputs, packing_metadata, flat_input_ids, pad=pad, logits_only=logits_only
|
|
2303
|
+
)
|
|
2304
|
+
return log_probs, logits
|
|
2305
|
+
|
|
2306
|
+
def _unpack_outputs(
|
|
2307
|
+
self,
|
|
2308
|
+
outputs,
|
|
2309
|
+
packing_metadata: dict[str, Any],
|
|
2310
|
+
flat_input_ids: torch.Tensor,
|
|
2311
|
+
pad: bool = True,
|
|
2312
|
+
logits_only: bool = False,
|
|
2313
|
+
) -> tuple[torch.Tensor | None, torch.Tensor]:
|
|
2314
|
+
"""Unpack outputs using nested tensors - zero syncs."""
|
|
2315
|
+
# use cross_entropy to compute log_probs
|
|
2316
|
+
log_probs, logits = self._compute_log_probs_from_model_output(
|
|
2317
|
+
outputs,
|
|
2318
|
+
flat_input_ids,
|
|
2319
|
+
torch.ones_like(flat_input_ids, dtype=torch.bool),
|
|
2320
|
+
-100,
|
|
2321
|
+
logits_only=logits_only,
|
|
2322
|
+
)
|
|
2323
|
+
# check shapes: [1, L] for log_probs, [1, L, vocab_size] for logits
|
|
2324
|
+
sequence_lengths = packing_metadata["sequence_lengths"]
|
|
2325
|
+
if logits_only:
|
|
2326
|
+
log_probs = None
|
|
2327
|
+
else:
|
|
2328
|
+
if log_probs.shape != logits.shape[:2]:
|
|
2329
|
+
raise ValueError(
|
|
2330
|
+
f"Log probs shape {log_probs.shape=} does not match logits shape {logits.shape[:2]=}"
|
|
2331
|
+
)
|
|
2332
|
+
if log_probs.ndim != 2:
|
|
2333
|
+
raise ValueError(f"Log probs shape {log_probs.shape=} is not 2D")
|
|
2334
|
+
if logits.ndim != 3:
|
|
2335
|
+
raise ValueError(f"Logits shape {logits.shape=} is not 3D")
|
|
2336
|
+
if log_probs.shape[1] != sequence_lengths.sum():
|
|
2337
|
+
raise ValueError(
|
|
2338
|
+
f"Log probs shape {log_probs.shape=} does not match sequence lengths {sequence_lengths.sum()=}"
|
|
2339
|
+
)
|
|
2340
|
+
|
|
2341
|
+
log_probs = log_probs.squeeze(0)
|
|
2342
|
+
nested_logprobs = torch.nested.nested_tensor_from_jagged(
|
|
2343
|
+
log_probs,
|
|
2344
|
+
lengths=sequence_lengths,
|
|
2345
|
+
)
|
|
2346
|
+
|
|
2347
|
+
logits = logits.squeeze(0)
|
|
2348
|
+
nested_logits = torch.nested.nested_tensor_from_jagged(
|
|
2349
|
+
logits, # Remove batch dim: (total_length, vocab_size)
|
|
2350
|
+
lengths=sequence_lengths,
|
|
2351
|
+
)
|
|
2352
|
+
|
|
2353
|
+
if logits_only:
|
|
2354
|
+
if pad:
|
|
2355
|
+
return None, nested_logits.to_padded_tensor(padding=0.0)
|
|
2356
|
+
return None, nested_logits
|
|
2357
|
+
else:
|
|
2358
|
+
if pad:
|
|
2359
|
+
return nested_logprobs.to_padded_tensor(
|
|
2360
|
+
padding=0.0
|
|
2361
|
+
), nested_logits.to_padded_tensor(padding=0.0)
|
|
2362
|
+
return nested_logprobs, nested_logits
|
|
2363
|
+
|
|
2364
|
+
def _create_block_diagonal_attention_mask(
|
|
2365
|
+
self, sequence_lengths: torch.Tensor
|
|
2366
|
+
) -> torch.Tensor:
|
|
2367
|
+
"""Efficient creation of a block-diagonal attention mask.
|
|
2368
|
+
|
|
2369
|
+
Zero cuda syncs, no integer involved except len(tensor) - compilable.
|
|
2370
|
+
|
|
2371
|
+
Args:
|
|
2372
|
+
sequence_lengths: Tensor of shape (batch_size,) containing the lengths of the sequences
|
|
2373
|
+
|
|
2374
|
+
Returns:
|
|
2375
|
+
attention_mask: Tensor of shape (batch_size, total_length, total_length)
|
|
2376
|
+
where each sequence can only attend to itself.
|
|
2377
|
+
"""
|
|
2378
|
+
seq_ids = torch.arange(len(sequence_lengths), device=sequence_lengths.device)
|
|
2379
|
+
position_to_seq_id = seq_ids.repeat_interleave(sequence_lengths)
|
|
2380
|
+
|
|
2381
|
+
attention_mask = position_to_seq_id.unsqueeze(
|
|
2382
|
+
1
|
|
2383
|
+
) == position_to_seq_id.unsqueeze(0)
|
|
2384
|
+
return attention_mask
|
|
2385
|
+
|
|
2386
|
+
def repeat_interleave_causal(self, sequence_lengths: torch.Tensor) -> torch.Tensor:
|
|
2387
|
+
"""Same as _create_block_diagonal_attention_mask, but with causal masking."""
|
|
2388
|
+
total_length = sequence_lengths.sum()
|
|
2389
|
+
|
|
2390
|
+
seq_ids = torch.arange(len(sequence_lengths), device=sequence_lengths.device)
|
|
2391
|
+
position_to_seq_id = seq_ids.repeat_interleave(sequence_lengths)
|
|
2392
|
+
|
|
2393
|
+
positions = torch.arange(int(total_length), device=sequence_lengths.device)
|
|
2394
|
+
|
|
2395
|
+
same_sequence = position_to_seq_id.unsqueeze(1) == position_to_seq_id.unsqueeze(
|
|
2396
|
+
0
|
|
2397
|
+
)
|
|
2398
|
+
causal = positions.unsqueeze(0) <= positions.unsqueeze(1)
|
|
2399
|
+
|
|
2400
|
+
attention_mask = same_sequence & causal
|
|
2401
|
+
return attention_mask
|
|
2402
|
+
|
|
2403
|
+
def _create_packed_position_ids(
|
|
2404
|
+
self, sequence_lengths: torch.Tensor, total_length: int | None = None
|
|
2405
|
+
) -> torch.Tensor:
|
|
2406
|
+
"""Create position IDs that restart from 0 for each sequence.
|
|
2407
|
+
|
|
2408
|
+
For sequences of length [3, 2], creates: [0, 1, 2, 0, 1]
|
|
2409
|
+
|
|
2410
|
+
No cuda syncs.
|
|
2411
|
+
"""
|
|
2412
|
+
if total_length is None:
|
|
2413
|
+
total_length = int(sequence_lengths.sum().item())
|
|
2414
|
+
|
|
2415
|
+
# Create global position IDs: [0, 1, 2, 3, 4]
|
|
2416
|
+
global_positions = torch.arange(total_length, device=sequence_lengths.device)
|
|
2417
|
+
|
|
2418
|
+
# Create sequence start offsets repeated for each position: [0, 0, 0, 3, 3]
|
|
2419
|
+
offsets = torch.cat(
|
|
2420
|
+
[
|
|
2421
|
+
torch.zeros(1, device=sequence_lengths.device),
|
|
2422
|
+
sequence_lengths.cumsum(0)[:-1],
|
|
2423
|
+
]
|
|
2424
|
+
)
|
|
2425
|
+
sequence_starts = offsets.repeat_interleave(sequence_lengths)
|
|
2426
|
+
|
|
2427
|
+
# Subtract to get local positions: [0, 1, 2, 0, 1]
|
|
2428
|
+
position_ids = global_positions - sequence_starts
|
|
2429
|
+
|
|
2430
|
+
return position_ids.unsqueeze(0) # (1, total_length)
|
|
2431
|
+
|
|
2432
|
+
|
|
2433
|
+
class RemoteTransformersWrapper:
|
|
2434
|
+
"""A remote Ray actor wrapper for TransformersWrapper that provides a simplified interface.
|
|
2435
|
+
|
|
2436
|
+
This class wraps a TransformersWrapper instance as a Ray actor, allowing remote execution
|
|
2437
|
+
while providing a clean interface that doesn't require explicit `remote()` and `get()` calls.
|
|
2438
|
+
|
|
2439
|
+
Args:
|
|
2440
|
+
model (str): The Hugging Face Transformers model to wrap.
|
|
2441
|
+
Must be a string (model name or path) that will be passed to `transformers.AutoModelForCausalLM.from_pretrained`.
|
|
2442
|
+
Transformers models are not serializable, so only model names/paths are supported.
|
|
2443
|
+
max_concurrency (int, optional): Maximum number of concurrent calls to the remote actor. Defaults to 16.
|
|
2444
|
+
validate_model (bool, optional): Whether to validate the model. Defaults to True.
|
|
2445
|
+
num_gpus (int, optional): Number of GPUs to use. Defaults to 0.
|
|
2446
|
+
num_cpus (int, optional): Number of CPUs to use. Defaults to 0.
|
|
2447
|
+
**kwargs: All other arguments are passed directly to TransformersWrapper.
|
|
2448
|
+
|
|
2449
|
+
Example:
|
|
2450
|
+
>>> import ray
|
|
2451
|
+
>>> from torchrl.modules.llm.policies import RemoteTransformersWrapper
|
|
2452
|
+
>>>
|
|
2453
|
+
>>> # Initialize Ray if not already done
|
|
2454
|
+
>>> if not ray.is_initialized():
|
|
2455
|
+
... ray.init()
|
|
2456
|
+
>>>
|
|
2457
|
+
>>> # Create remote wrapper
|
|
2458
|
+
>>> remote_wrapper = RemoteTransformersWrapper(
|
|
2459
|
+
... model="gpt2",
|
|
2460
|
+
... input_mode="history",
|
|
2461
|
+
... generate=True,
|
|
2462
|
+
... generate_kwargs={"max_new_tokens": 50}
|
|
2463
|
+
... )
|
|
2464
|
+
>>>
|
|
2465
|
+
>>> # Use like a regular wrapper (no remote/get calls needed)
|
|
2466
|
+
>>> result = remote_wrapper(tensordict_input)
|
|
2467
|
+
>>> print(result["text"].response)
|
|
2468
|
+
"""
|
|
2469
|
+
|
|
2470
|
+
def __init__(
|
|
2471
|
+
self,
|
|
2472
|
+
model,
|
|
2473
|
+
max_concurrency: int = 16,
|
|
2474
|
+
validate_model: bool = True,
|
|
2475
|
+
actor_name: str | None = None,
|
|
2476
|
+
num_gpus: int = 1,
|
|
2477
|
+
num_cpus: int = 1,
|
|
2478
|
+
**kwargs,
|
|
2479
|
+
):
|
|
2480
|
+
import ray
|
|
2481
|
+
|
|
2482
|
+
# Validate model parameter - only strings are allowed for Transformers
|
|
2483
|
+
if not isinstance(model, str) and validate_model:
|
|
2484
|
+
raise ValueError(
|
|
2485
|
+
"For RemoteTransformersWrapper, the model parameter must be a string "
|
|
2486
|
+
f"(model name or path). Got type: {type(model)}. "
|
|
2487
|
+
"Transformers models are not serializable, so only model names/paths are supported. "
|
|
2488
|
+
"You can bypass this check by setting validate_model=False."
|
|
2489
|
+
)
|
|
2490
|
+
|
|
2491
|
+
if not ray.is_initialized():
|
|
2492
|
+
ray.init()
|
|
2493
|
+
|
|
2494
|
+
if actor_name is not None:
|
|
2495
|
+
# Check if an actor with this name already exists
|
|
2496
|
+
try:
|
|
2497
|
+
existing_actor = ray.get_actor(actor_name)
|
|
2498
|
+
# If we can get the actor, assume it's alive and use it
|
|
2499
|
+
self._remote_wrapper = existing_actor
|
|
2500
|
+
torchrl_logger.info(f"Using existing actor {actor_name}")
|
|
2501
|
+
return
|
|
2502
|
+
except ValueError:
|
|
2503
|
+
# Actor doesn't exist, create a new one
|
|
2504
|
+
torchrl_logger.info(f"Creating new actor {actor_name}")
|
|
2505
|
+
|
|
2506
|
+
# Create the remote actor with the unique name
|
|
2507
|
+
self._remote_wrapper = (
|
|
2508
|
+
ray.remote(TransformersWrapper)
|
|
2509
|
+
.options(
|
|
2510
|
+
max_concurrency=max_concurrency,
|
|
2511
|
+
name=actor_name,
|
|
2512
|
+
num_gpus=num_gpus,
|
|
2513
|
+
num_cpus=num_cpus,
|
|
2514
|
+
)
|
|
2515
|
+
.remote(model, **kwargs)
|
|
2516
|
+
)
|
|
2517
|
+
|
|
2518
|
+
def __call__(self, tensordict, **kwargs):
|
|
2519
|
+
"""Forward pass that automatically handles remote execution."""
|
|
2520
|
+
import ray
|
|
2521
|
+
|
|
2522
|
+
return ray.get(self._remote_wrapper.forward.remote(tensordict, **kwargs))
|
|
2523
|
+
|
|
2524
|
+
def get_new_version(self, **kwargs):
|
|
2525
|
+
"""Get a new version of the wrapper with altered parameters."""
|
|
2526
|
+
import ray
|
|
2527
|
+
|
|
2528
|
+
return ray.get(self._remote_wrapper.get_new_version.remote(**kwargs))
|
|
2529
|
+
|
|
2530
|
+
def get_dist(self, tensordict, **kwargs):
|
|
2531
|
+
"""Get distribution from logits/log-probs with optional masking."""
|
|
2532
|
+
import ray
|
|
2533
|
+
|
|
2534
|
+
return ray.get(self._remote_wrapper.get_dist.remote(tensordict, **kwargs))
|
|
2535
|
+
|
|
2536
|
+
def get_dist_with_prompt_mask(self, tensordict, **kwargs):
|
|
2537
|
+
"""Get distribution masked to only include response tokens (exclude prompt)."""
|
|
2538
|
+
import ray
|
|
2539
|
+
|
|
2540
|
+
return ray.get(
|
|
2541
|
+
self._remote_wrapper.get_dist_with_prompt_mask.remote(tensordict, **kwargs)
|
|
2542
|
+
)
|
|
2543
|
+
|
|
2544
|
+
def _get_dist_with_assistant_mask(self, tensordict, **kwargs):
|
|
2545
|
+
"""Get distribution masked to only include assistant tokens."""
|
|
2546
|
+
import ray
|
|
2547
|
+
|
|
2548
|
+
return ray.get(
|
|
2549
|
+
self._remote_wrapper._get_dist_with_assistant_mask.remote(
|
|
2550
|
+
tensordict, **kwargs
|
|
2551
|
+
)
|
|
2552
|
+
)
|
|
2553
|
+
|
|
2554
|
+
def _get_dist_with_attention_mask(self, tensordict, **kwargs):
|
|
2555
|
+
"""Get distribution masked using attention mask."""
|
|
2556
|
+
import ray
|
|
2557
|
+
|
|
2558
|
+
return ray.get(
|
|
2559
|
+
self._remote_wrapper._get_dist_with_attention_mask.remote(
|
|
2560
|
+
tensordict, **kwargs
|
|
2561
|
+
)
|
|
2562
|
+
)
|
|
2563
|
+
|
|
2564
|
+
def _get_dist_with_custom_mask(self, tensordict, **kwargs):
|
|
2565
|
+
"""Get distribution with custom mask."""
|
|
2566
|
+
import ray
|
|
2567
|
+
|
|
2568
|
+
return ray.get(
|
|
2569
|
+
self._remote_wrapper._get_dist_with_custom_mask.remote(tensordict, **kwargs)
|
|
2570
|
+
)
|
|
2571
|
+
|
|
2572
|
+
def _get_sft_dist(self, tensordict, **kwargs):
|
|
2573
|
+
"""Get distribution suitable for SFT loss (response tokens only)."""
|
|
2574
|
+
import ray
|
|
2575
|
+
|
|
2576
|
+
return ray.get(self._remote_wrapper._get_sft_dist.remote(tensordict, **kwargs))
|
|
2577
|
+
|
|
2578
|
+
def _get_rlhf_dist(self, tensordict, **kwargs):
|
|
2579
|
+
"""Get distribution suitable for RLHF loss (assistant tokens only)."""
|
|
2580
|
+
import ray
|
|
2581
|
+
|
|
2582
|
+
return ray.get(self._remote_wrapper._get_rlhf_dist.remote(tensordict, **kwargs))
|
|
2583
|
+
|
|
2584
|
+
def _get_generic_dist(self, tensordict, **kwargs):
|
|
2585
|
+
"""Get distribution suitable for generic losses (all tokens)."""
|
|
2586
|
+
import ray
|
|
2587
|
+
|
|
2588
|
+
return ray.get(
|
|
2589
|
+
self._remote_wrapper._get_generic_dist.remote(tensordict, **kwargs)
|
|
2590
|
+
)
|
|
2591
|
+
|
|
2592
|
+
def log_prob(self, data, **kwargs):
|
|
2593
|
+
"""Compute log probabilities."""
|
|
2594
|
+
import ray
|
|
2595
|
+
|
|
2596
|
+
return ray.get(self._remote_wrapper.log_prob.remote(data, **kwargs))
|
|
2597
|
+
|
|
2598
|
+
def cleanup_batching(self):
|
|
2599
|
+
"""Clean up batching resources."""
|
|
2600
|
+
import ray
|
|
2601
|
+
|
|
2602
|
+
return ray.get(self._remote_wrapper.cleanup_batching.remote())
|
|
2603
|
+
|
|
2604
|
+
def __del__(self):
|
|
2605
|
+
"""Cleanup when the wrapper is destroyed."""
|
|
2606
|
+
try:
|
|
2607
|
+
import ray
|
|
2608
|
+
|
|
2609
|
+
if hasattr(self, "_remote_wrapper") and ray.is_initialized():
|
|
2610
|
+
# Clean up batching resources
|
|
2611
|
+
try:
|
|
2612
|
+
ray.get(self._remote_wrapper.cleanup_batching.remote())
|
|
2613
|
+
except Exception:
|
|
2614
|
+
pass # Ignore cleanup errors during destruction
|
|
2615
|
+
except Exception:
|
|
2616
|
+
pass # Ignore any errors during cleanup
|
|
2617
|
+
|
|
2618
|
+
def __enter__(self):
|
|
2619
|
+
"""Context manager entry."""
|
|
2620
|
+
return self
|
|
2621
|
+
|
|
2622
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
2623
|
+
"""Context manager exit with cleanup."""
|
|
2624
|
+
self.cleanup_batching()
|
|
2625
|
+
|
|
2626
|
+
def get_batching_state(self):
|
|
2627
|
+
"""Get the current batching state."""
|
|
2628
|
+
import ray
|
|
2629
|
+
|
|
2630
|
+
return ray.get(self._remote_wrapper.get_batching_state.remote())
|
|
2631
|
+
|
|
2632
|
+
@property
|
|
2633
|
+
def generate(self):
|
|
2634
|
+
"""Whether text generation is enabled."""
|
|
2635
|
+
import ray
|
|
2636
|
+
|
|
2637
|
+
return ray.get(self._remote_wrapper.generate.remote)
|
|
2638
|
+
|
|
2639
|
+
@property
|
|
2640
|
+
def pad_output(self):
|
|
2641
|
+
"""Whether output sequences are padded."""
|
|
2642
|
+
import ray
|
|
2643
|
+
|
|
2644
|
+
return ray.get(self._remote_wrapper.pad_output.remote)
|
|
2645
|
+
|
|
2646
|
+
@property
|
|
2647
|
+
def text_key(self):
|
|
2648
|
+
"""The key for text output."""
|
|
2649
|
+
import ray
|
|
2650
|
+
|
|
2651
|
+
return ray.get(self._remote_wrapper.text_key.remote)
|
|
2652
|
+
|
|
2653
|
+
@property
|
|
2654
|
+
def tokens_key(self):
|
|
2655
|
+
"""The key for tokens output."""
|
|
2656
|
+
import ray
|
|
2657
|
+
|
|
2658
|
+
return ray.get(self._remote_wrapper.tokens_key.remote)
|
|
2659
|
+
|
|
2660
|
+
@property
|
|
2661
|
+
def masks_key(self):
|
|
2662
|
+
"""The key for masks output."""
|
|
2663
|
+
import ray
|
|
2664
|
+
|
|
2665
|
+
return ray.get(self._remote_wrapper.masks_key.remote)
|
|
2666
|
+
|
|
2667
|
+
@property
|
|
2668
|
+
def log_probs_key(self):
|
|
2669
|
+
"""The key for log probabilities output."""
|
|
2670
|
+
import ray
|
|
2671
|
+
|
|
2672
|
+
return ray.get(self._remote_wrapper.log_probs_key.remote)
|
|
2673
|
+
|
|
2674
|
+
@property
|
|
2675
|
+
def in_keys(self):
|
|
2676
|
+
"""The input keys."""
|
|
2677
|
+
import ray
|
|
2678
|
+
|
|
2679
|
+
return ray.get(self._remote_wrapper.in_keys.remote)
|
|
2680
|
+
|
|
2681
|
+
@property
|
|
2682
|
+
def out_keys(self):
|
|
2683
|
+
"""The output keys."""
|
|
2684
|
+
import ray
|
|
2685
|
+
|
|
2686
|
+
return ray.get(self._remote_wrapper.out_keys.remote)
|
|
2687
|
+
|
|
2688
|
+
@property
|
|
2689
|
+
def inplace(self):
|
|
2690
|
+
"""Whether in-place operations are used."""
|
|
2691
|
+
import ray
|
|
2692
|
+
|
|
2693
|
+
return ray.get(self._remote_wrapper.inplace.remote)
|
|
2694
|
+
|
|
2695
|
+
@property
|
|
2696
|
+
def device(self):
|
|
2697
|
+
"""The device used for computation."""
|
|
2698
|
+
import ray
|
|
2699
|
+
|
|
2700
|
+
return ray.get(self._remote_wrapper.device.remote)
|
|
2701
|
+
|
|
2702
|
+
@property
|
|
2703
|
+
def layout(self):
|
|
2704
|
+
"""The layout used for output tensors."""
|
|
2705
|
+
import ray
|
|
2706
|
+
|
|
2707
|
+
return ray.get(self._remote_wrapper.layout.remote)
|
|
2708
|
+
|
|
2709
|
+
@property
|
|
2710
|
+
def num_samples(self):
|
|
2711
|
+
"""The number of samples to generate."""
|
|
2712
|
+
import ray
|
|
2713
|
+
|
|
2714
|
+
return ray.get(self._remote_wrapper.num_samples.remote)
|
|
2715
|
+
|
|
2716
|
+
@property
|
|
2717
|
+
def batching(self):
|
|
2718
|
+
"""Whether batching is enabled."""
|
|
2719
|
+
import ray
|
|
2720
|
+
|
|
2721
|
+
return ray.get(self._remote_wrapper.batching.remote)
|
|
2722
|
+
|
|
2723
|
+
@property
|
|
2724
|
+
def collector(self):
|
|
2725
|
+
"""The collector associated with the module."""
|
|
2726
|
+
import ray
|
|
2727
|
+
|
|
2728
|
+
return ray.get(self._remote_wrapper.collector.remote)
|
|
2729
|
+
|
|
2730
|
+
@property
|
|
2731
|
+
def log_prob_keys(self):
|
|
2732
|
+
"""The keys for log probabilities."""
|
|
2733
|
+
import ray
|
|
2734
|
+
|
|
2735
|
+
return ray.get(self._remote_wrapper.log_prob_keys.remote)
|
|
2736
|
+
|
|
2737
|
+
@log_prob_keys.setter
|
|
2738
|
+
def log_prob_keys(self, value):
|
|
2739
|
+
"""Set the keys for log probabilities."""
|
|
2740
|
+
import ray
|
|
2741
|
+
|
|
2742
|
+
ray.get(self._remote_wrapper.log_prob_keys.remote(value))
|
|
2743
|
+
|
|
2744
|
+
@property
|
|
2745
|
+
def dist_params_keys(self):
|
|
2746
|
+
"""The keys for distribution parameters."""
|
|
2747
|
+
import ray
|
|
2748
|
+
|
|
2749
|
+
return ray.get(self._remote_wrapper.dist_params_keys.remote)
|
|
2750
|
+
|
|
2751
|
+
@property
|
|
2752
|
+
def dist_sample_keys(self):
|
|
2753
|
+
"""The keys for distribution samples."""
|
|
2754
|
+
import ray
|
|
2755
|
+
|
|
2756
|
+
return ray.get(self._remote_wrapper.dist_sample_keys.remote)
|