tpu-inference 0.12.0.dev20251222__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/__init__.py +13 -0
- tests/core/__init__.py +13 -0
- tests/core/test_core_tpu.py +513 -0
- tests/core/test_disagg_executor.py +60 -0
- tests/core/test_disagg_utils.py +67 -0
- tests/core/test_dp_scheduler.py +724 -0
- tests/core/test_init.py +63 -0
- tests/distributed/__init__.py +13 -0
- tests/distributed/test_distributed_utils.py +120 -0
- tests/distributed/test_tpu_connector.py +478 -0
- tests/e2e/__init__.py +13 -0
- tests/e2e/test_async_scheduler.py +211 -0
- tests/e2e/test_data_parallel.py +393 -0
- tests/e2e/test_local_disagg.py +257 -0
- tests/e2e/test_model_loader.py +268 -0
- tests/e2e/test_multi_modal_inference.py +111 -0
- tests/e2e/test_pipeline_parallel.py +265 -0
- tests/e2e/test_runai_model_streamer_loader.py +104 -0
- tests/e2e/test_sampling_params.py +269 -0
- tests/e2e/test_speculative_decoding.py +291 -0
- tests/e2e/test_structured_decoding.py +46 -0
- tests/executors/__init__.py +13 -0
- tests/executors/test_ray_distributed_executor.py +199 -0
- tests/experimental/__init__.py +13 -0
- tests/experimental/test_llama3_jax_stashed.py +208 -0
- tests/kernels/__init__.py +13 -0
- tests/kernels/collectives/__init__.py +13 -0
- tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
- tests/kernels/fused_moe_v1_test.py +388 -0
- tests/kernels/gmm_test.py +205 -0
- tests/kernels/mla_v1_test.py +498 -0
- tests/kernels/quantized_matmul_kernel_test.py +159 -0
- tests/kernels/ragged_kv_cache_update_v2_test.py +248 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +414 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +565 -0
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +520 -0
- tests/layers/__init__.py +13 -0
- tests/layers/common/__init__.py +13 -0
- tests/layers/common/test_attention_interface.py +156 -0
- tests/layers/common/test_quantization.py +149 -0
- tests/layers/jax/__init__.py +13 -0
- tests/layers/jax/attention/__init__.py +13 -0
- tests/layers/jax/attention/test_common_attention.py +103 -0
- tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
- tests/layers/jax/attention/test_llama4_attention.py +135 -0
- tests/layers/jax/moe/__init__.py +13 -0
- tests/layers/jax/moe/test_deepseek_moe.py +235 -0
- tests/layers/jax/sample/__init__.py +13 -0
- tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
- tests/layers/jax/sample/test_sampling.py +115 -0
- tests/layers/jax/sample/test_sampling_metadata.py +254 -0
- tests/layers/jax/test_layers.py +155 -0
- tests/layers/jax/test_qwix.py +969 -0
- tests/layers/jax/test_rope.py +93 -0
- tests/layers/jax/test_sharding.py +159 -0
- tests/layers/jax/test_transformer_block.py +152 -0
- tests/layers/vllm/__init__.py +13 -0
- tests/layers/vllm/test_attention.py +363 -0
- tests/layers/vllm/test_awq.py +405 -0
- tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +403 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +426 -0
- tests/layers/vllm/test_fp8.py +17 -0
- tests/layers/vllm/test_mxfp4.py +297 -0
- tests/layers/vllm/test_unquantized.py +621 -0
- tests/layers/vllm/utils.py +72 -0
- tests/lora/__init__.py +13 -0
- tests/lora/conftest.py +46 -0
- tests/lora/test_bgmv.py +57 -0
- tests/lora/test_layers.py +666 -0
- tests/lora/test_lora.py +147 -0
- tests/lora/test_lora_perf.py +67 -0
- tests/lora/utils.py +88 -0
- tests/models/__init__.py +13 -0
- tests/models/common/__init__.py +13 -0
- tests/models/common/test_model_loader.py +455 -0
- tests/models/jax/__init__.py +13 -0
- tests/models/jax/test_deepseek_v3.py +401 -0
- tests/models/jax/test_llama3.py +184 -0
- tests/models/jax/test_llama4.py +298 -0
- tests/models/jax/test_llama_eagle3.py +197 -0
- tests/models/jax/test_llama_guard_4.py +242 -0
- tests/models/jax/test_qwen2.py +172 -0
- tests/models/jax/test_qwen2_5_vl.py +606 -0
- tests/models/jax/test_qwen3.py +169 -0
- tests/models/jax/test_weight_loading.py +180 -0
- tests/models/jax/utils/__init__.py +13 -0
- tests/models/jax/utils/test_multi_modal_utils.py +212 -0
- tests/platforms/__init__.py +13 -0
- tests/platforms/test_tpu_platform.py +54 -0
- tests/runner/__init__.py +13 -0
- tests/runner/test_block_table.py +395 -0
- tests/runner/test_input_batch.py +226 -0
- tests/runner/test_kv_cache.py +220 -0
- tests/runner/test_kv_cache_manager.py +498 -0
- tests/runner/test_multimodal_manager.py +429 -0
- tests/runner/test_persistent_batch_manager.py +84 -0
- tests/runner/test_speculative_decoding_manager.py +368 -0
- tests/runner/test_structured_decoding_manager.py +220 -0
- tests/runner/test_tpu_runner.py +202 -0
- tests/runner/test_tpu_runner_dp.py +1033 -0
- tests/runner/test_tpu_runner_mesh.py +200 -0
- tests/runner/test_utils.py +411 -0
- tests/spec_decode/__init__.py +13 -0
- tests/spec_decode/test_eagle3.py +311 -0
- tests/test_base.py +215 -0
- tests/test_envs.py +280 -0
- tests/test_tpu_info.py +134 -0
- tests/test_utils.py +193 -0
- tests/worker/__init__.py +13 -0
- tests/worker/tpu_worker_test.py +414 -0
- tpu_inference/__init__.py +67 -0
- tpu_inference/core/__init__.py +13 -0
- tpu_inference/core/core_tpu.py +786 -0
- tpu_inference/core/disagg_executor.py +118 -0
- tpu_inference/core/disagg_utils.py +49 -0
- tpu_inference/core/sched/__init__.py +13 -0
- tpu_inference/core/sched/dp_scheduler.py +814 -0
- tpu_inference/distributed/__init__.py +13 -0
- tpu_inference/distributed/jax_parallel_state.py +81 -0
- tpu_inference/distributed/tpu_connector.py +732 -0
- tpu_inference/distributed/utils.py +112 -0
- tpu_inference/env_override.py +9 -0
- tpu_inference/envs.py +191 -0
- tpu_inference/executors/__init__.py +13 -0
- tpu_inference/executors/ray_distributed_executor.py +399 -0
- tpu_inference/experimental/__init__.py +13 -0
- tpu_inference/experimental/llama3_jax_stashed.py +272 -0
- tpu_inference/kernels/__init__.py +13 -0
- tpu_inference/kernels/collectives/__init__.py +13 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +741 -0
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +65 -0
- tpu_inference/kernels/collectives/util.py +47 -0
- tpu_inference/kernels/flash_attention/__init__.py +13 -0
- tpu_inference/kernels/flash_attention/kernel.py +772 -0
- tpu_inference/kernels/fused_moe/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +1612 -0
- tpu_inference/kernels/megablox/__init__.py +13 -0
- tpu_inference/kernels/megablox/common.py +54 -0
- tpu_inference/kernels/megablox/gmm.py +646 -0
- tpu_inference/kernels/mla/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/kernel.py +1340 -0
- tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +456 -0
- tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
- tpu_inference/kernels/quantized_matmul/util.py +58 -0
- tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +876 -0
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +288 -0
- tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1594 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1586 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4460 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +548 -0
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +65 -0
- tpu_inference/layers/__init__.py +13 -0
- tpu_inference/layers/common/__init__.py +13 -0
- tpu_inference/layers/common/attention_interface.py +403 -0
- tpu_inference/layers/common/attention_metadata.py +48 -0
- tpu_inference/layers/common/binary_search.py +295 -0
- tpu_inference/layers/common/quant_methods.py +23 -0
- tpu_inference/layers/common/quantization.py +270 -0
- tpu_inference/layers/common/sharding.py +600 -0
- tpu_inference/layers/jax/__init__.py +13 -0
- tpu_inference/layers/jax/attention/__init__.py +13 -0
- tpu_inference/layers/jax/attention/attention.py +268 -0
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +547 -0
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +275 -0
- tpu_inference/layers/jax/attention/llama4_attention.py +167 -0
- tpu_inference/layers/jax/base.py +165 -0
- tpu_inference/layers/jax/constants.py +101 -0
- tpu_inference/layers/jax/layers.py +315 -0
- tpu_inference/layers/jax/misc.py +30 -0
- tpu_inference/layers/jax/moe/__init__.py +13 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +615 -0
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +199 -0
- tpu_inference/layers/jax/moe/moe.py +249 -0
- tpu_inference/layers/jax/pp_utils.py +53 -0
- tpu_inference/layers/jax/rope.py +294 -0
- tpu_inference/layers/jax/rope_interface.py +228 -0
- tpu_inference/layers/jax/sample/__init__.py +13 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +528 -0
- tpu_inference/layers/jax/sample/sampling.py +110 -0
- tpu_inference/layers/jax/sample/sampling_metadata.py +90 -0
- tpu_inference/layers/jax/transformer_block.py +121 -0
- tpu_inference/layers/vllm/__init__.py +13 -0
- tpu_inference/layers/vllm/attention.py +221 -0
- tpu_inference/layers/vllm/fused_moe.py +502 -0
- tpu_inference/layers/vllm/linear_common.py +221 -0
- tpu_inference/layers/vllm/quantization/__init__.py +55 -0
- tpu_inference/layers/vllm/quantization/awq.py +221 -0
- tpu_inference/layers/vllm/quantization/common.py +124 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +135 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +266 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +222 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +150 -0
- tpu_inference/layers/vllm/quantization/fp8.py +118 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +396 -0
- tpu_inference/layers/vllm/quantization/unquantized.py +416 -0
- tpu_inference/layers/vllm/sharding.py +244 -0
- tpu_inference/logger.py +10 -0
- tpu_inference/lora/__init__.py +13 -0
- tpu_inference/lora/torch_lora_ops.py +98 -0
- tpu_inference/lora/torch_punica_tpu.py +310 -0
- tpu_inference/models/__init__.py +13 -0
- tpu_inference/models/common/__init__.py +13 -0
- tpu_inference/models/common/model_loader.py +520 -0
- tpu_inference/models/jax/__init__.py +13 -0
- tpu_inference/models/jax/deepseek_v3.py +978 -0
- tpu_inference/models/jax/gpt_oss.py +508 -0
- tpu_inference/models/jax/jax_intermediate_tensor.py +93 -0
- tpu_inference/models/jax/llama3.py +436 -0
- tpu_inference/models/jax/llama4.py +643 -0
- tpu_inference/models/jax/llama_eagle3.py +350 -0
- tpu_inference/models/jax/llama_guard_4.py +375 -0
- tpu_inference/models/jax/qwen2.py +390 -0
- tpu_inference/models/jax/qwen2_5_vl.py +1232 -0
- tpu_inference/models/jax/qwen3.py +318 -0
- tpu_inference/models/jax/utils/__init__.py +13 -0
- tpu_inference/models/jax/utils/file_utils.py +110 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +177 -0
- tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
- tpu_inference/models/jax/utils/qwix/qwix_utils.py +713 -0
- tpu_inference/models/jax/utils/weight_utils.py +621 -0
- tpu_inference/models/vllm/__init__.py +13 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +307 -0
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +59 -0
- tpu_inference/platforms/__init__.py +16 -0
- tpu_inference/platforms/tpu_platform.py +258 -0
- tpu_inference/runner/__init__.py +13 -0
- tpu_inference/runner/block_table.py +122 -0
- tpu_inference/runner/compilation_manager.py +890 -0
- tpu_inference/runner/input_batch.py +435 -0
- tpu_inference/runner/kv_cache.py +166 -0
- tpu_inference/runner/kv_cache_manager.py +508 -0
- tpu_inference/runner/lora_utils.py +106 -0
- tpu_inference/runner/multimodal_manager.py +231 -0
- tpu_inference/runner/persistent_batch_manager.py +296 -0
- tpu_inference/runner/speculative_decoding_manager.py +262 -0
- tpu_inference/runner/structured_decoding_manager.py +101 -0
- tpu_inference/runner/tpu_runner.py +1768 -0
- tpu_inference/runner/utils.py +426 -0
- tpu_inference/spec_decode/__init__.py +13 -0
- tpu_inference/spec_decode/jax/__init__.py +13 -0
- tpu_inference/spec_decode/jax/eagle3.py +430 -0
- tpu_inference/tpu_info.py +92 -0
- tpu_inference/utils.py +345 -0
- tpu_inference/worker/__init__.py +13 -0
- tpu_inference/worker/tpu_worker.py +468 -0
- tpu_inference-0.12.0.dev20251222.dist-info/METADATA +106 -0
- tpu_inference-0.12.0.dev20251222.dist-info/RECORD +260 -0
- tpu_inference-0.12.0.dev20251222.dist-info/WHEEL +5 -0
- tpu_inference-0.12.0.dev20251222.dist-info/licenses/LICENSE +201 -0
- tpu_inference-0.12.0.dev20251222.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,1586 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
"""
|
|
15
|
+
A variant of TPU-Friendly Ragged Paged Attention kernel optimized for
|
|
16
|
+
head_dim = 64.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
import functools
|
|
20
|
+
|
|
21
|
+
import jax
|
|
22
|
+
import jax.numpy as jnp
|
|
23
|
+
from jax import lax
|
|
24
|
+
from jax.experimental import pallas as pl
|
|
25
|
+
from jax.experimental.pallas import tpu as pltpu
|
|
26
|
+
|
|
27
|
+
from tpu_inference.kernels.ragged_paged_attention.v3.tuned_block_sizes_hd64 import \
|
|
28
|
+
get_tuned_block_sizes
|
|
29
|
+
from tpu_inference.kernels.ragged_paged_attention.v3.util import (
|
|
30
|
+
align_to, cdiv, get_dtype_packing)
|
|
31
|
+
|
|
32
|
+
DEFAULT_MASK_VALUE = -0.7 * float(jnp.finfo(jnp.dtype("float32")).max)
|
|
33
|
+
|
|
34
|
+
DEFAULT_VMEM_LIMIT_BYTES = 100 * 1024 * 1024
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
# TODO(chengjiyao): refactor this hd64 variant and the original kernel to make
|
|
38
|
+
# sure most of the code is shared.
|
|
39
|
+
def ref_ragged_paged_attention_hd64(
|
|
40
|
+
queries: jax.
|
|
41
|
+
Array, # [max_num_tokens, actual_num_q_heads, actual_head_dim]
|
|
42
|
+
keys: jax.Array, # [max_num_tokens, actual_num_kv_heads, actual_head_dim]
|
|
43
|
+
values: jax.
|
|
44
|
+
Array, # [max_num_tokens, actual_num_kv_heads, actual_head_dim]
|
|
45
|
+
kv_cache: jax.
|
|
46
|
+
Array, # [total_num_pages, page_size, num_kv_heads, kv_packing, actual_head_dim_x2]
|
|
47
|
+
kv_lens: jax.Array, # i32[max_num_seqs]
|
|
48
|
+
page_indices: jax.Array, # i32[max_num_seqs * pages_per_seq]
|
|
49
|
+
cu_q_lens: jax.Array, # i32[max_num_seqs + 1]
|
|
50
|
+
distribution: jax.Array, # i32[3]
|
|
51
|
+
attention_sink: jax.Array | None = None, # f32[actual_num_q_heads]
|
|
52
|
+
*,
|
|
53
|
+
sm_scale: float = 1.0,
|
|
54
|
+
sliding_window: int | None = None,
|
|
55
|
+
soft_cap: float | None = None,
|
|
56
|
+
mask_value: float | None = DEFAULT_MASK_VALUE,
|
|
57
|
+
q_scale: float | None = None,
|
|
58
|
+
k_scale: float | None = None,
|
|
59
|
+
v_scale: float | None = None,
|
|
60
|
+
):
|
|
61
|
+
if mask_value is None:
|
|
62
|
+
mask_value = DEFAULT_MASK_VALUE
|
|
63
|
+
dynamic_validate_inputs(
|
|
64
|
+
queries,
|
|
65
|
+
keys,
|
|
66
|
+
values,
|
|
67
|
+
kv_cache,
|
|
68
|
+
kv_lens,
|
|
69
|
+
page_indices,
|
|
70
|
+
cu_q_lens,
|
|
71
|
+
distribution,
|
|
72
|
+
attention_sink,
|
|
73
|
+
sm_scale=sm_scale,
|
|
74
|
+
sliding_window=sliding_window,
|
|
75
|
+
soft_cap=soft_cap,
|
|
76
|
+
mask_value=mask_value,
|
|
77
|
+
q_scale=q_scale,
|
|
78
|
+
k_scale=k_scale,
|
|
79
|
+
v_scale=v_scale,
|
|
80
|
+
)
|
|
81
|
+
actual_head_dim = queries.shape[2]
|
|
82
|
+
actual_num_q_heads = queries.shape[1]
|
|
83
|
+
actual_num_kv_heads = keys.shape[1]
|
|
84
|
+
assert actual_head_dim == 64
|
|
85
|
+
(
|
|
86
|
+
_,
|
|
87
|
+
page_size,
|
|
88
|
+
_,
|
|
89
|
+
kv_packing,
|
|
90
|
+
actual_head_dim_x2,
|
|
91
|
+
) = kv_cache.shape
|
|
92
|
+
|
|
93
|
+
assert actual_num_q_heads % actual_num_kv_heads == 0
|
|
94
|
+
assert actual_head_dim_x2 == 128
|
|
95
|
+
assert get_dtype_packing(kv_cache.dtype) == kv_packing
|
|
96
|
+
actual_num_q_heads_per_kv_head = actual_num_q_heads // actual_num_kv_heads
|
|
97
|
+
padded_actual_num_kv_heads = align_to(actual_num_kv_heads, kv_packing)
|
|
98
|
+
max_num_seqs = kv_lens.shape[0]
|
|
99
|
+
num_page_indices = page_indices.shape[0]
|
|
100
|
+
assert num_page_indices % max_num_seqs == 0
|
|
101
|
+
pages_per_seq = num_page_indices // max_num_seqs
|
|
102
|
+
|
|
103
|
+
# prepare kv and queries
|
|
104
|
+
merged_kv = merge_kv(keys, values)
|
|
105
|
+
queries = jnp.pad(queries, ((0, 0), (0, 0), (0, 64)), constant_values=0.0)
|
|
106
|
+
outputs = []
|
|
107
|
+
|
|
108
|
+
for i in range(distribution[-1]):
|
|
109
|
+
q_start = cu_q_lens[i]
|
|
110
|
+
q_end = cu_q_lens[i + 1]
|
|
111
|
+
q_len = q_end - q_start
|
|
112
|
+
|
|
113
|
+
kv_len = kv_lens[i]
|
|
114
|
+
indices_start = i * pages_per_seq
|
|
115
|
+
indices_end = indices_start + cdiv(kv_len, page_size)
|
|
116
|
+
indices = page_indices[indices_start:indices_end]
|
|
117
|
+
q = queries[q_start:q_end, :, :]
|
|
118
|
+
|
|
119
|
+
# Update the kv cache.
|
|
120
|
+
assert kv_len - q_len >= 0
|
|
121
|
+
gathered_kv = kv_cache[indices]
|
|
122
|
+
gathered_shape = gathered_kv.shape
|
|
123
|
+
gathered_kv = gathered_kv.reshape(-1, *gathered_shape[-3:])
|
|
124
|
+
gathered_kv = gathered_kv.at[kv_len - q_len:kv_len].set(
|
|
125
|
+
merged_kv[q_start:q_end])
|
|
126
|
+
kv_cache = kv_cache.at[indices].set(
|
|
127
|
+
gathered_kv.reshape(gathered_shape))
|
|
128
|
+
|
|
129
|
+
kv = gathered_kv.reshape(
|
|
130
|
+
-1, padded_actual_num_kv_heads,
|
|
131
|
+
actual_head_dim_x2)[:, :actual_num_kv_heads, :]
|
|
132
|
+
kv = kv[:kv_len, :, :]
|
|
133
|
+
kv = jnp.repeat(kv, actual_num_q_heads_per_kv_head, axis=1)
|
|
134
|
+
if q_scale is not None:
|
|
135
|
+
q = q / q_scale
|
|
136
|
+
if jnp.issubdtype(kv.dtype, jnp.floating):
|
|
137
|
+
dtype_info = jnp.finfo(kv.dtype)
|
|
138
|
+
minval = float(dtype_info.min)
|
|
139
|
+
maxval = float(dtype_info.max)
|
|
140
|
+
q = jnp.clip(q, min=minval, max=maxval)
|
|
141
|
+
q = q.astype(kv.dtype)
|
|
142
|
+
attn = jnp.einsum("qhd,khd->hqk",
|
|
143
|
+
q,
|
|
144
|
+
kv,
|
|
145
|
+
preferred_element_type=jnp.float32)
|
|
146
|
+
attn *= sm_scale
|
|
147
|
+
if k_scale is not None:
|
|
148
|
+
attn *= k_scale
|
|
149
|
+
if q_scale is not None:
|
|
150
|
+
attn *= q_scale
|
|
151
|
+
|
|
152
|
+
q_span = (kv_len - q_len) + jax.lax.broadcasted_iota(
|
|
153
|
+
jnp.int32, attn.shape, 1)
|
|
154
|
+
kv_span = jax.lax.broadcasted_iota(jnp.int32, attn.shape, 2)
|
|
155
|
+
mask = q_span < kv_span
|
|
156
|
+
if sliding_window is not None:
|
|
157
|
+
mask = jnp.logical_or(mask, q_span - sliding_window >= kv_span)
|
|
158
|
+
if soft_cap is not None:
|
|
159
|
+
attn = soft_cap * jnp.tanh(attn / soft_cap)
|
|
160
|
+
attn = jnp.where(mask, mask_value, attn)
|
|
161
|
+
|
|
162
|
+
if attention_sink is not None:
|
|
163
|
+
reshaped_attention_sink = attention_sink.reshape(
|
|
164
|
+
actual_num_q_heads, 1, 1)
|
|
165
|
+
reshaped_attention_sink = jnp.repeat(reshaped_attention_sink,
|
|
166
|
+
q_len,
|
|
167
|
+
axis=1)
|
|
168
|
+
attn = jnp.concat([reshaped_attention_sink, attn], axis=2)
|
|
169
|
+
attn = jax.nn.softmax(attn, axis=-1).astype(kv.dtype)
|
|
170
|
+
attn = attn[..., 1:]
|
|
171
|
+
else:
|
|
172
|
+
attn = jax.nn.softmax(attn, axis=-1).astype(kv.dtype)
|
|
173
|
+
|
|
174
|
+
out = jnp.einsum("hqk,khd->qhd", attn, kv).astype(queries.dtype)
|
|
175
|
+
if v_scale is not None:
|
|
176
|
+
out *= v_scale
|
|
177
|
+
|
|
178
|
+
outputs.append(out)
|
|
179
|
+
|
|
180
|
+
result = jnp.concatenate(outputs, axis=0)
|
|
181
|
+
result = result[:, :, actual_head_dim:]
|
|
182
|
+
return result, kv_cache
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
def get_smem_estimate_bytes(max_num_seqs, pages_per_seq):
|
|
186
|
+
total_bits = (
|
|
187
|
+
# kv_lens_ref: i32[max_num_seqs]
|
|
188
|
+
align_to(max_num_seqs, 128) * 32 +
|
|
189
|
+
# page_indices_ref: i32[max_num_seqs * pages_per_seq]
|
|
190
|
+
align_to(max_num_seqs * pages_per_seq, 128) * 32 +
|
|
191
|
+
# cu_q_lens_ref: i32[max_num_seqs + 1]
|
|
192
|
+
align_to(max_num_seqs + 1, 128) * 32 +
|
|
193
|
+
# distribution_ref: i32[3]
|
|
194
|
+
128 * 32 +
|
|
195
|
+
# sem_ids_ref: i32[3]
|
|
196
|
+
128 * 32 +
|
|
197
|
+
# bo_ids_ref: i32[4]
|
|
198
|
+
128 * 32 +
|
|
199
|
+
# bkv_update_ids_ref: i32[6]
|
|
200
|
+
128 * 32)
|
|
201
|
+
return cdiv(total_bits, 8)
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
def get_vmem_estimate_bytes(
|
|
205
|
+
actual_num_kv_heads,
|
|
206
|
+
actual_num_q_heads_per_kv_head,
|
|
207
|
+
actual_head_dim,
|
|
208
|
+
bq_sz,
|
|
209
|
+
bkv_sz,
|
|
210
|
+
q_dtype,
|
|
211
|
+
kv_dtype,
|
|
212
|
+
):
|
|
213
|
+
assert actual_head_dim == 64
|
|
214
|
+
q_packing = get_dtype_packing(q_dtype)
|
|
215
|
+
kv_packing = get_dtype_packing(kv_dtype)
|
|
216
|
+
num_q_heads_per_kv_head = align_to(actual_num_q_heads_per_kv_head,
|
|
217
|
+
q_packing)
|
|
218
|
+
num_kv_heads = align_to(actual_num_kv_heads, kv_packing)
|
|
219
|
+
head_dim = actual_head_dim * 2
|
|
220
|
+
|
|
221
|
+
total_bits = (
|
|
222
|
+
# bkv_x2_ref
|
|
223
|
+
(2 * bkv_sz * num_kv_heads * head_dim) * (32 // kv_packing) +
|
|
224
|
+
# bq_x2_ref + bo_x2_ref
|
|
225
|
+
2 * (2 * actual_num_kv_heads * bq_sz * num_q_heads_per_kv_head *
|
|
226
|
+
head_dim) * (32 // q_packing) +
|
|
227
|
+
# l_ref + m_ref
|
|
228
|
+
2 *
|
|
229
|
+
(actual_num_kv_heads * bq_sz * num_q_heads_per_kv_head * 128) * 32 +
|
|
230
|
+
# acc_ref
|
|
231
|
+
(actual_num_kv_heads * bq_sz * num_q_heads_per_kv_head * head_dim) *
|
|
232
|
+
32)
|
|
233
|
+
return cdiv(total_bits, 8)
|
|
234
|
+
|
|
235
|
+
|
|
236
|
+
def get_kv_cache_shape(
|
|
237
|
+
total_num_pages,
|
|
238
|
+
page_size,
|
|
239
|
+
actual_num_kv_heads,
|
|
240
|
+
actual_head_dim,
|
|
241
|
+
kv_dtype,
|
|
242
|
+
):
|
|
243
|
+
assert actual_head_dim == 64
|
|
244
|
+
kv_packing = get_dtype_packing(kv_dtype)
|
|
245
|
+
return (
|
|
246
|
+
total_num_pages,
|
|
247
|
+
page_size,
|
|
248
|
+
align_to(actual_num_kv_heads, kv_packing) // kv_packing,
|
|
249
|
+
kv_packing,
|
|
250
|
+
128,
|
|
251
|
+
)
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
def _ragged_paged_attention_kernel(
|
|
255
|
+
# Prefetch
|
|
256
|
+
kv_lens_ref, # [max_num_seqs]
|
|
257
|
+
page_indices_ref, # [max_num_seqs * pages_per_seq]
|
|
258
|
+
cu_q_lens_ref, # [max_num_seqs + 1]
|
|
259
|
+
# TODO(jevinjiang): merge these into one so we can save SMEM.
|
|
260
|
+
distribution_ref, # [3] (decode_end, prefill_end, mixed_end)
|
|
261
|
+
sem_ids_ref, # [3] (bq_sem_idx, bkv_sem_idx, bo_sem_idx)
|
|
262
|
+
bo_ids_ref, # [4] (bo_sem_0_seq_idx, bo_sem_1_seq_idx, bo_sem_0_bo_idx, bo_sem_1_bo_idx)
|
|
263
|
+
bkv_update_ids_ref, # [6] (bkv_sem_0_seq_idx, bkv_sem_1_seq_idx, bkv_sem_0_offset, bkv_sem_1_offset, bkv_sem_0_sz, bkv_sem_1_sz)
|
|
264
|
+
# Input
|
|
265
|
+
q_hbm_ref, # [actual_num_kv_heads, max_num_tokens, num_q_heads_per_kv_head // q_packing, q_packing, head_dim]
|
|
266
|
+
kv_hbm_ref, # [max_num_tokens, num_kv_heads // kv_packing, kv_packing, actual_head_dim_x2]
|
|
267
|
+
kv_cache_hbm_ref, # [total_num_pages, page_size, num_kv_heads // kv_packing, kv_packing, actual_head_dim_x2]
|
|
268
|
+
attention_sink_ref, # [actual_num_kv_heads, num_q_heads_per_kv_head, 128]
|
|
269
|
+
# Output
|
|
270
|
+
o_hbm_ref, # [actual_num_kv_heads, max_num_tokens, num_q_heads_per_kv_head // q_packing, q_packing, actual_head_dim_x2]
|
|
271
|
+
updated_kv_cache_hbm_ref, # [total_num_pages, page_size, num_kv_heads // kv_packing, kv_packing, actual_head_dim_x2]
|
|
272
|
+
# Scratch
|
|
273
|
+
bkv_x2_ref, # [2, bkv_sz, num_kv_heads // kv_packing, kv_packing, actual_head_dim_x2]
|
|
274
|
+
bq_x2_ref, # [2, actual_num_kv_heads, bq_sz, num_q_heads_per_kv_head // q_packing, q_packing, actual_head_dim_x2]
|
|
275
|
+
bo_x2_ref, # [2, actual_num_kv_heads, bq_sz, num_q_heads_per_kv_head // q_packing, q_packing, actual_head_dim_x2]
|
|
276
|
+
sems, # [4, 2]
|
|
277
|
+
l_ref, # [actual_num_kv_heads, bq_sz * num_q_heads_per_kv_head, 128],
|
|
278
|
+
m_ref, # [actual_num_kv_heads, bq_sz * num_q_heads_per_kv_head, 128],
|
|
279
|
+
acc_ref, # [actual_num_kv_heads, bq_sz * num_q_heads_per_kv_head, actual_head_dim_x2],
|
|
280
|
+
*,
|
|
281
|
+
sm_scale: float,
|
|
282
|
+
sliding_window: int | None = None,
|
|
283
|
+
soft_cap: float | None = None,
|
|
284
|
+
mask_value: float = DEFAULT_MASK_VALUE,
|
|
285
|
+
q_scale: float | None = None,
|
|
286
|
+
k_scale: float | None = None,
|
|
287
|
+
v_scale: float | None = None,
|
|
288
|
+
chunk_prefill_size: int | None = None,
|
|
289
|
+
bkv_p,
|
|
290
|
+
bq_sz,
|
|
291
|
+
debug_mode: bool = False,
|
|
292
|
+
):
|
|
293
|
+
assert q_hbm_ref.shape == o_hbm_ref.shape
|
|
294
|
+
assert q_hbm_ref.shape[-1] == kv_cache_hbm_ref.shape[-1]
|
|
295
|
+
(
|
|
296
|
+
actual_num_kv_heads,
|
|
297
|
+
max_num_tokens,
|
|
298
|
+
num_q_heads_per_kv_head_per_packing,
|
|
299
|
+
q_packing,
|
|
300
|
+
actual_head_dim_x2,
|
|
301
|
+
) = q_hbm_ref.shape
|
|
302
|
+
(
|
|
303
|
+
total_num_pages,
|
|
304
|
+
page_size,
|
|
305
|
+
num_kv_heads_per_kv_packing,
|
|
306
|
+
kv_packing,
|
|
307
|
+
_,
|
|
308
|
+
) = kv_cache_hbm_ref.shape
|
|
309
|
+
max_num_seqs = kv_lens_ref.shape[0]
|
|
310
|
+
num_page_indices = page_indices_ref.shape[0]
|
|
311
|
+
assert num_page_indices % max_num_seqs == 0
|
|
312
|
+
pages_per_seq = num_page_indices // max_num_seqs
|
|
313
|
+
num_kv_heads = num_kv_heads_per_kv_packing * kv_packing
|
|
314
|
+
num_q_heads_per_kv_head = num_q_heads_per_kv_head_per_packing * q_packing
|
|
315
|
+
q_dtype = q_hbm_ref.dtype
|
|
316
|
+
kv_dtype = kv_cache_hbm_ref.dtype
|
|
317
|
+
assert o_hbm_ref.dtype == q_dtype
|
|
318
|
+
assert get_dtype_packing(q_dtype) == q_packing
|
|
319
|
+
assert get_dtype_packing(kv_dtype) == kv_packing
|
|
320
|
+
assert actual_head_dim_x2 == 128
|
|
321
|
+
bkv_sz = bkv_p * page_size
|
|
322
|
+
seq_idx = pl.program_id(0)
|
|
323
|
+
num_seqs = pl.num_programs(0)
|
|
324
|
+
decode_end = distribution_ref[0]
|
|
325
|
+
prefill_end = distribution_ref[1]
|
|
326
|
+
mixed_end = distribution_ref[2]
|
|
327
|
+
|
|
328
|
+
q_start = cu_q_lens_ref[seq_idx]
|
|
329
|
+
q_end = cu_q_lens_ref[seq_idx + 1]
|
|
330
|
+
q_len = q_end - q_start
|
|
331
|
+
kv_len = kv_lens_ref[seq_idx]
|
|
332
|
+
|
|
333
|
+
if sliding_window is None:
|
|
334
|
+
bkv_idx_start = next_seq_bkv_idx_start = 0
|
|
335
|
+
else:
|
|
336
|
+
bkv_idx_start = jnp.maximum(kv_len - q_len - sliding_window,
|
|
337
|
+
0) // bkv_sz
|
|
338
|
+
|
|
339
|
+
# If seq_idx + 1 == num_seqs, kv_lens_ref[seq_idx + 1] will trigger a
|
|
340
|
+
# out-of-bound error. To avoid this, we set upperbound of next_seq_idx
|
|
341
|
+
# to be num_seqs - 1.
|
|
342
|
+
next_seq_idx = jnp.minimum(seq_idx + 1, num_seqs - 1)
|
|
343
|
+
next_kv_len = kv_lens_ref[next_seq_idx]
|
|
344
|
+
next_q_len = cu_q_lens_ref[next_seq_idx + 1] - q_end
|
|
345
|
+
next_seq_bkv_idx_start = (
|
|
346
|
+
jnp.maximum(next_kv_len - next_q_len - sliding_window, 0) //
|
|
347
|
+
bkv_sz)
|
|
348
|
+
|
|
349
|
+
def debug_print(msg, *args):
|
|
350
|
+
if debug_mode:
|
|
351
|
+
pl.debug_print(msg, *args)
|
|
352
|
+
|
|
353
|
+
debug_print("[RPA debug] ======= In loop seq_idx={}", seq_idx)
|
|
354
|
+
debug_print("[RPA debug] num_seqs={}", num_seqs)
|
|
355
|
+
debug_print("[RPA debug] decode_end={}", decode_end)
|
|
356
|
+
debug_print("[RPA debug] prefill_end={}", prefill_end)
|
|
357
|
+
debug_print("[RPA debug] mixed_end={}", mixed_end)
|
|
358
|
+
debug_print("[RPA debug] bkv_p={}", bkv_p)
|
|
359
|
+
debug_print("[RPA debug] page_size={}", page_size)
|
|
360
|
+
debug_print("[RPA debug] pages_per_seq={}", pages_per_seq)
|
|
361
|
+
debug_print("[RPA debug] bkv_sz={}", bkv_sz)
|
|
362
|
+
debug_print("[RPA debug] bq_sz={}", bq_sz)
|
|
363
|
+
debug_print("[RPA debug] q_start={}", q_start)
|
|
364
|
+
debug_print("[RPA debug] q_end={}", q_end)
|
|
365
|
+
debug_print("[RPA debug] q_len={}", q_len)
|
|
366
|
+
debug_print("[RPA debug] kv_len={}", kv_len)
|
|
367
|
+
|
|
368
|
+
def flash_attention_step1_qk_softmax(
|
|
369
|
+
q, # [actual_bq_sz * num_q_heads_per_kv_head, actual_head_dim_x2]
|
|
370
|
+
kv, # [bkv_sz, actual_head_dim_x2]
|
|
371
|
+
*,
|
|
372
|
+
bq_idx,
|
|
373
|
+
bkv_idx,
|
|
374
|
+
kv_head_idx,
|
|
375
|
+
):
|
|
376
|
+
assert len(q.shape) == 2
|
|
377
|
+
assert q.shape[0] % num_q_heads_per_kv_head == 0
|
|
378
|
+
assert q.shape[1] == actual_head_dim_x2
|
|
379
|
+
assert kv.shape == (bkv_sz, actual_head_dim_x2)
|
|
380
|
+
head_l_ref = l_ref.at[kv_head_idx, :q.shape[0]]
|
|
381
|
+
head_m_ref = m_ref.at[kv_head_idx, :q.shape[0]]
|
|
382
|
+
|
|
383
|
+
def load_with_init(ref, init_val):
|
|
384
|
+
return jnp.where(bkv_idx == bkv_idx_start,
|
|
385
|
+
jnp.full_like(ref, init_val), ref[...])
|
|
386
|
+
|
|
387
|
+
# Follow FlashAttention-2 forward pass.
|
|
388
|
+
if q_scale is not None:
|
|
389
|
+
q = q / q_scale
|
|
390
|
+
if jnp.issubdtype(kv.dtype, jnp.floating):
|
|
391
|
+
dtype_info = jnp.finfo(kv.dtype)
|
|
392
|
+
minval = float(dtype_info.min)
|
|
393
|
+
maxval = float(dtype_info.max)
|
|
394
|
+
q = jnp.clip(q, min=minval, max=maxval)
|
|
395
|
+
q = q.astype(kv.dtype)
|
|
396
|
+
|
|
397
|
+
s = jnp.einsum("nd,md->nm", q, kv, preferred_element_type=jnp.float32)
|
|
398
|
+
s *= sm_scale
|
|
399
|
+
if k_scale is not None:
|
|
400
|
+
s *= k_scale
|
|
401
|
+
if q_scale is not None:
|
|
402
|
+
s *= q_scale
|
|
403
|
+
if soft_cap is not None:
|
|
404
|
+
s = soft_cap * jnp.tanh(s / soft_cap)
|
|
405
|
+
|
|
406
|
+
q_span = (kv_len - q_len + bq_idx * bq_sz +
|
|
407
|
+
lax.broadcasted_iota(jnp.int32, s.shape, 0) //
|
|
408
|
+
num_q_heads_per_kv_head)
|
|
409
|
+
k_span = bkv_idx * bkv_sz + lax.broadcasted_iota(jnp.int32, s.shape, 1)
|
|
410
|
+
mask = k_span <= q_span
|
|
411
|
+
|
|
412
|
+
if sliding_window is not None:
|
|
413
|
+
mask = jnp.logical_and(mask, q_span - sliding_window < k_span)
|
|
414
|
+
|
|
415
|
+
s = jnp.where(mask, s, mask_value)
|
|
416
|
+
s_rowmax = jnp.max(s, axis=1, keepdims=True)
|
|
417
|
+
|
|
418
|
+
if attention_sink_ref is not None:
|
|
419
|
+
sinks = attention_sink_ref[kv_head_idx]
|
|
420
|
+
actual_bq_sz = q.shape[0] // num_q_heads_per_kv_head
|
|
421
|
+
m_prev_init = jnp.concat([sinks] * actual_bq_sz, axis=0)
|
|
422
|
+
m_prev = jnp.where(bkv_idx == bkv_idx_start, m_prev_init,
|
|
423
|
+
head_m_ref[...])
|
|
424
|
+
else:
|
|
425
|
+
m_prev = load_with_init(head_m_ref, -jnp.inf)
|
|
426
|
+
|
|
427
|
+
m_curr = jnp.maximum(m_prev, s_rowmax)
|
|
428
|
+
head_m_ref[...] = m_curr
|
|
429
|
+
p = jnp.exp(s - broadcast_minor(m_curr, s.shape))
|
|
430
|
+
|
|
431
|
+
p_rowsum = jnp.sum(p, axis=1, keepdims=True)
|
|
432
|
+
exp_m_diff = jnp.exp(m_prev - m_curr)
|
|
433
|
+
l_prev = load_with_init(head_l_ref, 1.0)
|
|
434
|
+
l_curr = exp_m_diff * l_prev + p_rowsum
|
|
435
|
+
head_l_ref[...] = l_curr
|
|
436
|
+
|
|
437
|
+
return p, exp_m_diff
|
|
438
|
+
|
|
439
|
+
def flash_attention_step2_pv(
|
|
440
|
+
q_shape_0,
|
|
441
|
+
kv, # [bkv_sz, actual_head_dim_x2]
|
|
442
|
+
p, # from step1
|
|
443
|
+
exp_m_diff, # from step1
|
|
444
|
+
*,
|
|
445
|
+
bkv_idx,
|
|
446
|
+
kv_head_idx,
|
|
447
|
+
):
|
|
448
|
+
head_acc_ref = acc_ref.at[kv_head_idx, :q_shape_0]
|
|
449
|
+
|
|
450
|
+
def load_with_init(ref, init_val):
|
|
451
|
+
return jnp.where(bkv_idx == bkv_idx_start,
|
|
452
|
+
jnp.full_like(ref, init_val), ref[...])
|
|
453
|
+
|
|
454
|
+
pv = jnp.einsum("nm,md->nd", p, kv, preferred_element_type=jnp.float32)
|
|
455
|
+
if v_scale is not None:
|
|
456
|
+
pv *= v_scale
|
|
457
|
+
|
|
458
|
+
o_prev = load_with_init(head_acc_ref, 0.0)
|
|
459
|
+
o_curr = broadcast_minor(exp_m_diff, o_prev.shape) * o_prev + pv
|
|
460
|
+
head_acc_ref[...] = o_curr
|
|
461
|
+
|
|
462
|
+
def _async_copy(src, dst, sem, wait):
|
|
463
|
+
if debug_mode:
|
|
464
|
+
# Skip DMA if debug mode is enabled.
|
|
465
|
+
return
|
|
466
|
+
cp = pltpu.make_async_copy(src, dst, sem)
|
|
467
|
+
if wait:
|
|
468
|
+
cp.wait()
|
|
469
|
+
else:
|
|
470
|
+
cp.start()
|
|
471
|
+
|
|
472
|
+
def _fetch_bkv(seq_idx, bkv_idx, bkv_sem_idx, *, wait=False):
|
|
473
|
+
sem = sems.at[0, bkv_sem_idx]
|
|
474
|
+
vmem_ref = bkv_x2_ref.at[bkv_sem_idx]
|
|
475
|
+
|
|
476
|
+
cache_hbm_shape = kv_cache_hbm_ref.shape
|
|
477
|
+
cache_hbm_ref = kv_cache_hbm_ref.reshape(
|
|
478
|
+
cache_hbm_shape[0] * cache_hbm_shape[1], *cache_hbm_shape[2:])
|
|
479
|
+
kv_len = kv_lens_ref[seq_idx]
|
|
480
|
+
kv_len_start = bkv_idx * bkv_sz
|
|
481
|
+
kv_p_start = bkv_idx * bkv_p
|
|
482
|
+
q_start = cu_q_lens_ref[seq_idx]
|
|
483
|
+
q_end = cu_q_lens_ref[seq_idx + 1]
|
|
484
|
+
q_len = q_end - q_start
|
|
485
|
+
|
|
486
|
+
kv_left = kv_len - kv_len_start
|
|
487
|
+
kv_left_frm_cache = jnp.maximum(kv_left - q_len, 0)
|
|
488
|
+
kv_left_frm_new = kv_left - kv_left_frm_cache
|
|
489
|
+
bkv_p_frm_cache = jnp.minimum(cdiv(kv_left_frm_cache, page_size),
|
|
490
|
+
bkv_p)
|
|
491
|
+
bkv_sz_frm_new = jnp.minimum(
|
|
492
|
+
jnp.maximum(bkv_sz - kv_left_frm_cache, 0), kv_left_frm_new)
|
|
493
|
+
page_indices_offset = seq_idx * pages_per_seq + kv_p_start
|
|
494
|
+
|
|
495
|
+
# Make sure the current bkv buffer is safe to overwrite.
|
|
496
|
+
wait_update_kv_cache(bkv_sem_idx)
|
|
497
|
+
|
|
498
|
+
debug_print(
|
|
499
|
+
"[RPA debug]"
|
|
500
|
+
f" -----------{'wait' if wait else 'start'}_fetch_bkv-----------")
|
|
501
|
+
debug_print("[RPA debug] seq_idx={}", seq_idx)
|
|
502
|
+
debug_print("[RPA debug] bkv_idx={}", bkv_idx)
|
|
503
|
+
debug_print("[RPA debug] bkv_sem_idx={}", bkv_sem_idx)
|
|
504
|
+
debug_print("[RPA debug] kv_len_start={}", kv_len_start)
|
|
505
|
+
debug_print("[RPA debug] kv_p_start={}", kv_p_start)
|
|
506
|
+
debug_print("[RPA debug] kv_left={}", kv_left)
|
|
507
|
+
debug_print("[RPA debug] kv_left_frm_cache={}", kv_left_frm_cache)
|
|
508
|
+
debug_print("[RPA debug] kv_left_frm_new={}", kv_left_frm_new)
|
|
509
|
+
debug_print("[RPA debug] bkv_p_frm_cache={}", bkv_p_frm_cache)
|
|
510
|
+
debug_print("[RPA debug] bkv_sz_frm_new={}", bkv_sz_frm_new)
|
|
511
|
+
debug_print("[RPA debug] page_indices_offset={}", page_indices_offset)
|
|
512
|
+
|
|
513
|
+
if not wait:
|
|
514
|
+
# Fetch effective kv from kv cache.
|
|
515
|
+
def loop_body(i, offset):
|
|
516
|
+
sz = jnp.minimum(page_size, kv_left_frm_cache - i * page_size)
|
|
517
|
+
_async_copy(
|
|
518
|
+
cache_hbm_ref.at[pl.ds(
|
|
519
|
+
page_indices_ref[page_indices_offset + i] * page_size,
|
|
520
|
+
sz)],
|
|
521
|
+
vmem_ref.at[pl.ds(i * page_size, sz)],
|
|
522
|
+
sem,
|
|
523
|
+
wait=False,
|
|
524
|
+
)
|
|
525
|
+
debug_print("[RPA debug] loop_body i={}, sz={}", i, sz)
|
|
526
|
+
return offset + sz
|
|
527
|
+
|
|
528
|
+
offset = lax.fori_loop(
|
|
529
|
+
0,
|
|
530
|
+
bkv_p_frm_cache,
|
|
531
|
+
loop_body,
|
|
532
|
+
0, # offset
|
|
533
|
+
unroll=False,
|
|
534
|
+
)
|
|
535
|
+
|
|
536
|
+
size = lax.select(bkv_sz_frm_new > 0, bkv_sz_frm_new, 0)
|
|
537
|
+
new_kv_len_start = q_end - kv_left_frm_new
|
|
538
|
+
debug_print("[RPA debug] new_kv_len_start={}", new_kv_len_start)
|
|
539
|
+
debug_print("[RPA debug] offset_in_bkv={}", offset)
|
|
540
|
+
_async_copy(
|
|
541
|
+
kv_hbm_ref.at[pl.ds(new_kv_len_start, size)],
|
|
542
|
+
vmem_ref.at[pl.ds(offset, size)],
|
|
543
|
+
sem,
|
|
544
|
+
wait,
|
|
545
|
+
)
|
|
546
|
+
|
|
547
|
+
return kv_len_start + offset, bkv_sz_frm_new
|
|
548
|
+
else:
|
|
549
|
+
offset = jnp.minimum(kv_left_frm_cache, page_size * bkv_p)
|
|
550
|
+
dst = vmem_ref.at[pl.ds(0, offset + bkv_sz_frm_new)]
|
|
551
|
+
_async_copy(
|
|
552
|
+
src=dst,
|
|
553
|
+
dst=dst,
|
|
554
|
+
sem=sem,
|
|
555
|
+
wait=True,
|
|
556
|
+
)
|
|
557
|
+
return kv_len_start + offset, bkv_sz_frm_new
|
|
558
|
+
|
|
559
|
+
def _update_kv_cache(seq_idx,
|
|
560
|
+
bkv_sem_idx,
|
|
561
|
+
offset,
|
|
562
|
+
update_sz,
|
|
563
|
+
*,
|
|
564
|
+
wait=False):
|
|
565
|
+
sem = sems.at[3, bkv_sem_idx]
|
|
566
|
+
vmem_ref = bkv_x2_ref.at[bkv_sem_idx]
|
|
567
|
+
bkv_id = offset // bkv_sz
|
|
568
|
+
kv_p_start = offset // page_size
|
|
569
|
+
kv_p_end = cdiv(offset + update_sz, page_size)
|
|
570
|
+
ignore = offset % page_size
|
|
571
|
+
p_ignore = kv_p_start - bkv_id * bkv_p
|
|
572
|
+
page_indices_offset = seq_idx * pages_per_seq + kv_p_start
|
|
573
|
+
|
|
574
|
+
cache_hbm_shape = updated_kv_cache_hbm_ref.shape
|
|
575
|
+
cache_hbm_ref = updated_kv_cache_hbm_ref.reshape(
|
|
576
|
+
cache_hbm_shape[0] * cache_hbm_shape[1], *cache_hbm_shape[2:])
|
|
577
|
+
|
|
578
|
+
debug_print(
|
|
579
|
+
"[RPA debug]"
|
|
580
|
+
f" -----------{'wait' if wait else 'start'}_update_kv_cache-----------"
|
|
581
|
+
)
|
|
582
|
+
debug_print("[RPA debug] seq_idx={}", seq_idx)
|
|
583
|
+
debug_print("[RPA debug] bkv_sem_idx={}", bkv_sem_idx)
|
|
584
|
+
debug_print("[RPA debug] offset={}", offset)
|
|
585
|
+
debug_print("[RPA debug] update_sz={}", update_sz)
|
|
586
|
+
debug_print("[RPA debug] bkv_id={}", bkv_id)
|
|
587
|
+
debug_print("[RPA debug] kv_p_start={}", kv_p_start)
|
|
588
|
+
debug_print("[RPA debug] kv_p_end={}", kv_p_end)
|
|
589
|
+
debug_print("[RPA debug] ignore={}", ignore)
|
|
590
|
+
debug_print("[RPA debug] p_ignore={}", p_ignore)
|
|
591
|
+
debug_print("[RPA debug] page_indices_offset={}", page_indices_offset)
|
|
592
|
+
|
|
593
|
+
if not wait:
|
|
594
|
+
|
|
595
|
+
def loop_body(i, states):
|
|
596
|
+
update_sz, ignore = states
|
|
597
|
+
sz = jnp.minimum(page_size - ignore, update_sz)
|
|
598
|
+
|
|
599
|
+
_async_copy(
|
|
600
|
+
vmem_ref.at[pl.ds((p_ignore + i) * page_size + ignore,
|
|
601
|
+
sz)],
|
|
602
|
+
cache_hbm_ref.at[pl.ds(
|
|
603
|
+
page_indices_ref[page_indices_offset + i] * page_size +
|
|
604
|
+
ignore,
|
|
605
|
+
sz,
|
|
606
|
+
)],
|
|
607
|
+
sem,
|
|
608
|
+
wait=False,
|
|
609
|
+
)
|
|
610
|
+
debug_print("[RPA debug] loop_body i={}, sz={}", i, sz)
|
|
611
|
+
return update_sz - sz, 0
|
|
612
|
+
|
|
613
|
+
lax.fori_loop(
|
|
614
|
+
0,
|
|
615
|
+
kv_p_end - kv_p_start,
|
|
616
|
+
loop_body,
|
|
617
|
+
(update_sz, ignore), # total transfer size
|
|
618
|
+
unroll=False,
|
|
619
|
+
)
|
|
620
|
+
else:
|
|
621
|
+
dst = cache_hbm_ref.at[pl.ds(0, update_sz)]
|
|
622
|
+
_async_copy(
|
|
623
|
+
src=dst,
|
|
624
|
+
dst=dst,
|
|
625
|
+
sem=sem,
|
|
626
|
+
wait=True,
|
|
627
|
+
)
|
|
628
|
+
|
|
629
|
+
def _fetch_bq(seq_idx, bq_idx, bq_sem_idx, *, wait=False):
|
|
630
|
+
sem = sems.at[1, bq_sem_idx]
|
|
631
|
+
vmem_ref = bq_x2_ref.at[bq_sem_idx]
|
|
632
|
+
q_len_start = cu_q_lens_ref[seq_idx] + bq_idx * bq_sz
|
|
633
|
+
q_end = cu_q_lens_ref[seq_idx + 1]
|
|
634
|
+
sz = jnp.minimum(bq_sz, q_end - q_len_start)
|
|
635
|
+
|
|
636
|
+
debug_print(
|
|
637
|
+
"[RPA debug]"
|
|
638
|
+
f" -----------{'wait' if wait else 'start'}_fetch_bq-----------")
|
|
639
|
+
debug_print("[RPA debug] seq_idx={}", seq_idx)
|
|
640
|
+
debug_print("[RPA debug] bq_idx={}", bq_idx)
|
|
641
|
+
debug_print("[RPA debug] bq_sem_idx={}", bq_sem_idx)
|
|
642
|
+
debug_print("[RPA debug] q_len_start={}", q_len_start)
|
|
643
|
+
debug_print("[RPA debug] q_end={}", q_end)
|
|
644
|
+
debug_print("[RPA debug] sz={}", sz)
|
|
645
|
+
|
|
646
|
+
_async_copy(
|
|
647
|
+
q_hbm_ref.at[:, pl.ds(q_len_start, sz)],
|
|
648
|
+
vmem_ref.at[:, pl.ds(0, sz)],
|
|
649
|
+
sem,
|
|
650
|
+
wait,
|
|
651
|
+
)
|
|
652
|
+
|
|
653
|
+
def _send_bo(seq_idx, bo_idx, bo_sem_idx, *, wait=False):
|
|
654
|
+
sem = sems.at[2, bo_sem_idx]
|
|
655
|
+
vmem_ref = bo_x2_ref.at[bo_sem_idx]
|
|
656
|
+
q_len_start = cu_q_lens_ref[seq_idx] + bo_idx * bq_sz
|
|
657
|
+
q_end = cu_q_lens_ref[seq_idx + 1]
|
|
658
|
+
sz = jnp.minimum(bq_sz, q_end - q_len_start)
|
|
659
|
+
|
|
660
|
+
debug_print(
|
|
661
|
+
"[RPA debug]"
|
|
662
|
+
f" -----------{'wait' if wait else 'start'}_send_bo-----------")
|
|
663
|
+
debug_print("[RPA debug] seq_idx={}", seq_idx)
|
|
664
|
+
debug_print("[RPA debug] bo_idx={}", bo_idx)
|
|
665
|
+
debug_print("[RPA debug] bo_sem_idx={}", bo_sem_idx)
|
|
666
|
+
debug_print("[RPA debug] q_len_start={}", q_len_start)
|
|
667
|
+
debug_print("[RPA debug] q_end={}", q_end)
|
|
668
|
+
debug_print("[RPA debug] sz={}", sz)
|
|
669
|
+
|
|
670
|
+
_async_copy(
|
|
671
|
+
vmem_ref.at[:, pl.ds(0, sz)],
|
|
672
|
+
o_hbm_ref.at[:, pl.ds(q_len_start, sz)],
|
|
673
|
+
sem,
|
|
674
|
+
wait,
|
|
675
|
+
)
|
|
676
|
+
|
|
677
|
+
def start_fetch_bkv(seq_idx, bkv_idx, bkv_sem_idx):
|
|
678
|
+
return _fetch_bkv(seq_idx, bkv_idx, bkv_sem_idx)
|
|
679
|
+
|
|
680
|
+
def wait_fetch_bkv(seq_idx, bkv_idx, bkv_sem_idx):
|
|
681
|
+
return _fetch_bkv(seq_idx, bkv_idx, bkv_sem_idx, wait=True)
|
|
682
|
+
|
|
683
|
+
def start_fetch_bq(seq_idx, bq_idx, bq_sem_idx):
|
|
684
|
+
return _fetch_bq(seq_idx, bq_idx, bq_sem_idx)
|
|
685
|
+
|
|
686
|
+
def wait_fetch_bq(seq_idx, bq_idx, bq_sem_idx):
|
|
687
|
+
return _fetch_bq(seq_idx, bq_idx, bq_sem_idx, wait=True)
|
|
688
|
+
|
|
689
|
+
def start_send_bo(seq_idx, bo_idx, bo_sem_idx):
|
|
690
|
+
bo_ids_ref[bo_sem_idx] = seq_idx
|
|
691
|
+
bo_ids_ref[bo_sem_idx + 2] = bo_idx
|
|
692
|
+
_send_bo(seq_idx, bo_idx, bo_sem_idx)
|
|
693
|
+
|
|
694
|
+
def wait_send_bo(bo_sem_idx):
|
|
695
|
+
old_seq_idx = bo_ids_ref[bo_sem_idx]
|
|
696
|
+
old_bo_idx = bo_ids_ref[bo_sem_idx + 2]
|
|
697
|
+
|
|
698
|
+
@pl.when(jnp.logical_and(0 <= old_seq_idx, old_seq_idx <= seq_idx))
|
|
699
|
+
def _():
|
|
700
|
+
_send_bo(old_seq_idx, old_bo_idx, bo_sem_idx, wait=True)
|
|
701
|
+
|
|
702
|
+
def start_update_kv_cache(seq_idx, bkv_sem_idx, offset, update_sz):
|
|
703
|
+
bkv_update_ids_ref[bkv_sem_idx] = seq_idx
|
|
704
|
+
bkv_update_ids_ref[bkv_sem_idx + 2] = offset
|
|
705
|
+
bkv_update_ids_ref[bkv_sem_idx + 4] = update_sz
|
|
706
|
+
_update_kv_cache(seq_idx, bkv_sem_idx, offset, update_sz)
|
|
707
|
+
|
|
708
|
+
def wait_update_kv_cache(bkv_sem_idx):
|
|
709
|
+
update_sz = bkv_update_ids_ref[bkv_sem_idx + 4]
|
|
710
|
+
|
|
711
|
+
@pl.when(update_sz > 0)
|
|
712
|
+
def _():
|
|
713
|
+
seq_idx = bkv_update_ids_ref[bkv_sem_idx]
|
|
714
|
+
offset = bkv_update_ids_ref[bkv_sem_idx + 2]
|
|
715
|
+
bkv_update_ids_ref[bkv_sem_idx + 4] = 0
|
|
716
|
+
_update_kv_cache(seq_idx,
|
|
717
|
+
bkv_sem_idx,
|
|
718
|
+
offset,
|
|
719
|
+
update_sz,
|
|
720
|
+
wait=True)
|
|
721
|
+
|
|
722
|
+
def load_bq(bq_sem_idx, kv_head_idx, *, actual_bq_sz=bq_sz):
|
|
723
|
+
q_ref = (bq_x2_ref.bitcast(
|
|
724
|
+
jnp.uint32).at[bq_sem_idx, kv_head_idx].reshape(
|
|
725
|
+
bq_sz * num_q_heads_per_kv_head_per_packing,
|
|
726
|
+
actual_head_dim_x2))
|
|
727
|
+
return pltpu.bitcast(
|
|
728
|
+
q_ref[:actual_bq_sz * num_q_heads_per_kv_head_per_packing],
|
|
729
|
+
q_dtype)
|
|
730
|
+
|
|
731
|
+
def strided_load(ref, start, step):
|
|
732
|
+
assert get_dtype_packing(ref.dtype) == 1
|
|
733
|
+
assert len(ref.shape) == 2
|
|
734
|
+
_, l = ref.shape # noqa
|
|
735
|
+
assert l == 128
|
|
736
|
+
vec = ref[start::step]
|
|
737
|
+
return vec
|
|
738
|
+
|
|
739
|
+
def strided_load_bkv(bkv_sem_idx, start, step):
|
|
740
|
+
assert start % kv_packing == 0
|
|
741
|
+
assert step % kv_packing == 0
|
|
742
|
+
start //= kv_packing
|
|
743
|
+
step //= kv_packing
|
|
744
|
+
kv_ref = (bkv_x2_ref.bitcast(jnp.uint32).at[bkv_sem_idx].reshape(
|
|
745
|
+
bkv_sz * step, actual_head_dim_x2))
|
|
746
|
+
|
|
747
|
+
kv = strided_load(kv_ref, start, step)
|
|
748
|
+
bitwidth = 32 // kv_packing
|
|
749
|
+
repack_ty = jnp.dtype(f"uint{bitwidth}")
|
|
750
|
+
lst = []
|
|
751
|
+
for i in range(0, kv_packing):
|
|
752
|
+
cur_kv = pltpu.bitcast((kv >> (i * bitwidth)).astype(repack_ty),
|
|
753
|
+
kv_dtype)
|
|
754
|
+
lst.append(cur_kv)
|
|
755
|
+
return lst
|
|
756
|
+
|
|
757
|
+
def broadcast_minor(src, shape):
|
|
758
|
+
if src.shape == shape:
|
|
759
|
+
return src
|
|
760
|
+
assert src.shape[:-1] == shape[:-1]
|
|
761
|
+
assert src.shape[-1] % 128 == 0
|
|
762
|
+
target_minor = align_to(shape[-1], src.shape[-1])
|
|
763
|
+
# no-op concatenation.
|
|
764
|
+
return jnp.concatenate(
|
|
765
|
+
[src for _ in range(target_minor // src.shape[-1])],
|
|
766
|
+
axis=-1)[..., :shape[-1]]
|
|
767
|
+
|
|
768
|
+
def process(static_q_len=None):
|
|
769
|
+
num_bkv = cdiv(kv_len, bkv_sz)
|
|
770
|
+
if static_q_len is None:
|
|
771
|
+
actual_bq_sz = bq_sz
|
|
772
|
+
num_bq = cdiv(q_len, actual_bq_sz)
|
|
773
|
+
else:
|
|
774
|
+
actual_bq_sz = min(bq_sz, static_q_len)
|
|
775
|
+
num_bq = cdiv(static_q_len, actual_bq_sz)
|
|
776
|
+
|
|
777
|
+
def get_next_bq_ids(seq_idx, bq_idx, bq_sem_idx):
|
|
778
|
+
next_bq_idx = bq_idx + 1
|
|
779
|
+
is_last_bq = next_bq_idx == num_bq
|
|
780
|
+
next_bq_idx = lax.select(is_last_bq, 0, next_bq_idx)
|
|
781
|
+
next_seq_idx = lax.select(is_last_bq, seq_idx + 1, seq_idx)
|
|
782
|
+
next_bq_sem_idx = lax.select(bq_sem_idx == 0, 1, 0)
|
|
783
|
+
return next_seq_idx, next_bq_idx, next_bq_sem_idx
|
|
784
|
+
|
|
785
|
+
def get_next_bkv_ids(seq_idx, bq_idx, bkv_idx, bkv_sem_idx):
|
|
786
|
+
next_bkv_idx = bkv_idx + 1
|
|
787
|
+
is_last_bkv = next_bkv_idx == num_bkv
|
|
788
|
+
next_bq_idx = lax.select(is_last_bkv, bq_idx + 1, bq_idx)
|
|
789
|
+
is_last_bq = next_bq_idx == num_bq
|
|
790
|
+
next_bq_idx = lax.select(is_last_bq, 0, next_bq_idx)
|
|
791
|
+
next_seq_idx = lax.select(is_last_bq, seq_idx + 1, seq_idx)
|
|
792
|
+
next_bkv_sem_idx = lax.select(bkv_sem_idx == 0, 1, 0)
|
|
793
|
+
|
|
794
|
+
if sliding_window is None:
|
|
795
|
+
# When sliding window is disabled, starting bkv_idx of next request is
|
|
796
|
+
# always 0 regardless of seq_idx of next request.
|
|
797
|
+
next_bkv_idx_start = 0
|
|
798
|
+
else:
|
|
799
|
+
# Determine starting bkv_idx of next request based on whether next
|
|
800
|
+
# request is from the same sequence or next sequence.
|
|
801
|
+
next_bkv_idx_start = lax.select(
|
|
802
|
+
is_last_bq,
|
|
803
|
+
next_seq_bkv_idx_start,
|
|
804
|
+
bkv_idx_start,
|
|
805
|
+
)
|
|
806
|
+
next_bkv_idx = lax.select(is_last_bkv, next_bkv_idx_start,
|
|
807
|
+
next_bkv_idx)
|
|
808
|
+
|
|
809
|
+
return next_seq_idx, next_bq_idx, next_bkv_idx, next_bkv_sem_idx
|
|
810
|
+
|
|
811
|
+
def compute_with_bq(bq_idx, _):
|
|
812
|
+
bq_sem_idx = sem_ids_ref[0]
|
|
813
|
+
next_seq_idx, next_bq_idx, next_bq_sem_idx = get_next_bq_ids(
|
|
814
|
+
seq_idx, bq_idx, bq_sem_idx)
|
|
815
|
+
|
|
816
|
+
# Prefetch next bq
|
|
817
|
+
@pl.when(next_seq_idx < num_seqs)
|
|
818
|
+
def prefetch_next_bq():
|
|
819
|
+
sem_ids_ref[0] = next_bq_sem_idx
|
|
820
|
+
start_fetch_bq(next_seq_idx, next_bq_idx, next_bq_sem_idx)
|
|
821
|
+
|
|
822
|
+
def compute_with_bkv(bkv_idx, _):
|
|
823
|
+
# Create bitmask for KV.
|
|
824
|
+
assert bkv_sz % kv_packing == 0
|
|
825
|
+
|
|
826
|
+
# Get next bkv ids.
|
|
827
|
+
bkv_sem_idx = sem_ids_ref[1]
|
|
828
|
+
next_seq_idx, _, next_bkv_idx, next_bkv_sem_idx = get_next_bkv_ids(
|
|
829
|
+
seq_idx, bq_idx, bkv_idx, bkv_sem_idx)
|
|
830
|
+
|
|
831
|
+
# Prefetch next bkv
|
|
832
|
+
@pl.when(next_seq_idx < num_seqs)
|
|
833
|
+
def prefetch_next_bkv():
|
|
834
|
+
sem_ids_ref[1] = next_bkv_sem_idx
|
|
835
|
+
start_fetch_bkv(next_seq_idx, next_bkv_idx,
|
|
836
|
+
next_bkv_sem_idx)
|
|
837
|
+
|
|
838
|
+
# Wait for cur bq if not ready yet
|
|
839
|
+
@pl.when(bkv_idx == bkv_idx_start)
|
|
840
|
+
def wait_cur_bq():
|
|
841
|
+
wait_fetch_bq(seq_idx, bq_idx, bq_sem_idx)
|
|
842
|
+
|
|
843
|
+
# Wait for cur bkv
|
|
844
|
+
offset, update_sz = wait_fetch_bkv(seq_idx, bkv_idx,
|
|
845
|
+
bkv_sem_idx)
|
|
846
|
+
|
|
847
|
+
# Start updating bkv to kv cache if applicable.
|
|
848
|
+
# Only needed in first bq loop.
|
|
849
|
+
@pl.when(jnp.logical_and(update_sz > 0, bq_idx == 0))
|
|
850
|
+
def update_cur_bkv_to_cache():
|
|
851
|
+
start_update_kv_cache(seq_idx, bkv_sem_idx, offset,
|
|
852
|
+
update_sz)
|
|
853
|
+
|
|
854
|
+
debug_print(
|
|
855
|
+
"[RPA debug] -----------flash attention-----------")
|
|
856
|
+
debug_print("[RPA debug] seq_idx={}", seq_idx)
|
|
857
|
+
debug_print("[RPA debug] bq_idx={}", bq_idx)
|
|
858
|
+
debug_print("[RPA debug] bkv_idx={}", bkv_idx)
|
|
859
|
+
if debug_mode:
|
|
860
|
+
# Skip flash attention if debug mode is enabled.
|
|
861
|
+
return
|
|
862
|
+
|
|
863
|
+
# Flash attention with cur bkv and bq
|
|
864
|
+
prev_bq_shape_0 = None
|
|
865
|
+
prev_kv_head_bkv = None
|
|
866
|
+
prev_kv_head_idx = None
|
|
867
|
+
prev_kv_head_p = None
|
|
868
|
+
prev_kv_head_exp_m_diff = None
|
|
869
|
+
for kv_head_start in range(0, actual_num_kv_heads, kv_packing):
|
|
870
|
+
bkv_lst = strided_load_bkv(
|
|
871
|
+
bkv_sem_idx,
|
|
872
|
+
kv_head_start,
|
|
873
|
+
num_kv_heads,
|
|
874
|
+
)
|
|
875
|
+
assert len(bkv_lst) == kv_packing
|
|
876
|
+
for i in range(kv_packing):
|
|
877
|
+
cur_kv_head_idx = kv_head_start + i
|
|
878
|
+
if cur_kv_head_idx >= actual_num_kv_heads:
|
|
879
|
+
break
|
|
880
|
+
cur_kv_head_bq = load_bq(bq_sem_idx,
|
|
881
|
+
cur_kv_head_idx,
|
|
882
|
+
actual_bq_sz=actual_bq_sz)
|
|
883
|
+
cur_kv_head__bkv = bkv_lst[i]
|
|
884
|
+
# FlashAttention is divided into `flash_attention_step1_qk_softmax`
|
|
885
|
+
# and `flash_attention_step2_pv` to pipeline the computation.
|
|
886
|
+
# `step2_pv` for the previous KV head, which depends on the softmax
|
|
887
|
+
# output, is overlapped with `step1_qk_softmax` for the current KV
|
|
888
|
+
# head, reducing overall wait times.
|
|
889
|
+
cur_kv_head_p, cur_kv_head_exp_m_diff = (
|
|
890
|
+
flash_attention_step1_qk_softmax(
|
|
891
|
+
cur_kv_head_bq,
|
|
892
|
+
cur_kv_head__bkv,
|
|
893
|
+
bq_idx=bq_idx,
|
|
894
|
+
bkv_idx=bkv_idx,
|
|
895
|
+
kv_head_idx=cur_kv_head_idx,
|
|
896
|
+
))
|
|
897
|
+
if prev_bq_shape_0 is not None:
|
|
898
|
+
flash_attention_step2_pv(
|
|
899
|
+
prev_bq_shape_0,
|
|
900
|
+
prev_kv_head_bkv,
|
|
901
|
+
prev_kv_head_p,
|
|
902
|
+
prev_kv_head_exp_m_diff,
|
|
903
|
+
bkv_idx=bkv_idx,
|
|
904
|
+
kv_head_idx=prev_kv_head_idx,
|
|
905
|
+
)
|
|
906
|
+
prev_bq_shape_0 = cur_kv_head_bq.shape[0]
|
|
907
|
+
prev_kv_head_bkv = cur_kv_head__bkv
|
|
908
|
+
prev_kv_head_p = cur_kv_head_p
|
|
909
|
+
prev_kv_head_exp_m_diff = cur_kv_head_exp_m_diff
|
|
910
|
+
prev_kv_head_idx = cur_kv_head_idx
|
|
911
|
+
|
|
912
|
+
# Execute pv of last attention head.
|
|
913
|
+
assert prev_bq_shape_0 is not None
|
|
914
|
+
flash_attention_step2_pv(
|
|
915
|
+
prev_bq_shape_0,
|
|
916
|
+
prev_kv_head_bkv,
|
|
917
|
+
prev_kv_head_p,
|
|
918
|
+
prev_kv_head_exp_m_diff,
|
|
919
|
+
bkv_idx=bkv_idx,
|
|
920
|
+
kv_head_idx=prev_kv_head_idx,
|
|
921
|
+
)
|
|
922
|
+
|
|
923
|
+
lax.fori_loop(bkv_idx_start,
|
|
924
|
+
num_bkv,
|
|
925
|
+
compute_with_bkv,
|
|
926
|
+
None,
|
|
927
|
+
unroll=False)
|
|
928
|
+
|
|
929
|
+
# Load acc and calculate final output.
|
|
930
|
+
acc = acc_ref[...]
|
|
931
|
+
l = broadcast_minor(l_ref[...], acc.shape) # noqa
|
|
932
|
+
out = (lax.div(acc, l) if q_dtype == jnp.float32 else
|
|
933
|
+
(acc * pl.reciprocal(l, approx=True)).astype(q_dtype))
|
|
934
|
+
|
|
935
|
+
# Wait for previous bo to be fully sent before storing new bo.
|
|
936
|
+
bo_sem_idx = sem_ids_ref[2]
|
|
937
|
+
sem_ids_ref[2] = lax.select(bo_sem_idx == 0, 1, 0)
|
|
938
|
+
wait_send_bo(bo_sem_idx)
|
|
939
|
+
|
|
940
|
+
# Store output from acc to bo.
|
|
941
|
+
bo_x2_ref.at[bo_sem_idx].bitcast(jnp.int32).reshape(
|
|
942
|
+
actual_num_kv_heads,
|
|
943
|
+
bq_sz * num_q_heads_per_kv_head_per_packing,
|
|
944
|
+
actual_head_dim_x2,
|
|
945
|
+
)[...] = pltpu.bitcast(out, jnp.int32)
|
|
946
|
+
|
|
947
|
+
# Send cur bo
|
|
948
|
+
start_send_bo(seq_idx, bq_idx, bo_sem_idx)
|
|
949
|
+
|
|
950
|
+
lax.fori_loop(0, num_bq, compute_with_bq, None, unroll=False)
|
|
951
|
+
|
|
952
|
+
### ------- Kernel start ------- ###
|
|
953
|
+
|
|
954
|
+
@pl.when(seq_idx == 0)
|
|
955
|
+
def prologue():
|
|
956
|
+
start_fetch_bq(0, 0, 0)
|
|
957
|
+
|
|
958
|
+
# Initialize bkv_x2_ref to zeros to avoid NaN issues from accessing
|
|
959
|
+
# uninitialized memory. Bitcast into int32 to avoid tiling issues.
|
|
960
|
+
bkv_x2_int32_ref = bkv_x2_ref.bitcast(jnp.int32).reshape(
|
|
961
|
+
(2, -1, 8, 128))
|
|
962
|
+
zeros = jnp.zeros(bkv_x2_int32_ref.shape[1:], jnp.int32)
|
|
963
|
+
|
|
964
|
+
# To pipeline VST and DMA, we divide the initialization into two steps.
|
|
965
|
+
bkv_x2_int32_ref[0] = zeros
|
|
966
|
+
start_fetch_bkv(0, bkv_idx_start, 0)
|
|
967
|
+
bkv_x2_int32_ref[1] = zeros
|
|
968
|
+
|
|
969
|
+
@pl.when(seq_idx < decode_end)
|
|
970
|
+
def process_decode():
|
|
971
|
+
process(static_q_len=1)
|
|
972
|
+
|
|
973
|
+
@pl.when(jnp.logical_and(decode_end <= seq_idx, seq_idx < prefill_end))
|
|
974
|
+
def process_prefill():
|
|
975
|
+
process(static_q_len=chunk_prefill_size)
|
|
976
|
+
|
|
977
|
+
@pl.when(jnp.logical_and(prefill_end <= seq_idx, seq_idx < mixed_end))
|
|
978
|
+
def process_mixed():
|
|
979
|
+
process()
|
|
980
|
+
|
|
981
|
+
@pl.when(seq_idx == num_seqs - 1)
|
|
982
|
+
def epilogue():
|
|
983
|
+
for i in range(2):
|
|
984
|
+
wait_send_bo(i)
|
|
985
|
+
wait_update_kv_cache(i)
|
|
986
|
+
|
|
987
|
+
### ------- Kernel end ------- ###
|
|
988
|
+
|
|
989
|
+
|
|
990
|
+
def merge_kv(
|
|
991
|
+
k: jax.
|
|
992
|
+
Array, # [max_num_tokens, actual_num_kv_heads, actual_head_dim],
|
|
993
|
+
v: jax.
|
|
994
|
+
Array, # [max_num_tokens, actual_num_kv_heads, actual_head_dim],
|
|
995
|
+
):
|
|
996
|
+
assert k.shape == v.shape
|
|
997
|
+
assert k.dtype == v.dtype
|
|
998
|
+
max_num_tokens, actual_num_kv_heads, actual_head_dim = k.shape
|
|
999
|
+
kv_packing = get_dtype_packing(k.dtype)
|
|
1000
|
+
num_kv_heads = align_to(actual_num_kv_heads, kv_packing)
|
|
1001
|
+
kv = jnp.pad(
|
|
1002
|
+
jnp.concat([k, v], axis=-1),
|
|
1003
|
+
(
|
|
1004
|
+
(0, 0),
|
|
1005
|
+
(0, num_kv_heads - actual_num_kv_heads),
|
|
1006
|
+
(0, 0),
|
|
1007
|
+
),
|
|
1008
|
+
constant_values=0,
|
|
1009
|
+
).reshape(
|
|
1010
|
+
max_num_tokens,
|
|
1011
|
+
num_kv_heads // kv_packing,
|
|
1012
|
+
kv_packing,
|
|
1013
|
+
actual_head_dim * 2,
|
|
1014
|
+
)
|
|
1015
|
+
return kv
|
|
1016
|
+
|
|
1017
|
+
|
|
1018
|
+
def prepare_inputs(
|
|
1019
|
+
q: jax.Array, # [max_num_tokens, actual_num_q_heads, actual_head_dim],
|
|
1020
|
+
k: jax.
|
|
1021
|
+
Array, # [max_num_tokens, actual_num_kv_heads, actual_head_dim],
|
|
1022
|
+
v: jax.
|
|
1023
|
+
Array, # [max_num_tokens, actual_num_kv_heads, actual_head_dim],
|
|
1024
|
+
attention_sink: jax.Array | None = None, # f32[actual_num_q_heads],
|
|
1025
|
+
):
|
|
1026
|
+
max_num_tokens, actual_num_q_heads, actual_head_dim = q.shape
|
|
1027
|
+
actual_num_kv_heads = k.shape[1]
|
|
1028
|
+
assert actual_num_q_heads % actual_num_kv_heads == 0
|
|
1029
|
+
actual_num_q_heads_per_kv_head = actual_num_q_heads // actual_num_kv_heads
|
|
1030
|
+
q_packing = get_dtype_packing(q.dtype)
|
|
1031
|
+
num_q_heads_per_kv_head = align_to(actual_num_q_heads_per_kv_head,
|
|
1032
|
+
q_packing)
|
|
1033
|
+
head_dim = align_to(actual_head_dim, 128)
|
|
1034
|
+
q = (
|
|
1035
|
+
jnp.pad(
|
|
1036
|
+
q.reshape(
|
|
1037
|
+
max_num_tokens,
|
|
1038
|
+
actual_num_kv_heads,
|
|
1039
|
+
actual_num_q_heads_per_kv_head,
|
|
1040
|
+
actual_head_dim,
|
|
1041
|
+
),
|
|
1042
|
+
(
|
|
1043
|
+
(0, 0),
|
|
1044
|
+
(0, 0),
|
|
1045
|
+
(0, num_q_heads_per_kv_head - actual_num_q_heads_per_kv_head),
|
|
1046
|
+
(0, head_dim - actual_head_dim),
|
|
1047
|
+
),
|
|
1048
|
+
constant_values=0,
|
|
1049
|
+
).reshape(
|
|
1050
|
+
max_num_tokens,
|
|
1051
|
+
actual_num_kv_heads,
|
|
1052
|
+
num_q_heads_per_kv_head // q_packing,
|
|
1053
|
+
q_packing,
|
|
1054
|
+
head_dim,
|
|
1055
|
+
)
|
|
1056
|
+
# TODO(jevinjiang): Explore fusing swapping non-tiling axis to DMA.
|
|
1057
|
+
.swapaxes(0, 1))
|
|
1058
|
+
# TODO(kyuyeunk, chengjiyao): Add kv quantization here.
|
|
1059
|
+
kv = merge_kv(k, v)
|
|
1060
|
+
|
|
1061
|
+
if attention_sink is not None:
|
|
1062
|
+
attention_sink = attention_sink.reshape(
|
|
1063
|
+
(-1, num_q_heads_per_kv_head, 1))
|
|
1064
|
+
attention_sink = jnp.repeat(attention_sink, 128, -1)
|
|
1065
|
+
|
|
1066
|
+
return q, kv, attention_sink
|
|
1067
|
+
|
|
1068
|
+
|
|
1069
|
+
def prepare_outputs(
|
|
1070
|
+
out, # [actual_num_kv_heads, max_num_tokens, num_q_heads_per_kv_head // q_packing, q_packing, actual_head_dim_x2]
|
|
1071
|
+
actual_num_q_heads_per_kv_head: int,
|
|
1072
|
+
actual_head_dim: int,
|
|
1073
|
+
):
|
|
1074
|
+
(
|
|
1075
|
+
actual_num_kv_heads,
|
|
1076
|
+
max_num_tokens,
|
|
1077
|
+
num_q_heads_per_kv_head_per_q_packing,
|
|
1078
|
+
q_packing,
|
|
1079
|
+
actual_head_dim_x2,
|
|
1080
|
+
) = out.shape
|
|
1081
|
+
actual_num_q_heads = actual_num_q_heads_per_kv_head * actual_num_kv_heads
|
|
1082
|
+
return (out.swapaxes(0, 1).reshape(
|
|
1083
|
+
max_num_tokens,
|
|
1084
|
+
actual_num_kv_heads,
|
|
1085
|
+
num_q_heads_per_kv_head_per_q_packing * q_packing,
|
|
1086
|
+
actual_head_dim_x2,
|
|
1087
|
+
)[:, :, :actual_num_q_heads_per_kv_head,
|
|
1088
|
+
actual_head_dim:].reshape(max_num_tokens, actual_num_q_heads,
|
|
1089
|
+
actual_head_dim))
|
|
1090
|
+
|
|
1091
|
+
|
|
1092
|
+
# Expect to run this validation during runtime.
|
|
1093
|
+
def dynamic_validate_inputs(
|
|
1094
|
+
queries: jax.
|
|
1095
|
+
Array, # [max_num_tokens, actual_num_q_heads, actual_head_dim]
|
|
1096
|
+
keys: jax.Array, # [max_num_tokens, actual_num_kv_heads, actual_head_dim]
|
|
1097
|
+
values: jax.
|
|
1098
|
+
Array, # [max_num_tokens, actual_num_kv_heads, actual_head_dim]
|
|
1099
|
+
kv_cache: jax.
|
|
1100
|
+
Array, # [total_num_pages, page_size, num_kv_heads // kv_packing, kv_packing, head_dim]
|
|
1101
|
+
kv_lens: jax.Array, # i32[max_num_seqs]
|
|
1102
|
+
page_indices: jax.Array, # i32[max_num_seqs * pages_per_seq]
|
|
1103
|
+
cu_q_lens: jax.Array, # i32[max_num_seqs + 1]
|
|
1104
|
+
distribution: jax.Array, # i32[3]
|
|
1105
|
+
attention_sink: jax.Array | None = None, # f32[actual_num_q_heads]
|
|
1106
|
+
*,
|
|
1107
|
+
sm_scale: float = 1.0,
|
|
1108
|
+
sliding_window: int | None = None,
|
|
1109
|
+
soft_cap: float | None = None,
|
|
1110
|
+
mask_value: float | None = DEFAULT_MASK_VALUE,
|
|
1111
|
+
q_scale: float | None = None,
|
|
1112
|
+
k_scale: float | None = None,
|
|
1113
|
+
v_scale: float | None = None,
|
|
1114
|
+
# Kernel optimization params.
|
|
1115
|
+
chunk_prefill_size: int | None = None,
|
|
1116
|
+
# Kernel tuning params.
|
|
1117
|
+
num_kv_pages_per_block: int | None = None,
|
|
1118
|
+
num_queries_per_block: int | None = None,
|
|
1119
|
+
vmem_limit_bytes: int | None = None,
|
|
1120
|
+
# Debug params.
|
|
1121
|
+
debug_mode: bool = False,
|
|
1122
|
+
):
|
|
1123
|
+
q, k, v = queries, keys, values
|
|
1124
|
+
static_validate_inputs(
|
|
1125
|
+
q,
|
|
1126
|
+
k,
|
|
1127
|
+
v,
|
|
1128
|
+
kv_cache,
|
|
1129
|
+
kv_lens,
|
|
1130
|
+
page_indices,
|
|
1131
|
+
cu_q_lens,
|
|
1132
|
+
distribution,
|
|
1133
|
+
attention_sink,
|
|
1134
|
+
sm_scale=sm_scale,
|
|
1135
|
+
sliding_window=sliding_window,
|
|
1136
|
+
soft_cap=soft_cap,
|
|
1137
|
+
mask_value=mask_value,
|
|
1138
|
+
q_scale=q_scale,
|
|
1139
|
+
k_scale=k_scale,
|
|
1140
|
+
v_scale=v_scale,
|
|
1141
|
+
chunk_prefill_size=chunk_prefill_size,
|
|
1142
|
+
num_kv_pages_per_block=num_kv_pages_per_block,
|
|
1143
|
+
num_queries_per_block=num_queries_per_block,
|
|
1144
|
+
vmem_limit_bytes=vmem_limit_bytes,
|
|
1145
|
+
debug_mode=debug_mode,
|
|
1146
|
+
)
|
|
1147
|
+
max_num_tokens = q.shape[0]
|
|
1148
|
+
total_num_pages = kv_cache.shape[0]
|
|
1149
|
+
page_size = kv_cache.shape[1]
|
|
1150
|
+
max_num_seqs = kv_lens.shape[0]
|
|
1151
|
+
num_page_indices = page_indices.shape[0]
|
|
1152
|
+
assert num_page_indices % max_num_seqs == 0
|
|
1153
|
+
pages_per_seq = num_page_indices // max_num_seqs
|
|
1154
|
+
|
|
1155
|
+
i, j, k = distribution
|
|
1156
|
+
if not (i <= j <= k):
|
|
1157
|
+
raise ValueError(f"Invalid distribution: {distribution=}")
|
|
1158
|
+
|
|
1159
|
+
if k > max_num_seqs:
|
|
1160
|
+
raise ValueError(f"num_seqs={k} must be <= {max_num_seqs=}")
|
|
1161
|
+
|
|
1162
|
+
if cu_q_lens[k] > max_num_tokens:
|
|
1163
|
+
raise ValueError(
|
|
1164
|
+
f"Total q tokens {cu_q_lens[k]} must be <= {max_num_tokens=}.")
|
|
1165
|
+
for i in range(k):
|
|
1166
|
+
q_len = cu_q_lens[i + 1] - cu_q_lens[i]
|
|
1167
|
+
kv_len = kv_lens[i]
|
|
1168
|
+
if not (0 < q_len <= kv_len):
|
|
1169
|
+
raise ValueError(
|
|
1170
|
+
f"Require 0 < {q_len=} <= {kv_len=} at sequence {i}.")
|
|
1171
|
+
page_cnt = cdiv(kv_len, page_size)
|
|
1172
|
+
if page_cnt > pages_per_seq:
|
|
1173
|
+
raise ValueError(
|
|
1174
|
+
f"Require {page_cnt=} <= {pages_per_seq=} at sequence {i} where"
|
|
1175
|
+
f" {kv_len=} and {page_size=}.")
|
|
1176
|
+
for p in range(page_cnt):
|
|
1177
|
+
page_idx = page_indices[i * pages_per_seq + p]
|
|
1178
|
+
if not (0 <= page_idx < total_num_pages):
|
|
1179
|
+
raise ValueError(
|
|
1180
|
+
f"Require 0 <= {page_idx=} < {total_num_pages=} at sequence"
|
|
1181
|
+
f" {i} where {kv_len=} and {page_size=}.")
|
|
1182
|
+
|
|
1183
|
+
|
|
1184
|
+
# Expect to run this validation during compile time.
|
|
1185
|
+
def static_validate_inputs(
|
|
1186
|
+
queries: jax.
|
|
1187
|
+
Array, # [max_num_tokens, actual_num_q_heads, actual_head_dim]
|
|
1188
|
+
keys: jax.Array, # [max_num_tokens, actual_num_kv_heads, actual_head_dim]
|
|
1189
|
+
values: jax.
|
|
1190
|
+
Array, # [max_num_tokens, actual_num_kv_heads, actual_head_dim]
|
|
1191
|
+
kv_cache: jax.
|
|
1192
|
+
Array, # [total_num_pages, page_size, num_kv_heads // kv_packing, kv_packing, actual_head_dim_x2]
|
|
1193
|
+
kv_lens: jax.Array, # i32[max_num_seqs]
|
|
1194
|
+
page_indices: jax.Array, # i32[max_num_seqs * pages_per_seq]
|
|
1195
|
+
cu_q_lens: jax.Array, # i32[max_num_seqs + 1]
|
|
1196
|
+
distribution: jax.Array, # i32[3]
|
|
1197
|
+
attention_sink: jax.Array | None = None, # f32[actual_num_q_heads]
|
|
1198
|
+
*,
|
|
1199
|
+
sm_scale: float = 1.0,
|
|
1200
|
+
sliding_window: int | None = None,
|
|
1201
|
+
soft_cap: float | None = None,
|
|
1202
|
+
mask_value: float | None = DEFAULT_MASK_VALUE,
|
|
1203
|
+
q_scale: float | None = None,
|
|
1204
|
+
k_scale: float | None = None,
|
|
1205
|
+
v_scale: float | None = None,
|
|
1206
|
+
# Kernel optimization params.
|
|
1207
|
+
chunk_prefill_size: int | None = None,
|
|
1208
|
+
# Kernel tuning params.
|
|
1209
|
+
num_kv_pages_per_block: int | None = None,
|
|
1210
|
+
num_queries_per_block: int | None = None,
|
|
1211
|
+
vmem_limit_bytes: int | None = None,
|
|
1212
|
+
# Debug params.
|
|
1213
|
+
debug_mode: bool = False,
|
|
1214
|
+
):
|
|
1215
|
+
"""Validate inputs to the RPA kernel statically."""
|
|
1216
|
+
q, k, v = queries, keys, values
|
|
1217
|
+
if not (len(q.shape) == len(k.shape) == len(v.shape) == 3):
|
|
1218
|
+
raise ValueError(
|
|
1219
|
+
f"Expected 3D array for {q.shape=}, {k.shape=}, {v.shape=}")
|
|
1220
|
+
if k.shape != v.shape:
|
|
1221
|
+
raise ValueError(f"Expected {k.shape=} to be equal to {v.shape=}")
|
|
1222
|
+
if not (q.shape[0] == k.shape[0] == v.shape[0]):
|
|
1223
|
+
raise ValueError(
|
|
1224
|
+
f"Expected {q.shape[0]=} to be equal to {k.shape[0]=} and {v.shape[0]=}"
|
|
1225
|
+
)
|
|
1226
|
+
if not (q.shape[2] == k.shape[2] == v.shape[2]):
|
|
1227
|
+
raise ValueError(
|
|
1228
|
+
f"Expected {q.shape[2]=} to be equal to {k.shape[2]=} and {v.shape[2]=}"
|
|
1229
|
+
)
|
|
1230
|
+
if attention_sink is not None:
|
|
1231
|
+
if attention_sink.shape[0] != q.shape[1]:
|
|
1232
|
+
raise ValueError(
|
|
1233
|
+
f"Expected {attention_sink.shape[0]=} to be equal to"
|
|
1234
|
+
f" {q.shape[1]=} (num_q_heads).")
|
|
1235
|
+
if attention_sink.dtype != jnp.float32:
|
|
1236
|
+
raise ValueError(
|
|
1237
|
+
f"Expected {attention_sink.dtype=} to be equal to {jnp.float32=}."
|
|
1238
|
+
)
|
|
1239
|
+
|
|
1240
|
+
actual_head_dim = q.shape[2]
|
|
1241
|
+
if actual_head_dim != 64:
|
|
1242
|
+
raise ValueError(f"Expected {actual_head_dim=} to be 64.")
|
|
1243
|
+
actual_num_q_heads = q.shape[1]
|
|
1244
|
+
actual_num_kv_heads = k.shape[1]
|
|
1245
|
+
|
|
1246
|
+
if actual_num_q_heads % actual_num_kv_heads != 0:
|
|
1247
|
+
raise ValueError(f"Expected {actual_num_q_heads=} to be divisible by"
|
|
1248
|
+
f" {actual_num_kv_heads=}.")
|
|
1249
|
+
|
|
1250
|
+
(
|
|
1251
|
+
_,
|
|
1252
|
+
page_size,
|
|
1253
|
+
num_kv_heads_per_kv_packing,
|
|
1254
|
+
kv_packing,
|
|
1255
|
+
actual_head_dim_x2,
|
|
1256
|
+
) = kv_cache.shape
|
|
1257
|
+
|
|
1258
|
+
if actual_head_dim_x2 != 128:
|
|
1259
|
+
raise ValueError(f"Expected {actual_head_dim_x2=} is equal to 128")
|
|
1260
|
+
# Note: we expect the kv quantization happens outside of the RPA kernel.
|
|
1261
|
+
if not (kv_cache.dtype == k.dtype == v.dtype):
|
|
1262
|
+
raise ValueError(
|
|
1263
|
+
f"Expected {kv_cache.dtype=} to be equal to {k.dtype=} and {v.dtype=}."
|
|
1264
|
+
)
|
|
1265
|
+
# Integer kv quantization is currently not supported.
|
|
1266
|
+
if not jnp.issubdtype(kv_cache.dtype, jnp.floating):
|
|
1267
|
+
raise ValueError(f"Expected {kv_cache.dtype=} to be a floating point.")
|
|
1268
|
+
if kv_packing != get_dtype_packing(kv_cache.dtype):
|
|
1269
|
+
raise ValueError(
|
|
1270
|
+
f"{kv_packing=} does not match with {kv_cache.dtype=}")
|
|
1271
|
+
|
|
1272
|
+
num_kv_heads = num_kv_heads_per_kv_packing * kv_packing
|
|
1273
|
+
if align_to(actual_num_kv_heads, kv_packing) != num_kv_heads:
|
|
1274
|
+
raise ValueError(
|
|
1275
|
+
f"Invalid {num_kv_heads=}, {actual_num_kv_heads=}, {kv_packing=}")
|
|
1276
|
+
|
|
1277
|
+
if not (jnp.int32 == kv_lens.dtype == page_indices.dtype == cu_q_lens.dtype
|
|
1278
|
+
== distribution.dtype):
|
|
1279
|
+
raise ValueError(
|
|
1280
|
+
f"Expected int32 dtype for {kv_lens.dtype=}, {page_indices.dtype=},"
|
|
1281
|
+
f" {cu_q_lens.dtype=}, {distribution.dtype=}")
|
|
1282
|
+
|
|
1283
|
+
if not (len(kv_lens.shape) == len(page_indices.shape) == len(
|
|
1284
|
+
cu_q_lens.shape) == 1):
|
|
1285
|
+
raise ValueError(
|
|
1286
|
+
f"Expected 1D array for {kv_lens.shape=}, {page_indices.shape=},"
|
|
1287
|
+
f" {cu_q_lens.shape=}")
|
|
1288
|
+
|
|
1289
|
+
max_num_seqs = kv_lens.shape[0]
|
|
1290
|
+
num_page_indices = page_indices.shape[0]
|
|
1291
|
+
if num_page_indices % max_num_seqs != 0:
|
|
1292
|
+
raise ValueError(
|
|
1293
|
+
f"Expected {num_page_indices=} to be divisible by {max_num_seqs=}."
|
|
1294
|
+
)
|
|
1295
|
+
if cu_q_lens.shape != (max_num_seqs + 1, ):
|
|
1296
|
+
raise ValueError(
|
|
1297
|
+
f"Expected {cu_q_lens.shape=} to be ({max_num_seqs + 1},).")
|
|
1298
|
+
if distribution.shape != (3, ):
|
|
1299
|
+
raise ValueError(f"Expected {distribution.shape=} to be (3,).")
|
|
1300
|
+
|
|
1301
|
+
if page_size % kv_packing != 0:
|
|
1302
|
+
raise ValueError(f"{page_size=} must be divisible by {kv_packing=}.")
|
|
1303
|
+
if sliding_window is not None and sliding_window <= 0:
|
|
1304
|
+
raise ValueError(f"{sliding_window=} must be positive.")
|
|
1305
|
+
if soft_cap is not None and soft_cap == 0.0:
|
|
1306
|
+
raise ValueError(f"{soft_cap=} must not be 0.0.")
|
|
1307
|
+
if chunk_prefill_size is not None and chunk_prefill_size <= 0:
|
|
1308
|
+
raise ValueError(f"{chunk_prefill_size=} must be positive.")
|
|
1309
|
+
if num_kv_pages_per_block is not None:
|
|
1310
|
+
if num_kv_pages_per_block <= 0:
|
|
1311
|
+
raise ValueError(f"{num_kv_pages_per_block=} must be positive.")
|
|
1312
|
+
if num_queries_per_block is not None:
|
|
1313
|
+
if num_queries_per_block <= 0:
|
|
1314
|
+
raise ValueError(f"{num_queries_per_block=} must be positive.")
|
|
1315
|
+
if vmem_limit_bytes is not None and vmem_limit_bytes <= 0:
|
|
1316
|
+
raise ValueError(f"{vmem_limit_bytes=} must be positive.")
|
|
1317
|
+
|
|
1318
|
+
# No constraints for the following inputs.
|
|
1319
|
+
del sm_scale
|
|
1320
|
+
del mask_value
|
|
1321
|
+
del q_scale
|
|
1322
|
+
del k_scale
|
|
1323
|
+
del v_scale
|
|
1324
|
+
del debug_mode
|
|
1325
|
+
|
|
1326
|
+
|
|
1327
|
+
def get_kernel_scope_name(bq_size, bkv_p, page_size):
|
|
1328
|
+
return f"RPA-HD_64-bq_{bq_size}-bkvp_{bkv_p}-p_{page_size}-"
|
|
1329
|
+
|
|
1330
|
+
|
|
1331
|
+
@functools.partial(
|
|
1332
|
+
jax.jit,
|
|
1333
|
+
static_argnames=(
|
|
1334
|
+
"sm_scale",
|
|
1335
|
+
"sliding_window",
|
|
1336
|
+
"soft_cap",
|
|
1337
|
+
"mask_value",
|
|
1338
|
+
"q_scale",
|
|
1339
|
+
"k_scale",
|
|
1340
|
+
"v_scale",
|
|
1341
|
+
"chunk_prefill_size",
|
|
1342
|
+
"num_kv_pages_per_block",
|
|
1343
|
+
"num_queries_per_block",
|
|
1344
|
+
"vmem_limit_bytes",
|
|
1345
|
+
"debug_mode",
|
|
1346
|
+
),
|
|
1347
|
+
donate_argnames=("kv_cache", ),
|
|
1348
|
+
)
|
|
1349
|
+
def ragged_paged_attention_hd64(
|
|
1350
|
+
queries: jax.
|
|
1351
|
+
Array, # [max_num_tokens, actual_num_q_heads, actual_head_dim]
|
|
1352
|
+
keys: jax.Array, # [max_num_tokens, actual_num_kv_heads, actual_head_dim]
|
|
1353
|
+
values: jax.
|
|
1354
|
+
Array, # [max_num_tokens, actual_num_kv_heads, actual_head_dim]
|
|
1355
|
+
kv_cache: jax.
|
|
1356
|
+
Array, # [total_num_pages, page_size, num_kv_heads // kv_packing, kv_packing, actual_head_dim_x2]
|
|
1357
|
+
kv_lens: jax.Array, # i32[max_num_seqs]
|
|
1358
|
+
page_indices: jax.Array, # i32[max_num_seqs * pages_per_seq]
|
|
1359
|
+
cu_q_lens: jax.Array, # i32[max_num_seqs + 1]
|
|
1360
|
+
distribution: jax.Array, # i32[3]
|
|
1361
|
+
attention_sink: jax.Array | None = None, # f32[actual_num_q_heads]
|
|
1362
|
+
*,
|
|
1363
|
+
sm_scale: float = 1.0,
|
|
1364
|
+
sliding_window: int | None = None,
|
|
1365
|
+
soft_cap: float | None = None,
|
|
1366
|
+
mask_value: float | None = DEFAULT_MASK_VALUE,
|
|
1367
|
+
q_scale: float | None = None,
|
|
1368
|
+
k_scale: float | None = None,
|
|
1369
|
+
v_scale: float | None = None,
|
|
1370
|
+
# Kernel optimization params.
|
|
1371
|
+
chunk_prefill_size: int | None = None,
|
|
1372
|
+
# Kernel tuning params.
|
|
1373
|
+
num_kv_pages_per_block: int | None = None,
|
|
1374
|
+
num_queries_per_block: int | None = None,
|
|
1375
|
+
vmem_limit_bytes: int | None = None,
|
|
1376
|
+
# Debug params.
|
|
1377
|
+
debug_mode: bool = False,
|
|
1378
|
+
):
|
|
1379
|
+
"""A variant of ragged paged attention for head_dim=64.
|
|
1380
|
+
|
|
1381
|
+
Args:
|
|
1382
|
+
queries: concatenated all sequences' queries.
|
|
1383
|
+
keys: concatenated all sequences' keys (quantized).
|
|
1384
|
+
values: concatenated all sequences' values (quantized).
|
|
1385
|
+
kv_cache: paged KV cache with TPU-friendly shape.
|
|
1386
|
+
kv_lens: padded kv lengths. Only the first num_seqs values are valid.
|
|
1387
|
+
page_indices: flattened page indices look-up table by (seq_id, page_id).
|
|
1388
|
+
cu_q_lens: the cumulative sum of the effective query lengths. Similar to
|
|
1389
|
+
kv_lens, only the first num_seqs+1 values are valid.
|
|
1390
|
+
distribution: (i, j, k) represents that sequences[0:i] are decode-only,
|
|
1391
|
+
sequences[i:j] are chunked-prefill-only, and sequences[j:k] are mixed. The
|
|
1392
|
+
k is also the total number of sequences.
|
|
1393
|
+
attention_sink: optional attention sink for each q head.
|
|
1394
|
+
sm_scale: the softmax scale which will be applied to the Q@K^T.
|
|
1395
|
+
sliding_window: the sliding window size for the attention.
|
|
1396
|
+
soft_cap: the logit soft cap for the attention.
|
|
1397
|
+
mask_value: mask value for causal mask.
|
|
1398
|
+
q_scale: the scale for the query.
|
|
1399
|
+
k_scale: the scale for the key cache.
|
|
1400
|
+
v_scale: the scale for the value cache.
|
|
1401
|
+
chunk_prefill_size: the chunk prefill size for the attention.
|
|
1402
|
+
num_kv_pages_per_block: number of kv pages to be processed in one flash
|
|
1403
|
+
attention block in the pallas kernel.
|
|
1404
|
+
num_queries_per_block: number of kv pages to be processed in one flash
|
|
1405
|
+
attention block in the pallas kernel.
|
|
1406
|
+
vmem_limit_bytes: the vmem limit for the pallas kernel.
|
|
1407
|
+
debug_mode: if true, RPA does not issue any DMAs or run flash attention but
|
|
1408
|
+
print debug info. Need to compile with `--xla_tpu_enable_log_recorder`.
|
|
1409
|
+
|
|
1410
|
+
Returns:
|
|
1411
|
+
The output of the attention.
|
|
1412
|
+
"""
|
|
1413
|
+
q, k, v = queries, keys, values
|
|
1414
|
+
static_validate_inputs(
|
|
1415
|
+
q,
|
|
1416
|
+
k,
|
|
1417
|
+
v,
|
|
1418
|
+
kv_cache,
|
|
1419
|
+
kv_lens,
|
|
1420
|
+
page_indices,
|
|
1421
|
+
cu_q_lens,
|
|
1422
|
+
distribution,
|
|
1423
|
+
attention_sink,
|
|
1424
|
+
sm_scale=sm_scale,
|
|
1425
|
+
sliding_window=sliding_window,
|
|
1426
|
+
soft_cap=soft_cap,
|
|
1427
|
+
mask_value=mask_value,
|
|
1428
|
+
q_scale=q_scale,
|
|
1429
|
+
k_scale=k_scale,
|
|
1430
|
+
v_scale=v_scale,
|
|
1431
|
+
chunk_prefill_size=chunk_prefill_size,
|
|
1432
|
+
num_kv_pages_per_block=num_kv_pages_per_block,
|
|
1433
|
+
num_queries_per_block=num_queries_per_block,
|
|
1434
|
+
vmem_limit_bytes=vmem_limit_bytes,
|
|
1435
|
+
)
|
|
1436
|
+
|
|
1437
|
+
actual_num_q_heads = q.shape[1]
|
|
1438
|
+
actual_head_dim = q.shape[2]
|
|
1439
|
+
actual_num_kv_heads = k.shape[1]
|
|
1440
|
+
|
|
1441
|
+
actual_num_q_heads_per_kv_head = actual_num_q_heads // actual_num_kv_heads
|
|
1442
|
+
q, kv, attention_sink = prepare_inputs(q, k, v, attention_sink)
|
|
1443
|
+
(
|
|
1444
|
+
_,
|
|
1445
|
+
max_num_tokens,
|
|
1446
|
+
num_q_heads_per_kv_head_per_q_packing,
|
|
1447
|
+
q_packing,
|
|
1448
|
+
head_dim,
|
|
1449
|
+
) = q.shape
|
|
1450
|
+
page_size = kv_cache.shape[1]
|
|
1451
|
+
max_num_seqs = kv_lens.shape[0]
|
|
1452
|
+
num_page_indices = page_indices.shape[0]
|
|
1453
|
+
assert num_page_indices % max_num_seqs == 0
|
|
1454
|
+
pages_per_seq = num_page_indices // max_num_seqs
|
|
1455
|
+
num_q_heads_per_kv_head = num_q_heads_per_kv_head_per_q_packing * q_packing
|
|
1456
|
+
|
|
1457
|
+
bkv_p = num_kv_pages_per_block
|
|
1458
|
+
bq_sz = num_queries_per_block
|
|
1459
|
+
if bq_sz is None or bkv_p is None:
|
|
1460
|
+
bkv_p, bq_sz = get_tuned_block_sizes(
|
|
1461
|
+
q.dtype,
|
|
1462
|
+
kv_cache.dtype,
|
|
1463
|
+
actual_num_q_heads,
|
|
1464
|
+
actual_num_kv_heads,
|
|
1465
|
+
actual_head_dim,
|
|
1466
|
+
page_size,
|
|
1467
|
+
max_num_tokens,
|
|
1468
|
+
pages_per_seq,
|
|
1469
|
+
sliding_window,
|
|
1470
|
+
)
|
|
1471
|
+
bkv_sz = bkv_p * page_size
|
|
1472
|
+
if vmem_limit_bytes is None:
|
|
1473
|
+
# TODO (jevinjiang/jacobplatin): change this to use
|
|
1474
|
+
# `get_vmem_estimate_bytes` when VREG spilling is fixed.
|
|
1475
|
+
vmem_limit_bytes = DEFAULT_VMEM_LIMIT_BYTES
|
|
1476
|
+
grid = (distribution[2], )
|
|
1477
|
+
|
|
1478
|
+
in_specs = [
|
|
1479
|
+
pl.BlockSpec(memory_space=pltpu.HBM),
|
|
1480
|
+
pl.BlockSpec(memory_space=pltpu.HBM),
|
|
1481
|
+
pl.BlockSpec(memory_space=pltpu.HBM),
|
|
1482
|
+
None if attention_sink is None else pl.BlockSpec(
|
|
1483
|
+
memory_space=pltpu.VMEM),
|
|
1484
|
+
]
|
|
1485
|
+
|
|
1486
|
+
out_specs = [
|
|
1487
|
+
pl.BlockSpec(memory_space=pltpu.HBM),
|
|
1488
|
+
pl.BlockSpec(memory_space=pltpu.HBM),
|
|
1489
|
+
]
|
|
1490
|
+
|
|
1491
|
+
bkv_double_buf = pltpu.VMEM(
|
|
1492
|
+
(2, bkv_sz, *kv_cache.shape[2:]),
|
|
1493
|
+
kv_cache.dtype,
|
|
1494
|
+
)
|
|
1495
|
+
|
|
1496
|
+
bq_double_buf = pltpu.VMEM(
|
|
1497
|
+
(2, actual_num_kv_heads, bq_sz, *q.shape[2:]),
|
|
1498
|
+
q.dtype,
|
|
1499
|
+
)
|
|
1500
|
+
|
|
1501
|
+
bo_double_buf = bq_double_buf
|
|
1502
|
+
|
|
1503
|
+
l_scratch = pltpu.VMEM(
|
|
1504
|
+
(actual_num_kv_heads, bq_sz * num_q_heads_per_kv_head, 128),
|
|
1505
|
+
jnp.float32,
|
|
1506
|
+
)
|
|
1507
|
+
m_scratch = l_scratch
|
|
1508
|
+
|
|
1509
|
+
acc_scratch = pltpu.VMEM(
|
|
1510
|
+
(actual_num_kv_heads, bq_sz * num_q_heads_per_kv_head, head_dim),
|
|
1511
|
+
jnp.float32,
|
|
1512
|
+
)
|
|
1513
|
+
|
|
1514
|
+
scratch_shapes = [
|
|
1515
|
+
bkv_double_buf, # Double buffering for kv block.
|
|
1516
|
+
bq_double_buf, # Double buffering for q block.
|
|
1517
|
+
bo_double_buf, # Double buffering for output block.
|
|
1518
|
+
# Semaphores for double buffering of bkv, bq, bo and bkv_update.
|
|
1519
|
+
pltpu.SemaphoreType.DMA((4, 2)),
|
|
1520
|
+
# Intermediate buffers per kv head for flash attention.
|
|
1521
|
+
l_scratch,
|
|
1522
|
+
m_scratch,
|
|
1523
|
+
acc_scratch,
|
|
1524
|
+
]
|
|
1525
|
+
|
|
1526
|
+
scalar_prefetches = (
|
|
1527
|
+
kv_lens,
|
|
1528
|
+
# TODO(jevinjiang): can we use ragged page_indices to save some smem?
|
|
1529
|
+
page_indices,
|
|
1530
|
+
cu_q_lens,
|
|
1531
|
+
distribution,
|
|
1532
|
+
# (bq_sem_idx, bkv_sem_idx, bo_sem_idx)
|
|
1533
|
+
jnp.zeros((3, ), jnp.int32),
|
|
1534
|
+
# (bo_sem_0_seq_idx, bo_sem_1_seq_idx, bo_sem_0_bo_idx, bo_sem_1_bo_idx)
|
|
1535
|
+
jnp.full((4, ), -1, jnp.int32),
|
|
1536
|
+
# (bkv_sem_0_seq_idx, bkv_sem_1_seq_idx, bkv_sem_0_offset, bkv_sem_1_offset, bkv_sem_0_sz, bkv_sem_1_sz)
|
|
1537
|
+
jnp.full((6, ), -1, jnp.int32),
|
|
1538
|
+
)
|
|
1539
|
+
|
|
1540
|
+
scope_name = get_kernel_scope_name(bq_sz, bkv_p, page_size)
|
|
1541
|
+
kernel = pl.pallas_call(
|
|
1542
|
+
functools.partial(
|
|
1543
|
+
_ragged_paged_attention_kernel,
|
|
1544
|
+
sm_scale=sm_scale,
|
|
1545
|
+
sliding_window=sliding_window,
|
|
1546
|
+
soft_cap=soft_cap,
|
|
1547
|
+
mask_value=mask_value,
|
|
1548
|
+
q_scale=q_scale,
|
|
1549
|
+
k_scale=k_scale,
|
|
1550
|
+
v_scale=v_scale,
|
|
1551
|
+
chunk_prefill_size=chunk_prefill_size,
|
|
1552
|
+
bq_sz=bq_sz,
|
|
1553
|
+
bkv_p=bkv_p,
|
|
1554
|
+
debug_mode=debug_mode,
|
|
1555
|
+
),
|
|
1556
|
+
grid_spec=pltpu.PrefetchScalarGridSpec(
|
|
1557
|
+
num_scalar_prefetch=len(scalar_prefetches),
|
|
1558
|
+
in_specs=in_specs,
|
|
1559
|
+
out_specs=out_specs,
|
|
1560
|
+
grid=grid,
|
|
1561
|
+
scratch_shapes=scratch_shapes,
|
|
1562
|
+
),
|
|
1563
|
+
compiler_params=pltpu.CompilerParams(
|
|
1564
|
+
# TODO(jevinjiang): since each sequence depends on the previous
|
|
1565
|
+
# one, we need some extra work to support Megacore mode.
|
|
1566
|
+
dimension_semantics=("arbitrary", ),
|
|
1567
|
+
vmem_limit_bytes=vmem_limit_bytes,
|
|
1568
|
+
),
|
|
1569
|
+
out_shape=[
|
|
1570
|
+
jax.ShapeDtypeStruct(shape=q.shape, dtype=q.dtype),
|
|
1571
|
+
jax.ShapeDtypeStruct(shape=kv_cache.shape, dtype=kv_cache.dtype),
|
|
1572
|
+
],
|
|
1573
|
+
input_output_aliases={
|
|
1574
|
+
7: 0,
|
|
1575
|
+
9: 1
|
|
1576
|
+
},
|
|
1577
|
+
name=scope_name,
|
|
1578
|
+
)
|
|
1579
|
+
|
|
1580
|
+
output, updated_kv_cache = kernel(*scalar_prefetches, q, kv, kv_cache,
|
|
1581
|
+
attention_sink)
|
|
1582
|
+
return (
|
|
1583
|
+
prepare_outputs(output, actual_num_q_heads_per_kv_head,
|
|
1584
|
+
actual_head_dim),
|
|
1585
|
+
updated_kv_cache,
|
|
1586
|
+
)
|