ai-edge-torch-nightly 0.2.0.dev20240714__py3-none-any.whl → 0.3.0.dev20240926__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- ai_edge_torch/__init__.py +5 -4
- ai_edge_torch/_convert/conversion.py +112 -0
- ai_edge_torch/_convert/conversion_utils.py +64 -0
- ai_edge_torch/{convert → _convert}/converter.py +94 -48
- ai_edge_torch/_convert/fx_passes/__init__.py +22 -0
- ai_edge_torch/{convert → _convert}/fx_passes/build_aten_composite_pass.py +107 -44
- ai_edge_torch/{convert → _convert}/fx_passes/build_interpolate_composite_pass.py +23 -20
- ai_edge_torch/{convert → _convert}/fx_passes/inject_mlir_debuginfo_pass.py +5 -6
- ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/__init__.py +1 -1
- ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_check.py +39 -9
- ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_mark.py +2 -0
- ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_partitioners/__init__.py +1 -0
- ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py +17 -8
- ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py +9 -8
- ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +31 -18
- ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/op_func_registry.py +2 -2
- ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/pass_body.py +34 -24
- ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/utils.py +2 -0
- ai_edge_torch/_convert/signature.py +66 -0
- ai_edge_torch/_convert/test/test_convert.py +495 -0
- ai_edge_torch/_convert/test/test_convert_composites.py +234 -0
- ai_edge_torch/_convert/test/test_convert_multisig.py +189 -0
- ai_edge_torch/{convert → _convert}/test/test_to_channel_last_io.py +5 -5
- ai_edge_torch/{convert → _convert}/to_channel_last_io.py +10 -3
- ai_edge_torch/config.py +27 -0
- ai_edge_torch/conftest.py +20 -0
- ai_edge_torch/debug/culprit.py +72 -40
- ai_edge_torch/debug/test/test_culprit.py +7 -5
- ai_edge_torch/debug/test/test_search_model.py +8 -7
- ai_edge_torch/debug/utils.py +14 -3
- ai_edge_torch/fx_pass_base.py +101 -0
- ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py +68 -0
- ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +68 -0
- ai_edge_torch/generative/examples/gemma/{gemma.py → gemma1.py} +69 -55
- ai_edge_torch/generative/examples/gemma/gemma2.py +267 -0
- ai_edge_torch/generative/examples/gemma/verify_gemma1.py +56 -0
- ai_edge_torch/generative/examples/gemma/verify_gemma2.py +57 -0
- ai_edge_torch/generative/examples/gemma/verify_util.py +143 -0
- ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +68 -0
- ai_edge_torch/generative/examples/openelm/openelm.py +206 -0
- ai_edge_torch/generative/examples/openelm/verify.py +64 -0
- ai_edge_torch/generative/examples/phi/__init__.py +14 -0
- ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py +68 -0
- ai_edge_torch/generative/examples/phi/convert_to_tflite.py +68 -0
- ai_edge_torch/generative/examples/{phi2 → phi}/phi2.py +70 -51
- ai_edge_torch/generative/examples/phi/phi3.py +286 -0
- ai_edge_torch/generative/examples/phi/verify.py +65 -0
- ai_edge_torch/generative/examples/phi/verify_phi3.py +70 -0
- ai_edge_torch/generative/examples/smollm/__init__.py +14 -0
- ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +68 -0
- ai_edge_torch/generative/examples/smollm/smollm.py +101 -0
- ai_edge_torch/generative/examples/smollm/verify.py +62 -0
- ai_edge_torch/generative/examples/stable_diffusion/attention.py +3 -1
- ai_edge_torch/generative/examples/stable_diffusion/clip.py +83 -13
- ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +27 -14
- ai_edge_torch/generative/examples/stable_diffusion/decoder.py +74 -9
- ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +179 -37
- ai_edge_torch/generative/examples/stable_diffusion/encoder.py +4 -3
- ai_edge_torch/generative/examples/stable_diffusion/pipeline.py +83 -58
- ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler.py +4 -3
- ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler_ancestral.py +4 -3
- ai_edge_torch/generative/examples/stable_diffusion/samplers/k_lms.py +4 -3
- ai_edge_torch/generative/examples/stable_diffusion/samplers/sampler.py +1 -0
- ai_edge_torch/generative/examples/stable_diffusion/tokenizer.py +4 -1
- ai_edge_torch/generative/examples/stable_diffusion/util.py +9 -3
- ai_edge_torch/generative/examples/t5/convert_to_tflite.py +28 -25
- ai_edge_torch/generative/examples/t5/t5.py +208 -159
- ai_edge_torch/generative/examples/t5/t5_attention.py +45 -30
- ai_edge_torch/generative/examples/test_models/convert_toy_model.py +105 -0
- ai_edge_torch/generative/examples/test_models/toy_model.py +69 -41
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +50 -64
- ai_edge_torch/generative/examples/tiny_llama/__init__.py +14 -0
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +41 -39
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +67 -54
- ai_edge_torch/generative/examples/tiny_llama/verify.py +64 -0
- ai_edge_torch/generative/fx_passes/__init__.py +4 -5
- ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +10 -7
- ai_edge_torch/generative/layers/attention.py +141 -102
- ai_edge_torch/generative/layers/attention_utils.py +53 -12
- ai_edge_torch/generative/layers/builder.py +37 -7
- ai_edge_torch/generative/layers/feed_forward.py +39 -14
- ai_edge_torch/generative/layers/kv_cache.py +162 -50
- ai_edge_torch/generative/layers/model_config.py +84 -30
- ai_edge_torch/generative/layers/normalization.py +185 -7
- ai_edge_torch/generative/layers/rotary_position_embedding.py +6 -4
- ai_edge_torch/generative/layers/scaled_dot_product_attention.py +48 -21
- ai_edge_torch/generative/layers/unet/blocks_2d.py +136 -77
- ai_edge_torch/generative/layers/unet/builder.py +7 -4
- ai_edge_torch/generative/layers/unet/model_config.py +17 -15
- ai_edge_torch/generative/quantize/example.py +7 -8
- ai_edge_torch/generative/quantize/quant_recipe.py +10 -7
- ai_edge_torch/generative/quantize/quant_recipe_utils.py +12 -1
- ai_edge_torch/generative/quantize/quant_recipes.py +8 -0
- ai_edge_torch/generative/test/test_kv_cache.py +120 -0
- ai_edge_torch/generative/test/{loader_test.py → test_loader.py} +9 -7
- ai_edge_torch/generative/test/test_model_conversion.py +124 -188
- ai_edge_torch/generative/test/test_model_conversion_large.py +251 -0
- ai_edge_torch/generative/test/test_quantize.py +76 -60
- ai_edge_torch/generative/test/utils.py +54 -0
- ai_edge_torch/generative/utilities/converter.py +82 -0
- ai_edge_torch/generative/utilities/loader.py +120 -57
- ai_edge_torch/generative/utilities/stable_diffusion_loader.py +165 -57
- ai_edge_torch/generative/utilities/t5_loader.py +110 -81
- ai_edge_torch/generative/utilities/verifier.py +247 -0
- ai_edge_torch/hlfb/__init__.py +1 -1
- ai_edge_torch/hlfb/mark_pattern/__init__.py +9 -7
- ai_edge_torch/hlfb/mark_pattern/passes.py +23 -3
- ai_edge_torch/hlfb/mark_pattern/pattern.py +39 -30
- ai_edge_torch/hlfb/test/test_mark_pattern.py +46 -20
- ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py +24 -11
- ai_edge_torch/lowertools/__init__.py +18 -0
- ai_edge_torch/lowertools/_shim.py +80 -0
- ai_edge_torch/lowertools/common_utils.py +142 -0
- ai_edge_torch/lowertools/odml_torch_utils.py +255 -0
- ai_edge_torch/lowertools/test_utils.py +60 -0
- ai_edge_torch/lowertools/torch_xla_utils.py +284 -0
- ai_edge_torch/{generative/quantize/ai_edge_quantizer_glue → lowertools}/translate_recipe.py +29 -14
- ai_edge_torch/model.py +53 -18
- ai_edge_torch/odml_torch/__init__.py +20 -0
- ai_edge_torch/odml_torch/_torch_future.py +61 -0
- ai_edge_torch/odml_torch/_torch_library.py +19 -0
- ai_edge_torch/odml_torch/composite/__init__.py +16 -0
- ai_edge_torch/odml_torch/composite/mark_tensor.py +120 -0
- ai_edge_torch/odml_torch/composite/stablehlo_composite_builder.py +106 -0
- ai_edge_torch/odml_torch/debuginfo/__init__.py +16 -0
- ai_edge_torch/odml_torch/debuginfo/_build.py +43 -0
- ai_edge_torch/odml_torch/debuginfo/_op_polyfill.py +55 -0
- ai_edge_torch/odml_torch/export.py +357 -0
- ai_edge_torch/odml_torch/export_utils.py +168 -0
- ai_edge_torch/odml_torch/jax_bridge/__init__.py +15 -0
- ai_edge_torch/odml_torch/jax_bridge/_wrap.py +150 -0
- ai_edge_torch/odml_torch/jax_bridge/utils.py +75 -0
- ai_edge_torch/odml_torch/lowerings/__init__.py +25 -0
- ai_edge_torch/odml_torch/lowerings/_basic.py +258 -0
- ai_edge_torch/odml_torch/lowerings/_batch_norm.py +65 -0
- ai_edge_torch/odml_torch/lowerings/_convolution.py +241 -0
- ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +252 -0
- ai_edge_torch/odml_torch/lowerings/_layer_norm.py +78 -0
- ai_edge_torch/odml_torch/lowerings/context.py +42 -0
- ai_edge_torch/odml_torch/lowerings/registry.py +96 -0
- ai_edge_torch/odml_torch/lowerings/utils.py +185 -0
- ai_edge_torch/odml_torch/passes/__init__.py +38 -0
- ai_edge_torch/odml_torch/tf_integration.py +194 -0
- ai_edge_torch/quantize/pt2e_quantizer.py +52 -24
- ai_edge_torch/quantize/pt2e_quantizer_utils.py +43 -23
- ai_edge_torch/quantize/quant_config.py +13 -9
- ai_edge_torch/testing/model_coverage/model_coverage.py +29 -16
- ai_edge_torch/version.py +16 -0
- {ai_edge_torch_nightly-0.2.0.dev20240714.dist-info → ai_edge_torch_nightly-0.3.0.dev20240926.dist-info}/METADATA +7 -3
- ai_edge_torch_nightly-0.3.0.dev20240926.dist-info/RECORD +177 -0
- {ai_edge_torch_nightly-0.2.0.dev20240714.dist-info → ai_edge_torch_nightly-0.3.0.dev20240926.dist-info}/WHEEL +1 -1
- ai_edge_torch/convert/conversion.py +0 -117
- ai_edge_torch/convert/conversion_utils.py +0 -400
- ai_edge_torch/convert/fx_passes/__init__.py +0 -59
- ai_edge_torch/convert/fx_passes/_pass_base.py +0 -49
- ai_edge_torch/convert/fx_passes/canonicalize_pass.py +0 -37
- ai_edge_torch/convert/test/test_convert.py +0 -311
- ai_edge_torch/convert/test/test_convert_composites.py +0 -192
- ai_edge_torch/convert/test/test_convert_multisig.py +0 -139
- ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +0 -66
- ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +0 -64
- ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +0 -161
- ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/__init__.py +0 -0
- ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/RECORD +0 -121
- /ai_edge_torch/{convert → _convert}/__init__.py +0 -0
- /ai_edge_torch/{convert → _convert}/test/__init__.py +0 -0
- /ai_edge_torch/generative/examples/{phi2 → openelm}/__init__.py +0 -0
- {ai_edge_torch_nightly-0.2.0.dev20240714.dist-info → ai_edge_torch_nightly-0.3.0.dev20240926.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.2.0.dev20240714.dist-info → ai_edge_torch_nightly-0.3.0.dev20240926.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,194 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
"""APIs to convert lowered MLIR from PyTorch to TensorFlow and TFLite artifacts."""
|
16
|
+
|
17
|
+
import re
|
18
|
+
import tempfile
|
19
|
+
|
20
|
+
import tensorflow as tf
|
21
|
+
import torch
|
22
|
+
|
23
|
+
from tensorflow.compiler.tf2xla.python import xla as tfxla
|
24
|
+
|
25
|
+
from . import export
|
26
|
+
from . import export_utils
|
27
|
+
|
28
|
+
|
29
|
+
def torch_dtype_to_tf(dtype):
|
30
|
+
return {
|
31
|
+
torch.double: tf.float64,
|
32
|
+
torch.float32: tf.float32,
|
33
|
+
torch.half: tf.float16,
|
34
|
+
torch.long: tf.int64,
|
35
|
+
torch.int32: tf.int32,
|
36
|
+
torch.int16: tf.int16,
|
37
|
+
torch.bool: tf.bool,
|
38
|
+
}.get(dtype)
|
39
|
+
|
40
|
+
|
41
|
+
def _get_shape_with_dynamic(signature: export.VariableSignature):
|
42
|
+
return [
|
43
|
+
None if export_utils.is_torch_dynamic(s) else s for s in signature.shape
|
44
|
+
]
|
45
|
+
|
46
|
+
|
47
|
+
def _mangle_tf_root_scope_name(name):
|
48
|
+
r"""Build the mangled name for tf.Variable.
|
49
|
+
|
50
|
+
TF has more restricted constrain on the variable names at root scope. Root
|
51
|
+
scope name constrain: [A-Za-z0-9.][A-Za-z0-9_.\\-/]* Non-root scope name
|
52
|
+
constrain: [A-Za-z0-9_.\\-/]*
|
53
|
+
https://github.com/tensorflow/tensorflow/blob/51b601fa6bb7e801c0b6ae73c25580e40a8b5745/tensorflow/python/framework/ops.py#L3301-L3302
|
54
|
+
The state_dict key doesn't have such constrain, the name need to be mangled
|
55
|
+
when a root-scoped TF variable is created.
|
56
|
+
|
57
|
+
FX Graph Node may contain characters other than [A-Za-z0-9_.\\-/], replace
|
58
|
+
offending characters with '_'.
|
59
|
+
|
60
|
+
Args:
|
61
|
+
name: the tensor name to be mangled.
|
62
|
+
|
63
|
+
Returns:
|
64
|
+
Mangled name in str.
|
65
|
+
"""
|
66
|
+
if name[0] in "._\\-/":
|
67
|
+
name = "k" + name
|
68
|
+
name = re.sub(r"[^^\w\-/\\]+", "_", name)
|
69
|
+
return name
|
70
|
+
|
71
|
+
|
72
|
+
def _build_tf_state_dict(
|
73
|
+
lowered: export.MlirLowered,
|
74
|
+
) -> dict[str, tf.Variable]:
|
75
|
+
"""Build a dictionary of tf.Variable from the state_dict in lowered."""
|
76
|
+
tf_state_dict = {}
|
77
|
+
for sig in lowered.input_signature:
|
78
|
+
if sig.input_spec.is_parameter:
|
79
|
+
name = sig.input_spec.name
|
80
|
+
tf_state_dict[name] = tf.Variable(
|
81
|
+
lowered.state_dict[name].detach().numpy(),
|
82
|
+
trainable=False,
|
83
|
+
name=_mangle_tf_root_scope_name(name),
|
84
|
+
)
|
85
|
+
return tf_state_dict
|
86
|
+
|
87
|
+
|
88
|
+
def _extract_call_args(
|
89
|
+
lowered: export.MlirLowered,
|
90
|
+
args,
|
91
|
+
tf_state_dict: dict[str, tf.Variable],
|
92
|
+
):
|
93
|
+
"""Extract the flattened inputs to built tf.function."""
|
94
|
+
call_args = []
|
95
|
+
for sig in lowered.input_signature:
|
96
|
+
if sig.input_spec.is_user_input:
|
97
|
+
call_args.append(args[sig.input_spec.i])
|
98
|
+
elif sig.input_spec.is_parameter:
|
99
|
+
name = sig.input_spec.name
|
100
|
+
call_args.append(tf_state_dict[name])
|
101
|
+
return call_args
|
102
|
+
|
103
|
+
|
104
|
+
def _wrap_as_tf_func(lowered, tf_state_dict):
|
105
|
+
"""Build tf.function from lowered and tf_state_dict."""
|
106
|
+
|
107
|
+
def inner(*args):
|
108
|
+
t_outs = [torch_dtype_to_tf(sig.dtype) for sig in lowered.output_signature]
|
109
|
+
s_outs = [_get_shape_with_dynamic(sig) for sig in lowered.output_signature]
|
110
|
+
call_args = _extract_call_args(lowered, args, tf_state_dict)
|
111
|
+
return tfxla.call_module(
|
112
|
+
tuple(call_args),
|
113
|
+
version=5,
|
114
|
+
Tout=t_outs, # dtype information
|
115
|
+
Sout=s_outs, # Shape information
|
116
|
+
function_list=[],
|
117
|
+
module=lowered.module_bytecode,
|
118
|
+
)
|
119
|
+
|
120
|
+
return inner
|
121
|
+
|
122
|
+
|
123
|
+
def _make_input_signatures(
|
124
|
+
lowered: export.MlirLowered,
|
125
|
+
) -> list[tf.TensorSpec]:
|
126
|
+
"""Build the input signatures in tf.TensorSpec for building tf.function."""
|
127
|
+
user_input_signature = sorted(
|
128
|
+
[sig for sig in lowered.input_signature if sig.input_spec.is_user_input],
|
129
|
+
key=lambda sig: sig.input_spec.i,
|
130
|
+
)
|
131
|
+
tf_signatures = []
|
132
|
+
|
133
|
+
for sig in user_input_signature:
|
134
|
+
shape = _get_shape_with_dynamic(sig)
|
135
|
+
tf_signatures.append(
|
136
|
+
tf.TensorSpec(
|
137
|
+
shape=shape,
|
138
|
+
dtype=torch_dtype_to_tf(sig.dtype),
|
139
|
+
name=f"args_{sig.input_spec.i}",
|
140
|
+
)
|
141
|
+
)
|
142
|
+
return tf_signatures
|
143
|
+
|
144
|
+
|
145
|
+
def mlir_to_tf_function(lowered: export.MlirLowered):
|
146
|
+
"""Convert the MLIR lowered to a executable tf.function."""
|
147
|
+
tf_state_dict = _build_tf_state_dict(lowered)
|
148
|
+
return tf.function(
|
149
|
+
_wrap_as_tf_func(lowered, tf_state_dict),
|
150
|
+
input_signature=_make_input_signatures(lowered),
|
151
|
+
)
|
152
|
+
|
153
|
+
|
154
|
+
def mlir_to_flatbuffer(lowered: export.MlirLowered):
|
155
|
+
"""Convert the MLIR lowered to a TFLite flatbuffer binary."""
|
156
|
+
tf_state_dict = _build_tf_state_dict(lowered)
|
157
|
+
signature_names = [tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
|
158
|
+
tf_signatures = [_make_input_signatures(lowered)]
|
159
|
+
tf_functions = [_wrap_as_tf_func(lowered, tf_state_dict)]
|
160
|
+
|
161
|
+
tf_module = tf.Module()
|
162
|
+
tf_module.f = []
|
163
|
+
|
164
|
+
for tf_sig, func in zip(tf_signatures, tf_functions):
|
165
|
+
tf_module.f.append(
|
166
|
+
tf.function(
|
167
|
+
func,
|
168
|
+
input_signature=tf_sig,
|
169
|
+
)
|
170
|
+
)
|
171
|
+
|
172
|
+
tf_module._variables = list(tf_state_dict.values())
|
173
|
+
|
174
|
+
tf_concrete_funcs = [
|
175
|
+
func.get_concrete_function(*tf_sig)
|
176
|
+
for func, tf_sig in zip(tf_module.f, tf_signatures)
|
177
|
+
]
|
178
|
+
|
179
|
+
# We need to temporarily save since TFLite's from_concrete_functions does not
|
180
|
+
# allow providing names for each of the concrete functions.
|
181
|
+
with tempfile.TemporaryDirectory() as temp_dir_path:
|
182
|
+
tf.saved_model.save(
|
183
|
+
tf_module,
|
184
|
+
temp_dir_path,
|
185
|
+
signatures={
|
186
|
+
sig_name: tf_concrete_funcs[idx]
|
187
|
+
for idx, sig_name in enumerate(signature_names)
|
188
|
+
},
|
189
|
+
)
|
190
|
+
|
191
|
+
converter = tf.lite.TFLiteConverter.from_saved_model(temp_dir_path)
|
192
|
+
tflite_model = converter.convert()
|
193
|
+
|
194
|
+
return tflite_model
|
@@ -19,6 +19,12 @@ import copy
|
|
19
19
|
import functools
|
20
20
|
from typing import Any, Callable, Dict, List, Optional, Set
|
21
21
|
|
22
|
+
from ai_edge_torch.quantize.pt2e_quantizer_utils import _convert_scalars_to_attrs # NOQA
|
23
|
+
from ai_edge_torch.quantize.pt2e_quantizer_utils import OP_TO_ANNOTATOR
|
24
|
+
from ai_edge_torch.quantize.pt2e_quantizer_utils import OperatorConfig
|
25
|
+
from ai_edge_torch.quantize.pt2e_quantizer_utils import OperatorPatternType
|
26
|
+
from ai_edge_torch.quantize.pt2e_quantizer_utils import propagate_annotation
|
27
|
+
from ai_edge_torch.quantize.pt2e_quantizer_utils import QuantizationConfig
|
22
28
|
import torch
|
23
29
|
from torch.ao.quantization.fake_quantize import FusedMovingAvgObsFakeQuantize
|
24
30
|
from torch.ao.quantization.observer import HistogramObserver
|
@@ -34,20 +40,15 @@ from torch.ao.quantization.quantizer import Quantizer
|
|
34
40
|
from torch.fx import Node
|
35
41
|
import torch.nn.functional as F
|
36
42
|
|
37
|
-
from ai_edge_torch.quantize.pt2e_quantizer_utils import _convert_scalars_to_attrs # NOQA
|
38
|
-
from ai_edge_torch.quantize.pt2e_quantizer_utils import OP_TO_ANNOTATOR
|
39
|
-
from ai_edge_torch.quantize.pt2e_quantizer_utils import OperatorConfig
|
40
|
-
from ai_edge_torch.quantize.pt2e_quantizer_utils import OperatorPatternType
|
41
|
-
from ai_edge_torch.quantize.pt2e_quantizer_utils import propagate_annotation
|
42
|
-
from ai_edge_torch.quantize.pt2e_quantizer_utils import QuantizationConfig
|
43
|
-
|
44
43
|
__all__ = [
|
45
44
|
"PT2EQuantizer",
|
46
45
|
"get_symmetric_quantization_config",
|
47
46
|
]
|
48
47
|
|
49
48
|
|
50
|
-
def _supported_symmetric_quantized_operators() ->
|
49
|
+
def _supported_symmetric_quantized_operators() -> (
|
50
|
+
Dict[str, List[OperatorPatternType]]
|
51
|
+
):
|
51
52
|
supported_operators: Dict[str, List[OperatorPatternType]] = {
|
52
53
|
# Both conv and linear should be able to handle relu + hardtanh fusion since
|
53
54
|
# those are clamp ops
|
@@ -92,7 +93,9 @@ def get_symmetric_quantization_config(
|
|
92
93
|
):
|
93
94
|
if is_qat:
|
94
95
|
if is_dynamic:
|
95
|
-
raise NotImplementedError(
|
96
|
+
raise NotImplementedError(
|
97
|
+
"dynamic quantization for qat is not yet implemented."
|
98
|
+
)
|
96
99
|
act_observer_or_fake_quant_ctr = FusedMovingAvgObsFakeQuantize
|
97
100
|
else:
|
98
101
|
if is_dynamic:
|
@@ -106,12 +109,18 @@ def get_symmetric_quantization_config(
|
|
106
109
|
quant_max=127,
|
107
110
|
qscheme=torch.per_tensor_affine,
|
108
111
|
is_dynamic=is_dynamic,
|
109
|
-
observer_or_fake_quant_ctr=act_observer_or_fake_quant_ctr.with_args(
|
112
|
+
observer_or_fake_quant_ctr=act_observer_or_fake_quant_ctr.with_args(
|
113
|
+
eps=2**-12
|
114
|
+
),
|
110
115
|
)
|
111
116
|
qscheme = (
|
112
|
-
torch.per_channel_symmetric
|
117
|
+
torch.per_channel_symmetric
|
118
|
+
if is_per_channel
|
119
|
+
else torch.per_tensor_symmetric
|
120
|
+
)
|
121
|
+
weight_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = (
|
122
|
+
MinMaxObserver
|
113
123
|
)
|
114
|
-
weight_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = MinMaxObserver
|
115
124
|
if is_qat:
|
116
125
|
weight_observer_or_fake_quant_ctr = FusedMovingAvgObsFakeQuantize
|
117
126
|
elif is_per_channel:
|
@@ -179,15 +188,18 @@ def _get_supported_config_and_operators() -> List[OperatorConfig]:
|
|
179
188
|
|
180
189
|
def _get_module_name_filter(module_name: str):
|
181
190
|
"""Get the module_name_filter function for a given module name, the filter accepts
|
191
|
+
|
182
192
|
a node and checks if the node comes from a module that has certain module name
|
183
193
|
|
184
194
|
For example:
|
185
|
-
node: linear_op = call_function[...](...) # comes from a module with name
|
195
|
+
node: linear_op = call_function[...](...) # comes from a module with name
|
196
|
+
blocks.sub.linear1
|
186
197
|
|
187
198
|
|
188
199
|
>> module_name_filter = _get_module_name_filter("blocks.sub")
|
189
200
|
>> print(module_name_filter(node))
|
190
|
-
True # the node is from "blocks.sub" based on the fully qualified name
|
201
|
+
True # the node is from "blocks.sub" based on the fully qualified name
|
202
|
+
"blocks.sub.linear1"
|
191
203
|
"""
|
192
204
|
|
193
205
|
def module_name_filter(n: Node) -> bool:
|
@@ -197,7 +209,9 @@ def _get_module_name_filter(module_name: str):
|
|
197
209
|
# }
|
198
210
|
# get_attr nodes doesn't have nn_module_stack?
|
199
211
|
nn_module_stack = n.meta.get("nn_module_stack", {})
|
200
|
-
names = [
|
212
|
+
names = [
|
213
|
+
n[len("L__self___") :].replace("_", ".") for n in nn_module_stack.keys()
|
214
|
+
]
|
201
215
|
return module_name in names
|
202
216
|
|
203
217
|
return module_name_filter
|
@@ -205,15 +219,19 @@ def _get_module_name_filter(module_name: str):
|
|
205
219
|
|
206
220
|
def _get_module_type_filter(tp: Callable):
|
207
221
|
"""Get the module_type_filter function for a given module type, the filter accepts
|
222
|
+
|
208
223
|
a node and checks if the node comes from a module that has certain module type
|
209
224
|
|
210
225
|
For example:
|
211
|
-
node: linear_op = call_function[...](...) # comes from a module with type
|
226
|
+
node: linear_op = call_function[...](...) # comes from a module with type
|
227
|
+
Block -> Sub -> Linear
|
212
228
|
|
213
229
|
|
214
|
-
>> module_type_filter = _get_module_type_filter(Sub) # submodule with type
|
230
|
+
>> module_type_filter = _get_module_type_filter(Sub) # submodule with type
|
231
|
+
`Sub`, under the `Block` submodule
|
215
232
|
>> print(module_type_filter(node))
|
216
|
-
True # the node is from the submodule `Sub` (same for `Block` and `Linear` as
|
233
|
+
True # the node is from the submodule `Sub` (same for `Block` and `Linear` as
|
234
|
+
well)
|
217
235
|
"""
|
218
236
|
|
219
237
|
def module_type_filter(n: Node) -> bool:
|
@@ -232,7 +250,9 @@ def _get_not_module_type_or_name_filter(
|
|
232
250
|
tp_list: List[Callable], module_name_list: List[str]
|
233
251
|
) -> Callable[[Node], bool]:
|
234
252
|
module_type_filters = [_get_module_type_filter(tp) for tp in tp_list]
|
235
|
-
module_name_list_filters = [
|
253
|
+
module_name_list_filters = [
|
254
|
+
_get_module_name_filter(m) for m in module_name_list
|
255
|
+
]
|
236
256
|
|
237
257
|
def not_module_type_or_name_filter(n: Node) -> bool:
|
238
258
|
return not any(f(n) for f in module_type_filters + module_name_list_filters)
|
@@ -307,7 +327,9 @@ class PT2EQuantizer(Quantizer):
|
|
307
327
|
return ops
|
308
328
|
return []
|
309
329
|
|
310
|
-
def set_global(
|
330
|
+
def set_global(
|
331
|
+
self, quantization_config: QuantizationConfig
|
332
|
+
) -> PT2EQuantizer:
|
311
333
|
self.global_config = quantization_config
|
312
334
|
return self
|
313
335
|
|
@@ -323,8 +345,11 @@ class PT2EQuantizer(Quantizer):
|
|
323
345
|
self, module_type: Callable, quantization_config: QuantizationConfig
|
324
346
|
):
|
325
347
|
"""Set quantization_config for a submodule with type: `module_type`, for example:
|
326
|
-
|
327
|
-
|
348
|
+
|
349
|
+
quantizer.set_module_name(Sub) or quantizer.set_module_name(nn.Linear), it
|
350
|
+
will quantize all supported operator/operator
|
351
|
+
patterns in the submodule with this module type with the given
|
352
|
+
`quantization_config`
|
328
353
|
"""
|
329
354
|
self.module_type_config[module_type] = quantization_config
|
330
355
|
return self
|
@@ -333,8 +358,11 @@ class PT2EQuantizer(Quantizer):
|
|
333
358
|
self, module_name: str, quantization_config: Optional[QuantizationConfig]
|
334
359
|
):
|
335
360
|
"""Set quantization_config for a submodule with name: `module_name`, for example:
|
336
|
-
|
337
|
-
|
361
|
+
|
362
|
+
quantizer.set_module_name("blocks.sub"), it will quantize all supported
|
363
|
+
operator/operator
|
364
|
+
patterns in the submodule with this module name with the given
|
365
|
+
`quantization_config`
|
338
366
|
"""
|
339
367
|
assert (
|
340
368
|
quantization_config is not None
|
@@ -31,7 +31,7 @@ from torch.ao.quantization.quantizer import SharedQuantizationSpec
|
|
31
31
|
from torch.ao.quantization.quantizer.utils import _annotate_input_qspec_map
|
32
32
|
from torch.ao.quantization.quantizer.utils import _annotate_output_qspec
|
33
33
|
from torch.fx import Node
|
34
|
-
from torch.fx.passes.utils.matcher_with_name_node_map_utils import SubgraphMatcherWithNameNodeMap
|
34
|
+
from torch.fx.passes.utils.matcher_with_name_node_map_utils import SubgraphMatcherWithNameNodeMap
|
35
35
|
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions
|
36
36
|
import torch.nn.functional as F
|
37
37
|
|
@@ -95,9 +95,10 @@ class OperatorConfig(NamedTuple):
|
|
95
95
|
|
96
96
|
|
97
97
|
def _is_annotated(nodes: List[Node]):
|
98
|
-
"""
|
99
|
-
|
100
|
-
|
98
|
+
"""Checks if a list of nodes is annotated.
|
99
|
+
|
100
|
+
Given a list of nodes (that represents an operator pattern), check if any of
|
101
|
+
the node is annotated, return True if any of the node
|
101
102
|
is annotated, otherwise return False
|
102
103
|
"""
|
103
104
|
annotated = False
|
@@ -154,7 +155,9 @@ def get_weight_qspec(quantization_config: Optional[QuantizationConfig]):
|
|
154
155
|
torch.per_tensor_symmetric,
|
155
156
|
torch.per_channel_symmetric,
|
156
157
|
]:
|
157
|
-
raise ValueError(
|
158
|
+
raise ValueError(
|
159
|
+
f"Unsupported quantization_spec {quantization_spec} for weight"
|
160
|
+
)
|
158
161
|
return quantization_spec
|
159
162
|
|
160
163
|
|
@@ -193,7 +196,10 @@ def _annotate_linear(
|
|
193
196
|
weight_qspec = get_weight_qspec(quantization_config)
|
194
197
|
bias_qspec = get_bias_qspec(quantization_config)
|
195
198
|
for node in gm.graph.nodes:
|
196
|
-
if
|
199
|
+
if (
|
200
|
+
node.op != "call_function"
|
201
|
+
or node.target != torch.ops.aten.linear.default
|
202
|
+
):
|
197
203
|
continue
|
198
204
|
if filter_fn and not filter_fn(node):
|
199
205
|
continue
|
@@ -413,11 +419,13 @@ def _annotate_conv_bn(
|
|
413
419
|
quantization_config: Optional[QuantizationConfig],
|
414
420
|
filter_fn: Optional[Callable[[Node], bool]] = None,
|
415
421
|
) -> Optional[List[List[Node]]]:
|
422
|
+
"""Find conv + batchnorm parititions Note: This is only used for QAT.
|
423
|
+
|
424
|
+
In PTQ, batchnorm should already be fused into the conv.
|
416
425
|
"""
|
417
|
-
|
418
|
-
|
419
|
-
|
420
|
-
return _do_annotate_conv_bn(gm, quantization_config, filter_fn, has_relu=False)
|
426
|
+
return _do_annotate_conv_bn(
|
427
|
+
gm, quantization_config, filter_fn, has_relu=False
|
428
|
+
)
|
421
429
|
|
422
430
|
|
423
431
|
@register_annotator("conv_bn_relu")
|
@@ -426,9 +434,9 @@ def _annotate_conv_bn_relu(
|
|
426
434
|
quantization_config: Optional[QuantizationConfig],
|
427
435
|
filter_fn: Optional[Callable[[Node], bool]] = None,
|
428
436
|
) -> Optional[List[List[Node]]]:
|
429
|
-
"""
|
430
|
-
|
431
|
-
|
437
|
+
"""Find conv + batchnorm + relu parititions Note: This is only used for QAT.
|
438
|
+
|
439
|
+
In PTQ, batchnorm should already be fused into the conv.
|
432
440
|
"""
|
433
441
|
return _do_annotate_conv_bn(gm, quantization_config, filter_fn, has_relu=True)
|
434
442
|
|
@@ -439,8 +447,8 @@ def _do_annotate_conv_bn(
|
|
439
447
|
filter_fn: Optional[Callable[[Node], bool]],
|
440
448
|
has_relu: bool,
|
441
449
|
) -> List[List[Node]]:
|
442
|
-
"""
|
443
|
-
|
450
|
+
"""Given a function that takes in a `conv_fn` and returns a conv-bn[-relu] pattern,
|
451
|
+
|
444
452
|
return a list of annotated partitions.
|
445
453
|
|
446
454
|
The output of the pattern must include a dictionary from string name to node
|
@@ -486,7 +494,9 @@ def _do_annotate_conv_bn(
|
|
486
494
|
# Match against all conv dimensions and cuda variants
|
487
495
|
for (conv_fn, example_inputs), is_cuda, relu_is_inplace in combinations:
|
488
496
|
pattern = get_pattern(conv_fn, relu_is_inplace)
|
489
|
-
pattern = _get_aten_graph_module_for_pattern(
|
497
|
+
pattern = _get_aten_graph_module_for_pattern(
|
498
|
+
pattern, example_inputs, is_cuda
|
499
|
+
)
|
490
500
|
pattern.graph.eliminate_dead_code()
|
491
501
|
pattern.recompile()
|
492
502
|
matcher = SubgraphMatcherWithNameNodeMap(pattern, ignore_literals=True)
|
@@ -676,7 +686,9 @@ def _annotate_adaptive_avg_pool2d(
|
|
676
686
|
and pool_node.target != torch.ops.aten.mean.dim
|
677
687
|
and pool_node.target != torch.ops.aten.as_strided_.default
|
678
688
|
):
|
679
|
-
raise ValueError(
|
689
|
+
raise ValueError(
|
690
|
+
f"{pool_node} is not an aten adaptive_avg_pool2d operator"
|
691
|
+
)
|
680
692
|
|
681
693
|
if _is_annotated([pool_node]):
|
682
694
|
continue
|
@@ -741,7 +753,8 @@ def _annotate_fixed_qparams(
|
|
741
753
|
continue
|
742
754
|
|
743
755
|
node.meta["quantization_annotation"] = QuantizationAnnotation(
|
744
|
-
output_qspec=get_fixed_qparams_qspec(quantization_config),
|
756
|
+
output_qspec=get_fixed_qparams_qspec(quantization_config),
|
757
|
+
_annotated=True,
|
745
758
|
)
|
746
759
|
_mark_nodes_as_annotated(partition)
|
747
760
|
annotated_partitions.append(partition)
|
@@ -885,7 +898,9 @@ def _annotate_mul(
|
|
885
898
|
filter_fn: Optional[Callable[[Node], bool]] = None,
|
886
899
|
) -> Optional[List[List[Node]]]:
|
887
900
|
mul_partitions = get_source_partitions(
|
888
|
-
gm.graph,
|
901
|
+
gm.graph,
|
902
|
+
["mul", "mul_", operator.mul, torch.mul, operator.imul],
|
903
|
+
filter_fn,
|
889
904
|
)
|
890
905
|
mul_partitions = list(itertools.chain(*mul_partitions.values()))
|
891
906
|
annotated_partitions = []
|
@@ -932,8 +947,9 @@ def _annotate_cat(
|
|
932
947
|
|
933
948
|
if cat_node.target != torch.ops.aten.cat.default:
|
934
949
|
raise Exception(
|
935
|
-
|
936
|
-
" please check if you are calling the correct
|
950
|
+
"Expected cat node: torch.ops.aten.cat.default, but found"
|
951
|
+
f" {cat_node.target} please check if you are calling the correct"
|
952
|
+
" capture API"
|
937
953
|
)
|
938
954
|
|
939
955
|
annotated_partitions.append(cat_partition.nodes)
|
@@ -987,7 +1003,9 @@ def propagate_annotation(model: torch.fx.GraphModule) -> None:
|
|
987
1003
|
if not isinstance(prev_node, Node):
|
988
1004
|
continue
|
989
1005
|
|
990
|
-
quantization_annotation = prev_node.meta.get(
|
1006
|
+
quantization_annotation = prev_node.meta.get(
|
1007
|
+
"quantization_annotation", None
|
1008
|
+
)
|
991
1009
|
if not quantization_annotation:
|
992
1010
|
continue
|
993
1011
|
|
@@ -1014,7 +1032,9 @@ def propagate_annotation(model: torch.fx.GraphModule) -> None:
|
|
1014
1032
|
|
1015
1033
|
|
1016
1034
|
# TODO: make the list of ops customizable
|
1017
|
-
def _convert_scalars_to_attrs(
|
1035
|
+
def _convert_scalars_to_attrs(
|
1036
|
+
model: torch.fx.GraphModule,
|
1037
|
+
) -> torch.fx.GraphModule:
|
1018
1038
|
for n in model.graph.nodes:
|
1019
1039
|
if n.op != "call_function" or n.target not in [
|
1020
1040
|
torch.ops.aten.add.Tensor,
|
@@ -13,27 +13,27 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
|
-
|
16
|
+
import dataclasses
|
17
17
|
import enum
|
18
18
|
from typing import Optional
|
19
19
|
|
20
|
-
from ai_edge_torch.generative.quantize import quant_attrs
|
21
20
|
from ai_edge_torch.generative.quantize import quant_recipe
|
22
21
|
from ai_edge_torch.quantize import pt2e_quantizer as pt2eq
|
23
22
|
|
24
23
|
|
25
|
-
@dataclass(frozen=True)
|
24
|
+
@dataclasses.dataclass(frozen=True)
|
26
25
|
class QuantConfig:
|
27
|
-
"""
|
26
|
+
"""Encapsulates a quantization configuration.
|
27
|
+
|
28
28
|
Encapsulates all different quantization methods and schemes available for
|
29
29
|
models converted with ai_edge_torch.
|
30
30
|
|
31
|
-
|
31
|
+
Attributes:
|
32
32
|
pt2e_quantizer: The instance of PT2EQuantizer used to quantize the model
|
33
33
|
with PT2E quantization. This method of quantization is not applicable to
|
34
34
|
models created with the Edge Generative API.
|
35
|
-
generative_recipe: Quantization recipe to be applied on a model created
|
36
|
-
|
35
|
+
generative_recipe: Quantization recipe to be applied on a model created with
|
36
|
+
the Edge Generative API.
|
37
37
|
"""
|
38
38
|
|
39
39
|
pt2e_quantizer: pt2eq.PT2EQuantizer = None
|
@@ -76,6 +76,10 @@ class QuantConfig:
|
|
76
76
|
elif generative_recipe is not None:
|
77
77
|
generative_recipe.verify()
|
78
78
|
object.__setattr__(self, 'generative_recipe', generative_recipe)
|
79
|
-
object.__setattr__(
|
79
|
+
object.__setattr__(
|
80
|
+
self, '_quantizer_mode', self._QuantizerMode.AI_EDGE_QUANTIZER
|
81
|
+
)
|
80
82
|
else:
|
81
|
-
raise ValueError(
|
83
|
+
raise ValueError(
|
84
|
+
'Either pt2e_quantizer or generative_recipe must be set.'
|
85
|
+
)
|
@@ -13,26 +13,33 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
|
-
"""
|
16
|
+
"""Contains utility functions to test TFLite models exported from PyTorch."""
|
17
17
|
|
18
18
|
from collections.abc import Callable
|
19
19
|
|
20
|
+
from ai_edge_torch import model
|
20
21
|
import numpy as np
|
21
22
|
import torch
|
22
23
|
from torch.utils import _pytree as pytree
|
23
24
|
|
24
|
-
from ai_edge_torch.model import Model
|
25
|
-
|
26
25
|
|
27
26
|
# Utility to flatten the order to make it deterministic.
|
28
27
|
# Ordering is done in left-to-right depth-first tree traversal.
|
29
28
|
def _flatten(data):
|
30
|
-
out,
|
29
|
+
out, _ = pytree.tree_flatten(data)
|
31
30
|
return out
|
32
31
|
|
33
32
|
|
34
33
|
# Convert a Torch Tensor to a numpy array
|
35
34
|
def _torch_tensors_to_np(*argv):
|
35
|
+
"""Converts a Torch Tensor to a numpy array.
|
36
|
+
|
37
|
+
Args:
|
38
|
+
*argv: A list of torch.tensor or a single torch.tensor.
|
39
|
+
|
40
|
+
Returns:
|
41
|
+
A list of numpy array or a single numpy array.
|
42
|
+
"""
|
36
43
|
if len(argv) > 1:
|
37
44
|
data = list(argv)
|
38
45
|
else:
|
@@ -58,7 +65,7 @@ def _torch_tensors_to_np(*argv):
|
|
58
65
|
|
59
66
|
|
60
67
|
def compare_tflite_torch(
|
61
|
-
edge_model: Model,
|
68
|
+
edge_model: model.Model,
|
62
69
|
torch_eval_func: Callable,
|
63
70
|
args=None,
|
64
71
|
kwargs=None,
|
@@ -69,15 +76,17 @@ def compare_tflite_torch(
|
|
69
76
|
rtol: float = 1e-5
|
70
77
|
):
|
71
78
|
"""Compares torch models and TFLite models.
|
79
|
+
|
72
80
|
Args:
|
73
81
|
edge_model: Serialized ai_edge_torch.model.Model object.
|
74
82
|
torch_eval_func: Callable function to evaluate torch model.
|
75
|
-
args: torch.tensor array or a callable to generate a torch.tensor array
|
76
|
-
|
83
|
+
args: torch.tensor array or a callable to generate a torch.tensor array with
|
84
|
+
random data, to pass into models during inference. (default None).
|
77
85
|
kwargs: dict of str to torch.tensor, or a callable to generate such.
|
78
|
-
num_valid_inputs: Defines the number of times the random inputs will be
|
79
|
-
|
80
|
-
|
86
|
+
num_valid_inputs: Defines the number of times the random inputs will be
|
87
|
+
generated (if a callable is provided for input_data).
|
88
|
+
signature_name: If provided, specifies the name for the signature of the
|
89
|
+
edge_model to run. Calls the default signature if not provided.
|
81
90
|
atol: Absolute tolerance (see `numpy.allclose`)
|
82
91
|
rtol: Relative tolerance (see `numpy.allclose`)
|
83
92
|
"""
|
@@ -94,7 +103,9 @@ def compare_tflite_torch(
|
|
94
103
|
)
|
95
104
|
for _ in range(num_valid_inputs)
|
96
105
|
]
|
97
|
-
torch_outputs = [
|
106
|
+
torch_outputs = [
|
107
|
+
torch_eval_func(*args, **kwargs) for args, kwargs in torch_inputs
|
108
|
+
]
|
98
109
|
np_inputs = [
|
99
110
|
(_torch_tensors_to_np(args), _torch_tensors_to_np(kwargs))
|
100
111
|
for args, kwargs in torch_inputs
|
@@ -110,12 +121,13 @@ def compare_tflite_torch(
|
|
110
121
|
if signature_name is None:
|
111
122
|
return _flatten(edge_model(*args, **kwargs))
|
112
123
|
else:
|
113
|
-
return _flatten(
|
124
|
+
return _flatten(
|
125
|
+
edge_model(*args, **kwargs, signature_name=signature_name)
|
126
|
+
)
|
114
127
|
|
115
128
|
for idx, np_input in enumerate(np_inputs):
|
116
129
|
output = get_edge_output(np_input)
|
117
130
|
golden_output = np_outputs[idx]
|
118
|
-
|
119
131
|
is_output_len_eq = len(golden_output) == len(output)
|
120
132
|
|
121
133
|
output = [v.astype(np.float32) for v in output]
|
@@ -123,9 +135,10 @@ def compare_tflite_torch(
|
|
123
135
|
|
124
136
|
# Append the results of each invoke to a function-global variable
|
125
137
|
# used to store the comparison final results
|
126
|
-
is_equal = is_output_len_eq and all(
|
127
|
-
|
128
|
-
|
138
|
+
is_equal = is_output_len_eq and all([
|
139
|
+
equal_fn(out, golden_out)
|
140
|
+
for out, golden_out in zip(output, golden_output)
|
141
|
+
])
|
129
142
|
if not is_equal:
|
130
143
|
return False
|
131
144
|
|