ai-edge-torch-nightly 0.2.0.dev20240714__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ai-edge-torch-nightly might be problematic. Click here for more details.
- ai_edge_torch/__init__.py +31 -0
- ai_edge_torch/convert/__init__.py +14 -0
- ai_edge_torch/convert/conversion.py +117 -0
- ai_edge_torch/convert/conversion_utils.py +400 -0
- ai_edge_torch/convert/converter.py +202 -0
- ai_edge_torch/convert/fx_passes/__init__.py +59 -0
- ai_edge_torch/convert/fx_passes/_pass_base.py +49 -0
- ai_edge_torch/convert/fx_passes/build_aten_composite_pass.py +225 -0
- ai_edge_torch/convert/fx_passes/build_interpolate_composite_pass.py +123 -0
- ai_edge_torch/convert/fx_passes/canonicalize_pass.py +37 -0
- ai_edge_torch/convert/fx_passes/inject_mlir_debuginfo_pass.py +73 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/__init__.py +16 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_check.py +215 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_mark.py +48 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/__init__.py +17 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py +59 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py +215 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +400 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/op_func_registry.py +30 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +293 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/utils.py +62 -0
- ai_edge_torch/convert/test/__init__.py +14 -0
- ai_edge_torch/convert/test/test_convert.py +311 -0
- ai_edge_torch/convert/test/test_convert_composites.py +192 -0
- ai_edge_torch/convert/test/test_convert_multisig.py +139 -0
- ai_edge_torch/convert/test/test_to_channel_last_io.py +96 -0
- ai_edge_torch/convert/to_channel_last_io.py +85 -0
- ai_edge_torch/debug/__init__.py +17 -0
- ai_edge_torch/debug/culprit.py +464 -0
- ai_edge_torch/debug/test/__init__.py +14 -0
- ai_edge_torch/debug/test/test_culprit.py +133 -0
- ai_edge_torch/debug/test/test_search_model.py +50 -0
- ai_edge_torch/debug/utils.py +48 -0
- ai_edge_torch/experimental/__init__.py +14 -0
- ai_edge_torch/generative/__init__.py +14 -0
- ai_edge_torch/generative/examples/__init__.py +14 -0
- ai_edge_torch/generative/examples/gemma/__init__.py +14 -0
- ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +66 -0
- ai_edge_torch/generative/examples/gemma/gemma.py +174 -0
- ai_edge_torch/generative/examples/phi2/__init__.py +14 -0
- ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +64 -0
- ai_edge_torch/generative/examples/phi2/phi2.py +164 -0
- ai_edge_torch/generative/examples/stable_diffusion/__init__.py +14 -0
- ai_edge_torch/generative/examples/stable_diffusion/attention.py +106 -0
- ai_edge_torch/generative/examples/stable_diffusion/clip.py +115 -0
- ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +142 -0
- ai_edge_torch/generative/examples/stable_diffusion/decoder.py +317 -0
- ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +573 -0
- ai_edge_torch/generative/examples/stable_diffusion/encoder.py +118 -0
- ai_edge_torch/generative/examples/stable_diffusion/pipeline.py +222 -0
- ai_edge_torch/generative/examples/stable_diffusion/samplers/__init__.py +19 -0
- ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler.py +61 -0
- ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler_ancestral.py +65 -0
- ai_edge_torch/generative/examples/stable_diffusion/samplers/k_lms.py +73 -0
- ai_edge_torch/generative/examples/stable_diffusion/samplers/sampler.py +38 -0
- ai_edge_torch/generative/examples/stable_diffusion/tokenizer.py +108 -0
- ai_edge_torch/generative/examples/stable_diffusion/util.py +71 -0
- ai_edge_torch/generative/examples/t5/__init__.py +14 -0
- ai_edge_torch/generative/examples/t5/convert_to_tflite.py +135 -0
- ai_edge_torch/generative/examples/t5/t5.py +608 -0
- ai_edge_torch/generative/examples/t5/t5_attention.py +231 -0
- ai_edge_torch/generative/examples/test_models/__init__.py +14 -0
- ai_edge_torch/generative/examples/test_models/toy_model.py +122 -0
- ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +161 -0
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +143 -0
- ai_edge_torch/generative/examples/tiny_llama/__init__.py +0 -0
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +66 -0
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +164 -0
- ai_edge_torch/generative/fx_passes/__init__.py +31 -0
- ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +47 -0
- ai_edge_torch/generative/layers/__init__.py +14 -0
- ai_edge_torch/generative/layers/attention.py +354 -0
- ai_edge_torch/generative/layers/attention_utils.py +169 -0
- ai_edge_torch/generative/layers/builder.py +131 -0
- ai_edge_torch/generative/layers/feed_forward.py +95 -0
- ai_edge_torch/generative/layers/kv_cache.py +83 -0
- ai_edge_torch/generative/layers/model_config.py +158 -0
- ai_edge_torch/generative/layers/normalization.py +62 -0
- ai_edge_torch/generative/layers/rotary_position_embedding.py +36 -0
- ai_edge_torch/generative/layers/scaled_dot_product_attention.py +117 -0
- ai_edge_torch/generative/layers/unet/__init__.py +14 -0
- ai_edge_torch/generative/layers/unet/blocks_2d.py +711 -0
- ai_edge_torch/generative/layers/unet/builder.py +47 -0
- ai_edge_torch/generative/layers/unet/model_config.py +269 -0
- ai_edge_torch/generative/quantize/__init__.py +14 -0
- ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/__init__.py +0 -0
- ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/translate_recipe.py +148 -0
- ai_edge_torch/generative/quantize/example.py +45 -0
- ai_edge_torch/generative/quantize/quant_attrs.py +68 -0
- ai_edge_torch/generative/quantize/quant_recipe.py +151 -0
- ai_edge_torch/generative/quantize/quant_recipe_utils.py +51 -0
- ai_edge_torch/generative/quantize/quant_recipes.py +48 -0
- ai_edge_torch/generative/quantize/supported_schemes.py +32 -0
- ai_edge_torch/generative/test/__init__.py +14 -0
- ai_edge_torch/generative/test/loader_test.py +80 -0
- ai_edge_torch/generative/test/test_model_conversion.py +235 -0
- ai_edge_torch/generative/test/test_quantize.py +162 -0
- ai_edge_torch/generative/utilities/__init__.py +15 -0
- ai_edge_torch/generative/utilities/loader.py +328 -0
- ai_edge_torch/generative/utilities/stable_diffusion_loader.py +924 -0
- ai_edge_torch/generative/utilities/t5_loader.py +483 -0
- ai_edge_torch/hlfb/__init__.py +16 -0
- ai_edge_torch/hlfb/mark_pattern/__init__.py +139 -0
- ai_edge_torch/hlfb/mark_pattern/passes.py +42 -0
- ai_edge_torch/hlfb/mark_pattern/pattern.py +273 -0
- ai_edge_torch/hlfb/test/__init__.py +14 -0
- ai_edge_torch/hlfb/test/test_mark_pattern.py +133 -0
- ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py +270 -0
- ai_edge_torch/model.py +142 -0
- ai_edge_torch/quantize/__init__.py +16 -0
- ai_edge_torch/quantize/pt2e_quantizer.py +438 -0
- ai_edge_torch/quantize/pt2e_quantizer_utils.py +1041 -0
- ai_edge_torch/quantize/quant_config.py +81 -0
- ai_edge_torch/testing/__init__.py +14 -0
- ai_edge_torch/testing/model_coverage/__init__.py +16 -0
- ai_edge_torch/testing/model_coverage/model_coverage.py +132 -0
- ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/LICENSE +202 -0
- ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/METADATA +38 -0
- ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/RECORD +121 -0
- ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/WHEEL +5 -0
- ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
# Builder utils for individual components.
|
|
16
|
+
|
|
17
|
+
from torch import nn
|
|
18
|
+
|
|
19
|
+
import ai_edge_torch.generative.layers.unet.model_config as unet_config
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def build_upsampling(config: unet_config.UpSamplingConfig):
|
|
23
|
+
if config.mode == unet_config.SamplingType.NEAREST:
|
|
24
|
+
return nn.UpsamplingNearest2d(scale_factor=config.scale_factor)
|
|
25
|
+
elif config.mode == unet_config.SamplingType.BILINEAR:
|
|
26
|
+
return nn.UpsamplingBilinear2d(scale_factor=config.scale_factor)
|
|
27
|
+
else:
|
|
28
|
+
raise ValueError("Unsupported upsampling type.")
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def build_downsampling(config: unet_config.DownSamplingConfig):
|
|
32
|
+
if config.mode == unet_config.SamplingType.AVERAGE:
|
|
33
|
+
return nn.AvgPool2d(config.kernel_size, config.stride, padding=config.padding)
|
|
34
|
+
elif config.mode == unet_config.SamplingType.CONVOLUTION:
|
|
35
|
+
out_channels = (
|
|
36
|
+
config.in_channels if config.out_channels is None else config.out_channels
|
|
37
|
+
)
|
|
38
|
+
padding = (0, 1, 0, 1) if config.padding == 0 else config.padding
|
|
39
|
+
return nn.Conv2d(
|
|
40
|
+
config.in_channels,
|
|
41
|
+
out_channels=out_channels,
|
|
42
|
+
kernel_size=config.kernel_size,
|
|
43
|
+
stride=config.stride,
|
|
44
|
+
padding=padding,
|
|
45
|
+
)
|
|
46
|
+
else:
|
|
47
|
+
raise ValueError("Unsupported downsampling type.")
|
|
@@ -0,0 +1,269 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
# UNet configuration class.
|
|
17
|
+
from dataclasses import dataclass
|
|
18
|
+
from dataclasses import field
|
|
19
|
+
import enum
|
|
20
|
+
from typing import List, Optional
|
|
21
|
+
|
|
22
|
+
import ai_edge_torch.generative.layers.model_config as layers_cfg
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@enum.unique
|
|
26
|
+
class SamplingType(enum.Enum):
|
|
27
|
+
NEAREST = enum.auto()
|
|
28
|
+
BILINEAR = enum.auto()
|
|
29
|
+
AVERAGE = enum.auto()
|
|
30
|
+
CONVOLUTION = enum.auto()
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@dataclass
|
|
34
|
+
class UpSamplingConfig:
|
|
35
|
+
mode: SamplingType
|
|
36
|
+
scale_factor: float
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
@dataclass
|
|
40
|
+
class DownSamplingConfig:
|
|
41
|
+
mode: SamplingType
|
|
42
|
+
in_channels: int
|
|
43
|
+
kernel_size: int
|
|
44
|
+
stride: int
|
|
45
|
+
padding: int
|
|
46
|
+
out_channels: Optional[int] = None
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
@dataclass
|
|
50
|
+
class ResidualBlock2DConfig:
|
|
51
|
+
in_channels: int
|
|
52
|
+
out_channels: int
|
|
53
|
+
normalization_config: layers_cfg.NormalizationConfig
|
|
54
|
+
activation_config: layers_cfg.ActivationConfig
|
|
55
|
+
# Optional time embedding channels if the residual block takes a time embedding context as input
|
|
56
|
+
time_embedding_channels: Optional[int] = None
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
@dataclass
|
|
60
|
+
class AttentionBlock2DConfig:
|
|
61
|
+
dim: int
|
|
62
|
+
normalization_config: layers_cfg.NormalizationConfig
|
|
63
|
+
attention_config: layers_cfg.AttentionConfig
|
|
64
|
+
enable_hlfb: bool = True
|
|
65
|
+
attention_batch_size: int = 1
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
@dataclass
|
|
69
|
+
class CrossAttentionBlock2DConfig:
|
|
70
|
+
query_dim: int
|
|
71
|
+
cross_dim: int
|
|
72
|
+
normalization_config: layers_cfg.NormalizationConfig
|
|
73
|
+
attention_config: layers_cfg.AttentionConfig
|
|
74
|
+
enable_hlfb: bool = True
|
|
75
|
+
attention_batch_size: int = 1
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
@dataclass
|
|
79
|
+
class FeedForwardBlock2DConfig:
|
|
80
|
+
dim: int
|
|
81
|
+
hidden_dim: int
|
|
82
|
+
normalization_config: layers_cfg.NormalizationConfig
|
|
83
|
+
activation_config: layers_cfg.ActivationConfig
|
|
84
|
+
use_bias: bool
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
@dataclass
|
|
88
|
+
class TransformerBlock2DConfig:
|
|
89
|
+
pre_conv_normalization_config: layers_cfg.NormalizationConfig
|
|
90
|
+
attention_block_config: AttentionBlock2DConfig
|
|
91
|
+
cross_attention_block_config: CrossAttentionBlock2DConfig
|
|
92
|
+
feed_forward_block_config: FeedForwardBlock2DConfig
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
@dataclass
|
|
96
|
+
class UpDecoderBlock2DConfig:
|
|
97
|
+
in_channels: int
|
|
98
|
+
out_channels: int
|
|
99
|
+
normalization_config: layers_cfg.NormalizationConfig
|
|
100
|
+
activation_config: layers_cfg.ActivationConfig
|
|
101
|
+
num_layers: int
|
|
102
|
+
# Optional time embedding channels if the residual blocks take a time embedding as input
|
|
103
|
+
time_embedding_channels: Optional[int] = None
|
|
104
|
+
# Whether to add upsample operation after residual blocks
|
|
105
|
+
add_upsample: bool = True
|
|
106
|
+
# Whether to add a conv2d layer after upsample
|
|
107
|
+
upsample_conv: bool = True
|
|
108
|
+
# Optional sampling config if add_upsample is True.
|
|
109
|
+
sampling_config: Optional[UpSamplingConfig] = None
|
|
110
|
+
# Optional config of transformer blocks interleaved with residual blocks
|
|
111
|
+
transformer_block_config: Optional[TransformerBlock2DConfig] = None
|
|
112
|
+
# Optional dimension of context tensor if context tensor is given as input.
|
|
113
|
+
context_dim: Optional[int] = None
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
@dataclass
|
|
117
|
+
class SkipUpDecoderBlock2DConfig:
|
|
118
|
+
in_channels: int
|
|
119
|
+
out_channels: int
|
|
120
|
+
# The dimension of output channels of previous connected block
|
|
121
|
+
prev_out_channels: int
|
|
122
|
+
normalization_config: layers_cfg.NormalizationConfig
|
|
123
|
+
activation_config: layers_cfg.ActivationConfig
|
|
124
|
+
num_layers: int
|
|
125
|
+
# Optional time embedding channels if the residual blocks take a time embedding as input
|
|
126
|
+
time_embedding_channels: Optional[int] = None
|
|
127
|
+
# Whether to add upsample operation after residual blocks
|
|
128
|
+
add_upsample: bool = True
|
|
129
|
+
# Whether to add a conv2d layer after upsample
|
|
130
|
+
upsample_conv: bool = True
|
|
131
|
+
# Optional sampling config if add_upsample is True.
|
|
132
|
+
sampling_config: Optional[UpSamplingConfig] = None
|
|
133
|
+
# Optional config of transformer blocks interleaved with residual blocks
|
|
134
|
+
transformer_block_config: Optional[TransformerBlock2DConfig] = None
|
|
135
|
+
# Optional dimension of context tensor if context tensor is given as input.
|
|
136
|
+
context_dim: Optional[int] = None
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
@dataclass
|
|
140
|
+
class DownEncoderBlock2DConfig:
|
|
141
|
+
in_channels: int
|
|
142
|
+
out_channels: int
|
|
143
|
+
normalization_config: layers_cfg.NormalizationConfig
|
|
144
|
+
activation_config: layers_cfg.ActivationConfig
|
|
145
|
+
num_layers: int
|
|
146
|
+
# Padding for the downsampling convolution.
|
|
147
|
+
padding: int = 1
|
|
148
|
+
# Optional time embedding channels if the residual blocks take a time embedding as input
|
|
149
|
+
time_embedding_channels: Optional[int] = None
|
|
150
|
+
# Whether to add downsample operation after residual blocks
|
|
151
|
+
add_downsample: bool = True
|
|
152
|
+
# Optional sampling config if add_upsample is True.
|
|
153
|
+
sampling_config: Optional[DownSamplingConfig] = None
|
|
154
|
+
# Optional config of transformer blocks interleaved with residual blocks
|
|
155
|
+
transformer_block_config: Optional[TransformerBlock2DConfig] = None
|
|
156
|
+
# Optional dimension of context tensor if context tensor is given as input.
|
|
157
|
+
context_dim: Optional[int] = None
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
@dataclass
|
|
161
|
+
class MidBlock2DConfig:
|
|
162
|
+
in_channels: int
|
|
163
|
+
normalization_config: layers_cfg.NormalizationConfig
|
|
164
|
+
activation_config: layers_cfg.ActivationConfig
|
|
165
|
+
num_layers: int
|
|
166
|
+
# Optional time embedding channels if the residual blocks take a time embedding context as input
|
|
167
|
+
time_embedding_channels: Optional[int] = None
|
|
168
|
+
# Optional config of attention blocks interleaved with residual blocks
|
|
169
|
+
attention_block_config: Optional[AttentionBlock2DConfig] = None
|
|
170
|
+
# Optional config of transformer blocks interleaved with residual blocks
|
|
171
|
+
transformer_block_config: Optional[TransformerBlock2DConfig] = None
|
|
172
|
+
# Optional dimension of context tensor if context tensor is given as input.
|
|
173
|
+
context_dim: Optional[int] = None
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
@dataclass
|
|
177
|
+
class AutoEncoderConfig:
|
|
178
|
+
"""Configurations of encoder/decoder in the autoencoder model."""
|
|
179
|
+
|
|
180
|
+
# The activation type of encoder/decoder blocks.
|
|
181
|
+
activation_config: layers_cfg.ActivationConfig
|
|
182
|
+
|
|
183
|
+
# The output channels of each block.
|
|
184
|
+
block_out_channels: List[int]
|
|
185
|
+
|
|
186
|
+
# Number of channels in the input image.
|
|
187
|
+
in_channels: int
|
|
188
|
+
|
|
189
|
+
# Number of channels in the output.
|
|
190
|
+
out_channels: int
|
|
191
|
+
|
|
192
|
+
# Number of channels in the latent space.
|
|
193
|
+
latent_channels: int
|
|
194
|
+
|
|
195
|
+
# The component-wise standard deviation of the trained latent space computed using the first batch of the
|
|
196
|
+
# training set. This is used to scale the latent space to have unit variance when training the diffusion
|
|
197
|
+
# model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
|
|
198
|
+
# diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
|
|
199
|
+
# / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
|
|
200
|
+
# Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
|
|
201
|
+
scaling_factor: float
|
|
202
|
+
|
|
203
|
+
# The layesr number of each encoder/decoder block.
|
|
204
|
+
layers_per_block: int
|
|
205
|
+
|
|
206
|
+
# The normalization config.
|
|
207
|
+
normalization_config: layers_cfg.NormalizationConfig
|
|
208
|
+
|
|
209
|
+
# The configuration of middle blocks, that is, after the last block of encoder and before the first block of decoder.
|
|
210
|
+
mid_block_config: MidBlock2DConfig
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
@dataclass
|
|
214
|
+
class DiffusionModelConfig:
|
|
215
|
+
"""Configurations of Diffusion model."""
|
|
216
|
+
|
|
217
|
+
# Number of channels in the input tensor.
|
|
218
|
+
in_channels: int
|
|
219
|
+
|
|
220
|
+
# Number of channels in the output tensor.
|
|
221
|
+
out_channels: int
|
|
222
|
+
|
|
223
|
+
# The output channels of each block.
|
|
224
|
+
block_out_channels: List[int]
|
|
225
|
+
|
|
226
|
+
# The layesr number of each block.
|
|
227
|
+
layers_per_block: int
|
|
228
|
+
|
|
229
|
+
# The padding to use for the downsampling.
|
|
230
|
+
downsample_padding: int
|
|
231
|
+
|
|
232
|
+
# Normalization config used in residual blocks.
|
|
233
|
+
residual_norm_config: layers_cfg.NormalizationConfig
|
|
234
|
+
|
|
235
|
+
# Activation config used in residual blocks
|
|
236
|
+
residual_activation_type: layers_cfg.ActivationType
|
|
237
|
+
|
|
238
|
+
# The batch size used in transformer blocks, for attention layers.
|
|
239
|
+
transformer_batch_size: int
|
|
240
|
+
|
|
241
|
+
# The number of attention heads used in transformer blocks.
|
|
242
|
+
transformer_num_attention_heads: int
|
|
243
|
+
|
|
244
|
+
# The dimension of cross attention used in transformer blocks.
|
|
245
|
+
transformer_cross_attention_dim: int
|
|
246
|
+
|
|
247
|
+
# Normalization config used in prev conv layer of transformer blocks.
|
|
248
|
+
transformer_pre_conv_norm_config: layers_cfg.NormalizationConfig
|
|
249
|
+
|
|
250
|
+
# Normalization config used in transformer blocks.
|
|
251
|
+
transformer_norm_config: layers_cfg.NormalizationConfig
|
|
252
|
+
|
|
253
|
+
# Activation type of feed forward used in transformer blocks.
|
|
254
|
+
transformer_ff_activation_type: layers_cfg.ActivationType
|
|
255
|
+
|
|
256
|
+
# Number of layers in mid block.
|
|
257
|
+
mid_block_layers: int
|
|
258
|
+
|
|
259
|
+
# Dimension of time embedding.
|
|
260
|
+
time_embedding_dim: int
|
|
261
|
+
|
|
262
|
+
# Time embedding dimensions for blocks.
|
|
263
|
+
time_embedding_blocks_dim: int
|
|
264
|
+
|
|
265
|
+
# Normalization config used for final layer
|
|
266
|
+
final_norm_config: layers_cfg.NormalizationConfig
|
|
267
|
+
|
|
268
|
+
# Activation type used in final layer
|
|
269
|
+
final_activation_type: layers_cfg.ActivationType
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
File without changes
|
|
@@ -0,0 +1,148 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
import json
|
|
17
|
+
|
|
18
|
+
from ai_edge_quantizer import quantizer
|
|
19
|
+
|
|
20
|
+
from ai_edge_torch.generative.quantize import quant_attrs
|
|
21
|
+
from ai_edge_torch.generative.quantize import quant_recipe
|
|
22
|
+
|
|
23
|
+
_OpExecutionMode = quantizer.qtyping.OpExecutionMode
|
|
24
|
+
_OpName = quantizer.qtyping.TFLOperationName
|
|
25
|
+
_TensorQuantConfig = quantizer.qtyping.TensorQuantizationConfig
|
|
26
|
+
_OpQuantConfig = quantizer.qtyping.OpQuantizationConfig
|
|
27
|
+
|
|
28
|
+
_DEFAULT_REGEX_STR = '.*'
|
|
29
|
+
_SINGULAR_TRANSFORMER_BLOCK_REGEX_STR = 'transformer_block'
|
|
30
|
+
_IDX_TRANSFORMER_BLOCKS_REGEX_STR = 'transformer_blocks\[{}\]'
|
|
31
|
+
_ATTENTION_REGEX_STR = 'ai_edge_torch.generative.layers.attention'
|
|
32
|
+
_FEEDFORWARD_REGEX_STR = 'ai_edge_torch.generative.layers.feed_forward'
|
|
33
|
+
_EMBEDDING_REGEX_STR = 'Embedding_tok_embedding'
|
|
34
|
+
_ANY_TWO_DIGITS_REGEX_STR = '\d{1,2}'
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def _get_nbits_from_dtype(dtype: quant_attrs.Dtype) -> int:
|
|
38
|
+
if dtype == quant_attrs.Dtype.FP32:
|
|
39
|
+
return 32
|
|
40
|
+
elif dtype == quant_attrs.Dtype.FP16:
|
|
41
|
+
return 16
|
|
42
|
+
elif dtype == quant_attrs.Dtype.INT8:
|
|
43
|
+
return 8
|
|
44
|
+
raise ValueError('Unimplemented number of bits')
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def _get_dtype_from_dtype(dtype: quant_attrs.Dtype) -> quantizer.qtyping.TensorDataType:
|
|
48
|
+
if dtype == quant_attrs.Dtype.FP32 or dtype == quant_attrs.Dtype.FP16:
|
|
49
|
+
return quantizer.qtyping.TensorDataType.FLOAT
|
|
50
|
+
else:
|
|
51
|
+
return quantizer.qtyping.TensorDataType.INT
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def _get_execution_mode_from_mode(mode: quant_attrs.Mode) -> _OpExecutionMode:
|
|
55
|
+
if mode == quant_attrs.Mode.DYNAMIC_RANGE:
|
|
56
|
+
return _OpExecutionMode.DRQ
|
|
57
|
+
elif mode == quant_attrs.Mode.WEIGHT_ONLY:
|
|
58
|
+
return _OpExecutionMode.WEIGHT_ONLY
|
|
59
|
+
raise ValueError('Unimplemented execution mode')
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def _get_channelwise_from_granularity(granularity: quant_attrs.Granularity) -> bool:
|
|
63
|
+
if granularity == quant_attrs.Granularity.CHANNELWISE:
|
|
64
|
+
return True
|
|
65
|
+
elif granularity == quant_attrs.Granularity.NONE:
|
|
66
|
+
return False
|
|
67
|
+
raise ValueError('Unimplemented granularity')
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def _get_algorithm_key_from_algorithm(algo: quant_attrs.Algorithm) -> str:
|
|
71
|
+
if algo == quant_attrs.Algorithm.MIN_MAX:
|
|
72
|
+
return quantizer.algorithm_manager.AlgorithmName.MIN_MAX_UNIFORM_QUANT
|
|
73
|
+
elif algo == quant_attrs.Algorithm.FLOAT_CAST:
|
|
74
|
+
return quantizer.algorithm_manager.AlgorithmName.FLOAT_CASTING
|
|
75
|
+
raise ValueError('Unimplemented algorithm')
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def _set_quant_config(
|
|
79
|
+
rm: quantizer.recipe_manager.RecipeManager,
|
|
80
|
+
layer_recipe: quant_recipe.LayerQuantRecipe,
|
|
81
|
+
regex: str,
|
|
82
|
+
):
|
|
83
|
+
rm.add_quantization_config(
|
|
84
|
+
regex=regex,
|
|
85
|
+
operation_name=_OpName.ALL_SUPPORTED,
|
|
86
|
+
op_config=_OpQuantConfig(
|
|
87
|
+
weight_tensor_config=_TensorQuantConfig(
|
|
88
|
+
num_bits=_get_nbits_from_dtype(layer_recipe.weight_dtype),
|
|
89
|
+
symmetric=True,
|
|
90
|
+
channel_wise=_get_channelwise_from_granularity(layer_recipe.granularity),
|
|
91
|
+
dtype=_get_dtype_from_dtype(layer_recipe.weight_dtype),
|
|
92
|
+
),
|
|
93
|
+
execution_mode=_get_execution_mode_from_mode(layer_recipe.mode),
|
|
94
|
+
),
|
|
95
|
+
algorithm_key=_get_algorithm_key_from_algorithm(layer_recipe.algorithm),
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def translate_to_ai_edge_recipe(
|
|
100
|
+
recipe: quant_recipe.GenerativeQuantRecipe,
|
|
101
|
+
) -> quantizer.recipe_manager.ModelQuantizationRecipe:
|
|
102
|
+
rm = quantizer.recipe_manager.RecipeManager()
|
|
103
|
+
|
|
104
|
+
if recipe.default is not None:
|
|
105
|
+
_set_quant_config(rm, recipe.default, _DEFAULT_REGEX_STR)
|
|
106
|
+
|
|
107
|
+
if recipe.embedding is not None:
|
|
108
|
+
_set_quant_config(rm, recipe.embedding, _EMBEDDING_REGEX_STR)
|
|
109
|
+
|
|
110
|
+
if recipe.attention is not None:
|
|
111
|
+
if isinstance(recipe.attention, dict):
|
|
112
|
+
for idx, layer in recipe.attention.items():
|
|
113
|
+
_set_quant_config(
|
|
114
|
+
rm,
|
|
115
|
+
layer,
|
|
116
|
+
f'{_IDX_TRANSFORMER_BLOCKS_REGEX_STR.format(idx)}/{_ATTENTION_REGEX_STR}',
|
|
117
|
+
)
|
|
118
|
+
else:
|
|
119
|
+
_set_quant_config(
|
|
120
|
+
rm,
|
|
121
|
+
recipe.attention,
|
|
122
|
+
f'{_SINGULAR_TRANSFORMER_BLOCK_REGEX_STR}/{_ATTENTION_REGEX_STR}',
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
if recipe.feedforward is not None:
|
|
126
|
+
if isinstance(recipe.feedforward, dict):
|
|
127
|
+
for idx, layer in recipe.feedforward.items():
|
|
128
|
+
_set_quant_config(
|
|
129
|
+
rm,
|
|
130
|
+
layer,
|
|
131
|
+
f'{_IDX_TRANSFORMER_BLOCKS_REGEX_STR.format(idx)}/{_FEEDFORWARD_REGEX_STR}',
|
|
132
|
+
)
|
|
133
|
+
else:
|
|
134
|
+
_set_quant_config(
|
|
135
|
+
rm,
|
|
136
|
+
recipe.feedforward,
|
|
137
|
+
f'{_SINGULAR_TRANSFORMER_BLOCK_REGEX_STR}/{_FEEDFORWARD_REGEX_STR}',
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
return rm.get_quantization_recipe()
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
def quantize_model(
|
|
144
|
+
model: bytearray, recipe: quantizer.recipe_manager.ModelQuantizationRecipe
|
|
145
|
+
) -> bytearray:
|
|
146
|
+
qt = quantizer.Quantizer(bytearray(model), recipe)
|
|
147
|
+
result = qt.quantize()
|
|
148
|
+
return result.quantized_model
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
import numpy as np
|
|
17
|
+
import torch
|
|
18
|
+
|
|
19
|
+
import ai_edge_torch
|
|
20
|
+
from ai_edge_torch.generative.examples.gemma import gemma
|
|
21
|
+
from ai_edge_torch.generative.quantize import quant_recipes
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def main():
|
|
25
|
+
# Build a PyTorch model as usual
|
|
26
|
+
config = gemma.get_fake_model_config_2b_for_test()
|
|
27
|
+
model = gemma.Gemma(config)
|
|
28
|
+
idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
|
|
29
|
+
tokens = torch.full((1, 10), 0, dtype=torch.long, device="cpu")
|
|
30
|
+
tokens[0, :4] = idx
|
|
31
|
+
input_pos = torch.arange(0, 10)
|
|
32
|
+
|
|
33
|
+
# Create a quantization recipe to be applied to the model
|
|
34
|
+
quant_config = quant_recipes.full_int8_dynamic_recipe()
|
|
35
|
+
print(quant_config)
|
|
36
|
+
|
|
37
|
+
# Convert with quantization
|
|
38
|
+
edge_model = ai_edge_torch.convert(
|
|
39
|
+
model, (tokens, input_pos), quant_config=quant_config
|
|
40
|
+
)
|
|
41
|
+
edge_model.export("/tmp/gemma_2b_quantized.tflite")
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
if __name__ == "__main__":
|
|
45
|
+
main()
|
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
import enum
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@enum.unique
|
|
20
|
+
class Dtype(enum.Enum):
|
|
21
|
+
"""Data types and precision of tensors."""
|
|
22
|
+
|
|
23
|
+
FP32 = enum.auto()
|
|
24
|
+
FP16 = enum.auto()
|
|
25
|
+
INT8 = enum.auto()
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@enum.unique
|
|
29
|
+
class Algorithm(enum.Enum):
|
|
30
|
+
"""Algorithm used to calculate quantization parameters.
|
|
31
|
+
|
|
32
|
+
Attributes:
|
|
33
|
+
MIN_MAX: Maps the min/max of floating point space to the min/max of
|
|
34
|
+
quantized space and quantize uniformly.
|
|
35
|
+
FLOAT_CAST: Casts a float to another float of a different type.
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
MIN_MAX = enum.auto()
|
|
39
|
+
FLOAT_CAST = enum.auto()
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
@enum.unique
|
|
43
|
+
class Mode(enum.Enum):
|
|
44
|
+
"""Mode of quantization.
|
|
45
|
+
|
|
46
|
+
Attributes:
|
|
47
|
+
DYNAMIC_RANGE: Quantize activations during runtime and weights statically to
|
|
48
|
+
perform computation in integers.
|
|
49
|
+
WEIGHT_ONLY: Quantize weights statically and dequantize during runtime to
|
|
50
|
+
perform computation in floating points.
|
|
51
|
+
"""
|
|
52
|
+
|
|
53
|
+
DYNAMIC_RANGE = enum.auto()
|
|
54
|
+
WEIGHT_ONLY = enum.auto()
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
@enum.unique
|
|
58
|
+
class Granularity(enum.Enum):
|
|
59
|
+
"""Granularity of quantization parameters.
|
|
60
|
+
|
|
61
|
+
Attributes:
|
|
62
|
+
NONE: Granularity not applicable to this quantization scheme.
|
|
63
|
+
CHANNELWISE: Or per-channel quantization. Each channel of relevant tensors
|
|
64
|
+
is quantized independently of one another.
|
|
65
|
+
"""
|
|
66
|
+
|
|
67
|
+
NONE = enum.auto()
|
|
68
|
+
CHANNELWISE = enum.auto()
|