tpu-inference 0.11.1.dev202511270815__py3-none-any.whl → 0.13.0rc2.post7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (251) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_disagg_utils.py +14 -0
  4. tests/core/test_dp_scheduler.py +650 -768
  5. tests/core/test_init.py +14 -0
  6. tests/distributed/__init__.py +13 -0
  7. tests/distributed/test_distributed_utils.py +120 -0
  8. tests/distributed/test_tpu_connector.py +478 -0
  9. tests/e2e/__init__.py +13 -0
  10. tests/e2e/test_async_scheduler.py +211 -0
  11. tests/e2e/test_data_parallel.py +289 -0
  12. tests/e2e/test_hybrid_kvcache.py +219 -0
  13. tests/e2e/test_local_disagg.py +257 -0
  14. tests/e2e/test_model_loader.py +268 -0
  15. tests/e2e/test_multi_modal_inference.py +111 -0
  16. tests/e2e/test_pipeline_parallel.py +265 -0
  17. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  18. tests/e2e/test_sampling_params.py +269 -0
  19. tests/e2e/test_speculative_decoding.py +311 -0
  20. tests/e2e/test_structured_decoding.py +46 -0
  21. tests/executors/__init__.py +13 -0
  22. tests/executors/test_ray_distributed_executor.py +199 -0
  23. tests/experimental/__init__.py +13 -0
  24. tests/experimental/test_llama3_jax_stashed.py +208 -0
  25. tests/kernels/__init__.py +13 -0
  26. tests/kernels/collectives/__init__.py +13 -0
  27. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  28. tests/kernels/fused_moe_v1_test.py +14 -0
  29. tests/kernels/gmm_test.py +205 -0
  30. tests/kernels/mla_v1_test.py +143 -41
  31. tests/kernels/quantized_matmul_kernel_test.py +2 -34
  32. tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
  33. tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
  34. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
  35. tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
  36. tests/layers/__init__.py +13 -0
  37. tests/layers/common/__init__.py +13 -0
  38. tests/layers/common/test_attention_interface.py +156 -0
  39. tests/layers/common/test_quantization.py +149 -0
  40. tests/layers/jax/__init__.py +13 -0
  41. tests/layers/jax/attention/__init__.py +13 -0
  42. tests/layers/jax/attention/test_common_attention.py +103 -0
  43. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  44. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  45. tests/layers/jax/moe/__init__.py +13 -0
  46. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  47. tests/layers/jax/sample/__init__.py +13 -0
  48. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  49. tests/layers/jax/sample/test_sampling.py +115 -0
  50. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  51. tests/layers/jax/test_layers.py +155 -0
  52. tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
  53. tests/layers/jax/test_rope.py +93 -0
  54. tests/layers/jax/test_sharding.py +159 -0
  55. tests/layers/jax/test_transformer_block.py +152 -0
  56. tests/layers/vllm/__init__.py +13 -0
  57. tests/layers/vllm/test_attention.py +363 -0
  58. tests/layers/vllm/test_awq.py +405 -0
  59. tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
  60. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +418 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +441 -0
  62. tests/layers/vllm/test_fp8.py +17 -0
  63. tests/layers/vllm/test_mxfp4.py +312 -0
  64. tests/layers/vllm/test_unquantized.py +651 -0
  65. tests/layers/vllm/utils.py +87 -0
  66. tests/lora/__init__.py +13 -0
  67. tests/lora/conftest.py +14 -0
  68. tests/lora/test_bgmv.py +14 -0
  69. tests/lora/test_layers.py +21 -3
  70. tests/lora/test_lora.py +15 -1
  71. tests/lora/test_lora_perf.py +67 -0
  72. tests/models/__init__.py +13 -0
  73. tests/models/common/__init__.py +13 -0
  74. tests/models/common/test_model_loader.py +455 -0
  75. tests/models/jax/__init__.py +13 -0
  76. tests/models/jax/test_deepseek_v3.py +401 -0
  77. tests/models/jax/test_llama3.py +184 -0
  78. tests/models/jax/test_llama4.py +298 -0
  79. tests/models/jax/test_llama_eagle3.py +197 -0
  80. tests/models/jax/test_llama_guard_4.py +242 -0
  81. tests/models/jax/test_qwen2.py +172 -0
  82. tests/models/jax/test_qwen2_5_vl.py +605 -0
  83. tests/models/jax/test_qwen3.py +169 -0
  84. tests/models/jax/test_weight_loading.py +180 -0
  85. tests/models/jax/utils/__init__.py +13 -0
  86. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  87. tests/platforms/__init__.py +13 -0
  88. tests/platforms/test_tpu_platform.py +54 -0
  89. tests/runner/__init__.py +13 -0
  90. tests/runner/test_block_table.py +395 -0
  91. tests/runner/test_input_batch.py +226 -0
  92. tests/runner/test_kv_cache.py +220 -0
  93. tests/runner/test_kv_cache_manager.py +498 -0
  94. tests/runner/test_multimodal_manager.py +429 -0
  95. tests/runner/test_persistent_batch_manager.py +84 -0
  96. tests/runner/test_speculative_decoding_manager.py +368 -0
  97. tests/runner/test_structured_decoding_manager.py +220 -0
  98. tests/runner/test_tpu_runner.py +261 -0
  99. tests/runner/test_tpu_runner_dp.py +1099 -0
  100. tests/runner/test_tpu_runner_mesh.py +200 -0
  101. tests/runner/test_utils.py +411 -0
  102. tests/spec_decode/__init__.py +13 -0
  103. tests/spec_decode/test_eagle3.py +311 -0
  104. tests/test_base.py +14 -0
  105. tests/test_envs.py +110 -12
  106. tests/test_tpu_info.py +14 -0
  107. tests/test_utils.py +2 -45
  108. tests/worker/__init__.py +13 -0
  109. tests/worker/tpu_worker_test.py +414 -0
  110. tpu_inference/__init__.py +14 -0
  111. tpu_inference/core/__init__.py +13 -0
  112. tpu_inference/core/sched/__init__.py +13 -0
  113. tpu_inference/core/sched/dp_scheduler.py +372 -56
  114. tpu_inference/distributed/__init__.py +13 -0
  115. tpu_inference/distributed/jax_parallel_state.py +14 -0
  116. tpu_inference/distributed/tpu_connector.py +15 -10
  117. tpu_inference/distributed/utils.py +56 -4
  118. tpu_inference/envs.py +92 -8
  119. tpu_inference/executors/__init__.py +13 -0
  120. tpu_inference/executors/ray_distributed_executor.py +22 -1
  121. tpu_inference/experimental/__init__.py +13 -0
  122. tpu_inference/experimental/llama3_jax_stashed.py +14 -0
  123. tpu_inference/kernels/__init__.py +13 -0
  124. tpu_inference/kernels/collectives/__init__.py +13 -0
  125. tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
  126. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
  127. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  128. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  129. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  130. tpu_inference/kernels/fused_moe/v1/kernel.py +370 -324
  131. tpu_inference/kernels/megablox/__init__.py +13 -0
  132. tpu_inference/kernels/megablox/common.py +54 -0
  133. tpu_inference/kernels/megablox/gmm.py +646 -0
  134. tpu_inference/kernels/mla/__init__.py +13 -0
  135. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  136. tpu_inference/kernels/mla/v1/kernel.py +117 -145
  137. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  138. tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
  139. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  140. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  141. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
  142. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
  143. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  144. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
  145. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +167 -97
  146. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
  147. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
  148. tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
  149. tpu_inference/layers/__init__.py +13 -0
  150. tpu_inference/layers/common/__init__.py +13 -0
  151. tpu_inference/layers/common/attention_interface.py +26 -19
  152. tpu_inference/layers/common/attention_metadata.py +14 -0
  153. tpu_inference/layers/common/quant_methods.py +15 -0
  154. tpu_inference/layers/common/quantization.py +270 -0
  155. tpu_inference/layers/common/sharding.py +31 -9
  156. tpu_inference/layers/jax/__init__.py +13 -0
  157. tpu_inference/layers/jax/attention/__init__.py +13 -0
  158. tpu_inference/layers/jax/attention/attention.py +19 -6
  159. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
  160. tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
  161. tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
  162. tpu_inference/layers/jax/base.py +14 -0
  163. tpu_inference/layers/jax/constants.py +13 -0
  164. tpu_inference/layers/jax/layers.py +14 -0
  165. tpu_inference/layers/jax/misc.py +14 -0
  166. tpu_inference/layers/jax/moe/__init__.py +13 -0
  167. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
  168. tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
  169. tpu_inference/layers/jax/moe/moe.py +43 -3
  170. tpu_inference/layers/jax/pp_utils.py +53 -0
  171. tpu_inference/layers/jax/rope.py +14 -0
  172. tpu_inference/layers/jax/rope_interface.py +14 -0
  173. tpu_inference/layers/jax/sample/__init__.py +13 -0
  174. tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
  175. tpu_inference/layers/jax/sample/sampling.py +15 -1
  176. tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
  177. tpu_inference/layers/jax/transformer_block.py +14 -0
  178. tpu_inference/layers/vllm/__init__.py +13 -0
  179. tpu_inference/layers/vllm/attention.py +4 -4
  180. tpu_inference/layers/vllm/fused_moe.py +210 -260
  181. tpu_inference/layers/vllm/linear_common.py +57 -22
  182. tpu_inference/layers/vllm/quantization/__init__.py +16 -0
  183. tpu_inference/layers/vllm/quantization/awq.py +15 -1
  184. tpu_inference/layers/vllm/quantization/common.py +33 -18
  185. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  186. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +18 -3
  187. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +211 -148
  188. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  189. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +14 -0
  190. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +14 -0
  191. tpu_inference/layers/vllm/quantization/fp8.py +118 -0
  192. tpu_inference/layers/vllm/quantization/mxfp4.py +280 -210
  193. tpu_inference/layers/vllm/quantization/unquantized.py +134 -86
  194. tpu_inference/layers/vllm/sharding.py +21 -4
  195. tpu_inference/lora/__init__.py +13 -0
  196. tpu_inference/lora/torch_lora_ops.py +8 -13
  197. tpu_inference/models/__init__.py +13 -0
  198. tpu_inference/models/common/__init__.py +13 -0
  199. tpu_inference/models/common/model_loader.py +77 -36
  200. tpu_inference/models/jax/__init__.py +13 -0
  201. tpu_inference/models/jax/deepseek_v3.py +267 -157
  202. tpu_inference/models/jax/gpt_oss.py +26 -10
  203. tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
  204. tpu_inference/models/jax/llama3.py +99 -36
  205. tpu_inference/models/jax/llama4.py +14 -0
  206. tpu_inference/models/jax/llama_eagle3.py +14 -0
  207. tpu_inference/models/jax/llama_guard_4.py +15 -1
  208. tpu_inference/models/jax/qwen2.py +17 -2
  209. tpu_inference/models/jax/qwen2_5_vl.py +18 -4
  210. tpu_inference/models/jax/qwen3.py +17 -2
  211. tpu_inference/models/jax/utils/__init__.py +13 -0
  212. tpu_inference/models/jax/utils/file_utils.py +14 -0
  213. tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
  214. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  215. tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +91 -31
  216. tpu_inference/models/jax/utils/weight_utils.py +39 -2
  217. tpu_inference/models/vllm/__init__.py +13 -0
  218. tpu_inference/models/vllm/vllm_model_wrapper.py +20 -4
  219. tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
  220. tpu_inference/platforms/__init__.py +14 -0
  221. tpu_inference/platforms/tpu_platform.py +47 -71
  222. tpu_inference/runner/__init__.py +13 -0
  223. tpu_inference/runner/compilation_manager.py +158 -63
  224. tpu_inference/runner/kv_cache.py +54 -20
  225. tpu_inference/runner/kv_cache_manager.py +53 -30
  226. tpu_inference/runner/lora_utils.py +14 -0
  227. tpu_inference/runner/multimodal_manager.py +15 -1
  228. tpu_inference/runner/persistent_batch_manager.py +54 -2
  229. tpu_inference/runner/speculative_decoding_manager.py +14 -0
  230. tpu_inference/runner/structured_decoding_manager.py +14 -0
  231. tpu_inference/runner/tpu_runner.py +105 -57
  232. tpu_inference/runner/utils.py +2 -2
  233. tpu_inference/spec_decode/__init__.py +13 -0
  234. tpu_inference/spec_decode/jax/__init__.py +13 -0
  235. tpu_inference/spec_decode/jax/eagle3.py +65 -19
  236. tpu_inference/tpu_info.py +14 -0
  237. tpu_inference/utils.py +72 -44
  238. tpu_inference/worker/__init__.py +13 -0
  239. tpu_inference/worker/tpu_worker.py +65 -52
  240. {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/METADATA +11 -9
  241. tpu_inference-0.13.0rc2.post7.dist-info/RECORD +261 -0
  242. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  243. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
  244. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
  245. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
  246. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
  247. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
  248. tpu_inference-0.11.1.dev202511270815.dist-info/RECORD +0 -174
  249. {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/WHEEL +0 -0
  250. {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/licenses/LICENSE +0 -0
  251. {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/top_level.txt +0 -0
@@ -1,13 +1,27 @@
1
- import os
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
2
15
  import time
3
16
  from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple
4
17
 
5
18
  import jax
6
19
  import jax.numpy as jnp
7
20
  import numpy as np
8
- import vllm.envs as envs
21
+ import vllm.envs as vllm_envs
9
22
  from jax.sharding import NamedSharding, PartitionSpec
10
23
 
24
+ import tpu_inference.envs as envs
11
25
  from tpu_inference.core.disagg_utils import is_disagg_enabled
12
26
  from tpu_inference.layers.common.attention_metadata import AttentionMetadata
13
27
  from tpu_inference.layers.common.sharding import ShardingAxisName
@@ -15,6 +29,8 @@ from tpu_inference.layers.jax.sample.sampling import sample
15
29
  from tpu_inference.layers.jax.sample.sampling_metadata import \
16
30
  TPUSupportedSamplingMetadata
17
31
  from tpu_inference.logger import init_logger
32
+ from tpu_inference.models.jax.jax_intermediate_tensor import \
33
+ JaxIntermediateTensors
18
34
  from tpu_inference.utils import device_array
19
35
 
20
36
  if TYPE_CHECKING:
@@ -30,10 +46,12 @@ class CompilationManager:
30
46
 
31
47
  def __init__(self, runner: "TPUModelRunner"):
32
48
  self.runner = runner
33
- if not envs.VLLM_DISABLE_COMPILE_CACHE:
49
+ self._sampling_precompiled = False
50
+ self._gather_logprobs_precompiled = False
51
+ if not vllm_envs.VLLM_DISABLE_COMPILE_CACHE:
34
52
  logger.info("Enabling JAX compile cache.")
35
53
  jax.config.update("jax_compilation_cache_dir",
36
- envs.VLLM_XLA_CACHE_PATH)
54
+ vllm_envs.VLLM_XLA_CACHE_PATH)
37
55
 
38
56
  def _create_dummy_tensor(self,
39
57
  shape: Tuple[int, ...],
@@ -67,8 +85,7 @@ class CompilationManager:
67
85
  logger.info("Compilation finished in %.2f [secs].", end - start)
68
86
 
69
87
  def capture_model(self) -> None:
70
- if os.getenv("SKIP_JAX_PRECOMPILE",
71
- False) or self.runner.model_config.enforce_eager:
88
+ if envs.SKIP_JAX_PRECOMPILE or self.runner.model_config.enforce_eager:
72
89
  return
73
90
  logger.info("Precompile all the subgraphs with possible input shapes.")
74
91
 
@@ -81,11 +98,17 @@ class CompilationManager:
81
98
  self._precompile_backbone_with_inputs_embeds()
82
99
  if self.runner.scheduler_config.async_scheduling:
83
100
  self._precompile_substitute_placeholder_token()
101
+ if not self.runner.is_last_rank:
102
+ return
84
103
  self._precompile_select_from_array()
85
104
  self._precompile_compute_logits()
105
+ # Skip sampling if already precompiled before KV cache allocation
106
+ if not self._sampling_precompiled:
107
+ self._precompile_sampling()
86
108
  self._precompile_disagg_utils()
87
- self._precompile_sampling()
88
- self._precompile_gather_logprobs()
109
+ # Skip gather_logprobs if already precompiled before KV cache allocation
110
+ if not self._gather_logprobs_precompiled:
111
+ self._precompile_gather_logprobs()
89
112
  self._precompile_structured_decoding()
90
113
  if self.runner.speculative_config:
91
114
  self._precompile_speculative_decoding()
@@ -104,7 +127,7 @@ class CompilationManager:
104
127
 
105
128
  self._run_compilation(
106
129
  "input_embeddings_merger",
107
- self.runner.get_input_embeddings_fn,
130
+ self.runner.embed_input_ids_fn,
108
131
  self.runner.state,
109
132
  dummy_input_ids,
110
133
  dummy_multimodal_embeddings,
@@ -113,15 +136,22 @@ class CompilationManager:
113
136
 
114
137
  self._run_compilation(
115
138
  "input_embeddings_merger_text_only",
116
- self.runner.get_input_embeddings_fn,
139
+ self.runner.embed_input_ids_fn,
117
140
  self.runner.state,
118
141
  dummy_input_ids,
119
142
  None,
120
143
  num_tokens=num_tokens,
121
144
  )
122
145
 
123
- def _precompile_backbone_helper(self, name, *, input_ids, positions,
124
- inputs_embeds) -> None:
146
+ def _precompile_backbone_helper(self,
147
+ name,
148
+ *,
149
+ input_ids,
150
+ positions,
151
+ inputs_embeds,
152
+ intermediate_tensors=None,
153
+ is_first_rank=True,
154
+ is_last_rank=True) -> None:
125
155
  num_tokens = None
126
156
  if input_ids is not None:
127
157
  num_tokens = input_ids.shape[0]
@@ -181,10 +211,14 @@ class CompilationManager:
181
211
  inputs_embeds,
182
212
  layer_name_to_kvcache_index,
183
213
  lora_metadata,
214
+ intermediate_tensors,
215
+ is_first_rank,
216
+ is_last_rank,
184
217
  ):
185
218
  kv_caches, hidden_states, _ = self.runner.model_fn(
186
219
  state, kv_caches, input_ids, attention_metadata, inputs_embeds,
187
- positions, layer_name_to_kvcache_index, lora_metadata)
220
+ positions, layer_name_to_kvcache_index, lora_metadata,
221
+ intermediate_tensors, is_first_rank, is_last_rank)
188
222
  self.runner.kv_caches = kv_caches
189
223
  return hidden_states
190
224
 
@@ -207,6 +241,9 @@ class CompilationManager:
207
241
  inputs_embeds,
208
242
  tuple(self.runner.layer_name_to_kvcache_index.items()),
209
243
  lora_metadata,
244
+ intermediate_tensors,
245
+ is_first_rank,
246
+ is_last_rank,
210
247
  num_tokens=num_tokens,
211
248
  )
212
249
 
@@ -257,6 +294,7 @@ class CompilationManager:
257
294
  )
258
295
 
259
296
  def _precompile_backbone_text_only(self) -> None:
297
+ hidden_size = self.runner.model_config.get_hidden_size()
260
298
  for num_tokens in self.runner.num_tokens_paddings:
261
299
  dp_sharding = NamedSharding(
262
300
  self.runner.mesh, PartitionSpec(ShardingAxisName.ATTN_DATA, )
@@ -266,10 +304,28 @@ class CompilationManager:
266
304
  dp_sharding)
267
305
  positions = self._create_dummy_tensor((num_tokens, ), jnp.int32,
268
306
  dp_sharding)
269
- self._precompile_backbone_helper("backbone",
270
- input_ids=input_ids,
271
- positions=positions,
272
- inputs_embeds=None)
307
+ is_first_rank = self.runner.is_first_rank
308
+ is_last_rank = self.runner.is_last_rank
309
+ if is_first_rank:
310
+ intermediate_tensors = None
311
+ else:
312
+ hidden_states = self._create_dummy_tensor(
313
+ (num_tokens, hidden_size), jnp.bfloat16)
314
+ residual = self._create_dummy_tensor((num_tokens, hidden_size),
315
+ jnp.bfloat16)
316
+ intermediate_tensors = JaxIntermediateTensors(
317
+ tensors={
318
+ "hidden_states": hidden_states,
319
+ "residual": residual
320
+ })
321
+ self._precompile_backbone_helper(
322
+ f"worker{self.runner.rank} backbone",
323
+ input_ids=input_ids,
324
+ positions=positions,
325
+ inputs_embeds=None,
326
+ intermediate_tensors=intermediate_tensors,
327
+ is_first_rank=is_first_rank,
328
+ is_last_rank=is_last_rank)
273
329
 
274
330
  def _precompile_backbone_with_inputs_embeds(self) -> None:
275
331
  hidden_size = self.runner.model_config.get_hidden_size()
@@ -283,10 +339,28 @@ class CompilationManager:
283
339
  else:
284
340
  positions = self._create_dummy_tensor((num_tokens, ),
285
341
  jnp.int32)
286
- self._precompile_backbone_helper("backbone with embeds",
287
- input_ids=None,
288
- positions=positions,
289
- inputs_embeds=inputs_embeds)
342
+ is_first_rank = self.runner.is_first_rank
343
+ is_last_rank = self.runner.is_last_rank
344
+ if not is_first_rank:
345
+ hidden_states = self._create_dummy_tensor(
346
+ (num_tokens, hidden_size), jnp.bfloat16)
347
+ residual = self._create_dummy_tensor((num_tokens, hidden_size),
348
+ jnp.bfloat16)
349
+ intermediate_tensors = JaxIntermediateTensors(
350
+ tensors={
351
+ "hidden_states": hidden_states,
352
+ "residual": residual
353
+ })
354
+ else:
355
+ intermediate_tensors = None
356
+ self._precompile_backbone_helper(
357
+ f"worker{self.runner.rank} backbone with embeds",
358
+ input_ids=None,
359
+ positions=positions,
360
+ inputs_embeds=inputs_embeds,
361
+ intermediate_tensors=intermediate_tensors,
362
+ is_first_rank=is_first_rank,
363
+ is_last_rank=is_last_rank)
290
364
 
291
365
  def _precompile_select_from_array_helper(
292
366
  self,
@@ -354,7 +428,7 @@ class CompilationManager:
354
428
  self.runner.mesh, PartitionSpec(ShardingAxisName.ATTN_DATA, None))
355
429
  dp_size = self.runner.vllm_config.sharding_config.total_dp_size
356
430
  self._precompile_select_from_array_helper(
357
- name="select all logits",
431
+ name=f"worker{self.runner.rank} select all logits",
358
432
  source_paddings=self.runner.num_tokens_paddings,
359
433
  indices_paddings=index_paddings,
360
434
  hidden_dim=hsize,
@@ -365,7 +439,8 @@ class CompilationManager:
365
439
  if self.runner.speculative_config:
366
440
  vocab_size = self.runner.model_config.get_vocab_size()
367
441
  self._precompile_select_from_array_helper(
368
- name="select bonus tokens for spec decoding",
442
+ name=
443
+ f"worker{self.runner.rank} select bonus tokens for spec decoding",
369
444
  source_paddings=self.runner.num_logits_paddings,
370
445
  indices_paddings=self.runner.num_reqs_paddings,
371
446
  hidden_dim=vocab_size,
@@ -373,7 +448,8 @@ class CompilationManager:
373
448
  PartitionSpec(None, "model")),
374
449
  )
375
450
  self._precompile_select_from_array_helper(
376
- name="select target tokens for spec decoding",
451
+ name=
452
+ f"worker{self.runner.rank} select target tokens for spec decoding",
377
453
  source_paddings=self.runner.num_logits_paddings,
378
454
  indices_paddings=self.runner.num_logits_paddings,
379
455
  hidden_dim=vocab_size,
@@ -396,7 +472,7 @@ class CompilationManager:
396
472
  np.array([num_reqs], dtype=np.int32)):
397
473
  lora_metadata = self.runner.lora_utils.extract_lora_metadata()
398
474
  self._run_compilation(
399
- "compute_logits",
475
+ f"worker{self.runner.rank} compute_logits",
400
476
  self.runner.compute_logits_fn,
401
477
  self.runner.state,
402
478
  hidden_states,
@@ -410,43 +486,48 @@ class CompilationManager:
410
486
  for num_reqs in self.runner.num_reqs_paddings:
411
487
  logits_sharding = NamedSharding(
412
488
  self.runner.mesh,
413
- PartitionSpec(ShardingAxisName.ATTN_DATA, "model"))
489
+ PartitionSpec(ShardingAxisName.MLP_DATA,
490
+ ShardingAxisName.MLP_TENSOR))
414
491
  dp_size = self.runner.vllm_config.sharding_config.total_dp_size
415
492
  sampling_metadata_sharding = NamedSharding(
416
493
  self.runner.mesh, PartitionSpec(
417
- ShardingAxisName.ATTN_DATA)) if dp_size > 1 else None
494
+ ShardingAxisName.MLP_DATA)) if dp_size > 1 else None
418
495
  logits = self._create_dummy_tensor((num_reqs, hsize), jnp.bfloat16,
419
496
  logits_sharding)
