tpu-inference 0.11.1.dev202511220812__py3-none-any.whl → 0.13.2.dev20251230__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (257) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_disagg_utils.py +14 -0
  4. tests/core/test_dp_scheduler.py +650 -768
  5. tests/core/test_init.py +14 -0
  6. tests/distributed/__init__.py +13 -0
  7. tests/distributed/test_distributed_utils.py +120 -0
  8. tests/distributed/test_tpu_connector.py +478 -0
  9. tests/e2e/__init__.py +13 -0
  10. tests/e2e/test_async_scheduler.py +211 -0
  11. tests/e2e/test_data_parallel.py +289 -0
  12. tests/e2e/test_hybrid_kvcache.py +219 -0
  13. tests/e2e/test_local_disagg.py +257 -0
  14. tests/e2e/test_model_loader.py +268 -0
  15. tests/e2e/test_multi_modal_inference.py +111 -0
  16. tests/e2e/test_pipeline_parallel.py +265 -0
  17. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  18. tests/e2e/test_sampling_params.py +269 -0
  19. tests/e2e/test_speculative_decoding.py +311 -0
  20. tests/e2e/test_structured_decoding.py +46 -0
  21. tests/executors/__init__.py +13 -0
  22. tests/executors/test_ray_distributed_executor.py +199 -0
  23. tests/experimental/__init__.py +13 -0
  24. tests/experimental/test_llama3_jax_stashed.py +208 -0
  25. tests/kernels/__init__.py +13 -0
  26. tests/kernels/collectives/__init__.py +13 -0
  27. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  28. tests/kernels/fused_moe_v1_test.py +317 -34
  29. tests/kernels/gmm_test.py +205 -0
  30. tests/kernels/mla_v1_test.py +143 -41
  31. tests/kernels/quantized_matmul_kernel_test.py +2 -34
  32. tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
  33. tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
  34. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
  35. tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
  36. tests/layers/__init__.py +13 -0
  37. tests/layers/common/__init__.py +13 -0
  38. tests/layers/common/test_attention_interface.py +156 -0
  39. tests/layers/common/test_quantization.py +149 -0
  40. tests/layers/jax/__init__.py +13 -0
  41. tests/layers/jax/attention/__init__.py +13 -0
  42. tests/layers/jax/attention/test_common_attention.py +103 -0
  43. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  44. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  45. tests/layers/jax/moe/__init__.py +13 -0
  46. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  47. tests/layers/jax/sample/__init__.py +13 -0
  48. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  49. tests/layers/jax/sample/test_sampling.py +115 -0
  50. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  51. tests/layers/jax/test_layers.py +155 -0
  52. tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
  53. tests/layers/jax/test_rope.py +93 -0
  54. tests/layers/jax/test_sharding.py +159 -0
  55. tests/layers/jax/test_transformer_block.py +152 -0
  56. tests/layers/vllm/__init__.py +13 -0
  57. tests/layers/vllm/test_attention.py +363 -0
  58. tests/layers/vllm/test_awq.py +406 -0
  59. tests/layers/vllm/test_compressed_tensors_moe.py +199 -0
  60. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +441 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +443 -0
  62. tests/layers/vllm/test_fp8.py +17 -0
  63. tests/layers/vllm/test_mxfp4.py +320 -0
  64. tests/layers/vllm/test_unquantized.py +662 -0
  65. tests/layers/vllm/utils.py +87 -0
  66. tests/lora/__init__.py +13 -0
  67. tests/lora/conftest.py +14 -0
  68. tests/lora/test_bgmv.py +14 -0
  69. tests/lora/test_layers.py +26 -6
  70. tests/lora/test_lora.py +15 -1
  71. tests/lora/test_lora_perf.py +67 -0
  72. tests/models/__init__.py +13 -0
  73. tests/models/common/__init__.py +13 -0
  74. tests/models/common/test_model_loader.py +455 -0
  75. tests/models/jax/__init__.py +13 -0
  76. tests/models/jax/test_deepseek_v3.py +401 -0
  77. tests/models/jax/test_llama3.py +184 -0
  78. tests/models/jax/test_llama4.py +298 -0
  79. tests/models/jax/test_llama_eagle3.py +197 -0
  80. tests/models/jax/test_llama_guard_4.py +242 -0
  81. tests/models/jax/test_qwen2.py +172 -0
  82. tests/models/jax/test_qwen2_5_vl.py +605 -0
  83. tests/models/jax/test_qwen3.py +169 -0
  84. tests/models/jax/test_weight_loading.py +180 -0
  85. tests/models/jax/utils/__init__.py +13 -0
  86. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  87. tests/platforms/__init__.py +13 -0
  88. tests/platforms/test_tpu_platform.py +54 -0
  89. tests/runner/__init__.py +13 -0
  90. tests/runner/test_block_table.py +395 -0
  91. tests/runner/test_input_batch.py +226 -0
  92. tests/runner/test_kv_cache.py +220 -0
  93. tests/runner/test_kv_cache_manager.py +498 -0
  94. tests/runner/test_multimodal_manager.py +429 -0
  95. tests/runner/test_persistent_batch_manager.py +84 -0
  96. tests/runner/test_speculative_decoding_manager.py +368 -0
  97. tests/runner/test_structured_decoding_manager.py +220 -0
  98. tests/runner/test_tpu_runner.py +261 -0
  99. tests/runner/test_tpu_runner_dp.py +1099 -0
  100. tests/runner/test_tpu_runner_mesh.py +200 -0
  101. tests/runner/test_utils.py +411 -0
  102. tests/spec_decode/__init__.py +13 -0
  103. tests/spec_decode/test_eagle3.py +311 -0
  104. tests/test_base.py +14 -0
  105. tests/test_envs.py +110 -12
  106. tests/test_tpu_info.py +14 -0
  107. tests/test_utils.py +2 -45
  108. tests/worker/__init__.py +13 -0
  109. tests/worker/tpu_worker_test.py +414 -0
  110. tpu_inference/__init__.py +14 -0
  111. tpu_inference/core/__init__.py +13 -0
  112. tpu_inference/core/sched/__init__.py +13 -0
  113. tpu_inference/core/sched/dp_scheduler.py +372 -56
  114. tpu_inference/distributed/__init__.py +13 -0
  115. tpu_inference/distributed/jax_parallel_state.py +14 -0
  116. tpu_inference/distributed/tpu_connector.py +15 -10
  117. tpu_inference/distributed/utils.py +56 -4
  118. tpu_inference/envs.py +92 -8
  119. tpu_inference/executors/__init__.py +13 -0
  120. tpu_inference/executors/ray_distributed_executor.py +25 -4
  121. tpu_inference/experimental/__init__.py +13 -0
  122. tpu_inference/experimental/llama3_jax_stashed.py +14 -0
  123. tpu_inference/kernels/__init__.py +13 -0
  124. tpu_inference/kernels/collectives/__init__.py +13 -0
  125. tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
  126. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
  127. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  128. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  129. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  130. tpu_inference/kernels/fused_moe/v1/kernel.py +807 -230
  131. tpu_inference/kernels/megablox/__init__.py +13 -0
  132. tpu_inference/kernels/megablox/common.py +54 -0
  133. tpu_inference/kernels/megablox/gmm.py +646 -0
  134. tpu_inference/kernels/mla/__init__.py +13 -0
  135. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  136. tpu_inference/kernels/mla/v1/kernel.py +117 -145
  137. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  138. tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
  139. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  140. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  141. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
  142. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
  143. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  144. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
  145. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +218 -137
  146. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
  147. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
  148. tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
  149. tpu_inference/layers/__init__.py +13 -0
  150. tpu_inference/layers/common/__init__.py +13 -0
  151. tpu_inference/layers/common/attention_interface.py +25 -12
  152. tpu_inference/layers/common/attention_metadata.py +14 -0
  153. tpu_inference/layers/common/fused_moe_gmm.py +506 -0
  154. tpu_inference/layers/common/quant_methods.py +15 -0
  155. tpu_inference/layers/common/quantization.py +282 -0
  156. tpu_inference/layers/common/sharding.py +32 -9
  157. tpu_inference/layers/common/utils.py +94 -0
  158. tpu_inference/layers/jax/__init__.py +13 -0
  159. tpu_inference/layers/jax/attention/__init__.py +13 -0
  160. tpu_inference/layers/jax/attention/attention.py +19 -6
  161. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
  162. tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
  163. tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
  164. tpu_inference/layers/jax/base.py +14 -0
  165. tpu_inference/layers/jax/constants.py +13 -0
  166. tpu_inference/layers/jax/layers.py +14 -0
  167. tpu_inference/layers/jax/misc.py +14 -0
  168. tpu_inference/layers/jax/moe/__init__.py +13 -0
  169. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
  170. tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
  171. tpu_inference/layers/jax/moe/moe.py +43 -3
  172. tpu_inference/layers/jax/pp_utils.py +53 -0
  173. tpu_inference/layers/jax/rope.py +14 -0
  174. tpu_inference/layers/jax/rope_interface.py +14 -0
  175. tpu_inference/layers/jax/sample/__init__.py +13 -0
  176. tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
  177. tpu_inference/layers/jax/sample/sampling.py +15 -1
  178. tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
  179. tpu_inference/layers/jax/transformer_block.py +14 -0
  180. tpu_inference/layers/vllm/__init__.py +13 -0
  181. tpu_inference/layers/vllm/attention.py +4 -4
  182. tpu_inference/layers/vllm/fused_moe.py +101 -494
  183. tpu_inference/layers/vllm/linear.py +64 -0
  184. tpu_inference/layers/vllm/process_weights/__init__.py +13 -0
  185. tpu_inference/layers/vllm/{sharding.py → process_weights/cleanup_sharding.py} +24 -15
  186. tpu_inference/layers/vllm/process_weights/fused_moe_weights.py +369 -0
  187. tpu_inference/layers/vllm/process_weights/linear_weights.py +174 -0
  188. tpu_inference/layers/vllm/quantization/__init__.py +19 -3
  189. tpu_inference/layers/vllm/quantization/awq.py +96 -82
  190. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  191. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +23 -8
  192. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +172 -176
  193. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  194. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +111 -91
  195. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +79 -43
  196. tpu_inference/layers/vllm/quantization/{common.py → configs.py} +42 -25
  197. tpu_inference/layers/vllm/quantization/fp8.py +119 -0
  198. tpu_inference/layers/vllm/quantization/mxfp4.py +137 -178
  199. tpu_inference/layers/vllm/quantization/unquantized.py +157 -233
  200. tpu_inference/lora/__init__.py +13 -0
  201. tpu_inference/lora/torch_lora_ops.py +8 -13
  202. tpu_inference/models/__init__.py +13 -0
  203. tpu_inference/models/common/__init__.py +13 -0
  204. tpu_inference/models/common/model_loader.py +112 -35
  205. tpu_inference/models/jax/__init__.py +13 -0
  206. tpu_inference/models/jax/deepseek_v3.py +267 -157
  207. tpu_inference/models/jax/gpt_oss.py +26 -10
  208. tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
  209. tpu_inference/models/jax/llama3.py +99 -36
  210. tpu_inference/models/jax/llama4.py +14 -0
  211. tpu_inference/models/jax/llama_eagle3.py +18 -5
  212. tpu_inference/models/jax/llama_guard_4.py +15 -1
  213. tpu_inference/models/jax/qwen2.py +17 -2
  214. tpu_inference/models/jax/qwen2_5_vl.py +179 -51
  215. tpu_inference/models/jax/qwen3.py +17 -2
  216. tpu_inference/models/jax/utils/__init__.py +13 -0
  217. tpu_inference/models/jax/utils/file_utils.py +14 -0
  218. tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
  219. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  220. tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +92 -32
  221. tpu_inference/models/jax/utils/weight_utils.py +234 -155
  222. tpu_inference/models/vllm/__init__.py +13 -0
  223. tpu_inference/models/vllm/vllm_model_wrapper.py +32 -8
  224. tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
  225. tpu_inference/platforms/__init__.py +14 -0
  226. tpu_inference/platforms/tpu_platform.py +51 -72
  227. tpu_inference/runner/__init__.py +13 -0
  228. tpu_inference/runner/compilation_manager.py +180 -80
  229. tpu_inference/runner/kv_cache.py +54 -20
  230. tpu_inference/runner/kv_cache_manager.py +55 -33
  231. tpu_inference/runner/lora_utils.py +16 -1
  232. tpu_inference/runner/multimodal_manager.py +16 -2
  233. tpu_inference/runner/persistent_batch_manager.py +54 -2
  234. tpu_inference/runner/speculative_decoding_manager.py +14 -0
  235. tpu_inference/runner/structured_decoding_manager.py +16 -3
  236. tpu_inference/runner/tpu_runner.py +124 -61
  237. tpu_inference/runner/utils.py +2 -2
  238. tpu_inference/spec_decode/__init__.py +13 -0
  239. tpu_inference/spec_decode/jax/__init__.py +13 -0
  240. tpu_inference/spec_decode/jax/eagle3.py +84 -22
  241. tpu_inference/tpu_info.py +14 -0
  242. tpu_inference/utils.py +72 -44
  243. tpu_inference/worker/__init__.py +13 -0
  244. tpu_inference/worker/tpu_worker.py +66 -52
  245. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/METADATA +8 -9
  246. tpu_inference-0.13.2.dev20251230.dist-info/RECORD +266 -0
  247. tpu_inference/layers/vllm/linear_common.py +0 -186
  248. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  249. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
  250. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
  251. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
  252. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
  253. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
  254. tpu_inference-0.11.1.dev202511220812.dist-info/RECORD +0 -174
  255. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/WHEEL +0 -0
  256. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/licenses/LICENSE +0 -0
  257. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/top_level.txt +0 -0
