tpu-inference 0.11.1.dev202511220812__py3-none-any.whl → 0.13.2.dev20251230__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +13 -0
- tests/core/__init__.py +13 -0
- tests/core/test_disagg_utils.py +14 -0
- tests/core/test_dp_scheduler.py +650 -768
- tests/core/test_init.py +14 -0
- tests/distributed/__init__.py +13 -0
- tests/distributed/test_distributed_utils.py +120 -0
- tests/distributed/test_tpu_connector.py +478 -0
- tests/e2e/__init__.py +13 -0
- tests/e2e/test_async_scheduler.py +211 -0
- tests/e2e/test_data_parallel.py +289 -0
- tests/e2e/test_hybrid_kvcache.py +219 -0
- tests/e2e/test_local_disagg.py +257 -0
- tests/e2e/test_model_loader.py +268 -0
- tests/e2e/test_multi_modal_inference.py +111 -0
- tests/e2e/test_pipeline_parallel.py +265 -0
- tests/e2e/test_runai_model_streamer_loader.py +104 -0
- tests/e2e/test_sampling_params.py +269 -0
- tests/e2e/test_speculative_decoding.py +311 -0
- tests/e2e/test_structured_decoding.py +46 -0
- tests/executors/__init__.py +13 -0
- tests/executors/test_ray_distributed_executor.py +199 -0
- tests/experimental/__init__.py +13 -0
- tests/experimental/test_llama3_jax_stashed.py +208 -0
- tests/kernels/__init__.py +13 -0
- tests/kernels/collectives/__init__.py +13 -0
- tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
- tests/kernels/fused_moe_v1_test.py +317 -34
- tests/kernels/gmm_test.py +205 -0
- tests/kernels/mla_v1_test.py +143 -41
- tests/kernels/quantized_matmul_kernel_test.py +2 -34
- tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
- tests/layers/__init__.py +13 -0
- tests/layers/common/__init__.py +13 -0
- tests/layers/common/test_attention_interface.py +156 -0
- tests/layers/common/test_quantization.py +149 -0
- tests/layers/jax/__init__.py +13 -0
- tests/layers/jax/attention/__init__.py +13 -0
- tests/layers/jax/attention/test_common_attention.py +103 -0
- tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
- tests/layers/jax/attention/test_llama4_attention.py +135 -0
- tests/layers/jax/moe/__init__.py +13 -0
- tests/layers/jax/moe/test_deepseek_moe.py +235 -0
- tests/layers/jax/sample/__init__.py +13 -0
- tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
- tests/layers/jax/sample/test_sampling.py +115 -0
- tests/layers/jax/sample/test_sampling_metadata.py +254 -0
- tests/layers/jax/test_layers.py +155 -0
- tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
- tests/layers/jax/test_rope.py +93 -0
- tests/layers/jax/test_sharding.py +159 -0
- tests/layers/jax/test_transformer_block.py +152 -0
- tests/layers/vllm/__init__.py +13 -0
- tests/layers/vllm/test_attention.py +363 -0
- tests/layers/vllm/test_awq.py +406 -0
- tests/layers/vllm/test_compressed_tensors_moe.py +199 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +441 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +443 -0
- tests/layers/vllm/test_fp8.py +17 -0
- tests/layers/vllm/test_mxfp4.py +320 -0
- tests/layers/vllm/test_unquantized.py +662 -0
- tests/layers/vllm/utils.py +87 -0
- tests/lora/__init__.py +13 -0
- tests/lora/conftest.py +14 -0
- tests/lora/test_bgmv.py +14 -0
- tests/lora/test_layers.py +26 -6
- tests/lora/test_lora.py +15 -1
- tests/lora/test_lora_perf.py +67 -0
- tests/models/__init__.py +13 -0
- tests/models/common/__init__.py +13 -0
- tests/models/common/test_model_loader.py +455 -0
- tests/models/jax/__init__.py +13 -0
- tests/models/jax/test_deepseek_v3.py +401 -0
- tests/models/jax/test_llama3.py +184 -0
- tests/models/jax/test_llama4.py +298 -0
- tests/models/jax/test_llama_eagle3.py +197 -0
- tests/models/jax/test_llama_guard_4.py +242 -0
- tests/models/jax/test_qwen2.py +172 -0
- tests/models/jax/test_qwen2_5_vl.py +605 -0
- tests/models/jax/test_qwen3.py +169 -0
- tests/models/jax/test_weight_loading.py +180 -0
- tests/models/jax/utils/__init__.py +13 -0
- tests/models/jax/utils/test_multi_modal_utils.py +212 -0
- tests/platforms/__init__.py +13 -0
- tests/platforms/test_tpu_platform.py +54 -0
- tests/runner/__init__.py +13 -0
- tests/runner/test_block_table.py +395 -0
- tests/runner/test_input_batch.py +226 -0
- tests/runner/test_kv_cache.py +220 -0
- tests/runner/test_kv_cache_manager.py +498 -0
- tests/runner/test_multimodal_manager.py +429 -0
- tests/runner/test_persistent_batch_manager.py +84 -0
- tests/runner/test_speculative_decoding_manager.py +368 -0
- tests/runner/test_structured_decoding_manager.py +220 -0
- tests/runner/test_tpu_runner.py +261 -0
- tests/runner/test_tpu_runner_dp.py +1099 -0
- tests/runner/test_tpu_runner_mesh.py +200 -0
- tests/runner/test_utils.py +411 -0
- tests/spec_decode/__init__.py +13 -0
- tests/spec_decode/test_eagle3.py +311 -0
- tests/test_base.py +14 -0
- tests/test_envs.py +110 -12
- tests/test_tpu_info.py +14 -0
- tests/test_utils.py +2 -45
- tests/worker/__init__.py +13 -0
- tests/worker/tpu_worker_test.py +414 -0
- tpu_inference/__init__.py +14 -0
- tpu_inference/core/__init__.py +13 -0
- tpu_inference/core/sched/__init__.py +13 -0
- tpu_inference/core/sched/dp_scheduler.py +372 -56
- tpu_inference/distributed/__init__.py +13 -0
- tpu_inference/distributed/jax_parallel_state.py +14 -0
- tpu_inference/distributed/tpu_connector.py +15 -10
- tpu_inference/distributed/utils.py +56 -4
- tpu_inference/envs.py +92 -8
- tpu_inference/executors/__init__.py +13 -0
- tpu_inference/executors/ray_distributed_executor.py +25 -4
- tpu_inference/experimental/__init__.py +13 -0
- tpu_inference/experimental/llama3_jax_stashed.py +14 -0
- tpu_inference/kernels/__init__.py +13 -0
- tpu_inference/kernels/collectives/__init__.py +13 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
- tpu_inference/kernels/flash_attention/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +807 -230
- tpu_inference/kernels/megablox/__init__.py +13 -0
- tpu_inference/kernels/megablox/common.py +54 -0
- tpu_inference/kernels/megablox/gmm.py +646 -0
- tpu_inference/kernels/mla/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/kernel.py +117 -145
- tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
- tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +218 -137
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
- tpu_inference/layers/__init__.py +13 -0
- tpu_inference/layers/common/__init__.py +13 -0
- tpu_inference/layers/common/attention_interface.py +25 -12
- tpu_inference/layers/common/attention_metadata.py +14 -0
- tpu_inference/layers/common/fused_moe_gmm.py +506 -0
- tpu_inference/layers/common/quant_methods.py +15 -0
- tpu_inference/layers/common/quantization.py +282 -0
- tpu_inference/layers/common/sharding.py +32 -9
- tpu_inference/layers/common/utils.py +94 -0
- tpu_inference/layers/jax/__init__.py +13 -0
- tpu_inference/layers/jax/attention/__init__.py +13 -0
- tpu_inference/layers/jax/attention/attention.py +19 -6
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
- tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
- tpu_inference/layers/jax/base.py +14 -0
- tpu_inference/layers/jax/constants.py +13 -0
- tpu_inference/layers/jax/layers.py +14 -0
- tpu_inference/layers/jax/misc.py +14 -0
- tpu_inference/layers/jax/moe/__init__.py +13 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
- tpu_inference/layers/jax/moe/moe.py +43 -3
- tpu_inference/layers/jax/pp_utils.py +53 -0
- tpu_inference/layers/jax/rope.py +14 -0
- tpu_inference/layers/jax/rope_interface.py +14 -0
- tpu_inference/layers/jax/sample/__init__.py +13 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
- tpu_inference/layers/jax/sample/sampling.py +15 -1
- tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
- tpu_inference/layers/jax/transformer_block.py +14 -0
- tpu_inference/layers/vllm/__init__.py +13 -0
- tpu_inference/layers/vllm/attention.py +4 -4
- tpu_inference/layers/vllm/fused_moe.py +101 -494
- tpu_inference/layers/vllm/linear.py +64 -0
- tpu_inference/layers/vllm/process_weights/__init__.py +13 -0
- tpu_inference/layers/vllm/{sharding.py → process_weights/cleanup_sharding.py} +24 -15
- tpu_inference/layers/vllm/process_weights/fused_moe_weights.py +369 -0
- tpu_inference/layers/vllm/process_weights/linear_weights.py +174 -0
- tpu_inference/layers/vllm/quantization/__init__.py +19 -3
- tpu_inference/layers/vllm/quantization/awq.py +96 -82
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +23 -8
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +172 -176
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +111 -91
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +79 -43
- tpu_inference/layers/vllm/quantization/{common.py → configs.py} +42 -25
- tpu_inference/layers/vllm/quantization/fp8.py +119 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +137 -178
- tpu_inference/layers/vllm/quantization/unquantized.py +157 -233
- tpu_inference/lora/__init__.py +13 -0
- tpu_inference/lora/torch_lora_ops.py +8 -13
- tpu_inference/models/__init__.py +13 -0
- tpu_inference/models/common/__init__.py +13 -0
- tpu_inference/models/common/model_loader.py +112 -35
- tpu_inference/models/jax/__init__.py +13 -0
- tpu_inference/models/jax/deepseek_v3.py +267 -157
- tpu_inference/models/jax/gpt_oss.py +26 -10
- tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
- tpu_inference/models/jax/llama3.py +99 -36
- tpu_inference/models/jax/llama4.py +14 -0
- tpu_inference/models/jax/llama_eagle3.py +18 -5
- tpu_inference/models/jax/llama_guard_4.py +15 -1
- tpu_inference/models/jax/qwen2.py +17 -2
- tpu_inference/models/jax/qwen2_5_vl.py +179 -51
- tpu_inference/models/jax/qwen3.py +17 -2
- tpu_inference/models/jax/utils/__init__.py +13 -0
- tpu_inference/models/jax/utils/file_utils.py +14 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
- tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
- tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +92 -32
- tpu_inference/models/jax/utils/weight_utils.py +234 -155
- tpu_inference/models/vllm/__init__.py +13 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +32 -8
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
- tpu_inference/platforms/__init__.py +14 -0
- tpu_inference/platforms/tpu_platform.py +51 -72
- tpu_inference/runner/__init__.py +13 -0
- tpu_inference/runner/compilation_manager.py +180 -80
- tpu_inference/runner/kv_cache.py +54 -20
- tpu_inference/runner/kv_cache_manager.py +55 -33
- tpu_inference/runner/lora_utils.py +16 -1
- tpu_inference/runner/multimodal_manager.py +16 -2
- tpu_inference/runner/persistent_batch_manager.py +54 -2
- tpu_inference/runner/speculative_decoding_manager.py +14 -0
- tpu_inference/runner/structured_decoding_manager.py +16 -3
- tpu_inference/runner/tpu_runner.py +124 -61
- tpu_inference/runner/utils.py +2 -2
- tpu_inference/spec_decode/__init__.py +13 -0
- tpu_inference/spec_decode/jax/__init__.py +13 -0
- tpu_inference/spec_decode/jax/eagle3.py +84 -22
- tpu_inference/tpu_info.py +14 -0
- tpu_inference/utils.py +72 -44
- tpu_inference/worker/__init__.py +13 -0
- tpu_inference/worker/tpu_worker.py +66 -52
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/METADATA +8 -9
- tpu_inference-0.13.2.dev20251230.dist-info/RECORD +266 -0
- tpu_inference/layers/vllm/linear_common.py +0 -186
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
- tpu_inference-0.11.1.dev202511220812.dist-info/RECORD +0 -174
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/WHEEL +0 -0
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
@@ -200,7 +200,8 @@ def _prev_power_of_2(n: int) -> int:
|
|
|
200
200
|
def _get_page_size_bytes(block_size: int, num_combined_kv_heads: int,
|
|
201
201
|
head_size: int, kv_cache_dtype) -> int:
|
|
202
202
|
"""Returns the size in bytes of one page of the KV cache."""
|
|
203
|
-
kv_cache_dtype_bit_size = dtypes.bit_width(kv_cache_dtype)
|
|
203
|
+
kv_cache_dtype_bit_size = (dtypes.bit_width(kv_cache_dtype) if hasattr(
|
|
204
|
+
dtypes, "bit_width") else dtypes.itemsize_bits(kv_cache_dtype))
|
|
204
205
|
padded_head_size = _ceil_div(
|
|
205
206
|
head_size, TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT
|
|
206
207
|
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
@@ -1,3 +1,16 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
1
14
|
"""TPU-Friendly Ragged Paged Attention kernel.
|
|
2
15
|
|
|
3
16
|
This kernel offers a highly optimized implementation of ragged paged attention,
|
|
@@ -300,6 +313,22 @@ def _ragged_paged_attention_kernel(
|
|
|
300
313
|
q_len = q_end - q_start
|
|
301
314
|
kv_len = kv_lens_ref[seq_idx]
|
|
302
315
|
|
|
316
|
+
if sliding_window is None:
|
|
317
|
+
bkv_idx_start = next_seq_bkv_idx_start = 0
|
|
318
|
+
else:
|
|
319
|
+
bkv_idx_start = jnp.maximum(kv_len - q_len - sliding_window,
|
|
320
|
+
0) // bkv_sz
|
|
321
|
+
|
|
322
|
+
# If seq_idx + 1 == num_seqs, kv_lens_ref[seq_idx + 1] will trigger a
|
|
323
|
+
# out-of-bound error. To avoid this, we set upperbound of next_seq_idx
|
|
324
|
+
# to be num_seqs - 1.
|
|
325
|
+
next_seq_idx = jnp.minimum(seq_idx + 1, num_seqs - 1)
|
|
326
|
+
next_kv_len = kv_lens_ref[next_seq_idx]
|
|
327
|
+
next_q_len = cu_q_lens_ref[next_seq_idx + 1] - q_end
|
|
328
|
+
next_seq_bkv_idx_start = (
|
|
329
|
+
jnp.maximum(next_kv_len - next_q_len - sliding_window, 0) //
|
|
330
|
+
bkv_sz)
|
|
331
|
+
|
|
303
332
|
def debug_print(msg, *args):
|
|
304
333
|
if debug_mode:
|
|
305
334
|
pl.debug_print(msg, *args)
|
|
@@ -319,7 +348,7 @@ def _ragged_paged_attention_kernel(
|
|
|
319
348
|
debug_print("[RPA debug] q_len={}", q_len)
|
|
320
349
|
debug_print("[RPA debug] kv_len={}", kv_len)
|
|
321
350
|
|
|
322
|
-
def
|
|
351
|
+
def flash_attention_step1_qk_softmax(
|
|
323
352
|
q, # [actual_bq_sz * num_q_heads_per_kv_head, head_dim]
|
|
324
353
|
k, # [bkv_sz, head_dim]
|
|
325
354
|
v, # [bkv_sz, head_dim]
|
|
@@ -335,11 +364,10 @@ def _ragged_paged_attention_kernel(
|
|
|
335
364
|
assert k.dtype == v.dtype
|
|
336
365
|
head_l_ref = l_ref.at[kv_head_idx, :q.shape[0]]
|
|
337
366
|
head_m_ref = m_ref.at[kv_head_idx, :q.shape[0]]
|
|
338
|
-
head_acc_ref = acc_ref.at[kv_head_idx, :q.shape[0]]
|
|
339
367
|
|
|
340
368
|
def load_with_init(ref, init_val):
|
|
341
|
-
return jnp.where(bkv_idx ==
|
|
342
|
-
ref[...])
|
|
369
|
+
return jnp.where(bkv_idx == bkv_idx_start,
|
|
370
|
+
jnp.full_like(ref, init_val), ref[...])
|
|
343
371
|
|
|
344
372
|
# Follow FlashAttention-2 forward pass.
|
|
345
373
|
if q_scale is not None:
|
|
@@ -357,34 +385,52 @@ def _ragged_paged_attention_kernel(
|
|
|
357
385
|
s *= k_scale
|
|
358
386
|
if q_scale is not None:
|
|
359
387
|
s *= q_scale
|
|
388
|
+
if soft_cap is not None:
|
|
389
|
+
s = soft_cap * jnp.tanh(s / soft_cap)
|
|
360
390
|
|
|
361
391
|
q_span = (kv_len - q_len + bq_idx * bq_sz +
|
|
362
392
|
lax.broadcasted_iota(jnp.int32, s.shape, 0) //
|
|
363
393
|
num_q_heads_per_kv_head)
|
|
364
394
|
k_span = bkv_idx * bkv_sz + lax.broadcasted_iota(jnp.int32, s.shape, 1)
|
|
365
|
-
mask =
|
|
366
|
-
|
|
395
|
+
mask = k_span <= q_span
|
|
396
|
+
|
|
367
397
|
if sliding_window is not None:
|
|
368
|
-
mask = jnp.
|
|
398
|
+
mask = jnp.logical_and(mask, q_span - sliding_window < k_span)
|
|
369
399
|
|
|
370
|
-
|
|
371
|
-
s = soft_cap * jnp.tanh(s / soft_cap)
|
|
372
|
-
s += jnp.where(mask, mask_value, 0.0)
|
|
400
|
+
s = jnp.where(mask, s, mask_value)
|
|
373
401
|
s_rowmax = jnp.max(s, axis=1, keepdims=True)
|
|
402
|
+
|
|
374
403
|
m_prev = load_with_init(head_m_ref, -jnp.inf)
|
|
375
404
|
m_curr = jnp.maximum(m_prev, s_rowmax)
|
|
376
405
|
head_m_ref[...] = m_curr
|
|
377
406
|
p = jnp.exp(s - broadcast_minor(m_curr, s.shape))
|
|
378
407
|
|
|
379
|
-
pv = jnp.einsum("nm,md->nd", p, v, preferred_element_type=jnp.float32)
|
|
380
|
-
if v_scale is not None:
|
|
381
|
-
pv *= v_scale
|
|
382
|
-
|
|
383
408
|
p_rowsum = jnp.sum(p, axis=1, keepdims=True)
|
|
384
409
|
exp_m_diff = jnp.exp(m_prev - m_curr)
|
|
385
410
|
l_prev = load_with_init(head_l_ref, 0.0)
|
|
386
411
|
l_curr = exp_m_diff * l_prev + p_rowsum
|
|
387
412
|
head_l_ref[...] = l_curr
|
|
413
|
+
|
|
414
|
+
return p, exp_m_diff
|
|
415
|
+
|
|
416
|
+
def flash_attention_step2_pv(
|
|
417
|
+
q_shape_0,
|
|
418
|
+
v, # [bkv_sz, head_dim]
|
|
419
|
+
p, # from step1
|
|
420
|
+
exp_m_diff, # from step1
|
|
421
|
+
*,
|
|
422
|
+
bkv_idx,
|
|
423
|
+
kv_head_idx,
|
|
424
|
+
):
|
|
425
|
+
head_acc_ref = acc_ref.at[kv_head_idx, :q_shape_0]
|
|
426
|
+
|
|
427
|
+
def load_with_init(ref, init_val):
|
|
428
|
+
return jnp.where(bkv_idx == 0, jnp.full_like(ref, init_val),
|
|
429
|
+
ref[...])
|
|
430
|
+
|
|
431
|
+
pv = jnp.einsum("nm,md->nd", p, v, preferred_element_type=jnp.float32)
|
|
432
|
+
if v_scale is not None:
|
|
433
|
+
pv *= v_scale
|
|
388
434
|
o_prev = load_with_init(head_acc_ref, 0.0)
|
|
389
435
|
o_curr = broadcast_minor(exp_m_diff, o_prev.shape) * o_prev + pv
|
|
390
436
|
head_acc_ref[...] = o_curr
|
|
@@ -463,19 +509,16 @@ def _ragged_paged_attention_kernel(
|
|
|
463
509
|
unroll=False,
|
|
464
510
|
)
|
|
465
511
|
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
|
|
476
|
-
sem,
|
|
477
|
-
wait,
|
|
478
|
-
)
|
|
512
|
+
size = lax.select(bkv_sz_frm_new > 0, bkv_sz_frm_new, 0)
|
|
513
|
+
new_kv_len_start = q_end - kv_left_frm_new
|
|
514
|
+
debug_print("[RPA debug] new_kv_len_start={}", new_kv_len_start)
|
|
515
|
+
debug_print("[RPA debug] offset_in_bkv={}", offset)
|
|
516
|
+
_async_copy(
|
|
517
|
+
kv_hbm_ref.at[pl.ds(new_kv_len_start, size)],
|
|
518
|
+
vmem_ref.at[pl.ds(offset, size)],
|
|
519
|
+
sem,
|
|
520
|
+
wait,
|
|
521
|
+
)
|
|
479
522
|
|
|
480
523
|
return kv_len_start + offset, bkv_sz_frm_new
|
|
481
524
|
else:
|
|
@@ -672,7 +715,7 @@ def _ragged_paged_attention_kernel(
|
|
|
672
715
|
vec = jnp.concat([ref[start + i::step] for i in range(folds)], axis=1)
|
|
673
716
|
return vec
|
|
674
717
|
|
|
675
|
-
def strided_load_bkv(bkv_sem_idx, start, step
|
|
718
|
+
def strided_load_bkv(bkv_sem_idx, start, step):
|
|
676
719
|
assert start % kv_packing == 0
|
|
677
720
|
assert step % kv_packing == 0
|
|
678
721
|
start //= kv_packing
|
|
@@ -684,21 +727,11 @@ def _ragged_paged_attention_kernel(
|
|
|
684
727
|
k = strided_load(kv_ref, start, step)
|
|
685
728
|
v = strided_load(kv_ref, start + 1, step)
|
|
686
729
|
|
|
687
|
-
kv_zeros = jnp.zeros_like(k)
|
|
688
|
-
k = lax.select(bkv_mask, k, kv_zeros)
|
|
689
|
-
v = lax.select(bkv_mask, v, kv_zeros)
|
|
690
|
-
|
|
691
730
|
k = pltpu.bitcast(k, kv_dtype)
|
|
692
731
|
v = pltpu.bitcast(v, kv_dtype)
|
|
693
732
|
return [(k, v)]
|
|
694
733
|
|
|
695
734
|
kv = strided_load(kv_ref, start, step)
|
|
696
|
-
# bkv_mask holds information about where each row of bkv is valid. Because
|
|
697
|
-
# kv is packed, a single 32-bits value might contain multiple k & v from
|
|
698
|
-
# different kv heads. Despite this we can guarantee that all values in a
|
|
699
|
-
# single 32-bits will map to the same bkv row. Therefore, it is safe to
|
|
700
|
-
# apply bkv_mask to kv directly.
|
|
701
|
-
kv = lax.select(bkv_mask, kv, jnp.zeros_like(kv))
|
|
702
735
|
bitwidth = 32 // kv_packing
|
|
703
736
|
|
|
704
737
|
# If we want to convert 32-bits into 32//N number of N-bits value, naive
|
|
@@ -776,12 +809,27 @@ def _ragged_paged_attention_kernel(
|
|
|
776
809
|
def get_next_bkv_ids(seq_idx, bq_idx, bkv_idx, bkv_sem_idx):
|
|
777
810
|
next_bkv_idx = bkv_idx + 1
|
|
778
811
|
is_last_bkv = next_bkv_idx == num_bkv
|
|
779
|
-
next_bkv_idx = lax.select(is_last_bkv, 0, next_bkv_idx)
|
|
780
812
|
next_bq_idx = lax.select(is_last_bkv, bq_idx + 1, bq_idx)
|
|
781
813
|
is_last_bq = next_bq_idx == num_bq
|
|
782
814
|
next_bq_idx = lax.select(is_last_bq, 0, next_bq_idx)
|
|
783
815
|
next_seq_idx = lax.select(is_last_bq, seq_idx + 1, seq_idx)
|
|
784
816
|
next_bkv_sem_idx = lax.select(bkv_sem_idx == 0, 1, 0)
|
|
817
|
+
|
|
818
|
+
if sliding_window is None:
|
|
819
|
+
# When sliding window is disabled, starting bkv_idx of next request is
|
|
820
|
+
# always 0 regardless of seq_idx of next request.
|
|
821
|
+
next_bkv_idx_start = 0
|
|
822
|
+
else:
|
|
823
|
+
# Determine starting bkv_idx of next request based on whether next
|
|
824
|
+
# request is from the same sequence or next sequence.
|
|
825
|
+
next_bkv_idx_start = lax.select(
|
|
826
|
+
is_last_bq,
|
|
827
|
+
next_seq_bkv_idx_start,
|
|
828
|
+
bkv_idx_start,
|
|
829
|
+
)
|
|
830
|
+
next_bkv_idx = lax.select(is_last_bkv, next_bkv_idx_start,
|
|
831
|
+
next_bkv_idx)
|
|
832
|
+
|
|
785
833
|
return next_seq_idx, next_bq_idx, next_bkv_idx, next_bkv_sem_idx
|
|
786
834
|
|
|
787
835
|
def compute_with_bq(bq_idx, _):
|
|
@@ -798,10 +846,6 @@ def _ragged_paged_attention_kernel(
|
|
|
798
846
|
def compute_with_bkv(bkv_idx, _):
|
|
799
847
|
# Create bitmask for KV.
|
|
800
848
|
assert bkv_sz % kv_packing == 0
|
|
801
|
-
actual_bkv_sz = jnp.minimum(bkv_sz, kv_len - bkv_idx * bkv_sz)
|
|
802
|
-
bkv_shape = (bkv_sz, head_dim)
|
|
803
|
-
bkv_mask = lax.broadcasted_iota(jnp.int32, bkv_shape,
|
|
804
|
-
0) < actual_bkv_sz
|
|
805
849
|
|
|
806
850
|
# Get next bkv ids.
|
|
807
851
|
bkv_sem_idx = sem_ids_ref[1]
|
|
@@ -842,6 +886,11 @@ def _ragged_paged_attention_kernel(
|
|
|
842
886
|
|
|
843
887
|
# Flash attention with cur bkv and bq
|
|
844
888
|
# NOTE: kv_packing is divided by 2 because k and v are packed together.
|
|
889
|
+
prev_bq_shape_0 = None
|
|
890
|
+
prev_kv_head_bv = None
|
|
891
|
+
prev_kv_head_idx = None
|
|
892
|
+
prev_kv_head_p = None
|
|
893
|
+
prev_kv_head_exp_m_diff = None
|
|
845
894
|
heads_per_load = max(1, kv_packing // 2)
|
|
846
895
|
for kv_head_start in range(0, actual_num_kv_heads,
|
|
847
896
|
heads_per_load):
|
|
@@ -849,25 +898,56 @@ def _ragged_paged_attention_kernel(
|
|
|
849
898
|
bkv_sem_idx,
|
|
850
899
|
kv_head_start * 2,
|
|
851
900
|
num_kv_heads_x2,
|
|
852
|
-
bkv_mask=bkv_mask,
|
|
853
901
|
)
|
|
854
902
|
assert len(bkv_lst) == heads_per_load
|
|
855
903
|
for i in range(heads_per_load):
|
|
856
|
-
|
|
857
|
-
if
|
|
904
|
+
cur_kv_head_idx = kv_head_start + i
|
|
905
|
+
if cur_kv_head_idx >= actual_num_kv_heads:
|
|
858
906
|
break
|
|
859
|
-
|
|
860
|
-
|
|
861
|
-
|
|
907
|
+
|
|
908
|
+
cur_kv_head_bq = load_bq(bq_sem_idx,
|
|
909
|
+
cur_kv_head_idx,
|
|
910
|
+
actual_bq_sz=actual_bq_sz)
|
|
862
911
|
bk, bv = bkv_lst[i]
|
|
863
|
-
|
|
864
|
-
|
|
865
|
-
|
|
866
|
-
|
|
867
|
-
|
|
868
|
-
|
|
869
|
-
|
|
870
|
-
|
|
912
|
+
# FlashAttention is divided into `flash_attention_step1_qk_softmax`
|
|
913
|
+
# and `flash_attention_step2_pv` to pipeline the computation.
|
|
914
|
+
# `step2_pv` for the previous KV head, which depends on the softmax
|
|
915
|
+
# output, is overlapped with `step1_qk_softmax` for the current KV
|
|
916
|
+
# head, reducing overall wait times.
|
|
917
|
+
cur_kv_head_p, cur_kv_head_exp_m_diff = (
|
|
918
|
+
flash_attention_step1_qk_softmax(
|
|
919
|
+
cur_kv_head_bq,
|
|
920
|
+
bk,
|
|
921
|
+
bv,
|
|
922
|
+
bq_idx=bq_idx,
|
|
923
|
+
bkv_idx=bkv_idx,
|
|
924
|
+
kv_head_idx=cur_kv_head_idx,
|
|
925
|
+
))
|
|
926
|
+
if prev_bq_shape_0 is not None:
|
|
927
|
+
flash_attention_step2_pv(
|
|
928
|
+
prev_bq_shape_0,
|
|
929
|
+
prev_kv_head_bv,
|
|
930
|
+
prev_kv_head_p,
|
|
931
|
+
prev_kv_head_exp_m_diff,
|
|
932
|
+
bkv_idx=bkv_idx,
|
|
933
|
+
kv_head_idx=prev_kv_head_idx,
|
|
934
|
+
)
|
|
935
|
+
prev_bq_shape_0 = cur_kv_head_bq.shape[0]
|
|
936
|
+
prev_kv_head_bv = bv
|
|
937
|
+
prev_kv_head_p = cur_kv_head_p
|
|
938
|
+
prev_kv_head_exp_m_diff = cur_kv_head_exp_m_diff
|
|
939
|
+
prev_kv_head_idx = cur_kv_head_idx
|
|
940
|
+
|
|
941
|
+
# Execute pv of last attention head.
|
|
942
|
+
assert prev_bq_shape_0 is not None
|
|
943
|
+
flash_attention_step2_pv(
|
|
944
|
+
prev_bq_shape_0,
|
|
945
|
+
prev_kv_head_bv,
|
|
946
|
+
prev_kv_head_p,
|
|
947
|
+
prev_kv_head_exp_m_diff,
|
|
948
|
+
bkv_idx=bkv_idx,
|
|
949
|
+
kv_head_idx=prev_kv_head_idx,
|
|
950
|
+
)
|
|
871
951
|
|
|
872
952
|
lax.fori_loop(0, num_bkv, compute_with_bkv, None, unroll=False)
|
|
873
953
|
|
|
@@ -899,7 +979,17 @@ def _ragged_paged_attention_kernel(
|
|
|
899
979
|
@pl.when(seq_idx == 0)
|
|
900
980
|
def prologue():
|
|
901
981
|
start_fetch_bq(0, 0, 0)
|
|
902
|
-
|
|
982
|
+
|
|
983
|
+
# Initialize bkv_x2_ref to zeros to avoid NaN issues from accessing
|
|
984
|
+
# uninitialized memory. Bitcast into int32 to avoid tiling issues.
|
|
985
|
+
bkv_x2_int32_ref = bkv_x2_ref.bitcast(jnp.int32).reshape(
|
|
986
|
+
(2, -1, 8, 128))
|
|
987
|
+
zeros = jnp.zeros(bkv_x2_int32_ref.shape[1:], jnp.int32)
|
|
988
|
+
|
|
989
|
+
# To pipeline VST and DMA, we divide the initialization into two steps.
|
|
990
|
+
bkv_x2_int32_ref[0] = zeros
|
|
991
|
+
start_fetch_bkv(0, bkv_idx_start, 0)
|
|
992
|
+
bkv_x2_int32_ref[1] = zeros
|
|
903
993
|
|
|
904
994
|
@pl.when(seq_idx < decode_end)
|
|
905
995
|
def process_decode():
|
|
@@ -1248,6 +1338,10 @@ def static_validate_inputs(
|
|
|
1248
1338
|
del debug_mode
|
|
1249
1339
|
|
|
1250
1340
|
|
|
1341
|
+
def get_kernel_scope_name(bq_size, bkv_p, page_size):
|
|
1342
|
+
return f"RPA-bq_{bq_size}-bkvp_{bkv_p}-p_{page_size}-"
|
|
1343
|
+
|
|
1344
|
+
|
|
1251
1345
|
@functools.partial(
|
|
1252
1346
|
jax.jit,
|
|
1253
1347
|
static_argnames=(
|
|
@@ -1309,14 +1403,14 @@ def ragged_paged_attention(
|
|
|
1309
1403
|
distribution: (i, j, k) represents that sequences[0:i] are decode-only,
|
|
1310
1404
|
sequences[i:j] are chunked-prefill-only, and sequences[j:k] are mixed. The
|
|
1311
1405
|
k is also the total number of sequences.
|
|
1312
|
-
actual_head_dim: the actual head size of the attention. Here we assume k and
|
|
1313
|
-
v have the same actual head size.
|
|
1314
1406
|
sm_scale: the softmax scale which will be applied to the Q@K^T.
|
|
1315
1407
|
sliding_window: the sliding window size for the attention.
|
|
1316
1408
|
soft_cap: the logit soft cap for the attention.
|
|
1317
1409
|
mask_value: mask value for causal mask.
|
|
1410
|
+
q_scale: the scale for the query.
|
|
1318
1411
|
k_scale: the scale for the key cache.
|
|
1319
1412
|
v_scale: the scale for the value cache.
|
|
1413
|
+
chunk_prefill_size: the chunk prefill size for the attention.
|
|
1320
1414
|
num_kv_pages_per_block: number of kv pages to be processed in one flash
|
|
1321
1415
|
attention block in the pallas kernel.
|
|
1322
1416
|
num_queries_per_block: number of kv pages to be processed in one flash
|
|
@@ -1383,6 +1477,7 @@ def ragged_paged_attention(
|
|
|
1383
1477
|
page_size,
|
|
1384
1478
|
max_num_tokens,
|
|
1385
1479
|
pages_per_seq,
|
|
1480
|
+
sliding_window,
|
|
1386
1481
|
)
|
|
1387
1482
|
bkv_sz = bkv_p * page_size
|
|
1388
1483
|
if vmem_limit_bytes is None:
|
|
@@ -1451,47 +1546,45 @@ def ragged_paged_attention(
|
|
|
1451
1546
|
jnp.full((6, ), -1, jnp.int32),
|
|
1452
1547
|
)
|
|
1453
1548
|
|
|
1454
|
-
scope_name =
|
|
1455
|
-
kernel =
|
|
1456
|
-
|
|
1457
|
-
|
|
1458
|
-
|
|
1459
|
-
|
|
1460
|
-
|
|
1461
|
-
|
|
1462
|
-
|
|
1463
|
-
|
|
1464
|
-
|
|
1465
|
-
|
|
1466
|
-
|
|
1467
|
-
|
|
1468
|
-
|
|
1469
|
-
|
|
1470
|
-
|
|
1471
|
-
|
|
1472
|
-
|
|
1473
|
-
|
|
1474
|
-
|
|
1475
|
-
|
|
1476
|
-
|
|
1477
|
-
|
|
1478
|
-
|
|
1479
|
-
|
|
1480
|
-
|
|
1481
|
-
|
|
1482
|
-
|
|
1483
|
-
|
|
1484
|
-
|
|
1485
|
-
|
|
1486
|
-
|
|
1487
|
-
|
|
1488
|
-
|
|
1489
|
-
|
|
1490
|
-
|
|
1491
|
-
|
|
1492
|
-
|
|
1493
|
-
name=scope_name,
|
|
1494
|
-
))
|
|
1549
|
+
scope_name = get_kernel_scope_name(bq_sz, bkv_p, page_size)
|
|
1550
|
+
kernel = pl.pallas_call(
|
|
1551
|
+
functools.partial(
|
|
1552
|
+
_ragged_paged_attention_kernel,
|
|
1553
|
+
sm_scale=sm_scale,
|
|
1554
|
+
sliding_window=sliding_window,
|
|
1555
|
+
soft_cap=soft_cap,
|
|
1556
|
+
mask_value=mask_value,
|
|
1557
|
+
q_scale=q_scale,
|
|
1558
|
+
k_scale=k_scale,
|
|
1559
|
+
v_scale=v_scale,
|
|
1560
|
+
chunk_prefill_size=chunk_prefill_size,
|
|
1561
|
+
bq_sz=bq_sz,
|
|
1562
|
+
bkv_p=bkv_p,
|
|
1563
|
+
debug_mode=debug_mode,
|
|
1564
|
+
),
|
|
1565
|
+
grid_spec=pltpu.PrefetchScalarGridSpec(
|
|
1566
|
+
num_scalar_prefetch=len(scalar_prefetches),
|
|
1567
|
+
in_specs=in_specs,
|
|
1568
|
+
out_specs=out_specs,
|
|
1569
|
+
grid=grid,
|
|
1570
|
+
scratch_shapes=scratch_shapes,
|
|
1571
|
+
),
|
|
1572
|
+
compiler_params=pltpu.CompilerParams(
|
|
1573
|
+
# TODO(jevinjiang): since each sequence depends on the previous
|
|
1574
|
+
# one, we need some extra work to support Megacore mode.
|
|
1575
|
+
dimension_semantics=("arbitrary", ),
|
|
1576
|
+
vmem_limit_bytes=vmem_limit_bytes,
|
|
1577
|
+
),
|
|
1578
|
+
out_shape=[
|
|
1579
|
+
jax.ShapeDtypeStruct(shape=q.shape, dtype=q.dtype),
|
|
1580
|
+
jax.ShapeDtypeStruct(shape=kv_cache.shape, dtype=kv_cache.dtype),
|
|
1581
|
+
],
|
|
1582
|
+
input_output_aliases={
|
|
1583
|
+
7: 0,
|
|
1584
|
+
9: 1
|
|
1585
|
+
},
|
|
1586
|
+
name=scope_name,
|
|
1587
|
+
)
|
|
1495
1588
|
|
|
1496
1589
|
output, updated_kv_cache = kernel(*scalar_prefetches, q, kv, kv_cache)
|
|
1497
1590
|
return (
|