420
497
  for do_sampling in (True, False):
421
- if do_sampling:
422
- temperature = np.full((num_reqs, ), 0.7, dtype=np.float32)
423
- top_k = np.full((num_reqs, ), 20, dtype=np.int32)
424
- top_p = np.full((num_reqs, ), 0.8, dtype=np.float32)
425
- (temperature, top_k,
426
- top_p) = device_array(self.runner.mesh,
427
- (temperature, top_k, top_p),
428
- sharding=sampling_metadata_sharding)
429
- else:
430
- temperature = None
431
- top_k = None
432
- top_p = None
433
-
434
- sampling_metadata = TPUSupportedSamplingMetadata(
435
- temperature=temperature,
436
- top_k=top_k,
437
- top_p=top_p,
438
- do_sampling=do_sampling,
439
- )
440
- self._run_compilation(
441
- "sample",
442
- sample,
443
- self.runner.rng_params_for_sampling,
444
- self.runner.mesh,
445
- logits,
446
- sampling_metadata,
447
- num_reqs=num_reqs,
448
- do_sampling=do_sampling,
449
- )
498
+ for logprobs in (True, False):
499
+ if do_sampling:
500
+ temperature = np.full((num_reqs, ),
501
+ 0.7,
502
+ dtype=np.float32)
503
+ top_k = np.full((num_reqs, ), 20, dtype=np.int32)
504
+ top_p = np.full((num_reqs, ), 0.8, dtype=np.float32)
505
+ (temperature, top_k, top_p) = device_array(
506
+ self.runner.mesh, (temperature, top_k, top_p),
507
+ sharding=sampling_metadata_sharding)
508
+ else:
509
+ temperature = None
510
+ top_k = None
511
+ top_p = None
512
+
513
+ sampling_metadata = TPUSupportedSamplingMetadata(
514
+ temperature=temperature,
515
+ top_k=top_k,
516
+ top_p=top_p,
517
+ do_sampling=do_sampling,
518
+ logprobs=logprobs)
519
+ self._run_compilation(
520
+ f"worker{self.runner.rank} sample",
521
+ sample,
522
+ self.runner.rng_params_for_sampling,
523
+ self.runner.mesh,
524
+ logits,
525
+ sampling_metadata,
526
+ num_reqs=num_reqs,
527
+ do_sampling=do_sampling,
528
+ )
529
+
530
+ self._sampling_precompiled = True
450
531
 
