tpu-inference 0.11.1.dev202511150811__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (179) hide show
  1. tests/__init__.py +0 -0
  2. tests/core/__init__.py +0 -0
  3. tests/core/test_core_tpu.py +513 -0
  4. tests/core/test_disagg_executor.py +60 -0
  5. tests/core/test_disagg_utils.py +53 -0
  6. tests/core/test_dp_scheduler.py +899 -0
  7. tests/core/test_init.py +49 -0
  8. tests/kernels/__init__.py +0 -0
  9. tests/kernels/fused_moe_v1_test.py +105 -0
  10. tests/kernels/mla_v1_test.py +396 -0
  11. tests/kernels/quantized_matmul_kernel_test.py +191 -0
  12. tests/kernels/ragged_kv_cache_update_v2_test.py +234 -0
  13. tests/kernels/ragged_paged_attention_kernel_v2_test.py +400 -0
  14. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +549 -0
  15. tests/kernels/ragged_paged_attention_kernel_v3_test.py +504 -0
  16. tests/lora/__init__.py +0 -0
  17. tests/lora/conftest.py +32 -0
  18. tests/lora/test_bgmv.py +43 -0
  19. tests/lora/test_layers.py +654 -0
  20. tests/lora/test_lora.py +133 -0
  21. tests/lora/utils.py +96 -0
  22. tests/test_base.py +201 -0
  23. tests/test_envs.py +182 -0
  24. tests/test_quantization.py +836 -0
  25. tests/test_tpu_info.py +120 -0
  26. tests/test_utils.py +236 -0
  27. tpu_inference/__init__.py +34 -0
  28. tpu_inference/core/__init__.py +0 -0
  29. tpu_inference/core/core_tpu.py +786 -0
  30. tpu_inference/core/disagg_executor.py +118 -0
  31. tpu_inference/core/disagg_utils.py +51 -0
  32. tpu_inference/core/sched/__init__.py +0 -0
  33. tpu_inference/core/sched/dp_scheduler.py +523 -0
  34. tpu_inference/distributed/__init__.py +0 -0
  35. tpu_inference/distributed/jax_parallel_state.py +67 -0
  36. tpu_inference/distributed/tpu_connector.py +728 -0
  37. tpu_inference/distributed/utils.py +59 -0
  38. tpu_inference/env_override.py +9 -0
  39. tpu_inference/envs.py +107 -0
  40. tpu_inference/executors/__init__.py +0 -0
  41. tpu_inference/executors/ray_distributed_executor.py +362 -0
  42. tpu_inference/experimental/__init__.py +0 -0
  43. tpu_inference/experimental/llama3_jax_stashed.py +258 -0
  44. tpu_inference/kernels/__init__.py +0 -0
  45. tpu_inference/kernels/collectives/__init__.py +0 -0
  46. tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
  47. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
  48. tpu_inference/kernels/collectives/util.py +47 -0
  49. tpu_inference/kernels/flash_attention/__init__.py +0 -0
  50. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  51. tpu_inference/kernels/fused_moe/__init__.py +0 -0
  52. tpu_inference/kernels/fused_moe/v1/__init__.py +0 -0
  53. tpu_inference/kernels/fused_moe/v1/kernel.py +1035 -0
  54. tpu_inference/kernels/mla/__init__.py +0 -0
  55. tpu_inference/kernels/mla/v1/__init__.py +0 -0
  56. tpu_inference/kernels/mla/v1/kernel.py +1349 -0
  57. tpu_inference/kernels/quantized_matmul/__init__.py +0 -0
  58. tpu_inference/kernels/quantized_matmul/kernel.py +395 -0
  59. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  60. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  61. tpu_inference/kernels/ragged_paged_attention/__init__.py +0 -0
  62. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +0 -0
  63. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +875 -0
  64. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +287 -0
  65. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  66. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +0 -0
  67. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1478 -0
  68. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1482 -0
  69. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4147 -0
  70. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +367 -0
  71. tpu_inference/kernels/ragged_paged_attention/v3/util.py +51 -0
  72. tpu_inference/layers/__init__.py +0 -0
  73. tpu_inference/layers/common/__init__.py +0 -0
  74. tpu_inference/layers/common/attention_interface.py +390 -0
  75. tpu_inference/layers/common/attention_metadata.py +34 -0
  76. tpu_inference/layers/common/binary_search.py +295 -0
  77. tpu_inference/layers/common/quant_methods.py +8 -0
  78. tpu_inference/layers/common/sharding.py +582 -0
  79. tpu_inference/layers/jax/__init__.py +0 -0
  80. tpu_inference/layers/jax/attention/__init__.py +0 -0
  81. tpu_inference/layers/jax/attention/attention.py +255 -0
  82. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
  83. tpu_inference/layers/jax/attention/gpt_oss_attention.py +262 -0
  84. tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
  85. tpu_inference/layers/jax/base.py +151 -0
  86. tpu_inference/layers/jax/constants.py +88 -0
  87. tpu_inference/layers/jax/layers.py +301 -0
  88. tpu_inference/layers/jax/misc.py +16 -0
  89. tpu_inference/layers/jax/moe/__init__.py +0 -0
  90. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
  91. tpu_inference/layers/jax/moe/gpt_oss_moe.py +185 -0
  92. tpu_inference/layers/jax/moe/moe.py +209 -0
  93. tpu_inference/layers/jax/rope.py +280 -0
  94. tpu_inference/layers/jax/rope_interface.py +214 -0
  95. tpu_inference/layers/jax/sample/__init__.py +0 -0
  96. tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
  97. tpu_inference/layers/jax/sample/sampling.py +96 -0
  98. tpu_inference/layers/jax/sample/sampling_metadata.py +76 -0
  99. tpu_inference/layers/jax/transformer_block.py +107 -0
  100. tpu_inference/layers/vllm/__init__.py +0 -0
  101. tpu_inference/layers/vllm/attention.py +221 -0
  102. tpu_inference/layers/vllm/fused_moe.py +507 -0
  103. tpu_inference/layers/vllm/linear_common.py +186 -0
  104. tpu_inference/layers/vllm/quantization/__init__.py +39 -0
  105. tpu_inference/layers/vllm/quantization/awq.py +207 -0
  106. tpu_inference/layers/vllm/quantization/common.py +105 -0
  107. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
  108. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +120 -0
  109. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +203 -0
  110. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
  111. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
  112. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
  113. tpu_inference/layers/vllm/quantization/mxfp4.py +266 -0
  114. tpu_inference/layers/vllm/quantization/unquantized.py +386 -0
  115. tpu_inference/layers/vllm/sharding.py +230 -0
  116. tpu_inference/logger.py +10 -0
  117. tpu_inference/lora/__init__.py +0 -0
  118. tpu_inference/lora/torch_lora_ops.py +103 -0
  119. tpu_inference/lora/torch_punica_tpu.py +311 -0
  120. tpu_inference/mock/__init__.py +0 -0
  121. tpu_inference/mock/vllm_config_utils.py +28 -0
  122. tpu_inference/mock/vllm_envs.py +1219 -0
  123. tpu_inference/mock/vllm_logger.py +212 -0
  124. tpu_inference/mock/vllm_logging_utils.py +15 -0
  125. tpu_inference/models/__init__.py +0 -0
  126. tpu_inference/models/common/__init__.py +0 -0
  127. tpu_inference/models/common/model_loader.py +444 -0
  128. tpu_inference/models/jax/__init__.py +0 -0
  129. tpu_inference/models/jax/deepseek_v3.py +868 -0
  130. tpu_inference/models/jax/gpt_oss.py +492 -0
  131. tpu_inference/models/jax/jax_intermediate_tensor.py +79 -0
  132. tpu_inference/models/jax/llama3.py +375 -0
  133. tpu_inference/models/jax/llama4.py +629 -0
  134. tpu_inference/models/jax/llama_eagle3.py +333 -0
  135. tpu_inference/models/jax/phi3.py +376 -0
  136. tpu_inference/models/jax/qwen2.py +375 -0
  137. tpu_inference/models/jax/qwen2_5_vl.py +1103 -0
  138. tpu_inference/models/jax/qwen3.py +302 -0
  139. tpu_inference/models/jax/utils/__init__.py +0 -0
  140. tpu_inference/models/jax/utils/file_utils.py +96 -0
  141. tpu_inference/models/jax/utils/multi_modal_utils.py +163 -0
  142. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  143. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +5 -0
  144. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +6 -0
  145. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +5 -0
  146. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +6 -0
  147. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +105 -0
  148. tpu_inference/models/jax/utils/quantization/quantization_utils.py +653 -0
  149. tpu_inference/models/jax/utils/weight_utils.py +529 -0
  150. tpu_inference/models/vllm/__init__.py +0 -0
  151. tpu_inference/models/vllm/vllm_model_wrapper.py +286 -0
  152. tpu_inference/models/vllm/vllm_model_wrapper_context.py +45 -0
  153. tpu_inference/platforms/__init__.py +2 -0
  154. tpu_inference/platforms/tpu_platform.py +269 -0
  155. tpu_inference/runner/__init__.py +0 -0
  156. tpu_inference/runner/block_table.py +122 -0
  157. tpu_inference/runner/compilation_manager.py +780 -0
  158. tpu_inference/runner/input_batch.py +435 -0
  159. tpu_inference/runner/kv_cache.py +132 -0
  160. tpu_inference/runner/kv_cache_manager.py +479 -0
  161. tpu_inference/runner/lora_utils.py +92 -0
  162. tpu_inference/runner/multimodal_manager.py +217 -0
  163. tpu_inference/runner/persistent_batch_manager.py +244 -0
  164. tpu_inference/runner/speculative_decoding_manager.py +248 -0
  165. tpu_inference/runner/structured_decoding_manager.py +88 -0
  166. tpu_inference/runner/tpu_runner.py +1620 -0
  167. tpu_inference/runner/utils.py +426 -0
  168. tpu_inference/spec_decode/__init__.py +0 -0
  169. tpu_inference/spec_decode/jax/__init__.py +0 -0
  170. tpu_inference/spec_decode/jax/eagle3.py +367 -0
  171. tpu_inference/tpu_info.py +77 -0
  172. tpu_inference/utils.py +317 -0
  173. tpu_inference/worker/__init__.py +0 -0
  174. tpu_inference/worker/tpu_worker.py +321 -0
  175. tpu_inference-0.11.1.dev202511150811.dist-info/METADATA +107 -0
  176. tpu_inference-0.11.1.dev202511150811.dist-info/RECORD +179 -0
  177. tpu_inference-0.11.1.dev202511150811.dist-info/WHEEL +5 -0
  178. tpu_inference-0.11.1.dev202511150811.dist-info/licenses/LICENSE +201 -0
  179. tpu_inference-0.11.1.dev202511150811.dist-info/top_level.txt +2 -0
