tpu-inference 0.11.1.dev202512030818__py3-none-any.whl → 0.13.2rc3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (250) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_disagg_utils.py +14 -0
  4. tests/core/test_dp_scheduler.py +650 -768
  5. tests/core/test_init.py +14 -0
  6. tests/distributed/__init__.py +13 -0
  7. tests/distributed/test_distributed_utils.py +120 -0
  8. tests/distributed/test_tpu_connector.py +478 -0
  9. tests/e2e/__init__.py +13 -0
  10. tests/e2e/test_async_scheduler.py +211 -0
  11. tests/e2e/test_data_parallel.py +289 -0
  12. tests/e2e/test_hybrid_kvcache.py +219 -0
  13. tests/e2e/test_local_disagg.py +257 -0
  14. tests/e2e/test_model_loader.py +268 -0
  15. tests/e2e/test_multi_modal_inference.py +111 -0
  16. tests/e2e/test_pipeline_parallel.py +265 -0
  17. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  18. tests/e2e/test_sampling_params.py +269 -0
  19. tests/e2e/test_speculative_decoding.py +311 -0
  20. tests/e2e/test_structured_decoding.py +46 -0
  21. tests/executors/__init__.py +13 -0
  22. tests/executors/test_ray_distributed_executor.py +199 -0
  23. tests/experimental/__init__.py +13 -0
  24. tests/experimental/test_llama3_jax_stashed.py +208 -0
  25. tests/kernels/__init__.py +13 -0
  26. tests/kernels/collectives/__init__.py +13 -0
  27. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  28. tests/kernels/fused_moe_v1_test.py +14 -0
  29. tests/kernels/gmm_test.py +205 -0
  30. tests/kernels/mla_v1_test.py +143 -41
  31. tests/kernels/quantized_matmul_kernel_test.py +2 -34
  32. tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
  33. tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
  34. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
  35. tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
  36. tests/layers/__init__.py +13 -0
  37. tests/layers/common/__init__.py +13 -0
  38. tests/layers/common/test_attention_interface.py +156 -0
  39. tests/layers/common/test_quantization.py +149 -0
  40. tests/layers/jax/__init__.py +13 -0
  41. tests/layers/jax/attention/__init__.py +13 -0
  42. tests/layers/jax/attention/test_common_attention.py +103 -0
  43. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  44. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  45. tests/layers/jax/moe/__init__.py +13 -0
  46. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  47. tests/layers/jax/sample/__init__.py +13 -0
  48. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  49. tests/layers/jax/sample/test_sampling.py +115 -0
  50. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  51. tests/layers/jax/test_layers.py +155 -0
  52. tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
  53. tests/layers/jax/test_rope.py +93 -0
  54. tests/layers/jax/test_sharding.py +159 -0
  55. tests/layers/jax/test_transformer_block.py +152 -0
  56. tests/layers/vllm/__init__.py +13 -0
  57. tests/layers/vllm/test_attention.py +363 -0
  58. tests/layers/vllm/test_awq.py +405 -0
  59. tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
  60. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +418 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +441 -0
  62. tests/layers/vllm/test_fp8.py +17 -0
  63. tests/layers/vllm/test_mxfp4.py +312 -0
  64. tests/layers/vllm/test_unquantized.py +651 -0
  65. tests/layers/vllm/utils.py +87 -0
  66. tests/lora/__init__.py +13 -0
  67. tests/lora/conftest.py +14 -0
  68. tests/lora/test_bgmv.py +14 -0
  69. tests/lora/test_layers.py +21 -3
  70. tests/lora/test_lora.py +15 -1
  71. tests/lora/test_lora_perf.py +67 -0
  72. tests/models/__init__.py +13 -0
  73. tests/models/common/__init__.py +13 -0
  74. tests/models/common/test_model_loader.py +455 -0
  75. tests/models/jax/__init__.py +13 -0
  76. tests/models/jax/test_deepseek_v3.py +401 -0
  77. tests/models/jax/test_llama3.py +184 -0
  78. tests/models/jax/test_llama4.py +298 -0
  79. tests/models/jax/test_llama_eagle3.py +197 -0
  80. tests/models/jax/test_llama_guard_4.py +242 -0
  81. tests/models/jax/test_qwen2.py +172 -0
  82. tests/models/jax/test_qwen2_5_vl.py +605 -0
  83. tests/models/jax/test_qwen3.py +169 -0
  84. tests/models/jax/test_weight_loading.py +180 -0
  85. tests/models/jax/utils/__init__.py +13 -0
  86. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  87. tests/platforms/__init__.py +13 -0
  88. tests/platforms/test_tpu_platform.py +54 -0
  89. tests/runner/__init__.py +13 -0
  90. tests/runner/test_block_table.py +395 -0
  91. tests/runner/test_input_batch.py +226 -0
  92. tests/runner/test_kv_cache.py +220 -0
  93. tests/runner/test_kv_cache_manager.py +498 -0
  94. tests/runner/test_multimodal_manager.py +429 -0
  95. tests/runner/test_persistent_batch_manager.py +84 -0
  96. tests/runner/test_speculative_decoding_manager.py +368 -0
  97. tests/runner/test_structured_decoding_manager.py +220 -0
  98. tests/runner/test_tpu_runner.py +261 -0
  99. tests/runner/test_tpu_runner_dp.py +1099 -0
  100. tests/runner/test_tpu_runner_mesh.py +200 -0
  101. tests/runner/test_utils.py +411 -0
  102. tests/spec_decode/__init__.py +13 -0
  103. tests/spec_decode/test_eagle3.py +311 -0
  104. tests/test_base.py +14 -0
  105. tests/test_envs.py +78 -1
  106. tests/test_tpu_info.py +14 -0
  107. tests/test_utils.py +1 -43
  108. tests/worker/__init__.py +13 -0
  109. tests/worker/tpu_worker_test.py +414 -0
  110. tpu_inference/__init__.py +14 -0
  111. tpu_inference/core/__init__.py +13 -0
  112. tpu_inference/core/sched/__init__.py +13 -0
  113. tpu_inference/core/sched/dp_scheduler.py +372 -56
  114. tpu_inference/distributed/__init__.py +13 -0
  115. tpu_inference/distributed/jax_parallel_state.py +14 -0
  116. tpu_inference/distributed/tpu_connector.py +14 -9
  117. tpu_inference/distributed/utils.py +56 -4
  118. tpu_inference/envs.py +38 -7
  119. tpu_inference/executors/__init__.py +13 -0
  120. tpu_inference/executors/ray_distributed_executor.py +17 -0
  121. tpu_inference/experimental/__init__.py +13 -0
  122. tpu_inference/experimental/llama3_jax_stashed.py +14 -0
  123. tpu_inference/kernels/__init__.py +13 -0
  124. tpu_inference/kernels/collectives/__init__.py +13 -0
  125. tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
  126. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
  127. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  128. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  129. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  130. tpu_inference/kernels/fused_moe/v1/kernel.py +370 -324
  131. tpu_inference/kernels/megablox/__init__.py +13 -0
  132. tpu_inference/kernels/megablox/common.py +54 -0
  133. tpu_inference/kernels/megablox/gmm.py +646 -0
  134. tpu_inference/kernels/mla/__init__.py +13 -0
  135. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  136. tpu_inference/kernels/mla/v1/kernel.py +117 -145
  137. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  138. tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
  139. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  140. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  141. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
  142. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
  143. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  144. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
  145. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +95 -78
  146. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
  147. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
  148. tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
  149. tpu_inference/layers/__init__.py +13 -0
  150. tpu_inference/layers/common/__init__.py +13 -0
  151. tpu_inference/layers/common/attention_interface.py +26 -19
  152. tpu_inference/layers/common/attention_metadata.py +14 -0
  153. tpu_inference/layers/common/quant_methods.py +15 -0
  154. tpu_inference/layers/common/quantization.py +270 -0
  155. tpu_inference/layers/common/sharding.py +28 -5
  156. tpu_inference/layers/jax/__init__.py +13 -0
  157. tpu_inference/layers/jax/attention/__init__.py +13 -0
  158. tpu_inference/layers/jax/attention/attention.py +19 -6
  159. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
  160. tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
  161. tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
  162. tpu_inference/layers/jax/base.py +14 -0
  163. tpu_inference/layers/jax/constants.py +13 -0
  164. tpu_inference/layers/jax/layers.py +14 -0
  165. tpu_inference/layers/jax/misc.py +14 -0
  166. tpu_inference/layers/jax/moe/__init__.py +13 -0
  167. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
  168. tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
  169. tpu_inference/layers/jax/moe/moe.py +43 -3
  170. tpu_inference/layers/jax/pp_utils.py +53 -0
  171. tpu_inference/layers/jax/rope.py +14 -0
  172. tpu_inference/layers/jax/rope_interface.py +14 -0
  173. tpu_inference/layers/jax/sample/__init__.py +13 -0
  174. tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
  175. tpu_inference/layers/jax/sample/sampling.py +15 -1
  176. tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
  177. tpu_inference/layers/jax/transformer_block.py +14 -0
  178. tpu_inference/layers/vllm/__init__.py +13 -0
  179. tpu_inference/layers/vllm/attention.py +4 -4
  180. tpu_inference/layers/vllm/fused_moe.py +210 -260
  181. tpu_inference/layers/vllm/linear_common.py +57 -22
  182. tpu_inference/layers/vllm/quantization/__init__.py +16 -0
  183. tpu_inference/layers/vllm/quantization/awq.py +15 -1
  184. tpu_inference/layers/vllm/quantization/common.py +33 -18
  185. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  186. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +18 -3
  187. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +211 -148
  188. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  189. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +14 -0
  190. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +14 -0
  191. tpu_inference/layers/vllm/quantization/fp8.py +118 -0
  192. tpu_inference/layers/vllm/quantization/mxfp4.py +278 -209
  193. tpu_inference/layers/vllm/quantization/unquantized.py +134 -86
  194. tpu_inference/layers/vllm/sharding.py +21 -4
  195. tpu_inference/lora/__init__.py +13 -0
  196. tpu_inference/lora/torch_lora_ops.py +8 -13
  197. tpu_inference/models/__init__.py +13 -0
  198. tpu_inference/models/common/__init__.py +13 -0
  199. tpu_inference/models/common/model_loader.py +74 -35
  200. tpu_inference/models/jax/__init__.py +13 -0
  201. tpu_inference/models/jax/deepseek_v3.py +267 -157
  202. tpu_inference/models/jax/gpt_oss.py +26 -10
  203. tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
  204. tpu_inference/models/jax/llama3.py +99 -36
  205. tpu_inference/models/jax/llama4.py +14 -0
  206. tpu_inference/models/jax/llama_eagle3.py +14 -0
  207. tpu_inference/models/jax/llama_guard_4.py +15 -1
  208. tpu_inference/models/jax/qwen2.py +17 -2
  209. tpu_inference/models/jax/qwen2_5_vl.py +18 -4
  210. tpu_inference/models/jax/qwen3.py +17 -2
  211. tpu_inference/models/jax/utils/__init__.py +13 -0
  212. tpu_inference/models/jax/utils/file_utils.py +14 -0
  213. tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
  214. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  215. tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +89 -26
  216. tpu_inference/models/jax/utils/weight_utils.py +39 -2
  217. tpu_inference/models/vllm/__init__.py +13 -0
  218. tpu_inference/models/vllm/vllm_model_wrapper.py +20 -3
  219. tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
  220. tpu_inference/platforms/__init__.py +14 -0
  221. tpu_inference/platforms/tpu_platform.py +47 -64
  222. tpu_inference/runner/__init__.py +13 -0
  223. tpu_inference/runner/compilation_manager.py +72 -37
  224. tpu_inference/runner/kv_cache.py +54 -20
  225. tpu_inference/runner/kv_cache_manager.py +46 -17
  226. tpu_inference/runner/lora_utils.py +14 -0
  227. tpu_inference/runner/multimodal_manager.py +15 -1
  228. tpu_inference/runner/persistent_batch_manager.py +14 -0
  229. tpu_inference/runner/speculative_decoding_manager.py +14 -0
  230. tpu_inference/runner/structured_decoding_manager.py +14 -0
  231. tpu_inference/runner/tpu_runner.py +44 -17
  232. tpu_inference/spec_decode/__init__.py +13 -0
  233. tpu_inference/spec_decode/jax/__init__.py +13 -0
  234. tpu_inference/spec_decode/jax/eagle3.py +13 -0
  235. tpu_inference/tpu_info.py +14 -0
  236. tpu_inference/utils.py +42 -36
  237. tpu_inference/worker/__init__.py +13 -0
  238. tpu_inference/worker/tpu_worker.py +63 -50
  239. {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.2rc3.dist-info}/METADATA +7 -9
  240. tpu_inference-0.13.2rc3.dist-info/RECORD +261 -0
  241. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  242. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
  243. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
  244. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
  245. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
  246. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
  247. tpu_inference-0.11.1.dev202512030818.dist-info/RECORD +0 -174
  248. {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.2rc3.dist-info}/WHEEL +0 -0
  249. {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.2rc3.dist-info}/licenses/LICENSE +0 -0
  250. {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.2rc3.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  import time
2
16
  from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple
3
17
 
@@ -32,6 +46,8 @@ class CompilationManager:
32
46
 
33
47
  def __init__(self, runner: "TPUModelRunner"):
34
48
  self.runner = runner
49
+ self._sampling_precompiled = False
50
+ self._gather_logprobs_precompiled = False
35
51
  if not vllm_envs.VLLM_DISABLE_COMPILE_CACHE:
36
52
  logger.info("Enabling JAX compile cache.")
37
53
  jax.config.update("jax_compilation_cache_dir",
@@ -86,9 +102,13 @@ class CompilationManager:
86
102
  return
87
103
  self._precompile_select_from_array()
88
104
  self._precompile_compute_logits()
105
+ # Skip sampling if already precompiled before KV cache allocation
106
+ if not self._sampling_precompiled:
107
+ self._precompile_sampling()
89
108
  self._precompile_disagg_utils()
90
- self._precompile_sampling()
91
- self._precompile_gather_logprobs()
109
+ # Skip gather_logprobs if already precompiled before KV cache allocation
110
+ if not self._gather_logprobs_precompiled:
111
+ self._precompile_gather_logprobs()
92
112
  self._precompile_structured_decoding()
93
113
  if self.runner.speculative_config:
94
114
  self._precompile_speculative_decoding()
@@ -107,7 +127,7 @@ class CompilationManager:
107
127
 
108
128
  self._run_compilation(
109
129
  "input_embeddings_merger",
110
- self.runner.get_input_embeddings_fn,
130
+ self.runner.embed_input_ids_fn,
111
131
  self.runner.state,
112
132
  dummy_input_ids,
113
133
  dummy_multimodal_embeddings,
@@ -116,7 +136,7 @@ class CompilationManager:
116
136
 
117
137
  self._run_compilation(
118
138
  "input_embeddings_merger_text_only",
119
- self.runner.get_input_embeddings_fn,
139
+ self.runner.embed_input_ids_fn,
120
140
  self.runner.state,
121
141
  dummy_input_ids,
122
142
  None,
@@ -466,43 +486,48 @@ class CompilationManager:
466
486
  for num_reqs in self.runner.num_reqs_paddings:
467
487
  logits_sharding = NamedSharding(
468
488
  self.runner.mesh,
469
- PartitionSpec(ShardingAxisName.ATTN_DATA, "model"))
489
+ PartitionSpec(ShardingAxisName.MLP_DATA,
490
+ ShardingAxisName.MLP_TENSOR))
470
491
  dp_size = self.runner.vllm_config.sharding_config.total_dp_size
471
492
  sampling_metadata_sharding = NamedSharding(
472
493
  self.runner.mesh, PartitionSpec(
473
- ShardingAxisName.ATTN_DATA)) if dp_size > 1 else None
494
+ ShardingAxisName.MLP_DATA)) if dp_size > 1 else None
474
495
  logits = self._create_dummy_tensor((num_reqs, hsize), jnp.bfloat16,
475
496
  logits_sharding)
476
497
  for do_sampling in (True, False):
477
- if do_sampling:
478
- temperature = np.full((num_reqs, ), 0.7, dtype=np.float32)
479
- top_k = np.full((num_reqs, ), 20, dtype=np.int32)
480
- top_p = np.full((num_reqs, ), 0.8, dtype=np.float32)
481
- (temperature, top_k,
482
- top_p) = device_array(self.runner.mesh,
483
- (temperature, top_k, top_p),
484
- sharding=sampling_metadata_sharding)
485
- else:
486
- temperature = None
487
- top_k = None
488
- top_p = None
489
-
490
- sampling_metadata = TPUSupportedSamplingMetadata(
491
- temperature=temperature,
492
- top_k=top_k,
493
- top_p=top_p,
494
- do_sampling=do_sampling,
495
- )
496
- self._run_compilation(
497
- f"worker{self.runner.rank} sample",
498
- sample,
499
- self.runner.rng_params_for_sampling,
500
- self.runner.mesh,
501
- logits,
502
- sampling_metadata,
503
- num_reqs=num_reqs,
504
- do_sampling=do_sampling,
505
- )
498
+ for logprobs in (True, False):
499
+ if do_sampling:
500
+ temperature = np.full((num_reqs, ),
501
+ 0.7,
502
+ dtype=np.float32)
503
+ top_k = np.full((num_reqs, ), 20, dtype=np.int32)
504
+ top_p = np.full((num_reqs, ), 0.8, dtype=np.float32)
505
+ (temperature, top_k, top_p) = device_array(
506
+ self.runner.mesh, (temperature, top_k, top_p),
507
+ sharding=sampling_metadata_sharding)
508
+ else:
509
+ temperature = None
510
+ top_k = None
511
+ top_p = None
512
+
513
+ sampling_metadata = TPUSupportedSamplingMetadata(
514
+ temperature=temperature,
515
+ top_k=top_k,
516
+ top_p=top_p,
517
+ do_sampling=do_sampling,
518
+ logprobs=logprobs)
519
+ self._run_compilation(
520
+ f"worker{self.runner.rank} sample",
521
+ sample,
522
+ self.runner.rng_params_for_sampling,
523
+ self.runner.mesh,
524
+ logits,
525
+ sampling_metadata,
526
+ num_reqs=num_reqs,
527
+ do_sampling=do_sampling,
528
+ )
529
+
530
+ self._sampling_precompiled = True
506
531
 
507
532
  def _precompile_disagg_utils(self) -> None:
508
533
  if not is_disagg_enabled():
@@ -532,8 +557,16 @@ class CompilationManager:
532
557
  logger.info("Compiling gather_logprobs with different input shapes.")
533
558
  hsize = self.runner.model_config.get_vocab_size()
534
559
  for num_reqs in self.runner.num_reqs_paddings:
535
- logits = self._create_dummy_tensor((num_reqs, hsize), jnp.bfloat16)
536
- token_ids = self._create_dummy_tensor((num_reqs, ), jnp.int32)
560
+ logits_sharding = NamedSharding(
561
+ self.runner.mesh,
562
+ PartitionSpec(ShardingAxisName.MLP_DATA,
563
+ ShardingAxisName.MLP_TENSOR))
564
+ token_ids_sharding = NamedSharding(
565
+ self.runner.mesh, PartitionSpec(ShardingAxisName.MLP_DATA, ))
566
+ logits = self._create_dummy_tensor((num_reqs, hsize), jnp.bfloat16,
567
+ logits_sharding)
568
+ token_ids = self._create_dummy_tensor((num_reqs, ), jnp.int32,
569
+ token_ids_sharding)
537
570
  self._run_compilation(
538
571
  f"worker{self.runner.rank} gather_logprobs",
539
572
  self.runner._compute_and_gather_logprobs,
@@ -543,6 +576,8 @@ class CompilationManager:
543
576
  num_reqs=num_reqs,
544
577
  )
545
578
 
579
+ self._gather_logprobs_precompiled = True
580
+
546
581
  def _precompile_speculative_decoding(self) -> None:
547
582
  logger.info(
548
583
  "Compiling speculative_decoding with different input shapes.")
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  from typing import Any, List
2
16
 
3
17
  import jax
@@ -7,6 +21,7 @@ from jax._src import dtypes
7
21
  from jax.sharding import Mesh, NamedSharding, PartitionSpec
8
22
  from torchax.ops.mappings import t2j_dtype
9
23
 
24
+ import tpu_inference.kernels.mla.v1.kernel as mla
10
25
  import tpu_inference.kernels.ragged_paged_attention.v3.kernel as rpa
11
26
  import tpu_inference.kernels.ragged_paged_attention.v3.kernel_hd64 as rpa_hd64
12
27
  from tpu_inference.layers.common.sharding import ShardingAxisName
@@ -17,9 +32,13 @@ logger = init_logger(__name__)
17
32
  DEFAULT_KV_CACHE_DTYPE = jnp.bfloat16
18
33
 
19
34
 
20
- def get_kv_cache_shape_with_mesh(mesh: Mesh, total_num_pages: int,
21
- page_size: int, actual_num_kv_heads: int,
22
- actual_head_dim: int, kv_dtype: any):
35
+ def get_kv_cache_shape_with_mesh(mesh: Mesh,
36
+ total_num_pages: int,
37
+ page_size: int,
38
+ actual_num_kv_heads: int,
39
+ actual_head_dim: int,
40
+ kv_dtype: any,
41
+ use_mla: bool = False):
23
42
  """Gets the KV cache shape based on the mesh configuration."""
24
43
 
25
44
  model_cnt = mesh.shape["model"]
@@ -28,15 +47,21 @@ def get_kv_cache_shape_with_mesh(mesh: Mesh, total_num_pages: int,
28
47
  # specific model, rather than being determined by the head_dim. If new
29
48
  # models are introduced with a head_dim of 64, this will require additional
30
49
  # model-specific adjustments.
31
- get_kv_cache_shape_fn = (
32
- rpa_hd64.get_kv_cache_shape if actual_head_dim == 64 \
33
- else rpa.get_kv_cache_shape
34
- )
35
- shape = list(
36
- get_kv_cache_shape_fn(total_num_pages, page_size,
37
- actual_num_kv_heads // model_cnt,
38
- actual_head_dim, kv_dtype))
39
- shape[2] *= model_cnt
50
+ if use_mla:
51
+ get_kv_cache_shape_fn = mla.get_kv_cache_shape
52
+ shape = list(
53
+ get_kv_cache_shape_fn(total_num_pages, page_size, actual_head_dim,
54
+ kv_dtype))
55
+ else:
56
+ get_kv_cache_shape_fn = (
57
+ rpa_hd64.get_kv_cache_shape if actual_head_dim == 64 \
58
+ else rpa.get_kv_cache_shape
59
+ )
60
+ shape = list(
61
+ get_kv_cache_shape_fn(total_num_pages, page_size,
62
+ actual_num_kv_heads // model_cnt,
63
+ actual_head_dim, kv_dtype))
64
+ shape[2] *= model_cnt
40
65
  return tuple(shape)
41
66
 
42
67
 
@@ -48,6 +73,7 @@ def create_kv_caches(
48
73
  mesh: Mesh,
49
74
  layer_names: List[str],
50
75
  cache_dtype: jnp.dtype = DEFAULT_KV_CACHE_DTYPE,
76
+ use_mla: bool = False,
51
77
  ) -> List[jax.Array]:
52
78
  """
53
79
  Creates a list of KV cache where each array mapps to single attention layer.
@@ -74,12 +100,16 @@ def create_kv_caches(
74
100
 
75
101
  cache_shape = get_kv_cache_shape_with_mesh(mesh, num_blocks, block_size,
76
102
  num_kv_heads, head_size,
77
- cache_dtype)
103
+ cache_dtype, use_mla)
78
104
 
79
- sharding = NamedSharding(
80
- mesh,
81
- PartitionSpec(ShardingAxisName.ATTN_DATA, None,
82
- ShardingAxisName.ATTN_HEAD))
105
+ if use_mla:
106
+ sharding = NamedSharding(mesh,
107
+ PartitionSpec(ShardingAxisName.MLP_TENSOR))
108
+ else:
109
+ sharding = NamedSharding(
110
+ mesh,
111
+ PartitionSpec(ShardingAxisName.ATTN_DATA, None,
112
+ ShardingAxisName.ATTN_HEAD))
83
113
 
84
114
  def _allocate() -> jax.Array:
85
115
  return jnp.empty(
@@ -94,7 +124,8 @@ def create_kv_caches(
94
124
  return kv_caches
95
125
 
96
126
 
97
- def get_rpa_page_size_bytes(mesh: Mesh, kv_cache_specs: dict[str, Any]) -> int:
127
+ def get_attention_page_size_bytes(mesh: Mesh,
128
+ kv_cache_specs: dict[str, Any]) -> int:
98
129
  """
99
130
  Calculate KV cache page size of RPA kernel.
100
131
 
@@ -107,14 +138,16 @@ def get_rpa_page_size_bytes(mesh: Mesh, kv_cache_specs: dict[str, Any]) -> int:
107
138
  """
108
139
 
109
140
  # Import it here to avoid circular import.
110
- from vllm.v1.kv_cache_interface import AttentionSpec
141
+ from vllm.v1.kv_cache_interface import AttentionSpec, MLAAttentionSpec
111
142
 
112
143
  page_size_bytes_set = set()
113
144
  for kv_cache_spec in kv_cache_specs.values():
114
145
  assert isinstance(kv_cache_spec, AttentionSpec)
115
146
 
116
147
  dtype = t2j_dtype(kv_cache_spec.dtype)
117
- bits = dtypes.bit_width(dtype)
148
+ bits = (dtypes.bit_width(dtype) if hasattr(dtypes, "bit_width") else
149
+ dtypes.itemsize_bits(dtype))
150
+ use_mla = isinstance(kv_cache_spec, MLAAttentionSpec)
118
151
 
119
152
  kv_cache_shape = get_kv_cache_shape_with_mesh(
120
153
  mesh=mesh,
@@ -123,6 +156,7 @@ def get_rpa_page_size_bytes(mesh: Mesh, kv_cache_specs: dict[str, Any]) -> int:
123
156
  actual_num_kv_heads=kv_cache_spec.num_kv_heads,
124
157
  actual_head_dim=kv_cache_spec.head_size,
125
158
  kv_dtype=dtype,
159
+ use_mla=use_mla,
126
160
  )
127
161
  page_size_bytes = (bits * np.prod(kv_cache_shape)) // 8
128
162
  page_size_bytes_set.add(page_size_bytes)
@@ -1,5 +1,19 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  import functools
2
- from typing import TYPE_CHECKING, Dict, List
16
+ from typing import TYPE_CHECKING, List
3
17
 
4
18
  import jax
5
19
  import jax.numpy as jnp
@@ -39,20 +53,30 @@ class KVCacheManager:
39
53
  # means this layer will perform attention using the keys and values
40
54
  # from the KV cache of `shared_kv_cache_layers[layer_name]`.
41
55
  self.shared_kv_cache_layers: dict[str, str] = {}
56
+ self.use_mla = self.runner.model_config.use_mla
42
57
 
43
58
  def get_kv_cache_spec(self):
44
59
  # TODO(xiang): this hack tricks engine core to init successfully
45
60
  block_size = self.runner.cache_config.block_size
46
- use_mla = self.runner.model_config.use_mla
47
61
  kv_cache_spec: dict[str, KVCacheSpec] = {}
48
62
 
49
63
  # If use pure jax (MODEL_IMPL_TYPE=flax_nnx), we don't register
50
64
  # attention into compilation config.
51
65
  # Use FullAttentionSpec for each layer
52
66
  # TODO(pooyam): Is it possible to merge the logic for vllm and non-vllm models?
67
+ model_config = self.runner.model_config
68
+ if self.use_mla:
69
+ # Individually pad the RopE and latents
70
+ qk_rope_head_dim = getattr(model_config.hf_text_config,
71
+ "qk_rope_head_dim", 0)
72
+ padded_kv_lora_rank = common_utils.align_to(
73
+ model_config.hf_text_config.kv_lora_rank, 128)
74
+ padded_qk_rope_head_dim = common_utils.align_to(
75
+ qk_rope_head_dim, 128)
76
+ mla_head_size = padded_kv_lora_rank + padded_qk_rope_head_dim
77
+
53
78
  if len(self.runner.vllm_config.compilation_config.
54
79
  static_forward_context) == 0:
55
- model_config = self.runner.model_config
56
80
  parallel_config = self.runner.parallel_config
57
81
  # Pad num_kv_heads to multiple of TP size.
58
82
  num_kv_heads = common_utils.get_padded_num_heads(
@@ -61,11 +85,11 @@ class KVCacheManager:
61
85
  head_size = common_utils.get_padded_head_dim(
62
86
  model_config.get_head_size())
63
87
  for i in range(model_config.get_num_layers(parallel_config)):
64
- if use_mla:
88
+ if self.use_mla:
65
89
  kv_cache_spec[f"layer.{i}"] = MLAAttentionSpec(
66
90
  block_size=block_size,
67
- num_kv_heads=num_kv_heads,
68
- head_size=head_size,
91
+ num_kv_heads=1,
92
+ head_size=mla_head_size,
69
93
  dtype=self.runner.kv_cache_dtype,
70
94
  cache_dtype_str=self.runner.vllm_config.cache_config.
71
95
  cache_dtype)
@@ -83,14 +107,13 @@ class KVCacheManager:
83
107
  self.runner.mesh.shape["model"])
84
108
  head_size = common_utils.get_padded_head_dim(
85
109
  hf_config.hidden_size // hf_config.num_attention_heads)
86
-
87
110
  # Eagle3 has only 1 layer
88
111
  for i in range(1):
89
- if use_mla:
90
- kv_cache_spec[f"layer.{i}"] = MLAAttentionSpec(
112
+ if self.use_mla:
113
+ kv_cache_spec[f"draft_layer.{i}"] = MLAAttentionSpec(
91
114
  block_size=block_size,
92
- num_kv_heads=num_kv_heads,
93
- head_size=head_size,
115
+ num_kv_heads=1,
116
+ head_size=mla_head_size,
94
117
  dtype=self.runner.kv_cache_dtype,
95
118
  cache_dtype_str=self.runner.vllm_config.
96
119
  cache_config.cache_dtype)
@@ -104,6 +127,7 @@ class KVCacheManager:
104
127
  # Else propagate attention modules from compilation config.
105
128
  layers = get_layers_from_vllm_config(self.runner.vllm_config,
106
129
  Attention)
130
+ logger.warning(f"Compilation num_layers = {len(layers.items())}")
107
131
  for layer_name, attn_module in layers.items():
108
132
  if (kv_tgt_layer :=
109
133
  attn_module.kv_sharing_target_layer_name) is not None:
@@ -127,11 +151,11 @@ class KVCacheManager:
127
151
  attn_module.head_size),
128
152
  dtype=self.runner.kv_cache_dtype,
129
153
  sliding_window=attn_module.sliding_window)
130
- elif use_mla:
131
- kv_cache_spec[f"layer.{i}"] = MLAAttentionSpec(
154
+ elif self.use_mla:
155
+ kv_cache_spec[layer_name] = MLAAttentionSpec(
132
156
  block_size=block_size,
133
- num_kv_heads=attn_module.num_kv_heads,
134
- head_size=attn_module.head_size,
157
+ num_kv_heads=1,
158
+ head_size=mla_head_size,
135
159
  dtype=self.runner.kv_cache_dtype,
136
160
  cache_dtype_str=self.runner.vllm_config.
137
161
  cache_config.cache_dtype)
@@ -188,7 +212,6 @@ class KVCacheManager:
188
212
  # uniform page size.
189
213
  representative_spec = kv_cache_config.kv_cache_groups[0].kv_cache_spec
190
214
  page_size_bytes = representative_spec.page_size_bytes
191
- self.runner.layer_name_to_kvcache_index: Dict[str, int] = {}
192
215
  kv_caches = self.runner.kv_caches
193
216
  num_blocks_list = []
194
217
  for i, kv_cache_tensor in enumerate(kv_cache_config.kv_cache_tensors):
@@ -198,14 +221,20 @@ class KVCacheManager:
198
221
  # num_blocks must be a multiple of dp_size
199
222
  num_blocks = (num_blocks // dp_size) * dp_size
200
223
  # NOTE: we'll multiply the num_kv_heads by 2 in the function
224
+ if self.use_mla:
225
+ head_size = self.runner.model_config.hf_config.kv_lora_rank + \
226
+ self.runner.model_config.hf_config.qk_rope_head_dim
227
+ else:
228
+ head_size = representative_spec.head_size
201
229
  kv_cache = create_kv_caches(
202
230
  num_blocks=num_blocks,
203
231
  block_size=representative_spec.block_size,
204
232
  num_kv_heads=representative_spec.num_kv_heads,
205
- head_size=representative_spec.head_size,
233
+ head_size=head_size,
206
234
  mesh=self.runner.mesh,
207
235
  layer_names=[f'kv_cache_tensor.{i}'],
208
236
  cache_dtype=t2j_dtype(representative_spec.dtype),
237
+ use_mla=self.use_mla,
209
238
  )[0]
210
239
  kv_caches.append(kv_cache)
211
240
  num_blocks_list.append(num_blocks)
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  from __future__ import annotations
2
16
 
3
17
  from typing import TYPE_CHECKING
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  from typing import TYPE_CHECKING
2
16
 
3
17
  import jax
@@ -134,7 +148,7 @@ class MultiModalManager:
134
148
  # 2. A list or tuple (length: num_items) of tensors, each of shape
135
149
  # (feature_size, hidden_size) in case the feature size is dynamic
136
150
  # depending on the input multimodal items.
137
- curr_group_outputs = self.runner.get_multimodal_embeddings_fn(
151
+ curr_group_outputs = self.runner.embed_multimodal_fn(
138
152
  self.runner.state, image_grid_thw, **batched_mm_inputs)
139
153
 
140
154
  sanity_check_mm_encoder_outputs(
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  from typing import Dict
2
16
 
3
17
  import jax
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  from __future__ import annotations
2
16
 
3
17
  from dataclasses import dataclass
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  import functools
2
16
  from typing import TYPE_CHECKING, Tuple
3
17