ai-edge-torch-nightly 0.4.0.dev20250329__py3-none-any.whl → 0.4.0.dev20250331__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/generative/examples/amd_llama_135m/convert_to_tflite.py +7 -43
- ai_edge_torch/generative/examples/deepseek/convert_to_tflite.py +7 -42
- ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py +7 -45
- ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +7 -44
- ai_edge_torch/generative/examples/gemma3/convert_gemma3_to_tflite.py +10 -45
- ai_edge_torch/generative/examples/gemma3/verify_gemma3.py +90 -0
- ai_edge_torch/generative/examples/gemma3/verify_util.py +247 -0
- ai_edge_torch/generative/examples/llama/convert_to_tflite.py +9 -43
- ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +7 -44
- ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py +8 -39
- ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py +7 -44
- ai_edge_torch/generative/examples/phi/convert_phi4_to_tflite.py +7 -44
- ai_edge_torch/generative/examples/phi/convert_to_tflite.py +7 -42
- ai_edge_torch/generative/examples/qwen/convert_to_tflite.py +8 -45
- ai_edge_torch/generative/examples/qwen_vl/convert_to_tflite.py +8 -39
- ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +8 -43
- ai_edge_torch/generative/examples/smollm/convert_v2_to_tflite.py +8 -43
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +7 -44
- ai_edge_torch/generative/utilities/converter.py +45 -0
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.4.0.dev20250329.dist-info → ai_edge_torch_nightly-0.4.0.dev20250331.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.4.0.dev20250329.dist-info → ai_edge_torch_nightly-0.4.0.dev20250331.dist-info}/RECORD +25 -23
- {ai_edge_torch_nightly-0.4.0.dev20250329.dist-info → ai_edge_torch_nightly-0.4.0.dev20250331.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.4.0.dev20250329.dist-info → ai_edge_torch_nightly-0.4.0.dev20250331.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.4.0.dev20250329.dist-info → ai_edge_torch_nightly-0.4.0.dev20250331.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,247 @@
|
|
1
|
+
# Copyright 2025 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Utility functions to verify the reauthored Gemma model."""
|
17
|
+
|
18
|
+
import logging
|
19
|
+
import os
|
20
|
+
from typing import List, Optional, Tuple
|
21
|
+
|
22
|
+
from ai_edge_torch.generative.examples.gemma3 import gemma3
|
23
|
+
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
24
|
+
from ai_edge_torch.generative.layers.experimental import kv_cache as kv_utils
|
25
|
+
from ai_edge_torch.generative.utilities.experimental import verifier
|
26
|
+
from gemma import config as gemma_config
|
27
|
+
from gemma import model as gemma_model
|
28
|
+
import torch
|
29
|
+
|
30
|
+
|
31
|
+
def _get_actual_input_len(tokens: torch.Tensor) -> int:
|
32
|
+
for i in range(tokens.shape[1]):
|
33
|
+
if tokens[0, i] == 0:
|
34
|
+
return i
|
35
|
+
return tokens.shape[1]
|
36
|
+
|
37
|
+
|
38
|
+
class GemmaWrapper(verifier.ModelWrapper):
|
39
|
+
"""Gemma model wrapper for verification.
|
40
|
+
|
41
|
+
Verifier calls model.forward() with maxium sequence length (1024) expecting
|
42
|
+
the output is logits while Gemma gets the input tokens with the actual length
|
43
|
+
and returns logits in a tuple.
|
44
|
+
|
45
|
+
Verifier runs tokenizer before model.generate() while Gemma runs the tokenizer
|
46
|
+
inside model.generate().
|
47
|
+
"""
|
48
|
+
|
49
|
+
def _get_kv_caches(
|
50
|
+
self, max_seq_len: int
|
51
|
+
) -> List[Tuple[torch.Tensor, torch.Tensor]]:
|
52
|
+
config = self.model.config
|
53
|
+
cache_size = (1, max_seq_len, config.num_key_value_heads, config.head_dim)
|
54
|
+
cache = torch.zeros(cache_size)
|
55
|
+
return [
|
56
|
+
(cache.clone(), cache.clone()) for _ in range(config.num_hidden_layers)
|
57
|
+
]
|
58
|
+
|
59
|
+
def forward(self, tokens: torch.Tensor) -> torch.Tensor:
|
60
|
+
"""Forwards the model after reducing input tokens to the actual length."""
|
61
|
+
actual_input_len = _get_actual_input_len(tokens)
|
62
|
+
input_pos = torch.arange(0, actual_input_len, dtype=torch.long)
|
63
|
+
mask_cache = attn_utils.build_causal_mask_cache(tokens.shape[1])
|
64
|
+
local_mask_cache = attn_utils.build_sliding_window_mask_cache(
|
65
|
+
tokens.shape[1], self.model.config.sliding_window_size
|
66
|
+
)
|
67
|
+
_, logits = self.model.forward(
|
68
|
+
input_token_ids=tokens[0, :actual_input_len].unsqueeze(0),
|
69
|
+
input_positions=input_pos,
|
70
|
+
kv_write_indices=None,
|
71
|
+
kv_caches=self._get_kv_caches(tokens.shape[1]),
|
72
|
+
mask=mask_cache.index_select(2, input_pos),
|
73
|
+
output_positions=input_pos,
|
74
|
+
temperatures=None,
|
75
|
+
top_ps=torch.tensor([1.0], dtype=torch.float),
|
76
|
+
top_ks=torch.tensor([1], dtype=torch.long),
|
77
|
+
local_mask=local_mask_cache.index_select(2, input_pos),
|
78
|
+
)
|
79
|
+
return logits
|
80
|
+
|
81
|
+
def generate(
|
82
|
+
self, tokens: torch.Tensor, max_new_tokens: int
|
83
|
+
) -> torch.IntTensor:
|
84
|
+
"""Generates the response after decoding the tokens into a string."""
|
85
|
+
prompts = self.model.tokenizer.decode(tokens[0].tolist())
|
86
|
+
response = self.model.generate(
|
87
|
+
prompts, device="cpu", output_len=max_new_tokens, top_k=1
|
88
|
+
)
|
89
|
+
return torch.tensor([self.model.tokenizer.encode(prompts + response)])
|
90
|
+
|
91
|
+
|
92
|
+
class UnifiedGemma3Wrapper(verifier.ReauthoredModelWrapper):
|
93
|
+
"""Unified Gemma3 model wrapper for verification."""
|
94
|
+
|
95
|
+
def _init_kv_cache(self):
|
96
|
+
"""Returns an initialized KV cache."""
|
97
|
+
return kv_utils.KVCacheTransposed.from_model_config(self.model.model.config)
|
98
|
+
|
99
|
+
def forward(
|
100
|
+
self, tokens: torch.Tensor, pixel_values: torch.Tensor = None
|
101
|
+
) -> torch.Tensor:
|
102
|
+
"""Forwards the model."""
|
103
|
+
mask = attn_utils.build_causal_mask_cache(
|
104
|
+
self.model.model.config.kv_cache_max_len
|
105
|
+
)
|
106
|
+
input_pos = torch.arange(0, tokens.shape[1], dtype=torch.int)
|
107
|
+
mask = mask.index_select(2, input_pos)
|
108
|
+
output = self.model.model.forward(
|
109
|
+
tokens, input_pos, self._init_kv_cache(), mask=mask
|
110
|
+
)
|
111
|
+
return output["logits"]
|
112
|
+
|
113
|
+
def generate(
|
114
|
+
self,
|
115
|
+
prompts: torch.Tensor,
|
116
|
+
max_new_tokens: int,
|
117
|
+
pixel_values: torch.Tensor = None,
|
118
|
+
eos_token_id: Optional[int] = None,
|
119
|
+
) -> torch.IntTensor:
|
120
|
+
"""Generates the response."""
|
121
|
+
input_ids = prompts[0].int().tolist()
|
122
|
+
tokens = torch.tensor([input_ids])
|
123
|
+
input_pos = torch.arange(0, tokens.shape[1], dtype=torch.int)
|
124
|
+
kv_cache = self._init_kv_cache()
|
125
|
+
mask_cache = attn_utils.build_causal_mask_cache(
|
126
|
+
self.model.model.config.kv_cache_max_len
|
127
|
+
)
|
128
|
+
for _ in range(max_new_tokens):
|
129
|
+
mask = mask_cache.index_select(2, input_pos)
|
130
|
+
output = self.model.model.forward(
|
131
|
+
tokens, input_pos, kv_cache, mask=mask
|
132
|
+
)
|
133
|
+
logits, kv_cache = output["logits"], output["kv_cache"]
|
134
|
+
generated_token = logits[0][-1].argmax().item()
|
135
|
+
input_ids.append(generated_token)
|
136
|
+
if eos_token_id is not None and generated_token == eos_token_id:
|
137
|
+
break
|
138
|
+
tokens = torch.tensor([[generated_token]])
|
139
|
+
input_pos = torch.tensor([len(input_ids) - 1])
|
140
|
+
return torch.tensor([input_ids])
|
141
|
+
|
142
|
+
|
143
|
+
class GemmaTokenizerWrapper(verifier.TokenizerWrapper):
|
144
|
+
"""Tokenizer wrapper for verification.
|
145
|
+
|
146
|
+
Verifier expects the tokenizer to handle tokens in torch.Tensor while Gemma
|
147
|
+
tokenizer expects tokens in a list.
|
148
|
+
"""
|
149
|
+
|
150
|
+
def encode(self, text: str, **_) -> torch.Tensor:
|
151
|
+
"""Adds one more dimension to the output of the tokenizer."""
|
152
|
+
return torch.tensor([self.tokenizer.encode(text)])
|
153
|
+
|
154
|
+
def decode(self, tokens: torch.Tensor) -> str:
|
155
|
+
"""Decodes the token sequence after converting to a list."""
|
156
|
+
return self.tokenizer.decode(tokens.tolist())
|
157
|
+
|
158
|
+
|
159
|
+
def verify_reauthored_gemma_model(
|
160
|
+
checkpoint: str,
|
161
|
+
variant: str,
|
162
|
+
reauthored_model: torch.nn.Module,
|
163
|
+
generate_prompts: List[str],
|
164
|
+
forward_input_ids: List[List[int]],
|
165
|
+
weight_filename: str,
|
166
|
+
tokenizer_filename: str = "tokenizer.model",
|
167
|
+
max_new_tokens: int = 20,
|
168
|
+
rtol: float = 1e-05,
|
169
|
+
atol: float = 1e-05,
|
170
|
+
) -> bool:
|
171
|
+
"""Verifies the reauthored Gemma model against the original model.
|
172
|
+
|
173
|
+
Args:
|
174
|
+
checkpoint: Path to the Gemma checkpoint.
|
175
|
+
variant: Gemma model variant.
|
176
|
+
reauthored_model: The reauthored model to verify.
|
177
|
+
generate_prompts: List of prompts for generation.
|
178
|
+
forward_input_ids: List of input ids for forward pass.
|
179
|
+
weight_filename: Name of the weight file.
|
180
|
+
tokenizer_filename: Name of the tokenizer file.
|
181
|
+
max_new_tokens: Maximum number of new tokens to generate.
|
182
|
+
rtol: Relative tolerance for comparison.
|
183
|
+
atol: Absolute tolerance for comparison.
|
184
|
+
|
185
|
+
Returns:
|
186
|
+
True if the verification passes, False otherwise.
|
187
|
+
"""
|
188
|
+
config = gemma_config.get_model_config(variant)
|
189
|
+
config.tokenizer = os.path.join(checkpoint, tokenizer_filename)
|
190
|
+
# Use float32 to be compatible with the reauthored model.
|
191
|
+
config.dtype = torch.float32
|
192
|
+
|
193
|
+
logging.info("Loading the original model from: %s", checkpoint)
|
194
|
+
original_model = gemma_model.GemmaForCausalLM(config).eval()
|
195
|
+
original_model.load_weights(os.path.join(checkpoint, weight_filename))
|
196
|
+
|
197
|
+
return verifier.verify_reauthored_model(
|
198
|
+
original_model=GemmaWrapper(original_model),
|
199
|
+
reauthored_model=UnifiedGemma3Wrapper(reauthored_model),
|
200
|
+
tokenizer=GemmaTokenizerWrapper(original_model.tokenizer),
|
201
|
+
generate_prompts=generate_prompts,
|
202
|
+
max_new_tokens=max_new_tokens,
|
203
|
+
forward_input_ids=forward_input_ids,
|
204
|
+
rtol=rtol,
|
205
|
+
atol=atol,
|
206
|
+
)
|
207
|
+
|
208
|
+
|
209
|
+
def verify_gemma3(
|
210
|
+
checkpoint: str,
|
211
|
+
prompts: List[str],
|
212
|
+
max_new_tokens: int,
|
213
|
+
variant: str,
|
214
|
+
weight_filename: str,
|
215
|
+
) -> bool:
|
216
|
+
"""Verifies the reauthored Gemma3 model.
|
217
|
+
|
218
|
+
Args:
|
219
|
+
checkpoint: Path to the Gemma checkpoint.
|
220
|
+
prompts: List of prompts for generation.
|
221
|
+
max_new_tokens: Maximum number of new tokens to generate.
|
222
|
+
variant: Gemma model variant.
|
223
|
+
weight_filename: Name of the weight file.
|
224
|
+
|
225
|
+
Returns:
|
226
|
+
True if the verification passes, False otherwise.
|
227
|
+
"""
|
228
|
+
gemma3_model_path = os.path.join(checkpoint, weight_filename)
|
229
|
+
logging.info("Building the reauthored model from: %s", gemma3_model_path)
|
230
|
+
|
231
|
+
if variant == "1b":
|
232
|
+
reauthored_model = UnifiedGemma3Wrapper(
|
233
|
+
gemma3.build_model_1b(gemma3_model_path)
|
234
|
+
)
|
235
|
+
else:
|
236
|
+
raise ValueError(f"Unsupported Gemma3 variant: {variant}")
|
237
|
+
|
238
|
+
return verify_reauthored_gemma_model(
|
239
|
+
checkpoint=checkpoint,
|
240
|
+
variant=variant,
|
241
|
+
reauthored_model=reauthored_model,
|
242
|
+
generate_prompts=prompts,
|
243
|
+
forward_input_ids=[[2, 651, 9456, 576, 573, 3520, 3858, 603, 235248]],
|
244
|
+
max_new_tokens=max_new_tokens,
|
245
|
+
weight_filename=weight_filename,
|
246
|
+
atol=1e-04,
|
247
|
+
)
|
@@ -16,55 +16,21 @@
|
|
16
16
|
"""Example of converting Llama 3.2 1B model to multi-signature tflite model."""
|
17
17
|
|
18
18
|
import os
|
19
|
-
import pathlib
|
20
|
-
|
21
19
|
from absl import app
|
22
20
|
from absl import flags
|
23
21
|
from ai_edge_torch.generative.examples.llama import llama
|
24
22
|
from ai_edge_torch.generative.utilities import converter
|
25
23
|
from ai_edge_torch.generative.utilities.model_builder import ExportConfig
|
26
24
|
|
25
|
+
|
26
|
+
flags = converter.define_conversion_flags('llama')
|
27
|
+
|
27
28
|
_MODEL_SIZE = flags.DEFINE_enum(
|
28
29
|
'model_size',
|
29
30
|
'1b',
|
30
31
|
['1b', '3b'],
|
31
32
|
'The size of the model to verify.',
|
32
33
|
)
|
33
|
-
_CHECKPOINT_PATH = flags.DEFINE_string(
|
34
|
-
'checkpoint_path',
|
35
|
-
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/llama'),
|
36
|
-
'The path to the model checkpoint, or directory holding the checkpoint.',
|
37
|
-
)
|
38
|
-
_OUTPUT_PATH = flags.DEFINE_string(
|
39
|
-
'output_path',
|
40
|
-
'/tmp/',
|
41
|
-
'The path to export the tflite model.',
|
42
|
-
)
|
43
|
-
_OUTPUT_NAME_PREFIX = flags.DEFINE_string(
|
44
|
-
'output_name_prefix',
|
45
|
-
'llama',
|
46
|
-
'The prefix of the output tflite model name.',
|
47
|
-
)
|
48
|
-
_PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
|
49
|
-
'prefill_seq_lens',
|
50
|
-
(8, 64, 128, 256, 512, 1024),
|
51
|
-
'List of the maximum sizes of prefill input tensors.',
|
52
|
-
)
|
53
|
-
_KV_CACHE_MAX_LEN = flags.DEFINE_integer(
|
54
|
-
'kv_cache_max_len',
|
55
|
-
1280,
|
56
|
-
'The maximum size of KV cache buffer, including both prefill and decode.',
|
57
|
-
)
|
58
|
-
_QUANTIZE = flags.DEFINE_bool(
|
59
|
-
'quantize',
|
60
|
-
True,
|
61
|
-
'Whether the model should be quantized.',
|
62
|
-
)
|
63
|
-
_LORA_RANKS = flags.DEFINE_multi_integer(
|
64
|
-
'lora_ranks',
|
65
|
-
None,
|
66
|
-
'If set, the model will be converted with the provided list of LoRA ranks.',
|
67
|
-
)
|
68
34
|
|
69
35
|
_BUILDER = {
|
70
36
|
'1b': llama.build_1b_model,
|
@@ -74,15 +40,15 @@ _BUILDER = {
|
|
74
40
|
|
75
41
|
def main(_):
|
76
42
|
pytorch_model = _BUILDER[_MODEL_SIZE.value](
|
77
|
-
|
43
|
+
flags.FLAGS.checkpoint_path, kv_cache_max_len=flags.FLAGS.kv_cache_max_len
|
78
44
|
)
|
79
45
|
converter.convert_to_tflite(
|
80
46
|
pytorch_model,
|
81
|
-
output_path=
|
82
|
-
output_name_prefix=
|
83
|
-
prefill_seq_len=
|
84
|
-
quantize=
|
85
|
-
lora_ranks=
|
47
|
+
output_path=flags.FLAGS.output_path,
|
48
|
+
output_name_prefix=flags.FLAGS.output_name_prefix,
|
49
|
+
prefill_seq_len=flags.FLAGS.prefill_seq_lens,
|
50
|
+
quantize=flags.FLAGS.quantize,
|
51
|
+
lora_ranks=flags.FLAGS.lora_ranks,
|
86
52
|
export_config=ExportConfig(),
|
87
53
|
)
|
88
54
|
|
@@ -16,62 +16,25 @@
|
|
16
16
|
"""Example of converting OpenELM model to multi-signature tflite model."""
|
17
17
|
|
18
18
|
import os
|
19
|
-
import pathlib
|
20
|
-
|
21
19
|
from absl import app
|
22
20
|
from absl import flags
|
23
21
|
from ai_edge_torch.generative.examples.openelm import openelm
|
24
22
|
from ai_edge_torch.generative.utilities import converter
|
25
23
|
from ai_edge_torch.generative.utilities.model_builder import ExportConfig
|
26
24
|
|
27
|
-
|
28
|
-
'checkpoint_path',
|
29
|
-
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/openelm'),
|
30
|
-
'The path to the model checkpoint, or directory holding the checkpoint.',
|
31
|
-
)
|
32
|
-
_OUTPUT_PATH = flags.DEFINE_string(
|
33
|
-
'output_path',
|
34
|
-
'/tmp/',
|
35
|
-
'The path to export the tflite model.',
|
36
|
-
)
|
37
|
-
_OUTPUT_NAME_PREFIX = flags.DEFINE_string(
|
38
|
-
'output_name_prefix',
|
39
|
-
'openelm',
|
40
|
-
'The prefix of the output tflite model name.',
|
41
|
-
)
|
42
|
-
_PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
|
43
|
-
'prefill_seq_lens',
|
44
|
-
(8, 64, 128, 256, 512, 1024),
|
45
|
-
'List of the maximum sizes of prefill input tensors.',
|
46
|
-
)
|
47
|
-
_KV_CACHE_MAX_LEN = flags.DEFINE_integer(
|
48
|
-
'kv_cache_max_len',
|
49
|
-
1280,
|
50
|
-
'The maximum size of KV cache buffer, including both prefill and decode.',
|
51
|
-
)
|
52
|
-
_QUANTIZE = flags.DEFINE_bool(
|
53
|
-
'quantize',
|
54
|
-
True,
|
55
|
-
'Whether the model should be quantized.',
|
56
|
-
)
|
57
|
-
_LORA_RANKS = flags.DEFINE_multi_integer(
|
58
|
-
'lora_ranks',
|
59
|
-
None,
|
60
|
-
'If set, the model will be converted with the provided list of LoRA ranks.',
|
61
|
-
)
|
62
|
-
|
25
|
+
flags = converter.define_conversion_flags("openelm")
|
63
26
|
|
64
27
|
def main(_):
|
65
28
|
pytorch_model = openelm.build_model(
|
66
|
-
|
29
|
+
flags.FLAGS.checkpoint_path, kv_cache_max_len=flags.FLAGS.kv_cache_max_len
|
67
30
|
)
|
68
31
|
converter.convert_to_tflite(
|
69
32
|
pytorch_model,
|
70
|
-
output_path=
|
71
|
-
output_name_prefix=
|
72
|
-
prefill_seq_len=
|
73
|
-
quantize=
|
74
|
-
lora_ranks=
|
33
|
+
output_path=flags.FLAGS.output_path,
|
34
|
+
output_name_prefix=flags.FLAGS.output_name_prefix,
|
35
|
+
prefill_seq_len=flags.FLAGS.prefill_seq_lens,
|
36
|
+
quantize=flags.FLAGS.quantize,
|
37
|
+
lora_ranks=flags.FLAGS.lora_ranks,
|
75
38
|
export_config=ExportConfig(),
|
76
39
|
)
|
77
40
|
|
@@ -16,8 +16,6 @@
|
|
16
16
|
"""Example of converting a PaliGemma model to multi-signature tflite model."""
|
17
17
|
|
18
18
|
import os
|
19
|
-
import pathlib
|
20
|
-
|
21
19
|
from absl import app
|
22
20
|
from absl import flags
|
23
21
|
from ai_edge_torch.generative.examples.paligemma import paligemma
|
@@ -25,61 +23,32 @@ from ai_edge_torch.generative.utilities import converter
|
|
25
23
|
from ai_edge_torch.generative.utilities.model_builder import ExportConfig
|
26
24
|
import torch
|
27
25
|
|
26
|
+
flags = converter.define_conversion_flags('paligemma2-3b-224')
|
27
|
+
|
28
28
|
_VERSION = flags.DEFINE_enum(
|
29
29
|
'version',
|
30
30
|
'2',
|
31
31
|
['1', '2'],
|
32
32
|
'The version of PaliGemma model to verify.',
|
33
33
|
)
|
34
|
-
_CHECKPOINT_PATH = flags.DEFINE_string(
|
35
|
-
'checkpoint_path',
|
36
|
-
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/paligemma2-3b-224'),
|
37
|
-
'The path to the model checkpoint, or directory holding the checkpoint.',
|
38
|
-
)
|
39
|
-
_OUTPUT_PATH = flags.DEFINE_string(
|
40
|
-
'output_path',
|
41
|
-
'/tmp/',
|
42
|
-
'The path to export the tflite model.',
|
43
|
-
)
|
44
|
-
_OUTPUT_NAME_PREFIX = flags.DEFINE_string(
|
45
|
-
'output_name_prefix',
|
46
|
-
'paligemma',
|
47
|
-
'The prefix of the output tflite model name.',
|
48
|
-
)
|
49
|
-
_PREFILL_SEQ_LEN = flags.DEFINE_integer(
|
50
|
-
'prefill_seq_len',
|
51
|
-
1024,
|
52
|
-
'The maximum size of prefill input tensor.',
|
53
|
-
)
|
54
|
-
_KV_CACHE_MAX_LEN = flags.DEFINE_integer(
|
55
|
-
'kv_cache_max_len',
|
56
|
-
1280,
|
57
|
-
'The maximum size of KV cache buffer, including both prefill and decode.',
|
58
|
-
)
|
59
|
-
_QUANTIZE = flags.DEFINE_bool(
|
60
|
-
'quantize',
|
61
|
-
True,
|
62
|
-
'Whether the model should be quantized.',
|
63
|
-
)
|
64
|
-
|
65
34
|
|
66
35
|
def main(_):
|
67
36
|
pytorch_model = paligemma.build_model(
|
68
|
-
|
37
|
+
flags.FLAGS.checkpoint_path,
|
69
38
|
version=int(_VERSION.value),
|
70
|
-
kv_cache_max_len=
|
39
|
+
kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
|
71
40
|
)
|
72
41
|
|
73
42
|
config = pytorch_model.image_encoder.config.image_embedding
|
74
43
|
converter.convert_to_tflite(
|
75
44
|
pytorch_model,
|
76
|
-
output_path=
|
77
|
-
output_name_prefix=f'{
|
78
|
-
prefill_seq_len=
|
45
|
+
output_path=flags.FLAGS.output_path,
|
46
|
+
output_name_prefix=f'{flags.FLAGS.output_name_prefix}_{_VERSION.value}',
|
47
|
+
prefill_seq_len=flags.FLAGS.prefill_seq_lens,
|
79
48
|
pixel_values_size=torch.Size(
|
80
49
|
[1, config.channels, config.image_size, config.image_size]
|
81
50
|
),
|
82
|
-
quantize=
|
51
|
+
quantize=flags.FLAGS.quantize,
|
83
52
|
config=pytorch_model.config.decoder_config,
|
84
53
|
export_config=ExportConfig(),
|
85
54
|
)
|
@@ -16,62 +16,25 @@
|
|
16
16
|
"""Example of converting a Phi-3.5 model to multi-signature tflite model."""
|
17
17
|
|
18
18
|
import os
|
19
|
-
import pathlib
|
20
|
-
|
21
19
|
from absl import app
|
22
20
|
from absl import flags
|
23
21
|
from ai_edge_torch.generative.examples.phi import phi3
|
24
22
|
from ai_edge_torch.generative.utilities import converter
|
25
23
|
from ai_edge_torch.generative.utilities.model_builder import ExportConfig
|
26
24
|
|
27
|
-
|
28
|
-
'checkpoint_path',
|
29
|
-
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/phi3'),
|
30
|
-
'The path to the model checkpoint, or directory holding the checkpoint.',
|
31
|
-
)
|
32
|
-
_OUTPUT_PATH = flags.DEFINE_string(
|
33
|
-
'output_path',
|
34
|
-
'/tmp/',
|
35
|
-
'The path to export the tflite model.',
|
36
|
-
)
|
37
|
-
_OUTPUT_NAME_PREFIX = flags.DEFINE_string(
|
38
|
-
'output_name_prefix',
|
39
|
-
'phi3',
|
40
|
-
'The prefix of the output tflite model name.',
|
41
|
-
)
|
42
|
-
_PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
|
43
|
-
'prefill_seq_lens',
|
44
|
-
(8, 64, 128, 256, 512, 1024),
|
45
|
-
'List of the maximum sizes of prefill input tensors.',
|
46
|
-
)
|
47
|
-
_KV_CACHE_MAX_LEN = flags.DEFINE_integer(
|
48
|
-
'kv_cache_max_len',
|
49
|
-
1280,
|
50
|
-
'The maximum size of KV cache buffer, including both prefill and decode.',
|
51
|
-
)
|
52
|
-
_QUANTIZE = flags.DEFINE_bool(
|
53
|
-
'quantize',
|
54
|
-
True,
|
55
|
-
'Whether the model should be quantized.',
|
56
|
-
)
|
57
|
-
_LORA_RANKS = flags.DEFINE_multi_integer(
|
58
|
-
'lora_ranks',
|
59
|
-
None,
|
60
|
-
'If set, the model will be converted with the provided list of LoRA ranks.',
|
61
|
-
)
|
62
|
-
|
25
|
+
flags = converter.define_conversion_flags("phi3")
|
63
26
|
|
64
27
|
def main(_):
|
65
28
|
pytorch_model = phi3.build_model(
|
66
|
-
|
29
|
+
flags.FLAGS.checkpoint_path, kv_cache_max_len=flags.FLAGS.kv_cache_max_len
|
67
30
|
)
|
68
31
|
converter.convert_to_tflite(
|
69
32
|
pytorch_model,
|
70
|
-
output_path=
|
71
|
-
output_name_prefix=
|
72
|
-
prefill_seq_len=
|
73
|
-
quantize=
|
74
|
-
lora_ranks=
|
33
|
+
output_path=flags.FLAGS.output_path,
|
34
|
+
output_name_prefix=flags.FLAGS.output_name_prefix,
|
35
|
+
prefill_seq_len=flags.FLAGS.prefill_seq_lens,
|
36
|
+
quantize=flags.FLAGS.quantize,
|
37
|
+
lora_ranks=flags.FLAGS.lora_ranks,
|
75
38
|
export_config=ExportConfig(),
|
76
39
|
)
|
77
40
|
|
@@ -16,62 +16,25 @@
|
|
16
16
|
"""Example of converting a Phi-4 model to multi-signature tflite model."""
|
17
17
|
|
18
18
|
import os
|
19
|
-
import pathlib
|
20
|
-
|
21
19
|
from absl import app
|
22
20
|
from absl import flags
|
23
21
|
from ai_edge_torch.generative.examples.phi import phi4
|
24
22
|
from ai_edge_torch.generative.utilities import converter
|
25
23
|
from ai_edge_torch.generative.utilities.model_builder import ExportConfig
|
26
24
|
|
27
|
-
|
28
|
-
'checkpoint_path',
|
29
|
-
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/phi4'),
|
30
|
-
'The path to the model checkpoint, or directory holding the checkpoint.',
|
31
|
-
)
|
32
|
-
_OUTPUT_PATH = flags.DEFINE_string(
|
33
|
-
'output_path',
|
34
|
-
'/tmp/',
|
35
|
-
'The path to export the tflite model.',
|
36
|
-
)
|
37
|
-
_OUTPUT_NAME_PREFIX = flags.DEFINE_string(
|
38
|
-
'output_name_prefix',
|
39
|
-
'phi4',
|
40
|
-
'The prefix of the output tflite model name.',
|
41
|
-
)
|
42
|
-
_PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
|
43
|
-
'prefill_seq_lens',
|
44
|
-
(8, 64, 128, 256, 512, 1024),
|
45
|
-
'List of the maximum sizes of prefill input tensors.',
|
46
|
-
)
|
47
|
-
_KV_CACHE_MAX_LEN = flags.DEFINE_integer(
|
48
|
-
'kv_cache_max_len',
|
49
|
-
1280,
|
50
|
-
'The maximum size of KV cache buffer, including both prefill and decode.',
|
51
|
-
)
|
52
|
-
_QUANTIZE = flags.DEFINE_bool(
|
53
|
-
'quantize',
|
54
|
-
True,
|
55
|
-
'Whether the model should be quantized.',
|
56
|
-
)
|
57
|
-
_LORA_RANKS = flags.DEFINE_multi_integer(
|
58
|
-
'lora_ranks',
|
59
|
-
None,
|
60
|
-
'If set, the model will be converted with the provided list of LoRA ranks.',
|
61
|
-
)
|
62
|
-
|
25
|
+
flags = converter.define_conversion_flags("phi4")
|
63
26
|
|
64
27
|
def main(_):
|
65
28
|
pytorch_model = phi4.build_model(
|
66
|
-
|
29
|
+
flags.FLAGS.checkpoint_path, kv_cache_max_len=flags.FLAGS.kv_cache_max_len
|
67
30
|
)
|
68
31
|
converter.convert_to_tflite(
|
69
32
|
pytorch_model,
|
70
|
-
output_path=
|
71
|
-
output_name_prefix=
|
72
|
-
prefill_seq_len=
|
73
|
-
quantize=
|
74
|
-
lora_ranks=
|
33
|
+
output_path=flags.FLAGS.output_path,
|
34
|
+
output_name_prefix=flags.FLAGS.output_name_prefix,
|
35
|
+
prefill_seq_len=flags.FLAGS.prefill_seq_lens,
|
36
|
+
quantize=flags.FLAGS.quantize,
|
37
|
+
lora_ranks=flags.FLAGS.lora_ranks,
|
75
38
|
export_config=ExportConfig(),
|
76
39
|
)
|
77
40
|
|
@@ -24,54 +24,19 @@ from ai_edge_torch.generative.examples.phi import phi2
|
|
24
24
|
from ai_edge_torch.generative.utilities import converter
|
25
25
|
from ai_edge_torch.generative.utilities.model_builder import ExportConfig
|
26
26
|
|
27
|
-
|
28
|
-
'checkpoint_path',
|
29
|
-
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/phi2'),
|
30
|
-
'The path to the model checkpoint, or directory holding the checkpoint.',
|
31
|
-
)
|
32
|
-
_OUTPUT_PATH = flags.DEFINE_string(
|
33
|
-
'output_path',
|
34
|
-
'/tmp/',
|
35
|
-
'The path to export the tflite model.',
|
36
|
-
)
|
37
|
-
_OUTPUT_NAME_PREFIX = flags.DEFINE_string(
|
38
|
-
'output_name_prefix',
|
39
|
-
'phi2',
|
40
|
-
'The prefix of the output tflite model name.',
|
41
|
-
)
|
42
|
-
_PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
|
43
|
-
'prefill_seq_lens',
|
44
|
-
(8, 64, 128, 256, 512, 1024),
|
45
|
-
'List of the maximum sizes of prefill input tensors.',
|
46
|
-
)
|
47
|
-
_KV_CACHE_MAX_LEN = flags.DEFINE_integer(
|
48
|
-
'kv_cache_max_len',
|
49
|
-
1280,
|
50
|
-
'The maximum size of KV cache buffer, including both prefill and decode.',
|
51
|
-
)
|
52
|
-
_QUANTIZE = flags.DEFINE_bool(
|
53
|
-
'quantize',
|
54
|
-
True,
|
55
|
-
'Whether the model should be quantized.',
|
56
|
-
)
|
57
|
-
_LORA_RANKS = flags.DEFINE_multi_integer(
|
58
|
-
'lora_ranks',
|
59
|
-
None,
|
60
|
-
'If set, the model will be converted with the provided list of LoRA ranks.',
|
61
|
-
)
|
62
|
-
|
27
|
+
flags = converter.define_conversion_flags("phi2")
|
63
28
|
|
64
29
|
def main(_):
|
65
30
|
pytorch_model = phi2.build_model(
|
66
|
-
|
31
|
+
flags.FLAGS.checkpoint_path, kv_cache_max_len=flags.FLAGS.kv_cache_max_len
|
67
32
|
)
|
68
33
|
converter.convert_to_tflite(
|
69
34
|
pytorch_model,
|
70
|
-
output_path=
|
71
|
-
output_name_prefix=
|
72
|
-
prefill_seq_len=
|
73
|
-
quantize=
|
74
|
-
lora_ranks=
|
35
|
+
output_path=flags.FLAGS.output_path,
|
36
|
+
output_name_prefix=flags.FLAGS.output_name_prefix,
|
37
|
+
prefill_seq_len=flags.FLAGS.prefill_seq_lens,
|
38
|
+
quantize=flags.FLAGS.quantize,
|
39
|
+
lora_ranks=flags.FLAGS.lora_ranks,
|
75
40
|
export_config=ExportConfig(),
|
76
41
|
)
|
77
42
|
|