tpu-inference 0.12.0.dev20251222__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (260) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_core_tpu.py +513 -0
  4. tests/core/test_disagg_executor.py +60 -0
  5. tests/core/test_disagg_utils.py +67 -0
  6. tests/core/test_dp_scheduler.py +724 -0
  7. tests/core/test_init.py +63 -0
  8. tests/distributed/__init__.py +13 -0
  9. tests/distributed/test_distributed_utils.py +120 -0
  10. tests/distributed/test_tpu_connector.py +478 -0
  11. tests/e2e/__init__.py +13 -0
  12. tests/e2e/test_async_scheduler.py +211 -0
  13. tests/e2e/test_data_parallel.py +393 -0
  14. tests/e2e/test_local_disagg.py +257 -0
  15. tests/e2e/test_model_loader.py +268 -0
  16. tests/e2e/test_multi_modal_inference.py +111 -0
  17. tests/e2e/test_pipeline_parallel.py +265 -0
  18. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  19. tests/e2e/test_sampling_params.py +269 -0
  20. tests/e2e/test_speculative_decoding.py +291 -0
  21. tests/e2e/test_structured_decoding.py +46 -0
  22. tests/executors/__init__.py +13 -0
  23. tests/executors/test_ray_distributed_executor.py +199 -0
  24. tests/experimental/__init__.py +13 -0
  25. tests/experimental/test_llama3_jax_stashed.py +208 -0
  26. tests/kernels/__init__.py +13 -0
  27. tests/kernels/collectives/__init__.py +13 -0
  28. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  29. tests/kernels/fused_moe_v1_test.py +388 -0
  30. tests/kernels/gmm_test.py +205 -0
  31. tests/kernels/mla_v1_test.py +498 -0
  32. tests/kernels/quantized_matmul_kernel_test.py +159 -0
  33. tests/kernels/ragged_kv_cache_update_v2_test.py +248 -0
  34. tests/kernels/ragged_paged_attention_kernel_v2_test.py +414 -0
  35. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +565 -0
  36. tests/kernels/ragged_paged_attention_kernel_v3_test.py +520 -0
  37. tests/layers/__init__.py +13 -0
  38. tests/layers/common/__init__.py +13 -0
  39. tests/layers/common/test_attention_interface.py +156 -0
  40. tests/layers/common/test_quantization.py +149 -0
  41. tests/layers/jax/__init__.py +13 -0
  42. tests/layers/jax/attention/__init__.py +13 -0
  43. tests/layers/jax/attention/test_common_attention.py +103 -0
  44. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  45. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  46. tests/layers/jax/moe/__init__.py +13 -0
  47. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  48. tests/layers/jax/sample/__init__.py +13 -0
  49. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  50. tests/layers/jax/sample/test_sampling.py +115 -0
  51. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  52. tests/layers/jax/test_layers.py +155 -0
  53. tests/layers/jax/test_qwix.py +969 -0
  54. tests/layers/jax/test_rope.py +93 -0
  55. tests/layers/jax/test_sharding.py +159 -0
  56. tests/layers/jax/test_transformer_block.py +152 -0
  57. tests/layers/vllm/__init__.py +13 -0
  58. tests/layers/vllm/test_attention.py +363 -0
  59. tests/layers/vllm/test_awq.py +405 -0
  60. tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +403 -0
  62. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +426 -0
  63. tests/layers/vllm/test_fp8.py +17 -0
  64. tests/layers/vllm/test_mxfp4.py +297 -0
  65. tests/layers/vllm/test_unquantized.py +621 -0
  66. tests/layers/vllm/utils.py +72 -0
  67. tests/lora/__init__.py +13 -0
  68. tests/lora/conftest.py +46 -0
  69. tests/lora/test_bgmv.py +57 -0
  70. tests/lora/test_layers.py +666 -0
  71. tests/lora/test_lora.py +147 -0
  72. tests/lora/test_lora_perf.py +67 -0
  73. tests/lora/utils.py +88 -0
  74. tests/models/__init__.py +13 -0
  75. tests/models/common/__init__.py +13 -0
  76. tests/models/common/test_model_loader.py +455 -0
  77. tests/models/jax/__init__.py +13 -0
  78. tests/models/jax/test_deepseek_v3.py +401 -0
  79. tests/models/jax/test_llama3.py +184 -0
  80. tests/models/jax/test_llama4.py +298 -0
  81. tests/models/jax/test_llama_eagle3.py +197 -0
  82. tests/models/jax/test_llama_guard_4.py +242 -0
  83. tests/models/jax/test_qwen2.py +172 -0
  84. tests/models/jax/test_qwen2_5_vl.py +606 -0
  85. tests/models/jax/test_qwen3.py +169 -0
  86. tests/models/jax/test_weight_loading.py +180 -0
  87. tests/models/jax/utils/__init__.py +13 -0
  88. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  89. tests/platforms/__init__.py +13 -0
  90. tests/platforms/test_tpu_platform.py +54 -0
  91. tests/runner/__init__.py +13 -0
  92. tests/runner/test_block_table.py +395 -0
  93. tests/runner/test_input_batch.py +226 -0
  94. tests/runner/test_kv_cache.py +220 -0
  95. tests/runner/test_kv_cache_manager.py +498 -0
  96. tests/runner/test_multimodal_manager.py +429 -0
  97. tests/runner/test_persistent_batch_manager.py +84 -0
  98. tests/runner/test_speculative_decoding_manager.py +368 -0
  99. tests/runner/test_structured_decoding_manager.py +220 -0
  100. tests/runner/test_tpu_runner.py +202 -0
  101. tests/runner/test_tpu_runner_dp.py +1033 -0
  102. tests/runner/test_tpu_runner_mesh.py +200 -0
  103. tests/runner/test_utils.py +411 -0
  104. tests/spec_decode/__init__.py +13 -0
  105. tests/spec_decode/test_eagle3.py +311 -0
  106. tests/test_base.py +215 -0
  107. tests/test_envs.py +280 -0
  108. tests/test_tpu_info.py +134 -0
  109. tests/test_utils.py +193 -0
  110. tests/worker/__init__.py +13 -0
  111. tests/worker/tpu_worker_test.py +414 -0
  112. tpu_inference/__init__.py +67 -0
  113. tpu_inference/core/__init__.py +13 -0
  114. tpu_inference/core/core_tpu.py +786 -0
  115. tpu_inference/core/disagg_executor.py +118 -0
  116. tpu_inference/core/disagg_utils.py +49 -0
  117. tpu_inference/core/sched/__init__.py +13 -0
  118. tpu_inference/core/sched/dp_scheduler.py +814 -0
  119. tpu_inference/distributed/__init__.py +13 -0
  120. tpu_inference/distributed/jax_parallel_state.py +81 -0
  121. tpu_inference/distributed/tpu_connector.py +732 -0
  122. tpu_inference/distributed/utils.py +112 -0
  123. tpu_inference/env_override.py +9 -0
  124. tpu_inference/envs.py +191 -0
  125. tpu_inference/executors/__init__.py +13 -0
  126. tpu_inference/executors/ray_distributed_executor.py +399 -0
  127. tpu_inference/experimental/__init__.py +13 -0
  128. tpu_inference/experimental/llama3_jax_stashed.py +272 -0
  129. tpu_inference/kernels/__init__.py +13 -0
  130. tpu_inference/kernels/collectives/__init__.py +13 -0
  131. tpu_inference/kernels/collectives/all_gather_matmul.py +741 -0
  132. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +65 -0
  133. tpu_inference/kernels/collectives/util.py +47 -0
  134. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  135. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  136. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  137. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  138. tpu_inference/kernels/fused_moe/v1/kernel.py +1612 -0
  139. tpu_inference/kernels/megablox/__init__.py +13 -0
  140. tpu_inference/kernels/megablox/common.py +54 -0
  141. tpu_inference/kernels/megablox/gmm.py +646 -0
  142. tpu_inference/kernels/mla/__init__.py +13 -0
  143. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  144. tpu_inference/kernels/mla/v1/kernel.py +1340 -0
  145. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  146. tpu_inference/kernels/quantized_matmul/kernel.py +456 -0
  147. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  148. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  149. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  150. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  151. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +876 -0
  152. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +288 -0
  153. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  154. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  155. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1594 -0
  156. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1586 -0
  157. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4460 -0
  158. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +548 -0
  159. tpu_inference/kernels/ragged_paged_attention/v3/util.py +65 -0
  160. tpu_inference/layers/__init__.py +13 -0
  161. tpu_inference/layers/common/__init__.py +13 -0
  162. tpu_inference/layers/common/attention_interface.py +403 -0
  163. tpu_inference/layers/common/attention_metadata.py +48 -0
  164. tpu_inference/layers/common/binary_search.py +295 -0
  165. tpu_inference/layers/common/quant_methods.py +23 -0
  166. tpu_inference/layers/common/quantization.py +270 -0
  167. tpu_inference/layers/common/sharding.py +600 -0
  168. tpu_inference/layers/jax/__init__.py +13 -0
  169. tpu_inference/layers/jax/attention/__init__.py +13 -0
  170. tpu_inference/layers/jax/attention/attention.py +268 -0
  171. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +547 -0
  172. tpu_inference/layers/jax/attention/gpt_oss_attention.py +275 -0
  173. tpu_inference/layers/jax/attention/llama4_attention.py +167 -0
  174. tpu_inference/layers/jax/base.py +165 -0
  175. tpu_inference/layers/jax/constants.py +101 -0
  176. tpu_inference/layers/jax/layers.py +315 -0
  177. tpu_inference/layers/jax/misc.py +30 -0
  178. tpu_inference/layers/jax/moe/__init__.py +13 -0
  179. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +615 -0
  180. tpu_inference/layers/jax/moe/gpt_oss_moe.py +199 -0
  181. tpu_inference/layers/jax/moe/moe.py +249 -0
  182. tpu_inference/layers/jax/pp_utils.py +53 -0
  183. tpu_inference/layers/jax/rope.py +294 -0
  184. tpu_inference/layers/jax/rope_interface.py +228 -0
  185. tpu_inference/layers/jax/sample/__init__.py +13 -0
  186. tpu_inference/layers/jax/sample/rejection_sampler.py +528 -0
  187. tpu_inference/layers/jax/sample/sampling.py +110 -0
  188. tpu_inference/layers/jax/sample/sampling_metadata.py +90 -0
  189. tpu_inference/layers/jax/transformer_block.py +121 -0
  190. tpu_inference/layers/vllm/__init__.py +13 -0
  191. tpu_inference/layers/vllm/attention.py +221 -0
  192. tpu_inference/layers/vllm/fused_moe.py +502 -0
  193. tpu_inference/layers/vllm/linear_common.py +221 -0
  194. tpu_inference/layers/vllm/quantization/__init__.py +55 -0
  195. tpu_inference/layers/vllm/quantization/awq.py +221 -0
  196. tpu_inference/layers/vllm/quantization/common.py +124 -0
  197. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  198. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +135 -0
  199. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +266 -0
  200. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  201. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +222 -0
  202. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +150 -0
  203. tpu_inference/layers/vllm/quantization/fp8.py +118 -0
  204. tpu_inference/layers/vllm/quantization/mxfp4.py +396 -0
  205. tpu_inference/layers/vllm/quantization/unquantized.py +416 -0
  206. tpu_inference/layers/vllm/sharding.py +244 -0
  207. tpu_inference/logger.py +10 -0
  208. tpu_inference/lora/__init__.py +13 -0
  209. tpu_inference/lora/torch_lora_ops.py +98 -0
  210. tpu_inference/lora/torch_punica_tpu.py +310 -0
  211. tpu_inference/models/__init__.py +13 -0
  212. tpu_inference/models/common/__init__.py +13 -0
  213. tpu_inference/models/common/model_loader.py +520 -0
  214. tpu_inference/models/jax/__init__.py +13 -0
  215. tpu_inference/models/jax/deepseek_v3.py +978 -0
  216. tpu_inference/models/jax/gpt_oss.py +508 -0
  217. tpu_inference/models/jax/jax_intermediate_tensor.py +93 -0
  218. tpu_inference/models/jax/llama3.py +436 -0
  219. tpu_inference/models/jax/llama4.py +643 -0
  220. tpu_inference/models/jax/llama_eagle3.py +350 -0
  221. tpu_inference/models/jax/llama_guard_4.py +375 -0
  222. tpu_inference/models/jax/qwen2.py +390 -0
  223. tpu_inference/models/jax/qwen2_5_vl.py +1232 -0
  224. tpu_inference/models/jax/qwen3.py +318 -0
  225. tpu_inference/models/jax/utils/__init__.py +13 -0
  226. tpu_inference/models/jax/utils/file_utils.py +110 -0
  227. tpu_inference/models/jax/utils/multi_modal_utils.py +177 -0
  228. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  229. tpu_inference/models/jax/utils/qwix/qwix_utils.py +713 -0
  230. tpu_inference/models/jax/utils/weight_utils.py +621 -0
  231. tpu_inference/models/vllm/__init__.py +13 -0
  232. tpu_inference/models/vllm/vllm_model_wrapper.py +307 -0
  233. tpu_inference/models/vllm/vllm_model_wrapper_context.py +59 -0
  234. tpu_inference/platforms/__init__.py +16 -0
  235. tpu_inference/platforms/tpu_platform.py +258 -0
  236. tpu_inference/runner/__init__.py +13 -0
  237. tpu_inference/runner/block_table.py +122 -0
  238. tpu_inference/runner/compilation_manager.py +890 -0
  239. tpu_inference/runner/input_batch.py +435 -0
  240. tpu_inference/runner/kv_cache.py +166 -0
  241. tpu_inference/runner/kv_cache_manager.py +508 -0
  242. tpu_inference/runner/lora_utils.py +106 -0
  243. tpu_inference/runner/multimodal_manager.py +231 -0
  244. tpu_inference/runner/persistent_batch_manager.py +296 -0
  245. tpu_inference/runner/speculative_decoding_manager.py +262 -0
  246. tpu_inference/runner/structured_decoding_manager.py +101 -0
  247. tpu_inference/runner/tpu_runner.py +1768 -0
  248. tpu_inference/runner/utils.py +426 -0
  249. tpu_inference/spec_decode/__init__.py +13 -0
  250. tpu_inference/spec_decode/jax/__init__.py +13 -0
  251. tpu_inference/spec_decode/jax/eagle3.py +430 -0
  252. tpu_inference/tpu_info.py +92 -0
  253. tpu_inference/utils.py +345 -0
  254. tpu_inference/worker/__init__.py +13 -0
  255. tpu_inference/worker/tpu_worker.py +468 -0
  256. tpu_inference-0.12.0.dev20251222.dist-info/METADATA +106 -0
  257. tpu_inference-0.12.0.dev20251222.dist-info/RECORD +260 -0
  258. tpu_inference-0.12.0.dev20251222.dist-info/WHEEL +5 -0
  259. tpu_inference-0.12.0.dev20251222.dist-info/licenses/LICENSE +201 -0
  260. tpu_inference-0.12.0.dev20251222.dist-info/top_level.txt +2 -0