451
532
  def _precompile_disagg_utils(self) -> None:
452
533
  if not is_disagg_enabled():
@@ -476,10 +557,18 @@ class CompilationManager:
476
557
  logger.info("Compiling gather_logprobs with different input shapes.")
477
558
  hsize = self.runner.model_config.get_vocab_size()
478
559
  for num_reqs in self.runner.num_reqs_paddings:
479
- logits = self._create_dummy_tensor((num_reqs, hsize), jnp.bfloat16)
480
- token_ids = self._create_dummy_tensor((num_reqs, ), jnp.int32)
560
+ logits_sharding = NamedSharding(
561
+ self.runner.mesh,
562
+ PartitionSpec(ShardingAxisName.MLP_DATA,
563
+ ShardingAxisName.MLP_TENSOR))
564
+ token_ids_sharding = NamedSharding(
565
+ self.runner.mesh, PartitionSpec(ShardingAxisName.MLP_DATA, ))
566
+ logits = self._create_dummy_tensor((num_reqs, hsize), jnp.bfloat16,
567
+ logits_sharding)
568
+ token_ids = self._create_dummy_tensor((num_reqs, ), jnp.int32,
569
+ token_ids_sharding)
481
570
  self._run_compilation(
482
- "gather_logprobs",
571
+ f"worker{self.runner.rank} gather_logprobs",
483
572
  self.runner._compute_and_gather_logprobs,
484
573
  logits,
485
574
  token_ids,
@@ -487,6 +576,8 @@ class CompilationManager:
487
576
  num_reqs=num_reqs,
488
577
  )
