ai-edge-torch-nightly 0.5.0.dev20250513__py3-none-any.whl → 0.5.0.dev20250515__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -19,11 +19,19 @@ from absl import app
19
19
  from ai_edge_torch.generative.examples.gemma3 import gemma3
20
20
  from ai_edge_torch.generative.utilities import converter
21
21
  from ai_edge_torch.generative.utilities import export_config
22
+ from ai_edge_torch.generative.utilities import loader
22
23
 
23
24
  flags = converter.define_conversion_flags(
24
25
  'gemma3-1b', default_mask_as_input=True, default_transpose_kv_cache=True
25
26
  )
26
27
 
28
+ _CUSTOM_CHECKPOINT_LOADER = flags.DEFINE_bool(
29
+ 'custom_checkpoint_loader',
30
+ False,
31
+ 'If true, the conversion script will use a custom checkpoint loader which'
32
+ ' will read a checkpoint from a remote source.',
33
+ )
34
+
27
35
  _MODEL_SIZE = flags.DEFINE_string(
28
36
  'model_size',
29
37
  '1b',
@@ -32,10 +40,16 @@ _MODEL_SIZE = flags.DEFINE_string(
32
40
 
33
41
 
34
42
  def main(_):
43
+ custom_loader = None
44
+ if flags.FLAGS.custom_checkpoint_loader:
45
+ # If loading from a remote source, try to get a custom loader first.
46
+ custom_loader = loader.get_custom_loader(flags.FLAGS.checkpoint_path)
47
+
35
48
  if _MODEL_SIZE.value == '1b':
36
49
  pytorch_model = gemma3.build_model_1b(
37
50
  flags.FLAGS.checkpoint_path,
38
51
  kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
52
+ custom_loader=custom_loader,
39
53
  )
40
54
  else:
41
55
  raise ValueError(f'Unsupported model size: {_MODEL_SIZE.value}')
@@ -15,7 +15,7 @@
15
15
 
16
16
  """Example of building a Decoder for Gemma3 model."""
17
17
 
18
- from typing import List, Optional, Tuple
18
+ from typing import Callable, Dict, List, Optional, Tuple
19
19
 
20
20
  from ai_edge_torch.generative.layers import attention
21
21
  from ai_edge_torch.generative.layers import builder
@@ -410,7 +410,11 @@ def get_fake_decoder_config_1b(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
410
410
  return config
411
411
 
412
412
 
413
- def build_model_1b(checkpoint_path: str, **kwargs) -> nn.Module:
413
+ def build_model_1b(
414
+ checkpoint_path: str,
415
+ custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
416
+ **kwargs,
417
+ ) -> nn.Module:
414
418
  # TODO(b/403644647): Better error handling for loading checkpoints with
415
419
  # different tensor names.
416
420
  for tensor_names in TENSOR_NAMES_DICT.values():
@@ -420,6 +424,7 @@ def build_model_1b(checkpoint_path: str, **kwargs) -> nn.Module:
420
424
  config=get_decoder_config_1b(**kwargs),
421
425
  tensor_names=tensor_names,
422
426
  model_class=Decoder,
427
+ custom_loader=custom_loader,
423
428
  )
424
429
  except KeyError as ke:
425
430
  continue
@@ -16,8 +16,7 @@
16
16
  """Example of building a Gemma3 gpu model."""
17
17
 
18
18
  from dataclasses import dataclass
19
- from typing import List, Optional, Tuple
20
- import xmlrpc
19
+ from typing import List, Optional, Tuple, Callable, Dict
21
20
 
22
21
  from ai_edge_torch.generative.examples.gemma3 import decoder
23
22
  from ai_edge_torch.generative.examples.gemma3 import image_encoder
@@ -166,9 +165,14 @@ def get_fake_model_config(**kwargs) -> Gemma3MMConfig:
166
165
  mm_extra_tokens=32,
167
166
  )
168
167
 
169
- def build_model_1b(checkpoint_path: str, **kwargs) -> decoder.Decoder:
168
+
169
+ def build_model_1b(
170
+ checkpoint_path: str,
171
+ custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
172
+ **kwargs,
173
+ ) -> decoder.Decoder:
170
174
  if checkpoint_path:
171
- model = decoder.build_model_1b(checkpoint_path, **kwargs)
175
+ model = decoder.build_model_1b(checkpoint_path, custom_loader, **kwargs)
172
176
  else:
173
177
  config = decoder.get_decoder_config_1b(**kwargs)
174
178
  model = decoder.Decoder(config)
@@ -17,7 +17,7 @@
17
17
 
18
18
  import logging
19
19
  import os
20
- from typing import List, Optional, Tuple
20
+ from typing import Callable, Dict, List, Optional, Tuple
21
21
 
22
22
  from ai_edge_torch.generative.examples.gemma3 import gemma3
23
23
  from ai_edge_torch.generative.layers import kv_cache as kv_utils
@@ -167,6 +167,7 @@ def verify_reauthored_gemma_model(
167
167
  generate_prompts: List[str],
168
168
  forward_input_ids: List[List[int]],
169
169
  weight_filename: str,
170
+ custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
170
171
  tokenizer_filename: str = "tokenizer.model",
171
172
  max_new_tokens: int = 20,
172
173
  rtol: float = 1e-05,
@@ -196,7 +197,14 @@ def verify_reauthored_gemma_model(
196
197
 
197
198
  logging.info("Loading the original model from: %s", checkpoint)
198
199
  original_model = gemma_model.GemmaForCausalLM(config).eval()
199
- original_model.load_weights(os.path.join(checkpoint, weight_filename))
200
+ checkpoint_path = os.path.join(checkpoint, weight_filename)
201
+ if custom_loader is None:
202
+ original_model.load_weights(checkpoint_path)
203
+ else:
204
+ original_model.load_state_dict(
205
+ custom_loader(checkpoint_path)["model_state_dict"],
206
+ strict=False,
207
+ )
200
208
 
201
209
  return verifier.verify_reauthored_model(
202
210
  original_model=GemmaWrapper(original_model),
@@ -216,6 +224,7 @@ def verify_gemma3(
216
224
  max_new_tokens: int,
217
225
  variant: str,
218
226
  weight_filename: str,
227
+ custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
219
228
  ) -> bool:
220
229
  """Verifies the reauthored Gemma3 model.
221
230
 
@@ -225,6 +234,7 @@ def verify_gemma3(
225
234
  max_new_tokens: Maximum number of new tokens to generate.
226
235
  variant: Gemma model variant.
227
236
  weight_filename: Name of the weight file.
237
+ custom_loader: A custom loader to load the weights.
228
238
 
229
239
  Returns:
230
240
  True if the verification passes, False otherwise.
@@ -234,7 +244,7 @@ def verify_gemma3(
234
244
 
235
245
  if variant == "1b":
236
246
  reauthored_model = UnifiedGemma3Wrapper(
237
- gemma3.build_model_1b(gemma3_model_path)
247
+ gemma3.build_model_1b(gemma3_model_path, custom_loader)
238
248
  )
239
249
  else:
240
250
  raise ValueError(f"Unsupported Gemma3 variant: {variant}")
@@ -247,5 +257,6 @@ def verify_gemma3(
247
257
  forward_input_ids=[[2, 651, 9456, 576, 573, 3520, 3858, 603, 235248]],
248
258
  max_new_tokens=max_new_tokens,
249
259
  weight_filename=weight_filename,
260
+ custom_loader=custom_loader,
250
261
  atol=1e-04,
251
262
  )
@@ -22,8 +22,6 @@ from ai_edge_torch.generative.utilities import export_config
22
22
  import torch
23
23
 
24
24
  flags = converter.define_conversion_flags('paligemma2-3b-224')
25
- ExportConfig = export_config.ExportConfig
26
-
27
25
 
28
26
  _VERSION = flags.DEFINE_enum(
29
27
  'version',
@@ -32,6 +30,7 @@ _VERSION = flags.DEFINE_enum(
32
30
  'The version of PaliGemma model to verify.',
33
31
  )
34
32
 
33
+
35
34
  def main(_):
36
35
  pytorch_model = paligemma.build_model(
37
36
  flags.FLAGS.checkpoint_path,
@@ -51,7 +50,7 @@ def main(_):
51
50
  pixel_seq_len=(config.image_size // config.patch_size) ** 2,
52
51
  quantize=flags.FLAGS.quantize,
53
52
  config=pytorch_model.config.decoder_config,
54
- export_config=ExportConfig(),
53
+ export_config=export_config.get_from_flags(),
55
54
  )
56
55
 
57
56
 
@@ -21,8 +21,6 @@ from ai_edge_torch.generative.utilities import converter
21
21
  from ai_edge_torch.generative.utilities import export_config
22
22
 
23
23
  flags = converter.define_conversion_flags('qwen_vl')
24
- ExportConfig = export_config.ExportConfig
25
-
26
24
 
27
25
  _IMAGE_HEIGHT = flags.DEFINE_integer(
28
26
  'image_height',
@@ -35,6 +33,7 @@ _IMAGE_WIDTH = flags.DEFINE_integer(
35
33
  'The width of image.',
36
34
  )
37
35
 
36
+
38
37
  def main(_):
39
38
  pytorch_model = qwen_vl.build_model(
40
39
  flags.FLAGS.checkpoint_path,
@@ -60,7 +59,7 @@ def main(_):
60
59
  ),
61
60
  quantize=flags.FLAGS.quantize,
62
61
  config=pytorch_model.config.decoder_config,
63
- export_config=ExportConfig(),
62
+ export_config=export_config.get_from_flags(),
64
63
  )
65
64
 
66
65
 
@@ -28,6 +28,8 @@ class RMSNorm(torch.nn.Module):
28
28
  dim: int,
29
29
  eps: float = 1e-6,
30
30
  zero_centered_gamma=False,
31
+ with_scale: bool = False,
32
+ scale_shift: float = 1.0,
31
33
  enable_hlfb: bool = False,
32
34
  ):
33
35
  """Initialize the RMSNorm layer.
@@ -37,13 +39,22 @@ class RMSNorm(torch.nn.Module):
37
39
  eps (float): A small float value to ensure numerical stability (default:
38
40
  1e-6).
39
41
  zero_centered_gamma (bool): Whether or not gamma has an offset.
42
+ with_scale (bool): Whether or not to use a scale parameter.
43
+ scale_shift (float): The shift to apply to the scale parameter.
40
44
  enable_hlfb (bool): use HLFB in the op.
41
45
  """
42
46
  super().__init__()
47
+ self.dim = dim
43
48
  self.enable_hlfb = enable_hlfb
44
49
  self.eps = eps
45
- self.weight = torch.nn.Parameter(torch.ones(dim))
50
+ self.weight = torch.nn.Parameter(torch.ones(dim), requires_grad=False)
46
51
  self.zero_centered_gamma = zero_centered_gamma
52
+ self.with_scale = with_scale
53
+ if with_scale:
54
+ self.scale = torch.nn.Parameter(
55
+ torch.zeros((dim,), dtype=torch.float32), requires_grad=False
56
+ )
57
+ self.scale_shift = scale_shift
47
58
 
48
59
  def _norm(self, x):
49
60
  """Apply RMSNorm normalization.
@@ -70,14 +81,20 @@ class RMSNorm(torch.nn.Module):
70
81
  else:
71
82
  w = self.weight
72
83
 
84
+ final_scale = (
85
+ self.scale + self.scale_shift
86
+ if self.with_scale
87
+ else torch.ones((self.dim,), dtype=torch.float32)
88
+ )
73
89
  if self.enable_hlfb:
74
90
  return rms_norm_with_hlfb(
75
91
  x,
76
92
  w,
77
93
  self.eps,
94
+ final_scale,
78
95
  )
79
96
  else:
80
- output = self._norm(x.float()).type_as(x)
97
+ output = self._norm(x.float()).type_as(x) * final_scale
81
98
  return output * w
82
99
 
83
100
 
@@ -104,8 +121,8 @@ class GroupNorm(torch.nn.Module):
104
121
  self.enable_hlfb = enable_hlfb
105
122
  self.group_num = group_num
106
123
  self.eps = eps
107
- self.weight = torch.nn.Parameter(torch.empty(dim))
108
- self.bias = torch.nn.Parameter(torch.empty(dim))
124
+ self.weight = torch.nn.Parameter(torch.empty(dim), requires_grad=False)
125
+ self.bias = torch.nn.Parameter(torch.empty(dim), requires_grad=False)
109
126
 
110
127
  def forward(self, x):
111
128
  """Running the forward pass of GroupNorm layer.
@@ -140,8 +157,8 @@ class LayerNorm(torch.nn.Module):
140
157
  self.enable_hlfb = enable_hlfb
141
158
  self.normalized_shape = (dim,)
142
159
  self.eps = eps
143
- self.weight = torch.nn.Parameter(torch.empty(dim))
144
- self.bias = torch.nn.Parameter(torch.empty(dim))
160
+ self.weight = torch.nn.Parameter(torch.empty(dim), requires_grad=False)
161
+ self.bias = torch.nn.Parameter(torch.empty(dim), requires_grad=False)
145
162
 
146
163
  def forward(self, x):
147
164
  """Running the forward pass of LayerNorm layer.
@@ -165,6 +182,7 @@ def rms_norm_with_hlfb(
165
182
  x: torch.Tensor,
166
183
  w: torch.Tensor,
167
184
  eps: float,
185
+ final_scale: torch.Tensor,
168
186
  ):
169
187
  """RMS Normalization with high-level function boundary enabled.
170
188
 
@@ -172,6 +190,7 @@ def rms_norm_with_hlfb(
172
190
  x (torch.Tensor): Input tensor for RMS Normalization, with BCHW shape.
173
191
  w (torch.Tensor): The learned parameter tensor for normalization.
174
192
  eps (float): A small float value to ensure numerical stability.
193
+ final_scale (torch.Tensor): The final scale to apply to the normalization.
175
194
 
176
195
  Returns:
177
196
  The output tensor of RMS Normalization.
@@ -185,7 +204,7 @@ def rms_norm_with_hlfb(
185
204
  def _norm(x):
186
205
  return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps)
187
206
 
188
- output = _norm(x.float()).type_as(x)
207
+ output = _norm(x.float()).type_as(x) * final_scale
189
208
  out = output * w
190
209
 
191
210
  out = builder.mark_outputs(out)
@@ -0,0 +1,73 @@
1
+ # Copyright 2025 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Tests for normalization layers."""
16
+
17
+ from ai_edge_torch.generative.layers import normalization
18
+ import torch
19
+ from absl.testing import absltest as googletest
20
+ from absl.testing import parameterized
21
+
22
+
23
+ class NormalizationTest(parameterized.TestCase):
24
+
25
+ @parameterized.named_parameters(
26
+ dict(
27
+ testcase_name="rms_norm_test_1",
28
+ model_dim=10,
29
+ with_scale=False,
30
+ scale_shift=1.0,
31
+ enable_hlfb=False,
32
+ expected_values=torch.ones((10,), dtype=torch.float32),
33
+ ),
34
+ dict(
35
+ testcase_name="rms_norm_test_2",
36
+ model_dim=10,
37
+ with_scale=True,
38
+ scale_shift=2.0,
39
+ enable_hlfb=False,
40
+ expected_values=torch.ones((10,), dtype=torch.float32) * 2.0,
41
+ ),
42
+ dict(
43
+ testcase_name="rms_norm_test_3",
44
+ model_dim=10,
45
+ with_scale=True,
46
+ scale_shift=2.0,
47
+ enable_hlfb=True,
48
+ expected_values=torch.ones((10,), dtype=torch.float32) * 2.0,
49
+ ),
50
+ )
51
+ def test_rms_norm(
52
+ self,
53
+ model_dim: int,
54
+ with_scale: bool,
55
+ scale_shift: float,
56
+ enable_hlfb: bool,
57
+ expected_values: torch.Tensor,
58
+ ):
59
+ rms_norm = normalization.RMSNorm(
60
+ dim=model_dim,
61
+ with_scale=with_scale,
62
+ scale_shift=scale_shift,
63
+ enable_hlfb=enable_hlfb,
64
+ )
65
+
66
+ x = torch.ones((model_dim,), dtype=torch.float32)
67
+ out = rms_norm(x)
68
+ self.assertEqual(out.shape, (model_dim,))
69
+ self.assertTrue(torch.allclose(out, expected_values))
70
+
71
+
72
+ if __name__ == "__main__":
73
+ googletest.main()
@@ -280,6 +280,15 @@ def convert_to_tflite(
280
280
  '' if not lora_ranks else f'_lora{",".join(map(str, lora_ranks))}'
281
281
  )
282
282
 
283
+ if pixel_values_size is not None:
284
+ assert pixel_seq_len > 0, 'pixel_seq_len must be greater than 0'
285
+ max_prefill_seq_len = max(prefill_seq_lens)
286
+ assert kv_size > max_prefill_seq_len + pixel_seq_len, (
287
+ f'The KV cache size ({kv_size}) must be greater than the maximum '
288
+ f'prefill sequence length ({max_prefill_seq_len}) + pixel sequence '
289
+ f'length ({pixel_seq_len})'
290
+ )
291
+
283
292
  if export_config is not None:
284
293
  if export_config.decode_batch_size > 1:
285
294
  output_name_prefix += f'_dbs{export_config.decode_batch_size}'
@@ -19,10 +19,36 @@ import os
19
19
  from typing import Callable, Dict, List, Tuple
20
20
 
21
21
  from ai_edge_torch.generative.layers import model_config
22
+ import safetensors
22
23
  from safetensors import safe_open
23
24
  import torch
24
25
 
25
26
 
27
+ def get_custom_loader(
28
+ checkpoint_path: str,
29
+ ) -> Callable[[str], Dict[str, torch.Tensor]]:
30
+ """Returns a custom loader for the given checkpoint path.
31
+
32
+ Those customer loaders can either support state dictionary or safetensors, and
33
+ the actual data might be fetched from a remote source.
34
+
35
+ Args:
36
+ checkpoint_path (string): The path to the checkpoint.
37
+
38
+ Returns:
39
+ Callable[[str], Dict[str, torch.Tensor]]: The custom loader.
40
+
41
+ Raises:
42
+ ValueError: If the checkpoint format is not supported.
43
+ """
44
+
45
+ if os.path.splitext(checkpoint_path)[1] in [".bin", ".pt", ".ckpt"]:
46
+ return lambda path: torch.load(path, weights_only=True)
47
+ if checkpoint_path.endswith(".safetensors"):
48
+ return safetensors.torch.load_file
49
+ raise ValueError(f"Unsupported checkpoint format: {checkpoint_path}")
50
+
51
+
26
52
  def load_safetensors(full_path: str):
27
53
  """Loads safetensors into a single state dictionary.
28
54
 
@@ -117,7 +143,12 @@ class ModelLoader:
117
143
  final_norm: str = None
118
144
  lm_head: str = None
119
145
 
120
- def __init__(self, file_name: str, names: TensorNames) -> None:
146
+ def __init__(
147
+ self,
148
+ file_name: str,
149
+ names: TensorNames,
150
+ custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
151
+ ) -> None:
121
152
  """ModelLoader constructor.
122
153
 
123
154
  Can be used to load multiple models of the same type.
@@ -126,10 +157,15 @@ class ModelLoader:
126
157
  file_name (str): Path to the checkpoint. Can be a directory or an exact
127
158
  file.
128
159
  names (TensorNames): An instance of `TensorNames` to determine mappings.
160
+ custom_loader (Callable[[str], Dict[str, torch.Tensor]]): A custom
161
+ loader to be used. If not provided, the class will determine a proper
162
+ loader.
129
163
  """
130
164
  self._file_name = file_name
131
165
  self._names = names
132
- self._loader = self._get_loader()
166
+ self._loader = (
167
+ custom_loader if custom_loader is not None else self._get_loader()
168
+ )
133
169
 
134
170
  def get_state(self) -> Dict[str, torch.Tensor]:
135
171
  return self._loader(self._file_name)
@@ -16,6 +16,7 @@
16
16
  """Utilities to be used for re-authoring transformer models."""
17
17
 
18
18
  import copy
19
+ from typing import Callable, Dict
19
20
  from typing import Optional, Tuple
20
21
 
21
22
  from ai_edge_torch.generative.layers import attention
@@ -160,9 +161,12 @@ def build_decoder_only_model(
160
161
  config: cfg.ModelConfig,
161
162
  tensor_names: loading_utils.ModelLoader.TensorNames,
162
163
  model_class: type[nn.Module] = DecoderOnlyModel,
164
+ custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
163
165
  ) -> nn.Module:
164
166
  transformer = model_class(config)
165
- loader = loading_utils.ModelLoader(checkpoint_path, tensor_names)
167
+ loader = loading_utils.ModelLoader(
168
+ checkpoint_path, tensor_names, custom_loader
169
+ )
166
170
  loader.load(
167
171
  transformer, strict=not config.lm_head_share_weight_with_embedding
168
172
  )
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.5.0.dev20250513"
16
+ __version__ = "0.5.0.dev20250515"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.5.0.dev20250513
3
+ Version: 0.5.0.dev20250515
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=8sPR_5uXJA4NEE0nIwNdSl-ADOJEoR8hAgYvBQDY70Y,120
2
2
  ai_edge_torch/_config.py,sha256=AiqhbcheF7j_ozIGDLC89k1we95aVgFDa-tR6h7UI0s,2529
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/model.py,sha256=wxjSFq_rBSxSqbUE8E8EJTCkgvgaRLjq_ZuAM-IZpCU,5606
5
- ai_edge_torch/version.py,sha256=Q2u2GS0KjqxWhznlOZBgkCi4NAQcdpjJzkUYdcGYQ5o,706
5
+ ai_edge_torch/version.py,sha256=QVmEdwoLJem1gNQul_CoRyfqOc1Ljjy48x9GmKmuAOU,706
6
6
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
7
7
  ai_edge_torch/_convert/conversion.py,sha256=QVugYVfbyaeBgSKKbhFzHG5oXA7t3M-40JcpcdSu6W8,5436
8
8
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
@@ -68,12 +68,12 @@ ai_edge_torch/generative/examples/gemma/verify_gemma1.py,sha256=ip-Gmk4CI5f0GWSd
68
68
  ai_edge_torch/generative/examples/gemma/verify_gemma2.py,sha256=jhiyinOqPt5ZZjEadDRZt_wY5fiLSCpMo54PcxFaL_Q,1789
69
69
  ai_edge_torch/generative/examples/gemma/verify_util.py,sha256=n7f2nF6Lin_tDvPs0JVldsuaBzo7pAwi5YAHAhlIxQg,6139
70
70
  ai_edge_torch/generative/examples/gemma3/__init__.py,sha256=JaAnrFoXTl3RJX97XspklkTyqOHVyAgRJsZtzNDd10c,671
71
- ai_edge_torch/generative/examples/gemma3/convert_gemma3_to_tflite.py,sha256=JLXXn2mFEBs4DlHH_O6hpEG9KInJqsCdWy3DrgUjT1c,1827
72
- ai_edge_torch/generative/examples/gemma3/decoder.py,sha256=v0ZcKrIAvERQLb1wK1Vc_ewWWVZgJFUdRTyoVY0Lfus,14955
73
- ai_edge_torch/generative/examples/gemma3/gemma3.py,sha256=GACDBI_MsFowR8A3wAWrpzradPYe-AUgB9ZjXaVBG-s,6485
71
+ ai_edge_torch/generative/examples/gemma3/convert_gemma3_to_tflite.py,sha256=SsiK9xKCyboi5y-HdoFSN02QxRo0XabyzotUq46zO0E,2357
72
+ ai_edge_torch/generative/examples/gemma3/decoder.py,sha256=shdgLzKDUi0vyNOAsrIVAEFb3Adltsri6Rx1-wxzVf4,15089
73
+ ai_edge_torch/generative/examples/gemma3/gemma3.py,sha256=ZorRtnbElWsctcA0nEbfwjx0C578voF7fjFEvWSR5Ck,6582
74
74
  ai_edge_torch/generative/examples/gemma3/image_encoder.py,sha256=uRoLoBWzFtQz5wFZfPCxbkvZsgPAqSkUUsV3977GbYc,5184
75
75
  ai_edge_torch/generative/examples/gemma3/verify_gemma3.py,sha256=v8oNXFICmVOtQxfO7IhZ8GnbvotEkDi9lzYHjoQyOso,2464
76
- ai_edge_torch/generative/examples/gemma3/verify_util.py,sha256=KnE9ME3mrpQkAxFlBOJLsqcQkjsdDL1ClNhJahX5K5I,8960
76
+ ai_edge_torch/generative/examples/gemma3/verify_util.py,sha256=1vfAtayH_I_qTpqhzu6n9xnCuvhgTzhS8IzZviW2dJQ,9418
77
77
  ai_edge_torch/generative/examples/hammer/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
78
78
  ai_edge_torch/generative/examples/hammer/convert_to_tflite.py,sha256=9r8LXyaoBXYIIhhe1WQgEIjaxALQPE1dO2N6qopyWCk,1753
79
79
  ai_edge_torch/generative/examples/hammer/hammer.py,sha256=76INcjffvaNCQ02fzXcxJUW_6EKHs4sg3q1nDBbEpHE,3431
@@ -90,7 +90,7 @@ ai_edge_torch/generative/examples/openelm/convert_to_tflite.py,sha256=wRdT7bWbCX
90
90
  ai_edge_torch/generative/examples/openelm/openelm.py,sha256=hPcXYHj-nBP56TOeQQejB3HRzv6yHSftHOx0OEPP5M8,4574
91
91
  ai_edge_torch/generative/examples/openelm/verify.py,sha256=4W26ZtPF5Cb9mpHYuRM4b2QB_4W76zf4WV36KzexVjs,2446
92
92
  ai_edge_torch/generative/examples/paligemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
93
- ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py,sha256=7HHXkC-IIu7ieBvBI4RlXs_oITz7R8a6YVYQskAs_Uk,2023
93
+ ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py,sha256=fkP-mWrih1s-vgJ41fLt8v5JE-UOs8Zrngh6ElQ6PMw,1997
94
94
  ai_edge_torch/generative/examples/paligemma/decoder.py,sha256=-EYUZp55dfRY1E-N0Pr3b9i5c7Tt1XvYxvsRixguVS8,5527
95
95
  ai_edge_torch/generative/examples/paligemma/decoder2.py,sha256=WB8r-e_Crog1ItBq3Zse_nUG-foFyBcJsuEG26r_Ji8,6076
96
96
  ai_edge_torch/generative/examples/paligemma/image_encoder.py,sha256=SvuR97sjkBtfkerH7Hu1UXB8kCFLpEATNbPfCbNAyfo,5614
@@ -114,7 +114,7 @@ ai_edge_torch/generative/examples/qwen/convert_to_tflite.py,sha256=eOpv3scJr4mVs
114
114
  ai_edge_torch/generative/examples/qwen/qwen.py,sha256=m8APYzo9N0SXsdvCxC8HtCcbN3W7gLKkRBL-Tg0BWXU,4223
115
115
  ai_edge_torch/generative/examples/qwen/verify.py,sha256=9_AyEJTeUfvhhID64Rto2bflFPyXMFokdQLsseLUMiI,2775
116
116
  ai_edge_torch/generative/examples/qwen_vl/__init__.py,sha256=JaAnrFoXTl3RJX97XspklkTyqOHVyAgRJsZtzNDd10c,671
117
- ai_edge_torch/generative/examples/qwen_vl/convert_to_tflite.py,sha256=yVebRatt2SLCsGvrYTBXOM-0S2REhkpikHTyy5MCjUw,2222
117
+ ai_edge_torch/generative/examples/qwen_vl/convert_to_tflite.py,sha256=4Gntv6LBIxd0CaKkb-koLzGTdBEOGgVf3ob99lAuvuY,2196
118
118
  ai_edge_torch/generative/examples/qwen_vl/decoder.py,sha256=7RFM25tDj_b0FkpSv8RUWir8K8v9p2jMtwZmP4VAUhw,4474
119
119
  ai_edge_torch/generative/examples/qwen_vl/image_encoder.py,sha256=nHzBe_YSPnUe1d5i09v4bePQomVifzJNeUjRfprmxC0,14878
120
120
  ai_edge_torch/generative/examples/qwen_vl/qwen_vl.py,sha256=mfLFrT8NPEPh9CqlJYHwh-I2y6ST7hH_vEmbZYartHQ,7764
@@ -166,7 +166,8 @@ ai_edge_torch/generative/layers/feed_forward_test.py,sha256=8ZGy79BBpsyS6yKKDEKr
166
166
  ai_edge_torch/generative/layers/kv_cache.py,sha256=b-7shzDaKexmvQF7P3SiAmIz4ZofjYWv3m5u71GojsA,10460
167
167
  ai_edge_torch/generative/layers/lora.py,sha256=hsvWLLOnW7HQ0AysOZu30x_cetMquDd1tjfyLz8HCSU,17892
168
168
  ai_edge_torch/generative/layers/model_config.py,sha256=X_gjN5524DCDBNXsX5GrOBlkKM4UHzj_RfdCD0-VOxQ,8572
169
- ai_edge_torch/generative/layers/normalization.py,sha256=MbwH-n80Fob5YvjBzdqDjBizMHLzSJGYRDdbD-rL5C0,6174
169
+ ai_edge_torch/generative/layers/normalization.py,sha256=ijwCpi22NLX-Sygwy5sK9l9WjGvbPIhZvVwoBAonWAo,7014
170
+ ai_edge_torch/generative/layers/normalization_test.py,sha256=zwurZly-TgFxdgVVdpzu9vCpcLbd5RYt_gKg9Lfg1jI,2248
170
171
  ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=975zR202MdIrILJ7blceAcxrNqX1ZCN0ECKG1gz-bV8,2655
171
172
  ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=2_AgwENsaOgaxgiSqgoj0V0JzQ09dFtP_nBhX-lJK2g,5648
172
173
  ai_edge_torch/generative/layers/scaled_dot_product_attention_test.py,sha256=c6JBMQsq9XeMmR1XvGEIidNsoh-YIvichXo2LwVHgr4,3301
@@ -192,10 +193,10 @@ ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=vQWmpzMkJ2hP
192
193
  ai_edge_torch/generative/test/test_quantize.py,sha256=kKJ01wscTC2t_Ylr7huO5gNKES01gm3dT1gx52z15PA,7356
193
194
  ai_edge_torch/generative/test/utils.py,sha256=tF6aCfAGJnc9dmzCnZCEOuKNVimfWOqscv9og0DDLHU,2656
194
195
  ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
195
- ai_edge_torch/generative/utilities/converter.py,sha256=OMBy_nos9mEGMQOAD8o0on-gAkRk-kliodFSTthD5BE,14612
196
+ ai_edge_torch/generative/utilities/converter.py,sha256=4zcDlhgCQQyLylH8NLgVjnelou2pW6HWJHBFYsFyHuw,15020
196
197
  ai_edge_torch/generative/utilities/export_config.py,sha256=5IvR3grlMd4mWO5c_Y4x9Fk1b1xa57MzlYNE8XUaN28,2049
197
- ai_edge_torch/generative/utilities/loader.py,sha256=7p__m2JryWphGlYOuRxdoT4id4_tWJEVOV7y2X4H-Ak,13737
198
- ai_edge_torch/generative/utilities/model_builder.py,sha256=IG-88o7nWI9XrNDnwnQ-MoilsuqJ7KwrnbP3bn2EY9U,6334
198
+ ai_edge_torch/generative/utilities/loader.py,sha256=tSiew77hB_zyn6rpcfegSg1zrriqHSz63KjV9_llBxg,14893
199
+ ai_edge_torch/generative/utilities/model_builder.py,sha256=tBfOcsI_NcneggHqkCSydYN3ZgmkzPc6nW0AJrA81wI,6461
199
200
  ai_edge_torch/generative/utilities/moonshine_loader.py,sha256=_RpFabSqtGH5PHiP3_1f6QfO14qMADUxr_HGRlVDFB0,4891
200
201
  ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=dqPD9qRXEWtU3ombslOC-BE2l_dMwHoCNu7NsIJhsso,36158
201
202
  ai_edge_torch/generative/utilities/t5_loader.py,sha256=tEsfy8-ymzbbjOIc-oesXF3yGyyWtJgFXn2s7VOavt8,16961
@@ -251,8 +252,8 @@ ai_edge_torch/testing/__init__.py,sha256=_yGgvnBZWb7T3IN3mc4x1sS4vM96HZwM8pwIcPG
251
252
  ai_edge_torch/testing/export.py,sha256=k5mGDGzwc23Z4zaIVDs8CNh-oOt64gsf9MS9NjhbPy4,3293
252
253
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
253
254
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
254
- ai_edge_torch_nightly-0.5.0.dev20250513.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
255
- ai_edge_torch_nightly-0.5.0.dev20250513.dist-info/METADATA,sha256=qh5r3x7C0ksa3D2WriWd0yePFgxK8urh9aSsJCC_gjY,2074
256
- ai_edge_torch_nightly-0.5.0.dev20250513.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
257
- ai_edge_torch_nightly-0.5.0.dev20250513.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
258
- ai_edge_torch_nightly-0.5.0.dev20250513.dist-info/RECORD,,
255
+ ai_edge_torch_nightly-0.5.0.dev20250515.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
256
+ ai_edge_torch_nightly-0.5.0.dev20250515.dist-info/METADATA,sha256=FmCPouaJYszNPCOfgIx8WGFkGv5LrqR6_OGpciU2eKc,2074
257
+ ai_edge_torch_nightly-0.5.0.dev20250515.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
258
+ ai_edge_torch_nightly-0.5.0.dev20250515.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
259
+ ai_edge_torch_nightly-0.5.0.dev20250515.dist-info/RECORD,,