ai-edge-torch-nightly 0.2.0.dev20240714__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ai-edge-torch-nightly might be problematic. Click here for more details.

Files changed (121) hide show
  1. ai_edge_torch/__init__.py +31 -0
  2. ai_edge_torch/convert/__init__.py +14 -0
  3. ai_edge_torch/convert/conversion.py +117 -0
  4. ai_edge_torch/convert/conversion_utils.py +400 -0
  5. ai_edge_torch/convert/converter.py +202 -0
  6. ai_edge_torch/convert/fx_passes/__init__.py +59 -0
  7. ai_edge_torch/convert/fx_passes/_pass_base.py +49 -0
  8. ai_edge_torch/convert/fx_passes/build_aten_composite_pass.py +225 -0
  9. ai_edge_torch/convert/fx_passes/build_interpolate_composite_pass.py +123 -0
  10. ai_edge_torch/convert/fx_passes/canonicalize_pass.py +37 -0
  11. ai_edge_torch/convert/fx_passes/inject_mlir_debuginfo_pass.py +73 -0
  12. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/__init__.py +16 -0
  13. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_check.py +215 -0
  14. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_mark.py +48 -0
  15. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/__init__.py +17 -0
  16. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py +59 -0
  17. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py +215 -0
  18. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +400 -0
  19. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/op_func_registry.py +30 -0
  20. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +293 -0
  21. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/utils.py +62 -0
  22. ai_edge_torch/convert/test/__init__.py +14 -0
  23. ai_edge_torch/convert/test/test_convert.py +311 -0
  24. ai_edge_torch/convert/test/test_convert_composites.py +192 -0
  25. ai_edge_torch/convert/test/test_convert_multisig.py +139 -0
  26. ai_edge_torch/convert/test/test_to_channel_last_io.py +96 -0
  27. ai_edge_torch/convert/to_channel_last_io.py +85 -0
  28. ai_edge_torch/debug/__init__.py +17 -0
  29. ai_edge_torch/debug/culprit.py +464 -0
  30. ai_edge_torch/debug/test/__init__.py +14 -0
  31. ai_edge_torch/debug/test/test_culprit.py +133 -0
  32. ai_edge_torch/debug/test/test_search_model.py +50 -0
  33. ai_edge_torch/debug/utils.py +48 -0
  34. ai_edge_torch/experimental/__init__.py +14 -0
  35. ai_edge_torch/generative/__init__.py +14 -0
  36. ai_edge_torch/generative/examples/__init__.py +14 -0
  37. ai_edge_torch/generative/examples/gemma/__init__.py +14 -0
  38. ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +66 -0
  39. ai_edge_torch/generative/examples/gemma/gemma.py +174 -0
  40. ai_edge_torch/generative/examples/phi2/__init__.py +14 -0
  41. ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +64 -0
  42. ai_edge_torch/generative/examples/phi2/phi2.py +164 -0
  43. ai_edge_torch/generative/examples/stable_diffusion/__init__.py +14 -0
  44. ai_edge_torch/generative/examples/stable_diffusion/attention.py +106 -0
  45. ai_edge_torch/generative/examples/stable_diffusion/clip.py +115 -0
  46. ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +142 -0
  47. ai_edge_torch/generative/examples/stable_diffusion/decoder.py +317 -0
  48. ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +573 -0
  49. ai_edge_torch/generative/examples/stable_diffusion/encoder.py +118 -0
  50. ai_edge_torch/generative/examples/stable_diffusion/pipeline.py +222 -0
  51. ai_edge_torch/generative/examples/stable_diffusion/samplers/__init__.py +19 -0
  52. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler.py +61 -0
  53. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler_ancestral.py +65 -0
  54. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_lms.py +73 -0
  55. ai_edge_torch/generative/examples/stable_diffusion/samplers/sampler.py +38 -0
  56. ai_edge_torch/generative/examples/stable_diffusion/tokenizer.py +108 -0
  57. ai_edge_torch/generative/examples/stable_diffusion/util.py +71 -0
  58. ai_edge_torch/generative/examples/t5/__init__.py +14 -0
  59. ai_edge_torch/generative/examples/t5/convert_to_tflite.py +135 -0
  60. ai_edge_torch/generative/examples/t5/t5.py +608 -0
  61. ai_edge_torch/generative/examples/t5/t5_attention.py +231 -0
  62. ai_edge_torch/generative/examples/test_models/__init__.py +14 -0
  63. ai_edge_torch/generative/examples/test_models/toy_model.py +122 -0
  64. ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +161 -0
  65. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +143 -0
  66. ai_edge_torch/generative/examples/tiny_llama/__init__.py +0 -0
  67. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +66 -0
  68. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +164 -0
  69. ai_edge_torch/generative/fx_passes/__init__.py +31 -0
  70. ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +47 -0
  71. ai_edge_torch/generative/layers/__init__.py +14 -0
  72. ai_edge_torch/generative/layers/attention.py +354 -0
  73. ai_edge_torch/generative/layers/attention_utils.py +169 -0
  74. ai_edge_torch/generative/layers/builder.py +131 -0
  75. ai_edge_torch/generative/layers/feed_forward.py +95 -0
  76. ai_edge_torch/generative/layers/kv_cache.py +83 -0
  77. ai_edge_torch/generative/layers/model_config.py +158 -0
  78. ai_edge_torch/generative/layers/normalization.py +62 -0
  79. ai_edge_torch/generative/layers/rotary_position_embedding.py +36 -0
  80. ai_edge_torch/generative/layers/scaled_dot_product_attention.py +117 -0
  81. ai_edge_torch/generative/layers/unet/__init__.py +14 -0
  82. ai_edge_torch/generative/layers/unet/blocks_2d.py +711 -0
  83. ai_edge_torch/generative/layers/unet/builder.py +47 -0
  84. ai_edge_torch/generative/layers/unet/model_config.py +269 -0
  85. ai_edge_torch/generative/quantize/__init__.py +14 -0
  86. ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/__init__.py +0 -0
  87. ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/translate_recipe.py +148 -0
  88. ai_edge_torch/generative/quantize/example.py +45 -0
  89. ai_edge_torch/generative/quantize/quant_attrs.py +68 -0
  90. ai_edge_torch/generative/quantize/quant_recipe.py +151 -0
  91. ai_edge_torch/generative/quantize/quant_recipe_utils.py +51 -0
  92. ai_edge_torch/generative/quantize/quant_recipes.py +48 -0
  93. ai_edge_torch/generative/quantize/supported_schemes.py +32 -0
  94. ai_edge_torch/generative/test/__init__.py +14 -0
  95. ai_edge_torch/generative/test/loader_test.py +80 -0
  96. ai_edge_torch/generative/test/test_model_conversion.py +235 -0
  97. ai_edge_torch/generative/test/test_quantize.py +162 -0
  98. ai_edge_torch/generative/utilities/__init__.py +15 -0
  99. ai_edge_torch/generative/utilities/loader.py +328 -0
  100. ai_edge_torch/generative/utilities/stable_diffusion_loader.py +924 -0
  101. ai_edge_torch/generative/utilities/t5_loader.py +483 -0
  102. ai_edge_torch/hlfb/__init__.py +16 -0
  103. ai_edge_torch/hlfb/mark_pattern/__init__.py +139 -0
  104. ai_edge_torch/hlfb/mark_pattern/passes.py +42 -0
  105. ai_edge_torch/hlfb/mark_pattern/pattern.py +273 -0
  106. ai_edge_torch/hlfb/test/__init__.py +14 -0
  107. ai_edge_torch/hlfb/test/test_mark_pattern.py +133 -0
  108. ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py +270 -0
  109. ai_edge_torch/model.py +142 -0
  110. ai_edge_torch/quantize/__init__.py +16 -0
  111. ai_edge_torch/quantize/pt2e_quantizer.py +438 -0
  112. ai_edge_torch/quantize/pt2e_quantizer_utils.py +1041 -0
  113. ai_edge_torch/quantize/quant_config.py +81 -0
  114. ai_edge_torch/testing/__init__.py +14 -0
  115. ai_edge_torch/testing/model_coverage/__init__.py +16 -0
  116. ai_edge_torch/testing/model_coverage/model_coverage.py +132 -0
  117. ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/LICENSE +202 -0
  118. ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/METADATA +38 -0
  119. ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/RECORD +121 -0
  120. ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/WHEEL +5 -0
  121. ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/top_level.txt +1 -0
