tpu-inference 0.12.0.dev20251213__py3-none-any.whl → 0.13.2.dev20251230__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (248) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_disagg_utils.py +14 -0
  4. tests/core/test_dp_scheduler.py +650 -768
  5. tests/core/test_init.py +14 -0
  6. tests/distributed/__init__.py +13 -0
  7. tests/distributed/test_distributed_utils.py +120 -0
  8. tests/distributed/test_tpu_connector.py +478 -0
  9. tests/e2e/__init__.py +13 -0
  10. tests/e2e/test_async_scheduler.py +211 -0
  11. tests/e2e/test_data_parallel.py +289 -0
  12. tests/e2e/test_hybrid_kvcache.py +219 -0
  13. tests/e2e/test_local_disagg.py +257 -0
  14. tests/e2e/test_model_loader.py +268 -0
  15. tests/e2e/test_multi_modal_inference.py +111 -0
  16. tests/e2e/test_pipeline_parallel.py +265 -0
  17. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  18. tests/e2e/test_sampling_params.py +269 -0
  19. tests/e2e/test_speculative_decoding.py +311 -0
  20. tests/e2e/test_structured_decoding.py +46 -0
  21. tests/executors/__init__.py +13 -0
  22. tests/executors/test_ray_distributed_executor.py +199 -0
  23. tests/experimental/__init__.py +13 -0
  24. tests/experimental/test_llama3_jax_stashed.py +208 -0
  25. tests/kernels/__init__.py +13 -0
  26. tests/kernels/collectives/__init__.py +13 -0
  27. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  28. tests/kernels/fused_moe_v1_test.py +14 -0
  29. tests/kernels/gmm_test.py +205 -0
  30. tests/kernels/mla_v1_test.py +14 -0
  31. tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
  32. tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
  33. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +14 -0
  34. tests/kernels/ragged_paged_attention_kernel_v3_test.py +14 -0
  35. tests/layers/__init__.py +13 -0
  36. tests/layers/common/__init__.py +13 -0
  37. tests/layers/common/test_attention_interface.py +156 -0
  38. tests/layers/common/test_quantization.py +149 -0
  39. tests/layers/jax/__init__.py +13 -0
  40. tests/layers/jax/attention/__init__.py +13 -0
  41. tests/layers/jax/attention/test_common_attention.py +103 -0
  42. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  43. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  44. tests/layers/jax/moe/__init__.py +13 -0
  45. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  46. tests/layers/jax/sample/__init__.py +13 -0
  47. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  48. tests/layers/jax/sample/test_sampling.py +115 -0
  49. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  50. tests/layers/jax/test_layers.py +155 -0
  51. tests/{test_quantization.py → layers/jax/test_qwix.py} +180 -50
  52. tests/layers/jax/test_rope.py +93 -0
  53. tests/layers/jax/test_sharding.py +159 -0
  54. tests/layers/jax/test_transformer_block.py +152 -0
  55. tests/layers/vllm/__init__.py +13 -0
  56. tests/layers/vllm/test_attention.py +363 -0
  57. tests/layers/vllm/test_awq.py +406 -0
  58. tests/layers/vllm/test_compressed_tensors_moe.py +199 -0
  59. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +441 -0
  60. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +443 -0
  61. tests/layers/vllm/test_fp8.py +17 -0
  62. tests/layers/vllm/test_mxfp4.py +320 -0
  63. tests/layers/vllm/test_unquantized.py +662 -0
  64. tests/layers/vllm/utils.py +87 -0
  65. tests/lora/__init__.py +13 -0
  66. tests/lora/conftest.py +14 -0
  67. tests/lora/test_bgmv.py +14 -0
  68. tests/lora/test_layers.py +25 -8
  69. tests/lora/test_lora.py +15 -1
  70. tests/lora/test_lora_perf.py +14 -0
  71. tests/models/__init__.py +13 -0
  72. tests/models/common/__init__.py +13 -0
  73. tests/models/common/test_model_loader.py +455 -0
  74. tests/models/jax/__init__.py +13 -0
  75. tests/models/jax/test_deepseek_v3.py +401 -0
  76. tests/models/jax/test_llama3.py +184 -0
  77. tests/models/jax/test_llama4.py +298 -0
  78. tests/models/jax/test_llama_eagle3.py +197 -0
  79. tests/models/jax/test_llama_guard_4.py +242 -0
  80. tests/models/jax/test_qwen2.py +172 -0
  81. tests/models/jax/test_qwen2_5_vl.py +605 -0
  82. tests/models/jax/test_qwen3.py +169 -0
  83. tests/models/jax/test_weight_loading.py +180 -0
  84. tests/models/jax/utils/__init__.py +13 -0
  85. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  86. tests/platforms/__init__.py +13 -0
  87. tests/platforms/test_tpu_platform.py +54 -0
  88. tests/runner/__init__.py +13 -0
  89. tests/runner/test_block_table.py +395 -0
  90. tests/runner/test_input_batch.py +226 -0
  91. tests/runner/test_kv_cache.py +220 -0
  92. tests/runner/test_kv_cache_manager.py +498 -0
  93. tests/runner/test_multimodal_manager.py +429 -0
  94. tests/runner/test_persistent_batch_manager.py +84 -0
  95. tests/runner/test_speculative_decoding_manager.py +368 -0
  96. tests/runner/test_structured_decoding_manager.py +220 -0
  97. tests/runner/test_tpu_runner.py +261 -0
  98. tests/runner/test_tpu_runner_dp.py +1099 -0
  99. tests/runner/test_tpu_runner_mesh.py +200 -0
  100. tests/runner/test_utils.py +411 -0
  101. tests/spec_decode/__init__.py +13 -0
  102. tests/spec_decode/test_eagle3.py +311 -0
  103. tests/test_base.py +14 -0
  104. tests/test_tpu_info.py +14 -0
  105. tests/test_utils.py +1 -43
  106. tests/worker/__init__.py +13 -0
  107. tests/worker/tpu_worker_test.py +414 -0
  108. tpu_inference/__init__.py +14 -0
  109. tpu_inference/core/__init__.py +13 -0
  110. tpu_inference/core/sched/__init__.py +13 -0
  111. tpu_inference/core/sched/dp_scheduler.py +372 -56
  112. tpu_inference/distributed/__init__.py +13 -0
  113. tpu_inference/distributed/jax_parallel_state.py +14 -0
  114. tpu_inference/distributed/tpu_connector.py +14 -9
  115. tpu_inference/distributed/utils.py +56 -4
  116. tpu_inference/executors/__init__.py +13 -0
  117. tpu_inference/executors/ray_distributed_executor.py +20 -3
  118. tpu_inference/experimental/__init__.py +13 -0
  119. tpu_inference/experimental/llama3_jax_stashed.py +14 -0
  120. tpu_inference/kernels/__init__.py +13 -0
  121. tpu_inference/kernels/collectives/__init__.py +13 -0
  122. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  123. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  124. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  125. tpu_inference/kernels/fused_moe/v1/kernel.py +171 -163
  126. tpu_inference/kernels/megablox/__init__.py +13 -0
  127. tpu_inference/kernels/megablox/common.py +54 -0
  128. tpu_inference/kernels/megablox/gmm.py +646 -0
  129. tpu_inference/kernels/mla/__init__.py +13 -0
  130. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  131. tpu_inference/kernels/mla/v1/kernel.py +20 -26
  132. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  133. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  134. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  135. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  136. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +112 -69
  137. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +85 -65
  138. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
  139. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +374 -194
  140. tpu_inference/kernels/ragged_paged_attention/v3/util.py +13 -0
  141. tpu_inference/layers/__init__.py +13 -0
  142. tpu_inference/layers/common/__init__.py +13 -0
  143. tpu_inference/layers/common/attention_interface.py +26 -19
  144. tpu_inference/layers/common/attention_metadata.py +14 -0
  145. tpu_inference/layers/common/fused_moe_gmm.py +506 -0
  146. tpu_inference/layers/common/quant_methods.py +15 -0
  147. tpu_inference/layers/common/quantization.py +282 -0
  148. tpu_inference/layers/common/sharding.py +22 -3
  149. tpu_inference/layers/common/utils.py +94 -0
  150. tpu_inference/layers/jax/__init__.py +13 -0
  151. tpu_inference/layers/jax/attention/__init__.py +13 -0
  152. tpu_inference/layers/jax/attention/attention.py +19 -6
  153. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +52 -27
  154. tpu_inference/layers/jax/attention/gpt_oss_attention.py +19 -6
  155. tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
  156. tpu_inference/layers/jax/base.py +14 -0
  157. tpu_inference/layers/jax/constants.py +13 -0
  158. tpu_inference/layers/jax/layers.py +14 -0
  159. tpu_inference/layers/jax/misc.py +14 -0
  160. tpu_inference/layers/jax/moe/__init__.py +13 -0
  161. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
  162. tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
  163. tpu_inference/layers/jax/moe/moe.py +43 -3
  164. tpu_inference/layers/jax/pp_utils.py +53 -0
  165. tpu_inference/layers/jax/rope.py +14 -0
  166. tpu_inference/layers/jax/rope_interface.py +14 -0
  167. tpu_inference/layers/jax/sample/__init__.py +13 -0
  168. tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
  169. tpu_inference/layers/jax/sample/sampling.py +15 -1
  170. tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
  171. tpu_inference/layers/jax/transformer_block.py +14 -0
  172. tpu_inference/layers/vllm/__init__.py +13 -0
  173. tpu_inference/layers/vllm/attention.py +4 -4
  174. tpu_inference/layers/vllm/fused_moe.py +100 -455
  175. tpu_inference/layers/vllm/linear.py +64 -0
  176. tpu_inference/layers/vllm/process_weights/__init__.py +13 -0
  177. tpu_inference/layers/vllm/{sharding.py → process_weights/cleanup_sharding.py} +24 -15
  178. tpu_inference/layers/vllm/process_weights/fused_moe_weights.py +369 -0
  179. tpu_inference/layers/vllm/process_weights/linear_weights.py +174 -0
  180. tpu_inference/layers/vllm/quantization/__init__.py +19 -3
  181. tpu_inference/layers/vllm/quantization/awq.py +96 -82
  182. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  183. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +19 -5
  184. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +119 -132
  185. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  186. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +111 -91
  187. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +79 -43
  188. tpu_inference/layers/vllm/quantization/{common.py → configs.py} +38 -26
  189. tpu_inference/layers/vllm/quantization/fp8.py +119 -0
  190. tpu_inference/layers/vllm/quantization/mxfp4.py +133 -220
  191. tpu_inference/layers/vllm/quantization/unquantized.py +154 -253
  192. tpu_inference/lora/__init__.py +13 -0
  193. tpu_inference/lora/torch_lora_ops.py +8 -13
  194. tpu_inference/models/__init__.py +13 -0
  195. tpu_inference/models/common/__init__.py +13 -0
  196. tpu_inference/models/common/model_loader.py +37 -16
  197. tpu_inference/models/jax/__init__.py +13 -0
  198. tpu_inference/models/jax/deepseek_v3.py +113 -124
  199. tpu_inference/models/jax/gpt_oss.py +23 -7
  200. tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
  201. tpu_inference/models/jax/llama3.py +99 -36
  202. tpu_inference/models/jax/llama4.py +14 -0
  203. tpu_inference/models/jax/llama_eagle3.py +14 -0
  204. tpu_inference/models/jax/llama_guard_4.py +15 -1
  205. tpu_inference/models/jax/qwen2.py +17 -2
  206. tpu_inference/models/jax/qwen2_5_vl.py +18 -4
  207. tpu_inference/models/jax/qwen3.py +17 -2
  208. tpu_inference/models/jax/utils/__init__.py +13 -0
  209. tpu_inference/models/jax/utils/file_utils.py +14 -0
  210. tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
  211. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  212. tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +85 -24
  213. tpu_inference/models/jax/utils/weight_utils.py +32 -1
  214. tpu_inference/models/vllm/__init__.py +13 -0
  215. tpu_inference/models/vllm/vllm_model_wrapper.py +22 -4
  216. tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
  217. tpu_inference/platforms/__init__.py +14 -0
  218. tpu_inference/platforms/tpu_platform.py +27 -29
  219. tpu_inference/runner/__init__.py +13 -0
  220. tpu_inference/runner/compilation_manager.py +69 -35
  221. tpu_inference/runner/kv_cache.py +14 -0
  222. tpu_inference/runner/kv_cache_manager.py +15 -2
  223. tpu_inference/runner/lora_utils.py +16 -1
  224. tpu_inference/runner/multimodal_manager.py +16 -2
  225. tpu_inference/runner/persistent_batch_manager.py +14 -0
  226. tpu_inference/runner/speculative_decoding_manager.py +14 -0
  227. tpu_inference/runner/structured_decoding_manager.py +14 -0
  228. tpu_inference/runner/tpu_runner.py +30 -10
  229. tpu_inference/spec_decode/__init__.py +13 -0
  230. tpu_inference/spec_decode/jax/__init__.py +13 -0
  231. tpu_inference/spec_decode/jax/eagle3.py +13 -0
  232. tpu_inference/tpu_info.py +14 -0
  233. tpu_inference/utils.py +31 -30
  234. tpu_inference/worker/__init__.py +13 -0
  235. tpu_inference/worker/tpu_worker.py +23 -7
  236. {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/METADATA +1 -1
  237. tpu_inference-0.13.2.dev20251230.dist-info/RECORD +266 -0
  238. tpu_inference/layers/vllm/linear_common.py +0 -208
  239. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  240. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
  241. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
  242. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
  243. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
  244. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
  245. tpu_inference-0.12.0.dev20251213.dist-info/RECORD +0 -175
  246. {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/WHEEL +0 -0
  247. {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/licenses/LICENSE +0 -0
  248. {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,498 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from unittest.mock import MagicMock, patch
16
+
17
+ import jax
18
+ import jax.numpy as jnp
19
+ import numpy as np
20
+ import pytest
21
+ import torch
22
+ from vllm.attention.backends.abstract import AttentionType
23
+ from vllm.attention.layer import Attention
24
+ from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
25
+ SchedulerConfig, VllmConfig)
26
+ from vllm.sampling_params import SamplingType
27
+ from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
28
+ KVCacheGroupSpec, KVCacheTensor,
29
+ MLAAttentionSpec, SlidingWindowSpec)
30
+ from vllm.v1.request import Request
31
+
32
+ from tpu_inference import utils as common_utils
33
+ from tpu_inference.runner.input_batch import CachedRequestState
34
+ from tpu_inference.runner.tpu_runner import TPUModelRunner
35
+
36
+
37
+ class TestKVCacheManager:
38
+
39
+ def _setup_runner(self, use_mla: bool = False):
40
+ # Mock JAX dependencies
41
+ self.mock_rng_key = MagicMock()
42
+
43
+ self.mock_devices = [MagicMock(coords=i) for i in range(4)]
44
+ self.mock_rng_key = MagicMock()
45
+
46
+ # create 1x1 mesh
47
+ devices = np.asarray(jax.devices()[:1])
48
+ axis_names = ('data', 'attn_dp', 'model', 'expert')
49
+ mesh_shape = (1, 1, 1, 1)
50
+ self.mock_mesh = jax.sharding.Mesh(devices.reshape(mesh_shape),
51
+ axis_names)
52
+
53
+ with patch('jax.devices', return_value=self.mock_devices), \
54
+ patch('jax.make_mesh', return_value=self.mock_mesh), \
55
+ patch('jax.experimental.mesh_utils.create_device_mesh', return_value=self.mock_mesh), \
56
+ patch('tpu_inference.runner.tpu_runner.TPUModelRunner._create_new_model_mesh', return_value=self.mock_mesh), \
57
+ patch('tpu_inference.runner.tpu_runner.TPUModelRunner._init_mesh', return_value=self.mock_mesh), \
58
+ patch('jax.random.key', return_value=self.mock_rng_key), \
59
+ patch('tpu_inference.runner.tpu_runner.get_model', return_value=MagicMock()):
60
+
61
+ model_config = ModelConfig(tokenizer_mode="auto",
62
+ trust_remote_code=False,
63
+ seed=0,
64
+ dtype='bfloat16',
65
+ use_mla=use_mla)
66
+ cache_config = CacheConfig(
67
+ block_size=16,
68
+ gpu_memory_utilization=0.9,
69
+ swap_space=4,
70
+ cache_dtype="auto",
71
+ )
72
+ scheduler_config = SchedulerConfig(max_num_seqs=16,
73
+ max_model_len=1024,
74
+ is_encoder_decoder=False)
75
+ parallel_config = ParallelConfig(
76
+ pipeline_parallel_size=1,
77
+ tensor_parallel_size=1,
78
+ worker_use_ray=False,
79
+ )
80
+ vllm_config = VllmConfig(
81
+ model_config=model_config,
82
+ cache_config=cache_config,
83
+ scheduler_config=scheduler_config,
84
+ parallel_config=parallel_config,
85
+ observability_config={},
86
+ additional_config={},
87
+ )
88
+ self.runner = TPUModelRunner(vllm_config,
89
+ devices=self.mock_devices)
90
+ self.runner.mesh = self.mock_mesh
91
+
92
+ def setup_method(self):
93
+ self._setup_runner(use_mla=False)
94
+
95
+ def test_insert_request_with_kv_cache(self):
96
+ # This test refines the insertion test by first extracting a KV cache
97
+ # using get_kv_cache_for_block_ids, simulating a prefill->decode
98
+ # transfer, and then inserting it. This ensures the extraction and
99
+ # insertion logic are compatible.
100
+
101
+ # 1. ===== Setup source runner for prefill simulation =====
102
+ self.runner.block_size = 64
103
+ num_layers = 2
104
+ num_kv_heads = 16
105
+ head_size = 128
106
+ num_blocks = 50
107
+ # This is needed for the padding logic in insert_request_with_kv_cache
108
+ self.runner.vllm_config.cache_config.num_gpu_blocks = num_blocks
109
+
110
+ prompt_len = 64
111
+
112
+ # Populate a source KV cache with data. This represents the state
113
+ # of the prefill runner's KV cache.
114
+ source_kv_cache_shape = (num_blocks, self.runner.block_size,
115
+ 2 * num_kv_heads // 2, 2, head_size)
116
+ prod_val = int(np.prod(source_kv_cache_shape))
117
+ source_kv_caches = [
118
+ jnp.arange(prod_val,
119
+ dtype=jnp.bfloat16).reshape(source_kv_cache_shape),
120
+ jnp.arange(prod_val, 2 * prod_val,
121
+ dtype=jnp.bfloat16).reshape(source_kv_cache_shape)
122
+ ]
123
+ self.runner.kv_caches = source_kv_caches
124
+
125
+ # Create a mock for sampling_params to avoid TypeErrors in add_request
126
+ mock_sampling_params = MagicMock()
127
+ mock_sampling_params.sampling_type = SamplingType.GREEDY
128
+ mock_sampling_params.temperature = 0.0
129
+ mock_sampling_params.top_p = 1.0
130
+ mock_sampling_params.top_k = -1 # Common value for greedy
131
+ mock_sampling_params.min_tokens = 0
132
+ mock_sampling_params.logprobs = None
133
+ mock_sampling_params.logit_bias = None
134
+ mock_sampling_params.allowed_token_ids = set()
135
+ mock_sampling_params.bad_words_token_ids = None
136
+ mock_sampling_params.all_stop_token_ids = set()
137
+
138
+ # 2. ===== Simulate prefill execution state =====
139
+ prefill_block_ids = [5]
140
+ # Create a request state for prefill.
141
+ prefill_request_state = CachedRequestState(
142
+ req_id="test_req_1",
143
+ prompt_token_ids=list(range(prompt_len)),
144
+ output_token_ids=[],
145
+ sampling_params=mock_sampling_params,
146
+ block_ids=tuple([prefill_block_ids]),
147
+ num_computed_tokens=0,
148
+ lora_request=None,
149
+ mm_features=[],
150
+ pooling_params=None,
151
+ generator=None,
152
+ )
153
+
154
+ # Add the request to the input_batch to simulate it being scheduled.
155
+ self.runner.input_batch.add_request(prefill_request_state)
156
+
157
+ # 3. ===== Extract KV cache using get_kv_cache_for_block_ids =====
158
+ # Extract the full KV cache for the allocated block.
159
+ full_block_kv_cache = self.runner.get_kv_cache_for_block_ids(
160
+ block_ids=prefill_block_ids)
161
+
162
+ # Since get_kv_cache_for_block_ids returns the full block, but the
163
+ # prompt only fills part of it, we need to slice it to the actual
164
+ # prompt length for the insertion test to be accurate.
165
+ extracted_kv_cache_slices = [
166
+ layer_cache[:prompt_len] for layer_cache in full_block_kv_cache
167
+ ]
168
+
169
+ # 4. ===== Setup destination runner for decode simulation =====
170
+ # Reset runner state to simulate a fresh decode runner.
171
+ self.runner.requests = {}
172
+ req_index = self.runner.input_batch.remove_request("test_req_1")
173
+ if req_index is not None:
174
+ self.runner.input_batch.condense([req_index])
175
+
176
+ # Initialize destination KV caches with zeros.
177
+ dest_kv_cache_shape = (num_blocks, self.runner.block_size,
178
+ 2 * num_kv_heads // 2, 2, head_size)
179
+ self.runner.kv_caches = [
180
+ jnp.zeros(dest_kv_cache_shape, dtype=jnp.bfloat16)
181
+ for _ in range(num_layers)
182
+ ]
183
+
184
+ # Create a mock request as it would be after prefill + 1 token.
185
+ decode_request = MagicMock(spec=Request)
186
+ decode_request.request_id = "test_req_1"
187
+ decode_request.num_tokens = prompt_len + 1 # Total tokens
188
+ decode_request.num_computed_tokens = prompt_len
189
+ decode_request.prompt_token_ids = list(range(prompt_len))
190
+ decode_request.all_token_ids = [123, 232, 908]
191
+ decode_request.output_token_ids = [100]
192
+ decode_request.sampling_params = mock_sampling_params
193
+
194
+ decode_request.lora_request = None
195
+ decode_request.mm_kwargs, decode_request.mm_positions = [], []
196
+ decode_request.pooling_params, decode_request.generator = None, None
197
+
198
+ # Prepare the KV cache slices for insertion. They must be padded to the
199
+ # full block size and have a leading dimension for the number of blocks.
200
+
201
+ # Allocate new block IDs for the decode runner.
202
+ decode_block_ids = [[10]]
203
+ # 5. ===== Call the method to be tested =====
204
+ self.runner.insert_request_with_kv_cache(decode_request,
205
+ extracted_kv_cache_slices,
206
+ decode_block_ids)
207
+
208
+ # 6. ===== Assertions =====
209
+ assert "test_req_1" in self.runner.requests
210
+ assert "test_req_1" in self.runner.input_batch.req_id_to_index
211
+ assert self.runner.requests[
212
+ "test_req_1"].num_computed_tokens == prompt_len
213
+ assert self.runner.requests["test_req_1"].output_token_ids == [908]
214
+
215
+ # Verify the content of the inserted KV cache.
216
+ target_block_id = decode_block_ids[0][0]
217
+ for i, layer_kv_cache in enumerate(self.runner.kv_caches):
218
+ updated_block_content = layer_kv_cache[target_block_id]
219
+
220
+ # The extracted slice should be padded to the block size.
221
+ padding_size = self.runner.block_size - prompt_len
222
+ expected_padded_slice = jnp.pad(extracted_kv_cache_slices[i],
223
+ ((0, padding_size), (0, 0), (0, 0),
224
+ (0, 0)),
225
+ mode='constant')
226
+ np.testing.assert_array_equal(updated_block_content,
227
+ expected_padded_slice)
228
+
229
+ @pytest.mark.parametrize("num_kv_heads", [16, 32])
230
+ @pytest.mark.parametrize("head_size", [64, 100, 200])
231
+ def test_get_kv_cache_spec_with_compilation_cfg(self, num_kv_heads,
232
+ head_size):
233
+ # tests we create kv cache spec from compilation config
234
+ # create a static forward context with
235
+ # 10 full attention layers +
236
+ # 10 sliding window attention layers
237
+ # 1 layer with shared kv cache.
238
+ attn_type = AttentionType.DECODER
239
+ sliding_window = 10
240
+ static_forward_context = {}
241
+ for i in range(10):
242
+ static_forward_context[f'layer.{i}'] = MagicMock(
243
+ spec=Attention,
244
+ num_kv_heads=num_kv_heads,
245
+ head_size=head_size,
246
+ attn_type=attn_type,
247
+ sliding_window=None,
248
+ kv_sharing_target_layer_name=None,
249
+ )
250
+ for i in range(10, 20):
251
+ static_forward_context[f'layer.{i}'] = MagicMock(
252
+ spec=Attention,
253
+ num_kv_heads=num_kv_heads,
254
+ head_size=head_size,
255
+ attn_type=attn_type,
256
+ sliding_window=sliding_window,
257
+ kv_sharing_target_layer_name=None,
258
+ )
259
+ static_forward_context['layer.20'] = MagicMock(
260
+ spec=Attention,
261
+ num_kv_heads=num_kv_heads,
262
+ head_size=head_size,
263
+ attn_type=attn_type,
264
+ sliding_window=None,
265
+ kv_sharing_target_layer_name='layer.0',
266
+ )
267
+ self.runner.vllm_config.compilation_config.static_forward_context = \
268
+ static_forward_context
269
+
270
+ kv_cache_spec = self.runner.get_kv_cache_spec()
271
+
272
+ expected_full_attn_spec = FullAttentionSpec(
273
+ block_size=self.runner.vllm_config.cache_config.block_size,
274
+ num_kv_heads=common_utils.get_padded_num_heads(
275
+ num_kv_heads, self.runner.mesh.shape["model"]),
276
+ head_size=common_utils.get_padded_head_dim(head_size),
277
+ dtype=torch.bfloat16)
278
+ expected_sliding_window_spec = SlidingWindowSpec(
279
+ block_size=self.runner.vllm_config.cache_config.block_size,
280
+ num_kv_heads=common_utils.get_padded_num_heads(
281
+ num_kv_heads, self.runner.mesh.shape["model"]),
282
+ head_size=common_utils.get_padded_head_dim(head_size),
283
+ dtype=torch.bfloat16,
284
+ sliding_window=sliding_window)
285
+ assert len(kv_cache_spec) == 20
286
+ for i in range(10):
287
+ assert kv_cache_spec[f'layer.{i}'] == expected_full_attn_spec
288
+ for i in range(10, 20):
289
+ assert kv_cache_spec[f'layer.{i}'] == expected_sliding_window_spec
290
+ assert 'layer.20' not in kv_cache_spec
291
+ assert self.runner.kv_cache_manager.shared_kv_cache_layers == {
292
+ 'layer.20': 'layer.0'
293
+ }
294
+
295
+ def test_get_kv_cache_spec_with_compilation_cfg_mla(self):
296
+ # tests we create kv cache spec from compilation config with mla
297
+ self.runner.kv_cache_manager.use_mla = True
298
+
299
+ # Mock hf_text_config to have kv_lora_rank and qk_rope_head_dim
300
+ mock_hf_text_config = MagicMock()
301
+ mock_hf_text_config.kv_lora_rank = 400
302
+ mock_hf_text_config.qk_rope_head_dim = 40
303
+ self.runner.model_config.hf_text_config = mock_hf_text_config
304
+
305
+ num_kv_heads = 16
306
+ head_size = 512 # Aggregated padding amount may be passed to the model instead.
307
+ expected_head_size = 640 # 640 = align(512, 128) + alignto(40, 128)
308
+ attn_type = AttentionType.DECODER
309
+ static_forward_context = {}
310
+ # Mock one layer, as the logic is the same for all
311
+ mock_attn_module = MagicMock(
312
+ spec=Attention,
313
+ num_kv_heads=num_kv_heads,
314
+ head_size=head_size,
315
+ attn_type=attn_type,
316
+ sliding_window=None,
317
+ kv_sharing_target_layer_name=None,
318
+ )
319
+ mock_attn_module.use_mla = True
320
+ static_forward_context['layer.0'] = mock_attn_module
321
+ self.runner.vllm_config.compilation_config.static_forward_context = \
322
+ static_forward_context
323
+
324
+ kv_cache_spec = self.runner.get_kv_cache_spec()
325
+
326
+ assert len(kv_cache_spec) == 1
327
+ spec = kv_cache_spec['layer.0']
328
+ assert isinstance(spec, MLAAttentionSpec)
329
+ assert spec.num_kv_heads == 1
330
+ assert spec.head_size == expected_head_size
331
+
332
+ def test_get_kv_cache_spec_without_compilation_cfg(self):
333
+ # tests if there's no compilation config, we use full attention kv
334
+ # cache for each layer.
335
+ model_config = self.runner.vllm_config.model_config
336
+ parallel_config = self.runner.vllm_config.parallel_config
337
+ head_size = model_config.get_head_size()
338
+ num_kv_heads = model_config.get_total_num_kv_heads()
339
+ num_layers = model_config.get_num_layers(parallel_config)
340
+
341
+ self.runner.vllm_config.compilation_config.static_forward_context = {}
342
+ kv_cache_spec = self.runner.get_kv_cache_spec()
343
+
344
+ assert len(kv_cache_spec) == num_layers
345
+ expected_full_attn_spec = FullAttentionSpec(
346
+ block_size=self.runner.vllm_config.cache_config.block_size,
347
+ num_kv_heads=common_utils.get_padded_num_heads(
348
+ num_kv_heads, self.runner.mesh.shape["model"]),
349
+ head_size=common_utils.get_padded_head_dim(head_size),
350
+ dtype=torch.bfloat16)
351
+ for i in range(num_layers):
352
+ assert kv_cache_spec[f'layer.{i}'] == expected_full_attn_spec
353
+ assert len(self.runner.kv_cache_manager.shared_kv_cache_layers) == 0
354
+
355
+ def test_get_kv_cache_spec_without_compilation_cfg_mla(self):
356
+ self.runner.kv_cache_manager.use_mla = True
357
+ model_config = self.runner.vllm_config.model_config
358
+ parallel_config = self.runner.vllm_config.parallel_config
359
+ num_layers = model_config.get_num_layers(parallel_config)
360
+
361
+ mock_hf_text_config = MagicMock()
362
+ mock_hf_text_config.kv_lora_rank = 400
363
+ mock_hf_text_config.qk_rope_head_dim = 40
364
+ self.runner.model_config.hf_text_config = mock_hf_text_config
365
+ expected_head_size = 640 # 640 = align(512, 128) + alignto(40, 128)
366
+
367
+ self.runner.vllm_config.compilation_config.static_forward_context = {}
368
+ with patch('vllm.config.ModelConfig.get_num_layers',
369
+ return_value=num_layers):
370
+ kv_cache_spec = self.runner.get_kv_cache_spec()
371
+
372
+ assert len(kv_cache_spec) == num_layers
373
+ for i in range(num_layers):
374
+ spec = kv_cache_spec[f"layer.{i}"]
375
+ assert isinstance(spec, MLAAttentionSpec)
376
+ assert spec.num_kv_heads == 1
377
+ assert spec.head_size == expected_head_size
378
+
379
+ def test_initialize_kv_cache(self):
380
+ # create a kv cache config with 10 layers full attention and 10 layers
381
+ # sliding window attention.
382
+ block_size = self.runner.vllm_config.cache_config.block_size
383
+ num_kv_heads = 8
384
+ head_size = 128
385
+ sliding_window = 100
386
+ num_blocks = 100
387
+ kv_packing = 2 #bf16
388
+ sliding_window_spec = SlidingWindowSpec(
389
+ block_size=block_size,
390
+ num_kv_heads=num_kv_heads,
391
+ head_size=head_size,
392
+ dtype=torch.bfloat16,
393
+ sliding_window=sliding_window,
394
+ )
395
+ full_attn_spec = FullAttentionSpec(
396
+ block_size=block_size,
397
+ num_kv_heads=num_kv_heads,
398
+ head_size=head_size,
399
+ dtype=torch.bfloat16,
400
+ )
401
+ kv_cache_groups = [
402
+ KVCacheGroupSpec(layer_names=[f'layer.{i}' for i in range(10)],
403
+ kv_cache_spec=full_attn_spec),
404
+ KVCacheGroupSpec(layer_names=[f'layer.{i}' for i in range(10, 20)],
405
+ kv_cache_spec=sliding_window_spec),
406
+ ]
407
+ kv_cache_tensors = []
408
+ page_size_bytes = full_attn_spec.page_size_bytes
409
+ for i in range(10):
410
+ kv_cache_tensors.append(
411
+ KVCacheTensor(
412
+ size=num_blocks * page_size_bytes,
413
+ shared_by=[f'layer.{i}', f'layer.{i+10}'],
414
+ ))
415
+ kv_cache_config = KVCacheConfig(
416
+ num_blocks=num_blocks,
417
+ kv_cache_tensors=kv_cache_tensors,
418
+ kv_cache_groups=kv_cache_groups,
419
+ )
420
+
421
+ original_input_batch = self.runner.input_batch
422
+ self.runner.initialize_kv_cache(kv_cache_config)
423
+
424
+ # assert kv cache config with multiple kv cache groups will reinit
425
+ # input batch.
426
+ assert original_input_batch != self.runner.input_batch
427
+ assert len(self.runner.kv_caches) == 10
428
+ for i in range(10):
429
+ assert self.runner.kv_caches[i].shape == (num_blocks, block_size,
430
+ num_kv_heads * 2 //
431
+ kv_packing, kv_packing,
432
+ head_size)
433
+ assert self.runner.layer_name_to_kvcache_index[f'layer.{i}'] == i
434
+ assert self.runner.layer_name_to_kvcache_index[
435
+ f'layer.{i + 10}'] == i
436
+
437
+ def test_get_kv_cache_spec_with_eagle3(self):
438
+ # tests we create kv cache spec for eagle3 draft model
439
+ self.runner.vllm_config.compilation_config.static_forward_context = {}
440
+ mock_speculative_config = MagicMock()
441
+ mock_speculative_config.method = "eagle3"
442
+ mock_draft_model_config = MagicMock()
443
+ mock_hf_config = MagicMock()
444
+ mock_hf_config.num_key_value_heads = 4
445
+ mock_hf_config.hidden_size = 1024
446
+ mock_hf_config.num_attention_heads = 8
447
+ mock_draft_model_config.hf_config = mock_hf_config
448
+ mock_speculative_config.draft_model_config = mock_draft_model_config
449
+ self.runner.speculative_config = mock_speculative_config
450
+
451
+ kv_cache_spec = self.runner.get_kv_cache_spec()
452
+
453
+ assert "draft_layer.0" in kv_cache_spec
454
+ draft_spec = kv_cache_spec["draft_layer.0"]
455
+ assert isinstance(draft_spec, FullAttentionSpec)
456
+ assert draft_spec.block_size == self.runner.vllm_config.cache_config.block_size
457
+ assert draft_spec.num_kv_heads == common_utils.get_padded_num_heads(
458
+ 4, self.runner.mesh.shape["model"])
459
+ assert draft_spec.head_size == common_utils.get_padded_head_dim(128)
460
+ assert draft_spec.dtype == torch.bfloat16
461
+
462
+ def test_get_kv_cache_spec_with_eagle3_mla(self):
463
+ # tests we create kv cache spec for eagle3 draft model with mla
464
+ self.runner.kv_cache_manager.use_mla = True
465
+
466
+ self.runner.vllm_config.compilation_config.static_forward_context = {}
467
+ mock_speculative_config = MagicMock()
468
+ mock_speculative_config.method = "eagle3"
469
+ mock_draft_model_config = MagicMock()
470
+ mock_hf_config = MagicMock()
471
+ mock_hf_config.num_key_value_heads = 4
472
+ mock_hf_config.hidden_size = 1024
473
+ mock_hf_config.num_attention_heads = 8
474
+ mock_hf_config.num_layers = 16
475
+ model_layers = 1
476
+ mock_hf_text_config = MagicMock()
477
+ mock_hf_text_config.kv_lora_rank = 400
478
+ mock_hf_text_config.qk_rope_head_dim = 40
479
+ self.runner.model_config.hf_text_config = mock_hf_text_config
480
+ mock_draft_model_config.hf_config = mock_hf_config
481
+ mock_speculative_config.draft_model_config = mock_draft_model_config
482
+ self.runner.speculative_config = mock_speculative_config
483
+
484
+ kv_cache_spec = self.runner.get_kv_cache_spec()
485
+
486
+ # Without compilation context, it will create specs for the main model layers
487
+ # as well as the draft model layer.
488
+ assert len(kv_cache_spec) > model_layers
489
+
490
+ assert "draft_layer.0" in kv_cache_spec
491
+ draft_spec = kv_cache_spec["draft_layer.0"]
492
+ assert isinstance(draft_spec, FullAttentionSpec)
493
+
494
+ for i in range(model_layers):
495
+ assert f"layer.{i}" in kv_cache_spec
496
+ spec = kv_cache_spec[f"layer.{i}"]
497
+ assert isinstance(spec, MLAAttentionSpec)
498
+ assert spec.num_kv_heads == 1