@@ -1,13 +1,27 @@
1
- import os
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
2
15
  import time
3
16
  from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple
4
17
 
5
18
  import jax
6
19
  import jax.numpy as jnp
7
20
  import numpy as np
8
- import vllm.envs as envs
21
+ import vllm.envs as vllm_envs
9
22
  from jax.sharding import NamedSharding, PartitionSpec
10
23
 
24
+ import tpu_inference.envs as envs
11
25
  from tpu_inference.core.disagg_utils import is_disagg_enabled
12
26
  from tpu_inference.layers.common.attention_metadata import AttentionMetadata
13
27
  from tpu_inference.layers.common.sharding import ShardingAxisName
@@ -15,6 +29,8 @@ from tpu_inference.layers.jax.sample.sampling import sample
15
29
  from tpu_inference.layers.jax.sample.sampling_metadata import \
16
30
  TPUSupportedSamplingMetadata
17
31
  from tpu_inference.logger import init_logger
32
+ from tpu_inference.models.jax.jax_intermediate_tensor import \
33
+ JaxIntermediateTensors
18
34
  from tpu_inference.utils import device_array
19
35
 
20
36
  if TYPE_CHECKING:
@@ -30,10 +46,12 @@ class CompilationManager:
30
46
 
31
47
  def __init__(self, runner: "TPUModelRunner"):
