tpu-inference 0.12.0.dev20251213__py3-none-any.whl → 0.13.2.dev20251230__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (248) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_disagg_utils.py +14 -0
  4. tests/core/test_dp_scheduler.py +650 -768
  5. tests/core/test_init.py +14 -0
  6. tests/distributed/__init__.py +13 -0
  7. tests/distributed/test_distributed_utils.py +120 -0
  8. tests/distributed/test_tpu_connector.py +478 -0
  9. tests/e2e/__init__.py +13 -0
  10. tests/e2e/test_async_scheduler.py +211 -0
  11. tests/e2e/test_data_parallel.py +289 -0
  12. tests/e2e/test_hybrid_kvcache.py +219 -0
  13. tests/e2e/test_local_disagg.py +257 -0
  14. tests/e2e/test_model_loader.py +268 -0
  15. tests/e2e/test_multi_modal_inference.py +111 -0
  16. tests/e2e/test_pipeline_parallel.py +265 -0
  17. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  18. tests/e2e/test_sampling_params.py +269 -0
  19. tests/e2e/test_speculative_decoding.py +311 -0
  20. tests/e2e/test_structured_decoding.py +46 -0
  21. tests/executors/__init__.py +13 -0
  22. tests/executors/test_ray_distributed_executor.py +199 -0
  23. tests/experimental/__init__.py +13 -0
  24. tests/experimental/test_llama3_jax_stashed.py +208 -0
  25. tests/kernels/__init__.py +13 -0
  26. tests/kernels/collectives/__init__.py +13 -0
  27. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  28. tests/kernels/fused_moe_v1_test.py +14 -0
  29. tests/kernels/gmm_test.py +205 -0
  30. tests/kernels/mla_v1_test.py +14 -0
  31. tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
  32. tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
  33. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +14 -0
  34. tests/kernels/ragged_paged_attention_kernel_v3_test.py +14 -0
  35. tests/layers/__init__.py +13 -0
  36. tests/layers/common/__init__.py +13 -0
  37. tests/layers/common/test_attention_interface.py +156 -0
  38. tests/layers/common/test_quantization.py +149 -0
  39. tests/layers/jax/__init__.py +13 -0
  40. tests/layers/jax/attention/__init__.py +13 -0
  41. tests/layers/jax/attention/test_common_attention.py +103 -0
  42. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  43. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  44. tests/layers/jax/moe/__init__.py +13 -0
  45. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  46. tests/layers/jax/sample/__init__.py +13 -0
  47. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  48. tests/layers/jax/sample/test_sampling.py +115 -0
  49. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  50. tests/layers/jax/test_layers.py +155 -0
  51. tests/{test_quantization.py → layers/jax/test_qwix.py} +180 -50
  52. tests/layers/jax/test_rope.py +93 -0
  53. tests/layers/jax/test_sharding.py +159 -0
  54. tests/layers/jax/test_transformer_block.py +152 -0
  55. tests/layers/vllm/__init__.py +13 -0
  56. tests/layers/vllm/test_attention.py +363 -0
  57. tests/layers/vllm/test_awq.py +406 -0
  58. tests/layers/vllm/test_compressed_tensors_moe.py +199 -0
  59. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +441 -0
  60. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +443 -0
  61. tests/layers/vllm/test_fp8.py +17 -0
  62. tests/layers/vllm/test_mxfp4.py +320 -0
  63. tests/layers/vllm/test_unquantized.py +662 -0
  64. tests/layers/vllm/utils.py +87 -0
  65. tests/lora/__init__.py +13 -0
  66. tests/lora/conftest.py +14 -0
  67. tests/lora/test_bgmv.py +14 -0
  68. tests/lora/test_layers.py +25 -8
  69. tests/lora/test_lora.py +15 -1
  70. tests/lora/test_lora_perf.py +14 -0
  71. tests/models/__init__.py +13 -0
  72. tests/models/common/__init__.py +13 -0
  73. tests/models/common/test_model_loader.py +455 -0
  74. tests/models/jax/__init__.py +13 -0
  75. tests/models/jax/test_deepseek_v3.py +401 -0
  76. tests/models/jax/test_llama3.py +184 -0
  77. tests/models/jax/test_llama4.py +298 -0
  78. tests/models/jax/test_llama_eagle3.py +197 -0
  79. tests/models/jax/test_llama_guard_4.py +242 -0
  80. tests/models/jax/test_qwen2.py +172 -0
  81. tests/models/jax/test_qwen2_5_vl.py +605 -0
  82. tests/models/jax/test_qwen3.py +169 -0
  83. tests/models/jax/test_weight_loading.py +180 -0
  84. tests/models/jax/utils/__init__.py +13 -0
  85. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  86. tests/platforms/__init__.py +13 -0
  87. tests/platforms/test_tpu_platform.py +54 -0
  88. tests/runner/__init__.py +13 -0
  89. tests/runner/test_block_table.py +395 -0
  90. tests/runner/test_input_batch.py +226 -0
  91. tests/runner/test_kv_cache.py +220 -0
  92. tests/runner/test_kv_cache_manager.py +498 -0
  93. tests/runner/test_multimodal_manager.py +429 -0
  94. tests/runner/test_persistent_batch_manager.py +84 -0
  95. tests/runner/test_speculative_decoding_manager.py +368 -0
  96. tests/runner/test_structured_decoding_manager.py +220 -0
  97. tests/runner/test_tpu_runner.py +261 -0
  98. tests/runner/test_tpu_runner_dp.py +1099 -0
  99. tests/runner/test_tpu_runner_mesh.py +200 -0
  100. tests/runner/test_utils.py +411 -0
  101. tests/spec_decode/__init__.py +13 -0
  102. tests/spec_decode/test_eagle3.py +311 -0
  103. tests/test_base.py +14 -0
  104. tests/test_tpu_info.py +14 -0
  105. tests/test_utils.py +1 -43
  106. tests/worker/__init__.py +13 -0
  107. tests/worker/tpu_worker_test.py +414 -0
  108. tpu_inference/__init__.py +14 -0
  109. tpu_inference/core/__init__.py +13 -0
  110. tpu_inference/core/sched/__init__.py +13 -0
  111. tpu_inference/core/sched/dp_scheduler.py +372 -56
  112. tpu_inference/distributed/__init__.py +13 -0
  113. tpu_inference/distributed/jax_parallel_state.py +14 -0
  114. tpu_inference/distributed/tpu_connector.py +14 -9
  115. tpu_inference/distributed/utils.py +56 -4
  116. tpu_inference/executors/__init__.py +13 -0
  117. tpu_inference/executors/ray_distributed_executor.py +20 -3
  118. tpu_inference/experimental/__init__.py +13 -0
  119. tpu_inference/experimental/llama3_jax_stashed.py +14 -0
  120. tpu_inference/kernels/__init__.py +13 -0
  121. tpu_inference/kernels/collectives/__init__.py +13 -0
  122. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  123. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  124. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  125. tpu_inference/kernels/fused_moe/v1/kernel.py +171 -163
  126. tpu_inference/kernels/megablox/__init__.py +13 -0
  127. tpu_inference/kernels/megablox/common.py +54 -0
  128. tpu_inference/kernels/megablox/gmm.py +646 -0
  129. tpu_inference/kernels/mla/__init__.py +13 -0
  130. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  131. tpu_inference/kernels/mla/v1/kernel.py +20 -26
  132. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  133. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  134. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  135. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  136. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +112 -69
  137. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +85 -65
  138. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
  139. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +374 -194
  140. tpu_inference/kernels/ragged_paged_attention/v3/util.py +13 -0
  141. tpu_inference/layers/__init__.py +13 -0
  142. tpu_inference/layers/common/__init__.py +13 -0
  143. tpu_inference/layers/common/attention_interface.py +26 -19
  144. tpu_inference/layers/common/attention_metadata.py +14 -0
  145. tpu_inference/layers/common/fused_moe_gmm.py +506 -0
  146. tpu_inference/layers/common/quant_methods.py +15 -0
  147. tpu_inference/layers/common/quantization.py +282 -0
  148. tpu_inference/layers/common/sharding.py +22 -3
  149. tpu_inference/layers/common/utils.py +94 -0
  150. tpu_inference/layers/jax/__init__.py +13 -0
  151. tpu_inference/layers/jax/attention/__init__.py +13 -0
  152. tpu_inference/layers/jax/attention/attention.py +19 -6
  153. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +52 -27
  154. tpu_inference/layers/jax/attention/gpt_oss_attention.py +19 -6
  155. tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
  156. tpu_inference/layers/jax/base.py +14 -0
  157. tpu_inference/layers/jax/constants.py +13 -0
  158. tpu_inference/layers/jax/layers.py +14 -0
  159. tpu_inference/layers/jax/misc.py +14 -0
  160. tpu_inference/layers/jax/moe/__init__.py +13 -0
  161. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
  162. tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
  163. tpu_inference/layers/jax/moe/moe.py +43 -3
  164. tpu_inference/layers/jax/pp_utils.py +53 -0
  165. tpu_inference/layers/jax/rope.py +14 -0
  166. tpu_inference/layers/jax/rope_interface.py +14 -0
  167. tpu_inference/layers/jax/sample/__init__.py +13 -0
  168. tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
  169. tpu_inference/layers/jax/sample/sampling.py +15 -1
  170. tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
  171. tpu_inference/layers/jax/transformer_block.py +14 -0
  172. tpu_inference/layers/vllm/__init__.py +13 -0
  173. tpu_inference/layers/vllm/attention.py +4 -4
  174. tpu_inference/layers/vllm/fused_moe.py +100 -455
  175. tpu_inference/layers/vllm/linear.py +64 -0
  176. tpu_inference/layers/vllm/process_weights/__init__.py +13 -0
  177. tpu_inference/layers/vllm/{sharding.py → process_weights/cleanup_sharding.py} +24 -15
  178. tpu_inference/layers/vllm/process_weights/fused_moe_weights.py +369 -0
  179. tpu_inference/layers/vllm/process_weights/linear_weights.py +174 -0
  180. tpu_inference/layers/vllm/quantization/__init__.py +19 -3
  181. tpu_inference/layers/vllm/quantization/awq.py +96 -82
  182. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  183. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +19 -5
  184. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +119 -132
  185. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  186. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +111 -91
  187. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +79 -43
  188. tpu_inference/layers/vllm/quantization/{common.py → configs.py} +38 -26
  189. tpu_inference/layers/vllm/quantization/fp8.py +119 -0
  190. tpu_inference/layers/vllm/quantization/mxfp4.py +133 -220
  191. tpu_inference/layers/vllm/quantization/unquantized.py +154 -253
  192. tpu_inference/lora/__init__.py +13 -0
  193. tpu_inference/lora/torch_lora_ops.py +8 -13
  194. tpu_inference/models/__init__.py +13 -0
  195. tpu_inference/models/common/__init__.py +13 -0
  196. tpu_inference/models/common/model_loader.py +37 -16
  197. tpu_inference/models/jax/__init__.py +13 -0
  198. tpu_inference/models/jax/deepseek_v3.py +113 -124
  199. tpu_inference/models/jax/gpt_oss.py +23 -7
  200. tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
  201. tpu_inference/models/jax/llama3.py +99 -36
  202. tpu_inference/models/jax/llama4.py +14 -0
  203. tpu_inference/models/jax/llama_eagle3.py +14 -0
  204. tpu_inference/models/jax/llama_guard_4.py +15 -1
  205. tpu_inference/models/jax/qwen2.py +17 -2
  206. tpu_inference/models/jax/qwen2_5_vl.py +18 -4
  207. tpu_inference/models/jax/qwen3.py +17 -2
  208. tpu_inference/models/jax/utils/__init__.py +13 -0
  209. tpu_inference/models/jax/utils/file_utils.py +14 -0
  210. tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
  211. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  212. tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +85 -24
  213. tpu_inference/models/jax/utils/weight_utils.py +32 -1
  214. tpu_inference/models/vllm/__init__.py +13 -0
  215. tpu_inference/models/vllm/vllm_model_wrapper.py +22 -4
  216. tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
  217. tpu_inference/platforms/__init__.py +14 -0
  218. tpu_inference/platforms/tpu_platform.py +27 -29
  219. tpu_inference/runner/__init__.py +13 -0
  220. tpu_inference/runner/compilation_manager.py +69 -35
  221. tpu_inference/runner/kv_cache.py +14 -0
  222. tpu_inference/runner/kv_cache_manager.py +15 -2
  223. tpu_inference/runner/lora_utils.py +16 -1
  224. tpu_inference/runner/multimodal_manager.py +16 -2
  225. tpu_inference/runner/persistent_batch_manager.py +14 -0
  226. tpu_inference/runner/speculative_decoding_manager.py +14 -0
  227. tpu_inference/runner/structured_decoding_manager.py +14 -0
  228. tpu_inference/runner/tpu_runner.py +30 -10
  229. tpu_inference/spec_decode/__init__.py +13 -0
  230. tpu_inference/spec_decode/jax/__init__.py +13 -0
  231. tpu_inference/spec_decode/jax/eagle3.py +13 -0
  232. tpu_inference/tpu_info.py +14 -0
  233. tpu_inference/utils.py +31 -30
  234. tpu_inference/worker/__init__.py +13 -0
  235. tpu_inference/worker/tpu_worker.py +23 -7
  236. {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/METADATA +1 -1
  237. tpu_inference-0.13.2.dev20251230.dist-info/RECORD +266 -0
  238. tpu_inference/layers/vllm/linear_common.py +0 -208
  239. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  240. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
  241. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
  242. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
  243. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
  244. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
  245. tpu_inference-0.12.0.dev20251213.dist-info/RECORD +0 -175
  246. {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/WHEEL +0 -0
  247. {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/licenses/LICENSE +0 -0
  248. {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -1,3 +1,16 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
1
14
  """TPU-Friendly and Data-Movement-Friendly MLA Ragged Paged Attention kernel."""
2
15
 
3
16
  import functools
@@ -809,36 +822,17 @@ def _mla_ragged_paged_attention_kernel(
809
822
  return q_nope_vec, q_rope_vec
810
823
 
811
824
  def load_bkv(bkv_sem_idx, *, bkvc_mask, bkpe_mask):
812
- bitwidth = 32 // kv_packing
813
- repack_ty = jnp.dtype(f"uint{bitwidth}")
814
825
  bkvc_ref = (bkvc_x2_ref.bitcast(jnp.uint32).at[bkv_sem_idx].reshape(
815
826
  bkv_sz_per_kv_packing, lkv_dim))
816
- bkvc_vec = bkvc_ref[...]
817
- bkvc_vecs = []
818
- for i in range(kv_packing):
819
- masked_bkvc_vec = bkvc_vec >> (i * bitwidth)
820
- bkvc_vecs.append(masked_bkvc_vec)
821
- concated_bkvc_vec = jnp.concatenate(bkvc_vecs, axis=-1)
822
- concated_bkvc_vec = concated_bkvc_vec.reshape(bkv_sz, lkv_dim)
823
- concated_bkvc_vec = lax.select(bkvc_mask, concated_bkvc_vec,
824
- jnp.zeros_like(concated_bkvc_vec))
825
- concated_bkvc_vec = pltpu.bitcast(concated_bkvc_vec.astype(repack_ty),
826
- kv_dtype)
827
+ bkvc_vec = pltpu.bitcast(bkvc_ref[...], kv_dtype)
828
+ bkvc_vec = lax.select(bkvc_mask, bkvc_vec, jnp.zeros_like(bkvc_vec))
829
+
827
830
  bkpe_ref = (bkpe_x2_ref.bitcast(jnp.uint32).at[bkv_sem_idx].reshape(
828
831
  bkv_sz_per_kv_packing, r_dim))
829
- bkpe_vec = bkpe_ref[...]
830
- bkpe_vecs = []
831
- for i in range(kv_packing):
832
- masked_bkpe_vec = bkpe_vec >> (i * bitwidth)
833
- bkpe_vecs.append(masked_bkpe_vec)
834
- concated_bkpe_vec = jnp.concatenate(bkpe_vecs, axis=-1)
835
- concated_bkpe_vec = concated_bkpe_vec.reshape(bkv_sz, r_dim)
836
- concated_bkpe_vec = lax.select(bkpe_mask, concated_bkpe_vec,
837
- jnp.zeros_like(concated_bkpe_vec))
838
- concated_bkpe_vec = pltpu.bitcast(concated_bkpe_vec.astype(repack_ty),
839
- kv_dtype)
840
-
841
- return concated_bkvc_vec, concated_bkpe_vec
832
+ bkpe_vec = pltpu.bitcast(bkpe_ref[...], kv_dtype)
833
+ bkpe_vec = lax.select(bkpe_mask, bkpe_vec, jnp.zeros_like(bkpe_vec))
834
+
835
+ return bkvc_vec, bkpe_vec
842
836
 
843
837
  def broadcast_minor(src, shape):
844
838
  if src.shape == shape:
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -1,3 +1,16 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
1
14
  """TPU-Friendly Ragged Paged Attention kernel.
2
15
 
3
16
  This kernel offers a highly optimized implementation of ragged paged attention,
@@ -300,6 +313,22 @@ def _ragged_paged_attention_kernel(
300
313
  q_len = q_end - q_start
301
314
  kv_len = kv_lens_ref[seq_idx]
302
315
 
316
+ if sliding_window is None:
317
+ bkv_idx_start = next_seq_bkv_idx_start = 0
318
+ else:
319
+ bkv_idx_start = jnp.maximum(kv_len - q_len - sliding_window,
320
+ 0) // bkv_sz
321
+
322
+ # If seq_idx + 1 == num_seqs, kv_lens_ref[seq_idx + 1] will trigger a
323
+ # out-of-bound error. To avoid this, we set upperbound of next_seq_idx
324
+ # to be num_seqs - 1.
325
+ next_seq_idx = jnp.minimum(seq_idx + 1, num_seqs - 1)
326
+ next_kv_len = kv_lens_ref[next_seq_idx]
327
+ next_q_len = cu_q_lens_ref[next_seq_idx + 1] - q_end
328
+ next_seq_bkv_idx_start = (
329
+ jnp.maximum(next_kv_len - next_q_len - sliding_window, 0) //
330
+ bkv_sz)
331
+
303
332
  def debug_print(msg, *args):
304
333
  if debug_mode:
305
334
  pl.debug_print(msg, *args)
@@ -337,8 +366,8 @@ def _ragged_paged_attention_kernel(
337
366
  head_m_ref = m_ref.at[kv_head_idx, :q.shape[0]]
338
367
 
339
368
  def load_with_init(ref, init_val):
340
- return jnp.where(bkv_idx == 0, jnp.full_like(ref, init_val),
341
- ref[...])
369
+ return jnp.where(bkv_idx == bkv_idx_start,
370
+ jnp.full_like(ref, init_val), ref[...])
342
371
 
343
372
  # Follow FlashAttention-2 forward pass.
344
373
  if q_scale is not None:
@@ -356,20 +385,21 @@ def _ragged_paged_attention_kernel(
356
385
  s *= k_scale
357
386
  if q_scale is not None:
358
387
  s *= q_scale
388
+ if soft_cap is not None:
389
+ s = soft_cap * jnp.tanh(s / soft_cap)
359
390
 
360
391
  q_span = (kv_len - q_len + bq_idx * bq_sz +
361
392
  lax.broadcasted_iota(jnp.int32, s.shape, 0) //
362
393
  num_q_heads_per_kv_head)
363
394
  k_span = bkv_idx * bkv_sz + lax.broadcasted_iota(jnp.int32, s.shape, 1)
364
- mask = q_span < k_span
365
- # TODO(jevinjiang, xiowei): reduce pages_per_seq based on sliding_window.
395
+ mask = k_span <= q_span
396
+
366
397
  if sliding_window is not None:
367
- mask = jnp.logical_or(mask, q_span - sliding_window >= k_span)
398
+ mask = jnp.logical_and(mask, q_span - sliding_window < k_span)
368
399
 
369
- if soft_cap is not None:
370
- s = soft_cap * jnp.tanh(s / soft_cap)
371
- s += jnp.where(mask, mask_value, 0.0)
400
+ s = jnp.where(mask, s, mask_value)
372
401
  s_rowmax = jnp.max(s, axis=1, keepdims=True)
402
+
373
403
  m_prev = load_with_init(head_m_ref, -jnp.inf)
374
404
  m_curr = jnp.maximum(m_prev, s_rowmax)
375
405
  head_m_ref[...] = m_curr
@@ -685,7 +715,7 @@ def _ragged_paged_attention_kernel(
685
715
  vec = jnp.concat([ref[start + i::step] for i in range(folds)], axis=1)
686
716
  return vec
687
717
 
688
- def strided_load_bkv(bkv_sem_idx, start, step, *, bkv_mask):
718
+ def strided_load_bkv(bkv_sem_idx, start, step):
689
719
  assert start % kv_packing == 0
690
720
  assert step % kv_packing == 0
691
721
  start //= kv_packing
@@ -697,21 +727,11 @@ def _ragged_paged_attention_kernel(
697
727
  k = strided_load(kv_ref, start, step)
698
728
  v = strided_load(kv_ref, start + 1, step)
699
729
 
700
- kv_zeros = jnp.zeros_like(k)
701
- k = lax.select(bkv_mask, k, kv_zeros)
702
- v = lax.select(bkv_mask, v, kv_zeros)
703
-
704
730
  k = pltpu.bitcast(k, kv_dtype)
705
731
  v = pltpu.bitcast(v, kv_dtype)
706
732
  return [(k, v)]
707
733
 
708
734
  kv = strided_load(kv_ref, start, step)
709
- # bkv_mask holds information about where each row of bkv is valid. Because
710
- # kv is packed, a single 32-bits value might contain multiple k & v from
711
- # different kv heads. Despite this we can guarantee that all values in a
712
- # single 32-bits will map to the same bkv row. Therefore, it is safe to
713
- # apply bkv_mask to kv directly.
714
- kv = lax.select(bkv_mask, kv, jnp.zeros_like(kv))
715
735
  bitwidth = 32 // kv_packing
716
736
 
717
737
  # If we want to convert 32-bits into 32//N number of N-bits value, naive
@@ -789,12 +809,27 @@ def _ragged_paged_attention_kernel(
789
809
  def get_next_bkv_ids(seq_idx, bq_idx, bkv_idx, bkv_sem_idx):
790
810
  next_bkv_idx = bkv_idx + 1
791
811
  is_last_bkv = next_bkv_idx == num_bkv
792
- next_bkv_idx = lax.select(is_last_bkv, 0, next_bkv_idx)
793
812
  next_bq_idx = lax.select(is_last_bkv, bq_idx + 1, bq_idx)
794
813
  is_last_bq = next_bq_idx == num_bq
795
814
  next_bq_idx = lax.select(is_last_bq, 0, next_bq_idx)
796
815
  next_seq_idx = lax.select(is_last_bq, seq_idx + 1, seq_idx)
797
816
  next_bkv_sem_idx = lax.select(bkv_sem_idx == 0, 1, 0)
817
+
818
+ if sliding_window is None:
819
+ # When sliding window is disabled, starting bkv_idx of next request is
820
+ # always 0 regardless of seq_idx of next request.
821
+ next_bkv_idx_start = 0
822
+ else:
823
+ # Determine starting bkv_idx of next request based on whether next
824
+ # request is from the same sequence or next sequence.
825
+ next_bkv_idx_start = lax.select(
826
+ is_last_bq,
827
+ next_seq_bkv_idx_start,
828
+ bkv_idx_start,
829
+ )
830
+ next_bkv_idx = lax.select(is_last_bkv, next_bkv_idx_start,
831
+ next_bkv_idx)
832
+
798
833
  return next_seq_idx, next_bq_idx, next_bkv_idx, next_bkv_sem_idx
799
834
 
800
835
  def compute_with_bq(bq_idx, _):
@@ -811,10 +846,6 @@ def _ragged_paged_attention_kernel(
811
846
  def compute_with_bkv(bkv_idx, _):
812
847
  # Create bitmask for KV.
813
848
  assert bkv_sz % kv_packing == 0
814
- actual_bkv_sz = jnp.minimum(bkv_sz, kv_len - bkv_idx * bkv_sz)
815
- bkv_shape = (bkv_sz, head_dim)
816
- bkv_mask = lax.broadcasted_iota(jnp.int32, bkv_shape,
817
- 0) < actual_bkv_sz
818
849
 
819
850
  # Get next bkv ids.
820
851
  bkv_sem_idx = sem_ids_ref[1]
@@ -867,7 +898,6 @@ def _ragged_paged_attention_kernel(
867
898
  bkv_sem_idx,
868
899
  kv_head_start * 2,
869
900
  num_kv_heads_x2,
870
- bkv_mask=bkv_mask,
871
901
  )
872
902
  assert len(bkv_lst) == heads_per_load
873
903
  for i in range(heads_per_load):
@@ -949,7 +979,17 @@ def _ragged_paged_attention_kernel(
949
979
  @pl.when(seq_idx == 0)
950
980
  def prologue():
951
981
  start_fetch_bq(0, 0, 0)
952
- start_fetch_bkv(0, 0, 0)
982
+
983
+ # Initialize bkv_x2_ref to zeros to avoid NaN issues from accessing
984
+ # uninitialized memory. Bitcast into int32 to avoid tiling issues.
985
+ bkv_x2_int32_ref = bkv_x2_ref.bitcast(jnp.int32).reshape(
986
+ (2, -1, 8, 128))
987
+ zeros = jnp.zeros(bkv_x2_int32_ref.shape[1:], jnp.int32)
988
+
989
+ # To pipeline VST and DMA, we divide the initialization into two steps.
990
+ bkv_x2_int32_ref[0] = zeros
991
+ start_fetch_bkv(0, bkv_idx_start, 0)
992
+ bkv_x2_int32_ref[1] = zeros
953
993
 
954
994
  @pl.when(seq_idx < decode_end)
955
995
  def process_decode():
@@ -1298,6 +1338,10 @@ def static_validate_inputs(
1298
1338
  del debug_mode
1299
1339
 
1300
1340
 
1341
+ def get_kernel_scope_name(bq_size, bkv_p, page_size):
1342
+ return f"RPA-bq_{bq_size}-bkvp_{bkv_p}-p_{page_size}-"
1343
+
1344
+
1301
1345
  @functools.partial(
1302
1346
  jax.jit,
1303
1347
  static_argnames=(
@@ -1359,14 +1403,14 @@ def ragged_paged_attention(
1359
1403
  distribution: (i, j, k) represents that sequences[0:i] are decode-only,
1360
1404
  sequences[i:j] are chunked-prefill-only, and sequences[j:k] are mixed. The
1361
1405
  k is also the total number of sequences.
1362
- actual_head_dim: the actual head size of the attention. Here we assume k and
1363
- v have the same actual head size.
1364
1406
  sm_scale: the softmax scale which will be applied to the Q@K^T.
1365
1407
  sliding_window: the sliding window size for the attention.
1366
1408
  soft_cap: the logit soft cap for the attention.
1367
1409
  mask_value: mask value for causal mask.
1410
+ q_scale: the scale for the query.
1368
1411
  k_scale: the scale for the key cache.
1369
1412
  v_scale: the scale for the value cache.
1413
+ chunk_prefill_size: the chunk prefill size for the attention.
1370
1414
  num_kv_pages_per_block: number of kv pages to be processed in one flash
1371
1415
  attention block in the pallas kernel.
1372
1416
  num_queries_per_block: number of kv pages to be processed in one flash
@@ -1433,6 +1477,7 @@ def ragged_paged_attention(
1433
1477
  page_size,
1434
1478
  max_num_tokens,
1435
1479
  pages_per_seq,
1480
+ sliding_window,
1436
1481
  )
1437
1482
  bkv_sz = bkv_p * page_size
1438
1483
  if vmem_limit_bytes is None:
@@ -1501,47 +1546,45 @@ def ragged_paged_attention(
1501
1546
  jnp.full((6, ), -1, jnp.int32),
1502
1547
  )
1503
1548
 
1504
- scope_name = f"RPA-bq_{bq_sz}-bkvp_{bkv_p}-p_{page_size}"
1505
- kernel = jax.named_scope(scope_name)(
1506
- pl.pallas_call(
1507
- functools.partial(
1508
- _ragged_paged_attention_kernel,
1509
- sm_scale=sm_scale,
1510
- sliding_window=sliding_window,
1511
- soft_cap=soft_cap,
1512
- mask_value=mask_value,
1513
- q_scale=q_scale,
1514
- k_scale=k_scale,
1515
- v_scale=v_scale,
1516
- chunk_prefill_size=chunk_prefill_size,
1517
- bq_sz=bq_sz,
1518
- bkv_p=bkv_p,
1519
- debug_mode=debug_mode,
1520
- ),
1521
- grid_spec=pltpu.PrefetchScalarGridSpec(
1522
- num_scalar_prefetch=len(scalar_prefetches),
1523
- in_specs=in_specs,
1524
- out_specs=out_specs,
1525
- grid=grid,
1526
- scratch_shapes=scratch_shapes,
1527
- ),
1528
- compiler_params=pltpu.CompilerParams(
1529
- # TODO(jevinjiang): since each sequence depends on the previous
1530
- # one, we need some extra work to support Megacore mode.
1531
- dimension_semantics=("arbitrary", ),
1532
- vmem_limit_bytes=vmem_limit_bytes,
1533
- ),
1534
- out_shape=[
1535
- jax.ShapeDtypeStruct(shape=q.shape, dtype=q.dtype),
1536
- jax.ShapeDtypeStruct(shape=kv_cache.shape,
1537
- dtype=kv_cache.dtype),
1538
- ],
1539
- input_output_aliases={
1540
- 7: 0,
1541
- 9: 1
1542
- },
1543
- name=scope_name,
1544
- ))
1549
+ scope_name = get_kernel_scope_name(bq_sz, bkv_p, page_size)
1550
+ kernel = pl.pallas_call(
1551
+ functools.partial(
1552
+ _ragged_paged_attention_kernel,
1553
+ sm_scale=sm_scale,
1554
+ sliding_window=sliding_window,
1555
+ soft_cap=soft_cap,
1556
+ mask_value=mask_value,
1557
+ q_scale=q_scale,
1558
+ k_scale=k_scale,
1559
+ v_scale=v_scale,
1560
+ chunk_prefill_size=chunk_prefill_size,
1561
+ bq_sz=bq_sz,
1562
+ bkv_p=bkv_p,
1563
+ debug_mode=debug_mode,
1564
+ ),
1565
+ grid_spec=pltpu.PrefetchScalarGridSpec(
1566
+ num_scalar_prefetch=len(scalar_prefetches),
1567
+ in_specs=in_specs,
1568
+ out_specs=out_specs,
1569
+ grid=grid,
1570
+ scratch_shapes=scratch_shapes,
1571
+ ),
1572
+ compiler_params=pltpu.CompilerParams(
1573
+ # TODO(jevinjiang): since each sequence depends on the previous
1574
+ # one, we need some extra work to support Megacore mode.
1575
+ dimension_semantics=("arbitrary", ),
1576
+ vmem_limit_bytes=vmem_limit_bytes,
1577
+ ),
1578
+ out_shape=[
1579
+ jax.ShapeDtypeStruct(shape=q.shape, dtype=q.dtype),
1580
+ jax.ShapeDtypeStruct(shape=kv_cache.shape, dtype=kv_cache.dtype),
1581
+ ],
1582
+ input_output_aliases={
1583
+ 7: 0,
1584
+ 9: 1
1585
+ },
1586
+ name=scope_name,
1587
+ )
1545
1588
 
1546
1589
  output, updated_kv_cache = kernel(*scalar_prefetches, q, kv, kv_cache)
1547
1590
  return (