tpu-inference 0.11.1.dev202511150811__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +0 -0
- tests/core/__init__.py +0 -0
- tests/core/test_core_tpu.py +513 -0
- tests/core/test_disagg_executor.py +60 -0
- tests/core/test_disagg_utils.py +53 -0
- tests/core/test_dp_scheduler.py +899 -0
- tests/core/test_init.py +49 -0
- tests/kernels/__init__.py +0 -0
- tests/kernels/fused_moe_v1_test.py +105 -0
- tests/kernels/mla_v1_test.py +396 -0
- tests/kernels/quantized_matmul_kernel_test.py +191 -0
- tests/kernels/ragged_kv_cache_update_v2_test.py +234 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +400 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +549 -0
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +504 -0
- tests/lora/__init__.py +0 -0
- tests/lora/conftest.py +32 -0
- tests/lora/test_bgmv.py +43 -0
- tests/lora/test_layers.py +654 -0
- tests/lora/test_lora.py +133 -0
- tests/lora/utils.py +96 -0
- tests/test_base.py +201 -0
- tests/test_envs.py +182 -0
- tests/test_quantization.py +836 -0
- tests/test_tpu_info.py +120 -0
- tests/test_utils.py +236 -0
- tpu_inference/__init__.py +34 -0
- tpu_inference/core/__init__.py +0 -0
- tpu_inference/core/core_tpu.py +786 -0
- tpu_inference/core/disagg_executor.py +118 -0
- tpu_inference/core/disagg_utils.py +51 -0
- tpu_inference/core/sched/__init__.py +0 -0
- tpu_inference/core/sched/dp_scheduler.py +523 -0
- tpu_inference/distributed/__init__.py +0 -0
- tpu_inference/distributed/jax_parallel_state.py +67 -0
- tpu_inference/distributed/tpu_connector.py +728 -0
- tpu_inference/distributed/utils.py +59 -0
- tpu_inference/env_override.py +9 -0
- tpu_inference/envs.py +107 -0
- tpu_inference/executors/__init__.py +0 -0
- tpu_inference/executors/ray_distributed_executor.py +362 -0
- tpu_inference/experimental/__init__.py +0 -0
- tpu_inference/experimental/llama3_jax_stashed.py +258 -0
- tpu_inference/kernels/__init__.py +0 -0
- tpu_inference/kernels/collectives/__init__.py +0 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
- tpu_inference/kernels/collectives/util.py +47 -0
- tpu_inference/kernels/flash_attention/__init__.py +0 -0
- tpu_inference/kernels/flash_attention/kernel.py +772 -0
- tpu_inference/kernels/fused_moe/__init__.py +0 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +0 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +1035 -0
- tpu_inference/kernels/mla/__init__.py +0 -0
- tpu_inference/kernels/mla/v1/__init__.py +0 -0
- tpu_inference/kernels/mla/v1/kernel.py +1349 -0
- tpu_inference/kernels/quantized_matmul/__init__.py +0 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +395 -0
- tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
- tpu_inference/kernels/quantized_matmul/util.py +58 -0
- tpu_inference/kernels/ragged_paged_attention/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +875 -0
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +287 -0
- tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1478 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1482 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4147 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +367 -0
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +51 -0
- tpu_inference/layers/__init__.py +0 -0
- tpu_inference/layers/common/__init__.py +0 -0
- tpu_inference/layers/common/attention_interface.py +390 -0
- tpu_inference/layers/common/attention_metadata.py +34 -0
- tpu_inference/layers/common/binary_search.py +295 -0
- tpu_inference/layers/common/quant_methods.py +8 -0
- tpu_inference/layers/common/sharding.py +582 -0
- tpu_inference/layers/jax/__init__.py +0 -0
- tpu_inference/layers/jax/attention/__init__.py +0 -0
- tpu_inference/layers/jax/attention/attention.py +255 -0
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +262 -0
- tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
- tpu_inference/layers/jax/base.py +151 -0
- tpu_inference/layers/jax/constants.py +88 -0
- tpu_inference/layers/jax/layers.py +301 -0
- tpu_inference/layers/jax/misc.py +16 -0
- tpu_inference/layers/jax/moe/__init__.py +0 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +185 -0
- tpu_inference/layers/jax/moe/moe.py +209 -0
- tpu_inference/layers/jax/rope.py +280 -0
- tpu_inference/layers/jax/rope_interface.py +214 -0
- tpu_inference/layers/jax/sample/__init__.py +0 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
- tpu_inference/layers/jax/sample/sampling.py +96 -0
- tpu_inference/layers/jax/sample/sampling_metadata.py +76 -0
- tpu_inference/layers/jax/transformer_block.py +107 -0
- tpu_inference/layers/vllm/__init__.py +0 -0
- tpu_inference/layers/vllm/attention.py +221 -0
- tpu_inference/layers/vllm/fused_moe.py +507 -0
- tpu_inference/layers/vllm/linear_common.py +186 -0
- tpu_inference/layers/vllm/quantization/__init__.py +39 -0
- tpu_inference/layers/vllm/quantization/awq.py +207 -0
- tpu_inference/layers/vllm/quantization/common.py +105 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +120 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +203 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +266 -0
- tpu_inference/layers/vllm/quantization/unquantized.py +386 -0
- tpu_inference/layers/vllm/sharding.py +230 -0
- tpu_inference/logger.py +10 -0
- tpu_inference/lora/__init__.py +0 -0
- tpu_inference/lora/torch_lora_ops.py +103 -0
- tpu_inference/lora/torch_punica_tpu.py +311 -0
- tpu_inference/mock/__init__.py +0 -0
- tpu_inference/mock/vllm_config_utils.py +28 -0
- tpu_inference/mock/vllm_envs.py +1219 -0
- tpu_inference/mock/vllm_logger.py +212 -0
- tpu_inference/mock/vllm_logging_utils.py +15 -0
- tpu_inference/models/__init__.py +0 -0
- tpu_inference/models/common/__init__.py +0 -0
- tpu_inference/models/common/model_loader.py +444 -0
- tpu_inference/models/jax/__init__.py +0 -0
- tpu_inference/models/jax/deepseek_v3.py +868 -0
- tpu_inference/models/jax/gpt_oss.py +492 -0
- tpu_inference/models/jax/jax_intermediate_tensor.py +79 -0
- tpu_inference/models/jax/llama3.py +375 -0
- tpu_inference/models/jax/llama4.py +629 -0
- tpu_inference/models/jax/llama_eagle3.py +333 -0
- tpu_inference/models/jax/phi3.py +376 -0
- tpu_inference/models/jax/qwen2.py +375 -0
- tpu_inference/models/jax/qwen2_5_vl.py +1103 -0
- tpu_inference/models/jax/qwen3.py +302 -0
- tpu_inference/models/jax/utils/__init__.py +0 -0
- tpu_inference/models/jax/utils/file_utils.py +96 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +163 -0
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +5 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +6 -0
- tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +5 -0
- tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +6 -0
- tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +105 -0
- tpu_inference/models/jax/utils/quantization/quantization_utils.py +653 -0
- tpu_inference/models/jax/utils/weight_utils.py +529 -0
- tpu_inference/models/vllm/__init__.py +0 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +286 -0
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +45 -0
- tpu_inference/platforms/__init__.py +2 -0
- tpu_inference/platforms/tpu_platform.py +269 -0
- tpu_inference/runner/__init__.py +0 -0
- tpu_inference/runner/block_table.py +122 -0
- tpu_inference/runner/compilation_manager.py +780 -0
- tpu_inference/runner/input_batch.py +435 -0
- tpu_inference/runner/kv_cache.py +132 -0
- tpu_inference/runner/kv_cache_manager.py +479 -0
- tpu_inference/runner/lora_utils.py +92 -0
- tpu_inference/runner/multimodal_manager.py +217 -0
- tpu_inference/runner/persistent_batch_manager.py +244 -0
- tpu_inference/runner/speculative_decoding_manager.py +248 -0
- tpu_inference/runner/structured_decoding_manager.py +88 -0
- tpu_inference/runner/tpu_runner.py +1620 -0
- tpu_inference/runner/utils.py +426 -0
- tpu_inference/spec_decode/__init__.py +0 -0
- tpu_inference/spec_decode/jax/__init__.py +0 -0
- tpu_inference/spec_decode/jax/eagle3.py +367 -0
- tpu_inference/tpu_info.py +77 -0
- tpu_inference/utils.py +317 -0
- tpu_inference/worker/__init__.py +0 -0
- tpu_inference/worker/tpu_worker.py +321 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/METADATA +107 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/RECORD +179 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/WHEEL +5 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/licenses/LICENSE +201 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,217 @@
|
|
|
1
|
+
from typing import TYPE_CHECKING
|
|
2
|
+
|
|
3
|
+
import jax
|
|
4
|
+
import jax.numpy as jnp
|
|
5
|
+
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
|
|
6
|
+
from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange
|
|
7
|
+
from vllm.multimodal.utils import group_mm_kwargs_by_modality
|
|
8
|
+
from vllm.v1.core.sched.output import SchedulerOutput as VllmSchedulerOutput
|
|
9
|
+
from vllm.v1.worker.utils import (gather_mm_placeholders,
|
|
10
|
+
scatter_mm_placeholders)
|
|
11
|
+
|
|
12
|
+
from tpu_inference.models.jax.utils.multi_modal_utils import (
|
|
13
|
+
flatten_embeddings, sanity_check_mm_encoder_outputs)
|
|
14
|
+
|
|
15
|
+
if TYPE_CHECKING:
|
|
16
|
+
from tpu_inference.runner.tpu_runner import TPUModelRunner
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class MultiModalManager:
|
|
20
|
+
|
|
21
|
+
def __init__(self, runner: "TPUModelRunner"):
|
|
22
|
+
self.runner = runner
|
|
23
|
+
|
|
24
|
+
def calc_mrope_positions(self, scheduler_output: "VllmSchedulerOutput"):
|
|
25
|
+
mrope_pos_ptr = 0
|
|
26
|
+
for index, req_id in enumerate(self.runner.input_batch.req_ids):
|
|
27
|
+
req = self.runner.requests[req_id]
|
|
28
|
+
assert req.mrope_positions is not None
|
|
29
|
+
|
|
30
|
+
num_computed_tokens = \
|
|
31
|
+
self.runner.input_batch.num_computed_tokens_cpu[index]
|
|
32
|
+
num_scheduled_tokens = \
|
|
33
|
+
scheduler_output.num_scheduled_tokens[req_id]
|
|
34
|
+
num_prompt_tokens = len(req.prompt_token_ids)
|
|
35
|
+
|
|
36
|
+
if num_computed_tokens + num_scheduled_tokens > num_prompt_tokens:
|
|
37
|
+
prompt_part_len = max(0,
|
|
38
|
+
num_prompt_tokens - num_computed_tokens)
|
|
39
|
+
completion_part_len = max(
|
|
40
|
+
0, num_scheduled_tokens - prompt_part_len)
|
|
41
|
+
else:
|
|
42
|
+
prompt_part_len = num_scheduled_tokens
|
|
43
|
+
completion_part_len = 0
|
|
44
|
+
|
|
45
|
+
assert num_scheduled_tokens == prompt_part_len + completion_part_len
|
|
46
|
+
|
|
47
|
+
if prompt_part_len > 0:
|
|
48
|
+
# prompt's mrope_positions are pre-computed
|
|
49
|
+
dst_start = mrope_pos_ptr
|
|
50
|
+
dst_end = mrope_pos_ptr + prompt_part_len
|
|
51
|
+
src_start = num_computed_tokens
|
|
52
|
+
src_end = num_computed_tokens + prompt_part_len
|
|
53
|
+
|
|
54
|
+
self.runner.mrope_positions_cpu[:, dst_start:dst_end] = \
|
|
55
|
+
req.mrope_positions[:,src_start:src_end]
|
|
56
|
+
|
|
57
|
+
mrope_pos_ptr += prompt_part_len
|
|
58
|
+
|
|
59
|
+
if completion_part_len > 0:
|
|
60
|
+
# compute completion's mrope_positions on-the-fly
|
|
61
|
+
dst_start = mrope_pos_ptr
|
|
62
|
+
dst_end = mrope_pos_ptr + completion_part_len
|
|
63
|
+
|
|
64
|
+
MRotaryEmbedding.get_next_input_positions_tensor(
|
|
65
|
+
out=self.runner.mrope_positions_cpu,
|
|
66
|
+
out_offset=dst_start,
|
|
67
|
+
mrope_position_delta=req.mrope_position_delta,
|
|
68
|
+
context_len=num_computed_tokens + prompt_part_len,
|
|
69
|
+
num_new_tokens=completion_part_len,
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
mrope_pos_ptr += completion_part_len
|
|
73
|
+
|
|
74
|
+
def execute_mm_encoder(self, scheduler_output: "VllmSchedulerOutput"):
|
|
75
|
+
import torch
|
|
76
|
+
scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs
|
|
77
|
+
if not scheduled_encoder_inputs:
|
|
78
|
+
return
|
|
79
|
+
|
|
80
|
+
# Batch the multi-modal inputs.
|
|
81
|
+
mm_kwargs = list[MultiModalKwargsItem]()
|
|
82
|
+
# List of tuple (mm_hash, pos_info)
|
|
83
|
+
mm_hashes_pos = list[tuple[str, PlaceholderRange]]()
|
|
84
|
+
for req_id, encoder_input_ids in scheduled_encoder_inputs.items():
|
|
85
|
+
req_state = self.runner.requests[req_id]
|
|
86
|
+
for mm_input_id in encoder_input_ids:
|
|
87
|
+
mm_feature = req_state.mm_features[mm_input_id]
|
|
88
|
+
mm_hash = mm_feature.identifier
|
|
89
|
+
mm_kwargs.append(mm_feature.data)
|
|
90
|
+
mm_hashes_pos.append((mm_hash, mm_feature.mm_position))
|
|
91
|
+
|
|
92
|
+
# Batch mm inputs as much as we can: if a request in the batch has
|
|
93
|
+
# multiple modalities or a different modality than the previous one,
|
|
94
|
+
# we process it separately to preserve item order.
|
|
95
|
+
# FIXME(ywang96): This is a hacky way to deal with multiple modalities
|
|
96
|
+
# in the same batch while still being able to benefit from batching
|
|
97
|
+
# multimodal inputs. The proper solution should be reordering the
|
|
98
|
+
# encoder outputs.
|
|
99
|
+
encoder_outputs = []
|
|
100
|
+
for _, num_items, mm_kwargs_group in group_mm_kwargs_by_modality(
|
|
101
|
+
mm_kwargs, merge_by_field_config=False):
|
|
102
|
+
batched_mm_inputs = mm_kwargs_group
|
|
103
|
+
# Convert torch tensors to numpy arrays that JAX can handle.
|
|
104
|
+
if "pixel_values" in batched_mm_inputs and isinstance(
|
|
105
|
+
batched_mm_inputs["pixel_values"], list):
|
|
106
|
+
batched_mm_inputs["pixel_values"] = torch.cat(
|
|
107
|
+
batched_mm_inputs["pixel_values"], dim=0)
|
|
108
|
+
|
|
109
|
+
image_grid_thw = ()
|
|
110
|
+
for key, value in batched_mm_inputs.items():
|
|
111
|
+
if isinstance(value, torch.Tensor):
|
|
112
|
+
if key == 'image_grid_thw':
|
|
113
|
+
# change it to tuple of tuples to make it hashable for JIT
|
|
114
|
+
|
|
115
|
+
# Shape: (B, N, 3) -> (B*N, 3) -> tuple of tuples
|
|
116
|
+
grid_thw_tensor = batched_mm_inputs[key]
|
|
117
|
+
grid_thw_reshaped = grid_thw_tensor.reshape(-1, 3)
|
|
118
|
+
image_grid_thw = tuple(
|
|
119
|
+
tuple(row) for row in grid_thw_reshaped.tolist())
|
|
120
|
+
|
|
121
|
+
continue
|
|
122
|
+
|
|
123
|
+
if value.dtype == torch.bfloat16:
|
|
124
|
+
batched_mm_inputs[key] = value.to(
|
|
125
|
+
torch.float32).numpy().astype(jnp.bfloat16)
|
|
126
|
+
else:
|
|
127
|
+
batched_mm_inputs[key] = value.numpy()
|
|
128
|
+
batched_mm_inputs.pop('image_grid_thw')
|
|
129
|
+
|
|
130
|
+
# Run the encoder.
|
|
131
|
+
# `curr_group_outputs` is either of the following:
|
|
132
|
+
# 1. A tensor of shape (num_items, feature_size, hidden_size)
|
|
133
|
+
# in case feature_size is fixed across all multimodal items.
|
|
134
|
+
# 2. A list or tuple (length: num_items) of tensors, each of shape
|
|
135
|
+
# (feature_size, hidden_size) in case the feature size is dynamic
|
|
136
|
+
# depending on the input multimodal items.
|
|
137
|
+
curr_group_outputs = self.runner.get_multimodal_embeddings_fn(
|
|
138
|
+
self.runner.state, image_grid_thw, **batched_mm_inputs)
|
|
139
|
+
|
|
140
|
+
sanity_check_mm_encoder_outputs(
|
|
141
|
+
curr_group_outputs,
|
|
142
|
+
expected_num_items=num_items,
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
for output in curr_group_outputs:
|
|
146
|
+
encoder_outputs.append(output)
|
|
147
|
+
|
|
148
|
+
# Cache the encoder outputs.
|
|
149
|
+
for (mm_hash, pos_info), output in zip(
|
|
150
|
+
mm_hashes_pos,
|
|
151
|
+
encoder_outputs,
|
|
152
|
+
):
|
|
153
|
+
if req_id not in self.runner.encoder_cache:
|
|
154
|
+
self.runner.encoder_cache[req_id] = {}
|
|
155
|
+
|
|
156
|
+
self.runner.encoder_cache[mm_hash] = scatter_mm_placeholders(
|
|
157
|
+
output,
|
|
158
|
+
is_embed=pos_info.is_embed,
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
def gather_mm_embeddings(self, scheduler_output: "VllmSchedulerOutput",
|
|
162
|
+
target_pad_len: int) -> list[jax.Array]:
|
|
163
|
+
mm_embeds: list[jax.Array] = []
|
|
164
|
+
for req_id in self.runner.input_batch.req_ids:
|
|
165
|
+
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[
|
|
166
|
+
req_id]
|
|
167
|
+
req_state = self.runner.requests[req_id]
|
|
168
|
+
num_computed_tokens = req_state.num_computed_tokens
|
|
169
|
+
mm_features = req_state.mm_features
|
|
170
|
+
for _, mm_feature in enumerate(mm_features):
|
|
171
|
+
pos_info = mm_feature.mm_position
|
|
172
|
+
start_pos = pos_info.offset
|
|
173
|
+
num_encoder_tokens = pos_info.length
|
|
174
|
+
|
|
175
|
+
# The encoder output is needed if the two ranges overlap:
|
|
176
|
+
# [num_computed_tokens,
|
|
177
|
+
# num_computed_tokens + num_scheduled_tokens) and
|
|
178
|
+
# [start_pos, start_pos + num_encoder_tokens)
|
|
179
|
+
if start_pos >= num_computed_tokens + num_scheduled_tokens:
|
|
180
|
+
# The encoder output is not needed in this step.
|
|
181
|
+
break
|
|
182
|
+
if start_pos + num_encoder_tokens <= num_computed_tokens:
|
|
183
|
+
# The encoder output is already processed and stored
|
|
184
|
+
# in the decoder's KV cache.
|
|
185
|
+
continue
|
|
186
|
+
|
|
187
|
+
start_idx = max(num_computed_tokens - start_pos, 0)
|
|
188
|
+
end_idx = min(
|
|
189
|
+
num_computed_tokens - start_pos + num_scheduled_tokens,
|
|
190
|
+
num_encoder_tokens)
|
|
191
|
+
assert start_idx < end_idx
|
|
192
|
+
mm_hash = mm_feature.identifier
|
|
193
|
+
encoder_output = self.runner.encoder_cache.get(mm_hash, None)
|
|
194
|
+
assert encoder_output is not None,\
|
|
195
|
+
f"Encoder cache miss for {mm_hash}."
|
|
196
|
+
encoder_output = self.runner.encoder_cache[mm_hash]
|
|
197
|
+
|
|
198
|
+
if (is_embed := pos_info.is_embed) is not None:
|
|
199
|
+
is_embed = is_embed[start_idx:end_idx]
|
|
200
|
+
|
|
201
|
+
mm_embeds_item = gather_mm_placeholders(
|
|
202
|
+
encoder_output[start_idx:end_idx],
|
|
203
|
+
is_embed=is_embed,
|
|
204
|
+
)
|
|
205
|
+
mm_embeds.append(mm_embeds_item)
|
|
206
|
+
if not mm_embeds:
|
|
207
|
+
return None
|
|
208
|
+
flattened_embeds = flatten_embeddings(mm_embeds)
|
|
209
|
+
if flattened_embeds.shape[0] == 0:
|
|
210
|
+
return None
|
|
211
|
+
|
|
212
|
+
padding = jnp.zeros((target_pad_len - flattened_embeds.shape[0],
|
|
213
|
+
flattened_embeds.shape[1]),
|
|
214
|
+
dtype=flattened_embeds.dtype)
|
|
215
|
+
flattened_embeds = jnp.concatenate([flattened_embeds, padding], axis=0)
|
|
216
|
+
|
|
217
|
+
return flattened_embeds
|
|
@@ -0,0 +1,244 @@
|
|
|
1
|
+
from typing import Dict
|
|
2
|
+
|
|
3
|
+
import jax
|
|
4
|
+
from vllm.v1.core.sched.output import SchedulerOutput as VllmSchedulerOutput
|
|
5
|
+
|
|
6
|
+
from tpu_inference.logger import init_logger
|
|
7
|
+
from tpu_inference.runner.input_batch import CachedRequestState, InputBatch
|
|
8
|
+
|
|
9
|
+
logger = init_logger(__name__)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class PersistentBatchManager:
|
|
13
|
+
|
|
14
|
+
def __init__(self, requests: Dict[str, CachedRequestState],
|
|
15
|
+
input_batch: InputBatch, encoder_cache: Dict[str,
|
|
16
|
+
'jax.Array'],
|
|
17
|
+
uses_mrope: bool, model_config):
|
|
18
|
+
self.requests = requests
|
|
19
|
+
self.input_batch = input_batch
|
|
20
|
+
self.encoder_cache = encoder_cache
|
|
21
|
+
self.uses_mrope = uses_mrope
|
|
22
|
+
self.model_config = model_config
|
|
23
|
+
|
|
24
|
+
def _reorder_batch(self, scheduler_output: "VllmSchedulerOutput") -> int:
|
|
25
|
+
""" Reorder the sheduled requests to RPA kernel friendly distribution
|
|
26
|
+
(decode_only, fixed_chunked_prefill_only, mixed) and set the request
|
|
27
|
+
distribution accordingly.
|
|
28
|
+
|
|
29
|
+
Returns:
|
|
30
|
+
The number of swaps in requests.
|
|
31
|
+
"""
|
|
32
|
+
# Note(jevinjiang): currently we only consider decode_only.
|
|
33
|
+
num_reqs = self.input_batch.num_reqs
|
|
34
|
+
swap_cnt = 0
|
|
35
|
+
if num_reqs <= 0:
|
|
36
|
+
return swap_cnt
|
|
37
|
+
# Use two-pointer approach to reorder the decode requests to front.
|
|
38
|
+
i, j = 0, num_reqs - 1
|
|
39
|
+
while i < j:
|
|
40
|
+
i_req_id = self.input_batch.req_ids[i]
|
|
41
|
+
j_req_id = self.input_batch.req_ids[j]
|
|
42
|
+
|
|
43
|
+
if scheduler_output.num_scheduled_tokens[i_req_id] == 1:
|
|
44
|
+
# i is a decode request, move to the next one.
|
|
45
|
+
i += 1
|
|
46
|
+
elif scheduler_output.num_scheduled_tokens[j_req_id] > 1:
|
|
47
|
+
# j is a prefill request, move to the previous one.
|
|
48
|
+
j -= 1
|
|
49
|
+
else:
|
|
50
|
+
# Swap i and j.
|
|
51
|
+
self.input_batch.swap_states(i, j)
|
|
52
|
+
i += 1
|
|
53
|
+
j -= 1
|
|
54
|
+
swap_cnt += 1
|
|
55
|
+
|
|
56
|
+
num_decode = i + int(scheduler_output.num_scheduled_tokens[
|
|
57
|
+
self.input_batch.req_ids[i]] == 1)
|
|
58
|
+
|
|
59
|
+
self.input_batch.request_distribution = [
|
|
60
|
+
num_decode, num_decode, num_reqs
|
|
61
|
+
]
|
|
62
|
+
|
|
63
|
+
return swap_cnt
|
|
64
|
+
|
|
65
|
+
def update_states(self, scheduler_output: "VllmSchedulerOutput",
|
|
66
|
+
get_mrope_input_positions_fn) -> bool:
|
|
67
|
+
"""Update the cached states and the persistent batch with the scheduler
|
|
68
|
+
output.
|
|
69
|
+
|
|
70
|
+
The updated states are used by the `_prepare_inputs` function to create
|
|
71
|
+
the input TPU tensors for the model.
|
|
72
|
+
|
|
73
|
+
Returns:
|
|
74
|
+
True if there is a new/resumed/paused/finished request.
|
|
75
|
+
If False, we can skip copying SamplingMetadata to the TPU.
|
|
76
|
+
"""
|
|
77
|
+
# Remove finished requests from the cached states.
|
|
78
|
+
for req_id in scheduler_output.finished_req_ids:
|
|
79
|
+
self.requests.pop(req_id, None)
|
|
80
|
+
|
|
81
|
+
# Remove the finished requests from the persistent batch.
|
|
82
|
+
# NOTE(woosuk): There could be an edge case where finished_req_ids and
|
|
83
|
+
# scheduled_req_ids overlap. This happens when a request is aborted and
|
|
84
|
+
# then resubmitted with the same ID. In this case, we treat them as two
|
|
85
|
+
# distinct requests - clearing the cached states for the first request
|
|
86
|
+
# and handling the second as a new request.
|
|
87
|
+
removed_req_indices: list[int] = []
|
|
88
|
+
for req_id in scheduler_output.finished_req_ids:
|
|
89
|
+
req_index = self.input_batch.remove_request(req_id)
|
|
90
|
+
if req_index is not None:
|
|
91
|
+
removed_req_indices.append(req_index)
|
|
92
|
+
|
|
93
|
+
# Free the cached encoder outputs.
|
|
94
|
+
for mm_hash in scheduler_output.free_encoder_mm_hashes:
|
|
95
|
+
self.encoder_cache.pop(mm_hash, None)
|
|
96
|
+
|
|
97
|
+
# Remove the unscheduled requests from the persistent batch.
|
|
98
|
+
# NOTE(woosuk): The unscheduled requests are either preempted requests
|
|
99
|
+
# or running requests that are not scheduled in this step. We remove
|
|
100
|
+
# them from the persistent batch but keep their cached states since
|
|
101
|
+
# they will be scheduled again sometime in the future.
|
|
102
|
+
scheduled_req_ids = scheduler_output.num_scheduled_tokens.keys()
|
|
103
|
+
cached_req_ids = self.input_batch.req_id_to_index.keys()
|
|
104
|
+
unscheduled_req_ids = cached_req_ids - scheduled_req_ids
|
|
105
|
+
# NOTE(woosuk): The persistent batch optimization assumes that
|
|
106
|
+
# consecutive batches contain mostly the same requests. If batches
|
|
107
|
+
# have low request overlap (e.g., alternating between two distinct
|
|
108
|
+
# sets of requests), this optimization becomes very inefficient.
|
|
109
|
+
for req_id in unscheduled_req_ids:
|
|
110
|
+
req_index = self.input_batch.remove_request(req_id)
|
|
111
|
+
assert req_index is not None
|
|
112
|
+
removed_req_indices.append(req_index)
|
|
113
|
+
|
|
114
|
+
req_ids_to_add: list[str] = []
|
|
115
|
+
# Add new requests to the cached states.
|
|
116
|
+
for new_req_data in scheduler_output.scheduled_new_reqs:
|
|
117
|
+
req_id = new_req_data.req_id
|
|
118
|
+
sampling_params = new_req_data.sampling_params
|
|
119
|
+
|
|
120
|
+
self.requests[req_id] = CachedRequestState(
|
|
121
|
+
req_id=req_id,
|
|
122
|
+
prompt_token_ids=new_req_data.prompt_token_ids,
|
|
123
|
+
mm_features=new_req_data.mm_features,
|
|
124
|
+
sampling_params=sampling_params,
|
|
125
|
+
pooling_params=None,
|
|
126
|
+
generator=None,
|
|
127
|
+
block_ids=new_req_data.block_ids,
|
|
128
|
+
num_computed_tokens=new_req_data.num_computed_tokens,
|
|
129
|
+
output_token_ids=[],
|
|
130
|
+
lora_request=new_req_data.lora_request,
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
req_ids_to_add.append(req_id)
|
|
134
|
+
|
|
135
|
+
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
|
|
136
|
+
if self.uses_mrope:
|
|
137
|
+
image_grid_thw = []
|
|
138
|
+
video_grid_thw = []
|
|
139
|
+
second_per_grid_ts = []
|
|
140
|
+
audio_feature_lengths = []
|
|
141
|
+
use_audio_in_video = False
|
|
142
|
+
for mm_feature in self.requests[req_id].mm_features:
|
|
143
|
+
item = mm_feature.data
|
|
144
|
+
if item is None:
|
|
145
|
+
continue
|
|
146
|
+
mm_input = item.get_data()
|
|
147
|
+
if mm_input.get("image_grid_thw") is not None:
|
|
148
|
+
image_grid_thw.append(
|
|
149
|
+
mm_input["image_grid_thw"].tolist())
|
|
150
|
+
if mm_input.get("video_grid_thw") is not None:
|
|
151
|
+
video_grid_thw.append(
|
|
152
|
+
mm_input["video_grid_thw"].tolist())
|
|
153
|
+
if mm_input.get("second_per_grid_ts") is not None:
|
|
154
|
+
second_per_grid_ts.append(
|
|
155
|
+
mm_input["second_per_grid_ts"])
|
|
156
|
+
if mm_input.get("audio_feature_lengths") is not None:
|
|
157
|
+
audio_feature_lengths.append(
|
|
158
|
+
mm_input["audio_feature_lengths"])
|
|
159
|
+
if mm_input.get("use_audio_in_video") is True:
|
|
160
|
+
use_audio_in_video = True
|
|
161
|
+
|
|
162
|
+
hf_config = self.model_config.hf_config
|
|
163
|
+
|
|
164
|
+
self.requests[req_id].mrope_positions, self.requests[
|
|
165
|
+
req_id].mrope_position_delta = get_mrope_input_positions_fn(
|
|
166
|
+
self.requests[req_id].prompt_token_ids,
|
|
167
|
+
hf_config=hf_config,
|
|
168
|
+
image_grid_thw=image_grid_thw,
|
|
169
|
+
video_grid_thw=video_grid_thw,
|
|
170
|
+
second_per_grid_ts=second_per_grid_ts,
|
|
171
|
+
audio_feature_lengths=audio_feature_lengths,
|
|
172
|
+
use_audio_in_video=use_audio_in_video,
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
# Update the states of the running/resumed requests.
|
|
176
|
+
req_data = scheduler_output.scheduled_cached_reqs
|
|
177
|
+
for i, req_id in enumerate(req_data.req_ids):
|
|
178
|
+
req_state = self.requests[req_id]
|
|
179
|
+
num_computed_tokens = req_data.num_computed_tokens[i]
|
|
180
|
+
new_block_ids = req_data.new_block_ids[i]
|
|
181
|
+
resumed_from_preemption = req_data.resumed_from_preemption[i]
|
|
182
|
+
|
|
183
|
+
# Update the cached states.
|
|
184
|
+
req_state.num_computed_tokens = num_computed_tokens
|
|
185
|
+
if not resumed_from_preemption:
|
|
186
|
+
if new_block_ids is not None:
|
|
187
|
+
# Append the new blocks to the existing block IDs.
|
|
188
|
+
for block_ids, new_ids in zip(req_state.block_ids,
|
|
189
|
+
new_block_ids):
|
|
190
|
+
block_ids.extend(new_ids)
|
|
191
|
+
else:
|
|
192
|
+
assert new_block_ids is not None
|
|
193
|
+
# The request is resumed from preemption.
|
|
194
|
+
# Replace the existing block IDs with the new ones.
|
|
195
|
+
req_state.block_ids = new_block_ids
|
|
196
|
+
|
|
197
|
+
req_index = self.input_batch.req_id_to_index.get(req_id)
|
|
198
|
+
if req_index is None:
|
|
199
|
+
# The request is not in the persistent batch.
|
|
200
|
+
# The request was either preempted and resumed later, or was not
|
|
201
|
+
# scheduled in the previous step and needs to be added again.
|
|
202
|
+
req_ids_to_add.append(req_id)
|
|
203
|
+
continue
|
|
204
|
+
|
|
205
|
+
# Update the persistent batch.
|
|
206
|
+
self.input_batch.num_computed_tokens_cpu[
|
|
207
|
+
req_index] = num_computed_tokens
|
|
208
|
+
if new_block_ids is not None:
|
|
209
|
+
self.input_batch.block_table.append_row(
|
|
210
|
+
new_block_ids, req_index)
|
|
211
|
+
|
|
212
|
+
# Add spec_token_ids to token_ids_cpu.
|
|
213
|
+
spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get(
|
|
214
|
+
req_id, ())
|
|
215
|
+
if spec_token_ids:
|
|
216
|
+
num_spec_tokens = len(spec_token_ids)
|
|
217
|
+
start_index = self.input_batch.num_tokens_no_spec[req_index]
|
|
218
|
+
end_token_index = start_index + num_spec_tokens
|
|
219
|
+
self.input_batch.token_ids_cpu[
|
|
220
|
+
req_index, start_index:end_token_index] = spec_token_ids
|
|
221
|
+
# NOTE(woosuk): `num_tokens` here may include spec tokens.
|
|
222
|
+
self.input_batch.num_tokens[req_index] += num_spec_tokens
|
|
223
|
+
|
|
224
|
+
# Add the new or resumed requests to the persistent batch.
|
|
225
|
+
# The smaller empty indices are filled first.
|
|
226
|
+
removed_req_indices = sorted(removed_req_indices, reverse=True)
|
|
227
|
+
for req_id in req_ids_to_add:
|
|
228
|
+
req_state = self.requests[req_id]
|
|
229
|
+
if removed_req_indices:
|
|
230
|
+
# Fill the empty index.
|
|
231
|
+
req_index = removed_req_indices.pop()
|
|
232
|
+
else:
|
|
233
|
+
# Append to the end.
|
|
234
|
+
req_index = None
|
|
235
|
+
self.input_batch.add_request(req_state, req_index)
|
|
236
|
+
|
|
237
|
+
# Condense the batched states if there are empty indices.
|
|
238
|
+
if removed_req_indices:
|
|
239
|
+
self.input_batch.condense(removed_req_indices)
|
|
240
|
+
|
|
241
|
+
batch_changed = len(unscheduled_req_ids) > 0 or len(req_ids_to_add) > 0
|
|
242
|
+
# TODO(jevinjiang): I assume we do not need to set batch_changed to true if just swapping requests.
|
|
243
|
+
self._reorder_batch(scheduler_output)
|
|
244
|
+
return batch_changed
|