tpu-inference 0.12.0.dev20251213__py3-none-any.whl → 0.13.2.dev20251230__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +13 -0
- tests/core/__init__.py +13 -0
- tests/core/test_disagg_utils.py +14 -0
- tests/core/test_dp_scheduler.py +650 -768
- tests/core/test_init.py +14 -0
- tests/distributed/__init__.py +13 -0
- tests/distributed/test_distributed_utils.py +120 -0
- tests/distributed/test_tpu_connector.py +478 -0
- tests/e2e/__init__.py +13 -0
- tests/e2e/test_async_scheduler.py +211 -0
- tests/e2e/test_data_parallel.py +289 -0
- tests/e2e/test_hybrid_kvcache.py +219 -0
- tests/e2e/test_local_disagg.py +257 -0
- tests/e2e/test_model_loader.py +268 -0
- tests/e2e/test_multi_modal_inference.py +111 -0
- tests/e2e/test_pipeline_parallel.py +265 -0
- tests/e2e/test_runai_model_streamer_loader.py +104 -0
- tests/e2e/test_sampling_params.py +269 -0
- tests/e2e/test_speculative_decoding.py +311 -0
- tests/e2e/test_structured_decoding.py +46 -0
- tests/executors/__init__.py +13 -0
- tests/executors/test_ray_distributed_executor.py +199 -0
- tests/experimental/__init__.py +13 -0
- tests/experimental/test_llama3_jax_stashed.py +208 -0
- tests/kernels/__init__.py +13 -0
- tests/kernels/collectives/__init__.py +13 -0
- tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
- tests/kernels/fused_moe_v1_test.py +14 -0
- tests/kernels/gmm_test.py +205 -0
- tests/kernels/mla_v1_test.py +14 -0
- tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +14 -0
- tests/layers/__init__.py +13 -0
- tests/layers/common/__init__.py +13 -0
- tests/layers/common/test_attention_interface.py +156 -0
- tests/layers/common/test_quantization.py +149 -0
- tests/layers/jax/__init__.py +13 -0
- tests/layers/jax/attention/__init__.py +13 -0
- tests/layers/jax/attention/test_common_attention.py +103 -0
- tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
- tests/layers/jax/attention/test_llama4_attention.py +135 -0
- tests/layers/jax/moe/__init__.py +13 -0
- tests/layers/jax/moe/test_deepseek_moe.py +235 -0
- tests/layers/jax/sample/__init__.py +13 -0
- tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
- tests/layers/jax/sample/test_sampling.py +115 -0
- tests/layers/jax/sample/test_sampling_metadata.py +254 -0
- tests/layers/jax/test_layers.py +155 -0
- tests/{test_quantization.py → layers/jax/test_qwix.py} +180 -50
- tests/layers/jax/test_rope.py +93 -0
- tests/layers/jax/test_sharding.py +159 -0
- tests/layers/jax/test_transformer_block.py +152 -0
- tests/layers/vllm/__init__.py +13 -0
- tests/layers/vllm/test_attention.py +363 -0
- tests/layers/vllm/test_awq.py +406 -0
- tests/layers/vllm/test_compressed_tensors_moe.py +199 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +441 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +443 -0
- tests/layers/vllm/test_fp8.py +17 -0
- tests/layers/vllm/test_mxfp4.py +320 -0
- tests/layers/vllm/test_unquantized.py +662 -0
- tests/layers/vllm/utils.py +87 -0
- tests/lora/__init__.py +13 -0
- tests/lora/conftest.py +14 -0
- tests/lora/test_bgmv.py +14 -0
- tests/lora/test_layers.py +25 -8
- tests/lora/test_lora.py +15 -1
- tests/lora/test_lora_perf.py +14 -0
- tests/models/__init__.py +13 -0
- tests/models/common/__init__.py +13 -0
- tests/models/common/test_model_loader.py +455 -0
- tests/models/jax/__init__.py +13 -0
- tests/models/jax/test_deepseek_v3.py +401 -0
- tests/models/jax/test_llama3.py +184 -0
- tests/models/jax/test_llama4.py +298 -0
- tests/models/jax/test_llama_eagle3.py +197 -0
- tests/models/jax/test_llama_guard_4.py +242 -0
- tests/models/jax/test_qwen2.py +172 -0
- tests/models/jax/test_qwen2_5_vl.py +605 -0
- tests/models/jax/test_qwen3.py +169 -0
- tests/models/jax/test_weight_loading.py +180 -0
- tests/models/jax/utils/__init__.py +13 -0
- tests/models/jax/utils/test_multi_modal_utils.py +212 -0
- tests/platforms/__init__.py +13 -0
- tests/platforms/test_tpu_platform.py +54 -0
- tests/runner/__init__.py +13 -0
- tests/runner/test_block_table.py +395 -0
- tests/runner/test_input_batch.py +226 -0
- tests/runner/test_kv_cache.py +220 -0
- tests/runner/test_kv_cache_manager.py +498 -0
- tests/runner/test_multimodal_manager.py +429 -0
- tests/runner/test_persistent_batch_manager.py +84 -0
- tests/runner/test_speculative_decoding_manager.py +368 -0
- tests/runner/test_structured_decoding_manager.py +220 -0
- tests/runner/test_tpu_runner.py +261 -0
- tests/runner/test_tpu_runner_dp.py +1099 -0
- tests/runner/test_tpu_runner_mesh.py +200 -0
- tests/runner/test_utils.py +411 -0
- tests/spec_decode/__init__.py +13 -0
- tests/spec_decode/test_eagle3.py +311 -0
- tests/test_base.py +14 -0
- tests/test_tpu_info.py +14 -0
- tests/test_utils.py +1 -43
- tests/worker/__init__.py +13 -0
- tests/worker/tpu_worker_test.py +414 -0
- tpu_inference/__init__.py +14 -0
- tpu_inference/core/__init__.py +13 -0
- tpu_inference/core/sched/__init__.py +13 -0
- tpu_inference/core/sched/dp_scheduler.py +372 -56
- tpu_inference/distributed/__init__.py +13 -0
- tpu_inference/distributed/jax_parallel_state.py +14 -0
- tpu_inference/distributed/tpu_connector.py +14 -9
- tpu_inference/distributed/utils.py +56 -4
- tpu_inference/executors/__init__.py +13 -0
- tpu_inference/executors/ray_distributed_executor.py +20 -3
- tpu_inference/experimental/__init__.py +13 -0
- tpu_inference/experimental/llama3_jax_stashed.py +14 -0
- tpu_inference/kernels/__init__.py +13 -0
- tpu_inference/kernels/collectives/__init__.py +13 -0
- tpu_inference/kernels/flash_attention/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +171 -163
- tpu_inference/kernels/megablox/__init__.py +13 -0
- tpu_inference/kernels/megablox/common.py +54 -0
- tpu_inference/kernels/megablox/gmm.py +646 -0
- tpu_inference/kernels/mla/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/kernel.py +20 -26
- tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +112 -69
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +85 -65
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +374 -194
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +13 -0
- tpu_inference/layers/__init__.py +13 -0
- tpu_inference/layers/common/__init__.py +13 -0
- tpu_inference/layers/common/attention_interface.py +26 -19
- tpu_inference/layers/common/attention_metadata.py +14 -0
- tpu_inference/layers/common/fused_moe_gmm.py +506 -0
- tpu_inference/layers/common/quant_methods.py +15 -0
- tpu_inference/layers/common/quantization.py +282 -0
- tpu_inference/layers/common/sharding.py +22 -3
- tpu_inference/layers/common/utils.py +94 -0
- tpu_inference/layers/jax/__init__.py +13 -0
- tpu_inference/layers/jax/attention/__init__.py +13 -0
- tpu_inference/layers/jax/attention/attention.py +19 -6
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +52 -27
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +19 -6
- tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
- tpu_inference/layers/jax/base.py +14 -0
- tpu_inference/layers/jax/constants.py +13 -0
- tpu_inference/layers/jax/layers.py +14 -0
- tpu_inference/layers/jax/misc.py +14 -0
- tpu_inference/layers/jax/moe/__init__.py +13 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
- tpu_inference/layers/jax/moe/moe.py +43 -3
- tpu_inference/layers/jax/pp_utils.py +53 -0
- tpu_inference/layers/jax/rope.py +14 -0
- tpu_inference/layers/jax/rope_interface.py +14 -0
- tpu_inference/layers/jax/sample/__init__.py +13 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
- tpu_inference/layers/jax/sample/sampling.py +15 -1
- tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
- tpu_inference/layers/jax/transformer_block.py +14 -0
- tpu_inference/layers/vllm/__init__.py +13 -0
- tpu_inference/layers/vllm/attention.py +4 -4
- tpu_inference/layers/vllm/fused_moe.py +100 -455
- tpu_inference/layers/vllm/linear.py +64 -0
- tpu_inference/layers/vllm/process_weights/__init__.py +13 -0
- tpu_inference/layers/vllm/{sharding.py → process_weights/cleanup_sharding.py} +24 -15
- tpu_inference/layers/vllm/process_weights/fused_moe_weights.py +369 -0
- tpu_inference/layers/vllm/process_weights/linear_weights.py +174 -0
- tpu_inference/layers/vllm/quantization/__init__.py +19 -3
- tpu_inference/layers/vllm/quantization/awq.py +96 -82
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +19 -5
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +119 -132
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +111 -91
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +79 -43
- tpu_inference/layers/vllm/quantization/{common.py → configs.py} +38 -26
- tpu_inference/layers/vllm/quantization/fp8.py +119 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +133 -220
- tpu_inference/layers/vllm/quantization/unquantized.py +154 -253
- tpu_inference/lora/__init__.py +13 -0
- tpu_inference/lora/torch_lora_ops.py +8 -13
- tpu_inference/models/__init__.py +13 -0
- tpu_inference/models/common/__init__.py +13 -0
- tpu_inference/models/common/model_loader.py +37 -16
- tpu_inference/models/jax/__init__.py +13 -0
- tpu_inference/models/jax/deepseek_v3.py +113 -124
- tpu_inference/models/jax/gpt_oss.py +23 -7
- tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
- tpu_inference/models/jax/llama3.py +99 -36
- tpu_inference/models/jax/llama4.py +14 -0
- tpu_inference/models/jax/llama_eagle3.py +14 -0
- tpu_inference/models/jax/llama_guard_4.py +15 -1
- tpu_inference/models/jax/qwen2.py +17 -2
- tpu_inference/models/jax/qwen2_5_vl.py +18 -4
- tpu_inference/models/jax/qwen3.py +17 -2
- tpu_inference/models/jax/utils/__init__.py +13 -0
- tpu_inference/models/jax/utils/file_utils.py +14 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
- tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
- tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +85 -24
- tpu_inference/models/jax/utils/weight_utils.py +32 -1
- tpu_inference/models/vllm/__init__.py +13 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +22 -4
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
- tpu_inference/platforms/__init__.py +14 -0
- tpu_inference/platforms/tpu_platform.py +27 -29
- tpu_inference/runner/__init__.py +13 -0
- tpu_inference/runner/compilation_manager.py +69 -35
- tpu_inference/runner/kv_cache.py +14 -0
- tpu_inference/runner/kv_cache_manager.py +15 -2
- tpu_inference/runner/lora_utils.py +16 -1
- tpu_inference/runner/multimodal_manager.py +16 -2
- tpu_inference/runner/persistent_batch_manager.py +14 -0
- tpu_inference/runner/speculative_decoding_manager.py +14 -0
- tpu_inference/runner/structured_decoding_manager.py +14 -0
- tpu_inference/runner/tpu_runner.py +30 -10
- tpu_inference/spec_decode/__init__.py +13 -0
- tpu_inference/spec_decode/jax/__init__.py +13 -0
- tpu_inference/spec_decode/jax/eagle3.py +13 -0
- tpu_inference/tpu_info.py +14 -0
- tpu_inference/utils.py +31 -30
- tpu_inference/worker/__init__.py +13 -0
- tpu_inference/worker/tpu_worker.py +23 -7
- {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/METADATA +1 -1
- tpu_inference-0.13.2.dev20251230.dist-info/RECORD +266 -0
- tpu_inference/layers/vllm/linear_common.py +0 -208
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
- tpu_inference-0.12.0.dev20251213.dist-info/RECORD +0 -175
- {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/WHEEL +0 -0
- {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
@@ -1,3 +1,16 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
1
14
|
"""TPU-Friendly and Data-Movement-Friendly MLA Ragged Paged Attention kernel."""
|
|
2
15
|
|
|
3
16
|
import functools
|
|
@@ -809,36 +822,17 @@ def _mla_ragged_paged_attention_kernel(
|
|
|
809
822
|
return q_nope_vec, q_rope_vec
|
|
810
823
|
|
|
811
824
|
def load_bkv(bkv_sem_idx, *, bkvc_mask, bkpe_mask):
|
|
812
|
-
bitwidth = 32 // kv_packing
|
|
813
|
-
repack_ty = jnp.dtype(f"uint{bitwidth}")
|
|
814
825
|
bkvc_ref = (bkvc_x2_ref.bitcast(jnp.uint32).at[bkv_sem_idx].reshape(
|
|
815
826
|
bkv_sz_per_kv_packing, lkv_dim))
|
|
816
|
-
bkvc_vec = bkvc_ref[...]
|
|
817
|
-
|
|
818
|
-
|
|
819
|
-
masked_bkvc_vec = bkvc_vec >> (i * bitwidth)
|
|
820
|
-
bkvc_vecs.append(masked_bkvc_vec)
|
|
821
|
-
concated_bkvc_vec = jnp.concatenate(bkvc_vecs, axis=-1)
|
|
822
|
-
concated_bkvc_vec = concated_bkvc_vec.reshape(bkv_sz, lkv_dim)
|
|
823
|
-
concated_bkvc_vec = lax.select(bkvc_mask, concated_bkvc_vec,
|
|
824
|
-
jnp.zeros_like(concated_bkvc_vec))
|
|
825
|
-
concated_bkvc_vec = pltpu.bitcast(concated_bkvc_vec.astype(repack_ty),
|
|
826
|
-
kv_dtype)
|
|
827
|
+
bkvc_vec = pltpu.bitcast(bkvc_ref[...], kv_dtype)
|
|
828
|
+
bkvc_vec = lax.select(bkvc_mask, bkvc_vec, jnp.zeros_like(bkvc_vec))
|
|
829
|
+
|
|
827
830
|
bkpe_ref = (bkpe_x2_ref.bitcast(jnp.uint32).at[bkv_sem_idx].reshape(
|
|
828
831
|
bkv_sz_per_kv_packing, r_dim))
|
|
829
|
-
bkpe_vec = bkpe_ref[...]
|
|
830
|
-
|
|
831
|
-
|
|
832
|
-
|
|
833
|
-
bkpe_vecs.append(masked_bkpe_vec)
|
|
834
|
-
concated_bkpe_vec = jnp.concatenate(bkpe_vecs, axis=-1)
|
|
835
|
-
concated_bkpe_vec = concated_bkpe_vec.reshape(bkv_sz, r_dim)
|
|
836
|
-
concated_bkpe_vec = lax.select(bkpe_mask, concated_bkpe_vec,
|
|
837
|
-
jnp.zeros_like(concated_bkpe_vec))
|
|
838
|
-
concated_bkpe_vec = pltpu.bitcast(concated_bkpe_vec.astype(repack_ty),
|
|
839
|
-
kv_dtype)
|
|
840
|
-
|
|
841
|
-
return concated_bkvc_vec, concated_bkpe_vec
|
|
832
|
+
bkpe_vec = pltpu.bitcast(bkpe_ref[...], kv_dtype)
|
|
833
|
+
bkpe_vec = lax.select(bkpe_mask, bkpe_vec, jnp.zeros_like(bkpe_vec))
|
|
834
|
+
|
|
835
|
+
return bkvc_vec, bkpe_vec
|
|
842
836
|
|
|
843
837
|
def broadcast_minor(src, shape):
|
|
844
838
|
if src.shape == shape:
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
@@ -1,3 +1,16 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
1
14
|
"""TPU-Friendly Ragged Paged Attention kernel.
|
|
2
15
|
|
|
3
16
|
This kernel offers a highly optimized implementation of ragged paged attention,
|
|
@@ -300,6 +313,22 @@ def _ragged_paged_attention_kernel(
|
|
|
300
313
|
q_len = q_end - q_start
|
|
301
314
|
kv_len = kv_lens_ref[seq_idx]
|
|
302
315
|
|
|
316
|
+
if sliding_window is None:
|
|
317
|
+
bkv_idx_start = next_seq_bkv_idx_start = 0
|
|
318
|
+
else:
|
|
319
|
+
bkv_idx_start = jnp.maximum(kv_len - q_len - sliding_window,
|
|
320
|
+
0) // bkv_sz
|
|
321
|
+
|
|
322
|
+
# If seq_idx + 1 == num_seqs, kv_lens_ref[seq_idx + 1] will trigger a
|
|
323
|
+
# out-of-bound error. To avoid this, we set upperbound of next_seq_idx
|
|
324
|
+
# to be num_seqs - 1.
|
|
325
|
+
next_seq_idx = jnp.minimum(seq_idx + 1, num_seqs - 1)
|
|
326
|
+
next_kv_len = kv_lens_ref[next_seq_idx]
|
|
327
|
+
next_q_len = cu_q_lens_ref[next_seq_idx + 1] - q_end
|
|
328
|
+
next_seq_bkv_idx_start = (
|
|
329
|
+
jnp.maximum(next_kv_len - next_q_len - sliding_window, 0) //
|
|
330
|
+
bkv_sz)
|
|
331
|
+
|
|
303
332
|
def debug_print(msg, *args):
|
|
304
333
|
if debug_mode:
|
|
305
334
|
pl.debug_print(msg, *args)
|
|
@@ -337,8 +366,8 @@ def _ragged_paged_attention_kernel(
|
|
|
337
366
|
head_m_ref = m_ref.at[kv_head_idx, :q.shape[0]]
|
|
338
367
|
|
|
339
368
|
def load_with_init(ref, init_val):
|
|
340
|
-
return jnp.where(bkv_idx ==
|
|
341
|
-
ref[...])
|
|
369
|
+
return jnp.where(bkv_idx == bkv_idx_start,
|
|
370
|
+
jnp.full_like(ref, init_val), ref[...])
|
|
342
371
|
|
|
343
372
|
# Follow FlashAttention-2 forward pass.
|
|
344
373
|
if q_scale is not None:
|
|
@@ -356,20 +385,21 @@ def _ragged_paged_attention_kernel(
|
|
|
356
385
|
s *= k_scale
|
|
357
386
|
if q_scale is not None:
|
|
358
387
|
s *= q_scale
|
|
388
|
+
if soft_cap is not None:
|
|
389
|
+
s = soft_cap * jnp.tanh(s / soft_cap)
|
|
359
390
|
|
|
360
391
|
q_span = (kv_len - q_len + bq_idx * bq_sz +
|
|
361
392
|
lax.broadcasted_iota(jnp.int32, s.shape, 0) //
|
|
362
393
|
num_q_heads_per_kv_head)
|
|
363
394
|
k_span = bkv_idx * bkv_sz + lax.broadcasted_iota(jnp.int32, s.shape, 1)
|
|
364
|
-
mask =
|
|
365
|
-
|
|
395
|
+
mask = k_span <= q_span
|
|
396
|
+
|
|
366
397
|
if sliding_window is not None:
|
|
367
|
-
mask = jnp.
|
|
398
|
+
mask = jnp.logical_and(mask, q_span - sliding_window < k_span)
|
|
368
399
|
|
|
369
|
-
|
|
370
|
-
s = soft_cap * jnp.tanh(s / soft_cap)
|
|
371
|
-
s += jnp.where(mask, mask_value, 0.0)
|
|
400
|
+
s = jnp.where(mask, s, mask_value)
|
|
372
401
|
s_rowmax = jnp.max(s, axis=1, keepdims=True)
|
|
402
|
+
|
|
373
403
|
m_prev = load_with_init(head_m_ref, -jnp.inf)
|
|
374
404
|
m_curr = jnp.maximum(m_prev, s_rowmax)
|
|
375
405
|
head_m_ref[...] = m_curr
|
|
@@ -685,7 +715,7 @@ def _ragged_paged_attention_kernel(
|
|
|
685
715
|
vec = jnp.concat([ref[start + i::step] for i in range(folds)], axis=1)
|
|
686
716
|
return vec
|
|
687
717
|
|
|
688
|
-
def strided_load_bkv(bkv_sem_idx, start, step
|
|
718
|
+
def strided_load_bkv(bkv_sem_idx, start, step):
|
|
689
719
|
assert start % kv_packing == 0
|
|
690
720
|
assert step % kv_packing == 0
|
|
691
721
|
start //= kv_packing
|
|
@@ -697,21 +727,11 @@ def _ragged_paged_attention_kernel(
|
|
|
697
727
|
k = strided_load(kv_ref, start, step)
|
|
698
728
|
v = strided_load(kv_ref, start + 1, step)
|
|
699
729
|
|
|
700
|
-
kv_zeros = jnp.zeros_like(k)
|
|
701
|
-
k = lax.select(bkv_mask, k, kv_zeros)
|
|
702
|
-
v = lax.select(bkv_mask, v, kv_zeros)
|
|
703
|
-
|
|
704
730
|
k = pltpu.bitcast(k, kv_dtype)
|
|
705
731
|
v = pltpu.bitcast(v, kv_dtype)
|
|
706
732
|
return [(k, v)]
|
|
707
733
|
|
|
708
734
|
kv = strided_load(kv_ref, start, step)
|
|
709
|
-
# bkv_mask holds information about where each row of bkv is valid. Because
|
|
710
|
-
# kv is packed, a single 32-bits value might contain multiple k & v from
|
|
711
|
-
# different kv heads. Despite this we can guarantee that all values in a
|
|
712
|
-
# single 32-bits will map to the same bkv row. Therefore, it is safe to
|
|
713
|
-
# apply bkv_mask to kv directly.
|
|
714
|
-
kv = lax.select(bkv_mask, kv, jnp.zeros_like(kv))
|
|
715
735
|
bitwidth = 32 // kv_packing
|
|
716
736
|
|
|
717
737
|
# If we want to convert 32-bits into 32//N number of N-bits value, naive
|
|
@@ -789,12 +809,27 @@ def _ragged_paged_attention_kernel(
|
|
|
789
809
|
def get_next_bkv_ids(seq_idx, bq_idx, bkv_idx, bkv_sem_idx):
|
|
790
810
|
next_bkv_idx = bkv_idx + 1
|
|
791
811
|
is_last_bkv = next_bkv_idx == num_bkv
|
|
792
|
-
next_bkv_idx = lax.select(is_last_bkv, 0, next_bkv_idx)
|
|
793
812
|
next_bq_idx = lax.select(is_last_bkv, bq_idx + 1, bq_idx)
|
|
794
813
|
is_last_bq = next_bq_idx == num_bq
|
|
795
814
|
next_bq_idx = lax.select(is_last_bq, 0, next_bq_idx)
|
|
796
815
|
next_seq_idx = lax.select(is_last_bq, seq_idx + 1, seq_idx)
|
|
797
816
|
next_bkv_sem_idx = lax.select(bkv_sem_idx == 0, 1, 0)
|
|
817
|
+
|
|
818
|
+
if sliding_window is None:
|
|
819
|
+
# When sliding window is disabled, starting bkv_idx of next request is
|
|
820
|
+
# always 0 regardless of seq_idx of next request.
|
|
821
|
+
next_bkv_idx_start = 0
|
|
822
|
+
else:
|
|
823
|
+
# Determine starting bkv_idx of next request based on whether next
|
|
824
|
+
# request is from the same sequence or next sequence.
|
|
825
|
+
next_bkv_idx_start = lax.select(
|
|
826
|
+
is_last_bq,
|
|
827
|
+
next_seq_bkv_idx_start,
|
|
828
|
+
bkv_idx_start,
|
|
829
|
+
)
|
|
830
|
+
next_bkv_idx = lax.select(is_last_bkv, next_bkv_idx_start,
|
|
831
|
+
next_bkv_idx)
|
|
832
|
+
|
|
798
833
|
return next_seq_idx, next_bq_idx, next_bkv_idx, next_bkv_sem_idx
|
|
799
834
|
|
|
800
835
|
def compute_with_bq(bq_idx, _):
|
|
@@ -811,10 +846,6 @@ def _ragged_paged_attention_kernel(
|
|
|
811
846
|
def compute_with_bkv(bkv_idx, _):
|
|
812
847
|
# Create bitmask for KV.
|
|
813
848
|
assert bkv_sz % kv_packing == 0
|
|
814
|
-
actual_bkv_sz = jnp.minimum(bkv_sz, kv_len - bkv_idx * bkv_sz)
|
|
815
|
-
bkv_shape = (bkv_sz, head_dim)
|
|
816
|
-
bkv_mask = lax.broadcasted_iota(jnp.int32, bkv_shape,
|
|
817
|
-
0) < actual_bkv_sz
|
|
818
849
|
|
|
819
850
|
# Get next bkv ids.
|
|
820
851
|
bkv_sem_idx = sem_ids_ref[1]
|
|
@@ -867,7 +898,6 @@ def _ragged_paged_attention_kernel(
|
|
|
867
898
|
bkv_sem_idx,
|
|
868
899
|
kv_head_start * 2,
|
|
869
900
|
num_kv_heads_x2,
|
|
870
|
-
bkv_mask=bkv_mask,
|
|
871
901
|
)
|
|
872
902
|
assert len(bkv_lst) == heads_per_load
|
|
873
903
|
for i in range(heads_per_load):
|
|
@@ -949,7 +979,17 @@ def _ragged_paged_attention_kernel(
|
|
|
949
979
|
@pl.when(seq_idx == 0)
|
|
950
980
|
def prologue():
|
|
951
981
|
start_fetch_bq(0, 0, 0)
|
|
952
|
-
|
|
982
|
+
|
|
983
|
+
# Initialize bkv_x2_ref to zeros to avoid NaN issues from accessing
|
|
984
|
+
# uninitialized memory. Bitcast into int32 to avoid tiling issues.
|
|
985
|
+
bkv_x2_int32_ref = bkv_x2_ref.bitcast(jnp.int32).reshape(
|
|
986
|
+
(2, -1, 8, 128))
|
|
987
|
+
zeros = jnp.zeros(bkv_x2_int32_ref.shape[1:], jnp.int32)
|
|
988
|
+
|
|
989
|
+
# To pipeline VST and DMA, we divide the initialization into two steps.
|
|
990
|
+
bkv_x2_int32_ref[0] = zeros
|
|
991
|
+
start_fetch_bkv(0, bkv_idx_start, 0)
|
|
992
|
+
bkv_x2_int32_ref[1] = zeros
|
|
953
993
|
|
|
954
994
|
@pl.when(seq_idx < decode_end)
|
|
955
995
|
def process_decode():
|
|
@@ -1298,6 +1338,10 @@ def static_validate_inputs(
|
|
|
1298
1338
|
del debug_mode
|
|
1299
1339
|
|
|
1300
1340
|
|
|
1341
|
+
def get_kernel_scope_name(bq_size, bkv_p, page_size):
|
|
1342
|
+
return f"RPA-bq_{bq_size}-bkvp_{bkv_p}-p_{page_size}-"
|
|
1343
|
+
|
|
1344
|
+
|
|
1301
1345
|
@functools.partial(
|
|
1302
1346
|
jax.jit,
|
|
1303
1347
|
static_argnames=(
|
|
@@ -1359,14 +1403,14 @@ def ragged_paged_attention(
|
|
|
1359
1403
|
distribution: (i, j, k) represents that sequences[0:i] are decode-only,
|
|
1360
1404
|
sequences[i:j] are chunked-prefill-only, and sequences[j:k] are mixed. The
|
|
1361
1405
|
k is also the total number of sequences.
|
|
1362
|
-
actual_head_dim: the actual head size of the attention. Here we assume k and
|
|
1363
|
-
v have the same actual head size.
|
|
1364
1406
|
sm_scale: the softmax scale which will be applied to the Q@K^T.
|
|
1365
1407
|
sliding_window: the sliding window size for the attention.
|
|
1366
1408
|
soft_cap: the logit soft cap for the attention.
|
|
1367
1409
|
mask_value: mask value for causal mask.
|
|
1410
|
+
q_scale: the scale for the query.
|
|
1368
1411
|
k_scale: the scale for the key cache.
|
|
1369
1412
|
v_scale: the scale for the value cache.
|
|
1413
|
+
chunk_prefill_size: the chunk prefill size for the attention.
|
|
1370
1414
|
num_kv_pages_per_block: number of kv pages to be processed in one flash
|
|
1371
1415
|
attention block in the pallas kernel.
|
|
1372
1416
|
num_queries_per_block: number of kv pages to be processed in one flash
|
|
@@ -1433,6 +1477,7 @@ def ragged_paged_attention(
|
|
|
1433
1477
|
page_size,
|
|
1434
1478
|
max_num_tokens,
|
|
1435
1479
|
pages_per_seq,
|
|
1480
|
+
sliding_window,
|
|
1436
1481
|
)
|
|
1437
1482
|
bkv_sz = bkv_p * page_size
|
|
1438
1483
|
if vmem_limit_bytes is None:
|
|
@@ -1501,47 +1546,45 @@ def ragged_paged_attention(
|
|
|
1501
1546
|
jnp.full((6, ), -1, jnp.int32),
|
|
1502
1547
|
)
|
|
1503
1548
|
|
|
1504
|
-
scope_name =
|
|
1505
|
-
kernel =
|
|
1506
|
-
|
|
1507
|
-
|
|
1508
|
-
|
|
1509
|
-
|
|
1510
|
-
|
|
1511
|
-
|
|
1512
|
-
|
|
1513
|
-
|
|
1514
|
-
|
|
1515
|
-
|
|
1516
|
-
|
|
1517
|
-
|
|
1518
|
-
|
|
1519
|
-
|
|
1520
|
-
|
|
1521
|
-
|
|
1522
|
-
|
|
1523
|
-
|
|
1524
|
-
|
|
1525
|
-
|
|
1526
|
-
|
|
1527
|
-
|
|
1528
|
-
|
|
1529
|
-
|
|
1530
|
-
|
|
1531
|
-
|
|
1532
|
-
|
|
1533
|
-
|
|
1534
|
-
|
|
1535
|
-
|
|
1536
|
-
|
|
1537
|
-
|
|
1538
|
-
|
|
1539
|
-
|
|
1540
|
-
|
|
1541
|
-
|
|
1542
|
-
|
|
1543
|
-
name=scope_name,
|
|
1544
|
-
))
|
|
1549
|
+
scope_name = get_kernel_scope_name(bq_sz, bkv_p, page_size)
|
|
1550
|
+
kernel = pl.pallas_call(
|
|
1551
|
+
functools.partial(
|
|
1552
|
+
_ragged_paged_attention_kernel,
|
|
1553
|
+
sm_scale=sm_scale,
|
|
1554
|
+
sliding_window=sliding_window,
|
|
1555
|
+
soft_cap=soft_cap,
|
|
1556
|
+
mask_value=mask_value,
|
|
1557
|
+
q_scale=q_scale,
|
|
1558
|
+
k_scale=k_scale,
|
|
1559
|
+
v_scale=v_scale,
|
|
1560
|
+
chunk_prefill_size=chunk_prefill_size,
|
|
1561
|
+
bq_sz=bq_sz,
|
|
1562
|
+
bkv_p=bkv_p,
|
|
1563
|
+
debug_mode=debug_mode,
|
|
1564
|
+
),
|
|
1565
|
+
grid_spec=pltpu.PrefetchScalarGridSpec(
|
|
1566
|
+
num_scalar_prefetch=len(scalar_prefetches),
|
|
1567
|
+
in_specs=in_specs,
|
|
1568
|
+
out_specs=out_specs,
|
|
1569
|
+
grid=grid,
|
|
1570
|
+
scratch_shapes=scratch_shapes,
|
|
1571
|
+
),
|
|
1572
|
+
compiler_params=pltpu.CompilerParams(
|
|
1573
|
+
# TODO(jevinjiang): since each sequence depends on the previous
|
|
1574
|
+
# one, we need some extra work to support Megacore mode.
|
|
1575
|
+
dimension_semantics=("arbitrary", ),
|
|
1576
|
+
vmem_limit_bytes=vmem_limit_bytes,
|
|
1577
|
+
),
|
|
1578
|
+
out_shape=[
|
|
1579
|
+
jax.ShapeDtypeStruct(shape=q.shape, dtype=q.dtype),
|
|
1580
|
+
jax.ShapeDtypeStruct(shape=kv_cache.shape, dtype=kv_cache.dtype),
|
|
1581
|
+
],
|
|
1582
|
+
input_output_aliases={
|
|
1583
|
+
7: 0,
|
|
1584
|
+
9: 1
|
|
1585
|
+
},
|
|
1586
|
+
name=scope_name,
|
|
1587
|
+
)
|
|
1545
1588
|
|
|
1546
1589
|
output, updated_kv_cache = kernel(*scalar_prefetches, q, kv, kv_cache)
|
|
1547
1590
|
return (
|