ai-edge-torch-nightly 0.4.0.dev20250330__py3-none-any.whl → 0.4.0.dev20250331__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/generative/examples/amd_llama_135m/convert_to_tflite.py +7 -43
- ai_edge_torch/generative/examples/deepseek/convert_to_tflite.py +7 -42
- ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py +7 -45
- ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +7 -44
- ai_edge_torch/generative/examples/gemma3/convert_gemma3_to_tflite.py +10 -45
- ai_edge_torch/generative/examples/llama/convert_to_tflite.py +9 -43
- ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +7 -44
- ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py +8 -39
- ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py +7 -44
- ai_edge_torch/generative/examples/phi/convert_phi4_to_tflite.py +7 -44
- ai_edge_torch/generative/examples/phi/convert_to_tflite.py +7 -42
- ai_edge_torch/generative/examples/qwen/convert_to_tflite.py +8 -45
- ai_edge_torch/generative/examples/qwen_vl/convert_to_tflite.py +8 -39
- ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +8 -43
- ai_edge_torch/generative/examples/smollm/convert_v2_to_tflite.py +8 -43
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +7 -44
- ai_edge_torch/generative/utilities/converter.py +45 -0
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.4.0.dev20250330.dist-info → ai_edge_torch_nightly-0.4.0.dev20250331.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.4.0.dev20250330.dist-info → ai_edge_torch_nightly-0.4.0.dev20250331.dist-info}/RECORD +23 -23
- {ai_edge_torch_nightly-0.4.0.dev20250330.dist-info → ai_edge_torch_nightly-0.4.0.dev20250331.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.4.0.dev20250330.dist-info → ai_edge_torch_nightly-0.4.0.dev20250331.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.4.0.dev20250330.dist-info → ai_edge_torch_nightly-0.4.0.dev20250331.dist-info}/top_level.txt +0 -0
@@ -16,61 +16,25 @@
|
|
16
16
|
"""Example of converting AMD-Llama-135m model to multi-signature tflite model."""
|
17
17
|
|
18
18
|
import os
|
19
|
-
import pathlib
|
20
|
-
|
21
19
|
from absl import app
|
22
20
|
from absl import flags
|
23
21
|
from ai_edge_torch.generative.examples.amd_llama_135m import amd_llama_135m
|
24
22
|
from ai_edge_torch.generative.utilities import converter
|
25
23
|
from ai_edge_torch.generative.utilities.model_builder import ExportConfig
|
26
24
|
|
27
|
-
|
28
|
-
'checkpoint_path',
|
29
|
-
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/amd-llama-135m'),
|
30
|
-
'The path to the model checkpoint, or directory holding the checkpoint.',
|
31
|
-
)
|
32
|
-
_KV_CACHE_MAX_LEN = flags.DEFINE_integer(
|
33
|
-
'kv_cache_max_len',
|
34
|
-
1280,
|
35
|
-
'The maximum size of KV cache buffer, including both prefill and decode.',
|
36
|
-
)
|
37
|
-
_OUTPUT_PATH = flags.DEFINE_string(
|
38
|
-
'output_path',
|
39
|
-
'/tmp/',
|
40
|
-
'The path to export the tflite model.',
|
41
|
-
)
|
42
|
-
_OUTPUT_NAME_PREFIX = flags.DEFINE_string(
|
43
|
-
'output_name_prefix',
|
44
|
-
'amd_llama',
|
45
|
-
'The prefix of the output tflite model name.',
|
46
|
-
)
|
47
|
-
_PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
|
48
|
-
'prefill_seq_lens',
|
49
|
-
(8, 64, 128, 256, 512, 1024),
|
50
|
-
'List of the maximum sizes of prefill input tensors.',
|
51
|
-
)
|
52
|
-
_QUANTIZE = flags.DEFINE_bool(
|
53
|
-
'quantize',
|
54
|
-
True,
|
55
|
-
'Whether the model should be quantized.',
|
56
|
-
)
|
57
|
-
_LORA_RANKS = flags.DEFINE_multi_integer(
|
58
|
-
'lora_ranks',
|
59
|
-
None,
|
60
|
-
'If set, the model will be converted with the provided list of LoRA ranks.',
|
61
|
-
)
|
25
|
+
flags = converter.define_conversion_flags("amd-llama-135m")
|
62
26
|
|
63
27
|
def main(_):
|
64
28
|
pytorch_model = amd_llama_135m.build_model(
|
65
|
-
|
29
|
+
flags.FLAGS.checkpoint_path, kv_cache_max_len=flags.FLAGS.kv_cache_max_len
|
66
30
|
)
|
67
31
|
converter.convert_to_tflite(
|
68
32
|
pytorch_model,
|
69
|
-
output_path=
|
70
|
-
output_name_prefix=
|
71
|
-
prefill_seq_len=
|
72
|
-
quantize=
|
73
|
-
lora_ranks=
|
33
|
+
output_path=flags.FLAGS.output_path,
|
34
|
+
output_name_prefix=flags.FLAGS.output_name_prefix,
|
35
|
+
prefill_seq_len=flags.FLAGS.prefill_seq_lens,
|
36
|
+
quantize=flags.FLAGS.quantize,
|
37
|
+
lora_ranks=flags.FLAGS.lora_ranks,
|
74
38
|
export_config=ExportConfig(),
|
75
39
|
)
|
76
40
|
|
@@ -24,54 +24,19 @@ from ai_edge_torch.generative.examples.deepseek import deepseek
|
|
24
24
|
from ai_edge_torch.generative.utilities import converter
|
25
25
|
from ai_edge_torch.generative.utilities.model_builder import ExportConfig
|
26
26
|
|
27
|
-
|
28
|
-
'checkpoint_path',
|
29
|
-
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/deepseek'),
|
30
|
-
'The path to the model checkpoint, or directory holding the checkpoint.',
|
31
|
-
)
|
32
|
-
_OUTPUT_PATH = flags.DEFINE_string(
|
33
|
-
'output_path',
|
34
|
-
'/tmp/',
|
35
|
-
'The path to export the tflite model.',
|
36
|
-
)
|
37
|
-
_OUTPUT_NAME_PREFIX = flags.DEFINE_string(
|
38
|
-
'output_name_prefix',
|
39
|
-
'deepseek',
|
40
|
-
'The prefix of the output tflite model name.',
|
41
|
-
)
|
42
|
-
_PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
|
43
|
-
'prefill_seq_lens',
|
44
|
-
(8, 64, 128, 256, 512, 1024),
|
45
|
-
'List of the maximum sizes of prefill input tensors.',
|
46
|
-
)
|
47
|
-
_KV_CACHE_MAX_LEN = flags.DEFINE_integer(
|
48
|
-
'kv_cache_max_len',
|
49
|
-
1280,
|
50
|
-
'The maximum size of KV cache buffer, including both prefill and decode.',
|
51
|
-
)
|
52
|
-
_QUANTIZE = flags.DEFINE_bool(
|
53
|
-
'quantize',
|
54
|
-
True,
|
55
|
-
'Whether the model should be quantized.',
|
56
|
-
)
|
57
|
-
_LORA_RANKS = flags.DEFINE_multi_integer(
|
58
|
-
'lora_ranks',
|
59
|
-
None,
|
60
|
-
'If set, the model will be converted with the provided list of LoRA ranks.',
|
61
|
-
)
|
62
|
-
|
27
|
+
flags = converter.define_conversion_flags("deepseek")
|
63
28
|
|
64
29
|
def main(_):
|
65
30
|
pytorch_model = deepseek.build_model(
|
66
|
-
|
31
|
+
flags.FLAGS.checkpoint_path, kv_cache_max_len=flags.FLAGS.kv_cache_max_len
|
67
32
|
)
|
68
33
|
converter.convert_to_tflite(
|
69
34
|
pytorch_model,
|
70
|
-
output_path=
|
71
|
-
output_name_prefix=
|
72
|
-
prefill_seq_len=
|
73
|
-
quantize=
|
74
|
-
lora_ranks=
|
35
|
+
output_path=flags.FLAGS.output_path,
|
36
|
+
output_name_prefix=flags.FLAGS.output_name_prefix,
|
37
|
+
prefill_seq_len=flags.FLAGS.prefill_seq_lens,
|
38
|
+
quantize=flags.FLAGS.quantize,
|
39
|
+
lora_ranks=flags.FLAGS.lora_ranks,
|
75
40
|
export_config=ExportConfig(),
|
76
41
|
)
|
77
42
|
|
@@ -16,62 +16,24 @@
|
|
16
16
|
"""Example of converting a Gemma1 model to multi-signature tflite model."""
|
17
17
|
|
18
18
|
import os
|
19
|
-
import pathlib
|
20
|
-
|
21
19
|
from absl import app
|
22
|
-
from absl import flags
|
23
20
|
from ai_edge_torch.generative.examples.gemma import gemma1
|
24
21
|
from ai_edge_torch.generative.utilities import converter
|
25
22
|
from ai_edge_torch.generative.utilities.model_builder import ExportConfig
|
26
23
|
|
27
|
-
|
28
|
-
'checkpoint_path',
|
29
|
-
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/gemma-2b'),
|
30
|
-
'The path to the model checkpoint, or directory holding the checkpoint.',
|
31
|
-
)
|
32
|
-
_OUTPUT_PATH = flags.DEFINE_string(
|
33
|
-
'output_path',
|
34
|
-
'/tmp/',
|
35
|
-
'The path to export the tflite model.',
|
36
|
-
)
|
37
|
-
_OUTPUT_NAME_PREFIX = flags.DEFINE_string(
|
38
|
-
'output_name_prefix',
|
39
|
-
'gemma',
|
40
|
-
'The prefix of the output tflite model name.',
|
41
|
-
)
|
42
|
-
_PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
|
43
|
-
'prefill_seq_lens',
|
44
|
-
(8, 64, 128, 256, 512, 1024),
|
45
|
-
'List of the maximum sizes of prefill input tensors.',
|
46
|
-
)
|
47
|
-
_KV_CACHE_MAX_LEN = flags.DEFINE_integer(
|
48
|
-
'kv_cache_max_len',
|
49
|
-
1280,
|
50
|
-
'The maximum size of KV cache buffer, including both prefill and decode.',
|
51
|
-
)
|
52
|
-
_QUANTIZE = flags.DEFINE_bool(
|
53
|
-
'quantize',
|
54
|
-
True,
|
55
|
-
'Whether the model should be quantized.',
|
56
|
-
)
|
57
|
-
_LORA_RANKS = flags.DEFINE_multi_integer(
|
58
|
-
'lora_ranks',
|
59
|
-
None,
|
60
|
-
'If set, the model will be converted with the provided list of LoRA ranks.',
|
61
|
-
)
|
62
|
-
|
24
|
+
flags = converter.define_conversion_flags("gemma-2b")
|
63
25
|
|
64
26
|
def main(_):
|
65
27
|
pytorch_model = gemma1.build_2b_model(
|
66
|
-
|
28
|
+
flags.FLAGS.checkpoint_path, kv_cache_max_len=flags.FLAGS.kv_cache_max_len
|
67
29
|
)
|
68
30
|
converter.convert_to_tflite(
|
69
31
|
pytorch_model,
|
70
|
-
output_path=
|
71
|
-
output_name_prefix=
|
72
|
-
prefill_seq_len=
|
73
|
-
quantize=
|
74
|
-
lora_ranks=
|
32
|
+
output_path=flags.FLAGS.output_path,
|
33
|
+
output_name_prefix=flags.FLAGS.output_name_prefix,
|
34
|
+
prefill_seq_len=flags.FLAGS.prefill_seq_lens,
|
35
|
+
quantize=flags.FLAGS.quantize,
|
36
|
+
lora_ranks=flags.FLAGS.lora_ranks,
|
75
37
|
export_config=ExportConfig(),
|
76
38
|
)
|
77
39
|
|
@@ -16,62 +16,25 @@
|
|
16
16
|
"""Example of converting a Gemma2 model to multi-signature tflite model."""
|
17
17
|
|
18
18
|
import os
|
19
|
-
import pathlib
|
20
|
-
|
21
19
|
from absl import app
|
22
20
|
from absl import flags
|
23
21
|
from ai_edge_torch.generative.examples.gemma import gemma2
|
24
22
|
from ai_edge_torch.generative.utilities import converter
|
25
23
|
from ai_edge_torch.generative.utilities.model_builder import ExportConfig
|
26
24
|
|
27
|
-
|
28
|
-
'checkpoint_path',
|
29
|
-
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/gemma2-2b'),
|
30
|
-
'The path to the model checkpoint, or directory holding the checkpoint.',
|
31
|
-
)
|
32
|
-
_OUTPUT_PATH = flags.DEFINE_string(
|
33
|
-
'output_path',
|
34
|
-
'/tmp/',
|
35
|
-
'The path to export the tflite model.',
|
36
|
-
)
|
37
|
-
_OUTPUT_NAME_PREFIX = flags.DEFINE_string(
|
38
|
-
'output_name_prefix',
|
39
|
-
'gemma2',
|
40
|
-
'The prefix of the output tflite model name.',
|
41
|
-
)
|
42
|
-
_PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
|
43
|
-
'prefill_seq_lens',
|
44
|
-
(8, 64, 128, 256, 512, 1024),
|
45
|
-
'List of the maximum sizes of prefill input tensors.',
|
46
|
-
)
|
47
|
-
_KV_CACHE_MAX_LEN = flags.DEFINE_integer(
|
48
|
-
'kv_cache_max_len',
|
49
|
-
1280,
|
50
|
-
'The maximum size of KV cache buffer, including both prefill and decode.',
|
51
|
-
)
|
52
|
-
_QUANTIZE = flags.DEFINE_bool(
|
53
|
-
'quantize',
|
54
|
-
True,
|
55
|
-
'Whether the model should be quantized.',
|
56
|
-
)
|
57
|
-
_LORA_RANKS = flags.DEFINE_multi_integer(
|
58
|
-
'lora_ranks',
|
59
|
-
None,
|
60
|
-
'If set, the model will be converted with the provided list of LoRA ranks.',
|
61
|
-
)
|
62
|
-
|
25
|
+
flags = converter.define_conversion_flags("gemma2-2b")
|
63
26
|
|
64
27
|
def main(_):
|
65
28
|
pytorch_model = gemma2.build_2b_model(
|
66
|
-
|
29
|
+
flags.FLAGS.checkpoint_path, kv_cache_max_len=flags.FLAGS.kv_cache_max_len
|
67
30
|
)
|
68
31
|
converter.convert_to_tflite(
|
69
32
|
pytorch_model,
|
70
|
-
output_path=
|
71
|
-
output_name_prefix=
|
72
|
-
prefill_seq_len=
|
73
|
-
quantize=
|
74
|
-
lora_ranks=
|
33
|
+
output_path=flags.FLAGS.output_path,
|
34
|
+
output_name_prefix=flags.FLAGS.output_name_prefix,
|
35
|
+
prefill_seq_len=flags.FLAGS.prefill_seq_lens,
|
36
|
+
quantize=flags.FLAGS.quantize,
|
37
|
+
lora_ranks=flags.FLAGS.lora_ranks,
|
75
38
|
export_config=ExportConfig(),
|
76
39
|
)
|
77
40
|
|
@@ -16,8 +16,6 @@
|
|
16
16
|
"""Example of converting a Gemma3 model to multi-signature tflite model."""
|
17
17
|
|
18
18
|
import os
|
19
|
-
import pathlib
|
20
|
-
|
21
19
|
from absl import app
|
22
20
|
from absl import flags
|
23
21
|
from ai_edge_torch.generative.examples.gemma3 import gemma3
|
@@ -26,48 +24,14 @@ from ai_edge_torch.generative.utilities import converter
|
|
26
24
|
from ai_edge_torch.generative.utilities.model_builder import ExportConfig
|
27
25
|
import torch
|
28
26
|
|
27
|
+
flags = converter.define_conversion_flags('gemma3-1b')
|
28
|
+
|
29
29
|
_MODEL_SIZE = flags.DEFINE_string(
|
30
30
|
'model_size',
|
31
31
|
'1b',
|
32
32
|
'The size of the model to convert.',
|
33
33
|
)
|
34
34
|
|
35
|
-
_CHECKPOINT_PATH = flags.DEFINE_string(
|
36
|
-
'checkpoint_path',
|
37
|
-
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/gemma3-1b'),
|
38
|
-
'The path to the model checkpoint, or directory holding the checkpoint.',
|
39
|
-
)
|
40
|
-
_OUTPUT_PATH = flags.DEFINE_string(
|
41
|
-
'output_path',
|
42
|
-
'/tmp/',
|
43
|
-
'The path to export the tflite model.',
|
44
|
-
)
|
45
|
-
_OUTPUT_NAME_PREFIX = flags.DEFINE_string(
|
46
|
-
'output_name_prefix',
|
47
|
-
'gemma3',
|
48
|
-
'The prefix of the output tflite model name.',
|
49
|
-
)
|
50
|
-
_PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
|
51
|
-
'prefill_seq_lens',
|
52
|
-
(32, 64, 128, 256, 512, 1024),
|
53
|
-
'List of the maximum sizes of prefill input tensors.',
|
54
|
-
)
|
55
|
-
_KV_CACHE_MAX_LEN = flags.DEFINE_integer(
|
56
|
-
'kv_cache_max_len',
|
57
|
-
2048,
|
58
|
-
'The maximum size of KV cache buffer, including both prefill and decode.',
|
59
|
-
)
|
60
|
-
_QUANTIZE = flags.DEFINE_bool(
|
61
|
-
'quantize',
|
62
|
-
False,
|
63
|
-
'Whether the model should be quantized.',
|
64
|
-
)
|
65
|
-
_LORA_RANKS = flags.DEFINE_multi_integer(
|
66
|
-
'lora_ranks',
|
67
|
-
None,
|
68
|
-
'If set, the model will be converted with the provided list of LoRA ranks.',
|
69
|
-
)
|
70
|
-
|
71
35
|
|
72
36
|
def _create_mask(mask_len, kv_cache_max_len):
|
73
37
|
mask = torch.full(
|
@@ -101,21 +65,22 @@ def _create_export_config(
|
|
101
65
|
def main(_):
|
102
66
|
if _MODEL_SIZE.value == '1b':
|
103
67
|
pytorch_model = gemma3.build_model_1b(
|
104
|
-
|
68
|
+
flags.FLAGS.checkpoint_path,
|
69
|
+
kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
|
105
70
|
)
|
106
71
|
config = pytorch_model.config
|
107
72
|
else:
|
108
73
|
raise ValueError(f'Unsupported model size: {_MODEL_SIZE.value}')
|
109
74
|
converter.convert_to_tflite(
|
110
75
|
pytorch_model,
|
111
|
-
output_path=
|
112
|
-
output_name_prefix=
|
113
|
-
prefill_seq_len=
|
114
|
-
quantize=
|
76
|
+
output_path=flags.FLAGS.output_path,
|
77
|
+
output_name_prefix=flags.FLAGS.output_name_prefix,
|
78
|
+
prefill_seq_len=flags.FLAGS.prefill_seq_lens,
|
79
|
+
quantize=flags.FLAGS.quantize,
|
115
80
|
config=config,
|
116
|
-
lora_ranks=
|
81
|
+
lora_ranks=flags.FLAGS.lora_ranks,
|
117
82
|
export_config=_create_export_config(
|
118
|
-
|
83
|
+
flags.FLAGS.prefill_seq_lens, flags.FLAGS.kv_cache_max_len
|
119
84
|
),
|
120
85
|
)
|
121
86
|
|
@@ -16,55 +16,21 @@
|
|
16
16
|
"""Example of converting Llama 3.2 1B model to multi-signature tflite model."""
|
17
17
|
|
18
18
|
import os
|
19
|
-
import pathlib
|
20
|
-
|
21
19
|
from absl import app
|
22
20
|
from absl import flags
|
23
21
|
from ai_edge_torch.generative.examples.llama import llama
|
24
22
|
from ai_edge_torch.generative.utilities import converter
|
25
23
|
from ai_edge_torch.generative.utilities.model_builder import ExportConfig
|
26
24
|
|
25
|
+
|
26
|
+
flags = converter.define_conversion_flags('llama')
|
27
|
+
|
27
28
|
_MODEL_SIZE = flags.DEFINE_enum(
|
28
29
|
'model_size',
|
29
30
|
'1b',
|
30
31
|
['1b', '3b'],
|
31
32
|
'The size of the model to verify.',
|
32
33
|
)
|
33
|
-
_CHECKPOINT_PATH = flags.DEFINE_string(
|
34
|
-
'checkpoint_path',
|
35
|
-
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/llama'),
|
36
|
-
'The path to the model checkpoint, or directory holding the checkpoint.',
|
37
|
-
)
|
38
|
-
_OUTPUT_PATH = flags.DEFINE_string(
|
39
|
-
'output_path',
|
40
|
-
'/tmp/',
|
41
|
-
'The path to export the tflite model.',
|
42
|
-
)
|
43
|
-
_OUTPUT_NAME_PREFIX = flags.DEFINE_string(
|
44
|
-
'output_name_prefix',
|
45
|
-
'llama',
|
46
|
-
'The prefix of the output tflite model name.',
|
47
|
-
)
|
48
|
-
_PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
|
49
|
-
'prefill_seq_lens',
|
50
|
-
(8, 64, 128, 256, 512, 1024),
|
51
|
-
'List of the maximum sizes of prefill input tensors.',
|
52
|
-
)
|
53
|
-
_KV_CACHE_MAX_LEN = flags.DEFINE_integer(
|
54
|
-
'kv_cache_max_len',
|
55
|
-
1280,
|
56
|
-
'The maximum size of KV cache buffer, including both prefill and decode.',
|
57
|
-
)
|
58
|
-
_QUANTIZE = flags.DEFINE_bool(
|
59
|
-
'quantize',
|
60
|
-
True,
|
61
|
-
'Whether the model should be quantized.',
|
62
|
-
)
|
63
|
-
_LORA_RANKS = flags.DEFINE_multi_integer(
|
64
|
-
'lora_ranks',
|
65
|
-
None,
|
66
|
-
'If set, the model will be converted with the provided list of LoRA ranks.',
|
67
|
-
)
|
68
34
|
|
69
35
|
_BUILDER = {
|
70
36
|
'1b': llama.build_1b_model,
|
@@ -74,15 +40,15 @@ _BUILDER = {
|
|
74
40
|
|
75
41
|
def main(_):
|
76
42
|
pytorch_model = _BUILDER[_MODEL_SIZE.value](
|
77
|
-
|
43
|
+
flags.FLAGS.checkpoint_path, kv_cache_max_len=flags.FLAGS.kv_cache_max_len
|
78
44
|
)
|
79
45
|
converter.convert_to_tflite(
|
80
46
|
pytorch_model,
|
81
|
-
output_path=
|
82
|
-
output_name_prefix=
|
83
|
-
prefill_seq_len=
|
84
|
-
quantize=
|
85
|
-
lora_ranks=
|
47
|
+
output_path=flags.FLAGS.output_path,
|
48
|
+
output_name_prefix=flags.FLAGS.output_name_prefix,
|
49
|
+
prefill_seq_len=flags.FLAGS.prefill_seq_lens,
|
50
|
+
quantize=flags.FLAGS.quantize,
|
51
|
+
lora_ranks=flags.FLAGS.lora_ranks,
|
86
52
|
export_config=ExportConfig(),
|
87
53
|
)
|
88
54
|
|
@@ -16,62 +16,25 @@
|
|
16
16
|
"""Example of converting OpenELM model to multi-signature tflite model."""
|
17
17
|
|
18
18
|
import os
|
19
|
-
import pathlib
|
20
|
-
|
21
19
|
from absl import app
|
22
20
|
from absl import flags
|
23
21
|
from ai_edge_torch.generative.examples.openelm import openelm
|
24
22
|
from ai_edge_torch.generative.utilities import converter
|
25
23
|
from ai_edge_torch.generative.utilities.model_builder import ExportConfig
|
26
24
|
|
27
|
-
|
28
|
-
'checkpoint_path',
|
29
|
-
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/openelm'),
|
30
|
-
'The path to the model checkpoint, or directory holding the checkpoint.',
|
31
|
-
)
|
32
|
-
_OUTPUT_PATH = flags.DEFINE_string(
|
33
|
-
'output_path',
|
34
|
-
'/tmp/',
|
35
|
-
'The path to export the tflite model.',
|
36
|
-
)
|
37
|
-
_OUTPUT_NAME_PREFIX = flags.DEFINE_string(
|
38
|
-
'output_name_prefix',
|
39
|
-
'openelm',
|
40
|
-
'The prefix of the output tflite model name.',
|
41
|
-
)
|
42
|
-
_PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
|
43
|
-
'prefill_seq_lens',
|
44
|
-
(8, 64, 128, 256, 512, 1024),
|
45
|
-
'List of the maximum sizes of prefill input tensors.',
|
46
|
-
)
|
47
|
-
_KV_CACHE_MAX_LEN = flags.DEFINE_integer(
|
48
|
-
'kv_cache_max_len',
|
49
|
-
1280,
|
50
|
-
'The maximum size of KV cache buffer, including both prefill and decode.',
|
51
|
-
)
|
52
|
-
_QUANTIZE = flags.DEFINE_bool(
|
53
|
-
'quantize',
|
54
|
-
True,
|
55
|
-
'Whether the model should be quantized.',
|
56
|
-
)
|
57
|
-
_LORA_RANKS = flags.DEFINE_multi_integer(
|
58
|
-
'lora_ranks',
|
59
|
-
None,
|
60
|
-
'If set, the model will be converted with the provided list of LoRA ranks.',
|
61
|
-
)
|
62
|
-
|
25
|
+
flags = converter.define_conversion_flags("openelm")
|
63
26
|
|
64
27
|
def main(_):
|
65
28
|
pytorch_model = openelm.build_model(
|
66
|
-
|
29
|
+
flags.FLAGS.checkpoint_path, kv_cache_max_len=flags.FLAGS.kv_cache_max_len
|
67
30
|
)
|
68
31
|
converter.convert_to_tflite(
|
69
32
|
pytorch_model,
|
70
|
-
output_path=
|
71
|
-
output_name_prefix=
|
72
|
-
prefill_seq_len=
|
73
|
-
quantize=
|
74
|
-
lora_ranks=
|
33
|
+
output_path=flags.FLAGS.output_path,
|
34
|
+
output_name_prefix=flags.FLAGS.output_name_prefix,
|
35
|
+
prefill_seq_len=flags.FLAGS.prefill_seq_lens,
|
36
|
+
quantize=flags.FLAGS.quantize,
|
37
|
+
lora_ranks=flags.FLAGS.lora_ranks,
|
75
38
|
export_config=ExportConfig(),
|
76
39
|
)
|
77
40
|
|
@@ -16,8 +16,6 @@
|
|
16
16
|
"""Example of converting a PaliGemma model to multi-signature tflite model."""
|
17
17
|
|
18
18
|
import os
|
19
|
-
import pathlib
|
20
|
-
|
21
19
|
from absl import app
|
22
20
|
from absl import flags
|
23
21
|
from ai_edge_torch.generative.examples.paligemma import paligemma
|
@@ -25,61 +23,32 @@ from ai_edge_torch.generative.utilities import converter
|
|
25
23
|
from ai_edge_torch.generative.utilities.model_builder import ExportConfig
|
26
24
|
import torch
|
27
25
|
|
26
|
+
flags = converter.define_conversion_flags('paligemma2-3b-224')
|
27
|
+
|
28
28
|
_VERSION = flags.DEFINE_enum(
|
29
29
|
'version',
|
30
30
|
'2',
|
31
31
|
['1', '2'],
|
32
32
|
'The version of PaliGemma model to verify.',
|
33
33
|
)
|
34
|
-
_CHECKPOINT_PATH = flags.DEFINE_string(
|
35
|
-
'checkpoint_path',
|
36
|
-
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/paligemma2-3b-224'),
|
37
|
-
'The path to the model checkpoint, or directory holding the checkpoint.',
|
38
|
-
)
|
39
|
-
_OUTPUT_PATH = flags.DEFINE_string(
|
40
|
-
'output_path',
|
41
|
-
'/tmp/',
|
42
|
-
'The path to export the tflite model.',
|
43
|
-
)
|
44
|
-
_OUTPUT_NAME_PREFIX = flags.DEFINE_string(
|
45
|
-
'output_name_prefix',
|
46
|
-
'paligemma',
|
47
|
-
'The prefix of the output tflite model name.',
|
48
|
-
)
|
49
|
-
_PREFILL_SEQ_LEN = flags.DEFINE_integer(
|
50
|
-
'prefill_seq_len',
|
51
|
-
1024,
|
52
|
-
'The maximum size of prefill input tensor.',
|
53
|
-
)
|
54
|
-
_KV_CACHE_MAX_LEN = flags.DEFINE_integer(
|
55
|
-
'kv_cache_max_len',
|
56
|
-
1280,
|
57
|
-
'The maximum size of KV cache buffer, including both prefill and decode.',
|
58
|
-
)
|
59
|
-
_QUANTIZE = flags.DEFINE_bool(
|
60
|
-
'quantize',
|
61
|
-
True,
|
62
|
-
'Whether the model should be quantized.',
|
63
|
-
)
|
64
|
-
|
65
34
|
|
66
35
|
def main(_):
|
67
36
|
pytorch_model = paligemma.build_model(
|
68
|
-
|
37
|
+
flags.FLAGS.checkpoint_path,
|
69
38
|
version=int(_VERSION.value),
|
70
|
-
kv_cache_max_len=
|
39
|
+
kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
|
71
40
|
)
|
72
41
|
|
73
42
|
config = pytorch_model.image_encoder.config.image_embedding
|
74
43
|
converter.convert_to_tflite(
|
75
44
|
pytorch_model,
|
76
|
-
output_path=
|
77
|
-
output_name_prefix=f'{
|
78
|
-
prefill_seq_len=
|
45
|
+
output_path=flags.FLAGS.output_path,
|
46
|
+
output_name_prefix=f'{flags.FLAGS.output_name_prefix}_{_VERSION.value}',
|
47
|
+
prefill_seq_len=flags.FLAGS.prefill_seq_lens,
|
79
48
|
pixel_values_size=torch.Size(
|
80
49
|
[1, config.channels, config.image_size, config.image_size]
|
81
50
|
),
|
82
|
-
quantize=
|
51
|
+
quantize=flags.FLAGS.quantize,
|
83
52
|
config=pytorch_model.config.decoder_config,
|
84
53
|
export_config=ExportConfig(),
|
85
54
|
)
|
@@ -16,62 +16,25 @@
|
|
16
16
|
"""Example of converting a Phi-3.5 model to multi-signature tflite model."""
|
17
17
|
|
18
18
|
import os
|
19
|
-
import pathlib
|
20
|
-
|
21
19
|
from absl import app
|
22
20
|
from absl import flags
|
23
21
|
from ai_edge_torch.generative.examples.phi import phi3
|
24
22
|
from ai_edge_torch.generative.utilities import converter
|
25
23
|
from ai_edge_torch.generative.utilities.model_builder import ExportConfig
|
26
24
|
|
27
|
-
|
28
|
-
'checkpoint_path',
|
29
|
-
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/phi3'),
|
30
|
-
'The path to the model checkpoint, or directory holding the checkpoint.',
|
31
|
-
)
|
32
|
-
_OUTPUT_PATH = flags.DEFINE_string(
|
33
|
-
'output_path',
|
34
|
-
'/tmp/',
|
35
|
-
'The path to export the tflite model.',
|
36
|
-
)
|
37
|
-
_OUTPUT_NAME_PREFIX = flags.DEFINE_string(
|
38
|
-
'output_name_prefix',
|
39
|
-
'phi3',
|
40
|
-
'The prefix of the output tflite model name.',
|
41
|
-
)
|
42
|
-
_PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
|
43
|
-
'prefill_seq_lens',
|
44
|
-
(8, 64, 128, 256, 512, 1024),
|
45
|
-
'List of the maximum sizes of prefill input tensors.',
|
46
|
-
)
|
47
|
-
_KV_CACHE_MAX_LEN = flags.DEFINE_integer(
|
48
|
-
'kv_cache_max_len',
|
49
|
-
1280,
|
50
|
-
'The maximum size of KV cache buffer, including both prefill and decode.',
|
51
|
-
)
|
52
|
-
_QUANTIZE = flags.DEFINE_bool(
|
53
|
-
'quantize',
|
54
|
-
True,
|
55
|
-
'Whether the model should be quantized.',
|
56
|
-
)
|
57
|
-
_LORA_RANKS = flags.DEFINE_multi_integer(
|
58
|
-
'lora_ranks',
|
59
|
-
None,
|
60
|
-
'If set, the model will be converted with the provided list of LoRA ranks.',
|
61
|
-
)
|
62
|
-
|
25
|
+
flags = converter.define_conversion_flags("phi3")
|
63
26
|
|
64
27
|
def main(_):
|
65
28
|
pytorch_model = phi3.build_model(
|
66
|
-
|
29
|
+
flags.FLAGS.checkpoint_path, kv_cache_max_len=flags.FLAGS.kv_cache_max_len
|
67
30
|
)
|
68
31
|
converter.convert_to_tflite(
|
69
32
|
pytorch_model,
|
70
|
-
output_path=
|
71
|
-
output_name_prefix=
|
72
|
-
prefill_seq_len=
|
73
|
-
quantize=
|
74
|
-
lora_ranks=
|
33
|
+
output_path=flags.FLAGS.output_path,
|
34
|
+
output_name_prefix=flags.FLAGS.output_name_prefix,
|
35
|
+
prefill_seq_len=flags.FLAGS.prefill_seq_lens,
|
36
|
+
quantize=flags.FLAGS.quantize,
|
37
|
+
lora_ranks=flags.FLAGS.lora_ranks,
|
75
38
|
export_config=ExportConfig(),
|
76
39
|
)
|
77
40
|
|