ai-edge-torch-nightly 0.7.0.dev20250929__py3-none-any.whl → 0.8.0.dev20251206__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ai-edge-torch-nightly might be problematic. Click here for more details.
- ai_edge_torch/_convert/conversion.py +2 -1
- ai_edge_torch/fx_infra/_safe_run_decompositions.py +36 -1
- ai_edge_torch/generative/examples/amd_llama_135m/convert_to_tflite.py +1 -20
- ai_edge_torch/generative/examples/deepseek/convert_to_tflite.py +1 -20
- ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py +1 -20
- ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +1 -20
- ai_edge_torch/generative/examples/gemma3/convert_gemma3_to_tflite.py +3 -27
- ai_edge_torch/generative/examples/hammer/convert_to_tflite.py +1 -20
- ai_edge_torch/generative/examples/llama/convert_to_tflite.py +1 -20
- ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +1 -20
- ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py +1 -20
- ai_edge_torch/generative/examples/phi/convert_phi4_to_tflite.py +1 -20
- ai_edge_torch/generative/examples/phi/convert_to_tflite.py +1 -20
- ai_edge_torch/generative/examples/qwen/convert_to_tflite.py +1 -20
- ai_edge_torch/generative/examples/qwen/convert_v3_to_tflite.py +1 -20
- ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +1 -30
- ai_edge_torch/generative/examples/smollm/convert_v2_to_tflite.py +1 -30
- ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +1 -3
- ai_edge_torch/generative/examples/stable_diffusion/samplers/k_lms.py +1 -1
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +1 -20
- ai_edge_torch/generative/layers/attention.py +25 -2
- ai_edge_torch/generative/layers/attention_test.py +13 -1
- ai_edge_torch/generative/layers/attention_utils.py +62 -1
- ai_edge_torch/generative/layers/attention_utils_test.py +20 -0
- ai_edge_torch/generative/layers/builder.py +4 -2
- ai_edge_torch/generative/layers/model_config.py +5 -0
- ai_edge_torch/generative/layers/normalization.py +8 -2
- ai_edge_torch/generative/layers/scaled_dot_product_attention.py +35 -5
- ai_edge_torch/generative/layers/sdpa_with_kv_update.py +8 -3
- ai_edge_torch/generative/quantize/example.py +1 -1
- ai_edge_torch/generative/quantize/quant_attrs.py +8 -1
- ai_edge_torch/generative/quantize/quant_recipe.py +0 -13
- ai_edge_torch/generative/quantize/quant_recipe_utils.py +12 -19
- ai_edge_torch/generative/quantize/quant_recipes.py +16 -21
- ai_edge_torch/generative/quantize/supported_schemes.py +4 -1
- ai_edge_torch/generative/test/test_kv_cache.py +18 -6
- ai_edge_torch/generative/test/test_quantize.py +17 -26
- ai_edge_torch/generative/utilities/converter.py +183 -28
- ai_edge_torch/generative/utilities/export_config.py +2 -0
- ai_edge_torch/generative/utilities/litertlm_builder.py +61 -8
- ai_edge_torch/generative/utilities/loader.py +2 -1
- ai_edge_torch/lowertools/translate_recipe.py +8 -3
- ai_edge_torch/odml_torch/experimental/__init__.py +14 -0
- ai_edge_torch/odml_torch/experimental/torch_tfl/__init__.py +20 -0
- ai_edge_torch/odml_torch/experimental/torch_tfl/_decomps.py +438 -0
- ai_edge_torch/odml_torch/experimental/torch_tfl/_lowerings.py +728 -0
- ai_edge_torch/odml_torch/experimental/torch_tfl/_ops.py +371 -0
- ai_edge_torch/odml_torch/experimental/torch_tfl/torch_library_utils.py +37 -0
- ai_edge_torch/odml_torch/export.py +24 -7
- ai_edge_torch/odml_torch/lowerings/_basic.py +155 -0
- ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +255 -5
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.7.0.dev20250929.dist-info → ai_edge_torch_nightly-0.8.0.dev20251206.dist-info}/METADATA +15 -3
- {ai_edge_torch_nightly-0.7.0.dev20250929.dist-info → ai_edge_torch_nightly-0.8.0.dev20251206.dist-info}/RECORD +57 -51
- {ai_edge_torch_nightly-0.7.0.dev20250929.dist-info → ai_edge_torch_nightly-0.8.0.dev20251206.dist-info}/WHEEL +1 -1
- {ai_edge_torch_nightly-0.7.0.dev20250929.dist-info → ai_edge_torch_nightly-0.8.0.dev20251206.dist-info/licenses}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.7.0.dev20250929.dist-info → ai_edge_torch_nightly-0.8.0.dev20251206.dist-info}/top_level.txt +0 -0
|
@@ -133,10 +133,11 @@ def convert_signatures(
|
|
|
133
133
|
exported_program = fx_infra.safe_run_decompositions(
|
|
134
134
|
exported_program,
|
|
135
135
|
fx_infra.decomp.pre_convert_decomp(),
|
|
136
|
+
can_skip=False,
|
|
136
137
|
)
|
|
137
138
|
return exported_program
|
|
138
139
|
|
|
139
|
-
exported_programs
|
|
140
|
+
exported_programs = [
|
|
140
141
|
export(
|
|
141
142
|
mod=sig.module,
|
|
142
143
|
args=sig.args,
|
|
@@ -13,6 +13,7 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
"""ExportedProgram.run_decompositions wrapper to handle unexpected export behavior."""
|
|
16
|
+
import operator
|
|
16
17
|
import torch
|
|
17
18
|
|
|
18
19
|
|
|
@@ -26,8 +27,39 @@ _DUMMY_DECOMP_TABLE = {
|
|
|
26
27
|
torch._ops.OperatorBase(): lambda: None,
|
|
27
28
|
}
|
|
28
29
|
|
|
30
|
+
_BUILTIN_OPERATORS = {
|
|
31
|
+
getattr(operator, name)
|
|
32
|
+
for name in dir(operator)
|
|
33
|
+
if not name.startswith("_")
|
|
34
|
+
}
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def _require_decomp(
|
|
38
|
+
exported_program: torch.export.ExportedProgram, decomp_table
|
|
39
|
+
):
|
|
40
|
+
"""Checks if the exported program requires decompositions."""
|
|
41
|
+
for node in exported_program.graph.nodes:
|
|
42
|
+
if "call_" not in str(node.op):
|
|
43
|
+
continue
|
|
44
|
+
|
|
45
|
+
op = node.target
|
|
46
|
+
if isinstance(op, torch._ops.OpOverloadPacket):
|
|
47
|
+
op = op.default
|
|
29
48
|
|
|
30
|
-
|
|
49
|
+
if op in decomp_table:
|
|
50
|
+
return True
|
|
51
|
+
|
|
52
|
+
if (
|
|
53
|
+
not isinstance(op, (torch._ops.OpOverload, torch._ops.OperatorBase))
|
|
54
|
+
and op not in _BUILTIN_OPERATORS
|
|
55
|
+
):
|
|
56
|
+
# Python function that requires to be retraced via run_decompositions.
|
|
57
|
+
return True
|
|
58
|
+
|
|
59
|
+
return False
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def safe_run_decompositions(exported_program, decomp_table=None, can_skip=True):
|
|
31
63
|
"""Wrapper for ExportedProgram.run_decompositions to handle unexpected export behavior."""
|
|
32
64
|
|
|
33
65
|
if decomp_table is not None and not decomp_table:
|
|
@@ -35,6 +67,9 @@ def safe_run_decompositions(exported_program, decomp_table=None):
|
|
|
35
67
|
# instead for backward compatibility.
|
|
36
68
|
decomp_table = _DUMMY_DECOMP_TABLE
|
|
37
69
|
|
|
70
|
+
if can_skip and not _require_decomp(exported_program, decomp_table):
|
|
71
|
+
return exported_program
|
|
72
|
+
|
|
38
73
|
for node in exported_program.graph.nodes:
|
|
39
74
|
if node.target == torch.ops.aten.view.default:
|
|
40
75
|
# Passes or torch.export may generate aten.view nodes not respecting the
|
|
@@ -18,31 +18,12 @@
|
|
|
18
18
|
from absl import app
|
|
19
19
|
from ai_edge_torch.generative.examples.amd_llama_135m import amd_llama_135m
|
|
20
20
|
from ai_edge_torch.generative.utilities import converter
|
|
21
|
-
from ai_edge_torch.generative.utilities import export_config
|
|
22
|
-
from ai_edge_torch.generative.utilities import loader
|
|
23
21
|
|
|
24
22
|
flags = converter.define_conversion_flags("amd-llama-135m")
|
|
25
23
|
|
|
26
24
|
|
|
27
25
|
def main(_):
|
|
28
|
-
|
|
29
|
-
pytorch_model = amd_llama_135m.build_model(
|
|
30
|
-
checkpoint_path,
|
|
31
|
-
custom_loader=loader.maybe_get_custom_loader(
|
|
32
|
-
checkpoint_path, flags.FLAGS.custom_checkpoint_loader
|
|
33
|
-
),
|
|
34
|
-
mask_cache_size=converter.get_mask_cache_size_from_flags(),
|
|
35
|
-
)
|
|
36
|
-
converter.convert_to_tflite(
|
|
37
|
-
pytorch_model,
|
|
38
|
-
output_path=flags.FLAGS.output_path,
|
|
39
|
-
output_name_prefix=flags.FLAGS.output_name_prefix,
|
|
40
|
-
prefill_seq_len=flags.FLAGS.prefill_seq_lens,
|
|
41
|
-
kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
|
|
42
|
-
quantize=flags.FLAGS.quantize,
|
|
43
|
-
lora_ranks=flags.FLAGS.lora_ranks,
|
|
44
|
-
export_config=export_config.get_from_flags(),
|
|
45
|
-
)
|
|
26
|
+
converter.build_and_convert_to_tflite_from_flags(amd_llama_135m.build_model)
|
|
46
27
|
|
|
47
28
|
|
|
48
29
|
if __name__ == '__main__':
|
|
@@ -18,8 +18,6 @@
|
|
|
18
18
|
from absl import app
|
|
19
19
|
from ai_edge_torch.generative.examples.deepseek import deepseek
|
|
20
20
|
from ai_edge_torch.generative.utilities import converter
|
|
21
|
-
from ai_edge_torch.generative.utilities import export_config
|
|
22
|
-
from ai_edge_torch.generative.utilities import loader
|
|
23
21
|
|
|
24
22
|
flags = converter.define_conversion_flags(
|
|
25
23
|
'deepseek', default_mask_as_input=True, default_transpose_kv_cache=True
|
|
@@ -27,24 +25,7 @@ flags = converter.define_conversion_flags(
|
|
|
27
25
|
|
|
28
26
|
|
|
29
27
|
def main(_):
|
|
30
|
-
|
|
31
|
-
pytorch_model = deepseek.build_model(
|
|
32
|
-
checkpoint_path,
|
|
33
|
-
custom_loader=loader.maybe_get_custom_loader(
|
|
34
|
-
checkpoint_path, flags.FLAGS.custom_checkpoint_loader
|
|
35
|
-
),
|
|
36
|
-
mask_cache_size=converter.get_mask_cache_size_from_flags(),
|
|
37
|
-
)
|
|
38
|
-
converter.convert_to_tflite(
|
|
39
|
-
pytorch_model,
|
|
40
|
-
output_path=flags.FLAGS.output_path,
|
|
41
|
-
output_name_prefix=flags.FLAGS.output_name_prefix,
|
|
42
|
-
prefill_seq_len=flags.FLAGS.prefill_seq_lens,
|
|
43
|
-
kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
|
|
44
|
-
quantize=flags.FLAGS.quantize,
|
|
45
|
-
lora_ranks=flags.FLAGS.lora_ranks,
|
|
46
|
-
export_config=export_config.get_from_flags(),
|
|
47
|
-
)
|
|
28
|
+
converter.build_and_convert_to_tflite_from_flags(deepseek.build_model)
|
|
48
29
|
|
|
49
30
|
|
|
50
31
|
if __name__ == '__main__':
|
|
@@ -18,31 +18,12 @@
|
|
|
18
18
|
from absl import app
|
|
19
19
|
from ai_edge_torch.generative.examples.gemma import gemma1
|
|
20
20
|
from ai_edge_torch.generative.utilities import converter
|
|
21
|
-
from ai_edge_torch.generative.utilities import export_config
|
|
22
|
-
from ai_edge_torch.generative.utilities import loader
|
|
23
21
|
|
|
24
22
|
flags = converter.define_conversion_flags("gemma-2b")
|
|
25
23
|
|
|
26
24
|
|
|
27
25
|
def main(_):
|
|
28
|
-
|
|
29
|
-
pytorch_model = gemma1.build_2b_model(
|
|
30
|
-
checkpoint_path,
|
|
31
|
-
custom_loader=loader.maybe_get_custom_loader(
|
|
32
|
-
checkpoint_path, flags.FLAGS.custom_checkpoint_loader
|
|
33
|
-
),
|
|
34
|
-
mask_cache_size=converter.get_mask_cache_size_from_flags(),
|
|
35
|
-
)
|
|
36
|
-
converter.convert_to_tflite(
|
|
37
|
-
pytorch_model,
|
|
38
|
-
output_path=flags.FLAGS.output_path,
|
|
39
|
-
output_name_prefix=flags.FLAGS.output_name_prefix,
|
|
40
|
-
prefill_seq_len=flags.FLAGS.prefill_seq_lens,
|
|
41
|
-
kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
|
|
42
|
-
quantize=flags.FLAGS.quantize,
|
|
43
|
-
lora_ranks=flags.FLAGS.lora_ranks,
|
|
44
|
-
export_config=export_config.get_from_flags(),
|
|
45
|
-
)
|
|
26
|
+
converter.build_and_convert_to_tflite_from_flags(gemma1.build_2b_model)
|
|
46
27
|
|
|
47
28
|
|
|
48
29
|
if __name__ == '__main__':
|
|
@@ -18,8 +18,6 @@
|
|
|
18
18
|
from absl import app
|
|
19
19
|
from ai_edge_torch.generative.examples.gemma import gemma2
|
|
20
20
|
from ai_edge_torch.generative.utilities import converter
|
|
21
|
-
from ai_edge_torch.generative.utilities import export_config
|
|
22
|
-
from ai_edge_torch.generative.utilities import loader
|
|
23
21
|
|
|
24
22
|
flags = converter.define_conversion_flags(
|
|
25
23
|
"gemma2-2b", default_mask_as_input=True, default_transpose_kv_cache=True
|
|
@@ -27,24 +25,7 @@ flags = converter.define_conversion_flags(
|
|
|
27
25
|
|
|
28
26
|
|
|
29
27
|
def main(_):
|
|
30
|
-
|
|
31
|
-
pytorch_model = gemma2.build_2b_model(
|
|
32
|
-
checkpoint_path,
|
|
33
|
-
custom_loader=loader.maybe_get_custom_loader(
|
|
34
|
-
checkpoint_path, flags.FLAGS.custom_checkpoint_loader
|
|
35
|
-
),
|
|
36
|
-
mask_cache_size=converter.get_mask_cache_size_from_flags(),
|
|
37
|
-
)
|
|
38
|
-
converter.convert_to_tflite(
|
|
39
|
-
pytorch_model,
|
|
40
|
-
output_path=flags.FLAGS.output_path,
|
|
41
|
-
output_name_prefix=flags.FLAGS.output_name_prefix,
|
|
42
|
-
prefill_seq_len=flags.FLAGS.prefill_seq_lens,
|
|
43
|
-
kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
|
|
44
|
-
quantize=flags.FLAGS.quantize,
|
|
45
|
-
lora_ranks=flags.FLAGS.lora_ranks,
|
|
46
|
-
export_config=export_config.get_from_flags(),
|
|
47
|
-
)
|
|
28
|
+
converter.build_and_convert_to_tflite_from_flags(gemma2.build_2b_model)
|
|
48
29
|
|
|
49
30
|
|
|
50
31
|
if __name__ == '__main__':
|
|
@@ -18,8 +18,6 @@
|
|
|
18
18
|
from absl import app
|
|
19
19
|
from ai_edge_torch.generative.examples.gemma3 import gemma3
|
|
20
20
|
from ai_edge_torch.generative.utilities import converter
|
|
21
|
-
from ai_edge_torch.generative.utilities import export_config
|
|
22
|
-
from ai_edge_torch.generative.utilities import loader
|
|
23
21
|
|
|
24
22
|
flags = converter.define_conversion_flags(
|
|
25
23
|
'gemma3-1b', default_mask_as_input=True, default_transpose_kv_cache=True
|
|
@@ -33,36 +31,14 @@ _MODEL_SIZE = flags.DEFINE_string(
|
|
|
33
31
|
|
|
34
32
|
|
|
35
33
|
def main(_):
|
|
36
|
-
checkpoint_path = flags.FLAGS.checkpoint_path
|
|
37
34
|
if _MODEL_SIZE.value == '1b':
|
|
38
|
-
|
|
39
|
-
checkpoint_path,
|
|
40
|
-
custom_loader=loader.maybe_get_custom_loader(
|
|
41
|
-
checkpoint_path, flags.FLAGS.custom_checkpoint_loader
|
|
42
|
-
),
|
|
43
|
-
mask_cache_size=converter.get_mask_cache_size_from_flags(),
|
|
44
|
-
)
|
|
35
|
+
model_builder = gemma3.build_model_1b
|
|
45
36
|
elif _MODEL_SIZE.value == '270m':
|
|
46
|
-
|
|
47
|
-
checkpoint_path,
|
|
48
|
-
custom_loader=loader.maybe_get_custom_loader(
|
|
49
|
-
checkpoint_path, flags.FLAGS.custom_checkpoint_loader
|
|
50
|
-
),
|
|
51
|
-
mask_cache_size=converter.get_mask_cache_size_from_flags(),
|
|
52
|
-
)
|
|
37
|
+
model_builder = gemma3.build_model_270m
|
|
53
38
|
else:
|
|
54
39
|
raise ValueError(f'Unsupported model size: {_MODEL_SIZE.value}')
|
|
55
40
|
|
|
56
|
-
converter.
|
|
57
|
-
pytorch_model,
|
|
58
|
-
output_path=flags.FLAGS.output_path,
|
|
59
|
-
output_name_prefix=flags.FLAGS.output_name_prefix,
|
|
60
|
-
prefill_seq_len=flags.FLAGS.prefill_seq_lens,
|
|
61
|
-
kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
|
|
62
|
-
quantize=flags.FLAGS.quantize,
|
|
63
|
-
lora_ranks=flags.FLAGS.lora_ranks,
|
|
64
|
-
export_config=export_config.get_from_flags(),
|
|
65
|
-
)
|
|
41
|
+
converter.build_and_convert_to_tflite_from_flags(model_builder)
|
|
66
42
|
|
|
67
43
|
|
|
68
44
|
if __name__ == '__main__':
|
|
@@ -18,8 +18,6 @@
|
|
|
18
18
|
from absl import app
|
|
19
19
|
from ai_edge_torch.generative.examples.hammer import hammer
|
|
20
20
|
from ai_edge_torch.generative.utilities import converter
|
|
21
|
-
from ai_edge_torch.generative.utilities import export_config
|
|
22
|
-
from ai_edge_torch.generative.utilities import loader
|
|
23
21
|
|
|
24
22
|
flags = converter.define_conversion_flags('hammer')
|
|
25
23
|
|
|
@@ -37,24 +35,7 @@ _BUILDER = {
|
|
|
37
35
|
|
|
38
36
|
|
|
39
37
|
def main(_):
|
|
40
|
-
|
|
41
|
-
pytorch_model = _BUILDER[_MODEL_SIZE.value](
|
|
42
|
-
checkpoint_path,
|
|
43
|
-
custom_loader=loader.maybe_get_custom_loader(
|
|
44
|
-
checkpoint_path, flags.FLAGS.custom_checkpoint_loader
|
|
45
|
-
),
|
|
46
|
-
mask_cache_size=converter.get_mask_cache_size_from_flags(),
|
|
47
|
-
)
|
|
48
|
-
converter.convert_to_tflite(
|
|
49
|
-
pytorch_model,
|
|
50
|
-
output_path=flags.FLAGS.output_path,
|
|
51
|
-
output_name_prefix=flags.FLAGS.output_name_prefix,
|
|
52
|
-
prefill_seq_len=flags.FLAGS.prefill_seq_lens,
|
|
53
|
-
kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
|
|
54
|
-
quantize=flags.FLAGS.quantize,
|
|
55
|
-
lora_ranks=flags.FLAGS.lora_ranks,
|
|
56
|
-
export_config=export_config.get_from_flags(),
|
|
57
|
-
)
|
|
38
|
+
converter.build_and_convert_to_tflite_from_flags(_BUILDER[_MODEL_SIZE.value])
|
|
58
39
|
|
|
59
40
|
|
|
60
41
|
if __name__ == '__main__':
|
|
@@ -18,8 +18,6 @@
|
|
|
18
18
|
from absl import app
|
|
19
19
|
from ai_edge_torch.generative.examples.llama import llama
|
|
20
20
|
from ai_edge_torch.generative.utilities import converter
|
|
21
|
-
from ai_edge_torch.generative.utilities import export_config
|
|
22
|
-
from ai_edge_torch.generative.utilities import loader
|
|
23
21
|
|
|
24
22
|
flags = converter.define_conversion_flags('llama')
|
|
25
23
|
|
|
@@ -37,24 +35,7 @@ _BUILDER = {
|
|
|
37
35
|
|
|
38
36
|
|
|
39
37
|
def main(_):
|
|
40
|
-
|
|
41
|
-
pytorch_model = _BUILDER[_MODEL_SIZE.value](
|
|
42
|
-
checkpoint_path,
|
|
43
|
-
custom_loader=loader.maybe_get_custom_loader(
|
|
44
|
-
checkpoint_path, flags.FLAGS.custom_checkpoint_loader
|
|
45
|
-
),
|
|
46
|
-
mask_cache_size=converter.get_mask_cache_size_from_flags(),
|
|
47
|
-
)
|
|
48
|
-
converter.convert_to_tflite(
|
|
49
|
-
pytorch_model,
|
|
50
|
-
output_path=flags.FLAGS.output_path,
|
|
51
|
-
output_name_prefix=flags.FLAGS.output_name_prefix,
|
|
52
|
-
prefill_seq_len=flags.FLAGS.prefill_seq_lens,
|
|
53
|
-
kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
|
|
54
|
-
quantize=flags.FLAGS.quantize,
|
|
55
|
-
lora_ranks=flags.FLAGS.lora_ranks,
|
|
56
|
-
export_config=export_config.get_from_flags(),
|
|
57
|
-
)
|
|
38
|
+
converter.build_and_convert_to_tflite_from_flags(_BUILDER[_MODEL_SIZE.value])
|
|
58
39
|
|
|
59
40
|
|
|
60
41
|
if __name__ == '__main__':
|
|
@@ -18,31 +18,12 @@
|
|
|
18
18
|
from absl import app
|
|
19
19
|
from ai_edge_torch.generative.examples.openelm import openelm
|
|
20
20
|
from ai_edge_torch.generative.utilities import converter
|
|
21
|
-
from ai_edge_torch.generative.utilities import export_config
|
|
22
|
-
from ai_edge_torch.generative.utilities import loader
|
|
23
21
|
|
|
24
22
|
flags = converter.define_conversion_flags("openelm")
|
|
25
23
|
|
|
26
24
|
|
|
27
25
|
def main(_):
|
|
28
|
-
|
|
29
|
-
pytorch_model = openelm.build_model(
|
|
30
|
-
checkpoint_path,
|
|
31
|
-
custom_loader=loader.maybe_get_custom_loader(
|
|
32
|
-
checkpoint_path, flags.FLAGS.custom_checkpoint_loader
|
|
33
|
-
),
|
|
34
|
-
mask_cache_size=converter.get_mask_cache_size_from_flags(),
|
|
35
|
-
)
|
|
36
|
-
converter.convert_to_tflite(
|
|
37
|
-
pytorch_model,
|
|
38
|
-
output_path=flags.FLAGS.output_path,
|
|
39
|
-
output_name_prefix=flags.FLAGS.output_name_prefix,
|
|
40
|
-
prefill_seq_len=flags.FLAGS.prefill_seq_lens,
|
|
41
|
-
kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
|
|
42
|
-
quantize=flags.FLAGS.quantize,
|
|
43
|
-
lora_ranks=flags.FLAGS.lora_ranks,
|
|
44
|
-
export_config=export_config.get_from_flags(),
|
|
45
|
-
)
|
|
26
|
+
converter.build_and_convert_to_tflite_from_flags(openelm.build_model)
|
|
46
27
|
|
|
47
28
|
|
|
48
29
|
if __name__ == '__main__':
|
|
@@ -18,31 +18,12 @@
|
|
|
18
18
|
from absl import app
|
|
19
19
|
from ai_edge_torch.generative.examples.phi import phi3
|
|
20
20
|
from ai_edge_torch.generative.utilities import converter
|
|
21
|
-
from ai_edge_torch.generative.utilities import export_config
|
|
22
|
-
from ai_edge_torch.generative.utilities import loader
|
|
23
21
|
|
|
24
22
|
flags = converter.define_conversion_flags("phi3")
|
|
25
23
|
|
|
26
24
|
|
|
27
25
|
def main(_):
|
|
28
|
-
|
|
29
|
-
pytorch_model = phi3.build_model(
|
|
30
|
-
checkpoint_path,
|
|
31
|
-
custom_loader=loader.maybe_get_custom_loader(
|
|
32
|
-
checkpoint_path, flags.FLAGS.custom_checkpoint_loader
|
|
33
|
-
),
|
|
34
|
-
mask_cache_size=converter.get_mask_cache_size_from_flags(),
|
|
35
|
-
)
|
|
36
|
-
converter.convert_to_tflite(
|
|
37
|
-
pytorch_model,
|
|
38
|
-
output_path=flags.FLAGS.output_path,
|
|
39
|
-
output_name_prefix=flags.FLAGS.output_name_prefix,
|
|
40
|
-
prefill_seq_len=flags.FLAGS.prefill_seq_lens,
|
|
41
|
-
kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
|
|
42
|
-
quantize=flags.FLAGS.quantize,
|
|
43
|
-
lora_ranks=flags.FLAGS.lora_ranks,
|
|
44
|
-
export_config=export_config.get_from_flags(),
|
|
45
|
-
)
|
|
26
|
+
converter.build_and_convert_to_tflite_from_flags(phi3.build_model)
|
|
46
27
|
|
|
47
28
|
|
|
48
29
|
if __name__ == '__main__':
|
|
@@ -18,31 +18,12 @@
|
|
|
18
18
|
from absl import app
|
|
19
19
|
from ai_edge_torch.generative.examples.phi import phi4
|
|
20
20
|
from ai_edge_torch.generative.utilities import converter
|
|
21
|
-
from ai_edge_torch.generative.utilities import export_config
|
|
22
|
-
from ai_edge_torch.generative.utilities import loader
|
|
23
21
|
|
|
24
22
|
flags = converter.define_conversion_flags("phi4")
|
|
25
23
|
|
|
26
24
|
|
|
27
25
|
def main(_):
|
|
28
|
-
|
|
29
|
-
pytorch_model = phi4.build_model(
|
|
30
|
-
checkpoint_path,
|
|
31
|
-
custom_loader=loader.maybe_get_custom_loader(
|
|
32
|
-
checkpoint_path, flags.FLAGS.custom_checkpoint_loader
|
|
33
|
-
),
|
|
34
|
-
mask_cache_size=converter.get_mask_cache_size_from_flags(),
|
|
35
|
-
)
|
|
36
|
-
converter.convert_to_tflite(
|
|
37
|
-
pytorch_model,
|
|
38
|
-
output_path=flags.FLAGS.output_path,
|
|
39
|
-
output_name_prefix=flags.FLAGS.output_name_prefix,
|
|
40
|
-
prefill_seq_len=flags.FLAGS.prefill_seq_lens,
|
|
41
|
-
kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
|
|
42
|
-
quantize=flags.FLAGS.quantize,
|
|
43
|
-
lora_ranks=flags.FLAGS.lora_ranks,
|
|
44
|
-
export_config=export_config.get_from_flags(),
|
|
45
|
-
)
|
|
26
|
+
converter.build_and_convert_to_tflite_from_flags(phi4.build_model)
|
|
46
27
|
|
|
47
28
|
|
|
48
29
|
if __name__ == '__main__':
|
|
@@ -19,31 +19,12 @@
|
|
|
19
19
|
from absl import app
|
|
20
20
|
from ai_edge_torch.generative.examples.phi import phi2
|
|
21
21
|
from ai_edge_torch.generative.utilities import converter
|
|
22
|
-
from ai_edge_torch.generative.utilities import export_config
|
|
23
|
-
from ai_edge_torch.generative.utilities import loader
|
|
24
22
|
|
|
25
23
|
flags = converter.define_conversion_flags("phi2")
|
|
26
24
|
|
|
27
25
|
|
|
28
26
|
def main(_):
|
|
29
|
-
|
|
30
|
-
pytorch_model = phi2.build_model(
|
|
31
|
-
checkpoint_path,
|
|
32
|
-
custom_loader=loader.maybe_get_custom_loader(
|
|
33
|
-
checkpoint_path, flags.FLAGS.custom_checkpoint_loader
|
|
34
|
-
),
|
|
35
|
-
mask_cache_size=converter.get_mask_cache_size_from_flags(),
|
|
36
|
-
)
|
|
37
|
-
converter.convert_to_tflite(
|
|
38
|
-
pytorch_model,
|
|
39
|
-
output_path=flags.FLAGS.output_path,
|
|
40
|
-
output_name_prefix=flags.FLAGS.output_name_prefix,
|
|
41
|
-
prefill_seq_len=flags.FLAGS.prefill_seq_lens,
|
|
42
|
-
kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
|
|
43
|
-
quantize=flags.FLAGS.quantize,
|
|
44
|
-
lora_ranks=flags.FLAGS.lora_ranks,
|
|
45
|
-
export_config=export_config.get_from_flags(),
|
|
46
|
-
)
|
|
27
|
+
converter.build_and_convert_to_tflite_from_flags(phi2.build_model)
|
|
47
28
|
|
|
48
29
|
|
|
49
30
|
if __name__ == '__main__':
|
|
@@ -18,8 +18,6 @@
|
|
|
18
18
|
from absl import app
|
|
19
19
|
from ai_edge_torch.generative.examples.qwen import qwen
|
|
20
20
|
from ai_edge_torch.generative.utilities import converter
|
|
21
|
-
from ai_edge_torch.generative.utilities import export_config
|
|
22
|
-
from ai_edge_torch.generative.utilities import loader
|
|
23
21
|
|
|
24
22
|
flags = converter.define_conversion_flags('qwen')
|
|
25
23
|
|
|
@@ -38,24 +36,7 @@ _BUILDER = {
|
|
|
38
36
|
|
|
39
37
|
|
|
40
38
|
def main(_):
|
|
41
|
-
|
|
42
|
-
pytorch_model = _BUILDER[_MODEL_SIZE.value](
|
|
43
|
-
checkpoint_path,
|
|
44
|
-
custom_loader=loader.maybe_get_custom_loader(
|
|
45
|
-
checkpoint_path, flags.FLAGS.custom_checkpoint_loader
|
|
46
|
-
),
|
|
47
|
-
mask_cache_size=converter.get_mask_cache_size_from_flags(),
|
|
48
|
-
)
|
|
49
|
-
converter.convert_to_tflite(
|
|
50
|
-
pytorch_model,
|
|
51
|
-
output_path=flags.FLAGS.output_path,
|
|
52
|
-
output_name_prefix=flags.FLAGS.output_name_prefix,
|
|
53
|
-
prefill_seq_len=flags.FLAGS.prefill_seq_lens,
|
|
54
|
-
kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
|
|
55
|
-
quantize=flags.FLAGS.quantize,
|
|
56
|
-
lora_ranks=flags.FLAGS.lora_ranks,
|
|
57
|
-
export_config=export_config.get_from_flags(),
|
|
58
|
-
)
|
|
39
|
+
converter.build_and_convert_to_tflite_from_flags(_BUILDER[_MODEL_SIZE.value])
|
|
59
40
|
|
|
60
41
|
|
|
61
42
|
if __name__ == '__main__':
|
|
@@ -18,8 +18,6 @@
|
|
|
18
18
|
from absl import app
|
|
19
19
|
from ai_edge_torch.generative.examples.qwen import qwen3
|
|
20
20
|
from ai_edge_torch.generative.utilities import converter
|
|
21
|
-
from ai_edge_torch.generative.utilities import export_config
|
|
22
|
-
from ai_edge_torch.generative.utilities import loader
|
|
23
21
|
|
|
24
22
|
flags = converter.define_conversion_flags('qwen')
|
|
25
23
|
|
|
@@ -38,24 +36,7 @@ _BUILDER = {
|
|
|
38
36
|
|
|
39
37
|
|
|
40
38
|
def main(_):
|
|
41
|
-
|
|
42
|
-
pytorch_model = _BUILDER[_MODEL_SIZE.value](
|
|
43
|
-
checkpoint_path,
|
|
44
|
-
custom_loader=loader.maybe_get_custom_loader(
|
|
45
|
-
checkpoint_path, flags.FLAGS.custom_checkpoint_loader
|
|
46
|
-
),
|
|
47
|
-
mask_cache_size=converter.get_mask_cache_size_from_flags(),
|
|
48
|
-
)
|
|
49
|
-
converter.convert_to_tflite(
|
|
50
|
-
pytorch_model,
|
|
51
|
-
output_path=flags.FLAGS.output_path,
|
|
52
|
-
output_name_prefix=flags.FLAGS.output_name_prefix,
|
|
53
|
-
prefill_seq_len=flags.FLAGS.prefill_seq_lens,
|
|
54
|
-
kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
|
|
55
|
-
quantize=flags.FLAGS.quantize,
|
|
56
|
-
lora_ranks=flags.FLAGS.lora_ranks,
|
|
57
|
-
export_config=export_config.get_from_flags(),
|
|
58
|
-
)
|
|
39
|
+
converter.build_and_convert_to_tflite_from_flags(_BUILDER[_MODEL_SIZE.value])
|
|
59
40
|
|
|
60
41
|
|
|
61
42
|
if __name__ == '__main__':
|
|
@@ -18,41 +18,12 @@
|
|
|
18
18
|
from absl import app
|
|
19
19
|
from ai_edge_torch.generative.examples.smollm import smollm
|
|
20
20
|
from ai_edge_torch.generative.utilities import converter
|
|
21
|
-
from ai_edge_torch.generative.utilities import export_config as export_cfg
|
|
22
|
-
from ai_edge_torch.generative.utilities import loader
|
|
23
21
|
|
|
24
22
|
flags = converter.define_conversion_flags('smollm')
|
|
25
23
|
|
|
26
|
-
_DECODE_BATCH_SIZE = flags.DEFINE_integer(
|
|
27
|
-
'decode_batch_size',
|
|
28
|
-
1,
|
|
29
|
-
'The batch size for the decode signature.',
|
|
30
|
-
)
|
|
31
|
-
|
|
32
24
|
|
|
33
25
|
def main(_):
|
|
34
|
-
|
|
35
|
-
pytorch_model = smollm.build_model(
|
|
36
|
-
checkpoint_path,
|
|
37
|
-
custom_loader=loader.maybe_get_custom_loader(
|
|
38
|
-
checkpoint_path, flags.FLAGS.custom_checkpoint_loader
|
|
39
|
-
),
|
|
40
|
-
mask_cache_size=converter.get_mask_cache_size_from_flags(),
|
|
41
|
-
)
|
|
42
|
-
|
|
43
|
-
export_config = export_cfg.get_from_flags()
|
|
44
|
-
export_config.decode_batch_size = _DECODE_BATCH_SIZE.value
|
|
45
|
-
|
|
46
|
-
converter.convert_to_tflite(
|
|
47
|
-
pytorch_model,
|
|
48
|
-
output_path=flags.FLAGS.output_path,
|
|
49
|
-
output_name_prefix=flags.FLAGS.output_name_prefix,
|
|
50
|
-
prefill_seq_len=flags.FLAGS.prefill_seq_lens,
|
|
51
|
-
kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
|
|
52
|
-
quantize=flags.FLAGS.quantize,
|
|
53
|
-
lora_ranks=flags.FLAGS.lora_ranks,
|
|
54
|
-
export_config=export_config,
|
|
55
|
-
)
|
|
26
|
+
converter.build_and_convert_to_tflite_from_flags(smollm.build_model)
|
|
56
27
|
|
|
57
28
|
|
|
58
29
|
if __name__ == '__main__':
|
|
@@ -18,41 +18,12 @@
|
|
|
18
18
|
from absl import app
|
|
19
19
|
from ai_edge_torch.generative.examples.smollm import smollm
|
|
20
20
|
from ai_edge_torch.generative.utilities import converter
|
|
21
|
-
from ai_edge_torch.generative.utilities import export_config as export_cfg
|
|
22
|
-
from ai_edge_torch.generative.utilities import loader
|
|
23
21
|
|
|
24
22
|
flags = converter.define_conversion_flags('smollm2')
|
|
25
23
|
|
|
26
|
-
_DECODE_BATCH_SIZE = flags.DEFINE_integer(
|
|
27
|
-
'decode_batch_size',
|
|
28
|
-
1,
|
|
29
|
-
'The batch size for the decode signature.',
|
|
30
|
-
)
|
|
31
|
-
|
|
32
24
|
|
|
33
25
|
def main(_):
|
|
34
|
-
|
|
35
|
-
pytorch_model = smollm.build_model_v2(
|
|
36
|
-
checkpoint_path,
|
|
37
|
-
custom_loader=loader.maybe_get_custom_loader(
|
|
38
|
-
checkpoint_path, flags.FLAGS.custom_checkpoint_loader
|
|
39
|
-
),
|
|
40
|
-
mask_cache_size=converter.get_mask_cache_size_from_flags(),
|
|
41
|
-
)
|
|
42
|
-
|
|
43
|
-
export_config = export_cfg.get_from_flags()
|
|
44
|
-
export_config.decode_batch_size = _DECODE_BATCH_SIZE.value
|
|
45
|
-
|
|
46
|
-
converter.convert_to_tflite(
|
|
47
|
-
pytorch_model,
|
|
48
|
-
output_path=flags.FLAGS.output_path,
|
|
49
|
-
output_name_prefix=flags.FLAGS.output_name_prefix,
|
|
50
|
-
prefill_seq_len=flags.FLAGS.prefill_seq_lens,
|
|
51
|
-
kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
|
|
52
|
-
quantize=flags.FLAGS.quantize,
|
|
53
|
-
lora_ranks=flags.FLAGS.lora_ranks,
|
|
54
|
-
export_config=export_config,
|
|
55
|
-
)
|
|
26
|
+
converter.build_and_convert_to_tflite_from_flags(smollm.build_model_v2)
|
|
56
27
|
|
|
57
28
|
|
|
58
29
|
if __name__ == '__main__':
|
|
@@ -138,9 +138,7 @@ def convert_stable_diffusion_to_tflite(
|
|
|
138
138
|
if not os.path.exists(output_dir):
|
|
139
139
|
pathlib.Path(output_dir).mkdir(parents=True, exist_ok=True)
|
|
140
140
|
|
|
141
|
-
quant_config = (
|
|
142
|
-
quant_recipes.full_int8_weight_only_recipe() if quantize else None
|
|
143
|
-
)
|
|
141
|
+
quant_config = quant_recipes.full_weight_only_recipe() if quantize else None
|
|
144
142
|
|
|
145
143
|
# TODO(yichunk): convert to multi signature tflite model.
|
|
146
144
|
# CLIP text encoder
|
|
@@ -18,31 +18,12 @@
|
|
|
18
18
|
from absl import app
|
|
19
19
|
from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
|
|
20
20
|
from ai_edge_torch.generative.utilities import converter
|
|
21
|
-
from ai_edge_torch.generative.utilities import export_config
|
|
22
|
-
from ai_edge_torch.generative.utilities import loader
|
|
23
21
|
|
|
24
22
|
flags = converter.define_conversion_flags("tiny_llama")
|
|
25
23
|
|
|
26
24
|
|
|
27
25
|
def main(_):
|
|
28
|
-
|
|
29
|
-
pytorch_model = tiny_llama.build_model(
|
|
30
|
-
checkpoint_path,
|
|
31
|
-
custom_loader=loader.maybe_get_custom_loader(
|
|
32
|
-
checkpoint_path, flags.FLAGS.custom_checkpoint_loader
|
|
33
|
-
),
|
|
34
|
-
mask_cache_size=converter.get_mask_cache_size_from_flags(),
|
|
35
|
-
)
|
|
36
|
-
converter.convert_to_tflite(
|
|
37
|
-
pytorch_model,
|
|
38
|
-
output_path=flags.FLAGS.output_path,
|
|
39
|
-
output_name_prefix=flags.FLAGS.output_name_prefix,
|
|
40
|
-
prefill_seq_len=flags.FLAGS.prefill_seq_lens,
|
|
41
|
-
kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
|
|
42
|
-
quantize=flags.FLAGS.quantize,
|
|
43
|
-
lora_ranks=flags.FLAGS.lora_ranks,
|
|
44
|
-
export_config=export_config.get_from_flags(),
|
|
45
|
-
)
|
|
26
|
+
converter.build_and_convert_to_tflite_from_flags(tiny_llama.build_model)
|
|
46
27
|
|
|
47
28
|
|
|
48
29
|
if __name__ == '__main__':
|