tpu-inference 0.11.1.dev202511150811__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (179) hide show
  1. tests/__init__.py +0 -0
  2. tests/core/__init__.py +0 -0
  3. tests/core/test_core_tpu.py +513 -0
  4. tests/core/test_disagg_executor.py +60 -0
  5. tests/core/test_disagg_utils.py +53 -0
  6. tests/core/test_dp_scheduler.py +899 -0
  7. tests/core/test_init.py +49 -0
  8. tests/kernels/__init__.py +0 -0
  9. tests/kernels/fused_moe_v1_test.py +105 -0
  10. tests/kernels/mla_v1_test.py +396 -0
  11. tests/kernels/quantized_matmul_kernel_test.py +191 -0
  12. tests/kernels/ragged_kv_cache_update_v2_test.py +234 -0
  13. tests/kernels/ragged_paged_attention_kernel_v2_test.py +400 -0
  14. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +549 -0
  15. tests/kernels/ragged_paged_attention_kernel_v3_test.py +504 -0
  16. tests/lora/__init__.py +0 -0
  17. tests/lora/conftest.py +32 -0
  18. tests/lora/test_bgmv.py +43 -0
  19. tests/lora/test_layers.py +654 -0
  20. tests/lora/test_lora.py +133 -0
  21. tests/lora/utils.py +96 -0
  22. tests/test_base.py +201 -0
  23. tests/test_envs.py +182 -0
  24. tests/test_quantization.py +836 -0
  25. tests/test_tpu_info.py +120 -0
  26. tests/test_utils.py +236 -0
  27. tpu_inference/__init__.py +34 -0
  28. tpu_inference/core/__init__.py +0 -0
  29. tpu_inference/core/core_tpu.py +786 -0
  30. tpu_inference/core/disagg_executor.py +118 -0
  31. tpu_inference/core/disagg_utils.py +51 -0
  32. tpu_inference/core/sched/__init__.py +0 -0
  33. tpu_inference/core/sched/dp_scheduler.py +523 -0
  34. tpu_inference/distributed/__init__.py +0 -0
  35. tpu_inference/distributed/jax_parallel_state.py +67 -0
  36. tpu_inference/distributed/tpu_connector.py +728 -0
  37. tpu_inference/distributed/utils.py +59 -0
  38. tpu_inference/env_override.py +9 -0
  39. tpu_inference/envs.py +107 -0
  40. tpu_inference/executors/__init__.py +0 -0
  41. tpu_inference/executors/ray_distributed_executor.py +362 -0
  42. tpu_inference/experimental/__init__.py +0 -0
  43. tpu_inference/experimental/llama3_jax_stashed.py +258 -0
  44. tpu_inference/kernels/__init__.py +0 -0
  45. tpu_inference/kernels/collectives/__init__.py +0 -0
  46. tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
  47. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
  48. tpu_inference/kernels/collectives/util.py +47 -0
  49. tpu_inference/kernels/flash_attention/__init__.py +0 -0
  50. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  51. tpu_inference/kernels/fused_moe/__init__.py +0 -0
  52. tpu_inference/kernels/fused_moe/v1/__init__.py +0 -0
  53. tpu_inference/kernels/fused_moe/v1/kernel.py +1035 -0
  54. tpu_inference/kernels/mla/__init__.py +0 -0
  55. tpu_inference/kernels/mla/v1/__init__.py +0 -0
  56. tpu_inference/kernels/mla/v1/kernel.py +1349 -0
  57. tpu_inference/kernels/quantized_matmul/__init__.py +0 -0
  58. tpu_inference/kernels/quantized_matmul/kernel.py +395 -0
  59. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  60. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  61. tpu_inference/kernels/ragged_paged_attention/__init__.py +0 -0
  62. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +0 -0
  63. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +875 -0
  64. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +287 -0
  65. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  66. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +0 -0
  67. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1478 -0
  68. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1482 -0
  69. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4147 -0
  70. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +367 -0
  71. tpu_inference/kernels/ragged_paged_attention/v3/util.py +51 -0
  72. tpu_inference/layers/__init__.py +0 -0
  73. tpu_inference/layers/common/__init__.py +0 -0
  74. tpu_inference/layers/common/attention_interface.py +390 -0
  75. tpu_inference/layers/common/attention_metadata.py +34 -0
  76. tpu_inference/layers/common/binary_search.py +295 -0
  77. tpu_inference/layers/common/quant_methods.py +8 -0
  78. tpu_inference/layers/common/sharding.py +582 -0
  79. tpu_inference/layers/jax/__init__.py +0 -0
  80. tpu_inference/layers/jax/attention/__init__.py +0 -0
  81. tpu_inference/layers/jax/attention/attention.py +255 -0
  82. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
  83. tpu_inference/layers/jax/attention/gpt_oss_attention.py +262 -0
  84. tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
  85. tpu_inference/layers/jax/base.py +151 -0
  86. tpu_inference/layers/jax/constants.py +88 -0
  87. tpu_inference/layers/jax/layers.py +301 -0
  88. tpu_inference/layers/jax/misc.py +16 -0
  89. tpu_inference/layers/jax/moe/__init__.py +0 -0
  90. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
  91. tpu_inference/layers/jax/moe/gpt_oss_moe.py +185 -0
  92. tpu_inference/layers/jax/moe/moe.py +209 -0
  93. tpu_inference/layers/jax/rope.py +280 -0
  94. tpu_inference/layers/jax/rope_interface.py +214 -0
  95. tpu_inference/layers/jax/sample/__init__.py +0 -0
  96. tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
  97. tpu_inference/layers/jax/sample/sampling.py +96 -0
  98. tpu_inference/layers/jax/sample/sampling_metadata.py +76 -0
  99. tpu_inference/layers/jax/transformer_block.py +107 -0
  100. tpu_inference/layers/vllm/__init__.py +0 -0
  101. tpu_inference/layers/vllm/attention.py +221 -0
  102. tpu_inference/layers/vllm/fused_moe.py +507 -0
  103. tpu_inference/layers/vllm/linear_common.py +186 -0
  104. tpu_inference/layers/vllm/quantization/__init__.py +39 -0
  105. tpu_inference/layers/vllm/quantization/awq.py +207 -0
  106. tpu_inference/layers/vllm/quantization/common.py +105 -0
  107. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
  108. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +120 -0
  109. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +203 -0
  110. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
  111. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
  112. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
  113. tpu_inference/layers/vllm/quantization/mxfp4.py +266 -0
  114. tpu_inference/layers/vllm/quantization/unquantized.py +386 -0
  115. tpu_inference/layers/vllm/sharding.py +230 -0
  116. tpu_inference/logger.py +10 -0
  117. tpu_inference/lora/__init__.py +0 -0
  118. tpu_inference/lora/torch_lora_ops.py +103 -0
  119. tpu_inference/lora/torch_punica_tpu.py +311 -0
  120. tpu_inference/mock/__init__.py +0 -0
  121. tpu_inference/mock/vllm_config_utils.py +28 -0
  122. tpu_inference/mock/vllm_envs.py +1219 -0
  123. tpu_inference/mock/vllm_logger.py +212 -0
  124. tpu_inference/mock/vllm_logging_utils.py +15 -0
  125. tpu_inference/models/__init__.py +0 -0
  126. tpu_inference/models/common/__init__.py +0 -0
  127. tpu_inference/models/common/model_loader.py +444 -0
  128. tpu_inference/models/jax/__init__.py +0 -0
  129. tpu_inference/models/jax/deepseek_v3.py +868 -0
  130. tpu_inference/models/jax/gpt_oss.py +492 -0
  131. tpu_inference/models/jax/jax_intermediate_tensor.py +79 -0
  132. tpu_inference/models/jax/llama3.py +375 -0
  133. tpu_inference/models/jax/llama4.py +629 -0
  134. tpu_inference/models/jax/llama_eagle3.py +333 -0
  135. tpu_inference/models/jax/phi3.py +376 -0
  136. tpu_inference/models/jax/qwen2.py +375 -0
  137. tpu_inference/models/jax/qwen2_5_vl.py +1103 -0
  138. tpu_inference/models/jax/qwen3.py +302 -0
  139. tpu_inference/models/jax/utils/__init__.py +0 -0
  140. tpu_inference/models/jax/utils/file_utils.py +96 -0
  141. tpu_inference/models/jax/utils/multi_modal_utils.py +163 -0
  142. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  143. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +5 -0
  144. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +6 -0
  145. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +5 -0
  146. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +6 -0
  147. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +105 -0
  148. tpu_inference/models/jax/utils/quantization/quantization_utils.py +653 -0
  149. tpu_inference/models/jax/utils/weight_utils.py +529 -0
  150. tpu_inference/models/vllm/__init__.py +0 -0
  151. tpu_inference/models/vllm/vllm_model_wrapper.py +286 -0
  152. tpu_inference/models/vllm/vllm_model_wrapper_context.py +45 -0
  153. tpu_inference/platforms/__init__.py +2 -0
  154. tpu_inference/platforms/tpu_platform.py +269 -0
  155. tpu_inference/runner/__init__.py +0 -0
  156. tpu_inference/runner/block_table.py +122 -0
  157. tpu_inference/runner/compilation_manager.py +780 -0
  158. tpu_inference/runner/input_batch.py +435 -0
  159. tpu_inference/runner/kv_cache.py +132 -0
  160. tpu_inference/runner/kv_cache_manager.py +479 -0
  161. tpu_inference/runner/lora_utils.py +92 -0
  162. tpu_inference/runner/multimodal_manager.py +217 -0
  163. tpu_inference/runner/persistent_batch_manager.py +244 -0
  164. tpu_inference/runner/speculative_decoding_manager.py +248 -0
  165. tpu_inference/runner/structured_decoding_manager.py +88 -0
  166. tpu_inference/runner/tpu_runner.py +1620 -0
  167. tpu_inference/runner/utils.py +426 -0
  168. tpu_inference/spec_decode/__init__.py +0 -0
  169. tpu_inference/spec_decode/jax/__init__.py +0 -0
  170. tpu_inference/spec_decode/jax/eagle3.py +367 -0
  171. tpu_inference/tpu_info.py +77 -0
  172. tpu_inference/utils.py +317 -0
  173. tpu_inference/worker/__init__.py +0 -0
  174. tpu_inference/worker/tpu_worker.py +321 -0
  175. tpu_inference-0.11.1.dev202511150811.dist-info/METADATA +107 -0
  176. tpu_inference-0.11.1.dev202511150811.dist-info/RECORD +179 -0
  177. tpu_inference-0.11.1.dev202511150811.dist-info/WHEEL +5 -0
  178. tpu_inference-0.11.1.dev202511150811.dist-info/licenses/LICENSE +201 -0
  179. tpu_inference-0.11.1.dev202511150811.dist-info/top_level.txt +2 -0
