ai-edge-torch-nightly 0.2.0.dev20240609__py3-none-any.whl → 0.2.0.dev20240611__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ai-edge-torch-nightly might be problematic. Click here for more details.

Files changed (29) hide show
  1. ai_edge_torch/convert/conversion_utils.py +17 -5
  2. ai_edge_torch/generative/examples/gemma/gemma.py +1 -1
  3. ai_edge_torch/generative/examples/phi2/phi2.py +1 -1
  4. ai_edge_torch/generative/examples/stable_diffusion/clip.py +1 -1
  5. ai_edge_torch/generative/examples/stable_diffusion/decoder.py +4 -4
  6. ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +6 -1
  7. ai_edge_torch/generative/examples/t5/t5.py +1 -1
  8. ai_edge_torch/generative/examples/test_models/toy_model.py +1 -1
  9. ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +1 -1
  10. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +1 -1
  11. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +1 -1
  12. ai_edge_torch/generative/layers/builder.py +31 -9
  13. ai_edge_torch/generative/layers/model_config.py +10 -1
  14. ai_edge_torch/generative/layers/unet/blocks_2d.py +4 -4
  15. ai_edge_torch/generative/layers/unet/model_config.py +4 -4
  16. ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/__init__.py +0 -0
  17. ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/translate_recipe.py +164 -0
  18. ai_edge_torch/generative/quantize/quant_attrs.py +2 -0
  19. ai_edge_torch/generative/quantize/quant_recipe.py +49 -4
  20. ai_edge_torch/generative/quantize/quant_recipe_utils.py +2 -2
  21. ai_edge_torch/generative/quantize/quant_recipes.py +3 -3
  22. ai_edge_torch/generative/quantize/supported_schemes.py +2 -1
  23. ai_edge_torch/generative/test/test_quantize.py +74 -20
  24. ai_edge_torch/quantize/quant_config.py +11 -15
  25. {ai_edge_torch_nightly-0.2.0.dev20240609.dist-info → ai_edge_torch_nightly-0.2.0.dev20240611.dist-info}/METADATA +1 -1
  26. {ai_edge_torch_nightly-0.2.0.dev20240609.dist-info → ai_edge_torch_nightly-0.2.0.dev20240611.dist-info}/RECORD +29 -27
  27. {ai_edge_torch_nightly-0.2.0.dev20240609.dist-info → ai_edge_torch_nightly-0.2.0.dev20240611.dist-info}/LICENSE +0 -0
  28. {ai_edge_torch_nightly-0.2.0.dev20240609.dist-info → ai_edge_torch_nightly-0.2.0.dev20240611.dist-info}/WHEEL +0 -0
  29. {ai_edge_torch_nightly-0.2.0.dev20240609.dist-info → ai_edge_torch_nightly-0.2.0.dev20240611.dist-info}/top_level.txt +0 -0
@@ -24,6 +24,7 @@ from typing import Any, Dict, Optional, Tuple, Union
24
24
  import torch
25
25
  from torch_xla import stablehlo
26
26
 
27
+ from ai_edge_torch.generative.quantize.ai_edge_quantizer_glue import translate_recipe # NOQA
27
28
  from ai_edge_torch.quantize import quant_config as qcfg
28
29
 
29
30
  try:
@@ -249,11 +250,6 @@ def _set_tfl_converter_quant_flags(
249
250
  converter._experimental_qdq_conversion_mode = "DYNAMIC"
250
251
  elif quantizer_mode == qcfg.QuantConfig._QuantizerMode.PT2E_STATIC:
251
252
  converter._experimental_qdq_conversion_mode = "STATIC"
252
- elif quantizer_mode == qcfg.QuantConfig._QuantizerMode.TFLITE_DYNAMIC:
253
- converter.optimizations = [tf.lite.Optimize.DEFAULT]
254
- elif quantizer_mode == qcfg.QuantConfig._QuantizerMode.TFLITE_FP16:
255
- converter.optimizations = [tf.lite.Optimize.DEFAULT]
256
- converter.target_spec.supported_types = [tf.float16]
257
253
 
258
254
 
259
255
  def convert_stablehlo_to_tflite(
@@ -323,8 +319,24 @@ def convert_stablehlo_to_tflite(
323
319
  converter._experimental_enable_composite_direct_lowering = True
324
320
 
325
321
  _set_tfl_converter_quant_flags(converter, quant_config)
322
+ if (
323
+ quant_config is not None
324
+ and quant_config._quantizer_mode
325
+ == quant_config._QuantizerMode.AI_EDGE_QUANTIZER
326
+ ):
327
+ translated_recipe = translate_recipe.translate_to_ai_edge_recipe(
328
+ quant_config.generative_recipe
329
+ )
330
+
326
331
  _apply_tfl_backdoor_flags(converter, _tfl_converter_flags)
327
332
 
328
333
  tflite_model = converter.convert()
329
334
 
335
+ if (
336
+ quant_config is not None
337
+ and quant_config._quantizer_mode
338
+ == quant_config._QuantizerMode.AI_EDGE_QUANTIZER
339
+ ):
340
+ tflite_model = translate_recipe.quantize_model(tflite_model, translated_recipe)
341
+
330
342
  return tflite_model
@@ -116,7 +116,7 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
116
116
  )
117
117
  ff_config = cfg.FeedForwardConfig(
118
118
  type=cfg.FeedForwardType.GATED,
119
- activation=cfg.ActivationType.GELU_TANH,
119
+ activation=cfg.ActivationConfig(cfg.ActivationType.GELU_TANH),
120
120
  intermediate_size=16384,
121
121
  )
122
122
  norm_config = cfg.NormalizationConfig(
@@ -112,7 +112,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
112
112
  )
113
113
  ff_config = cfg.FeedForwardConfig(
114
114
  type=cfg.FeedForwardType.SEQUENTIAL,
115
- activation=cfg.ActivationType.GELU_TANH,
115
+ activation=cfg.ActivationConfig(cfg.ActivationType.GELU_TANH),
116
116
  intermediate_size=10240,
117
117
  use_bias=True,
118
118
  )
@@ -90,7 +90,7 @@ def get_model_config() -> cfg.ModelConfig:
90
90
 
91
91
  ff_config = cfg.FeedForwardConfig(
92
92
  type=cfg.FeedForwardType.SEQUENTIAL,
93
- activation=cfg.ActivationType.GELU_QUICK,
93
+ activation=cfg.ActivationConfig(cfg.ActivationType.GELU_QUICK),
94
94
  intermediate_size=embedding_dim * 4,
95
95
  use_bias=True,
96
96
  )
@@ -221,7 +221,7 @@ class Decoder(nn.Module):
221
221
  in_channels=prev_output_channel,
222
222
  out_channels=block_out_channels,
223
223
  normalization_config=config.normalization_config,
224
- activation_type=config.activation_type,
224
+ activation_config=config.activation_config,
225
225
  num_layers=config.layers_per_block,
226
226
  add_upsample=not_final_block,
227
227
  upsample_conv=True,
@@ -235,7 +235,7 @@ class Decoder(nn.Module):
235
235
  self.final_norm = layers_builder.build_norm(
236
236
  block_out_channels, config.normalization_config
237
237
  )
238
- self.act_fn = layers_builder.get_activation(config.activation_type)
238
+ self.act_fn = layers_builder.get_activation(config.activation_config)
239
239
  self.conv_out = nn.Conv2d(
240
240
  block_out_channels,
241
241
  config.out_channels,
@@ -287,7 +287,7 @@ def get_model_config() -> unet_cfg.AutoEncoderConfig:
287
287
  mid_block_config = unet_cfg.MidBlock2DConfig(
288
288
  in_channels=block_out_channels[-1],
289
289
  normalization_config=norm_config,
290
- activation_type=layers_cfg.ActivationType.SILU,
290
+ activation_config=layers_cfg.ActivationConfig(layers_cfg.ActivationType.SILU),
291
291
  num_layers=1,
292
292
  attention_block_config=att_config,
293
293
  )
@@ -296,7 +296,7 @@ def get_model_config() -> unet_cfg.AutoEncoderConfig:
296
296
  in_channels=in_channels,
297
297
  latent_channels=latent_channels,
298
298
  out_channels=out_channels,
299
- activation_type=layers_cfg.ActivationType.SILU,
299
+ activation_config=layers_cfg.ActivationConfig(layers_cfg.ActivationType.SILU),
300
300
  block_out_channels=block_out_channels,
301
301
  scaling_factor=scaling_factor,
302
302
  layers_per_block=layers_per_block,
@@ -130,7 +130,7 @@ class Upsample(nn.Module):
130
130
  self.conv = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
131
131
 
132
132
  def forward(self, x):
133
- x = F.interpolate(x, scale_factor=2, mode='nearest')
133
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
134
134
  return self.conv(x)
135
135
 
136
136
 
@@ -237,3 +237,8 @@ class Diffusion(nn.Module):
237
237
  output = self.unet(latent, context, time)
238
238
  output = self.final(output)
239
239
  return output
240
+
241
+
242
+ if __name__ == "__main__":
243
+ diffusion = Diffusion()
244
+ print(diffusion.state_dict().keys())
@@ -349,7 +349,7 @@ def get_model_config_t5() -> cfg.ModelConfig:
349
349
  )
350
350
  ff_config = cfg.FeedForwardConfig(
351
351
  type=cfg.FeedForwardType.SEQUENTIAL,
352
- activation=cfg.ActivationType.RELU,
352
+ activation=cfg.ActivationConfig(cfg.ActivationType.RELU),
353
353
  intermediate_size=3072,
354
354
  )
355
355
  # T5 Confirmed as RMS Norm and eps = 1e-6 TJA.
@@ -76,7 +76,7 @@ def define_and_run() -> None:
76
76
  )
77
77
  ff_config = cfg.FeedForwardConfig(
78
78
  type=cfg.FeedForwardType.GATED,
79
- activation=cfg.ActivationType.SILU,
79
+ activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
80
80
  intermediate_size=256,
81
81
  )
82
82
  norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
@@ -95,7 +95,7 @@ def get_model_config() -> cfg.ModelConfig:
95
95
  )
