tpu-inference 0.11.1.dev202511150811__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +0 -0
- tests/core/__init__.py +0 -0
- tests/core/test_core_tpu.py +513 -0
- tests/core/test_disagg_executor.py +60 -0
- tests/core/test_disagg_utils.py +53 -0
- tests/core/test_dp_scheduler.py +899 -0
- tests/core/test_init.py +49 -0
- tests/kernels/__init__.py +0 -0
- tests/kernels/fused_moe_v1_test.py +105 -0
- tests/kernels/mla_v1_test.py +396 -0
- tests/kernels/quantized_matmul_kernel_test.py +191 -0
- tests/kernels/ragged_kv_cache_update_v2_test.py +234 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +400 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +549 -0
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +504 -0
- tests/lora/__init__.py +0 -0
- tests/lora/conftest.py +32 -0
- tests/lora/test_bgmv.py +43 -0
- tests/lora/test_layers.py +654 -0
- tests/lora/test_lora.py +133 -0
- tests/lora/utils.py +96 -0
- tests/test_base.py +201 -0
- tests/test_envs.py +182 -0
- tests/test_quantization.py +836 -0
- tests/test_tpu_info.py +120 -0
- tests/test_utils.py +236 -0
- tpu_inference/__init__.py +34 -0
- tpu_inference/core/__init__.py +0 -0
- tpu_inference/core/core_tpu.py +786 -0
- tpu_inference/core/disagg_executor.py +118 -0
- tpu_inference/core/disagg_utils.py +51 -0
- tpu_inference/core/sched/__init__.py +0 -0
- tpu_inference/core/sched/dp_scheduler.py +523 -0
- tpu_inference/distributed/__init__.py +0 -0
- tpu_inference/distributed/jax_parallel_state.py +67 -0
- tpu_inference/distributed/tpu_connector.py +728 -0
- tpu_inference/distributed/utils.py +59 -0
- tpu_inference/env_override.py +9 -0
- tpu_inference/envs.py +107 -0
- tpu_inference/executors/__init__.py +0 -0
- tpu_inference/executors/ray_distributed_executor.py +362 -0
- tpu_inference/experimental/__init__.py +0 -0
- tpu_inference/experimental/llama3_jax_stashed.py +258 -0
- tpu_inference/kernels/__init__.py +0 -0
- tpu_inference/kernels/collectives/__init__.py +0 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
- tpu_inference/kernels/collectives/util.py +47 -0
- tpu_inference/kernels/flash_attention/__init__.py +0 -0
- tpu_inference/kernels/flash_attention/kernel.py +772 -0
- tpu_inference/kernels/fused_moe/__init__.py +0 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +0 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +1035 -0
- tpu_inference/kernels/mla/__init__.py +0 -0
- tpu_inference/kernels/mla/v1/__init__.py +0 -0
- tpu_inference/kernels/mla/v1/kernel.py +1349 -0
- tpu_inference/kernels/quantized_matmul/__init__.py +0 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +395 -0
- tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
- tpu_inference/kernels/quantized_matmul/util.py +58 -0
- tpu_inference/kernels/ragged_paged_attention/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +875 -0
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +287 -0
- tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1478 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1482 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4147 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +367 -0
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +51 -0
- tpu_inference/layers/__init__.py +0 -0
- tpu_inference/layers/common/__init__.py +0 -0
- tpu_inference/layers/common/attention_interface.py +390 -0
- tpu_inference/layers/common/attention_metadata.py +34 -0
- tpu_inference/layers/common/binary_search.py +295 -0
- tpu_inference/layers/common/quant_methods.py +8 -0
- tpu_inference/layers/common/sharding.py +582 -0
- tpu_inference/layers/jax/__init__.py +0 -0
- tpu_inference/layers/jax/attention/__init__.py +0 -0
- tpu_inference/layers/jax/attention/attention.py +255 -0
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +262 -0
- tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
- tpu_inference/layers/jax/base.py +151 -0
- tpu_inference/layers/jax/constants.py +88 -0
- tpu_inference/layers/jax/layers.py +301 -0
- tpu_inference/layers/jax/misc.py +16 -0
- tpu_inference/layers/jax/moe/__init__.py +0 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +185 -0
- tpu_inference/layers/jax/moe/moe.py +209 -0
- tpu_inference/layers/jax/rope.py +280 -0
- tpu_inference/layers/jax/rope_interface.py +214 -0
- tpu_inference/layers/jax/sample/__init__.py +0 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
- tpu_inference/layers/jax/sample/sampling.py +96 -0
- tpu_inference/layers/jax/sample/sampling_metadata.py +76 -0
- tpu_inference/layers/jax/transformer_block.py +107 -0
- tpu_inference/layers/vllm/__init__.py +0 -0
- tpu_inference/layers/vllm/attention.py +221 -0
- tpu_inference/layers/vllm/fused_moe.py +507 -0
- tpu_inference/layers/vllm/linear_common.py +186 -0
- tpu_inference/layers/vllm/quantization/__init__.py +39 -0
- tpu_inference/layers/vllm/quantization/awq.py +207 -0
- tpu_inference/layers/vllm/quantization/common.py +105 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +120 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +203 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +266 -0
- tpu_inference/layers/vllm/quantization/unquantized.py +386 -0
- tpu_inference/layers/vllm/sharding.py +230 -0
- tpu_inference/logger.py +10 -0
- tpu_inference/lora/__init__.py +0 -0
- tpu_inference/lora/torch_lora_ops.py +103 -0
- tpu_inference/lora/torch_punica_tpu.py +311 -0
- tpu_inference/mock/__init__.py +0 -0
- tpu_inference/mock/vllm_config_utils.py +28 -0
- tpu_inference/mock/vllm_envs.py +1219 -0
- tpu_inference/mock/vllm_logger.py +212 -0
- tpu_inference/mock/vllm_logging_utils.py +15 -0
- tpu_inference/models/__init__.py +0 -0
- tpu_inference/models/common/__init__.py +0 -0
- tpu_inference/models/common/model_loader.py +444 -0
- tpu_inference/models/jax/__init__.py +0 -0
- tpu_inference/models/jax/deepseek_v3.py +868 -0
- tpu_inference/models/jax/gpt_oss.py +492 -0
- tpu_inference/models/jax/jax_intermediate_tensor.py +79 -0
- tpu_inference/models/jax/llama3.py +375 -0
- tpu_inference/models/jax/llama4.py +629 -0
- tpu_inference/models/jax/llama_eagle3.py +333 -0
- tpu_inference/models/jax/phi3.py +376 -0
- tpu_inference/models/jax/qwen2.py +375 -0
- tpu_inference/models/jax/qwen2_5_vl.py +1103 -0
- tpu_inference/models/jax/qwen3.py +302 -0
- tpu_inference/models/jax/utils/__init__.py +0 -0
- tpu_inference/models/jax/utils/file_utils.py +96 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +163 -0
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +5 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +6 -0
- tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +5 -0
- tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +6 -0
- tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +105 -0
- tpu_inference/models/jax/utils/quantization/quantization_utils.py +653 -0
- tpu_inference/models/jax/utils/weight_utils.py +529 -0
- tpu_inference/models/vllm/__init__.py +0 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +286 -0
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +45 -0
- tpu_inference/platforms/__init__.py +2 -0
- tpu_inference/platforms/tpu_platform.py +269 -0
- tpu_inference/runner/__init__.py +0 -0
- tpu_inference/runner/block_table.py +122 -0
- tpu_inference/runner/compilation_manager.py +780 -0
- tpu_inference/runner/input_batch.py +435 -0
- tpu_inference/runner/kv_cache.py +132 -0
- tpu_inference/runner/kv_cache_manager.py +479 -0
- tpu_inference/runner/lora_utils.py +92 -0
- tpu_inference/runner/multimodal_manager.py +217 -0
- tpu_inference/runner/persistent_batch_manager.py +244 -0
- tpu_inference/runner/speculative_decoding_manager.py +248 -0
- tpu_inference/runner/structured_decoding_manager.py +88 -0
- tpu_inference/runner/tpu_runner.py +1620 -0
- tpu_inference/runner/utils.py +426 -0
- tpu_inference/spec_decode/__init__.py +0 -0
- tpu_inference/spec_decode/jax/__init__.py +0 -0
- tpu_inference/spec_decode/jax/eagle3.py +367 -0
- tpu_inference/tpu_info.py +77 -0
- tpu_inference/utils.py +317 -0
- tpu_inference/worker/__init__.py +0 -0
- tpu_inference/worker/tpu_worker.py +321 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/METADATA +107 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/RECORD +179 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/WHEEL +5 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/licenses/LICENSE +201 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,1620 @@
|
|
|
1
|
+
import copy
|
|
2
|
+
import functools
|
|
3
|
+
import os
|
|
4
|
+
import random
|
|
5
|
+
from contextlib import nullcontext
|
|
6
|
+
from dataclasses import dataclass
|
|
7
|
+
from typing import Any, Callable, Dict, List, Optional, Tuple, cast
|
|
8
|
+
|
|
9
|
+
import jax
|
|
10
|
+
import jax.numpy as jnp
|
|
11
|
+
import jaxtyping
|
|
12
|
+
import numpy as np
|
|
13
|
+
import torch
|
|
14
|
+
import vllm.envs as envs
|
|
15
|
+
from flax import nnx
|
|
16
|
+
from jax.experimental import mesh_utils
|
|
17
|
+
from jax.sharding import NamedSharding, PartitionSpec
|
|
18
|
+
from torchax.ops.mappings import j2t_dtype
|
|
19
|
+
from vllm.config import VllmConfig
|
|
20
|
+
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
|
|
21
|
+
has_kv_transfer_group)
|
|
22
|
+
from vllm.forward_context import set_forward_context
|
|
23
|
+
from vllm.sequence import IntermediateTensors
|
|
24
|
+
from vllm.tasks import SupportedTask
|
|
25
|
+
from vllm.utils.math_utils import cdiv
|
|
26
|
+
from vllm.v1.core.sched.output import GrammarOutput
|
|
27
|
+
from vllm.v1.core.sched.output import SchedulerOutput as VllmSchedulerOutput
|
|
28
|
+
from vllm.v1.kv_cache_interface import KVCacheConfig
|
|
29
|
+
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput,
|
|
30
|
+
DraftTokenIds, KVConnectorOutput, LogprobsLists,
|
|
31
|
+
ModelRunnerOutput)
|
|
32
|
+
from vllm.v1.request import Request
|
|
33
|
+
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
|
|
34
|
+
from vllm.v1.worker.kv_connector_model_runner_mixin import \
|
|
35
|
+
KVConnectorModelRunnerMixin
|
|
36
|
+
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
|
37
|
+
|
|
38
|
+
from tpu_inference import utils as common_utils
|
|
39
|
+
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
|
|
40
|
+
from tpu_inference.layers.common.sharding import (MESH_AXIS_NAMES,
|
|
41
|
+
MESH_AXIS_NAMES_2D,
|
|
42
|
+
ShardingAxisName,
|
|
43
|
+
ShardingConfigManager)
|
|
44
|
+
from tpu_inference.layers.jax.sample.rejection_sampler import RejectionSampler
|
|
45
|
+
from tpu_inference.layers.jax.sample.sampling import (compute_logprobs,
|
|
46
|
+
gather_logprobs, sample)
|
|
47
|
+
from tpu_inference.layers.jax.sample.sampling_metadata import \
|
|
48
|
+
TPUSupportedSamplingMetadata
|
|
49
|
+
from tpu_inference.logger import init_logger
|
|
50
|
+
from tpu_inference.models.common.model_loader import get_model
|
|
51
|
+
from tpu_inference.models.jax.utils.weight_utils import (
|
|
52
|
+
shard_put, transfer_state_with_mappings)
|
|
53
|
+
from tpu_inference.runner import utils as runner_utils
|
|
54
|
+
from tpu_inference.runner.compilation_manager import CompilationManager
|
|
55
|
+
from tpu_inference.runner.input_batch import CachedRequestState, InputBatch
|
|
56
|
+
from tpu_inference.runner.kv_cache_manager import KVCacheManager
|
|
57
|
+
from tpu_inference.runner.lora_utils import LoraUtils
|
|
58
|
+
from tpu_inference.runner.multimodal_manager import MultiModalManager
|
|
59
|
+
from tpu_inference.runner.persistent_batch_manager import \
|
|
60
|
+
PersistentBatchManager
|
|
61
|
+
from tpu_inference.runner.speculative_decoding_manager import (
|
|
62
|
+
SpecDecodeMetadata, SpeculativeDecodingManager)
|
|
63
|
+
from tpu_inference.runner.structured_decoding_manager import \
|
|
64
|
+
StructuredDecodingManager
|
|
65
|
+
from tpu_inference.spec_decode.jax.eagle3 import Eagle3Proposer
|
|
66
|
+
from tpu_inference.utils import (device_array, make_optimized_mesh,
|
|
67
|
+
time_function)
|
|
68
|
+
|
|
69
|
+
logger = init_logger(__name__)
|
|
70
|
+
|
|
71
|
+
INVALID_TOKEN_ID = -1
|
|
72
|
+
# Smallest output size
|
|
73
|
+
MIN_NUM_SEQS = 8
|
|
74
|
+
|
|
75
|
+
DUMMY_METADATA = AttentionMetadata(
|
|
76
|
+
input_positions=[],
|
|
77
|
+
block_tables=[],
|
|
78
|
+
request_distribution=[0, 0, 0],
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
TPU_STR_DTYPE_TO_TORCH_DTYPE = {
|
|
82
|
+
"half": torch.half,
|
|
83
|
+
"bfloat16": torch.bfloat16,
|
|
84
|
+
"float": torch.float,
|
|
85
|
+
"fp8": torch.float8_e4m3fn,
|
|
86
|
+
"fp8_e4m3": torch.float8_e4m3fn,
|
|
87
|
+
"fp8_e5m2": torch.float8_e5m2,
|
|
88
|
+
"int8": torch.int8,
|
|
89
|
+
"uint8": torch.uint8,
|
|
90
|
+
}
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
class AsyncTPUModelRunnerOutput(AsyncModelRunnerOutput):
|
|
94
|
+
"""Holds asynchronous model output specifically from a TPU runner.
|
|
95
|
+
|
|
96
|
+
This class acts as a wrapper around the standard ModelRunnerOutput. Its
|
|
97
|
+
primary purpose is to hold references to data still on the TPU device
|
|
98
|
+
(like the `next_tokens` JAX array) without blocking the main thread.
|
|
99
|
+
|
|
100
|
+
The `get_output()` method is called to resolve these async results,
|
|
101
|
+
triggering the JAX device-to-host (CPU) data transfer and populating
|
|
102
|
+
the final `ModelRunnerOutput` object.
|
|
103
|
+
"""
|
|
104
|
+
|
|
105
|
+
def __init__(
|
|
106
|
+
self,
|
|
107
|
+
model_runner_output: ModelRunnerOutput,
|
|
108
|
+
next_tokens: jax.Array,
|
|
109
|
+
num_reqs: int,
|
|
110
|
+
discard_sampled_tokens_req_indices: list[int],
|
|
111
|
+
logits_indices_selector: Optional[List[int]] = None,
|
|
112
|
+
):
|
|
113
|
+
self._model_runner_output = model_runner_output
|
|
114
|
+
self._next_tokens = next_tokens
|
|
115
|
+
self._num_reqs = num_reqs
|
|
116
|
+
self._discard_sampled_tokens_req_indices = discard_sampled_tokens_req_indices
|
|
117
|
+
self.logits_indices_selector: list[int] = logits_indices_selector
|
|
118
|
+
|
|
119
|
+
def get_output(self) -> ModelRunnerOutput:
|
|
120
|
+
next_tokens_cpu = np.asarray(jax.device_get(self._next_tokens))
|
|
121
|
+
if self.logits_indices_selector is not None:
|
|
122
|
+
next_tokens_cpu = next_tokens_cpu[self.logits_indices_selector]
|
|
123
|
+
selected_token_ids = np.expand_dims(next_tokens_cpu[:self._num_reqs],
|
|
124
|
+
1)
|
|
125
|
+
valid_sampled_token_ids = selected_token_ids.tolist()
|
|
126
|
+
for i in self._discard_sampled_tokens_req_indices:
|
|
127
|
+
valid_sampled_token_ids[i].clear()
|
|
128
|
+
self._model_runner_output.sampled_token_ids = valid_sampled_token_ids
|
|
129
|
+
return self._model_runner_output
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
@dataclass
|
|
133
|
+
class AsyncPreResults:
|
|
134
|
+
req_ids: list[str]
|
|
135
|
+
next_tokens: jax.Array
|
|
136
|
+
request_seq_lens: list[tuple[int, CachedRequestState, int]]
|
|
137
|
+
discard_sampled_tokens_req_indices: list[int]
|
|
138
|
+
placeholder_req_id_to_index: dict[str, int]
|
|
139
|
+
logits_indices_selector: Optional[List[int]] = None
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
@dataclass
|
|
143
|
+
class ExecuteModelState:
|
|
144
|
+
"""Ephemeral cached state transferred between execute_model() and
|
|
145
|
+
sample_tokens(), after execute_model() returns None."""
|
|
146
|
+
|
|
147
|
+
scheduler_output: "VllmSchedulerOutput"
|
|
148
|
+
attn_metadata: AttentionMetadata
|
|
149
|
+
input_ids: Optional[jax.Array]
|
|
150
|
+
hidden_states: jax.Array
|
|
151
|
+
logits: jax.Array
|
|
152
|
+
aux_hidden_states: Optional[jax.Array]
|
|
153
|
+
spec_decode_metadata: Optional[SpecDecodeMetadata]
|
|
154
|
+
kv_connector_output: Optional[KVConnectorOutput]
|
|
155
|
+
logits_indices_selector: Optional[List[int]] = None
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
@functools.partial(jax.jit, donate_argnums=(0, 1, 2))
|
|
159
|
+
def _substitute_placeholder_token(
|
|
160
|
+
input_ids: jax.Array, token_in_tpu_cur_input_indices: jax.Array,
|
|
161
|
+
token_in_tpu_pre_next_tokens_indices: jax.Array,
|
|
162
|
+
next_tokens: jax.Array, placeholder_num: int):
|
|
163
|
+
"""Substitute placeholder tokens from TPU for async scheduler
|
|
164
|
+
|
|
165
|
+
Padding for parallelisation of the substitute_placeholder_token_fn
|
|
166
|
+
[1, 3] => [1, 3, 0, 2, 4, 5, 6, 7, 8]
|
|
167
|
+
The reason for such a special padding instead of padding with -1 is:
|
|
168
|
+
An edge case when the end index needs to be updated and padding is required.
|
|
169
|
+
If we pad the array with -1, the _substitute_placeholder_token_fn will repeatedly update the end element with the original value
|
|
170
|
+
Although such a scenario is unlikely to happen in vLLM, it is best to eliminate any potential risks.
|
|
171
|
+
|
|
172
|
+
Args:
|
|
173
|
+
input_ids: possible input_ids size
|
|
174
|
+
token_in_tpu_cur_input_indices: replace holder idx in input_ids. Length the same to input_ids.
|
|
175
|
+
token_in_tpu_pre_next_tokens_indices: value idx in next_tokens. Length the same to input_ids.
|
|
176
|
+
next_tokens: next tokens on the TPU from previous step.
|
|
177
|
+
placeholder_num: number of placeholders. placeholder_num <= len(token_in_tpu_cur_input_indices)
|
|
178
|
+
Return:
|
|
179
|
+
input_ids after replace placeholder tokens
|
|
180
|
+
"""
|
|
181
|
+
assert input_ids.shape == token_in_tpu_cur_input_indices.shape == token_in_tpu_pre_next_tokens_indices.shape, \
|
|
182
|
+
f"Shape mismatch: input_ids and index arrays must have identical shapes due to precompilation assumptions. " \
|
|
183
|
+
f"Got: {input_ids.shape=}, {token_in_tpu_cur_input_indices.shape=}, {token_in_tpu_pre_next_tokens_indices.shape=}"
|
|
184
|
+
|
|
185
|
+
# updates the input_ids for all placeholders.
|
|
186
|
+
mask = jnp.arange(input_ids.shape[0]) < placeholder_num
|
|
187
|
+
new_token_values = next_tokens[token_in_tpu_pre_next_tokens_indices]
|
|
188
|
+
original_values = input_ids[token_in_tpu_cur_input_indices]
|
|
189
|
+
update_values = jnp.where(mask, new_token_values, original_values)
|
|
190
|
+
return input_ids.at[token_in_tpu_cur_input_indices].set(update_values)
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
def _reorder_logits_indices(logprobs_lists, logits_indices_selector):
|
|
194
|
+
return LogprobsLists(
|
|
195
|
+
logprob_token_ids=[
|
|
196
|
+
logprobs_lists.logprob_token_ids[i]
|
|
197
|
+
for i in logits_indices_selector
|
|
198
|
+
],
|
|
199
|
+
logprobs=[logprobs_lists.logprobs[i] for i in logits_indices_selector],
|
|
200
|
+
sampled_token_ranks=[
|
|
201
|
+
logprobs_lists.sampled_token_ranks[i]
|
|
202
|
+
for i in logits_indices_selector
|
|
203
|
+
],
|
|
204
|
+
cu_num_generated_tokens=logprobs_lists.cu_num_generated_tokens,
|
|
205
|
+
)
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
209
|
+
|
|
210
|
+
def __init__(
|
|
211
|
+
self,
|
|
212
|
+
vllm_config: VllmConfig,
|
|
213
|
+
devices: List[Any],
|
|
214
|
+
):
|
|
215
|
+
self.vllm_config = vllm_config
|
|
216
|
+
self.model_config = vllm_config.model_config
|
|
217
|
+
# TODO(jevinjiang): override block size based on RPA v3.
|
|
218
|
+
self.cache_config = vllm_config.cache_config
|
|
219
|
+
self.lora_config = vllm_config.lora_config
|
|
220
|
+
self.load_config = vllm_config.load_config
|
|
221
|
+
self.parallel_config = vllm_config.parallel_config
|
|
222
|
+
self.scheduler_config = vllm_config.scheduler_config
|
|
223
|
+
self.speculative_config = vllm_config.speculative_config
|
|
224
|
+
self.observability_config = vllm_config.observability_config
|
|
225
|
+
self.device_config = vllm_config.device_config
|
|
226
|
+
|
|
227
|
+
self.devices = devices
|
|
228
|
+
self.dtype = self.model_config.dtype
|
|
229
|
+
self.maybe_forbid_compile = runner_utils.ForbidCompile(
|
|
230
|
+
) if envs.VLLM_XLA_CHECK_RECOMPILATION else nullcontext()
|
|
231
|
+
self.dp_size = self.vllm_config.sharding_config.total_dp_size
|
|
232
|
+
|
|
233
|
+
self._init_random()
|
|
234
|
+
self._init_mesh()
|
|
235
|
+
self._init_phased_profiling()
|
|
236
|
+
self._init_mm()
|
|
237
|
+
self._init_inputs()
|
|
238
|
+
self._init_speculative_decoding()
|
|
239
|
+
|
|
240
|
+
# Delegate functions to specific manager classes.
|
|
241
|
+
self.compilation_manager = CompilationManager(self)
|
|
242
|
+
self.speculative_decoding_manager = SpeculativeDecodingManager(self)
|
|
243
|
+
self.structured_decoding_manager = StructuredDecodingManager(self)
|
|
244
|
+
self.kv_cache_manager = KVCacheManager(self)
|
|
245
|
+
self.mm_manager = MultiModalManager(self)
|
|
246
|
+
self.persistent_batch_manager = PersistentBatchManager(
|
|
247
|
+
self.requests, self.input_batch, self.encoder_cache,
|
|
248
|
+
self.uses_mrope, self.model_config)
|
|
249
|
+
self.lora_utils = LoraUtils(self)
|
|
250
|
+
|
|
251
|
+
cache_config = self.cache_config
|
|
252
|
+
if cache_config.cache_dtype == "auto":
|
|
253
|
+
model_dtype = self.dtype
|
|
254
|
+
if isinstance(model_dtype, str):
|
|
255
|
+
self.kv_cache_dtype = TPU_STR_DTYPE_TO_TORCH_DTYPE[model_dtype]
|
|
256
|
+
elif isinstance(getattr(model_dtype, 'dtype', None), jnp.dtype):
|
|
257
|
+
self.kv_cache_dtype = j2t_dtype(model_dtype.dtype)
|
|
258
|
+
elif isinstance(model_dtype, torch.dtype):
|
|
259
|
+
self.kv_cache_dtype = model_dtype
|
|
260
|
+
else:
|
|
261
|
+
raise ValueError(
|
|
262
|
+
"KV cache is unsupported for model_dtype of %s",
|
|
263
|
+
model_dtype)
|
|
264
|
+
else:
|
|
265
|
+
self.kv_cache_dtype = TPU_STR_DTYPE_TO_TORCH_DTYPE[
|
|
266
|
+
cache_config.cache_dtype]
|
|
267
|
+
|
|
268
|
+
self._pre_async_results: AsyncPreResults | None = None
|
|
269
|
+
self._substitute_placeholder_token_fn = _substitute_placeholder_token
|
|
270
|
+
self.execute_model_state: ExecuteModelState | None = None
|
|
271
|
+
|
|
272
|
+
def _init_random(self):
|
|
273
|
+
if self.model_config.seed is None:
|
|
274
|
+
self.model_config.seed = 0
|
|
275
|
+
random.seed(self.model_config.seed)
|
|
276
|
+
np.random.seed(self.model_config.seed)
|
|
277
|
+
self.rng_key = jax.random.key(self.model_config.seed)
|
|
278
|
+
|
|
279
|
+
def _init_mesh(self) -> None:
|
|
280
|
+
if os.getenv("NEW_MODEL_DESIGN", False):
|
|
281
|
+
self.mesh = self._create_new_model_mesh()
|
|
282
|
+
else:
|
|
283
|
+
# NOTE(wenxindongwork): The new MoE kernel expects a 2D mesh, so we need
|
|
284
|
+
# to create a 2D mesh for now. We should make the new_model_mesh as the default
|
|
285
|
+
# in the future.
|
|
286
|
+
self.mesh = self._create_2d_mesh()
|
|
287
|
+
|
|
288
|
+
logger.info(f"Init mesh | mesh={self.mesh}")
|
|
289
|
+
|
|
290
|
+
def _create_new_model_mesh(self) -> jax.sharding.Mesh:
|
|
291
|
+
num_slices = int(os.environ.get('NUM_SLICES', 1))
|
|
292
|
+
|
|
293
|
+
logger.info(f"Creating new model mesh | devices={len(self.devices)}, "
|
|
294
|
+
f"num_slices={num_slices}")
|
|
295
|
+
|
|
296
|
+
if num_slices == 1:
|
|
297
|
+
devices_array = self._create_single_slice_mesh()
|
|
298
|
+
else:
|
|
299
|
+
devices_array = self._create_multi_slice_mesh(num_slices)
|
|
300
|
+
|
|
301
|
+
return jax.sharding.Mesh(devices_array, MESH_AXIS_NAMES)
|
|
302
|
+
|
|
303
|
+
def _create_single_slice_mesh(self) -> jax.Array:
|
|
304
|
+
sharding_strategy: ShardingConfigManager = self.vllm_config.sharding_config
|
|
305
|
+
mesh_shape = (
|
|
306
|
+
sharding_strategy.model_dp_size,
|
|
307
|
+
sharding_strategy.attn_dp_size,
|
|
308
|
+
sharding_strategy.expert_size,
|
|
309
|
+
sharding_strategy.tp_size,
|
|
310
|
+
)
|
|
311
|
+
|
|
312
|
+
return mesh_utils.create_device_mesh(
|
|
313
|
+
mesh_shape,
|
|
314
|
+
self.devices,
|
|
315
|
+
allow_split_physical_axes=True,
|
|
316
|
+
)
|
|
317
|
+
|
|
318
|
+
def _create_multi_slice_mesh(self, num_slices: int) -> jax.Array:
|
|
319
|
+
sharding_strategy: ShardingConfigManager = self.vllm_config.sharding_config
|
|
320
|
+
dp_inner = sharding_strategy.model_dp_size // num_slices
|
|
321
|
+
|
|
322
|
+
# Splits data parallelism across multiple slices.
|
|
323
|
+
ici_mesh_shape = (
|
|
324
|
+
dp_inner,
|
|
325
|
+
sharding_strategy.attn_dp_size,
|
|
326
|
+
sharding_strategy.expert_size,
|
|
327
|
+
sharding_strategy.tp_size,
|
|
328
|
+
)
|
|
329
|
+
dcn_mesh_shape = (num_slices, 1, 1, 1)
|
|
330
|
+
|
|
331
|
+
return mesh_utils.create_hybrid_device_mesh(
|
|
332
|
+
mesh_shape=ici_mesh_shape,
|
|
333
|
+
dcn_mesh_shape=dcn_mesh_shape,
|
|
334
|
+
devices=self.devices,
|
|
335
|
+
allow_split_physical_axes=True,
|
|
336
|
+
)
|
|
337
|
+
|
|
338
|
+
def _create_2d_mesh(self) -> jax.sharding.Mesh:
|
|
339
|
+
|
|
340
|
+
sharding_strategy: ShardingConfigManager = self.vllm_config.sharding_config
|
|
341
|
+
mesh_shape = (
|
|
342
|
+
sharding_strategy.model_dp_size,
|
|
343
|
+
sharding_strategy.tp_size,
|
|
344
|
+
)
|
|
345
|
+
|
|
346
|
+
enforce_device_order = (
|
|
347
|
+
self.vllm_config.sharding_config.device_indexes is not None
|
|
348
|
+
and len(self.vllm_config.sharding_config.device_indexes) > 0)
|
|
349
|
+
|
|
350
|
+
if enforce_device_order:
|
|
351
|
+
return jax.make_mesh(mesh_shape,
|
|
352
|
+
MESH_AXIS_NAMES_2D,
|
|
353
|
+
devices=self.devices)
|
|
354
|
+
else:
|
|
355
|
+
return make_optimized_mesh(mesh_shape,
|
|
356
|
+
MESH_AXIS_NAMES_2D,
|
|
357
|
+
devices=self.devices)
|
|
358
|
+
|
|
359
|
+
def _init_phased_profiling(self) -> None:
|
|
360
|
+
self.phased_profiling_dir = os.getenv("PHASED_PROFILING_DIR", "")
|
|
361
|
+
self.phase_based_profiler = None
|
|
362
|
+
if self.phased_profiling_dir:
|
|
363
|
+
self.phase_based_profiler = runner_utils.PhasedBasedProfiler(
|
|
364
|
+
self.phased_profiling_dir)
|
|
365
|
+
|
|
366
|
+
def _init_mm(self) -> None:
|
|
367
|
+
self.is_multimodal_model = None
|
|
368
|
+
self.uses_mrope = self.model_config.uses_mrope
|
|
369
|
+
|
|
370
|
+
def _init_speculative_decoding(self) -> None:
|
|
371
|
+
self.drafter = None
|
|
372
|
+
if self.speculative_config:
|
|
373
|
+
if self.speculative_config.method == "ngram":
|
|
374
|
+
self.drafter = NgramProposer(self.vllm_config)
|
|
375
|
+
elif self.speculative_config.method == "eagle3":
|
|
376
|
+
self.drafter = Eagle3Proposer(self.vllm_config, self)
|
|
377
|
+
else:
|
|
378
|
+
raise NotImplementedError(
|
|
379
|
+
"Unsupported speculative decoding method: "
|
|
380
|
+
f"{self.speculative_config.method}")
|
|
381
|
+
self.rejection_sampler = RejectionSampler()
|
|
382
|
+
|
|
383
|
+
def _init_inputs(self) -> None:
|
|
384
|
+
model_config = self.model_config
|
|
385
|
+
cache_config = self.cache_config
|
|
386
|
+
scheduler_config = self.scheduler_config
|
|
387
|
+
|
|
388
|
+
self.sliding_window = model_config.get_sliding_window()
|
|
389
|
+
self.block_size = cache_config.block_size
|
|
390
|
+
self.max_model_len = model_config.max_model_len
|
|
391
|
+
self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size)
|
|
392
|
+
# InputBatch needs to work with sampling tensors greater than padding
|
|
393
|
+
# to avoid dynamic shapes. Also, avoid suboptimal alignment.
|
|
394
|
+
# The total number of requests is dp_size * max_num_seqs
|
|
395
|
+
self.max_num_reqs = max(self.dp_size * scheduler_config.max_num_seqs,
|
|
396
|
+
MIN_NUM_SEQS)
|
|
397
|
+
# [16, 32, 64, 128, 256, 512, 1024, 2048]
|
|
398
|
+
self.num_tokens_paddings = runner_utils.get_token_paddings(
|
|
399
|
+
min_token_size=max(16, self.dp_size),
|
|
400
|
+
max_token_size=scheduler_config.max_num_batched_tokens *
|
|
401
|
+
self.dp_size,
|
|
402
|
+
padding_gap=envs.VLLM_TPU_BUCKET_PADDING_GAP)
|
|
403
|
+
self.num_tokens_paddings_per_dp = [
|
|
404
|
+
padding // self.dp_size for padding in self.num_tokens_paddings
|
|
405
|
+
]
|
|
406
|
+
# In case `max_num_tokens < max(num_tokens_paddings)` use the actual
|
|
407
|
+
# padded max value to pre-allocate data structures and pre-compile.
|
|
408
|
+
self.max_num_tokens = self.num_tokens_paddings[-1]
|
|
409
|
+
|
|
410
|
+
# Request states.
|
|
411
|
+
self.requests: dict[str, CachedRequestState] = {}
|
|
412
|
+
# mm_hash -> encoder_output
|
|
413
|
+
self.encoder_cache: dict[str, jax.Array] = {}
|
|
414
|
+
self.input_batch = InputBatch(
|
|
415
|
+
max_num_reqs=self.max_num_reqs,
|
|
416
|
+
max_model_len=self.max_model_len,
|
|
417
|
+
max_num_batched_tokens=self.max_num_tokens,
|
|
418
|
+
pin_memory=False,
|
|
419
|
+
vocab_size=self.model_config.get_vocab_size(),
|
|
420
|
+
block_sizes=[self.block_size],
|
|
421
|
+
is_spec_decode=bool(self.vllm_config.speculative_config),
|
|
422
|
+
)
|
|
423
|
+
|
|
424
|
+
self.input_ids_cpu = np.zeros(self.max_num_tokens, dtype=np.int32)
|
|
425
|
+
self.positions_cpu = np.zeros(self.max_num_tokens, dtype=np.int32)
|
|
426
|
+
self.block_table_cpu = np.zeros(
|
|
427
|
+
(self.max_num_reqs, self.max_num_blocks_per_req), dtype=np.int32)
|
|
428
|
+
self.query_start_loc_cpu = np.zeros(self.max_num_reqs + self.dp_size,
|
|
429
|
+
dtype=np.int32)
|
|
430
|
+
self.seq_lens_cpu = np.zeros(self.max_num_reqs, dtype=np.int32)
|
|
431
|
+
self.logits_indices_cpu = np.zeros(self.max_num_reqs, dtype=np.int32)
|
|
432
|
+
# Range tensor with values [0 .. self.max_num_tokens - 1].
|
|
433
|
+
# Used to initialize positions / context_lens / seq_lens
|
|
434
|
+
# Keep in int64 to avoid overflow with long context
|
|
435
|
+
self.arange_cpu = np.arange(self.max_num_tokens, dtype=np.int64)
|
|
436
|
+
min_num_reqs = max(MIN_NUM_SEQS, self.dp_size)
|
|
437
|
+
self.num_reqs_paddings = runner_utils.get_req_paddings(
|
|
438
|
+
min_req_size=min_num_reqs, max_req_size=self.max_num_reqs)
|
|
439
|
+
self.num_reqs_paddings_per_dp = [
|
|
440
|
+
padding // self.dp_size for padding in self.num_reqs_paddings
|
|
441
|
+
]
|
|
442
|
+
|
|
443
|
+
# Padding for logits. Without speculative decoding, each request has one position to select from.
|
|
444
|
+
# With speculative decoding, each request has multiple positions to select from.
|
|
445
|
+
max_logits_per_req = 1
|
|
446
|
+
if self.speculative_config:
|
|
447
|
+
max_logits_per_req = self.speculative_config.num_speculative_tokens + 1 # Including bonus token
|
|
448
|
+
self.num_logits_paddings = runner_utils.get_token_paddings(
|
|
449
|
+
min_token_size=MIN_NUM_SEQS,
|
|
450
|
+
max_token_size=self.max_num_reqs * max_logits_per_req,
|
|
451
|
+
padding_gap=0)
|
|
452
|
+
else:
|
|
453
|
+
self.num_logits_paddings = None
|
|
454
|
+
|
|
455
|
+
self.temperatures_cpu = np.zeros(self.max_num_tokens, dtype=np.float32)
|
|
456
|
+
self.top_ps_cpu = np.zeros(self.max_num_tokens, dtype=np.float32)
|
|
457
|
+
self.top_ks_cpu = np.zeros(self.max_num_tokens, dtype=np.int32)
|
|
458
|
+
|
|
459
|
+
# tensors for structured decoding
|
|
460
|
+
self.vocab_size = self.model_config.get_vocab_size()
|
|
461
|
+
if self.lora_config is not None:
|
|
462
|
+
# lora_config.lora_extra_vocab_size is the "Maximum size of extra vocabulary that can be present in a LoRA adapter" per https://github.com/vanbasten23/vllm/blob/7f4a8b6705622fde952a2e633e86716f902d6e1b/vllm/config.py#L3040
|
|
463
|
+
self.vocab_size += self.lora_config.lora_extra_vocab_size
|
|
464
|
+
self.grammar_bitmask_cpu = np.zeros(
|
|
465
|
+
(self.max_num_reqs, cdiv(self.vocab_size, 32)),
|
|
466
|
+
dtype=np.int32,
|
|
467
|
+
)
|
|
468
|
+
self.require_structured_out_cpu = np.zeros(
|
|
469
|
+
(self.max_num_reqs, 1),
|
|
470
|
+
dtype=np.bool_,
|
|
471
|
+
)
|
|
472
|
+
self.structured_decode_arange = np.arange(0, 32, dtype=np.int32)
|
|
473
|
+
|
|
474
|
+
# multi-modal support
|
|
475
|
+
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
|
|
476
|
+
|
|
477
|
+
# NOTE: When M-RoPE is enabled, position ids are 3D regardless of
|
|
478
|
+
# the modality of inputs. For text-only inputs, each dimension has
|
|
479
|
+
# identical position IDs, making M-RoPE functionally equivalent to
|
|
480
|
+
# 1D-RoPE.
|
|
481
|
+
# See page 5 of https://arxiv.org/abs/2409.12191
|
|
482
|
+
self.mrope_positions_cpu = np.zeros((3, self.max_num_tokens),
|
|
483
|
+
dtype=np.int64)
|
|
484
|
+
|
|
485
|
+
def load_model(self):
|
|
486
|
+
self.model_fn, self.compute_logits_fn, self.combine_hidden_states_fn, multimodal_fns, self.state, self.lora_manager, self.model = get_model(
|
|
487
|
+
self.vllm_config,
|
|
488
|
+
self.rng_key,
|
|
489
|
+
self.mesh,
|
|
490
|
+
)
|
|
491
|
+
|
|
492
|
+
multimodal_fns = multimodal_fns or {}
|
|
493
|
+
self.precompile_vision_encoder_fn = multimodal_fns.get(
|
|
494
|
+
"precompile_vision_encoder_fn", None)
|
|
495
|
+
self.get_multimodal_embeddings_fn = multimodal_fns.get(
|
|
496
|
+
"get_multimodal_embeddings_fn", None)
|
|
497
|
+
self.get_input_embeddings_fn = multimodal_fns.get(
|
|
498
|
+
"get_input_embeddings_fn", None)
|
|
499
|
+
self.get_mrope_input_positions_fn = multimodal_fns.get(
|
|
500
|
+
"get_mrope_input_positions_fn", None)
|
|
501
|
+
|
|
502
|
+
if self.drafter is not None:
|
|
503
|
+
logger.info("Loading drafter model...")
|
|
504
|
+
self.drafter.load_model(self.state)
|
|
505
|
+
|
|
506
|
+
self.rng_params_for_sampling = nnx.Rngs(
|
|
507
|
+
jax.random.key(self.model_config.seed)).params()
|
|
508
|
+
self.is_multimodal_model = (self.model_config.is_multimodal_model
|
|
509
|
+
and self.get_multimodal_embeddings_fn
|
|
510
|
+
is not None)
|
|
511
|
+
|
|
512
|
+
logger.info(f"Init model | "
|
|
513
|
+
f"hbm={common_utils.hbm_usage_gb(self.devices)}GiB")
|
|
514
|
+
|
|
515
|
+
def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
|
|
516
|
+
return ("generate", )
|
|
517
|
+
|
|
518
|
+
def get_kv_cache_spec(self):
|
|
519
|
+
return self.kv_cache_manager.get_kv_cache_spec()
|
|
520
|
+
|
|
521
|
+
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
|
|
522
|
+
self.kv_cache_config = kv_cache_config
|
|
523
|
+
self.kv_caches = []
|
|
524
|
+
self.kv_cache_manager.initialize_kv_cache(kv_cache_config)
|
|
525
|
+
if has_kv_transfer_group():
|
|
526
|
+
get_kv_transfer_group().register_runner(self)
|
|
527
|
+
|
|
528
|
+
def capture_model(self) -> None:
|
|
529
|
+
self.compilation_manager.capture_model()
|
|
530
|
+
|
|
531
|
+
@time_function
|
|
532
|
+
def execute_model(
|
|
533
|
+
self,
|
|
534
|
+
scheduler_output: "VllmSchedulerOutput",
|
|
535
|
+
intermediate_tensors: Optional[IntermediateTensors] = None,
|
|
536
|
+
) -> ModelRunnerOutput | None:
|
|
537
|
+
if self.execute_model_state is not None:
|
|
538
|
+
raise RuntimeError("State error: sample_tokens() must be called "
|
|
539
|
+
"after execute_model() returns None.")
|
|
540
|
+
_, output = self._execute_model(scheduler_output)
|
|
541
|
+
return output
|
|
542
|
+
|
|
543
|
+
def sample_tokens(
|
|
544
|
+
self,
|
|
545
|
+
grammar_output: "GrammarOutput | None",
|
|
546
|
+
) -> ModelRunnerOutput | AsyncTPUModelRunnerOutput:
|
|
547
|
+
if self.execute_model_state is None:
|
|
548
|
+
# This can happen in pipeline parallel case.
|
|
549
|
+
return EMPTY_MODEL_RUNNER_OUTPUT
|
|
550
|
+
|
|
551
|
+
(scheduler_output, attn_metadata, input_ids, hidden_states, logits,
|
|
552
|
+
aux_hidden_states, spec_decode_metadata, kv_connector_output,
|
|
553
|
+
logits_indices_selector) = (
|
|
554
|
+
self.execute_model_state.scheduler_output,
|
|
555
|
+
self.execute_model_state.attn_metadata,
|
|
556
|
+
self.execute_model_state.input_ids,
|
|
557
|
+
self.execute_model_state.hidden_states,
|
|
558
|
+
self.execute_model_state.logits,
|
|
559
|
+
self.execute_model_state.aux_hidden_states,
|
|
560
|
+
self.execute_model_state.spec_decode_metadata,
|
|
561
|
+
self.execute_model_state.kv_connector_output,
|
|
562
|
+
self.execute_model_state.logits_indices_selector)
|
|
563
|
+
self.execute_model_state = None
|
|
564
|
+
|
|
565
|
+
if grammar_output is not None:
|
|
566
|
+
(
|
|
567
|
+
require_struct_decoding, grammar_bitmask_padded, arange
|
|
568
|
+
) = self.structured_decoding_manager.prepare_structured_decoding_input(
|
|
569
|
+
logits, grammar_output)
|
|
570
|
+
logits = self.structured_decoding_manager.structured_decode_fn(
|
|
571
|
+
require_struct_decoding,
|
|
572
|
+
grammar_bitmask_padded,
|
|
573
|
+
logits,
|
|
574
|
+
arange,
|
|
575
|
+
)
|
|
576
|
+
return self._sample_from_logits(scheduler_output, attn_metadata,
|
|
577
|
+
input_ids, hidden_states, logits,
|
|
578
|
+
aux_hidden_states,
|
|
579
|
+
spec_decode_metadata,
|
|
580
|
+
kv_connector_output,
|
|
581
|
+
logits_indices_selector)
|
|
582
|
+
|
|
583
|
+
def _modify_prev_results(self):
|
|
584
|
+
# If copy to host has not been done, we just wait.
|
|
585
|
+
# device_get should return immediately as we have scheduled it in previous function call.
|
|
586
|
+
assert self._pre_async_results is not None, "When we call _modify_prev_results(), self._pre_async_results should already exist"
|
|
587
|
+
pre_req_ids = self._pre_async_results.req_ids
|
|
588
|
+
pre_next_tokens = self._pre_async_results.next_tokens
|
|
589
|
+
pre_request_seq_lens = self._pre_async_results.request_seq_lens
|
|
590
|
+
pre_discard_sampled_tokens_req_indices = self._pre_async_results.discard_sampled_tokens_req_indices
|
|
591
|
+
pre_logits_indices_selector = self._pre_async_results.logits_indices_selector
|
|
592
|
+
|
|
593
|
+
next_tokens_cpu = np.asarray(jax.device_get(pre_next_tokens))
|
|
594
|
+
if pre_logits_indices_selector is not None:
|
|
595
|
+
next_tokens_cpu = next_tokens_cpu[pre_logits_indices_selector]
|
|
596
|
+
selected_token_ids = np.expand_dims(next_tokens_cpu[:len(pre_req_ids)],
|
|
597
|
+
1)
|
|
598
|
+
valid_sampled_token_ids = selected_token_ids.tolist()
|
|
599
|
+
|
|
600
|
+
# Mask out the sampled tokens that should not be sampled.
|
|
601
|
+
for i in pre_discard_sampled_tokens_req_indices:
|
|
602
|
+
valid_sampled_token_ids[i].clear()
|
|
603
|
+
# Append sampled tokens
|
|
604
|
+
for pre_req_idx, req_state, _ in pre_request_seq_lens:
|
|
605
|
+
sampled_ids = valid_sampled_token_ids[pre_req_idx]
|
|
606
|
+
if not sampled_ids:
|
|
607
|
+
continue
|
|
608
|
+
|
|
609
|
+
# If request not active in the *current* batch (e.g. finished or evicted), skip it.
|
|
610
|
+
req_id = pre_req_ids[pre_req_idx]
|
|
611
|
+
if req_id not in self.input_batch.req_id_to_index:
|
|
612
|
+
continue
|
|
613
|
+
|
|
614
|
+
req_idx = self.input_batch.req_id_to_index[req_id]
|
|
615
|
+
assert req_state is self.requests[
|
|
616
|
+
req_id], "The req_state should be valid and identical"
|
|
617
|
+
|
|
618
|
+
# Updated on previous execute
|
|
619
|
+
end_idx = self.input_batch.num_tokens_no_spec[req_idx]
|
|
620
|
+
assert len(sampled_ids) == 1, "do not support spec decode yet"
|
|
621
|
+
start_idx = end_idx - 1
|
|
622
|
+
assert end_idx <= self.max_model_len, (
|
|
623
|
+
"Sampled token IDs exceed the max model length. "
|
|
624
|
+
f"Total number of tokens: {end_idx} > max_model_len: "
|
|
625
|
+
f"{self.max_model_len}")
|
|
626
|
+
|
|
627
|
+
self.input_batch.token_ids_cpu[req_idx,
|
|
628
|
+
start_idx:end_idx] = sampled_ids
|
|
629
|
+
# Replace previous placeholder
|
|
630
|
+
req_state.output_token_ids[-1] = sampled_ids[-1]
|
|
631
|
+
|
|
632
|
+
def _update_placeholder(self,
|
|
633
|
+
discard_sampled_tokens_req_indices,
|
|
634
|
+
request_seq_lens,
|
|
635
|
+
logits_indices_selector=None):
|
|
636
|
+
placeholder_req_id_to_index: dict[str, int] = {}
|
|
637
|
+
discard_sampled_tokens_req_indices_set = set(
|
|
638
|
+
discard_sampled_tokens_req_indices)
|
|
639
|
+
for req_idx, req_state, _ in request_seq_lens:
|
|
640
|
+
if req_idx in discard_sampled_tokens_req_indices_set:
|
|
641
|
+
continue
|
|
642
|
+
|
|
643
|
+
start_idx = self.input_batch.num_tokens_no_spec[req_idx]
|
|
644
|
+
# Not supporting spec decode yet, assume only 1 new token
|
|
645
|
+
end_idx = start_idx + 1
|
|
646
|
+
assert end_idx <= self.max_model_len, (
|
|
647
|
+
"Sampled token IDs exceed the max model length. "
|
|
648
|
+
f"Total number of tokens: {end_idx} > max_model_len: "
|
|
649
|
+
f"{self.max_model_len}")
|
|
650
|
+
|
|
651
|
+
# Update cpu tokens at next execute and prepare input from tpu
|
|
652
|
+
self.input_batch.num_tokens_no_spec[req_idx] = end_idx
|
|
653
|
+
self.input_batch.num_tokens[req_idx] = end_idx
|
|
654
|
+
|
|
655
|
+
# For placeholder, should be update on next execute.
|
|
656
|
+
req_state.output_token_ids.extend([0])
|
|
657
|
+
if logits_indices_selector is None:
|
|
658
|
+
placeholder_req_id_to_index[req_state.req_id] = req_idx
|
|
659
|
+
else:
|
|
660
|
+
placeholder_req_id_to_index[
|
|
661
|
+
req_state.req_id] = logits_indices_selector[req_idx]
|
|
662
|
+
return placeholder_req_id_to_index
|
|
663
|
+
|
|
664
|
+
def _execute_model(
|
|
665
|
+
self,
|
|
666
|
+
scheduler_output: "VllmSchedulerOutput",
|
|
667
|
+
) -> tuple[AttentionMetadata, ModelRunnerOutput | None]:
|
|
668
|
+
self.persistent_batch_manager.update_states(
|
|
669
|
+
scheduler_output, self.get_mrope_input_positions_fn)
|
|
670
|
+
if not scheduler_output.total_num_scheduled_tokens:
|
|
671
|
+
if has_kv_transfer_group():
|
|
672
|
+
return DUMMY_METADATA, self.kv_connector_no_forward(
|
|
673
|
+
scheduler_output, self.vllm_config)
|
|
674
|
+
|
|
675
|
+
# Return empty ModelRunnerOutput if there's no work to do.
|
|
676
|
+
# TODO(fhzhang): We rely on empty cycles to remove requests in input batch. Fix it to reduce overhead.
|
|
677
|
+
logger.debug(f"Nothing scheduled: {scheduler_output}!")
|
|
678
|
+
# NOTE(pooyam): There is no guarantee that scheduler is not sending empty output: https://github.com/vllm-project/vllm/blob/7cfea0df390c154c1026f77d3682e2733ca4aca8/vllm/v1/engine/core.py#L275
|
|
679
|
+
# Why they are not preventing that is not clear to me.
|
|
680
|
+
if len(scheduler_output.finished_req_ids) == 0:
|
|
681
|
+
logger.warning(
|
|
682
|
+
"Should not schedule a request that does nothing!")
|
|
683
|
+
# raise Exception(
|
|
684
|
+
# "Should not schedule a request that does nothing!")
|
|
685
|
+
return DUMMY_METADATA, EMPTY_MODEL_RUNNER_OUTPUT
|
|
686
|
+
|
|
687
|
+
# TODO(pooyam): I guess we can remove returning sampling_metadata in `_prepare_inputs` after https://github.com/njhill/vllm/commit/b7433ca1a47732394b1bdea4099d98389515954b
|
|
688
|
+
(
|
|
689
|
+
input_ids,
|
|
690
|
+
attn_metadata,
|
|
691
|
+
_,
|
|
692
|
+
logits_indices,
|
|
693
|
+
spec_decode_metadata,
|
|
694
|
+
logits_indices_selector,
|
|
695
|
+
) = self._prepare_inputs(scheduler_output)
|
|
696
|
+
|
|
697
|
+
# multi-modal support
|
|
698
|
+
if self.is_multimodal_model:
|
|
699
|
+
# Run the multimodal encoder if any.
|
|
700
|
+
# We have the modality embeds at this time.
|
|
701
|
+
self.mm_manager.execute_mm_encoder(scheduler_output)
|
|
702
|
+
mm_embeds = self.mm_manager.gather_mm_embeddings(
|
|
703
|
+
scheduler_output, input_ids.shape[0])
|
|
704
|
+
else:
|
|
705
|
+
mm_embeds = []
|
|
706
|
+
|
|
707
|
+
# NOTE(Wenlong): For multi-modal model,
|
|
708
|
+
# it will embed the text tokens and merge with the existing modality embeds
|
|
709
|
+
# Later, the multi-modality model will take the embedding as the input.
|
|
710
|
+
# For text-only model, this does nothing. It will input the input_ids and
|
|
711
|
+
# leave the mebedding job inside the forward pass
|
|
712
|
+
input_ids, inputs_embeds = self._get_input_ids_embeds(
|
|
713
|
+
input_ids, mm_embeds)
|
|
714
|
+
|
|
715
|
+
lora_metadata = self.lora_utils.extract_lora_metadata()
|
|
716
|
+
# TODO: make _get_input_ids_embeds within this context
|
|
717
|
+
# NOTE: right now, mm model will use embeddings as the input,
|
|
718
|
+
# but text-only model will use input_ids
|
|
719
|
+
with self.maybe_forbid_compile:
|
|
720
|
+
|
|
721
|
+
with set_forward_context(
|
|
722
|
+
None,
|
|
723
|
+
self.vllm_config,
|
|
724
|
+
), self.maybe_get_kv_connector_output(
|
|
725
|
+
scheduler_output) as kv_connector_output:
|
|
726
|
+
# NOTE(Wenlong): It takes both `input_ids` and `inputs_embeds`,
|
|
727
|
+
# but one of them would be `None`
|
|
728
|
+
|
|
729
|
+
(self.kv_caches, hidden_states,
|
|
730
|
+
aux_hidden_states) = self.model_fn(
|
|
731
|
+
self.state,
|
|
732
|
+
self.kv_caches,
|
|
733
|
+
input_ids,
|
|
734
|
+
attn_metadata,
|
|
735
|
+
inputs_embeds,
|
|
736
|
+
tuple(self.layer_name_to_kvcache_index.items()),
|
|
737
|
+
lora_metadata,
|
|
738
|
+
)
|
|
739
|
+
|
|
740
|
+
hidden_states = self._select_from_array_fn(hidden_states,
|
|
741
|
+
logits_indices)
|
|
742
|
+
logits = self.compute_logits_fn(
|
|
743
|
+
self.state,
|
|
744
|
+
hidden_states,
|
|
745
|
+
lora_metadata,
|
|
746
|
+
)
|
|
747
|
+
|
|
748
|
+
self.execute_model_state = ExecuteModelState(
|
|
749
|
+
scheduler_output=scheduler_output,
|
|
750
|
+
attn_metadata=attn_metadata,
|
|
751
|
+
input_ids=input_ids,
|
|
752
|
+
hidden_states=hidden_states,
|
|
753
|
+
logits=logits,
|
|
754
|
+
aux_hidden_states=aux_hidden_states,
|
|
755
|
+
spec_decode_metadata=spec_decode_metadata,
|
|
756
|
+
kv_connector_output=kv_connector_output,
|
|
757
|
+
logits_indices_selector=logits_indices_selector)
|
|
758
|
+
return attn_metadata, None
|
|
759
|
+
|
|
760
|
+
def _sample_from_logits(
|
|
761
|
+
self,
|
|
762
|
+
scheduler_output: "VllmSchedulerOutput",
|
|
763
|
+
attn_metadata: AttentionMetadata,
|
|
764
|
+
input_ids: Optional[jax.Array],
|
|
765
|
+
hidden_states: jax.Array,
|
|
766
|
+
logits: jax.Array,
|
|
767
|
+
aux_hidden_states: Optional[jax.Array],
|
|
768
|
+
spec_decode_metadata: Optional[SpecDecodeMetadata],
|
|
769
|
+
kv_connector_output: Optional[KVConnectorOutput],
|
|
770
|
+
logits_indices_selector: Optional[List[int]] = None,
|
|
771
|
+
) -> ModelRunnerOutput | AsyncTPUModelRunnerOutput:
|
|
772
|
+
padded_num_reqs = runner_utils.get_padded_num_reqs_with_upper_limit(
|
|
773
|
+
self.input_batch.num_reqs, self.max_num_reqs)
|
|
774
|
+
tpu_sampling_metadata = TPUSupportedSamplingMetadata.from_input_batch(
|
|
775
|
+
self.mesh, self.input_batch, padded_num_reqs)
|
|
776
|
+
if spec_decode_metadata is None:
|
|
777
|
+
next_tokens = sample(
|
|
778
|
+
self.rng_params_for_sampling,
|
|
779
|
+
self.mesh,
|
|
780
|
+
logits,
|
|
781
|
+
tpu_sampling_metadata,
|
|
782
|
+
)
|
|
783
|
+
else:
|
|
784
|
+
bonus_logits = self._select_from_array_fn(
|
|
785
|
+
logits, spec_decode_metadata.bonus_logits_indices)
|
|
786
|
+
bonus_token_ids = sample(
|
|
787
|
+
self.rng_params_for_sampling,
|
|
788
|
+
self.mesh,
|
|
789
|
+
bonus_logits,
|
|
790
|
+
tpu_sampling_metadata,
|
|
791
|
+
)
|
|
792
|
+
target_logits = self._select_from_array_fn(
|
|
793
|
+
logits, spec_decode_metadata.target_logits_indices)
|
|
794
|
+
next_tokens = self.rejection_sampler(
|
|
795
|
+
draft_token_ids=spec_decode_metadata.draft_token_ids,
|
|
796
|
+
num_draft_tokens=spec_decode_metadata.draft_lengths,
|
|
797
|
+
draft_probs=None,
|
|
798
|
+
target_logits=target_logits,
|
|
799
|
+
bonus_token_ids=bonus_token_ids,
|
|
800
|
+
sampling_metadata=tpu_sampling_metadata,
|
|
801
|
+
key=self.rng_params_for_sampling,
|
|
802
|
+
)
|
|
803
|
+
|
|
804
|
+
if tpu_sampling_metadata.logprobs:
|
|
805
|
+
logprobs = self._compute_and_gather_logprobs(
|
|
806
|
+
logits, next_tokens, self.model_config.max_logprobs)
|
|
807
|
+
else:
|
|
808
|
+
logprobs = None
|
|
809
|
+
|
|
810
|
+
num_reqs = self.input_batch.num_reqs
|
|
811
|
+
|
|
812
|
+
# Update the cache state concurrently. Code above will not block until
|
|
813
|
+
# We use `selected_token_ids`. Add mark_step if post-processing changes
|
|
814
|
+
request_seq_lens: list[tuple[int, CachedRequestState, int]] = []
|
|
815
|
+
discard_sampled_tokens_req_indices = []
|
|
816
|
+
for i, req_id in zip(range(num_reqs), self.input_batch.req_ids):
|
|
817
|
+
assert req_id is not None
|
|
818
|
+
req_state = self.requests[req_id]
|
|
819
|
+
seq_len = (req_state.num_computed_tokens +
|
|
820
|
+
scheduler_output.num_scheduled_tokens[req_id])
|
|
821
|
+
if seq_len >= req_state.num_tokens:
|
|
822
|
+
request_seq_lens.append((i, req_state, seq_len))
|
|
823
|
+
else:
|
|
824
|
+
# Ignore the sampled token from the partial request.
|
|
825
|
+
# Rewind the generator state as if the token was not sampled.
|
|
826
|
+
generator = self.input_batch.generators.get(i)
|
|
827
|
+
if generator is not None:
|
|
828
|
+
# This relies on cuda-specific torch-internal impl details
|
|
829
|
+
generator.set_offset(generator.get_offset() - 4)
|
|
830
|
+
|
|
831
|
+
# Record the index of the request that should not be sampled,
|
|
832
|
+
# so that we could clear the sampled tokens before returning.
|
|
833
|
+
discard_sampled_tokens_req_indices.append(i)
|
|
834
|
+
|
|
835
|
+
assert all(
|
|
836
|
+
req_id is not None for req_id in
|
|
837
|
+
self.input_batch.req_ids[:num_reqs]), "req_ids contains None"
|
|
838
|
+
req_ids = cast(list[str], self.input_batch.req_ids[:num_reqs])
|
|
839
|
+
|
|
840
|
+
prompt_logprobs_dict = {}
|
|
841
|
+
for req_id in self.input_batch.req_ids[:num_reqs]:
|
|
842
|
+
prompt_logprobs_dict[req_id] = None
|
|
843
|
+
|
|
844
|
+
# If async scheduler enabled
|
|
845
|
+
if self.scheduler_config.async_scheduling:
|
|
846
|
+
# Get previous results from TPU and replace the placeholder.
|
|
847
|
+
if self._pre_async_results is not None:
|
|
848
|
+
assert not self.speculative_config and spec_decode_metadata is None, "Async scheduler does not support speculative decoding yet."
|
|
849
|
+
self._modify_prev_results()
|
|
850
|
+
|
|
851
|
+
# Set placeholder for next tokens that is not yet generated
|
|
852
|
+
placeholder_req_id_to_index: dict[
|
|
853
|
+
str, int] = self._update_placeholder(
|
|
854
|
+
discard_sampled_tokens_req_indices, request_seq_lens,
|
|
855
|
+
logits_indices_selector)
|
|
856
|
+
|
|
857
|
+
if logprobs is not None:
|
|
858
|
+
# Map logprobs back to the pre-dp shuffling order
|
|
859
|
+
logprobs_lists = logprobs.tolists()
|
|
860
|
+
if logits_indices_selector is not None:
|
|
861
|
+
logprobs_lists = _reorder_logits_indices(
|
|
862
|
+
logprobs_lists, logits_indices_selector)
|
|
863
|
+
|
|
864
|
+
else:
|
|
865
|
+
logprobs_lists = None
|
|
866
|
+
|
|
867
|
+
# Save the previous results
|
|
868
|
+
next_tokens = jax.copy_to_host_async(next_tokens)
|
|
869
|
+
self._pre_async_results = AsyncPreResults(
|
|
870
|
+
req_ids=req_ids,
|
|
871
|
+
next_tokens=next_tokens,
|
|
872
|
+
request_seq_lens=request_seq_lens,
|
|
873
|
+
discard_sampled_tokens_req_indices=
|
|
874
|
+
discard_sampled_tokens_req_indices,
|
|
875
|
+
placeholder_req_id_to_index=placeholder_req_id_to_index,
|
|
876
|
+
logits_indices_selector=logits_indices_selector)
|
|
877
|
+
|
|
878
|
+
# Return Model output to executor
|
|
879
|
+
model_runner_output = ModelRunnerOutput(
|
|
880
|
+
req_ids=req_ids,
|
|
881
|
+
req_id_to_index=copy.deepcopy(
|
|
882
|
+
self.input_batch.req_id_to_index),
|
|
883
|
+
sampled_token_ids=[], # Fill in async get
|
|
884
|
+
logprobs=logprobs_lists,
|
|
885
|
+
prompt_logprobs_dict=prompt_logprobs_dict,
|
|
886
|
+
pooler_output=[],
|
|
887
|
+
kv_connector_output=kv_connector_output,
|
|
888
|
+
)
|
|
889
|
+
# Return async_model_runner_output
|
|
890
|
+
async_model_runner_output = AsyncTPUModelRunnerOutput(
|
|
891
|
+
model_runner_output, next_tokens, num_reqs,
|
|
892
|
+
discard_sampled_tokens_req_indices, logits_indices_selector)
|
|
893
|
+
return async_model_runner_output
|
|
894
|
+
|
|
895
|
+
if spec_decode_metadata is None:
|
|
896
|
+
next_tokens = np.asarray(jax.device_get(next_tokens))
|
|
897
|
+
# Map tokens back to the pre-dp shuffling order
|
|
898
|
+
if logits_indices_selector is not None:
|
|
899
|
+
next_tokens = next_tokens[logits_indices_selector]
|
|
900
|
+
selected_token_ids = np.expand_dims(next_tokens[:num_reqs], 1)
|
|
901
|
+
valid_sampled_token_ids = selected_token_ids.tolist()
|
|
902
|
+
else:
|
|
903
|
+
valid_sampled_token_ids = self.rejection_sampler.parse_output(
|
|
904
|
+
next_tokens, self.input_batch.vocab_size,
|
|
905
|
+
spec_decode_metadata.draft_lengths_cpu, num_reqs,
|
|
906
|
+
spec_decode_metadata.draft_token_ids.shape[0])
|
|
907
|
+
|
|
908
|
+
# Mask out the sampled tokens that should not be sampled.
|
|
909
|
+
for i in discard_sampled_tokens_req_indices:
|
|
910
|
+
valid_sampled_token_ids[i].clear()
|
|
911
|
+
# Append sampled tokens
|
|
912
|
+
for req_idx, req_state, _ in request_seq_lens:
|
|
913
|
+
sampled_ids = valid_sampled_token_ids[req_idx]
|
|
914
|
+
if not sampled_ids:
|
|
915
|
+
continue
|
|
916
|
+
|
|
917
|
+
start_idx = self.input_batch.num_tokens_no_spec[req_idx]
|
|
918
|
+
end_idx = start_idx + len(sampled_ids)
|
|
919
|
+
assert end_idx <= self.max_model_len, (
|
|
920
|
+
"Sampled token IDs exceed the max model length. "
|
|
921
|
+
f"Total number of tokens: {end_idx} > max_model_len: "
|
|
922
|
+
f"{self.max_model_len}")
|
|
923
|
+
|
|
924
|
+
self.input_batch.token_ids_cpu[req_idx,
|
|
925
|
+
start_idx:end_idx] = sampled_ids
|
|
926
|
+
self.input_batch.num_tokens_no_spec[req_idx] = end_idx
|
|
927
|
+
self.input_batch.num_tokens[req_idx] = end_idx
|
|
928
|
+
req_state.output_token_ids.extend(sampled_ids)
|
|
929
|
+
|
|
930
|
+
if logprobs is not None:
|
|
931
|
+
# Map logprobs back to the pre-dp shuffling order
|
|
932
|
+
logprobs_lists = logprobs.tolists()
|
|
933
|
+
if logits_indices_selector is not None:
|
|
934
|
+
logprobs_lists = _reorder_logits_indices(
|
|
935
|
+
logprobs_lists, logits_indices_selector)
|
|
936
|
+
else:
|
|
937
|
+
logprobs_lists = None
|
|
938
|
+
|
|
939
|
+
if self.speculative_config:
|
|
940
|
+
with self.maybe_forbid_compile:
|
|
941
|
+
self.speculative_decoding_manager.propose_draft_token_ids(
|
|
942
|
+
valid_sampled_token_ids,
|
|
943
|
+
aux_hidden_states,
|
|
944
|
+
attn_metadata,
|
|
945
|
+
spec_decode_metadata,
|
|
946
|
+
scheduler_output,
|
|
947
|
+
input_ids,
|
|
948
|
+
)
|
|
949
|
+
|
|
950
|
+
model_runner_output = ModelRunnerOutput(
|
|
951
|
+
req_ids=req_ids,
|
|
952
|
+
req_id_to_index=self.input_batch.req_id_to_index,
|
|
953
|
+
sampled_token_ids=valid_sampled_token_ids,
|
|
954
|
+
logprobs=logprobs_lists,
|
|
955
|
+
prompt_logprobs_dict=prompt_logprobs_dict,
|
|
956
|
+
pooler_output=[],
|
|
957
|
+
kv_connector_output=kv_connector_output,
|
|
958
|
+
)
|
|
959
|
+
return model_runner_output
|
|
960
|
+
|
|
961
|
+
@functools.partial(jax.jit, static_argnums=(0, ))
|
|
962
|
+
def _select_from_array_fn(self, array, indices_to_select):
|
|
963
|
+
|
|
964
|
+
def select_local_fn(local_array, local_indices):
|
|
965
|
+
return local_array[local_indices]
|
|
966
|
+
|
|
967
|
+
ret = jax.shard_map(
|
|
968
|
+
select_local_fn,
|
|
969
|
+
mesh=self.mesh,
|
|
970
|
+
in_specs=(PartitionSpec(ShardingAxisName.ATTN_DATA),
|
|
971
|
+
PartitionSpec(ShardingAxisName.ATTN_DATA)),
|
|
972
|
+
out_specs=PartitionSpec(ShardingAxisName.ATTN_DATA))(
|
|
973
|
+
array, indices_to_select)
|
|
974
|
+
|
|
975
|
+
return ret
|
|
976
|
+
|
|
977
|
+
@staticmethod
|
|
978
|
+
@functools.partial(jax.jit, static_argnames=("max_logprobs", ))
|
|
979
|
+
def _compute_and_gather_logprobs(logits, next_tokens, max_logprobs):
|
|
980
|
+
logprobs = compute_logprobs(logits)
|
|
981
|
+
return gather_logprobs(logprobs, next_tokens, max_logprobs)
|
|
982
|
+
|
|
983
|
+
def _prepare_dp_input_metadata(self,
|
|
984
|
+
scheduler_output: "VllmSchedulerOutput"):
|
|
985
|
+
|
|
986
|
+
dp_size = self.dp_size
|
|
987
|
+
num_reqs = self.input_batch.num_reqs
|
|
988
|
+
max_num_reqs_per_dp_rank = self.max_num_reqs // dp_size
|
|
989
|
+
req_ids_dp = {dp_rank: [] for dp_rank in range(dp_size)}
|
|
990
|
+
req_indices_dp = {dp_rank: [] for dp_rank in range(dp_size)}
|
|
991
|
+
num_scheduled_tokens_per_dp_rank = {
|
|
992
|
+
dp_rank: 0
|
|
993
|
+
for dp_rank in range(dp_size)
|
|
994
|
+
}
|
|
995
|
+
scheduled_tokens_per_dp_rank = {
|
|
996
|
+
dp_rank: []
|
|
997
|
+
for dp_rank in range(dp_size)
|
|
998
|
+
}
|
|
999
|
+
num_req_per_dp_rank = {dp_rank: 0 for dp_rank in range(dp_size)}
|
|
1000
|
+
|
|
1001
|
+
for req_id in self.input_batch.req_ids[:num_reqs]:
|
|
1002
|
+
dp_rank = scheduler_output.assigned_dp_rank[req_id]
|
|
1003
|
+
req_ids_dp[dp_rank].append(req_id)
|
|
1004
|
+
req_indices_dp[dp_rank].append(
|
|
1005
|
+
self.input_batch.req_id_to_index[req_id])
|
|
1006
|
+
num_scheduled_tokens_per_dp_rank[
|
|
1007
|
+
dp_rank] += scheduler_output.num_scheduled_tokens[req_id]
|
|
1008
|
+
scheduled_tokens_per_dp_rank[dp_rank].append(
|
|
1009
|
+
scheduler_output.num_scheduled_tokens[req_id])
|
|
1010
|
+
num_req_per_dp_rank[dp_rank] += 1
|
|
1011
|
+
|
|
1012
|
+
# Find maximum number of scheduled tokens across DP ranks
|
|
1013
|
+
max_num_scheduled_tokens_across_dp = max(
|
|
1014
|
+
num_scheduled_tokens_per_dp_rank.values())
|
|
1015
|
+
|
|
1016
|
+
padded_num_scheduled_tokens_per_dp_rank = runner_utils.get_padded_token_len(
|
|
1017
|
+
self.num_tokens_paddings_per_dp,
|
|
1018
|
+
max_num_scheduled_tokens_across_dp)
|
|
1019
|
+
|
|
1020
|
+
padded_total_num_scheduled_tokens = (
|
|
1021
|
+
padded_num_scheduled_tokens_per_dp_rank * dp_size)
|
|
1022
|
+
|
|
1023
|
+
assert max_num_scheduled_tokens_across_dp > 0
|
|
1024
|
+
|
|
1025
|
+
# Find maximum number of requests across DP ranks
|
|
1026
|
+
max_num_reqs_across_dp = max(
|
|
1027
|
+
len(req_ids) for req_ids in req_ids_dp.values())
|
|
1028
|
+
padded_num_reqs_per_dp_rank = runner_utils.get_padded_token_len(
|
|
1029
|
+
self.num_reqs_paddings_per_dp, max_num_reqs_across_dp)
|
|
1030
|
+
padded_num_reqs = padded_num_reqs_per_dp_rank * dp_size
|
|
1031
|
+
|
|
1032
|
+
all_req_indices = np.concatenate(
|
|
1033
|
+
[req_indices_dp[dp_rank] for dp_rank in range(dp_size)])
|
|
1034
|
+
all_positions = np.concatenate([
|
|
1035
|
+
np.arange(len(req_indices_dp[dp_rank])) +
|
|
1036
|
+
padded_num_reqs_per_dp_rank * dp_rank for dp_rank in range(dp_size)
|
|
1037
|
+
])
|
|
1038
|
+
|
|
1039
|
+
# Sort positions by request indices
|
|
1040
|
+
sorted_indices = np.argsort(all_req_indices)
|
|
1041
|
+
logits_indices_selector = all_positions[sorted_indices]
|
|
1042
|
+
|
|
1043
|
+
return (req_ids_dp, req_indices_dp, num_scheduled_tokens_per_dp_rank,
|
|
1044
|
+
scheduled_tokens_per_dp_rank, num_req_per_dp_rank,
|
|
1045
|
+
padded_num_scheduled_tokens_per_dp_rank, padded_num_reqs,
|
|
1046
|
+
padded_total_num_scheduled_tokens, padded_num_reqs_per_dp_rank,
|
|
1047
|
+
logits_indices_selector, max_num_reqs_per_dp_rank)
|
|
1048
|
+
|
|
1049
|
+
def _prepare_async_token_substitution_indices_dp(
|
|
1050
|
+
self, req_ids_dp, scheduled_tokens_per_dp_rank,
|
|
1051
|
+
padded_num_scheduled_tokens_per_dp_rank, dp_size):
|
|
1052
|
+
"""Prepare token substitution indices for async scheduling in DP mode."""
|
|
1053
|
+
token_in_tpu_cur_input_indices_dp = {}
|
|
1054
|
+
token_in_tpu_pre_next_tokens_indices_dp = {}
|
|
1055
|
+
|
|
1056
|
+
for dp_rank in range(dp_size):
|
|
1057
|
+
token_in_tpu_cur_input_indices_dp[dp_rank] = []
|
|
1058
|
+
token_in_tpu_pre_next_tokens_indices_dp[dp_rank] = []
|
|
1059
|
+
|
|
1060
|
+
token_offset = padded_num_scheduled_tokens_per_dp_rank * dp_rank
|
|
1061
|
+
acc_cur_len = token_offset
|
|
1062
|
+
|
|
1063
|
+
for i, req_id in enumerate(req_ids_dp[dp_rank]):
|
|
1064
|
+
acc_cur_len += scheduled_tokens_per_dp_rank[dp_rank][i]
|
|
1065
|
+
if req_id not in self._pre_async_results.placeholder_req_id_to_index:
|
|
1066
|
+
continue
|
|
1067
|
+
|
|
1068
|
+
token_in_tpu_cur_input_indices_dp[dp_rank].append(acc_cur_len -
|
|
1069
|
+
1)
|
|
1070
|
+
token_in_tpu_pre_next_tokens_indices_dp[dp_rank].append(
|
|
1071
|
+
self._pre_async_results.placeholder_req_id_to_index[req_id]
|
|
1072
|
+
)
|
|
1073
|
+
|
|
1074
|
+
return token_in_tpu_cur_input_indices_dp, token_in_tpu_pre_next_tokens_indices_dp
|
|
1075
|
+
|
|
1076
|
+
def _prepare_async_token_substitution_indices_non_dp(
|
|
1077
|
+
self, num_reqs, num_scheduled_tokens_per_req):
|
|
1078
|
+
"""Prepare token substitution indices for async scheduling in non-DP mode."""
|
|
1079
|
+
token_in_tpu_cur_input_indices_list = []
|
|
1080
|
+
token_in_tpu_pre_next_tokens_indices_list = []
|
|
1081
|
+
acc_cur_len = 0
|
|
1082
|
+
|
|
1083
|
+
for i, req_id in enumerate(self.input_batch.req_ids[:num_reqs]):
|
|
1084
|
+
acc_cur_len += num_scheduled_tokens_per_req[i]
|
|
1085
|
+
assert req_id is not None
|
|
1086
|
+
if req_id not in self._pre_async_results.placeholder_req_id_to_index:
|
|
1087
|
+
continue
|
|
1088
|
+
|
|
1089
|
+
token_in_tpu_cur_input_indices_list.append(acc_cur_len - 1)
|
|
1090
|
+
token_in_tpu_pre_next_tokens_indices_list.append(
|
|
1091
|
+
self._pre_async_results.placeholder_req_id_to_index[req_id])
|
|
1092
|
+
|
|
1093
|
+
if len(token_in_tpu_cur_input_indices_list) > 0:
|
|
1094
|
+
return (np.array(token_in_tpu_cur_input_indices_list),
|
|
1095
|
+
np.array(token_in_tpu_pre_next_tokens_indices_list))
|
|
1096
|
+
else:
|
|
1097
|
+
return np.array([]), np.array([])
|
|
1098
|
+
|
|
1099
|
+
def _apply_async_token_substitution(self, input_ids,
|
|
1100
|
+
token_in_tpu_cur_input_indices,
|
|
1101
|
+
token_in_tpu_pre_next_tokens_indices):
|
|
1102
|
+
"""Apply async token substitution if needed."""
|
|
1103
|
+
if len(token_in_tpu_cur_input_indices) == 0:
|
|
1104
|
+
return input_ids
|
|
1105
|
+
|
|
1106
|
+
idx_pad_len = len(input_ids) - len(token_in_tpu_cur_input_indices)
|
|
1107
|
+
|
|
1108
|
+
# Pad according to the instructions written inside self._substitute_placeholder_token_fn
|
|
1109
|
+
full_range = np.arange(0, len(input_ids))
|
|
1110
|
+
missing_values = np.setdiff1d(full_range,
|
|
1111
|
+
token_in_tpu_cur_input_indices)
|
|
1112
|
+
padded_token_in_tpu_cur_input_indices = np.concatenate(
|
|
1113
|
+
(token_in_tpu_cur_input_indices, missing_values))
|
|
1114
|
+
|
|
1115
|
+
padded_token_in_tpu_pre_next_tokens_indices = np.pad(
|
|
1116
|
+
token_in_tpu_pre_next_tokens_indices, (0, idx_pad_len),
|
|
1117
|
+
mode='constant',
|
|
1118
|
+
constant_values=-1)
|
|
1119
|
+
|
|
1120
|
+
(padded_token_in_tpu_cur_input_indices,
|
|
1121
|
+
padded_token_in_tpu_pre_next_tokens_indices) = device_array(
|
|
1122
|
+
self.mesh, (padded_token_in_tpu_cur_input_indices,
|
|
1123
|
+
padded_token_in_tpu_pre_next_tokens_indices))
|
|
1124
|
+
|
|
1125
|
+
with self.maybe_forbid_compile:
|
|
1126
|
+
input_ids = self._substitute_placeholder_token_fn(
|
|
1127
|
+
input_ids, padded_token_in_tpu_cur_input_indices,
|
|
1128
|
+
padded_token_in_tpu_pre_next_tokens_indices,
|
|
1129
|
+
self._pre_async_results.next_tokens,
|
|
1130
|
+
len(token_in_tpu_cur_input_indices))
|
|
1131
|
+
|
|
1132
|
+
return input_ids
|
|
1133
|
+
|
|
1134
|
+
def _prepare_inputs(self, scheduler_output: "VllmSchedulerOutput"):
|
|
1135
|
+
if self.dp_size > 1:
|
|
1136
|
+
return self._prepare_inputs_dp(scheduler_output)
|
|
1137
|
+
else:
|
|
1138
|
+
return self._prepare_inputs_non_dp(scheduler_output)
|
|
1139
|
+
|
|
1140
|
+
def _prepare_inputs_dp(self, scheduler_output: "VllmSchedulerOutput"):
|
|
1141
|
+
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
|
1142
|
+
assert total_num_scheduled_tokens > 0
|
|
1143
|
+
num_reqs = self.input_batch.num_reqs
|
|
1144
|
+
assert num_reqs > 0
|
|
1145
|
+
|
|
1146
|
+
dp_size = self.dp_size
|
|
1147
|
+
data_parallel_attn_sharding = NamedSharding(
|
|
1148
|
+
self.mesh, PartitionSpec(ShardingAxisName.ATTN_DATA))
|
|
1149
|
+
|
|
1150
|
+
(req_ids_dp, req_indices_dp, num_scheduled_tokens_per_dp_rank,
|
|
1151
|
+
scheduled_tokens_per_dp_rank, num_req_per_dp_rank,
|
|
1152
|
+
padded_num_scheduled_tokens_per_dp_rank, padded_num_reqs,
|
|
1153
|
+
padded_total_num_scheduled_tokens, padded_num_reqs_per_dp_rank,
|
|
1154
|
+
logits_indices_selector, max_num_reqs_per_dp_rank
|
|
1155
|
+
) = self._prepare_dp_input_metadata(scheduler_output)
|
|
1156
|
+
# Multi-modal support
|
|
1157
|
+
# Calculate M-RoPE positions.
|
|
1158
|
+
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
|
|
1159
|
+
if self.uses_mrope:
|
|
1160
|
+
self.mm_manager.calc_mrope_positions(scheduler_output)
|
|
1161
|
+
|
|
1162
|
+
# Async scheduling: prepare token substitution indices for DP
|
|
1163
|
+
token_in_tpu_cur_input_indices_dp = {}
|
|
1164
|
+
token_in_tpu_pre_next_tokens_indices_dp = {}
|
|
1165
|
+
if self.scheduler_config.async_scheduling and self._pre_async_results is not None:
|
|
1166
|
+
# If async previous results exists, we will prepare for the token substitution here
|
|
1167
|
+
# The actual substitution will be performed in tpu during later parts of this function.
|
|
1168
|
+
(token_in_tpu_cur_input_indices_dp,
|
|
1169
|
+
token_in_tpu_pre_next_tokens_indices_dp
|
|
1170
|
+
) = self._prepare_async_token_substitution_indices_dp(
|
|
1171
|
+
req_ids_dp, scheduled_tokens_per_dp_rank,
|
|
1172
|
+
padded_num_scheduled_tokens_per_dp_rank, dp_size)
|
|
1173
|
+
|
|
1174
|
+
# Populates input_ids and positions
|
|
1175
|
+
for dp_rank in range(dp_size):
|
|
1176
|
+
if num_req_per_dp_rank[dp_rank] == 0:
|
|
1177
|
+
continue
|
|
1178
|
+
token_offset = padded_num_scheduled_tokens_per_dp_rank * dp_rank
|
|
1179
|
+
num_scheduled_tokens_per_req = scheduled_tokens_per_dp_rank[
|
|
1180
|
+
dp_rank]
|
|
1181
|
+
total_num_scheduled_tokens = num_scheduled_tokens_per_dp_rank[
|
|
1182
|
+
dp_rank]
|
|
1183
|
+
input_ids_cpu = self.input_ids_cpu[
|
|
1184
|
+
token_offset:token_offset +
|
|
1185
|
+
padded_num_scheduled_tokens_per_dp_rank]
|
|
1186
|
+
positions_cpu = self.positions_cpu[
|
|
1187
|
+
token_offset:token_offset +
|
|
1188
|
+
padded_num_scheduled_tokens_per_dp_rank]
|
|
1189
|
+
# Get request indices.
|
|
1190
|
+
# E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2]
|
|
1191
|
+
# For each scheduled token, what are the corresponding req index.
|
|
1192
|
+
req_indices = np.repeat(req_indices_dp[dp_rank],
|
|
1193
|
+
num_scheduled_tokens_per_req)
|
|
1194
|
+
# Get batched arange.
|
|
1195
|
+
# E.g., [2, 5, 3] -> [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
|
|
1196
|
+
# For each scheduled token, what is its position in corresponding req.
|
|
1197
|
+
arange = np.concatenate(
|
|
1198
|
+
[self.arange_cpu[:n] for n in num_scheduled_tokens_per_req])
|
|
1199
|
+
# Get positions.
|
|
1200
|
+
positions_np = positions_cpu[:total_num_scheduled_tokens]
|
|
1201
|
+
np.add(
|
|
1202
|
+
self.input_batch.num_computed_tokens_cpu[req_indices],
|
|
1203
|
+
arange,
|
|
1204
|
+
out=positions_np,
|
|
1205
|
+
)
|
|
1206
|
+
# Get token indices.
|
|
1207
|
+
# E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
|
|
1208
|
+
# -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2]
|
|
1209
|
+
# where M is the max_model_len.
|
|
1210
|
+
token_indices = (
|
|
1211
|
+
positions_np +
|
|
1212
|
+
req_indices * self.input_batch.token_ids_cpu.shape[1])
|
|
1213
|
+
# NOTE(woosuk): We use torch.index_select instead of np.take here
|
|
1214
|
+
# because torch.index_select is much faster than np.take for large
|
|
1215
|
+
# tensors.
|
|
1216
|
+
np.take(
|
|
1217
|
+
self.input_batch.token_ids_cpu.ravel(),
|
|
1218
|
+
token_indices,
|
|
1219
|
+
out=input_ids_cpu[:total_num_scheduled_tokens],
|
|
1220
|
+
)
|
|
1221
|
+
|
|
1222
|
+
input_ids_cpu[total_num_scheduled_tokens:] = 0
|
|
1223
|
+
|
|
1224
|
+
# Prepare the attention metadata (query_start_loc_cpu, seq_lens_cpu)
|
|
1225
|
+
for dp_rank in range(dp_size):
|
|
1226
|
+
req_offset = dp_rank * max_num_reqs_per_dp_rank
|
|
1227
|
+
query_start_loc_cpu = self.query_start_loc_cpu[
|
|
1228
|
+
req_offset + dp_rank:req_offset + max_num_reqs_per_dp_rank +
|
|
1229
|
+
dp_rank + 1]
|
|
1230
|
+
seq_lens_cpu = self.seq_lens_cpu[req_offset:req_offset +
|
|
1231
|
+
max_num_reqs_per_dp_rank]
|
|
1232
|
+
_num_reqs = num_req_per_dp_rank[dp_rank]
|
|
1233
|
+
req_indices = req_indices_dp[dp_rank]
|
|
1234
|
+
num_scheduled_tokens_per_req = scheduled_tokens_per_dp_rank[
|
|
1235
|
+
dp_rank]
|
|
1236
|
+
|
|
1237
|
+
if _num_reqs == 0:
|
|
1238
|
+
query_start_loc_cpu[:] = 0
|
|
1239
|
+
seq_lens_cpu[:] = 0
|
|
1240
|
+
continue
|
|
1241
|
+
|
|
1242
|
+
np.cumsum(
|
|
1243
|
+
num_scheduled_tokens_per_req,
|
|
1244
|
+
out=query_start_loc_cpu[1:_num_reqs + 1],
|
|
1245
|
+
)
|
|
1246
|
+
query_start_loc_cpu[_num_reqs + 1:] = 1
|
|
1247
|
+
|
|
1248
|
+
seq_lens_cpu[:_num_reqs] = (
|
|
1249
|
+
self.input_batch.num_computed_tokens_cpu[req_indices] +
|
|
1250
|
+
num_scheduled_tokens_per_req)
|
|
1251
|
+
seq_lens_cpu[_num_reqs:] = 0
|
|
1252
|
+
|
|
1253
|
+
# populate logits_indices
|
|
1254
|
+
for dp_rank in range(dp_size):
|
|
1255
|
+
req_offset = dp_rank * padded_num_reqs_per_dp_rank
|
|
1256
|
+
query_loc_req_offset = dp_rank * (max_num_reqs_per_dp_rank + 1)
|
|
1257
|
+
_num_reqs = num_req_per_dp_rank[dp_rank]
|
|
1258
|
+
|
|
1259
|
+
logits_indices_cpu = self.logits_indices_cpu[
|
|
1260
|
+
req_offset:req_offset + padded_num_reqs_per_dp_rank]
|
|
1261
|
+
logits_indices_cpu[:_num_reqs] = (
|
|
1262
|
+
self.query_start_loc_cpu[query_loc_req_offset +
|
|
1263
|
+
1:query_loc_req_offset + _num_reqs +
|
|
1264
|
+
1] - 1)
|
|
1265
|
+
logits_indices_cpu[_num_reqs:] = -1
|
|
1266
|
+
|
|
1267
|
+
logits_indices = self.logits_indices_cpu[:padded_num_reqs]
|
|
1268
|
+
|
|
1269
|
+
# Please see runner_utils.PhasedBasedProfiler for details
|
|
1270
|
+
if self.phase_based_profiler:
|
|
1271
|
+
batch_composition_stats = runner_utils.get_batch_composition_stats(
|
|
1272
|
+
self.input_batch, total_num_scheduled_tokens, num_reqs,
|
|
1273
|
+
padded_total_num_scheduled_tokens, scheduler_output)
|
|
1274
|
+
|
|
1275
|
+
self.phase_based_profiler.step(batch_composition_stats)
|
|
1276
|
+
|
|
1277
|
+
# Inputs
|
|
1278
|
+
input_ids = self.input_ids_cpu[:padded_total_num_scheduled_tokens]
|
|
1279
|
+
positions = self.positions_cpu[:padded_total_num_scheduled_tokens]
|
|
1280
|
+
mrope_positions = self.mrope_positions_cpu[:, :
|
|
1281
|
+
padded_total_num_scheduled_tokens]
|
|
1282
|
+
|
|
1283
|
+
block_tables = self.block_table_cpu[:self.max_num_reqs]
|
|
1284
|
+
for dp_rank in range(dp_size):
|
|
1285
|
+
req_offset = dp_rank * max_num_reqs_per_dp_rank
|
|
1286
|
+
_num_reqs = num_req_per_dp_rank[dp_rank]
|
|
1287
|
+
|
|
1288
|
+
block_tables[
|
|
1289
|
+
req_offset:req_offset + _num_reqs, :self.
|
|
1290
|
+
max_num_blocks_per_req] = self.input_batch.block_table[
|
|
1291
|
+
0].get_cpu_tensor()[req_indices_dp[dp_rank]]
|
|
1292
|
+
|
|
1293
|
+
query_start_loc = self.query_start_loc_cpu[:self.max_num_reqs +
|
|
1294
|
+
dp_size]
|
|
1295
|
+
seq_lens = self.seq_lens_cpu[:self.max_num_reqs]
|
|
1296
|
+
|
|
1297
|
+
_request_distribution = []
|
|
1298
|
+
for dp_rank in range(dp_size):
|
|
1299
|
+
_num_reqs = num_req_per_dp_rank[dp_rank]
|
|
1300
|
+
_request_distribution.append([0, 0, _num_reqs])
|
|
1301
|
+
request_distribution = np.array(_request_distribution).ravel()
|
|
1302
|
+
|
|
1303
|
+
use_spec_decode = len(
|
|
1304
|
+
scheduler_output.scheduled_spec_decode_tokens) > 0
|
|
1305
|
+
if not use_spec_decode:
|
|
1306
|
+
spec_decode_metadata = None
|
|
1307
|
+
else:
|
|
1308
|
+
num_draft_tokens = np.zeros(num_reqs, dtype=np.int32)
|
|
1309
|
+
for (
|
|
1310
|
+
req_id,
|
|
1311
|
+
draft_token_ids,
|
|
1312
|
+
) in scheduler_output.scheduled_spec_decode_tokens.items():
|
|
1313
|
+
req_idx = self.input_batch.req_id_to_index[req_id]
|
|
1314
|
+
num_draft_tokens[req_idx] = len(draft_token_ids)
|
|
1315
|
+
|
|
1316
|
+
spec_decode_metadata = (
|
|
1317
|
+
self.speculative_decoding_manager.get_spec_decode_metadata(
|
|
1318
|
+
num_draft_tokens,
|
|
1319
|
+
self.query_start_loc_cpu[1:num_reqs + 1],
|
|
1320
|
+
padded_num_reqs,
|
|
1321
|
+
))
|
|
1322
|
+
logits_indices = spec_decode_metadata.final_logits_indices
|
|
1323
|
+
|
|
1324
|
+
# Put to device
|
|
1325
|
+
sampling_metadata = TPUSupportedSamplingMetadata.from_input_batch(
|
|
1326
|
+
self.mesh,
|
|
1327
|
+
self.input_batch,
|
|
1328
|
+
padded_num_reqs,
|
|
1329
|
+
sharding=data_parallel_attn_sharding,
|
|
1330
|
+
)
|
|
1331
|
+
if self.uses_mrope:
|
|
1332
|
+
positions = mrope_positions
|
|
1333
|
+
|
|
1334
|
+
# Convert block_tables to 1D on cpu.
|
|
1335
|
+
block_tables = block_tables.reshape(-1)
|
|
1336
|
+
|
|
1337
|
+
query_start_loc_cpu = query_start_loc
|
|
1338
|
+
logits_indices_cpu = logits_indices
|
|
1339
|
+
seq_lens_cpu = seq_lens
|
|
1340
|
+
|
|
1341
|
+
(input_ids, positions, block_tables, query_start_loc, seq_lens,
|
|
1342
|
+
logits_indices, request_distribution) = device_array(
|
|
1343
|
+
self.mesh,
|
|
1344
|
+
(input_ids, positions, block_tables, query_start_loc, seq_lens,
|
|
1345
|
+
logits_indices, request_distribution),
|
|
1346
|
+
sharding=data_parallel_attn_sharding,
|
|
1347
|
+
)
|
|
1348
|
+
# Async scheduling: substitute placeholder tokens for DP
|
|
1349
|
+
if self.scheduler_config.async_scheduling and self._pre_async_results is not None:
|
|
1350
|
+
# Collect all token indices that need substitution across all DP ranks
|
|
1351
|
+
all_token_indices_to_substitute = []
|
|
1352
|
+
all_pre_next_tokens_indices = []
|
|
1353
|
+
|
|
1354
|
+
for dp_rank in range(dp_size):
|
|
1355
|
+
cur_indices = token_in_tpu_cur_input_indices_dp[dp_rank]
|
|
1356
|
+
pre_indices = token_in_tpu_pre_next_tokens_indices_dp[dp_rank]
|
|
1357
|
+
all_token_indices_to_substitute.extend(cur_indices)
|
|
1358
|
+
all_pre_next_tokens_indices.extend(pre_indices)
|
|
1359
|
+
|
|
1360
|
+
if len(all_token_indices_to_substitute) > 0:
|
|
1361
|
+
token_in_tpu_cur_input_indices = np.array(
|
|
1362
|
+
all_token_indices_to_substitute)
|
|
1363
|
+
token_in_tpu_pre_next_tokens_indices = np.array(
|
|
1364
|
+
all_pre_next_tokens_indices)
|
|
1365
|
+
input_ids = self._apply_async_token_substitution(
|
|
1366
|
+
input_ids, token_in_tpu_cur_input_indices,
|
|
1367
|
+
token_in_tpu_pre_next_tokens_indices)
|
|
1368
|
+
|
|
1369
|
+
if self.lora_config is not None:
|
|
1370
|
+
self.lora_utils.set_active_loras(
|
|
1371
|
+
num_scheduled_tokens_per_req,
|
|
1372
|
+
total_num_scheduled_tokens,
|
|
1373
|
+
padded_total_num_scheduled_tokens,
|
|
1374
|
+
)
|
|
1375
|
+
|
|
1376
|
+
attention_metadata = AttentionMetadata(
|
|
1377
|
+
input_positions=positions,
|
|
1378
|
+
block_tables=block_tables,
|
|
1379
|
+
seq_lens=seq_lens,
|
|
1380
|
+
query_start_loc=query_start_loc,
|
|
1381
|
+
request_distribution=request_distribution,
|
|
1382
|
+
)
|
|
1383
|
+
|
|
1384
|
+
# This is for making these cpu buffers hidden during tracing
|
|
1385
|
+
attention_metadata.query_start_loc_cpu = query_start_loc_cpu
|
|
1386
|
+
attention_metadata.seq_lens_cpu = seq_lens_cpu
|
|
1387
|
+
|
|
1388
|
+
return (
|
|
1389
|
+
input_ids,
|
|
1390
|
+
attention_metadata,
|
|
1391
|
+
sampling_metadata,
|
|
1392
|
+
logits_indices,
|
|
1393
|
+
spec_decode_metadata,
|
|
1394
|
+
logits_indices_selector,
|
|
1395
|
+
)
|
|
1396
|
+
|
|
1397
|
+
def _prepare_inputs_non_dp(self, scheduler_output: "VllmSchedulerOutput"):
|
|
1398
|
+
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
|
1399
|
+
assert total_num_scheduled_tokens > 0
|
|
1400
|
+
num_reqs = self.input_batch.num_reqs
|
|
1401
|
+
assert num_reqs > 0
|
|
1402
|
+
|
|
1403
|
+
# Get the number of scheduled tokens for each request.
|
|
1404
|
+
num_scheduled_tokens_per_req = []
|
|
1405
|
+
max_num_scheduled_tokens_all_reqs = 0
|
|
1406
|
+
for req_id in self.input_batch.req_ids[:num_reqs]:
|
|
1407
|
+
assert req_id is not None
|
|
1408
|
+
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
|
|
1409
|
+
num_scheduled_tokens_per_req.append(num_tokens)
|
|
1410
|
+
max_num_scheduled_tokens_all_reqs = max(
|
|
1411
|
+
max_num_scheduled_tokens_all_reqs, num_tokens)
|
|
1412
|
+
num_scheduled_tokens_per_req = np.array(num_scheduled_tokens_per_req,
|
|
1413
|
+
dtype=np.int32)
|
|
1414
|
+
assert max_num_scheduled_tokens_all_reqs > 0
|
|
1415
|
+
padded_num_reqs = runner_utils.get_padded_num_reqs_with_upper_limit(
|
|
1416
|
+
num_reqs, self.max_num_reqs)
|
|
1417
|
+
|
|
1418
|
+
# Get request indices.
|
|
1419
|
+
# E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2]
|
|
1420
|
+
# For each scheduled token, what are the corresponding req index.
|
|
1421
|
+
req_indices = np.repeat(self.arange_cpu[:num_reqs],
|
|
1422
|
+
num_scheduled_tokens_per_req)
|
|
1423
|
+
token_in_tpu_cur_input_indices = np.array([])
|
|
1424
|
+
token_in_tpu_pre_next_tokens_indices = np.array([])
|
|
1425
|
+
if self.scheduler_config.async_scheduling and self._pre_async_results is not None:
|
|
1426
|
+
# If async previous results exists, we will prepare for the token substitution here
|
|
1427
|
+
# The actual substitution will be performed in tpu during later parts of this function.
|
|
1428
|
+
(token_in_tpu_cur_input_indices,
|
|
1429
|
+
token_in_tpu_pre_next_tokens_indices
|
|
1430
|
+
) = self._prepare_async_token_substitution_indices_non_dp(
|
|
1431
|
+
num_reqs, num_scheduled_tokens_per_req)
|
|
1432
|
+
|
|
1433
|
+
# Get batched arange.
|
|
1434
|
+
# E.g., [2, 5, 3] -> [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
|
|
1435
|
+
# For each scheduled token, what is its position in corresponding req.
|
|
1436
|
+
arange = np.concatenate(
|
|
1437
|
+
[self.arange_cpu[:n] for n in num_scheduled_tokens_per_req])
|
|
1438
|
+
|
|
1439
|
+
# Get positions.
|
|
1440
|
+
positions_np = self.positions_cpu[:total_num_scheduled_tokens]
|
|
1441
|
+
np.add(self.input_batch.num_computed_tokens_cpu[req_indices],
|
|
1442
|
+
arange,
|
|
1443
|
+
out=positions_np)
|
|
1444
|
+
|
|
1445
|
+
# Multi-modal support
|
|
1446
|
+
# Calculate M-RoPE positions.
|
|
1447
|
+
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
|
|
1448
|
+
if self.uses_mrope:
|
|
1449
|
+
self.mm_manager.calc_mrope_positions(scheduler_output)
|
|
1450
|
+
|
|
1451
|
+
# Get token indices.
|
|
1452
|
+
# E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
|
|
1453
|
+
# -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2]
|
|
1454
|
+
# where M is the max_model_len.
|
|
1455
|
+
token_indices = (positions_np +
|
|
1456
|
+
req_indices * self.input_batch.token_ids_cpu.shape[1])
|
|
1457
|
+
|
|
1458
|
+
# NOTE(woosuk): We use torch.index_select instead of np.take here
|
|
1459
|
+
# because torch.index_select is much faster than np.take for large
|
|
1460
|
+
# tensors.
|
|
1461
|
+
np.take(self.input_batch.token_ids_cpu.ravel(),
|
|
1462
|
+
token_indices,
|
|
1463
|
+
out=self.input_ids_cpu[:total_num_scheduled_tokens])
|
|
1464
|
+
|
|
1465
|
+
# Prepare the attention metadata.
|
|
1466
|
+
self.query_start_loc_cpu[0] = 0
|
|
1467
|
+
np.cumsum(num_scheduled_tokens_per_req,
|
|
1468
|
+
out=self.query_start_loc_cpu[1:num_reqs + 1])
|
|
1469
|
+
self.query_start_loc_cpu[num_reqs + 1:] = 1
|
|
1470
|
+
|
|
1471
|
+
self.seq_lens_cpu[:num_reqs] = (
|
|
1472
|
+
self.input_batch.num_computed_tokens_cpu[:num_reqs] +
|
|
1473
|
+
num_scheduled_tokens_per_req)
|
|
1474
|
+
|
|
1475
|
+
# Do the padding and copy the tensors to the TPU.
|
|
1476
|
+
padded_total_num_scheduled_tokens = runner_utils.get_padded_token_len(
|
|
1477
|
+
self.num_tokens_paddings, total_num_scheduled_tokens)
|
|
1478
|
+
# Zero out to avoid spurious values from prev iteration (last cp chunk)
|
|
1479
|
+
self.input_ids_cpu[
|
|
1480
|
+
total_num_scheduled_tokens:padded_total_num_scheduled_tokens] = 0
|
|
1481
|
+
|
|
1482
|
+
# Please see runner_utils.PhasedBasedProfiler for details
|
|
1483
|
+
if self.phase_based_profiler:
|
|
1484
|
+
batch_composition_stats = runner_utils.get_batch_composition_stats(
|
|
1485
|
+
self.input_batch, total_num_scheduled_tokens, num_reqs,
|
|
1486
|
+
padded_total_num_scheduled_tokens, scheduler_output)
|
|
1487
|
+
|
|
1488
|
+
self.phase_based_profiler.step(batch_composition_stats)
|
|
1489
|
+
|
|
1490
|
+
# Inputs
|
|
1491
|
+
input_ids = self.input_ids_cpu[:padded_total_num_scheduled_tokens]
|
|
1492
|
+
positions = self.positions_cpu[:padded_total_num_scheduled_tokens]
|
|
1493
|
+
mrope_positions = self.mrope_positions_cpu[:, :
|
|
1494
|
+
padded_total_num_scheduled_tokens]
|
|
1495
|
+
block_tables = self.block_table_cpu[:self.max_num_reqs]
|
|
1496
|
+
block_tables[:num_reqs, :self.max_num_blocks_per_req] = (
|
|
1497
|
+
self.input_batch.block_table[0].get_cpu_tensor()[:num_reqs])
|
|
1498
|
+
|
|
1499
|
+
# TODO(pooyam): Some paddings are up to `num_reqs_paddings` (spec decoding, select hidden states, etc) and some other are to `max_num_reqs` (block table, seq_lens). We should stick to one of them maybe?
|
|
1500
|
+
query_start_loc = self.query_start_loc_cpu[:self.max_num_reqs + 1]
|
|
1501
|
+
seq_lens = self.seq_lens_cpu[:self.max_num_reqs]
|
|
1502
|
+
request_distribution = np.array(self.input_batch.request_distribution)
|
|
1503
|
+
use_spec_decode = len(
|
|
1504
|
+
scheduler_output.scheduled_spec_decode_tokens) > 0
|
|
1505
|
+
if not use_spec_decode:
|
|
1506
|
+
logits_indices = self.query_start_loc_cpu[1:padded_num_reqs +
|
|
1507
|
+
1] - 1
|
|
1508
|
+
spec_decode_metadata = None
|
|
1509
|
+
else:
|
|
1510
|
+
num_draft_tokens = np.zeros(num_reqs, dtype=np.int32)
|
|
1511
|
+
for req_id, draft_token_ids in (
|
|
1512
|
+
scheduler_output.scheduled_spec_decode_tokens.items()):
|
|
1513
|
+
req_idx = self.input_batch.req_id_to_index[req_id]
|
|
1514
|
+
num_draft_tokens[req_idx] = len(draft_token_ids)
|
|
1515
|
+
|
|
1516
|
+
spec_decode_metadata = self.speculative_decoding_manager.get_spec_decode_metadata(
|
|
1517
|
+
num_draft_tokens, self.query_start_loc_cpu[1:num_reqs + 1],
|
|
1518
|
+
padded_num_reqs)
|
|
1519
|
+
logits_indices = spec_decode_metadata.final_logits_indices
|
|
1520
|
+
|
|
1521
|
+
# Put to device
|
|
1522
|
+
sampling_metadata = TPUSupportedSamplingMetadata.from_input_batch(
|
|
1523
|
+
self.mesh, self.input_batch, padded_num_reqs)
|
|
1524
|
+
if self.uses_mrope:
|
|
1525
|
+
positions = mrope_positions
|
|
1526
|
+
|
|
1527
|
+
# Convert block_tables to 1D on cpu.
|
|
1528
|
+
block_tables = block_tables.reshape(-1)
|
|
1529
|
+
|
|
1530
|
+
query_start_loc_cpu = query_start_loc
|
|
1531
|
+
seq_lens_cpu = seq_lens
|
|
1532
|
+
(input_ids, positions, block_tables, query_start_loc, seq_lens,
|
|
1533
|
+
logits_indices, request_distribution) = device_array(
|
|
1534
|
+
self.mesh, (input_ids, positions, block_tables, query_start_loc,
|
|
1535
|
+
seq_lens, logits_indices, request_distribution))
|
|
1536
|
+
|
|
1537
|
+
if self.scheduler_config.async_scheduling and len(
|
|
1538
|
+
token_in_tpu_cur_input_indices) > 0:
|
|
1539
|
+
assert self._pre_async_results is not None
|
|
1540
|
+
input_ids = self._apply_async_token_substitution(
|
|
1541
|
+
input_ids, token_in_tpu_cur_input_indices,
|
|
1542
|
+
token_in_tpu_pre_next_tokens_indices)
|
|
1543
|
+
|
|
1544
|
+
if self.lora_config is not None:
|
|
1545
|
+
self.lora_utils.set_active_loras(
|
|
1546
|
+
num_scheduled_tokens_per_req, total_num_scheduled_tokens,
|
|
1547
|
+
padded_total_num_scheduled_tokens)
|
|
1548
|
+
|
|
1549
|
+
attention_metadata = AttentionMetadata(
|
|
1550
|
+
input_positions=positions,
|
|
1551
|
+
block_tables=block_tables,
|
|
1552
|
+
seq_lens=seq_lens,
|
|
1553
|
+
query_start_loc=query_start_loc,
|
|
1554
|
+
request_distribution=request_distribution)
|
|
1555
|
+
|
|
1556
|
+
# This is for making these cpu buffers hidden during tracing
|
|
1557
|
+
attention_metadata.query_start_loc_cpu = query_start_loc_cpu
|
|
1558
|
+
attention_metadata.seq_lens_cpu = seq_lens_cpu
|
|
1559
|
+
logits_indices_selector = None
|
|
1560
|
+
return (input_ids, attention_metadata, sampling_metadata,
|
|
1561
|
+
logits_indices, spec_decode_metadata, logits_indices_selector)
|
|
1562
|
+
|
|
1563
|
+
def _get_input_ids_embeds(self, input_ids: jax.Array,
|
|
1564
|
+
mm_embeds: list[jax.Array]):
|
|
1565
|
+
if self.is_multimodal_model:
|
|
1566
|
+
inputs_embeds = self.get_input_embeddings_fn(
|
|
1567
|
+
self.state,
|
|
1568
|
+
input_ids,
|
|
1569
|
+
mm_embeds,
|
|
1570
|
+
)
|
|
1571
|
+
return None, inputs_embeds
|
|
1572
|
+
else:
|
|
1573
|
+
return input_ids, None
|
|
1574
|
+
|
|
1575
|
+
def take_draft_token_ids(self) -> Optional[DraftTokenIds]:
|
|
1576
|
+
return self.speculative_decoding_manager.take_draft_token_ids()
|
|
1577
|
+
|
|
1578
|
+
###### Local disagg utilities ######
|
|
1579
|
+
|
|
1580
|
+
def get_kv_cache_for_block_ids(
|
|
1581
|
+
self,
|
|
1582
|
+
block_ids: List[int],
|
|
1583
|
+
) -> List[jax.Array]:
|
|
1584
|
+
return self.kv_cache_manager.get_kv_cache_for_block_ids(block_ids)
|
|
1585
|
+
|
|
1586
|
+
def transfer_kv_cache(self,
|
|
1587
|
+
kv_cache_slices: List[jax.Array]) -> List[jax.Array]:
|
|
1588
|
+
return self.kv_cache_manager.transfer_kv_cache(kv_cache_slices)
|
|
1589
|
+
|
|
1590
|
+
def insert_request_with_kv_cache(
|
|
1591
|
+
self,
|
|
1592
|
+
request: "Request",
|
|
1593
|
+
kv_cache_slices: List[jax.Array],
|
|
1594
|
+
block_ids: List[List[int]],
|
|
1595
|
+
):
|
|
1596
|
+
return self.kv_cache_manager.insert_request_with_kv_cache(
|
|
1597
|
+
request, kv_cache_slices, block_ids)
|
|
1598
|
+
|
|
1599
|
+
###### RL framework integration ######
|
|
1600
|
+
|
|
1601
|
+
def _sync_weights(
|
|
1602
|
+
self,
|
|
1603
|
+
updated_weights: jaxtyping.PyTree,
|
|
1604
|
+
mappings: Dict[str, Tuple[str, Tuple[str]]],
|
|
1605
|
+
transpose_keys: Dict[str, Tuple[int]],
|
|
1606
|
+
reshard_fn: Callable[[jaxtyping.PyTree, jaxtyping.PyTree],
|
|
1607
|
+
jaxtyping.PyTree] = None
|
|
1608
|
+
) -> None:
|
|
1609
|
+
"""For RL framework integration."""
|
|
1610
|
+
if reshard_fn is not None:
|
|
1611
|
+
updated_weights = reshard_fn(updated_weights, self.state)
|
|
1612
|
+
shard = None
|
|
1613
|
+
else:
|
|
1614
|
+
shard = functools.partial(shard_put, mesh=self.mesh)
|
|
1615
|
+
self.state = transfer_state_with_mappings(
|
|
1616
|
+
src_state=updated_weights,
|
|
1617
|
+
tgt_state=self.state,
|
|
1618
|
+
mappings=mappings,
|
|
1619
|
+
transpose_keys=transpose_keys,
|
|
1620
|
+
shard=shard)
|