tpu-inference 0.11.1.dev202511220812__py3-none-any.whl → 0.13.2.dev20251230__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (257) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_disagg_utils.py +14 -0
  4. tests/core/test_dp_scheduler.py +650 -768
  5. tests/core/test_init.py +14 -0
  6. tests/distributed/__init__.py +13 -0
  7. tests/distributed/test_distributed_utils.py +120 -0
  8. tests/distributed/test_tpu_connector.py +478 -0
  9. tests/e2e/__init__.py +13 -0
  10. tests/e2e/test_async_scheduler.py +211 -0
  11. tests/e2e/test_data_parallel.py +289 -0
  12. tests/e2e/test_hybrid_kvcache.py +219 -0
  13. tests/e2e/test_local_disagg.py +257 -0
  14. tests/e2e/test_model_loader.py +268 -0
  15. tests/e2e/test_multi_modal_inference.py +111 -0
  16. tests/e2e/test_pipeline_parallel.py +265 -0
  17. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  18. tests/e2e/test_sampling_params.py +269 -0
  19. tests/e2e/test_speculative_decoding.py +311 -0
  20. tests/e2e/test_structured_decoding.py +46 -0
  21. tests/executors/__init__.py +13 -0
  22. tests/executors/test_ray_distributed_executor.py +199 -0
  23. tests/experimental/__init__.py +13 -0
  24. tests/experimental/test_llama3_jax_stashed.py +208 -0
  25. tests/kernels/__init__.py +13 -0
  26. tests/kernels/collectives/__init__.py +13 -0
  27. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  28. tests/kernels/fused_moe_v1_test.py +317 -34
  29. tests/kernels/gmm_test.py +205 -0
  30. tests/kernels/mla_v1_test.py +143 -41
  31. tests/kernels/quantized_matmul_kernel_test.py +2 -34
  32. tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
  33. tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
  34. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
  35. tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
  36. tests/layers/__init__.py +13 -0
  37. tests/layers/common/__init__.py +13 -0
  38. tests/layers/common/test_attention_interface.py +156 -0
  39. tests/layers/common/test_quantization.py +149 -0
  40. tests/layers/jax/__init__.py +13 -0
  41. tests/layers/jax/attention/__init__.py +13 -0
  42. tests/layers/jax/attention/test_common_attention.py +103 -0
  43. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  44. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  45. tests/layers/jax/moe/__init__.py +13 -0
  46. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  47. tests/layers/jax/sample/__init__.py +13 -0
  48. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  49. tests/layers/jax/sample/test_sampling.py +115 -0
  50. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  51. tests/layers/jax/test_layers.py +155 -0
  52. tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
  53. tests/layers/jax/test_rope.py +93 -0
  54. tests/layers/jax/test_sharding.py +159 -0
  55. tests/layers/jax/test_transformer_block.py +152 -0
  56. tests/layers/vllm/__init__.py +13 -0
  57. tests/layers/vllm/test_attention.py +363 -0
  58. tests/layers/vllm/test_awq.py +406 -0
  59. tests/layers/vllm/test_compressed_tensors_moe.py +199 -0
  60. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +441 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +443 -0
  62. tests/layers/vllm/test_fp8.py +17 -0
  63. tests/layers/vllm/test_mxfp4.py +320 -0
  64. tests/layers/vllm/test_unquantized.py +662 -0
  65. tests/layers/vllm/utils.py +87 -0
  66. tests/lora/__init__.py +13 -0
  67. tests/lora/conftest.py +14 -0
  68. tests/lora/test_bgmv.py +14 -0
  69. tests/lora/test_layers.py +26 -6
  70. tests/lora/test_lora.py +15 -1
  71. tests/lora/test_lora_perf.py +67 -0
  72. tests/models/__init__.py +13 -0
  73. tests/models/common/__init__.py +13 -0
  74. tests/models/common/test_model_loader.py +455 -0
  75. tests/models/jax/__init__.py +13 -0
  76. tests/models/jax/test_deepseek_v3.py +401 -0
  77. tests/models/jax/test_llama3.py +184 -0
  78. tests/models/jax/test_llama4.py +298 -0
  79. tests/models/jax/test_llama_eagle3.py +197 -0
  80. tests/models/jax/test_llama_guard_4.py +242 -0
  81. tests/models/jax/test_qwen2.py +172 -0
  82. tests/models/jax/test_qwen2_5_vl.py +605 -0
  83. tests/models/jax/test_qwen3.py +169 -0
  84. tests/models/jax/test_weight_loading.py +180 -0
  85. tests/models/jax/utils/__init__.py +13 -0
  86. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  87. tests/platforms/__init__.py +13 -0
  88. tests/platforms/test_tpu_platform.py +54 -0
  89. tests/runner/__init__.py +13 -0
  90. tests/runner/test_block_table.py +395 -0
  91. tests/runner/test_input_batch.py +226 -0
  92. tests/runner/test_kv_cache.py +220 -0
  93. tests/runner/test_kv_cache_manager.py +498 -0
  94. tests/runner/test_multimodal_manager.py +429 -0
  95. tests/runner/test_persistent_batch_manager.py +84 -0
  96. tests/runner/test_speculative_decoding_manager.py +368 -0
  97. tests/runner/test_structured_decoding_manager.py +220 -0
  98. tests/runner/test_tpu_runner.py +261 -0
  99. tests/runner/test_tpu_runner_dp.py +1099 -0
  100. tests/runner/test_tpu_runner_mesh.py +200 -0
  101. tests/runner/test_utils.py +411 -0
  102. tests/spec_decode/__init__.py +13 -0
  103. tests/spec_decode/test_eagle3.py +311 -0
  104. tests/test_base.py +14 -0
  105. tests/test_envs.py +110 -12
  106. tests/test_tpu_info.py +14 -0
  107. tests/test_utils.py +2 -45
  108. tests/worker/__init__.py +13 -0
  109. tests/worker/tpu_worker_test.py +414 -0
  110. tpu_inference/__init__.py +14 -0
  111. tpu_inference/core/__init__.py +13 -0
  112. tpu_inference/core/sched/__init__.py +13 -0
  113. tpu_inference/core/sched/dp_scheduler.py +372 -56
  114. tpu_inference/distributed/__init__.py +13 -0
  115. tpu_inference/distributed/jax_parallel_state.py +14 -0
  116. tpu_inference/distributed/tpu_connector.py +15 -10
  117. tpu_inference/distributed/utils.py +56 -4
  118. tpu_inference/envs.py +92 -8
  119. tpu_inference/executors/__init__.py +13 -0
  120. tpu_inference/executors/ray_distributed_executor.py +25 -4
  121. tpu_inference/experimental/__init__.py +13 -0
  122. tpu_inference/experimental/llama3_jax_stashed.py +14 -0
  123. tpu_inference/kernels/__init__.py +13 -0
  124. tpu_inference/kernels/collectives/__init__.py +13 -0
  125. tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
  126. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
  127. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  128. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  129. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  130. tpu_inference/kernels/fused_moe/v1/kernel.py +807 -230
  131. tpu_inference/kernels/megablox/__init__.py +13 -0
  132. tpu_inference/kernels/megablox/common.py +54 -0
  133. tpu_inference/kernels/megablox/gmm.py +646 -0
  134. tpu_inference/kernels/mla/__init__.py +13 -0
  135. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  136. tpu_inference/kernels/mla/v1/kernel.py +117 -145
  137. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  138. tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
  139. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  140. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  141. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
  142. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
  143. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  144. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
  145. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +218 -137
  146. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
  147. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
  148. tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
  149. tpu_inference/layers/__init__.py +13 -0
  150. tpu_inference/layers/common/__init__.py +13 -0
  151. tpu_inference/layers/common/attention_interface.py +25 -12
  152. tpu_inference/layers/common/attention_metadata.py +14 -0
  153. tpu_inference/layers/common/fused_moe_gmm.py +506 -0
  154. tpu_inference/layers/common/quant_methods.py +15 -0
  155. tpu_inference/layers/common/quantization.py +282 -0
  156. tpu_inference/layers/common/sharding.py +32 -9
  157. tpu_inference/layers/common/utils.py +94 -0
  158. tpu_inference/layers/jax/__init__.py +13 -0
  159. tpu_inference/layers/jax/attention/__init__.py +13 -0
  160. tpu_inference/layers/jax/attention/attention.py +19 -6
  161. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
  162. tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
  163. tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
  164. tpu_inference/layers/jax/base.py +14 -0
  165. tpu_inference/layers/jax/constants.py +13 -0
  166. tpu_inference/layers/jax/layers.py +14 -0
  167. tpu_inference/layers/jax/misc.py +14 -0
  168. tpu_inference/layers/jax/moe/__init__.py +13 -0
  169. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
  170. tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
  171. tpu_inference/layers/jax/moe/moe.py +43 -3
  172. tpu_inference/layers/jax/pp_utils.py +53 -0
  173. tpu_inference/layers/jax/rope.py +14 -0
  174. tpu_inference/layers/jax/rope_interface.py +14 -0
  175. tpu_inference/layers/jax/sample/__init__.py +13 -0
  176. tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
  177. tpu_inference/layers/jax/sample/sampling.py +15 -1
  178. tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
  179. tpu_inference/layers/jax/transformer_block.py +14 -0
  180. tpu_inference/layers/vllm/__init__.py +13 -0
  181. tpu_inference/layers/vllm/attention.py +4 -4
  182. tpu_inference/layers/vllm/fused_moe.py +101 -494
  183. tpu_inference/layers/vllm/linear.py +64 -0
  184. tpu_inference/layers/vllm/process_weights/__init__.py +13 -0
  185. tpu_inference/layers/vllm/{sharding.py → process_weights/cleanup_sharding.py} +24 -15
  186. tpu_inference/layers/vllm/process_weights/fused_moe_weights.py +369 -0
  187. tpu_inference/layers/vllm/process_weights/linear_weights.py +174 -0
  188. tpu_inference/layers/vllm/quantization/__init__.py +19 -3
  189. tpu_inference/layers/vllm/quantization/awq.py +96 -82
  190. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  191. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +23 -8
  192. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +172 -176
  193. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  194. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +111 -91
  195. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +79 -43
  196. tpu_inference/layers/vllm/quantization/{common.py → configs.py} +42 -25
  197. tpu_inference/layers/vllm/quantization/fp8.py +119 -0
  198. tpu_inference/layers/vllm/quantization/mxfp4.py +137 -178
  199. tpu_inference/layers/vllm/quantization/unquantized.py +157 -233
  200. tpu_inference/lora/__init__.py +13 -0
  201. tpu_inference/lora/torch_lora_ops.py +8 -13
  202. tpu_inference/models/__init__.py +13 -0
  203. tpu_inference/models/common/__init__.py +13 -0
  204. tpu_inference/models/common/model_loader.py +112 -35
  205. tpu_inference/models/jax/__init__.py +13 -0
  206. tpu_inference/models/jax/deepseek_v3.py +267 -157
  207. tpu_inference/models/jax/gpt_oss.py +26 -10
  208. tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
  209. tpu_inference/models/jax/llama3.py +99 -36
  210. tpu_inference/models/jax/llama4.py +14 -0
  211. tpu_inference/models/jax/llama_eagle3.py +18 -5
  212. tpu_inference/models/jax/llama_guard_4.py +15 -1
  213. tpu_inference/models/jax/qwen2.py +17 -2
  214. tpu_inference/models/jax/qwen2_5_vl.py +179 -51
  215. tpu_inference/models/jax/qwen3.py +17 -2
  216. tpu_inference/models/jax/utils/__init__.py +13 -0
  217. tpu_inference/models/jax/utils/file_utils.py +14 -0
  218. tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
  219. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  220. tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +92 -32
  221. tpu_inference/models/jax/utils/weight_utils.py +234 -155
  222. tpu_inference/models/vllm/__init__.py +13 -0
  223. tpu_inference/models/vllm/vllm_model_wrapper.py +32 -8
  224. tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
  225. tpu_inference/platforms/__init__.py +14 -0
  226. tpu_inference/platforms/tpu_platform.py +51 -72
  227. tpu_inference/runner/__init__.py +13 -0
  228. tpu_inference/runner/compilation_manager.py +180 -80
  229. tpu_inference/runner/kv_cache.py +54 -20
  230. tpu_inference/runner/kv_cache_manager.py +55 -33
  231. tpu_inference/runner/lora_utils.py +16 -1
  232. tpu_inference/runner/multimodal_manager.py +16 -2
  233. tpu_inference/runner/persistent_batch_manager.py +54 -2
  234. tpu_inference/runner/speculative_decoding_manager.py +14 -0
  235. tpu_inference/runner/structured_decoding_manager.py +16 -3
  236. tpu_inference/runner/tpu_runner.py +124 -61
  237. tpu_inference/runner/utils.py +2 -2
  238. tpu_inference/spec_decode/__init__.py +13 -0
  239. tpu_inference/spec_decode/jax/__init__.py +13 -0
  240. tpu_inference/spec_decode/jax/eagle3.py +84 -22
  241. tpu_inference/tpu_info.py +14 -0
  242. tpu_inference/utils.py +72 -44
  243. tpu_inference/worker/__init__.py +13 -0
  244. tpu_inference/worker/tpu_worker.py +66 -52
  245. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/METADATA +8 -9
  246. tpu_inference-0.13.2.dev20251230.dist-info/RECORD +266 -0
  247. tpu_inference/layers/vllm/linear_common.py +0 -186
  248. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  249. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
  250. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
  251. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
  252. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
  253. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
  254. tpu_inference-0.11.1.dev202511220812.dist-info/RECORD +0 -174
  255. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/WHEEL +0 -0
  256. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/licenses/LICENSE +0 -0
  257. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,19 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  import functools