489
578
 
579
+ self._gather_logprobs_precompiled = True
580
+
490
581
  def _precompile_speculative_decoding(self) -> None:
491
582
  logger.info(
492
583
  "Compiling speculative_decoding with different input shapes.")
@@ -531,7 +622,7 @@ class CompilationManager:
531
622
  do_sampling=do_sampling)
532
623
 
533
624
  self._run_compilation(
534
- compilation_name,
625
+ f"worker{self.runner.rank} {compilation_name}",
535
626
  self.runner.rejection_sampler,
536
627
  draft_token_ids,
537
628
  num_draft_tokens,
@@ -601,6 +692,7 @@ class CompilationManager:
601
692
  self._run_compilation(
602
693
  "eagle3_get_draft_token_ids",
603
694
  self.runner.drafter._get_draft_token_ids,
695
+ self.runner.drafter.state,
604
696
  hidden_states,
605
697
  num_logits=num_logits,
606
698
  )
@@ -645,9 +737,9 @@ class CompilationManager:
645
737
  num_reqs,
646
738
  ):
647
739
  target_hidden_states, input_ids, last_token_indices, _ = self.runner.drafter._filter_token_and_prepare_initial_inputs(
648
- token_indices, query_start_loc, seq_lens, input_ids,
649
- aux_hidden_states, attention_metadata, next_token_ids,
650
- num_reqs)
740
+ self.runner.drafter.state, token_indices, query_start_loc,
741
+ seq_lens, input_ids, aux_hidden_states, attention_metadata,
742
+ next_token_ids, num_reqs)
651
743
  return target_hidden_states, input_ids, last_token_indices
