tpu-inference 0.12.0.dev20251222__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (260) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_core_tpu.py +513 -0
  4. tests/core/test_disagg_executor.py +60 -0
  5. tests/core/test_disagg_utils.py +67 -0
  6. tests/core/test_dp_scheduler.py +724 -0
  7. tests/core/test_init.py +63 -0
  8. tests/distributed/__init__.py +13 -0
  9. tests/distributed/test_distributed_utils.py +120 -0
  10. tests/distributed/test_tpu_connector.py +478 -0
  11. tests/e2e/__init__.py +13 -0
  12. tests/e2e/test_async_scheduler.py +211 -0
  13. tests/e2e/test_data_parallel.py +393 -0
  14. tests/e2e/test_local_disagg.py +257 -0
  15. tests/e2e/test_model_loader.py +268 -0
  16. tests/e2e/test_multi_modal_inference.py +111 -0
  17. tests/e2e/test_pipeline_parallel.py +265 -0
  18. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  19. tests/e2e/test_sampling_params.py +269 -0
  20. tests/e2e/test_speculative_decoding.py +291 -0
  21. tests/e2e/test_structured_decoding.py +46 -0
  22. tests/executors/__init__.py +13 -0
  23. tests/executors/test_ray_distributed_executor.py +199 -0
  24. tests/experimental/__init__.py +13 -0
  25. tests/experimental/test_llama3_jax_stashed.py +208 -0
  26. tests/kernels/__init__.py +13 -0
  27. tests/kernels/collectives/__init__.py +13 -0
  28. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  29. tests/kernels/fused_moe_v1_test.py +388 -0
  30. tests/kernels/gmm_test.py +205 -0
  31. tests/kernels/mla_v1_test.py +498 -0
  32. tests/kernels/quantized_matmul_kernel_test.py +159 -0
  33. tests/kernels/ragged_kv_cache_update_v2_test.py +248 -0
  34. tests/kernels/ragged_paged_attention_kernel_v2_test.py +414 -0
  35. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +565 -0
  36. tests/kernels/ragged_paged_attention_kernel_v3_test.py +520 -0
  37. tests/layers/__init__.py +13 -0
  38. tests/layers/common/__init__.py +13 -0
  39. tests/layers/common/test_attention_interface.py +156 -0
  40. tests/layers/common/test_quantization.py +149 -0
  41. tests/layers/jax/__init__.py +13 -0
  42. tests/layers/jax/attention/__init__.py +13 -0
  43. tests/layers/jax/attention/test_common_attention.py +103 -0
  44. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  45. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  46. tests/layers/jax/moe/__init__.py +13 -0
  47. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  48. tests/layers/jax/sample/__init__.py +13 -0
  49. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  50. tests/layers/jax/sample/test_sampling.py +115 -0
  51. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  52. tests/layers/jax/test_layers.py +155 -0
  53. tests/layers/jax/test_qwix.py +969 -0
  54. tests/layers/jax/test_rope.py +93 -0
  55. tests/layers/jax/test_sharding.py +159 -0
  56. tests/layers/jax/test_transformer_block.py +152 -0
  57. tests/layers/vllm/__init__.py +13 -0
  58. tests/layers/vllm/test_attention.py +363 -0
  59. tests/layers/vllm/test_awq.py +405 -0
  60. tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +403 -0
  62. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +426 -0
  63. tests/layers/vllm/test_fp8.py +17 -0
  64. tests/layers/vllm/test_mxfp4.py +297 -0
  65. tests/layers/vllm/test_unquantized.py +621 -0
  66. tests/layers/vllm/utils.py +72 -0
  67. tests/lora/__init__.py +13 -0
  68. tests/lora/conftest.py +46 -0
  69. tests/lora/test_bgmv.py +57 -0
  70. tests/lora/test_layers.py +666 -0
  71. tests/lora/test_lora.py +147 -0
  72. tests/lora/test_lora_perf.py +67 -0
  73. tests/lora/utils.py +88 -0
  74. tests/models/__init__.py +13 -0
  75. tests/models/common/__init__.py +13 -0
  76. tests/models/common/test_model_loader.py +455 -0
  77. tests/models/jax/__init__.py +13 -0
  78. tests/models/jax/test_deepseek_v3.py +401 -0
  79. tests/models/jax/test_llama3.py +184 -0
  80. tests/models/jax/test_llama4.py +298 -0
  81. tests/models/jax/test_llama_eagle3.py +197 -0
  82. tests/models/jax/test_llama_guard_4.py +242 -0
  83. tests/models/jax/test_qwen2.py +172 -0
  84. tests/models/jax/test_qwen2_5_vl.py +606 -0
  85. tests/models/jax/test_qwen3.py +169 -0
  86. tests/models/jax/test_weight_loading.py +180 -0
  87. tests/models/jax/utils/__init__.py +13 -0
  88. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  89. tests/platforms/__init__.py +13 -0
  90. tests/platforms/test_tpu_platform.py +54 -0
  91. tests/runner/__init__.py +13 -0
  92. tests/runner/test_block_table.py +395 -0
  93. tests/runner/test_input_batch.py +226 -0
  94. tests/runner/test_kv_cache.py +220 -0
  95. tests/runner/test_kv_cache_manager.py +498 -0
  96. tests/runner/test_multimodal_manager.py +429 -0
  97. tests/runner/test_persistent_batch_manager.py +84 -0
  98. tests/runner/test_speculative_decoding_manager.py +368 -0
  99. tests/runner/test_structured_decoding_manager.py +220 -0
  100. tests/runner/test_tpu_runner.py +202 -0
  101. tests/runner/test_tpu_runner_dp.py +1033 -0
  102. tests/runner/test_tpu_runner_mesh.py +200 -0
  103. tests/runner/test_utils.py +411 -0
  104. tests/spec_decode/__init__.py +13 -0
  105. tests/spec_decode/test_eagle3.py +311 -0
  106. tests/test_base.py +215 -0
  107. tests/test_envs.py +280 -0
  108. tests/test_tpu_info.py +134 -0
  109. tests/test_utils.py +193 -0
  110. tests/worker/__init__.py +13 -0
  111. tests/worker/tpu_worker_test.py +414 -0
  112. tpu_inference/__init__.py +67 -0
  113. tpu_inference/core/__init__.py +13 -0
  114. tpu_inference/core/core_tpu.py +786 -0
  115. tpu_inference/core/disagg_executor.py +118 -0
  116. tpu_inference/core/disagg_utils.py +49 -0
  117. tpu_inference/core/sched/__init__.py +13 -0
  118. tpu_inference/core/sched/dp_scheduler.py +814 -0
  119. tpu_inference/distributed/__init__.py +13 -0
  120. tpu_inference/distributed/jax_parallel_state.py +81 -0
  121. tpu_inference/distributed/tpu_connector.py +732 -0
  122. tpu_inference/distributed/utils.py +112 -0
  123. tpu_inference/env_override.py +9 -0
  124. tpu_inference/envs.py +191 -0
  125. tpu_inference/executors/__init__.py +13 -0
  126. tpu_inference/executors/ray_distributed_executor.py +399 -0
  127. tpu_inference/experimental/__init__.py +13 -0
  128. tpu_inference/experimental/llama3_jax_stashed.py +272 -0
  129. tpu_inference/kernels/__init__.py +13 -0
  130. tpu_inference/kernels/collectives/__init__.py +13 -0
  131. tpu_inference/kernels/collectives/all_gather_matmul.py +741 -0
  132. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +65 -0
  133. tpu_inference/kernels/collectives/util.py +47 -0
  134. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  135. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  136. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  137. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  138. tpu_inference/kernels/fused_moe/v1/kernel.py +1612 -0
  139. tpu_inference/kernels/megablox/__init__.py +13 -0
  140. tpu_inference/kernels/megablox/common.py +54 -0
  141. tpu_inference/kernels/megablox/gmm.py +646 -0
  142. tpu_inference/kernels/mla/__init__.py +13 -0
  143. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  144. tpu_inference/kernels/mla/v1/kernel.py +1340 -0
  145. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  146. tpu_inference/kernels/quantized_matmul/kernel.py +456 -0
  147. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  148. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  149. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  150. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  151. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +876 -0
  152. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +288 -0
  153. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  154. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  155. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1594 -0
  156. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1586 -0
  157. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4460 -0
  158. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +548 -0
  159. tpu_inference/kernels/ragged_paged_attention/v3/util.py +65 -0
  160. tpu_inference/layers/__init__.py +13 -0
  161. tpu_inference/layers/common/__init__.py +13 -0
  162. tpu_inference/layers/common/attention_interface.py +403 -0
  163. tpu_inference/layers/common/attention_metadata.py +48 -0
  164. tpu_inference/layers/common/binary_search.py +295 -0
  165. tpu_inference/layers/common/quant_methods.py +23 -0
  166. tpu_inference/layers/common/quantization.py +270 -0
  167. tpu_inference/layers/common/sharding.py +600 -0
  168. tpu_inference/layers/jax/__init__.py +13 -0
  169. tpu_inference/layers/jax/attention/__init__.py +13 -0
  170. tpu_inference/layers/jax/attention/attention.py +268 -0
  171. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +547 -0
  172. tpu_inference/layers/jax/attention/gpt_oss_attention.py +275 -0
  173. tpu_inference/layers/jax/attention/llama4_attention.py +167 -0
  174. tpu_inference/layers/jax/base.py +165 -0
  175. tpu_inference/layers/jax/constants.py +101 -0
  176. tpu_inference/layers/jax/layers.py +315 -0
  177. tpu_inference/layers/jax/misc.py +30 -0
  178. tpu_inference/layers/jax/moe/__init__.py +13 -0
  179. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +615 -0
  180. tpu_inference/layers/jax/moe/gpt_oss_moe.py +199 -0
  181. tpu_inference/layers/jax/moe/moe.py +249 -0
  182. tpu_inference/layers/jax/pp_utils.py +53 -0
  183. tpu_inference/layers/jax/rope.py +294 -0
  184. tpu_inference/layers/jax/rope_interface.py +228 -0
  185. tpu_inference/layers/jax/sample/__init__.py +13 -0
  186. tpu_inference/layers/jax/sample/rejection_sampler.py +528 -0
  187. tpu_inference/layers/jax/sample/sampling.py +110 -0
  188. tpu_inference/layers/jax/sample/sampling_metadata.py +90 -0
  189. tpu_inference/layers/jax/transformer_block.py +121 -0
  190. tpu_inference/layers/vllm/__init__.py +13 -0
  191. tpu_inference/layers/vllm/attention.py +221 -0
  192. tpu_inference/layers/vllm/fused_moe.py +502 -0
  193. tpu_inference/layers/vllm/linear_common.py +221 -0
  194. tpu_inference/layers/vllm/quantization/__init__.py +55 -0
  195. tpu_inference/layers/vllm/quantization/awq.py +221 -0
  196. tpu_inference/layers/vllm/quantization/common.py +124 -0
  197. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  198. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +135 -0
  199. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +266 -0
  200. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  201. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +222 -0
  202. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +150 -0
  203. tpu_inference/layers/vllm/quantization/fp8.py +118 -0
  204. tpu_inference/layers/vllm/quantization/mxfp4.py +396 -0
  205. tpu_inference/layers/vllm/quantization/unquantized.py +416 -0
  206. tpu_inference/layers/vllm/sharding.py +244 -0
  207. tpu_inference/logger.py +10 -0
  208. tpu_inference/lora/__init__.py +13 -0
  209. tpu_inference/lora/torch_lora_ops.py +98 -0
  210. tpu_inference/lora/torch_punica_tpu.py +310 -0
  211. tpu_inference/models/__init__.py +13 -0
  212. tpu_inference/models/common/__init__.py +13 -0
  213. tpu_inference/models/common/model_loader.py +520 -0
  214. tpu_inference/models/jax/__init__.py +13 -0
  215. tpu_inference/models/jax/deepseek_v3.py +978 -0
  216. tpu_inference/models/jax/gpt_oss.py +508 -0
  217. tpu_inference/models/jax/jax_intermediate_tensor.py +93 -0
  218. tpu_inference/models/jax/llama3.py +436 -0
  219. tpu_inference/models/jax/llama4.py +643 -0
  220. tpu_inference/models/jax/llama_eagle3.py +350 -0
  221. tpu_inference/models/jax/llama_guard_4.py +375 -0
  222. tpu_inference/models/jax/qwen2.py +390 -0
  223. tpu_inference/models/jax/qwen2_5_vl.py +1232 -0
  224. tpu_inference/models/jax/qwen3.py +318 -0
  225. tpu_inference/models/jax/utils/__init__.py +13 -0
  226. tpu_inference/models/jax/utils/file_utils.py +110 -0
  227. tpu_inference/models/jax/utils/multi_modal_utils.py +177 -0
  228. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  229. tpu_inference/models/jax/utils/qwix/qwix_utils.py +713 -0
  230. tpu_inference/models/jax/utils/weight_utils.py +621 -0
  231. tpu_inference/models/vllm/__init__.py +13 -0
  232. tpu_inference/models/vllm/vllm_model_wrapper.py +307 -0
  233. tpu_inference/models/vllm/vllm_model_wrapper_context.py +59 -0
  234. tpu_inference/platforms/__init__.py +16 -0
  235. tpu_inference/platforms/tpu_platform.py +258 -0
  236. tpu_inference/runner/__init__.py +13 -0
  237. tpu_inference/runner/block_table.py +122 -0
  238. tpu_inference/runner/compilation_manager.py +890 -0
  239. tpu_inference/runner/input_batch.py +435 -0
  240. tpu_inference/runner/kv_cache.py +166 -0
  241. tpu_inference/runner/kv_cache_manager.py +508 -0
  242. tpu_inference/runner/lora_utils.py +106 -0
  243. tpu_inference/runner/multimodal_manager.py +231 -0
  244. tpu_inference/runner/persistent_batch_manager.py +296 -0
  245. tpu_inference/runner/speculative_decoding_manager.py +262 -0
  246. tpu_inference/runner/structured_decoding_manager.py +101 -0
  247. tpu_inference/runner/tpu_runner.py +1768 -0
  248. tpu_inference/runner/utils.py +426 -0
  249. tpu_inference/spec_decode/__init__.py +13 -0
  250. tpu_inference/spec_decode/jax/__init__.py +13 -0
  251. tpu_inference/spec_decode/jax/eagle3.py +430 -0
  252. tpu_inference/tpu_info.py +92 -0
  253. tpu_inference/utils.py +345 -0
  254. tpu_inference/worker/__init__.py +13 -0
  255. tpu_inference/worker/tpu_worker.py +468 -0
  256. tpu_inference-0.12.0.dev20251222.dist-info/METADATA +106 -0
  257. tpu_inference-0.12.0.dev20251222.dist-info/RECORD +260 -0
  258. tpu_inference-0.12.0.dev20251222.dist-info/WHEEL +5 -0
  259. tpu_inference-0.12.0.dev20251222.dist-info/licenses/LICENSE +201 -0
  260. tpu_inference-0.12.0.dev20251222.dist-info/top_level.txt +2 -0
