ai-edge-torch-nightly 0.5.0.dev20250424__py3-none-any.whl → 0.5.0.dev20250425__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/_convert/conversion.py +1 -3
- ai_edge_torch/_convert/fx_passes/__init__.py +0 -1
- ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py +63 -2
- ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +2 -1
- ai_edge_torch/generative/examples/deepseek/convert_to_tflite.py +38 -4
- ai_edge_torch/generative/examples/deepseek/deepseek.py +1 -0
- ai_edge_torch/generative/examples/qwen/convert_to_tflite.py +37 -2
- ai_edge_torch/generative/examples/qwen/qwen.py +1 -0
- ai_edge_torch/generative/layers/scaled_dot_product_attention.py +51 -0
- ai_edge_torch/generative/layers/sdpa_with_kv_update.py +4 -6
- ai_edge_torch/generative/test/test_model_conversion.py +3 -33
- ai_edge_torch/generative/test/test_model_conversion_large.py +3 -75
- ai_edge_torch/generative/utilities/converter.py +5 -0
- ai_edge_torch/odml_torch/lowerings/_decomp_registry.py +2 -0
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.5.0.dev20250424.dist-info → ai_edge_torch_nightly-0.5.0.dev20250425.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.5.0.dev20250424.dist-info → ai_edge_torch_nightly-0.5.0.dev20250425.dist-info}/RECORD +20 -22
- ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py +0 -129
- ai_edge_torch/generative/layers/experimental/scaled_dot_product_attention.py +0 -93
- {ai_edge_torch_nightly-0.5.0.dev20250424.dist-info → ai_edge_torch_nightly-0.5.0.dev20250425.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.5.0.dev20250424.dist-info → ai_edge_torch_nightly-0.5.0.dev20250425.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.5.0.dev20250424.dist-info → ai_edge_torch_nightly-0.5.0.dev20250425.dist-info}/top_level.txt +0 -0
@@ -35,13 +35,11 @@ def _run_convert_passes(
|
|
35
35
|
)
|
36
36
|
|
37
37
|
passes = [
|
38
|
-
fx_passes.CastInputsBf16ToF32Pass(),
|
39
|
-
fx_passes.BuildInterpolateCompositePass(),
|
40
|
-
fx_passes.CanonicalizePass(),
|
41
38
|
fx_passes.OptimizeLayoutTransposesPass(),
|
42
39
|
fx_passes.CanonicalizePass(),
|
43
40
|
fx_passes.BuildAtenCompositePass(),
|
44
41
|
fx_passes.RemoveNonUserOutputsPass(),
|
42
|
+
fx_passes.CastInputsBf16ToF32Pass(),
|
45
43
|
]
|
46
44
|
|
47
45
|
# Debuginfo is not injected automatically by odml_torch. Only inject
|
@@ -16,7 +16,6 @@
|
|
16
16
|
from typing import Sequence, Union
|
17
17
|
|
18
18
|
from ai_edge_torch._convert.fx_passes.build_aten_composite_pass import BuildAtenCompositePass
|
19
|
-
from ai_edge_torch._convert.fx_passes.build_interpolate_composite_pass import BuildInterpolateCompositePass
|
20
19
|
from ai_edge_torch._convert.fx_passes.cast_inputs_bf16_to_f32_pass import CastInputsBf16ToF32Pass
|
21
20
|
from ai_edge_torch._convert.fx_passes.inject_mlir_debuginfo_pass import InjectMlirDebuginfoPass
|
22
21
|
from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import OptimizeLayoutTransposesPass
|
@@ -20,7 +20,8 @@ import torch
|
|
20
20
|
import torch.utils._pytree as pytree
|
21
21
|
|
22
22
|
_composite_builders: dict[
|
23
|
-
Callable
|
23
|
+
Callable[[Any, ...], Any],
|
24
|
+
Callable[[torch.fx.GraphModule, torch.fx.Node], None],
|
24
25
|
] = {}
|
25
26
|
|
26
27
|
|
@@ -272,13 +273,73 @@ def _aten_embedding(gm: torch.fx.GraphModule, node: torch.fx.Node):
|
|
272
273
|
output = op(**full_kwargs)
|
273
274
|
output = builder.mark_outputs(output)
|
274
275
|
|
275
|
-
# Explicitly reshape back to the original shape. This places the ReshapeOp
|
276
|
+
# Explicitly reshape back to the original shape. This places the ReshapeOp
|
277
|
+
# outside of the HLFB.
|
276
278
|
output = torch.reshape(output, (*(original_idx_shape), embedding_dim))
|
277
279
|
return output
|
278
280
|
|
279
281
|
node.target = embedding
|
280
282
|
|
281
283
|
|
284
|
+
@_register_composite_builder(torch.ops.aten.upsample_bilinear2d.vec)
|
285
|
+
def _aten_upsample_bilinear2d_vec(_, node: torch.fx.Node):
|
286
|
+
"""Build a composite for aten.upsample_bilinear2d.vec."""
|
287
|
+
op = node.target
|
288
|
+
args_mapper = TorchOpArgumentsMapper(op)
|
289
|
+
# Assumes later FX passes does not change the args/kwargs of the op.
|
290
|
+
# Which is a valid assumption for, given that composite/mark_tensor wrapper
|
291
|
+
# should semantically prevents any future mutations on the op.
|
292
|
+
output_h, output_w = node.meta["val"].shape[-2:]
|
293
|
+
|
294
|
+
def upsample_bilinear2d_vec(*args, **kwargs):
|
295
|
+
nonlocal op, args_mapper
|
296
|
+
full_kwargs = args_mapper.get_full_kwargs(args, kwargs)
|
297
|
+
|
298
|
+
builder = lowertools.StableHLOCompositeBuilder(
|
299
|
+
name="odml.upsample_bilinear2d",
|
300
|
+
attr={
|
301
|
+
"size": (int(output_h), int(output_w)),
|
302
|
+
"align_corners": full_kwargs["align_corners"],
|
303
|
+
"is_nchw_op": True,
|
304
|
+
},
|
305
|
+
)
|
306
|
+
full_kwargs["input"] = builder.mark_inputs(full_kwargs["input"])
|
307
|
+
output = op(**full_kwargs)
|
308
|
+
output = builder.mark_outputs(output)
|
309
|
+
return output
|
310
|
+
|
311
|
+
node.target = upsample_bilinear2d_vec
|
312
|
+
|
313
|
+
|
314
|
+
@_register_composite_builder(torch.ops.aten.upsample_nearest2d.vec)
|
315
|
+
def _aten_upsample_nearest2d_vec(_, node: torch.fx.Node):
|
316
|
+
"""Build a composite for aten.upsample_nearest2d.vec."""
|
317
|
+
op = node.target
|
318
|
+
args_mapper = TorchOpArgumentsMapper(op)
|
319
|
+
# Assumes later FX passes does not change the args/kwargs of the op.
|
320
|
+
# Which is a valid assumption for, given that composite/mark_tensor wrapper
|
321
|
+
# should semantically prevents any future mutations on the op.
|
322
|
+
output_h, output_w = node.meta["val"].shape[-2:]
|
323
|
+
|
324
|
+
def upsample_nearest2d_vec(*args, **kwargs):
|
325
|
+
nonlocal op, args_mapper
|
326
|
+
full_kwargs = args_mapper.get_full_kwargs(args, kwargs)
|
327
|
+
|
328
|
+
builder = lowertools.StableHLOCompositeBuilder(
|
329
|
+
name="tfl.resize_nearest_neighbor",
|
330
|
+
attr={
|
331
|
+
"size": (int(output_h), int(output_w)),
|
332
|
+
"is_nchw_op": True,
|
333
|
+
},
|
334
|
+
)
|
335
|
+
full_kwargs["input"] = builder.mark_inputs(full_kwargs["input"])
|
336
|
+
output = op(**full_kwargs)
|
337
|
+
output = builder.mark_outputs(output)
|
338
|
+
return output
|
339
|
+
|
340
|
+
node.target = upsample_nearest2d_vec
|
341
|
+
|
342
|
+
|
282
343
|
class BuildAtenCompositePass(fx_infra.PassBase):
|
283
344
|
|
284
345
|
def call(self, graph_module: torch.fx.GraphModule):
|
@@ -17,6 +17,7 @@
|
|
17
17
|
import operator
|
18
18
|
|
19
19
|
import ai_edge_torch
|
20
|
+
from ai_edge_torch import lowertools
|
20
21
|
from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import layout_mark
|
21
22
|
from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import op_func_registry
|
22
23
|
from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import utils
|
@@ -24,7 +25,7 @@ import torch
|
|
24
25
|
import torch.utils._pytree as pytree
|
25
26
|
|
26
27
|
aten = torch.ops.aten
|
27
|
-
StableHLOCompositeBuilder =
|
28
|
+
StableHLOCompositeBuilder = lowertools.StableHLOCompositeBuilder
|
28
29
|
|
29
30
|
__all__ = ["rewrite_nhwc_node", "has_nhwc_rewriter"]
|
30
31
|
|
@@ -17,11 +17,43 @@
|
|
17
17
|
|
18
18
|
from absl import app
|
19
19
|
from ai_edge_torch.generative.examples.deepseek import deepseek
|
20
|
+
from ai_edge_torch.generative.layers import kv_cache
|
20
21
|
from ai_edge_torch.generative.utilities import converter
|
21
|
-
from ai_edge_torch.generative.utilities import
|
22
|
+
from ai_edge_torch.generative.utilities.model_builder import export_cfg
|
23
|
+
import torch
|
24
|
+
|
25
|
+
flags = converter.define_conversion_flags('deepseek')
|
26
|
+
ExportConfig = export_cfg.ExportConfig
|
27
|
+
|
28
|
+
|
29
|
+
def _create_mask(mask_len, kv_cache_max_len):
|
30
|
+
mask = torch.full(
|
31
|
+
(mask_len, kv_cache_max_len), float('-inf'), dtype=torch.float32
|
32
|
+
)
|
33
|
+
mask = torch.triu(mask, diagonal=1).unsqueeze(0).unsqueeze(0)
|
34
|
+
return mask
|
35
|
+
|
36
|
+
|
37
|
+
def _create_export_config(
|
38
|
+
prefill_seq_lens: list[int], kv_cache_max_len: int
|
39
|
+
) -> ExportConfig:
|
40
|
+
"""Creates the export config for the model."""
|
41
|
+
export_config = ExportConfig()
|
42
|
+
if isinstance(prefill_seq_lens, list):
|
43
|
+
prefill_mask = [_create_mask(i, kv_cache_max_len) for i in prefill_seq_lens]
|
44
|
+
else:
|
45
|
+
prefill_mask = _create_mask(prefill_seq_lens, kv_cache_max_len)
|
46
|
+
|
47
|
+
export_config.prefill_mask = prefill_mask
|
48
|
+
|
49
|
+
decode_mask = torch.full(
|
50
|
+
(1, kv_cache_max_len), float('-inf'), dtype=torch.float32
|
51
|
+
)
|
52
|
+
decode_mask = torch.triu(decode_mask, diagonal=1).unsqueeze(0).unsqueeze(0)
|
53
|
+
export_config.decode_mask = decode_mask
|
54
|
+
export_config.kvcache_layout = kv_cache.KV_LAYOUT_TRANSPOSED
|
55
|
+
return export_config
|
22
56
|
|
23
|
-
flags = converter.define_conversion_flags("deepseek")
|
24
|
-
ExportConfig = export_config.ExportConfig
|
25
57
|
|
26
58
|
def main(_):
|
27
59
|
pytorch_model = deepseek.build_model(
|
@@ -34,7 +66,9 @@ def main(_):
|
|
34
66
|
prefill_seq_len=flags.FLAGS.prefill_seq_lens,
|
35
67
|
quantize=flags.FLAGS.quantize,
|
36
68
|
lora_ranks=flags.FLAGS.lora_ranks,
|
37
|
-
export_config=
|
69
|
+
export_config=_create_export_config(
|
70
|
+
flags.FLAGS.prefill_seq_lens, flags.FLAGS.kv_cache_max_len
|
71
|
+
),
|
38
72
|
)
|
39
73
|
|
40
74
|
|
@@ -53,6 +53,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
53
53
|
norm_config = cfg.NormalizationConfig(
|
54
54
|
type=cfg.NormalizationType.RMS_NORM,
|
55
55
|
epsilon=1e-06,
|
56
|
+
enable_hlfb=True,
|
56
57
|
)
|
57
58
|
block_config = cfg.TransformerBlockConfig(
|
58
59
|
attn_config=attn_config,
|
@@ -17,13 +17,14 @@
|
|
17
17
|
|
18
18
|
from absl import app
|
19
19
|
from ai_edge_torch.generative.examples.qwen import qwen
|
20
|
+
from ai_edge_torch.generative.layers import kv_cache
|
20
21
|
from ai_edge_torch.generative.utilities import converter
|
21
22
|
from ai_edge_torch.generative.utilities import export_config
|
23
|
+
import torch
|
22
24
|
|
23
25
|
flags = converter.define_conversion_flags('qwen')
|
24
26
|
ExportConfig = export_config.ExportConfig
|
25
27
|
|
26
|
-
|
27
28
|
_MODEL_SIZE = flags.DEFINE_enum(
|
28
29
|
'model_size',
|
29
30
|
'3b',
|
@@ -37,6 +38,36 @@ _BUILDER = {
|
|
37
38
|
'3b': qwen.build_3b_model,
|
38
39
|
}
|
39
40
|
|
41
|
+
|
42
|
+
def _create_mask(mask_len, kv_cache_max_len):
|
43
|
+
mask = torch.full(
|
44
|
+
(mask_len, kv_cache_max_len), float('-inf'), dtype=torch.float32
|
45
|
+
)
|
46
|
+
mask = torch.triu(mask, diagonal=1).unsqueeze(0).unsqueeze(0)
|
47
|
+
return mask
|
48
|
+
|
49
|
+
|
50
|
+
def _create_export_config(
|
51
|
+
prefill_seq_lens: list[int], kv_cache_max_len: int
|
52
|
+
) -> ExportConfig:
|
53
|
+
"""Creates the export config for the model."""
|
54
|
+
export_config = ExportConfig()
|
55
|
+
if isinstance(prefill_seq_lens, list):
|
56
|
+
prefill_mask = [_create_mask(i, kv_cache_max_len) for i in prefill_seq_lens]
|
57
|
+
else:
|
58
|
+
prefill_mask = _create_mask(prefill_seq_lens, kv_cache_max_len)
|
59
|
+
|
60
|
+
export_config.prefill_mask = prefill_mask
|
61
|
+
|
62
|
+
decode_mask = torch.full(
|
63
|
+
(1, kv_cache_max_len), float('-inf'), dtype=torch.float32
|
64
|
+
)
|
65
|
+
decode_mask = torch.triu(decode_mask, diagonal=1).unsqueeze(0).unsqueeze(0)
|
66
|
+
export_config.decode_mask = decode_mask
|
67
|
+
export_config.kvcache_layout = kv_cache.KV_LAYOUT_TRANSPOSED
|
68
|
+
return export_config
|
69
|
+
|
70
|
+
|
40
71
|
def main(_):
|
41
72
|
pytorch_model = _BUILDER[_MODEL_SIZE.value](
|
42
73
|
flags.FLAGS.checkpoint_path, kv_cache_max_len=flags.FLAGS.kv_cache_max_len
|
@@ -48,7 +79,11 @@ def main(_):
|
|
48
79
|
prefill_seq_len=flags.FLAGS.prefill_seq_lens,
|
49
80
|
quantize=flags.FLAGS.quantize,
|
50
81
|
lora_ranks=flags.FLAGS.lora_ranks,
|
51
|
-
export_config=
|
82
|
+
export_config=_create_export_config(
|
83
|
+
flags.FLAGS.prefill_seq_lens, flags.FLAGS.kv_cache_max_len
|
84
|
+
)
|
85
|
+
if flags.FLAGS.transpose_kv_cache
|
86
|
+
else ExportConfig(),
|
52
87
|
)
|
53
88
|
|
54
89
|
|
@@ -53,6 +53,7 @@ def get_3b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
53
53
|
norm_config = cfg.NormalizationConfig(
|
54
54
|
type=cfg.NormalizationType.RMS_NORM,
|
55
55
|
epsilon=1e-06,
|
56
|
+
enable_hlfb=True,
|
56
57
|
)
|
57
58
|
block_config = cfg.TransformerBlockConfig(
|
58
59
|
attn_config=attn_config,
|
@@ -17,6 +17,8 @@
|
|
17
17
|
import math
|
18
18
|
from typing import Optional
|
19
19
|
|
20
|
+
from ai_edge_torch.generative.custom_ops import bmm_4d as bmm_lib
|
21
|
+
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
20
22
|
from ai_edge_torch.hlfb import StableHLOCompositeBuilder
|
21
23
|
import torch
|
22
24
|
import torch.nn.functional as F
|
@@ -142,3 +144,52 @@ def scaled_dot_product_attention_with_hlfb(
|
|
142
144
|
result = y.transpose(1, 2)
|
143
145
|
result = builder.mark_outputs(result)
|
144
146
|
return result
|
147
|
+
|
148
|
+
|
149
|
+
def scaled_dot_product_attention_transposed(
|
150
|
+
query: torch.Tensor,
|
151
|
+
key: torch.Tensor,
|
152
|
+
value: torch.Tensor,
|
153
|
+
head_size: int,
|
154
|
+
mask: Optional[torch.Tensor] = None,
|
155
|
+
scale: Optional[float] = None,
|
156
|
+
softcap: Optional[float] = None,
|
157
|
+
):
|
158
|
+
"""Scaled dot product attention with transposed key and value.
|
159
|
+
|
160
|
+
Args:
|
161
|
+
query: Query tensor, with shape [B, T, N, H].
|
162
|
+
key: Key tensor, with shape [B, T, KV_LEN, H].
|
163
|
+
value: Value tensor, with shape [B, T, KV_LEN, H].
|
164
|
+
head_size (int): head dimension.
|
165
|
+
mask (torch.Tensor): the optional mask tensor.
|
166
|
+
scale (float): the optional scale factor.
|
167
|
+
softcap (float): the optional softcap for the logits.
|
168
|
+
|
169
|
+
Returns:
|
170
|
+
The output tensor of scaled_dot_product_attention_transposed.
|
171
|
+
"""
|
172
|
+
|
173
|
+
if scale is None:
|
174
|
+
scale = 1.0 / math.sqrt(head_size)
|
175
|
+
|
176
|
+
query = query * scale
|
177
|
+
|
178
|
+
assert mask is not None, "Mask should not be None!"
|
179
|
+
t = mask.shape[2]
|
180
|
+
|
181
|
+
logits = bmm_lib.bmm_4d(query, key)
|
182
|
+
|
183
|
+
_, bk, gt, s = logits.shape
|
184
|
+
g = gt // t
|
185
|
+
logits = logits.reshape((bk, g, t, s))
|
186
|
+
if softcap is not None:
|
187
|
+
logits = torch.tanh(logits / softcap)
|
188
|
+
logits = logits * softcap
|
189
|
+
|
190
|
+
padded_logits = logits + mask
|
191
|
+
padded_logits = padded_logits.reshape(1, bk, gt, s)
|
192
|
+
probs = F.softmax(padded_logits, dim=-1).type_as(key)
|
193
|
+
encoded = bmm_lib.bmm_4d(probs, value)
|
194
|
+
|
195
|
+
return encoded # 1, bk, gt, h
|
@@ -18,9 +18,8 @@
|
|
18
18
|
from typing import Tuple
|
19
19
|
|
20
20
|
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
21
|
-
from ai_edge_torch.generative.layers import scaled_dot_product_attention as
|
21
|
+
from ai_edge_torch.generative.layers import scaled_dot_product_attention as sdpa
|
22
22
|
from ai_edge_torch.generative.layers.experimental import kv_cache as kv_utils_experimental
|
23
|
-
from ai_edge_torch.generative.layers.experimental import scaled_dot_product_attention as sdpa
|
24
23
|
import ai_edge_torch.generative.layers.model_config as cfg
|
25
24
|
import torch
|
26
25
|
|
@@ -72,8 +71,7 @@ def _sdpa_with_kv_update_transposed(
|
|
72
71
|
kv = kv_utils_experimental.update(kv, input_pos, key, value)
|
73
72
|
key, value = kv.k_cache, kv.v_cache
|
74
73
|
|
75
|
-
sdpa_out = sdpa.
|
76
|
-
kv,
|
74
|
+
sdpa_out = sdpa.scaled_dot_product_attention_transposed(
|
77
75
|
query,
|
78
76
|
key,
|
79
77
|
value,
|
@@ -105,9 +103,9 @@ def _sdpa_with_kv_update_default(
|
|
105
103
|
key, value = kv.k_cache, kv.v_cache
|
106
104
|
|
107
105
|
if enable_hlfb:
|
108
|
-
sdpa_func =
|
106
|
+
sdpa_func = sdpa.scaled_dot_product_attention_with_hlfb
|
109
107
|
else:
|
110
|
-
sdpa_func =
|
108
|
+
sdpa_func = sdpa.scaled_dot_product_attention
|
111
109
|
sdpa_out = sdpa_func(
|
112
110
|
query,
|
113
111
|
key,
|
@@ -32,10 +32,8 @@ class TestModelConversion(googletest.TestCase):
|
|
32
32
|
|
33
33
|
def setUp(self):
|
34
34
|
super().setUp()
|
35
|
-
# Builder function for an Interpreter that supports custom ops.
|
36
35
|
self._interpreter_builder = (
|
37
|
-
lambda tflite_model: lambda: interpreter.
|
38
|
-
custom_op_registerers=["GenAIOpsRegisterer"],
|
36
|
+
lambda tflite_model: lambda: interpreter.Interpreter(
|
39
37
|
model_content=tflite_model,
|
40
38
|
experimental_default_delegate_latest_features=True,
|
41
39
|
)
|
@@ -85,44 +83,24 @@ class TestModelConversion(googletest.TestCase):
|
|
85
83
|
)
|
86
84
|
)
|
87
85
|
|
88
|
-
@googletest.skipIf(
|
89
|
-
ai_edge_torch.config.in_oss,
|
90
|
-
reason="tests with custom ops are not supported in oss",
|
91
|
-
)
|
92
86
|
def test_toy_model_with_kv_cache(self):
|
93
87
|
self._test_model_with_kv_cache(enable_hlfb=False)
|
94
88
|
|
95
|
-
@googletest.skipIf(
|
96
|
-
ai_edge_torch.config.in_oss,
|
97
|
-
reason="tests with custom ops are not supported in oss",
|
98
|
-
)
|
99
89
|
def test_toy_model_with_kv_cache_with_hlfb(self):
|
100
90
|
self._test_model_with_kv_cache(enable_hlfb=True)
|
101
91
|
|
102
|
-
@googletest.skipIf(
|
103
|
-
ai_edge_torch.config.in_oss,
|
104
|
-
reason="tests with custom ops are not supported in oss",
|
105
|
-
)
|
106
92
|
def test_toy_model_with_kv_cache_transposed(self):
|
107
93
|
self._test_model_with_kv_cache(kv_layout=kv_cache.KV_LAYOUT_TRANSPOSED)
|
108
94
|
|
109
|
-
@googletest.skipIf(
|
110
|
-
ai_edge_torch.config.in_oss,
|
111
|
-
reason="tests with custom ops are not supported in oss",
|
112
|
-
)
|
113
95
|
def test_toy_model_has_dus_op(self):
|
114
96
|
"""Tests that the model has the dynamic update slice op."""
|
115
97
|
_, edge_model, _ = self._get_params(
|
116
98
|
enable_hlfb=True, kv_layout=kv_cache.KV_LAYOUT_DEFAULT
|
117
99
|
)
|
118
|
-
|
119
|
-
custom_op_registerers=["GenAIOpsRegisterer"],
|
120
|
-
model_content=edge_model.tflite_model(),
|
121
|
-
experimental_default_delegate_latest_features=True,
|
122
|
-
)
|
100
|
+
interpreter = self._interpreter_builder(edge_model.tflite_model())()
|
123
101
|
|
124
102
|
# pylint: disable=protected-access
|
125
|
-
op_names = [op["op_name"] for op in
|
103
|
+
op_names = [op["op_name"] for op in interpreter._get_ops_details()]
|
126
104
|
self.assertIn("DYNAMIC_UPDATE_SLICE", op_names)
|
127
105
|
|
128
106
|
def _test_multisig_model(
|
@@ -197,19 +175,11 @@ class TestModelConversion(googletest.TestCase):
|
|
197
175
|
)
|
198
176
|
)
|
199
177
|
|
200
|
-
@googletest.skipIf(
|
201
|
-
ai_edge_torch.config.in_oss,
|
202
|
-
reason="tests with custom ops are not supported in oss",
|
203
|
-
)
|
204
178
|
def test_tiny_llama_multisig(self):
|
205
179
|
config = tiny_llama.get_fake_model_config()
|
206
180
|
pytorch_model = tiny_llama.TinyLlama(config).eval()
|
207
181
|
self._test_multisig_model(config, pytorch_model, atol=1e-5, rtol=1e-5)
|
208
182
|
|
209
|
-
@googletest.skipIf(
|
210
|
-
ai_edge_torch.config.in_oss,
|
211
|
-
reason="tests with custom ops are not supported in oss",
|
212
|
-
)
|
213
183
|
def test_tiny_llama_multisig_kv_layout_transposed(self):
|
214
184
|
config = tiny_llama.get_fake_model_config()
|
215
185
|
pytorch_model = tiny_llama.TinyLlama(config).eval()
|
@@ -48,10 +48,8 @@ class TestModelConversion(googletest.TestCase):
|
|
48
48
|
|
49
49
|
def setUp(self):
|
50
50
|
super().setUp()
|
51
|
-
# Builder function for an Interpreter that supports custom ops.
|
52
51
|
self._interpreter_builder = (
|
53
|
-
lambda tflite_model: lambda: interpreter.
|
54
|
-
custom_op_registerers=["GenAIOpsRegisterer"],
|
52
|
+
lambda tflite_model: lambda: interpreter.Interpreter(
|
55
53
|
model_content=tflite_model,
|
56
54
|
experimental_default_delegate_latest_features=True,
|
57
55
|
)
|
@@ -94,110 +92,62 @@ class TestModelConversion(googletest.TestCase):
|
|
94
92
|
)
|
95
93
|
)
|
96
94
|
|
97
|
-
@googletest.skipIf(
|
98
|
-
ai_edge_torch.config.in_oss,
|
99
|
-
reason="tests with custom ops are not supported in oss",
|
100
|
-
)
|
101
95
|
def test_gemma1(self):
|
102
96
|
config = gemma1.get_fake_model_config()
|
103
97
|
pytorch_model = gemma1.Gemma1(config).eval()
|
104
98
|
self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)
|
105
99
|
|
106
|
-
@googletest.skipIf(
|
107
|
-
ai_edge_torch.config.in_oss,
|
108
|
-
reason="tests with custom ops are not supported in oss",
|
109
|
-
)
|
110
100
|
def test_gemma2(self):
|
111
101
|
config = gemma2.get_fake_model_config()
|
112
102
|
pytorch_model = gemma2.Gemma2(config).eval()
|
113
103
|
self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
|
114
104
|
|
115
|
-
@googletest.skipIf(
|
116
|
-
ai_edge_torch.config.in_oss,
|
117
|
-
reason="tests with custom ops are not supported in oss",
|
118
|
-
)
|
119
105
|
def test_llama(self):
|
120
106
|
config = llama.get_fake_model_config()
|
121
107
|
pytorch_model = llama.Llama(config).eval()
|
122
108
|
self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)
|
123
109
|
|
124
|
-
@googletest.skipIf(
|
125
|
-
ai_edge_torch.config.in_oss,
|
126
|
-
reason="tests with custom ops are not supported in oss",
|
127
|
-
)
|
128
110
|
def test_phi2(self):
|
129
111
|
config = phi2.get_fake_model_config()
|
130
112
|
pytorch_model = phi2.Phi2(config).eval()
|
131
113
|
# Phi-2 logits are very big, so we need a larger absolute tolerance.
|
132
114
|
self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)
|
133
115
|
|
134
|
-
@googletest.skipIf(
|
135
|
-
ai_edge_torch.config.in_oss,
|
136
|
-
reason="tests with custom ops are not supported in oss",
|
137
|
-
)
|
138
116
|
def test_phi3(self):
|
139
117
|
config = phi3.get_fake_model_config()
|
140
118
|
pytorch_model = phi3.Phi3_5Mini(config).eval()
|
141
119
|
self._test_model(config, pytorch_model, "prefill", atol=1e-5, rtol=1e-5)
|
142
120
|
|
143
|
-
@googletest.skipIf(
|
144
|
-
ai_edge_torch.config.in_oss,
|
145
|
-
reason="tests with custom ops are not supported in oss",
|
146
|
-
)
|
147
121
|
def test_phi4(self):
|
148
122
|
config = phi4.get_fake_model_config()
|
149
123
|
pytorch_model = phi4.Phi4Mini(config).eval()
|
150
124
|
self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)
|
151
125
|
|
152
|
-
@googletest.skipIf(
|
153
|
-
ai_edge_torch.config.in_oss,
|
154
|
-
reason="tests with custom ops are not supported in oss",
|
155
|
-
)
|
156
126
|
def test_smollm(self):
|
157
127
|
config = smollm.get_fake_model_config()
|
158
128
|
pytorch_model = smollm.SmolLM(config).eval()
|
159
129
|
self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
|
160
130
|
|
161
|
-
@googletest.skipIf(
|
162
|
-
ai_edge_torch.config.in_oss,
|
163
|
-
reason="tests with custom ops are not supported in oss",
|
164
|
-
)
|
165
131
|
def test_smollm2(self):
|
166
132
|
config = smollm.get_fake_model_config_v2()
|
167
133
|
pytorch_model = smollm.SmolLM2(config).eval()
|
168
134
|
self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
|
169
135
|
|
170
|
-
@googletest.skipIf(
|
171
|
-
ai_edge_torch.config.in_oss,
|
172
|
-
reason="tests with custom ops are not supported in oss",
|
173
|
-
)
|
174
136
|
def test_openelm(self):
|
175
137
|
config = openelm.get_fake_model_config()
|
176
138
|
pytorch_model = openelm.OpenELM(config).eval()
|
177
139
|
self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
|
178
140
|
|
179
|
-
@googletest.skipIf(
|
180
|
-
ai_edge_torch.config.in_oss,
|
181
|
-
reason="tests with custom ops are not supported in oss",
|
182
|
-
)
|
183
141
|
def test_qwen(self):
|
184
142
|
config = qwen.get_fake_model_config()
|
185
143
|
pytorch_model = qwen.Qwen(config).eval()
|
186
144
|
self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)
|
187
145
|
|
188
|
-
@googletest.skipIf(
|
189
|
-
ai_edge_torch.config.in_oss,
|
190
|
-
reason="tests with custom ops are not supported in oss",
|
191
|
-
)
|
192
146
|
def test_deepseek(self):
|
193
147
|
config = deepseek.get_fake_model_config()
|
194
148
|
pytorch_model = deepseek.DeepSeekDistillQwen(config).eval()
|
195
149
|
self._test_model(config, pytorch_model, "prefill", atol=1e-5, rtol=1e-5)
|
196
150
|
|
197
|
-
@googletest.skipIf(
|
198
|
-
ai_edge_torch.config.in_oss,
|
199
|
-
reason="tests with custom ops are not supported in oss",
|
200
|
-
)
|
201
151
|
def test_amd_llama_135m(self):
|
202
152
|
config = amd_llama_135m.get_fake_model_config()
|
203
153
|
pytorch_model = amd_llama_135m.AmdLlama(config).eval()
|
@@ -246,19 +196,11 @@ class TestModelConversion(googletest.TestCase):
|
|
246
196
|
)
|
247
197
|
)
|
248
198
|
|
249
|
-
@googletest.skipIf(
|
250
|
-
ai_edge_torch.config.in_oss,
|
251
|
-
reason="tests with custom ops are not supported in oss",
|
252
|
-
)
|
253
199
|
def test_paligemma1(self):
|
254
200
|
self._test_paligemma_model(
|
255
201
|
decoder.Decoder, decoder.get_fake_decoder_config, atol=1e-3, rtol=1e-5
|
256
202
|
)
|
257
203
|
|
258
|
-
@googletest.skipIf(
|
259
|
-
ai_edge_torch.config.in_oss,
|
260
|
-
reason="tests with custom ops are not supported in oss",
|
261
|
-
)
|
262
204
|
def test_paligemma2(self):
|
263
205
|
self._test_paligemma_model(
|
264
206
|
decoder2.Decoder2,
|
@@ -267,10 +209,6 @@ class TestModelConversion(googletest.TestCase):
|
|
267
209
|
rtol=1e-5,
|
268
210
|
)
|
269
211
|
|
270
|
-
@googletest.skipIf(
|
271
|
-
ai_edge_torch.config.in_oss,
|
272
|
-
reason="tests with custom ops are not supported in oss",
|
273
|
-
)
|
274
212
|
def test_qwen_vl_model(self):
|
275
213
|
config = qwen_vl.get_fake_model_config()
|
276
214
|
pytorch_model = qwen_vl.QwenVL(config).eval()
|
@@ -316,10 +254,7 @@ class TestModelConversion(googletest.TestCase):
|
|
316
254
|
)
|
317
255
|
)
|
318
256
|
|
319
|
-
@googletest.skipIf(
|
320
|
-
ai_edge_torch.config.in_oss,
|
321
|
-
reason="tests with custom ops are not supported in oss",
|
322
|
-
)
|
257
|
+
@googletest.skipIf(ai_edge_torch.config.in_oss, reason="flaky")
|
323
258
|
def test_stable_diffusion_clip(self):
|
324
259
|
config = sd_clip.get_fake_model_config()
|
325
260
|
prompt_tokens = torch.from_numpy(
|
@@ -348,10 +283,7 @@ class TestModelConversion(googletest.TestCase):
|
|
348
283
|
)
|
349
284
|
)
|
350
285
|
|
351
|
-
@googletest.skipIf(
|
352
|
-
ai_edge_torch.config.in_oss,
|
353
|
-
reason="tests with custom ops are not supported in oss",
|
354
|
-
)
|
286
|
+
@googletest.skipIf(ai_edge_torch.config.in_oss, reason="b/413106901")
|
355
287
|
def test_stable_diffusion_diffusion(self):
|
356
288
|
config = sd_diffusion.get_fake_model_config(2)
|
357
289
|
# Reduce stddev(scale) of input values to avoid too big output logits which
|
@@ -390,10 +322,6 @@ class TestModelConversion(googletest.TestCase):
|
|
390
322
|
)
|
391
323
|
)
|
392
324
|
|
393
|
-
@googletest.skipIf(
|
394
|
-
ai_edge_torch.config.in_oss,
|
395
|
-
reason="tests with custom ops are not supported in oss",
|
396
|
-
)
|
397
325
|
def test_stable_diffusion_decoder(self):
|
398
326
|
config = sd_decoder.get_fake_model_config()
|
399
327
|
# Reduce stddev(scale) of input values to avoid too big output logits which
|
@@ -81,6 +81,11 @@ def define_conversion_flags(model_name: str):
|
|
81
81
|
'If set, the model will be converted with the provided list of LoRA'
|
82
82
|
' ranks.',
|
83
83
|
)
|
84
|
+
flags.DEFINE_bool(
|
85
|
+
'transpose_kv_cache',
|
86
|
+
False,
|
87
|
+
'If set, the model will be converted with transposed KV cache.',
|
88
|
+
)
|
84
89
|
|
85
90
|
return flags
|
86
91
|
|
@@ -34,6 +34,8 @@ fx_infra.decomp.update_pre_lower_decomp(
|
|
34
34
|
torch.ops.aten.replication_pad1d,
|
35
35
|
torch.ops.aten.replication_pad2d,
|
36
36
|
torch.ops.aten.replication_pad3d,
|
37
|
+
torch.ops.aten.upsample_bilinear2d.vec,
|
38
|
+
torch.ops.aten.upsample_nearest2d.vec,
|
37
39
|
torch.ops.aten.addmm,
|
38
40
|
])
|
39
41
|
)
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.5.0.
|
3
|
+
Version: 0.5.0.dev20250425
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
@@ -2,16 +2,15 @@ ai_edge_torch/__init__.py,sha256=8sPR_5uXJA4NEE0nIwNdSl-ADOJEoR8hAgYvBQDY70Y,120
|
|
2
2
|
ai_edge_torch/_config.py,sha256=AiqhbcheF7j_ozIGDLC89k1we95aVgFDa-tR6h7UI0s,2529
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
4
4
|
ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
|
5
|
-
ai_edge_torch/version.py,sha256=
|
5
|
+
ai_edge_torch/version.py,sha256=_aF64u6MXH8zPBTEg6odQq2WazbUIxQYlfJNXzfkMdM,706
|
6
6
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
7
|
-
ai_edge_torch/_convert/conversion.py,sha256=
|
7
|
+
ai_edge_torch/_convert/conversion.py,sha256=QVugYVfbyaeBgSKKbhFzHG5oXA7t3M-40JcpcdSu6W8,5436
|
8
8
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
9
9
|
ai_edge_torch/_convert/converter.py,sha256=075F8LRewk_033Ebsnft7FJr3KgtIbtZ_-8udIPy6ho,9980
|
10
10
|
ai_edge_torch/_convert/signature.py,sha256=-YKJdLk-eNEHfhdPCtcQVtZf915SoVePEFxKXPPf16c,2572
|
11
11
|
ai_edge_torch/_convert/to_channel_last_io.py,sha256=_31phf7TYgZY2ftpNbrdlB1RhDium1lz_BXEQ6IsMFc,2893
|
12
|
-
ai_edge_torch/_convert/fx_passes/__init__.py,sha256=
|
13
|
-
ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py,sha256=
|
14
|
-
ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py,sha256=3JyjiHpn17Zhfq3yGQXK5LMH71DQPXHb_4GOkP9uAjY,4251
|
12
|
+
ai_edge_torch/_convert/fx_passes/__init__.py,sha256=jbRCZmSduG_1qmngaEEtbofAyL1PKZ8P1uxzzsXQhsw,1253
|
13
|
+
ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py,sha256=dgUO-lI9Id9hIOHP5XmegVlu5Fl79GR4_b-lDUehzoo,11428
|
15
14
|
ai_edge_torch/_convert/fx_passes/cast_inputs_bf16_to_f32_pass.py,sha256=90YxLVAAkiA3qKr4Um__JmPeC1bTeA2PxBCj0GETq1Q,1748
|
16
15
|
ai_edge_torch/_convert/fx_passes/inject_mlir_debuginfo_pass.py,sha256=Z6E3U7SYZvMl3Ivpqa3burVOLKFndEZuNmWKNxjq2mM,2386
|
17
16
|
ai_edge_torch/_convert/fx_passes/remove_non_user_outputs_pass.py,sha256=HCOkj0k3NhaYbtfjE8HDXVmYhZ9fL5V_u6VunVh9mN4,2116
|
@@ -19,7 +18,7 @@ ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/__init__.py,sha
|
|
19
18
|
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/_decomp_registry.py,sha256=aWO_zHDF4j_hokoKJQNFIFmua4ysXztsgS6pcyBUht0,1082
|
20
19
|
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_check.py,sha256=7yEKSfXskXUk4tsd7c8vL155O-iU4eUjXCU5RSZHrbw,8204
|
21
20
|
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_mark.py,sha256=4RyGUwR22bZqkn_TnptenFJodc_Q43f4_SBG7gmTbos,1621
|
22
|
-
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py,sha256=
|
21
|
+
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py,sha256=IhEh3tTP3-AmQlpt24stKKEl0AIRyuo2REZIbhkmgJo,13940
|
23
22
|
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/op_func_registry.py,sha256=OhisegHY2j4cv_m9auCh9Mq9qmm1lUqpFLVO9X-oBlc,1032
|
24
23
|
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/pass_body.py,sha256=mr0MiLbaQmU-3S3KT-vb58kRWbNT3VJiCKY-K7_3tFg,10556
|
25
24
|
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/utils.py,sha256=YLMttMg5PdvXTtQ8lxpKb434UGVvYVALV1-xeuH4UGc,2131
|
@@ -54,8 +53,8 @@ ai_edge_torch/generative/examples/amd_llama_135m/amd_llama_135m.py,sha256=urNif8
|
|
54
53
|
ai_edge_torch/generative/examples/amd_llama_135m/convert_to_tflite.py,sha256=z5MWiZLnsQzhNYMiQbcI9i0ki-dtkbimCptkiTFZxwo,1586
|
55
54
|
ai_edge_torch/generative/examples/amd_llama_135m/verify.py,sha256=o13NkFlBgawBsjdJup05VMUjAPvDRAmig6VyEkX8q6U,2426
|
56
55
|
ai_edge_torch/generative/examples/deepseek/__init__.py,sha256=JaAnrFoXTl3RJX97XspklkTyqOHVyAgRJsZtzNDd10c,671
|
57
|
-
ai_edge_torch/generative/examples/deepseek/convert_to_tflite.py,sha256=
|
58
|
-
ai_edge_torch/generative/examples/deepseek/deepseek.py,sha256=
|
56
|
+
ai_edge_torch/generative/examples/deepseek/convert_to_tflite.py,sha256=1wz4h3bjyX2qMRZ310UKGNYTORegzxinVFmYz2Fupm4,2666
|
57
|
+
ai_edge_torch/generative/examples/deepseek/deepseek.py,sha256=yhS_i2kR0GJWpWciCt4p9Z9nHYh6A5uJ8Ycy2ebFN9w,2909
|
59
58
|
ai_edge_torch/generative/examples/deepseek/verify.py,sha256=iYldze-pvZGvPkkqr6zA7EmitPnH9sXkzjNVx353IcE,2403
|
60
59
|
ai_edge_torch/generative/examples/gemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
61
60
|
ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py,sha256=tSEtGeS-Ndcc_cTm7c4CT4FqRiwrHedEv1oJk4Y_zYU,1552
|
@@ -104,8 +103,8 @@ ai_edge_torch/generative/examples/phi/verify.py,sha256=YPFCdbnfmvq38fbpBNr0kHPfS
|
|
104
103
|
ai_edge_torch/generative/examples/phi/verify_phi3.py,sha256=kVYaBVvddfQng0IyZGxyTJEzhiPO0G4VFJm2WOc2Q94,2360
|
105
104
|
ai_edge_torch/generative/examples/phi/verify_phi4.py,sha256=BoCa5kUBRHtMQ-5ql6yD4pG4xHJMyUiQlpMOWVx-JgY,2356
|
106
105
|
ai_edge_torch/generative/examples/qwen/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
107
|
-
ai_edge_torch/generative/examples/qwen/convert_to_tflite.py,sha256
|
108
|
-
ai_edge_torch/generative/examples/qwen/qwen.py,sha256=
|
106
|
+
ai_edge_torch/generative/examples/qwen/convert_to_tflite.py,sha256=-Xe5koexhNUkWjS2XgS9Ggg7XOQAlMO8QcBJRTNjJa4,2972
|
107
|
+
ai_edge_torch/generative/examples/qwen/qwen.py,sha256=m8APYzo9N0SXsdvCxC8HtCcbN3W7gLKkRBL-Tg0BWXU,4223
|
109
108
|
ai_edge_torch/generative/examples/qwen/verify.py,sha256=9_AyEJTeUfvhhID64Rto2bflFPyXMFokdQLsseLUMiI,2775
|
110
109
|
ai_edge_torch/generative/examples/qwen_vl/__init__.py,sha256=JaAnrFoXTl3RJX97XspklkTyqOHVyAgRJsZtzNDd10c,671
|
111
110
|
ai_edge_torch/generative/examples/qwen_vl/convert_to_tflite.py,sha256=yVebRatt2SLCsGvrYTBXOM-0S2REhkpikHTyy5MCjUw,2222
|
@@ -159,11 +158,10 @@ ai_edge_torch/generative/layers/lora.py,sha256=hsvWLLOnW7HQ0AysOZu30x_cetMquDd1t
|
|
159
158
|
ai_edge_torch/generative/layers/model_config.py,sha256=nLXvTkDAIHJQ0PTaWODF8oxJQoJ-K8D10cKR9229SAw,8355
|
160
159
|
ai_edge_torch/generative/layers/normalization.py,sha256=MbwH-n80Fob5YvjBzdqDjBizMHLzSJGYRDdbD-rL5C0,6174
|
161
160
|
ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=975zR202MdIrILJ7blceAcxrNqX1ZCN0ECKG1gz-bV8,2655
|
162
|
-
ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=
|
163
|
-
ai_edge_torch/generative/layers/sdpa_with_kv_update.py,sha256=
|
161
|
+
ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=efqqGRZPJ55hKn1MQJ-cXfrJD85uS1v7W_juyGyts58,5648
|
162
|
+
ai_edge_torch/generative/layers/sdpa_with_kv_update.py,sha256=Hn8Zw-jiB9GH2uZ-yaRMcDdpmjECcW4uCy-YNH9zV8c,3693
|
164
163
|
ai_edge_torch/generative/layers/experimental/__init__.py,sha256=nz-K0h8DfiATHzR6s1_bCw2akUmHWffU1bDRSkIzSqI,592
|
165
164
|
ai_edge_torch/generative/layers/experimental/kv_cache.py,sha256=zgpFVftOfllvjh9-UEBSvUbm152SnQETn29rUMMMvAM,2978
|
166
|
-
ai_edge_torch/generative/layers/experimental/scaled_dot_product_attention.py,sha256=YFcIGOkaNb-vvQKjI-G9-bC2Z1W0O_qRyIZPlsLl72U,2797
|
167
165
|
ai_edge_torch/generative/layers/unet/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
168
166
|
ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=ZteHZXK6HKyxYji49DQ46sA9aIy7U3Jnz0HZp6hfevY,28996
|
169
167
|
ai_edge_torch/generative/layers/unet/builder.py,sha256=zAqWXdimmMrQRhmE_t9XkS68mh6PSrzwb-2NZZXrR5I,1901
|
@@ -180,12 +178,12 @@ ai_edge_torch/generative/test/test_custom_dus.py,sha256=MjIhTvkTko872M35XMciobvI
|
|
180
178
|
ai_edge_torch/generative/test/test_kv_cache.py,sha256=1sXN2RPntq0PP3IEy0NkvIbzQ0Y8JhPIwRSFwO9JLlE,5728
|
181
179
|
ai_edge_torch/generative/test/test_loader.py,sha256=9mQUeeZKOVApOWSWl2cN9c10axZjMKM1-0Zd823CCS4,3449
|
182
180
|
ai_edge_torch/generative/test/test_lora.py,sha256=6QIM6RLTc2HrodGpp_aS3OxM9Rco2KAzEnYgotkg41M,5310
|
183
|
-
ai_edge_torch/generative/test/test_model_conversion.py,sha256=
|
184
|
-
ai_edge_torch/generative/test/test_model_conversion_large.py,sha256
|
181
|
+
ai_edge_torch/generative/test/test_model_conversion.py,sha256=mhNJikLnGVGi9NKmXB8FhnqeDy9gtrvC3yEbrTABZ4Y,6163
|
182
|
+
ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=6LkLnFOvlnt7JVVDYKMaZClPRBEvdjq6xnSjIFYNdI8,12554
|
185
183
|
ai_edge_torch/generative/test/test_quantize.py,sha256=bEJMhpQ9bIDUZVBXTW888728FcH-i3SyE4JSZZUgU0A,6071
|
186
184
|
ai_edge_torch/generative/test/utils.py,sha256=tF6aCfAGJnc9dmzCnZCEOuKNVimfWOqscv9og0DDLHU,2656
|
187
185
|
ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
|
188
|
-
ai_edge_torch/generative/utilities/converter.py,sha256=
|
186
|
+
ai_edge_torch/generative/utilities/converter.py,sha256=z3CvNJxKzglu1BU_5ri91RUeGHh7urhoWFbk0oq7i2M,10768
|
189
187
|
ai_edge_torch/generative/utilities/export_config.py,sha256=8-795nyd3M34LkGhgW7hwHlJyTc2Oz1iipHK8yBhdFs,1633
|
190
188
|
ai_edge_torch/generative/utilities/loader.py,sha256=7p__m2JryWphGlYOuRxdoT4id4_tWJEVOV7y2X4H-Ak,13737
|
191
189
|
ai_edge_torch/generative/utilities/model_builder.py,sha256=ZYX1TxpFdj573du2QCyHJlFjx4q1m12R74fp4Gwl92A,6343
|
@@ -227,7 +225,7 @@ ai_edge_torch/odml_torch/lowerings/__init__.py,sha256=uJ-niilt1c-D6QJzLwgvCUf62l
|
|
227
225
|
ai_edge_torch/odml_torch/lowerings/_basic.py,sha256=fEWjIdEpDIqT1EYLZE13O9A41OuaNdbfBrv3vNxS9gI,11601
|
228
226
|
ai_edge_torch/odml_torch/lowerings/_batch_norm.py,sha256=PaLI0BB6pdBW1VyfW8VTOT_Be-ZcqYdNOsyfzKfq8Cg,2064
|
229
227
|
ai_edge_torch/odml_torch/lowerings/_convolution.py,sha256=Q0aDzyUcZMoSzSbOU-r3LJMgPe6fble0QwdYVIOHHHk,6887
|
230
|
-
ai_edge_torch/odml_torch/lowerings/_decomp_registry.py,sha256=
|
228
|
+
ai_edge_torch/odml_torch/lowerings/_decomp_registry.py,sha256=ybOdoFE5HIJTkyiYcc73zpyUyUpioVnAca6k0wyJPs4,2572
|
231
229
|
ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py,sha256=tkaDo232HjuZvJHyua0n6tdHecifUuVzclJAGq7PPYs,11428
|
232
230
|
ai_edge_torch/odml_torch/lowerings/_layer_norm.py,sha256=khJIvDVk2s332Nd2Be-5dM6-wp5DGff61HCV5lskHmQ,3011
|
233
231
|
ai_edge_torch/odml_torch/lowerings/_quantized_decomposed.py,sha256=XDZ0zLej_XaQDJnaAAxhNFAd7NfQm5SOVEp_nno_krA,6178
|
@@ -244,8 +242,8 @@ ai_edge_torch/testing/__init__.py,sha256=_yGgvnBZWb7T3IN3mc4x1sS4vM96HZwM8pwIcPG
|
|
244
242
|
ai_edge_torch/testing/export.py,sha256=k5mGDGzwc23Z4zaIVDs8CNh-oOt64gsf9MS9NjhbPy4,3293
|
245
243
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
246
244
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
247
|
-
ai_edge_torch_nightly-0.5.0.
|
248
|
-
ai_edge_torch_nightly-0.5.0.
|
249
|
-
ai_edge_torch_nightly-0.5.0.
|
250
|
-
ai_edge_torch_nightly-0.5.0.
|
251
|
-
ai_edge_torch_nightly-0.5.0.
|
245
|
+
ai_edge_torch_nightly-0.5.0.dev20250425.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
246
|
+
ai_edge_torch_nightly-0.5.0.dev20250425.dist-info/METADATA,sha256=owGeoLcv0XFf4tXFatFjXLSisoaRBBwrtyLx3LFq8PM,2051
|
247
|
+
ai_edge_torch_nightly-0.5.0.dev20250425.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
|
248
|
+
ai_edge_torch_nightly-0.5.0.dev20250425.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
249
|
+
ai_edge_torch_nightly-0.5.0.dev20250425.dist-info/RECORD,,
|
@@ -1,129 +0,0 @@
|
|
1
|
-
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
"""Build interpolate composite pass."""
|
16
|
-
|
17
|
-
import functools
|
18
|
-
|
19
|
-
from ai_edge_torch import fx_infra
|
20
|
-
from ai_edge_torch.hlfb import mark_pattern
|
21
|
-
from ai_edge_torch.hlfb.mark_pattern import pattern as pattern_module
|
22
|
-
import torch
|
23
|
-
|
24
|
-
# For torch nightly released after mid June 2024,
|
25
|
-
# torch.nn.functional.interpolate no longer gets exported into decomposed graph
|
26
|
-
# but a single aten op:
|
27
|
-
# torch.ops.aten.upsample_nearest2d.vec/torch.ops.aten.upsample_bilinear2d.vec.
|
28
|
-
# This would interefere with our pattern matching based composite builder.
|
29
|
-
# Here we register the now missing decompositions first.
|
30
|
-
_INTERPOLATE_DECOMPOSITIONS = torch._decomp.get_decompositions([
|
31
|
-
torch.ops.aten.upsample_bilinear2d.vec,
|
32
|
-
torch.ops.aten.upsample_nearest2d.vec,
|
33
|
-
])
|
34
|
-
|
35
|
-
|
36
|
-
@functools.cache
|
37
|
-
def _get_upsample_bilinear2d_pattern():
|
38
|
-
pattern = pattern_module.Pattern(
|
39
|
-
"odml.upsample_bilinear2d",
|
40
|
-
lambda x: torch.nn.functional.interpolate(
|
41
|
-
x, scale_factor=2, mode="bilinear", align_corners=False
|
42
|
-
),
|
43
|
-
export_args=(torch.rand(1, 3, 100, 100),),
|
44
|
-
extra_decomp_table=_INTERPOLATE_DECOMPOSITIONS,
|
45
|
-
)
|
46
|
-
|
47
|
-
@pattern.register_attr_builder
|
48
|
-
def attr_builder(pattern, graph_module, internal_match):
|
49
|
-
output = internal_match.returning_nodes[0]
|
50
|
-
output_h, output_w = output.meta["val"].shape[-2:]
|
51
|
-
return {
|
52
|
-
"size": (int(output_h), int(output_w)),
|
53
|
-
"align_corners": False,
|
54
|
-
"is_nchw_op": True,
|
55
|
-
}
|
56
|
-
|
57
|
-
return pattern
|
58
|
-
|
59
|
-
|
60
|
-
@functools.cache
|
61
|
-
def _get_upsample_bilinear2d_align_corners_pattern():
|
62
|
-
pattern = pattern_module.Pattern(
|
63
|
-
"odml.upsample_bilinear2d",
|
64
|
-
lambda x: torch.nn.functional.interpolate(
|
65
|
-
x, scale_factor=2, mode="bilinear", align_corners=True
|
66
|
-
),
|
67
|
-
export_args=(torch.rand(1, 3, 100, 100),),
|
68
|
-
extra_decomp_table=_INTERPOLATE_DECOMPOSITIONS,
|
69
|
-
)
|
70
|
-
|
71
|
-
@pattern.register_attr_builder
|
72
|
-
def attr_builder(graph_module, pattern, internal_match):
|
73
|
-
output = internal_match.returning_nodes[0]
|
74
|
-
output_h, output_w = output.meta["val"].shape[-2:]
|
75
|
-
return {
|
76
|
-
"size": (int(output_h), int(output_w)),
|
77
|
-
"align_corners": True,
|
78
|
-
"is_nchw_op": True,
|
79
|
-
}
|
80
|
-
|
81
|
-
return pattern
|
82
|
-
|
83
|
-
|
84
|
-
@functools.cache
|
85
|
-
def _get_interpolate_nearest2d_pattern():
|
86
|
-
pattern = pattern_module.Pattern(
|
87
|
-
"tfl.resize_nearest_neighbor",
|
88
|
-
lambda x: torch.nn.functional.interpolate(
|
89
|
-
x, scale_factor=2, mode="nearest"
|
90
|
-
),
|
91
|
-
export_args=(torch.rand(1, 3, 100, 100),),
|
92
|
-
extra_decomp_table=_INTERPOLATE_DECOMPOSITIONS,
|
93
|
-
)
|
94
|
-
|
95
|
-
@pattern.register_attr_builder
|
96
|
-
def attr_builder(pattern, graph_module, internal_match):
|
97
|
-
output = internal_match.returning_nodes[0]
|
98
|
-
output_h, output_w = output.meta["val"].shape[-2:]
|
99
|
-
return {
|
100
|
-
"size": (int(output_h), int(output_w)),
|
101
|
-
"is_nchw_op": True,
|
102
|
-
}
|
103
|
-
|
104
|
-
return pattern
|
105
|
-
|
106
|
-
|
107
|
-
class BuildInterpolateCompositePass(fx_infra.ExportedProgramPassBase):
|
108
|
-
|
109
|
-
def __init__(self):
|
110
|
-
super().__init__()
|
111
|
-
self._patterns = [
|
112
|
-
_get_upsample_bilinear2d_pattern(),
|
113
|
-
_get_upsample_bilinear2d_align_corners_pattern(),
|
114
|
-
_get_interpolate_nearest2d_pattern(),
|
115
|
-
]
|
116
|
-
|
117
|
-
def call(self, exported_program: torch.export.ExportedProgram):
|
118
|
-
exported_program = fx_infra.safe_run_decompositions(
|
119
|
-
exported_program,
|
120
|
-
_INTERPOLATE_DECOMPOSITIONS,
|
121
|
-
)
|
122
|
-
|
123
|
-
graph_module = exported_program.graph_module
|
124
|
-
for pattern in self._patterns:
|
125
|
-
graph_module = mark_pattern.mark_pattern(graph_module, pattern)
|
126
|
-
|
127
|
-
graph_module.graph.lint()
|
128
|
-
graph_module.recompile()
|
129
|
-
return fx_infra.ExportedProgramPassResult(exported_program, True)
|
@@ -1,93 +0,0 @@
|
|
1
|
-
# Copyright 2025 The AI Edge Torch Authors.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
# Implements scaled dot product attention. This is experimental and
|
16
|
-
# GPU-specific code.
|
17
|
-
|
18
|
-
import math
|
19
|
-
from typing import Optional
|
20
|
-
|
21
|
-
from ai_edge_torch.generative.custom_ops import bmm_4d as bmm_lib
|
22
|
-
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
23
|
-
from ai_edge_torch.generative.utilities import types
|
24
|
-
from ai_edge_torch.hlfb import StableHLOCompositeBuilder
|
25
|
-
from multipledispatch import dispatch
|
26
|
-
import torch
|
27
|
-
import torch.nn.functional as F
|
28
|
-
|
29
|
-
|
30
|
-
def scaled_dot_product_attention(
|
31
|
-
kv: kv_utils.KVCacheEntry,
|
32
|
-
query: torch.Tensor,
|
33
|
-
key: torch.Tensor,
|
34
|
-
value: torch.Tensor,
|
35
|
-
head_size: int,
|
36
|
-
mask: Optional[torch.Tensor] = None,
|
37
|
-
scale: Optional[float] = None,
|
38
|
-
softcap: Optional[float] = None,
|
39
|
-
):
|
40
|
-
if hasattr(kv, "kv_layout"):
|
41
|
-
return _sdpa(
|
42
|
-
kv.kv_layout[0](), # key layout
|
43
|
-
kv.kv_layout[1](), # value layout
|
44
|
-
query=query,
|
45
|
-
key=key,
|
46
|
-
value=value,
|
47
|
-
head_size=head_size,
|
48
|
-
mask=mask,
|
49
|
-
scale=scale,
|
50
|
-
softcap=softcap,
|
51
|
-
)
|
52
|
-
raise ValueError("No kv_layout attribute found in kv.")
|
53
|
-
|
54
|
-
|
55
|
-
@dispatch(types.BNTH, types.BNHT)
|
56
|
-
def _sdpa(k_type, v_type, *args, **kwargs):
|
57
|
-
query = kwargs["query"]
|
58
|
-
key = kwargs["key"]
|
59
|
-
value = kwargs["value"]
|
60
|
-
head_size = kwargs["head_size"]
|
61
|
-
mask = kwargs.get("mask", None)
|
62
|
-
scale = kwargs.get("scale", None)
|
63
|
-
softcap = kwargs.get("softcap", None)
|
64
|
-
|
65
|
-
if scale is None:
|
66
|
-
scale = 1.0 / math.sqrt(head_size)
|
67
|
-
|
68
|
-
query = query * scale
|
69
|
-
|
70
|
-
assert mask is not None, "Mask should not be None!"
|
71
|
-
t = mask.shape[2]
|
72
|
-
|
73
|
-
logits = bmm_lib.bmm_4d(query, key)
|
74
|
-
|
75
|
-
_, bk, gt, s = logits.shape
|
76
|
-
g = gt // t
|
77
|
-
logits = logits.reshape((bk, g, t, s))
|
78
|
-
if softcap is not None:
|
79
|
-
logits = torch.tanh(logits / softcap)
|
80
|
-
logits = logits * softcap
|
81
|
-
|
82
|
-
padded_logits = logits + mask
|
83
|
-
padded_logits = padded_logits.reshape(1, bk, gt, s)
|
84
|
-
probs = F.softmax(padded_logits, dim=-1).type_as(key)
|
85
|
-
encoded = bmm_lib.bmm_4d(probs, value)
|
86
|
-
|
87
|
-
return encoded # 1, bk, gt, h
|
88
|
-
|
89
|
-
|
90
|
-
@dispatch(object, object)
|
91
|
-
def _sdpa(k_type, v_type, *args, **kwargs):
|
92
|
-
|
93
|
-
raise ValueError(f"No implementations for k={k_type} and v={v_type}")
|
File without changes
|
File without changes
|