tpu-inference 0.12.0.dev20251222__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (260) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_core_tpu.py +513 -0
  4. tests/core/test_disagg_executor.py +60 -0
  5. tests/core/test_disagg_utils.py +67 -0
  6. tests/core/test_dp_scheduler.py +724 -0
  7. tests/core/test_init.py +63 -0
  8. tests/distributed/__init__.py +13 -0
  9. tests/distributed/test_distributed_utils.py +120 -0
  10. tests/distributed/test_tpu_connector.py +478 -0
  11. tests/e2e/__init__.py +13 -0
  12. tests/e2e/test_async_scheduler.py +211 -0
  13. tests/e2e/test_data_parallel.py +393 -0
  14. tests/e2e/test_local_disagg.py +257 -0
  15. tests/e2e/test_model_loader.py +268 -0
  16. tests/e2e/test_multi_modal_inference.py +111 -0
  17. tests/e2e/test_pipeline_parallel.py +265 -0
  18. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  19. tests/e2e/test_sampling_params.py +269 -0
  20. tests/e2e/test_speculative_decoding.py +291 -0
  21. tests/e2e/test_structured_decoding.py +46 -0
  22. tests/executors/__init__.py +13 -0
  23. tests/executors/test_ray_distributed_executor.py +199 -0
  24. tests/experimental/__init__.py +13 -0
  25. tests/experimental/test_llama3_jax_stashed.py +208 -0
  26. tests/kernels/__init__.py +13 -0
  27. tests/kernels/collectives/__init__.py +13 -0
  28. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  29. tests/kernels/fused_moe_v1_test.py +388 -0
  30. tests/kernels/gmm_test.py +205 -0
  31. tests/kernels/mla_v1_test.py +498 -0
  32. tests/kernels/quantized_matmul_kernel_test.py +159 -0
  33. tests/kernels/ragged_kv_cache_update_v2_test.py +248 -0
  34. tests/kernels/ragged_paged_attention_kernel_v2_test.py +414 -0
  35. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +565 -0
  36. tests/kernels/ragged_paged_attention_kernel_v3_test.py +520 -0
  37. tests/layers/__init__.py +13 -0
  38. tests/layers/common/__init__.py +13 -0
  39. tests/layers/common/test_attention_interface.py +156 -0
  40. tests/layers/common/test_quantization.py +149 -0
  41. tests/layers/jax/__init__.py +13 -0
  42. tests/layers/jax/attention/__init__.py +13 -0
  43. tests/layers/jax/attention/test_common_attention.py +103 -0
  44. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  45. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  46. tests/layers/jax/moe/__init__.py +13 -0
  47. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  48. tests/layers/jax/sample/__init__.py +13 -0
  49. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  50. tests/layers/jax/sample/test_sampling.py +115 -0
  51. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  52. tests/layers/jax/test_layers.py +155 -0
  53. tests/layers/jax/test_qwix.py +969 -0
  54. tests/layers/jax/test_rope.py +93 -0
  55. tests/layers/jax/test_sharding.py +159 -0
  56. tests/layers/jax/test_transformer_block.py +152 -0
  57. tests/layers/vllm/__init__.py +13 -0
  58. tests/layers/vllm/test_attention.py +363 -0
  59. tests/layers/vllm/test_awq.py +405 -0
  60. tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +403 -0
  62. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +426 -0
  63. tests/layers/vllm/test_fp8.py +17 -0
  64. tests/layers/vllm/test_mxfp4.py +297 -0
  65. tests/layers/vllm/test_unquantized.py +621 -0
  66. tests/layers/vllm/utils.py +72 -0
  67. tests/lora/__init__.py +13 -0
  68. tests/lora/conftest.py +46 -0
  69. tests/lora/test_bgmv.py +57 -0
  70. tests/lora/test_layers.py +666 -0
  71. tests/lora/test_lora.py +147 -0
  72. tests/lora/test_lora_perf.py +67 -0
  73. tests/lora/utils.py +88 -0
  74. tests/models/__init__.py +13 -0
  75. tests/models/common/__init__.py +13 -0
  76. tests/models/common/test_model_loader.py +455 -0
  77. tests/models/jax/__init__.py +13 -0
  78. tests/models/jax/test_deepseek_v3.py +401 -0
  79. tests/models/jax/test_llama3.py +184 -0
  80. tests/models/jax/test_llama4.py +298 -0
  81. tests/models/jax/test_llama_eagle3.py +197 -0
  82. tests/models/jax/test_llama_guard_4.py +242 -0
  83. tests/models/jax/test_qwen2.py +172 -0
  84. tests/models/jax/test_qwen2_5_vl.py +606 -0
  85. tests/models/jax/test_qwen3.py +169 -0
  86. tests/models/jax/test_weight_loading.py +180 -0
  87. tests/models/jax/utils/__init__.py +13 -0
  88. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  89. tests/platforms/__init__.py +13 -0
  90. tests/platforms/test_tpu_platform.py +54 -0
  91. tests/runner/__init__.py +13 -0
  92. tests/runner/test_block_table.py +395 -0
  93. tests/runner/test_input_batch.py +226 -0
  94. tests/runner/test_kv_cache.py +220 -0
  95. tests/runner/test_kv_cache_manager.py +498 -0
  96. tests/runner/test_multimodal_manager.py +429 -0
  97. tests/runner/test_persistent_batch_manager.py +84 -0
  98. tests/runner/test_speculative_decoding_manager.py +368 -0
  99. tests/runner/test_structured_decoding_manager.py +220 -0
  100. tests/runner/test_tpu_runner.py +202 -0
  101. tests/runner/test_tpu_runner_dp.py +1033 -0
  102. tests/runner/test_tpu_runner_mesh.py +200 -0
  103. tests/runner/test_utils.py +411 -0
  104. tests/spec_decode/__init__.py +13 -0
  105. tests/spec_decode/test_eagle3.py +311 -0
  106. tests/test_base.py +215 -0
  107. tests/test_envs.py +280 -0
  108. tests/test_tpu_info.py +134 -0
  109. tests/test_utils.py +193 -0
  110. tests/worker/__init__.py +13 -0
  111. tests/worker/tpu_worker_test.py +414 -0
  112. tpu_inference/__init__.py +67 -0
  113. tpu_inference/core/__init__.py +13 -0
  114. tpu_inference/core/core_tpu.py +786 -0
  115. tpu_inference/core/disagg_executor.py +118 -0
  116. tpu_inference/core/disagg_utils.py +49 -0
  117. tpu_inference/core/sched/__init__.py +13 -0
  118. tpu_inference/core/sched/dp_scheduler.py +814 -0
  119. tpu_inference/distributed/__init__.py +13 -0
  120. tpu_inference/distributed/jax_parallel_state.py +81 -0
  121. tpu_inference/distributed/tpu_connector.py +732 -0
  122. tpu_inference/distributed/utils.py +112 -0
  123. tpu_inference/env_override.py +9 -0
  124. tpu_inference/envs.py +191 -0
  125. tpu_inference/executors/__init__.py +13 -0
  126. tpu_inference/executors/ray_distributed_executor.py +399 -0
  127. tpu_inference/experimental/__init__.py +13 -0
  128. tpu_inference/experimental/llama3_jax_stashed.py +272 -0
  129. tpu_inference/kernels/__init__.py +13 -0
  130. tpu_inference/kernels/collectives/__init__.py +13 -0
  131. tpu_inference/kernels/collectives/all_gather_matmul.py +741 -0
  132. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +65 -0
  133. tpu_inference/kernels/collectives/util.py +47 -0
  134. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  135. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  136. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  137. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  138. tpu_inference/kernels/fused_moe/v1/kernel.py +1612 -0
  139. tpu_inference/kernels/megablox/__init__.py +13 -0
  140. tpu_inference/kernels/megablox/common.py +54 -0
  141. tpu_inference/kernels/megablox/gmm.py +646 -0
  142. tpu_inference/kernels/mla/__init__.py +13 -0
  143. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  144. tpu_inference/kernels/mla/v1/kernel.py +1340 -0
  145. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  146. tpu_inference/kernels/quantized_matmul/kernel.py +456 -0
  147. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  148. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  149. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  150. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  151. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +876 -0
  152. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +288 -0
  153. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  154. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  155. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1594 -0
  156. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1586 -0
  157. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4460 -0
  158. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +548 -0
  159. tpu_inference/kernels/ragged_paged_attention/v3/util.py +65 -0
  160. tpu_inference/layers/__init__.py +13 -0
  161. tpu_inference/layers/common/__init__.py +13 -0
  162. tpu_inference/layers/common/attention_interface.py +403 -0
  163. tpu_inference/layers/common/attention_metadata.py +48 -0
  164. tpu_inference/layers/common/binary_search.py +295 -0
  165. tpu_inference/layers/common/quant_methods.py +23 -0
  166. tpu_inference/layers/common/quantization.py +270 -0
  167. tpu_inference/layers/common/sharding.py +600 -0
  168. tpu_inference/layers/jax/__init__.py +13 -0
  169. tpu_inference/layers/jax/attention/__init__.py +13 -0
  170. tpu_inference/layers/jax/attention/attention.py +268 -0
  171. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +547 -0
  172. tpu_inference/layers/jax/attention/gpt_oss_attention.py +275 -0
  173. tpu_inference/layers/jax/attention/llama4_attention.py +167 -0
  174. tpu_inference/layers/jax/base.py +165 -0
  175. tpu_inference/layers/jax/constants.py +101 -0
  176. tpu_inference/layers/jax/layers.py +315 -0
  177. tpu_inference/layers/jax/misc.py +30 -0
  178. tpu_inference/layers/jax/moe/__init__.py +13 -0
  179. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +615 -0
  180. tpu_inference/layers/jax/moe/gpt_oss_moe.py +199 -0
  181. tpu_inference/layers/jax/moe/moe.py +249 -0
  182. tpu_inference/layers/jax/pp_utils.py +53 -0
  183. tpu_inference/layers/jax/rope.py +294 -0
  184. tpu_inference/layers/jax/rope_interface.py +228 -0
  185. tpu_inference/layers/jax/sample/__init__.py +13 -0
  186. tpu_inference/layers/jax/sample/rejection_sampler.py +528 -0
  187. tpu_inference/layers/jax/sample/sampling.py +110 -0
  188. tpu_inference/layers/jax/sample/sampling_metadata.py +90 -0
  189. tpu_inference/layers/jax/transformer_block.py +121 -0
  190. tpu_inference/layers/vllm/__init__.py +13 -0
  191. tpu_inference/layers/vllm/attention.py +221 -0
  192. tpu_inference/layers/vllm/fused_moe.py +502 -0
  193. tpu_inference/layers/vllm/linear_common.py +221 -0
  194. tpu_inference/layers/vllm/quantization/__init__.py +55 -0
  195. tpu_inference/layers/vllm/quantization/awq.py +221 -0
  196. tpu_inference/layers/vllm/quantization/common.py +124 -0
  197. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  198. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +135 -0
  199. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +266 -0
  200. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  201. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +222 -0
  202. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +150 -0
  203. tpu_inference/layers/vllm/quantization/fp8.py +118 -0
  204. tpu_inference/layers/vllm/quantization/mxfp4.py +396 -0
  205. tpu_inference/layers/vllm/quantization/unquantized.py +416 -0
  206. tpu_inference/layers/vllm/sharding.py +244 -0
  207. tpu_inference/logger.py +10 -0
  208. tpu_inference/lora/__init__.py +13 -0
  209. tpu_inference/lora/torch_lora_ops.py +98 -0
  210. tpu_inference/lora/torch_punica_tpu.py +310 -0
  211. tpu_inference/models/__init__.py +13 -0
  212. tpu_inference/models/common/__init__.py +13 -0
  213. tpu_inference/models/common/model_loader.py +520 -0
  214. tpu_inference/models/jax/__init__.py +13 -0
  215. tpu_inference/models/jax/deepseek_v3.py +978 -0
  216. tpu_inference/models/jax/gpt_oss.py +508 -0
  217. tpu_inference/models/jax/jax_intermediate_tensor.py +93 -0
  218. tpu_inference/models/jax/llama3.py +436 -0
  219. tpu_inference/models/jax/llama4.py +643 -0
  220. tpu_inference/models/jax/llama_eagle3.py +350 -0
  221. tpu_inference/models/jax/llama_guard_4.py +375 -0
  222. tpu_inference/models/jax/qwen2.py +390 -0
  223. tpu_inference/models/jax/qwen2_5_vl.py +1232 -0
  224. tpu_inference/models/jax/qwen3.py +318 -0
  225. tpu_inference/models/jax/utils/__init__.py +13 -0
  226. tpu_inference/models/jax/utils/file_utils.py +110 -0
  227. tpu_inference/models/jax/utils/multi_modal_utils.py +177 -0
  228. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  229. tpu_inference/models/jax/utils/qwix/qwix_utils.py +713 -0
  230. tpu_inference/models/jax/utils/weight_utils.py +621 -0
  231. tpu_inference/models/vllm/__init__.py +13 -0
  232. tpu_inference/models/vllm/vllm_model_wrapper.py +307 -0
  233. tpu_inference/models/vllm/vllm_model_wrapper_context.py +59 -0
  234. tpu_inference/platforms/__init__.py +16 -0
  235. tpu_inference/platforms/tpu_platform.py +258 -0
  236. tpu_inference/runner/__init__.py +13 -0
  237. tpu_inference/runner/block_table.py +122 -0
  238. tpu_inference/runner/compilation_manager.py +890 -0
  239. tpu_inference/runner/input_batch.py +435 -0
  240. tpu_inference/runner/kv_cache.py +166 -0
  241. tpu_inference/runner/kv_cache_manager.py +508 -0
  242. tpu_inference/runner/lora_utils.py +106 -0
  243. tpu_inference/runner/multimodal_manager.py +231 -0
  244. tpu_inference/runner/persistent_batch_manager.py +296 -0
  245. tpu_inference/runner/speculative_decoding_manager.py +262 -0
  246. tpu_inference/runner/structured_decoding_manager.py +101 -0
  247. tpu_inference/runner/tpu_runner.py +1768 -0
  248. tpu_inference/runner/utils.py +426 -0
  249. tpu_inference/spec_decode/__init__.py +13 -0
  250. tpu_inference/spec_decode/jax/__init__.py +13 -0
  251. tpu_inference/spec_decode/jax/eagle3.py +430 -0
  252. tpu_inference/tpu_info.py +92 -0
  253. tpu_inference/utils.py +345 -0
  254. tpu_inference/worker/__init__.py +13 -0
  255. tpu_inference/worker/tpu_worker.py +468 -0
  256. tpu_inference-0.12.0.dev20251222.dist-info/METADATA +106 -0
  257. tpu_inference-0.12.0.dev20251222.dist-info/RECORD +260 -0
  258. tpu_inference-0.12.0.dev20251222.dist-info/WHEEL +5 -0
  259. tpu_inference-0.12.0.dev20251222.dist-info/licenses/LICENSE +201 -0
  260. tpu_inference-0.12.0.dev20251222.dist-info/top_level.txt +2 -0
