tpu-inference 0.11.1.dev202511150811__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +0 -0
- tests/core/__init__.py +0 -0
- tests/core/test_core_tpu.py +513 -0
- tests/core/test_disagg_executor.py +60 -0
- tests/core/test_disagg_utils.py +53 -0
- tests/core/test_dp_scheduler.py +899 -0
- tests/core/test_init.py +49 -0
- tests/kernels/__init__.py +0 -0
- tests/kernels/fused_moe_v1_test.py +105 -0
- tests/kernels/mla_v1_test.py +396 -0
- tests/kernels/quantized_matmul_kernel_test.py +191 -0
- tests/kernels/ragged_kv_cache_update_v2_test.py +234 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +400 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +549 -0
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +504 -0
- tests/lora/__init__.py +0 -0
- tests/lora/conftest.py +32 -0
- tests/lora/test_bgmv.py +43 -0
- tests/lora/test_layers.py +654 -0
- tests/lora/test_lora.py +133 -0
- tests/lora/utils.py +96 -0
- tests/test_base.py +201 -0
- tests/test_envs.py +182 -0
- tests/test_quantization.py +836 -0
- tests/test_tpu_info.py +120 -0
- tests/test_utils.py +236 -0
- tpu_inference/__init__.py +34 -0
- tpu_inference/core/__init__.py +0 -0
- tpu_inference/core/core_tpu.py +786 -0
- tpu_inference/core/disagg_executor.py +118 -0
- tpu_inference/core/disagg_utils.py +51 -0
- tpu_inference/core/sched/__init__.py +0 -0
- tpu_inference/core/sched/dp_scheduler.py +523 -0
- tpu_inference/distributed/__init__.py +0 -0
- tpu_inference/distributed/jax_parallel_state.py +67 -0
- tpu_inference/distributed/tpu_connector.py +728 -0
- tpu_inference/distributed/utils.py +59 -0
- tpu_inference/env_override.py +9 -0
- tpu_inference/envs.py +107 -0
- tpu_inference/executors/__init__.py +0 -0
- tpu_inference/executors/ray_distributed_executor.py +362 -0
- tpu_inference/experimental/__init__.py +0 -0
- tpu_inference/experimental/llama3_jax_stashed.py +258 -0
- tpu_inference/kernels/__init__.py +0 -0
- tpu_inference/kernels/collectives/__init__.py +0 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
- tpu_inference/kernels/collectives/util.py +47 -0
- tpu_inference/kernels/flash_attention/__init__.py +0 -0
- tpu_inference/kernels/flash_attention/kernel.py +772 -0
- tpu_inference/kernels/fused_moe/__init__.py +0 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +0 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +1035 -0
- tpu_inference/kernels/mla/__init__.py +0 -0
- tpu_inference/kernels/mla/v1/__init__.py +0 -0
- tpu_inference/kernels/mla/v1/kernel.py +1349 -0
- tpu_inference/kernels/quantized_matmul/__init__.py +0 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +395 -0
- tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
- tpu_inference/kernels/quantized_matmul/util.py +58 -0
- tpu_inference/kernels/ragged_paged_attention/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +875 -0
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +287 -0
- tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1478 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1482 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4147 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +367 -0
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +51 -0
- tpu_inference/layers/__init__.py +0 -0
- tpu_inference/layers/common/__init__.py +0 -0
- tpu_inference/layers/common/attention_interface.py +390 -0
- tpu_inference/layers/common/attention_metadata.py +34 -0
- tpu_inference/layers/common/binary_search.py +295 -0
- tpu_inference/layers/common/quant_methods.py +8 -0
- tpu_inference/layers/common/sharding.py +582 -0
- tpu_inference/layers/jax/__init__.py +0 -0
- tpu_inference/layers/jax/attention/__init__.py +0 -0
- tpu_inference/layers/jax/attention/attention.py +255 -0
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +262 -0
- tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
- tpu_inference/layers/jax/base.py +151 -0
- tpu_inference/layers/jax/constants.py +88 -0
- tpu_inference/layers/jax/layers.py +301 -0
- tpu_inference/layers/jax/misc.py +16 -0
- tpu_inference/layers/jax/moe/__init__.py +0 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +185 -0
- tpu_inference/layers/jax/moe/moe.py +209 -0
- tpu_inference/layers/jax/rope.py +280 -0
- tpu_inference/layers/jax/rope_interface.py +214 -0
- tpu_inference/layers/jax/sample/__init__.py +0 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
- tpu_inference/layers/jax/sample/sampling.py +96 -0
- tpu_inference/layers/jax/sample/sampling_metadata.py +76 -0
- tpu_inference/layers/jax/transformer_block.py +107 -0
- tpu_inference/layers/vllm/__init__.py +0 -0
- tpu_inference/layers/vllm/attention.py +221 -0
- tpu_inference/layers/vllm/fused_moe.py +507 -0
- tpu_inference/layers/vllm/linear_common.py +186 -0
- tpu_inference/layers/vllm/quantization/__init__.py +39 -0
- tpu_inference/layers/vllm/quantization/awq.py +207 -0
- tpu_inference/layers/vllm/quantization/common.py +105 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +120 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +203 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +266 -0
- tpu_inference/layers/vllm/quantization/unquantized.py +386 -0
- tpu_inference/layers/vllm/sharding.py +230 -0
- tpu_inference/logger.py +10 -0
- tpu_inference/lora/__init__.py +0 -0
- tpu_inference/lora/torch_lora_ops.py +103 -0
- tpu_inference/lora/torch_punica_tpu.py +311 -0
- tpu_inference/mock/__init__.py +0 -0
- tpu_inference/mock/vllm_config_utils.py +28 -0
- tpu_inference/mock/vllm_envs.py +1219 -0
- tpu_inference/mock/vllm_logger.py +212 -0
- tpu_inference/mock/vllm_logging_utils.py +15 -0
- tpu_inference/models/__init__.py +0 -0
- tpu_inference/models/common/__init__.py +0 -0
- tpu_inference/models/common/model_loader.py +444 -0
- tpu_inference/models/jax/__init__.py +0 -0
- tpu_inference/models/jax/deepseek_v3.py +868 -0
- tpu_inference/models/jax/gpt_oss.py +492 -0
- tpu_inference/models/jax/jax_intermediate_tensor.py +79 -0
- tpu_inference/models/jax/llama3.py +375 -0
- tpu_inference/models/jax/llama4.py +629 -0
- tpu_inference/models/jax/llama_eagle3.py +333 -0
- tpu_inference/models/jax/phi3.py +376 -0
- tpu_inference/models/jax/qwen2.py +375 -0
- tpu_inference/models/jax/qwen2_5_vl.py +1103 -0
- tpu_inference/models/jax/qwen3.py +302 -0
- tpu_inference/models/jax/utils/__init__.py +0 -0
- tpu_inference/models/jax/utils/file_utils.py +96 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +163 -0
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +5 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +6 -0
- tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +5 -0
- tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +6 -0
- tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +105 -0
- tpu_inference/models/jax/utils/quantization/quantization_utils.py +653 -0
- tpu_inference/models/jax/utils/weight_utils.py +529 -0
- tpu_inference/models/vllm/__init__.py +0 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +286 -0
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +45 -0
- tpu_inference/platforms/__init__.py +2 -0
- tpu_inference/platforms/tpu_platform.py +269 -0
- tpu_inference/runner/__init__.py +0 -0
- tpu_inference/runner/block_table.py +122 -0
- tpu_inference/runner/compilation_manager.py +780 -0
- tpu_inference/runner/input_batch.py +435 -0
- tpu_inference/runner/kv_cache.py +132 -0
- tpu_inference/runner/kv_cache_manager.py +479 -0
- tpu_inference/runner/lora_utils.py +92 -0
- tpu_inference/runner/multimodal_manager.py +217 -0
- tpu_inference/runner/persistent_batch_manager.py +244 -0
- tpu_inference/runner/speculative_decoding_manager.py +248 -0
- tpu_inference/runner/structured_decoding_manager.py +88 -0
- tpu_inference/runner/tpu_runner.py +1620 -0
- tpu_inference/runner/utils.py +426 -0
- tpu_inference/spec_decode/__init__.py +0 -0
- tpu_inference/spec_decode/jax/__init__.py +0 -0
- tpu_inference/spec_decode/jax/eagle3.py +367 -0
- tpu_inference/tpu_info.py +77 -0
- tpu_inference/utils.py +317 -0
- tpu_inference/worker/__init__.py +0 -0
- tpu_inference/worker/tpu_worker.py +321 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/METADATA +107 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/RECORD +179 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/WHEEL +5 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/licenses/LICENSE +201 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,772 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
+
"""Flash Attention TPU kernel."""
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import dataclasses
|
|
6
|
+
import functools
|
|
7
|
+
import math
|
|
8
|
+
from typing import Any, NamedTuple
|
|
9
|
+
|
|
10
|
+
import jax
|
|
11
|
+
import jax.numpy as jnp
|
|
12
|
+
from jax import lax
|
|
13
|
+
from jax.experimental import pallas as pl
|
|
14
|
+
from jax.experimental.pallas import tpu as pltpu
|
|
15
|
+
|
|
16
|
+
DEFAULT_MASK_VALUE = -0.7 * float(jnp.finfo(jnp.dtype("float32")).max)
|
|
17
|
+
NUM_LANES = 128
|
|
18
|
+
NUM_SUBLANES = 8
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class SegmentIds(NamedTuple):
|
|
22
|
+
"""SegmentIds for Q and KV sequences.
|
|
23
|
+
|
|
24
|
+
SegmentIds are used to generate segment mask, which prevents attention between
|
|
25
|
+
different segments in the input sequence. Each array is a list of ids
|
|
26
|
+
(integers).
|
|
27
|
+
Only the token with the same id can attend to each other.
|
|
28
|
+
|
|
29
|
+
Attributes:
|
|
30
|
+
q: segment ids along the Q sequence.
|
|
31
|
+
kv: segment ids along the KV sequence.
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
q: jax.Array # [batch_size, q_seq_len]
|
|
35
|
+
kv: jax.Array # [batch_size, kv_seq_len]
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
@dataclasses.dataclass(frozen=True)
|
|
39
|
+
class BlockSizes:
|
|
40
|
+
"""Tile sizes parameterizing FlashAttention kernels.
|
|
41
|
+
|
|
42
|
+
Those parameters have negligible effect on numerics, but affect performance
|
|
43
|
+
greatly.
|
|
44
|
+
"""
|
|
45
|
+
block_q: int
|
|
46
|
+
block_k_major: int
|
|
47
|
+
block_k: int
|
|
48
|
+
block_b: int
|
|
49
|
+
|
|
50
|
+
def __post_init__(self):
|
|
51
|
+
|
|
52
|
+
def verify_major_minor(prefix, suffix, major, minor):
|
|
53
|
+
if minor > major:
|
|
54
|
+
raise ValueError(
|
|
55
|
+
f"{prefix}{suffix}={minor} should be smaller than"
|
|
56
|
+
f" {prefix}_major{suffix}={major}")
|
|
57
|
+
if major % minor != 0:
|
|
58
|
+
raise ValueError(f"{prefix}{suffix}={minor} should divide"
|
|
59
|
+
f" {prefix}_major{suffix}={major}")
|
|
60
|
+
|
|
61
|
+
verify_major_minor("block_k", "", self.block_k_major, self.block_k)
|
|
62
|
+
|
|
63
|
+
@classmethod
|
|
64
|
+
def get_default(cls, batch_size, num_heads, q_seq_len, kv_len, d_model):
|
|
65
|
+
# TODO(apaszke,sharadmv): Select better parameters based on a heuristic.
|
|
66
|
+
del batch_size, num_heads, q_seq_len, kv_len, d_model # Unused.
|
|
67
|
+
return BlockSizes(
|
|
68
|
+
block_q=128,
|
|
69
|
+
block_k_major=128,
|
|
70
|
+
block_k=128,
|
|
71
|
+
block_b=1,
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
@functools.partial(
|
|
76
|
+
jax.jit,
|
|
77
|
+
static_argnames=[
|
|
78
|
+
"causal",
|
|
79
|
+
"sm_scale",
|
|
80
|
+
"block_sizes",
|
|
81
|
+
"vmem_limit_bytes",
|
|
82
|
+
"debug",
|
|
83
|
+
],
|
|
84
|
+
)
|
|
85
|
+
def flash_attention(
|
|
86
|
+
q, # [batch_size, num_heads, q_seq_len, d_model]
|
|
87
|
+
k, # [batch_size, num_heads, kv_seq_len, d_model]
|
|
88
|
+
v, # [batch_size, num_heads, kv_seq_len, d_model]
|
|
89
|
+
ab=None, # [batch_size, num_heads, q_seq_len, kv_seq_len]
|
|
90
|
+
segment_ids=None, # q of [batch_size, q_seq_len] and kv of [batch_size, kv_seq_len]
|
|
91
|
+
*,
|
|
92
|
+
causal: bool = False,
|
|
93
|
+
sm_scale: float = 1.0,
|
|
94
|
+
block_sizes: BlockSizes | None = None,
|
|
95
|
+
vmem_limit_bytes: int,
|
|
96
|
+
debug: bool = False,
|
|
97
|
+
):
|
|
98
|
+
batch_size, num_heads, q_seq_len, d_model = q.shape
|
|
99
|
+
batch_size_k, num_heads_k, kv_seq_len, d_model_k = k.shape
|
|
100
|
+
batch_size_v, num_heads_v, kv_seq_len_v, d_model_v = v.shape
|
|
101
|
+
if batch_size != batch_size_k or batch_size != batch_size_v:
|
|
102
|
+
raise ValueError(
|
|
103
|
+
f"Batch size mismatch: got {batch_size}, {batch_size_k} and"
|
|
104
|
+
f" {batch_size_v} (for q, k, v respectively)")
|
|
105
|
+
if num_heads != num_heads_k or num_heads != num_heads_v:
|
|
106
|
+
raise ValueError(
|
|
107
|
+
f"Head count mismatch: got {num_heads}, {num_heads_k},"
|
|
108
|
+
f" {num_heads_v} (for q, k, v respectively)")
|
|
109
|
+
if d_model != d_model_k:
|
|
110
|
+
raise ValueError(
|
|
111
|
+
f"Model dimension mismatch: got {d_model} and {d_model_k} (for q and k"
|
|
112
|
+
" respectively)")
|
|
113
|
+
if d_model != d_model_v:
|
|
114
|
+
raise NotImplementedError(
|
|
115
|
+
"V model dimension unequal to KV model dimension unsupported")
|
|
116
|
+
if kv_seq_len != kv_seq_len_v:
|
|
117
|
+
raise ValueError(
|
|
118
|
+
f"KV sequence length mismatch: got {kv_seq_len} and {kv_seq_len_v}"
|
|
119
|
+
)
|
|
120
|
+
if ab is not None:
|
|
121
|
+
if ab.shape != (batch_size, num_heads, q_seq_len, kv_seq_len):
|
|
122
|
+
raise ValueError(
|
|
123
|
+
f"Attention bias shape mismatch: expected ({batch_size=},"
|
|
124
|
+
f" {num_heads=}, {q_seq_len=}, {kv_seq_len=}), got {ab.shape}")
|
|
125
|
+
if segment_ids is not None:
|
|
126
|
+
if segment_ids.q.shape != (batch_size, q_seq_len):
|
|
127
|
+
raise ValueError(
|
|
128
|
+
f"Q segment ids shape mismatch: expected ({batch_size=},"
|
|
129
|
+
f" {q_seq_len=},), got {segment_ids.q.shape}")
|
|
130
|
+
if segment_ids.kv.shape != (batch_size, kv_seq_len):
|
|
131
|
+
raise ValueError(
|
|
132
|
+
f"KV segment ids shape mismatch: expected ({batch_size=},"
|
|
133
|
+
f" {kv_seq_len=},), got {segment_ids.kv.shape}")
|
|
134
|
+
if block_sizes is None:
|
|
135
|
+
block_sizes = BlockSizes.get_default(batch_size, num_heads, q_seq_len,
|
|
136
|
+
kv_seq_len, d_model)
|
|
137
|
+
# TODO (KWang1998 & hfan): tune the block sizes properly.
|
|
138
|
+
if kv_seq_len <= 92800:
|
|
139
|
+
# Override block_k/block_k_major to use `_flash_attention_kernel_single_batch_single_step`.
|
|
140
|
+
block_sizes = BlockSizes(block_q=block_sizes.block_q,
|
|
141
|
+
block_b=block_sizes.block_b,
|
|
142
|
+
block_k_major=kv_seq_len,
|
|
143
|
+
block_k=kv_seq_len)
|
|
144
|
+
return _flash_attention(q, k, v, ab, segment_ids, False, causal, sm_scale,
|
|
145
|
+
block_sizes, vmem_limit_bytes, debug)
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
def _flash_attention(
|
|
149
|
+
q,
|
|
150
|
+
k,
|
|
151
|
+
v,
|
|
152
|
+
ab,
|
|
153
|
+
segment_ids,
|
|
154
|
+
save_residuals,
|
|
155
|
+
causal,
|
|
156
|
+
sm_scale,
|
|
157
|
+
block_sizes,
|
|
158
|
+
vmem_limit_bytes,
|
|
159
|
+
debug,
|
|
160
|
+
):
|
|
161
|
+
return _flash_attention_impl(
|
|
162
|
+
q,
|
|
163
|
+
k,
|
|
164
|
+
v,
|
|
165
|
+
ab,
|
|
166
|
+
segment_ids,
|
|
167
|
+
save_residuals,
|
|
168
|
+
causal,
|
|
169
|
+
sm_scale,
|
|
170
|
+
block_sizes.block_b,
|
|
171
|
+
block_sizes.block_q,
|
|
172
|
+
block_sizes.block_k_major,
|
|
173
|
+
block_sizes.block_k,
|
|
174
|
+
vmem_limit_bytes,
|
|
175
|
+
debug,
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
MIN_BLOCK_SIZE = 128
|
|
180
|
+
TRANS_B_DIM_NUMBERS = (((1, ), (1, )), ((), ()))
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
def below_or_on_diag(r, r_blk_size, c, c_blk_size):
|
|
184
|
+
# A block is considered below or on diagonal as long as the bottom left
|
|
185
|
+
# corner of the block is below or on diagonal.
|
|
186
|
+
return ((r + 1) * r_blk_size - 1) > (c * c_blk_size)
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
def _flash_attention_kernel(q_tile_ref, *args, **kwargs):
|
|
190
|
+
block_b = q_tile_ref.shape[0]
|
|
191
|
+
# If we're not going to tile the softmax, then we can avoid a bunch of VPU ops.
|
|
192
|
+
if kwargs["block_k"] == kwargs["kv_seq_len"]:
|
|
193
|
+
kernel = _flash_attention_kernel_single_batch_single_step
|
|
194
|
+
else:
|
|
195
|
+
kernel = _flash_attention_kernel_single_batch
|
|
196
|
+
for batch_idx in range(block_b):
|
|
197
|
+
kernel((batch_idx, 0), q_tile_ref, *args, **kwargs)
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
def _flash_attention_kernel_single_batch(
|
|
201
|
+
batch_idx: tuple[int, ...],
|
|
202
|
+
q_tile_ref,
|
|
203
|
+
k_tile_ref,
|
|
204
|
+
v_tile_ref,
|
|
205
|
+
ab_tile_ref,
|
|
206
|
+
q_segment_ids_tile_ref,
|
|
207
|
+
kv_segment_ids_tile_ref, # Input arrays
|
|
208
|
+
o_tile_ref, # Output arrays
|
|
209
|
+
l_ref,
|
|
210
|
+
m_ref,
|
|
211
|
+
m_scratch_ref,
|
|
212
|
+
l_scratch_ref,
|
|
213
|
+
acc_scratch_ref,
|
|
214
|
+
*,
|
|
215
|
+
causal,
|
|
216
|
+
sm_scale,
|
|
217
|
+
block_k,
|
|
218
|
+
kv_seq_len,
|
|
219
|
+
mask_value,
|
|
220
|
+
):
|
|
221
|
+
block_k_major = k_tile_ref.shape[2]
|
|
222
|
+
block_q = q_tile_ref.shape[2]
|
|
223
|
+
head_dim = q_tile_ref.shape[-1]
|
|
224
|
+
|
|
225
|
+
kv_seq_idx = pl.program_id(3)
|
|
226
|
+
|
|
227
|
+
@pl.when(kv_seq_idx == 0)
|
|
228
|
+
def start_new_sequence():
|
|
229
|
+
m_scratch_ref[batch_idx] = jnp.full(m_scratch_ref.shape[2:], -jnp.inf,
|
|
230
|
+
jnp.float32)
|
|
231
|
+
l_scratch_ref[batch_idx] = jnp.zeros(l_scratch_ref.shape[2:],
|
|
232
|
+
jnp.float32)
|
|
233
|
+
acc_scratch_ref[batch_idx] = jnp.zeros(acc_scratch_ref.shape[2:],
|
|
234
|
+
jnp.float32)
|
|
235
|
+
|
|
236
|
+
q_seq_idx = pl.program_id(2)
|
|
237
|
+
if causal:
|
|
238
|
+
should_run = below_or_on_diag(q_seq_idx, block_q, kv_seq_idx,
|
|
239
|
+
block_k_major)
|
|
240
|
+
else:
|
|
241
|
+
should_run = True
|
|
242
|
+
|
|
243
|
+
@pl.when(should_run)
|
|
244
|
+
def run():
|
|
245
|
+
|
|
246
|
+
@pl.loop(0, block_k_major, step=block_k, unroll=True)
|
|
247
|
+
def _body(start_k):
|
|
248
|
+
m_prev = m_scratch_ref[batch_idx]
|
|
249
|
+
l_prev = l_scratch_ref[batch_idx]
|
|
250
|
+
q = q_tile_ref[batch_idx] # [block_q, head_dim]
|
|
251
|
+
k = k_tile_ref[(*batch_idx, pl.dslice(start_k, block_k),
|
|
252
|
+
slice(None))] # [block_k, head_dim]
|
|
253
|
+
|
|
254
|
+
s = jax.lax.dot_general(
|
|
255
|
+
q, k, TRANS_B_DIM_NUMBERS,
|
|
256
|
+
preferred_element_type=jnp.float32) # [block_q, block_k]
|
|
257
|
+
|
|
258
|
+
# Add attention bias if needed.
|
|
259
|
+
# TODO(tanburn) Should the attention bias be added before or after
|
|
260
|
+
# multiplication by sm_scale?
|
|
261
|
+
if ab_tile_ref is not None:
|
|
262
|
+
ab = ab_tile_ref[(*batch_idx, pl.dslice(None),
|
|
263
|
+
pl.dslice(start_k,
|
|
264
|
+
block_k))].astype(jnp.float32)
|
|
265
|
+
s += ab
|
|
266
|
+
|
|
267
|
+
if sm_scale != 1.0:
|
|
268
|
+
s *= sm_scale
|
|
269
|
+
|
|
270
|
+
mask = None
|
|
271
|
+
if q_segment_ids_tile_ref is not None:
|
|
272
|
+
repeats, rem = divmod(block_k, NUM_LANES)
|
|
273
|
+
if rem:
|
|
274
|
+
raise NotImplementedError(
|
|
275
|
+
f"kv block size must be a multiple of {NUM_LANES}")
|
|
276
|
+
q_segment_ids = pltpu.repeat(
|
|
277
|
+
q_segment_ids_tile_ref[batch_idx[0]], repeats,
|
|
278
|
+
axis=1) # [block_q, block_k].
|
|
279
|
+
kv_segment_ids = kv_segment_ids_tile_ref[
|
|
280
|
+
batch_idx[0], :1,
|
|
281
|
+
pl.dslice(start_k, block_k)] # [1, block_k].
|
|
282
|
+
mask = jnp.equal(q_segment_ids,
|
|
283
|
+
kv_segment_ids).astype(jnp.bool_)
|
|
284
|
+
|
|
285
|
+
if causal:
|
|
286
|
+
mask_shape = (block_q, block_k)
|
|
287
|
+
row_ids = jax.lax.broadcasted_iota(jnp.int32, mask_shape, 0)
|
|
288
|
+
row_ids += q_seq_idx * block_q
|
|
289
|
+
col_ids = jax.lax.broadcasted_iota(jnp.int32, mask_shape, 1)
|
|
290
|
+
col_ids += kv_seq_idx * block_k_major + start_k
|
|
291
|
+
causal_mask = col_ids <= row_ids
|
|
292
|
+
mask = (causal_mask if mask is None else jnp.logical_and(
|
|
293
|
+
mask, causal_mask))
|
|
294
|
+
|
|
295
|
+
s = s if mask is None else s + jnp.where(mask, 0.0, mask_value)
|
|
296
|
+
|
|
297
|
+
m_curr = jnp.max(s, axis=1)[:,
|
|
298
|
+
None] # Row max, shape [block_q, 1].
|
|
299
|
+
m_next = jnp.maximum(m_prev, m_curr) # Shape [block_q, 128].
|
|
300
|
+
|
|
301
|
+
block_k_repeats, rem = divmod(block_k, MIN_BLOCK_SIZE)
|
|
302
|
+
if rem:
|
|
303
|
+
raise NotImplementedError(
|
|
304
|
+
f"{block_k=} should be a multiple of {MIN_BLOCK_SIZE}")
|
|
305
|
+
p = jnp.exp(s - pltpu.repeat(m_next, block_k_repeats, 1))
|
|
306
|
+
|
|
307
|
+
alpha = jnp.exp(m_prev - m_next) # Shape [block_q, 128].
|
|
308
|
+
|
|
309
|
+
l_corr = alpha * l_prev
|
|
310
|
+
|
|
311
|
+
l_next = jnp.sum(p, axis=1)[:,
|
|
312
|
+
None] + l_corr # Shape [block_q, 128]
|
|
313
|
+
|
|
314
|
+
head_dim_repeats, rem = divmod(head_dim, MIN_BLOCK_SIZE)
|
|
315
|
+
l_broadcast = lambda l: pltpu.repeat(l, head_dim_repeats, 1)
|
|
316
|
+
if rem:
|
|
317
|
+
if head_dim_repeats == 0:
|
|
318
|
+
l_broadcast = lambda l: l[:, :head_dim]
|
|
319
|
+
else:
|
|
320
|
+
raise NotImplementedError(
|
|
321
|
+
f"{head_dim=} should be a multiple of {MIN_BLOCK_SIZE} if larger"
|
|
322
|
+
)
|
|
323
|
+
l_scratch_ref[batch_idx] = l_next
|
|
324
|
+
m_scratch_ref[batch_idx] = m_next
|
|
325
|
+
|
|
326
|
+
l_next_inv_safe = jnp.where(l_next == 0.0, 1.0, 1.0 / l_next)
|
|
327
|
+
acc_scratch_ref[batch_idx] *= l_broadcast(l_corr * l_next_inv_safe)
|
|
328
|
+
v = v_tile_ref[(*batch_idx, pl.dslice(start_k,
|
|
329
|
+
block_k), slice(None))]
|
|
330
|
+
o_curr = jax.lax.dot(p.astype(v.dtype),
|
|
331
|
+
v,
|
|
332
|
+
preferred_element_type=jnp.float32)
|
|
333
|
+
acc_scratch_ref[batch_idx] += o_curr * l_broadcast(l_next_inv_safe)
|
|
334
|
+
|
|
335
|
+
@pl.when(kv_seq_idx == (kv_seq_len // block_k_major) - 1)
|
|
336
|
+
def store_output():
|
|
337
|
+
o_tile_ref[batch_idx] = acc_scratch_ref[batch_idx].astype(
|
|
338
|
+
o_tile_ref.dtype)
|
|
339
|
+
if l_ref is not None:
|
|
340
|
+
l_ref[batch_idx] = l_scratch_ref[batch_idx].astype(l_ref.dtype)
|
|
341
|
+
if m_ref is not None:
|
|
342
|
+
m_ref[batch_idx] = m_scratch_ref[batch_idx].astype(m_ref.dtype)
|
|
343
|
+
|
|
344
|
+
|
|
345
|
+
# ruff: noqa #731
|
|
346
|
+
# ruff: noqa #741
|
|
347
|
+
def _flash_attention_kernel_single_batch_single_step(
|
|
348
|
+
batch_idx: tuple[int, ...],
|
|
349
|
+
q_tile_ref,
|
|
350
|
+
k_tile_ref,
|
|
351
|
+
v_tile_ref,
|
|
352
|
+
ab_tile_ref,
|
|
353
|
+
q_segment_ids_tile_ref,
|
|
354
|
+
kv_segment_ids_tile_ref, # Input arrays
|
|
355
|
+
o_tile_ref, # Output arrays
|
|
356
|
+
l_ref: Any | None = None,
|
|
357
|
+
m_ref: Any | None = None,
|
|
358
|
+
*,
|
|
359
|
+
causal,
|
|
360
|
+
sm_scale,
|
|
361
|
+
block_k,
|
|
362
|
+
kv_seq_len,
|
|
363
|
+
mask_value,
|
|
364
|
+
):
|
|
365
|
+
block_k_major = k_tile_ref.shape[2]
|
|
366
|
+
block_q = q_tile_ref.shape[2]
|
|
367
|
+
|
|
368
|
+
assert kv_seq_len == block_k_major == block_k
|
|
369
|
+
|
|
370
|
+
q = q_tile_ref[batch_idx] # [block_q, head_dim]
|
|
371
|
+
k = k_tile_ref[batch_idx] # [block_k, head_dim]
|
|
372
|
+
s = jax.lax.dot_general(
|
|
373
|
+
q, k, TRANS_B_DIM_NUMBERS,
|
|
374
|
+
preferred_element_type=jnp.float32) # [block_q, block_k]
|
|
375
|
+
|
|
376
|
+
if ab_tile_ref is not None:
|
|
377
|
+
s += ab_tile_ref[batch_idx].astype(jnp.float32)
|
|
378
|
+
if sm_scale != 1.0:
|
|
379
|
+
s *= sm_scale
|
|
380
|
+
|
|
381
|
+
mask = None
|
|
382
|
+
if q_segment_ids_tile_ref is not None:
|
|
383
|
+
repeats, rem = divmod(block_k, NUM_LANES)
|
|
384
|
+
if rem:
|
|
385
|
+
raise NotImplementedError(
|
|
386
|
+
f"kv block size must be a multiple of {NUM_LANES}")
|
|
387
|
+
q_segment_ids = q_segment_ids_tile_ref[
|
|
388
|
+
batch_idx[0]] # [block_q, NUM_LANES].
|
|
389
|
+
q_segment_ids = pltpu.repeat(q_segment_ids, repeats,
|
|
390
|
+
axis=1) # [block_q, block_k].
|
|
391
|
+
kv_segment_ids = kv_segment_ids_tile_ref[batch_idx[0], :
|
|
392
|
+
1] # [1, block_k].
|
|
393
|
+
mask = jnp.equal(q_segment_ids, kv_segment_ids).astype(jnp.bool_)
|
|
394
|
+
|
|
395
|
+
if causal:
|
|
396
|
+
q_seq_idx = pl.program_id(2)
|
|
397
|
+
mask_shape = (block_q, block_k)
|
|
398
|
+
row_ids = jax.lax.broadcasted_iota(jnp.int32, mask_shape, 0)
|
|
399
|
+
row_ids += q_seq_idx * block_q
|
|
400
|
+
col_ids = jax.lax.broadcasted_iota(jnp.int32, mask_shape, 1)
|
|
401
|
+
causal_mask = col_ids <= row_ids
|
|
402
|
+
mask = causal_mask if mask is None else jnp.logical_and(
|
|
403
|
+
mask, causal_mask)
|
|
404
|
+
s = s if mask is None else s + jnp.where(mask, 0.0, mask_value)
|
|
405
|
+
|
|
406
|
+
m = jnp.max(s, axis=1)[:, None]
|
|
407
|
+
p = jnp.exp(s - m)
|
|
408
|
+
l = jnp.sum(p, axis=1)[:, None]
|
|
409
|
+
p /= l
|
|
410
|
+
|
|
411
|
+
if m_ref is not None:
|
|
412
|
+
m_ref[batch_idx] = lax.broadcast_in_dim(m, m_ref.shape[2:], range(2))
|
|
413
|
+
if l_ref is not None:
|
|
414
|
+
l_ref[batch_idx] = lax.broadcast_in_dim(l, l_ref.shape[2:], range(2))
|
|
415
|
+
|
|
416
|
+
v = v_tile_ref[batch_idx]
|
|
417
|
+
o_tile_ref[batch_idx] = jax.lax.dot(
|
|
418
|
+
p.astype(v.dtype), v,
|
|
419
|
+
preferred_element_type=jnp.float32).astype(o_tile_ref.dtype)
|
|
420
|
+
|
|
421
|
+
|
|
422
|
+
def _bytes(x: jax.Array | jax.ShapeDtypeStruct) -> int:
|
|
423
|
+
return math.prod(x.shape) * x.dtype.itemsize
|
|
424
|
+
|
|
425
|
+
|
|
426
|
+
def _fwd_cost_estimate(
|
|
427
|
+
q: jax.Array,
|
|
428
|
+
k: jax.Array,
|
|
429
|
+
v: jax.Array,
|
|
430
|
+
ab: jax.Array | None,
|
|
431
|
+
segment_ids: SegmentIds | None,
|
|
432
|
+
*,
|
|
433
|
+
causal: bool,
|
|
434
|
+
sm_scale: jax.Array | None,
|
|
435
|
+
kernel_inputs_specs,
|
|
436
|
+
kernel_outputs_specs,
|
|
437
|
+
) -> pl.CostEstimate | None:
|
|
438
|
+
body_cost = pl.estimate_cost(mha_reference,
|
|
439
|
+
q,
|
|
440
|
+
k,
|
|
441
|
+
v,
|
|
442
|
+
ab,
|
|
443
|
+
segment_ids,
|
|
444
|
+
causal=causal,
|
|
445
|
+
sm_scale=sm_scale)
|
|
446
|
+
input_bytes = sum(_bytes(x) for x in jax.tree.leaves(kernel_inputs_specs))
|
|
447
|
+
output_bytes = sum(
|
|
448
|
+
_bytes(x) for x in jax.tree.leaves(kernel_outputs_specs))
|
|
449
|
+
return pl.CostEstimate(
|
|
450
|
+
flops=body_cost.flops,
|
|
451
|
+
transcendentals=body_cost.transcendentals,
|
|
452
|
+
bytes_accessed=input_bytes + output_bytes,
|
|
453
|
+
)
|
|
454
|
+
|
|
455
|
+
|
|
456
|
+
def _flash_attention_impl(
|
|
457
|
+
q,
|
|
458
|
+
k,
|
|
459
|
+
v,
|
|
460
|
+
ab,
|
|
461
|
+
segment_ids,
|
|
462
|
+
save_residuals,
|
|
463
|
+
causal,
|
|
464
|
+
sm_scale,
|
|
465
|
+
block_b,
|
|
466
|
+
block_q,
|
|
467
|
+
block_k_major,
|
|
468
|
+
block_k,
|
|
469
|
+
vmem_limit_bytes,
|
|
470
|
+
debug,
|
|
471
|
+
):
|
|
472
|
+
batch_size, num_heads, q_seq_len, head_dim = q.shape
|
|
473
|
+
_, _, kv_seq_len, _ = k.shape
|
|
474
|
+
_verify_block("block_q",
|
|
475
|
+
"q_seq_len",
|
|
476
|
+
block_q,
|
|
477
|
+
q_seq_len,
|
|
478
|
+
should_divide=False)
|
|
479
|
+
_verify_block("block_k_major", "kv_seq_len", block_k_major, kv_seq_len)
|
|
480
|
+
_verify_block("block_k", "kv_seq_len", block_k, kv_seq_len)
|
|
481
|
+
_verify_block("block_b", "batch", block_b, batch_size, should_divide=False)
|
|
482
|
+
|
|
483
|
+
# TODO(apaszke): Tile over heads as well.
|
|
484
|
+
grid = (
|
|
485
|
+
pl.cdiv(batch_size, block_b),
|
|
486
|
+
num_heads,
|
|
487
|
+
pl.cdiv(q_seq_len, block_q),
|
|
488
|
+
kv_seq_len // block_k_major,
|
|
489
|
+
)
|
|
490
|
+
|
|
491
|
+
def q_index_map(batch_index, head_index, q_seq_index, _):
|
|
492
|
+
return (batch_index, head_index, q_seq_index, 0)
|
|
493
|
+
|
|
494
|
+
def kv_index_map(batch_index, head_index, q_seq_index, kv_seq_index):
|
|
495
|
+
if causal:
|
|
496
|
+
# If the kv block is skipped, prefetch the next valid kv block, i.e. the
|
|
497
|
+
# 0th one to be used for the next block_q rows.
|
|
498
|
+
next_kv_index = lax.select(
|
|
499
|
+
below_or_on_diag(q_seq_index, block_q, kv_seq_index,
|
|
500
|
+
block_k_major),
|
|
501
|
+
kv_seq_index,
|
|
502
|
+
0,
|
|
503
|
+
)
|
|
504
|
+
else:
|
|
505
|
+
next_kv_index = kv_seq_index
|
|
506
|
+
return (batch_index, head_index, next_kv_index, 0)
|
|
507
|
+
|
|
508
|
+
def ab_index_map(batch_index, head_index, q_seq_index, kv_seq_index):
|
|
509
|
+
if causal:
|
|
510
|
+
should_run = below_or_on_diag(q_seq_index, block_q, kv_seq_index,
|
|
511
|
+
block_k_major)
|
|
512
|
+
# If the ab block is skipped, prefetch the next valid ab block, i.e. the
|
|
513
|
+
# 0th kv to be used for the next block_q rows.
|
|
514
|
+
next_q_index = lax.select(
|
|
515
|
+
should_run,
|
|
516
|
+
q_seq_index,
|
|
517
|
+
lax.select(q_seq_index == (q_seq_len // block_q) - 1, 0,
|
|
518
|
+
q_seq_index + 1),
|
|
519
|
+
)
|
|
520
|
+
next_kv_index = lax.select(should_run, kv_seq_index, 0)
|
|
521
|
+
else:
|
|
522
|
+
next_q_index = q_seq_index
|
|
523
|
+
next_kv_index = kv_seq_index
|
|
524
|
+
|
|
525
|
+
return (batch_index, head_index, next_q_index, next_kv_index)
|
|
526
|
+
|
|
527
|
+
def o_index_map(batch_index, head_index, q_seq_index, _):
|
|
528
|
+
return (batch_index, head_index, q_seq_index, 0)
|
|
529
|
+
|
|
530
|
+
def lm_index_map(batch_index, head_index, q_seq_index, _):
|
|
531
|
+
return (batch_index, head_index, q_seq_index, 0)
|
|
532
|
+
|
|
533
|
+
kernel = functools.partial(
|
|
534
|
+
_flash_attention_kernel,
|
|
535
|
+
causal=causal,
|
|
536
|
+
mask_value=DEFAULT_MASK_VALUE,
|
|
537
|
+
sm_scale=sm_scale,
|
|
538
|
+
block_k=block_k,
|
|
539
|
+
kv_seq_len=kv_seq_len,
|
|
540
|
+
)
|
|
541
|
+
out_shape = jax.ShapeDtypeStruct(shape=q.shape, dtype=q.dtype)
|
|
542
|
+
out_shape = [out_shape]
|
|
543
|
+
out_specs = [pl.BlockSpec((block_b, 1, block_q, head_dim), o_index_map)]
|
|
544
|
+
|
|
545
|
+
if block_k != kv_seq_len:
|
|
546
|
+
m_scratch = pltpu.VMEM((block_b, 1, block_q, MIN_BLOCK_SIZE),
|
|
547
|
+
jnp.float32)
|
|
548
|
+
l_scratch = pltpu.VMEM((block_b, 1, block_q, MIN_BLOCK_SIZE),
|
|
549
|
+
jnp.float32)
|
|
550
|
+
acc_scratch = pltpu.VMEM((block_b, 1, block_q, head_dim), jnp.float32)
|
|
551
|
+
scratch_shapes = [m_scratch, l_scratch, acc_scratch]
|
|
552
|
+
else:
|
|
553
|
+
scratch_shapes = []
|
|
554
|
+
|
|
555
|
+
if save_residuals:
|
|
556
|
+
out_specs = [
|
|
557
|
+
*out_specs,
|
|
558
|
+
pl.BlockSpec((block_b, 1, block_q, MIN_BLOCK_SIZE), lm_index_map),
|
|
559
|
+
pl.BlockSpec((block_b, 1, block_q, MIN_BLOCK_SIZE), lm_index_map),
|
|
560
|
+
]
|
|
561
|
+
l = jax.ShapeDtypeStruct(
|
|
562
|
+
(batch_size, num_heads, q_seq_len, MIN_BLOCK_SIZE),
|
|
563
|
+
dtype=jnp.float32)
|
|
564
|
+
m = jax.ShapeDtypeStruct(
|
|
565
|
+
(batch_size, num_heads, q_seq_len, MIN_BLOCK_SIZE),
|
|
566
|
+
dtype=jnp.float32)
|
|
567
|
+
out_shape = (*out_shape, l, m)
|
|
568
|
+
else:
|
|
569
|
+
out_specs = [*out_specs, None, None]
|
|
570
|
+
out_shape = (*out_shape, None, None)
|
|
571
|
+
|
|
572
|
+
ab_block_spec = (pl.BlockSpec(
|
|
573
|
+
(block_b, 1, block_q,
|
|
574
|
+
block_k_major), ab_index_map) if ab is not None else None)
|
|
575
|
+
|
|
576
|
+
q_segment_ids_spec = kv_segment_ids_spec = None
|
|
577
|
+
q_segment_ids = kv_segment_ids = None
|
|
578
|
+
if segment_ids is not None:
|
|
579
|
+
|
|
580
|
+
def q_segment_ids_index_map(batch_index, head_index, q_seq_index, _):
|
|
581
|
+
del head_index
|
|
582
|
+
return (batch_index, q_seq_index, 0)
|
|
583
|
+
|
|
584
|
+
def kv_segment_ids_index_map(batch_index, head_index, q_seq_index,
|
|
585
|
+
kv_seq_index):
|
|
586
|
+
del head_index
|
|
587
|
+
if causal:
|
|
588
|
+
next_kv_index = lax.select(
|
|
589
|
+
below_or_on_diag(q_seq_index, block_q, kv_seq_index,
|
|
590
|
+
block_k_major),
|
|
591
|
+
kv_seq_index,
|
|
592
|
+
0,
|
|
593
|
+
)
|
|
594
|
+
else:
|
|
595
|
+
next_kv_index = kv_seq_index
|
|
596
|
+
return (batch_index, 0, next_kv_index)
|
|
597
|
+
|
|
598
|
+
q_segment_ids_spec = pl.BlockSpec((block_b, block_q, NUM_LANES),
|
|
599
|
+
q_segment_ids_index_map)
|
|
600
|
+
kv_segment_ids_spec = pl.BlockSpec(
|
|
601
|
+
(block_b, NUM_SUBLANES, block_k_major), kv_segment_ids_index_map)
|
|
602
|
+
|
|
603
|
+
q_segment_ids = jax.lax.broadcast_in_dim(
|
|
604
|
+
segment_ids.q,
|
|
605
|
+
(batch_size, q_seq_len, NUM_LANES),
|
|
606
|
+
(
|
|
607
|
+
0,
|
|
608
|
+
1,
|
|
609
|
+
),
|
|
610
|
+
)
|
|
611
|
+
kv_segment_ids = jax.lax.broadcast_in_dim(
|
|
612
|
+
segment_ids.kv,
|
|
613
|
+
(batch_size, NUM_SUBLANES, kv_seq_len),
|
|
614
|
+
(
|
|
615
|
+
0,
|
|
616
|
+
2,
|
|
617
|
+
),
|
|
618
|
+
)
|
|
619
|
+
|
|
620
|
+
in_specs = [
|
|
621
|
+
pl.BlockSpec((block_b, 1, block_q, head_dim), q_index_map),
|
|
622
|
+
pl.BlockSpec((block_b, 1, block_k_major, head_dim), kv_index_map),
|
|
623
|
+
pl.BlockSpec((block_b, 1, block_k_major, head_dim), kv_index_map),
|
|
624
|
+
ab_block_spec,
|
|
625
|
+
q_segment_ids_spec,
|
|
626
|
+
kv_segment_ids_spec,
|
|
627
|
+
]
|
|
628
|
+
|
|
629
|
+
o, *aux = pl.pallas_call(
|
|
630
|
+
kernel,
|
|
631
|
+
grid_spec=pltpu.PrefetchScalarGridSpec(
|
|
632
|
+
num_scalar_prefetch=0,
|
|
633
|
+
grid=grid,
|
|
634
|
+
in_specs=in_specs,
|
|
635
|
+
out_specs=out_specs,
|
|
636
|
+
scratch_shapes=scratch_shapes,
|
|
637
|
+
),
|
|
638
|
+
out_shape=out_shape,
|
|
639
|
+
debug=debug,
|
|
640
|
+
compiler_params=pltpu.CompilerParams(
|
|
641
|
+
dimension_semantics=(
|
|
642
|
+
"parallel",
|
|
643
|
+
"parallel",
|
|
644
|
+
"parallel",
|
|
645
|
+
"arbitrary",
|
|
646
|
+
),
|
|
647
|
+
vmem_limit_bytes=vmem_limit_bytes,
|
|
648
|
+
),
|
|
649
|
+
cost_estimate=_fwd_cost_estimate(
|
|
650
|
+
q,
|
|
651
|
+
k,
|
|
652
|
+
v,
|
|
653
|
+
ab,
|
|
654
|
+
segment_ids,
|
|
655
|
+
causal=causal,
|
|
656
|
+
sm_scale=sm_scale,
|
|
657
|
+
kernel_inputs_specs=(q, k, v, ab, q_segment_ids, kv_segment_ids),
|
|
658
|
+
kernel_outputs_specs=out_shape,
|
|
659
|
+
),
|
|
660
|
+
)(q, k, v, ab, q_segment_ids, kv_segment_ids)
|
|
661
|
+
if save_residuals:
|
|
662
|
+
l, m = (v[..., 0] for v in aux[-2:])
|
|
663
|
+
return (o, l, m)
|
|
664
|
+
else:
|
|
665
|
+
return o
|
|
666
|
+
|
|
667
|
+
|
|
668
|
+
# For autograd testing.
|
|
669
|
+
def mha_reference_no_custom_vjp(
|
|
670
|
+
q,
|
|
671
|
+
k,
|
|
672
|
+
v,
|
|
673
|
+
ab: jax.Array | None = None,
|
|
674
|
+
segment_ids: SegmentIds | None = None,
|
|
675
|
+
*,
|
|
676
|
+
causal: bool = False,
|
|
677
|
+
mask_value: float = DEFAULT_MASK_VALUE,
|
|
678
|
+
sm_scale: float = 1.0,
|
|
679
|
+
save_residuals: bool = False,
|
|
680
|
+
):
|
|
681
|
+
logits = jnp.einsum("bhqc,bhkc->bhqk", q, k)
|
|
682
|
+
if ab is not None:
|
|
683
|
+
logits += ab
|
|
684
|
+
if sm_scale != 1.0:
|
|
685
|
+
logits *= sm_scale
|
|
686
|
+
|
|
687
|
+
mask = None
|
|
688
|
+
if segment_ids is not None:
|
|
689
|
+
mask = segment_ids.q[:, :, None] == segment_ids.kv[:, None, :]
|
|
690
|
+
mask = mask[:, None, :, :]
|
|
691
|
+
|
|
692
|
+
if causal:
|
|
693
|
+
_, _, q_seq_len, _ = q.shape
|
|
694
|
+
_, _, kv_seq_len, _ = k.shape
|
|
695
|
+
mask_shape = (q_seq_len, kv_seq_len)
|
|
696
|
+
row_ids = jax.lax.broadcasted_iota(jnp.int32, mask_shape, 0)
|
|
697
|
+
col_ids = jax.lax.broadcasted_iota(jnp.int32, mask_shape, 1)
|
|
698
|
+
causal_mask = (col_ids <= row_ids)[None, None, :, :]
|
|
699
|
+
mask = causal_mask if mask is None else jnp.logical_and(
|
|
700
|
+
mask, causal_mask)
|
|
701
|
+
|
|
702
|
+
logits = logits if mask is None else logits + jnp.where(
|
|
703
|
+
mask, 0.0, mask_value)
|
|
704
|
+
|
|
705
|
+
m = logits.max(axis=-1)
|
|
706
|
+
unnormalized = jnp.exp(logits - m[..., None])
|
|
707
|
+
l = unnormalized.sum(axis=-1)
|
|
708
|
+
weights = unnormalized / l[..., None]
|
|
709
|
+
out = jnp.einsum("bhqk,bhkc->bhqc", weights, v)
|
|
710
|
+
if save_residuals:
|
|
711
|
+
return out, l, m
|
|
712
|
+
return out
|
|
713
|
+
|
|
714
|
+
|
|
715
|
+
@functools.partial(jax.jit,
|
|
716
|
+
static_argnames=["causal", "mask_value", "sm_scale"])
|
|
717
|
+
@jax.default_matmul_precision("bfloat16")
|
|
718
|
+
def mha_reference(
|
|
719
|
+
q,
|
|
720
|
+
k,
|
|
721
|
+
v,
|
|
722
|
+
ab,
|
|
723
|
+
segment_ids: SegmentIds | None = None,
|
|
724
|
+
causal: bool = False,
|
|
725
|
+
mask_value: float = DEFAULT_MASK_VALUE,
|
|
726
|
+
sm_scale=1.0,
|
|
727
|
+
):
|
|
728
|
+
return _mha_reference(
|
|
729
|
+
q,
|
|
730
|
+
k,
|
|
731
|
+
v,
|
|
732
|
+
ab,
|
|
733
|
+
segment_ids,
|
|
734
|
+
causal=causal,
|
|
735
|
+
mask_value=mask_value,
|
|
736
|
+
sm_scale=sm_scale,
|
|
737
|
+
save_residuals=False,
|
|
738
|
+
)
|
|
739
|
+
|
|
740
|
+
|
|
741
|
+
def _mha_reference(
|
|
742
|
+
q,
|
|
743
|
+
k,
|
|
744
|
+
v,
|
|
745
|
+
ab,
|
|
746
|
+
segment_ids: SegmentIds | None,
|
|
747
|
+
causal: bool,
|
|
748
|
+
mask_value: float,
|
|
749
|
+
sm_scale: float,
|
|
750
|
+
save_residuals: bool,
|
|
751
|
+
):
|
|
752
|
+
return mha_reference_no_custom_vjp(
|
|
753
|
+
q,
|
|
754
|
+
k,
|
|
755
|
+
v,
|
|
756
|
+
ab,
|
|
757
|
+
segment_ids,
|
|
758
|
+
causal=causal,
|
|
759
|
+
mask_value=mask_value,
|
|
760
|
+
sm_scale=sm_scale,
|
|
761
|
+
save_residuals=save_residuals,
|
|
762
|
+
)
|
|
763
|
+
|
|
764
|
+
|
|
765
|
+
def _verify_block(block_name, dim_name, block, dim, should_divide=True):
|
|
766
|
+
if block > dim:
|
|
767
|
+
raise ValueError(
|
|
768
|
+
f"{block_name}={block} should be smaller or equal to {dim_name}={dim}"
|
|
769
|
+
)
|
|
770
|
+
if should_divide and dim % block != 0:
|
|
771
|
+
raise ValueError(
|
|
772
|
+
f"{dim_name}={dim} should be divisible by {block_name}={block}")
|