tpu-inference 0.11.1.dev202511220812__py3-none-any.whl → 0.13.2.dev20251230__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (257) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_disagg_utils.py +14 -0
  4. tests/core/test_dp_scheduler.py +650 -768
  5. tests/core/test_init.py +14 -0
  6. tests/distributed/__init__.py +13 -0
  7. tests/distributed/test_distributed_utils.py +120 -0
  8. tests/distributed/test_tpu_connector.py +478 -0
  9. tests/e2e/__init__.py +13 -0
  10. tests/e2e/test_async_scheduler.py +211 -0
  11. tests/e2e/test_data_parallel.py +289 -0
  12. tests/e2e/test_hybrid_kvcache.py +219 -0
  13. tests/e2e/test_local_disagg.py +257 -0
  14. tests/e2e/test_model_loader.py +268 -0
  15. tests/e2e/test_multi_modal_inference.py +111 -0
  16. tests/e2e/test_pipeline_parallel.py +265 -0
  17. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  18. tests/e2e/test_sampling_params.py +269 -0
  19. tests/e2e/test_speculative_decoding.py +311 -0
  20. tests/e2e/test_structured_decoding.py +46 -0
  21. tests/executors/__init__.py +13 -0
  22. tests/executors/test_ray_distributed_executor.py +199 -0
  23. tests/experimental/__init__.py +13 -0
  24. tests/experimental/test_llama3_jax_stashed.py +208 -0
  25. tests/kernels/__init__.py +13 -0
  26. tests/kernels/collectives/__init__.py +13 -0
  27. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  28. tests/kernels/fused_moe_v1_test.py +317 -34
  29. tests/kernels/gmm_test.py +205 -0
  30. tests/kernels/mla_v1_test.py +143 -41
  31. tests/kernels/quantized_matmul_kernel_test.py +2 -34
  32. tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
  33. tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
  34. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
  35. tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
  36. tests/layers/__init__.py +13 -0
  37. tests/layers/common/__init__.py +13 -0
  38. tests/layers/common/test_attention_interface.py +156 -0
  39. tests/layers/common/test_quantization.py +149 -0
  40. tests/layers/jax/__init__.py +13 -0
  41. tests/layers/jax/attention/__init__.py +13 -0
  42. tests/layers/jax/attention/test_common_attention.py +103 -0
  43. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  44. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  45. tests/layers/jax/moe/__init__.py +13 -0
  46. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  47. tests/layers/jax/sample/__init__.py +13 -0
  48. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  49. tests/layers/jax/sample/test_sampling.py +115 -0
  50. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  51. tests/layers/jax/test_layers.py +155 -0
  52. tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
  53. tests/layers/jax/test_rope.py +93 -0
  54. tests/layers/jax/test_sharding.py +159 -0
  55. tests/layers/jax/test_transformer_block.py +152 -0
  56. tests/layers/vllm/__init__.py +13 -0
  57. tests/layers/vllm/test_attention.py +363 -0
  58. tests/layers/vllm/test_awq.py +406 -0
  59. tests/layers/vllm/test_compressed_tensors_moe.py +199 -0
  60. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +441 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +443 -0
  62. tests/layers/vllm/test_fp8.py +17 -0
  63. tests/layers/vllm/test_mxfp4.py +320 -0
  64. tests/layers/vllm/test_unquantized.py +662 -0
  65. tests/layers/vllm/utils.py +87 -0
  66. tests/lora/__init__.py +13 -0
  67. tests/lora/conftest.py +14 -0
  68. tests/lora/test_bgmv.py +14 -0
  69. tests/lora/test_layers.py +26 -6
  70. tests/lora/test_lora.py +15 -1
  71. tests/lora/test_lora_perf.py +67 -0
  72. tests/models/__init__.py +13 -0
  73. tests/models/common/__init__.py +13 -0
  74. tests/models/common/test_model_loader.py +455 -0
  75. tests/models/jax/__init__.py +13 -0
  76. tests/models/jax/test_deepseek_v3.py +401 -0
  77. tests/models/jax/test_llama3.py +184 -0
  78. tests/models/jax/test_llama4.py +298 -0
  79. tests/models/jax/test_llama_eagle3.py +197 -0
  80. tests/models/jax/test_llama_guard_4.py +242 -0
  81. tests/models/jax/test_qwen2.py +172 -0
  82. tests/models/jax/test_qwen2_5_vl.py +605 -0
  83. tests/models/jax/test_qwen3.py +169 -0
  84. tests/models/jax/test_weight_loading.py +180 -0
  85. tests/models/jax/utils/__init__.py +13 -0
  86. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  87. tests/platforms/__init__.py +13 -0
  88. tests/platforms/test_tpu_platform.py +54 -0
  89. tests/runner/__init__.py +13 -0
  90. tests/runner/test_block_table.py +395 -0
  91. tests/runner/test_input_batch.py +226 -0
  92. tests/runner/test_kv_cache.py +220 -0
  93. tests/runner/test_kv_cache_manager.py +498 -0
  94. tests/runner/test_multimodal_manager.py +429 -0
  95. tests/runner/test_persistent_batch_manager.py +84 -0
  96. tests/runner/test_speculative_decoding_manager.py +368 -0
  97. tests/runner/test_structured_decoding_manager.py +220 -0
  98. tests/runner/test_tpu_runner.py +261 -0
  99. tests/runner/test_tpu_runner_dp.py +1099 -0
  100. tests/runner/test_tpu_runner_mesh.py +200 -0
  101. tests/runner/test_utils.py +411 -0
  102. tests/spec_decode/__init__.py +13 -0
  103. tests/spec_decode/test_eagle3.py +311 -0
  104. tests/test_base.py +14 -0
  105. tests/test_envs.py +110 -12
  106. tests/test_tpu_info.py +14 -0
  107. tests/test_utils.py +2 -45
  108. tests/worker/__init__.py +13 -0
  109. tests/worker/tpu_worker_test.py +414 -0
  110. tpu_inference/__init__.py +14 -0
  111. tpu_inference/core/__init__.py +13 -0
  112. tpu_inference/core/sched/__init__.py +13 -0
  113. tpu_inference/core/sched/dp_scheduler.py +372 -56
  114. tpu_inference/distributed/__init__.py +13 -0
  115. tpu_inference/distributed/jax_parallel_state.py +14 -0
  116. tpu_inference/distributed/tpu_connector.py +15 -10
  117. tpu_inference/distributed/utils.py +56 -4
  118. tpu_inference/envs.py +92 -8
  119. tpu_inference/executors/__init__.py +13 -0
  120. tpu_inference/executors/ray_distributed_executor.py +25 -4
  121. tpu_inference/experimental/__init__.py +13 -0
  122. tpu_inference/experimental/llama3_jax_stashed.py +14 -0
  123. tpu_inference/kernels/__init__.py +13 -0
  124. tpu_inference/kernels/collectives/__init__.py +13 -0
  125. tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
  126. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
  127. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  128. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  129. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  130. tpu_inference/kernels/fused_moe/v1/kernel.py +807 -230
  131. tpu_inference/kernels/megablox/__init__.py +13 -0
  132. tpu_inference/kernels/megablox/common.py +54 -0
  133. tpu_inference/kernels/megablox/gmm.py +646 -0
  134. tpu_inference/kernels/mla/__init__.py +13 -0
  135. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  136. tpu_inference/kernels/mla/v1/kernel.py +117 -145
  137. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  138. tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
  139. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  140. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  141. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
  142. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
  143. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  144. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
  145. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +218 -137
  146. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
  147. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
  148. tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
  149. tpu_inference/layers/__init__.py +13 -0
  150. tpu_inference/layers/common/__init__.py +13 -0
  151. tpu_inference/layers/common/attention_interface.py +25 -12
  152. tpu_inference/layers/common/attention_metadata.py +14 -0
  153. tpu_inference/layers/common/fused_moe_gmm.py +506 -0
  154. tpu_inference/layers/common/quant_methods.py +15 -0
  155. tpu_inference/layers/common/quantization.py +282 -0
  156. tpu_inference/layers/common/sharding.py +32 -9
  157. tpu_inference/layers/common/utils.py +94 -0
  158. tpu_inference/layers/jax/__init__.py +13 -0
  159. tpu_inference/layers/jax/attention/__init__.py +13 -0
  160. tpu_inference/layers/jax/attention/attention.py +19 -6
  161. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
  162. tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
  163. tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
  164. tpu_inference/layers/jax/base.py +14 -0
  165. tpu_inference/layers/jax/constants.py +13 -0
  166. tpu_inference/layers/jax/layers.py +14 -0
  167. tpu_inference/layers/jax/misc.py +14 -0
  168. tpu_inference/layers/jax/moe/__init__.py +13 -0
  169. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
  170. tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
  171. tpu_inference/layers/jax/moe/moe.py +43 -3
  172. tpu_inference/layers/jax/pp_utils.py +53 -0
  173. tpu_inference/layers/jax/rope.py +14 -0
  174. tpu_inference/layers/jax/rope_interface.py +14 -0
  175. tpu_inference/layers/jax/sample/__init__.py +13 -0
  176. tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
  177. tpu_inference/layers/jax/sample/sampling.py +15 -1
  178. tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
  179. tpu_inference/layers/jax/transformer_block.py +14 -0
  180. tpu_inference/layers/vllm/__init__.py +13 -0
  181. tpu_inference/layers/vllm/attention.py +4 -4
  182. tpu_inference/layers/vllm/fused_moe.py +101 -494
  183. tpu_inference/layers/vllm/linear.py +64 -0
  184. tpu_inference/layers/vllm/process_weights/__init__.py +13 -0
  185. tpu_inference/layers/vllm/{sharding.py → process_weights/cleanup_sharding.py} +24 -15
  186. tpu_inference/layers/vllm/process_weights/fused_moe_weights.py +369 -0
  187. tpu_inference/layers/vllm/process_weights/linear_weights.py +174 -0
  188. tpu_inference/layers/vllm/quantization/__init__.py +19 -3
  189. tpu_inference/layers/vllm/quantization/awq.py +96 -82
  190. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  191. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +23 -8
  192. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +172 -176
  193. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  194. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +111 -91
  195. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +79 -43
  196. tpu_inference/layers/vllm/quantization/{common.py → configs.py} +42 -25
  197. tpu_inference/layers/vllm/quantization/fp8.py +119 -0
  198. tpu_inference/layers/vllm/quantization/mxfp4.py +137 -178
  199. tpu_inference/layers/vllm/quantization/unquantized.py +157 -233
  200. tpu_inference/lora/__init__.py +13 -0
  201. tpu_inference/lora/torch_lora_ops.py +8 -13
  202. tpu_inference/models/__init__.py +13 -0
  203. tpu_inference/models/common/__init__.py +13 -0
  204. tpu_inference/models/common/model_loader.py +112 -35
  205. tpu_inference/models/jax/__init__.py +13 -0
  206. tpu_inference/models/jax/deepseek_v3.py +267 -157
  207. tpu_inference/models/jax/gpt_oss.py +26 -10
  208. tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
  209. tpu_inference/models/jax/llama3.py +99 -36
  210. tpu_inference/models/jax/llama4.py +14 -0
  211. tpu_inference/models/jax/llama_eagle3.py +18 -5
  212. tpu_inference/models/jax/llama_guard_4.py +15 -1
  213. tpu_inference/models/jax/qwen2.py +17 -2
  214. tpu_inference/models/jax/qwen2_5_vl.py +179 -51
  215. tpu_inference/models/jax/qwen3.py +17 -2
  216. tpu_inference/models/jax/utils/__init__.py +13 -0
  217. tpu_inference/models/jax/utils/file_utils.py +14 -0
  218. tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
  219. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  220. tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +92 -32
  221. tpu_inference/models/jax/utils/weight_utils.py +234 -155
  222. tpu_inference/models/vllm/__init__.py +13 -0
  223. tpu_inference/models/vllm/vllm_model_wrapper.py +32 -8
  224. tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
  225. tpu_inference/platforms/__init__.py +14 -0
  226. tpu_inference/platforms/tpu_platform.py +51 -72
  227. tpu_inference/runner/__init__.py +13 -0
  228. tpu_inference/runner/compilation_manager.py +180 -80
  229. tpu_inference/runner/kv_cache.py +54 -20
  230. tpu_inference/runner/kv_cache_manager.py +55 -33
  231. tpu_inference/runner/lora_utils.py +16 -1
  232. tpu_inference/runner/multimodal_manager.py +16 -2
  233. tpu_inference/runner/persistent_batch_manager.py +54 -2
  234. tpu_inference/runner/speculative_decoding_manager.py +14 -0
  235. tpu_inference/runner/structured_decoding_manager.py +16 -3
  236. tpu_inference/runner/tpu_runner.py +124 -61
  237. tpu_inference/runner/utils.py +2 -2
  238. tpu_inference/spec_decode/__init__.py +13 -0
  239. tpu_inference/spec_decode/jax/__init__.py +13 -0
  240. tpu_inference/spec_decode/jax/eagle3.py +84 -22
  241. tpu_inference/tpu_info.py +14 -0
  242. tpu_inference/utils.py +72 -44
  243. tpu_inference/worker/__init__.py +13 -0
  244. tpu_inference/worker/tpu_worker.py +66 -52
  245. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/METADATA +8 -9
  246. tpu_inference-0.13.2.dev20251230.dist-info/RECORD +266 -0
  247. tpu_inference/layers/vllm/linear_common.py +0 -186
  248. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  249. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
  250. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
  251. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
  252. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
  253. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
  254. tpu_inference-0.11.1.dev202511220812.dist-info/RECORD +0 -174
  255. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/WHEEL +0 -0
  256. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/licenses/LICENSE +0 -0
  257. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -655,7 +655,8 @@ def cdiv(a, b):
