tpu-inference 0.12.0.dev20251222__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/__init__.py +13 -0
- tests/core/__init__.py +13 -0
- tests/core/test_core_tpu.py +513 -0
- tests/core/test_disagg_executor.py +60 -0
- tests/core/test_disagg_utils.py +67 -0
- tests/core/test_dp_scheduler.py +724 -0
- tests/core/test_init.py +63 -0
- tests/distributed/__init__.py +13 -0
- tests/distributed/test_distributed_utils.py +120 -0
- tests/distributed/test_tpu_connector.py +478 -0
- tests/e2e/__init__.py +13 -0
- tests/e2e/test_async_scheduler.py +211 -0
- tests/e2e/test_data_parallel.py +393 -0
- tests/e2e/test_local_disagg.py +257 -0
- tests/e2e/test_model_loader.py +268 -0
- tests/e2e/test_multi_modal_inference.py +111 -0
- tests/e2e/test_pipeline_parallel.py +265 -0
- tests/e2e/test_runai_model_streamer_loader.py +104 -0
- tests/e2e/test_sampling_params.py +269 -0
- tests/e2e/test_speculative_decoding.py +291 -0
- tests/e2e/test_structured_decoding.py +46 -0
- tests/executors/__init__.py +13 -0
- tests/executors/test_ray_distributed_executor.py +199 -0
- tests/experimental/__init__.py +13 -0
- tests/experimental/test_llama3_jax_stashed.py +208 -0
- tests/kernels/__init__.py +13 -0
- tests/kernels/collectives/__init__.py +13 -0
- tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
- tests/kernels/fused_moe_v1_test.py +388 -0
- tests/kernels/gmm_test.py +205 -0
- tests/kernels/mla_v1_test.py +498 -0
- tests/kernels/quantized_matmul_kernel_test.py +159 -0
- tests/kernels/ragged_kv_cache_update_v2_test.py +248 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +414 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +565 -0
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +520 -0
- tests/layers/__init__.py +13 -0
- tests/layers/common/__init__.py +13 -0
- tests/layers/common/test_attention_interface.py +156 -0
- tests/layers/common/test_quantization.py +149 -0
- tests/layers/jax/__init__.py +13 -0
- tests/layers/jax/attention/__init__.py +13 -0
- tests/layers/jax/attention/test_common_attention.py +103 -0
- tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
- tests/layers/jax/attention/test_llama4_attention.py +135 -0
- tests/layers/jax/moe/__init__.py +13 -0
- tests/layers/jax/moe/test_deepseek_moe.py +235 -0
- tests/layers/jax/sample/__init__.py +13 -0
- tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
- tests/layers/jax/sample/test_sampling.py +115 -0
- tests/layers/jax/sample/test_sampling_metadata.py +254 -0
- tests/layers/jax/test_layers.py +155 -0
- tests/layers/jax/test_qwix.py +969 -0
- tests/layers/jax/test_rope.py +93 -0
- tests/layers/jax/test_sharding.py +159 -0
- tests/layers/jax/test_transformer_block.py +152 -0
- tests/layers/vllm/__init__.py +13 -0
- tests/layers/vllm/test_attention.py +363 -0
- tests/layers/vllm/test_awq.py +405 -0
- tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +403 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +426 -0
- tests/layers/vllm/test_fp8.py +17 -0
- tests/layers/vllm/test_mxfp4.py +297 -0
- tests/layers/vllm/test_unquantized.py +621 -0
- tests/layers/vllm/utils.py +72 -0
- tests/lora/__init__.py +13 -0
- tests/lora/conftest.py +46 -0
- tests/lora/test_bgmv.py +57 -0
- tests/lora/test_layers.py +666 -0
- tests/lora/test_lora.py +147 -0
- tests/lora/test_lora_perf.py +67 -0
- tests/lora/utils.py +88 -0
- tests/models/__init__.py +13 -0
- tests/models/common/__init__.py +13 -0
- tests/models/common/test_model_loader.py +455 -0
- tests/models/jax/__init__.py +13 -0
- tests/models/jax/test_deepseek_v3.py +401 -0
- tests/models/jax/test_llama3.py +184 -0
- tests/models/jax/test_llama4.py +298 -0
- tests/models/jax/test_llama_eagle3.py +197 -0
- tests/models/jax/test_llama_guard_4.py +242 -0
- tests/models/jax/test_qwen2.py +172 -0
- tests/models/jax/test_qwen2_5_vl.py +606 -0
- tests/models/jax/test_qwen3.py +169 -0
- tests/models/jax/test_weight_loading.py +180 -0
- tests/models/jax/utils/__init__.py +13 -0
- tests/models/jax/utils/test_multi_modal_utils.py +212 -0
- tests/platforms/__init__.py +13 -0
- tests/platforms/test_tpu_platform.py +54 -0
- tests/runner/__init__.py +13 -0
- tests/runner/test_block_table.py +395 -0
- tests/runner/test_input_batch.py +226 -0
- tests/runner/test_kv_cache.py +220 -0
- tests/runner/test_kv_cache_manager.py +498 -0
- tests/runner/test_multimodal_manager.py +429 -0
- tests/runner/test_persistent_batch_manager.py +84 -0
- tests/runner/test_speculative_decoding_manager.py +368 -0
- tests/runner/test_structured_decoding_manager.py +220 -0
- tests/runner/test_tpu_runner.py +202 -0
- tests/runner/test_tpu_runner_dp.py +1033 -0
- tests/runner/test_tpu_runner_mesh.py +200 -0
- tests/runner/test_utils.py +411 -0
- tests/spec_decode/__init__.py +13 -0
- tests/spec_decode/test_eagle3.py +311 -0
- tests/test_base.py +215 -0
- tests/test_envs.py +280 -0
- tests/test_tpu_info.py +134 -0
- tests/test_utils.py +193 -0
- tests/worker/__init__.py +13 -0
- tests/worker/tpu_worker_test.py +414 -0
- tpu_inference/__init__.py +67 -0
- tpu_inference/core/__init__.py +13 -0
- tpu_inference/core/core_tpu.py +786 -0
- tpu_inference/core/disagg_executor.py +118 -0
- tpu_inference/core/disagg_utils.py +49 -0
- tpu_inference/core/sched/__init__.py +13 -0
- tpu_inference/core/sched/dp_scheduler.py +814 -0
- tpu_inference/distributed/__init__.py +13 -0
- tpu_inference/distributed/jax_parallel_state.py +81 -0
- tpu_inference/distributed/tpu_connector.py +732 -0
- tpu_inference/distributed/utils.py +112 -0
- tpu_inference/env_override.py +9 -0
- tpu_inference/envs.py +191 -0
- tpu_inference/executors/__init__.py +13 -0
- tpu_inference/executors/ray_distributed_executor.py +399 -0
- tpu_inference/experimental/__init__.py +13 -0
- tpu_inference/experimental/llama3_jax_stashed.py +272 -0
- tpu_inference/kernels/__init__.py +13 -0
- tpu_inference/kernels/collectives/__init__.py +13 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +741 -0
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +65 -0
- tpu_inference/kernels/collectives/util.py +47 -0
- tpu_inference/kernels/flash_attention/__init__.py +13 -0
- tpu_inference/kernels/flash_attention/kernel.py +772 -0
- tpu_inference/kernels/fused_moe/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +1612 -0
- tpu_inference/kernels/megablox/__init__.py +13 -0
- tpu_inference/kernels/megablox/common.py +54 -0
- tpu_inference/kernels/megablox/gmm.py +646 -0
- tpu_inference/kernels/mla/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/kernel.py +1340 -0
- tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +456 -0
- tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
- tpu_inference/kernels/quantized_matmul/util.py +58 -0
- tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +876 -0
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +288 -0
- tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1594 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1586 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4460 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +548 -0
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +65 -0
- tpu_inference/layers/__init__.py +13 -0
- tpu_inference/layers/common/__init__.py +13 -0
- tpu_inference/layers/common/attention_interface.py +403 -0
- tpu_inference/layers/common/attention_metadata.py +48 -0
- tpu_inference/layers/common/binary_search.py +295 -0
- tpu_inference/layers/common/quant_methods.py +23 -0
- tpu_inference/layers/common/quantization.py +270 -0
- tpu_inference/layers/common/sharding.py +600 -0
- tpu_inference/layers/jax/__init__.py +13 -0
- tpu_inference/layers/jax/attention/__init__.py +13 -0
- tpu_inference/layers/jax/attention/attention.py +268 -0
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +547 -0
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +275 -0
- tpu_inference/layers/jax/attention/llama4_attention.py +167 -0
- tpu_inference/layers/jax/base.py +165 -0
- tpu_inference/layers/jax/constants.py +101 -0
- tpu_inference/layers/jax/layers.py +315 -0
- tpu_inference/layers/jax/misc.py +30 -0
- tpu_inference/layers/jax/moe/__init__.py +13 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +615 -0
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +199 -0
- tpu_inference/layers/jax/moe/moe.py +249 -0
- tpu_inference/layers/jax/pp_utils.py +53 -0
- tpu_inference/layers/jax/rope.py +294 -0
- tpu_inference/layers/jax/rope_interface.py +228 -0
- tpu_inference/layers/jax/sample/__init__.py +13 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +528 -0
- tpu_inference/layers/jax/sample/sampling.py +110 -0
- tpu_inference/layers/jax/sample/sampling_metadata.py +90 -0
- tpu_inference/layers/jax/transformer_block.py +121 -0
- tpu_inference/layers/vllm/__init__.py +13 -0
- tpu_inference/layers/vllm/attention.py +221 -0
- tpu_inference/layers/vllm/fused_moe.py +502 -0
- tpu_inference/layers/vllm/linear_common.py +221 -0
- tpu_inference/layers/vllm/quantization/__init__.py +55 -0
- tpu_inference/layers/vllm/quantization/awq.py +221 -0
- tpu_inference/layers/vllm/quantization/common.py +124 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +135 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +266 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +222 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +150 -0
- tpu_inference/layers/vllm/quantization/fp8.py +118 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +396 -0
- tpu_inference/layers/vllm/quantization/unquantized.py +416 -0
- tpu_inference/layers/vllm/sharding.py +244 -0
- tpu_inference/logger.py +10 -0
- tpu_inference/lora/__init__.py +13 -0
- tpu_inference/lora/torch_lora_ops.py +98 -0
- tpu_inference/lora/torch_punica_tpu.py +310 -0
- tpu_inference/models/__init__.py +13 -0
- tpu_inference/models/common/__init__.py +13 -0
- tpu_inference/models/common/model_loader.py +520 -0
- tpu_inference/models/jax/__init__.py +13 -0
- tpu_inference/models/jax/deepseek_v3.py +978 -0
- tpu_inference/models/jax/gpt_oss.py +508 -0
- tpu_inference/models/jax/jax_intermediate_tensor.py +93 -0
- tpu_inference/models/jax/llama3.py +436 -0
- tpu_inference/models/jax/llama4.py +643 -0
- tpu_inference/models/jax/llama_eagle3.py +350 -0
- tpu_inference/models/jax/llama_guard_4.py +375 -0
- tpu_inference/models/jax/qwen2.py +390 -0
- tpu_inference/models/jax/qwen2_5_vl.py +1232 -0
- tpu_inference/models/jax/qwen3.py +318 -0
- tpu_inference/models/jax/utils/__init__.py +13 -0
- tpu_inference/models/jax/utils/file_utils.py +110 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +177 -0
- tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
- tpu_inference/models/jax/utils/qwix/qwix_utils.py +713 -0
- tpu_inference/models/jax/utils/weight_utils.py +621 -0
- tpu_inference/models/vllm/__init__.py +13 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +307 -0
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +59 -0
- tpu_inference/platforms/__init__.py +16 -0
- tpu_inference/platforms/tpu_platform.py +258 -0
- tpu_inference/runner/__init__.py +13 -0
- tpu_inference/runner/block_table.py +122 -0
- tpu_inference/runner/compilation_manager.py +890 -0
- tpu_inference/runner/input_batch.py +435 -0
- tpu_inference/runner/kv_cache.py +166 -0
- tpu_inference/runner/kv_cache_manager.py +508 -0
- tpu_inference/runner/lora_utils.py +106 -0
- tpu_inference/runner/multimodal_manager.py +231 -0
- tpu_inference/runner/persistent_batch_manager.py +296 -0
- tpu_inference/runner/speculative_decoding_manager.py +262 -0
- tpu_inference/runner/structured_decoding_manager.py +101 -0
- tpu_inference/runner/tpu_runner.py +1768 -0
- tpu_inference/runner/utils.py +426 -0
- tpu_inference/spec_decode/__init__.py +13 -0
- tpu_inference/spec_decode/jax/__init__.py +13 -0
- tpu_inference/spec_decode/jax/eagle3.py +430 -0
- tpu_inference/tpu_info.py +92 -0
- tpu_inference/utils.py +345 -0
- tpu_inference/worker/__init__.py +13 -0
- tpu_inference/worker/tpu_worker.py +468 -0
- tpu_inference-0.12.0.dev20251222.dist-info/METADATA +106 -0
- tpu_inference-0.12.0.dev20251222.dist-info/RECORD +260 -0
- tpu_inference-0.12.0.dev20251222.dist-info/WHEEL +5 -0
- tpu_inference-0.12.0.dev20251222.dist-info/licenses/LICENSE +201 -0
- tpu_inference-0.12.0.dev20251222.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,876 @@
|
|
|
1
|
+
# Copyright 2025 The JAX Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
"""TPU-Friendly Ragged Paged Attention kernel.
|
|
15
|
+
|
|
16
|
+
This kernel offers a highly optimized implementation of ragged paged attention,
|
|
17
|
+
specifically designed for TPU and compatible with a wide range of model
|
|
18
|
+
specifications. It supports mixed prefill and decoding, enhancing throughput
|
|
19
|
+
during inference.
|
|
20
|
+
"""
|
|
21
|
+
import functools
|
|
22
|
+
|
|
23
|
+
import jax
|
|
24
|
+
import jax.numpy as jnp
|
|
25
|
+
from jax import lax
|
|
26
|
+
from jax._src import dtypes
|
|
27
|
+
from jax.experimental import pallas as pl
|
|
28
|
+
from jax.experimental.pallas import tpu as pltpu
|
|
29
|
+
from jax.experimental.pallas.ops.tpu.ragged_paged_attention.tuned_block_sizes import \
|
|
30
|
+
get_tuned_block_sizes
|
|
31
|
+
|
|
32
|
+
DEFAULT_MASK_VALUE = -0.7 * float(jnp.finfo(jnp.dtype("float32")).max)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class MultiPageAsyncCopyDescriptor:
|
|
36
|
+
"""Descriptor for async copy of multiple K/V pages from HBM."""
|
|
37
|
+
|
|
38
|
+
def __init__(
|
|
39
|
+
self,
|
|
40
|
+
pages_hbm_ref, # [total_num_pages, page_size, num_combined_kv_heads_per_blk, head_dim]
|
|
41
|
+
vmem_buf, # [num_kv_pages_per_blk, page_size, num_combined_kv_heads_per_blk, head_dim]
|
|
42
|
+
sem,
|
|
43
|
+
page_indices_ref, # i32[max_num_seqs, pages_per_seq]
|
|
44
|
+
metadata, # [seq_idx, start_page_idx, end_page_idx]
|
|
45
|
+
):
|
|
46
|
+
self._vmem_buf = vmem_buf
|
|
47
|
+
seq_id, start_page_idx, end_page_idx = metadata
|
|
48
|
+
self._async_copies = []
|
|
49
|
+
# TODO(jevinjiang): Only fetch dynamic shape in need! This will insert
|
|
50
|
+
# a bunch of if-ops. Check the performance when we have benchmarking setup.
|
|
51
|
+
for i in range(vmem_buf.shape[0]):
|
|
52
|
+
page_idx = start_page_idx + i
|
|
53
|
+
page_idx = jax.lax.select(page_idx < end_page_idx, page_idx, 0)
|
|
54
|
+
self._async_copies.append(
|
|
55
|
+
pltpu.make_async_copy(
|
|
56
|
+
pages_hbm_ref.at[page_indices_ref[seq_id, page_idx]],
|
|
57
|
+
vmem_buf.at[i],
|
|
58
|
+
sem,
|
|
59
|
+
))
|
|
60
|
+
|
|
61
|
+
def start(self):
|
|
62
|
+
"""Starts the async copies."""
|
|
63
|
+
for async_copy in self._async_copies:
|
|
64
|
+
async_copy.start()
|
|
65
|
+
|
|
66
|
+
def wait(self):
|
|
67
|
+
for async_copy in self._async_copies:
|
|
68
|
+
async_copy.wait()
|
|
69
|
+
return self._vmem_buf
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def ref_ragged_paged_attention(
|
|
73
|
+
queries: jax.Array, # [max_num_batched_tokens, num_q_heads, head_dim]
|
|
74
|
+
kv_pages: jax.
|
|
75
|
+
Array, # [total_num_pages, page_size, num_combined_kv_heads, head_dim]
|
|
76
|
+
kv_lens: jax.Array, # i32[max_num_seqs]
|
|
77
|
+
page_indices: jax.Array, # i32[max_num_seqs, pages_per_seq]
|
|
78
|
+
cu_q_lens: jax.Array, # i32[max_num_seqs + 1]
|
|
79
|
+
num_seqs: jax.Array, # i32[1],
|
|
80
|
+
*,
|
|
81
|
+
sm_scale: float = 1.0,
|
|
82
|
+
sliding_window: int | None = None,
|
|
83
|
+
soft_cap: float | None = None,
|
|
84
|
+
mask_value: float | None = DEFAULT_MASK_VALUE,
|
|
85
|
+
k_scale: float | None = None,
|
|
86
|
+
v_scale: float | None = None,
|
|
87
|
+
):
|
|
88
|
+
static_validate_inputs(
|
|
89
|
+
queries,
|
|
90
|
+
kv_pages,
|
|
91
|
+
kv_lens,
|
|
92
|
+
page_indices,
|
|
93
|
+
cu_q_lens,
|
|
94
|
+
num_seqs,
|
|
95
|
+
sm_scale=sm_scale,
|
|
96
|
+
k_scale=k_scale,
|
|
97
|
+
v_scale=v_scale,
|
|
98
|
+
sliding_window=sliding_window,
|
|
99
|
+
soft_cap=soft_cap,
|
|
100
|
+
mask_value=mask_value,
|
|
101
|
+
)
|
|
102
|
+
if mask_value is None:
|
|
103
|
+
mask_value = DEFAULT_MASK_VALUE
|
|
104
|
+
_, _, num_combined_kv_heads, head_dim = kv_pages.shape
|
|
105
|
+
assert num_combined_kv_heads % 2 == 0
|
|
106
|
+
num_kv_heads = num_combined_kv_heads // 2
|
|
107
|
+
num_q_heads = queries.shape[1]
|
|
108
|
+
assert num_q_heads % num_kv_heads == 0
|
|
109
|
+
num_query_per_kv = num_q_heads // num_kv_heads
|
|
110
|
+
outputs = []
|
|
111
|
+
for i in range(num_seqs[0]):
|
|
112
|
+
q_start = cu_q_lens[i]
|
|
113
|
+
q_end = cu_q_lens[i + 1]
|
|
114
|
+
q_len = q_end - q_start
|
|
115
|
+
kv_len = kv_lens[i]
|
|
116
|
+
indices = page_indices[i]
|
|
117
|
+
q = queries[q_start:q_end]
|
|
118
|
+
k = kv_pages[indices, :, 0::2, :].reshape(-1, num_kv_heads,
|
|
119
|
+
head_dim)[:kv_len]
|
|
120
|
+
v = kv_pages[indices, :, 1::2, :].reshape(-1, num_kv_heads,
|
|
121
|
+
head_dim)[:kv_len]
|
|
122
|
+
if k_scale is not None:
|
|
123
|
+
k = k.astype(jnp.float32) * k_scale
|
|
124
|
+
k = k.astype(q.dtype)
|
|
125
|
+
if v_scale is not None:
|
|
126
|
+
v = v.astype(jnp.float32) * v_scale
|
|
127
|
+
v = v.astype(q.dtype)
|
|
128
|
+
k = jnp.repeat(k, num_query_per_kv, axis=1)
|
|
129
|
+
v = jnp.repeat(v, num_query_per_kv, axis=1)
|
|
130
|
+
attn = jnp.einsum("qhd,khd->hqk",
|
|
131
|
+
q,
|
|
132
|
+
k,
|
|
133
|
+
preferred_element_type=jnp.float32)
|
|
134
|
+
attn *= sm_scale
|
|
135
|
+
q_span = (kv_len - q_len) + jax.lax.broadcasted_iota(
|
|
136
|
+
jnp.int32, attn.shape, 1)
|
|
137
|
+
kv_span = jax.lax.broadcasted_iota(jnp.int32, attn.shape, 2)
|
|
138
|
+
mask = q_span < kv_span
|
|
139
|
+
if sliding_window is not None:
|
|
140
|
+
mask = jnp.logical_or(mask, q_span - sliding_window >= kv_span)
|
|
141
|
+
if soft_cap is not None:
|
|
142
|
+
attn = soft_cap * jnp.tanh(attn / soft_cap)
|
|
143
|
+
attn += jnp.where(mask, mask_value, 0.0)
|
|
144
|
+
attn = jax.nn.softmax(attn, axis=-1).astype(v.dtype)
|
|
145
|
+
out = jnp.einsum("hqk,khd->qhd", attn, v).astype(queries.dtype)
|
|
146
|
+
outputs.append(out)
|
|
147
|
+
|
|
148
|
+
return jnp.concatenate(outputs, axis=0)
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
# Expect to run these checks during runtime.
|
|
152
|
+
def dynamic_validate_inputs(
|
|
153
|
+
q: jax.Array, # [max_num_batched_tokens, num_q_heads, head_dim]
|
|
154
|
+
kv_pages: jax.
|
|
155
|
+
Array, # [total_num_pages, page_size, num_combined_kv_heads, head_dim]
|
|
156
|
+
kv_lens: jax.Array, # i32[max_num_seqs]
|
|
157
|
+
page_indices: jax.Array, # i32[max_num_seqs, pages_per_seq]
|
|
158
|
+
cu_q_lens: jax.Array, # i32[max_num_seqs + 1]
|
|
159
|
+
num_seqs: jax.Array, # i32[1]
|
|
160
|
+
*,
|
|
161
|
+
# These inputs are optional. If not specified, we will not validate them.
|
|
162
|
+
sm_scale: float | None = None,
|
|
163
|
+
sliding_window: int | None = None,
|
|
164
|
+
soft_cap: float | None = None,
|
|
165
|
+
mask_value: float | None = None,
|
|
166
|
+
k_scale: float | None = None,
|
|
167
|
+
v_scale: float | None = None,
|
|
168
|
+
# Kernel tuning params.
|
|
169
|
+
num_kv_pages_per_block: int | None = None,
|
|
170
|
+
num_queries_per_block: int | None = None,
|
|
171
|
+
vmem_limit_bytes: int | None = None,
|
|
172
|
+
):
|
|
173
|
+
static_validate_inputs(
|
|
174
|
+
q,
|
|
175
|
+
kv_pages,
|
|
176
|
+
kv_lens,
|
|
177
|
+
page_indices,
|
|
178
|
+
cu_q_lens,
|
|
179
|
+
num_seqs,
|
|
180
|
+
sm_scale=sm_scale,
|
|
181
|
+
sliding_window=sliding_window,
|
|
182
|
+
soft_cap=soft_cap,
|
|
183
|
+
mask_value=mask_value,
|
|
184
|
+
k_scale=k_scale,
|
|
185
|
+
v_scale=v_scale,
|
|
186
|
+
num_kv_pages_per_block=num_kv_pages_per_block,
|
|
187
|
+
num_queries_per_block=num_queries_per_block,
|
|
188
|
+
vmem_limit_bytes=vmem_limit_bytes,
|
|
189
|
+
)
|
|
190
|
+
max_num_batched_tokens = q.shape[0]
|
|
191
|
+
page_size = kv_pages.shape[1]
|
|
192
|
+
max_num_seqs, pages_per_seq = page_indices.shape
|
|
193
|
+
if num_seqs[0] > max_num_seqs:
|
|
194
|
+
raise ValueError(
|
|
195
|
+
f"{num_seqs[0]=} must be less or equal to {max_num_seqs=}")
|
|
196
|
+
max_kv_len = jnp.max(kv_lens)
|
|
197
|
+
min_pages_per_seq = cdiv(max_kv_len, page_size)
|
|
198
|
+
if pages_per_seq < min_pages_per_seq:
|
|
199
|
+
raise ValueError(
|
|
200
|
+
f"{pages_per_seq=} must be greater or equal to"
|
|
201
|
+
f" {min_pages_per_seq=} given {max_kv_len=} and {page_size=}.")
|
|
202
|
+
if cu_q_lens[num_seqs[0]] > max_num_batched_tokens:
|
|
203
|
+
raise ValueError(
|
|
204
|
+
f"Total q tokens {cu_q_lens[num_seqs[0]]} must be less or equal to"
|
|
205
|
+
f" {max_num_batched_tokens=}.")
|
|
206
|
+
for i in range(num_seqs[0]):
|
|
207
|
+
q_len = cu_q_lens[i + 1] - cu_q_lens[i]
|
|
208
|
+
kv_len = kv_lens[i]
|
|
209
|
+
if q_len > kv_len:
|
|
210
|
+
raise ValueError(
|
|
211
|
+
f"{q_len=} must be less or equal to {kv_len=} at sequence {i}."
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
# Expect to run these checks during compile time.
|
|
216
|
+
def static_validate_inputs(
|
|
217
|
+
q: jax.Array, # [max_num_batched_tokens, num_q_heads, head_dim]
|
|
218
|
+
kv_pages: jax.
|
|
219
|
+
Array, # [total_num_pages, page_size, num_combined_kv_heads, head_dim]
|
|
220
|
+
kv_lens: jax.Array, # i32[max_num_seqs]
|
|
221
|
+
page_indices: jax.Array, # i32[max_num_seqs, pages_per_seq]
|
|
222
|
+
cu_q_lens: jax.Array, # i32[max_num_seqs + 1]
|
|
223
|
+
num_seqs: jax.Array, # i32[1]
|
|
224
|
+
*,
|
|
225
|
+
# These inputs are optional. If not specified, we will not validate them.
|
|
226
|
+
sm_scale: float | None = None,
|
|
227
|
+
sliding_window: int | None = None,
|
|
228
|
+
soft_cap: float | None = None,
|
|
229
|
+
mask_value: float | None = None,
|
|
230
|
+
k_scale: float | None = None,
|
|
231
|
+
v_scale: float | None = None,
|
|
232
|
+
# Kernel tuning params.
|
|
233
|
+
num_kv_pages_per_block: int | None = None,
|
|
234
|
+
num_queries_per_block: int | None = None,
|
|
235
|
+
vmem_limit_bytes: int | None = None,
|
|
236
|
+
):
|
|
237
|
+
_, num_q_heads, head_dim = q.shape
|
|
238
|
+
_, _, num_combined_kv_heads, head_dim_k = kv_pages.shape
|
|
239
|
+
assert num_combined_kv_heads % 2 == 0
|
|
240
|
+
assert isinstance(k_scale, float) or k_scale is None
|
|
241
|
+
assert isinstance(v_scale, float) or v_scale is None
|
|
242
|
+
num_kv_heads = num_combined_kv_heads // 2
|
|
243
|
+
max_num_seqs, pages_per_seq = page_indices.shape
|
|
244
|
+
if num_seqs.shape != (1, ):
|
|
245
|
+
raise ValueError(f"{num_seqs.shape=} must be (1,)")
|
|
246
|
+
if head_dim_k != head_dim:
|
|
247
|
+
raise ValueError(
|
|
248
|
+
f"Q head_dim {head_dim} must be the same as that of K/V {head_dim_k}."
|
|
249
|
+
)
|
|
250
|
+
if kv_lens.shape != (max_num_seqs, ):
|
|
251
|
+
raise ValueError(
|
|
252
|
+
f"Expected {kv_lens.shape=} to be ({max_num_seqs},) where"
|
|
253
|
+
" `max_num_seqs` is `page_indices.shape[0]`.")
|
|
254
|
+
if cu_q_lens.shape != (max_num_seqs + 1, ):
|
|
255
|
+
raise ValueError(
|
|
256
|
+
f"Expected {cu_q_lens.shape=} to be ({max_num_seqs + 1},) where"
|
|
257
|
+
" `max_num_seqs` is `page_indices.shape[0]`.")
|
|
258
|
+
if (kv_lens.dtype != jnp.int32 or page_indices.dtype != jnp.int32
|
|
259
|
+
or cu_q_lens.dtype != jnp.int32):
|
|
260
|
+
raise ValueError(
|
|
261
|
+
"The dtype of `kv_lens`, `page_indices`, and `cu_q_lens` must be"
|
|
262
|
+
f" int32. Got {kv_lens.dtype=}, {page_indices.dtype=},"
|
|
263
|
+
f" {cu_q_lens.dtype=}.")
|
|
264
|
+
if num_q_heads % num_kv_heads != 0:
|
|
265
|
+
raise ValueError(
|
|
266
|
+
f"{num_q_heads=} must be divisible by {num_kv_heads=}")
|
|
267
|
+
if sliding_window is not None and sliding_window <= 0:
|
|
268
|
+
raise ValueError(f"{sliding_window=} must be positive.")
|
|
269
|
+
if soft_cap is not None and soft_cap == 0.0:
|
|
270
|
+
raise ValueError(f"{soft_cap=} must not be 0.0.")
|
|
271
|
+
if (num_kv_pages_per_block is not None
|
|
272
|
+
and not 0 < num_kv_pages_per_block <= pages_per_seq):
|
|
273
|
+
raise ValueError(
|
|
274
|
+
f"{num_kv_pages_per_block=} must be in range (0, {pages_per_seq}]."
|
|
275
|
+
)
|
|
276
|
+
if num_queries_per_block is not None and num_queries_per_block <= 0:
|
|
277
|
+
raise ValueError(f"{num_queries_per_block=} must be positive.")
|
|
278
|
+
if vmem_limit_bytes is not None and vmem_limit_bytes <= 0:
|
|
279
|
+
raise ValueError(f"{vmem_limit_bytes=} must be positive.")
|
|
280
|
+
del sm_scale # No constraints on sm_scale.
|
|
281
|
+
del mask_value # No consstraints on mask_value.
|
|
282
|
+
|
|
283
|
+
|
|
284
|
+
def ragged_paged_attention_kernel(
|
|
285
|
+
# Prefetch
|
|
286
|
+
kv_lens_ref, # [max_num_seqs]
|
|
287
|
+
page_indices_ref, # [max_num_seqs, pages_per_seq]
|
|
288
|
+
cu_q_lens_ref, # [max_num_seqs + 1]
|
|
289
|
+
seq_buf_idx_ref,
|
|
290
|
+
# TODO(jevinjiang): if OOM in SMEM, consider pack to other scalar refs.
|
|
291
|
+
num_seqs_ref,
|
|
292
|
+
# Input
|
|
293
|
+
q_ref, # [num_q_per_blk, num_q_heads_per_blk, head_dim]
|
|
294
|
+
kv_pages_hbm_ref, # [total_num_pages, page_size, num_combined_kv_heads, head_dim]
|
|
295
|
+
# Output
|
|
296
|
+
o_ref, # [num_q_per_blk, num_q_heads_per_blk, head_dim]
|
|
297
|
+
# Scratch
|
|
298
|
+
kv_bufs, # [2, num_kv_pages_per_blk, page_size, num_combined_kv_heads_per_blk, head_dim]
|
|
299
|
+
sems, # [2, 2]
|
|
300
|
+
l_ref, # [num_kv_heads_per_blk, num_q_per_blk * num_q_heads_per_kv_head, 128]
|
|
301
|
+
m_ref, # [num_kv_heads_per_blk, num_q_per_blk * num_q_heads_per_kv_head, 128]
|
|
302
|
+
acc_ref, # [num_q_per_blk, num_q_heads_per_blk, head_dim]
|
|
303
|
+
*,
|
|
304
|
+
sm_scale: float,
|
|
305
|
+
sliding_window: int | None = None,
|
|
306
|
+
soft_cap: float | None = None,
|
|
307
|
+
mask_value: float | None = DEFAULT_MASK_VALUE,
|
|
308
|
+
k_scale: float | None = None,
|
|
309
|
+
v_scale: float | None = None,
|
|
310
|
+
):
|
|
311
|
+
if mask_value is None:
|
|
312
|
+
mask_value = DEFAULT_MASK_VALUE
|
|
313
|
+
num_q_per_blk, num_q_heads_per_blk, head_dim = q_ref.shape
|
|
314
|
+
pages_per_seq = page_indices_ref.shape[-1]
|
|
315
|
+
num_seqs = num_seqs_ref[0]
|
|
316
|
+
_, num_kv_pages_per_blk, page_size, num_combined_kv_heads_per_blk, _ = (
|
|
317
|
+
kv_bufs.shape)
|
|
318
|
+
num_kv_heads_per_blk = num_combined_kv_heads_per_blk // 2
|
|
319
|
+
num_kv_per_blk = num_kv_pages_per_blk * page_size
|
|
320
|
+
num_q_heads_per_kv_head = num_q_heads_per_blk // num_kv_heads_per_blk
|
|
321
|
+
heads_blk_idx, q_blk_idx = (
|
|
322
|
+
pl.program_id(0),
|
|
323
|
+
pl.program_id(1),
|
|
324
|
+
)
|
|
325
|
+
num_heads_blks = pl.num_programs(0)
|
|
326
|
+
init_seq_idx = seq_buf_idx_ref[0]
|
|
327
|
+
init_buf_idx = seq_buf_idx_ref[1]
|
|
328
|
+
q_len_start = q_blk_idx * num_q_per_blk
|
|
329
|
+
q_len_end = q_len_start + num_q_per_blk
|
|
330
|
+
|
|
331
|
+
def create_kv_async_copy_descriptors(heads_blk_idx, seq_idx, kv_blk_idx,
|
|
332
|
+
buf_idx):
|
|
333
|
+
start_kv_page_idx = kv_blk_idx * num_kv_pages_per_blk
|
|
334
|
+
end_kv_page_idx = jnp.minimum(pages_per_seq,
|
|
335
|
+
cdiv(kv_lens_ref[seq_idx], page_size))
|
|
336
|
+
metadata = (seq_idx, start_kv_page_idx, end_kv_page_idx)
|
|
337
|
+
heads_start = heads_blk_idx * num_combined_kv_heads_per_blk
|
|
338
|
+
async_copy_kv = MultiPageAsyncCopyDescriptor(
|
|
339
|
+
kv_pages_hbm_ref.
|
|
340
|
+
at[:, :,
|
|
341
|
+
pl.ds(heads_start, num_combined_kv_heads_per_blk), :],
|
|
342
|
+
kv_bufs.at[buf_idx],
|
|
343
|
+
sems.at[buf_idx],
|
|
344
|
+
page_indices_ref,
|
|
345
|
+
metadata,
|
|
346
|
+
)
|
|
347
|
+
return async_copy_kv
|
|
348
|
+
|
|
349
|
+
# TODO(jevinjiang): Add these to Mosaic:
|
|
350
|
+
# 1. Support arbitrary strided load/store for int4 and int8 dtype.
|
|
351
|
+
# 2. Support arbitrary strided load/store for any last dimension.
|
|
352
|
+
def strided_load_kv(ref, start, step):
|
|
353
|
+
packing = get_dtype_packing(ref.dtype)
|
|
354
|
+
if packing == 1:
|
|
355
|
+
return [ref[start::step, :]], [ref[start + 1::step, :]]
|
|
356
|
+
assert packing in (2, 4, 8)
|
|
357
|
+
assert step % packing == 0
|
|
358
|
+
k_list, v_list = [], []
|
|
359
|
+
b_start = start // packing
|
|
360
|
+
b_step = step // packing
|
|
361
|
+
b_ref = ref.bitcast(jnp.uint32)
|
|
362
|
+
b = b_ref[b_start::b_step, :]
|
|
363
|
+
|
|
364
|
+
# TODO(chengjiyao): use the general strided loading logic for bf16 after
|
|
365
|
+
# fixing the issue in mosaic's infer vector layout pass
|
|
366
|
+
if ref.dtype == jnp.bfloat16:
|
|
367
|
+
bk = b << 16
|
|
368
|
+
bv = b & jnp.uint32(0xFFFF0000)
|
|
369
|
+
k = pltpu.bitcast(bk, jnp.float32).astype(jnp.bfloat16)
|
|
370
|
+
v = pltpu.bitcast(bv, jnp.float32).astype(jnp.bfloat16)
|
|
371
|
+
k_list.append(k)
|
|
372
|
+
v_list.append(v)
|
|
373
|
+
else:
|
|
374
|
+
bitwidth = 32 // packing
|
|
375
|
+
bitcast_dst_dtype = jnp.dtype(f"uint{bitwidth}")
|
|
376
|
+
for i in range(0, packing, 2):
|
|
377
|
+
bk = b >> (i * bitwidth)
|
|
378
|
+
k = pltpu.bitcast(bk.astype(bitcast_dst_dtype), ref.dtype)
|
|
379
|
+
k_list.append(k)
|
|
380
|
+
bv = b >> ((i + 1) * bitwidth)
|
|
381
|
+
v = pltpu.bitcast(bv.astype(bitcast_dst_dtype), ref.dtype)
|
|
382
|
+
v_list.append(v)
|
|
383
|
+
|
|
384
|
+
return k_list, v_list
|
|
385
|
+
|
|
386
|
+
def fold_on_2nd_minor(vec):
|
|
387
|
+
assert vec.dtype == jnp.bfloat16 or vec.dtype == jnp.float32
|
|
388
|
+
assert len(vec.shape) >= 2
|
|
389
|
+
last_dim = vec.shape[-1]
|
|
390
|
+
packing = get_dtype_packing(vec.dtype)
|
|
391
|
+
if vec.shape[-2] % packing != 0:
|
|
392
|
+
vec = vec.astype(jnp.float32)
|
|
393
|
+
return vec.reshape(-1, last_dim)
|
|
394
|
+
|
|
395
|
+
@pl.when(heads_blk_idx + q_blk_idx == 0)
|
|
396
|
+
def prefetch_first_kv_blk():
|
|
397
|
+
async_copy_kv = create_kv_async_copy_descriptors(
|
|
398
|
+
heads_blk_idx, init_seq_idx, 0, init_buf_idx)
|
|
399
|
+
async_copy_kv.start()
|
|
400
|
+
|
|
401
|
+
def is_cur_q_blk_needed(q_states):
|
|
402
|
+
done, cur_seq_idx, _ = q_states
|
|
403
|
+
should_run = jnp.logical_and(q_len_start < cu_q_lens_ref[num_seqs],
|
|
404
|
+
cur_seq_idx < num_seqs)
|
|
405
|
+
return jnp.logical_and(done == 0, should_run)
|
|
406
|
+
|
|
407
|
+
def compute_with_cur_q_blk(q_states):
|
|
408
|
+
done, cur_seq_idx, cur_buf_idx = q_states
|
|
409
|
+
q_start = cu_q_lens_ref[cur_seq_idx]
|
|
410
|
+
q_end = cu_q_lens_ref[cur_seq_idx + 1]
|
|
411
|
+
q_len = q_end - q_start
|
|
412
|
+
kv_len = kv_lens_ref[cur_seq_idx]
|
|
413
|
+
|
|
414
|
+
def get_next_prefetch_ids(heads_blk_idx, cur_seq_idx, kv_blk_idx,
|
|
415
|
+
cur_buf_idx):
|
|
416
|
+
next_kv_blk_idx = kv_blk_idx + 1
|
|
417
|
+
is_last_kv_blk = next_kv_blk_idx * num_kv_per_blk >= kv_len
|
|
418
|
+
next_kv_blk_idx = lax.select(
|
|
419
|
+
is_last_kv_blk,
|
|
420
|
+
0,
|
|
421
|
+
next_kv_blk_idx,
|
|
422
|
+
)
|
|
423
|
+
is_cur_seq_end_in_cur_q_blk = q_end <= q_len_end
|
|
424
|
+
next_seq_idx = lax.select(
|
|
425
|
+
is_last_kv_blk,
|
|
426
|
+
lax.select(is_cur_seq_end_in_cur_q_blk, cur_seq_idx + 1,
|
|
427
|
+
cur_seq_idx),
|
|
428
|
+
cur_seq_idx,
|
|
429
|
+
)
|
|
430
|
+
is_last_seq = next_seq_idx == num_seqs
|
|
431
|
+
next_seq_idx = lax.select(
|
|
432
|
+
is_last_seq,
|
|
433
|
+
0,
|
|
434
|
+
next_seq_idx,
|
|
435
|
+
)
|
|
436
|
+
next_heads_blk_idx = lax.select(
|
|
437
|
+
is_last_seq,
|
|
438
|
+
heads_blk_idx + 1,
|
|
439
|
+
heads_blk_idx,
|
|
440
|
+
)
|
|
441
|
+
next_buf_idx = lax.select(cur_buf_idx == 0, 1, 0)
|
|
442
|
+
return next_heads_blk_idx, next_seq_idx, next_kv_blk_idx, next_buf_idx
|
|
443
|
+
|
|
444
|
+
def flash_attention(
|
|
445
|
+
q, # [num_q_per_blk * num_q_heads_per_kv_head, head_dim]
|
|
446
|
+
k, # [num_kv_per_blk, head_dim]
|
|
447
|
+
v, # [num_kv_per_blk, head_dim]
|
|
448
|
+
head_l_ref, # [num_q_per_blk * num_q_heads_per_kv_head, 128]
|
|
449
|
+
head_m_ref, # [num_q_per_blk * num_q_heads_per_kv_head, 128]
|
|
450
|
+
head_acc_ref, # [num_q_per_blk, num_q_heads_per_kv_head, head_dim]
|
|
451
|
+
*,
|
|
452
|
+
kv_blk_idx,
|
|
453
|
+
):
|
|
454
|
+
assert q.shape == (
|
|
455
|
+
num_q_per_blk * num_q_heads_per_kv_head,
|
|
456
|
+
head_dim,
|
|
457
|
+
)
|
|
458
|
+
assert (k.shape == v.shape == (
|
|
459
|
+
num_kv_per_blk,
|
|
460
|
+
head_dim,
|
|
461
|
+
))
|
|
462
|
+
assert k.dtype == v.dtype
|
|
463
|
+
assert (head_m_ref.shape == head_l_ref.shape == (
|
|
464
|
+
num_q_per_blk * num_q_heads_per_kv_head,
|
|
465
|
+
128,
|
|
466
|
+
))
|
|
467
|
+
assert head_acc_ref.shape == (
|
|
468
|
+
num_q_per_blk,
|
|
469
|
+
num_q_heads_per_kv_head,
|
|
470
|
+
head_dim,
|
|
471
|
+
)
|
|
472
|
+
kv_len_start = kv_blk_idx * num_kv_per_blk
|
|
473
|
+
|
|
474
|
+
def masked_store(ref, val, start, end, group=1):
|
|
475
|
+
iota = lax.broadcasted_iota(jnp.int32, ref.shape, 0) // group
|
|
476
|
+
mask = jnp.logical_and(iota >= start, iota < end)
|
|
477
|
+
pl.store(ref,
|
|
478
|
+
idx=tuple(slice(None) for _ in ref.shape),
|
|
479
|
+
val=val,
|
|
480
|
+
mask=mask)
|
|
481
|
+
|
|
482
|
+
def load_with_init(ref, init_val):
|
|
483
|
+
return jnp.where(kv_blk_idx == 0, jnp.full_like(ref, init_val),
|
|
484
|
+
ref[...])
|
|
485
|
+
|
|
486
|
+
# kv lens will be contracting dim, we should mask out the NaNs.
|
|
487
|
+
kv_mask = (lax.broadcasted_iota(jnp.int32, k.shape, 0)
|
|
488
|
+
< kv_len - kv_len_start)
|
|
489
|
+
k = jnp.where(kv_mask, k.astype(jnp.float32), 0).astype(k.dtype)
|
|
490
|
+
v = jnp.where(kv_mask, v.astype(jnp.float32), 0).astype(v.dtype)
|
|
491
|
+
|
|
492
|
+
qk = (jnp.einsum(
|
|
493
|
+
"nd,md->nm", q, k, preferred_element_type=jnp.float32) *
|
|
494
|
+
sm_scale)
|
|
495
|
+
store_start = jnp.maximum(q_start - q_len_start, 0)
|
|
496
|
+
store_end = jnp.minimum(q_end - q_len_start, num_q_per_blk)
|
|
497
|
+
|
|
498
|
+
row_ids = (
|
|
499
|
+
(kv_len - q_len) + q_len_start - q_start +
|
|
500
|
+
jax.lax.broadcasted_iota(
|
|
501
|
+
jnp.int32,
|
|
502
|
+
(num_q_per_blk * num_q_heads_per_kv_head, num_kv_per_blk),
|
|
503
|
+
0,
|
|
504
|
+
) // num_q_heads_per_kv_head)
|
|
505
|
+
col_ids = kv_len_start + jax.lax.broadcasted_iota(
|
|
506
|
+
jnp.int32,
|
|
507
|
+
(num_q_per_blk * num_q_heads_per_kv_head, num_kv_per_blk),
|
|
508
|
+
1,
|
|
509
|
+
)
|
|
510
|
+
causal_mask = row_ids < col_ids
|
|
511
|
+
if sliding_window is not None:
|
|
512
|
+
causal_mask = jnp.logical_or(
|
|
513
|
+
causal_mask, row_ids - sliding_window >= col_ids)
|
|
514
|
+
if soft_cap is not None:
|
|
515
|
+
qk = soft_cap * jnp.tanh(qk / soft_cap)
|
|
516
|
+
qk += jnp.where(causal_mask, mask_value, 0.0)
|
|
517
|
+
m_curr = jnp.max(qk, axis=1, keepdims=True)
|
|
518
|
+
s_curr = jnp.exp(qk - m_curr)
|
|
519
|
+
qkv = jnp.dot(s_curr, v, preferred_element_type=jnp.float32)
|
|
520
|
+
lm_store_shape = head_m_ref.shape
|
|
521
|
+
m_curr = jnp.broadcast_to(m_curr, lm_store_shape)
|
|
522
|
+
l_curr = jnp.broadcast_to(s_curr.sum(axis=1, keepdims=True),
|
|
523
|
+
lm_store_shape)
|
|
524
|
+
m_prev = load_with_init(head_m_ref, -jnp.inf)
|
|
525
|
+
l_prev = load_with_init(head_l_ref, 0.0)
|
|
526
|
+
m_next = jnp.maximum(m_prev, m_curr)
|
|
527
|
+
masked_store(head_m_ref, m_next, store_start, store_end,
|
|
528
|
+
num_q_heads_per_kv_head)
|
|
529
|
+
alpha = jnp.exp(m_prev - m_next)
|
|
530
|
+
beta = jnp.exp(m_curr - m_next)
|
|
531
|
+
l_alpha = alpha * l_prev
|
|
532
|
+
l_next = l_alpha + beta * l_curr
|
|
533
|
+
l_next_safe = jnp.where(l_next == 0.0, 1.0, l_next)
|
|
534
|
+
masked_store(
|
|
535
|
+
head_l_ref,
|
|
536
|
+
l_next_safe,
|
|
537
|
+
store_start,
|
|
538
|
+
store_end,
|
|
539
|
+
num_q_heads_per_kv_head,
|
|
540
|
+
)
|
|
541
|
+
|
|
542
|
+
def broadcast_to_shape(arr, shape):
|
|
543
|
+
if arr.shape == shape:
|
|
544
|
+
return arr
|
|
545
|
+
assert len(arr.shape) == len(shape)
|
|
546
|
+
assert arr.shape[0] == shape[0]
|
|
547
|
+
assert shape[1] % arr.shape[1] == 0
|
|
548
|
+
# no-op concatenation.
|
|
549
|
+
return jnp.concatenate(
|
|
550
|
+
[arr for _ in range(shape[1] // arr.shape[1])], axis=1)
|
|
551
|
+
|
|
552
|
+
o_curr = load_with_init(head_acc_ref, 0.0).reshape(-1, head_dim)
|
|
553
|
+
l_alpha = broadcast_to_shape(l_alpha, qkv.shape)
|
|
554
|
+
beta = broadcast_to_shape(beta, qkv.shape)
|
|
555
|
+
l_next_safe = broadcast_to_shape(l_next_safe, qkv.shape)
|
|
556
|
+
out = lax.div(
|
|
557
|
+
l_alpha * o_curr + beta * qkv,
|
|
558
|
+
l_next_safe,
|
|
559
|
+
)
|
|
560
|
+
masked_store(
|
|
561
|
+
head_acc_ref,
|
|
562
|
+
out.reshape(head_acc_ref.shape),
|
|
563
|
+
store_start,
|
|
564
|
+
store_end,
|
|
565
|
+
)
|
|
566
|
+
|
|
567
|
+
def is_valid_kv_blk_in_cur_seq(kv_states):
|
|
568
|
+
kv_blk_idx, _ = kv_states
|
|
569
|
+
return kv_blk_idx * num_kv_per_blk < kv_len
|
|
570
|
+
|
|
571
|
+
def compute_with_kv_blk_in_cur_seq(kv_states):
|
|
572
|
+
kv_blk_idx, cur_buf_idx = kv_states
|
|
573
|
+
next_heads_blk_idx, next_seq_idx, next_kv_blk_idx, next_buf_idx = (
|
|
574
|
+
get_next_prefetch_ids(heads_blk_idx, cur_seq_idx, kv_blk_idx,
|
|
575
|
+
cur_buf_idx))
|
|
576
|
+
|
|
577
|
+
@pl.when(next_heads_blk_idx < num_heads_blks)
|
|
578
|
+
def prefetch_next_kv_blk():
|
|
579
|
+
# TODO(jevinjiang): reuse the same buffer if it is already prefetched!
|
|
580
|
+
# TODO(jevinjiang): only fetch effective dynamic size to hold kv_len and
|
|
581
|
+
# DMA to fixed size buffer!
|
|
582
|
+
next_async_copy_kv = create_kv_async_copy_descriptors(
|
|
583
|
+
next_heads_blk_idx, next_seq_idx, next_kv_blk_idx,
|
|
584
|
+
next_buf_idx)
|
|
585
|
+
next_async_copy_kv.start()
|
|
586
|
+
|
|
587
|
+
cur_async_copy_kv = create_kv_async_copy_descriptors(
|
|
588
|
+
heads_blk_idx, cur_seq_idx, kv_blk_idx, cur_buf_idx)
|
|
589
|
+
kv_ref = cur_async_copy_kv.wait().reshape(
|
|
590
|
+
num_kv_pages_per_blk * page_size *
|
|
591
|
+
num_combined_kv_heads_per_blk,
|
|
592
|
+
head_dim,
|
|
593
|
+
)
|
|
594
|
+
kv_packing = get_dtype_packing(kv_ref.dtype)
|
|
595
|
+
# NOTE: kv_packing is divided by 2 because k and v are packed together.
|
|
596
|
+
kv_load_step = max(1, kv_packing // 2)
|
|
597
|
+
for kv_head_chunk_idx in range(0, num_kv_heads_per_blk,
|
|
598
|
+
kv_load_step):
|
|
599
|
+
k_list, v_list = strided_load_kv(
|
|
600
|
+
kv_ref, kv_head_chunk_idx * 2,
|
|
601
|
+
num_combined_kv_heads_per_blk)
|
|
602
|
+
for step_idx in range(kv_load_step):
|
|
603
|
+
k = k_list[step_idx]
|
|
604
|
+
v = v_list[step_idx]
|
|
605
|
+
if k_scale is not None:
|
|
606
|
+
# NOTE: Conversion between arbitrary data types is not supported.
|
|
607
|
+
# That's why it is converted to float32 first.
|
|
608
|
+
k = k.astype(jnp.float32) * k_scale
|
|
609
|
+
k = k.astype(q_ref.dtype)
|
|
610
|
+
if v_scale is not None:
|
|
611
|
+
v = v.astype(jnp.float32) * v_scale
|
|
612
|
+
v = v.astype(q_ref.dtype)
|
|
613
|
+
kv_head_idx = kv_head_chunk_idx + step_idx
|
|
614
|
+
q_head_idx = kv_head_idx * num_q_heads_per_kv_head
|
|
615
|
+
# TODO(jevinjiang): extra handling for packed type that can start at
|
|
616
|
+
# unaligned position!
|
|
617
|
+
q = fold_on_2nd_minor(q_ref[:, q_head_idx:q_head_idx +
|
|
618
|
+
num_q_heads_per_kv_head, :])
|
|
619
|
+
flash_attention(
|
|
620
|
+
q,
|
|
621
|
+
k,
|
|
622
|
+
v,
|
|
623
|
+
l_ref.at[kv_head_idx],
|
|
624
|
+
m_ref.at[kv_head_idx],
|
|
625
|
+
acc_ref.at[:, q_head_idx:q_head_idx +
|
|
626
|
+
num_q_heads_per_kv_head, :],
|
|
627
|
+
kv_blk_idx=kv_blk_idx,
|
|
628
|
+
)
|
|
629
|
+
return kv_blk_idx + 1, next_buf_idx
|
|
630
|
+
|
|
631
|
+
_, next_buf_idx = lax.while_loop(
|
|
632
|
+
is_valid_kv_blk_in_cur_seq,
|
|
633
|
+
compute_with_kv_blk_in_cur_seq,
|
|
634
|
+
(0, cur_buf_idx), # (kv_blk_idx, buf_idx)
|
|
635
|
+
)
|
|
636
|
+
next_seq_idx = lax.select(q_end <= q_len_end, cur_seq_idx + 1,
|
|
637
|
+
cur_seq_idx)
|
|
638
|
+
done = lax.select(q_end < q_len_end, done, 1)
|
|
639
|
+
return done, next_seq_idx, next_buf_idx
|
|
640
|
+
|
|
641
|
+
_, seq_idx, buf_idx = lax.while_loop(
|
|
642
|
+
is_cur_q_blk_needed,
|
|
643
|
+
compute_with_cur_q_blk,
|
|
644
|
+
(0, init_seq_idx, init_buf_idx), # (done, seq_idx, buf_idx)
|
|
645
|
+
)
|
|
646
|
+
# Reset seq_idx for next kv_heads_blk if run out of seqs!
|
|
647
|
+
seq_buf_idx_ref[0] = lax.select(seq_idx < num_seqs, seq_idx, 0)
|
|
648
|
+
seq_buf_idx_ref[1] = buf_idx
|
|
649
|
+
o_ref[...] = acc_ref[...].astype(q_ref.dtype)
|
|
650
|
+
|
|
651
|
+
|
|
652
|
+
def cdiv(a, b):
|
|
653
|
+
assert b != 0
|
|
654
|
+
return (a + b - 1) // b
|
|
655
|
+
|
|
656
|
+
|
|
657
|
+
def get_dtype_packing(dtype):
|
|
658
|
+
bits = (dtypes.bit_width(dtype)
|
|
659
|
+
if hasattr(dtypes, "bit_width") else dtypes.itemsize_bits(dtype))
|
|
660
|
+
return 32 // bits
|
|
661
|
+
|
|
662
|
+
|
|
663
|
+
def get_min_heads_per_blk(num_q_heads, num_combined_kv_heads, q_dtype,
|
|
664
|
+
kv_dtype):
|
|
665
|
+
q_packing = get_dtype_packing(q_dtype)
|
|
666
|
+
kv_packing = get_dtype_packing(kv_dtype)
|
|
667
|
+
|
|
668
|
+
def can_be_xla_fully_tiled(x, packing):
|
|
669
|
+
if x % packing != 0:
|
|
670
|
+
return False
|
|
671
|
+
x //= packing
|
|
672
|
+
return x in (1, 2, 4, 8) or x % 8 == 0
|
|
673
|
+
|
|
674
|
+
# TODO(jevinjiang): support unaligned number of heads!
|
|
675
|
+
if not can_be_xla_fully_tiled(num_combined_kv_heads, kv_packing):
|
|
676
|
+
raise ValueError(
|
|
677
|
+
f"Not implemented: {num_combined_kv_heads=} can not be XLA fully tiled."
|
|
678
|
+
)
|
|
679
|
+
assert num_combined_kv_heads % 2 == 0
|
|
680
|
+
num_kv_heads = num_combined_kv_heads // 2
|
|
681
|
+
assert num_q_heads % num_kv_heads == 0
|
|
682
|
+
ratio = num_q_heads // num_kv_heads
|
|
683
|
+
# TODO(jevinjiang): we can choose smaller tiling for packed type if large
|
|
684
|
+
# second minor tiling is not on.
|
|
685
|
+
max_combined_kv_tiling = 8 * kv_packing
|
|
686
|
+
min_combined_kv_heads = (max_combined_kv_tiling if num_combined_kv_heads %
|
|
687
|
+
max_combined_kv_tiling == 0 else
|
|
688
|
+
num_combined_kv_heads)
|
|
689
|
+
min_q_heads = min_combined_kv_heads // 2 * ratio
|
|
690
|
+
if can_be_xla_fully_tiled(min_q_heads, q_packing):
|
|
691
|
+
return min_q_heads, min_combined_kv_heads
|
|
692
|
+
return num_q_heads, num_combined_kv_heads
|
|
693
|
+
|
|
694
|
+
|
|
695
|
+
@functools.partial(
|
|
696
|
+
jax.jit,
|
|
697
|
+
static_argnames=[
|
|
698
|
+
"sm_scale",
|
|
699
|
+
"mask_value",
|
|
700
|
+
"num_kv_pages_per_block",
|
|
701
|
+
"num_queries_per_block",
|
|
702
|
+
"vmem_limit_bytes",
|
|
703
|
+
"sliding_window",
|
|
704
|
+
"soft_cap",
|
|
705
|
+
"k_scale",
|
|
706
|
+
"v_scale",
|
|
707
|
+
],
|
|
708
|
+
)
|
|
709
|
+
def ragged_paged_attention(
|
|
710
|
+
q: jax.Array, # [max_num_batched_tokens, num_q_heads, head_dim]
|
|
711
|
+
# TODO(jevinjiang): create a write_to_kv_cache kernel!
|
|
712
|
+
kv_pages: jax.
|
|
713
|
+
Array, # [total_num_pages, page_size, num_combined_kv_heads, head_dim]
|
|
714
|
+
kv_lens: jax.Array, # i32[max_num_seqs]
|
|
715
|
+
page_indices: jax.Array, # i32[max_num_seqs, pages_per_seq]
|
|
716
|
+
cu_q_lens: jax.Array, # i32[max_num_seqs + 1]
|
|
717
|
+
num_seqs: jax.Array, # i32[1]
|
|
718
|
+
*,
|
|
719
|
+
sm_scale: float = 1.0,
|
|
720
|
+
sliding_window: int | None = None,
|
|
721
|
+
soft_cap: float | None = None,
|
|
722
|
+
mask_value: float | None = DEFAULT_MASK_VALUE,
|
|
723
|
+
k_scale: float | None = None,
|
|
724
|
+
v_scale: float | None = None,
|
|
725
|
+
num_kv_pages_per_block: int | None = None,
|
|
726
|
+
num_queries_per_block: int | None = None,
|
|
727
|
+
vmem_limit_bytes: int | None = None,
|
|
728
|
+
):
|
|
729
|
+
"""Ragged paged attention that supports mixed prefill and decode.
|
|
730
|
+
|
|
731
|
+
Args:
|
|
732
|
+
q: concatenated all sequences' queries.
|
|
733
|
+
kv_pages: paged KV cache. Normally in HBM.
|
|
734
|
+
kv_lens: padded kv lengths. Only the first num_seqs values are valid.
|
|
735
|
+
page_indices: the first index indicates which page to use in the kv cache
|
|
736
|
+
for each sequence. Only the first num_seqs values are valid.
|
|
737
|
+
cu_q_lens: the cumulative sum of the effective query lengths. Similar to
|
|
738
|
+
kv_lens, only the first num_seqs+1 values are valid.
|
|
739
|
+
num_seqs: the dynamic number of sequences.
|
|
740
|
+
sm_scale: the softmax scale which will be applied to the Q@K^T.
|
|
741
|
+
sliding_window: the sliding window size for the attention.
|
|
742
|
+
soft_cap: the logit soft cap for the attention.
|
|
743
|
+
mask_value: mask value for causal mask.
|
|
744
|
+
k_scale: the scale for the key cache.
|
|
745
|
+
v_scale: the scale for the value cache.
|
|
746
|
+
num_kv_pages_per_block: number of kv pages to be processed in one flash
|
|
747
|
+
attention block in the pallas kernel.
|
|
748
|
+
num_queries_per_block: number of kv pages to be processed in one flash
|
|
749
|
+
attention block in the pallas kernel.
|
|
750
|
+
vmem_limit_bytes: the vmem limit for the pallas kernel.
|
|
751
|
+
|
|
752
|
+
Returns:
|
|
753
|
+
The output of the attention.
|
|
754
|
+
"""
|
|
755
|
+
static_validate_inputs(
|
|
756
|
+
q,
|
|
757
|
+
kv_pages,
|
|
758
|
+
kv_lens,
|
|
759
|
+
page_indices,
|
|
760
|
+
cu_q_lens,
|
|
761
|
+
num_seqs,
|
|
762
|
+
sm_scale=sm_scale,
|
|
763
|
+
sliding_window=sliding_window,
|
|
764
|
+
soft_cap=soft_cap,
|
|
765
|
+
mask_value=mask_value,
|
|
766
|
+
k_scale=k_scale,
|
|
767
|
+
v_scale=v_scale,
|
|
768
|
+
num_kv_pages_per_block=num_kv_pages_per_block,
|
|
769
|
+
num_queries_per_block=num_queries_per_block,
|
|
770
|
+
vmem_limit_bytes=vmem_limit_bytes,
|
|
771
|
+
)
|
|
772
|
+
if mask_value is None:
|
|
773
|
+
mask_value = DEFAULT_MASK_VALUE
|
|
774
|
+
num_q_tokens, num_q_heads, head_dim = q.shape
|
|
775
|
+
_, page_size, num_combined_kv_heads, _ = kv_pages.shape
|
|
776
|
+
assert num_combined_kv_heads % 2 == 0
|
|
777
|
+
num_kv_heads = num_combined_kv_heads // 2
|
|
778
|
+
_, pages_per_seq = page_indices.shape
|
|
779
|
+
num_q_heads_per_blk, num_combined_kv_heads_per_blk = get_min_heads_per_blk(
|
|
780
|
+
num_q_heads, num_combined_kv_heads, q.dtype, kv_pages.dtype)
|
|
781
|
+
num_q_per_blk = num_queries_per_block
|
|
782
|
+
num_kv_pages_per_blk = num_kv_pages_per_block
|
|
783
|
+
if num_q_per_blk is None or num_kv_pages_per_blk is None:
|
|
784
|
+
num_kv_pages_per_blk, num_q_per_blk = get_tuned_block_sizes(
|
|
785
|
+
q.dtype,
|
|
786
|
+
kv_pages.dtype,
|
|
787
|
+
num_q_heads_per_blk,
|
|
788
|
+
num_combined_kv_heads_per_blk // 2,
|
|
789
|
+
head_dim,
|
|
790
|
+
page_size,
|
|
791
|
+
num_q_tokens,
|
|
792
|
+
pages_per_seq,
|
|
793
|
+
)
|
|
794
|
+
num_q_heads_per_kv_head = num_q_heads // num_kv_heads
|
|
795
|
+
num_q_blks = cdiv(num_q_tokens, num_q_per_blk)
|
|
796
|
+
assert num_combined_kv_heads_per_blk % 2 == 0
|
|
797
|
+
num_kv_heads_per_blk = num_combined_kv_heads_per_blk // 2
|
|
798
|
+
assert num_q_heads_per_blk % num_q_heads_per_kv_head == 0
|
|
799
|
+
num_heads_blks = num_q_heads // num_q_heads_per_blk
|
|
800
|
+
grid = (num_heads_blks, num_q_blks)
|
|
801
|
+
|
|
802
|
+
def q_index_map(heads_blk_idx, q_blk_idx, *_):
|
|
803
|
+
return (q_blk_idx, heads_blk_idx, 0)
|
|
804
|
+
|
|
805
|
+
q_block_spec = pl.BlockSpec(
|
|
806
|
+
(num_q_per_blk, num_q_heads_per_blk, head_dim),
|
|
807
|
+
q_index_map,
|
|
808
|
+
)
|
|
809
|
+
in_specs = [
|
|
810
|
+
q_block_spec,
|
|
811
|
+
pl.BlockSpec(memory_space=pltpu.ANY),
|
|
812
|
+
]
|
|
813
|
+
out_specs = q_block_spec
|
|
814
|
+
lm_scratch = pltpu.VMEM(
|
|
815
|
+
# TODO(jevinjiang): use 128 instead of 1 is due to Mosaic does not support
|
|
816
|
+
# unaligned slicing!
|
|
817
|
+
(num_kv_heads_per_blk, num_q_per_blk * num_q_heads_per_kv_head, 128),
|
|
818
|
+
jnp.float32,
|
|
819
|
+
)
|
|
820
|
+
acc_scratch = pltpu.VMEM(
|
|
821
|
+
(num_q_per_blk, num_q_heads_per_blk, head_dim),
|
|
822
|
+
jnp.float32,
|
|
823
|
+
)
|
|
824
|
+
double_buf_scratch = pltpu.VMEM(
|
|
825
|
+
(
|
|
826
|
+
2, # For double buffering during DMA copies.
|
|
827
|
+
num_kv_pages_per_blk,
|
|
828
|
+
page_size,
|
|
829
|
+
num_combined_kv_heads_per_blk,
|
|
830
|
+
head_dim,
|
|
831
|
+
),
|
|
832
|
+
kv_pages.dtype,
|
|
833
|
+
)
|
|
834
|
+
scratch_shapes = [
|
|
835
|
+
double_buf_scratch, # kv_bufs
|
|
836
|
+
pltpu.SemaphoreType.DMA((2, )), # Semaphores for double buffers.
|
|
837
|
+
lm_scratch, # l_ref
|
|
838
|
+
lm_scratch, # m_ref
|
|
839
|
+
acc_scratch,
|
|
840
|
+
]
|
|
841
|
+
scalar_prefetches = (
|
|
842
|
+
kv_lens,
|
|
843
|
+
page_indices,
|
|
844
|
+
cu_q_lens,
|
|
845
|
+
jnp.array((0, 0), jnp.int32), # seq_idx, buf_idx
|
|
846
|
+
num_seqs,
|
|
847
|
+
)
|
|
848
|
+
kernel = pl.pallas_call(
|
|
849
|
+
functools.partial(
|
|
850
|
+
ragged_paged_attention_kernel,
|
|
851
|
+
sm_scale=sm_scale,
|
|
852
|
+
sliding_window=sliding_window,
|
|
853
|
+
soft_cap=soft_cap,
|
|
854
|
+
mask_value=mask_value,
|
|
855
|
+
k_scale=k_scale,
|
|
856
|
+
v_scale=v_scale,
|
|
857
|
+
),
|
|
858
|
+
grid_spec=pltpu.PrefetchScalarGridSpec(
|
|
859
|
+
num_scalar_prefetch=len(scalar_prefetches),
|
|
860
|
+
in_specs=in_specs,
|
|
861
|
+
out_specs=out_specs,
|
|
862
|
+
grid=grid,
|
|
863
|
+
scratch_shapes=scratch_shapes,
|
|
864
|
+
),
|
|
865
|
+
compiler_params=pltpu.CompilerParams(
|
|
866
|
+
dimension_semantics=(
|
|
867
|
+
"arbitrary",
|
|
868
|
+
"arbitrary",
|
|
869
|
+
),
|
|
870
|
+
vmem_limit_bytes=vmem_limit_bytes,
|
|
871
|
+
),
|
|
872
|
+
out_shape=jax.ShapeDtypeStruct(shape=q.shape, dtype=q.dtype),
|
|
873
|
+
name="ragged_paged_attention_kernel",
|
|
874
|
+
)
|
|
875
|
+
|
|
876
|
+
return kernel(*scalar_prefetches, q, kv_pages)
|