ai-edge-torch-nightly 0.3.0.dev20250105__py3-none-any.whl → 0.3.0.dev20250108__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- ai_edge_torch/_config.py +26 -9
- ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_check.py +13 -0
- ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +36 -0
- ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py +16 -6
- ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +16 -6
- ai_edge_torch/generative/examples/gemma/gemma2.py +25 -43
- ai_edge_torch/generative/examples/llama/convert_to_tflite.py +16 -6
- ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +16 -9
- ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py +11 -6
- ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py +17 -7
- ai_edge_torch/generative/examples/phi/convert_to_tflite.py +16 -6
- ai_edge_torch/generative/examples/qwen/convert_to_tflite.py +17 -9
- ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +16 -7
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +3 -3
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +16 -8
- ai_edge_torch/generative/layers/attention.py +70 -12
- ai_edge_torch/generative/layers/lora.py +557 -0
- ai_edge_torch/generative/layers/normalization.py +2 -50
- ai_edge_torch/generative/layers/rotary_position_embedding.py +27 -34
- ai_edge_torch/generative/test/test_lora.py +147 -0
- ai_edge_torch/generative/utilities/converter.py +100 -47
- ai_edge_torch/generative/utilities/model_builder.py +21 -16
- ai_edge_torch/generative/utilities/verifier.py +4 -4
- ai_edge_torch/odml_torch/_torch_future.py +13 -0
- ai_edge_torch/odml_torch/export.py +6 -2
- ai_edge_torch/odml_torch/lowerings/decomp.py +4 -0
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20250105.dist-info → ai_edge_torch_nightly-0.3.0.dev20250108.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20250105.dist-info → ai_edge_torch_nightly-0.3.0.dev20250108.dist-info}/RECORD +32 -30
- {ai_edge_torch_nightly-0.3.0.dev20250105.dist-info → ai_edge_torch_nightly-0.3.0.dev20250108.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20250105.dist-info → ai_edge_torch_nightly-0.3.0.dev20250108.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20250105.dist-info → ai_edge_torch_nightly-0.3.0.dev20250108.dist-info}/top_level.txt +0 -0
ai_edge_torch/_config.py
CHANGED
@@ -22,6 +22,18 @@ import os
|
|
22
22
|
__all__ = ["config"]
|
23
23
|
|
24
24
|
|
25
|
+
def _get_bool_env_var(name: str, default: bool) -> bool:
|
26
|
+
var = os.environ.get(name, "false")
|
27
|
+
var = var.lower().strip()
|
28
|
+
if var in ("y", "yes", "t", "true", "on", "1"):
|
29
|
+
return True
|
30
|
+
elif var in ("n", "no", "f", "false", "off", "0"):
|
31
|
+
return False
|
32
|
+
else:
|
33
|
+
logging.warning("Invalid %s value is ignored: %s.", name, var)
|
34
|
+
return default
|
35
|
+
|
36
|
+
|
25
37
|
class _Config:
|
26
38
|
"""ai-edge-torch global configs."""
|
27
39
|
|
@@ -33,20 +45,25 @@ class _Config:
|
|
33
45
|
To use torch_xla as the lowering backend, set environment variable
|
34
46
|
`USE_TORCH_XLA` to "true".
|
35
47
|
"""
|
36
|
-
|
37
|
-
var = var.lower().strip()
|
38
|
-
if var in ("y", "yes", "t", "true", "on", "1"):
|
39
|
-
return True
|
40
|
-
elif var in ("n", "no", "f", "false", "off", "0"):
|
41
|
-
return False
|
42
|
-
else:
|
43
|
-
logging.warning("Invalid USE_TORCH_XLA value is ignored: %s.", var)
|
44
|
-
return False
|
48
|
+
return _get_bool_env_var("USE_TORCH_XLA", default=False)
|
45
49
|
|
46
50
|
@property
|
47
51
|
def in_oss(self) -> bool:
|
48
52
|
"""True if the code is not running in google internal environment."""
|
49
53
|
return True
|
50
54
|
|
55
|
+
@property
|
56
|
+
def enable_group_norm_composite(self) -> bool:
|
57
|
+
"""True if lowering group norm in StableHLO composite.
|
58
|
+
|
59
|
+
Currently only supports NHWC group norm generated by
|
60
|
+
OptimizeLayoutTransposesPass.
|
61
|
+
"""
|
62
|
+
return _get_bool_env_var("ENABLE_GROUP_NORM_COMPOSITE", default=False)
|
63
|
+
|
64
|
+
@enable_group_norm_composite.setter
|
65
|
+
def enable_group_norm_composite(self, value: bool):
|
66
|
+
os.environ["ENABLE_GROUP_NORM_COMPOSITE"] = "y" if value else "n"
|
67
|
+
|
51
68
|
|
52
69
|
config = _Config()
|
@@ -17,6 +17,7 @@
|
|
17
17
|
import dataclasses
|
18
18
|
import operator
|
19
19
|
|
20
|
+
import ai_edge_torch
|
20
21
|
from ai_edge_torch import lowertools
|
21
22
|
from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import layout_rewrite
|
22
23
|
from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import utils
|
@@ -155,6 +156,7 @@ def _qdq_layout_sensitive_inputs_getter(node: Node):
|
|
155
156
|
@layout_sensitive_inputs_getters.register(
|
156
157
|
aten._native_batch_norm_legit_no_training
|
157
158
|
)
|
159
|
+
@layout_sensitive_inputs_getters.register(aten.group_norm)
|
158
160
|
@layout_sensitive_inputs_getters.register(aten.native_group_norm)
|
159
161
|
def _first_arg_getter(node):
|
160
162
|
return [node.args[0]]
|
@@ -188,6 +190,17 @@ def _aten_norm_checker(node):
|
|
188
190
|
return NHWCable(can_be=len(val[0].shape) == 4, must_be=False)
|
189
191
|
|
190
192
|
|
193
|
+
@nhwcable_node_checkers.register(aten.group_norm)
|
194
|
+
def _aten_group_norm_checker(node):
|
195
|
+
val = node.meta.get("val")
|
196
|
+
if not hasattr(val, "shape"):
|
197
|
+
return NHWCable(can_be=False, must_be=False)
|
198
|
+
|
199
|
+
can_be = len(val.shape) == 4
|
200
|
+
must_be = can_be and ai_edge_torch.config.enable_group_norm_composite
|
201
|
+
return NHWCable(can_be=can_be, must_be=must_be)
|
202
|
+
|
203
|
+
|
191
204
|
@nhwcable_node_checkers.register(aten.native_group_norm)
|
192
205
|
def _aten_native_group_norm_checker(node):
|
193
206
|
val = node.meta.get("val")
|
@@ -16,6 +16,7 @@
|
|
16
16
|
|
17
17
|
import operator
|
18
18
|
|
19
|
+
import ai_edge_torch
|
19
20
|
from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import layout_mark
|
20
21
|
from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import op_func_registry
|
21
22
|
from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import utils
|
@@ -23,6 +24,7 @@ import torch
|
|
23
24
|
import torch.utils._pytree as pytree
|
24
25
|
|
25
26
|
aten = torch.ops.aten
|
27
|
+
StableHLOCompositeBuilder = ai_edge_torch.hlfb.StableHLOCompositeBuilder
|
26
28
|
|
27
29
|
__all__ = ["rewrite_nhwc_node", "has_nhwc_rewriter"]
|
28
30
|
|
@@ -342,6 +344,39 @@ def _aten__native_batch_norm_legit_no_training(node):
|
|
342
344
|
node.target = batch_norm
|
343
345
|
|
344
346
|
|
347
|
+
@rewriters.register(aten.group_norm.default)
|
348
|
+
def _aten_group_norm(node):
|
349
|
+
def group_norm(input, num_groups: int, weight=None, bias=None, eps=1e-5):
|
350
|
+
is_composite_supported = (
|
351
|
+
ai_edge_torch.config.enable_group_norm_composite
|
352
|
+
and weight is not None
|
353
|
+
and bias is not None
|
354
|
+
)
|
355
|
+
|
356
|
+
builder = None
|
357
|
+
if is_composite_supported:
|
358
|
+
builder = StableHLOCompositeBuilder(
|
359
|
+
name="odml.group_norm",
|
360
|
+
attr={
|
361
|
+
"num_groups": num_groups,
|
362
|
+
"epsilon": eps,
|
363
|
+
"reduction_axes": [3],
|
364
|
+
"channel_axis": 3,
|
365
|
+
},
|
366
|
+
)
|
367
|
+
input, weight, bias = builder.mark_inputs(input, weight, bias)
|
368
|
+
|
369
|
+
input = utils.tensor_to_nchw(input)
|
370
|
+
output = aten.group_norm.default(input, num_groups, weight, bias, eps=eps)
|
371
|
+
output = utils.tensor_to_nhwc(output)
|
372
|
+
|
373
|
+
if builder is not None:
|
374
|
+
output = builder.mark_outputs(output)
|
375
|
+
return output
|
376
|
+
|
377
|
+
node.target = group_norm
|
378
|
+
|
379
|
+
|
345
380
|
@rewriters.register(aten.native_group_norm.default)
|
346
381
|
def _aten_native_group_norm(node):
|
347
382
|
|
@@ -354,6 +389,7 @@ def _aten_native_group_norm(node):
|
|
354
389
|
flattened_inner_size: int,
|
355
390
|
num_groups: int,
|
356
391
|
eps: float,
|
392
|
+
**kwargs,
|
357
393
|
):
|
358
394
|
input_reshaped = torch.reshape(
|
359
395
|
input,
|
@@ -29,10 +29,15 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
|
|
29
29
|
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/gemma-2b'),
|
30
30
|
'The path to the model checkpoint, or directory holding the checkpoint.',
|
31
31
|
)
|
32
|
-
|
33
|
-
'
|
32
|
+
_OUTPUT_PATH = flags.DEFINE_string(
|
33
|
+
'output_path',
|
34
34
|
'/tmp/',
|
35
|
-
'The
|
35
|
+
'The path to export the tflite model.',
|
36
|
+
)
|
37
|
+
_OUTPUT_NAME_PREFIX = flags.DEFINE_string(
|
38
|
+
'output_name_prefix',
|
39
|
+
'gemma',
|
40
|
+
'The prefix of the output tflite model name.',
|
36
41
|
)
|
37
42
|
_PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
|
38
43
|
'prefill_seq_lens',
|
@@ -49,19 +54,24 @@ _QUANTIZE = flags.DEFINE_bool(
|
|
49
54
|
True,
|
50
55
|
'Whether the model should be quantized.',
|
51
56
|
)
|
57
|
+
_LORA_RANKS = flags.DEFINE_multi_integer(
|
58
|
+
'lora_ranks',
|
59
|
+
None,
|
60
|
+
'If set, the model will be converted with the provided list of LoRA ranks.',
|
61
|
+
)
|
52
62
|
|
53
63
|
|
54
64
|
def main(_):
|
55
65
|
pytorch_model = gemma1.build_2b_model(
|
56
66
|
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
|
57
67
|
)
|
58
|
-
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
|
59
|
-
output_filename = f'gemma_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
|
60
68
|
converter.convert_to_tflite(
|
61
69
|
pytorch_model,
|
62
|
-
|
70
|
+
output_path=_OUTPUT_PATH.value,
|
71
|
+
output_name_prefix=_OUTPUT_NAME_PREFIX.value,
|
63
72
|
prefill_seq_len=_PREFILL_SEQ_LENS.value,
|
64
73
|
quantize=_QUANTIZE.value,
|
74
|
+
lora_ranks=_LORA_RANKS.value,
|
65
75
|
export_config=ExportConfig(),
|
66
76
|
)
|
67
77
|
|
@@ -29,10 +29,15 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
|
|
29
29
|
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/gemma2-2b'),
|
30
30
|
'The path to the model checkpoint, or directory holding the checkpoint.',
|
31
31
|
)
|
32
|
-
|
33
|
-
'
|
32
|
+
_OUTPUT_PATH = flags.DEFINE_string(
|
33
|
+
'output_path',
|
34
34
|
'/tmp/',
|
35
|
-
'The
|
35
|
+
'The path to export the tflite model.',
|
36
|
+
)
|
37
|
+
_OUTPUT_NAME_PREFIX = flags.DEFINE_string(
|
38
|
+
'output_name_prefix',
|
39
|
+
'gemma2',
|
40
|
+
'The prefix of the output tflite model name.',
|
36
41
|
)
|
37
42
|
_PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
|
38
43
|
'prefill_seq_lens',
|
@@ -49,19 +54,24 @@ _QUANTIZE = flags.DEFINE_bool(
|
|
49
54
|
True,
|
50
55
|
'Whether the model should be quantized.',
|
51
56
|
)
|
57
|
+
_LORA_RANKS = flags.DEFINE_multi_integer(
|
58
|
+
'lora_ranks',
|
59
|
+
None,
|
60
|
+
'If set, the model will be converted with the provided list of LoRA ranks.',
|
61
|
+
)
|
52
62
|
|
53
63
|
|
54
64
|
def main(_):
|
55
65
|
pytorch_model = gemma2.build_2b_model(
|
56
66
|
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
|
57
67
|
)
|
58
|
-
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
|
59
|
-
output_filename = f'gemma2_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
|
60
68
|
converter.convert_to_tflite(
|
61
69
|
pytorch_model,
|
62
|
-
|
70
|
+
output_path=_OUTPUT_PATH.value,
|
71
|
+
output_name_prefix=_OUTPUT_NAME_PREFIX.value,
|
63
72
|
prefill_seq_len=_PREFILL_SEQ_LENS.value,
|
64
73
|
quantize=_QUANTIZE.value,
|
74
|
+
lora_ranks=_LORA_RANKS.value,
|
65
75
|
export_config=ExportConfig(),
|
66
76
|
)
|
67
77
|
|
@@ -15,14 +15,13 @@
|
|
15
15
|
|
16
16
|
"""Example of building a Gemma2 model."""
|
17
17
|
|
18
|
-
from typing import
|
18
|
+
from typing import Optional, Tuple
|
19
19
|
|
20
20
|
from ai_edge_torch.generative.layers import attention
|
21
21
|
from ai_edge_torch.generative.layers import builder
|
22
22
|
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
23
23
|
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
24
24
|
import ai_edge_torch.generative.layers.model_config as cfg
|
25
|
-
import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
|
26
25
|
from ai_edge_torch.generative.utilities import model_builder
|
27
26
|
import ai_edge_torch.generative.utilities.loader as loading_utils
|
28
27
|
import torch
|
@@ -104,12 +103,17 @@ class Gemma2(nn.Module):
|
|
104
103
|
config.embedding_dim,
|
105
104
|
config.final_norm_config,
|
106
105
|
)
|
107
|
-
self.mask_cache = attn_utils.build_causal_mask_cache(
|
108
|
-
size=config.kv_cache_max,
|
109
|
-
)
|
110
106
|
# Gemma2 has same hyper parameters for each layer except for attention
|
111
107
|
# types. Use the first layer.
|
112
108
|
attn_config = config.block_config(0).attn_config
|
109
|
+
self.rope_cache = attn_utils.build_rope_cache(
|
110
|
+
size=config.kv_cache_max,
|
111
|
+
dim=int(attn_config.rotary_percentage * attn_config.head_dim),
|
112
|
+
base=attn_config.rotary_base,
|
113
|
+
)
|
114
|
+
self.mask_cache = attn_utils.build_causal_mask_cache(
|
115
|
+
size=config.kv_cache_max,
|
116
|
+
)
|
113
117
|
self.sliding_window_mask_cache = attn_utils.build_sliding_window_mask_cache(
|
114
118
|
size=config.kv_cache_max,
|
115
119
|
window_size=attn_config.sliding_window_size,
|
@@ -136,48 +140,29 @@ class Gemma2(nn.Module):
|
|
136
140
|
f"Cannot forward sequence of length {seq_len}, max seq length is only"
|
137
141
|
f" {self.config.max_seq_len}"
|
138
142
|
)
|
139
|
-
|
140
|
-
# token embeddings of shape (b, t, n_embd)
|
141
|
-
input_embeds = self.tok_embedding(tokens)
|
142
|
-
# RoPE parameters are the same for all blocks. Use the first layer.
|
143
|
-
attn_config = self.config.block_config(0).attn_config
|
144
|
-
n_elem = int(attn_config.rotary_percentage * attn_config.head_dim)
|
145
|
-
rope = rotary_pos_emb.build_rope(
|
146
|
-
input_pos, n_elem, attn_config.head_dim, attn_config.rotary_base
|
147
|
-
)
|
148
|
-
mask = [self.get_attention_mask(
|
149
|
-
self.config.block_config(i).attn_config.attn_type, input_pos
|
150
|
-
) for i in range(self.config.num_layers)]
|
151
|
-
|
152
|
-
return self._forward_with_embeds(
|
153
|
-
input_embeds, rope, mask, input_pos, kv_cache, export_config
|
154
|
-
)
|
155
|
-
|
156
|
-
def _forward_with_embeds(
|
157
|
-
self,
|
158
|
-
input_embeds: torch.Tensor,
|
159
|
-
rope: Tuple[torch.Tensor, torch.Tensor],
|
160
|
-
mask: List[torch.Tensor],
|
161
|
-
input_pos: torch.Tensor,
|
162
|
-
kv_cache: kv_utils.KVCache,
|
163
|
-
export_config: Optional[model_builder.ExportConfig] = None,
|
164
|
-
) -> dict[torch.Tensor, kv_utils.KVCache]:
|
165
|
-
"""Forwards the model with input embeddings."""
|
166
143
|
assert len(self.transformer_blocks) == len(kv_cache.caches), (
|
167
144
|
"The number of transformer blocks and the number of KV cache entries"
|
168
145
|
" must be the same."
|
169
146
|
)
|
170
147
|
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
148
|
+
cos, sin = self.rope_cache
|
149
|
+
cos = cos.index_select(0, input_pos)
|
150
|
+
sin = sin.index_select(0, input_pos)
|
151
|
+
|
152
|
+
# token embeddings of shape (b, t, n_embd)
|
153
|
+
x = self.tok_embedding(tokens)
|
154
|
+
x = x * (self.config.embedding_dim**0.5)
|
155
|
+
|
156
|
+
updated_kv_entires = []
|
175
157
|
for i, block in enumerate(self.transformer_blocks):
|
158
|
+
mask = self.get_attention_mask(
|
159
|
+
block.config.attn_config.attn_type, input_pos
|
160
|
+
)
|
176
161
|
kv_entry = kv_cache.caches[i] if kv_cache else None
|
177
|
-
x, kv_entry = block(x,
|
162
|
+
x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
|
178
163
|
if kv_entry:
|
179
|
-
|
180
|
-
updated_kv_cache = kv_utils.KVCache(tuple(
|
164
|
+
updated_kv_entires.append(kv_entry)
|
165
|
+
updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
|
181
166
|
|
182
167
|
if export_config is not None:
|
183
168
|
if (
|
@@ -243,13 +228,11 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
243
228
|
)
|
244
229
|
|
245
230
|
num_layers = 26
|
246
|
-
embedding_dim = 2304
|
247
231
|
config = cfg.ModelConfig(
|
248
232
|
vocab_size=256000,
|
249
233
|
num_layers=num_layers,
|
250
234
|
max_seq_len=8192,
|
251
|
-
embedding_dim=
|
252
|
-
embedding_scale=embedding_dim**0.5,
|
235
|
+
embedding_dim=2304,
|
253
236
|
kv_cache_max_len=kv_cache_max_len,
|
254
237
|
block_configs=[get_block_config(i) for i in range(num_layers)],
|
255
238
|
final_norm_config=norm_config,
|
@@ -266,7 +249,6 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
|
|
266
249
|
config.num_layers = 2
|
267
250
|
config.max_seq_len = 2 * kv_cache_max_len
|
268
251
|
config.embedding_dim = 128
|
269
|
-
config.embedding_scale = config.embedding_dim**0.5
|
270
252
|
config.block_configs = config.block_configs[: config.num_layers]
|
271
253
|
for block_config in config.block_configs:
|
272
254
|
block_config.attn_config.num_heads = 4
|
@@ -35,10 +35,15 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
|
|
35
35
|
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/llama'),
|
36
36
|
'The path to the model checkpoint, or directory holding the checkpoint.',
|
37
37
|
)
|
38
|
-
|
39
|
-
'
|
38
|
+
_OUTPUT_PATH = flags.DEFINE_string(
|
39
|
+
'output_path',
|
40
40
|
'/tmp/',
|
41
|
-
'The
|
41
|
+
'The path to export the tflite model.',
|
42
|
+
)
|
43
|
+
_OUTPUT_NAME_PREFIX = flags.DEFINE_string(
|
44
|
+
'output_name_prefix',
|
45
|
+
'llama',
|
46
|
+
'The prefix of the output tflite model name.',
|
42
47
|
)
|
43
48
|
_PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
|
44
49
|
'prefill_seq_lens',
|
@@ -55,6 +60,11 @@ _QUANTIZE = flags.DEFINE_bool(
|
|
55
60
|
True,
|
56
61
|
'Whether the model should be quantized.',
|
57
62
|
)
|
63
|
+
_LORA_RANKS = flags.DEFINE_multi_integer(
|
64
|
+
'lora_ranks',
|
65
|
+
None,
|
66
|
+
'If set, the model will be converted with the provided list of LoRA ranks.',
|
67
|
+
)
|
58
68
|
|
59
69
|
_BUILDER = {
|
60
70
|
'1b': llama.build_1b_model,
|
@@ -66,13 +76,13 @@ def main(_):
|
|
66
76
|
pytorch_model = _BUILDER[_MODEL_SIZE.value](
|
67
77
|
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
|
68
78
|
)
|
69
|
-
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
|
70
|
-
output_filename = f'llama_{_MODEL_SIZE.value}_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
|
71
79
|
converter.convert_to_tflite(
|
72
80
|
pytorch_model,
|
73
|
-
|
81
|
+
output_path=_OUTPUT_PATH.value,
|
82
|
+
output_name_prefix=_OUTPUT_NAME_PREFIX.value,
|
74
83
|
prefill_seq_len=_PREFILL_SEQ_LENS.value,
|
75
84
|
quantize=_QUANTIZE.value,
|
85
|
+
lora_ranks=_LORA_RANKS.value,
|
76
86
|
export_config=ExportConfig(),
|
77
87
|
)
|
78
88
|
|
@@ -29,10 +29,15 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
|
|
29
29
|
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/openelm'),
|
30
30
|
'The path to the model checkpoint, or directory holding the checkpoint.',
|
31
31
|
)
|
32
|
-
|
33
|
-
'
|
32
|
+
_OUTPUT_PATH = flags.DEFINE_string(
|
33
|
+
'output_path',
|
34
34
|
'/tmp/',
|
35
|
-
'The
|
35
|
+
'The path to export the tflite model.',
|
36
|
+
)
|
37
|
+
_OUTPUT_NAME_PREFIX = flags.DEFINE_string(
|
38
|
+
'output_name_prefix',
|
39
|
+
'openelm',
|
40
|
+
'The prefix of the output tflite model name.',
|
36
41
|
)
|
37
42
|
_PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
|
38
43
|
'prefill_seq_lens',
|
@@ -49,22 +54,24 @@ _QUANTIZE = flags.DEFINE_bool(
|
|
49
54
|
True,
|
50
55
|
'Whether the model should be quantized.',
|
51
56
|
)
|
57
|
+
_LORA_RANKS = flags.DEFINE_multi_integer(
|
58
|
+
'lora_ranks',
|
59
|
+
None,
|
60
|
+
'If set, the model will be converted with the provided list of LoRA ranks.',
|
61
|
+
)
|
52
62
|
|
53
63
|
|
54
64
|
def main(_):
|
55
65
|
pytorch_model = openelm.build_model(
|
56
66
|
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
|
57
67
|
)
|
58
|
-
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
|
59
|
-
output_filename = (
|
60
|
-
f'openelm_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
|
61
|
-
)
|
62
|
-
|
63
68
|
converter.convert_to_tflite(
|
64
69
|
pytorch_model,
|
65
|
-
|
70
|
+
output_path=_OUTPUT_PATH.value,
|
71
|
+
output_name_prefix=_OUTPUT_NAME_PREFIX.value,
|
66
72
|
prefill_seq_len=_PREFILL_SEQ_LENS.value,
|
67
73
|
quantize=_QUANTIZE.value,
|
74
|
+
lora_ranks=_LORA_RANKS.value,
|
68
75
|
export_config=ExportConfig(),
|
69
76
|
)
|
70
77
|
|
@@ -40,10 +40,15 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
|
|
40
40
|
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/paligemma2-3b-224'),
|
41
41
|
'The path to the model checkpoint, or directory holding the checkpoint.',
|
42
42
|
)
|
43
|
-
|
44
|
-
'
|
43
|
+
_OUTPUT_PATH = flags.DEFINE_string(
|
44
|
+
'output_path',
|
45
45
|
'/tmp/',
|
46
|
-
'The
|
46
|
+
'The path to export the tflite model.',
|
47
|
+
)
|
48
|
+
_OUTPUT_NAME_PREFIX = flags.DEFINE_string(
|
49
|
+
'output_name_prefix',
|
50
|
+
'paligemma',
|
51
|
+
'The prefix of the output tflite model name.',
|
47
52
|
)
|
48
53
|
_PREFILL_SEQ_LEN = flags.DEFINE_integer(
|
49
54
|
'prefill_seq_len',
|
@@ -73,11 +78,11 @@ def main(_):
|
|
73
78
|
version=int(_VERSION.value),
|
74
79
|
kv_cache_max_len=_KV_CACHE_MAX_LEN.value,
|
75
80
|
)
|
76
|
-
|
77
|
-
output_filename = f'paligemma{_VERSION.value}_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
|
81
|
+
|
78
82
|
converter.convert_to_tflite(
|
79
83
|
pytorch_model,
|
80
|
-
|
84
|
+
output_path=_OUTPUT_PATH.value,
|
85
|
+
output_name_prefix=f'{_OUTPUT_NAME_PREFIX.value}_{_VERSION.value}',
|
81
86
|
prefill_seq_len=_PREFILL_SEQ_LEN.value,
|
82
87
|
pixel_values_size=torch.Size(_PIXEL_VALUES_SIZE.value),
|
83
88
|
quantize=_QUANTIZE.value,
|
@@ -26,13 +26,18 @@ from ai_edge_torch.generative.utilities.model_builder import ExportConfig
|
|
26
26
|
|
27
27
|
_CHECKPOINT_PATH = flags.DEFINE_string(
|
28
28
|
'checkpoint_path',
|
29
|
-
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/
|
29
|
+
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/gemma-2b'),
|
30
30
|
'The path to the model checkpoint, or directory holding the checkpoint.',
|
31
31
|
)
|
32
|
-
|
33
|
-
'
|
32
|
+
_OUTPUT_PATH = flags.DEFINE_string(
|
33
|
+
'output_path',
|
34
34
|
'/tmp/',
|
35
|
-
'The
|
35
|
+
'The path to export the tflite model.',
|
36
|
+
)
|
37
|
+
_OUTPUT_NAME_PREFIX = flags.DEFINE_string(
|
38
|
+
'output_name_prefix',
|
39
|
+
'phi3',
|
40
|
+
'The prefix of the output tflite model name.',
|
36
41
|
)
|
37
42
|
_PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
|
38
43
|
'prefill_seq_lens',
|
@@ -49,19 +54,24 @@ _QUANTIZE = flags.DEFINE_bool(
|
|
49
54
|
True,
|
50
55
|
'Whether the model should be quantized.',
|
51
56
|
)
|
57
|
+
_LORA_RANKS = flags.DEFINE_multi_integer(
|
58
|
+
'lora_ranks',
|
59
|
+
None,
|
60
|
+
'If set, the model will be converted with the provided list of LoRA ranks.',
|
61
|
+
)
|
52
62
|
|
53
63
|
|
54
64
|
def main(_):
|
55
65
|
pytorch_model = phi3.build_model(
|
56
66
|
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
|
57
67
|
)
|
58
|
-
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
|
59
|
-
output_filename = f'phi3_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
|
60
68
|
converter.convert_to_tflite(
|
61
69
|
pytorch_model,
|
62
|
-
|
70
|
+
output_path=_OUTPUT_PATH.value,
|
71
|
+
output_name_prefix=_OUTPUT_NAME_PREFIX.value,
|
63
72
|
prefill_seq_len=_PREFILL_SEQ_LENS.value,
|
64
73
|
quantize=_QUANTIZE.value,
|
74
|
+
lora_ranks=_LORA_RANKS.value,
|
65
75
|
export_config=ExportConfig(),
|
66
76
|
)
|
67
77
|
|
@@ -29,10 +29,15 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
|
|
29
29
|
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/phi2'),
|
30
30
|
'The path to the model checkpoint, or directory holding the checkpoint.',
|
31
31
|
)
|
32
|
-
|
33
|
-
'
|
32
|
+
_OUTPUT_PATH = flags.DEFINE_string(
|
33
|
+
'output_path',
|
34
34
|
'/tmp/',
|
35
|
-
'The
|
35
|
+
'The path to export the tflite model.',
|
36
|
+
)
|
37
|
+
_OUTPUT_NAME_PREFIX = flags.DEFINE_string(
|
38
|
+
'output_name_prefix',
|
39
|
+
'phi2',
|
40
|
+
'The prefix of the output tflite model name.',
|
36
41
|
)
|
37
42
|
_PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
|
38
43
|
'prefill_seq_lens',
|
@@ -49,19 +54,24 @@ _QUANTIZE = flags.DEFINE_bool(
|
|
49
54
|
True,
|
50
55
|
'Whether the model should be quantized.',
|
51
56
|
)
|
57
|
+
_LORA_RANKS = flags.DEFINE_multi_integer(
|
58
|
+
'lora_ranks',
|
59
|
+
None,
|
60
|
+
'If set, the model will be converted with the provided list of LoRA ranks.',
|
61
|
+
)
|
52
62
|
|
53
63
|
|
54
64
|
def main(_):
|
55
65
|
pytorch_model = phi2.build_model(
|
56
66
|
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
|
57
67
|
)
|
58
|
-
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
|
59
|
-
output_filename = f'phi2_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
|
60
68
|
converter.convert_to_tflite(
|
61
69
|
pytorch_model,
|
62
|
-
|
70
|
+
output_path=_OUTPUT_PATH.value,
|
71
|
+
output_name_prefix=_OUTPUT_NAME_PREFIX.value,
|
63
72
|
prefill_seq_len=_PREFILL_SEQ_LENS.value,
|
64
73
|
quantize=_QUANTIZE.value,
|
74
|
+
lora_ranks=_LORA_RANKS.value,
|
65
75
|
export_config=ExportConfig(),
|
66
76
|
)
|
67
77
|
|
@@ -35,10 +35,15 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
|
|
35
35
|
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/qwen'),
|
36
36
|
'The path to the model checkpoint, or directory holding the checkpoint.',
|
37
37
|
)
|
38
|
-
|
39
|
-
'
|
38
|
+
_OUTPUT_PATH = flags.DEFINE_string(
|
39
|
+
'output_path',
|
40
40
|
'/tmp/',
|
41
|
-
'The
|
41
|
+
'The path to export the tflite model.',
|
42
|
+
)
|
43
|
+
_OUTPUT_NAME_PREFIX = flags.DEFINE_string(
|
44
|
+
'output_name_prefix',
|
45
|
+
'qwen',
|
46
|
+
'The prefix of the output tflite model name.',
|
42
47
|
)
|
43
48
|
_PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
|
44
49
|
'prefill_seq_lens',
|
@@ -55,6 +60,12 @@ _QUANTIZE = flags.DEFINE_bool(
|
|
55
60
|
True,
|
56
61
|
'Whether the model should be quantized.',
|
57
62
|
)
|
63
|
+
_LORA_RANKS = flags.DEFINE_multi_integer(
|
64
|
+
'lora_ranks',
|
65
|
+
None,
|
66
|
+
'If set, the model will be converted with the provided list of LoRA ranks.',
|
67
|
+
)
|
68
|
+
|
58
69
|
|
59
70
|
_BUILDER = {
|
60
71
|
'0.5b': qwen.build_0_5b_model,
|
@@ -67,16 +78,13 @@ def main(_):
|
|
67
78
|
pytorch_model = _BUILDER[_MODEL_SIZE.value](
|
68
79
|
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
|
69
80
|
)
|
70
|
-
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
|
71
|
-
model_size = _MODEL_SIZE.value.replace('.', '_')
|
72
|
-
output_filename = (
|
73
|
-
f'qwen_{model_size}_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
|
74
|
-
)
|
75
81
|
converter.convert_to_tflite(
|
76
82
|
pytorch_model,
|
77
|
-
|
83
|
+
output_path=_OUTPUT_PATH.value,
|
84
|
+
output_name_prefix=_OUTPUT_NAME_PREFIX.value,
|
78
85
|
prefill_seq_len=_PREFILL_SEQ_LENS.value,
|
79
86
|
quantize=_QUANTIZE.value,
|
87
|
+
lora_ranks=_LORA_RANKS.value,
|
80
88
|
export_config=ExportConfig(),
|
81
89
|
)
|
82
90
|
|