ai-edge-torch-nightly 0.2.0.dev20240714__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ai-edge-torch-nightly might be problematic. Click here for more details.

Files changed (121) hide show
  1. ai_edge_torch/__init__.py +31 -0
  2. ai_edge_torch/convert/__init__.py +14 -0
  3. ai_edge_torch/convert/conversion.py +117 -0
  4. ai_edge_torch/convert/conversion_utils.py +400 -0
  5. ai_edge_torch/convert/converter.py +202 -0
  6. ai_edge_torch/convert/fx_passes/__init__.py +59 -0
  7. ai_edge_torch/convert/fx_passes/_pass_base.py +49 -0
  8. ai_edge_torch/convert/fx_passes/build_aten_composite_pass.py +225 -0
  9. ai_edge_torch/convert/fx_passes/build_interpolate_composite_pass.py +123 -0
  10. ai_edge_torch/convert/fx_passes/canonicalize_pass.py +37 -0
  11. ai_edge_torch/convert/fx_passes/inject_mlir_debuginfo_pass.py +73 -0
  12. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/__init__.py +16 -0
  13. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_check.py +215 -0
  14. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_mark.py +48 -0
  15. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/__init__.py +17 -0
  16. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py +59 -0
  17. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py +215 -0
  18. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +400 -0
  19. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/op_func_registry.py +30 -0
  20. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +293 -0
  21. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/utils.py +62 -0
  22. ai_edge_torch/convert/test/__init__.py +14 -0
  23. ai_edge_torch/convert/test/test_convert.py +311 -0
  24. ai_edge_torch/convert/test/test_convert_composites.py +192 -0
  25. ai_edge_torch/convert/test/test_convert_multisig.py +139 -0
  26. ai_edge_torch/convert/test/test_to_channel_last_io.py +96 -0
  27. ai_edge_torch/convert/to_channel_last_io.py +85 -0
  28. ai_edge_torch/debug/__init__.py +17 -0
  29. ai_edge_torch/debug/culprit.py +464 -0
  30. ai_edge_torch/debug/test/__init__.py +14 -0
  31. ai_edge_torch/debug/test/test_culprit.py +133 -0
  32. ai_edge_torch/debug/test/test_search_model.py +50 -0
  33. ai_edge_torch/debug/utils.py +48 -0
  34. ai_edge_torch/experimental/__init__.py +14 -0
  35. ai_edge_torch/generative/__init__.py +14 -0
  36. ai_edge_torch/generative/examples/__init__.py +14 -0
  37. ai_edge_torch/generative/examples/gemma/__init__.py +14 -0
  38. ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +66 -0
  39. ai_edge_torch/generative/examples/gemma/gemma.py +174 -0
  40. ai_edge_torch/generative/examples/phi2/__init__.py +14 -0
  41. ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +64 -0
  42. ai_edge_torch/generative/examples/phi2/phi2.py +164 -0
  43. ai_edge_torch/generative/examples/stable_diffusion/__init__.py +14 -0
  44. ai_edge_torch/generative/examples/stable_diffusion/attention.py +106 -0
  45. ai_edge_torch/generative/examples/stable_diffusion/clip.py +115 -0
  46. ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +142 -0
  47. ai_edge_torch/generative/examples/stable_diffusion/decoder.py +317 -0
  48. ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +573 -0
  49. ai_edge_torch/generative/examples/stable_diffusion/encoder.py +118 -0
  50. ai_edge_torch/generative/examples/stable_diffusion/pipeline.py +222 -0
  51. ai_edge_torch/generative/examples/stable_diffusion/samplers/__init__.py +19 -0
  52. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler.py +61 -0
  53. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler_ancestral.py +65 -0
  54. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_lms.py +73 -0
  55. ai_edge_torch/generative/examples/stable_diffusion/samplers/sampler.py +38 -0
  56. ai_edge_torch/generative/examples/stable_diffusion/tokenizer.py +108 -0
  57. ai_edge_torch/generative/examples/stable_diffusion/util.py +71 -0
  58. ai_edge_torch/generative/examples/t5/__init__.py +14 -0
  59. ai_edge_torch/generative/examples/t5/convert_to_tflite.py +135 -0
  60. ai_edge_torch/generative/examples/t5/t5.py +608 -0
  61. ai_edge_torch/generative/examples/t5/t5_attention.py +231 -0
  62. ai_edge_torch/generative/examples/test_models/__init__.py +14 -0
  63. ai_edge_torch/generative/examples/test_models/toy_model.py +122 -0
  64. ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +161 -0
  65. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +143 -0
  66. ai_edge_torch/generative/examples/tiny_llama/__init__.py +0 -0
  67. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +66 -0
  68. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +164 -0
  69. ai_edge_torch/generative/fx_passes/__init__.py +31 -0
  70. ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +47 -0
  71. ai_edge_torch/generative/layers/__init__.py +14 -0
  72. ai_edge_torch/generative/layers/attention.py +354 -0
  73. ai_edge_torch/generative/layers/attention_utils.py +169 -0
  74. ai_edge_torch/generative/layers/builder.py +131 -0
  75. ai_edge_torch/generative/layers/feed_forward.py +95 -0
  76. ai_edge_torch/generative/layers/kv_cache.py +83 -0
  77. ai_edge_torch/generative/layers/model_config.py +158 -0
  78. ai_edge_torch/generative/layers/normalization.py +62 -0
  79. ai_edge_torch/generative/layers/rotary_position_embedding.py +36 -0
  80. ai_edge_torch/generative/layers/scaled_dot_product_attention.py +117 -0
  81. ai_edge_torch/generative/layers/unet/__init__.py +14 -0
  82. ai_edge_torch/generative/layers/unet/blocks_2d.py +711 -0
  83. ai_edge_torch/generative/layers/unet/builder.py +47 -0
  84. ai_edge_torch/generative/layers/unet/model_config.py +269 -0
  85. ai_edge_torch/generative/quantize/__init__.py +14 -0
  86. ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/__init__.py +0 -0
  87. ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/translate_recipe.py +148 -0
  88. ai_edge_torch/generative/quantize/example.py +45 -0
  89. ai_edge_torch/generative/quantize/quant_attrs.py +68 -0
  90. ai_edge_torch/generative/quantize/quant_recipe.py +151 -0
  91. ai_edge_torch/generative/quantize/quant_recipe_utils.py +51 -0
  92. ai_edge_torch/generative/quantize/quant_recipes.py +48 -0
  93. ai_edge_torch/generative/quantize/supported_schemes.py +32 -0
  94. ai_edge_torch/generative/test/__init__.py +14 -0
  95. ai_edge_torch/generative/test/loader_test.py +80 -0
  96. ai_edge_torch/generative/test/test_model_conversion.py +235 -0
  97. ai_edge_torch/generative/test/test_quantize.py +162 -0
  98. ai_edge_torch/generative/utilities/__init__.py +15 -0
  99. ai_edge_torch/generative/utilities/loader.py +328 -0
  100. ai_edge_torch/generative/utilities/stable_diffusion_loader.py +924 -0
  101. ai_edge_torch/generative/utilities/t5_loader.py +483 -0
  102. ai_edge_torch/hlfb/__init__.py +16 -0
  103. ai_edge_torch/hlfb/mark_pattern/__init__.py +139 -0
  104. ai_edge_torch/hlfb/mark_pattern/passes.py +42 -0
  105. ai_edge_torch/hlfb/mark_pattern/pattern.py +273 -0
  106. ai_edge_torch/hlfb/test/__init__.py +14 -0
  107. ai_edge_torch/hlfb/test/test_mark_pattern.py +133 -0
  108. ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py +270 -0
  109. ai_edge_torch/model.py +142 -0
  110. ai_edge_torch/quantize/__init__.py +16 -0
  111. ai_edge_torch/quantize/pt2e_quantizer.py +438 -0
  112. ai_edge_torch/quantize/pt2e_quantizer_utils.py +1041 -0
  113. ai_edge_torch/quantize/quant_config.py +81 -0
  114. ai_edge_torch/testing/__init__.py +14 -0
  115. ai_edge_torch/testing/model_coverage/__init__.py +16 -0
  116. ai_edge_torch/testing/model_coverage/model_coverage.py +132 -0
  117. ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/LICENSE +202 -0
  118. ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/METADATA +38 -0
  119. ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/RECORD +121 -0
  120. ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/WHEEL +5 -0
  121. ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/top_level.txt +1 -0
