tpu-inference 0.11.1.dev202511220812__py3-none-any.whl → 0.13.2.dev20251230__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (257) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_disagg_utils.py +14 -0
  4. tests/core/test_dp_scheduler.py +650 -768
  5. tests/core/test_init.py +14 -0
  6. tests/distributed/__init__.py +13 -0
  7. tests/distributed/test_distributed_utils.py +120 -0
  8. tests/distributed/test_tpu_connector.py +478 -0
  9. tests/e2e/__init__.py +13 -0
  10. tests/e2e/test_async_scheduler.py +211 -0
  11. tests/e2e/test_data_parallel.py +289 -0
  12. tests/e2e/test_hybrid_kvcache.py +219 -0
  13. tests/e2e/test_local_disagg.py +257 -0
  14. tests/e2e/test_model_loader.py +268 -0
  15. tests/e2e/test_multi_modal_inference.py +111 -0
  16. tests/e2e/test_pipeline_parallel.py +265 -0
  17. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  18. tests/e2e/test_sampling_params.py +269 -0
  19. tests/e2e/test_speculative_decoding.py +311 -0
  20. tests/e2e/test_structured_decoding.py +46 -0
  21. tests/executors/__init__.py +13 -0
  22. tests/executors/test_ray_distributed_executor.py +199 -0
  23. tests/experimental/__init__.py +13 -0
  24. tests/experimental/test_llama3_jax_stashed.py +208 -0
  25. tests/kernels/__init__.py +13 -0
  26. tests/kernels/collectives/__init__.py +13 -0
  27. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  28. tests/kernels/fused_moe_v1_test.py +317 -34
  29. tests/kernels/gmm_test.py +205 -0
  30. tests/kernels/mla_v1_test.py +143 -41
  31. tests/kernels/quantized_matmul_kernel_test.py +2 -34
  32. tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
  33. tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
  34. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
  35. tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
  36. tests/layers/__init__.py +13 -0
  37. tests/layers/common/__init__.py +13 -0
  38. tests/layers/common/test_attention_interface.py +156 -0
  39. tests/layers/common/test_quantization.py +149 -0
  40. tests/layers/jax/__init__.py +13 -0
  41. tests/layers/jax/attention/__init__.py +13 -0
  42. tests/layers/jax/attention/test_common_attention.py +103 -0
  43. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  44. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  45. tests/layers/jax/moe/__init__.py +13 -0
  46. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  47. tests/layers/jax/sample/__init__.py +13 -0
  48. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  49. tests/layers/jax/sample/test_sampling.py +115 -0
  50. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  51. tests/layers/jax/test_layers.py +155 -0
  52. tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
  53. tests/layers/jax/test_rope.py +93 -0
  54. tests/layers/jax/test_sharding.py +159 -0
  55. tests/layers/jax/test_transformer_block.py +152 -0
  56. tests/layers/vllm/__init__.py +13 -0
  57. tests/layers/vllm/test_attention.py +363 -0
  58. tests/layers/vllm/test_awq.py +406 -0
  59. tests/layers/vllm/test_compressed_tensors_moe.py +199 -0
  60. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +441 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +443 -0
  62. tests/layers/vllm/test_fp8.py +17 -0
  63. tests/layers/vllm/test_mxfp4.py +320 -0
  64. tests/layers/vllm/test_unquantized.py +662 -0
  65. tests/layers/vllm/utils.py +87 -0
  66. tests/lora/__init__.py +13 -0
  67. tests/lora/conftest.py +14 -0
  68. tests/lora/test_bgmv.py +14 -0
  69. tests/lora/test_layers.py +26 -6
  70. tests/lora/test_lora.py +15 -1
  71. tests/lora/test_lora_perf.py +67 -0
  72. tests/models/__init__.py +13 -0
  73. tests/models/common/__init__.py +13 -0
  74. tests/models/common/test_model_loader.py +455 -0
  75. tests/models/jax/__init__.py +13 -0
  76. tests/models/jax/test_deepseek_v3.py +401 -0
  77. tests/models/jax/test_llama3.py +184 -0
  78. tests/models/jax/test_llama4.py +298 -0
  79. tests/models/jax/test_llama_eagle3.py +197 -0
  80. tests/models/jax/test_llama_guard_4.py +242 -0
  81. tests/models/jax/test_qwen2.py +172 -0
  82. tests/models/jax/test_qwen2_5_vl.py +605 -0
  83. tests/models/jax/test_qwen3.py +169 -0
  84. tests/models/jax/test_weight_loading.py +180 -0
  85. tests/models/jax/utils/__init__.py +13 -0
  86. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  87. tests/platforms/__init__.py +13 -0
  88. tests/platforms/test_tpu_platform.py +54 -0
  89. tests/runner/__init__.py +13 -0
  90. tests/runner/test_block_table.py +395 -0
  91. tests/runner/test_input_batch.py +226 -0
  92. tests/runner/test_kv_cache.py +220 -0
  93. tests/runner/test_kv_cache_manager.py +498 -0
  94. tests/runner/test_multimodal_manager.py +429 -0
  95. tests/runner/test_persistent_batch_manager.py +84 -0
  96. tests/runner/test_speculative_decoding_manager.py +368 -0
  97. tests/runner/test_structured_decoding_manager.py +220 -0
  98. tests/runner/test_tpu_runner.py +261 -0
  99. tests/runner/test_tpu_runner_dp.py +1099 -0
  100. tests/runner/test_tpu_runner_mesh.py +200 -0
  101. tests/runner/test_utils.py +411 -0
  102. tests/spec_decode/__init__.py +13 -0
  103. tests/spec_decode/test_eagle3.py +311 -0
  104. tests/test_base.py +14 -0
  105. tests/test_envs.py +110 -12
  106. tests/test_tpu_info.py +14 -0
  107. tests/test_utils.py +2 -45
  108. tests/worker/__init__.py +13 -0
  109. tests/worker/tpu_worker_test.py +414 -0
  110. tpu_inference/__init__.py +14 -0
  111. tpu_inference/core/__init__.py +13 -0
  112. tpu_inference/core/sched/__init__.py +13 -0
  113. tpu_inference/core/sched/dp_scheduler.py +372 -56
  114. tpu_inference/distributed/__init__.py +13 -0
  115. tpu_inference/distributed/jax_parallel_state.py +14 -0
  116. tpu_inference/distributed/tpu_connector.py +15 -10
  117. tpu_inference/distributed/utils.py +56 -4
  118. tpu_inference/envs.py +92 -8
  119. tpu_inference/executors/__init__.py +13 -0
  120. tpu_inference/executors/ray_distributed_executor.py +25 -4
  121. tpu_inference/experimental/__init__.py +13 -0
  122. tpu_inference/experimental/llama3_jax_stashed.py +14 -0
  123. tpu_inference/kernels/__init__.py +13 -0
  124. tpu_inference/kernels/collectives/__init__.py +13 -0
  125. tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
  126. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
  127. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  128. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  129. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  130. tpu_inference/kernels/fused_moe/v1/kernel.py +807 -230
  131. tpu_inference/kernels/megablox/__init__.py +13 -0
  132. tpu_inference/kernels/megablox/common.py +54 -0
  133. tpu_inference/kernels/megablox/gmm.py +646 -0
  134. tpu_inference/kernels/mla/__init__.py +13 -0
  135. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  136. tpu_inference/kernels/mla/v1/kernel.py +117 -145
  137. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  138. tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
  139. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  140. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  141. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
  142. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
  143. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  144. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
  145. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +218 -137
  146. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
  147. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
  148. tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
  149. tpu_inference/layers/__init__.py +13 -0
  150. tpu_inference/layers/common/__init__.py +13 -0
  151. tpu_inference/layers/common/attention_interface.py +25 -12
  152. tpu_inference/layers/common/attention_metadata.py +14 -0
  153. tpu_inference/layers/common/fused_moe_gmm.py +506 -0
  154. tpu_inference/layers/common/quant_methods.py +15 -0
  155. tpu_inference/layers/common/quantization.py +282 -0
  156. tpu_inference/layers/common/sharding.py +32 -9
  157. tpu_inference/layers/common/utils.py +94 -0
  158. tpu_inference/layers/jax/__init__.py +13 -0
  159. tpu_inference/layers/jax/attention/__init__.py +13 -0
  160. tpu_inference/layers/jax/attention/attention.py +19 -6
  161. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
  162. tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
  163. tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
  164. tpu_inference/layers/jax/base.py +14 -0
  165. tpu_inference/layers/jax/constants.py +13 -0
  166. tpu_inference/layers/jax/layers.py +14 -0
  167. tpu_inference/layers/jax/misc.py +14 -0
  168. tpu_inference/layers/jax/moe/__init__.py +13 -0
  169. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
  170. tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
  171. tpu_inference/layers/jax/moe/moe.py +43 -3
  172. tpu_inference/layers/jax/pp_utils.py +53 -0
  173. tpu_inference/layers/jax/rope.py +14 -0
  174. tpu_inference/layers/jax/rope_interface.py +14 -0
  175. tpu_inference/layers/jax/sample/__init__.py +13 -0
  176. tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
  177. tpu_inference/layers/jax/sample/sampling.py +15 -1
  178. tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
  179. tpu_inference/layers/jax/transformer_block.py +14 -0
  180. tpu_inference/layers/vllm/__init__.py +13 -0
  181. tpu_inference/layers/vllm/attention.py +4 -4
  182. tpu_inference/layers/vllm/fused_moe.py +101 -494
  183. tpu_inference/layers/vllm/linear.py +64 -0
  184. tpu_inference/layers/vllm/process_weights/__init__.py +13 -0
  185. tpu_inference/layers/vllm/{sharding.py → process_weights/cleanup_sharding.py} +24 -15
  186. tpu_inference/layers/vllm/process_weights/fused_moe_weights.py +369 -0
  187. tpu_inference/layers/vllm/process_weights/linear_weights.py +174 -0
  188. tpu_inference/layers/vllm/quantization/__init__.py +19 -3
  189. tpu_inference/layers/vllm/quantization/awq.py +96 -82
  190. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  191. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +23 -8
  192. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +172 -176
  193. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  194. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +111 -91
  195. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +79 -43
  196. tpu_inference/layers/vllm/quantization/{common.py → configs.py} +42 -25
  197. tpu_inference/layers/vllm/quantization/fp8.py +119 -0
  198. tpu_inference/layers/vllm/quantization/mxfp4.py +137 -178
  199. tpu_inference/layers/vllm/quantization/unquantized.py +157 -233
  200. tpu_inference/lora/__init__.py +13 -0
  201. tpu_inference/lora/torch_lora_ops.py +8 -13
  202. tpu_inference/models/__init__.py +13 -0
  203. tpu_inference/models/common/__init__.py +13 -0
  204. tpu_inference/models/common/model_loader.py +112 -35
  205. tpu_inference/models/jax/__init__.py +13 -0
  206. tpu_inference/models/jax/deepseek_v3.py +267 -157
  207. tpu_inference/models/jax/gpt_oss.py +26 -10
  208. tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
  209. tpu_inference/models/jax/llama3.py +99 -36
  210. tpu_inference/models/jax/llama4.py +14 -0
  211. tpu_inference/models/jax/llama_eagle3.py +18 -5
  212. tpu_inference/models/jax/llama_guard_4.py +15 -1
  213. tpu_inference/models/jax/qwen2.py +17 -2
  214. tpu_inference/models/jax/qwen2_5_vl.py +179 -51
  215. tpu_inference/models/jax/qwen3.py +17 -2
  216. tpu_inference/models/jax/utils/__init__.py +13 -0
  217. tpu_inference/models/jax/utils/file_utils.py +14 -0
  218. tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
  219. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  220. tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +92 -32
  221. tpu_inference/models/jax/utils/weight_utils.py +234 -155
  222. tpu_inference/models/vllm/__init__.py +13 -0
  223. tpu_inference/models/vllm/vllm_model_wrapper.py +32 -8
  224. tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
  225. tpu_inference/platforms/__init__.py +14 -0
  226. tpu_inference/platforms/tpu_platform.py +51 -72
  227. tpu_inference/runner/__init__.py +13 -0
  228. tpu_inference/runner/compilation_manager.py +180 -80
  229. tpu_inference/runner/kv_cache.py +54 -20
  230. tpu_inference/runner/kv_cache_manager.py +55 -33
  231. tpu_inference/runner/lora_utils.py +16 -1
  232. tpu_inference/runner/multimodal_manager.py +16 -2
  233. tpu_inference/runner/persistent_batch_manager.py +54 -2
  234. tpu_inference/runner/speculative_decoding_manager.py +14 -0
  235. tpu_inference/runner/structured_decoding_manager.py +16 -3
  236. tpu_inference/runner/tpu_runner.py +124 -61
  237. tpu_inference/runner/utils.py +2 -2
  238. tpu_inference/spec_decode/__init__.py +13 -0
  239. tpu_inference/spec_decode/jax/__init__.py +13 -0
  240. tpu_inference/spec_decode/jax/eagle3.py +84 -22
  241. tpu_inference/tpu_info.py +14 -0
  242. tpu_inference/utils.py +72 -44
  243. tpu_inference/worker/__init__.py +13 -0
  244. tpu_inference/worker/tpu_worker.py +66 -52
  245. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/METADATA +8 -9
  246. tpu_inference-0.13.2.dev20251230.dist-info/RECORD +266 -0
  247. tpu_inference/layers/vllm/linear_common.py +0 -186
  248. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  249. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
  250. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
  251. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
  252. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
  253. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
  254. tpu_inference-0.11.1.dev202511220812.dist-info/RECORD +0 -174
  255. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/WHEEL +0 -0
  256. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/licenses/LICENSE +0 -0
  257. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,16 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