2
- from typing import TYPE_CHECKING, Dict, List
16
+ from typing import TYPE_CHECKING, List
3
17
 
4
18
  import jax
5
19
  import jax.numpy as jnp
@@ -7,8 +21,8 @@ import numpy as np
7
21
  import vllm.envs as envs
8
22
  from jax.sharding import NamedSharding, PartitionSpec
9
23
  from torchax.ops.mappings import t2j_dtype
10
- from vllm.attention import Attention
11
24
  from vllm.attention.backends.abstract import AttentionType
25
+ from vllm.attention.layer import Attention
12
26
  from vllm.config import get_layers_from_vllm_config
13
27
  from vllm.utils.math_utils import cdiv
14
28
  from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
@@ -39,20 +53,30 @@ class KVCacheManager:
39
53
  # means this layer will perform attention using the keys and values
40
54
  # from the KV cache of `shared_kv_cache_layers[layer_name]`.
41
55
  self.shared_kv_cache_layers: dict[str, str] = {}
56
+ self.use_mla = self.runner.model_config.use_mla
42
57
 
43
58
  def get_kv_cache_spec(self):
44
59
  # TODO(xiang): this hack tricks engine core to init successfully
45
60
  block_size = self.runner.cache_config.block_size
46
- use_mla = self.runner.model_config.use_mla
47
61
  kv_cache_spec: dict[str, KVCacheSpec] = {}
