ai-edge-torch-nightly 0.3.0.dev20250105__py3-none-any.whl → 0.3.0.dev20250108__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (32) hide show
  1. ai_edge_torch/_config.py +26 -9
  2. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_check.py +13 -0
  3. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +36 -0
  4. ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py +16 -6
  5. ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +16 -6
  6. ai_edge_torch/generative/examples/gemma/gemma2.py +25 -43
  7. ai_edge_torch/generative/examples/llama/convert_to_tflite.py +16 -6
  8. ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +16 -9
  9. ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py +11 -6
  10. ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py +17 -7
  11. ai_edge_torch/generative/examples/phi/convert_to_tflite.py +16 -6
  12. ai_edge_torch/generative/examples/qwen/convert_to_tflite.py +17 -9
  13. ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +16 -7
  14. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +3 -3
  15. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +16 -8
  16. ai_edge_torch/generative/layers/attention.py +70 -12
  17. ai_edge_torch/generative/layers/lora.py +557 -0
  18. ai_edge_torch/generative/layers/normalization.py +2 -50
  19. ai_edge_torch/generative/layers/rotary_position_embedding.py +27 -34
  20. ai_edge_torch/generative/test/test_lora.py +147 -0
  21. ai_edge_torch/generative/utilities/converter.py +100 -47
  22. ai_edge_torch/generative/utilities/model_builder.py +21 -16
  23. ai_edge_torch/generative/utilities/verifier.py +4 -4
  24. ai_edge_torch/odml_torch/_torch_future.py +13 -0
  25. ai_edge_torch/odml_torch/export.py +6 -2
  26. ai_edge_torch/odml_torch/lowerings/decomp.py +4 -0
  27. ai_edge_torch/version.py +1 -1
  28. {ai_edge_torch_nightly-0.3.0.dev20250105.dist-info → ai_edge_torch_nightly-0.3.0.dev20250108.dist-info}/METADATA +1 -1
  29. {ai_edge_torch_nightly-0.3.0.dev20250105.dist-info → ai_edge_torch_nightly-0.3.0.dev20250108.dist-info}/RECORD +32 -30
  30. {ai_edge_torch_nightly-0.3.0.dev20250105.dist-info → ai_edge_torch_nightly-0.3.0.dev20250108.dist-info}/LICENSE +0 -0
  31. {ai_edge_torch_nightly-0.3.0.dev20250105.dist-info → ai_edge_torch_nightly-0.3.0.dev20250108.dist-info}/WHEEL +0 -0
  32. {ai_edge_torch_nightly-0.3.0.dev20250105.dist-info → ai_edge_torch_nightly-0.3.0.dev20250108.dist-info}/top_level.txt +0 -0
ai_edge_torch/_config.py CHANGED
@@ -22,6 +22,18 @@ import os
22
22
  __all__ = ["config"]
23
23
 
24
24
 
25
+ def _get_bool_env_var(name: str, default: bool) -> bool:
26
+ var = os.environ.get(name, "false")
27
+ var = var.lower().strip()
28
+ if var in ("y", "yes", "t", "true", "on", "1"):
29
+ return True
30
+ elif var in ("n", "no", "f", "false", "off", "0"):
31
+ return False
32
+ else:
33
+ logging.warning("Invalid %s value is ignored: %s.", name, var)
34
+ return default
35
+
36
+
25
37
  class _Config:
26
38
  """ai-edge-torch global configs."""
27
39
 
@@ -33,20 +45,25 @@ class _Config:
33
45
  To use torch_xla as the lowering backend, set environment variable
34
46
  `USE_TORCH_XLA` to "true".
