ai-edge-torch-nightly 0.2.0.dev20240609__py3-none-any.whl → 0.2.0.dev20240611__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ai-edge-torch-nightly might be problematic. Click here for more details.
- ai_edge_torch/convert/conversion_utils.py +17 -5
- ai_edge_torch/generative/examples/gemma/gemma.py +1 -1
- ai_edge_torch/generative/examples/phi2/phi2.py +1 -1
- ai_edge_torch/generative/examples/stable_diffusion/clip.py +1 -1
- ai_edge_torch/generative/examples/stable_diffusion/decoder.py +4 -4
- ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +6 -1
- ai_edge_torch/generative/examples/t5/t5.py +1 -1
- ai_edge_torch/generative/examples/test_models/toy_model.py +1 -1
- ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +1 -1
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +1 -1
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +1 -1
- ai_edge_torch/generative/layers/builder.py +31 -9
- ai_edge_torch/generative/layers/model_config.py +10 -1
- ai_edge_torch/generative/layers/unet/blocks_2d.py +4 -4
- ai_edge_torch/generative/layers/unet/model_config.py +4 -4
- ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/__init__.py +0 -0
- ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/translate_recipe.py +164 -0
- ai_edge_torch/generative/quantize/quant_attrs.py +2 -0
- ai_edge_torch/generative/quantize/quant_recipe.py +49 -4
- ai_edge_torch/generative/quantize/quant_recipe_utils.py +2 -2
- ai_edge_torch/generative/quantize/quant_recipes.py +3 -3
- ai_edge_torch/generative/quantize/supported_schemes.py +2 -1
- ai_edge_torch/generative/test/test_quantize.py +74 -20
- ai_edge_torch/quantize/quant_config.py +11 -15
- {ai_edge_torch_nightly-0.2.0.dev20240609.dist-info → ai_edge_torch_nightly-0.2.0.dev20240611.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.2.0.dev20240609.dist-info → ai_edge_torch_nightly-0.2.0.dev20240611.dist-info}/RECORD +29 -27
- {ai_edge_torch_nightly-0.2.0.dev20240609.dist-info → ai_edge_torch_nightly-0.2.0.dev20240611.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.2.0.dev20240609.dist-info → ai_edge_torch_nightly-0.2.0.dev20240611.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.2.0.dev20240609.dist-info → ai_edge_torch_nightly-0.2.0.dev20240611.dist-info}/top_level.txt +0 -0
|
@@ -24,6 +24,7 @@ from typing import Any, Dict, Optional, Tuple, Union
|
|
|
24
24
|
import torch
|
|
25
25
|
from torch_xla import stablehlo
|
|
26
26
|
|
|
27
|
+
from ai_edge_torch.generative.quantize.ai_edge_quantizer_glue import translate_recipe # NOQA
|
|
27
28
|
from ai_edge_torch.quantize import quant_config as qcfg
|
|
28
29
|
|
|
29
30
|
try:
|
|
@@ -249,11 +250,6 @@ def _set_tfl_converter_quant_flags(
|
|
|
249
250
|
converter._experimental_qdq_conversion_mode = "DYNAMIC"
|
|
250
251
|
elif quantizer_mode == qcfg.QuantConfig._QuantizerMode.PT2E_STATIC:
|
|
251
252
|
converter._experimental_qdq_conversion_mode = "STATIC"
|
|
252
|
-
elif quantizer_mode == qcfg.QuantConfig._QuantizerMode.TFLITE_DYNAMIC:
|
|
253
|
-
converter.optimizations = [tf.lite.Optimize.DEFAULT]
|
|
254
|
-
elif quantizer_mode == qcfg.QuantConfig._QuantizerMode.TFLITE_FP16:
|
|
255
|
-
converter.optimizations = [tf.lite.Optimize.DEFAULT]
|
|
256
|
-
converter.target_spec.supported_types = [tf.float16]
|
|
257
253
|
|
|
258
254
|
|
|
259
255
|
def convert_stablehlo_to_tflite(
|
|
@@ -323,8 +319,24 @@ def convert_stablehlo_to_tflite(
|
|
|
323
319
|
converter._experimental_enable_composite_direct_lowering = True
|
|
324
320
|
|
|
325
321
|
_set_tfl_converter_quant_flags(converter, quant_config)
|
|
322
|
+
if (
|
|
323
|
+
quant_config is not None
|
|
324
|
+
and quant_config._quantizer_mode
|
|
325
|
+
== quant_config._QuantizerMode.AI_EDGE_QUANTIZER
|
|
326
|
+
):
|
|
327
|
+
translated_recipe = translate_recipe.translate_to_ai_edge_recipe(
|
|
328
|
+
quant_config.generative_recipe
|
|
329
|
+
)
|
|
330
|
+
|
|
326
331
|
_apply_tfl_backdoor_flags(converter, _tfl_converter_flags)
|
|
327
332
|
|
|
328
333
|
tflite_model = converter.convert()
|
|
329
334
|
|
|
335
|
+
if (
|
|
336
|
+
quant_config is not None
|
|
337
|
+
and quant_config._quantizer_mode
|
|
338
|
+
== quant_config._QuantizerMode.AI_EDGE_QUANTIZER
|
|
339
|
+
):
|
|
340
|
+
tflite_model = translate_recipe.quantize_model(tflite_model, translated_recipe)
|
|
341
|
+
|
|
330
342
|
return tflite_model
|
|
@@ -116,7 +116,7 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
|
116
116
|
)
|
|
117
117
|
ff_config = cfg.FeedForwardConfig(
|
|
118
118
|
type=cfg.FeedForwardType.GATED,
|
|
119
|
-
activation=cfg.ActivationType.GELU_TANH,
|
|
119
|
+
activation=cfg.ActivationConfig(cfg.ActivationType.GELU_TANH),
|
|
120
120
|
intermediate_size=16384,
|
|
121
121
|
)
|
|
122
122
|
norm_config = cfg.NormalizationConfig(
|
|
@@ -112,7 +112,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
|
112
112
|
)
|
|
113
113
|
ff_config = cfg.FeedForwardConfig(
|
|
114
114
|
type=cfg.FeedForwardType.SEQUENTIAL,
|
|
115
|
-
activation=cfg.ActivationType.GELU_TANH,
|
|
115
|
+
activation=cfg.ActivationConfig(cfg.ActivationType.GELU_TANH),
|
|
116
116
|
intermediate_size=10240,
|
|
117
117
|
use_bias=True,
|
|
118
118
|
)
|
|
@@ -90,7 +90,7 @@ def get_model_config() -> cfg.ModelConfig:
|
|
|
90
90
|
|
|
91
91
|
ff_config = cfg.FeedForwardConfig(
|
|
92
92
|
type=cfg.FeedForwardType.SEQUENTIAL,
|
|
93
|
-
activation=cfg.ActivationType.GELU_QUICK,
|
|
93
|
+
activation=cfg.ActivationConfig(cfg.ActivationType.GELU_QUICK),
|
|
94
94
|
intermediate_size=embedding_dim * 4,
|
|
95
95
|
use_bias=True,
|
|
96
96
|
)
|
|
@@ -221,7 +221,7 @@ class Decoder(nn.Module):
|
|
|
221
221
|
in_channels=prev_output_channel,
|
|
222
222
|
out_channels=block_out_channels,
|
|
223
223
|
normalization_config=config.normalization_config,
|
|
224
|
-
|
|
224
|
+
activation_config=config.activation_config,
|
|
225
225
|
num_layers=config.layers_per_block,
|
|
226
226
|
add_upsample=not_final_block,
|
|
227
227
|
upsample_conv=True,
|
|
@@ -235,7 +235,7 @@ class Decoder(nn.Module):
|
|
|
235
235
|
self.final_norm = layers_builder.build_norm(
|
|
236
236
|
block_out_channels, config.normalization_config
|
|
237
237
|
)
|
|
238
|
-
self.act_fn = layers_builder.get_activation(config.
|
|
238
|
+
self.act_fn = layers_builder.get_activation(config.activation_config)
|
|
239
239
|
self.conv_out = nn.Conv2d(
|
|
240
240
|
block_out_channels,
|
|
241
241
|
config.out_channels,
|
|
@@ -287,7 +287,7 @@ def get_model_config() -> unet_cfg.AutoEncoderConfig:
|
|
|
287
287
|
mid_block_config = unet_cfg.MidBlock2DConfig(
|
|
288
288
|
in_channels=block_out_channels[-1],
|
|
289
289
|
normalization_config=norm_config,
|
|
290
|
-
|
|
290
|
+
activation_config=layers_cfg.ActivationConfig(layers_cfg.ActivationType.SILU),
|
|
291
291
|
num_layers=1,
|
|
292
292
|
attention_block_config=att_config,
|
|
293
293
|
)
|
|
@@ -296,7 +296,7 @@ def get_model_config() -> unet_cfg.AutoEncoderConfig:
|
|
|
296
296
|
in_channels=in_channels,
|
|
297
297
|
latent_channels=latent_channels,
|
|
298
298
|
out_channels=out_channels,
|
|
299
|
-
|
|
299
|
+
activation_config=layers_cfg.ActivationConfig(layers_cfg.ActivationType.SILU),
|
|
300
300
|
block_out_channels=block_out_channels,
|
|
301
301
|
scaling_factor=scaling_factor,
|
|
302
302
|
layers_per_block=layers_per_block,
|
|
@@ -130,7 +130,7 @@ class Upsample(nn.Module):
|
|
|
130
130
|
self.conv = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
|
|
131
131
|
|
|
132
132
|
def forward(self, x):
|
|
133
|
-
x = F.interpolate(x, scale_factor=2, mode=
|
|
133
|
+
x = F.interpolate(x, scale_factor=2, mode="nearest")
|
|
134
134
|
return self.conv(x)
|
|
135
135
|
|
|
136
136
|
|
|
@@ -237,3 +237,8 @@ class Diffusion(nn.Module):
|
|
|
237
237
|
output = self.unet(latent, context, time)
|
|
238
238
|
output = self.final(output)
|
|
239
239
|
return output
|
|
240
|
+
|
|
241
|
+
|
|
242
|
+
if __name__ == "__main__":
|
|
243
|
+
diffusion = Diffusion()
|
|
244
|
+
print(diffusion.state_dict().keys())
|
|
@@ -349,7 +349,7 @@ def get_model_config_t5() -> cfg.ModelConfig:
|
|
|
349
349
|
)
|
|
350
350
|
ff_config = cfg.FeedForwardConfig(
|
|
351
351
|
type=cfg.FeedForwardType.SEQUENTIAL,
|
|
352
|
-
activation=cfg.ActivationType.RELU,
|
|
352
|
+
activation=cfg.ActivationConfig(cfg.ActivationType.RELU),
|
|
353
353
|
intermediate_size=3072,
|
|
354
354
|
)
|
|
355
355
|
# T5 Confirmed as RMS Norm and eps = 1e-6 TJA.
|
|
@@ -76,7 +76,7 @@ def define_and_run() -> None:
|
|
|
76
76
|
)
|
|
77
77
|
ff_config = cfg.FeedForwardConfig(
|
|
78
78
|
type=cfg.FeedForwardType.GATED,
|
|
79
|
-
activation=cfg.ActivationType.SILU,
|
|
79
|
+
activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
|
|
80
80
|
intermediate_size=256,
|
|
81
81
|
)
|
|
82
82
|
norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
|
|
@@ -95,7 +95,7 @@ def get_model_config() -> cfg.ModelConfig:
|
|
|
95
95
|
)
|
|
96
96
|
ff_config = cfg.FeedForwardConfig(
|
|
97
97
|
type=cfg.FeedForwardType.GATED,
|
|
98
|
-
activation=cfg.ActivationType.SILU,
|
|
98
|
+
activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
|
|
99
99
|
intermediate_size=256,
|
|
100
100
|
)
|
|
101
101
|
norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
|
|
@@ -83,7 +83,7 @@ def get_model_config() -> cfg.ModelConfig:
|
|
|
83
83
|
)
|
|
84
84
|
ff_config = cfg.FeedForwardConfig(
|
|
85
85
|
type=cfg.FeedForwardType.GATED,
|
|
86
|
-
activation=cfg.ActivationType.SILU,
|
|
86
|
+
activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
|
|
87
87
|
intermediate_size=256,
|
|
88
88
|
)
|
|
89
89
|
norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
|
|
@@ -112,7 +112,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
|
112
112
|
)
|
|
113
113
|
ff_config = cfg.FeedForwardConfig(
|
|
114
114
|
type=cfg.FeedForwardType.GATED,
|
|
115
|
-
activation=cfg.ActivationType.SILU,
|
|
115
|
+
activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
|
|
116
116
|
intermediate_size=5632,
|
|
117
117
|
)
|
|
118
118
|
norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
|
|
@@ -13,6 +13,7 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
# Builder class for individual components.
|
|
16
|
+
import torch
|
|
16
17
|
from torch import nn
|
|
17
18
|
import torch.nn.functional as F
|
|
18
19
|
|
|
@@ -21,6 +22,23 @@ import ai_edge_torch.generative.layers.model_config as cfg
|
|
|
21
22
|
import ai_edge_torch.generative.layers.normalization as normalization
|
|
22
23
|
|
|
23
24
|
|
|
25
|
+
class GeGLU(nn.Module):
|
|
26
|
+
"""GeGLU is an activation function which is a variant of GELU.
|
|
27
|
+
|
|
28
|
+
GeGLU(x) = (xW+b) * GELU(xV+c)
|
|
29
|
+
See: https://arxiv.org/abs/2002.05202v1
|
|
30
|
+
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
def __init__(self, d_in: int, d_out: int):
|
|
34
|
+
super().__init__()
|
|
35
|
+
self.proj = nn.Linear(d_in, d_out * 2)
|
|
36
|
+
|
|
37
|
+
def forward(self, x: torch.Tensor):
|
|
38
|
+
x, gate = self.proj(x).chunk(2, dim=-1)
|
|
39
|
+
return x * F.gelu(gate)
|
|
40
|
+
|
|
41
|
+
|
|
24
42
|
def build_norm(dim: int, config: cfg.NormalizationConfig):
|
|
25
43
|
"""Builder function for normalizers.
|
|
26
44
|
|
|
@@ -81,29 +99,33 @@ def build_ff(dim: int, config: cfg.FeedForwardConfig):
|
|
|
81
99
|
)
|
|
82
100
|
|
|
83
101
|
|
|
84
|
-
def get_activation(
|
|
85
|
-
"""Get pytorch callable activation from the
|
|
102
|
+
def get_activation(config: cfg.ActivationConfig):
|
|
103
|
+
"""Get pytorch callable activation from the activation config.
|
|
86
104
|
|
|
87
105
|
Args:
|
|
88
|
-
|
|
106
|
+
config (cfg.ActivationConfig): activation config.
|
|
89
107
|
|
|
90
108
|
Returns:
|
|
91
109
|
Activation function.
|
|
92
110
|
|
|
93
111
|
Raises:
|
|
94
|
-
ValueError: If activation
|
|
112
|
+
ValueError: If activation config is not supported.
|
|
95
113
|
"""
|
|
96
|
-
if
|
|
114
|
+
if config.type == cfg.ActivationType.LINEAR:
|
|
115
|
+
return lambda x: x
|
|
116
|
+
elif config.type == cfg.ActivationType.SILU:
|
|
97
117
|
return F.silu
|
|
98
|
-
elif
|
|
118
|
+
elif config.type == cfg.ActivationType.GELU:
|
|
99
119
|
return F.gelu
|
|
100
|
-
elif
|
|
120
|
+
elif config.type == cfg.ActivationType.GELU_TANH:
|
|
101
121
|
return lambda x: F.gelu(x, approximate="tanh")
|
|
102
|
-
elif
|
|
122
|
+
elif config.type == cfg.ActivationType.GELU_QUICK:
|
|
103
123
|
# GELU approximation that is fast but somewhat inaccurate.
|
|
104
124
|
# See: https://github.com/hendrycks/GELUs
|
|
105
125
|
return lambda x: x * F.sigmoid(1.702 * x)
|
|
106
|
-
elif
|
|
126
|
+
elif config.type == cfg.ActivationType.GE_GLU:
|
|
127
|
+
return GeGLU(config.dim_in, config.dim_out)
|
|
128
|
+
elif config.type == cfg.ActivationType.RELU:
|
|
107
129
|
return F.relu
|
|
108
130
|
else:
|
|
109
131
|
raise ValueError("Unsupported activation type.")
|
|
@@ -28,6 +28,7 @@ class ActivationType(enum.Enum):
|
|
|
28
28
|
GELU = enum.auto()
|
|
29
29
|
GELU_TANH = enum.auto()
|
|
30
30
|
GELU_QUICK = enum.auto()
|
|
31
|
+
GE_GLU = enum.auto()
|
|
31
32
|
RELU = enum.auto()
|
|
32
33
|
|
|
33
34
|
|
|
@@ -74,12 +75,20 @@ class AttentionConfig:
|
|
|
74
75
|
relative_attention_max_distance: int = 0
|
|
75
76
|
|
|
76
77
|
|
|
78
|
+
@dataclass
|
|
79
|
+
class ActivationConfig:
|
|
80
|
+
type: ActivationType = ActivationType.LINEAR
|
|
81
|
+
# Dimension of input and output, used in GeGLU.
|
|
82
|
+
dim_in: Optional[int] = None
|
|
83
|
+
dim_out: Optional[int] = None
|
|
84
|
+
|
|
85
|
+
|
|
77
86
|
@dataclass
|
|
78
87
|
class FeedForwardConfig:
|
|
79
88
|
"""FeedForward module's parameters."""
|
|
80
89
|
|
|
81
90
|
type: FeedForwardType
|
|
82
|
-
activation:
|
|
91
|
+
activation: ActivationConfig
|
|
83
92
|
intermediate_size: int
|
|
84
93
|
use_bias: bool = False
|
|
85
94
|
|
|
@@ -53,7 +53,7 @@ class ResidualBlock2D(nn.Module):
|
|
|
53
53
|
self.conv_2 = nn.Conv2d(
|
|
54
54
|
config.out_channels, config.out_channels, kernel_size=3, stride=1, padding=1
|
|
55
55
|
)
|
|
56
|
-
self.act_fn = layers_builder.get_activation(config.
|
|
56
|
+
self.act_fn = layers_builder.get_activation(config.activation_config)
|
|
57
57
|
if config.in_channels == config.out_channels:
|
|
58
58
|
self.residual_layer = nn.Identity()
|
|
59
59
|
else:
|
|
@@ -167,7 +167,7 @@ class UpDecoderBlock2D(nn.Module):
|
|
|
167
167
|
out_channels=config.out_channels,
|
|
168
168
|
time_embedding_channels=config.time_embedding_channels,
|
|
169
169
|
normalization_config=config.normalization_config,
|
|
170
|
-
|
|
170
|
+
activation_config=config.activation_config,
|
|
171
171
|
)
|
|
172
172
|
)
|
|
173
173
|
)
|
|
@@ -244,7 +244,7 @@ class MidBlock2D(nn.Module):
|
|
|
244
244
|
out_channels=config.in_channels,
|
|
245
245
|
time_embedding_channels=config.time_embedding_channels,
|
|
246
246
|
normalization_config=config.normalization_config,
|
|
247
|
-
|
|
247
|
+
activation_config=config.activation_config,
|
|
248
248
|
)
|
|
249
249
|
)
|
|
250
250
|
]
|
|
@@ -259,7 +259,7 @@ class MidBlock2D(nn.Module):
|
|
|
259
259
|
out_channels=config.in_channels,
|
|
260
260
|
time_embedding_channels=config.time_embedding_channels,
|
|
261
261
|
normalization_config=config.normalization_config,
|
|
262
|
-
|
|
262
|
+
activation_config=config.activation_config,
|
|
263
263
|
)
|
|
264
264
|
)
|
|
265
265
|
)
|
|
@@ -39,7 +39,7 @@ class ResidualBlock2DConfig:
|
|
|
39
39
|
in_channels: int
|
|
40
40
|
out_channels: int
|
|
41
41
|
normalization_config: layers_cfg.NormalizationConfig
|
|
42
|
-
|
|
42
|
+
activation_config: layers_cfg.ActivationConfig
|
|
43
43
|
# Optional time embedding channels if the residual block takes a time embedding context as input
|
|
44
44
|
time_embedding_channels: Optional[int] = None
|
|
45
45
|
|
|
@@ -56,7 +56,7 @@ class UpDecoderBlock2DConfig:
|
|
|
56
56
|
in_channels: int
|
|
57
57
|
out_channels: int
|
|
58
58
|
normalization_config: layers_cfg.NormalizationConfig
|
|
59
|
-
|
|
59
|
+
activation_config: layers_cfg.ActivationConfig
|
|
60
60
|
num_layers: int
|
|
61
61
|
# Optional time embedding channels if the residual blocks take a time embedding context as input
|
|
62
62
|
time_embedding_channels: Optional[int] = None
|
|
@@ -72,7 +72,7 @@ class UpDecoderBlock2DConfig:
|
|
|
72
72
|
class MidBlock2DConfig:
|
|
73
73
|
in_channels: int
|
|
74
74
|
normalization_config: layers_cfg.NormalizationConfig
|
|
75
|
-
|
|
75
|
+
activation_config: layers_cfg.ActivationConfig
|
|
76
76
|
num_layers: int
|
|
77
77
|
# Optional time embedding channels if the residual blocks take a time embedding context as input
|
|
78
78
|
time_embedding_channels: Optional[int] = None
|
|
@@ -85,7 +85,7 @@ class AutoEncoderConfig:
|
|
|
85
85
|
"""Configurations of encoder/decoder in the autoencoder model."""
|
|
86
86
|
|
|
87
87
|
# The activation type of encoder/decoder blocks.
|
|
88
|
-
|
|
88
|
+
activation_config: layers_cfg.ActivationConfig
|
|
89
89
|
|
|
90
90
|
# The output channels of each block.
|
|
91
91
|
block_out_channels: List[int]
|
|
File without changes
|
|
@@ -0,0 +1,164 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
import json
|
|
17
|
+
|
|
18
|
+
from ai_edge_quantizer import quantizer
|
|
19
|
+
|
|
20
|
+
from ai_edge_torch.generative.quantize import quant_attrs
|
|
21
|
+
from ai_edge_torch.generative.quantize import quant_recipe
|
|
22
|
+
|
|
23
|
+
_OpExecutionMode = quantizer.qtyping.OpExecutionMode
|
|
24
|
+
_OpName = quantizer.qtyping.TFLOperationName
|
|
25
|
+
_TensorQuantConfig = quantizer.qtyping.TensorQuantizationConfig
|
|
26
|
+
_OpQuantConfig = quantizer.qtyping.OpQuantizationConfig
|
|
27
|
+
|
|
28
|
+
_DEFAULT_REGEX_STR = '.*'
|
|
29
|
+
_ATTENTION_IDX_REGEX_STR = (
|
|
30
|
+
'transformer_blocks\[{}\]/ai_edge_torch.generative.layers.attention'
|
|
31
|
+
)
|
|
32
|
+
_FEEDFORWARD_IDX_REGEX_STR = (
|
|
33
|
+
'transformer_blocks\[{}\]/ai_edge_torch.generative.layers.feed_forward'
|
|
34
|
+
)
|
|
35
|
+
_EMBEDDING_REGEX_STR = 'Embedding_tok_embedding'
|
|
36
|
+
_ANY_TWO_DIGITS_REGEX_STR = '\d{1,2}'
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def _get_nbits_from_dtype(dtype: quant_attrs.Dtype) -> int:
|
|
40
|
+
if dtype == quant_attrs.Dtype.FP32:
|
|
41
|
+
return 32
|
|
42
|
+
elif dtype == quant_attrs.Dtype.FP16:
|
|
43
|
+
return 16
|
|
44
|
+
elif dtype == quant_attrs.Dtype.INT8:
|
|
45
|
+
return 8
|
|
46
|
+
raise ValueError('Unimplemented number of bits')
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def _get_dtype_from_dtype(dtype: quant_attrs.Dtype) -> quantizer.qtyping.TensorDataType:
|
|
50
|
+
if dtype == quant_attrs.Dtype.FP32 or dtype == quant_attrs.Dtype.FP16:
|
|
51
|
+
return quantizer.qtyping.TensorDataType.FLOAT
|
|
52
|
+
else:
|
|
53
|
+
return quantizer.qtyping.TensorDataType.INT
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def _get_execution_mode_from_mode(mode: quant_attrs.Mode) -> _OpExecutionMode:
|
|
57
|
+
if mode == quant_attrs.Mode.DYNAMIC_RANGE:
|
|
58
|
+
return _OpExecutionMode.DRQ
|
|
59
|
+
elif mode == quant_attrs.Mode.WEIGHT_ONLY:
|
|
60
|
+
return _OpExecutionMode.WEIGHT_ONLY
|
|
61
|
+
raise ValueError('Unimplemented execution mode')
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def _get_channelwise_from_granularity(granularity: quant_attrs.Granularity) -> bool:
|
|
65
|
+
if granularity == quant_attrs.Granularity.CHANNELWISE:
|
|
66
|
+
return True
|
|
67
|
+
elif granularity == quant_attrs.Granularity.NONE:
|
|
68
|
+
return False
|
|
69
|
+
raise ValueError('Unimplemented granularity')
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def _get_algorithm_key_from_algorithm(algo: quant_attrs.Algorithm) -> str:
|
|
73
|
+
if algo == quant_attrs.Algorithm.MIN_MAX:
|
|
74
|
+
return quantizer.algorithm_manager.AlgorithmName.MIN_MAX_UNIFORM_QUANT
|
|
75
|
+
elif algo == quant_attrs.Algorithm.FLOAT_CAST:
|
|
76
|
+
return quantizer.algorithm_manager.AlgorithmName.FLOAT_CASTING
|
|
77
|
+
raise ValueError('Unimplemented algorithm')
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def _set_quant_config(
|
|
81
|
+
rm: quantizer.recipe_manager.RecipeManager,
|
|
82
|
+
layer_recipe: quant_recipe.LayerQuantRecipe,
|
|
83
|
+
regex: str,
|
|
84
|
+
):
|
|
85
|
+
support_op_list = [_OpName.FULLY_CONNECTED, _OpName.CONV_2D]
|
|
86
|
+
if layer_recipe.algorithm == quant_attrs.Algorithm.MIN_MAX:
|
|
87
|
+
support_op_list += [_OpName.BATCH_MATMUL, _OpName.EMBEDDING_LOOKUP]
|
|
88
|
+
for op_name in support_op_list:
|
|
89
|
+
rm.add_quantization_config(
|
|
90
|
+
regex=regex,
|
|
91
|
+
operation_name=op_name,
|
|
92
|
+
op_config=_OpQuantConfig(
|
|
93
|
+
weight_tensor_config=_TensorQuantConfig(
|
|
94
|
+
num_bits=_get_nbits_from_dtype(layer_recipe.weight_dtype),
|
|
95
|
+
symmetric=True,
|
|
96
|
+
channel_wise=_get_channelwise_from_granularity(
|
|
97
|
+
layer_recipe.granularity
|
|
98
|
+
),
|
|
99
|
+
dtype=_get_dtype_from_dtype(layer_recipe.weight_dtype),
|
|
100
|
+
),
|
|
101
|
+
execution_mode=_get_execution_mode_from_mode(layer_recipe.mode),
|
|
102
|
+
),
|
|
103
|
+
algorithm_key=_get_algorithm_key_from_algorithm(layer_recipe.algorithm),
|
|
104
|
+
override_algorithm=True,
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def translate_to_ai_edge_recipe(
|
|
109
|
+
recipe: quant_recipe.GenerativeQuantRecipe,
|
|
110
|
+
) -> quantizer.recipe_manager.ModelQuantizationRecipe:
|
|
111
|
+
rm = quantizer.recipe_manager.RecipeManager()
|
|
112
|
+
|
|
113
|
+
if recipe.default is not None:
|
|
114
|
+
_set_quant_config(rm, recipe.default, _DEFAULT_REGEX_STR)
|
|
115
|
+
|
|
116
|
+
if recipe.embedding is not None:
|
|
117
|
+
_set_quant_config(rm, recipe.embedding, _EMBEDDING_REGEX_STR)
|
|
118
|
+
|
|
119
|
+
if recipe.attention is not None:
|
|
120
|
+
if isinstance(recipe.attention, dict):
|
|
121
|
+
for idx, layer in recipe.attention.items():
|
|
122
|
+
_set_quant_config(rm, layer, _ATTENTION_IDX_REGEX_STR.format(idx))
|
|
123
|
+
else:
|
|
124
|
+
_set_quant_config(
|
|
125
|
+
rm,
|
|
126
|
+
recipe.attention,
|
|
127
|
+
_ATTENTION_IDX_REGEX_STR.format(_ANY_TWO_DIGITS_REGEX_STR),
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
if recipe.feedforward is not None:
|
|
131
|
+
if isinstance(recipe.feedforward, dict):
|
|
132
|
+
for idx, layer in recipe.feedforward.items():
|
|
133
|
+
_set_quant_config(rm, layer, _FEEDFORWARD_IDX_REGEX_STR.format(idx))
|
|
134
|
+
else:
|
|
135
|
+
_set_quant_config(
|
|
136
|
+
rm,
|
|
137
|
+
recipe.feedforward,
|
|
138
|
+
_FEEDFORWARD_IDX_REGEX_STR.format(_ANY_TWO_DIGITS_REGEX_STR),
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
return rm.get_quantization_recipe()
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
def quantize_model(
|
|
145
|
+
model: bytearray, recipe: quantizer.recipe_manager.ModelQuantizationRecipe
|
|
146
|
+
) -> bytearray:
|
|
147
|
+
# TODO(b/336599483): Remove tempfile and use bytearray instead
|
|
148
|
+
tmp_model_path = '/tmp/tmp.tflite'
|
|
149
|
+
tmp_recipe_path = '/tmp/recipe.json'
|
|
150
|
+
with open(tmp_model_path, 'wb') as fp:
|
|
151
|
+
fp.write(model)
|
|
152
|
+
with open(tmp_recipe_path, 'w') as rp:
|
|
153
|
+
rp.write(json.dumps(recipe))
|
|
154
|
+
|
|
155
|
+
qt = quantizer.Quantizer(tmp_model_path, tmp_recipe_path)
|
|
156
|
+
result = qt.quantize()
|
|
157
|
+
|
|
158
|
+
# TODO(b/336599483): Remove tempfile and use bytearray instead
|
|
159
|
+
import os
|
|
160
|
+
|
|
161
|
+
os.remove(tmp_model_path)
|
|
162
|
+
os.remove(tmp_recipe_path)
|
|
163
|
+
|
|
164
|
+
return result.quantized_model
|
|
@@ -32,9 +32,11 @@ class Algorithm(enum.Enum):
|
|
|
32
32
|
Attributes:
|
|
33
33
|
MIN_MAX: Maps the min/max of floating point space to the min/max of
|
|
34
34
|
quantized space and quantize uniformly.
|
|
35
|
+
FLOAT_CAST: Casts a float to another float of a different type.
|
|
35
36
|
"""
|
|
36
37
|
|
|
37
38
|
MIN_MAX = enum.auto()
|
|
39
|
+
FLOAT_CAST = enum.auto()
|
|
38
40
|
|
|
39
41
|
|
|
40
42
|
@enum.unique
|
|
@@ -14,8 +14,7 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
|
|
16
16
|
from dataclasses import dataclass
|
|
17
|
-
import
|
|
18
|
-
from typing import Optional
|
|
17
|
+
from typing import Optional, Union
|
|
19
18
|
|
|
20
19
|
from ai_edge_torch.generative.quantize import quant_attrs
|
|
21
20
|
from ai_edge_torch.generative.quantize import supported_schemes
|
|
@@ -80,18 +79,50 @@ class LayerQuantRecipe:
|
|
|
80
79
|
|
|
81
80
|
|
|
82
81
|
@dataclass
|
|
83
|
-
class
|
|
82
|
+
class GenerativeQuantRecipe:
|
|
84
83
|
"""Quantization recipe for a model composed of the Edge Generative API layers.
|
|
85
84
|
|
|
85
|
+
Some layers can be specified with different `LayerQuantRecipe` for each block by
|
|
86
|
+
providing a dictionary keyed by the TransformerBlock index, e.g. attention
|
|
87
|
+
and feedforward. For example,
|
|
88
|
+
|
|
89
|
+
```
|
|
90
|
+
default = LayerQuantRecipeA
|
|
91
|
+
attention = { 2: LayerQuantRecipeB }
|
|
92
|
+
feedforward = { 3: LayerQuantRecipeC }
|
|
93
|
+
```
|
|
94
|
+
|
|
95
|
+
will apply LayerQuantRecipeA to the entire model, overriden by
|
|
96
|
+
LayerQuantRecipeB for the TransformerBlock[2].attention layer and
|
|
97
|
+
LayerQuantRecipeC for the TransformerBlock[3].feedforward layer. Any config
|
|
98
|
+
with invalid indices will be ignored.
|
|
99
|
+
|
|
86
100
|
Attributes:
|
|
87
101
|
default: The quantization recipe for global scope of the model.
|
|
102
|
+
embedding: Recipe for the embedding table.
|
|
103
|
+
attention: Recipe for the attention blocks. This could be specified with
|
|
104
|
+
different LayerQuantRecipe for each block by providing a dictionary
|
|
105
|
+
keyed by the TransformerBlock index.
|
|
106
|
+
feedforward: Recipe for the feedforward layers. This could be specified with
|
|
107
|
+
different LayerQuantRecipe for each block by providing a dictionary
|
|
108
|
+
keyed by the TransformerBlock index.
|
|
88
109
|
"""
|
|
89
110
|
|
|
90
111
|
default: Optional[LayerQuantRecipe] = None
|
|
112
|
+
embedding: Optional[LayerQuantRecipe] = None
|
|
113
|
+
attention: Union[
|
|
114
|
+
Optional[LayerQuantRecipe], Optional[dict[int, LayerQuantRecipe]]
|
|
115
|
+
] = None
|
|
116
|
+
feedforward: Union[
|
|
117
|
+
Optional[LayerQuantRecipe], Optional[dict[int, LayerQuantRecipe]]
|
|
118
|
+
] = None
|
|
91
119
|
|
|
92
120
|
def __str__(self):
|
|
93
|
-
return f"""
|
|
121
|
+
return f"""GenerativeQuantRecipe(
|
|
94
122
|
Default: {self.default}
|
|
123
|
+
Embedding: {self.embedding}
|
|
124
|
+
Attention: {self.attention}
|
|
125
|
+
Feedforward: {self.feedforward}
|
|
95
126
|
)"""
|
|
96
127
|
|
|
97
128
|
__repr__ = __str__
|
|
@@ -104,3 +135,17 @@ class TransformerQuantRecipe:
|
|
|
104
135
|
"""
|
|
105
136
|
if self.default is not None:
|
|
106
137
|
self.default.verify()
|
|
138
|
+
if self.embedding is not None:
|
|
139
|
+
self.embedding.verify()
|
|
140
|
+
if self.attention is not None:
|
|
141
|
+
if isinstance(self.attention, dict):
|
|
142
|
+
for recipe in self.attention.values():
|
|
143
|
+
recipe.verify()
|
|
144
|
+
else:
|
|
145
|
+
self.attention.verify()
|
|
146
|
+
if self.feedforward is not None:
|
|
147
|
+
if isinstance(self.feedforward, dict):
|
|
148
|
+
for recipe in self.feedforward.values():
|
|
149
|
+
recipe.verify()
|
|
150
|
+
else:
|
|
151
|
+
self.feedforward.verify()
|
|
@@ -22,7 +22,7 @@ Typical usage example:
|
|
|
22
22
|
|
|
23
23
|
1. Applying a single layer recipe to the entire model
|
|
24
24
|
|
|
25
|
-
quant_recipe.
|
|
25
|
+
quant_recipe.GenerativeQuantRecipe(
|
|
26
26
|
default=quant_recipe_utils.create_layer_quant_int8_dynamic()
|
|
27
27
|
)
|
|
28
28
|
"""
|
|
@@ -46,6 +46,6 @@ def create_layer_quant_fp16() -> quant_recipe.LayerQuantRecipe:
|
|
|
46
46
|
activation_dtype=quant_attrs.Dtype.FP32,
|
|
47
47
|
weight_dtype=quant_attrs.Dtype.FP16,
|
|
48
48
|
mode=quant_attrs.Mode.WEIGHT_ONLY,
|
|
49
|
-
algorithm=quant_attrs.Algorithm.
|
|
49
|
+
algorithm=quant_attrs.Algorithm.FLOAT_CAST,
|
|
50
50
|
granularity=quant_attrs.Granularity.NONE,
|
|
51
51
|
)
|
|
@@ -34,15 +34,15 @@ from ai_edge_torch.quantize import quant_config
|
|
|
34
34
|
|
|
35
35
|
def full_linear_int8_dynamic_recipe() -> quant_config.QuantConfig:
|
|
36
36
|
return quant_config.QuantConfig(
|
|
37
|
-
|
|
38
|
-
default=quant_recipe_utils.create_layer_quant_int8_dynamic()
|
|
37
|
+
generative_recipe=quant_recipe.GenerativeQuantRecipe(
|
|
38
|
+
default=quant_recipe_utils.create_layer_quant_int8_dynamic(),
|
|
39
39
|
)
|
|
40
40
|
)
|
|
41
41
|
|
|
42
42
|
|
|
43
43
|
def full_fp16_recipe() -> quant_config.QuantConfig:
|
|
44
44
|
return quant_config.QuantConfig(
|
|
45
|
-
|
|
45
|
+
generative_recipe=quant_recipe.GenerativeQuantRecipe(
|
|
46
46
|
default=quant_recipe_utils.create_layer_quant_fp16()
|
|
47
47
|
)
|
|
48
48
|
)
|
|
@@ -27,5 +27,6 @@ def get_supported_layer_schemes():
|
|
|
27
27
|
|
|
28
28
|
return [
|
|
29
29
|
(_t.FP32, _t.INT8, _m.DYNAMIC_RANGE, _a.MIN_MAX, _g.CHANNELWISE),
|
|
30
|
-
(_t.FP32, _t.
|
|
30
|
+
(_t.FP32, _t.INT8, _m.WEIGHT_ONLY, _a.MIN_MAX, _g.CHANNELWISE),
|
|
31
|
+
(_t.FP32, _t.FP16, _m.WEIGHT_ONLY, _a.FLOAT_CAST, _g.NONE),
|
|
31
32
|
]
|
|
@@ -21,11 +21,13 @@ import torch
|
|
|
21
21
|
import ai_edge_torch
|
|
22
22
|
from ai_edge_torch.generative.examples.test_models import toy_model_with_kv_cache # NOQA
|
|
23
23
|
from ai_edge_torch.generative.quantize import quant_recipe
|
|
24
|
+
from ai_edge_torch.generative.quantize import quant_recipe_utils
|
|
24
25
|
from ai_edge_torch.generative.quantize import quant_recipes
|
|
25
26
|
from ai_edge_torch.generative.quantize.quant_attrs import Algorithm
|
|
26
27
|
from ai_edge_torch.generative.quantize.quant_attrs import Dtype
|
|
27
28
|
from ai_edge_torch.generative.quantize.quant_attrs import Granularity
|
|
28
29
|
from ai_edge_torch.generative.quantize.quant_attrs import Mode
|
|
30
|
+
from ai_edge_torch.quantize import quant_config
|
|
29
31
|
from ai_edge_torch.testing import model_coverage
|
|
30
32
|
|
|
31
33
|
|
|
@@ -34,34 +36,47 @@ class TestVerifyRecipes(unittest.TestCase):
|
|
|
34
36
|
|
|
35
37
|
@parameterized.expand(
|
|
36
38
|
[
|
|
37
|
-
(Dtype.FP32, Dtype.FP32
|
|
38
|
-
(Dtype.INT8, Dtype.INT8
|
|
39
|
-
(Dtype.INT8, Dtype.FP16
|
|
40
|
-
(Dtype.FP16, Dtype.INT8
|
|
41
|
-
(Dtype.
|
|
42
|
-
(Dtype.INT8, Dtype.INT8, Mode.WEIGHT_ONLY),
|
|
43
|
-
(Dtype.FP16, Dtype.INT8, Mode.WEIGHT_ONLY),
|
|
44
|
-
(Dtype.INT8, Dtype.FP16, Mode.WEIGHT_ONLY),
|
|
45
|
-
(Dtype.FP16, Dtype.FP16, Mode.WEIGHT_ONLY),
|
|
39
|
+
(Dtype.FP32, Dtype.FP32),
|
|
40
|
+
(Dtype.INT8, Dtype.INT8),
|
|
41
|
+
(Dtype.INT8, Dtype.FP16),
|
|
42
|
+
(Dtype.FP16, Dtype.INT8),
|
|
43
|
+
(Dtype.FP16, Dtype.FP16),
|
|
46
44
|
]
|
|
47
45
|
)
|
|
48
46
|
def test_verify_invalid_recipes(
|
|
49
47
|
self,
|
|
50
48
|
activation,
|
|
51
49
|
weight,
|
|
52
|
-
mode,
|
|
53
|
-
algo=Algorithm.MIN_MAX,
|
|
54
|
-
granularity=Granularity.CHANNELWISE,
|
|
55
50
|
):
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
51
|
+
for m in Mode:
|
|
52
|
+
for a in Algorithm:
|
|
53
|
+
for g in Granularity:
|
|
54
|
+
with self.assertRaises(ValueError):
|
|
55
|
+
quant_recipe.LayerQuantRecipe(activation, weight, m, a, g).verify()
|
|
60
56
|
|
|
61
57
|
@parameterized.expand(
|
|
62
58
|
[
|
|
63
|
-
(
|
|
64
|
-
|
|
59
|
+
(
|
|
60
|
+
Dtype.FP32,
|
|
61
|
+
Dtype.INT8,
|
|
62
|
+
Mode.DYNAMIC_RANGE,
|
|
63
|
+
Algorithm.MIN_MAX,
|
|
64
|
+
Granularity.CHANNELWISE,
|
|
65
|
+
),
|
|
66
|
+
(
|
|
67
|
+
Dtype.FP32,
|
|
68
|
+
Dtype.INT8,
|
|
69
|
+
Mode.WEIGHT_ONLY,
|
|
70
|
+
Algorithm.MIN_MAX,
|
|
71
|
+
Granularity.CHANNELWISE,
|
|
72
|
+
),
|
|
73
|
+
(
|
|
74
|
+
Dtype.FP32,
|
|
75
|
+
Dtype.FP16,
|
|
76
|
+
Mode.WEIGHT_ONLY,
|
|
77
|
+
Algorithm.FLOAT_CAST,
|
|
78
|
+
Granularity.NONE,
|
|
79
|
+
),
|
|
65
80
|
]
|
|
66
81
|
)
|
|
67
82
|
def test_verify_valid_recipes(
|
|
@@ -69,8 +84,8 @@ class TestVerifyRecipes(unittest.TestCase):
|
|
|
69
84
|
activation,
|
|
70
85
|
weight,
|
|
71
86
|
mode,
|
|
87
|
+
algo,
|
|
72
88
|
granularity,
|
|
73
|
-
algo=Algorithm.MIN_MAX,
|
|
74
89
|
):
|
|
75
90
|
quant_recipe.LayerQuantRecipe(activation, weight, mode, algo, granularity).verify()
|
|
76
91
|
|
|
@@ -78,7 +93,46 @@ class TestVerifyRecipes(unittest.TestCase):
|
|
|
78
93
|
class TestQuantizeConvert(unittest.TestCase):
|
|
79
94
|
"""Test conversion with quantization."""
|
|
80
95
|
|
|
81
|
-
def
|
|
96
|
+
def _attention_1_int8_dynamic_recipe() -> quant_config.QuantConfig:
|
|
97
|
+
return quant_config.QuantConfig(
|
|
98
|
+
generative_recipe=quant_recipe.GenerativeQuantRecipe(
|
|
99
|
+
attention={1: quant_recipe_utils.create_layer_quant_int8_dynamic()},
|
|
100
|
+
)
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
def _feedforward_0_int8_dynamic_recipe() -> quant_config.QuantConfig:
|
|
104
|
+
return quant_config.QuantConfig(
|
|
105
|
+
generative_recipe=quant_recipe.GenerativeQuantRecipe(
|
|
106
|
+
feedforward={0: quant_recipe_utils.create_layer_quant_int8_dynamic()},
|
|
107
|
+
)
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
@parameterized.expand(
|
|
111
|
+
[
|
|
112
|
+
(quant_recipes.full_fp16_recipe(), 0.75),
|
|
113
|
+
(quant_recipes.full_linear_int8_dynamic_recipe(), 0.64),
|
|
114
|
+
(_attention_1_int8_dynamic_recipe(), 0.95),
|
|
115
|
+
(_feedforward_0_int8_dynamic_recipe(), 0.87),
|
|
116
|
+
]
|
|
117
|
+
)
|
|
118
|
+
def test_quantize_convert_toy_sizes(self, quant_config, expected_compression):
|
|
119
|
+
config = toy_model_with_kv_cache.get_model_config()
|
|
120
|
+
pytorch_model = toy_model_with_kv_cache.ToyModelWithKV(config)
|
|
121
|
+
idx, input_pos = torch.tensor([[1]], dtype=torch.long), torch.tensor(
|
|
122
|
+
[10], dtype=torch.int64
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
quantized_model = ai_edge_torch.convert(
|
|
126
|
+
pytorch_model, (idx, input_pos), quant_config=quant_config
|
|
127
|
+
)
|
|
128
|
+
float_model = ai_edge_torch.convert(pytorch_model, (idx, input_pos))
|
|
129
|
+
self.assertAlmostEqual(
|
|
130
|
+
len(quantized_model._tflite_model) / len(float_model._tflite_model),
|
|
131
|
+
expected_compression,
|
|
132
|
+
delta=0.01,
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
def test_quantize_convert_compare_toy(self):
|
|
82
136
|
self.skipTest("b/338288901")
|
|
83
137
|
config = toy_model_with_kv_cache.get_model_config()
|
|
84
138
|
pytorch_model = toy_model_with_kv_cache.ToyModelWithKV(config)
|
|
@@ -32,27 +32,26 @@ class QuantConfig:
|
|
|
32
32
|
pt2e_quantizer: The instance of PT2EQuantizer used to quantize the model
|
|
33
33
|
with PT2E quantization. This method of quantization is not applicable to
|
|
34
34
|
models created with the Edge Generative API.
|
|
35
|
-
|
|
35
|
+
generative_recipe: Quantization recipe to be applied on a model created
|
|
36
36
|
with the Edge Generative API.
|
|
37
37
|
"""
|
|
38
38
|
|
|
39
39
|
pt2e_quantizer: pt2eq.PT2EQuantizer = None
|
|
40
|
-
|
|
40
|
+
generative_recipe: quant_recipe.GenerativeQuantRecipe = None
|
|
41
41
|
|
|
42
42
|
@enum.unique
|
|
43
43
|
class _QuantizerMode(enum.Enum):
|
|
44
44
|
NONE = enum.auto()
|
|
45
45
|
PT2E_DYNAMIC = enum.auto()
|
|
46
46
|
PT2E_STATIC = enum.auto()
|
|
47
|
-
|
|
48
|
-
TFLITE_FP16 = enum.auto()
|
|
47
|
+
AI_EDGE_QUANTIZER = enum.auto()
|
|
49
48
|
|
|
50
49
|
_quantizer_mode: _QuantizerMode = _QuantizerMode.NONE
|
|
51
50
|
|
|
52
51
|
def __init__(
|
|
53
52
|
self,
|
|
54
53
|
pt2e_quantizer: Optional[pt2eq.PT2EQuantizer] = None,
|
|
55
|
-
|
|
54
|
+
generative_recipe: Optional[quant_recipe.GenerativeQuantRecipe] = None,
|
|
56
55
|
):
|
|
57
56
|
"""Initializes some internal states based on selected quantization method.
|
|
58
57
|
|
|
@@ -61,8 +60,8 @@ class QuantConfig:
|
|
|
61
60
|
is properly setup. Additionally sets up an utility enum _quantizer_mode to
|
|
62
61
|
guide certain conversion processes.
|
|
63
62
|
"""
|
|
64
|
-
if pt2e_quantizer is not None and
|
|
65
|
-
raise ValueError('Cannot set both pt2e_quantizer and
|
|
63
|
+
if pt2e_quantizer is not None and generative_recipe is not None:
|
|
64
|
+
raise ValueError('Cannot set both pt2e_quantizer and generative_recipe.')
|
|
66
65
|
elif pt2e_quantizer is not None:
|
|
67
66
|
object.__setattr__(self, 'pt2e_quantizer', pt2e_quantizer)
|
|
68
67
|
object.__setattr__(
|
|
@@ -74,12 +73,9 @@ class QuantConfig:
|
|
|
74
73
|
else self._QuantizerMode.PT2E_STATIC
|
|
75
74
|
),
|
|
76
75
|
)
|
|
77
|
-
elif
|
|
78
|
-
|
|
79
|
-
object.__setattr__(self, '
|
|
80
|
-
|
|
81
|
-
object.__setattr__(self, '_quantizer_mode', self._QuantizerMode.TFLITE_DYNAMIC)
|
|
82
|
-
elif self.transformer_recipe.default.weight_dtype == quant_attrs.Dtype.FP16:
|
|
83
|
-
object.__setattr__(self, '_quantizer_mode', self._QuantizerMode.TFLITE_FP16)
|
|
76
|
+
elif generative_recipe is not None:
|
|
77
|
+
generative_recipe.verify()
|
|
78
|
+
object.__setattr__(self, 'generative_recipe', generative_recipe)
|
|
79
|
+
object.__setattr__(self, '_quantizer_mode', self._QuantizerMode.AI_EDGE_QUANTIZER)
|
|
84
80
|
else:
|
|
85
|
-
raise ValueError('Either pt2e_quantizer or
|
|
81
|
+
raise ValueError('Either pt2e_quantizer or generative_recipe must be set.')
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: ai-edge-torch-nightly
|
|
3
|
-
Version: 0.2.0.
|
|
3
|
+
Version: 0.2.0.dev20240611
|
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
|
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=FPMmuFU3pyMREtjB_san1fy_0PFtAsgA0VZfOYvDrb4,100
|
|
|
2
2
|
ai_edge_torch/model.py,sha256=kmcgELjsYl8YzF8nUF6P7q4i8MWS-pLGpfsy-yTUXmE,4243
|
|
3
3
|
ai_edge_torch/convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
4
4
|
ai_edge_torch/convert/conversion.py,sha256=GN2Js232u_5Y118wg3qIfEoYewxbxLl3TpSnO6osi8c,4029
|
|
5
|
-
ai_edge_torch/convert/conversion_utils.py,sha256=
|
|
5
|
+
ai_edge_torch/convert/conversion_utils.py,sha256=9BqCL38DErv1vEVGtT3BIJVhdwZjw2EQ-_m5UpvVVYE,11294
|
|
6
6
|
ai_edge_torch/convert/converter.py,sha256=bjj5TV5_g4sGyuSh8ThEDydlNMqhkGSY4SzXK6vwhqI,6927
|
|
7
7
|
ai_edge_torch/convert/fx_passes/__init__.py,sha256=EPs4PSIDLuRH5EBETi6deaOvaaf_Q4xD3_9NVcR7x8o,2810
|
|
8
8
|
ai_edge_torch/convert/fx_passes/_pass_base.py,sha256=ijVyDclPnd6a0DWWUJkwR4igj6f82S-cE1-83QGPvgw,1652
|
|
@@ -34,16 +34,16 @@ ai_edge_torch/generative/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf
|
|
|
34
34
|
ai_edge_torch/generative/examples/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
35
35
|
ai_edge_torch/generative/examples/gemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
36
36
|
ai_edge_torch/generative/examples/gemma/convert_to_tflite.py,sha256=dZv3r24uHsTMokEdnl3nf7LpmV0q7FLnVtCuHn5AuUs,2538
|
|
37
|
-
ai_edge_torch/generative/examples/gemma/gemma.py,sha256=
|
|
37
|
+
ai_edge_torch/generative/examples/gemma/gemma.py,sha256=1lZfXGHmbII4rFu0U2B9NzlJCRhphxtmQtkCHQ39_uw,5935
|
|
38
38
|
ai_edge_torch/generative/examples/phi2/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
39
39
|
ai_edge_torch/generative/examples/phi2/convert_to_tflite.py,sha256=6nOuwx9q3AUlYcjXRRXSr_3M2JKqdJ-vUf-uE3VFYHE,2512
|
|
40
|
-
ai_edge_torch/generative/examples/phi2/phi2.py,sha256=
|
|
40
|
+
ai_edge_torch/generative/examples/phi2/phi2.py,sha256=PMhKC6JCAMYSj2F3UmWHWK4rTcXD-B6PuehaoDccRqk,5562
|
|
41
41
|
ai_edge_torch/generative/examples/stable_diffusion/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
42
42
|
ai_edge_torch/generative/examples/stable_diffusion/attention.py,sha256=Lo4Dq7a3Kg-lyH56iqGtqCo5UaClQHRCTDdNagXGTo8,3535
|
|
43
|
-
ai_edge_torch/generative/examples/stable_diffusion/clip.py,sha256=
|
|
43
|
+
ai_edge_torch/generative/examples/stable_diffusion/clip.py,sha256=qU1wVEcn_biwCuDguZljhlLGzpLIqgqC31Dh_lXquQc,3720
|
|
44
44
|
ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py,sha256=NmgDo5uAefrhMUbYku0TKHlqzO0NVWI_M1ue8tddQR4,4024
|
|
45
|
-
ai_edge_torch/generative/examples/stable_diffusion/decoder.py,sha256=
|
|
46
|
-
ai_edge_torch/generative/examples/stable_diffusion/diffusion.py,sha256=
|
|
45
|
+
ai_edge_torch/generative/examples/stable_diffusion/decoder.py,sha256=meW8t-3BDdjFs5vCAf76cn6lGx49a_GcEvnVa9R5if4,11106
|
|
46
|
+
ai_edge_torch/generative/examples/stable_diffusion/diffusion.py,sha256=_gEeUxa9Xyd3iLb_fyeUefHKuELVDorDlQs8e7wdXKg,7878
|
|
47
47
|
ai_edge_torch/generative/examples/stable_diffusion/encoder.py,sha256=mgbxkeFDMkNIGmnbcFTIFPu8EWKokghiviYIOB2lE3Q,3437
|
|
48
48
|
ai_edge_torch/generative/examples/stable_diffusion/pipeline.py,sha256=FCbnwlkpYYb-tF7KscbSYjNEdg7XnuLju1cDuIRoQv8,8277
|
|
49
49
|
ai_edge_torch/generative/examples/stable_diffusion/tokenizer.py,sha256=r9RqbyNvuvXOGu3ojtl7ZmbC7o4Pt8aUKAhN1yCdtEc,3397
|
|
@@ -55,40 +55,42 @@ ai_edge_torch/generative/examples/stable_diffusion/samplers/k_lms.py,sha256=iPYX
|
|
|
55
55
|
ai_edge_torch/generative/examples/stable_diffusion/samplers/sampler.py,sha256=5iRfU5MO6GR6K3WrdddIU_9U7ZZGEEb7zGKVY1WFl-8,1340
|
|
56
56
|
ai_edge_torch/generative/examples/t5/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
57
57
|
ai_edge_torch/generative/examples/t5/convert_to_tflite.py,sha256=bWtwtUacvJOEDUpuYvLTgkP7oTkXKJA-Tf4FPxlD1Cw,4536
|
|
58
|
-
ai_edge_torch/generative/examples/t5/t5.py,sha256=
|
|
58
|
+
ai_edge_torch/generative/examples/t5/t5.py,sha256=L6YrVzUEzP-Imb8W28LdukFGrx1aWSzz1kyYK_9RFZM,21087
|
|
59
59
|
ai_edge_torch/generative/examples/t5/t5_attention.py,sha256=rRgwCEdVtzcJEaGbbBjw8HxCxrCX3pXA5nelawdYiME,9036
|
|
60
60
|
ai_edge_torch/generative/examples/test_models/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
61
|
-
ai_edge_torch/generative/examples/test_models/toy_model.py,sha256=
|
|
62
|
-
ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py,sha256=
|
|
63
|
-
ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py,sha256=
|
|
61
|
+
ai_edge_torch/generative/examples/test_models/toy_model.py,sha256=CUXsQ_IU96NaCg9jyfeKI0Zz2iWDkJUsPJyPR1Pgz7I,3813
|
|
62
|
+
ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py,sha256=zwCmCnhr-vhBwHqv9i7xMasdBGVNqAGxZvWsncsJn58,5543
|
|
63
|
+
ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py,sha256=lfYUiem_Pbn3vGgPx84BeI8n7rN3-1fImwCLm8Eo2U8,4853
|
|
64
64
|
ai_edge_torch/generative/examples/tiny_llama/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
65
65
|
ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py,sha256=E4I5OlC4zyl5cxiiu7uTED-zcwYRu210lP1zuT3xLBE,2566
|
|
66
|
-
ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=
|
|
66
|
+
ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=IFRLPG9wz_aLl_zV_6CETCjSM03ukA6bZqqyDLVACuw,5651
|
|
67
67
|
ai_edge_torch/generative/layers/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
68
68
|
ai_edge_torch/generative/layers/attention.py,sha256=Z8gXHYs6h8gaRiYAdvYUbHzg_2EmqfxiChsf_SYraAc,7902
|
|
69
69
|
ai_edge_torch/generative/layers/attention_utils.py,sha256=hXhuyKblPPxKIRzlAf1YNlwHgpbj-6nReRLhRHELx5k,6350
|
|
70
|
-
ai_edge_torch/generative/layers/builder.py,sha256=
|
|
70
|
+
ai_edge_torch/generative/layers/builder.py,sha256=jAyrR5hsSI0aimKZumyvxdJ1GovERIfsK0g-dezX2gs,4163
|
|
71
71
|
ai_edge_torch/generative/layers/feed_forward.py,sha256=4j2QaSCw59Jkk_ixKDpKEj7FLRauzuExTiSNRzAjAhE,2820
|
|
72
72
|
ai_edge_torch/generative/layers/kv_cache.py,sha256=4uiZLO3om5G3--kT04Jt0esEYznbkJ7QLzSHfb8mjc4,3090
|
|
73
|
-
ai_edge_torch/generative/layers/model_config.py,sha256=
|
|
73
|
+
ai_edge_torch/generative/layers/model_config.py,sha256=toWECENDWgay9hsZcy4C89qph0KI3CpaeFqFc8Fr-Xk,4584
|
|
74
74
|
ai_edge_torch/generative/layers/normalization.py,sha256=M27eW3TcNK20oaXClXtfnu0lLWrAGrSKSsbegRWnj3c,1867
|
|
75
75
|
ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=12SsCuoRuLNCwnFGe_pHDOZEBwBcqXs87Aj0PaWWw4E,1383
|
|
76
76
|
ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=dYafGC205QE5CLIbBTCI-7eVvEGZEHzs1toPEhemeDs,3391
|
|
77
77
|
ai_edge_torch/generative/layers/unet/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
78
|
-
ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=
|
|
78
|
+
ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=7mHyJYq9lq5zVYp4mEz-R8Az3FFngi711YC20KP6ED8,10066
|
|
79
79
|
ai_edge_torch/generative/layers/unet/builder.py,sha256=iH0_nuY9TF2ap5h1JbGNCOonPTfrXQHcF8U0slrIREM,1210
|
|
80
|
-
ai_edge_torch/generative/layers/unet/model_config.py,sha256=
|
|
80
|
+
ai_edge_torch/generative/layers/unet/model_config.py,sha256=sbtbDEHmMV9GLKngwjsNvqm8wovLxnlidkQbXdXkXKs,4060
|
|
81
81
|
ai_edge_torch/generative/quantize/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
82
82
|
ai_edge_torch/generative/quantize/example.py,sha256=t-YwyKSPAG-OZC1DfH-0vfie2RHHpTSQjxUY-tmhu5g,1543
|
|
83
|
-
ai_edge_torch/generative/quantize/quant_attrs.py,sha256=
|
|
84
|
-
ai_edge_torch/generative/quantize/quant_recipe.py,sha256=
|
|
85
|
-
ai_edge_torch/generative/quantize/quant_recipe_utils.py,sha256
|
|
86
|
-
ai_edge_torch/generative/quantize/quant_recipes.py,sha256=
|
|
87
|
-
ai_edge_torch/generative/quantize/supported_schemes.py,sha256=
|
|
83
|
+
ai_edge_torch/generative/quantize/quant_attrs.py,sha256=n1Fm8BFC8gJa_oiwwAOOghJyHtOXYZ4q-5ZRy4pHrIw,1957
|
|
84
|
+
ai_edge_torch/generative/quantize/quant_recipe.py,sha256=Y8zahKw7b_h7ajPaJZVef4jG-MoqImRCpVSbFtV_i24,5139
|
|
85
|
+
ai_edge_torch/generative/quantize/quant_recipe_utils.py,sha256=-vd6Qp0BdXJVKg4f0_hhwbKOi3QPIAPVqyXnJ-ZnISQ,1915
|
|
86
|
+
ai_edge_torch/generative/quantize/quant_recipes.py,sha256=9ItD70jQRXMEhWod-nUfEeoWGJUUu6V9YOffF07VU9g,1795
|
|
87
|
+
ai_edge_torch/generative/quantize/supported_schemes.py,sha256=FjdycEOvxRgBmQdZVufetPvkDoD7rUowIOSKV9oV5Kk,1418
|
|
88
|
+
ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
89
|
+
ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/translate_recipe.py,sha256=qUB4f2DoB14dLkNPWf6TZodpT81mfAJeWM-lCAmkuHY,5735
|
|
88
90
|
ai_edge_torch/generative/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
89
91
|
ai_edge_torch/generative/test/loader_test.py,sha256=N88CbrLW7Q2x1EyurwdXQ6YjsA-ySQcPxpZH3QOGp-M,3317
|
|
90
92
|
ai_edge_torch/generative/test/test_model_conversion.py,sha256=i_SAW-hD8SaHuopMZI9IuXXDFn5uSTJa1nKZhaC3dAQ,6811
|
|
91
|
-
ai_edge_torch/generative/test/test_quantize.py,sha256=
|
|
93
|
+
ai_edge_torch/generative/test/test_quantize.py,sha256=NVlMixAxVpDUabEvp6zTHHgIDgHFsMRwlf5MuyDwrPg,5355
|
|
92
94
|
ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
|
|
93
95
|
ai_edge_torch/generative/utilities/autoencoder_loader.py,sha256=G2Nosy33JzkjGALPR4JjvffdFX1JWOj2zjbbuaDJEgg,10065
|
|
94
96
|
ai_edge_torch/generative/utilities/loader.py,sha256=Hs92478j1g4jQGvbdP1aWvOy907HjwqQZE-NFy6HELo,11326
|
|
@@ -103,12 +105,12 @@ ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py,sha256=aUAPKnH4_Jxpp
|
|
|
103
105
|
ai_edge_torch/quantize/__init__.py,sha256=aB5dXot04bqyUhpsDFvxt9CIi15QAC4euvqOndJ0XLU,714
|
|
104
106
|
ai_edge_torch/quantize/pt2e_quantizer.py,sha256=ye1f5vAZ0Vr4RWAtfrgU1o3JLs03Sa4inHRq3YxJDGo,15602
|
|
105
107
|
ai_edge_torch/quantize/pt2e_quantizer_utils.py,sha256=yjzKoptnfEeW_sN7sODUfj3nCtUMXVzq3vHKxblsd5Y,36046
|
|
106
|
-
ai_edge_torch/quantize/quant_config.py,sha256=
|
|
108
|
+
ai_edge_torch/quantize/quant_config.py,sha256=eO9Ra160ITjQSyRBEGy6nNIVH3gYacSWDdN5XtvHwjc,3148
|
|
107
109
|
ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
108
110
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
|
109
111
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=EIyKz-HY70DguWuSrJal8LpYXQ5ZSEUf3ZrVl7jikFM,4286
|
|
110
|
-
ai_edge_torch_nightly-0.2.0.
|
|
111
|
-
ai_edge_torch_nightly-0.2.0.
|
|
112
|
-
ai_edge_torch_nightly-0.2.0.
|
|
113
|
-
ai_edge_torch_nightly-0.2.0.
|
|
114
|
-
ai_edge_torch_nightly-0.2.0.
|
|
112
|
+
ai_edge_torch_nightly-0.2.0.dev20240611.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
|
113
|
+
ai_edge_torch_nightly-0.2.0.dev20240611.dist-info/METADATA,sha256=WPGu2pq6N57fBtpunyFhunPe73UK_SVbqlZQsZwjWGo,1748
|
|
114
|
+
ai_edge_torch_nightly-0.2.0.dev20240611.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
|
115
|
+
ai_edge_torch_nightly-0.2.0.dev20240611.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
|
116
|
+
ai_edge_torch_nightly-0.2.0.dev20240611.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|