1
14
  """
2
15
  A variant of TPU-Friendly Ragged Paged Attention kernel optimized for
3
16
  head_dim = 64.
@@ -317,19 +330,21 @@ def _ragged_paged_attention_kernel(
317
330
  q_len = q_end - q_start
318
331
  kv_len = kv_lens_ref[seq_idx]
319
332
 
320
- bkv_idx_start = 0 if sliding_window is None else jnp.maximum(
321
- kv_len - sliding_window, 0) // bkv_sz
322
-
323
333
  if sliding_window is None:
324
- next_bkv_idx_start = 0
334
+ bkv_idx_start = next_seq_bkv_idx_start = 0
325
335
  else:
326
-
327
- def get_next_bkv_idx_start():
328
- next_kv_len = kv_lens_ref[seq_idx + 1]
329
- return jnp.maximum(next_kv_len - sliding_window, 0) // bkv_sz
330
-
331
- next_bkv_idx_start = lax.cond(seq_idx + 1 < num_seqs,
332
- get_next_bkv_idx_start, lambda: 0)
336
+ bkv_idx_start = jnp.maximum(kv_len - q_len - sliding_window,
337
+ 0) // bkv_sz
338
+
339
+ # If seq_idx + 1 == num_seqs, kv_lens_ref[seq_idx + 1] will trigger a
340
+ # out-of-bound error. To avoid this, we set upperbound of next_seq_idx
341
+ # to be num_seqs - 1.
342
+ next_seq_idx = jnp.minimum(seq_idx + 1, num_seqs - 1)
343
+ next_kv_len = kv_lens_ref[next_seq_idx]
344
+ next_q_len = cu_q_lens_ref[next_seq_idx + 1] - q_end
345
+ next_seq_bkv_idx_start = (
346
+ jnp.maximum(next_kv_len - next_q_len - sliding_window, 0) //
347
+ bkv_sz)
333
348
 
334
349
  def debug_print(msg, *args):
335
350
  if debug_mode:
@@ -350,7 +365,7 @@ def _ragged_paged_attention_kernel(
350
365
  debug_print("[RPA debug] q_len={}", q_len)
351
366
  debug_print("[RPA debug] kv_len={}", kv_len)
352
367
 
353
- def flash_attention(
368
+ def flash_attention_step1_qk_softmax(
354
369
  q, # [actual_bq_sz * num_q_heads_per_kv_head, actual_head_dim_x2]
355
370
  kv, # [bkv_sz, actual_head_dim_x2]
356
371
  *,
@@ -364,7 +379,6 @@ def _ragged_paged_attention_kernel(
364
379
  assert kv.shape == (bkv_sz, actual_head_dim_x2)
365
380
  head_l_ref = l_ref.at[kv_head_idx, :q.shape[0]]
366
381
  head_m_ref = m_ref.at[kv_head_idx, :q.shape[0]]
367
- head_acc_ref = acc_ref.at[kv_head_idx, :q.shape[0]]
368
382
 
369
383
  def load_with_init(ref, init_val):
370
384
  return jnp.where(bkv_idx == bkv_idx_start,
@@ -386,16 +400,19 @@ def _ragged_paged_attention_kernel(
386
400
  s *= k_scale
387
401
  if q_scale is not None:
388
402
  s *= q_scale
403
+ if soft_cap is not None:
404
+ s = soft_cap * jnp.tanh(s / soft_cap)
389
405
 
390
406
  q_span = (kv_len - q_len + bq_idx * bq_sz +
391
407
  lax.broadcasted_iota(jnp.int32, s.shape, 0) //
392
408
  num_q_heads_per_kv_head)
393
409
  k_span = bkv_idx * bkv_sz + lax.broadcasted_iota(jnp.int32, s.shape, 1)
394
- mask = q_span < k_span
410
+ mask = k_span <= q_span
395
411
 
396
- if soft_cap is not None:
397
- s = soft_cap * jnp.tanh(s / soft_cap)
398
- s += jnp.where(mask, mask_value, 0.0)
412
+ if sliding_window is not None:
413
+ mask = jnp.logical_and(mask, q_span - sliding_window < k_span)
414
+
415
+ s = jnp.where(mask, s, mask_value)
399
416
  s_rowmax = jnp.max(s, axis=1, keepdims=True)
400
417
 
401
418
  if attention_sink_ref is not None:
@@ -411,15 +428,33 @@ def _ragged_paged_attention_kernel(
411
428
  head_m_ref[...] = m_curr
412
429
  p = jnp.exp(s - broadcast_minor(m_curr, s.shape))
413
430
 
414
- pv = jnp.einsum("nm,md->nd", p, kv, preferred_element_type=jnp.float32)
415
- if v_scale is not None:
416
- pv *= v_scale
417
-
418
431
  p_rowsum = jnp.sum(p, axis=1, keepdims=True)
419
432
  exp_m_diff = jnp.exp(m_prev - m_curr)
420
433
  l_prev = load_with_init(head_l_ref, 1.0)
421
434
  l_curr = exp_m_diff * l_prev + p_rowsum
422
435
  head_l_ref[...] = l_curr
436
+
437
+ return p, exp_m_diff
438
+
439
+ def flash_attention_step2_pv(
440
+ q_shape_0,
441
+ kv, # [bkv_sz, actual_head_dim_x2]
442
+ p, # from step1
443
+ exp_m_diff, # from step1
444
+ *,
445
+ bkv_idx,
446
+ kv_head_idx,
447
+ ):
448
+ head_acc_ref = acc_ref.at[kv_head_idx, :q_shape_0]
449
+
450
+ def load_with_init(ref, init_val):
451
+ return jnp.where(bkv_idx == bkv_idx_start,
452
+ jnp.full_like(ref, init_val), ref[...])
453
+
454
+ pv = jnp.einsum("nm,md->nd", p, kv, preferred_element_type=jnp.float32)
455
+ if v_scale is not None:
456
+ pv *= v_scale
457
+
423
458
  o_prev = load_with_init(head_acc_ref, 0.0)
424
459
  o_curr = broadcast_minor(exp_m_diff, o_prev.shape) * o_prev + pv
425
460
  head_acc_ref[...] = o_curr
@@ -498,19 +533,16 @@ def _ragged_paged_attention_kernel(
498
533
  unroll=False,
499
534
  )
500
535
 
501
- # Fetch kv directly from new kv.
502
- @pl.when(bkv_sz_frm_new > 0)
503
- def _fetch_bkv_from_new_kv():
504
- new_kv_len_start = q_end - kv_left_frm_new
505
- debug_print("[RPA debug] new_kv_len_start={}",
506
- new_kv_len_start)
507
- debug_print("[RPA debug] offset_in_bkv={}", offset)
508
- _async_copy(
509
- kv_hbm_ref.at[pl.ds(new_kv_len_start, bkv_sz_frm_new)],
510
- vmem_ref.at[pl.ds(offset, bkv_sz_frm_new)],
511
- sem,
512
- wait,
513
- )
536
+ size = lax.select(bkv_sz_frm_new > 0, bkv_sz_frm_new, 0)
537
+ new_kv_len_start = q_end - kv_left_frm_new
538
+ debug_print("[RPA debug] new_kv_len_start={}", new_kv_len_start)
539
+ debug_print("[RPA debug] offset_in_bkv={}", offset)
540
+ _async_copy(
541
+ kv_hbm_ref.at[pl.ds(new_kv_len_start, size)],
542
+ vmem_ref.at[pl.ds(offset, size)],
543
+ sem,
544
+ wait,
545
+ )
514
546
 
515
547
  return kv_len_start + offset, bkv_sz_frm_new
516
548
  else:
@@ -704,7 +736,7 @@ def _ragged_paged_attention_kernel(
704
736
  vec = ref[start::step]
705
737
  return vec
706
738
 
707
- def strided_load_bkv(bkv_sem_idx, start, step, *, bkv_mask):
739
+ def strided_load_bkv(bkv_sem_idx, start, step):
708
740
  assert start % kv_packing == 0
709
741
  assert step % kv_packing == 0
710
742
  start //= kv_packing
@@ -713,7 +745,6 @@ def _ragged_paged_attention_kernel(
713
745
  bkv_sz * step, actual_head_dim_x2))
714
746
 
715
747
  kv = strided_load(kv_ref, start, step)
716
- kv = lax.select(bkv_mask, kv, jnp.zeros_like(kv))
717
748
  bitwidth = 32 // kv_packing
718
749
  repack_ty = jnp.dtype(f"uint{bitwidth}")
719
750
  lst = []
@@ -760,13 +791,21 @@ def _ragged_paged_attention_kernel(
760
791
  next_seq_idx = lax.select(is_last_bq, seq_idx + 1, seq_idx)
761
792
  next_bkv_sem_idx = lax.select(bkv_sem_idx == 0, 1, 0)
762
793
 
763
- next_bkv_idx = lax.select(
764
- is_last_bkv,
765
- lax.select(
794
+ if sliding_window is None:
795
+ # When sliding window is disabled, starting bkv_idx of next request is
796
+ # always 0 regardless of seq_idx of next request.
797
+ next_bkv_idx_start = 0
798
+ else:
799
+ # Determine starting bkv_idx of next request based on whether next
800
+ # request is from the same sequence or next sequence.
801
+ next_bkv_idx_start = lax.select(
766
802
  is_last_bq,
767
- next_bkv_idx_start,
803
+ next_seq_bkv_idx_start,
768
804
  bkv_idx_start,
769
- ), next_bkv_idx)
805
+ )
806
+ next_bkv_idx = lax.select(is_last_bkv, next_bkv_idx_start,
807
+ next_bkv_idx)
808
+
770
809
  return next_seq_idx, next_bq_idx, next_bkv_idx, next_bkv_sem_idx
771
810
 
772
811
  def compute_with_bq(bq_idx, _):
@@ -783,10 +822,6 @@ def _ragged_paged_attention_kernel(
783
822
  def compute_with_bkv(bkv_idx, _):
784
823
  # Create bitmask for KV.
785
824
  assert bkv_sz % kv_packing == 0
786
- actual_bkv_sz = jnp.minimum(bkv_sz, kv_len - bkv_idx * bkv_sz)
787
- bkv_shape = (bkv_sz, actual_head_dim_x2)
788
- bkv_mask = lax.broadcasted_iota(jnp.int32, bkv_shape,
789
- 0) < actual_bkv_sz
790
825
 
791
826
  # Get next bkv ids.
792
827
  bkv_sem_idx = sem_ids_ref[1]
@@ -826,29 +861,64 @@ def _ragged_paged_attention_kernel(
826
861
  return
827
862
 
828
863
  # Flash attention with cur bkv and bq
864
+ prev_bq_shape_0 = None
865
+ prev_kv_head_bkv = None
866
+ prev_kv_head_idx = None
867
+ prev_kv_head_p = None
868
+ prev_kv_head_exp_m_diff = None
829
869
  for kv_head_start in range(0, actual_num_kv_heads, kv_packing):
830
870
  bkv_lst = strided_load_bkv(
831
871
  bkv_sem_idx,
832
872
  kv_head_start,
833
873
  num_kv_heads,
834
- bkv_mask=bkv_mask,
835
874
  )
836
875
  assert len(bkv_lst) == kv_packing
837
876
  for i in range(kv_packing):
838
- kv_head_idx = kv_head_start + i
839
- if kv_head_idx >= actual_num_kv_heads:
877
+ cur_kv_head_idx = kv_head_start + i
878
+ if cur_kv_head_idx >= actual_num_kv_heads:
840
879
  break
841
- bq = load_bq(bq_sem_idx,
842
- kv_head_idx,
843
- actual_bq_sz=actual_bq_sz)
844
- bkv = bkv_lst[i]
845
- flash_attention(
846
- bq,
847
- bkv,
848
- bq_idx=bq_idx,
849
- bkv_idx=bkv_idx,
850
- kv_head_idx=kv_head_idx,
851
- )
880
+ cur_kv_head_bq = load_bq(bq_sem_idx,
881
+ cur_kv_head_idx,
882
+ actual_bq_sz=actual_bq_sz)
883
+ cur_kv_head__bkv = bkv_lst[i]
884
+ # FlashAttention is divided into `flash_attention_step1_qk_softmax`
885
+ # and `flash_attention_step2_pv` to pipeline the computation.
886
+ # `step2_pv` for the previous KV head, which depends on the softmax
887
+ # output, is overlapped with `step1_qk_softmax` for the current KV
888
+ # head, reducing overall wait times.
889
+ cur_kv_head_p, cur_kv_head_exp_m_diff = (
890
+ flash_attention_step1_qk_softmax(
891
+ cur_kv_head_bq,
892
+ cur_kv_head__bkv,
893
+ bq_idx=bq_idx,
894
+ bkv_idx=bkv_idx,
895
+ kv_head_idx=cur_kv_head_idx,
896
+ ))
897
+ if prev_bq_shape_0 is not None:
898
+ flash_attention_step2_pv(
899
+ prev_bq_shape_0,
900
+ prev_kv_head_bkv,
901
+ prev_kv_head_p,
902
+ prev_kv_head_exp_m_diff,
903
+ bkv_idx=bkv_idx,
904
+ kv_head_idx=prev_kv_head_idx,
905
+ )
906
+ prev_bq_shape_0 = cur_kv_head_bq.shape[0]
907
+ prev_kv_head_bkv = cur_kv_head__bkv
908
+ prev_kv_head_p = cur_kv_head_p
909
+ prev_kv_head_exp_m_diff = cur_kv_head_exp_m_diff
910
+ prev_kv_head_idx = cur_kv_head_idx
911
+
912
+ # Execute pv of last attention head.
913
+ assert prev_bq_shape_0 is not None
914
+ flash_attention_step2_pv(
915
+ prev_bq_shape_0,
916
+ prev_kv_head_bkv,
917
+ prev_kv_head_p,
918
+ prev_kv_head_exp_m_diff,
919
+ bkv_idx=bkv_idx,
920
+ kv_head_idx=prev_kv_head_idx,
921
+ )
852
922
 
853
923
  lax.fori_loop(bkv_idx_start,
854
924
  num_bkv,
@@ -884,7 +954,17 @@ def _ragged_paged_attention_kernel(
884
954
  @pl.when(seq_idx == 0)
885
955
  def prologue():
886
956
  start_fetch_bq(0, 0, 0)
957
+
958
+ # Initialize bkv_x2_ref to zeros to avoid NaN issues from accessing
959
+ # uninitialized memory. Bitcast into int32 to avoid tiling issues.
960
+ bkv_x2_int32_ref = bkv_x2_ref.bitcast(jnp.int32).reshape(
961
+ (2, -1, 8, 128))
962
+ zeros = jnp.zeros(bkv_x2_int32_ref.shape[1:], jnp.int32)
963
+
964
+ # To pipeline VST and DMA, we divide the initialization into two steps.
965
+ bkv_x2_int32_ref[0] = zeros
887
966
  start_fetch_bkv(0, bkv_idx_start, 0)
967
+ bkv_x2_int32_ref[1] = zeros
888
968
 
889
969
  @pl.when(seq_idx < decode_end)
890
970
  def process_decode():
@@ -1244,6 +1324,10 @@ def static_validate_inputs(
1244
1324
  del debug_mode
1245
1325
 
1246
1326
 
1327
+ def get_kernel_scope_name(bq_size, bkv_p, page_size):
1328
+ return f"RPA-HD_64-bq_{bq_size}-bkvp_{bkv_p}-p_{page_size}-"
1329
+
1330
+
1247
1331
  @functools.partial(
1248
1332
  jax.jit,
1249
1333
  static_argnames=(
@@ -1292,42 +1376,40 @@ def ragged_paged_attention_hd64(
1292
1376
  # Debug params.
1293
1377
  debug_mode: bool = False,
1294
1378
  ):
1295
- """A special Ragged paged attention version for head_dim=64 that supports mixed
1296
-
1297
- prefill and decode.
1298
-
1299
- Args:
1300
- queries: concatenated all sequences' queries.
1301
- keys: concatenated all sequences' keys (quantized).
1302
- values: concatenated all sequences' values (quantized).
1303
- kv_cache: paged KV cache with TPU-friendly shape.
1304
- kv_lens: padded kv lengths. Only the first num_seqs values are valid.
1305
- page_indices: flattened page indices look-up table by (seq_id, page_id).
1306
- cu_q_lens: the cumulative sum of the effective query lengths. Similar to
1307
- kv_lens, only the first num_seqs+1 values are valid.
1308
- distribution: (i, j, k) represents that sequences[0:i] are decode-only,
1309
- sequences[i:j] are chunked-prefill-only, and sequences[j:k] are mixed. The
1310
- k is also the total number of sequences.
1311
- attention_sink: optional attention sink for each q head.
1312
- actual_head_dim: the actual head size of the attention. Here we assume k and
1313
- v have the same actual head size.
1314
- sm_scale: the softmax scale which will be applied to the Q@K^T.
1315
- sliding_window: the sliding window size for the attention.
1316
- soft_cap: the logit soft cap for the attention.
1317
- mask_value: mask value for causal mask.
1318
- k_scale: the scale for the key cache.
1319
- v_scale: the scale for the value cache.
1320
- num_kv_pages_per_block: number of kv pages to be processed in one flash
1321
- attention block in the pallas kernel.
1322
- num_queries_per_block: number of kv pages to be processed in one flash
1323
- attention block in the pallas kernel.
1324
- vmem_limit_bytes: the vmem limit for the pallas kernel.
1325
- debug_mode: if true, RPA does not issue any DMAs or run flash attention but
1326
- print debug info. Need to compile with `--xla_tpu_enable_log_recorder`.
1327
-
1328
- Returns:
1329
- The output of the attention.
1330
- """
1379
+ """A variant of ragged paged attention for head_dim=64.
1380
+
1381
+ Args:
1382
+ queries: concatenated all sequences' queries.
1383
+ keys: concatenated all sequences' keys (quantized).
1384
+ values: concatenated all sequences' values (quantized).
1385
+ kv_cache: paged KV cache with TPU-friendly shape.
1386
+ kv_lens: padded kv lengths. Only the first num_seqs values are valid.
1387
+ page_indices: flattened page indices look-up table by (seq_id, page_id).
1388
+ cu_q_lens: the cumulative sum of the effective query lengths. Similar to
1389
+ kv_lens, only the first num_seqs+1 values are valid.
1390
+ distribution: (i, j, k) represents that sequences[0:i] are decode-only,
1391
+ sequences[i:j] are chunked-prefill-only, and sequences[j:k] are mixed. The
1392
+ k is also the total number of sequences.
1393
+ attention_sink: optional attention sink for each q head.
1394
+ sm_scale: the softmax scale which will be applied to the Q@K^T.
1395
+ sliding_window: the sliding window size for the attention.
1396
+ soft_cap: the logit soft cap for the attention.
1397
+ mask_value: mask value for causal mask.
1398
+ q_scale: the scale for the query.
1399
+ k_scale: the scale for the key cache.
1400
+ v_scale: the scale for the value cache.
1401
+ chunk_prefill_size: the chunk prefill size for the attention.
1402
+ num_kv_pages_per_block: number of kv pages to be processed in one flash
1403
+ attention block in the pallas kernel.
1404
+ num_queries_per_block: number of kv pages to be processed in one flash
1405
+ attention block in the pallas kernel.
1406
+ vmem_limit_bytes: the vmem limit for the pallas kernel.
1407
+ debug_mode: if true, RPA does not issue any DMAs or run flash attention but
1408
+ print debug info. Need to compile with `--xla_tpu_enable_log_recorder`.
1409
+
1410
+ Returns:
1411
+ The output of the attention.
1412
+ """
1331
1413
  q, k, v = queries, keys, values
1332
1414
  static_validate_inputs(
1333
1415
  q,
@@ -1384,6 +1466,7 @@ def ragged_paged_attention_hd64(
1384
1466
  page_size,
1385
1467
  max_num_tokens,
1386
1468
  pages_per_seq,
1469
+ sliding_window,
1387
1470
  )
1388
1471
  bkv_sz = bkv_p * page_size
1389
1472
  if vmem_limit_bytes is None:
@@ -1397,7 +1480,7 @@ def ragged_paged_attention_hd64(
1397
1480
  pl.BlockSpec(memory_space=pltpu.HBM),
1398
1481
  pl.BlockSpec(memory_space=pltpu.HBM),
1399
1482
  None if attention_sink is None else pl.BlockSpec(
1400
- memory_space=pltpu.VMEM)
1483
+ memory_space=pltpu.VMEM),
1401
1484
  ]
1402
1485
 
1403
1486
  out_specs = [
@@ -1454,47 +1537,45 @@ def ragged_paged_attention_hd64(
1454
1537
  jnp.full((6, ), -1, jnp.int32),
1455
1538
  )
1456
1539
 
1457
- scope_name = f"RPA-HD_64-bq_{bq_sz}-bkvp_{bkv_p}-p_{page_size}"
1458
- kernel = jax.named_scope(scope_name)(
1459
- pl.pallas_call(
1460
- functools.partial(
1461
- _ragged_paged_attention_kernel,
1462
- sm_scale=sm_scale,
1463
- sliding_window=sliding_window,
1464
- soft_cap=soft_cap,
1465
- mask_value=mask_value,
1466
- q_scale=q_scale,
1467
- k_scale=k_scale,
1468
- v_scale=v_scale,
1469
- chunk_prefill_size=chunk_prefill_size,
1470
- bq_sz=bq_sz,
1471
- bkv_p=bkv_p,
1472
- debug_mode=debug_mode,
1473
- ),
1474
- grid_spec=pltpu.PrefetchScalarGridSpec(
1475
- num_scalar_prefetch=len(scalar_prefetches),
1476
- in_specs=in_specs,
1477
- out_specs=out_specs,
1478
- grid=grid,
1479
- scratch_shapes=scratch_shapes,
1480
- ),
1481
- compiler_params=pltpu.CompilerParams(
1482
- # TODO(jevinjiang): since each sequence depends on the previous
1483
- # one, we need some extra work to support Megacore mode.
1484
- dimension_semantics=("arbitrary", ),
1485
- vmem_limit_bytes=vmem_limit_bytes,
1486
- ),
1487
- out_shape=[
1488
- jax.ShapeDtypeStruct(shape=q.shape, dtype=q.dtype),
1489
- jax.ShapeDtypeStruct(shape=kv_cache.shape,
1490
- dtype=kv_cache.dtype),
1491
- ],
1492
- input_output_aliases={
1493
- 7: 0,
1494
- 9: 1
1495
- },
1496
- name=scope_name,
1497
- ))
1540
+ scope_name = get_kernel_scope_name(bq_sz, bkv_p, page_size)
1541
+ kernel = pl.pallas_call(
1542
+ functools.partial(
1543
+ _ragged_paged_attention_kernel,
1544
+ sm_scale=sm_scale,
1545
+ sliding_window=sliding_window,
1546
+ soft_cap=soft_cap,
1547
+ mask_value=mask_value,
1548
+ q_scale=q_scale,
1549
+ k_scale=k_scale,
1550
+ v_scale=v_scale,
1551
+ chunk_prefill_size=chunk_prefill_size,
1552
+ bq_sz=bq_sz,
1553
+ bkv_p=bkv_p,
1554
+ debug_mode=debug_mode,
1555
+ ),
1556
+ grid_spec=pltpu.PrefetchScalarGridSpec(
1557
+ num_scalar_prefetch=len(scalar_prefetches),
1558
+ in_specs=in_specs,
1559
+ out_specs=out_specs,
1560
+ grid=grid,
1561
+ scratch_shapes=scratch_shapes,
1562
+ ),
1563
+ compiler_params=pltpu.CompilerParams(
1564
+ # TODO(jevinjiang): since each sequence depends on the previous
1565
+ # one, we need some extra work to support Megacore mode.
1566
+ dimension_semantics=("arbitrary", ),
1567
+ vmem_limit_bytes=vmem_limit_bytes,
1568
+ ),
1569
+ out_shape=[
1570
+ jax.ShapeDtypeStruct(shape=q.shape, dtype=q.dtype),
1571
+ jax.ShapeDtypeStruct(shape=kv_cache.shape, dtype=kv_cache.dtype),
1572
+ ],
1573
+ input_output_aliases={
1574
+ 7: 0,
1575
+ 9: 1
1576
+ },
1577
+ name=scope_name,
1578
+ )
1498
1579
 
1499
1580
  output, updated_kv_cache = kernel(*scalar_prefetches, q, kv, kv_cache,
1500
1581
  attention_sink)