@@ -0,0 +1,1482 @@
1
+ """
2
+ A variant of TPU-Friendly Ragged Paged Attention kernel optimized for
3
+ head_dim = 64.
4
+ """
5
+
6
+ import functools
7
+
8
+ import jax
9
+ import jax.numpy as jnp
10
+ from jax import lax
11
+ from jax.experimental import pallas as pl
12
+ from jax.experimental.pallas import tpu as pltpu
13
+
14
+ from tpu_inference.kernels.ragged_paged_attention.v3.tuned_block_sizes_hd64 import \
15
+ get_tuned_block_sizes
16
+ from tpu_inference.kernels.ragged_paged_attention.v3.util import (
17
+ align_to, cdiv, get_dtype_packing)
18
+
19
+ DEFAULT_MASK_VALUE = -0.7 * float(jnp.finfo(jnp.dtype("float32")).max)
20
+
21
+ DEFAULT_VMEM_LIMIT_BYTES = 100 * 1024 * 1024
22
+
23
+
24
+ # TODO(chengjiyao): refactor this hd64 variant and the original kernel to make
25
+ # sure most of the code is shared.
26
+ def ref_ragged_paged_attention_hd64(
27
+ queries: jax.
28
+ Array, # [max_num_tokens, actual_num_q_heads, actual_head_dim]
29
+ keys: jax.Array, # [max_num_tokens, actual_num_kv_heads, actual_head_dim]
30
+ values: jax.
31
+ Array, # [max_num_tokens, actual_num_kv_heads, actual_head_dim]
32
+ kv_cache: jax.
33
+ Array, # [total_num_pages, page_size, num_kv_heads, kv_packing, actual_head_dim_x2]
34
+ kv_lens: jax.Array, # i32[max_num_seqs]
35
+ page_indices: jax.Array, # i32[max_num_seqs * pages_per_seq]
36
+ cu_q_lens: jax.Array, # i32[max_num_seqs + 1]
37
+ distribution: jax.Array, # i32[3]
38
+ attention_sink: jax.Array | None = None, # f32[actual_num_q_heads]
39
+ *,
40
+ sm_scale: float = 1.0,
41
+ sliding_window: int | None = None,
42
+ soft_cap: float | None = None,
43
+ mask_value: float | None = DEFAULT_MASK_VALUE,
44
+ q_scale: float | None = None,
45
+ k_scale: float | None = None,
46
+ v_scale: float | None = None,
47
+ ):
48
+ if mask_value is None:
49
+ mask_value = DEFAULT_MASK_VALUE
50
+ dynamic_validate_inputs(
51
+ queries,
52
+ keys,
53
+ values,
54
+ kv_cache,
55
+ kv_lens,
56
+ page_indices,
57
+ cu_q_lens,
58
+ distribution,
59
+ attention_sink,
60
+ sm_scale=sm_scale,
61
+ sliding_window=sliding_window,
62
+ soft_cap=soft_cap,
63
+ mask_value=mask_value,
64
+ q_scale=q_scale,
65
+ k_scale=k_scale,
66
+ v_scale=v_scale,
67
+ )
68
+ actual_head_dim = queries.shape[2]
69
+ actual_num_q_heads = queries.shape[1]
70
+ actual_num_kv_heads = keys.shape[1]
71
+ assert actual_head_dim == 64
72
+ (
73
+ _,
74
+ page_size,
75
+ _,
76
+ kv_packing,
77
+ actual_head_dim_x2,
78
+ ) = kv_cache.shape
79
+
80
+ assert actual_num_q_heads % actual_num_kv_heads == 0
81
+ assert actual_head_dim_x2 == 128
82
+ assert get_dtype_packing(kv_cache.dtype) == kv_packing
83
+ actual_num_q_heads_per_kv_head = actual_num_q_heads // actual_num_kv_heads
84
+ padded_actual_num_kv_heads = align_to(actual_num_kv_heads, kv_packing)
85
+ max_num_seqs = kv_lens.shape[0]
86
+ num_page_indices = page_indices.shape[0]
87
+ assert num_page_indices % max_num_seqs == 0
88
+ pages_per_seq = num_page_indices // max_num_seqs
89
+
90
+ # prepare kv and queries
91
+ merged_kv = merge_kv(keys, values)
92
+ queries = jnp.pad(queries, ((0, 0), (0, 0), (0, 64)), constant_values=0.0)
93
+ outputs = []
94
+
95
+ for i in range(distribution[-1]):
96
+ q_start = cu_q_lens[i]
97
+ q_end = cu_q_lens[i + 1]
98
+ q_len = q_end - q_start
99
+
100
+ kv_len = kv_lens[i]
101
+ indices_start = i * pages_per_seq
102
+ indices_end = indices_start + cdiv(kv_len, page_size)
103
+ indices = page_indices[indices_start:indices_end]
104
+ q = queries[q_start:q_end, :, :]
105
+
106
+ # Update the kv cache.
107
+ assert kv_len - q_len >= 0
108
+ gathered_kv = kv_cache[indices]
109
+ gathered_shape = gathered_kv.shape
110
+ gathered_kv = gathered_kv.reshape(-1, *gathered_shape[-3:])
111
+ gathered_kv = gathered_kv.at[kv_len - q_len:kv_len].set(
112
+ merged_kv[q_start:q_end])
113
+ kv_cache = kv_cache.at[indices].set(
114
+ gathered_kv.reshape(gathered_shape))
115
+
116
+ kv = gathered_kv.reshape(
117
+ -1, padded_actual_num_kv_heads,
118
+ actual_head_dim_x2)[:, :actual_num_kv_heads, :]
119
+ kv = kv[:kv_len, :, :]
120
+ kv = jnp.repeat(kv, actual_num_q_heads_per_kv_head, axis=1)
121
+ if q_scale is not None:
122
+ q = q / q_scale
123
+ if jnp.issubdtype(kv.dtype, jnp.floating):
124
+ dtype_info = jnp.finfo(kv.dtype)
125
+ minval = float(dtype_info.min)
126
+ maxval = float(dtype_info.max)
127
+ q = jnp.clip(q, min=minval, max=maxval)
128
+ q = q.astype(kv.dtype)
129
+ attn = jnp.einsum("qhd,khd->hqk",
130
+ q,
131
+ kv,
132
+ preferred_element_type=jnp.float32)
133
+ attn *= sm_scale
134
+ if k_scale is not None:
135
+ attn *= k_scale
136
+ if q_scale is not None:
137
+ attn *= q_scale
138
+
139
+ q_span = (kv_len - q_len) + jax.lax.broadcasted_iota(
140
+ jnp.int32, attn.shape, 1)
141
+ kv_span = jax.lax.broadcasted_iota(jnp.int32, attn.shape, 2)
142
+ mask = q_span < kv_span
143
+ if sliding_window is not None:
144
+ mask = jnp.logical_or(mask, q_span - sliding_window >= kv_span)
145
+ if soft_cap is not None:
146
+ attn = soft_cap * jnp.tanh(attn / soft_cap)
147
+ attn = jnp.where(mask, mask_value, attn)
148
+
149
+ if attention_sink is not None:
150
+ reshaped_attention_sink = attention_sink.reshape(
151
+ actual_num_q_heads, 1, 1)
152
+ reshaped_attention_sink = jnp.repeat(reshaped_attention_sink,
153
+ q_len,
154
+ axis=1)
155
+ attn = jnp.concat([reshaped_attention_sink, attn], axis=2)
156
+ attn = jax.nn.softmax(attn, axis=-1).astype(kv.dtype)
157
+ attn = attn[..., 1:]
158
+ else:
159
+ attn = jax.nn.softmax(attn, axis=-1).astype(kv.dtype)
160
+
161
+ out = jnp.einsum("hqk,khd->qhd", attn, kv).astype(queries.dtype)
162
+ if v_scale is not None:
163
+ out *= v_scale
164
+
165
+ outputs.append(out)
166
+
167
+ result = jnp.concatenate(outputs, axis=0)
168
+ result = result[:, :, actual_head_dim:]
169
+ return result, kv_cache
170
+
171
+
172
+ def get_smem_estimate_bytes(max_num_seqs, pages_per_seq):
173
+ total_bits = (
174
+ # kv_lens_ref: i32[max_num_seqs]
175
+ align_to(max_num_seqs, 128) * 32 +
176
+ # page_indices_ref: i32[max_num_seqs * pages_per_seq]
177
+ align_to(max_num_seqs * pages_per_seq, 128) * 32 +
178
+ # cu_q_lens_ref: i32[max_num_seqs + 1]
179
+ align_to(max_num_seqs + 1, 128) * 32 +
180
+ # distribution_ref: i32[3]
181
+ 128 * 32 +
182
+ # sem_ids_ref: i32[3]
183
+ 128 * 32 +
184
+ # bo_ids_ref: i32[4]
185
+ 128 * 32 +
186
+ # bkv_update_ids_ref: i32[6]
187
+ 128 * 32)
188
+ return cdiv(total_bits, 8)
189
+
190
+
191
+ def get_vmem_estimate_bytes(
192
+ actual_num_kv_heads,
193
+ actual_num_q_heads_per_kv_head,
194
+ actual_head_dim,
195
+ bq_sz,
196
+ bkv_sz,
197
+ q_dtype,
198
+ kv_dtype,
199
+ ):
200
+ assert actual_head_dim == 64
201
+ q_packing = get_dtype_packing(q_dtype)
202
+ kv_packing = get_dtype_packing(kv_dtype)
203
+ num_q_heads_per_kv_head = align_to(actual_num_q_heads_per_kv_head,
204
+ q_packing)
205
+ num_kv_heads = align_to(actual_num_kv_heads, kv_packing)
206
+ head_dim = actual_head_dim * 2
207
+
208
+ total_bits = (
209
+ # bkv_x2_ref
210
+ (2 * bkv_sz * num_kv_heads * head_dim) * (32 // kv_packing) +
211
+ # bq_x2_ref + bo_x2_ref
212
+ 2 * (2 * actual_num_kv_heads * bq_sz * num_q_heads_per_kv_head *
213
+ head_dim) * (32 // q_packing) +
214
+ # l_ref + m_ref
215
+ 2 *
216
+ (actual_num_kv_heads * bq_sz * num_q_heads_per_kv_head * 128) * 32 +
217
+ # acc_ref
218
+ (actual_num_kv_heads * bq_sz * num_q_heads_per_kv_head * head_dim) *
219
+ 32)
220
+ return cdiv(total_bits, 8)
221
+
222
+
223
+ def get_kv_cache_shape(
224
+ total_num_pages,
225
+ page_size,
226
+ actual_num_kv_heads,
227
+ actual_head_dim,
228
+ kv_dtype,
229
+ ):
230
+ assert actual_head_dim == 64
231
+ kv_packing = get_dtype_packing(kv_dtype)
232
+ return (
233
+ total_num_pages,
234
+ page_size,
235
+ align_to(actual_num_kv_heads, kv_packing) // kv_packing,
236
+ kv_packing,
237
+ 128,
238
+ )
239
+
240
+
241
+ def _ragged_paged_attention_kernel(
242
+ # Prefetch
243
+ kv_lens_ref, # [max_num_seqs]
244
+ page_indices_ref, # [max_num_seqs * pages_per_seq]
245
+ cu_q_lens_ref, # [max_num_seqs + 1]
246
+ # TODO(jevinjiang): merge these into one so we can save SMEM.
247
+ distribution_ref, # [3] (decode_end, prefill_end, mixed_end)
248
+ sem_ids_ref, # [3] (bq_sem_idx, bkv_sem_idx, bo_sem_idx)
249
+ bo_ids_ref, # [4] (bo_sem_0_seq_idx, bo_sem_1_seq_idx, bo_sem_0_bo_idx, bo_sem_1_bo_idx)
250
+ bkv_update_ids_ref, # [6] (bkv_sem_0_seq_idx, bkv_sem_1_seq_idx, bkv_sem_0_offset, bkv_sem_1_offset, bkv_sem_0_sz, bkv_sem_1_sz)
251
+ # Input
252
+ q_hbm_ref, # [actual_num_kv_heads, max_num_tokens, num_q_heads_per_kv_head // q_packing, q_packing, head_dim]
253
+ kv_hbm_ref, # [max_num_tokens, num_kv_heads // kv_packing, kv_packing, actual_head_dim_x2]
254
+ kv_cache_hbm_ref, # [total_num_pages, page_size, num_kv_heads // kv_packing, kv_packing, actual_head_dim_x2]
255
+ attention_sink_ref, # [actual_num_kv_heads, num_q_heads_per_kv_head, 128]
256
+ # Output
257
+ o_hbm_ref, # [actual_num_kv_heads, max_num_tokens, num_q_heads_per_kv_head // q_packing, q_packing, actual_head_dim_x2]
258
+ updated_kv_cache_hbm_ref, # [total_num_pages, page_size, num_kv_heads // kv_packing, kv_packing, actual_head_dim_x2]
259
+ # Scratch
260
+ bkv_x2_ref, # [2, bkv_sz, num_kv_heads // kv_packing, kv_packing, actual_head_dim_x2]
261
+ bq_x2_ref, # [2, actual_num_kv_heads, bq_sz, num_q_heads_per_kv_head // q_packing, q_packing, actual_head_dim_x2]
262
+ bo_x2_ref, # [2, actual_num_kv_heads, bq_sz, num_q_heads_per_kv_head // q_packing, q_packing, actual_head_dim_x2]
263
+ sems, # [4, 2]
264
+ l_ref, # [actual_num_kv_heads, bq_sz * num_q_heads_per_kv_head, 128],
265
+ m_ref, # [actual_num_kv_heads, bq_sz * num_q_heads_per_kv_head, 128],
266
+ acc_ref, # [actual_num_kv_heads, bq_sz * num_q_heads_per_kv_head, actual_head_dim_x2],
267
+ *,
268
+ sm_scale: float,
269
+ sliding_window: int | None = None,
270
+ soft_cap: float | None = None,
271
+ mask_value: float = DEFAULT_MASK_VALUE,
272
+ q_scale: float | None = None,
273
+ k_scale: float | None = None,
274
+ v_scale: float | None = None,
275
+ chunk_prefill_size: int | None = None,
276
+ bkv_p,
277
+ bq_sz,
278
+ debug_mode: bool = False,
279
+ ):
280
+ assert q_hbm_ref.shape == o_hbm_ref.shape
281
+ assert q_hbm_ref.shape[-1] == kv_cache_hbm_ref.shape[-1]
282
+ (
283
+ actual_num_kv_heads,
284
+ max_num_tokens,
285
+ num_q_heads_per_kv_head_per_packing,
286
+ q_packing,
287
+ actual_head_dim_x2,
288
+ ) = q_hbm_ref.shape
289
+ (
290
+ total_num_pages,
291
+ page_size,
292
+ num_kv_heads_per_kv_packing,
293
+ kv_packing,
294
+ _,
295
+ ) = kv_cache_hbm_ref.shape
296
+ max_num_seqs = kv_lens_ref.shape[0]
297
+ num_page_indices = page_indices_ref.shape[0]
298
+ assert num_page_indices % max_num_seqs == 0
299
+ pages_per_seq = num_page_indices // max_num_seqs
300
+ num_kv_heads = num_kv_heads_per_kv_packing * kv_packing
301
+ num_q_heads_per_kv_head = num_q_heads_per_kv_head_per_packing * q_packing
302
+ q_dtype = q_hbm_ref.dtype
303
+ kv_dtype = kv_cache_hbm_ref.dtype
304
+ assert o_hbm_ref.dtype == q_dtype
305
+ assert get_dtype_packing(q_dtype) == q_packing
306
+ assert get_dtype_packing(kv_dtype) == kv_packing
307
+ assert actual_head_dim_x2 == 128
308
+ bkv_sz = bkv_p * page_size
309
+ seq_idx = pl.program_id(0)
310
+ num_seqs = pl.num_programs(0)
311
+ decode_end = distribution_ref[0]
312
+ prefill_end = distribution_ref[1]
313
+ mixed_end = distribution_ref[2]
314
+
315
+ q_start = cu_q_lens_ref[seq_idx]
316
+ q_end = cu_q_lens_ref[seq_idx + 1]
317
+ q_len = q_end - q_start
318
+ kv_len = kv_lens_ref[seq_idx]
319
+
320
+ bkv_idx_start = 0 if sliding_window is None else jnp.maximum(
321
+ kv_len - sliding_window, 0) // bkv_sz
322
+
323
+ if sliding_window is None:
324
+ next_bkv_idx_start = 0
325
+ else:
326
+
327
+ def get_next_bkv_idx_start():
328
+ next_kv_len = kv_lens_ref[seq_idx + 1]
329
+ return jnp.maximum(next_kv_len - sliding_window, 0) // bkv_sz
330
+
331
+ next_bkv_idx_start = lax.cond(seq_idx + 1 < num_seqs,
332
+ get_next_bkv_idx_start, lambda: 0)
333
+
334
+ def debug_print(msg, *args):
335
+ if debug_mode:
336
+ pl.debug_print(msg, *args)
337
+
338
+ debug_print("[RPA debug] ======= In loop seq_idx={}", seq_idx)
339
+ debug_print("[RPA debug] num_seqs={}", num_seqs)
340
+ debug_print("[RPA debug] decode_end={}", decode_end)
341
+ debug_print("[RPA debug] prefill_end={}", prefill_end)
342
+ debug_print("[RPA debug] mixed_end={}", mixed_end)
343
+ debug_print("[RPA debug] bkv_p={}", bkv_p)
344
+ debug_print("[RPA debug] page_size={}", page_size)
345
+ debug_print("[RPA debug] pages_per_seq={}", pages_per_seq)
346
+ debug_print("[RPA debug] bkv_sz={}", bkv_sz)
347
+ debug_print("[RPA debug] bq_sz={}", bq_sz)
348
+ debug_print("[RPA debug] q_start={}", q_start)
349
+ debug_print("[RPA debug] q_end={}", q_end)
350
+ debug_print("[RPA debug] q_len={}", q_len)
351
+ debug_print("[RPA debug] kv_len={}", kv_len)
352
+
353
+ def flash_attention(
354
+ q, # [actual_bq_sz * num_q_heads_per_kv_head, actual_head_dim_x2]
355
+ kv, # [bkv_sz, actual_head_dim_x2]
356
+ *,
357
+ bq_idx,
358
+ bkv_idx,
359
+ kv_head_idx,
360
+ ):
361
+ assert len(q.shape) == 2
362
+ assert q.shape[0] % num_q_heads_per_kv_head == 0
363
+ assert q.shape[1] == actual_head_dim_x2
364
+ assert kv.shape == (bkv_sz, actual_head_dim_x2)
365
+ head_l_ref = l_ref.at[kv_head_idx, :q.shape[0]]
366
+ head_m_ref = m_ref.at[kv_head_idx, :q.shape[0]]
367
+ head_acc_ref = acc_ref.at[kv_head_idx, :q.shape[0]]
368
+
369
+ def load_with_init(ref, init_val):
370
+ return jnp.where(bkv_idx == bkv_idx_start,
371
+ jnp.full_like(ref, init_val), ref[...])
372
+
373
+ # Follow FlashAttention-2 forward pass.
374
+ if q_scale is not None:
375
+ q = q / q_scale
376
+ if jnp.issubdtype(kv.dtype, jnp.floating):
377
+ dtype_info = jnp.finfo(kv.dtype)
378
+ minval = float(dtype_info.min)
379
+ maxval = float(dtype_info.max)
380
+ q = jnp.clip(q, min=minval, max=maxval)
381
+ q = q.astype(kv.dtype)
382
+
383
+ s = jnp.einsum("nd,md->nm", q, kv, preferred_element_type=jnp.float32)
384
+ s *= sm_scale
385
+ if k_scale is not None:
386
+ s *= k_scale
387
+ if q_scale is not None:
388
+ s *= q_scale
389
+
390
+ q_span = (kv_len - q_len + bq_idx * bq_sz +
391
+ lax.broadcasted_iota(jnp.int32, s.shape, 0) //
392
+ num_q_heads_per_kv_head)
393
+ k_span = bkv_idx * bkv_sz + lax.broadcasted_iota(jnp.int32, s.shape, 1)
394
+ mask = q_span < k_span
395
+
396
+ if soft_cap is not None:
397
+ s = soft_cap * jnp.tanh(s / soft_cap)
398
+ s += jnp.where(mask, mask_value, 0.0)
399
+ s_rowmax = jnp.max(s, axis=1, keepdims=True)
400
+
401
+ if attention_sink_ref is not None:
402
+ sinks = attention_sink_ref[kv_head_idx]
403
+ actual_bq_sz = q.shape[0] // num_q_heads_per_kv_head
404
+ m_prev_init = jnp.concat([sinks] * actual_bq_sz, axis=0)
405
+ m_prev = jnp.where(bkv_idx == bkv_idx_start, m_prev_init,
406
+ head_m_ref[...])
407
+ else:
408
+ m_prev = load_with_init(head_m_ref, -jnp.inf)
409
+
410
+ m_curr = jnp.maximum(m_prev, s_rowmax)
411
+ head_m_ref[...] = m_curr
412
+ p = jnp.exp(s - broadcast_minor(m_curr, s.shape))
413
+
414
+ pv = jnp.einsum("nm,md->nd", p, kv, preferred_element_type=jnp.float32)
415
+ if v_scale is not None:
416
+ pv *= v_scale
417
+
418
+ p_rowsum = jnp.sum(p, axis=1, keepdims=True)
419
+ exp_m_diff = jnp.exp(m_prev - m_curr)
420
+ l_prev = load_with_init(head_l_ref, 1.0)
421
+ l_curr = exp_m_diff * l_prev + p_rowsum
422
+ head_l_ref[...] = l_curr
423
+ o_prev = load_with_init(head_acc_ref, 0.0)
424
+ o_curr = broadcast_minor(exp_m_diff, o_prev.shape) * o_prev + pv
425
+ head_acc_ref[...] = o_curr
426
+
427
+ def _async_copy(src, dst, sem, wait):
428
+ if debug_mode:
429
+ # Skip DMA if debug mode is enabled.
430
+ return
431
+ cp = pltpu.make_async_copy(src, dst, sem)
432
+ if wait:
433
+ cp.wait()
434
+ else:
435
+ cp.start()
436
+
437
+ def _fetch_bkv(seq_idx, bkv_idx, bkv_sem_idx, *, wait=False):
438
+ sem = sems.at[0, bkv_sem_idx]
439
+ vmem_ref = bkv_x2_ref.at[bkv_sem_idx]
440
+
441
+ cache_hbm_shape = kv_cache_hbm_ref.shape
442
+ cache_hbm_ref = kv_cache_hbm_ref.reshape(
443
+ cache_hbm_shape[0] * cache_hbm_shape[1], *cache_hbm_shape[2:])
444
+ kv_len = kv_lens_ref[seq_idx]
445
+ kv_len_start = bkv_idx * bkv_sz
446
+ kv_p_start = bkv_idx * bkv_p
447
+ q_start = cu_q_lens_ref[seq_idx]
448
+ q_end = cu_q_lens_ref[seq_idx + 1]
449
+ q_len = q_end - q_start
450
+
451
+ kv_left = kv_len - kv_len_start
452
+ kv_left_frm_cache = jnp.maximum(kv_left - q_len, 0)
453
+ kv_left_frm_new = kv_left - kv_left_frm_cache
454
+ bkv_p_frm_cache = jnp.minimum(cdiv(kv_left_frm_cache, page_size),
455
+ bkv_p)
456
+ bkv_sz_frm_new = jnp.minimum(
457
+ jnp.maximum(bkv_sz - kv_left_frm_cache, 0), kv_left_frm_new)
458
+ page_indices_offset = seq_idx * pages_per_seq + kv_p_start
459
+
460
+ # Make sure the current bkv buffer is safe to overwrite.
461
+ wait_update_kv_cache(bkv_sem_idx)
462
+
463
+ debug_print(
464
+ "[RPA debug]"
465
+ f" -----------{'wait' if wait else 'start'}_fetch_bkv-----------")
466
+ debug_print("[RPA debug] seq_idx={}", seq_idx)
467
+ debug_print("[RPA debug] bkv_idx={}", bkv_idx)
468
+ debug_print("[RPA debug] bkv_sem_idx={}", bkv_sem_idx)
469
+ debug_print("[RPA debug] kv_len_start={}", kv_len_start)
470
+ debug_print("[RPA debug] kv_p_start={}", kv_p_start)
471
+ debug_print("[RPA debug] kv_left={}", kv_left)
472
+ debug_print("[RPA debug] kv_left_frm_cache={}", kv_left_frm_cache)
473
+ debug_print("[RPA debug] kv_left_frm_new={}", kv_left_frm_new)
474
+ debug_print("[RPA debug] bkv_p_frm_cache={}", bkv_p_frm_cache)
475
+ debug_print("[RPA debug] bkv_sz_frm_new={}", bkv_sz_frm_new)
476
+ debug_print("[RPA debug] page_indices_offset={}", page_indices_offset)
477
+
478
+ # Fetch effective kv from kv cache.
479
+ def loop_body(i, offset):
480
+ sz = jnp.minimum(page_size, kv_left_frm_cache - i * page_size)
481
+ _async_copy(
482
+ cache_hbm_ref.at[pl.ds(
483
+ page_indices_ref[page_indices_offset + i] * page_size,
484
+ sz)],
485
+ vmem_ref.at[pl.ds(i * page_size, sz)],
486
+ sem,
487
+ wait,
488
+ )
489
+ debug_print("[RPA debug] loop_body i={}, sz={}", i, sz)
490
+ return offset + sz
491
+
492
+ offset = lax.fori_loop(
493
+ 0,
494
+ bkv_p_frm_cache,
495
+ loop_body,
496
+ 0, # offset
497
+ unroll=False,
498
+ )
499
+
500
+ # Fetch kv directly from new kv.
501
+ @pl.when(bkv_sz_frm_new > 0)
502
+ def _fetch_bkv_from_new_kv():
503
+ new_kv_len_start = q_end - kv_left_frm_new
504
+ debug_print("[RPA debug] new_kv_len_start={}", new_kv_len_start)
505
+ debug_print("[RPA debug] offset_in_bkv={}", offset)
506
+ _async_copy(
507
+ kv_hbm_ref.at[pl.ds(new_kv_len_start, bkv_sz_frm_new)],
508
+ vmem_ref.at[pl.ds(offset, bkv_sz_frm_new)],
509
+ sem,
510
+ wait,
511
+ )
512
+
513
+ return kv_len_start + offset, bkv_sz_frm_new
514
+
515
+ def _update_kv_cache(seq_idx,
516
+ bkv_sem_idx,
517
+ offset,
518
+ update_sz,
519
+ *,
520
+ wait=False):
521
+ sem = sems.at[3, bkv_sem_idx]
522
+ vmem_ref = bkv_x2_ref.at[bkv_sem_idx]
523
+ bkv_id = offset // bkv_sz
524
+ kv_p_start = offset // page_size
525
+ kv_p_end = cdiv(offset + update_sz, page_size)
526
+ ignore = offset % page_size
527
+ p_ignore = kv_p_start - bkv_id * bkv_p
528
+ page_indices_offset = seq_idx * pages_per_seq + kv_p_start
529
+
530
+ cache_hbm_shape = updated_kv_cache_hbm_ref.shape
531
+ cache_hbm_ref = updated_kv_cache_hbm_ref.reshape(
532
+ cache_hbm_shape[0] * cache_hbm_shape[1], *cache_hbm_shape[2:])
533
+
534
+ debug_print(
535
+ "[RPA debug]"
536
+ f" -----------{'wait' if wait else 'start'}_update_kv_cache-----------"
537
+ )
538
+ debug_print("[RPA debug] seq_idx={}", seq_idx)
539
+ debug_print("[RPA debug] bkv_sem_idx={}", bkv_sem_idx)
540
+ debug_print("[RPA debug] offset={}", offset)
541
+ debug_print("[RPA debug] update_sz={}", update_sz)
542
+ debug_print("[RPA debug] bkv_id={}", bkv_id)
543
+ debug_print("[RPA debug] kv_p_start={}", kv_p_start)
544
+ debug_print("[RPA debug] kv_p_end={}", kv_p_end)
545
+ debug_print("[RPA debug] ignore={}", ignore)
546
+ debug_print("[RPA debug] p_ignore={}", p_ignore)
547
+ debug_print("[RPA debug] page_indices_offset={}", page_indices_offset)
548
+
549
+ def loop_body(i, states):
550
+ update_sz, ignore = states
551
+ sz = jnp.minimum(page_size - ignore, update_sz)
552
+
553
+ _async_copy(
554
+ vmem_ref.at[pl.ds((p_ignore + i) * page_size + ignore, sz)],
555
+ cache_hbm_ref.at[pl.ds(
556
+ page_indices_ref[page_indices_offset + i] * page_size +
557
+ ignore,
558
+ sz,
559
+ )],
560
+ sem,
561
+ wait,
562
+ )
563
+ debug_print("[RPA debug] loop_body i={}, sz={}", i, sz)
564
+ return update_sz - sz, 0
565
+
566
+ lax.fori_loop(
567
+ 0,
568
+ kv_p_end - kv_p_start,
569
+ loop_body,
570
+ (update_sz, ignore), # total transfer size
571
+ unroll=False,
572
+ )
573
+
574
+ def _fetch_bq(seq_idx, bq_idx, bq_sem_idx, *, wait=False):
575
+ sem = sems.at[1, bq_sem_idx]
576
+ vmem_ref = bq_x2_ref.at[bq_sem_idx]
577
+ q_len_start = cu_q_lens_ref[seq_idx] + bq_idx * bq_sz
578
+ q_end = cu_q_lens_ref[seq_idx + 1]
579
+ sz = jnp.minimum(bq_sz, q_end - q_len_start)
580
+
581
+ debug_print(
582
+ "[RPA debug]"
583
+ f" -----------{'wait' if wait else 'start'}_fetch_bq-----------")
584
+ debug_print("[RPA debug] seq_idx={}", seq_idx)
585
+ debug_print("[RPA debug] bq_idx={}", bq_idx)
586
+ debug_print("[RPA debug] bq_sem_idx={}", bq_sem_idx)
587
+ debug_print("[RPA debug] q_len_start={}", q_len_start)
588
+ debug_print("[RPA debug] q_end={}", q_end)
589
+ debug_print("[RPA debug] sz={}", sz)
590
+
591
+ _async_copy(
592
+ q_hbm_ref.at[:, pl.ds(q_len_start, sz)],
593
+ vmem_ref.at[:, pl.ds(0, sz)],
594
+ sem,
595
+ wait,
596
+ )
597
+
598
+ def _send_bo(seq_idx, bo_idx, bo_sem_idx, *, wait=False):
599
+ sem = sems.at[2, bo_sem_idx]
600
+ vmem_ref = bo_x2_ref.at[bo_sem_idx]
601
+ q_len_start = cu_q_lens_ref[seq_idx] + bo_idx * bq_sz
602
+ q_end = cu_q_lens_ref[seq_idx + 1]
603
+ sz = jnp.minimum(bq_sz, q_end - q_len_start)
604
+
605
+ debug_print(
606
+ "[RPA debug]"
607
+ f" -----------{'wait' if wait else 'start'}_send_bo-----------")
608
+ debug_print("[RPA debug] seq_idx={}", seq_idx)
609
+ debug_print("[RPA debug] bo_idx={}", bo_idx)
610
+ debug_print("[RPA debug] bo_sem_idx={}", bo_sem_idx)
611
+ debug_print("[RPA debug] q_len_start={}", q_len_start)
612
+ debug_print("[RPA debug] q_end={}", q_end)
613
+ debug_print("[RPA debug] sz={}", sz)
614
+
615
+ _async_copy(
616
+ vmem_ref.at[:, pl.ds(0, sz)],
617
+ o_hbm_ref.at[:, pl.ds(q_len_start, sz)],
618
+ sem,
619
+ wait,
620
+ )
621
+
622
+ def start_fetch_bkv(seq_idx, bkv_idx, bkv_sem_idx):
623
+ return _fetch_bkv(seq_idx, bkv_idx, bkv_sem_idx)
624
+
625
+ def wait_fetch_bkv(seq_idx, bkv_idx, bkv_sem_idx):
626
+ return _fetch_bkv(seq_idx, bkv_idx, bkv_sem_idx, wait=True)
627
+
628
+ def start_fetch_bq(seq_idx, bq_idx, bq_sem_idx):
629
+ return _fetch_bq(seq_idx, bq_idx, bq_sem_idx)
630
+
631
+ def wait_fetch_bq(seq_idx, bq_idx, bq_sem_idx):
632
+ return _fetch_bq(seq_idx, bq_idx, bq_sem_idx, wait=True)
633
+
634
+ def start_send_bo(seq_idx, bo_idx, bo_sem_idx):
635
+ bo_ids_ref[bo_sem_idx] = seq_idx
636
+ bo_ids_ref[bo_sem_idx + 2] = bo_idx
637
+ _send_bo(seq_idx, bo_idx, bo_sem_idx)
638
+
639
+ def wait_send_bo(bo_sem_idx):
640
+ old_seq_idx = bo_ids_ref[bo_sem_idx]
641
+ old_bo_idx = bo_ids_ref[bo_sem_idx + 2]
642
+
643
+ @pl.when(jnp.logical_and(0 <= old_seq_idx, old_seq_idx <= seq_idx))
644
+ def _():
645
+ _send_bo(old_seq_idx, old_bo_idx, bo_sem_idx, wait=True)
646
+
647
+ def start_update_kv_cache(seq_idx, bkv_sem_idx, offset, update_sz):
648
+ bkv_update_ids_ref[bkv_sem_idx] = seq_idx
649
+ bkv_update_ids_ref[bkv_sem_idx + 2] = offset
650
+ bkv_update_ids_ref[bkv_sem_idx + 4] = update_sz
651
+ _update_kv_cache(seq_idx, bkv_sem_idx, offset, update_sz)
652
+
653
+ def wait_update_kv_cache(bkv_sem_idx):
654
+ update_sz = bkv_update_ids_ref[bkv_sem_idx + 4]
655
+
656
+ @pl.when(update_sz > 0)
657
+ def _():
658
+ seq_idx = bkv_update_ids_ref[bkv_sem_idx]
659
+ offset = bkv_update_ids_ref[bkv_sem_idx + 2]
660
+ bkv_update_ids_ref[bkv_sem_idx + 4] = 0
661
+ _update_kv_cache(seq_idx,
662
+ bkv_sem_idx,
663
+ offset,
664
+ update_sz,
665
+ wait=True)
666
+
667
+ def load_bq(bq_sem_idx, kv_head_idx, *, actual_bq_sz=bq_sz):
668
+ q_ref = (bq_x2_ref.bitcast(
669
+ jnp.uint32).at[bq_sem_idx, kv_head_idx].reshape(
670
+ bq_sz * num_q_heads_per_kv_head_per_packing,
671
+ actual_head_dim_x2))
672
+ return pltpu.bitcast(
673
+ q_ref[:actual_bq_sz * num_q_heads_per_kv_head_per_packing],
674
+ q_dtype)
675
+
676
+ def strided_load(ref, start, step):
677
+ assert get_dtype_packing(ref.dtype) == 1
678
+ assert len(ref.shape) == 2
679
+ _, l = ref.shape # noqa
680
+ assert l == 128
681
+ vec = ref[start::step]
682
+ return vec
683
+
684
+ def strided_load_bkv(bkv_sem_idx, start, step, *, bkv_mask):
685
+ assert start % kv_packing == 0
686
+ assert step % kv_packing == 0
687
+ start //= kv_packing
688
+ step //= kv_packing
689
+ kv_ref = (bkv_x2_ref.bitcast(jnp.uint32).at[bkv_sem_idx].reshape(
690
+ bkv_sz * step, actual_head_dim_x2))
691
+
692
+ kv = strided_load(kv_ref, start, step)
693
+ kv = lax.select(bkv_mask, kv, jnp.zeros_like(kv))
694
+ bitwidth = 32 // kv_packing
695
+ repack_ty = jnp.dtype(f"uint{bitwidth}")
696
+ lst = []
697
+ for i in range(0, kv_packing):
698
+ cur_kv = pltpu.bitcast((kv >> (i * bitwidth)).astype(repack_ty),
699
+ kv_dtype)
700
+ lst.append(cur_kv)
701
+ return lst
702
+
703
+ def broadcast_minor(src, shape):
704
+ if src.shape == shape:
705
+ return src
706
+ assert src.shape[:-1] == shape[:-1]
707
+ assert src.shape[-1] % 128 == 0
708
+ target_minor = align_to(shape[-1], src.shape[-1])
709
+ # no-op concatenation.
710
+ return jnp.concatenate(
711
+ [src for _ in range(target_minor // src.shape[-1])],
712
+ axis=-1)[..., :shape[-1]]
713
+
714
+ def process(static_q_len=None):
715
+ num_bkv = cdiv(kv_len, bkv_sz)
716
+ if static_q_len is None:
717
+ actual_bq_sz = bq_sz
718
+ num_bq = cdiv(q_len, actual_bq_sz)
719
+ else:
720
+ actual_bq_sz = min(bq_sz, static_q_len)
721
+ num_bq = cdiv(static_q_len, actual_bq_sz)
722
+
723
+ def get_next_bq_ids(seq_idx, bq_idx, bq_sem_idx):
724
+ next_bq_idx = bq_idx + 1
725
+ is_last_bq = next_bq_idx == num_bq
726
+ next_bq_idx = lax.select(is_last_bq, 0, next_bq_idx)
727
+ next_seq_idx = lax.select(is_last_bq, seq_idx + 1, seq_idx)
728
+ next_bq_sem_idx = lax.select(bq_sem_idx == 0, 1, 0)
729
+ return next_seq_idx, next_bq_idx, next_bq_sem_idx
730
+
731
+ def get_next_bkv_ids(seq_idx, bq_idx, bkv_idx, bkv_sem_idx):
732
+ next_bkv_idx = bkv_idx + 1
733
+ is_last_bkv = next_bkv_idx == num_bkv
734
+ next_bq_idx = lax.select(is_last_bkv, bq_idx + 1, bq_idx)
735
+ is_last_bq = next_bq_idx == num_bq
736
+ next_bq_idx = lax.select(is_last_bq, 0, next_bq_idx)
737
+ next_seq_idx = lax.select(is_last_bq, seq_idx + 1, seq_idx)
738
+ next_bkv_sem_idx = lax.select(bkv_sem_idx == 0, 1, 0)
739
+
740
+ next_bkv_idx = lax.select(
741
+ is_last_bkv,
742
+ lax.select(
743
+ is_last_bq,
744
+ next_bkv_idx_start,
745
+ bkv_idx_start,
746
+ ), next_bkv_idx)
747
+ return next_seq_idx, next_bq_idx, next_bkv_idx, next_bkv_sem_idx
748
+
749
+ def compute_with_bq(bq_idx, _):
750
+ bq_sem_idx = sem_ids_ref[0]
751
+ next_seq_idx, next_bq_idx, next_bq_sem_idx = get_next_bq_ids(
752
+ seq_idx, bq_idx, bq_sem_idx)
753
+
754
+ # Prefetch next bq
755
+ @pl.when(next_seq_idx < num_seqs)
756
+ def prefetch_next_bq():
757
+ sem_ids_ref[0] = next_bq_sem_idx
758
+ start_fetch_bq(next_seq_idx, next_bq_idx, next_bq_sem_idx)
759
+
760
+ def compute_with_bkv(bkv_idx, _):
761
+ # Create bitmask for KV.
762
+ assert bkv_sz % kv_packing == 0
763
+ actual_bkv_sz = jnp.minimum(bkv_sz, kv_len - bkv_idx * bkv_sz)
764
+ bkv_shape = (bkv_sz, actual_head_dim_x2)
765
+ bkv_mask = lax.broadcasted_iota(jnp.int32, bkv_shape,
766
+ 0) < actual_bkv_sz
767
+
768
+ # Get next bkv ids.
769
+ bkv_sem_idx = sem_ids_ref[1]
770
+ next_seq_idx, _, next_bkv_idx, next_bkv_sem_idx = get_next_bkv_ids(
771
+ seq_idx, bq_idx, bkv_idx, bkv_sem_idx)
772
+
773
+ # Prefetch next bkv
774
+ @pl.when(next_seq_idx < num_seqs)
775
+ def prefetch_next_bkv():
776
+ sem_ids_ref[1] = next_bkv_sem_idx
777
+ start_fetch_bkv(next_seq_idx, next_bkv_idx,
778
+ next_bkv_sem_idx)
779
+
780
+ # Wait for cur bq if not ready yet
781
+ @pl.when(bkv_idx == bkv_idx_start)
782
+ def wait_cur_bq():
783
+ wait_fetch_bq(seq_idx, bq_idx, bq_sem_idx)
784
+
785
+ # Wait for cur bkv
786
+ offset, update_sz = wait_fetch_bkv(seq_idx, bkv_idx,
787
+ bkv_sem_idx)
788
+
789
+ # Start updating bkv to kv cache if applicable.
790
+ # Only needed in first bq loop.
791
+ @pl.when(jnp.logical_and(update_sz > 0, bq_idx == 0))
792
+ def update_cur_bkv_to_cache():
793
+ start_update_kv_cache(seq_idx, bkv_sem_idx, offset,
794
+ update_sz)
795
+
796
+ debug_print(
797
+ "[RPA debug] -----------flash attention-----------")
798
+ debug_print("[RPA debug] seq_idx={}", seq_idx)
799
+ debug_print("[RPA debug] bq_idx={}", bq_idx)
800
+ debug_print("[RPA debug] bkv_idx={}", bkv_idx)
801
+ if debug_mode:
802
+ # Skip flash attention if debug mode is enabled.
803
+ return
804
+
805
+ # Flash attention with cur bkv and bq
806
+ for kv_head_start in range(0, actual_num_kv_heads, kv_packing):
807
+ bkv_lst = strided_load_bkv(
808
+ bkv_sem_idx,
809
+ kv_head_start,
810
+ num_kv_heads,
811
+ bkv_mask=bkv_mask,
812
+ )
813
+ assert len(bkv_lst) == kv_packing
814
+ for i in range(kv_packing):
815
+ kv_head_idx = kv_head_start + i
816
+ if kv_head_idx >= actual_num_kv_heads:
817
+ break
818
+ bq = load_bq(bq_sem_idx,
819
+ kv_head_idx,
820
+ actual_bq_sz=actual_bq_sz)
821
+ bkv = bkv_lst[i]
822
+ flash_attention(
823
+ bq,
824
+ bkv,
825
+ bq_idx=bq_idx,
826
+ bkv_idx=bkv_idx,
827
+ kv_head_idx=kv_head_idx,
828
+ )
829
+
830
+ lax.fori_loop(bkv_idx_start,
831
+ num_bkv,
832
+ compute_with_bkv,
833
+ None,
834
+ unroll=False)
835
+
836
+ # Load acc and calculate final output.
837
+ acc = acc_ref[...]
838
+ l = broadcast_minor(l_ref[...], acc.shape) # noqa
839
+ out = (lax.div(acc, l) if q_dtype == jnp.float32 else
840
+ (acc * pl.reciprocal(l, approx=True)).astype(q_dtype))
841
+
842
+ # Wait for previous bo to be fully sent before storing new bo.
843
+ bo_sem_idx = sem_ids_ref[2]
844
+ sem_ids_ref[2] = lax.select(bo_sem_idx == 0, 1, 0)
845
+ wait_send_bo(bo_sem_idx)
846
+
847
+ # Store output from acc to bo.
848
+ bo_x2_ref.at[bo_sem_idx].bitcast(jnp.int32).reshape(
849
+ actual_num_kv_heads,
850
+ bq_sz * num_q_heads_per_kv_head_per_packing,
851
+ actual_head_dim_x2,
852
+ )[...] = pltpu.bitcast(out, jnp.int32)
853
+
854
+ # Send cur bo
855
+ start_send_bo(seq_idx, bq_idx, bo_sem_idx)
856
+
857
+ lax.fori_loop(0, num_bq, compute_with_bq, None, unroll=False)
858
+
859
+ ### ------- Kernel start ------- ###
860
+
861
+ @pl.when(seq_idx == 0)
862
+ def prologue():
863
+ start_fetch_bq(0, 0, 0)
864
+ start_fetch_bkv(0, bkv_idx_start, 0)
865
+
866
+ @pl.when(seq_idx < decode_end)
867
+ def process_decode():
868
+ process(static_q_len=1)
869
+
870
+ @pl.when(jnp.logical_and(decode_end <= seq_idx, seq_idx < prefill_end))
871
+ def process_prefill():
872
+ process(static_q_len=chunk_prefill_size)
873
+
874
+ @pl.when(jnp.logical_and(prefill_end <= seq_idx, seq_idx < mixed_end))
875
+ def process_mixed():
876
+ process()
877
+
878
+ @pl.when(seq_idx == num_seqs - 1)
879
+ def epilogue():
880
+ for i in range(2):
881
+ wait_send_bo(i)
882
+ wait_update_kv_cache(i)
883
+
884
+ ### ------- Kernel end ------- ###
885
+
886
+
887
+ def merge_kv(
888
+ k: jax.
889
+ Array, # [max_num_tokens, actual_num_kv_heads, actual_head_dim],
890
+ v: jax.
891
+ Array, # [max_num_tokens, actual_num_kv_heads, actual_head_dim],
892
+ ):
893
+ assert k.shape == v.shape
894
+ assert k.dtype == v.dtype
895
+ max_num_tokens, actual_num_kv_heads, actual_head_dim = k.shape
896
+ kv_packing = get_dtype_packing(k.dtype)
897
+ num_kv_heads = align_to(actual_num_kv_heads, kv_packing)
898
+ kv = jnp.pad(
899
+ jnp.concat([k, v], axis=-1),
900
+ (
901
+ (0, 0),
902
+ (0, num_kv_heads - actual_num_kv_heads),
903
+ (0, 0),
904
+ ),
905
+ constant_values=0,
906
+ ).reshape(
907
+ max_num_tokens,
908
+ num_kv_heads // kv_packing,
909
+ kv_packing,
910
+ actual_head_dim * 2,
911
+ )
912
+ return kv
913
+
914
+
915
+ def prepare_inputs(
916
+ q: jax.Array, # [max_num_tokens, actual_num_q_heads, actual_head_dim],
917
+ k: jax.
918
+ Array, # [max_num_tokens, actual_num_kv_heads, actual_head_dim],
919
+ v: jax.
920
+ Array, # [max_num_tokens, actual_num_kv_heads, actual_head_dim],
921
+ attention_sink: jax.Array | None = None, # f32[actual_num_q_heads],
922
+ ):
923
+ max_num_tokens, actual_num_q_heads, actual_head_dim = q.shape
924
+ actual_num_kv_heads = k.shape[1]
925
+ assert actual_num_q_heads % actual_num_kv_heads == 0
926
+ actual_num_q_heads_per_kv_head = actual_num_q_heads // actual_num_kv_heads
927
+ q_packing = get_dtype_packing(q.dtype)
928
+ num_q_heads_per_kv_head = align_to(actual_num_q_heads_per_kv_head,
929
+ q_packing)
930
+ head_dim = align_to(actual_head_dim, 128)
931
+ q = (
932
+ jnp.pad(
933
+ q.reshape(
934
+ max_num_tokens,
935
+ actual_num_kv_heads,
936
+ actual_num_q_heads_per_kv_head,
937
+ actual_head_dim,
938
+ ),
939
+ (
940
+ (0, 0),
941
+ (0, 0),
942
+ (0, num_q_heads_per_kv_head - actual_num_q_heads_per_kv_head),
943
+ (0, head_dim - actual_head_dim),
944
+ ),
945
+ constant_values=0,
946
+ ).reshape(
947
+ max_num_tokens,
948
+ actual_num_kv_heads,
949
+ num_q_heads_per_kv_head // q_packing,
950
+ q_packing,
951
+ head_dim,
952
+ )
953
+ # TODO(jevinjiang): Explore fusing swapping non-tiling axis to DMA.
954
+ .swapaxes(0, 1))
955
+ # TODO(kyuyeunk, chengjiyao): Add kv quantization here.
956
+ kv = merge_kv(k, v)
957
+
958
+ if attention_sink is not None:
959
+ attention_sink = attention_sink.reshape(
960
+ (-1, num_q_heads_per_kv_head, 1))
961
+ attention_sink = jnp.repeat(attention_sink, 128, -1)
962
+
963
+ return q, kv, attention_sink
964
+
965
+
966
+ def prepare_outputs(
967
+ out, # [actual_num_kv_heads, max_num_tokens, num_q_heads_per_kv_head // q_packing, q_packing, actual_head_dim_x2]
968
+ actual_num_q_heads_per_kv_head: int,
969
+ actual_head_dim: int,
970
+ ):
971
+ (
972
+ actual_num_kv_heads,
973
+ max_num_tokens,
974
+ num_q_heads_per_kv_head_per_q_packing,
975
+ q_packing,
976
+ actual_head_dim_x2,
977
+ ) = out.shape
978
+ actual_num_q_heads = actual_num_q_heads_per_kv_head * actual_num_kv_heads
979
+ return (out.swapaxes(0, 1).reshape(
980
+ max_num_tokens,
981
+ actual_num_kv_heads,
982
+ num_q_heads_per_kv_head_per_q_packing * q_packing,
983
+ actual_head_dim_x2,
984
+ )[:, :, :actual_num_q_heads_per_kv_head,
985
+ actual_head_dim:].reshape(max_num_tokens, actual_num_q_heads,
986
+ actual_head_dim))
987
+
988
+
989
+ # Expect to run this validation during runtime.
990
+ def dynamic_validate_inputs(
991
+ queries: jax.
992
+ Array, # [max_num_tokens, actual_num_q_heads, actual_head_dim]
993
+ keys: jax.Array, # [max_num_tokens, actual_num_kv_heads, actual_head_dim]
994
+ values: jax.
995
+ Array, # [max_num_tokens, actual_num_kv_heads, actual_head_dim]
996
+ kv_cache: jax.
997
+ Array, # [total_num_pages, page_size, num_kv_heads // kv_packing, kv_packing, head_dim]
998
+ kv_lens: jax.Array, # i32[max_num_seqs]
999
+ page_indices: jax.Array, # i32[max_num_seqs * pages_per_seq]
1000
+ cu_q_lens: jax.Array, # i32[max_num_seqs + 1]
1001
+ distribution: jax.Array, # i32[3]
1002
+ attention_sink: jax.Array | None = None, # f32[actual_num_q_heads]
1003
+ *,
1004
+ sm_scale: float = 1.0,
1005
+ sliding_window: int | None = None,
1006
+ soft_cap: float | None = None,
1007
+ mask_value: float | None = DEFAULT_MASK_VALUE,
1008
+ q_scale: float | None = None,
1009
+ k_scale: float | None = None,
1010
+ v_scale: float | None = None,
1011
+ # Kernel optimization params.
1012
+ chunk_prefill_size: int | None = None,
1013
+ # Kernel tuning params.
1014
+ num_kv_pages_per_block: int | None = None,
1015
+ num_queries_per_block: int | None = None,
1016
+ vmem_limit_bytes: int | None = None,
1017
+ # Debug params.
1018
+ debug_mode: bool = False,
1019
+ ):
1020
+ q, k, v = queries, keys, values
1021
+ static_validate_inputs(
1022
+ q,
1023
+ k,
1024
+ v,
1025
+ kv_cache,
1026
+ kv_lens,
1027
+ page_indices,
1028
+ cu_q_lens,
1029
+ distribution,
1030
+ attention_sink,
1031
+ sm_scale=sm_scale,
1032
+ sliding_window=sliding_window,
1033
+ soft_cap=soft_cap,
1034
+ mask_value=mask_value,
1035
+ q_scale=q_scale,
1036
+ k_scale=k_scale,
1037
+ v_scale=v_scale,
1038
+ chunk_prefill_size=chunk_prefill_size,
1039
+ num_kv_pages_per_block=num_kv_pages_per_block,
1040
+ num_queries_per_block=num_queries_per_block,
1041
+ vmem_limit_bytes=vmem_limit_bytes,
1042
+ debug_mode=debug_mode,
1043
+ )
1044
+ max_num_tokens = q.shape[0]
1045
+ total_num_pages = kv_cache.shape[0]
1046
+ page_size = kv_cache.shape[1]
1047
+ max_num_seqs = kv_lens.shape[0]
1048
+ num_page_indices = page_indices.shape[0]
1049
+ assert num_page_indices % max_num_seqs == 0
1050
+ pages_per_seq = num_page_indices // max_num_seqs
1051
+
1052
+ i, j, k = distribution
1053
+ if not (i <= j <= k):
1054
+ raise ValueError(f"Invalid distribution: {distribution=}")
1055
+
1056
+ if k > max_num_seqs:
1057
+ raise ValueError(f"num_seqs={k} must be <= {max_num_seqs=}")
1058
+
1059
+ if cu_q_lens[k] > max_num_tokens:
1060
+ raise ValueError(
1061
+ f"Total q tokens {cu_q_lens[k]} must be <= {max_num_tokens=}.")
1062
+ for i in range(k):
1063
+ q_len = cu_q_lens[i + 1] - cu_q_lens[i]
1064
+ kv_len = kv_lens[i]
1065
+ if not (0 < q_len <= kv_len):
1066
+ raise ValueError(
1067
+ f"Require 0 < {q_len=} <= {kv_len=} at sequence {i}.")
1068
+ page_cnt = cdiv(kv_len, page_size)
1069
+ if page_cnt > pages_per_seq:
1070
+ raise ValueError(
1071
+ f"Require {page_cnt=} <= {pages_per_seq=} at sequence {i} where"
1072
+ f" {kv_len=} and {page_size=}.")
1073
+ for p in range(page_cnt):
1074
+ page_idx = page_indices[i * pages_per_seq + p]
1075
+ if not (0 <= page_idx < total_num_pages):
1076
+ raise ValueError(
1077
+ f"Require 0 <= {page_idx=} < {total_num_pages=} at sequence"
1078
+ f" {i} where {kv_len=} and {page_size=}.")
1079
+
1080
+
1081
+ # Expect to run this validation during compile time.
1082
+ def static_validate_inputs(
1083
+ queries: jax.
1084
+ Array, # [max_num_tokens, actual_num_q_heads, actual_head_dim]
1085
+ keys: jax.Array, # [max_num_tokens, actual_num_kv_heads, actual_head_dim]
1086
+ values: jax.
1087
+ Array, # [max_num_tokens, actual_num_kv_heads, actual_head_dim]
1088
+ kv_cache: jax.
1089
+ Array, # [total_num_pages, page_size, num_kv_heads // kv_packing, kv_packing, actual_head_dim_x2]
1090
+ kv_lens: jax.Array, # i32[max_num_seqs]
1091
+ page_indices: jax.Array, # i32[max_num_seqs * pages_per_seq]
1092
+ cu_q_lens: jax.Array, # i32[max_num_seqs + 1]
1093
+ distribution: jax.Array, # i32[3]
1094
+ attention_sink: jax.Array | None = None, # f32[actual_num_q_heads]
1095
+ *,
1096
+ sm_scale: float = 1.0,
1097
+ sliding_window: int | None = None,
1098
+ soft_cap: float | None = None,
1099
+ mask_value: float | None = DEFAULT_MASK_VALUE,
1100
+ q_scale: float | None = None,
1101
+ k_scale: float | None = None,
1102
+ v_scale: float | None = None,
1103
+ # Kernel optimization params.
1104
+ chunk_prefill_size: int | None = None,
1105
+ # Kernel tuning params.
1106
+ num_kv_pages_per_block: int | None = None,
1107
+ num_queries_per_block: int | None = None,
1108
+ vmem_limit_bytes: int | None = None,
1109
+ # Debug params.
1110
+ debug_mode: bool = False,
1111
+ ):
1112
+ """Validate inputs to the RPA kernel statically."""
1113
+ q, k, v = queries, keys, values
1114
+ if not (len(q.shape) == len(k.shape) == len(v.shape) == 3):
1115
+ raise ValueError(
1116
+ f"Expected 3D array for {q.shape=}, {k.shape=}, {v.shape=}")
1117
+ if k.shape != v.shape:
1118
+ raise ValueError(f"Expected {k.shape=} to be equal to {v.shape=}")
1119
+ if not (q.shape[0] == k.shape[0] == v.shape[0]):
1120
+ raise ValueError(
1121
+ f"Expected {q.shape[0]=} to be equal to {k.shape[0]=} and {v.shape[0]=}"
1122
+ )
1123
+ if not (q.shape[2] == k.shape[2] == v.shape[2]):
1124
+ raise ValueError(
1125
+ f"Expected {q.shape[2]=} to be equal to {k.shape[2]=} and {v.shape[2]=}"
1126
+ )
1127
+ if attention_sink is not None:
1128
+ if attention_sink.shape[0] != q.shape[1]:
1129
+ raise ValueError(
1130
+ f"Expected {attention_sink.shape[0]=} to be equal to"
1131
+ f" {q.shape[1]=} (num_q_heads).")
1132
+ if attention_sink.dtype != jnp.float32:
1133
+ raise ValueError(
1134
+ f"Expected {attention_sink.dtype=} to be equal to {jnp.float32=}."
1135
+ )
1136
+
1137
+ actual_head_dim = q.shape[2]
1138
+ if actual_head_dim != 64:
1139
+ raise ValueError(f"Expected {actual_head_dim=} to be 64.")
1140
+ actual_num_q_heads = q.shape[1]
1141
+ actual_num_kv_heads = k.shape[1]
1142
+
1143
+ if actual_num_q_heads % actual_num_kv_heads != 0:
1144
+ raise ValueError(f"Expected {actual_num_q_heads=} to be divisible by"
1145
+ f" {actual_num_kv_heads=}.")
1146
+
1147
+ (
1148
+ _,
1149
+ page_size,
1150
+ num_kv_heads_per_kv_packing,
1151
+ kv_packing,
1152
+ actual_head_dim_x2,
1153
+ ) = kv_cache.shape
1154
+
1155
+ if actual_head_dim_x2 != 128:
1156
+ raise ValueError(f"Expected {actual_head_dim_x2=} is equal to 128")
1157
+ # Note: we expect the kv quantization happens outside of the RPA kernel.
1158
+ if not (kv_cache.dtype == k.dtype == v.dtype):
1159
+ raise ValueError(
1160
+ f"Expected {kv_cache.dtype=} to be equal to {k.dtype=} and {v.dtype=}."
1161
+ )
1162
+ # Integer kv quantization is currently not supported.
1163
+ if not jnp.issubdtype(kv_cache.dtype, jnp.floating):
1164
+ raise ValueError(f"Expected {kv_cache.dtype=} to be a floating point.")
1165
+ if kv_packing != get_dtype_packing(kv_cache.dtype):
1166
+ raise ValueError(
1167
+ f"{kv_packing=} does not match with {kv_cache.dtype=}")
1168
+
1169
+ num_kv_heads = num_kv_heads_per_kv_packing * kv_packing
1170
+ if align_to(actual_num_kv_heads, kv_packing) != num_kv_heads:
1171
+ raise ValueError(
1172
+ f"Invalid {num_kv_heads=}, {actual_num_kv_heads=}, {kv_packing=}")
1173
+
1174
+ if not (jnp.int32 == kv_lens.dtype == page_indices.dtype == cu_q_lens.dtype
1175
+ == distribution.dtype):
1176
+ raise ValueError(
1177
+ f"Expected int32 dtype for {kv_lens.dtype=}, {page_indices.dtype=},"
1178
+ f" {cu_q_lens.dtype=}, {distribution.dtype=}")
1179
+
1180
+ if not (len(kv_lens.shape) == len(page_indices.shape) == len(
1181
+ cu_q_lens.shape) == 1):
1182
+ raise ValueError(
1183
+ f"Expected 1D array for {kv_lens.shape=}, {page_indices.shape=},"
1184
+ f" {cu_q_lens.shape=}")
1185
+
1186
+ max_num_seqs = kv_lens.shape[0]
1187
+ num_page_indices = page_indices.shape[0]
1188
+ if num_page_indices % max_num_seqs != 0:
1189
+ raise ValueError(
1190
+ f"Expected {num_page_indices=} to be divisible by {max_num_seqs=}."
1191
+ )
1192
+ if cu_q_lens.shape != (max_num_seqs + 1, ):
1193
+ raise ValueError(
1194
+ f"Expected {cu_q_lens.shape=} to be ({max_num_seqs + 1},).")
1195
+ if distribution.shape != (3, ):
1196
+ raise ValueError(f"Expected {distribution.shape=} to be (3,).")
1197
+
1198
+ if page_size % kv_packing != 0:
1199
+ raise ValueError(f"{page_size=} must be divisible by {kv_packing=}.")
1200
+ if sliding_window is not None and sliding_window <= 0:
1201
+ raise ValueError(f"{sliding_window=} must be positive.")
1202
+ if soft_cap is not None and soft_cap == 0.0:
1203
+ raise ValueError(f"{soft_cap=} must not be 0.0.")
1204
+ if chunk_prefill_size is not None and chunk_prefill_size <= 0:
1205
+ raise ValueError(f"{chunk_prefill_size=} must be positive.")
1206
+ if num_kv_pages_per_block is not None:
1207
+ if num_kv_pages_per_block <= 0:
1208
+ raise ValueError(f"{num_kv_pages_per_block=} must be positive.")
1209
+ if num_queries_per_block is not None:
1210
+ if num_queries_per_block <= 0:
1211
+ raise ValueError(f"{num_queries_per_block=} must be positive.")
1212
+ if vmem_limit_bytes is not None and vmem_limit_bytes <= 0:
1213
+ raise ValueError(f"{vmem_limit_bytes=} must be positive.")
1214
+
1215
+ # No constraints for the following inputs.
1216
+ del sm_scale
1217
+ del mask_value
1218
+ del q_scale
1219
+ del k_scale
1220
+ del v_scale
1221
+ del debug_mode
1222
+
1223
+
1224
+ @functools.partial(
1225
+ jax.jit,
1226
+ static_argnames=(
1227
+ "sm_scale",
1228
+ "sliding_window",
1229
+ "soft_cap",
1230
+ "mask_value",
1231
+ "q_scale",
1232
+ "k_scale",
1233
+ "v_scale",
1234
+ "chunk_prefill_size",
1235
+ "num_kv_pages_per_block",
1236
+ "num_queries_per_block",
1237
+ "vmem_limit_bytes",
1238
+ "debug_mode",
1239
+ ),
1240
+ donate_argnames=("kv_cache", ),
1241
+ )
1242
+ def ragged_paged_attention_hd64(
1243
+ queries: jax.
1244
+ Array, # [max_num_tokens, actual_num_q_heads, actual_head_dim]
1245
+ keys: jax.Array, # [max_num_tokens, actual_num_kv_heads, actual_head_dim]
1246
+ values: jax.
1247
+ Array, # [max_num_tokens, actual_num_kv_heads, actual_head_dim]
1248
+ kv_cache: jax.
1249
+ Array, # [total_num_pages, page_size, num_kv_heads // kv_packing, kv_packing, actual_head_dim_x2]
1250
+ kv_lens: jax.Array, # i32[max_num_seqs]
1251
+ page_indices: jax.Array, # i32[max_num_seqs * pages_per_seq]
1252
+ cu_q_lens: jax.Array, # i32[max_num_seqs + 1]
1253
+ distribution: jax.Array, # i32[3]
1254
+ attention_sink: jax.Array | None = None, # f32[actual_num_q_heads]
1255
+ *,
1256
+ sm_scale: float = 1.0,
1257
+ sliding_window: int | None = None,
1258
+ soft_cap: float | None = None,
1259
+ mask_value: float | None = DEFAULT_MASK_VALUE,
1260
+ q_scale: float | None = None,
1261
+ k_scale: float | None = None,
1262
+ v_scale: float | None = None,
1263
+ # Kernel optimization params.
1264
+ chunk_prefill_size: int | None = None,
1265
+ # Kernel tuning params.
1266
+ num_kv_pages_per_block: int | None = None,
1267
+ num_queries_per_block: int | None = None,
1268
+ vmem_limit_bytes: int | None = None,
1269
+ # Debug params.
1270
+ debug_mode: bool = False,
1271
+ ):
1272
+ """A special Ragged paged attention version for head_dim=64 that supports mixed
1273
+
1274
+ prefill and decode.
1275
+
1276
+ Args:
1277
+ queries: concatenated all sequences' queries.
1278
+ keys: concatenated all sequences' keys (quantized).
1279
+ values: concatenated all sequences' values (quantized).
1280
+ kv_cache: paged KV cache with TPU-friendly shape.
1281
+ kv_lens: padded kv lengths. Only the first num_seqs values are valid.
1282
+ page_indices: flattened page indices look-up table by (seq_id, page_id).
1283
+ cu_q_lens: the cumulative sum of the effective query lengths. Similar to
1284
+ kv_lens, only the first num_seqs+1 values are valid.
1285
+ distribution: (i, j, k) represents that sequences[0:i] are decode-only,
1286
+ sequences[i:j] are chunked-prefill-only, and sequences[j:k] are mixed. The
1287
+ k is also the total number of sequences.
1288
+ attention_sink: optional attention sink for each q head.
1289
+ actual_head_dim: the actual head size of the attention. Here we assume k and
1290
+ v have the same actual head size.
1291
+ sm_scale: the softmax scale which will be applied to the Q@K^T.
1292
+ sliding_window: the sliding window size for the attention.
1293
+ soft_cap: the logit soft cap for the attention.
1294
+ mask_value: mask value for causal mask.
1295
+ k_scale: the scale for the key cache.
1296
+ v_scale: the scale for the value cache.
1297
+ num_kv_pages_per_block: number of kv pages to be processed in one flash
1298
+ attention block in the pallas kernel.
1299
+ num_queries_per_block: number of kv pages to be processed in one flash
1300
+ attention block in the pallas kernel.
1301
+ vmem_limit_bytes: the vmem limit for the pallas kernel.
1302
+ debug_mode: if true, RPA does not issue any DMAs or run flash attention but
1303
+ print debug info. Need to compile with `--xla_tpu_enable_log_recorder`.
1304
+
1305
+ Returns:
1306
+ The output of the attention.
1307
+ """
1308
+ q, k, v = queries, keys, values
1309
+ static_validate_inputs(
1310
+ q,
1311
+ k,
1312
+ v,
1313
+ kv_cache,
1314
+ kv_lens,
1315
+ page_indices,
1316
+ cu_q_lens,
1317
+ distribution,
1318
+ attention_sink,
1319
+ sm_scale=sm_scale,
1320
+ sliding_window=sliding_window,
1321
+ soft_cap=soft_cap,
1322
+ mask_value=mask_value,
1323
+ q_scale=q_scale,
1324
+ k_scale=k_scale,
1325
+ v_scale=v_scale,
1326
+ chunk_prefill_size=chunk_prefill_size,
1327
+ num_kv_pages_per_block=num_kv_pages_per_block,
1328
+ num_queries_per_block=num_queries_per_block,
1329
+ vmem_limit_bytes=vmem_limit_bytes,
1330
+ )
1331
+
1332
+ actual_num_q_heads = q.shape[1]
1333
+ actual_head_dim = q.shape[2]
1334
+ actual_num_kv_heads = k.shape[1]
1335
+
1336
+ actual_num_q_heads_per_kv_head = actual_num_q_heads // actual_num_kv_heads
1337
+ q, kv, attention_sink = prepare_inputs(q, k, v, attention_sink)
1338
+ (
1339
+ _,
1340
+ max_num_tokens,
1341
+ num_q_heads_per_kv_head_per_q_packing,
1342
+ q_packing,
1343
+ head_dim,
1344
+ ) = q.shape
1345
+ page_size = kv_cache.shape[1]
1346
+ max_num_seqs = kv_lens.shape[0]
1347
+ num_page_indices = page_indices.shape[0]
1348
+ assert num_page_indices % max_num_seqs == 0
1349
+ pages_per_seq = num_page_indices // max_num_seqs
1350
+ num_q_heads_per_kv_head = num_q_heads_per_kv_head_per_q_packing * q_packing
1351
+
1352
+ bkv_p = num_kv_pages_per_block
1353
+ bq_sz = num_queries_per_block
1354
+ if bq_sz is None or bkv_p is None:
1355
+ bkv_p, bq_sz = get_tuned_block_sizes(
1356
+ q.dtype,
1357
+ kv_cache.dtype,
1358
+ actual_num_q_heads,
1359
+ actual_num_kv_heads,
1360
+ actual_head_dim,
1361
+ page_size,
1362
+ max_num_tokens,
1363
+ pages_per_seq,
1364
+ )
1365
+ bkv_sz = bkv_p * page_size
1366
+ if vmem_limit_bytes is None:
1367
+ # TODO (jevinjiang/jacobplatin): change this to use
1368
+ # `get_vmem_estimate_bytes` when VREG spilling is fixed.
1369
+ vmem_limit_bytes = DEFAULT_VMEM_LIMIT_BYTES
1370
+ grid = (distribution[2], )
1371
+
1372
+ in_specs = [
1373
+ pl.BlockSpec(memory_space=pltpu.HBM),
1374
+ pl.BlockSpec(memory_space=pltpu.HBM),
1375
+ pl.BlockSpec(memory_space=pltpu.HBM),
1376
+ None if attention_sink is None else pl.BlockSpec(
1377
+ memory_space=pltpu.VMEM)
1378
+ ]
1379
+
1380
+ out_specs = [
1381
+ pl.BlockSpec(memory_space=pltpu.HBM),
1382
+ pl.BlockSpec(memory_space=pltpu.HBM),
1383
+ ]
1384
+
1385
+ bkv_double_buf = pltpu.VMEM(
1386
+ (2, bkv_sz, *kv_cache.shape[2:]),
1387
+ kv_cache.dtype,
1388
+ )
1389
+
1390
+ bq_double_buf = pltpu.VMEM(
1391
+ (2, actual_num_kv_heads, bq_sz, *q.shape[2:]),
1392
+ q.dtype,
1393
+ )
1394
+
1395
+ bo_double_buf = bq_double_buf
1396
+
1397
+ l_scratch = pltpu.VMEM(
1398
+ (actual_num_kv_heads, bq_sz * num_q_heads_per_kv_head, 128),
1399
+ jnp.float32,
1400
+ )
1401
+ m_scratch = l_scratch
1402
+
1403
+ acc_scratch = pltpu.VMEM(
1404
+ (actual_num_kv_heads, bq_sz * num_q_heads_per_kv_head, head_dim),
1405
+ jnp.float32,
1406
+ )
1407
+
1408
+ scratch_shapes = [
1409
+ bkv_double_buf, # Double buffering for kv block.
1410
+ bq_double_buf, # Double buffering for q block.
1411
+ bo_double_buf, # Double buffering for output block.
1412
+ # Semaphores for double buffering of bkv, bq, bo and bkv_update.
1413
+ pltpu.SemaphoreType.DMA((4, 2)),
1414
+ # Intermediate buffers per kv head for flash attention.
1415
+ l_scratch,
1416
+ m_scratch,
1417
+ acc_scratch,
1418
+ ]
1419
+
1420
+ scalar_prefetches = (
1421
+ kv_lens,
1422
+ # TODO(jevinjiang): can we use ragged page_indices to save some smem?
1423
+ page_indices,
1424
+ cu_q_lens,
1425
+ distribution,
1426
+ # (bq_sem_idx, bkv_sem_idx, bo_sem_idx)
1427
+ jnp.zeros((3, ), jnp.int32),
1428
+ # (bo_sem_0_seq_idx, bo_sem_1_seq_idx, bo_sem_0_bo_idx, bo_sem_1_bo_idx)
1429
+ jnp.full((4, ), -1, jnp.int32),
1430
+ # (bkv_sem_0_seq_idx, bkv_sem_1_seq_idx, bkv_sem_0_offset, bkv_sem_1_offset, bkv_sem_0_sz, bkv_sem_1_sz)
1431
+ jnp.full((6, ), -1, jnp.int32),
1432
+ )
1433
+
1434
+ scope_name = f"RPA-HD_64-bq_{bq_sz}-bkvp_{bkv_p}-p_{page_size}"
1435
+ kernel = jax.named_scope(scope_name)(
1436
+ pl.pallas_call(
1437
+ functools.partial(
1438
+ _ragged_paged_attention_kernel,
1439
+ sm_scale=sm_scale,
1440
+ sliding_window=sliding_window,
1441
+ soft_cap=soft_cap,
1442
+ mask_value=mask_value,
1443
+ q_scale=q_scale,
1444
+ k_scale=k_scale,
1445
+ v_scale=v_scale,
1446
+ chunk_prefill_size=chunk_prefill_size,
1447
+ bq_sz=bq_sz,
1448
+ bkv_p=bkv_p,
1449
+ debug_mode=debug_mode,
1450
+ ),
1451
+ grid_spec=pltpu.PrefetchScalarGridSpec(
1452
+ num_scalar_prefetch=len(scalar_prefetches),
1453
+ in_specs=in_specs,
1454
+ out_specs=out_specs,
1455
+ grid=grid,
1456
+ scratch_shapes=scratch_shapes,
1457
+ ),
1458
+ compiler_params=pltpu.CompilerParams(
1459
+ # TODO(jevinjiang): since each sequence depends on the previous
1460
+ # one, we need some extra work to support Megacore mode.
1461
+ dimension_semantics=("arbitrary", ),
1462
+ vmem_limit_bytes=vmem_limit_bytes,
1463
+ ),
1464
+ out_shape=[
1465
+ jax.ShapeDtypeStruct(shape=q.shape, dtype=q.dtype),
1466
+ jax.ShapeDtypeStruct(shape=kv_cache.shape,
1467
+ dtype=kv_cache.dtype),
1468
+ ],
1469
+ input_output_aliases={
1470
+ 7: 0,
1471
+ 9: 1
1472
+ },
1473
+ name=scope_name,
1474
+ ))
1475
+
1476
+ output, updated_kv_cache = kernel(*scalar_prefetches, q, kv, kv_cache,
1477
+ attention_sink)
1478
+ return (
1479
+ prepare_outputs(output, actual_num_q_heads_per_kv_head,
1480
+ actual_head_dim),
1481
+ updated_kv_cache,
1482
+ )