48
62
 
49
63
  # If use pure jax (MODEL_IMPL_TYPE=flax_nnx), we don't register
50
64
  # attention into compilation config.
51
65
  # Use FullAttentionSpec for each layer
52
66
  # TODO(pooyam): Is it possible to merge the logic for vllm and non-vllm models?
67
+ model_config = self.runner.model_config
68
+ if self.use_mla:
69
+ # Individually pad the RopE and latents
70
+ qk_rope_head_dim = getattr(model_config.hf_text_config,
71
+ "qk_rope_head_dim", 0)
72
+ padded_kv_lora_rank = common_utils.align_to(
73
+ model_config.hf_text_config.kv_lora_rank, 128)
74
+ padded_qk_rope_head_dim = common_utils.align_to(
75
+ qk_rope_head_dim, 128)
76
+ mla_head_size = padded_kv_lora_rank + padded_qk_rope_head_dim
77
+
53
78
  if len(self.runner.vllm_config.compilation_config.
54
79
  static_forward_context) == 0:
55
- model_config = self.runner.model_config
56
80
  parallel_config = self.runner.parallel_config
57
81
  # Pad num_kv_heads to multiple of TP size.
58
82
  num_kv_heads = common_utils.get_padded_num_heads(
@@ -61,11 +85,11 @@ class KVCacheManager:
61
85
  head_size = common_utils.get_padded_head_dim(
62
86
  model_config.get_head_size())
63
87
  for i in range(model_config.get_num_layers(parallel_config)):
64
- if use_mla:
88
+ if self.use_mla:
65
89
  kv_cache_spec[f"layer.{i}"] = MLAAttentionSpec(
66
90
  block_size=block_size,
67
- num_kv_heads=num_kv_heads,
68
- head_size=head_size,
91
+ num_kv_heads=1,
92
+ head_size=mla_head_size,
69
93
  dtype=self.runner.kv_cache_dtype,
70
94
  cache_dtype_str=self.runner.vllm_config.cache_config.
71
95
  cache_dtype)
@@ -83,14 +107,13 @@ class KVCacheManager:
83
107
  self.runner.mesh.shape["model"])
