tpu-inference 0.0.1rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (174) hide show
  1. tests/__init__.py +0 -0
  2. tests/core/__init__.py +0 -0
  3. tests/core/test_core_tpu.py +513 -0
  4. tests/core/test_disagg_executor.py +60 -0
  5. tests/core/test_disagg_utils.py +53 -0
  6. tests/core/test_dp_scheduler.py +899 -0
  7. tests/core/test_init.py +49 -0
  8. tests/kernels/__init__.py +0 -0
  9. tests/kernels/fused_moe_v1_test.py +374 -0
  10. tests/kernels/mla_v1_test.py +396 -0
  11. tests/kernels/quantized_matmul_kernel_test.py +191 -0
  12. tests/kernels/ragged_kv_cache_update_v2_test.py +234 -0
  13. tests/kernels/ragged_paged_attention_kernel_v2_test.py +400 -0
  14. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +549 -0
  15. tests/kernels/ragged_paged_attention_kernel_v3_test.py +504 -0
  16. tests/lora/__init__.py +0 -0
  17. tests/lora/conftest.py +32 -0
  18. tests/lora/test_bgmv.py +43 -0
  19. tests/lora/test_layers.py +648 -0
  20. tests/lora/test_lora.py +133 -0
  21. tests/lora/utils.py +88 -0
  22. tests/test_base.py +201 -0
  23. tests/test_envs.py +203 -0
  24. tests/test_quantization.py +836 -0
  25. tests/test_tpu_info.py +120 -0
  26. tests/test_utils.py +235 -0
  27. tpu_inference/__init__.py +53 -0
  28. tpu_inference/core/__init__.py +0 -0
  29. tpu_inference/core/core_tpu.py +786 -0
  30. tpu_inference/core/disagg_executor.py +118 -0
  31. tpu_inference/core/disagg_utils.py +49 -0
  32. tpu_inference/core/sched/__init__.py +0 -0
  33. tpu_inference/core/sched/dp_scheduler.py +523 -0
  34. tpu_inference/distributed/__init__.py +0 -0
  35. tpu_inference/distributed/jax_parallel_state.py +67 -0
  36. tpu_inference/distributed/tpu_connector.py +727 -0
  37. tpu_inference/distributed/utils.py +60 -0
  38. tpu_inference/env_override.py +9 -0
  39. tpu_inference/envs.py +160 -0
  40. tpu_inference/executors/__init__.py +0 -0
  41. tpu_inference/executors/ray_distributed_executor.py +382 -0
  42. tpu_inference/experimental/__init__.py +0 -0
  43. tpu_inference/experimental/llama3_jax_stashed.py +258 -0
  44. tpu_inference/kernels/__init__.py +0 -0
  45. tpu_inference/kernels/collectives/__init__.py +0 -0
  46. tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
  47. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
  48. tpu_inference/kernels/collectives/util.py +47 -0
  49. tpu_inference/kernels/flash_attention/__init__.py +0 -0
  50. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  51. tpu_inference/kernels/fused_moe/__init__.py +0 -0
  52. tpu_inference/kernels/fused_moe/v1/__init__.py +0 -0
  53. tpu_inference/kernels/fused_moe/v1/kernel.py +1566 -0
  54. tpu_inference/kernels/mla/__init__.py +0 -0
  55. tpu_inference/kernels/mla/v1/__init__.py +0 -0
  56. tpu_inference/kernels/mla/v1/kernel.py +1349 -0
  57. tpu_inference/kernels/quantized_matmul/__init__.py +0 -0
  58. tpu_inference/kernels/quantized_matmul/kernel.py +395 -0
  59. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  60. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  61. tpu_inference/kernels/ragged_paged_attention/__init__.py +0 -0
  62. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +0 -0
  63. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +875 -0
  64. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +287 -0
  65. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  66. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +0 -0
  67. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1501 -0
  68. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1603 -0
  69. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4147 -0
  70. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +367 -0
  71. tpu_inference/kernels/ragged_paged_attention/v3/util.py +51 -0
  72. tpu_inference/layers/__init__.py +0 -0
  73. tpu_inference/layers/common/__init__.py +0 -0
  74. tpu_inference/layers/common/attention_interface.py +396 -0
  75. tpu_inference/layers/common/attention_metadata.py +34 -0
  76. tpu_inference/layers/common/binary_search.py +295 -0
  77. tpu_inference/layers/common/quant_methods.py +8 -0
  78. tpu_inference/layers/common/sharding.py +582 -0
  79. tpu_inference/layers/jax/__init__.py +0 -0
  80. tpu_inference/layers/jax/attention/__init__.py +0 -0
  81. tpu_inference/layers/jax/attention/attention.py +255 -0
  82. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
  83. tpu_inference/layers/jax/attention/gpt_oss_attention.py +262 -0
  84. tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
  85. tpu_inference/layers/jax/base.py +151 -0
  86. tpu_inference/layers/jax/constants.py +88 -0
  87. tpu_inference/layers/jax/layers.py +301 -0
  88. tpu_inference/layers/jax/misc.py +16 -0
  89. tpu_inference/layers/jax/moe/__init__.py +0 -0
  90. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
  91. tpu_inference/layers/jax/moe/gpt_oss_moe.py +185 -0
  92. tpu_inference/layers/jax/moe/moe.py +209 -0
  93. tpu_inference/layers/jax/rope.py +280 -0
  94. tpu_inference/layers/jax/rope_interface.py +214 -0
  95. tpu_inference/layers/jax/sample/__init__.py +0 -0
  96. tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
  97. tpu_inference/layers/jax/sample/sampling.py +96 -0
  98. tpu_inference/layers/jax/sample/sampling_metadata.py +76 -0
  99. tpu_inference/layers/jax/transformer_block.py +107 -0
  100. tpu_inference/layers/vllm/__init__.py +0 -0
  101. tpu_inference/layers/vllm/attention.py +221 -0
  102. tpu_inference/layers/vllm/fused_moe.py +469 -0
  103. tpu_inference/layers/vllm/linear_common.py +186 -0
  104. tpu_inference/layers/vllm/quantization/__init__.py +39 -0
  105. tpu_inference/layers/vllm/quantization/awq.py +207 -0
  106. tpu_inference/layers/vllm/quantization/common.py +110 -0
  107. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
  108. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +120 -0
  109. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +203 -0
  110. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
  111. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
  112. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
  113. tpu_inference/layers/vllm/quantization/mxfp4.py +331 -0
  114. tpu_inference/layers/vllm/quantization/unquantized.py +368 -0
  115. tpu_inference/layers/vllm/sharding.py +230 -0
  116. tpu_inference/logger.py +10 -0
  117. tpu_inference/lora/__init__.py +0 -0
  118. tpu_inference/lora/torch_lora_ops.py +103 -0
  119. tpu_inference/lora/torch_punica_tpu.py +310 -0
  120. tpu_inference/models/__init__.py +0 -0
  121. tpu_inference/models/common/__init__.py +0 -0
  122. tpu_inference/models/common/model_loader.py +478 -0
  123. tpu_inference/models/jax/__init__.py +0 -0
  124. tpu_inference/models/jax/deepseek_v3.py +868 -0
  125. tpu_inference/models/jax/gpt_oss.py +492 -0
  126. tpu_inference/models/jax/jax_intermediate_tensor.py +79 -0
  127. tpu_inference/models/jax/llama3.py +376 -0
  128. tpu_inference/models/jax/llama4.py +629 -0
  129. tpu_inference/models/jax/llama_eagle3.py +336 -0
  130. tpu_inference/models/jax/llama_guard_4.py +361 -0
  131. tpu_inference/models/jax/qwen2.py +376 -0
  132. tpu_inference/models/jax/qwen2_5_vl.py +1218 -0
  133. tpu_inference/models/jax/qwen3.py +303 -0
  134. tpu_inference/models/jax/utils/__init__.py +0 -0
  135. tpu_inference/models/jax/utils/file_utils.py +96 -0
  136. tpu_inference/models/jax/utils/multi_modal_utils.py +163 -0
  137. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  138. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +5 -0
  139. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +6 -0
  140. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +5 -0
  141. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +6 -0
  142. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +105 -0
  143. tpu_inference/models/jax/utils/quantization/quantization_utils.py +650 -0
  144. tpu_inference/models/jax/utils/weight_utils.py +584 -0
  145. tpu_inference/models/vllm/__init__.py +0 -0
  146. tpu_inference/models/vllm/vllm_model_wrapper.py +293 -0
  147. tpu_inference/models/vllm/vllm_model_wrapper_context.py +45 -0
  148. tpu_inference/platforms/__init__.py +2 -0
  149. tpu_inference/platforms/tpu_platform.py +275 -0
  150. tpu_inference/runner/__init__.py +0 -0
  151. tpu_inference/runner/block_table.py +122 -0
  152. tpu_inference/runner/compilation_manager.py +865 -0
  153. tpu_inference/runner/input_batch.py +435 -0
  154. tpu_inference/runner/kv_cache.py +132 -0
  155. tpu_inference/runner/kv_cache_manager.py +478 -0
  156. tpu_inference/runner/lora_utils.py +92 -0
  157. tpu_inference/runner/multimodal_manager.py +217 -0
  158. tpu_inference/runner/persistent_batch_manager.py +282 -0
  159. tpu_inference/runner/speculative_decoding_manager.py +248 -0
  160. tpu_inference/runner/structured_decoding_manager.py +87 -0
  161. tpu_inference/runner/tpu_runner.py +1744 -0
  162. tpu_inference/runner/utils.py +426 -0
  163. tpu_inference/spec_decode/__init__.py +0 -0
  164. tpu_inference/spec_decode/jax/__init__.py +0 -0
  165. tpu_inference/spec_decode/jax/eagle3.py +417 -0
  166. tpu_inference/tpu_info.py +78 -0
  167. tpu_inference/utils.py +340 -0
  168. tpu_inference/worker/__init__.py +0 -0
  169. tpu_inference/worker/tpu_worker.py +458 -0
  170. tpu_inference-0.0.1rc1.dist-info/METADATA +108 -0
  171. tpu_inference-0.0.1rc1.dist-info/RECORD +174 -0
  172. tpu_inference-0.0.1rc1.dist-info/WHEEL +5 -0
  173. tpu_inference-0.0.1rc1.dist-info/licenses/LICENSE +201 -0
  174. tpu_inference-0.0.1rc1.dist-info/top_level.txt +2 -0
