ai-edge-torch-nightly 0.5.0.dev20250508__py3-none-any.whl → 0.5.0.dev20250510__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -20,7 +20,9 @@ from ai_edge_torch.generative.examples.gemma import gemma2
20
20
  from ai_edge_torch.generative.utilities import converter
21
21
  from ai_edge_torch.generative.utilities import export_config
22
22
 
23
- flags = converter.define_conversion_flags("gemma2-2b")
23
+ flags = converter.define_conversion_flags(
24
+ "gemma2-2b", default_mask_as_input=True, default_transpose_kv_cache=True
25
+ )
24
26
 
25
27
 
26
28
  def main(_):
@@ -20,7 +20,9 @@ from ai_edge_torch.generative.examples.gemma3 import gemma3
20
20
  from ai_edge_torch.generative.utilities import converter
21
21
  from ai_edge_torch.generative.utilities import export_config
22
22
 
23
- flags = converter.define_conversion_flags('gemma3-1b')
23
+ flags = converter.define_conversion_flags(
24
+ 'gemma3-1b', default_mask_as_input=True, default_transpose_kv_cache=True
25
+ )
24
26
 
25
27
  _MODEL_SIZE = flags.DEFINE_string(
26
28
  'model_size',
@@ -119,9 +119,7 @@ class Decoder(nn.Module):
119
119
  config.vocab_size, config.embedding_dim, padding_idx=0
120
120
  )
121
121
  self.lm_head = nn.Linear(
122
- config.embedding_dim,
123
- config.vocab_size,
124
- bias=config.lm_head_use_bias,
122
+ config.embedding_dim, config.vocab_size, bias=config.lm_head_use_bias
125
123
  )
126
124
  # Gemma3 re-uses the embedding as the head projection layer.
127
125
  self.lm_head.weight.data = self.tok_embedding.weight.data
@@ -130,30 +128,13 @@ class Decoder(nn.Module):
130
128
  for idx in range(config.num_layers)
131
129
  )
132
130
  self.final_norm = builder.build_norm(
133
- config.embedding_dim,
134
- config.final_norm_config,
131
+ config.embedding_dim, config.final_norm_config
135
132
  )
136
133
  self.mask_cache = attn_utils.build_causal_mask_cache(
137
134
  size=config.kv_cache_max,
138
135
  )
139
- # Gemma3 has same hyper parameters for each layer except for attention
140
- # types. Use the first layer.
141
- attn_config = config.block_config(0).attn_config
142
- self.sliding_window_mask_cache = attn_utils.build_sliding_window_mask_cache(
143
- size=config.kv_cache_max,
144
- window_size=attn_config.sliding_window_size,
145
- )
146
136
  self.config = config
147
137
 
148
- def get_attention_mask(
149
- self,
150
- attn_type: cfg.AttentionType,
151
- input_pos: torch.Tensor,
152
- ) -> torch.Tensor:
153
- if attn_type == cfg.AttentionType.LOCAL_SLIDING:
154
- return self.sliding_window_mask_cache.index_select(2, input_pos)
155
- return self.mask_cache.index_select(2, input_pos)
156
-
157
138
  def get_local_global_attention_mask(
158
139
  self,
159
140
  attention_mask: torch.Tensor,
@@ -200,9 +181,7 @@ class Decoder(nn.Module):
200
181
  sliding_mask_bool,
201
182
  torch.zeros_like(sliding_mask_bool, dtype=torch.float),
202
183
  torch.full_like(
203
- sliding_mask_bool,
204
- self.config.causal_mask_value,
205
- dtype=torch.float,
184
+ sliding_mask_bool, self.config.causal_mask_value, dtype=torch.float
206
185
  ),
207
186
  )
208
187
 
@@ -261,7 +240,6 @@ class Decoder(nn.Module):
261
240
  pixel_mask = self.build_pixel_mask(image_indices)
262
241
  # RoPE parameters are the same for all blocks. Use the first layer.
263
242
  attn_config = self.config.block_config(0).attn_config
264
- n_elem = int(attn_config.rotary_percentage * attn_config.head_dim)
265
243
  # Different rotary base for global and local attention
266
244
  # based on attention pattern
267
245
  rope = [
@@ -273,12 +251,8 @@ class Decoder(nn.Module):
273
251
  for i in range(self.config.num_layers)
274
252
  ]
