tpu-inference 0.12.0.dev20251222__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (260) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_core_tpu.py +513 -0
  4. tests/core/test_disagg_executor.py +60 -0
  5. tests/core/test_disagg_utils.py +67 -0
  6. tests/core/test_dp_scheduler.py +724 -0
  7. tests/core/test_init.py +63 -0
  8. tests/distributed/__init__.py +13 -0
  9. tests/distributed/test_distributed_utils.py +120 -0
  10. tests/distributed/test_tpu_connector.py +478 -0
  11. tests/e2e/__init__.py +13 -0
  12. tests/e2e/test_async_scheduler.py +211 -0
  13. tests/e2e/test_data_parallel.py +393 -0
  14. tests/e2e/test_local_disagg.py +257 -0
  15. tests/e2e/test_model_loader.py +268 -0
  16. tests/e2e/test_multi_modal_inference.py +111 -0
  17. tests/e2e/test_pipeline_parallel.py +265 -0
  18. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  19. tests/e2e/test_sampling_params.py +269 -0
  20. tests/e2e/test_speculative_decoding.py +291 -0
  21. tests/e2e/test_structured_decoding.py +46 -0
  22. tests/executors/__init__.py +13 -0
  23. tests/executors/test_ray_distributed_executor.py +199 -0
  24. tests/experimental/__init__.py +13 -0
  25. tests/experimental/test_llama3_jax_stashed.py +208 -0
  26. tests/kernels/__init__.py +13 -0
  27. tests/kernels/collectives/__init__.py +13 -0
  28. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  29. tests/kernels/fused_moe_v1_test.py +388 -0
  30. tests/kernels/gmm_test.py +205 -0
  31. tests/kernels/mla_v1_test.py +498 -0
  32. tests/kernels/quantized_matmul_kernel_test.py +159 -0
  33. tests/kernels/ragged_kv_cache_update_v2_test.py +248 -0
  34. tests/kernels/ragged_paged_attention_kernel_v2_test.py +414 -0
  35. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +565 -0
  36. tests/kernels/ragged_paged_attention_kernel_v3_test.py +520 -0
  37. tests/layers/__init__.py +13 -0
  38. tests/layers/common/__init__.py +13 -0
  39. tests/layers/common/test_attention_interface.py +156 -0
  40. tests/layers/common/test_quantization.py +149 -0
  41. tests/layers/jax/__init__.py +13 -0
  42. tests/layers/jax/attention/__init__.py +13 -0
  43. tests/layers/jax/attention/test_common_attention.py +103 -0
  44. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  45. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  46. tests/layers/jax/moe/__init__.py +13 -0
  47. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  48. tests/layers/jax/sample/__init__.py +13 -0
  49. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  50. tests/layers/jax/sample/test_sampling.py +115 -0
  51. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  52. tests/layers/jax/test_layers.py +155 -0
  53. tests/layers/jax/test_qwix.py +969 -0
  54. tests/layers/jax/test_rope.py +93 -0
  55. tests/layers/jax/test_sharding.py +159 -0
  56. tests/layers/jax/test_transformer_block.py +152 -0
  57. tests/layers/vllm/__init__.py +13 -0
  58. tests/layers/vllm/test_attention.py +363 -0
  59. tests/layers/vllm/test_awq.py +405 -0
  60. tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +403 -0
  62. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +426 -0
  63. tests/layers/vllm/test_fp8.py +17 -0
  64. tests/layers/vllm/test_mxfp4.py +297 -0
  65. tests/layers/vllm/test_unquantized.py +621 -0
  66. tests/layers/vllm/utils.py +72 -0
  67. tests/lora/__init__.py +13 -0
  68. tests/lora/conftest.py +46 -0
  69. tests/lora/test_bgmv.py +57 -0
  70. tests/lora/test_layers.py +666 -0
  71. tests/lora/test_lora.py +147 -0
  72. tests/lora/test_lora_perf.py +67 -0
  73. tests/lora/utils.py +88 -0
  74. tests/models/__init__.py +13 -0
  75. tests/models/common/__init__.py +13 -0
  76. tests/models/common/test_model_loader.py +455 -0
  77. tests/models/jax/__init__.py +13 -0
  78. tests/models/jax/test_deepseek_v3.py +401 -0
  79. tests/models/jax/test_llama3.py +184 -0
  80. tests/models/jax/test_llama4.py +298 -0
  81. tests/models/jax/test_llama_eagle3.py +197 -0
  82. tests/models/jax/test_llama_guard_4.py +242 -0
  83. tests/models/jax/test_qwen2.py +172 -0
  84. tests/models/jax/test_qwen2_5_vl.py +606 -0
  85. tests/models/jax/test_qwen3.py +169 -0
  86. tests/models/jax/test_weight_loading.py +180 -0
  87. tests/models/jax/utils/__init__.py +13 -0
  88. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  89. tests/platforms/__init__.py +13 -0
  90. tests/platforms/test_tpu_platform.py +54 -0
  91. tests/runner/__init__.py +13 -0
  92. tests/runner/test_block_table.py +395 -0
  93. tests/runner/test_input_batch.py +226 -0
  94. tests/runner/test_kv_cache.py +220 -0
  95. tests/runner/test_kv_cache_manager.py +498 -0
  96. tests/runner/test_multimodal_manager.py +429 -0
  97. tests/runner/test_persistent_batch_manager.py +84 -0
  98. tests/runner/test_speculative_decoding_manager.py +368 -0
  99. tests/runner/test_structured_decoding_manager.py +220 -0
  100. tests/runner/test_tpu_runner.py +202 -0
  101. tests/runner/test_tpu_runner_dp.py +1033 -0
  102. tests/runner/test_tpu_runner_mesh.py +200 -0
  103. tests/runner/test_utils.py +411 -0
  104. tests/spec_decode/__init__.py +13 -0
  105. tests/spec_decode/test_eagle3.py +311 -0
  106. tests/test_base.py +215 -0
  107. tests/test_envs.py +280 -0
  108. tests/test_tpu_info.py +134 -0
  109. tests/test_utils.py +193 -0
  110. tests/worker/__init__.py +13 -0
  111. tests/worker/tpu_worker_test.py +414 -0
  112. tpu_inference/__init__.py +67 -0
  113. tpu_inference/core/__init__.py +13 -0
  114. tpu_inference/core/core_tpu.py +786 -0
  115. tpu_inference/core/disagg_executor.py +118 -0
  116. tpu_inference/core/disagg_utils.py +49 -0
  117. tpu_inference/core/sched/__init__.py +13 -0
  118. tpu_inference/core/sched/dp_scheduler.py +814 -0
  119. tpu_inference/distributed/__init__.py +13 -0
  120. tpu_inference/distributed/jax_parallel_state.py +81 -0
  121. tpu_inference/distributed/tpu_connector.py +732 -0
  122. tpu_inference/distributed/utils.py +112 -0
  123. tpu_inference/env_override.py +9 -0
  124. tpu_inference/envs.py +191 -0
  125. tpu_inference/executors/__init__.py +13 -0
  126. tpu_inference/executors/ray_distributed_executor.py +399 -0
  127. tpu_inference/experimental/__init__.py +13 -0
  128. tpu_inference/experimental/llama3_jax_stashed.py +272 -0
  129. tpu_inference/kernels/__init__.py +13 -0
  130. tpu_inference/kernels/collectives/__init__.py +13 -0
  131. tpu_inference/kernels/collectives/all_gather_matmul.py +741 -0
  132. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +65 -0
  133. tpu_inference/kernels/collectives/util.py +47 -0
  134. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  135. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  136. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  137. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  138. tpu_inference/kernels/fused_moe/v1/kernel.py +1612 -0
  139. tpu_inference/kernels/megablox/__init__.py +13 -0
  140. tpu_inference/kernels/megablox/common.py +54 -0
  141. tpu_inference/kernels/megablox/gmm.py +646 -0
  142. tpu_inference/kernels/mla/__init__.py +13 -0
  143. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  144. tpu_inference/kernels/mla/v1/kernel.py +1340 -0
  145. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  146. tpu_inference/kernels/quantized_matmul/kernel.py +456 -0
  147. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  148. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  149. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  150. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  151. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +876 -0
  152. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +288 -0
  153. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  154. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  155. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1594 -0
  156. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1586 -0
  157. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4460 -0
  158. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +548 -0
  159. tpu_inference/kernels/ragged_paged_attention/v3/util.py +65 -0
  160. tpu_inference/layers/__init__.py +13 -0
  161. tpu_inference/layers/common/__init__.py +13 -0
  162. tpu_inference/layers/common/attention_interface.py +403 -0
  163. tpu_inference/layers/common/attention_metadata.py +48 -0
  164. tpu_inference/layers/common/binary_search.py +295 -0
  165. tpu_inference/layers/common/quant_methods.py +23 -0
  166. tpu_inference/layers/common/quantization.py +270 -0
  167. tpu_inference/layers/common/sharding.py +600 -0
  168. tpu_inference/layers/jax/__init__.py +13 -0
  169. tpu_inference/layers/jax/attention/__init__.py +13 -0
  170. tpu_inference/layers/jax/attention/attention.py +268 -0
  171. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +547 -0
  172. tpu_inference/layers/jax/attention/gpt_oss_attention.py +275 -0
  173. tpu_inference/layers/jax/attention/llama4_attention.py +167 -0
  174. tpu_inference/layers/jax/base.py +165 -0
  175. tpu_inference/layers/jax/constants.py +101 -0
  176. tpu_inference/layers/jax/layers.py +315 -0
  177. tpu_inference/layers/jax/misc.py +30 -0
  178. tpu_inference/layers/jax/moe/__init__.py +13 -0
  179. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +615 -0
  180. tpu_inference/layers/jax/moe/gpt_oss_moe.py +199 -0
  181. tpu_inference/layers/jax/moe/moe.py +249 -0
  182. tpu_inference/layers/jax/pp_utils.py +53 -0
  183. tpu_inference/layers/jax/rope.py +294 -0
  184. tpu_inference/layers/jax/rope_interface.py +228 -0
  185. tpu_inference/layers/jax/sample/__init__.py +13 -0
  186. tpu_inference/layers/jax/sample/rejection_sampler.py +528 -0
  187. tpu_inference/layers/jax/sample/sampling.py +110 -0
  188. tpu_inference/layers/jax/sample/sampling_metadata.py +90 -0
  189. tpu_inference/layers/jax/transformer_block.py +121 -0
  190. tpu_inference/layers/vllm/__init__.py +13 -0
  191. tpu_inference/layers/vllm/attention.py +221 -0
  192. tpu_inference/layers/vllm/fused_moe.py +502 -0
  193. tpu_inference/layers/vllm/linear_common.py +221 -0
  194. tpu_inference/layers/vllm/quantization/__init__.py +55 -0
  195. tpu_inference/layers/vllm/quantization/awq.py +221 -0
  196. tpu_inference/layers/vllm/quantization/common.py +124 -0
  197. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  198. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +135 -0
  199. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +266 -0
  200. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  201. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +222 -0
  202. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +150 -0
  203. tpu_inference/layers/vllm/quantization/fp8.py +118 -0
  204. tpu_inference/layers/vllm/quantization/mxfp4.py +396 -0
  205. tpu_inference/layers/vllm/quantization/unquantized.py +416 -0
  206. tpu_inference/layers/vllm/sharding.py +244 -0
  207. tpu_inference/logger.py +10 -0
  208. tpu_inference/lora/__init__.py +13 -0
  209. tpu_inference/lora/torch_lora_ops.py +98 -0
  210. tpu_inference/lora/torch_punica_tpu.py +310 -0
  211. tpu_inference/models/__init__.py +13 -0
  212. tpu_inference/models/common/__init__.py +13 -0
  213. tpu_inference/models/common/model_loader.py +520 -0
  214. tpu_inference/models/jax/__init__.py +13 -0
  215. tpu_inference/models/jax/deepseek_v3.py +978 -0
  216. tpu_inference/models/jax/gpt_oss.py +508 -0
  217. tpu_inference/models/jax/jax_intermediate_tensor.py +93 -0
  218. tpu_inference/models/jax/llama3.py +436 -0
  219. tpu_inference/models/jax/llama4.py +643 -0
  220. tpu_inference/models/jax/llama_eagle3.py +350 -0
  221. tpu_inference/models/jax/llama_guard_4.py +375 -0
  222. tpu_inference/models/jax/qwen2.py +390 -0
  223. tpu_inference/models/jax/qwen2_5_vl.py +1232 -0
  224. tpu_inference/models/jax/qwen3.py +318 -0
  225. tpu_inference/models/jax/utils/__init__.py +13 -0
  226. tpu_inference/models/jax/utils/file_utils.py +110 -0
  227. tpu_inference/models/jax/utils/multi_modal_utils.py +177 -0
  228. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  229. tpu_inference/models/jax/utils/qwix/qwix_utils.py +713 -0
  230. tpu_inference/models/jax/utils/weight_utils.py +621 -0
  231. tpu_inference/models/vllm/__init__.py +13 -0
  232. tpu_inference/models/vllm/vllm_model_wrapper.py +307 -0
  233. tpu_inference/models/vllm/vllm_model_wrapper_context.py +59 -0
  234. tpu_inference/platforms/__init__.py +16 -0
  235. tpu_inference/platforms/tpu_platform.py +258 -0
  236. tpu_inference/runner/__init__.py +13 -0
  237. tpu_inference/runner/block_table.py +122 -0
  238. tpu_inference/runner/compilation_manager.py +890 -0
  239. tpu_inference/runner/input_batch.py +435 -0
  240. tpu_inference/runner/kv_cache.py +166 -0
  241. tpu_inference/runner/kv_cache_manager.py +508 -0
  242. tpu_inference/runner/lora_utils.py +106 -0
  243. tpu_inference/runner/multimodal_manager.py +231 -0
  244. tpu_inference/runner/persistent_batch_manager.py +296 -0
  245. tpu_inference/runner/speculative_decoding_manager.py +262 -0
  246. tpu_inference/runner/structured_decoding_manager.py +101 -0
  247. tpu_inference/runner/tpu_runner.py +1768 -0
  248. tpu_inference/runner/utils.py +426 -0
  249. tpu_inference/spec_decode/__init__.py +13 -0
  250. tpu_inference/spec_decode/jax/__init__.py +13 -0
  251. tpu_inference/spec_decode/jax/eagle3.py +430 -0
  252. tpu_inference/tpu_info.py +92 -0
  253. tpu_inference/utils.py +345 -0
  254. tpu_inference/worker/__init__.py +13 -0
  255. tpu_inference/worker/tpu_worker.py +468 -0
  256. tpu_inference-0.12.0.dev20251222.dist-info/METADATA +106 -0
  257. tpu_inference-0.12.0.dev20251222.dist-info/RECORD +260 -0
  258. tpu_inference-0.12.0.dev20251222.dist-info/WHEEL +5 -0
  259. tpu_inference-0.12.0.dev20251222.dist-info/licenses/LICENSE +201 -0
  260. tpu_inference-0.12.0.dev20251222.dist-info/top_level.txt +2 -0
