ai-edge-torch-nightly 0.1.dev202405131930__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ai-edge-torch-nightly might be problematic. Click here for more details.

Files changed (91) hide show
  1. ai_edge_torch/__init__.py +30 -0
  2. ai_edge_torch/convert/__init__.py +14 -0
  3. ai_edge_torch/convert/conversion.py +117 -0
  4. ai_edge_torch/convert/conversion_utils.py +330 -0
  5. ai_edge_torch/convert/converter.py +171 -0
  6. ai_edge_torch/convert/fx_passes/__init__.py +59 -0
  7. ai_edge_torch/convert/fx_passes/_pass_base.py +49 -0
  8. ai_edge_torch/convert/fx_passes/build_aten_composite_pass.py +192 -0
  9. ai_edge_torch/convert/fx_passes/build_upsample_bilinear2d_composite_pass.py +84 -0
  10. ai_edge_torch/convert/fx_passes/canonicalize_pass.py +37 -0
  11. ai_edge_torch/convert/fx_passes/inject_mlir_debuginfo_pass.py +73 -0
  12. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/__init__.py +16 -0
  13. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_check.py +215 -0
  14. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_mark.py +48 -0
  15. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/__init__.py +17 -0
  16. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py +59 -0
  17. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py +196 -0
  18. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +400 -0
  19. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/op_func_registry.py +30 -0
  20. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +286 -0
  21. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/utils.py +62 -0
  22. ai_edge_torch/convert/test/__init__.py +14 -0
  23. ai_edge_torch/convert/test/test_convert.py +273 -0
  24. ai_edge_torch/convert/test/test_convert_composites.py +171 -0
  25. ai_edge_torch/convert/test/test_convert_multisig.py +139 -0
  26. ai_edge_torch/debug/__init__.py +16 -0
  27. ai_edge_torch/debug/culprit.py +423 -0
  28. ai_edge_torch/debug/test/__init__.py +14 -0
  29. ai_edge_torch/debug/test/test_culprit.py +133 -0
  30. ai_edge_torch/debug/utils.py +48 -0
  31. ai_edge_torch/experimental/__init__.py +14 -0
  32. ai_edge_torch/generative/__init__.py +14 -0
  33. ai_edge_torch/generative/examples/__init__.py +14 -0
  34. ai_edge_torch/generative/examples/gemma/__init__.py +14 -0
  35. ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +66 -0
  36. ai_edge_torch/generative/examples/gemma/gemma.py +174 -0
  37. ai_edge_torch/generative/examples/phi2/__init__.py +14 -0
  38. ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +64 -0
  39. ai_edge_torch/generative/examples/phi2/phi2.py +164 -0
  40. ai_edge_torch/generative/examples/t5/__init__.py +14 -0
  41. ai_edge_torch/generative/examples/t5/convert_to_tflite.py +135 -0
  42. ai_edge_torch/generative/examples/t5/t5.py +608 -0
  43. ai_edge_torch/generative/examples/t5/t5_attention.py +255 -0
  44. ai_edge_torch/generative/examples/test_models/__init__.py +14 -0
  45. ai_edge_torch/generative/examples/test_models/toy_model.py +119 -0
  46. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +143 -0
  47. ai_edge_torch/generative/examples/tiny_llama/__init__.py +0 -0
  48. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +66 -0
  49. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +164 -0
  50. ai_edge_torch/generative/layers/__init__.py +14 -0
  51. ai_edge_torch/generative/layers/attention.py +288 -0
  52. ai_edge_torch/generative/layers/attention_utils.py +169 -0
  53. ai_edge_torch/generative/layers/builder.py +103 -0
  54. ai_edge_torch/generative/layers/feed_forward.py +95 -0
  55. ai_edge_torch/generative/layers/kv_cache.py +83 -0
  56. ai_edge_torch/generative/layers/model_config.py +135 -0
  57. ai_edge_torch/generative/layers/normalization.py +62 -0
  58. ai_edge_torch/generative/layers/rotary_position_embedding.py +36 -0
  59. ai_edge_torch/generative/quantize/__init__.py +14 -0
  60. ai_edge_torch/generative/quantize/example.py +45 -0
  61. ai_edge_torch/generative/quantize/quant_attrs.py +66 -0
  62. ai_edge_torch/generative/quantize/quant_recipe.py +106 -0
  63. ai_edge_torch/generative/quantize/quant_recipe_utils.py +51 -0
  64. ai_edge_torch/generative/quantize/quant_recipes.py +48 -0
  65. ai_edge_torch/generative/quantize/supported_schemes.py +31 -0
  66. ai_edge_torch/generative/test/__init__.py +14 -0
  67. ai_edge_torch/generative/test/test_model_conversion.py +201 -0
  68. ai_edge_torch/generative/test/test_quantize.py +109 -0
  69. ai_edge_torch/generative/utilities/__init__.py +15 -0
  70. ai_edge_torch/generative/utilities/loader.py +290 -0
  71. ai_edge_torch/generative/utilities/t5_loader.py +467 -0
  72. ai_edge_torch/hlfb/__init__.py +16 -0
  73. ai_edge_torch/hlfb/mark_pattern/__init__.py +139 -0
  74. ai_edge_torch/hlfb/mark_pattern/passes.py +42 -0
  75. ai_edge_torch/hlfb/mark_pattern/pattern.py +260 -0
  76. ai_edge_torch/hlfb/test/__init__.py +14 -0
  77. ai_edge_torch/hlfb/test/test_mark_pattern.py +133 -0
  78. ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py +270 -0
  79. ai_edge_torch/model.py +134 -0
  80. ai_edge_torch/quantize/__init__.py +16 -0
  81. ai_edge_torch/quantize/pt2e_quantizer.py +438 -0
  82. ai_edge_torch/quantize/pt2e_quantizer_utils.py +1041 -0
  83. ai_edge_torch/quantize/quant_config.py +85 -0
  84. ai_edge_torch/testing/__init__.py +14 -0
  85. ai_edge_torch/testing/model_coverage/__init__.py +16 -0
  86. ai_edge_torch/testing/model_coverage/model_coverage.py +126 -0
  87. ai_edge_torch_nightly-0.1.dev202405131930.dist-info/LICENSE +202 -0
  88. ai_edge_torch_nightly-0.1.dev202405131930.dist-info/METADATA +38 -0
  89. ai_edge_torch_nightly-0.1.dev202405131930.dist-info/RECORD +91 -0
  90. ai_edge_torch_nightly-0.1.dev202405131930.dist-info/WHEEL +5 -0
  91. ai_edge_torch_nightly-0.1.dev202405131930.dist-info/top_level.txt +1 -0