32
48
  self.runner = runner
33
- if not envs.VLLM_DISABLE_COMPILE_CACHE:
49
+ self._sampling_precompiled = False
50
+ self._gather_logprobs_precompiled = False
51
+ if not vllm_envs.VLLM_DISABLE_COMPILE_CACHE:
34
52
  logger.info("Enabling JAX compile cache.")
35
53
  jax.config.update("jax_compilation_cache_dir",
36
- envs.VLLM_XLA_CACHE_PATH)
54
+ vllm_envs.VLLM_XLA_CACHE_PATH)
37
55
 
38
56
  def _create_dummy_tensor(self,
39
57
  shape: Tuple[int, ...],
@@ -67,8 +85,7 @@ class CompilationManager:
67
85
  logger.info("Compilation finished in %.2f [secs].", end - start)
68
86
 
69
87
  def capture_model(self) -> None:
70
- if os.getenv("SKIP_JAX_PRECOMPILE",
71
- False) or self.runner.model_config.enforce_eager:
88
+ if envs.SKIP_JAX_PRECOMPILE or self.runner.model_config.enforce_eager:
72
89
  return
73
90
  logger.info("Precompile all the subgraphs with possible input shapes.")
74
91
 
@@ -81,11 +98,17 @@ class CompilationManager:
81
98
  self._precompile_backbone_with_inputs_embeds()
82
99
  if self.runner.scheduler_config.async_scheduling:
83
100
  self._precompile_substitute_placeholder_token()
101
+ if not self.runner.is_last_rank:
102
+ return
84
103
  self._precompile_select_from_array()
85
104
  self._precompile_compute_logits()
105
+ # Skip sampling if already precompiled before KV cache allocation
106
+ if not self._sampling_precompiled:
107
+ self._precompile_sampling()
86
108
  self._precompile_disagg_utils()
87
- self._precompile_sampling()
88
- self._precompile_gather_logprobs()
109
+ # Skip gather_logprobs if already precompiled before KV cache allocation
110
+ if not self._gather_logprobs_precompiled:
111
+ self._precompile_gather_logprobs()
89
112
  self._precompile_structured_decoding()
90
113
  if self.runner.speculative_config:
91
114
  self._precompile_speculative_decoding()
@@ -104,7 +127,7 @@ class CompilationManager:
104
127
 
105
128
  self._run_compilation(
106
129
  "input_embeddings_merger",
107
- self.runner.get_input_embeddings_fn,
130
+ self.runner.embed_input_ids_fn,
108
131
  self.runner.state,
109
132
  dummy_input_ids,
110
133
  dummy_multimodal_embeddings,
@@ -113,15 +136,22 @@ class CompilationManager:
113
136
 
114
137
  self._run_compilation(
115
138
  "input_embeddings_merger_text_only",
116
- self.runner.get_input_embeddings_fn,
139
+ self.runner.embed_input_ids_fn,
117
140
  self.runner.state,
118
141
  dummy_input_ids,
119
142
  None,
120
143
  num_tokens=num_tokens,
121
144
  )
122
145
 
123
- def _precompile_backbone_helper(self, name, *, input_ids, positions,
124
- inputs_embeds) -> None:
146
+ def _precompile_backbone_helper(self,
147
+ name,
148
+ *,
149
+ input_ids,
150
+ positions,
151
+ inputs_embeds,
152
+ intermediate_tensors=None,
153
+ is_first_rank=True,
154
+ is_last_rank=True) -> None:
125
155
  num_tokens = None
126
156
  if input_ids is not None:
127
157
  num_tokens = input_ids.shape[0]
@@ -181,10 +211,14 @@ class CompilationManager:
181
211
  inputs_embeds,
182
212
  layer_name_to_kvcache_index,
183
213
  lora_metadata,
214
+ intermediate_tensors,
215
+ is_first_rank,
216
+ is_last_rank,
184
217
  ):
185
218
  kv_caches, hidden_states, _ = self.runner.model_fn(
186
219
  state, kv_caches, input_ids, attention_metadata, inputs_embeds,
187
- positions, layer_name_to_kvcache_index, lora_metadata)
220
+ positions, layer_name_to_kvcache_index, lora_metadata,
221
+ intermediate_tensors, is_first_rank, is_last_rank)
188
222
  self.runner.kv_caches = kv_caches