@@ -0,0 +1,1594 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """TPU-Friendly Ragged Paged Attention kernel.
15
+
16
+ This kernel offers a highly optimized implementation of ragged paged attention,
17
+ specifically designed for TPU and compatible with a wide range of model
18
+ specifications. It supports mixed prefill and decoding, enhancing throughput
19
+ during inference.
20
+ """
21
+ import functools
22
+
23
+ import jax
24
+ import jax.numpy as jnp
25
+ from jax import lax
26
+ from jax.experimental import pallas as pl
27
+ from jax.experimental.pallas import tpu as pltpu
28
+
29
+ from tpu_inference.kernels.ragged_paged_attention.v3.tuned_block_sizes import \
30
+ get_tuned_block_sizes
31
+ from tpu_inference.kernels.ragged_paged_attention.v3.util import (
32
+ align_to, cdiv, get_dtype_bitwidth, get_dtype_packing)
33
+
34
+ DEFAULT_MASK_VALUE = -0.7 * float(jnp.finfo(jnp.dtype("float32")).max)
35
+
36
+ DEFAULT_VMEM_LIMIT_BYTES = 100 * 1024 * 1024
37
+
38
+
39
+ def ref_ragged_paged_attention(
40
+ queries: jax.
41
+ Array, # [max_num_tokens, actual_num_q_heads, actual_head_dim]
42
+ keys: jax.Array, # [max_num_tokens, actual_num_kv_heads, actual_head_dim]
43
+ values: jax.
44
+ Array, # [max_num_tokens, actual_num_kv_heads, actual_head_dim]
45
+ kv_cache: jax.
46
+ Array, # [total_num_pages, page_size, num_kv_heads_x2 // kv_packing, kv_packing, head_dim]
47
+ kv_lens: jax.Array, # i32[max_num_seqs]
48
+ page_indices: jax.Array, # i32[max_num_seqs * pages_per_seq]
49
+ cu_q_lens: jax.Array, # i32[max_num_seqs + 1]
50
+ distribution: jax.Array, # i32[3]
51
+ *,
52
+ sm_scale: float = 1.0,
53
+ sliding_window: int | None = None,
54
+ soft_cap: float | None = None,
55
+ mask_value: float | None = DEFAULT_MASK_VALUE,
56
+ q_scale: float | None = None,
57
+ k_scale: float | None = None,
58
+ v_scale: float | None = None,
59
+ ):
60
+ if mask_value is None:
61
+ mask_value = DEFAULT_MASK_VALUE
62
+
63
+ dynamic_validate_inputs(
64
+ queries,
65
+ keys,
66
+ values,
67
+ kv_cache,
68
+ kv_lens,
69
+ page_indices,
70
+ cu_q_lens,
71
+ distribution,
72
+ sm_scale=sm_scale,
73
+ sliding_window=sliding_window,
74
+ soft_cap=soft_cap,
75
+ mask_value=mask_value,
76
+ q_scale=q_scale,
77
+ k_scale=k_scale,
78
+ v_scale=v_scale,
79
+ )
80
+ actual_head_dim = queries.shape[2]
81
+ actual_num_q_heads = queries.shape[1]
82
+ actual_num_kv_heads = keys.shape[1]
83
+ merged_kv = merge_kv(keys, values)
84
+ assert merged_kv.shape[-3:] == kv_cache.shape[-3:]
85
+
86
+ _, page_size, num_kv_heads_x2_per_kv_packing, kv_packing, head_dim = (
87
+ kv_cache.shape)
88
+ num_kv_heads_x2 = num_kv_heads_x2_per_kv_packing * kv_packing
89
+ assert num_kv_heads_x2 % 2 == 0
90
+ assert actual_num_q_heads % actual_num_kv_heads == 0
91
+ assert head_dim % 128 == 0
92
+ assert get_dtype_packing(kv_cache.dtype) == kv_packing
93
+ assert num_kv_heads_x2 == align_to(actual_num_kv_heads * 2, kv_packing)
94
+ actual_num_q_heads_per_kv_head = actual_num_q_heads // actual_num_kv_heads
95
+ max_num_seqs = kv_lens.shape[0]
96
+ num_page_indices = page_indices.shape[0]
97
+ assert num_page_indices % max_num_seqs == 0
98
+ pages_per_seq = num_page_indices // max_num_seqs
99
+ outputs = []
100
+
101
+ for i in range(distribution[-1]):
102
+ q_start = cu_q_lens[i]
103
+ q_end = cu_q_lens[i + 1]
104
+ q_len = q_end - q_start
105
+
106
+ kv_len = kv_lens[i]
107
+ indices_start = i * pages_per_seq
108
+ indices_end = indices_start + cdiv(kv_len, page_size)
109
+ indices = page_indices[indices_start:indices_end]
110
+ q = queries[q_start:q_end, :, :actual_head_dim]
111
+
112
+ # Update the kv cache.
113
+ assert kv_len - q_len >= 0
114
+ gathered_kv = kv_cache[indices]
115
+ gathered_shape = gathered_kv.shape
116
+ gathered_kv = gathered_kv.reshape(-1, *gathered_shape[-3:])
117
+ gathered_kv = gathered_kv.at[kv_len - q_len:kv_len].set(
118
+ merged_kv[q_start:q_end])
119
+ kv_cache = kv_cache.at[indices].set(
120
+ gathered_kv.reshape(gathered_shape))
121
+
122
+ kv = gathered_kv.reshape(
123
+ -1, num_kv_heads_x2,
124
+ head_dim)[:, :actual_num_kv_heads * 2, :].reshape(
125
+ -1, actual_num_kv_heads, head_dim * 2)
126
+ k = kv[:kv_len, :, :head_dim][:, :, :actual_head_dim]
127
+ v = kv[:kv_len, :, head_dim:][:, :, :actual_head_dim]
128
+ k = jnp.repeat(k, actual_num_q_heads_per_kv_head, axis=1)
129
+ v = jnp.repeat(v, actual_num_q_heads_per_kv_head, axis=1)
130
+
131
+ if q_scale is not None:
132
+ q = q / q_scale
133
+ if jnp.issubdtype(k.dtype, jnp.floating):
134
+ dtype_info = jnp.finfo(k.dtype)
135
+ minval = float(dtype_info.min)
136
+ maxval = float(dtype_info.max)
137
+ q = jnp.clip(q, min=minval, max=maxval)
138
+ q = q.astype(k.dtype)
139
+
140
+ attn = jnp.einsum("qhd,khd->hqk",
141
+ q,
142
+ k,
143
+ preferred_element_type=jnp.float32)
144
+ attn *= sm_scale
145
+ if k_scale is not None:
146
+ attn *= k_scale
147
+ if q_scale is not None:
148
+ attn *= q_scale
149
+
150
+ q_span = (kv_len - q_len) + jax.lax.broadcasted_iota(
151
+ jnp.int32, attn.shape, 1)
152
+ kv_span = jax.lax.broadcasted_iota(jnp.int32, attn.shape, 2)
153
+ mask = q_span < kv_span
154
+ if sliding_window is not None:
155
+ mask = jnp.logical_or(mask, q_span - sliding_window >= kv_span)
156
+ if soft_cap is not None:
157
+ attn = soft_cap * jnp.tanh(attn / soft_cap)
158
+ attn += jnp.where(mask, mask_value, 0.0)
159
+ attn = jax.nn.softmax(attn, axis=-1).astype(v.dtype)
160
+
161
+ out = jnp.einsum("hqk,khd->qhd", attn, v).astype(queries.dtype)
162
+ if v_scale is not None:
163
+ out *= v_scale
164
+
165
+ outputs.append(out)
166
+
167
+ result = jnp.concatenate(outputs, axis=0)
168
+ return result, kv_cache
169
+
170
+
171
+ def get_smem_estimate_bytes(max_num_seqs, pages_per_seq):
172
+ total_bits = (
173
+ # kv_lens_ref: i32[max_num_seqs]
174
+ align_to(max_num_seqs, 128) * 32 +
175
+ # page_indices_ref: i32[max_num_seqs * pages_per_seq]
176
+ align_to(max_num_seqs * pages_per_seq, 128) * 32 +
177
+ # cu_q_lens_ref: i32[max_num_seqs + 1]
178
+ align_to(max_num_seqs + 1, 128) * 32 +
179
+ # distribution_ref: i32[3]
180
+ 128 * 32 +
181
+ # sem_ids_ref: i32[3]
182
+ 128 * 32 +
183
+ # bo_ids_ref: i32[4]
184
+ 128 * 32 +
185
+ # bkv_update_ids_ref: i32[6]
186
+ 128 * 32)
187
+ return cdiv(total_bits, 8)
188
+
189
+
190
+ def get_vmem_estimate_bytes(
191
+ actual_num_kv_heads,
192
+ actual_num_q_heads_per_kv_head,
193
+ actual_head_dim,
194
+ bq_sz,
195
+ bkv_sz,
196
+ q_dtype,
197
+ kv_dtype,
198
+ ):
199
+ q_packing = get_dtype_packing(q_dtype)
200
+ kv_packing = get_dtype_packing(kv_dtype)
201
+ num_q_heads_per_kv_head = align_to(actual_num_q_heads_per_kv_head,
202
+ q_packing)
203
+ num_kv_heads_x2 = align_to(actual_num_kv_heads * 2, kv_packing)
204
+ head_dim = align_to(actual_head_dim, 128)
205
+
206
+ total_bits = (
207
+ # bkv_x2_ref
208
+ (2 * bkv_sz * num_kv_heads_x2 * head_dim) * (32 // kv_packing) +
209
+ # bq_x2_ref + bo_x2_ref
210
+ 2 * (2 * actual_num_kv_heads * bq_sz * num_q_heads_per_kv_head *
211
+ head_dim) * (32 // q_packing) +
212
+ # l_ref + m_ref
213
+ 2 *
214
+ (actual_num_kv_heads * bq_sz * num_q_heads_per_kv_head * 128) * 32 +
215
+ # acc_ref
216
+ (actual_num_kv_heads * bq_sz * num_q_heads_per_kv_head * head_dim) *
217
+ 32)
218
+ return cdiv(total_bits, 8)
219
+
220
+
221
+ def get_kv_cache_shape(
222
+ total_num_pages,
223
+ page_size,
224
+ actual_num_kv_heads,
225
+ actual_head_dim,
226
+ kv_dtype,
227
+ ):
228
+ kv_packing = get_dtype_packing(kv_dtype)
229
+ return (
230
+ total_num_pages,
231
+ page_size,
232
+ align_to(actual_num_kv_heads * 2, kv_packing) // kv_packing,
233
+ kv_packing,
234
+ align_to(actual_head_dim, 128),
235
+ )
236
+
237
+
238
+ def _ragged_paged_attention_kernel(
239
+ # Prefetch
240
+ kv_lens_ref, # [max_num_seqs]
241
+ page_indices_ref, # [max_num_seqs * pages_per_seq]
242
+ cu_q_lens_ref, # [max_num_seqs + 1]
243
+ # TODO(jevinjiang): merge these into one so we can save SMEM.
244
+ distribution_ref, # [3] (decode_end, prefill_end, mixed_end)
245
+ sem_ids_ref, # [3] (bq_sem_idx, bkv_sem_idx, bo_sem_idx)
246
+ bo_ids_ref, # [4] (bo_sem_0_seq_idx, bo_sem_1_seq_idx, bo_sem_0_bo_idx, bo_sem_1_bo_idx)
247
+ bkv_update_ids_ref, # [6] (bkv_sem_0_seq_idx, bkv_sem_1_seq_idx, bkv_sem_0_offset, bkv_sem_1_offset, bkv_sem_0_sz, bkv_sem_1_sz)
248
+ # Input
249
+ q_hbm_ref, # [actual_num_kv_heads, max_num_tokens, num_q_heads_per_kv_head // q_packing, q_packing, head_dim]
250
+ kv_hbm_ref, # [max_num_tokens, num_kv_heads_x2 // kv_packing, kv_packing, head_dim]
251
+ kv_cache_hbm_ref, # [total_num_pages, page_size, num_kv_heads_x2 // kv_packing, kv_packing, head_dim]
252
+ # Output
253
+ o_hbm_ref, # [actual_num_kv_heads, max_num_tokens, num_q_heads_per_kv_head // q_packing, q_packing, head_dim]
254
+ updated_kv_cache_hbm_ref, # [total_num_pages, page_size, num_kv_heads_x2 // kv_packing, kv_packing, head_dim]
255
+ # Scratch
256
+ bkv_x2_ref, # [2, bkv_sz, num_kv_heads_x2 // kv_packing, kv_packing, head_dim]
257
+ bq_x2_ref, # [2, actual_num_kv_heads, bq_sz, num_q_heads_per_kv_head // q_packing, q_packing, head_dim]
258
+ bo_x2_ref, # [2, actual_num_kv_heads, bq_sz, num_q_heads_per_kv_head // q_packing, q_packing, head_dim]
259
+ sems, # [4, 2]
260
+ l_ref, # [actual_num_kv_heads, bq_sz * num_q_heads_per_kv_head, 128],
261
+ m_ref, # [actual_num_kv_heads, bq_sz * num_q_heads_per_kv_head, 128],
262
+ acc_ref, # [actual_num_kv_heads, bq_sz * num_q_heads_per_kv_head, head_dim],
263
+ *,
264
+ sm_scale: float,
265
+ sliding_window: int | None = None,
266
+ soft_cap: float | None = None,
267
+ mask_value: float = DEFAULT_MASK_VALUE,
268
+ q_scale: float | None = None,
269
+ k_scale: float | None = None,
270
+ v_scale: float | None = None,
271
+ chunk_prefill_size: int | None = None,
272
+ bkv_p,
273
+ bq_sz,
274
+ debug_mode: bool = False,
275
+ ):
276
+ assert q_hbm_ref.shape == o_hbm_ref.shape
277
+ assert q_hbm_ref.shape[-1] == kv_cache_hbm_ref.shape[-1]
278
+ (
279
+ actual_num_kv_heads,
280
+ max_num_tokens,
281
+ num_q_heads_per_kv_head_per_packing,
282
+ q_packing,
283
+ head_dim,
284
+ ) = q_hbm_ref.shape
285
+ (
286
+ total_num_pages,
287
+ page_size,
288
+ num_kv_heads_x2_per_kv_packing,
289
+ kv_packing,
290
+ _,
291
+ ) = kv_cache_hbm_ref.shape
292
+ max_num_seqs = kv_lens_ref.shape[0]
293
+ num_page_indices = page_indices_ref.shape[0]
294
+ assert num_page_indices % max_num_seqs == 0
295
+ pages_per_seq = num_page_indices // max_num_seqs
296
+ num_kv_heads_x2 = num_kv_heads_x2_per_kv_packing * kv_packing
297
+ num_q_heads_per_kv_head = num_q_heads_per_kv_head_per_packing * q_packing
298
+ q_dtype = q_hbm_ref.dtype
299
+ kv_dtype = kv_cache_hbm_ref.dtype
300
+ assert o_hbm_ref.dtype == q_dtype
301
+ assert get_dtype_packing(q_dtype) == q_packing
302
+ assert get_dtype_packing(kv_dtype) == kv_packing
303
+ assert head_dim % 128 == 0
304
+ bkv_sz = bkv_p * page_size
305
+ seq_idx = pl.program_id(0)
306
+ num_seqs = pl.num_programs(0)
307
+ decode_end = distribution_ref[0]
308
+ prefill_end = distribution_ref[1]
309
+ mixed_end = distribution_ref[2]
310
+
311
+ q_start = cu_q_lens_ref[seq_idx]
312
+ q_end = cu_q_lens_ref[seq_idx + 1]
313
+ q_len = q_end - q_start
314
+ kv_len = kv_lens_ref[seq_idx]
315
+
316
+ if sliding_window is None:
317
+ bkv_idx_start = next_seq_bkv_idx_start = 0
318
+ else:
319
+ bkv_idx_start = jnp.maximum(kv_len - q_len - sliding_window,
320
+ 0) // bkv_sz
321
+
322
+ # If seq_idx + 1 == num_seqs, kv_lens_ref[seq_idx + 1] will trigger a
323
+ # out-of-bound error. To avoid this, we set upperbound of next_seq_idx
324
+ # to be num_seqs - 1.
325
+ next_seq_idx = jnp.minimum(seq_idx + 1, num_seqs - 1)
326
+ next_kv_len = kv_lens_ref[next_seq_idx]
327
+ next_q_len = cu_q_lens_ref[next_seq_idx + 1] - q_end
328
+ next_seq_bkv_idx_start = (
329
+ jnp.maximum(next_kv_len - next_q_len - sliding_window, 0) //
330
+ bkv_sz)
331
+
332
+ def debug_print(msg, *args):
333
+ if debug_mode:
334
+ pl.debug_print(msg, *args)
335
+
336
+ debug_print("[RPA debug] ======= In loop seq_idx={}", seq_idx)
337
+ debug_print("[RPA debug] num_seqs={}", num_seqs)
338
+ debug_print("[RPA debug] decode_end={}", decode_end)
339
+ debug_print("[RPA debug] prefill_end={}", prefill_end)
340
+ debug_print("[RPA debug] mixed_end={}", mixed_end)
341
+ debug_print("[RPA debug] bkv_p={}", bkv_p)
342
+ debug_print("[RPA debug] page_size={}", page_size)
343
+ debug_print("[RPA debug] pages_per_seq={}", pages_per_seq)
344
+ debug_print("[RPA debug] bkv_sz={}", bkv_sz)
345
+ debug_print("[RPA debug] bq_sz={}", bq_sz)
346
+ debug_print("[RPA debug] q_start={}", q_start)
347
+ debug_print("[RPA debug] q_end={}", q_end)
348
+ debug_print("[RPA debug] q_len={}", q_len)
349
+ debug_print("[RPA debug] kv_len={}", kv_len)
350
+
351
+ def flash_attention_step1_qk_softmax(
352
+ q, # [actual_bq_sz * num_q_heads_per_kv_head, head_dim]
353
+ k, # [bkv_sz, head_dim]
354
+ v, # [bkv_sz, head_dim]
355
+ *,
356
+ bq_idx,
357
+ bkv_idx,
358
+ kv_head_idx,
359
+ ):
360
+ assert len(q.shape) == 2
361
+ assert q.shape[0] % num_q_heads_per_kv_head == 0
362
+ assert q.shape[1] == head_dim
363
+ assert k.shape == v.shape == (bkv_sz, head_dim)
364
+ assert k.dtype == v.dtype
365
+ head_l_ref = l_ref.at[kv_head_idx, :q.shape[0]]
366
+ head_m_ref = m_ref.at[kv_head_idx, :q.shape[0]]
367
+
368
+ def load_with_init(ref, init_val):
369
+ return jnp.where(bkv_idx == bkv_idx_start,
370
+ jnp.full_like(ref, init_val), ref[...])
371
+
372
+ # Follow FlashAttention-2 forward pass.
373
+ if q_scale is not None:
374
+ q = q / q_scale
375
+ if jnp.issubdtype(k.dtype, jnp.floating):
376
+ dtype_info = jnp.finfo(k.dtype)
377
+ minval = float(dtype_info.min)
378
+ maxval = float(dtype_info.max)
379
+ q = jnp.clip(q, min=minval, max=maxval)
380
+ q = q.astype(k.dtype)
381
+
382
+ s = jnp.einsum("nd,md->nm", q, k, preferred_element_type=jnp.float32)
383
+ s *= sm_scale
384
+ if k_scale is not None:
385
+ s *= k_scale
386
+ if q_scale is not None:
387
+ s *= q_scale
388
+ if soft_cap is not None:
389
+ s = soft_cap * jnp.tanh(s / soft_cap)
390
+
391
+ q_span = (kv_len - q_len + bq_idx * bq_sz +
392
+ lax.broadcasted_iota(jnp.int32, s.shape, 0) //
393
+ num_q_heads_per_kv_head)
394
+ k_span = bkv_idx * bkv_sz + lax.broadcasted_iota(jnp.int32, s.shape, 1)
395
+ mask = k_span <= q_span
396
+
397
+ if sliding_window is not None:
398
+ mask = jnp.logical_and(mask, q_span - sliding_window < k_span)
399
+
400
+ s = jnp.where(mask, s, mask_value)
401
+ s_rowmax = jnp.max(s, axis=1, keepdims=True)
402
+
403
+ m_prev = load_with_init(head_m_ref, -jnp.inf)
404
+ m_curr = jnp.maximum(m_prev, s_rowmax)
405
+ head_m_ref[...] = m_curr
406
+ p = jnp.exp(s - broadcast_minor(m_curr, s.shape))
407
+
408
+ p_rowsum = jnp.sum(p, axis=1, keepdims=True)
409
+ exp_m_diff = jnp.exp(m_prev - m_curr)
410
+ l_prev = load_with_init(head_l_ref, 0.0)
411
+ l_curr = exp_m_diff * l_prev + p_rowsum
412
+ head_l_ref[...] = l_curr
413
+
414
+ return p, exp_m_diff
415
+
416
+ def flash_attention_step2_pv(
417
+ q_shape_0,
418
+ v, # [bkv_sz, head_dim]
419
+ p, # from step1
420
+ exp_m_diff, # from step1
421
+ *,
422
+ bkv_idx,
423
+ kv_head_idx,
424
+ ):
425
+ head_acc_ref = acc_ref.at[kv_head_idx, :q_shape_0]
426
+
427
+ def load_with_init(ref, init_val):
428
+ return jnp.where(bkv_idx == 0, jnp.full_like(ref, init_val),
429
+ ref[...])
430
+
431
+ pv = jnp.einsum("nm,md->nd", p, v, preferred_element_type=jnp.float32)
432
+ if v_scale is not None:
433
+ pv *= v_scale
434
+ o_prev = load_with_init(head_acc_ref, 0.0)
435
+ o_curr = broadcast_minor(exp_m_diff, o_prev.shape) * o_prev + pv
436
+ head_acc_ref[...] = o_curr
437
+
438
+ def _async_copy(src, dst, sem, wait):
439
+ if debug_mode:
440
+ # Skip DMA if debug mode is enabled.
441
+ return
442
+ cp = pltpu.make_async_copy(src, dst, sem)
443
+ if wait:
444
+ cp.wait()
445
+ else:
446
+ cp.start()
447
+
448
+ def _fetch_bkv(seq_idx, bkv_idx, bkv_sem_idx, *, wait=False):
449
+ sem = sems.at[0, bkv_sem_idx]
450
+ vmem_ref = bkv_x2_ref.at[bkv_sem_idx]
451
+
452
+ cache_hbm_shape = kv_cache_hbm_ref.shape
453
+ cache_hbm_ref = kv_cache_hbm_ref.reshape(
454
+ cache_hbm_shape[0] * cache_hbm_shape[1], *cache_hbm_shape[2:])
455
+ kv_len = kv_lens_ref[seq_idx]
456
+ kv_len_start = bkv_idx * bkv_sz
457
+ kv_p_start = bkv_idx * bkv_p
458
+ q_start = cu_q_lens_ref[seq_idx]
459
+ q_end = cu_q_lens_ref[seq_idx + 1]
460
+ q_len = q_end - q_start
461
+
462
+ kv_left = kv_len - kv_len_start
463
+ kv_left_frm_cache = jnp.maximum(kv_left - q_len, 0)
464
+ kv_left_frm_new = kv_left - kv_left_frm_cache
465
+ bkv_p_frm_cache = jnp.minimum(cdiv(kv_left_frm_cache, page_size),
466
+ bkv_p)
467
+ bkv_sz_frm_new = jnp.minimum(
468
+ jnp.maximum(bkv_sz - kv_left_frm_cache, 0), kv_left_frm_new)
469
+ page_indices_offset = seq_idx * pages_per_seq + kv_p_start
470
+
471
+ # Make sure the current bkv buffer is safe to overwrite.
472
+ wait_update_kv_cache(bkv_sem_idx)
473
+
474
+ debug_print(
475
+ "[RPA debug]"
476
+ f" -----------{'wait' if wait else 'start'}_fetch_bkv-----------")
477
+ debug_print("[RPA debug] seq_idx={}", seq_idx)
478
+ debug_print("[RPA debug] bkv_idx={}", bkv_idx)
479
+ debug_print("[RPA debug] bkv_sem_idx={}", bkv_sem_idx)
480
+ debug_print("[RPA debug] kv_len_start={}", kv_len_start)
481
+ debug_print("[RPA debug] kv_p_start={}", kv_p_start)
482
+ debug_print("[RPA debug] kv_left={}", kv_left)
483
+ debug_print("[RPA debug] kv_left_frm_cache={}", kv_left_frm_cache)
484
+ debug_print("[RPA debug] kv_left_frm_new={}", kv_left_frm_new)
485
+ debug_print("[RPA debug] bkv_p_frm_cache={}", bkv_p_frm_cache)
486
+ debug_print("[RPA debug] bkv_sz_frm_new={}", bkv_sz_frm_new)
487
+ debug_print("[RPA debug] page_indices_offset={}", page_indices_offset)
488
+
489
+ if not wait:
490
+ # Fetch effective kv from kv cache.
491
+ def loop_body(i, offset):
492
+ sz = jnp.minimum(page_size, kv_left_frm_cache - i * page_size)
493
+ _async_copy(
494
+ cache_hbm_ref.at[pl.ds(
495
+ page_indices_ref[page_indices_offset + i] * page_size,
496
+ sz)],
497
+ vmem_ref.at[pl.ds(i * page_size, sz)],
498
+ sem,
499
+ wait=False,
500
+ )
501
+ debug_print("[RPA debug] loop_body i={}, sz={}", i, sz)
502
+ return offset + sz
503
+
504
+ offset = lax.fori_loop(
505
+ 0,
506
+ bkv_p_frm_cache,
507
+ loop_body,
508
+ 0, # offset
509
+ unroll=False,
510
+ )
511
+
512
+ size = lax.select(bkv_sz_frm_new > 0, bkv_sz_frm_new, 0)
513
+ new_kv_len_start = q_end - kv_left_frm_new
514
+ debug_print("[RPA debug] new_kv_len_start={}", new_kv_len_start)
515
+ debug_print("[RPA debug] offset_in_bkv={}", offset)
516
+ _async_copy(
517
+ kv_hbm_ref.at[pl.ds(new_kv_len_start, size)],
518
+ vmem_ref.at[pl.ds(offset, size)],
519
+ sem,
520
+ wait,
521
+ )
522
+
523
+ return kv_len_start + offset, bkv_sz_frm_new
524
+ else:
525
+ offset = jnp.minimum(kv_left_frm_cache, page_size * bkv_p)
526
+ dst = vmem_ref.at[pl.ds(0, offset + bkv_sz_frm_new)]
527
+ _async_copy(
528
+ src=dst,
529
+ dst=dst,
530
+ sem=sem,
531
+ wait=True,
532
+ )
533
+ return kv_len_start + offset, bkv_sz_frm_new
534
+
535
+ def _update_kv_cache(seq_idx,
536
+ bkv_sem_idx,
537
+ offset,
538
+ update_sz,
539
+ *,
540
+ wait=False):
541
+ sem = sems.at[3, bkv_sem_idx]
542
+ vmem_ref = bkv_x2_ref.at[bkv_sem_idx]
543
+ bkv_id = offset // bkv_sz
544
+ kv_p_start = offset // page_size
545
+ kv_p_end = cdiv(offset + update_sz, page_size)
546
+ ignore = offset % page_size
547
+ p_ignore = kv_p_start - bkv_id * bkv_p
548
+ page_indices_offset = seq_idx * pages_per_seq + kv_p_start
549
+
550
+ cache_hbm_shape = updated_kv_cache_hbm_ref.shape
551
+ cache_hbm_ref = updated_kv_cache_hbm_ref.reshape(
552
+ cache_hbm_shape[0] * cache_hbm_shape[1], *cache_hbm_shape[2:])
553
+
554
+ debug_print(
555
+ "[RPA debug]"
556
+ f" -----------{'wait' if wait else 'start'}_update_kv_cache-----------"
557
+ )
558
+ debug_print("[RPA debug] seq_idx={}", seq_idx)
559
+ debug_print("[RPA debug] bkv_sem_idx={}", bkv_sem_idx)
560
+ debug_print("[RPA debug] offset={}", offset)
561
+ debug_print("[RPA debug] update_sz={}", update_sz)
562
+ debug_print("[RPA debug] bkv_id={}", bkv_id)
563
+ debug_print("[RPA debug] kv_p_start={}", kv_p_start)
564
+ debug_print("[RPA debug] kv_p_end={}", kv_p_end)
565
+ debug_print("[RPA debug] ignore={}", ignore)
566
+ debug_print("[RPA debug] p_ignore={}", p_ignore)
567
+ debug_print("[RPA debug] page_indices_offset={}", page_indices_offset)
568
+
569
+ if not wait:
570
+
571
+ def loop_body(i, states):
572
+ update_sz, ignore = states
573
+ sz = jnp.minimum(page_size - ignore, update_sz)
574
+
575
+ _async_copy(
576
+ vmem_ref.at[pl.ds((p_ignore + i) * page_size + ignore,
577
+ sz)],
578
+ cache_hbm_ref.at[pl.ds(
579
+ page_indices_ref[page_indices_offset + i] * page_size +
580
+ ignore,
581
+ sz,
582
+ )],
583
+ sem,
584
+ wait=False,
585
+ )
586
+ debug_print("[RPA debug] loop_body i={}, sz={}", i, sz)
587
+ return update_sz - sz, 0
588
+
589
+ lax.fori_loop(
590
+ 0,
591
+ kv_p_end - kv_p_start,
592
+ loop_body,
593
+ (update_sz, ignore), # total transfer size
594
+ unroll=False,
595
+ )
596
+ else:
597
+ dst = cache_hbm_ref.at[pl.ds(0, update_sz)]
598
+ _async_copy(
599
+ src=dst,
600
+ dst=dst,
601
+ sem=sem,
602
+ wait=True,
603
+ )
604
+
605
+ def _fetch_bq(seq_idx, bq_idx, bq_sem_idx, *, wait=False):
606
+ sem = sems.at[1, bq_sem_idx]
607
+ vmem_ref = bq_x2_ref.at[bq_sem_idx]
608
+ q_len_start = cu_q_lens_ref[seq_idx] + bq_idx * bq_sz
609
+ q_end = cu_q_lens_ref[seq_idx + 1]
610
+ sz = jnp.minimum(bq_sz, q_end - q_len_start)
611
+
612
+ debug_print(
613
+ "[RPA debug]"
614
+ f" -----------{'wait' if wait else 'start'}_fetch_bq-----------")
615
+ debug_print("[RPA debug] seq_idx={}", seq_idx)
616
+ debug_print("[RPA debug] bq_idx={}", bq_idx)
617
+ debug_print("[RPA debug] bq_sem_idx={}", bq_sem_idx)
618
+ debug_print("[RPA debug] q_len_start={}", q_len_start)
619
+ debug_print("[RPA debug] q_end={}", q_end)
620
+ debug_print("[RPA debug] sz={}", sz)
621
+
622
+ _async_copy(
623
+ q_hbm_ref.at[:, pl.ds(q_len_start, sz)],
624
+ vmem_ref.at[:, pl.ds(0, sz)],
625
+ sem,
626
+ wait,
627
+ )
628
+
629
+ def _send_bo(seq_idx, bo_idx, bo_sem_idx, *, wait=False):
630
+ sem = sems.at[2, bo_sem_idx]
631
+ vmem_ref = bo_x2_ref.at[bo_sem_idx]
632
+ q_len_start = cu_q_lens_ref[seq_idx] + bo_idx * bq_sz
633
+ q_end = cu_q_lens_ref[seq_idx + 1]
634
+ sz = jnp.minimum(bq_sz, q_end - q_len_start)
635
+
636
+ debug_print(
637
+ "[RPA debug]"
638
+ f" -----------{'wait' if wait else 'start'}_send_bo-----------")
639
+ debug_print("[RPA debug] seq_idx={}", seq_idx)
640
+ debug_print("[RPA debug] bo_idx={}", bo_idx)
641
+ debug_print("[RPA debug] bo_sem_idx={}", bo_sem_idx)
642
+ debug_print("[RPA debug] q_len_start={}", q_len_start)
643
+ debug_print("[RPA debug] q_end={}", q_end)
644
+ debug_print("[RPA debug] sz={}", sz)
645
+
646
+ _async_copy(
647
+ vmem_ref.at[:, pl.ds(0, sz)],
648
+ o_hbm_ref.at[:, pl.ds(q_len_start, sz)],
649
+ sem,
650
+ wait,
651
+ )
652
+
653
+ def start_fetch_bkv(seq_idx, bkv_idx, bkv_sem_idx):
654
+ return _fetch_bkv(seq_idx, bkv_idx, bkv_sem_idx)
655
+
656
+ def wait_fetch_bkv(seq_idx, bkv_idx, bkv_sem_idx):
657
+ return _fetch_bkv(seq_idx, bkv_idx, bkv_sem_idx, wait=True)
658
+
659
+ def start_fetch_bq(seq_idx, bq_idx, bq_sem_idx):
660
+ return _fetch_bq(seq_idx, bq_idx, bq_sem_idx)
661
+
662
+ def wait_fetch_bq(seq_idx, bq_idx, bq_sem_idx):
663
+ return _fetch_bq(seq_idx, bq_idx, bq_sem_idx, wait=True)
664
+
665
+ def start_send_bo(seq_idx, bo_idx, bo_sem_idx):
666
+ bo_ids_ref[bo_sem_idx] = seq_idx
667
+ bo_ids_ref[bo_sem_idx + 2] = bo_idx
668
+ _send_bo(seq_idx, bo_idx, bo_sem_idx)
669
+
670
+ def wait_send_bo(bo_sem_idx):
671
+ old_seq_idx = bo_ids_ref[bo_sem_idx]
672
+ old_bo_idx = bo_ids_ref[bo_sem_idx + 2]
673
+
674
+ @pl.when(jnp.logical_and(0 <= old_seq_idx, old_seq_idx <= seq_idx))
675
+ def _():
676
+ _send_bo(old_seq_idx, old_bo_idx, bo_sem_idx, wait=True)
677
+
678
+ def start_update_kv_cache(seq_idx, bkv_sem_idx, offset, update_sz):
679
+ bkv_update_ids_ref[bkv_sem_idx] = seq_idx
680
+ bkv_update_ids_ref[bkv_sem_idx + 2] = offset
681
+ bkv_update_ids_ref[bkv_sem_idx + 4] = update_sz
682
+ _update_kv_cache(seq_idx, bkv_sem_idx, offset, update_sz)
683
+
684
+ def wait_update_kv_cache(bkv_sem_idx):
685
+ update_sz = bkv_update_ids_ref[bkv_sem_idx + 4]
686
+
687
+ @pl.when(update_sz > 0)
688
+ def _():
689
+ seq_idx = bkv_update_ids_ref[bkv_sem_idx]
690
+ offset = bkv_update_ids_ref[bkv_sem_idx + 2]
691
+ bkv_update_ids_ref[bkv_sem_idx + 4] = 0
692
+ _update_kv_cache(seq_idx,
693
+ bkv_sem_idx,
694
+ offset,
695
+ update_sz,
696
+ wait=True)
697
+
698
+ def load_bq(bq_sem_idx, kv_head_idx, *, actual_bq_sz=bq_sz):
699
+ q_ref = (bq_x2_ref.bitcast(
700
+ jnp.uint32).at[bq_sem_idx, kv_head_idx].reshape(
701
+ bq_sz * num_q_heads_per_kv_head_per_packing, head_dim))
702
+ return pltpu.bitcast(
703
+ q_ref[:actual_bq_sz * num_q_heads_per_kv_head_per_packing],
704
+ q_dtype)
705
+
706
+ def strided_load(ref, start, step):
707
+ assert get_dtype_packing(ref.dtype) == 1
708
+ assert len(ref.shape) == 2
709
+ r, l = ref.shape # noqa
710
+ assert l % 128 == 0
711
+ folds = l // 128
712
+ ref = ref.reshape(r * folds, 128)
713
+ start *= folds
714
+ step *= folds
715
+ vec = jnp.concat([ref[start + i::step] for i in range(folds)], axis=1)
716
+ return vec
717
+
718
+ def strided_load_bkv(bkv_sem_idx, start, step):
719
+ assert start % kv_packing == 0
720
+ assert step % kv_packing == 0
721
+ start //= kv_packing
722
+ step //= kv_packing
723
+ kv_ref = (bkv_x2_ref.bitcast(jnp.uint32).at[bkv_sem_idx].reshape(
724
+ bkv_sz * step, head_dim))
725
+
726
+ if kv_packing == 1:
727
+ k = strided_load(kv_ref, start, step)
728
+ v = strided_load(kv_ref, start + 1, step)
729
+
730
+ k = pltpu.bitcast(k, kv_dtype)
731
+ v = pltpu.bitcast(v, kv_dtype)
732
+ return [(k, v)]
733
+
734
+ kv = strided_load(kv_ref, start, step)
735
+ bitwidth = 32 // kv_packing
736
+
737
+ # If we want to convert 32-bits into 32//N number of N-bits value, naive
738
+ # approach would be to perform 32//N number of 32-bits to N-bits conversion.
739
+ # However, we can reduce number of instructions by utilizing binary tree.
740
+ # 0: [32]
741
+ # 1: [16, 16]
742
+ # ...
743
+ # log2(32//N): [N, N, ... N]
744
+
745
+ def _convert_to_target_bitwidth(val, target_bitwidth: int):
746
+ curr_dtype = val.dtype
747
+ curr_bitwidth = get_dtype_bitwidth(curr_dtype)
748
+ assert target_bitwidth != curr_bitwidth, "No conversion is needed."
749
+
750
+ # We split val into two vals (left and right) where each have half of the
751
+ # original bitwidth.
752
+ next_bitwidth = curr_bitwidth // 2
753
+ next_dtype = jnp.dtype(f"uint{next_bitwidth}")
754
+
755
+ left = val.astype(next_dtype)
756
+
757
+ # Bitwise shift is only supported in uint32.
758
+ val_u32 = pltpu.bitcast(val, jnp.uint32)
759
+ val_u32_shifted = val_u32 >> next_bitwidth
760
+ # Convert back to original dtype.
761
+ val_shifted = pltpu.bitcast(val_u32_shifted, curr_dtype)
762
+ right = val_shifted.astype(next_dtype)
763
+
764
+ if next_bitwidth == target_bitwidth:
765
+ k = pltpu.bitcast(left, kv_dtype)
766
+ v = pltpu.bitcast(right, kv_dtype)
767
+ return [(k, v)]
768
+ else:
769
+ left_out = _convert_to_target_bitwidth(
770
+ left,
771
+ target_bitwidth=target_bitwidth,
772
+ )
773
+ right_out = _convert_to_target_bitwidth(
774
+ right,
775
+ target_bitwidth=target_bitwidth,
776
+ )
777
+ return left_out + right_out
778
+
779
+ return _convert_to_target_bitwidth(kv, target_bitwidth=bitwidth)
780
+
781
+ def broadcast_minor(src, shape):
782
+ if src.shape == shape:
783
+ return src
784
+ assert src.shape[:-1] == shape[:-1]
785
+ assert src.shape[-1] % 128 == 0
786
+ target_minor = align_to(shape[-1], src.shape[-1])
787
+ # no-op concatenation.
788
+ return jnp.concatenate(
789
+ [src for _ in range(target_minor // src.shape[-1])],
790
+ axis=-1)[..., :shape[-1]]
791
+
792
+ def process(static_q_len=None):
793
+ num_bkv = cdiv(kv_len, bkv_sz)
794
+ if static_q_len is None:
795
+ actual_bq_sz = bq_sz
796
+ num_bq = cdiv(q_len, actual_bq_sz)
797
+ else:
798
+ actual_bq_sz = min(bq_sz, static_q_len)
799
+ num_bq = cdiv(static_q_len, actual_bq_sz)
800
+
801
+ def get_next_bq_ids(seq_idx, bq_idx, bq_sem_idx):
802
+ next_bq_idx = bq_idx + 1
803
+ is_last_bq = next_bq_idx == num_bq
804
+ next_bq_idx = lax.select(is_last_bq, 0, next_bq_idx)
805
+ next_seq_idx = lax.select(is_last_bq, seq_idx + 1, seq_idx)
806
+ next_bq_sem_idx = lax.select(bq_sem_idx == 0, 1, 0)
807
+ return next_seq_idx, next_bq_idx, next_bq_sem_idx
808
+
809
+ def get_next_bkv_ids(seq_idx, bq_idx, bkv_idx, bkv_sem_idx):
810
+ next_bkv_idx = bkv_idx + 1
811
+ is_last_bkv = next_bkv_idx == num_bkv
812
+ next_bq_idx = lax.select(is_last_bkv, bq_idx + 1, bq_idx)
813
+ is_last_bq = next_bq_idx == num_bq
814
+ next_bq_idx = lax.select(is_last_bq, 0, next_bq_idx)
815
+ next_seq_idx = lax.select(is_last_bq, seq_idx + 1, seq_idx)
816
+ next_bkv_sem_idx = lax.select(bkv_sem_idx == 0, 1, 0)
817
+
818
+ if sliding_window is None:
819
+ # When sliding window is disabled, starting bkv_idx of next request is
820
+ # always 0 regardless of seq_idx of next request.
821
+ next_bkv_idx_start = 0
822
+ else:
823
+ # Determine starting bkv_idx of next request based on whether next
824
+ # request is from the same sequence or next sequence.
825
+ next_bkv_idx_start = lax.select(
826
+ is_last_bq,
827
+ next_seq_bkv_idx_start,
828
+ bkv_idx_start,
829
+ )
830
+ next_bkv_idx = lax.select(is_last_bkv, next_bkv_idx_start,
831
+ next_bkv_idx)
832
+
833
+ return next_seq_idx, next_bq_idx, next_bkv_idx, next_bkv_sem_idx
834
+
835
+ def compute_with_bq(bq_idx, _):
836
+ bq_sem_idx = sem_ids_ref[0]
837
+ next_seq_idx, next_bq_idx, next_bq_sem_idx = get_next_bq_ids(
838
+ seq_idx, bq_idx, bq_sem_idx)
839
+
840
+ # Prefetch next bq
841
+ @pl.when(next_seq_idx < num_seqs)
842
+ def prefetch_next_bq():
843
+ sem_ids_ref[0] = next_bq_sem_idx
844
+ start_fetch_bq(next_seq_idx, next_bq_idx, next_bq_sem_idx)
845
+
846
+ def compute_with_bkv(bkv_idx, _):
847
+ # Create bitmask for KV.
848
+ assert bkv_sz % kv_packing == 0
849
+
850
+ # Get next bkv ids.
851
+ bkv_sem_idx = sem_ids_ref[1]
852
+ next_seq_idx, _, next_bkv_idx, next_bkv_sem_idx = get_next_bkv_ids(
853
+ seq_idx, bq_idx, bkv_idx, bkv_sem_idx)
854
+
855
+ # Prefetch next bkv
856
+ @pl.when(next_seq_idx < num_seqs)
857
+ def prefetch_next_bkv():
858
+ sem_ids_ref[1] = next_bkv_sem_idx
859
+ start_fetch_bkv(next_seq_idx, next_bkv_idx,
860
+ next_bkv_sem_idx)
861
+
862
+ # Wait for cur bq if not ready yet
863
+ @pl.when(bkv_idx == 0)
864
+ def wait_cur_bq():
865
+ wait_fetch_bq(seq_idx, bq_idx, bq_sem_idx)
866
+
867
+ # Wait for cur bkv
868
+ offset, update_sz = wait_fetch_bkv(seq_idx, bkv_idx,
869
+ bkv_sem_idx)
870
+
871
+ # Start updating bkv to kv cache if applicable.
872
+ # Only needed in first bq loop.
873
+ @pl.when(jnp.logical_and(update_sz > 0, bq_idx == 0))
874
+ def update_cur_bkv_to_cache():
875
+ start_update_kv_cache(seq_idx, bkv_sem_idx, offset,
876
+ update_sz)
877
+
878
+ debug_print(
879
+ "[RPA debug] -----------flash attention-----------")
880
+ debug_print("[RPA debug] seq_idx={}", seq_idx)
881
+ debug_print("[RPA debug] bq_idx={}", bq_idx)
882
+ debug_print("[RPA debug] bkv_idx={}", bkv_idx)
883
+ if debug_mode:
884
+ # Skip flash attention if debug mode is enabled.
885
+ return
886
+
887
+ # Flash attention with cur bkv and bq
888
+ # NOTE: kv_packing is divided by 2 because k and v are packed together.
889
+ prev_bq_shape_0 = None
890
+ prev_kv_head_bv = None
891
+ prev_kv_head_idx = None
892
+ prev_kv_head_p = None
893
+ prev_kv_head_exp_m_diff = None
894
+ heads_per_load = max(1, kv_packing // 2)
895
+ for kv_head_start in range(0, actual_num_kv_heads,
896
+ heads_per_load):
897
+ bkv_lst = strided_load_bkv(
898
+ bkv_sem_idx,
899
+ kv_head_start * 2,
900
+ num_kv_heads_x2,
901
+ )
902
+ assert len(bkv_lst) == heads_per_load
903
+ for i in range(heads_per_load):
904
+ cur_kv_head_idx = kv_head_start + i
905
+ if cur_kv_head_idx >= actual_num_kv_heads:
906
+ break
907
+
908
+ cur_kv_head_bq = load_bq(bq_sem_idx,
909
+ cur_kv_head_idx,
910
+ actual_bq_sz=actual_bq_sz)
911
+ bk, bv = bkv_lst[i]
912
+ # FlashAttention is divided into `flash_attention_step1_qk_softmax`
913
+ # and `flash_attention_step2_pv` to pipeline the computation.
914
+ # `step2_pv` for the previous KV head, which depends on the softmax
915
+ # output, is overlapped with `step1_qk_softmax` for the current KV
916
+ # head, reducing overall wait times.
917
+ cur_kv_head_p, cur_kv_head_exp_m_diff = (
918
+ flash_attention_step1_qk_softmax(
919
+ cur_kv_head_bq,
920
+ bk,
921
+ bv,
922
+ bq_idx=bq_idx,
923
+ bkv_idx=bkv_idx,
924
+ kv_head_idx=cur_kv_head_idx,
925
+ ))
926
+ if prev_bq_shape_0 is not None:
927
+ flash_attention_step2_pv(
928
+ prev_bq_shape_0,
929
+ prev_kv_head_bv,
930
+ prev_kv_head_p,
931
+ prev_kv_head_exp_m_diff,
932
+ bkv_idx=bkv_idx,
933
+ kv_head_idx=prev_kv_head_idx,
934
+ )
935
+ prev_bq_shape_0 = cur_kv_head_bq.shape[0]
936
+ prev_kv_head_bv = bv
937
+ prev_kv_head_p = cur_kv_head_p
938
+ prev_kv_head_exp_m_diff = cur_kv_head_exp_m_diff
939
+ prev_kv_head_idx = cur_kv_head_idx
940
+
941
+ # Execute pv of last attention head.
942
+ assert prev_bq_shape_0 is not None
943
+ flash_attention_step2_pv(
944
+ prev_bq_shape_0,
945
+ prev_kv_head_bv,
946
+ prev_kv_head_p,
947
+ prev_kv_head_exp_m_diff,
948
+ bkv_idx=bkv_idx,
949
+ kv_head_idx=prev_kv_head_idx,
950
+ )
951
+
952
+ lax.fori_loop(0, num_bkv, compute_with_bkv, None, unroll=False)
953
+
954
+ # Load acc and calculate final output.
955
+ acc = acc_ref[...]
956
+ l = broadcast_minor(l_ref[...], acc.shape) # noqa
957
+ out = (lax.div(acc, l) if q_dtype == jnp.float32 else
958
+ (acc * pl.reciprocal(l, approx=True)).astype(q_dtype))
959
+
960
+ # Wait for previous bo to be fully sent before storing new bo.
961
+ bo_sem_idx = sem_ids_ref[2]
962
+ sem_ids_ref[2] = lax.select(bo_sem_idx == 0, 1, 0)
963
+ wait_send_bo(bo_sem_idx)
964
+
965
+ # Store output from acc to bo.
966
+ bo_x2_ref.at[bo_sem_idx].bitcast(jnp.int32).reshape(
967
+ actual_num_kv_heads,
968
+ bq_sz * num_q_heads_per_kv_head_per_packing,
969
+ head_dim,
970
+ )[...] = pltpu.bitcast(out, jnp.int32)
971
+
972
+ # Send cur bo
973
+ start_send_bo(seq_idx, bq_idx, bo_sem_idx)
974
+
975
+ lax.fori_loop(0, num_bq, compute_with_bq, None, unroll=False)
976
+
977
+ ### ------- Kernel start ------- ###
978
+
979
+ @pl.when(seq_idx == 0)
980
+ def prologue():
981
+ start_fetch_bq(0, 0, 0)
982
+
983
+ # Initialize bkv_x2_ref to zeros to avoid NaN issues from accessing
984
+ # uninitialized memory. Bitcast into int32 to avoid tiling issues.
985
+ bkv_x2_int32_ref = bkv_x2_ref.bitcast(jnp.int32).reshape(
986
+ (2, -1, 8, 128))
987
+ zeros = jnp.zeros(bkv_x2_int32_ref.shape[1:], jnp.int32)
988
+
989
+ # To pipeline VST and DMA, we divide the initialization into two steps.
990
+ bkv_x2_int32_ref[0] = zeros
991
+ start_fetch_bkv(0, bkv_idx_start, 0)
992
+ bkv_x2_int32_ref[1] = zeros
993
+
994
+ @pl.when(seq_idx < decode_end)
995
+ def process_decode():
996
+ process(static_q_len=1)
997
+
998
+ @pl.when(jnp.logical_and(decode_end <= seq_idx, seq_idx < prefill_end))
999
+ def process_prefill():
1000
+ process(static_q_len=chunk_prefill_size)
1001
+
1002
+ @pl.when(jnp.logical_and(prefill_end <= seq_idx, seq_idx < mixed_end))
1003
+ def process_mixed():
1004
+ process()
1005
+
1006
+ @pl.when(seq_idx == num_seqs - 1)
1007
+ def epilogue():
1008
+ for i in range(2):
1009
+ wait_send_bo(i)
1010
+ wait_update_kv_cache(i)
1011
+
1012
+ ### ------- Kernel end ------- ###
1013
+
1014
+
1015
+ def merge_kv(
1016
+ k: jax.
1017
+ Array, # [max_num_tokens, actual_num_kv_heads, actual_head_dim],
1018
+ v: jax.
1019
+ Array, # [max_num_tokens, actual_num_kv_heads, actual_head_dim],
1020
+ ):
1021
+ assert k.shape == v.shape
1022
+ assert k.dtype == v.dtype
1023
+ max_num_tokens, actual_num_kv_heads, actual_head_dim = k.shape
1024
+ kv_packing = get_dtype_packing(k.dtype)
1025
+ actual_num_kv_heads_x2 = actual_num_kv_heads * 2
1026
+ num_kv_heads_x2 = align_to(actual_num_kv_heads_x2, kv_packing)
1027
+ head_dim = align_to(actual_head_dim, 128)
1028
+ kv = jnp.pad(
1029
+ jnp.concat([k, v],
1030
+ axis=-1).reshape(max_num_tokens, actual_num_kv_heads_x2,
1031
+ actual_head_dim),
1032
+ (
1033
+ (0, 0),
1034
+ (0, num_kv_heads_x2 - actual_num_kv_heads_x2),
1035
+ (0, head_dim - actual_head_dim),
1036
+ ),
1037
+ constant_values=0,
1038
+ ).reshape(
1039
+ max_num_tokens,
1040
+ num_kv_heads_x2 // kv_packing,
1041
+ kv_packing,
1042
+ head_dim,
1043
+ )
1044
+ return kv
1045
+
1046
+
1047
+ def prepare_inputs(
1048
+ q: jax.Array, # [max_num_tokens, actual_num_q_heads, actual_head_dim],
1049
+ k: jax.
1050
+ Array, # [max_num_tokens, actual_num_kv_heads, actual_head_dim],
1051
+ v: jax.
1052
+ Array, # [max_num_tokens, actual_num_kv_heads, actual_head_dim],
1053
+ ):
1054
+ max_num_tokens, actual_num_q_heads, actual_head_dim = q.shape
1055
+ actual_num_kv_heads = k.shape[1]
1056
+ assert actual_num_q_heads % actual_num_kv_heads == 0
1057
+ actual_num_q_heads_per_kv_head = actual_num_q_heads // actual_num_kv_heads
1058
+ q_packing = get_dtype_packing(q.dtype)
1059
+ num_q_heads_per_kv_head = align_to(actual_num_q_heads_per_kv_head,
1060
+ q_packing)
1061
+ head_dim = align_to(actual_head_dim, 128)
1062
+ q = (
1063
+ jnp.pad(
1064
+ q.reshape(
1065
+ max_num_tokens,
1066
+ actual_num_kv_heads,
1067
+ actual_num_q_heads_per_kv_head,
1068
+ actual_head_dim,
1069
+ ),
1070
+ (
1071
+ (0, 0),
1072
+ (0, 0),
1073
+ (0, num_q_heads_per_kv_head - actual_num_q_heads_per_kv_head),
1074
+ (0, head_dim - actual_head_dim),
1075
+ ),
1076
+ constant_values=0,
1077
+ ).reshape(
1078
+ max_num_tokens,
1079
+ actual_num_kv_heads,
1080
+ num_q_heads_per_kv_head // q_packing,
1081
+ q_packing,
1082
+ head_dim,
1083
+ )
1084
+ # TODO(jevinjiang): Explore fusing swapping non-tiling axis to DMA.
1085
+ .swapaxes(0, 1))
1086
+ # TODO(kyuyeunk, chengjiyao): Add kv quantization here.
1087
+ kv = merge_kv(k, v)
1088
+ return q, kv
1089
+
1090
+
1091
+ def prepare_outputs(
1092
+ out, # [actual_num_kv_heads, max_num_tokens, num_q_heads_per_kv_head // q_packing, q_packing, head_dim]
1093
+ actual_num_q_heads_per_kv_head: int,
1094
+ actual_head_dim: int,
1095
+ ):
1096
+ (
1097
+ actual_num_kv_heads,
1098
+ max_num_tokens,
1099
+ num_q_heads_per_kv_head_per_q_packing,
1100
+ q_packing,
1101
+ head_dim,
1102
+ ) = out.shape
1103
+ actual_num_q_heads = actual_num_q_heads_per_kv_head * actual_num_kv_heads
1104
+ return (out.swapaxes(0, 1).reshape(
1105
+ max_num_tokens,
1106
+ actual_num_kv_heads,
1107
+ num_q_heads_per_kv_head_per_q_packing * q_packing,
1108
+ head_dim,
1109
+ )[:, :, :actual_num_q_heads_per_kv_head, :actual_head_dim].reshape(
1110
+ max_num_tokens, actual_num_q_heads, actual_head_dim))
1111
+
1112
+
1113
+ # Expect to run this validation during runtime.
1114
+ def dynamic_validate_inputs(
1115
+ queries: jax.
1116
+ Array, # [max_num_tokens, actual_num_q_heads, actual_head_dim]
1117
+ keys: jax.Array, # [max_num_tokens, actual_num_kv_heads, actual_head_dim]
1118
+ values: jax.
1119
+ Array, # [max_num_tokens, actual_num_kv_heads, actual_head_dim]
1120
+ kv_cache: jax.
1121
+ Array, # [total_num_pages, page_size, num_kv_heads_x2 // kv_packing, kv_packing, head_dim]
1122
+ kv_lens: jax.Array, # i32[max_num_seqs]
1123
+ page_indices: jax.Array, # i32[max_num_seqs * pages_per_seq]
1124
+ cu_q_lens: jax.Array, # i32[max_num_seqs + 1]
1125
+ distribution: jax.Array, # i32[3]
1126
+ *,
1127
+ sm_scale: float = 1.0,
1128
+ sliding_window: int | None = None,
1129
+ soft_cap: float | None = None,
1130
+ mask_value: float | None = DEFAULT_MASK_VALUE,
1131
+ q_scale: float | None = None,
1132
+ k_scale: float | None = None,
1133
+ v_scale: float | None = None,
1134
+ # Kernel optimization params.
1135
+ chunk_prefill_size: int | None = None,
1136
+ # Kernel tuning params.
1137
+ num_kv_pages_per_block: int | None = None,
1138
+ num_queries_per_block: int | None = None,
1139
+ vmem_limit_bytes: int | None = None,
1140
+ # Debug params.
1141
+ debug_mode: bool = False,
1142
+ ):
1143
+ q, k, v = queries, keys, values
1144
+ static_validate_inputs(
1145
+ q,
1146
+ k,
1147
+ v,
1148
+ kv_cache,
1149
+ kv_lens,
1150
+ page_indices,
1151
+ cu_q_lens,
1152
+ distribution,
1153
+ sm_scale=sm_scale,
1154
+ sliding_window=sliding_window,
1155
+ soft_cap=soft_cap,
1156
+ mask_value=mask_value,
1157
+ q_scale=q_scale,
1158
+ k_scale=k_scale,
1159
+ v_scale=v_scale,
1160
+ chunk_prefill_size=chunk_prefill_size,
1161
+ num_kv_pages_per_block=num_kv_pages_per_block,
1162
+ num_queries_per_block=num_queries_per_block,
1163
+ vmem_limit_bytes=vmem_limit_bytes,
1164
+ debug_mode=debug_mode,
1165
+ )
1166
+ max_num_tokens = q.shape[0]
1167
+ total_num_pages = kv_cache.shape[0]
1168
+ page_size = kv_cache.shape[1]
1169
+ max_num_seqs = kv_lens.shape[0]
1170
+ num_page_indices = page_indices.shape[0]
1171
+ assert num_page_indices % max_num_seqs == 0
1172
+ pages_per_seq = num_page_indices // max_num_seqs
1173
+
1174
+ i, j, k = distribution
1175
+ if not (i <= j <= k):
1176
+ raise ValueError(f"Invalid distribution: {distribution=}")
1177
+
1178
+ if k > max_num_seqs:
1179
+ raise ValueError(f"num_seqs={k} must be <= {max_num_seqs=}")
1180
+
1181
+ if cu_q_lens[k] > max_num_tokens:
1182
+ raise ValueError(
1183
+ f"Total q tokens {cu_q_lens[k]} must be <= {max_num_tokens=}.")
1184
+ for i in range(k):
1185
+ q_len = cu_q_lens[i + 1] - cu_q_lens[i]
1186
+ kv_len = kv_lens[i]
1187
+ if not (0 < q_len <= kv_len):
1188
+ raise ValueError(
1189
+ f"Require 0 < {q_len=} <= {kv_len=} at sequence {i}.")
1190
+ page_cnt = cdiv(kv_len, page_size)
1191
+ if page_cnt > pages_per_seq:
1192
+ raise ValueError(
1193
+ f"Require {page_cnt=} <= {pages_per_seq=} at sequence {i} where"
1194
+ f" {kv_len=} and {page_size=}.")
1195
+ for p in range(page_cnt):
1196
+ page_idx = page_indices[i * pages_per_seq + p]
1197
+ if not (0 <= page_idx < total_num_pages):
1198
+ raise ValueError(
1199
+ f"Require 0 <= {page_idx=} < {total_num_pages=} at sequence"
1200
+ f" {i} where {kv_len=} and {page_size=}.")
1201
+
1202
+
1203
+ # Expect to run this validation during compile time.
1204
+ def static_validate_inputs(
1205
+ queries: jax.
1206
+ Array, # [max_num_tokens, actual_num_q_heads, actual_head_dim]
1207
+ keys: jax.Array, # [max_num_tokens, actual_num_kv_heads, actual_head_dim]
1208
+ values: jax.
1209
+ Array, # [max_num_tokens, actual_num_kv_heads, actual_head_dim]
1210
+ kv_cache: jax.
1211
+ Array, # [total_num_pages, page_size, num_kv_heads_x2 // kv_packing, kv_packing, head_dim]
1212
+ kv_lens: jax.Array, # i32[max_num_seqs]
1213
+ page_indices: jax.Array, # i32[max_num_seqs * pages_per_seq]
1214
+ cu_q_lens: jax.Array, # i32[max_num_seqs + 1]
1215
+ distribution: jax.Array, # i32[3]
1216
+ *,
1217
+ sm_scale: float = 1.0,
1218
+ sliding_window: int | None = None,
1219
+ soft_cap: float | None = None,
1220
+ mask_value: float | None = DEFAULT_MASK_VALUE,
1221
+ q_scale: float | None = None,
1222
+ k_scale: float | None = None,
1223
+ v_scale: float | None = None,
1224
+ # Kernel optimization params.
1225
+ chunk_prefill_size: int | None = None,
1226
+ # Kernel tuning params.
1227
+ num_kv_pages_per_block: int | None = None,
1228
+ num_queries_per_block: int | None = None,
1229
+ vmem_limit_bytes: int | None = None,
1230
+ # Debug params.
1231
+ debug_mode: bool = False,
1232
+ ):
1233
+ """Validate inputs to the RPA kernel statically."""
1234
+ q, k, v = queries, keys, values
1235
+ if not (len(q.shape) == len(k.shape) == len(v.shape) == 3):
1236
+ raise ValueError(
1237
+ f"Expected 3D array for {q.shape=}, {k.shape=}, {v.shape=}")
1238
+ if k.shape != v.shape:
1239
+ raise ValueError(f"Expected {k.shape=} to be equal to {v.shape=}")
1240
+ if not (q.shape[0] == k.shape[0] == v.shape[0]):
1241
+ raise ValueError(
1242
+ f"Expected {q.shape[0]=} to be equal to {k.shape[0]=} and {v.shape[0]=}"
1243
+ )
1244
+ if not (q.shape[2] == k.shape[2] == v.shape[2]):
1245
+ raise ValueError(
1246
+ f"Expected {q.shape[2]=} to be equal to {k.shape[2]=} and {v.shape[2]=}"
1247
+ )
1248
+
1249
+ actual_head_dim = q.shape[2]
1250
+ actual_num_q_heads = q.shape[1]
1251
+ actual_num_kv_heads = k.shape[1]
1252
+
1253
+ if actual_num_q_heads % actual_num_kv_heads != 0:
1254
+ raise ValueError(f"Expected {actual_num_q_heads=} to be divisible by"
1255
+ f" {actual_num_kv_heads=}.")
1256
+
1257
+ (
1258
+ _,
1259
+ page_size,
1260
+ num_kv_heads_x2_per_kv_packing,
1261
+ kv_packing,
1262
+ head_dim,
1263
+ ) = kv_cache.shape
1264
+
1265
+ if head_dim != align_to(actual_head_dim, 128):
1266
+ raise ValueError(
1267
+ f"Expected {head_dim=} is equal to {align_to(actual_head_dim, 128)=}"
1268
+ )
1269
+ # Note: we expect the kv quantization happens outside of the RPA kernel.
1270
+ if not (kv_cache.dtype == k.dtype == v.dtype):
1271
+ raise ValueError(
1272
+ f"Expected {kv_cache.dtype=} to be equal to {k.dtype=} and {v.dtype=}."
1273
+ )
1274
+ # Integer kv quantization is currently not supported.
1275
+ if not jnp.issubdtype(kv_cache.dtype, jnp.floating):
1276
+ raise ValueError(f"Expected {kv_cache.dtype=} to be a floating point.")
1277
+ if kv_packing != get_dtype_packing(kv_cache.dtype):
1278
+ raise ValueError(
1279
+ f"{kv_packing=} does not match with {kv_cache.dtype=}")
1280
+
1281
+ num_kv_heads_x2 = num_kv_heads_x2_per_kv_packing * kv_packing
1282
+ if num_kv_heads_x2 % 2 != 0:
1283
+ raise ValueError(
1284
+ f"Combined KV heads must be divisible by 2, but got {num_kv_heads_x2}"
1285
+ )
1286
+ if align_to(actual_num_kv_heads * 2, kv_packing) != num_kv_heads_x2:
1287
+ raise ValueError(
1288
+ f"Invalid {num_kv_heads_x2=}, {actual_num_kv_heads=}, {kv_packing=}"
1289
+ )
1290
+
1291
+ if not (jnp.int32 == kv_lens.dtype == page_indices.dtype == cu_q_lens.dtype
1292
+ == distribution.dtype):
1293
+ raise ValueError(
1294
+ f"Expected int32 dtype for {kv_lens.dtype=}, {page_indices.dtype=},"
1295
+ f" {cu_q_lens.dtype=}, {distribution.dtype=}")
1296
+
1297
+ if not (len(kv_lens.shape) == len(page_indices.shape) == len(
1298
+ cu_q_lens.shape) == 1):
1299
+ raise ValueError(
1300
+ f"Expected 1D array for {kv_lens.shape=}, {page_indices.shape=},"
1301
+ f" {cu_q_lens.shape=}")
1302
+
1303
+ max_num_seqs = kv_lens.shape[0]
1304
+ num_page_indices = page_indices.shape[0]
1305
+ if num_page_indices % max_num_seqs != 0:
1306
+ raise ValueError(
1307
+ f"Expected {num_page_indices=} to be divisible by {max_num_seqs=}."
1308
+ )
1309
+ if cu_q_lens.shape != (max_num_seqs + 1, ):
1310
+ raise ValueError(
1311
+ f"Expected {cu_q_lens.shape=} to be ({max_num_seqs + 1},).")
1312
+ if distribution.shape != (3, ):
1313
+ raise ValueError(f"Expected {distribution.shape=} to be (3,).")
1314
+
1315
+ if page_size % kv_packing != 0:
1316
+ raise ValueError(f"{page_size=} must be divisible by {kv_packing=}.")
1317
+ if sliding_window is not None and sliding_window <= 0:
1318
+ raise ValueError(f"{sliding_window=} must be positive.")
1319
+ if soft_cap is not None and soft_cap == 0.0:
1320
+ raise ValueError(f"{soft_cap=} must not be 0.0.")
1321
+ if chunk_prefill_size is not None and chunk_prefill_size <= 0:
1322
+ raise ValueError(f"{chunk_prefill_size=} must be positive.")
1323
+ if num_kv_pages_per_block is not None:
1324
+ if num_kv_pages_per_block <= 0:
1325
+ raise ValueError(f"{num_kv_pages_per_block=} must be positive.")
1326
+ if num_queries_per_block is not None:
1327
+ if num_queries_per_block <= 0:
1328
+ raise ValueError(f"{num_queries_per_block=} must be positive.")
1329
+ if vmem_limit_bytes is not None and vmem_limit_bytes <= 0:
1330
+ raise ValueError(f"{vmem_limit_bytes=} must be positive.")
1331
+
1332
+ # No constraints for the following inputs.
1333
+ del sm_scale
1334
+ del mask_value
1335
+ del q_scale
1336
+ del k_scale
1337
+ del v_scale
1338
+ del debug_mode
1339
+
1340
+
1341
+ def get_kernel_scope_name(bq_size, bkv_p, page_size):
1342
+ return f"RPA-bq_{bq_size}-bkvp_{bkv_p}-p_{page_size}-"
1343
+
1344
+
1345
+ @functools.partial(
1346
+ jax.jit,
1347
+ static_argnames=(
1348
+ "sm_scale",
1349
+ "sliding_window",
1350
+ "soft_cap",
1351
+ "mask_value",
1352
+ "q_scale",
1353
+ "k_scale",
1354
+ "v_scale",
1355
+ "chunk_prefill_size",
1356
+ "num_kv_pages_per_block",
1357
+ "num_queries_per_block",
1358
+ "vmem_limit_bytes",
1359
+ "debug_mode",
1360
+ ),
1361
+ donate_argnames=("kv_cache", ),
1362
+ )
1363
+ def ragged_paged_attention(
1364
+ queries: jax.
1365
+ Array, # [max_num_tokens, actual_num_q_heads, actual_head_dim]
1366
+ keys: jax.Array, # [max_num_tokens, actual_num_kv_heads, actual_head_dim]
1367
+ values: jax.
1368
+ Array, # [max_num_tokens, actual_num_kv_heads, actual_head_dim]
1369
+ kv_cache: jax.
1370
+ Array, # [total_num_pages, page_size, num_kv_heads_x2 // kv_packing, kv_packing, head_dim]
1371
+ kv_lens: jax.Array, # i32[max_num_seqs]
1372
+ page_indices: jax.Array, # i32[max_num_seqs * pages_per_seq]
1373
+ cu_q_lens: jax.Array, # i32[max_num_seqs + 1]
1374
+ distribution: jax.Array, # i32[3]
1375
+ *,
1376
+ sm_scale: float = 1.0,
1377
+ sliding_window: int | None = None,
1378
+ soft_cap: float | None = None,
1379
+ mask_value: float | None = DEFAULT_MASK_VALUE,
1380
+ q_scale: float | None = None,
1381
+ k_scale: float | None = None,
1382
+ v_scale: float | None = None,
1383
+ # Kernel optimization params.
1384
+ chunk_prefill_size: int | None = None,
1385
+ # Kernel tuning params.
1386
+ num_kv_pages_per_block: int | None = None,
1387
+ num_queries_per_block: int | None = None,
1388
+ vmem_limit_bytes: int | None = None,
1389
+ # Debug params.
1390
+ debug_mode: bool = False,
1391
+ ):
1392
+ """Ragged paged attention that supports mixed prefill and decode.
1393
+
1394
+ Args:
1395
+ queries: concatenated all sequences' queries.
1396
+ keys: concatenated all sequences' keys (quantized).
1397
+ values: concatenated all sequences' values (quantized).
1398
+ kv_cache: paged KV cache with TPU-friendly shape.
1399
+ kv_lens: padded kv lengths. Only the first num_seqs values are valid.
1400
+ page_indices: flattened page indices look-up table by (seq_id, page_id).
1401
+ cu_q_lens: the cumulative sum of the effective query lengths. Similar to
1402
+ kv_lens, only the first num_seqs+1 values are valid.
1403
+ distribution: (i, j, k) represents that sequences[0:i] are decode-only,
1404
+ sequences[i:j] are chunked-prefill-only, and sequences[j:k] are mixed. The
1405
+ k is also the total number of sequences.
1406
+ sm_scale: the softmax scale which will be applied to the Q@K^T.
1407
+ sliding_window: the sliding window size for the attention.
1408
+ soft_cap: the logit soft cap for the attention.
1409
+ mask_value: mask value for causal mask.
1410
+ q_scale: the scale for the query.
1411
+ k_scale: the scale for the key cache.
1412
+ v_scale: the scale for the value cache.
1413
+ chunk_prefill_size: the chunk prefill size for the attention.
1414
+ num_kv_pages_per_block: number of kv pages to be processed in one flash
1415
+ attention block in the pallas kernel.
1416
+ num_queries_per_block: number of kv pages to be processed in one flash
1417
+ attention block in the pallas kernel.
1418
+ vmem_limit_bytes: the vmem limit for the pallas kernel.
1419
+ debug_mode: if true, RPA does not issue any DMAs or run flash attention but
1420
+ print debug info. Need to compile with `--xla_tpu_enable_log_recorder`.
1421
+
1422
+ Returns:
1423
+ The output of the attention.
1424
+ """
1425
+ q, k, v = queries, keys, values
1426
+ static_validate_inputs(
1427
+ q,
1428
+ k,
1429
+ v,
1430
+ kv_cache,
1431
+ kv_lens,
1432
+ page_indices,
1433
+ cu_q_lens,
1434
+ distribution,
1435
+ sm_scale=sm_scale,
1436
+ sliding_window=sliding_window,
1437
+ soft_cap=soft_cap,
1438
+ mask_value=mask_value,
1439
+ q_scale=q_scale,
1440
+ k_scale=k_scale,
1441
+ v_scale=v_scale,
1442
+ chunk_prefill_size=chunk_prefill_size,
1443
+ num_kv_pages_per_block=num_kv_pages_per_block,
1444
+ num_queries_per_block=num_queries_per_block,
1445
+ vmem_limit_bytes=vmem_limit_bytes,
1446
+ )
1447
+
1448
+ actual_num_q_heads = q.shape[1]
1449
+ actual_head_dim = q.shape[2]
1450
+ actual_num_kv_heads = k.shape[1]
1451
+
1452
+ actual_num_q_heads_per_kv_head = actual_num_q_heads // actual_num_kv_heads
1453
+ q, kv = prepare_inputs(q, k, v)
1454
+ (
1455
+ _,
1456
+ max_num_tokens,
1457
+ num_q_heads_per_kv_head_per_q_packing,
1458
+ q_packing,
1459
+ head_dim,
1460
+ ) = q.shape
1461
+ page_size = kv_cache.shape[1]
1462
+ max_num_seqs = kv_lens.shape[0]
1463
+ num_page_indices = page_indices.shape[0]
1464
+ assert num_page_indices % max_num_seqs == 0
1465
+ pages_per_seq = num_page_indices // max_num_seqs
1466
+ num_q_heads_per_kv_head = num_q_heads_per_kv_head_per_q_packing * q_packing
1467
+
1468
+ bkv_p = num_kv_pages_per_block
1469
+ bq_sz = num_queries_per_block
1470
+ if bq_sz is None or bkv_p is None:
1471
+ bkv_p, bq_sz = get_tuned_block_sizes(
1472
+ q.dtype,
1473
+ kv_cache.dtype,
1474
+ actual_num_q_heads,
1475
+ actual_num_kv_heads,
1476
+ head_dim,
1477
+ page_size,
1478
+ max_num_tokens,
1479
+ pages_per_seq,
1480
+ sliding_window,
1481
+ )
1482
+ bkv_sz = bkv_p * page_size
1483
+ if vmem_limit_bytes is None:
1484
+ # TODO (jevinjiang/jacobplatin): change this to use
1485
+ # `get_vmem_estimate_bytes` when VREG spilling is fixed.
1486
+ vmem_limit_bytes = DEFAULT_VMEM_LIMIT_BYTES
1487
+ grid = (distribution[2], )
1488
+
1489
+ in_specs = [
1490
+ pl.BlockSpec(memory_space=pltpu.HBM),
1491
+ pl.BlockSpec(memory_space=pltpu.HBM),
1492
+ pl.BlockSpec(memory_space=pltpu.HBM),
1493
+ ]
1494
+
1495
+ out_specs = [
1496
+ pl.BlockSpec(memory_space=pltpu.HBM),
1497
+ pl.BlockSpec(memory_space=pltpu.HBM),
1498
+ ]
1499
+
1500
+ bkv_double_buf = pltpu.VMEM(
1501
+ (2, bkv_sz, *kv_cache.shape[2:]),
1502
+ kv_cache.dtype,
1503
+ )
1504
+
1505
+ bq_double_buf = pltpu.VMEM(
1506
+ (2, actual_num_kv_heads, bq_sz, *q.shape[2:]),
1507
+ q.dtype,
1508
+ )
1509
+
1510
+ bo_double_buf = bq_double_buf
1511
+
1512
+ l_scratch = pltpu.VMEM(
1513
+ (actual_num_kv_heads, bq_sz * num_q_heads_per_kv_head, 128),
1514
+ jnp.float32,
1515
+ )
1516
+ m_scratch = l_scratch
1517
+
1518
+ acc_scratch = pltpu.VMEM(
1519
+ (actual_num_kv_heads, bq_sz * num_q_heads_per_kv_head, head_dim),
1520
+ jnp.float32,
1521
+ )
1522
+
1523
+ scratch_shapes = [
1524
+ bkv_double_buf, # Double buffering for kv block.
1525
+ bq_double_buf, # Double buffering for q block.
1526
+ bo_double_buf, # Double buffering for output block.
1527
+ # Semaphores for double buffering of bkv, bq, bo and bkv_update.
1528
+ pltpu.SemaphoreType.DMA((4, 2)),
1529
+ # Intermediate buffers per kv head for flash attention.
1530
+ l_scratch,
1531
+ m_scratch,
1532
+ acc_scratch,
1533
+ ]
1534
+
1535
+ scalar_prefetches = (
1536
+ kv_lens,
1537
+ # TODO(jevinjiang): can we use ragged page_indices to save some smem?
1538
+ page_indices,
1539
+ cu_q_lens,
1540
+ distribution,
1541
+ # (bq_sem_idx, bkv_sem_idx, bo_sem_idx)
1542
+ jnp.zeros((3, ), jnp.int32),
1543
+ # (bo_sem_0_seq_idx, bo_sem_1_seq_idx, bo_sem_0_bo_idx, bo_sem_1_bo_idx)
1544
+ jnp.full((4, ), -1, jnp.int32),
1545
+ # (bkv_sem_0_seq_idx, bkv_sem_1_seq_idx, bkv_sem_0_offset, bkv_sem_1_offset, bkv_sem_0_sz, bkv_sem_1_sz)
1546
+ jnp.full((6, ), -1, jnp.int32),
1547
+ )
1548
+
1549
+ scope_name = get_kernel_scope_name(bq_sz, bkv_p, page_size)
1550
+ kernel = pl.pallas_call(
1551
+ functools.partial(
1552
+ _ragged_paged_attention_kernel,
1553
+ sm_scale=sm_scale,
1554
+ sliding_window=sliding_window,
1555
+ soft_cap=soft_cap,
1556
+ mask_value=mask_value,
1557
+ q_scale=q_scale,
1558
+ k_scale=k_scale,
1559
+ v_scale=v_scale,
1560
+ chunk_prefill_size=chunk_prefill_size,
1561
+ bq_sz=bq_sz,
1562
+ bkv_p=bkv_p,
1563
+ debug_mode=debug_mode,
1564
+ ),
1565
+ grid_spec=pltpu.PrefetchScalarGridSpec(
1566
+ num_scalar_prefetch=len(scalar_prefetches),
1567
+ in_specs=in_specs,
1568
+ out_specs=out_specs,
1569
+ grid=grid,
1570
+ scratch_shapes=scratch_shapes,
1571
+ ),
1572
+ compiler_params=pltpu.CompilerParams(
1573
+ # TODO(jevinjiang): since each sequence depends on the previous
1574
+ # one, we need some extra work to support Megacore mode.
1575
+ dimension_semantics=("arbitrary", ),
1576
+ vmem_limit_bytes=vmem_limit_bytes,
1577
+ ),
1578
+ out_shape=[
1579
+ jax.ShapeDtypeStruct(shape=q.shape, dtype=q.dtype),
1580
+ jax.ShapeDtypeStruct(shape=kv_cache.shape, dtype=kv_cache.dtype),
1581
+ ],
1582
+ input_output_aliases={
1583
+ 7: 0,
1584
+ 9: 1
1585
+ },
1586
+ name=scope_name,
1587
+ )
1588
+
1589
+ output, updated_kv_cache = kernel(*scalar_prefetches, q, kv, kv_cache)
1590
+ return (
1591
+ prepare_outputs(output, actual_num_q_heads_per_kv_head,
1592
+ actual_head_dim),
1593
+ updated_kv_cache,
1594
+ )