tpu-inference 0.11.1.dev202512030818__py3-none-any.whl → 0.13.2rc3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +13 -0
- tests/core/__init__.py +13 -0
- tests/core/test_disagg_utils.py +14 -0
- tests/core/test_dp_scheduler.py +650 -768
- tests/core/test_init.py +14 -0
- tests/distributed/__init__.py +13 -0
- tests/distributed/test_distributed_utils.py +120 -0
- tests/distributed/test_tpu_connector.py +478 -0
- tests/e2e/__init__.py +13 -0
- tests/e2e/test_async_scheduler.py +211 -0
- tests/e2e/test_data_parallel.py +289 -0
- tests/e2e/test_hybrid_kvcache.py +219 -0
- tests/e2e/test_local_disagg.py +257 -0
- tests/e2e/test_model_loader.py +268 -0
- tests/e2e/test_multi_modal_inference.py +111 -0
- tests/e2e/test_pipeline_parallel.py +265 -0
- tests/e2e/test_runai_model_streamer_loader.py +104 -0
- tests/e2e/test_sampling_params.py +269 -0
- tests/e2e/test_speculative_decoding.py +311 -0
- tests/e2e/test_structured_decoding.py +46 -0
- tests/executors/__init__.py +13 -0
- tests/executors/test_ray_distributed_executor.py +199 -0
- tests/experimental/__init__.py +13 -0
- tests/experimental/test_llama3_jax_stashed.py +208 -0
- tests/kernels/__init__.py +13 -0
- tests/kernels/collectives/__init__.py +13 -0
- tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
- tests/kernels/fused_moe_v1_test.py +14 -0
- tests/kernels/gmm_test.py +205 -0
- tests/kernels/mla_v1_test.py +143 -41
- tests/kernels/quantized_matmul_kernel_test.py +2 -34
- tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
- tests/layers/__init__.py +13 -0
- tests/layers/common/__init__.py +13 -0
- tests/layers/common/test_attention_interface.py +156 -0
- tests/layers/common/test_quantization.py +149 -0
- tests/layers/jax/__init__.py +13 -0
- tests/layers/jax/attention/__init__.py +13 -0
- tests/layers/jax/attention/test_common_attention.py +103 -0
- tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
- tests/layers/jax/attention/test_llama4_attention.py +135 -0
- tests/layers/jax/moe/__init__.py +13 -0
- tests/layers/jax/moe/test_deepseek_moe.py +235 -0
- tests/layers/jax/sample/__init__.py +13 -0
- tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
- tests/layers/jax/sample/test_sampling.py +115 -0
- tests/layers/jax/sample/test_sampling_metadata.py +254 -0
- tests/layers/jax/test_layers.py +155 -0
- tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
- tests/layers/jax/test_rope.py +93 -0
- tests/layers/jax/test_sharding.py +159 -0
- tests/layers/jax/test_transformer_block.py +152 -0
- tests/layers/vllm/__init__.py +13 -0
- tests/layers/vllm/test_attention.py +363 -0
- tests/layers/vllm/test_awq.py +405 -0
- tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +418 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +441 -0
- tests/layers/vllm/test_fp8.py +17 -0
- tests/layers/vllm/test_mxfp4.py +312 -0
- tests/layers/vllm/test_unquantized.py +651 -0
- tests/layers/vllm/utils.py +87 -0
- tests/lora/__init__.py +13 -0
- tests/lora/conftest.py +14 -0
- tests/lora/test_bgmv.py +14 -0
- tests/lora/test_layers.py +21 -3
- tests/lora/test_lora.py +15 -1
- tests/lora/test_lora_perf.py +67 -0
- tests/models/__init__.py +13 -0
- tests/models/common/__init__.py +13 -0
- tests/models/common/test_model_loader.py +455 -0
- tests/models/jax/__init__.py +13 -0
- tests/models/jax/test_deepseek_v3.py +401 -0
- tests/models/jax/test_llama3.py +184 -0
- tests/models/jax/test_llama4.py +298 -0
- tests/models/jax/test_llama_eagle3.py +197 -0
- tests/models/jax/test_llama_guard_4.py +242 -0
- tests/models/jax/test_qwen2.py +172 -0
- tests/models/jax/test_qwen2_5_vl.py +605 -0
- tests/models/jax/test_qwen3.py +169 -0
- tests/models/jax/test_weight_loading.py +180 -0
- tests/models/jax/utils/__init__.py +13 -0
- tests/models/jax/utils/test_multi_modal_utils.py +212 -0
- tests/platforms/__init__.py +13 -0
- tests/platforms/test_tpu_platform.py +54 -0
- tests/runner/__init__.py +13 -0
- tests/runner/test_block_table.py +395 -0
- tests/runner/test_input_batch.py +226 -0
- tests/runner/test_kv_cache.py +220 -0
- tests/runner/test_kv_cache_manager.py +498 -0
- tests/runner/test_multimodal_manager.py +429 -0
- tests/runner/test_persistent_batch_manager.py +84 -0
- tests/runner/test_speculative_decoding_manager.py +368 -0
- tests/runner/test_structured_decoding_manager.py +220 -0
- tests/runner/test_tpu_runner.py +261 -0
- tests/runner/test_tpu_runner_dp.py +1099 -0
- tests/runner/test_tpu_runner_mesh.py +200 -0
- tests/runner/test_utils.py +411 -0
- tests/spec_decode/__init__.py +13 -0
- tests/spec_decode/test_eagle3.py +311 -0
- tests/test_base.py +14 -0
- tests/test_envs.py +78 -1
- tests/test_tpu_info.py +14 -0
- tests/test_utils.py +1 -43
- tests/worker/__init__.py +13 -0
- tests/worker/tpu_worker_test.py +414 -0
- tpu_inference/__init__.py +14 -0
- tpu_inference/core/__init__.py +13 -0
- tpu_inference/core/sched/__init__.py +13 -0
- tpu_inference/core/sched/dp_scheduler.py +372 -56
- tpu_inference/distributed/__init__.py +13 -0
- tpu_inference/distributed/jax_parallel_state.py +14 -0
- tpu_inference/distributed/tpu_connector.py +14 -9
- tpu_inference/distributed/utils.py +56 -4
- tpu_inference/envs.py +38 -7
- tpu_inference/executors/__init__.py +13 -0
- tpu_inference/executors/ray_distributed_executor.py +17 -0
- tpu_inference/experimental/__init__.py +13 -0
- tpu_inference/experimental/llama3_jax_stashed.py +14 -0
- tpu_inference/kernels/__init__.py +13 -0
- tpu_inference/kernels/collectives/__init__.py +13 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
- tpu_inference/kernels/flash_attention/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +370 -324
- tpu_inference/kernels/megablox/__init__.py +13 -0
- tpu_inference/kernels/megablox/common.py +54 -0
- tpu_inference/kernels/megablox/gmm.py +646 -0
- tpu_inference/kernels/mla/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/kernel.py +117 -145
- tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
- tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +95 -78
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
- tpu_inference/layers/__init__.py +13 -0
- tpu_inference/layers/common/__init__.py +13 -0
- tpu_inference/layers/common/attention_interface.py +26 -19
- tpu_inference/layers/common/attention_metadata.py +14 -0
- tpu_inference/layers/common/quant_methods.py +15 -0
- tpu_inference/layers/common/quantization.py +270 -0
- tpu_inference/layers/common/sharding.py +28 -5
- tpu_inference/layers/jax/__init__.py +13 -0
- tpu_inference/layers/jax/attention/__init__.py +13 -0
- tpu_inference/layers/jax/attention/attention.py +19 -6
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
- tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
- tpu_inference/layers/jax/base.py +14 -0
- tpu_inference/layers/jax/constants.py +13 -0
- tpu_inference/layers/jax/layers.py +14 -0
- tpu_inference/layers/jax/misc.py +14 -0
- tpu_inference/layers/jax/moe/__init__.py +13 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
- tpu_inference/layers/jax/moe/moe.py +43 -3
- tpu_inference/layers/jax/pp_utils.py +53 -0
- tpu_inference/layers/jax/rope.py +14 -0
- tpu_inference/layers/jax/rope_interface.py +14 -0
- tpu_inference/layers/jax/sample/__init__.py +13 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
- tpu_inference/layers/jax/sample/sampling.py +15 -1
- tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
- tpu_inference/layers/jax/transformer_block.py +14 -0
- tpu_inference/layers/vllm/__init__.py +13 -0
- tpu_inference/layers/vllm/attention.py +4 -4
- tpu_inference/layers/vllm/fused_moe.py +210 -260
- tpu_inference/layers/vllm/linear_common.py +57 -22
- tpu_inference/layers/vllm/quantization/__init__.py +16 -0
- tpu_inference/layers/vllm/quantization/awq.py +15 -1
- tpu_inference/layers/vllm/quantization/common.py +33 -18
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +18 -3
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +211 -148
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +14 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +14 -0
- tpu_inference/layers/vllm/quantization/fp8.py +118 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +278 -209
- tpu_inference/layers/vllm/quantization/unquantized.py +134 -86
- tpu_inference/layers/vllm/sharding.py +21 -4
- tpu_inference/lora/__init__.py +13 -0
- tpu_inference/lora/torch_lora_ops.py +8 -13
- tpu_inference/models/__init__.py +13 -0
- tpu_inference/models/common/__init__.py +13 -0
- tpu_inference/models/common/model_loader.py +74 -35
- tpu_inference/models/jax/__init__.py +13 -0
- tpu_inference/models/jax/deepseek_v3.py +267 -157
- tpu_inference/models/jax/gpt_oss.py +26 -10
- tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
- tpu_inference/models/jax/llama3.py +99 -36
- tpu_inference/models/jax/llama4.py +14 -0
- tpu_inference/models/jax/llama_eagle3.py +14 -0
- tpu_inference/models/jax/llama_guard_4.py +15 -1
- tpu_inference/models/jax/qwen2.py +17 -2
- tpu_inference/models/jax/qwen2_5_vl.py +18 -4
- tpu_inference/models/jax/qwen3.py +17 -2
- tpu_inference/models/jax/utils/__init__.py +13 -0
- tpu_inference/models/jax/utils/file_utils.py +14 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
- tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
- tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +89 -26
- tpu_inference/models/jax/utils/weight_utils.py +39 -2
- tpu_inference/models/vllm/__init__.py +13 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +20 -3
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
- tpu_inference/platforms/__init__.py +14 -0
- tpu_inference/platforms/tpu_platform.py +47 -64
- tpu_inference/runner/__init__.py +13 -0
- tpu_inference/runner/compilation_manager.py +72 -37
- tpu_inference/runner/kv_cache.py +54 -20
- tpu_inference/runner/kv_cache_manager.py +46 -17
- tpu_inference/runner/lora_utils.py +14 -0
- tpu_inference/runner/multimodal_manager.py +15 -1
- tpu_inference/runner/persistent_batch_manager.py +14 -0
- tpu_inference/runner/speculative_decoding_manager.py +14 -0
- tpu_inference/runner/structured_decoding_manager.py +14 -0
- tpu_inference/runner/tpu_runner.py +44 -17
- tpu_inference/spec_decode/__init__.py +13 -0
- tpu_inference/spec_decode/jax/__init__.py +13 -0
- tpu_inference/spec_decode/jax/eagle3.py +13 -0
- tpu_inference/tpu_info.py +14 -0
- tpu_inference/utils.py +42 -36
- tpu_inference/worker/__init__.py +13 -0
- tpu_inference/worker/tpu_worker.py +63 -50
- {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.2rc3.dist-info}/METADATA +7 -9
- tpu_inference-0.13.2rc3.dist-info/RECORD +261 -0
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
- tpu_inference-0.11.1.dev202512030818.dist-info/RECORD +0 -174
- {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.2rc3.dist-info}/WHEEL +0 -0
- {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.2rc3.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.2rc3.dist-info}/top_level.txt +0 -0
|
@@ -1,3 +1,16 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
1
14
|
"""
|
|
2
15
|
A variant of TPU-Friendly Ragged Paged Attention kernel optimized for
|
|
3
16
|
head_dim = 64.
|
|
@@ -267,7 +280,6 @@ def _ragged_paged_attention_kernel(
|
|
|
267
280
|
*,
|
|
268
281
|
sm_scale: float,
|
|
269
282
|
sliding_window: int | None = None,
|
|
270
|
-
strict_sliding_window: bool = True,
|
|
271
283
|
soft_cap: float | None = None,
|
|
272
284
|
mask_value: float = DEFAULT_MASK_VALUE,
|
|
273
285
|
q_scale: float | None = None,
|
|
@@ -324,14 +336,15 @@ def _ragged_paged_attention_kernel(
|
|
|
324
336
|
bkv_idx_start = jnp.maximum(kv_len - q_len - sliding_window,
|
|
325
337
|
0) // bkv_sz
|
|
326
338
|
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
next_seq_bkv_idx_start =
|
|
334
|
-
|
|
339
|
+
# If seq_idx + 1 == num_seqs, kv_lens_ref[seq_idx + 1] will trigger a
|
|
340
|
+
# out-of-bound error. To avoid this, we set upperbound of next_seq_idx
|
|
341
|
+
# to be num_seqs - 1.
|
|
342
|
+
next_seq_idx = jnp.minimum(seq_idx + 1, num_seqs - 1)
|
|
343
|
+
next_kv_len = kv_lens_ref[next_seq_idx]
|
|
344
|
+
next_q_len = cu_q_lens_ref[next_seq_idx + 1] - q_end
|
|
345
|
+
next_seq_bkv_idx_start = (
|
|
346
|
+
jnp.maximum(next_kv_len - next_q_len - sliding_window, 0) //
|
|
347
|
+
bkv_sz)
|
|
335
348
|
|
|
336
349
|
def debug_print(msg, *args):
|
|
337
350
|
if debug_mode:
|
|
@@ -396,7 +409,7 @@ def _ragged_paged_attention_kernel(
|
|
|
396
409
|
k_span = bkv_idx * bkv_sz + lax.broadcasted_iota(jnp.int32, s.shape, 1)
|
|
397
410
|
mask = k_span <= q_span
|
|
398
411
|
|
|
399
|
-
if sliding_window is not None
|
|
412
|
+
if sliding_window is not None:
|
|
400
413
|
mask = jnp.logical_and(mask, q_span - sliding_window < k_span)
|
|
401
414
|
|
|
402
415
|
s = jnp.where(mask, s, mask_value)
|
|
@@ -520,19 +533,16 @@ def _ragged_paged_attention_kernel(
|
|
|
520
533
|
unroll=False,
|
|
521
534
|
)
|
|
522
535
|
|
|
523
|
-
|
|
524
|
-
|
|
525
|
-
|
|
526
|
-
|
|
527
|
-
|
|
528
|
-
|
|
529
|
-
|
|
530
|
-
|
|
531
|
-
|
|
532
|
-
|
|
533
|
-
sem,
|
|
534
|
-
wait,
|
|
535
|
-
)
|
|
536
|
+
size = lax.select(bkv_sz_frm_new > 0, bkv_sz_frm_new, 0)
|
|
537
|
+
new_kv_len_start = q_end - kv_left_frm_new
|
|
538
|
+
debug_print("[RPA debug] new_kv_len_start={}", new_kv_len_start)
|
|
539
|
+
debug_print("[RPA debug] offset_in_bkv={}", offset)
|
|
540
|
+
_async_copy(
|
|
541
|
+
kv_hbm_ref.at[pl.ds(new_kv_len_start, size)],
|
|
542
|
+
vmem_ref.at[pl.ds(offset, size)],
|
|
543
|
+
sem,
|
|
544
|
+
wait,
|
|
545
|
+
)
|
|
536
546
|
|
|
537
547
|
return kv_len_start + offset, bkv_sz_frm_new
|
|
538
548
|
else:
|
|
@@ -726,7 +736,7 @@ def _ragged_paged_attention_kernel(
|
|
|
726
736
|
vec = ref[start::step]
|
|
727
737
|
return vec
|
|
728
738
|
|
|
729
|
-
def strided_load_bkv(bkv_sem_idx, start, step
|
|
739
|
+
def strided_load_bkv(bkv_sem_idx, start, step):
|
|
730
740
|
assert start % kv_packing == 0
|
|
731
741
|
assert step % kv_packing == 0
|
|
732
742
|
start //= kv_packing
|
|
@@ -735,7 +745,6 @@ def _ragged_paged_attention_kernel(
|
|
|
735
745
|
bkv_sz * step, actual_head_dim_x2))
|
|
736
746
|
|
|
737
747
|
kv = strided_load(kv_ref, start, step)
|
|
738
|
-
kv = lax.select(bkv_mask, kv, jnp.zeros_like(kv))
|
|
739
748
|
bitwidth = 32 // kv_packing
|
|
740
749
|
repack_ty = jnp.dtype(f"uint{bitwidth}")
|
|
741
750
|
lst = []
|
|
@@ -783,14 +792,18 @@ def _ragged_paged_attention_kernel(
|
|
|
783
792
|
next_bkv_sem_idx = lax.select(bkv_sem_idx == 0, 1, 0)
|
|
784
793
|
|
|
785
794
|
if sliding_window is None:
|
|
786
|
-
|
|
795
|
+
# When sliding window is disabled, starting bkv_idx of next request is
|
|
796
|
+
# always 0 regardless of seq_idx of next request.
|
|
797
|
+
next_bkv_idx_start = 0
|
|
787
798
|
else:
|
|
788
|
-
|
|
799
|
+
# Determine starting bkv_idx of next request based on whether next
|
|
800
|
+
# request is from the same sequence or next sequence.
|
|
801
|
+
next_bkv_idx_start = lax.select(
|
|
789
802
|
is_last_bq,
|
|
790
803
|
next_seq_bkv_idx_start,
|
|
791
804
|
bkv_idx_start,
|
|
792
805
|
)
|
|
793
|
-
next_bkv_idx = lax.select(is_last_bkv,
|
|
806
|
+
next_bkv_idx = lax.select(is_last_bkv, next_bkv_idx_start,
|
|
794
807
|
next_bkv_idx)
|
|
795
808
|
|
|
796
809
|
return next_seq_idx, next_bq_idx, next_bkv_idx, next_bkv_sem_idx
|
|
@@ -809,10 +822,6 @@ def _ragged_paged_attention_kernel(
|
|
|
809
822
|
def compute_with_bkv(bkv_idx, _):
|
|
810
823
|
# Create bitmask for KV.
|
|
811
824
|
assert bkv_sz % kv_packing == 0
|
|
812
|
-
actual_bkv_sz = jnp.minimum(bkv_sz, kv_len - bkv_idx * bkv_sz)
|
|
813
|
-
bkv_shape = (bkv_sz, actual_head_dim_x2)
|
|
814
|
-
bkv_mask = lax.broadcasted_iota(jnp.int32, bkv_shape,
|
|
815
|
-
0) < actual_bkv_sz
|
|
816
825
|
|
|
817
826
|
# Get next bkv ids.
|
|
818
827
|
bkv_sem_idx = sem_ids_ref[1]
|
|
@@ -862,7 +871,6 @@ def _ragged_paged_attention_kernel(
|
|
|
862
871
|
bkv_sem_idx,
|
|
863
872
|
kv_head_start,
|
|
864
873
|
num_kv_heads,
|
|
865
|
-
bkv_mask=bkv_mask,
|
|
866
874
|
)
|
|
867
875
|
assert len(bkv_lst) == kv_packing
|
|
868
876
|
for i in range(kv_packing):
|
|
@@ -946,7 +954,17 @@ def _ragged_paged_attention_kernel(
|
|
|
946
954
|
@pl.when(seq_idx == 0)
|
|
947
955
|
def prologue():
|
|
948
956
|
start_fetch_bq(0, 0, 0)
|
|
957
|
+
|
|
958
|
+
# Initialize bkv_x2_ref to zeros to avoid NaN issues from accessing
|
|
959
|
+
# uninitialized memory. Bitcast into int32 to avoid tiling issues.
|
|
960
|
+
bkv_x2_int32_ref = bkv_x2_ref.bitcast(jnp.int32).reshape(
|
|
961
|
+
(2, -1, 8, 128))
|
|
962
|
+
zeros = jnp.zeros(bkv_x2_int32_ref.shape[1:], jnp.int32)
|
|
963
|
+
|
|
964
|
+
# To pipeline VST and DMA, we divide the initialization into two steps.
|
|
965
|
+
bkv_x2_int32_ref[0] = zeros
|
|
949
966
|
start_fetch_bkv(0, bkv_idx_start, 0)
|
|
967
|
+
bkv_x2_int32_ref[1] = zeros
|
|
950
968
|
|
|
951
969
|
@pl.when(seq_idx < decode_end)
|
|
952
970
|
def process_decode():
|
|
@@ -1306,12 +1324,15 @@ def static_validate_inputs(
|
|
|
1306
1324
|
del debug_mode
|
|
1307
1325
|
|
|
1308
1326
|
|
|
1327
|
+
def get_kernel_scope_name(bq_size, bkv_p, page_size):
|
|
1328
|
+
return f"RPA-HD_64-bq_{bq_size}-bkvp_{bkv_p}-p_{page_size}-"
|
|
1329
|
+
|
|
1330
|
+
|
|
1309
1331
|
@functools.partial(
|
|
1310
1332
|
jax.jit,
|
|
1311
1333
|
static_argnames=(
|
|
1312
1334
|
"sm_scale",
|
|
1313
1335
|
"sliding_window",
|
|
1314
|
-
"strict_sliding_window",
|
|
1315
1336
|
"soft_cap",
|
|
1316
1337
|
"mask_value",
|
|
1317
1338
|
"q_scale",
|
|
@@ -1341,7 +1362,6 @@ def ragged_paged_attention_hd64(
|
|
|
1341
1362
|
*,
|
|
1342
1363
|
sm_scale: float = 1.0,
|
|
1343
1364
|
sliding_window: int | None = None,
|
|
1344
|
-
strict_sliding_window: bool = True,
|
|
1345
1365
|
soft_cap: float | None = None,
|
|
1346
1366
|
mask_value: float | None = DEFAULT_MASK_VALUE,
|
|
1347
1367
|
q_scale: float | None = None,
|
|
@@ -1373,7 +1393,6 @@ def ragged_paged_attention_hd64(
|
|
|
1373
1393
|
attention_sink: optional attention sink for each q head.
|
|
1374
1394
|
sm_scale: the softmax scale which will be applied to the Q@K^T.
|
|
1375
1395
|
sliding_window: the sliding window size for the attention.
|
|
1376
|
-
strict_sliding_window: compute tokens that are strictly within the window.
|
|
1377
1396
|
soft_cap: the logit soft cap for the attention.
|
|
1378
1397
|
mask_value: mask value for causal mask.
|
|
1379
1398
|
q_scale: the scale for the query.
|
|
@@ -1447,6 +1466,7 @@ def ragged_paged_attention_hd64(
|
|
|
1447
1466
|
page_size,
|
|
1448
1467
|
max_num_tokens,
|
|
1449
1468
|
pages_per_seq,
|
|
1469
|
+
sliding_window,
|
|
1450
1470
|
)
|
|
1451
1471
|
bkv_sz = bkv_p * page_size
|
|
1452
1472
|
if vmem_limit_bytes is None:
|
|
@@ -1517,48 +1537,45 @@ def ragged_paged_attention_hd64(
|
|
|
1517
1537
|
jnp.full((6, ), -1, jnp.int32),
|
|
1518
1538
|
)
|
|
1519
1539
|
|
|
1520
|
-
scope_name =
|
|
1521
|
-
kernel =
|
|
1522
|
-
|
|
1523
|
-
|
|
1524
|
-
|
|
1525
|
-
|
|
1526
|
-
|
|
1527
|
-
|
|
1528
|
-
|
|
1529
|
-
|
|
1530
|
-
|
|
1531
|
-
|
|
1532
|
-
|
|
1533
|
-
|
|
1534
|
-
|
|
1535
|
-
|
|
1536
|
-
|
|
1537
|
-
),
|
|
1538
|
-
|
|
1539
|
-
|
|
1540
|
-
|
|
1541
|
-
|
|
1542
|
-
|
|
1543
|
-
|
|
1544
|
-
)
|
|
1545
|
-
|
|
1546
|
-
|
|
1547
|
-
|
|
1548
|
-
|
|
1549
|
-
|
|
1550
|
-
),
|
|
1551
|
-
|
|
1552
|
-
|
|
1553
|
-
|
|
1554
|
-
|
|
1555
|
-
|
|
1556
|
-
|
|
1557
|
-
|
|
1558
|
-
|
|
1559
|
-
},
|
|
1560
|
-
name=scope_name,
|
|
1561
|
-
))
|
|
1540
|
+
scope_name = get_kernel_scope_name(bq_sz, bkv_p, page_size)
|
|
1541
|
+
kernel = pl.pallas_call(
|
|
1542
|
+
functools.partial(
|
|
1543
|
+
_ragged_paged_attention_kernel,
|
|
1544
|
+
sm_scale=sm_scale,
|
|
1545
|
+
sliding_window=sliding_window,
|
|
1546
|
+
soft_cap=soft_cap,
|
|
1547
|
+
mask_value=mask_value,
|
|
1548
|
+
q_scale=q_scale,
|
|
1549
|
+
k_scale=k_scale,
|
|
1550
|
+
v_scale=v_scale,
|
|
1551
|
+
chunk_prefill_size=chunk_prefill_size,
|
|
1552
|
+
bq_sz=bq_sz,
|
|
1553
|
+
bkv_p=bkv_p,
|
|
1554
|
+
debug_mode=debug_mode,
|
|
1555
|
+
),
|
|
1556
|
+
grid_spec=pltpu.PrefetchScalarGridSpec(
|
|
1557
|
+
num_scalar_prefetch=len(scalar_prefetches),
|
|
1558
|
+
in_specs=in_specs,
|
|
1559
|
+
out_specs=out_specs,
|
|
1560
|
+
grid=grid,
|
|
1561
|
+
scratch_shapes=scratch_shapes,
|
|
1562
|
+
),
|
|
1563
|
+
compiler_params=pltpu.CompilerParams(
|
|
1564
|
+
# TODO(jevinjiang): since each sequence depends on the previous
|
|
1565
|
+
# one, we need some extra work to support Megacore mode.
|
|
1566
|
+
dimension_semantics=("arbitrary", ),
|
|
1567
|
+
vmem_limit_bytes=vmem_limit_bytes,
|
|
1568
|
+
),
|
|
1569
|
+
out_shape=[
|
|
1570
|
+
jax.ShapeDtypeStruct(shape=q.shape, dtype=q.dtype),
|
|
1571
|
+
jax.ShapeDtypeStruct(shape=kv_cache.shape, dtype=kv_cache.dtype),
|
|
1572
|
+
],
|
|
1573
|
+
input_output_aliases={
|
|
1574
|
+
7: 0,
|
|
1575
|
+
9: 1
|
|
1576
|
+
},
|
|
1577
|
+
name=scope_name,
|
|
1578
|
+
)
|
|
1562
1579
|
|
|
1563
1580
|
output, updated_kv_cache = kernel(*scalar_prefetches, q, kv, kv_cache,
|
|
1564
1581
|
attention_sink)
|