ai-edge-torch-nightly 0.1.dev202405131930__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ai-edge-torch-nightly might be problematic. Click here for more details.

Files changed (91) hide show
  1. ai_edge_torch/__init__.py +30 -0
  2. ai_edge_torch/convert/__init__.py +14 -0
  3. ai_edge_torch/convert/conversion.py +117 -0
  4. ai_edge_torch/convert/conversion_utils.py +330 -0
  5. ai_edge_torch/convert/converter.py +171 -0
  6. ai_edge_torch/convert/fx_passes/__init__.py +59 -0
  7. ai_edge_torch/convert/fx_passes/_pass_base.py +49 -0
  8. ai_edge_torch/convert/fx_passes/build_aten_composite_pass.py +192 -0
  9. ai_edge_torch/convert/fx_passes/build_upsample_bilinear2d_composite_pass.py +84 -0
  10. ai_edge_torch/convert/fx_passes/canonicalize_pass.py +37 -0
  11. ai_edge_torch/convert/fx_passes/inject_mlir_debuginfo_pass.py +73 -0
  12. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/__init__.py +16 -0
  13. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_check.py +215 -0
  14. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_mark.py +48 -0
  15. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/__init__.py +17 -0
  16. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py +59 -0
  17. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py +196 -0
  18. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +400 -0
  19. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/op_func_registry.py +30 -0
  20. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +286 -0
  21. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/utils.py +62 -0
  22. ai_edge_torch/convert/test/__init__.py +14 -0
  23. ai_edge_torch/convert/test/test_convert.py +273 -0
  24. ai_edge_torch/convert/test/test_convert_composites.py +171 -0
  25. ai_edge_torch/convert/test/test_convert_multisig.py +139 -0
  26. ai_edge_torch/debug/__init__.py +16 -0
  27. ai_edge_torch/debug/culprit.py +423 -0
  28. ai_edge_torch/debug/test/__init__.py +14 -0
  29. ai_edge_torch/debug/test/test_culprit.py +133 -0
  30. ai_edge_torch/debug/utils.py +48 -0
  31. ai_edge_torch/experimental/__init__.py +14 -0
  32. ai_edge_torch/generative/__init__.py +14 -0
  33. ai_edge_torch/generative/examples/__init__.py +14 -0
  34. ai_edge_torch/generative/examples/gemma/__init__.py +14 -0
  35. ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +66 -0
  36. ai_edge_torch/generative/examples/gemma/gemma.py +174 -0
  37. ai_edge_torch/generative/examples/phi2/__init__.py +14 -0
  38. ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +64 -0
  39. ai_edge_torch/generative/examples/phi2/phi2.py +164 -0
  40. ai_edge_torch/generative/examples/t5/__init__.py +14 -0
  41. ai_edge_torch/generative/examples/t5/convert_to_tflite.py +135 -0
  42. ai_edge_torch/generative/examples/t5/t5.py +608 -0
  43. ai_edge_torch/generative/examples/t5/t5_attention.py +255 -0
  44. ai_edge_torch/generative/examples/test_models/__init__.py +14 -0
  45. ai_edge_torch/generative/examples/test_models/toy_model.py +119 -0
  46. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +143 -0
  47. ai_edge_torch/generative/examples/tiny_llama/__init__.py +0 -0
  48. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +66 -0
  49. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +164 -0
  50. ai_edge_torch/generative/layers/__init__.py +14 -0
  51. ai_edge_torch/generative/layers/attention.py +288 -0
  52. ai_edge_torch/generative/layers/attention_utils.py +169 -0
  53. ai_edge_torch/generative/layers/builder.py +103 -0
  54. ai_edge_torch/generative/layers/feed_forward.py +95 -0
  55. ai_edge_torch/generative/layers/kv_cache.py +83 -0
  56. ai_edge_torch/generative/layers/model_config.py +135 -0
  57. ai_edge_torch/generative/layers/normalization.py +62 -0
  58. ai_edge_torch/generative/layers/rotary_position_embedding.py +36 -0
  59. ai_edge_torch/generative/quantize/__init__.py +14 -0
  60. ai_edge_torch/generative/quantize/example.py +45 -0
  61. ai_edge_torch/generative/quantize/quant_attrs.py +66 -0
  62. ai_edge_torch/generative/quantize/quant_recipe.py +106 -0
  63. ai_edge_torch/generative/quantize/quant_recipe_utils.py +51 -0
  64. ai_edge_torch/generative/quantize/quant_recipes.py +48 -0
  65. ai_edge_torch/generative/quantize/supported_schemes.py +31 -0
  66. ai_edge_torch/generative/test/__init__.py +14 -0
  67. ai_edge_torch/generative/test/test_model_conversion.py +201 -0
  68. ai_edge_torch/generative/test/test_quantize.py +109 -0
  69. ai_edge_torch/generative/utilities/__init__.py +15 -0
  70. ai_edge_torch/generative/utilities/loader.py +290 -0
  71. ai_edge_torch/generative/utilities/t5_loader.py +467 -0
  72. ai_edge_torch/hlfb/__init__.py +16 -0
  73. ai_edge_torch/hlfb/mark_pattern/__init__.py +139 -0
  74. ai_edge_torch/hlfb/mark_pattern/passes.py +42 -0
  75. ai_edge_torch/hlfb/mark_pattern/pattern.py +260 -0
  76. ai_edge_torch/hlfb/test/__init__.py +14 -0
  77. ai_edge_torch/hlfb/test/test_mark_pattern.py +133 -0
  78. ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py +270 -0
  79. ai_edge_torch/model.py +134 -0
  80. ai_edge_torch/quantize/__init__.py +16 -0
  81. ai_edge_torch/quantize/pt2e_quantizer.py +438 -0
  82. ai_edge_torch/quantize/pt2e_quantizer_utils.py +1041 -0
  83. ai_edge_torch/quantize/quant_config.py +85 -0
  84. ai_edge_torch/testing/__init__.py +14 -0
  85. ai_edge_torch/testing/model_coverage/__init__.py +16 -0
  86. ai_edge_torch/testing/model_coverage/model_coverage.py +126 -0
  87. ai_edge_torch_nightly-0.1.dev202405131930.dist-info/LICENSE +202 -0
  88. ai_edge_torch_nightly-0.1.dev202405131930.dist-info/METADATA +38 -0
  89. ai_edge_torch_nightly-0.1.dev202405131930.dist-info/RECORD +91 -0
  90. ai_edge_torch_nightly-0.1.dev202405131930.dist-info/WHEEL +5 -0
  91. ai_edge_torch_nightly-0.1.dev202405131930.dist-info/top_level.txt +1 -0
