ai-edge-torch-nightly 0.2.0.dev20240714__py3-none-any.whl → 0.3.0.dev20240926__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (169) hide show
  1. ai_edge_torch/__init__.py +5 -4
  2. ai_edge_torch/_convert/conversion.py +112 -0
  3. ai_edge_torch/_convert/conversion_utils.py +64 -0
  4. ai_edge_torch/{convert → _convert}/converter.py +94 -48
  5. ai_edge_torch/_convert/fx_passes/__init__.py +22 -0
  6. ai_edge_torch/{convert → _convert}/fx_passes/build_aten_composite_pass.py +107 -44
  7. ai_edge_torch/{convert → _convert}/fx_passes/build_interpolate_composite_pass.py +23 -20
  8. ai_edge_torch/{convert → _convert}/fx_passes/inject_mlir_debuginfo_pass.py +5 -6
  9. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/__init__.py +1 -1
  10. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_check.py +39 -9
  11. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_mark.py +2 -0
  12. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_partitioners/__init__.py +1 -0
  13. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py +17 -8
  14. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py +9 -8
  15. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +31 -18
  16. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/op_func_registry.py +2 -2
  17. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/pass_body.py +34 -24
  18. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/utils.py +2 -0
  19. ai_edge_torch/_convert/signature.py +66 -0
  20. ai_edge_torch/_convert/test/test_convert.py +495 -0
  21. ai_edge_torch/_convert/test/test_convert_composites.py +234 -0
  22. ai_edge_torch/_convert/test/test_convert_multisig.py +189 -0
  23. ai_edge_torch/{convert → _convert}/test/test_to_channel_last_io.py +5 -5
  24. ai_edge_torch/{convert → _convert}/to_channel_last_io.py +10 -3
  25. ai_edge_torch/config.py +27 -0
  26. ai_edge_torch/conftest.py +20 -0
  27. ai_edge_torch/debug/culprit.py +72 -40
  28. ai_edge_torch/debug/test/test_culprit.py +7 -5
  29. ai_edge_torch/debug/test/test_search_model.py +8 -7
  30. ai_edge_torch/debug/utils.py +14 -3
  31. ai_edge_torch/fx_pass_base.py +101 -0
  32. ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py +68 -0
  33. ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +68 -0
  34. ai_edge_torch/generative/examples/gemma/{gemma.py → gemma1.py} +69 -55
  35. ai_edge_torch/generative/examples/gemma/gemma2.py +267 -0
  36. ai_edge_torch/generative/examples/gemma/verify_gemma1.py +56 -0
  37. ai_edge_torch/generative/examples/gemma/verify_gemma2.py +57 -0
  38. ai_edge_torch/generative/examples/gemma/verify_util.py +143 -0
  39. ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +68 -0
  40. ai_edge_torch/generative/examples/openelm/openelm.py +206 -0
  41. ai_edge_torch/generative/examples/openelm/verify.py +64 -0
  42. ai_edge_torch/generative/examples/phi/__init__.py +14 -0
  43. ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py +68 -0
  44. ai_edge_torch/generative/examples/phi/convert_to_tflite.py +68 -0
  45. ai_edge_torch/generative/examples/{phi2 → phi}/phi2.py +70 -51
  46. ai_edge_torch/generative/examples/phi/phi3.py +286 -0
  47. ai_edge_torch/generative/examples/phi/verify.py +65 -0
  48. ai_edge_torch/generative/examples/phi/verify_phi3.py +70 -0
  49. ai_edge_torch/generative/examples/smollm/__init__.py +14 -0
  50. ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +68 -0
  51. ai_edge_torch/generative/examples/smollm/smollm.py +101 -0
  52. ai_edge_torch/generative/examples/smollm/verify.py +62 -0
  53. ai_edge_torch/generative/examples/stable_diffusion/attention.py +3 -1
  54. ai_edge_torch/generative/examples/stable_diffusion/clip.py +83 -13
  55. ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +27 -14
  56. ai_edge_torch/generative/examples/stable_diffusion/decoder.py +74 -9
  57. ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +179 -37
  58. ai_edge_torch/generative/examples/stable_diffusion/encoder.py +4 -3
  59. ai_edge_torch/generative/examples/stable_diffusion/pipeline.py +83 -58
  60. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler.py +4 -3
  61. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler_ancestral.py +4 -3
  62. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_lms.py +4 -3
  63. ai_edge_torch/generative/examples/stable_diffusion/samplers/sampler.py +1 -0
  64. ai_edge_torch/generative/examples/stable_diffusion/tokenizer.py +4 -1
  65. ai_edge_torch/generative/examples/stable_diffusion/util.py +9 -3
  66. ai_edge_torch/generative/examples/t5/convert_to_tflite.py +28 -25
  67. ai_edge_torch/generative/examples/t5/t5.py +208 -159
  68. ai_edge_torch/generative/examples/t5/t5_attention.py +45 -30
  69. ai_edge_torch/generative/examples/test_models/convert_toy_model.py +105 -0
  70. ai_edge_torch/generative/examples/test_models/toy_model.py +69 -41
  71. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +50 -64
  72. ai_edge_torch/generative/examples/tiny_llama/__init__.py +14 -0
  73. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +41 -39
  74. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +67 -54
  75. ai_edge_torch/generative/examples/tiny_llama/verify.py +64 -0
  76. ai_edge_torch/generative/fx_passes/__init__.py +4 -5
  77. ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +10 -7
  78. ai_edge_torch/generative/layers/attention.py +141 -102
  79. ai_edge_torch/generative/layers/attention_utils.py +53 -12
  80. ai_edge_torch/generative/layers/builder.py +37 -7
  81. ai_edge_torch/generative/layers/feed_forward.py +39 -14
  82. ai_edge_torch/generative/layers/kv_cache.py +162 -50
  83. ai_edge_torch/generative/layers/model_config.py +84 -30
  84. ai_edge_torch/generative/layers/normalization.py +185 -7
  85. ai_edge_torch/generative/layers/rotary_position_embedding.py +6 -4
  86. ai_edge_torch/generative/layers/scaled_dot_product_attention.py +48 -21
  87. ai_edge_torch/generative/layers/unet/blocks_2d.py +136 -77
  88. ai_edge_torch/generative/layers/unet/builder.py +7 -4
  89. ai_edge_torch/generative/layers/unet/model_config.py +17 -15
  90. ai_edge_torch/generative/quantize/example.py +7 -8
  91. ai_edge_torch/generative/quantize/quant_recipe.py +10 -7
  92. ai_edge_torch/generative/quantize/quant_recipe_utils.py +12 -1
  93. ai_edge_torch/generative/quantize/quant_recipes.py +8 -0
  94. ai_edge_torch/generative/test/test_kv_cache.py +120 -0
  95. ai_edge_torch/generative/test/{loader_test.py → test_loader.py} +9 -7
  96. ai_edge_torch/generative/test/test_model_conversion.py +124 -188
  97. ai_edge_torch/generative/test/test_model_conversion_large.py +251 -0
  98. ai_edge_torch/generative/test/test_quantize.py +76 -60
  99. ai_edge_torch/generative/test/utils.py +54 -0
  100. ai_edge_torch/generative/utilities/converter.py +82 -0
  101. ai_edge_torch/generative/utilities/loader.py +120 -57
  102. ai_edge_torch/generative/utilities/stable_diffusion_loader.py +165 -57
  103. ai_edge_torch/generative/utilities/t5_loader.py +110 -81
  104. ai_edge_torch/generative/utilities/verifier.py +247 -0
  105. ai_edge_torch/hlfb/__init__.py +1 -1
  106. ai_edge_torch/hlfb/mark_pattern/__init__.py +9 -7
  107. ai_edge_torch/hlfb/mark_pattern/passes.py +23 -3
  108. ai_edge_torch/hlfb/mark_pattern/pattern.py +39 -30
  109. ai_edge_torch/hlfb/test/test_mark_pattern.py +46 -20
  110. ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py +24 -11
  111. ai_edge_torch/lowertools/__init__.py +18 -0
  112. ai_edge_torch/lowertools/_shim.py +80 -0
  113. ai_edge_torch/lowertools/common_utils.py +142 -0
  114. ai_edge_torch/lowertools/odml_torch_utils.py +255 -0
  115. ai_edge_torch/lowertools/test_utils.py +60 -0
  116. ai_edge_torch/lowertools/torch_xla_utils.py +284 -0
  117. ai_edge_torch/{generative/quantize/ai_edge_quantizer_glue → lowertools}/translate_recipe.py +29 -14
  118. ai_edge_torch/model.py +53 -18
  119. ai_edge_torch/odml_torch/__init__.py +20 -0
  120. ai_edge_torch/odml_torch/_torch_future.py +61 -0
  121. ai_edge_torch/odml_torch/_torch_library.py +19 -0
  122. ai_edge_torch/odml_torch/composite/__init__.py +16 -0
  123. ai_edge_torch/odml_torch/composite/mark_tensor.py +120 -0
  124. ai_edge_torch/odml_torch/composite/stablehlo_composite_builder.py +106 -0
  125. ai_edge_torch/odml_torch/debuginfo/__init__.py +16 -0
  126. ai_edge_torch/odml_torch/debuginfo/_build.py +43 -0
  127. ai_edge_torch/odml_torch/debuginfo/_op_polyfill.py +55 -0
  128. ai_edge_torch/odml_torch/export.py +357 -0
  129. ai_edge_torch/odml_torch/export_utils.py +168 -0
  130. ai_edge_torch/odml_torch/jax_bridge/__init__.py +15 -0
  131. ai_edge_torch/odml_torch/jax_bridge/_wrap.py +150 -0
  132. ai_edge_torch/odml_torch/jax_bridge/utils.py +75 -0
  133. ai_edge_torch/odml_torch/lowerings/__init__.py +25 -0
  134. ai_edge_torch/odml_torch/lowerings/_basic.py +258 -0
  135. ai_edge_torch/odml_torch/lowerings/_batch_norm.py +65 -0
  136. ai_edge_torch/odml_torch/lowerings/_convolution.py +241 -0
  137. ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +252 -0
  138. ai_edge_torch/odml_torch/lowerings/_layer_norm.py +78 -0
  139. ai_edge_torch/odml_torch/lowerings/context.py +42 -0
  140. ai_edge_torch/odml_torch/lowerings/registry.py +96 -0
  141. ai_edge_torch/odml_torch/lowerings/utils.py +185 -0
  142. ai_edge_torch/odml_torch/passes/__init__.py +38 -0
  143. ai_edge_torch/odml_torch/tf_integration.py +194 -0
  144. ai_edge_torch/quantize/pt2e_quantizer.py +52 -24
  145. ai_edge_torch/quantize/pt2e_quantizer_utils.py +43 -23
  146. ai_edge_torch/quantize/quant_config.py +13 -9
  147. ai_edge_torch/testing/model_coverage/model_coverage.py +29 -16
  148. ai_edge_torch/version.py +16 -0
  149. {ai_edge_torch_nightly-0.2.0.dev20240714.dist-info → ai_edge_torch_nightly-0.3.0.dev20240926.dist-info}/METADATA +7 -3
  150. ai_edge_torch_nightly-0.3.0.dev20240926.dist-info/RECORD +177 -0
  151. {ai_edge_torch_nightly-0.2.0.dev20240714.dist-info → ai_edge_torch_nightly-0.3.0.dev20240926.dist-info}/WHEEL +1 -1
  152. ai_edge_torch/convert/conversion.py +0 -117
  153. ai_edge_torch/convert/conversion_utils.py +0 -400
  154. ai_edge_torch/convert/fx_passes/__init__.py +0 -59
  155. ai_edge_torch/convert/fx_passes/_pass_base.py +0 -49
  156. ai_edge_torch/convert/fx_passes/canonicalize_pass.py +0 -37
  157. ai_edge_torch/convert/test/test_convert.py +0 -311
  158. ai_edge_torch/convert/test/test_convert_composites.py +0 -192
  159. ai_edge_torch/convert/test/test_convert_multisig.py +0 -139
  160. ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +0 -66
  161. ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +0 -64
  162. ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +0 -161
  163. ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/__init__.py +0 -0
  164. ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/RECORD +0 -121
  165. /ai_edge_torch/{convert → _convert}/__init__.py +0 -0
  166. /ai_edge_torch/{convert → _convert}/test/__init__.py +0 -0
  167. /ai_edge_torch/generative/examples/{phi2 → openelm}/__init__.py +0 -0
  168. {ai_edge_torch_nightly-0.2.0.dev20240714.dist-info → ai_edge_torch_nightly-0.3.0.dev20240926.dist-info}/LICENSE +0 -0
  169. {ai_edge_torch_nightly-0.2.0.dev20240714.dist-info → ai_edge_torch_nightly-0.3.0.dev20240926.dist-info}/top_level.txt +0 -0