189
223
  return hidden_states
190
224
 
@@ -207,6 +241,9 @@ class CompilationManager:
207
241
  inputs_embeds,
208
242
  tuple(self.runner.layer_name_to_kvcache_index.items()),
209
243
  lora_metadata,
244
+ intermediate_tensors,
245
+ is_first_rank,
246
+ is_last_rank,
210
247
  num_tokens=num_tokens,
211
248
  )
212
249
 
@@ -257,6 +294,7 @@ class CompilationManager:
257
294
  )
258
295
 
259
296
  def _precompile_backbone_text_only(self) -> None:
297
+ hidden_size = self.runner.model_config.get_hidden_size()
260
298
  for num_tokens in self.runner.num_tokens_paddings:
261
299
  dp_sharding = NamedSharding(
262
300
  self.runner.mesh, PartitionSpec(ShardingAxisName.ATTN_DATA, )
@@ -266,10 +304,28 @@ class CompilationManager:
266
304
  dp_sharding)
267
305
  positions = self._create_dummy_tensor((num_tokens, ), jnp.int32,
268
306
  dp_sharding)
269
- self._precompile_backbone_helper("backbone",
270
- input_ids=input_ids,
271
- positions=positions,
272
- inputs_embeds=None)
307
+ is_first_rank = self.runner.is_first_rank
308
+ is_last_rank = self.runner.is_last_rank
309
+ if is_first_rank:
310
+ intermediate_tensors = None
311
+ else:
312
+ hidden_states = self._create_dummy_tensor(
313
+ (num_tokens, hidden_size), jnp.bfloat16)
314
+ residual = self._create_dummy_tensor((num_tokens, hidden_size),
315
+ jnp.bfloat16)
316
+ intermediate_tensors = JaxIntermediateTensors(
317
+ tensors={
318
+ "hidden_states": hidden_states,
319
+ "residual": residual
320
+ })
321
+ self._precompile_backbone_helper(
322
+ f"worker{self.runner.rank} backbone",
323
+ input_ids=input_ids,
324
+ positions=positions,
325
+ inputs_embeds=None,
326
+ intermediate_tensors=intermediate_tensors,
327
+ is_first_rank=is_first_rank,
328
+ is_last_rank=is_last_rank)
273
329
 
274
330
  def _precompile_backbone_with_inputs_embeds(self) -> None:
275
331
  hidden_size = self.runner.model_config.get_hidden_size()
@@ -283,10 +339,28 @@ class CompilationManager:
283
339
  else:
284
340
  positions = self._create_dummy_tensor((num_tokens, ),
285
341
  jnp.int32)
