torchrl 0.11.0__cp314-cp314t-manylinux_2_28_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/benchmark_batched_envs.py +104 -0
- benchmarks/conftest.py +91 -0
- benchmarks/ecosystem/gym_env_throughput.py +321 -0
- benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
- benchmarks/requirements.txt +7 -0
- benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
- benchmarks/test_collectors_benchmark.py +240 -0
- benchmarks/test_compressed_storage_benchmark.py +145 -0
- benchmarks/test_envs_benchmark.py +133 -0
- benchmarks/test_llm.py +101 -0
- benchmarks/test_non_tensor_env_benchmark.py +70 -0
- benchmarks/test_objectives_benchmarks.py +1199 -0
- benchmarks/test_replaybuffer_benchmark.py +254 -0
- sota-check/README.md +35 -0
- sota-implementations/README.md +142 -0
- sota-implementations/a2c/README.md +39 -0
- sota-implementations/a2c/a2c_atari.py +291 -0
- sota-implementations/a2c/a2c_mujoco.py +273 -0
- sota-implementations/a2c/utils_atari.py +240 -0
- sota-implementations/a2c/utils_mujoco.py +160 -0
- sota-implementations/bandits/README.md +7 -0
- sota-implementations/bandits/dqn.py +126 -0
- sota-implementations/cql/cql_offline.py +198 -0
- sota-implementations/cql/cql_online.py +249 -0
- sota-implementations/cql/discrete_cql_offline.py +180 -0
- sota-implementations/cql/discrete_cql_online.py +227 -0
- sota-implementations/cql/utils.py +471 -0
- sota-implementations/crossq/crossq.py +271 -0
- sota-implementations/crossq/utils.py +320 -0
- sota-implementations/ddpg/ddpg.py +231 -0
- sota-implementations/ddpg/utils.py +325 -0
- sota-implementations/decision_transformer/dt.py +163 -0
- sota-implementations/decision_transformer/lamb.py +167 -0
- sota-implementations/decision_transformer/online_dt.py +178 -0
- sota-implementations/decision_transformer/utils.py +562 -0
- sota-implementations/discrete_sac/discrete_sac.py +243 -0
- sota-implementations/discrete_sac/utils.py +324 -0
- sota-implementations/dqn/README.md +30 -0
- sota-implementations/dqn/dqn_atari.py +272 -0
- sota-implementations/dqn/dqn_cartpole.py +236 -0
- sota-implementations/dqn/utils_atari.py +132 -0
- sota-implementations/dqn/utils_cartpole.py +90 -0
- sota-implementations/dreamer/README.md +129 -0
- sota-implementations/dreamer/dreamer.py +586 -0
- sota-implementations/dreamer/dreamer_utils.py +1107 -0
- sota-implementations/expert-iteration/README.md +352 -0
- sota-implementations/expert-iteration/ei_utils.py +770 -0
- sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
- sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
- sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
- sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
- sota-implementations/gail/gail.py +327 -0
- sota-implementations/gail/gail_utils.py +68 -0
- sota-implementations/gail/ppo_utils.py +157 -0
- sota-implementations/grpo/README.md +273 -0
- sota-implementations/grpo/grpo-async.py +437 -0
- sota-implementations/grpo/grpo-sync.py +435 -0
- sota-implementations/grpo/grpo_utils.py +843 -0
- sota-implementations/grpo/requirements_gsm8k.txt +11 -0
- sota-implementations/grpo/requirements_ifeval.txt +16 -0
- sota-implementations/impala/README.md +33 -0
- sota-implementations/impala/impala_multi_node_ray.py +292 -0
- sota-implementations/impala/impala_multi_node_submitit.py +284 -0
- sota-implementations/impala/impala_single_node.py +261 -0
- sota-implementations/impala/utils.py +184 -0
- sota-implementations/iql/discrete_iql.py +230 -0
- sota-implementations/iql/iql_offline.py +164 -0
- sota-implementations/iql/iql_online.py +225 -0
- sota-implementations/iql/utils.py +437 -0
- sota-implementations/multiagent/README.md +74 -0
- sota-implementations/multiagent/iql.py +237 -0
- sota-implementations/multiagent/maddpg_iddpg.py +266 -0
- sota-implementations/multiagent/mappo_ippo.py +267 -0
- sota-implementations/multiagent/qmix_vdn.py +271 -0
- sota-implementations/multiagent/sac.py +337 -0
- sota-implementations/multiagent/utils/__init__.py +4 -0
- sota-implementations/multiagent/utils/logging.py +151 -0
- sota-implementations/multiagent/utils/utils.py +43 -0
- sota-implementations/ppo/README.md +29 -0
- sota-implementations/ppo/ppo_atari.py +305 -0
- sota-implementations/ppo/ppo_mujoco.py +293 -0
- sota-implementations/ppo/utils_atari.py +238 -0
- sota-implementations/ppo/utils_mujoco.py +152 -0
- sota-implementations/ppo_trainer/train.py +21 -0
- sota-implementations/redq/README.md +7 -0
- sota-implementations/redq/redq.py +199 -0
- sota-implementations/redq/utils.py +1060 -0
- sota-implementations/sac/sac-async.py +266 -0
- sota-implementations/sac/sac.py +239 -0
- sota-implementations/sac/utils.py +381 -0
- sota-implementations/sac_trainer/train.py +16 -0
- sota-implementations/td3/td3.py +254 -0
- sota-implementations/td3/utils.py +319 -0
- sota-implementations/td3_bc/td3_bc.py +177 -0
- sota-implementations/td3_bc/utils.py +251 -0
- torchrl/__init__.py +144 -0
- torchrl/_extension.py +74 -0
- torchrl/_torchrl.cpython-314t-aarch64-linux-gnu.so +0 -0
- torchrl/_utils.py +1431 -0
- torchrl/collectors/__init__.py +48 -0
- torchrl/collectors/_base.py +1058 -0
- torchrl/collectors/_constants.py +88 -0
- torchrl/collectors/_multi_async.py +324 -0
- torchrl/collectors/_multi_base.py +1805 -0
- torchrl/collectors/_multi_sync.py +464 -0
- torchrl/collectors/_runner.py +581 -0
- torchrl/collectors/_single.py +2009 -0
- torchrl/collectors/_single_async.py +259 -0
- torchrl/collectors/collectors.py +62 -0
- torchrl/collectors/distributed/__init__.py +32 -0
- torchrl/collectors/distributed/default_configs.py +133 -0
- torchrl/collectors/distributed/generic.py +1306 -0
- torchrl/collectors/distributed/ray.py +1092 -0
- torchrl/collectors/distributed/rpc.py +1006 -0
- torchrl/collectors/distributed/sync.py +731 -0
- torchrl/collectors/distributed/utils.py +160 -0
- torchrl/collectors/llm/__init__.py +10 -0
- torchrl/collectors/llm/base.py +494 -0
- torchrl/collectors/llm/ray_collector.py +275 -0
- torchrl/collectors/llm/utils.py +36 -0
- torchrl/collectors/llm/weight_update/__init__.py +10 -0
- torchrl/collectors/llm/weight_update/vllm.py +348 -0
- torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
- torchrl/collectors/utils.py +433 -0
- torchrl/collectors/weight_update.py +591 -0
- torchrl/csrc/numpy_utils.h +38 -0
- torchrl/csrc/pybind.cpp +27 -0
- torchrl/csrc/segment_tree.h +458 -0
- torchrl/csrc/torch_utils.h +34 -0
- torchrl/csrc/utils.cpp +48 -0
- torchrl/csrc/utils.h +31 -0
- torchrl/data/__init__.py +187 -0
- torchrl/data/datasets/__init__.py +58 -0
- torchrl/data/datasets/atari_dqn.py +878 -0
- torchrl/data/datasets/common.py +281 -0
- torchrl/data/datasets/d4rl.py +489 -0
- torchrl/data/datasets/d4rl_infos.py +187 -0
- torchrl/data/datasets/gen_dgrl.py +375 -0
- torchrl/data/datasets/minari_data.py +643 -0
- torchrl/data/datasets/openml.py +177 -0
- torchrl/data/datasets/openx.py +798 -0
- torchrl/data/datasets/roboset.py +363 -0
- torchrl/data/datasets/utils.py +11 -0
- torchrl/data/datasets/vd4rl.py +432 -0
- torchrl/data/llm/__init__.py +34 -0
- torchrl/data/llm/dataset.py +491 -0
- torchrl/data/llm/history.py +1378 -0
- torchrl/data/llm/prompt.py +198 -0
- torchrl/data/llm/reward.py +225 -0
- torchrl/data/llm/topk.py +186 -0
- torchrl/data/llm/utils.py +543 -0
- torchrl/data/map/__init__.py +21 -0
- torchrl/data/map/hash.py +185 -0
- torchrl/data/map/query.py +204 -0
- torchrl/data/map/tdstorage.py +363 -0
- torchrl/data/map/tree.py +1434 -0
- torchrl/data/map/utils.py +103 -0
- torchrl/data/postprocs/__init__.py +8 -0
- torchrl/data/postprocs/postprocs.py +391 -0
- torchrl/data/replay_buffers/__init__.py +99 -0
- torchrl/data/replay_buffers/checkpointers.py +622 -0
- torchrl/data/replay_buffers/ray_buffer.py +292 -0
- torchrl/data/replay_buffers/replay_buffers.py +2376 -0
- torchrl/data/replay_buffers/samplers.py +2578 -0
- torchrl/data/replay_buffers/scheduler.py +265 -0
- torchrl/data/replay_buffers/storages.py +2412 -0
- torchrl/data/replay_buffers/utils.py +1042 -0
- torchrl/data/replay_buffers/writers.py +781 -0
- torchrl/data/tensor_specs.py +7101 -0
- torchrl/data/utils.py +334 -0
- torchrl/envs/__init__.py +265 -0
- torchrl/envs/async_envs.py +1105 -0
- torchrl/envs/batched_envs.py +3093 -0
- torchrl/envs/common.py +4241 -0
- torchrl/envs/custom/__init__.py +11 -0
- torchrl/envs/custom/chess.py +617 -0
- torchrl/envs/custom/llm.py +214 -0
- torchrl/envs/custom/pendulum.py +401 -0
- torchrl/envs/custom/san_moves.txt +29274 -0
- torchrl/envs/custom/tictactoeenv.py +288 -0
- torchrl/envs/env_creator.py +263 -0
- torchrl/envs/gym_like.py +752 -0
- torchrl/envs/libs/__init__.py +68 -0
- torchrl/envs/libs/_gym_utils.py +326 -0
- torchrl/envs/libs/brax.py +846 -0
- torchrl/envs/libs/dm_control.py +544 -0
- torchrl/envs/libs/envpool.py +447 -0
- torchrl/envs/libs/gym.py +2239 -0
- torchrl/envs/libs/habitat.py +138 -0
- torchrl/envs/libs/isaac_lab.py +87 -0
- torchrl/envs/libs/isaacgym.py +203 -0
- torchrl/envs/libs/jax_utils.py +166 -0
- torchrl/envs/libs/jumanji.py +963 -0
- torchrl/envs/libs/meltingpot.py +599 -0
- torchrl/envs/libs/openml.py +153 -0
- torchrl/envs/libs/openspiel.py +652 -0
- torchrl/envs/libs/pettingzoo.py +1042 -0
- torchrl/envs/libs/procgen.py +351 -0
- torchrl/envs/libs/robohive.py +429 -0
- torchrl/envs/libs/smacv2.py +645 -0
- torchrl/envs/libs/unity_mlagents.py +891 -0
- torchrl/envs/libs/utils.py +147 -0
- torchrl/envs/libs/vmas.py +813 -0
- torchrl/envs/llm/__init__.py +63 -0
- torchrl/envs/llm/chat.py +730 -0
- torchrl/envs/llm/datasets/README.md +4 -0
- torchrl/envs/llm/datasets/__init__.py +17 -0
- torchrl/envs/llm/datasets/gsm8k.py +353 -0
- torchrl/envs/llm/datasets/ifeval.py +274 -0
- torchrl/envs/llm/envs.py +789 -0
- torchrl/envs/llm/libs/README.md +3 -0
- torchrl/envs/llm/libs/__init__.py +8 -0
- torchrl/envs/llm/libs/mlgym.py +869 -0
- torchrl/envs/llm/reward/__init__.py +10 -0
- torchrl/envs/llm/reward/gsm8k.py +324 -0
- torchrl/envs/llm/reward/ifeval/README.md +13 -0
- torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
- torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
- torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
- torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
- torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
- torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
- torchrl/envs/llm/transforms/__init__.py +55 -0
- torchrl/envs/llm/transforms/browser.py +292 -0
- torchrl/envs/llm/transforms/dataloading.py +859 -0
- torchrl/envs/llm/transforms/format.py +73 -0
- torchrl/envs/llm/transforms/kl.py +1544 -0
- torchrl/envs/llm/transforms/policy_version.py +189 -0
- torchrl/envs/llm/transforms/reason.py +323 -0
- torchrl/envs/llm/transforms/tokenizer.py +321 -0
- torchrl/envs/llm/transforms/tools.py +1955 -0
- torchrl/envs/model_based/__init__.py +9 -0
- torchrl/envs/model_based/common.py +180 -0
- torchrl/envs/model_based/dreamer.py +112 -0
- torchrl/envs/transforms/__init__.py +147 -0
- torchrl/envs/transforms/functional.py +48 -0
- torchrl/envs/transforms/gym_transforms.py +203 -0
- torchrl/envs/transforms/module.py +341 -0
- torchrl/envs/transforms/r3m.py +372 -0
- torchrl/envs/transforms/ray_service.py +663 -0
- torchrl/envs/transforms/rb_transforms.py +214 -0
- torchrl/envs/transforms/transforms.py +11835 -0
- torchrl/envs/transforms/utils.py +94 -0
- torchrl/envs/transforms/vc1.py +307 -0
- torchrl/envs/transforms/vecnorm.py +845 -0
- torchrl/envs/transforms/vip.py +407 -0
- torchrl/envs/utils.py +1718 -0
- torchrl/envs/vec_envs.py +11 -0
- torchrl/modules/__init__.py +206 -0
- torchrl/modules/distributions/__init__.py +73 -0
- torchrl/modules/distributions/continuous.py +830 -0
- torchrl/modules/distributions/discrete.py +908 -0
- torchrl/modules/distributions/truncated_normal.py +187 -0
- torchrl/modules/distributions/utils.py +233 -0
- torchrl/modules/llm/__init__.py +62 -0
- torchrl/modules/llm/backends/__init__.py +65 -0
- torchrl/modules/llm/backends/vllm/__init__.py +94 -0
- torchrl/modules/llm/backends/vllm/_models.py +46 -0
- torchrl/modules/llm/backends/vllm/base.py +72 -0
- torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
- torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
- torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
- torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
- torchrl/modules/llm/policies/__init__.py +28 -0
- torchrl/modules/llm/policies/common.py +1809 -0
- torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
- torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
- torchrl/modules/llm/utils.py +23 -0
- torchrl/modules/mcts/__init__.py +21 -0
- torchrl/modules/mcts/scores.py +579 -0
- torchrl/modules/models/__init__.py +86 -0
- torchrl/modules/models/batchrenorm.py +119 -0
- torchrl/modules/models/decision_transformer.py +179 -0
- torchrl/modules/models/exploration.py +731 -0
- torchrl/modules/models/llm.py +156 -0
- torchrl/modules/models/model_based.py +596 -0
- torchrl/modules/models/models.py +1712 -0
- torchrl/modules/models/multiagent.py +1067 -0
- torchrl/modules/models/recipes/impala.py +185 -0
- torchrl/modules/models/utils.py +162 -0
- torchrl/modules/planners/__init__.py +10 -0
- torchrl/modules/planners/cem.py +228 -0
- torchrl/modules/planners/common.py +73 -0
- torchrl/modules/planners/mppi.py +265 -0
- torchrl/modules/tensordict_module/__init__.py +89 -0
- torchrl/modules/tensordict_module/actors.py +2457 -0
- torchrl/modules/tensordict_module/common.py +529 -0
- torchrl/modules/tensordict_module/exploration.py +814 -0
- torchrl/modules/tensordict_module/probabilistic.py +321 -0
- torchrl/modules/tensordict_module/rnn.py +1639 -0
- torchrl/modules/tensordict_module/sequence.py +132 -0
- torchrl/modules/tensordict_module/world_models.py +34 -0
- torchrl/modules/utils/__init__.py +38 -0
- torchrl/modules/utils/mappings.py +9 -0
- torchrl/modules/utils/utils.py +89 -0
- torchrl/objectives/__init__.py +78 -0
- torchrl/objectives/a2c.py +659 -0
- torchrl/objectives/common.py +753 -0
- torchrl/objectives/cql.py +1346 -0
- torchrl/objectives/crossq.py +710 -0
- torchrl/objectives/ddpg.py +453 -0
- torchrl/objectives/decision_transformer.py +371 -0
- torchrl/objectives/deprecated.py +516 -0
- torchrl/objectives/dqn.py +683 -0
- torchrl/objectives/dreamer.py +488 -0
- torchrl/objectives/functional.py +48 -0
- torchrl/objectives/gail.py +258 -0
- torchrl/objectives/iql.py +996 -0
- torchrl/objectives/llm/__init__.py +30 -0
- torchrl/objectives/llm/grpo.py +846 -0
- torchrl/objectives/llm/sft.py +482 -0
- torchrl/objectives/multiagent/__init__.py +8 -0
- torchrl/objectives/multiagent/qmixer.py +396 -0
- torchrl/objectives/ppo.py +1669 -0
- torchrl/objectives/redq.py +683 -0
- torchrl/objectives/reinforce.py +530 -0
- torchrl/objectives/sac.py +1580 -0
- torchrl/objectives/td3.py +570 -0
- torchrl/objectives/td3_bc.py +625 -0
- torchrl/objectives/utils.py +782 -0
- torchrl/objectives/value/__init__.py +28 -0
- torchrl/objectives/value/advantages.py +1956 -0
- torchrl/objectives/value/functional.py +1459 -0
- torchrl/objectives/value/utils.py +360 -0
- torchrl/record/__init__.py +17 -0
- torchrl/record/loggers/__init__.py +23 -0
- torchrl/record/loggers/common.py +48 -0
- torchrl/record/loggers/csv.py +226 -0
- torchrl/record/loggers/mlflow.py +142 -0
- torchrl/record/loggers/tensorboard.py +139 -0
- torchrl/record/loggers/trackio.py +163 -0
- torchrl/record/loggers/utils.py +78 -0
- torchrl/record/loggers/wandb.py +214 -0
- torchrl/record/recorder.py +554 -0
- torchrl/services/__init__.py +79 -0
- torchrl/services/base.py +109 -0
- torchrl/services/ray_service.py +453 -0
- torchrl/testing/__init__.py +107 -0
- torchrl/testing/assertions.py +179 -0
- torchrl/testing/dist_utils.py +122 -0
- torchrl/testing/env_creators.py +227 -0
- torchrl/testing/env_helper.py +35 -0
- torchrl/testing/gym_helpers.py +156 -0
- torchrl/testing/llm_mocks.py +119 -0
- torchrl/testing/mocking_classes.py +2720 -0
- torchrl/testing/modules.py +295 -0
- torchrl/testing/mp_helpers.py +15 -0
- torchrl/testing/ray_helpers.py +293 -0
- torchrl/testing/utils.py +190 -0
- torchrl/trainers/__init__.py +42 -0
- torchrl/trainers/algorithms/__init__.py +11 -0
- torchrl/trainers/algorithms/configs/__init__.py +705 -0
- torchrl/trainers/algorithms/configs/collectors.py +216 -0
- torchrl/trainers/algorithms/configs/common.py +41 -0
- torchrl/trainers/algorithms/configs/data.py +308 -0
- torchrl/trainers/algorithms/configs/envs.py +104 -0
- torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
- torchrl/trainers/algorithms/configs/logging.py +80 -0
- torchrl/trainers/algorithms/configs/modules.py +570 -0
- torchrl/trainers/algorithms/configs/objectives.py +177 -0
- torchrl/trainers/algorithms/configs/trainers.py +340 -0
- torchrl/trainers/algorithms/configs/transforms.py +955 -0
- torchrl/trainers/algorithms/configs/utils.py +252 -0
- torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
- torchrl/trainers/algorithms/configs/weight_update.py +159 -0
- torchrl/trainers/algorithms/ppo.py +373 -0
- torchrl/trainers/algorithms/sac.py +308 -0
- torchrl/trainers/helpers/__init__.py +40 -0
- torchrl/trainers/helpers/collectors.py +416 -0
- torchrl/trainers/helpers/envs.py +573 -0
- torchrl/trainers/helpers/logger.py +33 -0
- torchrl/trainers/helpers/losses.py +132 -0
- torchrl/trainers/helpers/models.py +658 -0
- torchrl/trainers/helpers/replay_buffer.py +59 -0
- torchrl/trainers/helpers/trainers.py +301 -0
- torchrl/trainers/trainers.py +2052 -0
- torchrl/weight_update/__init__.py +33 -0
- torchrl/weight_update/_distributed.py +749 -0
- torchrl/weight_update/_mp.py +624 -0
- torchrl/weight_update/_noupdate.py +102 -0
- torchrl/weight_update/_ray.py +1032 -0
- torchrl/weight_update/_rpc.py +284 -0
- torchrl/weight_update/_shared.py +891 -0
- torchrl/weight_update/llm/__init__.py +32 -0
- torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
- torchrl/weight_update/llm/vllm_nccl.py +710 -0
- torchrl/weight_update/utils.py +73 -0
- torchrl/weight_update/weight_sync_schemes.py +1244 -0
- torchrl-0.11.0.dist-info/METADATA +1308 -0
- torchrl-0.11.0.dist-info/RECORD +394 -0
- torchrl-0.11.0.dist-info/WHEEL +5 -0
- torchrl-0.11.0.dist-info/entry_points.txt +2 -0
- torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
- torchrl-0.11.0.dist-info/top_level.txt +7 -0
|
@@ -0,0 +1,2241 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import collections
|
|
8
|
+
|
|
9
|
+
import importlib.util
|
|
10
|
+
import threading
|
|
11
|
+
import warnings
|
|
12
|
+
from typing import Any, Literal, TYPE_CHECKING
|
|
13
|
+
|
|
14
|
+
import torch
|
|
15
|
+
from tensordict import (
|
|
16
|
+
lazy_stack,
|
|
17
|
+
LazyStackedTensorDict,
|
|
18
|
+
MetaData,
|
|
19
|
+
NonTensorStack,
|
|
20
|
+
set_list_to_stack,
|
|
21
|
+
TensorDict,
|
|
22
|
+
TensorDictBase,
|
|
23
|
+
)
|
|
24
|
+
from tensordict.tensorclass import from_dataclass, TensorClass
|
|
25
|
+
from tensordict.utils import _zip_strict, NestedKey
|
|
26
|
+
from torch import distributions as D
|
|
27
|
+
from torch.nn.utils.rnn import pad_sequence
|
|
28
|
+
|
|
29
|
+
from torchrl.envs.utils import _classproperty
|
|
30
|
+
from torchrl.modules.llm.policies.common import (
|
|
31
|
+
_batching,
|
|
32
|
+
_extract_responses_from_full_histories,
|
|
33
|
+
ChatHistory,
|
|
34
|
+
LLMWrapperBase,
|
|
35
|
+
LogProbs,
|
|
36
|
+
Masks,
|
|
37
|
+
Text,
|
|
38
|
+
Tokens,
|
|
39
|
+
)
|
|
40
|
+
from torchrl.modules.utils.utils import _unpad_tensors
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
_HAS_VLLM = importlib.util.find_spec("vllm") is not None
|
|
44
|
+
_HAS_TRANSFORMERS = importlib.util.find_spec("transformers") is not None
|
|
45
|
+
|
|
46
|
+
if TYPE_CHECKING:
|
|
47
|
+
from vllm.inputs import TokensPrompt # type: ignore[import-not-found]
|
|
48
|
+
from vllm.outputs import RequestOutput # type: ignore[import-not-found]
|
|
49
|
+
from vllm.sampling_params import SamplingParams # type: ignore[import-not-found]
|
|
50
|
+
elif _HAS_VLLM:
|
|
51
|
+
from vllm.outputs import RequestOutput
|
|
52
|
+
from vllm.sampling_params import SamplingParams
|
|
53
|
+
|
|
54
|
+
try:
|
|
55
|
+
from vllm.inputs import TokensPrompt
|
|
56
|
+
except ImportError:
|
|
57
|
+
# Fallback for older vLLM versions
|
|
58
|
+
TokensPrompt = None
|
|
59
|
+
else:
|
|
60
|
+
SamplingParams = None # Will error at usage if vLLM not available
|
|
61
|
+
RequestOutput = None
|
|
62
|
+
TokensPrompt = None
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def _require_transformers() -> None:
|
|
66
|
+
if not _HAS_TRANSFORMERS:
|
|
67
|
+
raise ImportError(
|
|
68
|
+
"transformers is required for vLLMWrapper. Please install it with `pip install transformers`."
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def _require_vllm():
|
|
73
|
+
"""Import vLLM lazily.
|
|
74
|
+
|
|
75
|
+
We intentionally avoid importing vLLM at module import time because importing vLLM can
|
|
76
|
+
load native extensions that may hard-crash the interpreter on some platforms.
|
|
77
|
+
"""
|
|
78
|
+
if not _HAS_VLLM:
|
|
79
|
+
raise ImportError(
|
|
80
|
+
"vllm is required for vLLMWrapper. Please install it with `pip install vllm`."
|
|
81
|
+
)
|
|
82
|
+
import vllm as _vllm # local import is intentional / required
|
|
83
|
+
|
|
84
|
+
return _vllm
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
# Import async vLLM engines
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
class vLLMWrapper(LLMWrapperBase):
|
|
91
|
+
"""A wrapper class for vLLM models, providing a consistent interface for text generation and log probability computation.
|
|
92
|
+
|
|
93
|
+
This class is a subclass of :class:`~torchrl.modules.llm.policies.LLMWrapperBase` and provides a unified API for handling different input
|
|
94
|
+
modalities (history, text, tokens) with consistent output structure using :class:`~tensordict.TensorClass` objects.
|
|
95
|
+
|
|
96
|
+
The wrapper supports both synchronous (vllm.LLM) and asynchronous (:class:`~torchrl.modules.llm.backends.AsyncVLLM`) vLLM engines.
|
|
97
|
+
|
|
98
|
+
.. note::
|
|
99
|
+
**Recommended: Use AsyncVLLM for better performance**
|
|
100
|
+
|
|
101
|
+
For distributed inference and better resource utilization, we recommend using
|
|
102
|
+
:class:`~torchrl.modules.llm.backends.AsyncVLLM` instead of the synchronous vllm.LLM:
|
|
103
|
+
|
|
104
|
+
>>> from torchrl.modules.llm.backends import AsyncVLLM
|
|
105
|
+
>>> from torchrl.modules.llm import vLLMWrapper
|
|
106
|
+
>>>
|
|
107
|
+
>>> # Recommended approach
|
|
108
|
+
>>> async_engine = AsyncVLLM.from_pretrained("Qwen/Qwen2.5-3B", num_replicas=2)
|
|
109
|
+
>>> wrapper = vLLMWrapper(async_engine, input_mode="history", generate=True)
|
|
110
|
+
|
|
111
|
+
AsyncVLLM provides:
|
|
112
|
+
- Better GPU utilization through Ray-based distribution
|
|
113
|
+
- Multiple replicas for higher throughput
|
|
114
|
+
- Native vLLM batching for optimal performance
|
|
115
|
+
- Automatic resource management and cleanup
|
|
116
|
+
|
|
117
|
+
Args:
|
|
118
|
+
model (vllm.LLM | AsyncVLLM | Ray Actor | str): The vLLM model to wrap.
|
|
119
|
+
- If a string, it will be converted to an AsyncVLLM instance (recommended)
|
|
120
|
+
- If a vllm.LLM instance, uses synchronous generation via `model.generate()`
|
|
121
|
+
- If an AsyncVLLM instance, uses async generation via `model.generate()`
|
|
122
|
+
- If a Ray actor with generate method, uses remote calls via `ray.get(model.generate.remote())`
|
|
123
|
+
|
|
124
|
+
Keyword Args:
|
|
125
|
+
tokenizer (transformers.tokenization_utils.PreTrainedTokenizer | str | None, optional): The tokenizer to use for encoding and decoding text.
|
|
126
|
+
If `None`, the tokenizer associated with the model will be used. If a string, it will be passed to `transformers.AutoTokenizer.from_pretrained`.
|
|
127
|
+
Defaults to `None`.
|
|
128
|
+
input_mode (str, optional): The input modality to use. Must be one of `"history"`, `"text"`, or `"tokens"`. Defaults to `"history"`.
|
|
129
|
+
input_key (str | None, optional): The key for the input data. If `None`, defaults to
|
|
130
|
+
- `("history", "prompt")` for `"history"` when `generate=True`, `("history", "full")` for `"history"` when `generate=False`
|
|
131
|
+
- `("text", "prompt")` for `"text"` when `generate=True`, `("text", "full")` for `"text"` when `generate=False`
|
|
132
|
+
- `("tokens", "prompt")` for `"tokens"` when `generate=True`, `("tokens", "full")` for `"tokens"` when `generate=False`
|
|
133
|
+
attention_mask_key (str, optional): The key for attention masks (used in `"tokens"` mode). Defaults to `"attention_mask"`.
|
|
134
|
+
|
|
135
|
+
.. warning:: This argument is under development and may change in the future.
|
|
136
|
+
|
|
137
|
+
generate (bool, optional): Whether to enable text generation. If `True`, the model will generate text based on the input.
|
|
138
|
+
If `False`, only log probabilities will be computed. Defaults to `True`.
|
|
139
|
+
return_log_probs (bool, optional): Whether to return log probabilities. Defaults to `True`.
|
|
140
|
+
generate_kwargs (dict | None, optional): Additional arguments to pass to the model's generate method. Defaults to `None`.
|
|
141
|
+
|
|
142
|
+
**Standardized Parameters (cross-backend compatible):**
|
|
143
|
+
|
|
144
|
+
* **max_new_tokens** (int): Maximum number of new tokens to generate (maps to vLLM's max_tokens)
|
|
145
|
+
* **num_return_sequences** (int): Number of sequences to return (maps to vLLM's n)
|
|
146
|
+
* **temperature** (float): Sampling temperature (0.0 = deterministic, higher = more random)
|
|
147
|
+
* **top_p** (float): Nucleus sampling parameter (0.0-1.0)
|
|
148
|
+
* **top_k** (int): Top-k sampling parameter
|
|
149
|
+
* **repetition_penalty** (float): Penalty for repeating tokens
|
|
150
|
+
* **do_sample** (bool): Whether to use sampling vs greedy decoding
|
|
151
|
+
* **num_beams** (int): Number of beams for beam search
|
|
152
|
+
* **length_penalty** (float): Penalty for sequence length
|
|
153
|
+
* **early_stopping** (bool): Whether to stop early in beam search
|
|
154
|
+
* **stop_sequences** (list): Sequences that stop generation (maps to vLLM's stop)
|
|
155
|
+
* **skip_special_tokens** (bool): Whether to skip special tokens in output
|
|
156
|
+
* **logprobs** (bool): Whether to return log probabilities
|
|
157
|
+
|
|
158
|
+
.. warning:: Usage of this parameter is discouraged as it may conflict with the `generate` parameter
|
|
159
|
+
of the class.
|
|
160
|
+
|
|
161
|
+
**vLLM-Specific Parameters:**
|
|
162
|
+
|
|
163
|
+
* **presence_penalty** (float): Penalty for token presence
|
|
164
|
+
* **frequency_penalty** (float): Penalty for token frequency
|
|
165
|
+
* **ignore_eos** (bool): Whether to ignore EOS token
|
|
166
|
+
* **prompt_logprobs** (bool): Whether to return prompt log probabilities
|
|
167
|
+
* **detokenize** (bool): Whether to detokenize output
|
|
168
|
+
* **include_stop_str_in_output** (bool): Whether to include stop strings in output
|
|
169
|
+
* **spaces_between_special_tokens** (bool): Whether to add spaces between special tokens
|
|
170
|
+
* **sampling_type** (str): Type of sampling to use
|
|
171
|
+
* **temperature_last** (bool): Whether to apply temperature only to last token
|
|
172
|
+
* **top_p_last** (bool): Whether to apply top_p only to last token
|
|
173
|
+
* **top_k_last** (bool): Whether to apply top_k only to last token
|
|
174
|
+
|
|
175
|
+
**Legacy Parameter Support:**
|
|
176
|
+
|
|
177
|
+
* **max_tokens** (int): Automatically converted to max_new_tokens
|
|
178
|
+
* **n** (int): Automatically converted to num_return_sequences
|
|
179
|
+
|
|
180
|
+
**Parameter Conflict Resolution:**
|
|
181
|
+
|
|
182
|
+
When both legacy (vLLM-specific) and standardized parameter names are provided,
|
|
183
|
+
a :exc:`ValueError` is raised to prevent confusion. For example:
|
|
184
|
+
|
|
185
|
+
* If both ``max_tokens`` and ``max_new_tokens`` are passed, an error is raised
|
|
186
|
+
* If both ``n`` and ``num_return_sequences`` are passed, an error is raised
|
|
187
|
+
|
|
188
|
+
This ensures clear parameter usage and prevents unexpected behavior.
|
|
189
|
+
|
|
190
|
+
tokenizer_kwargs (dict | None, optional): Additional arguments to pass to the tokenizer. Defaults to `None`.
|
|
191
|
+
pad_output (bool, optional): Whether to pad the output sequences to a uniform length. Defaults to `False`.
|
|
192
|
+
pad_model_input (bool, optional): Whether to pad the model input sequences to a uniform length.
|
|
193
|
+
This is not supported by vLLM.
|
|
194
|
+
inplace (Literal[True, False, "empty"] | None, optional): Determines how the module should handle in-place operations. Defaults to `True`.
|
|
195
|
+
device (torch.device | None, optional): The device to use for computation. Defaults to `None`.
|
|
196
|
+
layout (torch.layout | None, optional): The layout to use for the output tensors when `pad_output=False`. Defaults to `torch.strided`.
|
|
197
|
+
chat_template_name (Literal["chatml_format", "qwen"] | None, optional): The name of the chat template to use when applying the chat template to the history.
|
|
198
|
+
Defaults to `None`. For `input_mode="history"` only.
|
|
199
|
+
chat_template (str | None, optional): The chat template to use when applying the chat template to the history. Defaults to `None`.
|
|
200
|
+
For `input_mode="history"` only.
|
|
201
|
+
num_samples (int | None, optional): The number of samples to generate. Defaults to `None` (one sample, and no batch-dimension for it).
|
|
202
|
+
Can also be set via the `generate_kwargs["n"] = value` argument.
|
|
203
|
+
log_probs_key (NestedKey | None, optional): The key for the log probabilities :class:`~torchrl.modules.llm.policies.LogProbs` object. Defaults to `"log_probs"`.
|
|
204
|
+
text_key (NestedKey | None, optional): The key for the action :class:`~torchrl.modules.llm.policies.Text` object. Defaults to `"text"`.
|
|
205
|
+
tokens_key (NestedKey | None, optional): The key for the action :class:`~torchrl.modules.llm.policies.Tokens` object. Defaults to `"tokens"`.
|
|
206
|
+
masks_key (NestedKey | None, optional): The key for the action :class:`~torchrl.modules.llm.policies.Masks` object. Defaults to `"masks"`.
|
|
207
|
+
history_key (NestedKey | None, optional): The key for the action :class:`~torchrl.modules.llm.policies.ChatHistory` object. Defaults to `"history"`.
|
|
208
|
+
batching (bool, optional): Whether to enable batching. Defaults to `False`. See `Batching`_ below for more details.
|
|
209
|
+
min_batch_size (int | None, optional): The minimum batch size to use for batching. See `Batching`_ below for more details.
|
|
210
|
+
max_batch_size (int | None, optional): The maximum batch size to use for batching. See `Batching`_ below for more details.
|
|
211
|
+
batching_timeout (float, optional): The timeout for batching. See `Batching`_ below for more details.
|
|
212
|
+
|
|
213
|
+
.. _Batching:
|
|
214
|
+
|
|
215
|
+
**Batching**
|
|
216
|
+
|
|
217
|
+
Batching is a feature that allows the module to process multiple inputs in a single call.
|
|
218
|
+
It is designed to work in a multi-threaded environment.
|
|
219
|
+
To enable batching, it suffices to set `batching=True` which will set `min_batch_size` to 1 if not provided.
|
|
220
|
+
If you want to set a different value for `min_batch_size` or `max_batch_size` for a fine-grained control,
|
|
221
|
+
you can to set `batching=True` and then set `min_batch_size` or `max_batch_size` to a value greater or equal to 1.
|
|
222
|
+
The way batching works is as follows:
|
|
223
|
+
- If `min_batch_size` is not provided but `max_batch_size` is, `min_batch_size` is set to 1.
|
|
224
|
+
- If `max_batch_size` is not provided but `min_batch_size` is, `max_batch_size` is set to the number of inputs in the queue.
|
|
225
|
+
- When the model is called, a check is performed to see if the number of inputs in the queue is greater or equal to `min_batch_size`.
|
|
226
|
+
If it is, the batch is processed immediately, while waiting for the previous batch to be processed if the model is busy.
|
|
227
|
+
Otherwise, the input is added to the queue and the function waits for the batch to be completed.
|
|
228
|
+
While waiting for the batch to be completed, a timeout is set to `batching_timeout` seconds such that if the batch is not
|
|
229
|
+
completed after `batching_timeout` seconds, the remaining items to process are processed as is and the function returns after
|
|
230
|
+
at most `batching_timeout` seconds (plus the time to finish processing the previous and current batch).
|
|
231
|
+
|
|
232
|
+
Input Keys:
|
|
233
|
+
The input key depends on both `input_mode` and `generate`:
|
|
234
|
+
|
|
235
|
+
- If `input_mode="history"` and `generate=True`: `input_key` (defaults to `("history", "prompt")`)
|
|
236
|
+
- If `input_mode="history"` and `generate=False`: `input_key` (defaults to `("history", "full")`)
|
|
237
|
+
- If `input_mode="text"` and `generate=True`: `input_key` (defaults to `("text", "prompt")`)
|
|
238
|
+
- If `input_mode="text"` and `generate=False`: `input_key` (defaults to `("text", "full")`)
|
|
239
|
+
- If `input_mode="tokens"` and `generate=True`: `input_key` (defaults to `("tokens", "prompt")`)
|
|
240
|
+
- If `input_mode="tokens"` and `generate=False`: `input_key` (defaults to `("tokens", "full")`)
|
|
241
|
+
|
|
242
|
+
Output Keys:
|
|
243
|
+
The output keys are automatically determined based on the input_mode:
|
|
244
|
+
- **Tokens**: Always returned (`tokens_key`, defaults to `"tokens"`)
|
|
245
|
+
- **Text**: Returned for `"text"` and `"history"` modes (`text_key`, defaults to `"text"`)
|
|
246
|
+
- **History**: Returned only for `"history"` mode (`history_key`, defaults to `"history"`)
|
|
247
|
+
- **Masks**: Always returned (`masks_key`, defaults to `"masks"`)
|
|
248
|
+
- **Log Probs**: Returned when `return_log_probs=True` (`log_probs_key`, defaults to `"log_probs"`)
|
|
249
|
+
|
|
250
|
+
Example output structure for `input_mode="history"`::
|
|
251
|
+
|
|
252
|
+
TensorDict(
|
|
253
|
+
text=Text(prompt=..., response=..., full=...),
|
|
254
|
+
masks=Masks(all_attention_mask=..., all_assistant_mask=...),
|
|
255
|
+
tokens=Tokens(prompt=..., response=..., full=...),
|
|
256
|
+
log_probs=LogProbs(prompt=..., response=..., full=...),
|
|
257
|
+
history=ChatHistory(prompt=..., response=..., full=...)
|
|
258
|
+
)
|
|
259
|
+
|
|
260
|
+
Example:
|
|
261
|
+
>>> from vllm import LLM
|
|
262
|
+
>>> from transformers import AutoTokenizer
|
|
263
|
+
>>> from torchrl.data.llm import History
|
|
264
|
+
>>> from torchrl.modules.llm.policies import ChatHistory
|
|
265
|
+
>>>
|
|
266
|
+
>>> model = LLM("gpt2")
|
|
267
|
+
>>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
|
268
|
+
>>>
|
|
269
|
+
>>> # History input (recommended for RL environments)
|
|
270
|
+
>>> wrapper = vLLMWrapper(
|
|
271
|
+
... model,
|
|
272
|
+
... tokenizer=tokenizer,
|
|
273
|
+
... input_mode="history",
|
|
274
|
+
... generate=True,
|
|
275
|
+
... return_log_probs=True,
|
|
276
|
+
... generate_kwargs={
|
|
277
|
+
... "max_new_tokens": 50, # Standardized parameter
|
|
278
|
+
... "temperature": 0.7,
|
|
279
|
+
... "top_p": 0.9,
|
|
280
|
+
... "do_sample": True,
|
|
281
|
+
... }
|
|
282
|
+
... )
|
|
283
|
+
>>>
|
|
284
|
+
>>> history = History.from_chats([[
|
|
285
|
+
... {"role": "user", "content": "Hello"},
|
|
286
|
+
... {"role": "assistant", "content": "Hi there!"}
|
|
287
|
+
... ]])
|
|
288
|
+
>>> chat_history = ChatHistory(prompt=history)
|
|
289
|
+
>>> result = wrapper(TensorDict(history=chat_history, batch_size=(1,)))
|
|
290
|
+
>>> print(result["text"].response) # Generated text
|
|
291
|
+
>>> print(result["log_probs"].response) # Log probabilities
|
|
292
|
+
>>> print(result["history"].response) # History with response
|
|
293
|
+
|
|
294
|
+
Attributes:
|
|
295
|
+
collector: The collector associated with the module, if it exists.
|
|
296
|
+
|
|
297
|
+
.. seealso::
|
|
298
|
+
- :class:`~torchrl.modules.llm.policies.LLMWrapperBase`
|
|
299
|
+
- :class:`~torchrl.modules.llm.policies.TransformersWrapper`
|
|
300
|
+
"""
|
|
301
|
+
|
|
302
|
+
def __init__(
|
|
303
|
+
self,
|
|
304
|
+
model: Any, # vllm.LLM | AsyncVLLMEngineService | AsyncLLMEngineExtended | str
|
|
305
|
+
*,
|
|
306
|
+
tokenizer: callable | str | None = None, # type: ignore
|
|
307
|
+
input_mode: str = "history",
|
|
308
|
+
input_key: NestedKey | None = None,
|
|
309
|
+
attention_mask_key: str = "attention_mask",
|
|
310
|
+
generate: bool = True,
|
|
311
|
+
generate_kwargs: dict | None = None,
|
|
312
|
+
tokenizer_kwargs: dict | None = None,
|
|
313
|
+
pad_output: bool = False,
|
|
314
|
+
pad_model_input: bool | None = None,
|
|
315
|
+
inplace: Literal[True, False, "empty"] | None = None,
|
|
316
|
+
device: torch.device | None = None,
|
|
317
|
+
layout: torch.layout | None = None,
|
|
318
|
+
num_samples: int | None = None,
|
|
319
|
+
chat_template_name: Literal["chatml_format", "qwen"] | None = None,
|
|
320
|
+
chat_template: str | None = None,
|
|
321
|
+
return_log_probs: bool | None = None,
|
|
322
|
+
history_key: NestedKey | None = "history",
|
|
323
|
+
text_key: NestedKey | None = "text",
|
|
324
|
+
tokens_key: NestedKey | None = "tokens",
|
|
325
|
+
masks_key: NestedKey | None = "masks",
|
|
326
|
+
log_probs_key: NestedKey | None = "log_probs",
|
|
327
|
+
batching: bool | None = None,
|
|
328
|
+
min_batch_size: int | None = None,
|
|
329
|
+
max_batch_size: int | None = None,
|
|
330
|
+
batching_timeout: float = 10.0,
|
|
331
|
+
):
|
|
332
|
+
super().__init__()
|
|
333
|
+
|
|
334
|
+
if batching and min_batch_size is None:
|
|
335
|
+
min_batch_size = 1
|
|
336
|
+
elif (min_batch_size is not None or max_batch_size is not None) and (
|
|
337
|
+
batching is False
|
|
338
|
+
):
|
|
339
|
+
raise ValueError(
|
|
340
|
+
"min_batch_size and max_batch_size must be None if batching is False."
|
|
341
|
+
)
|
|
342
|
+
|
|
343
|
+
# Validate that min_batch_size <= max_batch_size when both are specified
|
|
344
|
+
if min_batch_size is not None and max_batch_size is not None:
|
|
345
|
+
if min_batch_size > max_batch_size:
|
|
346
|
+
raise ValueError(
|
|
347
|
+
f"min_batch_size ({min_batch_size}) must be <= max_batch_size ({max_batch_size})"
|
|
348
|
+
)
|
|
349
|
+
|
|
350
|
+
self._min_batch_size = min_batch_size
|
|
351
|
+
self._max_batch_size = max_batch_size
|
|
352
|
+
self._batching_timeout = batching_timeout
|
|
353
|
+
self._batch_queue = []
|
|
354
|
+
self._futures = []
|
|
355
|
+
if self.batching:
|
|
356
|
+
self._batching_lock = threading.Lock()
|
|
357
|
+
else:
|
|
358
|
+
self._batching_lock = None
|
|
359
|
+
|
|
360
|
+
_require_transformers()
|
|
361
|
+
|
|
362
|
+
# Detect and initialize model
|
|
363
|
+
if isinstance(model, str):
|
|
364
|
+
# Import lazily to avoid importing vLLM backends unless actually needed.
|
|
365
|
+
from torchrl.modules.llm.backends.vllm import ( # local import is intentional / required
|
|
366
|
+
AsyncVLLM,
|
|
367
|
+
)
|
|
368
|
+
|
|
369
|
+
model = AsyncVLLM.from_pretrained(model)
|
|
370
|
+
|
|
371
|
+
# Validate model type
|
|
372
|
+
model_type = type(model)
|
|
373
|
+
model_module = getattr(model_type, "__module__", "")
|
|
374
|
+
model_name = getattr(model_type, "__name__", "")
|
|
375
|
+
if model_name == "AsyncVLLM" and model_module.startswith(
|
|
376
|
+
"torchrl.modules.llm.backends.vllm"
|
|
377
|
+
):
|
|
378
|
+
self._model_type = "async_vllm"
|
|
379
|
+
elif model_name == "LLM" and model_module.startswith("vllm"):
|
|
380
|
+
self._model_type = "sync_vllm"
|
|
381
|
+
elif hasattr(model, "generate") and hasattr(model, "remote"):
|
|
382
|
+
# Ray actor with generate method
|
|
383
|
+
self._model_type = "ray_actor"
|
|
384
|
+
else:
|
|
385
|
+
raise ValueError(
|
|
386
|
+
f"model must be a string, vllm.LLM, AsyncVLLM, or Ray actor. Got {type(model)}"
|
|
387
|
+
)
|
|
388
|
+
|
|
389
|
+
if isinstance(tokenizer, str):
|
|
390
|
+
from transformers import AutoTokenizer
|
|
391
|
+
|
|
392
|
+
tokenizer = AutoTokenizer.from_pretrained(tokenizer)
|
|
393
|
+
# Import vLLM lazily: only needed if we are going to interact with vLLM types.
|
|
394
|
+
# (This keeps importing this module safe even if vLLM hard-crashes on import.)
|
|
395
|
+
if self._model_type in ("sync_vllm",):
|
|
396
|
+
_require_vllm()
|
|
397
|
+
|
|
398
|
+
# Validate input_mode
|
|
399
|
+
if input_mode not in ["history", "text", "tokens"]:
|
|
400
|
+
raise ValueError(
|
|
401
|
+
f"input_mode must be one of 'history', 'text', 'tokens'. Got '{input_mode}'"
|
|
402
|
+
)
|
|
403
|
+
|
|
404
|
+
self.model = model
|
|
405
|
+
self.input_mode = input_mode
|
|
406
|
+
self.attention_mask_key = attention_mask_key
|
|
407
|
+
self.generate = generate
|
|
408
|
+
if pad_model_input is not None:
|
|
409
|
+
raise ValueError("pad_model_input is not supported by vLLMWrapper.")
|
|
410
|
+
|
|
411
|
+
# Auto-determine what to return based on input mode
|
|
412
|
+
self.return_history = input_mode in ("history",)
|
|
413
|
+
self.return_text = input_mode in ("text", "history")
|
|
414
|
+
self.return_tokens = input_mode in ("tokens", "history", "text")
|
|
415
|
+
self.return_masks = True
|
|
416
|
+
if return_log_probs is False and not generate:
|
|
417
|
+
raise ValueError("return_log_probs must be True when generate=False.")
|
|
418
|
+
return_log_probs = (
|
|
419
|
+
True
|
|
420
|
+
if (return_log_probs is None and generate) or (not generate)
|
|
421
|
+
else bool(return_log_probs)
|
|
422
|
+
)
|
|
423
|
+
self.return_log_probs = return_log_probs
|
|
424
|
+
|
|
425
|
+
self.history_key = history_key
|
|
426
|
+
self.log_probs_key = log_probs_key
|
|
427
|
+
self.masks_key = masks_key
|
|
428
|
+
self.text_key = text_key
|
|
429
|
+
self.tokens_key = tokens_key
|
|
430
|
+
|
|
431
|
+
if not isinstance(pad_output, bool):
|
|
432
|
+
raise ValueError("pad_output must be a boolean")
|
|
433
|
+
self.pad_output = pad_output
|
|
434
|
+
self._device = device
|
|
435
|
+
if not pad_output and layout is None:
|
|
436
|
+
layout = torch.strided
|
|
437
|
+
self.layout = layout
|
|
438
|
+
padding_value = None
|
|
439
|
+
|
|
440
|
+
# Set input keys based on mode and generate parameter
|
|
441
|
+
if input_mode == "history":
|
|
442
|
+
if generate:
|
|
443
|
+
self.in_keys = [
|
|
444
|
+
("history", "prompt") if input_key is None else input_key
|
|
445
|
+
]
|
|
446
|
+
else:
|
|
447
|
+
self.in_keys = [("history", "full") if input_key is None else input_key]
|
|
448
|
+
elif input_mode == "text":
|
|
449
|
+
if generate:
|
|
450
|
+
self.in_keys = [("text", "prompt") if input_key is None else input_key]
|
|
451
|
+
else:
|
|
452
|
+
self.in_keys = [("text", "full") if input_key is None else input_key]
|
|
453
|
+
elif input_mode == "tokens":
|
|
454
|
+
if generate:
|
|
455
|
+
self.in_keys = [
|
|
456
|
+
("tokens", "prompt") if input_key is None else input_key
|
|
457
|
+
]
|
|
458
|
+
else:
|
|
459
|
+
self.in_keys = [("tokens", "full") if input_key is None else input_key]
|
|
460
|
+
else:
|
|
461
|
+
raise ValueError(f"Invalid input_mode: {input_mode}")
|
|
462
|
+
self.input_key = self.in_keys[0]
|
|
463
|
+
|
|
464
|
+
# Set output keys based on auto-determined return flags
|
|
465
|
+
self.out_keys = []
|
|
466
|
+
if self.return_text:
|
|
467
|
+
self.out_keys.append(self.text_key)
|
|
468
|
+
if self.return_masks:
|
|
469
|
+
self.out_keys.append(self.masks_key)
|
|
470
|
+
if self.return_tokens:
|
|
471
|
+
self.out_keys.append(self.tokens_key)
|
|
472
|
+
if self.return_log_probs:
|
|
473
|
+
self.out_keys.append(self.log_probs_key)
|
|
474
|
+
if self.return_history:
|
|
475
|
+
self.out_keys.append(self.history_key)
|
|
476
|
+
|
|
477
|
+
# Tokenizer setup
|
|
478
|
+
if not tokenizer_kwargs:
|
|
479
|
+
tokenizer_kwargs = {}
|
|
480
|
+
if not tokenizer_kwargs.setdefault("return_attention_mask", True):
|
|
481
|
+
raise RuntimeError("return_attention_mask must be True")
|
|
482
|
+
|
|
483
|
+
# If we don't pad, we use lists
|
|
484
|
+
return_tensors = "pt" if self.pad_output else False
|
|
485
|
+
if return_tensors:
|
|
486
|
+
if (
|
|
487
|
+
tokenizer_kwargs.setdefault("return_tensors", return_tensors)
|
|
488
|
+
!= return_tensors
|
|
489
|
+
):
|
|
490
|
+
raise RuntimeError
|
|
491
|
+
if tokenizer_kwargs.setdefault("padding", self.pad_output) not in (
|
|
492
|
+
self.pad_output,
|
|
493
|
+
):
|
|
494
|
+
raise RuntimeError
|
|
495
|
+
if tokenizer_kwargs.setdefault("padding_side", "left") != "left":
|
|
496
|
+
raise RuntimeError
|
|
497
|
+
|
|
498
|
+
self.tokenizer_kwargs = tokenizer_kwargs
|
|
499
|
+
|
|
500
|
+
# Get tokenizer if needed
|
|
501
|
+
if tokenizer is None:
|
|
502
|
+
try:
|
|
503
|
+
if hasattr(model, "get_tokenizer"):
|
|
504
|
+
tokenizer = model.get_tokenizer()
|
|
505
|
+
else:
|
|
506
|
+
# Try to extract model name and load tokenizer as fallback
|
|
507
|
+
model_name = self._extract_model_name(model)
|
|
508
|
+
if model_name:
|
|
509
|
+
warnings.warn(
|
|
510
|
+
f"No tokenizer provided. Attempting to load tokenizer from model name: {model_name}"
|
|
511
|
+
)
|
|
512
|
+
from transformers import AutoTokenizer
|
|
513
|
+
|
|
514
|
+
try:
|
|
515
|
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
516
|
+
except Exception as tokenizer_error:
|
|
517
|
+
warnings.warn(
|
|
518
|
+
f"Failed to load tokenizer from {model_name}: {tokenizer_error}"
|
|
519
|
+
)
|
|
520
|
+
else:
|
|
521
|
+
warnings.warn(
|
|
522
|
+
"No tokenizer provided and no tokenizer found in model."
|
|
523
|
+
)
|
|
524
|
+
except Exception as e:
|
|
525
|
+
warnings.warn(f"Could not get tokenizer from model: {e}")
|
|
526
|
+
self.tokenizer = tokenizer
|
|
527
|
+
|
|
528
|
+
if self.tokenizer is not None and (
|
|
529
|
+
not hasattr(self.tokenizer, "pad_token") or self.tokenizer.pad_token is None
|
|
530
|
+
):
|
|
531
|
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
|
532
|
+
if self.tokenizer is not None:
|
|
533
|
+
padding_value = self.tokenizer(self.tokenizer.pad_token)["input_ids"][0]
|
|
534
|
+
self.padding_value = padding_value
|
|
535
|
+
|
|
536
|
+
# Generate kwargs setup
|
|
537
|
+
if generate_kwargs is None:
|
|
538
|
+
generate_kwargs = {}
|
|
539
|
+
else:
|
|
540
|
+
generate_kwargs = dict(generate_kwargs)
|
|
541
|
+
|
|
542
|
+
# Standardize common parameters
|
|
543
|
+
generate_kwargs = self._standardize_generate_kwargs(generate_kwargs)
|
|
544
|
+
|
|
545
|
+
# Extract wrapper-specific parameters
|
|
546
|
+
vllm_specific_kwargs = self._get_wrapper_specific_kwargs(
|
|
547
|
+
generate_kwargs, "vllm"
|
|
548
|
+
)
|
|
549
|
+
|
|
550
|
+
# Convert common parameters back to vLLM format
|
|
551
|
+
vllm_kwargs = {}
|
|
552
|
+
for key, value in generate_kwargs.items():
|
|
553
|
+
if key in self.COMMON_GENERATION_PARAMS:
|
|
554
|
+
# Convert common names to vLLM names
|
|
555
|
+
if key == "max_new_tokens":
|
|
556
|
+
vllm_kwargs["max_tokens"] = value
|
|
557
|
+
elif key == "num_return_sequences":
|
|
558
|
+
vllm_kwargs["n"] = value
|
|
559
|
+
elif key == "stop_sequences":
|
|
560
|
+
vllm_kwargs["stop"] = value
|
|
561
|
+
elif key == "logprobs":
|
|
562
|
+
# vLLM expects int for logprobs, not bool
|
|
563
|
+
if isinstance(value, bool):
|
|
564
|
+
value = 1 if value else None
|
|
565
|
+
vllm_kwargs["logprobs"] = value
|
|
566
|
+
elif key == "do_sample":
|
|
567
|
+
# do_sample is handled through the sampling parameters
|
|
568
|
+
# If do_sample=False, we use greedy decoding (temperature=0)
|
|
569
|
+
# If do_sample=True, we use the provided sampling parameters
|
|
570
|
+
if not value:
|
|
571
|
+
vllm_kwargs["temperature"] = 0.0
|
|
572
|
+
# If do_sample=True, we keep the existing temperature/top_p/top_k values
|
|
573
|
+
elif key in ["length_penalty", "early_stopping", "num_beams"]:
|
|
574
|
+
# These parameters are not supported by vLLM, skip them
|
|
575
|
+
pass
|
|
576
|
+
else:
|
|
577
|
+
# Direct mapping for other common parameters
|
|
578
|
+
vllm_kwargs[key] = value
|
|
579
|
+
|
|
580
|
+
# Add vLLM-specific parameters
|
|
581
|
+
vllm_kwargs.update(vllm_specific_kwargs)
|
|
582
|
+
|
|
583
|
+
self.num_samples = num_samples
|
|
584
|
+
if vllm_kwargs.get("n", 1) > 1 or num_samples is not None:
|
|
585
|
+
if inplace in (True, "empty"):
|
|
586
|
+
raise ValueError(
|
|
587
|
+
"inplace must be False (or None) when generating more than one sample."
|
|
588
|
+
)
|
|
589
|
+
if inplace is None:
|
|
590
|
+
inplace = False
|
|
591
|
+
if (
|
|
592
|
+
vllm_kwargs.get("n", 1) > 1
|
|
593
|
+
and num_samples is not None
|
|
594
|
+
and vllm_kwargs.get("n", 1) != num_samples
|
|
595
|
+
):
|
|
596
|
+
raise ValueError("num_samples differs from generate_kwargs['n'].")
|
|
597
|
+
elif num_samples is None:
|
|
598
|
+
self.num_samples = vllm_kwargs.get("n", 1)
|
|
599
|
+
vllm_kwargs["n"] = self.num_samples
|
|
600
|
+
elif inplace is None:
|
|
601
|
+
inplace = True
|
|
602
|
+
|
|
603
|
+
self.inplace = inplace
|
|
604
|
+
|
|
605
|
+
# vLLM expects int for logprobs, not bool. Use 1 if True, None if False.
|
|
606
|
+
prompt_logprobs = 1 if return_log_probs else None
|
|
607
|
+
|
|
608
|
+
if not generate:
|
|
609
|
+
# We want only the log-probs, we generate a single token (that we then discard)
|
|
610
|
+
# and retrieve the prompt log-probs
|
|
611
|
+
vllm_kwargs["max_tokens"] = 1
|
|
612
|
+
if not return_log_probs:
|
|
613
|
+
raise ValueError("return_log_probs must be True when generate=False.")
|
|
614
|
+
|
|
615
|
+
vllm_kwargs.setdefault("detokenize", not pad_output)
|
|
616
|
+
vllm_kwargs.setdefault("prompt_logprobs", prompt_logprobs)
|
|
617
|
+
vllm_kwargs.setdefault("logprobs", 1 if return_log_probs else None)
|
|
618
|
+
vllm_kwargs.setdefault("include_stop_str_in_output", True)
|
|
619
|
+
vllm_kwargs.setdefault("skip_special_tokens", False)
|
|
620
|
+
|
|
621
|
+
sampling_params = SamplingParams(**vllm_kwargs)
|
|
622
|
+
self.sampling_params = sampling_params
|
|
623
|
+
|
|
624
|
+
# Additional transformers-specific settings
|
|
625
|
+
self.chat_template_name = chat_template_name
|
|
626
|
+
self.chat_template = chat_template
|
|
627
|
+
|
|
628
|
+
def get_new_version(self, **kwargs):
|
|
629
|
+
"""Returns a new version of the module with altered parameters.
|
|
630
|
+
|
|
631
|
+
For instance, the generate parameter can be altered to enable text generation or log-probabilities computation.
|
|
632
|
+
This is especially useful when one wants to avoid re-initializing the module with a new set of parameters, when the
|
|
633
|
+
same parameters could be used to gather log-probs.
|
|
634
|
+
|
|
635
|
+
Positional arguments are not supported.
|
|
636
|
+
|
|
637
|
+
See the class constructor for more details about the parameters.
|
|
638
|
+
"""
|
|
639
|
+
# Build the constructor arguments by using current values for missing parameters
|
|
640
|
+
constructor_kwargs = {}
|
|
641
|
+
|
|
642
|
+
# Model is always required
|
|
643
|
+
constructor_kwargs["model"] = kwargs.get("model", self.model)
|
|
644
|
+
|
|
645
|
+
# Check for each parameter and use current value if not provided
|
|
646
|
+
if "tokenizer" in kwargs:
|
|
647
|
+
constructor_kwargs["tokenizer"] = kwargs["tokenizer"]
|
|
648
|
+
elif hasattr(self, "tokenizer"):
|
|
649
|
+
constructor_kwargs["tokenizer"] = self.tokenizer
|
|
650
|
+
|
|
651
|
+
if "input_mode" in kwargs:
|
|
652
|
+
constructor_kwargs["input_mode"] = kwargs["input_mode"]
|
|
653
|
+
elif hasattr(self, "input_mode"):
|
|
654
|
+
constructor_kwargs["input_mode"] = self.input_mode
|
|
655
|
+
|
|
656
|
+
if "input_key" in kwargs:
|
|
657
|
+
constructor_kwargs["input_key"] = kwargs["input_key"]
|
|
658
|
+
# Since the input_key is dynamically determined, we don't want to set it here
|
|
659
|
+
# elif hasattr(self, "input_key"):
|
|
660
|
+
# constructor_kwargs["input_key"] = self.input_key
|
|
661
|
+
|
|
662
|
+
if "attention_mask_key" in kwargs:
|
|
663
|
+
constructor_kwargs["attention_mask_key"] = kwargs["attention_mask_key"]
|
|
664
|
+
elif hasattr(self, "attention_mask_key"):
|
|
665
|
+
constructor_kwargs["attention_mask_key"] = self.attention_mask_key
|
|
666
|
+
|
|
667
|
+
if "generate" in kwargs:
|
|
668
|
+
constructor_kwargs["generate"] = kwargs["generate"]
|
|
669
|
+
elif hasattr(self, "generate"):
|
|
670
|
+
constructor_kwargs["generate"] = self.generate
|
|
671
|
+
|
|
672
|
+
if "return_log_probs" in kwargs:
|
|
673
|
+
constructor_kwargs["return_log_probs"] = kwargs["return_log_probs"]
|
|
674
|
+
elif not constructor_kwargs.get("generate", True):
|
|
675
|
+
# if we are not generating, we want to return log-probs
|
|
676
|
+
constructor_kwargs["return_log_probs"] = True
|
|
677
|
+
elif hasattr(self, "return_log_probs"):
|
|
678
|
+
constructor_kwargs["return_log_probs"] = self.return_log_probs
|
|
679
|
+
|
|
680
|
+
if "generate_kwargs" in kwargs:
|
|
681
|
+
constructor_kwargs["generate_kwargs"] = kwargs["generate_kwargs"]
|
|
682
|
+
elif hasattr(self, "generate_kwargs"):
|
|
683
|
+
constructor_kwargs["generate_kwargs"] = self.generate_kwargs
|
|
684
|
+
|
|
685
|
+
if "pad_output" in kwargs:
|
|
686
|
+
constructor_kwargs["pad_output"] = kwargs["pad_output"]
|
|
687
|
+
elif hasattr(self, "pad_output"):
|
|
688
|
+
constructor_kwargs["pad_output"] = self.pad_output
|
|
689
|
+
|
|
690
|
+
if "tokenizer_kwargs" in kwargs:
|
|
691
|
+
constructor_kwargs["tokenizer_kwargs"] = kwargs["tokenizer_kwargs"]
|
|
692
|
+
elif hasattr(self, "tokenizer_kwargs"):
|
|
693
|
+
constructor_kwargs["tokenizer_kwargs"] = dict(self.tokenizer_kwargs)
|
|
694
|
+
if (
|
|
695
|
+
"pad_output" in kwargs
|
|
696
|
+
and kwargs.get("pad_output")
|
|
697
|
+
!= constructor_kwargs["tokenizer_kwargs"]["padding"]
|
|
698
|
+
):
|
|
699
|
+
constructor_kwargs["tokenizer_kwargs"]["padding"] = kwargs.get(
|
|
700
|
+
"pad_output"
|
|
701
|
+
)
|
|
702
|
+
|
|
703
|
+
if "inplace" in kwargs:
|
|
704
|
+
constructor_kwargs["inplace"] = kwargs["inplace"]
|
|
705
|
+
elif hasattr(self, "inplace"):
|
|
706
|
+
constructor_kwargs["inplace"] = self.inplace
|
|
707
|
+
|
|
708
|
+
if "device" in kwargs:
|
|
709
|
+
constructor_kwargs["device"] = kwargs["device"]
|
|
710
|
+
elif hasattr(self, "_device"):
|
|
711
|
+
constructor_kwargs["device"] = self._device
|
|
712
|
+
|
|
713
|
+
if "layout" in kwargs:
|
|
714
|
+
constructor_kwargs["layout"] = kwargs["layout"]
|
|
715
|
+
elif hasattr(self, "layout"):
|
|
716
|
+
constructor_kwargs["layout"] = self.layout
|
|
717
|
+
|
|
718
|
+
if "num_samples" in kwargs:
|
|
719
|
+
constructor_kwargs["num_samples"] = kwargs["num_samples"]
|
|
720
|
+
elif hasattr(self, "num_samples"):
|
|
721
|
+
constructor_kwargs["num_samples"] = self.num_samples
|
|
722
|
+
|
|
723
|
+
if "chat_template_name" in kwargs:
|
|
724
|
+
constructor_kwargs["chat_template_name"] = kwargs["chat_template_name"]
|
|
725
|
+
elif hasattr(self, "chat_template_name"):
|
|
726
|
+
constructor_kwargs["chat_template_name"] = self.chat_template_name
|
|
727
|
+
|
|
728
|
+
if "chat_template" in kwargs:
|
|
729
|
+
constructor_kwargs["chat_template"] = kwargs["chat_template"]
|
|
730
|
+
elif hasattr(self, "chat_template"):
|
|
731
|
+
constructor_kwargs["chat_template"] = self.chat_template
|
|
732
|
+
|
|
733
|
+
if "history_key" in kwargs:
|
|
734
|
+
constructor_kwargs["history_key"] = kwargs["history_key"]
|
|
735
|
+
elif hasattr(self, "history_key"):
|
|
736
|
+
constructor_kwargs["history_key"] = self.history_key
|
|
737
|
+
|
|
738
|
+
if "text_key" in kwargs:
|
|
739
|
+
constructor_kwargs["text_key"] = kwargs["text_key"]
|
|
740
|
+
elif hasattr(self, "text_key"):
|
|
741
|
+
constructor_kwargs["text_key"] = self.text_key
|
|
742
|
+
|
|
743
|
+
if "tokens_key" in kwargs:
|
|
744
|
+
constructor_kwargs["tokens_key"] = kwargs["tokens_key"]
|
|
745
|
+
elif hasattr(self, "tokens_key"):
|
|
746
|
+
constructor_kwargs["tokens_key"] = self.tokens_key
|
|
747
|
+
|
|
748
|
+
if "masks_key" in kwargs:
|
|
749
|
+
constructor_kwargs["masks_key"] = kwargs["masks_key"]
|
|
750
|
+
elif hasattr(self, "masks_key"):
|
|
751
|
+
constructor_kwargs["masks_key"] = self.masks_key
|
|
752
|
+
|
|
753
|
+
if "log_probs_key" in kwargs:
|
|
754
|
+
constructor_kwargs["log_probs_key"] = kwargs["log_probs_key"]
|
|
755
|
+
elif hasattr(self, "log_probs_key"):
|
|
756
|
+
constructor_kwargs["log_probs_key"] = self.log_probs_key
|
|
757
|
+
|
|
758
|
+
# Create and return new instance
|
|
759
|
+
return type(self)(**constructor_kwargs)
|
|
760
|
+
|
|
761
|
+
def set_tokenizer(self, tokenizer):
|
|
762
|
+
"""Set the tokenizer for the wrapper. Useful for async engines where tokenizer retrieval is deferred."""
|
|
763
|
+
self.tokenizer = tokenizer
|
|
764
|
+
if self.tokenizer is not None and (
|
|
765
|
+
not hasattr(self.tokenizer, "pad_token") or self.tokenizer.pad_token is None
|
|
766
|
+
):
|
|
767
|
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
|
768
|
+
if self.tokenizer is not None:
|
|
769
|
+
padding_value = self.tokenizer(self.tokenizer.pad_token)["input_ids"][0]
|
|
770
|
+
else:
|
|
771
|
+
padding_value = None
|
|
772
|
+
self.padding_value = padding_value
|
|
773
|
+
|
|
774
|
+
def _extract_model_name(self, model) -> str | None:
|
|
775
|
+
"""Extract model name from different model types for tokenizer fallback."""
|
|
776
|
+
try:
|
|
777
|
+
# For AsyncVLLM, try to get the model name from engine_args
|
|
778
|
+
if hasattr(model, "engine_args") and hasattr(model.engine_args, "model"):
|
|
779
|
+
return model.engine_args.model
|
|
780
|
+
|
|
781
|
+
# For vllm.LLM, try to get the model name
|
|
782
|
+
elif hasattr(model, "llm_engine") and hasattr(
|
|
783
|
+
model.llm_engine, "model_config"
|
|
784
|
+
):
|
|
785
|
+
return getattr(model.llm_engine.model_config, "model", None)
|
|
786
|
+
|
|
787
|
+
# For Ray actors, try to get model name via remote call
|
|
788
|
+
elif hasattr(model, "remote") and hasattr(model, "get_model_name"):
|
|
789
|
+
import ray
|
|
790
|
+
|
|
791
|
+
try:
|
|
792
|
+
return ray.get(model.get_model_name.remote())
|
|
793
|
+
except Exception:
|
|
794
|
+
pass
|
|
795
|
+
|
|
796
|
+
# Try common attributes that might contain model name
|
|
797
|
+
for attr in ["model_name", "model", "model_path", "_model_name"]:
|
|
798
|
+
if hasattr(model, attr):
|
|
799
|
+
value = getattr(model, attr)
|
|
800
|
+
if isinstance(value, str):
|
|
801
|
+
return value
|
|
802
|
+
|
|
803
|
+
return None
|
|
804
|
+
except Exception:
|
|
805
|
+
return None
|
|
806
|
+
|
|
807
|
+
def _call_generate(self, *args, **kwargs):
|
|
808
|
+
"""Call generate method based on model type.
|
|
809
|
+
|
|
810
|
+
In vLLM 0.14+, prompt_token_ids should be passed as TokensPrompt objects
|
|
811
|
+
rather than as a keyword argument.
|
|
812
|
+
"""
|
|
813
|
+
# Convert prompt_token_ids to TokensPrompt format for vLLM 0.14+ compatibility
|
|
814
|
+
prompt_token_ids = kwargs.pop("prompt_token_ids", None)
|
|
815
|
+
if prompt_token_ids is not None and TokensPrompt is not None:
|
|
816
|
+
# Convert list of token ID lists to TokensPrompt objects
|
|
817
|
+
if isinstance(prompt_token_ids, list) and len(prompt_token_ids) > 0:
|
|
818
|
+
if isinstance(prompt_token_ids[0], list):
|
|
819
|
+
# List of token ID lists -> list of TokensPrompt
|
|
820
|
+
prompts = [
|
|
821
|
+
TokensPrompt(prompt_token_ids=tids) for tids in prompt_token_ids
|
|
822
|
+
]
|
|
823
|
+
else:
|
|
824
|
+
# Single token ID list -> single TokensPrompt
|
|
825
|
+
prompts = TokensPrompt(prompt_token_ids=prompt_token_ids)
|
|
826
|
+
# Insert prompts as the first positional argument
|
|
827
|
+
args = (prompts,) + args
|
|
828
|
+
elif prompt_token_ids is not None:
|
|
829
|
+
# Fallback for older vLLM versions that still support prompt_token_ids kwarg
|
|
830
|
+
kwargs["prompt_token_ids"] = prompt_token_ids
|
|
831
|
+
|
|
832
|
+
if self._model_type == "ray_actor":
|
|
833
|
+
import ray
|
|
834
|
+
|
|
835
|
+
return ray.get(self.model.generate.remote(*args, **kwargs))
|
|
836
|
+
else:
|
|
837
|
+
# Both sync_vllm and async_vllm have direct generate methods
|
|
838
|
+
return self.model.generate(*args, **kwargs)
|
|
839
|
+
|
|
840
|
+
@set_list_to_stack(True)
|
|
841
|
+
@_batching
|
|
842
|
+
def forward(
|
|
843
|
+
self,
|
|
844
|
+
tensordict: TensorDictBase,
|
|
845
|
+
*,
|
|
846
|
+
tensordict_out: TensorDictBase | None = None,
|
|
847
|
+
logits_only: bool = False,
|
|
848
|
+
**kwargs,
|
|
849
|
+
) -> TensorDictBase:
|
|
850
|
+
tensordict_orig = tensordict
|
|
851
|
+
if not tensordict.ndim:
|
|
852
|
+
if tensordict_out is not None:
|
|
853
|
+
raise ValueError(
|
|
854
|
+
"tensordict_out must not be provided when tensordict.ndim == 0. If this is needed, "
|
|
855
|
+
"please submit an issue on github."
|
|
856
|
+
)
|
|
857
|
+
# unsqueeze - squeeze the input
|
|
858
|
+
return self.forward(lazy_stack([tensordict]), logits_only=logits_only)[0]
|
|
859
|
+
elif tensordict.ndim > 1:
|
|
860
|
+
if tensordict_out is not None:
|
|
861
|
+
raise ValueError(
|
|
862
|
+
"tensordict_out must not be provided when tensordict.ndim > 1. If this is needed, "
|
|
863
|
+
"please submit an issue on github."
|
|
864
|
+
)
|
|
865
|
+
return self.forward(tensordict.reshape(-1), logits_only=logits_only).view(
|
|
866
|
+
tensordict.shape
|
|
867
|
+
)
|
|
868
|
+
|
|
869
|
+
if not isinstance(tensordict, LazyStackedTensorDict):
|
|
870
|
+
tensordict = tensordict.to_lazystack(0)
|
|
871
|
+
|
|
872
|
+
_source_device = None
|
|
873
|
+
if self._device:
|
|
874
|
+
_source_device = tensordict.device
|
|
875
|
+
if tensordict.device:
|
|
876
|
+
tensordict = tensordict.copy().clear_device_()
|
|
877
|
+
|
|
878
|
+
if kwargs:
|
|
879
|
+
from vllm import SamplingParams
|
|
880
|
+
|
|
881
|
+
sampling_params = SamplingParams(**kwargs)
|
|
882
|
+
else:
|
|
883
|
+
sampling_params = self.sampling_params
|
|
884
|
+
|
|
885
|
+
if self.num_samples is not None:
|
|
886
|
+
out = (
|
|
887
|
+
TensorDict(
|
|
888
|
+
device=tensordict.device,
|
|
889
|
+
batch_size=(
|
|
890
|
+
tensordict.batch_size[0],
|
|
891
|
+
self.num_samples,
|
|
892
|
+
*tensordict.batch_size[1:],
|
|
893
|
+
),
|
|
894
|
+
)
|
|
895
|
+
.to_lazystack(1)
|
|
896
|
+
.to_lazystack(0)
|
|
897
|
+
)
|
|
898
|
+
else:
|
|
899
|
+
out = TensorDict(
|
|
900
|
+
device=tensordict.device, batch_size=tensordict.batch_size
|
|
901
|
+
).to_lazystack(0)
|
|
902
|
+
|
|
903
|
+
if self.input_mode == "history":
|
|
904
|
+
if self.generate:
|
|
905
|
+
out = self._from_vllm_generate_history(tensordict, sampling_params, out)
|
|
906
|
+
else:
|
|
907
|
+
out = self._from_vllm_logprobs_history(tensordict, sampling_params, out)
|
|
908
|
+
elif self.input_mode == "text":
|
|
909
|
+
if self.generate:
|
|
910
|
+
out = self._from_vllm_generate_text(tensordict, sampling_params, out)
|
|
911
|
+
else:
|
|
912
|
+
out = self._from_vllm_logprobs_text(tensordict, sampling_params, out)
|
|
913
|
+
elif self.input_mode == "tokens":
|
|
914
|
+
if self.generate:
|
|
915
|
+
out = self._from_vllm_generate_tokens(tensordict, sampling_params, out)
|
|
916
|
+
else:
|
|
917
|
+
out = self._from_vllm_logprobs_tokens(tensordict, sampling_params, out)
|
|
918
|
+
|
|
919
|
+
if _source_device:
|
|
920
|
+
out = out.to(_source_device)
|
|
921
|
+
|
|
922
|
+
if tensordict_out is None:
|
|
923
|
+
if self.inplace is True:
|
|
924
|
+
# The output is the input
|
|
925
|
+
tensordict_out = tensordict_orig
|
|
926
|
+
elif self.inplace is False:
|
|
927
|
+
# The output is the new structure
|
|
928
|
+
tensordict_out = out
|
|
929
|
+
elif self.inplace == "empty":
|
|
930
|
+
# The output is empty
|
|
931
|
+
tensordict_out = tensordict.empty()
|
|
932
|
+
|
|
933
|
+
if tensordict_out is not None and tensordict_out is not out:
|
|
934
|
+
result = tensordict_out.exclude(*self.out_keys, inplace=True)
|
|
935
|
+
result.update(out, keys_to_update=self.out_keys)
|
|
936
|
+
elif tensordict_out is out:
|
|
937
|
+
result = out.select(*self.out_keys)
|
|
938
|
+
elif self.inplace:
|
|
939
|
+
result = out
|
|
940
|
+
keys = list(set(self.out_keys + list(tensordict.keys(True, True))))
|
|
941
|
+
result = tensordict.exclude(*self.out_keys, inplace=True).update(
|
|
942
|
+
result, keys_to_update=keys
|
|
943
|
+
)
|
|
944
|
+
else:
|
|
945
|
+
result = out
|
|
946
|
+
return result
|
|
947
|
+
|
|
948
|
+
def _from_vllm_generate_history(
|
|
949
|
+
self,
|
|
950
|
+
tensordict_input: TensorDictBase,
|
|
951
|
+
sampling_params: Any,
|
|
952
|
+
out: TensorDictBase,
|
|
953
|
+
) -> TensorDictBase:
|
|
954
|
+
"""Generate text from history input."""
|
|
955
|
+
from torchrl.data.llm import History
|
|
956
|
+
|
|
957
|
+
assert isinstance(
|
|
958
|
+
tensordict_input, TensorDictBase
|
|
959
|
+
), f"tensordict_input must be TensorDictBase, got {type(tensordict_input)}"
|
|
960
|
+
assert isinstance(
|
|
961
|
+
sampling_params, SamplingParams
|
|
962
|
+
), f"sampling_params must be SamplingParams, got {type(sampling_params)}"
|
|
963
|
+
assert isinstance(
|
|
964
|
+
out, TensorDictBase
|
|
965
|
+
), f"out must be TensorDictBase, got {type(out)}"
|
|
966
|
+
|
|
967
|
+
# Validate input
|
|
968
|
+
if self.input_key not in tensordict_input:
|
|
969
|
+
raise ValueError(
|
|
970
|
+
f"Expected '{self.input_key}' key for history input mode, "
|
|
971
|
+
f"but found keys: {list(tensordict_input.keys())}"
|
|
972
|
+
)
|
|
973
|
+
|
|
974
|
+
history = tensordict_input.get(self.input_key)
|
|
975
|
+
if not isinstance(history, History):
|
|
976
|
+
raise TypeError(
|
|
977
|
+
f"Expected History object for '{self.input_key}', got {type(history)}"
|
|
978
|
+
)
|
|
979
|
+
|
|
980
|
+
# Apply chat template
|
|
981
|
+
tokenizer_kwargs = {}
|
|
982
|
+
if self.chat_template_name is not None:
|
|
983
|
+
tokenizer_kwargs.setdefault("chat_template_name", self.chat_template_name)
|
|
984
|
+
if self.chat_template is not None:
|
|
985
|
+
tokenizer_kwargs.setdefault("chat_template", self.chat_template)
|
|
986
|
+
tokenizer_kwargs.setdefault("add_generation_prompt", True)
|
|
987
|
+
text_prompt = history.apply_chat_template(
|
|
988
|
+
tokenizer=self.tokenizer, **tokenizer_kwargs
|
|
989
|
+
)
|
|
990
|
+
|
|
991
|
+
tokenizer_kwargs.setdefault("return_assistant_tokens_mask", False)
|
|
992
|
+
tokenizer_kwargs.setdefault("tokenize", True)
|
|
993
|
+
tokenizer_kwargs.setdefault("padding", False)
|
|
994
|
+
tokenizer_kwargs.setdefault("return_dict", True)
|
|
995
|
+
response_struct = history.apply_chat_template(
|
|
996
|
+
tokenizer=self.tokenizer, **tokenizer_kwargs
|
|
997
|
+
)
|
|
998
|
+
tokens_prompt_padded = None
|
|
999
|
+
tokens_prompt_unpadded = None
|
|
1000
|
+
if self.pad_output:
|
|
1001
|
+
tokens_prompt_padded = response_struct.get(
|
|
1002
|
+
"input_ids",
|
|
1003
|
+
as_padded_tensor=True,
|
|
1004
|
+
padding_value=self.padding_value,
|
|
1005
|
+
padding_side="left",
|
|
1006
|
+
)
|
|
1007
|
+
else:
|
|
1008
|
+
tokens_prompt_unpadded = response_struct.get("input_ids", as_list=True)
|
|
1009
|
+
|
|
1010
|
+
result = self._generate_from_tokens(
|
|
1011
|
+
tokens_prompt_padded=tokens_prompt_padded,
|
|
1012
|
+
tokens_prompt_unpadded=tokens_prompt_unpadded,
|
|
1013
|
+
sampling_params=sampling_params,
|
|
1014
|
+
out=out,
|
|
1015
|
+
)
|
|
1016
|
+
|
|
1017
|
+
# Generate using text path
|
|
1018
|
+
if self.pad_output:
|
|
1019
|
+
result[(self.tokens_key, "prompt")] = (
|
|
1020
|
+
tokens_prompt_padded
|
|
1021
|
+
if not self.num_samples
|
|
1022
|
+
else tokens_prompt_padded.unsqueeze(1).repeat(1, self.num_samples, 1)
|
|
1023
|
+
)
|
|
1024
|
+
else:
|
|
1025
|
+
tokens_prompt_nested = torch.nested.as_nested_tensor(tokens_prompt_unpadded)
|
|
1026
|
+
if not self.num_samples:
|
|
1027
|
+
result[(self.tokens_key, "prompt")] = tokens_prompt_nested
|
|
1028
|
+
else:
|
|
1029
|
+
for r in result.unbind(1):
|
|
1030
|
+
r[(self.tokens_key, "prompt")] = tokens_prompt_nested
|
|
1031
|
+
|
|
1032
|
+
text_result = Text._from_tensordict(result.empty())
|
|
1033
|
+
result.set(self.text_key, text_result)
|
|
1034
|
+
if not self.num_samples:
|
|
1035
|
+
text_result.prompt = text_prompt
|
|
1036
|
+
else:
|
|
1037
|
+
for r in result.unbind(1):
|
|
1038
|
+
r[self.text_key, "prompt"] = text_prompt
|
|
1039
|
+
with result.view(-1) as result_flat:
|
|
1040
|
+
if self.pad_output:
|
|
1041
|
+
tokens_full_padded = result_flat.get(
|
|
1042
|
+
(self.tokens_key, "full"),
|
|
1043
|
+
as_padded_tensor=True,
|
|
1044
|
+
padding_side="right",
|
|
1045
|
+
padding_value=self.padding_value,
|
|
1046
|
+
)
|
|
1047
|
+
if tokens_full_padded is None:
|
|
1048
|
+
raise ValueError("tokens_full_padded is None")
|
|
1049
|
+
text_full = self.tokenizer.batch_decode(
|
|
1050
|
+
tokens_full_padded, skip_special_tokens=False
|
|
1051
|
+
)
|
|
1052
|
+
else:
|
|
1053
|
+
tokens_full_unpadded = result_flat.get(
|
|
1054
|
+
(self.tokens_key, "full"), as_list=True
|
|
1055
|
+
)
|
|
1056
|
+
# print("shapes of assistant masks", [t.shape for t in result_flat.get(("masks", "all_assistant_mask"), as_list=True)])
|
|
1057
|
+
if tokens_full_unpadded is None:
|
|
1058
|
+
raise ValueError("tokens_full_unpadded is None")
|
|
1059
|
+
text_full = self.tokenizer.batch_decode(
|
|
1060
|
+
tokens_full_unpadded, skip_special_tokens=False
|
|
1061
|
+
)
|
|
1062
|
+
text_prompt = result_flat[self.text_key, "prompt"]
|
|
1063
|
+
text_response = [
|
|
1064
|
+
txt[len(prompt) :]
|
|
1065
|
+
for txt, prompt in _zip_strict(text_full, text_prompt)
|
|
1066
|
+
]
|
|
1067
|
+
result_flat.set((self.text_key, "full"), text_full)
|
|
1068
|
+
result_flat.set((self.text_key, "response"), text_response)
|
|
1069
|
+
|
|
1070
|
+
# Now parse the full text back to a history object, and use the extra history objects
|
|
1071
|
+
# as response
|
|
1072
|
+
history_chat = ChatHistory._from_tensordict(result.empty())
|
|
1073
|
+
if self.num_samples is None:
|
|
1074
|
+
history_chat.prompt = history
|
|
1075
|
+
else:
|
|
1076
|
+
for h in history_chat.unbind(1):
|
|
1077
|
+
h.prompt = history
|
|
1078
|
+
with history_chat.view(-1) as history_chat_flat:
|
|
1079
|
+
prompt_histories = history_chat_flat.prompt
|
|
1080
|
+
# Extract response histories from full text
|
|
1081
|
+
h_responses = _extract_responses_from_full_histories(
|
|
1082
|
+
text_full, prompt_histories, self.chat_template_name, self.tokenizer
|
|
1083
|
+
)
|
|
1084
|
+
history_chat_flat.response = h_responses
|
|
1085
|
+
history_chat_flat.full = history_chat_flat.prompt.extend(
|
|
1086
|
+
h_responses, inplace=False, dim=-1
|
|
1087
|
+
)
|
|
1088
|
+
result.set(self.history_key, history_chat)
|
|
1089
|
+
return result
|
|
1090
|
+
|
|
1091
|
+
def _from_vllm_logprobs_history(
|
|
1092
|
+
self,
|
|
1093
|
+
tensordict_input: TensorDictBase,
|
|
1094
|
+
sampling_params: Any,
|
|
1095
|
+
out: TensorDictBase,
|
|
1096
|
+
) -> TensorDictBase:
|
|
1097
|
+
"""Compute log-probs from history input."""
|
|
1098
|
+
assert isinstance(
|
|
1099
|
+
tensordict_input, TensorDictBase
|
|
1100
|
+
), f"tensordict_input must be TensorDictBase, got {type(tensordict_input)}"
|
|
1101
|
+
assert isinstance(
|
|
1102
|
+
sampling_params, SamplingParams
|
|
1103
|
+
), f"sampling_params must be SamplingParams, got {type(sampling_params)}"
|
|
1104
|
+
assert isinstance(
|
|
1105
|
+
out, TensorDictBase
|
|
1106
|
+
), f"out must be TensorDictBase, got {type(out)}"
|
|
1107
|
+
|
|
1108
|
+
from torchrl.data.llm import History
|
|
1109
|
+
|
|
1110
|
+
# Validate input
|
|
1111
|
+
if self.input_key not in tensordict_input:
|
|
1112
|
+
raise ValueError(
|
|
1113
|
+
f"Expected '{self.input_key}' key for history input mode, "
|
|
1114
|
+
f"but found keys: {list(tensordict_input.keys())}"
|
|
1115
|
+
)
|
|
1116
|
+
|
|
1117
|
+
history = tensordict_input.get(self.input_key)
|
|
1118
|
+
if not isinstance(history, History):
|
|
1119
|
+
raise TypeError(
|
|
1120
|
+
f"Expected History object for '{self.input_key}', got {type(history)}"
|
|
1121
|
+
)
|
|
1122
|
+
|
|
1123
|
+
# Apply chat template
|
|
1124
|
+
tokenizer_kwargs = {}
|
|
1125
|
+
if self.chat_template_name is not None:
|
|
1126
|
+
tokenizer_kwargs.setdefault("chat_template_name", self.chat_template_name)
|
|
1127
|
+
if self.chat_template is not None:
|
|
1128
|
+
tokenizer_kwargs.setdefault("chat_template", self.chat_template)
|
|
1129
|
+
tokenizer_kwargs.setdefault("add_generation_prompt", False)
|
|
1130
|
+
text_full = history.apply_chat_template(
|
|
1131
|
+
tokenizer=self.tokenizer, **tokenizer_kwargs
|
|
1132
|
+
)
|
|
1133
|
+
tokenizer_kwargs.setdefault("return_assistant_tokens_mask", True)
|
|
1134
|
+
tokenizer_kwargs.setdefault("tokenize", True)
|
|
1135
|
+
tokenizer_kwargs.setdefault("padding", False)
|
|
1136
|
+
tokenizer_kwargs.setdefault("return_dict", True)
|
|
1137
|
+
response_struct = history.apply_chat_template(
|
|
1138
|
+
tokenizer=self.tokenizer, **tokenizer_kwargs
|
|
1139
|
+
)
|
|
1140
|
+
|
|
1141
|
+
result = self._logprobs_from_tokens(
|
|
1142
|
+
response_struct=response_struct, sampling_params=sampling_params, out=out
|
|
1143
|
+
)
|
|
1144
|
+
text_result = Text._from_tensordict(result.empty())
|
|
1145
|
+
result.set(self.text_key, text_result)
|
|
1146
|
+
result[self.text_key, "full"] = text_full
|
|
1147
|
+
result.set(self.history_key, ChatHistory(full=history))
|
|
1148
|
+
return result
|
|
1149
|
+
|
|
1150
|
+
def _from_vllm_generate_text(
|
|
1151
|
+
self, td: TensorDictBase, sampling_params: Any, out: TensorDictBase
|
|
1152
|
+
) -> TensorDictBase:
|
|
1153
|
+
"""Generate text from text input."""
|
|
1154
|
+
# Type assertions
|
|
1155
|
+
assert isinstance(
|
|
1156
|
+
td, TensorDictBase
|
|
1157
|
+
), f"td must be TensorDictBase, got {type(td)}"
|
|
1158
|
+
assert isinstance(
|
|
1159
|
+
sampling_params, SamplingParams
|
|
1160
|
+
), f"sampling_params must be SamplingParams, got {type(sampling_params)}"
|
|
1161
|
+
assert isinstance(
|
|
1162
|
+
out, TensorDictBase
|
|
1163
|
+
), f"out must be TensorDictBase, got {type(out)}"
|
|
1164
|
+
|
|
1165
|
+
# Validate input
|
|
1166
|
+
if self.input_key not in td:
|
|
1167
|
+
raise ValueError(
|
|
1168
|
+
f"Expected '{self.input_key}' key for text input mode, "
|
|
1169
|
+
f"but found keys: {list(td.keys())}"
|
|
1170
|
+
)
|
|
1171
|
+
|
|
1172
|
+
text = td.get(self.input_key)
|
|
1173
|
+
if text is None:
|
|
1174
|
+
raise ValueError(f"Expected '{self.input_key}' key for text input mode")
|
|
1175
|
+
|
|
1176
|
+
return self._generate_from_text(text, sampling_params, out)
|
|
1177
|
+
|
|
1178
|
+
def _from_vllm_logprobs_text(
|
|
1179
|
+
self, td: TensorDictBase, sampling_params: Any, out: TensorDictBase
|
|
1180
|
+
) -> TensorDictBase:
|
|
1181
|
+
"""Compute log-probs from text input."""
|
|
1182
|
+
# Type assertions
|
|
1183
|
+
assert isinstance(
|
|
1184
|
+
td, TensorDictBase
|
|
1185
|
+
), f"td must be TensorDictBase, got {type(td)}"
|
|
1186
|
+
assert isinstance(
|
|
1187
|
+
sampling_params, SamplingParams
|
|
1188
|
+
), f"sampling_params must be SamplingParams, got {type(sampling_params)}"
|
|
1189
|
+
assert isinstance(
|
|
1190
|
+
out, TensorDictBase
|
|
1191
|
+
), f"out must be TensorDictBase, got {type(out)}"
|
|
1192
|
+
|
|
1193
|
+
# Validate input
|
|
1194
|
+
if self.input_key not in td:
|
|
1195
|
+
raise ValueError(
|
|
1196
|
+
f"Expected '{self.input_key}' key for text input mode, "
|
|
1197
|
+
f"but found keys: {list(td.keys())}"
|
|
1198
|
+
)
|
|
1199
|
+
|
|
1200
|
+
text = td.get(self.input_key)
|
|
1201
|
+
if text is None:
|
|
1202
|
+
raise ValueError(f"Expected '{self.input_key}' key for text input mode")
|
|
1203
|
+
|
|
1204
|
+
return self._logprobs_from_text(text, sampling_params, out)
|
|
1205
|
+
|
|
1206
|
+
def _from_vllm_generate_tokens(
|
|
1207
|
+
self, td: TensorDictBase, sampling_params: Any, out: TensorDictBase
|
|
1208
|
+
) -> TensorDictBase:
|
|
1209
|
+
"""Generate text from tokens input."""
|
|
1210
|
+
# Type assertions
|
|
1211
|
+
assert isinstance(
|
|
1212
|
+
td, TensorDictBase
|
|
1213
|
+
), f"td must be TensorDictBase, got {type(td)}"
|
|
1214
|
+
assert isinstance(
|
|
1215
|
+
sampling_params, SamplingParams
|
|
1216
|
+
), f"sampling_params must be SamplingParams, got {type(sampling_params)}"
|
|
1217
|
+
assert isinstance(
|
|
1218
|
+
out, TensorDictBase
|
|
1219
|
+
), f"out must be TensorDictBase, got {type(out)}"
|
|
1220
|
+
|
|
1221
|
+
# Validate input
|
|
1222
|
+
if self.input_key not in td:
|
|
1223
|
+
raise ValueError(
|
|
1224
|
+
f"Expected '{self.input_key}' key for tokens input mode, "
|
|
1225
|
+
f"but found keys: {list(td.keys())}"
|
|
1226
|
+
)
|
|
1227
|
+
|
|
1228
|
+
tokens_prompt_padded = None
|
|
1229
|
+
tokens_prompt_unpadded = None
|
|
1230
|
+
if self.pad_output:
|
|
1231
|
+
tokens_prompt_padded = td.get(self.input_key)
|
|
1232
|
+
else:
|
|
1233
|
+
tokens_prompt_unpadded = list(td.get(self.input_key, as_list=True))
|
|
1234
|
+
# make sure we remove the padding tokens
|
|
1235
|
+
tokens_prompt_unpadded = [
|
|
1236
|
+
tokens[tokens != self.padding_value]
|
|
1237
|
+
for tokens in tokens_prompt_unpadded
|
|
1238
|
+
]
|
|
1239
|
+
|
|
1240
|
+
return self._generate_from_tokens(
|
|
1241
|
+
tokens_prompt_unpadded=tokens_prompt_unpadded,
|
|
1242
|
+
tokens_prompt_padded=tokens_prompt_padded,
|
|
1243
|
+
sampling_params=sampling_params,
|
|
1244
|
+
out=out,
|
|
1245
|
+
)
|
|
1246
|
+
|
|
1247
|
+
def _from_vllm_logprobs_tokens(
|
|
1248
|
+
self, td: TensorDictBase, sampling_params: Any, out: TensorDictBase
|
|
1249
|
+
) -> TensorDictBase:
|
|
1250
|
+
"""Compute log-probs from tokens input."""
|
|
1251
|
+
# Type assertions
|
|
1252
|
+
assert isinstance(
|
|
1253
|
+
td, TensorDictBase
|
|
1254
|
+
), f"td must be TensorDictBase, got {type(td)}"
|
|
1255
|
+
assert isinstance(
|
|
1256
|
+
sampling_params, SamplingParams
|
|
1257
|
+
), f"sampling_params must be SamplingParams, got {type(sampling_params)}"
|
|
1258
|
+
assert isinstance(
|
|
1259
|
+
out, TensorDictBase
|
|
1260
|
+
), f"out must be TensorDictBase, got {type(out)}"
|
|
1261
|
+
|
|
1262
|
+
# Validate input
|
|
1263
|
+
if self.input_key not in td:
|
|
1264
|
+
raise ValueError(
|
|
1265
|
+
f"Expected '{self.input_key}' key for tokens input mode, "
|
|
1266
|
+
f"but found keys: {list(td.keys())}"
|
|
1267
|
+
)
|
|
1268
|
+
|
|
1269
|
+
tokens_full_padded = None
|
|
1270
|
+
tokens_full_unpadded = None
|
|
1271
|
+
if self.pad_output:
|
|
1272
|
+
tokens_full_padded = td.get(self.input_key)
|
|
1273
|
+
else:
|
|
1274
|
+
tokens_full_unpadded = list(td.get(self.input_key, as_list=True))
|
|
1275
|
+
# make sure we remove the padding tokens
|
|
1276
|
+
tokens_full_unpadded = [
|
|
1277
|
+
tokens[tokens != self.padding_value] for tokens in tokens_full_unpadded
|
|
1278
|
+
]
|
|
1279
|
+
|
|
1280
|
+
return self._logprobs_from_tokens(
|
|
1281
|
+
response_struct=None,
|
|
1282
|
+
tokens_full_unpadded=tokens_full_unpadded,
|
|
1283
|
+
tokens_full_padded=tokens_full_padded,
|
|
1284
|
+
sampling_params=sampling_params,
|
|
1285
|
+
out=out,
|
|
1286
|
+
)
|
|
1287
|
+
|
|
1288
|
+
def _cat_text(
|
|
1289
|
+
self, text: str | list[str], response_text: str | list[str] | None
|
|
1290
|
+
) -> str | list[str]:
|
|
1291
|
+
"""Concatenate text and response text."""
|
|
1292
|
+
assert isinstance(
|
|
1293
|
+
text, (str, list)
|
|
1294
|
+
), f"text must be str or list, got {type(text)}"
|
|
1295
|
+
|
|
1296
|
+
# Handle None response_text (when tokenizer is not available)
|
|
1297
|
+
if response_text is None:
|
|
1298
|
+
raise RuntimeError(
|
|
1299
|
+
"response_text is None, likely due to missing tokenizer. "
|
|
1300
|
+
"Cannot decode vLLM response without a tokenizer. "
|
|
1301
|
+
"Please provide a tokenizer explicitly or ensure the model has one available."
|
|
1302
|
+
)
|
|
1303
|
+
|
|
1304
|
+
assert isinstance(
|
|
1305
|
+
response_text, (str, list)
|
|
1306
|
+
), f"response_text must be str or list, got {type(response_text)}"
|
|
1307
|
+
|
|
1308
|
+
if isinstance(text, list):
|
|
1309
|
+
return [self._cat_text(t, t_) for t, t_ in _zip_strict(text, response_text)]
|
|
1310
|
+
else:
|
|
1311
|
+
return text + response_text
|
|
1312
|
+
|
|
1313
|
+
def _generate_from_text(
|
|
1314
|
+
self,
|
|
1315
|
+
text: str | list[str] | NonTensorStack,
|
|
1316
|
+
sampling_params: Any,
|
|
1317
|
+
out: TensorDictBase,
|
|
1318
|
+
) -> TensorDictBase:
|
|
1319
|
+
"""Generate text from text input."""
|
|
1320
|
+
# Convert text to list format
|
|
1321
|
+
if isinstance(text, str):
|
|
1322
|
+
text = [text]
|
|
1323
|
+
elif not isinstance(text, list):
|
|
1324
|
+
text = text.tolist()
|
|
1325
|
+
|
|
1326
|
+
assert isinstance(
|
|
1327
|
+
text, (str, list)
|
|
1328
|
+
), f"text must be str or list, got {type(text)}"
|
|
1329
|
+
assert isinstance(
|
|
1330
|
+
sampling_params, SamplingParams
|
|
1331
|
+
), f"sampling_params must be SamplingParams, got {type(sampling_params)}"
|
|
1332
|
+
assert isinstance(
|
|
1333
|
+
out, TensorDictBase
|
|
1334
|
+
), f"out must be TensorDictBase, got {type(out)}"
|
|
1335
|
+
|
|
1336
|
+
generate_kwargs = {"sampling_params": sampling_params}
|
|
1337
|
+
args = ()
|
|
1338
|
+
|
|
1339
|
+
# Convert text to list format
|
|
1340
|
+
if isinstance(text, str):
|
|
1341
|
+
text = [text]
|
|
1342
|
+
elif not isinstance(text, list):
|
|
1343
|
+
text = text.tolist()
|
|
1344
|
+
|
|
1345
|
+
# Call generate based on model type
|
|
1346
|
+
request_output = self._call_generate(text, *args, **generate_kwargs)
|
|
1347
|
+
|
|
1348
|
+
request_output_tc = _RequestOutput_tc.from_request_output(request_output)
|
|
1349
|
+
|
|
1350
|
+
# Extract response tokens and text
|
|
1351
|
+
outputs = (
|
|
1352
|
+
request_output_tc.outputs.view(-1)
|
|
1353
|
+
if self.num_samples is not None
|
|
1354
|
+
else request_output_tc.outputs
|
|
1355
|
+
)
|
|
1356
|
+
if self.pad_output:
|
|
1357
|
+
response_tokens_padded = outputs.view(-1).get(
|
|
1358
|
+
"token_ids",
|
|
1359
|
+
as_padded_tensor=self.pad_output,
|
|
1360
|
+
padding_value=self.padding_value,
|
|
1361
|
+
padding_side="right",
|
|
1362
|
+
)
|
|
1363
|
+
response_tokens_list = outputs.view(-1).get(
|
|
1364
|
+
"token_ids",
|
|
1365
|
+
as_list=True,
|
|
1366
|
+
)
|
|
1367
|
+
self._check_not_padded(response_tokens_list)
|
|
1368
|
+
if self.tokenizer is not None:
|
|
1369
|
+
response_text = self.tokenizer.batch_decode(
|
|
1370
|
+
response_tokens_list, skip_special_tokens=False
|
|
1371
|
+
)
|
|
1372
|
+
else:
|
|
1373
|
+
response_text = None
|
|
1374
|
+
|
|
1375
|
+
# Build output TensorClass objects
|
|
1376
|
+
|
|
1377
|
+
masks_obj = Masks._from_tensordict(out.empty())
|
|
1378
|
+
masks_obj.all_attention_mask = None
|
|
1379
|
+
masks_obj.all_assistant_mask = None
|
|
1380
|
+
masks_obj.padded = MetaData(self.pad_output)
|
|
1381
|
+
out.set(self.masks_key, masks_obj)
|
|
1382
|
+
|
|
1383
|
+
if self.num_samples is not None:
|
|
1384
|
+
text = [txt for txt in text for _ in range(self.num_samples)]
|
|
1385
|
+
text_obj = Text._from_tensordict(out.empty())
|
|
1386
|
+
with text_obj.view(-1) as text_obj_flat:
|
|
1387
|
+
text_obj_flat.prompt = text
|
|
1388
|
+
text_obj_flat.response = response_text
|
|
1389
|
+
text_obj_flat.full = self._cat_text(text, response_text)
|
|
1390
|
+
out.set(self.text_key, text_obj)
|
|
1391
|
+
|
|
1392
|
+
tokens_obj = Tokens._from_tensordict(out.empty())
|
|
1393
|
+
with tokens_obj.view(-1) as tokens_obj_flat:
|
|
1394
|
+
tokens_obj_flat.prompt = None # We don't have prompt tokens in this path
|
|
1395
|
+
if self.pad_output:
|
|
1396
|
+
tokens_obj_flat.response = response_tokens_padded
|
|
1397
|
+
self._check_padded(response_tokens_padded)
|
|
1398
|
+
else:
|
|
1399
|
+
tokens_obj_flat.response = response_tokens_list
|
|
1400
|
+
self._check_not_padded(response_tokens_list)
|
|
1401
|
+
tokens_obj_flat.full = (
|
|
1402
|
+
None # we don't have prompt tokens in this path so no all_tokens either
|
|
1403
|
+
)
|
|
1404
|
+
tokens_obj.padded = MetaData(self.pad_output)
|
|
1405
|
+
out.set(self.tokens_key, tokens_obj)
|
|
1406
|
+
|
|
1407
|
+
if self.return_log_probs:
|
|
1408
|
+
log_probs_obj = LogProbs._from_tensordict(out.empty())
|
|
1409
|
+
with log_probs_obj.view(-1) as log_probs_obj_flat:
|
|
1410
|
+
if self.pad_output:
|
|
1411
|
+
log_probs_padded = outputs.get(
|
|
1412
|
+
"logprobs",
|
|
1413
|
+
as_padded_tensor=self.pad_output,
|
|
1414
|
+
padding_value=self.padding_value,
|
|
1415
|
+
padding_side="right",
|
|
1416
|
+
)
|
|
1417
|
+
self._check_padded(log_probs_padded)
|
|
1418
|
+
log_probs_obj_flat.response = log_probs_padded
|
|
1419
|
+
log_probs_obj_flat.full = log_probs_padded
|
|
1420
|
+
else:
|
|
1421
|
+
log_probs_list = outputs.get(
|
|
1422
|
+
"logprobs",
|
|
1423
|
+
as_list=True,
|
|
1424
|
+
)
|
|
1425
|
+
self._check_not_padded(log_probs_list)
|
|
1426
|
+
log_probs_obj_flat.response = log_probs_list
|
|
1427
|
+
log_probs_obj_flat.full = log_probs_list
|
|
1428
|
+
log_probs_obj_flat.prompt = None
|
|
1429
|
+
log_probs_obj.padded = MetaData(self.pad_output)
|
|
1430
|
+
out.set(self.log_probs_key, log_probs_obj)
|
|
1431
|
+
|
|
1432
|
+
return out
|
|
1433
|
+
|
|
1434
|
+
def _logprobs_from_text(
|
|
1435
|
+
self,
|
|
1436
|
+
text: str | list[str] | NonTensorStack,
|
|
1437
|
+
sampling_params: Any,
|
|
1438
|
+
out: TensorDictBase,
|
|
1439
|
+
) -> TensorDictBase:
|
|
1440
|
+
"""Compute log-probs from text input."""
|
|
1441
|
+
# Convert text to list format
|
|
1442
|
+
if isinstance(text, str):
|
|
1443
|
+
text = [text]
|
|
1444
|
+
elif not isinstance(text, list):
|
|
1445
|
+
text = text.tolist()
|
|
1446
|
+
|
|
1447
|
+
assert isinstance(
|
|
1448
|
+
text, (str, list)
|
|
1449
|
+
), f"text must be str or list, got {type(text)}"
|
|
1450
|
+
assert isinstance(
|
|
1451
|
+
sampling_params, SamplingParams
|
|
1452
|
+
), f"sampling_params must be SamplingParams, got {type(sampling_params)}"
|
|
1453
|
+
assert isinstance(
|
|
1454
|
+
out, TensorDictBase
|
|
1455
|
+
), f"out must be TensorDictBase, got {type(out)}"
|
|
1456
|
+
|
|
1457
|
+
# Tokenize the text
|
|
1458
|
+
if self.tokenizer is None:
|
|
1459
|
+
raise ValueError(
|
|
1460
|
+
"Tokenizer is required for log-probs computation with text input"
|
|
1461
|
+
)
|
|
1462
|
+
|
|
1463
|
+
# Tokenize the text
|
|
1464
|
+
tokenized_output = self.tokenizer(text, **self.tokenizer_kwargs)
|
|
1465
|
+
if self.pad_output:
|
|
1466
|
+
tokens_full_padded = tokenized_output["input_ids"]
|
|
1467
|
+
attention_mask_full_padded = tokenized_output["attention_mask"]
|
|
1468
|
+
tokens_full_list = self._to_list(
|
|
1469
|
+
tokens_full_padded, attention_mask_full_padded
|
|
1470
|
+
)
|
|
1471
|
+
else:
|
|
1472
|
+
tokens_full_unpadded = tokenized_output["input_ids"]
|
|
1473
|
+
tokens_full_list = self._to_list(tokens_full_unpadded, None)
|
|
1474
|
+
attention_mask_full_unpadded = tokenized_output["attention_mask"]
|
|
1475
|
+
attention_mask_full_unpadded = [
|
|
1476
|
+
am.bool()
|
|
1477
|
+
if isinstance(am, torch.Tensor)
|
|
1478
|
+
else torch.tensor(am, dtype=torch.bool)
|
|
1479
|
+
for am in attention_mask_full_unpadded
|
|
1480
|
+
]
|
|
1481
|
+
|
|
1482
|
+
# Convert to list format for vLLM
|
|
1483
|
+
generate_kwargs = {
|
|
1484
|
+
"sampling_params": sampling_params,
|
|
1485
|
+
"prompt_token_ids": tokens_full_list,
|
|
1486
|
+
}
|
|
1487
|
+
|
|
1488
|
+
# Generate with vLLM to get prompt_logprobs
|
|
1489
|
+
request_output = self._call_generate(**generate_kwargs)
|
|
1490
|
+
|
|
1491
|
+
request_output_tc = _RequestOutput_tc.from_request_output(request_output)
|
|
1492
|
+
|
|
1493
|
+
# Extract log-probs from prompt_logprobs
|
|
1494
|
+
if self.pad_output:
|
|
1495
|
+
# For padded case, use all prompt_logprobs
|
|
1496
|
+
log_probs_full_padded = request_output_tc.get(
|
|
1497
|
+
"prompt_logprobs",
|
|
1498
|
+
as_padded_tensor=True,
|
|
1499
|
+
padding_value=0,
|
|
1500
|
+
padding_side="left",
|
|
1501
|
+
)
|
|
1502
|
+
|
|
1503
|
+
# Mask out padding
|
|
1504
|
+
attention_mask_full_padded = tokens_full_padded != self.padding_value
|
|
1505
|
+
log_probs_full_padded = torch.where(
|
|
1506
|
+
attention_mask_full_padded, log_probs_full_padded, 0.0
|
|
1507
|
+
)
|
|
1508
|
+
else:
|
|
1509
|
+
# For unpadded case, extract from each sequence
|
|
1510
|
+
log_probs_full_unpadded = request_output_tc.get(
|
|
1511
|
+
"prompt_logprobs", as_list=True
|
|
1512
|
+
)
|
|
1513
|
+
self._check_not_padded(log_probs_full_unpadded)
|
|
1514
|
+
|
|
1515
|
+
masks_obj = Masks._from_tensordict(
|
|
1516
|
+
TensorDict(batch_size=out.batch_size).to_lazystack(0)
|
|
1517
|
+
)
|
|
1518
|
+
if self.pad_output:
|
|
1519
|
+
self._check_padded(attention_mask_full_padded)
|
|
1520
|
+
masks_obj.all_attention_mask = attention_mask_full_padded.bool()
|
|
1521
|
+
else:
|
|
1522
|
+
self._check_not_padded(attention_mask_full_unpadded)
|
|
1523
|
+
masks_obj.all_attention_mask = attention_mask_full_unpadded
|
|
1524
|
+
masks_obj.padded = MetaData(self.pad_output)
|
|
1525
|
+
out.set(self.masks_key, masks_obj)
|
|
1526
|
+
|
|
1527
|
+
# Build output TensorClass objects
|
|
1528
|
+
text_obj = Text._from_tensordict(
|
|
1529
|
+
TensorDict(batch_size=out.batch_size).to_lazystack(0)
|
|
1530
|
+
)
|
|
1531
|
+
text_obj.prompt = None
|
|
1532
|
+
text_obj.response = None
|
|
1533
|
+
text_obj.full = text
|
|
1534
|
+
out.set(self.text_key, text_obj)
|
|
1535
|
+
|
|
1536
|
+
tokens_obj = Tokens._from_tensordict(
|
|
1537
|
+
TensorDict(batch_size=out.batch_size).to_lazystack(0)
|
|
1538
|
+
)
|
|
1539
|
+
if self.pad_output:
|
|
1540
|
+
self._check_padded(tokens_full_padded)
|
|
1541
|
+
tokens_obj.full = tokens_full_padded
|
|
1542
|
+
else:
|
|
1543
|
+
tokens_obj.full = tokens_full_unpadded
|
|
1544
|
+
tokens_obj.response = None
|
|
1545
|
+
tokens_obj.padded = MetaData(self.pad_output)
|
|
1546
|
+
out.set(self.tokens_key, tokens_obj)
|
|
1547
|
+
|
|
1548
|
+
if self.return_log_probs:
|
|
1549
|
+
log_probs_obj = LogProbs._from_tensordict(
|
|
1550
|
+
TensorDict(batch_size=out.batch_size).to_lazystack(0)
|
|
1551
|
+
)
|
|
1552
|
+
if self.pad_output:
|
|
1553
|
+
self._check_padded(log_probs_full_padded)
|
|
1554
|
+
log_probs_obj.full = log_probs_full_padded
|
|
1555
|
+
else:
|
|
1556
|
+
self._check_not_padded(log_probs_full_unpadded)
|
|
1557
|
+
log_probs_obj.full = log_probs_full_unpadded
|
|
1558
|
+
log_probs_obj.response = None
|
|
1559
|
+
log_probs_obj.padded = MetaData(self.pad_output)
|
|
1560
|
+
out.set(self.log_probs_key, log_probs_obj)
|
|
1561
|
+
|
|
1562
|
+
return out
|
|
1563
|
+
|
|
1564
|
+
def _cat_tensors(
|
|
1565
|
+
self,
|
|
1566
|
+
tokens: list[torch.Tensor] | torch.Tensor,
|
|
1567
|
+
response_tokens: list[torch.Tensor] | torch.Tensor,
|
|
1568
|
+
) -> list[torch.Tensor] | torch.Tensor:
|
|
1569
|
+
"""Concatenate tokens and response tokens."""
|
|
1570
|
+
if isinstance(tokens, list) or isinstance(response_tokens, list):
|
|
1571
|
+
return [
|
|
1572
|
+
self._cat_tensors(t, t_)
|
|
1573
|
+
for t, t_ in _zip_strict(tokens, response_tokens)
|
|
1574
|
+
]
|
|
1575
|
+
else:
|
|
1576
|
+
return torch.cat([tokens, response_tokens], dim=-1)
|
|
1577
|
+
|
|
1578
|
+
def _generate_from_tokens(
|
|
1579
|
+
self,
|
|
1580
|
+
tokens_prompt_unpadded: list[torch.Tensor] | None,
|
|
1581
|
+
tokens_prompt_padded: torch.Tensor | None,
|
|
1582
|
+
sampling_params: Any,
|
|
1583
|
+
out: TensorDictBase,
|
|
1584
|
+
) -> TensorDictBase:
|
|
1585
|
+
"""Generate text from tokens input."""
|
|
1586
|
+
assert isinstance(
|
|
1587
|
+
tokens_prompt_padded, (torch.Tensor, type(None))
|
|
1588
|
+
), f"tokens_prompt_padded must be torch.Tensor or None, got {type(tokens_prompt_padded)}"
|
|
1589
|
+
assert isinstance(
|
|
1590
|
+
tokens_prompt_unpadded, (list, type(None))
|
|
1591
|
+
), f"tokens_prompt_unpadded must be list or None, got {type(tokens_prompt_unpadded)}"
|
|
1592
|
+
assert isinstance(
|
|
1593
|
+
sampling_params, SamplingParams
|
|
1594
|
+
), f"sampling_params must be SamplingParams, got {type(sampling_params)}"
|
|
1595
|
+
assert isinstance(
|
|
1596
|
+
out, TensorDictBase
|
|
1597
|
+
), f"out must be TensorDictBase, got {type(out)}"
|
|
1598
|
+
|
|
1599
|
+
generate_kwargs = {"sampling_params": sampling_params}
|
|
1600
|
+
args = ()
|
|
1601
|
+
empirical_attention_mask = None
|
|
1602
|
+
|
|
1603
|
+
if tokens_prompt_unpadded is None:
|
|
1604
|
+
# TODO: To be on the safe side, we may do this even in the unpadded case since we're not sure
|
|
1605
|
+
# the user passed an unpadded tensor in the first place.
|
|
1606
|
+
empirical_attention_mask = tokens_prompt_padded != self.padding_value
|
|
1607
|
+
tokens_prompt_list = self._to_list(
|
|
1608
|
+
tokens_prompt_padded, empirical_attention_mask
|
|
1609
|
+
)
|
|
1610
|
+
else:
|
|
1611
|
+
tokens_prompt_list = self._to_list(tokens_prompt_unpadded, None)
|
|
1612
|
+
generate_kwargs.update({"prompt_token_ids": tokens_prompt_list})
|
|
1613
|
+
|
|
1614
|
+
# Call generate based on model type
|
|
1615
|
+
request_output = self._call_generate(*args, **generate_kwargs)
|
|
1616
|
+
|
|
1617
|
+
request_output_tc = _RequestOutput_tc.from_request_output(request_output)
|
|
1618
|
+
|
|
1619
|
+
# Extract response tokens and text
|
|
1620
|
+
outputs = (
|
|
1621
|
+
request_output_tc.outputs.view(-1)
|
|
1622
|
+
if self.num_samples is not None
|
|
1623
|
+
else request_output_tc.outputs
|
|
1624
|
+
)
|
|
1625
|
+
if self.pad_output:
|
|
1626
|
+
tokens_response_padded = outputs.get(
|
|
1627
|
+
"token_ids",
|
|
1628
|
+
as_padded_tensor=self.pad_output,
|
|
1629
|
+
padding_value=self.padding_value,
|
|
1630
|
+
padding_side="right",
|
|
1631
|
+
)
|
|
1632
|
+
self._check_padded(tokens_response_padded)
|
|
1633
|
+
tokens_response_unpadded = outputs.get(
|
|
1634
|
+
"token_ids",
|
|
1635
|
+
as_list=True,
|
|
1636
|
+
)
|
|
1637
|
+
self._check_not_padded(tokens_response_unpadded)
|
|
1638
|
+
|
|
1639
|
+
tokens_obj = Tokens._from_tensordict(out.empty())
|
|
1640
|
+
if self.pad_output:
|
|
1641
|
+
self._check_padded(tokens_response_padded)
|
|
1642
|
+
self._check_padded(tokens_prompt_padded)
|
|
1643
|
+
else:
|
|
1644
|
+
self._check_not_padded(tokens_response_unpadded)
|
|
1645
|
+
self._check_not_padded(tokens_prompt_unpadded)
|
|
1646
|
+
|
|
1647
|
+
if self.num_samples is not None:
|
|
1648
|
+
# replicate tokens
|
|
1649
|
+
for i in range(self.num_samples):
|
|
1650
|
+
tokens_obj[:, i].prompt = (
|
|
1651
|
+
tokens_prompt_unpadded
|
|
1652
|
+
if not self.pad_output
|
|
1653
|
+
else tokens_prompt_padded
|
|
1654
|
+
)
|
|
1655
|
+
else:
|
|
1656
|
+
tokens_obj.prompt = (
|
|
1657
|
+
tokens_prompt_unpadded if not self.pad_output else tokens_prompt_padded
|
|
1658
|
+
)
|
|
1659
|
+
with tokens_obj.view(-1) as tokens_obj_flat:
|
|
1660
|
+
if self.pad_output:
|
|
1661
|
+
tokens_obj_flat.response = tokens_response_padded
|
|
1662
|
+
tokens_full_padded = self._cat_tensors(
|
|
1663
|
+
tokens_obj_flat.prompt, tokens_response_padded
|
|
1664
|
+
)
|
|
1665
|
+
tokens_obj_flat.full = tokens_full_padded
|
|
1666
|
+
else:
|
|
1667
|
+
tokens_obj_flat.response = tokens_response_unpadded
|
|
1668
|
+
tokens_full_unpadded = self._cat_tensors(
|
|
1669
|
+
tokens_obj_flat.get("prompt", as_list=True),
|
|
1670
|
+
tokens_response_unpadded,
|
|
1671
|
+
)
|
|
1672
|
+
tokens_obj_flat.full = tokens_full_unpadded
|
|
1673
|
+
tokens_obj.padded = MetaData(self.pad_output)
|
|
1674
|
+
out.set(self.tokens_key, tokens_obj)
|
|
1675
|
+
|
|
1676
|
+
masks_obj = Masks._from_tensordict(out.empty())
|
|
1677
|
+
# self.return_tokens must be True
|
|
1678
|
+
if self.pad_output:
|
|
1679
|
+
# Get "real" attention masks
|
|
1680
|
+
full_attention_mask_padded = tokens_obj.get("full") != self.padding_value
|
|
1681
|
+
masks_obj.all_attention_mask = full_attention_mask_padded.bool()
|
|
1682
|
+
else:
|
|
1683
|
+
# Get "real" attention masks
|
|
1684
|
+
# We can use select to avoid batch-size problems
|
|
1685
|
+
_td = torch.ones_like(
|
|
1686
|
+
out.select(("tokens", "full"))
|
|
1687
|
+
.copy()
|
|
1688
|
+
.rename_key_(("tokens", "full"), "all_attention_mask")
|
|
1689
|
+
).bool()
|
|
1690
|
+
del _td["tokens"]
|
|
1691
|
+
masks_obj.update(_td)
|
|
1692
|
+
masks_obj.all_assistant_mask = None
|
|
1693
|
+
masks_obj.padded = MetaData(self.pad_output)
|
|
1694
|
+
out.set(self.masks_key, masks_obj)
|
|
1695
|
+
|
|
1696
|
+
if self.return_log_probs:
|
|
1697
|
+
if self.pad_output:
|
|
1698
|
+
log_probs_padded = outputs.get(
|
|
1699
|
+
"logprobs",
|
|
1700
|
+
as_padded_tensor=self.pad_output,
|
|
1701
|
+
padding_value=self.padding_value,
|
|
1702
|
+
padding_side="right",
|
|
1703
|
+
)
|
|
1704
|
+
else:
|
|
1705
|
+
log_probs_list = outputs.get(
|
|
1706
|
+
"logprobs",
|
|
1707
|
+
as_list=True,
|
|
1708
|
+
)
|
|
1709
|
+
self._check_not_padded(log_probs_list)
|
|
1710
|
+
if self.num_samples is None:
|
|
1711
|
+
# TODO: this is not correct, we should use the prompt_logprobs
|
|
1712
|
+
# but they're not returned by vLLM
|
|
1713
|
+
if self.pad_output:
|
|
1714
|
+
prompt_logprobs_padded = request_output_tc.get(
|
|
1715
|
+
"prompt_logprobs",
|
|
1716
|
+
as_padded_tensor=self.pad_output,
|
|
1717
|
+
padding_value=self.padding_value,
|
|
1718
|
+
padding_side="right",
|
|
1719
|
+
)
|
|
1720
|
+
if (
|
|
1721
|
+
prompt_logprobs_padded.shape[-1]
|
|
1722
|
+
!= tokens_prompt_padded.shape[-1]
|
|
1723
|
+
):
|
|
1724
|
+
tshape = tokens_prompt_padded.shape
|
|
1725
|
+
oshape = prompt_logprobs_padded.shape
|
|
1726
|
+
# it could be that the input was padded already - padding again then
|
|
1727
|
+
prompt_logprobs_padded = torch.cat(
|
|
1728
|
+
[
|
|
1729
|
+
prompt_logprobs_padded.new_zeros(
|
|
1730
|
+
tshape[:-1] + (tshape[-1] - oshape[-1],)
|
|
1731
|
+
),
|
|
1732
|
+
prompt_logprobs_padded,
|
|
1733
|
+
],
|
|
1734
|
+
-1,
|
|
1735
|
+
)
|
|
1736
|
+
else:
|
|
1737
|
+
prompt_logprobs_list = request_output_tc.get(
|
|
1738
|
+
"prompt_logprobs",
|
|
1739
|
+
as_list=True,
|
|
1740
|
+
)
|
|
1741
|
+
self._check_not_padded(prompt_logprobs_list)
|
|
1742
|
+
log_probs_obj = LogProbs._from_tensordict(out.empty())
|
|
1743
|
+
if self.pad_output:
|
|
1744
|
+
self._check_padded(log_probs_padded)
|
|
1745
|
+
if self.num_samples is None:
|
|
1746
|
+
self._check_padded(prompt_logprobs_padded)
|
|
1747
|
+
log_probs_obj.prompt = prompt_logprobs_padded
|
|
1748
|
+
else:
|
|
1749
|
+
self._check_not_padded(log_probs_list)
|
|
1750
|
+
if self.num_samples is None:
|
|
1751
|
+
self._check_not_padded(prompt_logprobs_list)
|
|
1752
|
+
log_probs_obj.prompt = prompt_logprobs_list
|
|
1753
|
+
with log_probs_obj.view(-1) as log_probs_obj_flat:
|
|
1754
|
+
log_probs_obj_flat.response = (
|
|
1755
|
+
log_probs_padded if self.pad_output else log_probs_list
|
|
1756
|
+
)
|
|
1757
|
+
if self.num_samples is None:
|
|
1758
|
+
if self.pad_output:
|
|
1759
|
+
log_probs_obj_flat.full = self._cat_tensors(
|
|
1760
|
+
log_probs_obj_flat.prompt, log_probs_padded
|
|
1761
|
+
)
|
|
1762
|
+
else:
|
|
1763
|
+
log_probs_obj_flat.full = self._cat_tensors(
|
|
1764
|
+
log_probs_obj_flat.get("prompt", as_list=True),
|
|
1765
|
+
log_probs_list,
|
|
1766
|
+
)
|
|
1767
|
+
else:
|
|
1768
|
+
log_probs_obj_flat.full = None
|
|
1769
|
+
log_probs_obj.padded = MetaData(self.pad_output)
|
|
1770
|
+
out.set(self.log_probs_key, log_probs_obj)
|
|
1771
|
+
return out
|
|
1772
|
+
|
|
1773
|
+
def _logprobs_from_tokens(
|
|
1774
|
+
self,
|
|
1775
|
+
*,
|
|
1776
|
+
response_struct: TensorDictBase | None = None,
|
|
1777
|
+
tokens_full_unpadded: list[torch.Tensor] | None = None,
|
|
1778
|
+
tokens_full_padded: torch.Tensor | None = None,
|
|
1779
|
+
sampling_params: Any | None = None,
|
|
1780
|
+
out: TensorDictBase | None = None,
|
|
1781
|
+
) -> TensorDictBase:
|
|
1782
|
+
"""Compute log-probs from tokens input."""
|
|
1783
|
+
assert isinstance(
|
|
1784
|
+
response_struct, (TensorDictBase, type(None))
|
|
1785
|
+
), f"response_struct must be TensorDictBase or None, got {type(response_struct)}"
|
|
1786
|
+
assert isinstance(
|
|
1787
|
+
tokens_full_unpadded, (list, type(None))
|
|
1788
|
+
), f"tokens_full_unpadded must be list or None, got {type(tokens_full_unpadded)}"
|
|
1789
|
+
assert isinstance(
|
|
1790
|
+
tokens_full_padded, (torch.Tensor, type(None))
|
|
1791
|
+
), f"tokens_full_padded must be torch.Tensor or None, got {type(tokens_full_padded)}"
|
|
1792
|
+
assert isinstance(
|
|
1793
|
+
sampling_params, (SamplingParams, type(None))
|
|
1794
|
+
), f"sampling_params must be SamplingParams or None, got {type(sampling_params)}"
|
|
1795
|
+
assert isinstance(
|
|
1796
|
+
out, (TensorDictBase, type(None))
|
|
1797
|
+
), f"out must be TensorDictBase or None, got {type(out)}"
|
|
1798
|
+
|
|
1799
|
+
# Convert to list format for vLLM
|
|
1800
|
+
if response_struct is not None:
|
|
1801
|
+
tokens_full_padded = response_struct.get(
|
|
1802
|
+
"input_ids",
|
|
1803
|
+
as_padded_tensor=True,
|
|
1804
|
+
padding_value=self.padding_value,
|
|
1805
|
+
padding_side="left",
|
|
1806
|
+
)
|
|
1807
|
+
attention_mask_full_padded = response_struct.get(
|
|
1808
|
+
"attention_mask",
|
|
1809
|
+
as_padded_tensor=True,
|
|
1810
|
+
padding_value=False,
|
|
1811
|
+
padding_side="left",
|
|
1812
|
+
).bool()
|
|
1813
|
+
attention_mask_full_unpadded = _unpad_tensors(
|
|
1814
|
+
attention_mask_full_padded, attention_mask_full_padded, as_nested=False
|
|
1815
|
+
)
|
|
1816
|
+
elif tokens_full_unpadded is not None:
|
|
1817
|
+
tokens_full_padded = pad_sequence(
|
|
1818
|
+
tokens_full_unpadded,
|
|
1819
|
+
padding_value=self.padding_value,
|
|
1820
|
+
batch_first=True,
|
|
1821
|
+
padding_side="left",
|
|
1822
|
+
)
|
|
1823
|
+
attention_mask_full_unpadded = [
|
|
1824
|
+
t != self.padding_value for t in tokens_full_unpadded
|
|
1825
|
+
]
|
|
1826
|
+
attention_mask_full_padded = pad_sequence(
|
|
1827
|
+
attention_mask_full_unpadded,
|
|
1828
|
+
padding_value=False,
|
|
1829
|
+
batch_first=True,
|
|
1830
|
+
padding_side="left",
|
|
1831
|
+
)
|
|
1832
|
+
elif tokens_full_padded is not None:
|
|
1833
|
+
attention_mask_full_padded = tokens_full_padded != self.padding_value
|
|
1834
|
+
else:
|
|
1835
|
+
raise ValueError("Either response_struct or tokens must be provided")
|
|
1836
|
+
|
|
1837
|
+
assert isinstance(tokens_full_padded, torch.Tensor)
|
|
1838
|
+
assert isinstance(attention_mask_full_padded, torch.Tensor)
|
|
1839
|
+
if tokens_full_unpadded is None:
|
|
1840
|
+
tokens_full_list = self._to_list(
|
|
1841
|
+
tokens_full_padded, attention_mask_full_padded
|
|
1842
|
+
)
|
|
1843
|
+
else:
|
|
1844
|
+
tokens_full_list = self._to_list(tokens_full_unpadded, None)
|
|
1845
|
+
|
|
1846
|
+
generate_kwargs = {
|
|
1847
|
+
"sampling_params": sampling_params,
|
|
1848
|
+
"prompt_token_ids": tokens_full_list,
|
|
1849
|
+
}
|
|
1850
|
+
|
|
1851
|
+
# Generate with vLLM to get prompt_logprobs
|
|
1852
|
+
tokens_out_stuct = self._call_generate(**generate_kwargs)
|
|
1853
|
+
|
|
1854
|
+
request_output_tc = _RequestOutput_tc.from_request_output(tokens_out_stuct)
|
|
1855
|
+
|
|
1856
|
+
# For unpadded case, extract from each sequence
|
|
1857
|
+
log_probs_full_unpadded = request_output_tc.get("prompt_logprobs", as_list=True)
|
|
1858
|
+
|
|
1859
|
+
# Extract log-probs from prompt_logprobs
|
|
1860
|
+
if self.pad_output:
|
|
1861
|
+
# For padded case, use all prompt_logprobs
|
|
1862
|
+
if attention_mask_full_padded is not None:
|
|
1863
|
+
attention_mask_full_padded = tokens_full_padded != self.padding_value
|
|
1864
|
+
log_probs_full_padded = torch.zeros_like(
|
|
1865
|
+
tokens_full_padded, dtype=torch.get_default_dtype()
|
|
1866
|
+
)
|
|
1867
|
+
log_probs_full_padded[attention_mask_full_padded] = torch.cat(
|
|
1868
|
+
log_probs_full_unpadded, -1
|
|
1869
|
+
)
|
|
1870
|
+
else:
|
|
1871
|
+
self._check_not_padded(log_probs_full_unpadded)
|
|
1872
|
+
|
|
1873
|
+
assistant_mask_full_padded = None
|
|
1874
|
+
if response_struct is not None:
|
|
1875
|
+
assistant_mask_full_padded = response_struct.get(
|
|
1876
|
+
"assistant_masks",
|
|
1877
|
+
as_padded_tensor=True,
|
|
1878
|
+
padding_side="left",
|
|
1879
|
+
padding_value=0,
|
|
1880
|
+
)
|
|
1881
|
+
if assistant_mask_full_padded is not None:
|
|
1882
|
+
assistant_mask_full_padded = assistant_mask_full_padded.bool()
|
|
1883
|
+
if not self.pad_output:
|
|
1884
|
+
assistant_mask_full_unpadded = _unpad_tensors(
|
|
1885
|
+
assistant_mask_full_padded,
|
|
1886
|
+
attention_mask_full_padded,
|
|
1887
|
+
as_nested=False,
|
|
1888
|
+
)
|
|
1889
|
+
else:
|
|
1890
|
+
assistant_mask_full_unpadded = None
|
|
1891
|
+
else:
|
|
1892
|
+
assistant_mask_full_unpadded = None
|
|
1893
|
+
|
|
1894
|
+
masks_obj = Masks._from_tensordict(
|
|
1895
|
+
TensorDict(batch_size=out.batch_size).to_lazystack(0)
|
|
1896
|
+
)
|
|
1897
|
+
if self.pad_output:
|
|
1898
|
+
self._check_padded(attention_mask_full_padded)
|
|
1899
|
+
masks_obj.all_attention_mask = attention_mask_full_padded.bool()
|
|
1900
|
+
if assistant_mask_full_padded is not None:
|
|
1901
|
+
masks_obj.all_assistant_mask = assistant_mask_full_padded
|
|
1902
|
+
else:
|
|
1903
|
+
self._check_not_padded(attention_mask_full_unpadded)
|
|
1904
|
+
masks_obj.all_attention_mask = attention_mask_full_unpadded
|
|
1905
|
+
if assistant_mask_full_unpadded is not None:
|
|
1906
|
+
masks_obj.all_assistant_mask = assistant_mask_full_unpadded
|
|
1907
|
+
masks_obj.padded = MetaData(self.pad_output)
|
|
1908
|
+
out.set(self.masks_key, masks_obj)
|
|
1909
|
+
|
|
1910
|
+
tokens_obj = Tokens._from_tensordict(
|
|
1911
|
+
TensorDict(batch_size=out.batch_size).to_lazystack(0)
|
|
1912
|
+
)
|
|
1913
|
+
if self.pad_output:
|
|
1914
|
+
self._check_padded(tokens_full_padded)
|
|
1915
|
+
tokens_obj.full = tokens_full_padded
|
|
1916
|
+
else:
|
|
1917
|
+
tokens_obj.full = tokens_full_unpadded
|
|
1918
|
+
tokens_obj.response = None
|
|
1919
|
+
tokens_obj.padded = MetaData(self.pad_output)
|
|
1920
|
+
out.set(self.tokens_key, tokens_obj)
|
|
1921
|
+
|
|
1922
|
+
log_probs_obj = LogProbs._from_tensordict(
|
|
1923
|
+
TensorDict(batch_size=out.batch_size).to_lazystack(0)
|
|
1924
|
+
)
|
|
1925
|
+
if self.pad_output:
|
|
1926
|
+
self._check_padded(log_probs_full_padded)
|
|
1927
|
+
log_probs_obj.full = log_probs_full_padded
|
|
1928
|
+
else:
|
|
1929
|
+
self._check_not_padded(log_probs_full_unpadded)
|
|
1930
|
+
log_probs_obj.full = log_probs_full_unpadded
|
|
1931
|
+
log_probs_obj.response = None
|
|
1932
|
+
log_probs_obj.padded = MetaData(self.pad_output)
|
|
1933
|
+
out.set(self.log_probs_key, log_probs_obj)
|
|
1934
|
+
|
|
1935
|
+
return out
|
|
1936
|
+
|
|
1937
|
+
def _to_list(
|
|
1938
|
+
self,
|
|
1939
|
+
tokens_padded: torch.Tensor | list[torch.Tensor],
|
|
1940
|
+
attention_mask_padded: torch.Tensor | None,
|
|
1941
|
+
) -> list[list[int]]:
|
|
1942
|
+
"""Converts a tensor of integers into a masked list (of lists) of integers."""
|
|
1943
|
+
if isinstance(tokens_padded, torch.Tensor):
|
|
1944
|
+
parent = []
|
|
1945
|
+
queue = collections.deque()
|
|
1946
|
+
if attention_mask_padded is None:
|
|
1947
|
+
attention_mask_padded = torch.ones_like(tokens_padded)
|
|
1948
|
+
queue.append((tokens_padded, attention_mask_padded.bool(), parent))
|
|
1949
|
+
while queue:
|
|
1950
|
+
token_tensor, attention_mask_bool, _parent = queue.popleft()
|
|
1951
|
+
if token_tensor.ndim == 1:
|
|
1952
|
+
_parent.extend(token_tensor[attention_mask_bool].tolist())
|
|
1953
|
+
else:
|
|
1954
|
+
_parent.extend([[] for _ in range(token_tensor.shape[0])])
|
|
1955
|
+
queue.extend(
|
|
1956
|
+
[
|
|
1957
|
+
(t, m, local_parent)
|
|
1958
|
+
for t, m, local_parent in zip(
|
|
1959
|
+
token_tensor, attention_mask_bool, _parent
|
|
1960
|
+
)
|
|
1961
|
+
]
|
|
1962
|
+
)
|
|
1963
|
+
tokens_list = parent
|
|
1964
|
+
elif isinstance(tokens_padded, list):
|
|
1965
|
+
parent = []
|
|
1966
|
+
queue = collections.deque()
|
|
1967
|
+
queue.append((tokens_padded, parent))
|
|
1968
|
+
while queue:
|
|
1969
|
+
tokens_list, _parent = queue.popleft()
|
|
1970
|
+
if isinstance(tokens_list, list) and isinstance(
|
|
1971
|
+
tokens_list[0], (list, torch.Tensor)
|
|
1972
|
+
):
|
|
1973
|
+
_parent.extend([[] for _ in tokens_list])
|
|
1974
|
+
queue.extend(
|
|
1975
|
+
[
|
|
1976
|
+
(t, local_parent)
|
|
1977
|
+
for t, local_parent in zip(tokens_list, _parent)
|
|
1978
|
+
]
|
|
1979
|
+
)
|
|
1980
|
+
continue
|
|
1981
|
+
elif isinstance(tokens_list, torch.Tensor):
|
|
1982
|
+
tokens_list = tokens_list.tolist()
|
|
1983
|
+
_parent.extend(tokens_list)
|
|
1984
|
+
tokens_list = parent
|
|
1985
|
+
|
|
1986
|
+
return tokens_list
|
|
1987
|
+
|
|
1988
|
+
@_classproperty
|
|
1989
|
+
def CompletionOutput_tc(cls):
|
|
1990
|
+
_vllm = _require_vllm()
|
|
1991
|
+
|
|
1992
|
+
if hasattr(cls, "_CompletionOutput_tc"):
|
|
1993
|
+
return cls._CompletionOutput_tc
|
|
1994
|
+
CompletionOutput_tc = from_dataclass(_vllm.outputs.CompletionOutput) # type: ignore
|
|
1995
|
+
cls._CompletionOutput_tc = CompletionOutput_tc
|
|
1996
|
+
return CompletionOutput_tc
|
|
1997
|
+
|
|
1998
|
+
def get_dist(
|
|
1999
|
+
self,
|
|
2000
|
+
tensordict: TensorDictBase,
|
|
2001
|
+
tensordict_out: TensorDictBase | None = None,
|
|
2002
|
+
logits_key: NestedKey = "logits",
|
|
2003
|
+
mask_key: NestedKey | None = None,
|
|
2004
|
+
as_padded_tensor: bool | None = None,
|
|
2005
|
+
as_nested_tensor: bool | None = None,
|
|
2006
|
+
padding_value: float | None = None,
|
|
2007
|
+
padding_side: str = "right",
|
|
2008
|
+
layout: torch.layout | None = None,
|
|
2009
|
+
**kwargs,
|
|
2010
|
+
) -> D.Distribution:
|
|
2011
|
+
"""Get distribution from logits/log-probs with optional masking.
|
|
2012
|
+
|
|
2013
|
+
vLLM does not return logits, so this method is not supported.
|
|
2014
|
+
"""
|
|
2015
|
+
raise NotImplementedError(
|
|
2016
|
+
"vLLM does not return logits, so get_dist is not supported"
|
|
2017
|
+
)
|
|
2018
|
+
|
|
2019
|
+
def get_dist_with_prompt_mask(
|
|
2020
|
+
self,
|
|
2021
|
+
tensordict: TensorDictBase,
|
|
2022
|
+
tokens_key: NestedKey = ("tokens", "full"),
|
|
2023
|
+
logits_key: NestedKey = "logits",
|
|
2024
|
+
assistant_mask_key: NestedKey = ("masks", "all_assistant_mask"),
|
|
2025
|
+
attention_mask_key: NestedKey = ("masks", "all_attention_mask"),
|
|
2026
|
+
**kwargs,
|
|
2027
|
+
) -> D.Distribution:
|
|
2028
|
+
"""Get distribution masked to only include response tokens (exclude prompt).
|
|
2029
|
+
|
|
2030
|
+
vLLM does not return logits, so this method is not supported.
|
|
2031
|
+
|
|
2032
|
+
This is a provisional method that will be replaced by the `get_dist` method once we have a better masking strategy.
|
|
2033
|
+
"""
|
|
2034
|
+
raise NotImplementedError(
|
|
2035
|
+
"vLLM does not return logits, so get_dist_with_prompt_mask is not supported"
|
|
2036
|
+
)
|
|
2037
|
+
|
|
2038
|
+
def _get_dist_with_assistant_mask(
|
|
2039
|
+
self,
|
|
2040
|
+
tensordict: TensorDictBase,
|
|
2041
|
+
assistant_mask_key: NestedKey = ("masks", "all_assistant_mask"),
|
|
2042
|
+
logits_key: NestedKey = "logits",
|
|
2043
|
+
**kwargs,
|
|
2044
|
+
) -> D.Distribution:
|
|
2045
|
+
"""Get distribution masked to only include assistant tokens.
|
|
2046
|
+
|
|
2047
|
+
vLLM does not return logits, so this method is not supported.
|
|
2048
|
+
|
|
2049
|
+
This is a provisional method that will be replaced by the `get_dist` method once we have a better masking strategy.
|
|
2050
|
+
"""
|
|
2051
|
+
raise NotImplementedError(
|
|
2052
|
+
"vLLM does not return logits, so get_dist_with_assistant_mask is not supported"
|
|
2053
|
+
)
|
|
2054
|
+
|
|
2055
|
+
def _get_dist_with_attention_mask(
|
|
2056
|
+
self,
|
|
2057
|
+
tensordict: TensorDictBase,
|
|
2058
|
+
attention_mask_key: NestedKey = ("masks", "all_attention_mask"),
|
|
2059
|
+
logits_key: NestedKey = "logits",
|
|
2060
|
+
**kwargs,
|
|
2061
|
+
) -> D.Distribution:
|
|
2062
|
+
"""Get distribution masked using attention mask.
|
|
2063
|
+
|
|
2064
|
+
vLLM does not return logits, so this method is not supported.
|
|
2065
|
+
|
|
2066
|
+
This is a provisional method that will be replaced by the `get_dist` method once we have a better masking strategy.
|
|
2067
|
+
"""
|
|
2068
|
+
raise NotImplementedError(
|
|
2069
|
+
"vLLM does not return logits, so get_dist_with_attention_mask is not supported"
|
|
2070
|
+
)
|
|
2071
|
+
|
|
2072
|
+
def _get_dist_with_custom_mask(
|
|
2073
|
+
self,
|
|
2074
|
+
tensordict: TensorDictBase,
|
|
2075
|
+
mask: torch.Tensor,
|
|
2076
|
+
logits_key: NestedKey = "logits",
|
|
2077
|
+
**kwargs,
|
|
2078
|
+
) -> D.Distribution:
|
|
2079
|
+
"""Get distribution with custom mask.
|
|
2080
|
+
|
|
2081
|
+
vLLM does not return logits, so this method is not supported.
|
|
2082
|
+
|
|
2083
|
+
This is a provisional method that will be replaced by the `get_dist` method once we have a better masking strategy.
|
|
2084
|
+
"""
|
|
2085
|
+
raise NotImplementedError(
|
|
2086
|
+
"vLLM does not return logits, so get_dist_with_custom_mask is not supported"
|
|
2087
|
+
)
|
|
2088
|
+
|
|
2089
|
+
# Convenience methods for common LLM training scenarios
|
|
2090
|
+
def _get_sft_dist(self, tensordict: TensorDictBase, **kwargs) -> D.Distribution:
|
|
2091
|
+
"""Get distribution suitable for SFT loss (response tokens only).
|
|
2092
|
+
|
|
2093
|
+
vLLM does not return logits, so this method is not supported.
|
|
2094
|
+
|
|
2095
|
+
This is a provisional method that will be replaced by the `get_dist` method once we have a better masking strategy.
|
|
2096
|
+
"""
|
|
2097
|
+
raise NotImplementedError(
|
|
2098
|
+
"vLLM does not return logits, so get_sft_dist is not supported"
|
|
2099
|
+
)
|
|
2100
|
+
|
|
2101
|
+
def _get_rlhf_dist(self, tensordict: TensorDictBase, **kwargs) -> D.Distribution:
|
|
2102
|
+
"""Get distribution suitable for RLHF loss (assistant tokens only).
|
|
2103
|
+
|
|
2104
|
+
vLLM does not return logits, so this method is not supported.
|
|
2105
|
+
|
|
2106
|
+
This is a provisional method that will be replaced by the `get_dist` method once we have a better masking strategy.
|
|
2107
|
+
"""
|
|
2108
|
+
raise NotImplementedError(
|
|
2109
|
+
"vLLM does not return logits, so get_rlhf_dist is not supported"
|
|
2110
|
+
)
|
|
2111
|
+
|
|
2112
|
+
def _get_generic_dist(self, tensordict: TensorDictBase, **kwargs) -> D.Distribution:
|
|
2113
|
+
"""Get distribution suitable for generic losses (all tokens).
|
|
2114
|
+
|
|
2115
|
+
vLLM does not return logits, so this method is not supported.
|
|
2116
|
+
|
|
2117
|
+
This is a provisional method that will be replaced by the `get_dist` method once we have a better masking strategy.
|
|
2118
|
+
"""
|
|
2119
|
+
raise NotImplementedError(
|
|
2120
|
+
"vLLM does not return logits, so get_generic_dist is not supported"
|
|
2121
|
+
)
|
|
2122
|
+
|
|
2123
|
+
|
|
2124
|
+
class _RequestOutput_tc(TensorClass["nocast"]):
|
|
2125
|
+
"""TensorClass wrapper for vLLM RequestOutput."""
|
|
2126
|
+
|
|
2127
|
+
request_id: str
|
|
2128
|
+
prompt: str
|
|
2129
|
+
prompt_token_ids: torch.Tensor
|
|
2130
|
+
prompt_logprobs: torch.Tensor
|
|
2131
|
+
outputs: Any
|
|
2132
|
+
finished: str
|
|
2133
|
+
metrics: str
|
|
2134
|
+
lora_request: str
|
|
2135
|
+
encoder_prompt: str
|
|
2136
|
+
encoder_prompt_token_ids: str
|
|
2137
|
+
num_cached_tokens: torch.Tensor
|
|
2138
|
+
|
|
2139
|
+
def __post_init__(self):
|
|
2140
|
+
CompletionOutput_tc = vLLMWrapper.CompletionOutput_tc
|
|
2141
|
+
|
|
2142
|
+
def postproc(output):
|
|
2143
|
+
def get_logprob(output):
|
|
2144
|
+
t = []
|
|
2145
|
+
token_ids = output.token_ids
|
|
2146
|
+
if isinstance(token_ids, torch.Tensor):
|
|
2147
|
+
token_ids = token_ids.tolist()
|
|
2148
|
+
for v, tid in zip(output.logprobs, token_ids):
|
|
2149
|
+
t.append(
|
|
2150
|
+
v[tid]["logprob"] if v[tid].get("logprob") is not None else 0.0
|
|
2151
|
+
)
|
|
2152
|
+
return torch.tensor(t)
|
|
2153
|
+
|
|
2154
|
+
if output.logprobs:
|
|
2155
|
+
output.logprobs = get_logprob(output)
|
|
2156
|
+
output.token_ids = torch.as_tensor(output.token_ids)
|
|
2157
|
+
return output
|
|
2158
|
+
|
|
2159
|
+
if isinstance(self.outputs, list):
|
|
2160
|
+
outputs = self.outputs
|
|
2161
|
+
outputs = [
|
|
2162
|
+
postproc(from_dataclass(output, dest_cls=CompletionOutput_tc))
|
|
2163
|
+
for output in outputs
|
|
2164
|
+
]
|
|
2165
|
+
if len(outputs) == 1:
|
|
2166
|
+
self.outputs = outputs[0]
|
|
2167
|
+
else:
|
|
2168
|
+
# Check if we can stack the outputs (they should have the same shape)
|
|
2169
|
+
try:
|
|
2170
|
+
self.outputs = lazy_stack(outputs)
|
|
2171
|
+
except RuntimeError:
|
|
2172
|
+
# If stacking fails (different sizes), keep as list
|
|
2173
|
+
self.outputs = outputs
|
|
2174
|
+
|
|
2175
|
+
@classmethod
|
|
2176
|
+
def from_request_output(
|
|
2177
|
+
cls, requests: RequestOutput | list[RequestOutput]
|
|
2178
|
+
) -> _RequestOutput_tc | list[_RequestOutput_tc]:
|
|
2179
|
+
"""Create _RequestOutput_tc from vLLM RequestOutput."""
|
|
2180
|
+
# Type assertions
|
|
2181
|
+
assert isinstance(
|
|
2182
|
+
requests, (RequestOutput, list)
|
|
2183
|
+
), f"requests must be RequestOutput or list, got {type(requests)}"
|
|
2184
|
+
|
|
2185
|
+
# Check if we can stack the outputs
|
|
2186
|
+
try:
|
|
2187
|
+
out = lazy_stack(
|
|
2188
|
+
[
|
|
2189
|
+
cls(
|
|
2190
|
+
request_id=request.request_id,
|
|
2191
|
+
prompt=request.prompt,
|
|
2192
|
+
prompt_token_ids=torch.as_tensor(request.prompt_token_ids),
|
|
2193
|
+
prompt_logprobs=torch.tensor(
|
|
2194
|
+
[
|
|
2195
|
+
v[int(tid)].logprob if v is not None else 0.0
|
|
2196
|
+
for v, tid in _zip_strict(
|
|
2197
|
+
request.prompt_logprobs, request.prompt_token_ids
|
|
2198
|
+
)
|
|
2199
|
+
]
|
|
2200
|
+
)
|
|
2201
|
+
if request.prompt_logprobs is not None
|
|
2202
|
+
else torch.tensor([]),
|
|
2203
|
+
outputs=request.outputs,
|
|
2204
|
+
finished=request.finished,
|
|
2205
|
+
metrics=request.metrics,
|
|
2206
|
+
lora_request=request.lora_request,
|
|
2207
|
+
encoder_prompt=request.encoder_prompt,
|
|
2208
|
+
encoder_prompt_token_ids=request.encoder_prompt_token_ids,
|
|
2209
|
+
num_cached_tokens=torch.as_tensor(request.num_cached_tokens),
|
|
2210
|
+
)
|
|
2211
|
+
for request in requests
|
|
2212
|
+
]
|
|
2213
|
+
)
|
|
2214
|
+
return out
|
|
2215
|
+
except RuntimeError:
|
|
2216
|
+
# If stacking fails, return a list of individual _RequestOutput_tc objects
|
|
2217
|
+
return [
|
|
2218
|
+
cls(
|
|
2219
|
+
request_id=request.request_id,
|
|
2220
|
+
prompt=request.prompt,
|
|
2221
|
+
prompt_token_ids=torch.as_tensor(request.prompt_token_ids),
|
|
2222
|
+
prompt_logprobs=torch.tensor(
|
|
2223
|
+
[
|
|
2224
|
+
v[int(tid)].logprob if v is not None else 0.0
|
|
2225
|
+
for v, tid in _zip_strict(
|
|
2226
|
+
request.prompt_logprobs, request.prompt_token_ids
|
|
2227
|
+
)
|
|
2228
|
+
]
|
|
2229
|
+
)
|
|
2230
|
+
if request.prompt_logprobs is not None
|
|
2231
|
+
else torch.tensor([]),
|
|
2232
|
+
outputs=request.outputs,
|
|
2233
|
+
finished=request.finished,
|
|
2234
|
+
metrics=request.metrics,
|
|
2235
|
+
lora_request=request.lora_request,
|
|
2236
|
+
encoder_prompt=request.encoder_prompt,
|
|
2237
|
+
encoder_prompt_token_ids=request.encoder_prompt_token_ids,
|
|
2238
|
+
num_cached_tokens=torch.as_tensor(request.num_cached_tokens),
|
|
2239
|
+
)
|
|
2240
|
+
for request in requests
|
|
2241
|
+
]
|