tpu-inference 0.11.1.dev202511270815__py3-none-any.whl → 0.13.0rc2.post7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +13 -0
- tests/core/__init__.py +13 -0
- tests/core/test_disagg_utils.py +14 -0
- tests/core/test_dp_scheduler.py +650 -768
- tests/core/test_init.py +14 -0
- tests/distributed/__init__.py +13 -0
- tests/distributed/test_distributed_utils.py +120 -0
- tests/distributed/test_tpu_connector.py +478 -0
- tests/e2e/__init__.py +13 -0
- tests/e2e/test_async_scheduler.py +211 -0
- tests/e2e/test_data_parallel.py +289 -0
- tests/e2e/test_hybrid_kvcache.py +219 -0
- tests/e2e/test_local_disagg.py +257 -0
- tests/e2e/test_model_loader.py +268 -0
- tests/e2e/test_multi_modal_inference.py +111 -0
- tests/e2e/test_pipeline_parallel.py +265 -0
- tests/e2e/test_runai_model_streamer_loader.py +104 -0
- tests/e2e/test_sampling_params.py +269 -0
- tests/e2e/test_speculative_decoding.py +311 -0
- tests/e2e/test_structured_decoding.py +46 -0
- tests/executors/__init__.py +13 -0
- tests/executors/test_ray_distributed_executor.py +199 -0
- tests/experimental/__init__.py +13 -0
- tests/experimental/test_llama3_jax_stashed.py +208 -0
- tests/kernels/__init__.py +13 -0
- tests/kernels/collectives/__init__.py +13 -0
- tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
- tests/kernels/fused_moe_v1_test.py +14 -0
- tests/kernels/gmm_test.py +205 -0
- tests/kernels/mla_v1_test.py +143 -41
- tests/kernels/quantized_matmul_kernel_test.py +2 -34
- tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
- tests/layers/__init__.py +13 -0
- tests/layers/common/__init__.py +13 -0
- tests/layers/common/test_attention_interface.py +156 -0
- tests/layers/common/test_quantization.py +149 -0
- tests/layers/jax/__init__.py +13 -0
- tests/layers/jax/attention/__init__.py +13 -0
- tests/layers/jax/attention/test_common_attention.py +103 -0
- tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
- tests/layers/jax/attention/test_llama4_attention.py +135 -0
- tests/layers/jax/moe/__init__.py +13 -0
- tests/layers/jax/moe/test_deepseek_moe.py +235 -0
- tests/layers/jax/sample/__init__.py +13 -0
- tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
- tests/layers/jax/sample/test_sampling.py +115 -0
- tests/layers/jax/sample/test_sampling_metadata.py +254 -0
- tests/layers/jax/test_layers.py +155 -0
- tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
- tests/layers/jax/test_rope.py +93 -0
- tests/layers/jax/test_sharding.py +159 -0
- tests/layers/jax/test_transformer_block.py +152 -0
- tests/layers/vllm/__init__.py +13 -0
- tests/layers/vllm/test_attention.py +363 -0
- tests/layers/vllm/test_awq.py +405 -0
- tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +418 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +441 -0
- tests/layers/vllm/test_fp8.py +17 -0
- tests/layers/vllm/test_mxfp4.py +312 -0
- tests/layers/vllm/test_unquantized.py +651 -0
- tests/layers/vllm/utils.py +87 -0
- tests/lora/__init__.py +13 -0
- tests/lora/conftest.py +14 -0
- tests/lora/test_bgmv.py +14 -0
- tests/lora/test_layers.py +21 -3
- tests/lora/test_lora.py +15 -1
- tests/lora/test_lora_perf.py +67 -0
- tests/models/__init__.py +13 -0
- tests/models/common/__init__.py +13 -0
- tests/models/common/test_model_loader.py +455 -0
- tests/models/jax/__init__.py +13 -0
- tests/models/jax/test_deepseek_v3.py +401 -0
- tests/models/jax/test_llama3.py +184 -0
- tests/models/jax/test_llama4.py +298 -0
- tests/models/jax/test_llama_eagle3.py +197 -0
- tests/models/jax/test_llama_guard_4.py +242 -0
- tests/models/jax/test_qwen2.py +172 -0
- tests/models/jax/test_qwen2_5_vl.py +605 -0
- tests/models/jax/test_qwen3.py +169 -0
- tests/models/jax/test_weight_loading.py +180 -0
- tests/models/jax/utils/__init__.py +13 -0
- tests/models/jax/utils/test_multi_modal_utils.py +212 -0
- tests/platforms/__init__.py +13 -0
- tests/platforms/test_tpu_platform.py +54 -0
- tests/runner/__init__.py +13 -0
- tests/runner/test_block_table.py +395 -0
- tests/runner/test_input_batch.py +226 -0
- tests/runner/test_kv_cache.py +220 -0
- tests/runner/test_kv_cache_manager.py +498 -0
- tests/runner/test_multimodal_manager.py +429 -0
- tests/runner/test_persistent_batch_manager.py +84 -0
- tests/runner/test_speculative_decoding_manager.py +368 -0
- tests/runner/test_structured_decoding_manager.py +220 -0
- tests/runner/test_tpu_runner.py +261 -0
- tests/runner/test_tpu_runner_dp.py +1099 -0
- tests/runner/test_tpu_runner_mesh.py +200 -0
- tests/runner/test_utils.py +411 -0
- tests/spec_decode/__init__.py +13 -0
- tests/spec_decode/test_eagle3.py +311 -0
- tests/test_base.py +14 -0
- tests/test_envs.py +110 -12
- tests/test_tpu_info.py +14 -0
- tests/test_utils.py +2 -45
- tests/worker/__init__.py +13 -0
- tests/worker/tpu_worker_test.py +414 -0
- tpu_inference/__init__.py +14 -0
- tpu_inference/core/__init__.py +13 -0
- tpu_inference/core/sched/__init__.py +13 -0
- tpu_inference/core/sched/dp_scheduler.py +372 -56
- tpu_inference/distributed/__init__.py +13 -0
- tpu_inference/distributed/jax_parallel_state.py +14 -0
- tpu_inference/distributed/tpu_connector.py +15 -10
- tpu_inference/distributed/utils.py +56 -4
- tpu_inference/envs.py +92 -8
- tpu_inference/executors/__init__.py +13 -0
- tpu_inference/executors/ray_distributed_executor.py +22 -1
- tpu_inference/experimental/__init__.py +13 -0
- tpu_inference/experimental/llama3_jax_stashed.py +14 -0
- tpu_inference/kernels/__init__.py +13 -0
- tpu_inference/kernels/collectives/__init__.py +13 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
- tpu_inference/kernels/flash_attention/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +370 -324
- tpu_inference/kernels/megablox/__init__.py +13 -0
- tpu_inference/kernels/megablox/common.py +54 -0
- tpu_inference/kernels/megablox/gmm.py +646 -0
- tpu_inference/kernels/mla/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/kernel.py +117 -145
- tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
- tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +167 -97
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
- tpu_inference/layers/__init__.py +13 -0
- tpu_inference/layers/common/__init__.py +13 -0
- tpu_inference/layers/common/attention_interface.py +26 -19
- tpu_inference/layers/common/attention_metadata.py +14 -0
- tpu_inference/layers/common/quant_methods.py +15 -0
- tpu_inference/layers/common/quantization.py +270 -0
- tpu_inference/layers/common/sharding.py +31 -9
- tpu_inference/layers/jax/__init__.py +13 -0
- tpu_inference/layers/jax/attention/__init__.py +13 -0
- tpu_inference/layers/jax/attention/attention.py +19 -6
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
- tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
- tpu_inference/layers/jax/base.py +14 -0
- tpu_inference/layers/jax/constants.py +13 -0
- tpu_inference/layers/jax/layers.py +14 -0
- tpu_inference/layers/jax/misc.py +14 -0
- tpu_inference/layers/jax/moe/__init__.py +13 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
- tpu_inference/layers/jax/moe/moe.py +43 -3
- tpu_inference/layers/jax/pp_utils.py +53 -0
- tpu_inference/layers/jax/rope.py +14 -0
- tpu_inference/layers/jax/rope_interface.py +14 -0
- tpu_inference/layers/jax/sample/__init__.py +13 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
- tpu_inference/layers/jax/sample/sampling.py +15 -1
- tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
- tpu_inference/layers/jax/transformer_block.py +14 -0
- tpu_inference/layers/vllm/__init__.py +13 -0
- tpu_inference/layers/vllm/attention.py +4 -4
- tpu_inference/layers/vllm/fused_moe.py +210 -260
- tpu_inference/layers/vllm/linear_common.py +57 -22
- tpu_inference/layers/vllm/quantization/__init__.py +16 -0
- tpu_inference/layers/vllm/quantization/awq.py +15 -1
- tpu_inference/layers/vllm/quantization/common.py +33 -18
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +18 -3
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +211 -148
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +14 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +14 -0
- tpu_inference/layers/vllm/quantization/fp8.py +118 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +280 -210
- tpu_inference/layers/vllm/quantization/unquantized.py +134 -86
- tpu_inference/layers/vllm/sharding.py +21 -4
- tpu_inference/lora/__init__.py +13 -0
- tpu_inference/lora/torch_lora_ops.py +8 -13
- tpu_inference/models/__init__.py +13 -0
- tpu_inference/models/common/__init__.py +13 -0
- tpu_inference/models/common/model_loader.py +77 -36
- tpu_inference/models/jax/__init__.py +13 -0
- tpu_inference/models/jax/deepseek_v3.py +267 -157
- tpu_inference/models/jax/gpt_oss.py +26 -10
- tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
- tpu_inference/models/jax/llama3.py +99 -36
- tpu_inference/models/jax/llama4.py +14 -0
- tpu_inference/models/jax/llama_eagle3.py +14 -0
- tpu_inference/models/jax/llama_guard_4.py +15 -1
- tpu_inference/models/jax/qwen2.py +17 -2
- tpu_inference/models/jax/qwen2_5_vl.py +18 -4
- tpu_inference/models/jax/qwen3.py +17 -2
- tpu_inference/models/jax/utils/__init__.py +13 -0
- tpu_inference/models/jax/utils/file_utils.py +14 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
- tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
- tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +91 -31
- tpu_inference/models/jax/utils/weight_utils.py +39 -2
- tpu_inference/models/vllm/__init__.py +13 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +20 -4
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
- tpu_inference/platforms/__init__.py +14 -0
- tpu_inference/platforms/tpu_platform.py +47 -71
- tpu_inference/runner/__init__.py +13 -0
- tpu_inference/runner/compilation_manager.py +158 -63
- tpu_inference/runner/kv_cache.py +54 -20
- tpu_inference/runner/kv_cache_manager.py +53 -30
- tpu_inference/runner/lora_utils.py +14 -0
- tpu_inference/runner/multimodal_manager.py +15 -1
- tpu_inference/runner/persistent_batch_manager.py +54 -2
- tpu_inference/runner/speculative_decoding_manager.py +14 -0
- tpu_inference/runner/structured_decoding_manager.py +14 -0
- tpu_inference/runner/tpu_runner.py +105 -57
- tpu_inference/runner/utils.py +2 -2
- tpu_inference/spec_decode/__init__.py +13 -0
- tpu_inference/spec_decode/jax/__init__.py +13 -0
- tpu_inference/spec_decode/jax/eagle3.py +65 -19
- tpu_inference/tpu_info.py +14 -0
- tpu_inference/utils.py +72 -44
- tpu_inference/worker/__init__.py +13 -0
- tpu_inference/worker/tpu_worker.py +65 -52
- {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/METADATA +11 -9
- tpu_inference-0.13.0rc2.post7.dist-info/RECORD +261 -0
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
- tpu_inference-0.11.1.dev202511270815.dist-info/RECORD +0 -174
- {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/WHEEL +0 -0
- {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/top_level.txt +0 -0
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
import functools
|
|
2
16
|
from typing import TYPE_CHECKING, Dict, List
|
|
3
17
|
|
|
@@ -39,20 +53,30 @@ class KVCacheManager:
|
|
|
39
53
|
# means this layer will perform attention using the keys and values
|
|
40
54
|
# from the KV cache of `shared_kv_cache_layers[layer_name]`.
|
|
41
55
|
self.shared_kv_cache_layers: dict[str, str] = {}
|
|
56
|
+
self.use_mla = self.runner.model_config.use_mla
|
|
42
57
|
|
|
43
58
|
def get_kv_cache_spec(self):
|
|
44
59
|
# TODO(xiang): this hack tricks engine core to init successfully
|
|
45
60
|
block_size = self.runner.cache_config.block_size
|
|
46
|
-
use_mla = self.runner.model_config.use_mla
|
|
47
61
|
kv_cache_spec: dict[str, KVCacheSpec] = {}
|
|
48
62
|
|
|
49
63
|
# If use pure jax (MODEL_IMPL_TYPE=flax_nnx), we don't register
|
|
50
64
|
# attention into compilation config.
|
|
51
65
|
# Use FullAttentionSpec for each layer
|
|
52
66
|
# TODO(pooyam): Is it possible to merge the logic for vllm and non-vllm models?
|
|
67
|
+
model_config = self.runner.model_config
|
|
68
|
+
if self.use_mla:
|
|
69
|
+
# Individually pad the RopE and latents
|
|
70
|
+
qk_rope_head_dim = getattr(model_config.hf_text_config,
|
|
71
|
+
"qk_rope_head_dim", 0)
|
|
72
|
+
padded_kv_lora_rank = common_utils.align_to(
|
|
73
|
+
model_config.hf_text_config.kv_lora_rank, 128)
|
|
74
|
+
padded_qk_rope_head_dim = common_utils.align_to(
|
|
75
|
+
qk_rope_head_dim, 128)
|
|
76
|
+
mla_head_size = padded_kv_lora_rank + padded_qk_rope_head_dim
|
|
77
|
+
|
|
53
78
|
if len(self.runner.vllm_config.compilation_config.
|
|
54
79
|
static_forward_context) == 0:
|
|
55
|
-
model_config = self.runner.model_config
|
|
56
80
|
parallel_config = self.runner.parallel_config
|
|
57
81
|
# Pad num_kv_heads to multiple of TP size.
|
|
58
82
|
num_kv_heads = common_utils.get_padded_num_heads(
|
|
@@ -61,11 +85,11 @@ class KVCacheManager:
|
|
|
61
85
|
head_size = common_utils.get_padded_head_dim(
|
|
62
86
|
model_config.get_head_size())
|
|
63
87
|
for i in range(model_config.get_num_layers(parallel_config)):
|
|
64
|
-
if use_mla:
|
|
88
|
+
if self.use_mla:
|
|
65
89
|
kv_cache_spec[f"layer.{i}"] = MLAAttentionSpec(
|
|
66
90
|
block_size=block_size,
|
|
67
|
-
num_kv_heads=
|
|
68
|
-
head_size=
|
|
91
|
+
num_kv_heads=1,
|
|
92
|
+
head_size=mla_head_size,
|
|
69
93
|
dtype=self.runner.kv_cache_dtype,
|
|
70
94
|
cache_dtype_str=self.runner.vllm_config.cache_config.
|
|
71
95
|
cache_dtype)
|
|
@@ -83,14 +107,13 @@ class KVCacheManager:
|
|
|
83
107
|
self.runner.mesh.shape["model"])
|
|
84
108
|
head_size = common_utils.get_padded_head_dim(
|
|
85
109
|
hf_config.hidden_size // hf_config.num_attention_heads)
|
|
86
|
-
|
|
87
110
|
# Eagle3 has only 1 layer
|
|
88
111
|
for i in range(1):
|
|
89
|
-
if use_mla:
|
|
90
|
-
kv_cache_spec[f"
|
|
112
|
+
if self.use_mla:
|
|
113
|
+
kv_cache_spec[f"draft_layer.{i}"] = MLAAttentionSpec(
|
|
91
114
|
block_size=block_size,
|
|
92
|
-
num_kv_heads=
|
|
93
|
-
head_size=
|
|
115
|
+
num_kv_heads=1,
|
|
116
|
+
head_size=mla_head_size,
|
|
94
117
|
dtype=self.runner.kv_cache_dtype,
|
|
95
118
|
cache_dtype_str=self.runner.vllm_config.
|
|
96
119
|
cache_config.cache_dtype)
|
|
@@ -104,6 +127,7 @@ class KVCacheManager:
|
|
|
104
127
|
# Else propagate attention modules from compilation config.
|
|
105
128
|
layers = get_layers_from_vllm_config(self.runner.vllm_config,
|
|
106
129
|
Attention)
|
|
130
|
+
logger.warning(f"Compilation num_layers = {len(layers.items())}")
|
|
107
131
|
for layer_name, attn_module in layers.items():
|
|
108
132
|
if (kv_tgt_layer :=
|
|
109
133
|
attn_module.kv_sharing_target_layer_name) is not None:
|
|
@@ -127,11 +151,11 @@ class KVCacheManager:
|
|
|
127
151
|
attn_module.head_size),
|
|
128
152
|
dtype=self.runner.kv_cache_dtype,
|
|
129
153
|
sliding_window=attn_module.sliding_window)
|
|
130
|
-
elif use_mla:
|
|
131
|
-
kv_cache_spec[
|
|
154
|
+
elif self.use_mla:
|
|
155
|
+
kv_cache_spec[layer_name] = MLAAttentionSpec(
|
|
132
156
|
block_size=block_size,
|
|
133
|
-
num_kv_heads=
|
|
134
|
-
head_size=
|
|
157
|
+
num_kv_heads=1,
|
|
158
|
+
head_size=mla_head_size,
|
|
135
159
|
dtype=self.runner.kv_cache_dtype,
|
|
136
160
|
cache_dtype_str=self.runner.vllm_config.
|
|
137
161
|
cache_config.cache_dtype)
|
|
@@ -198,14 +222,20 @@ class KVCacheManager:
|
|
|
198
222
|
# num_blocks must be a multiple of dp_size
|
|
199
223
|
num_blocks = (num_blocks // dp_size) * dp_size
|
|
200
224
|
# NOTE: we'll multiply the num_kv_heads by 2 in the function
|
|
225
|
+
if self.use_mla:
|
|
226
|
+
head_size = self.runner.model_config.hf_config.kv_lora_rank + \
|
|
227
|
+
self.runner.model_config.hf_config.qk_rope_head_dim
|
|
228
|
+
else:
|
|
229
|
+
head_size = representative_spec.head_size
|
|
201
230
|
kv_cache = create_kv_caches(
|
|
202
231
|
num_blocks=num_blocks,
|
|
203
232
|
block_size=representative_spec.block_size,
|
|
204
233
|
num_kv_heads=representative_spec.num_kv_heads,
|
|
205
|
-
head_size=
|
|
234
|
+
head_size=head_size,
|
|
206
235
|
mesh=self.runner.mesh,
|
|
207
236
|
layer_names=[f'kv_cache_tensor.{i}'],
|
|
208
237
|
cache_dtype=t2j_dtype(representative_spec.dtype),
|
|
238
|
+
use_mla=self.use_mla,
|
|
209
239
|
)[0]
|
|
210
240
|
kv_caches.append(kv_cache)
|
|
211
241
|
num_blocks_list.append(num_blocks)
|
|
@@ -289,13 +319,8 @@ class KVCacheManager:
|
|
|
289
319
|
|
|
290
320
|
def _update_layer(cache, slices):
|
|
291
321
|
"""The function to apply to each layer's cache and slices."""
|
|
292
|
-
reshaped_slices = slices.reshape(-1,
|
|
293
|
-
|
|
294
|
-
for (i, block_idx) in enumerate(block_numbers):
|
|
295
|
-
cache = jax.lax.dynamic_update_slice_in_dim(cache,
|
|
296
|
-
reshaped_slices[i],
|
|
297
|
-
block_idx,
|
|
298
|
-
axis=0)
|
|
322
|
+
reshaped_slices = slices.reshape(-1, block_size, *slices.shape[1:])
|
|
323
|
+
cache.at[block_numbers].set(reshaped_slices)
|
|
299
324
|
return cache
|
|
300
325
|
|
|
301
326
|
return jax.tree.map(_update_layer, kv_caches, kv_cache_slices)
|
|
@@ -348,16 +373,12 @@ class KVCacheManager:
|
|
|
348
373
|
"""
|
|
349
374
|
if block_ids == list(range(block_ids[0],
|
|
350
375
|
block_ids[0] + len(block_ids))):
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
batched_kv_cache_per_layer = self._jitted_gather_continuous_kv_cache(
|
|
354
|
-
self.runner.kv_caches, block_ids[0], len(block_ids))
|
|
376
|
+
batched_kv_cache_per_layer = self._jitted_gather_continuous_kv_cache(
|
|
377
|
+
self.runner.kv_caches, block_ids[0], len(block_ids))
|
|
355
378
|
|
|
356
379
|
else:
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
batched_kv_cache_per_layer = self._jitted_gather_kv_cache(
|
|
360
|
-
self.runner.kv_caches, jnp.array(block_ids))
|
|
380
|
+
batched_kv_cache_per_layer = self._jitted_gather_kv_cache(
|
|
381
|
+
self.runner.kv_caches, jnp.array(block_ids))
|
|
361
382
|
return batched_kv_cache_per_layer
|
|
362
383
|
|
|
363
384
|
def transfer_kv_cache(self,
|
|
@@ -446,6 +467,7 @@ class KVCacheManager:
|
|
|
446
467
|
kv_cache_slices,
|
|
447
468
|
start_block,
|
|
448
469
|
)
|
|
470
|
+
jax.block_until_ready(self.runner.kv_caches)
|
|
449
471
|
else:
|
|
450
472
|
with runner_utils.LatencyTracker(
|
|
451
473
|
f"JittedInsertKVCache-b{len(block_numbers)}"):
|
|
@@ -457,6 +479,7 @@ class KVCacheManager:
|
|
|
457
479
|
kv_cache_slices,
|
|
458
480
|
jnp.array(block_numbers),
|
|
459
481
|
)
|
|
482
|
+
jax.block_until_ready(self.runner.kv_caches)
|
|
460
483
|
|
|
461
484
|
logger.debug(
|
|
462
485
|
f"Updated kv cache entries cnt={len(self.runner.kv_caches)}")
|
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
from __future__ import annotations
|
|
2
16
|
|
|
3
17
|
from typing import TYPE_CHECKING
|
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
from typing import TYPE_CHECKING
|
|
2
16
|
|
|
3
17
|
import jax
|
|
@@ -134,7 +148,7 @@ class MultiModalManager:
|
|
|
134
148
|
# 2. A list or tuple (length: num_items) of tensors, each of shape
|
|
135
149
|
# (feature_size, hidden_size) in case the feature size is dynamic
|
|
136
150
|
# depending on the input multimodal items.
|
|
137
|
-
curr_group_outputs = self.runner.
|
|
151
|
+
curr_group_outputs = self.runner.embed_multimodal_fn(
|
|
138
152
|
self.runner.state, image_grid_thw, **batched_mm_inputs)
|
|
139
153
|
|
|
140
154
|
sanity_check_mm_encoder_outputs(
|
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
from typing import Dict
|
|
2
16
|
|
|
3
17
|
import jax
|
|
@@ -14,12 +28,13 @@ class PersistentBatchManager:
|
|
|
14
28
|
def __init__(self, requests: Dict[str, CachedRequestState],
|
|
15
29
|
input_batch: InputBatch, encoder_cache: Dict[str,
|
|
16
30
|
'jax.Array'],
|
|
17
|
-
uses_mrope: bool, model_config):
|
|
31
|
+
uses_mrope: bool, model_config, is_last_rank: bool):
|
|
18
32
|
self.requests = requests
|
|
19
33
|
self.input_batch = input_batch
|
|
20
34
|
self.encoder_cache = encoder_cache
|
|
21
35
|
self.uses_mrope = uses_mrope
|
|
22
36
|
self.model_config = model_config
|
|
37
|
+
self.is_last_rank = is_last_rank
|
|
23
38
|
|
|
24
39
|
def _reorder_batch(self, scheduler_output: "VllmSchedulerOutput") -> int:
|
|
25
40
|
""" Reorder the sheduled requests to RPA kernel friendly distribution
|
|
@@ -179,9 +194,35 @@ class PersistentBatchManager:
|
|
|
179
194
|
num_computed_tokens = req_data.num_computed_tokens[i]
|
|
180
195
|
new_block_ids = req_data.new_block_ids[i]
|
|
181
196
|
resumed_from_preemption = req_data.resumed_from_preemption[i]
|
|
197
|
+
num_output_tokens = req_data.num_output_tokens[i]
|
|
182
198
|
|
|
183
199
|
# Update the cached states.
|
|
184
200
|
req_state.num_computed_tokens = num_computed_tokens
|
|
201
|
+
req_index = self.input_batch.req_id_to_index.get(req_id)
|
|
202
|
+
|
|
203
|
+
if not self.is_last_rank:
|
|
204
|
+
# When using PP, the scheduler sends the sampled tokens back,
|
|
205
|
+
# because there's no direct communication between the first-
|
|
206
|
+
# stage worker and the last-stage worker.
|
|
207
|
+
new_token_ids = req_data.new_token_ids[i]
|
|
208
|
+
# Add the sampled token(s) from the previous step (if any).
|
|
209
|
+
# This doesn't include "unverified" tokens like spec tokens.
|
|
210
|
+
num_new_tokens = (num_computed_tokens + len(new_token_ids) -
|
|
211
|
+
req_state.num_tokens)
|
|
212
|
+
if num_new_tokens == 1:
|
|
213
|
+
req_state.output_token_ids.append(new_token_ids[-1])
|
|
214
|
+
elif num_new_tokens > 0:
|
|
215
|
+
req_state.output_token_ids.extend(
|
|
216
|
+
new_token_ids[-num_new_tokens:])
|
|
217
|
+
elif num_output_tokens < len(req_state.output_token_ids):
|
|
218
|
+
del req_state.output_token_ids[num_output_tokens:]
|
|
219
|
+
if req_index is not None:
|
|
220
|
+
end_idx = (self.input_batch.num_prompt_tokens[req_index] +
|
|
221
|
+
num_output_tokens)
|
|
222
|
+
self.input_batch.num_tokens[req_index] = end_idx
|
|
223
|
+
self.input_batch.num_tokens_no_spec[req_index] = end_idx
|
|
224
|
+
|
|
225
|
+
# Update the block IDs.
|
|
185
226
|
if not resumed_from_preemption:
|
|
186
227
|
if new_block_ids is not None:
|
|
187
228
|
# Append the new blocks to the existing block IDs.
|
|
@@ -194,7 +235,6 @@ class PersistentBatchManager:
|
|
|
194
235
|
# Replace the existing block IDs with the new ones.
|
|
195
236
|
req_state.block_ids = new_block_ids
|
|
196
237
|
|
|
197
|
-
req_index = self.input_batch.req_id_to_index.get(req_id)
|
|
198
238
|
if req_index is None:
|
|
199
239
|
# The request is not in the persistent batch.
|
|
200
240
|
# The request was either preempted and resumed later, or was not
|
|
@@ -209,6 +249,18 @@ class PersistentBatchManager:
|
|
|
209
249
|
self.input_batch.block_table.append_row(
|
|
210
250
|
new_block_ids, req_index)
|
|
211
251
|
|
|
252
|
+
# For the last rank, we don't need to update the token_ids_cpu
|
|
253
|
+
# because the sampled tokens are already cached.
|
|
254
|
+
if not self.is_last_rank:
|
|
255
|
+
start_token_index = num_computed_tokens
|
|
256
|
+
end_token_index = num_computed_tokens + len(new_token_ids)
|
|
257
|
+
self.input_batch.token_ids_cpu[
|
|
258
|
+
req_index,
|
|
259
|
+
start_token_index:end_token_index] = new_token_ids
|
|
260
|
+
self.input_batch.num_tokens_no_spec[
|
|
261
|
+
req_index] = end_token_index
|
|
262
|
+
self.input_batch.num_tokens[req_index] = end_token_index
|
|
263
|
+
|
|
212
264
|
# Add spec_token_ids to token_ids_cpu.
|
|
213
265
|
spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get(
|
|
214
266
|
req_id, ())
|
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
from __future__ import annotations
|
|
2
16
|
|
|
3
17
|
from dataclasses import dataclass
|
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
import functools
|
|
2
16
|
from typing import TYPE_CHECKING, Tuple
|
|
3
17
|
|