96
96
  ff_config = cfg.FeedForwardConfig(
97
97
  type=cfg.FeedForwardType.GATED,
98
- activation=cfg.ActivationType.SILU,
98
+ activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
99
99
  intermediate_size=256,
100
100
  )
101
101
  norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
@@ -83,7 +83,7 @@ def get_model_config() -> cfg.ModelConfig:
83
83
  )
84
84
  ff_config = cfg.FeedForwardConfig(
85
85
  type=cfg.FeedForwardType.GATED,
86
- activation=cfg.ActivationType.SILU,
86
+ activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
87
87
  intermediate_size=256,
88
88
  )
89
89
  norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
@@ -112,7 +112,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
112
112
  )
113
113
  ff_config = cfg.FeedForwardConfig(
114
114
  type=cfg.FeedForwardType.GATED,
115
- activation=cfg.ActivationType.SILU,
115
+ activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
116
116
  intermediate_size=5632,
117
117
  )
118
118
  norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
@@ -13,6 +13,7 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
  # Builder class for individual components.
16
+ import torch
16
17
  from torch import nn
17
18
  import torch.nn.functional as F
18
19
 
@@ -21,6 +22,23 @@ import ai_edge_torch.generative.layers.model_config as cfg
21
22
  import ai_edge_torch.generative.layers.normalization as normalization
22
23
 
23
24
 
25
+ class GeGLU(nn.Module):
26
+ """GeGLU is an activation function which is a variant of GELU.
27
+
28
+ GeGLU(x) = (xW+b) * GELU(xV+c)
29
+ See: https://arxiv.org/abs/2002.05202v1
30
+
31
+ """
32
+
33
+ def __init__(self, d_in: int, d_out: int):
34
+ super().__init__()
35
+ self.proj = nn.Linear(d_in, d_out * 2)
36
+
37
+ def forward(self, x: torch.Tensor):
38
+ x, gate = self.proj(x).chunk(2, dim=-1)
39
+ return x * F.gelu(gate)
40
+
41
+
24
42
  def build_norm(dim: int, config: cfg.NormalizationConfig):
25
43
  """Builder function for normalizers.
26
44
 
@@ -81,29 +99,33 @@ def build_ff(dim: int, config: cfg.FeedForwardConfig):
81
99
  )
82
100
 
83
101
 
84
- def get_activation(type_: cfg.ActivationType):
85
- """Get pytorch callable activation from the name.
102
+ def get_activation(config: cfg.ActivationConfig):
103
+ """Get pytorch callable activation from the activation config.
86
104
 
87
105
  Args:
88
- name (string): activation's name.
106
+ config (cfg.ActivationConfig): activation config.
89
107
 
90
108
  Returns:
91
109
  Activation function.
92
110
 
93
111
  Raises:
94
- ValueError: If activation name is not supported.
112
+ ValueError: If activation config is not supported.
95
113
  """
96
- if type_ == cfg.ActivationType.SILU:
114
+ if config.type == cfg.ActivationType.LINEAR:
115
+ return lambda x: x
116
+ elif config.type == cfg.ActivationType.SILU:
97
117
  return F.silu
98
- elif type_ == cfg.ActivationType.GELU:
118
+ elif config.type == cfg.ActivationType.GELU:
99
119
  return F.gelu
100
- elif type_ == cfg.ActivationType.GELU_TANH:
120
+ elif config.type == cfg.ActivationType.GELU_TANH:
101
121
  return lambda x: F.gelu(x, approximate="tanh")
102
- elif type_ == cfg.ActivationType.GELU_QUICK:
122
+ elif config.type == cfg.ActivationType.GELU_QUICK:
103
123
  # GELU approximation that is fast but somewhat inaccurate.
104
124
  # See: https://github.com/hendrycks/GELUs
105
125
  return lambda x: x * F.sigmoid(1.702 * x)
106
- elif type_ == cfg.ActivationType.RELU:
126
+ elif config.type == cfg.ActivationType.GE_GLU:
127
+ return GeGLU(config.dim_in, config.dim_out)
128
+ elif config.type == cfg.ActivationType.RELU:
107
129
  return F.relu
108
130
  else:
109
131
  raise ValueError("Unsupported activation type.")
@@ -28,6 +28,7 @@ class ActivationType(enum.Enum):
28
28
  GELU = enum.auto()
29
29
  GELU_TANH = enum.auto()
30
30
  GELU_QUICK = enum.auto()
31
+ GE_GLU = enum.auto()
31
32
  RELU = enum.auto()
32
33
 
33
34
 
@@ -74,12 +75,20 @@ class AttentionConfig:
74
75
  relative_attention_max_distance: int = 0
75
76
 
76
77
 
78
+ @dataclass
79
+ class ActivationConfig:
80
+ type: ActivationType = ActivationType.LINEAR
81
+ # Dimension of input and output, used in GeGLU.
82
+ dim_in: Optional[int] = None
83
+ dim_out: Optional[int] = None
84
+
85
+
77
86
  @dataclass