@@ -0,0 +1,508 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import functools
16
+ from typing import TYPE_CHECKING, Dict, List
17
+
18
+ import jax
19
+ import jax.numpy as jnp
20
+ import numpy as np
21
+ import vllm.envs as envs
22
+ from jax.sharding import NamedSharding, PartitionSpec
23
+ from torchax.ops.mappings import t2j_dtype
24
+ from vllm.attention.backends.abstract import AttentionType
25
+ from vllm.attention.layer import Attention
26
+ from vllm.config import get_layers_from_vllm_config
27
+ from vllm.utils.math_utils import cdiv
28
+ from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
29
+ KVCacheSpec, MLAAttentionSpec,
30
+ SlidingWindowSpec)
31
+
32
+ from tpu_inference import utils
33
+ from tpu_inference import utils as common_utils
34
+ from tpu_inference.logger import init_logger
35
+ from tpu_inference.runner import utils as runner_utils
36
+ from tpu_inference.runner.input_batch import CachedRequestState, InputBatch
37
+ from tpu_inference.runner.kv_cache import create_kv_caches
38
+
39
+ if TYPE_CHECKING:
40
+ from vllm.v1.request import Request
41
+
42
+ from tpu_inference.runner.tpu_runner import TPUModelRunner
43
+
44
+ logger = init_logger(__name__)
45
+
46
+
47
+ class KVCacheManager:
48
+
49
+ def __init__(self, runner: "TPUModelRunner"):
50
+ self.runner = runner
51
+ # Layer pairings for cross-layer KV sharing.
52
+ # If an Attention layer `layer_name` is in the keys of this dict, it
53
+ # means this layer will perform attention using the keys and values
54
+ # from the KV cache of `shared_kv_cache_layers[layer_name]`.
55
+ self.shared_kv_cache_layers: dict[str, str] = {}
56
+ self.use_mla = self.runner.model_config.use_mla
57
+
58
+ def get_kv_cache_spec(self):
59
+ # TODO(xiang): this hack tricks engine core to init successfully
60
+ block_size = self.runner.cache_config.block_size
61
+ kv_cache_spec: dict[str, KVCacheSpec] = {}
62
+
63
+ # If use pure jax (MODEL_IMPL_TYPE=flax_nnx), we don't register
64
+ # attention into compilation config.
65
+ # Use FullAttentionSpec for each layer
66
+ # TODO(pooyam): Is it possible to merge the logic for vllm and non-vllm models?
67
+ model_config = self.runner.model_config
68
+ if self.use_mla:
69
+ # Individually pad the RopE and latents
70
+ qk_rope_head_dim = getattr(model_config.hf_text_config,
71
+ "qk_rope_head_dim", 0)
72
+ padded_kv_lora_rank = common_utils.align_to(
73
+ model_config.hf_text_config.kv_lora_rank, 128)
74
+ padded_qk_rope_head_dim = common_utils.align_to(
75
+ qk_rope_head_dim, 128)
76
+ mla_head_size = padded_kv_lora_rank + padded_qk_rope_head_dim
77
+
78
+ if len(self.runner.vllm_config.compilation_config.
79
+ static_forward_context) == 0:
80
+ parallel_config = self.runner.parallel_config
81
+ # Pad num_kv_heads to multiple of TP size.
82
+ num_kv_heads = common_utils.get_padded_num_heads(
83
+ model_config.get_total_num_kv_heads(),
84
+ self.runner.mesh.shape["model"])
85
+ head_size = common_utils.get_padded_head_dim(
86
+ model_config.get_head_size())
87
+ for i in range(model_config.get_num_layers(parallel_config)):
88
+ if self.use_mla:
89
+ kv_cache_spec[f"layer.{i}"] = MLAAttentionSpec(
90
+ block_size=block_size,
91
+ num_kv_heads=1,
92
+ head_size=mla_head_size,
93
+ dtype=self.runner.kv_cache_dtype,
94
+ cache_dtype_str=self.runner.vllm_config.cache_config.
95
+ cache_dtype)
96
+ else:
97
+ kv_cache_spec[f"layer.{i}"] = FullAttentionSpec(
98
+ block_size=block_size,
99
+ num_kv_heads=num_kv_heads,
100
+ head_size=head_size,
101
+ dtype=self.runner.kv_cache_dtype)
102
+ if self.runner.speculative_config and self.runner.speculative_config.method == "eagle3":
103
+ draft_model_config = self.runner.speculative_config.draft_model_config
104
+ hf_config = draft_model_config.hf_config
105
+ num_kv_heads = common_utils.get_padded_num_heads(
106
+ hf_config.num_key_value_heads,
107
+ self.runner.mesh.shape["model"])
108
+ head_size = common_utils.get_padded_head_dim(
109
+ hf_config.hidden_size // hf_config.num_attention_heads)
110
+ # Eagle3 has only 1 layer
111
+ for i in range(1):
112
+ if self.use_mla:
113
+ kv_cache_spec[f"draft_layer.{i}"] = MLAAttentionSpec(
114
+ block_size=block_size,
115
+ num_kv_heads=1,
116
+ head_size=mla_head_size,
117
+ dtype=self.runner.kv_cache_dtype,
118
+ cache_dtype_str=self.runner.vllm_config.
119
+ cache_config.cache_dtype)
120
+ else:
121
+ kv_cache_spec[f"draft_layer.{i}"] = FullAttentionSpec(
122
+ block_size=block_size,
123
+ num_kv_heads=num_kv_heads,
124
+ head_size=head_size,
125
+ dtype=self.runner.kv_cache_dtype)
126
+ else:
127
+ # Else propagate attention modules from compilation config.
128
+ layers = get_layers_from_vllm_config(self.runner.vllm_config,
129
+ Attention)
130
+ logger.warning(f"Compilation num_layers = {len(layers.items())}")
131
+ for layer_name, attn_module in layers.items():
132
+ if (kv_tgt_layer :=
133
+ attn_module.kv_sharing_target_layer_name) is not None:
134
+ # The layer doesn't need its own KV cache and will use that of
135
+ # the target layer. We skip creating a KVCacheSpec for it, so
136
+ # that KV cache management logic will act as this layer does
137
+ # not exist, and doesn't allocate KV cache for the layer. This
138
+ # enables the memory saving of cross-layer kv sharing, allowing
139
+ # a given amount of memory to accommodate longer context lengths
140
+ # or enable more requests to be processed simultaneously.
141
+ self.shared_kv_cache_layers[layer_name] = kv_tgt_layer
142
+ continue
143
+ if attn_module.attn_type == AttentionType.DECODER:
144
+ if attn_module.sliding_window is not None:
145
+ kv_cache_spec[layer_name] = SlidingWindowSpec(
146
+ block_size=block_size,
147
+ num_kv_heads=common_utils.get_padded_num_heads(
148
+ attn_module.num_kv_heads,
149
+ self.runner.mesh.shape["model"]),
150
+ head_size=common_utils.get_padded_head_dim(
151
+ attn_module.head_size),
152
+ dtype=self.runner.kv_cache_dtype,
153
+ sliding_window=attn_module.sliding_window)
154
+ elif self.use_mla:
155
+ kv_cache_spec[layer_name] = MLAAttentionSpec(
156
+ block_size=block_size,
157
+ num_kv_heads=1,
158
+ head_size=mla_head_size,
159
+ dtype=self.runner.kv_cache_dtype,
160
+ cache_dtype_str=self.runner.vllm_config.
161
+ cache_config.cache_dtype)
162
+ else:
163
+ kv_cache_spec[layer_name] = FullAttentionSpec(
164
+ block_size=block_size,
165
+ num_kv_heads=common_utils.get_padded_num_heads(
166
+ attn_module.num_kv_heads,
167
+ self.runner.mesh.shape["model"]),
168
+ head_size=common_utils.get_padded_head_dim(
169
+ attn_module.head_size),
170
+ dtype=self.runner.kv_cache_dtype)
171
+ elif attn_module.attn_type in (AttentionType.ENCODER,
172
+ AttentionType.ENCODER_ONLY):
173
+ # encoder-only attention does not need KV cache.
174
+ continue
175
+ elif attn_module.attn_type == AttentionType.ENCODER_DECODER:
176
+ raise NotImplementedError
177
+ else:
178
+ raise ValueError(
179
+ f"Unknown attention type: {attn_module.attn_type}")
180
+ return kv_cache_spec
181
+
182
+ def maybe_reinitialize_input_batch(self,
183
+ kv_cache_config: KVCacheConfig) -> None:
184
+ block_sizes = [
185
+ kv_cache_group.kv_cache_spec.block_size
186
+ for kv_cache_group in kv_cache_config.kv_cache_groups
187
+ ]
188
+ if block_sizes != [self.runner.cache_config.block_size]:
189
+ assert self.runner.cache_config.cpu_offload_gb == 0, (
190
+ "Cannot re-initialize the input batch when CPU weight "
191
+ "offloading is enabled. See https://github.com/vllm-project/vllm/pull/18298 " # noqa: E501
192
+ "for more details.")
193
+ new_input_batch = InputBatch(
194
+ max_num_reqs=self.runner.max_num_reqs,
195
+ max_model_len=self.runner.max_model_len,
196
+ max_num_batched_tokens=self.runner.max_num_tokens,
197
+ pin_memory=False,
198
+ vocab_size=self.runner.model_config.get_vocab_size(),
199
+ block_sizes=block_sizes,
200
+ )
201
+ self.runner.input_batch = new_input_batch
202
+ self.runner.persistent_batch_manager.input_batch = new_input_batch
203
+ self.runner.block_tables_cpu = [
204
+ np.zeros((self.runner.max_num_reqs,
205
+ cdiv(self.runner.max_model_len, block_size)),
206
+ dtype=np.int32) for block_size in block_sizes
207
+ ]
208
+
209
+ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
210
+ self.maybe_reinitialize_input_batch(kv_cache_config)
211
+
212
+ # uniform page size.
213
+ representative_spec = kv_cache_config.kv_cache_groups[0].kv_cache_spec
214
+ page_size_bytes = representative_spec.page_size_bytes
215
+ self.runner.layer_name_to_kvcache_index: Dict[str, int] = {}
216
+ kv_caches = self.runner.kv_caches
217
+ num_blocks_list = []
218
+ for i, kv_cache_tensor in enumerate(kv_cache_config.kv_cache_tensors):
219
+ assert kv_cache_tensor.size % page_size_bytes == 0
220
+ num_blocks = kv_cache_tensor.size // page_size_bytes
221
+ dp_size = self.runner.vllm_config.sharding_config.total_dp_size
222
+ # num_blocks must be a multiple of dp_size
223
+ num_blocks = (num_blocks // dp_size) * dp_size
224
+ # NOTE: we'll multiply the num_kv_heads by 2 in the function
225
+ if self.use_mla:
226
+ head_size = self.runner.model_config.hf_config.kv_lora_rank + \
227
+ self.runner.model_config.hf_config.qk_rope_head_dim
228
+ else:
229
+ head_size = representative_spec.head_size
230
+ kv_cache = create_kv_caches(
231
+ num_blocks=num_blocks,
232
+ block_size=representative_spec.block_size,
233
+ num_kv_heads=representative_spec.num_kv_heads,
234
+ head_size=head_size,
235
+ mesh=self.runner.mesh,
236
+ layer_names=[f'kv_cache_tensor.{i}'],
237
+ cache_dtype=t2j_dtype(representative_spec.dtype),
238
+ use_mla=self.use_mla,
239
+ )[0]
240
+ kv_caches.append(kv_cache)
241
+ num_blocks_list.append(num_blocks)
242
+ for layer_name in kv_cache_tensor.shared_by:
243
+ self.runner.layer_name_to_kvcache_index[layer_name] = i
244
+
245
+ if self.shared_kv_cache_layers:
246
+ for layer_name, target_layer_name in self.shared_kv_cache_layers.items(
247
+ ):
248
+ self.runner.layer_name_to_kvcache_index[
249
+ layer_name] = self.runner.layer_name_to_kvcache_index[
250
+ target_layer_name]
251
+
252
+ logger.info(
253
+ f"Init kv-cache | "
254
+ f"num_layers={len(kv_caches)} | "
255
+ f"shape=(num_blocks, {kv_caches[0].shape[1:]}) | "
256
+ f"num_blocks={num_blocks_list} | "
257
+ f"sharding={kv_caches[0].sharding} | "
258
+ f"dtype={kv_caches[0].dtype} | "
259
+ f"hbm={utils.hbm_usage_gb(self.runner.mesh.devices.flatten())}Gb")
260
+
261
+ @staticmethod
262
+ @functools.partial(jax.jit)
263
+ def _jitted_gather_kv_cache(kv_caches: List[jax.Array],
264
+ block_ids: jax.Array) -> List[jax.Array]:
265
+ """
266
+ JIT-compiled function to gather KV cache slices for all layers at once.
267
+ This uses jax.tree.map to apply the operation across all layers.
268
+ """
269
+
270
+ def gather_and_reshape(layer_kv_cache):
271
+ return layer_kv_cache.at[block_ids].get().reshape(
272
+ -1, *layer_kv_cache.shape[2:])
273
+
274
+ return jax.tree.map(gather_and_reshape, kv_caches)
275
+
276
+ @staticmethod
277
+ @functools.partial(
278
+ jax.jit,
279
+ static_argnames=("len_block"),
280
+ )
281
+ def _jitted_gather_continuous_kv_cache(kv_caches: List[jax.Array],
282
+ start_block,
283
+ len_block) -> List[jax.Array]:
284
+ """
285
+ JIT-compiled function to gather KV cache slices for all layers at once.
286
+ This uses jax.tree.map to apply the operation across all layers.
287
+ """
288
+
289
+ def gather_and_reshape(layer_kv_cache):
290
+ shape = layer_kv_cache.shape
291
+ return jax.lax.dynamic_slice_in_dim(layer_kv_cache,
292
+ start_block,
293
+ len_block,
294
+ axis=0).reshape(
295
+ -1, *shape[2:])
296
+
297
+ return jax.tree.map(gather_and_reshape, kv_caches)
298
+
299
+ @staticmethod
300
+ @functools.partial(
301
+ jax.jit,
302
+ static_argnames=("block_size"),
303
+ donate_argnames=(
304
+ "kv_caches",
305
+ "kv_cache_slices",
306
+ ),
307
+ )
308
+ def _jitted_insert_kv_cache(
309
+ block_size,
310
+ kv_caches: List[jax.Array],
311
+ kv_cache_slices: List[jax.Array],
312
+ block_numbers: jax.Array,
313
+ ) -> List[jax.Array]:
314
+ """
315
+ JIT-compiled function to insert KV cache slices into the physical
316
+ cache for all layers at once. This fuses the pad, reshape, and scatter
317
+ operations into a single efficient kernel.
318
+ """
319
+
320
+ def _update_layer(cache, slices):
321
+ """The function to apply to each layer's cache and slices."""
322
+ reshaped_slices = slices.reshape(-1, block_size, *slices.shape[1:])
323
+ cache.at[block_numbers].set(reshaped_slices)
324
+ return cache
325
+
326
+ return jax.tree.map(_update_layer, kv_caches, kv_cache_slices)
327
+
328
+ @staticmethod
329
+ @functools.partial(
330
+ jax.jit,
331
+ static_argnames=("block_size"),
332
+ donate_argnames=(
333
+ "kv_caches",
334
+ "kv_cache_slices",
335
+ ),
336
+ )
337
+ def _jitted_insert_continuous_kv_cache(
338
+ block_size,
339
+ kv_caches: List[jax.Array],
340
+ kv_cache_slices: List[jax.Array],
341
+ start_block,
342
+ ) -> List[jax.Array]:
343
+ """
344
+ JIT-compiled function to insert KV cache slices into continuous blocks.
345
+ Makes use of dynamic_update_slice_in_dim.
346
+ """
347
+
348
+ def _update_layer(cache, slices):
349
+ """The function to apply to each layer's cache and slices."""
350
+ reshaped_slices = slices.reshape(-1, block_size, *slices.shape[1:])
351
+
352
+ return jax.lax.dynamic_update_slice_in_dim(cache,
353
+ reshaped_slices,
354
+ start_block,
355
+ axis=0)
356
+
357
+ return jax.tree.map(_update_layer, kv_caches, kv_cache_slices)
358
+
359
+ def get_kv_cache_for_block_ids(
360
+ self,
361
+ block_ids: List[int],
362
+ ) -> List[jax.Array]:
363
+ """
364
+ Extracts the KV cache slices for a given list of block IDs.
365
+ This assumes all provided blocks are full.
366
+
367
+ Args:
368
+ block_ids: A list of block IDs to extract KV cache for.
369
+
370
+ Returns:
371
+ A list of JAX arrays, with each array representing the KV cache
372
+ slices for a layer, concatenated for all blocks.
373
+ """
374
+ if block_ids == list(range(block_ids[0],
375
+ block_ids[0] + len(block_ids))):
376
+ batched_kv_cache_per_layer = self._jitted_gather_continuous_kv_cache(
377
+ self.runner.kv_caches, block_ids[0], len(block_ids))
378
+
379
+ else:
380
+ batched_kv_cache_per_layer = self._jitted_gather_kv_cache(
381
+ self.runner.kv_caches, jnp.array(block_ids))
382
+ return batched_kv_cache_per_layer
383
+
384
+ def transfer_kv_cache(self,
385
+ kv_cache_slices: List[jax.Array]) -> List[jax.Array]:
386
+ """
387
+ Transfers KV cache slices to the runner's mesh.
388
+
389
+ This is used when a KV cache generated on one runner (e.g., a prefill
390
+ runner) needs to be used on another runner (e.g., a decode runner)
391
+ with a different device mesh. The transfer is asynchronous.
392
+
393
+ Args:
394
+ kv_cache_slices: A list of JAX arrays, where each array contains
395
+ the KV cache slices for a specific layer. The shape of each
396
+ slice is expected to be (num_tokens, num_kv_heads * 2, head_size).
397
+
398
+ Returns:
399
+ A new list of JAX arrays representing the KV cache slices, sharded
400
+ across the runner's device mesh.
401
+ """
402
+ # The KV cache slices have a shape of (num_tokens, num_kv_heads * 2, head_size).
403
+ # We shard along the num_kv_heads dimension (axis=1), which corresponds
404
+ # to the "model" axis of the mesh for tensor parallelism.
405
+ logger.debug(
406
+ f"Transferring kv cache shape {len(kv_cache_slices)} * {kv_cache_slices[0].shape} sharding {kv_cache_slices[0].sharding} size {kv_cache_slices[0].nbytes * len(kv_cache_slices)/1024/1024} Mbytes"
407
+ )
408
+ sharding = NamedSharding(self.runner.mesh,
409
+ PartitionSpec(None, "model"))
410
+ if envs.VLLM_TPU_USING_PATHWAYS:
411
+ from pathwaysutils.experimental import \
412
+ reshard as experimental_reshard
413
+
414
+ def get_sharding(x):
415
+ return sharding
416
+
417
+ sharding_spec_pytree = jax.tree.map(get_sharding, kv_cache_slices)
418
+ transferred_kv_cache = experimental_reshard.reshard(
419
+ kv_cache_slices,
420
+ sharding_spec_pytree,
421
+ donate=False,
422
+ )
423
+ else:
424
+ transferred_kv_cache = jax.device_put(kv_cache_slices, sharding)
425
+
426
+ jax.block_until_ready(transferred_kv_cache)
427
+ return transferred_kv_cache
428
+
429
+ def insert_request_with_kv_cache(
430
+ self,
431
+ request: "Request",
432
+ kv_cache_slices: List[jax.Array],
433
+ block_ids: List[List[int]],
434
+ ):
435
+ """
436
+ Inserts a request and its KV cache into the runner. This is used to
437
+ transfer a request from a prefill runner to a decode runner.
438
+
439
+ The provided KV cache slices are copied into the physical blocks
440
+ allocated for the request. The runner's internal state is then updated
441
+ to include the request.
442
+
443
+ Args:
444
+ request: The vLLM request object, containing the state after prefill.
445
+ kv_cache_slices: The KV cache for the request, already transferred
446
+ to this runner's mesh. This is a list of JAX arrays, one per layer.
447
+ block_ids: The physical block numbers allocated for this request on
448
+ this runner. This is a list of lists, for each KV cache group.
449
+ """
450
+ # Assume one KV cache group for now, which is consistent with current setup.
451
+ if len(block_ids) > 1:
452
+ raise NotImplementedError(
453
+ "Inserting KV cache for models with multiple KV cache groups "
454
+ "is not supported yet.")
455
+ block_numbers = block_ids[0]
456
+ if block_numbers == list(
457
+ range(block_numbers[0],
458
+ block_numbers[0] + len(block_numbers))):
459
+ # For continuous blocks we use slice instead of scatter.
460
+ start_block = block_numbers[0]
461
+ with runner_utils.LatencyTracker(
462
+ f"JittedInsertContinuousKVCache-b{len(block_numbers)}"):
463
+ logger.debug(f"inserting to continuous blocks {block_numbers}")
464
+ self.runner.kv_caches = KVCacheManager._jitted_insert_continuous_kv_cache(
465
+ self.runner.block_size,
466
+ self.runner.kv_caches,
467
+ kv_cache_slices,
468
+ start_block,
469
+ )
470
+ jax.block_until_ready(self.runner.kv_caches)
471
+ else:
472
+ with runner_utils.LatencyTracker(
473
+ f"JittedInsertKVCache-b{len(block_numbers)}"):
474
+ logger.debug(
475
+ f"inserting to non continuous blocks {block_numbers}")
476
+ self.runner.kv_caches = KVCacheManager._jitted_insert_kv_cache(
477
+ self.runner.block_size,
478
+ self.runner.kv_caches,
479
+ kv_cache_slices,
480
+ jnp.array(block_numbers),
481
+ )
482
+ jax.block_until_ready(self.runner.kv_caches)
483
+
484
+ logger.debug(
485
+ f"Updated kv cache entries cnt={len(self.runner.kv_caches)}")
486
+
487
+ # Update runner's internal state to track the new request.
488
+ req_id = request.request_id
489
+ if req_id in self.runner.requests:
490
+ logger.warning(
491
+ f"Request {req_id} already exists in the runner. Overwriting.")
492
+
493
+ # Create a CachedRequestState object to add to the input batch.
494
+ req_state = CachedRequestState(
495
+ req_id=request.request_id,
496
+ prompt_token_ids=request.prompt_token_ids,
497
+ output_token_ids=[request.all_token_ids[-1]],
498
+ sampling_params=request.sampling_params,
499
+ block_ids=tuple(block_ids),
500
+ num_computed_tokens=request.num_computed_tokens,
501
+ lora_request=request.lora_request,
502
+ mm_features=getattr(request, "mm_features", []),
503
+ pooling_params=getattr(request, "pooling_params", None),
504
+ generator=None,
505
+ )
506
+
507
+ self.runner.requests[req_id] = req_state
508
+ self.runner.input_batch.add_request(req_state)
@@ -0,0 +1,106 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import annotations
16
+
17
+ from typing import TYPE_CHECKING
18
+
19
+ import numpy as np
20
+ from torchax.interop import jax_view
21
+ from vllm.lora.layers.base_linear import BaseLinearLayerWithLoRA
22
+ from vllm.lora.request import LoRARequest
23
+
24
+ from tpu_inference.layers.vllm.sharding import update_lora
25
+
26
+ if TYPE_CHECKING:
27
+ from tpu_inference.runner.tpu_runner import TPUModelRunner
28
+
29
+
30
+ class LoraUtils:
31
+
32
+ def __init__(self, runner: "TPUModelRunner"):
33
+ self.runner = runner
34
+
35
+ def set_active_loras(self, num_scheduled_tokens_per_req,
36
+ total_num_scheduled_tokens,
37
+ padded_total_num_scheduled_tokens):
38
+ # We need to respect padding when activating LoRA adapters
39
+ padded_num_scheduled_tokens_per_req = np.copy(
40
+ num_scheduled_tokens_per_req
41
+ ) # Copying to avoid accidental state corruption bugs
42
+ padded_num_scheduled_tokens_per_req[-1] += \
43
+ padded_total_num_scheduled_tokens - total_num_scheduled_tokens
44
+
45
+ prompt_lora_mapping: tuple[int, ...] # of size input_batch.num_reqs
46
+ token_lora_mapping: tuple[int,
47
+ ...] # of size np.sum(num_scheduled_tokens)
48
+ lora_requests: set[LoRARequest]
49
+ prompt_lora_mapping, token_lora_mapping, lora_requests = \
50
+ self.runner.input_batch.make_lora_inputs(padded_num_scheduled_tokens_per_req)
51
+ # One should not put lora_manager.set_active_loras under
52
+ # torchax.default_env() because set_active_loras also load lora from
53
+ # disk and torchax currently does not support that. Here we load the
54
+ # lora and set the lora weight to the linear layers.
55
+ self.runner._set_active_loras(prompt_lora_mapping, token_lora_mapping,
56
+ lora_requests)
57
+
58
+ params_and_buffers = update_lora(
59
+ self.runner.model.model, initial_params_buffers=self.runner.state)
60
+ self.runner.state = jax_view(params_and_buffers)
61
+
62
+ def extract_lora_metadata(self):
63
+ if self.runner.lora_config is None:
64
+ return None
65
+
66
+ metadata = {}
67
+ punica_wrapper = None
68
+ for _, m in self.runner.model.model.named_modules():
69
+ if isinstance(m, BaseLinearLayerWithLoRA):
70
+ assert getattr(
71
+ m, 'punica_wrapper', None
72
+ ) is not None, 'A lora wrapper should have contained a punica_wrapper'
73
+ punica_wrapper = m.punica_wrapper
74
+ break
75
+ assert punica_wrapper is not None, 'Should have been able to find a punica wrapper from the Lora wrapper.'
76
+
77
+ # vars does not show inherited methods or class attributes but this is
78
+ # fine because we only care about instance attributes.
79
+ for k in vars(punica_wrapper):
80
+ v = getattr(punica_wrapper, k, None)
81
+ if k == 'device': # Exclude string as it can't be traced by jax.jit
82
+ continue
83
+ metadata[k] = v
84
+ return jax_view(metadata)
85
+
86
+
87
+ def replace_lora_metadata(model, metadata: dict, lora_config) -> dict:
88
+ if lora_config is None or not metadata:
89
+ return {}
90
+
91
+ original_metadata = {}
92
+ punica_wrapper = None
93
+ for _, m in model.named_modules():
94
+ if isinstance(m, BaseLinearLayerWithLoRA):
95
+ assert getattr(
96
+ m, 'punica_wrapper', None
97
+ ) is not None, 'A lora wrapper should have contained a punica_wrapper'
98
+ punica_wrapper = m.punica_wrapper
99
+ break
100
+ assert punica_wrapper is not None, 'Should have been able to find a punica wrapper from the Lora wrapper.'
101
+
102
+ for k in vars(punica_wrapper):
103
+ if k in metadata:
104
+ original_metadata[k] = getattr(punica_wrapper, k)
105
+ setattr(punica_wrapper, k, metadata[k])
106
+ return original_metadata