tpu-inference 0.11.1.dev202512030818__py3-none-any.whl → 0.13.2rc3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (250) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_disagg_utils.py +14 -0
  4. tests/core/test_dp_scheduler.py +650 -768
  5. tests/core/test_init.py +14 -0
  6. tests/distributed/__init__.py +13 -0
  7. tests/distributed/test_distributed_utils.py +120 -0
  8. tests/distributed/test_tpu_connector.py +478 -0
  9. tests/e2e/__init__.py +13 -0
  10. tests/e2e/test_async_scheduler.py +211 -0
  11. tests/e2e/test_data_parallel.py +289 -0
  12. tests/e2e/test_hybrid_kvcache.py +219 -0
  13. tests/e2e/test_local_disagg.py +257 -0
  14. tests/e2e/test_model_loader.py +268 -0
  15. tests/e2e/test_multi_modal_inference.py +111 -0
  16. tests/e2e/test_pipeline_parallel.py +265 -0
  17. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  18. tests/e2e/test_sampling_params.py +269 -0
  19. tests/e2e/test_speculative_decoding.py +311 -0
  20. tests/e2e/test_structured_decoding.py +46 -0
  21. tests/executors/__init__.py +13 -0
  22. tests/executors/test_ray_distributed_executor.py +199 -0
  23. tests/experimental/__init__.py +13 -0
  24. tests/experimental/test_llama3_jax_stashed.py +208 -0
  25. tests/kernels/__init__.py +13 -0
  26. tests/kernels/collectives/__init__.py +13 -0
  27. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  28. tests/kernels/fused_moe_v1_test.py +14 -0
  29. tests/kernels/gmm_test.py +205 -0
  30. tests/kernels/mla_v1_test.py +143 -41
  31. tests/kernels/quantized_matmul_kernel_test.py +2 -34
  32. tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
  33. tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
  34. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
  35. tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
  36. tests/layers/__init__.py +13 -0
  37. tests/layers/common/__init__.py +13 -0
  38. tests/layers/common/test_attention_interface.py +156 -0
  39. tests/layers/common/test_quantization.py +149 -0
  40. tests/layers/jax/__init__.py +13 -0
  41. tests/layers/jax/attention/__init__.py +13 -0
  42. tests/layers/jax/attention/test_common_attention.py +103 -0
  43. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  44. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  45. tests/layers/jax/moe/__init__.py +13 -0
  46. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  47. tests/layers/jax/sample/__init__.py +13 -0
  48. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  49. tests/layers/jax/sample/test_sampling.py +115 -0
  50. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  51. tests/layers/jax/test_layers.py +155 -0
  52. tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
  53. tests/layers/jax/test_rope.py +93 -0
  54. tests/layers/jax/test_sharding.py +159 -0
  55. tests/layers/jax/test_transformer_block.py +152 -0
  56. tests/layers/vllm/__init__.py +13 -0
  57. tests/layers/vllm/test_attention.py +363 -0
  58. tests/layers/vllm/test_awq.py +405 -0
  59. tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
  60. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +418 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +441 -0
  62. tests/layers/vllm/test_fp8.py +17 -0
  63. tests/layers/vllm/test_mxfp4.py +312 -0
  64. tests/layers/vllm/test_unquantized.py +651 -0
  65. tests/layers/vllm/utils.py +87 -0
  66. tests/lora/__init__.py +13 -0
  67. tests/lora/conftest.py +14 -0
  68. tests/lora/test_bgmv.py +14 -0
  69. tests/lora/test_layers.py +21 -3
  70. tests/lora/test_lora.py +15 -1
  71. tests/lora/test_lora_perf.py +67 -0
  72. tests/models/__init__.py +13 -0
  73. tests/models/common/__init__.py +13 -0
  74. tests/models/common/test_model_loader.py +455 -0
  75. tests/models/jax/__init__.py +13 -0
  76. tests/models/jax/test_deepseek_v3.py +401 -0
  77. tests/models/jax/test_llama3.py +184 -0
  78. tests/models/jax/test_llama4.py +298 -0
  79. tests/models/jax/test_llama_eagle3.py +197 -0
  80. tests/models/jax/test_llama_guard_4.py +242 -0
  81. tests/models/jax/test_qwen2.py +172 -0
  82. tests/models/jax/test_qwen2_5_vl.py +605 -0
  83. tests/models/jax/test_qwen3.py +169 -0
  84. tests/models/jax/test_weight_loading.py +180 -0
  85. tests/models/jax/utils/__init__.py +13 -0
  86. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  87. tests/platforms/__init__.py +13 -0
  88. tests/platforms/test_tpu_platform.py +54 -0
  89. tests/runner/__init__.py +13 -0
  90. tests/runner/test_block_table.py +395 -0
  91. tests/runner/test_input_batch.py +226 -0
  92. tests/runner/test_kv_cache.py +220 -0
  93. tests/runner/test_kv_cache_manager.py +498 -0
  94. tests/runner/test_multimodal_manager.py +429 -0
  95. tests/runner/test_persistent_batch_manager.py +84 -0
  96. tests/runner/test_speculative_decoding_manager.py +368 -0
  97. tests/runner/test_structured_decoding_manager.py +220 -0
  98. tests/runner/test_tpu_runner.py +261 -0
  99. tests/runner/test_tpu_runner_dp.py +1099 -0
  100. tests/runner/test_tpu_runner_mesh.py +200 -0
  101. tests/runner/test_utils.py +411 -0
  102. tests/spec_decode/__init__.py +13 -0
  103. tests/spec_decode/test_eagle3.py +311 -0
  104. tests/test_base.py +14 -0
  105. tests/test_envs.py +78 -1
  106. tests/test_tpu_info.py +14 -0
  107. tests/test_utils.py +1 -43
  108. tests/worker/__init__.py +13 -0
  109. tests/worker/tpu_worker_test.py +414 -0
  110. tpu_inference/__init__.py +14 -0
  111. tpu_inference/core/__init__.py +13 -0
  112. tpu_inference/core/sched/__init__.py +13 -0
  113. tpu_inference/core/sched/dp_scheduler.py +372 -56
  114. tpu_inference/distributed/__init__.py +13 -0
  115. tpu_inference/distributed/jax_parallel_state.py +14 -0
  116. tpu_inference/distributed/tpu_connector.py +14 -9
  117. tpu_inference/distributed/utils.py +56 -4
  118. tpu_inference/envs.py +38 -7
  119. tpu_inference/executors/__init__.py +13 -0
  120. tpu_inference/executors/ray_distributed_executor.py +17 -0
  121. tpu_inference/experimental/__init__.py +13 -0
  122. tpu_inference/experimental/llama3_jax_stashed.py +14 -0
  123. tpu_inference/kernels/__init__.py +13 -0
  124. tpu_inference/kernels/collectives/__init__.py +13 -0
  125. tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
  126. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
  127. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  128. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  129. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  130. tpu_inference/kernels/fused_moe/v1/kernel.py +370 -324
  131. tpu_inference/kernels/megablox/__init__.py +13 -0
  132. tpu_inference/kernels/megablox/common.py +54 -0
  133. tpu_inference/kernels/megablox/gmm.py +646 -0
  134. tpu_inference/kernels/mla/__init__.py +13 -0
  135. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  136. tpu_inference/kernels/mla/v1/kernel.py +117 -145
  137. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  138. tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
  139. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  140. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  141. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
  142. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
  143. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  144. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
  145. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +95 -78
  146. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
  147. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
  148. tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
  149. tpu_inference/layers/__init__.py +13 -0
  150. tpu_inference/layers/common/__init__.py +13 -0
  151. tpu_inference/layers/common/attention_interface.py +26 -19
  152. tpu_inference/layers/common/attention_metadata.py +14 -0
  153. tpu_inference/layers/common/quant_methods.py +15 -0
  154. tpu_inference/layers/common/quantization.py +270 -0
  155. tpu_inference/layers/common/sharding.py +28 -5
  156. tpu_inference/layers/jax/__init__.py +13 -0
  157. tpu_inference/layers/jax/attention/__init__.py +13 -0
  158. tpu_inference/layers/jax/attention/attention.py +19 -6
  159. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
  160. tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
  161. tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
  162. tpu_inference/layers/jax/base.py +14 -0
  163. tpu_inference/layers/jax/constants.py +13 -0
  164. tpu_inference/layers/jax/layers.py +14 -0
  165. tpu_inference/layers/jax/misc.py +14 -0
  166. tpu_inference/layers/jax/moe/__init__.py +13 -0
  167. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
  168. tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
  169. tpu_inference/layers/jax/moe/moe.py +43 -3
  170. tpu_inference/layers/jax/pp_utils.py +53 -0
  171. tpu_inference/layers/jax/rope.py +14 -0
  172. tpu_inference/layers/jax/rope_interface.py +14 -0
  173. tpu_inference/layers/jax/sample/__init__.py +13 -0
  174. tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
  175. tpu_inference/layers/jax/sample/sampling.py +15 -1
  176. tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
  177. tpu_inference/layers/jax/transformer_block.py +14 -0
  178. tpu_inference/layers/vllm/__init__.py +13 -0
  179. tpu_inference/layers/vllm/attention.py +4 -4
  180. tpu_inference/layers/vllm/fused_moe.py +210 -260
  181. tpu_inference/layers/vllm/linear_common.py +57 -22
  182. tpu_inference/layers/vllm/quantization/__init__.py +16 -0
  183. tpu_inference/layers/vllm/quantization/awq.py +15 -1
  184. tpu_inference/layers/vllm/quantization/common.py +33 -18
  185. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  186. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +18 -3
  187. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +211 -148
  188. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  189. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +14 -0
  190. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +14 -0
  191. tpu_inference/layers/vllm/quantization/fp8.py +118 -0
  192. tpu_inference/layers/vllm/quantization/mxfp4.py +278 -209
  193. tpu_inference/layers/vllm/quantization/unquantized.py +134 -86
  194. tpu_inference/layers/vllm/sharding.py +21 -4
  195. tpu_inference/lora/__init__.py +13 -0
  196. tpu_inference/lora/torch_lora_ops.py +8 -13
  197. tpu_inference/models/__init__.py +13 -0
  198. tpu_inference/models/common/__init__.py +13 -0
  199. tpu_inference/models/common/model_loader.py +74 -35
  200. tpu_inference/models/jax/__init__.py +13 -0
  201. tpu_inference/models/jax/deepseek_v3.py +267 -157
  202. tpu_inference/models/jax/gpt_oss.py +26 -10
  203. tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
  204. tpu_inference/models/jax/llama3.py +99 -36
  205. tpu_inference/models/jax/llama4.py +14 -0
  206. tpu_inference/models/jax/llama_eagle3.py +14 -0
  207. tpu_inference/models/jax/llama_guard_4.py +15 -1
  208. tpu_inference/models/jax/qwen2.py +17 -2
  209. tpu_inference/models/jax/qwen2_5_vl.py +18 -4
  210. tpu_inference/models/jax/qwen3.py +17 -2
  211. tpu_inference/models/jax/utils/__init__.py +13 -0
  212. tpu_inference/models/jax/utils/file_utils.py +14 -0
  213. tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
  214. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  215. tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +89 -26
  216. tpu_inference/models/jax/utils/weight_utils.py +39 -2
  217. tpu_inference/models/vllm/__init__.py +13 -0
  218. tpu_inference/models/vllm/vllm_model_wrapper.py +20 -3
  219. tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
  220. tpu_inference/platforms/__init__.py +14 -0
  221. tpu_inference/platforms/tpu_platform.py +47 -64
  222. tpu_inference/runner/__init__.py +13 -0
  223. tpu_inference/runner/compilation_manager.py +72 -37
  224. tpu_inference/runner/kv_cache.py +54 -20
  225. tpu_inference/runner/kv_cache_manager.py +46 -17
  226. tpu_inference/runner/lora_utils.py +14 -0
  227. tpu_inference/runner/multimodal_manager.py +15 -1
  228. tpu_inference/runner/persistent_batch_manager.py +14 -0
  229. tpu_inference/runner/speculative_decoding_manager.py +14 -0
  230. tpu_inference/runner/structured_decoding_manager.py +14 -0
  231. tpu_inference/runner/tpu_runner.py +44 -17
  232. tpu_inference/spec_decode/__init__.py +13 -0
  233. tpu_inference/spec_decode/jax/__init__.py +13 -0
  234. tpu_inference/spec_decode/jax/eagle3.py +13 -0
  235. tpu_inference/tpu_info.py +14 -0
  236. tpu_inference/utils.py +42 -36
  237. tpu_inference/worker/__init__.py +13 -0
  238. tpu_inference/worker/tpu_worker.py +63 -50
  239. {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.2rc3.dist-info}/METADATA +7 -9
  240. tpu_inference-0.13.2rc3.dist-info/RECORD +261 -0
  241. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  242. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
  243. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
  244. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
  245. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
  246. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
  247. tpu_inference-0.11.1.dev202512030818.dist-info/RECORD +0 -174
  248. {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.2rc3.dist-info}/WHEEL +0 -0
  249. {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.2rc3.dist-info}/licenses/LICENSE +0 -0
  250. {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.2rc3.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,16 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
1
14
  """
2
15
  A variant of TPU-Friendly Ragged Paged Attention kernel optimized for
3
16
  head_dim = 64.
@@ -267,7 +280,6 @@ def _ragged_paged_attention_kernel(
267
280
  *,
268
281
  sm_scale: float,
269
282
  sliding_window: int | None = None,
270
- strict_sliding_window: bool = True,
271
283
  soft_cap: float | None = None,
272
284
  mask_value: float = DEFAULT_MASK_VALUE,
273
285
  q_scale: float | None = None,
@@ -324,14 +336,15 @@ def _ragged_paged_attention_kernel(
324
336
  bkv_idx_start = jnp.maximum(kv_len - q_len - sliding_window,
325
337
  0) // bkv_sz
326
338
 
327
- def get_next_bkv_idx_start():
328
- next_kv_len = kv_lens_ref[seq_idx + 1]
329
- next_q_len = cu_q_lens_ref[seq_idx + 2] - q_end
330
- return jnp.maximum(next_kv_len - next_q_len - sliding_window,
331
- 0) // bkv_sz
332
-
333
- next_seq_bkv_idx_start = lax.cond(seq_idx + 1 < num_seqs,
334
- get_next_bkv_idx_start, lambda: 0)
339
+ # If seq_idx + 1 == num_seqs, kv_lens_ref[seq_idx + 1] will trigger a
340
+ # out-of-bound error. To avoid this, we set upperbound of next_seq_idx
341
+ # to be num_seqs - 1.
342
+ next_seq_idx = jnp.minimum(seq_idx + 1, num_seqs - 1)
343
+ next_kv_len = kv_lens_ref[next_seq_idx]
344
+ next_q_len = cu_q_lens_ref[next_seq_idx + 1] - q_end
345
+ next_seq_bkv_idx_start = (
346
+ jnp.maximum(next_kv_len - next_q_len - sliding_window, 0) //
347
+ bkv_sz)
335
348
 
336
349
  def debug_print(msg, *args):
337
350
  if debug_mode:
@@ -396,7 +409,7 @@ def _ragged_paged_attention_kernel(
396
409
  k_span = bkv_idx * bkv_sz + lax.broadcasted_iota(jnp.int32, s.shape, 1)
397
410
  mask = k_span <= q_span
398
411
 
399
- if sliding_window is not None and strict_sliding_window:
412
+ if sliding_window is not None:
400
413
  mask = jnp.logical_and(mask, q_span - sliding_window < k_span)
401
414
 
402
415
  s = jnp.where(mask, s, mask_value)
@@ -520,19 +533,16 @@ def _ragged_paged_attention_kernel(
520
533
  unroll=False,
521
534
  )
522
535
 
523
- # Fetch kv directly from new kv.
524
- @pl.when(bkv_sz_frm_new > 0)
525
- def _fetch_bkv_from_new_kv():
526
- new_kv_len_start = q_end - kv_left_frm_new
527
- debug_print("[RPA debug] new_kv_len_start={}",
528
- new_kv_len_start)
529
- debug_print("[RPA debug] offset_in_bkv={}", offset)
530
- _async_copy(
531
- kv_hbm_ref.at[pl.ds(new_kv_len_start, bkv_sz_frm_new)],
532
- vmem_ref.at[pl.ds(offset, bkv_sz_frm_new)],
533
- sem,
534
- wait,
535
- )
536
+ size = lax.select(bkv_sz_frm_new > 0, bkv_sz_frm_new, 0)
537
+ new_kv_len_start = q_end - kv_left_frm_new
538
+ debug_print("[RPA debug] new_kv_len_start={}", new_kv_len_start)
539
+ debug_print("[RPA debug] offset_in_bkv={}", offset)
540
+ _async_copy(
541
+ kv_hbm_ref.at[pl.ds(new_kv_len_start, size)],
542
+ vmem_ref.at[pl.ds(offset, size)],
543
+ sem,
544
+ wait,
545
+ )
536
546
 
537
547
  return kv_len_start + offset, bkv_sz_frm_new
538
548
  else:
@@ -726,7 +736,7 @@ def _ragged_paged_attention_kernel(
726
736
  vec = ref[start::step]
727
737
  return vec
728
738
 
729
- def strided_load_bkv(bkv_sem_idx, start, step, *, bkv_mask):
739
+ def strided_load_bkv(bkv_sem_idx, start, step):
730
740
  assert start % kv_packing == 0
731
741
  assert step % kv_packing == 0
732
742
  start //= kv_packing
@@ -735,7 +745,6 @@ def _ragged_paged_attention_kernel(
735
745
  bkv_sz * step, actual_head_dim_x2))
736
746
 
737
747
  kv = strided_load(kv_ref, start, step)
738
- kv = lax.select(bkv_mask, kv, jnp.zeros_like(kv))
739
748
  bitwidth = 32 // kv_packing
740
749
  repack_ty = jnp.dtype(f"uint{bitwidth}")
741
750
  lst = []
@@ -783,14 +792,18 @@ def _ragged_paged_attention_kernel(
783
792
  next_bkv_sem_idx = lax.select(bkv_sem_idx == 0, 1, 0)
784
793
 
785
794
  if sliding_window is None:
786
- next_bkv_start_idx = 0
795
+ # When sliding window is disabled, starting bkv_idx of next request is
796
+ # always 0 regardless of seq_idx of next request.
797
+ next_bkv_idx_start = 0
787
798
  else:
788
- next_bkv_start_idx = lax.select(
799
+ # Determine starting bkv_idx of next request based on whether next
800
+ # request is from the same sequence or next sequence.
801
+ next_bkv_idx_start = lax.select(
789
802
  is_last_bq,
790
803
  next_seq_bkv_idx_start,
791
804
  bkv_idx_start,
792
805
  )
793
- next_bkv_idx = lax.select(is_last_bkv, next_bkv_start_idx,
806
+ next_bkv_idx = lax.select(is_last_bkv, next_bkv_idx_start,
794
807
  next_bkv_idx)
795
808
 
796
809
  return next_seq_idx, next_bq_idx, next_bkv_idx, next_bkv_sem_idx
@@ -809,10 +822,6 @@ def _ragged_paged_attention_kernel(
809
822
  def compute_with_bkv(bkv_idx, _):
810
823
  # Create bitmask for KV.
811
824
  assert bkv_sz % kv_packing == 0
812
- actual_bkv_sz = jnp.minimum(bkv_sz, kv_len - bkv_idx * bkv_sz)
813
- bkv_shape = (bkv_sz, actual_head_dim_x2)
814
- bkv_mask = lax.broadcasted_iota(jnp.int32, bkv_shape,
815
- 0) < actual_bkv_sz
816
825
 
817
826
  # Get next bkv ids.
818
827
  bkv_sem_idx = sem_ids_ref[1]
@@ -862,7 +871,6 @@ def _ragged_paged_attention_kernel(
862
871
  bkv_sem_idx,
863
872
  kv_head_start,
864
873
  num_kv_heads,
865
- bkv_mask=bkv_mask,
866
874
  )
867
875
  assert len(bkv_lst) == kv_packing
868
876
  for i in range(kv_packing):
@@ -946,7 +954,17 @@ def _ragged_paged_attention_kernel(
946
954
  @pl.when(seq_idx == 0)
947
955
  def prologue():
948
956
  start_fetch_bq(0, 0, 0)
957
+
958
+ # Initialize bkv_x2_ref to zeros to avoid NaN issues from accessing
959
+ # uninitialized memory. Bitcast into int32 to avoid tiling issues.
960
+ bkv_x2_int32_ref = bkv_x2_ref.bitcast(jnp.int32).reshape(
961
+ (2, -1, 8, 128))
962
+ zeros = jnp.zeros(bkv_x2_int32_ref.shape[1:], jnp.int32)
963
+
964
+ # To pipeline VST and DMA, we divide the initialization into two steps.
965
+ bkv_x2_int32_ref[0] = zeros
949
966
  start_fetch_bkv(0, bkv_idx_start, 0)
967
+ bkv_x2_int32_ref[1] = zeros
950
968
 
951
969
  @pl.when(seq_idx < decode_end)
952
970
  def process_decode():
@@ -1306,12 +1324,15 @@ def static_validate_inputs(
1306
1324
  del debug_mode
1307
1325
 
1308
1326
 
1327
+ def get_kernel_scope_name(bq_size, bkv_p, page_size):
1328
+ return f"RPA-HD_64-bq_{bq_size}-bkvp_{bkv_p}-p_{page_size}-"
1329
+
1330
+
1309
1331
  @functools.partial(
1310
1332
  jax.jit,
1311
1333
  static_argnames=(
1312
1334
  "sm_scale",
1313
1335
  "sliding_window",
1314
- "strict_sliding_window",
1315
1336
  "soft_cap",
1316
1337
  "mask_value",
1317
1338
  "q_scale",
@@ -1341,7 +1362,6 @@ def ragged_paged_attention_hd64(
1341
1362
  *,
1342
1363
  sm_scale: float = 1.0,
1343
1364
  sliding_window: int | None = None,
1344
- strict_sliding_window: bool = True,
1345
1365
  soft_cap: float | None = None,
1346
1366
  mask_value: float | None = DEFAULT_MASK_VALUE,
1347
1367
  q_scale: float | None = None,
@@ -1373,7 +1393,6 @@ def ragged_paged_attention_hd64(
1373
1393
  attention_sink: optional attention sink for each q head.
1374
1394
  sm_scale: the softmax scale which will be applied to the Q@K^T.
1375
1395
  sliding_window: the sliding window size for the attention.
1376
- strict_sliding_window: compute tokens that are strictly within the window.
1377
1396
  soft_cap: the logit soft cap for the attention.
1378
1397
  mask_value: mask value for causal mask.
1379
1398
  q_scale: the scale for the query.
@@ -1447,6 +1466,7 @@ def ragged_paged_attention_hd64(
1447
1466
  page_size,
1448
1467
  max_num_tokens,
1449
1468
  pages_per_seq,
1469
+ sliding_window,
1450
1470
  )
1451
1471
  bkv_sz = bkv_p * page_size
1452
1472
  if vmem_limit_bytes is None:
@@ -1517,48 +1537,45 @@ def ragged_paged_attention_hd64(
1517
1537
  jnp.full((6, ), -1, jnp.int32),
1518
1538
  )
1519
1539
 
1520
- scope_name = f"RPA-HD_64-bq_{bq_sz}-bkvp_{bkv_p}-p_{page_size}"
1521
- kernel = jax.named_scope(scope_name)(
1522
- pl.pallas_call(
1523
- functools.partial(
1524
- _ragged_paged_attention_kernel,
1525
- sm_scale=sm_scale,
1526
- sliding_window=sliding_window,
1527
- strict_sliding_window=strict_sliding_window,
1528
- soft_cap=soft_cap,
1529
- mask_value=mask_value,
1530
- q_scale=q_scale,
1531
- k_scale=k_scale,
1532
- v_scale=v_scale,
1533
- chunk_prefill_size=chunk_prefill_size,
1534
- bq_sz=bq_sz,
1535
- bkv_p=bkv_p,
1536
- debug_mode=debug_mode,
1537
- ),
1538
- grid_spec=pltpu.PrefetchScalarGridSpec(
1539
- num_scalar_prefetch=len(scalar_prefetches),
1540
- in_specs=in_specs,
1541
- out_specs=out_specs,
1542
- grid=grid,
1543
- scratch_shapes=scratch_shapes,
1544
- ),
1545
- compiler_params=pltpu.CompilerParams(
1546
- # TODO(jevinjiang): since each sequence depends on the previous
1547
- # one, we need some extra work to support Megacore mode.
1548
- dimension_semantics=("arbitrary", ),
1549
- vmem_limit_bytes=vmem_limit_bytes,
1550
- ),
1551
- out_shape=[
1552
- jax.ShapeDtypeStruct(shape=q.shape, dtype=q.dtype),
1553
- jax.ShapeDtypeStruct(shape=kv_cache.shape,
1554
- dtype=kv_cache.dtype),
1555
- ],
1556
- input_output_aliases={
1557
- 7: 0,
1558
- 9: 1
1559
- },
1560
- name=scope_name,
1561
- ))
1540
+ scope_name = get_kernel_scope_name(bq_sz, bkv_p, page_size)
1541
+ kernel = pl.pallas_call(
1542
+ functools.partial(
1543
+ _ragged_paged_attention_kernel,
1544
+ sm_scale=sm_scale,
1545
+ sliding_window=sliding_window,
1546
+ soft_cap=soft_cap,
1547
+ mask_value=mask_value,
1548
+ q_scale=q_scale,
1549
+ k_scale=k_scale,
1550
+ v_scale=v_scale,
1551
+ chunk_prefill_size=chunk_prefill_size,
1552
+ bq_sz=bq_sz,
1553
+ bkv_p=bkv_p,
1554
+ debug_mode=debug_mode,
1555
+ ),
1556
+ grid_spec=pltpu.PrefetchScalarGridSpec(
1557
+ num_scalar_prefetch=len(scalar_prefetches),
1558
+ in_specs=in_specs,
1559
+ out_specs=out_specs,
1560
+ grid=grid,
1561
+ scratch_shapes=scratch_shapes,
1562
+ ),
1563
+ compiler_params=pltpu.CompilerParams(
1564
+ # TODO(jevinjiang): since each sequence depends on the previous
1565
+ # one, we need some extra work to support Megacore mode.
1566
+ dimension_semantics=("arbitrary", ),
1567
+ vmem_limit_bytes=vmem_limit_bytes,
1568
+ ),
1569
+ out_shape=[
1570
+ jax.ShapeDtypeStruct(shape=q.shape, dtype=q.dtype),
1571
+ jax.ShapeDtypeStruct(shape=kv_cache.shape, dtype=kv_cache.dtype),
1572
+ ],
1573
+ input_output_aliases={
1574
+ 7: 0,
1575
+ 9: 1
1576
+ },
1577
+ name=scope_name,
1578
+ )
1562
1579
 
1563
1580
  output, updated_kv_cache = kernel(*scalar_prefetches, q, kv, kv_cache,
1564
1581
  attention_sink)