ai-edge-torch-nightly 0.5.0.dev20250514__py3-none-any.whl → 0.5.0.dev20250515__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -19,11 +19,19 @@ from absl import app
19
19
  from ai_edge_torch.generative.examples.gemma3 import gemma3
20
20
  from ai_edge_torch.generative.utilities import converter
21
21
  from ai_edge_torch.generative.utilities import export_config
22
+ from ai_edge_torch.generative.utilities import loader
22
23
 
23
24
  flags = converter.define_conversion_flags(
24
25
  'gemma3-1b', default_mask_as_input=True, default_transpose_kv_cache=True
25
26
  )
26
27
 
28
+ _CUSTOM_CHECKPOINT_LOADER = flags.DEFINE_bool(
29
+ 'custom_checkpoint_loader',
30
+ False,
31
+ 'If true, the conversion script will use a custom checkpoint loader which'
32
+ ' will read a checkpoint from a remote source.',
33
+ )
34
+
27
35
  _MODEL_SIZE = flags.DEFINE_string(
28
36
  'model_size',
29
37
  '1b',
@@ -32,10 +40,16 @@ _MODEL_SIZE = flags.DEFINE_string(
32
40
 
33
41
 
34
42
  def main(_):
43
+ custom_loader = None
44
+ if flags.FLAGS.custom_checkpoint_loader:
45
+ # If loading from a remote source, try to get a custom loader first.
46
+ custom_loader = loader.get_custom_loader(flags.FLAGS.checkpoint_path)
47
+
35
48
  if _MODEL_SIZE.value == '1b':
36
49
  pytorch_model = gemma3.build_model_1b(
37
50
  flags.FLAGS.checkpoint_path,
38
51
  kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
52
+ custom_loader=custom_loader,
39
53
  )
40
54
  else:
41
55
  raise ValueError(f'Unsupported model size: {_MODEL_SIZE.value}')
@@ -28,6 +28,8 @@ class RMSNorm(torch.nn.Module):
28
28
  dim: int,
29
29
  eps: float = 1e-6,
30
30
  zero_centered_gamma=False,
31
+ with_scale: bool = False,
32
+ scale_shift: float = 1.0,
31
33
  enable_hlfb: bool = False,
32
34
  ):
33
35
  """Initialize the RMSNorm layer.
@@ -37,13 +39,22 @@ class RMSNorm(torch.nn.Module):
37
39
  eps (float): A small float value to ensure numerical stability (default:
38
40
  1e-6).
39
41
  zero_centered_gamma (bool): Whether or not gamma has an offset.
42
+ with_scale (bool): Whether or not to use a scale parameter.
43
+ scale_shift (float): The shift to apply to the scale parameter.
40
44
  enable_hlfb (bool): use HLFB in the op.
41
45
  """
42
46
  super().__init__()
47
+ self.dim = dim
43
48
  self.enable_hlfb = enable_hlfb
44
49
  self.eps = eps
45
- self.weight = torch.nn.Parameter(torch.ones(dim))
50
+ self.weight = torch.nn.Parameter(torch.ones(dim), requires_grad=False)
46
51
  self.zero_centered_gamma = zero_centered_gamma
52
+ self.with_scale = with_scale
53
+ if with_scale:
54
+ self.scale = torch.nn.Parameter(
55
+ torch.zeros((dim,), dtype=torch.float32), requires_grad=False
56
+ )
57
+ self.scale_shift = scale_shift
47
58
 
48
59
  def _norm(self, x):
49
60
  """Apply RMSNorm normalization.
@@ -70,14 +81,20 @@ class RMSNorm(torch.nn.Module):
70
81
  else:
71
82
  w = self.weight
72
83
 
84
+ final_scale = (
85
+ self.scale + self.scale_shift
86
+ if self.with_scale
87
+ else torch.ones((self.dim,), dtype=torch.float32)
88
+ )
73
89
  if self.enable_hlfb:
74
90
  return rms_norm_with_hlfb(
75
91
  x,
76
92
  w,
77
93
  self.eps,
94
+ final_scale,
78
95
  )
79
96
  else:
80
- output = self._norm(x.float()).type_as(x)
97
+ output = self._norm(x.float()).type_as(x) * final_scale
81
98
  return output * w
82
99
 
83
100
 
@@ -104,8 +121,8 @@ class GroupNorm(torch.nn.Module):
104
121
  self.enable_hlfb = enable_hlfb
105
122
  self.group_num = group_num
106
123
  self.eps = eps
107
- self.weight = torch.nn.Parameter(torch.empty(dim))
108
- self.bias = torch.nn.Parameter(torch.empty(dim))
124
+ self.weight = torch.nn.Parameter(torch.empty(dim), requires_grad=False)
125
+ self.bias = torch.nn.Parameter(torch.empty(dim), requires_grad=False)
109
126
 
110
127
  def forward(self, x):
111
128
  """Running the forward pass of GroupNorm layer.
@@ -140,8 +157,8 @@ class LayerNorm(torch.nn.Module):
140
157
  self.enable_hlfb = enable_hlfb
141
158
  self.normalized_shape = (dim,)
142
159
  self.eps = eps
143
- self.weight = torch.nn.Parameter(torch.empty(dim))
144
- self.bias = torch.nn.Parameter(torch.empty(dim))
160
+ self.weight = torch.nn.Parameter(torch.empty(dim), requires_grad=False)
161
+ self.bias = torch.nn.Parameter(torch.empty(dim), requires_grad=False)
145
162
 
146
163
  def forward(self, x):
147
164
  """Running the forward pass of LayerNorm layer.
@@ -165,6 +182,7 @@ def rms_norm_with_hlfb(
165
182
  x: torch.Tensor,
166
183
  w: torch.Tensor,
167
184
  eps: float,
185
+ final_scale: torch.Tensor,
168
186
  ):
169
187
  """RMS Normalization with high-level function boundary enabled.
170
188
 
@@ -172,6 +190,7 @@ def rms_norm_with_hlfb(
172
190
  x (torch.Tensor): Input tensor for RMS Normalization, with BCHW shape.
173
191
  w (torch.Tensor): The learned parameter tensor for normalization.
174
192
  eps (float): A small float value to ensure numerical stability.
193
+ final_scale (torch.Tensor): The final scale to apply to the normalization.
175
194
 
176
195
  Returns:
177
196
  The output tensor of RMS Normalization.
@@ -185,7 +204,7 @@ def rms_norm_with_hlfb(
185
204
  def _norm(x):
186
205
  return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps)
187
206
 
188
- output = _norm(x.float()).type_as(x)
207
+ output = _norm(x.float()).type_as(x) * final_scale
189
208
  out = output * w
190
209
 
191
210
  out = builder.mark_outputs(out)
@@ -0,0 +1,73 @@
1
+ # Copyright 2025 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Tests for normalization layers."""
16
+
17
+ from ai_edge_torch.generative.layers import normalization
18
+ import torch
19
+ from absl.testing import absltest as googletest
20
+ from absl.testing import parameterized
21
+
22
+
23
+ class NormalizationTest(parameterized.TestCase):
24
+
25
+ @parameterized.named_parameters(
26
+ dict(
27
+ testcase_name="rms_norm_test_1",
28
+ model_dim=10,
29
+ with_scale=False,
30
+ scale_shift=1.0,
31
+ enable_hlfb=False,
32
+ expected_values=torch.ones((10,), dtype=torch.float32),
33
+ ),
34
+ dict(
35
+ testcase_name="rms_norm_test_2",
36
+ model_dim=10,
37
+ with_scale=True,
38
+ scale_shift=2.0,
39
+ enable_hlfb=False,
40
+ expected_values=torch.ones((10,), dtype=torch.float32) * 2.0,
41
+ ),
42
+ dict(
43
+ testcase_name="rms_norm_test_3",
44
+ model_dim=10,
45
+ with_scale=True,
46
+ scale_shift=2.0,
47
+ enable_hlfb=True,
48
+ expected_values=torch.ones((10,), dtype=torch.float32) * 2.0,
49
+ ),
50
+ )
51
+ def test_rms_norm(
52
+ self,
53
+ model_dim: int,
54
+ with_scale: bool,
55
+ scale_shift: float,
56
+ enable_hlfb: bool,
57
+ expected_values: torch.Tensor,
58
+ ):
59
+ rms_norm = normalization.RMSNorm(
60
+ dim=model_dim,
61
+ with_scale=with_scale,
62
+ scale_shift=scale_shift,
63
+ enable_hlfb=enable_hlfb,
64
+ )
65
+
66
+ x = torch.ones((model_dim,), dtype=torch.float32)
67
+ out = rms_norm(x)
68
+ self.assertEqual(out.shape, (model_dim,))
69
+ self.assertTrue(torch.allclose(out, expected_values))
70
+
71
+
72
+ if __name__ == "__main__":
73
+ googletest.main()
@@ -19,10 +19,36 @@ import os
19
19
  from typing import Callable, Dict, List, Tuple
20
20
 
21
21
  from ai_edge_torch.generative.layers import model_config
22
+ import safetensors
22
23
  from safetensors import safe_open
23
24
  import torch
24
25
 
25
26
 
27
+ def get_custom_loader(
28
+ checkpoint_path: str,
29
+ ) -> Callable[[str], Dict[str, torch.Tensor]]:
30
+ """Returns a custom loader for the given checkpoint path.
31
+
32
+ Those customer loaders can either support state dictionary or safetensors, and
33
+ the actual data might be fetched from a remote source.
34
+
35
+ Args:
36
+ checkpoint_path (string): The path to the checkpoint.
37
+
38
+ Returns:
39
+ Callable[[str], Dict[str, torch.Tensor]]: The custom loader.
40
+
41
+ Raises:
42
+ ValueError: If the checkpoint format is not supported.
43
+ """
44
+
45
+ if os.path.splitext(checkpoint_path)[1] in [".bin", ".pt", ".ckpt"]:
46
+ return lambda path: torch.load(path, weights_only=True)
47
+ if checkpoint_path.endswith(".safetensors"):
48
+ return safetensors.torch.load_file
49
+ raise ValueError(f"Unsupported checkpoint format: {checkpoint_path}")
50
+
51
+
26
52
  def load_safetensors(full_path: str):
27
53
  """Loads safetensors into a single state dictionary.
28
54
 
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.5.0.dev20250514"
16
+ __version__ = "0.5.0.dev20250515"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.5.0.dev20250514
3
+ Version: 0.5.0.dev20250515
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=8sPR_5uXJA4NEE0nIwNdSl-ADOJEoR8hAgYvBQDY70Y,120
2
2
  ai_edge_torch/_config.py,sha256=AiqhbcheF7j_ozIGDLC89k1we95aVgFDa-tR6h7UI0s,2529
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/model.py,sha256=wxjSFq_rBSxSqbUE8E8EJTCkgvgaRLjq_ZuAM-IZpCU,5606
5
- ai_edge_torch/version.py,sha256=ZvSDZpKkUslpMEN4pPp4xI6n8g3mHZMdfIcYeWth5Dg,706
5
+ ai_edge_torch/version.py,sha256=QVmEdwoLJem1gNQul_CoRyfqOc1Ljjy48x9GmKmuAOU,706
6
6
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
7
7
  ai_edge_torch/_convert/conversion.py,sha256=QVugYVfbyaeBgSKKbhFzHG5oXA7t3M-40JcpcdSu6W8,5436
8
8
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
@@ -68,7 +68,7 @@ ai_edge_torch/generative/examples/gemma/verify_gemma1.py,sha256=ip-Gmk4CI5f0GWSd
68
68
  ai_edge_torch/generative/examples/gemma/verify_gemma2.py,sha256=jhiyinOqPt5ZZjEadDRZt_wY5fiLSCpMo54PcxFaL_Q,1789
69
69
  ai_edge_torch/generative/examples/gemma/verify_util.py,sha256=n7f2nF6Lin_tDvPs0JVldsuaBzo7pAwi5YAHAhlIxQg,6139
70
70
  ai_edge_torch/generative/examples/gemma3/__init__.py,sha256=JaAnrFoXTl3RJX97XspklkTyqOHVyAgRJsZtzNDd10c,671
71
- ai_edge_torch/generative/examples/gemma3/convert_gemma3_to_tflite.py,sha256=JLXXn2mFEBs4DlHH_O6hpEG9KInJqsCdWy3DrgUjT1c,1827
71
+ ai_edge_torch/generative/examples/gemma3/convert_gemma3_to_tflite.py,sha256=SsiK9xKCyboi5y-HdoFSN02QxRo0XabyzotUq46zO0E,2357
72
72
  ai_edge_torch/generative/examples/gemma3/decoder.py,sha256=shdgLzKDUi0vyNOAsrIVAEFb3Adltsri6Rx1-wxzVf4,15089
73
73
  ai_edge_torch/generative/examples/gemma3/gemma3.py,sha256=ZorRtnbElWsctcA0nEbfwjx0C578voF7fjFEvWSR5Ck,6582
74
74
  ai_edge_torch/generative/examples/gemma3/image_encoder.py,sha256=uRoLoBWzFtQz5wFZfPCxbkvZsgPAqSkUUsV3977GbYc,5184
@@ -166,7 +166,8 @@ ai_edge_torch/generative/layers/feed_forward_test.py,sha256=8ZGy79BBpsyS6yKKDEKr
166
166
  ai_edge_torch/generative/layers/kv_cache.py,sha256=b-7shzDaKexmvQF7P3SiAmIz4ZofjYWv3m5u71GojsA,10460
167
167
  ai_edge_torch/generative/layers/lora.py,sha256=hsvWLLOnW7HQ0AysOZu30x_cetMquDd1tjfyLz8HCSU,17892
168
168
  ai_edge_torch/generative/layers/model_config.py,sha256=X_gjN5524DCDBNXsX5GrOBlkKM4UHzj_RfdCD0-VOxQ,8572
169
- ai_edge_torch/generative/layers/normalization.py,sha256=MbwH-n80Fob5YvjBzdqDjBizMHLzSJGYRDdbD-rL5C0,6174
169
+ ai_edge_torch/generative/layers/normalization.py,sha256=ijwCpi22NLX-Sygwy5sK9l9WjGvbPIhZvVwoBAonWAo,7014
170
+ ai_edge_torch/generative/layers/normalization_test.py,sha256=zwurZly-TgFxdgVVdpzu9vCpcLbd5RYt_gKg9Lfg1jI,2248
170
171
  ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=975zR202MdIrILJ7blceAcxrNqX1ZCN0ECKG1gz-bV8,2655
171
172
  ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=2_AgwENsaOgaxgiSqgoj0V0JzQ09dFtP_nBhX-lJK2g,5648
172
173
  ai_edge_torch/generative/layers/scaled_dot_product_attention_test.py,sha256=c6JBMQsq9XeMmR1XvGEIidNsoh-YIvichXo2LwVHgr4,3301
@@ -194,7 +195,7 @@ ai_edge_torch/generative/test/utils.py,sha256=tF6aCfAGJnc9dmzCnZCEOuKNVimfWOqscv
194
195
  ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
195
196
  ai_edge_torch/generative/utilities/converter.py,sha256=4zcDlhgCQQyLylH8NLgVjnelou2pW6HWJHBFYsFyHuw,15020
196
197
  ai_edge_torch/generative/utilities/export_config.py,sha256=5IvR3grlMd4mWO5c_Y4x9Fk1b1xa57MzlYNE8XUaN28,2049
197
- ai_edge_torch/generative/utilities/loader.py,sha256=nw2REQ9sGWDwphShfRqNFICYmwIjqLp6bDcwVmsNTtg,14067
198
+ ai_edge_torch/generative/utilities/loader.py,sha256=tSiew77hB_zyn6rpcfegSg1zrriqHSz63KjV9_llBxg,14893
198
199
  ai_edge_torch/generative/utilities/model_builder.py,sha256=tBfOcsI_NcneggHqkCSydYN3ZgmkzPc6nW0AJrA81wI,6461
199
200
  ai_edge_torch/generative/utilities/moonshine_loader.py,sha256=_RpFabSqtGH5PHiP3_1f6QfO14qMADUxr_HGRlVDFB0,4891
200
201
  ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=dqPD9qRXEWtU3ombslOC-BE2l_dMwHoCNu7NsIJhsso,36158
@@ -251,8 +252,8 @@ ai_edge_torch/testing/__init__.py,sha256=_yGgvnBZWb7T3IN3mc4x1sS4vM96HZwM8pwIcPG
251
252
  ai_edge_torch/testing/export.py,sha256=k5mGDGzwc23Z4zaIVDs8CNh-oOt64gsf9MS9NjhbPy4,3293
252
253
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
253
254
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
254
- ai_edge_torch_nightly-0.5.0.dev20250514.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
255
- ai_edge_torch_nightly-0.5.0.dev20250514.dist-info/METADATA,sha256=4_d1LvNhvXOHKlqYZDcBYSLdYDmoGvWMgCK5PJasNiU,2074
256
- ai_edge_torch_nightly-0.5.0.dev20250514.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
257
- ai_edge_torch_nightly-0.5.0.dev20250514.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
258
- ai_edge_torch_nightly-0.5.0.dev20250514.dist-info/RECORD,,
255
+ ai_edge_torch_nightly-0.5.0.dev20250515.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
256
+ ai_edge_torch_nightly-0.5.0.dev20250515.dist-info/METADATA,sha256=FmCPouaJYszNPCOfgIx8WGFkGv5LrqR6_OGpciU2eKc,2074
257
+ ai_edge_torch_nightly-0.5.0.dev20250515.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
258
+ ai_edge_torch_nightly-0.5.0.dev20250515.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
259
+ ai_edge_torch_nightly-0.5.0.dev20250515.dist-info/RECORD,,