tpu-inference 0.11.1.dev202512030818__py3-none-any.whl → 0.13.2rc3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (250) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_disagg_utils.py +14 -0
  4. tests/core/test_dp_scheduler.py +650 -768
  5. tests/core/test_init.py +14 -0
  6. tests/distributed/__init__.py +13 -0
  7. tests/distributed/test_distributed_utils.py +120 -0
  8. tests/distributed/test_tpu_connector.py +478 -0
  9. tests/e2e/__init__.py +13 -0
  10. tests/e2e/test_async_scheduler.py +211 -0
  11. tests/e2e/test_data_parallel.py +289 -0
  12. tests/e2e/test_hybrid_kvcache.py +219 -0
  13. tests/e2e/test_local_disagg.py +257 -0
  14. tests/e2e/test_model_loader.py +268 -0
  15. tests/e2e/test_multi_modal_inference.py +111 -0
  16. tests/e2e/test_pipeline_parallel.py +265 -0
  17. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  18. tests/e2e/test_sampling_params.py +269 -0
  19. tests/e2e/test_speculative_decoding.py +311 -0
  20. tests/e2e/test_structured_decoding.py +46 -0
  21. tests/executors/__init__.py +13 -0
  22. tests/executors/test_ray_distributed_executor.py +199 -0
  23. tests/experimental/__init__.py +13 -0
  24. tests/experimental/test_llama3_jax_stashed.py +208 -0
  25. tests/kernels/__init__.py +13 -0
  26. tests/kernels/collectives/__init__.py +13 -0
  27. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  28. tests/kernels/fused_moe_v1_test.py +14 -0
  29. tests/kernels/gmm_test.py +205 -0
  30. tests/kernels/mla_v1_test.py +143 -41
  31. tests/kernels/quantized_matmul_kernel_test.py +2 -34
  32. tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
  33. tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
  34. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
  35. tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
  36. tests/layers/__init__.py +13 -0
  37. tests/layers/common/__init__.py +13 -0
  38. tests/layers/common/test_attention_interface.py +156 -0
  39. tests/layers/common/test_quantization.py +149 -0
  40. tests/layers/jax/__init__.py +13 -0
  41. tests/layers/jax/attention/__init__.py +13 -0
  42. tests/layers/jax/attention/test_common_attention.py +103 -0
  43. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  44. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  45. tests/layers/jax/moe/__init__.py +13 -0
  46. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  47. tests/layers/jax/sample/__init__.py +13 -0
  48. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  49. tests/layers/jax/sample/test_sampling.py +115 -0
  50. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  51. tests/layers/jax/test_layers.py +155 -0
  52. tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
  53. tests/layers/jax/test_rope.py +93 -0
  54. tests/layers/jax/test_sharding.py +159 -0
  55. tests/layers/jax/test_transformer_block.py +152 -0
  56. tests/layers/vllm/__init__.py +13 -0
  57. tests/layers/vllm/test_attention.py +363 -0
  58. tests/layers/vllm/test_awq.py +405 -0
  59. tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
  60. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +418 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +441 -0
  62. tests/layers/vllm/test_fp8.py +17 -0
  63. tests/layers/vllm/test_mxfp4.py +312 -0
  64. tests/layers/vllm/test_unquantized.py +651 -0
  65. tests/layers/vllm/utils.py +87 -0
  66. tests/lora/__init__.py +13 -0
  67. tests/lora/conftest.py +14 -0
  68. tests/lora/test_bgmv.py +14 -0
  69. tests/lora/test_layers.py +21 -3
  70. tests/lora/test_lora.py +15 -1
  71. tests/lora/test_lora_perf.py +67 -0
  72. tests/models/__init__.py +13 -0
  73. tests/models/common/__init__.py +13 -0
  74. tests/models/common/test_model_loader.py +455 -0
  75. tests/models/jax/__init__.py +13 -0
  76. tests/models/jax/test_deepseek_v3.py +401 -0
  77. tests/models/jax/test_llama3.py +184 -0
  78. tests/models/jax/test_llama4.py +298 -0
  79. tests/models/jax/test_llama_eagle3.py +197 -0
  80. tests/models/jax/test_llama_guard_4.py +242 -0
  81. tests/models/jax/test_qwen2.py +172 -0
  82. tests/models/jax/test_qwen2_5_vl.py +605 -0
  83. tests/models/jax/test_qwen3.py +169 -0
  84. tests/models/jax/test_weight_loading.py +180 -0
  85. tests/models/jax/utils/__init__.py +13 -0
  86. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  87. tests/platforms/__init__.py +13 -0
  88. tests/platforms/test_tpu_platform.py +54 -0
  89. tests/runner/__init__.py +13 -0
  90. tests/runner/test_block_table.py +395 -0
  91. tests/runner/test_input_batch.py +226 -0
  92. tests/runner/test_kv_cache.py +220 -0
  93. tests/runner/test_kv_cache_manager.py +498 -0
  94. tests/runner/test_multimodal_manager.py +429 -0
  95. tests/runner/test_persistent_batch_manager.py +84 -0
  96. tests/runner/test_speculative_decoding_manager.py +368 -0
  97. tests/runner/test_structured_decoding_manager.py +220 -0
  98. tests/runner/test_tpu_runner.py +261 -0
  99. tests/runner/test_tpu_runner_dp.py +1099 -0
  100. tests/runner/test_tpu_runner_mesh.py +200 -0
  101. tests/runner/test_utils.py +411 -0
  102. tests/spec_decode/__init__.py +13 -0
  103. tests/spec_decode/test_eagle3.py +311 -0
  104. tests/test_base.py +14 -0
  105. tests/test_envs.py +78 -1
  106. tests/test_tpu_info.py +14 -0
  107. tests/test_utils.py +1 -43
  108. tests/worker/__init__.py +13 -0
  109. tests/worker/tpu_worker_test.py +414 -0
  110. tpu_inference/__init__.py +14 -0
  111. tpu_inference/core/__init__.py +13 -0
  112. tpu_inference/core/sched/__init__.py +13 -0
  113. tpu_inference/core/sched/dp_scheduler.py +372 -56
  114. tpu_inference/distributed/__init__.py +13 -0
  115. tpu_inference/distributed/jax_parallel_state.py +14 -0
  116. tpu_inference/distributed/tpu_connector.py +14 -9
  117. tpu_inference/distributed/utils.py +56 -4
  118. tpu_inference/envs.py +38 -7
  119. tpu_inference/executors/__init__.py +13 -0
  120. tpu_inference/executors/ray_distributed_executor.py +17 -0
  121. tpu_inference/experimental/__init__.py +13 -0
  122. tpu_inference/experimental/llama3_jax_stashed.py +14 -0
  123. tpu_inference/kernels/__init__.py +13 -0
  124. tpu_inference/kernels/collectives/__init__.py +13 -0
  125. tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
  126. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
  127. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  128. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  129. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  130. tpu_inference/kernels/fused_moe/v1/kernel.py +370 -324
  131. tpu_inference/kernels/megablox/__init__.py +13 -0
  132. tpu_inference/kernels/megablox/common.py +54 -0
  133. tpu_inference/kernels/megablox/gmm.py +646 -0
  134. tpu_inference/kernels/mla/__init__.py +13 -0
  135. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  136. tpu_inference/kernels/mla/v1/kernel.py +117 -145
  137. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  138. tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
  139. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  140. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  141. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
  142. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
  143. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  144. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
  145. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +95 -78
  146. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
  147. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
  148. tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
  149. tpu_inference/layers/__init__.py +13 -0
  150. tpu_inference/layers/common/__init__.py +13 -0
  151. tpu_inference/layers/common/attention_interface.py +26 -19
  152. tpu_inference/layers/common/attention_metadata.py +14 -0
  153. tpu_inference/layers/common/quant_methods.py +15 -0
  154. tpu_inference/layers/common/quantization.py +270 -0
  155. tpu_inference/layers/common/sharding.py +28 -5
  156. tpu_inference/layers/jax/__init__.py +13 -0
  157. tpu_inference/layers/jax/attention/__init__.py +13 -0
  158. tpu_inference/layers/jax/attention/attention.py +19 -6
  159. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
  160. tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
  161. tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
  162. tpu_inference/layers/jax/base.py +14 -0
  163. tpu_inference/layers/jax/constants.py +13 -0
  164. tpu_inference/layers/jax/layers.py +14 -0
  165. tpu_inference/layers/jax/misc.py +14 -0
  166. tpu_inference/layers/jax/moe/__init__.py +13 -0
  167. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
  168. tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
  169. tpu_inference/layers/jax/moe/moe.py +43 -3
  170. tpu_inference/layers/jax/pp_utils.py +53 -0
  171. tpu_inference/layers/jax/rope.py +14 -0
  172. tpu_inference/layers/jax/rope_interface.py +14 -0
  173. tpu_inference/layers/jax/sample/__init__.py +13 -0
  174. tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
  175. tpu_inference/layers/jax/sample/sampling.py +15 -1
  176. tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
  177. tpu_inference/layers/jax/transformer_block.py +14 -0
  178. tpu_inference/layers/vllm/__init__.py +13 -0
  179. tpu_inference/layers/vllm/attention.py +4 -4
  180. tpu_inference/layers/vllm/fused_moe.py +210 -260
  181. tpu_inference/layers/vllm/linear_common.py +57 -22
  182. tpu_inference/layers/vllm/quantization/__init__.py +16 -0
  183. tpu_inference/layers/vllm/quantization/awq.py +15 -1
  184. tpu_inference/layers/vllm/quantization/common.py +33 -18
  185. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  186. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +18 -3
  187. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +211 -148
  188. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  189. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +14 -0
  190. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +14 -0
  191. tpu_inference/layers/vllm/quantization/fp8.py +118 -0
  192. tpu_inference/layers/vllm/quantization/mxfp4.py +278 -209
  193. tpu_inference/layers/vllm/quantization/unquantized.py +134 -86
  194. tpu_inference/layers/vllm/sharding.py +21 -4
  195. tpu_inference/lora/__init__.py +13 -0
  196. tpu_inference/lora/torch_lora_ops.py +8 -13
  197. tpu_inference/models/__init__.py +13 -0
  198. tpu_inference/models/common/__init__.py +13 -0
  199. tpu_inference/models/common/model_loader.py +74 -35
  200. tpu_inference/models/jax/__init__.py +13 -0
  201. tpu_inference/models/jax/deepseek_v3.py +267 -157
  202. tpu_inference/models/jax/gpt_oss.py +26 -10
  203. tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
  204. tpu_inference/models/jax/llama3.py +99 -36
  205. tpu_inference/models/jax/llama4.py +14 -0
  206. tpu_inference/models/jax/llama_eagle3.py +14 -0
  207. tpu_inference/models/jax/llama_guard_4.py +15 -1
  208. tpu_inference/models/jax/qwen2.py +17 -2
  209. tpu_inference/models/jax/qwen2_5_vl.py +18 -4
  210. tpu_inference/models/jax/qwen3.py +17 -2
  211. tpu_inference/models/jax/utils/__init__.py +13 -0
  212. tpu_inference/models/jax/utils/file_utils.py +14 -0
  213. tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
  214. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  215. tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +89 -26
  216. tpu_inference/models/jax/utils/weight_utils.py +39 -2
  217. tpu_inference/models/vllm/__init__.py +13 -0
  218. tpu_inference/models/vllm/vllm_model_wrapper.py +20 -3
  219. tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
  220. tpu_inference/platforms/__init__.py +14 -0
  221. tpu_inference/platforms/tpu_platform.py +47 -64
  222. tpu_inference/runner/__init__.py +13 -0
  223. tpu_inference/runner/compilation_manager.py +72 -37
  224. tpu_inference/runner/kv_cache.py +54 -20
  225. tpu_inference/runner/kv_cache_manager.py +46 -17
  226. tpu_inference/runner/lora_utils.py +14 -0
  227. tpu_inference/runner/multimodal_manager.py +15 -1
  228. tpu_inference/runner/persistent_batch_manager.py +14 -0
  229. tpu_inference/runner/speculative_decoding_manager.py +14 -0
  230. tpu_inference/runner/structured_decoding_manager.py +14 -0
  231. tpu_inference/runner/tpu_runner.py +44 -17
  232. tpu_inference/spec_decode/__init__.py +13 -0
  233. tpu_inference/spec_decode/jax/__init__.py +13 -0
  234. tpu_inference/spec_decode/jax/eagle3.py +13 -0
  235. tpu_inference/tpu_info.py +14 -0
  236. tpu_inference/utils.py +42 -36
  237. tpu_inference/worker/__init__.py +13 -0
  238. tpu_inference/worker/tpu_worker.py +63 -50
  239. {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.2rc3.dist-info}/METADATA +7 -9
  240. tpu_inference-0.13.2rc3.dist-info/RECORD +261 -0
  241. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  242. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
  243. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
  244. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
  245. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
  246. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
  247. tpu_inference-0.11.1.dev202512030818.dist-info/RECORD +0 -174
  248. {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.2rc3.dist-info}/WHEEL +0 -0
  249. {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.2rc3.dist-info}/licenses/LICENSE +0 -0
  250. {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.2rc3.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,16 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
1
14
  """TPU-Friendly and Data-Movement-Friendly MLA Ragged Paged Attention kernel."""
2
15
 
3
16
  import functools
@@ -16,17 +29,30 @@ DEFAULT_MASK_VALUE = -0.7 * float(jnp.finfo(jnp.dtype("float32")).max)
16
29
  DEFAULT_VMEM_LIMIT_BYTES = 100 * 1024 * 1024
17
30
 
18
31
 
32
+ def get_kv_cache_shape(
33
+ total_num_pages,
34
+ page_size,
35
+ kv_dim,
36
+ kv_dtype,
37
+ ):
38
+ kv_packing = get_dtype_packing(kv_dtype)
39
+ return (
40
+ total_num_pages,
41
+ align_to(page_size, kv_packing) // kv_packing,
42
+ kv_packing,
43
+ align_to(kv_dim, 128),
44
+ )
45
+
46
+
19
47
  @functools.partial(
20
48
  jax.jit,
21
- donate_argnames=("cache_kv_c", "cache_k_pe"),
49
+ donate_argnames=("cache_kv"),
22
50
  )
23
51
  def update_kv_cache(
24
52
  new_kv_c: jax.Array, # [num_tokens, actual_lkv_dim]
25
53
  new_k_pe: jax.Array, # [num_tokens, actual_r_dim]
26
- cache_kv_c: jax.
27
- Array, # [total_num_pages, page_size_per_kv_packing, kv_packing, lkv_dim]
28
- cache_k_pe: jax.
29
- Array, # [total_num_pages, page_size_per_kv_packing, kv_packing, r_dim]
54
+ cache_kv: jax.
55
+ Array, # [total_num_pages, page_size_per_kv_packing, kv_packing, lkv_dim+r_dim]
30
56
  kv_lens: jax.Array, # i32[max_num_seqs]
31
57
  page_indices: jax.Array, # i32[max_num_seqs * pages_per_seq]
32
58
  cu_q_lens: jax.Array, # i32[max_num_seqs + 1]
@@ -43,25 +69,21 @@ def update_kv_cache(
43
69
  if actual_lkv_dim != lkv_dim:
44
70
  new_kv_c = jnp.pad(new_kv_c, ((0, 0), (0, lkv_dim - actual_lkv_dim)),
45
71
  constant_values=0)
46
-
47
- _, page_size_per_kv_packing, kv_packing, cache_lkv_dim = cache_kv_c.shape
48
- _, _, _, cache_r_dim = cache_k_pe.shape
49
- assert lkv_dim == cache_lkv_dim
50
- assert r_dim == cache_r_dim
72
+ kv_dim = r_dim + lkv_dim
73
+ _, page_size_per_kv_packing, kv_packing, cache_kv_dim = cache_kv.shape
74
+ assert kv_dim == cache_kv_dim
51
75
  page_size = page_size_per_kv_packing * kv_packing
52
76
 
53
77
  max_num_seqs = kv_lens.shape[0]
54
78
  num_page_indices = page_indices.shape[0]
55
79
  pages_per_seq = num_page_indices // max_num_seqs
56
80
 
57
- def seq_loop_body(i, caches):
58
- cache_kv_c, cache_k_pe = caches
81
+ def seq_loop_body(i, cache_kv):
59
82
  q_start, q_end = cu_q_lens[i], cu_q_lens[i + 1]
60
83
  q_len = q_end - q_start
61
84
  kv_len = kv_lens[i]
62
85
 
63
- def token_loop_body(j, caches_):
64
- cache_kv_c_, cache_k_pe_ = caches_
86
+ def token_loop_body(j, cache_kv_):
65
87
  token_idx_in_seq = kv_len - q_len + j
66
88
  page_num_in_seq = token_idx_in_seq // page_size
67
89
  page_indices_start = i * pages_per_seq
@@ -69,18 +91,17 @@ def update_kv_cache(
69
91
  row = (token_idx_in_seq % page_size) // kv_packing
70
92
  col = (token_idx_in_seq % page_size) % kv_packing
71
93
 
72
- cache_kv_c_ = cache_kv_c_.at[page_idx, row,
73
- col].set(new_kv_c[q_start + j])
74
- cache_k_pe_ = cache_k_pe_.at[page_idx, row,
75
- col].set(new_k_pe[q_start + j])
76
- return cache_kv_c_, cache_k_pe_
94
+ cache_kv_ = cache_kv_.at[page_idx, row, col,
95
+ ..., :lkv_dim].set(new_kv_c[q_start + j])
96
+ cache_kv_ = cache_kv_.at[page_idx, row, col, ...,
97
+ lkv_dim:].set(new_k_pe[q_start + j])
98
+ return cache_kv_
99
+
100
+ return lax.fori_loop(0, q_len, token_loop_body, cache_kv)
77
101
 
78
- return lax.fori_loop(0, q_len, token_loop_body,
79
- (cache_kv_c, cache_k_pe))
102
+ cache_kv = lax.fori_loop(0, distribution[-1], seq_loop_body, cache_kv)
80
103
 
81
- cache_kv_c, cache_k_pe = lax.fori_loop(0, distribution[-1], seq_loop_body,
82
- (cache_kv_c, cache_k_pe))
83
- return cache_kv_c, cache_k_pe
104
+ return cache_kv
84
105
 
85
106
 
86
107
  def ref_mla_ragged_paged_attention(
@@ -88,10 +109,8 @@ def ref_mla_ragged_paged_attention(
88
109
  q_pe: jax.Array, # [num_tokens, actual_num_q_heads, actual_r_dim]
89
110
  new_kv_c: jax.Array, # [num_tokens, actual_lkv_dim]
90
111
  new_k_pe: jax.Array, # [num_tokens, actual_r_dim]
91
- cache_kv_c: jax.
112
+ cache_kv: jax.
92
113
  Array, # [total_num_pages, page_size_per_kv_packing, kv_packing, lkv_dim]
93
- cache_k_pe: jax.
94
- Array, # [total_num_pages, page_size_per_kv_packing, kv_packing, r_dim]
95
114
  kv_lens: jax.Array, # i32[max_num_seqs]
96
115
  page_indices: jax.Array, # i32[max_num_seqs * pages_per_seq]
97
116
  cu_q_lens: jax.Array, # i32[max_num_seqs + 1]
@@ -111,8 +130,7 @@ def ref_mla_ragged_paged_attention(
111
130
  q_pe,
112
131
  new_kv_c,
113
132
  new_k_pe,
114
- cache_kv_c,
115
- cache_k_pe,
133
+ cache_kv,
116
134
  kv_lens,
117
135
  page_indices,
118
136
  cu_q_lens,
@@ -123,11 +141,10 @@ def ref_mla_ragged_paged_attention(
123
141
  mask_value=mask_value,
124
142
  )
125
143
 
126
- cache_kv_c, cache_k_pe = update_kv_cache(
144
+ updated_cache_kv = update_kv_cache(
127
145
  new_kv_c,
128
146
  new_k_pe,
129
- cache_kv_c,
130
- cache_k_pe,
147
+ cache_kv,
131
148
  kv_lens,
132
149
  page_indices,
133
150
  cu_q_lens,
@@ -154,13 +171,17 @@ def ref_mla_ragged_paged_attention(
154
171
  assert num_page_indices % max_num_seqs == 0
155
172
  pages_per_seq = num_page_indices // max_num_seqs
156
173
 
157
- total_num_pages, page_size_per_kv_packing, kv_packing, _ = cache_kv_c.shape
174
+ total_num_pages, page_size_per_kv_packing, kv_packing, _ = updated_cache_kv.shape
158
175
  page_size = page_size_per_kv_packing * kv_packing
159
176
  assert lkv_dim == ql_nope.shape[-1]
160
177
  assert r_dim == q_pe.shape[-1]
178
+ assert lkv_dim + r_dim == updated_cache_kv.shape[-1]
161
179
 
162
- kv_c_cache = cache_kv_c.reshape(total_num_pages, page_size, lkv_dim)
163
- k_pe_cache = cache_k_pe.reshape(total_num_pages, page_size, r_dim)
180
+ kv_c_cache = updated_cache_kv[..., :lkv_dim].reshape(
181
+ total_num_pages, page_size, lkv_dim)
182
+ k_pe_cache = updated_cache_kv[...,
183
+ lkv_dim:].reshape(total_num_pages, page_size,
184
+ r_dim)
164
185
 
165
186
  outputs = []
166
187
 
@@ -221,8 +242,7 @@ def ref_mla_ragged_paged_attention(
221
242
 
222
243
  return (
223
244
  jnp.concatenate(outputs, axis=0),
224
- cache_kv_c,
225
- cache_k_pe,
245
+ updated_cache_kv,
226
246
  )
227
247
 
228
248
 
@@ -232,10 +252,8 @@ def dynamic_validate_inputs(
232
252
  q_pe: jax.Array, # [max_num_tokens, actual_num_q_heads, actual_r_dim]
233
253
  new_kv_c: jax.Array, # [max_num_tokens, actual_lkv_dim]
234
254
  new_k_pe: jax.Array, # [max_num_tokens, actual_r_dim]
235
- cache_kv_c: jax.
255
+ cache_kv: jax.
236
256
  Array, # [total_num_pages, page_size_per_kv_packing, kv_packing, lkv_dim]
237
- cache_k_pe: jax.
238
- Array, # [total_num_pages, page_size_per_kv_packing, kv_packing, r_dim]
239
257
  kv_lens: jax.Array, # i32[max_num_seqs]
240
258
  page_indices: jax.Array, # i32[max_num_seqs * pages_per_seq]
241
259
  cu_q_lens: jax.Array, # i32[max_num_seqs + 1]
@@ -260,8 +278,7 @@ def dynamic_validate_inputs(
260
278
  q_pe,
261
279
  new_kv_c,
262
280
  new_k_pe,
263
- cache_kv_c,
264
- cache_k_pe,
281
+ cache_kv,
265
282
  kv_lens,
266
283
  page_indices,
267
284
  cu_q_lens,
@@ -277,8 +294,8 @@ def dynamic_validate_inputs(
277
294
  debug_mode=debug_mode,
278
295
  )
279
296
  max_num_tokens = ql_nope.shape[0]
280
- total_num_pages = cache_kv_c.shape[0]
281
- _, page_size_per_kv_packing, kv_packing, _ = cache_kv_c.shape
297
+ total_num_pages = cache_kv.shape[0]
298
+ _, page_size_per_kv_packing, kv_packing, _ = cache_kv.shape
282
299
  page_size = page_size_per_kv_packing * kv_packing
283
300
  max_num_seqs = kv_lens.shape[0]
284
301
  num_page_indices = page_indices.shape[0]
@@ -320,10 +337,8 @@ def static_validate_inputs(
320
337
  q_pe: jax.Array, # [max_num_tokens, actual_num_q_heads, actual_r_dim]
321
338
  new_kv_c: jax.Array, # [max_num_tokens, actual_lkv_dim]
322
339
  new_k_pe: jax.Array, # [max_num_tokens, actual_r_dim]
323
- cache_kv_c: jax.
340
+ cache_kv: jax.
324
341
  Array, # [total_num_pages, page_size_per_kv_packing, kv_packing, lkv_dim]
325
- cache_k_pe: jax.
326
- Array, # [total_num_pages, page_size_per_kv_packing, kv_packing, r_dim]
327
342
  kv_lens: jax.Array, # i32[max_num_seqs]
328
343
  page_indices: jax.Array, # i32[max_num_seqs * pages_per_seq]
329
344
  cu_q_lens: jax.Array, # i32[max_num_seqs + 1]
@@ -373,44 +388,34 @@ def static_validate_inputs(
373
388
 
374
389
  actual_lkv_dim = ql_nope.shape[2]
375
390
  actual_r_dim = q_pe.shape[2]
391
+ lkv_dim = align_to(actual_lkv_dim, 128)
392
+ r_dim = align_to(actual_r_dim, 128)
376
393
 
377
394
  (
378
395
  _,
379
396
  page_size_per_kv_packing,
380
397
  kv_packing,
381
- lkv_dim,
382
- ) = cache_kv_c.shape
383
- _, _, _, r_dim = cache_k_pe.shape
398
+ kv_dim,
399
+ ) = cache_kv.shape
384
400
 
385
- if lkv_dim != align_to(actual_lkv_dim, 128):
386
- raise ValueError(
387
- f"Expected {lkv_dim=} is equal to {align_to(actual_lkv_dim, 128)=}"
388
- )
389
- if r_dim != align_to(actual_r_dim, 128):
401
+ if lkv_dim + r_dim != kv_dim:
390
402
  raise ValueError(
391
- f"Expected {r_dim=} is equal to {align_to(actual_r_dim, 128)=}")
403
+ f"Expected {lkv_dim=} + {r_dim=} to be equal to {kv_dim=}")
392
404
 
393
- if not (cache_kv_c.dtype == new_kv_c.dtype):
405
+ if not (cache_kv.dtype == new_kv_c.dtype):
394
406
  raise ValueError(
395
- f"Expected {cache_kv_c.dtype=} to be equal to {new_kv_c.dtype=}.")
396
- if not (cache_k_pe.dtype == new_k_pe.dtype):
407
+ f"Expected {cache_kv.dtype=} to be equal to {new_kv_c.dtype=}.")
408
+ if not (cache_kv.dtype == new_k_pe.dtype):
397
409
  raise ValueError(
398
- f"Expected {cache_k_pe.dtype=} to be equal to {new_k_pe.dtype=}.")
410
+ f"Expected {cache_kv.dtype=} to be equal to {new_k_pe.dtype=}.")
399
411
 
400
412
  # Integer kv quantization is currently not supported.
401
- if not jnp.issubdtype(cache_kv_c.dtype, jnp.floating):
402
- raise ValueError(
403
- f"Expected {cache_kv_c.dtype=} to be a floating point.")
404
- if not jnp.issubdtype(cache_k_pe.dtype, jnp.floating):
405
- raise ValueError(
406
- f"Expected {cache_k_pe.dtype=} to be a floating point.")
413
+ if not jnp.issubdtype(cache_kv.dtype, jnp.floating):
414
+ raise ValueError(f"Expected {cache_kv.dtype=} to be a floating point.")
407
415
 
408
- if kv_packing != get_dtype_packing(cache_kv_c.dtype):
416
+ if kv_packing != get_dtype_packing(cache_kv.dtype):
409
417
  raise ValueError(
410
- f"{kv_packing=} does not match with {cache_kv_c.dtype=}")
411
- if kv_packing != get_dtype_packing(cache_k_pe.dtype):
412
- raise ValueError(
413
- f"{kv_packing=} does not match with {cache_k_pe.dtype=}")
418
+ f"{kv_packing=} does not match with {cache_kv.dtype=}")
414
419
 
415
420
  if not (jnp.int32 == kv_lens.dtype == page_indices.dtype == cu_q_lens.dtype
416
421
  == distribution.dtype):
@@ -475,14 +480,12 @@ def _mla_ragged_paged_attention_kernel(
475
480
  q_pe_hbm_ref, # [max_num_tokens, num_q_heads_per_q_packing, q_packing, r_dim]
476
481
  new_kv_c_hbm_ref, # [max_num_tokens_per_kv_packing, kv_packing, lkv_dim]
477
482
  new_k_pe_hbm_ref, # [max_num_tokens_per_kv_packing, kv_packing, r_dim]
478
- cache_kv_c_hbm_ref, # [total_num_pages, page_size_per_kv_packing, kv_packing, lkv_dim]
479
- cache_k_pe_hbm_ref, # [total_num_pages, page_size_per_kv_packing, kv_packing, r_dim]
483
+ cache_kv_hbm_ref, # [total_num_pages, page_size_per_kv_packing, kv_packing, align_to(lkv_dim + r_dim, 128)]
480
484
  # Output
481
485
  o_hbm_ref, # [max_num_tokens, num_q_heads_per_q_packing, q_packing, lkv_dim]
482
- updated_cache_kv_c_hbm_ref, # [total_num_pages, page_size_per_kv_packing, kv_packing, lkv_dim]
483
- updated_cache_k_pe_hbm_ref, # [total_num_pages, page_size_per_kv_packing, kv_packing, r_dim]
486
+ updated_cache_kv_hbm_ref, # [total_num_pages, page_size_per_kv_packing, kv_packing, align_to(lkv_dim + r_dim, 128)]
484
487
  # Scratch
485
- bkvc_x2_ref, # [2, bkv_sz_per_kv_packing, kv_packing, lkv_dim]
488
+ bkvc_x2_ref, # [2, bkv_sz_per_kv_packing, kv_packing, lkv_dim].
486
489
  bkpe_x2_ref, # [2, bkv_sz_per_kv_packing, kv_packing, r_dim]
487
490
  bq_nope_x2_ref, # [2, bq_sz, num_q_heads_per_q_packing, q_packing, lkv_dim]
488
491
  bq_rope_x2_ref, # [2, bq_sz, num_q_heads_per_q_packing, q_packing, r_dim]
@@ -505,20 +508,24 @@ def _mla_ragged_paged_attention_kernel(
505
508
  debug_mode: bool = False,
506
509
  ):
507
510
  assert ql_nope_hbm_ref.shape == o_hbm_ref.shape
508
- assert ql_nope_hbm_ref.shape[-1] == cache_kv_c_hbm_ref.shape[-1]
509
- assert q_pe_hbm_ref.shape[-1] == cache_k_pe_hbm_ref.shape[-1]
511
+ # Validation checks on the dimensions
512
+ nope_dim = ql_nope_hbm_ref.shape[-1]
513
+ pe_dim = q_pe_hbm_ref.shape[-1]
514
+ assert nope_dim + pe_dim == cache_kv_hbm_ref.shape[-1]
515
+
510
516
  _, num_q_heads_per_q_packing, q_packing, lkv_dim = ql_nope_hbm_ref.shape
511
517
  r_dim = q_pe_hbm_ref.shape[-1]
512
518
  num_q_heads = num_q_heads_per_q_packing * q_packing
513
519
  total_num_pages, page_size_per_kv_packing, kv_packing, _ = (
514
- cache_kv_c_hbm_ref.shape)
520
+ cache_kv_hbm_ref.shape)
515
521
  max_num_seqs = kv_lens_ref.shape[0]
516
522
  num_page_indices = page_indices_ref.shape[0]
517
523
 
518
524
  assert num_page_indices % max_num_seqs == 0
519
525
  pages_per_seq = num_page_indices // max_num_seqs
520
526
  q_dtype = ql_nope_hbm_ref.dtype
521
- kv_dtype = cache_kv_c_hbm_ref.dtype
527
+ # Validate against the KV dtype.
528
+ kv_dtype = cache_kv_hbm_ref.dtype
522
529
  assert q_pe_hbm_ref.dtype == q_dtype
523
530
  assert o_hbm_ref.dtype == q_dtype
524
531
  assert get_dtype_packing(q_dtype) == q_packing
@@ -561,8 +568,8 @@ def _mla_ragged_paged_attention_kernel(
561
568
  def flash_attention(
562
569
  ql_nope, # [actual_bq_sz * num_q_heads, lkv_dim]
563
570
  q_pe, # [actual_bq_sz * num_q_heads, r_dim]
564
- kv_c, # [bkv_sz, lkv_dim]
565
- k_pe, # [bkv_sz, r_dim]
571
+ kv_c, # [bkv_sz, lkv_dim] <- Correspond to data from bkvc_x2_ref
572
+ k_pe, # [bkv_sz, r_dim] <- Correspond to data from bpe_x2_ref
566
573
  *,
567
574
  bq_idx,
568
575
  bkv_idx,
@@ -649,14 +656,9 @@ def _mla_ragged_paged_attention_kernel(
649
656
  sem = sems.at[0, bkv_sem_idx]
650
657
  bkvc_vmem_ref = bkvc_x2_ref.at[bkv_sem_idx]
651
658
  bkvpe_vmem_ref = bkpe_x2_ref.at[bkv_sem_idx]
652
-
653
- reshaped_cache_kv_c_hbm_ref = cache_kv_c_hbm_ref.reshape(
659
+ reshaped_cache_hbm_ref = cache_kv_hbm_ref.reshape(
654
660
  total_num_pages * page_size_per_kv_packing,
655
- *cache_kv_c_hbm_ref.shape[2:],
656
- )
657
- reshaped_cache_k_pe_hbm_ref = cache_k_pe_hbm_ref.reshape(
658
- total_num_pages * page_size_per_kv_packing,
659
- *cache_k_pe_hbm_ref.shape[2:],
661
+ *cache_kv_hbm_ref.shape[2:],
660
662
  )
661
663
  kv_len = kv_lens_ref[seq_idx]
662
664
  kv_len_start = bkv_idx * bkv_sz
@@ -684,22 +686,22 @@ def _mla_ragged_paged_attention_kernel(
684
686
  kv_left_per_kv_packing - i * page_size_per_kv_packing,
685
687
  )
686
688
  _async_copy(
687
- reshaped_cache_kv_c_hbm_ref.at[pl.ds(
689
+ reshaped_cache_hbm_ref.at[pl.ds(
688
690
  page_indices_ref[page_indices_offset + i] *
689
691
  page_size_per_kv_packing,
690
692
  sz_per_kv_packing,
691
- )],
693
+ ), ..., :nope_dim],
692
694
  bkvc_vmem_ref.at[pl.ds(i * page_size_per_kv_packing,
693
695
  sz_per_kv_packing)],
694
696
  sem,
695
697
  wait,
696
698
  )
697
699
  _async_copy(
698
- reshaped_cache_k_pe_hbm_ref.at[pl.ds(
700
+ reshaped_cache_hbm_ref.at[pl.ds(
699
701
  page_indices_ref[page_indices_offset + i] *
700
702
  page_size_per_kv_packing,
701
703
  sz_per_kv_packing,
702
- )],
704
+ ), ..., nope_dim:],
703
705
  bkvpe_vmem_ref.at[pl.ds(i * page_size_per_kv_packing,
704
706
  sz_per_kv_packing)],
705
707
  sem,
@@ -820,37 +822,17 @@ def _mla_ragged_paged_attention_kernel(
820
822
  return q_nope_vec, q_rope_vec
821
823
 
822
824
  def load_bkv(bkv_sem_idx, *, bkvc_mask, bkpe_mask):
823
- bitwidth = 32 // kv_packing
824
- repack_ty = jnp.dtype(f"uint{bitwidth}")
825
825
  bkvc_ref = (bkvc_x2_ref.bitcast(jnp.uint32).at[bkv_sem_idx].reshape(
826
826
  bkv_sz_per_kv_packing, lkv_dim))
827
- bkvc_vec = bkvc_ref[...]
828
- bkvc_vecs = []
829
- for i in range(kv_packing):
830
- masked_bkvc_vec = bkvc_vec >> (i * bitwidth)
831
- bkvc_vecs.append(masked_bkvc_vec)
832
- concated_bkvc_vec = jnp.concatenate(bkvc_vecs, axis=-1)
833
- concated_bkvc_vec = concated_bkvc_vec.reshape(bkv_sz, lkv_dim)
834
- concated_bkvc_vec = lax.select(bkvc_mask, concated_bkvc_vec,
835
- jnp.zeros_like(concated_bkvc_vec))
836
- concated_bkvc_vec = pltpu.bitcast(concated_bkvc_vec.astype(repack_ty),
837
- kv_dtype)
827
+ bkvc_vec = pltpu.bitcast(bkvc_ref[...], kv_dtype)
828
+ bkvc_vec = lax.select(bkvc_mask, bkvc_vec, jnp.zeros_like(bkvc_vec))
838
829
 
839
830
  bkpe_ref = (bkpe_x2_ref.bitcast(jnp.uint32).at[bkv_sem_idx].reshape(
840
831
  bkv_sz_per_kv_packing, r_dim))
841
- bkpe_vec = bkpe_ref[...]
842
- bkpe_vecs = []
843
- for i in range(kv_packing):
844
- masked_bkpe_vec = bkpe_vec >> (i * bitwidth)
845
- bkpe_vecs.append(masked_bkpe_vec)
846
- concated_bkpe_vec = jnp.concatenate(bkpe_vecs, axis=-1)
847
- concated_bkpe_vec = concated_bkpe_vec.reshape(bkv_sz, r_dim)
848
- concated_bkpe_vec = lax.select(bkpe_mask, concated_bkpe_vec,
849
- jnp.zeros_like(concated_bkpe_vec))
850
- concated_bkpe_vec = pltpu.bitcast(concated_bkpe_vec.astype(repack_ty),
851
- kv_dtype)
852
-
853
- return concated_bkvc_vec, concated_bkpe_vec
832
+ bkpe_vec = pltpu.bitcast(bkpe_ref[...], kv_dtype)
833
+ bkpe_vec = lax.select(bkpe_mask, bkpe_vec, jnp.zeros_like(bkpe_vec))
834
+
835
+ return bkvc_vec, bkpe_vec
854
836
 
855
837
  def broadcast_minor(src, shape):
856
838
  if src.shape == shape:
@@ -1082,17 +1064,16 @@ def prepare_outputs(
1082
1064
  "vmem_limit_bytes",
1083
1065
  "debug_mode",
1084
1066
  ),
1085
- donate_argnames=("cache_kv_c", "cache_k_pe"),
1067
+ donate_argnames=("cache_kv"),
1086
1068
  )
1087
1069
  def mla_ragged_paged_attention(
1088
1070
  ql_nope: jax.Array, # [max_num_tokens, actual_num_q_heads, actual_lkv_dim]
1089
1071
  q_pe: jax.Array, # [max_num_tokens, actual_num_q_heads, actual_r_dim]
1090
1072
  new_kv_c: jax.Array, # [max_num_tokens, actual_lkv_dim]
1091
1073
  new_k_pe: jax.Array, # [max_num_tokens, actual_r_dim]
1092
- cache_kv_c: jax.
1093
- Array, # [total_num_pages, page_size_per_kv_packing, kv_packing, lkv_dim]
1094
- cache_k_pe: jax.
1095
- Array, # [total_num_pages, page_size_per_kv_packing, kv_packing, r_dim]
1074
+ # TODO(gpolovets): Explore separating out into lkv & pe KV caches.
1075
+ cache_kv: jax.
1076
+ Array, # [total_num_pages, page_size_per_kv_packing, kv_packing, align_to(lkv_dim, 128)]
1096
1077
  kv_lens: jax.Array, # i32[max_num_seqs]
1097
1078
  page_indices: jax.Array, # i32[max_num_seqs * pages_per_seq]
1098
1079
  cu_q_lens: jax.Array, # i32[max_num_seqs + 1]
@@ -1124,8 +1105,7 @@ def mla_ragged_paged_attention(
1124
1105
  q_pe: concatenated all sequences' rope.
1125
1106
  new_kv_c: concatenated all sequences' kv_c values
1126
1107
  new_k_pe: concatenated all sequences' k_pe values
1127
- cache_kv_c: the current kv_c cache.
1128
- cache_k_pe: the current k_pe cache.
1108
+ cache_kv: the current kv cache.
1129
1109
  kv_lens: the length of each sequence in the kv cache.
1130
1110
  page_indices: flattened page indices look-up table by (seq_id, page_id).
1131
1111
  cu_q_lens: the cumulative sum of the effective query lengths. Similar to
@@ -1159,8 +1139,7 @@ def mla_ragged_paged_attention(
1159
1139
  q_pe,
1160
1140
  new_kv_c,
1161
1141
  new_k_pe,
1162
- cache_kv_c,
1163
- cache_k_pe,
1142
+ cache_kv,
1164
1143
  kv_lens,
1165
1144
  page_indices,
1166
1145
  cu_q_lens,
@@ -1177,11 +1156,10 @@ def mla_ragged_paged_attention(
1177
1156
  )
1178
1157
 
1179
1158
  # TODO(chengjiyao): fuse kv cache update into the kernel.
1180
- cache_kv_c, cache_k_pe = update_kv_cache(
1159
+ cache_kv = update_kv_cache(
1181
1160
  new_kv_c,
1182
1161
  new_k_pe,
1183
- cache_kv_c,
1184
- cache_k_pe,
1162
+ cache_kv,
1185
1163
  kv_lens,
1186
1164
  page_indices,
1187
1165
  cu_q_lens,
@@ -1202,7 +1180,7 @@ def mla_ragged_paged_attention(
1202
1180
  lkv_dim = new_kv_c.shape[-1]
1203
1181
  r_dim = new_k_pe.shape[-1]
1204
1182
 
1205
- _, page_size_per_kv_packing, kv_packing, _ = cache_kv_c.shape
1183
+ _, page_size_per_kv_packing, kv_packing, _ = cache_kv.shape
1206
1184
  page_size = page_size_per_kv_packing * kv_packing
1207
1185
  _, num_q_heads_per_q_packing, q_packing, _ = ql_nope.shape
1208
1186
  max_num_seqs = kv_lens.shape[0]
@@ -1221,23 +1199,21 @@ def mla_ragged_paged_attention(
1221
1199
  pl.BlockSpec(memory_space=pltpu.HBM),
1222
1200
  pl.BlockSpec(memory_space=pltpu.HBM),
1223
1201
  pl.BlockSpec(memory_space=pltpu.HBM),
1224
- pl.BlockSpec(memory_space=pltpu.HBM),
1225
1202
  ]
1226
1203
 
1227
1204
  out_specs = [
1228
1205
  pl.BlockSpec(memory_space=pltpu.HBM),
1229
1206
  pl.BlockSpec(memory_space=pltpu.HBM),
1230
- pl.BlockSpec(memory_space=pltpu.HBM),
1231
1207
  ]
1232
1208
 
1233
1209
  bkvc_double_buf = pltpu.VMEM(
1234
1210
  (2, bkv_sz_per_kv_packing, kv_packing, lkv_dim),
1235
- cache_kv_c.dtype,
1211
+ cache_kv.dtype,
1236
1212
  )
1237
1213
 
1238
1214
  bkpe_double_buf = pltpu.VMEM(
1239
1215
  (2, bkv_sz_per_kv_packing, kv_packing, r_dim),
1240
- cache_k_pe.dtype,
1216
+ cache_kv.dtype,
1241
1217
  )
1242
1218
 
1243
1219
  bq_nope_double_buf = pltpu.VMEM(
@@ -1320,30 +1296,26 @@ def mla_ragged_paged_attention(
1320
1296
  ),
1321
1297
  out_shape=[
1322
1298
  jax.ShapeDtypeStruct(shape=ql_nope.shape, dtype=ql_nope.dtype),
1323
- jax.ShapeDtypeStruct(shape=cache_kv_c.shape,
1324
- dtype=cache_kv_c.dtype),
1325
- jax.ShapeDtypeStruct(shape=cache_k_pe.shape,
1326
- dtype=cache_k_pe.dtype),
1299
+ jax.ShapeDtypeStruct(shape=cache_kv.shape,
1300
+ dtype=cache_kv.dtype),
1327
1301
  ],
1328
1302
  input_output_aliases={
1329
1303
  7: 0,
1330
1304
  11: 1,
1331
- 12: 2
1332
1305
  },
1333
1306
  name=scope_name,
1334
1307
  ))
1335
1308
 
1336
- output, updated_kv_c, updated_k_pe = kernel(
1309
+ output, updated_kv = kernel(
1337
1310
  *scalar_prefetches,
1338
1311
  ql_nope,
1339
1312
  q_pe,
1340
1313
  new_kv_c,
1341
1314
  new_k_pe,
1342
- cache_kv_c,
1343
- cache_k_pe,
1315
+ cache_kv,
1344
1316
  )
1345
1317
  output = prepare_outputs(
1346
1318
  output, actual_num_q_heads,
1347
1319
  actual_lkv_dim) # [max_num_tokens, actual_num_q_heads, actual_lkv_dim]
1348
1320
 
1349
- return output, updated_kv_c, updated_k_pe
1321
+ return output, updated_kv
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -9,12 +9,58 @@ from jax._src import dtypes
9
9
  from jax.experimental import pallas as pl
10
10
  from jax.experimental.pallas import tpu as pltpu
11
11
 
12
+ from tpu_inference.kernels.quantized_matmul import util
12
13
  from tpu_inference.kernels.quantized_matmul.tuned_block_sizes import (
13
14
  TunedValue, get_device_vmem_limit, get_tuned_block_sizes)
14
15
  from tpu_inference.kernels.quantized_matmul.util import (get_kernel_name,
15
16
  next_multiple,
16
17
  unfold_args)
17
18
 
19
+ quantize_tensor = util.quantize_tensor
20
+
21
+
22
+ def xla_quantized_matmul(
23
+ x: jax.Array,
24
+ w_q: jax.Array,
25
+ w_scale: jax.Array,
26
+ quantize_activation=True,
27
+ ) -> jax.Array:
28
+ """
29
+ Reference (pure JAX) implementation of the quantized matmul kernel below.
30
+
31
+ Args:
32
+ x: Activation.
33
+ w_q: Weight quantized array. [n_output_features, n_input_features]
34
+ w_s: Weight quantization scale. [n_output_features]
35
+ mesh: Mesh to shard on.
36
+ weight_sharding: PartitionSpec for the weight tensor.
37
+
38
+ Returns:
39
+ Output of the quantized matmul.
40
+ """
41
+ if quantize_activation:
42
+ acc_dtype = jnp.float32
43
+ if quantize_activation and jnp.issubdtype(w_q.dtype, jnp.integer):
44
+ acc_dtype = jnp.int32
45
+
46
+ x_q, x_scale = quantize_tensor(x, w_q.dtype)
47
+ out = jax.lax.dot_general(
48
+ x_q,
49
+ w_q,
50
+ dimension_numbers=(((1, ), (1, )), ((), ())),
51
+ preferred_element_type=acc_dtype,
52
+ ).astype(jnp.float32)
53
+ out *= x_scale
54
+ else:
55
+ out = jax.lax.dot_general(
56
+ x,
57
+ w_q,
58
+ dimension_numbers=(((1, ), (1, )), ((), ())),
59
+ preferred_element_type=jnp.float32,
60
+ )
61
+ out *= jnp.expand_dims(w_scale, 0)
62
+ return out.astype(x.dtype)
63
+
18
64
 
19
65
  def quantize_array(
20
66
  x: jax.Array, # [bs_block_size, in_block_size]
@@ -50,11 +96,20 @@ def get_vmem_limit(
50
96
  """Calculate VMEM limit for the kernel."""
51
97
 
52
98
  # Calculate in/out VMEM size.
53
- x_size = batch_block_size * in_block_size * dtypes.bit_width(x_dtype)
54
- x_abs_max_size = batch_block_size * dtypes.bit_width(scale_dtype)
55
- w_q_size = out_block_size * in_block_size * dtypes.bit_width(w_q_dtype)
56
- w_scale_size = out_block_size * dtypes.bit_width(scale_dtype)
57
- out_size = batch_block_size * out_block_size * dtypes.bit_width(out_dtype)
99
+ x_size = (batch_block_size *
100
+ in_block_size * (dtypes.bit_width(x_dtype) if hasattr(
101
+ dtypes, "bit_width") else dtypes.itemsize_bits(x_dtype)))
102
+ x_abs_max_size = (
103
+ batch_block_size * (dtypes.bit_width(scale_dtype) if hasattr(
104
+ dtypes, "bit_width") else dtypes.itemsize_bits(scale_dtype)))
105
+ w_q_size = (out_block_size *
106
+ in_block_size * (dtypes.bit_width(w_q_dtype) if hasattr(
107
+ dtypes, "bit_width") else dtypes.itemsize_bits(w_q_dtype)))
108
+ w_scale_size = (out_block_size * (dtypes.bit_width(scale_dtype) if hasattr(
109
+ dtypes, "bit_width") else dtypes.itemsize_bits(scale_dtype)))
110
+ out_size = (batch_block_size *
111
+ out_block_size * (dtypes.bit_width(out_dtype) if hasattr(
112
+ dtypes, "bit_width") else dtypes.itemsize_bits(out_dtype)))
58
113
 
59
114
  vmem_in_out = x_size + x_abs_max_size + w_q_size + w_scale_size + out_size
60
115
  vmem_in_out *= 2 # Account for compute and vreg spills.
@@ -68,9 +123,15 @@ def get_vmem_limit(
68
123
  vmem_in_out += out_size if (n_batch > 1 or n_out > 1) else 0
69
124
 
70
125
  # Calculate scratch VMEM size.
71
- acc_size = batch_block_size * out_block_size * dtypes.bit_width(acc_dtype)
72
- x_q_size = batch_block_size * in_block_size * dtypes.bit_width(x_q_dtype)
73
- x_scale_size = batch_block_size * dtypes.bit_width(scale_dtype)
126
+ acc_size = (batch_block_size *
127
+ out_block_size * (dtypes.bit_width(acc_dtype) if hasattr(
128
+ dtypes, "bit_width") else dtypes.itemsize_bits(acc_dtype)))
129
+ x_q_size = (batch_block_size *
130
+ in_block_size * (dtypes.bit_width(x_q_dtype) if hasattr(
131
+ dtypes, "bit_width") else dtypes.itemsize_bits(x_q_dtype)))
132
+ x_scale_size = (
133
+ batch_block_size * (dtypes.bit_width(scale_dtype) if hasattr(
134
+ dtypes, "bit_width") else dtypes.itemsize_bits(scale_dtype)))
74
135
 
75
136
  vmem_scratch = acc_size if save_acc else 0
76
137
  vmem_scratch += x_q_size + x_scale_size if save_x_q else 0
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.