286
- self._precompile_backbone_helper("backbone with embeds",
287
- input_ids=None,
288
- positions=positions,
289
- inputs_embeds=inputs_embeds)
342
+ is_first_rank = self.runner.is_first_rank
343
+ is_last_rank = self.runner.is_last_rank
344
+ if not is_first_rank:
345
+ hidden_states = self._create_dummy_tensor(
346
+ (num_tokens, hidden_size), jnp.bfloat16)
347
+ residual = self._create_dummy_tensor((num_tokens, hidden_size),
348
+ jnp.bfloat16)
349
+ intermediate_tensors = JaxIntermediateTensors(
350
+ tensors={
351
+ "hidden_states": hidden_states,
352
+ "residual": residual
353
+ })
354
+ else:
355
+ intermediate_tensors = None
356
+ self._precompile_backbone_helper(
357
+ f"worker{self.runner.rank} backbone with embeds",
358
+ input_ids=None,
359
+ positions=positions,
360
+ inputs_embeds=inputs_embeds,
361
+ intermediate_tensors=intermediate_tensors,
362
+ is_first_rank=is_first_rank,
363
+ is_last_rank=is_last_rank)
290
364
 
291
365
  def _precompile_select_from_array_helper(
292
366
  self,
@@ -354,7 +428,7 @@ class CompilationManager:
354
428
  self.runner.mesh, PartitionSpec(ShardingAxisName.ATTN_DATA, None))
355
429
  dp_size = self.runner.vllm_config.sharding_config.total_dp_size
356
430
  self._precompile_select_from_array_helper(
357
- name="select all logits",
431
+ name=f"worker{self.runner.rank} select all logits",
358
432
  source_paddings=self.runner.num_tokens_paddings,
359
433
  indices_paddings=index_paddings,
360
434
  hidden_dim=hsize,
@@ -365,7 +439,8 @@ class CompilationManager:
365
439
  if self.runner.speculative_config:
366
440
  vocab_size = self.runner.model_config.get_vocab_size()
367
441
  self._precompile_select_from_array_helper(
368
- name="select bonus tokens for spec decoding",
442
+ name=
443
+ f"worker{self.runner.rank} select bonus tokens for spec decoding",
369
444
  source_paddings=self.runner.num_logits_paddings,
370
445
  indices_paddings=self.runner.num_reqs_paddings,
371
446
  hidden_dim=vocab_size,
@@ -373,7 +448,8 @@ class CompilationManager:
373
448
  PartitionSpec(None, "model")),
374
449
  )
375
450
  self._precompile_select_from_array_helper(
376
- name="select target tokens for spec decoding",
451
+ name=
452
+ f"worker{self.runner.rank} select target tokens for spec decoding",
377
453
  source_paddings=self.runner.num_logits_paddings,
378
454
  indices_paddings=self.runner.num_logits_paddings,
379
455
  hidden_dim=vocab_size,
@@ -396,7 +472,7 @@ class CompilationManager:
396
472
  np.array([num_reqs], dtype=np.int32)):
397
473
  lora_metadata = self.runner.lora_utils.extract_lora_metadata()
398
474
  self._run_compilation(
399
- "compute_logits",
475
+ f"worker{self.runner.rank} compute_logits",
400
476
  self.runner.compute_logits_fn,
401
477
  self.runner.state,
402
478
  hidden_states,
@@ -410,43 +486,48 @@ class CompilationManager:
410
486
  for num_reqs in self.runner.num_reqs_paddings:
411
487
  logits_sharding = NamedSharding(
412
488
  self.runner.mesh,
413
- PartitionSpec(ShardingAxisName.ATTN_DATA, "model"))
489
+ PartitionSpec(ShardingAxisName.MLP_DATA,
490
+ ShardingAxisName.MLP_TENSOR))
414
491
  dp_size = self.runner.vllm_config.sharding_config.total_dp_size
415
492
  sampling_metadata_sharding = NamedSharding(
416
493
  self.runner.mesh, PartitionSpec(
417
- ShardingAxisName.ATTN_DATA)) if dp_size > 1 else None
494
+ ShardingAxisName.MLP_DATA)) if dp_size > 1 else None
418
495
  logits = self._create_dummy_tensor((num_reqs, hsize), jnp.bfloat16,
419
496
  logits_sharding)
420
497
  for do_sampling in (True, False):
421
- if do_sampling:
422
- temperature = np.full((num_reqs, ), 0.7, dtype=np.float32)
423
- top_k = np.full((num_reqs, ), 20, dtype=np.int32)
424
- top_p = np.full((num_reqs, ), 0.8, dtype=np.float32)
425
- (temperature, top_k,
426
- top_p) = device_array(self.runner.mesh,
427
- (temperature, top_k, top_p),
428
- sharding=sampling_metadata_sharding)
429
- else:
430
- temperature = None
431
- top_k = None
432
- top_p = None
433
-
434
- sampling_metadata = TPUSupportedSamplingMetadata(
435
- temperature=temperature,
436
- top_k=top_k,
437
- top_p=top_p,
438
- do_sampling=do_sampling,
439
- )
440
- self._run_compilation(
441
- "sample",
442
- sample,
443
- self.runner.rng_params_for_sampling,
444
- self.runner.mesh,
445
- logits,
446
- sampling_metadata,
447
- num_reqs=num_reqs,
448
- do_sampling=do_sampling,
449
- )
498
+ for logprobs in (True, False):
499
+ if do_sampling:
500
+ temperature = np.full((num_reqs, ),
501
+ 0.7,
502
+ dtype=np.float32)
503
+ top_k = np.full((num_reqs, ), 20, dtype=np.int32)
504
+ top_p = np.full((num_reqs, ), 0.8, dtype=np.float32)
505
+ (temperature, top_k, top_p) = device_array(
506
+ self.runner.mesh, (temperature, top_k, top_p),
507
+ sharding=sampling_metadata_sharding)
508
+ else:
509
+ temperature = None
510
+ top_k = None
511
+ top_p = None
512
+
513
+ sampling_metadata = TPUSupportedSamplingMetadata(
514
+ temperature=temperature,
515
+ top_k=top_k,
516
+ top_p=top_p,
517
+ do_sampling=do_sampling,
518
+ logprobs=logprobs)
519
+ self._run_compilation(
520
+ f"worker{self.runner.rank} sample",
521
+ sample,
522
+ self.runner.rng_params_for_sampling,
523
+ self.runner.mesh,
524
+ logits,
525
+ sampling_metadata,
526
+ num_reqs=num_reqs,
527
+ do_sampling=do_sampling,
528
+ )
529
+
530
+ self._sampling_precompiled = True
450
531
 
