ai-edge-torch-nightly 0.2.0.dev20240604__py3-none-any.whl → 0.2.0.dev20240606__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ai-edge-torch-nightly might be problematic. Click here for more details.
- ai_edge_torch/convert/fx_passes/build_aten_composite_pass.py +30 -0
- ai_edge_torch/convert/test/test_convert_composites.py +18 -0
- ai_edge_torch/generative/examples/stable_diffusion/clip.py +83 -49
- ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +7 -5
- ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +0 -260
- ai_edge_torch/generative/examples/t5/t5_attention.py +2 -2
- ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +161 -0
- ai_edge_torch/generative/layers/attention.py +27 -114
- ai_edge_torch/generative/layers/builder.py +4 -0
- ai_edge_torch/generative/layers/model_config.py +5 -1
- ai_edge_torch/generative/layers/scaled_dot_product_attention.py +117 -0
- ai_edge_torch/generative/test/test_model_conversion.py +90 -80
- ai_edge_torch/generative/utilities/loader.py +56 -27
- {ai_edge_torch_nightly-0.2.0.dev20240604.dist-info → ai_edge_torch_nightly-0.2.0.dev20240606.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.2.0.dev20240604.dist-info → ai_edge_torch_nightly-0.2.0.dev20240606.dist-info}/RECORD +18 -16
- {ai_edge_torch_nightly-0.2.0.dev20240604.dist-info → ai_edge_torch_nightly-0.2.0.dev20240606.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.2.0.dev20240604.dist-info → ai_edge_torch_nightly-0.2.0.dev20240606.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.2.0.dev20240604.dist-info → ai_edge_torch_nightly-0.2.0.dev20240606.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,161 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
# A toy example which has basic transformer block (w/ externalized KV-Cache).
|
|
16
|
+
|
|
17
|
+
from typing import List, Tuple
|
|
18
|
+
|
|
19
|
+
import numpy as np
|
|
20
|
+
import torch
|
|
21
|
+
import torch.nn as nn
|
|
22
|
+
import torch_xla
|
|
23
|
+
|
|
24
|
+
import ai_edge_torch
|
|
25
|
+
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
|
26
|
+
import ai_edge_torch.generative.layers.builder as builder
|
|
27
|
+
from ai_edge_torch.generative.layers.experimental.attention import TransformerBlock # NOQA
|
|
28
|
+
import ai_edge_torch.generative.layers.model_config as cfg
|
|
29
|
+
|
|
30
|
+
RoPECache = Tuple[torch.Tensor, torch.Tensor]
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class ToyModelWithExternalKV(torch.nn.Module):
|
|
34
|
+
|
|
35
|
+
def __init__(self, config: cfg.ModelConfig) -> None:
|
|
36
|
+
super().__init__()
|
|
37
|
+
self.lm_head = nn.Linear(
|
|
38
|
+
config.embedding_dim, config.vocab_size, bias=config.lm_head_use_bias
|
|
39
|
+
)
|
|
40
|
+
self.tok_embedding = nn.Embedding(config.vocab_size, config.embedding_dim)
|
|
41
|
+
self.transformer_blocks = nn.ModuleList(
|
|
42
|
+
TransformerBlock(config) for _ in range(config.num_layers)
|
|
43
|
+
)
|
|
44
|
+
self.final_norm = builder.build_norm(
|
|
45
|
+
config.embedding_dim,
|
|
46
|
+
config.final_norm_config,
|
|
47
|
+
)
|
|
48
|
+
self.rope_cache = attn_utils.build_rope_cache(
|
|
49
|
+
size=config.max_seq_len,
|
|
50
|
+
dim=int(config.attn_config.rotary_percentage * config.head_dim),
|
|
51
|
+
base=10_000,
|
|
52
|
+
condense_ratio=1,
|
|
53
|
+
dtype=torch.float32,
|
|
54
|
+
device=torch.device('cpu'),
|
|
55
|
+
)
|
|
56
|
+
self.mask_cache = attn_utils.build_causal_mask_cache(
|
|
57
|
+
size=config.max_seq_len, dtype=torch.float32, device=torch.device('cpu')
|
|
58
|
+
)
|
|
59
|
+
self.config = config
|
|
60
|
+
|
|
61
|
+
def forward(
|
|
62
|
+
self,
|
|
63
|
+
idx: torch.Tensor,
|
|
64
|
+
input_pos: torch.Tensor,
|
|
65
|
+
k_caches: torch.Tensor,
|
|
66
|
+
v_caches: torch.Tensor,
|
|
67
|
+
) -> (torch.Tensor, torch.Tensor, torch.Tensor):
|
|
68
|
+
x = self.tok_embedding(idx)
|
|
69
|
+
cos, sin = self.rope_cache
|
|
70
|
+
cos = cos.index_select(0, input_pos)
|
|
71
|
+
sin = sin.index_select(0, input_pos)
|
|
72
|
+
mask = self.mask_cache.index_select(2, input_pos)
|
|
73
|
+
mask = mask[:, :, :, : self.config.max_seq_len]
|
|
74
|
+
|
|
75
|
+
for i, block in enumerate(self.transformer_blocks):
|
|
76
|
+
input_k, input_v = k_caches[i], v_caches[i]
|
|
77
|
+
x, (updated_k, updated_v) = block(
|
|
78
|
+
x, (cos, sin), mask, input_pos, (input_k, input_v)
|
|
79
|
+
)
|
|
80
|
+
k_caches[i], v_caches[i] = updated_k, updated_v
|
|
81
|
+
|
|
82
|
+
x = self.final_norm(x)
|
|
83
|
+
return self.lm_head(x), k_caches, v_caches
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def _export_stablehlo_mlir(model, args):
|
|
87
|
+
ep = torch.export.export(model, args)
|
|
88
|
+
stablehlo_gm = torch_xla.stablehlo.exported_program_to_stablehlo(ep)
|
|
89
|
+
return stablehlo_gm.get_stablehlo_text()
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def get_model_config() -> cfg.ModelConfig:
|
|
93
|
+
attn_config = cfg.AttentionConfig(
|
|
94
|
+
num_heads=32, num_query_groups=4, rotary_percentage=1.0
|
|
95
|
+
)
|
|
96
|
+
ff_config = cfg.FeedForwardConfig(
|
|
97
|
+
type=cfg.FeedForwardType.GATED,
|
|
98
|
+
activation=cfg.ActivationType.SILU,
|
|
99
|
+
intermediate_size=256,
|
|
100
|
+
)
|
|
101
|
+
norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
|
|
102
|
+
config = cfg.ModelConfig(
|
|
103
|
+
vocab_size=150,
|
|
104
|
+
num_layers=2,
|
|
105
|
+
max_seq_len=100,
|
|
106
|
+
embedding_dim=128,
|
|
107
|
+
attn_config=attn_config,
|
|
108
|
+
ff_config=ff_config,
|
|
109
|
+
pre_attention_norm_config=norm_config,
|
|
110
|
+
pre_ff_norm_config=norm_config,
|
|
111
|
+
final_norm_config=norm_config,
|
|
112
|
+
enable_hlfb=True,
|
|
113
|
+
)
|
|
114
|
+
return config
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def get_sample_prefill_inputs() -> Tuple[torch.Tensor, torch.Tensor]:
|
|
118
|
+
idx = torch.unsqueeze(torch.arange(0, 100), 0)
|
|
119
|
+
input_pos = torch.arange(0, 100)
|
|
120
|
+
return idx, input_pos
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def get_sample_decode_inputs() -> Tuple[torch.Tensor, torch.Tensor]:
|
|
124
|
+
idx = torch.tensor([[1]], dtype=torch.long)
|
|
125
|
+
input_pos = torch.tensor([10])
|
|
126
|
+
return idx, input_pos
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def define_and_run() -> None:
|
|
130
|
+
dump_mlir = False
|
|
131
|
+
|
|
132
|
+
config = get_model_config()
|
|
133
|
+
model = ToyModelWithExternalKV(config)
|
|
134
|
+
print('running an inference')
|
|
135
|
+
k_caches = torch.zeros((2, 1, 100, 4, 4), dtype=torch.float32)
|
|
136
|
+
v_caches = torch.zeros((2, 1, 100, 4, 4), dtype=torch.float32)
|
|
137
|
+
|
|
138
|
+
idx, input_pos = get_sample_prefill_inputs()
|
|
139
|
+
decode_idx, decode_input_pos = get_sample_decode_inputs()
|
|
140
|
+
print(model.forward(idx, input_pos, k_caches, v_caches))
|
|
141
|
+
|
|
142
|
+
if dump_mlir:
|
|
143
|
+
mlir_text = _export_stablehlo_mlir(model, (idx, input_pos, k_caches, v_caches))
|
|
144
|
+
with open('/tmp/toy_model_with_external_kv.stablehlo.mlir', 'w') as f:
|
|
145
|
+
f.write(mlir_text)
|
|
146
|
+
|
|
147
|
+
# Convert model to tflite with 2 signatures (prefill + decode).
|
|
148
|
+
# TODO(b/344014416): currently conversion will fail, because we generate int64 index
|
|
149
|
+
# in dynamic update slice op.
|
|
150
|
+
print('converting toy model to tflite with 2 signatures (prefill + decode)')
|
|
151
|
+
edge_model = (
|
|
152
|
+
ai_edge_torch.signature('prefill', model, (idx, input_pos, k_caches, v_caches))
|
|
153
|
+
.signature('decode', model, (decode_idx, decode_input_pos, k_caches, v_caches))
|
|
154
|
+
.convert()
|
|
155
|
+
)
|
|
156
|
+
edge_model.export('/tmp/toy_external_kv_cache.tflite')
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
if __name__ == '__main__':
|
|
160
|
+
with torch.inference_mode():
|
|
161
|
+
define_and_run()
|
|
@@ -14,7 +14,6 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
# Common building blocks for Attention layer.
|
|
16
16
|
|
|
17
|
-
import math
|
|
18
17
|
from typing import Optional, Tuple
|
|
19
18
|
|
|
20
19
|
import torch
|
|
@@ -25,101 +24,8 @@ import ai_edge_torch.generative.layers.builder as builder
|
|
|
25
24
|
from ai_edge_torch.generative.layers.kv_cache import KVCache
|
|
26
25
|
import ai_edge_torch.generative.layers.model_config as cfg
|
|
27
26
|
import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
|
|
28
|
-
from ai_edge_torch.
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
def scaled_dot_product_attention(
|
|
32
|
-
q: torch.Tensor,
|
|
33
|
-
k: torch.Tensor,
|
|
34
|
-
v: torch.Tensor,
|
|
35
|
-
head_size: int,
|
|
36
|
-
mask: Optional[torch.Tensor] = None,
|
|
37
|
-
scale: Optional[float] = None,
|
|
38
|
-
):
|
|
39
|
-
"""Scaled dot product attention.
|
|
40
|
-
|
|
41
|
-
Args:
|
|
42
|
-
q (torch.Tensor): Query tensor, with shape [B, T, N, H].
|
|
43
|
-
k (torch.Tensor): Key tensor, with shape [B, T, KV_LEN, H].
|
|
44
|
-
v (torch.Tensor): Value tensor, with shape [B, T, KV_LEN, H].
|
|
45
|
-
head_size (int): head dimension.
|
|
46
|
-
mask (torch.Tensor): the optional mask tensor.
|
|
47
|
-
|
|
48
|
-
Returns:
|
|
49
|
-
The output tensor of scaled_dot_product_attention.
|
|
50
|
-
"""
|
|
51
|
-
|
|
52
|
-
if scale is None:
|
|
53
|
-
scale = 1.0 / math.sqrt(head_size)
|
|
54
|
-
|
|
55
|
-
q = q.transpose(1, 2)
|
|
56
|
-
k = k.transpose(1, 2)
|
|
57
|
-
v = v.transpose(1, 2)
|
|
58
|
-
if q.size() != k.size():
|
|
59
|
-
# Handle the GQA case, where q.shape[1] % k.shape[1] == 0.
|
|
60
|
-
k = k.repeat_interleave(q.shape[1] // k.shape[1], dim=1)
|
|
61
|
-
v = v.repeat_interleave(q.shape[1] // v.shape[1], dim=1)
|
|
62
|
-
y = F.scaled_dot_product_attention(
|
|
63
|
-
q,
|
|
64
|
-
k,
|
|
65
|
-
v,
|
|
66
|
-
attn_mask=mask,
|
|
67
|
-
dropout_p=0.0,
|
|
68
|
-
is_causal=mask is None,
|
|
69
|
-
scale=scale,
|
|
70
|
-
)
|
|
71
|
-
return y.transpose(1, 2)
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
def scaled_dot_product_attention_with_hlfb(
|
|
75
|
-
q: torch.Tensor,
|
|
76
|
-
k: torch.Tensor,
|
|
77
|
-
v: torch.Tensor,
|
|
78
|
-
head_size: int,
|
|
79
|
-
mask: Optional[torch.Tensor] = None,
|
|
80
|
-
scale: Optional[float] = None,
|
|
81
|
-
):
|
|
82
|
-
"""Scaled dot product attention with high-level function boundary enabled.
|
|
83
|
-
|
|
84
|
-
Args:
|
|
85
|
-
q (torch.Tensor): Query tensor, with shape [B, T, N, H].
|
|
86
|
-
k (torch.Tensor): Key tensor, with shape [B, T, KV_LEN, H].
|
|
87
|
-
v (torch.Tensor): Value tensor, with shape [B, T, KV_LEN, H].
|
|
88
|
-
head_size (int): head dimension.
|
|
89
|
-
mask (torch.Tensor): the optional mask tensor.
|
|
90
|
-
|
|
91
|
-
Returns:
|
|
92
|
-
The output tensor of scaled_dot_product_attention.
|
|
93
|
-
"""
|
|
94
|
-
|
|
95
|
-
if scale is None:
|
|
96
|
-
scale = 1.0 / math.sqrt(head_size)
|
|
97
|
-
|
|
98
|
-
builder = StableHLOCompositeBuilder(
|
|
99
|
-
name="odml.scaled_dot_product_attention", attr={"scale": scale}
|
|
100
|
-
)
|
|
101
|
-
q, k, v, mask = builder.mark_inputs(q, k, v, mask)
|
|
102
|
-
|
|
103
|
-
q = q.transpose(1, 2)
|
|
104
|
-
k = k.transpose(1, 2)
|
|
105
|
-
v = v.transpose(1, 2)
|
|
106
|
-
if q.size() != k.size():
|
|
107
|
-
# Handle the GQA case, where q.shape[1] % k.shape[1] == 0.
|
|
108
|
-
k = k.repeat_interleave(q.shape[1] // k.shape[1], dim=1)
|
|
109
|
-
v = v.repeat_interleave(q.shape[1] // v.shape[1], dim=1)
|
|
110
|
-
y = F.scaled_dot_product_attention(
|
|
111
|
-
q,
|
|
112
|
-
k,
|
|
113
|
-
v,
|
|
114
|
-
attn_mask=mask,
|
|
115
|
-
dropout_p=0.0,
|
|
116
|
-
is_causal=mask is None,
|
|
117
|
-
scale=scale,
|
|
118
|
-
)
|
|
119
|
-
|
|
120
|
-
result = y.transpose(1, 2)
|
|
121
|
-
result = builder.mark_outputs(result)
|
|
122
|
-
return result
|
|
27
|
+
from ai_edge_torch.generative.layers.scaled_dot_product_attention import scaled_dot_product_attention # NOQA
|
|
28
|
+
from ai_edge_torch.generative.layers.scaled_dot_product_attention import scaled_dot_product_attention_with_hlfb # NOQA
|
|
123
29
|
|
|
124
30
|
|
|
125
31
|
class TransformerBlock(nn.Module):
|
|
@@ -151,7 +57,7 @@ class TransformerBlock(nn.Module):
|
|
|
151
57
|
def forward(
|
|
152
58
|
self,
|
|
153
59
|
x: torch.Tensor,
|
|
154
|
-
rope: Tuple[torch.Tensor, torch.Tensor],
|
|
60
|
+
rope: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
|
155
61
|
mask: Optional[torch.Tensor] = None,
|
|
156
62
|
input_pos: Optional[torch.Tensor] = None,
|
|
157
63
|
) -> torch.Tensor:
|
|
@@ -182,7 +88,6 @@ class TransformerBlock(nn.Module):
|
|
|
182
88
|
return output
|
|
183
89
|
|
|
184
90
|
|
|
185
|
-
# CausalSelfAttention which can support MHQ, MQA or GQA.
|
|
186
91
|
class CausalSelfAttention(nn.Module):
|
|
187
92
|
|
|
188
93
|
def __init__(
|
|
@@ -229,11 +134,12 @@ class CausalSelfAttention(nn.Module):
|
|
|
229
134
|
def forward(
|
|
230
135
|
self,
|
|
231
136
|
x: torch.Tensor,
|
|
232
|
-
rope: Tuple[torch.Tensor, torch.Tensor],
|
|
137
|
+
rope: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
|
233
138
|
mask: Optional[torch.Tensor] = None,
|
|
234
139
|
input_pos: Optional[torch.Tensor] = None,
|
|
235
140
|
) -> torch.Tensor:
|
|
236
|
-
"""Forward function of the CausalSelfAttention layer
|
|
141
|
+
"""Forward function of the CausalSelfAttention layer, which can support
|
|
142
|
+
MQA, GQA and MHA.
|
|
237
143
|
|
|
238
144
|
Args:
|
|
239
145
|
x (torch.Tensor): the input tensor.
|
|
@@ -253,28 +159,35 @@ class CausalSelfAttention(nn.Module):
|
|
|
253
159
|
# Assemble into a number of query groups to support MHA, MQA and GQA.
|
|
254
160
|
q_per_kv = self.config.num_heads // self.config.num_query_groups
|
|
255
161
|
total_qkv = q_per_kv + 2 # Each group has >=1 queries, 1 key, and 1 value.
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
162
|
+
if self.config.qkv_transpose_before_split:
|
|
163
|
+
qkv = qkv.view(
|
|
164
|
+
B, T, total_qkv, self.config.num_query_groups, self.head_dim
|
|
165
|
+
) # (B, T, total_qkv, num_query_groups, head_dim)
|
|
166
|
+
qkv_axis = -3
|
|
167
|
+
else:
|
|
168
|
+
qkv = qkv.view(
|
|
169
|
+
B, T, self.config.num_query_groups, total_qkv, self.head_dim
|
|
170
|
+
) # (B, T, num_query_groups, total_qkv, head_dim)
|
|
171
|
+
qkv_axis = -2
|
|
259
172
|
|
|
260
173
|
# Split batched computation into three.
|
|
261
|
-
q, k, v = qkv.split((q_per_kv, 1, 1), dim
|
|
262
|
-
|
|
174
|
+
q, k, v = qkv.split((q_per_kv, 1, 1), dim=qkv_axis)
|
|
263
175
|
q = q.reshape(B, T, -1, self.head_dim)
|
|
264
176
|
k = k.reshape(B, T, -1, self.head_dim)
|
|
265
177
|
v = v.reshape(B, T, -1, self.head_dim)
|
|
266
178
|
|
|
267
179
|
# Compute rotary positional embedding for query and key.
|
|
268
180
|
n_elem = int(self.config.rotary_percentage * self.head_dim)
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
181
|
+
if n_elem > 0:
|
|
182
|
+
cos, sin = rope
|
|
183
|
+
q_roped = rotary_pos_emb.apply_rope(
|
|
184
|
+
q[..., :n_elem], cos.repeat(1, 2), sin.repeat(1, 2)
|
|
185
|
+
)
|
|
186
|
+
k_roped = rotary_pos_emb.apply_rope(
|
|
187
|
+
k[..., :n_elem], cos.repeat(1, 2), sin.repeat(1, 2)
|
|
188
|
+
)
|
|
189
|
+
q = torch.cat((q_roped, q[..., n_elem:]), dim=-1)
|
|
190
|
+
k = torch.cat((k_roped, k[..., n_elem:]), dim=-1)
|
|
278
191
|
|
|
279
192
|
if self.kv_cache is not None:
|
|
280
193
|
# TODO(haoliang): Handle when execeeding max sequence length.
|
|
@@ -97,6 +97,10 @@ def _get_activation(type_: cfg.ActivationType):
|
|
|
97
97
|
return F.gelu
|
|
98
98
|
elif type_ == cfg.ActivationType.GELU_TANH:
|
|
99
99
|
return lambda x: F.gelu(x, approximate="tanh")
|
|
100
|
+
elif type_ == cfg.ActivationType.GELU_QUICK:
|
|
101
|
+
# GELU approximation that is fast but somewhat inaccurate.
|
|
102
|
+
# See: https://github.com/hendrycks/GELUs
|
|
103
|
+
return lambda x: x * F.sigmoid(1.702 * x)
|
|
100
104
|
elif type_ == cfg.ActivationType.RELU:
|
|
101
105
|
return F.relu
|
|
102
106
|
else:
|
|
@@ -27,6 +27,7 @@ class ActivationType(enum.Enum):
|
|
|
27
27
|
SILU = enum.auto()
|
|
28
28
|
GELU = enum.auto()
|
|
29
29
|
GELU_TANH = enum.auto()
|
|
30
|
+
GELU_QUICK = enum.auto()
|
|
30
31
|
RELU = enum.auto()
|
|
31
32
|
|
|
32
33
|
|
|
@@ -46,7 +47,7 @@ class FeedForwardType(enum.Enum):
|
|
|
46
47
|
|
|
47
48
|
# `output = linear(act(linear(x)))`.
|
|
48
49
|
SEQUENTIAL = enum.auto()
|
|
49
|
-
# `output =
|
|
50
|
+
# `output = linear_2(act(linear_1(x)) * lienar_3(x))`.
|
|
50
51
|
GATED = enum.auto()
|
|
51
52
|
|
|
52
53
|
|
|
@@ -60,6 +61,9 @@ class AttentionConfig:
|
|
|
60
61
|
num_query_groups: Optional[int]
|
|
61
62
|
# Percentage of Rotary Positional Embedding added Q and K projections.
|
|
62
63
|
rotary_percentage: Optional[float] = None
|
|
64
|
+
# Whether to transpose the query groups of qkv bundled tensor before
|
|
65
|
+
# splitting into separated tensors.
|
|
66
|
+
qkv_transpose_before_split: bool = False
|
|
63
67
|
# Whether to use bias with Query, Key, and Value projection.
|
|
64
68
|
qkv_use_bias: bool = False
|
|
65
69
|
# Whether to use bias with attention output projection.
|
|
@@ -0,0 +1,117 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
# Implements scaled dot product attention.
|
|
16
|
+
|
|
17
|
+
import math
|
|
18
|
+
from typing import Optional
|
|
19
|
+
|
|
20
|
+
import torch
|
|
21
|
+
import torch.nn.functional as F
|
|
22
|
+
|
|
23
|
+
from ai_edge_torch.hlfb import StableHLOCompositeBuilder
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def scaled_dot_product_attention(
|
|
27
|
+
q: torch.Tensor,
|
|
28
|
+
k: torch.Tensor,
|
|
29
|
+
v: torch.Tensor,
|
|
30
|
+
head_size: int,
|
|
31
|
+
mask: Optional[torch.Tensor] = None,
|
|
32
|
+
scale: Optional[float] = None,
|
|
33
|
+
):
|
|
34
|
+
"""Scaled dot product attention.
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
q (torch.Tensor): Query tensor, with shape [B, T, N, H].
|
|
38
|
+
k (torch.Tensor): Key tensor, with shape [B, T, KV_LEN, H].
|
|
39
|
+
v (torch.Tensor): Value tensor, with shape [B, T, KV_LEN, H].
|
|
40
|
+
head_size (int): head dimension.
|
|
41
|
+
mask (torch.Tensor): the optional mask tensor.
|
|
42
|
+
|
|
43
|
+
Returns:
|
|
44
|
+
The output tensor of scaled_dot_product_attention.
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
if scale is None:
|
|
48
|
+
scale = 1.0 / math.sqrt(head_size)
|
|
49
|
+
|
|
50
|
+
q = q.transpose(1, 2)
|
|
51
|
+
k = k.transpose(1, 2)
|
|
52
|
+
v = v.transpose(1, 2)
|
|
53
|
+
if q.size() != k.size():
|
|
54
|
+
# Handle the GQA case, where q.shape[1] % k.shape[1] == 0.
|
|
55
|
+
k = k.repeat_interleave(q.shape[1] // k.shape[1], dim=1)
|
|
56
|
+
v = v.repeat_interleave(q.shape[1] // v.shape[1], dim=1)
|
|
57
|
+
y = F.scaled_dot_product_attention(
|
|
58
|
+
q,
|
|
59
|
+
k,
|
|
60
|
+
v,
|
|
61
|
+
attn_mask=mask,
|
|
62
|
+
dropout_p=0.0,
|
|
63
|
+
is_causal=mask is None,
|
|
64
|
+
scale=scale,
|
|
65
|
+
)
|
|
66
|
+
return y.transpose(1, 2)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def scaled_dot_product_attention_with_hlfb(
|
|
70
|
+
q: torch.Tensor,
|
|
71
|
+
k: torch.Tensor,
|
|
72
|
+
v: torch.Tensor,
|
|
73
|
+
head_size: int,
|
|
74
|
+
mask: Optional[torch.Tensor] = None,
|
|
75
|
+
scale: Optional[float] = None,
|
|
76
|
+
):
|
|
77
|
+
"""Scaled dot product attention with high-level function boundary enabled.
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
q (torch.Tensor): Query tensor, with shape [B, T, N, H].
|
|
81
|
+
k (torch.Tensor): Key tensor, with shape [B, T, KV_LEN, H].
|
|
82
|
+
v (torch.Tensor): Value tensor, with shape [B, T, KV_LEN, H].
|
|
83
|
+
head_size (int): head dimension.
|
|
84
|
+
mask (torch.Tensor): the optional mask tensor.
|
|
85
|
+
|
|
86
|
+
Returns:
|
|
87
|
+
The output tensor of scaled_dot_product_attention.
|
|
88
|
+
"""
|
|
89
|
+
|
|
90
|
+
if scale is None:
|
|
91
|
+
scale = 1.0 / math.sqrt(head_size)
|
|
92
|
+
|
|
93
|
+
builder = StableHLOCompositeBuilder(
|
|
94
|
+
name="odml.scaled_dot_product_attention", attr={"scale": scale}
|
|
95
|
+
)
|
|
96
|
+
q, k, v, mask = builder.mark_inputs(q, k, v, mask)
|
|
97
|
+
|
|
98
|
+
q = q.transpose(1, 2)
|
|
99
|
+
k = k.transpose(1, 2)
|
|
100
|
+
v = v.transpose(1, 2)
|
|
101
|
+
if q.size() != k.size():
|
|
102
|
+
# Handle the GQA case, where q.shape[1] % k.shape[1] == 0.
|
|
103
|
+
k = k.repeat_interleave(q.shape[1] // k.shape[1], dim=1)
|
|
104
|
+
v = v.repeat_interleave(q.shape[1] // v.shape[1], dim=1)
|
|
105
|
+
y = F.scaled_dot_product_attention(
|
|
106
|
+
q,
|
|
107
|
+
k,
|
|
108
|
+
v,
|
|
109
|
+
attn_mask=mask,
|
|
110
|
+
dropout_p=0.0,
|
|
111
|
+
is_causal=mask is None,
|
|
112
|
+
scale=scale,
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
result = y.transpose(1, 2)
|
|
116
|
+
result = builder.mark_outputs(result)
|
|
117
|
+
return result
|