ai-edge-torch-nightly 0.2.0.dev20240801__py3-none-any.whl → 0.2.0.dev20240803__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ai-edge-torch-nightly might be problematic. Click here for more details.
- ai_edge_torch/__init__.py +1 -0
- ai_edge_torch/convert/conversion.py +12 -8
- ai_edge_torch/convert/conversion_utils.py +38 -20
- ai_edge_torch/convert/converter.py +11 -5
- ai_edge_torch/convert/fx_passes/__init__.py +3 -4
- ai_edge_torch/convert/fx_passes/_pass_base.py +6 -2
- ai_edge_torch/convert/fx_passes/build_aten_composite_pass.py +46 -40
- ai_edge_torch/convert/fx_passes/build_interpolate_composite_pass.py +11 -10
- ai_edge_torch/convert/fx_passes/canonicalize_pass.py +2 -3
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_check.py +18 -7
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py +4 -3
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py +6 -4
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +9 -5
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/op_func_registry.py +1 -2
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +14 -10
- ai_edge_torch/convert/test/test_convert.py +39 -16
- ai_edge_torch/convert/test/test_convert_composites.py +115 -86
- ai_edge_torch/convert/test/test_convert_multisig.py +18 -10
- ai_edge_torch/convert/test/test_to_channel_last_io.py +1 -2
- ai_edge_torch/convert/to_channel_last_io.py +6 -2
- ai_edge_torch/debug/culprit.py +41 -16
- ai_edge_torch/debug/test/test_culprit.py +4 -3
- ai_edge_torch/debug/test/test_search_model.py +4 -3
- ai_edge_torch/debug/utils.py +3 -1
- ai_edge_torch/generative/examples/experimental/gemma/convert_to_tflite.py +4 -3
- ai_edge_torch/generative/examples/experimental/gemma/gemma.py +10 -8
- ai_edge_torch/generative/examples/experimental/phi/convert_to_tflite.py +7 -4
- ai_edge_torch/generative/examples/experimental/phi/phi2.py +10 -8
- ai_edge_torch/generative/examples/experimental/tiny_llama/convert_to_tflite.py +1 -2
- ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py +10 -8
- ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +4 -3
- ai_edge_torch/generative/examples/gemma/gemma.py +13 -9
- ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +7 -4
- ai_edge_torch/generative/examples/phi2/phi2.py +13 -9
- ai_edge_torch/generative/examples/stable_diffusion/attention.py +3 -1
- ai_edge_torch/generative/examples/stable_diffusion/clip.py +20 -9
- ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +14 -6
- ai_edge_torch/generative/examples/stable_diffusion/decoder.py +14 -7
- ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +41 -16
- ai_edge_torch/generative/examples/stable_diffusion/encoder.py +4 -3
- ai_edge_torch/generative/examples/stable_diffusion/pipeline.py +36 -13
- ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler.py +4 -3
- ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler_ancestral.py +4 -3
- ai_edge_torch/generative/examples/stable_diffusion/samplers/k_lms.py +4 -3
- ai_edge_torch/generative/examples/stable_diffusion/tokenizer.py +4 -1
- ai_edge_torch/generative/examples/stable_diffusion/util.py +9 -3
- ai_edge_torch/generative/examples/t5/convert_to_tflite.py +8 -5
- ai_edge_torch/generative/examples/t5/t5.py +158 -125
- ai_edge_torch/generative/examples/t5/t5_attention.py +15 -7
- ai_edge_torch/generative/examples/test_models/toy_model.py +7 -5
- ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +3 -4
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +4 -5
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +4 -3
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +10 -8
- ai_edge_torch/generative/fx_passes/__init__.py +1 -2
- ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +6 -3
- ai_edge_torch/generative/layers/attention.py +19 -11
- ai_edge_torch/generative/layers/builder.py +3 -4
- ai_edge_torch/generative/layers/kv_cache.py +4 -3
- ai_edge_torch/generative/layers/model_config.py +6 -2
- ai_edge_torch/generative/layers/rotary_position_embedding.py +3 -1
- ai_edge_torch/generative/layers/scaled_dot_product_attention.py +1 -2
- ai_edge_torch/generative/layers/unet/blocks_2d.py +69 -21
- ai_edge_torch/generative/layers/unet/builder.py +7 -4
- ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/translate_recipe.py +9 -4
- ai_edge_torch/generative/quantize/example.py +2 -3
- ai_edge_torch/generative/quantize/quant_recipe.py +2 -1
- ai_edge_torch/generative/test/loader_test.py +5 -4
- ai_edge_torch/generative/test/test_experimental_ekv.py +22 -11
- ai_edge_torch/generative/test/test_model_conversion.py +2 -3
- ai_edge_torch/generative/test/test_quantize.py +45 -48
- ai_edge_torch/generative/utilities/loader.py +55 -28
- ai_edge_torch/generative/utilities/stable_diffusion_loader.py +86 -33
- ai_edge_torch/generative/utilities/t5_loader.py +77 -48
- ai_edge_torch/hlfb/mark_pattern/__init__.py +2 -3
- ai_edge_torch/hlfb/mark_pattern/pattern.py +16 -7
- ai_edge_torch/hlfb/test/test_mark_pattern.py +4 -3
- ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py +12 -6
- ai_edge_torch/model.py +8 -5
- ai_edge_torch/quantize/pt2e_quantizer.py +30 -15
- ai_edge_torch/quantize/pt2e_quantizer_utils.py +30 -11
- ai_edge_torch/quantize/quant_config.py +6 -2
- ai_edge_torch/testing/model_coverage/model_coverage.py +11 -7
- ai_edge_torch/version.py +16 -0
- {ai_edge_torch_nightly-0.2.0.dev20240801.dist-info → ai_edge_torch_nightly-0.2.0.dev20240803.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.2.0.dev20240801.dist-info → ai_edge_torch_nightly-0.2.0.dev20240803.dist-info}/RECORD +89 -88
- {ai_edge_torch_nightly-0.2.0.dev20240801.dist-info → ai_edge_torch_nightly-0.2.0.dev20240803.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.2.0.dev20240801.dist-info → ai_edge_torch_nightly-0.2.0.dev20240803.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.2.0.dev20240801.dist-info → ai_edge_torch_nightly-0.2.0.dev20240803.dist-info}/top_level.txt +0 -0
|
@@ -16,13 +16,6 @@ import operator
|
|
|
16
16
|
import os
|
|
17
17
|
from typing import Optional, Tuple, Union
|
|
18
18
|
|
|
19
|
-
import torch
|
|
20
|
-
import torch.ao.quantization.quantize_pt2e
|
|
21
|
-
from torch.export import ExportedProgram
|
|
22
|
-
from torch.fx import GraphModule
|
|
23
|
-
from torch.fx import Node
|
|
24
|
-
import torch.utils._pytree as pytree
|
|
25
|
-
|
|
26
19
|
from ai_edge_torch.convert.fx_passes import ExportedProgramPassBase
|
|
27
20
|
from ai_edge_torch.convert.fx_passes import ExportedProgramPassResult
|
|
28
21
|
from ai_edge_torch.convert.fx_passes.optimize_layout_transposes_pass import layout_check # NOQA
|
|
@@ -30,6 +23,12 @@ from ai_edge_torch.convert.fx_passes.optimize_layout_transposes_pass import layo
|
|
|
30
23
|
from ai_edge_torch.convert.fx_passes.optimize_layout_transposes_pass import layout_partitioners # NOQA
|
|
31
24
|
from ai_edge_torch.convert.fx_passes.optimize_layout_transposes_pass import layout_rewrite # NOQA
|
|
32
25
|
from ai_edge_torch.convert.fx_passes.optimize_layout_transposes_pass import utils # NOQA
|
|
26
|
+
import torch
|
|
27
|
+
import torch.ao.quantization.quantize_pt2e
|
|
28
|
+
from torch.export import ExportedProgram
|
|
29
|
+
from torch.fx import GraphModule
|
|
30
|
+
from torch.fx import Node
|
|
31
|
+
import torch.utils._pytree as pytree
|
|
33
32
|
|
|
34
33
|
TransposeFunc = Union[utils.tensor_to_nchw, utils.tensor_to_nhwc]
|
|
35
34
|
|
|
@@ -208,7 +207,8 @@ class OptimizeLayoutTransposesPass(ExportedProgramPassBase):
|
|
|
208
207
|
|
|
209
208
|
if not layout_check.is_4d(input_node):
|
|
210
209
|
raise AssertionError(
|
|
211
|
-
|
|
210
|
+
"Attempting to convert non-NHWC compatible node to NHWC:"
|
|
211
|
+
f" {input_node}"
|
|
212
212
|
)
|
|
213
213
|
|
|
214
214
|
# Assign target node's source meta to the to_NHWC node, because the transpose
|
|
@@ -250,7 +250,9 @@ class OptimizeLayoutTransposesPass(ExportedProgramPassBase):
|
|
|
250
250
|
|
|
251
251
|
for node in graph.nodes:
|
|
252
252
|
has_input_nodes = len(node.all_input_nodes) > 0
|
|
253
|
-
all_inputs_are_const = all(
|
|
253
|
+
all_inputs_are_const = all(
|
|
254
|
+
map(layout_mark.is_const_node, node.all_input_nodes)
|
|
255
|
+
)
|
|
254
256
|
if (
|
|
255
257
|
node.name in non_user_input_names
|
|
256
258
|
or (has_input_nodes and all_inputs_are_const)
|
|
@@ -262,7 +264,9 @@ class OptimizeLayoutTransposesPass(ExportedProgramPassBase):
|
|
|
262
264
|
self.mark_const_nodes(exported_program)
|
|
263
265
|
|
|
264
266
|
graph_module = exported_program.graph_module
|
|
265
|
-
partitioner = os.environ.get(
|
|
267
|
+
partitioner = os.environ.get(
|
|
268
|
+
"AIEDGETORCH_LAYOUT_OPTIMIZE_PARTITIONER", None
|
|
269
|
+
)
|
|
266
270
|
if partitioner == "MINCUT":
|
|
267
271
|
graph_module = layout_partitioners.min_cut.partition(graph_module)
|
|
268
272
|
elif partitioner == "GREEDY":
|
|
@@ -20,15 +20,14 @@ import tempfile
|
|
|
20
20
|
from typing import Tuple
|
|
21
21
|
import unittest
|
|
22
22
|
|
|
23
|
+
import ai_edge_torch
|
|
24
|
+
from ai_edge_torch.convert import conversion_utils as cutils
|
|
25
|
+
from ai_edge_torch.testing import model_coverage
|
|
23
26
|
import numpy as np
|
|
24
27
|
import tensorflow as tf
|
|
25
28
|
import torch
|
|
26
29
|
import torchvision
|
|
27
30
|
|
|
28
|
-
import ai_edge_torch
|
|
29
|
-
from ai_edge_torch.convert import conversion_utils as cutils
|
|
30
|
-
from ai_edge_torch.testing import model_coverage
|
|
31
|
-
|
|
32
31
|
|
|
33
32
|
@dataclass
|
|
34
33
|
class TestContainer1:
|
|
@@ -36,7 +35,9 @@ class TestContainer1:
|
|
|
36
35
|
data_2: Tuple[torch.Tensor, torch.Tensor]
|
|
37
36
|
|
|
38
37
|
|
|
39
|
-
torch.export.register_dataclass(
|
|
38
|
+
torch.export.register_dataclass(
|
|
39
|
+
TestContainer1, serialized_type_name="TestContainer1"
|
|
40
|
+
)
|
|
40
41
|
|
|
41
42
|
|
|
42
43
|
class TestConvert(unittest.TestCase):
|
|
@@ -60,7 +61,9 @@ class TestConvert(unittest.TestCase):
|
|
|
60
61
|
torch_module = Add().eval()
|
|
61
62
|
edge_model = ai_edge_torch.convert(torch_module, args)
|
|
62
63
|
|
|
63
|
-
self.assertTrue(
|
|
64
|
+
self.assertTrue(
|
|
65
|
+
model_coverage.compare_tflite_torch(edge_model, torch_module, args)
|
|
66
|
+
)
|
|
64
67
|
|
|
65
68
|
def test_convert_dot_add(self):
|
|
66
69
|
class DotAdd(torch.nn.Module):
|
|
@@ -77,14 +80,18 @@ class TestConvert(unittest.TestCase):
|
|
|
77
80
|
torch_module = DotAdd().eval()
|
|
78
81
|
edge_model = ai_edge_torch.convert(torch_module, args)
|
|
79
82
|
|
|
80
|
-
self.assertTrue(
|
|
83
|
+
self.assertTrue(
|
|
84
|
+
model_coverage.compare_tflite_torch(edge_model, torch_module, args)
|
|
85
|
+
)
|
|
81
86
|
|
|
82
87
|
def test_convert_resnet18(self):
|
|
83
88
|
args = (torch.randn(4, 3, 224, 224),)
|
|
84
89
|
torch_module = torchvision.models.resnet18().eval()
|
|
85
90
|
edge_model = ai_edge_torch.convert(torch_module, args)
|
|
86
91
|
|
|
87
|
-
self.assertTrue(
|
|
92
|
+
self.assertTrue(
|
|
93
|
+
model_coverage.compare_tflite_torch(edge_model, torch_module, args)
|
|
94
|
+
)
|
|
88
95
|
|
|
89
96
|
def test_signature_args_ordering(self):
|
|
90
97
|
"""Tests conversion of a model with more than 10 arguments."""
|
|
@@ -156,7 +163,9 @@ class TestConvert(unittest.TestCase):
|
|
|
156
163
|
torch_model = BasicAddModelWithMultipleOutputs().eval()
|
|
157
164
|
edge_model = ai_edge_torch.convert(torch_model, sample_input)
|
|
158
165
|
|
|
159
|
-
result = model_coverage.compare_tflite_torch(
|
|
166
|
+
result = model_coverage.compare_tflite_torch(
|
|
167
|
+
edge_model, torch_model, sample_input
|
|
168
|
+
)
|
|
160
169
|
self.assertTrue(result)
|
|
161
170
|
|
|
162
171
|
def test_12_outputs_model(self):
|
|
@@ -201,7 +210,9 @@ class TestConvert(unittest.TestCase):
|
|
|
201
210
|
torch_model = BasicAddModelWithMultipleOutputs().eval()
|
|
202
211
|
edge_model = ai_edge_torch.convert(torch_model, sample_input)
|
|
203
212
|
|
|
204
|
-
result = model_coverage.compare_tflite_torch(
|
|
213
|
+
result = model_coverage.compare_tflite_torch(
|
|
214
|
+
edge_model, torch_model, sample_input
|
|
215
|
+
)
|
|
205
216
|
self.assertTrue(result)
|
|
206
217
|
|
|
207
218
|
def test_apply_tfl_backdoor_flags(self):
|
|
@@ -244,7 +255,9 @@ class TestConvert(unittest.TestCase):
|
|
|
244
255
|
tmp_dir_path, "test_convert_add_backdoor_flags_mlir_dump"
|
|
245
256
|
)
|
|
246
257
|
ai_edge_torch.convert(
|
|
247
|
-
torch_module,
|
|
258
|
+
torch_module,
|
|
259
|
+
args,
|
|
260
|
+
_ai_edge_converter_flags={"ir_dump_dir": ir_dump_path},
|
|
248
261
|
)
|
|
249
262
|
self.assertTrue(os.path.isdir(ir_dump_path))
|
|
250
263
|
|
|
@@ -296,7 +309,9 @@ class TestConvert(unittest.TestCase):
|
|
|
296
309
|
edge_model = ai_edge_torch.convert(model, sample_kwargs=kwargs_gen())
|
|
297
310
|
|
|
298
311
|
self.assertTrue(
|
|
299
|
-
model_coverage.compare_tflite_torch(
|
|
312
|
+
model_coverage.compare_tflite_torch(
|
|
313
|
+
edge_model, model, kwargs=kwargs_gen
|
|
314
|
+
)
|
|
300
315
|
)
|
|
301
316
|
|
|
302
317
|
def test_convert_model_with_args_kwargs(self):
|
|
@@ -316,7 +331,9 @@ class TestConvert(unittest.TestCase):
|
|
|
316
331
|
edge_model = ai_edge_torch.convert(model, args_gen(), kwargs_gen())
|
|
317
332
|
|
|
318
333
|
self.assertTrue(
|
|
319
|
-
model_coverage.compare_tflite_torch(
|
|
334
|
+
model_coverage.compare_tflite_torch(
|
|
335
|
+
edge_model, model, args_gen, kwargs_gen
|
|
336
|
+
)
|
|
320
337
|
)
|
|
321
338
|
|
|
322
339
|
def test_convert_model_with_args_nested_kwargs_1(self):
|
|
@@ -344,7 +361,9 @@ class TestConvert(unittest.TestCase):
|
|
|
344
361
|
"z_data_2_0": kwargs["z"].data_2[0].numpy(),
|
|
345
362
|
"z_data_2_1": kwargs["z"].data_2[1].numpy(),
|
|
346
363
|
}
|
|
347
|
-
self._compare_tflite_torch_args_kwargs(
|
|
364
|
+
self._compare_tflite_torch_args_kwargs(
|
|
365
|
+
SampleModel(), args, kwargs, flat_inputs
|
|
366
|
+
)
|
|
348
367
|
|
|
349
368
|
def test_convert_model_with_args_nested_kwargs_2(self):
|
|
350
369
|
"""
|
|
@@ -371,7 +390,9 @@ class TestConvert(unittest.TestCase):
|
|
|
371
390
|
"z_data_2_0_0": kwargs["z"].data_2[0][0].numpy(),
|
|
372
391
|
"z_data_2_1": kwargs["z"].data_2[1].numpy(),
|
|
373
392
|
}
|
|
374
|
-
self._compare_tflite_torch_args_kwargs(
|
|
393
|
+
self._compare_tflite_torch_args_kwargs(
|
|
394
|
+
SampleModel(), args, kwargs, flat_inputs
|
|
395
|
+
)
|
|
375
396
|
|
|
376
397
|
def test_convert_model_with_args_nested_kwargs_3(self):
|
|
377
398
|
"""
|
|
@@ -398,7 +419,9 @@ class TestConvert(unittest.TestCase):
|
|
|
398
419
|
"z_data_2_0_foo": kwargs["z"].data_2[0]["foo"].numpy(),
|
|
399
420
|
"z_data_2_1": kwargs["z"].data_2[1].numpy(),
|
|
400
421
|
}
|
|
401
|
-
self._compare_tflite_torch_args_kwargs(
|
|
422
|
+
self._compare_tflite_torch_args_kwargs(
|
|
423
|
+
SampleModel(), args, kwargs, flat_inputs
|
|
424
|
+
)
|
|
402
425
|
|
|
403
426
|
def _compare_tflite_torch_args_kwargs(self, model, args, kwargs, flat_inputs):
|
|
404
427
|
model.eval()
|
|
@@ -17,11 +17,10 @@
|
|
|
17
17
|
from typing import Callable
|
|
18
18
|
import unittest
|
|
19
19
|
|
|
20
|
-
import parameterized
|
|
21
|
-
import torch
|
|
22
|
-
|
|
23
20
|
import ai_edge_torch
|
|
24
21
|
from ai_edge_torch.testing import model_coverage
|
|
22
|
+
import parameterized
|
|
23
|
+
import torch
|
|
25
24
|
|
|
26
25
|
|
|
27
26
|
def _func_to_torch_module(func: Callable):
|
|
@@ -47,35 +46,35 @@ class TestConvertComposites(unittest.TestCase):
|
|
|
47
46
|
torch_module = torch.nn.Hardswish().eval()
|
|
48
47
|
edge_model = ai_edge_torch.convert(torch_module, args)
|
|
49
48
|
|
|
50
|
-
self.assertTrue(
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
]
|
|
78
|
-
)
|
|
49
|
+
self.assertTrue(
|
|
50
|
+
model_coverage.compare_tflite_torch(edge_model, torch_module, args)
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
@parameterized.parameterized.expand([
|
|
54
|
+
# input_size, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override
|
|
55
|
+
# no padding, stride = 1
|
|
56
|
+
([1, 3, 6, 6], [3, 3], [1, 1], [0, 0], False, True, None),
|
|
57
|
+
# add stride
|
|
58
|
+
([1, 3, 6, 6], [3, 3], [2, 2], [0, 0], False, True, None),
|
|
59
|
+
# default values
|
|
60
|
+
([1, 3, 6, 6], [3, 3]),
|
|
61
|
+
# add padding
|
|
62
|
+
([1, 3, 6, 6], [3, 3], [1, 1], [1, 1], False, True, None),
|
|
63
|
+
# add different padding for different dims
|
|
64
|
+
([1, 3, 6, 6], [3, 3], [1, 1], [0, 1], False, True, None),
|
|
65
|
+
# add both stride and padding
|
|
66
|
+
([1, 3, 6, 6], [3, 3], [2, 2], [1, 1], False, True, None),
|
|
67
|
+
# count_include_pad = False
|
|
68
|
+
([1, 3, 6, 6], [3, 3], [1, 1], [1, 1], False, False, None),
|
|
69
|
+
# ceil_mode = True
|
|
70
|
+
([1, 3, 6, 6], [3, 3], [1, 1], [1, 1], True, True, None),
|
|
71
|
+
# ceil_mode = True, stride=[3, 3]
|
|
72
|
+
([1, 3, 6, 6], [3, 3], [3, 3], [1, 1], True, True, None),
|
|
73
|
+
# set divisor_override
|
|
74
|
+
([1, 3, 6, 6], [3, 3], [1, 1], 0, False, True, 6),
|
|
75
|
+
# padding set to one number
|
|
76
|
+
([1, 3, 6, 6], [3, 3], [1, 1], 1, False, True, None),
|
|
77
|
+
])
|
|
79
78
|
def test_convert_avg_pool2d(self, input_size, *args):
|
|
80
79
|
"""Tests conversion of a module containing an avg_pool2d aten."""
|
|
81
80
|
torch_module = _func_to_torch_module(
|
|
@@ -85,51 +84,65 @@ class TestConvertComposites(unittest.TestCase):
|
|
|
85
84
|
edge_model = ai_edge_torch.convert(torch_module, tracing_args)
|
|
86
85
|
|
|
87
86
|
self.assertTrue(
|
|
88
|
-
model_coverage.compare_tflite_torch(
|
|
87
|
+
model_coverage.compare_tflite_torch(
|
|
88
|
+
edge_model, torch_module, tracing_args
|
|
89
|
+
)
|
|
89
90
|
)
|
|
90
91
|
|
|
91
|
-
@parameterized.parameterized.expand(
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
92
|
+
@parameterized.parameterized.expand([
|
|
93
|
+
# use scale_factor with align_corners=False
|
|
94
|
+
(
|
|
95
|
+
[1, 3, 10, 10],
|
|
96
|
+
dict(scale_factor=3.0, mode='bilinear', align_corners=False),
|
|
97
|
+
),
|
|
98
|
+
# use scale_factor with align_corners=true
|
|
99
|
+
(
|
|
100
|
+
[1, 3, 10, 10],
|
|
101
|
+
dict(scale_factor=3.0, mode='bilinear', align_corners=True),
|
|
102
|
+
),
|
|
103
|
+
# use size
|
|
104
|
+
([1, 3, 10, 10], dict(size=[15, 20], mode='bilinear')),
|
|
105
|
+
# use size with align_corners=true
|
|
106
|
+
(
|
|
107
|
+
[1, 3, 10, 10],
|
|
108
|
+
dict(size=[15, 20], mode='bilinear', align_corners=True),
|
|
109
|
+
),
|
|
110
|
+
])
|
|
106
111
|
def test_convert_upsample_bilinear_functional(self, input_size, kwargs):
|
|
107
112
|
"""Tests conversion of a torch.nn.functional.upsample module."""
|
|
108
113
|
torch_module = _func_to_torch_module(
|
|
109
|
-
lambda input_tensor: torch.nn.functional.upsample(
|
|
114
|
+
lambda input_tensor: torch.nn.functional.upsample(
|
|
115
|
+
input_tensor, **kwargs
|
|
116
|
+
)
|
|
110
117
|
)
|
|
111
118
|
tracing_args = (torch.randn(*input_size),)
|
|
112
119
|
edge_model = ai_edge_torch.convert(torch_module, tracing_args)
|
|
113
120
|
|
|
114
121
|
self.assertTrue(
|
|
115
|
-
model_coverage.compare_tflite_torch(
|
|
122
|
+
model_coverage.compare_tflite_torch(
|
|
123
|
+
edge_model, torch_module, tracing_args
|
|
124
|
+
)
|
|
116
125
|
)
|
|
117
126
|
|
|
118
|
-
@parameterized.parameterized.expand(
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
127
|
+
@parameterized.parameterized.expand([
|
|
128
|
+
# use scale_factor with align_corners=False
|
|
129
|
+
(
|
|
130
|
+
[1, 3, 10, 10],
|
|
131
|
+
dict(scale_factor=3.0, mode='bilinear', align_corners=False),
|
|
132
|
+
),
|
|
133
|
+
# use scale_factor with align_corners=true
|
|
134
|
+
(
|
|
135
|
+
[1, 3, 10, 10],
|
|
136
|
+
dict(scale_factor=3.0, mode='bilinear', align_corners=True),
|
|
137
|
+
),
|
|
138
|
+
# use size
|
|
139
|
+
([1, 3, 10, 10], dict(size=[15, 20], mode='bilinear')),
|
|
140
|
+
# use size with align_corners=true
|
|
141
|
+
(
|
|
142
|
+
[1, 3, 10, 10],
|
|
143
|
+
dict(size=[15, 20], mode='bilinear', align_corners=True),
|
|
144
|
+
),
|
|
145
|
+
])
|
|
133
146
|
def test_convert_upsample_bilinear(self, input_size, kwargs):
|
|
134
147
|
"""Tests conversion of a torch.nn.Upsample module."""
|
|
135
148
|
torch_module = _func_to_torch_module(
|
|
@@ -139,34 +152,44 @@ class TestConvertComposites(unittest.TestCase):
|
|
|
139
152
|
edge_model = ai_edge_torch.convert(torch_module, tracing_args)
|
|
140
153
|
|
|
141
154
|
self.assertTrue(
|
|
142
|
-
model_coverage.compare_tflite_torch(
|
|
155
|
+
model_coverage.compare_tflite_torch(
|
|
156
|
+
edge_model, torch_module, tracing_args
|
|
157
|
+
)
|
|
143
158
|
)
|
|
144
159
|
|
|
145
|
-
@parameterized.parameterized.expand(
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
+
@parameterized.parameterized.expand([
|
|
161
|
+
# use scale_factor with align_corners=False
|
|
162
|
+
(
|
|
163
|
+
[1, 3, 10, 10],
|
|
164
|
+
dict(scale_factor=3.0, mode='bilinear', align_corners=False),
|
|
165
|
+
),
|
|
166
|
+
# use scale_factor with align_corners=true
|
|
167
|
+
(
|
|
168
|
+
[1, 3, 10, 10],
|
|
169
|
+
dict(scale_factor=3.0, mode='bilinear', align_corners=True),
|
|
170
|
+
),
|
|
171
|
+
# use size
|
|
172
|
+
([1, 3, 10, 10], dict(size=[15, 20], mode='bilinear')),
|
|
173
|
+
# use size with align_corners=true
|
|
174
|
+
(
|
|
175
|
+
[1, 3, 10, 10],
|
|
176
|
+
dict(size=[15, 20], mode='bilinear', align_corners=True),
|
|
177
|
+
),
|
|
178
|
+
])
|
|
160
179
|
def test_convert_interpolate_bilinear_functional(self, input_size, kwargs):
|
|
161
180
|
"""Tests conversion of a torch.nn.functional.interpolate module."""
|
|
162
181
|
torch_module = _func_to_torch_module(
|
|
163
|
-
lambda input_tensor: torch.nn.functional.interpolate(
|
|
182
|
+
lambda input_tensor: torch.nn.functional.interpolate(
|
|
183
|
+
input_tensor, **kwargs
|
|
184
|
+
)
|
|
164
185
|
)
|
|
165
186
|
tracing_args = (torch.randn(*input_size),)
|
|
166
187
|
edge_model = ai_edge_torch.convert(torch_module, tracing_args)
|
|
167
188
|
|
|
168
189
|
self.assertTrue(
|
|
169
|
-
model_coverage.compare_tflite_torch(
|
|
190
|
+
model_coverage.compare_tflite_torch(
|
|
191
|
+
edge_model, torch_module, tracing_args
|
|
192
|
+
)
|
|
170
193
|
)
|
|
171
194
|
|
|
172
195
|
def test_convert_gelu(self):
|
|
@@ -176,7 +199,9 @@ class TestConvertComposites(unittest.TestCase):
|
|
|
176
199
|
torch_module = torch.nn.GELU().eval()
|
|
177
200
|
edge_model = ai_edge_torch.convert(torch_module, args)
|
|
178
201
|
|
|
179
|
-
self.assertTrue(
|
|
202
|
+
self.assertTrue(
|
|
203
|
+
model_coverage.compare_tflite_torch(edge_model, torch_module, args)
|
|
204
|
+
)
|
|
180
205
|
|
|
181
206
|
def test_convert_gelu_approximate(self):
|
|
182
207
|
"""Tests conversion of an Approximate GELU module."""
|
|
@@ -185,7 +210,9 @@ class TestConvertComposites(unittest.TestCase):
|
|
|
185
210
|
torch_module = torch.nn.GELU('tanh').eval()
|
|
186
211
|
edge_model = ai_edge_torch.convert(torch_module, args)
|
|
187
212
|
|
|
188
|
-
self.assertTrue(
|
|
213
|
+
self.assertTrue(
|
|
214
|
+
model_coverage.compare_tflite_torch(edge_model, torch_module, args)
|
|
215
|
+
)
|
|
189
216
|
|
|
190
217
|
def test_convert_embedding_lookup(self):
|
|
191
218
|
"""Tests conversion of an Embedding module."""
|
|
@@ -194,7 +221,9 @@ class TestConvertComposites(unittest.TestCase):
|
|
|
194
221
|
torch_module = torch.nn.Embedding(10, 10)
|
|
195
222
|
edge_model = ai_edge_torch.convert(torch_module, args)
|
|
196
223
|
|
|
197
|
-
self.assertTrue(
|
|
224
|
+
self.assertTrue(
|
|
225
|
+
model_coverage.compare_tflite_torch(edge_model, torch_module, args)
|
|
226
|
+
)
|
|
198
227
|
|
|
199
228
|
|
|
200
229
|
if __name__ == '__main__':
|
|
@@ -15,11 +15,10 @@
|
|
|
15
15
|
|
|
16
16
|
import unittest
|
|
17
17
|
|
|
18
|
-
import torch
|
|
19
|
-
import torchvision
|
|
20
|
-
|
|
21
18
|
import ai_edge_torch
|
|
22
19
|
from ai_edge_torch.testing import model_coverage
|
|
20
|
+
import torch
|
|
21
|
+
import torchvision
|
|
23
22
|
|
|
24
23
|
|
|
25
24
|
class TestConvertMultiSignature(unittest.TestCase):
|
|
@@ -41,7 +40,9 @@ class TestConvertMultiSignature(unittest.TestCase):
|
|
|
41
40
|
signature_name, torch_module, large_args
|
|
42
41
|
).convert(torch_module, args)
|
|
43
42
|
|
|
44
|
-
self.assertTrue(
|
|
43
|
+
self.assertTrue(
|
|
44
|
+
model_coverage.compare_tflite_torch(edge_model, torch_module, args)
|
|
45
|
+
)
|
|
45
46
|
self.assertTrue(
|
|
46
47
|
model_coverage.compare_tflite_torch(
|
|
47
48
|
edge_model, torch_module, large_args, signature_name=signature_name
|
|
@@ -74,7 +75,10 @@ class TestConvertMultiSignature(unittest.TestCase):
|
|
|
74
75
|
)
|
|
75
76
|
self.assertTrue(
|
|
76
77
|
model_coverage.compare_tflite_torch(
|
|
77
|
-
edge_model,
|
|
78
|
+
edge_model,
|
|
79
|
+
torch_module,
|
|
80
|
+
large_args,
|
|
81
|
+
signature_name=signature_name_2,
|
|
78
82
|
)
|
|
79
83
|
)
|
|
80
84
|
|
|
@@ -87,11 +91,13 @@ class TestConvertMultiSignature(unittest.TestCase):
|
|
|
87
91
|
|
|
88
92
|
signature_name = "large_input"
|
|
89
93
|
|
|
90
|
-
edge_model = ai_edge_torch.signature(
|
|
91
|
-
torch_module,
|
|
92
|
-
)
|
|
94
|
+
edge_model = ai_edge_torch.signature(
|
|
95
|
+
signature_name, torch_module, args
|
|
96
|
+
).convert(torch_module, large_args)
|
|
93
97
|
|
|
94
|
-
self.assertTrue(
|
|
98
|
+
self.assertTrue(
|
|
99
|
+
model_coverage.compare_tflite_torch(edge_model, torch_module, args)
|
|
100
|
+
)
|
|
95
101
|
self.assertTrue(
|
|
96
102
|
model_coverage.compare_tflite_torch(
|
|
97
103
|
edge_model, torch_module, large_args, signature_name=signature_name
|
|
@@ -110,7 +116,9 @@ class TestConvertMultiSignature(unittest.TestCase):
|
|
|
110
116
|
resnet_signature_name = "resnet18"
|
|
111
117
|
|
|
112
118
|
edge_model = (
|
|
113
|
-
ai_edge_torch.signature(
|
|
119
|
+
ai_edge_torch.signature(
|
|
120
|
+
mobilenet_signature_name, mobilentv2, mobilenet_args
|
|
121
|
+
)
|
|
114
122
|
.signature(resnet_signature_name, resnet18, resnet_args)
|
|
115
123
|
.convert(resnet18, resnet_args)
|
|
116
124
|
)
|
|
@@ -31,7 +31,9 @@ class ChannelLastIOWrapper(nn.Module):
|
|
|
31
31
|
if not torch.is_tensor(x):
|
|
32
32
|
raise ValueError("Input must be a torch tensor")
|
|
33
33
|
if x.ndim < 3:
|
|
34
|
-
raise ValueError(
|
|
34
|
+
raise ValueError(
|
|
35
|
+
"Input must be a tensor with rank >= 3 in layout (N, C, ...)"
|
|
36
|
+
)
|
|
35
37
|
dims = [0, *range(2, x.ndim), 1]
|
|
36
38
|
return torch.permute(x, dims)
|
|
37
39
|
|
|
@@ -39,7 +41,9 @@ class ChannelLastIOWrapper(nn.Module):
|
|
|
39
41
|
if not torch.is_tensor(x):
|
|
40
42
|
raise ValueError("Input must be a torch tensor.")
|
|
41
43
|
if x.ndim < 3:
|
|
42
|
-
raise ValueError(
|
|
44
|
+
raise ValueError(
|
|
45
|
+
"Input must be a tensor with rank >= 3 in layout (N, ..., C)"
|
|
46
|
+
)
|
|
43
47
|
dims = [0, x.ndim - 1, *range(1, x.ndim - 1)]
|
|
44
48
|
return torch.permute(x, dims)
|
|
45
49
|
|