655
655
 
656
656
 
657
657
  def get_dtype_packing(dtype):
658
- bits = dtypes.bit_width(dtype)
658
+ bits = (dtypes.bit_width(dtype)
659
+ if hasattr(dtypes, "bit_width") else dtypes.itemsize_bits(dtype))
659
660
  return 32 // bits
660
661
 
661
662
 
@@ -200,7 +200,8 @@ def _prev_power_of_2(n: int) -> int:
200
200
  def _get_page_size_bytes(block_size: int, num_combined_kv_heads: int,
201
201
  head_size: int, kv_cache_dtype) -> int:
202
202
  """Returns the size in bytes of one page of the KV cache."""
203
- kv_cache_dtype_bit_size = dtypes.bit_width(kv_cache_dtype)
203
+ kv_cache_dtype_bit_size = (dtypes.bit_width(kv_cache_dtype) if hasattr(
204
+ dtypes, "bit_width") else dtypes.itemsize_bits(kv_cache_dtype))
204
205
  padded_head_size = _ceil_div(
205
206
  head_size, TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT
206
207
 
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -1,3 +1,16 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
1
14
  """TPU-Friendly Ragged Paged Attention kernel.
2
15
 
3
16
  This kernel offers a highly optimized implementation of ragged paged attention,
@@ -300,6 +313,22 @@ def _ragged_paged_attention_kernel(
300
313
  q_len = q_end - q_start
301
314
  kv_len = kv_lens_ref[seq_idx]
302
315
 
316
+ if sliding_window is None:
317
+ bkv_idx_start = next_seq_bkv_idx_start = 0
318
+ else:
319
+ bkv_idx_start = jnp.maximum(kv_len - q_len - sliding_window,
320
+ 0) // bkv_sz
321
+
322
+ # If seq_idx + 1 == num_seqs, kv_lens_ref[seq_idx + 1] will trigger a
323
+ # out-of-bound error. To avoid this, we set upperbound of next_seq_idx
324
+ # to be num_seqs - 1.
325
+ next_seq_idx = jnp.minimum(seq_idx + 1, num_seqs - 1)
326
+ next_kv_len = kv_lens_ref[next_seq_idx]
327
+ next_q_len = cu_q_lens_ref[next_seq_idx + 1] - q_end
328
+ next_seq_bkv_idx_start = (
329
+ jnp.maximum(next_kv_len - next_q_len - sliding_window, 0) //
330
+ bkv_sz)
331
+
303
332
  def debug_print(msg, *args):
304
333
  if debug_mode:
305
334
  pl.debug_print(msg, *args)
@@ -319,7 +348,7 @@ def _ragged_paged_attention_kernel(
319
348
  debug_print("[RPA debug] q_len={}", q_len)
320
349
  debug_print("[RPA debug] kv_len={}", kv_len)
321
350
 
322
- def flash_attention(
351
+ def flash_attention_step1_qk_softmax(
323
352
  q, # [actual_bq_sz * num_q_heads_per_kv_head, head_dim]
324
353
  k, # [bkv_sz, head_dim]
325
354
  v, # [bkv_sz, head_dim]
@@ -335,11 +364,10 @@ def _ragged_paged_attention_kernel(
335
364
  assert k.dtype == v.dtype
336
365
  head_l_ref = l_ref.at[kv_head_idx, :q.shape[0]]
337
366
  head_m_ref = m_ref.at[kv_head_idx, :q.shape[0]]
338
- head_acc_ref = acc_ref.at[kv_head_idx, :q.shape[0]]
339
367
 
340
368
  def load_with_init(ref, init_val):
341
- return jnp.where(bkv_idx == 0, jnp.full_like(ref, init_val),
342
- ref[...])
369
+ return jnp.where(bkv_idx == bkv_idx_start,
370
+ jnp.full_like(ref, init_val), ref[...])
343
371
 
344
372
  # Follow FlashAttention-2 forward pass.
345
373
  if q_scale is not None:
@@ -357,34 +385,52 @@ def _ragged_paged_attention_kernel(
357
385
  s *= k_scale
358
386
  if q_scale is not None:
359
387
  s *= q_scale
388
+ if soft_cap is not None:
389
+ s = soft_cap * jnp.tanh(s / soft_cap)
360
390
 
361
391
  q_span = (kv_len - q_len + bq_idx * bq_sz +
362
392
  lax.broadcasted_iota(jnp.int32, s.shape, 0) //
363
393
  num_q_heads_per_kv_head)
364
394
  k_span = bkv_idx * bkv_sz + lax.broadcasted_iota(jnp.int32, s.shape, 1)
365
- mask = q_span < k_span
366
- # TODO(jevinjiang, xiowei): reduce pages_per_seq based on sliding_window.
395
+ mask = k_span <= q_span
396
+
367
397
  if sliding_window is not None:
368
- mask = jnp.logical_or(mask, q_span - sliding_window >= k_span)
398
+ mask = jnp.logical_and(mask, q_span - sliding_window < k_span)
369
399
 
370
- if soft_cap is not None:
371
- s = soft_cap * jnp.tanh(s / soft_cap)
372
- s += jnp.where(mask, mask_value, 0.0)
400
+ s = jnp.where(mask, s, mask_value)
373
401
  s_rowmax = jnp.max(s, axis=1, keepdims=True)
402
+
374
403
  m_prev = load_with_init(head_m_ref, -jnp.inf)
375
404
  m_curr = jnp.maximum(m_prev, s_rowmax)
376
405
  head_m_ref[...] = m_curr
377
406
  p = jnp.exp(s - broadcast_minor(m_curr, s.shape))
378
407
 
379
- pv = jnp.einsum("nm,md->nd", p, v, preferred_element_type=jnp.float32)
380
- if v_scale is not None:
381
- pv *= v_scale
382
-
383
408
  p_rowsum = jnp.sum(p, axis=1, keepdims=True)
384
409
  exp_m_diff = jnp.exp(m_prev - m_curr)
385
410
  l_prev = load_with_init(head_l_ref, 0.0)
386
411
  l_curr = exp_m_diff * l_prev + p_rowsum
387
412
  head_l_ref[...] = l_curr
413
+
414
+ return p, exp_m_diff
415
+
416
+ def flash_attention_step2_pv(
417
+ q_shape_0,
418
+ v, # [bkv_sz, head_dim]
419
+ p, # from step1
420
+ exp_m_diff, # from step1
421
+ *,
422
+ bkv_idx,
423
+ kv_head_idx,
424
+ ):
425
+ head_acc_ref = acc_ref.at[kv_head_idx, :q_shape_0]
426
+
427
+ def load_with_init(ref, init_val):
428
+ return jnp.where(bkv_idx == 0, jnp.full_like(ref, init_val),
429
+ ref[...])
430
+
431
+ pv = jnp.einsum("nm,md->nd", p, v, preferred_element_type=jnp.float32)
432
+ if v_scale is not None:
433
+ pv *= v_scale
388
434
  o_prev = load_with_init(head_acc_ref, 0.0)
389
435
  o_curr = broadcast_minor(exp_m_diff, o_prev.shape) * o_prev + pv
390
436
  head_acc_ref[...] = o_curr
@@ -463,19 +509,16 @@ def _ragged_paged_attention_kernel(
463
509
  unroll=False,
464
510
  )
465
511
 
466
- # Fetch kv directly from new kv.
467
- @pl.when(bkv_sz_frm_new > 0)
468
- def _fetch_bkv_from_new_kv():
469
- new_kv_len_start = q_end - kv_left_frm_new
470
- debug_print("[RPA debug] new_kv_len_start={}",
471
- new_kv_len_start)
472
- debug_print("[RPA debug] offset_in_bkv={}", offset)
473
- _async_copy(
474
- kv_hbm_ref.at[pl.ds(new_kv_len_start, bkv_sz_frm_new)],
475
- vmem_ref.at[pl.ds(offset, bkv_sz_frm_new)],
476
- sem,
477
- wait,
478
- )
512
+ size = lax.select(bkv_sz_frm_new > 0, bkv_sz_frm_new, 0)
513
+ new_kv_len_start = q_end - kv_left_frm_new
514
+ debug_print("[RPA debug] new_kv_len_start={}", new_kv_len_start)
515
+ debug_print("[RPA debug] offset_in_bkv={}", offset)
516
+ _async_copy(
517
+ kv_hbm_ref.at[pl.ds(new_kv_len_start, size)],
518
+ vmem_ref.at[pl.ds(offset, size)],
519
+ sem,
520
+ wait,
521
+ )
479
522
 
480
523
  return kv_len_start + offset, bkv_sz_frm_new
481
524
  else:
@@ -672,7 +715,7 @@ def _ragged_paged_attention_kernel(
672
715
  vec = jnp.concat([ref[start + i::step] for i in range(folds)], axis=1)
673
716
  return vec
674
717
 
675
- def strided_load_bkv(bkv_sem_idx, start, step, *, bkv_mask):
718
+ def strided_load_bkv(bkv_sem_idx, start, step):
676
719
  assert start % kv_packing == 0
677
720
  assert step % kv_packing == 0
678
721
  start //= kv_packing
@@ -684,21 +727,11 @@ def _ragged_paged_attention_kernel(
684
727
  k = strided_load(kv_ref, start, step)
685
728
  v = strided_load(kv_ref, start + 1, step)
686
729
 
687
- kv_zeros = jnp.zeros_like(k)
688
- k = lax.select(bkv_mask, k, kv_zeros)
689
- v = lax.select(bkv_mask, v, kv_zeros)
690
-
691
730
  k = pltpu.bitcast(k, kv_dtype)
692
731
  v = pltpu.bitcast(v, kv_dtype)
693
732
  return [(k, v)]
694
733
 
695
734
  kv = strided_load(kv_ref, start, step)
696
- # bkv_mask holds information about where each row of bkv is valid. Because
697
- # kv is packed, a single 32-bits value might contain multiple k & v from
698
- # different kv heads. Despite this we can guarantee that all values in a
699
- # single 32-bits will map to the same bkv row. Therefore, it is safe to
700
- # apply bkv_mask to kv directly.
701
- kv = lax.select(bkv_mask, kv, jnp.zeros_like(kv))
702
735
  bitwidth = 32 // kv_packing
703
736
 
704
737
  # If we want to convert 32-bits into 32//N number of N-bits value, naive
@@ -776,12 +809,27 @@ def _ragged_paged_attention_kernel(
776
809
  def get_next_bkv_ids(seq_idx, bq_idx, bkv_idx, bkv_sem_idx):
777
810
  next_bkv_idx = bkv_idx + 1
778
811
  is_last_bkv = next_bkv_idx == num_bkv
779
- next_bkv_idx = lax.select(is_last_bkv, 0, next_bkv_idx)
780
812
  next_bq_idx = lax.select(is_last_bkv, bq_idx + 1, bq_idx)
781
813
  is_last_bq = next_bq_idx == num_bq
782
814
  next_bq_idx = lax.select(is_last_bq, 0, next_bq_idx)
783
815
  next_seq_idx = lax.select(is_last_bq, seq_idx + 1, seq_idx)
784
816
  next_bkv_sem_idx = lax.select(bkv_sem_idx == 0, 1, 0)
817
+
818
+ if sliding_window is None:
819
+ # When sliding window is disabled, starting bkv_idx of next request is
820
+ # always 0 regardless of seq_idx of next request.
821
+ next_bkv_idx_start = 0
822
+ else:
823
+ # Determine starting bkv_idx of next request based on whether next
824
+ # request is from the same sequence or next sequence.
825
+ next_bkv_idx_start = lax.select(
826
+ is_last_bq,
827
+ next_seq_bkv_idx_start,
828
+ bkv_idx_start,
829
+ )
830
+ next_bkv_idx = lax.select(is_last_bkv, next_bkv_idx_start,
831
+ next_bkv_idx)
832
+
785
833
  return next_seq_idx, next_bq_idx, next_bkv_idx, next_bkv_sem_idx
786
834
 
787
835
  def compute_with_bq(bq_idx, _):
@@ -798,10 +846,6 @@ def _ragged_paged_attention_kernel(
798
846
  def compute_with_bkv(bkv_idx, _):
799
847
  # Create bitmask for KV.
800
848
  assert bkv_sz % kv_packing == 0
801
- actual_bkv_sz = jnp.minimum(bkv_sz, kv_len - bkv_idx * bkv_sz)
802
- bkv_shape = (bkv_sz, head_dim)
803
- bkv_mask = lax.broadcasted_iota(jnp.int32, bkv_shape,
804
- 0) < actual_bkv_sz
805
849
 
806
850
  # Get next bkv ids.
807
851
  bkv_sem_idx = sem_ids_ref[1]
@@ -842,6 +886,11 @@ def _ragged_paged_attention_kernel(
842
886
 
843
887
  # Flash attention with cur bkv and bq
844
888
  # NOTE: kv_packing is divided by 2 because k and v are packed together.
889
+ prev_bq_shape_0 = None
890
+ prev_kv_head_bv = None
891
+ prev_kv_head_idx = None
892
+ prev_kv_head_p = None
893
+ prev_kv_head_exp_m_diff = None
845
894
  heads_per_load = max(1, kv_packing // 2)
846
895
  for kv_head_start in range(0, actual_num_kv_heads,
847
896
  heads_per_load):
@@ -849,25 +898,56 @@ def _ragged_paged_attention_kernel(
849
898
  bkv_sem_idx,
850
899
  kv_head_start * 2,
851
900
  num_kv_heads_x2,
852
- bkv_mask=bkv_mask,
853
901
  )
854
902
  assert len(bkv_lst) == heads_per_load
855
903
  for i in range(heads_per_load):
856
- kv_head_idx = kv_head_start + i
857
- if kv_head_idx >= actual_num_kv_heads:
904
+ cur_kv_head_idx = kv_head_start + i
905
+ if cur_kv_head_idx >= actual_num_kv_heads:
858
906
  break
859
- bq = load_bq(bq_sem_idx,
860
- kv_head_idx,
861
- actual_bq_sz=actual_bq_sz)
907
+
908
+ cur_kv_head_bq = load_bq(bq_sem_idx,
909
+ cur_kv_head_idx,
910
+ actual_bq_sz=actual_bq_sz)
862
911
  bk, bv = bkv_lst[i]
863
- flash_attention(
864
- bq,
865
- bk,
866
- bv,
867
- bq_idx=bq_idx,
868
- bkv_idx=bkv_idx,
869
- kv_head_idx=kv_head_idx,
870
- )
912
+ # FlashAttention is divided into `flash_attention_step1_qk_softmax`
913
+ # and `flash_attention_step2_pv` to pipeline the computation.
914
+ # `step2_pv` for the previous KV head, which depends on the softmax
915
+ # output, is overlapped with `step1_qk_softmax` for the current KV
916
+ # head, reducing overall wait times.
917
+ cur_kv_head_p, cur_kv_head_exp_m_diff = (
918
+ flash_attention_step1_qk_softmax(
919
+ cur_kv_head_bq,
920
+ bk,
921
+ bv,
922
+ bq_idx=bq_idx,
923
+ bkv_idx=bkv_idx,
924
+ kv_head_idx=cur_kv_head_idx,
925
+ ))
926
+ if prev_bq_shape_0 is not None:
927
+ flash_attention_step2_pv(
928
+ prev_bq_shape_0,
929
+ prev_kv_head_bv,
930
+ prev_kv_head_p,
931
+ prev_kv_head_exp_m_diff,
932
+ bkv_idx=bkv_idx,
933
+ kv_head_idx=prev_kv_head_idx,
934
+ )
935
+ prev_bq_shape_0 = cur_kv_head_bq.shape[0]
936
+ prev_kv_head_bv = bv
937
+ prev_kv_head_p = cur_kv_head_p
938
+ prev_kv_head_exp_m_diff = cur_kv_head_exp_m_diff
939
+ prev_kv_head_idx = cur_kv_head_idx
940
+
941
+ # Execute pv of last attention head.
942
+ assert prev_bq_shape_0 is not None
943
+ flash_attention_step2_pv(
944
+ prev_bq_shape_0,
945
+ prev_kv_head_bv,
946
+ prev_kv_head_p,
947
+ prev_kv_head_exp_m_diff,
948
+ bkv_idx=bkv_idx,
949
+ kv_head_idx=prev_kv_head_idx,
950
+ )
871
951
 
872
952
  lax.fori_loop(0, num_bkv, compute_with_bkv, None, unroll=False)
873
953
 
@@ -899,7 +979,17 @@ def _ragged_paged_attention_kernel(
899
979
  @pl.when(seq_idx == 0)
900
980
  def prologue():
901
981
  start_fetch_bq(0, 0, 0)
902
- start_fetch_bkv(0, 0, 0)
982
+
983
+ # Initialize bkv_x2_ref to zeros to avoid NaN issues from accessing
984
+ # uninitialized memory. Bitcast into int32 to avoid tiling issues.
985
+ bkv_x2_int32_ref = bkv_x2_ref.bitcast(jnp.int32).reshape(
986
+ (2, -1, 8, 128))
987
+ zeros = jnp.zeros(bkv_x2_int32_ref.shape[1:], jnp.int32)
988
+
989
+ # To pipeline VST and DMA, we divide the initialization into two steps.
990
+ bkv_x2_int32_ref[0] = zeros
991
+ start_fetch_bkv(0, bkv_idx_start, 0)
992
+ bkv_x2_int32_ref[1] = zeros
903
993
 
904
994
  @pl.when(seq_idx < decode_end)
905
995
  def process_decode():
@@ -1248,6 +1338,10 @@ def static_validate_inputs(
1248
1338
  del debug_mode
1249
1339
 
1250
1340
 
1341
+ def get_kernel_scope_name(bq_size, bkv_p, page_size):
1342
+ return f"RPA-bq_{bq_size}-bkvp_{bkv_p}-p_{page_size}-"
1343
+
1344
+
1251
1345
  @functools.partial(
1252
1346
  jax.jit,
1253
1347
  static_argnames=(
@@ -1309,14 +1403,14 @@ def ragged_paged_attention(
1309
1403
  distribution: (i, j, k) represents that sequences[0:i] are decode-only,
1310
1404
  sequences[i:j] are chunked-prefill-only, and sequences[j:k] are mixed. The
1311
1405
  k is also the total number of sequences.
1312
- actual_head_dim: the actual head size of the attention. Here we assume k and
1313
- v have the same actual head size.
1314
1406
  sm_scale: the softmax scale which will be applied to the Q@K^T.
1315
1407
  sliding_window: the sliding window size for the attention.
1316
1408
  soft_cap: the logit soft cap for the attention.
1317
1409
  mask_value: mask value for causal mask.
1410
+ q_scale: the scale for the query.
1318
1411
  k_scale: the scale for the key cache.
1319
1412
  v_scale: the scale for the value cache.
1413
+ chunk_prefill_size: the chunk prefill size for the attention.
1320
1414
  num_kv_pages_per_block: number of kv pages to be processed in one flash
1321
1415
  attention block in the pallas kernel.
1322
1416
  num_queries_per_block: number of kv pages to be processed in one flash
@@ -1383,6 +1477,7 @@ def ragged_paged_attention(
1383
1477
  page_size,
1384
1478
  max_num_tokens,
1385
1479
  pages_per_seq,
1480
+ sliding_window,
1386
1481
  )
1387
1482
  bkv_sz = bkv_p * page_size
1388
1483
  if vmem_limit_bytes is None:
@@ -1451,47 +1546,45 @@ def ragged_paged_attention(
1451
1546
  jnp.full((6, ), -1, jnp.int32),
1452
1547
  )
1453
1548
 
1454
- scope_name = f"RPA-bq_{bq_sz}-bkvp_{bkv_p}-p_{page_size}"
1455
- kernel = jax.named_scope(scope_name)(
1456
- pl.pallas_call(
1457
- functools.partial(
1458
- _ragged_paged_attention_kernel,
1459
- sm_scale=sm_scale,
1460
- sliding_window=sliding_window,
1461
- soft_cap=soft_cap,
1462
- mask_value=mask_value,
1463
- q_scale=q_scale,
1464
- k_scale=k_scale,
1465
- v_scale=v_scale,
1466
- chunk_prefill_size=chunk_prefill_size,
1467
- bq_sz=bq_sz,
1468
- bkv_p=bkv_p,
1469
- debug_mode=debug_mode,
1470
- ),
1471
- grid_spec=pltpu.PrefetchScalarGridSpec(
1472
- num_scalar_prefetch=len(scalar_prefetches),
1473
- in_specs=in_specs,
1474
- out_specs=out_specs,
1475
- grid=grid,
1476
- scratch_shapes=scratch_shapes,
1477
- ),
1478
- compiler_params=pltpu.CompilerParams(
1479
- # TODO(jevinjiang): since each sequence depends on the previous
1480
- # one, we need some extra work to support Megacore mode.
1481
- dimension_semantics=("arbitrary", ),
1482
- vmem_limit_bytes=vmem_limit_bytes,
1483
- ),
1484
- out_shape=[
1485
- jax.ShapeDtypeStruct(shape=q.shape, dtype=q.dtype),
1486
- jax.ShapeDtypeStruct(shape=kv_cache.shape,
1487
- dtype=kv_cache.dtype),
1488
- ],
1489
- input_output_aliases={
1490
- 7: 0,
1491
- 9: 1
1492
- },
1493
- name=scope_name,
1494
- ))
1549
+ scope_name = get_kernel_scope_name(bq_sz, bkv_p, page_size)
1550
+ kernel = pl.pallas_call(
1551
+ functools.partial(
1552
+ _ragged_paged_attention_kernel,
1553
+ sm_scale=sm_scale,
1554
+ sliding_window=sliding_window,
1555
+ soft_cap=soft_cap,
1556
+ mask_value=mask_value,
1557
+ q_scale=q_scale,
1558
+ k_scale=k_scale,
1559
+ v_scale=v_scale,
1560
+ chunk_prefill_size=chunk_prefill_size,
1561
+ bq_sz=bq_sz,
1562
+ bkv_p=bkv_p,
1563
+ debug_mode=debug_mode,
1564
+ ),
1565
+ grid_spec=pltpu.PrefetchScalarGridSpec(
1566
+ num_scalar_prefetch=len(scalar_prefetches),
1567
+ in_specs=in_specs,
1568
+ out_specs=out_specs,
1569
+ grid=grid,
1570
+ scratch_shapes=scratch_shapes,
1571
+ ),
1572
+ compiler_params=pltpu.CompilerParams(
1573
+ # TODO(jevinjiang): since each sequence depends on the previous
1574
+ # one, we need some extra work to support Megacore mode.
1575
+ dimension_semantics=("arbitrary", ),
1576
+ vmem_limit_bytes=vmem_limit_bytes,
1577
+ ),
1578
+ out_shape=[
1579
+ jax.ShapeDtypeStruct(shape=q.shape, dtype=q.dtype),
1580
+ jax.ShapeDtypeStruct(shape=kv_cache.shape, dtype=kv_cache.dtype),
1581
+ ],
1582
+ input_output_aliases={
1583
+ 7: 0,
1584
+ 9: 1
1585
+ },
1586
+ name=scope_name,
1587
+ )
1495
1588
 
1496
1589
  output, updated_kv_cache = kernel(*scalar_prefetches, q, kv, kv_cache)
1497
1590
  return (