@@ -1,311 +0,0 @@
1
- # Copyright 2024 The AI Edge Torch Authors.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
-
17
- import os
18
- import tempfile
19
- import unittest
20
-
21
- import torch
22
- import torchvision
23
-
24
- import ai_edge_torch
25
- from ai_edge_torch.convert import conversion_utils as cutils
26
- from ai_edge_torch.testing import model_coverage
27
-
28
-
29
- class TestConvert(unittest.TestCase):
30
- """Tests conversion of various modules."""
31
-
32
- def setUp(self):
33
- torch.manual_seed(0)
34
-
35
- def test_convert_add(self):
36
- """Tests conversion of a simple Add module."""
37
-
38
- class Add(torch.nn.Module):
39
-
40
- def forward(self, a, b):
41
- return a + b
42
-
43
- args = (
44
- torch.randn((5, 10)),
45
- torch.randn((5, 10)),
46
- )
47
- torch_module = Add().eval()
48
- edge_model = ai_edge_torch.convert(torch_module, args)
49
-
50
- self.assertTrue(model_coverage.compare_tflite_torch(edge_model, torch_module, args))
51
-
52
- def test_convert_dot_add(self):
53
- class DotAdd(torch.nn.Module):
54
- """Tests conversion of a matrix multiplication followed by an add."""
55
-
56
- def forward(self, a, b, c):
57
- return a @ b + c
58
-
59
- args = (
60
- torch.randn((5, 10)),
61
- torch.randn((10, 5)),
62
- torch.randn((5, 5)),
63
- )
64
- torch_module = DotAdd().eval()
65
- edge_model = ai_edge_torch.convert(torch_module, args)
66
-
67
- self.assertTrue(model_coverage.compare_tflite_torch(edge_model, torch_module, args))
68
-
69
- def test_convert_resnet18(self):
70
- args = (torch.randn(4, 3, 224, 224),)
71
- torch_module = torchvision.models.resnet18().eval()
72
- edge_model = ai_edge_torch.convert(torch_module, args)
73
-
74
- self.assertTrue(model_coverage.compare_tflite_torch(edge_model, torch_module, args))
75
-
76
- def test_signature_args_ordering(self):
77
- """Tests conversion of a model with more than 10 arguments."""
78
-
79
- class AddChainWith11Args(torch.nn.Module):
80
-
81
- def forward(
82
- self,
83
- arg0: "f32[64]",
84
- arg1: "f32[64]",
85
- arg2: "f32[64]",
86
- arg3: "f32[64]",
87
- arg4: "f32[64]",
88
- arg5: "f32[64]",
89
- arg6: "f32[64]",
90
- arg7: "f32[64]",
91
- arg8: "f32[64]",
92
- arg9: "f32[64]",
93
- arg10: "f32[64]",
94
- ):
95
- add0 = torch.add(arg0, arg1)
96
- add1 = torch.add(add0, arg2)
97
- add2 = torch.add(add1, arg3)
98
- add3 = torch.add(add2, arg4)
99
- add4 = torch.add(add3, arg5)
100
- add5 = torch.add(add4, arg6)
101
- add6 = torch.add(add5, arg7)
102
- add7 = torch.add(add6, arg8)
103
- add8 = torch.add(add7, arg9)
104
- add9 = torch.add(add8, arg10)
105
- return add9
106
-
107
- sample_input = lambda: (
108
- torch.rand((64,), dtype=torch.float32),
109
- torch.rand((64,), dtype=torch.float32),
110
- torch.rand((64,), dtype=torch.float32),
111
- torch.rand((64,), dtype=torch.float32),
112
- torch.rand((64,), dtype=torch.float32),
113
- torch.rand((64,), dtype=torch.float32),
114
- torch.rand((64,), dtype=torch.float32),
115
- torch.rand((64,), dtype=torch.float32),
116
- torch.rand((64,), dtype=torch.float32),
117
- torch.rand((64,), dtype=torch.float32),
118
- torch.rand((64,), dtype=torch.float32),
119
- )
120
- torch_model = AddChainWith11Args().eval()
121
- edge_model = ai_edge_torch.convert(torch_model, sample_input())
122
-
123
- result = model_coverage.compare_tflite_torch(
124
- edge_model, torch_model, sample_input, num_valid_inputs=10
125
- )
126
- self.assertTrue(result)
127
-
128
- def test_multi_output_model(self):
129
- """Tests conversion of a model that returns multiple outputs."""
130
-
131
- class BasicAddModelWithMultipleOutputs(torch.nn.Module):
132
-
133
- def forward(self, arg0, arg1):
134
- add0 = arg0 + arg1
135
- mul0 = arg0 * arg1
136
- return add0, mul0
137
-
138
- sample_input = (
139
- torch.rand((64,), dtype=torch.float32),
140
- torch.rand((64,), dtype=torch.float32),
141
- )
142
-
143
- torch_model = BasicAddModelWithMultipleOutputs().eval()
144
- edge_model = ai_edge_torch.convert(torch_model, sample_input)
145
-
146
- result = model_coverage.compare_tflite_torch(edge_model, torch_model, sample_input)
147
- self.assertTrue(result)
148
-
149
- def test_12_outputs_model(self):
150
- """Tests conversion of a model that returns multiple outputs."""
151
-
152
- class BasicAddModelWithMultipleOutputs(torch.nn.Module):
153
-
154
- def forward(self, arg0, arg1):
155
- add0 = arg0 + arg1
156
- mul0 = arg0 * arg1
157
- add1 = add0 + mul0
158
- mul1 = add0 * mul0
159
- add2 = add1 + mul1
160
- mul2 = add1 * mul1
161
- add3 = add2 + mul2
162
- mul3 = add2 * mul2
163
- add4 = add3 + mul3
164
- mul4 = add3 * mul3
165
- add5 = add4 + mul4
166
- mul5 = add4 * mul4
167
-
168
- return (
169
- add0,
170
- mul0,
171
- add1,
172
- mul1,
173
- add2,
174
- mul2,
175
- add3,
176
- mul3,
177
- add4,
178
- mul4,
179
- add5,
180
- mul5,
181
- )
182
-
183
- sample_input = (
184
- torch.rand((64,), dtype=torch.float32),
185
- torch.rand((64,), dtype=torch.float32),
186
- )
187
-
188
- torch_model = BasicAddModelWithMultipleOutputs().eval()
189
- edge_model = ai_edge_torch.convert(torch_model, sample_input)
190
-
191
- result = model_coverage.compare_tflite_torch(edge_model, torch_model, sample_input)
192
- self.assertTrue(result)
193
-
194
- def test_apply_tfl_backdoor_flags(self):
195
- """Tests if _apply_tfl_backdoor_flags correctly sets the values in a Converter object."""
196
-
197
- class MockConverterInternalObject:
198
-
199
- def __init__(self):
200
- self.subkey2 = "original_subvalue2"
201
-
202
- class MockConverter:
203
-
204
- def __init__(self):
205
- self.key1 = "original_value1"
206
- self.key2 = MockConverterInternalObject()
207
-
208
- mock_converter = MockConverter()
209
- flags = {"key1": "new_value1", "key2": {"subkey2": "new_subvalue2"}}
210
- cutils._apply_tfl_backdoor_flags(mock_converter, flags)
211
-
212
- self.assertTrue(flags["key1"], "new_value1")
213
- self.assertTrue(flags["key2"]["subkey2"], "new_subvalue2")
214
-
215
- def test_convert_add_backdoor_flags(self):
216
- """Tests conversion of an add module setting a tflite converter flag."""
217
-
218
- class Add(torch.nn.Module):
219
-
220
- def forward(self, a, b):
221
- return a + b
222
-
223
- args = (
224
- torch.randn((5, 10)),
225
- torch.randn((5, 10)),
226
- )
227
- torch_module = Add().eval()
228
-
229
- with tempfile.TemporaryDirectory() as tmp_dir_path:
230
- ir_dump_path = os.path.join(
231
- tmp_dir_path, "test_convert_add_backdoor_flags_mlir_dump"
232
- )
233
- ai_edge_torch.convert(
234
- torch_module, args, _ai_edge_converter_flags={"ir_dump_dir": ir_dump_path}
235
- )
236
- self.assertTrue(os.path.isdir(ir_dump_path))
237
-
238
- def test_convert_model_with_dynamic_batch(self):
239
- """
240
- Test converting a simple model with dynamic batch size.
241
- """
242
-
243
- class SampleModel(torch.nn.Module):
244
-
245
- def __init__(self):
246
- super().__init__()
247
- self.w = torch.ones((10, 10)) * 2.7
248
-
249
- def forward(self, x, y):
250
- return x + y + self.w
251
-
252
- sample_input = (torch.randn(4, 3, 10, 10), torch.randn(4, 3, 10, 10))
253
- batch = torch.export.Dim("batch")
254
- dynamic_shapes = ({0: batch}, {0: batch})
255
-
256
- model = SampleModel().eval()
257
- edge_model = ai_edge_torch.convert(
258
- model, sample_input, dynamic_shapes=dynamic_shapes
259
- )
260
-
261
- for batch_size in [2, 4, 10]:
262
- validate_input = (
263
- torch.randn(batch_size, 3, 10, 10),
264
- torch.randn(batch_size, 3, 10, 10),
265
- )
266
- self.assertTrue(
267
- model_coverage.compare_tflite_torch(edge_model, model, validate_input)
268
- )
269
-
270
- def test_convert_model_with_kwargs(self):
271
- """
272
- Test converting a simple model with sample_kwargs.
273
- """
274
-
275
- class SampleModel(torch.nn.Module):
276
-
277
- def forward(self, x, y):
278
- return x + y
279
-
280
- kwargs_gen = lambda: dict(x=torch.randn(10, 10), y=torch.randn(10, 10))
281
-
282
- model = SampleModel().eval()
283
- edge_model = ai_edge_torch.convert(model, sample_kwargs=kwargs_gen())
284
-
285
- self.assertTrue(
286
- model_coverage.compare_tflite_torch(edge_model, model, kwargs=kwargs_gen)
287
- )
288
-
289
- def test_convert_model_with_args_kwargs(self):
290
- """
291
- Test converting a simple model with both sample_args and sample_kwargs.
292
- """
293
-
294
- class SampleModel(torch.nn.Module):
295
-
296
- def forward(self, x, y):
297
- return x + y
298
-
299
- args_gen = lambda: (torch.randn(10, 10),)
300
- kwargs_gen = lambda: dict(y=torch.randn(10, 10))
301
-
302
- model = SampleModel().eval()
303
- edge_model = ai_edge_torch.convert(model, args_gen(), kwargs_gen())
304
-
305
- self.assertTrue(
306
- model_coverage.compare_tflite_torch(edge_model, model, args_gen, kwargs_gen)
307
- )
308
-
309
-
310
- if __name__ == "__main__":
311
- unittest.main()
@@ -1,192 +0,0 @@
1
- # Copyright 2024 The AI Edge Torch Authors.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
-
17
- from typing import Callable
18
- import unittest
19
-
20
- import parameterized
21
- import torch
22
-
23
- import ai_edge_torch
24
- from ai_edge_torch.testing import model_coverage
25
-
26
-
27
- def _func_to_torch_module(func: Callable):
28
- class TestModule(torch.nn.Module):
29
-
30
- def __init__(self, func):
31
- super().__init__()
32
- self._func = func
33
-
34
- def forward(self, *args, **kwargs):
35
- return self._func(*args, **kwargs)
36
-
37
- return TestModule(func).eval()
38
-
39
-
40
- class TestConvertComposites(unittest.TestCase):
41
- """Tests conversion modules that are meant to be wrapped as composites."""
42
-
43
- def test_convert_hardswish(self):
44
- """Tests conversion of a HardSwish module."""
45
-
46
- args = (torch.randn((5, 10)),)
47
- torch_module = torch.nn.Hardswish().eval()
48
- edge_model = ai_edge_torch.convert(torch_module, args)
49
-
50
- self.assertTrue(model_coverage.compare_tflite_torch(edge_model, torch_module, args))
51
-
52
- @parameterized.parameterized.expand(
53
- [
54
- # input_size, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override
55
- # no padding, stride = 1
56
- ([1, 3, 6, 6], [3, 3], [1, 1], [0, 0], False, True, None),
57
- # add stride
58
- ([1, 3, 6, 6], [3, 3], [2, 2], [0, 0], False, True, None),
59
- # default values
60
- ([1, 3, 6, 6], [3, 3]),
61
- # add padding
62
- ([1, 3, 6, 6], [3, 3], [1, 1], [1, 1], False, True, None),
63
- # add different padding for different dims
64
- ([1, 3, 6, 6], [3, 3], [1, 1], [0, 1], False, True, None),
65
- # add both stride and padding
66
- ([1, 3, 6, 6], [3, 3], [2, 2], [1, 1], False, True, None),
67
- # count_include_pad = False
68
- ([1, 3, 6, 6], [3, 3], [1, 1], [1, 1], False, False, None),
69
- # ceil_mode = True
70
- ([1, 3, 6, 6], [3, 3], [1, 1], [1, 1], True, True, None),
71
- # ceil_mode = True, stride=[3, 3]
72
- ([1, 3, 6, 6], [3, 3], [3, 3], [1, 1], True, True, None),
73
- # set divisor_override
74
- ([1, 3, 6, 6], [3, 3], [1, 1], 0, False, True, 6),
75
- # padding set to one number
76
- ([1, 3, 6, 6], [3, 3], [1, 1], 1, False, True, None),
77
- ]
78
- )
79
- def test_convert_avg_pool2d(self, input_size, *args):
80
- """Tests conversion of a module containing an avg_pool2d aten."""
81
- torch_module = _func_to_torch_module(
82
- lambda input_tensor: torch.ops.aten.avg_pool2d(input_tensor, *args)
83
- )
84
- tracing_args = (torch.randn(*input_size),)
85
- edge_model = ai_edge_torch.convert(torch_module, tracing_args)
86
-
87
- self.assertTrue(
88
- model_coverage.compare_tflite_torch(edge_model, torch_module, tracing_args)
89
- )
90
-
91
- @parameterized.parameterized.expand(
92
- [
93
- # use scale_factor with align_corners=False
94
- (
95
- [1, 3, 10, 10],
96
- dict(scale_factor=3.0, mode='bilinear', align_corners=False),
97
- ),
98
- # use scale_factor with align_corners=true
99
- ([1, 3, 10, 10], dict(scale_factor=3.0, mode='bilinear', align_corners=True)),
100
- # use size
101
- ([1, 3, 10, 10], dict(size=[15, 20], mode='bilinear')),
102
- # use size with align_corners=true
103
- ([1, 3, 10, 10], dict(size=[15, 20], mode='bilinear', align_corners=True)),
104
- ]
105
- )
106
- def test_convert_upsample_bilinear_functional(self, input_size, kwargs):
107
- """Tests conversion of a torch.nn.functional.upsample module."""
108
- torch_module = _func_to_torch_module(
109
- lambda input_tensor: torch.nn.functional.upsample(input_tensor, **kwargs)
110
- )
111
- tracing_args = (torch.randn(*input_size),)
112
- edge_model = ai_edge_torch.convert(torch_module, tracing_args)
113
-
114
- self.assertTrue(
115
- model_coverage.compare_tflite_torch(edge_model, torch_module, tracing_args)
116
- )
117
-
118
- @parameterized.parameterized.expand(
119
- [
120
- # use scale_factor with align_corners=False
121
- (
122
- [1, 3, 10, 10],
123
- dict(scale_factor=3.0, mode='bilinear', align_corners=False),
124
- ),
125
- # use scale_factor with align_corners=true
126
- ([1, 3, 10, 10], dict(scale_factor=3.0, mode='bilinear', align_corners=True)),
127
- # use size
128
- ([1, 3, 10, 10], dict(size=[15, 20], mode='bilinear')),
129
- # use size with align_corners=true
130
- ([1, 3, 10, 10], dict(size=[15, 20], mode='bilinear', align_corners=True)),
131
- ]
132
- )
133
- def test_convert_upsample_bilinear(self, input_size, kwargs):
134
- """Tests conversion of a torch.nn.Upsample module."""
135
- torch_module = _func_to_torch_module(
136
- lambda input_tensor: torch.nn.Upsample(**kwargs)(input_tensor)
137
- )
138
- tracing_args = (torch.randn(*input_size),)
139
- edge_model = ai_edge_torch.convert(torch_module, tracing_args)
140
-
141
- self.assertTrue(
142
- model_coverage.compare_tflite_torch(edge_model, torch_module, tracing_args)
143
- )
144
-
145
- @parameterized.parameterized.expand(
146
- [
147
- # use scale_factor with align_corners=False
148
- (
149
- [1, 3, 10, 10],
150
- dict(scale_factor=3.0, mode='bilinear', align_corners=False),
151
- ),
152
- # use scale_factor with align_corners=true
153
- ([1, 3, 10, 10], dict(scale_factor=3.0, mode='bilinear', align_corners=True)),
154
- # use size
155
- ([1, 3, 10, 10], dict(size=[15, 20], mode='bilinear')),
156
- # use size with align_corners=true
157
- ([1, 3, 10, 10], dict(size=[15, 20], mode='bilinear', align_corners=True)),
158
- ]
159
- )
160
- def test_convert_interpolate_bilinear_functional(self, input_size, kwargs):
161
- """Tests conversion of a torch.nn.functional.interpolate module."""
162
- torch_module = _func_to_torch_module(
163
- lambda input_tensor: torch.nn.functional.interpolate(input_tensor, **kwargs)
164
- )
165
- tracing_args = (torch.randn(*input_size),)
166
- edge_model = ai_edge_torch.convert(torch_module, tracing_args)
167
-
168
- self.assertTrue(
169
- model_coverage.compare_tflite_torch(edge_model, torch_module, tracing_args)
170
- )
171
-
172
- def test_convert_gelu(self):
173
- """Tests conversion of a GELU module."""
174
-
175
- args = (torch.randn((5, 10)),)
176
- torch_module = torch.nn.GELU().eval()
177
- edge_model = ai_edge_torch.convert(torch_module, args)
178
-
179
- self.assertTrue(model_coverage.compare_tflite_torch(edge_model, torch_module, args))
180
-
181
- def test_convert_gelu_approximate(self):
182
- """Tests conversion of an Approximate GELU module."""
183
-
184
- args = (torch.randn((5, 10)),)
185
- torch_module = torch.nn.GELU('tanh').eval()
186
- edge_model = ai_edge_torch.convert(torch_module, args)
187
-
188
- self.assertTrue(model_coverage.compare_tflite_torch(edge_model, torch_module, args))
189
-
190
-
191
- if __name__ == '__main__':
192
- unittest.main()
@@ -1,139 +0,0 @@
1
- # Copyright 2024 The AI Edge Torch Authors.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- import unittest
17
-
18
- import torch
19
- import torchvision
20
-
21
- import ai_edge_torch
22
- from ai_edge_torch.testing import model_coverage
23
-
24
-
25
- class TestConvertMultiSignature(unittest.TestCase):
26
- """Tests conversion of various modules through multi-signature conversion."""
27
-
28
- def setUp(self):
29
- torch.manual_seed(0)
30
-
31
- def test_convert_mobilenet_v2_with_default(self):
32
- """Tests conversion of a model with two signatures one of which is the default."""
33
- torch_module = torchvision.models.mobilenet_v2().eval()
34
-
35
- args = (torch.randn(4, 3, 224, 224),)
36
- large_args = (torch.randn(4, 3, 336, 336),)
37
-
38
- signature_name = "large_input"
39
-
40
- edge_model = ai_edge_torch.signature(
41
- signature_name, torch_module, large_args
42
- ).convert(torch_module, args)
43
-
44
- self.assertTrue(model_coverage.compare_tflite_torch(edge_model, torch_module, args))
45
- self.assertTrue(
46
- model_coverage.compare_tflite_torch(
47
- edge_model, torch_module, large_args, signature_name=signature_name
48
- )
49
- )
50
-
51
- def test_convert_mobilenet_v2_no_default(self):
52
- """Tests conversion of a model with two signatures none of which is the default."""
53
- torch_module = torchvision.models.mobilenet_v2().eval()
54
-
55
- args = (torch.randn(4, 3, 224, 224),)
56
- large_args = (torch.randn(4, 3, 336, 336),)
57
-
58
- signature_name_1 = "input"
59
- signature_name_2 = "large_input"
60
-
61
- edge_model = (
62
- ai_edge_torch.signature(signature_name_1, torch_module, args)
63
- .signature(signature_name_2, torch_module, large_args)
64
- .convert()
65
- )
66
-
67
- with self.assertRaises(ValueError):
68
- edge_model(*args)
69
-
70
- self.assertTrue(
71
- model_coverage.compare_tflite_torch(
72
- edge_model, torch_module, args, signature_name=signature_name_1
73
- )
74
- )
75
- self.assertTrue(
76
- model_coverage.compare_tflite_torch(
77
- edge_model, torch_module, large_args, signature_name=signature_name_2
78
- )
79
- )
80
-
81
- def test_convert_mobilenet_v2_signature_helper(self):
82
- """Tests the ai_edge_torch.signature helper function works."""
83
- torch_module = torchvision.models.mobilenet_v2().eval()
84
-
85
- args = (torch.randn(4, 3, 224, 224),)
86
- large_args = (torch.randn(4, 3, 336, 336),)
87
-
88
- signature_name = "large_input"
89
-
90
- edge_model = ai_edge_torch.signature(signature_name, torch_module, args).convert(
91
- torch_module, large_args
92
- )
93
-
94
- self.assertTrue(model_coverage.compare_tflite_torch(edge_model, torch_module, args))
95
- self.assertTrue(
96
- model_coverage.compare_tflite_torch(
97
- edge_model, torch_module, large_args, signature_name=signature_name
98
- )
99
- )
100
-
101
- def test_convert_separate_modules(self):
102
- """Tests conversion of two completely different modules as separate signatures."""
103
- mobilentv2 = torchvision.models.mobilenet_v2().eval()
104
- resnet18 = torchvision.models.resnet18().eval()
105
-
106
- mobilenet_args = (torch.randn(4, 3, 224, 224),)
107
- resnet_args = (torch.randn(4, 3, 224, 224),)
108
-
109
- mobilenet_signature_name = "mobilentv2"
110
- resnet_signature_name = "resnet18"
111
-
112
- edge_model = (
113
- ai_edge_torch.signature(mobilenet_signature_name, mobilentv2, mobilenet_args)
114
- .signature(resnet_signature_name, resnet18, resnet_args)
115
- .convert(resnet18, resnet_args)
116
- )
117
-
118
- mobilenet_inference_args = (torch.randn(4, 3, 224, 224),)
119
- resnet_inference_args = (torch.randn(4, 3, 224, 224),)
120
- self.assertTrue(
121
- model_coverage.compare_tflite_torch(
122
- edge_model,
123
- mobilentv2,
124
- mobilenet_inference_args,
125
- signature_name=mobilenet_signature_name,
126
- )
127
- )
128
- self.assertTrue(
129
- model_coverage.compare_tflite_torch(
130
- edge_model,
131
- resnet18,
132
- resnet_inference_args,
133
- signature_name=resnet_signature_name,
134
- )
135
- )
136
-
137
-
138
- if __name__ == "__main__":
139
- unittest.main()
@@ -1,66 +0,0 @@
1
- # Copyright 2024 The AI Edge Torch Authors.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- import os
17
- from pathlib import Path
18
-
19
- import torch
20
-
21
- import ai_edge_torch
22
- from ai_edge_torch.generative.examples.gemma import gemma
23
- from ai_edge_torch.generative.quantize import quant_recipes
24
-
25
-
26
- def convert_gemma_to_tflite(
27
- checkpoint_path: str,
28
- prefill_seq_len: int = 512,
29
- kv_cache_max_len: int = 1024,
30
- quantize: bool = True,
31
- ):
32
- """An example method for converting a Gemma 2B model to multi-signature
33
- tflite model.
34
-
35
- Args:
36
- checkpoint_path (str): The filepath to the model checkpoint, or directory holding the checkpoint.
37
- prefill_seq_len (int, optional): The maximum size of prefill input tensor.
38
- Defaults to 512.
39
- kv_cache_max_len (int, optional): The maximum size of KV cache buffer,
40
- including both prefill and decode. Defaults to 1024.
41
- quantize (bool, optional): Whether the model should be quanized.
42
- Defaults to True.
43
- """
44
- pytorch_model = gemma.build_2b_model(
45
- checkpoint_path, kv_cache_max_len=kv_cache_max_len
46
- )
47
- # Tensors used to trace the model graph during conversion.
48
- prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.long)
49
- prefill_input_pos = torch.arange(0, prefill_seq_len)
50
- decode_token = torch.tensor([[0]], dtype=torch.long)
51
- decode_input_pos = torch.tensor([0], dtype=torch.int64)
52
-
53
- quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
54
- edge_model = (
55
- ai_edge_torch.signature(
56
- 'prefill', pytorch_model, (prefill_tokens, prefill_input_pos)
57
- )
58
- .signature('decode', pytorch_model, (decode_token, decode_input_pos))
59
- .convert(quant_config=quant_config)
60
- )
61
- edge_model.export(f'/tmp/gemma_seq{prefill_seq_len}_kv{kv_cache_max_len}.tflite')
62
-
63
-
64
- if __name__ == '__main__':
65
- checkpoint_path = os.path.join(Path.home(), 'Downloads/llm_data/gemma-2b')
66
- convert_gemma_to_tflite(checkpoint_path)