tpu-inference 0.12.0.dev20251213__py3-none-any.whl → 0.13.2.dev20251230__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (248) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_disagg_utils.py +14 -0
  4. tests/core/test_dp_scheduler.py +650 -768
  5. tests/core/test_init.py +14 -0
  6. tests/distributed/__init__.py +13 -0
  7. tests/distributed/test_distributed_utils.py +120 -0
  8. tests/distributed/test_tpu_connector.py +478 -0
  9. tests/e2e/__init__.py +13 -0
  10. tests/e2e/test_async_scheduler.py +211 -0
  11. tests/e2e/test_data_parallel.py +289 -0
  12. tests/e2e/test_hybrid_kvcache.py +219 -0
  13. tests/e2e/test_local_disagg.py +257 -0
  14. tests/e2e/test_model_loader.py +268 -0
  15. tests/e2e/test_multi_modal_inference.py +111 -0
  16. tests/e2e/test_pipeline_parallel.py +265 -0
  17. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  18. tests/e2e/test_sampling_params.py +269 -0
  19. tests/e2e/test_speculative_decoding.py +311 -0
  20. tests/e2e/test_structured_decoding.py +46 -0
  21. tests/executors/__init__.py +13 -0
  22. tests/executors/test_ray_distributed_executor.py +199 -0
  23. tests/experimental/__init__.py +13 -0
  24. tests/experimental/test_llama3_jax_stashed.py +208 -0
  25. tests/kernels/__init__.py +13 -0
  26. tests/kernels/collectives/__init__.py +13 -0
  27. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  28. tests/kernels/fused_moe_v1_test.py +14 -0
  29. tests/kernels/gmm_test.py +205 -0
  30. tests/kernels/mla_v1_test.py +14 -0
  31. tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
  32. tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
  33. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +14 -0
  34. tests/kernels/ragged_paged_attention_kernel_v3_test.py +14 -0
  35. tests/layers/__init__.py +13 -0
  36. tests/layers/common/__init__.py +13 -0
  37. tests/layers/common/test_attention_interface.py +156 -0
  38. tests/layers/common/test_quantization.py +149 -0
  39. tests/layers/jax/__init__.py +13 -0
  40. tests/layers/jax/attention/__init__.py +13 -0
  41. tests/layers/jax/attention/test_common_attention.py +103 -0
  42. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  43. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  44. tests/layers/jax/moe/__init__.py +13 -0
  45. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  46. tests/layers/jax/sample/__init__.py +13 -0
  47. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  48. tests/layers/jax/sample/test_sampling.py +115 -0
  49. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  50. tests/layers/jax/test_layers.py +155 -0
  51. tests/{test_quantization.py → layers/jax/test_qwix.py} +180 -50
  52. tests/layers/jax/test_rope.py +93 -0
  53. tests/layers/jax/test_sharding.py +159 -0
  54. tests/layers/jax/test_transformer_block.py +152 -0
  55. tests/layers/vllm/__init__.py +13 -0
  56. tests/layers/vllm/test_attention.py +363 -0
  57. tests/layers/vllm/test_awq.py +406 -0
  58. tests/layers/vllm/test_compressed_tensors_moe.py +199 -0
  59. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +441 -0
  60. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +443 -0
  61. tests/layers/vllm/test_fp8.py +17 -0
  62. tests/layers/vllm/test_mxfp4.py +320 -0
  63. tests/layers/vllm/test_unquantized.py +662 -0
  64. tests/layers/vllm/utils.py +87 -0
  65. tests/lora/__init__.py +13 -0
  66. tests/lora/conftest.py +14 -0
  67. tests/lora/test_bgmv.py +14 -0
  68. tests/lora/test_layers.py +25 -8
  69. tests/lora/test_lora.py +15 -1
  70. tests/lora/test_lora_perf.py +14 -0
  71. tests/models/__init__.py +13 -0
  72. tests/models/common/__init__.py +13 -0
  73. tests/models/common/test_model_loader.py +455 -0
  74. tests/models/jax/__init__.py +13 -0
  75. tests/models/jax/test_deepseek_v3.py +401 -0
  76. tests/models/jax/test_llama3.py +184 -0
  77. tests/models/jax/test_llama4.py +298 -0
  78. tests/models/jax/test_llama_eagle3.py +197 -0
  79. tests/models/jax/test_llama_guard_4.py +242 -0
  80. tests/models/jax/test_qwen2.py +172 -0
  81. tests/models/jax/test_qwen2_5_vl.py +605 -0
  82. tests/models/jax/test_qwen3.py +169 -0
  83. tests/models/jax/test_weight_loading.py +180 -0
  84. tests/models/jax/utils/__init__.py +13 -0
  85. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  86. tests/platforms/__init__.py +13 -0
  87. tests/platforms/test_tpu_platform.py +54 -0
  88. tests/runner/__init__.py +13 -0
  89. tests/runner/test_block_table.py +395 -0
  90. tests/runner/test_input_batch.py +226 -0
  91. tests/runner/test_kv_cache.py +220 -0
  92. tests/runner/test_kv_cache_manager.py +498 -0
  93. tests/runner/test_multimodal_manager.py +429 -0
  94. tests/runner/test_persistent_batch_manager.py +84 -0
  95. tests/runner/test_speculative_decoding_manager.py +368 -0
  96. tests/runner/test_structured_decoding_manager.py +220 -0
  97. tests/runner/test_tpu_runner.py +261 -0
  98. tests/runner/test_tpu_runner_dp.py +1099 -0
  99. tests/runner/test_tpu_runner_mesh.py +200 -0
  100. tests/runner/test_utils.py +411 -0
  101. tests/spec_decode/__init__.py +13 -0
  102. tests/spec_decode/test_eagle3.py +311 -0
  103. tests/test_base.py +14 -0
  104. tests/test_tpu_info.py +14 -0
  105. tests/test_utils.py +1 -43
  106. tests/worker/__init__.py +13 -0
  107. tests/worker/tpu_worker_test.py +414 -0
  108. tpu_inference/__init__.py +14 -0
  109. tpu_inference/core/__init__.py +13 -0
  110. tpu_inference/core/sched/__init__.py +13 -0
  111. tpu_inference/core/sched/dp_scheduler.py +372 -56
  112. tpu_inference/distributed/__init__.py +13 -0
  113. tpu_inference/distributed/jax_parallel_state.py +14 -0
  114. tpu_inference/distributed/tpu_connector.py +14 -9
  115. tpu_inference/distributed/utils.py +56 -4
  116. tpu_inference/executors/__init__.py +13 -0
  117. tpu_inference/executors/ray_distributed_executor.py +20 -3
  118. tpu_inference/experimental/__init__.py +13 -0
  119. tpu_inference/experimental/llama3_jax_stashed.py +14 -0
  120. tpu_inference/kernels/__init__.py +13 -0
  121. tpu_inference/kernels/collectives/__init__.py +13 -0
  122. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  123. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  124. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  125. tpu_inference/kernels/fused_moe/v1/kernel.py +171 -163
  126. tpu_inference/kernels/megablox/__init__.py +13 -0
  127. tpu_inference/kernels/megablox/common.py +54 -0
  128. tpu_inference/kernels/megablox/gmm.py +646 -0
  129. tpu_inference/kernels/mla/__init__.py +13 -0
  130. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  131. tpu_inference/kernels/mla/v1/kernel.py +20 -26
  132. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  133. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  134. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  135. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  136. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +112 -69
  137. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +85 -65
  138. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
  139. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +374 -194
  140. tpu_inference/kernels/ragged_paged_attention/v3/util.py +13 -0
  141. tpu_inference/layers/__init__.py +13 -0
  142. tpu_inference/layers/common/__init__.py +13 -0
  143. tpu_inference/layers/common/attention_interface.py +26 -19
  144. tpu_inference/layers/common/attention_metadata.py +14 -0
  145. tpu_inference/layers/common/fused_moe_gmm.py +506 -0
  146. tpu_inference/layers/common/quant_methods.py +15 -0
  147. tpu_inference/layers/common/quantization.py +282 -0
  148. tpu_inference/layers/common/sharding.py +22 -3
  149. tpu_inference/layers/common/utils.py +94 -0
  150. tpu_inference/layers/jax/__init__.py +13 -0
  151. tpu_inference/layers/jax/attention/__init__.py +13 -0
  152. tpu_inference/layers/jax/attention/attention.py +19 -6
  153. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +52 -27
  154. tpu_inference/layers/jax/attention/gpt_oss_attention.py +19 -6
  155. tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
  156. tpu_inference/layers/jax/base.py +14 -0
  157. tpu_inference/layers/jax/constants.py +13 -0
  158. tpu_inference/layers/jax/layers.py +14 -0
  159. tpu_inference/layers/jax/misc.py +14 -0
  160. tpu_inference/layers/jax/moe/__init__.py +13 -0
  161. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
  162. tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
  163. tpu_inference/layers/jax/moe/moe.py +43 -3
  164. tpu_inference/layers/jax/pp_utils.py +53 -0
  165. tpu_inference/layers/jax/rope.py +14 -0
  166. tpu_inference/layers/jax/rope_interface.py +14 -0
  167. tpu_inference/layers/jax/sample/__init__.py +13 -0
  168. tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
  169. tpu_inference/layers/jax/sample/sampling.py +15 -1
  170. tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
  171. tpu_inference/layers/jax/transformer_block.py +14 -0
  172. tpu_inference/layers/vllm/__init__.py +13 -0
  173. tpu_inference/layers/vllm/attention.py +4 -4
  174. tpu_inference/layers/vllm/fused_moe.py +100 -455
  175. tpu_inference/layers/vllm/linear.py +64 -0
  176. tpu_inference/layers/vllm/process_weights/__init__.py +13 -0
  177. tpu_inference/layers/vllm/{sharding.py → process_weights/cleanup_sharding.py} +24 -15
  178. tpu_inference/layers/vllm/process_weights/fused_moe_weights.py +369 -0
  179. tpu_inference/layers/vllm/process_weights/linear_weights.py +174 -0
  180. tpu_inference/layers/vllm/quantization/__init__.py +19 -3
  181. tpu_inference/layers/vllm/quantization/awq.py +96 -82
  182. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  183. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +19 -5
  184. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +119 -132
  185. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  186. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +111 -91
  187. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +79 -43
  188. tpu_inference/layers/vllm/quantization/{common.py → configs.py} +38 -26
  189. tpu_inference/layers/vllm/quantization/fp8.py +119 -0
  190. tpu_inference/layers/vllm/quantization/mxfp4.py +133 -220
  191. tpu_inference/layers/vllm/quantization/unquantized.py +154 -253
  192. tpu_inference/lora/__init__.py +13 -0
  193. tpu_inference/lora/torch_lora_ops.py +8 -13
  194. tpu_inference/models/__init__.py +13 -0
  195. tpu_inference/models/common/__init__.py +13 -0
  196. tpu_inference/models/common/model_loader.py +37 -16
  197. tpu_inference/models/jax/__init__.py +13 -0
  198. tpu_inference/models/jax/deepseek_v3.py +113 -124
  199. tpu_inference/models/jax/gpt_oss.py +23 -7
  200. tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
  201. tpu_inference/models/jax/llama3.py +99 -36
  202. tpu_inference/models/jax/llama4.py +14 -0
  203. tpu_inference/models/jax/llama_eagle3.py +14 -0
  204. tpu_inference/models/jax/llama_guard_4.py +15 -1
  205. tpu_inference/models/jax/qwen2.py +17 -2
  206. tpu_inference/models/jax/qwen2_5_vl.py +18 -4
  207. tpu_inference/models/jax/qwen3.py +17 -2
  208. tpu_inference/models/jax/utils/__init__.py +13 -0
  209. tpu_inference/models/jax/utils/file_utils.py +14 -0
  210. tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
  211. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  212. tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +85 -24
  213. tpu_inference/models/jax/utils/weight_utils.py +32 -1
  214. tpu_inference/models/vllm/__init__.py +13 -0
  215. tpu_inference/models/vllm/vllm_model_wrapper.py +22 -4
  216. tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
  217. tpu_inference/platforms/__init__.py +14 -0
  218. tpu_inference/platforms/tpu_platform.py +27 -29
  219. tpu_inference/runner/__init__.py +13 -0
  220. tpu_inference/runner/compilation_manager.py +69 -35
  221. tpu_inference/runner/kv_cache.py +14 -0
  222. tpu_inference/runner/kv_cache_manager.py +15 -2
  223. tpu_inference/runner/lora_utils.py +16 -1
  224. tpu_inference/runner/multimodal_manager.py +16 -2
  225. tpu_inference/runner/persistent_batch_manager.py +14 -0
  226. tpu_inference/runner/speculative_decoding_manager.py +14 -0
  227. tpu_inference/runner/structured_decoding_manager.py +14 -0
  228. tpu_inference/runner/tpu_runner.py +30 -10
  229. tpu_inference/spec_decode/__init__.py +13 -0
  230. tpu_inference/spec_decode/jax/__init__.py +13 -0
  231. tpu_inference/spec_decode/jax/eagle3.py +13 -0
  232. tpu_inference/tpu_info.py +14 -0
  233. tpu_inference/utils.py +31 -30
  234. tpu_inference/worker/__init__.py +13 -0
  235. tpu_inference/worker/tpu_worker.py +23 -7
  236. {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/METADATA +1 -1
  237. tpu_inference-0.13.2.dev20251230.dist-info/RECORD +266 -0
  238. tpu_inference/layers/vllm/linear_common.py +0 -208
  239. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  240. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
  241. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
  242. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
  243. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
  244. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
  245. tpu_inference-0.12.0.dev20251213.dist-info/RECORD +0 -175
  246. {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/WHEEL +0 -0
  247. {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/licenses/LICENSE +0 -0
  248. {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,16 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
1
14
  """
2
15
  A variant of TPU-Friendly Ragged Paged Attention kernel optimized for
3
16
  head_dim = 64.
@@ -267,7 +280,6 @@ def _ragged_paged_attention_kernel(
267
280
  *,
268
281
  sm_scale: float,
269
282
  sliding_window: int | None = None,
270
- strict_sliding_window: bool = True,
271
283
  soft_cap: float | None = None,
272
284
  mask_value: float = DEFAULT_MASK_VALUE,
273
285
  q_scale: float | None = None,
@@ -324,14 +336,15 @@ def _ragged_paged_attention_kernel(
324
336
  bkv_idx_start = jnp.maximum(kv_len - q_len - sliding_window,
325
337
  0) // bkv_sz
326
338
 
327
- def get_next_bkv_idx_start():
328
- next_kv_len = kv_lens_ref[seq_idx + 1]
329
- next_q_len = cu_q_lens_ref[seq_idx + 2] - q_end
330
- return jnp.maximum(next_kv_len - next_q_len - sliding_window,
331
- 0) // bkv_sz
332
-
333
- next_seq_bkv_idx_start = lax.cond(seq_idx + 1 < num_seqs,
334
- get_next_bkv_idx_start, lambda: 0)
339
+ # If seq_idx + 1 == num_seqs, kv_lens_ref[seq_idx + 1] will trigger a
340
+ # out-of-bound error. To avoid this, we set upperbound of next_seq_idx
341
+ # to be num_seqs - 1.
342
+ next_seq_idx = jnp.minimum(seq_idx + 1, num_seqs - 1)
343
+ next_kv_len = kv_lens_ref[next_seq_idx]
344
+ next_q_len = cu_q_lens_ref[next_seq_idx + 1] - q_end
345
+ next_seq_bkv_idx_start = (
346
+ jnp.maximum(next_kv_len - next_q_len - sliding_window, 0) //
347
+ bkv_sz)
335
348
 
336
349
  def debug_print(msg, *args):
337
350
  if debug_mode:
@@ -396,7 +409,7 @@ def _ragged_paged_attention_kernel(
396
409
  k_span = bkv_idx * bkv_sz + lax.broadcasted_iota(jnp.int32, s.shape, 1)
397
410
  mask = k_span <= q_span
398
411
 
399
- if sliding_window is not None and strict_sliding_window:
412
+ if sliding_window is not None:
400
413
  mask = jnp.logical_and(mask, q_span - sliding_window < k_span)
401
414
 
402
415
  s = jnp.where(mask, s, mask_value)
@@ -723,7 +736,7 @@ def _ragged_paged_attention_kernel(
723
736
  vec = ref[start::step]
724
737
  return vec
725
738
 
726
- def strided_load_bkv(bkv_sem_idx, start, step, *, bkv_mask):
739
+ def strided_load_bkv(bkv_sem_idx, start, step):
727
740
  assert start % kv_packing == 0
728
741
  assert step % kv_packing == 0
729
742
  start //= kv_packing
@@ -732,7 +745,6 @@ def _ragged_paged_attention_kernel(
732
745
  bkv_sz * step, actual_head_dim_x2))
733
746
 
734
747
  kv = strided_load(kv_ref, start, step)
735
- kv = lax.select(bkv_mask, kv, jnp.zeros_like(kv))
736
748
  bitwidth = 32 // kv_packing
737
749
  repack_ty = jnp.dtype(f"uint{bitwidth}")
738
750
  lst = []
@@ -780,14 +792,18 @@ def _ragged_paged_attention_kernel(
780
792
  next_bkv_sem_idx = lax.select(bkv_sem_idx == 0, 1, 0)
781
793
 
782
794
  if sliding_window is None:
783
- next_bkv_start_idx = 0
795
+ # When sliding window is disabled, starting bkv_idx of next request is
796
+ # always 0 regardless of seq_idx of next request.
797
+ next_bkv_idx_start = 0
784
798
  else:
785
- next_bkv_start_idx = lax.select(
799
+ # Determine starting bkv_idx of next request based on whether next
800
+ # request is from the same sequence or next sequence.
801
+ next_bkv_idx_start = lax.select(
786
802
  is_last_bq,
787
803
  next_seq_bkv_idx_start,
788
804
  bkv_idx_start,
789
805
  )
790
- next_bkv_idx = lax.select(is_last_bkv, next_bkv_start_idx,
806
+ next_bkv_idx = lax.select(is_last_bkv, next_bkv_idx_start,
791
807
  next_bkv_idx)
792
808
 
793
809
  return next_seq_idx, next_bq_idx, next_bkv_idx, next_bkv_sem_idx
@@ -806,10 +822,6 @@ def _ragged_paged_attention_kernel(
806
822
  def compute_with_bkv(bkv_idx, _):
807
823
  # Create bitmask for KV.
808
824
  assert bkv_sz % kv_packing == 0
809
- actual_bkv_sz = jnp.minimum(bkv_sz, kv_len - bkv_idx * bkv_sz)
810
- bkv_shape = (bkv_sz, actual_head_dim_x2)
811
- bkv_mask = lax.broadcasted_iota(jnp.int32, bkv_shape,
812
- 0) < actual_bkv_sz
813
825
 
814
826
  # Get next bkv ids.
815
827
  bkv_sem_idx = sem_ids_ref[1]
@@ -859,7 +871,6 @@ def _ragged_paged_attention_kernel(
859
871
  bkv_sem_idx,
860
872
  kv_head_start,
861
873
  num_kv_heads,
862
- bkv_mask=bkv_mask,
863
874
  )
864
875
  assert len(bkv_lst) == kv_packing
865
876
  for i in range(kv_packing):
@@ -943,7 +954,17 @@ def _ragged_paged_attention_kernel(
943
954
  @pl.when(seq_idx == 0)
944
955
  def prologue():
945
956
  start_fetch_bq(0, 0, 0)
957
+
958
+ # Initialize bkv_x2_ref to zeros to avoid NaN issues from accessing
959
+ # uninitialized memory. Bitcast into int32 to avoid tiling issues.
960
+ bkv_x2_int32_ref = bkv_x2_ref.bitcast(jnp.int32).reshape(
961
+ (2, -1, 8, 128))
962
+ zeros = jnp.zeros(bkv_x2_int32_ref.shape[1:], jnp.int32)
963
+
964
+ # To pipeline VST and DMA, we divide the initialization into two steps.
965
+ bkv_x2_int32_ref[0] = zeros
946
966
  start_fetch_bkv(0, bkv_idx_start, 0)
967
+ bkv_x2_int32_ref[1] = zeros
947
968
 
948
969
  @pl.when(seq_idx < decode_end)
949
970
  def process_decode():
@@ -1303,12 +1324,15 @@ def static_validate_inputs(
1303
1324
  del debug_mode
1304
1325
 
1305
1326
 
1327
+ def get_kernel_scope_name(bq_size, bkv_p, page_size):
1328
+ return f"RPA-HD_64-bq_{bq_size}-bkvp_{bkv_p}-p_{page_size}-"
1329
+
1330
+
1306
1331
  @functools.partial(
1307
1332
  jax.jit,
1308
1333
  static_argnames=(
1309
1334
  "sm_scale",
1310
1335
  "sliding_window",
1311
- "strict_sliding_window",
1312
1336
  "soft_cap",
1313
1337
  "mask_value",
1314
1338
  "q_scale",
@@ -1338,7 +1362,6 @@ def ragged_paged_attention_hd64(
1338
1362
  *,
1339
1363
  sm_scale: float = 1.0,
1340
1364
  sliding_window: int | None = None,
1341
- strict_sliding_window: bool = True,
1342
1365
  soft_cap: float | None = None,
1343
1366
  mask_value: float | None = DEFAULT_MASK_VALUE,
1344
1367
  q_scale: float | None = None,
@@ -1370,7 +1393,6 @@ def ragged_paged_attention_hd64(
1370
1393
  attention_sink: optional attention sink for each q head.
1371
1394
  sm_scale: the softmax scale which will be applied to the Q@K^T.
1372
1395
  sliding_window: the sliding window size for the attention.
1373
- strict_sliding_window: compute tokens that are strictly within the window.
1374
1396
  soft_cap: the logit soft cap for the attention.
1375
1397
  mask_value: mask value for causal mask.
1376
1398
  q_scale: the scale for the query.
@@ -1444,6 +1466,7 @@ def ragged_paged_attention_hd64(
1444
1466
  page_size,
1445
1467
  max_num_tokens,
1446
1468
  pages_per_seq,
1469
+ sliding_window,
1447
1470
  )
1448
1471
  bkv_sz = bkv_p * page_size
1449
1472
  if vmem_limit_bytes is None:
@@ -1514,48 +1537,45 @@ def ragged_paged_attention_hd64(
1514
1537
  jnp.full((6, ), -1, jnp.int32),
1515
1538
  )
1516
1539
 
1517
- scope_name = f"RPA-HD_64-bq_{bq_sz}-bkvp_{bkv_p}-p_{page_size}"
1518
- kernel = jax.named_scope(scope_name)(
1519
- pl.pallas_call(
1520
- functools.partial(
1521
- _ragged_paged_attention_kernel,
1522
- sm_scale=sm_scale,
1523
- sliding_window=sliding_window,
1524
- strict_sliding_window=strict_sliding_window,
1525
- soft_cap=soft_cap,
1526
- mask_value=mask_value,
1527
- q_scale=q_scale,
1528
- k_scale=k_scale,
1529
- v_scale=v_scale,
1530
- chunk_prefill_size=chunk_prefill_size,
1531
- bq_sz=bq_sz,
1532
- bkv_p=bkv_p,
1533
- debug_mode=debug_mode,
1534
- ),
1535
- grid_spec=pltpu.PrefetchScalarGridSpec(
1536
- num_scalar_prefetch=len(scalar_prefetches),
1537
- in_specs=in_specs,
1538
- out_specs=out_specs,
1539
- grid=grid,
1540
- scratch_shapes=scratch_shapes,
1541
- ),
1542
- compiler_params=pltpu.CompilerParams(
1543
- # TODO(jevinjiang): since each sequence depends on the previous
1544
- # one, we need some extra work to support Megacore mode.
1545
- dimension_semantics=("arbitrary", ),
1546
- vmem_limit_bytes=vmem_limit_bytes,
1547
- ),
1548
- out_shape=[
1549
- jax.ShapeDtypeStruct(shape=q.shape, dtype=q.dtype),
1550
- jax.ShapeDtypeStruct(shape=kv_cache.shape,
1551
- dtype=kv_cache.dtype),
1552
- ],
1553
- input_output_aliases={
1554
- 7: 0,
1555
- 9: 1
1556
- },
1557
- name=scope_name,
1558
- ))
1540
+ scope_name = get_kernel_scope_name(bq_sz, bkv_p, page_size)
1541
+ kernel = pl.pallas_call(
1542
+ functools.partial(
1543
+ _ragged_paged_attention_kernel,
1544
+ sm_scale=sm_scale,
1545
+ sliding_window=sliding_window,
1546
+ soft_cap=soft_cap,
1547
+ mask_value=mask_value,
1548
+ q_scale=q_scale,
1549
+ k_scale=k_scale,
1550
+ v_scale=v_scale,
1551
+ chunk_prefill_size=chunk_prefill_size,
1552
+ bq_sz=bq_sz,
1553
+ bkv_p=bkv_p,
1554
+ debug_mode=debug_mode,
1555
+ ),
1556
+ grid_spec=pltpu.PrefetchScalarGridSpec(
1557
+ num_scalar_prefetch=len(scalar_prefetches),
1558
+ in_specs=in_specs,
1559
+ out_specs=out_specs,
1560
+ grid=grid,
1561
+ scratch_shapes=scratch_shapes,
1562
+ ),
1563
+ compiler_params=pltpu.CompilerParams(
1564
+ # TODO(jevinjiang): since each sequence depends on the previous
1565
+ # one, we need some extra work to support Megacore mode.
1566
+ dimension_semantics=("arbitrary", ),
1567
+ vmem_limit_bytes=vmem_limit_bytes,
1568
+ ),
1569
+ out_shape=[
1570
+ jax.ShapeDtypeStruct(shape=q.shape, dtype=q.dtype),
1571
+ jax.ShapeDtypeStruct(shape=kv_cache.shape, dtype=kv_cache.dtype),
1572
+ ],
1573
+ input_output_aliases={
1574
+ 7: 0,
1575
+ 9: 1
1576
+ },
1577
+ name=scope_name,
1578
+ )
1559
1579
 
1560
1580
  output, updated_kv_cache = kernel(*scalar_prefetches, q, kv, kv_cache,
1561
1581
  attention_sink)