451
532
  def _precompile_disagg_utils(self) -> None:
452
533
  if not is_disagg_enabled():
@@ -476,10 +557,18 @@ class CompilationManager:
476
557
  logger.info("Compiling gather_logprobs with different input shapes.")
477
558
  hsize = self.runner.model_config.get_vocab_size()
478
559
  for num_reqs in self.runner.num_reqs_paddings:
479
- logits = self._create_dummy_tensor((num_reqs, hsize), jnp.bfloat16)
480
- token_ids = self._create_dummy_tensor((num_reqs, ), jnp.int32)
560
+ logits_sharding = NamedSharding(
561
+ self.runner.mesh,
562
+ PartitionSpec(ShardingAxisName.MLP_DATA,
563
+ ShardingAxisName.MLP_TENSOR))
564
+ token_ids_sharding = NamedSharding(
565
+ self.runner.mesh, PartitionSpec(ShardingAxisName.MLP_DATA, ))
566
+ logits = self._create_dummy_tensor((num_reqs, hsize), jnp.bfloat16,
567
+ logits_sharding)
568
+ token_ids = self._create_dummy_tensor((num_reqs, ), jnp.int32,
569
+ token_ids_sharding)
481
570
  self._run_compilation(
482
- "gather_logprobs",
571
+ f"worker{self.runner.rank} gather_logprobs",
483
572
  self.runner._compute_and_gather_logprobs,
484
573
  logits,
485
574
  token_ids,
@@ -487,6 +576,8 @@ class CompilationManager:
487
576
  num_reqs=num_reqs,
488
577
  )
489
578
 
579
+ self._gather_logprobs_precompiled = True
580
+
490
581
  def _precompile_speculative_decoding(self) -> None:
491
582
  logger.info(
492
583
  "Compiling speculative_decoding with different input shapes.")
@@ -531,7 +622,7 @@ class CompilationManager:
531
622
  do_sampling=do_sampling)
532
623
 
533
624
  self._run_compilation(
534
- compilation_name,
625
+ f"worker{self.runner.rank} {compilation_name}",
535
626
  self.runner.rejection_sampler,
536
627
  draft_token_ids,
537
628
  num_draft_tokens,
@@ -548,7 +639,9 @@ class CompilationManager:
548
639
  def _precompile_eagle3_helpers(self) -> None:
549
640
  logger.info(
550
641
  "Compiling eagle3 jitted helpers with different input shapes.")
551
- hidden_size = self.runner.model_config.get_hidden_size()
642
+ target_hidden_size = self.runner.model_config.get_hidden_size()
643
+ draft_hidden_size = self.runner.speculative_config.draft_model_config.get_hidden_size(
644
+ )
552
645
  dtype = self.runner.model_config.dtype
553
646
 
554
647
  num_kv_cache_groups = len(self.runner.kv_cache_config.kv_cache_groups)
@@ -595,10 +688,11 @@ class CompilationManager:
595
688
 
596
689
  for num_logits in self.runner.num_logits_paddings:
597
690
  hidden_states = self._create_dummy_tensor(
598
- (num_logits, hidden_size), jnp.bfloat16)
691
+ (num_logits, draft_hidden_size), jnp.bfloat16)
599
692
  self._run_compilation(
600
693
  "eagle3_get_draft_token_ids",
601
694
  self.runner.drafter._get_draft_token_ids,
695
+ self.runner.drafter.state,
602
696
  hidden_states,
603
697
  num_logits=num_logits,
604
698
  )
@@ -606,8 +700,8 @@ class CompilationManager:
606
700
  input_ids_loop = self._create_dummy_tensor(
607
701
  (self.runner.max_num_reqs, ), jnp.int32,
608
702
  NamedSharding(self.runner.mesh, PartitionSpec()))
609
- target_hidden_state_loop = self._create_dummy_tensor(
610
- (self.runner.max_num_reqs, hidden_size), dtype,
703
+ draft_hidden_state_loop = self._create_dummy_tensor(
704
+ (self.runner.max_num_reqs, draft_hidden_size), dtype,
611
705
  NamedSharding(self.runner.mesh, PartitionSpec(None, None)))
612
706
  next_token_ids = self._create_dummy_tensor(
613
707
  (self.runner.max_num_reqs, ), jnp.int32)
@@ -615,9 +709,12 @@ class CompilationManager:
615
709
  (self.runner.max_num_reqs, ), jnp.int32)
616
710
  for num_tokens in self.runner.num_tokens_paddings:
617
711
  aux_hidden_states = [
618
- self._create_dummy_tensor((num_tokens, hidden_size), dtype),
619
- self._create_dummy_tensor((num_tokens, hidden_size), dtype),
620
- self._create_dummy_tensor((num_tokens, hidden_size), dtype),
712
+ self._create_dummy_tensor((num_tokens, target_hidden_size),
713
+ dtype),
714
+ self._create_dummy_tensor((num_tokens, target_hidden_size),
715
+ dtype),
716
+ self._create_dummy_tensor((num_tokens, target_hidden_size),
717
+ dtype),
621
718
  ]
622
719
 
623
720
  positions = self._create_dummy_tensor((num_tokens, ), jnp.int32)
@@ -640,23 +737,23 @@ class CompilationManager:
640
737
  num_reqs,