652
744
 
653
745
  input_ids = self._create_dummy_tensor((num_tokens, ), jnp.int32)
@@ -724,6 +816,7 @@ class CompilationManager:
724
816
  self._run_compilation(
725
817
  "eagle3_prepare_hidden_states_and_input_ids",
726
818
  self.runner.drafter._prepare_hidden_states_and_input_ids,
819
+ self.runner.drafter.state,
727
820
  aux_hidden_states,
728
821
  query_start_loc,
729
822
  target_token_ids,
@@ -758,6 +851,7 @@ class CompilationManager:
758
851
  self._run_compilation(
759
852
  "eagle3_select_inputs_for_loop_speculation",
760
853
  self.runner.drafter._select_inputs_for_loop_speculation,
854
+ self.runner.drafter.state,
761
855
  positions,
762
856
  hidden_states,
763
857
  hidden_states,
@@ -768,6 +862,7 @@ class CompilationManager:
768
862
  self._run_compilation(
769
863
  "eagle3_select_draft_token_ids",
770
864
  self.runner.drafter._select_draft_token_ids,
865
+ self.runner.drafter.state,
771
866
  hidden_states,
772
867
  last_token_indices,
773
868
  num_tokens=num_tokens,
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  from typing import Any, List
2
16
 
3
17
  import jax
@@ -7,6 +21,7 @@ from jax._src import dtypes
7
21
  from jax.sharding import Mesh, NamedSharding, PartitionSpec
8
22
  from torchax.ops.mappings import t2j_dtype
9
23
 
24
+ import tpu_inference.kernels.mla.v1.kernel as mla
10
25
  import tpu_inference.kernels.ragged_paged_attention.v3.kernel as rpa
11
26
  import tpu_inference.kernels.ragged_paged_attention.v3.kernel_hd64 as rpa_hd64
12
27
  from tpu_inference.layers.common.sharding import ShardingAxisName
@@ -17,9 +32,13 @@ logger = init_logger(__name__)
17
32
  DEFAULT_KV_CACHE_DTYPE = jnp.bfloat16
18
33
 
19
34
 
20
- def get_kv_cache_shape_with_mesh(mesh: Mesh, total_num_pages: int,
21
- page_size: int, actual_num_kv_heads: int,
22
- actual_head_dim: int, kv_dtype: any):
35
+ def get_kv_cache_shape_with_mesh(mesh: Mesh,
36
+ total_num_pages: int,
37
+ page_size: int,
38
+ actual_num_kv_heads: int,
39
+ actual_head_dim: int,
40
+ kv_dtype: any,
41
+ use_mla: bool = False):
23
42
  """Gets the KV cache shape based on the mesh configuration."""
24
43
 
25
44
  model_cnt = mesh.shape["model"]
@@ -28,15 +47,21 @@ def get_kv_cache_shape_with_mesh(mesh: Mesh, total_num_pages: int,
28
47
  # specific model, rather than being determined by the head_dim. If new
29
48
  # models are introduced with a head_dim of 64, this will require additional
30
49
  # model-specific adjustments.
31
- get_kv_cache_shape_fn = (
32
- rpa_hd64.get_kv_cache_shape if actual_head_dim == 64 \
33
- else rpa.get_kv_cache_shape
34
- )
35
- shape = list(
36
- get_kv_cache_shape_fn(total_num_pages, page_size,
37
- actual_num_kv_heads // model_cnt,
38
- actual_head_dim, kv_dtype))
39
- shape[2] *= model_cnt
50
+ if use_mla:
51
+ get_kv_cache_shape_fn = mla.get_kv_cache_shape
52
+ shape = list(
53
+ get_kv_cache_shape_fn(total_num_pages, page_size, actual_head_dim,
54
+ kv_dtype))
55
+ else:
56
+ get_kv_cache_shape_fn = (
57
+ rpa_hd64.get_kv_cache_shape if actual_head_dim == 64 \
58
+ else rpa.get_kv_cache_shape
59
+ )
60
+ shape = list(
61
+ get_kv_cache_shape_fn(total_num_pages, page_size,
62
+ actual_num_kv_heads // model_cnt,
63
+ actual_head_dim, kv_dtype))
64
+ shape[2] *= model_cnt
40
65
  return tuple(shape)
41
66
 
42
67
 
@@ -48,6 +73,7 @@ def create_kv_caches(
48
73
  mesh: Mesh,
49
74
  layer_names: List[str],
50
75
  cache_dtype: jnp.dtype = DEFAULT_KV_CACHE_DTYPE,
76
+ use_mla: bool = False,
51
77
  ) -> List[jax.Array]:
52
78
  """
53
79
  Creates a list of KV cache where each array mapps to single attention layer.
@@ -74,12 +100,16 @@ def create_kv_caches(
74
100
 
75
101
  cache_shape = get_kv_cache_shape_with_mesh(mesh, num_blocks, block_size,
76
102
  num_kv_heads, head_size,
77
- cache_dtype)
103
+ cache_dtype, use_mla)
78
104
 
79
- sharding = NamedSharding(
80
- mesh,
81
- PartitionSpec(ShardingAxisName.ATTN_DATA, None,
82
- ShardingAxisName.ATTN_HEAD))
105
+ if use_mla:
106
+ sharding = NamedSharding(mesh,
107
+ PartitionSpec(ShardingAxisName.MLP_TENSOR))
108
+ else:
109
+ sharding = NamedSharding(
110
+ mesh,
111
+ PartitionSpec(ShardingAxisName.ATTN_DATA, None,
112
+ ShardingAxisName.ATTN_HEAD))
83
113
 
84
114
  def _allocate() -> jax.Array:
85
115
  return jnp.empty(
@@ -94,7 +124,8 @@ def create_kv_caches(
94
124
  return kv_caches
95
125
 
96
126
 
97
- def get_rpa_page_size_bytes(mesh: Mesh, kv_cache_specs: dict[str, Any]) -> int:
127
+ def get_attention_page_size_bytes(mesh: Mesh,
128
+ kv_cache_specs: dict[str, Any]) -> int:
98
129
  """
99
130
  Calculate KV cache page size of RPA kernel.
100
131
 
@@ -107,14 +138,16 @@ def get_rpa_page_size_bytes(mesh: Mesh, kv_cache_specs: dict[str, Any]) -> int:
107
138
  """
108
139
 
109
140
  # Import it here to avoid circular import.
110
- from vllm.v1.kv_cache_interface import AttentionSpec
141
+ from vllm.v1.kv_cache_interface import AttentionSpec, MLAAttentionSpec
111
142
 
112
143
  page_size_bytes_set = set()
113
144
  for kv_cache_spec in kv_cache_specs.values():
114
145
  assert isinstance(kv_cache_spec, AttentionSpec)
115
146
 
116
147
  dtype = t2j_dtype(kv_cache_spec.dtype)
117
- bits = dtypes.bit_width(dtype)
148
+ bits = (dtypes.bit_width(dtype) if hasattr(dtypes, "bit_width") else
149
+ dtypes.itemsize_bits(dtype))
150
+ use_mla = isinstance(kv_cache_spec, MLAAttentionSpec)
118
151
 
119
152
  kv_cache_shape = get_kv_cache_shape_with_mesh(
120
153
  mesh=mesh,
@@ -123,6 +156,7 @@ def get_rpa_page_size_bytes(mesh: Mesh, kv_cache_specs: dict[str, Any]) -> int:
123
156
  actual_num_kv_heads=kv_cache_spec.num_kv_heads,
124
157
  actual_head_dim=kv_cache_spec.head_size,
125
158
  kv_dtype=dtype,
159
+ use_mla=use_mla,
126
160
  )
127
161
  page_size_bytes = (bits * np.prod(kv_cache_shape)) // 8
128
162
  page_size_bytes_set.add(page_size_bytes)