ai-edge-torch-nightly 0.5.0.dev20250423__py3-none-any.whl → 0.5.0.dev20250425__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/_convert/conversion.py +1 -3
- ai_edge_torch/_convert/fx_passes/__init__.py +0 -1
- ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py +63 -2
- ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +2 -1
- ai_edge_torch/generative/examples/deepseek/convert_to_tflite.py +38 -4
- ai_edge_torch/generative/examples/deepseek/deepseek.py +1 -0
- ai_edge_torch/generative/examples/gemma3/decoder.py +1 -1
- ai_edge_torch/generative/examples/qwen/convert_to_tflite.py +37 -2
- ai_edge_torch/generative/examples/qwen/qwen.py +1 -0
- ai_edge_torch/generative/layers/attention.py +4 -18
- ai_edge_torch/generative/layers/scaled_dot_product_attention.py +51 -0
- ai_edge_torch/generative/layers/sdpa_with_kv_update.py +38 -44
- ai_edge_torch/generative/test/test_model_conversion.py +38 -33
- ai_edge_torch/generative/test/test_model_conversion_large.py +3 -75
- ai_edge_torch/generative/utilities/converter.py +5 -0
- ai_edge_torch/odml_torch/lowerings/_decomp_registry.py +2 -0
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.5.0.dev20250423.dist-info → ai_edge_torch_nightly-0.5.0.dev20250425.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.5.0.dev20250423.dist-info → ai_edge_torch_nightly-0.5.0.dev20250425.dist-info}/RECORD +22 -25
- ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py +0 -129
- ai_edge_torch/generative/layers/experimental/attention.py +0 -231
- ai_edge_torch/generative/layers/experimental/scaled_dot_product_attention.py +0 -93
- {ai_edge_torch_nightly-0.5.0.dev20250423.dist-info → ai_edge_torch_nightly-0.5.0.dev20250425.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.5.0.dev20250423.dist-info → ai_edge_torch_nightly-0.5.0.dev20250425.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.5.0.dev20250423.dist-info → ai_edge_torch_nightly-0.5.0.dev20250425.dist-info}/top_level.txt +0 -0
@@ -35,13 +35,11 @@ def _run_convert_passes(
|
|
35
35
|
)
|
36
36
|
|
37
37
|
passes = [
|
38
|
-
fx_passes.CastInputsBf16ToF32Pass(),
|
39
|
-
fx_passes.BuildInterpolateCompositePass(),
|
40
|
-
fx_passes.CanonicalizePass(),
|
41
38
|
fx_passes.OptimizeLayoutTransposesPass(),
|
42
39
|
fx_passes.CanonicalizePass(),
|
43
40
|
fx_passes.BuildAtenCompositePass(),
|
44
41
|
fx_passes.RemoveNonUserOutputsPass(),
|
42
|
+
fx_passes.CastInputsBf16ToF32Pass(),
|
45
43
|
]
|
46
44
|
|
47
45
|
# Debuginfo is not injected automatically by odml_torch. Only inject
|
@@ -16,7 +16,6 @@
|
|
16
16
|
from typing import Sequence, Union
|
17
17
|
|
18
18
|
from ai_edge_torch._convert.fx_passes.build_aten_composite_pass import BuildAtenCompositePass
|
19
|
-
from ai_edge_torch._convert.fx_passes.build_interpolate_composite_pass import BuildInterpolateCompositePass
|
20
19
|
from ai_edge_torch._convert.fx_passes.cast_inputs_bf16_to_f32_pass import CastInputsBf16ToF32Pass
|
21
20
|
from ai_edge_torch._convert.fx_passes.inject_mlir_debuginfo_pass import InjectMlirDebuginfoPass
|
22
21
|
from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import OptimizeLayoutTransposesPass
|
@@ -20,7 +20,8 @@ import torch
|
|
20
20
|
import torch.utils._pytree as pytree
|
21
21
|
|
22
22
|
_composite_builders: dict[
|
23
|
-
Callable
|
23
|
+
Callable[[Any, ...], Any],
|
24
|
+
Callable[[torch.fx.GraphModule, torch.fx.Node], None],
|
24
25
|
] = {}
|
25
26
|
|
26
27
|
|
@@ -272,13 +273,73 @@ def _aten_embedding(gm: torch.fx.GraphModule, node: torch.fx.Node):
|
|
272
273
|
output = op(**full_kwargs)
|
273
274
|
output = builder.mark_outputs(output)
|
274
275
|
|
275
|
-
# Explicitly reshape back to the original shape. This places the ReshapeOp
|
276
|
+
# Explicitly reshape back to the original shape. This places the ReshapeOp
|
277
|
+
# outside of the HLFB.
|
276
278
|
output = torch.reshape(output, (*(original_idx_shape), embedding_dim))
|
277
279
|
return output
|
278
280
|
|
279
281
|
node.target = embedding
|
280
282
|
|
281
283
|
|
284
|
+
@_register_composite_builder(torch.ops.aten.upsample_bilinear2d.vec)
|
285
|
+
def _aten_upsample_bilinear2d_vec(_, node: torch.fx.Node):
|
286
|
+
"""Build a composite for aten.upsample_bilinear2d.vec."""
|
287
|
+
op = node.target
|
288
|
+
args_mapper = TorchOpArgumentsMapper(op)
|
289
|
+
# Assumes later FX passes does not change the args/kwargs of the op.
|
290
|
+
# Which is a valid assumption for, given that composite/mark_tensor wrapper
|
291
|
+
# should semantically prevents any future mutations on the op.
|
292
|
+
output_h, output_w = node.meta["val"].shape[-2:]
|
293
|
+
|
294
|
+
def upsample_bilinear2d_vec(*args, **kwargs):
|
295
|
+
nonlocal op, args_mapper
|
296
|
+
full_kwargs = args_mapper.get_full_kwargs(args, kwargs)
|
297
|
+
|
298
|
+
builder = lowertools.StableHLOCompositeBuilder(
|
299
|
+
name="odml.upsample_bilinear2d",
|
300
|
+
attr={
|
301
|
+
"size": (int(output_h), int(output_w)),
|
302
|
+
"align_corners": full_kwargs["align_corners"],
|
303
|
+
"is_nchw_op": True,
|
304
|
+
},
|
305
|
+
)
|
306
|
+
full_kwargs["input"] = builder.mark_inputs(full_kwargs["input"])
|
307
|
+
output = op(**full_kwargs)
|
308
|
+
output = builder.mark_outputs(output)
|
309
|
+
return output
|
310
|
+
|
311
|
+
node.target = upsample_bilinear2d_vec
|
312
|
+
|
313
|
+
|
314
|
+
@_register_composite_builder(torch.ops.aten.upsample_nearest2d.vec)
|
315
|
+
def _aten_upsample_nearest2d_vec(_, node: torch.fx.Node):
|
316
|
+
"""Build a composite for aten.upsample_nearest2d.vec."""
|
317
|
+
op = node.target
|
318
|
+
args_mapper = TorchOpArgumentsMapper(op)
|
319
|
+
# Assumes later FX passes does not change the args/kwargs of the op.
|
320
|
+
# Which is a valid assumption for, given that composite/mark_tensor wrapper
|
321
|
+
# should semantically prevents any future mutations on the op.
|
322
|
+
output_h, output_w = node.meta["val"].shape[-2:]
|
323
|
+
|
324
|
+
def upsample_nearest2d_vec(*args, **kwargs):
|
325
|
+
nonlocal op, args_mapper
|
326
|
+
full_kwargs = args_mapper.get_full_kwargs(args, kwargs)
|
327
|
+
|
328
|
+
builder = lowertools.StableHLOCompositeBuilder(
|
329
|
+
name="tfl.resize_nearest_neighbor",
|
330
|
+
attr={
|
331
|
+
"size": (int(output_h), int(output_w)),
|
332
|
+
"is_nchw_op": True,
|
333
|
+
},
|
334
|
+
)
|
335
|
+
full_kwargs["input"] = builder.mark_inputs(full_kwargs["input"])
|
336
|
+
output = op(**full_kwargs)
|
337
|
+
output = builder.mark_outputs(output)
|
338
|
+
return output
|
339
|
+
|
340
|
+
node.target = upsample_nearest2d_vec
|
341
|
+
|
342
|
+
|
282
343
|
class BuildAtenCompositePass(fx_infra.PassBase):
|
283
344
|
|
284
345
|
def call(self, graph_module: torch.fx.GraphModule):
|
@@ -17,6 +17,7 @@
|
|
17
17
|
import operator
|
18
18
|
|
19
19
|
import ai_edge_torch
|
20
|
+
from ai_edge_torch import lowertools
|
20
21
|
from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import layout_mark
|
21
22
|
from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import op_func_registry
|
22
23
|
from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import utils
|
@@ -24,7 +25,7 @@ import torch
|
|
24
25
|
import torch.utils._pytree as pytree
|
25
26
|
|
26
27
|
aten = torch.ops.aten
|
27
|
-
StableHLOCompositeBuilder =
|
28
|
+
StableHLOCompositeBuilder = lowertools.StableHLOCompositeBuilder
|
28
29
|
|
29
30
|
__all__ = ["rewrite_nhwc_node", "has_nhwc_rewriter"]
|
30
31
|
|
@@ -17,11 +17,43 @@
|
|
17
17
|
|
18
18
|
from absl import app
|
19
19
|
from ai_edge_torch.generative.examples.deepseek import deepseek
|
20
|
+
from ai_edge_torch.generative.layers import kv_cache
|
20
21
|
from ai_edge_torch.generative.utilities import converter
|
21
|
-
from ai_edge_torch.generative.utilities import
|
22
|
+
from ai_edge_torch.generative.utilities.model_builder import export_cfg
|
23
|
+
import torch
|
24
|
+
|
25
|
+
flags = converter.define_conversion_flags('deepseek')
|
26
|
+
ExportConfig = export_cfg.ExportConfig
|
27
|
+
|
28
|
+
|
29
|
+
def _create_mask(mask_len, kv_cache_max_len):
|
30
|
+
mask = torch.full(
|
31
|
+
(mask_len, kv_cache_max_len), float('-inf'), dtype=torch.float32
|
32
|
+
)
|
33
|
+
mask = torch.triu(mask, diagonal=1).unsqueeze(0).unsqueeze(0)
|
34
|
+
return mask
|
35
|
+
|
36
|
+
|
37
|
+
def _create_export_config(
|
38
|
+
prefill_seq_lens: list[int], kv_cache_max_len: int
|
39
|
+
) -> ExportConfig:
|
40
|
+
"""Creates the export config for the model."""
|
41
|
+
export_config = ExportConfig()
|
42
|
+
if isinstance(prefill_seq_lens, list):
|
43
|
+
prefill_mask = [_create_mask(i, kv_cache_max_len) for i in prefill_seq_lens]
|
44
|
+
else:
|
45
|
+
prefill_mask = _create_mask(prefill_seq_lens, kv_cache_max_len)
|
46
|
+
|
47
|
+
export_config.prefill_mask = prefill_mask
|
48
|
+
|
49
|
+
decode_mask = torch.full(
|
50
|
+
(1, kv_cache_max_len), float('-inf'), dtype=torch.float32
|
51
|
+
)
|
52
|
+
decode_mask = torch.triu(decode_mask, diagonal=1).unsqueeze(0).unsqueeze(0)
|
53
|
+
export_config.decode_mask = decode_mask
|
54
|
+
export_config.kvcache_layout = kv_cache.KV_LAYOUT_TRANSPOSED
|
55
|
+
return export_config
|
22
56
|
|
23
|
-
flags = converter.define_conversion_flags("deepseek")
|
24
|
-
ExportConfig = export_config.ExportConfig
|
25
57
|
|
26
58
|
def main(_):
|
27
59
|
pytorch_model = deepseek.build_model(
|
@@ -34,7 +66,9 @@ def main(_):
|
|
34
66
|
prefill_seq_len=flags.FLAGS.prefill_seq_lens,
|
35
67
|
quantize=flags.FLAGS.quantize,
|
36
68
|
lora_ranks=flags.FLAGS.lora_ranks,
|
37
|
-
export_config=
|
69
|
+
export_config=_create_export_config(
|
70
|
+
flags.FLAGS.prefill_seq_lens, flags.FLAGS.kv_cache_max_len
|
71
|
+
),
|
38
72
|
)
|
39
73
|
|
40
74
|
|
@@ -53,6 +53,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
53
53
|
norm_config = cfg.NormalizationConfig(
|
54
54
|
type=cfg.NormalizationType.RMS_NORM,
|
55
55
|
epsilon=1e-06,
|
56
|
+
enable_hlfb=True,
|
56
57
|
)
|
57
58
|
block_config = cfg.TransformerBlockConfig(
|
58
59
|
attn_config=attn_config,
|
@@ -17,10 +17,10 @@
|
|
17
17
|
|
18
18
|
from typing import List, Optional, Tuple
|
19
19
|
|
20
|
+
from ai_edge_torch.generative.layers import attention
|
20
21
|
from ai_edge_torch.generative.layers import builder
|
21
22
|
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
22
23
|
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
23
|
-
from ai_edge_torch.generative.layers.experimental import attention
|
24
24
|
import ai_edge_torch.generative.layers.model_config as cfg
|
25
25
|
import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
|
26
26
|
from ai_edge_torch.generative.utilities import export_config as export_cfg
|
@@ -17,13 +17,14 @@
|
|
17
17
|
|
18
18
|
from absl import app
|
19
19
|
from ai_edge_torch.generative.examples.qwen import qwen
|
20
|
+
from ai_edge_torch.generative.layers import kv_cache
|
20
21
|
from ai_edge_torch.generative.utilities import converter
|
21
22
|
from ai_edge_torch.generative.utilities import export_config
|
23
|
+
import torch
|
22
24
|
|
23
25
|
flags = converter.define_conversion_flags('qwen')
|
24
26
|
ExportConfig = export_config.ExportConfig
|
25
27
|
|
26
|
-
|
27
28
|
_MODEL_SIZE = flags.DEFINE_enum(
|
28
29
|
'model_size',
|
29
30
|
'3b',
|
@@ -37,6 +38,36 @@ _BUILDER = {
|
|
37
38
|
'3b': qwen.build_3b_model,
|
38
39
|
}
|
39
40
|
|
41
|
+
|
42
|
+
def _create_mask(mask_len, kv_cache_max_len):
|
43
|
+
mask = torch.full(
|
44
|
+
(mask_len, kv_cache_max_len), float('-inf'), dtype=torch.float32
|
45
|
+
)
|
46
|
+
mask = torch.triu(mask, diagonal=1).unsqueeze(0).unsqueeze(0)
|
47
|
+
return mask
|
48
|
+
|
49
|
+
|
50
|
+
def _create_export_config(
|
51
|
+
prefill_seq_lens: list[int], kv_cache_max_len: int
|
52
|
+
) -> ExportConfig:
|
53
|
+
"""Creates the export config for the model."""
|
54
|
+
export_config = ExportConfig()
|
55
|
+
if isinstance(prefill_seq_lens, list):
|
56
|
+
prefill_mask = [_create_mask(i, kv_cache_max_len) for i in prefill_seq_lens]
|
57
|
+
else:
|
58
|
+
prefill_mask = _create_mask(prefill_seq_lens, kv_cache_max_len)
|
59
|
+
|
60
|
+
export_config.prefill_mask = prefill_mask
|
61
|
+
|
62
|
+
decode_mask = torch.full(
|
63
|
+
(1, kv_cache_max_len), float('-inf'), dtype=torch.float32
|
64
|
+
)
|
65
|
+
decode_mask = torch.triu(decode_mask, diagonal=1).unsqueeze(0).unsqueeze(0)
|
66
|
+
export_config.decode_mask = decode_mask
|
67
|
+
export_config.kvcache_layout = kv_cache.KV_LAYOUT_TRANSPOSED
|
68
|
+
return export_config
|
69
|
+
|
70
|
+
|
40
71
|
def main(_):
|
41
72
|
pytorch_model = _BUILDER[_MODEL_SIZE.value](
|
42
73
|
flags.FLAGS.checkpoint_path, kv_cache_max_len=flags.FLAGS.kv_cache_max_len
|
@@ -48,7 +79,11 @@ def main(_):
|
|
48
79
|
prefill_seq_len=flags.FLAGS.prefill_seq_lens,
|
49
80
|
quantize=flags.FLAGS.quantize,
|
50
81
|
lora_ranks=flags.FLAGS.lora_ranks,
|
51
|
-
export_config=
|
82
|
+
export_config=_create_export_config(
|
83
|
+
flags.FLAGS.prefill_seq_lens, flags.FLAGS.kv_cache_max_len
|
84
|
+
)
|
85
|
+
if flags.FLAGS.transpose_kv_cache
|
86
|
+
else ExportConfig(),
|
52
87
|
)
|
53
88
|
|
54
89
|
|
@@ -53,6 +53,7 @@ def get_3b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
53
53
|
norm_config = cfg.NormalizationConfig(
|
54
54
|
type=cfg.NormalizationType.RMS_NORM,
|
55
55
|
epsilon=1e-06,
|
56
|
+
enable_hlfb=True,
|
56
57
|
)
|
57
58
|
block_config = cfg.TransformerBlockConfig(
|
58
59
|
attn_config=attn_config,
|
@@ -21,6 +21,7 @@ from ai_edge_torch.generative.layers import builder
|
|
21
21
|
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
22
22
|
from ai_edge_torch.generative.layers import lora as lora_utils
|
23
23
|
from ai_edge_torch.generative.layers import scaled_dot_product_attention as sdpa
|
24
|
+
from ai_edge_torch.generative.layers import sdpa_with_kv_update
|
24
25
|
import ai_edge_torch.generative.layers.model_config as cfg
|
25
26
|
import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
|
26
27
|
import torch
|
@@ -142,11 +143,6 @@ class CausalSelfAttention(nn.Module):
|
|
142
143
|
self.key_norm = builder.build_norm(config.head_dim, config.key_norm_config)
|
143
144
|
self.config = config
|
144
145
|
self.enable_hlfb = enable_hlfb
|
145
|
-
self.sdpa_func = (
|
146
|
-
sdpa.scaled_dot_product_attention_with_hlfb
|
147
|
-
if enable_hlfb
|
148
|
-
else sdpa.scaled_dot_product_attention
|
149
|
-
)
|
150
146
|
|
151
147
|
def forward(
|
152
148
|
self,
|
@@ -174,7 +170,7 @@ class CausalSelfAttention(nn.Module):
|
|
174
170
|
KV Cach Entry (if passed in).
|
175
171
|
"""
|
176
172
|
# Batch size, sequence length, embedding dimensionality.
|
177
|
-
B, T,
|
173
|
+
B, T, _ = x.size()
|
178
174
|
qkv = self.qkv_projection(x)
|
179
175
|
|
180
176
|
# Assemble into a number of query groups to support MHA, MQA and GQA.
|
@@ -218,19 +214,9 @@ class CausalSelfAttention(nn.Module):
|
|
218
214
|
cos, sin = rope
|
219
215
|
q, k = rotary_pos_emb.apply_rope_inline(q, k, cos, sin)
|
220
216
|
|
221
|
-
|
222
|
-
|
223
|
-
k, v = kv_cache.k_cache, kv_cache.v_cache
|
224
|
-
|
225
|
-
sdpa_out = self.sdpa_func(
|
226
|
-
q,
|
227
|
-
k,
|
228
|
-
v,
|
229
|
-
self.config.head_dim,
|
230
|
-
mask=mask,
|
231
|
-
softcap=self.config.logit_softcap,
|
217
|
+
sdpa_out, kv_cache = sdpa_with_kv_update.sdpa_with_kv_update(
|
218
|
+
q, k, v, kv_cache, input_pos, mask, self.config, self.enable_hlfb
|
232
219
|
)
|
233
|
-
sdpa_out = sdpa_out.reshape(B, T, -1)
|
234
220
|
|
235
221
|
# Compute the output projection.
|
236
222
|
y = self.output_projection(sdpa_out)
|
@@ -17,6 +17,8 @@
|
|
17
17
|
import math
|
18
18
|
from typing import Optional
|
19
19
|
|
20
|
+
from ai_edge_torch.generative.custom_ops import bmm_4d as bmm_lib
|
21
|
+
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
20
22
|
from ai_edge_torch.hlfb import StableHLOCompositeBuilder
|
21
23
|
import torch
|
22
24
|
import torch.nn.functional as F
|
@@ -142,3 +144,52 @@ def scaled_dot_product_attention_with_hlfb(
|
|
142
144
|
result = y.transpose(1, 2)
|
143
145
|
result = builder.mark_outputs(result)
|
144
146
|
return result
|
147
|
+
|
148
|
+
|
149
|
+
def scaled_dot_product_attention_transposed(
|
150
|
+
query: torch.Tensor,
|
151
|
+
key: torch.Tensor,
|
152
|
+
value: torch.Tensor,
|
153
|
+
head_size: int,
|
154
|
+
mask: Optional[torch.Tensor] = None,
|
155
|
+
scale: Optional[float] = None,
|
156
|
+
softcap: Optional[float] = None,
|
157
|
+
):
|
158
|
+
"""Scaled dot product attention with transposed key and value.
|
159
|
+
|
160
|
+
Args:
|
161
|
+
query: Query tensor, with shape [B, T, N, H].
|
162
|
+
key: Key tensor, with shape [B, T, KV_LEN, H].
|
163
|
+
value: Value tensor, with shape [B, T, KV_LEN, H].
|
164
|
+
head_size (int): head dimension.
|
165
|
+
mask (torch.Tensor): the optional mask tensor.
|
166
|
+
scale (float): the optional scale factor.
|
167
|
+
softcap (float): the optional softcap for the logits.
|
168
|
+
|
169
|
+
Returns:
|
170
|
+
The output tensor of scaled_dot_product_attention_transposed.
|
171
|
+
"""
|
172
|
+
|
173
|
+
if scale is None:
|
174
|
+
scale = 1.0 / math.sqrt(head_size)
|
175
|
+
|
176
|
+
query = query * scale
|
177
|
+
|
178
|
+
assert mask is not None, "Mask should not be None!"
|
179
|
+
t = mask.shape[2]
|
180
|
+
|
181
|
+
logits = bmm_lib.bmm_4d(query, key)
|
182
|
+
|
183
|
+
_, bk, gt, s = logits.shape
|
184
|
+
g = gt // t
|
185
|
+
logits = logits.reshape((bk, g, t, s))
|
186
|
+
if softcap is not None:
|
187
|
+
logits = torch.tanh(logits / softcap)
|
188
|
+
logits = logits * softcap
|
189
|
+
|
190
|
+
padded_logits = logits + mask
|
191
|
+
padded_logits = padded_logits.reshape(1, bk, gt, s)
|
192
|
+
probs = F.softmax(padded_logits, dim=-1).type_as(key)
|
193
|
+
encoded = bmm_lib.bmm_4d(probs, value)
|
194
|
+
|
195
|
+
return encoded # 1, bk, gt, h
|
@@ -12,16 +12,15 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
|
16
|
-
|
15
|
+
|
16
|
+
"""Common utility functions for data loading etc."""
|
17
|
+
|
17
18
|
from typing import Tuple
|
19
|
+
|
18
20
|
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
19
|
-
from ai_edge_torch.generative.layers import scaled_dot_product_attention as
|
21
|
+
from ai_edge_torch.generative.layers import scaled_dot_product_attention as sdpa
|
20
22
|
from ai_edge_torch.generative.layers.experimental import kv_cache as kv_utils_experimental
|
21
|
-
from ai_edge_torch.generative.layers.experimental import scaled_dot_product_attention as sdpa
|
22
23
|
import ai_edge_torch.generative.layers.model_config as cfg
|
23
|
-
from ai_edge_torch.generative.utilities import types
|
24
|
-
from multipledispatch import dispatch
|
25
24
|
import torch
|
26
25
|
|
27
26
|
|
@@ -33,32 +32,27 @@ def sdpa_with_kv_update(
|
|
33
32
|
input_pos: torch.Tensor,
|
34
33
|
mask: torch.Tensor,
|
35
34
|
config: cfg.AttentionConfig,
|
35
|
+
enable_hlfb: bool,
|
36
36
|
) -> Tuple[torch.Tensor, kv_utils.KVCacheEntry]:
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
kv
|
44
|
-
input_pos=input_pos,
|
45
|
-
mask=mask,
|
46
|
-
config=config,
|
37
|
+
"""Wrapper function for scaled dot product attention with KV cache update."""
|
38
|
+
if kv is not None and kv.kv_layout == kv_utils.KV_LAYOUT_TRANSPOSED:
|
39
|
+
return _sdpa_with_kv_update_transposed(
|
40
|
+
query, key, value, kv, input_pos, mask, config
|
41
|
+
)
|
42
|
+
return _sdpa_with_kv_update_default(
|
43
|
+
query, key, value, kv, input_pos, mask, config, enable_hlfb
|
47
44
|
)
|
48
45
|
|
49
46
|
|
50
|
-
|
51
|
-
|
52
|
-
|
47
|
+
def _sdpa_with_kv_update_transposed(
|
48
|
+
query: torch.Tensor,
|
49
|
+
key: torch.Tensor,
|
50
|
+
value: torch.Tensor,
|
51
|
+
kv: kv_utils.KVCacheEntry,
|
52
|
+
input_pos: torch.Tensor,
|
53
|
+
mask: torch.Tensor,
|
54
|
+
config: cfg.AttentionConfig,
|
53
55
|
) -> Tuple[torch.Tensor, kv_utils.KVCacheEntry]:
|
54
|
-
query = kwargs["query"]
|
55
|
-
key = kwargs["key"]
|
56
|
-
value = kwargs["value"]
|
57
|
-
kv = kwargs["kv"]
|
58
|
-
input_pos = kwargs["input_pos"]
|
59
|
-
mask = kwargs["mask"]
|
60
|
-
config = kwargs["config"]
|
61
|
-
|
62
56
|
# Transpose k/v to specific layout for GPU implementation.
|
63
57
|
b, seq_len, n, h = query.shape
|
64
58
|
g = n // config.num_query_groups
|
@@ -74,12 +68,10 @@ def sdpa_with_kv_update_impl(
|
|
74
68
|
1, -1, config.head_dim, seq_len
|
75
69
|
) # 1, bk, h, s
|
76
70
|
|
77
|
-
|
78
|
-
|
79
|
-
key, value = kv.k_cache, kv.v_cache
|
71
|
+
kv = kv_utils_experimental.update(kv, input_pos, key, value)
|
72
|
+
key, value = kv.k_cache, kv.v_cache
|
80
73
|
|
81
|
-
sdpa_out = sdpa.
|
82
|
-
kv,
|
74
|
+
sdpa_out = sdpa.scaled_dot_product_attention_transposed(
|
83
75
|
query,
|
84
76
|
key,
|
85
77
|
value,
|
@@ -95,24 +87,26 @@ def sdpa_with_kv_update_impl(
|
|
95
87
|
return sdpa_out, kv
|
96
88
|
|
97
89
|
|
98
|
-
|
99
|
-
|
100
|
-
|
90
|
+
def _sdpa_with_kv_update_default(
|
91
|
+
query: torch.Tensor,
|
92
|
+
key: torch.Tensor,
|
93
|
+
value: torch.Tensor,
|
94
|
+
kv: kv_utils.KVCacheEntry,
|
95
|
+
input_pos: torch.Tensor,
|
96
|
+
mask: torch.Tensor,
|
97
|
+
config: cfg.AttentionConfig,
|
98
|
+
enable_hlfb: bool,
|
101
99
|
) -> Tuple[torch.Tensor, kv_utils.KVCacheEntry]:
|
102
|
-
query = kwargs["query"]
|
103
|
-
key = kwargs["key"]
|
104
|
-
value = kwargs["value"]
|
105
|
-
kv = kwargs["kv"]
|
106
|
-
input_pos = kwargs["input_pos"]
|
107
|
-
mask = kwargs["mask"]
|
108
|
-
config = kwargs["config"]
|
109
|
-
|
110
100
|
b, seq_len, _, _ = query.shape
|
111
101
|
if kv is not None:
|
112
102
|
kv = kv_utils.update(kv, input_pos, key, value)
|
113
103
|
key, value = kv.k_cache, kv.v_cache
|
114
104
|
|
115
|
-
|
105
|
+
if enable_hlfb:
|
106
|
+
sdpa_func = sdpa.scaled_dot_product_attention_with_hlfb
|
107
|
+
else:
|
108
|
+
sdpa_func = sdpa.scaled_dot_product_attention
|
109
|
+
sdpa_out = sdpa_func(
|
116
110
|
query,
|
117
111
|
key,
|
118
112
|
value,
|
@@ -32,16 +32,14 @@ class TestModelConversion(googletest.TestCase):
|
|
32
32
|
|
33
33
|
def setUp(self):
|
34
34
|
super().setUp()
|
35
|
-
# Builder function for an Interpreter that supports custom ops.
|
36
35
|
self._interpreter_builder = (
|
37
|
-
lambda tflite_model: lambda: interpreter.
|
38
|
-
custom_op_registerers=["GenAIOpsRegisterer"],
|
36
|
+
lambda tflite_model: lambda: interpreter.Interpreter(
|
39
37
|
model_content=tflite_model,
|
40
38
|
experimental_default_delegate_latest_features=True,
|
41
39
|
)
|
42
40
|
)
|
43
41
|
|
44
|
-
def _get_params(self, enable_hlfb: bool):
|
42
|
+
def _get_params(self, enable_hlfb: bool, kv_layout: kv_cache.KVLayout):
|
45
43
|
"""Returns a model, edge model and the kwargs to use for testing."""
|
46
44
|
config = toy_model_with_kv_cache.get_model_config()
|
47
45
|
config.enable_hlfb = enable_hlfb
|
@@ -49,7 +47,7 @@ class TestModelConversion(googletest.TestCase):
|
|
49
47
|
tokens, input_pos = torch.tensor([[1]], dtype=torch.int), torch.tensor(
|
50
48
|
[10], dtype=torch.int
|
51
49
|
)
|
52
|
-
kv = kv_cache.KVCache.from_model_config(config)
|
50
|
+
kv = kv_cache.KVCache.from_model_config(config, kv_layout=kv_layout)
|
53
51
|
kwargs = {
|
54
52
|
"tokens": tokens,
|
55
53
|
"input_pos": input_pos,
|
@@ -65,8 +63,12 @@ class TestModelConversion(googletest.TestCase):
|
|
65
63
|
)
|
66
64
|
return pytorch_model, edge_model, kwargs
|
67
65
|
|
68
|
-
def _test_model_with_kv_cache(
|
69
|
-
|
66
|
+
def _test_model_with_kv_cache(
|
67
|
+
self,
|
68
|
+
enable_hlfb: bool = False,
|
69
|
+
kv_layout: kv_cache.KVLayout = kv_cache.KV_LAYOUT_DEFAULT,
|
70
|
+
):
|
71
|
+
pytorch_model, edge_model, kwargs = self._get_params(enable_hlfb, kv_layout)
|
70
72
|
|
71
73
|
self.assertTrue(
|
72
74
|
test_utils.compare_tflite_torch(
|
@@ -81,38 +83,34 @@ class TestModelConversion(googletest.TestCase):
|
|
81
83
|
)
|
82
84
|
)
|
83
85
|
|
84
|
-
@googletest.skipIf(
|
85
|
-
ai_edge_torch.config.in_oss,
|
86
|
-
reason="tests with custom ops are not supported in oss",
|
87
|
-
)
|
88
86
|
def test_toy_model_with_kv_cache(self):
|
89
87
|
self._test_model_with_kv_cache(enable_hlfb=False)
|
90
88
|
|
91
|
-
@googletest.skipIf(
|
92
|
-
ai_edge_torch.config.in_oss,
|
93
|
-
reason="tests with custom ops are not supported in oss",
|
94
|
-
)
|
95
89
|
def test_toy_model_with_kv_cache_with_hlfb(self):
|
96
90
|
self._test_model_with_kv_cache(enable_hlfb=True)
|
97
91
|
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
)
|
92
|
+
def test_toy_model_with_kv_cache_transposed(self):
|
93
|
+
self._test_model_with_kv_cache(kv_layout=kv_cache.KV_LAYOUT_TRANSPOSED)
|
94
|
+
|
102
95
|
def test_toy_model_has_dus_op(self):
|
103
96
|
"""Tests that the model has the dynamic update slice op."""
|
104
|
-
_, edge_model, _ = self._get_params(
|
105
|
-
|
106
|
-
custom_op_registerers=["GenAIOpsRegisterer"],
|
107
|
-
model_content=edge_model.tflite_model(),
|
108
|
-
experimental_default_delegate_latest_features=True,
|
97
|
+
_, edge_model, _ = self._get_params(
|
98
|
+
enable_hlfb=True, kv_layout=kv_cache.KV_LAYOUT_DEFAULT
|
109
99
|
)
|
100
|
+
interpreter = self._interpreter_builder(edge_model.tflite_model())()
|
110
101
|
|
111
102
|
# pylint: disable=protected-access
|
112
|
-
op_names = [op["op_name"] for op in
|
103
|
+
op_names = [op["op_name"] for op in interpreter._get_ops_details()]
|
113
104
|
self.assertIn("DYNAMIC_UPDATE_SLICE", op_names)
|
114
105
|
|
115
|
-
def _test_multisig_model(
|
106
|
+
def _test_multisig_model(
|
107
|
+
self,
|
108
|
+
config,
|
109
|
+
pytorch_model,
|
110
|
+
atol,
|
111
|
+
rtol,
|
112
|
+
kv_layout=kv_cache.KV_LAYOUT_DEFAULT,
|
113
|
+
):
|
116
114
|
# prefill
|
117
115
|
seq_len = 10
|
118
116
|
prefill_tokens = torch.zeros((1, seq_len), dtype=torch.int, device="cpu")
|
@@ -124,7 +122,7 @@ class TestModelConversion(googletest.TestCase):
|
|
124
122
|
decode_token = torch.tensor([[1]], dtype=torch.int)
|
125
123
|
decode_input_pos = torch.tensor([5], dtype=torch.int)
|
126
124
|
|
127
|
-
kv = kv_cache.KVCache.from_model_config(config)
|
125
|
+
kv = kv_cache.KVCache.from_model_config(config, kv_layout=kv_layout)
|
128
126
|
|
129
127
|
edge_model = (
|
130
128
|
ai_edge_torch.signature(
|
@@ -160,7 +158,7 @@ class TestModelConversion(googletest.TestCase):
|
|
160
158
|
kv,
|
161
159
|
signature_name="prefill",
|
162
160
|
atol=atol,
|
163
|
-
rtol=
|
161
|
+
rtol=rtol,
|
164
162
|
)
|
165
163
|
)
|
166
164
|
|
@@ -173,19 +171,26 @@ class TestModelConversion(googletest.TestCase):
|
|
173
171
|
kv,
|
174
172
|
signature_name="decode",
|
175
173
|
atol=atol,
|
176
|
-
rtol=
|
174
|
+
rtol=rtol,
|
177
175
|
)
|
178
176
|
)
|
179
177
|
|
180
|
-
@googletest.skipIf(
|
181
|
-
ai_edge_torch.config.in_oss,
|
182
|
-
reason="tests with custom ops are not supported in oss",
|
183
|
-
)
|
184
178
|
def test_tiny_llama_multisig(self):
|
185
179
|
config = tiny_llama.get_fake_model_config()
|
186
180
|
pytorch_model = tiny_llama.TinyLlama(config).eval()
|
187
181
|
self._test_multisig_model(config, pytorch_model, atol=1e-5, rtol=1e-5)
|
188
182
|
|
183
|
+
def test_tiny_llama_multisig_kv_layout_transposed(self):
|
184
|
+
config = tiny_llama.get_fake_model_config()
|
185
|
+
pytorch_model = tiny_llama.TinyLlama(config).eval()
|
186
|
+
self._test_multisig_model(
|
187
|
+
config,
|
188
|
+
pytorch_model,
|
189
|
+
atol=1e-5,
|
190
|
+
rtol=1e-5,
|
191
|
+
kv_layout=kv_cache.KV_LAYOUT_TRANSPOSED,
|
192
|
+
)
|
193
|
+
|
189
194
|
|
190
195
|
if __name__ == "__main__":
|
191
196
|
googletest.main()
|