@@ -0,0 +1,1349 @@
1
+ """TPU-Friendly and Data-Movement-Friendly MLA Ragged Paged Attention kernel."""
2
+
3
+ import functools
4
+
5
+ import jax
6
+ import jax.numpy as jnp
7
+ from jax import lax
8
+ from jax.experimental import pallas as pl
9
+ from jax.experimental.pallas import tpu as pltpu
10
+
11
+ from tpu_inference.kernels.ragged_paged_attention.v3.util import (
12
+ align_to, cdiv, get_dtype_packing)
13
+
14
+ DEFAULT_MASK_VALUE = -0.7 * float(jnp.finfo(jnp.dtype("float32")).max)
15
+
16
+ DEFAULT_VMEM_LIMIT_BYTES = 100 * 1024 * 1024
17
+
18
+
19
+ @functools.partial(
20
+ jax.jit,
21
+ donate_argnames=("cache_kv_c", "cache_k_pe"),
22
+ )
23
+ def update_kv_cache(
24
+ new_kv_c: jax.Array, # [num_tokens, actual_lkv_dim]
25
+ new_k_pe: jax.Array, # [num_tokens, actual_r_dim]
26
+ cache_kv_c: jax.
27
+ Array, # [total_num_pages, page_size_per_kv_packing, kv_packing, lkv_dim]
28
+ cache_k_pe: jax.
29
+ Array, # [total_num_pages, page_size_per_kv_packing, kv_packing, r_dim]
30
+ kv_lens: jax.Array, # i32[max_num_seqs]
31
+ page_indices: jax.Array, # i32[max_num_seqs * pages_per_seq]
32
+ cu_q_lens: jax.Array, # i32[max_num_seqs + 1]
33
+ distribution: jax.Array, # i32[3]
34
+ ) -> tuple[jax.Array, jax.Array]:
35
+ """Update KV cache with new tokens."""
36
+ actual_r_dim = new_k_pe.shape[-1]
37
+ r_dim = align_to(actual_r_dim, 128)
38
+ if actual_r_dim != r_dim:
39
+ new_k_pe = jnp.pad(new_k_pe, ((0, 0), (0, r_dim - actual_r_dim)),
40
+ constant_values=0)
41
+ actual_lkv_dim = new_kv_c.shape[-1]
42
+ lkv_dim = align_to(actual_lkv_dim, 128)
43
+ if actual_lkv_dim != lkv_dim:
44
+ new_kv_c = jnp.pad(new_kv_c, ((0, 0), (0, lkv_dim - actual_lkv_dim)),
45
+ constant_values=0)
46
+
47
+ _, page_size_per_kv_packing, kv_packing, cache_lkv_dim = cache_kv_c.shape
48
+ _, _, _, cache_r_dim = cache_k_pe.shape
49
+ assert lkv_dim == cache_lkv_dim
50
+ assert r_dim == cache_r_dim
51
+ page_size = page_size_per_kv_packing * kv_packing
52
+
53
+ max_num_seqs = kv_lens.shape[0]
54
+ num_page_indices = page_indices.shape[0]
55
+ pages_per_seq = num_page_indices // max_num_seqs
56
+
57
+ def seq_loop_body(i, caches):
58
+ cache_kv_c, cache_k_pe = caches
59
+ q_start, q_end = cu_q_lens[i], cu_q_lens[i + 1]
60
+ q_len = q_end - q_start
61
+ kv_len = kv_lens[i]
62
+
63
+ def token_loop_body(j, caches_):
64
+ cache_kv_c_, cache_k_pe_ = caches_
65
+ token_idx_in_seq = kv_len - q_len + j
66
+ page_num_in_seq = token_idx_in_seq // page_size
67
+ page_indices_start = i * pages_per_seq
68
+ page_idx = page_indices[page_indices_start + page_num_in_seq]
69
+ row = (token_idx_in_seq % page_size) // kv_packing
70
+ col = (token_idx_in_seq % page_size) % kv_packing
71
+
72
+ cache_kv_c_ = cache_kv_c_.at[page_idx, row,
73
+ col].set(new_kv_c[q_start + j])
74
+ cache_k_pe_ = cache_k_pe_.at[page_idx, row,
75
+ col].set(new_k_pe[q_start + j])
76
+ return cache_kv_c_, cache_k_pe_
77
+
78
+ return lax.fori_loop(0, q_len, token_loop_body,
79
+ (cache_kv_c, cache_k_pe))
80
+
81
+ cache_kv_c, cache_k_pe = lax.fori_loop(0, distribution[-1], seq_loop_body,
82
+ (cache_kv_c, cache_k_pe))
83
+ return cache_kv_c, cache_k_pe
84
+
85
+
86
+ def ref_mla_ragged_paged_attention(
87
+ ql_nope: jax.Array, # [num_tokens, actual_num_q_heads, actual_lkv_dim]
88
+ q_pe: jax.Array, # [num_tokens, actual_num_q_heads, actual_r_dim]
89
+ new_kv_c: jax.Array, # [num_tokens, actual_lkv_dim]
90
+ new_k_pe: jax.Array, # [num_tokens, actual_r_dim]
91
+ cache_kv_c: jax.
92
+ Array, # [total_num_pages, page_size_per_kv_packing, kv_packing, lkv_dim]
93
+ cache_k_pe: jax.
94
+ Array, # [total_num_pages, page_size_per_kv_packing, kv_packing, r_dim]
95
+ kv_lens: jax.Array, # i32[max_num_seqs]
96
+ page_indices: jax.Array, # i32[max_num_seqs * pages_per_seq]
97
+ cu_q_lens: jax.Array, # i32[max_num_seqs + 1]
98
+ distribution: jax.Array, # i32[3]
99
+ *,
100
+ sm_scale: float = 1.0,
101
+ sliding_window: int | None = None,
102
+ soft_cap: float | None = None,
103
+ mask_value: float | None = DEFAULT_MASK_VALUE,
104
+ ):
105
+
106
+ if mask_value is None:
107
+ mask_value = DEFAULT_MASK_VALUE
108
+
109
+ dynamic_validate_inputs(
110
+ ql_nope,
111
+ q_pe,
112
+ new_kv_c,
113
+ new_k_pe,
114
+ cache_kv_c,
115
+ cache_k_pe,
116
+ kv_lens,
117
+ page_indices,
118
+ cu_q_lens,
119
+ distribution,
120
+ sm_scale=sm_scale,
121
+ sliding_window=sliding_window,
122
+ soft_cap=soft_cap,
123
+ mask_value=mask_value,
124
+ )
125
+
126
+ cache_kv_c, cache_k_pe = update_kv_cache(
127
+ new_kv_c,
128
+ new_k_pe,
129
+ cache_kv_c,
130
+ cache_k_pe,
131
+ kv_lens,
132
+ page_indices,
133
+ cu_q_lens,
134
+ distribution,
135
+ )
136
+ # Pad ql_nope and q_pe to make the last dimension 128-byte aligned.
137
+ actual_lkv_dim = ql_nope.shape[-1]
138
+ lkv_dim = align_to(actual_lkv_dim, 128)
139
+ if lkv_dim != actual_lkv_dim:
140
+ ql_nope = jnp.pad(
141
+ ql_nope,
142
+ ((0, 0), (0, 0), (0, lkv_dim - actual_lkv_dim)),
143
+ constant_values=0,
144
+ )
145
+ actual_r_dim = q_pe.shape[-1]
146
+ r_dim = align_to(actual_r_dim, 128)
147
+ if actual_r_dim != r_dim:
148
+ q_pe = jnp.pad(q_pe, ((0, 0), (0, 0), (0, r_dim - actual_r_dim)),
149
+ constant_values=0)
150
+
151
+ q = jnp.concatenate([ql_nope, q_pe], axis=-1)
152
+ max_num_seqs = kv_lens.shape[0]
153
+ num_page_indices = page_indices.shape[0]
154
+ assert num_page_indices % max_num_seqs == 0
155
+ pages_per_seq = num_page_indices // max_num_seqs
156
+
157
+ total_num_pages, page_size_per_kv_packing, kv_packing, _ = cache_kv_c.shape
158
+ page_size = page_size_per_kv_packing * kv_packing
159
+ assert lkv_dim == ql_nope.shape[-1]
160
+ assert r_dim == q_pe.shape[-1]
161
+
162
+ kv_c_cache = cache_kv_c.reshape(total_num_pages, page_size, lkv_dim)
163
+ k_pe_cache = cache_k_pe.reshape(total_num_pages, page_size, r_dim)
164
+
165
+ outputs = []
166
+
167
+ for i in range(distribution[-1]):
168
+ q_start, q_end = cu_q_lens[i], cu_q_lens[i + 1]
169
+ q_len = q_end - q_start
170
+ kv_len = kv_lens[i]
171
+
172
+ q_i = q[q_start:q_end] # [q_len, actual_num_q_heads, lkv_dim+r_dim]
173
+
174
+ indices_start = i * pages_per_seq
175
+ num_pages_i = cdiv(kv_len, page_size)
176
+ indices_end = indices_start + num_pages_i
177
+ indices = page_indices[indices_start:indices_end]
178
+
179
+ # Gather paged kv_c and k_pe
180
+ gathered_kv_c = kv_c_cache[
181
+ indices] # [num_pages_i, page_size, lkv_dim]
182
+ gathered_k_pe = k_pe_cache[indices] # [num_pages_i, page_size, r_dim]
183
+
184
+ # Flatten pages to sequence
185
+ flat_kv_c = gathered_kv_c.reshape(
186
+ -1, lkv_dim) # [num_pages_i * page_size, lkv_dim]
187
+ flat_k_pe = gathered_k_pe.reshape(
188
+ -1, r_dim) # [num_pages_i * page_size, r_dim]
189
+
190
+ # Prepare k and v for attention
191
+ k_i = jnp.concatenate([flat_kv_c[:kv_len], flat_k_pe[:kv_len]],
192
+ axis=-1) # [kv_len, lkv_dim+r_dim]
193
+ v_i = flat_kv_c[:kv_len] # [kv_len, lkv_dim]
194
+
195
+ # MQA attention:
196
+ # q:[q_len, actual_num_q_heads, lkv_dim+r_dim]
197
+ # k:[kv_len, lkv_dim+r_dim]
198
+ # v:[kv_len, lkv_dim]
199
+ # attn: [actual_num_q_heads, q_len, kv_len]
200
+ attn = jnp.einsum("qnh,kh->nqk",
201
+ q_i,
202
+ k_i,
203
+ preferred_element_type=jnp.float32)
204
+ attn *= sm_scale
205
+
206
+ # Causal mask
207
+ q_span = kv_len - q_len + jax.lax.broadcasted_iota(
208
+ jnp.int32, attn.shape, 1)
209
+ kv_span = jax.lax.broadcasted_iota(jnp.int32, attn.shape, 2)
210
+ mask = q_span < kv_span
211
+ if sliding_window is not None:
212
+ mask = jnp.logical_or(mask, q_span - sliding_window >= kv_span)
213
+ if soft_cap is not None:
214
+ attn = soft_cap * jnp.tanh(attn / soft_cap)
215
+ attn = jnp.where(mask, mask_value, attn)
216
+ attn = jax.nn.softmax(attn, axis=-1).astype(v_i.dtype)
217
+
218
+ # out_i: [q_len, actual_num_q_heads, lkv_dim]
219
+ out_i = jnp.einsum("nqk,kl->qnl", attn, v_i).astype(q_i.dtype)
220
+ outputs.append(out_i)
221
+
222
+ return (
223
+ jnp.concatenate(outputs, axis=0),
224
+ cache_kv_c,
225
+ cache_k_pe,
226
+ )
227
+
228
+
229
+ # Expect to run this validation during runtime.
230
+ def dynamic_validate_inputs(
231
+ ql_nope: jax.Array, # [max_num_tokens, actual_num_q_heads, actual_lkv_dim]
232
+ q_pe: jax.Array, # [max_num_tokens, actual_num_q_heads, actual_r_dim]
233
+ new_kv_c: jax.Array, # [max_num_tokens, actual_lkv_dim]
234
+ new_k_pe: jax.Array, # [max_num_tokens, actual_r_dim]
235
+ cache_kv_c: jax.
236
+ Array, # [total_num_pages, page_size_per_kv_packing, kv_packing, lkv_dim]
237
+ cache_k_pe: jax.
238
+ Array, # [total_num_pages, page_size_per_kv_packing, kv_packing, r_dim]
239
+ kv_lens: jax.Array, # i32[max_num_seqs]
240
+ page_indices: jax.Array, # i32[max_num_seqs * pages_per_seq]
241
+ cu_q_lens: jax.Array, # i32[max_num_seqs + 1]
242
+ distribution: jax.Array, # i32[3]
243
+ *,
244
+ sm_scale: float = 1.0,
245
+ sliding_window: int | None = None,
246
+ soft_cap: float | None = None,
247
+ mask_value: float | None = DEFAULT_MASK_VALUE,
248
+ # Kernel optimization params.
249
+ chunk_prefill_size: int | None = None,
250
+ # Kernel tuning params.
251
+ num_kv_pages_per_block: int | None = None,
252
+ num_queries_per_block: int | None = None,
253
+ vmem_limit_bytes: int | None = None,
254
+ # Debug params.
255
+ debug_mode: bool = False,
256
+ ):
257
+ """Validate inputs to the MLA RPA kernel dynamically."""
258
+ static_validate_inputs(
259
+ ql_nope,
260
+ q_pe,
261
+ new_kv_c,
262
+ new_k_pe,
263
+ cache_kv_c,
264
+ cache_k_pe,
265
+ kv_lens,
266
+ page_indices,
267
+ cu_q_lens,
268
+ distribution,
269
+ sm_scale=sm_scale,
270
+ sliding_window=sliding_window,
271
+ soft_cap=soft_cap,
272
+ mask_value=mask_value,
273
+ chunk_prefill_size=chunk_prefill_size,
274
+ num_kv_pages_per_block=num_kv_pages_per_block,
275
+ num_queries_per_block=num_queries_per_block,
276
+ vmem_limit_bytes=vmem_limit_bytes,
277
+ debug_mode=debug_mode,
278
+ )
279
+ max_num_tokens = ql_nope.shape[0]
280
+ total_num_pages = cache_kv_c.shape[0]
281
+ _, page_size_per_kv_packing, kv_packing, _ = cache_kv_c.shape
282
+ page_size = page_size_per_kv_packing * kv_packing
283
+ max_num_seqs = kv_lens.shape[0]
284
+ num_page_indices = page_indices.shape[0]
285
+ assert num_page_indices % max_num_seqs == 0
286
+ pages_per_seq = num_page_indices // max_num_seqs
287
+
288
+ i, j, k = distribution
289
+ if not (0 <= i <= j <= k):
290
+ raise ValueError(f"Invalid distribution: {distribution=}")
291
+
292
+ if k > max_num_seqs:
293
+ raise ValueError(f"num_seqs={k} must be <= {max_num_seqs=}")
294
+
295
+ if cu_q_lens[k] > max_num_tokens:
296
+ raise ValueError(
297
+ f"Total q tokens {cu_q_lens[k]} must be <= {max_num_tokens=}.")
298
+ for seq_idx in range(k):
299
+ q_len = cu_q_lens[seq_idx + 1] - cu_q_lens[seq_idx]
300
+ kv_len = kv_lens[seq_idx]
301
+ if not (0 < q_len <= kv_len):
302
+ raise ValueError(
303
+ f"Require 0 < {q_len=} <= {kv_len=} at sequence {seq_idx}.")
304
+ page_cnt = cdiv(kv_len, page_size)
305
+ if page_cnt > pages_per_seq:
306
+ raise ValueError(
307
+ f"Require {page_cnt=} <= {pages_per_seq=} at sequence {seq_idx} where"
308
+ f" {kv_len=} and {page_size=}.")
309
+ for p in range(page_cnt):
310
+ page_idx = page_indices[seq_idx * pages_per_seq + p]
311
+ if not (0 <= page_idx < total_num_pages):
312
+ raise ValueError(
313
+ f"Require 0 <= {page_idx=} < {total_num_pages=} at sequence"
314
+ f" {seq_idx} where {kv_len=} and {page_size=}.")
315
+
316
+
317
+ # Expect to run this validation during compile time.
318
+ def static_validate_inputs(
319
+ ql_nope: jax.Array, # [max_num_tokens, actual_num_q_heads, actual_lkv_dim]
320
+ q_pe: jax.Array, # [max_num_tokens, actual_num_q_heads, actual_r_dim]
321
+ new_kv_c: jax.Array, # [max_num_tokens, actual_lkv_dim]
322
+ new_k_pe: jax.Array, # [max_num_tokens, actual_r_dim]
323
+ cache_kv_c: jax.
324
+ Array, # [total_num_pages, page_size_per_kv_packing, kv_packing, lkv_dim]
325
+ cache_k_pe: jax.
326
+ Array, # [total_num_pages, page_size_per_kv_packing, kv_packing, r_dim]
327
+ kv_lens: jax.Array, # i32[max_num_seqs]
328
+ page_indices: jax.Array, # i32[max_num_seqs * pages_per_seq]
329
+ cu_q_lens: jax.Array, # i32[max_num_seqs + 1]
330
+ distribution: jax.Array, # i32[3]
331
+ *,
332
+ sm_scale: float = 1.0,
333
+ sliding_window: int | None = None,
334
+ soft_cap: float | None = None,
335
+ mask_value: float | None = DEFAULT_MASK_VALUE,
336
+ # Kernel optimization params.
337
+ chunk_prefill_size: int | None = None,
338
+ # Kernel tuning params.
339
+ num_kv_pages_per_block: int | None = None,
340
+ num_queries_per_block: int | None = None,
341
+ vmem_limit_bytes: int | None = None,
342
+ # Debug params.
343
+ debug_mode: bool = False,
344
+ ):
345
+ """Validate inputs to the MLA RPA kernel statically."""
346
+ if len(ql_nope.shape) != 3:
347
+ raise ValueError(f"Expected 3D array for {ql_nope.shape=}")
348
+ if len(q_pe.shape) != 3:
349
+ raise ValueError(f"Expected 3D array for {q_pe.shape=}")
350
+ if len(new_kv_c.shape) != 2:
351
+ raise ValueError(f"Expected 2D array for {new_kv_c.shape=}")
352
+ if len(new_k_pe.shape) != 2:
353
+ raise ValueError(f"Expected 2D array for {new_k_pe.shape=}")
354
+
355
+ if ql_nope.shape[:2] != q_pe.shape[:2]:
356
+ raise ValueError(
357
+ f"Expected {ql_nope.shape[:2]=} to be equal to {q_pe.shape[:2]=}")
358
+ if ql_nope.shape[0] != new_kv_c.shape[0]:
359
+ raise ValueError(
360
+ f"Expected {ql_nope.shape[0]=} to be equal to {new_kv_c.shape[0]=}"
361
+ )
362
+ if new_kv_c.shape[0] != new_k_pe.shape[0]:
363
+ raise ValueError(
364
+ f"Expected {new_kv_c.shape[0]=} to be equal to {new_k_pe.shape[0]=}"
365
+ )
366
+ if ql_nope.shape[2] != new_kv_c.shape[1]:
367
+ raise ValueError(
368
+ f"Expected {ql_nope.shape[2]=} to be equal to {new_kv_c.shape[1]=}"
369
+ )
370
+ if q_pe.shape[2] != new_k_pe.shape[1]:
371
+ raise ValueError(
372
+ f"Expected {q_pe.shape[2]=} to be equal to {new_k_pe.shape[1]=}")
373
+
374
+ actual_lkv_dim = ql_nope.shape[2]
375
+ actual_r_dim = q_pe.shape[2]
376
+
377
+ (
378
+ _,
379
+ page_size_per_kv_packing,
380
+ kv_packing,
381
+ lkv_dim,
382
+ ) = cache_kv_c.shape
383
+ _, _, _, r_dim = cache_k_pe.shape
384
+
385
+ if lkv_dim != align_to(actual_lkv_dim, 128):
386
+ raise ValueError(
387
+ f"Expected {lkv_dim=} is equal to {align_to(actual_lkv_dim, 128)=}"
388
+ )
389
+ if r_dim != align_to(actual_r_dim, 128):
390
+ raise ValueError(
391
+ f"Expected {r_dim=} is equal to {align_to(actual_r_dim, 128)=}")
392
+
393
+ if not (cache_kv_c.dtype == new_kv_c.dtype):
394
+ raise ValueError(
395
+ f"Expected {cache_kv_c.dtype=} to be equal to {new_kv_c.dtype=}.")
396
+ if not (cache_k_pe.dtype == new_k_pe.dtype):
397
+ raise ValueError(
398
+ f"Expected {cache_k_pe.dtype=} to be equal to {new_k_pe.dtype=}.")
399
+
400
+ # Integer kv quantization is currently not supported.
401
+ if not jnp.issubdtype(cache_kv_c.dtype, jnp.floating):
402
+ raise ValueError(
403
+ f"Expected {cache_kv_c.dtype=} to be a floating point.")
404
+ if not jnp.issubdtype(cache_k_pe.dtype, jnp.floating):
405
+ raise ValueError(
406
+ f"Expected {cache_k_pe.dtype=} to be a floating point.")
407
+
408
+ if kv_packing != get_dtype_packing(cache_kv_c.dtype):
409
+ raise ValueError(
410
+ f"{kv_packing=} does not match with {cache_kv_c.dtype=}")
411
+ if kv_packing != get_dtype_packing(cache_k_pe.dtype):
412
+ raise ValueError(
413
+ f"{kv_packing=} does not match with {cache_k_pe.dtype=}")
414
+
415
+ if not (jnp.int32 == kv_lens.dtype == page_indices.dtype == cu_q_lens.dtype
416
+ == distribution.dtype):
417
+ raise ValueError(
418
+ f"Expected int32 dtype for {kv_lens.dtype=}, {page_indices.dtype=},"
419
+ f" {cu_q_lens.dtype=}, {distribution.dtype=}")
420
+
421
+ if not (len(kv_lens.shape) == len(page_indices.shape) == len(
422
+ cu_q_lens.shape) == 1):
423
+ raise ValueError(
424
+ f"Expected 1D array for {kv_lens.shape=}, {page_indices.shape=},"
425
+ f" {cu_q_lens.shape=}")
426
+
427
+ max_num_seqs = kv_lens.shape[0]
428
+ num_page_indices = page_indices.shape[0]
429
+ if num_page_indices % max_num_seqs != 0:
430
+ raise ValueError(
431
+ f"Expected {num_page_indices=} to be divisible by {max_num_seqs=}."
432
+ )
433
+ if cu_q_lens.shape != (max_num_seqs + 1, ):
434
+ raise ValueError(
435
+ f"Expected {cu_q_lens.shape=} to be ({max_num_seqs + 1},).")
436
+ if distribution.shape != (3, ):
437
+ raise ValueError(f"Expected {distribution.shape=} to be (3,).")
438
+
439
+ page_size = page_size_per_kv_packing * kv_packing
440
+ if page_size % kv_packing != 0:
441
+ raise ValueError(f"{page_size=} must be divisible by {kv_packing=}.")
442
+ if sliding_window is not None and sliding_window <= 0:
443
+ raise ValueError(f"{sliding_window=} must be positive.")
444
+ if soft_cap is not None and soft_cap == 0.0:
445
+ raise ValueError(f"{soft_cap=} must not be 0.0.")
446
+ if chunk_prefill_size is not None and chunk_prefill_size <= 0:
447
+ raise ValueError(f"{chunk_prefill_size=} must be positive.")
448
+ if num_kv_pages_per_block is not None:
449
+ if num_kv_pages_per_block <= 0:
450
+ raise ValueError(f"{num_kv_pages_per_block=} must be positive.")
451
+ if num_queries_per_block is not None:
452
+ if num_queries_per_block <= 0:
453
+ raise ValueError(f"{num_queries_per_block=} must be positive.")
454
+ if vmem_limit_bytes is not None and vmem_limit_bytes <= 0:
455
+ raise ValueError(f"{vmem_limit_bytes=} must be positive.")
456
+
457
+ # No constraints for the following inputs.
458
+ del sm_scale
459
+ del mask_value
460
+ del debug_mode
461
+
462
+
463
+ def _mla_ragged_paged_attention_kernel(
464
+ # Prefetch
465
+ kv_lens_ref, # [max_num_seqs]
466
+ page_indices_ref, # [max_num_seqs * pages_per_seq]
467
+ cu_q_lens_ref, # [max_num_seqs + 1]
468
+ # TODO(jevinjiang): merge these into one so we can save SMEM.
469
+ distribution_ref, # [3] (decode_end, prefill_end, mixed_end)
470
+ sem_ids_ref, # [3] (bq_sem_idx, bkv_sem_idx, bo_sem_idx)
471
+ bo_ids_ref, # [4] (bo_sem_0_seq_idx, bo_sem_1_seq_idx, bo_sem_0_bo_idx, bo_sem_1_bo_idx)
472
+ bkv_update_ids_ref, # [6] (bkv_sem_0_seq_idx, bkv_sem_1_seq_idx, bkv_sem_0_offset, bkv_sem_1_offset, bkv_sem_0_sz, bkv_sem_1_sz)
473
+ # Input
474
+ ql_nope_hbm_ref, # [max_num_tokens, num_q_heads_per_q_packing, q_packing, lkv_dim]
475
+ q_pe_hbm_ref, # [max_num_tokens, num_q_heads_per_q_packing, q_packing, r_dim]
476
+ new_kv_c_hbm_ref, # [max_num_tokens_per_kv_packing, kv_packing, lkv_dim]
477
+ new_k_pe_hbm_ref, # [max_num_tokens_per_kv_packing, kv_packing, r_dim]
478
+ cache_kv_c_hbm_ref, # [total_num_pages, page_size_per_kv_packing, kv_packing, lkv_dim]
479
+ cache_k_pe_hbm_ref, # [total_num_pages, page_size_per_kv_packing, kv_packing, r_dim]
480
+ # Output
481
+ o_hbm_ref, # [max_num_tokens, num_q_heads_per_q_packing, q_packing, lkv_dim]
482
+ updated_cache_kv_c_hbm_ref, # [total_num_pages, page_size_per_kv_packing, kv_packing, lkv_dim]
483
+ updated_cache_k_pe_hbm_ref, # [total_num_pages, page_size_per_kv_packing, kv_packing, r_dim]
484
+ # Scratch
485
+ bkvc_x2_ref, # [2, bkv_sz_per_kv_packing, kv_packing, lkv_dim]
486
+ bkpe_x2_ref, # [2, bkv_sz_per_kv_packing, kv_packing, r_dim]
487
+ bq_nope_x2_ref, # [2, bq_sz, num_q_heads_per_q_packing, q_packing, lkv_dim]
488
+ bq_rope_x2_ref, # [2, bq_sz, num_q_heads_per_q_packing, q_packing, r_dim]
489
+ bo_x2_ref, # [2, bq_sz, num_q_heads_per_q_packing, q_packing, lkv_dim]
490
+ sems, # [4, 2]
491
+ l_ref, # [bq_sz * num_q_heads, 128],
492
+ m_ref, # [bq_sz * num_q_heads, 128],
493
+ acc_ref, # [bq_sz * num_q_heads, lkv_dim],
494
+ *,
495
+ sm_scale: float,
496
+ sliding_window: int | None = None,
497
+ soft_cap: float | None = None,
498
+ mask_value: float = DEFAULT_MASK_VALUE,
499
+ q_scale: float | None = None,
500
+ k_scale: float | None = None,
501
+ v_scale: float | None = None,
502
+ chunk_prefill_size: int | None = None,
503
+ bkv_p,
504
+ bq_sz,
505
+ debug_mode: bool = False,
506
+ ):
507
+ assert ql_nope_hbm_ref.shape == o_hbm_ref.shape
508
+ assert ql_nope_hbm_ref.shape[-1] == cache_kv_c_hbm_ref.shape[-1]
509
+ assert q_pe_hbm_ref.shape[-1] == cache_k_pe_hbm_ref.shape[-1]
510
+ _, num_q_heads_per_q_packing, q_packing, lkv_dim = ql_nope_hbm_ref.shape
511
+ r_dim = q_pe_hbm_ref.shape[-1]
512
+ num_q_heads = num_q_heads_per_q_packing * q_packing
513
+ total_num_pages, page_size_per_kv_packing, kv_packing, _ = (
514
+ cache_kv_c_hbm_ref.shape)
515
+ max_num_seqs = kv_lens_ref.shape[0]
516
+ num_page_indices = page_indices_ref.shape[0]
517
+
518
+ assert num_page_indices % max_num_seqs == 0
519
+ pages_per_seq = num_page_indices // max_num_seqs
520
+ q_dtype = ql_nope_hbm_ref.dtype
521
+ kv_dtype = cache_kv_c_hbm_ref.dtype
522
+ assert q_pe_hbm_ref.dtype == q_dtype
523
+ assert o_hbm_ref.dtype == q_dtype
524
+ assert get_dtype_packing(q_dtype) == q_packing
525
+ assert get_dtype_packing(kv_dtype) == kv_packing
526
+ assert lkv_dim % 128 == 0
527
+ assert r_dim % 128 == 0
528
+ bkv_sz_per_kv_packing = bkv_p * page_size_per_kv_packing
529
+ bkv_sz = bkv_sz_per_kv_packing * kv_packing
530
+ page_size = page_size_per_kv_packing * kv_packing
531
+ seq_idx = pl.program_id(0)
532
+ num_seqs = pl.num_programs(0)
533
+ decode_end = distribution_ref[0]
534
+ prefill_end = distribution_ref[1]
535
+ mixed_end = distribution_ref[2]
536
+
537
+ q_start = cu_q_lens_ref[seq_idx]
538
+ q_end = cu_q_lens_ref[seq_idx + 1]
539
+ q_len = q_end - q_start
540
+ kv_len = kv_lens_ref[seq_idx]
541
+
542
+ def debug_print(msg, *args):
543
+ if debug_mode:
544
+ pl.debug_print(msg, *args)
545
+
546
+ debug_print("[RPA debug] ======= In loop seq_idx={}", seq_idx)
547
+ debug_print("[RPA debug] num_seqs={}", num_seqs)
548
+ debug_print("[RPA debug] decode_end={}", decode_end)
549
+ debug_print("[RPA debug] prefill_end={}", prefill_end)
550
+ debug_print("[RPA debug] mixed_end={}", mixed_end)
551
+ debug_print("[RPA debug] bkv_p={}", bkv_p)
552
+ debug_print("[RPA debug] page_size={}", page_size)
553
+ debug_print("[RPA debug] pages_per_seq={}", pages_per_seq)
554
+ debug_print("[RPA debug] bkv_sz_per_kv_packing={}", bkv_sz_per_kv_packing)
555
+ debug_print("[RPA debug] bq_sz={}", bq_sz)
556
+ debug_print("[RPA debug] q_start={}", q_start)
557
+ debug_print("[RPA debug] q_end={}", q_end)
558
+ debug_print("[RPA debug] q_len={}", q_len)
559
+ debug_print("[RPA debug] kv_len={}", kv_len)
560
+
561
+ def flash_attention(
562
+ ql_nope, # [actual_bq_sz * num_q_heads, lkv_dim]
563
+ q_pe, # [actual_bq_sz * num_q_heads, r_dim]
564
+ kv_c, # [bkv_sz, lkv_dim]
565
+ k_pe, # [bkv_sz, r_dim]
566
+ *,
567
+ bq_idx,
568
+ bkv_idx,
569
+ ):
570
+ assert len(ql_nope.shape) == 2
571
+ assert len(q_pe.shape) == 2
572
+ assert len(kv_c.shape) == 2
573
+ assert len(k_pe.shape) == 2
574
+ assert ql_nope.shape[0] % num_q_heads == 0
575
+ assert ql_nope.shape[0] == q_pe.shape[0]
576
+ assert q_pe.shape[0] % bq_sz == 0
577
+ assert ql_nope.shape[1] == lkv_dim
578
+ assert q_pe.shape[1] == r_dim
579
+ assert kv_c.shape == (bkv_sz, lkv_dim)
580
+ assert k_pe.shape == (bkv_sz, r_dim)
581
+ head_l_ref = l_ref.at[:ql_nope.shape[0]]
582
+ head_m_ref = m_ref.at[:ql_nope.shape[0]]
583
+ head_acc_ref = acc_ref.at[:ql_nope.shape[0]]
584
+
585
+ def load_with_init(ref, init_val):
586
+ return jnp.where(bkv_idx == 0, jnp.full_like(ref, init_val),
587
+ ref[...])
588
+
589
+ # Follow FlashAttention-2 forward pass.
590
+ s_nope = jnp.einsum("nd,md->nm",
591
+ ql_nope,
592
+ kv_c,
593
+ preferred_element_type=jnp.float32)
594
+ s_pe = jnp.einsum("nd,md->nm",
595
+ q_pe,
596
+ k_pe,
597
+ preferred_element_type=jnp.float32)
598
+ s = s_nope + s_pe
599
+ s *= sm_scale
600
+ if k_scale is not None:
601
+ s *= k_scale
602
+ if q_scale is not None:
603
+ s *= q_scale
604
+
605
+ q_span = (kv_len - q_len + bq_idx * bq_sz +
606
+ lax.broadcasted_iota(jnp.int32, s.shape, 0) // num_q_heads)
607
+ k_span = bkv_idx * bkv_sz + lax.broadcasted_iota(jnp.int32, s.shape, 1)
608
+ mask = q_span < k_span
609
+ # TODO(jevinjiang, xiowei): reduce pages_per_seq based on sliding_window.
610
+ if sliding_window is not None:
611
+ mask = jnp.logical_or(mask, q_span - sliding_window >= k_span)
612
+
613
+ if soft_cap is not None:
614
+ s = soft_cap * jnp.tanh(s / soft_cap)
615
+ s = jnp.where(mask, mask_value, s)
616
+ s_rowmax = jnp.max(s, axis=1, keepdims=True)
617
+ m_prev = load_with_init(head_m_ref, -jnp.inf)
618
+ m_curr = jnp.maximum(m_prev, s_rowmax)
619
+ head_m_ref[...] = m_curr
620
+ p = jnp.exp(s - broadcast_minor(m_curr, s.shape))
621
+
622
+ pv = jnp.einsum("nm,md->nd",
623
+ p,
624
+ kv_c,
625
+ preferred_element_type=jnp.float32)
626
+ if v_scale is not None:
627
+ pv *= v_scale
628
+
629
+ p_rowsum = jnp.sum(p, axis=1, keepdims=True)
630
+ exp_m_diff = jnp.exp(m_prev - m_curr)
631
+ l_prev = load_with_init(head_l_ref, 0.0)
632
+ l_curr = exp_m_diff * l_prev + p_rowsum
633
+ head_l_ref[...] = l_curr
634
+ o_prev = load_with_init(head_acc_ref, 0.0)
635
+ o_curr = broadcast_minor(exp_m_diff, o_prev.shape) * o_prev + pv
636
+ head_acc_ref[...] = o_curr
637
+
638
+ def _async_copy(src, dst, sem, wait):
639
+ if debug_mode:
640
+ # Skip DMA if debug mode is enabled.
641
+ return
642
+ cp = pltpu.make_async_copy(src, dst, sem)
643
+ if wait:
644
+ cp.wait()
645
+ else:
646
+ cp.start()
647
+
648
+ def _fetch_bkv(seq_idx, bkv_idx, bkv_sem_idx, *, wait=False):
649
+ sem = sems.at[0, bkv_sem_idx]
650
+ bkvc_vmem_ref = bkvc_x2_ref.at[bkv_sem_idx]
651
+ bkvpe_vmem_ref = bkpe_x2_ref.at[bkv_sem_idx]
652
+
653
+ reshaped_cache_kv_c_hbm_ref = cache_kv_c_hbm_ref.reshape(
654
+ total_num_pages * page_size_per_kv_packing,
655
+ *cache_kv_c_hbm_ref.shape[2:],
656
+ )
657
+ reshaped_cache_k_pe_hbm_ref = cache_k_pe_hbm_ref.reshape(
658
+ total_num_pages * page_size_per_kv_packing,
659
+ *cache_k_pe_hbm_ref.shape[2:],
660
+ )
661
+ kv_len = kv_lens_ref[seq_idx]
662
+ kv_len_start = bkv_idx * bkv_sz
663
+ kv_p_start = bkv_idx * bkv_p
664
+
665
+ kv_left = kv_len - kv_len_start
666
+ kv_left_per_kv_packing = cdiv(kv_left, kv_packing)
667
+ page_indices_offset = seq_idx * pages_per_seq + kv_p_start
668
+
669
+ debug_print(
670
+ "[RPA debug]"
671
+ f" -----------{'wait' if wait else 'start'}_fetch_bkv-----------")
672
+ debug_print("[RPA debug] seq_idx={}", seq_idx)
673
+ debug_print("[RPA debug] bkv_idx={}", bkv_idx)
674
+ debug_print("[RPA debug] bkv_sem_idx={}", bkv_sem_idx)
675
+ debug_print("[RPA debug] kv_len_start={}", kv_len_start)
676
+ debug_print("[RPA debug] kv_p_start={}", kv_p_start)
677
+ debug_print("[RPA debug] kv_left={}", kv_left)
678
+ debug_print("[RPA debug] page_indices_offset={}", page_indices_offset)
679
+
680
+ # Fetch effective kv from kv cache.
681
+ def loop_body(i, _):
682
+ sz_per_kv_packing = jnp.minimum(
683
+ page_size_per_kv_packing,
684
+ kv_left_per_kv_packing - i * page_size_per_kv_packing,
685
+ )
686
+ _async_copy(
687
+ reshaped_cache_kv_c_hbm_ref.at[pl.ds(
688
+ page_indices_ref[page_indices_offset + i] *
689
+ page_size_per_kv_packing,
690
+ sz_per_kv_packing,
691
+ )],
692
+ bkvc_vmem_ref.at[pl.ds(i * page_size_per_kv_packing,
693
+ sz_per_kv_packing)],
694
+ sem,
695
+ wait,
696
+ )
697
+ _async_copy(
698
+ reshaped_cache_k_pe_hbm_ref.at[pl.ds(
699
+ page_indices_ref[page_indices_offset + i] *
700
+ page_size_per_kv_packing,
701
+ sz_per_kv_packing,
702
+ )],
703
+ bkvpe_vmem_ref.at[pl.ds(i * page_size_per_kv_packing,
704
+ sz_per_kv_packing)],
705
+ sem,
706
+ wait,
707
+ )
708
+ debug_print(
709
+ "[RPA debug] loop_body i={}, sz_per_kv_packing={}",
710
+ i,
711
+ sz_per_kv_packing,
712
+ )
713
+
714
+ actual_bkv_p = jnp.minimum(cdiv(kv_left, page_size), bkv_p)
715
+ lax.fori_loop(
716
+ 0,
717
+ actual_bkv_p,
718
+ loop_body,
719
+ None, # init value
720
+ unroll=False,
721
+ )
722
+
723
+ def _fetch_bq(seq_idx, bq_idx, bq_sem_idx, *, wait=False):
724
+ sem = sems.at[1, bq_sem_idx]
725
+ bq_nope_vmem_ref = bq_nope_x2_ref.at[bq_sem_idx]
726
+ bq_rope_vmem_ref = bq_rope_x2_ref.at[bq_sem_idx]
727
+
728
+ q_len_start = cu_q_lens_ref[seq_idx] + bq_idx * bq_sz
729
+ q_end = cu_q_lens_ref[seq_idx + 1]
730
+ sz = jnp.minimum(bq_sz, q_end - q_len_start)
731
+
732
+ debug_print(
733
+ "[RPA debug]"
734
+ f" -----------{'wait' if wait else 'start'}_fetch_bq-----------")
735
+ debug_print("[RPA debug] seq_idx={}", seq_idx)
736
+ debug_print("[RPA debug] bq_idx={}", bq_idx)
737
+ debug_print("[RPA debug] bq_sem_idx={}", bq_sem_idx)
738
+ debug_print("[RPA debug] q_len_start={}", q_len_start)
739
+ debug_print("[RPA debug] q_end={}", q_end)
740
+ debug_print("[RPA debug] sz={}", sz)
741
+
742
+ _async_copy(
743
+ ql_nope_hbm_ref.at[pl.ds(q_len_start, sz)],
744
+ bq_nope_vmem_ref.at[pl.ds(0, sz)],
745
+ sem,
746
+ wait,
747
+ )
748
+
749
+ _async_copy(
750
+ q_pe_hbm_ref.at[pl.ds(q_len_start, sz)],
751
+ bq_rope_vmem_ref.at[pl.ds(0, sz)],
752
+ sem,
753
+ wait,
754
+ )
755
+
756
+ def _send_bo(seq_idx, bo_idx, bo_sem_idx, *, wait=False):
757
+ sem = sems.at[2, bo_sem_idx]
758
+ vmem_ref = bo_x2_ref.at[bo_sem_idx]
759
+ q_len_start = cu_q_lens_ref[seq_idx] + bo_idx * bq_sz
760
+ q_end = cu_q_lens_ref[seq_idx + 1]
761
+ sz = jnp.minimum(bq_sz, q_end - q_len_start)
762
+
763
+ debug_print(
764
+ "[RPA debug]"
765
+ f" -----------{'wait' if wait else 'start'}_send_bo-----------")
766
+ debug_print("[RPA debug] seq_idx={}", seq_idx)
767
+ debug_print("[RPA debug] bo_idx={}", bo_idx)
768
+ debug_print("[RPA debug] bo_sem_idx={}", bo_sem_idx)
769
+ debug_print("[RPA debug] q_len_start={}", q_len_start)
770
+ debug_print("[RPA debug] q_end={}", q_end)
771
+ debug_print("[RPA debug] sz={}", sz)
772
+
773
+ _async_copy(
774
+ vmem_ref.at[pl.ds(0, sz)],
775
+ o_hbm_ref.at[pl.ds(q_len_start, sz)],
776
+ sem,
777
+ wait,
778
+ )
779
+
780
+ def start_fetch_bkv(seq_idx, bkv_idx, bkv_sem_idx):
781
+ return _fetch_bkv(seq_idx, bkv_idx, bkv_sem_idx)
782
+
783
+ def wait_fetch_bkv(seq_idx, bkv_idx, bkv_sem_idx):
784
+ return _fetch_bkv(seq_idx, bkv_idx, bkv_sem_idx, wait=True)
785
+
786
+ def start_fetch_bq(seq_idx, bq_idx, bq_sem_idx):
787
+ return _fetch_bq(seq_idx, bq_idx, bq_sem_idx)
788
+
789
+ def wait_fetch_bq(seq_idx, bq_idx, bq_sem_idx):
790
+ return _fetch_bq(seq_idx, bq_idx, bq_sem_idx, wait=True)
791
+
792
+ def start_send_bo(seq_idx, bo_idx, bo_sem_idx):
793
+ bo_ids_ref[bo_sem_idx] = seq_idx
794
+ bo_ids_ref[bo_sem_idx + 2] = bo_idx
795
+ _send_bo(seq_idx, bo_idx, bo_sem_idx)
796
+
797
+ def wait_send_bo(bo_sem_idx):
798
+ old_seq_idx = bo_ids_ref[bo_sem_idx]
799
+ old_bo_idx = bo_ids_ref[bo_sem_idx + 2]
800
+
801
+ @pl.when(jnp.logical_and(0 <= old_seq_idx, old_seq_idx <= seq_idx))
802
+ def _():
803
+ _send_bo(old_seq_idx, old_bo_idx, bo_sem_idx, wait=True)
804
+
805
+ def load_bq(bq_sem_idx, *, actual_bq_sz=bq_sz):
806
+ q_nope_ref = (bq_nope_x2_ref.bitcast(
807
+ jnp.uint32).at[bq_sem_idx].reshape(
808
+ bq_sz * num_q_heads_per_q_packing, lkv_dim))
809
+ q_nope_vec = pltpu.bitcast(
810
+ q_nope_ref[:actual_bq_sz * num_q_heads_per_q_packing],
811
+ q_dtype,
812
+ )
813
+ q_rope_ref = (bq_rope_x2_ref.bitcast(
814
+ jnp.uint32).at[bq_sem_idx].reshape(
815
+ bq_sz * num_q_heads_per_q_packing, r_dim))
816
+ q_rope_vec = pltpu.bitcast(
817
+ q_rope_ref[:actual_bq_sz * num_q_heads_per_q_packing],
818
+ q_dtype,
819
+ )
820
+ return q_nope_vec, q_rope_vec
821
+
822
+ def load_bkv(bkv_sem_idx, *, bkvc_mask, bkpe_mask):
823
+ bitwidth = 32 // kv_packing
824
+ repack_ty = jnp.dtype(f"uint{bitwidth}")
825
+ bkvc_ref = (bkvc_x2_ref.bitcast(jnp.uint32).at[bkv_sem_idx].reshape(
826
+ bkv_sz_per_kv_packing, lkv_dim))
827
+ bkvc_vec = bkvc_ref[...]
828
+ bkvc_vecs = []
829
+ for i in range(kv_packing):
830
+ masked_bkvc_vec = bkvc_vec >> (i * bitwidth)
831
+ bkvc_vecs.append(masked_bkvc_vec)
832
+ concated_bkvc_vec = jnp.concatenate(bkvc_vecs, axis=-1)
833
+ concated_bkvc_vec = concated_bkvc_vec.reshape(bkv_sz, lkv_dim)
834
+ concated_bkvc_vec = lax.select(bkvc_mask, concated_bkvc_vec,
835
+ jnp.zeros_like(concated_bkvc_vec))
836
+ concated_bkvc_vec = pltpu.bitcast(concated_bkvc_vec.astype(repack_ty),
837
+ kv_dtype)
838
+
839
+ bkpe_ref = (bkpe_x2_ref.bitcast(jnp.uint32).at[bkv_sem_idx].reshape(
840
+ bkv_sz_per_kv_packing, r_dim))
841
+ bkpe_vec = bkpe_ref[...]
842
+ bkpe_vecs = []
843
+ for i in range(kv_packing):
844
+ masked_bkpe_vec = bkpe_vec >> (i * bitwidth)
845
+ bkpe_vecs.append(masked_bkpe_vec)
846
+ concated_bkpe_vec = jnp.concatenate(bkpe_vecs, axis=-1)
847
+ concated_bkpe_vec = concated_bkpe_vec.reshape(bkv_sz, r_dim)
848
+ concated_bkpe_vec = lax.select(bkpe_mask, concated_bkpe_vec,
849
+ jnp.zeros_like(concated_bkpe_vec))
850
+ concated_bkpe_vec = pltpu.bitcast(concated_bkpe_vec.astype(repack_ty),
851
+ kv_dtype)
852
+
853
+ return concated_bkvc_vec, concated_bkpe_vec
854
+
855
+ def broadcast_minor(src, shape):
856
+ if src.shape == shape:
857
+ return src
858
+ assert src.shape[:-1] == shape[:-1]
859
+ assert src.shape[-1] % 128 == 0
860
+ target_minor = align_to(shape[-1], src.shape[-1])
861
+ # no-op concatenation.
862
+ return jnp.concatenate(
863
+ [src for _ in range(target_minor // src.shape[-1])],
864
+ axis=-1)[..., :shape[-1]]
865
+
866
+ def process(static_q_len=None):
867
+ num_bkv = cdiv(kv_len, bkv_sz)
868
+ if static_q_len is None:
869
+ actual_bq_sz = bq_sz
870
+ num_bq = cdiv(q_len, actual_bq_sz)
871
+ else:
872
+ actual_bq_sz = min(bq_sz, static_q_len)
873
+ num_bq = cdiv(static_q_len, actual_bq_sz)
874
+
875
+ def get_next_bq_ids(seq_idx, bq_idx, bq_sem_idx):
876
+ next_bq_idx = bq_idx + 1
877
+ is_last_bq = next_bq_idx == num_bq
878
+ next_bq_idx = lax.select(is_last_bq, 0, next_bq_idx)
879
+ next_seq_idx = lax.select(is_last_bq, seq_idx + 1, seq_idx)
880
+ next_bq_sem_idx = lax.select(bq_sem_idx == 0, 1, 0)
881
+ return next_seq_idx, next_bq_idx, next_bq_sem_idx
882
+
883
+ def get_next_bkv_ids(seq_idx, bq_idx, bkv_idx, bkv_sem_idx):
884
+ next_bkv_idx = bkv_idx + 1
885
+ is_last_bkv = next_bkv_idx == num_bkv
886
+ next_bkv_idx = lax.select(is_last_bkv, 0, next_bkv_idx)
887
+ next_bq_idx = lax.select(is_last_bkv, bq_idx + 1, bq_idx)
888
+ is_last_bq = next_bq_idx == num_bq
889
+ next_bq_idx = lax.select(is_last_bq, 0, next_bq_idx)
890
+ next_seq_idx = lax.select(is_last_bq, seq_idx + 1, seq_idx)
891
+ next_bkv_sem_idx = lax.select(bkv_sem_idx == 0, 1, 0)
892
+ return next_seq_idx, next_bq_idx, next_bkv_idx, next_bkv_sem_idx
893
+
894
+ def compute_with_bq(bq_idx, _):
895
+ bq_sem_idx = sem_ids_ref[0]
896
+ next_seq_idx, next_bq_idx, next_bq_sem_idx = get_next_bq_ids(
897
+ seq_idx, bq_idx, bq_sem_idx)
898
+
899
+ # Prefetch next bq
900
+ @pl.when(next_seq_idx < num_seqs)
901
+ def prefetch_next_bq():
902
+ sem_ids_ref[0] = next_bq_sem_idx
903
+ start_fetch_bq(next_seq_idx, next_bq_idx, next_bq_sem_idx)
904
+
905
+ def compute_with_bkv(bkv_idx, _):
906
+ # Create bitmask for KV.
907
+ assert bkv_sz % kv_packing == 0
908
+ actual_bkv_sz = jnp.minimum(bkv_sz, kv_len - bkv_idx * bkv_sz)
909
+ bkvc_shape = (bkv_sz, lkv_dim)
910
+ bkvc_mask = (lax.broadcasted_iota(jnp.int32, bkvc_shape, 0)
911
+ < actual_bkv_sz)
912
+ bkpe_shape = (bkv_sz, r_dim)
913
+ bkpe_mask = (lax.broadcasted_iota(jnp.int32, bkpe_shape, 0)
914
+ < actual_bkv_sz)
915
+
916
+ # Get next bkv ids.
917
+ bkv_sem_idx = sem_ids_ref[1]
918
+ next_seq_idx, _, next_bkv_idx, next_bkv_sem_idx = get_next_bkv_ids(
919
+ seq_idx, bq_idx, bkv_idx, bkv_sem_idx)
920
+
921
+ # Prefetch next bkv
922
+ @pl.when(next_seq_idx < num_seqs)
923
+ def prefetch_next_bkv():
924
+ sem_ids_ref[1] = next_bkv_sem_idx
925
+ start_fetch_bkv(next_seq_idx, next_bkv_idx,
926
+ next_bkv_sem_idx)
927
+
928
+ # Wait for cur bq if not ready yet
929
+ @pl.when(bkv_idx == 0)
930
+ def wait_cur_bq():
931
+ wait_fetch_bq(seq_idx, bq_idx, bq_sem_idx)
932
+
933
+ # Wait for cur bkv
934
+ wait_fetch_bkv(seq_idx, bkv_idx, bkv_sem_idx)
935
+
936
+ debug_print(
937
+ "[RPA debug] -----------flash attention-----------")
938
+ debug_print("[RPA debug] seq_idx={}", seq_idx)
939
+ debug_print("[RPA debug] bq_idx={}", bq_idx)
940
+ debug_print("[RPA debug] bkv_idx={}", bkv_idx)
941
+ if debug_mode:
942
+ # Skip flash attention if debug mode is enabled.
943
+ return
944
+
945
+ # Flash attention with cur bkv and bq
946
+ bkvc, bkpe = load_bkv(bkv_sem_idx,
947
+ bkvc_mask=bkvc_mask,
948
+ bkpe_mask=bkpe_mask)
949
+ bq_nope_vec, bq_pe_vec = load_bq(bq_sem_idx,
950
+ actual_bq_sz=actual_bq_sz)
951
+ flash_attention(
952
+ bq_nope_vec,
953
+ bq_pe_vec,
954
+ bkvc,
955
+ bkpe,
956
+ bq_idx=bq_idx,
957
+ bkv_idx=bkv_idx,
958
+ )
959
+
960
+ lax.fori_loop(0, num_bkv, compute_with_bkv, None, unroll=False)
961
+
962
+ # Load acc and calculate final output.
963
+ acc = acc_ref[...]
964
+ l = broadcast_minor(l_ref[...], acc.shape) # noqa
965
+ out = (lax.div(acc, l) if q_dtype == jnp.float32 else
966
+ (acc * pl.reciprocal(l, approx=True)).astype(q_dtype))
967
+
968
+ # Wait for previous bo to be fully sent before storing new bo.
969
+ bo_sem_idx = sem_ids_ref[2]
970
+ sem_ids_ref[2] = lax.select(bo_sem_idx == 0, 1, 0)
971
+ wait_send_bo(bo_sem_idx)
972
+
973
+ # Store output from acc to bo.
974
+ bo_x2_ref.at[bo_sem_idx].bitcast(jnp.int32).reshape(
975
+ bq_sz * num_q_heads_per_q_packing,
976
+ lkv_dim,
977
+ )[...] = pltpu.bitcast(out, jnp.int32)
978
+
979
+ # Send cur bo
980
+ start_send_bo(seq_idx, bq_idx, bo_sem_idx)
981
+
982
+ lax.fori_loop(0, num_bq, compute_with_bq, None, unroll=False)
983
+
984
+ ### ------- Kernel start ------- ###
985
+
986
+ @pl.when(seq_idx == 0)
987
+ def prologue():
988
+ start_fetch_bq(0, 0, 0)
989
+ start_fetch_bkv(0, 0, 0)
990
+
991
+ @pl.when(seq_idx < decode_end)
992
+ def process_decode():
993
+ process(static_q_len=1)
994
+
995
+ @pl.when(jnp.logical_and(decode_end <= seq_idx, seq_idx < prefill_end))
996
+ def process_prefill():
997
+ process(static_q_len=chunk_prefill_size)
998
+
999
+ @pl.when(jnp.logical_and(prefill_end <= seq_idx, seq_idx < mixed_end))
1000
+ def process_mixed():
1001
+ process()
1002
+
1003
+ @pl.when(seq_idx == num_seqs - 1)
1004
+ def epilogue():
1005
+ for i in range(2):
1006
+ wait_send_bo(i)
1007
+
1008
+ ### ------- Kernel end ------- ###
1009
+
1010
+
1011
+ def prepare_q_inputs(
1012
+ q: jax.Array, # [max_num_tokens, actual_num_q_heads, actual_head_dim],
1013
+ ):
1014
+ max_num_tokens, actual_num_q_heads, actual_head_dim = q.shape
1015
+ q_packing = get_dtype_packing(q.dtype)
1016
+ num_q_heads = align_to(actual_num_q_heads, q_packing)
1017
+ head_dim = align_to(actual_head_dim, 128)
1018
+ q = jnp.pad(
1019
+ q.reshape(
1020
+ max_num_tokens,
1021
+ actual_num_q_heads,
1022
+ actual_head_dim,
1023
+ ),
1024
+ (
1025
+ (0, 0),
1026
+ (0, num_q_heads - actual_num_q_heads),
1027
+ (0, head_dim - actual_head_dim),
1028
+ ),
1029
+ constant_values=0,
1030
+ ).reshape(
1031
+ max_num_tokens,
1032
+ num_q_heads // q_packing,
1033
+ q_packing,
1034
+ head_dim,
1035
+ )
1036
+ return q
1037
+
1038
+
1039
+ def prepare_kv_inputs(
1040
+ kv: jax.Array, # [max_num_tokens, actual_head_dim],
1041
+ ):
1042
+ max_num_tokens, actual_head_dim = kv.shape
1043
+ kv_packing = get_dtype_packing(kv.dtype)
1044
+ assert max_num_tokens % kv_packing == 0
1045
+ head_dim = align_to(actual_head_dim, 128)
1046
+
1047
+ kv = kv.reshape(max_num_tokens // kv_packing, kv_packing, actual_head_dim)
1048
+ kv = jnp.pad(kv, ((0, 0), (0, 0), (0, head_dim - actual_head_dim)),
1049
+ constant_values=0)
1050
+
1051
+ return kv
1052
+
1053
+
1054
+ def prepare_outputs(
1055
+ out, # [max_num_tokens, num_q_heads // q_packing, q_packing, head_dim]
1056
+ actual_num_q_heads: int,
1057
+ actual_head_dim: int,
1058
+ ):
1059
+ (
1060
+ max_num_tokens,
1061
+ num_q_heads_per_q_packing,
1062
+ q_packing,
1063
+ head_dim,
1064
+ ) = out.shape
1065
+ return out.reshape(
1066
+ max_num_tokens,
1067
+ num_q_heads_per_q_packing * q_packing,
1068
+ head_dim,
1069
+ )[:, :actual_num_q_heads, :actual_head_dim]
1070
+
1071
+
1072
+ @functools.partial(
1073
+ jax.jit,
1074
+ static_argnames=(
1075
+ "sm_scale",
1076
+ "sliding_window",
1077
+ "soft_cap",
1078
+ "mask_value",
1079
+ "chunk_prefill_size",
1080
+ "num_kv_pages_per_block",
1081
+ "num_queries_per_block",
1082
+ "vmem_limit_bytes",
1083
+ "debug_mode",
1084
+ ),
1085
+ donate_argnames=("cache_kv_c", "cache_k_pe"),
1086
+ )
1087
+ def mla_ragged_paged_attention(
1088
+ ql_nope: jax.Array, # [max_num_tokens, actual_num_q_heads, actual_lkv_dim]
1089
+ q_pe: jax.Array, # [max_num_tokens, actual_num_q_heads, actual_r_dim]
1090
+ new_kv_c: jax.Array, # [max_num_tokens, actual_lkv_dim]
1091
+ new_k_pe: jax.Array, # [max_num_tokens, actual_r_dim]
1092
+ cache_kv_c: jax.
1093
+ Array, # [total_num_pages, page_size_per_kv_packing, kv_packing, lkv_dim]
1094
+ cache_k_pe: jax.
1095
+ Array, # [total_num_pages, page_size_per_kv_packing, kv_packing, r_dim]
1096
+ kv_lens: jax.Array, # i32[max_num_seqs]
1097
+ page_indices: jax.Array, # i32[max_num_seqs * pages_per_seq]
1098
+ cu_q_lens: jax.Array, # i32[max_num_seqs + 1]
1099
+ distribution: jax.Array, # i32[3]
1100
+ *,
1101
+ sm_scale: float = 1.0,
1102
+ sliding_window: int | None = None,
1103
+ soft_cap: float | None = None,
1104
+ mask_value: float | None = DEFAULT_MASK_VALUE,
1105
+ # Kernel optimization params.
1106
+ chunk_prefill_size: int | None = None,
1107
+ # Kernel tuning params.
1108
+ num_kv_pages_per_block: int | None = None,
1109
+ num_queries_per_block: int | None = None,
1110
+ vmem_limit_bytes: int | None = None,
1111
+ # Debug params.
1112
+ debug_mode: bool = False,
1113
+ ) -> tuple[
1114
+ jax.Array, # [max_num_tokens, actual_num_q_heads, actual_lkv_dim]
1115
+ jax.
1116
+ Array, # [total_num_pages, page_size_per_kv_packing, kv_packing, lkv_dim]
1117
+ jax.
1118
+ Array, # [total_num_pages, page_size_per_kv_packing, kv_packing, r_dim]
1119
+ ]:
1120
+ """MLA Ragged paged attention that supports mixed prefill and decode.
1121
+
1122
+ Args:
1123
+ ql_nope: concatenated all sequences' queries.
1124
+ q_pe: concatenated all sequences' rope.
1125
+ new_kv_c: concatenated all sequences' kv_c values
1126
+ new_k_pe: concatenated all sequences' k_pe values
1127
+ cache_kv_c: the current kv_c cache.
1128
+ cache_k_pe: the current k_pe cache.
1129
+ kv_lens: the length of each sequence in the kv cache.
1130
+ page_indices: flattened page indices look-up table by (seq_id, page_id).
1131
+ cu_q_lens: the cumulative sum of the effective query lengths. Similar to
1132
+ kv_lens, only the first num_seqs+1 values are valid.
1133
+ distribution: (i, j, k) represents that sequences[0:i] are decode-only,
1134
+ sequences[i:j] are chunked-prefill-only, and sequences[j:k] are mixed. The
1135
+ k is also the total number of sequences.
1136
+ sm_scale: the softmax scale which will be applied to the Q@K^T.
1137
+ sliding_window: the sliding window size for the attention.
1138
+ soft_cap: the logit soft cap for the attention.
1139
+ mask_value: mask value for causal mask.
1140
+ num_kv_pages_per_block: number of kv pages to be processed in one flash
1141
+ attention block in the pallas kernel.
1142
+ num_queries_per_block: number of kv pages to be processed in one flash
1143
+ attention block in the pallas kernel.
1144
+ vmem_limit_bytes: the vmem limit for the pallas kernel.
1145
+ debug_mode: if true, RPA does not issue any DMAs or run flash attention but
1146
+ print debug info. Need to compile with `--xla_tpu_enable_log_recorder`.
1147
+
1148
+ Returns:
1149
+ The output of the attention.
1150
+ """
1151
+ # TODO(chengjiyao): Support both autotuning table and heuristic logic to set
1152
+ # these kernel block sizes
1153
+ if num_kv_pages_per_block is None or num_queries_per_block is None:
1154
+ raise ValueError(
1155
+ "num_kv_pages_per_block and num_queries_per_block must be specified."
1156
+ )
1157
+ static_validate_inputs(
1158
+ ql_nope,
1159
+ q_pe,
1160
+ new_kv_c,
1161
+ new_k_pe,
1162
+ cache_kv_c,
1163
+ cache_k_pe,
1164
+ kv_lens,
1165
+ page_indices,
1166
+ cu_q_lens,
1167
+ distribution,
1168
+ sm_scale=sm_scale,
1169
+ sliding_window=sliding_window,
1170
+ soft_cap=soft_cap,
1171
+ mask_value=mask_value,
1172
+ chunk_prefill_size=chunk_prefill_size,
1173
+ num_kv_pages_per_block=num_kv_pages_per_block,
1174
+ num_queries_per_block=num_queries_per_block,
1175
+ vmem_limit_bytes=vmem_limit_bytes,
1176
+ debug_mode=debug_mode,
1177
+ )
1178
+
1179
+ # TODO(chengjiyao): fuse kv cache update into the kernel.
1180
+ cache_kv_c, cache_k_pe = update_kv_cache(
1181
+ new_kv_c,
1182
+ new_k_pe,
1183
+ cache_kv_c,
1184
+ cache_k_pe,
1185
+ kv_lens,
1186
+ page_indices,
1187
+ cu_q_lens,
1188
+ distribution,
1189
+ )
1190
+
1191
+ _, actual_num_q_heads, actual_lkv_dim = ql_nope.shape
1192
+
1193
+ ql_nope = prepare_q_inputs(
1194
+ ql_nope
1195
+ ) # [max_num_tokens, num_q_heads_per_q_packing, q_packing, lkv_dim]
1196
+ q_pe = prepare_q_inputs(
1197
+ q_pe) # [max_num_tokens, num_q_heads_per_q_packing, q_packing, r_dim]
1198
+ new_kv_c = prepare_kv_inputs(
1199
+ new_kv_c) # [max_num_tokens_per_kv_packing, kv_packing, lkv_dim]
1200
+ new_k_pe = prepare_kv_inputs(
1201
+ new_k_pe) # [max_num_tokens_per_kv_packing, kv_packing, r_dim]
1202
+ lkv_dim = new_kv_c.shape[-1]
1203
+ r_dim = new_k_pe.shape[-1]
1204
+
1205
+ _, page_size_per_kv_packing, kv_packing, _ = cache_kv_c.shape
1206
+ page_size = page_size_per_kv_packing * kv_packing
1207
+ _, num_q_heads_per_q_packing, q_packing, _ = ql_nope.shape
1208
+ max_num_seqs = kv_lens.shape[0]
1209
+ num_page_indices = page_indices.shape[0]
1210
+ assert num_page_indices % max_num_seqs == 0
1211
+ num_q_heads = num_q_heads_per_q_packing * q_packing
1212
+
1213
+ bkv_p = num_kv_pages_per_block
1214
+ bq_sz = num_queries_per_block
1215
+ bkv_sz_per_kv_packing = bkv_p * page_size_per_kv_packing
1216
+ grid = (distribution[2], )
1217
+
1218
+ in_specs = [
1219
+ pl.BlockSpec(memory_space=pltpu.HBM),
1220
+ pl.BlockSpec(memory_space=pltpu.HBM),
1221
+ pl.BlockSpec(memory_space=pltpu.HBM),
1222
+ pl.BlockSpec(memory_space=pltpu.HBM),
1223
+ pl.BlockSpec(memory_space=pltpu.HBM),
1224
+ pl.BlockSpec(memory_space=pltpu.HBM),
1225
+ ]
1226
+
1227
+ out_specs = [
1228
+ pl.BlockSpec(memory_space=pltpu.HBM),
1229
+ pl.BlockSpec(memory_space=pltpu.HBM),
1230
+ pl.BlockSpec(memory_space=pltpu.HBM),
1231
+ ]
1232
+
1233
+ bkvc_double_buf = pltpu.VMEM(
1234
+ (2, bkv_sz_per_kv_packing, kv_packing, lkv_dim),
1235
+ cache_kv_c.dtype,
1236
+ )
1237
+
1238
+ bkpe_double_buf = pltpu.VMEM(
1239
+ (2, bkv_sz_per_kv_packing, kv_packing, r_dim),
1240
+ cache_k_pe.dtype,
1241
+ )
1242
+
1243
+ bq_nope_double_buf = pltpu.VMEM(
1244
+ (2, bq_sz, num_q_heads_per_q_packing, q_packing, lkv_dim),
1245
+ ql_nope.dtype,
1246
+ )
1247
+
1248
+ bq_rope_double_buf = pltpu.VMEM(
1249
+ (2, bq_sz, num_q_heads_per_q_packing, q_packing, r_dim),
1250
+ q_pe.dtype,
1251
+ )
1252
+
1253
+ bo_double_buf = bq_nope_double_buf
1254
+
1255
+ l_scratch = pltpu.VMEM(
1256
+ (bq_sz * num_q_heads, 128),
1257
+ jnp.float32,
1258
+ )
1259
+ m_scratch = l_scratch
1260
+
1261
+ acc_scratch = pltpu.VMEM(
1262
+ (bq_sz * num_q_heads, lkv_dim),
1263
+ jnp.float32,
1264
+ )
1265
+
1266
+ scratch_shapes = [
1267
+ bkvc_double_buf,
1268
+ bkpe_double_buf,
1269
+ bq_nope_double_buf,
1270
+ bq_rope_double_buf,
1271
+ bo_double_buf, # Double buffering for output block.
1272
+ # Semaphores for double buffering of bkv, bq, bo and bkv_update.
1273
+ pltpu.SemaphoreType.DMA((4, 2)),
1274
+ # Intermediate buffers per kv head for flash attention.
1275
+ l_scratch,
1276
+ m_scratch,
1277
+ acc_scratch,
1278
+ ]
1279
+
1280
+ scalar_prefetches = (
1281
+ kv_lens,
1282
+ # TODO(jevinjiang): can we use ragged page_indices to save some smem?
1283
+ page_indices,
1284
+ cu_q_lens,
1285
+ distribution,
1286
+ # (bq_sem_idx, bkv_sem_idx, bo_sem_idx)
1287
+ jnp.zeros((3, ), jnp.int32),
1288
+ # (bo_sem_0_seq_idx, bo_sem_1_seq_idx, bo_sem_0_bo_idx, bo_sem_1_bo_idx)
1289
+ jnp.full((4, ), -1, jnp.int32),
1290
+ # (bkv_sem_0_seq_idx, bkv_sem_1_seq_idx, bkv_sem_0_offset, bkv_sem_1_offset, bkv_sem_0_sz, bkv_sem_1_sz)
1291
+ jnp.full((6, ), -1, jnp.int32),
1292
+ )
1293
+
1294
+ scope_name = f"MLA-RPA-bq_{bq_sz}-bkvp_{bkv_p}-p_{page_size}"
1295
+ kernel = jax.named_scope(scope_name)(
1296
+ pl.pallas_call(
1297
+ functools.partial(
1298
+ _mla_ragged_paged_attention_kernel,
1299
+ sm_scale=sm_scale,
1300
+ sliding_window=sliding_window,
1301
+ soft_cap=soft_cap,
1302
+ mask_value=mask_value,
1303
+ chunk_prefill_size=chunk_prefill_size,
1304
+ bq_sz=bq_sz,
1305
+ bkv_p=bkv_p,
1306
+ debug_mode=debug_mode,
1307
+ ),
1308
+ grid_spec=pltpu.PrefetchScalarGridSpec(
1309
+ num_scalar_prefetch=len(scalar_prefetches),
1310
+ in_specs=in_specs,
1311
+ out_specs=out_specs,
1312
+ grid=grid,
1313
+ scratch_shapes=scratch_shapes,
1314
+ ),
1315
+ compiler_params=pltpu.CompilerParams(
1316
+ # TODO(jevinjiang): since each sequence depends on the previous
1317
+ # one, we need some extra work to support Megacore mode.
1318
+ dimension_semantics=("arbitrary", ),
1319
+ vmem_limit_bytes=vmem_limit_bytes,
1320
+ ),
1321
+ out_shape=[
1322
+ jax.ShapeDtypeStruct(shape=ql_nope.shape, dtype=ql_nope.dtype),
1323
+ jax.ShapeDtypeStruct(shape=cache_kv_c.shape,
1324
+ dtype=cache_kv_c.dtype),
1325
+ jax.ShapeDtypeStruct(shape=cache_k_pe.shape,
1326
+ dtype=cache_k_pe.dtype),
1327
+ ],
1328
+ input_output_aliases={
1329
+ 7: 0,
1330
+ 11: 1,
1331
+ 12: 2
1332
+ },
1333
+ name=scope_name,
1334
+ ))
1335
+
1336
+ output, updated_kv_c, updated_k_pe = kernel(
1337
+ *scalar_prefetches,
1338
+ ql_nope,
1339
+ q_pe,
1340
+ new_kv_c,
1341
+ new_k_pe,
1342
+ cache_kv_c,
1343
+ cache_k_pe,
1344
+ )
1345
+ output = prepare_outputs(
1346
+ output, actual_num_q_heads,
1347
+ actual_lkv_dim) # [max_num_tokens, actual_num_q_heads, actual_lkv_dim]
1348
+
1349
+ return output, updated_kv_c, updated_k_pe