tpu-inference 0.11.1.dev202511270815__py3-none-any.whl → 0.13.0rc2.post7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (251) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_disagg_utils.py +14 -0
  4. tests/core/test_dp_scheduler.py +650 -768
  5. tests/core/test_init.py +14 -0
  6. tests/distributed/__init__.py +13 -0
  7. tests/distributed/test_distributed_utils.py +120 -0
  8. tests/distributed/test_tpu_connector.py +478 -0
  9. tests/e2e/__init__.py +13 -0
  10. tests/e2e/test_async_scheduler.py +211 -0
  11. tests/e2e/test_data_parallel.py +289 -0
  12. tests/e2e/test_hybrid_kvcache.py +219 -0
  13. tests/e2e/test_local_disagg.py +257 -0
  14. tests/e2e/test_model_loader.py +268 -0
  15. tests/e2e/test_multi_modal_inference.py +111 -0
  16. tests/e2e/test_pipeline_parallel.py +265 -0
  17. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  18. tests/e2e/test_sampling_params.py +269 -0
  19. tests/e2e/test_speculative_decoding.py +311 -0
  20. tests/e2e/test_structured_decoding.py +46 -0
  21. tests/executors/__init__.py +13 -0
  22. tests/executors/test_ray_distributed_executor.py +199 -0
  23. tests/experimental/__init__.py +13 -0
  24. tests/experimental/test_llama3_jax_stashed.py +208 -0
  25. tests/kernels/__init__.py +13 -0
  26. tests/kernels/collectives/__init__.py +13 -0
  27. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  28. tests/kernels/fused_moe_v1_test.py +14 -0
  29. tests/kernels/gmm_test.py +205 -0
  30. tests/kernels/mla_v1_test.py +143 -41
  31. tests/kernels/quantized_matmul_kernel_test.py +2 -34
  32. tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
  33. tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
  34. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
  35. tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
  36. tests/layers/__init__.py +13 -0
  37. tests/layers/common/__init__.py +13 -0
  38. tests/layers/common/test_attention_interface.py +156 -0
  39. tests/layers/common/test_quantization.py +149 -0
  40. tests/layers/jax/__init__.py +13 -0
  41. tests/layers/jax/attention/__init__.py +13 -0
  42. tests/layers/jax/attention/test_common_attention.py +103 -0
  43. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  44. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  45. tests/layers/jax/moe/__init__.py +13 -0
  46. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  47. tests/layers/jax/sample/__init__.py +13 -0
  48. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  49. tests/layers/jax/sample/test_sampling.py +115 -0
  50. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  51. tests/layers/jax/test_layers.py +155 -0
  52. tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
  53. tests/layers/jax/test_rope.py +93 -0
  54. tests/layers/jax/test_sharding.py +159 -0
  55. tests/layers/jax/test_transformer_block.py +152 -0
  56. tests/layers/vllm/__init__.py +13 -0
  57. tests/layers/vllm/test_attention.py +363 -0
  58. tests/layers/vllm/test_awq.py +405 -0
  59. tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
  60. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +418 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +441 -0
  62. tests/layers/vllm/test_fp8.py +17 -0
  63. tests/layers/vllm/test_mxfp4.py +312 -0
  64. tests/layers/vllm/test_unquantized.py +651 -0
  65. tests/layers/vllm/utils.py +87 -0
  66. tests/lora/__init__.py +13 -0
  67. tests/lora/conftest.py +14 -0
  68. tests/lora/test_bgmv.py +14 -0
  69. tests/lora/test_layers.py +21 -3
  70. tests/lora/test_lora.py +15 -1
  71. tests/lora/test_lora_perf.py +67 -0
  72. tests/models/__init__.py +13 -0
  73. tests/models/common/__init__.py +13 -0
  74. tests/models/common/test_model_loader.py +455 -0
  75. tests/models/jax/__init__.py +13 -0
  76. tests/models/jax/test_deepseek_v3.py +401 -0
  77. tests/models/jax/test_llama3.py +184 -0
  78. tests/models/jax/test_llama4.py +298 -0
  79. tests/models/jax/test_llama_eagle3.py +197 -0
  80. tests/models/jax/test_llama_guard_4.py +242 -0
  81. tests/models/jax/test_qwen2.py +172 -0
  82. tests/models/jax/test_qwen2_5_vl.py +605 -0
  83. tests/models/jax/test_qwen3.py +169 -0
  84. tests/models/jax/test_weight_loading.py +180 -0
  85. tests/models/jax/utils/__init__.py +13 -0
  86. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  87. tests/platforms/__init__.py +13 -0
  88. tests/platforms/test_tpu_platform.py +54 -0
  89. tests/runner/__init__.py +13 -0
  90. tests/runner/test_block_table.py +395 -0
  91. tests/runner/test_input_batch.py +226 -0
  92. tests/runner/test_kv_cache.py +220 -0
  93. tests/runner/test_kv_cache_manager.py +498 -0
  94. tests/runner/test_multimodal_manager.py +429 -0
  95. tests/runner/test_persistent_batch_manager.py +84 -0
  96. tests/runner/test_speculative_decoding_manager.py +368 -0
  97. tests/runner/test_structured_decoding_manager.py +220 -0
  98. tests/runner/test_tpu_runner.py +261 -0
  99. tests/runner/test_tpu_runner_dp.py +1099 -0
  100. tests/runner/test_tpu_runner_mesh.py +200 -0
  101. tests/runner/test_utils.py +411 -0
  102. tests/spec_decode/__init__.py +13 -0
  103. tests/spec_decode/test_eagle3.py +311 -0
  104. tests/test_base.py +14 -0
  105. tests/test_envs.py +110 -12
  106. tests/test_tpu_info.py +14 -0
  107. tests/test_utils.py +2 -45
  108. tests/worker/__init__.py +13 -0
  109. tests/worker/tpu_worker_test.py +414 -0
  110. tpu_inference/__init__.py +14 -0
  111. tpu_inference/core/__init__.py +13 -0
  112. tpu_inference/core/sched/__init__.py +13 -0
  113. tpu_inference/core/sched/dp_scheduler.py +372 -56
  114. tpu_inference/distributed/__init__.py +13 -0
  115. tpu_inference/distributed/jax_parallel_state.py +14 -0
  116. tpu_inference/distributed/tpu_connector.py +15 -10
  117. tpu_inference/distributed/utils.py +56 -4
  118. tpu_inference/envs.py +92 -8
  119. tpu_inference/executors/__init__.py +13 -0
  120. tpu_inference/executors/ray_distributed_executor.py +22 -1
  121. tpu_inference/experimental/__init__.py +13 -0
  122. tpu_inference/experimental/llama3_jax_stashed.py +14 -0
  123. tpu_inference/kernels/__init__.py +13 -0
  124. tpu_inference/kernels/collectives/__init__.py +13 -0
  125. tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
  126. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
  127. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  128. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  129. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  130. tpu_inference/kernels/fused_moe/v1/kernel.py +370 -324
  131. tpu_inference/kernels/megablox/__init__.py +13 -0
  132. tpu_inference/kernels/megablox/common.py +54 -0
  133. tpu_inference/kernels/megablox/gmm.py +646 -0
  134. tpu_inference/kernels/mla/__init__.py +13 -0
  135. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  136. tpu_inference/kernels/mla/v1/kernel.py +117 -145
  137. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  138. tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
  139. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  140. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  141. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
  142. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
  143. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  144. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
  145. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +167 -97
  146. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
  147. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
  148. tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
  149. tpu_inference/layers/__init__.py +13 -0
  150. tpu_inference/layers/common/__init__.py +13 -0
  151. tpu_inference/layers/common/attention_interface.py +26 -19
  152. tpu_inference/layers/common/attention_metadata.py +14 -0
  153. tpu_inference/layers/common/quant_methods.py +15 -0
  154. tpu_inference/layers/common/quantization.py +270 -0
  155. tpu_inference/layers/common/sharding.py +31 -9
  156. tpu_inference/layers/jax/__init__.py +13 -0
  157. tpu_inference/layers/jax/attention/__init__.py +13 -0
  158. tpu_inference/layers/jax/attention/attention.py +19 -6
  159. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
  160. tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
  161. tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
  162. tpu_inference/layers/jax/base.py +14 -0
  163. tpu_inference/layers/jax/constants.py +13 -0
  164. tpu_inference/layers/jax/layers.py +14 -0
  165. tpu_inference/layers/jax/misc.py +14 -0
  166. tpu_inference/layers/jax/moe/__init__.py +13 -0
  167. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
  168. tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
  169. tpu_inference/layers/jax/moe/moe.py +43 -3
  170. tpu_inference/layers/jax/pp_utils.py +53 -0
  171. tpu_inference/layers/jax/rope.py +14 -0
  172. tpu_inference/layers/jax/rope_interface.py +14 -0
  173. tpu_inference/layers/jax/sample/__init__.py +13 -0
  174. tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
  175. tpu_inference/layers/jax/sample/sampling.py +15 -1
  176. tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
  177. tpu_inference/layers/jax/transformer_block.py +14 -0
  178. tpu_inference/layers/vllm/__init__.py +13 -0
  179. tpu_inference/layers/vllm/attention.py +4 -4
  180. tpu_inference/layers/vllm/fused_moe.py +210 -260
  181. tpu_inference/layers/vllm/linear_common.py +57 -22
  182. tpu_inference/layers/vllm/quantization/__init__.py +16 -0
  183. tpu_inference/layers/vllm/quantization/awq.py +15 -1
  184. tpu_inference/layers/vllm/quantization/common.py +33 -18
  185. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  186. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +18 -3
  187. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +211 -148
  188. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  189. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +14 -0
  190. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +14 -0
  191. tpu_inference/layers/vllm/quantization/fp8.py +118 -0
  192. tpu_inference/layers/vllm/quantization/mxfp4.py +280 -210
  193. tpu_inference/layers/vllm/quantization/unquantized.py +134 -86
  194. tpu_inference/layers/vllm/sharding.py +21 -4
  195. tpu_inference/lora/__init__.py +13 -0
  196. tpu_inference/lora/torch_lora_ops.py +8 -13
  197. tpu_inference/models/__init__.py +13 -0
  198. tpu_inference/models/common/__init__.py +13 -0
  199. tpu_inference/models/common/model_loader.py +77 -36
  200. tpu_inference/models/jax/__init__.py +13 -0
  201. tpu_inference/models/jax/deepseek_v3.py +267 -157
  202. tpu_inference/models/jax/gpt_oss.py +26 -10
  203. tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
  204. tpu_inference/models/jax/llama3.py +99 -36
  205. tpu_inference/models/jax/llama4.py +14 -0
  206. tpu_inference/models/jax/llama_eagle3.py +14 -0
  207. tpu_inference/models/jax/llama_guard_4.py +15 -1
  208. tpu_inference/models/jax/qwen2.py +17 -2
  209. tpu_inference/models/jax/qwen2_5_vl.py +18 -4
  210. tpu_inference/models/jax/qwen3.py +17 -2
  211. tpu_inference/models/jax/utils/__init__.py +13 -0
  212. tpu_inference/models/jax/utils/file_utils.py +14 -0
  213. tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
  214. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  215. tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +91 -31
  216. tpu_inference/models/jax/utils/weight_utils.py +39 -2
  217. tpu_inference/models/vllm/__init__.py +13 -0
  218. tpu_inference/models/vllm/vllm_model_wrapper.py +20 -4
  219. tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
  220. tpu_inference/platforms/__init__.py +14 -0
  221. tpu_inference/platforms/tpu_platform.py +47 -71
  222. tpu_inference/runner/__init__.py +13 -0
  223. tpu_inference/runner/compilation_manager.py +158 -63
  224. tpu_inference/runner/kv_cache.py +54 -20
  225. tpu_inference/runner/kv_cache_manager.py +53 -30
  226. tpu_inference/runner/lora_utils.py +14 -0
  227. tpu_inference/runner/multimodal_manager.py +15 -1
  228. tpu_inference/runner/persistent_batch_manager.py +54 -2
  229. tpu_inference/runner/speculative_decoding_manager.py +14 -0
  230. tpu_inference/runner/structured_decoding_manager.py +14 -0
  231. tpu_inference/runner/tpu_runner.py +105 -57
  232. tpu_inference/runner/utils.py +2 -2
  233. tpu_inference/spec_decode/__init__.py +13 -0
  234. tpu_inference/spec_decode/jax/__init__.py +13 -0
  235. tpu_inference/spec_decode/jax/eagle3.py +65 -19
  236. tpu_inference/tpu_info.py +14 -0
  237. tpu_inference/utils.py +72 -44
  238. tpu_inference/worker/__init__.py +13 -0
  239. tpu_inference/worker/tpu_worker.py +65 -52
  240. {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/METADATA +11 -9
  241. tpu_inference-0.13.0rc2.post7.dist-info/RECORD +261 -0
  242. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  243. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
  244. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
  245. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
  246. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
  247. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
  248. tpu_inference-0.11.1.dev202511270815.dist-info/RECORD +0 -174
  249. {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/WHEEL +0 -0
  250. {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/licenses/LICENSE +0 -0
  251. {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,16 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
1
14
  """
2
15
  A variant of TPU-Friendly Ragged Paged Attention kernel optimized for
3
16
  head_dim = 64.
@@ -267,7 +280,6 @@ def _ragged_paged_attention_kernel(
267
280
  *,
268
281
  sm_scale: float,
269
282
  sliding_window: int | None = None,
270
- strict_sliding_window: bool = True,
271
283
  soft_cap: float | None = None,
272
284
  mask_value: float = DEFAULT_MASK_VALUE,
273
285
  q_scale: float | None = None,
@@ -324,14 +336,15 @@ def _ragged_paged_attention_kernel(
324
336
  bkv_idx_start = jnp.maximum(kv_len - q_len - sliding_window,
325
337
  0) // bkv_sz
326
338
 
327
- def get_next_bkv_idx_start():
328
- next_kv_len = kv_lens_ref[seq_idx + 1]
329
- next_q_len = cu_q_lens_ref[seq_idx + 2] - q_end
330
- return jnp.maximum(next_kv_len - next_q_len - sliding_window,
331
- 0) // bkv_sz
332
-
333
- next_seq_bkv_idx_start = lax.cond(seq_idx + 1 < num_seqs,
334
- get_next_bkv_idx_start, lambda: 0)
339
+ # If seq_idx + 1 == num_seqs, kv_lens_ref[seq_idx + 1] will trigger a
340
+ # out-of-bound error. To avoid this, we set upperbound of next_seq_idx
341
+ # to be num_seqs - 1.
342
+ next_seq_idx = jnp.minimum(seq_idx + 1, num_seqs - 1)
343
+ next_kv_len = kv_lens_ref[next_seq_idx]
344
+ next_q_len = cu_q_lens_ref[next_seq_idx + 1] - q_end
345
+ next_seq_bkv_idx_start = (
346
+ jnp.maximum(next_kv_len - next_q_len - sliding_window, 0) //
347
+ bkv_sz)
335
348
 
336
349
  def debug_print(msg, *args):
337
350
  if debug_mode:
@@ -352,7 +365,7 @@ def _ragged_paged_attention_kernel(
352
365
  debug_print("[RPA debug] q_len={}", q_len)
353
366
  debug_print("[RPA debug] kv_len={}", kv_len)
354
367
 
355
- def flash_attention(
368
+ def flash_attention_step1_qk_softmax(
356
369
  q, # [actual_bq_sz * num_q_heads_per_kv_head, actual_head_dim_x2]
357
370
  kv, # [bkv_sz, actual_head_dim_x2]
358
371
  *,
@@ -366,7 +379,6 @@ def _ragged_paged_attention_kernel(
366
379
  assert kv.shape == (bkv_sz, actual_head_dim_x2)
367
380
  head_l_ref = l_ref.at[kv_head_idx, :q.shape[0]]
368
381
  head_m_ref = m_ref.at[kv_head_idx, :q.shape[0]]
369
- head_acc_ref = acc_ref.at[kv_head_idx, :q.shape[0]]
370
382
 
371
383
  def load_with_init(ref, init_val):
372
384
  return jnp.where(bkv_idx == bkv_idx_start,
@@ -397,7 +409,7 @@ def _ragged_paged_attention_kernel(
397
409
  k_span = bkv_idx * bkv_sz + lax.broadcasted_iota(jnp.int32, s.shape, 1)
398
410
  mask = k_span <= q_span
399
411
 
400
- if sliding_window is not None and strict_sliding_window:
412
+ if sliding_window is not None:
401
413
  mask = jnp.logical_and(mask, q_span - sliding_window < k_span)
402
414
 
403
415
  s = jnp.where(mask, s, mask_value)
@@ -416,15 +428,33 @@ def _ragged_paged_attention_kernel(
416
428
  head_m_ref[...] = m_curr
417
429
  p = jnp.exp(s - broadcast_minor(m_curr, s.shape))
418
430
 
419
- pv = jnp.einsum("nm,md->nd", p, kv, preferred_element_type=jnp.float32)
420
- if v_scale is not None:
421
- pv *= v_scale
422
-
423
431
  p_rowsum = jnp.sum(p, axis=1, keepdims=True)
424
432
  exp_m_diff = jnp.exp(m_prev - m_curr)
425
433
  l_prev = load_with_init(head_l_ref, 1.0)
426
434
  l_curr = exp_m_diff * l_prev + p_rowsum
427
435
  head_l_ref[...] = l_curr
436
+
437
+ return p, exp_m_diff
438
+
439
+ def flash_attention_step2_pv(
440
+ q_shape_0,
441
+ kv, # [bkv_sz, actual_head_dim_x2]
442
+ p, # from step1
443
+ exp_m_diff, # from step1
444
+ *,
445
+ bkv_idx,
446
+ kv_head_idx,
447
+ ):
448
+ head_acc_ref = acc_ref.at[kv_head_idx, :q_shape_0]
449
+
450
+ def load_with_init(ref, init_val):
451
+ return jnp.where(bkv_idx == bkv_idx_start,
452
+ jnp.full_like(ref, init_val), ref[...])
453
+
454
+ pv = jnp.einsum("nm,md->nd", p, kv, preferred_element_type=jnp.float32)
455
+ if v_scale is not None:
456
+ pv *= v_scale
457
+
428
458
  o_prev = load_with_init(head_acc_ref, 0.0)
429
459
  o_curr = broadcast_minor(exp_m_diff, o_prev.shape) * o_prev + pv
430
460
  head_acc_ref[...] = o_curr
@@ -503,19 +533,16 @@ def _ragged_paged_attention_kernel(
503
533
  unroll=False,
504
534
  )
505
535
 
506
- # Fetch kv directly from new kv.
507
- @pl.when(bkv_sz_frm_new > 0)
508
- def _fetch_bkv_from_new_kv():
509
- new_kv_len_start = q_end - kv_left_frm_new
510
- debug_print("[RPA debug] new_kv_len_start={}",
511
- new_kv_len_start)
512
- debug_print("[RPA debug] offset_in_bkv={}", offset)
513
- _async_copy(
514
- kv_hbm_ref.at[pl.ds(new_kv_len_start, bkv_sz_frm_new)],
515
- vmem_ref.at[pl.ds(offset, bkv_sz_frm_new)],
516
- sem,
517
- wait,
518
- )
536
+ size = lax.select(bkv_sz_frm_new > 0, bkv_sz_frm_new, 0)
537
+ new_kv_len_start = q_end - kv_left_frm_new
538
+ debug_print("[RPA debug] new_kv_len_start={}", new_kv_len_start)
539
+ debug_print("[RPA debug] offset_in_bkv={}", offset)
540
+ _async_copy(
541
+ kv_hbm_ref.at[pl.ds(new_kv_len_start, size)],
542
+ vmem_ref.at[pl.ds(offset, size)],
543
+ sem,
544
+ wait,
545
+ )
519
546
 
520
547
  return kv_len_start + offset, bkv_sz_frm_new
521
548
  else:
@@ -709,7 +736,7 @@ def _ragged_paged_attention_kernel(
709
736
  vec = ref[start::step]
710
737
  return vec
711
738
 
712
- def strided_load_bkv(bkv_sem_idx, start, step, *, bkv_mask):
739
+ def strided_load_bkv(bkv_sem_idx, start, step):
713
740
  assert start % kv_packing == 0
714
741
  assert step % kv_packing == 0
715
742
  start //= kv_packing
@@ -718,7 +745,6 @@ def _ragged_paged_attention_kernel(
718
745
  bkv_sz * step, actual_head_dim_x2))
719
746
 
720
747
  kv = strided_load(kv_ref, start, step)
721
- kv = lax.select(bkv_mask, kv, jnp.zeros_like(kv))
722
748
  bitwidth = 32 // kv_packing
723
749
  repack_ty = jnp.dtype(f"uint{bitwidth}")
724
750
  lst = []
@@ -766,14 +792,18 @@ def _ragged_paged_attention_kernel(
766
792
  next_bkv_sem_idx = lax.select(bkv_sem_idx == 0, 1, 0)
767
793
 
768
794
  if sliding_window is None:
769
- next_bkv_start_idx = 0
795
+ # When sliding window is disabled, starting bkv_idx of next request is
796
+ # always 0 regardless of seq_idx of next request.
797
+ next_bkv_idx_start = 0
770
798
  else:
771
- next_bkv_start_idx = lax.select(
799
+ # Determine starting bkv_idx of next request based on whether next
800
+ # request is from the same sequence or next sequence.
801
+ next_bkv_idx_start = lax.select(
772
802
  is_last_bq,
773
803
  next_seq_bkv_idx_start,
774
804
  bkv_idx_start,
775
805
  )
776
- next_bkv_idx = lax.select(is_last_bkv, next_bkv_start_idx,
806
+ next_bkv_idx = lax.select(is_last_bkv, next_bkv_idx_start,
777
807
  next_bkv_idx)
778
808
 
779
809
  return next_seq_idx, next_bq_idx, next_bkv_idx, next_bkv_sem_idx
@@ -792,10 +822,6 @@ def _ragged_paged_attention_kernel(
792
822
  def compute_with_bkv(bkv_idx, _):
793
823
  # Create bitmask for KV.
794
824
  assert bkv_sz % kv_packing == 0
795
- actual_bkv_sz = jnp.minimum(bkv_sz, kv_len - bkv_idx * bkv_sz)
796
- bkv_shape = (bkv_sz, actual_head_dim_x2)
797
- bkv_mask = lax.broadcasted_iota(jnp.int32, bkv_shape,
798
- 0) < actual_bkv_sz
799
825
 
800
826
  # Get next bkv ids.
801
827
  bkv_sem_idx = sem_ids_ref[1]
@@ -835,29 +861,64 @@ def _ragged_paged_attention_kernel(
835
861
  return
836
862
 
837
863
  # Flash attention with cur bkv and bq
864
+ prev_bq_shape_0 = None
865
+ prev_kv_head_bkv = None
866
+ prev_kv_head_idx = None
867
+ prev_kv_head_p = None
868
+ prev_kv_head_exp_m_diff = None
838
869
  for kv_head_start in range(0, actual_num_kv_heads, kv_packing):
839
870
  bkv_lst = strided_load_bkv(
840
871
  bkv_sem_idx,
841
872
  kv_head_start,
842
873
  num_kv_heads,
843
- bkv_mask=bkv_mask,
844
874
  )
845
875
  assert len(bkv_lst) == kv_packing
846
876
  for i in range(kv_packing):
847
- kv_head_idx = kv_head_start + i
848
- if kv_head_idx >= actual_num_kv_heads:
877
+ cur_kv_head_idx = kv_head_start + i
878
+ if cur_kv_head_idx >= actual_num_kv_heads:
849
879
  break
850
- bq = load_bq(bq_sem_idx,
851
- kv_head_idx,
852
- actual_bq_sz=actual_bq_sz)
853
- bkv = bkv_lst[i]
854
- flash_attention(
855
- bq,
856
- bkv,
857
- bq_idx=bq_idx,
858
- bkv_idx=bkv_idx,
859
- kv_head_idx=kv_head_idx,
860
- )
880
+ cur_kv_head_bq = load_bq(bq_sem_idx,
881
+ cur_kv_head_idx,
882
+ actual_bq_sz=actual_bq_sz)
883
+ cur_kv_head__bkv = bkv_lst[i]
884
+ # FlashAttention is divided into `flash_attention_step1_qk_softmax`
885
+ # and `flash_attention_step2_pv` to pipeline the computation.
886
+ # `step2_pv` for the previous KV head, which depends on the softmax
887
+ # output, is overlapped with `step1_qk_softmax` for the current KV
888
+ # head, reducing overall wait times.
889
+ cur_kv_head_p, cur_kv_head_exp_m_diff = (
890
+ flash_attention_step1_qk_softmax(
891
+ cur_kv_head_bq,
892
+ cur_kv_head__bkv,
893
+ bq_idx=bq_idx,
894
+ bkv_idx=bkv_idx,
895
+ kv_head_idx=cur_kv_head_idx,
896
+ ))
897
+ if prev_bq_shape_0 is not None:
898
+ flash_attention_step2_pv(
899
+ prev_bq_shape_0,
900
+ prev_kv_head_bkv,
901
+ prev_kv_head_p,
902
+ prev_kv_head_exp_m_diff,
903
+ bkv_idx=bkv_idx,
904
+ kv_head_idx=prev_kv_head_idx,
905
+ )
906
+ prev_bq_shape_0 = cur_kv_head_bq.shape[0]
907
+ prev_kv_head_bkv = cur_kv_head__bkv
908
+ prev_kv_head_p = cur_kv_head_p
909
+ prev_kv_head_exp_m_diff = cur_kv_head_exp_m_diff
910
+ prev_kv_head_idx = cur_kv_head_idx
911
+
912
+ # Execute pv of last attention head.
913
+ assert prev_bq_shape_0 is not None
914
+ flash_attention_step2_pv(
915
+ prev_bq_shape_0,
916
+ prev_kv_head_bkv,
917
+ prev_kv_head_p,
918
+ prev_kv_head_exp_m_diff,
919
+ bkv_idx=bkv_idx,
920
+ kv_head_idx=prev_kv_head_idx,
921
+ )
861
922
 
862
923
  lax.fori_loop(bkv_idx_start,
863
924
  num_bkv,
@@ -893,7 +954,17 @@ def _ragged_paged_attention_kernel(
893
954
  @pl.when(seq_idx == 0)
894
955
  def prologue():
895
956
  start_fetch_bq(0, 0, 0)
957
+
958
+ # Initialize bkv_x2_ref to zeros to avoid NaN issues from accessing
959
+ # uninitialized memory. Bitcast into int32 to avoid tiling issues.
960
+ bkv_x2_int32_ref = bkv_x2_ref.bitcast(jnp.int32).reshape(
961
+ (2, -1, 8, 128))
962
+ zeros = jnp.zeros(bkv_x2_int32_ref.shape[1:], jnp.int32)
963
+
964
+ # To pipeline VST and DMA, we divide the initialization into two steps.
965
+ bkv_x2_int32_ref[0] = zeros
896
966
  start_fetch_bkv(0, bkv_idx_start, 0)
967
+ bkv_x2_int32_ref[1] = zeros
897
968
 
898
969
  @pl.when(seq_idx < decode_end)
899
970
  def process_decode():
@@ -1253,12 +1324,15 @@ def static_validate_inputs(
1253
1324
  del debug_mode
1254
1325
 
1255
1326
 
1327
+ def get_kernel_scope_name(bq_size, bkv_p, page_size):
1328
+ return f"RPA-HD_64-bq_{bq_size}-bkvp_{bkv_p}-p_{page_size}-"
1329
+
1330
+
1256
1331
  @functools.partial(
1257
1332
  jax.jit,
1258
1333
  static_argnames=(
1259
1334
  "sm_scale",
1260
1335
  "sliding_window",
1261
- "strict_sliding_window",
1262
1336
  "soft_cap",
1263
1337
  "mask_value",
1264
1338
  "q_scale",
@@ -1288,7 +1362,6 @@ def ragged_paged_attention_hd64(
1288
1362
  *,
1289
1363
  sm_scale: float = 1.0,
1290
1364
  sliding_window: int | None = None,
1291
- strict_sliding_window: bool = True,
1292
1365
  soft_cap: float | None = None,
1293
1366
  mask_value: float | None = DEFAULT_MASK_VALUE,
1294
1367
  q_scale: float | None = None,
@@ -1320,7 +1393,6 @@ def ragged_paged_attention_hd64(
1320
1393
  attention_sink: optional attention sink for each q head.
1321
1394
  sm_scale: the softmax scale which will be applied to the Q@K^T.
1322
1395
  sliding_window: the sliding window size for the attention.
1323
- strict_sliding_window: compute tokens that are strictly within the window.
1324
1396
  soft_cap: the logit soft cap for the attention.
1325
1397
  mask_value: mask value for causal mask.
1326
1398
  q_scale: the scale for the query.
@@ -1394,6 +1466,7 @@ def ragged_paged_attention_hd64(
1394
1466
  page_size,
1395
1467
  max_num_tokens,
1396
1468
  pages_per_seq,
1469
+ sliding_window,
1397
1470
  )
1398
1471
  bkv_sz = bkv_p * page_size
1399
1472
  if vmem_limit_bytes is None:
@@ -1464,48 +1537,45 @@ def ragged_paged_attention_hd64(
1464
1537
  jnp.full((6, ), -1, jnp.int32),
1465
1538
  )
1466
1539
 
1467
- scope_name = f"RPA-HD_64-bq_{bq_sz}-bkvp_{bkv_p}-p_{page_size}"
1468
- kernel = jax.named_scope(scope_name)(
1469
- pl.pallas_call(
1470
- functools.partial(
1471
- _ragged_paged_attention_kernel,
1472
- sm_scale=sm_scale,
1473
- sliding_window=sliding_window,
1474
- strict_sliding_window=strict_sliding_window,
1475
- soft_cap=soft_cap,
1476
- mask_value=mask_value,
1477
- q_scale=q_scale,
1478
- k_scale=k_scale,
1479
- v_scale=v_scale,
1480
- chunk_prefill_size=chunk_prefill_size,
1481
- bq_sz=bq_sz,
1482
- bkv_p=bkv_p,
1483
- debug_mode=debug_mode,
1484
- ),
1485
- grid_spec=pltpu.PrefetchScalarGridSpec(
1486
- num_scalar_prefetch=len(scalar_prefetches),
1487
- in_specs=in_specs,
1488
- out_specs=out_specs,
1489
- grid=grid,
1490
- scratch_shapes=scratch_shapes,
1491
- ),
1492
- compiler_params=pltpu.CompilerParams(
1493
- # TODO(jevinjiang): since each sequence depends on the previous
1494
- # one, we need some extra work to support Megacore mode.
1495
- dimension_semantics=("arbitrary", ),
1496
- vmem_limit_bytes=vmem_limit_bytes,
1497
- ),
1498
- out_shape=[
1499
- jax.ShapeDtypeStruct(shape=q.shape, dtype=q.dtype),
1500
- jax.ShapeDtypeStruct(shape=kv_cache.shape,
1501
- dtype=kv_cache.dtype),
1502
- ],
1503
- input_output_aliases={
1504
- 7: 0,
1505
- 9: 1
1506
- },
1507
- name=scope_name,
1508
- ))
1540
+ scope_name = get_kernel_scope_name(bq_sz, bkv_p, page_size)
1541
+ kernel = pl.pallas_call(
1542
+ functools.partial(
1543
+ _ragged_paged_attention_kernel,
1544
+ sm_scale=sm_scale,
1545
+ sliding_window=sliding_window,
1546
+ soft_cap=soft_cap,
1547
+ mask_value=mask_value,
1548
+ q_scale=q_scale,
1549
+ k_scale=k_scale,
1550
+ v_scale=v_scale,
1551
+ chunk_prefill_size=chunk_prefill_size,
1552
+ bq_sz=bq_sz,
1553
+ bkv_p=bkv_p,
1554
+ debug_mode=debug_mode,
1555
+ ),
1556
+ grid_spec=pltpu.PrefetchScalarGridSpec(
1557
+ num_scalar_prefetch=len(scalar_prefetches),
1558
+ in_specs=in_specs,
1559
+ out_specs=out_specs,
1560
+ grid=grid,
1561
+ scratch_shapes=scratch_shapes,
1562
+ ),
1563
+ compiler_params=pltpu.CompilerParams(
1564
+ # TODO(jevinjiang): since each sequence depends on the previous
1565
+ # one, we need some extra work to support Megacore mode.
1566
+ dimension_semantics=("arbitrary", ),
1567
+ vmem_limit_bytes=vmem_limit_bytes,
1568
+ ),
1569
+ out_shape=[
1570
+ jax.ShapeDtypeStruct(shape=q.shape, dtype=q.dtype),
1571
+ jax.ShapeDtypeStruct(shape=kv_cache.shape, dtype=kv_cache.dtype),
1572
+ ],
1573
+ input_output_aliases={
1574
+ 7: 0,
1575
+ 9: 1
1576
+ },
1577
+ name=scope_name,
1578
+ )
1509
1579
 
1510
1580
  output, updated_kv_cache = kernel(*scalar_prefetches, q, kv, kv_cache,
1511
1581
  attention_sink)