84
108
  head_size = common_utils.get_padded_head_dim(
85
109
  hf_config.hidden_size // hf_config.num_attention_heads)
86
-
87
110
  # Eagle3 has only 1 layer
88
111
  for i in range(1):
89
- if use_mla:
90
- kv_cache_spec[f"layer.{i}"] = MLAAttentionSpec(
112
+ if self.use_mla:
113
+ kv_cache_spec[f"draft_layer.{i}"] = MLAAttentionSpec(
91
114
  block_size=block_size,
92
- num_kv_heads=num_kv_heads,
93
- head_size=head_size,
115
+ num_kv_heads=1,
116
+ head_size=mla_head_size,
94
117
  dtype=self.runner.kv_cache_dtype,
95
118
  cache_dtype_str=self.runner.vllm_config.
96
119
  cache_config.cache_dtype)
@@ -104,6 +127,7 @@ class KVCacheManager:
104
127
  # Else propagate attention modules from compilation config.
105
128
  layers = get_layers_from_vllm_config(self.runner.vllm_config,
106
129
  Attention)
130
+ logger.warning(f"Compilation num_layers = {len(layers.items())}")
107
131
  for layer_name, attn_module in layers.items():
108
132
  if (kv_tgt_layer :=
109
133
  attn_module.kv_sharing_target_layer_name) is not None:
@@ -127,11 +151,11 @@ class KVCacheManager:
127
151
  attn_module.head_size),
128
152
  dtype=self.runner.kv_cache_dtype,
129
153
  sliding_window=attn_module.sliding_window)
