ai-edge-torch-nightly 0.5.0.dev20250515__py3-none-any.whl → 0.5.0.dev20250516__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/__init__.py +1 -0
- ai_edge_torch/_convert/conversion.py +23 -0
- ai_edge_torch/_convert/converter.py +57 -3
- ai_edge_torch/_convert/test/test_convert.py +25 -0
- ai_edge_torch/generative/examples/amd_llama_135m/amd_llama_135m.py +9 -2
- ai_edge_torch/generative/examples/amd_llama_135m/convert_to_tflite.py +7 -1
- ai_edge_torch/generative/examples/deepseek/convert_to_tflite.py +7 -2
- ai_edge_torch/generative/examples/deepseek/deepseek.py +8 -1
- ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py +7 -1
- ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +7 -1
- ai_edge_torch/generative/examples/gemma/gemma1.py +9 -1
- ai_edge_torch/generative/examples/gemma/gemma2.py +7 -2
- ai_edge_torch/generative/examples/gemma3/convert_gemma3_to_tflite.py +5 -14
- ai_edge_torch/generative/examples/hammer/convert_to_tflite.py +7 -1
- ai_edge_torch/generative/examples/hammer/hammer.py +14 -2
- ai_edge_torch/generative/examples/llama/convert_to_tflite.py +7 -1
- ai_edge_torch/generative/examples/llama/llama.py +25 -6
- ai_edge_torch/generative/examples/moonshine/convert_moonshine_to_tflite.py +0 -1
- ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +7 -1
- ai_edge_torch/generative/examples/openelm/openelm.py +8 -1
- ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py +6 -1
- ai_edge_torch/generative/examples/paligemma/decoder.py +1 -0
- ai_edge_torch/generative/examples/paligemma/decoder2.py +1 -0
- ai_edge_torch/generative/examples/paligemma/image_encoder.py +2 -1
- ai_edge_torch/generative/examples/paligemma/paligemma.py +12 -5
- ai_edge_torch/generative/examples/paligemma/verify.py +27 -5
- ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py +7 -1
- ai_edge_torch/generative/examples/phi/convert_phi4_to_tflite.py +7 -1
- ai_edge_torch/generative/examples/phi/convert_to_tflite.py +7 -1
- ai_edge_torch/generative/examples/phi/phi2.py +8 -1
- ai_edge_torch/generative/examples/phi/phi3.py +7 -2
- ai_edge_torch/generative/examples/phi/phi4.py +7 -2
- ai_edge_torch/generative/examples/qwen/convert_to_tflite.py +7 -1
- ai_edge_torch/generative/examples/qwen/qwen.py +20 -3
- ai_edge_torch/generative/examples/qwen_vl/convert_to_tflite.py +6 -1
- ai_edge_torch/generative/examples/qwen_vl/decoder.py +1 -2
- ai_edge_torch/generative/examples/qwen_vl/image_encoder.py +12 -4
- ai_edge_torch/generative/examples/qwen_vl/qwen_vl.py +12 -4
- ai_edge_torch/generative/examples/qwen_vl/verify.py +26 -5
- ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +7 -2
- ai_edge_torch/generative/examples/smollm/convert_v2_to_tflite.py +7 -1
- ai_edge_torch/generative/examples/smollm/smollm.py +14 -2
- ai_edge_torch/generative/examples/smollm/verify.py +2 -2
- ai_edge_torch/generative/examples/stable_diffusion/clip.py +2 -1
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +7 -1
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +8 -1
- ai_edge_torch/generative/utilities/converter.py +16 -4
- ai_edge_torch/generative/utilities/loader.py +19 -0
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.5.0.dev20250515.dist-info → ai_edge_torch_nightly-0.5.0.dev20250516.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.5.0.dev20250515.dist-info → ai_edge_torch_nightly-0.5.0.dev20250516.dist-info}/RECORD +54 -54
- {ai_edge_torch_nightly-0.5.0.dev20250515.dist-info → ai_edge_torch_nightly-0.5.0.dev20250516.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.5.0.dev20250515.dist-info → ai_edge_torch_nightly-0.5.0.dev20250516.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.5.0.dev20250515.dist-info → ai_edge_torch_nightly-0.5.0.dev20250516.dist-info}/top_level.txt +0 -0
ai_edge_torch/__init__.py
CHANGED
@@ -14,6 +14,7 @@
|
|
14
14
|
# ==============================================================================
|
15
15
|
from ai_edge_torch._config import config
|
16
16
|
from ai_edge_torch._convert.converter import convert
|
17
|
+
from ai_edge_torch._convert.converter import experimental_add_compilation_backend
|
17
18
|
from ai_edge_torch._convert.converter import signature
|
18
19
|
from ai_edge_torch._convert.to_channel_last_io import to_channel_last_io
|
19
20
|
from ai_edge_torch.model import Model
|
@@ -26,6 +26,9 @@ from ai_edge_torch.generative import fx_passes as generative_fx_passes
|
|
26
26
|
from ai_edge_torch.quantize import quant_config as qcfg
|
27
27
|
import torch
|
28
28
|
|
29
|
+
from ai_edge_litert.aot import aot_compile as aot_compile_lib
|
30
|
+
from ai_edge_litert.aot.core import types as litert_types
|
31
|
+
|
29
32
|
|
30
33
|
def _run_convert_passes(
|
31
34
|
exported_program: torch.export.ExportedProgram,
|
@@ -153,3 +156,23 @@ def convert_signatures(
|
|
153
156
|
)
|
154
157
|
|
155
158
|
return model.TfLiteModel(tflite_model)
|
159
|
+
|
160
|
+
|
161
|
+
def aot_compile(
|
162
|
+
compilation_configs: list[litert_types.CompilationConfig],
|
163
|
+
cpu_model: model.TfLiteModel,
|
164
|
+
) -> litert_types.CompilationResult:
|
165
|
+
"""Compiles the given CPU model.
|
166
|
+
|
167
|
+
Args:
|
168
|
+
compilation_configs: The list of compilation configs to use.
|
169
|
+
cpu_model: The CPU model to compile.
|
170
|
+
|
171
|
+
Returns:
|
172
|
+
The compilation result.
|
173
|
+
"""
|
174
|
+
litert_model = litert_types.Model.create_from_bytes(cpu_model.tflite_model())
|
175
|
+
return aot_compile_lib.aot_compile(
|
176
|
+
litert_model,
|
177
|
+
config=compilation_configs,
|
178
|
+
)
|
@@ -23,6 +23,9 @@ from ai_edge_torch._convert import signature as signature_module
|
|
23
23
|
from ai_edge_torch.quantize import quant_config as qcfg
|
24
24
|
import torch
|
25
25
|
|
26
|
+
from ai_edge_litert.aot.core import types as litert_types
|
27
|
+
from ai_edge_litert.aot.vendors import import_vendor as vendor_lib
|
28
|
+
|
26
29
|
|
27
30
|
class Converter:
|
28
31
|
"""A converter for converting PyTorch models to edge models.
|
@@ -32,6 +35,7 @@ class Converter:
|
|
32
35
|
|
33
36
|
def __init__(self):
|
34
37
|
self._signatures: list[signature_module.Signature] = []
|
38
|
+
self._compilation_configs: list[litert_types.CompilationConfig] = []
|
35
39
|
|
36
40
|
def signature(
|
37
41
|
self,
|
@@ -96,6 +100,31 @@ class Converter:
|
|
96
100
|
)
|
97
101
|
return self
|
98
102
|
|
103
|
+
def experimental_add_compilation_backend(
|
104
|
+
self,
|
105
|
+
target: litert_types.Target | None = None,
|
106
|
+
**kwargs: litert_types.Config,
|
107
|
+
) -> Converter:
|
108
|
+
"""Adds an AOT compilation target to the converter.
|
109
|
+
|
110
|
+
NOTE: This API is experimental and subject to change.
|
111
|
+
|
112
|
+
Args:
|
113
|
+
target: The target to compile for. If not specified, will compile to all
|
114
|
+
registered AOT targets in ai_edge_litert. See ai_edge_litert.aot.vendors
|
115
|
+
for more details. Adding a same target multiple times will be a no-op.
|
116
|
+
**kwargs: Additional arguments to pass to the backend compiler.
|
117
|
+
|
118
|
+
Returns:
|
119
|
+
The converter object itself.
|
120
|
+
"""
|
121
|
+
if target is None:
|
122
|
+
target = vendor_lib.AllRegisteredTarget()
|
123
|
+
if isinstance(target, litert_types.Target):
|
124
|
+
target = litert_types.CompilationConfig(target=target, **kwargs)
|
125
|
+
self._compilation_configs.append(target)
|
126
|
+
return self
|
127
|
+
|
99
128
|
def convert(
|
100
129
|
self,
|
101
130
|
module: torch.nn.Module = None,
|
@@ -107,7 +136,7 @@ class Converter:
|
|
107
136
|
dynamic_shapes: Optional[Union[dict[str, Any], Tuple[Any, ...]]] = None,
|
108
137
|
_ai_edge_converter_flags: Optional[dict[str, Any]] = None,
|
109
138
|
_saved_model_dir: Optional[str] = None,
|
110
|
-
) -> model.TfLiteModel:
|
139
|
+
) -> model.TfLiteModel | litert_types.CompilationResult:
|
111
140
|
"""Finalizes the conversion and produces an edge model.
|
112
141
|
|
113
142
|
This could be called with no arguments as follows:
|
@@ -144,7 +173,9 @@ class Converter:
|
|
144
173
|
specified, a random temporary directory would be used.
|
145
174
|
|
146
175
|
Returns:
|
147
|
-
The converted edge model.
|
176
|
+
The converted edge model. If compilation configs are provided, returns the
|
177
|
+
compilation result that contains the compiled edge models for different
|
178
|
+
targets.
|
148
179
|
|
149
180
|
Raises:
|
150
181
|
ValueError: If the arguments are not provided as expected. See the example
|
@@ -169,13 +200,16 @@ class Converter:
|
|
169
200
|
"sample_args or sample_kwargs must be provided if a module is"
|
170
201
|
" specified."
|
171
202
|
)
|
172
|
-
|
203
|
+
converted_model = conversion.convert_signatures(
|
173
204
|
self._signatures,
|
174
205
|
strict_export=strict_export,
|
175
206
|
quant_config=quant_config,
|
176
207
|
_tfl_converter_flags=_ai_edge_converter_flags,
|
177
208
|
_saved_model_dir=_saved_model_dir,
|
178
209
|
)
|
210
|
+
if self._compilation_configs:
|
211
|
+
return conversion.aot_compile(self._compilation_configs, converted_model)
|
212
|
+
return converted_model
|
179
213
|
|
180
214
|
|
181
215
|
def signature(
|
@@ -211,6 +245,26 @@ def signature(
|
|
211
245
|
)
|
212
246
|
|
213
247
|
|
248
|
+
def experimental_add_compilation_backend(
|
249
|
+
target: litert_types.Target | None = None,
|
250
|
+
**kwargs: litert_types.Config,
|
251
|
+
) -> Converter:
|
252
|
+
"""Adds an AOT compilation target to the converter.
|
253
|
+
|
254
|
+
NOTE: This API is experimental and subject to change.
|
255
|
+
|
256
|
+
Args:
|
257
|
+
target: The target to compile for. If not specified, will compile to all
|
258
|
+
registered AOT targets in ai_edge_litert. See ai_edge_litert.aot.vendors
|
259
|
+
for more details. Adding a same target multiple times will be a no-op.
|
260
|
+
**kwargs: Additional arguments to pass to the backend compiler.
|
261
|
+
|
262
|
+
Returns:
|
263
|
+
The converter object itself.
|
264
|
+
"""
|
265
|
+
return Converter().experimental_add_compilation_backend(target, **kwargs)
|
266
|
+
|
267
|
+
|
214
268
|
def convert(
|
215
269
|
module: torch.nn.Module = None,
|
216
270
|
sample_args=None,
|
@@ -29,6 +29,8 @@ from torch.ao.quantization import quantize_pt2e
|
|
29
29
|
import torchvision
|
30
30
|
|
31
31
|
from absl.testing import absltest as googletest
|
32
|
+
from ai_edge_litert.aot.core import types as litert_types
|
33
|
+
from ai_edge_litert.aot.vendors import fallback_backend
|
32
34
|
from ai_edge_litert import interpreter as tfl_interpreter # pylint: disable=g-direct-tensorflow-import
|
33
35
|
|
34
36
|
|
@@ -574,6 +576,29 @@ class TestConvert(googletest.TestCase):
|
|
574
576
|
self.fail(f"Conversion failed with bloat16 inputs: {err}")
|
575
577
|
# pylint: enable=broad-except
|
576
578
|
|
579
|
+
def test_compile_model(self):
|
580
|
+
"""Tests AOT compilation of a simple Add module."""
|
581
|
+
|
582
|
+
class Add(nn.Module):
|
583
|
+
|
584
|
+
def forward(self, a, b):
|
585
|
+
return a + b
|
586
|
+
|
587
|
+
args = (
|
588
|
+
torch.randn((5, 10)),
|
589
|
+
torch.randn((5, 10)),
|
590
|
+
)
|
591
|
+
torch_module = Add().eval()
|
592
|
+
compilation_result = ai_edge_torch.experimental_add_compilation_backend(
|
593
|
+
fallback_backend.FallbackTarget()
|
594
|
+
).convert(torch_module, args)
|
595
|
+
assert isinstance(compilation_result, litert_types.CompilationResult)
|
596
|
+
self.assertLen(compilation_result.models_with_backend, 1)
|
597
|
+
self.assertEqual(
|
598
|
+
compilation_result.models_with_backend[0][0].id(),
|
599
|
+
fallback_backend.FallbackBackend.id(),
|
600
|
+
)
|
601
|
+
|
577
602
|
|
578
603
|
if __name__ == "__main__":
|
579
604
|
googletest.main()
|
@@ -15,8 +15,10 @@
|
|
15
15
|
|
16
16
|
"""Example of building AMD-Llama-135m."""
|
17
17
|
|
18
|
+
from typing import Callable, Dict
|
18
19
|
import ai_edge_torch.generative.layers.model_config as cfg
|
19
20
|
from ai_edge_torch.generative.utilities import model_builder
|
21
|
+
import torch
|
20
22
|
from torch import nn
|
21
23
|
|
22
24
|
TENSOR_NAMES = model_builder.TENSOR_NAMES_WITH_SEPARATE_LM_HEAD
|
@@ -80,10 +82,15 @@ def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
|
|
80
82
|
return config
|
81
83
|
|
82
84
|
|
83
|
-
def build_model(
|
85
|
+
def build_model(
|
86
|
+
checkpoint_path: str,
|
87
|
+
custom_loader: Callable[[str], Dict[str, torch.Tensor]] | None = None,
|
88
|
+
**kwargs
|
89
|
+
) -> nn.Module:
|
84
90
|
return model_builder.build_decoder_only_model(
|
85
91
|
checkpoint_path=checkpoint_path,
|
86
92
|
config=get_model_config(**kwargs),
|
87
93
|
tensor_names=TENSOR_NAMES,
|
88
|
-
model_class=AmdLlama
|
94
|
+
model_class=AmdLlama,
|
95
|
+
custom_loader=custom_loader,
|
89
96
|
)
|
@@ -19,13 +19,19 @@ from absl import app
|
|
19
19
|
from ai_edge_torch.generative.examples.amd_llama_135m import amd_llama_135m
|
20
20
|
from ai_edge_torch.generative.utilities import converter
|
21
21
|
from ai_edge_torch.generative.utilities import export_config
|
22
|
+
from ai_edge_torch.generative.utilities import loader
|
22
23
|
|
23
24
|
flags = converter.define_conversion_flags("amd-llama-135m")
|
24
25
|
|
25
26
|
|
26
27
|
def main(_):
|
28
|
+
checkpoint_path = flags.FLAGS.checkpoint_path
|
27
29
|
pytorch_model = amd_llama_135m.build_model(
|
28
|
-
|
30
|
+
checkpoint_path,
|
31
|
+
custom_loader=loader.maybe_get_custom_loader(
|
32
|
+
checkpoint_path, flags.FLAGS.custom_checkpoint_loader
|
33
|
+
),
|
34
|
+
kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
|
29
35
|
)
|
30
36
|
converter.convert_to_tflite(
|
31
37
|
pytorch_model,
|
@@ -17,15 +17,20 @@
|
|
17
17
|
|
18
18
|
from absl import app
|
19
19
|
from ai_edge_torch.generative.examples.deepseek import deepseek
|
20
|
-
from ai_edge_torch.generative.layers import kv_cache
|
21
20
|
from ai_edge_torch.generative.utilities import converter
|
22
21
|
from ai_edge_torch.generative.utilities import export_config
|
22
|
+
from ai_edge_torch.generative.utilities import loader
|
23
23
|
|
24
24
|
flags = converter.define_conversion_flags('deepseek')
|
25
25
|
|
26
26
|
def main(_):
|
27
|
+
checkpoint_path = flags.FLAGS.checkpoint_path
|
27
28
|
pytorch_model = deepseek.build_model(
|
28
|
-
|
29
|
+
checkpoint_path,
|
30
|
+
custom_loader=loader.maybe_get_custom_loader(
|
31
|
+
checkpoint_path, flags.FLAGS.custom_checkpoint_loader
|
32
|
+
),
|
33
|
+
kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
|
29
34
|
)
|
30
35
|
converter.convert_to_tflite(
|
31
36
|
pytorch_model,
|
@@ -15,8 +15,10 @@
|
|
15
15
|
|
16
16
|
"""Example of building DeepSeek R1 distilled models."""
|
17
17
|
|
18
|
+
from typing import Callable, Dict
|
18
19
|
import ai_edge_torch.generative.layers.model_config as cfg
|
19
20
|
from ai_edge_torch.generative.utilities import model_builder
|
21
|
+
import torch
|
20
22
|
from torch import nn
|
21
23
|
|
22
24
|
TENSOR_NAMES = model_builder.TENSOR_NAMES_WITH_SEPARATE_LM_HEAD
|
@@ -84,10 +86,15 @@ def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
|
|
84
86
|
return config
|
85
87
|
|
86
88
|
|
87
|
-
def build_model(
|
89
|
+
def build_model(
|
90
|
+
checkpoint_path: str,
|
91
|
+
custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
|
92
|
+
**kwargs
|
93
|
+
) -> nn.Module:
|
88
94
|
return model_builder.build_decoder_only_model(
|
89
95
|
checkpoint_path=checkpoint_path,
|
90
96
|
config=get_model_config(**kwargs),
|
91
97
|
tensor_names=TENSOR_NAMES,
|
92
98
|
model_class=DeepSeekDistillQwen,
|
99
|
+
custom_loader=custom_loader,
|
93
100
|
)
|
@@ -19,13 +19,19 @@ from absl import app
|
|
19
19
|
from ai_edge_torch.generative.examples.gemma import gemma1
|
20
20
|
from ai_edge_torch.generative.utilities import converter
|
21
21
|
from ai_edge_torch.generative.utilities import export_config
|
22
|
+
from ai_edge_torch.generative.utilities import loader
|
22
23
|
|
23
24
|
flags = converter.define_conversion_flags("gemma-2b")
|
24
25
|
|
25
26
|
|
26
27
|
def main(_):
|
28
|
+
checkpoint_path = flags.FLAGS.checkpoint_path
|
27
29
|
pytorch_model = gemma1.build_2b_model(
|
28
|
-
|
30
|
+
checkpoint_path,
|
31
|
+
custom_loader=loader.maybe_get_custom_loader(
|
32
|
+
checkpoint_path, flags.FLAGS.custom_checkpoint_loader
|
33
|
+
),
|
34
|
+
kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
|
29
35
|
)
|
30
36
|
converter.convert_to_tflite(
|
31
37
|
pytorch_model,
|
@@ -19,6 +19,7 @@ from absl import app
|
|
19
19
|
from ai_edge_torch.generative.examples.gemma import gemma2
|
20
20
|
from ai_edge_torch.generative.utilities import converter
|
21
21
|
from ai_edge_torch.generative.utilities import export_config
|
22
|
+
from ai_edge_torch.generative.utilities import loader
|
22
23
|
|
23
24
|
flags = converter.define_conversion_flags(
|
24
25
|
"gemma2-2b", default_mask_as_input=True, default_transpose_kv_cache=True
|
@@ -26,8 +27,13 @@ flags = converter.define_conversion_flags(
|
|
26
27
|
|
27
28
|
|
28
29
|
def main(_):
|
30
|
+
checkpoint_path = flags.FLAGS.checkpoint_path
|
29
31
|
pytorch_model = gemma2.build_2b_model(
|
30
|
-
|
32
|
+
checkpoint_path,
|
33
|
+
custom_loader=loader.maybe_get_custom_loader(
|
34
|
+
checkpoint_path, flags.FLAGS.custom_checkpoint_loader
|
35
|
+
),
|
36
|
+
kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
|
31
37
|
)
|
32
38
|
converter.convert_to_tflite(
|
33
39
|
pytorch_model,
|
@@ -15,9 +15,12 @@
|
|
15
15
|
|
16
16
|
"""Example of building a Gemma1 model."""
|
17
17
|
|
18
|
+
from typing import Callable, Dict
|
19
|
+
|
18
20
|
import ai_edge_torch.generative.layers.model_config as cfg
|
19
21
|
from ai_edge_torch.generative.utilities import model_builder
|
20
22
|
import ai_edge_torch.generative.utilities.loader as loading_utils
|
23
|
+
import torch
|
21
24
|
from torch import nn
|
22
25
|
|
23
26
|
TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
@@ -99,10 +102,15 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
|
|
99
102
|
return config
|
100
103
|
|
101
104
|
|
102
|
-
def build_2b_model(
|
105
|
+
def build_2b_model(
|
106
|
+
checkpoint_path: str,
|
107
|
+
custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
|
108
|
+
**kwargs
|
109
|
+
) -> nn.Module:
|
103
110
|
return model_builder.build_decoder_only_model(
|
104
111
|
checkpoint_path=checkpoint_path,
|
105
112
|
config=get_model_config_2b(**kwargs),
|
106
113
|
tensor_names=TENSOR_NAMES,
|
107
114
|
model_class=Gemma1,
|
115
|
+
custom_loader=custom_loader,
|
108
116
|
)
|
@@ -15,7 +15,7 @@
|
|
15
15
|
|
16
16
|
"""Example of building a Gemma2 model."""
|
17
17
|
|
18
|
-
from typing import List, Optional, Tuple
|
18
|
+
from typing import Callable, Dict, List, Optional, Tuple
|
19
19
|
|
20
20
|
from ai_edge_torch.generative.layers import attention
|
21
21
|
from ai_edge_torch.generative.layers import builder
|
@@ -306,7 +306,11 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
|
|
306
306
|
return config
|
307
307
|
|
308
308
|
|
309
|
-
def build_2b_model(
|
309
|
+
def build_2b_model(
|
310
|
+
checkpoint_path: str,
|
311
|
+
custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
|
312
|
+
**kwargs,
|
313
|
+
) -> nn.Module:
|
310
314
|
for tensor_names in TENSOR_NAMES_DICT.values():
|
311
315
|
try:
|
312
316
|
return model_builder.build_decoder_only_model(
|
@@ -314,6 +318,7 @@ def build_2b_model(checkpoint_path: str, **kwargs) -> nn.Module:
|
|
314
318
|
config=get_model_config_2b(**kwargs),
|
315
319
|
tensor_names=tensor_names,
|
316
320
|
model_class=Gemma2,
|
321
|
+
custom_loader=custom_loader,
|
317
322
|
)
|
318
323
|
except KeyError as _:
|
319
324
|
continue
|
@@ -25,13 +25,6 @@ flags = converter.define_conversion_flags(
|
|
25
25
|
'gemma3-1b', default_mask_as_input=True, default_transpose_kv_cache=True
|
26
26
|
)
|
27
27
|
|
28
|
-
_CUSTOM_CHECKPOINT_LOADER = flags.DEFINE_bool(
|
29
|
-
'custom_checkpoint_loader',
|
30
|
-
False,
|
31
|
-
'If true, the conversion script will use a custom checkpoint loader which'
|
32
|
-
' will read a checkpoint from a remote source.',
|
33
|
-
)
|
34
|
-
|
35
28
|
_MODEL_SIZE = flags.DEFINE_string(
|
36
29
|
'model_size',
|
37
30
|
'1b',
|
@@ -40,16 +33,14 @@ _MODEL_SIZE = flags.DEFINE_string(
|
|
40
33
|
|
41
34
|
|
42
35
|
def main(_):
|
43
|
-
|
44
|
-
if flags.FLAGS.custom_checkpoint_loader:
|
45
|
-
# If loading from a remote source, try to get a custom loader first.
|
46
|
-
custom_loader = loader.get_custom_loader(flags.FLAGS.checkpoint_path)
|
47
|
-
|
36
|
+
checkpoint_path = flags.FLAGS.checkpoint_path
|
48
37
|
if _MODEL_SIZE.value == '1b':
|
49
38
|
pytorch_model = gemma3.build_model_1b(
|
50
|
-
|
39
|
+
checkpoint_path,
|
40
|
+
custom_loader=loader.maybe_get_custom_loader(
|
41
|
+
checkpoint_path, flags.FLAGS.custom_checkpoint_loader
|
42
|
+
),
|
51
43
|
kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
|
52
|
-
custom_loader=custom_loader,
|
53
44
|
)
|
54
45
|
else:
|
55
46
|
raise ValueError(f'Unsupported model size: {_MODEL_SIZE.value}')
|
@@ -19,6 +19,7 @@ from absl import app
|
|
19
19
|
from ai_edge_torch.generative.examples.hammer import hammer
|
20
20
|
from ai_edge_torch.generative.utilities import converter
|
21
21
|
from ai_edge_torch.generative.utilities import export_config
|
22
|
+
from ai_edge_torch.generative.utilities import loader
|
22
23
|
|
23
24
|
flags = converter.define_conversion_flags('hammer')
|
24
25
|
|
@@ -36,8 +37,13 @@ _BUILDER = {
|
|
36
37
|
|
37
38
|
|
38
39
|
def main(_):
|
40
|
+
checkpoint_path = flags.FLAGS.checkpoint_path
|
39
41
|
pytorch_model = _BUILDER[_MODEL_SIZE.value](
|
40
|
-
|
42
|
+
checkpoint_path,
|
43
|
+
custom_loader=loader.maybe_get_custom_loader(
|
44
|
+
checkpoint_path, flags.FLAGS.custom_checkpoint_loader
|
45
|
+
),
|
46
|
+
kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
|
41
47
|
)
|
42
48
|
converter.convert_to_tflite(
|
43
49
|
pytorch_model,
|
@@ -15,8 +15,10 @@
|
|
15
15
|
|
16
16
|
"""Example of building Hammer 2.1 models."""
|
17
17
|
|
18
|
+
from typing import Callable, Dict
|
18
19
|
import ai_edge_torch.generative.layers.model_config as cfg
|
19
20
|
from ai_edge_torch.generative.utilities import model_builder
|
21
|
+
import torch
|
20
22
|
from torch import nn
|
21
23
|
|
22
24
|
TENSOR_NAMES = model_builder.TENSOR_NAMES
|
@@ -89,19 +91,29 @@ def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
|
|
89
91
|
return config
|
90
92
|
|
91
93
|
|
92
|
-
def build_1_5b_model(
|
94
|
+
def build_1_5b_model(
|
95
|
+
checkpoint_path: str,
|
96
|
+
custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
|
97
|
+
**kwargs
|
98
|
+
) -> nn.Module:
|
93
99
|
return model_builder.build_decoder_only_model(
|
94
100
|
checkpoint_path=checkpoint_path,
|
95
101
|
config=get_1_5b_model_config(**kwargs),
|
96
102
|
tensor_names=TENSOR_NAMES,
|
97
103
|
model_class=Hammer,
|
104
|
+
custom_loader=custom_loader,
|
98
105
|
)
|
99
106
|
|
100
107
|
|
101
|
-
def build_0_5b_model(
|
108
|
+
def build_0_5b_model(
|
109
|
+
checkpoint_path: str,
|
110
|
+
custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
|
111
|
+
**kwargs
|
112
|
+
) -> nn.Module:
|
102
113
|
return model_builder.build_decoder_only_model(
|
103
114
|
checkpoint_path=checkpoint_path,
|
104
115
|
config=get_0_5b_model_config(**kwargs),
|
105
116
|
tensor_names=TENSOR_NAMES,
|
106
117
|
model_class=Hammer,
|
118
|
+
custom_loader=custom_loader,
|
107
119
|
)
|
@@ -19,6 +19,7 @@ from absl import app
|
|
19
19
|
from ai_edge_torch.generative.examples.llama import llama
|
20
20
|
from ai_edge_torch.generative.utilities import converter
|
21
21
|
from ai_edge_torch.generative.utilities import export_config
|
22
|
+
from ai_edge_torch.generative.utilities import loader
|
22
23
|
|
23
24
|
|
24
25
|
flags = converter.define_conversion_flags('llama')
|
@@ -37,8 +38,13 @@ _BUILDER = {
|
|
37
38
|
|
38
39
|
|
39
40
|
def main(_):
|
41
|
+
checkpoint_path = flags.FLAGS.checkpoint_path
|
40
42
|
pytorch_model = _BUILDER[_MODEL_SIZE.value](
|
41
|
-
|
43
|
+
checkpoint_path,
|
44
|
+
custom_loader=loader.maybe_get_custom_loader(
|
45
|
+
checkpoint_path, flags.FLAGS.custom_checkpoint_loader
|
46
|
+
),
|
47
|
+
kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
|
42
48
|
)
|
43
49
|
converter.convert_to_tflite(
|
44
50
|
pytorch_model,
|
@@ -17,7 +17,7 @@
|
|
17
17
|
|
18
18
|
from functools import partial
|
19
19
|
import math
|
20
|
-
from typing import Tuple
|
20
|
+
from typing import Callable, Dict, Tuple
|
21
21
|
|
22
22
|
import ai_edge_torch.generative.layers.model_config as cfg
|
23
23
|
from ai_edge_torch.generative.utilities import model_builder
|
@@ -180,19 +180,38 @@ def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
|
|
180
180
|
|
181
181
|
|
182
182
|
def _build_model(
|
183
|
-
checkpoint_path: str,
|
183
|
+
checkpoint_path: str,
|
184
|
+
config: cfg.ModelConfig,
|
185
|
+
custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
|
184
186
|
) -> torch.nn.Module:
|
185
187
|
return model_builder.build_decoder_only_model(
|
186
188
|
checkpoint_path=checkpoint_path,
|
187
189
|
config=config,
|
188
190
|
tensor_names=TENSOR_NAMES,
|
189
191
|
model_class=Llama,
|
192
|
+
custom_loader=custom_loader,
|
190
193
|
)
|
191
194
|
|
192
195
|
|
193
|
-
def build_1b_model(
|
194
|
-
|
196
|
+
def build_1b_model(
|
197
|
+
checkpoint_path: str,
|
198
|
+
custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
|
199
|
+
**kwargs
|
200
|
+
) -> torch.nn.Module:
|
201
|
+
return _build_model(
|
202
|
+
checkpoint_path,
|
203
|
+
get_1b_model_config(**kwargs),
|
204
|
+
custom_loader=custom_loader,
|
205
|
+
)
|
195
206
|
|
196
207
|
|
197
|
-
def build_3b_model(
|
198
|
-
|
208
|
+
def build_3b_model(
|
209
|
+
checkpoint_path: str,
|
210
|
+
custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
|
211
|
+
**kwargs
|
212
|
+
) -> torch.nn.Module:
|
213
|
+
return _build_model(
|
214
|
+
checkpoint_path,
|
215
|
+
get_3b_model_config(**kwargs),
|
216
|
+
custom_loader=custom_loader,
|
217
|
+
)
|
@@ -19,13 +19,19 @@ from absl import app
|
|
19
19
|
from ai_edge_torch.generative.examples.openelm import openelm
|
20
20
|
from ai_edge_torch.generative.utilities import converter
|
21
21
|
from ai_edge_torch.generative.utilities import export_config
|
22
|
+
from ai_edge_torch.generative.utilities import loader
|
22
23
|
|
23
24
|
flags = converter.define_conversion_flags("openelm")
|
24
25
|
|
25
26
|
|
26
27
|
def main(_):
|
28
|
+
checkpoint_path = flags.FLAGS.checkpoint_path
|
27
29
|
pytorch_model = openelm.build_model(
|
28
|
-
|
30
|
+
checkpoint_path,
|
31
|
+
custom_loader=loader.maybe_get_custom_loader(
|
32
|
+
checkpoint_path, flags.FLAGS.custom_checkpoint_loader
|
33
|
+
),
|
34
|
+
kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
|
29
35
|
)
|
30
36
|
converter.convert_to_tflite(
|
31
37
|
pytorch_model,
|
@@ -15,9 +15,11 @@
|
|
15
15
|
|
16
16
|
"""Example of building an OpenELM model."""
|
17
17
|
|
18
|
+
from typing import Callable, Dict
|
18
19
|
import ai_edge_torch.generative.layers.model_config as cfg
|
19
20
|
from ai_edge_torch.generative.utilities import model_builder
|
20
21
|
import ai_edge_torch.generative.utilities.loader as loading_utils
|
22
|
+
import torch
|
21
23
|
from torch import nn
|
22
24
|
|
23
25
|
TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
@@ -118,10 +120,15 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
|
|
118
120
|
return config
|
119
121
|
|
120
122
|
|
121
|
-
def build_model(
|
123
|
+
def build_model(
|
124
|
+
checkpoint_path: str,
|
125
|
+
custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
|
126
|
+
**kwargs
|
127
|
+
) -> nn.Module:
|
122
128
|
return model_builder.build_decoder_only_model(
|
123
129
|
checkpoint_path=checkpoint_path,
|
124
130
|
config=get_model_config(**kwargs),
|
125
131
|
tensor_names=TENSOR_NAMES,
|
126
132
|
model_class=OpenELM,
|
133
|
+
custom_loader=custom_loader,
|
127
134
|
)
|
@@ -19,6 +19,7 @@ from absl import app
|
|
19
19
|
from ai_edge_torch.generative.examples.paligemma import paligemma
|
20
20
|
from ai_edge_torch.generative.utilities import converter
|
21
21
|
from ai_edge_torch.generative.utilities import export_config
|
22
|
+
from ai_edge_torch.generative.utilities import loader
|
22
23
|
import torch
|
23
24
|
|
24
25
|
flags = converter.define_conversion_flags('paligemma2-3b-224')
|
@@ -32,9 +33,13 @@ _VERSION = flags.DEFINE_enum(
|
|
32
33
|
|
33
34
|
|
34
35
|
def main(_):
|
36
|
+
checkpoint_path = flags.FLAGS.checkpoint_path
|
35
37
|
pytorch_model = paligemma.build_model(
|
36
|
-
|
38
|
+
checkpoint_path,
|
37
39
|
version=int(_VERSION.value),
|
40
|
+
custom_loader=loader.maybe_get_custom_loader(
|
41
|
+
checkpoint_path, flags.FLAGS.custom_checkpoint_loader
|
42
|
+
),
|
38
43
|
kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
|
39
44
|
)
|
40
45
|
|
@@ -113,6 +113,7 @@ def get_decoder_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
113
113
|
type=cfg.NormalizationType.RMS_NORM,
|
114
114
|
epsilon=1e-6,
|
115
115
|
zero_centered=True,
|
116
|
+
enable_hlfb=True,
|
116
117
|
)
|
117
118
|
block_config = cfg.TransformerBlockConfig(
|
118
119
|
attn_config=attn_config,
|