@@ -0,0 +1,876 @@
1
+ # Copyright 2025 The JAX Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """TPU-Friendly Ragged Paged Attention kernel.
15
+
16
+ This kernel offers a highly optimized implementation of ragged paged attention,
17
+ specifically designed for TPU and compatible with a wide range of model
18
+ specifications. It supports mixed prefill and decoding, enhancing throughput
19
+ during inference.
20
+ """
21
+ import functools
22
+
23
+ import jax
24
+ import jax.numpy as jnp
25
+ from jax import lax
26
+ from jax._src import dtypes
27
+ from jax.experimental import pallas as pl
28
+ from jax.experimental.pallas import tpu as pltpu
29
+ from jax.experimental.pallas.ops.tpu.ragged_paged_attention.tuned_block_sizes import \
30
+ get_tuned_block_sizes
31
+
32
+ DEFAULT_MASK_VALUE = -0.7 * float(jnp.finfo(jnp.dtype("float32")).max)
33
+
34
+
35
+ class MultiPageAsyncCopyDescriptor:
36
+ """Descriptor for async copy of multiple K/V pages from HBM."""
37
+
38
+ def __init__(
39
+ self,
40
+ pages_hbm_ref, # [total_num_pages, page_size, num_combined_kv_heads_per_blk, head_dim]
41
+ vmem_buf, # [num_kv_pages_per_blk, page_size, num_combined_kv_heads_per_blk, head_dim]
42
+ sem,
43
+ page_indices_ref, # i32[max_num_seqs, pages_per_seq]
44
+ metadata, # [seq_idx, start_page_idx, end_page_idx]
45
+ ):
46
+ self._vmem_buf = vmem_buf
47
+ seq_id, start_page_idx, end_page_idx = metadata
48
+ self._async_copies = []
49
+ # TODO(jevinjiang): Only fetch dynamic shape in need! This will insert
50
+ # a bunch of if-ops. Check the performance when we have benchmarking setup.
51
+ for i in range(vmem_buf.shape[0]):
52
+ page_idx = start_page_idx + i
53
+ page_idx = jax.lax.select(page_idx < end_page_idx, page_idx, 0)
54
+ self._async_copies.append(
55
+ pltpu.make_async_copy(
56
+ pages_hbm_ref.at[page_indices_ref[seq_id, page_idx]],
57
+ vmem_buf.at[i],
58
+ sem,
59
+ ))
60
+
61
+ def start(self):
62
+ """Starts the async copies."""
63
+ for async_copy in self._async_copies:
64
+ async_copy.start()
65
+
66
+ def wait(self):
67
+ for async_copy in self._async_copies:
68
+ async_copy.wait()
69
+ return self._vmem_buf
70
+
71
+
72
+ def ref_ragged_paged_attention(
73
+ queries: jax.Array, # [max_num_batched_tokens, num_q_heads, head_dim]
74
+ kv_pages: jax.
75
+ Array, # [total_num_pages, page_size, num_combined_kv_heads, head_dim]
76
+ kv_lens: jax.Array, # i32[max_num_seqs]
77
+ page_indices: jax.Array, # i32[max_num_seqs, pages_per_seq]
78
+ cu_q_lens: jax.Array, # i32[max_num_seqs + 1]
79
+ num_seqs: jax.Array, # i32[1],
80
+ *,
81
+ sm_scale: float = 1.0,
82
+ sliding_window: int | None = None,
83
+ soft_cap: float | None = None,
84
+ mask_value: float | None = DEFAULT_MASK_VALUE,
85
+ k_scale: float | None = None,
86
+ v_scale: float | None = None,
87
+ ):
88
+ static_validate_inputs(
89
+ queries,
90
+ kv_pages,
91
+ kv_lens,
92
+ page_indices,
93
+ cu_q_lens,
94
+ num_seqs,
95
+ sm_scale=sm_scale,
96
+ k_scale=k_scale,
97
+ v_scale=v_scale,
98
+ sliding_window=sliding_window,
99
+ soft_cap=soft_cap,
100
+ mask_value=mask_value,
101
+ )
102
+ if mask_value is None:
103
+ mask_value = DEFAULT_MASK_VALUE
104
+ _, _, num_combined_kv_heads, head_dim = kv_pages.shape
105
+ assert num_combined_kv_heads % 2 == 0
106
+ num_kv_heads = num_combined_kv_heads // 2
107
+ num_q_heads = queries.shape[1]
108
+ assert num_q_heads % num_kv_heads == 0
109
+ num_query_per_kv = num_q_heads // num_kv_heads
110
+ outputs = []
111
+ for i in range(num_seqs[0]):
112
+ q_start = cu_q_lens[i]
113
+ q_end = cu_q_lens[i + 1]
114
+ q_len = q_end - q_start
115
+ kv_len = kv_lens[i]
116
+ indices = page_indices[i]
117
+ q = queries[q_start:q_end]
118
+ k = kv_pages[indices, :, 0::2, :].reshape(-1, num_kv_heads,
119
+ head_dim)[:kv_len]
120
+ v = kv_pages[indices, :, 1::2, :].reshape(-1, num_kv_heads,
121
+ head_dim)[:kv_len]
122
+ if k_scale is not None:
123
+ k = k.astype(jnp.float32) * k_scale
124
+ k = k.astype(q.dtype)
125
+ if v_scale is not None:
126
+ v = v.astype(jnp.float32) * v_scale
127
+ v = v.astype(q.dtype)
128
+ k = jnp.repeat(k, num_query_per_kv, axis=1)
129
+ v = jnp.repeat(v, num_query_per_kv, axis=1)
130
+ attn = jnp.einsum("qhd,khd->hqk",
131
+ q,
132
+ k,
133
+ preferred_element_type=jnp.float32)
134
+ attn *= sm_scale
135
+ q_span = (kv_len - q_len) + jax.lax.broadcasted_iota(
136
+ jnp.int32, attn.shape, 1)
137
+ kv_span = jax.lax.broadcasted_iota(jnp.int32, attn.shape, 2)
138
+ mask = q_span < kv_span
139
+ if sliding_window is not None:
140
+ mask = jnp.logical_or(mask, q_span - sliding_window >= kv_span)
141
+ if soft_cap is not None:
142
+ attn = soft_cap * jnp.tanh(attn / soft_cap)
143
+ attn += jnp.where(mask, mask_value, 0.0)
144
+ attn = jax.nn.softmax(attn, axis=-1).astype(v.dtype)
145
+ out = jnp.einsum("hqk,khd->qhd", attn, v).astype(queries.dtype)
146
+ outputs.append(out)
147
+
148
+ return jnp.concatenate(outputs, axis=0)
149
+
150
+
151
+ # Expect to run these checks during runtime.
152
+ def dynamic_validate_inputs(
153
+ q: jax.Array, # [max_num_batched_tokens, num_q_heads, head_dim]
154
+ kv_pages: jax.
155
+ Array, # [total_num_pages, page_size, num_combined_kv_heads, head_dim]
156
+ kv_lens: jax.Array, # i32[max_num_seqs]
157
+ page_indices: jax.Array, # i32[max_num_seqs, pages_per_seq]
158
+ cu_q_lens: jax.Array, # i32[max_num_seqs + 1]
159
+ num_seqs: jax.Array, # i32[1]
160
+ *,
161
+ # These inputs are optional. If not specified, we will not validate them.
162
+ sm_scale: float | None = None,
163
+ sliding_window: int | None = None,
164
+ soft_cap: float | None = None,
165
+ mask_value: float | None = None,
166
+ k_scale: float | None = None,
167
+ v_scale: float | None = None,
168
+ # Kernel tuning params.
169
+ num_kv_pages_per_block: int | None = None,
170
+ num_queries_per_block: int | None = None,
171
+ vmem_limit_bytes: int | None = None,
172
+ ):
173
+ static_validate_inputs(
174
+ q,
175
+ kv_pages,
176
+ kv_lens,
177
+ page_indices,
178
+ cu_q_lens,
179
+ num_seqs,
180
+ sm_scale=sm_scale,
181
+ sliding_window=sliding_window,
182
+ soft_cap=soft_cap,
183
+ mask_value=mask_value,
184
+ k_scale=k_scale,
185
+ v_scale=v_scale,
186
+ num_kv_pages_per_block=num_kv_pages_per_block,
187
+ num_queries_per_block=num_queries_per_block,
188
+ vmem_limit_bytes=vmem_limit_bytes,
189
+ )
190
+ max_num_batched_tokens = q.shape[0]
191
+ page_size = kv_pages.shape[1]
192
+ max_num_seqs, pages_per_seq = page_indices.shape
193
+ if num_seqs[0] > max_num_seqs:
194
+ raise ValueError(
195
+ f"{num_seqs[0]=} must be less or equal to {max_num_seqs=}")
196
+ max_kv_len = jnp.max(kv_lens)
197
+ min_pages_per_seq = cdiv(max_kv_len, page_size)
198
+ if pages_per_seq < min_pages_per_seq:
199
+ raise ValueError(
200
+ f"{pages_per_seq=} must be greater or equal to"
201
+ f" {min_pages_per_seq=} given {max_kv_len=} and {page_size=}.")
202
+ if cu_q_lens[num_seqs[0]] > max_num_batched_tokens:
203
+ raise ValueError(
204
+ f"Total q tokens {cu_q_lens[num_seqs[0]]} must be less or equal to"
205
+ f" {max_num_batched_tokens=}.")
206
+ for i in range(num_seqs[0]):
207
+ q_len = cu_q_lens[i + 1] - cu_q_lens[i]
208
+ kv_len = kv_lens[i]
209
+ if q_len > kv_len:
210
+ raise ValueError(
211
+ f"{q_len=} must be less or equal to {kv_len=} at sequence {i}."
212
+ )
213
+
214
+
215
+ # Expect to run these checks during compile time.
216
+ def static_validate_inputs(
217
+ q: jax.Array, # [max_num_batched_tokens, num_q_heads, head_dim]
218
+ kv_pages: jax.
219
+ Array, # [total_num_pages, page_size, num_combined_kv_heads, head_dim]
220
+ kv_lens: jax.Array, # i32[max_num_seqs]
221
+ page_indices: jax.Array, # i32[max_num_seqs, pages_per_seq]
222
+ cu_q_lens: jax.Array, # i32[max_num_seqs + 1]
223
+ num_seqs: jax.Array, # i32[1]
224
+ *,
225
+ # These inputs are optional. If not specified, we will not validate them.
226
+ sm_scale: float | None = None,
227
+ sliding_window: int | None = None,
228
+ soft_cap: float | None = None,
229
+ mask_value: float | None = None,
230
+ k_scale: float | None = None,
231
+ v_scale: float | None = None,
232
+ # Kernel tuning params.
233
+ num_kv_pages_per_block: int | None = None,
234
+ num_queries_per_block: int | None = None,
235
+ vmem_limit_bytes: int | None = None,
236
+ ):
237
+ _, num_q_heads, head_dim = q.shape
238
+ _, _, num_combined_kv_heads, head_dim_k = kv_pages.shape
239
+ assert num_combined_kv_heads % 2 == 0
240
+ assert isinstance(k_scale, float) or k_scale is None
241
+ assert isinstance(v_scale, float) or v_scale is None
242
+ num_kv_heads = num_combined_kv_heads // 2
243
+ max_num_seqs, pages_per_seq = page_indices.shape
244
+ if num_seqs.shape != (1, ):
245
+ raise ValueError(f"{num_seqs.shape=} must be (1,)")
246
+ if head_dim_k != head_dim:
247
+ raise ValueError(
248
+ f"Q head_dim {head_dim} must be the same as that of K/V {head_dim_k}."
249
+ )
250
+ if kv_lens.shape != (max_num_seqs, ):
251
+ raise ValueError(
252
+ f"Expected {kv_lens.shape=} to be ({max_num_seqs},) where"
253
+ " `max_num_seqs` is `page_indices.shape[0]`.")
254
+ if cu_q_lens.shape != (max_num_seqs + 1, ):
255
+ raise ValueError(
256
+ f"Expected {cu_q_lens.shape=} to be ({max_num_seqs + 1},) where"
257
+ " `max_num_seqs` is `page_indices.shape[0]`.")
258
+ if (kv_lens.dtype != jnp.int32 or page_indices.dtype != jnp.int32
259
+ or cu_q_lens.dtype != jnp.int32):
260
+ raise ValueError(
261
+ "The dtype of `kv_lens`, `page_indices`, and `cu_q_lens` must be"
262
+ f" int32. Got {kv_lens.dtype=}, {page_indices.dtype=},"
263
+ f" {cu_q_lens.dtype=}.")
264
+ if num_q_heads % num_kv_heads != 0:
265
+ raise ValueError(
266
+ f"{num_q_heads=} must be divisible by {num_kv_heads=}")
267
+ if sliding_window is not None and sliding_window <= 0:
268
+ raise ValueError(f"{sliding_window=} must be positive.")
269
+ if soft_cap is not None and soft_cap == 0.0:
270
+ raise ValueError(f"{soft_cap=} must not be 0.0.")
271
+ if (num_kv_pages_per_block is not None
272
+ and not 0 < num_kv_pages_per_block <= pages_per_seq):
273
+ raise ValueError(
274
+ f"{num_kv_pages_per_block=} must be in range (0, {pages_per_seq}]."
275
+ )
276
+ if num_queries_per_block is not None and num_queries_per_block <= 0:
277
+ raise ValueError(f"{num_queries_per_block=} must be positive.")
278
+ if vmem_limit_bytes is not None and vmem_limit_bytes <= 0:
279
+ raise ValueError(f"{vmem_limit_bytes=} must be positive.")
280
+ del sm_scale # No constraints on sm_scale.
281
+ del mask_value # No consstraints on mask_value.
282
+
283
+
284
+ def ragged_paged_attention_kernel(
285
+ # Prefetch
286
+ kv_lens_ref, # [max_num_seqs]
287
+ page_indices_ref, # [max_num_seqs, pages_per_seq]
288
+ cu_q_lens_ref, # [max_num_seqs + 1]
289
+ seq_buf_idx_ref,
290
+ # TODO(jevinjiang): if OOM in SMEM, consider pack to other scalar refs.
291
+ num_seqs_ref,
292
+ # Input
293
+ q_ref, # [num_q_per_blk, num_q_heads_per_blk, head_dim]
294
+ kv_pages_hbm_ref, # [total_num_pages, page_size, num_combined_kv_heads, head_dim]
295
+ # Output
296
+ o_ref, # [num_q_per_blk, num_q_heads_per_blk, head_dim]
297
+ # Scratch
298
+ kv_bufs, # [2, num_kv_pages_per_blk, page_size, num_combined_kv_heads_per_blk, head_dim]
299
+ sems, # [2, 2]
300
+ l_ref, # [num_kv_heads_per_blk, num_q_per_blk * num_q_heads_per_kv_head, 128]
301
+ m_ref, # [num_kv_heads_per_blk, num_q_per_blk * num_q_heads_per_kv_head, 128]
302
+ acc_ref, # [num_q_per_blk, num_q_heads_per_blk, head_dim]
303
+ *,
304
+ sm_scale: float,
305
+ sliding_window: int | None = None,
306
+ soft_cap: float | None = None,
307
+ mask_value: float | None = DEFAULT_MASK_VALUE,
308
+ k_scale: float | None = None,
309
+ v_scale: float | None = None,
310
+ ):
311
+ if mask_value is None:
312
+ mask_value = DEFAULT_MASK_VALUE
313
+ num_q_per_blk, num_q_heads_per_blk, head_dim = q_ref.shape
314
+ pages_per_seq = page_indices_ref.shape[-1]
315
+ num_seqs = num_seqs_ref[0]
316
+ _, num_kv_pages_per_blk, page_size, num_combined_kv_heads_per_blk, _ = (
317
+ kv_bufs.shape)
318
+ num_kv_heads_per_blk = num_combined_kv_heads_per_blk // 2
319
+ num_kv_per_blk = num_kv_pages_per_blk * page_size
320
+ num_q_heads_per_kv_head = num_q_heads_per_blk // num_kv_heads_per_blk
321
+ heads_blk_idx, q_blk_idx = (
322
+ pl.program_id(0),
323
+ pl.program_id(1),
324
+ )
325
+ num_heads_blks = pl.num_programs(0)
326
+ init_seq_idx = seq_buf_idx_ref[0]
327
+ init_buf_idx = seq_buf_idx_ref[1]
328
+ q_len_start = q_blk_idx * num_q_per_blk
329
+ q_len_end = q_len_start + num_q_per_blk
330
+
331
+ def create_kv_async_copy_descriptors(heads_blk_idx, seq_idx, kv_blk_idx,
332
+ buf_idx):
333
+ start_kv_page_idx = kv_blk_idx * num_kv_pages_per_blk
334
+ end_kv_page_idx = jnp.minimum(pages_per_seq,
335
+ cdiv(kv_lens_ref[seq_idx], page_size))
336
+ metadata = (seq_idx, start_kv_page_idx, end_kv_page_idx)
337
+ heads_start = heads_blk_idx * num_combined_kv_heads_per_blk
338
+ async_copy_kv = MultiPageAsyncCopyDescriptor(
339
+ kv_pages_hbm_ref.
340
+ at[:, :,
341
+ pl.ds(heads_start, num_combined_kv_heads_per_blk), :],
342
+ kv_bufs.at[buf_idx],
343
+ sems.at[buf_idx],
344
+ page_indices_ref,
345
+ metadata,
346
+ )
347
+ return async_copy_kv
348
+
349
+ # TODO(jevinjiang): Add these to Mosaic:
350
+ # 1. Support arbitrary strided load/store for int4 and int8 dtype.
351
+ # 2. Support arbitrary strided load/store for any last dimension.
352
+ def strided_load_kv(ref, start, step):
353
+ packing = get_dtype_packing(ref.dtype)
354
+ if packing == 1:
355
+ return [ref[start::step, :]], [ref[start + 1::step, :]]
356
+ assert packing in (2, 4, 8)
357
+ assert step % packing == 0
358
+ k_list, v_list = [], []
359
+ b_start = start // packing
360
+ b_step = step // packing
361
+ b_ref = ref.bitcast(jnp.uint32)
362
+ b = b_ref[b_start::b_step, :]
363
+
364
+ # TODO(chengjiyao): use the general strided loading logic for bf16 after
365
+ # fixing the issue in mosaic's infer vector layout pass
366
+ if ref.dtype == jnp.bfloat16:
367
+ bk = b << 16
368
+ bv = b & jnp.uint32(0xFFFF0000)
369
+ k = pltpu.bitcast(bk, jnp.float32).astype(jnp.bfloat16)
370
+ v = pltpu.bitcast(bv, jnp.float32).astype(jnp.bfloat16)
371
+ k_list.append(k)
372
+ v_list.append(v)
373
+ else:
374
+ bitwidth = 32 // packing
375
+ bitcast_dst_dtype = jnp.dtype(f"uint{bitwidth}")
376
+ for i in range(0, packing, 2):
377
+ bk = b >> (i * bitwidth)
378
+ k = pltpu.bitcast(bk.astype(bitcast_dst_dtype), ref.dtype)
379
+ k_list.append(k)
380
+ bv = b >> ((i + 1) * bitwidth)
381
+ v = pltpu.bitcast(bv.astype(bitcast_dst_dtype), ref.dtype)
382
+ v_list.append(v)
383
+
384
+ return k_list, v_list
385
+
386
+ def fold_on_2nd_minor(vec):
387
+ assert vec.dtype == jnp.bfloat16 or vec.dtype == jnp.float32
388
+ assert len(vec.shape) >= 2
389
+ last_dim = vec.shape[-1]
390
+ packing = get_dtype_packing(vec.dtype)
391
+ if vec.shape[-2] % packing != 0:
392
+ vec = vec.astype(jnp.float32)
393
+ return vec.reshape(-1, last_dim)
394
+
395
+ @pl.when(heads_blk_idx + q_blk_idx == 0)
396
+ def prefetch_first_kv_blk():
397
+ async_copy_kv = create_kv_async_copy_descriptors(
398
+ heads_blk_idx, init_seq_idx, 0, init_buf_idx)
399
+ async_copy_kv.start()
400
+
401
+ def is_cur_q_blk_needed(q_states):
402
+ done, cur_seq_idx, _ = q_states
403
+ should_run = jnp.logical_and(q_len_start < cu_q_lens_ref[num_seqs],
404
+ cur_seq_idx < num_seqs)
405
+ return jnp.logical_and(done == 0, should_run)
406
+
407
+ def compute_with_cur_q_blk(q_states):
408
+ done, cur_seq_idx, cur_buf_idx = q_states
409
+ q_start = cu_q_lens_ref[cur_seq_idx]
410
+ q_end = cu_q_lens_ref[cur_seq_idx + 1]
411
+ q_len = q_end - q_start
412
+ kv_len = kv_lens_ref[cur_seq_idx]
413
+
414
+ def get_next_prefetch_ids(heads_blk_idx, cur_seq_idx, kv_blk_idx,
415
+ cur_buf_idx):
416
+ next_kv_blk_idx = kv_blk_idx + 1
417
+ is_last_kv_blk = next_kv_blk_idx * num_kv_per_blk >= kv_len
418
+ next_kv_blk_idx = lax.select(
419
+ is_last_kv_blk,
420
+ 0,
421
+ next_kv_blk_idx,
422
+ )
423
+ is_cur_seq_end_in_cur_q_blk = q_end <= q_len_end
424
+ next_seq_idx = lax.select(
425
+ is_last_kv_blk,
426
+ lax.select(is_cur_seq_end_in_cur_q_blk, cur_seq_idx + 1,
427
+ cur_seq_idx),
428
+ cur_seq_idx,
429
+ )
430
+ is_last_seq = next_seq_idx == num_seqs
431
+ next_seq_idx = lax.select(
432
+ is_last_seq,
433
+ 0,
434
+ next_seq_idx,
435
+ )
436
+ next_heads_blk_idx = lax.select(
437
+ is_last_seq,
438
+ heads_blk_idx + 1,
439
+ heads_blk_idx,
440
+ )
441
+ next_buf_idx = lax.select(cur_buf_idx == 0, 1, 0)
442
+ return next_heads_blk_idx, next_seq_idx, next_kv_blk_idx, next_buf_idx
443
+
444
+ def flash_attention(
445
+ q, # [num_q_per_blk * num_q_heads_per_kv_head, head_dim]
446
+ k, # [num_kv_per_blk, head_dim]
447
+ v, # [num_kv_per_blk, head_dim]
448
+ head_l_ref, # [num_q_per_blk * num_q_heads_per_kv_head, 128]
449
+ head_m_ref, # [num_q_per_blk * num_q_heads_per_kv_head, 128]
450
+ head_acc_ref, # [num_q_per_blk, num_q_heads_per_kv_head, head_dim]
451
+ *,
452
+ kv_blk_idx,
453
+ ):
454
+ assert q.shape == (
455
+ num_q_per_blk * num_q_heads_per_kv_head,
456
+ head_dim,
457
+ )
458
+ assert (k.shape == v.shape == (
459
+ num_kv_per_blk,
460
+ head_dim,
461
+ ))
462
+ assert k.dtype == v.dtype
463
+ assert (head_m_ref.shape == head_l_ref.shape == (
464
+ num_q_per_blk * num_q_heads_per_kv_head,
465
+ 128,
466
+ ))
467
+ assert head_acc_ref.shape == (
468
+ num_q_per_blk,
469
+ num_q_heads_per_kv_head,
470
+ head_dim,
471
+ )
472
+ kv_len_start = kv_blk_idx * num_kv_per_blk
473
+
474
+ def masked_store(ref, val, start, end, group=1):
475
+ iota = lax.broadcasted_iota(jnp.int32, ref.shape, 0) // group
476
+ mask = jnp.logical_and(iota >= start, iota < end)
477
+ pl.store(ref,
478
+ idx=tuple(slice(None) for _ in ref.shape),
479
+ val=val,
480
+ mask=mask)
481
+
482
+ def load_with_init(ref, init_val):
483
+ return jnp.where(kv_blk_idx == 0, jnp.full_like(ref, init_val),
484
+ ref[...])
485
+
486
+ # kv lens will be contracting dim, we should mask out the NaNs.
487
+ kv_mask = (lax.broadcasted_iota(jnp.int32, k.shape, 0)
488
+ < kv_len - kv_len_start)
489
+ k = jnp.where(kv_mask, k.astype(jnp.float32), 0).astype(k.dtype)
490
+ v = jnp.where(kv_mask, v.astype(jnp.float32), 0).astype(v.dtype)
491
+
492
+ qk = (jnp.einsum(
493
+ "nd,md->nm", q, k, preferred_element_type=jnp.float32) *
494
+ sm_scale)
495
+ store_start = jnp.maximum(q_start - q_len_start, 0)
496
+ store_end = jnp.minimum(q_end - q_len_start, num_q_per_blk)
497
+
498
+ row_ids = (
499
+ (kv_len - q_len) + q_len_start - q_start +
500
+ jax.lax.broadcasted_iota(
501
+ jnp.int32,
502
+ (num_q_per_blk * num_q_heads_per_kv_head, num_kv_per_blk),
503
+ 0,
504
+ ) // num_q_heads_per_kv_head)
505
+ col_ids = kv_len_start + jax.lax.broadcasted_iota(
506
+ jnp.int32,
507
+ (num_q_per_blk * num_q_heads_per_kv_head, num_kv_per_blk),
508
+ 1,
509
+ )
510
+ causal_mask = row_ids < col_ids
511
+ if sliding_window is not None:
512
+ causal_mask = jnp.logical_or(
513
+ causal_mask, row_ids - sliding_window >= col_ids)
514
+ if soft_cap is not None:
515
+ qk = soft_cap * jnp.tanh(qk / soft_cap)
516
+ qk += jnp.where(causal_mask, mask_value, 0.0)
517
+ m_curr = jnp.max(qk, axis=1, keepdims=True)
518
+ s_curr = jnp.exp(qk - m_curr)
519
+ qkv = jnp.dot(s_curr, v, preferred_element_type=jnp.float32)
520
+ lm_store_shape = head_m_ref.shape
521
+ m_curr = jnp.broadcast_to(m_curr, lm_store_shape)
522
+ l_curr = jnp.broadcast_to(s_curr.sum(axis=1, keepdims=True),
523
+ lm_store_shape)
524
+ m_prev = load_with_init(head_m_ref, -jnp.inf)
525
+ l_prev = load_with_init(head_l_ref, 0.0)
526
+ m_next = jnp.maximum(m_prev, m_curr)
527
+ masked_store(head_m_ref, m_next, store_start, store_end,
528
+ num_q_heads_per_kv_head)
529
+ alpha = jnp.exp(m_prev - m_next)
530
+ beta = jnp.exp(m_curr - m_next)
531
+ l_alpha = alpha * l_prev
532
+ l_next = l_alpha + beta * l_curr
533
+ l_next_safe = jnp.where(l_next == 0.0, 1.0, l_next)
534
+ masked_store(
535
+ head_l_ref,
536
+ l_next_safe,
537
+ store_start,
538
+ store_end,
539
+ num_q_heads_per_kv_head,
540
+ )
541
+
542
+ def broadcast_to_shape(arr, shape):
543
+ if arr.shape == shape:
544
+ return arr
545
+ assert len(arr.shape) == len(shape)
546
+ assert arr.shape[0] == shape[0]
547
+ assert shape[1] % arr.shape[1] == 0
548
+ # no-op concatenation.
549
+ return jnp.concatenate(
550
+ [arr for _ in range(shape[1] // arr.shape[1])], axis=1)
551
+
552
+ o_curr = load_with_init(head_acc_ref, 0.0).reshape(-1, head_dim)
553
+ l_alpha = broadcast_to_shape(l_alpha, qkv.shape)
554
+ beta = broadcast_to_shape(beta, qkv.shape)
555
+ l_next_safe = broadcast_to_shape(l_next_safe, qkv.shape)
556
+ out = lax.div(
557
+ l_alpha * o_curr + beta * qkv,
558
+ l_next_safe,
559
+ )
560
+ masked_store(
561
+ head_acc_ref,
562
+ out.reshape(head_acc_ref.shape),
563
+ store_start,
564
+ store_end,
565
+ )
566
+
567
+ def is_valid_kv_blk_in_cur_seq(kv_states):
568
+ kv_blk_idx, _ = kv_states
569
+ return kv_blk_idx * num_kv_per_blk < kv_len
570
+
571
+ def compute_with_kv_blk_in_cur_seq(kv_states):
572
+ kv_blk_idx, cur_buf_idx = kv_states
573
+ next_heads_blk_idx, next_seq_idx, next_kv_blk_idx, next_buf_idx = (
574
+ get_next_prefetch_ids(heads_blk_idx, cur_seq_idx, kv_blk_idx,
575
+ cur_buf_idx))
576
+
577
+ @pl.when(next_heads_blk_idx < num_heads_blks)
578
+ def prefetch_next_kv_blk():
579
+ # TODO(jevinjiang): reuse the same buffer if it is already prefetched!
580
+ # TODO(jevinjiang): only fetch effective dynamic size to hold kv_len and
581
+ # DMA to fixed size buffer!
582
+ next_async_copy_kv = create_kv_async_copy_descriptors(
583
+ next_heads_blk_idx, next_seq_idx, next_kv_blk_idx,
584
+ next_buf_idx)
585
+ next_async_copy_kv.start()
586
+
587
+ cur_async_copy_kv = create_kv_async_copy_descriptors(
588
+ heads_blk_idx, cur_seq_idx, kv_blk_idx, cur_buf_idx)
589
+ kv_ref = cur_async_copy_kv.wait().reshape(
590
+ num_kv_pages_per_blk * page_size *
591
+ num_combined_kv_heads_per_blk,
592
+ head_dim,
593
+ )
594
+ kv_packing = get_dtype_packing(kv_ref.dtype)
595
+ # NOTE: kv_packing is divided by 2 because k and v are packed together.
596
+ kv_load_step = max(1, kv_packing // 2)
597
+ for kv_head_chunk_idx in range(0, num_kv_heads_per_blk,
598
+ kv_load_step):
599
+ k_list, v_list = strided_load_kv(
600
+ kv_ref, kv_head_chunk_idx * 2,
601
+ num_combined_kv_heads_per_blk)
602
+ for step_idx in range(kv_load_step):
603
+ k = k_list[step_idx]
604
+ v = v_list[step_idx]
605
+ if k_scale is not None:
606
+ # NOTE: Conversion between arbitrary data types is not supported.
607
+ # That's why it is converted to float32 first.
608
+ k = k.astype(jnp.float32) * k_scale
609
+ k = k.astype(q_ref.dtype)
610
+ if v_scale is not None:
611
+ v = v.astype(jnp.float32) * v_scale
612
+ v = v.astype(q_ref.dtype)
613
+ kv_head_idx = kv_head_chunk_idx + step_idx
614
+ q_head_idx = kv_head_idx * num_q_heads_per_kv_head
615
+ # TODO(jevinjiang): extra handling for packed type that can start at
616
+ # unaligned position!
617
+ q = fold_on_2nd_minor(q_ref[:, q_head_idx:q_head_idx +
618
+ num_q_heads_per_kv_head, :])
619
+ flash_attention(
620
+ q,
621
+ k,
622
+ v,
623
+ l_ref.at[kv_head_idx],
624
+ m_ref.at[kv_head_idx],
625
+ acc_ref.at[:, q_head_idx:q_head_idx +
626
+ num_q_heads_per_kv_head, :],
627
+ kv_blk_idx=kv_blk_idx,
628
+ )
629
+ return kv_blk_idx + 1, next_buf_idx
630
+
631
+ _, next_buf_idx = lax.while_loop(
632
+ is_valid_kv_blk_in_cur_seq,
633
+ compute_with_kv_blk_in_cur_seq,
634
+ (0, cur_buf_idx), # (kv_blk_idx, buf_idx)
635
+ )
636
+ next_seq_idx = lax.select(q_end <= q_len_end, cur_seq_idx + 1,
637
+ cur_seq_idx)
638
+ done = lax.select(q_end < q_len_end, done, 1)
639
+ return done, next_seq_idx, next_buf_idx
640
+
641
+ _, seq_idx, buf_idx = lax.while_loop(
642
+ is_cur_q_blk_needed,
643
+ compute_with_cur_q_blk,
644
+ (0, init_seq_idx, init_buf_idx), # (done, seq_idx, buf_idx)
645
+ )
646
+ # Reset seq_idx for next kv_heads_blk if run out of seqs!
647
+ seq_buf_idx_ref[0] = lax.select(seq_idx < num_seqs, seq_idx, 0)
648
+ seq_buf_idx_ref[1] = buf_idx
649
+ o_ref[...] = acc_ref[...].astype(q_ref.dtype)
650
+
651
+
652
+ def cdiv(a, b):
653
+ assert b != 0
654
+ return (a + b - 1) // b
655
+
656
+
657
+ def get_dtype_packing(dtype):
658
+ bits = (dtypes.bit_width(dtype)
659
+ if hasattr(dtypes, "bit_width") else dtypes.itemsize_bits(dtype))
660
+ return 32 // bits
661
+
662
+
663
+ def get_min_heads_per_blk(num_q_heads, num_combined_kv_heads, q_dtype,
664
+ kv_dtype):
665
+ q_packing = get_dtype_packing(q_dtype)
666
+ kv_packing = get_dtype_packing(kv_dtype)
667
+
668
+ def can_be_xla_fully_tiled(x, packing):
669
+ if x % packing != 0:
670
+ return False
671
+ x //= packing
672
+ return x in (1, 2, 4, 8) or x % 8 == 0
673
+
674
+ # TODO(jevinjiang): support unaligned number of heads!
675
+ if not can_be_xla_fully_tiled(num_combined_kv_heads, kv_packing):
676
+ raise ValueError(
677
+ f"Not implemented: {num_combined_kv_heads=} can not be XLA fully tiled."
678
+ )
679
+ assert num_combined_kv_heads % 2 == 0
680
+ num_kv_heads = num_combined_kv_heads // 2
681
+ assert num_q_heads % num_kv_heads == 0
682
+ ratio = num_q_heads // num_kv_heads
683
+ # TODO(jevinjiang): we can choose smaller tiling for packed type if large
684
+ # second minor tiling is not on.
685
+ max_combined_kv_tiling = 8 * kv_packing
686
+ min_combined_kv_heads = (max_combined_kv_tiling if num_combined_kv_heads %
687
+ max_combined_kv_tiling == 0 else
688
+ num_combined_kv_heads)
689
+ min_q_heads = min_combined_kv_heads // 2 * ratio
690
+ if can_be_xla_fully_tiled(min_q_heads, q_packing):
691
+ return min_q_heads, min_combined_kv_heads
692
+ return num_q_heads, num_combined_kv_heads
693
+
694
+
695
+ @functools.partial(
696
+ jax.jit,
697
+ static_argnames=[
698
+ "sm_scale",
699
+ "mask_value",
700
+ "num_kv_pages_per_block",
701
+ "num_queries_per_block",
702
+ "vmem_limit_bytes",
703
+ "sliding_window",
704
+ "soft_cap",
705
+ "k_scale",
706
+ "v_scale",
707
+ ],
708
+ )
709
+ def ragged_paged_attention(
710
+ q: jax.Array, # [max_num_batched_tokens, num_q_heads, head_dim]
711
+ # TODO(jevinjiang): create a write_to_kv_cache kernel!
712
+ kv_pages: jax.
713
+ Array, # [total_num_pages, page_size, num_combined_kv_heads, head_dim]
714
+ kv_lens: jax.Array, # i32[max_num_seqs]
715
+ page_indices: jax.Array, # i32[max_num_seqs, pages_per_seq]
716
+ cu_q_lens: jax.Array, # i32[max_num_seqs + 1]
717
+ num_seqs: jax.Array, # i32[1]
718
+ *,
719
+ sm_scale: float = 1.0,
720
+ sliding_window: int | None = None,
721
+ soft_cap: float | None = None,
722
+ mask_value: float | None = DEFAULT_MASK_VALUE,
723
+ k_scale: float | None = None,
724
+ v_scale: float | None = None,
725
+ num_kv_pages_per_block: int | None = None,
726
+ num_queries_per_block: int | None = None,
727
+ vmem_limit_bytes: int | None = None,
728
+ ):
729
+ """Ragged paged attention that supports mixed prefill and decode.
730
+
731
+ Args:
732
+ q: concatenated all sequences' queries.
733
+ kv_pages: paged KV cache. Normally in HBM.
734
+ kv_lens: padded kv lengths. Only the first num_seqs values are valid.
735
+ page_indices: the first index indicates which page to use in the kv cache
736
+ for each sequence. Only the first num_seqs values are valid.
737
+ cu_q_lens: the cumulative sum of the effective query lengths. Similar to
738
+ kv_lens, only the first num_seqs+1 values are valid.
739
+ num_seqs: the dynamic number of sequences.
740
+ sm_scale: the softmax scale which will be applied to the Q@K^T.
741
+ sliding_window: the sliding window size for the attention.
742
+ soft_cap: the logit soft cap for the attention.
743
+ mask_value: mask value for causal mask.
744
+ k_scale: the scale for the key cache.
745
+ v_scale: the scale for the value cache.
746
+ num_kv_pages_per_block: number of kv pages to be processed in one flash
747
+ attention block in the pallas kernel.
748
+ num_queries_per_block: number of kv pages to be processed in one flash
749
+ attention block in the pallas kernel.
750
+ vmem_limit_bytes: the vmem limit for the pallas kernel.
751
+
752
+ Returns:
753
+ The output of the attention.
754
+ """
755
+ static_validate_inputs(
756
+ q,
757
+ kv_pages,
758
+ kv_lens,
759
+ page_indices,
760
+ cu_q_lens,
761
+ num_seqs,
762
+ sm_scale=sm_scale,
763
+ sliding_window=sliding_window,
764
+ soft_cap=soft_cap,
765
+ mask_value=mask_value,
766
+ k_scale=k_scale,
767
+ v_scale=v_scale,
768
+ num_kv_pages_per_block=num_kv_pages_per_block,
769
+ num_queries_per_block=num_queries_per_block,
770
+ vmem_limit_bytes=vmem_limit_bytes,
771
+ )
772
+ if mask_value is None:
773
+ mask_value = DEFAULT_MASK_VALUE
774
+ num_q_tokens, num_q_heads, head_dim = q.shape
775
+ _, page_size, num_combined_kv_heads, _ = kv_pages.shape
776
+ assert num_combined_kv_heads % 2 == 0
777
+ num_kv_heads = num_combined_kv_heads // 2
778
+ _, pages_per_seq = page_indices.shape
779
+ num_q_heads_per_blk, num_combined_kv_heads_per_blk = get_min_heads_per_blk(
780
+ num_q_heads, num_combined_kv_heads, q.dtype, kv_pages.dtype)
781
+ num_q_per_blk = num_queries_per_block
782
+ num_kv_pages_per_blk = num_kv_pages_per_block
783
+ if num_q_per_blk is None or num_kv_pages_per_blk is None:
784
+ num_kv_pages_per_blk, num_q_per_blk = get_tuned_block_sizes(
785
+ q.dtype,
786
+ kv_pages.dtype,
787
+ num_q_heads_per_blk,
788
+ num_combined_kv_heads_per_blk // 2,
789
+ head_dim,
790
+ page_size,
791
+ num_q_tokens,
792
+ pages_per_seq,
793
+ )
794
+ num_q_heads_per_kv_head = num_q_heads // num_kv_heads
795
+ num_q_blks = cdiv(num_q_tokens, num_q_per_blk)
796
+ assert num_combined_kv_heads_per_blk % 2 == 0
797
+ num_kv_heads_per_blk = num_combined_kv_heads_per_blk // 2
798
+ assert num_q_heads_per_blk % num_q_heads_per_kv_head == 0
799
+ num_heads_blks = num_q_heads // num_q_heads_per_blk
800
+ grid = (num_heads_blks, num_q_blks)
801
+
802
+ def q_index_map(heads_blk_idx, q_blk_idx, *_):
803
+ return (q_blk_idx, heads_blk_idx, 0)
804
+
805
+ q_block_spec = pl.BlockSpec(
806
+ (num_q_per_blk, num_q_heads_per_blk, head_dim),
807
+ q_index_map,
808
+ )
809
+ in_specs = [
810
+ q_block_spec,
811
+ pl.BlockSpec(memory_space=pltpu.ANY),
812
+ ]
813
+ out_specs = q_block_spec
814
+ lm_scratch = pltpu.VMEM(
815
+ # TODO(jevinjiang): use 128 instead of 1 is due to Mosaic does not support
816
+ # unaligned slicing!
817
+ (num_kv_heads_per_blk, num_q_per_blk * num_q_heads_per_kv_head, 128),
818
+ jnp.float32,
819
+ )
820
+ acc_scratch = pltpu.VMEM(
821
+ (num_q_per_blk, num_q_heads_per_blk, head_dim),
822
+ jnp.float32,
823
+ )
824
+ double_buf_scratch = pltpu.VMEM(
825
+ (
826
+ 2, # For double buffering during DMA copies.
827
+ num_kv_pages_per_blk,
828
+ page_size,
829
+ num_combined_kv_heads_per_blk,
830
+ head_dim,
831
+ ),
832
+ kv_pages.dtype,
833
+ )
834
+ scratch_shapes = [
835
+ double_buf_scratch, # kv_bufs
836
+ pltpu.SemaphoreType.DMA((2, )), # Semaphores for double buffers.
837
+ lm_scratch, # l_ref
838
+ lm_scratch, # m_ref
839
+ acc_scratch,
840
+ ]
841
+ scalar_prefetches = (
842
+ kv_lens,
843
+ page_indices,
844
+ cu_q_lens,
845
+ jnp.array((0, 0), jnp.int32), # seq_idx, buf_idx
846
+ num_seqs,
847
+ )
848
+ kernel = pl.pallas_call(
849
+ functools.partial(
850
+ ragged_paged_attention_kernel,
851
+ sm_scale=sm_scale,
852
+ sliding_window=sliding_window,
853
+ soft_cap=soft_cap,
854
+ mask_value=mask_value,
855
+ k_scale=k_scale,
856
+ v_scale=v_scale,
857
+ ),
858
+ grid_spec=pltpu.PrefetchScalarGridSpec(
859
+ num_scalar_prefetch=len(scalar_prefetches),
860
+ in_specs=in_specs,
861
+ out_specs=out_specs,
862
+ grid=grid,
863
+ scratch_shapes=scratch_shapes,
864
+ ),
865
+ compiler_params=pltpu.CompilerParams(
866
+ dimension_semantics=(
867
+ "arbitrary",
868
+ "arbitrary",
869
+ ),
870
+ vmem_limit_bytes=vmem_limit_bytes,
871
+ ),
872
+ out_shape=jax.ShapeDtypeStruct(shape=q.shape, dtype=q.dtype),
873
+ name="ragged_paged_attention_kernel",
874
+ )
875
+
876
+ return kernel(*scalar_prefetches, q, kv_pages)