ai-edge-torch-nightly 0.3.0.dev20250105__py3-none-any.whl → 0.3.0.dev20250108__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (32) hide show
  1. ai_edge_torch/_config.py +26 -9
  2. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_check.py +13 -0
  3. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +36 -0
  4. ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py +16 -6
  5. ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +16 -6
  6. ai_edge_torch/generative/examples/gemma/gemma2.py +25 -43
  7. ai_edge_torch/generative/examples/llama/convert_to_tflite.py +16 -6
  8. ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +16 -9
  9. ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py +11 -6
  10. ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py +17 -7
  11. ai_edge_torch/generative/examples/phi/convert_to_tflite.py +16 -6
  12. ai_edge_torch/generative/examples/qwen/convert_to_tflite.py +17 -9
  13. ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +16 -7
  14. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +3 -3
  15. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +16 -8
  16. ai_edge_torch/generative/layers/attention.py +70 -12
  17. ai_edge_torch/generative/layers/lora.py +557 -0
  18. ai_edge_torch/generative/layers/normalization.py +2 -50
  19. ai_edge_torch/generative/layers/rotary_position_embedding.py +27 -34
  20. ai_edge_torch/generative/test/test_lora.py +147 -0
  21. ai_edge_torch/generative/utilities/converter.py +100 -47
  22. ai_edge_torch/generative/utilities/model_builder.py +21 -16
  23. ai_edge_torch/generative/utilities/verifier.py +4 -4
  24. ai_edge_torch/odml_torch/_torch_future.py +13 -0
  25. ai_edge_torch/odml_torch/export.py +6 -2
  26. ai_edge_torch/odml_torch/lowerings/decomp.py +4 -0
  27. ai_edge_torch/version.py +1 -1
  28. {ai_edge_torch_nightly-0.3.0.dev20250105.dist-info → ai_edge_torch_nightly-0.3.0.dev20250108.dist-info}/METADATA +1 -1
  29. {ai_edge_torch_nightly-0.3.0.dev20250105.dist-info → ai_edge_torch_nightly-0.3.0.dev20250108.dist-info}/RECORD +32 -30
  30. {ai_edge_torch_nightly-0.3.0.dev20250105.dist-info → ai_edge_torch_nightly-0.3.0.dev20250108.dist-info}/LICENSE +0 -0
  31. {ai_edge_torch_nightly-0.3.0.dev20250105.dist-info → ai_edge_torch_nightly-0.3.0.dev20250108.dist-info}/WHEEL +0 -0
  32. {ai_edge_torch_nightly-0.3.0.dev20250105.dist-info → ai_edge_torch_nightly-0.3.0.dev20250108.dist-info}/top_level.txt +0 -0
ai_edge_torch/_config.py CHANGED
@@ -22,6 +22,18 @@ import os
22
22
  __all__ = ["config"]
23
23
 
24
24
 
25
+ def _get_bool_env_var(name: str, default: bool) -> bool:
26
+ var = os.environ.get(name, "false")
27
+ var = var.lower().strip()
28
+ if var in ("y", "yes", "t", "true", "on", "1"):
29
+ return True
30
+ elif var in ("n", "no", "f", "false", "off", "0"):
31
+ return False
32
+ else:
33
+ logging.warning("Invalid %s value is ignored: %s.", name, var)
34
+ return default
35
+
36
+
25
37
  class _Config:
26
38
  """ai-edge-torch global configs."""
27
39
 
@@ -33,20 +45,25 @@ class _Config:
33
45
  To use torch_xla as the lowering backend, set environment variable
34
46
  `USE_TORCH_XLA` to "true".