641
738
  ):
642
739
  target_hidden_states, input_ids, last_token_indices, _ = self.runner.drafter._filter_token_and_prepare_initial_inputs(
643
- token_indices, query_start_loc, seq_lens, input_ids,
644
- aux_hidden_states, attention_metadata, next_token_ids,
645
- num_reqs)
740
+ self.runner.drafter.state, token_indices, query_start_loc,
741
+ seq_lens, input_ids, aux_hidden_states, attention_metadata,
742
+ next_token_ids, num_reqs)
646
743
  return target_hidden_states, input_ids, last_token_indices
647
744
 
648
745
  input_ids = self._create_dummy_tensor((num_tokens, ), jnp.int32)
649
746
  aux_hidden_states = [
650
747
  self._create_dummy_tensor(
651
- (num_tokens, hidden_size), jnp.bfloat16,
748
+ (num_tokens, target_hidden_size), jnp.bfloat16,
652
749
  NamedSharding(self.runner.mesh, PartitionSpec(None,
653
750
  None))),
654
751
  self._create_dummy_tensor(
655
- (num_tokens, hidden_size), jnp.bfloat16,
752
+ (num_tokens, target_hidden_size), jnp.bfloat16,
656
753
  NamedSharding(self.runner.mesh, PartitionSpec(None,
657
754
  None))),
658
755
  self._create_dummy_tensor(
659
- (num_tokens, hidden_size), jnp.bfloat16,
756
+ (num_tokens, target_hidden_size), jnp.bfloat16,
660
757
  NamedSharding(self.runner.mesh, PartitionSpec(None,
661
758
  None))),
662
759
  ]
@@ -688,17 +785,17 @@ class CompilationManager:
688
785
  state,
689
786
  kv_caches,
690
787
  input_ids,
691
- target_hidden_states,
788
+ draft_hidden_states,
692
789
  attention_metadata,
693
790
  ):
694
791
  kv_caches, hidden_states, _ = self.runner.drafter.model_fn(
695
- state, kv_caches, input_ids, target_hidden_states,
792
+ state, kv_caches, input_ids, draft_hidden_states,
696
793
  attention_metadata)
697
794
  self.runner.kv_caches = kv_caches
698
795
  return hidden_states
699
796
 
