tpu-inference 0.11.1.dev202511150811__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +0 -0
- tests/core/__init__.py +0 -0
- tests/core/test_core_tpu.py +513 -0
- tests/core/test_disagg_executor.py +60 -0
- tests/core/test_disagg_utils.py +53 -0
- tests/core/test_dp_scheduler.py +899 -0
- tests/core/test_init.py +49 -0
- tests/kernels/__init__.py +0 -0
- tests/kernels/fused_moe_v1_test.py +105 -0
- tests/kernels/mla_v1_test.py +396 -0
- tests/kernels/quantized_matmul_kernel_test.py +191 -0
- tests/kernels/ragged_kv_cache_update_v2_test.py +234 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +400 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +549 -0
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +504 -0
- tests/lora/__init__.py +0 -0
- tests/lora/conftest.py +32 -0
- tests/lora/test_bgmv.py +43 -0
- tests/lora/test_layers.py +654 -0
- tests/lora/test_lora.py +133 -0
- tests/lora/utils.py +96 -0
- tests/test_base.py +201 -0
- tests/test_envs.py +182 -0
- tests/test_quantization.py +836 -0
- tests/test_tpu_info.py +120 -0
- tests/test_utils.py +236 -0
- tpu_inference/__init__.py +34 -0
- tpu_inference/core/__init__.py +0 -0
- tpu_inference/core/core_tpu.py +786 -0
- tpu_inference/core/disagg_executor.py +118 -0
- tpu_inference/core/disagg_utils.py +51 -0
- tpu_inference/core/sched/__init__.py +0 -0
- tpu_inference/core/sched/dp_scheduler.py +523 -0
- tpu_inference/distributed/__init__.py +0 -0
- tpu_inference/distributed/jax_parallel_state.py +67 -0
- tpu_inference/distributed/tpu_connector.py +728 -0
- tpu_inference/distributed/utils.py +59 -0
- tpu_inference/env_override.py +9 -0
- tpu_inference/envs.py +107 -0
- tpu_inference/executors/__init__.py +0 -0
- tpu_inference/executors/ray_distributed_executor.py +362 -0
- tpu_inference/experimental/__init__.py +0 -0
- tpu_inference/experimental/llama3_jax_stashed.py +258 -0
- tpu_inference/kernels/__init__.py +0 -0
- tpu_inference/kernels/collectives/__init__.py +0 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
- tpu_inference/kernels/collectives/util.py +47 -0
- tpu_inference/kernels/flash_attention/__init__.py +0 -0
- tpu_inference/kernels/flash_attention/kernel.py +772 -0
- tpu_inference/kernels/fused_moe/__init__.py +0 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +0 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +1035 -0
- tpu_inference/kernels/mla/__init__.py +0 -0
- tpu_inference/kernels/mla/v1/__init__.py +0 -0
- tpu_inference/kernels/mla/v1/kernel.py +1349 -0
- tpu_inference/kernels/quantized_matmul/__init__.py +0 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +395 -0
- tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
- tpu_inference/kernels/quantized_matmul/util.py +58 -0
- tpu_inference/kernels/ragged_paged_attention/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +875 -0
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +287 -0
- tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1478 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1482 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4147 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +367 -0
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +51 -0
- tpu_inference/layers/__init__.py +0 -0
- tpu_inference/layers/common/__init__.py +0 -0
- tpu_inference/layers/common/attention_interface.py +390 -0
- tpu_inference/layers/common/attention_metadata.py +34 -0
- tpu_inference/layers/common/binary_search.py +295 -0
- tpu_inference/layers/common/quant_methods.py +8 -0
- tpu_inference/layers/common/sharding.py +582 -0
- tpu_inference/layers/jax/__init__.py +0 -0
- tpu_inference/layers/jax/attention/__init__.py +0 -0
- tpu_inference/layers/jax/attention/attention.py +255 -0
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +262 -0
- tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
- tpu_inference/layers/jax/base.py +151 -0
- tpu_inference/layers/jax/constants.py +88 -0
- tpu_inference/layers/jax/layers.py +301 -0
- tpu_inference/layers/jax/misc.py +16 -0
- tpu_inference/layers/jax/moe/__init__.py +0 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +185 -0
- tpu_inference/layers/jax/moe/moe.py +209 -0
- tpu_inference/layers/jax/rope.py +280 -0
- tpu_inference/layers/jax/rope_interface.py +214 -0
- tpu_inference/layers/jax/sample/__init__.py +0 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
- tpu_inference/layers/jax/sample/sampling.py +96 -0
- tpu_inference/layers/jax/sample/sampling_metadata.py +76 -0
- tpu_inference/layers/jax/transformer_block.py +107 -0
- tpu_inference/layers/vllm/__init__.py +0 -0
- tpu_inference/layers/vllm/attention.py +221 -0
- tpu_inference/layers/vllm/fused_moe.py +507 -0
- tpu_inference/layers/vllm/linear_common.py +186 -0
- tpu_inference/layers/vllm/quantization/__init__.py +39 -0
- tpu_inference/layers/vllm/quantization/awq.py +207 -0
- tpu_inference/layers/vllm/quantization/common.py +105 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +120 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +203 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +266 -0
- tpu_inference/layers/vllm/quantization/unquantized.py +386 -0
- tpu_inference/layers/vllm/sharding.py +230 -0
- tpu_inference/logger.py +10 -0
- tpu_inference/lora/__init__.py +0 -0
- tpu_inference/lora/torch_lora_ops.py +103 -0
- tpu_inference/lora/torch_punica_tpu.py +311 -0
- tpu_inference/mock/__init__.py +0 -0
- tpu_inference/mock/vllm_config_utils.py +28 -0
- tpu_inference/mock/vllm_envs.py +1219 -0
- tpu_inference/mock/vllm_logger.py +212 -0
- tpu_inference/mock/vllm_logging_utils.py +15 -0
- tpu_inference/models/__init__.py +0 -0
- tpu_inference/models/common/__init__.py +0 -0
- tpu_inference/models/common/model_loader.py +444 -0
- tpu_inference/models/jax/__init__.py +0 -0
- tpu_inference/models/jax/deepseek_v3.py +868 -0
- tpu_inference/models/jax/gpt_oss.py +492 -0
- tpu_inference/models/jax/jax_intermediate_tensor.py +79 -0
- tpu_inference/models/jax/llama3.py +375 -0
- tpu_inference/models/jax/llama4.py +629 -0
- tpu_inference/models/jax/llama_eagle3.py +333 -0
- tpu_inference/models/jax/phi3.py +376 -0
- tpu_inference/models/jax/qwen2.py +375 -0
- tpu_inference/models/jax/qwen2_5_vl.py +1103 -0
- tpu_inference/models/jax/qwen3.py +302 -0
- tpu_inference/models/jax/utils/__init__.py +0 -0
- tpu_inference/models/jax/utils/file_utils.py +96 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +163 -0
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +5 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +6 -0
- tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +5 -0
- tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +6 -0
- tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +105 -0
- tpu_inference/models/jax/utils/quantization/quantization_utils.py +653 -0
- tpu_inference/models/jax/utils/weight_utils.py +529 -0
- tpu_inference/models/vllm/__init__.py +0 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +286 -0
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +45 -0
- tpu_inference/platforms/__init__.py +2 -0
- tpu_inference/platforms/tpu_platform.py +269 -0
- tpu_inference/runner/__init__.py +0 -0
- tpu_inference/runner/block_table.py +122 -0
- tpu_inference/runner/compilation_manager.py +780 -0
- tpu_inference/runner/input_batch.py +435 -0
- tpu_inference/runner/kv_cache.py +132 -0
- tpu_inference/runner/kv_cache_manager.py +479 -0
- tpu_inference/runner/lora_utils.py +92 -0
- tpu_inference/runner/multimodal_manager.py +217 -0
- tpu_inference/runner/persistent_batch_manager.py +244 -0
- tpu_inference/runner/speculative_decoding_manager.py +248 -0
- tpu_inference/runner/structured_decoding_manager.py +88 -0
- tpu_inference/runner/tpu_runner.py +1620 -0
- tpu_inference/runner/utils.py +426 -0
- tpu_inference/spec_decode/__init__.py +0 -0
- tpu_inference/spec_decode/jax/__init__.py +0 -0
- tpu_inference/spec_decode/jax/eagle3.py +367 -0
- tpu_inference/tpu_info.py +77 -0
- tpu_inference/utils.py +317 -0
- tpu_inference/worker/__init__.py +0 -0
- tpu_inference/worker/tpu_worker.py +321 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/METADATA +107 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/RECORD +179 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/WHEEL +5 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/licenses/LICENSE +201 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,435 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
+
# Datastructures defining an input batch
|
|
3
|
+
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from typing import Any, Optional, cast
|
|
6
|
+
|
|
7
|
+
import jax
|
|
8
|
+
import jax.numpy as jnp
|
|
9
|
+
import numpy as np
|
|
10
|
+
from vllm.lora.request import LoRARequest
|
|
11
|
+
from vllm.sampling_params import SamplingType
|
|
12
|
+
from vllm.utils.collection_utils import swap_dict_values
|
|
13
|
+
from vllm.v1.core.sched.output import NewRequestData
|
|
14
|
+
from vllm.v1.spec_decode.utils import is_spec_decode_unsupported
|
|
15
|
+
|
|
16
|
+
from tpu_inference.runner.block_table import MultiGroupBlockTable
|
|
17
|
+
|
|
18
|
+
_SAMPLING_EPS = 1e-5
|
|
19
|
+
|
|
20
|
+
# TODO(xiang): fix cpu tensor init
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@dataclass
|
|
24
|
+
class CachedRequestState(NewRequestData):
|
|
25
|
+
|
|
26
|
+
output_token_ids: Optional[list[int]] = None
|
|
27
|
+
generator: Optional[Any] = None
|
|
28
|
+
mrope_positions: Optional[jax.Array] = None
|
|
29
|
+
mrope_position_delta: Optional[int] = None
|
|
30
|
+
|
|
31
|
+
def __post_init__(self):
|
|
32
|
+
self.num_prompt_tokens = len(self.prompt_token_ids)
|
|
33
|
+
|
|
34
|
+
@property
|
|
35
|
+
def num_tokens(self) -> int:
|
|
36
|
+
return self.num_prompt_tokens + len(self.output_token_ids)
|
|
37
|
+
|
|
38
|
+
def get_token_id(self, idx: int) -> int:
|
|
39
|
+
if idx < self.num_prompt_tokens:
|
|
40
|
+
return self.prompt_token_ids[idx]
|
|
41
|
+
else:
|
|
42
|
+
return self.output_token_ids[idx - self.num_prompt_tokens]
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class InputBatch:
|
|
46
|
+
|
|
47
|
+
def __init__(
|
|
48
|
+
self,
|
|
49
|
+
max_num_reqs: int,
|
|
50
|
+
max_model_len: int,
|
|
51
|
+
max_num_batched_tokens: int,
|
|
52
|
+
pin_memory: bool,
|
|
53
|
+
vocab_size: int,
|
|
54
|
+
block_sizes: list[int],
|
|
55
|
+
is_spec_decode: bool = False,
|
|
56
|
+
):
|
|
57
|
+
self.is_spec_decode = is_spec_decode
|
|
58
|
+
self.max_num_reqs = max_num_reqs
|
|
59
|
+
self.max_model_len = max_model_len
|
|
60
|
+
self.max_num_batched_tokens = max_num_batched_tokens
|
|
61
|
+
self.pin_memory = pin_memory
|
|
62
|
+
self.vocab_size = vocab_size
|
|
63
|
+
|
|
64
|
+
self._req_ids: list[Optional[str]] = []
|
|
65
|
+
self.req_id_to_index: dict[str, int] = {}
|
|
66
|
+
|
|
67
|
+
# TODO(woosuk): This buffer could be too large if max_model_len is big.
|
|
68
|
+
# Find a way to reduce the CPU memory usage.
|
|
69
|
+
# This buffer is not directly transferred to the GPU, so it does not
|
|
70
|
+
# need to be pinned.
|
|
71
|
+
self.token_ids_cpu = np.zeros(
|
|
72
|
+
(max_num_reqs, max_model_len),
|
|
73
|
+
dtype=np.int32,
|
|
74
|
+
)
|
|
75
|
+
self.num_tokens = np.zeros(max_num_reqs, dtype=np.int32)
|
|
76
|
+
self.num_tokens_no_spec = np.zeros(max_num_reqs, dtype=np.int32)
|
|
77
|
+
self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32)
|
|
78
|
+
self.num_computed_tokens_cpu = np.zeros(
|
|
79
|
+
(max_num_reqs, ),
|
|
80
|
+
dtype=np.int32,
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
# Block table.
|
|
84
|
+
self.block_table = MultiGroupBlockTable(
|
|
85
|
+
max_num_reqs=max_num_reqs,
|
|
86
|
+
max_model_len=max_model_len,
|
|
87
|
+
max_num_batched_tokens=max_num_batched_tokens,
|
|
88
|
+
pin_memory=pin_memory,
|
|
89
|
+
block_sizes=block_sizes,
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
# Sampling-related.
|
|
93
|
+
self.temperature_cpu = np.empty((max_num_reqs, ), dtype=np.float32)
|
|
94
|
+
self.greedy_reqs: set[str] = set()
|
|
95
|
+
self.random_reqs: set[str] = set()
|
|
96
|
+
|
|
97
|
+
self.top_p_cpu = np.empty((max_num_reqs, ), dtype=np.float32)
|
|
98
|
+
|
|
99
|
+
self.top_k_cpu = np.empty((max_num_reqs, ), dtype=np.int32)
|
|
100
|
+
|
|
101
|
+
# IDs of requests which do not support spec decoding
|
|
102
|
+
self.spec_decode_unsupported_reqs: set[str] = set()
|
|
103
|
+
|
|
104
|
+
# req_index -> (min_tokens, stop_token_ids)
|
|
105
|
+
self.min_tokens: dict[int, tuple[int, set[int]]] = {}
|
|
106
|
+
|
|
107
|
+
# lora related
|
|
108
|
+
self.request_lora_mapping = np.zeros((self.max_num_reqs, ),
|
|
109
|
+
dtype=np.int32)
|
|
110
|
+
self.lora_id_to_request_ids: dict[int, set[str]] = {}
|
|
111
|
+
self.lora_id_to_lora_request: dict[int, LoRARequest] = {}
|
|
112
|
+
|
|
113
|
+
# req_index -> generator
|
|
114
|
+
# NOTE(woosuk): The indices of the requests that do not have their own
|
|
115
|
+
# generator should not be included in the dictionary.
|
|
116
|
+
self.generators: dict[int, Any] = {}
|
|
117
|
+
|
|
118
|
+
self.num_logprobs: dict[str, int] = {}
|
|
119
|
+
|
|
120
|
+
self.logit_bias: list[Optional[dict[int,
|
|
121
|
+
float]]] = [None] * max_num_reqs
|
|
122
|
+
self.has_allowed_token_ids: set[str] = set()
|
|
123
|
+
# NOTE(lufang): In the mask tensor, if the corresponding token allowed,
|
|
124
|
+
# the value is False. Since we use masked_fill_ to set -inf.
|
|
125
|
+
self.allowed_token_ids_mask: Optional[jax.Array] = None
|
|
126
|
+
self.allowed_token_ids_mask_cpu: Optional[jax.Array] = None
|
|
127
|
+
|
|
128
|
+
# req_index -> bad_words_token_ids
|
|
129
|
+
self.bad_words_token_ids: dict[int, list[list[int]]] = {}
|
|
130
|
+
|
|
131
|
+
self.req_output_token_ids: list[Optional[list[int]]] = []
|
|
132
|
+
|
|
133
|
+
self.request_distribution: list[int] = [0, 0, 0]
|
|
134
|
+
|
|
135
|
+
@property
|
|
136
|
+
def req_ids(self) -> list[str]:
|
|
137
|
+
# None elements should only be present transiently
|
|
138
|
+
# while performing state updates to the batch.
|
|
139
|
+
return cast(list[str], self._req_ids)
|
|
140
|
+
|
|
141
|
+
def add_request(
|
|
142
|
+
self,
|
|
143
|
+
request: "CachedRequestState",
|
|
144
|
+
req_index: Optional[int] = None,
|
|
145
|
+
) -> None:
|
|
146
|
+
if req_index is None:
|
|
147
|
+
req_index = self.num_reqs
|
|
148
|
+
assert req_index < self.max_num_reqs, f"{req_index} < {self.max_num_reqs} failed!"
|
|
149
|
+
|
|
150
|
+
req_id = request.req_id
|
|
151
|
+
if req_index == len(self._req_ids):
|
|
152
|
+
self._req_ids.append(req_id)
|
|
153
|
+
self.req_output_token_ids.append(request.output_token_ids)
|
|
154
|
+
else:
|
|
155
|
+
self._req_ids[req_index] = req_id
|
|
156
|
+
self.req_output_token_ids[req_index] = request.output_token_ids
|
|
157
|
+
|
|
158
|
+
self.req_id_to_index[req_id] = req_index
|
|
159
|
+
|
|
160
|
+
# Copy the prompt token ids and output token ids.
|
|
161
|
+
num_prompt_tokens = len(request.prompt_token_ids)
|
|
162
|
+
self.num_prompt_tokens[req_index] = num_prompt_tokens
|
|
163
|
+
self.token_ids_cpu[
|
|
164
|
+
req_index, :num_prompt_tokens] = request.prompt_token_ids
|
|
165
|
+
start_idx = num_prompt_tokens
|
|
166
|
+
end_idx = start_idx + len(request.output_token_ids)
|
|
167
|
+
self.token_ids_cpu[req_index,
|
|
168
|
+
start_idx:end_idx] = request.output_token_ids
|
|
169
|
+
# Number of token ids in token_ids_cpu.
|
|
170
|
+
# NOTE(woosuk): This may include spec decode tokens.
|
|
171
|
+
self.num_tokens[req_index] = request.num_tokens
|
|
172
|
+
# Number of tokens without spec decode tokens.
|
|
173
|
+
self.num_tokens_no_spec[req_index] = request.num_tokens
|
|
174
|
+
|
|
175
|
+
self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens
|
|
176
|
+
self.block_table.add_row(request.block_ids, req_index)
|
|
177
|
+
|
|
178
|
+
sampling_params = request.sampling_params
|
|
179
|
+
|
|
180
|
+
if (self.is_spec_decode
|
|
181
|
+
and is_spec_decode_unsupported(sampling_params)):
|
|
182
|
+
self.spec_decode_unsupported_reqs.add(req_id)
|
|
183
|
+
|
|
184
|
+
if sampling_params.sampling_type == SamplingType.GREEDY:
|
|
185
|
+
# Avoid later division by zero.
|
|
186
|
+
self.temperature_cpu[req_index] = -1.0
|
|
187
|
+
self.greedy_reqs.add(req_id)
|
|
188
|
+
else:
|
|
189
|
+
self.temperature_cpu[req_index] = sampling_params.temperature
|
|
190
|
+
self.random_reqs.add(req_id)
|
|
191
|
+
|
|
192
|
+
self.top_p_cpu[req_index] = sampling_params.top_p
|
|
193
|
+
top_k = sampling_params.top_k
|
|
194
|
+
if top_k <= 0 or top_k >= self.vocab_size:
|
|
195
|
+
top_k = 1
|
|
196
|
+
self.top_k_cpu[req_index] = top_k
|
|
197
|
+
if sampling_params.min_tokens:
|
|
198
|
+
self.min_tokens[req_index] = (sampling_params.min_tokens,
|
|
199
|
+
sampling_params.all_stop_token_ids)
|
|
200
|
+
|
|
201
|
+
# NOTE(woosuk): self.generators should not include the requests that
|
|
202
|
+
# do not have their own generator.
|
|
203
|
+
if request.generator is not None:
|
|
204
|
+
self.generators[req_index] = request.generator
|
|
205
|
+
|
|
206
|
+
if sampling_params.logprobs is not None:
|
|
207
|
+
self.num_logprobs[req_id] = sampling_params.logprobs
|
|
208
|
+
if sampling_params.logit_bias is not None:
|
|
209
|
+
self.logit_bias[req_index] = sampling_params.logit_bias
|
|
210
|
+
|
|
211
|
+
if sampling_params.allowed_token_ids:
|
|
212
|
+
self.has_allowed_token_ids.add(req_id)
|
|
213
|
+
if self.allowed_token_ids_mask_cpu is None:
|
|
214
|
+
# Lazy allocation for this tensor, which can be large.
|
|
215
|
+
# False means we don't fill with -inf.
|
|
216
|
+
self.allowed_token_ids_mask = jnp.zeros(self.max_num_reqs,
|
|
217
|
+
self.vocab_size,
|
|
218
|
+
dtype=jnp.bool)
|
|
219
|
+
self.allowed_token_ids_mask_cpu = np.zeros(self.max_num_reqs,
|
|
220
|
+
self.vocab_size,
|
|
221
|
+
dtype=np.bool)
|
|
222
|
+
self.allowed_token_ids_mask_cpu[req_index] = True
|
|
223
|
+
# False means we don't fill with -inf.
|
|
224
|
+
self.allowed_token_ids_mask_cpu[req_index][
|
|
225
|
+
sampling_params.allowed_token_ids] = False
|
|
226
|
+
|
|
227
|
+
if sampling_params.bad_words_token_ids:
|
|
228
|
+
self.bad_words_token_ids[
|
|
229
|
+
req_index] = sampling_params.bad_words_token_ids
|
|
230
|
+
|
|
231
|
+
# Add request lora ID
|
|
232
|
+
if request.lora_request:
|
|
233
|
+
lora_id = request.lora_request.lora_int_id
|
|
234
|
+
if lora_id not in self.lora_id_to_request_ids:
|
|
235
|
+
self.lora_id_to_request_ids[lora_id] = set()
|
|
236
|
+
|
|
237
|
+
self.request_lora_mapping[req_index] = lora_id
|
|
238
|
+
self.lora_id_to_request_ids[lora_id].add(request.req_id)
|
|
239
|
+
self.lora_id_to_lora_request[lora_id] = request.lora_request
|
|
240
|
+
else:
|
|
241
|
+
# No LoRA
|
|
242
|
+
self.request_lora_mapping[req_index] = 0
|
|
243
|
+
|
|
244
|
+
def remove_request(self, req_id: str) -> Optional[int]:
|
|
245
|
+
"""This method must always be followed by a call to condense()."""
|
|
246
|
+
|
|
247
|
+
req_index = self.req_id_to_index.pop(req_id, None)
|
|
248
|
+
if req_index is None:
|
|
249
|
+
return None
|
|
250
|
+
self._req_ids[req_index] = None
|
|
251
|
+
self.req_output_token_ids[req_index] = None
|
|
252
|
+
|
|
253
|
+
self.greedy_reqs.discard(req_id)
|
|
254
|
+
self.random_reqs.discard(req_id)
|
|
255
|
+
self.spec_decode_unsupported_reqs.discard(req_id)
|
|
256
|
+
self.min_tokens.pop(req_index, None)
|
|
257
|
+
self.generators.pop(req_index, None)
|
|
258
|
+
self.num_logprobs.pop(req_id, None)
|
|
259
|
+
|
|
260
|
+
# LoRA
|
|
261
|
+
lora_id = self.request_lora_mapping[req_index]
|
|
262
|
+
if lora_id != 0:
|
|
263
|
+
self.lora_id_to_request_ids[lora_id].discard(req_id)
|
|
264
|
+
if len(self.lora_id_to_request_ids[lora_id]) == 0:
|
|
265
|
+
self.lora_id_to_request_ids.pop(lora_id)
|
|
266
|
+
self.lora_id_to_lora_request.pop(lora_id)
|
|
267
|
+
self.request_lora_mapping[req_index] = 0
|
|
268
|
+
|
|
269
|
+
self.logit_bias[req_index] = None
|
|
270
|
+
self.has_allowed_token_ids.discard(req_id)
|
|
271
|
+
if self.allowed_token_ids_mask_cpu is not None:
|
|
272
|
+
# False means we don't fill with -inf.
|
|
273
|
+
self.allowed_token_ids_mask_cpu[req_index].fill_(False)
|
|
274
|
+
self.bad_words_token_ids.pop(req_index, None)
|
|
275
|
+
return req_index
|
|
276
|
+
|
|
277
|
+
def swap_states(self, i1: int, i2: int) -> None:
|
|
278
|
+
old_id_i1 = self._req_ids[i1]
|
|
279
|
+
old_id_i2 = self._req_ids[i2]
|
|
280
|
+
self._req_ids[i1], self._req_ids[i2] =\
|
|
281
|
+
self._req_ids[i2], self._req_ids[i1] # noqa
|
|
282
|
+
self.req_output_token_ids[i1], self.req_output_token_ids[i2] =\
|
|
283
|
+
self.req_output_token_ids[i2], self.req_output_token_ids[i1]
|
|
284
|
+
assert old_id_i1 is not None and old_id_i2 is not None
|
|
285
|
+
self.req_id_to_index[old_id_i1], self.req_id_to_index[old_id_i2] =\
|
|
286
|
+
self.req_id_to_index[old_id_i2], self.req_id_to_index[old_id_i1]
|
|
287
|
+
self.num_tokens[i1], self.num_tokens[i2] =\
|
|
288
|
+
self.num_tokens[i2], self.num_tokens[i1]
|
|
289
|
+
self.num_tokens_no_spec[i1], self.num_tokens_no_spec[i2] =\
|
|
290
|
+
self.num_tokens_no_spec[i2], self.num_tokens_no_spec[i1]
|
|
291
|
+
self.num_prompt_tokens[i1], self.num_prompt_tokens[i2] =\
|
|
292
|
+
self.num_prompt_tokens[i2], self.num_prompt_tokens[i1]
|
|
293
|
+
self.num_computed_tokens_cpu[i1], self.num_computed_tokens_cpu[i2] =\
|
|
294
|
+
self.num_computed_tokens_cpu[i2], self.num_computed_tokens_cpu[i1]
|
|
295
|
+
self.temperature_cpu[i1], self.temperature_cpu[i2] =\
|
|
296
|
+
self.temperature_cpu[i2], self.temperature_cpu[i1]
|
|
297
|
+
self.top_p_cpu[i1], self.top_p_cpu[i2] =\
|
|
298
|
+
self.top_p_cpu[i2], self.top_p_cpu[i1]
|
|
299
|
+
self.top_k_cpu[i1], self.top_k_cpu[i2] =\
|
|
300
|
+
self.top_k_cpu[i2], self.top_k_cpu[i1]
|
|
301
|
+
|
|
302
|
+
# NOTE: the following is unsafe
|
|
303
|
+
# self.token_ids_cpu[i1, ...], self.token_ids_cpu[i2, ...], =\
|
|
304
|
+
# self.token_ids_cpu[i2, ...], self.token_ids_cpu[i1, ...]
|
|
305
|
+
# instead, we need to temporiarily copy the data for one of the indices
|
|
306
|
+
# TODO(lucas): optimize this by only copying valid indices
|
|
307
|
+
tmp = self.token_ids_cpu[i1, ...].copy()
|
|
308
|
+
self.token_ids_cpu[i1, ...] = self.token_ids_cpu[i2, ...]
|
|
309
|
+
self.token_ids_cpu[i2, ...] = tmp
|
|
310
|
+
|
|
311
|
+
swap_dict_values(self.generators, i1, i2)
|
|
312
|
+
swap_dict_values(self.min_tokens, i1, i2)
|
|
313
|
+
swap_dict_values(self.bad_words_token_ids, i1, i2)
|
|
314
|
+
|
|
315
|
+
self.request_lora_mapping[i1], self.request_lora_mapping[i2] =\
|
|
316
|
+
self.request_lora_mapping[i2], self.request_lora_mapping[i1]
|
|
317
|
+
self.logit_bias[i1], self.logit_bias[i2] =\
|
|
318
|
+
self.logit_bias[i2], self.logit_bias[i1]
|
|
319
|
+
|
|
320
|
+
if self.allowed_token_ids_mask_cpu is not None:
|
|
321
|
+
self.allowed_token_ids_mask_cpu[i1], \
|
|
322
|
+
self.allowed_token_ids_mask_cpu[i2] =\
|
|
323
|
+
self.allowed_token_ids_mask_cpu[i2], \
|
|
324
|
+
self.allowed_token_ids_mask_cpu[i1]
|
|
325
|
+
self.block_table.swap_row(i1, i2)
|
|
326
|
+
|
|
327
|
+
def condense(self, empty_req_indices: list[int]) -> None:
|
|
328
|
+
num_reqs = self.num_reqs
|
|
329
|
+
if num_reqs == 0:
|
|
330
|
+
# The batched states are empty.
|
|
331
|
+
self._req_ids.clear()
|
|
332
|
+
self.req_output_token_ids.clear()
|
|
333
|
+
return
|
|
334
|
+
|
|
335
|
+
# NOTE(woosuk): This function assumes that the empty_req_indices
|
|
336
|
+
# is sorted in descending order.
|
|
337
|
+
last_req_index = num_reqs + len(empty_req_indices) - 1
|
|
338
|
+
while empty_req_indices:
|
|
339
|
+
# Find the largest non-empty index.
|
|
340
|
+
while last_req_index in empty_req_indices:
|
|
341
|
+
last_req_index -= 1
|
|
342
|
+
|
|
343
|
+
# Find the smallest empty index.
|
|
344
|
+
empty_index = empty_req_indices.pop()
|
|
345
|
+
if empty_index >= last_req_index:
|
|
346
|
+
break
|
|
347
|
+
|
|
348
|
+
# Swap the states.
|
|
349
|
+
req_id = self._req_ids[last_req_index]
|
|
350
|
+
output_token_ids = self.req_output_token_ids[last_req_index]
|
|
351
|
+
assert req_id is not None
|
|
352
|
+
self._req_ids[empty_index] = req_id
|
|
353
|
+
self._req_ids[last_req_index] = None
|
|
354
|
+
self.req_output_token_ids[empty_index] = output_token_ids
|
|
355
|
+
self.req_output_token_ids[last_req_index] = None
|
|
356
|
+
self.req_id_to_index[req_id] = empty_index
|
|
357
|
+
|
|
358
|
+
num_tokens = self.num_tokens[last_req_index]
|
|
359
|
+
self.token_ids_cpu[empty_index, :num_tokens] = self.token_ids_cpu[
|
|
360
|
+
last_req_index, :num_tokens]
|
|
361
|
+
self.num_tokens[empty_index] = num_tokens
|
|
362
|
+
self.num_tokens_no_spec[empty_index] = self.num_tokens_no_spec[
|
|
363
|
+
last_req_index]
|
|
364
|
+
self.num_prompt_tokens[empty_index] = self.num_prompt_tokens[
|
|
365
|
+
last_req_index]
|
|
366
|
+
self.num_computed_tokens_cpu[
|
|
367
|
+
empty_index] = self.num_computed_tokens_cpu[last_req_index]
|
|
368
|
+
self.block_table.move_row(last_req_index, empty_index)
|
|
369
|
+
self.temperature_cpu[empty_index] = self.temperature_cpu[
|
|
370
|
+
last_req_index]
|
|
371
|
+
self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index]
|
|
372
|
+
self.top_k_cpu[empty_index] = self.top_k_cpu[last_req_index]
|
|
373
|
+
generator = self.generators.pop(last_req_index, None)
|
|
374
|
+
if generator is not None:
|
|
375
|
+
self.generators[empty_index] = generator
|
|
376
|
+
|
|
377
|
+
min_token = self.min_tokens.pop(last_req_index, None)
|
|
378
|
+
if min_token is not None:
|
|
379
|
+
self.min_tokens[empty_index] = min_token
|
|
380
|
+
|
|
381
|
+
self.request_lora_mapping[empty_index] = self.request_lora_mapping[
|
|
382
|
+
last_req_index]
|
|
383
|
+
|
|
384
|
+
self.logit_bias[empty_index] = self.logit_bias[last_req_index]
|
|
385
|
+
|
|
386
|
+
if self.allowed_token_ids_mask_cpu is not None:
|
|
387
|
+
self.allowed_token_ids_mask_cpu[
|
|
388
|
+
empty_index] = self.allowed_token_ids_mask_cpu[
|
|
389
|
+
last_req_index]
|
|
390
|
+
|
|
391
|
+
bad_words_token_ids = self.bad_words_token_ids.pop(
|
|
392
|
+
last_req_index, None)
|
|
393
|
+
if bad_words_token_ids is not None:
|
|
394
|
+
self.bad_words_token_ids[empty_index] = bad_words_token_ids
|
|
395
|
+
# Decrement last_req_index since it is now empty.
|
|
396
|
+
last_req_index -= 1
|
|
397
|
+
|
|
398
|
+
# Trim lists to the batch size.
|
|
399
|
+
del self._req_ids[self.num_reqs:]
|
|
400
|
+
del self.req_output_token_ids[self.num_reqs:]
|
|
401
|
+
|
|
402
|
+
@property
|
|
403
|
+
def num_reqs(self) -> int:
|
|
404
|
+
return len(self.req_id_to_index)
|
|
405
|
+
|
|
406
|
+
@property
|
|
407
|
+
def all_greedy(self) -> bool:
|
|
408
|
+
return len(self.random_reqs) == 0
|
|
409
|
+
|
|
410
|
+
@property
|
|
411
|
+
def max_num_logprobs(self) -> Optional[int]:
|
|
412
|
+
return max(self.num_logprobs.values()) if self.num_logprobs else None
|
|
413
|
+
|
|
414
|
+
def make_lora_inputs(
|
|
415
|
+
self, num_scheduled_tokens: np.ndarray
|
|
416
|
+
) -> tuple[tuple[int, ...], tuple[int, ...], set[LoRARequest]]:
|
|
417
|
+
"""
|
|
418
|
+
Given the num_scheduled_tokens for each request in the batch, return
|
|
419
|
+
datastructures used to activate the current LoRAs.
|
|
420
|
+
Returns:
|
|
421
|
+
1. prompt_lora_mapping: A tuple of size self.num_reqs where,
|
|
422
|
+
prompt_lora_mapping[i] is the LoRA id to use for the ith prompt.
|
|
423
|
+
2. token_lora_mapping: A tuple of size np.sum(num_scheduled_tokens)
|
|
424
|
+
where, token_lora_mapping[i] is the LoRA id to use for ith token.
|
|
425
|
+
3. lora_requests: Set of relevant LoRA requests.
|
|
426
|
+
"""
|
|
427
|
+
|
|
428
|
+
req_lora_mapping = self.request_lora_mapping[:self.num_reqs]
|
|
429
|
+
prompt_lora_mapping = tuple(req_lora_mapping)
|
|
430
|
+
token_lora_mapping = tuple(
|
|
431
|
+
req_lora_mapping.repeat(num_scheduled_tokens))
|
|
432
|
+
active_lora_requests: set[LoRARequest] = set(
|
|
433
|
+
self.lora_id_to_lora_request.values())
|
|
434
|
+
|
|
435
|
+
return prompt_lora_mapping, token_lora_mapping, active_lora_requests
|
|
@@ -0,0 +1,132 @@
|
|
|
1
|
+
from typing import Any, List
|
|
2
|
+
|
|
3
|
+
import jax
|
|
4
|
+
import jax.numpy as jnp
|
|
5
|
+
import numpy as np
|
|
6
|
+
from jax._src import dtypes
|
|
7
|
+
from jax.sharding import Mesh, NamedSharding, PartitionSpec
|
|
8
|
+
from torchax.ops.mappings import t2j_dtype
|
|
9
|
+
|
|
10
|
+
import tpu_inference.kernels.ragged_paged_attention.v3.kernel as rpa
|
|
11
|
+
import tpu_inference.kernels.ragged_paged_attention.v3.kernel_hd64 as rpa_hd64
|
|
12
|
+
from tpu_inference.layers.common.sharding import ShardingAxisName
|
|
13
|
+
from tpu_inference.logger import init_logger
|
|
14
|
+
|
|
15
|
+
logger = init_logger(__name__)
|
|
16
|
+
|
|
17
|
+
DEFAULT_KV_CACHE_DTYPE = jnp.bfloat16
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def get_kv_cache_shape_with_mesh(mesh: Mesh, total_num_pages: int,
|
|
21
|
+
page_size: int, actual_num_kv_heads: int,
|
|
22
|
+
actual_head_dim: int, kv_dtype: any):
|
|
23
|
+
"""Gets the KV cache shape based on the mesh configuration."""
|
|
24
|
+
|
|
25
|
+
model_cnt = mesh.shape["model"]
|
|
26
|
+
assert actual_num_kv_heads % model_cnt == 0
|
|
27
|
+
# NOTE(chengjiyao): Currently, the attention kernel is tailored to the
|
|
28
|
+
# specific model, rather than being determined by the head_dim. If new
|
|
29
|
+
# models are introduced with a head_dim of 64, this will require additional
|
|
30
|
+
# model-specific adjustments.
|
|
31
|
+
get_kv_cache_shape_fn = (
|
|
32
|
+
rpa_hd64.get_kv_cache_shape if actual_head_dim == 64 \
|
|
33
|
+
else rpa.get_kv_cache_shape
|
|
34
|
+
)
|
|
35
|
+
shape = list(
|
|
36
|
+
get_kv_cache_shape_fn(total_num_pages, page_size,
|
|
37
|
+
actual_num_kv_heads // model_cnt,
|
|
38
|
+
actual_head_dim, kv_dtype))
|
|
39
|
+
shape[2] *= model_cnt
|
|
40
|
+
return tuple(shape)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def create_kv_caches(
|
|
44
|
+
num_blocks: int,
|
|
45
|
+
block_size: int,
|
|
46
|
+
num_kv_heads: int,
|
|
47
|
+
head_size: int,
|
|
48
|
+
mesh: Mesh,
|
|
49
|
+
layer_names: List[str],
|
|
50
|
+
cache_dtype: jnp.dtype = DEFAULT_KV_CACHE_DTYPE,
|
|
51
|
+
) -> List[jax.Array]:
|
|
52
|
+
"""
|
|
53
|
+
Creates a list of KV cache where each array mapps to single attention layer.
|
|
54
|
+
|
|
55
|
+
The shape of the KV cache per layer is:
|
|
56
|
+
(num_blocks, block_size, cdiv(num_kv_heads * 2, packing), packing, head_dim)
|
|
57
|
+
where packing = (32 // dtype bits)
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
num_blocks: The number of blocks in the KV cache.
|
|
61
|
+
block_size: The size of each block in the KV cache.
|
|
62
|
+
num_kv_heads: The number of KV heads in the KV cache.
|
|
63
|
+
head_size: The size of each head in the KV cache.
|
|
64
|
+
mesh: The mesh to shard the KV caches across.
|
|
65
|
+
layer_names: The names of the decoder layers in the model.
|
|
66
|
+
cache_dtype: The datatype of KV cache.
|
|
67
|
+
|
|
68
|
+
Returns:
|
|
69
|
+
A list of KV caches, one per each decoder layer in the model.
|
|
70
|
+
|
|
71
|
+
"""
|
|
72
|
+
# TODO(xiang): fix this together with get_kv_cache_spec
|
|
73
|
+
# cache_dtype = kv_cache_spec.dtype
|
|
74
|
+
|
|
75
|
+
cache_shape = get_kv_cache_shape_with_mesh(mesh, num_blocks, block_size,
|
|
76
|
+
num_kv_heads, head_size,
|
|
77
|
+
cache_dtype)
|
|
78
|
+
|
|
79
|
+
sharding = NamedSharding(
|
|
80
|
+
mesh,
|
|
81
|
+
PartitionSpec(ShardingAxisName.ATTN_DATA, None,
|
|
82
|
+
ShardingAxisName.ATTN_HEAD))
|
|
83
|
+
|
|
84
|
+
def _allocate() -> jax.Array:
|
|
85
|
+
return jnp.empty(
|
|
86
|
+
shape=cache_shape,
|
|
87
|
+
dtype=cache_dtype,
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
sharded_allocate = jax.jit(_allocate, out_shardings=sharding)
|
|
91
|
+
kv_caches = []
|
|
92
|
+
for _ in layer_names:
|
|
93
|
+
kv_caches.append(sharded_allocate())
|
|
94
|
+
return kv_caches
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def get_rpa_page_size_bytes(mesh: Mesh, kv_cache_specs: dict[str, Any]) -> int:
|
|
98
|
+
"""
|
|
99
|
+
Calculate KV cache page size of RPA kernel.
|
|
100
|
+
|
|
101
|
+
Args:
|
|
102
|
+
mesh: The mesh to shard the KV caches across.
|
|
103
|
+
kv_cache_specs: Dictionary of KV cache specs.
|
|
104
|
+
|
|
105
|
+
Returns:
|
|
106
|
+
KV cache page size in bytes.
|
|
107
|
+
"""
|
|
108
|
+
|
|
109
|
+
# Import it here to avoid circular import.
|
|
110
|
+
from vllm.v1.kv_cache_interface import AttentionSpec
|
|
111
|
+
|
|
112
|
+
page_size_bytes_set = set()
|
|
113
|
+
for kv_cache_spec in kv_cache_specs.values():
|
|
114
|
+
assert isinstance(kv_cache_spec, AttentionSpec)
|
|
115
|
+
|
|
116
|
+
dtype = t2j_dtype(kv_cache_spec.dtype)
|
|
117
|
+
bits = dtypes.bit_width(dtype)
|
|
118
|
+
|
|
119
|
+
kv_cache_shape = get_kv_cache_shape_with_mesh(
|
|
120
|
+
mesh=mesh,
|
|
121
|
+
total_num_pages=1, # Pass 1 to get shape of a single page.
|
|
122
|
+
page_size=kv_cache_spec.block_size,
|
|
123
|
+
actual_num_kv_heads=kv_cache_spec.num_kv_heads,
|
|
124
|
+
actual_head_dim=kv_cache_spec.head_size,
|
|
125
|
+
kv_dtype=dtype,
|
|
126
|
+
)
|
|
127
|
+
page_size_bytes = (bits * np.prod(kv_cache_shape)) // 8
|
|
128
|
+
page_size_bytes_set.add(page_size_bytes)
|
|
129
|
+
|
|
130
|
+
# Ensure that page size is the same for all kv caches.
|
|
131
|
+
assert len(page_size_bytes_set) == 1
|
|
132
|
+
return page_size_bytes_set.pop()
|