ai-edge-torch-nightly 0.2.0.dev20240714__py3-none-any.whl → 0.3.0.dev20240926__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (169) hide show
  1. ai_edge_torch/__init__.py +5 -4
  2. ai_edge_torch/_convert/conversion.py +112 -0
  3. ai_edge_torch/_convert/conversion_utils.py +64 -0
  4. ai_edge_torch/{convert → _convert}/converter.py +94 -48
  5. ai_edge_torch/_convert/fx_passes/__init__.py +22 -0
  6. ai_edge_torch/{convert → _convert}/fx_passes/build_aten_composite_pass.py +107 -44
  7. ai_edge_torch/{convert → _convert}/fx_passes/build_interpolate_composite_pass.py +23 -20
  8. ai_edge_torch/{convert → _convert}/fx_passes/inject_mlir_debuginfo_pass.py +5 -6
  9. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/__init__.py +1 -1
  10. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_check.py +39 -9
  11. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_mark.py +2 -0
  12. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_partitioners/__init__.py +1 -0
  13. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py +17 -8
  14. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py +9 -8
  15. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +31 -18
  16. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/op_func_registry.py +2 -2
  17. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/pass_body.py +34 -24
  18. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/utils.py +2 -0
  19. ai_edge_torch/_convert/signature.py +66 -0
  20. ai_edge_torch/_convert/test/test_convert.py +495 -0
  21. ai_edge_torch/_convert/test/test_convert_composites.py +234 -0
  22. ai_edge_torch/_convert/test/test_convert_multisig.py +189 -0
  23. ai_edge_torch/{convert → _convert}/test/test_to_channel_last_io.py +5 -5
  24. ai_edge_torch/{convert → _convert}/to_channel_last_io.py +10 -3
  25. ai_edge_torch/config.py +27 -0
  26. ai_edge_torch/conftest.py +20 -0
  27. ai_edge_torch/debug/culprit.py +72 -40
  28. ai_edge_torch/debug/test/test_culprit.py +7 -5
  29. ai_edge_torch/debug/test/test_search_model.py +8 -7
  30. ai_edge_torch/debug/utils.py +14 -3
  31. ai_edge_torch/fx_pass_base.py +101 -0
  32. ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py +68 -0
  33. ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +68 -0
  34. ai_edge_torch/generative/examples/gemma/{gemma.py → gemma1.py} +69 -55
  35. ai_edge_torch/generative/examples/gemma/gemma2.py +267 -0
  36. ai_edge_torch/generative/examples/gemma/verify_gemma1.py +56 -0
  37. ai_edge_torch/generative/examples/gemma/verify_gemma2.py +57 -0
  38. ai_edge_torch/generative/examples/gemma/verify_util.py +143 -0
  39. ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +68 -0
  40. ai_edge_torch/generative/examples/openelm/openelm.py +206 -0
  41. ai_edge_torch/generative/examples/openelm/verify.py +64 -0
  42. ai_edge_torch/generative/examples/phi/__init__.py +14 -0
  43. ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py +68 -0
  44. ai_edge_torch/generative/examples/phi/convert_to_tflite.py +68 -0
  45. ai_edge_torch/generative/examples/{phi2 → phi}/phi2.py +70 -51
  46. ai_edge_torch/generative/examples/phi/phi3.py +286 -0
  47. ai_edge_torch/generative/examples/phi/verify.py +65 -0
  48. ai_edge_torch/generative/examples/phi/verify_phi3.py +70 -0
  49. ai_edge_torch/generative/examples/smollm/__init__.py +14 -0
  50. ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +68 -0
  51. ai_edge_torch/generative/examples/smollm/smollm.py +101 -0
  52. ai_edge_torch/generative/examples/smollm/verify.py +62 -0
  53. ai_edge_torch/generative/examples/stable_diffusion/attention.py +3 -1
  54. ai_edge_torch/generative/examples/stable_diffusion/clip.py +83 -13
  55. ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +27 -14
  56. ai_edge_torch/generative/examples/stable_diffusion/decoder.py +74 -9
  57. ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +179 -37
  58. ai_edge_torch/generative/examples/stable_diffusion/encoder.py +4 -3
  59. ai_edge_torch/generative/examples/stable_diffusion/pipeline.py +83 -58
  60. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler.py +4 -3
  61. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler_ancestral.py +4 -3
  62. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_lms.py +4 -3
  63. ai_edge_torch/generative/examples/stable_diffusion/samplers/sampler.py +1 -0
  64. ai_edge_torch/generative/examples/stable_diffusion/tokenizer.py +4 -1
  65. ai_edge_torch/generative/examples/stable_diffusion/util.py +9 -3
  66. ai_edge_torch/generative/examples/t5/convert_to_tflite.py +28 -25
  67. ai_edge_torch/generative/examples/t5/t5.py +208 -159
  68. ai_edge_torch/generative/examples/t5/t5_attention.py +45 -30
  69. ai_edge_torch/generative/examples/test_models/convert_toy_model.py +105 -0
  70. ai_edge_torch/generative/examples/test_models/toy_model.py +69 -41
  71. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +50 -64
  72. ai_edge_torch/generative/examples/tiny_llama/__init__.py +14 -0
  73. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +41 -39
  74. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +67 -54
  75. ai_edge_torch/generative/examples/tiny_llama/verify.py +64 -0
  76. ai_edge_torch/generative/fx_passes/__init__.py +4 -5
  77. ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +10 -7
  78. ai_edge_torch/generative/layers/attention.py +141 -102
  79. ai_edge_torch/generative/layers/attention_utils.py +53 -12
  80. ai_edge_torch/generative/layers/builder.py +37 -7
  81. ai_edge_torch/generative/layers/feed_forward.py +39 -14
  82. ai_edge_torch/generative/layers/kv_cache.py +162 -50
  83. ai_edge_torch/generative/layers/model_config.py +84 -30
  84. ai_edge_torch/generative/layers/normalization.py +185 -7
  85. ai_edge_torch/generative/layers/rotary_position_embedding.py +6 -4
  86. ai_edge_torch/generative/layers/scaled_dot_product_attention.py +48 -21
  87. ai_edge_torch/generative/layers/unet/blocks_2d.py +136 -77
  88. ai_edge_torch/generative/layers/unet/builder.py +7 -4
  89. ai_edge_torch/generative/layers/unet/model_config.py +17 -15
  90. ai_edge_torch/generative/quantize/example.py +7 -8
  91. ai_edge_torch/generative/quantize/quant_recipe.py +10 -7
  92. ai_edge_torch/generative/quantize/quant_recipe_utils.py +12 -1
  93. ai_edge_torch/generative/quantize/quant_recipes.py +8 -0
  94. ai_edge_torch/generative/test/test_kv_cache.py +120 -0
  95. ai_edge_torch/generative/test/{loader_test.py → test_loader.py} +9 -7
  96. ai_edge_torch/generative/test/test_model_conversion.py +124 -188
  97. ai_edge_torch/generative/test/test_model_conversion_large.py +251 -0
  98. ai_edge_torch/generative/test/test_quantize.py +76 -60
  99. ai_edge_torch/generative/test/utils.py +54 -0
  100. ai_edge_torch/generative/utilities/converter.py +82 -0
  101. ai_edge_torch/generative/utilities/loader.py +120 -57
  102. ai_edge_torch/generative/utilities/stable_diffusion_loader.py +165 -57
  103. ai_edge_torch/generative/utilities/t5_loader.py +110 -81
  104. ai_edge_torch/generative/utilities/verifier.py +247 -0
  105. ai_edge_torch/hlfb/__init__.py +1 -1
  106. ai_edge_torch/hlfb/mark_pattern/__init__.py +9 -7
  107. ai_edge_torch/hlfb/mark_pattern/passes.py +23 -3
  108. ai_edge_torch/hlfb/mark_pattern/pattern.py +39 -30
  109. ai_edge_torch/hlfb/test/test_mark_pattern.py +46 -20
  110. ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py +24 -11
  111. ai_edge_torch/lowertools/__init__.py +18 -0
  112. ai_edge_torch/lowertools/_shim.py +80 -0
  113. ai_edge_torch/lowertools/common_utils.py +142 -0
  114. ai_edge_torch/lowertools/odml_torch_utils.py +255 -0
  115. ai_edge_torch/lowertools/test_utils.py +60 -0
  116. ai_edge_torch/lowertools/torch_xla_utils.py +284 -0
  117. ai_edge_torch/{generative/quantize/ai_edge_quantizer_glue → lowertools}/translate_recipe.py +29 -14
  118. ai_edge_torch/model.py +53 -18
  119. ai_edge_torch/odml_torch/__init__.py +20 -0
  120. ai_edge_torch/odml_torch/_torch_future.py +61 -0
  121. ai_edge_torch/odml_torch/_torch_library.py +19 -0
  122. ai_edge_torch/odml_torch/composite/__init__.py +16 -0
  123. ai_edge_torch/odml_torch/composite/mark_tensor.py +120 -0
  124. ai_edge_torch/odml_torch/composite/stablehlo_composite_builder.py +106 -0
  125. ai_edge_torch/odml_torch/debuginfo/__init__.py +16 -0
  126. ai_edge_torch/odml_torch/debuginfo/_build.py +43 -0
  127. ai_edge_torch/odml_torch/debuginfo/_op_polyfill.py +55 -0
  128. ai_edge_torch/odml_torch/export.py +357 -0
  129. ai_edge_torch/odml_torch/export_utils.py +168 -0
  130. ai_edge_torch/odml_torch/jax_bridge/__init__.py +15 -0
  131. ai_edge_torch/odml_torch/jax_bridge/_wrap.py +150 -0
  132. ai_edge_torch/odml_torch/jax_bridge/utils.py +75 -0
  133. ai_edge_torch/odml_torch/lowerings/__init__.py +25 -0
  134. ai_edge_torch/odml_torch/lowerings/_basic.py +258 -0
  135. ai_edge_torch/odml_torch/lowerings/_batch_norm.py +65 -0
  136. ai_edge_torch/odml_torch/lowerings/_convolution.py +241 -0
  137. ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +252 -0
  138. ai_edge_torch/odml_torch/lowerings/_layer_norm.py +78 -0
  139. ai_edge_torch/odml_torch/lowerings/context.py +42 -0
  140. ai_edge_torch/odml_torch/lowerings/registry.py +96 -0
  141. ai_edge_torch/odml_torch/lowerings/utils.py +185 -0
  142. ai_edge_torch/odml_torch/passes/__init__.py +38 -0
  143. ai_edge_torch/odml_torch/tf_integration.py +194 -0
  144. ai_edge_torch/quantize/pt2e_quantizer.py +52 -24
  145. ai_edge_torch/quantize/pt2e_quantizer_utils.py +43 -23
  146. ai_edge_torch/quantize/quant_config.py +13 -9
  147. ai_edge_torch/testing/model_coverage/model_coverage.py +29 -16
  148. ai_edge_torch/version.py +16 -0
  149. {ai_edge_torch_nightly-0.2.0.dev20240714.dist-info → ai_edge_torch_nightly-0.3.0.dev20240926.dist-info}/METADATA +7 -3
  150. ai_edge_torch_nightly-0.3.0.dev20240926.dist-info/RECORD +177 -0
  151. {ai_edge_torch_nightly-0.2.0.dev20240714.dist-info → ai_edge_torch_nightly-0.3.0.dev20240926.dist-info}/WHEEL +1 -1
  152. ai_edge_torch/convert/conversion.py +0 -117
  153. ai_edge_torch/convert/conversion_utils.py +0 -400
  154. ai_edge_torch/convert/fx_passes/__init__.py +0 -59
  155. ai_edge_torch/convert/fx_passes/_pass_base.py +0 -49
  156. ai_edge_torch/convert/fx_passes/canonicalize_pass.py +0 -37
  157. ai_edge_torch/convert/test/test_convert.py +0 -311
  158. ai_edge_torch/convert/test/test_convert_composites.py +0 -192
  159. ai_edge_torch/convert/test/test_convert_multisig.py +0 -139
  160. ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +0 -66
  161. ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +0 -64
  162. ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +0 -161
  163. ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/__init__.py +0 -0
  164. ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/RECORD +0 -121
  165. /ai_edge_torch/{convert → _convert}/__init__.py +0 -0
  166. /ai_edge_torch/{convert → _convert}/test/__init__.py +0 -0
  167. /ai_edge_torch/generative/examples/{phi2 → openelm}/__init__.py +0 -0
  168. {ai_edge_torch_nightly-0.2.0.dev20240714.dist-info → ai_edge_torch_nightly-0.3.0.dev20240926.dist-info}/LICENSE +0 -0
  169. {ai_edge_torch_nightly-0.2.0.dev20240714.dist-info → ai_edge_torch_nightly-0.3.0.dev20240926.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,251 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Testing model conversion for a few gen-ai models."""
17
+
18
+ import ai_edge_torch
19
+ from ai_edge_torch import config as ai_edge_config
20
+ from ai_edge_torch.generative.examples.gemma import gemma1
21
+ from ai_edge_torch.generative.examples.gemma import gemma2
22
+ from ai_edge_torch.generative.examples.openelm import openelm
23
+ from ai_edge_torch.generative.examples.phi import phi2
24
+ from ai_edge_torch.generative.examples.phi import phi3
25
+ from ai_edge_torch.generative.examples.smollm import smollm
26
+ from ai_edge_torch.generative.examples.stable_diffusion import clip as sd_clip
27
+ from ai_edge_torch.generative.examples.stable_diffusion import decoder as sd_decoder
28
+ from ai_edge_torch.generative.examples.stable_diffusion import diffusion as sd_diffusion
29
+ from ai_edge_torch.generative.layers import kv_cache
30
+ from ai_edge_torch.generative.test import utils as test_utils
31
+ import numpy as np
32
+ import torch
33
+
34
+ from absl.testing import absltest as googletest
35
+ from ai_edge_litert import interpreter
36
+
37
+
38
+ class TestModelConversion(googletest.TestCase):
39
+ """Unit tests that check for model conversion and correctness."""
40
+
41
+ def setUp(self):
42
+ super().setUp()
43
+ # Builder function for an Interpreter that supports custom ops.
44
+ self._interpreter_builder = (
45
+ lambda tflite_model: lambda: interpreter.InterpreterWithCustomOps(
46
+ custom_op_registerers=["GenAIOpsRegisterer"],
47
+ model_content=tflite_model,
48
+ experimental_default_delegate_latest_features=True,
49
+ )
50
+ )
51
+
52
+ def _test_model(self, config, model, signature_name, atol, rtol):
53
+ idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
54
+ tokens = torch.full((1, 10), 0, dtype=torch.int, device="cpu")
55
+ tokens[0, :4] = idx
56
+ input_pos = torch.arange(0, 10, dtype=torch.int)
57
+ kv = kv_cache.KVCache.from_model_config(config)
58
+
59
+ edge_model = ai_edge_torch.signature(
60
+ signature_name,
61
+ model,
62
+ sample_kwargs={
63
+ "tokens": tokens,
64
+ "input_pos": input_pos,
65
+ "kv_cache": kv,
66
+ },
67
+ ).convert()
68
+ edge_model.set_interpreter_builder(
69
+ self._interpreter_builder(edge_model.tflite_model())
70
+ )
71
+
72
+ self.assertTrue(
73
+ test_utils.compare_tflite_torch(
74
+ edge_model,
75
+ model,
76
+ tokens,
77
+ input_pos,
78
+ kv,
79
+ signature_name=signature_name,
80
+ atol=atol,
81
+ rtol=rtol,
82
+ )
83
+ )
84
+
85
+ @googletest.skipIf(
86
+ ai_edge_config.Config.use_torch_xla,
87
+ reason="tests with custom ops are not supported on oss",
88
+ )
89
+ def test_gemma1(self):
90
+ config = gemma1.get_fake_model_config()
91
+ pytorch_model = gemma1.Gemma(config).eval()
92
+ self._test_model(
93
+ config, pytorch_model, "serving_default", atol=1e-2, rtol=1e-5
94
+ )
95
+
96
+ @googletest.skipIf(
97
+ ai_edge_config.Config.use_torch_xla,
98
+ reason="tests with custom ops are not supported on oss",
99
+ )
100
+ def test_gemma2(self):
101
+ config = gemma2.get_fake_model_config()
102
+ pytorch_model = gemma2.Gemma2(config).eval()
103
+ self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
104
+
105
+ @googletest.skipIf(
106
+ ai_edge_config.Config.use_torch_xla,
107
+ reason="tests with custom ops are not supported on oss",
108
+ )
109
+ def test_phi2(self):
110
+ config = phi2.get_fake_model_config()
111
+ pytorch_model = phi2.Phi2(config).eval()
112
+ self._test_model(
113
+ config, pytorch_model, "serving_default", atol=1e-3, rtol=1e-3
114
+ )
115
+
116
+ @googletest.skipIf(
117
+ ai_edge_config.Config.use_torch_xla,
118
+ reason="tests with custom ops are not supported on oss",
119
+ )
120
+ def test_phi3(self):
121
+ config = phi3.get_fake_model_config()
122
+ pytorch_model = phi3.Phi3_5Mini(config).eval()
123
+ self._test_model(
124
+ config, pytorch_model, "prefill", atol=1e-5, rtol=1e-5
125
+ )
126
+
127
+ @googletest.skipIf(
128
+ ai_edge_config.Config.use_torch_xla,
129
+ reason="tests with custom ops are not supported on oss",
130
+ )
131
+ def test_smollm(self):
132
+ config = smollm.get_fake_model_config()
133
+ pytorch_model = smollm.SmolLM(config).eval()
134
+ self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
135
+
136
+ @googletest.skipIf(
137
+ ai_edge_config.Config.use_torch_xla,
138
+ reason="tests with custom ops are not supported on oss",
139
+ )
140
+ def test_openelm(self):
141
+ config = openelm.get_fake_model_config()
142
+ pytorch_model = openelm.OpenELM(config).eval()
143
+ self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
144
+
145
+ @googletest.skipIf(
146
+ ai_edge_config.Config.use_torch_xla,
147
+ reason="tests with custom ops are not supported on oss",
148
+ )
149
+ def test_stable_diffusion_clip(self):
150
+ config = sd_clip.get_fake_model_config()
151
+ prompt_tokens = torch.from_numpy(
152
+ np.array([[1, 2, 3, 4, 5, 6]], dtype=np.int32)
153
+ )
154
+
155
+ pytorch_model = sd_clip.CLIP(config).eval()
156
+ torch_output = pytorch_model(prompt_tokens)
157
+
158
+ edge_model = ai_edge_torch.signature(
159
+ "encode", pytorch_model, (prompt_tokens,)
160
+ ).convert()
161
+ edge_model.set_interpreter_builder(
162
+ self._interpreter_builder(edge_model.tflite_model())
163
+ )
164
+ edge_output = edge_model(
165
+ prompt_tokens.numpy(),
166
+ signature_name="encode",
167
+ )
168
+ self.assertTrue(
169
+ np.allclose(
170
+ edge_output,
171
+ torch_output.detach().numpy(),
172
+ atol=1e-4,
173
+ rtol=1e-5,
174
+ )
175
+ )
176
+
177
+ @googletest.skipIf(
178
+ ai_edge_config.Config.use_torch_xla,
179
+ reason="tests with custom ops are not supported on oss",
180
+ )
181
+ def test_stable_diffusion_diffusion(self):
182
+ config = sd_diffusion.get_fake_model_config(2)
183
+ latents = torch.from_numpy(
184
+ np.random.normal(size=(2, 4, 8, 8)).astype(np.float32)
185
+ )
186
+ context = torch.from_numpy(
187
+ np.random.normal(size=(2, 4, 4)).astype(np.float32)
188
+ )
189
+ time_embedding = torch.from_numpy(
190
+ np.random.normal(size=(2, 2)).astype(np.float32)
191
+ )
192
+
193
+ pytorch_model = sd_diffusion.Diffusion(config).eval()
194
+ torch_output = pytorch_model(latents, context, time_embedding)
195
+
196
+ edge_model = ai_edge_torch.signature(
197
+ "diffusion", pytorch_model, (latents, context, time_embedding)
198
+ ).convert()
199
+ edge_model.set_interpreter_builder(
200
+ self._interpreter_builder(edge_model.tflite_model())
201
+ )
202
+ edge_output = edge_model(
203
+ latents.numpy(),
204
+ context.numpy(),
205
+ time_embedding.numpy(),
206
+ signature_name="diffusion",
207
+ )
208
+ self.assertTrue(
209
+ np.allclose(
210
+ edge_output,
211
+ torch_output.detach().numpy(),
212
+ atol=1e-4,
213
+ rtol=1e-5,
214
+ )
215
+ )
216
+
217
+ @googletest.skipIf(
218
+ ai_edge_config.Config.use_torch_xla,
219
+ reason="tests with custom ops are not supported on oss",
220
+ )
221
+ def test_stable_diffusion_decoder(self):
222
+ config = sd_decoder.get_fake_model_config()
223
+ latents = torch.from_numpy(
224
+ np.random.normal(size=(1, 4, 64, 64)).astype(np.float32)
225
+ )
226
+
227
+ pytorch_model = sd_decoder.Decoder(config).eval()
228
+ torch_output = pytorch_model(latents)
229
+
230
+ edge_model = ai_edge_torch.signature(
231
+ "decode", pytorch_model, (latents,)
232
+ ).convert()
233
+ edge_model.set_interpreter_builder(
234
+ self._interpreter_builder(edge_model.tflite_model())
235
+ )
236
+ edge_output = edge_model(
237
+ latents.numpy(),
238
+ signature_name="decode",
239
+ )
240
+ self.assertTrue(
241
+ np.allclose(
242
+ edge_output,
243
+ torch_output.detach().numpy(),
244
+ atol=1e-4,
245
+ rtol=1e-5,
246
+ )
247
+ )
248
+
249
+
250
+ if __name__ == "__main__":
251
+ googletest.main()
@@ -13,12 +13,8 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- import unittest
17
-
18
- from parameterized import parameterized
19
- import torch
20
-
21
16
  import ai_edge_torch
17
+ from ai_edge_torch import config
22
18
  from ai_edge_torch.generative.examples.test_models import toy_model # NOQA
23
19
  from ai_edge_torch.generative.quantize import quant_recipe
24
20
  from ai_edge_torch.generative.quantize import quant_recipe_utils
@@ -29,20 +25,22 @@ from ai_edge_torch.generative.quantize.quant_attrs import Granularity
29
25
  from ai_edge_torch.generative.quantize.quant_attrs import Mode
30
26
  from ai_edge_torch.quantize import quant_config
31
27
  from ai_edge_torch.testing import model_coverage
28
+ import torch
32
29
 
30
+ from absl.testing import absltest as googletest
31
+ from absl.testing import parameterized
33
32
 
34
- class TestVerifyRecipes(unittest.TestCase):
33
+
34
+ class TestVerifyRecipes(parameterized.TestCase):
35
35
  """Unit tests that check for model quantization recipes."""
36
36
 
37
- @parameterized.expand(
38
- [
39
- (Dtype.FP32, Dtype.FP32),
40
- (Dtype.INT8, Dtype.INT8),
41
- (Dtype.INT8, Dtype.FP16),
42
- (Dtype.FP16, Dtype.INT8),
43
- (Dtype.FP16, Dtype.FP16),
44
- ]
45
- )
37
+ @parameterized.parameters([
38
+ (Dtype.FP32, Dtype.FP32),
39
+ (Dtype.INT8, Dtype.INT8),
40
+ (Dtype.INT8, Dtype.FP16),
41
+ (Dtype.FP16, Dtype.INT8),
42
+ (Dtype.FP16, Dtype.FP16),
43
+ ])
46
44
  def test_verify_invalid_recipes(
47
45
  self,
48
46
  activation,
@@ -54,31 +52,29 @@ class TestVerifyRecipes(unittest.TestCase):
54
52
  with self.assertRaises(ValueError):
55
53
  quant_recipe.LayerQuantRecipe(activation, weight, m, a, g).verify()
56
54
 
57
- @parameterized.expand(
58
- [
59
- (
60
- Dtype.FP32,
61
- Dtype.INT8,
62
- Mode.DYNAMIC_RANGE,
63
- Algorithm.MIN_MAX,
64
- Granularity.CHANNELWISE,
65
- ),
66
- (
67
- Dtype.FP32,
68
- Dtype.INT8,
69
- Mode.WEIGHT_ONLY,
70
- Algorithm.MIN_MAX,
71
- Granularity.CHANNELWISE,
72
- ),
73
- (
74
- Dtype.FP32,
75
- Dtype.FP16,
76
- Mode.WEIGHT_ONLY,
77
- Algorithm.FLOAT_CAST,
78
- Granularity.NONE,
79
- ),
80
- ]
81
- )
55
+ @parameterized.parameters([
56
+ (
57
+ Dtype.FP32,
58
+ Dtype.INT8,
59
+ Mode.DYNAMIC_RANGE,
60
+ Algorithm.MIN_MAX,
61
+ Granularity.CHANNELWISE,
62
+ ),
63
+ (
64
+ Dtype.FP32,
65
+ Dtype.INT8,
66
+ Mode.WEIGHT_ONLY,
67
+ Algorithm.MIN_MAX,
68
+ Granularity.CHANNELWISE,
69
+ ),
70
+ (
71
+ Dtype.FP32,
72
+ Dtype.FP16,
73
+ Mode.WEIGHT_ONLY,
74
+ Algorithm.FLOAT_CAST,
75
+ Granularity.NONE,
76
+ ),
77
+ ])
82
78
  def test_verify_valid_recipes(
83
79
  self,
84
80
  activation,
@@ -87,10 +83,12 @@ class TestVerifyRecipes(unittest.TestCase):
87
83
  algo,
88
84
  granularity,
89
85
  ):
90
- quant_recipe.LayerQuantRecipe(activation, weight, mode, algo, granularity).verify()
86
+ quant_recipe.LayerQuantRecipe(
87
+ activation, weight, mode, algo, granularity
88
+ ).verify()
91
89
 
92
90
 
93
- class TestQuantizeConvert(unittest.TestCase):
91
+ class TestQuantizeConvert(parameterized.TestCase):
94
92
  """Test conversion with quantization."""
95
93
 
96
94
  def _attention_int8_dynamic_recipe() -> quant_config.QuantConfig:
@@ -107,35 +105,51 @@ class TestQuantizeConvert(unittest.TestCase):
107
105
  )
108
106
  )
109
107
 
110
- @parameterized.expand(
111
- [
112
- (quant_recipes.full_fp16_recipe(), 0.65),
113
- (quant_recipes.full_int8_dynamic_recipe(), 0.47),
114
- (_attention_int8_dynamic_recipe(), 0.89),
115
- (_feedforward_int8_dynamic_recipe(), 0.72),
116
- ]
117
- )
118
- def test_quantize_convert_toy_sizes(self, quant_config, expected_compression):
108
+ @parameterized.parameters([
109
+ (quant_recipes.full_fp16_recipe()),
110
+ (quant_recipes.full_int8_dynamic_recipe()),
111
+ (quant_recipes.full_int8_weight_only_recipe()),
112
+ (_attention_int8_dynamic_recipe()),
113
+ (_feedforward_int8_dynamic_recipe()),
114
+ ])
115
+ def test_quantize_convert_toy_sizes(self, quant_config):
119
116
  config = toy_model.get_model_config()
120
117
  pytorch_model = toy_model.ToySingleLayerModel(config)
121
- idx = torch.unsqueeze(torch.arange(0, 100), 0)
122
- input_pos = torch.arange(0, 100)
118
+ idx = torch.unsqueeze(torch.arange(0, 100, dtype=torch.int), 0)
119
+ input_pos = torch.arange(0, 100, dtype=torch.int)
120
+
121
+ quantized_model = ai_edge_torch.convert(
122
+ pytorch_model, (idx, input_pos), quant_config=quant_config
123
+ )
124
+ float_model = ai_edge_torch.convert(pytorch_model, (idx, input_pos))
125
+ self.assertLess(
126
+ len(quantized_model._tflite_model),
127
+ len(float_model._tflite_model),
128
+ "Quantized model isn't smaller than F32 model.",
129
+ )
123
130
 
131
+ def test_quantize_convert_toy_weight_sharing(self):
132
+ config = toy_model.get_model_config()
133
+ pytorch_model = toy_model.ToySingleLayerModelWeightSharing(config)
134
+ idx = torch.unsqueeze(torch.arange(0, 100, dtype=torch.int), 0)
135
+ input_pos = torch.arange(0, 100, dtype=torch.int)
136
+
137
+ quant_config = quant_recipes.full_int8_dynamic_recipe()
124
138
  quantized_model = ai_edge_torch.convert(
125
139
  pytorch_model, (idx, input_pos), quant_config=quant_config
126
140
  )
127
141
  float_model = ai_edge_torch.convert(pytorch_model, (idx, input_pos))
128
- self.assertAlmostEqual(
129
- len(quantized_model._tflite_model) / len(float_model._tflite_model),
130
- expected_compression,
131
- delta=0.01,
142
+ self.assertLess(
143
+ len(quantized_model._tflite_model),
144
+ len(float_model._tflite_model),
145
+ "Quantized model isn't smaller than F32 model.",
132
146
  )
133
147
 
134
148
  def test_quantize_convert_compare_toy(self):
135
149
  self.skipTest("b/338288901")
136
150
  config = toy_model_with_kv_cache.get_model_config()
137
151
  pytorch_model = toy_model_with_kv_cache.ToyModelWithKV(config)
138
- idx, input_pos = torch.tensor([[1]], dtype=torch.long), torch.tensor(
152
+ idx, input_pos = torch.tensor([[1]], dtype=torch.int), torch.tensor(
139
153
  [10], dtype=torch.int64
140
154
  )
141
155
 
@@ -145,7 +159,9 @@ class TestQuantizeConvert(unittest.TestCase):
145
159
  )
146
160
  float_model = ai_edge_torch.convert(pytorch_model, (idx, input_pos))
147
161
 
148
- self.assertLess(len(quantized_model._tflite_model), len(float_model._tflite_model))
162
+ self.assertLess(
163
+ len(quantized_model._tflite_model), len(float_model._tflite_model)
164
+ )
149
165
  self.assertTrue(
150
166
  model_coverage.compare_tflite_torch(
151
167
  quantized_model,
@@ -159,4 +175,4 @@ class TestQuantizeConvert(unittest.TestCase):
159
175
 
160
176
 
161
177
  if __name__ == "__main__":
162
- unittest.main()
178
+ googletest.main()
@@ -0,0 +1,54 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Common utils for testing."""
17
+
18
+ from ai_edge_torch import model
19
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
20
+ from ai_edge_torch.lowertools import common_utils
21
+ import numpy as np
22
+ import torch
23
+ from torch.utils import _pytree as pytree
24
+
25
+
26
+ def compare_tflite_torch(
27
+ edge_model: model.Model,
28
+ torch_model: torch.nn.Module,
29
+ tokens: torch.Tensor,
30
+ input_pos: torch.Tensor,
31
+ kv_cache: kv_utils.KVCache,
32
+ signature_name: str,
33
+ atol: float = 1e-5,
34
+ rtol: float = 1e-5,
35
+ ):
36
+ """Compares torch models and TFLite models."""
37
+ values, spec = pytree.tree_flatten({"kv_cache": kv_cache})
38
+ flat_names = common_utils.flat_dict_names(spec.children_specs, spec.context)
39
+ torch_output = torch_model(tokens, input_pos, kv_cache)
40
+
41
+ input_kv_flatten = {k: v.numpy() for k, v in zip(flat_names, values)}
42
+ edge_output = edge_model(
43
+ signature_name=signature_name,
44
+ tokens=tokens.numpy(),
45
+ input_pos=input_pos.numpy(),
46
+ **input_kv_flatten,
47
+ )
48
+
49
+ return np.allclose(
50
+ edge_output["logits"],
51
+ torch_output["logits"].detach().numpy(),
52
+ atol=atol,
53
+ rtol=rtol,
54
+ )
@@ -0,0 +1,82 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Common utility functions for model conversion."""
17
+
18
+ import ai_edge_torch
19
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
20
+ from ai_edge_torch.generative.quantize import quant_recipes
21
+ import torch
22
+
23
+
24
+ def convert_to_tflite(
25
+ pytorch_model: torch.nn.Module,
26
+ tflite_path: str,
27
+ prefill_seq_len: int = 512,
28
+ quantize: bool = True,
29
+ ):
30
+ """Converts a nn.Module model to multi-signature tflite model.
31
+
32
+ A PyTorch model will be converted to a tflite model with two signatures:
33
+ "prefill" and "decode".
34
+
35
+ "prefill" signature takes a tensor of shape [1, prefill_seq_len] of token
36
+ sequence, a tensor of shape [1, prefill_seq_len] of token positions, and an
37
+ external KV cache as a sample input.
38
+
39
+ "decode" signature takes a tensor of shape [1, 1] of token sequence, a tensor
40
+ of shape [1, 1] of the token position, and an external KV cache as a sample
41
+ input.
42
+
43
+ The final tflite model will be exported to tflite_path.
44
+
45
+ Args:
46
+ pytorch_model (torch.nn.Module): PyTorch model to convert to tflite.
47
+ tflite_path (str): The tflite file path to export.
48
+ prefill_seq_len (int, optional): The maximum size of prefill input tensor.
49
+ Defaults to 512.
50
+ quantize (bool, optional): Whether the model should be quanized. Defaults
51
+ to True.
52
+ """
53
+ # Tensors used to trace the model graph during conversion.
54
+ prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.int)
55
+ prefill_input_pos = torch.arange(0, prefill_seq_len, dtype=torch.int)
56
+ decode_token = torch.tensor([[0]], dtype=torch.int)
57
+ decode_input_pos = torch.tensor([0], dtype=torch.int)
58
+ kv = kv_utils.KVCache.from_model_config(pytorch_model.config)
59
+
60
+ quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
61
+ edge_model = (
62
+ ai_edge_torch.signature(
63
+ 'prefill',
64
+ pytorch_model,
65
+ sample_kwargs={
66
+ 'tokens': prefill_tokens,
67
+ 'input_pos': prefill_input_pos,
68
+ 'kv_cache': kv,
69
+ },
70
+ )
71
+ .signature(
72
+ 'decode',
73
+ pytorch_model,
74
+ sample_kwargs={
75
+ 'tokens': decode_token,
76
+ 'input_pos': decode_input_pos,
77
+ 'kv_cache': kv,
78
+ },
79
+ )
80
+ .convert(quant_config=quant_config)
81
+ )
82
+ edge_model.export(tflite_path)