130
- elif use_mla:
131
- kv_cache_spec[f"layer.{i}"] = MLAAttentionSpec(
154
+ elif self.use_mla:
155
+ kv_cache_spec[layer_name] = MLAAttentionSpec(
132
156
  block_size=block_size,
133
- num_kv_heads=attn_module.num_kv_heads,
134
- head_size=attn_module.head_size,
157
+ num_kv_heads=1,
158
+ head_size=mla_head_size,
135
159
  dtype=self.runner.kv_cache_dtype,
136
160
  cache_dtype_str=self.runner.vllm_config.
137
161
  cache_config.cache_dtype)
@@ -188,7 +212,6 @@ class KVCacheManager:
188
212
  # uniform page size.
189
213
  representative_spec = kv_cache_config.kv_cache_groups[0].kv_cache_spec
190
214
  page_size_bytes = representative_spec.page_size_bytes
191
- self.runner.layer_name_to_kvcache_index: Dict[str, int] = {}
192
215
  kv_caches = self.runner.kv_caches
193
216
  num_blocks_list = []
194
217
  for i, kv_cache_tensor in enumerate(kv_cache_config.kv_cache_tensors):
@@ -198,14 +221,20 @@ class KVCacheManager:
198
221
  # num_blocks must be a multiple of dp_size
199
222
  num_blocks = (num_blocks // dp_size) * dp_size
200
223
  # NOTE: we'll multiply the num_kv_heads by 2 in the function
224
+ if self.use_mla:
225
+ head_size = self.runner.model_config.hf_config.kv_lora_rank + \
226
+ self.runner.model_config.hf_config.qk_rope_head_dim
227
+ else:
228
+ head_size = representative_spec.head_size
201
229
  kv_cache = create_kv_caches(
202
230
  num_blocks=num_blocks,
203
231
  block_size=representative_spec.block_size,
204
232
  num_kv_heads=representative_spec.num_kv_heads,
205
- head_size=representative_spec.head_size,
233
+ head_size=head_size,
206
234
  mesh=self.runner.mesh,
207
235
  layer_names=[f'kv_cache_tensor.{i}'],
208
236
  cache_dtype=t2j_dtype(representative_spec.dtype),
237
+ use_mla=self.use_mla,
209
238
  )[0]
210
239
  kv_caches.append(kv_cache)
211
240
  num_blocks_list.append(num_blocks)
@@ -289,13 +318,8 @@ class KVCacheManager:
289
318
 
290
319
  def _update_layer(cache, slices):
291
320
  """The function to apply to each layer's cache and slices."""
292
- reshaped_slices = slices.reshape(-1, 1, block_size,
293
- *slices.shape[1:])
294
- for (i, block_idx) in enumerate(block_numbers):
295
- cache = jax.lax.dynamic_update_slice_in_dim(cache,
296
- reshaped_slices[i],
297
- block_idx,
298
- axis=0)
321
+ reshaped_slices = slices.reshape(-1, block_size, *slices.shape[1:])
322
+ cache.at[block_numbers].set(reshaped_slices)
299
323
  return cache
300
324
 
301
325
  return jax.tree.map(_update_layer, kv_caches, kv_cache_slices)
@@ -348,16 +372,12 @@ class KVCacheManager:
348
372
  """
349
373
  if block_ids == list(range(block_ids[0],
350
374
  block_ids[0] + len(block_ids))):
351
- with runner_utils.LatencyTracker(
352
- "BatchedGatherKVSlices-for-blocks"):
353
- batched_kv_cache_per_layer = self._jitted_gather_continuous_kv_cache(
354
- self.runner.kv_caches, block_ids[0], len(block_ids))
375
+ batched_kv_cache_per_layer = self._jitted_gather_continuous_kv_cache(
376
+ self.runner.kv_caches, block_ids[0], len(block_ids))
355
377
 
356
378
  else:
357
- with runner_utils.LatencyTracker(
358
- "BatchedGatherKVSlices-for-blocks"):
359
- batched_kv_cache_per_layer = self._jitted_gather_kv_cache(
360
- self.runner.kv_caches, jnp.array(block_ids))
379
+ batched_kv_cache_per_layer = self._jitted_gather_kv_cache(
380
+ self.runner.kv_caches, jnp.array(block_ids))
361
381
  return batched_kv_cache_per_layer
362
382
 
363
383
  def transfer_kv_cache(self,
@@ -446,6 +466,7 @@ class KVCacheManager:
446
466
  kv_cache_slices,
447
467
  start_block,
448
468
  )
469
+ jax.block_until_ready(self.runner.kv_caches)
449
470
  else:
450
471
  with runner_utils.LatencyTracker(
451
472
  f"JittedInsertKVCache-b{len(block_numbers)}"):
@@ -457,6 +478,7 @@ class KVCacheManager:
457
478
  kv_cache_slices,
458
479
  jnp.array(block_numbers),
459
480
  )
481
+ jax.block_until_ready(self.runner.kv_caches)
460
482
 
461
483
  logger.debug(
462
484
  f"Updated kv cache entries cnt={len(self.runner.kv_caches)}")
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  from __future__ import annotations
2
16
 
3
17
  from typing import TYPE_CHECKING
@@ -7,7 +21,8 @@ from torchax.interop import jax_view
7
21
  from vllm.lora.layers.base_linear import BaseLinearLayerWithLoRA
8
22
  from vllm.lora.request import LoRARequest
9
23
 
10
- from tpu_inference.layers.vllm.sharding import update_lora
24
+ from tpu_inference.layers.vllm.process_weights.cleanup_sharding import \
25
+ update_lora
11
26
 
12
27
  if TYPE_CHECKING:
13
28
  from tpu_inference.runner.tpu_runner import TPUModelRunner
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  from typing import TYPE_CHECKING
2
16
 
3
17
  import jax
@@ -98,7 +112,7 @@ class MultiModalManager:
98
112
  # encoder outputs.
99
113
  encoder_outputs = []
100
114
  for _, num_items, mm_kwargs_group in group_mm_kwargs_by_modality(
101
- mm_kwargs, merge_by_field_config=False):
115
+ mm_kwargs):
102
116
  batched_mm_inputs = mm_kwargs_group
103
117
  # Convert torch tensors to numpy arrays that JAX can handle.
104
118
  if "pixel_values" in batched_mm_inputs and isinstance(
@@ -134,7 +148,7 @@ class MultiModalManager:
134
148
  # 2. A list or tuple (length: num_items) of tensors, each of shape
135
149
  # (feature_size, hidden_size) in case the feature size is dynamic
136
150
  # depending on the input multimodal items.
137
- curr_group_outputs = self.runner.get_multimodal_embeddings_fn(
151
+ curr_group_outputs = self.runner.embed_multimodal_fn(
138
152
  self.runner.state, image_grid_thw, **batched_mm_inputs)
139
153
 
140
154
  sanity_check_mm_encoder_outputs(
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  from typing import Dict
2
16
 
3
17
  import jax
@@ -14,12 +28,13 @@ class PersistentBatchManager:
14
28
  def __init__(self, requests: Dict[str, CachedRequestState],
15
29
  input_batch: InputBatch, encoder_cache: Dict[str,
16
30
  'jax.Array'],
17
- uses_mrope: bool, model_config):
31
+ uses_mrope: bool, model_config, is_last_rank: bool):
18
32
  self.requests = requests
19
33
  self.input_batch = input_batch
20
34
  self.encoder_cache = encoder_cache
21
35
  self.uses_mrope = uses_mrope
22
36
  self.model_config = model_config
37
+ self.is_last_rank = is_last_rank
23
38
 
24
39
  def _reorder_batch(self, scheduler_output: "VllmSchedulerOutput") -> int:
25
40
  """ Reorder the sheduled requests to RPA kernel friendly distribution
@@ -179,9 +194,35 @@ class PersistentBatchManager:
179
194
  num_computed_tokens = req_data.num_computed_tokens[i]
180
195
  new_block_ids = req_data.new_block_ids[i]
181
196
  resumed_from_preemption = req_data.resumed_from_preemption[i]
197
+ num_output_tokens = req_data.num_output_tokens[i]
182
198
 
183
199
  # Update the cached states.
184
200
  req_state.num_computed_tokens = num_computed_tokens
201
+ req_index = self.input_batch.req_id_to_index.get(req_id)
202
+
203
+ if not self.is_last_rank:
204
+ # When using PP, the scheduler sends the sampled tokens back,
205
+ # because there's no direct communication between the first-
206
+ # stage worker and the last-stage worker.
207
+ new_token_ids = req_data.new_token_ids[i]
208
+ # Add the sampled token(s) from the previous step (if any).
209
+ # This doesn't include "unverified" tokens like spec tokens.
210
+ num_new_tokens = (num_computed_tokens + len(new_token_ids) -
211
+ req_state.num_tokens)
212
+ if num_new_tokens == 1:
213
+ req_state.output_token_ids.append(new_token_ids[-1])
214
+ elif num_new_tokens > 0:
215
+ req_state.output_token_ids.extend(
216
+ new_token_ids[-num_new_tokens:])
217
+ elif num_output_tokens < len(req_state.output_token_ids):
218
+ del req_state.output_token_ids[num_output_tokens:]
219
+ if req_index is not None:
220
+ end_idx = (self.input_batch.num_prompt_tokens[req_index] +
221
+ num_output_tokens)
222
+ self.input_batch.num_tokens[req_index] = end_idx
223
+ self.input_batch.num_tokens_no_spec[req_index] = end_idx
224
+
225
+ # Update the block IDs.
185
226
  if not resumed_from_preemption:
186
227
  if new_block_ids is not None:
187
228
  # Append the new blocks to the existing block IDs.
@@ -194,7 +235,6 @@ class PersistentBatchManager:
194
235
  # Replace the existing block IDs with the new ones.
195
236
  req_state.block_ids = new_block_ids
196
237
 
197
- req_index = self.input_batch.req_id_to_index.get(req_id)
198
238
  if req_index is None:
199
239
  # The request is not in the persistent batch.
200
240
  # The request was either preempted and resumed later, or was not
@@ -209,6 +249,18 @@ class PersistentBatchManager:
209
249
  self.input_batch.block_table.append_row(
210
250
  new_block_ids, req_index)
211
251
 
252
+ # For the last rank, we don't need to update the token_ids_cpu
253
+ # because the sampled tokens are already cached.
254
+ if not self.is_last_rank:
255
+ start_token_index = num_computed_tokens
256
+ end_token_index = num_computed_tokens + len(new_token_ids)
257
+ self.input_batch.token_ids_cpu[
258
+ req_index,
259
+ start_token_index:end_token_index] = new_token_ids
260
+ self.input_batch.num_tokens_no_spec[
261
+ req_index] = end_token_index
262
+ self.input_batch.num_tokens[req_index] = end_token_index
263
+
212
264
  # Add spec_token_ids to token_ids_cpu.
213
265
  spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get(
214
266
  req_id, ())
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  from __future__ import annotations
2
16
 
3
17
  from dataclasses import dataclass
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  import functools
2
16
  from typing import TYPE_CHECKING, Tuple
3
17
 
@@ -61,11 +75,10 @@ class StructuredDecodingManager:
61
75
  self.runner.require_structured_out_cpu.fill(0)
62
76
 
63
77
  sorted_struct_requests = sorted(
64
- grammar_output.structured_output_request_ids.items(),
65
- key=lambda item: item[1])
78
+ grammar_output.structured_output_request_ids)
66
79
 
67
80
  cumulative_mask_idx = 0
68
- for req_id, _ in sorted_struct_requests:
81
+ for req_id in sorted_struct_requests:
69
82
  if req_id not in self.runner.input_batch.req_id_to_index:
70
83
  continue
71
84
  batch_index = self.runner.input_batch.req_id_to_index[req_id]