ai-edge-torch-nightly 0.5.0.dev20250515__py3-none-any.whl → 0.5.0.dev20250517__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/__init__.py +1 -0
- ai_edge_torch/_convert/conversion.py +24 -0
- ai_edge_torch/_convert/converter.py +57 -3
- ai_edge_torch/_convert/fx_passes/__init__.py +1 -0
- ai_edge_torch/_convert/fx_passes/eliminate_dead_code_pass.py +40 -0
- ai_edge_torch/_convert/test/test_convert.py +25 -0
- ai_edge_torch/generative/examples/amd_llama_135m/amd_llama_135m.py +10 -6
- ai_edge_torch/generative/examples/amd_llama_135m/convert_to_tflite.py +7 -1
- ai_edge_torch/generative/examples/deepseek/convert_to_tflite.py +7 -2
- ai_edge_torch/generative/examples/deepseek/deepseek.py +9 -5
- ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py +7 -1
- ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +7 -1
- ai_edge_torch/generative/examples/gemma/gemma1.py +10 -6
- ai_edge_torch/generative/examples/gemma/gemma2.py +8 -7
- ai_edge_torch/generative/examples/gemma3/convert_gemma3_to_tflite.py +5 -14
- ai_edge_torch/generative/examples/gemma3/decoder.py +10 -10
- ai_edge_torch/generative/examples/gemma3/gemma3.py +1 -3
- ai_edge_torch/generative/examples/gemma3/image_encoder.py +1 -4
- ai_edge_torch/generative/examples/hammer/convert_to_tflite.py +7 -1
- ai_edge_torch/generative/examples/hammer/hammer.py +15 -6
- ai_edge_torch/generative/examples/llama/convert_to_tflite.py +7 -1
- ai_edge_torch/generative/examples/llama/llama.py +26 -10
- ai_edge_torch/generative/examples/moonshine/convert_moonshine_to_tflite.py +0 -1
- ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +7 -1
- ai_edge_torch/generative/examples/openelm/openelm.py +9 -3
- ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py +6 -1
- ai_edge_torch/generative/examples/paligemma/decoder.py +1 -4
- ai_edge_torch/generative/examples/paligemma/decoder2.py +1 -4
- ai_edge_torch/generative/examples/paligemma/image_encoder.py +3 -5
- ai_edge_torch/generative/examples/paligemma/paligemma.py +12 -5
- ai_edge_torch/generative/examples/paligemma/verify.py +27 -5
- ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py +7 -1
- ai_edge_torch/generative/examples/phi/convert_phi4_to_tflite.py +7 -1
- ai_edge_torch/generative/examples/phi/convert_to_tflite.py +7 -1
- ai_edge_torch/generative/examples/phi/phi2.py +9 -5
- ai_edge_torch/generative/examples/phi/phi3.py +8 -6
- ai_edge_torch/generative/examples/phi/phi4.py +8 -6
- ai_edge_torch/generative/examples/qwen/convert_to_tflite.py +7 -1
- ai_edge_torch/generative/examples/qwen/qwen.py +21 -7
- ai_edge_torch/generative/examples/qwen_vl/convert_to_tflite.py +6 -1
- ai_edge_torch/generative/examples/qwen_vl/decoder.py +1 -3
- ai_edge_torch/generative/examples/qwen_vl/image_encoder.py +13 -7
- ai_edge_torch/generative/examples/qwen_vl/qwen_vl.py +12 -4
- ai_edge_torch/generative/examples/qwen_vl/verify.py +26 -5
- ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +7 -2
- ai_edge_torch/generative/examples/smollm/convert_v2_to_tflite.py +7 -1
- ai_edge_torch/generative/examples/smollm/smollm.py +15 -6
- ai_edge_torch/generative/examples/smollm/verify.py +2 -2
- ai_edge_torch/generative/examples/stable_diffusion/clip.py +8 -5
- ai_edge_torch/generative/examples/t5/t5.py +1 -3
- ai_edge_torch/generative/examples/test_models/toy_model.py +4 -1
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +3 -2
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +7 -1
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +9 -5
- ai_edge_torch/generative/layers/model_config.py +2 -2
- ai_edge_torch/generative/utilities/converter.py +18 -5
- ai_edge_torch/generative/utilities/loader.py +19 -0
- ai_edge_torch/odml_torch/lowerings/utils.py +13 -0
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.5.0.dev20250515.dist-info → ai_edge_torch_nightly-0.5.0.dev20250517.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.5.0.dev20250515.dist-info → ai_edge_torch_nightly-0.5.0.dev20250517.dist-info}/RECORD +64 -63
- {ai_edge_torch_nightly-0.5.0.dev20250515.dist-info → ai_edge_torch_nightly-0.5.0.dev20250517.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.5.0.dev20250515.dist-info → ai_edge_torch_nightly-0.5.0.dev20250517.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.5.0.dev20250515.dist-info → ai_edge_torch_nightly-0.5.0.dev20250517.dist-info}/top_level.txt +0 -0
ai_edge_torch/__init__.py
CHANGED
@@ -14,6 +14,7 @@
|
|
14
14
|
# ==============================================================================
|
15
15
|
from ai_edge_torch._config import config
|
16
16
|
from ai_edge_torch._convert.converter import convert
|
17
|
+
from ai_edge_torch._convert.converter import experimental_add_compilation_backend
|
17
18
|
from ai_edge_torch._convert.converter import signature
|
18
19
|
from ai_edge_torch._convert.to_channel_last_io import to_channel_last_io
|
19
20
|
from ai_edge_torch.model import Model
|
@@ -26,6 +26,9 @@ from ai_edge_torch.generative import fx_passes as generative_fx_passes
|
|
26
26
|
from ai_edge_torch.quantize import quant_config as qcfg
|
27
27
|
import torch
|
28
28
|
|
29
|
+
from ai_edge_litert.aot import aot_compile as aot_compile_lib
|
30
|
+
from ai_edge_litert.aot.core import types as litert_types
|
31
|
+
|
29
32
|
|
30
33
|
def _run_convert_passes(
|
31
34
|
exported_program: torch.export.ExportedProgram,
|
@@ -35,6 +38,7 @@ def _run_convert_passes(
|
|
35
38
|
)
|
36
39
|
|
37
40
|
passes = [
|
41
|
+
fx_passes.EliminateDeadCodePass(),
|
38
42
|
fx_passes.OptimizeLayoutTransposesPass(),
|
39
43
|
fx_passes.CanonicalizePass(),
|
40
44
|
fx_passes.BuildAtenCompositePass(),
|
@@ -153,3 +157,23 @@ def convert_signatures(
|
|
153
157
|
)
|
154
158
|
|
155
159
|
return model.TfLiteModel(tflite_model)
|
160
|
+
|
161
|
+
|
162
|
+
def aot_compile(
|
163
|
+
compilation_configs: list[litert_types.CompilationConfig],
|
164
|
+
cpu_model: model.TfLiteModel,
|
165
|
+
) -> litert_types.CompilationResult:
|
166
|
+
"""Compiles the given CPU model.
|
167
|
+
|
168
|
+
Args:
|
169
|
+
compilation_configs: The list of compilation configs to use.
|
170
|
+
cpu_model: The CPU model to compile.
|
171
|
+
|
172
|
+
Returns:
|
173
|
+
The compilation result.
|
174
|
+
"""
|
175
|
+
litert_model = litert_types.Model.create_from_bytes(cpu_model.tflite_model())
|
176
|
+
return aot_compile_lib.aot_compile(
|
177
|
+
litert_model,
|
178
|
+
config=compilation_configs,
|
179
|
+
)
|
@@ -23,6 +23,9 @@ from ai_edge_torch._convert import signature as signature_module
|
|
23
23
|
from ai_edge_torch.quantize import quant_config as qcfg
|
24
24
|
import torch
|
25
25
|
|
26
|
+
from ai_edge_litert.aot.core import types as litert_types
|
27
|
+
from ai_edge_litert.aot.vendors import import_vendor as vendor_lib
|
28
|
+
|
26
29
|
|
27
30
|
class Converter:
|
28
31
|
"""A converter for converting PyTorch models to edge models.
|
@@ -32,6 +35,7 @@ class Converter:
|
|
32
35
|
|
33
36
|
def __init__(self):
|
34
37
|
self._signatures: list[signature_module.Signature] = []
|
38
|
+
self._compilation_configs: list[litert_types.CompilationConfig] = []
|
35
39
|
|
36
40
|
def signature(
|
37
41
|
self,
|
@@ -96,6 +100,31 @@ class Converter:
|
|
96
100
|
)
|
97
101
|
return self
|
98
102
|
|
103
|
+
def experimental_add_compilation_backend(
|
104
|
+
self,
|
105
|
+
target: litert_types.Target | None = None,
|
106
|
+
**kwargs: litert_types.Config,
|
107
|
+
) -> Converter:
|
108
|
+
"""Adds an AOT compilation target to the converter.
|
109
|
+
|
110
|
+
NOTE: This API is experimental and subject to change.
|
111
|
+
|
112
|
+
Args:
|
113
|
+
target: The target to compile for. If not specified, will compile to all
|
114
|
+
registered AOT targets in ai_edge_litert. See ai_edge_litert.aot.vendors
|
115
|
+
for more details. Adding a same target multiple times will be a no-op.
|
116
|
+
**kwargs: Additional arguments to pass to the backend compiler.
|
117
|
+
|
118
|
+
Returns:
|
119
|
+
The converter object itself.
|
120
|
+
"""
|
121
|
+
if target is None:
|
122
|
+
target = vendor_lib.AllRegisteredTarget()
|
123
|
+
if isinstance(target, litert_types.Target):
|
124
|
+
target = litert_types.CompilationConfig(target=target, **kwargs)
|
125
|
+
self._compilation_configs.append(target)
|
126
|
+
return self
|
127
|
+
|
99
128
|
def convert(
|
100
129
|
self,
|
101
130
|
module: torch.nn.Module = None,
|
@@ -107,7 +136,7 @@ class Converter:
|
|
107
136
|
dynamic_shapes: Optional[Union[dict[str, Any], Tuple[Any, ...]]] = None,
|
108
137
|
_ai_edge_converter_flags: Optional[dict[str, Any]] = None,
|
109
138
|
_saved_model_dir: Optional[str] = None,
|
110
|
-
) -> model.TfLiteModel:
|
139
|
+
) -> model.TfLiteModel | litert_types.CompilationResult:
|
111
140
|
"""Finalizes the conversion and produces an edge model.
|
112
141
|
|
113
142
|
This could be called with no arguments as follows:
|
@@ -144,7 +173,9 @@ class Converter:
|
|
144
173
|
specified, a random temporary directory would be used.
|
145
174
|
|
146
175
|
Returns:
|
147
|
-
The converted edge model.
|
176
|
+
The converted edge model. If compilation configs are provided, returns the
|
177
|
+
compilation result that contains the compiled edge models for different
|
178
|
+
targets.
|
148
179
|
|
149
180
|
Raises:
|
150
181
|
ValueError: If the arguments are not provided as expected. See the example
|
@@ -169,13 +200,16 @@ class Converter:
|
|
169
200
|
"sample_args or sample_kwargs must be provided if a module is"
|
170
201
|
" specified."
|
171
202
|
)
|
172
|
-
|
203
|
+
converted_model = conversion.convert_signatures(
|
173
204
|
self._signatures,
|
174
205
|
strict_export=strict_export,
|
175
206
|
quant_config=quant_config,
|
176
207
|
_tfl_converter_flags=_ai_edge_converter_flags,
|
177
208
|
_saved_model_dir=_saved_model_dir,
|
178
209
|
)
|
210
|
+
if self._compilation_configs:
|
211
|
+
return conversion.aot_compile(self._compilation_configs, converted_model)
|
212
|
+
return converted_model
|
179
213
|
|
180
214
|
|
181
215
|
def signature(
|
@@ -211,6 +245,26 @@ def signature(
|
|
211
245
|
)
|
212
246
|
|
213
247
|
|
248
|
+
def experimental_add_compilation_backend(
|
249
|
+
target: litert_types.Target | None = None,
|
250
|
+
**kwargs: litert_types.Config,
|
251
|
+
) -> Converter:
|
252
|
+
"""Adds an AOT compilation target to the converter.
|
253
|
+
|
254
|
+
NOTE: This API is experimental and subject to change.
|
255
|
+
|
256
|
+
Args:
|
257
|
+
target: The target to compile for. If not specified, will compile to all
|
258
|
+
registered AOT targets in ai_edge_litert. See ai_edge_litert.aot.vendors
|
259
|
+
for more details. Adding a same target multiple times will be a no-op.
|
260
|
+
**kwargs: Additional arguments to pass to the backend compiler.
|
261
|
+
|
262
|
+
Returns:
|
263
|
+
The converter object itself.
|
264
|
+
"""
|
265
|
+
return Converter().experimental_add_compilation_backend(target, **kwargs)
|
266
|
+
|
267
|
+
|
214
268
|
def convert(
|
215
269
|
module: torch.nn.Module = None,
|
216
270
|
sample_args=None,
|
@@ -17,6 +17,7 @@ from typing import Sequence, Union
|
|
17
17
|
|
18
18
|
from ai_edge_torch._convert.fx_passes.build_aten_composite_pass import BuildAtenCompositePass
|
19
19
|
from ai_edge_torch._convert.fx_passes.cast_inputs_bf16_to_f32_pass import CastInputsBf16ToF32Pass
|
20
|
+
from ai_edge_torch._convert.fx_passes.eliminate_dead_code_pass import EliminateDeadCodePass
|
20
21
|
from ai_edge_torch._convert.fx_passes.inject_mlir_debuginfo_pass import InjectMlirDebuginfoPass
|
21
22
|
from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import OptimizeLayoutTransposesPass
|
22
23
|
from ai_edge_torch._convert.fx_passes.remove_non_user_outputs_pass import RemoveNonUserOutputsPass
|
@@ -0,0 +1,40 @@
|
|
1
|
+
# Copyright 2025 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
"""Pass to eliminate dead code for ai-edge-torch conversion."""
|
16
|
+
|
17
|
+
|
18
|
+
from ai_edge_torch import fx_infra
|
19
|
+
import torch
|
20
|
+
|
21
|
+
|
22
|
+
class EliminateDeadCodePass(fx_infra.PassBase):
|
23
|
+
"""Eliminates dead code with dedicated rules for ai-edge-torch conversion."""
|
24
|
+
|
25
|
+
def call(self, graph_module: torch.fx.GraphModule):
|
26
|
+
def is_impure_node(node: torch.fx.Node):
|
27
|
+
# Starting from torch 2.7.0, random torch ops with
|
28
|
+
# _nondeterministic_seeded set are no longer considered pure. However,
|
29
|
+
# for conversion, unused random ops/tensors should still be removed.
|
30
|
+
if getattr(node.target, "_nondeterministic_seeded", False):
|
31
|
+
return False
|
32
|
+
return node.is_impure()
|
33
|
+
|
34
|
+
try:
|
35
|
+
graph_module.graph.eliminate_dead_code(is_impure_node)
|
36
|
+
except TypeError:
|
37
|
+
# eliminate_dead_code has no is_impure_node input in old torch versions.
|
38
|
+
pass
|
39
|
+
|
40
|
+
return fx_infra.PassResult(graph_module, True)
|
@@ -29,6 +29,8 @@ from torch.ao.quantization import quantize_pt2e
|
|
29
29
|
import torchvision
|
30
30
|
|
31
31
|
from absl.testing import absltest as googletest
|
32
|
+
from ai_edge_litert.aot.core import types as litert_types
|
33
|
+
from ai_edge_litert.aot.vendors import fallback_backend
|
32
34
|
from ai_edge_litert import interpreter as tfl_interpreter # pylint: disable=g-direct-tensorflow-import
|
33
35
|
|
34
36
|
|
@@ -574,6 +576,29 @@ class TestConvert(googletest.TestCase):
|
|
574
576
|
self.fail(f"Conversion failed with bloat16 inputs: {err}")
|
575
577
|
# pylint: enable=broad-except
|
576
578
|
|
579
|
+
def test_compile_model(self):
|
580
|
+
"""Tests AOT compilation of a simple Add module."""
|
581
|
+
|
582
|
+
class Add(nn.Module):
|
583
|
+
|
584
|
+
def forward(self, a, b):
|
585
|
+
return a + b
|
586
|
+
|
587
|
+
args = (
|
588
|
+
torch.randn((5, 10)),
|
589
|
+
torch.randn((5, 10)),
|
590
|
+
)
|
591
|
+
torch_module = Add().eval()
|
592
|
+
compilation_result = ai_edge_torch.experimental_add_compilation_backend(
|
593
|
+
fallback_backend.FallbackTarget()
|
594
|
+
).convert(torch_module, args)
|
595
|
+
assert isinstance(compilation_result, litert_types.CompilationResult)
|
596
|
+
self.assertLen(compilation_result.models_with_backend, 1)
|
597
|
+
self.assertEqual(
|
598
|
+
compilation_result.models_with_backend[0][0].id(),
|
599
|
+
fallback_backend.FallbackBackend.id(),
|
600
|
+
)
|
601
|
+
|
577
602
|
|
578
603
|
if __name__ == "__main__":
|
579
604
|
googletest.main()
|
@@ -15,8 +15,10 @@
|
|
15
15
|
|
16
16
|
"""Example of building AMD-Llama-135m."""
|
17
17
|
|
18
|
+
from typing import Callable, Dict
|
18
19
|
import ai_edge_torch.generative.layers.model_config as cfg
|
19
20
|
from ai_edge_torch.generative.utilities import model_builder
|
21
|
+
import torch
|
20
22
|
from torch import nn
|
21
23
|
|
22
24
|
TENSOR_NAMES = model_builder.TENSOR_NAMES_WITH_SEPARATE_LM_HEAD
|
@@ -49,9 +51,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
49
51
|
activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
|
50
52
|
intermediate_size=2048,
|
51
53
|
)
|
52
|
-
norm_config = cfg.NormalizationConfig(
|
53
|
-
type=cfg.NormalizationType.RMS_NORM, enable_hlfb=True
|
54
|
-
)
|
54
|
+
norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
|
55
55
|
block_config = cfg.TransformerBlockConfig(
|
56
56
|
attn_config=attn_config,
|
57
57
|
ff_config=ff_config,
|
@@ -67,7 +67,6 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
67
67
|
block_configs=block_config,
|
68
68
|
final_norm_config=norm_config,
|
69
69
|
lm_head_share_weight_with_embedding=False,
|
70
|
-
enable_hlfb=True,
|
71
70
|
)
|
72
71
|
return config
|
73
72
|
|
@@ -80,10 +79,15 @@ def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
|
|
80
79
|
return config
|
81
80
|
|
82
81
|
|
83
|
-
def build_model(
|
82
|
+
def build_model(
|
83
|
+
checkpoint_path: str,
|
84
|
+
custom_loader: Callable[[str], Dict[str, torch.Tensor]] | None = None,
|
85
|
+
**kwargs
|
86
|
+
) -> nn.Module:
|
84
87
|
return model_builder.build_decoder_only_model(
|
85
88
|
checkpoint_path=checkpoint_path,
|
86
89
|
config=get_model_config(**kwargs),
|
87
90
|
tensor_names=TENSOR_NAMES,
|
88
|
-
model_class=AmdLlama
|
91
|
+
model_class=AmdLlama,
|
92
|
+
custom_loader=custom_loader,
|
89
93
|
)
|
@@ -19,13 +19,19 @@ from absl import app
|
|
19
19
|
from ai_edge_torch.generative.examples.amd_llama_135m import amd_llama_135m
|
20
20
|
from ai_edge_torch.generative.utilities import converter
|
21
21
|
from ai_edge_torch.generative.utilities import export_config
|
22
|
+
from ai_edge_torch.generative.utilities import loader
|
22
23
|
|
23
24
|
flags = converter.define_conversion_flags("amd-llama-135m")
|
24
25
|
|
25
26
|
|
26
27
|
def main(_):
|
28
|
+
checkpoint_path = flags.FLAGS.checkpoint_path
|
27
29
|
pytorch_model = amd_llama_135m.build_model(
|
28
|
-
|
30
|
+
checkpoint_path,
|
31
|
+
custom_loader=loader.maybe_get_custom_loader(
|
32
|
+
checkpoint_path, flags.FLAGS.custom_checkpoint_loader
|
33
|
+
),
|
34
|
+
kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
|
29
35
|
)
|
30
36
|
converter.convert_to_tflite(
|
31
37
|
pytorch_model,
|
@@ -17,15 +17,20 @@
|
|
17
17
|
|
18
18
|
from absl import app
|
19
19
|
from ai_edge_torch.generative.examples.deepseek import deepseek
|
20
|
-
from ai_edge_torch.generative.layers import kv_cache
|
21
20
|
from ai_edge_torch.generative.utilities import converter
|
22
21
|
from ai_edge_torch.generative.utilities import export_config
|
22
|
+
from ai_edge_torch.generative.utilities import loader
|
23
23
|
|
24
24
|
flags = converter.define_conversion_flags('deepseek')
|
25
25
|
|
26
26
|
def main(_):
|
27
|
+
checkpoint_path = flags.FLAGS.checkpoint_path
|
27
28
|
pytorch_model = deepseek.build_model(
|
28
|
-
|
29
|
+
checkpoint_path,
|
30
|
+
custom_loader=loader.maybe_get_custom_loader(
|
31
|
+
checkpoint_path, flags.FLAGS.custom_checkpoint_loader
|
32
|
+
),
|
33
|
+
kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
|
29
34
|
)
|
30
35
|
converter.convert_to_tflite(
|
31
36
|
pytorch_model,
|
@@ -15,8 +15,10 @@
|
|
15
15
|
|
16
16
|
"""Example of building DeepSeek R1 distilled models."""
|
17
17
|
|
18
|
+
from typing import Callable, Dict
|
18
19
|
import ai_edge_torch.generative.layers.model_config as cfg
|
19
20
|
from ai_edge_torch.generative.utilities import model_builder
|
21
|
+
import torch
|
20
22
|
from torch import nn
|
21
23
|
|
22
24
|
TENSOR_NAMES = model_builder.TENSOR_NAMES_WITH_SEPARATE_LM_HEAD
|
@@ -51,9 +53,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
51
53
|
intermediate_size=8960,
|
52
54
|
)
|
53
55
|
norm_config = cfg.NormalizationConfig(
|
54
|
-
type=cfg.NormalizationType.RMS_NORM,
|
55
|
-
epsilon=1e-06,
|
56
|
-
enable_hlfb=True,
|
56
|
+
type=cfg.NormalizationType.RMS_NORM, epsilon=1e-06
|
57
57
|
)
|
58
58
|
block_config = cfg.TransformerBlockConfig(
|
59
59
|
attn_config=attn_config,
|
@@ -70,7 +70,6 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
70
70
|
block_configs=block_config,
|
71
71
|
final_norm_config=norm_config,
|
72
72
|
lm_head_share_weight_with_embedding=False,
|
73
|
-
enable_hlfb=True,
|
74
73
|
)
|
75
74
|
return config
|
76
75
|
|
@@ -84,10 +83,15 @@ def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
|
|
84
83
|
return config
|
85
84
|
|
86
85
|
|
87
|
-
def build_model(
|
86
|
+
def build_model(
|
87
|
+
checkpoint_path: str,
|
88
|
+
custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
|
89
|
+
**kwargs
|
90
|
+
) -> nn.Module:
|
88
91
|
return model_builder.build_decoder_only_model(
|
89
92
|
checkpoint_path=checkpoint_path,
|
90
93
|
config=get_model_config(**kwargs),
|
91
94
|
tensor_names=TENSOR_NAMES,
|
92
95
|
model_class=DeepSeekDistillQwen,
|
96
|
+
custom_loader=custom_loader,
|
93
97
|
)
|
@@ -19,13 +19,19 @@ from absl import app
|
|
19
19
|
from ai_edge_torch.generative.examples.gemma import gemma1
|
20
20
|
from ai_edge_torch.generative.utilities import converter
|
21
21
|
from ai_edge_torch.generative.utilities import export_config
|
22
|
+
from ai_edge_torch.generative.utilities import loader
|
22
23
|
|
23
24
|
flags = converter.define_conversion_flags("gemma-2b")
|
24
25
|
|
25
26
|
|
26
27
|
def main(_):
|
28
|
+
checkpoint_path = flags.FLAGS.checkpoint_path
|
27
29
|
pytorch_model = gemma1.build_2b_model(
|
28
|
-
|
30
|
+
checkpoint_path,
|
31
|
+
custom_loader=loader.maybe_get_custom_loader(
|
32
|
+
checkpoint_path, flags.FLAGS.custom_checkpoint_loader
|
33
|
+
),
|
34
|
+
kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
|
29
35
|
)
|
30
36
|
converter.convert_to_tflite(
|
31
37
|
pytorch_model,
|
@@ -19,6 +19,7 @@ from absl import app
|
|
19
19
|
from ai_edge_torch.generative.examples.gemma import gemma2
|
20
20
|
from ai_edge_torch.generative.utilities import converter
|
21
21
|
from ai_edge_torch.generative.utilities import export_config
|
22
|
+
from ai_edge_torch.generative.utilities import loader
|
22
23
|
|
23
24
|
flags = converter.define_conversion_flags(
|
24
25
|
"gemma2-2b", default_mask_as_input=True, default_transpose_kv_cache=True
|
@@ -26,8 +27,13 @@ flags = converter.define_conversion_flags(
|
|
26
27
|
|
27
28
|
|
28
29
|
def main(_):
|
30
|
+
checkpoint_path = flags.FLAGS.checkpoint_path
|
29
31
|
pytorch_model = gemma2.build_2b_model(
|
30
|
-
|
32
|
+
checkpoint_path,
|
33
|
+
custom_loader=loader.maybe_get_custom_loader(
|
34
|
+
checkpoint_path, flags.FLAGS.custom_checkpoint_loader
|
35
|
+
),
|
36
|
+
kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
|
31
37
|
)
|
32
38
|
converter.convert_to_tflite(
|
33
39
|
pytorch_model,
|
@@ -15,9 +15,12 @@
|
|
15
15
|
|
16
16
|
"""Example of building a Gemma1 model."""
|
17
17
|
|
18
|
+
from typing import Callable, Dict
|
19
|
+
|
18
20
|
import ai_edge_torch.generative.layers.model_config as cfg
|
19
21
|
from ai_edge_torch.generative.utilities import model_builder
|
20
22
|
import ai_edge_torch.generative.utilities.loader as loading_utils
|
23
|
+
import torch
|
21
24
|
from torch import nn
|
22
25
|
|
23
26
|
TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
@@ -62,10 +65,7 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
62
65
|
intermediate_size=16384,
|
63
66
|
)
|
64
67
|
norm_config = cfg.NormalizationConfig(
|
65
|
-
type=cfg.NormalizationType.RMS_NORM,
|
66
|
-
epsilon=1e-6,
|
67
|
-
zero_centered=True,
|
68
|
-
enable_hlfb=True,
|
68
|
+
type=cfg.NormalizationType.RMS_NORM, epsilon=1e-6, zero_centered=True
|
69
69
|
)
|
70
70
|
block_config = cfg.TransformerBlockConfig(
|
71
71
|
attn_config=attn_config,
|
@@ -84,7 +84,6 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
84
84
|
block_configs=block_config,
|
85
85
|
final_norm_config=norm_config,
|
86
86
|
lm_head_use_bias=False,
|
87
|
-
enable_hlfb=True,
|
88
87
|
)
|
89
88
|
return config
|
90
89
|
|
@@ -99,10 +98,15 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
|
|
99
98
|
return config
|
100
99
|
|
101
100
|
|
102
|
-
def build_2b_model(
|
101
|
+
def build_2b_model(
|
102
|
+
checkpoint_path: str,
|
103
|
+
custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
|
104
|
+
**kwargs
|
105
|
+
) -> nn.Module:
|
103
106
|
return model_builder.build_decoder_only_model(
|
104
107
|
checkpoint_path=checkpoint_path,
|
105
108
|
config=get_model_config_2b(**kwargs),
|
106
109
|
tensor_names=TENSOR_NAMES,
|
107
110
|
model_class=Gemma1,
|
111
|
+
custom_loader=custom_loader,
|
108
112
|
)
|
@@ -15,7 +15,7 @@
|
|
15
15
|
|
16
16
|
"""Example of building a Gemma2 model."""
|
17
17
|
|
18
|
-
from typing import List, Optional, Tuple
|
18
|
+
from typing import Callable, Dict, List, Optional, Tuple
|
19
19
|
|
20
20
|
from ai_edge_torch.generative.layers import attention
|
21
21
|
from ai_edge_torch.generative.layers import builder
|
@@ -233,10 +233,7 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
233
233
|
The model config for a Gemma 2B model.
|
234
234
|
"""
|
235
235
|
norm_config = cfg.NormalizationConfig(
|
236
|
-
type=cfg.NormalizationType.RMS_NORM,
|
237
|
-
epsilon=1e-6,
|
238
|
-
zero_centered=True,
|
239
|
-
enable_hlfb=True,
|
236
|
+
type=cfg.NormalizationType.RMS_NORM, epsilon=1e-6, zero_centered=True
|
240
237
|
)
|
241
238
|
ff_config = cfg.FeedForwardConfig(
|
242
239
|
type=cfg.FeedForwardType.GATED,
|
@@ -284,7 +281,6 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
284
281
|
block_configs=[get_block_config(i) for i in range(num_layers)],
|
285
282
|
final_norm_config=norm_config,
|
286
283
|
lm_head_use_bias=False,
|
287
|
-
enable_hlfb=True,
|
288
284
|
final_logit_softcap=30.0,
|
289
285
|
)
|
290
286
|
return config
|
@@ -306,7 +302,11 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
|
|
306
302
|
return config
|
307
303
|
|
308
304
|
|
309
|
-
def build_2b_model(
|
305
|
+
def build_2b_model(
|
306
|
+
checkpoint_path: str,
|
307
|
+
custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
|
308
|
+
**kwargs,
|
309
|
+
) -> nn.Module:
|
310
310
|
for tensor_names in TENSOR_NAMES_DICT.values():
|
311
311
|
try:
|
312
312
|
return model_builder.build_decoder_only_model(
|
@@ -314,6 +314,7 @@ def build_2b_model(checkpoint_path: str, **kwargs) -> nn.Module:
|
|
314
314
|
config=get_model_config_2b(**kwargs),
|
315
315
|
tensor_names=tensor_names,
|
316
316
|
model_class=Gemma2,
|
317
|
+
custom_loader=custom_loader,
|
317
318
|
)
|
318
319
|
except KeyError as _:
|
319
320
|
continue
|
@@ -25,13 +25,6 @@ flags = converter.define_conversion_flags(
|
|
25
25
|
'gemma3-1b', default_mask_as_input=True, default_transpose_kv_cache=True
|
26
26
|
)
|
27
27
|
|
28
|
-
_CUSTOM_CHECKPOINT_LOADER = flags.DEFINE_bool(
|
29
|
-
'custom_checkpoint_loader',
|
30
|
-
False,
|
31
|
-
'If true, the conversion script will use a custom checkpoint loader which'
|
32
|
-
' will read a checkpoint from a remote source.',
|
33
|
-
)
|
34
|
-
|
35
28
|
_MODEL_SIZE = flags.DEFINE_string(
|
36
29
|
'model_size',
|
37
30
|
'1b',
|
@@ -40,16 +33,14 @@ _MODEL_SIZE = flags.DEFINE_string(
|
|
40
33
|
|
41
34
|
|
42
35
|
def main(_):
|
43
|
-
|
44
|
-
if flags.FLAGS.custom_checkpoint_loader:
|
45
|
-
# If loading from a remote source, try to get a custom loader first.
|
46
|
-
custom_loader = loader.get_custom_loader(flags.FLAGS.checkpoint_path)
|
47
|
-
|
36
|
+
checkpoint_path = flags.FLAGS.checkpoint_path
|
48
37
|
if _MODEL_SIZE.value == '1b':
|
49
38
|
pytorch_model = gemma3.build_model_1b(
|
50
|
-
|
39
|
+
checkpoint_path,
|
40
|
+
custom_loader=loader.maybe_get_custom_loader(
|
41
|
+
checkpoint_path, flags.FLAGS.custom_checkpoint_loader
|
42
|
+
),
|
51
43
|
kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
|
52
|
-
custom_loader=custom_loader,
|
53
44
|
)
|
54
45
|
else:
|
55
46
|
raise ValueError(f'Unsupported model size: {_MODEL_SIZE.value}')
|
@@ -149,8 +149,12 @@ class Decoder(nn.Module):
|
|
149
149
|
cache_len=attention_mask.shape[-1],
|
150
150
|
sliding_window_size=sliding_window_size,
|
151
151
|
)
|
152
|
-
#
|
153
|
-
|
152
|
+
# Expand sliding_mask to match attention_mask's dimensions
|
153
|
+
# (e.g., [B, 1, seq_len, cache_len]).
|
154
|
+
# Assuming the head dimension is dim 1 for attention_mask.
|
155
|
+
expanded_sliding_mask = sliding_mask.unsqueeze(1)
|
156
|
+
# Combine masks using logical AND (min ensures -inf propagates).
|
157
|
+
combined_mask = torch.min(attention_mask, expanded_sliding_mask)
|
154
158
|
return combined_mask
|
155
159
|
return attention_mask
|
156
160
|
|
@@ -161,9 +165,9 @@ class Decoder(nn.Module):
|
|
161
165
|
sliding_window_size: int,
|
162
166
|
) -> torch.Tensor:
|
163
167
|
"""Creates mask for sliding window attention (PyTorch)."""
|
164
|
-
|
165
|
-
|
166
|
-
)
|
168
|
+
# Use torch.arange to create a tensor with a range of integers in a
|
169
|
+
# Dynamo-friendly way.
|
170
|
+
cache_positions = torch.arange(cache_len, dtype=torch.int32)
|
167
171
|
cache_positions = cache_positions.view(1, 1, -1) # [1, 1, cache_len]
|
168
172
|
segment_pos_expanded = segment_pos.clone().unsqueeze(-1) # [B, seq_len, 1]
|
169
173
|
|
@@ -329,10 +333,7 @@ def get_decoder_config_1b(kv_cache_max_len: int = 2048) -> cfg.ModelConfig:
|
|
329
333
|
The model config for a Gemma 1B model.
|
330
334
|
"""
|
331
335
|
norm_config = cfg.NormalizationConfig(
|
332
|
-
type=cfg.NormalizationType.RMS_NORM,
|
333
|
-
epsilon=1e-6,
|
334
|
-
zero_centered=True,
|
335
|
-
enable_hlfb=True,
|
336
|
+
type=cfg.NormalizationType.RMS_NORM, epsilon=1e-6, zero_centered=True,
|
336
337
|
)
|
337
338
|
ff_config = cfg.FeedForwardConfig(
|
338
339
|
type=cfg.FeedForwardType.GATED,
|
@@ -379,7 +380,6 @@ def get_decoder_config_1b(kv_cache_max_len: int = 2048) -> cfg.ModelConfig:
|
|
379
380
|
block_configs=[get_block_config(i) for i in range(num_layers)],
|
380
381
|
final_norm_config=norm_config,
|
381
382
|
lm_head_use_bias=False,
|
382
|
-
enable_hlfb=True,
|
383
383
|
final_logit_softcap=None,
|
384
384
|
)
|
385
385
|
return config
|
@@ -158,9 +158,7 @@ def get_fake_model_config(**kwargs) -> Gemma3MMConfig:
|
|
158
158
|
image_projection_scale=128**0.5,
|
159
159
|
image_projection_use_bias=False,
|
160
160
|
mm_norm_config=cfg.NormalizationConfig(
|
161
|
-
type=cfg.NormalizationType.LAYER_NORM,
|
162
|
-
epsilon=1e-6,
|
163
|
-
enable_hlfb=True,
|
161
|
+
type=cfg.NormalizationType.LAYER_NORM, epsilon=1e-6
|
164
162
|
),
|
165
163
|
mm_extra_tokens=32,
|
166
164
|
)
|
@@ -98,9 +98,7 @@ def get_image_encoder_config() -> cfg.ModelConfig:
|
|
98
98
|
output_proj_use_bias=True,
|
99
99
|
)
|
100
100
|
norm_config = cfg.NormalizationConfig(
|
101
|
-
type=cfg.NormalizationType.LAYER_NORM,
|
102
|
-
epsilon=1e-6,
|
103
|
-
enable_hlfb=True,
|
101
|
+
type=cfg.NormalizationType.LAYER_NORM, epsilon=1e-6
|
104
102
|
)
|
105
103
|
ff_config = cfg.FeedForwardConfig(
|
106
104
|
type=cfg.FeedForwardType.SEQUENTIAL,
|
@@ -123,7 +121,6 @@ def get_image_encoder_config() -> cfg.ModelConfig:
|
|
123
121
|
image_embedding=image_embedding_config,
|
124
122
|
block_configs=block_config,
|
125
123
|
final_norm_config=norm_config,
|
126
|
-
enable_hlfb=True,
|
127
124
|
num_mm_tokens_per_image=256,
|
128
125
|
)
|
129
126
|
return config
|
@@ -19,6 +19,7 @@ from absl import app
|
|
19
19
|
from ai_edge_torch.generative.examples.hammer import hammer
|
20
20
|
from ai_edge_torch.generative.utilities import converter
|
21
21
|
from ai_edge_torch.generative.utilities import export_config
|
22
|
+
from ai_edge_torch.generative.utilities import loader
|
22
23
|
|
23
24
|
flags = converter.define_conversion_flags('hammer')
|
24
25
|
|
@@ -36,8 +37,13 @@ _BUILDER = {
|
|
36
37
|
|
37
38
|
|
38
39
|
def main(_):
|
40
|
+
checkpoint_path = flags.FLAGS.checkpoint_path
|
39
41
|
pytorch_model = _BUILDER[_MODEL_SIZE.value](
|
40
|
-
|
42
|
+
checkpoint_path,
|
43
|
+
custom_loader=loader.maybe_get_custom_loader(
|
44
|
+
checkpoint_path, flags.FLAGS.custom_checkpoint_loader
|
45
|
+
),
|
46
|
+
kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
|
41
47
|
)
|
42
48
|
converter.convert_to_tflite(
|
43
49
|
pytorch_model,
|