tpu-inference 0.11.1.dev202511270815__py3-none-any.whl → 0.13.0rc2.post7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +13 -0
- tests/core/__init__.py +13 -0
- tests/core/test_disagg_utils.py +14 -0
- tests/core/test_dp_scheduler.py +650 -768
- tests/core/test_init.py +14 -0
- tests/distributed/__init__.py +13 -0
- tests/distributed/test_distributed_utils.py +120 -0
- tests/distributed/test_tpu_connector.py +478 -0
- tests/e2e/__init__.py +13 -0
- tests/e2e/test_async_scheduler.py +211 -0
- tests/e2e/test_data_parallel.py +289 -0
- tests/e2e/test_hybrid_kvcache.py +219 -0
- tests/e2e/test_local_disagg.py +257 -0
- tests/e2e/test_model_loader.py +268 -0
- tests/e2e/test_multi_modal_inference.py +111 -0
- tests/e2e/test_pipeline_parallel.py +265 -0
- tests/e2e/test_runai_model_streamer_loader.py +104 -0
- tests/e2e/test_sampling_params.py +269 -0
- tests/e2e/test_speculative_decoding.py +311 -0
- tests/e2e/test_structured_decoding.py +46 -0
- tests/executors/__init__.py +13 -0
- tests/executors/test_ray_distributed_executor.py +199 -0
- tests/experimental/__init__.py +13 -0
- tests/experimental/test_llama3_jax_stashed.py +208 -0
- tests/kernels/__init__.py +13 -0
- tests/kernels/collectives/__init__.py +13 -0
- tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
- tests/kernels/fused_moe_v1_test.py +14 -0
- tests/kernels/gmm_test.py +205 -0
- tests/kernels/mla_v1_test.py +143 -41
- tests/kernels/quantized_matmul_kernel_test.py +2 -34
- tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
- tests/layers/__init__.py +13 -0
- tests/layers/common/__init__.py +13 -0
- tests/layers/common/test_attention_interface.py +156 -0
- tests/layers/common/test_quantization.py +149 -0
- tests/layers/jax/__init__.py +13 -0
- tests/layers/jax/attention/__init__.py +13 -0
- tests/layers/jax/attention/test_common_attention.py +103 -0
- tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
- tests/layers/jax/attention/test_llama4_attention.py +135 -0
- tests/layers/jax/moe/__init__.py +13 -0
- tests/layers/jax/moe/test_deepseek_moe.py +235 -0
- tests/layers/jax/sample/__init__.py +13 -0
- tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
- tests/layers/jax/sample/test_sampling.py +115 -0
- tests/layers/jax/sample/test_sampling_metadata.py +254 -0
- tests/layers/jax/test_layers.py +155 -0
- tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
- tests/layers/jax/test_rope.py +93 -0
- tests/layers/jax/test_sharding.py +159 -0
- tests/layers/jax/test_transformer_block.py +152 -0
- tests/layers/vllm/__init__.py +13 -0
- tests/layers/vllm/test_attention.py +363 -0
- tests/layers/vllm/test_awq.py +405 -0
- tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +418 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +441 -0
- tests/layers/vllm/test_fp8.py +17 -0
- tests/layers/vllm/test_mxfp4.py +312 -0
- tests/layers/vllm/test_unquantized.py +651 -0
- tests/layers/vllm/utils.py +87 -0
- tests/lora/__init__.py +13 -0
- tests/lora/conftest.py +14 -0
- tests/lora/test_bgmv.py +14 -0
- tests/lora/test_layers.py +21 -3
- tests/lora/test_lora.py +15 -1
- tests/lora/test_lora_perf.py +67 -0
- tests/models/__init__.py +13 -0
- tests/models/common/__init__.py +13 -0
- tests/models/common/test_model_loader.py +455 -0
- tests/models/jax/__init__.py +13 -0
- tests/models/jax/test_deepseek_v3.py +401 -0
- tests/models/jax/test_llama3.py +184 -0
- tests/models/jax/test_llama4.py +298 -0
- tests/models/jax/test_llama_eagle3.py +197 -0
- tests/models/jax/test_llama_guard_4.py +242 -0
- tests/models/jax/test_qwen2.py +172 -0
- tests/models/jax/test_qwen2_5_vl.py +605 -0
- tests/models/jax/test_qwen3.py +169 -0
- tests/models/jax/test_weight_loading.py +180 -0
- tests/models/jax/utils/__init__.py +13 -0
- tests/models/jax/utils/test_multi_modal_utils.py +212 -0
- tests/platforms/__init__.py +13 -0
- tests/platforms/test_tpu_platform.py +54 -0
- tests/runner/__init__.py +13 -0
- tests/runner/test_block_table.py +395 -0
- tests/runner/test_input_batch.py +226 -0
- tests/runner/test_kv_cache.py +220 -0
- tests/runner/test_kv_cache_manager.py +498 -0
- tests/runner/test_multimodal_manager.py +429 -0
- tests/runner/test_persistent_batch_manager.py +84 -0
- tests/runner/test_speculative_decoding_manager.py +368 -0
- tests/runner/test_structured_decoding_manager.py +220 -0
- tests/runner/test_tpu_runner.py +261 -0
- tests/runner/test_tpu_runner_dp.py +1099 -0
- tests/runner/test_tpu_runner_mesh.py +200 -0
- tests/runner/test_utils.py +411 -0
- tests/spec_decode/__init__.py +13 -0
- tests/spec_decode/test_eagle3.py +311 -0
- tests/test_base.py +14 -0
- tests/test_envs.py +110 -12
- tests/test_tpu_info.py +14 -0
- tests/test_utils.py +2 -45
- tests/worker/__init__.py +13 -0
- tests/worker/tpu_worker_test.py +414 -0
- tpu_inference/__init__.py +14 -0
- tpu_inference/core/__init__.py +13 -0
- tpu_inference/core/sched/__init__.py +13 -0
- tpu_inference/core/sched/dp_scheduler.py +372 -56
- tpu_inference/distributed/__init__.py +13 -0
- tpu_inference/distributed/jax_parallel_state.py +14 -0
- tpu_inference/distributed/tpu_connector.py +15 -10
- tpu_inference/distributed/utils.py +56 -4
- tpu_inference/envs.py +92 -8
- tpu_inference/executors/__init__.py +13 -0
- tpu_inference/executors/ray_distributed_executor.py +22 -1
- tpu_inference/experimental/__init__.py +13 -0
- tpu_inference/experimental/llama3_jax_stashed.py +14 -0
- tpu_inference/kernels/__init__.py +13 -0
- tpu_inference/kernels/collectives/__init__.py +13 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
- tpu_inference/kernels/flash_attention/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +370 -324
- tpu_inference/kernels/megablox/__init__.py +13 -0
- tpu_inference/kernels/megablox/common.py +54 -0
- tpu_inference/kernels/megablox/gmm.py +646 -0
- tpu_inference/kernels/mla/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/kernel.py +117 -145
- tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
- tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +167 -97
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
- tpu_inference/layers/__init__.py +13 -0
- tpu_inference/layers/common/__init__.py +13 -0
- tpu_inference/layers/common/attention_interface.py +26 -19
- tpu_inference/layers/common/attention_metadata.py +14 -0
- tpu_inference/layers/common/quant_methods.py +15 -0
- tpu_inference/layers/common/quantization.py +270 -0
- tpu_inference/layers/common/sharding.py +31 -9
- tpu_inference/layers/jax/__init__.py +13 -0
- tpu_inference/layers/jax/attention/__init__.py +13 -0
- tpu_inference/layers/jax/attention/attention.py +19 -6
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
- tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
- tpu_inference/layers/jax/base.py +14 -0
- tpu_inference/layers/jax/constants.py +13 -0
- tpu_inference/layers/jax/layers.py +14 -0
- tpu_inference/layers/jax/misc.py +14 -0
- tpu_inference/layers/jax/moe/__init__.py +13 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
- tpu_inference/layers/jax/moe/moe.py +43 -3
- tpu_inference/layers/jax/pp_utils.py +53 -0
- tpu_inference/layers/jax/rope.py +14 -0
- tpu_inference/layers/jax/rope_interface.py +14 -0
- tpu_inference/layers/jax/sample/__init__.py +13 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
- tpu_inference/layers/jax/sample/sampling.py +15 -1
- tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
- tpu_inference/layers/jax/transformer_block.py +14 -0
- tpu_inference/layers/vllm/__init__.py +13 -0
- tpu_inference/layers/vllm/attention.py +4 -4
- tpu_inference/layers/vllm/fused_moe.py +210 -260
- tpu_inference/layers/vllm/linear_common.py +57 -22
- tpu_inference/layers/vllm/quantization/__init__.py +16 -0
- tpu_inference/layers/vllm/quantization/awq.py +15 -1
- tpu_inference/layers/vllm/quantization/common.py +33 -18
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +18 -3
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +211 -148
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +14 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +14 -0
- tpu_inference/layers/vllm/quantization/fp8.py +118 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +280 -210
- tpu_inference/layers/vllm/quantization/unquantized.py +134 -86
- tpu_inference/layers/vllm/sharding.py +21 -4
- tpu_inference/lora/__init__.py +13 -0
- tpu_inference/lora/torch_lora_ops.py +8 -13
- tpu_inference/models/__init__.py +13 -0
- tpu_inference/models/common/__init__.py +13 -0
- tpu_inference/models/common/model_loader.py +77 -36
- tpu_inference/models/jax/__init__.py +13 -0
- tpu_inference/models/jax/deepseek_v3.py +267 -157
- tpu_inference/models/jax/gpt_oss.py +26 -10
- tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
- tpu_inference/models/jax/llama3.py +99 -36
- tpu_inference/models/jax/llama4.py +14 -0
- tpu_inference/models/jax/llama_eagle3.py +14 -0
- tpu_inference/models/jax/llama_guard_4.py +15 -1
- tpu_inference/models/jax/qwen2.py +17 -2
- tpu_inference/models/jax/qwen2_5_vl.py +18 -4
- tpu_inference/models/jax/qwen3.py +17 -2
- tpu_inference/models/jax/utils/__init__.py +13 -0
- tpu_inference/models/jax/utils/file_utils.py +14 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
- tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
- tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +91 -31
- tpu_inference/models/jax/utils/weight_utils.py +39 -2
- tpu_inference/models/vllm/__init__.py +13 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +20 -4
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
- tpu_inference/platforms/__init__.py +14 -0
- tpu_inference/platforms/tpu_platform.py +47 -71
- tpu_inference/runner/__init__.py +13 -0
- tpu_inference/runner/compilation_manager.py +158 -63
- tpu_inference/runner/kv_cache.py +54 -20
- tpu_inference/runner/kv_cache_manager.py +53 -30
- tpu_inference/runner/lora_utils.py +14 -0
- tpu_inference/runner/multimodal_manager.py +15 -1
- tpu_inference/runner/persistent_batch_manager.py +54 -2
- tpu_inference/runner/speculative_decoding_manager.py +14 -0
- tpu_inference/runner/structured_decoding_manager.py +14 -0
- tpu_inference/runner/tpu_runner.py +105 -57
- tpu_inference/runner/utils.py +2 -2
- tpu_inference/spec_decode/__init__.py +13 -0
- tpu_inference/spec_decode/jax/__init__.py +13 -0
- tpu_inference/spec_decode/jax/eagle3.py +65 -19
- tpu_inference/tpu_info.py +14 -0
- tpu_inference/utils.py +72 -44
- tpu_inference/worker/__init__.py +13 -0
- tpu_inference/worker/tpu_worker.py +65 -52
- {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/METADATA +11 -9
- tpu_inference-0.13.0rc2.post7.dist-info/RECORD +261 -0
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
- tpu_inference-0.11.1.dev202511270815.dist-info/RECORD +0 -174
- {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/WHEEL +0 -0
- {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/top_level.txt +0 -0
|
@@ -1,3 +1,16 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
1
14
|
"""
|
|
2
15
|
A variant of TPU-Friendly Ragged Paged Attention kernel optimized for
|
|
3
16
|
head_dim = 64.
|
|
@@ -267,7 +280,6 @@ def _ragged_paged_attention_kernel(
|
|
|
267
280
|
*,
|
|
268
281
|
sm_scale: float,
|
|
269
282
|
sliding_window: int | None = None,
|
|
270
|
-
strict_sliding_window: bool = True,
|
|
271
283
|
soft_cap: float | None = None,
|
|
272
284
|
mask_value: float = DEFAULT_MASK_VALUE,
|
|
273
285
|
q_scale: float | None = None,
|
|
@@ -324,14 +336,15 @@ def _ragged_paged_attention_kernel(
|
|
|
324
336
|
bkv_idx_start = jnp.maximum(kv_len - q_len - sliding_window,
|
|
325
337
|
0) // bkv_sz
|
|
326
338
|
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
next_seq_bkv_idx_start =
|
|
334
|
-
|
|
339
|
+
# If seq_idx + 1 == num_seqs, kv_lens_ref[seq_idx + 1] will trigger a
|
|
340
|
+
# out-of-bound error. To avoid this, we set upperbound of next_seq_idx
|
|
341
|
+
# to be num_seqs - 1.
|
|
342
|
+
next_seq_idx = jnp.minimum(seq_idx + 1, num_seqs - 1)
|
|
343
|
+
next_kv_len = kv_lens_ref[next_seq_idx]
|
|
344
|
+
next_q_len = cu_q_lens_ref[next_seq_idx + 1] - q_end
|
|
345
|
+
next_seq_bkv_idx_start = (
|
|
346
|
+
jnp.maximum(next_kv_len - next_q_len - sliding_window, 0) //
|
|
347
|
+
bkv_sz)
|
|
335
348
|
|
|
336
349
|
def debug_print(msg, *args):
|
|
337
350
|
if debug_mode:
|
|
@@ -352,7 +365,7 @@ def _ragged_paged_attention_kernel(
|
|
|
352
365
|
debug_print("[RPA debug] q_len={}", q_len)
|
|
353
366
|
debug_print("[RPA debug] kv_len={}", kv_len)
|
|
354
367
|
|
|
355
|
-
def
|
|
368
|
+
def flash_attention_step1_qk_softmax(
|
|
356
369
|
q, # [actual_bq_sz * num_q_heads_per_kv_head, actual_head_dim_x2]
|
|
357
370
|
kv, # [bkv_sz, actual_head_dim_x2]
|
|
358
371
|
*,
|
|
@@ -366,7 +379,6 @@ def _ragged_paged_attention_kernel(
|
|
|
366
379
|
assert kv.shape == (bkv_sz, actual_head_dim_x2)
|
|
367
380
|
head_l_ref = l_ref.at[kv_head_idx, :q.shape[0]]
|
|
368
381
|
head_m_ref = m_ref.at[kv_head_idx, :q.shape[0]]
|
|
369
|
-
head_acc_ref = acc_ref.at[kv_head_idx, :q.shape[0]]
|
|
370
382
|
|
|
371
383
|
def load_with_init(ref, init_val):
|
|
372
384
|
return jnp.where(bkv_idx == bkv_idx_start,
|
|
@@ -397,7 +409,7 @@ def _ragged_paged_attention_kernel(
|
|
|
397
409
|
k_span = bkv_idx * bkv_sz + lax.broadcasted_iota(jnp.int32, s.shape, 1)
|
|
398
410
|
mask = k_span <= q_span
|
|
399
411
|
|
|
400
|
-
if sliding_window is not None
|
|
412
|
+
if sliding_window is not None:
|
|
401
413
|
mask = jnp.logical_and(mask, q_span - sliding_window < k_span)
|
|
402
414
|
|
|
403
415
|
s = jnp.where(mask, s, mask_value)
|
|
@@ -416,15 +428,33 @@ def _ragged_paged_attention_kernel(
|
|
|
416
428
|
head_m_ref[...] = m_curr
|
|
417
429
|
p = jnp.exp(s - broadcast_minor(m_curr, s.shape))
|
|
418
430
|
|
|
419
|
-
pv = jnp.einsum("nm,md->nd", p, kv, preferred_element_type=jnp.float32)
|
|
420
|
-
if v_scale is not None:
|
|
421
|
-
pv *= v_scale
|
|
422
|
-
|
|
423
431
|
p_rowsum = jnp.sum(p, axis=1, keepdims=True)
|
|
424
432
|
exp_m_diff = jnp.exp(m_prev - m_curr)
|
|
425
433
|
l_prev = load_with_init(head_l_ref, 1.0)
|
|
426
434
|
l_curr = exp_m_diff * l_prev + p_rowsum
|
|
427
435
|
head_l_ref[...] = l_curr
|
|
436
|
+
|
|
437
|
+
return p, exp_m_diff
|
|
438
|
+
|
|
439
|
+
def flash_attention_step2_pv(
|
|
440
|
+
q_shape_0,
|
|
441
|
+
kv, # [bkv_sz, actual_head_dim_x2]
|
|
442
|
+
p, # from step1
|
|
443
|
+
exp_m_diff, # from step1
|
|
444
|
+
*,
|
|
445
|
+
bkv_idx,
|
|
446
|
+
kv_head_idx,
|
|
447
|
+
):
|
|
448
|
+
head_acc_ref = acc_ref.at[kv_head_idx, :q_shape_0]
|
|
449
|
+
|
|
450
|
+
def load_with_init(ref, init_val):
|
|
451
|
+
return jnp.where(bkv_idx == bkv_idx_start,
|
|
452
|
+
jnp.full_like(ref, init_val), ref[...])
|
|
453
|
+
|
|
454
|
+
pv = jnp.einsum("nm,md->nd", p, kv, preferred_element_type=jnp.float32)
|
|
455
|
+
if v_scale is not None:
|
|
456
|
+
pv *= v_scale
|
|
457
|
+
|
|
428
458
|
o_prev = load_with_init(head_acc_ref, 0.0)
|
|
429
459
|
o_curr = broadcast_minor(exp_m_diff, o_prev.shape) * o_prev + pv
|
|
430
460
|
head_acc_ref[...] = o_curr
|
|
@@ -503,19 +533,16 @@ def _ragged_paged_attention_kernel(
|
|
|
503
533
|
unroll=False,
|
|
504
534
|
)
|
|
505
535
|
|
|
506
|
-
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
|
|
510
|
-
|
|
511
|
-
|
|
512
|
-
|
|
513
|
-
|
|
514
|
-
|
|
515
|
-
|
|
516
|
-
sem,
|
|
517
|
-
wait,
|
|
518
|
-
)
|
|
536
|
+
size = lax.select(bkv_sz_frm_new > 0, bkv_sz_frm_new, 0)
|
|
537
|
+
new_kv_len_start = q_end - kv_left_frm_new
|
|
538
|
+
debug_print("[RPA debug] new_kv_len_start={}", new_kv_len_start)
|
|
539
|
+
debug_print("[RPA debug] offset_in_bkv={}", offset)
|
|
540
|
+
_async_copy(
|
|
541
|
+
kv_hbm_ref.at[pl.ds(new_kv_len_start, size)],
|
|
542
|
+
vmem_ref.at[pl.ds(offset, size)],
|
|
543
|
+
sem,
|
|
544
|
+
wait,
|
|
545
|
+
)
|
|
519
546
|
|
|
520
547
|
return kv_len_start + offset, bkv_sz_frm_new
|
|
521
548
|
else:
|
|
@@ -709,7 +736,7 @@ def _ragged_paged_attention_kernel(
|
|
|
709
736
|
vec = ref[start::step]
|
|
710
737
|
return vec
|
|
711
738
|
|
|
712
|
-
def strided_load_bkv(bkv_sem_idx, start, step
|
|
739
|
+
def strided_load_bkv(bkv_sem_idx, start, step):
|
|
713
740
|
assert start % kv_packing == 0
|
|
714
741
|
assert step % kv_packing == 0
|
|
715
742
|
start //= kv_packing
|
|
@@ -718,7 +745,6 @@ def _ragged_paged_attention_kernel(
|
|
|
718
745
|
bkv_sz * step, actual_head_dim_x2))
|
|
719
746
|
|
|
720
747
|
kv = strided_load(kv_ref, start, step)
|
|
721
|
-
kv = lax.select(bkv_mask, kv, jnp.zeros_like(kv))
|
|
722
748
|
bitwidth = 32 // kv_packing
|
|
723
749
|
repack_ty = jnp.dtype(f"uint{bitwidth}")
|
|
724
750
|
lst = []
|
|
@@ -766,14 +792,18 @@ def _ragged_paged_attention_kernel(
|
|
|
766
792
|
next_bkv_sem_idx = lax.select(bkv_sem_idx == 0, 1, 0)
|
|
767
793
|
|
|
768
794
|
if sliding_window is None:
|
|
769
|
-
|
|
795
|
+
# When sliding window is disabled, starting bkv_idx of next request is
|
|
796
|
+
# always 0 regardless of seq_idx of next request.
|
|
797
|
+
next_bkv_idx_start = 0
|
|
770
798
|
else:
|
|
771
|
-
|
|
799
|
+
# Determine starting bkv_idx of next request based on whether next
|
|
800
|
+
# request is from the same sequence or next sequence.
|
|
801
|
+
next_bkv_idx_start = lax.select(
|
|
772
802
|
is_last_bq,
|
|
773
803
|
next_seq_bkv_idx_start,
|
|
774
804
|
bkv_idx_start,
|
|
775
805
|
)
|
|
776
|
-
next_bkv_idx = lax.select(is_last_bkv,
|
|
806
|
+
next_bkv_idx = lax.select(is_last_bkv, next_bkv_idx_start,
|
|
777
807
|
next_bkv_idx)
|
|
778
808
|
|
|
779
809
|
return next_seq_idx, next_bq_idx, next_bkv_idx, next_bkv_sem_idx
|
|
@@ -792,10 +822,6 @@ def _ragged_paged_attention_kernel(
|
|
|
792
822
|
def compute_with_bkv(bkv_idx, _):
|
|
793
823
|
# Create bitmask for KV.
|
|
794
824
|
assert bkv_sz % kv_packing == 0
|
|
795
|
-
actual_bkv_sz = jnp.minimum(bkv_sz, kv_len - bkv_idx * bkv_sz)
|
|
796
|
-
bkv_shape = (bkv_sz, actual_head_dim_x2)
|
|
797
|
-
bkv_mask = lax.broadcasted_iota(jnp.int32, bkv_shape,
|
|
798
|
-
0) < actual_bkv_sz
|
|
799
825
|
|
|
800
826
|
# Get next bkv ids.
|
|
801
827
|
bkv_sem_idx = sem_ids_ref[1]
|
|
@@ -835,29 +861,64 @@ def _ragged_paged_attention_kernel(
|
|
|
835
861
|
return
|
|
836
862
|
|
|
837
863
|
# Flash attention with cur bkv and bq
|
|
864
|
+
prev_bq_shape_0 = None
|
|
865
|
+
prev_kv_head_bkv = None
|
|
866
|
+
prev_kv_head_idx = None
|
|
867
|
+
prev_kv_head_p = None
|
|
868
|
+
prev_kv_head_exp_m_diff = None
|
|
838
869
|
for kv_head_start in range(0, actual_num_kv_heads, kv_packing):
|
|
839
870
|
bkv_lst = strided_load_bkv(
|
|
840
871
|
bkv_sem_idx,
|
|
841
872
|
kv_head_start,
|
|
842
873
|
num_kv_heads,
|
|
843
|
-
bkv_mask=bkv_mask,
|
|
844
874
|
)
|
|
845
875
|
assert len(bkv_lst) == kv_packing
|
|
846
876
|
for i in range(kv_packing):
|
|
847
|
-
|
|
848
|
-
if
|
|
877
|
+
cur_kv_head_idx = kv_head_start + i
|
|
878
|
+
if cur_kv_head_idx >= actual_num_kv_heads:
|
|
849
879
|
break
|
|
850
|
-
|
|
851
|
-
|
|
852
|
-
|
|
853
|
-
|
|
854
|
-
|
|
855
|
-
|
|
856
|
-
|
|
857
|
-
|
|
858
|
-
|
|
859
|
-
|
|
860
|
-
|
|
880
|
+
cur_kv_head_bq = load_bq(bq_sem_idx,
|
|
881
|
+
cur_kv_head_idx,
|
|
882
|
+
actual_bq_sz=actual_bq_sz)
|
|
883
|
+
cur_kv_head__bkv = bkv_lst[i]
|
|
884
|
+
# FlashAttention is divided into `flash_attention_step1_qk_softmax`
|
|
885
|
+
# and `flash_attention_step2_pv` to pipeline the computation.
|
|
886
|
+
# `step2_pv` for the previous KV head, which depends on the softmax
|
|
887
|
+
# output, is overlapped with `step1_qk_softmax` for the current KV
|
|
888
|
+
# head, reducing overall wait times.
|
|
889
|
+
cur_kv_head_p, cur_kv_head_exp_m_diff = (
|
|
890
|
+
flash_attention_step1_qk_softmax(
|
|
891
|
+
cur_kv_head_bq,
|
|
892
|
+
cur_kv_head__bkv,
|
|
893
|
+
bq_idx=bq_idx,
|
|
894
|
+
bkv_idx=bkv_idx,
|
|
895
|
+
kv_head_idx=cur_kv_head_idx,
|
|
896
|
+
))
|
|
897
|
+
if prev_bq_shape_0 is not None:
|
|
898
|
+
flash_attention_step2_pv(
|
|
899
|
+
prev_bq_shape_0,
|
|
900
|
+
prev_kv_head_bkv,
|
|
901
|
+
prev_kv_head_p,
|
|
902
|
+
prev_kv_head_exp_m_diff,
|
|
903
|
+
bkv_idx=bkv_idx,
|
|
904
|
+
kv_head_idx=prev_kv_head_idx,
|
|
905
|
+
)
|
|
906
|
+
prev_bq_shape_0 = cur_kv_head_bq.shape[0]
|
|
907
|
+
prev_kv_head_bkv = cur_kv_head__bkv
|
|
908
|
+
prev_kv_head_p = cur_kv_head_p
|
|
909
|
+
prev_kv_head_exp_m_diff = cur_kv_head_exp_m_diff
|
|
910
|
+
prev_kv_head_idx = cur_kv_head_idx
|
|
911
|
+
|
|
912
|
+
# Execute pv of last attention head.
|
|
913
|
+
assert prev_bq_shape_0 is not None
|
|
914
|
+
flash_attention_step2_pv(
|
|
915
|
+
prev_bq_shape_0,
|
|
916
|
+
prev_kv_head_bkv,
|
|
917
|
+
prev_kv_head_p,
|
|
918
|
+
prev_kv_head_exp_m_diff,
|
|
919
|
+
bkv_idx=bkv_idx,
|
|
920
|
+
kv_head_idx=prev_kv_head_idx,
|
|
921
|
+
)
|
|
861
922
|
|
|
862
923
|
lax.fori_loop(bkv_idx_start,
|
|
863
924
|
num_bkv,
|
|
@@ -893,7 +954,17 @@ def _ragged_paged_attention_kernel(
|
|
|
893
954
|
@pl.when(seq_idx == 0)
|
|
894
955
|
def prologue():
|
|
895
956
|
start_fetch_bq(0, 0, 0)
|
|
957
|
+
|
|
958
|
+
# Initialize bkv_x2_ref to zeros to avoid NaN issues from accessing
|
|
959
|
+
# uninitialized memory. Bitcast into int32 to avoid tiling issues.
|
|
960
|
+
bkv_x2_int32_ref = bkv_x2_ref.bitcast(jnp.int32).reshape(
|
|
961
|
+
(2, -1, 8, 128))
|
|
962
|
+
zeros = jnp.zeros(bkv_x2_int32_ref.shape[1:], jnp.int32)
|
|
963
|
+
|
|
964
|
+
# To pipeline VST and DMA, we divide the initialization into two steps.
|
|
965
|
+
bkv_x2_int32_ref[0] = zeros
|
|
896
966
|
start_fetch_bkv(0, bkv_idx_start, 0)
|
|
967
|
+
bkv_x2_int32_ref[1] = zeros
|
|
897
968
|
|
|
898
969
|
@pl.when(seq_idx < decode_end)
|
|
899
970
|
def process_decode():
|
|
@@ -1253,12 +1324,15 @@ def static_validate_inputs(
|
|
|
1253
1324
|
del debug_mode
|
|
1254
1325
|
|
|
1255
1326
|
|
|
1327
|
+
def get_kernel_scope_name(bq_size, bkv_p, page_size):
|
|
1328
|
+
return f"RPA-HD_64-bq_{bq_size}-bkvp_{bkv_p}-p_{page_size}-"
|
|
1329
|
+
|
|
1330
|
+
|
|
1256
1331
|
@functools.partial(
|
|
1257
1332
|
jax.jit,
|
|
1258
1333
|
static_argnames=(
|
|
1259
1334
|
"sm_scale",
|
|
1260
1335
|
"sliding_window",
|
|
1261
|
-
"strict_sliding_window",
|
|
1262
1336
|
"soft_cap",
|
|
1263
1337
|
"mask_value",
|
|
1264
1338
|
"q_scale",
|
|
@@ -1288,7 +1362,6 @@ def ragged_paged_attention_hd64(
|
|
|
1288
1362
|
*,
|
|
1289
1363
|
sm_scale: float = 1.0,
|
|
1290
1364
|
sliding_window: int | None = None,
|
|
1291
|
-
strict_sliding_window: bool = True,
|
|
1292
1365
|
soft_cap: float | None = None,
|
|
1293
1366
|
mask_value: float | None = DEFAULT_MASK_VALUE,
|
|
1294
1367
|
q_scale: float | None = None,
|
|
@@ -1320,7 +1393,6 @@ def ragged_paged_attention_hd64(
|
|
|
1320
1393
|
attention_sink: optional attention sink for each q head.
|
|
1321
1394
|
sm_scale: the softmax scale which will be applied to the Q@K^T.
|
|
1322
1395
|
sliding_window: the sliding window size for the attention.
|
|
1323
|
-
strict_sliding_window: compute tokens that are strictly within the window.
|
|
1324
1396
|
soft_cap: the logit soft cap for the attention.
|
|
1325
1397
|
mask_value: mask value for causal mask.
|
|
1326
1398
|
q_scale: the scale for the query.
|
|
@@ -1394,6 +1466,7 @@ def ragged_paged_attention_hd64(
|
|
|
1394
1466
|
page_size,
|
|
1395
1467
|
max_num_tokens,
|
|
1396
1468
|
pages_per_seq,
|
|
1469
|
+
sliding_window,
|
|
1397
1470
|
)
|
|
1398
1471
|
bkv_sz = bkv_p * page_size
|
|
1399
1472
|
if vmem_limit_bytes is None:
|
|
@@ -1464,48 +1537,45 @@ def ragged_paged_attention_hd64(
|
|
|
1464
1537
|
jnp.full((6, ), -1, jnp.int32),
|
|
1465
1538
|
)
|
|
1466
1539
|
|
|
1467
|
-
scope_name =
|
|
1468
|
-
kernel =
|
|
1469
|
-
|
|
1470
|
-
|
|
1471
|
-
|
|
1472
|
-
|
|
1473
|
-
|
|
1474
|
-
|
|
1475
|
-
|
|
1476
|
-
|
|
1477
|
-
|
|
1478
|
-
|
|
1479
|
-
|
|
1480
|
-
|
|
1481
|
-
|
|
1482
|
-
|
|
1483
|
-
|
|
1484
|
-
),
|
|
1485
|
-
|
|
1486
|
-
|
|
1487
|
-
|
|
1488
|
-
|
|
1489
|
-
|
|
1490
|
-
|
|
1491
|
-
)
|
|
1492
|
-
|
|
1493
|
-
|
|
1494
|
-
|
|
1495
|
-
|
|
1496
|
-
|
|
1497
|
-
),
|
|
1498
|
-
|
|
1499
|
-
|
|
1500
|
-
|
|
1501
|
-
|
|
1502
|
-
|
|
1503
|
-
|
|
1504
|
-
|
|
1505
|
-
|
|
1506
|
-
},
|
|
1507
|
-
name=scope_name,
|
|
1508
|
-
))
|
|
1540
|
+
scope_name = get_kernel_scope_name(bq_sz, bkv_p, page_size)
|
|
1541
|
+
kernel = pl.pallas_call(
|
|
1542
|
+
functools.partial(
|
|
1543
|
+
_ragged_paged_attention_kernel,
|
|
1544
|
+
sm_scale=sm_scale,
|
|
1545
|
+
sliding_window=sliding_window,
|
|
1546
|
+
soft_cap=soft_cap,
|
|
1547
|
+
mask_value=mask_value,
|
|
1548
|
+
q_scale=q_scale,
|
|
1549
|
+
k_scale=k_scale,
|
|
1550
|
+
v_scale=v_scale,
|
|
1551
|
+
chunk_prefill_size=chunk_prefill_size,
|
|
1552
|
+
bq_sz=bq_sz,
|
|
1553
|
+
bkv_p=bkv_p,
|
|
1554
|
+
debug_mode=debug_mode,
|
|
1555
|
+
),
|
|
1556
|
+
grid_spec=pltpu.PrefetchScalarGridSpec(
|
|
1557
|
+
num_scalar_prefetch=len(scalar_prefetches),
|
|
1558
|
+
in_specs=in_specs,
|
|
1559
|
+
out_specs=out_specs,
|
|
1560
|
+
grid=grid,
|
|
1561
|
+
scratch_shapes=scratch_shapes,
|
|
1562
|
+
),
|
|
1563
|
+
compiler_params=pltpu.CompilerParams(
|
|
1564
|
+
# TODO(jevinjiang): since each sequence depends on the previous
|
|
1565
|
+
# one, we need some extra work to support Megacore mode.
|
|
1566
|
+
dimension_semantics=("arbitrary", ),
|
|
1567
|
+
vmem_limit_bytes=vmem_limit_bytes,
|
|
1568
|
+
),
|
|
1569
|
+
out_shape=[
|
|
1570
|
+
jax.ShapeDtypeStruct(shape=q.shape, dtype=q.dtype),
|
|
1571
|
+
jax.ShapeDtypeStruct(shape=kv_cache.shape, dtype=kv_cache.dtype),
|
|
1572
|
+
],
|
|
1573
|
+
input_output_aliases={
|
|
1574
|
+
7: 0,
|
|
1575
|
+
9: 1
|
|
1576
|
+
},
|
|
1577
|
+
name=scope_name,
|
|
1578
|
+
)
|
|
1509
1579
|
|
|
1510
1580
|
output, updated_kv_cache = kernel(*scalar_prefetches, q, kv, kv_cache,
|
|
1511
1581
|
attention_sink)
|