@@ -0,0 +1,106 @@
1
+ # Copyright 2024 The AI Edge Torch Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from dataclasses import dataclass
17
+ import enum
18
+ from typing import Optional
19
+
20
+ from ai_edge_torch.generative.quantize import quant_attrs
21
+ from ai_edge_torch.generative.quantize import supported_schemes
22
+
23
+
24
+ @dataclass
25
+ class LayerQuantRecipe:
26
+ """Quantization recipe for a single Edge Generative API layer (e.g. Attention).
27
+
28
+ Generic layer-scoped quantization recipe that specifies how this layer should
29
+ be quantized by the Edge Generative API. This is applicable to layers implemented
30
+ in ai_edge_torch/generative/layers/. Combinations of attributes that are not
31
+ supported during runtime will be detected when .verify() is called.
32
+
33
+ Attributes:
34
+ activation_dtype: Desired data type of activation tensors.
35
+ weight_dtype: Desired data type of weight tensors.
36
+ mode: Type of quantization.
37
+ algorithm: Algorithm for calculating quantization parameters.
38
+ granularity: Granularity of quantization.
39
+ """
40
+
41
+ activation_dtype: quant_attrs.Dtype
42
+ weight_dtype: quant_attrs.Dtype
43
+ mode: quant_attrs.Mode
44
+ algorithm: quant_attrs.Algorithm
45
+ granularity: quant_attrs.Granularity
46
+
47
+ def __str__(self):
48
+ return (
49
+ f'(a:{self.activation_dtype.name}, '
50
+ f'w:{self.weight_dtype.name}, '
51
+ f'{self.mode.name}, '
52
+ f'{self.algorithm.name}, '
53
+ f'{self.granularity.name})'
54
+ )
55
+
56
+ __repr__ = __str__
57
+
58
+ def verify(self):
59
+ """Checks if all attributes configured are supported in runtime.
60
+
61
+ Raises:
62
+ ValueError: If any attributes are incompatible.
63
+ """
64
+ is_valid = False
65
+ for supported in supported_schemes.get_supported_layer_schemes():
66
+ if (
67
+ self.activation_dtype == supported[0]
68
+ and self.weight_dtype == supported[1]
69
+ and self.mode == supported[2]
70
+ and self.algorithm == supported[3]
71
+ and self.granularity == supported[4]
72
+ ):
73
+ is_valid = True
74
+ break
75
+
76
+ if not is_valid:
77
+ raise ValueError(
78
+ 'Unsupported LayerQuantRecipe configuration. See get_supported_recipe_matrix()'
79
+ )
80
+
81
+
82
+ @dataclass
83
+ class TransformerQuantRecipe:
84
+ """Quantization recipe for a model composed of the Edge Generative API layers.
85
+
86
+ Attributes:
87
+ default: The quantization recipe for global scope of the model.
88
+ """
89
+
90
+ default: Optional[LayerQuantRecipe] = None
91
+
92
+ def __str__(self):
93
+ return f"""TransformerQuantRecipe(
94
+ Default: {self.default}
95
+ )"""
96
+
97
+ __repr__ = __str__
98
+
99
+ def verify(self):
100
+ """Checks if the recipe configured can be supported in runtime.
101
+
102
+ Raises:
103
+ ValueError: If the recipe configured is invalid or unsupported.
104
+ """
105
+ if self.default is not None:
106
+ self.default.verify()
@@ -0,0 +1,51 @@
1
+ # Copyright 2024 The AI Edge Torch Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Helper functions to construct custom quantization recipes.
17
+
18
+ These are intended for more advanced users who want to configure their own
19
+ quantization recipes. For pre-constructed recipes, use `quant_recipes.py` instead.
20
+
21
+ Typical usage example:
22
+
23
+ 1. Applying a single layer recipe to the entire model
24
+
25
+ quant_recipe.TransformerQuantRecipe(
26
+ default=quant_recipe_utils.create_layer_quant_int8_dynamic()
27
+ )
28
+ """
29
+
30
+ from ai_edge_torch.generative.quantize import quant_attrs
31
+ from ai_edge_torch.generative.quantize import quant_recipe
32
+
33
+
34
+ def create_layer_quant_int8_dynamic() -> quant_recipe.LayerQuantRecipe:
35
+ return quant_recipe.LayerQuantRecipe(
36
+ activation_dtype=quant_attrs.Dtype.FP32,
37
+ weight_dtype=quant_attrs.Dtype.INT8,
38
+ mode=quant_attrs.Mode.DYNAMIC_RANGE,
39
+ algorithm=quant_attrs.Algorithm.MIN_MAX,
40
+ granularity=quant_attrs.Granularity.CHANNELWISE,
41
+ )
42
+
43
+
44
+ def create_layer_quant_fp16() -> quant_recipe.LayerQuantRecipe:
45
+ return quant_recipe.LayerQuantRecipe(
46
+ activation_dtype=quant_attrs.Dtype.FP32,
47
+ weight_dtype=quant_attrs.Dtype.FP16,
48
+ mode=quant_attrs.Mode.WEIGHT_ONLY,
49
+ algorithm=quant_attrs.Algorithm.MIN_MAX,
50
+ granularity=quant_attrs.Granularity.NONE,
51
+ )
@@ -0,0 +1,48 @@
1
+ # Copyright 2024 The AI Edge Torch Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Helper functions to create common and supported quantization recipes.
17
+
18
+ These recipes will work with models created with the Edge Generative API only.
19
+ Assume Transformer architecture congruent with
20
+ ai_edge_torch/generative/layers/model_config.py:ModelConfig.
21
+
22
+ Typical usage example:
23
+
24
+ quant_config = quant_recipes.full_linear_int8_dynamic_recipe()
25
+ edge_model = ai_edge_torch.convert(
26
+ model, (tokens, input_pos), quant_config=quant_config
27
+ )
28
+ """
29
+
30
+ from ai_edge_torch.generative.quantize import quant_recipe
31
+ from ai_edge_torch.generative.quantize import quant_recipe_utils
32
+ from ai_edge_torch.quantize import quant_config
33
+
34
+
35
+ def full_linear_int8_dynamic_recipe() -> quant_config.QuantConfig:
36
+ return quant_config.QuantConfig(
37
+ transformer_recipe=quant_recipe.TransformerQuantRecipe(
38
+ default=quant_recipe_utils.create_layer_quant_int8_dynamic()
39
+ )
40
+ )
41
+
42
+
43
+ def full_fp16_recipe() -> quant_config.QuantConfig:
44
+ return quant_config.QuantConfig(
45
+ transformer_recipe=quant_recipe.TransformerQuantRecipe(
46
+ default=quant_recipe_utils.create_layer_quant_fp16()
47
+ )
48
+ )
@@ -0,0 +1,31 @@
1
+ # Copyright 2024 The AI Edge Torch Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+
17
+ def get_supported_layer_schemes():
18
+ """List of layer-scoped quantization schemes supported in runtime.
19
+
20
+ Returns:
21
+ List of tuple(activation_dtype, weight_dtype, mode, algorithm, granularity).
22
+ """
23
+ from ai_edge_torch.generative.quantize.quant_attrs import Algorithm as _a
24
+ from ai_edge_torch.generative.quantize.quant_attrs import Dtype as _t
25
+ from ai_edge_torch.generative.quantize.quant_attrs import Granularity as _g
26
+ from ai_edge_torch.generative.quantize.quant_attrs import Mode as _m
27
+
28
+ return [
29
+ (_t.FP32, _t.INT8, _m.DYNAMIC_RANGE, _a.MIN_MAX, _g.CHANNELWISE),
30
+ (_t.FP32, _t.FP16, _m.WEIGHT_ONLY, _a.MIN_MAX, _g.NONE),
31
+ ]
@@ -0,0 +1,14 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
@@ -0,0 +1,201 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ # Testing model conversion for a few gen-ai models.
16
+ import copy
17
+ import os
18
+ import tempfile
19
+ import unittest
20
+
21
+ import numpy as np
22
+ import torch
23
+
24
+ import ai_edge_torch
25
+ from ai_edge_torch.generative.examples.gemma import gemma
26
+ from ai_edge_torch.generative.examples.phi2 import phi2
27
+ from ai_edge_torch.generative.examples.test_models import toy_model_with_kv_cache # NOQA
28
+ from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
29
+ from ai_edge_torch.testing import model_coverage
30
+
31
+
32
+ class TestModelConversion(unittest.TestCase):
33
+ """Unit tests that check for model conversion and correctness."""
34
+
35
+ def test_toy_model_with_kv_cache(self):
36
+ self.skipTest("b/338288901")
37
+ config = toy_model_with_kv_cache.get_model_config()
38
+ pytorch_model = toy_model_with_kv_cache.ToyModelWithKV(config)
39
+ idx, input_pos = torch.tensor([[1]], dtype=torch.long), torch.tensor(
40
+ [10], dtype=torch.int64
41
+ )
42
+
43
+ edge_model = ai_edge_torch.convert(pytorch_model, (idx, input_pos))
44
+
45
+ self.assertTrue(
46
+ model_coverage.compare_tflite_torch(
47
+ edge_model,
48
+ pytorch_model,
49
+ (idx, input_pos),
50
+ num_valid_inputs=1,
51
+ atol=1e-5,
52
+ rtol=1e-5,
53
+ )
54
+ )
55
+
56
+ def test_toy_model_with_kv_cache_with_hlfb(self):
57
+ self.skipTest("b/338288901")
58
+ config = toy_model_with_kv_cache.get_model_config()
59
+ config.enable_hlfb = True
60
+ pytorch_model = toy_model_with_kv_cache.ToyModelWithKV(config)
61
+ idx, input_pos = torch.tensor([[1]], dtype=torch.long), torch.tensor(
62
+ [10], dtype=torch.int64
63
+ )
64
+
65
+ edge_model = ai_edge_torch.convert(pytorch_model, (idx, input_pos))
66
+
67
+ self.assertTrue(
68
+ model_coverage.compare_tflite_torch(
69
+ edge_model,
70
+ pytorch_model,
71
+ (idx, input_pos),
72
+ num_valid_inputs=1,
73
+ atol=1e-5,
74
+ rtol=1e-5,
75
+ )
76
+ )
77
+
78
+ def test_tiny_llama(self):
79
+ self.skipTest("b/338288901")
80
+ config = tiny_llama.get_fake_model_config_for_test()
81
+ pytorch_model = tiny_llama.TinyLLamma(config)
82
+
83
+ idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
84
+ tokens = torch.full((1, 10), 0, dtype=torch.long, device="cpu")
85
+ tokens[0, :4] = idx
86
+ input_pos = torch.arange(0, 10)
87
+
88
+ edge_model = ai_edge_torch.convert(pytorch_model, (tokens, input_pos))
89
+
90
+ self.assertTrue(
91
+ model_coverage.compare_tflite_torch(
92
+ edge_model,
93
+ pytorch_model,
94
+ (tokens, input_pos),
95
+ num_valid_inputs=1,
96
+ atol=1e-5,
97
+ rtol=1e-5,
98
+ )
99
+ )
100
+
101
+ def test_tiny_llama_multisig(self):
102
+ self.skipTest("b/338288901")
103
+ config = tiny_llama.get_fake_model_config_for_test()
104
+ pytorch_model = tiny_llama.TinyLLamma(config)
105
+
106
+ # prefill
107
+ seq_len = 10
108
+ prefill_tokens = torch.full((1, seq_len), 0, dtype=torch.long, device="cpu")
109
+ prompt_token = torch.from_numpy(np.array([1, 2, 3, 4]))
110
+ prefill_tokens[0, : len(prompt_token)] = prompt_token
111
+ prefill_input_pos = torch.arange(0, seq_len)
112
+
113
+ # decode
114
+ decode_token = torch.tensor([[1]], dtype=torch.long)
115
+ decode_input_pos = torch.tensor([5], dtype=torch.int64)
116
+
117
+ edge_model = (
118
+ ai_edge_torch.signature(
119
+ "prefill", pytorch_model, (prefill_tokens, prefill_input_pos)
120
+ )
121
+ .signature("decode", pytorch_model, (decode_token, decode_input_pos))
122
+ .convert()
123
+ )
124
+
125
+ # For the pytorch model, the KV cache is a persistent state internal to the model, and it
126
+ # will be shared for prefill and decode. However, for tflite, currently we can't share
127
+ # kv-cache between the two signatures. prefill will change the content in kv-cache,
128
+ # but it won't be readable by the decode tflite model. This means the output of running `decode` after
129
+ # running `prefill` in pytorch will be different from the output of running `decode` after `prefill` via ai_edge_torch.
130
+ copied_model = copy.deepcopy(pytorch_model)
131
+
132
+ self.assertTrue(
133
+ model_coverage.compare_tflite_torch(
134
+ edge_model,
135
+ pytorch_model,
136
+ (prefill_tokens, prefill_input_pos),
137
+ signature_name="prefill",
138
+ num_valid_inputs=1,
139
+ )
140
+ )
141
+
142
+ self.assertTrue(
143
+ model_coverage.compare_tflite_torch(
144
+ edge_model,
145
+ copied_model,
146
+ (decode_token, decode_input_pos),
147
+ signature_name="decode",
148
+ num_valid_inputs=1,
149
+ )
150
+ )
151
+
152
+ def test_gemma(self):
153
+ self.skipTest("b/338288901")
154
+ config = gemma.get_fake_model_config_2b_for_test()
155
+ model = gemma.Gemma(config)
156
+
157
+ idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
158
+ tokens = torch.full((1, 10), 0, dtype=torch.long, device="cpu")
159
+ tokens[0, :4] = idx
160
+ input_pos = torch.arange(0, 10)
161
+
162
+ edge_model = ai_edge_torch.convert(model, (tokens, input_pos))
163
+
164
+ # TODO(talumbau, haoliang): debug numerical diff.
165
+ self.assertTrue(
166
+ model_coverage.compare_tflite_torch(
167
+ edge_model,
168
+ model,
169
+ (tokens, input_pos),
170
+ num_valid_inputs=1,
171
+ atol=1e-2,
172
+ rtol=1e-5,
173
+ )
174
+ )
175
+
176
+ def test_phi2(self):
177
+ self.skipTest("b/338288901")
178
+ config = phi2.get_fake_model_config_for_test()
179
+ pytorch_model = phi2.Phi2(config)
180
+
181
+ idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
182
+ tokens = torch.full((1, 10), 0, dtype=torch.long, device="cpu")
183
+ tokens[0, :4] = idx
184
+ input_pos = torch.arange(0, 10)
185
+
186
+ edge_model = ai_edge_torch.convert(pytorch_model, (tokens, input_pos))
187
+
188
+ self.assertTrue(
189
+ model_coverage.compare_tflite_torch(
190
+ edge_model,
191
+ pytorch_model,
192
+ (tokens, input_pos),
193
+ num_valid_inputs=1,
194
+ atol=1e-5,
195
+ rtol=1e-5,
196
+ )
197
+ )
198
+
199
+
200
+ if __name__ == "__main__":
201
+ unittest.main()
@@ -0,0 +1,109 @@
1
+ # Copyright 2024 The AI Edge Torch Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import unittest
17
+
18
+ from parameterized import parameterized
19
+ import torch
20
+
21
+ import ai_edge_torch
22
+ from ai_edge_torch.generative.examples.test_models import toy_model_with_kv_cache # NOQA
23
+ from ai_edge_torch.generative.quantize import quant_recipe
24
+ from ai_edge_torch.generative.quantize import quant_recipes
25
+ from ai_edge_torch.generative.quantize.quant_attrs import Algorithm
26
+ from ai_edge_torch.generative.quantize.quant_attrs import Dtype
27
+ from ai_edge_torch.generative.quantize.quant_attrs import Granularity
28
+ from ai_edge_torch.generative.quantize.quant_attrs import Mode
29
+ from ai_edge_torch.testing import model_coverage
30
+
31
+
32
+ class TestVerifyRecipes(unittest.TestCase):
33
+ """Unit tests that check for model quantization recipes."""
34
+
35
+ @parameterized.expand(
36
+ [
37
+ (Dtype.FP32, Dtype.FP32, Mode.DYNAMIC_RANGE),
38
+ (Dtype.INT8, Dtype.INT8, Mode.DYNAMIC_RANGE),
39
+ (Dtype.INT8, Dtype.FP16, Mode.DYNAMIC_RANGE),
40
+ (Dtype.FP16, Dtype.INT8, Mode.DYNAMIC_RANGE),
41
+ (Dtype.FP32, Dtype.FP32, Mode.WEIGHT_ONLY),
42
+ (Dtype.INT8, Dtype.INT8, Mode.WEIGHT_ONLY),
43
+ (Dtype.FP16, Dtype.INT8, Mode.WEIGHT_ONLY),
44
+ (Dtype.INT8, Dtype.FP16, Mode.WEIGHT_ONLY),
45
+ (Dtype.FP16, Dtype.FP16, Mode.WEIGHT_ONLY),
46
+ ]
47
+ )
48
+ def test_verify_invalid_recipes(
49
+ self,
50
+ activation,
51
+ weight,
52
+ mode,
53
+ algo=Algorithm.MIN_MAX,
54
+ granularity=Granularity.CHANNELWISE,
55
+ ):
56
+ with self.assertRaises(ValueError):
57
+ quant_recipe.LayerQuantRecipe(
58
+ activation, weight, mode, algo, granularity
59
+ ).verify()
60
+
61
+ @parameterized.expand(
62
+ [
63
+ (Dtype.FP32, Dtype.INT8, Mode.DYNAMIC_RANGE, Granularity.CHANNELWISE),
64
+ (Dtype.FP32, Dtype.FP16, Mode.WEIGHT_ONLY, Granularity.NONE),
65
+ ]
66
+ )
67
+ def test_verify_valid_recipes(
68
+ self,
69
+ activation,
70
+ weight,
71
+ mode,
72
+ granularity,
73
+ algo=Algorithm.MIN_MAX,
74
+ ):
75
+ quant_recipe.LayerQuantRecipe(activation, weight, mode, algo, granularity).verify()
76
+
77
+
78
+ class TestQuantizeConvert(unittest.TestCase):
79
+ """Test conversion with quantization."""
80
+
81
+ def test_quantize_convert_toy(self):
82
+ self.skipTest("b/338288901")
83
+ config = toy_model_with_kv_cache.get_model_config()
84
+ pytorch_model = toy_model_with_kv_cache.ToyModelWithKV(config)
85
+ idx, input_pos = torch.tensor([[1]], dtype=torch.long), torch.tensor(
86
+ [10], dtype=torch.int64
87
+ )
88
+
89
+ quant_config = quant_recipes.full_fp16_recipe()
90
+ quantized_model = ai_edge_torch.convert(
91
+ pytorch_model, (idx, input_pos), quant_config=quant_config
92
+ )
93
+ float_model = ai_edge_torch.convert(pytorch_model, (idx, input_pos))
94
+
95
+ self.assertLess(len(quantized_model._tflite_model), len(float_model._tflite_model))
96
+ self.assertTrue(
97
+ model_coverage.compare_tflite_torch(
98
+ quantized_model,
99
+ pytorch_model,
100
+ (idx, input_pos),
101
+ num_valid_inputs=1,
102
+ atol=1e-3,
103
+ rtol=1e-3,
104
+ )
105
+ )
106
+
107
+
108
+ if __name__ == "__main__":
109
+ unittest.main()
@@ -0,0 +1,15 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ # This module contains common utility functions.