78
87
  class FeedForwardConfig:
79
88
  """FeedForward module's parameters."""
80
89
 
81
90
  type: FeedForwardType
82
- activation: ActivationType
91
+ activation: ActivationConfig
83
92
  intermediate_size: int
84
93
  use_bias: bool = False
85
94
 
@@ -53,7 +53,7 @@ class ResidualBlock2D(nn.Module):
53
53
  self.conv_2 = nn.Conv2d(
54
54
  config.out_channels, config.out_channels, kernel_size=3, stride=1, padding=1
55
55
  )
56
- self.act_fn = layers_builder.get_activation(config.activation_type)
56
+ self.act_fn = layers_builder.get_activation(config.activation_config)
57
57
  if config.in_channels == config.out_channels:
58
58
  self.residual_layer = nn.Identity()
59
59
  else:
@@ -167,7 +167,7 @@ class UpDecoderBlock2D(nn.Module):
167
167
  out_channels=config.out_channels,
168
168
  time_embedding_channels=config.time_embedding_channels,
169
169
  normalization_config=config.normalization_config,
170
- activation_type=config.activation_type,
170
+ activation_config=config.activation_config,
171
171
  )
172
172
  )
173
173
  )
@@ -244,7 +244,7 @@ class MidBlock2D(nn.Module):
244
244
  out_channels=config.in_channels,
245
245
  time_embedding_channels=config.time_embedding_channels,
246
246
  normalization_config=config.normalization_config,
247
- activation_type=config.activation_type,
247
+ activation_config=config.activation_config,
248
248
  )
249
249
  )
250
250
  ]
@@ -259,7 +259,7 @@ class MidBlock2D(nn.Module):
259
259
  out_channels=config.in_channels,
260
260
  time_embedding_channels=config.time_embedding_channels,
261
261
  normalization_config=config.normalization_config,
262
- activation_type=config.activation_type,
262
+ activation_config=config.activation_config,
263
263
  )
264
264
  )
265
265
  )
@@ -39,7 +39,7 @@ class ResidualBlock2DConfig:
39
39
  in_channels: int
40
40
  out_channels: int
41
41
  normalization_config: layers_cfg.NormalizationConfig
42
- activation_type: layers_cfg.ActivationType
42
+ activation_config: layers_cfg.ActivationConfig
43
43
  # Optional time embedding channels if the residual block takes a time embedding context as input
44
44
  time_embedding_channels: Optional[int] = None
45
45
 
@@ -56,7 +56,7 @@ class UpDecoderBlock2DConfig:
56
56
  in_channels: int
57
57
  out_channels: int
58
58
  normalization_config: layers_cfg.NormalizationConfig
59
- activation_type: layers_cfg.ActivationType
59
+ activation_config: layers_cfg.ActivationConfig
60
60
  num_layers: int
61
61
  # Optional time embedding channels if the residual blocks take a time embedding context as input
62
62
  time_embedding_channels: Optional[int] = None
@@ -72,7 +72,7 @@ class UpDecoderBlock2DConfig:
72
72
  class MidBlock2DConfig:
73
73
  in_channels: int
74
74
  normalization_config: layers_cfg.NormalizationConfig
75
- activation_type: layers_cfg.ActivationType
75
+ activation_config: layers_cfg.ActivationConfig
76
76
  num_layers: int
77
77
  # Optional time embedding channels if the residual blocks take a time embedding context as input
78
78
  time_embedding_channels: Optional[int] = None
@@ -85,7 +85,7 @@ class AutoEncoderConfig:
85
85
  """Configurations of encoder/decoder in the autoencoder model."""
86
86
 
87
87
  # The activation type of encoder/decoder blocks.
88
- activation_type: layers_cfg.ActivationType
88
+ activation_config: layers_cfg.ActivationConfig
89
89
 
90
90
  # The output channels of each block.
91
91
  block_out_channels: List[int]