@@ -0,0 +1,47 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ # Builder utils for individual components.
16
+
17
+ from torch import nn
18
+
19
+ import ai_edge_torch.generative.layers.unet.model_config as unet_config
20
+
21
+
22
+ def build_upsampling(config: unet_config.UpSamplingConfig):
23
+ if config.mode == unet_config.SamplingType.NEAREST:
24
+ return nn.UpsamplingNearest2d(scale_factor=config.scale_factor)
25
+ elif config.mode == unet_config.SamplingType.BILINEAR:
26
+ return nn.UpsamplingBilinear2d(scale_factor=config.scale_factor)
27
+ else:
28
+ raise ValueError("Unsupported upsampling type.")
29
+
30
+
31
+ def build_downsampling(config: unet_config.DownSamplingConfig):
32
+ if config.mode == unet_config.SamplingType.AVERAGE:
33
+ return nn.AvgPool2d(config.kernel_size, config.stride, padding=config.padding)
34
+ elif config.mode == unet_config.SamplingType.CONVOLUTION:
35
+ out_channels = (
36
+ config.in_channels if config.out_channels is None else config.out_channels
37
+ )
38
+ padding = (0, 1, 0, 1) if config.padding == 0 else config.padding
39
+ return nn.Conv2d(
40
+ config.in_channels,
41
+ out_channels=out_channels,
42
+ kernel_size=config.kernel_size,
43
+ stride=config.stride,
44
+ padding=padding,
45
+ )
46
+ else:
47
+ raise ValueError("Unsupported downsampling type.")
@@ -0,0 +1,269 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ # UNet configuration class.
17
+ from dataclasses import dataclass
18
+ from dataclasses import field
19
+ import enum
20
+ from typing import List, Optional
21
+
22
+ import ai_edge_torch.generative.layers.model_config as layers_cfg
23
+
24
+
25
+ @enum.unique
26
+ class SamplingType(enum.Enum):
27
+ NEAREST = enum.auto()
28
+ BILINEAR = enum.auto()
29
+ AVERAGE = enum.auto()
30
+ CONVOLUTION = enum.auto()
31
+
32
+
33
+ @dataclass
34
+ class UpSamplingConfig:
35
+ mode: SamplingType
36
+ scale_factor: float
37
+
38
+
39
+ @dataclass
40
+ class DownSamplingConfig:
41
+ mode: SamplingType
42
+ in_channels: int
43
+ kernel_size: int
44
+ stride: int
45
+ padding: int
46
+ out_channels: Optional[int] = None
47
+
48
+
49
+ @dataclass
50
+ class ResidualBlock2DConfig:
51
+ in_channels: int
52
+ out_channels: int
53
+ normalization_config: layers_cfg.NormalizationConfig
54
+ activation_config: layers_cfg.ActivationConfig
55
+ # Optional time embedding channels if the residual block takes a time embedding context as input
56
+ time_embedding_channels: Optional[int] = None
57
+
58
+
59
+ @dataclass
60
+ class AttentionBlock2DConfig:
61
+ dim: int
62
+ normalization_config: layers_cfg.NormalizationConfig
63
+ attention_config: layers_cfg.AttentionConfig
64
+ enable_hlfb: bool = True
65
+ attention_batch_size: int = 1
66
+
67
+
68
+ @dataclass
69
+ class CrossAttentionBlock2DConfig:
70
+ query_dim: int
71
+ cross_dim: int
72
+ normalization_config: layers_cfg.NormalizationConfig
73
+ attention_config: layers_cfg.AttentionConfig
74
+ enable_hlfb: bool = True
75
+ attention_batch_size: int = 1
76
+
77
+
78
+ @dataclass
79
+ class FeedForwardBlock2DConfig:
80
+ dim: int
81
+ hidden_dim: int
82
+ normalization_config: layers_cfg.NormalizationConfig
83
+ activation_config: layers_cfg.ActivationConfig
84
+ use_bias: bool
85
+
86
+
87
+ @dataclass
88
+ class TransformerBlock2DConfig:
89
+ pre_conv_normalization_config: layers_cfg.NormalizationConfig
90
+ attention_block_config: AttentionBlock2DConfig
91
+ cross_attention_block_config: CrossAttentionBlock2DConfig
92
+ feed_forward_block_config: FeedForwardBlock2DConfig
93
+
94
+
95
+ @dataclass
96
+ class UpDecoderBlock2DConfig:
97
+ in_channels: int
98
+ out_channels: int
99
+ normalization_config: layers_cfg.NormalizationConfig
100
+ activation_config: layers_cfg.ActivationConfig
101
+ num_layers: int
102
+ # Optional time embedding channels if the residual blocks take a time embedding as input
103
+ time_embedding_channels: Optional[int] = None
104
+ # Whether to add upsample operation after residual blocks
105
+ add_upsample: bool = True
106
+ # Whether to add a conv2d layer after upsample
107
+ upsample_conv: bool = True
108
+ # Optional sampling config if add_upsample is True.
109
+ sampling_config: Optional[UpSamplingConfig] = None
110
+ # Optional config of transformer blocks interleaved with residual blocks
111
+ transformer_block_config: Optional[TransformerBlock2DConfig] = None
112
+ # Optional dimension of context tensor if context tensor is given as input.
113
+ context_dim: Optional[int] = None
114
+
115
+
116
+ @dataclass
117
+ class SkipUpDecoderBlock2DConfig:
118
+ in_channels: int
119
+ out_channels: int
120
+ # The dimension of output channels of previous connected block
121
+ prev_out_channels: int
122
+ normalization_config: layers_cfg.NormalizationConfig
123
+ activation_config: layers_cfg.ActivationConfig
124
+ num_layers: int
125
+ # Optional time embedding channels if the residual blocks take a time embedding as input
126
+ time_embedding_channels: Optional[int] = None
127
+ # Whether to add upsample operation after residual blocks
128
+ add_upsample: bool = True
129
+ # Whether to add a conv2d layer after upsample
130
+ upsample_conv: bool = True
131
+ # Optional sampling config if add_upsample is True.
132
+ sampling_config: Optional[UpSamplingConfig] = None
133
+ # Optional config of transformer blocks interleaved with residual blocks
134
+ transformer_block_config: Optional[TransformerBlock2DConfig] = None
135
+ # Optional dimension of context tensor if context tensor is given as input.
136
+ context_dim: Optional[int] = None
137
+
138
+
139
+ @dataclass
140
+ class DownEncoderBlock2DConfig:
141
+ in_channels: int
142
+ out_channels: int
143
+ normalization_config: layers_cfg.NormalizationConfig
144
+ activation_config: layers_cfg.ActivationConfig
145
+ num_layers: int
146
+ # Padding for the downsampling convolution.
147
+ padding: int = 1
148
+ # Optional time embedding channels if the residual blocks take a time embedding as input
149
+ time_embedding_channels: Optional[int] = None
150
+ # Whether to add downsample operation after residual blocks
151
+ add_downsample: bool = True
152
+ # Optional sampling config if add_upsample is True.
153
+ sampling_config: Optional[DownSamplingConfig] = None
154
+ # Optional config of transformer blocks interleaved with residual blocks
155
+ transformer_block_config: Optional[TransformerBlock2DConfig] = None
156
+ # Optional dimension of context tensor if context tensor is given as input.
157
+ context_dim: Optional[int] = None
158
+
159
+
160
+ @dataclass
161
+ class MidBlock2DConfig:
162
+ in_channels: int
163
+ normalization_config: layers_cfg.NormalizationConfig
164
+ activation_config: layers_cfg.ActivationConfig
165
+ num_layers: int
166
+ # Optional time embedding channels if the residual blocks take a time embedding context as input
167
+ time_embedding_channels: Optional[int] = None
168
+ # Optional config of attention blocks interleaved with residual blocks
169
+ attention_block_config: Optional[AttentionBlock2DConfig] = None
170
+ # Optional config of transformer blocks interleaved with residual blocks
171
+ transformer_block_config: Optional[TransformerBlock2DConfig] = None
172
+ # Optional dimension of context tensor if context tensor is given as input.
173
+ context_dim: Optional[int] = None
174
+
175
+
176
+ @dataclass
177
+ class AutoEncoderConfig:
178
+ """Configurations of encoder/decoder in the autoencoder model."""
179
+
180
+ # The activation type of encoder/decoder blocks.
181
+ activation_config: layers_cfg.ActivationConfig
182
+
183
+ # The output channels of each block.
184
+ block_out_channels: List[int]
185
+
186
+ # Number of channels in the input image.
187
+ in_channels: int
188
+
189
+ # Number of channels in the output.
190
+ out_channels: int
191
+
192
+ # Number of channels in the latent space.
193
+ latent_channels: int
194
+
195
+ # The component-wise standard deviation of the trained latent space computed using the first batch of the
196
+ # training set. This is used to scale the latent space to have unit variance when training the diffusion
197
+ # model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
198
+ # diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
199
+ # / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
200
+ # Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
201
+ scaling_factor: float
202
+
203
+ # The layesr number of each encoder/decoder block.
204
+ layers_per_block: int
205
+
206
+ # The normalization config.
207
+ normalization_config: layers_cfg.NormalizationConfig
208
+
209
+ # The configuration of middle blocks, that is, after the last block of encoder and before the first block of decoder.
210
+ mid_block_config: MidBlock2DConfig
211
+
212
+
213
+ @dataclass
214
+ class DiffusionModelConfig:
215
+ """Configurations of Diffusion model."""
216
+
217
+ # Number of channels in the input tensor.
218
+ in_channels: int
219
+
220
+ # Number of channels in the output tensor.
221
+ out_channels: int
222
+
223
+ # The output channels of each block.
224
+ block_out_channels: List[int]
225
+
226
+ # The layesr number of each block.
227
+ layers_per_block: int
228
+
229
+ # The padding to use for the downsampling.
230
+ downsample_padding: int
231
+
232
+ # Normalization config used in residual blocks.
233
+ residual_norm_config: layers_cfg.NormalizationConfig
234
+
235
+ # Activation config used in residual blocks
236
+ residual_activation_type: layers_cfg.ActivationType
237
+
238
+ # The batch size used in transformer blocks, for attention layers.
239
+ transformer_batch_size: int
240
+
241
+ # The number of attention heads used in transformer blocks.
242
+ transformer_num_attention_heads: int
243
+
244
+ # The dimension of cross attention used in transformer blocks.
245
+ transformer_cross_attention_dim: int
246
+
247
+ # Normalization config used in prev conv layer of transformer blocks.
248
+ transformer_pre_conv_norm_config: layers_cfg.NormalizationConfig
249
+
250
+ # Normalization config used in transformer blocks.
251
+ transformer_norm_config: layers_cfg.NormalizationConfig
252
+
253
+ # Activation type of feed forward used in transformer blocks.
254
+ transformer_ff_activation_type: layers_cfg.ActivationType
255
+
256
+ # Number of layers in mid block.
257
+ mid_block_layers: int
258
+
259
+ # Dimension of time embedding.
260
+ time_embedding_dim: int
261
+
262
+ # Time embedding dimensions for blocks.
263
+ time_embedding_blocks_dim: int
264
+
265
+ # Normalization config used for final layer
266
+ final_norm_config: layers_cfg.NormalizationConfig
267
+
268
+ # Activation type used in final layer
269
+ final_activation_type: layers_cfg.ActivationType
@@ -0,0 +1,14 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
@@ -0,0 +1,148 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import json
17
+
18
+ from ai_edge_quantizer import quantizer
19
+
20
+ from ai_edge_torch.generative.quantize import quant_attrs
21
+ from ai_edge_torch.generative.quantize import quant_recipe
22
+
23
+ _OpExecutionMode = quantizer.qtyping.OpExecutionMode
24
+ _OpName = quantizer.qtyping.TFLOperationName
25
+ _TensorQuantConfig = quantizer.qtyping.TensorQuantizationConfig
26
+ _OpQuantConfig = quantizer.qtyping.OpQuantizationConfig
27
+
28
+ _DEFAULT_REGEX_STR = '.*'
29
+ _SINGULAR_TRANSFORMER_BLOCK_REGEX_STR = 'transformer_block'
30
+ _IDX_TRANSFORMER_BLOCKS_REGEX_STR = 'transformer_blocks\[{}\]'
31
+ _ATTENTION_REGEX_STR = 'ai_edge_torch.generative.layers.attention'
32
+ _FEEDFORWARD_REGEX_STR = 'ai_edge_torch.generative.layers.feed_forward'
33
+ _EMBEDDING_REGEX_STR = 'Embedding_tok_embedding'
34
+ _ANY_TWO_DIGITS_REGEX_STR = '\d{1,2}'
35
+
36
+
37
+ def _get_nbits_from_dtype(dtype: quant_attrs.Dtype) -> int:
38
+ if dtype == quant_attrs.Dtype.FP32:
39
+ return 32
40
+ elif dtype == quant_attrs.Dtype.FP16:
41
+ return 16
42
+ elif dtype == quant_attrs.Dtype.INT8:
43
+ return 8
44
+ raise ValueError('Unimplemented number of bits')
45
+
46
+
47
+ def _get_dtype_from_dtype(dtype: quant_attrs.Dtype) -> quantizer.qtyping.TensorDataType:
48
+ if dtype == quant_attrs.Dtype.FP32 or dtype == quant_attrs.Dtype.FP16:
49
+ return quantizer.qtyping.TensorDataType.FLOAT
50
+ else:
51
+ return quantizer.qtyping.TensorDataType.INT
52
+
53
+
54
+ def _get_execution_mode_from_mode(mode: quant_attrs.Mode) -> _OpExecutionMode:
55
+ if mode == quant_attrs.Mode.DYNAMIC_RANGE:
56
+ return _OpExecutionMode.DRQ
57
+ elif mode == quant_attrs.Mode.WEIGHT_ONLY:
58
+ return _OpExecutionMode.WEIGHT_ONLY
59
+ raise ValueError('Unimplemented execution mode')
60
+
61
+
62
+ def _get_channelwise_from_granularity(granularity: quant_attrs.Granularity) -> bool:
63
+ if granularity == quant_attrs.Granularity.CHANNELWISE:
64
+ return True
65
+ elif granularity == quant_attrs.Granularity.NONE:
66
+ return False
67
+ raise ValueError('Unimplemented granularity')
68
+
69
+
70
+ def _get_algorithm_key_from_algorithm(algo: quant_attrs.Algorithm) -> str:
71
+ if algo == quant_attrs.Algorithm.MIN_MAX:
72
+ return quantizer.algorithm_manager.AlgorithmName.MIN_MAX_UNIFORM_QUANT
73
+ elif algo == quant_attrs.Algorithm.FLOAT_CAST:
74
+ return quantizer.algorithm_manager.AlgorithmName.FLOAT_CASTING
75
+ raise ValueError('Unimplemented algorithm')
76
+
77
+
78
+ def _set_quant_config(
79
+ rm: quantizer.recipe_manager.RecipeManager,
80
+ layer_recipe: quant_recipe.LayerQuantRecipe,
81
+ regex: str,
82
+ ):
83
+ rm.add_quantization_config(
84
+ regex=regex,
85
+ operation_name=_OpName.ALL_SUPPORTED,
86
+ op_config=_OpQuantConfig(
87
+ weight_tensor_config=_TensorQuantConfig(
88
+ num_bits=_get_nbits_from_dtype(layer_recipe.weight_dtype),
89
+ symmetric=True,
90
+ channel_wise=_get_channelwise_from_granularity(layer_recipe.granularity),
91
+ dtype=_get_dtype_from_dtype(layer_recipe.weight_dtype),
92
+ ),
93
+ execution_mode=_get_execution_mode_from_mode(layer_recipe.mode),
94
+ ),
95
+ algorithm_key=_get_algorithm_key_from_algorithm(layer_recipe.algorithm),
96
+ )
97
+
98
+
99
+ def translate_to_ai_edge_recipe(
100
+ recipe: quant_recipe.GenerativeQuantRecipe,
101
+ ) -> quantizer.recipe_manager.ModelQuantizationRecipe:
102
+ rm = quantizer.recipe_manager.RecipeManager()
103
+
104
+ if recipe.default is not None:
105
+ _set_quant_config(rm, recipe.default, _DEFAULT_REGEX_STR)
106
+
107
+ if recipe.embedding is not None:
108
+ _set_quant_config(rm, recipe.embedding, _EMBEDDING_REGEX_STR)
109
+
110
+ if recipe.attention is not None:
111
+ if isinstance(recipe.attention, dict):
112
+ for idx, layer in recipe.attention.items():
113
+ _set_quant_config(
114
+ rm,
115
+ layer,
116
+ f'{_IDX_TRANSFORMER_BLOCKS_REGEX_STR.format(idx)}/{_ATTENTION_REGEX_STR}',
117
+ )
118
+ else:
119
+ _set_quant_config(
120
+ rm,
121
+ recipe.attention,
122
+ f'{_SINGULAR_TRANSFORMER_BLOCK_REGEX_STR}/{_ATTENTION_REGEX_STR}',
123
+ )
124
+
125
+ if recipe.feedforward is not None:
126
+ if isinstance(recipe.feedforward, dict):
127
+ for idx, layer in recipe.feedforward.items():
128
+ _set_quant_config(
129
+ rm,
130
+ layer,
131
+ f'{_IDX_TRANSFORMER_BLOCKS_REGEX_STR.format(idx)}/{_FEEDFORWARD_REGEX_STR}',
132
+ )
133
+ else:
134
+ _set_quant_config(
135
+ rm,
136
+ recipe.feedforward,
137
+ f'{_SINGULAR_TRANSFORMER_BLOCK_REGEX_STR}/{_FEEDFORWARD_REGEX_STR}',
138
+ )
139
+
140
+ return rm.get_quantization_recipe()
141
+
142
+
143
+ def quantize_model(
144
+ model: bytearray, recipe: quantizer.recipe_manager.ModelQuantizationRecipe
145
+ ) -> bytearray:
146
+ qt = quantizer.Quantizer(bytearray(model), recipe)
147
+ result = qt.quantize()
148
+ return result.quantized_model
@@ -0,0 +1,45 @@
1
+ # Copyright 2024 The AI Edge Torch Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import numpy as np
17
+ import torch
18
+
19
+ import ai_edge_torch
20
+ from ai_edge_torch.generative.examples.gemma import gemma
21
+ from ai_edge_torch.generative.quantize import quant_recipes
22
+
23
+
24
+ def main():
25
+ # Build a PyTorch model as usual
26
+ config = gemma.get_fake_model_config_2b_for_test()
27
+ model = gemma.Gemma(config)
28
+ idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
29
+ tokens = torch.full((1, 10), 0, dtype=torch.long, device="cpu")
30
+ tokens[0, :4] = idx
31
+ input_pos = torch.arange(0, 10)
32
+
33
+ # Create a quantization recipe to be applied to the model
34
+ quant_config = quant_recipes.full_int8_dynamic_recipe()
35
+ print(quant_config)
36
+
37
+ # Convert with quantization
38
+ edge_model = ai_edge_torch.convert(
39
+ model, (tokens, input_pos), quant_config=quant_config
40
+ )
41
+ edge_model.export("/tmp/gemma_2b_quantized.tflite")
42
+
43
+
44
+ if __name__ == "__main__":
45
+ main()
@@ -0,0 +1,68 @@
1
+ # Copyright 2024 The AI Edge Torch Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import enum
17
+
18
+
19
+ @enum.unique
20
+ class Dtype(enum.Enum):
21
+ """Data types and precision of tensors."""
22
+
23
+ FP32 = enum.auto()
24
+ FP16 = enum.auto()
25
+ INT8 = enum.auto()
26
+
27
+
28
+ @enum.unique
29
+ class Algorithm(enum.Enum):
30
+ """Algorithm used to calculate quantization parameters.
31
+
32
+ Attributes:
33
+ MIN_MAX: Maps the min/max of floating point space to the min/max of
34
+ quantized space and quantize uniformly.
35
+ FLOAT_CAST: Casts a float to another float of a different type.
36
+ """
37
+
38
+ MIN_MAX = enum.auto()
39
+ FLOAT_CAST = enum.auto()
40
+
41
+
42
+ @enum.unique
43
+ class Mode(enum.Enum):
44
+ """Mode of quantization.
45
+
46
+ Attributes:
47
+ DYNAMIC_RANGE: Quantize activations during runtime and weights statically to
48
+ perform computation in integers.
49
+ WEIGHT_ONLY: Quantize weights statically and dequantize during runtime to
50
+ perform computation in floating points.
51
+ """
52
+
53
+ DYNAMIC_RANGE = enum.auto()
54
+ WEIGHT_ONLY = enum.auto()
55
+
56
+
57
+ @enum.unique
58
+ class Granularity(enum.Enum):
59
+ """Granularity of quantization parameters.
60
+
61
+ Attributes:
62
+ NONE: Granularity not applicable to this quantization scheme.
63
+ CHANNELWISE: Or per-channel quantization. Each channel of relevant tensors
64
+ is quantized independently of one another.
65
+ """
66
+
67
+ NONE = enum.auto()
68
+ CHANNELWISE = enum.auto()