tpu-inference 0.11.1.dev202511180814__py3-none-any.whl → 0.12.0.dev20251213__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/kernels/fused_moe_v1_test.py +303 -34
- tests/kernels/mla_v1_test.py +129 -41
- tests/kernels/quantized_matmul_kernel_test.py +2 -34
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +3 -1
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +3 -1
- tests/lora/test_layers.py +4 -7
- tests/lora/test_lora_perf.py +53 -0
- tests/lora/utils.py +0 -8
- tests/test_envs.py +110 -12
- tests/test_quantization.py +3 -0
- tests/test_utils.py +1 -2
- tpu_inference/__init__.py +22 -3
- tpu_inference/core/disagg_utils.py +6 -8
- tpu_inference/distributed/tpu_connector.py +3 -4
- tpu_inference/distributed/utils.py +3 -2
- tpu_inference/envs.py +93 -9
- tpu_inference/executors/ray_distributed_executor.py +9 -2
- tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
- tpu_inference/kernels/fused_moe/v1/kernel.py +712 -143
- tpu_inference/kernels/mla/v1/kernel.py +98 -120
- tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +140 -67
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +204 -120
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +2 -1
- tpu_inference/layers/common/attention_interface.py +7 -1
- tpu_inference/layers/common/sharding.py +11 -7
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +232 -64
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +5 -5
- tpu_inference/layers/vllm/fused_moe.py +170 -208
- tpu_inference/layers/vllm/linear_common.py +43 -21
- tpu_inference/layers/vllm/quantization/common.py +11 -6
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +4 -3
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +74 -65
- tpu_inference/layers/vllm/quantization/mxfp4.py +140 -94
- tpu_inference/layers/vllm/quantization/unquantized.py +103 -80
- tpu_inference/layers/vllm/sharding.py +2 -2
- tpu_inference/lora/torch_punica_tpu.py +1 -2
- tpu_inference/models/common/model_loader.py +84 -28
- tpu_inference/models/jax/deepseek_v3.py +185 -64
- tpu_inference/models/jax/gpt_oss.py +3 -3
- tpu_inference/models/jax/llama3.py +2 -1
- tpu_inference/models/jax/llama_eagle3.py +8 -5
- tpu_inference/models/jax/llama_guard_4.py +361 -0
- tpu_inference/models/jax/qwen2.py +2 -1
- tpu_inference/models/jax/qwen2_5_vl.py +163 -48
- tpu_inference/models/jax/qwen3.py +2 -1
- tpu_inference/models/jax/utils/quantization/quantization_utils.py +7 -8
- tpu_inference/models/jax/utils/weight_utils.py +205 -144
- tpu_inference/models/vllm/vllm_model_wrapper.py +14 -8
- tpu_inference/platforms/tpu_platform.py +34 -50
- tpu_inference/runner/compilation_manager.py +144 -60
- tpu_inference/runner/kv_cache.py +40 -20
- tpu_inference/runner/kv_cache_manager.py +48 -33
- tpu_inference/runner/persistent_batch_manager.py +40 -2
- tpu_inference/runner/structured_decoding_manager.py +2 -3
- tpu_inference/runner/tpu_runner.py +280 -149
- tpu_inference/runner/utils.py +2 -2
- tpu_inference/spec_decode/jax/eagle3.py +71 -21
- tpu_inference/tpu_info.py +4 -3
- tpu_inference/utils.py +46 -18
- tpu_inference/worker/tpu_worker.py +197 -63
- {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/METADATA +9 -10
- {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/RECORD +70 -74
- tpu_inference/mock/__init__.py +0 -0
- tpu_inference/mock/vllm_config_utils.py +0 -28
- tpu_inference/mock/vllm_envs.py +0 -1219
- tpu_inference/mock/vllm_logger.py +0 -212
- tpu_inference/mock/vllm_logging_utils.py +0 -15
- tpu_inference/models/jax/phi3.py +0 -376
- {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/WHEEL +0 -0
- {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/top_level.txt +0 -0
|
@@ -1,15 +1,16 @@
|
|
|
1
1
|
import functools
|
|
2
|
-
import math
|
|
3
2
|
from typing import TYPE_CHECKING, Dict, List
|
|
4
3
|
|
|
5
4
|
import jax
|
|
6
5
|
import jax.numpy as jnp
|
|
6
|
+
import numpy as np
|
|
7
7
|
import vllm.envs as envs
|
|
8
8
|
from jax.sharding import NamedSharding, PartitionSpec
|
|
9
9
|
from torchax.ops.mappings import t2j_dtype
|
|
10
|
-
from vllm.attention import Attention
|
|
11
10
|
from vllm.attention.backends.abstract import AttentionType
|
|
11
|
+
from vllm.attention.layer import Attention
|
|
12
12
|
from vllm.config import get_layers_from_vllm_config
|
|
13
|
+
from vllm.utils.math_utils import cdiv
|
|
13
14
|
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
|
14
15
|
KVCacheSpec, MLAAttentionSpec,
|
|
15
16
|
SlidingWindowSpec)
|
|
@@ -38,20 +39,30 @@ class KVCacheManager:
|
|
|
38
39
|
# means this layer will perform attention using the keys and values
|
|
39
40
|
# from the KV cache of `shared_kv_cache_layers[layer_name]`.
|
|
40
41
|
self.shared_kv_cache_layers: dict[str, str] = {}
|
|
42
|
+
self.use_mla = self.runner.model_config.use_mla
|
|
41
43
|
|
|
42
44
|
def get_kv_cache_spec(self):
|
|
43
45
|
# TODO(xiang): this hack tricks engine core to init successfully
|
|
44
46
|
block_size = self.runner.cache_config.block_size
|
|
45
|
-
use_mla = self.runner.model_config.use_mla
|
|
46
47
|
kv_cache_spec: dict[str, KVCacheSpec] = {}
|
|
47
48
|
|
|
48
49
|
# If use pure jax (MODEL_IMPL_TYPE=flax_nnx), we don't register
|
|
49
50
|
# attention into compilation config.
|
|
50
51
|
# Use FullAttentionSpec for each layer
|
|
51
52
|
# TODO(pooyam): Is it possible to merge the logic for vllm and non-vllm models?
|
|
53
|
+
model_config = self.runner.model_config
|
|
54
|
+
if self.use_mla:
|
|
55
|
+
# Individually pad the RopE and latents
|
|
56
|
+
qk_rope_head_dim = getattr(model_config.hf_text_config,
|
|
57
|
+
"qk_rope_head_dim", 0)
|
|
58
|
+
padded_kv_lora_rank = common_utils.align_to(
|
|
59
|
+
model_config.hf_text_config.kv_lora_rank, 128)
|
|
60
|
+
padded_qk_rope_head_dim = common_utils.align_to(
|
|
61
|
+
qk_rope_head_dim, 128)
|
|
62
|
+
mla_head_size = padded_kv_lora_rank + padded_qk_rope_head_dim
|
|
63
|
+
|
|
52
64
|
if len(self.runner.vllm_config.compilation_config.
|
|
53
65
|
static_forward_context) == 0:
|
|
54
|
-
model_config = self.runner.model_config
|
|
55
66
|
parallel_config = self.runner.parallel_config
|
|
56
67
|
# Pad num_kv_heads to multiple of TP size.
|
|
57
68
|
num_kv_heads = common_utils.get_padded_num_heads(
|
|
@@ -60,11 +71,11 @@ class KVCacheManager:
|
|
|
60
71
|
head_size = common_utils.get_padded_head_dim(
|
|
61
72
|
model_config.get_head_size())
|
|
62
73
|
for i in range(model_config.get_num_layers(parallel_config)):
|
|
63
|
-
if use_mla:
|
|
74
|
+
if self.use_mla:
|
|
64
75
|
kv_cache_spec[f"layer.{i}"] = MLAAttentionSpec(
|
|
65
76
|
block_size=block_size,
|
|
66
|
-
num_kv_heads=
|
|
67
|
-
head_size=
|
|
77
|
+
num_kv_heads=1,
|
|
78
|
+
head_size=mla_head_size,
|
|
68
79
|
dtype=self.runner.kv_cache_dtype,
|
|
69
80
|
cache_dtype_str=self.runner.vllm_config.cache_config.
|
|
70
81
|
cache_dtype)
|
|
@@ -82,14 +93,13 @@ class KVCacheManager:
|
|
|
82
93
|
self.runner.mesh.shape["model"])
|
|
83
94
|
head_size = common_utils.get_padded_head_dim(
|
|
84
95
|
hf_config.hidden_size // hf_config.num_attention_heads)
|
|
85
|
-
|
|
86
96
|
# Eagle3 has only 1 layer
|
|
87
97
|
for i in range(1):
|
|
88
|
-
if use_mla:
|
|
89
|
-
kv_cache_spec[f"
|
|
98
|
+
if self.use_mla:
|
|
99
|
+
kv_cache_spec[f"draft_layer.{i}"] = MLAAttentionSpec(
|
|
90
100
|
block_size=block_size,
|
|
91
|
-
num_kv_heads=
|
|
92
|
-
head_size=
|
|
101
|
+
num_kv_heads=1,
|
|
102
|
+
head_size=mla_head_size,
|
|
93
103
|
dtype=self.runner.kv_cache_dtype,
|
|
94
104
|
cache_dtype_str=self.runner.vllm_config.
|
|
95
105
|
cache_config.cache_dtype)
|
|
@@ -103,6 +113,7 @@ class KVCacheManager:
|
|
|
103
113
|
# Else propagate attention modules from compilation config.
|
|
104
114
|
layers = get_layers_from_vllm_config(self.runner.vllm_config,
|
|
105
115
|
Attention)
|
|
116
|
+
logger.warning(f"Compilation num_layers = {len(layers.items())}")
|
|
106
117
|
for layer_name, attn_module in layers.items():
|
|
107
118
|
if (kv_tgt_layer :=
|
|
108
119
|
attn_module.kv_sharing_target_layer_name) is not None:
|
|
@@ -126,11 +137,11 @@ class KVCacheManager:
|
|
|
126
137
|
attn_module.head_size),
|
|
127
138
|
dtype=self.runner.kv_cache_dtype,
|
|
128
139
|
sliding_window=attn_module.sliding_window)
|
|
129
|
-
elif use_mla:
|
|
130
|
-
kv_cache_spec[
|
|
140
|
+
elif self.use_mla:
|
|
141
|
+
kv_cache_spec[layer_name] = MLAAttentionSpec(
|
|
131
142
|
block_size=block_size,
|
|
132
|
-
num_kv_heads=
|
|
133
|
-
head_size=
|
|
143
|
+
num_kv_heads=1,
|
|
144
|
+
head_size=mla_head_size,
|
|
134
145
|
dtype=self.runner.kv_cache_dtype,
|
|
135
146
|
cache_dtype_str=self.runner.vllm_config.
|
|
136
147
|
cache_config.cache_dtype)
|
|
@@ -175,6 +186,11 @@ class KVCacheManager:
|
|
|
175
186
|
)
|
|
176
187
|
self.runner.input_batch = new_input_batch
|
|
177
188
|
self.runner.persistent_batch_manager.input_batch = new_input_batch
|
|
189
|
+
self.runner.block_tables_cpu = [
|
|
190
|
+
np.zeros((self.runner.max_num_reqs,
|
|
191
|
+
cdiv(self.runner.max_model_len, block_size)),
|
|
192
|
+
dtype=np.int32) for block_size in block_sizes
|
|
193
|
+
]
|
|
178
194
|
|
|
179
195
|
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
|
|
180
196
|
self.maybe_reinitialize_input_batch(kv_cache_config)
|
|
@@ -190,16 +206,22 @@ class KVCacheManager:
|
|
|
190
206
|
num_blocks = kv_cache_tensor.size // page_size_bytes
|
|
191
207
|
dp_size = self.runner.vllm_config.sharding_config.total_dp_size
|
|
192
208
|
# num_blocks must be a multiple of dp_size
|
|
193
|
-
num_blocks =
|
|
209
|
+
num_blocks = (num_blocks // dp_size) * dp_size
|
|
194
210
|
# NOTE: we'll multiply the num_kv_heads by 2 in the function
|
|
211
|
+
if self.use_mla:
|
|
212
|
+
head_size = self.runner.model_config.hf_config.kv_lora_rank + \
|
|
213
|
+
self.runner.model_config.hf_config.qk_rope_head_dim
|
|
214
|
+
else:
|
|
215
|
+
head_size = representative_spec.head_size
|
|
195
216
|
kv_cache = create_kv_caches(
|
|
196
217
|
num_blocks=num_blocks,
|
|
197
218
|
block_size=representative_spec.block_size,
|
|
198
219
|
num_kv_heads=representative_spec.num_kv_heads,
|
|
199
|
-
head_size=
|
|
220
|
+
head_size=head_size,
|
|
200
221
|
mesh=self.runner.mesh,
|
|
201
222
|
layer_names=[f'kv_cache_tensor.{i}'],
|
|
202
223
|
cache_dtype=t2j_dtype(representative_spec.dtype),
|
|
224
|
+
use_mla=self.use_mla,
|
|
203
225
|
)[0]
|
|
204
226
|
kv_caches.append(kv_cache)
|
|
205
227
|
num_blocks_list.append(num_blocks)
|
|
@@ -283,13 +305,8 @@ class KVCacheManager:
|
|
|
283
305
|
|
|
284
306
|
def _update_layer(cache, slices):
|
|
285
307
|
"""The function to apply to each layer's cache and slices."""
|
|
286
|
-
reshaped_slices = slices.reshape(-1,
|
|
287
|
-
|
|
288
|
-
for (i, block_idx) in enumerate(block_numbers):
|
|
289
|
-
cache = jax.lax.dynamic_update_slice_in_dim(cache,
|
|
290
|
-
reshaped_slices[i],
|
|
291
|
-
block_idx,
|
|
292
|
-
axis=0)
|
|
308
|
+
reshaped_slices = slices.reshape(-1, block_size, *slices.shape[1:])
|
|
309
|
+
cache.at[block_numbers].set(reshaped_slices)
|
|
293
310
|
return cache
|
|
294
311
|
|
|
295
312
|
return jax.tree.map(_update_layer, kv_caches, kv_cache_slices)
|
|
@@ -342,16 +359,12 @@ class KVCacheManager:
|
|
|
342
359
|
"""
|
|
343
360
|
if block_ids == list(range(block_ids[0],
|
|
344
361
|
block_ids[0] + len(block_ids))):
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
batched_kv_cache_per_layer = self._jitted_gather_continuous_kv_cache(
|
|
348
|
-
self.runner.kv_caches, block_ids[0], len(block_ids))
|
|
362
|
+
batched_kv_cache_per_layer = self._jitted_gather_continuous_kv_cache(
|
|
363
|
+
self.runner.kv_caches, block_ids[0], len(block_ids))
|
|
349
364
|
|
|
350
365
|
else:
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
batched_kv_cache_per_layer = self._jitted_gather_kv_cache(
|
|
354
|
-
self.runner.kv_caches, jnp.array(block_ids))
|
|
366
|
+
batched_kv_cache_per_layer = self._jitted_gather_kv_cache(
|
|
367
|
+
self.runner.kv_caches, jnp.array(block_ids))
|
|
355
368
|
return batched_kv_cache_per_layer
|
|
356
369
|
|
|
357
370
|
def transfer_kv_cache(self,
|
|
@@ -440,6 +453,7 @@ class KVCacheManager:
|
|
|
440
453
|
kv_cache_slices,
|
|
441
454
|
start_block,
|
|
442
455
|
)
|
|
456
|
+
jax.block_until_ready(self.runner.kv_caches)
|
|
443
457
|
else:
|
|
444
458
|
with runner_utils.LatencyTracker(
|
|
445
459
|
f"JittedInsertKVCache-b{len(block_numbers)}"):
|
|
@@ -451,6 +465,7 @@ class KVCacheManager:
|
|
|
451
465
|
kv_cache_slices,
|
|
452
466
|
jnp.array(block_numbers),
|
|
453
467
|
)
|
|
468
|
+
jax.block_until_ready(self.runner.kv_caches)
|
|
454
469
|
|
|
455
470
|
logger.debug(
|
|
456
471
|
f"Updated kv cache entries cnt={len(self.runner.kv_caches)}")
|
|
@@ -14,12 +14,13 @@ class PersistentBatchManager:
|
|
|
14
14
|
def __init__(self, requests: Dict[str, CachedRequestState],
|
|
15
15
|
input_batch: InputBatch, encoder_cache: Dict[str,
|
|
16
16
|
'jax.Array'],
|
|
17
|
-
uses_mrope: bool, model_config):
|
|
17
|
+
uses_mrope: bool, model_config, is_last_rank: bool):
|
|
18
18
|
self.requests = requests
|
|
19
19
|
self.input_batch = input_batch
|
|
20
20
|
self.encoder_cache = encoder_cache
|
|
21
21
|
self.uses_mrope = uses_mrope
|
|
22
22
|
self.model_config = model_config
|
|
23
|
+
self.is_last_rank = is_last_rank
|
|
23
24
|
|
|
24
25
|
def _reorder_batch(self, scheduler_output: "VllmSchedulerOutput") -> int:
|
|
25
26
|
""" Reorder the sheduled requests to RPA kernel friendly distribution
|
|
@@ -179,9 +180,35 @@ class PersistentBatchManager:
|
|
|
179
180
|
num_computed_tokens = req_data.num_computed_tokens[i]
|
|
180
181
|
new_block_ids = req_data.new_block_ids[i]
|
|
181
182
|
resumed_from_preemption = req_data.resumed_from_preemption[i]
|
|
183
|
+
num_output_tokens = req_data.num_output_tokens[i]
|
|
182
184
|
|
|
183
185
|
# Update the cached states.
|
|
184
186
|
req_state.num_computed_tokens = num_computed_tokens
|
|
187
|
+
req_index = self.input_batch.req_id_to_index.get(req_id)
|
|
188
|
+
|
|
189
|
+
if not self.is_last_rank:
|
|
190
|
+
# When using PP, the scheduler sends the sampled tokens back,
|
|
191
|
+
# because there's no direct communication between the first-
|
|
192
|
+
# stage worker and the last-stage worker.
|
|
193
|
+
new_token_ids = req_data.new_token_ids[i]
|
|
194
|
+
# Add the sampled token(s) from the previous step (if any).
|
|
195
|
+
# This doesn't include "unverified" tokens like spec tokens.
|
|
196
|
+
num_new_tokens = (num_computed_tokens + len(new_token_ids) -
|
|
197
|
+
req_state.num_tokens)
|
|
198
|
+
if num_new_tokens == 1:
|
|
199
|
+
req_state.output_token_ids.append(new_token_ids[-1])
|
|
200
|
+
elif num_new_tokens > 0:
|
|
201
|
+
req_state.output_token_ids.extend(
|
|
202
|
+
new_token_ids[-num_new_tokens:])
|
|
203
|
+
elif num_output_tokens < len(req_state.output_token_ids):
|
|
204
|
+
del req_state.output_token_ids[num_output_tokens:]
|
|
205
|
+
if req_index is not None:
|
|
206
|
+
end_idx = (self.input_batch.num_prompt_tokens[req_index] +
|
|
207
|
+
num_output_tokens)
|
|
208
|
+
self.input_batch.num_tokens[req_index] = end_idx
|
|
209
|
+
self.input_batch.num_tokens_no_spec[req_index] = end_idx
|
|
210
|
+
|
|
211
|
+
# Update the block IDs.
|
|
185
212
|
if not resumed_from_preemption:
|
|
186
213
|
if new_block_ids is not None:
|
|
187
214
|
# Append the new blocks to the existing block IDs.
|
|
@@ -194,7 +221,6 @@ class PersistentBatchManager:
|
|
|
194
221
|
# Replace the existing block IDs with the new ones.
|
|
195
222
|
req_state.block_ids = new_block_ids
|
|
196
223
|
|
|
197
|
-
req_index = self.input_batch.req_id_to_index.get(req_id)
|
|
198
224
|
if req_index is None:
|
|
199
225
|
# The request is not in the persistent batch.
|
|
200
226
|
# The request was either preempted and resumed later, or was not
|
|
@@ -209,6 +235,18 @@ class PersistentBatchManager:
|
|
|
209
235
|
self.input_batch.block_table.append_row(
|
|
210
236
|
new_block_ids, req_index)
|
|
211
237
|
|
|
238
|
+
# For the last rank, we don't need to update the token_ids_cpu
|
|
239
|
+
# because the sampled tokens are already cached.
|
|
240
|
+
if not self.is_last_rank:
|
|
241
|
+
start_token_index = num_computed_tokens
|
|
242
|
+
end_token_index = num_computed_tokens + len(new_token_ids)
|
|
243
|
+
self.input_batch.token_ids_cpu[
|
|
244
|
+
req_index,
|
|
245
|
+
start_token_index:end_token_index] = new_token_ids
|
|
246
|
+
self.input_batch.num_tokens_no_spec[
|
|
247
|
+
req_index] = end_token_index
|
|
248
|
+
self.input_batch.num_tokens[req_index] = end_token_index
|
|
249
|
+
|
|
212
250
|
# Add spec_token_ids to token_ids_cpu.
|
|
213
251
|
spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get(
|
|
214
252
|
req_id, ())
|
|
@@ -61,11 +61,10 @@ class StructuredDecodingManager:
|
|
|
61
61
|
self.runner.require_structured_out_cpu.fill(0)
|
|
62
62
|
|
|
63
63
|
sorted_struct_requests = sorted(
|
|
64
|
-
grammar_output.structured_output_request_ids
|
|
65
|
-
key=lambda item: item[1])
|
|
64
|
+
grammar_output.structured_output_request_ids)
|
|
66
65
|
|
|
67
66
|
cumulative_mask_idx = 0
|
|
68
|
-
for req_id
|
|
67
|
+
for req_id in sorted_struct_requests:
|
|
69
68
|
if req_id not in self.runner.input_batch.req_id_to_index:
|
|
70
69
|
continue
|
|
71
70
|
batch_index = self.runner.input_batch.req_id_to_index[req_id]
|