ai-edge-torch-nightly 0.5.0.dev20250513__py3-none-any.whl → 0.5.0.dev20250515__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/generative/examples/gemma3/convert_gemma3_to_tflite.py +14 -0
- ai_edge_torch/generative/examples/gemma3/decoder.py +7 -2
- ai_edge_torch/generative/examples/gemma3/gemma3.py +8 -4
- ai_edge_torch/generative/examples/gemma3/verify_util.py +14 -3
- ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py +2 -3
- ai_edge_torch/generative/examples/qwen_vl/convert_to_tflite.py +2 -3
- ai_edge_torch/generative/layers/normalization.py +26 -7
- ai_edge_torch/generative/layers/normalization_test.py +73 -0
- ai_edge_torch/generative/utilities/converter.py +9 -0
- ai_edge_torch/generative/utilities/loader.py +38 -2
- ai_edge_torch/generative/utilities/model_builder.py +5 -1
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.5.0.dev20250513.dist-info → ai_edge_torch_nightly-0.5.0.dev20250515.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.5.0.dev20250513.dist-info → ai_edge_torch_nightly-0.5.0.dev20250515.dist-info}/RECORD +17 -16
- {ai_edge_torch_nightly-0.5.0.dev20250513.dist-info → ai_edge_torch_nightly-0.5.0.dev20250515.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.5.0.dev20250513.dist-info → ai_edge_torch_nightly-0.5.0.dev20250515.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.5.0.dev20250513.dist-info → ai_edge_torch_nightly-0.5.0.dev20250515.dist-info}/top_level.txt +0 -0
@@ -19,11 +19,19 @@ from absl import app
|
|
19
19
|
from ai_edge_torch.generative.examples.gemma3 import gemma3
|
20
20
|
from ai_edge_torch.generative.utilities import converter
|
21
21
|
from ai_edge_torch.generative.utilities import export_config
|
22
|
+
from ai_edge_torch.generative.utilities import loader
|
22
23
|
|
23
24
|
flags = converter.define_conversion_flags(
|
24
25
|
'gemma3-1b', default_mask_as_input=True, default_transpose_kv_cache=True
|
25
26
|
)
|
26
27
|
|
28
|
+
_CUSTOM_CHECKPOINT_LOADER = flags.DEFINE_bool(
|
29
|
+
'custom_checkpoint_loader',
|
30
|
+
False,
|
31
|
+
'If true, the conversion script will use a custom checkpoint loader which'
|
32
|
+
' will read a checkpoint from a remote source.',
|
33
|
+
)
|
34
|
+
|
27
35
|
_MODEL_SIZE = flags.DEFINE_string(
|
28
36
|
'model_size',
|
29
37
|
'1b',
|
@@ -32,10 +40,16 @@ _MODEL_SIZE = flags.DEFINE_string(
|
|
32
40
|
|
33
41
|
|
34
42
|
def main(_):
|
43
|
+
custom_loader = None
|
44
|
+
if flags.FLAGS.custom_checkpoint_loader:
|
45
|
+
# If loading from a remote source, try to get a custom loader first.
|
46
|
+
custom_loader = loader.get_custom_loader(flags.FLAGS.checkpoint_path)
|
47
|
+
|
35
48
|
if _MODEL_SIZE.value == '1b':
|
36
49
|
pytorch_model = gemma3.build_model_1b(
|
37
50
|
flags.FLAGS.checkpoint_path,
|
38
51
|
kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
|
52
|
+
custom_loader=custom_loader,
|
39
53
|
)
|
40
54
|
else:
|
41
55
|
raise ValueError(f'Unsupported model size: {_MODEL_SIZE.value}')
|
@@ -15,7 +15,7 @@
|
|
15
15
|
|
16
16
|
"""Example of building a Decoder for Gemma3 model."""
|
17
17
|
|
18
|
-
from typing import List, Optional, Tuple
|
18
|
+
from typing import Callable, Dict, List, Optional, Tuple
|
19
19
|
|
20
20
|
from ai_edge_torch.generative.layers import attention
|
21
21
|
from ai_edge_torch.generative.layers import builder
|
@@ -410,7 +410,11 @@ def get_fake_decoder_config_1b(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
|
|
410
410
|
return config
|
411
411
|
|
412
412
|
|
413
|
-
def build_model_1b(
|
413
|
+
def build_model_1b(
|
414
|
+
checkpoint_path: str,
|
415
|
+
custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
|
416
|
+
**kwargs,
|
417
|
+
) -> nn.Module:
|
414
418
|
# TODO(b/403644647): Better error handling for loading checkpoints with
|
415
419
|
# different tensor names.
|
416
420
|
for tensor_names in TENSOR_NAMES_DICT.values():
|
@@ -420,6 +424,7 @@ def build_model_1b(checkpoint_path: str, **kwargs) -> nn.Module:
|
|
420
424
|
config=get_decoder_config_1b(**kwargs),
|
421
425
|
tensor_names=tensor_names,
|
422
426
|
model_class=Decoder,
|
427
|
+
custom_loader=custom_loader,
|
423
428
|
)
|
424
429
|
except KeyError as ke:
|
425
430
|
continue
|
@@ -16,8 +16,7 @@
|
|
16
16
|
"""Example of building a Gemma3 gpu model."""
|
17
17
|
|
18
18
|
from dataclasses import dataclass
|
19
|
-
from typing import List, Optional, Tuple
|
20
|
-
import xmlrpc
|
19
|
+
from typing import List, Optional, Tuple, Callable, Dict
|
21
20
|
|
22
21
|
from ai_edge_torch.generative.examples.gemma3 import decoder
|
23
22
|
from ai_edge_torch.generative.examples.gemma3 import image_encoder
|
@@ -166,9 +165,14 @@ def get_fake_model_config(**kwargs) -> Gemma3MMConfig:
|
|
166
165
|
mm_extra_tokens=32,
|
167
166
|
)
|
168
167
|
|
169
|
-
|
168
|
+
|
169
|
+
def build_model_1b(
|
170
|
+
checkpoint_path: str,
|
171
|
+
custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
|
172
|
+
**kwargs,
|
173
|
+
) -> decoder.Decoder:
|
170
174
|
if checkpoint_path:
|
171
|
-
model = decoder.build_model_1b(checkpoint_path, **kwargs)
|
175
|
+
model = decoder.build_model_1b(checkpoint_path, custom_loader, **kwargs)
|
172
176
|
else:
|
173
177
|
config = decoder.get_decoder_config_1b(**kwargs)
|
174
178
|
model = decoder.Decoder(config)
|
@@ -17,7 +17,7 @@
|
|
17
17
|
|
18
18
|
import logging
|
19
19
|
import os
|
20
|
-
from typing import List, Optional, Tuple
|
20
|
+
from typing import Callable, Dict, List, Optional, Tuple
|
21
21
|
|
22
22
|
from ai_edge_torch.generative.examples.gemma3 import gemma3
|
23
23
|
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
@@ -167,6 +167,7 @@ def verify_reauthored_gemma_model(
|
|
167
167
|
generate_prompts: List[str],
|
168
168
|
forward_input_ids: List[List[int]],
|
169
169
|
weight_filename: str,
|
170
|
+
custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
|
170
171
|
tokenizer_filename: str = "tokenizer.model",
|
171
172
|
max_new_tokens: int = 20,
|
172
173
|
rtol: float = 1e-05,
|
@@ -196,7 +197,14 @@ def verify_reauthored_gemma_model(
|
|
196
197
|
|
197
198
|
logging.info("Loading the original model from: %s", checkpoint)
|
198
199
|
original_model = gemma_model.GemmaForCausalLM(config).eval()
|
199
|
-
|
200
|
+
checkpoint_path = os.path.join(checkpoint, weight_filename)
|
201
|
+
if custom_loader is None:
|
202
|
+
original_model.load_weights(checkpoint_path)
|
203
|
+
else:
|
204
|
+
original_model.load_state_dict(
|
205
|
+
custom_loader(checkpoint_path)["model_state_dict"],
|
206
|
+
strict=False,
|
207
|
+
)
|
200
208
|
|
201
209
|
return verifier.verify_reauthored_model(
|
202
210
|
original_model=GemmaWrapper(original_model),
|
@@ -216,6 +224,7 @@ def verify_gemma3(
|
|
216
224
|
max_new_tokens: int,
|
217
225
|
variant: str,
|
218
226
|
weight_filename: str,
|
227
|
+
custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
|
219
228
|
) -> bool:
|
220
229
|
"""Verifies the reauthored Gemma3 model.
|
221
230
|
|
@@ -225,6 +234,7 @@ def verify_gemma3(
|
|
225
234
|
max_new_tokens: Maximum number of new tokens to generate.
|
226
235
|
variant: Gemma model variant.
|
227
236
|
weight_filename: Name of the weight file.
|
237
|
+
custom_loader: A custom loader to load the weights.
|
228
238
|
|
229
239
|
Returns:
|
230
240
|
True if the verification passes, False otherwise.
|
@@ -234,7 +244,7 @@ def verify_gemma3(
|
|
234
244
|
|
235
245
|
if variant == "1b":
|
236
246
|
reauthored_model = UnifiedGemma3Wrapper(
|
237
|
-
gemma3.build_model_1b(gemma3_model_path)
|
247
|
+
gemma3.build_model_1b(gemma3_model_path, custom_loader)
|
238
248
|
)
|
239
249
|
else:
|
240
250
|
raise ValueError(f"Unsupported Gemma3 variant: {variant}")
|
@@ -247,5 +257,6 @@ def verify_gemma3(
|
|
247
257
|
forward_input_ids=[[2, 651, 9456, 576, 573, 3520, 3858, 603, 235248]],
|
248
258
|
max_new_tokens=max_new_tokens,
|
249
259
|
weight_filename=weight_filename,
|
260
|
+
custom_loader=custom_loader,
|
250
261
|
atol=1e-04,
|
251
262
|
)
|
@@ -22,8 +22,6 @@ from ai_edge_torch.generative.utilities import export_config
|
|
22
22
|
import torch
|
23
23
|
|
24
24
|
flags = converter.define_conversion_flags('paligemma2-3b-224')
|
25
|
-
ExportConfig = export_config.ExportConfig
|
26
|
-
|
27
25
|
|
28
26
|
_VERSION = flags.DEFINE_enum(
|
29
27
|
'version',
|
@@ -32,6 +30,7 @@ _VERSION = flags.DEFINE_enum(
|
|
32
30
|
'The version of PaliGemma model to verify.',
|
33
31
|
)
|
34
32
|
|
33
|
+
|
35
34
|
def main(_):
|
36
35
|
pytorch_model = paligemma.build_model(
|
37
36
|
flags.FLAGS.checkpoint_path,
|
@@ -51,7 +50,7 @@ def main(_):
|
|
51
50
|
pixel_seq_len=(config.image_size // config.patch_size) ** 2,
|
52
51
|
quantize=flags.FLAGS.quantize,
|
53
52
|
config=pytorch_model.config.decoder_config,
|
54
|
-
export_config=
|
53
|
+
export_config=export_config.get_from_flags(),
|
55
54
|
)
|
56
55
|
|
57
56
|
|
@@ -21,8 +21,6 @@ from ai_edge_torch.generative.utilities import converter
|
|
21
21
|
from ai_edge_torch.generative.utilities import export_config
|
22
22
|
|
23
23
|
flags = converter.define_conversion_flags('qwen_vl')
|
24
|
-
ExportConfig = export_config.ExportConfig
|
25
|
-
|
26
24
|
|
27
25
|
_IMAGE_HEIGHT = flags.DEFINE_integer(
|
28
26
|
'image_height',
|
@@ -35,6 +33,7 @@ _IMAGE_WIDTH = flags.DEFINE_integer(
|
|
35
33
|
'The width of image.',
|
36
34
|
)
|
37
35
|
|
36
|
+
|
38
37
|
def main(_):
|
39
38
|
pytorch_model = qwen_vl.build_model(
|
40
39
|
flags.FLAGS.checkpoint_path,
|
@@ -60,7 +59,7 @@ def main(_):
|
|
60
59
|
),
|
61
60
|
quantize=flags.FLAGS.quantize,
|
62
61
|
config=pytorch_model.config.decoder_config,
|
63
|
-
export_config=
|
62
|
+
export_config=export_config.get_from_flags(),
|
64
63
|
)
|
65
64
|
|
66
65
|
|
@@ -28,6 +28,8 @@ class RMSNorm(torch.nn.Module):
|
|
28
28
|
dim: int,
|
29
29
|
eps: float = 1e-6,
|
30
30
|
zero_centered_gamma=False,
|
31
|
+
with_scale: bool = False,
|
32
|
+
scale_shift: float = 1.0,
|
31
33
|
enable_hlfb: bool = False,
|
32
34
|
):
|
33
35
|
"""Initialize the RMSNorm layer.
|
@@ -37,13 +39,22 @@ class RMSNorm(torch.nn.Module):
|
|
37
39
|
eps (float): A small float value to ensure numerical stability (default:
|
38
40
|
1e-6).
|
39
41
|
zero_centered_gamma (bool): Whether or not gamma has an offset.
|
42
|
+
with_scale (bool): Whether or not to use a scale parameter.
|
43
|
+
scale_shift (float): The shift to apply to the scale parameter.
|
40
44
|
enable_hlfb (bool): use HLFB in the op.
|
41
45
|
"""
|
42
46
|
super().__init__()
|
47
|
+
self.dim = dim
|
43
48
|
self.enable_hlfb = enable_hlfb
|
44
49
|
self.eps = eps
|
45
|
-
self.weight = torch.nn.Parameter(torch.ones(dim))
|
50
|
+
self.weight = torch.nn.Parameter(torch.ones(dim), requires_grad=False)
|
46
51
|
self.zero_centered_gamma = zero_centered_gamma
|
52
|
+
self.with_scale = with_scale
|
53
|
+
if with_scale:
|
54
|
+
self.scale = torch.nn.Parameter(
|
55
|
+
torch.zeros((dim,), dtype=torch.float32), requires_grad=False
|
56
|
+
)
|
57
|
+
self.scale_shift = scale_shift
|
47
58
|
|
48
59
|
def _norm(self, x):
|
49
60
|
"""Apply RMSNorm normalization.
|
@@ -70,14 +81,20 @@ class RMSNorm(torch.nn.Module):
|
|
70
81
|
else:
|
71
82
|
w = self.weight
|
72
83
|
|
84
|
+
final_scale = (
|
85
|
+
self.scale + self.scale_shift
|
86
|
+
if self.with_scale
|
87
|
+
else torch.ones((self.dim,), dtype=torch.float32)
|
88
|
+
)
|
73
89
|
if self.enable_hlfb:
|
74
90
|
return rms_norm_with_hlfb(
|
75
91
|
x,
|
76
92
|
w,
|
77
93
|
self.eps,
|
94
|
+
final_scale,
|
78
95
|
)
|
79
96
|
else:
|
80
|
-
output = self._norm(x.float()).type_as(x)
|
97
|
+
output = self._norm(x.float()).type_as(x) * final_scale
|
81
98
|
return output * w
|
82
99
|
|
83
100
|
|
@@ -104,8 +121,8 @@ class GroupNorm(torch.nn.Module):
|
|
104
121
|
self.enable_hlfb = enable_hlfb
|
105
122
|
self.group_num = group_num
|
106
123
|
self.eps = eps
|
107
|
-
self.weight = torch.nn.Parameter(torch.empty(dim))
|
108
|
-
self.bias = torch.nn.Parameter(torch.empty(dim))
|
124
|
+
self.weight = torch.nn.Parameter(torch.empty(dim), requires_grad=False)
|
125
|
+
self.bias = torch.nn.Parameter(torch.empty(dim), requires_grad=False)
|
109
126
|
|
110
127
|
def forward(self, x):
|
111
128
|
"""Running the forward pass of GroupNorm layer.
|
@@ -140,8 +157,8 @@ class LayerNorm(torch.nn.Module):
|
|
140
157
|
self.enable_hlfb = enable_hlfb
|
141
158
|
self.normalized_shape = (dim,)
|
142
159
|
self.eps = eps
|
143
|
-
self.weight = torch.nn.Parameter(torch.empty(dim))
|
144
|
-
self.bias = torch.nn.Parameter(torch.empty(dim))
|
160
|
+
self.weight = torch.nn.Parameter(torch.empty(dim), requires_grad=False)
|
161
|
+
self.bias = torch.nn.Parameter(torch.empty(dim), requires_grad=False)
|
145
162
|
|
146
163
|
def forward(self, x):
|
147
164
|
"""Running the forward pass of LayerNorm layer.
|
@@ -165,6 +182,7 @@ def rms_norm_with_hlfb(
|
|
165
182
|
x: torch.Tensor,
|
166
183
|
w: torch.Tensor,
|
167
184
|
eps: float,
|
185
|
+
final_scale: torch.Tensor,
|
168
186
|
):
|
169
187
|
"""RMS Normalization with high-level function boundary enabled.
|
170
188
|
|
@@ -172,6 +190,7 @@ def rms_norm_with_hlfb(
|
|
172
190
|
x (torch.Tensor): Input tensor for RMS Normalization, with BCHW shape.
|
173
191
|
w (torch.Tensor): The learned parameter tensor for normalization.
|
174
192
|
eps (float): A small float value to ensure numerical stability.
|
193
|
+
final_scale (torch.Tensor): The final scale to apply to the normalization.
|
175
194
|
|
176
195
|
Returns:
|
177
196
|
The output tensor of RMS Normalization.
|
@@ -185,7 +204,7 @@ def rms_norm_with_hlfb(
|
|
185
204
|
def _norm(x):
|
186
205
|
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps)
|
187
206
|
|
188
|
-
output = _norm(x.float()).type_as(x)
|
207
|
+
output = _norm(x.float()).type_as(x) * final_scale
|
189
208
|
out = output * w
|
190
209
|
|
191
210
|
out = builder.mark_outputs(out)
|
@@ -0,0 +1,73 @@
|
|
1
|
+
# Copyright 2025 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
"""Tests for normalization layers."""
|
16
|
+
|
17
|
+
from ai_edge_torch.generative.layers import normalization
|
18
|
+
import torch
|
19
|
+
from absl.testing import absltest as googletest
|
20
|
+
from absl.testing import parameterized
|
21
|
+
|
22
|
+
|
23
|
+
class NormalizationTest(parameterized.TestCase):
|
24
|
+
|
25
|
+
@parameterized.named_parameters(
|
26
|
+
dict(
|
27
|
+
testcase_name="rms_norm_test_1",
|
28
|
+
model_dim=10,
|
29
|
+
with_scale=False,
|
30
|
+
scale_shift=1.0,
|
31
|
+
enable_hlfb=False,
|
32
|
+
expected_values=torch.ones((10,), dtype=torch.float32),
|
33
|
+
),
|
34
|
+
dict(
|
35
|
+
testcase_name="rms_norm_test_2",
|
36
|
+
model_dim=10,
|
37
|
+
with_scale=True,
|
38
|
+
scale_shift=2.0,
|
39
|
+
enable_hlfb=False,
|
40
|
+
expected_values=torch.ones((10,), dtype=torch.float32) * 2.0,
|
41
|
+
),
|
42
|
+
dict(
|
43
|
+
testcase_name="rms_norm_test_3",
|
44
|
+
model_dim=10,
|
45
|
+
with_scale=True,
|
46
|
+
scale_shift=2.0,
|
47
|
+
enable_hlfb=True,
|
48
|
+
expected_values=torch.ones((10,), dtype=torch.float32) * 2.0,
|
49
|
+
),
|
50
|
+
)
|
51
|
+
def test_rms_norm(
|
52
|
+
self,
|
53
|
+
model_dim: int,
|
54
|
+
with_scale: bool,
|
55
|
+
scale_shift: float,
|
56
|
+
enable_hlfb: bool,
|
57
|
+
expected_values: torch.Tensor,
|
58
|
+
):
|
59
|
+
rms_norm = normalization.RMSNorm(
|
60
|
+
dim=model_dim,
|
61
|
+
with_scale=with_scale,
|
62
|
+
scale_shift=scale_shift,
|
63
|
+
enable_hlfb=enable_hlfb,
|
64
|
+
)
|
65
|
+
|
66
|
+
x = torch.ones((model_dim,), dtype=torch.float32)
|
67
|
+
out = rms_norm(x)
|
68
|
+
self.assertEqual(out.shape, (model_dim,))
|
69
|
+
self.assertTrue(torch.allclose(out, expected_values))
|
70
|
+
|
71
|
+
|
72
|
+
if __name__ == "__main__":
|
73
|
+
googletest.main()
|
@@ -280,6 +280,15 @@ def convert_to_tflite(
|
|
280
280
|
'' if not lora_ranks else f'_lora{",".join(map(str, lora_ranks))}'
|
281
281
|
)
|
282
282
|
|
283
|
+
if pixel_values_size is not None:
|
284
|
+
assert pixel_seq_len > 0, 'pixel_seq_len must be greater than 0'
|
285
|
+
max_prefill_seq_len = max(prefill_seq_lens)
|
286
|
+
assert kv_size > max_prefill_seq_len + pixel_seq_len, (
|
287
|
+
f'The KV cache size ({kv_size}) must be greater than the maximum '
|
288
|
+
f'prefill sequence length ({max_prefill_seq_len}) + pixel sequence '
|
289
|
+
f'length ({pixel_seq_len})'
|
290
|
+
)
|
291
|
+
|
283
292
|
if export_config is not None:
|
284
293
|
if export_config.decode_batch_size > 1:
|
285
294
|
output_name_prefix += f'_dbs{export_config.decode_batch_size}'
|
@@ -19,10 +19,36 @@ import os
|
|
19
19
|
from typing import Callable, Dict, List, Tuple
|
20
20
|
|
21
21
|
from ai_edge_torch.generative.layers import model_config
|
22
|
+
import safetensors
|
22
23
|
from safetensors import safe_open
|
23
24
|
import torch
|
24
25
|
|
25
26
|
|
27
|
+
def get_custom_loader(
|
28
|
+
checkpoint_path: str,
|
29
|
+
) -> Callable[[str], Dict[str, torch.Tensor]]:
|
30
|
+
"""Returns a custom loader for the given checkpoint path.
|
31
|
+
|
32
|
+
Those customer loaders can either support state dictionary or safetensors, and
|
33
|
+
the actual data might be fetched from a remote source.
|
34
|
+
|
35
|
+
Args:
|
36
|
+
checkpoint_path (string): The path to the checkpoint.
|
37
|
+
|
38
|
+
Returns:
|
39
|
+
Callable[[str], Dict[str, torch.Tensor]]: The custom loader.
|
40
|
+
|
41
|
+
Raises:
|
42
|
+
ValueError: If the checkpoint format is not supported.
|
43
|
+
"""
|
44
|
+
|
45
|
+
if os.path.splitext(checkpoint_path)[1] in [".bin", ".pt", ".ckpt"]:
|
46
|
+
return lambda path: torch.load(path, weights_only=True)
|
47
|
+
if checkpoint_path.endswith(".safetensors"):
|
48
|
+
return safetensors.torch.load_file
|
49
|
+
raise ValueError(f"Unsupported checkpoint format: {checkpoint_path}")
|
50
|
+
|
51
|
+
|
26
52
|
def load_safetensors(full_path: str):
|
27
53
|
"""Loads safetensors into a single state dictionary.
|
28
54
|
|
@@ -117,7 +143,12 @@ class ModelLoader:
|
|
117
143
|
final_norm: str = None
|
118
144
|
lm_head: str = None
|
119
145
|
|
120
|
-
def __init__(
|
146
|
+
def __init__(
|
147
|
+
self,
|
148
|
+
file_name: str,
|
149
|
+
names: TensorNames,
|
150
|
+
custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
|
151
|
+
) -> None:
|
121
152
|
"""ModelLoader constructor.
|
122
153
|
|
123
154
|
Can be used to load multiple models of the same type.
|
@@ -126,10 +157,15 @@ class ModelLoader:
|
|
126
157
|
file_name (str): Path to the checkpoint. Can be a directory or an exact
|
127
158
|
file.
|
128
159
|
names (TensorNames): An instance of `TensorNames` to determine mappings.
|
160
|
+
custom_loader (Callable[[str], Dict[str, torch.Tensor]]): A custom
|
161
|
+
loader to be used. If not provided, the class will determine a proper
|
162
|
+
loader.
|
129
163
|
"""
|
130
164
|
self._file_name = file_name
|
131
165
|
self._names = names
|
132
|
-
self._loader =
|
166
|
+
self._loader = (
|
167
|
+
custom_loader if custom_loader is not None else self._get_loader()
|
168
|
+
)
|
133
169
|
|
134
170
|
def get_state(self) -> Dict[str, torch.Tensor]:
|
135
171
|
return self._loader(self._file_name)
|
@@ -16,6 +16,7 @@
|
|
16
16
|
"""Utilities to be used for re-authoring transformer models."""
|
17
17
|
|
18
18
|
import copy
|
19
|
+
from typing import Callable, Dict
|
19
20
|
from typing import Optional, Tuple
|
20
21
|
|
21
22
|
from ai_edge_torch.generative.layers import attention
|
@@ -160,9 +161,12 @@ def build_decoder_only_model(
|
|
160
161
|
config: cfg.ModelConfig,
|
161
162
|
tensor_names: loading_utils.ModelLoader.TensorNames,
|
162
163
|
model_class: type[nn.Module] = DecoderOnlyModel,
|
164
|
+
custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
|
163
165
|
) -> nn.Module:
|
164
166
|
transformer = model_class(config)
|
165
|
-
loader = loading_utils.ModelLoader(
|
167
|
+
loader = loading_utils.ModelLoader(
|
168
|
+
checkpoint_path, tensor_names, custom_loader
|
169
|
+
)
|
166
170
|
loader.load(
|
167
171
|
transformer, strict=not config.lm_head_share_weight_with_embedding
|
168
172
|
)
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.5.0.
|
3
|
+
Version: 0.5.0.dev20250515
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=8sPR_5uXJA4NEE0nIwNdSl-ADOJEoR8hAgYvBQDY70Y,120
|
|
2
2
|
ai_edge_torch/_config.py,sha256=AiqhbcheF7j_ozIGDLC89k1we95aVgFDa-tR6h7UI0s,2529
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
4
4
|
ai_edge_torch/model.py,sha256=wxjSFq_rBSxSqbUE8E8EJTCkgvgaRLjq_ZuAM-IZpCU,5606
|
5
|
-
ai_edge_torch/version.py,sha256=
|
5
|
+
ai_edge_torch/version.py,sha256=QVmEdwoLJem1gNQul_CoRyfqOc1Ljjy48x9GmKmuAOU,706
|
6
6
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
7
7
|
ai_edge_torch/_convert/conversion.py,sha256=QVugYVfbyaeBgSKKbhFzHG5oXA7t3M-40JcpcdSu6W8,5436
|
8
8
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
@@ -68,12 +68,12 @@ ai_edge_torch/generative/examples/gemma/verify_gemma1.py,sha256=ip-Gmk4CI5f0GWSd
|
|
68
68
|
ai_edge_torch/generative/examples/gemma/verify_gemma2.py,sha256=jhiyinOqPt5ZZjEadDRZt_wY5fiLSCpMo54PcxFaL_Q,1789
|
69
69
|
ai_edge_torch/generative/examples/gemma/verify_util.py,sha256=n7f2nF6Lin_tDvPs0JVldsuaBzo7pAwi5YAHAhlIxQg,6139
|
70
70
|
ai_edge_torch/generative/examples/gemma3/__init__.py,sha256=JaAnrFoXTl3RJX97XspklkTyqOHVyAgRJsZtzNDd10c,671
|
71
|
-
ai_edge_torch/generative/examples/gemma3/convert_gemma3_to_tflite.py,sha256=
|
72
|
-
ai_edge_torch/generative/examples/gemma3/decoder.py,sha256=
|
73
|
-
ai_edge_torch/generative/examples/gemma3/gemma3.py,sha256=
|
71
|
+
ai_edge_torch/generative/examples/gemma3/convert_gemma3_to_tflite.py,sha256=SsiK9xKCyboi5y-HdoFSN02QxRo0XabyzotUq46zO0E,2357
|
72
|
+
ai_edge_torch/generative/examples/gemma3/decoder.py,sha256=shdgLzKDUi0vyNOAsrIVAEFb3Adltsri6Rx1-wxzVf4,15089
|
73
|
+
ai_edge_torch/generative/examples/gemma3/gemma3.py,sha256=ZorRtnbElWsctcA0nEbfwjx0C578voF7fjFEvWSR5Ck,6582
|
74
74
|
ai_edge_torch/generative/examples/gemma3/image_encoder.py,sha256=uRoLoBWzFtQz5wFZfPCxbkvZsgPAqSkUUsV3977GbYc,5184
|
75
75
|
ai_edge_torch/generative/examples/gemma3/verify_gemma3.py,sha256=v8oNXFICmVOtQxfO7IhZ8GnbvotEkDi9lzYHjoQyOso,2464
|
76
|
-
ai_edge_torch/generative/examples/gemma3/verify_util.py,sha256=
|
76
|
+
ai_edge_torch/generative/examples/gemma3/verify_util.py,sha256=1vfAtayH_I_qTpqhzu6n9xnCuvhgTzhS8IzZviW2dJQ,9418
|
77
77
|
ai_edge_torch/generative/examples/hammer/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
78
78
|
ai_edge_torch/generative/examples/hammer/convert_to_tflite.py,sha256=9r8LXyaoBXYIIhhe1WQgEIjaxALQPE1dO2N6qopyWCk,1753
|
79
79
|
ai_edge_torch/generative/examples/hammer/hammer.py,sha256=76INcjffvaNCQ02fzXcxJUW_6EKHs4sg3q1nDBbEpHE,3431
|
@@ -90,7 +90,7 @@ ai_edge_torch/generative/examples/openelm/convert_to_tflite.py,sha256=wRdT7bWbCX
|
|
90
90
|
ai_edge_torch/generative/examples/openelm/openelm.py,sha256=hPcXYHj-nBP56TOeQQejB3HRzv6yHSftHOx0OEPP5M8,4574
|
91
91
|
ai_edge_torch/generative/examples/openelm/verify.py,sha256=4W26ZtPF5Cb9mpHYuRM4b2QB_4W76zf4WV36KzexVjs,2446
|
92
92
|
ai_edge_torch/generative/examples/paligemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
93
|
-
ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py,sha256=
|
93
|
+
ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py,sha256=fkP-mWrih1s-vgJ41fLt8v5JE-UOs8Zrngh6ElQ6PMw,1997
|
94
94
|
ai_edge_torch/generative/examples/paligemma/decoder.py,sha256=-EYUZp55dfRY1E-N0Pr3b9i5c7Tt1XvYxvsRixguVS8,5527
|
95
95
|
ai_edge_torch/generative/examples/paligemma/decoder2.py,sha256=WB8r-e_Crog1ItBq3Zse_nUG-foFyBcJsuEG26r_Ji8,6076
|
96
96
|
ai_edge_torch/generative/examples/paligemma/image_encoder.py,sha256=SvuR97sjkBtfkerH7Hu1UXB8kCFLpEATNbPfCbNAyfo,5614
|
@@ -114,7 +114,7 @@ ai_edge_torch/generative/examples/qwen/convert_to_tflite.py,sha256=eOpv3scJr4mVs
|
|
114
114
|
ai_edge_torch/generative/examples/qwen/qwen.py,sha256=m8APYzo9N0SXsdvCxC8HtCcbN3W7gLKkRBL-Tg0BWXU,4223
|
115
115
|
ai_edge_torch/generative/examples/qwen/verify.py,sha256=9_AyEJTeUfvhhID64Rto2bflFPyXMFokdQLsseLUMiI,2775
|
116
116
|
ai_edge_torch/generative/examples/qwen_vl/__init__.py,sha256=JaAnrFoXTl3RJX97XspklkTyqOHVyAgRJsZtzNDd10c,671
|
117
|
-
ai_edge_torch/generative/examples/qwen_vl/convert_to_tflite.py,sha256=
|
117
|
+
ai_edge_torch/generative/examples/qwen_vl/convert_to_tflite.py,sha256=4Gntv6LBIxd0CaKkb-koLzGTdBEOGgVf3ob99lAuvuY,2196
|
118
118
|
ai_edge_torch/generative/examples/qwen_vl/decoder.py,sha256=7RFM25tDj_b0FkpSv8RUWir8K8v9p2jMtwZmP4VAUhw,4474
|
119
119
|
ai_edge_torch/generative/examples/qwen_vl/image_encoder.py,sha256=nHzBe_YSPnUe1d5i09v4bePQomVifzJNeUjRfprmxC0,14878
|
120
120
|
ai_edge_torch/generative/examples/qwen_vl/qwen_vl.py,sha256=mfLFrT8NPEPh9CqlJYHwh-I2y6ST7hH_vEmbZYartHQ,7764
|
@@ -166,7 +166,8 @@ ai_edge_torch/generative/layers/feed_forward_test.py,sha256=8ZGy79BBpsyS6yKKDEKr
|
|
166
166
|
ai_edge_torch/generative/layers/kv_cache.py,sha256=b-7shzDaKexmvQF7P3SiAmIz4ZofjYWv3m5u71GojsA,10460
|
167
167
|
ai_edge_torch/generative/layers/lora.py,sha256=hsvWLLOnW7HQ0AysOZu30x_cetMquDd1tjfyLz8HCSU,17892
|
168
168
|
ai_edge_torch/generative/layers/model_config.py,sha256=X_gjN5524DCDBNXsX5GrOBlkKM4UHzj_RfdCD0-VOxQ,8572
|
169
|
-
ai_edge_torch/generative/layers/normalization.py,sha256=
|
169
|
+
ai_edge_torch/generative/layers/normalization.py,sha256=ijwCpi22NLX-Sygwy5sK9l9WjGvbPIhZvVwoBAonWAo,7014
|
170
|
+
ai_edge_torch/generative/layers/normalization_test.py,sha256=zwurZly-TgFxdgVVdpzu9vCpcLbd5RYt_gKg9Lfg1jI,2248
|
170
171
|
ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=975zR202MdIrILJ7blceAcxrNqX1ZCN0ECKG1gz-bV8,2655
|
171
172
|
ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=2_AgwENsaOgaxgiSqgoj0V0JzQ09dFtP_nBhX-lJK2g,5648
|
172
173
|
ai_edge_torch/generative/layers/scaled_dot_product_attention_test.py,sha256=c6JBMQsq9XeMmR1XvGEIidNsoh-YIvichXo2LwVHgr4,3301
|
@@ -192,10 +193,10 @@ ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=vQWmpzMkJ2hP
|
|
192
193
|
ai_edge_torch/generative/test/test_quantize.py,sha256=kKJ01wscTC2t_Ylr7huO5gNKES01gm3dT1gx52z15PA,7356
|
193
194
|
ai_edge_torch/generative/test/utils.py,sha256=tF6aCfAGJnc9dmzCnZCEOuKNVimfWOqscv9og0DDLHU,2656
|
194
195
|
ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
|
195
|
-
ai_edge_torch/generative/utilities/converter.py,sha256=
|
196
|
+
ai_edge_torch/generative/utilities/converter.py,sha256=4zcDlhgCQQyLylH8NLgVjnelou2pW6HWJHBFYsFyHuw,15020
|
196
197
|
ai_edge_torch/generative/utilities/export_config.py,sha256=5IvR3grlMd4mWO5c_Y4x9Fk1b1xa57MzlYNE8XUaN28,2049
|
197
|
-
ai_edge_torch/generative/utilities/loader.py,sha256=
|
198
|
-
ai_edge_torch/generative/utilities/model_builder.py,sha256=
|
198
|
+
ai_edge_torch/generative/utilities/loader.py,sha256=tSiew77hB_zyn6rpcfegSg1zrriqHSz63KjV9_llBxg,14893
|
199
|
+
ai_edge_torch/generative/utilities/model_builder.py,sha256=tBfOcsI_NcneggHqkCSydYN3ZgmkzPc6nW0AJrA81wI,6461
|
199
200
|
ai_edge_torch/generative/utilities/moonshine_loader.py,sha256=_RpFabSqtGH5PHiP3_1f6QfO14qMADUxr_HGRlVDFB0,4891
|
200
201
|
ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=dqPD9qRXEWtU3ombslOC-BE2l_dMwHoCNu7NsIJhsso,36158
|
201
202
|
ai_edge_torch/generative/utilities/t5_loader.py,sha256=tEsfy8-ymzbbjOIc-oesXF3yGyyWtJgFXn2s7VOavt8,16961
|
@@ -251,8 +252,8 @@ ai_edge_torch/testing/__init__.py,sha256=_yGgvnBZWb7T3IN3mc4x1sS4vM96HZwM8pwIcPG
|
|
251
252
|
ai_edge_torch/testing/export.py,sha256=k5mGDGzwc23Z4zaIVDs8CNh-oOt64gsf9MS9NjhbPy4,3293
|
252
253
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
253
254
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
254
|
-
ai_edge_torch_nightly-0.5.0.
|
255
|
-
ai_edge_torch_nightly-0.5.0.
|
256
|
-
ai_edge_torch_nightly-0.5.0.
|
257
|
-
ai_edge_torch_nightly-0.5.0.
|
258
|
-
ai_edge_torch_nightly-0.5.0.
|
255
|
+
ai_edge_torch_nightly-0.5.0.dev20250515.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
256
|
+
ai_edge_torch_nightly-0.5.0.dev20250515.dist-info/METADATA,sha256=FmCPouaJYszNPCOfgIx8WGFkGv5LrqR6_OGpciU2eKc,2074
|
257
|
+
ai_edge_torch_nightly-0.5.0.dev20250515.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
|
258
|
+
ai_edge_torch_nightly-0.5.0.dev20250515.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
259
|
+
ai_edge_torch_nightly-0.5.0.dev20250515.dist-info/RECORD,,
|
File without changes
|
File without changes
|