ai-edge-torch-nightly 0.3.0.dev20250105__py3-none-any.whl → 0.3.0.dev20250108__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/_config.py +26 -9
- ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_check.py +13 -0
- ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +36 -0
- ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py +16 -6
- ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +16 -6
- ai_edge_torch/generative/examples/gemma/gemma2.py +25 -43
- ai_edge_torch/generative/examples/llama/convert_to_tflite.py +16 -6
- ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +16 -9
- ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py +11 -6
- ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py +17 -7
- ai_edge_torch/generative/examples/phi/convert_to_tflite.py +16 -6
- ai_edge_torch/generative/examples/qwen/convert_to_tflite.py +17 -9
- ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +16 -7
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +3 -3
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +16 -8
- ai_edge_torch/generative/layers/attention.py +70 -12
- ai_edge_torch/generative/layers/lora.py +557 -0
- ai_edge_torch/generative/layers/normalization.py +2 -50
- ai_edge_torch/generative/layers/rotary_position_embedding.py +27 -34
- ai_edge_torch/generative/test/test_lora.py +147 -0
- ai_edge_torch/generative/utilities/converter.py +100 -47
- ai_edge_torch/generative/utilities/model_builder.py +21 -16
- ai_edge_torch/generative/utilities/verifier.py +4 -4
- ai_edge_torch/odml_torch/_torch_future.py +13 -0
- ai_edge_torch/odml_torch/export.py +6 -2
- ai_edge_torch/odml_torch/lowerings/decomp.py +4 -0
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20250105.dist-info → ai_edge_torch_nightly-0.3.0.dev20250108.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20250105.dist-info → ai_edge_torch_nightly-0.3.0.dev20250108.dist-info}/RECORD +32 -30
- {ai_edge_torch_nightly-0.3.0.dev20250105.dist-info → ai_edge_torch_nightly-0.3.0.dev20250108.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20250105.dist-info → ai_edge_torch_nightly-0.3.0.dev20250108.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20250105.dist-info → ai_edge_torch_nightly-0.3.0.dev20250108.dist-info}/top_level.txt +0 -0
ai_edge_torch/_config.py
CHANGED
@@ -22,6 +22,18 @@ import os
|
|
22
22
|
__all__ = ["config"]
|
23
23
|
|
24
24
|
|
25
|
+
def _get_bool_env_var(name: str, default: bool) -> bool:
|
26
|
+
var = os.environ.get(name, "false")
|
27
|
+
var = var.lower().strip()
|
28
|
+
if var in ("y", "yes", "t", "true", "on", "1"):
|
29
|
+
return True
|
30
|
+
elif var in ("n", "no", "f", "false", "off", "0"):
|
31
|
+
return False
|
32
|
+
else:
|
33
|
+
logging.warning("Invalid %s value is ignored: %s.", name, var)
|
34
|
+
return default
|
35
|
+
|
36
|
+
|
25
37
|
class _Config:
|
26
38
|
"""ai-edge-torch global configs."""
|
27
39
|
|
@@ -33,20 +45,25 @@ class _Config:
|
|
33
45
|
To use torch_xla as the lowering backend, set environment variable
|
34
46
|
`USE_TORCH_XLA` to "true".
|
35
47
|
"""
|
36
|
-
|
37
|
-
var = var.lower().strip()
|
38
|
-
if var in ("y", "yes", "t", "true", "on", "1"):
|
39
|
-
return True
|
40
|
-
elif var in ("n", "no", "f", "false", "off", "0"):
|
41
|
-
return False
|
42
|
-
else:
|
43
|
-
logging.warning("Invalid USE_TORCH_XLA value is ignored: %s.", var)
|
44
|
-
return False
|
48
|
+
return _get_bool_env_var("USE_TORCH_XLA", default=False)
|
45
49
|
|
46
50
|
@property
|
47
51
|
def in_oss(self) -> bool:
|
48
52
|
"""True if the code is not running in google internal environment."""
|
49
53
|
return True
|
50
54
|
|
55
|
+
@property
|
56
|
+
def enable_group_norm_composite(self) -> bool:
|
57
|
+
"""True if lowering group norm in StableHLO composite.
|
58
|
+
|
59
|
+
Currently only supports NHWC group norm generated by
|
60
|
+
OptimizeLayoutTransposesPass.
|
61
|
+
"""
|
62
|
+
return _get_bool_env_var("ENABLE_GROUP_NORM_COMPOSITE", default=False)
|
63
|
+
|
64
|
+
@enable_group_norm_composite.setter
|
65
|
+
def enable_group_norm_composite(self, value: bool):
|
66
|
+
os.environ["ENABLE_GROUP_NORM_COMPOSITE"] = "y" if value else "n"
|
67
|
+
|
51
68
|
|
52
69
|
config = _Config()
|
@@ -17,6 +17,7 @@
|
|
17
17
|
import dataclasses
|
18
18
|
import operator
|
19
19
|
|
20
|
+
import ai_edge_torch
|
20
21
|
from ai_edge_torch import lowertools
|
21
22
|
from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import layout_rewrite
|
22
23
|
from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import utils
|
@@ -155,6 +156,7 @@ def _qdq_layout_sensitive_inputs_getter(node: Node):
|
|
155
156
|
@layout_sensitive_inputs_getters.register(
|
156
157
|
aten._native_batch_norm_legit_no_training
|
157
158
|
)
|
159
|
+
@layout_sensitive_inputs_getters.register(aten.group_norm)
|
158
160
|
@layout_sensitive_inputs_getters.register(aten.native_group_norm)
|
159
161
|
def _first_arg_getter(node):
|
160
162
|
return [node.args[0]]
|
@@ -188,6 +190,17 @@ def _aten_norm_checker(node):
|
|
188
190
|
return NHWCable(can_be=len(val[0].shape) == 4, must_be=False)
|
189
191
|
|
190
192
|
|
193
|
+
@nhwcable_node_checkers.register(aten.group_norm)
|
194
|
+
def _aten_group_norm_checker(node):
|
195
|
+
val = node.meta.get("val")
|
196
|
+
if not hasattr(val, "shape"):
|
197
|
+
return NHWCable(can_be=False, must_be=False)
|
198
|
+
|
199
|
+
can_be = len(val.shape) == 4
|
200
|
+
must_be = can_be and ai_edge_torch.config.enable_group_norm_composite
|
201
|
+
return NHWCable(can_be=can_be, must_be=must_be)
|
202
|
+
|
203
|
+
|
191
204
|
@nhwcable_node_checkers.register(aten.native_group_norm)
|
192
205
|
def _aten_native_group_norm_checker(node):
|
193
206
|
val = node.meta.get("val")
|
@@ -16,6 +16,7 @@
|
|
16
16
|
|
17
17
|
import operator
|
18
18
|
|
19
|
+
import ai_edge_torch
|
19
20
|
from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import layout_mark
|
20
21
|
from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import op_func_registry
|
21
22
|
from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import utils
|
@@ -23,6 +24,7 @@ import torch
|
|
23
24
|
import torch.utils._pytree as pytree
|
24
25
|
|
25
26
|
aten = torch.ops.aten
|
27
|
+
StableHLOCompositeBuilder = ai_edge_torch.hlfb.StableHLOCompositeBuilder
|
26
28
|
|
27
29
|
__all__ = ["rewrite_nhwc_node", "has_nhwc_rewriter"]
|
28
30
|
|
@@ -342,6 +344,39 @@ def _aten__native_batch_norm_legit_no_training(node):
|
|
342
344
|
node.target = batch_norm
|
343
345
|
|
344
346
|
|
347
|
+
@rewriters.register(aten.group_norm.default)
|
348
|
+
def _aten_group_norm(node):
|
349
|
+
def group_norm(input, num_groups: int, weight=None, bias=None, eps=1e-5):
|
350
|
+
is_composite_supported = (
|
351
|
+
ai_edge_torch.config.enable_group_norm_composite
|
352
|
+
and weight is not None
|
353
|
+
and bias is not None
|
354
|
+
)
|
355
|
+
|
356
|
+
builder = None
|
357
|
+
if is_composite_supported:
|
358
|
+
builder = StableHLOCompositeBuilder(
|
359
|
+
name="odml.group_norm",
|
360
|
+
attr={
|
361
|
+
"num_groups": num_groups,
|
362
|
+
"epsilon": eps,
|
363
|
+
"reduction_axes": [3],
|
364
|
+
"channel_axis": 3,
|
365
|
+
},
|
366
|
+
)
|
367
|
+
input, weight, bias = builder.mark_inputs(input, weight, bias)
|
368
|
+
|
369
|
+
input = utils.tensor_to_nchw(input)
|
370
|
+
output = aten.group_norm.default(input, num_groups, weight, bias, eps=eps)
|
371
|
+
output = utils.tensor_to_nhwc(output)
|
372
|
+
|
373
|
+
if builder is not None:
|
374
|
+
output = builder.mark_outputs(output)
|
375
|
+
return output
|
376
|
+
|
377
|
+
node.target = group_norm
|
378
|
+
|
379
|
+
|
345
380
|
@rewriters.register(aten.native_group_norm.default)
|
346
381
|
def _aten_native_group_norm(node):
|
347
382
|
|
@@ -354,6 +389,7 @@ def _aten_native_group_norm(node):
|
|
354
389
|
flattened_inner_size: int,
|
355
390
|
num_groups: int,
|
356
391
|
eps: float,
|
392
|
+
**kwargs,
|
357
393
|
):
|
358
394
|
input_reshaped = torch.reshape(
|
359
395
|
input,
|
@@ -29,10 +29,15 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
|
|
29
29
|
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/gemma-2b'),
|
30
30
|
'The path to the model checkpoint, or directory holding the checkpoint.',
|
31
31
|
)
|
32
|
-
|
33
|
-
'
|
32
|
+
_OUTPUT_PATH = flags.DEFINE_string(
|
33
|
+
'output_path',
|
34
34
|
'/tmp/',
|
35
|
-
'The
|
35
|
+
'The path to export the tflite model.',
|
36
|
+
)
|
37
|
+
_OUTPUT_NAME_PREFIX = flags.DEFINE_string(
|
38
|
+
'output_name_prefix',
|
39
|
+
'gemma',
|
40
|
+
'The prefix of the output tflite model name.',
|
36
41
|
)
|
37
42
|
_PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
|
38
43
|
'prefill_seq_lens',
|
@@ -49,19 +54,24 @@ _QUANTIZE = flags.DEFINE_bool(
|
|
49
54
|
True,
|
50
55
|
'Whether the model should be quantized.',
|
51
56
|
)
|
57
|
+
_LORA_RANKS = flags.DEFINE_multi_integer(
|
58
|
+
'lora_ranks',
|
59
|
+
None,
|
60
|
+
'If set, the model will be converted with the provided list of LoRA ranks.',
|
61
|
+
)
|
52
62
|
|
53
63
|
|
54
64
|
def main(_):
|
55
65
|
pytorch_model = gemma1.build_2b_model(
|
56
66
|
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
|
57
67
|
)
|
58
|
-
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
|
59
|
-
output_filename = f'gemma_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
|
60
68
|
converter.convert_to_tflite(
|
61
69
|
pytorch_model,
|
62
|
-
|
70
|
+
output_path=_OUTPUT_PATH.value,
|
71
|
+
output_name_prefix=_OUTPUT_NAME_PREFIX.value,
|
63
72
|
prefill_seq_len=_PREFILL_SEQ_LENS.value,
|
64
73
|
quantize=_QUANTIZE.value,
|
74
|
+
lora_ranks=_LORA_RANKS.value,
|
65
75
|
export_config=ExportConfig(),
|
66
76
|
)
|
67
77
|
|
@@ -29,10 +29,15 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
|
|
29
29
|
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/gemma2-2b'),
|
30
30
|
'The path to the model checkpoint, or directory holding the checkpoint.',
|
31
31
|
)
|
32
|
-
|
33
|
-
'
|
32
|
+
_OUTPUT_PATH = flags.DEFINE_string(
|
33
|
+
'output_path',
|
34
34
|
'/tmp/',
|
35
|
-
'The
|
35
|
+
'The path to export the tflite model.',
|
36
|
+
)
|
37
|
+
_OUTPUT_NAME_PREFIX = flags.DEFINE_string(
|
38
|
+
'output_name_prefix',
|
39
|
+
'gemma2',
|
40
|
+
'The prefix of the output tflite model name.',
|
36
41
|
)
|
37
42
|
_PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
|
38
43
|
'prefill_seq_lens',
|
@@ -49,19 +54,24 @@ _QUANTIZE = flags.DEFINE_bool(
|
|
49
54
|
True,
|
50
55
|
'Whether the model should be quantized.',
|
51
56
|
)
|
57
|
+
_LORA_RANKS = flags.DEFINE_multi_integer(
|
58
|
+
'lora_ranks',
|
59
|
+
None,
|
60
|
+
'If set, the model will be converted with the provided list of LoRA ranks.',
|
61
|
+
)
|
52
62
|
|
53
63
|
|
54
64
|
def main(_):
|
55
65
|
pytorch_model = gemma2.build_2b_model(
|
56
66
|
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
|
57
67
|
)
|
58
|
-
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
|
59
|
-
output_filename = f'gemma2_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
|
60
68
|
converter.convert_to_tflite(
|
61
69
|
pytorch_model,
|
62
|
-
|
70
|
+
output_path=_OUTPUT_PATH.value,
|
71
|
+
output_name_prefix=_OUTPUT_NAME_PREFIX.value,
|
63
72
|
prefill_seq_len=_PREFILL_SEQ_LENS.value,
|
64
73
|
quantize=_QUANTIZE.value,
|
74
|
+
lora_ranks=_LORA_RANKS.value,
|
65
75
|
export_config=ExportConfig(),
|
66
76
|
)
|
67
77
|
|
@@ -15,14 +15,13 @@
|
|
15
15
|
|
16
16
|
"""Example of building a Gemma2 model."""
|
17
17
|
|
18
|
-
from typing import
|
18
|
+
from typing import Optional, Tuple
|
19
19
|
|
20
20
|
from ai_edge_torch.generative.layers import attention
|
21
21
|
from ai_edge_torch.generative.layers import builder
|
22
22
|
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
23
23
|
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
24
24
|
import ai_edge_torch.generative.layers.model_config as cfg
|
25
|
-
import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
|
26
25
|
from ai_edge_torch.generative.utilities import model_builder
|
27
26
|
import ai_edge_torch.generative.utilities.loader as loading_utils
|
28
27
|
import torch
|
@@ -104,12 +103,17 @@ class Gemma2(nn.Module):
|
|
104
103
|
config.embedding_dim,
|
105
104
|
config.final_norm_config,
|
106
105
|
)
|
107
|
-
self.mask_cache = attn_utils.build_causal_mask_cache(
|
108
|
-
size=config.kv_cache_max,
|
109
|
-
)
|
110
106
|
# Gemma2 has same hyper parameters for each layer except for attention
|
111
107
|
# types. Use the first layer.
|
112
108
|
attn_config = config.block_config(0).attn_config
|
109
|
+
self.rope_cache = attn_utils.build_rope_cache(
|
110
|
+
size=config.kv_cache_max,
|
111
|
+
dim=int(attn_config.rotary_percentage * attn_config.head_dim),
|
112
|
+
base=attn_config.rotary_base,
|
113
|
+
)
|
114
|
+
self.mask_cache = attn_utils.build_causal_mask_cache(
|
115
|
+
size=config.kv_cache_max,
|
116
|
+
)
|
113
117
|
self.sliding_window_mask_cache = attn_utils.build_sliding_window_mask_cache(
|
114
118
|
size=config.kv_cache_max,
|
115
119
|
window_size=attn_config.sliding_window_size,
|
@@ -136,48 +140,29 @@ class Gemma2(nn.Module):
|
|
136
140
|
f"Cannot forward sequence of length {seq_len}, max seq length is only"
|
137
141
|
f" {self.config.max_seq_len}"
|
138
142
|
)
|
139
|
-
|
140
|
-
# token embeddings of shape (b, t, n_embd)
|
141
|
-
input_embeds = self.tok_embedding(tokens)
|
142
|
-
# RoPE parameters are the same for all blocks. Use the first layer.
|
143
|
-
attn_config = self.config.block_config(0).attn_config
|
144
|
-
n_elem = int(attn_config.rotary_percentage * attn_config.head_dim)
|
145
|
-
rope = rotary_pos_emb.build_rope(
|
146
|
-
input_pos, n_elem, attn_config.head_dim, attn_config.rotary_base
|
147
|
-
)
|
148
|
-
mask = [self.get_attention_mask(
|
149
|
-
self.config.block_config(i).attn_config.attn_type, input_pos
|
150
|
-
) for i in range(self.config.num_layers)]
|
151
|
-
|
152
|
-
return self._forward_with_embeds(
|
153
|
-
input_embeds, rope, mask, input_pos, kv_cache, export_config
|
154
|
-
)
|
155
|
-
|
156
|
-
def _forward_with_embeds(
|
157
|
-
self,
|
158
|
-
input_embeds: torch.Tensor,
|
159
|
-
rope: Tuple[torch.Tensor, torch.Tensor],
|
160
|
-
mask: List[torch.Tensor],
|
161
|
-
input_pos: torch.Tensor,
|
162
|
-
kv_cache: kv_utils.KVCache,
|
163
|
-
export_config: Optional[model_builder.ExportConfig] = None,
|
164
|
-
) -> dict[torch.Tensor, kv_utils.KVCache]:
|
165
|
-
"""Forwards the model with input embeddings."""
|
166
143
|
assert len(self.transformer_blocks) == len(kv_cache.caches), (
|
167
144
|
"The number of transformer blocks and the number of KV cache entries"
|
168
145
|
" must be the same."
|
169
146
|
)
|
170
147
|
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
148
|
+
cos, sin = self.rope_cache
|
149
|
+
cos = cos.index_select(0, input_pos)
|
150
|
+
sin = sin.index_select(0, input_pos)
|
151
|
+
|
152
|
+
# token embeddings of shape (b, t, n_embd)
|
153
|
+
x = self.tok_embedding(tokens)
|
154
|
+
x = x * (self.config.embedding_dim**0.5)
|
155
|
+
|
156
|
+
updated_kv_entires = []
|
175
157
|
for i, block in enumerate(self.transformer_blocks):
|
158
|
+
mask = self.get_attention_mask(
|
159
|
+
block.config.attn_config.attn_type, input_pos
|
160
|
+
)
|
176
161
|
kv_entry = kv_cache.caches[i] if kv_cache else None
|
177
|
-
x, kv_entry = block(x,
|
162
|
+
x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
|
178
163
|
if kv_entry:
|
179
|
-
|
180
|
-
updated_kv_cache = kv_utils.KVCache(tuple(
|
164
|
+
updated_kv_entires.append(kv_entry)
|
165
|
+
updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
|
181
166
|
|
182
167
|
if export_config is not None:
|
183
168
|
if (
|
@@ -243,13 +228,11 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
243
228
|
)
|
244
229
|
|
245
230
|
num_layers = 26
|
246
|
-
embedding_dim = 2304
|
247
231
|
config = cfg.ModelConfig(
|
248
232
|
vocab_size=256000,
|
249
233
|
num_layers=num_layers,
|
250
234
|
max_seq_len=8192,
|
251
|
-
embedding_dim=
|
252
|
-
embedding_scale=embedding_dim**0.5,
|
235
|
+
embedding_dim=2304,
|
253
236
|
kv_cache_max_len=kv_cache_max_len,
|
254
237
|
block_configs=[get_block_config(i) for i in range(num_layers)],
|
255
238
|
final_norm_config=norm_config,
|
@@ -266,7 +249,6 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
|
|
266
249
|
config.num_layers = 2
|
267
250
|
config.max_seq_len = 2 * kv_cache_max_len
|
268
251
|
config.embedding_dim = 128
|
269
|
-
config.embedding_scale = config.embedding_dim**0.5
|
270
252
|
config.block_configs = config.block_configs[: config.num_layers]
|
271
253
|
for block_config in config.block_configs:
|
272
254
|
block_config.attn_config.num_heads = 4
|
@@ -35,10 +35,15 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
|
|
35
35
|
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/llama'),
|
36
36
|
'The path to the model checkpoint, or directory holding the checkpoint.',
|
37
37
|
)
|
38
|
-
|
39
|
-
'
|
38
|
+
_OUTPUT_PATH = flags.DEFINE_string(
|
39
|
+
'output_path',
|
40
40
|
'/tmp/',
|
41
|
-
'The
|
41
|
+
'The path to export the tflite model.',
|
42
|
+
)
|
43
|
+
_OUTPUT_NAME_PREFIX = flags.DEFINE_string(
|
44
|
+
'output_name_prefix',
|
45
|
+
'llama',
|
46
|
+
'The prefix of the output tflite model name.',
|
42
47
|
)
|
43
48
|
_PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
|
44
49
|
'prefill_seq_lens',
|
@@ -55,6 +60,11 @@ _QUANTIZE = flags.DEFINE_bool(
|
|
55
60
|
True,
|
56
61
|
'Whether the model should be quantized.',
|
57
62
|
)
|
63
|
+
_LORA_RANKS = flags.DEFINE_multi_integer(
|
64
|
+
'lora_ranks',
|
65
|
+
None,
|
66
|
+
'If set, the model will be converted with the provided list of LoRA ranks.',
|
67
|
+
)
|
58
68
|
|
59
69
|
_BUILDER = {
|
60
70
|
'1b': llama.build_1b_model,
|
@@ -66,13 +76,13 @@ def main(_):
|
|
66
76
|
pytorch_model = _BUILDER[_MODEL_SIZE.value](
|
67
77
|
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
|
68
78
|
)
|
69
|
-
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
|
70
|
-
output_filename = f'llama_{_MODEL_SIZE.value}_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
|
71
79
|
converter.convert_to_tflite(
|
72
80
|
pytorch_model,
|
73
|
-
|
81
|
+
output_path=_OUTPUT_PATH.value,
|
82
|
+
output_name_prefix=_OUTPUT_NAME_PREFIX.value,
|
74
83
|
prefill_seq_len=_PREFILL_SEQ_LENS.value,
|
75
84
|
quantize=_QUANTIZE.value,
|
85
|
+
lora_ranks=_LORA_RANKS.value,
|
76
86
|
export_config=ExportConfig(),
|
77
87
|
)
|
78
88
|
|
@@ -29,10 +29,15 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
|
|
29
29
|
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/openelm'),
|
30
30
|
'The path to the model checkpoint, or directory holding the checkpoint.',
|
31
31
|
)
|
32
|
-
|
33
|
-
'
|
32
|
+
_OUTPUT_PATH = flags.DEFINE_string(
|
33
|
+
'output_path',
|
34
34
|
'/tmp/',
|
35
|
-
'The
|
35
|
+
'The path to export the tflite model.',
|
36
|
+
)
|
37
|
+
_OUTPUT_NAME_PREFIX = flags.DEFINE_string(
|
38
|
+
'output_name_prefix',
|
39
|
+
'openelm',
|
40
|
+
'The prefix of the output tflite model name.',
|
36
41
|
)
|
37
42
|
_PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
|
38
43
|
'prefill_seq_lens',
|
@@ -49,22 +54,24 @@ _QUANTIZE = flags.DEFINE_bool(
|
|
49
54
|
True,
|
50
55
|
'Whether the model should be quantized.',
|
51
56
|
)
|
57
|
+
_LORA_RANKS = flags.DEFINE_multi_integer(
|
58
|
+
'lora_ranks',
|
59
|
+
None,
|
60
|
+
'If set, the model will be converted with the provided list of LoRA ranks.',
|
61
|
+
)
|
52
62
|
|
53
63
|
|
54
64
|
def main(_):
|
55
65
|
pytorch_model = openelm.build_model(
|
56
66
|
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
|
57
67
|
)
|
58
|
-
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
|
59
|
-
output_filename = (
|
60
|
-
f'openelm_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
|
61
|
-
)
|
62
|
-
|
63
68
|
converter.convert_to_tflite(
|
64
69
|
pytorch_model,
|
65
|
-
|
70
|
+
output_path=_OUTPUT_PATH.value,
|
71
|
+
output_name_prefix=_OUTPUT_NAME_PREFIX.value,
|
66
72
|
prefill_seq_len=_PREFILL_SEQ_LENS.value,
|
67
73
|
quantize=_QUANTIZE.value,
|
74
|
+
lora_ranks=_LORA_RANKS.value,
|
68
75
|
export_config=ExportConfig(),
|
69
76
|
)
|
70
77
|
|
@@ -40,10 +40,15 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
|
|
40
40
|
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/paligemma2-3b-224'),
|
41
41
|
'The path to the model checkpoint, or directory holding the checkpoint.',
|
42
42
|
)
|
43
|
-
|
44
|
-
'
|
43
|
+
_OUTPUT_PATH = flags.DEFINE_string(
|
44
|
+
'output_path',
|
45
45
|
'/tmp/',
|
46
|
-
'The
|
46
|
+
'The path to export the tflite model.',
|
47
|
+
)
|
48
|
+
_OUTPUT_NAME_PREFIX = flags.DEFINE_string(
|
49
|
+
'output_name_prefix',
|
50
|
+
'paligemma',
|
51
|
+
'The prefix of the output tflite model name.',
|
47
52
|
)
|
48
53
|
_PREFILL_SEQ_LEN = flags.DEFINE_integer(
|
49
54
|
'prefill_seq_len',
|
@@ -73,11 +78,11 @@ def main(_):
|
|
73
78
|
version=int(_VERSION.value),
|
74
79
|
kv_cache_max_len=_KV_CACHE_MAX_LEN.value,
|
75
80
|
)
|
76
|
-
|
77
|
-
output_filename = f'paligemma{_VERSION.value}_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
|
81
|
+
|
78
82
|
converter.convert_to_tflite(
|
79
83
|
pytorch_model,
|
80
|
-
|
84
|
+
output_path=_OUTPUT_PATH.value,
|
85
|
+
output_name_prefix=f'{_OUTPUT_NAME_PREFIX.value}_{_VERSION.value}',
|
81
86
|
prefill_seq_len=_PREFILL_SEQ_LEN.value,
|
82
87
|
pixel_values_size=torch.Size(_PIXEL_VALUES_SIZE.value),
|
83
88
|
quantize=_QUANTIZE.value,
|
@@ -26,13 +26,18 @@ from ai_edge_torch.generative.utilities.model_builder import ExportConfig
|
|
26
26
|
|
27
27
|
_CHECKPOINT_PATH = flags.DEFINE_string(
|
28
28
|
'checkpoint_path',
|
29
|
-
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/
|
29
|
+
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/gemma-2b'),
|
30
30
|
'The path to the model checkpoint, or directory holding the checkpoint.',
|
31
31
|
)
|
32
|
-
|
33
|
-
'
|
32
|
+
_OUTPUT_PATH = flags.DEFINE_string(
|
33
|
+
'output_path',
|
34
34
|
'/tmp/',
|
35
|
-
'The
|
35
|
+
'The path to export the tflite model.',
|
36
|
+
)
|
37
|
+
_OUTPUT_NAME_PREFIX = flags.DEFINE_string(
|
38
|
+
'output_name_prefix',
|
39
|
+
'phi3',
|
40
|
+
'The prefix of the output tflite model name.',
|
36
41
|
)
|
37
42
|
_PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
|
38
43
|
'prefill_seq_lens',
|
@@ -49,19 +54,24 @@ _QUANTIZE = flags.DEFINE_bool(
|
|
49
54
|
True,
|
50
55
|
'Whether the model should be quantized.',
|
51
56
|
)
|
57
|
+
_LORA_RANKS = flags.DEFINE_multi_integer(
|
58
|
+
'lora_ranks',
|
59
|
+
None,
|
60
|
+
'If set, the model will be converted with the provided list of LoRA ranks.',
|
61
|
+
)
|
52
62
|
|
53
63
|
|
54
64
|
def main(_):
|
55
65
|
pytorch_model = phi3.build_model(
|
56
66
|
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
|
57
67
|
)
|
58
|
-
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
|
59
|
-
output_filename = f'phi3_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
|
60
68
|
converter.convert_to_tflite(
|
61
69
|
pytorch_model,
|
62
|
-
|
70
|
+
output_path=_OUTPUT_PATH.value,
|
71
|
+
output_name_prefix=_OUTPUT_NAME_PREFIX.value,
|
63
72
|
prefill_seq_len=_PREFILL_SEQ_LENS.value,
|
64
73
|
quantize=_QUANTIZE.value,
|
74
|
+
lora_ranks=_LORA_RANKS.value,
|
65
75
|
export_config=ExportConfig(),
|
66
76
|
)
|
67
77
|
|
@@ -29,10 +29,15 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
|
|
29
29
|
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/phi2'),
|
30
30
|
'The path to the model checkpoint, or directory holding the checkpoint.',
|
31
31
|
)
|
32
|
-
|
33
|
-
'
|
32
|
+
_OUTPUT_PATH = flags.DEFINE_string(
|
33
|
+
'output_path',
|
34
34
|
'/tmp/',
|
35
|
-
'The
|
35
|
+
'The path to export the tflite model.',
|
36
|
+
)
|
37
|
+
_OUTPUT_NAME_PREFIX = flags.DEFINE_string(
|
38
|
+
'output_name_prefix',
|
39
|
+
'phi2',
|
40
|
+
'The prefix of the output tflite model name.',
|
36
41
|
)
|
37
42
|
_PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
|
38
43
|
'prefill_seq_lens',
|
@@ -49,19 +54,24 @@ _QUANTIZE = flags.DEFINE_bool(
|
|
49
54
|
True,
|
50
55
|
'Whether the model should be quantized.',
|
51
56
|
)
|
57
|
+
_LORA_RANKS = flags.DEFINE_multi_integer(
|
58
|
+
'lora_ranks',
|
59
|
+
None,
|
60
|
+
'If set, the model will be converted with the provided list of LoRA ranks.',
|
61
|
+
)
|
52
62
|
|
53
63
|
|
54
64
|
def main(_):
|
55
65
|
pytorch_model = phi2.build_model(
|
56
66
|
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
|
57
67
|
)
|
58
|
-
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
|
59
|
-
output_filename = f'phi2_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
|
60
68
|
converter.convert_to_tflite(
|
61
69
|
pytorch_model,
|
62
|
-
|
70
|
+
output_path=_OUTPUT_PATH.value,
|
71
|
+
output_name_prefix=_OUTPUT_NAME_PREFIX.value,
|
63
72
|
prefill_seq_len=_PREFILL_SEQ_LENS.value,
|
64
73
|
quantize=_QUANTIZE.value,
|
74
|
+
lora_ranks=_LORA_RANKS.value,
|
65
75
|
export_config=ExportConfig(),
|
66
76
|
)
|
67
77
|
|
@@ -35,10 +35,15 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
|
|
35
35
|
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/qwen'),
|
36
36
|
'The path to the model checkpoint, or directory holding the checkpoint.',
|
37
37
|
)
|
38
|
-
|
39
|
-
'
|
38
|
+
_OUTPUT_PATH = flags.DEFINE_string(
|
39
|
+
'output_path',
|
40
40
|
'/tmp/',
|
41
|
-
'The
|
41
|
+
'The path to export the tflite model.',
|
42
|
+
)
|
43
|
+
_OUTPUT_NAME_PREFIX = flags.DEFINE_string(
|
44
|
+
'output_name_prefix',
|
45
|
+
'qwen',
|
46
|
+
'The prefix of the output tflite model name.',
|
42
47
|
)
|
43
48
|
_PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
|
44
49
|
'prefill_seq_lens',
|
@@ -55,6 +60,12 @@ _QUANTIZE = flags.DEFINE_bool(
|
|
55
60
|
True,
|
56
61
|
'Whether the model should be quantized.',
|
57
62
|
)
|
63
|
+
_LORA_RANKS = flags.DEFINE_multi_integer(
|
64
|
+
'lora_ranks',
|
65
|
+
None,
|
66
|
+
'If set, the model will be converted with the provided list of LoRA ranks.',
|
67
|
+
)
|
68
|
+
|
58
69
|
|
59
70
|
_BUILDER = {
|
60
71
|
'0.5b': qwen.build_0_5b_model,
|
@@ -67,16 +78,13 @@ def main(_):
|
|
67
78
|
pytorch_model = _BUILDER[_MODEL_SIZE.value](
|
68
79
|
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
|
69
80
|
)
|
70
|
-
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
|
71
|
-
model_size = _MODEL_SIZE.value.replace('.', '_')
|
72
|
-
output_filename = (
|
73
|
-
f'qwen_{model_size}_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
|
74
|
-
)
|
75
81
|
converter.convert_to_tflite(
|
76
82
|
pytorch_model,
|
77
|
-
|
83
|
+
output_path=_OUTPUT_PATH.value,
|
84
|
+
output_name_prefix=_OUTPUT_NAME_PREFIX.value,
|
78
85
|
prefill_seq_len=_PREFILL_SEQ_LENS.value,
|
79
86
|
quantize=_QUANTIZE.value,
|
87
|
+
lora_ranks=_LORA_RANKS.value,
|
80
88
|
export_config=ExportConfig(),
|
81
89
|
)
|
82
90
|
|