ai-edge-torch-nightly 0.5.0.dev20250424__py3-none-any.whl → 0.5.0.dev20250426__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/_convert/conversion.py +1 -3
- ai_edge_torch/_convert/fx_passes/__init__.py +0 -1
- ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py +63 -2
- ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +2 -1
- ai_edge_torch/generative/examples/deepseek/convert_to_tflite.py +3 -3
- ai_edge_torch/generative/examples/deepseek/deepseek.py +1 -0
- ai_edge_torch/generative/examples/gemma3/convert_gemma3_to_tflite.py +2 -38
- ai_edge_torch/generative/examples/hammer/__init__.py +14 -0
- ai_edge_torch/generative/examples/hammer/convert_to_tflite.py +92 -0
- ai_edge_torch/generative/examples/hammer/hammer.py +107 -0
- ai_edge_torch/generative/examples/hammer/verify.py +86 -0
- ai_edge_torch/generative/examples/llama/convert_to_tflite.py +1 -3
- ai_edge_torch/generative/examples/llama/llama.py +3 -1
- ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py +1 -2
- ai_edge_torch/generative/examples/phi/convert_phi4_to_tflite.py +1 -2
- ai_edge_torch/generative/examples/phi/convert_to_tflite.py +1 -2
- ai_edge_torch/generative/examples/phi/phi2.py +1 -1
- ai_edge_torch/generative/examples/phi/phi3.py +3 -1
- ai_edge_torch/generative/examples/phi/phi4.py +3 -1
- ai_edge_torch/generative/examples/qwen/convert_to_tflite.py +2 -3
- ai_edge_torch/generative/examples/qwen/qwen.py +1 -0
- ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +5 -3
- ai_edge_torch/generative/examples/smollm/convert_v2_to_tflite.py +4 -3
- ai_edge_torch/generative/examples/smollm/smollm.py +3 -1
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +1 -2
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +3 -1
- ai_edge_torch/generative/layers/kv_cache.py +2 -4
- ai_edge_torch/generative/layers/scaled_dot_product_attention.py +51 -0
- ai_edge_torch/generative/layers/sdpa_with_kv_update.py +4 -6
- ai_edge_torch/generative/test/test_model_conversion.py +3 -33
- ai_edge_torch/generative/test/test_model_conversion_large.py +10 -75
- ai_edge_torch/generative/utilities/converter.py +11 -1
- ai_edge_torch/generative/utilities/export_config.py +30 -0
- ai_edge_torch/model.py +2 -0
- ai_edge_torch/odml_torch/lowerings/_decomp_registry.py +2 -0
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.5.0.dev20250424.dist-info → ai_edge_torch_nightly-0.5.0.dev20250426.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.5.0.dev20250424.dist-info → ai_edge_torch_nightly-0.5.0.dev20250426.dist-info}/RECORD +41 -39
- ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py +0 -129
- ai_edge_torch/generative/layers/experimental/scaled_dot_product_attention.py +0 -93
- {ai_edge_torch_nightly-0.5.0.dev20250424.dist-info → ai_edge_torch_nightly-0.5.0.dev20250426.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.5.0.dev20250424.dist-info → ai_edge_torch_nightly-0.5.0.dev20250426.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.5.0.dev20250424.dist-info → ai_edge_torch_nightly-0.5.0.dev20250426.dist-info}/top_level.txt +0 -0
@@ -53,6 +53,7 @@ def get_3b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
53
53
|
norm_config = cfg.NormalizationConfig(
|
54
54
|
type=cfg.NormalizationType.RMS_NORM,
|
55
55
|
epsilon=1e-06,
|
56
|
+
enable_hlfb=True,
|
56
57
|
)
|
57
58
|
block_config = cfg.TransformerBlockConfig(
|
58
59
|
attn_config=attn_config,
|
@@ -35,6 +35,10 @@ def main(_):
|
|
35
35
|
pytorch_model = smollm.build_model(
|
36
36
|
flags.FLAGS.checkpoint_path, kv_cache_max_len=flags.FLAGS.kv_cache_max_len
|
37
37
|
)
|
38
|
+
|
39
|
+
export_config = export_cfg.get_from_flags()
|
40
|
+
export_config.decode_batch_size = _DECODE_BATCH_SIZE.value
|
41
|
+
|
38
42
|
converter.convert_to_tflite(
|
39
43
|
pytorch_model,
|
40
44
|
output_path=flags.FLAGS.output_path,
|
@@ -42,9 +46,7 @@ def main(_):
|
|
42
46
|
prefill_seq_len=flags.FLAGS.prefill_seq_lens,
|
43
47
|
quantize=flags.FLAGS.quantize,
|
44
48
|
lora_ranks=flags.FLAGS.lora_ranks,
|
45
|
-
export_config=
|
46
|
-
decode_batch_size=_DECODE_BATCH_SIZE.value
|
47
|
-
),
|
49
|
+
export_config=export_config,
|
48
50
|
)
|
49
51
|
|
50
52
|
|
@@ -34,6 +34,9 @@ def main(_):
|
|
34
34
|
flags.FLAGS.checkpoint_path, kv_cache_max_len=flags.FLAGS.kv_cache_max_len
|
35
35
|
)
|
36
36
|
|
37
|
+
export_config = export_cfg.get_from_flags()
|
38
|
+
export_config.decode_batch_size = _DECODE_BATCH_SIZE.value
|
39
|
+
|
37
40
|
converter.convert_to_tflite(
|
38
41
|
pytorch_model,
|
39
42
|
output_path=flags.FLAGS.output_path,
|
@@ -41,9 +44,7 @@ def main(_):
|
|
41
44
|
prefill_seq_len=flags.FLAGS.prefill_seq_lens,
|
42
45
|
quantize=flags.FLAGS.quantize,
|
43
46
|
lora_ranks=flags.FLAGS.lora_ranks,
|
44
|
-
export_config=
|
45
|
-
decode_batch_size=_DECODE_BATCH_SIZE.value
|
46
|
-
),
|
47
|
+
export_config=export_config,
|
47
48
|
)
|
48
49
|
|
49
50
|
|
@@ -49,7 +49,9 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
49
49
|
activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
|
50
50
|
intermediate_size=1536,
|
51
51
|
)
|
52
|
-
norm_config = cfg.NormalizationConfig(
|
52
|
+
norm_config = cfg.NormalizationConfig(
|
53
|
+
type=cfg.NormalizationType.RMS_NORM, enable_hlfb=True
|
54
|
+
)
|
53
55
|
block_config = cfg.TransformerBlockConfig(
|
54
56
|
attn_config=attn_config,
|
55
57
|
ff_config=ff_config,
|
@@ -21,7 +21,6 @@ from ai_edge_torch.generative.utilities import converter
|
|
21
21
|
from ai_edge_torch.generative.utilities import export_config
|
22
22
|
|
23
23
|
flags = converter.define_conversion_flags("tiny_llama")
|
24
|
-
ExportConfig = export_config.ExportConfig
|
25
24
|
|
26
25
|
|
27
26
|
def main(_):
|
@@ -35,7 +34,7 @@ def main(_):
|
|
35
34
|
prefill_seq_len=flags.FLAGS.prefill_seq_lens,
|
36
35
|
quantize=flags.FLAGS.quantize,
|
37
36
|
lora_ranks=flags.FLAGS.lora_ranks,
|
38
|
-
export_config=
|
37
|
+
export_config=export_config.get_from_flags(),
|
39
38
|
)
|
40
39
|
|
41
40
|
|
@@ -49,7 +49,9 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
49
49
|
activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
|
50
50
|
intermediate_size=5632,
|
51
51
|
)
|
52
|
-
norm_config = cfg.NormalizationConfig(
|
52
|
+
norm_config = cfg.NormalizationConfig(
|
53
|
+
type=cfg.NormalizationType.RMS_NORM, enable_hlfb=True
|
54
|
+
)
|
53
55
|
block_config = cfg.TransformerBlockConfig(
|
54
56
|
attn_config=attn_config,
|
55
57
|
ff_config=ff_config,
|
@@ -51,10 +51,7 @@ class KVCacheEntry:
|
|
51
51
|
config: model_config.AttentionConfig,
|
52
52
|
batch_size: int,
|
53
53
|
) -> List[int]:
|
54
|
-
"""
|
55
|
-
|
56
|
-
the specified layout.
|
57
|
-
"""
|
54
|
+
"""Construct the shape of KV cache entry based on the specified layout."""
|
58
55
|
output_shape = []
|
59
56
|
for dim_spec in shape_spec:
|
60
57
|
if dim_spec is types.TensorDims.BATCH:
|
@@ -213,6 +210,7 @@ pytree.register_pytree_node(
|
|
213
210
|
serialized_type_name="",
|
214
211
|
)
|
215
212
|
|
213
|
+
|
216
214
|
def update(
|
217
215
|
cache: KVCacheEntry,
|
218
216
|
input_pos: torch.Tensor,
|
@@ -17,6 +17,8 @@
|
|
17
17
|
import math
|
18
18
|
from typing import Optional
|
19
19
|
|
20
|
+
from ai_edge_torch.generative.custom_ops import bmm_4d as bmm_lib
|
21
|
+
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
20
22
|
from ai_edge_torch.hlfb import StableHLOCompositeBuilder
|
21
23
|
import torch
|
22
24
|
import torch.nn.functional as F
|
@@ -142,3 +144,52 @@ def scaled_dot_product_attention_with_hlfb(
|
|
142
144
|
result = y.transpose(1, 2)
|
143
145
|
result = builder.mark_outputs(result)
|
144
146
|
return result
|
147
|
+
|
148
|
+
|
149
|
+
def scaled_dot_product_attention_transposed(
|
150
|
+
query: torch.Tensor,
|
151
|
+
key: torch.Tensor,
|
152
|
+
value: torch.Tensor,
|
153
|
+
head_size: int,
|
154
|
+
mask: Optional[torch.Tensor] = None,
|
155
|
+
scale: Optional[float] = None,
|
156
|
+
softcap: Optional[float] = None,
|
157
|
+
):
|
158
|
+
"""Scaled dot product attention with transposed key and value.
|
159
|
+
|
160
|
+
Args:
|
161
|
+
query: Query tensor, with shape [B, T, N, H].
|
162
|
+
key: Key tensor, with shape [B, T, KV_LEN, H].
|
163
|
+
value: Value tensor, with shape [B, T, KV_LEN, H].
|
164
|
+
head_size (int): head dimension.
|
165
|
+
mask (torch.Tensor): the optional mask tensor.
|
166
|
+
scale (float): the optional scale factor.
|
167
|
+
softcap (float): the optional softcap for the logits.
|
168
|
+
|
169
|
+
Returns:
|
170
|
+
The output tensor of scaled_dot_product_attention_transposed.
|
171
|
+
"""
|
172
|
+
|
173
|
+
if scale is None:
|
174
|
+
scale = 1.0 / math.sqrt(head_size)
|
175
|
+
|
176
|
+
query = query * scale
|
177
|
+
|
178
|
+
assert mask is not None, "Mask should not be None!"
|
179
|
+
t = mask.shape[2]
|
180
|
+
|
181
|
+
logits = bmm_lib.bmm_4d(query, key)
|
182
|
+
|
183
|
+
_, bk, gt, s = logits.shape
|
184
|
+
g = gt // t
|
185
|
+
logits = logits.reshape((bk, g, t, s))
|
186
|
+
if softcap is not None:
|
187
|
+
logits = torch.tanh(logits / softcap)
|
188
|
+
logits = logits * softcap
|
189
|
+
|
190
|
+
padded_logits = logits + mask
|
191
|
+
padded_logits = padded_logits.reshape(1, bk, gt, s)
|
192
|
+
probs = F.softmax(padded_logits, dim=-1).type_as(key)
|
193
|
+
encoded = bmm_lib.bmm_4d(probs, value)
|
194
|
+
|
195
|
+
return encoded # 1, bk, gt, h
|
@@ -18,9 +18,8 @@
|
|
18
18
|
from typing import Tuple
|
19
19
|
|
20
20
|
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
21
|
-
from ai_edge_torch.generative.layers import scaled_dot_product_attention as
|
21
|
+
from ai_edge_torch.generative.layers import scaled_dot_product_attention as sdpa
|
22
22
|
from ai_edge_torch.generative.layers.experimental import kv_cache as kv_utils_experimental
|
23
|
-
from ai_edge_torch.generative.layers.experimental import scaled_dot_product_attention as sdpa
|
24
23
|
import ai_edge_torch.generative.layers.model_config as cfg
|
25
24
|
import torch
|
26
25
|
|
@@ -72,8 +71,7 @@ def _sdpa_with_kv_update_transposed(
|
|
72
71
|
kv = kv_utils_experimental.update(kv, input_pos, key, value)
|
73
72
|
key, value = kv.k_cache, kv.v_cache
|
74
73
|
|
75
|
-
sdpa_out = sdpa.
|
76
|
-
kv,
|
74
|
+
sdpa_out = sdpa.scaled_dot_product_attention_transposed(
|
77
75
|
query,
|
78
76
|
key,
|
79
77
|
value,
|
@@ -105,9 +103,9 @@ def _sdpa_with_kv_update_default(
|
|
105
103
|
key, value = kv.k_cache, kv.v_cache
|
106
104
|
|
107
105
|
if enable_hlfb:
|
108
|
-
sdpa_func =
|
106
|
+
sdpa_func = sdpa.scaled_dot_product_attention_with_hlfb
|
109
107
|
else:
|
110
|
-
sdpa_func =
|
108
|
+
sdpa_func = sdpa.scaled_dot_product_attention
|
111
109
|
sdpa_out = sdpa_func(
|
112
110
|
query,
|
113
111
|
key,
|
@@ -32,10 +32,8 @@ class TestModelConversion(googletest.TestCase):
|
|
32
32
|
|
33
33
|
def setUp(self):
|
34
34
|
super().setUp()
|
35
|
-
# Builder function for an Interpreter that supports custom ops.
|
36
35
|
self._interpreter_builder = (
|
37
|
-
lambda tflite_model: lambda: interpreter.
|
38
|
-
custom_op_registerers=["GenAIOpsRegisterer"],
|
36
|
+
lambda tflite_model: lambda: interpreter.Interpreter(
|
39
37
|
model_content=tflite_model,
|
40
38
|
experimental_default_delegate_latest_features=True,
|
41
39
|
)
|
@@ -85,44 +83,24 @@ class TestModelConversion(googletest.TestCase):
|
|
85
83
|
)
|
86
84
|
)
|
87
85
|
|
88
|
-
@googletest.skipIf(
|
89
|
-
ai_edge_torch.config.in_oss,
|
90
|
-
reason="tests with custom ops are not supported in oss",
|
91
|
-
)
|
92
86
|
def test_toy_model_with_kv_cache(self):
|
93
87
|
self._test_model_with_kv_cache(enable_hlfb=False)
|
94
88
|
|
95
|
-
@googletest.skipIf(
|
96
|
-
ai_edge_torch.config.in_oss,
|
97
|
-
reason="tests with custom ops are not supported in oss",
|
98
|
-
)
|
99
89
|
def test_toy_model_with_kv_cache_with_hlfb(self):
|
100
90
|
self._test_model_with_kv_cache(enable_hlfb=True)
|
101
91
|
|
102
|
-
@googletest.skipIf(
|
103
|
-
ai_edge_torch.config.in_oss,
|
104
|
-
reason="tests with custom ops are not supported in oss",
|
105
|
-
)
|
106
92
|
def test_toy_model_with_kv_cache_transposed(self):
|
107
93
|
self._test_model_with_kv_cache(kv_layout=kv_cache.KV_LAYOUT_TRANSPOSED)
|
108
94
|
|
109
|
-
@googletest.skipIf(
|
110
|
-
ai_edge_torch.config.in_oss,
|
111
|
-
reason="tests with custom ops are not supported in oss",
|
112
|
-
)
|
113
95
|
def test_toy_model_has_dus_op(self):
|
114
96
|
"""Tests that the model has the dynamic update slice op."""
|
115
97
|
_, edge_model, _ = self._get_params(
|
116
98
|
enable_hlfb=True, kv_layout=kv_cache.KV_LAYOUT_DEFAULT
|
117
99
|
)
|
118
|
-
|
119
|
-
custom_op_registerers=["GenAIOpsRegisterer"],
|
120
|
-
model_content=edge_model.tflite_model(),
|
121
|
-
experimental_default_delegate_latest_features=True,
|
122
|
-
)
|
100
|
+
interpreter = self._interpreter_builder(edge_model.tflite_model())()
|
123
101
|
|
124
102
|
# pylint: disable=protected-access
|
125
|
-
op_names = [op["op_name"] for op in
|
103
|
+
op_names = [op["op_name"] for op in interpreter._get_ops_details()]
|
126
104
|
self.assertIn("DYNAMIC_UPDATE_SLICE", op_names)
|
127
105
|
|
128
106
|
def _test_multisig_model(
|
@@ -197,19 +175,11 @@ class TestModelConversion(googletest.TestCase):
|
|
197
175
|
)
|
198
176
|
)
|
199
177
|
|
200
|
-
@googletest.skipIf(
|
201
|
-
ai_edge_torch.config.in_oss,
|
202
|
-
reason="tests with custom ops are not supported in oss",
|
203
|
-
)
|
204
178
|
def test_tiny_llama_multisig(self):
|
205
179
|
config = tiny_llama.get_fake_model_config()
|
206
180
|
pytorch_model = tiny_llama.TinyLlama(config).eval()
|
207
181
|
self._test_multisig_model(config, pytorch_model, atol=1e-5, rtol=1e-5)
|
208
182
|
|
209
|
-
@googletest.skipIf(
|
210
|
-
ai_edge_torch.config.in_oss,
|
211
|
-
reason="tests with custom ops are not supported in oss",
|
212
|
-
)
|
213
183
|
def test_tiny_llama_multisig_kv_layout_transposed(self):
|
214
184
|
config = tiny_llama.get_fake_model_config()
|
215
185
|
pytorch_model = tiny_llama.TinyLlama(config).eval()
|
@@ -20,6 +20,7 @@ from ai_edge_torch.generative.examples.amd_llama_135m import amd_llama_135m
|
|
20
20
|
from ai_edge_torch.generative.examples.deepseek import deepseek
|
21
21
|
from ai_edge_torch.generative.examples.gemma import gemma1
|
22
22
|
from ai_edge_torch.generative.examples.gemma import gemma2
|
23
|
+
from ai_edge_torch.generative.examples.hammer import hammer
|
23
24
|
from ai_edge_torch.generative.examples.llama import llama
|
24
25
|
from ai_edge_torch.generative.examples.openelm import openelm
|
25
26
|
from ai_edge_torch.generative.examples.paligemma import decoder
|
@@ -48,10 +49,8 @@ class TestModelConversion(googletest.TestCase):
|
|
48
49
|
|
49
50
|
def setUp(self):
|
50
51
|
super().setUp()
|
51
|
-
# Builder function for an Interpreter that supports custom ops.
|
52
52
|
self._interpreter_builder = (
|
53
|
-
lambda tflite_model: lambda: interpreter.
|
54
|
-
custom_op_registerers=["GenAIOpsRegisterer"],
|
53
|
+
lambda tflite_model: lambda: interpreter.Interpreter(
|
55
54
|
model_content=tflite_model,
|
56
55
|
experimental_default_delegate_latest_features=True,
|
57
56
|
)
|
@@ -94,110 +93,68 @@ class TestModelConversion(googletest.TestCase):
|
|
94
93
|
)
|
95
94
|
)
|
96
95
|
|
97
|
-
@googletest.skipIf(
|
98
|
-
ai_edge_torch.config.in_oss,
|
99
|
-
reason="tests with custom ops are not supported in oss",
|
100
|
-
)
|
101
96
|
def test_gemma1(self):
|
102
97
|
config = gemma1.get_fake_model_config()
|
103
98
|
pytorch_model = gemma1.Gemma1(config).eval()
|
104
99
|
self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)
|
105
100
|
|
106
|
-
@googletest.skipIf(
|
107
|
-
ai_edge_torch.config.in_oss,
|
108
|
-
reason="tests with custom ops are not supported in oss",
|
109
|
-
)
|
110
101
|
def test_gemma2(self):
|
111
102
|
config = gemma2.get_fake_model_config()
|
112
103
|
pytorch_model = gemma2.Gemma2(config).eval()
|
113
104
|
self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
|
114
105
|
|
115
|
-
@googletest.skipIf(
|
116
|
-
ai_edge_torch.config.in_oss,
|
117
|
-
reason="tests with custom ops are not supported in oss",
|
118
|
-
)
|
119
106
|
def test_llama(self):
|
120
107
|
config = llama.get_fake_model_config()
|
121
108
|
pytorch_model = llama.Llama(config).eval()
|
122
109
|
self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)
|
123
110
|
|
124
|
-
@googletest.skipIf(
|
125
|
-
ai_edge_torch.config.in_oss,
|
126
|
-
reason="tests with custom ops are not supported in oss",
|
127
|
-
)
|
128
111
|
def test_phi2(self):
|
129
112
|
config = phi2.get_fake_model_config()
|
130
113
|
pytorch_model = phi2.Phi2(config).eval()
|
131
114
|
# Phi-2 logits are very big, so we need a larger absolute tolerance.
|
132
115
|
self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)
|
133
116
|
|
134
|
-
@googletest.skipIf(
|
135
|
-
ai_edge_torch.config.in_oss,
|
136
|
-
reason="tests with custom ops are not supported in oss",
|
137
|
-
)
|
138
117
|
def test_phi3(self):
|
139
118
|
config = phi3.get_fake_model_config()
|
140
119
|
pytorch_model = phi3.Phi3_5Mini(config).eval()
|
141
120
|
self._test_model(config, pytorch_model, "prefill", atol=1e-5, rtol=1e-5)
|
142
121
|
|
143
|
-
@googletest.skipIf(
|
144
|
-
ai_edge_torch.config.in_oss,
|
145
|
-
reason="tests with custom ops are not supported in oss",
|
146
|
-
)
|
147
122
|
def test_phi4(self):
|
148
123
|
config = phi4.get_fake_model_config()
|
149
124
|
pytorch_model = phi4.Phi4Mini(config).eval()
|
150
125
|
self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)
|
151
126
|
|
152
|
-
@googletest.skipIf(
|
153
|
-
ai_edge_torch.config.in_oss,
|
154
|
-
reason="tests with custom ops are not supported in oss",
|
155
|
-
)
|
156
127
|
def test_smollm(self):
|
157
128
|
config = smollm.get_fake_model_config()
|
158
129
|
pytorch_model = smollm.SmolLM(config).eval()
|
159
130
|
self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
|
160
131
|
|
161
|
-
@googletest.skipIf(
|
162
|
-
ai_edge_torch.config.in_oss,
|
163
|
-
reason="tests with custom ops are not supported in oss",
|
164
|
-
)
|
165
132
|
def test_smollm2(self):
|
166
133
|
config = smollm.get_fake_model_config_v2()
|
167
134
|
pytorch_model = smollm.SmolLM2(config).eval()
|
168
135
|
self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
|
169
136
|
|
170
|
-
@googletest.skipIf(
|
171
|
-
ai_edge_torch.config.in_oss,
|
172
|
-
reason="tests with custom ops are not supported in oss",
|
173
|
-
)
|
174
137
|
def test_openelm(self):
|
175
138
|
config = openelm.get_fake_model_config()
|
176
139
|
pytorch_model = openelm.OpenELM(config).eval()
|
177
140
|
self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
|
178
141
|
|
179
|
-
@googletest.skipIf(
|
180
|
-
ai_edge_torch.config.in_oss,
|
181
|
-
reason="tests with custom ops are not supported in oss",
|
182
|
-
)
|
183
142
|
def test_qwen(self):
|
184
143
|
config = qwen.get_fake_model_config()
|
185
144
|
pytorch_model = qwen.Qwen(config).eval()
|
186
145
|
self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)
|
187
146
|
|
188
|
-
@googletest.skipIf(
|
189
|
-
ai_edge_torch.config.in_oss,
|
190
|
-
reason="tests with custom ops are not supported in oss",
|
191
|
-
)
|
192
147
|
def test_deepseek(self):
|
193
148
|
config = deepseek.get_fake_model_config()
|
194
149
|
pytorch_model = deepseek.DeepSeekDistillQwen(config).eval()
|
195
150
|
self._test_model(config, pytorch_model, "prefill", atol=1e-5, rtol=1e-5)
|
196
151
|
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
152
|
+
def test_hammer(self):
|
153
|
+
config = hammer.get_fake_model_config()
|
154
|
+
pytorch_model = hammer.Hammer(config).eval()
|
155
|
+
self._test_model(config, pytorch_model, "prefill", atol=1e-5, rtol=1e-5)
|
156
|
+
|
157
|
+
|
201
158
|
def test_amd_llama_135m(self):
|
202
159
|
config = amd_llama_135m.get_fake_model_config()
|
203
160
|
pytorch_model = amd_llama_135m.AmdLlama(config).eval()
|
@@ -246,19 +203,11 @@ class TestModelConversion(googletest.TestCase):
|
|
246
203
|
)
|
247
204
|
)
|
248
205
|
|
249
|
-
@googletest.skipIf(
|
250
|
-
ai_edge_torch.config.in_oss,
|
251
|
-
reason="tests with custom ops are not supported in oss",
|
252
|
-
)
|
253
206
|
def test_paligemma1(self):
|
254
207
|
self._test_paligemma_model(
|
255
208
|
decoder.Decoder, decoder.get_fake_decoder_config, atol=1e-3, rtol=1e-5
|
256
209
|
)
|
257
210
|
|
258
|
-
@googletest.skipIf(
|
259
|
-
ai_edge_torch.config.in_oss,
|
260
|
-
reason="tests with custom ops are not supported in oss",
|
261
|
-
)
|
262
211
|
def test_paligemma2(self):
|
263
212
|
self._test_paligemma_model(
|
264
213
|
decoder2.Decoder2,
|
@@ -267,10 +216,6 @@ class TestModelConversion(googletest.TestCase):
|
|
267
216
|
rtol=1e-5,
|
268
217
|
)
|
269
218
|
|
270
|
-
@googletest.skipIf(
|
271
|
-
ai_edge_torch.config.in_oss,
|
272
|
-
reason="tests with custom ops are not supported in oss",
|
273
|
-
)
|
274
219
|
def test_qwen_vl_model(self):
|
275
220
|
config = qwen_vl.get_fake_model_config()
|
276
221
|
pytorch_model = qwen_vl.QwenVL(config).eval()
|
@@ -316,10 +261,7 @@ class TestModelConversion(googletest.TestCase):
|
|
316
261
|
)
|
317
262
|
)
|
318
263
|
|
319
|
-
@googletest.skipIf(
|
320
|
-
ai_edge_torch.config.in_oss,
|
321
|
-
reason="tests with custom ops are not supported in oss",
|
322
|
-
)
|
264
|
+
@googletest.skipIf(ai_edge_torch.config.in_oss, reason="flaky")
|
323
265
|
def test_stable_diffusion_clip(self):
|
324
266
|
config = sd_clip.get_fake_model_config()
|
325
267
|
prompt_tokens = torch.from_numpy(
|
@@ -348,10 +290,7 @@ class TestModelConversion(googletest.TestCase):
|
|
348
290
|
)
|
349
291
|
)
|
350
292
|
|
351
|
-
@googletest.skipIf(
|
352
|
-
ai_edge_torch.config.in_oss,
|
353
|
-
reason="tests with custom ops are not supported in oss",
|
354
|
-
)
|
293
|
+
@googletest.skipIf(ai_edge_torch.config.in_oss, reason="b/413106901")
|
355
294
|
def test_stable_diffusion_diffusion(self):
|
356
295
|
config = sd_diffusion.get_fake_model_config(2)
|
357
296
|
# Reduce stddev(scale) of input values to avoid too big output logits which
|
@@ -390,10 +329,6 @@ class TestModelConversion(googletest.TestCase):
|
|
390
329
|
)
|
391
330
|
)
|
392
331
|
|
393
|
-
@googletest.skipIf(
|
394
|
-
ai_edge_torch.config.in_oss,
|
395
|
-
reason="tests with custom ops are not supported in oss",
|
396
|
-
)
|
397
332
|
def test_stable_diffusion_decoder(self):
|
398
333
|
config = sd_decoder.get_fake_model_config()
|
399
334
|
# Reduce stddev(scale) of input values to avoid too big output logits which
|
@@ -81,7 +81,17 @@ def define_conversion_flags(model_name: str):
|
|
81
81
|
'If set, the model will be converted with the provided list of LoRA'
|
82
82
|
' ranks.',
|
83
83
|
)
|
84
|
-
|
84
|
+
flags.DEFINE_bool(
|
85
|
+
'mask_as_input',
|
86
|
+
False,
|
87
|
+
'If true, the mask will be passed in as input. Otherwise, mask will be '
|
88
|
+
'built by the model internally.',
|
89
|
+
)
|
90
|
+
flags.DEFINE_bool(
|
91
|
+
'transpose_kv_cache',
|
92
|
+
False,
|
93
|
+
'If true, the model will be converted with transposed KV cache.',
|
94
|
+
)
|
85
95
|
return flags
|
86
96
|
|
87
97
|
|
@@ -14,8 +14,11 @@
|
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
16
|
"""Config for customizing model export process."""
|
17
|
+
|
17
18
|
import dataclasses
|
18
19
|
from typing import List, Optional
|
20
|
+
|
21
|
+
from absl import flags
|
19
22
|
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
20
23
|
import torch
|
21
24
|
|
@@ -38,3 +41,30 @@ class ExportConfig:
|
|
38
41
|
kvcache_cls: type = kv_utils.KVCache
|
39
42
|
# The batch size of the decode signature.
|
40
43
|
decode_batch_size: int = 1
|
44
|
+
|
45
|
+
|
46
|
+
def _build_mask(mask_len, kv_cache_max_len) -> torch.Tensor:
|
47
|
+
if isinstance(mask_len, list):
|
48
|
+
return [_build_mask(i, kv_cache_max_len) for i in mask_len]
|
49
|
+
|
50
|
+
mask = torch.full(
|
51
|
+
(mask_len, kv_cache_max_len), float('-inf'), dtype=torch.float32
|
52
|
+
)
|
53
|
+
mask = torch.triu(mask, diagonal=1).unsqueeze(0).unsqueeze(0)
|
54
|
+
return mask
|
55
|
+
|
56
|
+
|
57
|
+
def get_from_flags() -> ExportConfig:
|
58
|
+
"""Builds an export config according to the commandline flags."""
|
59
|
+
export_config = ExportConfig()
|
60
|
+
|
61
|
+
if flags.FLAGS.mask_as_input:
|
62
|
+
export_config.prefill_mask = _build_mask(
|
63
|
+
flags.FLAGS.prefill_seq_lens, flags.FLAGS.kv_cache_max_len
|
64
|
+
)
|
65
|
+
export_config.decode_mask = _build_mask(1, flags.FLAGS.kv_cache_max_len)
|
66
|
+
|
67
|
+
if flags.FLAGS.transpose_kv_cache:
|
68
|
+
export_config.kvcache_layout = kv_utils.KV_LAYOUT_TRANSPOSED
|
69
|
+
|
70
|
+
return export_config
|
ai_edge_torch/model.py
CHANGED
@@ -22,6 +22,7 @@ from __future__ import annotations
|
|
22
22
|
|
23
23
|
import abc
|
24
24
|
import re
|
25
|
+
import os
|
25
26
|
from typing import Callable
|
26
27
|
|
27
28
|
import numpy.typing as npt
|
@@ -154,6 +155,7 @@ class TfLiteModel(Model):
|
|
154
155
|
Args:
|
155
156
|
path: The path to file to which the model is serialized.
|
156
157
|
"""
|
158
|
+
os.makedirs(os.path.dirname(path), exist_ok=True)
|
157
159
|
with open(path, 'wb') as file_handle:
|
158
160
|
file_handle.write(self._tflite_model)
|
159
161
|
|
@@ -34,6 +34,8 @@ fx_infra.decomp.update_pre_lower_decomp(
|
|
34
34
|
torch.ops.aten.replication_pad1d,
|
35
35
|
torch.ops.aten.replication_pad2d,
|
36
36
|
torch.ops.aten.replication_pad3d,
|
37
|
+
torch.ops.aten.upsample_bilinear2d.vec,
|
38
|
+
torch.ops.aten.upsample_nearest2d.vec,
|
37
39
|
torch.ops.aten.addmm,
|
38
40
|
])
|
39
41
|
)
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.5.0.
|
3
|
+
Version: 0.5.0.dev20250426
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|