@@ -0,0 +1,164 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import json
17
+
18
+ from ai_edge_quantizer import quantizer
19
+
20
+ from ai_edge_torch.generative.quantize import quant_attrs
21
+ from ai_edge_torch.generative.quantize import quant_recipe
22
+
23
+ _OpExecutionMode = quantizer.qtyping.OpExecutionMode
24
+ _OpName = quantizer.qtyping.TFLOperationName
25
+ _TensorQuantConfig = quantizer.qtyping.TensorQuantizationConfig
26
+ _OpQuantConfig = quantizer.qtyping.OpQuantizationConfig
27
+
28
+ _DEFAULT_REGEX_STR = '.*'
29
+ _ATTENTION_IDX_REGEX_STR = (
30
+ 'transformer_blocks\[{}\]/ai_edge_torch.generative.layers.attention'
31
+ )
32
+ _FEEDFORWARD_IDX_REGEX_STR = (
33
+ 'transformer_blocks\[{}\]/ai_edge_torch.generative.layers.feed_forward'
34
+ )
35
+ _EMBEDDING_REGEX_STR = 'Embedding_tok_embedding'
36
+ _ANY_TWO_DIGITS_REGEX_STR = '\d{1,2}'
37
+
38
+
39
+ def _get_nbits_from_dtype(dtype: quant_attrs.Dtype) -> int:
40
+ if dtype == quant_attrs.Dtype.FP32:
41
+ return 32
42
+ elif dtype == quant_attrs.Dtype.FP16:
43
+ return 16
44
+ elif dtype == quant_attrs.Dtype.INT8:
45
+ return 8
46
+ raise ValueError('Unimplemented number of bits')
47
+
48
+
49
+ def _get_dtype_from_dtype(dtype: quant_attrs.Dtype) -> quantizer.qtyping.TensorDataType:
50
+ if dtype == quant_attrs.Dtype.FP32 or dtype == quant_attrs.Dtype.FP16:
51
+ return quantizer.qtyping.TensorDataType.FLOAT
52
+ else:
53
+ return quantizer.qtyping.TensorDataType.INT
54
+
55
+
56
+ def _get_execution_mode_from_mode(mode: quant_attrs.Mode) -> _OpExecutionMode:
57
+ if mode == quant_attrs.Mode.DYNAMIC_RANGE:
58
+ return _OpExecutionMode.DRQ
59
+ elif mode == quant_attrs.Mode.WEIGHT_ONLY:
60
+ return _OpExecutionMode.WEIGHT_ONLY
61
+ raise ValueError('Unimplemented execution mode')
62
+
63
+
64
+ def _get_channelwise_from_granularity(granularity: quant_attrs.Granularity) -> bool:
65
+ if granularity == quant_attrs.Granularity.CHANNELWISE:
66
+ return True
67
+ elif granularity == quant_attrs.Granularity.NONE:
68
+ return False
69
+ raise ValueError('Unimplemented granularity')
70
+
71
+
72
+ def _get_algorithm_key_from_algorithm(algo: quant_attrs.Algorithm) -> str:
73
+ if algo == quant_attrs.Algorithm.MIN_MAX:
74
+ return quantizer.algorithm_manager.AlgorithmName.MIN_MAX_UNIFORM_QUANT
75
+ elif algo == quant_attrs.Algorithm.FLOAT_CAST:
76
+ return quantizer.algorithm_manager.AlgorithmName.FLOAT_CASTING
77
+ raise ValueError('Unimplemented algorithm')
78
+
79
+
80
+ def _set_quant_config(
81
+ rm: quantizer.recipe_manager.RecipeManager,
82
+ layer_recipe: quant_recipe.LayerQuantRecipe,
83
+ regex: str,
84
+ ):
85
+ support_op_list = [_OpName.FULLY_CONNECTED, _OpName.CONV_2D]
86
+ if layer_recipe.algorithm == quant_attrs.Algorithm.MIN_MAX:
87
+ support_op_list += [_OpName.BATCH_MATMUL, _OpName.EMBEDDING_LOOKUP]
88
+ for op_name in support_op_list:
89
+ rm.add_quantization_config(
90
+ regex=regex,
91
+ operation_name=op_name,
92
+ op_config=_OpQuantConfig(
93
+ weight_tensor_config=_TensorQuantConfig(
94
+ num_bits=_get_nbits_from_dtype(layer_recipe.weight_dtype),
95
+ symmetric=True,
96
+ channel_wise=_get_channelwise_from_granularity(
97
+ layer_recipe.granularity
98
+ ),
99
+ dtype=_get_dtype_from_dtype(layer_recipe.weight_dtype),
100
+ ),
101
+ execution_mode=_get_execution_mode_from_mode(layer_recipe.mode),
102
+ ),
103
+ algorithm_key=_get_algorithm_key_from_algorithm(layer_recipe.algorithm),
104
+ override_algorithm=True,
105
+ )
106
+
107
+
108
+ def translate_to_ai_edge_recipe(
109
+ recipe: quant_recipe.GenerativeQuantRecipe,
110
+ ) -> quantizer.recipe_manager.ModelQuantizationRecipe:
111
+ rm = quantizer.recipe_manager.RecipeManager()
112
+
113
+ if recipe.default is not None:
114
+ _set_quant_config(rm, recipe.default, _DEFAULT_REGEX_STR)
115
+
116
+ if recipe.embedding is not None:
117
+ _set_quant_config(rm, recipe.embedding, _EMBEDDING_REGEX_STR)
118
+
119
+ if recipe.attention is not None:
120
+ if isinstance(recipe.attention, dict):
121
+ for idx, layer in recipe.attention.items():
122
+ _set_quant_config(rm, layer, _ATTENTION_IDX_REGEX_STR.format(idx))
123
+ else:
124
+ _set_quant_config(
125
+ rm,
126
+ recipe.attention,
127
+ _ATTENTION_IDX_REGEX_STR.format(_ANY_TWO_DIGITS_REGEX_STR),
128
+ )
129
+
130
+ if recipe.feedforward is not None:
131
+ if isinstance(recipe.feedforward, dict):
132
+ for idx, layer in recipe.feedforward.items():
133
+ _set_quant_config(rm, layer, _FEEDFORWARD_IDX_REGEX_STR.format(idx))
134
+ else:
135
+ _set_quant_config(
136
+ rm,
137
+ recipe.feedforward,
138
+ _FEEDFORWARD_IDX_REGEX_STR.format(_ANY_TWO_DIGITS_REGEX_STR),
139
+ )
140
+
141
+ return rm.get_quantization_recipe()
142
+
143
+
144
+ def quantize_model(
145
+ model: bytearray, recipe: quantizer.recipe_manager.ModelQuantizationRecipe
146
+ ) -> bytearray:
147
+ # TODO(b/336599483): Remove tempfile and use bytearray instead
148
+ tmp_model_path = '/tmp/tmp.tflite'
149
+ tmp_recipe_path = '/tmp/recipe.json'
150
+ with open(tmp_model_path, 'wb') as fp:
151
+ fp.write(model)
152
+ with open(tmp_recipe_path, 'w') as rp:
153
+ rp.write(json.dumps(recipe))
154
+
155
+ qt = quantizer.Quantizer(tmp_model_path, tmp_recipe_path)
156
+ result = qt.quantize()
157
+
158
+ # TODO(b/336599483): Remove tempfile and use bytearray instead
159
+ import os
160
+
161
+ os.remove(tmp_model_path)
162
+ os.remove(tmp_recipe_path)
163
+
164
+ return result.quantized_model
@@ -32,9 +32,11 @@ class Algorithm(enum.Enum):
32
32
  Attributes:
33
33
  MIN_MAX: Maps the min/max of floating point space to the min/max of
34
34
  quantized space and quantize uniformly.
35
+ FLOAT_CAST: Casts a float to another float of a different type.
35
36
  """
36
37
 
37
38
  MIN_MAX = enum.auto()
39
+ FLOAT_CAST = enum.auto()
38
40
 
39
41
 
40
42
  @enum.unique
@@ -14,8 +14,7 @@
14
14
  # ==============================================================================
15
15
 
16
16
  from dataclasses import dataclass
17
- import enum
18
- from typing import Optional
17
+ from typing import Optional, Union
19
18
 
20
19
  from ai_edge_torch.generative.quantize import quant_attrs
21
20
  from ai_edge_torch.generative.quantize import supported_schemes
@@ -80,18 +79,50 @@ class LayerQuantRecipe:
80
79
 
81
80
 
82
81
  @dataclass
83
- class TransformerQuantRecipe:
82
+ class GenerativeQuantRecipe:
84
83
  """Quantization recipe for a model composed of the Edge Generative API layers.
85
84
 
85
+ Some layers can be specified with different `LayerQuantRecipe` for each block by
86
+ providing a dictionary keyed by the TransformerBlock index, e.g. attention
87
+ and feedforward. For example,
88
+
89
+ ```
90
+ default = LayerQuantRecipeA
91
+ attention = { 2: LayerQuantRecipeB }
92
+ feedforward = { 3: LayerQuantRecipeC }
93
+ ```
94
+
95
+ will apply LayerQuantRecipeA to the entire model, overriden by
96
+ LayerQuantRecipeB for the TransformerBlock[2].attention layer and
97
+ LayerQuantRecipeC for the TransformerBlock[3].feedforward layer. Any config
98
+ with invalid indices will be ignored.
99
+
86
100
  Attributes:
87
101
  default: The quantization recipe for global scope of the model.
102
+ embedding: Recipe for the embedding table.
103
+ attention: Recipe for the attention blocks. This could be specified with
104
+ different LayerQuantRecipe for each block by providing a dictionary
105
+ keyed by the TransformerBlock index.
106
+ feedforward: Recipe for the feedforward layers. This could be specified with
107
+ different LayerQuantRecipe for each block by providing a dictionary
108
+ keyed by the TransformerBlock index.
88
109
  """
89
110
 
90
111
  default: Optional[LayerQuantRecipe] = None
112
+ embedding: Optional[LayerQuantRecipe] = None
113
+ attention: Union[
114
+ Optional[LayerQuantRecipe], Optional[dict[int, LayerQuantRecipe]]
115
+ ] = None
116
+ feedforward: Union[
117
+ Optional[LayerQuantRecipe], Optional[dict[int, LayerQuantRecipe]]
118
+ ] = None
91
119
 
92
120
  def __str__(self):
93
- return f"""TransformerQuantRecipe(
121
+ return f"""GenerativeQuantRecipe(
94
122
  Default: {self.default}
123
+ Embedding: {self.embedding}
124
+ Attention: {self.attention}
125
+ Feedforward: {self.feedforward}
95
126
  )"""
96
127
 
97
128
  __repr__ = __str__
@@ -104,3 +135,17 @@ class TransformerQuantRecipe:
104
135
  """
105
136
  if self.default is not None:
106
137
  self.default.verify()
138
+ if self.embedding is not None:
139
+ self.embedding.verify()
140
+ if self.attention is not None:
141
+ if isinstance(self.attention, dict):
142
+ for recipe in self.attention.values():
143
+ recipe.verify()
144
+ else:
145
+ self.attention.verify()
146
+ if self.feedforward is not None:
147
+ if isinstance(self.feedforward, dict):
148
+ for recipe in self.feedforward.values():
149
+ recipe.verify()
150
+ else:
151
+ self.feedforward.verify()
@@ -22,7 +22,7 @@ Typical usage example:
22
22
 
23
23
  1. Applying a single layer recipe to the entire model
24
24
 
25
- quant_recipe.TransformerQuantRecipe(
25
+ quant_recipe.GenerativeQuantRecipe(
26
26
  default=quant_recipe_utils.create_layer_quant_int8_dynamic()
27
27
  )
28
28
  """
@@ -46,6 +46,6 @@ def create_layer_quant_fp16() -> quant_recipe.LayerQuantRecipe:
46
46
  activation_dtype=quant_attrs.Dtype.FP32,
47
47
  weight_dtype=quant_attrs.Dtype.FP16,
48
48
  mode=quant_attrs.Mode.WEIGHT_ONLY,
49
- algorithm=quant_attrs.Algorithm.MIN_MAX,
49
+ algorithm=quant_attrs.Algorithm.FLOAT_CAST,
50
50
  granularity=quant_attrs.Granularity.NONE,
51
51
  )
