tpu-inference 0.11.1.dev202511270815__py3-none-any.whl → 0.13.0rc2.post7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (251) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_disagg_utils.py +14 -0
  4. tests/core/test_dp_scheduler.py +650 -768
  5. tests/core/test_init.py +14 -0
  6. tests/distributed/__init__.py +13 -0
  7. tests/distributed/test_distributed_utils.py +120 -0
  8. tests/distributed/test_tpu_connector.py +478 -0
  9. tests/e2e/__init__.py +13 -0
  10. tests/e2e/test_async_scheduler.py +211 -0
  11. tests/e2e/test_data_parallel.py +289 -0
  12. tests/e2e/test_hybrid_kvcache.py +219 -0
  13. tests/e2e/test_local_disagg.py +257 -0
  14. tests/e2e/test_model_loader.py +268 -0
  15. tests/e2e/test_multi_modal_inference.py +111 -0
  16. tests/e2e/test_pipeline_parallel.py +265 -0
  17. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  18. tests/e2e/test_sampling_params.py +269 -0
  19. tests/e2e/test_speculative_decoding.py +311 -0
  20. tests/e2e/test_structured_decoding.py +46 -0
  21. tests/executors/__init__.py +13 -0
  22. tests/executors/test_ray_distributed_executor.py +199 -0
  23. tests/experimental/__init__.py +13 -0
  24. tests/experimental/test_llama3_jax_stashed.py +208 -0
  25. tests/kernels/__init__.py +13 -0
  26. tests/kernels/collectives/__init__.py +13 -0
  27. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  28. tests/kernels/fused_moe_v1_test.py +14 -0
  29. tests/kernels/gmm_test.py +205 -0
  30. tests/kernels/mla_v1_test.py +143 -41
  31. tests/kernels/quantized_matmul_kernel_test.py +2 -34
  32. tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
  33. tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
  34. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
  35. tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
  36. tests/layers/__init__.py +13 -0
  37. tests/layers/common/__init__.py +13 -0
  38. tests/layers/common/test_attention_interface.py +156 -0
  39. tests/layers/common/test_quantization.py +149 -0
  40. tests/layers/jax/__init__.py +13 -0
  41. tests/layers/jax/attention/__init__.py +13 -0
  42. tests/layers/jax/attention/test_common_attention.py +103 -0
  43. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  44. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  45. tests/layers/jax/moe/__init__.py +13 -0
  46. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  47. tests/layers/jax/sample/__init__.py +13 -0
  48. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  49. tests/layers/jax/sample/test_sampling.py +115 -0
  50. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  51. tests/layers/jax/test_layers.py +155 -0
  52. tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
  53. tests/layers/jax/test_rope.py +93 -0
  54. tests/layers/jax/test_sharding.py +159 -0
  55. tests/layers/jax/test_transformer_block.py +152 -0
  56. tests/layers/vllm/__init__.py +13 -0
  57. tests/layers/vllm/test_attention.py +363 -0
  58. tests/layers/vllm/test_awq.py +405 -0
  59. tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
  60. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +418 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +441 -0
  62. tests/layers/vllm/test_fp8.py +17 -0
  63. tests/layers/vllm/test_mxfp4.py +312 -0
  64. tests/layers/vllm/test_unquantized.py +651 -0
  65. tests/layers/vllm/utils.py +87 -0
  66. tests/lora/__init__.py +13 -0
  67. tests/lora/conftest.py +14 -0
  68. tests/lora/test_bgmv.py +14 -0
  69. tests/lora/test_layers.py +21 -3
  70. tests/lora/test_lora.py +15 -1
  71. tests/lora/test_lora_perf.py +67 -0
  72. tests/models/__init__.py +13 -0
  73. tests/models/common/__init__.py +13 -0
  74. tests/models/common/test_model_loader.py +455 -0
  75. tests/models/jax/__init__.py +13 -0
  76. tests/models/jax/test_deepseek_v3.py +401 -0
  77. tests/models/jax/test_llama3.py +184 -0
  78. tests/models/jax/test_llama4.py +298 -0
  79. tests/models/jax/test_llama_eagle3.py +197 -0
  80. tests/models/jax/test_llama_guard_4.py +242 -0
  81. tests/models/jax/test_qwen2.py +172 -0
  82. tests/models/jax/test_qwen2_5_vl.py +605 -0
  83. tests/models/jax/test_qwen3.py +169 -0
  84. tests/models/jax/test_weight_loading.py +180 -0
  85. tests/models/jax/utils/__init__.py +13 -0
  86. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  87. tests/platforms/__init__.py +13 -0
  88. tests/platforms/test_tpu_platform.py +54 -0
  89. tests/runner/__init__.py +13 -0
  90. tests/runner/test_block_table.py +395 -0
  91. tests/runner/test_input_batch.py +226 -0
  92. tests/runner/test_kv_cache.py +220 -0
  93. tests/runner/test_kv_cache_manager.py +498 -0
  94. tests/runner/test_multimodal_manager.py +429 -0
  95. tests/runner/test_persistent_batch_manager.py +84 -0
  96. tests/runner/test_speculative_decoding_manager.py +368 -0
  97. tests/runner/test_structured_decoding_manager.py +220 -0
  98. tests/runner/test_tpu_runner.py +261 -0
  99. tests/runner/test_tpu_runner_dp.py +1099 -0
  100. tests/runner/test_tpu_runner_mesh.py +200 -0
  101. tests/runner/test_utils.py +411 -0
  102. tests/spec_decode/__init__.py +13 -0
  103. tests/spec_decode/test_eagle3.py +311 -0
  104. tests/test_base.py +14 -0
  105. tests/test_envs.py +110 -12
  106. tests/test_tpu_info.py +14 -0
  107. tests/test_utils.py +2 -45
  108. tests/worker/__init__.py +13 -0
  109. tests/worker/tpu_worker_test.py +414 -0
  110. tpu_inference/__init__.py +14 -0
  111. tpu_inference/core/__init__.py +13 -0
  112. tpu_inference/core/sched/__init__.py +13 -0
  113. tpu_inference/core/sched/dp_scheduler.py +372 -56
  114. tpu_inference/distributed/__init__.py +13 -0
  115. tpu_inference/distributed/jax_parallel_state.py +14 -0
  116. tpu_inference/distributed/tpu_connector.py +15 -10
  117. tpu_inference/distributed/utils.py +56 -4
  118. tpu_inference/envs.py +92 -8
  119. tpu_inference/executors/__init__.py +13 -0
  120. tpu_inference/executors/ray_distributed_executor.py +22 -1
  121. tpu_inference/experimental/__init__.py +13 -0
  122. tpu_inference/experimental/llama3_jax_stashed.py +14 -0
  123. tpu_inference/kernels/__init__.py +13 -0
  124. tpu_inference/kernels/collectives/__init__.py +13 -0
  125. tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
  126. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
  127. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  128. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  129. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  130. tpu_inference/kernels/fused_moe/v1/kernel.py +370 -324
  131. tpu_inference/kernels/megablox/__init__.py +13 -0
  132. tpu_inference/kernels/megablox/common.py +54 -0
  133. tpu_inference/kernels/megablox/gmm.py +646 -0
  134. tpu_inference/kernels/mla/__init__.py +13 -0
  135. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  136. tpu_inference/kernels/mla/v1/kernel.py +117 -145
  137. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  138. tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
  139. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  140. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  141. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
  142. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
  143. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  144. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
  145. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +167 -97
  146. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
  147. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
  148. tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
  149. tpu_inference/layers/__init__.py +13 -0
  150. tpu_inference/layers/common/__init__.py +13 -0
  151. tpu_inference/layers/common/attention_interface.py +26 -19
  152. tpu_inference/layers/common/attention_metadata.py +14 -0
  153. tpu_inference/layers/common/quant_methods.py +15 -0
  154. tpu_inference/layers/common/quantization.py +270 -0
  155. tpu_inference/layers/common/sharding.py +31 -9
  156. tpu_inference/layers/jax/__init__.py +13 -0
  157. tpu_inference/layers/jax/attention/__init__.py +13 -0
  158. tpu_inference/layers/jax/attention/attention.py +19 -6
  159. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
  160. tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
  161. tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
  162. tpu_inference/layers/jax/base.py +14 -0
  163. tpu_inference/layers/jax/constants.py +13 -0
  164. tpu_inference/layers/jax/layers.py +14 -0
  165. tpu_inference/layers/jax/misc.py +14 -0
  166. tpu_inference/layers/jax/moe/__init__.py +13 -0
  167. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
  168. tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
  169. tpu_inference/layers/jax/moe/moe.py +43 -3
  170. tpu_inference/layers/jax/pp_utils.py +53 -0
  171. tpu_inference/layers/jax/rope.py +14 -0
  172. tpu_inference/layers/jax/rope_interface.py +14 -0
  173. tpu_inference/layers/jax/sample/__init__.py +13 -0
  174. tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
  175. tpu_inference/layers/jax/sample/sampling.py +15 -1
  176. tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
  177. tpu_inference/layers/jax/transformer_block.py +14 -0
  178. tpu_inference/layers/vllm/__init__.py +13 -0
  179. tpu_inference/layers/vllm/attention.py +4 -4
  180. tpu_inference/layers/vllm/fused_moe.py +210 -260
  181. tpu_inference/layers/vllm/linear_common.py +57 -22
  182. tpu_inference/layers/vllm/quantization/__init__.py +16 -0
  183. tpu_inference/layers/vllm/quantization/awq.py +15 -1
  184. tpu_inference/layers/vllm/quantization/common.py +33 -18
  185. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  186. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +18 -3
  187. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +211 -148
  188. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  189. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +14 -0
  190. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +14 -0
  191. tpu_inference/layers/vllm/quantization/fp8.py +118 -0
  192. tpu_inference/layers/vllm/quantization/mxfp4.py +280 -210
  193. tpu_inference/layers/vllm/quantization/unquantized.py +134 -86
  194. tpu_inference/layers/vllm/sharding.py +21 -4
  195. tpu_inference/lora/__init__.py +13 -0
  196. tpu_inference/lora/torch_lora_ops.py +8 -13
  197. tpu_inference/models/__init__.py +13 -0
  198. tpu_inference/models/common/__init__.py +13 -0
  199. tpu_inference/models/common/model_loader.py +77 -36
  200. tpu_inference/models/jax/__init__.py +13 -0
  201. tpu_inference/models/jax/deepseek_v3.py +267 -157
  202. tpu_inference/models/jax/gpt_oss.py +26 -10
  203. tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
  204. tpu_inference/models/jax/llama3.py +99 -36
  205. tpu_inference/models/jax/llama4.py +14 -0
  206. tpu_inference/models/jax/llama_eagle3.py +14 -0
  207. tpu_inference/models/jax/llama_guard_4.py +15 -1
  208. tpu_inference/models/jax/qwen2.py +17 -2
  209. tpu_inference/models/jax/qwen2_5_vl.py +18 -4
  210. tpu_inference/models/jax/qwen3.py +17 -2
  211. tpu_inference/models/jax/utils/__init__.py +13 -0
  212. tpu_inference/models/jax/utils/file_utils.py +14 -0
  213. tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
  214. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  215. tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +91 -31
  216. tpu_inference/models/jax/utils/weight_utils.py +39 -2
  217. tpu_inference/models/vllm/__init__.py +13 -0
  218. tpu_inference/models/vllm/vllm_model_wrapper.py +20 -4
  219. tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
  220. tpu_inference/platforms/__init__.py +14 -0
  221. tpu_inference/platforms/tpu_platform.py +47 -71
  222. tpu_inference/runner/__init__.py +13 -0
  223. tpu_inference/runner/compilation_manager.py +158 -63
  224. tpu_inference/runner/kv_cache.py +54 -20
  225. tpu_inference/runner/kv_cache_manager.py +53 -30
  226. tpu_inference/runner/lora_utils.py +14 -0
  227. tpu_inference/runner/multimodal_manager.py +15 -1
  228. tpu_inference/runner/persistent_batch_manager.py +54 -2
  229. tpu_inference/runner/speculative_decoding_manager.py +14 -0
  230. tpu_inference/runner/structured_decoding_manager.py +14 -0
  231. tpu_inference/runner/tpu_runner.py +105 -57
  232. tpu_inference/runner/utils.py +2 -2
  233. tpu_inference/spec_decode/__init__.py +13 -0
  234. tpu_inference/spec_decode/jax/__init__.py +13 -0
  235. tpu_inference/spec_decode/jax/eagle3.py +65 -19
  236. tpu_inference/tpu_info.py +14 -0
  237. tpu_inference/utils.py +72 -44
  238. tpu_inference/worker/__init__.py +13 -0
  239. tpu_inference/worker/tpu_worker.py +65 -52
  240. {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/METADATA +11 -9
  241. tpu_inference-0.13.0rc2.post7.dist-info/RECORD +261 -0
  242. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  243. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
  244. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
  245. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
  246. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
  247. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
  248. tpu_inference-0.11.1.dev202511270815.dist-info/RECORD +0 -174
  249. {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/WHEEL +0 -0
  250. {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/licenses/LICENSE +0 -0
  251. {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  import functools
2
16
  from typing import TYPE_CHECKING, Dict, List
3
17
 
@@ -39,20 +53,30 @@ class KVCacheManager:
39
53
  # means this layer will perform attention using the keys and values
40
54
  # from the KV cache of `shared_kv_cache_layers[layer_name]`.
41
55
  self.shared_kv_cache_layers: dict[str, str] = {}
56
+ self.use_mla = self.runner.model_config.use_mla
42
57
 
43
58
  def get_kv_cache_spec(self):
44
59
  # TODO(xiang): this hack tricks engine core to init successfully
45
60
  block_size = self.runner.cache_config.block_size
46
- use_mla = self.runner.model_config.use_mla
47
61
  kv_cache_spec: dict[str, KVCacheSpec] = {}
48
62
 
49
63
  # If use pure jax (MODEL_IMPL_TYPE=flax_nnx), we don't register
50
64
  # attention into compilation config.
51
65
  # Use FullAttentionSpec for each layer
52
66
  # TODO(pooyam): Is it possible to merge the logic for vllm and non-vllm models?
67
+ model_config = self.runner.model_config
68
+ if self.use_mla:
69
+ # Individually pad the RopE and latents
70
+ qk_rope_head_dim = getattr(model_config.hf_text_config,
71
+ "qk_rope_head_dim", 0)
72
+ padded_kv_lora_rank = common_utils.align_to(
73
+ model_config.hf_text_config.kv_lora_rank, 128)
74
+ padded_qk_rope_head_dim = common_utils.align_to(
75
+ qk_rope_head_dim, 128)
76
+ mla_head_size = padded_kv_lora_rank + padded_qk_rope_head_dim
77
+
53
78
  if len(self.runner.vllm_config.compilation_config.
54
79
  static_forward_context) == 0:
55
- model_config = self.runner.model_config
56
80
  parallel_config = self.runner.parallel_config
57
81
  # Pad num_kv_heads to multiple of TP size.
58
82
  num_kv_heads = common_utils.get_padded_num_heads(
@@ -61,11 +85,11 @@ class KVCacheManager:
61
85
  head_size = common_utils.get_padded_head_dim(
62
86
  model_config.get_head_size())
63
87
  for i in range(model_config.get_num_layers(parallel_config)):
64
- if use_mla:
88
+ if self.use_mla:
65
89
  kv_cache_spec[f"layer.{i}"] = MLAAttentionSpec(
66
90
  block_size=block_size,
67
- num_kv_heads=num_kv_heads,
68
- head_size=head_size,
91
+ num_kv_heads=1,
92
+ head_size=mla_head_size,
69
93
  dtype=self.runner.kv_cache_dtype,
70
94
  cache_dtype_str=self.runner.vllm_config.cache_config.
71
95
  cache_dtype)
@@ -83,14 +107,13 @@ class KVCacheManager:
83
107
  self.runner.mesh.shape["model"])
84
108
  head_size = common_utils.get_padded_head_dim(
85
109
  hf_config.hidden_size // hf_config.num_attention_heads)
86
-
87
110
  # Eagle3 has only 1 layer
88
111
  for i in range(1):
89
- if use_mla:
90
- kv_cache_spec[f"layer.{i}"] = MLAAttentionSpec(
112
+ if self.use_mla:
113
+ kv_cache_spec[f"draft_layer.{i}"] = MLAAttentionSpec(
91
114
  block_size=block_size,
92
- num_kv_heads=num_kv_heads,
93
- head_size=head_size,
115
+ num_kv_heads=1,
116
+ head_size=mla_head_size,
94
117
  dtype=self.runner.kv_cache_dtype,
95
118
  cache_dtype_str=self.runner.vllm_config.
96
119
  cache_config.cache_dtype)
@@ -104,6 +127,7 @@ class KVCacheManager:
104
127
  # Else propagate attention modules from compilation config.
105
128
  layers = get_layers_from_vllm_config(self.runner.vllm_config,
106
129
  Attention)
130
+ logger.warning(f"Compilation num_layers = {len(layers.items())}")
107
131
  for layer_name, attn_module in layers.items():
108
132
  if (kv_tgt_layer :=
109
133
  attn_module.kv_sharing_target_layer_name) is not None:
@@ -127,11 +151,11 @@ class KVCacheManager:
127
151
  attn_module.head_size),
128
152
  dtype=self.runner.kv_cache_dtype,
129
153
  sliding_window=attn_module.sliding_window)
130
- elif use_mla:
131
- kv_cache_spec[f"layer.{i}"] = MLAAttentionSpec(
154
+ elif self.use_mla:
155
+ kv_cache_spec[layer_name] = MLAAttentionSpec(
132
156
  block_size=block_size,
133
- num_kv_heads=attn_module.num_kv_heads,
134
- head_size=attn_module.head_size,
157
+ num_kv_heads=1,
158
+ head_size=mla_head_size,
135
159
  dtype=self.runner.kv_cache_dtype,
136
160
  cache_dtype_str=self.runner.vllm_config.
137
161
  cache_config.cache_dtype)
@@ -198,14 +222,20 @@ class KVCacheManager:
198
222
  # num_blocks must be a multiple of dp_size
199
223
  num_blocks = (num_blocks // dp_size) * dp_size
200
224
  # NOTE: we'll multiply the num_kv_heads by 2 in the function
225
+ if self.use_mla:
226
+ head_size = self.runner.model_config.hf_config.kv_lora_rank + \
227
+ self.runner.model_config.hf_config.qk_rope_head_dim
228
+ else:
229
+ head_size = representative_spec.head_size
201
230
  kv_cache = create_kv_caches(
202
231
  num_blocks=num_blocks,
203
232
  block_size=representative_spec.block_size,
204
233
  num_kv_heads=representative_spec.num_kv_heads,
205
- head_size=representative_spec.head_size,
234
+ head_size=head_size,
206
235
  mesh=self.runner.mesh,
207
236
  layer_names=[f'kv_cache_tensor.{i}'],
208
237
  cache_dtype=t2j_dtype(representative_spec.dtype),
238
+ use_mla=self.use_mla,
209
239
  )[0]
210
240
  kv_caches.append(kv_cache)
211
241
  num_blocks_list.append(num_blocks)
@@ -289,13 +319,8 @@ class KVCacheManager:
289
319
 
290
320
  def _update_layer(cache, slices):
291
321
  """The function to apply to each layer's cache and slices."""
292
- reshaped_slices = slices.reshape(-1, 1, block_size,
293
- *slices.shape[1:])
294
- for (i, block_idx) in enumerate(block_numbers):
295
- cache = jax.lax.dynamic_update_slice_in_dim(cache,
296
- reshaped_slices[i],
297
- block_idx,
298
- axis=0)
322
+ reshaped_slices = slices.reshape(-1, block_size, *slices.shape[1:])
323
+ cache.at[block_numbers].set(reshaped_slices)
299
324
  return cache
300
325
 
301
326
  return jax.tree.map(_update_layer, kv_caches, kv_cache_slices)
@@ -348,16 +373,12 @@ class KVCacheManager:
348
373
  """
349
374
  if block_ids == list(range(block_ids[0],
350
375
  block_ids[0] + len(block_ids))):
351
- with runner_utils.LatencyTracker(
352
- "BatchedGatherKVSlices-for-blocks"):
353
- batched_kv_cache_per_layer = self._jitted_gather_continuous_kv_cache(
354
- self.runner.kv_caches, block_ids[0], len(block_ids))
376
+ batched_kv_cache_per_layer = self._jitted_gather_continuous_kv_cache(
377
+ self.runner.kv_caches, block_ids[0], len(block_ids))
355
378
 
356
379
  else:
357
- with runner_utils.LatencyTracker(
358
- "BatchedGatherKVSlices-for-blocks"):
359
- batched_kv_cache_per_layer = self._jitted_gather_kv_cache(
360
- self.runner.kv_caches, jnp.array(block_ids))
380
+ batched_kv_cache_per_layer = self._jitted_gather_kv_cache(
381
+ self.runner.kv_caches, jnp.array(block_ids))
361
382
  return batched_kv_cache_per_layer
362
383
 
363
384
  def transfer_kv_cache(self,
@@ -446,6 +467,7 @@ class KVCacheManager:
446
467
  kv_cache_slices,
447
468
  start_block,
448
469
  )
470
+ jax.block_until_ready(self.runner.kv_caches)
449
471
  else:
450
472
  with runner_utils.LatencyTracker(
451
473
  f"JittedInsertKVCache-b{len(block_numbers)}"):
@@ -457,6 +479,7 @@ class KVCacheManager:
457
479
  kv_cache_slices,
458
480
  jnp.array(block_numbers),
459
481
  )
482
+ jax.block_until_ready(self.runner.kv_caches)
460
483
 
461
484
  logger.debug(
462
485
  f"Updated kv cache entries cnt={len(self.runner.kv_caches)}")
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  from __future__ import annotations
2
16
 
3
17
  from typing import TYPE_CHECKING
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  from typing import TYPE_CHECKING
2
16
 
3
17
  import jax
@@ -134,7 +148,7 @@ class MultiModalManager:
134
148
  # 2. A list or tuple (length: num_items) of tensors, each of shape
135
149
  # (feature_size, hidden_size) in case the feature size is dynamic
136
150
  # depending on the input multimodal items.
137
- curr_group_outputs = self.runner.get_multimodal_embeddings_fn(
151
+ curr_group_outputs = self.runner.embed_multimodal_fn(
138
152
  self.runner.state, image_grid_thw, **batched_mm_inputs)
139
153
 
140
154
  sanity_check_mm_encoder_outputs(
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  from typing import Dict
2
16
 
3
17
  import jax
@@ -14,12 +28,13 @@ class PersistentBatchManager:
14
28
  def __init__(self, requests: Dict[str, CachedRequestState],
15
29
  input_batch: InputBatch, encoder_cache: Dict[str,
16
30
  'jax.Array'],
17
- uses_mrope: bool, model_config):
31
+ uses_mrope: bool, model_config, is_last_rank: bool):
18
32
  self.requests = requests
19
33
  self.input_batch = input_batch
20
34
  self.encoder_cache = encoder_cache
21
35
  self.uses_mrope = uses_mrope
22
36
  self.model_config = model_config
37
+ self.is_last_rank = is_last_rank
23
38
 
24
39
  def _reorder_batch(self, scheduler_output: "VllmSchedulerOutput") -> int:
25
40
  """ Reorder the sheduled requests to RPA kernel friendly distribution
@@ -179,9 +194,35 @@ class PersistentBatchManager:
179
194
  num_computed_tokens = req_data.num_computed_tokens[i]
180
195
  new_block_ids = req_data.new_block_ids[i]
181
196
  resumed_from_preemption = req_data.resumed_from_preemption[i]
197
+ num_output_tokens = req_data.num_output_tokens[i]
182
198
 
183
199
  # Update the cached states.
184
200
  req_state.num_computed_tokens = num_computed_tokens
201
+ req_index = self.input_batch.req_id_to_index.get(req_id)
202
+
203
+ if not self.is_last_rank:
204
+ # When using PP, the scheduler sends the sampled tokens back,
205
+ # because there's no direct communication between the first-
206
+ # stage worker and the last-stage worker.
207
+ new_token_ids = req_data.new_token_ids[i]
208
+ # Add the sampled token(s) from the previous step (if any).
209
+ # This doesn't include "unverified" tokens like spec tokens.
210
+ num_new_tokens = (num_computed_tokens + len(new_token_ids) -
211
+ req_state.num_tokens)
212
+ if num_new_tokens == 1:
213
+ req_state.output_token_ids.append(new_token_ids[-1])
214
+ elif num_new_tokens > 0:
215
+ req_state.output_token_ids.extend(
216
+ new_token_ids[-num_new_tokens:])
217
+ elif num_output_tokens < len(req_state.output_token_ids):
218
+ del req_state.output_token_ids[num_output_tokens:]
219
+ if req_index is not None:
220
+ end_idx = (self.input_batch.num_prompt_tokens[req_index] +
221
+ num_output_tokens)
222
+ self.input_batch.num_tokens[req_index] = end_idx
223
+ self.input_batch.num_tokens_no_spec[req_index] = end_idx
224
+
225
+ # Update the block IDs.
185
226
  if not resumed_from_preemption:
186
227
  if new_block_ids is not None:
187
228
  # Append the new blocks to the existing block IDs.
@@ -194,7 +235,6 @@ class PersistentBatchManager:
194
235
  # Replace the existing block IDs with the new ones.
195
236
  req_state.block_ids = new_block_ids
196
237
 
197
- req_index = self.input_batch.req_id_to_index.get(req_id)
198
238
  if req_index is None:
199
239
  # The request is not in the persistent batch.
200
240
  # The request was either preempted and resumed later, or was not
@@ -209,6 +249,18 @@ class PersistentBatchManager:
209
249
  self.input_batch.block_table.append_row(
210
250
  new_block_ids, req_index)
211
251
 
252
+ # For the last rank, we don't need to update the token_ids_cpu
253
+ # because the sampled tokens are already cached.
254
+ if not self.is_last_rank:
255
+ start_token_index = num_computed_tokens
256
+ end_token_index = num_computed_tokens + len(new_token_ids)
257
+ self.input_batch.token_ids_cpu[
258
+ req_index,
259
+ start_token_index:end_token_index] = new_token_ids
260
+ self.input_batch.num_tokens_no_spec[
261
+ req_index] = end_token_index
262
+ self.input_batch.num_tokens[req_index] = end_token_index
263
+
212
264
  # Add spec_token_ids to token_ids_cpu.
213
265
  spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get(
214
266
  req_id, ())
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  from __future__ import annotations
2
16
 
3
17
  from dataclasses import dataclass
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  import functools
2
16
  from typing import TYPE_CHECKING, Tuple
3
17