tpu-inference 0.11.1.dev202511220812__py3-none-any.whl → 0.13.2.dev20251230__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +13 -0
- tests/core/__init__.py +13 -0
- tests/core/test_disagg_utils.py +14 -0
- tests/core/test_dp_scheduler.py +650 -768
- tests/core/test_init.py +14 -0
- tests/distributed/__init__.py +13 -0
- tests/distributed/test_distributed_utils.py +120 -0
- tests/distributed/test_tpu_connector.py +478 -0
- tests/e2e/__init__.py +13 -0
- tests/e2e/test_async_scheduler.py +211 -0
- tests/e2e/test_data_parallel.py +289 -0
- tests/e2e/test_hybrid_kvcache.py +219 -0
- tests/e2e/test_local_disagg.py +257 -0
- tests/e2e/test_model_loader.py +268 -0
- tests/e2e/test_multi_modal_inference.py +111 -0
- tests/e2e/test_pipeline_parallel.py +265 -0
- tests/e2e/test_runai_model_streamer_loader.py +104 -0
- tests/e2e/test_sampling_params.py +269 -0
- tests/e2e/test_speculative_decoding.py +311 -0
- tests/e2e/test_structured_decoding.py +46 -0
- tests/executors/__init__.py +13 -0
- tests/executors/test_ray_distributed_executor.py +199 -0
- tests/experimental/__init__.py +13 -0
- tests/experimental/test_llama3_jax_stashed.py +208 -0
- tests/kernels/__init__.py +13 -0
- tests/kernels/collectives/__init__.py +13 -0
- tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
- tests/kernels/fused_moe_v1_test.py +317 -34
- tests/kernels/gmm_test.py +205 -0
- tests/kernels/mla_v1_test.py +143 -41
- tests/kernels/quantized_matmul_kernel_test.py +2 -34
- tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
- tests/layers/__init__.py +13 -0
- tests/layers/common/__init__.py +13 -0
- tests/layers/common/test_attention_interface.py +156 -0
- tests/layers/common/test_quantization.py +149 -0
- tests/layers/jax/__init__.py +13 -0
- tests/layers/jax/attention/__init__.py +13 -0
- tests/layers/jax/attention/test_common_attention.py +103 -0
- tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
- tests/layers/jax/attention/test_llama4_attention.py +135 -0
- tests/layers/jax/moe/__init__.py +13 -0
- tests/layers/jax/moe/test_deepseek_moe.py +235 -0
- tests/layers/jax/sample/__init__.py +13 -0
- tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
- tests/layers/jax/sample/test_sampling.py +115 -0
- tests/layers/jax/sample/test_sampling_metadata.py +254 -0
- tests/layers/jax/test_layers.py +155 -0
- tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
- tests/layers/jax/test_rope.py +93 -0
- tests/layers/jax/test_sharding.py +159 -0
- tests/layers/jax/test_transformer_block.py +152 -0
- tests/layers/vllm/__init__.py +13 -0
- tests/layers/vllm/test_attention.py +363 -0
- tests/layers/vllm/test_awq.py +406 -0
- tests/layers/vllm/test_compressed_tensors_moe.py +199 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +441 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +443 -0
- tests/layers/vllm/test_fp8.py +17 -0
- tests/layers/vllm/test_mxfp4.py +320 -0
- tests/layers/vllm/test_unquantized.py +662 -0
- tests/layers/vllm/utils.py +87 -0
- tests/lora/__init__.py +13 -0
- tests/lora/conftest.py +14 -0
- tests/lora/test_bgmv.py +14 -0
- tests/lora/test_layers.py +26 -6
- tests/lora/test_lora.py +15 -1
- tests/lora/test_lora_perf.py +67 -0
- tests/models/__init__.py +13 -0
- tests/models/common/__init__.py +13 -0
- tests/models/common/test_model_loader.py +455 -0
- tests/models/jax/__init__.py +13 -0
- tests/models/jax/test_deepseek_v3.py +401 -0
- tests/models/jax/test_llama3.py +184 -0
- tests/models/jax/test_llama4.py +298 -0
- tests/models/jax/test_llama_eagle3.py +197 -0
- tests/models/jax/test_llama_guard_4.py +242 -0
- tests/models/jax/test_qwen2.py +172 -0
- tests/models/jax/test_qwen2_5_vl.py +605 -0
- tests/models/jax/test_qwen3.py +169 -0
- tests/models/jax/test_weight_loading.py +180 -0
- tests/models/jax/utils/__init__.py +13 -0
- tests/models/jax/utils/test_multi_modal_utils.py +212 -0
- tests/platforms/__init__.py +13 -0
- tests/platforms/test_tpu_platform.py +54 -0
- tests/runner/__init__.py +13 -0
- tests/runner/test_block_table.py +395 -0
- tests/runner/test_input_batch.py +226 -0
- tests/runner/test_kv_cache.py +220 -0
- tests/runner/test_kv_cache_manager.py +498 -0
- tests/runner/test_multimodal_manager.py +429 -0
- tests/runner/test_persistent_batch_manager.py +84 -0
- tests/runner/test_speculative_decoding_manager.py +368 -0
- tests/runner/test_structured_decoding_manager.py +220 -0
- tests/runner/test_tpu_runner.py +261 -0
- tests/runner/test_tpu_runner_dp.py +1099 -0
- tests/runner/test_tpu_runner_mesh.py +200 -0
- tests/runner/test_utils.py +411 -0
- tests/spec_decode/__init__.py +13 -0
- tests/spec_decode/test_eagle3.py +311 -0
- tests/test_base.py +14 -0
- tests/test_envs.py +110 -12
- tests/test_tpu_info.py +14 -0
- tests/test_utils.py +2 -45
- tests/worker/__init__.py +13 -0
- tests/worker/tpu_worker_test.py +414 -0
- tpu_inference/__init__.py +14 -0
- tpu_inference/core/__init__.py +13 -0
- tpu_inference/core/sched/__init__.py +13 -0
- tpu_inference/core/sched/dp_scheduler.py +372 -56
- tpu_inference/distributed/__init__.py +13 -0
- tpu_inference/distributed/jax_parallel_state.py +14 -0
- tpu_inference/distributed/tpu_connector.py +15 -10
- tpu_inference/distributed/utils.py +56 -4
- tpu_inference/envs.py +92 -8
- tpu_inference/executors/__init__.py +13 -0
- tpu_inference/executors/ray_distributed_executor.py +25 -4
- tpu_inference/experimental/__init__.py +13 -0
- tpu_inference/experimental/llama3_jax_stashed.py +14 -0
- tpu_inference/kernels/__init__.py +13 -0
- tpu_inference/kernels/collectives/__init__.py +13 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
- tpu_inference/kernels/flash_attention/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +807 -230
- tpu_inference/kernels/megablox/__init__.py +13 -0
- tpu_inference/kernels/megablox/common.py +54 -0
- tpu_inference/kernels/megablox/gmm.py +646 -0
- tpu_inference/kernels/mla/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/kernel.py +117 -145
- tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
- tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +218 -137
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
- tpu_inference/layers/__init__.py +13 -0
- tpu_inference/layers/common/__init__.py +13 -0
- tpu_inference/layers/common/attention_interface.py +25 -12
- tpu_inference/layers/common/attention_metadata.py +14 -0
- tpu_inference/layers/common/fused_moe_gmm.py +506 -0
- tpu_inference/layers/common/quant_methods.py +15 -0
- tpu_inference/layers/common/quantization.py +282 -0
- tpu_inference/layers/common/sharding.py +32 -9
- tpu_inference/layers/common/utils.py +94 -0
- tpu_inference/layers/jax/__init__.py +13 -0
- tpu_inference/layers/jax/attention/__init__.py +13 -0
- tpu_inference/layers/jax/attention/attention.py +19 -6
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
- tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
- tpu_inference/layers/jax/base.py +14 -0
- tpu_inference/layers/jax/constants.py +13 -0
- tpu_inference/layers/jax/layers.py +14 -0
- tpu_inference/layers/jax/misc.py +14 -0
- tpu_inference/layers/jax/moe/__init__.py +13 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
- tpu_inference/layers/jax/moe/moe.py +43 -3
- tpu_inference/layers/jax/pp_utils.py +53 -0
- tpu_inference/layers/jax/rope.py +14 -0
- tpu_inference/layers/jax/rope_interface.py +14 -0
- tpu_inference/layers/jax/sample/__init__.py +13 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
- tpu_inference/layers/jax/sample/sampling.py +15 -1
- tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
- tpu_inference/layers/jax/transformer_block.py +14 -0
- tpu_inference/layers/vllm/__init__.py +13 -0
- tpu_inference/layers/vllm/attention.py +4 -4
- tpu_inference/layers/vllm/fused_moe.py +101 -494
- tpu_inference/layers/vllm/linear.py +64 -0
- tpu_inference/layers/vllm/process_weights/__init__.py +13 -0
- tpu_inference/layers/vllm/{sharding.py → process_weights/cleanup_sharding.py} +24 -15
- tpu_inference/layers/vllm/process_weights/fused_moe_weights.py +369 -0
- tpu_inference/layers/vllm/process_weights/linear_weights.py +174 -0
- tpu_inference/layers/vllm/quantization/__init__.py +19 -3
- tpu_inference/layers/vllm/quantization/awq.py +96 -82
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +23 -8
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +172 -176
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +111 -91
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +79 -43
- tpu_inference/layers/vllm/quantization/{common.py → configs.py} +42 -25
- tpu_inference/layers/vllm/quantization/fp8.py +119 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +137 -178
- tpu_inference/layers/vllm/quantization/unquantized.py +157 -233
- tpu_inference/lora/__init__.py +13 -0
- tpu_inference/lora/torch_lora_ops.py +8 -13
- tpu_inference/models/__init__.py +13 -0
- tpu_inference/models/common/__init__.py +13 -0
- tpu_inference/models/common/model_loader.py +112 -35
- tpu_inference/models/jax/__init__.py +13 -0
- tpu_inference/models/jax/deepseek_v3.py +267 -157
- tpu_inference/models/jax/gpt_oss.py +26 -10
- tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
- tpu_inference/models/jax/llama3.py +99 -36
- tpu_inference/models/jax/llama4.py +14 -0
- tpu_inference/models/jax/llama_eagle3.py +18 -5
- tpu_inference/models/jax/llama_guard_4.py +15 -1
- tpu_inference/models/jax/qwen2.py +17 -2
- tpu_inference/models/jax/qwen2_5_vl.py +179 -51
- tpu_inference/models/jax/qwen3.py +17 -2
- tpu_inference/models/jax/utils/__init__.py +13 -0
- tpu_inference/models/jax/utils/file_utils.py +14 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
- tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
- tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +92 -32
- tpu_inference/models/jax/utils/weight_utils.py +234 -155
- tpu_inference/models/vllm/__init__.py +13 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +32 -8
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
- tpu_inference/platforms/__init__.py +14 -0
- tpu_inference/platforms/tpu_platform.py +51 -72
- tpu_inference/runner/__init__.py +13 -0
- tpu_inference/runner/compilation_manager.py +180 -80
- tpu_inference/runner/kv_cache.py +54 -20
- tpu_inference/runner/kv_cache_manager.py +55 -33
- tpu_inference/runner/lora_utils.py +16 -1
- tpu_inference/runner/multimodal_manager.py +16 -2
- tpu_inference/runner/persistent_batch_manager.py +54 -2
- tpu_inference/runner/speculative_decoding_manager.py +14 -0
- tpu_inference/runner/structured_decoding_manager.py +16 -3
- tpu_inference/runner/tpu_runner.py +124 -61
- tpu_inference/runner/utils.py +2 -2
- tpu_inference/spec_decode/__init__.py +13 -0
- tpu_inference/spec_decode/jax/__init__.py +13 -0
- tpu_inference/spec_decode/jax/eagle3.py +84 -22
- tpu_inference/tpu_info.py +14 -0
- tpu_inference/utils.py +72 -44
- tpu_inference/worker/__init__.py +13 -0
- tpu_inference/worker/tpu_worker.py +66 -52
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/METADATA +8 -9
- tpu_inference-0.13.2.dev20251230.dist-info/RECORD +266 -0
- tpu_inference/layers/vllm/linear_common.py +0 -186
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
- tpu_inference-0.11.1.dev202511220812.dist-info/RECORD +0 -174
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/WHEEL +0 -0
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/top_level.txt +0 -0
|
@@ -1,5 +1,19 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
import functools
|
|
2
|
-
from typing import TYPE_CHECKING,
|
|
16
|
+
from typing import TYPE_CHECKING, List
|
|
3
17
|
|
|
4
18
|
import jax
|
|
5
19
|
import jax.numpy as jnp
|
|
@@ -7,8 +21,8 @@ import numpy as np
|
|
|
7
21
|
import vllm.envs as envs
|
|
8
22
|
from jax.sharding import NamedSharding, PartitionSpec
|
|
9
23
|
from torchax.ops.mappings import t2j_dtype
|
|
10
|
-
from vllm.attention import Attention
|
|
11
24
|
from vllm.attention.backends.abstract import AttentionType
|
|
25
|
+
from vllm.attention.layer import Attention
|
|
12
26
|
from vllm.config import get_layers_from_vllm_config
|
|
13
27
|
from vllm.utils.math_utils import cdiv
|
|
14
28
|
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
|
@@ -39,20 +53,30 @@ class KVCacheManager:
|
|
|
39
53
|
# means this layer will perform attention using the keys and values
|
|
40
54
|
# from the KV cache of `shared_kv_cache_layers[layer_name]`.
|
|
41
55
|
self.shared_kv_cache_layers: dict[str, str] = {}
|
|
56
|
+
self.use_mla = self.runner.model_config.use_mla
|
|
42
57
|
|
|
43
58
|
def get_kv_cache_spec(self):
|
|
44
59
|
# TODO(xiang): this hack tricks engine core to init successfully
|
|
45
60
|
block_size = self.runner.cache_config.block_size
|
|
46
|
-
use_mla = self.runner.model_config.use_mla
|
|
47
61
|
kv_cache_spec: dict[str, KVCacheSpec] = {}
|
|
48
62
|
|
|
49
63
|
# If use pure jax (MODEL_IMPL_TYPE=flax_nnx), we don't register
|
|
50
64
|
# attention into compilation config.
|
|
51
65
|
# Use FullAttentionSpec for each layer
|
|
52
66
|
# TODO(pooyam): Is it possible to merge the logic for vllm and non-vllm models?
|
|
67
|
+
model_config = self.runner.model_config
|
|
68
|
+
if self.use_mla:
|
|
69
|
+
# Individually pad the RopE and latents
|
|
70
|
+
qk_rope_head_dim = getattr(model_config.hf_text_config,
|
|
71
|
+
"qk_rope_head_dim", 0)
|
|
72
|
+
padded_kv_lora_rank = common_utils.align_to(
|
|
73
|
+
model_config.hf_text_config.kv_lora_rank, 128)
|
|
74
|
+
padded_qk_rope_head_dim = common_utils.align_to(
|
|
75
|
+
qk_rope_head_dim, 128)
|
|
76
|
+
mla_head_size = padded_kv_lora_rank + padded_qk_rope_head_dim
|
|
77
|
+
|
|
53
78
|
if len(self.runner.vllm_config.compilation_config.
|
|
54
79
|
static_forward_context) == 0:
|
|
55
|
-
model_config = self.runner.model_config
|
|
56
80
|
parallel_config = self.runner.parallel_config
|
|
57
81
|
# Pad num_kv_heads to multiple of TP size.
|
|
58
82
|
num_kv_heads = common_utils.get_padded_num_heads(
|
|
@@ -61,11 +85,11 @@ class KVCacheManager:
|
|
|
61
85
|
head_size = common_utils.get_padded_head_dim(
|
|
62
86
|
model_config.get_head_size())
|
|
63
87
|
for i in range(model_config.get_num_layers(parallel_config)):
|
|
64
|
-
if use_mla:
|
|
88
|
+
if self.use_mla:
|
|
65
89
|
kv_cache_spec[f"layer.{i}"] = MLAAttentionSpec(
|
|
66
90
|
block_size=block_size,
|
|
67
|
-
num_kv_heads=
|
|
68
|
-
head_size=
|
|
91
|
+
num_kv_heads=1,
|
|
92
|
+
head_size=mla_head_size,
|
|
69
93
|
dtype=self.runner.kv_cache_dtype,
|
|
70
94
|
cache_dtype_str=self.runner.vllm_config.cache_config.
|
|
71
95
|
cache_dtype)
|
|
@@ -83,14 +107,13 @@ class KVCacheManager:
|
|
|
83
107
|
self.runner.mesh.shape["model"])
|
|
84
108
|
head_size = common_utils.get_padded_head_dim(
|
|
85
109
|
hf_config.hidden_size // hf_config.num_attention_heads)
|
|
86
|
-
|
|
87
110
|
# Eagle3 has only 1 layer
|
|
88
111
|
for i in range(1):
|
|
89
|
-
if use_mla:
|
|
90
|
-
kv_cache_spec[f"
|
|
112
|
+
if self.use_mla:
|
|
113
|
+
kv_cache_spec[f"draft_layer.{i}"] = MLAAttentionSpec(
|
|
91
114
|
block_size=block_size,
|
|
92
|
-
num_kv_heads=
|
|
93
|
-
head_size=
|
|
115
|
+
num_kv_heads=1,
|
|
116
|
+
head_size=mla_head_size,
|
|
94
117
|
dtype=self.runner.kv_cache_dtype,
|
|
95
118
|
cache_dtype_str=self.runner.vllm_config.
|
|
96
119
|
cache_config.cache_dtype)
|
|
@@ -104,6 +127,7 @@ class KVCacheManager:
|
|
|
104
127
|
# Else propagate attention modules from compilation config.
|
|
105
128
|
layers = get_layers_from_vllm_config(self.runner.vllm_config,
|
|
106
129
|
Attention)
|
|
130
|
+
logger.warning(f"Compilation num_layers = {len(layers.items())}")
|
|
107
131
|
for layer_name, attn_module in layers.items():
|
|
108
132
|
if (kv_tgt_layer :=
|
|
109
133
|
attn_module.kv_sharing_target_layer_name) is not None:
|
|
@@ -127,11 +151,11 @@ class KVCacheManager:
|
|
|
127
151
|
attn_module.head_size),
|
|
128
152
|
dtype=self.runner.kv_cache_dtype,
|
|
129
153
|
sliding_window=attn_module.sliding_window)
|
|
130
|
-
elif use_mla:
|
|
131
|
-
kv_cache_spec[
|
|
154
|
+
elif self.use_mla:
|
|
155
|
+
kv_cache_spec[layer_name] = MLAAttentionSpec(
|
|
132
156
|
block_size=block_size,
|
|
133
|
-
num_kv_heads=
|
|
134
|
-
head_size=
|
|
157
|
+
num_kv_heads=1,
|
|
158
|
+
head_size=mla_head_size,
|
|
135
159
|
dtype=self.runner.kv_cache_dtype,
|
|
136
160
|
cache_dtype_str=self.runner.vllm_config.
|
|
137
161
|
cache_config.cache_dtype)
|
|
@@ -188,7 +212,6 @@ class KVCacheManager:
|
|
|
188
212
|
# uniform page size.
|
|
189
213
|
representative_spec = kv_cache_config.kv_cache_groups[0].kv_cache_spec
|
|
190
214
|
page_size_bytes = representative_spec.page_size_bytes
|
|
191
|
-
self.runner.layer_name_to_kvcache_index: Dict[str, int] = {}
|
|
192
215
|
kv_caches = self.runner.kv_caches
|
|
193
216
|
num_blocks_list = []
|
|
194
217
|
for i, kv_cache_tensor in enumerate(kv_cache_config.kv_cache_tensors):
|
|
@@ -198,14 +221,20 @@ class KVCacheManager:
|
|
|
198
221
|
# num_blocks must be a multiple of dp_size
|
|
199
222
|
num_blocks = (num_blocks // dp_size) * dp_size
|
|
200
223
|
# NOTE: we'll multiply the num_kv_heads by 2 in the function
|
|
224
|
+
if self.use_mla:
|
|
225
|
+
head_size = self.runner.model_config.hf_config.kv_lora_rank + \
|
|
226
|
+
self.runner.model_config.hf_config.qk_rope_head_dim
|
|
227
|
+
else:
|
|
228
|
+
head_size = representative_spec.head_size
|
|
201
229
|
kv_cache = create_kv_caches(
|
|
202
230
|
num_blocks=num_blocks,
|
|
203
231
|
block_size=representative_spec.block_size,
|
|
204
232
|
num_kv_heads=representative_spec.num_kv_heads,
|
|
205
|
-
head_size=
|
|
233
|
+
head_size=head_size,
|
|
206
234
|
mesh=self.runner.mesh,
|
|
207
235
|
layer_names=[f'kv_cache_tensor.{i}'],
|
|
208
236
|
cache_dtype=t2j_dtype(representative_spec.dtype),
|
|
237
|
+
use_mla=self.use_mla,
|
|
209
238
|
)[0]
|
|
210
239
|
kv_caches.append(kv_cache)
|
|
211
240
|
num_blocks_list.append(num_blocks)
|
|
@@ -289,13 +318,8 @@ class KVCacheManager:
|
|
|
289
318
|
|
|
290
319
|
def _update_layer(cache, slices):
|
|
291
320
|
"""The function to apply to each layer's cache and slices."""
|
|
292
|
-
reshaped_slices = slices.reshape(-1,
|
|
293
|
-
|
|
294
|
-
for (i, block_idx) in enumerate(block_numbers):
|
|
295
|
-
cache = jax.lax.dynamic_update_slice_in_dim(cache,
|
|
296
|
-
reshaped_slices[i],
|
|
297
|
-
block_idx,
|
|
298
|
-
axis=0)
|
|
321
|
+
reshaped_slices = slices.reshape(-1, block_size, *slices.shape[1:])
|
|
322
|
+
cache.at[block_numbers].set(reshaped_slices)
|
|
299
323
|
return cache
|
|
300
324
|
|
|
301
325
|
return jax.tree.map(_update_layer, kv_caches, kv_cache_slices)
|
|
@@ -348,16 +372,12 @@ class KVCacheManager:
|
|
|
348
372
|
"""
|
|
349
373
|
if block_ids == list(range(block_ids[0],
|
|
350
374
|
block_ids[0] + len(block_ids))):
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
batched_kv_cache_per_layer = self._jitted_gather_continuous_kv_cache(
|
|
354
|
-
self.runner.kv_caches, block_ids[0], len(block_ids))
|
|
375
|
+
batched_kv_cache_per_layer = self._jitted_gather_continuous_kv_cache(
|
|
376
|
+
self.runner.kv_caches, block_ids[0], len(block_ids))
|
|
355
377
|
|
|
356
378
|
else:
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
batched_kv_cache_per_layer = self._jitted_gather_kv_cache(
|
|
360
|
-
self.runner.kv_caches, jnp.array(block_ids))
|
|
379
|
+
batched_kv_cache_per_layer = self._jitted_gather_kv_cache(
|
|
380
|
+
self.runner.kv_caches, jnp.array(block_ids))
|
|
361
381
|
return batched_kv_cache_per_layer
|
|
362
382
|
|
|
363
383
|
def transfer_kv_cache(self,
|
|
@@ -446,6 +466,7 @@ class KVCacheManager:
|
|
|
446
466
|
kv_cache_slices,
|
|
447
467
|
start_block,
|
|
448
468
|
)
|
|
469
|
+
jax.block_until_ready(self.runner.kv_caches)
|
|
449
470
|
else:
|
|
450
471
|
with runner_utils.LatencyTracker(
|
|
451
472
|
f"JittedInsertKVCache-b{len(block_numbers)}"):
|
|
@@ -457,6 +478,7 @@ class KVCacheManager:
|
|
|
457
478
|
kv_cache_slices,
|
|
458
479
|
jnp.array(block_numbers),
|
|
459
480
|
)
|
|
481
|
+
jax.block_until_ready(self.runner.kv_caches)
|
|
460
482
|
|
|
461
483
|
logger.debug(
|
|
462
484
|
f"Updated kv cache entries cnt={len(self.runner.kv_caches)}")
|
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
from __future__ import annotations
|
|
2
16
|
|
|
3
17
|
from typing import TYPE_CHECKING
|
|
@@ -7,7 +21,8 @@ from torchax.interop import jax_view
|
|
|
7
21
|
from vllm.lora.layers.base_linear import BaseLinearLayerWithLoRA
|
|
8
22
|
from vllm.lora.request import LoRARequest
|
|
9
23
|
|
|
10
|
-
from tpu_inference.layers.vllm.
|
|
24
|
+
from tpu_inference.layers.vllm.process_weights.cleanup_sharding import \
|
|
25
|
+
update_lora
|
|
11
26
|
|
|
12
27
|
if TYPE_CHECKING:
|
|
13
28
|
from tpu_inference.runner.tpu_runner import TPUModelRunner
|
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
from typing import TYPE_CHECKING
|
|
2
16
|
|
|
3
17
|
import jax
|
|
@@ -98,7 +112,7 @@ class MultiModalManager:
|
|
|
98
112
|
# encoder outputs.
|
|
99
113
|
encoder_outputs = []
|
|
100
114
|
for _, num_items, mm_kwargs_group in group_mm_kwargs_by_modality(
|
|
101
|
-
mm_kwargs
|
|
115
|
+
mm_kwargs):
|
|
102
116
|
batched_mm_inputs = mm_kwargs_group
|
|
103
117
|
# Convert torch tensors to numpy arrays that JAX can handle.
|
|
104
118
|
if "pixel_values" in batched_mm_inputs and isinstance(
|
|
@@ -134,7 +148,7 @@ class MultiModalManager:
|
|
|
134
148
|
# 2. A list or tuple (length: num_items) of tensors, each of shape
|
|
135
149
|
# (feature_size, hidden_size) in case the feature size is dynamic
|
|
136
150
|
# depending on the input multimodal items.
|
|
137
|
-
curr_group_outputs = self.runner.
|
|
151
|
+
curr_group_outputs = self.runner.embed_multimodal_fn(
|
|
138
152
|
self.runner.state, image_grid_thw, **batched_mm_inputs)
|
|
139
153
|
|
|
140
154
|
sanity_check_mm_encoder_outputs(
|
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
from typing import Dict
|
|
2
16
|
|
|
3
17
|
import jax
|
|
@@ -14,12 +28,13 @@ class PersistentBatchManager:
|
|
|
14
28
|
def __init__(self, requests: Dict[str, CachedRequestState],
|
|
15
29
|
input_batch: InputBatch, encoder_cache: Dict[str,
|
|
16
30
|
'jax.Array'],
|
|
17
|
-
uses_mrope: bool, model_config):
|
|
31
|
+
uses_mrope: bool, model_config, is_last_rank: bool):
|
|
18
32
|
self.requests = requests
|
|
19
33
|
self.input_batch = input_batch
|
|
20
34
|
self.encoder_cache = encoder_cache
|
|
21
35
|
self.uses_mrope = uses_mrope
|
|
22
36
|
self.model_config = model_config
|
|
37
|
+
self.is_last_rank = is_last_rank
|
|
23
38
|
|
|
24
39
|
def _reorder_batch(self, scheduler_output: "VllmSchedulerOutput") -> int:
|
|
25
40
|
""" Reorder the sheduled requests to RPA kernel friendly distribution
|
|
@@ -179,9 +194,35 @@ class PersistentBatchManager:
|
|
|
179
194
|
num_computed_tokens = req_data.num_computed_tokens[i]
|
|
180
195
|
new_block_ids = req_data.new_block_ids[i]
|
|
181
196
|
resumed_from_preemption = req_data.resumed_from_preemption[i]
|
|
197
|
+
num_output_tokens = req_data.num_output_tokens[i]
|
|
182
198
|
|
|
183
199
|
# Update the cached states.
|
|
184
200
|
req_state.num_computed_tokens = num_computed_tokens
|
|
201
|
+
req_index = self.input_batch.req_id_to_index.get(req_id)
|
|
202
|
+
|
|
203
|
+
if not self.is_last_rank:
|
|
204
|
+
# When using PP, the scheduler sends the sampled tokens back,
|
|
205
|
+
# because there's no direct communication between the first-
|
|
206
|
+
# stage worker and the last-stage worker.
|
|
207
|
+
new_token_ids = req_data.new_token_ids[i]
|
|
208
|
+
# Add the sampled token(s) from the previous step (if any).
|
|
209
|
+
# This doesn't include "unverified" tokens like spec tokens.
|
|
210
|
+
num_new_tokens = (num_computed_tokens + len(new_token_ids) -
|
|
211
|
+
req_state.num_tokens)
|
|
212
|
+
if num_new_tokens == 1:
|
|
213
|
+
req_state.output_token_ids.append(new_token_ids[-1])
|
|
214
|
+
elif num_new_tokens > 0:
|
|
215
|
+
req_state.output_token_ids.extend(
|
|
216
|
+
new_token_ids[-num_new_tokens:])
|
|
217
|
+
elif num_output_tokens < len(req_state.output_token_ids):
|
|
218
|
+
del req_state.output_token_ids[num_output_tokens:]
|
|
219
|
+
if req_index is not None:
|
|
220
|
+
end_idx = (self.input_batch.num_prompt_tokens[req_index] +
|
|
221
|
+
num_output_tokens)
|
|
222
|
+
self.input_batch.num_tokens[req_index] = end_idx
|
|
223
|
+
self.input_batch.num_tokens_no_spec[req_index] = end_idx
|
|
224
|
+
|
|
225
|
+
# Update the block IDs.
|
|
185
226
|
if not resumed_from_preemption:
|
|
186
227
|
if new_block_ids is not None:
|
|
187
228
|
# Append the new blocks to the existing block IDs.
|
|
@@ -194,7 +235,6 @@ class PersistentBatchManager:
|
|
|
194
235
|
# Replace the existing block IDs with the new ones.
|
|
195
236
|
req_state.block_ids = new_block_ids
|
|
196
237
|
|
|
197
|
-
req_index = self.input_batch.req_id_to_index.get(req_id)
|
|
198
238
|
if req_index is None:
|
|
199
239
|
# The request is not in the persistent batch.
|
|
200
240
|
# The request was either preempted and resumed later, or was not
|
|
@@ -209,6 +249,18 @@ class PersistentBatchManager:
|
|
|
209
249
|
self.input_batch.block_table.append_row(
|
|
210
250
|
new_block_ids, req_index)
|
|
211
251
|
|
|
252
|
+
# For the last rank, we don't need to update the token_ids_cpu
|
|
253
|
+
# because the sampled tokens are already cached.
|
|
254
|
+
if not self.is_last_rank:
|
|
255
|
+
start_token_index = num_computed_tokens
|
|
256
|
+
end_token_index = num_computed_tokens + len(new_token_ids)
|
|
257
|
+
self.input_batch.token_ids_cpu[
|
|
258
|
+
req_index,
|
|
259
|
+
start_token_index:end_token_index] = new_token_ids
|
|
260
|
+
self.input_batch.num_tokens_no_spec[
|
|
261
|
+
req_index] = end_token_index
|
|
262
|
+
self.input_batch.num_tokens[req_index] = end_token_index
|
|
263
|
+
|
|
212
264
|
# Add spec_token_ids to token_ids_cpu.
|
|
213
265
|
spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get(
|
|
214
266
|
req_id, ())
|
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
from __future__ import annotations
|
|
2
16
|
|
|
3
17
|
from dataclasses import dataclass
|
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
import functools
|
|
2
16
|
from typing import TYPE_CHECKING, Tuple
|
|
3
17
|
|
|
@@ -61,11 +75,10 @@ class StructuredDecodingManager:
|
|
|
61
75
|
self.runner.require_structured_out_cpu.fill(0)
|
|
62
76
|
|
|
63
77
|
sorted_struct_requests = sorted(
|
|
64
|
-
grammar_output.structured_output_request_ids
|
|
65
|
-
key=lambda item: item[1])
|
|
78
|
+
grammar_output.structured_output_request_ids)
|
|
66
79
|
|
|
67
80
|
cumulative_mask_idx = 0
|
|
68
|
-
for req_id
|
|
81
|
+
for req_id in sorted_struct_requests:
|
|
69
82
|
if req_id not in self.runner.input_batch.req_id_to_index:
|
|
70
83
|
continue
|
|
71
84
|
batch_index = self.runner.input_batch.req_id_to_index[req_id]
|