tpu-inference 0.12.0.dev20251213__py3-none-any.whl → 0.13.2.dev20251230__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +13 -0
- tests/core/__init__.py +13 -0
- tests/core/test_disagg_utils.py +14 -0
- tests/core/test_dp_scheduler.py +650 -768
- tests/core/test_init.py +14 -0
- tests/distributed/__init__.py +13 -0
- tests/distributed/test_distributed_utils.py +120 -0
- tests/distributed/test_tpu_connector.py +478 -0
- tests/e2e/__init__.py +13 -0
- tests/e2e/test_async_scheduler.py +211 -0
- tests/e2e/test_data_parallel.py +289 -0
- tests/e2e/test_hybrid_kvcache.py +219 -0
- tests/e2e/test_local_disagg.py +257 -0
- tests/e2e/test_model_loader.py +268 -0
- tests/e2e/test_multi_modal_inference.py +111 -0
- tests/e2e/test_pipeline_parallel.py +265 -0
- tests/e2e/test_runai_model_streamer_loader.py +104 -0
- tests/e2e/test_sampling_params.py +269 -0
- tests/e2e/test_speculative_decoding.py +311 -0
- tests/e2e/test_structured_decoding.py +46 -0
- tests/executors/__init__.py +13 -0
- tests/executors/test_ray_distributed_executor.py +199 -0
- tests/experimental/__init__.py +13 -0
- tests/experimental/test_llama3_jax_stashed.py +208 -0
- tests/kernels/__init__.py +13 -0
- tests/kernels/collectives/__init__.py +13 -0
- tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
- tests/kernels/fused_moe_v1_test.py +14 -0
- tests/kernels/gmm_test.py +205 -0
- tests/kernels/mla_v1_test.py +14 -0
- tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +14 -0
- tests/layers/__init__.py +13 -0
- tests/layers/common/__init__.py +13 -0
- tests/layers/common/test_attention_interface.py +156 -0
- tests/layers/common/test_quantization.py +149 -0
- tests/layers/jax/__init__.py +13 -0
- tests/layers/jax/attention/__init__.py +13 -0
- tests/layers/jax/attention/test_common_attention.py +103 -0
- tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
- tests/layers/jax/attention/test_llama4_attention.py +135 -0
- tests/layers/jax/moe/__init__.py +13 -0
- tests/layers/jax/moe/test_deepseek_moe.py +235 -0
- tests/layers/jax/sample/__init__.py +13 -0
- tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
- tests/layers/jax/sample/test_sampling.py +115 -0
- tests/layers/jax/sample/test_sampling_metadata.py +254 -0
- tests/layers/jax/test_layers.py +155 -0
- tests/{test_quantization.py → layers/jax/test_qwix.py} +180 -50
- tests/layers/jax/test_rope.py +93 -0
- tests/layers/jax/test_sharding.py +159 -0
- tests/layers/jax/test_transformer_block.py +152 -0
- tests/layers/vllm/__init__.py +13 -0
- tests/layers/vllm/test_attention.py +363 -0
- tests/layers/vllm/test_awq.py +406 -0
- tests/layers/vllm/test_compressed_tensors_moe.py +199 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +441 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +443 -0
- tests/layers/vllm/test_fp8.py +17 -0
- tests/layers/vllm/test_mxfp4.py +320 -0
- tests/layers/vllm/test_unquantized.py +662 -0
- tests/layers/vllm/utils.py +87 -0
- tests/lora/__init__.py +13 -0
- tests/lora/conftest.py +14 -0
- tests/lora/test_bgmv.py +14 -0
- tests/lora/test_layers.py +25 -8
- tests/lora/test_lora.py +15 -1
- tests/lora/test_lora_perf.py +14 -0
- tests/models/__init__.py +13 -0
- tests/models/common/__init__.py +13 -0
- tests/models/common/test_model_loader.py +455 -0
- tests/models/jax/__init__.py +13 -0
- tests/models/jax/test_deepseek_v3.py +401 -0
- tests/models/jax/test_llama3.py +184 -0
- tests/models/jax/test_llama4.py +298 -0
- tests/models/jax/test_llama_eagle3.py +197 -0
- tests/models/jax/test_llama_guard_4.py +242 -0
- tests/models/jax/test_qwen2.py +172 -0
- tests/models/jax/test_qwen2_5_vl.py +605 -0
- tests/models/jax/test_qwen3.py +169 -0
- tests/models/jax/test_weight_loading.py +180 -0
- tests/models/jax/utils/__init__.py +13 -0
- tests/models/jax/utils/test_multi_modal_utils.py +212 -0
- tests/platforms/__init__.py +13 -0
- tests/platforms/test_tpu_platform.py +54 -0
- tests/runner/__init__.py +13 -0
- tests/runner/test_block_table.py +395 -0
- tests/runner/test_input_batch.py +226 -0
- tests/runner/test_kv_cache.py +220 -0
- tests/runner/test_kv_cache_manager.py +498 -0
- tests/runner/test_multimodal_manager.py +429 -0
- tests/runner/test_persistent_batch_manager.py +84 -0
- tests/runner/test_speculative_decoding_manager.py +368 -0
- tests/runner/test_structured_decoding_manager.py +220 -0
- tests/runner/test_tpu_runner.py +261 -0
- tests/runner/test_tpu_runner_dp.py +1099 -0
- tests/runner/test_tpu_runner_mesh.py +200 -0
- tests/runner/test_utils.py +411 -0
- tests/spec_decode/__init__.py +13 -0
- tests/spec_decode/test_eagle3.py +311 -0
- tests/test_base.py +14 -0
- tests/test_tpu_info.py +14 -0
- tests/test_utils.py +1 -43
- tests/worker/__init__.py +13 -0
- tests/worker/tpu_worker_test.py +414 -0
- tpu_inference/__init__.py +14 -0
- tpu_inference/core/__init__.py +13 -0
- tpu_inference/core/sched/__init__.py +13 -0
- tpu_inference/core/sched/dp_scheduler.py +372 -56
- tpu_inference/distributed/__init__.py +13 -0
- tpu_inference/distributed/jax_parallel_state.py +14 -0
- tpu_inference/distributed/tpu_connector.py +14 -9
- tpu_inference/distributed/utils.py +56 -4
- tpu_inference/executors/__init__.py +13 -0
- tpu_inference/executors/ray_distributed_executor.py +20 -3
- tpu_inference/experimental/__init__.py +13 -0
- tpu_inference/experimental/llama3_jax_stashed.py +14 -0
- tpu_inference/kernels/__init__.py +13 -0
- tpu_inference/kernels/collectives/__init__.py +13 -0
- tpu_inference/kernels/flash_attention/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +171 -163
- tpu_inference/kernels/megablox/__init__.py +13 -0
- tpu_inference/kernels/megablox/common.py +54 -0
- tpu_inference/kernels/megablox/gmm.py +646 -0
- tpu_inference/kernels/mla/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/kernel.py +20 -26
- tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +112 -69
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +85 -65
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +374 -194
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +13 -0
- tpu_inference/layers/__init__.py +13 -0
- tpu_inference/layers/common/__init__.py +13 -0
- tpu_inference/layers/common/attention_interface.py +26 -19
- tpu_inference/layers/common/attention_metadata.py +14 -0
- tpu_inference/layers/common/fused_moe_gmm.py +506 -0
- tpu_inference/layers/common/quant_methods.py +15 -0
- tpu_inference/layers/common/quantization.py +282 -0
- tpu_inference/layers/common/sharding.py +22 -3
- tpu_inference/layers/common/utils.py +94 -0
- tpu_inference/layers/jax/__init__.py +13 -0
- tpu_inference/layers/jax/attention/__init__.py +13 -0
- tpu_inference/layers/jax/attention/attention.py +19 -6
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +52 -27
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +19 -6
- tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
- tpu_inference/layers/jax/base.py +14 -0
- tpu_inference/layers/jax/constants.py +13 -0
- tpu_inference/layers/jax/layers.py +14 -0
- tpu_inference/layers/jax/misc.py +14 -0
- tpu_inference/layers/jax/moe/__init__.py +13 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
- tpu_inference/layers/jax/moe/moe.py +43 -3
- tpu_inference/layers/jax/pp_utils.py +53 -0
- tpu_inference/layers/jax/rope.py +14 -0
- tpu_inference/layers/jax/rope_interface.py +14 -0
- tpu_inference/layers/jax/sample/__init__.py +13 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
- tpu_inference/layers/jax/sample/sampling.py +15 -1
- tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
- tpu_inference/layers/jax/transformer_block.py +14 -0
- tpu_inference/layers/vllm/__init__.py +13 -0
- tpu_inference/layers/vllm/attention.py +4 -4
- tpu_inference/layers/vllm/fused_moe.py +100 -455
- tpu_inference/layers/vllm/linear.py +64 -0
- tpu_inference/layers/vllm/process_weights/__init__.py +13 -0
- tpu_inference/layers/vllm/{sharding.py → process_weights/cleanup_sharding.py} +24 -15
- tpu_inference/layers/vllm/process_weights/fused_moe_weights.py +369 -0
- tpu_inference/layers/vllm/process_weights/linear_weights.py +174 -0
- tpu_inference/layers/vllm/quantization/__init__.py +19 -3
- tpu_inference/layers/vllm/quantization/awq.py +96 -82
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +19 -5
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +119 -132
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +111 -91
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +79 -43
- tpu_inference/layers/vllm/quantization/{common.py → configs.py} +38 -26
- tpu_inference/layers/vllm/quantization/fp8.py +119 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +133 -220
- tpu_inference/layers/vllm/quantization/unquantized.py +154 -253
- tpu_inference/lora/__init__.py +13 -0
- tpu_inference/lora/torch_lora_ops.py +8 -13
- tpu_inference/models/__init__.py +13 -0
- tpu_inference/models/common/__init__.py +13 -0
- tpu_inference/models/common/model_loader.py +37 -16
- tpu_inference/models/jax/__init__.py +13 -0
- tpu_inference/models/jax/deepseek_v3.py +113 -124
- tpu_inference/models/jax/gpt_oss.py +23 -7
- tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
- tpu_inference/models/jax/llama3.py +99 -36
- tpu_inference/models/jax/llama4.py +14 -0
- tpu_inference/models/jax/llama_eagle3.py +14 -0
- tpu_inference/models/jax/llama_guard_4.py +15 -1
- tpu_inference/models/jax/qwen2.py +17 -2
- tpu_inference/models/jax/qwen2_5_vl.py +18 -4
- tpu_inference/models/jax/qwen3.py +17 -2
- tpu_inference/models/jax/utils/__init__.py +13 -0
- tpu_inference/models/jax/utils/file_utils.py +14 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
- tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
- tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +85 -24
- tpu_inference/models/jax/utils/weight_utils.py +32 -1
- tpu_inference/models/vllm/__init__.py +13 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +22 -4
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
- tpu_inference/platforms/__init__.py +14 -0
- tpu_inference/platforms/tpu_platform.py +27 -29
- tpu_inference/runner/__init__.py +13 -0
- tpu_inference/runner/compilation_manager.py +69 -35
- tpu_inference/runner/kv_cache.py +14 -0
- tpu_inference/runner/kv_cache_manager.py +15 -2
- tpu_inference/runner/lora_utils.py +16 -1
- tpu_inference/runner/multimodal_manager.py +16 -2
- tpu_inference/runner/persistent_batch_manager.py +14 -0
- tpu_inference/runner/speculative_decoding_manager.py +14 -0
- tpu_inference/runner/structured_decoding_manager.py +14 -0
- tpu_inference/runner/tpu_runner.py +30 -10
- tpu_inference/spec_decode/__init__.py +13 -0
- tpu_inference/spec_decode/jax/__init__.py +13 -0
- tpu_inference/spec_decode/jax/eagle3.py +13 -0
- tpu_inference/tpu_info.py +14 -0
- tpu_inference/utils.py +31 -30
- tpu_inference/worker/__init__.py +13 -0
- tpu_inference/worker/tpu_worker.py +23 -7
- {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/METADATA +1 -1
- tpu_inference-0.13.2.dev20251230.dist-info/RECORD +266 -0
- tpu_inference/layers/vllm/linear_common.py +0 -208
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
- tpu_inference-0.12.0.dev20251213.dist-info/RECORD +0 -175
- {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/WHEEL +0 -0
- {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/top_level.txt +0 -0
|
@@ -1,3 +1,16 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
1
14
|
"""
|
|
2
15
|
A variant of TPU-Friendly Ragged Paged Attention kernel optimized for
|
|
3
16
|
head_dim = 64.
|
|
@@ -267,7 +280,6 @@ def _ragged_paged_attention_kernel(
|
|
|
267
280
|
*,
|
|
268
281
|
sm_scale: float,
|
|
269
282
|
sliding_window: int | None = None,
|
|
270
|
-
strict_sliding_window: bool = True,
|
|
271
283
|
soft_cap: float | None = None,
|
|
272
284
|
mask_value: float = DEFAULT_MASK_VALUE,
|
|
273
285
|
q_scale: float | None = None,
|
|
@@ -324,14 +336,15 @@ def _ragged_paged_attention_kernel(
|
|
|
324
336
|
bkv_idx_start = jnp.maximum(kv_len - q_len - sliding_window,
|
|
325
337
|
0) // bkv_sz
|
|
326
338
|
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
next_seq_bkv_idx_start =
|
|
334
|
-
|
|
339
|
+
# If seq_idx + 1 == num_seqs, kv_lens_ref[seq_idx + 1] will trigger a
|
|
340
|
+
# out-of-bound error. To avoid this, we set upperbound of next_seq_idx
|
|
341
|
+
# to be num_seqs - 1.
|
|
342
|
+
next_seq_idx = jnp.minimum(seq_idx + 1, num_seqs - 1)
|
|
343
|
+
next_kv_len = kv_lens_ref[next_seq_idx]
|
|
344
|
+
next_q_len = cu_q_lens_ref[next_seq_idx + 1] - q_end
|
|
345
|
+
next_seq_bkv_idx_start = (
|
|
346
|
+
jnp.maximum(next_kv_len - next_q_len - sliding_window, 0) //
|
|
347
|
+
bkv_sz)
|
|
335
348
|
|
|
336
349
|
def debug_print(msg, *args):
|
|
337
350
|
if debug_mode:
|
|
@@ -396,7 +409,7 @@ def _ragged_paged_attention_kernel(
|
|
|
396
409
|
k_span = bkv_idx * bkv_sz + lax.broadcasted_iota(jnp.int32, s.shape, 1)
|
|
397
410
|
mask = k_span <= q_span
|
|
398
411
|
|
|
399
|
-
if sliding_window is not None
|
|
412
|
+
if sliding_window is not None:
|
|
400
413
|
mask = jnp.logical_and(mask, q_span - sliding_window < k_span)
|
|
401
414
|
|
|
402
415
|
s = jnp.where(mask, s, mask_value)
|
|
@@ -723,7 +736,7 @@ def _ragged_paged_attention_kernel(
|
|
|
723
736
|
vec = ref[start::step]
|
|
724
737
|
return vec
|
|
725
738
|
|
|
726
|
-
def strided_load_bkv(bkv_sem_idx, start, step
|
|
739
|
+
def strided_load_bkv(bkv_sem_idx, start, step):
|
|
727
740
|
assert start % kv_packing == 0
|
|
728
741
|
assert step % kv_packing == 0
|
|
729
742
|
start //= kv_packing
|
|
@@ -732,7 +745,6 @@ def _ragged_paged_attention_kernel(
|
|
|
732
745
|
bkv_sz * step, actual_head_dim_x2))
|
|
733
746
|
|
|
734
747
|
kv = strided_load(kv_ref, start, step)
|
|
735
|
-
kv = lax.select(bkv_mask, kv, jnp.zeros_like(kv))
|
|
736
748
|
bitwidth = 32 // kv_packing
|
|
737
749
|
repack_ty = jnp.dtype(f"uint{bitwidth}")
|
|
738
750
|
lst = []
|
|
@@ -780,14 +792,18 @@ def _ragged_paged_attention_kernel(
|
|
|
780
792
|
next_bkv_sem_idx = lax.select(bkv_sem_idx == 0, 1, 0)
|
|
781
793
|
|
|
782
794
|
if sliding_window is None:
|
|
783
|
-
|
|
795
|
+
# When sliding window is disabled, starting bkv_idx of next request is
|
|
796
|
+
# always 0 regardless of seq_idx of next request.
|
|
797
|
+
next_bkv_idx_start = 0
|
|
784
798
|
else:
|
|
785
|
-
|
|
799
|
+
# Determine starting bkv_idx of next request based on whether next
|
|
800
|
+
# request is from the same sequence or next sequence.
|
|
801
|
+
next_bkv_idx_start = lax.select(
|
|
786
802
|
is_last_bq,
|
|
787
803
|
next_seq_bkv_idx_start,
|
|
788
804
|
bkv_idx_start,
|
|
789
805
|
)
|
|
790
|
-
next_bkv_idx = lax.select(is_last_bkv,
|
|
806
|
+
next_bkv_idx = lax.select(is_last_bkv, next_bkv_idx_start,
|
|
791
807
|
next_bkv_idx)
|
|
792
808
|
|
|
793
809
|
return next_seq_idx, next_bq_idx, next_bkv_idx, next_bkv_sem_idx
|
|
@@ -806,10 +822,6 @@ def _ragged_paged_attention_kernel(
|
|
|
806
822
|
def compute_with_bkv(bkv_idx, _):
|
|
807
823
|
# Create bitmask for KV.
|
|
808
824
|
assert bkv_sz % kv_packing == 0
|
|
809
|
-
actual_bkv_sz = jnp.minimum(bkv_sz, kv_len - bkv_idx * bkv_sz)
|
|
810
|
-
bkv_shape = (bkv_sz, actual_head_dim_x2)
|
|
811
|
-
bkv_mask = lax.broadcasted_iota(jnp.int32, bkv_shape,
|
|
812
|
-
0) < actual_bkv_sz
|
|
813
825
|
|
|
814
826
|
# Get next bkv ids.
|
|
815
827
|
bkv_sem_idx = sem_ids_ref[1]
|
|
@@ -859,7 +871,6 @@ def _ragged_paged_attention_kernel(
|
|
|
859
871
|
bkv_sem_idx,
|
|
860
872
|
kv_head_start,
|
|
861
873
|
num_kv_heads,
|
|
862
|
-
bkv_mask=bkv_mask,
|
|
863
874
|
)
|
|
864
875
|
assert len(bkv_lst) == kv_packing
|
|
865
876
|
for i in range(kv_packing):
|
|
@@ -943,7 +954,17 @@ def _ragged_paged_attention_kernel(
|
|
|
943
954
|
@pl.when(seq_idx == 0)
|
|
944
955
|
def prologue():
|
|
945
956
|
start_fetch_bq(0, 0, 0)
|
|
957
|
+
|
|
958
|
+
# Initialize bkv_x2_ref to zeros to avoid NaN issues from accessing
|
|
959
|
+
# uninitialized memory. Bitcast into int32 to avoid tiling issues.
|
|
960
|
+
bkv_x2_int32_ref = bkv_x2_ref.bitcast(jnp.int32).reshape(
|
|
961
|
+
(2, -1, 8, 128))
|
|
962
|
+
zeros = jnp.zeros(bkv_x2_int32_ref.shape[1:], jnp.int32)
|
|
963
|
+
|
|
964
|
+
# To pipeline VST and DMA, we divide the initialization into two steps.
|
|
965
|
+
bkv_x2_int32_ref[0] = zeros
|
|
946
966
|
start_fetch_bkv(0, bkv_idx_start, 0)
|
|
967
|
+
bkv_x2_int32_ref[1] = zeros
|
|
947
968
|
|
|
948
969
|
@pl.when(seq_idx < decode_end)
|
|
949
970
|
def process_decode():
|
|
@@ -1303,12 +1324,15 @@ def static_validate_inputs(
|
|
|
1303
1324
|
del debug_mode
|
|
1304
1325
|
|
|
1305
1326
|
|
|
1327
|
+
def get_kernel_scope_name(bq_size, bkv_p, page_size):
|
|
1328
|
+
return f"RPA-HD_64-bq_{bq_size}-bkvp_{bkv_p}-p_{page_size}-"
|
|
1329
|
+
|
|
1330
|
+
|
|
1306
1331
|
@functools.partial(
|
|
1307
1332
|
jax.jit,
|
|
1308
1333
|
static_argnames=(
|
|
1309
1334
|
"sm_scale",
|
|
1310
1335
|
"sliding_window",
|
|
1311
|
-
"strict_sliding_window",
|
|
1312
1336
|
"soft_cap",
|
|
1313
1337
|
"mask_value",
|
|
1314
1338
|
"q_scale",
|
|
@@ -1338,7 +1362,6 @@ def ragged_paged_attention_hd64(
|
|
|
1338
1362
|
*,
|
|
1339
1363
|
sm_scale: float = 1.0,
|
|
1340
1364
|
sliding_window: int | None = None,
|
|
1341
|
-
strict_sliding_window: bool = True,
|
|
1342
1365
|
soft_cap: float | None = None,
|
|
1343
1366
|
mask_value: float | None = DEFAULT_MASK_VALUE,
|
|
1344
1367
|
q_scale: float | None = None,
|
|
@@ -1370,7 +1393,6 @@ def ragged_paged_attention_hd64(
|
|
|
1370
1393
|
attention_sink: optional attention sink for each q head.
|
|
1371
1394
|
sm_scale: the softmax scale which will be applied to the Q@K^T.
|
|
1372
1395
|
sliding_window: the sliding window size for the attention.
|
|
1373
|
-
strict_sliding_window: compute tokens that are strictly within the window.
|
|
1374
1396
|
soft_cap: the logit soft cap for the attention.
|
|
1375
1397
|
mask_value: mask value for causal mask.
|
|
1376
1398
|
q_scale: the scale for the query.
|
|
@@ -1444,6 +1466,7 @@ def ragged_paged_attention_hd64(
|
|
|
1444
1466
|
page_size,
|
|
1445
1467
|
max_num_tokens,
|
|
1446
1468
|
pages_per_seq,
|
|
1469
|
+
sliding_window,
|
|
1447
1470
|
)
|
|
1448
1471
|
bkv_sz = bkv_p * page_size
|
|
1449
1472
|
if vmem_limit_bytes is None:
|
|
@@ -1514,48 +1537,45 @@ def ragged_paged_attention_hd64(
|
|
|
1514
1537
|
jnp.full((6, ), -1, jnp.int32),
|
|
1515
1538
|
)
|
|
1516
1539
|
|
|
1517
|
-
scope_name =
|
|
1518
|
-
kernel =
|
|
1519
|
-
|
|
1520
|
-
|
|
1521
|
-
|
|
1522
|
-
|
|
1523
|
-
|
|
1524
|
-
|
|
1525
|
-
|
|
1526
|
-
|
|
1527
|
-
|
|
1528
|
-
|
|
1529
|
-
|
|
1530
|
-
|
|
1531
|
-
|
|
1532
|
-
|
|
1533
|
-
|
|
1534
|
-
),
|
|
1535
|
-
|
|
1536
|
-
|
|
1537
|
-
|
|
1538
|
-
|
|
1539
|
-
|
|
1540
|
-
|
|
1541
|
-
)
|
|
1542
|
-
|
|
1543
|
-
|
|
1544
|
-
|
|
1545
|
-
|
|
1546
|
-
|
|
1547
|
-
),
|
|
1548
|
-
|
|
1549
|
-
|
|
1550
|
-
|
|
1551
|
-
|
|
1552
|
-
|
|
1553
|
-
|
|
1554
|
-
|
|
1555
|
-
|
|
1556
|
-
},
|
|
1557
|
-
name=scope_name,
|
|
1558
|
-
))
|
|
1540
|
+
scope_name = get_kernel_scope_name(bq_sz, bkv_p, page_size)
|
|
1541
|
+
kernel = pl.pallas_call(
|
|
1542
|
+
functools.partial(
|
|
1543
|
+
_ragged_paged_attention_kernel,
|
|
1544
|
+
sm_scale=sm_scale,
|
|
1545
|
+
sliding_window=sliding_window,
|
|
1546
|
+
soft_cap=soft_cap,
|
|
1547
|
+
mask_value=mask_value,
|
|
1548
|
+
q_scale=q_scale,
|
|
1549
|
+
k_scale=k_scale,
|
|
1550
|
+
v_scale=v_scale,
|
|
1551
|
+
chunk_prefill_size=chunk_prefill_size,
|
|
1552
|
+
bq_sz=bq_sz,
|
|
1553
|
+
bkv_p=bkv_p,
|
|
1554
|
+
debug_mode=debug_mode,
|
|
1555
|
+
),
|
|
1556
|
+
grid_spec=pltpu.PrefetchScalarGridSpec(
|
|
1557
|
+
num_scalar_prefetch=len(scalar_prefetches),
|
|
1558
|
+
in_specs=in_specs,
|
|
1559
|
+
out_specs=out_specs,
|
|
1560
|
+
grid=grid,
|
|
1561
|
+
scratch_shapes=scratch_shapes,
|
|
1562
|
+
),
|
|
1563
|
+
compiler_params=pltpu.CompilerParams(
|
|
1564
|
+
# TODO(jevinjiang): since each sequence depends on the previous
|
|
1565
|
+
# one, we need some extra work to support Megacore mode.
|
|
1566
|
+
dimension_semantics=("arbitrary", ),
|
|
1567
|
+
vmem_limit_bytes=vmem_limit_bytes,
|
|
1568
|
+
),
|
|
1569
|
+
out_shape=[
|
|
1570
|
+
jax.ShapeDtypeStruct(shape=q.shape, dtype=q.dtype),
|
|
1571
|
+
jax.ShapeDtypeStruct(shape=kv_cache.shape, dtype=kv_cache.dtype),
|
|
1572
|
+
],
|
|
1573
|
+
input_output_aliases={
|
|
1574
|
+
7: 0,
|
|
1575
|
+
9: 1
|
|
1576
|
+
},
|
|
1577
|
+
name=scope_name,
|
|
1578
|
+
)
|
|
1559
1579
|
|
|
1560
1580
|
output, updated_kv_cache = kernel(*scalar_prefetches, q, kv, kv_cache,
|
|
1561
1581
|
attention_sink)
|