35
47
  """
36
- var = os.environ.get("USE_TORCH_XLA", "false")
37
- var = var.lower().strip()
38
- if var in ("y", "yes", "t", "true", "on", "1"):
39
- return True
40
- elif var in ("n", "no", "f", "false", "off", "0"):
41
- return False
42
- else:
43
- logging.warning("Invalid USE_TORCH_XLA value is ignored: %s.", var)
44
- return False
48
+ return _get_bool_env_var("USE_TORCH_XLA", default=False)
45
49
 
46
50
  @property
47
51
  def in_oss(self) -> bool:
48
52
  """True if the code is not running in google internal environment."""
49
53
  return True
50
54
 
55
+ @property
56
+ def enable_group_norm_composite(self) -> bool:
57
+ """True if lowering group norm in StableHLO composite.
58
+
59
+ Currently only supports NHWC group norm generated by
60
+ OptimizeLayoutTransposesPass.
61
+ """
62
+ return _get_bool_env_var("ENABLE_GROUP_NORM_COMPOSITE", default=False)
63
+
64
+ @enable_group_norm_composite.setter
65
+ def enable_group_norm_composite(self, value: bool):
66
+ os.environ["ENABLE_GROUP_NORM_COMPOSITE"] = "y" if value else "n"
67
+
51
68
 
52
69
  config = _Config()
@@ -17,6 +17,7 @@
17
17
  import dataclasses
18
18
  import operator
19
19
 
20
+ import ai_edge_torch
20
21
  from ai_edge_torch import lowertools
21
22
  from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import layout_rewrite
22
23
  from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import utils
@@ -155,6 +156,7 @@ def _qdq_layout_sensitive_inputs_getter(node: Node):
155
156
  @layout_sensitive_inputs_getters.register(
156
157
  aten._native_batch_norm_legit_no_training
157
158
  )
159
+ @layout_sensitive_inputs_getters.register(aten.group_norm)
158
160
  @layout_sensitive_inputs_getters.register(aten.native_group_norm)
159
161
  def _first_arg_getter(node):
160
162
  return [node.args[0]]
@@ -188,6 +190,17 @@ def _aten_norm_checker(node):
188
190
  return NHWCable(can_be=len(val[0].shape) == 4, must_be=False)
189
191
 
190
192
 
193
+ @nhwcable_node_checkers.register(aten.group_norm)
194
+ def _aten_group_norm_checker(node):
195
+ val = node.meta.get("val")
196
+ if not hasattr(val, "shape"):
197
+ return NHWCable(can_be=False, must_be=False)
198
+
199
+ can_be = len(val.shape) == 4
200
+ must_be = can_be and ai_edge_torch.config.enable_group_norm_composite
201
+ return NHWCable(can_be=can_be, must_be=must_be)
202
+
203
+
191
204
  @nhwcable_node_checkers.register(aten.native_group_norm)
192
205
  def _aten_native_group_norm_checker(node):
193
206
  val = node.meta.get("val")
@@ -16,6 +16,7 @@
16
16
 
17
17
  import operator
18
18
 
19
+ import ai_edge_torch
19
20
  from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import layout_mark
20
21
  from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import op_func_registry
21
22
  from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import utils
@@ -23,6 +24,7 @@ import torch
23
24
  import torch.utils._pytree as pytree
24
25
 
25
26
  aten = torch.ops.aten
27
+ StableHLOCompositeBuilder = ai_edge_torch.hlfb.StableHLOCompositeBuilder
26
28
 
27
29
  __all__ = ["rewrite_nhwc_node", "has_nhwc_rewriter"]
28
30
 
@@ -342,6 +344,39 @@ def _aten__native_batch_norm_legit_no_training(node):
342
344
  node.target = batch_norm
343
345
 
344
346
 
347
+ @rewriters.register(aten.group_norm.default)
348
+ def _aten_group_norm(node):
349
+ def group_norm(input, num_groups: int, weight=None, bias=None, eps=1e-5):
350
+ is_composite_supported = (
351
+ ai_edge_torch.config.enable_group_norm_composite
352
+ and weight is not None
353
+ and bias is not None
354
+ )
355
+
356
+ builder = None
357
+ if is_composite_supported:
358
+ builder = StableHLOCompositeBuilder(
359
+ name="odml.group_norm",
360
+ attr={
361
+ "num_groups": num_groups,
362
+ "epsilon": eps,
363
+ "reduction_axes": [3],
364
+ "channel_axis": 3,
365
+ },
366
+ )
367
+ input, weight, bias = builder.mark_inputs(input, weight, bias)
368
+
369
+ input = utils.tensor_to_nchw(input)
370
+ output = aten.group_norm.default(input, num_groups, weight, bias, eps=eps)
371
+ output = utils.tensor_to_nhwc(output)
372
+
373
+ if builder is not None:
374
+ output = builder.mark_outputs(output)
375
+ return output
376
+
377
+ node.target = group_norm
378
+
379
+
345
380
  @rewriters.register(aten.native_group_norm.default)
346
381
  def _aten_native_group_norm(node):
347
382
 
@@ -354,6 +389,7 @@ def _aten_native_group_norm(node):
354
389
  flattened_inner_size: int,
355
390
  num_groups: int,
356
391
  eps: float,
392
+ **kwargs,
357
393
  ):
358
394
  input_reshaped = torch.reshape(
359
395
  input,
@@ -29,10 +29,15 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
29
29
  os.path.join(pathlib.Path.home(), 'Downloads/llm_data/gemma-2b'),
30
30
  'The path to the model checkpoint, or directory holding the checkpoint.',
31
31
  )
32
- _TFLITE_PATH = flags.DEFINE_string(
33
- 'tflite_path',
32
+ _OUTPUT_PATH = flags.DEFINE_string(
33
+ 'output_path',
34
34
  '/tmp/',
35
- 'The tflite file path to export.',
35
+ 'The path to export the tflite model.',
36
+ )
37
+ _OUTPUT_NAME_PREFIX = flags.DEFINE_string(
38
+ 'output_name_prefix',
39
+ 'gemma',
40
+ 'The prefix of the output tflite model name.',
36
41
  )
37
42
  _PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
38
43
  'prefill_seq_lens',
@@ -49,19 +54,24 @@ _QUANTIZE = flags.DEFINE_bool(
49
54
  True,
50
55
  'Whether the model should be quantized.',
51
56
  )
57
+ _LORA_RANKS = flags.DEFINE_multi_integer(
58
+ 'lora_ranks',
59
+ None,
60
+ 'If set, the model will be converted with the provided list of LoRA ranks.',
61
+ )
52
62
 
53
63
 
54
64
  def main(_):
55
65
  pytorch_model = gemma1.build_2b_model(
56
66
  _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
57
67
  )
58
- quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
59
- output_filename = f'gemma_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
60
68
  converter.convert_to_tflite(
61
69
  pytorch_model,
62
- tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
70
+ output_path=_OUTPUT_PATH.value,
71
+ output_name_prefix=_OUTPUT_NAME_PREFIX.value,
63
72
  prefill_seq_len=_PREFILL_SEQ_LENS.value,
64
73
  quantize=_QUANTIZE.value,
74
+ lora_ranks=_LORA_RANKS.value,
65
75
  export_config=ExportConfig(),
66
76
  )
67
77
 
@@ -29,10 +29,15 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
29
29
  os.path.join(pathlib.Path.home(), 'Downloads/llm_data/gemma2-2b'),
30
30
  'The path to the model checkpoint, or directory holding the checkpoint.',
31
31
  )
32
- _TFLITE_PATH = flags.DEFINE_string(
33
- 'tflite_path',
32
+ _OUTPUT_PATH = flags.DEFINE_string(
33
+ 'output_path',
34
34
  '/tmp/',
35
- 'The tflite file path to export.',
35
+ 'The path to export the tflite model.',
36
+ )
37
+ _OUTPUT_NAME_PREFIX = flags.DEFINE_string(
38
+ 'output_name_prefix',
39
+ 'gemma2',
40
+ 'The prefix of the output tflite model name.',
36
41
  )
37
42
  _PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
38
43
  'prefill_seq_lens',
@@ -49,19 +54,24 @@ _QUANTIZE = flags.DEFINE_bool(
49
54
  True,
50
55
  'Whether the model should be quantized.',
51
56
  )
57
+ _LORA_RANKS = flags.DEFINE_multi_integer(
58
+ 'lora_ranks',
59
+ None,
60
+ 'If set, the model will be converted with the provided list of LoRA ranks.',
61
+ )
52
62
 
53
63
 
54
64
  def main(_):
55
65
  pytorch_model = gemma2.build_2b_model(
56
66
  _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
57
67
  )
58
- quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
59
- output_filename = f'gemma2_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
60
68
  converter.convert_to_tflite(
61
69
  pytorch_model,
62
- tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
70
+ output_path=_OUTPUT_PATH.value,
71
+ output_name_prefix=_OUTPUT_NAME_PREFIX.value,
63
72
  prefill_seq_len=_PREFILL_SEQ_LENS.value,
64
73
  quantize=_QUANTIZE.value,
74
+ lora_ranks=_LORA_RANKS.value,
65
75
  export_config=ExportConfig(),
66
76
  )
67
77
 
@@ -15,14 +15,13 @@
15
15
 
16
16
  """Example of building a Gemma2 model."""
17
17
 
18
- from typing import List, Optional, Tuple
18
+ from typing import Optional, Tuple
19
19
 
20
20
  from ai_edge_torch.generative.layers import attention
21
21
  from ai_edge_torch.generative.layers import builder
22
22
  from ai_edge_torch.generative.layers import kv_cache as kv_utils
23
23
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
24
24
  import ai_edge_torch.generative.layers.model_config as cfg
25
- import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
26
25
  from ai_edge_torch.generative.utilities import model_builder
27
26
  import ai_edge_torch.generative.utilities.loader as loading_utils
28
27
  import torch
@@ -104,12 +103,17 @@ class Gemma2(nn.Module):
104
103
  config.embedding_dim,
105
104
  config.final_norm_config,
106
105
  )
107
- self.mask_cache = attn_utils.build_causal_mask_cache(
108
- size=config.kv_cache_max,
109
- )
110
106
  # Gemma2 has same hyper parameters for each layer except for attention
111
107
  # types. Use the first layer.
112
108
  attn_config = config.block_config(0).attn_config
109
+ self.rope_cache = attn_utils.build_rope_cache(
110
+ size=config.kv_cache_max,
111
+ dim=int(attn_config.rotary_percentage * attn_config.head_dim),
112
+ base=attn_config.rotary_base,
113
+ )
114
+ self.mask_cache = attn_utils.build_causal_mask_cache(
115
+ size=config.kv_cache_max,
116
+ )
113
117
  self.sliding_window_mask_cache = attn_utils.build_sliding_window_mask_cache(
114
118
  size=config.kv_cache_max,
115
119
  window_size=attn_config.sliding_window_size,
@@ -136,48 +140,29 @@ class Gemma2(nn.Module):
136
140
  f"Cannot forward sequence of length {seq_len}, max seq length is only"
137
141
  f" {self.config.max_seq_len}"
138
142
  )
139
-
140
- # token embeddings of shape (b, t, n_embd)
141
- input_embeds = self.tok_embedding(tokens)
142
- # RoPE parameters are the same for all blocks. Use the first layer.
143
- attn_config = self.config.block_config(0).attn_config
144
- n_elem = int(attn_config.rotary_percentage * attn_config.head_dim)
145
- rope = rotary_pos_emb.build_rope(
146
- input_pos, n_elem, attn_config.head_dim, attn_config.rotary_base
147
- )
148
- mask = [self.get_attention_mask(
149
- self.config.block_config(i).attn_config.attn_type, input_pos
150
- ) for i in range(self.config.num_layers)]
151
-
152
- return self._forward_with_embeds(
153
- input_embeds, rope, mask, input_pos, kv_cache, export_config
154
- )
155
-
156
- def _forward_with_embeds(
157
- self,
158
- input_embeds: torch.Tensor,
159
- rope: Tuple[torch.Tensor, torch.Tensor],
160
- mask: List[torch.Tensor],
161
- input_pos: torch.Tensor,
162
- kv_cache: kv_utils.KVCache,
163
- export_config: Optional[model_builder.ExportConfig] = None,
164
- ) -> dict[torch.Tensor, kv_utils.KVCache]:
165
- """Forwards the model with input embeddings."""
166
143
  assert len(self.transformer_blocks) == len(kv_cache.caches), (
167
144
  "The number of transformer blocks and the number of KV cache entries"
168
145
  " must be the same."
169
146
  )
170
147
 
171
- if self.config.embedding_scale is not None:
172
- input_embeds = input_embeds * self.config.embedding_scale
173
- x = input_embeds
174
- updated_kv_entries = []
148
+ cos, sin = self.rope_cache
149
+ cos = cos.index_select(0, input_pos)
150
+ sin = sin.index_select(0, input_pos)
151
+
152
+ # token embeddings of shape (b, t, n_embd)
153
+ x = self.tok_embedding(tokens)
154
+ x = x * (self.config.embedding_dim**0.5)
155
+
156
+ updated_kv_entires = []
175
157
  for i, block in enumerate(self.transformer_blocks):
158
+ mask = self.get_attention_mask(
159
+ block.config.attn_config.attn_type, input_pos
160
+ )
176
161
  kv_entry = kv_cache.caches[i] if kv_cache else None
177
- x, kv_entry = block(x, rope, mask[i], input_pos, kv_entry)
162
+ x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
178
163
  if kv_entry:
179
- updated_kv_entries.append(kv_entry)
180
- updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entries))
164
+ updated_kv_entires.append(kv_entry)
165
+ updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
181
166
 
182
167
  if export_config is not None:
183
168
  if (
@@ -243,13 +228,11 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
243
228
  )
244
229
 
245
230
  num_layers = 26
246
- embedding_dim = 2304
247
231
  config = cfg.ModelConfig(
248
232
  vocab_size=256000,
249
233
  num_layers=num_layers,
250
234
  max_seq_len=8192,
251
- embedding_dim=embedding_dim,
252
- embedding_scale=embedding_dim**0.5,
235
+ embedding_dim=2304,
253
236
  kv_cache_max_len=kv_cache_max_len,
254
237
  block_configs=[get_block_config(i) for i in range(num_layers)],
255
238
  final_norm_config=norm_config,
@@ -266,7 +249,6 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
266
249
  config.num_layers = 2
267
250
  config.max_seq_len = 2 * kv_cache_max_len
268
251
  config.embedding_dim = 128
269
- config.embedding_scale = config.embedding_dim**0.5
270
252
  config.block_configs = config.block_configs[: config.num_layers]
271
253
  for block_config in config.block_configs:
272
254
  block_config.attn_config.num_heads = 4
@@ -35,10 +35,15 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
35
35
  os.path.join(pathlib.Path.home(), 'Downloads/llm_data/llama'),
36
36
  'The path to the model checkpoint, or directory holding the checkpoint.',
37
37
  )
38
- _TFLITE_PATH = flags.DEFINE_string(
39
- 'tflite_path',
38
+ _OUTPUT_PATH = flags.DEFINE_string(
39
+ 'output_path',
40
40
  '/tmp/',
41
- 'The tflite file path to export.',
41
+ 'The path to export the tflite model.',
42
+ )
43
+ _OUTPUT_NAME_PREFIX = flags.DEFINE_string(
44
+ 'output_name_prefix',
45
+ 'llama',
46
+ 'The prefix of the output tflite model name.',
42
47
  )
43
48
  _PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
44
49
  'prefill_seq_lens',
@@ -55,6 +60,11 @@ _QUANTIZE = flags.DEFINE_bool(
55
60
  True,
56
61
  'Whether the model should be quantized.',
57
62
  )
63
+ _LORA_RANKS = flags.DEFINE_multi_integer(
64
+ 'lora_ranks',
65
+ None,
66
+ 'If set, the model will be converted with the provided list of LoRA ranks.',
67
+ )
58
68
 
59
69
  _BUILDER = {
60
70
  '1b': llama.build_1b_model,
@@ -66,13 +76,13 @@ def main(_):
66
76
  pytorch_model = _BUILDER[_MODEL_SIZE.value](
67
77
  _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
68
78
  )
69
- quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
70
- output_filename = f'llama_{_MODEL_SIZE.value}_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
71
79
  converter.convert_to_tflite(
72
80
  pytorch_model,
73
- tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
81
+ output_path=_OUTPUT_PATH.value,
82
+ output_name_prefix=_OUTPUT_NAME_PREFIX.value,
74
83
  prefill_seq_len=_PREFILL_SEQ_LENS.value,
75
84
  quantize=_QUANTIZE.value,
85
+ lora_ranks=_LORA_RANKS.value,
76
86
  export_config=ExportConfig(),
77
87
  )
78
88
 
@@ -29,10 +29,15 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
29
29
  os.path.join(pathlib.Path.home(), 'Downloads/llm_data/openelm'),
30
30
  'The path to the model checkpoint, or directory holding the checkpoint.',
31
31
  )
32
- _TFLITE_PATH = flags.DEFINE_string(
33
- 'tflite_path',
32
+ _OUTPUT_PATH = flags.DEFINE_string(
33
+ 'output_path',
34
34
  '/tmp/',
35
- 'The tflite file path to export.',
35
+ 'The path to export the tflite model.',
36
+ )
37
+ _OUTPUT_NAME_PREFIX = flags.DEFINE_string(
38
+ 'output_name_prefix',
39
+ 'openelm',
40
+ 'The prefix of the output tflite model name.',
36
41
  )
37
42
  _PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
38
43
  'prefill_seq_lens',
@@ -49,22 +54,24 @@ _QUANTIZE = flags.DEFINE_bool(
49
54
  True,
50
55
  'Whether the model should be quantized.',
51
56
  )
57
+ _LORA_RANKS = flags.DEFINE_multi_integer(
58
+ 'lora_ranks',
59
+ None,
60
+ 'If set, the model will be converted with the provided list of LoRA ranks.',
61
+ )
52
62
 
53
63
 
54
64
  def main(_):
55
65
  pytorch_model = openelm.build_model(
56
66
  _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
57
67
  )
58
- quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
59
- output_filename = (
60
- f'openelm_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
61
- )
62
-
63
68
  converter.convert_to_tflite(
64
69
  pytorch_model,
65
- tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
70
+ output_path=_OUTPUT_PATH.value,
71
+ output_name_prefix=_OUTPUT_NAME_PREFIX.value,
66
72
  prefill_seq_len=_PREFILL_SEQ_LENS.value,
67
73
  quantize=_QUANTIZE.value,
74
+ lora_ranks=_LORA_RANKS.value,
68
75
  export_config=ExportConfig(),
69
76
  )
70
77
 
@@ -40,10 +40,15 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
40
40
  os.path.join(pathlib.Path.home(), 'Downloads/llm_data/paligemma2-3b-224'),
41
41
  'The path to the model checkpoint, or directory holding the checkpoint.',
42
42
  )
43
- _TFLITE_PATH = flags.DEFINE_string(
44
- 'tflite_path',
43
+ _OUTPUT_PATH = flags.DEFINE_string(
44
+ 'output_path',
45
45
  '/tmp/',
46
- 'The tflite file path to export.',
46
+ 'The path to export the tflite model.',
47
+ )
48
+ _OUTPUT_NAME_PREFIX = flags.DEFINE_string(
49
+ 'output_name_prefix',
50
+ 'paligemma',
51
+ 'The prefix of the output tflite model name.',
47
52
  )
48
53
  _PREFILL_SEQ_LEN = flags.DEFINE_integer(
49
54
  'prefill_seq_len',
@@ -73,11 +78,11 @@ def main(_):
73
78
  version=int(_VERSION.value),
74
79
  kv_cache_max_len=_KV_CACHE_MAX_LEN.value,
75
80
  )
76
- quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
77
- output_filename = f'paligemma{_VERSION.value}_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
81
+
78
82
  converter.convert_to_tflite(
79
83
  pytorch_model,
80
- tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
84
+ output_path=_OUTPUT_PATH.value,
85
+ output_name_prefix=f'{_OUTPUT_NAME_PREFIX.value}_{_VERSION.value}',
81
86
  prefill_seq_len=_PREFILL_SEQ_LEN.value,
82
87
  pixel_values_size=torch.Size(_PIXEL_VALUES_SIZE.value),
83
88
  quantize=_QUANTIZE.value,
@@ -26,13 +26,18 @@ from ai_edge_torch.generative.utilities.model_builder import ExportConfig
26
26
 
27
27
  _CHECKPOINT_PATH = flags.DEFINE_string(
28
28
  'checkpoint_path',
29
- os.path.join(pathlib.Path.home(), 'Downloads/llm_data/phi3'),
29
+ os.path.join(pathlib.Path.home(), 'Downloads/llm_data/gemma-2b'),
30
30
  'The path to the model checkpoint, or directory holding the checkpoint.',
31
31
  )
32
- _TFLITE_PATH = flags.DEFINE_string(
33
- 'tflite_path',
32
+ _OUTPUT_PATH = flags.DEFINE_string(
33
+ 'output_path',
34
34
  '/tmp/',
35
- 'The tflite file path to export.',
35
+ 'The path to export the tflite model.',
36
+ )
37
+ _OUTPUT_NAME_PREFIX = flags.DEFINE_string(
38
+ 'output_name_prefix',
39
+ 'phi3',
40
+ 'The prefix of the output tflite model name.',
36
41
  )
37
42
  _PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
38
43
  'prefill_seq_lens',
@@ -49,19 +54,24 @@ _QUANTIZE = flags.DEFINE_bool(
49
54
  True,
50
55
  'Whether the model should be quantized.',
51
56
  )
57
+ _LORA_RANKS = flags.DEFINE_multi_integer(
58
+ 'lora_ranks',
59
+ None,
60
+ 'If set, the model will be converted with the provided list of LoRA ranks.',
61
+ )
52
62
 
53
63
 
54
64
  def main(_):
55
65
  pytorch_model = phi3.build_model(
56
66
  _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
57
67
  )
58
- quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
59
- output_filename = f'phi3_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
60
68
  converter.convert_to_tflite(
61
69
  pytorch_model,
62
- tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
70
+ output_path=_OUTPUT_PATH.value,
71
+ output_name_prefix=_OUTPUT_NAME_PREFIX.value,
63
72
  prefill_seq_len=_PREFILL_SEQ_LENS.value,
64
73
  quantize=_QUANTIZE.value,
74
+ lora_ranks=_LORA_RANKS.value,
65
75
  export_config=ExportConfig(),
66
76
  )
67
77
 
@@ -29,10 +29,15 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
29
29
  os.path.join(pathlib.Path.home(), 'Downloads/llm_data/phi2'),
30
30
  'The path to the model checkpoint, or directory holding the checkpoint.',
31
31
  )
32
- _TFLITE_PATH = flags.DEFINE_string(
33
- 'tflite_path',
32
+ _OUTPUT_PATH = flags.DEFINE_string(
33
+ 'output_path',
34
34
  '/tmp/',
35
- 'The tflite file path to export.',
35
+ 'The path to export the tflite model.',
36
+ )
37
+ _OUTPUT_NAME_PREFIX = flags.DEFINE_string(
38
+ 'output_name_prefix',
39
+ 'phi2',
40
+ 'The prefix of the output tflite model name.',
36
41
  )
37
42
  _PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
38
43
  'prefill_seq_lens',
@@ -49,19 +54,24 @@ _QUANTIZE = flags.DEFINE_bool(
49
54
  True,
50
55
  'Whether the model should be quantized.',
51
56
  )
57
+ _LORA_RANKS = flags.DEFINE_multi_integer(
58
+ 'lora_ranks',
59
+ None,
60
+ 'If set, the model will be converted with the provided list of LoRA ranks.',
61
+ )
52
62
 
53
63
 
54
64
  def main(_):
55
65
  pytorch_model = phi2.build_model(
56
66
  _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
57
67
  )
58
- quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
59
- output_filename = f'phi2_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
60
68
  converter.convert_to_tflite(
61
69
  pytorch_model,
62
- tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
70
+ output_path=_OUTPUT_PATH.value,
71
+ output_name_prefix=_OUTPUT_NAME_PREFIX.value,
63
72
  prefill_seq_len=_PREFILL_SEQ_LENS.value,
64
73
  quantize=_QUANTIZE.value,
74
+ lora_ranks=_LORA_RANKS.value,
65
75
  export_config=ExportConfig(),
66
76
  )
67
77
 
@@ -35,10 +35,15 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
35
35
  os.path.join(pathlib.Path.home(), 'Downloads/llm_data/qwen'),
36
36
  'The path to the model checkpoint, or directory holding the checkpoint.',
37
37
  )
38
- _TFLITE_PATH = flags.DEFINE_string(
39
- 'tflite_path',
38
+ _OUTPUT_PATH = flags.DEFINE_string(
39
+ 'output_path',
40
40
  '/tmp/',
41
- 'The tflite file path to export.',
41
+ 'The path to export the tflite model.',
42
+ )
43
+ _OUTPUT_NAME_PREFIX = flags.DEFINE_string(
44
+ 'output_name_prefix',
45
+ 'qwen',
46
+ 'The prefix of the output tflite model name.',
42
47
  )
43
48
  _PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
44
49
  'prefill_seq_lens',
@@ -55,6 +60,12 @@ _QUANTIZE = flags.DEFINE_bool(
55
60
  True,
56
61
  'Whether the model should be quantized.',
57
62
  )
63
+ _LORA_RANKS = flags.DEFINE_multi_integer(
64
+ 'lora_ranks',
65
+ None,
66
+ 'If set, the model will be converted with the provided list of LoRA ranks.',
67
+ )
68
+
58
69
 
59
70
  _BUILDER = {
60
71
  '0.5b': qwen.build_0_5b_model,
@@ -67,16 +78,13 @@ def main(_):
67
78
  pytorch_model = _BUILDER[_MODEL_SIZE.value](
68
79
  _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
69
80
  )
70
- quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
71
- model_size = _MODEL_SIZE.value.replace('.', '_')
72
- output_filename = (
73
- f'qwen_{model_size}_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
74
- )
75
81
  converter.convert_to_tflite(
76
82
  pytorch_model,
77
- tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
83
+ output_path=_OUTPUT_PATH.value,
84
+ output_name_prefix=_OUTPUT_NAME_PREFIX.value,
78
85
  prefill_seq_len=_PREFILL_SEQ_LENS.value,
79
86
  quantize=_QUANTIZE.value,
87
+ lora_ranks=_LORA_RANKS.value,
80
88
  export_config=ExportConfig(),
81
89
  )
82
90