@@ -0,0 +1,1501 @@
1
+ """TPU-Friendly Ragged Paged Attention kernel.
2
+
3
+ This kernel offers a highly optimized implementation of ragged paged attention,
4
+ specifically designed for TPU and compatible with a wide range of model
5
+ specifications. It supports mixed prefill and decoding, enhancing throughput
6
+ during inference.
7
+ """
8
+ import functools
9
+
10
+ import jax
11
+ import jax.numpy as jnp
12
+ from jax import lax
13
+ from jax.experimental import pallas as pl
14
+ from jax.experimental.pallas import tpu as pltpu
15
+
16
+ from tpu_inference.kernels.ragged_paged_attention.v3.tuned_block_sizes import \
17
+ get_tuned_block_sizes
18
+ from tpu_inference.kernels.ragged_paged_attention.v3.util import (
19
+ align_to, cdiv, get_dtype_bitwidth, get_dtype_packing)
20
+
21
+ DEFAULT_MASK_VALUE = -0.7 * float(jnp.finfo(jnp.dtype("float32")).max)
22
+
23
+ DEFAULT_VMEM_LIMIT_BYTES = 100 * 1024 * 1024
24
+
25
+
26
+ def ref_ragged_paged_attention(
27
+ queries: jax.
28
+ Array, # [max_num_tokens, actual_num_q_heads, actual_head_dim]
29
+ keys: jax.Array, # [max_num_tokens, actual_num_kv_heads, actual_head_dim]
30
+ values: jax.
31
+ Array, # [max_num_tokens, actual_num_kv_heads, actual_head_dim]
32
+ kv_cache: jax.
33
+ Array, # [total_num_pages, page_size, num_kv_heads_x2 // kv_packing, kv_packing, head_dim]
34
+ kv_lens: jax.Array, # i32[max_num_seqs]
35
+ page_indices: jax.Array, # i32[max_num_seqs * pages_per_seq]
36
+ cu_q_lens: jax.Array, # i32[max_num_seqs + 1]
37
+ distribution: jax.Array, # i32[3]
38
+ *,
39
+ sm_scale: float = 1.0,
40
+ sliding_window: int | None = None,
41
+ soft_cap: float | None = None,
42
+ mask_value: float | None = DEFAULT_MASK_VALUE,
43
+ q_scale: float | None = None,
44
+ k_scale: float | None = None,
45
+ v_scale: float | None = None,
46
+ ):
47
+ if mask_value is None:
48
+ mask_value = DEFAULT_MASK_VALUE
49
+
50
+ dynamic_validate_inputs(
51
+ queries,
52
+ keys,
53
+ values,
54
+ kv_cache,
55
+ kv_lens,
56
+ page_indices,
57
+ cu_q_lens,
58
+ distribution,
59
+ sm_scale=sm_scale,
60
+ sliding_window=sliding_window,
61
+ soft_cap=soft_cap,
62
+ mask_value=mask_value,
63
+ q_scale=q_scale,
64
+ k_scale=k_scale,
65
+ v_scale=v_scale,
66
+ )
67
+ actual_head_dim = queries.shape[2]
68
+ actual_num_q_heads = queries.shape[1]
69
+ actual_num_kv_heads = keys.shape[1]
70
+ merged_kv = merge_kv(keys, values)
71
+ assert merged_kv.shape[-3:] == kv_cache.shape[-3:]
72
+
73
+ _, page_size, num_kv_heads_x2_per_kv_packing, kv_packing, head_dim = (
74
+ kv_cache.shape)
75
+ num_kv_heads_x2 = num_kv_heads_x2_per_kv_packing * kv_packing
76
+ assert num_kv_heads_x2 % 2 == 0
77
+ assert actual_num_q_heads % actual_num_kv_heads == 0
78
+ assert head_dim % 128 == 0
79
+ assert get_dtype_packing(kv_cache.dtype) == kv_packing
80
+ assert num_kv_heads_x2 == align_to(actual_num_kv_heads * 2, kv_packing)
81
+ actual_num_q_heads_per_kv_head = actual_num_q_heads // actual_num_kv_heads
82
+ max_num_seqs = kv_lens.shape[0]
83
+ num_page_indices = page_indices.shape[0]
84
+ assert num_page_indices % max_num_seqs == 0
85
+ pages_per_seq = num_page_indices // max_num_seqs
86
+ outputs = []
87
+
88
+ for i in range(distribution[-1]):
89
+ q_start = cu_q_lens[i]
90
+ q_end = cu_q_lens[i + 1]
91
+ q_len = q_end - q_start
92
+
93
+ kv_len = kv_lens[i]
94
+ indices_start = i * pages_per_seq
95
+ indices_end = indices_start + cdiv(kv_len, page_size)
96
+ indices = page_indices[indices_start:indices_end]
97
+ q = queries[q_start:q_end, :, :actual_head_dim]
98
+
99
+ # Update the kv cache.
100
+ assert kv_len - q_len >= 0
101
+ gathered_kv = kv_cache[indices]
102
+ gathered_shape = gathered_kv.shape
103
+ gathered_kv = gathered_kv.reshape(-1, *gathered_shape[-3:])
104
+ gathered_kv = gathered_kv.at[kv_len - q_len:kv_len].set(
105
+ merged_kv[q_start:q_end])
106
+ kv_cache = kv_cache.at[indices].set(
107
+ gathered_kv.reshape(gathered_shape))
108
+
109
+ kv = gathered_kv.reshape(
110
+ -1, num_kv_heads_x2,
111
+ head_dim)[:, :actual_num_kv_heads * 2, :].reshape(
112
+ -1, actual_num_kv_heads, head_dim * 2)
113
+ k = kv[:kv_len, :, :head_dim][:, :, :actual_head_dim]
114
+ v = kv[:kv_len, :, head_dim:][:, :, :actual_head_dim]
115
+ k = jnp.repeat(k, actual_num_q_heads_per_kv_head, axis=1)
116
+ v = jnp.repeat(v, actual_num_q_heads_per_kv_head, axis=1)
117
+
118
+ if q_scale is not None:
119
+ q = q / q_scale
120
+ if jnp.issubdtype(k.dtype, jnp.floating):
121
+ dtype_info = jnp.finfo(k.dtype)
122
+ minval = float(dtype_info.min)
123
+ maxval = float(dtype_info.max)
124
+ q = jnp.clip(q, min=minval, max=maxval)
125
+ q = q.astype(k.dtype)
126
+
127
+ attn = jnp.einsum("qhd,khd->hqk",
128
+ q,
129
+ k,
130
+ preferred_element_type=jnp.float32)
131
+ attn *= sm_scale
132
+ if k_scale is not None:
133
+ attn *= k_scale
134
+ if q_scale is not None:
135
+ attn *= q_scale
136
+
137
+ q_span = (kv_len - q_len) + jax.lax.broadcasted_iota(
138
+ jnp.int32, attn.shape, 1)
139
+ kv_span = jax.lax.broadcasted_iota(jnp.int32, attn.shape, 2)
140
+ mask = q_span < kv_span
141
+ if sliding_window is not None:
142
+ mask = jnp.logical_or(mask, q_span - sliding_window >= kv_span)
143
+ if soft_cap is not None:
144
+ attn = soft_cap * jnp.tanh(attn / soft_cap)
145
+ attn += jnp.where(mask, mask_value, 0.0)
146
+ attn = jax.nn.softmax(attn, axis=-1).astype(v.dtype)
147
+
148
+ out = jnp.einsum("hqk,khd->qhd", attn, v).astype(queries.dtype)
149
+ if v_scale is not None:
150
+ out *= v_scale
151
+
152
+ outputs.append(out)
153
+
154
+ result = jnp.concatenate(outputs, axis=0)
155
+ return result, kv_cache
156
+
157
+
158
+ def get_smem_estimate_bytes(max_num_seqs, pages_per_seq):
159
+ total_bits = (
160
+ # kv_lens_ref: i32[max_num_seqs]
161
+ align_to(max_num_seqs, 128) * 32 +
162
+ # page_indices_ref: i32[max_num_seqs * pages_per_seq]
163
+ align_to(max_num_seqs * pages_per_seq, 128) * 32 +
164
+ # cu_q_lens_ref: i32[max_num_seqs + 1]
165
+ align_to(max_num_seqs + 1, 128) * 32 +
166
+ # distribution_ref: i32[3]
167
+ 128 * 32 +
168
+ # sem_ids_ref: i32[3]
169
+ 128 * 32 +
170
+ # bo_ids_ref: i32[4]
171
+ 128 * 32 +
172
+ # bkv_update_ids_ref: i32[6]
173
+ 128 * 32)
174
+ return cdiv(total_bits, 8)
175
+
176
+
177
+ def get_vmem_estimate_bytes(
178
+ actual_num_kv_heads,
179
+ actual_num_q_heads_per_kv_head,
180
+ actual_head_dim,
181
+ bq_sz,
182
+ bkv_sz,
183
+ q_dtype,
184
+ kv_dtype,
185
+ ):
186
+ q_packing = get_dtype_packing(q_dtype)
187
+ kv_packing = get_dtype_packing(kv_dtype)
188
+ num_q_heads_per_kv_head = align_to(actual_num_q_heads_per_kv_head,
189
+ q_packing)
190
+ num_kv_heads_x2 = align_to(actual_num_kv_heads * 2, kv_packing)
191
+ head_dim = align_to(actual_head_dim, 128)
192
+
193
+ total_bits = (
194
+ # bkv_x2_ref
195
+ (2 * bkv_sz * num_kv_heads_x2 * head_dim) * (32 // kv_packing) +
196
+ # bq_x2_ref + bo_x2_ref
197
+ 2 * (2 * actual_num_kv_heads * bq_sz * num_q_heads_per_kv_head *
198
+ head_dim) * (32 // q_packing) +
199
+ # l_ref + m_ref
200
+ 2 *
201
+ (actual_num_kv_heads * bq_sz * num_q_heads_per_kv_head * 128) * 32 +
202
+ # acc_ref
203
+ (actual_num_kv_heads * bq_sz * num_q_heads_per_kv_head * head_dim) *
204
+ 32)
205
+ return cdiv(total_bits, 8)
206
+
207
+
208
+ def get_kv_cache_shape(
209
+ total_num_pages,
210
+ page_size,
211
+ actual_num_kv_heads,
212
+ actual_head_dim,
213
+ kv_dtype,
214
+ ):
215
+ kv_packing = get_dtype_packing(kv_dtype)
216
+ return (
217
+ total_num_pages,
218
+ page_size,
219
+ align_to(actual_num_kv_heads * 2, kv_packing) // kv_packing,
220
+ kv_packing,
221
+ align_to(actual_head_dim, 128),
222
+ )
223
+
224
+
225
+ def _ragged_paged_attention_kernel(
226
+ # Prefetch
227
+ kv_lens_ref, # [max_num_seqs]
228
+ page_indices_ref, # [max_num_seqs * pages_per_seq]
229
+ cu_q_lens_ref, # [max_num_seqs + 1]
230
+ # TODO(jevinjiang): merge these into one so we can save SMEM.
231
+ distribution_ref, # [3] (decode_end, prefill_end, mixed_end)
232
+ sem_ids_ref, # [3] (bq_sem_idx, bkv_sem_idx, bo_sem_idx)
233
+ bo_ids_ref, # [4] (bo_sem_0_seq_idx, bo_sem_1_seq_idx, bo_sem_0_bo_idx, bo_sem_1_bo_idx)
234
+ bkv_update_ids_ref, # [6] (bkv_sem_0_seq_idx, bkv_sem_1_seq_idx, bkv_sem_0_offset, bkv_sem_1_offset, bkv_sem_0_sz, bkv_sem_1_sz)
235
+ # Input
236
+ q_hbm_ref, # [actual_num_kv_heads, max_num_tokens, num_q_heads_per_kv_head // q_packing, q_packing, head_dim]
237
+ kv_hbm_ref, # [max_num_tokens, num_kv_heads_x2 // kv_packing, kv_packing, head_dim]
238
+ kv_cache_hbm_ref, # [total_num_pages, page_size, num_kv_heads_x2 // kv_packing, kv_packing, head_dim]
239
+ # Output
240
+ o_hbm_ref, # [actual_num_kv_heads, max_num_tokens, num_q_heads_per_kv_head // q_packing, q_packing, head_dim]
241
+ updated_kv_cache_hbm_ref, # [total_num_pages, page_size, num_kv_heads_x2 // kv_packing, kv_packing, head_dim]
242
+ # Scratch
243
+ bkv_x2_ref, # [2, bkv_sz, num_kv_heads_x2 // kv_packing, kv_packing, head_dim]
244
+ bq_x2_ref, # [2, actual_num_kv_heads, bq_sz, num_q_heads_per_kv_head // q_packing, q_packing, head_dim]
245
+ bo_x2_ref, # [2, actual_num_kv_heads, bq_sz, num_q_heads_per_kv_head // q_packing, q_packing, head_dim]
246
+ sems, # [4, 2]
247
+ l_ref, # [actual_num_kv_heads, bq_sz * num_q_heads_per_kv_head, 128],
248
+ m_ref, # [actual_num_kv_heads, bq_sz * num_q_heads_per_kv_head, 128],
249
+ acc_ref, # [actual_num_kv_heads, bq_sz * num_q_heads_per_kv_head, head_dim],
250
+ *,
251
+ sm_scale: float,
252
+ sliding_window: int | None = None,
253
+ soft_cap: float | None = None,
254
+ mask_value: float = DEFAULT_MASK_VALUE,
255
+ q_scale: float | None = None,
256
+ k_scale: float | None = None,
257
+ v_scale: float | None = None,
258
+ chunk_prefill_size: int | None = None,
259
+ bkv_p,
260
+ bq_sz,
261
+ debug_mode: bool = False,
262
+ ):
263
+ assert q_hbm_ref.shape == o_hbm_ref.shape
264
+ assert q_hbm_ref.shape[-1] == kv_cache_hbm_ref.shape[-1]
265
+ (
266
+ actual_num_kv_heads,
267
+ max_num_tokens,
268
+ num_q_heads_per_kv_head_per_packing,
269
+ q_packing,
270
+ head_dim,
271
+ ) = q_hbm_ref.shape
272
+ (
273
+ total_num_pages,
274
+ page_size,
275
+ num_kv_heads_x2_per_kv_packing,
276
+ kv_packing,
277
+ _,
278
+ ) = kv_cache_hbm_ref.shape
279
+ max_num_seqs = kv_lens_ref.shape[0]
280
+ num_page_indices = page_indices_ref.shape[0]
281
+ assert num_page_indices % max_num_seqs == 0
282
+ pages_per_seq = num_page_indices // max_num_seqs
283
+ num_kv_heads_x2 = num_kv_heads_x2_per_kv_packing * kv_packing
284
+ num_q_heads_per_kv_head = num_q_heads_per_kv_head_per_packing * q_packing
285
+ q_dtype = q_hbm_ref.dtype
286
+ kv_dtype = kv_cache_hbm_ref.dtype
287
+ assert o_hbm_ref.dtype == q_dtype
288
+ assert get_dtype_packing(q_dtype) == q_packing
289
+ assert get_dtype_packing(kv_dtype) == kv_packing
290
+ assert head_dim % 128 == 0
291
+ bkv_sz = bkv_p * page_size
292
+ seq_idx = pl.program_id(0)
293
+ num_seqs = pl.num_programs(0)
294
+ decode_end = distribution_ref[0]
295
+ prefill_end = distribution_ref[1]
296
+ mixed_end = distribution_ref[2]
297
+
298
+ q_start = cu_q_lens_ref[seq_idx]
299
+ q_end = cu_q_lens_ref[seq_idx + 1]
300
+ q_len = q_end - q_start
301
+ kv_len = kv_lens_ref[seq_idx]
302
+
303
+ def debug_print(msg, *args):
304
+ if debug_mode:
305
+ pl.debug_print(msg, *args)
306
+
307
+ debug_print("[RPA debug] ======= In loop seq_idx={}", seq_idx)
308
+ debug_print("[RPA debug] num_seqs={}", num_seqs)
309
+ debug_print("[RPA debug] decode_end={}", decode_end)
310
+ debug_print("[RPA debug] prefill_end={}", prefill_end)
311
+ debug_print("[RPA debug] mixed_end={}", mixed_end)
312
+ debug_print("[RPA debug] bkv_p={}", bkv_p)
313
+ debug_print("[RPA debug] page_size={}", page_size)
314
+ debug_print("[RPA debug] pages_per_seq={}", pages_per_seq)
315
+ debug_print("[RPA debug] bkv_sz={}", bkv_sz)
316
+ debug_print("[RPA debug] bq_sz={}", bq_sz)
317
+ debug_print("[RPA debug] q_start={}", q_start)
318
+ debug_print("[RPA debug] q_end={}", q_end)
319
+ debug_print("[RPA debug] q_len={}", q_len)
320
+ debug_print("[RPA debug] kv_len={}", kv_len)
321
+
322
+ def flash_attention(
323
+ q, # [actual_bq_sz * num_q_heads_per_kv_head, head_dim]
324
+ k, # [bkv_sz, head_dim]
325
+ v, # [bkv_sz, head_dim]
326
+ *,
327
+ bq_idx,
328
+ bkv_idx,
329
+ kv_head_idx,
330
+ ):
331
+ assert len(q.shape) == 2
332
+ assert q.shape[0] % num_q_heads_per_kv_head == 0
333
+ assert q.shape[1] == head_dim
334
+ assert k.shape == v.shape == (bkv_sz, head_dim)
335
+ assert k.dtype == v.dtype
336
+ head_l_ref = l_ref.at[kv_head_idx, :q.shape[0]]
337
+ head_m_ref = m_ref.at[kv_head_idx, :q.shape[0]]
338
+ head_acc_ref = acc_ref.at[kv_head_idx, :q.shape[0]]
339
+
340
+ def load_with_init(ref, init_val):
341
+ return jnp.where(bkv_idx == 0, jnp.full_like(ref, init_val),
342
+ ref[...])
343
+
344
+ # Follow FlashAttention-2 forward pass.
345
+ if q_scale is not None:
346
+ q = q / q_scale
347
+ if jnp.issubdtype(k.dtype, jnp.floating):
348
+ dtype_info = jnp.finfo(k.dtype)
349
+ minval = float(dtype_info.min)
350
+ maxval = float(dtype_info.max)
351
+ q = jnp.clip(q, min=minval, max=maxval)
352
+ q = q.astype(k.dtype)
353
+
354
+ s = jnp.einsum("nd,md->nm", q, k, preferred_element_type=jnp.float32)
355
+ s *= sm_scale
356
+ if k_scale is not None:
357
+ s *= k_scale
358
+ if q_scale is not None:
359
+ s *= q_scale
360
+
361
+ q_span = (kv_len - q_len + bq_idx * bq_sz +
362
+ lax.broadcasted_iota(jnp.int32, s.shape, 0) //
363
+ num_q_heads_per_kv_head)
364
+ k_span = bkv_idx * bkv_sz + lax.broadcasted_iota(jnp.int32, s.shape, 1)
365
+ mask = q_span < k_span
366
+ # TODO(jevinjiang, xiowei): reduce pages_per_seq based on sliding_window.
367
+ if sliding_window is not None:
368
+ mask = jnp.logical_or(mask, q_span - sliding_window >= k_span)
369
+
370
+ if soft_cap is not None:
371
+ s = soft_cap * jnp.tanh(s / soft_cap)
372
+ s += jnp.where(mask, mask_value, 0.0)
373
+ s_rowmax = jnp.max(s, axis=1, keepdims=True)
374
+ m_prev = load_with_init(head_m_ref, -jnp.inf)
375
+ m_curr = jnp.maximum(m_prev, s_rowmax)
376
+ head_m_ref[...] = m_curr
377
+ p = jnp.exp(s - broadcast_minor(m_curr, s.shape))
378
+
379
+ pv = jnp.einsum("nm,md->nd", p, v, preferred_element_type=jnp.float32)
380
+ if v_scale is not None:
381
+ pv *= v_scale
382
+
383
+ p_rowsum = jnp.sum(p, axis=1, keepdims=True)
384
+ exp_m_diff = jnp.exp(m_prev - m_curr)
385
+ l_prev = load_with_init(head_l_ref, 0.0)
386
+ l_curr = exp_m_diff * l_prev + p_rowsum
387
+ head_l_ref[...] = l_curr
388
+ o_prev = load_with_init(head_acc_ref, 0.0)
389
+ o_curr = broadcast_minor(exp_m_diff, o_prev.shape) * o_prev + pv
390
+ head_acc_ref[...] = o_curr
391
+
392
+ def _async_copy(src, dst, sem, wait):
393
+ if debug_mode:
394
+ # Skip DMA if debug mode is enabled.
395
+ return
396
+ cp = pltpu.make_async_copy(src, dst, sem)
397
+ if wait:
398
+ cp.wait()
399
+ else:
400
+ cp.start()
401
+
402
+ def _fetch_bkv(seq_idx, bkv_idx, bkv_sem_idx, *, wait=False):
403
+ sem = sems.at[0, bkv_sem_idx]
404
+ vmem_ref = bkv_x2_ref.at[bkv_sem_idx]
405
+
406
+ cache_hbm_shape = kv_cache_hbm_ref.shape
407
+ cache_hbm_ref = kv_cache_hbm_ref.reshape(
408
+ cache_hbm_shape[0] * cache_hbm_shape[1], *cache_hbm_shape[2:])
409
+ kv_len = kv_lens_ref[seq_idx]
410
+ kv_len_start = bkv_idx * bkv_sz
411
+ kv_p_start = bkv_idx * bkv_p
412
+ q_start = cu_q_lens_ref[seq_idx]
413
+ q_end = cu_q_lens_ref[seq_idx + 1]
414
+ q_len = q_end - q_start
415
+
416
+ kv_left = kv_len - kv_len_start
417
+ kv_left_frm_cache = jnp.maximum(kv_left - q_len, 0)
418
+ kv_left_frm_new = kv_left - kv_left_frm_cache
419
+ bkv_p_frm_cache = jnp.minimum(cdiv(kv_left_frm_cache, page_size),
420
+ bkv_p)
421
+ bkv_sz_frm_new = jnp.minimum(
422
+ jnp.maximum(bkv_sz - kv_left_frm_cache, 0), kv_left_frm_new)
423
+ page_indices_offset = seq_idx * pages_per_seq + kv_p_start
424
+
425
+ # Make sure the current bkv buffer is safe to overwrite.
426
+ wait_update_kv_cache(bkv_sem_idx)
427
+
428
+ debug_print(
429
+ "[RPA debug]"
430
+ f" -----------{'wait' if wait else 'start'}_fetch_bkv-----------")
431
+ debug_print("[RPA debug] seq_idx={}", seq_idx)
432
+ debug_print("[RPA debug] bkv_idx={}", bkv_idx)
433
+ debug_print("[RPA debug] bkv_sem_idx={}", bkv_sem_idx)
434
+ debug_print("[RPA debug] kv_len_start={}", kv_len_start)
435
+ debug_print("[RPA debug] kv_p_start={}", kv_p_start)
436
+ debug_print("[RPA debug] kv_left={}", kv_left)
437
+ debug_print("[RPA debug] kv_left_frm_cache={}", kv_left_frm_cache)
438
+ debug_print("[RPA debug] kv_left_frm_new={}", kv_left_frm_new)
439
+ debug_print("[RPA debug] bkv_p_frm_cache={}", bkv_p_frm_cache)
440
+ debug_print("[RPA debug] bkv_sz_frm_new={}", bkv_sz_frm_new)
441
+ debug_print("[RPA debug] page_indices_offset={}", page_indices_offset)
442
+
443
+ if not wait:
444
+ # Fetch effective kv from kv cache.
445
+ def loop_body(i, offset):
446
+ sz = jnp.minimum(page_size, kv_left_frm_cache - i * page_size)
447
+ _async_copy(
448
+ cache_hbm_ref.at[pl.ds(
449
+ page_indices_ref[page_indices_offset + i] * page_size,
450
+ sz)],
451
+ vmem_ref.at[pl.ds(i * page_size, sz)],
452
+ sem,
453
+ wait=False,
454
+ )
455
+ debug_print("[RPA debug] loop_body i={}, sz={}", i, sz)
456
+ return offset + sz
457
+
458
+ offset = lax.fori_loop(
459
+ 0,
460
+ bkv_p_frm_cache,
461
+ loop_body,
462
+ 0, # offset
463
+ unroll=False,
464
+ )
465
+
466
+ # Fetch kv directly from new kv.
467
+ @pl.when(bkv_sz_frm_new > 0)
468
+ def _fetch_bkv_from_new_kv():
469
+ new_kv_len_start = q_end - kv_left_frm_new
470
+ debug_print("[RPA debug] new_kv_len_start={}",
471
+ new_kv_len_start)
472
+ debug_print("[RPA debug] offset_in_bkv={}", offset)
473
+ _async_copy(
474
+ kv_hbm_ref.at[pl.ds(new_kv_len_start, bkv_sz_frm_new)],
475
+ vmem_ref.at[pl.ds(offset, bkv_sz_frm_new)],
476
+ sem,
477
+ wait,
478
+ )
479
+
480
+ return kv_len_start + offset, bkv_sz_frm_new
481
+ else:
482
+ offset = jnp.minimum(kv_left_frm_cache, page_size * bkv_p)
483
+ dst = vmem_ref.at[pl.ds(0, offset + bkv_sz_frm_new)]
484
+ _async_copy(
485
+ src=dst,
486
+ dst=dst,
487
+ sem=sem,
488
+ wait=True,
489
+ )
490
+ return kv_len_start + offset, bkv_sz_frm_new
491
+
492
+ def _update_kv_cache(seq_idx,
493
+ bkv_sem_idx,
494
+ offset,
495
+ update_sz,
496
+ *,
497
+ wait=False):
498
+ sem = sems.at[3, bkv_sem_idx]
499
+ vmem_ref = bkv_x2_ref.at[bkv_sem_idx]
500
+ bkv_id = offset // bkv_sz
501
+ kv_p_start = offset // page_size
502
+ kv_p_end = cdiv(offset + update_sz, page_size)
503
+ ignore = offset % page_size
504
+ p_ignore = kv_p_start - bkv_id * bkv_p
505
+ page_indices_offset = seq_idx * pages_per_seq + kv_p_start
506
+
507
+ cache_hbm_shape = updated_kv_cache_hbm_ref.shape
508
+ cache_hbm_ref = updated_kv_cache_hbm_ref.reshape(
509
+ cache_hbm_shape[0] * cache_hbm_shape[1], *cache_hbm_shape[2:])
510
+
511
+ debug_print(
512
+ "[RPA debug]"
513
+ f" -----------{'wait' if wait else 'start'}_update_kv_cache-----------"
514
+ )
515
+ debug_print("[RPA debug] seq_idx={}", seq_idx)
516
+ debug_print("[RPA debug] bkv_sem_idx={}", bkv_sem_idx)
517
+ debug_print("[RPA debug] offset={}", offset)
518
+ debug_print("[RPA debug] update_sz={}", update_sz)
519
+ debug_print("[RPA debug] bkv_id={}", bkv_id)
520
+ debug_print("[RPA debug] kv_p_start={}", kv_p_start)
521
+ debug_print("[RPA debug] kv_p_end={}", kv_p_end)
522
+ debug_print("[RPA debug] ignore={}", ignore)
523
+ debug_print("[RPA debug] p_ignore={}", p_ignore)
524
+ debug_print("[RPA debug] page_indices_offset={}", page_indices_offset)
525
+
526
+ if not wait:
527
+
528
+ def loop_body(i, states):
529
+ update_sz, ignore = states
530
+ sz = jnp.minimum(page_size - ignore, update_sz)
531
+
532
+ _async_copy(
533
+ vmem_ref.at[pl.ds((p_ignore + i) * page_size + ignore,
534
+ sz)],
535
+ cache_hbm_ref.at[pl.ds(
536
+ page_indices_ref[page_indices_offset + i] * page_size +
537
+ ignore,
538
+ sz,
539
+ )],
540
+ sem,
541
+ wait=False,
542
+ )
543
+ debug_print("[RPA debug] loop_body i={}, sz={}", i, sz)
544
+ return update_sz - sz, 0
545
+
546
+ lax.fori_loop(
547
+ 0,
548
+ kv_p_end - kv_p_start,
549
+ loop_body,
550
+ (update_sz, ignore), # total transfer size
551
+ unroll=False,
552
+ )
553
+ else:
554
+ dst = cache_hbm_ref.at[pl.ds(0, update_sz)]
555
+ _async_copy(
556
+ src=dst,
557
+ dst=dst,
558
+ sem=sem,
559
+ wait=True,
560
+ )
561
+
562
+ def _fetch_bq(seq_idx, bq_idx, bq_sem_idx, *, wait=False):
563
+ sem = sems.at[1, bq_sem_idx]
564
+ vmem_ref = bq_x2_ref.at[bq_sem_idx]
565
+ q_len_start = cu_q_lens_ref[seq_idx] + bq_idx * bq_sz
566
+ q_end = cu_q_lens_ref[seq_idx + 1]
567
+ sz = jnp.minimum(bq_sz, q_end - q_len_start)
568
+
569
+ debug_print(
570
+ "[RPA debug]"
571
+ f" -----------{'wait' if wait else 'start'}_fetch_bq-----------")
572
+ debug_print("[RPA debug] seq_idx={}", seq_idx)
573
+ debug_print("[RPA debug] bq_idx={}", bq_idx)
574
+ debug_print("[RPA debug] bq_sem_idx={}", bq_sem_idx)
575
+ debug_print("[RPA debug] q_len_start={}", q_len_start)
576
+ debug_print("[RPA debug] q_end={}", q_end)
577
+ debug_print("[RPA debug] sz={}", sz)
578
+
579
+ _async_copy(
580
+ q_hbm_ref.at[:, pl.ds(q_len_start, sz)],
581
+ vmem_ref.at[:, pl.ds(0, sz)],
582
+ sem,
583
+ wait,
584
+ )
585
+
586
+ def _send_bo(seq_idx, bo_idx, bo_sem_idx, *, wait=False):
587
+ sem = sems.at[2, bo_sem_idx]
588
+ vmem_ref = bo_x2_ref.at[bo_sem_idx]
589
+ q_len_start = cu_q_lens_ref[seq_idx] + bo_idx * bq_sz
590
+ q_end = cu_q_lens_ref[seq_idx + 1]
591
+ sz = jnp.minimum(bq_sz, q_end - q_len_start)
592
+
593
+ debug_print(
594
+ "[RPA debug]"
595
+ f" -----------{'wait' if wait else 'start'}_send_bo-----------")
596
+ debug_print("[RPA debug] seq_idx={}", seq_idx)
597
+ debug_print("[RPA debug] bo_idx={}", bo_idx)
598
+ debug_print("[RPA debug] bo_sem_idx={}", bo_sem_idx)
599
+ debug_print("[RPA debug] q_len_start={}", q_len_start)
600
+ debug_print("[RPA debug] q_end={}", q_end)
601
+ debug_print("[RPA debug] sz={}", sz)
602
+
603
+ _async_copy(
604
+ vmem_ref.at[:, pl.ds(0, sz)],
605
+ o_hbm_ref.at[:, pl.ds(q_len_start, sz)],
606
+ sem,
607
+ wait,
608
+ )
609
+
610
+ def start_fetch_bkv(seq_idx, bkv_idx, bkv_sem_idx):
611
+ return _fetch_bkv(seq_idx, bkv_idx, bkv_sem_idx)
612
+
613
+ def wait_fetch_bkv(seq_idx, bkv_idx, bkv_sem_idx):
614
+ return _fetch_bkv(seq_idx, bkv_idx, bkv_sem_idx, wait=True)
615
+
616
+ def start_fetch_bq(seq_idx, bq_idx, bq_sem_idx):
617
+ return _fetch_bq(seq_idx, bq_idx, bq_sem_idx)
618
+
619
+ def wait_fetch_bq(seq_idx, bq_idx, bq_sem_idx):
620
+ return _fetch_bq(seq_idx, bq_idx, bq_sem_idx, wait=True)
621
+
622
+ def start_send_bo(seq_idx, bo_idx, bo_sem_idx):
623
+ bo_ids_ref[bo_sem_idx] = seq_idx
624
+ bo_ids_ref[bo_sem_idx + 2] = bo_idx
625
+ _send_bo(seq_idx, bo_idx, bo_sem_idx)
626
+
627
+ def wait_send_bo(bo_sem_idx):
628
+ old_seq_idx = bo_ids_ref[bo_sem_idx]
629
+ old_bo_idx = bo_ids_ref[bo_sem_idx + 2]
630
+
631
+ @pl.when(jnp.logical_and(0 <= old_seq_idx, old_seq_idx <= seq_idx))
632
+ def _():
633
+ _send_bo(old_seq_idx, old_bo_idx, bo_sem_idx, wait=True)
634
+
635
+ def start_update_kv_cache(seq_idx, bkv_sem_idx, offset, update_sz):
636
+ bkv_update_ids_ref[bkv_sem_idx] = seq_idx
637
+ bkv_update_ids_ref[bkv_sem_idx + 2] = offset
638
+ bkv_update_ids_ref[bkv_sem_idx + 4] = update_sz
639
+ _update_kv_cache(seq_idx, bkv_sem_idx, offset, update_sz)
640
+
641
+ def wait_update_kv_cache(bkv_sem_idx):
642
+ update_sz = bkv_update_ids_ref[bkv_sem_idx + 4]
643
+
644
+ @pl.when(update_sz > 0)
645
+ def _():
646
+ seq_idx = bkv_update_ids_ref[bkv_sem_idx]
647
+ offset = bkv_update_ids_ref[bkv_sem_idx + 2]
648
+ bkv_update_ids_ref[bkv_sem_idx + 4] = 0
649
+ _update_kv_cache(seq_idx,
650
+ bkv_sem_idx,
651
+ offset,
652
+ update_sz,
653
+ wait=True)
654
+
655
+ def load_bq(bq_sem_idx, kv_head_idx, *, actual_bq_sz=bq_sz):
656
+ q_ref = (bq_x2_ref.bitcast(
657
+ jnp.uint32).at[bq_sem_idx, kv_head_idx].reshape(
658
+ bq_sz * num_q_heads_per_kv_head_per_packing, head_dim))
659
+ return pltpu.bitcast(
660
+ q_ref[:actual_bq_sz * num_q_heads_per_kv_head_per_packing],
661
+ q_dtype)
662
+
663
+ def strided_load(ref, start, step):
664
+ assert get_dtype_packing(ref.dtype) == 1
665
+ assert len(ref.shape) == 2
666
+ r, l = ref.shape # noqa
667
+ assert l % 128 == 0
668
+ folds = l // 128
669
+ ref = ref.reshape(r * folds, 128)
670
+ start *= folds
671
+ step *= folds
672
+ vec = jnp.concat([ref[start + i::step] for i in range(folds)], axis=1)
673
+ return vec
674
+
675
+ def strided_load_bkv(bkv_sem_idx, start, step, *, bkv_mask):
676
+ assert start % kv_packing == 0
677
+ assert step % kv_packing == 0
678
+ start //= kv_packing
679
+ step //= kv_packing
680
+ kv_ref = (bkv_x2_ref.bitcast(jnp.uint32).at[bkv_sem_idx].reshape(
681
+ bkv_sz * step, head_dim))
682
+
683
+ if kv_packing == 1:
684
+ k = strided_load(kv_ref, start, step)
685
+ v = strided_load(kv_ref, start + 1, step)
686
+
687
+ kv_zeros = jnp.zeros_like(k)
688
+ k = lax.select(bkv_mask, k, kv_zeros)
689
+ v = lax.select(bkv_mask, v, kv_zeros)
690
+
691
+ k = pltpu.bitcast(k, kv_dtype)
692
+ v = pltpu.bitcast(v, kv_dtype)
693
+ return [(k, v)]
694
+
695
+ kv = strided_load(kv_ref, start, step)
696
+ # bkv_mask holds information about where each row of bkv is valid. Because
697
+ # kv is packed, a single 32-bits value might contain multiple k & v from
698
+ # different kv heads. Despite this we can guarantee that all values in a
699
+ # single 32-bits will map to the same bkv row. Therefore, it is safe to
700
+ # apply bkv_mask to kv directly.
701
+ kv = lax.select(bkv_mask, kv, jnp.zeros_like(kv))
702
+ bitwidth = 32 // kv_packing
703
+
704
+ # If we want to convert 32-bits into 32//N number of N-bits value, naive
705
+ # approach would be to perform 32//N number of 32-bits to N-bits conversion.
706
+ # However, we can reduce number of instructions by utilizing binary tree.
707
+ # 0: [32]
708
+ # 1: [16, 16]
709
+ # ...
710
+ # log2(32//N): [N, N, ... N]
711
+
712
+ def _convert_to_target_bitwidth(val, target_bitwidth: int):
713
+ curr_dtype = val.dtype
714
+ curr_bitwidth = get_dtype_bitwidth(curr_dtype)
715
+ assert target_bitwidth != curr_bitwidth, "No conversion is needed."
716
+
717
+ # We split val into two vals (left and right) where each have half of the
718
+ # original bitwidth.
719
+ next_bitwidth = curr_bitwidth // 2
720
+ next_dtype = jnp.dtype(f"uint{next_bitwidth}")
721
+
722
+ left = val.astype(next_dtype)
723
+
724
+ # Bitwise shift is only supported in uint32.
725
+ val_u32 = pltpu.bitcast(val, jnp.uint32)
726
+ val_u32_shifted = val_u32 >> next_bitwidth
727
+ # Convert back to original dtype.
728
+ val_shifted = pltpu.bitcast(val_u32_shifted, curr_dtype)
729
+ right = val_shifted.astype(next_dtype)
730
+
731
+ if next_bitwidth == target_bitwidth:
732
+ k = pltpu.bitcast(left, kv_dtype)
733
+ v = pltpu.bitcast(right, kv_dtype)
734
+ return [(k, v)]
735
+ else:
736
+ left_out = _convert_to_target_bitwidth(
737
+ left,
738
+ target_bitwidth=target_bitwidth,
739
+ )
740
+ right_out = _convert_to_target_bitwidth(
741
+ right,
742
+ target_bitwidth=target_bitwidth,
743
+ )
744
+ return left_out + right_out
745
+
746
+ return _convert_to_target_bitwidth(kv, target_bitwidth=bitwidth)
747
+
748
+ def broadcast_minor(src, shape):
749
+ if src.shape == shape:
750
+ return src
751
+ assert src.shape[:-1] == shape[:-1]
752
+ assert src.shape[-1] % 128 == 0
753
+ target_minor = align_to(shape[-1], src.shape[-1])
754
+ # no-op concatenation.
755
+ return jnp.concatenate(
756
+ [src for _ in range(target_minor // src.shape[-1])],
757
+ axis=-1)[..., :shape[-1]]
758
+
759
+ def process(static_q_len=None):
760
+ num_bkv = cdiv(kv_len, bkv_sz)
761
+ if static_q_len is None:
762
+ actual_bq_sz = bq_sz
763
+ num_bq = cdiv(q_len, actual_bq_sz)
764
+ else:
765
+ actual_bq_sz = min(bq_sz, static_q_len)
766
+ num_bq = cdiv(static_q_len, actual_bq_sz)
767
+
768
+ def get_next_bq_ids(seq_idx, bq_idx, bq_sem_idx):
769
+ next_bq_idx = bq_idx + 1
770
+ is_last_bq = next_bq_idx == num_bq
771
+ next_bq_idx = lax.select(is_last_bq, 0, next_bq_idx)
772
+ next_seq_idx = lax.select(is_last_bq, seq_idx + 1, seq_idx)
773
+ next_bq_sem_idx = lax.select(bq_sem_idx == 0, 1, 0)
774
+ return next_seq_idx, next_bq_idx, next_bq_sem_idx
775
+
776
+ def get_next_bkv_ids(seq_idx, bq_idx, bkv_idx, bkv_sem_idx):
777
+ next_bkv_idx = bkv_idx + 1
778
+ is_last_bkv = next_bkv_idx == num_bkv
779
+ next_bkv_idx = lax.select(is_last_bkv, 0, next_bkv_idx)
780
+ next_bq_idx = lax.select(is_last_bkv, bq_idx + 1, bq_idx)
781
+ is_last_bq = next_bq_idx == num_bq
782
+ next_bq_idx = lax.select(is_last_bq, 0, next_bq_idx)
783
+ next_seq_idx = lax.select(is_last_bq, seq_idx + 1, seq_idx)
784
+ next_bkv_sem_idx = lax.select(bkv_sem_idx == 0, 1, 0)
785
+ return next_seq_idx, next_bq_idx, next_bkv_idx, next_bkv_sem_idx
786
+
787
+ def compute_with_bq(bq_idx, _):
788
+ bq_sem_idx = sem_ids_ref[0]
789
+ next_seq_idx, next_bq_idx, next_bq_sem_idx = get_next_bq_ids(
790
+ seq_idx, bq_idx, bq_sem_idx)
791
+
792
+ # Prefetch next bq
793
+ @pl.when(next_seq_idx < num_seqs)
794
+ def prefetch_next_bq():
795
+ sem_ids_ref[0] = next_bq_sem_idx
796
+ start_fetch_bq(next_seq_idx, next_bq_idx, next_bq_sem_idx)
797
+
798
+ def compute_with_bkv(bkv_idx, _):
799
+ # Create bitmask for KV.
800
+ assert bkv_sz % kv_packing == 0
801
+ actual_bkv_sz = jnp.minimum(bkv_sz, kv_len - bkv_idx * bkv_sz)
802
+ bkv_shape = (bkv_sz, head_dim)
803
+ bkv_mask = lax.broadcasted_iota(jnp.int32, bkv_shape,
804
+ 0) < actual_bkv_sz
805
+
806
+ # Get next bkv ids.
807
+ bkv_sem_idx = sem_ids_ref[1]
808
+ next_seq_idx, _, next_bkv_idx, next_bkv_sem_idx = get_next_bkv_ids(
809
+ seq_idx, bq_idx, bkv_idx, bkv_sem_idx)
810
+
811
+ # Prefetch next bkv
812
+ @pl.when(next_seq_idx < num_seqs)
813
+ def prefetch_next_bkv():
814
+ sem_ids_ref[1] = next_bkv_sem_idx
815
+ start_fetch_bkv(next_seq_idx, next_bkv_idx,
816
+ next_bkv_sem_idx)
817
+
818
+ # Wait for cur bq if not ready yet
819
+ @pl.when(bkv_idx == 0)
820
+ def wait_cur_bq():
821
+ wait_fetch_bq(seq_idx, bq_idx, bq_sem_idx)
822
+
823
+ # Wait for cur bkv
824
+ offset, update_sz = wait_fetch_bkv(seq_idx, bkv_idx,
825
+ bkv_sem_idx)
826
+
827
+ # Start updating bkv to kv cache if applicable.
828
+ # Only needed in first bq loop.
829
+ @pl.when(jnp.logical_and(update_sz > 0, bq_idx == 0))
830
+ def update_cur_bkv_to_cache():
831
+ start_update_kv_cache(seq_idx, bkv_sem_idx, offset,
832
+ update_sz)
833
+
834
+ debug_print(
835
+ "[RPA debug] -----------flash attention-----------")
836
+ debug_print("[RPA debug] seq_idx={}", seq_idx)
837
+ debug_print("[RPA debug] bq_idx={}", bq_idx)
838
+ debug_print("[RPA debug] bkv_idx={}", bkv_idx)
839
+ if debug_mode:
840
+ # Skip flash attention if debug mode is enabled.
841
+ return
842
+
843
+ # Flash attention with cur bkv and bq
844
+ # NOTE: kv_packing is divided by 2 because k and v are packed together.
845
+ heads_per_load = max(1, kv_packing // 2)
846
+ for kv_head_start in range(0, actual_num_kv_heads,
847
+ heads_per_load):
848
+ bkv_lst = strided_load_bkv(
849
+ bkv_sem_idx,
850
+ kv_head_start * 2,
851
+ num_kv_heads_x2,
852
+ bkv_mask=bkv_mask,
853
+ )
854
+ assert len(bkv_lst) == heads_per_load
855
+ for i in range(heads_per_load):
856
+ kv_head_idx = kv_head_start + i
857
+ if kv_head_idx >= actual_num_kv_heads:
858
+ break
859
+ bq = load_bq(bq_sem_idx,
860
+ kv_head_idx,
861
+ actual_bq_sz=actual_bq_sz)
862
+ bk, bv = bkv_lst[i]
863
+ flash_attention(
864
+ bq,
865
+ bk,
866
+ bv,
867
+ bq_idx=bq_idx,
868
+ bkv_idx=bkv_idx,
869
+ kv_head_idx=kv_head_idx,
870
+ )
871
+
872
+ lax.fori_loop(0, num_bkv, compute_with_bkv, None, unroll=False)
873
+
874
+ # Load acc and calculate final output.
875
+ acc = acc_ref[...]
876
+ l = broadcast_minor(l_ref[...], acc.shape) # noqa
877
+ out = (lax.div(acc, l) if q_dtype == jnp.float32 else
878
+ (acc * pl.reciprocal(l, approx=True)).astype(q_dtype))
879
+
880
+ # Wait for previous bo to be fully sent before storing new bo.
881
+ bo_sem_idx = sem_ids_ref[2]
882
+ sem_ids_ref[2] = lax.select(bo_sem_idx == 0, 1, 0)
883
+ wait_send_bo(bo_sem_idx)
884
+
885
+ # Store output from acc to bo.
886
+ bo_x2_ref.at[bo_sem_idx].bitcast(jnp.int32).reshape(
887
+ actual_num_kv_heads,
888
+ bq_sz * num_q_heads_per_kv_head_per_packing,
889
+ head_dim,
890
+ )[...] = pltpu.bitcast(out, jnp.int32)
891
+
892
+ # Send cur bo
893
+ start_send_bo(seq_idx, bq_idx, bo_sem_idx)
894
+
895
+ lax.fori_loop(0, num_bq, compute_with_bq, None, unroll=False)
896
+
897
+ ### ------- Kernel start ------- ###
898
+
899
+ @pl.when(seq_idx == 0)
900
+ def prologue():
901
+ start_fetch_bq(0, 0, 0)
902
+ start_fetch_bkv(0, 0, 0)
903
+
904
+ @pl.when(seq_idx < decode_end)
905
+ def process_decode():
906
+ process(static_q_len=1)
907
+
908
+ @pl.when(jnp.logical_and(decode_end <= seq_idx, seq_idx < prefill_end))
909
+ def process_prefill():
910
+ process(static_q_len=chunk_prefill_size)
911
+
912
+ @pl.when(jnp.logical_and(prefill_end <= seq_idx, seq_idx < mixed_end))
913
+ def process_mixed():
914
+ process()
915
+
916
+ @pl.when(seq_idx == num_seqs - 1)
917
+ def epilogue():
918
+ for i in range(2):
919
+ wait_send_bo(i)
920
+ wait_update_kv_cache(i)
921
+
922
+ ### ------- Kernel end ------- ###
923
+
924
+
925
+ def merge_kv(
926
+ k: jax.
927
+ Array, # [max_num_tokens, actual_num_kv_heads, actual_head_dim],
928
+ v: jax.
929
+ Array, # [max_num_tokens, actual_num_kv_heads, actual_head_dim],
930
+ ):
931
+ assert k.shape == v.shape
932
+ assert k.dtype == v.dtype
933
+ max_num_tokens, actual_num_kv_heads, actual_head_dim = k.shape
934
+ kv_packing = get_dtype_packing(k.dtype)
935
+ actual_num_kv_heads_x2 = actual_num_kv_heads * 2
936
+ num_kv_heads_x2 = align_to(actual_num_kv_heads_x2, kv_packing)
937
+ head_dim = align_to(actual_head_dim, 128)
938
+ kv = jnp.pad(
939
+ jnp.concat([k, v],
940
+ axis=-1).reshape(max_num_tokens, actual_num_kv_heads_x2,
941
+ actual_head_dim),
942
+ (
943
+ (0, 0),
944
+ (0, num_kv_heads_x2 - actual_num_kv_heads_x2),
945
+ (0, head_dim - actual_head_dim),
946
+ ),
947
+ constant_values=0,
948
+ ).reshape(
949
+ max_num_tokens,
950
+ num_kv_heads_x2 // kv_packing,
951
+ kv_packing,
952
+ head_dim,
953
+ )
954
+ return kv
955
+
956
+
957
+ def prepare_inputs(
958
+ q: jax.Array, # [max_num_tokens, actual_num_q_heads, actual_head_dim],
959
+ k: jax.
960
+ Array, # [max_num_tokens, actual_num_kv_heads, actual_head_dim],
961
+ v: jax.
962
+ Array, # [max_num_tokens, actual_num_kv_heads, actual_head_dim],
963
+ ):
964
+ max_num_tokens, actual_num_q_heads, actual_head_dim = q.shape
965
+ actual_num_kv_heads = k.shape[1]
966
+ assert actual_num_q_heads % actual_num_kv_heads == 0
967
+ actual_num_q_heads_per_kv_head = actual_num_q_heads // actual_num_kv_heads
968
+ q_packing = get_dtype_packing(q.dtype)
969
+ num_q_heads_per_kv_head = align_to(actual_num_q_heads_per_kv_head,
970
+ q_packing)
971
+ head_dim = align_to(actual_head_dim, 128)
972
+ q = (
973
+ jnp.pad(
974
+ q.reshape(
975
+ max_num_tokens,
976
+ actual_num_kv_heads,
977
+ actual_num_q_heads_per_kv_head,
978
+ actual_head_dim,
979
+ ),
980
+ (
981
+ (0, 0),
982
+ (0, 0),
983
+ (0, num_q_heads_per_kv_head - actual_num_q_heads_per_kv_head),
984
+ (0, head_dim - actual_head_dim),
985
+ ),
986
+ constant_values=0,
987
+ ).reshape(
988
+ max_num_tokens,
989
+ actual_num_kv_heads,
990
+ num_q_heads_per_kv_head // q_packing,
991
+ q_packing,
992
+ head_dim,
993
+ )
994
+ # TODO(jevinjiang): Explore fusing swapping non-tiling axis to DMA.
995
+ .swapaxes(0, 1))
996
+ # TODO(kyuyeunk, chengjiyao): Add kv quantization here.
997
+ kv = merge_kv(k, v)
998
+ return q, kv
999
+
1000
+
1001
+ def prepare_outputs(
1002
+ out, # [actual_num_kv_heads, max_num_tokens, num_q_heads_per_kv_head // q_packing, q_packing, head_dim]
1003
+ actual_num_q_heads_per_kv_head: int,
1004
+ actual_head_dim: int,
1005
+ ):
1006
+ (
1007
+ actual_num_kv_heads,
1008
+ max_num_tokens,
1009
+ num_q_heads_per_kv_head_per_q_packing,
1010
+ q_packing,
1011
+ head_dim,
1012
+ ) = out.shape
1013
+ actual_num_q_heads = actual_num_q_heads_per_kv_head * actual_num_kv_heads
1014
+ return (out.swapaxes(0, 1).reshape(
1015
+ max_num_tokens,
1016
+ actual_num_kv_heads,
1017
+ num_q_heads_per_kv_head_per_q_packing * q_packing,
1018
+ head_dim,
1019
+ )[:, :, :actual_num_q_heads_per_kv_head, :actual_head_dim].reshape(
1020
+ max_num_tokens, actual_num_q_heads, actual_head_dim))
1021
+
1022
+
1023
+ # Expect to run this validation during runtime.
1024
+ def dynamic_validate_inputs(
1025
+ queries: jax.
1026
+ Array, # [max_num_tokens, actual_num_q_heads, actual_head_dim]
1027
+ keys: jax.Array, # [max_num_tokens, actual_num_kv_heads, actual_head_dim]
1028
+ values: jax.
1029
+ Array, # [max_num_tokens, actual_num_kv_heads, actual_head_dim]
1030
+ kv_cache: jax.
1031
+ Array, # [total_num_pages, page_size, num_kv_heads_x2 // kv_packing, kv_packing, head_dim]
1032
+ kv_lens: jax.Array, # i32[max_num_seqs]
1033
+ page_indices: jax.Array, # i32[max_num_seqs * pages_per_seq]
1034
+ cu_q_lens: jax.Array, # i32[max_num_seqs + 1]
1035
+ distribution: jax.Array, # i32[3]
1036
+ *,
1037
+ sm_scale: float = 1.0,
1038
+ sliding_window: int | None = None,
1039
+ soft_cap: float | None = None,
1040
+ mask_value: float | None = DEFAULT_MASK_VALUE,
1041
+ q_scale: float | None = None,
1042
+ k_scale: float | None = None,
1043
+ v_scale: float | None = None,
1044
+ # Kernel optimization params.
1045
+ chunk_prefill_size: int | None = None,
1046
+ # Kernel tuning params.
1047
+ num_kv_pages_per_block: int | None = None,
1048
+ num_queries_per_block: int | None = None,
1049
+ vmem_limit_bytes: int | None = None,
1050
+ # Debug params.
1051
+ debug_mode: bool = False,
1052
+ ):
1053
+ q, k, v = queries, keys, values
1054
+ static_validate_inputs(
1055
+ q,
1056
+ k,
1057
+ v,
1058
+ kv_cache,
1059
+ kv_lens,
1060
+ page_indices,
1061
+ cu_q_lens,
1062
+ distribution,
1063
+ sm_scale=sm_scale,
1064
+ sliding_window=sliding_window,
1065
+ soft_cap=soft_cap,
1066
+ mask_value=mask_value,
1067
+ q_scale=q_scale,
1068
+ k_scale=k_scale,
1069
+ v_scale=v_scale,
1070
+ chunk_prefill_size=chunk_prefill_size,
1071
+ num_kv_pages_per_block=num_kv_pages_per_block,
1072
+ num_queries_per_block=num_queries_per_block,
1073
+ vmem_limit_bytes=vmem_limit_bytes,
1074
+ debug_mode=debug_mode,
1075
+ )
1076
+ max_num_tokens = q.shape[0]
1077
+ total_num_pages = kv_cache.shape[0]
1078
+ page_size = kv_cache.shape[1]
1079
+ max_num_seqs = kv_lens.shape[0]
1080
+ num_page_indices = page_indices.shape[0]
1081
+ assert num_page_indices % max_num_seqs == 0
1082
+ pages_per_seq = num_page_indices // max_num_seqs
1083
+
1084
+ i, j, k = distribution
1085
+ if not (i <= j <= k):
1086
+ raise ValueError(f"Invalid distribution: {distribution=}")
1087
+
1088
+ if k > max_num_seqs:
1089
+ raise ValueError(f"num_seqs={k} must be <= {max_num_seqs=}")
1090
+
1091
+ if cu_q_lens[k] > max_num_tokens:
1092
+ raise ValueError(
1093
+ f"Total q tokens {cu_q_lens[k]} must be <= {max_num_tokens=}.")
1094
+ for i in range(k):
1095
+ q_len = cu_q_lens[i + 1] - cu_q_lens[i]
1096
+ kv_len = kv_lens[i]
1097
+ if not (0 < q_len <= kv_len):
1098
+ raise ValueError(
1099
+ f"Require 0 < {q_len=} <= {kv_len=} at sequence {i}.")
1100
+ page_cnt = cdiv(kv_len, page_size)
1101
+ if page_cnt > pages_per_seq:
1102
+ raise ValueError(
1103
+ f"Require {page_cnt=} <= {pages_per_seq=} at sequence {i} where"
1104
+ f" {kv_len=} and {page_size=}.")
1105
+ for p in range(page_cnt):
1106
+ page_idx = page_indices[i * pages_per_seq + p]
1107
+ if not (0 <= page_idx < total_num_pages):
1108
+ raise ValueError(
1109
+ f"Require 0 <= {page_idx=} < {total_num_pages=} at sequence"
1110
+ f" {i} where {kv_len=} and {page_size=}.")
1111
+
1112
+
1113
+ # Expect to run this validation during compile time.
1114
+ def static_validate_inputs(
1115
+ queries: jax.
1116
+ Array, # [max_num_tokens, actual_num_q_heads, actual_head_dim]
1117
+ keys: jax.Array, # [max_num_tokens, actual_num_kv_heads, actual_head_dim]
1118
+ values: jax.
1119
+ Array, # [max_num_tokens, actual_num_kv_heads, actual_head_dim]
1120
+ kv_cache: jax.
1121
+ Array, # [total_num_pages, page_size, num_kv_heads_x2 // kv_packing, kv_packing, head_dim]
1122
+ kv_lens: jax.Array, # i32[max_num_seqs]
1123
+ page_indices: jax.Array, # i32[max_num_seqs * pages_per_seq]
1124
+ cu_q_lens: jax.Array, # i32[max_num_seqs + 1]
1125
+ distribution: jax.Array, # i32[3]
1126
+ *,
1127
+ sm_scale: float = 1.0,
1128
+ sliding_window: int | None = None,
1129
+ soft_cap: float | None = None,
1130
+ mask_value: float | None = DEFAULT_MASK_VALUE,
1131
+ q_scale: float | None = None,
1132
+ k_scale: float | None = None,
1133
+ v_scale: float | None = None,
1134
+ # Kernel optimization params.
1135
+ chunk_prefill_size: int | None = None,
1136
+ # Kernel tuning params.
1137
+ num_kv_pages_per_block: int | None = None,
1138
+ num_queries_per_block: int | None = None,
1139
+ vmem_limit_bytes: int | None = None,
1140
+ # Debug params.
1141
+ debug_mode: bool = False,
1142
+ ):
1143
+ """Validate inputs to the RPA kernel statically."""
1144
+ q, k, v = queries, keys, values
1145
+ if not (len(q.shape) == len(k.shape) == len(v.shape) == 3):
1146
+ raise ValueError(
1147
+ f"Expected 3D array for {q.shape=}, {k.shape=}, {v.shape=}")
1148
+ if k.shape != v.shape:
1149
+ raise ValueError(f"Expected {k.shape=} to be equal to {v.shape=}")
1150
+ if not (q.shape[0] == k.shape[0] == v.shape[0]):
1151
+ raise ValueError(
1152
+ f"Expected {q.shape[0]=} to be equal to {k.shape[0]=} and {v.shape[0]=}"
1153
+ )
1154
+ if not (q.shape[2] == k.shape[2] == v.shape[2]):
1155
+ raise ValueError(
1156
+ f"Expected {q.shape[2]=} to be equal to {k.shape[2]=} and {v.shape[2]=}"
1157
+ )
1158
+
1159
+ actual_head_dim = q.shape[2]
1160
+ actual_num_q_heads = q.shape[1]
1161
+ actual_num_kv_heads = k.shape[1]
1162
+
1163
+ if actual_num_q_heads % actual_num_kv_heads != 0:
1164
+ raise ValueError(f"Expected {actual_num_q_heads=} to be divisible by"
1165
+ f" {actual_num_kv_heads=}.")
1166
+
1167
+ (
1168
+ _,
1169
+ page_size,
1170
+ num_kv_heads_x2_per_kv_packing,
1171
+ kv_packing,
1172
+ head_dim,
1173
+ ) = kv_cache.shape
1174
+
1175
+ if head_dim != align_to(actual_head_dim, 128):
1176
+ raise ValueError(
1177
+ f"Expected {head_dim=} is equal to {align_to(actual_head_dim, 128)=}"
1178
+ )
1179
+ # Note: we expect the kv quantization happens outside of the RPA kernel.
1180
+ if not (kv_cache.dtype == k.dtype == v.dtype):
1181
+ raise ValueError(
1182
+ f"Expected {kv_cache.dtype=} to be equal to {k.dtype=} and {v.dtype=}."
1183
+ )
1184
+ # Integer kv quantization is currently not supported.
1185
+ if not jnp.issubdtype(kv_cache.dtype, jnp.floating):
1186
+ raise ValueError(f"Expected {kv_cache.dtype=} to be a floating point.")
1187
+ if kv_packing != get_dtype_packing(kv_cache.dtype):
1188
+ raise ValueError(
1189
+ f"{kv_packing=} does not match with {kv_cache.dtype=}")
1190
+
1191
+ num_kv_heads_x2 = num_kv_heads_x2_per_kv_packing * kv_packing
1192
+ if num_kv_heads_x2 % 2 != 0:
1193
+ raise ValueError(
1194
+ f"Combined KV heads must be divisible by 2, but got {num_kv_heads_x2}"
1195
+ )
1196
+ if align_to(actual_num_kv_heads * 2, kv_packing) != num_kv_heads_x2:
1197
+ raise ValueError(
1198
+ f"Invalid {num_kv_heads_x2=}, {actual_num_kv_heads=}, {kv_packing=}"
1199
+ )
1200
+
1201
+ if not (jnp.int32 == kv_lens.dtype == page_indices.dtype == cu_q_lens.dtype
1202
+ == distribution.dtype):
1203
+ raise ValueError(
1204
+ f"Expected int32 dtype for {kv_lens.dtype=}, {page_indices.dtype=},"
1205
+ f" {cu_q_lens.dtype=}, {distribution.dtype=}")
1206
+
1207
+ if not (len(kv_lens.shape) == len(page_indices.shape) == len(
1208
+ cu_q_lens.shape) == 1):
1209
+ raise ValueError(
1210
+ f"Expected 1D array for {kv_lens.shape=}, {page_indices.shape=},"
1211
+ f" {cu_q_lens.shape=}")
1212
+
1213
+ max_num_seqs = kv_lens.shape[0]
1214
+ num_page_indices = page_indices.shape[0]
1215
+ if num_page_indices % max_num_seqs != 0:
1216
+ raise ValueError(
1217
+ f"Expected {num_page_indices=} to be divisible by {max_num_seqs=}."
1218
+ )
1219
+ if cu_q_lens.shape != (max_num_seqs + 1, ):
1220
+ raise ValueError(
1221
+ f"Expected {cu_q_lens.shape=} to be ({max_num_seqs + 1},).")
1222
+ if distribution.shape != (3, ):
1223
+ raise ValueError(f"Expected {distribution.shape=} to be (3,).")
1224
+
1225
+ if page_size % kv_packing != 0:
1226
+ raise ValueError(f"{page_size=} must be divisible by {kv_packing=}.")
1227
+ if sliding_window is not None and sliding_window <= 0:
1228
+ raise ValueError(f"{sliding_window=} must be positive.")
1229
+ if soft_cap is not None and soft_cap == 0.0:
1230
+ raise ValueError(f"{soft_cap=} must not be 0.0.")
1231
+ if chunk_prefill_size is not None and chunk_prefill_size <= 0:
1232
+ raise ValueError(f"{chunk_prefill_size=} must be positive.")
1233
+ if num_kv_pages_per_block is not None:
1234
+ if num_kv_pages_per_block <= 0:
1235
+ raise ValueError(f"{num_kv_pages_per_block=} must be positive.")
1236
+ if num_queries_per_block is not None:
1237
+ if num_queries_per_block <= 0:
1238
+ raise ValueError(f"{num_queries_per_block=} must be positive.")
1239
+ if vmem_limit_bytes is not None and vmem_limit_bytes <= 0:
1240
+ raise ValueError(f"{vmem_limit_bytes=} must be positive.")
1241
+
1242
+ # No constraints for the following inputs.
1243
+ del sm_scale
1244
+ del mask_value
1245
+ del q_scale
1246
+ del k_scale
1247
+ del v_scale
1248
+ del debug_mode
1249
+
1250
+
1251
+ @functools.partial(
1252
+ jax.jit,
1253
+ static_argnames=(
1254
+ "sm_scale",
1255
+ "sliding_window",
1256
+ "soft_cap",
1257
+ "mask_value",
1258
+ "q_scale",
1259
+ "k_scale",
1260
+ "v_scale",
1261
+ "chunk_prefill_size",
1262
+ "num_kv_pages_per_block",
1263
+ "num_queries_per_block",
1264
+ "vmem_limit_bytes",
1265
+ "debug_mode",
1266
+ ),
1267
+ donate_argnames=("kv_cache", ),
1268
+ )
1269
+ def ragged_paged_attention(
1270
+ queries: jax.
1271
+ Array, # [max_num_tokens, actual_num_q_heads, actual_head_dim]
1272
+ keys: jax.Array, # [max_num_tokens, actual_num_kv_heads, actual_head_dim]
1273
+ values: jax.
1274
+ Array, # [max_num_tokens, actual_num_kv_heads, actual_head_dim]
1275
+ kv_cache: jax.
1276
+ Array, # [total_num_pages, page_size, num_kv_heads_x2 // kv_packing, kv_packing, head_dim]
1277
+ kv_lens: jax.Array, # i32[max_num_seqs]
1278
+ page_indices: jax.Array, # i32[max_num_seqs * pages_per_seq]
1279
+ cu_q_lens: jax.Array, # i32[max_num_seqs + 1]
1280
+ distribution: jax.Array, # i32[3]
1281
+ *,
1282
+ sm_scale: float = 1.0,
1283
+ sliding_window: int | None = None,
1284
+ soft_cap: float | None = None,
1285
+ mask_value: float | None = DEFAULT_MASK_VALUE,
1286
+ q_scale: float | None = None,
1287
+ k_scale: float | None = None,
1288
+ v_scale: float | None = None,
1289
+ # Kernel optimization params.
1290
+ chunk_prefill_size: int | None = None,
1291
+ # Kernel tuning params.
1292
+ num_kv_pages_per_block: int | None = None,
1293
+ num_queries_per_block: int | None = None,
1294
+ vmem_limit_bytes: int | None = None,
1295
+ # Debug params.
1296
+ debug_mode: bool = False,
1297
+ ):
1298
+ """Ragged paged attention that supports mixed prefill and decode.
1299
+
1300
+ Args:
1301
+ queries: concatenated all sequences' queries.
1302
+ keys: concatenated all sequences' keys (quantized).
1303
+ values: concatenated all sequences' values (quantized).
1304
+ kv_cache: paged KV cache with TPU-friendly shape.
1305
+ kv_lens: padded kv lengths. Only the first num_seqs values are valid.
1306
+ page_indices: flattened page indices look-up table by (seq_id, page_id).
1307
+ cu_q_lens: the cumulative sum of the effective query lengths. Similar to
1308
+ kv_lens, only the first num_seqs+1 values are valid.
1309
+ distribution: (i, j, k) represents that sequences[0:i] are decode-only,
1310
+ sequences[i:j] are chunked-prefill-only, and sequences[j:k] are mixed. The
1311
+ k is also the total number of sequences.
1312
+ actual_head_dim: the actual head size of the attention. Here we assume k and
1313
+ v have the same actual head size.
1314
+ sm_scale: the softmax scale which will be applied to the Q@K^T.
1315
+ sliding_window: the sliding window size for the attention.
1316
+ soft_cap: the logit soft cap for the attention.
1317
+ mask_value: mask value for causal mask.
1318
+ k_scale: the scale for the key cache.
1319
+ v_scale: the scale for the value cache.
1320
+ num_kv_pages_per_block: number of kv pages to be processed in one flash
1321
+ attention block in the pallas kernel.
1322
+ num_queries_per_block: number of kv pages to be processed in one flash
1323
+ attention block in the pallas kernel.
1324
+ vmem_limit_bytes: the vmem limit for the pallas kernel.
1325
+ debug_mode: if true, RPA does not issue any DMAs or run flash attention but
1326
+ print debug info. Need to compile with `--xla_tpu_enable_log_recorder`.
1327
+
1328
+ Returns:
1329
+ The output of the attention.
1330
+ """
1331
+ q, k, v = queries, keys, values
1332
+ static_validate_inputs(
1333
+ q,
1334
+ k,
1335
+ v,
1336
+ kv_cache,
1337
+ kv_lens,
1338
+ page_indices,
1339
+ cu_q_lens,
1340
+ distribution,
1341
+ sm_scale=sm_scale,
1342
+ sliding_window=sliding_window,
1343
+ soft_cap=soft_cap,
1344
+ mask_value=mask_value,
1345
+ q_scale=q_scale,
1346
+ k_scale=k_scale,
1347
+ v_scale=v_scale,
1348
+ chunk_prefill_size=chunk_prefill_size,
1349
+ num_kv_pages_per_block=num_kv_pages_per_block,
1350
+ num_queries_per_block=num_queries_per_block,
1351
+ vmem_limit_bytes=vmem_limit_bytes,
1352
+ )
1353
+
1354
+ actual_num_q_heads = q.shape[1]
1355
+ actual_head_dim = q.shape[2]
1356
+ actual_num_kv_heads = k.shape[1]
1357
+
1358
+ actual_num_q_heads_per_kv_head = actual_num_q_heads // actual_num_kv_heads
1359
+ q, kv = prepare_inputs(q, k, v)
1360
+ (
1361
+ _,
1362
+ max_num_tokens,
1363
+ num_q_heads_per_kv_head_per_q_packing,
1364
+ q_packing,
1365
+ head_dim,
1366
+ ) = q.shape
1367
+ page_size = kv_cache.shape[1]
1368
+ max_num_seqs = kv_lens.shape[0]
1369
+ num_page_indices = page_indices.shape[0]
1370
+ assert num_page_indices % max_num_seqs == 0
1371
+ pages_per_seq = num_page_indices // max_num_seqs
1372
+ num_q_heads_per_kv_head = num_q_heads_per_kv_head_per_q_packing * q_packing
1373
+
1374
+ bkv_p = num_kv_pages_per_block
1375
+ bq_sz = num_queries_per_block
1376
+ if bq_sz is None or bkv_p is None:
1377
+ bkv_p, bq_sz = get_tuned_block_sizes(
1378
+ q.dtype,
1379
+ kv_cache.dtype,
1380
+ actual_num_q_heads,
1381
+ actual_num_kv_heads,
1382
+ head_dim,
1383
+ page_size,
1384
+ max_num_tokens,
1385
+ pages_per_seq,
1386
+ )
1387
+ bkv_sz = bkv_p * page_size
1388
+ if vmem_limit_bytes is None:
1389
+ # TODO (jevinjiang/jacobplatin): change this to use
1390
+ # `get_vmem_estimate_bytes` when VREG spilling is fixed.
1391
+ vmem_limit_bytes = DEFAULT_VMEM_LIMIT_BYTES
1392
+ grid = (distribution[2], )
1393
+
1394
+ in_specs = [
1395
+ pl.BlockSpec(memory_space=pltpu.HBM),
1396
+ pl.BlockSpec(memory_space=pltpu.HBM),
1397
+ pl.BlockSpec(memory_space=pltpu.HBM),
1398
+ ]
1399
+
1400
+ out_specs = [
1401
+ pl.BlockSpec(memory_space=pltpu.HBM),
1402
+ pl.BlockSpec(memory_space=pltpu.HBM),
1403
+ ]
1404
+
1405
+ bkv_double_buf = pltpu.VMEM(
1406
+ (2, bkv_sz, *kv_cache.shape[2:]),
1407
+ kv_cache.dtype,
1408
+ )
1409
+
1410
+ bq_double_buf = pltpu.VMEM(
1411
+ (2, actual_num_kv_heads, bq_sz, *q.shape[2:]),
1412
+ q.dtype,
1413
+ )
1414
+
1415
+ bo_double_buf = bq_double_buf
1416
+
1417
+ l_scratch = pltpu.VMEM(
1418
+ (actual_num_kv_heads, bq_sz * num_q_heads_per_kv_head, 128),
1419
+ jnp.float32,
1420
+ )
1421
+ m_scratch = l_scratch
1422
+
1423
+ acc_scratch = pltpu.VMEM(
1424
+ (actual_num_kv_heads, bq_sz * num_q_heads_per_kv_head, head_dim),
1425
+ jnp.float32,
1426
+ )
1427
+
1428
+ scratch_shapes = [
1429
+ bkv_double_buf, # Double buffering for kv block.
1430
+ bq_double_buf, # Double buffering for q block.
1431
+ bo_double_buf, # Double buffering for output block.
1432
+ # Semaphores for double buffering of bkv, bq, bo and bkv_update.
1433
+ pltpu.SemaphoreType.DMA((4, 2)),
1434
+ # Intermediate buffers per kv head for flash attention.
1435
+ l_scratch,
1436
+ m_scratch,
1437
+ acc_scratch,
1438
+ ]
1439
+
1440
+ scalar_prefetches = (
1441
+ kv_lens,
1442
+ # TODO(jevinjiang): can we use ragged page_indices to save some smem?
1443
+ page_indices,
1444
+ cu_q_lens,
1445
+ distribution,
1446
+ # (bq_sem_idx, bkv_sem_idx, bo_sem_idx)
1447
+ jnp.zeros((3, ), jnp.int32),
1448
+ # (bo_sem_0_seq_idx, bo_sem_1_seq_idx, bo_sem_0_bo_idx, bo_sem_1_bo_idx)
1449
+ jnp.full((4, ), -1, jnp.int32),
1450
+ # (bkv_sem_0_seq_idx, bkv_sem_1_seq_idx, bkv_sem_0_offset, bkv_sem_1_offset, bkv_sem_0_sz, bkv_sem_1_sz)
1451
+ jnp.full((6, ), -1, jnp.int32),
1452
+ )
1453
+
1454
+ scope_name = f"RPA-bq_{bq_sz}-bkvp_{bkv_p}-p_{page_size}"
1455
+ kernel = jax.named_scope(scope_name)(
1456
+ pl.pallas_call(
1457
+ functools.partial(
1458
+ _ragged_paged_attention_kernel,
1459
+ sm_scale=sm_scale,
1460
+ sliding_window=sliding_window,
1461
+ soft_cap=soft_cap,
1462
+ mask_value=mask_value,
1463
+ q_scale=q_scale,
1464
+ k_scale=k_scale,
1465
+ v_scale=v_scale,
1466
+ chunk_prefill_size=chunk_prefill_size,
1467
+ bq_sz=bq_sz,
1468
+ bkv_p=bkv_p,
1469
+ debug_mode=debug_mode,
1470
+ ),
1471
+ grid_spec=pltpu.PrefetchScalarGridSpec(
1472
+ num_scalar_prefetch=len(scalar_prefetches),
1473
+ in_specs=in_specs,
1474
+ out_specs=out_specs,
1475
+ grid=grid,
1476
+ scratch_shapes=scratch_shapes,
1477
+ ),
1478
+ compiler_params=pltpu.CompilerParams(
1479
+ # TODO(jevinjiang): since each sequence depends on the previous
1480
+ # one, we need some extra work to support Megacore mode.
1481
+ dimension_semantics=("arbitrary", ),
1482
+ vmem_limit_bytes=vmem_limit_bytes,
1483
+ ),
1484
+ out_shape=[
1485
+ jax.ShapeDtypeStruct(shape=q.shape, dtype=q.dtype),
1486
+ jax.ShapeDtypeStruct(shape=kv_cache.shape,
1487
+ dtype=kv_cache.dtype),
1488
+ ],
1489
+ input_output_aliases={
1490
+ 7: 0,
1491
+ 9: 1
1492
+ },
1493
+ name=scope_name,
1494
+ ))
1495
+
1496
+ output, updated_kv_cache = kernel(*scalar_prefetches, q, kv, kv_cache)
1497
+ return (
1498
+ prepare_outputs(output, actual_num_q_heads_per_kv_head,
1499
+ actual_head_dim),
1500
+ updated_kv_cache,
1501
+ )