275
253
  if mask is None:
276
- mask = [
277
- self.get_attention_mask(
278
- self.config.block_config(i).attn_config.attn_type, input_pos
279
- )
280
- for i in range(self.config.num_layers)
281
- ]
254
+ mask = self.mask_cache.index_select(2, input_pos)
255
+ mask = mask[:, :, :, : self.config.kv_cache_max]
282
256
 
283
257
  return self._forward_with_embeds(
284
258
  input_embeds, rope, mask, input_pos, kv_cache, pixel_mask, export_config
@@ -305,7 +279,7 @@ class Decoder(nn.Module):
305
279
  if pixel_mask is None:
306
280
  mask = [
307
281
  self.get_local_global_attention_mask(
308
- mask,
282
+ mask[i] if isinstance(mask, list) else mask,
309
283
  self.config.block_config(i).attn_config.attn_type,
310
284
  input_pos,
311
285
  self.config.block_config(i).attn_config.sliding_window_size,
@@ -316,7 +290,7 @@ class Decoder(nn.Module):
316
290
  pixel_mask = pixel_mask.index_select(2, input_pos)
317
291
  mask = [
318
292
  self.compose_mask(
319
- mask[i],
293
+ mask[i] if isinstance(mask, list) else mask,
320
294
  pixel_mask,
321
295
  self.config.block_config(i).attn_config.attn_type,
322
296
  )
@@ -330,6 +304,7 @@ class Decoder(nn.Module):
330
304
  if kv_entry:
331
305
  updated_kv_entries.append(kv_entry)
332
306
  updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entries))
307
+
333
308
  if export_config is not None:
334
309
  if (
335
310
  torch.numel(input_pos) > 1
@@ -16,9 +16,12 @@
16
16
  from dataclasses import dataclass
17
17
  from typing import Optional, Union
18
18
 
19
+ from ai_edge_torch.generative.layers import model_config
19
20
  from ai_edge_torch.generative.quantize import quant_attrs
20
21
  from ai_edge_torch.generative.quantize import supported_schemes
21
22
 
23
+ ModelConfig = model_config.ModelConfig
24
+
22
25
 
23
26
  @dataclass
24
27
  class LayerQuantRecipe:
@@ -52,7 +55,7 @@ class LayerQuantRecipe:
52
55
  f'w:{self.weight_dtype.name}, '
53
56
  f'{self.mode.name}, '
54
57
  f'{self.algorithm.name}, '
55
- f'{self.granularity.name}'
58
+ f'{self.granularity.name}, '
56
59
  f'{self.block_size}'
57
60
  )
58
61
  return f'{base_str})'
@@ -133,6 +136,7 @@ class GenerativeQuantRecipe:
133
136
  feedforward: Union[
134
137
  Optional[LayerQuantRecipe], Optional[dict[int, LayerQuantRecipe]]
135
138
  ] = None
139
+ _model_config: Optional[ModelConfig] = None
136
140
 
137
141
  def __str__(self):
138
142
  return f"""GenerativeQuantRecipe(
@@ -63,6 +63,7 @@ def all_supported_int4_dynamic_block_recipe(
63
63
  generative_recipe=quant_recipe.GenerativeQuantRecipe(
64
64
  default=quant_recipe_utils.create_layer_quant_int4_dynamic_block(
65
65
  block_size
66
- )
66
+ ),
67
+ embedding=quant_recipe_utils.create_layer_quant_int8_dynamic(),
67
68
  )
68
69
  )
@@ -14,7 +14,6 @@
14
14
  # ==============================================================================
15
15
 
16
16
  import ai_edge_torch
17
- from ai_edge_torch import config
18
17
  from ai_edge_torch.generative.examples.test_models import toy_model # NOQA
19
18
  from ai_edge_torch.generative.quantize import quant_recipe
20
19
  from ai_edge_torch.generative.quantize import quant_recipe_utils
@@ -15,6 +15,7 @@
15
15
 
16
16
  """Common utility functions for model conversion."""
17
17
 
18
+ import enum
18
19
  import os
19
20
  import pathlib
20
21
  from typing import Optional, Union
@@ -42,7 +43,32 @@ class ExportableModule(torch.nn.Module):
42
43
  return self.module(*export_args, **full_kwargs)
43
44
 
44
45
 
45
- def define_conversion_flags(model_name: str):
46
+ class QuantizationName(str, enum.Enum):
47
+ """Strings for all supported quantization recipes.
48
+
49
+ none: No quantization.
50
+ dynamic_int8: Dynamic range quantization with int8 weights.
51
+ weight_only_int8: Weight only quantization with int8 weights.
52
+ fp16: Float16 quantization.
53
+ dynamic_int4_block32: Dynamic range quantization with int4 weights and block
54
+ size of 32, better model quality but slower inference.
55
+ dynamic_int4_block128: Dynamic range quantization with int4 weights and block
56
+ size of 128, faster inference but worse model quality.
57
+ """
58
+
59
+ NONE = 'none'
60
+ DYNAMIC_INT8 = 'dynamic_int8'
61
+ WEIGHT_ONLY_INT8 = 'weight_only_int8'
62
+ FP16 = 'fp16'
63
+ DYNAMIC_INT4_BLOCK32 = 'dynamic_int4_block32'
64
+ DYNAMIC_INT4_BLOCK128 = 'dynamic_int4_block128'
65
+
66
+
67
+ def define_conversion_flags(
68
+ model_name: str,
69
+ default_mask_as_input: bool = False,
70
+ default_transpose_kv_cache: bool = False,
71
+ ):
46
72
  """Defines common flags used for model conversion."""
47
73
 
48
74
  flags.DEFINE_string(
@@ -70,10 +96,10 @@ def define_conversion_flags(model_name: str):
70
96
  1280,
71
97
  'The maximum size of KV cache buffer, including both prefill and decode.',
72
98
  )
73
- flags.DEFINE_bool(
99
+ flags.DEFINE_string(
74
100
  'quantize',
75
- True,
76
- 'Whether the model should be quantized.',
101
+ 'dynamic_int8',
102
+ 'How the model should be quantized.',
77
103
  )
78
104
  flags.DEFINE_multi_integer(
79
105
  'lora_ranks',
@@ -83,18 +109,78 @@ def define_conversion_flags(model_name: str):
83
109
  )
84
110
  flags.DEFINE_bool(
85
111
  'mask_as_input',
86
- False,
112
+ default_mask_as_input,
87
113
  'If true, the mask will be passed in as input. Otherwise, mask will be '
88
114
  'built by the model internally.',
89
115
  )
90
116
  flags.DEFINE_bool(
91
117
  'transpose_kv_cache',
92
- False,
118
+ default_transpose_kv_cache,
93
119
  'If true, the model will be converted with transposed KV cache.',
94
120
  )
95
121
  return flags
96
122
 
97
123
 
124
+ def get_quant_recipe_from_flag(
125
+ quantize: str,
126
+ ) -> Optional[quant_recipes.QuantizationRecipe]:
127
+ """Processes the quantization flag and returns the corresponding recipe.
128
+
129
+ Args:
130
+ quantize: The quantization type.
131
+
132
+ Returns:
133
+ The quantization recipe, or None if no quantization is needed.
134
+
135
+ Raises:
136
+ ValueError: If the quantization type is not supported.
137
+ """
138
+ match quantize:
139
+ case QuantizationName.NONE:
140
+ return None
141
+ case QuantizationName.DYNAMIC_INT8:
142
+ return quant_recipes.full_int8_dynamic_recipe()
143
+ case QuantizationName.WEIGHT_ONLY_INT8:
144
+ return quant_recipes.full_int8_weight_only_recipe()
145
+ case QuantizationName.FP16:
146
+ return quant_recipes.full_fp16_recipe()
147
+ case QuantizationName.DYNAMIC_INT4_BLOCK32:
148
+ return quant_recipes.full_int4_dynamic_block_recipe(32)
149
+ case QuantizationName.DYNAMIC_INT4_BLOCK128:
150
+ return quant_recipes.full_int4_dynamic_block_recipe(128)
151
+ case _:
152
+ raise ValueError(f'Unsupported quantization flag: {quantize}')
153
+
154
+
155
+ def create_quantize_suffix(quantize: str) -> str:
156
+ """Creates a suffix for the output file name based on the quantization type.
157
+
158
+ Args:
159
+ quantize: The quantization type.
160
+
161
+ Returns:
162
+ A string representing the quantization suffix.
163
+
164
+ Raises:
165
+ ValueError: If the quantization type is not supported.
166
+ """
167
+ match quantize:
168
+ case QuantizationName.NONE:
169
+ return 'f32'
170
+ case QuantizationName.DYNAMIC_INT8:
171
+ return 'q8'
172
+ case QuantizationName.WEIGHT_ONLY_INT8:
173
+ return 'q8_wo'
174
+ case QuantizationName.FP16:
175
+ return 'fp16'
176
+ case QuantizationName.DYNAMIC_INT4_BLOCK32:
177
+ return 'q4_block32'
178
+ case QuantizationName.DYNAMIC_INT4_BLOCK128:
179
+ return 'q4_block128'
180
+ case _:
181
+ raise ValueError(f'Unsupported quantization flag: {quantize}')
182
+
183
+
98
184
  def _build_mask(mask_len, kv_cache_max_len, causal_mask_value) -> torch.Tensor:
99
185
  if isinstance(mask_len, list):
100
186
  return [
@@ -114,7 +200,7 @@ def convert_to_tflite(
114
200
  prefill_seq_len: Union[int, list[int]],
115
201
  pixel_values_size: torch.Size = None,
116
202
  pixel_seq_len: int = 0,
117
- quantize: bool = True,
203
+ quantize: str = 'dynamic_int8',
118
204
  config: cfg.ModelConfig = None,
119
205
  lora_ranks: Optional[list[int]] = None,
120
206
  export_config: ExportConfig = None,
@@ -160,8 +246,8 @@ def convert_to_tflite(
160
246
  embeddings generated by the image encoder with pixel values. The actual
161
247
  length of prefill_seq_len will be added by pixel_seq_len when pixel
162
248
  values are passed.
163
- quantize (bool, optional): Whether the model should be quanized. Defaults
164
- to True.
249
+ quantize (str, optional): The quantization type. Defaults to
250
+ 'dynamic_int8'.
165
251
  config (cfg.ModelConfig, optional): The model config used to configure KV
166
252
  cache. If None, it uses the config of the pytorch_model.
167
253
  lora_ranks (list[int], optional): The ranks of the LORA layers. If None,
@@ -182,7 +268,7 @@ def convert_to_tflite(
182
268
  lora = lora_utils.LoRA.zeros(rank, config)
183
269
  loras.append(lora)
184
270
 
185
- quant_suffix = 'q8' if quantize else 'f32'
271
+ quant_suffix = create_quantize_suffix(quantize)
186
272
  kv_size = config.kv_cache_max_len
187
273
  lora_suffix = (
188
274
  '' if not lora_ranks else f'_lora{",".join(map(str, lora_ranks))}'
@@ -216,7 +302,7 @@ def _export_helper(
216
302
  prefill_seq_lens: list[int],
217
303
  pixel_values_size: torch.Size,
218
304
  pixel_seq_len: int,
219
- quantize: bool,
305
+ quantize: str,
220
306
  config: cfg.ModelConfig,
221
307
  loras: list[None | lora_utils.LoRA],
222
308
  export_config: ExportConfig,
@@ -265,7 +351,8 @@ def _export_helper(
265
351
  kv_layout=export_config.kvcache_layout,
266
352
  )
267
353
 
268
- quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
354
+ quant_config = get_quant_recipe_from_flag(quantize)
355
+ quant_config._model_config = config
269
356
 
270
357
  # For export, we create a module that captures any non-exportable,
271
358
  # arugments, e.g. the generation config object.
@@ -330,5 +417,7 @@ def _export_helper(
330
417
  sample_kwargs=sample_kwargs,
331
418
  )
332
419
 
333
- edge_model = converter.convert(quant_config=quant_config)
420
+ edge_model = converter.convert(
421
+ quant_config=quant_config,
422
+ )
334
423
  edge_model.export(output_file)
@@ -75,8 +75,7 @@ class DecoderOnlyModel(nn.Module):
75
75
  for idx in range(config.num_layers)
76
76
  )
77
77
  self.final_norm = builder.build_norm(
78
- config.embedding_dim,
79
- config.final_norm_config,
78
+ config.embedding_dim, config.final_norm_config
80
79
  )
81
80
  self.mask_cache = attn_utils.build_causal_mask_cache(
82
81
  size=config.kv_cache_max,
@@ -50,7 +50,7 @@ def exported_programs_to_tflite(
50
50
  *,
51
51
  quant_config: Optional[qcfg.QuantConfig] = None,
52
52
  _tfl_converter_flags: Optional[dict[str, Any]] = None,
53
- _saved_model_dir: Optional[str] = None
53
+ _saved_model_dir: Optional[str] = None,
54
54
  ):
55
55
  """Converts a list of ExportedProgram to a TFLite model.
56
56
 
@@ -29,6 +29,8 @@ _IDX_TRANSFORMER_BLOCKS_REGEX_STR = 'transformer_blocks\[{}\]'
29
29
  _ATTENTION_REGEX_STR = 'ai_edge_torch.generative.layers.attention'
30
30
  _FEEDFORWARD_REGEX_STR = 'ai_edge_torch.generative.layers.feed_forward'
31
31
  _EMBEDDING_REGEX_STR = 'Embedding_tok_embedding'
32
+ # TODO: b/415833584 - Improve the regex for pre-softmax layer.
33
+ _DECODE_LOGITS_REGEX_STR = 'StatefulPartitionedCall'
32
34
  _ANY_TWO_DIGITS_REGEX_STR = '\d{1,2}'
33
35
 
34
36
 
@@ -95,10 +97,11 @@ def _set_quant_config(
95
97
  rm: quantizer.recipe_manager.RecipeManager,
96
98
  layer_recipe: quant_recipe.LayerQuantRecipe,
97
99
  regex: str,
100
+ operation_name: _OpName = _OpName.ALL_SUPPORTED,
98
101
  ):
99
102
  rm.add_quantization_config(
100
103
  regex=regex,
101
- operation_name=_OpName.ALL_SUPPORTED,
104
+ operation_name=operation_name,
102
105
  op_config=_OpQuantConfig(
103
106
  weight_tensor_config=_TensorQuantConfig(
104
107
  num_bits=_get_nbits_from_dtype(layer_recipe.weight_dtype),
@@ -126,6 +129,16 @@ def translate_to_ai_edge_recipe(
126
129
 
127
130
  if recipe.embedding is not None:
128
131
  _set_quant_config(rm, recipe.embedding, _EMBEDDING_REGEX_STR)
132
+ if (
133
+ recipe._model_config is not None
134
+ and recipe._model_config.lm_head_share_weight_with_embedding
135
+ ):
136
+ _set_quant_config(
137
+ rm,
138
+ recipe.embedding,
139
+ _DECODE_LOGITS_REGEX_STR,
140
+ _OpName.FULLY_CONNECTED,
141
+ )
129
142
 
130
143
  if recipe.attention is not None:
131
144
  if isinstance(recipe.attention, dict):
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.5.0.dev20250508"
16
+ __version__ = "0.5.0.dev20250510"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.5.0.dev20250508
3
+ Version: 0.5.0.dev20250510
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -22,6 +22,7 @@ Classifier: Topic :: Software Development :: Libraries :: Python Modules
22
22
  Requires-Python: >=3.10
23
23
  Description-Content-Type: text/markdown
24
24
  License-File: LICENSE
25
+ Requires-Dist: absl-py
25
26
  Requires-Dist: numpy
26
27
  Requires-Dist: scipy
27
28
  Requires-Dist: safetensors
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=8sPR_5uXJA4NEE0nIwNdSl-ADOJEoR8hAgYvBQDY70Y,120
2
2
  ai_edge_torch/_config.py,sha256=AiqhbcheF7j_ozIGDLC89k1we95aVgFDa-tR6h7UI0s,2529
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/model.py,sha256=wxjSFq_rBSxSqbUE8E8EJTCkgvgaRLjq_ZuAM-IZpCU,5606
5
- ai_edge_torch/version.py,sha256=7lrbHHeWyBpqJdwFYYooOGJss4Rvg3UAdFSo9K0uzek,706
5
+ ai_edge_torch/version.py,sha256=03QthwiMre1vVY49We8vVzXhxe0zkOzzsTMQZv3hDrk,706
6
6
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
7
7
  ai_edge_torch/_convert/conversion.py,sha256=QVugYVfbyaeBgSKKbhFzHG5oXA7t3M-40JcpcdSu6W8,5436
8
8
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
@@ -61,15 +61,15 @@ ai_edge_torch/generative/examples/deepseek/deepseek.py,sha256=yhS_i2kR0GJWpWciCt
61
61
  ai_edge_torch/generative/examples/deepseek/verify.py,sha256=iYldze-pvZGvPkkqr6zA7EmitPnH9sXkzjNVx353IcE,2403
62
62
  ai_edge_torch/generative/examples/gemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
63
63
  ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py,sha256=RRilUl2Ui08R9gy1Ua0jnaXNCrIJJb-oztgP62G3mX4,1526
64
- ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py,sha256=7IlF-4NEfZAzIfkOUHR-HeCSLSUGEu7wnO52UtERCa4,1527
64
+ ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py,sha256=9ozSw2-xuf5Wfh1HeLDTP3wJxxUZmrD3An1njJPMpdI,1594
65
65
  ai_edge_torch/generative/examples/gemma/gemma1.py,sha256=6ImjTzJcq6JoKz2Z-z8pjv5BsRu5nUeEsTK3IPs3xgI,3521
66
66
  ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=JQLLiHNVBM9jOrZqUF0EmgAwtDD0yTRlmIbLaWM7qTg,11557
67
67
  ai_edge_torch/generative/examples/gemma/verify_gemma1.py,sha256=ip-Gmk4CI5f0GWSdAIdrectxQWJ0t328KCsA4nfHuGg,1736
68
68
  ai_edge_torch/generative/examples/gemma/verify_gemma2.py,sha256=jhiyinOqPt5ZZjEadDRZt_wY5fiLSCpMo54PcxFaL_Q,1789
69
69
  ai_edge_torch/generative/examples/gemma/verify_util.py,sha256=n7f2nF6Lin_tDvPs0JVldsuaBzo7pAwi5YAHAhlIxQg,6139
70
70
  ai_edge_torch/generative/examples/gemma3/__init__.py,sha256=JaAnrFoXTl3RJX97XspklkTyqOHVyAgRJsZtzNDd10c,671
71
- ai_edge_torch/generative/examples/gemma3/convert_gemma3_to_tflite.py,sha256=MjkQDVynaw9C5z9ODzKfb85xW5JfxHUWBJ_Aco05FHo,1760
72
- ai_edge_torch/generative/examples/gemma3/decoder.py,sha256=xGxeNKQvgyrENmUQMu0uKymL3qthvbdoxdMbAzwiLz0,15725
71
+ ai_edge_torch/generative/examples/gemma3/convert_gemma3_to_tflite.py,sha256=JLXXn2mFEBs4DlHH_O6hpEG9KInJqsCdWy3DrgUjT1c,1827
72
+ ai_edge_torch/generative/examples/gemma3/decoder.py,sha256=v0ZcKrIAvERQLb1wK1Vc_ewWWVZgJFUdRTyoVY0Lfus,14955
73
73
  ai_edge_torch/generative/examples/gemma3/gemma3.py,sha256=GACDBI_MsFowR8A3wAWrpzradPYe-AUgB9ZjXaVBG-s,6485
74
74
  ai_edge_torch/generative/examples/gemma3/image_encoder.py,sha256=uRoLoBWzFtQz5wFZfPCxbkvZsgPAqSkUUsV3977GbYc,5184
75
75
  ai_edge_torch/generative/examples/gemma3/verify_gemma3.py,sha256=v8oNXFICmVOtQxfO7IhZ8GnbvotEkDi9lzYHjoQyOso,2464
@@ -178,9 +178,9 @@ ai_edge_torch/generative/layers/unet/model_config.py,sha256=pPDwLawc23pfMaPVyMJl
178
178
  ai_edge_torch/generative/quantize/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
179
179
  ai_edge_torch/generative/quantize/example.py,sha256=1lfVNUd2cEyRUnoZ7BLbRJ9IN-FTKiWBtZNPFUzAiWE,1747
180
180
  ai_edge_torch/generative/quantize/quant_attrs.py,sha256=plMsd7JBi98r2NHsAdMdvS6TPTXAoRFLCwOXu8H3-24,2004
181
- ai_edge_torch/generative/quantize/quant_recipe.py,sha256=3xT4N5tfggXJqgwKW4ntIkwsrNVtkG2SIUHeiSF5yOs,5579
181
+ ai_edge_torch/generative/quantize/quant_recipe.py,sha256=CEW-ewHxwb59x_GISx4jr7WMihvn-jKWVcBonllzDS4,5724
182
182
  ai_edge_torch/generative/quantize/quant_recipe_utils.py,sha256=h3k_na6rbR08Ip79-2JbkeH8RDk_rrnEGiytuzFDhqc,2678
183
- ai_edge_torch/generative/quantize/quant_recipes.py,sha256=a71KFHVbjJdBDpYshbUI69NxGNOmPuqp_NZvNSrf00c,2349
183
+ ai_edge_torch/generative/quantize/quant_recipes.py,sha256=5UkUAT0qsWzLtNAeX-M5hEMi-kqoLV70_F76QiXmVZ4,2424
184
184
  ai_edge_torch/generative/quantize/supported_schemes.py,sha256=TwR2FpQuBEORy6FshEyHNBMKARWlA2MVtTfX9tXV5aE,1488
185
185
  ai_edge_torch/generative/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
186
186
  ai_edge_torch/generative/test/test_custom_dus.py,sha256=MjIhTvkTko872M35XMciobvICcDWTcIDJ3rociko-wM,3267
@@ -189,13 +189,13 @@ ai_edge_torch/generative/test/test_loader.py,sha256=9mQUeeZKOVApOWSWl2cN9c10axZj
189
189
  ai_edge_torch/generative/test/test_lora.py,sha256=6QIM6RLTc2HrodGpp_aS3OxM9Rco2KAzEnYgotkg41M,5310
190
190
  ai_edge_torch/generative/test/test_model_conversion.py,sha256=mhNJikLnGVGi9NKmXB8FhnqeDy9gtrvC3yEbrTABZ4Y,6163
191
191
  ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=vQWmpzMkJ2hPmWpg41ZMWwBsngTykRVzRPHtpbkwiLM,12811
192
- ai_edge_torch/generative/test/test_quantize.py,sha256=TG6vTF9yOZWe2wW7v8-hmuaQoODwJC1Z-2d5xv3zgfI,7389
192
+ ai_edge_torch/generative/test/test_quantize.py,sha256=kKJ01wscTC2t_Ylr7huO5gNKES01gm3dT1gx52z15PA,7356
193
193
  ai_edge_torch/generative/test/utils.py,sha256=tF6aCfAGJnc9dmzCnZCEOuKNVimfWOqscv9og0DDLHU,2656
194
194
  ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
195
- ai_edge_torch/generative/utilities/converter.py,sha256=d0JOWN5l2vbvt8RzFFiRoulkWiejyEZ21xKv5LdLIyc,11675
195
+ ai_edge_torch/generative/utilities/converter.py,sha256=LrBqxXVxkOWh4abcHfY4QXRpYxjjfEYd4ifrpGGbebI,14441
196
196
  ai_edge_torch/generative/utilities/export_config.py,sha256=5IvR3grlMd4mWO5c_Y4x9Fk1b1xa57MzlYNE8XUaN28,2049
197
197
  ai_edge_torch/generative/utilities/loader.py,sha256=7p__m2JryWphGlYOuRxdoT4id4_tWJEVOV7y2X4H-Ak,13737
198
- ai_edge_torch/generative/utilities/model_builder.py,sha256=ZYX1TxpFdj573du2QCyHJlFjx4q1m12R74fp4Gwl92A,6343
198
+ ai_edge_torch/generative/utilities/model_builder.py,sha256=IG-88o7nWI9XrNDnwnQ-MoilsuqJ7KwrnbP3bn2EY9U,6334
199
199
  ai_edge_torch/generative/utilities/moonshine_loader.py,sha256=_RpFabSqtGH5PHiP3_1f6QfO14qMADUxr_HGRlVDFB0,4891
200
200
  ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=dqPD9qRXEWtU3ombslOC-BE2l_dMwHoCNu7NsIJhsso,36158
201
201
  ai_edge_torch/generative/utilities/t5_loader.py,sha256=tEsfy8-ymzbbjOIc-oesXF3yGyyWtJgFXn2s7VOavt8,16961
@@ -209,12 +209,12 @@ ai_edge_torch/hlfb/mark_pattern/pattern.py,sha256=Ui6BrehF3zJJN7uTxKwbO2yCY9mYjb
209
209
  ai_edge_torch/hlfb/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
210
210
  ai_edge_torch/hlfb/test/test_mark_pattern.py,sha256=5kmOJWCc7sU1Hrqr1y17BtShUrssTfaV1sMyUvdMbsg,5573
211
211
  ai_edge_torch/lowertools/__init__.py,sha256=A8WBXvWtuFYYWtNTqPD7waVntLaSVAnSMwx5ugjZBIw,761
212
- ai_edge_torch/lowertools/_shim.py,sha256=Mbg16tnCVK0YsHowfbpqpNX1qySuMLvpGI_-I5SIrG0,3276
212
+ ai_edge_torch/lowertools/_shim.py,sha256=rEZkOdHiz7CPvPL0WATIYnH4K6wF1YBtcv3oFEx2ZeQ,3277
213
213
  ai_edge_torch/lowertools/common_utils.py,sha256=4HQtquPZ6oiId8vR_1ykW_uK4ELnyo5zo3MlX1QYW4c,4513
214
214
  ai_edge_torch/lowertools/odml_torch_utils.py,sha256=QRuS7S5lULRWEh3J1sWIsnKh-rbX7rd9tt6JJHbMPfo,8317
215
215
  ai_edge_torch/lowertools/test_utils.py,sha256=mdxTlhqHABZEQ_GEmPFCL8LIAWtqRtYZUGdSY1ieZjw,1949
216
216
  ai_edge_torch/lowertools/torch_xla_utils.py,sha256=1EytIw2R6dthhLhf69wN1L9BaQTeybCD0wga-PhHcMI,9518
217
- ai_edge_torch/lowertools/translate_recipe.py,sha256=kUVCe69_DzvfbNYVB0MY2rCZwWaN8t3NoNu8Vh4x5bQ,5849
217
+ ai_edge_torch/lowertools/translate_recipe.py,sha256=JNsRc1Jmpj5W6PBww8KRMkbtxcv7ssl8Rr1R3x5_7to,6283
218
218
  ai_edge_torch/odml_torch/__init__.py,sha256=S8jOzE9nLof-6es3XDiGJRN-9H_XTxsVm9dE7lD3RWo,812
219
219
  ai_edge_torch/odml_torch/_torch_future.py,sha256=jSYHf1CMTJzMizPMbu2b39hAt0ZTR6gQLq67GMe9KTo,2336
220
220
  ai_edge_torch/odml_torch/_torch_library.py,sha256=Lw1gqL2HWNRspdTwNhIkYAHDyafHedHtkXyKKxn-Wss,805
@@ -251,8 +251,8 @@ ai_edge_torch/testing/__init__.py,sha256=_yGgvnBZWb7T3IN3mc4x1sS4vM96HZwM8pwIcPG
251
251
  ai_edge_torch/testing/export.py,sha256=k5mGDGzwc23Z4zaIVDs8CNh-oOt64gsf9MS9NjhbPy4,3293
252
252
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
253
253
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
254
- ai_edge_torch_nightly-0.5.0.dev20250508.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
255
- ai_edge_torch_nightly-0.5.0.dev20250508.dist-info/METADATA,sha256=GGDJl2Fya8gLr9RIfSLCmm1K1xA3qzBrrEOy1hwR2dQ,2051
256
- ai_edge_torch_nightly-0.5.0.dev20250508.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
257
- ai_edge_torch_nightly-0.5.0.dev20250508.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
258
- ai_edge_torch_nightly-0.5.0.dev20250508.dist-info/RECORD,,
254
+ ai_edge_torch_nightly-0.5.0.dev20250510.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
255
+ ai_edge_torch_nightly-0.5.0.dev20250510.dist-info/METADATA,sha256=1fA2DwpzLkPWBS-gV86ik7v9m39lO_RUaU4k7qAEvkM,2074
256
+ ai_edge_torch_nightly-0.5.0.dev20250510.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
257
+ ai_edge_torch_nightly-0.5.0.dev20250510.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
258
+ ai_edge_torch_nightly-0.5.0.dev20250510.dist-info/RECORD,,