tpu-inference 0.11.1.dev202512030818__py3-none-any.whl → 0.13.0rc2.post7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +13 -0
- tests/core/__init__.py +13 -0
- tests/core/test_disagg_utils.py +14 -0
- tests/core/test_dp_scheduler.py +650 -768
- tests/core/test_init.py +14 -0
- tests/distributed/__init__.py +13 -0
- tests/distributed/test_distributed_utils.py +120 -0
- tests/distributed/test_tpu_connector.py +478 -0
- tests/e2e/__init__.py +13 -0
- tests/e2e/test_async_scheduler.py +211 -0
- tests/e2e/test_data_parallel.py +289 -0
- tests/e2e/test_hybrid_kvcache.py +219 -0
- tests/e2e/test_local_disagg.py +257 -0
- tests/e2e/test_model_loader.py +268 -0
- tests/e2e/test_multi_modal_inference.py +111 -0
- tests/e2e/test_pipeline_parallel.py +265 -0
- tests/e2e/test_runai_model_streamer_loader.py +104 -0
- tests/e2e/test_sampling_params.py +269 -0
- tests/e2e/test_speculative_decoding.py +311 -0
- tests/e2e/test_structured_decoding.py +46 -0
- tests/executors/__init__.py +13 -0
- tests/executors/test_ray_distributed_executor.py +199 -0
- tests/experimental/__init__.py +13 -0
- tests/experimental/test_llama3_jax_stashed.py +208 -0
- tests/kernels/__init__.py +13 -0
- tests/kernels/collectives/__init__.py +13 -0
- tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
- tests/kernels/fused_moe_v1_test.py +14 -0
- tests/kernels/gmm_test.py +205 -0
- tests/kernels/mla_v1_test.py +143 -41
- tests/kernels/quantized_matmul_kernel_test.py +2 -34
- tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
- tests/layers/__init__.py +13 -0
- tests/layers/common/__init__.py +13 -0
- tests/layers/common/test_attention_interface.py +156 -0
- tests/layers/common/test_quantization.py +149 -0
- tests/layers/jax/__init__.py +13 -0
- tests/layers/jax/attention/__init__.py +13 -0
- tests/layers/jax/attention/test_common_attention.py +103 -0
- tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
- tests/layers/jax/attention/test_llama4_attention.py +135 -0
- tests/layers/jax/moe/__init__.py +13 -0
- tests/layers/jax/moe/test_deepseek_moe.py +235 -0
- tests/layers/jax/sample/__init__.py +13 -0
- tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
- tests/layers/jax/sample/test_sampling.py +115 -0
- tests/layers/jax/sample/test_sampling_metadata.py +254 -0
- tests/layers/jax/test_layers.py +155 -0
- tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
- tests/layers/jax/test_rope.py +93 -0
- tests/layers/jax/test_sharding.py +159 -0
- tests/layers/jax/test_transformer_block.py +152 -0
- tests/layers/vllm/__init__.py +13 -0
- tests/layers/vllm/test_attention.py +363 -0
- tests/layers/vllm/test_awq.py +405 -0
- tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +418 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +441 -0
- tests/layers/vllm/test_fp8.py +17 -0
- tests/layers/vllm/test_mxfp4.py +312 -0
- tests/layers/vllm/test_unquantized.py +651 -0
- tests/layers/vllm/utils.py +87 -0
- tests/lora/__init__.py +13 -0
- tests/lora/conftest.py +14 -0
- tests/lora/test_bgmv.py +14 -0
- tests/lora/test_layers.py +21 -3
- tests/lora/test_lora.py +15 -1
- tests/lora/test_lora_perf.py +67 -0
- tests/models/__init__.py +13 -0
- tests/models/common/__init__.py +13 -0
- tests/models/common/test_model_loader.py +455 -0
- tests/models/jax/__init__.py +13 -0
- tests/models/jax/test_deepseek_v3.py +401 -0
- tests/models/jax/test_llama3.py +184 -0
- tests/models/jax/test_llama4.py +298 -0
- tests/models/jax/test_llama_eagle3.py +197 -0
- tests/models/jax/test_llama_guard_4.py +242 -0
- tests/models/jax/test_qwen2.py +172 -0
- tests/models/jax/test_qwen2_5_vl.py +605 -0
- tests/models/jax/test_qwen3.py +169 -0
- tests/models/jax/test_weight_loading.py +180 -0
- tests/models/jax/utils/__init__.py +13 -0
- tests/models/jax/utils/test_multi_modal_utils.py +212 -0
- tests/platforms/__init__.py +13 -0
- tests/platforms/test_tpu_platform.py +54 -0
- tests/runner/__init__.py +13 -0
- tests/runner/test_block_table.py +395 -0
- tests/runner/test_input_batch.py +226 -0
- tests/runner/test_kv_cache.py +220 -0
- tests/runner/test_kv_cache_manager.py +498 -0
- tests/runner/test_multimodal_manager.py +429 -0
- tests/runner/test_persistent_batch_manager.py +84 -0
- tests/runner/test_speculative_decoding_manager.py +368 -0
- tests/runner/test_structured_decoding_manager.py +220 -0
- tests/runner/test_tpu_runner.py +261 -0
- tests/runner/test_tpu_runner_dp.py +1099 -0
- tests/runner/test_tpu_runner_mesh.py +200 -0
- tests/runner/test_utils.py +411 -0
- tests/spec_decode/__init__.py +13 -0
- tests/spec_decode/test_eagle3.py +311 -0
- tests/test_base.py +14 -0
- tests/test_envs.py +78 -1
- tests/test_tpu_info.py +14 -0
- tests/test_utils.py +1 -43
- tests/worker/__init__.py +13 -0
- tests/worker/tpu_worker_test.py +414 -0
- tpu_inference/__init__.py +14 -0
- tpu_inference/core/__init__.py +13 -0
- tpu_inference/core/sched/__init__.py +13 -0
- tpu_inference/core/sched/dp_scheduler.py +372 -56
- tpu_inference/distributed/__init__.py +13 -0
- tpu_inference/distributed/jax_parallel_state.py +14 -0
- tpu_inference/distributed/tpu_connector.py +14 -9
- tpu_inference/distributed/utils.py +56 -4
- tpu_inference/envs.py +38 -7
- tpu_inference/executors/__init__.py +13 -0
- tpu_inference/executors/ray_distributed_executor.py +17 -0
- tpu_inference/experimental/__init__.py +13 -0
- tpu_inference/experimental/llama3_jax_stashed.py +14 -0
- tpu_inference/kernels/__init__.py +13 -0
- tpu_inference/kernels/collectives/__init__.py +13 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
- tpu_inference/kernels/flash_attention/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +370 -324
- tpu_inference/kernels/megablox/__init__.py +13 -0
- tpu_inference/kernels/megablox/common.py +54 -0
- tpu_inference/kernels/megablox/gmm.py +646 -0
- tpu_inference/kernels/mla/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/kernel.py +117 -145
- tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
- tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +95 -78
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
- tpu_inference/layers/__init__.py +13 -0
- tpu_inference/layers/common/__init__.py +13 -0
- tpu_inference/layers/common/attention_interface.py +26 -19
- tpu_inference/layers/common/attention_metadata.py +14 -0
- tpu_inference/layers/common/quant_methods.py +15 -0
- tpu_inference/layers/common/quantization.py +270 -0
- tpu_inference/layers/common/sharding.py +28 -5
- tpu_inference/layers/jax/__init__.py +13 -0
- tpu_inference/layers/jax/attention/__init__.py +13 -0
- tpu_inference/layers/jax/attention/attention.py +19 -6
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
- tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
- tpu_inference/layers/jax/base.py +14 -0
- tpu_inference/layers/jax/constants.py +13 -0
- tpu_inference/layers/jax/layers.py +14 -0
- tpu_inference/layers/jax/misc.py +14 -0
- tpu_inference/layers/jax/moe/__init__.py +13 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
- tpu_inference/layers/jax/moe/moe.py +43 -3
- tpu_inference/layers/jax/pp_utils.py +53 -0
- tpu_inference/layers/jax/rope.py +14 -0
- tpu_inference/layers/jax/rope_interface.py +14 -0
- tpu_inference/layers/jax/sample/__init__.py +13 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
- tpu_inference/layers/jax/sample/sampling.py +15 -1
- tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
- tpu_inference/layers/jax/transformer_block.py +14 -0
- tpu_inference/layers/vllm/__init__.py +13 -0
- tpu_inference/layers/vllm/attention.py +4 -4
- tpu_inference/layers/vllm/fused_moe.py +210 -260
- tpu_inference/layers/vllm/linear_common.py +57 -22
- tpu_inference/layers/vllm/quantization/__init__.py +16 -0
- tpu_inference/layers/vllm/quantization/awq.py +15 -1
- tpu_inference/layers/vllm/quantization/common.py +33 -18
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +18 -3
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +211 -148
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +14 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +14 -0
- tpu_inference/layers/vllm/quantization/fp8.py +118 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +278 -209
- tpu_inference/layers/vllm/quantization/unquantized.py +134 -86
- tpu_inference/layers/vllm/sharding.py +21 -4
- tpu_inference/lora/__init__.py +13 -0
- tpu_inference/lora/torch_lora_ops.py +8 -13
- tpu_inference/models/__init__.py +13 -0
- tpu_inference/models/common/__init__.py +13 -0
- tpu_inference/models/common/model_loader.py +74 -35
- tpu_inference/models/jax/__init__.py +13 -0
- tpu_inference/models/jax/deepseek_v3.py +267 -157
- tpu_inference/models/jax/gpt_oss.py +26 -10
- tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
- tpu_inference/models/jax/llama3.py +99 -36
- tpu_inference/models/jax/llama4.py +14 -0
- tpu_inference/models/jax/llama_eagle3.py +14 -0
- tpu_inference/models/jax/llama_guard_4.py +15 -1
- tpu_inference/models/jax/qwen2.py +17 -2
- tpu_inference/models/jax/qwen2_5_vl.py +18 -4
- tpu_inference/models/jax/qwen3.py +17 -2
- tpu_inference/models/jax/utils/__init__.py +13 -0
- tpu_inference/models/jax/utils/file_utils.py +14 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
- tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
- tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +88 -25
- tpu_inference/models/jax/utils/weight_utils.py +39 -2
- tpu_inference/models/vllm/__init__.py +13 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +20 -3
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
- tpu_inference/platforms/__init__.py +14 -0
- tpu_inference/platforms/tpu_platform.py +47 -64
- tpu_inference/runner/__init__.py +13 -0
- tpu_inference/runner/compilation_manager.py +72 -37
- tpu_inference/runner/kv_cache.py +54 -20
- tpu_inference/runner/kv_cache_manager.py +45 -15
- tpu_inference/runner/lora_utils.py +14 -0
- tpu_inference/runner/multimodal_manager.py +15 -1
- tpu_inference/runner/persistent_batch_manager.py +14 -0
- tpu_inference/runner/speculative_decoding_manager.py +14 -0
- tpu_inference/runner/structured_decoding_manager.py +14 -0
- tpu_inference/runner/tpu_runner.py +41 -16
- tpu_inference/spec_decode/__init__.py +13 -0
- tpu_inference/spec_decode/jax/__init__.py +13 -0
- tpu_inference/spec_decode/jax/eagle3.py +13 -0
- tpu_inference/tpu_info.py +14 -0
- tpu_inference/utils.py +42 -36
- tpu_inference/worker/__init__.py +13 -0
- tpu_inference/worker/tpu_worker.py +63 -50
- {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/METADATA +11 -9
- tpu_inference-0.13.0rc2.post7.dist-info/RECORD +261 -0
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
- tpu_inference-0.11.1.dev202512030818.dist-info/RECORD +0 -174
- {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/WHEEL +0 -0
- {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,498 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from unittest.mock import MagicMock, patch
|
|
16
|
+
|
|
17
|
+
import jax
|
|
18
|
+
import jax.numpy as jnp
|
|
19
|
+
import numpy as np
|
|
20
|
+
import pytest
|
|
21
|
+
import torch
|
|
22
|
+
from vllm.attention.backends.abstract import AttentionType
|
|
23
|
+
from vllm.attention.layer import Attention
|
|
24
|
+
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
|
|
25
|
+
SchedulerConfig, VllmConfig)
|
|
26
|
+
from vllm.sampling_params import SamplingType
|
|
27
|
+
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
|
28
|
+
KVCacheGroupSpec, KVCacheTensor,
|
|
29
|
+
MLAAttentionSpec, SlidingWindowSpec)
|
|
30
|
+
from vllm.v1.request import Request
|
|
31
|
+
|
|
32
|
+
from tpu_inference import utils as common_utils
|
|
33
|
+
from tpu_inference.runner.input_batch import CachedRequestState
|
|
34
|
+
from tpu_inference.runner.tpu_runner import TPUModelRunner
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class TestKVCacheManager:
|
|
38
|
+
|
|
39
|
+
def _setup_runner(self, use_mla: bool = False):
|
|
40
|
+
# Mock JAX dependencies
|
|
41
|
+
self.mock_rng_key = MagicMock()
|
|
42
|
+
|
|
43
|
+
self.mock_devices = [MagicMock(coords=i) for i in range(4)]
|
|
44
|
+
self.mock_rng_key = MagicMock()
|
|
45
|
+
|
|
46
|
+
# create 1x1 mesh
|
|
47
|
+
devices = np.asarray(jax.devices()[:1])
|
|
48
|
+
axis_names = ('data', 'attn_dp', 'model', 'expert')
|
|
49
|
+
mesh_shape = (1, 1, 1, 1)
|
|
50
|
+
self.mock_mesh = jax.sharding.Mesh(devices.reshape(mesh_shape),
|
|
51
|
+
axis_names)
|
|
52
|
+
|
|
53
|
+
with patch('jax.devices', return_value=self.mock_devices), \
|
|
54
|
+
patch('jax.make_mesh', return_value=self.mock_mesh), \
|
|
55
|
+
patch('jax.experimental.mesh_utils.create_device_mesh', return_value=self.mock_mesh), \
|
|
56
|
+
patch('tpu_inference.runner.tpu_runner.TPUModelRunner._create_new_model_mesh', return_value=self.mock_mesh), \
|
|
57
|
+
patch('tpu_inference.runner.tpu_runner.TPUModelRunner._init_mesh', return_value=self.mock_mesh), \
|
|
58
|
+
patch('jax.random.key', return_value=self.mock_rng_key), \
|
|
59
|
+
patch('tpu_inference.runner.tpu_runner.get_model', return_value=MagicMock()):
|
|
60
|
+
|
|
61
|
+
model_config = ModelConfig(tokenizer_mode="auto",
|
|
62
|
+
trust_remote_code=False,
|
|
63
|
+
seed=0,
|
|
64
|
+
dtype='bfloat16',
|
|
65
|
+
use_mla=use_mla)
|
|
66
|
+
cache_config = CacheConfig(
|
|
67
|
+
block_size=16,
|
|
68
|
+
gpu_memory_utilization=0.9,
|
|
69
|
+
swap_space=4,
|
|
70
|
+
cache_dtype="auto",
|
|
71
|
+
)
|
|
72
|
+
scheduler_config = SchedulerConfig(max_num_seqs=16,
|
|
73
|
+
max_model_len=1024,
|
|
74
|
+
is_encoder_decoder=False)
|
|
75
|
+
parallel_config = ParallelConfig(
|
|
76
|
+
pipeline_parallel_size=1,
|
|
77
|
+
tensor_parallel_size=1,
|
|
78
|
+
worker_use_ray=False,
|
|
79
|
+
)
|
|
80
|
+
vllm_config = VllmConfig(
|
|
81
|
+
model_config=model_config,
|
|
82
|
+
cache_config=cache_config,
|
|
83
|
+
scheduler_config=scheduler_config,
|
|
84
|
+
parallel_config=parallel_config,
|
|
85
|
+
observability_config={},
|
|
86
|
+
additional_config={},
|
|
87
|
+
)
|
|
88
|
+
self.runner = TPUModelRunner(vllm_config,
|
|
89
|
+
devices=self.mock_devices)
|
|
90
|
+
self.runner.mesh = self.mock_mesh
|
|
91
|
+
|
|
92
|
+
def setup_method(self):
|
|
93
|
+
self._setup_runner(use_mla=False)
|
|
94
|
+
|
|
95
|
+
def test_insert_request_with_kv_cache(self):
|
|
96
|
+
# This test refines the insertion test by first extracting a KV cache
|
|
97
|
+
# using get_kv_cache_for_block_ids, simulating a prefill->decode
|
|
98
|
+
# transfer, and then inserting it. This ensures the extraction and
|
|
99
|
+
# insertion logic are compatible.
|
|
100
|
+
|
|
101
|
+
# 1. ===== Setup source runner for prefill simulation =====
|
|
102
|
+
self.runner.block_size = 64
|
|
103
|
+
num_layers = 2
|
|
104
|
+
num_kv_heads = 16
|
|
105
|
+
head_size = 128
|
|
106
|
+
num_blocks = 50
|
|
107
|
+
# This is needed for the padding logic in insert_request_with_kv_cache
|
|
108
|
+
self.runner.vllm_config.cache_config.num_gpu_blocks = num_blocks
|
|
109
|
+
|
|
110
|
+
prompt_len = 64
|
|
111
|
+
|
|
112
|
+
# Populate a source KV cache with data. This represents the state
|
|
113
|
+
# of the prefill runner's KV cache.
|
|
114
|
+
source_kv_cache_shape = (num_blocks, self.runner.block_size,
|
|
115
|
+
2 * num_kv_heads // 2, 2, head_size)
|
|
116
|
+
prod_val = int(np.prod(source_kv_cache_shape))
|
|
117
|
+
source_kv_caches = [
|
|
118
|
+
jnp.arange(prod_val,
|
|
119
|
+
dtype=jnp.bfloat16).reshape(source_kv_cache_shape),
|
|
120
|
+
jnp.arange(prod_val, 2 * prod_val,
|
|
121
|
+
dtype=jnp.bfloat16).reshape(source_kv_cache_shape)
|
|
122
|
+
]
|
|
123
|
+
self.runner.kv_caches = source_kv_caches
|
|
124
|
+
|
|
125
|
+
# Create a mock for sampling_params to avoid TypeErrors in add_request
|
|
126
|
+
mock_sampling_params = MagicMock()
|
|
127
|
+
mock_sampling_params.sampling_type = SamplingType.GREEDY
|
|
128
|
+
mock_sampling_params.temperature = 0.0
|
|
129
|
+
mock_sampling_params.top_p = 1.0
|
|
130
|
+
mock_sampling_params.top_k = -1 # Common value for greedy
|
|
131
|
+
mock_sampling_params.min_tokens = 0
|
|
132
|
+
mock_sampling_params.logprobs = None
|
|
133
|
+
mock_sampling_params.logit_bias = None
|
|
134
|
+
mock_sampling_params.allowed_token_ids = set()
|
|
135
|
+
mock_sampling_params.bad_words_token_ids = None
|
|
136
|
+
mock_sampling_params.all_stop_token_ids = set()
|
|
137
|
+
|
|
138
|
+
# 2. ===== Simulate prefill execution state =====
|
|
139
|
+
prefill_block_ids = [5]
|
|
140
|
+
# Create a request state for prefill.
|
|
141
|
+
prefill_request_state = CachedRequestState(
|
|
142
|
+
req_id="test_req_1",
|
|
143
|
+
prompt_token_ids=list(range(prompt_len)),
|
|
144
|
+
output_token_ids=[],
|
|
145
|
+
sampling_params=mock_sampling_params,
|
|
146
|
+
block_ids=tuple([prefill_block_ids]),
|
|
147
|
+
num_computed_tokens=0,
|
|
148
|
+
lora_request=None,
|
|
149
|
+
mm_features=[],
|
|
150
|
+
pooling_params=None,
|
|
151
|
+
generator=None,
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
# Add the request to the input_batch to simulate it being scheduled.
|
|
155
|
+
self.runner.input_batch.add_request(prefill_request_state)
|
|
156
|
+
|
|
157
|
+
# 3. ===== Extract KV cache using get_kv_cache_for_block_ids =====
|
|
158
|
+
# Extract the full KV cache for the allocated block.
|
|
159
|
+
full_block_kv_cache = self.runner.get_kv_cache_for_block_ids(
|
|
160
|
+
block_ids=prefill_block_ids)
|
|
161
|
+
|
|
162
|
+
# Since get_kv_cache_for_block_ids returns the full block, but the
|
|
163
|
+
# prompt only fills part of it, we need to slice it to the actual
|
|
164
|
+
# prompt length for the insertion test to be accurate.
|
|
165
|
+
extracted_kv_cache_slices = [
|
|
166
|
+
layer_cache[:prompt_len] for layer_cache in full_block_kv_cache
|
|
167
|
+
]
|
|
168
|
+
|
|
169
|
+
# 4. ===== Setup destination runner for decode simulation =====
|
|
170
|
+
# Reset runner state to simulate a fresh decode runner.
|
|
171
|
+
self.runner.requests = {}
|
|
172
|
+
req_index = self.runner.input_batch.remove_request("test_req_1")
|
|
173
|
+
if req_index is not None:
|
|
174
|
+
self.runner.input_batch.condense([req_index])
|
|
175
|
+
|
|
176
|
+
# Initialize destination KV caches with zeros.
|
|
177
|
+
dest_kv_cache_shape = (num_blocks, self.runner.block_size,
|
|
178
|
+
2 * num_kv_heads // 2, 2, head_size)
|
|
179
|
+
self.runner.kv_caches = [
|
|
180
|
+
jnp.zeros(dest_kv_cache_shape, dtype=jnp.bfloat16)
|
|
181
|
+
for _ in range(num_layers)
|
|
182
|
+
]
|
|
183
|
+
|
|
184
|
+
# Create a mock request as it would be after prefill + 1 token.
|
|
185
|
+
decode_request = MagicMock(spec=Request)
|
|
186
|
+
decode_request.request_id = "test_req_1"
|
|
187
|
+
decode_request.num_tokens = prompt_len + 1 # Total tokens
|
|
188
|
+
decode_request.num_computed_tokens = prompt_len
|
|
189
|
+
decode_request.prompt_token_ids = list(range(prompt_len))
|
|
190
|
+
decode_request.all_token_ids = [123, 232, 908]
|
|
191
|
+
decode_request.output_token_ids = [100]
|
|
192
|
+
decode_request.sampling_params = mock_sampling_params
|
|
193
|
+
|
|
194
|
+
decode_request.lora_request = None
|
|
195
|
+
decode_request.mm_kwargs, decode_request.mm_positions = [], []
|
|
196
|
+
decode_request.pooling_params, decode_request.generator = None, None
|
|
197
|
+
|
|
198
|
+
# Prepare the KV cache slices for insertion. They must be padded to the
|
|
199
|
+
# full block size and have a leading dimension for the number of blocks.
|
|
200
|
+
|
|
201
|
+
# Allocate new block IDs for the decode runner.
|
|
202
|
+
decode_block_ids = [[10]]
|
|
203
|
+
# 5. ===== Call the method to be tested =====
|
|
204
|
+
self.runner.insert_request_with_kv_cache(decode_request,
|
|
205
|
+
extracted_kv_cache_slices,
|
|
206
|
+
decode_block_ids)
|
|
207
|
+
|
|
208
|
+
# 6. ===== Assertions =====
|
|
209
|
+
assert "test_req_1" in self.runner.requests
|
|
210
|
+
assert "test_req_1" in self.runner.input_batch.req_id_to_index
|
|
211
|
+
assert self.runner.requests[
|
|
212
|
+
"test_req_1"].num_computed_tokens == prompt_len
|
|
213
|
+
assert self.runner.requests["test_req_1"].output_token_ids == [908]
|
|
214
|
+
|
|
215
|
+
# Verify the content of the inserted KV cache.
|
|
216
|
+
target_block_id = decode_block_ids[0][0]
|
|
217
|
+
for i, layer_kv_cache in enumerate(self.runner.kv_caches):
|
|
218
|
+
updated_block_content = layer_kv_cache[target_block_id]
|
|
219
|
+
|
|
220
|
+
# The extracted slice should be padded to the block size.
|
|
221
|
+
padding_size = self.runner.block_size - prompt_len
|
|
222
|
+
expected_padded_slice = jnp.pad(extracted_kv_cache_slices[i],
|
|
223
|
+
((0, padding_size), (0, 0), (0, 0),
|
|
224
|
+
(0, 0)),
|
|
225
|
+
mode='constant')
|
|
226
|
+
np.testing.assert_array_equal(updated_block_content,
|
|
227
|
+
expected_padded_slice)
|
|
228
|
+
|
|
229
|
+
@pytest.mark.parametrize("num_kv_heads", [16, 32])
|
|
230
|
+
@pytest.mark.parametrize("head_size", [64, 100, 200])
|
|
231
|
+
def test_get_kv_cache_spec_with_compilation_cfg(self, num_kv_heads,
|
|
232
|
+
head_size):
|
|
233
|
+
# tests we create kv cache spec from compilation config
|
|
234
|
+
# create a static forward context with
|
|
235
|
+
# 10 full attention layers +
|
|
236
|
+
# 10 sliding window attention layers
|
|
237
|
+
# 1 layer with shared kv cache.
|
|
238
|
+
attn_type = AttentionType.DECODER
|
|
239
|
+
sliding_window = 10
|
|
240
|
+
static_forward_context = {}
|
|
241
|
+
for i in range(10):
|
|
242
|
+
static_forward_context[f'layer.{i}'] = MagicMock(
|
|
243
|
+
spec=Attention,
|
|
244
|
+
num_kv_heads=num_kv_heads,
|
|
245
|
+
head_size=head_size,
|
|
246
|
+
attn_type=attn_type,
|
|
247
|
+
sliding_window=None,
|
|
248
|
+
kv_sharing_target_layer_name=None,
|
|
249
|
+
)
|
|
250
|
+
for i in range(10, 20):
|
|
251
|
+
static_forward_context[f'layer.{i}'] = MagicMock(
|
|
252
|
+
spec=Attention,
|
|
253
|
+
num_kv_heads=num_kv_heads,
|
|
254
|
+
head_size=head_size,
|
|
255
|
+
attn_type=attn_type,
|
|
256
|
+
sliding_window=sliding_window,
|
|
257
|
+
kv_sharing_target_layer_name=None,
|
|
258
|
+
)
|
|
259
|
+
static_forward_context['layer.20'] = MagicMock(
|
|
260
|
+
spec=Attention,
|
|
261
|
+
num_kv_heads=num_kv_heads,
|
|
262
|
+
head_size=head_size,
|
|
263
|
+
attn_type=attn_type,
|
|
264
|
+
sliding_window=None,
|
|
265
|
+
kv_sharing_target_layer_name='layer.0',
|
|
266
|
+
)
|
|
267
|
+
self.runner.vllm_config.compilation_config.static_forward_context = \
|
|
268
|
+
static_forward_context
|
|
269
|
+
|
|
270
|
+
kv_cache_spec = self.runner.get_kv_cache_spec()
|
|
271
|
+
|
|
272
|
+
expected_full_attn_spec = FullAttentionSpec(
|
|
273
|
+
block_size=self.runner.vllm_config.cache_config.block_size,
|
|
274
|
+
num_kv_heads=common_utils.get_padded_num_heads(
|
|
275
|
+
num_kv_heads, self.runner.mesh.shape["model"]),
|
|
276
|
+
head_size=common_utils.get_padded_head_dim(head_size),
|
|
277
|
+
dtype=torch.bfloat16)
|
|
278
|
+
expected_sliding_window_spec = SlidingWindowSpec(
|
|
279
|
+
block_size=self.runner.vllm_config.cache_config.block_size,
|
|
280
|
+
num_kv_heads=common_utils.get_padded_num_heads(
|
|
281
|
+
num_kv_heads, self.runner.mesh.shape["model"]),
|
|
282
|
+
head_size=common_utils.get_padded_head_dim(head_size),
|
|
283
|
+
dtype=torch.bfloat16,
|
|
284
|
+
sliding_window=sliding_window)
|
|
285
|
+
assert len(kv_cache_spec) == 20
|
|
286
|
+
for i in range(10):
|
|
287
|
+
assert kv_cache_spec[f'layer.{i}'] == expected_full_attn_spec
|
|
288
|
+
for i in range(10, 20):
|
|
289
|
+
assert kv_cache_spec[f'layer.{i}'] == expected_sliding_window_spec
|
|
290
|
+
assert 'layer.20' not in kv_cache_spec
|
|
291
|
+
assert self.runner.kv_cache_manager.shared_kv_cache_layers == {
|
|
292
|
+
'layer.20': 'layer.0'
|
|
293
|
+
}
|
|
294
|
+
|
|
295
|
+
def test_get_kv_cache_spec_with_compilation_cfg_mla(self):
|
|
296
|
+
# tests we create kv cache spec from compilation config with mla
|
|
297
|
+
self.runner.kv_cache_manager.use_mla = True
|
|
298
|
+
|
|
299
|
+
# Mock hf_text_config to have kv_lora_rank and qk_rope_head_dim
|
|
300
|
+
mock_hf_text_config = MagicMock()
|
|
301
|
+
mock_hf_text_config.kv_lora_rank = 400
|
|
302
|
+
mock_hf_text_config.qk_rope_head_dim = 40
|
|
303
|
+
self.runner.model_config.hf_text_config = mock_hf_text_config
|
|
304
|
+
|
|
305
|
+
num_kv_heads = 16
|
|
306
|
+
head_size = 512 # Aggregated padding amount may be passed to the model instead.
|
|
307
|
+
expected_head_size = 640 # 640 = align(512, 128) + alignto(40, 128)
|
|
308
|
+
attn_type = AttentionType.DECODER
|
|
309
|
+
static_forward_context = {}
|
|
310
|
+
# Mock one layer, as the logic is the same for all
|
|
311
|
+
mock_attn_module = MagicMock(
|
|
312
|
+
spec=Attention,
|
|
313
|
+
num_kv_heads=num_kv_heads,
|
|
314
|
+
head_size=head_size,
|
|
315
|
+
attn_type=attn_type,
|
|
316
|
+
sliding_window=None,
|
|
317
|
+
kv_sharing_target_layer_name=None,
|
|
318
|
+
)
|
|
319
|
+
mock_attn_module.use_mla = True
|
|
320
|
+
static_forward_context['layer.0'] = mock_attn_module
|
|
321
|
+
self.runner.vllm_config.compilation_config.static_forward_context = \
|
|
322
|
+
static_forward_context
|
|
323
|
+
|
|
324
|
+
kv_cache_spec = self.runner.get_kv_cache_spec()
|
|
325
|
+
|
|
326
|
+
assert len(kv_cache_spec) == 1
|
|
327
|
+
spec = kv_cache_spec['layer.0']
|
|
328
|
+
assert isinstance(spec, MLAAttentionSpec)
|
|
329
|
+
assert spec.num_kv_heads == 1
|
|
330
|
+
assert spec.head_size == expected_head_size
|
|
331
|
+
|
|
332
|
+
def test_get_kv_cache_spec_without_compilation_cfg(self):
|
|
333
|
+
# tests if there's no compilation config, we use full attention kv
|
|
334
|
+
# cache for each layer.
|
|
335
|
+
model_config = self.runner.vllm_config.model_config
|
|
336
|
+
parallel_config = self.runner.vllm_config.parallel_config
|
|
337
|
+
head_size = model_config.get_head_size()
|
|
338
|
+
num_kv_heads = model_config.get_total_num_kv_heads()
|
|
339
|
+
num_layers = model_config.get_num_layers(parallel_config)
|
|
340
|
+
|
|
341
|
+
self.runner.vllm_config.compilation_config.static_forward_context = {}
|
|
342
|
+
kv_cache_spec = self.runner.get_kv_cache_spec()
|
|
343
|
+
|
|
344
|
+
assert len(kv_cache_spec) == num_layers
|
|
345
|
+
expected_full_attn_spec = FullAttentionSpec(
|
|
346
|
+
block_size=self.runner.vllm_config.cache_config.block_size,
|
|
347
|
+
num_kv_heads=common_utils.get_padded_num_heads(
|
|
348
|
+
num_kv_heads, self.runner.mesh.shape["model"]),
|
|
349
|
+
head_size=common_utils.get_padded_head_dim(head_size),
|
|
350
|
+
dtype=torch.bfloat16)
|
|
351
|
+
for i in range(num_layers):
|
|
352
|
+
assert kv_cache_spec[f'layer.{i}'] == expected_full_attn_spec
|
|
353
|
+
assert len(self.runner.kv_cache_manager.shared_kv_cache_layers) == 0
|
|
354
|
+
|
|
355
|
+
def test_get_kv_cache_spec_without_compilation_cfg_mla(self):
|
|
356
|
+
self.runner.kv_cache_manager.use_mla = True
|
|
357
|
+
model_config = self.runner.vllm_config.model_config
|
|
358
|
+
parallel_config = self.runner.vllm_config.parallel_config
|
|
359
|
+
num_layers = model_config.get_num_layers(parallel_config)
|
|
360
|
+
|
|
361
|
+
mock_hf_text_config = MagicMock()
|
|
362
|
+
mock_hf_text_config.kv_lora_rank = 400
|
|
363
|
+
mock_hf_text_config.qk_rope_head_dim = 40
|
|
364
|
+
self.runner.model_config.hf_text_config = mock_hf_text_config
|
|
365
|
+
expected_head_size = 640 # 640 = align(512, 128) + alignto(40, 128)
|
|
366
|
+
|
|
367
|
+
self.runner.vllm_config.compilation_config.static_forward_context = {}
|
|
368
|
+
with patch('vllm.config.ModelConfig.get_num_layers',
|
|
369
|
+
return_value=num_layers):
|
|
370
|
+
kv_cache_spec = self.runner.get_kv_cache_spec()
|
|
371
|
+
|
|
372
|
+
assert len(kv_cache_spec) == num_layers
|
|
373
|
+
for i in range(num_layers):
|
|
374
|
+
spec = kv_cache_spec[f"layer.{i}"]
|
|
375
|
+
assert isinstance(spec, MLAAttentionSpec)
|
|
376
|
+
assert spec.num_kv_heads == 1
|
|
377
|
+
assert spec.head_size == expected_head_size
|
|
378
|
+
|
|
379
|
+
def test_initialize_kv_cache(self):
|
|
380
|
+
# create a kv cache config with 10 layers full attention and 10 layers
|
|
381
|
+
# sliding window attention.
|
|
382
|
+
block_size = self.runner.vllm_config.cache_config.block_size
|
|
383
|
+
num_kv_heads = 8
|
|
384
|
+
head_size = 128
|
|
385
|
+
sliding_window = 100
|
|
386
|
+
num_blocks = 100
|
|
387
|
+
kv_packing = 2 #bf16
|
|
388
|
+
sliding_window_spec = SlidingWindowSpec(
|
|
389
|
+
block_size=block_size,
|
|
390
|
+
num_kv_heads=num_kv_heads,
|
|
391
|
+
head_size=head_size,
|
|
392
|
+
dtype=torch.bfloat16,
|
|
393
|
+
sliding_window=sliding_window,
|
|
394
|
+
)
|
|
395
|
+
full_attn_spec = FullAttentionSpec(
|
|
396
|
+
block_size=block_size,
|
|
397
|
+
num_kv_heads=num_kv_heads,
|
|
398
|
+
head_size=head_size,
|
|
399
|
+
dtype=torch.bfloat16,
|
|
400
|
+
)
|
|
401
|
+
kv_cache_groups = [
|
|
402
|
+
KVCacheGroupSpec(layer_names=[f'layer.{i}' for i in range(10)],
|
|
403
|
+
kv_cache_spec=full_attn_spec),
|
|
404
|
+
KVCacheGroupSpec(layer_names=[f'layer.{i}' for i in range(10, 20)],
|
|
405
|
+
kv_cache_spec=sliding_window_spec),
|
|
406
|
+
]
|
|
407
|
+
kv_cache_tensors = []
|
|
408
|
+
page_size_bytes = full_attn_spec.page_size_bytes
|
|
409
|
+
for i in range(10):
|
|
410
|
+
kv_cache_tensors.append(
|
|
411
|
+
KVCacheTensor(
|
|
412
|
+
size=num_blocks * page_size_bytes,
|
|
413
|
+
shared_by=[f'layer.{i}', f'layer.{i+10}'],
|
|
414
|
+
))
|
|
415
|
+
kv_cache_config = KVCacheConfig(
|
|
416
|
+
num_blocks=num_blocks,
|
|
417
|
+
kv_cache_tensors=kv_cache_tensors,
|
|
418
|
+
kv_cache_groups=kv_cache_groups,
|
|
419
|
+
)
|
|
420
|
+
|
|
421
|
+
original_input_batch = self.runner.input_batch
|
|
422
|
+
self.runner.initialize_kv_cache(kv_cache_config)
|
|
423
|
+
|
|
424
|
+
# assert kv cache config with multiple kv cache groups will reinit
|
|
425
|
+
# input batch.
|
|
426
|
+
assert original_input_batch != self.runner.input_batch
|
|
427
|
+
assert len(self.runner.kv_caches) == 10
|
|
428
|
+
for i in range(10):
|
|
429
|
+
assert self.runner.kv_caches[i].shape == (num_blocks, block_size,
|
|
430
|
+
num_kv_heads * 2 //
|
|
431
|
+
kv_packing, kv_packing,
|
|
432
|
+
head_size)
|
|
433
|
+
assert self.runner.layer_name_to_kvcache_index[f'layer.{i}'] == i
|
|
434
|
+
assert self.runner.layer_name_to_kvcache_index[
|
|
435
|
+
f'layer.{i + 10}'] == i
|
|
436
|
+
|
|
437
|
+
def test_get_kv_cache_spec_with_eagle3(self):
|
|
438
|
+
# tests we create kv cache spec for eagle3 draft model
|
|
439
|
+
self.runner.vllm_config.compilation_config.static_forward_context = {}
|
|
440
|
+
mock_speculative_config = MagicMock()
|
|
441
|
+
mock_speculative_config.method = "eagle3"
|
|
442
|
+
mock_draft_model_config = MagicMock()
|
|
443
|
+
mock_hf_config = MagicMock()
|
|
444
|
+
mock_hf_config.num_key_value_heads = 4
|
|
445
|
+
mock_hf_config.hidden_size = 1024
|
|
446
|
+
mock_hf_config.num_attention_heads = 8
|
|
447
|
+
mock_draft_model_config.hf_config = mock_hf_config
|
|
448
|
+
mock_speculative_config.draft_model_config = mock_draft_model_config
|
|
449
|
+
self.runner.speculative_config = mock_speculative_config
|
|
450
|
+
|
|
451
|
+
kv_cache_spec = self.runner.get_kv_cache_spec()
|
|
452
|
+
|
|
453
|
+
assert "draft_layer.0" in kv_cache_spec
|
|
454
|
+
draft_spec = kv_cache_spec["draft_layer.0"]
|
|
455
|
+
assert isinstance(draft_spec, FullAttentionSpec)
|
|
456
|
+
assert draft_spec.block_size == self.runner.vllm_config.cache_config.block_size
|
|
457
|
+
assert draft_spec.num_kv_heads == common_utils.get_padded_num_heads(
|
|
458
|
+
4, self.runner.mesh.shape["model"])
|
|
459
|
+
assert draft_spec.head_size == common_utils.get_padded_head_dim(128)
|
|
460
|
+
assert draft_spec.dtype == torch.bfloat16
|
|
461
|
+
|
|
462
|
+
def test_get_kv_cache_spec_with_eagle3_mla(self):
|
|
463
|
+
# tests we create kv cache spec for eagle3 draft model with mla
|
|
464
|
+
self.runner.kv_cache_manager.use_mla = True
|
|
465
|
+
|
|
466
|
+
self.runner.vllm_config.compilation_config.static_forward_context = {}
|
|
467
|
+
mock_speculative_config = MagicMock()
|
|
468
|
+
mock_speculative_config.method = "eagle3"
|
|
469
|
+
mock_draft_model_config = MagicMock()
|
|
470
|
+
mock_hf_config = MagicMock()
|
|
471
|
+
mock_hf_config.num_key_value_heads = 4
|
|
472
|
+
mock_hf_config.hidden_size = 1024
|
|
473
|
+
mock_hf_config.num_attention_heads = 8
|
|
474
|
+
mock_hf_config.num_layers = 16
|
|
475
|
+
model_layers = 1
|
|
476
|
+
mock_hf_text_config = MagicMock()
|
|
477
|
+
mock_hf_text_config.kv_lora_rank = 400
|
|
478
|
+
mock_hf_text_config.qk_rope_head_dim = 40
|
|
479
|
+
self.runner.model_config.hf_text_config = mock_hf_text_config
|
|
480
|
+
mock_draft_model_config.hf_config = mock_hf_config
|
|
481
|
+
mock_speculative_config.draft_model_config = mock_draft_model_config
|
|
482
|
+
self.runner.speculative_config = mock_speculative_config
|
|
483
|
+
|
|
484
|
+
kv_cache_spec = self.runner.get_kv_cache_spec()
|
|
485
|
+
|
|
486
|
+
# Without compilation context, it will create specs for the main model layers
|
|
487
|
+
# as well as the draft model layer.
|
|
488
|
+
assert len(kv_cache_spec) > model_layers
|
|
489
|
+
|
|
490
|
+
assert "draft_layer.0" in kv_cache_spec
|
|
491
|
+
draft_spec = kv_cache_spec["draft_layer.0"]
|
|
492
|
+
assert isinstance(draft_spec, FullAttentionSpec)
|
|
493
|
+
|
|
494
|
+
for i in range(model_layers):
|
|
495
|
+
assert f"layer.{i}" in kv_cache_spec
|
|
496
|
+
spec = kv_cache_spec[f"layer.{i}"]
|
|
497
|
+
assert isinstance(spec, MLAAttentionSpec)
|
|
498
|
+
assert spec.num_kv_heads == 1
|