@@ -0,0 +1,1340 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """TPU-Friendly and Data-Movement-Friendly MLA Ragged Paged Attention kernel."""
15
+
16
+ import functools
17
+
18
+ import jax
19
+ import jax.numpy as jnp
20
+ from jax import lax
21
+ from jax.experimental import pallas as pl
22
+ from jax.experimental.pallas import tpu as pltpu
23
+
24
+ from tpu_inference.kernels.ragged_paged_attention.v3.util import (
25
+ align_to, cdiv, get_dtype_packing)
26
+
27
+ DEFAULT_MASK_VALUE = -0.7 * float(jnp.finfo(jnp.dtype("float32")).max)
28
+
29
+ DEFAULT_VMEM_LIMIT_BYTES = 100 * 1024 * 1024
30
+
31
+
32
+ def get_kv_cache_shape(
33
+ total_num_pages,
34
+ page_size,
35
+ kv_dim,
36
+ kv_dtype,
37
+ ):
38
+ kv_packing = get_dtype_packing(kv_dtype)
39
+ return (
40
+ total_num_pages,
41
+ align_to(page_size, kv_packing) // kv_packing,
42
+ kv_packing,
43
+ align_to(kv_dim, 128),
44
+ )
45
+
46
+
47
+ @functools.partial(
48
+ jax.jit,
49
+ donate_argnames=("cache_kv"),
50
+ )
51
+ def update_kv_cache(
52
+ new_kv_c: jax.Array, # [num_tokens, actual_lkv_dim]
53
+ new_k_pe: jax.Array, # [num_tokens, actual_r_dim]
54
+ cache_kv: jax.
55
+ Array, # [total_num_pages, page_size_per_kv_packing, kv_packing, lkv_dim+r_dim]
56
+ kv_lens: jax.Array, # i32[max_num_seqs]
57
+ page_indices: jax.Array, # i32[max_num_seqs * pages_per_seq]
58
+ cu_q_lens: jax.Array, # i32[max_num_seqs + 1]
59
+ distribution: jax.Array, # i32[3]
60
+ ) -> tuple[jax.Array, jax.Array]:
61
+ """Update KV cache with new tokens."""
62
+ actual_r_dim = new_k_pe.shape[-1]
63
+ r_dim = align_to(actual_r_dim, 128)
64
+ if actual_r_dim != r_dim:
65
+ new_k_pe = jnp.pad(new_k_pe, ((0, 0), (0, r_dim - actual_r_dim)),
66
+ constant_values=0)
67
+ actual_lkv_dim = new_kv_c.shape[-1]
68
+ lkv_dim = align_to(actual_lkv_dim, 128)
69
+ if actual_lkv_dim != lkv_dim:
70
+ new_kv_c = jnp.pad(new_kv_c, ((0, 0), (0, lkv_dim - actual_lkv_dim)),
71
+ constant_values=0)
72
+ kv_dim = r_dim + lkv_dim
73
+ _, page_size_per_kv_packing, kv_packing, cache_kv_dim = cache_kv.shape
74
+ assert kv_dim == cache_kv_dim
75
+ page_size = page_size_per_kv_packing * kv_packing
76
+
77
+ max_num_seqs = kv_lens.shape[0]
78
+ num_page_indices = page_indices.shape[0]
79
+ pages_per_seq = num_page_indices // max_num_seqs
80
+
81
+ def seq_loop_body(i, cache_kv):
82
+ q_start, q_end = cu_q_lens[i], cu_q_lens[i + 1]
83
+ q_len = q_end - q_start
84
+ kv_len = kv_lens[i]
85
+
86
+ def token_loop_body(j, cache_kv_):
87
+ token_idx_in_seq = kv_len - q_len + j
88
+ page_num_in_seq = token_idx_in_seq // page_size
89
+ page_indices_start = i * pages_per_seq
90
+ page_idx = page_indices[page_indices_start + page_num_in_seq]
91
+ row = (token_idx_in_seq % page_size) // kv_packing
92
+ col = (token_idx_in_seq % page_size) % kv_packing
93
+
94
+ cache_kv_ = cache_kv_.at[page_idx, row, col,
95
+ ..., :lkv_dim].set(new_kv_c[q_start + j])
96
+ cache_kv_ = cache_kv_.at[page_idx, row, col, ...,
97
+ lkv_dim:].set(new_k_pe[q_start + j])
98
+ return cache_kv_
99
+
100
+ return lax.fori_loop(0, q_len, token_loop_body, cache_kv)
101
+
102
+ cache_kv = lax.fori_loop(0, distribution[-1], seq_loop_body, cache_kv)
103
+
104
+ return cache_kv
105
+
106
+
107
+ def ref_mla_ragged_paged_attention(
108
+ ql_nope: jax.Array, # [num_tokens, actual_num_q_heads, actual_lkv_dim]
109
+ q_pe: jax.Array, # [num_tokens, actual_num_q_heads, actual_r_dim]
110
+ new_kv_c: jax.Array, # [num_tokens, actual_lkv_dim]
111
+ new_k_pe: jax.Array, # [num_tokens, actual_r_dim]
112
+ cache_kv: jax.
113
+ Array, # [total_num_pages, page_size_per_kv_packing, kv_packing, lkv_dim]
114
+ kv_lens: jax.Array, # i32[max_num_seqs]
115
+ page_indices: jax.Array, # i32[max_num_seqs * pages_per_seq]
116
+ cu_q_lens: jax.Array, # i32[max_num_seqs + 1]
117
+ distribution: jax.Array, # i32[3]
118
+ *,
119
+ sm_scale: float = 1.0,
120
+ sliding_window: int | None = None,
121
+ soft_cap: float | None = None,
122
+ mask_value: float | None = DEFAULT_MASK_VALUE,
123
+ ):
124
+
125
+ if mask_value is None:
126
+ mask_value = DEFAULT_MASK_VALUE
127
+
128
+ dynamic_validate_inputs(
129
+ ql_nope,
130
+ q_pe,
131
+ new_kv_c,
132
+ new_k_pe,
133
+ cache_kv,
134
+ kv_lens,
135
+ page_indices,
136
+ cu_q_lens,
137
+ distribution,
138
+ sm_scale=sm_scale,
139
+ sliding_window=sliding_window,
140
+ soft_cap=soft_cap,
141
+ mask_value=mask_value,
142
+ )
143
+
144
+ updated_cache_kv = update_kv_cache(
145
+ new_kv_c,
146
+ new_k_pe,
147
+ cache_kv,
148
+ kv_lens,
149
+ page_indices,
150
+ cu_q_lens,
151
+ distribution,
152
+ )
153
+ # Pad ql_nope and q_pe to make the last dimension 128-byte aligned.
154
+ actual_lkv_dim = ql_nope.shape[-1]
155
+ lkv_dim = align_to(actual_lkv_dim, 128)
156
+ if lkv_dim != actual_lkv_dim:
157
+ ql_nope = jnp.pad(
158
+ ql_nope,
159
+ ((0, 0), (0, 0), (0, lkv_dim - actual_lkv_dim)),
160
+ constant_values=0,
161
+ )
162
+ actual_r_dim = q_pe.shape[-1]
163
+ r_dim = align_to(actual_r_dim, 128)
164
+ if actual_r_dim != r_dim:
165
+ q_pe = jnp.pad(q_pe, ((0, 0), (0, 0), (0, r_dim - actual_r_dim)),
166
+ constant_values=0)
167
+
168
+ q = jnp.concatenate([ql_nope, q_pe], axis=-1)
169
+ max_num_seqs = kv_lens.shape[0]
170
+ num_page_indices = page_indices.shape[0]
171
+ assert num_page_indices % max_num_seqs == 0
172
+ pages_per_seq = num_page_indices // max_num_seqs
173
+
174
+ total_num_pages, page_size_per_kv_packing, kv_packing, _ = updated_cache_kv.shape
175
+ page_size = page_size_per_kv_packing * kv_packing
176
+ assert lkv_dim == ql_nope.shape[-1]
177
+ assert r_dim == q_pe.shape[-1]
178
+ assert lkv_dim + r_dim == updated_cache_kv.shape[-1]
179
+
180
+ kv_c_cache = updated_cache_kv[..., :lkv_dim].reshape(
181
+ total_num_pages, page_size, lkv_dim)
182
+ k_pe_cache = updated_cache_kv[...,
183
+ lkv_dim:].reshape(total_num_pages, page_size,
184
+ r_dim)
185
+
186
+ outputs = []
187
+
188
+ for i in range(distribution[-1]):
189
+ q_start, q_end = cu_q_lens[i], cu_q_lens[i + 1]
190
+ q_len = q_end - q_start
191
+ kv_len = kv_lens[i]
192
+
193
+ q_i = q[q_start:q_end] # [q_len, actual_num_q_heads, lkv_dim+r_dim]
194
+
195
+ indices_start = i * pages_per_seq
196
+ num_pages_i = cdiv(kv_len, page_size)
197
+ indices_end = indices_start + num_pages_i
198
+ indices = page_indices[indices_start:indices_end]
199
+
200
+ # Gather paged kv_c and k_pe
201
+ gathered_kv_c = kv_c_cache[
202
+ indices] # [num_pages_i, page_size, lkv_dim]
203
+ gathered_k_pe = k_pe_cache[indices] # [num_pages_i, page_size, r_dim]
204
+
205
+ # Flatten pages to sequence
206
+ flat_kv_c = gathered_kv_c.reshape(
207
+ -1, lkv_dim) # [num_pages_i * page_size, lkv_dim]
208
+ flat_k_pe = gathered_k_pe.reshape(
209
+ -1, r_dim) # [num_pages_i * page_size, r_dim]
210
+
211
+ # Prepare k and v for attention
212
+ k_i = jnp.concatenate([flat_kv_c[:kv_len], flat_k_pe[:kv_len]],
213
+ axis=-1) # [kv_len, lkv_dim+r_dim]
214
+ v_i = flat_kv_c[:kv_len] # [kv_len, lkv_dim]
215
+
216
+ # MQA attention:
217
+ # q:[q_len, actual_num_q_heads, lkv_dim+r_dim]
218
+ # k:[kv_len, lkv_dim+r_dim]
219
+ # v:[kv_len, lkv_dim]
220
+ # attn: [actual_num_q_heads, q_len, kv_len]
221
+ attn = jnp.einsum("qnh,kh->nqk",
222
+ q_i,
223
+ k_i,
224
+ preferred_element_type=jnp.float32)
225
+ attn *= sm_scale
226
+
227
+ # Causal mask
228
+ q_span = kv_len - q_len + jax.lax.broadcasted_iota(
229
+ jnp.int32, attn.shape, 1)
230
+ kv_span = jax.lax.broadcasted_iota(jnp.int32, attn.shape, 2)
231
+ mask = q_span < kv_span
232
+ if sliding_window is not None:
233
+ mask = jnp.logical_or(mask, q_span - sliding_window >= kv_span)
234
+ if soft_cap is not None:
235
+ attn = soft_cap * jnp.tanh(attn / soft_cap)
236
+ attn = jnp.where(mask, mask_value, attn)
237
+ attn = jax.nn.softmax(attn, axis=-1).astype(v_i.dtype)
238
+
239
+ # out_i: [q_len, actual_num_q_heads, lkv_dim]
240
+ out_i = jnp.einsum("nqk,kl->qnl", attn, v_i).astype(q_i.dtype)
241
+ outputs.append(out_i)
242
+
243
+ return (
244
+ jnp.concatenate(outputs, axis=0),
245
+ updated_cache_kv,
246
+ )
247
+
248
+
249
+ # Expect to run this validation during runtime.
250
+ def dynamic_validate_inputs(
251
+ ql_nope: jax.Array, # [max_num_tokens, actual_num_q_heads, actual_lkv_dim]
252
+ q_pe: jax.Array, # [max_num_tokens, actual_num_q_heads, actual_r_dim]
253
+ new_kv_c: jax.Array, # [max_num_tokens, actual_lkv_dim]
254
+ new_k_pe: jax.Array, # [max_num_tokens, actual_r_dim]
255
+ cache_kv: jax.
256
+ Array, # [total_num_pages, page_size_per_kv_packing, kv_packing, lkv_dim]
257
+ kv_lens: jax.Array, # i32[max_num_seqs]
258
+ page_indices: jax.Array, # i32[max_num_seqs * pages_per_seq]
259
+ cu_q_lens: jax.Array, # i32[max_num_seqs + 1]
260
+ distribution: jax.Array, # i32[3]
261
+ *,
262
+ sm_scale: float = 1.0,
263
+ sliding_window: int | None = None,
264
+ soft_cap: float | None = None,
265
+ mask_value: float | None = DEFAULT_MASK_VALUE,
266
+ # Kernel optimization params.
267
+ chunk_prefill_size: int | None = None,
268
+ # Kernel tuning params.
269
+ num_kv_pages_per_block: int | None = None,
270
+ num_queries_per_block: int | None = None,
271
+ vmem_limit_bytes: int | None = None,
272
+ # Debug params.
273
+ debug_mode: bool = False,
274
+ ):
275
+ """Validate inputs to the MLA RPA kernel dynamically."""
276
+ static_validate_inputs(
277
+ ql_nope,
278
+ q_pe,
279
+ new_kv_c,
280
+ new_k_pe,
281
+ cache_kv,
282
+ kv_lens,
283
+ page_indices,
284
+ cu_q_lens,
285
+ distribution,
286
+ sm_scale=sm_scale,
287
+ sliding_window=sliding_window,
288
+ soft_cap=soft_cap,
289
+ mask_value=mask_value,
290
+ chunk_prefill_size=chunk_prefill_size,
291
+ num_kv_pages_per_block=num_kv_pages_per_block,
292
+ num_queries_per_block=num_queries_per_block,
293
+ vmem_limit_bytes=vmem_limit_bytes,
294
+ debug_mode=debug_mode,
295
+ )
296
+ max_num_tokens = ql_nope.shape[0]
297
+ total_num_pages = cache_kv.shape[0]
298
+ _, page_size_per_kv_packing, kv_packing, _ = cache_kv.shape
299
+ page_size = page_size_per_kv_packing * kv_packing
300
+ max_num_seqs = kv_lens.shape[0]
301
+ num_page_indices = page_indices.shape[0]
302
+ assert num_page_indices % max_num_seqs == 0
303
+ pages_per_seq = num_page_indices // max_num_seqs
304
+
305
+ i, j, k = distribution
306
+ if not (0 <= i <= j <= k):
307
+ raise ValueError(f"Invalid distribution: {distribution=}")
308
+
309
+ if k > max_num_seqs:
310
+ raise ValueError(f"num_seqs={k} must be <= {max_num_seqs=}")
311
+
312
+ if cu_q_lens[k] > max_num_tokens:
313
+ raise ValueError(
314
+ f"Total q tokens {cu_q_lens[k]} must be <= {max_num_tokens=}.")
315
+ for seq_idx in range(k):
316
+ q_len = cu_q_lens[seq_idx + 1] - cu_q_lens[seq_idx]
317
+ kv_len = kv_lens[seq_idx]
318
+ if not (0 < q_len <= kv_len):
319
+ raise ValueError(
320
+ f"Require 0 < {q_len=} <= {kv_len=} at sequence {seq_idx}.")
321
+ page_cnt = cdiv(kv_len, page_size)
322
+ if page_cnt > pages_per_seq:
323
+ raise ValueError(
324
+ f"Require {page_cnt=} <= {pages_per_seq=} at sequence {seq_idx} where"
325
+ f" {kv_len=} and {page_size=}.")
326
+ for p in range(page_cnt):
327
+ page_idx = page_indices[seq_idx * pages_per_seq + p]
328
+ if not (0 <= page_idx < total_num_pages):
329
+ raise ValueError(
330
+ f"Require 0 <= {page_idx=} < {total_num_pages=} at sequence"
331
+ f" {seq_idx} where {kv_len=} and {page_size=}.")
332
+
333
+
334
+ # Expect to run this validation during compile time.
335
+ def static_validate_inputs(
336
+ ql_nope: jax.Array, # [max_num_tokens, actual_num_q_heads, actual_lkv_dim]
337
+ q_pe: jax.Array, # [max_num_tokens, actual_num_q_heads, actual_r_dim]
338
+ new_kv_c: jax.Array, # [max_num_tokens, actual_lkv_dim]
339
+ new_k_pe: jax.Array, # [max_num_tokens, actual_r_dim]
340
+ cache_kv: jax.
341
+ Array, # [total_num_pages, page_size_per_kv_packing, kv_packing, lkv_dim]
342
+ kv_lens: jax.Array, # i32[max_num_seqs]
343
+ page_indices: jax.Array, # i32[max_num_seqs * pages_per_seq]
344
+ cu_q_lens: jax.Array, # i32[max_num_seqs + 1]
345
+ distribution: jax.Array, # i32[3]
346
+ *,
347
+ sm_scale: float = 1.0,
348
+ sliding_window: int | None = None,
349
+ soft_cap: float | None = None,
350
+ mask_value: float | None = DEFAULT_MASK_VALUE,
351
+ # Kernel optimization params.
352
+ chunk_prefill_size: int | None = None,
353
+ # Kernel tuning params.
354
+ num_kv_pages_per_block: int | None = None,
355
+ num_queries_per_block: int | None = None,
356
+ vmem_limit_bytes: int | None = None,
357
+ # Debug params.
358
+ debug_mode: bool = False,
359
+ ):
360
+ """Validate inputs to the MLA RPA kernel statically."""
361
+ if len(ql_nope.shape) != 3:
362
+ raise ValueError(f"Expected 3D array for {ql_nope.shape=}")
363
+ if len(q_pe.shape) != 3:
364
+ raise ValueError(f"Expected 3D array for {q_pe.shape=}")
365
+ if len(new_kv_c.shape) != 2:
366
+ raise ValueError(f"Expected 2D array for {new_kv_c.shape=}")
367
+ if len(new_k_pe.shape) != 2:
368
+ raise ValueError(f"Expected 2D array for {new_k_pe.shape=}")
369
+
370
+ if ql_nope.shape[:2] != q_pe.shape[:2]:
371
+ raise ValueError(
372
+ f"Expected {ql_nope.shape[:2]=} to be equal to {q_pe.shape[:2]=}")
373
+ if ql_nope.shape[0] != new_kv_c.shape[0]:
374
+ raise ValueError(
375
+ f"Expected {ql_nope.shape[0]=} to be equal to {new_kv_c.shape[0]=}"
376
+ )
377
+ if new_kv_c.shape[0] != new_k_pe.shape[0]:
378
+ raise ValueError(
379
+ f"Expected {new_kv_c.shape[0]=} to be equal to {new_k_pe.shape[0]=}"
380
+ )
381
+ if ql_nope.shape[2] != new_kv_c.shape[1]:
382
+ raise ValueError(
383
+ f"Expected {ql_nope.shape[2]=} to be equal to {new_kv_c.shape[1]=}"
384
+ )
385
+ if q_pe.shape[2] != new_k_pe.shape[1]:
386
+ raise ValueError(
387
+ f"Expected {q_pe.shape[2]=} to be equal to {new_k_pe.shape[1]=}")
388
+
389
+ actual_lkv_dim = ql_nope.shape[2]
390
+ actual_r_dim = q_pe.shape[2]
391
+ lkv_dim = align_to(actual_lkv_dim, 128)
392
+ r_dim = align_to(actual_r_dim, 128)
393
+
394
+ (
395
+ _,
396
+ page_size_per_kv_packing,
397
+ kv_packing,
398
+ kv_dim,
399
+ ) = cache_kv.shape
400
+
401
+ if lkv_dim + r_dim != kv_dim:
402
+ raise ValueError(
403
+ f"Expected {lkv_dim=} + {r_dim=} to be equal to {kv_dim=}")
404
+
405
+ if not (cache_kv.dtype == new_kv_c.dtype):
406
+ raise ValueError(
407
+ f"Expected {cache_kv.dtype=} to be equal to {new_kv_c.dtype=}.")
408
+ if not (cache_kv.dtype == new_k_pe.dtype):
409
+ raise ValueError(
410
+ f"Expected {cache_kv.dtype=} to be equal to {new_k_pe.dtype=}.")
411
+
412
+ # Integer kv quantization is currently not supported.
413
+ if not jnp.issubdtype(cache_kv.dtype, jnp.floating):
414
+ raise ValueError(f"Expected {cache_kv.dtype=} to be a floating point.")
415
+
416
+ if kv_packing != get_dtype_packing(cache_kv.dtype):
417
+ raise ValueError(
418
+ f"{kv_packing=} does not match with {cache_kv.dtype=}")
419
+
420
+ if not (jnp.int32 == kv_lens.dtype == page_indices.dtype == cu_q_lens.dtype
421
+ == distribution.dtype):
422
+ raise ValueError(
423
+ f"Expected int32 dtype for {kv_lens.dtype=}, {page_indices.dtype=},"
424
+ f" {cu_q_lens.dtype=}, {distribution.dtype=}")
425
+
426
+ if not (len(kv_lens.shape) == len(page_indices.shape) == len(
427
+ cu_q_lens.shape) == 1):
428
+ raise ValueError(
429
+ f"Expected 1D array for {kv_lens.shape=}, {page_indices.shape=},"
430
+ f" {cu_q_lens.shape=}")
431
+
432
+ max_num_seqs = kv_lens.shape[0]
433
+ num_page_indices = page_indices.shape[0]
434
+ if num_page_indices % max_num_seqs != 0:
435
+ raise ValueError(
436
+ f"Expected {num_page_indices=} to be divisible by {max_num_seqs=}."
437
+ )
438
+ if cu_q_lens.shape != (max_num_seqs + 1, ):
439
+ raise ValueError(
440
+ f"Expected {cu_q_lens.shape=} to be ({max_num_seqs + 1},).")
441
+ if distribution.shape != (3, ):
442
+ raise ValueError(f"Expected {distribution.shape=} to be (3,).")
443
+
444
+ page_size = page_size_per_kv_packing * kv_packing
445
+ if page_size % kv_packing != 0:
446
+ raise ValueError(f"{page_size=} must be divisible by {kv_packing=}.")
447
+ if sliding_window is not None and sliding_window <= 0:
448
+ raise ValueError(f"{sliding_window=} must be positive.")
449
+ if soft_cap is not None and soft_cap == 0.0:
450
+ raise ValueError(f"{soft_cap=} must not be 0.0.")
451
+ if chunk_prefill_size is not None and chunk_prefill_size <= 0:
452
+ raise ValueError(f"{chunk_prefill_size=} must be positive.")
453
+ if num_kv_pages_per_block is not None:
454
+ if num_kv_pages_per_block <= 0:
455
+ raise ValueError(f"{num_kv_pages_per_block=} must be positive.")
456
+ if num_queries_per_block is not None:
457
+ if num_queries_per_block <= 0:
458
+ raise ValueError(f"{num_queries_per_block=} must be positive.")
459
+ if vmem_limit_bytes is not None and vmem_limit_bytes <= 0:
460
+ raise ValueError(f"{vmem_limit_bytes=} must be positive.")
461
+
462
+ # No constraints for the following inputs.
463
+ del sm_scale
464
+ del mask_value
465
+ del debug_mode
466
+
467
+
468
+ def _mla_ragged_paged_attention_kernel(
469
+ # Prefetch
470
+ kv_lens_ref, # [max_num_seqs]
471
+ page_indices_ref, # [max_num_seqs * pages_per_seq]
472
+ cu_q_lens_ref, # [max_num_seqs + 1]
473
+ # TODO(jevinjiang): merge these into one so we can save SMEM.
474
+ distribution_ref, # [3] (decode_end, prefill_end, mixed_end)
475
+ sem_ids_ref, # [3] (bq_sem_idx, bkv_sem_idx, bo_sem_idx)
476
+ bo_ids_ref, # [4] (bo_sem_0_seq_idx, bo_sem_1_seq_idx, bo_sem_0_bo_idx, bo_sem_1_bo_idx)
477
+ bkv_update_ids_ref, # [6] (bkv_sem_0_seq_idx, bkv_sem_1_seq_idx, bkv_sem_0_offset, bkv_sem_1_offset, bkv_sem_0_sz, bkv_sem_1_sz)
478
+ # Input
479
+ ql_nope_hbm_ref, # [max_num_tokens, num_q_heads_per_q_packing, q_packing, lkv_dim]
480
+ q_pe_hbm_ref, # [max_num_tokens, num_q_heads_per_q_packing, q_packing, r_dim]
481
+ new_kv_c_hbm_ref, # [max_num_tokens_per_kv_packing, kv_packing, lkv_dim]
482
+ new_k_pe_hbm_ref, # [max_num_tokens_per_kv_packing, kv_packing, r_dim]
483
+ cache_kv_hbm_ref, # [total_num_pages, page_size_per_kv_packing, kv_packing, align_to(lkv_dim + r_dim, 128)]
484
+ # Output
485
+ o_hbm_ref, # [max_num_tokens, num_q_heads_per_q_packing, q_packing, lkv_dim]
486
+ updated_cache_kv_hbm_ref, # [total_num_pages, page_size_per_kv_packing, kv_packing, align_to(lkv_dim + r_dim, 128)]
487
+ # Scratch
488
+ bkvc_x2_ref, # [2, bkv_sz_per_kv_packing, kv_packing, lkv_dim].
489
+ bkpe_x2_ref, # [2, bkv_sz_per_kv_packing, kv_packing, r_dim]
490
+ bq_nope_x2_ref, # [2, bq_sz, num_q_heads_per_q_packing, q_packing, lkv_dim]
491
+ bq_rope_x2_ref, # [2, bq_sz, num_q_heads_per_q_packing, q_packing, r_dim]
492
+ bo_x2_ref, # [2, bq_sz, num_q_heads_per_q_packing, q_packing, lkv_dim]
493
+ sems, # [4, 2]
494
+ l_ref, # [bq_sz * num_q_heads, 128],
495
+ m_ref, # [bq_sz * num_q_heads, 128],
496
+ acc_ref, # [bq_sz * num_q_heads, lkv_dim],
497
+ *,
498
+ sm_scale: float,
499
+ sliding_window: int | None = None,
500
+ soft_cap: float | None = None,
501
+ mask_value: float = DEFAULT_MASK_VALUE,
502
+ q_scale: float | None = None,
503
+ k_scale: float | None = None,
504
+ v_scale: float | None = None,
505
+ chunk_prefill_size: int | None = None,
506
+ bkv_p,
507
+ bq_sz,
508
+ debug_mode: bool = False,
509
+ ):
510
+ assert ql_nope_hbm_ref.shape == o_hbm_ref.shape
511
+ # Validation checks on the dimensions
512
+ nope_dim = ql_nope_hbm_ref.shape[-1]
513
+ pe_dim = q_pe_hbm_ref.shape[-1]
514
+ assert nope_dim + pe_dim == cache_kv_hbm_ref.shape[-1]
515
+
516
+ _, num_q_heads_per_q_packing, q_packing, lkv_dim = ql_nope_hbm_ref.shape
517
+ r_dim = q_pe_hbm_ref.shape[-1]
518
+ num_q_heads = num_q_heads_per_q_packing * q_packing
519
+ total_num_pages, page_size_per_kv_packing, kv_packing, _ = (
520
+ cache_kv_hbm_ref.shape)
521
+ max_num_seqs = kv_lens_ref.shape[0]
522
+ num_page_indices = page_indices_ref.shape[0]
523
+
524
+ assert num_page_indices % max_num_seqs == 0
525
+ pages_per_seq = num_page_indices // max_num_seqs
526
+ q_dtype = ql_nope_hbm_ref.dtype
527
+ # Validate against the KV dtype.
528
+ kv_dtype = cache_kv_hbm_ref.dtype
529
+ assert q_pe_hbm_ref.dtype == q_dtype
530
+ assert o_hbm_ref.dtype == q_dtype
531
+ assert get_dtype_packing(q_dtype) == q_packing
532
+ assert get_dtype_packing(kv_dtype) == kv_packing
533
+ assert lkv_dim % 128 == 0
534
+ assert r_dim % 128 == 0
535
+ bkv_sz_per_kv_packing = bkv_p * page_size_per_kv_packing
536
+ bkv_sz = bkv_sz_per_kv_packing * kv_packing
537
+ page_size = page_size_per_kv_packing * kv_packing
538
+ seq_idx = pl.program_id(0)
539
+ num_seqs = pl.num_programs(0)
540
+ decode_end = distribution_ref[0]
541
+ prefill_end = distribution_ref[1]
542
+ mixed_end = distribution_ref[2]
543
+
544
+ q_start = cu_q_lens_ref[seq_idx]
545
+ q_end = cu_q_lens_ref[seq_idx + 1]
546
+ q_len = q_end - q_start
547
+ kv_len = kv_lens_ref[seq_idx]
548
+
549
+ def debug_print(msg, *args):
550
+ if debug_mode:
551
+ pl.debug_print(msg, *args)
552
+
553
+ debug_print("[RPA debug] ======= In loop seq_idx={}", seq_idx)
554
+ debug_print("[RPA debug] num_seqs={}", num_seqs)
555
+ debug_print("[RPA debug] decode_end={}", decode_end)
556
+ debug_print("[RPA debug] prefill_end={}", prefill_end)
557
+ debug_print("[RPA debug] mixed_end={}", mixed_end)
558
+ debug_print("[RPA debug] bkv_p={}", bkv_p)
559
+ debug_print("[RPA debug] page_size={}", page_size)
560
+ debug_print("[RPA debug] pages_per_seq={}", pages_per_seq)
561
+ debug_print("[RPA debug] bkv_sz_per_kv_packing={}", bkv_sz_per_kv_packing)
562
+ debug_print("[RPA debug] bq_sz={}", bq_sz)
563
+ debug_print("[RPA debug] q_start={}", q_start)
564
+ debug_print("[RPA debug] q_end={}", q_end)
565
+ debug_print("[RPA debug] q_len={}", q_len)
566
+ debug_print("[RPA debug] kv_len={}", kv_len)
567
+
568
+ def flash_attention(
569
+ ql_nope, # [actual_bq_sz * num_q_heads, lkv_dim]
570
+ q_pe, # [actual_bq_sz * num_q_heads, r_dim]
571
+ kv_c, # [bkv_sz, lkv_dim] <- Correspond to data from bkvc_x2_ref
572
+ k_pe, # [bkv_sz, r_dim] <- Correspond to data from bpe_x2_ref
573
+ *,
574
+ bq_idx,
575
+ bkv_idx,
576
+ ):
577
+ assert len(ql_nope.shape) == 2
578
+ assert len(q_pe.shape) == 2
579
+ assert len(kv_c.shape) == 2
580
+ assert len(k_pe.shape) == 2
581
+ assert ql_nope.shape[0] % num_q_heads == 0
582
+ assert ql_nope.shape[0] == q_pe.shape[0]
583
+ assert q_pe.shape[0] % bq_sz == 0
584
+ assert ql_nope.shape[1] == lkv_dim
585
+ assert q_pe.shape[1] == r_dim
586
+ assert kv_c.shape == (bkv_sz, lkv_dim)
587
+ assert k_pe.shape == (bkv_sz, r_dim)
588
+ head_l_ref = l_ref.at[:ql_nope.shape[0]]
589
+ head_m_ref = m_ref.at[:ql_nope.shape[0]]
590
+ head_acc_ref = acc_ref.at[:ql_nope.shape[0]]
591
+
592
+ def load_with_init(ref, init_val):
593
+ return jnp.where(bkv_idx == 0, jnp.full_like(ref, init_val),
594
+ ref[...])
595
+
596
+ # Follow FlashAttention-2 forward pass.
597
+ s_nope = jnp.einsum("nd,md->nm",
598
+ ql_nope,
599
+ kv_c,
600
+ preferred_element_type=jnp.float32)
601
+ s_pe = jnp.einsum("nd,md->nm",
602
+ q_pe,
603
+ k_pe,
604
+ preferred_element_type=jnp.float32)
605
+ s = s_nope + s_pe
606
+ s *= sm_scale
607
+ if k_scale is not None:
608
+ s *= k_scale
609
+ if q_scale is not None:
610
+ s *= q_scale
611
+
612
+ q_span = (kv_len - q_len + bq_idx * bq_sz +
613
+ lax.broadcasted_iota(jnp.int32, s.shape, 0) // num_q_heads)
614
+ k_span = bkv_idx * bkv_sz + lax.broadcasted_iota(jnp.int32, s.shape, 1)
615
+ mask = q_span < k_span
616
+ # TODO(jevinjiang, xiowei): reduce pages_per_seq based on sliding_window.
617
+ if sliding_window is not None:
618
+ mask = jnp.logical_or(mask, q_span - sliding_window >= k_span)
619
+
620
+ if soft_cap is not None:
621
+ s = soft_cap * jnp.tanh(s / soft_cap)
622
+ s = jnp.where(mask, mask_value, s)
623
+ s_rowmax = jnp.max(s, axis=1, keepdims=True)
624
+ m_prev = load_with_init(head_m_ref, -jnp.inf)
625
+ m_curr = jnp.maximum(m_prev, s_rowmax)
626
+ head_m_ref[...] = m_curr
627
+ p = jnp.exp(s - broadcast_minor(m_curr, s.shape))
628
+
629
+ pv = jnp.einsum("nm,md->nd",
630
+ p,
631
+ kv_c,
632
+ preferred_element_type=jnp.float32)
633
+ if v_scale is not None:
634
+ pv *= v_scale
635
+
636
+ p_rowsum = jnp.sum(p, axis=1, keepdims=True)
637
+ exp_m_diff = jnp.exp(m_prev - m_curr)
638
+ l_prev = load_with_init(head_l_ref, 0.0)
639
+ l_curr = exp_m_diff * l_prev + p_rowsum
640
+ head_l_ref[...] = l_curr
641
+ o_prev = load_with_init(head_acc_ref, 0.0)
642
+ o_curr = broadcast_minor(exp_m_diff, o_prev.shape) * o_prev + pv
643
+ head_acc_ref[...] = o_curr
644
+
645
+ def _async_copy(src, dst, sem, wait):
646
+ if debug_mode:
647
+ # Skip DMA if debug mode is enabled.
648
+ return
649
+ cp = pltpu.make_async_copy(src, dst, sem)
650
+ if wait:
651
+ cp.wait()
652
+ else:
653
+ cp.start()
654
+
655
+ def _fetch_bkv(seq_idx, bkv_idx, bkv_sem_idx, *, wait=False):
656
+ sem = sems.at[0, bkv_sem_idx]
657
+ bkvc_vmem_ref = bkvc_x2_ref.at[bkv_sem_idx]
658
+ bkvpe_vmem_ref = bkpe_x2_ref.at[bkv_sem_idx]
659
+ reshaped_cache_hbm_ref = cache_kv_hbm_ref.reshape(
660
+ total_num_pages * page_size_per_kv_packing,
661
+ *cache_kv_hbm_ref.shape[2:],
662
+ )
663
+ kv_len = kv_lens_ref[seq_idx]
664
+ kv_len_start = bkv_idx * bkv_sz
665
+ kv_p_start = bkv_idx * bkv_p
666
+
667
+ kv_left = kv_len - kv_len_start
668
+ kv_left_per_kv_packing = cdiv(kv_left, kv_packing)
669
+ page_indices_offset = seq_idx * pages_per_seq + kv_p_start
670
+
671
+ debug_print(
672
+ "[RPA debug]"
673
+ f" -----------{'wait' if wait else 'start'}_fetch_bkv-----------")
674
+ debug_print("[RPA debug] seq_idx={}", seq_idx)
675
+ debug_print("[RPA debug] bkv_idx={}", bkv_idx)
676
+ debug_print("[RPA debug] bkv_sem_idx={}", bkv_sem_idx)
677
+ debug_print("[RPA debug] kv_len_start={}", kv_len_start)
678
+ debug_print("[RPA debug] kv_p_start={}", kv_p_start)
679
+ debug_print("[RPA debug] kv_left={}", kv_left)
680
+ debug_print("[RPA debug] page_indices_offset={}", page_indices_offset)
681
+
682
+ # Fetch effective kv from kv cache.
683
+ def loop_body(i, _):
684
+ sz_per_kv_packing = jnp.minimum(
685
+ page_size_per_kv_packing,
686
+ kv_left_per_kv_packing - i * page_size_per_kv_packing,
687
+ )
688
+ _async_copy(
689
+ reshaped_cache_hbm_ref.at[pl.ds(
690
+ page_indices_ref[page_indices_offset + i] *
691
+ page_size_per_kv_packing,
692
+ sz_per_kv_packing,
693
+ ), ..., :nope_dim],
694
+ bkvc_vmem_ref.at[pl.ds(i * page_size_per_kv_packing,
695
+ sz_per_kv_packing)],
696
+ sem,
697
+ wait,
698
+ )
699
+ _async_copy(
700
+ reshaped_cache_hbm_ref.at[pl.ds(
701
+ page_indices_ref[page_indices_offset + i] *
702
+ page_size_per_kv_packing,
703
+ sz_per_kv_packing,
704
+ ), ..., nope_dim:],
705
+ bkvpe_vmem_ref.at[pl.ds(i * page_size_per_kv_packing,
706
+ sz_per_kv_packing)],
707
+ sem,
708
+ wait,
709
+ )
710
+ debug_print(
711
+ "[RPA debug] loop_body i={}, sz_per_kv_packing={}",
712
+ i,
713
+ sz_per_kv_packing,
714
+ )
715
+
716
+ actual_bkv_p = jnp.minimum(cdiv(kv_left, page_size), bkv_p)
717
+ lax.fori_loop(
718
+ 0,
719
+ actual_bkv_p,
720
+ loop_body,
721
+ None, # init value
722
+ unroll=False,
723
+ )
724
+
725
+ def _fetch_bq(seq_idx, bq_idx, bq_sem_idx, *, wait=False):
726
+ sem = sems.at[1, bq_sem_idx]
727
+ bq_nope_vmem_ref = bq_nope_x2_ref.at[bq_sem_idx]
728
+ bq_rope_vmem_ref = bq_rope_x2_ref.at[bq_sem_idx]
729
+
730
+ q_len_start = cu_q_lens_ref[seq_idx] + bq_idx * bq_sz
731
+ q_end = cu_q_lens_ref[seq_idx + 1]
732
+ sz = jnp.minimum(bq_sz, q_end - q_len_start)
733
+
734
+ debug_print(
735
+ "[RPA debug]"
736
+ f" -----------{'wait' if wait else 'start'}_fetch_bq-----------")
737
+ debug_print("[RPA debug] seq_idx={}", seq_idx)
738
+ debug_print("[RPA debug] bq_idx={}", bq_idx)
739
+ debug_print("[RPA debug] bq_sem_idx={}", bq_sem_idx)
740
+ debug_print("[RPA debug] q_len_start={}", q_len_start)
741
+ debug_print("[RPA debug] q_end={}", q_end)
742
+ debug_print("[RPA debug] sz={}", sz)
743
+
744
+ _async_copy(
745
+ ql_nope_hbm_ref.at[pl.ds(q_len_start, sz)],
746
+ bq_nope_vmem_ref.at[pl.ds(0, sz)],
747
+ sem,
748
+ wait,
749
+ )
750
+
751
+ _async_copy(
752
+ q_pe_hbm_ref.at[pl.ds(q_len_start, sz)],
753
+ bq_rope_vmem_ref.at[pl.ds(0, sz)],
754
+ sem,
755
+ wait,
756
+ )
757
+
758
+ def _send_bo(seq_idx, bo_idx, bo_sem_idx, *, wait=False):
759
+ sem = sems.at[2, bo_sem_idx]
760
+ vmem_ref = bo_x2_ref.at[bo_sem_idx]
761
+ q_len_start = cu_q_lens_ref[seq_idx] + bo_idx * bq_sz
762
+ q_end = cu_q_lens_ref[seq_idx + 1]
763
+ sz = jnp.minimum(bq_sz, q_end - q_len_start)
764
+
765
+ debug_print(
766
+ "[RPA debug]"
767
+ f" -----------{'wait' if wait else 'start'}_send_bo-----------")
768
+ debug_print("[RPA debug] seq_idx={}", seq_idx)
769
+ debug_print("[RPA debug] bo_idx={}", bo_idx)
770
+ debug_print("[RPA debug] bo_sem_idx={}", bo_sem_idx)
771
+ debug_print("[RPA debug] q_len_start={}", q_len_start)
772
+ debug_print("[RPA debug] q_end={}", q_end)
773
+ debug_print("[RPA debug] sz={}", sz)
774
+
775
+ _async_copy(
776
+ vmem_ref.at[pl.ds(0, sz)],
777
+ o_hbm_ref.at[pl.ds(q_len_start, sz)],
778
+ sem,
779
+ wait,
780
+ )
781
+
782
+ def start_fetch_bkv(seq_idx, bkv_idx, bkv_sem_idx):
783
+ return _fetch_bkv(seq_idx, bkv_idx, bkv_sem_idx)
784
+
785
+ def wait_fetch_bkv(seq_idx, bkv_idx, bkv_sem_idx):
786
+ return _fetch_bkv(seq_idx, bkv_idx, bkv_sem_idx, wait=True)
787
+
788
+ def start_fetch_bq(seq_idx, bq_idx, bq_sem_idx):
789
+ return _fetch_bq(seq_idx, bq_idx, bq_sem_idx)
790
+
791
+ def wait_fetch_bq(seq_idx, bq_idx, bq_sem_idx):
792
+ return _fetch_bq(seq_idx, bq_idx, bq_sem_idx, wait=True)
793
+
794
+ def start_send_bo(seq_idx, bo_idx, bo_sem_idx):
795
+ bo_ids_ref[bo_sem_idx] = seq_idx
796
+ bo_ids_ref[bo_sem_idx + 2] = bo_idx
797
+ _send_bo(seq_idx, bo_idx, bo_sem_idx)
798
+
799
+ def wait_send_bo(bo_sem_idx):
800
+ old_seq_idx = bo_ids_ref[bo_sem_idx]
801
+ old_bo_idx = bo_ids_ref[bo_sem_idx + 2]
802
+
803
+ @pl.when(jnp.logical_and(0 <= old_seq_idx, old_seq_idx <= seq_idx))
804
+ def _():
805
+ _send_bo(old_seq_idx, old_bo_idx, bo_sem_idx, wait=True)
806
+
807
+ def load_bq(bq_sem_idx, *, actual_bq_sz=bq_sz):
808
+ q_nope_ref = (bq_nope_x2_ref.bitcast(
809
+ jnp.uint32).at[bq_sem_idx].reshape(
810
+ bq_sz * num_q_heads_per_q_packing, lkv_dim))
811
+ q_nope_vec = pltpu.bitcast(
812
+ q_nope_ref[:actual_bq_sz * num_q_heads_per_q_packing],
813
+ q_dtype,
814
+ )
815
+ q_rope_ref = (bq_rope_x2_ref.bitcast(
816
+ jnp.uint32).at[bq_sem_idx].reshape(
817
+ bq_sz * num_q_heads_per_q_packing, r_dim))
818
+ q_rope_vec = pltpu.bitcast(
819
+ q_rope_ref[:actual_bq_sz * num_q_heads_per_q_packing],
820
+ q_dtype,
821
+ )
822
+ return q_nope_vec, q_rope_vec
823
+
824
+ def load_bkv(bkv_sem_idx, *, bkvc_mask, bkpe_mask):
825
+ bitwidth = 32 // kv_packing
826
+ repack_ty = jnp.dtype(f"uint{bitwidth}")
827
+ bkvc_ref = (bkvc_x2_ref.bitcast(jnp.uint32).at[bkv_sem_idx].reshape(
828
+ bkv_sz_per_kv_packing, lkv_dim))
829
+ bkvc_vec = bkvc_ref[...]
830
+ bkvc_vecs = []
831
+ for i in range(kv_packing):
832
+ masked_bkvc_vec = bkvc_vec >> (i * bitwidth)
833
+ bkvc_vecs.append(masked_bkvc_vec)
834
+ concated_bkvc_vec = jnp.concatenate(bkvc_vecs, axis=-1)
835
+ concated_bkvc_vec = concated_bkvc_vec.reshape(bkv_sz, lkv_dim)
836
+ concated_bkvc_vec = lax.select(bkvc_mask, concated_bkvc_vec,
837
+ jnp.zeros_like(concated_bkvc_vec))
838
+ concated_bkvc_vec = pltpu.bitcast(concated_bkvc_vec.astype(repack_ty),
839
+ kv_dtype)
840
+ bkpe_ref = (bkpe_x2_ref.bitcast(jnp.uint32).at[bkv_sem_idx].reshape(
841
+ bkv_sz_per_kv_packing, r_dim))
842
+ bkpe_vec = bkpe_ref[...]
843
+ bkpe_vecs = []
844
+ for i in range(kv_packing):
845
+ masked_bkpe_vec = bkpe_vec >> (i * bitwidth)
846
+ bkpe_vecs.append(masked_bkpe_vec)
847
+ concated_bkpe_vec = jnp.concatenate(bkpe_vecs, axis=-1)
848
+ concated_bkpe_vec = concated_bkpe_vec.reshape(bkv_sz, r_dim)
849
+ concated_bkpe_vec = lax.select(bkpe_mask, concated_bkpe_vec,
850
+ jnp.zeros_like(concated_bkpe_vec))
851
+ concated_bkpe_vec = pltpu.bitcast(concated_bkpe_vec.astype(repack_ty),
852
+ kv_dtype)
853
+
854
+ return concated_bkvc_vec, concated_bkpe_vec
855
+
856
+ def broadcast_minor(src, shape):
857
+ if src.shape == shape:
858
+ return src
859
+ assert src.shape[:-1] == shape[:-1]
860
+ assert src.shape[-1] % 128 == 0
861
+ target_minor = align_to(shape[-1], src.shape[-1])
862
+ # no-op concatenation.
863
+ return jnp.concatenate(
864
+ [src for _ in range(target_minor // src.shape[-1])],
865
+ axis=-1)[..., :shape[-1]]
866
+
867
+ def process(static_q_len=None):
868
+ num_bkv = cdiv(kv_len, bkv_sz)
869
+ if static_q_len is None:
870
+ actual_bq_sz = bq_sz
871
+ num_bq = cdiv(q_len, actual_bq_sz)
872
+ else:
873
+ actual_bq_sz = min(bq_sz, static_q_len)
874
+ num_bq = cdiv(static_q_len, actual_bq_sz)
875
+
876
+ def get_next_bq_ids(seq_idx, bq_idx, bq_sem_idx):
877
+ next_bq_idx = bq_idx + 1
878
+ is_last_bq = next_bq_idx == num_bq
879
+ next_bq_idx = lax.select(is_last_bq, 0, next_bq_idx)
880
+ next_seq_idx = lax.select(is_last_bq, seq_idx + 1, seq_idx)
881
+ next_bq_sem_idx = lax.select(bq_sem_idx == 0, 1, 0)
882
+ return next_seq_idx, next_bq_idx, next_bq_sem_idx
883
+
884
+ def get_next_bkv_ids(seq_idx, bq_idx, bkv_idx, bkv_sem_idx):
885
+ next_bkv_idx = bkv_idx + 1
886
+ is_last_bkv = next_bkv_idx == num_bkv
887
+ next_bkv_idx = lax.select(is_last_bkv, 0, next_bkv_idx)
888
+ next_bq_idx = lax.select(is_last_bkv, bq_idx + 1, bq_idx)
889
+ is_last_bq = next_bq_idx == num_bq
890
+ next_bq_idx = lax.select(is_last_bq, 0, next_bq_idx)
891
+ next_seq_idx = lax.select(is_last_bq, seq_idx + 1, seq_idx)
892
+ next_bkv_sem_idx = lax.select(bkv_sem_idx == 0, 1, 0)
893
+ return next_seq_idx, next_bq_idx, next_bkv_idx, next_bkv_sem_idx
894
+
895
+ def compute_with_bq(bq_idx, _):
896
+ bq_sem_idx = sem_ids_ref[0]
897
+ next_seq_idx, next_bq_idx, next_bq_sem_idx = get_next_bq_ids(
898
+ seq_idx, bq_idx, bq_sem_idx)
899
+
900
+ # Prefetch next bq
901
+ @pl.when(next_seq_idx < num_seqs)
902
+ def prefetch_next_bq():
903
+ sem_ids_ref[0] = next_bq_sem_idx
904
+ start_fetch_bq(next_seq_idx, next_bq_idx, next_bq_sem_idx)
905
+
906
+ def compute_with_bkv(bkv_idx, _):
907
+ # Create bitmask for KV.
908
+ assert bkv_sz % kv_packing == 0
909
+ actual_bkv_sz = jnp.minimum(bkv_sz, kv_len - bkv_idx * bkv_sz)
910
+ bkvc_shape = (bkv_sz, lkv_dim)
911
+ bkvc_mask = (lax.broadcasted_iota(jnp.int32, bkvc_shape, 0)
912
+ < actual_bkv_sz)
913
+ bkpe_shape = (bkv_sz, r_dim)
914
+ bkpe_mask = (lax.broadcasted_iota(jnp.int32, bkpe_shape, 0)
915
+ < actual_bkv_sz)
916
+
917
+ # Get next bkv ids.
918
+ bkv_sem_idx = sem_ids_ref[1]
919
+ next_seq_idx, _, next_bkv_idx, next_bkv_sem_idx = get_next_bkv_ids(
920
+ seq_idx, bq_idx, bkv_idx, bkv_sem_idx)
921
+
922
+ # Prefetch next bkv
923
+ @pl.when(next_seq_idx < num_seqs)
924
+ def prefetch_next_bkv():
925
+ sem_ids_ref[1] = next_bkv_sem_idx
926
+ start_fetch_bkv(next_seq_idx, next_bkv_idx,
927
+ next_bkv_sem_idx)
928
+
929
+ # Wait for cur bq if not ready yet
930
+ @pl.when(bkv_idx == 0)
931
+ def wait_cur_bq():
932
+ wait_fetch_bq(seq_idx, bq_idx, bq_sem_idx)
933
+
934
+ # Wait for cur bkv
935
+ wait_fetch_bkv(seq_idx, bkv_idx, bkv_sem_idx)
936
+
937
+ debug_print(
938
+ "[RPA debug] -----------flash attention-----------")
939
+ debug_print("[RPA debug] seq_idx={}", seq_idx)
940
+ debug_print("[RPA debug] bq_idx={}", bq_idx)
941
+ debug_print("[RPA debug] bkv_idx={}", bkv_idx)
942
+ if debug_mode:
943
+ # Skip flash attention if debug mode is enabled.
944
+ return
945
+
946
+ # Flash attention with cur bkv and bq
947
+ bkvc, bkpe = load_bkv(bkv_sem_idx,
948
+ bkvc_mask=bkvc_mask,
949
+ bkpe_mask=bkpe_mask)
950
+ bq_nope_vec, bq_pe_vec = load_bq(bq_sem_idx,
951
+ actual_bq_sz=actual_bq_sz)
952
+ flash_attention(
953
+ bq_nope_vec,
954
+ bq_pe_vec,
955
+ bkvc,
956
+ bkpe,
957
+ bq_idx=bq_idx,
958
+ bkv_idx=bkv_idx,
959
+ )
960
+
961
+ lax.fori_loop(0, num_bkv, compute_with_bkv, None, unroll=False)
962
+
963
+ # Load acc and calculate final output.
964
+ acc = acc_ref[...]
965
+ l = broadcast_minor(l_ref[...], acc.shape) # noqa
966
+ out = (lax.div(acc, l) if q_dtype == jnp.float32 else
967
+ (acc * pl.reciprocal(l, approx=True)).astype(q_dtype))
968
+
969
+ # Wait for previous bo to be fully sent before storing new bo.
970
+ bo_sem_idx = sem_ids_ref[2]
971
+ sem_ids_ref[2] = lax.select(bo_sem_idx == 0, 1, 0)
972
+ wait_send_bo(bo_sem_idx)
973
+
974
+ # Store output from acc to bo.
975
+ bo_x2_ref.at[bo_sem_idx].bitcast(jnp.int32).reshape(
976
+ bq_sz * num_q_heads_per_q_packing,
977
+ lkv_dim,
978
+ )[...] = pltpu.bitcast(out, jnp.int32)
979
+
980
+ # Send cur bo
981
+ start_send_bo(seq_idx, bq_idx, bo_sem_idx)
982
+
983
+ lax.fori_loop(0, num_bq, compute_with_bq, None, unroll=False)
984
+
985
+ ### ------- Kernel start ------- ###
986
+
987
+ @pl.when(seq_idx == 0)
988
+ def prologue():
989
+ start_fetch_bq(0, 0, 0)
990
+ start_fetch_bkv(0, 0, 0)
991
+
992
+ @pl.when(seq_idx < decode_end)
993
+ def process_decode():
994
+ process(static_q_len=1)
995
+
996
+ @pl.when(jnp.logical_and(decode_end <= seq_idx, seq_idx < prefill_end))
997
+ def process_prefill():
998
+ process(static_q_len=chunk_prefill_size)
999
+
1000
+ @pl.when(jnp.logical_and(prefill_end <= seq_idx, seq_idx < mixed_end))
1001
+ def process_mixed():
1002
+ process()
1003
+
1004
+ @pl.when(seq_idx == num_seqs - 1)
1005
+ def epilogue():
1006
+ for i in range(2):
1007
+ wait_send_bo(i)
1008
+
1009
+ ### ------- Kernel end ------- ###
1010
+
1011
+
1012
+ def prepare_q_inputs(
1013
+ q: jax.Array, # [max_num_tokens, actual_num_q_heads, actual_head_dim],
1014
+ ):
1015
+ max_num_tokens, actual_num_q_heads, actual_head_dim = q.shape
1016
+ q_packing = get_dtype_packing(q.dtype)
1017
+ num_q_heads = align_to(actual_num_q_heads, q_packing)
1018
+ head_dim = align_to(actual_head_dim, 128)
1019
+ q = jnp.pad(
1020
+ q.reshape(
1021
+ max_num_tokens,
1022
+ actual_num_q_heads,
1023
+ actual_head_dim,
1024
+ ),
1025
+ (
1026
+ (0, 0),
1027
+ (0, num_q_heads - actual_num_q_heads),
1028
+ (0, head_dim - actual_head_dim),
1029
+ ),
1030
+ constant_values=0,
1031
+ ).reshape(
1032
+ max_num_tokens,
1033
+ num_q_heads // q_packing,
1034
+ q_packing,
1035
+ head_dim,
1036
+ )
1037
+ return q
1038
+
1039
+
1040
+ def prepare_kv_inputs(
1041
+ kv: jax.Array, # [max_num_tokens, actual_head_dim],
1042
+ ):
1043
+ max_num_tokens, actual_head_dim = kv.shape
1044
+ kv_packing = get_dtype_packing(kv.dtype)
1045
+ assert max_num_tokens % kv_packing == 0
1046
+ head_dim = align_to(actual_head_dim, 128)
1047
+
1048
+ kv = kv.reshape(max_num_tokens // kv_packing, kv_packing, actual_head_dim)
1049
+ kv = jnp.pad(kv, ((0, 0), (0, 0), (0, head_dim - actual_head_dim)),
1050
+ constant_values=0)
1051
+
1052
+ return kv
1053
+
1054
+
1055
+ def prepare_outputs(
1056
+ out, # [max_num_tokens, num_q_heads // q_packing, q_packing, head_dim]
1057
+ actual_num_q_heads: int,
1058
+ actual_head_dim: int,
1059
+ ):
1060
+ (
1061
+ max_num_tokens,
1062
+ num_q_heads_per_q_packing,
1063
+ q_packing,
1064
+ head_dim,
1065
+ ) = out.shape
1066
+ return out.reshape(
1067
+ max_num_tokens,
1068
+ num_q_heads_per_q_packing * q_packing,
1069
+ head_dim,
1070
+ )[:, :actual_num_q_heads, :actual_head_dim]
1071
+
1072
+
1073
+ @functools.partial(
1074
+ jax.jit,
1075
+ static_argnames=(
1076
+ "sm_scale",
1077
+ "sliding_window",
1078
+ "soft_cap",
1079
+ "mask_value",
1080
+ "chunk_prefill_size",
1081
+ "num_kv_pages_per_block",
1082
+ "num_queries_per_block",
1083
+ "vmem_limit_bytes",
1084
+ "debug_mode",
1085
+ ),
1086
+ donate_argnames=("cache_kv"),
1087
+ )
1088
+ def mla_ragged_paged_attention(
1089
+ ql_nope: jax.Array, # [max_num_tokens, actual_num_q_heads, actual_lkv_dim]
1090
+ q_pe: jax.Array, # [max_num_tokens, actual_num_q_heads, actual_r_dim]
1091
+ new_kv_c: jax.Array, # [max_num_tokens, actual_lkv_dim]
1092
+ new_k_pe: jax.Array, # [max_num_tokens, actual_r_dim]
1093
+ # TODO(gpolovets): Explore separating out into lkv & pe KV caches.
1094
+ cache_kv: jax.
1095
+ Array, # [total_num_pages, page_size_per_kv_packing, kv_packing, align_to(lkv_dim, 128)]
1096
+ kv_lens: jax.Array, # i32[max_num_seqs]
1097
+ page_indices: jax.Array, # i32[max_num_seqs * pages_per_seq]
1098
+ cu_q_lens: jax.Array, # i32[max_num_seqs + 1]
1099
+ distribution: jax.Array, # i32[3]
1100
+ *,
1101
+ sm_scale: float = 1.0,
1102
+ sliding_window: int | None = None,
1103
+ soft_cap: float | None = None,
1104
+ mask_value: float | None = DEFAULT_MASK_VALUE,
1105
+ # Kernel optimization params.
1106
+ chunk_prefill_size: int | None = None,
1107
+ # Kernel tuning params.
1108
+ num_kv_pages_per_block: int | None = None,
1109
+ num_queries_per_block: int | None = None,
1110
+ vmem_limit_bytes: int | None = None,
1111
+ # Debug params.
1112
+ debug_mode: bool = False,
1113
+ ) -> tuple[
1114
+ jax.Array, # [max_num_tokens, actual_num_q_heads, actual_lkv_dim]
1115
+ jax.
1116
+ Array, # [total_num_pages, page_size_per_kv_packing, kv_packing, lkv_dim]
1117
+ jax.
1118
+ Array, # [total_num_pages, page_size_per_kv_packing, kv_packing, r_dim]
1119
+ ]:
1120
+ """MLA Ragged paged attention that supports mixed prefill and decode.
1121
+
1122
+ Args:
1123
+ ql_nope: concatenated all sequences' queries.
1124
+ q_pe: concatenated all sequences' rope.
1125
+ new_kv_c: concatenated all sequences' kv_c values
1126
+ new_k_pe: concatenated all sequences' k_pe values
1127
+ cache_kv: the current kv cache.
1128
+ kv_lens: the length of each sequence in the kv cache.
1129
+ page_indices: flattened page indices look-up table by (seq_id, page_id).
1130
+ cu_q_lens: the cumulative sum of the effective query lengths. Similar to
1131
+ kv_lens, only the first num_seqs+1 values are valid.
1132
+ distribution: (i, j, k) represents that sequences[0:i] are decode-only,
1133
+ sequences[i:j] are chunked-prefill-only, and sequences[j:k] are mixed. The
1134
+ k is also the total number of sequences.
1135
+ sm_scale: the softmax scale which will be applied to the Q@K^T.
1136
+ sliding_window: the sliding window size for the attention.
1137
+ soft_cap: the logit soft cap for the attention.
1138
+ mask_value: mask value for causal mask.
1139
+ num_kv_pages_per_block: number of kv pages to be processed in one flash
1140
+ attention block in the pallas kernel.
1141
+ num_queries_per_block: number of kv pages to be processed in one flash
1142
+ attention block in the pallas kernel.
1143
+ vmem_limit_bytes: the vmem limit for the pallas kernel.
1144
+ debug_mode: if true, RPA does not issue any DMAs or run flash attention but
1145
+ print debug info. Need to compile with `--xla_tpu_enable_log_recorder`.
1146
+
1147
+ Returns:
1148
+ The output of the attention.
1149
+ """
1150
+ # TODO(chengjiyao): Support both autotuning table and heuristic logic to set
1151
+ # these kernel block sizes
1152
+ if num_kv_pages_per_block is None or num_queries_per_block is None:
1153
+ raise ValueError(
1154
+ "num_kv_pages_per_block and num_queries_per_block must be specified."
1155
+ )
1156
+ static_validate_inputs(
1157
+ ql_nope,
1158
+ q_pe,
1159
+ new_kv_c,
1160
+ new_k_pe,
1161
+ cache_kv,
1162
+ kv_lens,
1163
+ page_indices,
1164
+ cu_q_lens,
1165
+ distribution,
1166
+ sm_scale=sm_scale,
1167
+ sliding_window=sliding_window,
1168
+ soft_cap=soft_cap,
1169
+ mask_value=mask_value,
1170
+ chunk_prefill_size=chunk_prefill_size,
1171
+ num_kv_pages_per_block=num_kv_pages_per_block,
1172
+ num_queries_per_block=num_queries_per_block,
1173
+ vmem_limit_bytes=vmem_limit_bytes,
1174
+ debug_mode=debug_mode,
1175
+ )
1176
+
1177
+ # TODO(chengjiyao): fuse kv cache update into the kernel.
1178
+ cache_kv = update_kv_cache(
1179
+ new_kv_c,
1180
+ new_k_pe,
1181
+ cache_kv,
1182
+ kv_lens,
1183
+ page_indices,
1184
+ cu_q_lens,
1185
+ distribution,
1186
+ )
1187
+
1188
+ _, actual_num_q_heads, actual_lkv_dim = ql_nope.shape
1189
+
1190
+ ql_nope = prepare_q_inputs(
1191
+ ql_nope
1192
+ ) # [max_num_tokens, num_q_heads_per_q_packing, q_packing, lkv_dim]
1193
+ q_pe = prepare_q_inputs(
1194
+ q_pe) # [max_num_tokens, num_q_heads_per_q_packing, q_packing, r_dim]
1195
+ new_kv_c = prepare_kv_inputs(
1196
+ new_kv_c) # [max_num_tokens_per_kv_packing, kv_packing, lkv_dim]
1197
+ new_k_pe = prepare_kv_inputs(
1198
+ new_k_pe) # [max_num_tokens_per_kv_packing, kv_packing, r_dim]
1199
+ lkv_dim = new_kv_c.shape[-1]
1200
+ r_dim = new_k_pe.shape[-1]
1201
+
1202
+ _, page_size_per_kv_packing, kv_packing, _ = cache_kv.shape
1203
+ page_size = page_size_per_kv_packing * kv_packing
1204
+ _, num_q_heads_per_q_packing, q_packing, _ = ql_nope.shape
1205
+ max_num_seqs = kv_lens.shape[0]
1206
+ num_page_indices = page_indices.shape[0]
1207
+ assert num_page_indices % max_num_seqs == 0
1208
+ num_q_heads = num_q_heads_per_q_packing * q_packing
1209
+
1210
+ bkv_p = num_kv_pages_per_block
1211
+ bq_sz = num_queries_per_block
1212
+ bkv_sz_per_kv_packing = bkv_p * page_size_per_kv_packing
1213
+ grid = (distribution[2], )
1214
+
1215
+ in_specs = [
1216
+ pl.BlockSpec(memory_space=pltpu.HBM),
1217
+ pl.BlockSpec(memory_space=pltpu.HBM),
1218
+ pl.BlockSpec(memory_space=pltpu.HBM),
1219
+ pl.BlockSpec(memory_space=pltpu.HBM),
1220
+ pl.BlockSpec(memory_space=pltpu.HBM),
1221
+ ]
1222
+
1223
+ out_specs = [
1224
+ pl.BlockSpec(memory_space=pltpu.HBM),
1225
+ pl.BlockSpec(memory_space=pltpu.HBM),
1226
+ ]
1227
+
1228
+ bkvc_double_buf = pltpu.VMEM(
1229
+ (2, bkv_sz_per_kv_packing, kv_packing, lkv_dim),
1230
+ cache_kv.dtype,
1231
+ )
1232
+
1233
+ bkpe_double_buf = pltpu.VMEM(
1234
+ (2, bkv_sz_per_kv_packing, kv_packing, r_dim),
1235
+ cache_kv.dtype,
1236
+ )
1237
+
1238
+ bq_nope_double_buf = pltpu.VMEM(
1239
+ (2, bq_sz, num_q_heads_per_q_packing, q_packing, lkv_dim),
1240
+ ql_nope.dtype,
1241
+ )
1242
+
1243
+ bq_rope_double_buf = pltpu.VMEM(
1244
+ (2, bq_sz, num_q_heads_per_q_packing, q_packing, r_dim),
1245
+ q_pe.dtype,
1246
+ )
1247
+
1248
+ bo_double_buf = bq_nope_double_buf
1249
+
1250
+ l_scratch = pltpu.VMEM(
1251
+ (bq_sz * num_q_heads, 128),
1252
+ jnp.float32,
1253
+ )
1254
+ m_scratch = l_scratch
1255
+
1256
+ acc_scratch = pltpu.VMEM(
1257
+ (bq_sz * num_q_heads, lkv_dim),
1258
+ jnp.float32,
1259
+ )
1260
+
1261
+ scratch_shapes = [
1262
+ bkvc_double_buf,
1263
+ bkpe_double_buf,
1264
+ bq_nope_double_buf,
1265
+ bq_rope_double_buf,
1266
+ bo_double_buf, # Double buffering for output block.
1267
+ # Semaphores for double buffering of bkv, bq, bo and bkv_update.
1268
+ pltpu.SemaphoreType.DMA((4, 2)),
1269
+ # Intermediate buffers per kv head for flash attention.
1270
+ l_scratch,
1271
+ m_scratch,
1272
+ acc_scratch,
1273
+ ]
1274
+
1275
+ scalar_prefetches = (
1276
+ kv_lens,
1277
+ # TODO(jevinjiang): can we use ragged page_indices to save some smem?
1278
+ page_indices,
1279
+ cu_q_lens,
1280
+ distribution,
1281
+ # (bq_sem_idx, bkv_sem_idx, bo_sem_idx)
1282
+ jnp.zeros((3, ), jnp.int32),
1283
+ # (bo_sem_0_seq_idx, bo_sem_1_seq_idx, bo_sem_0_bo_idx, bo_sem_1_bo_idx)
1284
+ jnp.full((4, ), -1, jnp.int32),
1285
+ # (bkv_sem_0_seq_idx, bkv_sem_1_seq_idx, bkv_sem_0_offset, bkv_sem_1_offset, bkv_sem_0_sz, bkv_sem_1_sz)
1286
+ jnp.full((6, ), -1, jnp.int32),
1287
+ )
1288
+
1289
+ scope_name = f"MLA-RPA-bq_{bq_sz}-bkvp_{bkv_p}-p_{page_size}"
1290
+ kernel = jax.named_scope(scope_name)(
1291
+ pl.pallas_call(
1292
+ functools.partial(
1293
+ _mla_ragged_paged_attention_kernel,
1294
+ sm_scale=sm_scale,
1295
+ sliding_window=sliding_window,
1296
+ soft_cap=soft_cap,
1297
+ mask_value=mask_value,
1298
+ chunk_prefill_size=chunk_prefill_size,
1299
+ bq_sz=bq_sz,
1300
+ bkv_p=bkv_p,
1301
+ debug_mode=debug_mode,
1302
+ ),
1303
+ grid_spec=pltpu.PrefetchScalarGridSpec(
1304
+ num_scalar_prefetch=len(scalar_prefetches),
1305
+ in_specs=in_specs,
1306
+ out_specs=out_specs,
1307
+ grid=grid,
1308
+ scratch_shapes=scratch_shapes,
1309
+ ),
1310
+ compiler_params=pltpu.CompilerParams(
1311
+ # TODO(jevinjiang): since each sequence depends on the previous
1312
+ # one, we need some extra work to support Megacore mode.
1313
+ dimension_semantics=("arbitrary", ),
1314
+ vmem_limit_bytes=vmem_limit_bytes,
1315
+ ),
1316
+ out_shape=[
1317
+ jax.ShapeDtypeStruct(shape=ql_nope.shape, dtype=ql_nope.dtype),
1318
+ jax.ShapeDtypeStruct(shape=cache_kv.shape,
1319
+ dtype=cache_kv.dtype),
1320
+ ],
1321
+ input_output_aliases={
1322
+ 7: 0,
1323
+ 11: 1,
1324
+ },
1325
+ name=scope_name,
1326
+ ))
1327
+
1328
+ output, updated_kv = kernel(
1329
+ *scalar_prefetches,
1330
+ ql_nope,
1331
+ q_pe,
1332
+ new_kv_c,
1333
+ new_k_pe,
1334
+ cache_kv,
1335
+ )
1336
+ output = prepare_outputs(
1337
+ output, actual_num_q_heads,
1338
+ actual_lkv_dim) # [max_num_tokens, actual_num_q_heads, actual_lkv_dim]
1339
+
1340
+ return output, updated_kv