@@ -34,15 +34,15 @@ from ai_edge_torch.quantize import quant_config
34
34
 
35
35
  def full_linear_int8_dynamic_recipe() -> quant_config.QuantConfig:
36
36
  return quant_config.QuantConfig(
37
- transformer_recipe=quant_recipe.TransformerQuantRecipe(
38
- default=quant_recipe_utils.create_layer_quant_int8_dynamic()
37
+ generative_recipe=quant_recipe.GenerativeQuantRecipe(
38
+ default=quant_recipe_utils.create_layer_quant_int8_dynamic(),
39
39
  )
40
40
  )
41
41
 
42
42
 
43
43
  def full_fp16_recipe() -> quant_config.QuantConfig:
44
44
  return quant_config.QuantConfig(
45
- transformer_recipe=quant_recipe.TransformerQuantRecipe(
45
+ generative_recipe=quant_recipe.GenerativeQuantRecipe(
46
46
  default=quant_recipe_utils.create_layer_quant_fp16()
47
47
  )
48
48
  )
@@ -27,5 +27,6 @@ def get_supported_layer_schemes():
27
27
 
28
28
  return [
29
29
  (_t.FP32, _t.INT8, _m.DYNAMIC_RANGE, _a.MIN_MAX, _g.CHANNELWISE),
30
- (_t.FP32, _t.FP16, _m.WEIGHT_ONLY, _a.MIN_MAX, _g.NONE),
30
+ (_t.FP32, _t.INT8, _m.WEIGHT_ONLY, _a.MIN_MAX, _g.CHANNELWISE),
31
+ (_t.FP32, _t.FP16, _m.WEIGHT_ONLY, _a.FLOAT_CAST, _g.NONE),
31
32
  ]
@@ -21,11 +21,13 @@ import torch
21
21
  import ai_edge_torch
22
22
  from ai_edge_torch.generative.examples.test_models import toy_model_with_kv_cache # NOQA
23
23
  from ai_edge_torch.generative.quantize import quant_recipe
24
+ from ai_edge_torch.generative.quantize import quant_recipe_utils
24
25
  from ai_edge_torch.generative.quantize import quant_recipes
25
26
  from ai_edge_torch.generative.quantize.quant_attrs import Algorithm
26
27
  from ai_edge_torch.generative.quantize.quant_attrs import Dtype
27
28
  from ai_edge_torch.generative.quantize.quant_attrs import Granularity
28
29
  from ai_edge_torch.generative.quantize.quant_attrs import Mode
30
+ from ai_edge_torch.quantize import quant_config
29
31
  from ai_edge_torch.testing import model_coverage
30
32
 
31
33
 
@@ -34,34 +36,47 @@ class TestVerifyRecipes(unittest.TestCase):
34
36
 
35
37
  @parameterized.expand(
36
38
  [
37
- (Dtype.FP32, Dtype.FP32, Mode.DYNAMIC_RANGE),
38
- (Dtype.INT8, Dtype.INT8, Mode.DYNAMIC_RANGE),
39
- (Dtype.INT8, Dtype.FP16, Mode.DYNAMIC_RANGE),
40
- (Dtype.FP16, Dtype.INT8, Mode.DYNAMIC_RANGE),
41
- (Dtype.FP32, Dtype.FP32, Mode.WEIGHT_ONLY),
42
- (Dtype.INT8, Dtype.INT8, Mode.WEIGHT_ONLY),
43
- (Dtype.FP16, Dtype.INT8, Mode.WEIGHT_ONLY),
44
- (Dtype.INT8, Dtype.FP16, Mode.WEIGHT_ONLY),
45
- (Dtype.FP16, Dtype.FP16, Mode.WEIGHT_ONLY),
39
+ (Dtype.FP32, Dtype.FP32),
40
+ (Dtype.INT8, Dtype.INT8),
41
+ (Dtype.INT8, Dtype.FP16),
42
+ (Dtype.FP16, Dtype.INT8),
43
+ (Dtype.FP16, Dtype.FP16),
46
44
  ]
47
45
  )
48
46
  def test_verify_invalid_recipes(
49
47
  self,
50
48
  activation,
51
49
  weight,
52
- mode,
53
- algo=Algorithm.MIN_MAX,
54
- granularity=Granularity.CHANNELWISE,
55
50
  ):
56
- with self.assertRaises(ValueError):
57
- quant_recipe.LayerQuantRecipe(
58
- activation, weight, mode, algo, granularity
59
- ).verify()
51
+ for m in Mode:
52
+ for a in Algorithm:
53
+ for g in Granularity:
54
+ with self.assertRaises(ValueError):
55
+ quant_recipe.LayerQuantRecipe(activation, weight, m, a, g).verify()
60
56
 
61
57
  @parameterized.expand(
62
58
  [
63
- (Dtype.FP32, Dtype.INT8, Mode.DYNAMIC_RANGE, Granularity.CHANNELWISE),
64
- (Dtype.FP32, Dtype.FP16, Mode.WEIGHT_ONLY, Granularity.NONE),
59
+ (
60
+ Dtype.FP32,
61
+ Dtype.INT8,
62
+ Mode.DYNAMIC_RANGE,
63
+ Algorithm.MIN_MAX,
64
+ Granularity.CHANNELWISE,
65
+ ),
66
+ (
67
+ Dtype.FP32,
68
+ Dtype.INT8,
69
+ Mode.WEIGHT_ONLY,
70
+ Algorithm.MIN_MAX,
71
+ Granularity.CHANNELWISE,
72
+ ),
73
+ (
74
+ Dtype.FP32,
75
+ Dtype.FP16,
76
+ Mode.WEIGHT_ONLY,
77
+ Algorithm.FLOAT_CAST,
78
+ Granularity.NONE,
79
+ ),
65
80
  ]
66
81
  )
67
82
  def test_verify_valid_recipes(
@@ -69,8 +84,8 @@ class TestVerifyRecipes(unittest.TestCase):
69
84
  activation,
70
85
  weight,
71
86
  mode,
87
+ algo,
72
88
  granularity,
73
- algo=Algorithm.MIN_MAX,
74
89
  ):
75
90
  quant_recipe.LayerQuantRecipe(activation, weight, mode, algo, granularity).verify()
76
91
 
@@ -78,7 +93,46 @@ class TestVerifyRecipes(unittest.TestCase):
78
93
  class TestQuantizeConvert(unittest.TestCase):
79
94
  """Test conversion with quantization."""
80
95
 
81
- def test_quantize_convert_toy(self):
96
+ def _attention_1_int8_dynamic_recipe() -> quant_config.QuantConfig:
97
+ return quant_config.QuantConfig(
98
+ generative_recipe=quant_recipe.GenerativeQuantRecipe(
99
+ attention={1: quant_recipe_utils.create_layer_quant_int8_dynamic()},
100
+ )
101
+ )
102
+
103
+ def _feedforward_0_int8_dynamic_recipe() -> quant_config.QuantConfig:
104
+ return quant_config.QuantConfig(
105
+ generative_recipe=quant_recipe.GenerativeQuantRecipe(
106
+ feedforward={0: quant_recipe_utils.create_layer_quant_int8_dynamic()},
107
+ )
108
+ )
109
+
110
+ @parameterized.expand(
111
+ [
112
+ (quant_recipes.full_fp16_recipe(), 0.75),
113
+ (quant_recipes.full_linear_int8_dynamic_recipe(), 0.64),
114
+ (_attention_1_int8_dynamic_recipe(), 0.95),
115
+ (_feedforward_0_int8_dynamic_recipe(), 0.87),
116
+ ]
117
+ )
118
+ def test_quantize_convert_toy_sizes(self, quant_config, expected_compression):
119
+ config = toy_model_with_kv_cache.get_model_config()
120
+ pytorch_model = toy_model_with_kv_cache.ToyModelWithKV(config)
121
+ idx, input_pos = torch.tensor([[1]], dtype=torch.long), torch.tensor(
122
+ [10], dtype=torch.int64
123
+ )
124
+
125
+ quantized_model = ai_edge_torch.convert(
126
+ pytorch_model, (idx, input_pos), quant_config=quant_config
127
+ )
128
+ float_model = ai_edge_torch.convert(pytorch_model, (idx, input_pos))
129
+ self.assertAlmostEqual(
130
+ len(quantized_model._tflite_model) / len(float_model._tflite_model),
131
+ expected_compression,
132
+ delta=0.01,
133
+ )
134
+
135
+ def test_quantize_convert_compare_toy(self):
82
136
  self.skipTest("b/338288901")
83
137
  config = toy_model_with_kv_cache.get_model_config()
84
138
  pytorch_model = toy_model_with_kv_cache.ToyModelWithKV(config)
@@ -32,27 +32,26 @@ class QuantConfig:
32
32
  pt2e_quantizer: The instance of PT2EQuantizer used to quantize the model
33
33
  with PT2E quantization. This method of quantization is not applicable to
34
34
  models created with the Edge Generative API.
35
- transformer_recipe: Quantization recipe to be applied on a model created
35
+ generative_recipe: Quantization recipe to be applied on a model created
36
36
  with the Edge Generative API.
37
37
  """
38
38
 
39
39
  pt2e_quantizer: pt2eq.PT2EQuantizer = None
40
- transformer_recipe: quant_recipe.TransformerQuantRecipe = None
40
+ generative_recipe: quant_recipe.GenerativeQuantRecipe = None
41
41
 
42
42
  @enum.unique
43
43
  class _QuantizerMode(enum.Enum):
44
44
  NONE = enum.auto()
45
45
  PT2E_DYNAMIC = enum.auto()
46
46
  PT2E_STATIC = enum.auto()
47
- TFLITE_DYNAMIC = enum.auto()
48
- TFLITE_FP16 = enum.auto()
47
+ AI_EDGE_QUANTIZER = enum.auto()
49
48
 
50
49
  _quantizer_mode: _QuantizerMode = _QuantizerMode.NONE
51
50
 
52
51
  def __init__(
53
52
  self,
54
53
  pt2e_quantizer: Optional[pt2eq.PT2EQuantizer] = None,
55
- transformer_recipe: Optional[quant_recipe.TransformerQuantRecipe] = None,
54
+ generative_recipe: Optional[quant_recipe.GenerativeQuantRecipe] = None,
56
55
  ):
57
56
  """Initializes some internal states based on selected quantization method.
58
57
 
@@ -61,8 +60,8 @@ class QuantConfig:
61
60
  is properly setup. Additionally sets up an utility enum _quantizer_mode to
62
61
  guide certain conversion processes.
63
62
  """
64
- if pt2e_quantizer is not None and transformer_recipe is not None:
65
- raise ValueError('Cannot set both pt2e_quantizer and transformer_recipe.')
63
+ if pt2e_quantizer is not None and generative_recipe is not None:
64
+ raise ValueError('Cannot set both pt2e_quantizer and generative_recipe.')
66
65
  elif pt2e_quantizer is not None:
67
66
  object.__setattr__(self, 'pt2e_quantizer', pt2e_quantizer)
68
67
  object.__setattr__(
@@ -74,12 +73,9 @@ class QuantConfig:
74
73
  else self._QuantizerMode.PT2E_STATIC
75
74
  ),
76
75
  )
77
- elif transformer_recipe is not None:
78
- transformer_recipe.verify()
79
- object.__setattr__(self, 'transformer_recipe', transformer_recipe)
80
- if self.transformer_recipe.default.mode == quant_attrs.Mode.DYNAMIC_RANGE:
81
- object.__setattr__(self, '_quantizer_mode', self._QuantizerMode.TFLITE_DYNAMIC)
82
- elif self.transformer_recipe.default.weight_dtype == quant_attrs.Dtype.FP16:
83
- object.__setattr__(self, '_quantizer_mode', self._QuantizerMode.TFLITE_FP16)
76
+ elif generative_recipe is not None:
77
+ generative_recipe.verify()
78
+ object.__setattr__(self, 'generative_recipe', generative_recipe)
79
+ object.__setattr__(self, '_quantizer_mode', self._QuantizerMode.AI_EDGE_QUANTIZER)
84
80
  else:
85
- raise ValueError('Either pt2e_quantizer or transformer_recipe must be set.')
81
+ raise ValueError('Either pt2e_quantizer or generative_recipe must be set.')
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.2.0.dev20240609
3
+ Version: 0.2.0.dev20240611
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=FPMmuFU3pyMREtjB_san1fy_0PFtAsgA0VZfOYvDrb4,100
2
2
  ai_edge_torch/model.py,sha256=kmcgELjsYl8YzF8nUF6P7q4i8MWS-pLGpfsy-yTUXmE,4243
3
3
  ai_edge_torch/convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
4
4
  ai_edge_torch/convert/conversion.py,sha256=GN2Js232u_5Y118wg3qIfEoYewxbxLl3TpSnO6osi8c,4029
5
- ai_edge_torch/convert/conversion_utils.py,sha256=NpVm3Ms81_cIW5IYgGsr0BVganJJgBKWVBDe5h_ZaGE,11021
5
+ ai_edge_torch/convert/conversion_utils.py,sha256=9BqCL38DErv1vEVGtT3BIJVhdwZjw2EQ-_m5UpvVVYE,11294
6
6
  ai_edge_torch/convert/converter.py,sha256=bjj5TV5_g4sGyuSh8ThEDydlNMqhkGSY4SzXK6vwhqI,6927
7
7
  ai_edge_torch/convert/fx_passes/__init__.py,sha256=EPs4PSIDLuRH5EBETi6deaOvaaf_Q4xD3_9NVcR7x8o,2810
8
8
  ai_edge_torch/convert/fx_passes/_pass_base.py,sha256=ijVyDclPnd6a0DWWUJkwR4igj6f82S-cE1-83QGPvgw,1652
@@ -34,16 +34,16 @@ ai_edge_torch/generative/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf
34
34
  ai_edge_torch/generative/examples/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
35
35
  ai_edge_torch/generative/examples/gemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
36
36
  ai_edge_torch/generative/examples/gemma/convert_to_tflite.py,sha256=dZv3r24uHsTMokEdnl3nf7LpmV0q7FLnVtCuHn5AuUs,2538
37
- ai_edge_torch/generative/examples/gemma/gemma.py,sha256=YF4Ua-1lnL3qhQnh1sY5-HlYw2Dq6ZRm227XyDe7WAw,5913
37
+ ai_edge_torch/generative/examples/gemma/gemma.py,sha256=1lZfXGHmbII4rFu0U2B9NzlJCRhphxtmQtkCHQ39_uw,5935
38
38
  ai_edge_torch/generative/examples/phi2/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
39
39
  ai_edge_torch/generative/examples/phi2/convert_to_tflite.py,sha256=6nOuwx9q3AUlYcjXRRXSr_3M2JKqdJ-vUf-uE3VFYHE,2512
40
- ai_edge_torch/generative/examples/phi2/phi2.py,sha256=VvigzPQ_LJHeADTsMliwFwPe2BcnOhFgKDqr_WZ2JQ8,5540
40
+ ai_edge_torch/generative/examples/phi2/phi2.py,sha256=PMhKC6JCAMYSj2F3UmWHWK4rTcXD-B6PuehaoDccRqk,5562
41
41
  ai_edge_torch/generative/examples/stable_diffusion/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
42
42
  ai_edge_torch/generative/examples/stable_diffusion/attention.py,sha256=Lo4Dq7a3Kg-lyH56iqGtqCo5UaClQHRCTDdNagXGTo8,3535
43
- ai_edge_torch/generative/examples/stable_diffusion/clip.py,sha256=yUCJemEh4n8ez-yLgVU0HZAki-PZ9nY04DFjgpx9PUc,3698
43
+ ai_edge_torch/generative/examples/stable_diffusion/clip.py,sha256=qU1wVEcn_biwCuDguZljhlLGzpLIqgqC31Dh_lXquQc,3720
44
44
  ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py,sha256=NmgDo5uAefrhMUbYku0TKHlqzO0NVWI_M1ue8tddQR4,4024
45
- ai_edge_torch/generative/examples/stable_diffusion/decoder.py,sha256=Z1bnZvYtPdwNy706kixVDfL32X-R87B_WF3CcHwiz0o,11038
46
- ai_edge_torch/generative/examples/stable_diffusion/diffusion.py,sha256=TfbfsmuKoGsBENF9fYIAN_SMEQNhj-kjNdqQXFJGxpg,7784
45
+ ai_edge_torch/generative/examples/stable_diffusion/decoder.py,sha256=meW8t-3BDdjFs5vCAf76cn6lGx49a_GcEvnVa9R5if4,11106
46
+ ai_edge_torch/generative/examples/stable_diffusion/diffusion.py,sha256=_gEeUxa9Xyd3iLb_fyeUefHKuELVDorDlQs8e7wdXKg,7878
47
47
  ai_edge_torch/generative/examples/stable_diffusion/encoder.py,sha256=mgbxkeFDMkNIGmnbcFTIFPu8EWKokghiviYIOB2lE3Q,3437
48
48
  ai_edge_torch/generative/examples/stable_diffusion/pipeline.py,sha256=FCbnwlkpYYb-tF7KscbSYjNEdg7XnuLju1cDuIRoQv8,8277
49
49
  ai_edge_torch/generative/examples/stable_diffusion/tokenizer.py,sha256=r9RqbyNvuvXOGu3ojtl7ZmbC7o4Pt8aUKAhN1yCdtEc,3397
@@ -55,40 +55,42 @@ ai_edge_torch/generative/examples/stable_diffusion/samplers/k_lms.py,sha256=iPYX
55
55
  ai_edge_torch/generative/examples/stable_diffusion/samplers/sampler.py,sha256=5iRfU5MO6GR6K3WrdddIU_9U7ZZGEEb7zGKVY1WFl-8,1340
56
56
  ai_edge_torch/generative/examples/t5/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
57
57
  ai_edge_torch/generative/examples/t5/convert_to_tflite.py,sha256=bWtwtUacvJOEDUpuYvLTgkP7oTkXKJA-Tf4FPxlD1Cw,4536
58
- ai_edge_torch/generative/examples/t5/t5.py,sha256=q2gG5RRo7RgNzvHXYC0Juh6Tgt5d_RTMSWFaYvOKiZU,21065
58
+ ai_edge_torch/generative/examples/t5/t5.py,sha256=L6YrVzUEzP-Imb8W28LdukFGrx1aWSzz1kyYK_9RFZM,21087
59
59
  ai_edge_torch/generative/examples/t5/t5_attention.py,sha256=rRgwCEdVtzcJEaGbbBjw8HxCxrCX3pXA5nelawdYiME,9036
60
60
  ai_edge_torch/generative/examples/test_models/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
61
- ai_edge_torch/generative/examples/test_models/toy_model.py,sha256=EV07_MEG3fv9g0ZGu9gbBd5BjjrGkxCT1pv7dvhz4TI,3791
62
- ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py,sha256=rzL5h7Z5DIEgfpc1pWgYHdKt2aR8ha_CUqTKQBSPBaU,5521
63
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py,sha256=MUr6fSj2hBuYSlNbZtrBBpzqB_0WY-l_xYcd_TFFUjY,4831
61
+ ai_edge_torch/generative/examples/test_models/toy_model.py,sha256=CUXsQ_IU96NaCg9jyfeKI0Zz2iWDkJUsPJyPR1Pgz7I,3813
62
+ ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py,sha256=zwCmCnhr-vhBwHqv9i7xMasdBGVNqAGxZvWsncsJn58,5543
63
+ ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py,sha256=lfYUiem_Pbn3vGgPx84BeI8n7rN3-1fImwCLm8Eo2U8,4853
64
64
  ai_edge_torch/generative/examples/tiny_llama/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
65
65
  ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py,sha256=E4I5OlC4zyl5cxiiu7uTED-zcwYRu210lP1zuT3xLBE,2566
66
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=hVGpuI8gpj4Rn9k4otsRE22MSLFHBDlUOgioY6Ru6VI,5629
66
+ ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=IFRLPG9wz_aLl_zV_6CETCjSM03ukA6bZqqyDLVACuw,5651
67
67
  ai_edge_torch/generative/layers/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
68
68
  ai_edge_torch/generative/layers/attention.py,sha256=Z8gXHYs6h8gaRiYAdvYUbHzg_2EmqfxiChsf_SYraAc,7902
69
69
  ai_edge_torch/generative/layers/attention_utils.py,sha256=hXhuyKblPPxKIRzlAf1YNlwHgpbj-6nReRLhRHELx5k,6350
70
- ai_edge_torch/generative/layers/builder.py,sha256=8cPL1NAutjT6Dwtyy2X7NSaTl9WCUJM5SIrBIDcEvVY,3520
70
+ ai_edge_torch/generative/layers/builder.py,sha256=jAyrR5hsSI0aimKZumyvxdJ1GovERIfsK0g-dezX2gs,4163
71
71
  ai_edge_torch/generative/layers/feed_forward.py,sha256=4j2QaSCw59Jkk_ixKDpKEj7FLRauzuExTiSNRzAjAhE,2820
72
72
  ai_edge_torch/generative/layers/kv_cache.py,sha256=4uiZLO3om5G3--kT04Jt0esEYznbkJ7QLzSHfb8mjc4,3090
73
- ai_edge_torch/generative/layers/model_config.py,sha256=72DXOsFC0buvzZp6YyVjuTVrpphAubBJ5NJWfs3kEwk,4362
73
+ ai_edge_torch/generative/layers/model_config.py,sha256=toWECENDWgay9hsZcy4C89qph0KI3CpaeFqFc8Fr-Xk,4584
74
74
  ai_edge_torch/generative/layers/normalization.py,sha256=M27eW3TcNK20oaXClXtfnu0lLWrAGrSKSsbegRWnj3c,1867
75
75
  ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=12SsCuoRuLNCwnFGe_pHDOZEBwBcqXs87Aj0PaWWw4E,1383
76
76
  ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=dYafGC205QE5CLIbBTCI-7eVvEGZEHzs1toPEhemeDs,3391
77
77
  ai_edge_torch/generative/layers/unet/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
78
- ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=KuZd2oZhkCQSknSgXMBla-sfYBPUv5bZNf9RYKXHfGg,10052
78
+ ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=7mHyJYq9lq5zVYp4mEz-R8Az3FFngi711YC20KP6ED8,10066
79
79
  ai_edge_torch/generative/layers/unet/builder.py,sha256=iH0_nuY9TF2ap5h1JbGNCOonPTfrXQHcF8U0slrIREM,1210
80
- ai_edge_torch/generative/layers/unet/model_config.py,sha256=LeRGB34fQ73UlknlFpjM9U-SZIRcQDnSmDltJivX-UA,4044
80
+ ai_edge_torch/generative/layers/unet/model_config.py,sha256=sbtbDEHmMV9GLKngwjsNvqm8wovLxnlidkQbXdXkXKs,4060
81
81
  ai_edge_torch/generative/quantize/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
82
82
  ai_edge_torch/generative/quantize/example.py,sha256=t-YwyKSPAG-OZC1DfH-0vfie2RHHpTSQjxUY-tmhu5g,1543
83
- ai_edge_torch/generative/quantize/quant_attrs.py,sha256=ffBALrrbrfiG_mrOr-f3B1Gc6PlAma9gtvVnfP7SDzI,1862
84
- ai_edge_torch/generative/quantize/quant_recipe.py,sha256=BOk4E0FW-_YD8Y-oPVmIDsgXx_bPtvzsP_V1av5DvgU,3327
85
- ai_edge_torch/generative/quantize/quant_recipe_utils.py,sha256=9ktL7fT8C5j1dnY_7fkiFL4oWNLVs1dMWXkS_EuyA3Y,1913
86
- ai_edge_torch/generative/quantize/quant_recipes.py,sha256=CRA2ENevS-3usHqidWDe2wrf_epILE_7Hx-XfZQ9buk,1798
87
- ai_edge_torch/generative/quantize/supported_schemes.py,sha256=OQ4ghQXknA1PPjuY-xBgAmOpaIBgYFM8F2YAIot06hE,1345
83
+ ai_edge_torch/generative/quantize/quant_attrs.py,sha256=n1Fm8BFC8gJa_oiwwAOOghJyHtOXYZ4q-5ZRy4pHrIw,1957
84
+ ai_edge_torch/generative/quantize/quant_recipe.py,sha256=Y8zahKw7b_h7ajPaJZVef4jG-MoqImRCpVSbFtV_i24,5139
85
+ ai_edge_torch/generative/quantize/quant_recipe_utils.py,sha256=-vd6Qp0BdXJVKg4f0_hhwbKOi3QPIAPVqyXnJ-ZnISQ,1915
86
+ ai_edge_torch/generative/quantize/quant_recipes.py,sha256=9ItD70jQRXMEhWod-nUfEeoWGJUUu6V9YOffF07VU9g,1795
87
+ ai_edge_torch/generative/quantize/supported_schemes.py,sha256=FjdycEOvxRgBmQdZVufetPvkDoD7rUowIOSKV9oV5Kk,1418
88
+ ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
89
+ ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/translate_recipe.py,sha256=qUB4f2DoB14dLkNPWf6TZodpT81mfAJeWM-lCAmkuHY,5735
88
90
  ai_edge_torch/generative/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
89
91
  ai_edge_torch/generative/test/loader_test.py,sha256=N88CbrLW7Q2x1EyurwdXQ6YjsA-ySQcPxpZH3QOGp-M,3317
90
92
  ai_edge_torch/generative/test/test_model_conversion.py,sha256=i_SAW-hD8SaHuopMZI9IuXXDFn5uSTJa1nKZhaC3dAQ,6811
91
- ai_edge_torch/generative/test/test_quantize.py,sha256=f70sH1ZFzdCwYj0MG-eg54WOC4LasR0D8CTUYpjxZYM,3728
93
+ ai_edge_torch/generative/test/test_quantize.py,sha256=NVlMixAxVpDUabEvp6zTHHgIDgHFsMRwlf5MuyDwrPg,5355
92
94
  ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
93
95
  ai_edge_torch/generative/utilities/autoencoder_loader.py,sha256=G2Nosy33JzkjGALPR4JjvffdFX1JWOj2zjbbuaDJEgg,10065
94
96
  ai_edge_torch/generative/utilities/loader.py,sha256=Hs92478j1g4jQGvbdP1aWvOy907HjwqQZE-NFy6HELo,11326
@@ -103,12 +105,12 @@ ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py,sha256=aUAPKnH4_Jxpp
103
105
  ai_edge_torch/quantize/__init__.py,sha256=aB5dXot04bqyUhpsDFvxt9CIi15QAC4euvqOndJ0XLU,714
104
106
  ai_edge_torch/quantize/pt2e_quantizer.py,sha256=ye1f5vAZ0Vr4RWAtfrgU1o3JLs03Sa4inHRq3YxJDGo,15602
105
107
  ai_edge_torch/quantize/pt2e_quantizer_utils.py,sha256=yjzKoptnfEeW_sN7sODUfj3nCtUMXVzq3vHKxblsd5Y,36046
106
- ai_edge_torch/quantize/quant_config.py,sha256=ExThdTXqnWmGC3-F6sdXbXr8nYzkEe_qCziCfhsoMPA,3435
108
+ ai_edge_torch/quantize/quant_config.py,sha256=eO9Ra160ITjQSyRBEGy6nNIVH3gYacSWDdN5XtvHwjc,3148
107
109
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
108
110
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
109
111
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=EIyKz-HY70DguWuSrJal8LpYXQ5ZSEUf3ZrVl7jikFM,4286
110
- ai_edge_torch_nightly-0.2.0.dev20240609.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
111
- ai_edge_torch_nightly-0.2.0.dev20240609.dist-info/METADATA,sha256=THUdn03pVtqLTQHdigvUA8B32EHwfV5UTvBRQCpR8v0,1748
112
- ai_edge_torch_nightly-0.2.0.dev20240609.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
113
- ai_edge_torch_nightly-0.2.0.dev20240609.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
114
- ai_edge_torch_nightly-0.2.0.dev20240609.dist-info/RECORD,,
112
+ ai_edge_torch_nightly-0.2.0.dev20240611.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
113
+ ai_edge_torch_nightly-0.2.0.dev20240611.dist-info/METADATA,sha256=WPGu2pq6N57fBtpunyFhunPe73UK_SVbqlZQsZwjWGo,1748
114
+ ai_edge_torch_nightly-0.2.0.dev20240611.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
115
+ ai_edge_torch_nightly-0.2.0.dev20240611.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
116
+ ai_edge_torch_nightly-0.2.0.dev20240611.dist-info/RECORD,,