700
- target_hidden_states = self._create_dummy_tensor(
701
- (num_tokens, hidden_size), dtype,
797
+ draft_hidden_states = self._create_dummy_tensor(
798
+ (num_tokens, draft_hidden_size), dtype,
702
799
  NamedSharding(self.runner.mesh, PartitionSpec(None, "model")))
703
800
  input_ids = self._create_dummy_tensor(
704
801
  (num_tokens, ), jnp.int32,
@@ -709,7 +806,7 @@ class CompilationManager:
709
806
  self.runner.drafter.state,
710
807
  self.runner.kv_caches,
711
808
  input_ids,
712
- target_hidden_states,
809
+ draft_hidden_states,
713
810
  attention_metadata,
714
811
  num_tokens=num_tokens,
715
812
  )
@@ -719,6 +816,7 @@ class CompilationManager:
719
816
  self._run_compilation(
720
817
  "eagle3_prepare_hidden_states_and_input_ids",
721
818
  self.runner.drafter._prepare_hidden_states_and_input_ids,
819
+ self.runner.drafter.state,
722
820
  aux_hidden_states,
723
821
  query_start_loc,
724
822
  target_token_ids,
@@ -741,18 +839,19 @@ class CompilationManager:
741
839
  self.runner.drafter.state,
742
840
  self.runner.kv_caches,
743
841
  input_ids_loop,
744
- target_hidden_state_loop,
842
+ draft_hidden_state_loop,
745
843
  attention_metadata,
746
844
  num_tokens=num_tokens,
747
845
  )
748
846
 
749
847
  hidden_states = self._create_dummy_tensor(
750
- (num_tokens, hidden_size), jnp.bfloat16,
848
+ (num_tokens, draft_hidden_size), jnp.bfloat16,
751
849
  NamedSharding(self.runner.mesh, PartitionSpec(None, None)))
752
850
 
753
851
  self._run_compilation(
754
852
  "eagle3_select_inputs_for_loop_speculation",
755
853
  self.runner.drafter._select_inputs_for_loop_speculation,
854
+ self.runner.drafter.state,
756
855
  positions,
757
856
  hidden_states,
758
857
  hidden_states,
@@ -763,6 +862,7 @@ class CompilationManager:
763
862
  self._run_compilation(
764
863
  "eagle3_select_draft_token_ids",
765
864
  self.runner.drafter._select_draft_token_ids,
865
+ self.runner.drafter.state,
766
866
  hidden_states,
767
867
  last_token_indices,
768
868
  num_tokens=num_tokens,
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  from typing import Any, List
2
16
 
3
17
  import jax
@@ -7,6 +21,7 @@ from jax._src import dtypes
7
21
  from jax.sharding import Mesh, NamedSharding, PartitionSpec
8
22
  from torchax.ops.mappings import t2j_dtype
9
23
 
24
+ import tpu_inference.kernels.mla.v1.kernel as mla
10
25
  import tpu_inference.kernels.ragged_paged_attention.v3.kernel as rpa
11
26
  import tpu_inference.kernels.ragged_paged_attention.v3.kernel_hd64 as rpa_hd64
12
27
  from tpu_inference.layers.common.sharding import ShardingAxisName
@@ -17,9 +32,13 @@ logger = init_logger(__name__)
17
32
  DEFAULT_KV_CACHE_DTYPE = jnp.bfloat16
18
33
 
19
34
 
20
- def get_kv_cache_shape_with_mesh(mesh: Mesh, total_num_pages: int,
21
- page_size: int, actual_num_kv_heads: int,
22
- actual_head_dim: int, kv_dtype: any):
35
+ def get_kv_cache_shape_with_mesh(mesh: Mesh,
36
+ total_num_pages: int,
37
+ page_size: int,
38
+ actual_num_kv_heads: int,
39
+ actual_head_dim: int,
40
+ kv_dtype: any,
41
+ use_mla: bool = False):
23
42
  """Gets the KV cache shape based on the mesh configuration."""
24
43
 
25
44
  model_cnt = mesh.shape["model"]
@@ -28,15 +47,21 @@ def get_kv_cache_shape_with_mesh(mesh: Mesh, total_num_pages: int,
28
47
  # specific model, rather than being determined by the head_dim. If new
29
48
  # models are introduced with a head_dim of 64, this will require additional
30
49
  # model-specific adjustments.
31
- get_kv_cache_shape_fn = (
32
- rpa_hd64.get_kv_cache_shape if actual_head_dim == 64 \
33
- else rpa.get_kv_cache_shape
34
- )
35
- shape = list(
36
- get_kv_cache_shape_fn(total_num_pages, page_size,
37
- actual_num_kv_heads // model_cnt,
38
- actual_head_dim, kv_dtype))
39
- shape[2] *= model_cnt
50
+ if use_mla:
51
+ get_kv_cache_shape_fn = mla.get_kv_cache_shape
52
+ shape = list(
53
+ get_kv_cache_shape_fn(total_num_pages, page_size, actual_head_dim,
54
+ kv_dtype))
55
+ else:
56
+ get_kv_cache_shape_fn = (
57
+ rpa_hd64.get_kv_cache_shape if actual_head_dim == 64 \
58
+ else rpa.get_kv_cache_shape
59
+ )
60
+ shape = list(
61
+ get_kv_cache_shape_fn(total_num_pages, page_size,
62
+ actual_num_kv_heads // model_cnt,
63
+ actual_head_dim, kv_dtype))
64
+ shape[2] *= model_cnt
40
65
  return tuple(shape)
41
66
 
42
67
 
@@ -48,6 +73,7 @@ def create_kv_caches(
48
73
  mesh: Mesh,
49
74
  layer_names: List[str],
50
75
  cache_dtype: jnp.dtype = DEFAULT_KV_CACHE_DTYPE,
76
+ use_mla: bool = False,
51
77
  ) -> List[jax.Array]:
52
78
  """
53
79
  Creates a list of KV cache where each array mapps to single attention layer.
@@ -74,12 +100,16 @@ def create_kv_caches(
74
100
 
75
101
  cache_shape = get_kv_cache_shape_with_mesh(mesh, num_blocks, block_size,
76
102
  num_kv_heads, head_size,
77
- cache_dtype)
103
+ cache_dtype, use_mla)
78
104
 
79
- sharding = NamedSharding(
80
- mesh,
81
- PartitionSpec(ShardingAxisName.ATTN_DATA, None,
82
- ShardingAxisName.ATTN_HEAD))
105
+ if use_mla:
106
+ sharding = NamedSharding(mesh,
107
+ PartitionSpec(ShardingAxisName.MLP_TENSOR))
108
+ else:
109
+ sharding = NamedSharding(
110
+ mesh,
111
+ PartitionSpec(ShardingAxisName.ATTN_DATA, None,
112
+ ShardingAxisName.ATTN_HEAD))
83
113
 
84
114
  def _allocate() -> jax.Array:
85
115
  return jnp.empty(
@@ -94,7 +124,8 @@ def create_kv_caches(
94
124
  return kv_caches
95
125
 
96
126
 
97
- def get_rpa_page_size_bytes(mesh: Mesh, kv_cache_specs: dict[str, Any]) -> int:
127
+ def get_attention_page_size_bytes(mesh: Mesh,
128
+ kv_cache_specs: dict[str, Any]) -> int:
98
129
  """
99
130
  Calculate KV cache page size of RPA kernel.
100
131
 
@@ -107,14 +138,16 @@ def get_rpa_page_size_bytes(mesh: Mesh, kv_cache_specs: dict[str, Any]) -> int:
107
138
  """
108
139
 
109
140
  # Import it here to avoid circular import.
110
- from vllm.v1.kv_cache_interface import AttentionSpec
141
+ from vllm.v1.kv_cache_interface import AttentionSpec, MLAAttentionSpec
111
142
 
112
143
  page_size_bytes_set = set()
113
144
  for kv_cache_spec in kv_cache_specs.values():
114
145
  assert isinstance(kv_cache_spec, AttentionSpec)
115
146
 
116
147
  dtype = t2j_dtype(kv_cache_spec.dtype)
117
- bits = dtypes.bit_width(dtype)
148
+ bits = (dtypes.bit_width(dtype) if hasattr(dtypes, "bit_width") else
149
+ dtypes.itemsize_bits(dtype))
150
+ use_mla = isinstance(kv_cache_spec, MLAAttentionSpec)
118
151
 
119
152
  kv_cache_shape = get_kv_cache_shape_with_mesh(
120
153
  mesh=mesh,
@@ -123,6 +156,7 @@ def get_rpa_page_size_bytes(mesh: Mesh, kv_cache_specs: dict[str, Any]) -> int:
123
156
  actual_num_kv_heads=kv_cache_spec.num_kv_heads,
124
157
  actual_head_dim=kv_cache_spec.head_size,
125
158
  kv_dtype=dtype,
159
+ use_mla=use_mla,
126
160
  )
127
161
  page_size_bytes = (bits * np.prod(kv_cache_shape)) // 8
128
162
  page_size_bytes_set.add(page_size_bytes)