@@ -0,0 +1,192 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+
17
+ from typing import Callable
18
+ import unittest
19
+
20
+ import parameterized
21
+ import torch
22
+
23
+ import ai_edge_torch
24
+ from ai_edge_torch.testing import model_coverage
25
+
26
+
27
+ def _func_to_torch_module(func: Callable):
28
+ class TestModule(torch.nn.Module):
29
+
30
+ def __init__(self, func):
31
+ super().__init__()
32
+ self._func = func
33
+
34
+ def forward(self, *args, **kwargs):
35
+ return self._func(*args, **kwargs)
36
+
37
+ return TestModule(func).eval()
38
+
39
+
40
+ class TestConvertComposites(unittest.TestCase):
41
+ """Tests conversion modules that are meant to be wrapped as composites."""
42
+
43
+ def test_convert_hardswish(self):
44
+ """Tests conversion of a HardSwish module."""
45
+
46
+ args = (torch.randn((5, 10)),)
47
+ torch_module = torch.nn.Hardswish().eval()
48
+ edge_model = ai_edge_torch.convert(torch_module, args)
49
+
50
+ self.assertTrue(model_coverage.compare_tflite_torch(edge_model, torch_module, args))
51
+
52
+ @parameterized.parameterized.expand(
53
+ [
54
+ # input_size, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override
55
+ # no padding, stride = 1
56
+ ([1, 3, 6, 6], [3, 3], [1, 1], [0, 0], False, True, None),
57
+ # add stride
58
+ ([1, 3, 6, 6], [3, 3], [2, 2], [0, 0], False, True, None),
59
+ # default values
60
+ ([1, 3, 6, 6], [3, 3]),
61
+ # add padding
62
+ ([1, 3, 6, 6], [3, 3], [1, 1], [1, 1], False, True, None),
63
+ # add different padding for different dims
64
+ ([1, 3, 6, 6], [3, 3], [1, 1], [0, 1], False, True, None),
65
+ # add both stride and padding
66
+ ([1, 3, 6, 6], [3, 3], [2, 2], [1, 1], False, True, None),
67
+ # count_include_pad = False
68
+ ([1, 3, 6, 6], [3, 3], [1, 1], [1, 1], False, False, None),
69
+ # ceil_mode = True
70
+ ([1, 3, 6, 6], [3, 3], [1, 1], [1, 1], True, True, None),
71
+ # ceil_mode = True, stride=[3, 3]
72
+ ([1, 3, 6, 6], [3, 3], [3, 3], [1, 1], True, True, None),
73
+ # set divisor_override
74
+ ([1, 3, 6, 6], [3, 3], [1, 1], 0, False, True, 6),
75
+ # padding set to one number
76
+ ([1, 3, 6, 6], [3, 3], [1, 1], 1, False, True, None),
77
+ ]
78
+ )
79
+ def test_convert_avg_pool2d(self, input_size, *args):
80
+ """Tests conversion of a module containing an avg_pool2d aten."""
81
+ torch_module = _func_to_torch_module(
82
+ lambda input_tensor: torch.ops.aten.avg_pool2d(input_tensor, *args)
83
+ )
84
+ tracing_args = (torch.randn(*input_size),)
85
+ edge_model = ai_edge_torch.convert(torch_module, tracing_args)
86
+
87
+ self.assertTrue(
88
+ model_coverage.compare_tflite_torch(edge_model, torch_module, tracing_args)
89
+ )
90
+
91
+ @parameterized.parameterized.expand(
92
+ [
93
+ # use scale_factor with align_corners=False
94
+ (
95
+ [1, 3, 10, 10],
96
+ dict(scale_factor=3.0, mode='bilinear', align_corners=False),
97
+ ),
98
+ # use scale_factor with align_corners=true
99
+ ([1, 3, 10, 10], dict(scale_factor=3.0, mode='bilinear', align_corners=True)),
100
+ # use size
101
+ ([1, 3, 10, 10], dict(size=[15, 20], mode='bilinear')),
102
+ # use size with align_corners=true
103
+ ([1, 3, 10, 10], dict(size=[15, 20], mode='bilinear', align_corners=True)),
104
+ ]
105
+ )
106
+ def test_convert_upsample_bilinear_functional(self, input_size, kwargs):
107
+ """Tests conversion of a torch.nn.functional.upsample module."""
108
+ torch_module = _func_to_torch_module(
109
+ lambda input_tensor: torch.nn.functional.upsample(input_tensor, **kwargs)
110
+ )
111
+ tracing_args = (torch.randn(*input_size),)
112
+ edge_model = ai_edge_torch.convert(torch_module, tracing_args)
113
+
114
+ self.assertTrue(
115
+ model_coverage.compare_tflite_torch(edge_model, torch_module, tracing_args)
116
+ )
117
+
118
+ @parameterized.parameterized.expand(
119
+ [
120
+ # use scale_factor with align_corners=False
121
+ (
122
+ [1, 3, 10, 10],
123
+ dict(scale_factor=3.0, mode='bilinear', align_corners=False),
124
+ ),
125
+ # use scale_factor with align_corners=true
126
+ ([1, 3, 10, 10], dict(scale_factor=3.0, mode='bilinear', align_corners=True)),
127
+ # use size
128
+ ([1, 3, 10, 10], dict(size=[15, 20], mode='bilinear')),
129
+ # use size with align_corners=true
130
+ ([1, 3, 10, 10], dict(size=[15, 20], mode='bilinear', align_corners=True)),
131
+ ]
132
+ )
133
+ def test_convert_upsample_bilinear(self, input_size, kwargs):
134
+ """Tests conversion of a torch.nn.Upsample module."""
135
+ torch_module = _func_to_torch_module(
136
+ lambda input_tensor: torch.nn.Upsample(**kwargs)(input_tensor)
137
+ )
138
+ tracing_args = (torch.randn(*input_size),)
139
+ edge_model = ai_edge_torch.convert(torch_module, tracing_args)
140
+
141
+ self.assertTrue(
142
+ model_coverage.compare_tflite_torch(edge_model, torch_module, tracing_args)
143
+ )
144
+
145
+ @parameterized.parameterized.expand(
146
+ [
147
+ # use scale_factor with align_corners=False
148
+ (
149
+ [1, 3, 10, 10],
150
+ dict(scale_factor=3.0, mode='bilinear', align_corners=False),
151
+ ),
152
+ # use scale_factor with align_corners=true
153
+ ([1, 3, 10, 10], dict(scale_factor=3.0, mode='bilinear', align_corners=True)),
154
+ # use size
155
+ ([1, 3, 10, 10], dict(size=[15, 20], mode='bilinear')),
156
+ # use size with align_corners=true
157
+ ([1, 3, 10, 10], dict(size=[15, 20], mode='bilinear', align_corners=True)),
158
+ ]
159
+ )
160
+ def test_convert_interpolate_bilinear_functional(self, input_size, kwargs):
161
+ """Tests conversion of a torch.nn.functional.interpolate module."""
162
+ torch_module = _func_to_torch_module(
163
+ lambda input_tensor: torch.nn.functional.interpolate(input_tensor, **kwargs)
164
+ )
165
+ tracing_args = (torch.randn(*input_size),)
166
+ edge_model = ai_edge_torch.convert(torch_module, tracing_args)
167
+
168
+ self.assertTrue(
169
+ model_coverage.compare_tflite_torch(edge_model, torch_module, tracing_args)
170
+ )
171
+
172
+ def test_convert_gelu(self):
173
+ """Tests conversion of a GELU module."""
174
+
175
+ args = (torch.randn((5, 10)),)
176
+ torch_module = torch.nn.GELU().eval()
177
+ edge_model = ai_edge_torch.convert(torch_module, args)
178
+
179
+ self.assertTrue(model_coverage.compare_tflite_torch(edge_model, torch_module, args))
180
+
181
+ def test_convert_gelu_approximate(self):
182
+ """Tests conversion of an Approximate GELU module."""
183
+
184
+ args = (torch.randn((5, 10)),)
185
+ torch_module = torch.nn.GELU('tanh').eval()
186
+ edge_model = ai_edge_torch.convert(torch_module, args)
187
+
188
+ self.assertTrue(model_coverage.compare_tflite_torch(edge_model, torch_module, args))
189
+
190
+
191
+ if __name__ == '__main__':
192
+ unittest.main()
@@ -0,0 +1,139 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import unittest
17
+
18
+ import torch
19
+ import torchvision
20
+
21
+ import ai_edge_torch
22
+ from ai_edge_torch.testing import model_coverage
23
+
24
+
25
+ class TestConvertMultiSignature(unittest.TestCase):
26
+ """Tests conversion of various modules through multi-signature conversion."""
27
+
28
+ def setUp(self):
29
+ torch.manual_seed(0)
30
+
31
+ def test_convert_mobilenet_v2_with_default(self):
32
+ """Tests conversion of a model with two signatures one of which is the default."""
33
+ torch_module = torchvision.models.mobilenet_v2().eval()
34
+
35
+ args = (torch.randn(4, 3, 224, 224),)
36
+ large_args = (torch.randn(4, 3, 336, 336),)
37
+
38
+ signature_name = "large_input"
39
+
40
+ edge_model = ai_edge_torch.signature(
41
+ signature_name, torch_module, large_args
42
+ ).convert(torch_module, args)
43
+
44
+ self.assertTrue(model_coverage.compare_tflite_torch(edge_model, torch_module, args))
45
+ self.assertTrue(
46
+ model_coverage.compare_tflite_torch(
47
+ edge_model, torch_module, large_args, signature_name=signature_name
48
+ )
49
+ )
50
+
51
+ def test_convert_mobilenet_v2_no_default(self):
52
+ """Tests conversion of a model with two signatures none of which is the default."""
53
+ torch_module = torchvision.models.mobilenet_v2().eval()
54
+
55
+ args = (torch.randn(4, 3, 224, 224),)
56
+ large_args = (torch.randn(4, 3, 336, 336),)
57
+
58
+ signature_name_1 = "input"
59
+ signature_name_2 = "large_input"
60
+
61
+ edge_model = (
62
+ ai_edge_torch.signature(signature_name_1, torch_module, args)
63
+ .signature(signature_name_2, torch_module, large_args)
64
+ .convert()
65
+ )
66
+
67
+ with self.assertRaises(ValueError):
68
+ edge_model(*args)
69
+
70
+ self.assertTrue(
71
+ model_coverage.compare_tflite_torch(
72
+ edge_model, torch_module, args, signature_name=signature_name_1
73
+ )
74
+ )
75
+ self.assertTrue(
76
+ model_coverage.compare_tflite_torch(
77
+ edge_model, torch_module, large_args, signature_name=signature_name_2
78
+ )
79
+ )
80
+
81
+ def test_convert_mobilenet_v2_signature_helper(self):
82
+ """Tests the ai_edge_torch.signature helper function works."""
83
+ torch_module = torchvision.models.mobilenet_v2().eval()
84
+
85
+ args = (torch.randn(4, 3, 224, 224),)
86
+ large_args = (torch.randn(4, 3, 336, 336),)
87
+
88
+ signature_name = "large_input"
89
+
90
+ edge_model = ai_edge_torch.signature(signature_name, torch_module, args).convert(
91
+ torch_module, large_args
92
+ )
93
+
94
+ self.assertTrue(model_coverage.compare_tflite_torch(edge_model, torch_module, args))
95
+ self.assertTrue(
96
+ model_coverage.compare_tflite_torch(
97
+ edge_model, torch_module, large_args, signature_name=signature_name
98
+ )
99
+ )
100
+
101
+ def test_convert_separate_modules(self):
102
+ """Tests conversion of two completely different modules as separate signatures."""
103
+ mobilentv2 = torchvision.models.mobilenet_v2().eval()
104
+ resnet18 = torchvision.models.resnet18().eval()
105
+
106
+ mobilenet_args = (torch.randn(4, 3, 224, 224),)
107
+ resnet_args = (torch.randn(4, 3, 224, 224),)
108
+
109
+ mobilenet_signature_name = "mobilentv2"
110
+ resnet_signature_name = "resnet18"
111
+
112
+ edge_model = (
113
+ ai_edge_torch.signature(mobilenet_signature_name, mobilentv2, mobilenet_args)
114
+ .signature(resnet_signature_name, resnet18, resnet_args)
115
+ .convert(resnet18, resnet_args)
116
+ )
117
+
118
+ mobilenet_inference_args = (torch.randn(4, 3, 224, 224),)
119
+ resnet_inference_args = (torch.randn(4, 3, 224, 224),)
120
+ self.assertTrue(
121
+ model_coverage.compare_tflite_torch(
122
+ edge_model,
123
+ mobilentv2,
124
+ mobilenet_inference_args,
125
+ signature_name=mobilenet_signature_name,
126
+ )
127
+ )
128
+ self.assertTrue(
129
+ model_coverage.compare_tflite_torch(
130
+ edge_model,
131
+ resnet18,
132
+ resnet_inference_args,
133
+ signature_name=resnet_signature_name,
134
+ )
135
+ )
136
+
137
+
138
+ if __name__ == "__main__":
139
+ unittest.main()
@@ -0,0 +1,96 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import unittest
17
+
18
+ import torch
19
+
20
+ import ai_edge_torch
21
+
22
+
23
+ class Identity(torch.nn.Module):
24
+
25
+ def forward(self, x):
26
+ return x
27
+
28
+
29
+ class TestToChannelLastIO(unittest.TestCase):
30
+ """Tests to_channel_last_io API and module wrapper."""
31
+
32
+ def test_no_transformations(self):
33
+ x = torch.rand(1, 3, 10, 10)
34
+ y = ai_edge_torch.to_channel_last_io(Identity())(x)
35
+ self.assertEqual(y.shape, (1, 3, 10, 10))
36
+
37
+ def test_args(self):
38
+ x = torch.rand(1, 10, 10, 3)
39
+ y = ai_edge_torch.to_channel_last_io(Identity(), args=[0])(x)
40
+ self.assertEqual(y.shape, (1, 3, 10, 10))
41
+
42
+ def test_outputs(self):
43
+ x = torch.rand(1, 3, 10, 10)
44
+ y = ai_edge_torch.to_channel_last_io(Identity(), outputs=[0])(x)
45
+ self.assertEqual(y.shape, (1, 10, 10, 3))
46
+
47
+ def test_args_outputs(self):
48
+ x = torch.rand(1, 10, 10, 3)
49
+ y = ai_edge_torch.to_channel_last_io(Identity(), args=[0], outputs=[0])(x)
50
+ self.assertEqual(y.shape, (1, 10, 10, 3))
51
+
52
+ def test_args_5d(self):
53
+ x = torch.rand(1, 10, 10, 10, 3)
54
+ y = ai_edge_torch.to_channel_last_io(Identity(), args=[0])(x)
55
+ self.assertEqual(y.shape, (1, 3, 10, 10, 10))
56
+
57
+ def test_outputs_5d(self):
58
+ x = torch.rand(1, 3, 10, 10, 10)
59
+ y = ai_edge_torch.to_channel_last_io(Identity(), outputs=[0])(x)
60
+ self.assertEqual(y.shape, (1, 10, 10, 10, 3))
61
+
62
+ def test_chained_wrappers(self):
63
+ x = torch.rand(1, 10, 10, 3)
64
+
65
+ m = Identity()
66
+ m = ai_edge_torch.to_channel_last_io(m, args=[0])
67
+ m = ai_edge_torch.to_channel_last_io(m, outputs=[0])
68
+
69
+ y = m(x)
70
+ self.assertEqual(y.shape, (1, 10, 10, 3))
71
+
72
+ def test_list_args(self):
73
+ class Add(torch.nn.Module):
74
+
75
+ def forward(self, x, y):
76
+ return x + y
77
+
78
+ x = (torch.rand(1, 10, 10, 3), torch.rand(1, 10, 10, 3))
79
+ y = ai_edge_torch.to_channel_last_io(Add(), args=[0, 1])(*x)
80
+ self.assertEqual(y.shape, (1, 3, 10, 10))
81
+
82
+ def test_list_outputs(self):
83
+ class TwoIdentity(torch.nn.Module):
84
+
85
+ def forward(self, x):
86
+ return x, x
87
+
88
+ x = torch.rand(1, 3, 10, 10)
89
+ y = ai_edge_torch.to_channel_last_io(TwoIdentity(), outputs=[0])(x)
90
+ self.assertIsInstance(y, tuple)
91
+ self.assertEqual(y[0].shape, (1, 10, 10, 3))
92
+ self.assertEqual(y[1].shape, (1, 3, 10, 10))
93
+
94
+
95
+ if __name__ == "__main__":
96
+ unittest.main()
@@ -0,0 +1,85 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from typing import Optional
17
+
18
+ import torch
19
+ from torch import nn
20
+
21
+
22
+ class ChannelLastIOWrapper(nn.Module):
23
+
24
+ def __init__(self, wrapped, *, args=None, outputs=None):
25
+ super().__init__()
26
+ self.wrapped = wrapped
27
+ self._args = args or []
28
+ self._outputs = outputs or []
29
+
30
+ def _to_channel_last(self, x):
31
+ if not torch.is_tensor(x):
32
+ raise ValueError("Input must be a torch tensor")
33
+ if x.ndim < 3:
34
+ raise ValueError("Input must be a tensor with rank >= 3 in layout (N, C, ...)")
35
+ dims = [0, *range(2, x.ndim), 1]
36
+ return torch.permute(x, dims)
37
+
38
+ def _to_channel_first(self, x):
39
+ if not torch.is_tensor(x):
40
+ raise ValueError("Input must be a torch tensor.")
41
+ if x.ndim < 3:
42
+ raise ValueError("Input must be a tensor with rank >= 3 in layout (N, ..., C)")
43
+ dims = [0, x.ndim - 1, *range(1, x.ndim - 1)]
44
+ return torch.permute(x, dims)
45
+
46
+ def forward(self, *args, **kwargs):
47
+ args = list(args)
48
+ for i in self._args:
49
+ args[i] = self._to_channel_first(args[i])
50
+
51
+ outputs = self.wrapped(*args, **kwargs)
52
+
53
+ if not isinstance(outputs, (list, tuple)):
54
+ outputs_is_list = False
55
+ output_list = [outputs]
56
+ else:
57
+ outputs_is_list = True
58
+ output_list = list(outputs)
59
+
60
+ for i in self._outputs:
61
+ output_list[i] = self._to_channel_last(output_list[i])
62
+
63
+ if not outputs_is_list:
64
+ return output_list[0]
65
+ else:
66
+ return type(outputs)(output_list)
67
+
68
+
69
+ def to_channel_last_io(
70
+ module: nn.Module,
71
+ args: Optional[list[int]] = None,
72
+ outputs: Optional[list[int]] = None,
73
+ ):
74
+ """Wraps the module with channel first to channel last layout transformations.
75
+
76
+ Args:
77
+ args (list[int]): Transform args with indices in the list from channel first
78
+ (N, C, ...) to channel last (N, ..., C).
79
+ outputs (list[int]): Transform outputs with indices in the list from channel
80
+ first (N, C, ...) to channel last (N, ..., C).
81
+ Returns:
82
+ The wrapped nn.Module with additional layout transposes after inputs and/or before
83
+ outputs.
84
+ """
85
+ return ChannelLastIOWrapper(module, args=args, outputs=outputs)
@@ -0,0 +1,17 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from .culprit import _search_model
17
+ from .culprit import find_culprits