35
47
  """
36
- var = os.environ.get("USE_TORCH_XLA", "false")
37
- var = var.lower().strip()
38
- if var in ("y", "yes", "t", "true", "on", "1"):
39
- return True
40
- elif var in ("n", "no", "f", "false", "off", "0"):
41
- return False
42
- else:
43
- logging.warning("Invalid USE_TORCH_XLA value is ignored: %s.", var)
44
- return False
48
+ return _get_bool_env_var("USE_TORCH_XLA", default=False)
45
49
 
46
50
  @property
47
51
  def in_oss(self) -> bool:
48
52
  """True if the code is not running in google internal environment."""
49
53
  return True
50
54
 
55
+ @property
56
+ def enable_group_norm_composite(self) -> bool:
57
+ """True if lowering group norm in StableHLO composite.
58
+
59
+ Currently only supports NHWC group norm generated by
60
+ OptimizeLayoutTransposesPass.
61
+ """
62
+ return _get_bool_env_var("ENABLE_GROUP_NORM_COMPOSITE", default=False)
63
+
64
+ @enable_group_norm_composite.setter
65
+ def enable_group_norm_composite(self, value: bool):
66
+ os.environ["ENABLE_GROUP_NORM_COMPOSITE"] = "y" if value else "n"
67
+
51
68
 
52
69
  config = _Config()
@@ -17,6 +17,7 @@
17
17
  import dataclasses
18
18
  import operator
19
19
 
20
+ import ai_edge_torch
20
21
  from ai_edge_torch import lowertools
21
22
  from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import layout_rewrite
22
23
  from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import utils
@@ -155,6 +156,7 @@ def _qdq_layout_sensitive_inputs_getter(node: Node):
155
156
  @layout_sensitive_inputs_getters.register(
156
157
  aten._native_batch_norm_legit_no_training
157
158
  )
159
+ @layout_sensitive_inputs_getters.register(aten.group_norm)
158
160
  @layout_sensitive_inputs_getters.register(aten.native_group_norm)
159
161
  def _first_arg_getter(node):
160
162
  return [node.args[0]]
@@ -188,6 +190,17 @@ def _aten_norm_checker(node):
188
190
  return NHWCable(can_be=len(val[0].shape) == 4, must_be=False)
189
191
 
190
192
 
193
+ @nhwcable_node_checkers.register(aten.group_norm)
194
+ def _aten_group_norm_checker(node):
195
+ val = node.meta.get("val")
196
+ if not hasattr(val, "shape"):
197
+ return NHWCable(can_be=False, must_be=False)
198
+
199
+ can_be = len(val.shape) == 4
200
+ must_be = can_be and ai_edge_torch.config.enable_group_norm_composite
201
+ return NHWCable(can_be=can_be, must_be=must_be)
202
+
203
+
191
204
  @nhwcable_node_checkers.register(aten.native_group_norm)
192
205
  def _aten_native_group_norm_checker(node):
193
206
  val = node.meta.get("val")
@@ -16,6 +16,7 @@
16
16
 
17
17
  import operator
18
18
 
19
+ import ai_edge_torch
19
20
  from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import layout_mark
20
21
  from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import op_func_registry
21
22
  from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import utils
@@ -23,6 +24,7 @@ import torch
23
24
  import torch.utils._pytree as pytree
24
25
 
25
26
  aten = torch.ops.aten
27
+ StableHLOCompositeBuilder = ai_edge_torch.hlfb.StableHLOCompositeBuilder
26
28
 
27
29
  __all__ = ["rewrite_nhwc_node", "has_nhwc_rewriter"]
28
30
 
@@ -342,6 +344,39 @@ def _aten__native_batch_norm_legit_no_training(node):
342
344
  node.target = batch_norm
343
345
 
344
346
 
347
+ @rewriters.register(aten.group_norm.default)
348
+ def _aten_group_norm(node):
349
+ def group_norm(input, num_groups: int, weight=None, bias=None, eps=1e-5):
350
+ is_composite_supported = (
351
+ ai_edge_torch.config.enable_group_norm_composite
352
+ and weight is not None
353
+ and bias is not None
354
+ )
355
+
356
+ builder = None
357
+ if is_composite_supported:
358
+ builder = StableHLOCompositeBuilder(
359
+ name="odml.group_norm",
360
+ attr={
361
+ "num_groups": num_groups,
362
+ "epsilon": eps,
363
+ "reduction_axes": [3],
364
+ "channel_axis": 3,
365
+ },
366
+ )
367
+ input, weight, bias = builder.mark_inputs(input, weight, bias)
368
+
369
+ input = utils.tensor_to_nchw(input)
370
+ output = aten.group_norm.default(input, num_groups, weight, bias, eps=eps)
371
+ output = utils.tensor_to_nhwc(output)
372
+
373
+ if builder is not None:
374
+ output = builder.mark_outputs(output)
375
+ return output
376
+
377
+ node.target = group_norm
378
+
379
+
345
380
  @rewriters.register(aten.native_group_norm.default)
346
381
  def _aten_native_group_norm(node):
347
382
 
@@ -354,6 +389,7 @@ def _aten_native_group_norm(node):
354
389
  flattened_inner_size: int,
355
390
  num_groups: int,
356
391
  eps: float,
392
+ **kwargs,
357
393
  ):
358
394
  input_reshaped = torch.reshape(
359
395
  input,
@@ -29,10 +29,15 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
29
29
  os.path.join(pathlib.Path.home(), 'Downloads/llm_data/gemma-2b'),
30
30
  'The path to the model checkpoint, or directory holding the checkpoint.',
31
31
  )
32
- _TFLITE_PATH = flags.DEFINE_string(
33
- 'tflite_path',
32
+ _OUTPUT_PATH = flags.DEFINE_string(
33
+ 'output_path',
34
34
  '/tmp/',
35
- 'The tflite file path to export.',
35
+ 'The path to export the tflite model.',
36
+ )
37
+ _OUTPUT_NAME_PREFIX = flags.DEFINE_string(
38
+ 'output_name_prefix',
39
+ 'gemma',
40
+ 'The prefix of the output tflite model name.',
36
41
  )
37
42
  _PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
38
43
  'prefill_seq_lens',
@@ -49,19 +54,24 @@ _QUANTIZE = flags.DEFINE_bool(
49
54
  True,
50
55
  'Whether the model should be quantized.',
51
56
  )
57
+ _LORA_RANKS = flags.DEFINE_multi_integer(
58
+ 'lora_ranks',
59
+ None,
60
+ 'If set, the model will be converted with the provided list of LoRA ranks.',
61
+ )
52
62
 
53
63
 
54
64
  def main(_):
55
65
  pytorch_model = gemma1.build_2b_model(
56
66
  _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
57
67
  )
58
- quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
59
- output_filename = f'gemma_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
60
68
  converter.convert_to_tflite(
61
69
  pytorch_model,
62
- tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
70
+ output_path=_OUTPUT_PATH.value,
71
+ output_name_prefix=_OUTPUT_NAME_PREFIX.value,
63
72
  prefill_seq_len=_PREFILL_SEQ_LENS.value,
64
73
  quantize=_QUANTIZE.value,
74
+ lora_ranks=_LORA_RANKS.value,
65
75
  export_config=ExportConfig(),
66
76
  )
67
77
 
@@ -29,10 +29,15 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
29
29
  os.path.join(pathlib.Path.home(), 'Downloads/llm_data/gemma2-2b'),
30
30
  'The path to the model checkpoint, or directory holding the checkpoint.',
31
31
  )
32
- _TFLITE_PATH = flags.DEFINE_string(
33
- 'tflite_path',
32
+ _OUTPUT_PATH = flags.DEFINE_string(
33
+ 'output_path',
34
34
  '/tmp/',
35
- 'The tflite file path to export.',
35
+ 'The path to export the tflite model.',
36
+ )
37
+ _OUTPUT_NAME_PREFIX = flags.DEFINE_string(
38
+ 'output_name_prefix',
39
+ 'gemma2',
40
+ 'The prefix of the output tflite model name.',
36
41
  )
37
42
  _PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
38
43
  'prefill_seq_lens',
@@ -49,19 +54,24 @@ _QUANTIZE = flags.DEFINE_bool(
49
54
  True,
50
55
  'Whether the model should be quantized.',
51
56
  )
57
+ _LORA_RANKS = flags.DEFINE_multi_integer(
58
+ 'lora_ranks',
59
+ None,
60
+ 'If set, the model will be converted with the provided list of LoRA ranks.',
61
+ )
52
62
 
53
63
 
54
64
  def main(_):
55
65
  pytorch_model = gemma2.build_2b_model(
56
66
  _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
57
67
  )
58
- quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
59
- output_filename = f'gemma2_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
60
68
  converter.convert_to_tflite(
61
69
  pytorch_model,
62
- tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
70
+ output_path=_OUTPUT_PATH.value,
71
+ output_name_prefix=_OUTPUT_NAME_PREFIX.value,
63
72
  prefill_seq_len=_PREFILL_SEQ_LENS.value,
64
73
  quantize=_QUANTIZE.value,
74
+ lora_ranks=_LORA_RANKS.value,
65
75
  export_config=ExportConfig(),
66
76
  )
67
77
 
@@ -15,14 +15,13 @@
15
15
 
16
16
  """Example of building a Gemma2 model."""
17
17
 
18
- from typing import List, Optional, Tuple
18
+ from typing import Optional, Tuple
19
19
 
20
20
  from ai_edge_torch.generative.layers import attention
21
21
  from ai_edge_torch.generative.layers import builder
22
22
  from ai_edge_torch.generative.layers import kv_cache as kv_utils
23
23
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
24
24
  import ai_edge_torch.generative.layers.model_config as cfg
25
- import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
26
25
  from ai_edge_torch.generative.utilities import model_builder
27
26
  import ai_edge_torch.generative.utilities.loader as loading_utils
28
27
  import torch
@@ -104,12 +103,17 @@ class Gemma2(nn.Module):
104
103
  config.embedding_dim,
105
104
  config.final_norm_config,
106
105
  )
107
- self.mask_cache = attn_utils.build_causal_mask_cache(
108
- size=config.kv_cache_max,
109
- )
110
106
  # Gemma2 has same hyper parameters for each layer except for attention
111
107
  # types. Use the first layer.
112
108
  attn_config = config.block_config(0).attn_config
109
+ self.rope_cache = attn_utils.build_rope_cache(
110
+ size=config.kv_cache_max,
111
+ dim=int(attn_config.rotary_percentage * attn_config.head_dim),
112
+ base=attn_config.rotary_base,
113
+ )
114
+ self.mask_cache = attn_utils.build_causal_mask_cache(
115
+ size=config.kv_cache_max,
116
+ )
113
117
  self.sliding_window_mask_cache = attn_utils.build_sliding_window_mask_cache(
114
118
  size=config.kv_cache_max,
115
119
  window_size=attn_config.sliding_window_size,
@@ -136,48 +140,29 @@ class Gemma2(nn.Module):
136
140
  f"Cannot forward sequence of length {seq_len}, max seq length is only"
137
141
  f" {self.config.max_seq_len}"
138
142
  )
139
-
140
- # token embeddings of shape (b, t, n_embd)
141
- input_embeds = self.tok_embedding(tokens)
142
- # RoPE parameters are the same for all blocks. Use the first layer.
143
- attn_config = self.config.block_config(0).attn_config
144
- n_elem = int(attn_config.rotary_percentage * attn_config.head_dim)
145
- rope = rotary_pos_emb.build_rope(
146
- input_pos, n_elem, attn_config.head_dim, attn_config.rotary_base
147
- )
148
- mask = [self.get_attention_mask(
149
- self.config.block_config(i).attn_config.attn_type, input_pos
150
- ) for i in range(self.config.num_layers)]
151
-
152
- return self._forward_with_embeds(
153
- input_embeds, rope, mask, input_pos, kv_cache, export_config
154
- )
155
-
156
- def _forward_with_embeds(
157
- self,
158
- input_embeds: torch.Tensor,
159
- rope: Tuple[torch.Tensor, torch.Tensor],
160
- mask: List[torch.Tensor],
161
- input_pos: torch.Tensor,
162
- kv_cache: kv_utils.KVCache,
163
- export_config: Optional[model_builder.ExportConfig] = None,
164
- ) -> dict[torch.Tensor, kv_utils.KVCache]:
165
- """Forwards the model with input embeddings."""
166
143
  assert len(self.transformer_blocks) == len(kv_cache.caches), (
167
144
  "The number of transformer blocks and the number of KV cache entries"
168
145
  " must be the same."
169
146
  )
170
147
 
171
- if self.config.embedding_scale is not None:
172
- input_embeds = input_embeds * self.config.embedding_scale
173
- x = input_embeds
174
- updated_kv_entries = []
148
+ cos, sin = self.rope_cache
149
+ cos = cos.index_select(0, input_pos)
150
+ sin = sin.index_select(0, input_pos)
151
+
152
+ # token embeddings of shape (b, t, n_embd)
153
+ x = self.tok_embedding(tokens)
154
+ x = x * (self.config.embedding_dim**0.5)
155
+
156
+ updated_kv_entires = []
175
157
  for i, block in enumerate(self.transformer_blocks):
158
+ mask = self.get_attention_mask(
159
+ block.config.attn_config.attn_type, input_pos
160
+ )
176
161
  kv_entry = kv_cache.caches[i] if kv_cache else None
177
- x, kv_entry = block(x, rope, mask[i], input_pos, kv_entry)
162
+ x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
178
163
  if kv_entry:
179
- updated_kv_entries.append(kv_entry)
180
- updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entries))
164
+ updated_kv_entires.append(kv_entry)
165
+ updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
181
166
 
182
167
  if export_config is not None:
183
168
  if (
@@ -243,13 +228,11 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
243
228
  )
244
229
 
245
230
  num_layers = 26
246
- embedding_dim = 2304
247
231
  config = cfg.ModelConfig(
248
232
  vocab_size=256000,
249
233
  num_layers=num_layers,
250
234
  max_seq_len=8192,
251
- embedding_dim=embedding_dim,
252
- embedding_scale=embedding_dim**0.5,
235
+ embedding_dim=2304,
253
236
  kv_cache_max_len=kv_cache_max_len,
254
237
  block_configs=[get_block_config(i) for i in range(num_layers)],
255
238
  final_norm_config=norm_config,
@@ -266,7 +249,6 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
266
249
  config.num_layers = 2
267
250
  config.max_seq_len = 2 * kv_cache_max_len
268
251
  config.embedding_dim = 128
269
- config.embedding_scale = config.embedding_dim**0.5
270
252
  config.block_configs = config.block_configs[: config.num_layers]
271
253
  for block_config in config.block_configs:
272
254
  block_config.attn_config.num_heads = 4
@@ -35,10 +35,15 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
35
35
  os.path.join(pathlib.Path.home(), 'Downloads/llm_data/llama'),
36
36
  'The path to the model checkpoint, or directory holding the checkpoint.',
37
37
  )
38
- _TFLITE_PATH = flags.DEFINE_string(
39
- 'tflite_path',
38
+ _OUTPUT_PATH = flags.DEFINE_string(
39
+ 'output_path',
40
40
  '/tmp/',
41
- 'The tflite file path to export.',
41
+ 'The path to export the tflite model.',
42
+ )
43
+ _OUTPUT_NAME_PREFIX = flags.DEFINE_string(
44
+ 'output_name_prefix',
45
+ 'llama',
46
+ 'The prefix of the output tflite model name.',
42
47
  )
43
48
  _PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
44
49
  'prefill_seq_lens',
@@ -55,6 +60,11 @@ _QUANTIZE = flags.DEFINE_bool(
55
60
  True,
56
61
  'Whether the model should be quantized.',
57
62
  )
63
+ _LORA_RANKS = flags.DEFINE_multi_integer(
64
+ 'lora_ranks',
65
+ None,
66
+ 'If set, the model will be converted with the provided list of LoRA ranks.',
67
+ )
58
68
 
59
69
  _BUILDER = {
60
70
  '1b': llama.build_1b_model,
@@ -66,13 +76,13 @@ def main(_):
66
76
  pytorch_model = _BUILDER[_MODEL_SIZE.value](
67
77
  _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
68
78
  )
69
- quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
70
- output_filename = f'llama_{_MODEL_SIZE.value}_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
71
79
  converter.convert_to_tflite(
72
80
  pytorch_model,
73
- tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
81
+ output_path=_OUTPUT_PATH.value,
82
+ output_name_prefix=_OUTPUT_NAME_PREFIX.value,
74
83
  prefill_seq_len=_PREFILL_SEQ_LENS.value,
75
84
  quantize=_QUANTIZE.value,
85
+ lora_ranks=_LORA_RANKS.value,
76
86
  export_config=ExportConfig(),
77
87
  )
78
88
 
@@ -29,10 +29,15 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
29
29
  os.path.join(pathlib.Path.home(), 'Downloads/llm_data/openelm'),
30
30
  'The path to the model checkpoint, or directory holding the checkpoint.',
31
31
  )
32
- _TFLITE_PATH = flags.DEFINE_string(
33
- 'tflite_path',
32
+ _OUTPUT_PATH = flags.DEFINE_string(
33
+ 'output_path',
34
34
  '/tmp/',
35
- 'The tflite file path to export.',
35
+ 'The path to export the tflite model.',
36
+ )
37
+ _OUTPUT_NAME_PREFIX = flags.DEFINE_string(
38
+ 'output_name_prefix',
39
+ 'openelm',
40
+ 'The prefix of the output tflite model name.',
36
41
  )
37
42
  _PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
38
43
  'prefill_seq_lens',
@@ -49,22 +54,24 @@ _QUANTIZE = flags.DEFINE_bool(
49
54
  True,
50
55
  'Whether the model should be quantized.',
51
56
  )
57
+ _LORA_RANKS = flags.DEFINE_multi_integer(
58
+ 'lora_ranks',
59
+ None,
60
+ 'If set, the model will be converted with the provided list of LoRA ranks.',
61
+ )
52
62
 
53
63
 
54
64
  def main(_):
55
65
  pytorch_model = openelm.build_model(
56
66
  _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
57
67
  )
58
- quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
59
- output_filename = (
60
- f'openelm_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
61
- )
62
-
63
68
  converter.convert_to_tflite(
64
69
  pytorch_model,
65
- tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
70
+ output_path=_OUTPUT_PATH.value,
71
+ output_name_prefix=_OUTPUT_NAME_PREFIX.value,
66
72
  prefill_seq_len=_PREFILL_SEQ_LENS.value,
67
73
  quantize=_QUANTIZE.value,
74
+ lora_ranks=_LORA_RANKS.value,
68
75
  export_config=ExportConfig(),
69
76
  )
70
77
 
@@ -40,10 +40,15 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
40
40
  os.path.join(pathlib.Path.home(), 'Downloads/llm_data/paligemma2-3b-224'),
41
41
  'The path to the model checkpoint, or directory holding the checkpoint.',
42
42
  )
43
- _TFLITE_PATH = flags.DEFINE_string(
44
- 'tflite_path',
43
+ _OUTPUT_PATH = flags.DEFINE_string(
44
+ 'output_path',
45
45
  '/tmp/',
46
- 'The tflite file path to export.',
46
+ 'The path to export the tflite model.',
47
+ )
48
+ _OUTPUT_NAME_PREFIX = flags.DEFINE_string(
49
+ 'output_name_prefix',
50
+ 'paligemma',
51
+ 'The prefix of the output tflite model name.',
47
52
  )
48
53
  _PREFILL_SEQ_LEN = flags.DEFINE_integer(
49
54
  'prefill_seq_len',
@@ -73,11 +78,11 @@ def main(_):
73
78
  version=int(_VERSION.value),
74
79
  kv_cache_max_len=_KV_CACHE_MAX_LEN.value,
75
80
  )
76
- quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
77
- output_filename = f'paligemma{_VERSION.value}_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
81
+
78
82
  converter.convert_to_tflite(
79
83
  pytorch_model,
80
- tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
84
+ output_path=_OUTPUT_PATH.value,
85
+ output_name_prefix=f'{_OUTPUT_NAME_PREFIX.value}_{_VERSION.value}',
81
86
  prefill_seq_len=_PREFILL_SEQ_LEN.value,
82
87
  pixel_values_size=torch.Size(_PIXEL_VALUES_SIZE.value),
83
88
  quantize=_QUANTIZE.value,
@@ -26,13 +26,18 @@ from ai_edge_torch.generative.utilities.model_builder import ExportConfig
26
26
 
27
27
  _CHECKPOINT_PATH = flags.DEFINE_string(
28
28
  'checkpoint_path',
29
- os.path.join(pathlib.Path.home(), 'Downloads/llm_data/phi3'),
29
+ os.path.join(pathlib.Path.home(), 'Downloads/llm_data/gemma-2b'),
30
30
  'The path to the model checkpoint, or directory holding the checkpoint.',
31
31
  )
32
- _TFLITE_PATH = flags.DEFINE_string(
33
- 'tflite_path',
32
+ _OUTPUT_PATH = flags.DEFINE_string(
33
+ 'output_path',
34
34
  '/tmp/',
35
- 'The tflite file path to export.',
35
+ 'The path to export the tflite model.',
36
+ )
37
+ _OUTPUT_NAME_PREFIX = flags.DEFINE_string(
38
+ 'output_name_prefix',
39
+ 'phi3',
40
+ 'The prefix of the output tflite model name.',
36
41
  )
37
42
  _PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
38
43
  'prefill_seq_lens',
@@ -49,19 +54,24 @@ _QUANTIZE = flags.DEFINE_bool(
49
54
  True,
50
55
  'Whether the model should be quantized.',
51
56
  )
57
+ _LORA_RANKS = flags.DEFINE_multi_integer(
58
+ 'lora_ranks',
59
+ None,
60
+ 'If set, the model will be converted with the provided list of LoRA ranks.',
61
+ )
52
62
 
53
63
 
54
64
  def main(_):
55
65
  pytorch_model = phi3.build_model(
56
66
  _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
57
67
  )
58
- quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
59
- output_filename = f'phi3_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
60
68
  converter.convert_to_tflite(
61
69
  pytorch_model,
62
- tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
70
+ output_path=_OUTPUT_PATH.value,
71
+ output_name_prefix=_OUTPUT_NAME_PREFIX.value,
63
72
  prefill_seq_len=_PREFILL_SEQ_LENS.value,
64
73
  quantize=_QUANTIZE.value,
74
+ lora_ranks=_LORA_RANKS.value,
65
75
  export_config=ExportConfig(),
66
76
  )
67
77
 
@@ -29,10 +29,15 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
29
29
  os.path.join(pathlib.Path.home(), 'Downloads/llm_data/phi2'),
30
30
  'The path to the model checkpoint, or directory holding the checkpoint.',
31
31
  )
32
- _TFLITE_PATH = flags.DEFINE_string(
33
- 'tflite_path',
32
+ _OUTPUT_PATH = flags.DEFINE_string(
33
+ 'output_path',
34
34
  '/tmp/',
35
- 'The tflite file path to export.',
35
+ 'The path to export the tflite model.',
36
+ )
37
+ _OUTPUT_NAME_PREFIX = flags.DEFINE_string(
38
+ 'output_name_prefix',
39
+ 'phi2',
40
+ 'The prefix of the output tflite model name.',
36
41
  )
37
42
  _PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
38
43
  'prefill_seq_lens',
@@ -49,19 +54,24 @@ _QUANTIZE = flags.DEFINE_bool(
49
54
  True,
50
55
  'Whether the model should be quantized.',
51
56
  )
57
+ _LORA_RANKS = flags.DEFINE_multi_integer(
58
+ 'lora_ranks',
59
+ None,
60
+ 'If set, the model will be converted with the provided list of LoRA ranks.',
61
+ )
52
62
 
53
63
 
54
64
  def main(_):
55
65
  pytorch_model = phi2.build_model(
56
66
  _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
57
67
  )
58
- quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
59
- output_filename = f'phi2_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
60
68
  converter.convert_to_tflite(
61
69
  pytorch_model,
62
- tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
70
+ output_path=_OUTPUT_PATH.value,
71
+ output_name_prefix=_OUTPUT_NAME_PREFIX.value,
63
72
  prefill_seq_len=_PREFILL_SEQ_LENS.value,
64
73
  quantize=_QUANTIZE.value,
74
+ lora_ranks=_LORA_RANKS.value,
65
75
  export_config=ExportConfig(),
66
76
  )
67
77
 
@@ -35,10 +35,15 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
35
35
  os.path.join(pathlib.Path.home(), 'Downloads/llm_data/qwen'),
36
36
  'The path to the model checkpoint, or directory holding the checkpoint.',
37
37
  )
38
- _TFLITE_PATH = flags.DEFINE_string(
39
- 'tflite_path',
38
+ _OUTPUT_PATH = flags.DEFINE_string(
39
+ 'output_path',
40
40
  '/tmp/',
41
- 'The tflite file path to export.',
41
+ 'The path to export the tflite model.',
42
+ )
43
+ _OUTPUT_NAME_PREFIX = flags.DEFINE_string(
44
+ 'output_name_prefix',
45
+ 'qwen',
46
+ 'The prefix of the output tflite model name.',
42
47
  )
43
48
  _PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
44
49
  'prefill_seq_lens',
@@ -55,6 +60,12 @@ _QUANTIZE = flags.DEFINE_bool(
55
60
  True,
56
61
  'Whether the model should be quantized.',
57
62
  )
63
+ _LORA_RANKS = flags.DEFINE_multi_integer(
64
+ 'lora_ranks',
65
+ None,
66
+ 'If set, the model will be converted with the provided list of LoRA ranks.',
67
+ )
68
+
58
69
 
59
70
  _BUILDER = {
60
71
  '0.5b': qwen.build_0_5b_model,
@@ -67,16 +78,13 @@ def main(_):
67
78
  pytorch_model = _BUILDER[_MODEL_SIZE.value](
68
79
  _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
69
80
  )
70
- quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
71
- model_size = _MODEL_SIZE.value.replace('.', '_')
72
- output_filename = (
73
- f'qwen_{model_size}_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
74
- )
75
81
  converter.convert_to_tflite(
76
82
  pytorch_model,
77
- tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
83
+ output_path=_OUTPUT_PATH.value,
84
+ output_name_prefix=_OUTPUT_NAME_PREFIX.value,
78
85
  prefill_seq_len=_PREFILL_SEQ_LENS.value,
79
86
  quantize=_QUANTIZE.value,
87
+ lora_ranks=_LORA_RANKS.value,
80
88
  export_config=ExportConfig(),
81
89
  )
82
90