ai-edge-torch-nightly 0.2.0.dev20240711__py3-none-any.whl → 0.2.0.dev20240712__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ai-edge-torch-nightly might be problematic. Click here for more details.
- ai_edge_torch/convert/conversion.py +2 -4
- ai_edge_torch/convert/conversion_utils.py +61 -3
- ai_edge_torch/convert/converter.py +47 -16
- ai_edge_torch/convert/test/test_convert.py +39 -0
- ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +1 -1
- ai_edge_torch/model.py +11 -3
- ai_edge_torch/testing/model_coverage/model_coverage.py +19 -13
- {ai_edge_torch_nightly-0.2.0.dev20240711.dist-info → ai_edge_torch_nightly-0.2.0.dev20240712.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.2.0.dev20240711.dist-info → ai_edge_torch_nightly-0.2.0.dev20240712.dist-info}/RECORD +12 -12
- {ai_edge_torch_nightly-0.2.0.dev20240711.dist-info → ai_edge_torch_nightly-0.2.0.dev20240712.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.2.0.dev20240711.dist-info → ai_edge_torch_nightly-0.2.0.dev20240712.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.2.0.dev20240711.dist-info → ai_edge_torch_nightly-0.2.0.dev20240712.dist-info}/top_level.txt +0 -0
|
@@ -88,16 +88,14 @@ def convert_signatures(
|
|
|
88
88
|
_warn_training_modules(signatures)
|
|
89
89
|
|
|
90
90
|
exported_programs: torch.export.ExportedProgram = [
|
|
91
|
-
torch.export.export(
|
|
92
|
-
sig.module, sig.sample_args, dynamic_shapes=sig.dynamic_shapes
|
|
93
|
-
)
|
|
91
|
+
torch.export.export(sig.module, sig.flat_args, dynamic_shapes=sig.dynamic_shapes)
|
|
94
92
|
for sig in signatures
|
|
95
93
|
]
|
|
96
94
|
|
|
97
95
|
# Apply default fx passes
|
|
98
96
|
exported_programs = list(map(_run_convert_passes, exported_programs))
|
|
99
97
|
shlo_bundles: list[stablehlo.StableHLOModelBundle] = [
|
|
100
|
-
cutils.exported_program_to_stablehlo_bundle(exported, sig.
|
|
98
|
+
cutils.exported_program_to_stablehlo_bundle(exported, sig.flat_args)
|
|
101
99
|
for exported, sig in zip(exported_programs, signatures)
|
|
102
100
|
]
|
|
103
101
|
|
|
@@ -13,6 +13,7 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
|
|
16
|
+
import collections
|
|
16
17
|
import copy
|
|
17
18
|
from dataclasses import dataclass
|
|
18
19
|
import gc
|
|
@@ -22,6 +23,7 @@ import tempfile
|
|
|
22
23
|
from typing import Any, Dict, Optional, Tuple, Union
|
|
23
24
|
|
|
24
25
|
import torch
|
|
26
|
+
import torch.utils._pytree as pytree
|
|
25
27
|
from torch_xla import stablehlo
|
|
26
28
|
|
|
27
29
|
from ai_edge_torch.generative.quantize.ai_edge_quantizer_glue import translate_recipe # NOQA
|
|
@@ -47,8 +49,59 @@ class Signature:
|
|
|
47
49
|
name: str
|
|
48
50
|
module: torch.nn.Module
|
|
49
51
|
sample_args: tuple[torch.Tensor]
|
|
52
|
+
sample_kwargs: dict[str, torch.Tensor]
|
|
50
53
|
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None
|
|
51
54
|
|
|
55
|
+
@property
|
|
56
|
+
def _normalized_sample_args_kwargs(self):
|
|
57
|
+
args, kwargs = self.sample_args, self.sample_kwargs
|
|
58
|
+
if args is not None:
|
|
59
|
+
if not isinstance(args, tuple):
|
|
60
|
+
# TODO(b/352584188): Check value types
|
|
61
|
+
raise ValueError("sample_args must be a tuple of torch tensors.")
|
|
62
|
+
if kwargs is not None:
|
|
63
|
+
if not isinstance(kwargs, dict) or not all(
|
|
64
|
+
isinstance(key, str) for key in kwargs.keys()
|
|
65
|
+
):
|
|
66
|
+
# TODO(b/352584188): Check value types
|
|
67
|
+
raise ValueError("sample_kwargs must be a dict of string to tensor.")
|
|
68
|
+
|
|
69
|
+
args = args if args is not None else tuple()
|
|
70
|
+
kwargs = kwargs if kwargs is not None else {}
|
|
71
|
+
return args, kwargs
|
|
72
|
+
|
|
73
|
+
@property
|
|
74
|
+
def flat_arg_names(self) -> list[str]:
|
|
75
|
+
spec = pytree.tree_flatten(self._normalized_sample_args_kwargs)[1]
|
|
76
|
+
args_spec, kwargs_spec = spec.children_specs
|
|
77
|
+
|
|
78
|
+
names = []
|
|
79
|
+
for i in range(args_spec.num_leaves):
|
|
80
|
+
names.append(f"args_{i}")
|
|
81
|
+
|
|
82
|
+
dict_context = (
|
|
83
|
+
kwargs_spec.context
|
|
84
|
+
if kwargs_spec.type is not collections.defaultdict
|
|
85
|
+
# ignore mismatch of `default_factory` for defaultdict
|
|
86
|
+
else kwargs_spec.context[1]
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
for name, value_spec in zip(dict_context, kwargs_spec.children_specs):
|
|
90
|
+
if value_spec.num_leaves == 1:
|
|
91
|
+
names.append(name)
|
|
92
|
+
else:
|
|
93
|
+
# value_spec.num_leaves may be greater than 1 when the value is a (nested)
|
|
94
|
+
# tuple of tensors. We haven't decided how we should support flattenable
|
|
95
|
+
# tensor containers as inputs.
|
|
96
|
+
# TODO(b/352584188): Decide the behavior of tensor container as input (flatten or reject)
|
|
97
|
+
for i in range(value_spec.num_leaves):
|
|
98
|
+
names.append(f"{name}_{i}")
|
|
99
|
+
return names
|
|
100
|
+
|
|
101
|
+
@property
|
|
102
|
+
def flat_args(self) -> tuple[torch.Tensor]:
|
|
103
|
+
return tuple(pytree.tree_flatten(self._normalized_sample_args_kwargs)[0])
|
|
104
|
+
|
|
52
105
|
|
|
53
106
|
def exported_program_to_stablehlo_bundle(
|
|
54
107
|
exported_program: torch.export.ExportedProgram, sample_args: tuple[torch.Tensor]
|
|
@@ -189,7 +242,9 @@ def _make_tf_function(
|
|
|
189
242
|
|
|
190
243
|
def _make_tf_signature(
|
|
191
244
|
meta: stablehlo.StableHLOFunctionMeta,
|
|
245
|
+
signature: Signature,
|
|
192
246
|
) -> list[tf.TensorSpec]:
|
|
247
|
+
input_names = signature.flat_arg_names
|
|
193
248
|
input_pos_to_spec = {
|
|
194
249
|
loc.position: spec
|
|
195
250
|
for loc, spec in itertools.chain(
|
|
@@ -197,9 +252,11 @@ def _make_tf_signature(
|
|
|
197
252
|
)
|
|
198
253
|
if loc.type_ == stablehlo.VariableType.INPUT_ARG
|
|
199
254
|
}
|
|
255
|
+
assert len(input_pos_to_spec) == len(input_names)
|
|
256
|
+
|
|
200
257
|
primitive_type_to_tf_type = {"int": "int32", "float": "float32"}
|
|
201
258
|
ret: list[tf.TensorSpec] = []
|
|
202
|
-
for i in
|
|
259
|
+
for i, name in enumerate(input_names):
|
|
203
260
|
spec = input_pos_to_spec[i]
|
|
204
261
|
shape = _get_shape_with_dynamic(spec)
|
|
205
262
|
ret.append(
|
|
@@ -208,7 +265,7 @@ def _make_tf_signature(
|
|
|
208
265
|
dtype=primitive_type_to_tf_type[spec.dtype]
|
|
209
266
|
if spec.dtype in primitive_type_to_tf_type
|
|
210
267
|
else spec.dtype,
|
|
211
|
-
name=
|
|
268
|
+
name=name,
|
|
212
269
|
)
|
|
213
270
|
)
|
|
214
271
|
return ret
|
|
@@ -276,7 +333,8 @@ def convert_stablehlo_to_tflite(
|
|
|
276
333
|
tf.Variable(v, trainable=False) for v in bundle.additional_constants
|
|
277
334
|
]
|
|
278
335
|
tf_signatures: list[list[tf.TensorSpec]] = list(
|
|
279
|
-
_make_tf_signature(func.meta)
|
|
336
|
+
_make_tf_signature(func.meta, sig)
|
|
337
|
+
for func, sig in zip(bundle.stablehlo_funcs, signatures)
|
|
280
338
|
)
|
|
281
339
|
|
|
282
340
|
tf_functions = _make_tf_function(shlo_graph_module, bundle)
|
|
@@ -34,17 +34,23 @@ class Converter:
|
|
|
34
34
|
self,
|
|
35
35
|
name: str,
|
|
36
36
|
module: torch.nn.Module,
|
|
37
|
-
sample_args
|
|
37
|
+
sample_args=None,
|
|
38
|
+
sample_kwargs=None,
|
|
39
|
+
*,
|
|
38
40
|
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
|
|
39
41
|
) -> Converter:
|
|
40
42
|
"""Alias to `add_signature`"""
|
|
41
|
-
return self.add_signature(
|
|
43
|
+
return self.add_signature(
|
|
44
|
+
name, module, sample_args, sample_kwargs, dynamic_shapes=dynamic_shapes
|
|
45
|
+
)
|
|
42
46
|
|
|
43
47
|
def add_signature(
|
|
44
48
|
self,
|
|
45
49
|
name: str,
|
|
46
50
|
module: torch.nn.Module,
|
|
47
|
-
sample_args
|
|
51
|
+
sample_args=None,
|
|
52
|
+
sample_kwargs=None,
|
|
53
|
+
*,
|
|
48
54
|
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
|
|
49
55
|
) -> Converter:
|
|
50
56
|
"""Allows adding a new named torch model along with sample args to the conversion.
|
|
@@ -52,7 +58,8 @@ class Converter:
|
|
|
52
58
|
Args:
|
|
53
59
|
name: The name of the signature included in the converted edge model.
|
|
54
60
|
module: The torch module to be converted.
|
|
55
|
-
sample_args: Tuple of
|
|
61
|
+
sample_args: Tuple of tensors by which the torch module will be traced with prior to conversion.
|
|
62
|
+
sample_kwargs: Dict of str to tensor by which the torch module will be traced with prior to conversion.
|
|
56
63
|
dynamic_shapes: Optional dict or tuple that specify dynamic shape specifications for each input in original order.
|
|
57
64
|
See https://pytorch.org/docs/stable/export.html#expressing-dynamism for more details.
|
|
58
65
|
|
|
@@ -63,13 +70,21 @@ class Converter:
|
|
|
63
70
|
if name in [sig.name for sig in self._signatures]:
|
|
64
71
|
raise ValueError(f"A signature with the provided name ({name}) is already added.")
|
|
65
72
|
|
|
66
|
-
|
|
73
|
+
if sample_args is None and sample_kwargs is None:
|
|
74
|
+
raise ValueError("sample_args or sample_kwargs must be provided.")
|
|
75
|
+
|
|
76
|
+
self._signatures.append(
|
|
77
|
+
cutils.Signature(
|
|
78
|
+
name, module, sample_args, sample_kwargs, dynamic_shapes=dynamic_shapes
|
|
79
|
+
)
|
|
80
|
+
)
|
|
67
81
|
return self
|
|
68
82
|
|
|
69
83
|
def convert(
|
|
70
84
|
self,
|
|
71
85
|
module: torch.nn.Module = None,
|
|
72
|
-
sample_args
|
|
86
|
+
sample_args=None,
|
|
87
|
+
sample_kwargs=None,
|
|
73
88
|
*,
|
|
74
89
|
quant_config: Optional[qcfg.QuantConfig] = None,
|
|
75
90
|
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
|
|
@@ -88,7 +103,8 @@ class Converter:
|
|
|
88
103
|
Args:
|
|
89
104
|
name: The name of the signature included in the converted edge model.
|
|
90
105
|
module: The torch module to be converted.
|
|
91
|
-
sample_args: Tuple of
|
|
106
|
+
sample_args: Tuple of tensors by which the torch module will be traced with prior to conversion.
|
|
107
|
+
sample_kwargs: Dict of str to tensor by which the torch module will be traced with prior to conversion.
|
|
92
108
|
quant_config: User-defined quantization method and scheme of the model.
|
|
93
109
|
dynamic_shapes: Optional dict or tuple that specify dynamic shape specifications for each input in original order.
|
|
94
110
|
See https://pytorch.org/docs/stable/export.html#expressing-dynamism for more details.
|
|
@@ -100,12 +116,20 @@ class Converter:
|
|
|
100
116
|
ValueError: If the arguments are not provided as expected. See the example in this functions's comment.
|
|
101
117
|
"""
|
|
102
118
|
if module is not None:
|
|
103
|
-
if
|
|
119
|
+
if (
|
|
120
|
+
sample_args is not None or sample_kwargs is not None
|
|
121
|
+
): # both module and args provided
|
|
104
122
|
self.add_signature(
|
|
105
|
-
cutils.DEFAULT_SIGNATURE_NAME,
|
|
123
|
+
cutils.DEFAULT_SIGNATURE_NAME,
|
|
124
|
+
module,
|
|
125
|
+
sample_args,
|
|
126
|
+
sample_kwargs,
|
|
127
|
+
dynamic_shapes=dynamic_shapes,
|
|
128
|
+
)
|
|
129
|
+
else: # module is provided but not args
|
|
130
|
+
raise ValueError(
|
|
131
|
+
"sample_args or sample_kwargs must be provided if a module is specified."
|
|
106
132
|
)
|
|
107
|
-
else: # module is provided but not sample_args
|
|
108
|
-
raise ValueError("sample_args needs to be provided if a module is specified.")
|
|
109
133
|
|
|
110
134
|
return conversion.convert_signatures(
|
|
111
135
|
self._signatures,
|
|
@@ -117,7 +141,8 @@ class Converter:
|
|
|
117
141
|
def signature(
|
|
118
142
|
name: str,
|
|
119
143
|
module: torch.nn.Module,
|
|
120
|
-
sample_args
|
|
144
|
+
sample_args=None,
|
|
145
|
+
sample_kwargs=None,
|
|
121
146
|
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
|
|
122
147
|
) -> Converter:
|
|
123
148
|
"""Initiates a Converter object with the provided signature.
|
|
@@ -125,7 +150,8 @@ def signature(
|
|
|
125
150
|
Args:
|
|
126
151
|
name: The name of the signature included in the converted edge model.
|
|
127
152
|
module: The torch module to be converted.
|
|
128
|
-
sample_args: Tuple of
|
|
153
|
+
sample_args: Tuple of tensors by which the torch module will be traced with prior to conversion.
|
|
154
|
+
sample_kwargs: Dict of str to tensor by which the torch module will be traced with prior to conversion.
|
|
129
155
|
dynamic_shapes: Optional dict or tuple that specify dynamic shape specifications for each input in original order.
|
|
130
156
|
See https://pytorch.org/docs/stable/export.html#expressing-dynamism for more details.
|
|
131
157
|
|
|
@@ -134,12 +160,15 @@ def signature(
|
|
|
134
160
|
edge_model = converter.convert()
|
|
135
161
|
|
|
136
162
|
"""
|
|
137
|
-
return Converter().signature(
|
|
163
|
+
return Converter().signature(
|
|
164
|
+
name, module, sample_args, sample_kwargs, dynamic_shapes=dynamic_shapes
|
|
165
|
+
)
|
|
138
166
|
|
|
139
167
|
|
|
140
168
|
def convert(
|
|
141
169
|
module: torch.nn.Module = None,
|
|
142
|
-
sample_args
|
|
170
|
+
sample_args=None,
|
|
171
|
+
sample_kwargs=None,
|
|
143
172
|
*,
|
|
144
173
|
quant_config: Optional[qcfg.QuantConfig] = None,
|
|
145
174
|
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
|
|
@@ -149,7 +178,8 @@ def convert(
|
|
|
149
178
|
|
|
150
179
|
Args:
|
|
151
180
|
module: The torch module to be converted.
|
|
152
|
-
sample_args: Tuple of
|
|
181
|
+
sample_args: Tuple of tensors by which the torch module will be traced with prior to conversion.
|
|
182
|
+
sample_kwargs: Dict of str to tensor by which the torch module will be traced with prior to conversion.
|
|
153
183
|
quant_config: User-defined quantization method and scheme of the model.
|
|
154
184
|
dynamic_shapes: Optional dict or tuple that specify dynamic shape specifications for each input in original order.
|
|
155
185
|
See https://pytorch.org/docs/stable/export.html#expressing-dynamism for more details.
|
|
@@ -165,6 +195,7 @@ def convert(
|
|
|
165
195
|
return Converter().convert(
|
|
166
196
|
module,
|
|
167
197
|
sample_args,
|
|
198
|
+
sample_kwargs,
|
|
168
199
|
quant_config=quant_config,
|
|
169
200
|
dynamic_shapes=dynamic_shapes,
|
|
170
201
|
_ai_edge_converter_flags=_ai_edge_converter_flags,
|
|
@@ -267,6 +267,45 @@ class TestConvert(unittest.TestCase):
|
|
|
267
267
|
model_coverage.compare_tflite_torch(edge_model, model, validate_input)
|
|
268
268
|
)
|
|
269
269
|
|
|
270
|
+
def test_convert_model_with_kwargs(self):
|
|
271
|
+
"""
|
|
272
|
+
Test converting a simple model with sample_kwargs.
|
|
273
|
+
"""
|
|
274
|
+
|
|
275
|
+
class SampleModel(torch.nn.Module):
|
|
276
|
+
|
|
277
|
+
def forward(self, x, y):
|
|
278
|
+
return x + y
|
|
279
|
+
|
|
280
|
+
kwargs_gen = lambda: dict(x=torch.randn(10, 10), y=torch.randn(10, 10))
|
|
281
|
+
|
|
282
|
+
model = SampleModel().eval()
|
|
283
|
+
edge_model = ai_edge_torch.convert(model, sample_kwargs=kwargs_gen())
|
|
284
|
+
|
|
285
|
+
self.assertTrue(
|
|
286
|
+
model_coverage.compare_tflite_torch(edge_model, model, kwargs=kwargs_gen)
|
|
287
|
+
)
|
|
288
|
+
|
|
289
|
+
def test_convert_model_with_args_kwargs(self):
|
|
290
|
+
"""
|
|
291
|
+
Test converting a simple model with both sample_args and sample_kwargs.
|
|
292
|
+
"""
|
|
293
|
+
|
|
294
|
+
class SampleModel(torch.nn.Module):
|
|
295
|
+
|
|
296
|
+
def forward(self, x, y):
|
|
297
|
+
return x + y
|
|
298
|
+
|
|
299
|
+
args_gen = lambda: (torch.randn(10, 10),)
|
|
300
|
+
kwargs_gen = lambda: dict(y=torch.randn(10, 10))
|
|
301
|
+
|
|
302
|
+
model = SampleModel().eval()
|
|
303
|
+
edge_model = ai_edge_torch.convert(model, args_gen(), kwargs_gen())
|
|
304
|
+
|
|
305
|
+
self.assertTrue(
|
|
306
|
+
model_coverage.compare_tflite_torch(edge_model, model, args_gen, kwargs_gen)
|
|
307
|
+
)
|
|
308
|
+
|
|
270
309
|
|
|
271
310
|
if __name__ == "__main__":
|
|
272
311
|
unittest.main()
|
|
@@ -40,7 +40,7 @@ class RemoveSDPACompositeZeroMaskPass(ExportedProgramPassBase):
|
|
|
40
40
|
if self.is_zero_tensor_node(source):
|
|
41
41
|
# Remove the mark_tensor call on the mask input by
|
|
42
42
|
# replacing the target with an identity function.
|
|
43
|
-
node.target = lambda *args, **kwargs: args[0]
|
|
43
|
+
node.target = lambda *args, **kwargs: torch.zeros_like(args[0])
|
|
44
44
|
|
|
45
45
|
exported_program.graph_module.graph.lint()
|
|
46
46
|
exported_program.graph_module.recompile()
|
ai_edge_torch/model.py
CHANGED
|
@@ -33,7 +33,10 @@ class Model(abc.ABC):
|
|
|
33
33
|
|
|
34
34
|
@abc.abstractmethod
|
|
35
35
|
def __call__(
|
|
36
|
-
self,
|
|
36
|
+
self,
|
|
37
|
+
*args: npt.ArrayLike,
|
|
38
|
+
signature_name: str = cutils.DEFAULT_SIGNATURE_NAME,
|
|
39
|
+
**kwargs,
|
|
37
40
|
) -> npt.ArrayLike | tuple[npt.ArrayLike]:
|
|
38
41
|
raise NotImplementedError()
|
|
39
42
|
|
|
@@ -62,12 +65,16 @@ class TfLiteModel(Model):
|
|
|
62
65
|
self._tflite_model = tflite_model
|
|
63
66
|
|
|
64
67
|
def __call__(
|
|
65
|
-
self,
|
|
68
|
+
self,
|
|
69
|
+
*args: npt.ArrayLike,
|
|
70
|
+
signature_name: str = cutils.DEFAULT_SIGNATURE_NAME,
|
|
71
|
+
**kwargs,
|
|
66
72
|
) -> npt.ArrayLike | tuple[npt.ArrayLike]:
|
|
67
73
|
"""Runs inference on the edge model using the provided arguments.
|
|
68
74
|
|
|
69
75
|
Args:
|
|
70
76
|
*args: The arguments to be passed to the model for inference.
|
|
77
|
+
**kwargs: The arguments with specific names to be passed to the model for inference.
|
|
71
78
|
signature_name: The name of the signature to be used for inference.
|
|
72
79
|
The default signature is used if not provided.
|
|
73
80
|
"""
|
|
@@ -90,13 +97,14 @@ class TfLiteModel(Model):
|
|
|
90
97
|
else:
|
|
91
98
|
raise exception
|
|
92
99
|
|
|
93
|
-
if len(signature_list[signature_name]['inputs']) != len(args):
|
|
100
|
+
if len(signature_list[signature_name]['inputs']) != len(args) + len(kwargs):
|
|
94
101
|
raise ValueError(
|
|
95
102
|
f"The model requires {len(signature_list[signature_name]['inputs'])} arguments but {len(args)} was provided."
|
|
96
103
|
)
|
|
97
104
|
|
|
98
105
|
# Gather the input dictionary based on the signature.
|
|
99
106
|
inputs = {f'args_{idx}': args[idx] for idx in range(len(args))}
|
|
107
|
+
inputs = {**inputs, **kwargs}
|
|
100
108
|
outputs = runner(**inputs)
|
|
101
109
|
|
|
102
110
|
return (
|
|
@@ -60,7 +60,8 @@ def _torch_tensors_to_np(*argv):
|
|
|
60
60
|
def compare_tflite_torch(
|
|
61
61
|
edge_model: Model,
|
|
62
62
|
torch_eval_func: Callable,
|
|
63
|
-
|
|
63
|
+
args=None,
|
|
64
|
+
kwargs=None,
|
|
64
65
|
*,
|
|
65
66
|
num_valid_inputs: int = 1,
|
|
66
67
|
signature_name: str = None,
|
|
@@ -71,8 +72,9 @@ def compare_tflite_torch(
|
|
|
71
72
|
Args:
|
|
72
73
|
edge_model: Serialized ai_edge_torch.model.Model object.
|
|
73
74
|
torch_eval_func: Callable function to evaluate torch model.
|
|
74
|
-
|
|
75
|
+
args: torch.tensor array or a callable to generate a torch.tensor array
|
|
75
76
|
with random data, to pass into models during inference. (default None).
|
|
77
|
+
kwargs: dict of str to torch.tensor, or a callable to generate such.
|
|
76
78
|
num_valid_inputs: Defines the number of times the random inputs will be generated (if a callable is provided for input_data).
|
|
77
79
|
signature_name: If provided, specifies the name for the signature of the edge_model to run.
|
|
78
80
|
Calls the default signature if not provided.
|
|
@@ -86,29 +88,33 @@ def compare_tflite_torch(
|
|
|
86
88
|
# The supplied model_def.forward_args() will be executed num_valid_inputs
|
|
87
89
|
# times to generate num_valid_inputs random inputs.
|
|
88
90
|
torch_inputs = [
|
|
89
|
-
|
|
91
|
+
(
|
|
92
|
+
(args() if callable(args) else args) or tuple(),
|
|
93
|
+
(kwargs() if callable(kwargs) else kwargs) or {},
|
|
94
|
+
)
|
|
90
95
|
for _ in range(num_valid_inputs)
|
|
91
96
|
]
|
|
92
|
-
torch_outputs = [torch_eval_func(*
|
|
93
|
-
np_inputs = [
|
|
97
|
+
torch_outputs = [torch_eval_func(*args, **kwargs) for args, kwargs in torch_inputs]
|
|
98
|
+
np_inputs = [
|
|
99
|
+
(_torch_tensors_to_np(args), _torch_tensors_to_np(kwargs))
|
|
100
|
+
for args, kwargs in torch_inputs
|
|
101
|
+
]
|
|
94
102
|
np_outputs = [_torch_tensors_to_np(_flatten(ys)) for ys in torch_outputs]
|
|
95
103
|
|
|
96
104
|
# Define inline utility function used throughout the function.
|
|
97
105
|
def equal_fn(actual, expected):
|
|
98
106
|
return np.allclose(actual, expected, atol=atol, rtol=rtol)
|
|
99
107
|
|
|
100
|
-
def
|
|
108
|
+
def get_edge_output(inputs):
|
|
109
|
+
args, kwargs = inputs
|
|
101
110
|
if signature_name is None:
|
|
102
|
-
return _flatten(edge_model(*
|
|
111
|
+
return _flatten(edge_model(*args, **kwargs))
|
|
103
112
|
else:
|
|
104
|
-
return _flatten(edge_model(*
|
|
105
|
-
|
|
106
|
-
def get_expected_fn(input=None, idx=0):
|
|
107
|
-
return np_outputs[idx]
|
|
113
|
+
return _flatten(edge_model(*args, **kwargs, signature_name=signature_name))
|
|
108
114
|
|
|
109
115
|
for idx, np_input in enumerate(np_inputs):
|
|
110
|
-
output =
|
|
111
|
-
golden_output =
|
|
116
|
+
output = get_edge_output(np_input)
|
|
117
|
+
golden_output = np_outputs[idx]
|
|
112
118
|
|
|
113
119
|
is_output_len_eq = len(golden_output) == len(output)
|
|
114
120
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: ai-edge-torch-nightly
|
|
3
|
-
Version: 0.2.0.
|
|
3
|
+
Version: 0.2.0.dev20240712
|
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
|
@@ -1,9 +1,9 @@
|
|
|
1
1
|
ai_edge_torch/__init__.py,sha256=CNDboRP4zQBpz2hznNCQWcQCARvNXUm3DMa1Dw_XXFg,1067
|
|
2
|
-
ai_edge_torch/model.py,sha256=
|
|
2
|
+
ai_edge_torch/model.py,sha256=8Ba9ia7TCM_fciulw6qObmzdcxL3IaLQKDqpR7Lxp-Q,4440
|
|
3
3
|
ai_edge_torch/convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
4
|
-
ai_edge_torch/convert/conversion.py,sha256=
|
|
5
|
-
ai_edge_torch/convert/conversion_utils.py,sha256=
|
|
6
|
-
ai_edge_torch/convert/converter.py,sha256=
|
|
4
|
+
ai_edge_torch/convert/conversion.py,sha256=StJHglvx6cii36oi8sj-tZda009e9UqR6ufZOZkP1SY,4137
|
|
5
|
+
ai_edge_torch/convert/conversion_utils.py,sha256=PKXIlSCU-8DhppNBh9ICDNUlEOpV0HgCbt85jDVe3rA,13394
|
|
6
|
+
ai_edge_torch/convert/converter.py,sha256=hSrW6A-kix9cjdD6CuLL7rseWrLKoV6GRy-iUSW_nZc,7875
|
|
7
7
|
ai_edge_torch/convert/to_channel_last_io.py,sha256=zo5tY3yDhY_EPCkrL1XSXs2uRFS8B4_qu08dSjNsUGk,2778
|
|
8
8
|
ai_edge_torch/convert/fx_passes/__init__.py,sha256=EPs4PSIDLuRH5EBETi6deaOvaaf_Q4xD3_9NVcR7x8o,2810
|
|
9
9
|
ai_edge_torch/convert/fx_passes/_pass_base.py,sha256=ijVyDclPnd6a0DWWUJkwR4igj6f82S-cE1-83QGPvgw,1652
|
|
@@ -22,7 +22,7 @@ ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partition
|
|
|
22
22
|
ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py,sha256=8uHJbIwPMTgeSfYVba163pkXSQkHLxFwar_8A1AhgAM,2279
|
|
23
23
|
ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py,sha256=lklGxE1R32vsjFbhLLBDEFL4pfLi_iTgI9Ftb6Grezk,7156
|
|
24
24
|
ai_edge_torch/convert/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
25
|
-
ai_edge_torch/convert/test/test_convert.py,sha256=
|
|
25
|
+
ai_edge_torch/convert/test/test_convert.py,sha256=h0vOffr8saDQRkiXljNWDZ17EBjnS4xAtxd8DxETleY,9081
|
|
26
26
|
ai_edge_torch/convert/test/test_convert_composites.py,sha256=_Ojc-H6GOS5s8ek3_8eRBL_AiCs-k3srziPJ2R4Ulrg,7255
|
|
27
27
|
ai_edge_torch/convert/test/test_convert_multisig.py,sha256=kMaGnHe9ylfyU68qCifYcaGwJqyejKz--QQt9jS2oUA,4537
|
|
28
28
|
ai_edge_torch/convert/test/test_to_channel_last_io.py,sha256=I8c4ZG3v1vo0yxQYzLK_BTId4AOL9vadHGDtfCUZ4UI,2930
|
|
@@ -68,7 +68,7 @@ ai_edge_torch/generative/examples/tiny_llama/__init__.py,sha256=47DEQpj8HBSa-_TI
|
|
|
68
68
|
ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py,sha256=nT7Fh-f5ZdwaK3dPoCvZflpJ4fRHjLdFMjk1_uw3-b8,2559
|
|
69
69
|
ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=IFRLPG9wz_aLl_zV_6CETCjSM03ukA6bZqqyDLVACuw,5651
|
|
70
70
|
ai_edge_torch/generative/fx_passes/__init__.py,sha256=aXvYiaHDvETIrh0Q9DDZA_ZBiazGk80DT6nt7lLtC1o,1172
|
|
71
|
-
ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=
|
|
71
|
+
ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=BCAcc_OcEjvbaXQSbc8vlKeMad7E3gCA4BNsUdWRwBI,1966
|
|
72
72
|
ai_edge_torch/generative/layers/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
73
73
|
ai_edge_torch/generative/layers/attention.py,sha256=AW0Qo3uOIe6p1rJNJ6zR_r4fqL2y-6QJHh0yUd-5Yb0,11966
|
|
74
74
|
ai_edge_torch/generative/layers/attention_utils.py,sha256=hXhuyKblPPxKIRzlAf1YNlwHgpbj-6nReRLhRHELx5k,6350
|
|
@@ -113,9 +113,9 @@ ai_edge_torch/quantize/pt2e_quantizer_utils.py,sha256=yjzKoptnfEeW_sN7sODUfj3nCt
|
|
|
113
113
|
ai_edge_torch/quantize/quant_config.py,sha256=eO9Ra160ITjQSyRBEGy6nNIVH3gYacSWDdN5XtvHwjc,3148
|
|
114
114
|
ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
115
115
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
|
116
|
-
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=
|
|
117
|
-
ai_edge_torch_nightly-0.2.0.
|
|
118
|
-
ai_edge_torch_nightly-0.2.0.
|
|
119
|
-
ai_edge_torch_nightly-0.2.0.
|
|
120
|
-
ai_edge_torch_nightly-0.2.0.
|
|
121
|
-
ai_edge_torch_nightly-0.2.0.
|
|
116
|
+
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=kzIulTldq8R9E-lAZsvfSTvLu3FYEX7b9DyYM3qisXM,4485
|
|
117
|
+
ai_edge_torch_nightly-0.2.0.dev20240712.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
|
118
|
+
ai_edge_torch_nightly-0.2.0.dev20240712.dist-info/METADATA,sha256=BGHmRLYo3ko7KRysDP59YexTpPn45jrtpRwqQkPAM5s,1745
|
|
119
|
+
ai_edge_torch_nightly-0.2.0.dev20240712.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
|
120
|
+
ai_edge_torch_nightly-0.2.0.dev20240712.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
|
121
|
+
ai_edge_torch_nightly-0.2.0.dev20240712.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|