@@ -0,0 +1,62 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ from typing import Callable
16
+
17
+ import torch
18
+ import torch.ao.quantization.quantize_pt2e
19
+
20
+
21
+ def tensor_to_nhwc(t: torch.Tensor):
22
+ return torch.ops.aten.permute(t.contiguous(), [0, 2, 3, 1]).contiguous()
23
+
24
+
25
+ def tensor_to_nchw(t: torch.Tensor):
26
+ return torch.ops.aten.permute(t.contiguous(), [0, 3, 1, 2]).contiguous()
27
+
28
+
29
+ def flatten_torch_op_overloads(op):
30
+ if isinstance(op, torch._ops.OpOverloadPacket):
31
+ return [getattr(op, overload) for overload in op.overloads()]
32
+ return [op]
33
+
34
+
35
+ _TORCH_Q_OPS = [
36
+ torch.ops.quantized_decomposed.quantize_per_tensor.default,
37
+ torch.ops.quantized_decomposed.quantize_per_tensor.tensor,
38
+ torch.ops.quantized_decomposed.quantize_per_tensor.tensor2,
39
+ torch.ops.quantized_decomposed.quantize_per_channel.default,
40
+ ]
41
+
42
+ _TORCH_DQ_OPS = [
43
+ torch.ops.quantized_decomposed.dequantize_per_tensor.default,
44
+ torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
45
+ torch.ops.quantized_decomposed.dequantize_per_tensor.tensor2,
46
+ torch.ops.quantized_decomposed.dequantize_per_channel.default,
47
+ ]
48
+
49
+
50
+ def is_q_node(node: torch.fx.Node):
51
+ return node.target in _TORCH_Q_OPS
52
+
53
+
54
+ def is_dq_node(node: torch.fx.Node):
55
+ return node.target in _TORCH_DQ_OPS
56
+
57
+
58
+ def get_paired_q_dq_ops(op: Callable) -> tuple[Callable, Callable]:
59
+ for q, dq in zip(_TORCH_Q_OPS, _TORCH_DQ_OPS):
60
+ if op in (q, dq):
61
+ return q, dq
62
+ raise AssertionError(f"{op} is not a Q/DQ op.")
@@ -0,0 +1,14 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
@@ -0,0 +1,273 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+
17
+ import os
18
+ import tempfile
19
+ import unittest
20
+
21
+ import torch
22
+ import torchvision
23
+
24
+ import ai_edge_torch
25
+ from ai_edge_torch.convert import conversion_utils as cutils
26
+ from ai_edge_torch.testing import model_coverage
27
+
28
+
29
+ class TestConvert(unittest.TestCase):
30
+ """Tests conversion of various modules."""
31
+
32
+ def setUp(self):
33
+ torch.manual_seed(0)
34
+
35
+ def test_convert_add(self):
36
+ """Tests conversion of a simple Add module."""
37
+
38
+ class Add(torch.nn.Module):
39
+
40
+ def forward(self, a, b):
41
+ return a + b
42
+
43
+ args = (
44
+ torch.randn((5, 10)),
45
+ torch.randn((5, 10)),
46
+ )
47
+ torch_module = Add().eval()
48
+ edge_model = ai_edge_torch.convert(torch_module, args)
49
+
50
+ self.assertTrue(model_coverage.compare_tflite_torch(edge_model, torch_module, args))
51
+
52
+ def test_convert_dot_add(self):
53
+ class DotAdd(torch.nn.Module):
54
+ """Tests conversion of a matrix multiplication followed by an add."""
55
+
56
+ def forward(self, a, b, c):
57
+ return a @ b + c
58
+
59
+ args = (
60
+ torch.randn((5, 10)),
61
+ torch.randn((10, 5)),
62
+ torch.randn((5, 5)),
63
+ )
64
+ torch_module = DotAdd().eval()
65
+ edge_model = ai_edge_torch.convert(torch_module, args)
66
+
67
+ self.assertTrue(model_coverage.compare_tflite_torch(edge_model, torch_module, args))
68
+
69
+ def test_convert_resnet18(self):
70
+ args = (torch.randn(4, 3, 224, 224),)
71
+ torch_module = torchvision.models.resnet18().eval()
72
+ edge_model = ai_edge_torch.convert(torch_module, args)
73
+
74
+ self.assertTrue(model_coverage.compare_tflite_torch(edge_model, torch_module, args))
75
+
76
+ def test_signature_args_ordering(self):
77
+ """Tests conversion of a model with more than 10 arguments."""
78
+
79
+ class AddChainWith11Args(torch.nn.Module):
80
+
81
+ def forward(
82
+ self,
83
+ arg0: "f32[64]",
84
+ arg1: "f32[64]",
85
+ arg2: "f32[64]",
86
+ arg3: "f32[64]",
87
+ arg4: "f32[64]",
88
+ arg5: "f32[64]",
89
+ arg6: "f32[64]",
90
+ arg7: "f32[64]",
91
+ arg8: "f32[64]",
92
+ arg9: "f32[64]",
93
+ arg10: "f32[64]",
94
+ ):
95
+ add0 = torch.add(arg0, arg1)
96
+ add1 = torch.add(add0, arg2)
97
+ add2 = torch.add(add1, arg3)
98
+ add3 = torch.add(add2, arg4)
99
+ add4 = torch.add(add3, arg5)
100
+ add5 = torch.add(add4, arg6)
101
+ add6 = torch.add(add5, arg7)
102
+ add7 = torch.add(add6, arg8)
103
+ add8 = torch.add(add7, arg9)
104
+ add9 = torch.add(add8, arg10)
105
+ return add9
106
+
107
+ sample_input = lambda: (
108
+ torch.rand((64,), dtype=torch.float32),
109
+ torch.rand((64,), dtype=torch.float32),
110
+ torch.rand((64,), dtype=torch.float32),
111
+ torch.rand((64,), dtype=torch.float32),
112
+ torch.rand((64,), dtype=torch.float32),
113
+ torch.rand((64,), dtype=torch.float32),
114
+ torch.rand((64,), dtype=torch.float32),
115
+ torch.rand((64,), dtype=torch.float32),
116
+ torch.rand((64,), dtype=torch.float32),
117
+ torch.rand((64,), dtype=torch.float32),
118
+ torch.rand((64,), dtype=torch.float32),
119
+ )
120
+ torch_model = AddChainWith11Args().eval()
121
+ edge_model = ai_edge_torch.convert(torch_model, sample_input())
122
+
123
+ result = model_coverage.compare_tflite_torch(
124
+ edge_model, torch_model, sample_input, num_valid_inputs=10
125
+ )
126
+ self.assertTrue(result)
127
+
128
+ def test_multi_output_model(self):
129
+ """Tests conversion of a model that returns multiple outputs."""
130
+
131
+ class BasicAddModelWithMultipleOutputs(torch.nn.Module):
132
+
133
+ def forward(self, arg0, arg1):
134
+ add0 = arg0 + arg1
135
+ mul0 = arg0 * arg1
136
+ return add0, mul0
137
+
138
+ sample_input = (
139
+ torch.rand((64,), dtype=torch.float32),
140
+ torch.rand((64,), dtype=torch.float32),
141
+ )
142
+
143
+ torch_model = BasicAddModelWithMultipleOutputs().eval()
144
+ edge_model = ai_edge_torch.convert(torch_model, sample_input)
145
+
146
+ result = model_coverage.compare_tflite_torch(edge_model, torch_model, sample_input)
147
+ self.assertTrue(result)
148
+
149
+ def test_12_outputs_model(self):
150
+ """Tests conversion of a model that returns multiple outputs."""
151
+
152
+ class BasicAddModelWithMultipleOutputs(torch.nn.Module):
153
+
154
+ def forward(self, arg0, arg1):
155
+ add0 = arg0 + arg1
156
+ mul0 = arg0 * arg1
157
+ add1 = add0 + mul0
158
+ mul1 = add0 * mul0
159
+ add2 = add1 + mul1
160
+ mul2 = add1 * mul1
161
+ add3 = add2 + mul2
162
+ mul3 = add2 * mul2
163
+ add4 = add3 + mul3
164
+ mul4 = add3 * mul3
165
+ add5 = add4 + mul4
166
+ mul5 = add4 * mul4
167
+
168
+ return (
169
+ add0,
170
+ mul0,
171
+ add1,
172
+ mul1,
173
+ add2,
174
+ mul2,
175
+ add3,
176
+ mul3,
177
+ add4,
178
+ mul4,
179
+ add5,
180
+ mul5,
181
+ )
182
+
183
+ sample_input = (
184
+ torch.rand((64,), dtype=torch.float32),
185
+ torch.rand((64,), dtype=torch.float32),
186
+ )
187
+
188
+ torch_model = BasicAddModelWithMultipleOutputs().eval()
189
+ edge_model = ai_edge_torch.convert(torch_model, sample_input)
190
+
191
+ result = model_coverage.compare_tflite_torch(edge_model, torch_model, sample_input)
192
+ self.assertTrue(result)
193
+
194
+ def test_apply_tfl_backdoor_flags(self):
195
+ """Tests if _apply_tfl_backdoor_flags correctly sets the values in a Converter object."""
196
+
197
+ class MockConverterInternalObject:
198
+
199
+ def __init__(self):
200
+ self.subkey2 = "original_subvalue2"
201
+
202
+ class MockConverter:
203
+
204
+ def __init__(self):
205
+ self.key1 = "original_value1"
206
+ self.key2 = MockConverterInternalObject()
207
+
208
+ mock_converter = MockConverter()
209
+ flags = {"key1": "new_value1", "key2": {"subkey2": "new_subvalue2"}}
210
+ cutils._apply_tfl_backdoor_flags(mock_converter, flags)
211
+
212
+ self.assertTrue(flags["key1"], "new_value1")
213
+ self.assertTrue(flags["key2"]["subkey2"], "new_subvalue2")
214
+
215
+ @unittest.skip("https://b.corp.google.com/issues/331463544")
216
+ def test_convert_add_backdoor_flags(self):
217
+ """Tests conversion of an add module setting a tflite converter flag."""
218
+
219
+ class Add(torch.nn.Module):
220
+
221
+ def forward(self, a, b):
222
+ return a + b
223
+
224
+ args = (
225
+ torch.randn((5, 10)),
226
+ torch.randn((5, 10)),
227
+ )
228
+ torch_module = Add().eval()
229
+
230
+ with tempfile.TemporaryDirectory() as tmp_dir_path:
231
+ mlir_dump_path = os.path.join(
232
+ tmp_dir_path, "test_convert_add_backdoor_flags_mlir_dump"
233
+ )
234
+ ai_edge_torch.convert(
235
+ torch_module, args, _ai_edge_converter_flags={"mlir_dump_dir": mlir_dump_path}
236
+ )
237
+ self.assertTrue(os.path.isdir(mlir_dump_path))
238
+
239
+ def test_convert_model_with_dynamic_batch(self):
240
+ """
241
+ Test converting a simple model with dynamic batch size.
242
+ """
243
+
244
+ class SampleModel(torch.nn.Module):
245
+
246
+ def __init__(self):
247
+ super().__init__()
248
+ self.w = torch.ones((10, 10)) * 2.7
249
+
250
+ def forward(self, x, y):
251
+ return x + y + self.w
252
+
253
+ sample_input = (torch.randn(4, 3, 10, 10), torch.randn(4, 3, 10, 10))
254
+ batch = torch.export.Dim("batch")
255
+ dynamic_shapes = ({0: batch}, {0: batch})
256
+
257
+ model = SampleModel().eval()
258
+ edge_model = ai_edge_torch.convert(
259
+ model, sample_input, dynamic_shapes=dynamic_shapes
260
+ )
261
+
262
+ for batch_size in [2, 4, 10]:
263
+ validate_input = (
264
+ torch.randn(batch_size, 3, 10, 10),
265
+ torch.randn(batch_size, 3, 10, 10),
266
+ )
267
+ self.assertTrue(
268
+ model_coverage.compare_tflite_torch(edge_model, model, validate_input)
269
+ )
270
+
271
+
272
+ if __name__ == "__main__":
273
+ unittest.main()
@@ -0,0 +1,171 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+
17
+ from typing import Callable
18
+ import unittest
19
+
20
+ import parameterized
21
+ import torch
22
+
23
+ import ai_edge_torch
24
+ from ai_edge_torch.testing import model_coverage
25
+
26
+
27
+ def _func_to_torch_module(func: Callable):
28
+ class TestModule(torch.nn.Module):
29
+
30
+ def __init__(self, func):
31
+ super().__init__()
32
+ self._func = func
33
+
34
+ def forward(self, *args, **kwargs):
35
+ return self._func(*args, **kwargs)
36
+
37
+ return TestModule(func).eval()
38
+
39
+
40
+ class TestConvertComposites(unittest.TestCase):
41
+ """Tests conversion modules that are meant to be wrapped as composites."""
42
+
43
+ def test_convert_hardswish(self):
44
+ """Tests conversion of a HardSwish module."""
45
+
46
+ args = (torch.randn((5, 10)),)
47
+ torch_module = torch.nn.Hardswish().eval()
48
+ edge_model = ai_edge_torch.convert(torch_module, args)
49
+
50
+ self.assertTrue(model_coverage.compare_tflite_torch(edge_model, torch_module, args))
51
+
52
+ @parameterized.parameterized.expand(
53
+ [
54
+ # no padding, stride = 1
55
+ ([1, 3, 6, 6], [3, 3], [1, 1], [0, 0], False, True, None),
56
+ # add stride
57
+ ([1, 3, 6, 6], [3, 3], [2, 2], [0, 0], False, True, None),
58
+ # default values
59
+ ([1, 3, 6, 6], [3, 3]),
60
+ # add padding
61
+ ([1, 3, 6, 6], [3, 3], [1, 1], [1, 1], False, True, None),
62
+ # add different padding for different dims
63
+ ([1, 3, 6, 6], [3, 3], [1, 1], [0, 1], False, True, None),
64
+ # add both stride and padding
65
+ ([1, 3, 6, 6], [3, 3], [2, 2], [1, 1], False, True, None),
66
+ # count_include_pad = False
67
+ ([1, 3, 6, 6], [3, 3], [1, 1], [1, 1], False, False, None),
68
+ # ceil_mode = True
69
+ ([1, 3, 6, 6], [3, 3], [1, 1], [1, 1], True, True, None),
70
+ # set divisor_override
71
+ ([1, 3, 6, 6], [3, 3], [1, 1], 0, False, True, 6),
72
+ # padding set to one number
73
+ ([1, 3, 6, 6], [3, 3], [1, 1], 1, False, True, None),
74
+ ]
75
+ )
76
+ def test_convert_avg_pool2d(self, input_size, *args):
77
+ """Tests conversion of a module containing an avg_pool2d aten."""
78
+ torch_module = _func_to_torch_module(
79
+ lambda input_tensor: torch.ops.aten.avg_pool2d(input_tensor, *args)
80
+ )
81
+ tracing_args = (torch.randn(*input_size),)
82
+ edge_model = ai_edge_torch.convert(torch_module, tracing_args)
83
+
84
+ self.assertTrue(
85
+ model_coverage.compare_tflite_torch(edge_model, torch_module, tracing_args)
86
+ )
87
+
88
+ @parameterized.parameterized.expand(
89
+ [
90
+ # use scale_factor with align_corners=False
91
+ (
92
+ [1, 3, 10, 10],
93
+ dict(scale_factor=3.0, mode='bilinear', align_corners=False),
94
+ ),
95
+ # use scale_factor with align_corners=true
96
+ ([1, 3, 10, 10], dict(scale_factor=3.0, mode='bilinear', align_corners=True)),
97
+ # use size
98
+ ([1, 3, 10, 10], dict(size=[15, 20], mode='bilinear')),
99
+ # use size with align_corners=true
100
+ ([1, 3, 10, 10], dict(size=[15, 20], mode='bilinear', align_corners=True)),
101
+ ]
102
+ )
103
+ def test_convert_upsample_bilinear_functional(self, input_size, kwargs):
104
+ """Tests conversion of a torch.nn.functional.upsample module."""
105
+ torch_module = _func_to_torch_module(
106
+ lambda input_tensor: torch.nn.functional.upsample(input_tensor, **kwargs)
107
+ )
108
+ tracing_args = (torch.randn(*input_size),)
109
+ edge_model = ai_edge_torch.convert(torch_module, tracing_args)
110
+
111
+ self.assertTrue(
112
+ model_coverage.compare_tflite_torch(edge_model, torch_module, tracing_args)
113
+ )
114
+
115
+ @parameterized.parameterized.expand(
116
+ [
117
+ # use scale_factor with align_corners=False
118
+ (
119
+ [1, 3, 10, 10],
120
+ dict(scale_factor=3.0, mode='bilinear', align_corners=False),
121
+ ),
122
+ # use scale_factor with align_corners=true
123
+ ([1, 3, 10, 10], dict(scale_factor=3.0, mode='bilinear', align_corners=True)),
124
+ # use size
125
+ ([1, 3, 10, 10], dict(size=[15, 20], mode='bilinear')),
126
+ # use size with align_corners=true
127
+ ([1, 3, 10, 10], dict(size=[15, 20], mode='bilinear', align_corners=True)),
128
+ ]
129
+ )
130
+ def test_convert_upsample_bilinear(self, input_size, kwargs):
131
+ """Tests conversion of a torch.nn.Upsample module."""
132
+ torch_module = _func_to_torch_module(
133
+ lambda input_tensor: torch.nn.Upsample(**kwargs)(input_tensor)
134
+ )
135
+ tracing_args = (torch.randn(*input_size),)
136
+ edge_model = ai_edge_torch.convert(torch_module, tracing_args)
137
+
138
+ self.assertTrue(
139
+ model_coverage.compare_tflite_torch(edge_model, torch_module, tracing_args)
140
+ )
141
+
142
+ @parameterized.parameterized.expand(
143
+ [
144
+ # use scale_factor with align_corners=False
145
+ (
146
+ [1, 3, 10, 10],
147
+ dict(scale_factor=3.0, mode='bilinear', align_corners=False),
148
+ ),
149
+ # use scale_factor with align_corners=true
150
+ ([1, 3, 10, 10], dict(scale_factor=3.0, mode='bilinear', align_corners=True)),
151
+ # use size
152
+ ([1, 3, 10, 10], dict(size=[15, 20], mode='bilinear')),
153
+ # use size with align_corners=true
154
+ ([1, 3, 10, 10], dict(size=[15, 20], mode='bilinear', align_corners=True)),
155
+ ]
156
+ )
157
+ def test_convert_interpolate_bilinear_functional(self, input_size, kwargs):
158
+ """Tests conversion of a torch.nn.functional.interpolate module."""
159
+ torch_module = _func_to_torch_module(
160
+ lambda input_tensor: torch.nn.functional.interpolate(input_tensor, **kwargs)
161
+ )
162
+ tracing_args = (torch.randn(*input_size),)
163
+ edge_model = ai_edge_torch.convert(torch_module, tracing_args)
164
+
165
+ self.assertTrue(
166
+ model_coverage.compare_tflite_torch(edge_model, torch_module, tracing_args)
167
+ )
168
+
169
+
170
+ if __name__ == '__main__':
171
+ unittest.main()
@@ -0,0 +1,139 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import unittest
17
+
18
+ import torch
19
+ import torchvision
20
+
21
+ import ai_edge_torch
22
+ from ai_edge_torch.testing import model_coverage
23
+
24
+
25
+ class TestConvertMultiSignature(unittest.TestCase):
26
+ """Tests conversion of various modules through multi-signature conversion."""
27
+
28
+ def setUp(self):
29
+ torch.manual_seed(0)
30
+
31
+ def test_convert_mobilenet_v2_with_default(self):
32
+ """Tests conversion of a model with two signatures one of which is the default."""
33
+ torch_module = torchvision.models.mobilenet_v2().eval()
34
+
35
+ args = (torch.randn(4, 3, 224, 224),)
36
+ large_args = (torch.randn(4, 3, 336, 336),)
37
+
38
+ signature_name = "large_input"
39
+
40
+ edge_model = ai_edge_torch.signature(
41
+ signature_name, torch_module, large_args
42
+ ).convert(torch_module, args)
43
+
44
+ self.assertTrue(model_coverage.compare_tflite_torch(edge_model, torch_module, args))
45
+ self.assertTrue(
46
+ model_coverage.compare_tflite_torch(
47
+ edge_model, torch_module, large_args, signature_name=signature_name
48
+ )
49
+ )
50
+
51
+ def test_convert_mobilenet_v2_no_default(self):
52
+ """Tests conversion of a model with two signatures none of which is the default."""
53
+ torch_module = torchvision.models.mobilenet_v2().eval()
54
+
55
+ args = (torch.randn(4, 3, 224, 224),)
56
+ large_args = (torch.randn(4, 3, 336, 336),)
57
+
58
+ signature_name_1 = "input"
59
+ signature_name_2 = "large_input"
60
+
61
+ edge_model = (
62
+ ai_edge_torch.signature(signature_name_1, torch_module, args)
63
+ .signature(signature_name_2, torch_module, large_args)
64
+ .convert()
65
+ )
66
+
67
+ with self.assertRaises(ValueError):
68
+ edge_model(*args)
69
+
70
+ self.assertTrue(
71
+ model_coverage.compare_tflite_torch(
72
+ edge_model, torch_module, args, signature_name=signature_name_1
73
+ )
74
+ )
75
+ self.assertTrue(
76
+ model_coverage.compare_tflite_torch(
77
+ edge_model, torch_module, large_args, signature_name=signature_name_2
78
+ )
79
+ )
80
+
81
+ def test_convert_mobilenet_v2_signature_helper(self):
82
+ """Tests the ai_edge_torch.signature helper function works."""
83
+ torch_module = torchvision.models.mobilenet_v2().eval()
84
+
85
+ args = (torch.randn(4, 3, 224, 224),)
86
+ large_args = (torch.randn(4, 3, 336, 336),)
87
+
88
+ signature_name = "large_input"
89
+
90
+ edge_model = ai_edge_torch.signature(signature_name, torch_module, args).convert(
91
+ torch_module, large_args
92
+ )
93
+
94
+ self.assertTrue(model_coverage.compare_tflite_torch(edge_model, torch_module, args))
95
+ self.assertTrue(
96
+ model_coverage.compare_tflite_torch(
97
+ edge_model, torch_module, large_args, signature_name=signature_name
98
+ )
99
+ )
100
+
101
+ def test_convert_separate_modules(self):
102
+ """Tests conversion of two completely different modules as separate signatures."""
103
+ mobilentv2 = torchvision.models.mobilenet_v2().eval()
104
+ resnet18 = torchvision.models.resnet18().eval()
105
+
106
+ mobilenet_args = (torch.randn(4, 3, 224, 224),)
107
+ resnet_args = (torch.randn(4, 3, 224, 224),)
108
+
109
+ mobilenet_signature_name = "mobilentv2"
110
+ resnet_signature_name = "resnet18"
111
+
112
+ edge_model = (
113
+ ai_edge_torch.signature(mobilenet_signature_name, mobilentv2, mobilenet_args)
114
+ .signature(resnet_signature_name, resnet18, resnet_args)
115
+ .convert(resnet18, resnet_args)
116
+ )
117
+
118
+ mobilenet_inference_args = (torch.randn(4, 3, 224, 224),)
119
+ resnet_inference_args = (torch.randn(4, 3, 224, 224),)
120
+ self.assertTrue(
121
+ model_coverage.compare_tflite_torch(
122
+ edge_model,
123
+ mobilentv2,
124
+ mobilenet_inference_args,
125
+ signature_name=mobilenet_signature_name,
126
+ )
127
+ )
128
+ self.assertTrue(
129
+ model_coverage.compare_tflite_torch(
130
+ edge_model,
131
+ resnet18,
132
+ resnet_inference_args,
133
+ signature_name=resnet_signature_name,
134
+ )
135
+ )
136
+
137
+
138
+ if __name__ == "__main__":
139
+ unittest.main()
@@ -0,0 +1,16 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from .culprit import find_culprits