ai-edge-torch-nightly 0.5.0.dev20250514__py3-none-any.whl → 0.5.0.dev20250515__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/generative/examples/gemma3/convert_gemma3_to_tflite.py +14 -0
- ai_edge_torch/generative/layers/normalization.py +26 -7
- ai_edge_torch/generative/layers/normalization_test.py +73 -0
- ai_edge_torch/generative/utilities/loader.py +26 -0
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.5.0.dev20250514.dist-info → ai_edge_torch_nightly-0.5.0.dev20250515.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.5.0.dev20250514.dist-info → ai_edge_torch_nightly-0.5.0.dev20250515.dist-info}/RECORD +10 -9
- {ai_edge_torch_nightly-0.5.0.dev20250514.dist-info → ai_edge_torch_nightly-0.5.0.dev20250515.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.5.0.dev20250514.dist-info → ai_edge_torch_nightly-0.5.0.dev20250515.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.5.0.dev20250514.dist-info → ai_edge_torch_nightly-0.5.0.dev20250515.dist-info}/top_level.txt +0 -0
@@ -19,11 +19,19 @@ from absl import app
|
|
19
19
|
from ai_edge_torch.generative.examples.gemma3 import gemma3
|
20
20
|
from ai_edge_torch.generative.utilities import converter
|
21
21
|
from ai_edge_torch.generative.utilities import export_config
|
22
|
+
from ai_edge_torch.generative.utilities import loader
|
22
23
|
|
23
24
|
flags = converter.define_conversion_flags(
|
24
25
|
'gemma3-1b', default_mask_as_input=True, default_transpose_kv_cache=True
|
25
26
|
)
|
26
27
|
|
28
|
+
_CUSTOM_CHECKPOINT_LOADER = flags.DEFINE_bool(
|
29
|
+
'custom_checkpoint_loader',
|
30
|
+
False,
|
31
|
+
'If true, the conversion script will use a custom checkpoint loader which'
|
32
|
+
' will read a checkpoint from a remote source.',
|
33
|
+
)
|
34
|
+
|
27
35
|
_MODEL_SIZE = flags.DEFINE_string(
|
28
36
|
'model_size',
|
29
37
|
'1b',
|
@@ -32,10 +40,16 @@ _MODEL_SIZE = flags.DEFINE_string(
|
|
32
40
|
|
33
41
|
|
34
42
|
def main(_):
|
43
|
+
custom_loader = None
|
44
|
+
if flags.FLAGS.custom_checkpoint_loader:
|
45
|
+
# If loading from a remote source, try to get a custom loader first.
|
46
|
+
custom_loader = loader.get_custom_loader(flags.FLAGS.checkpoint_path)
|
47
|
+
|
35
48
|
if _MODEL_SIZE.value == '1b':
|
36
49
|
pytorch_model = gemma3.build_model_1b(
|
37
50
|
flags.FLAGS.checkpoint_path,
|
38
51
|
kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
|
52
|
+
custom_loader=custom_loader,
|
39
53
|
)
|
40
54
|
else:
|
41
55
|
raise ValueError(f'Unsupported model size: {_MODEL_SIZE.value}')
|
@@ -28,6 +28,8 @@ class RMSNorm(torch.nn.Module):
|
|
28
28
|
dim: int,
|
29
29
|
eps: float = 1e-6,
|
30
30
|
zero_centered_gamma=False,
|
31
|
+
with_scale: bool = False,
|
32
|
+
scale_shift: float = 1.0,
|
31
33
|
enable_hlfb: bool = False,
|
32
34
|
):
|
33
35
|
"""Initialize the RMSNorm layer.
|
@@ -37,13 +39,22 @@ class RMSNorm(torch.nn.Module):
|
|
37
39
|
eps (float): A small float value to ensure numerical stability (default:
|
38
40
|
1e-6).
|
39
41
|
zero_centered_gamma (bool): Whether or not gamma has an offset.
|
42
|
+
with_scale (bool): Whether or not to use a scale parameter.
|
43
|
+
scale_shift (float): The shift to apply to the scale parameter.
|
40
44
|
enable_hlfb (bool): use HLFB in the op.
|
41
45
|
"""
|
42
46
|
super().__init__()
|
47
|
+
self.dim = dim
|
43
48
|
self.enable_hlfb = enable_hlfb
|
44
49
|
self.eps = eps
|
45
|
-
self.weight = torch.nn.Parameter(torch.ones(dim))
|
50
|
+
self.weight = torch.nn.Parameter(torch.ones(dim), requires_grad=False)
|
46
51
|
self.zero_centered_gamma = zero_centered_gamma
|
52
|
+
self.with_scale = with_scale
|
53
|
+
if with_scale:
|
54
|
+
self.scale = torch.nn.Parameter(
|
55
|
+
torch.zeros((dim,), dtype=torch.float32), requires_grad=False
|
56
|
+
)
|
57
|
+
self.scale_shift = scale_shift
|
47
58
|
|
48
59
|
def _norm(self, x):
|
49
60
|
"""Apply RMSNorm normalization.
|
@@ -70,14 +81,20 @@ class RMSNorm(torch.nn.Module):
|
|
70
81
|
else:
|
71
82
|
w = self.weight
|
72
83
|
|
84
|
+
final_scale = (
|
85
|
+
self.scale + self.scale_shift
|
86
|
+
if self.with_scale
|
87
|
+
else torch.ones((self.dim,), dtype=torch.float32)
|
88
|
+
)
|
73
89
|
if self.enable_hlfb:
|
74
90
|
return rms_norm_with_hlfb(
|
75
91
|
x,
|
76
92
|
w,
|
77
93
|
self.eps,
|
94
|
+
final_scale,
|
78
95
|
)
|
79
96
|
else:
|
80
|
-
output = self._norm(x.float()).type_as(x)
|
97
|
+
output = self._norm(x.float()).type_as(x) * final_scale
|
81
98
|
return output * w
|
82
99
|
|
83
100
|
|
@@ -104,8 +121,8 @@ class GroupNorm(torch.nn.Module):
|
|
104
121
|
self.enable_hlfb = enable_hlfb
|
105
122
|
self.group_num = group_num
|
106
123
|
self.eps = eps
|
107
|
-
self.weight = torch.nn.Parameter(torch.empty(dim))
|
108
|
-
self.bias = torch.nn.Parameter(torch.empty(dim))
|
124
|
+
self.weight = torch.nn.Parameter(torch.empty(dim), requires_grad=False)
|
125
|
+
self.bias = torch.nn.Parameter(torch.empty(dim), requires_grad=False)
|
109
126
|
|
110
127
|
def forward(self, x):
|
111
128
|
"""Running the forward pass of GroupNorm layer.
|
@@ -140,8 +157,8 @@ class LayerNorm(torch.nn.Module):
|
|
140
157
|
self.enable_hlfb = enable_hlfb
|
141
158
|
self.normalized_shape = (dim,)
|
142
159
|
self.eps = eps
|
143
|
-
self.weight = torch.nn.Parameter(torch.empty(dim))
|
144
|
-
self.bias = torch.nn.Parameter(torch.empty(dim))
|
160
|
+
self.weight = torch.nn.Parameter(torch.empty(dim), requires_grad=False)
|
161
|
+
self.bias = torch.nn.Parameter(torch.empty(dim), requires_grad=False)
|
145
162
|
|
146
163
|
def forward(self, x):
|
147
164
|
"""Running the forward pass of LayerNorm layer.
|
@@ -165,6 +182,7 @@ def rms_norm_with_hlfb(
|
|
165
182
|
x: torch.Tensor,
|
166
183
|
w: torch.Tensor,
|
167
184
|
eps: float,
|
185
|
+
final_scale: torch.Tensor,
|
168
186
|
):
|
169
187
|
"""RMS Normalization with high-level function boundary enabled.
|
170
188
|
|
@@ -172,6 +190,7 @@ def rms_norm_with_hlfb(
|
|
172
190
|
x (torch.Tensor): Input tensor for RMS Normalization, with BCHW shape.
|
173
191
|
w (torch.Tensor): The learned parameter tensor for normalization.
|
174
192
|
eps (float): A small float value to ensure numerical stability.
|
193
|
+
final_scale (torch.Tensor): The final scale to apply to the normalization.
|
175
194
|
|
176
195
|
Returns:
|
177
196
|
The output tensor of RMS Normalization.
|
@@ -185,7 +204,7 @@ def rms_norm_with_hlfb(
|
|
185
204
|
def _norm(x):
|
186
205
|
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps)
|
187
206
|
|
188
|
-
output = _norm(x.float()).type_as(x)
|
207
|
+
output = _norm(x.float()).type_as(x) * final_scale
|
189
208
|
out = output * w
|
190
209
|
|
191
210
|
out = builder.mark_outputs(out)
|
@@ -0,0 +1,73 @@
|
|
1
|
+
# Copyright 2025 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
"""Tests for normalization layers."""
|
16
|
+
|
17
|
+
from ai_edge_torch.generative.layers import normalization
|
18
|
+
import torch
|
19
|
+
from absl.testing import absltest as googletest
|
20
|
+
from absl.testing import parameterized
|
21
|
+
|
22
|
+
|
23
|
+
class NormalizationTest(parameterized.TestCase):
|
24
|
+
|
25
|
+
@parameterized.named_parameters(
|
26
|
+
dict(
|
27
|
+
testcase_name="rms_norm_test_1",
|
28
|
+
model_dim=10,
|
29
|
+
with_scale=False,
|
30
|
+
scale_shift=1.0,
|
31
|
+
enable_hlfb=False,
|
32
|
+
expected_values=torch.ones((10,), dtype=torch.float32),
|
33
|
+
),
|
34
|
+
dict(
|
35
|
+
testcase_name="rms_norm_test_2",
|
36
|
+
model_dim=10,
|
37
|
+
with_scale=True,
|
38
|
+
scale_shift=2.0,
|
39
|
+
enable_hlfb=False,
|
40
|
+
expected_values=torch.ones((10,), dtype=torch.float32) * 2.0,
|
41
|
+
),
|
42
|
+
dict(
|
43
|
+
testcase_name="rms_norm_test_3",
|
44
|
+
model_dim=10,
|
45
|
+
with_scale=True,
|
46
|
+
scale_shift=2.0,
|
47
|
+
enable_hlfb=True,
|
48
|
+
expected_values=torch.ones((10,), dtype=torch.float32) * 2.0,
|
49
|
+
),
|
50
|
+
)
|
51
|
+
def test_rms_norm(
|
52
|
+
self,
|
53
|
+
model_dim: int,
|
54
|
+
with_scale: bool,
|
55
|
+
scale_shift: float,
|
56
|
+
enable_hlfb: bool,
|
57
|
+
expected_values: torch.Tensor,
|
58
|
+
):
|
59
|
+
rms_norm = normalization.RMSNorm(
|
60
|
+
dim=model_dim,
|
61
|
+
with_scale=with_scale,
|
62
|
+
scale_shift=scale_shift,
|
63
|
+
enable_hlfb=enable_hlfb,
|
64
|
+
)
|
65
|
+
|
66
|
+
x = torch.ones((model_dim,), dtype=torch.float32)
|
67
|
+
out = rms_norm(x)
|
68
|
+
self.assertEqual(out.shape, (model_dim,))
|
69
|
+
self.assertTrue(torch.allclose(out, expected_values))
|
70
|
+
|
71
|
+
|
72
|
+
if __name__ == "__main__":
|
73
|
+
googletest.main()
|
@@ -19,10 +19,36 @@ import os
|
|
19
19
|
from typing import Callable, Dict, List, Tuple
|
20
20
|
|
21
21
|
from ai_edge_torch.generative.layers import model_config
|
22
|
+
import safetensors
|
22
23
|
from safetensors import safe_open
|
23
24
|
import torch
|
24
25
|
|
25
26
|
|
27
|
+
def get_custom_loader(
|
28
|
+
checkpoint_path: str,
|
29
|
+
) -> Callable[[str], Dict[str, torch.Tensor]]:
|
30
|
+
"""Returns a custom loader for the given checkpoint path.
|
31
|
+
|
32
|
+
Those customer loaders can either support state dictionary or safetensors, and
|
33
|
+
the actual data might be fetched from a remote source.
|
34
|
+
|
35
|
+
Args:
|
36
|
+
checkpoint_path (string): The path to the checkpoint.
|
37
|
+
|
38
|
+
Returns:
|
39
|
+
Callable[[str], Dict[str, torch.Tensor]]: The custom loader.
|
40
|
+
|
41
|
+
Raises:
|
42
|
+
ValueError: If the checkpoint format is not supported.
|
43
|
+
"""
|
44
|
+
|
45
|
+
if os.path.splitext(checkpoint_path)[1] in [".bin", ".pt", ".ckpt"]:
|
46
|
+
return lambda path: torch.load(path, weights_only=True)
|
47
|
+
if checkpoint_path.endswith(".safetensors"):
|
48
|
+
return safetensors.torch.load_file
|
49
|
+
raise ValueError(f"Unsupported checkpoint format: {checkpoint_path}")
|
50
|
+
|
51
|
+
|
26
52
|
def load_safetensors(full_path: str):
|
27
53
|
"""Loads safetensors into a single state dictionary.
|
28
54
|
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.5.0.
|
3
|
+
Version: 0.5.0.dev20250515
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=8sPR_5uXJA4NEE0nIwNdSl-ADOJEoR8hAgYvBQDY70Y,120
|
|
2
2
|
ai_edge_torch/_config.py,sha256=AiqhbcheF7j_ozIGDLC89k1we95aVgFDa-tR6h7UI0s,2529
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
4
4
|
ai_edge_torch/model.py,sha256=wxjSFq_rBSxSqbUE8E8EJTCkgvgaRLjq_ZuAM-IZpCU,5606
|
5
|
-
ai_edge_torch/version.py,sha256=
|
5
|
+
ai_edge_torch/version.py,sha256=QVmEdwoLJem1gNQul_CoRyfqOc1Ljjy48x9GmKmuAOU,706
|
6
6
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
7
7
|
ai_edge_torch/_convert/conversion.py,sha256=QVugYVfbyaeBgSKKbhFzHG5oXA7t3M-40JcpcdSu6W8,5436
|
8
8
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
@@ -68,7 +68,7 @@ ai_edge_torch/generative/examples/gemma/verify_gemma1.py,sha256=ip-Gmk4CI5f0GWSd
|
|
68
68
|
ai_edge_torch/generative/examples/gemma/verify_gemma2.py,sha256=jhiyinOqPt5ZZjEadDRZt_wY5fiLSCpMo54PcxFaL_Q,1789
|
69
69
|
ai_edge_torch/generative/examples/gemma/verify_util.py,sha256=n7f2nF6Lin_tDvPs0JVldsuaBzo7pAwi5YAHAhlIxQg,6139
|
70
70
|
ai_edge_torch/generative/examples/gemma3/__init__.py,sha256=JaAnrFoXTl3RJX97XspklkTyqOHVyAgRJsZtzNDd10c,671
|
71
|
-
ai_edge_torch/generative/examples/gemma3/convert_gemma3_to_tflite.py,sha256=
|
71
|
+
ai_edge_torch/generative/examples/gemma3/convert_gemma3_to_tflite.py,sha256=SsiK9xKCyboi5y-HdoFSN02QxRo0XabyzotUq46zO0E,2357
|
72
72
|
ai_edge_torch/generative/examples/gemma3/decoder.py,sha256=shdgLzKDUi0vyNOAsrIVAEFb3Adltsri6Rx1-wxzVf4,15089
|
73
73
|
ai_edge_torch/generative/examples/gemma3/gemma3.py,sha256=ZorRtnbElWsctcA0nEbfwjx0C578voF7fjFEvWSR5Ck,6582
|
74
74
|
ai_edge_torch/generative/examples/gemma3/image_encoder.py,sha256=uRoLoBWzFtQz5wFZfPCxbkvZsgPAqSkUUsV3977GbYc,5184
|
@@ -166,7 +166,8 @@ ai_edge_torch/generative/layers/feed_forward_test.py,sha256=8ZGy79BBpsyS6yKKDEKr
|
|
166
166
|
ai_edge_torch/generative/layers/kv_cache.py,sha256=b-7shzDaKexmvQF7P3SiAmIz4ZofjYWv3m5u71GojsA,10460
|
167
167
|
ai_edge_torch/generative/layers/lora.py,sha256=hsvWLLOnW7HQ0AysOZu30x_cetMquDd1tjfyLz8HCSU,17892
|
168
168
|
ai_edge_torch/generative/layers/model_config.py,sha256=X_gjN5524DCDBNXsX5GrOBlkKM4UHzj_RfdCD0-VOxQ,8572
|
169
|
-
ai_edge_torch/generative/layers/normalization.py,sha256=
|
169
|
+
ai_edge_torch/generative/layers/normalization.py,sha256=ijwCpi22NLX-Sygwy5sK9l9WjGvbPIhZvVwoBAonWAo,7014
|
170
|
+
ai_edge_torch/generative/layers/normalization_test.py,sha256=zwurZly-TgFxdgVVdpzu9vCpcLbd5RYt_gKg9Lfg1jI,2248
|
170
171
|
ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=975zR202MdIrILJ7blceAcxrNqX1ZCN0ECKG1gz-bV8,2655
|
171
172
|
ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=2_AgwENsaOgaxgiSqgoj0V0JzQ09dFtP_nBhX-lJK2g,5648
|
172
173
|
ai_edge_torch/generative/layers/scaled_dot_product_attention_test.py,sha256=c6JBMQsq9XeMmR1XvGEIidNsoh-YIvichXo2LwVHgr4,3301
|
@@ -194,7 +195,7 @@ ai_edge_torch/generative/test/utils.py,sha256=tF6aCfAGJnc9dmzCnZCEOuKNVimfWOqscv
|
|
194
195
|
ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
|
195
196
|
ai_edge_torch/generative/utilities/converter.py,sha256=4zcDlhgCQQyLylH8NLgVjnelou2pW6HWJHBFYsFyHuw,15020
|
196
197
|
ai_edge_torch/generative/utilities/export_config.py,sha256=5IvR3grlMd4mWO5c_Y4x9Fk1b1xa57MzlYNE8XUaN28,2049
|
197
|
-
ai_edge_torch/generative/utilities/loader.py,sha256=
|
198
|
+
ai_edge_torch/generative/utilities/loader.py,sha256=tSiew77hB_zyn6rpcfegSg1zrriqHSz63KjV9_llBxg,14893
|
198
199
|
ai_edge_torch/generative/utilities/model_builder.py,sha256=tBfOcsI_NcneggHqkCSydYN3ZgmkzPc6nW0AJrA81wI,6461
|
199
200
|
ai_edge_torch/generative/utilities/moonshine_loader.py,sha256=_RpFabSqtGH5PHiP3_1f6QfO14qMADUxr_HGRlVDFB0,4891
|
200
201
|
ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=dqPD9qRXEWtU3ombslOC-BE2l_dMwHoCNu7NsIJhsso,36158
|
@@ -251,8 +252,8 @@ ai_edge_torch/testing/__init__.py,sha256=_yGgvnBZWb7T3IN3mc4x1sS4vM96HZwM8pwIcPG
|
|
251
252
|
ai_edge_torch/testing/export.py,sha256=k5mGDGzwc23Z4zaIVDs8CNh-oOt64gsf9MS9NjhbPy4,3293
|
252
253
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
253
254
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
254
|
-
ai_edge_torch_nightly-0.5.0.
|
255
|
-
ai_edge_torch_nightly-0.5.0.
|
256
|
-
ai_edge_torch_nightly-0.5.0.
|
257
|
-
ai_edge_torch_nightly-0.5.0.
|
258
|
-
ai_edge_torch_nightly-0.5.0.
|
255
|
+
ai_edge_torch_nightly-0.5.0.dev20250515.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
256
|
+
ai_edge_torch_nightly-0.5.0.dev20250515.dist-info/METADATA,sha256=FmCPouaJYszNPCOfgIx8WGFkGv5LrqR6_OGpciU2eKc,2074
|
257
|
+
ai_edge_torch_nightly-0.5.0.dev20250515.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
|
258
|
+
ai_edge_torch_nightly-0.5.0.dev20250515.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
259
|
+
ai_edge_torch_nightly-0.5.0.dev20250515.dist-info/RECORD,,
|
File without changes
|
File without changes
|