ai-edge-torch-nightly 0.4.0.dev20250330__py3-none-any.whl → 0.4.0.dev20250401__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (24) hide show
  1. ai_edge_torch/generative/examples/amd_llama_135m/convert_to_tflite.py +7 -43
  2. ai_edge_torch/generative/examples/deepseek/convert_to_tflite.py +7 -42
  3. ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py +7 -45
  4. ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +7 -44
  5. ai_edge_torch/generative/examples/gemma3/convert_gemma3_to_tflite.py +10 -45
  6. ai_edge_torch/generative/examples/llama/convert_to_tflite.py +9 -43
  7. ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +7 -44
  8. ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py +8 -39
  9. ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py +7 -44
  10. ai_edge_torch/generative/examples/phi/convert_phi4_to_tflite.py +7 -44
  11. ai_edge_torch/generative/examples/phi/convert_to_tflite.py +7 -42
  12. ai_edge_torch/generative/examples/qwen/convert_to_tflite.py +8 -45
  13. ai_edge_torch/generative/examples/qwen_vl/convert_to_tflite.py +8 -39
  14. ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +8 -43
  15. ai_edge_torch/generative/examples/smollm/convert_v2_to_tflite.py +8 -43
  16. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +7 -44
  17. ai_edge_torch/generative/utilities/converter.py +45 -0
  18. ai_edge_torch/generative/utilities/loader.py +5 -1
  19. ai_edge_torch/version.py +1 -1
  20. {ai_edge_torch_nightly-0.4.0.dev20250330.dist-info → ai_edge_torch_nightly-0.4.0.dev20250401.dist-info}/METADATA +1 -1
  21. {ai_edge_torch_nightly-0.4.0.dev20250330.dist-info → ai_edge_torch_nightly-0.4.0.dev20250401.dist-info}/RECORD +24 -24
  22. {ai_edge_torch_nightly-0.4.0.dev20250330.dist-info → ai_edge_torch_nightly-0.4.0.dev20250401.dist-info}/LICENSE +0 -0
  23. {ai_edge_torch_nightly-0.4.0.dev20250330.dist-info → ai_edge_torch_nightly-0.4.0.dev20250401.dist-info}/WHEEL +0 -0
  24. {ai_edge_torch_nightly-0.4.0.dev20250330.dist-info → ai_edge_torch_nightly-0.4.0.dev20250401.dist-info}/top_level.txt +0 -0
@@ -16,61 +16,25 @@
16
16
  """Example of converting AMD-Llama-135m model to multi-signature tflite model."""
17
17
 
18
18
  import os
19
- import pathlib
20
-
21
19
  from absl import app
22
20
  from absl import flags
23
21
  from ai_edge_torch.generative.examples.amd_llama_135m import amd_llama_135m
24
22
  from ai_edge_torch.generative.utilities import converter
25
23
  from ai_edge_torch.generative.utilities.model_builder import ExportConfig
26
24
 
27
- _CHECKPOINT_PATH = flags.DEFINE_string(
28
- 'checkpoint_path',
29
- os.path.join(pathlib.Path.home(), 'Downloads/llm_data/amd-llama-135m'),
30
- 'The path to the model checkpoint, or directory holding the checkpoint.',
31
- )
32
- _KV_CACHE_MAX_LEN = flags.DEFINE_integer(
33
- 'kv_cache_max_len',
34
- 1280,
35
- 'The maximum size of KV cache buffer, including both prefill and decode.',
36
- )
37
- _OUTPUT_PATH = flags.DEFINE_string(
38
- 'output_path',
39
- '/tmp/',
40
- 'The path to export the tflite model.',
41
- )
42
- _OUTPUT_NAME_PREFIX = flags.DEFINE_string(
43
- 'output_name_prefix',
44
- 'amd_llama',
45
- 'The prefix of the output tflite model name.',
46
- )
47
- _PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
48
- 'prefill_seq_lens',
49
- (8, 64, 128, 256, 512, 1024),
50
- 'List of the maximum sizes of prefill input tensors.',
51
- )
52
- _QUANTIZE = flags.DEFINE_bool(
53
- 'quantize',
54
- True,
55
- 'Whether the model should be quantized.',
56
- )
57
- _LORA_RANKS = flags.DEFINE_multi_integer(
58
- 'lora_ranks',
59
- None,
60
- 'If set, the model will be converted with the provided list of LoRA ranks.',
61
- )
25
+ flags = converter.define_conversion_flags("amd-llama-135m")
62
26
 
63
27
  def main(_):
64
28
  pytorch_model = amd_llama_135m.build_model(
65
- _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
29
+ flags.FLAGS.checkpoint_path, kv_cache_max_len=flags.FLAGS.kv_cache_max_len
66
30
  )
67
31
  converter.convert_to_tflite(
68
32
  pytorch_model,
69
- output_path=_OUTPUT_PATH.value,
70
- output_name_prefix=_OUTPUT_NAME_PREFIX.value,
71
- prefill_seq_len=_PREFILL_SEQ_LENS.value,
72
- quantize=_QUANTIZE.value,
73
- lora_ranks=_LORA_RANKS.value,
33
+ output_path=flags.FLAGS.output_path,
34
+ output_name_prefix=flags.FLAGS.output_name_prefix,
35
+ prefill_seq_len=flags.FLAGS.prefill_seq_lens,
36
+ quantize=flags.FLAGS.quantize,
37
+ lora_ranks=flags.FLAGS.lora_ranks,
74
38
  export_config=ExportConfig(),
75
39
  )
76
40
 
@@ -24,54 +24,19 @@ from ai_edge_torch.generative.examples.deepseek import deepseek
24
24
  from ai_edge_torch.generative.utilities import converter
25
25
  from ai_edge_torch.generative.utilities.model_builder import ExportConfig
26
26
 
27
- _CHECKPOINT_PATH = flags.DEFINE_string(
28
- 'checkpoint_path',
29
- os.path.join(pathlib.Path.home(), 'Downloads/llm_data/deepseek'),
30
- 'The path to the model checkpoint, or directory holding the checkpoint.',
31
- )
32
- _OUTPUT_PATH = flags.DEFINE_string(
33
- 'output_path',
34
- '/tmp/',
35
- 'The path to export the tflite model.',
36
- )
37
- _OUTPUT_NAME_PREFIX = flags.DEFINE_string(
38
- 'output_name_prefix',
39
- 'deepseek',
40
- 'The prefix of the output tflite model name.',
41
- )
42
- _PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
43
- 'prefill_seq_lens',
44
- (8, 64, 128, 256, 512, 1024),
45
- 'List of the maximum sizes of prefill input tensors.',
46
- )
47
- _KV_CACHE_MAX_LEN = flags.DEFINE_integer(
48
- 'kv_cache_max_len',
49
- 1280,
50
- 'The maximum size of KV cache buffer, including both prefill and decode.',
51
- )
52
- _QUANTIZE = flags.DEFINE_bool(
53
- 'quantize',
54
- True,
55
- 'Whether the model should be quantized.',
56
- )
57
- _LORA_RANKS = flags.DEFINE_multi_integer(
58
- 'lora_ranks',
59
- None,
60
- 'If set, the model will be converted with the provided list of LoRA ranks.',
61
- )
62
-
27
+ flags = converter.define_conversion_flags("deepseek")
63
28
 
64
29
  def main(_):
65
30
  pytorch_model = deepseek.build_model(
66
- _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
31
+ flags.FLAGS.checkpoint_path, kv_cache_max_len=flags.FLAGS.kv_cache_max_len
67
32
  )
68
33
  converter.convert_to_tflite(
69
34
  pytorch_model,
70
- output_path=_OUTPUT_PATH.value,
71
- output_name_prefix=_OUTPUT_NAME_PREFIX.value,
72
- prefill_seq_len=_PREFILL_SEQ_LENS.value,
73
- quantize=_QUANTIZE.value,
74
- lora_ranks=_LORA_RANKS.value,
35
+ output_path=flags.FLAGS.output_path,
36
+ output_name_prefix=flags.FLAGS.output_name_prefix,
37
+ prefill_seq_len=flags.FLAGS.prefill_seq_lens,
38
+ quantize=flags.FLAGS.quantize,
39
+ lora_ranks=flags.FLAGS.lora_ranks,
75
40
  export_config=ExportConfig(),
76
41
  )
77
42
 
@@ -16,62 +16,24 @@
16
16
  """Example of converting a Gemma1 model to multi-signature tflite model."""
17
17
 
18
18
  import os
19
- import pathlib
20
-
21
19
  from absl import app
22
- from absl import flags
23
20
  from ai_edge_torch.generative.examples.gemma import gemma1
24
21
  from ai_edge_torch.generative.utilities import converter
25
22
  from ai_edge_torch.generative.utilities.model_builder import ExportConfig
26
23
 
27
- _CHECKPOINT_PATH = flags.DEFINE_string(
28
- 'checkpoint_path',
29
- os.path.join(pathlib.Path.home(), 'Downloads/llm_data/gemma-2b'),
30
- 'The path to the model checkpoint, or directory holding the checkpoint.',
31
- )
32
- _OUTPUT_PATH = flags.DEFINE_string(
33
- 'output_path',
34
- '/tmp/',
35
- 'The path to export the tflite model.',
36
- )
37
- _OUTPUT_NAME_PREFIX = flags.DEFINE_string(
38
- 'output_name_prefix',
39
- 'gemma',
40
- 'The prefix of the output tflite model name.',
41
- )
42
- _PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
43
- 'prefill_seq_lens',
44
- (8, 64, 128, 256, 512, 1024),
45
- 'List of the maximum sizes of prefill input tensors.',
46
- )
47
- _KV_CACHE_MAX_LEN = flags.DEFINE_integer(
48
- 'kv_cache_max_len',
49
- 1280,
50
- 'The maximum size of KV cache buffer, including both prefill and decode.',
51
- )
52
- _QUANTIZE = flags.DEFINE_bool(
53
- 'quantize',
54
- True,
55
- 'Whether the model should be quantized.',
56
- )
57
- _LORA_RANKS = flags.DEFINE_multi_integer(
58
- 'lora_ranks',
59
- None,
60
- 'If set, the model will be converted with the provided list of LoRA ranks.',
61
- )
62
-
24
+ flags = converter.define_conversion_flags("gemma-2b")
63
25
 
64
26
  def main(_):
65
27
  pytorch_model = gemma1.build_2b_model(
66
- _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
28
+ flags.FLAGS.checkpoint_path, kv_cache_max_len=flags.FLAGS.kv_cache_max_len
67
29
  )
68
30
  converter.convert_to_tflite(
69
31
  pytorch_model,
70
- output_path=_OUTPUT_PATH.value,
71
- output_name_prefix=_OUTPUT_NAME_PREFIX.value,
72
- prefill_seq_len=_PREFILL_SEQ_LENS.value,
73
- quantize=_QUANTIZE.value,
74
- lora_ranks=_LORA_RANKS.value,
32
+ output_path=flags.FLAGS.output_path,
33
+ output_name_prefix=flags.FLAGS.output_name_prefix,
34
+ prefill_seq_len=flags.FLAGS.prefill_seq_lens,
35
+ quantize=flags.FLAGS.quantize,
36
+ lora_ranks=flags.FLAGS.lora_ranks,
75
37
  export_config=ExportConfig(),
76
38
  )
77
39
 
@@ -16,62 +16,25 @@
16
16
  """Example of converting a Gemma2 model to multi-signature tflite model."""
17
17
 
18
18
  import os
19
- import pathlib
20
-
21
19
  from absl import app
22
20
  from absl import flags
23
21
  from ai_edge_torch.generative.examples.gemma import gemma2
24
22
  from ai_edge_torch.generative.utilities import converter
25
23
  from ai_edge_torch.generative.utilities.model_builder import ExportConfig
26
24
 
27
- _CHECKPOINT_PATH = flags.DEFINE_string(
28
- 'checkpoint_path',
29
- os.path.join(pathlib.Path.home(), 'Downloads/llm_data/gemma2-2b'),
30
- 'The path to the model checkpoint, or directory holding the checkpoint.',
31
- )
32
- _OUTPUT_PATH = flags.DEFINE_string(
33
- 'output_path',
34
- '/tmp/',
35
- 'The path to export the tflite model.',
36
- )
37
- _OUTPUT_NAME_PREFIX = flags.DEFINE_string(
38
- 'output_name_prefix',
39
- 'gemma2',
40
- 'The prefix of the output tflite model name.',
41
- )
42
- _PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
43
- 'prefill_seq_lens',
44
- (8, 64, 128, 256, 512, 1024),
45
- 'List of the maximum sizes of prefill input tensors.',
46
- )
47
- _KV_CACHE_MAX_LEN = flags.DEFINE_integer(
48
- 'kv_cache_max_len',
49
- 1280,
50
- 'The maximum size of KV cache buffer, including both prefill and decode.',
51
- )
52
- _QUANTIZE = flags.DEFINE_bool(
53
- 'quantize',
54
- True,
55
- 'Whether the model should be quantized.',
56
- )
57
- _LORA_RANKS = flags.DEFINE_multi_integer(
58
- 'lora_ranks',
59
- None,
60
- 'If set, the model will be converted with the provided list of LoRA ranks.',
61
- )
62
-
25
+ flags = converter.define_conversion_flags("gemma2-2b")
63
26
 
64
27
  def main(_):
65
28
  pytorch_model = gemma2.build_2b_model(
66
- _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
29
+ flags.FLAGS.checkpoint_path, kv_cache_max_len=flags.FLAGS.kv_cache_max_len
67
30
  )
68
31
  converter.convert_to_tflite(
69
32
  pytorch_model,
70
- output_path=_OUTPUT_PATH.value,
71
- output_name_prefix=_OUTPUT_NAME_PREFIX.value,
72
- prefill_seq_len=_PREFILL_SEQ_LENS.value,
73
- quantize=_QUANTIZE.value,
74
- lora_ranks=_LORA_RANKS.value,
33
+ output_path=flags.FLAGS.output_path,
34
+ output_name_prefix=flags.FLAGS.output_name_prefix,
35
+ prefill_seq_len=flags.FLAGS.prefill_seq_lens,
36
+ quantize=flags.FLAGS.quantize,
37
+ lora_ranks=flags.FLAGS.lora_ranks,
75
38
  export_config=ExportConfig(),
76
39
  )
77
40
 
@@ -16,8 +16,6 @@
16
16
  """Example of converting a Gemma3 model to multi-signature tflite model."""
17
17
 
18
18
  import os
19
- import pathlib
20
-
21
19
  from absl import app
22
20
  from absl import flags
23
21
  from ai_edge_torch.generative.examples.gemma3 import gemma3
@@ -26,48 +24,14 @@ from ai_edge_torch.generative.utilities import converter
26
24
  from ai_edge_torch.generative.utilities.model_builder import ExportConfig
27
25
  import torch
28
26
 
27
+ flags = converter.define_conversion_flags('gemma3-1b')
28
+
29
29
  _MODEL_SIZE = flags.DEFINE_string(
30
30
  'model_size',
31
31
  '1b',
32
32
  'The size of the model to convert.',
33
33
  )
34
34
 
35
- _CHECKPOINT_PATH = flags.DEFINE_string(
36
- 'checkpoint_path',
37
- os.path.join(pathlib.Path.home(), 'Downloads/llm_data/gemma3-1b'),
38
- 'The path to the model checkpoint, or directory holding the checkpoint.',
39
- )
40
- _OUTPUT_PATH = flags.DEFINE_string(
41
- 'output_path',
42
- '/tmp/',
43
- 'The path to export the tflite model.',
44
- )
45
- _OUTPUT_NAME_PREFIX = flags.DEFINE_string(
46
- 'output_name_prefix',
47
- 'gemma3',
48
- 'The prefix of the output tflite model name.',
49
- )
50
- _PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
51
- 'prefill_seq_lens',
52
- (32, 64, 128, 256, 512, 1024),
53
- 'List of the maximum sizes of prefill input tensors.',
54
- )
55
- _KV_CACHE_MAX_LEN = flags.DEFINE_integer(
56
- 'kv_cache_max_len',
57
- 2048,
58
- 'The maximum size of KV cache buffer, including both prefill and decode.',
59
- )
60
- _QUANTIZE = flags.DEFINE_bool(
61
- 'quantize',
62
- False,
63
- 'Whether the model should be quantized.',
64
- )
65
- _LORA_RANKS = flags.DEFINE_multi_integer(
66
- 'lora_ranks',
67
- None,
68
- 'If set, the model will be converted with the provided list of LoRA ranks.',
69
- )
70
-
71
35
 
72
36
  def _create_mask(mask_len, kv_cache_max_len):
73
37
  mask = torch.full(
@@ -101,21 +65,22 @@ def _create_export_config(
101
65
  def main(_):
102
66
  if _MODEL_SIZE.value == '1b':
103
67
  pytorch_model = gemma3.build_model_1b(
104
- _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
68
+ flags.FLAGS.checkpoint_path,
69
+ kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
105
70
  )
106
71
  config = pytorch_model.config
107
72
  else:
108
73
  raise ValueError(f'Unsupported model size: {_MODEL_SIZE.value}')
109
74
  converter.convert_to_tflite(
110
75
  pytorch_model,
111
- output_path=_OUTPUT_PATH.value,
112
- output_name_prefix=_OUTPUT_NAME_PREFIX.value,
113
- prefill_seq_len=_PREFILL_SEQ_LENS.value,
114
- quantize=_QUANTIZE.value,
76
+ output_path=flags.FLAGS.output_path,
77
+ output_name_prefix=flags.FLAGS.output_name_prefix,
78
+ prefill_seq_len=flags.FLAGS.prefill_seq_lens,
79
+ quantize=flags.FLAGS.quantize,
115
80
  config=config,
116
- lora_ranks=_LORA_RANKS.value,
81
+ lora_ranks=flags.FLAGS.lora_ranks,
117
82
  export_config=_create_export_config(
118
- _PREFILL_SEQ_LENS.value, _KV_CACHE_MAX_LEN.value
83
+ flags.FLAGS.prefill_seq_lens, flags.FLAGS.kv_cache_max_len
119
84
  ),
120
85
  )
121
86
 
@@ -16,55 +16,21 @@
16
16
  """Example of converting Llama 3.2 1B model to multi-signature tflite model."""
17
17
 
18
18
  import os
19
- import pathlib
20
-
21
19
  from absl import app
22
20
  from absl import flags
23
21
  from ai_edge_torch.generative.examples.llama import llama
24
22
  from ai_edge_torch.generative.utilities import converter
25
23
  from ai_edge_torch.generative.utilities.model_builder import ExportConfig
26
24
 
25
+
26
+ flags = converter.define_conversion_flags('llama')
27
+
27
28
  _MODEL_SIZE = flags.DEFINE_enum(
28
29
  'model_size',
29
30
  '1b',
30
31
  ['1b', '3b'],
31
32
  'The size of the model to verify.',
32
33
  )
33
- _CHECKPOINT_PATH = flags.DEFINE_string(
34
- 'checkpoint_path',
35
- os.path.join(pathlib.Path.home(), 'Downloads/llm_data/llama'),
36
- 'The path to the model checkpoint, or directory holding the checkpoint.',
37
- )
38
- _OUTPUT_PATH = flags.DEFINE_string(
39
- 'output_path',
40
- '/tmp/',
41
- 'The path to export the tflite model.',
42
- )
43
- _OUTPUT_NAME_PREFIX = flags.DEFINE_string(
44
- 'output_name_prefix',
45
- 'llama',
46
- 'The prefix of the output tflite model name.',
47
- )
48
- _PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
49
- 'prefill_seq_lens',
50
- (8, 64, 128, 256, 512, 1024),
51
- 'List of the maximum sizes of prefill input tensors.',
52
- )
53
- _KV_CACHE_MAX_LEN = flags.DEFINE_integer(
54
- 'kv_cache_max_len',
55
- 1280,
56
- 'The maximum size of KV cache buffer, including both prefill and decode.',
57
- )
58
- _QUANTIZE = flags.DEFINE_bool(
59
- 'quantize',
60
- True,
61
- 'Whether the model should be quantized.',
62
- )
63
- _LORA_RANKS = flags.DEFINE_multi_integer(
64
- 'lora_ranks',
65
- None,
66
- 'If set, the model will be converted with the provided list of LoRA ranks.',
67
- )
68
34
 
69
35
  _BUILDER = {
70
36
  '1b': llama.build_1b_model,
@@ -74,15 +40,15 @@ _BUILDER = {
74
40
 
75
41
  def main(_):
76
42
  pytorch_model = _BUILDER[_MODEL_SIZE.value](
77
- _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
43
+ flags.FLAGS.checkpoint_path, kv_cache_max_len=flags.FLAGS.kv_cache_max_len
78
44
  )
79
45
  converter.convert_to_tflite(
80
46
  pytorch_model,
81
- output_path=_OUTPUT_PATH.value,
82
- output_name_prefix=_OUTPUT_NAME_PREFIX.value,
83
- prefill_seq_len=_PREFILL_SEQ_LENS.value,
84
- quantize=_QUANTIZE.value,
85
- lora_ranks=_LORA_RANKS.value,
47
+ output_path=flags.FLAGS.output_path,
48
+ output_name_prefix=flags.FLAGS.output_name_prefix,
49
+ prefill_seq_len=flags.FLAGS.prefill_seq_lens,
50
+ quantize=flags.FLAGS.quantize,
51
+ lora_ranks=flags.FLAGS.lora_ranks,
86
52
  export_config=ExportConfig(),
87
53
  )
88
54
 
@@ -16,62 +16,25 @@
16
16
  """Example of converting OpenELM model to multi-signature tflite model."""
17
17
 
18
18
  import os
19
- import pathlib
20
-
21
19
  from absl import app
22
20
  from absl import flags
23
21
  from ai_edge_torch.generative.examples.openelm import openelm
24
22
  from ai_edge_torch.generative.utilities import converter
25
23
  from ai_edge_torch.generative.utilities.model_builder import ExportConfig
26
24
 
27
- _CHECKPOINT_PATH = flags.DEFINE_string(
28
- 'checkpoint_path',
29
- os.path.join(pathlib.Path.home(), 'Downloads/llm_data/openelm'),
30
- 'The path to the model checkpoint, or directory holding the checkpoint.',
31
- )
32
- _OUTPUT_PATH = flags.DEFINE_string(
33
- 'output_path',
34
- '/tmp/',
35
- 'The path to export the tflite model.',
36
- )
37
- _OUTPUT_NAME_PREFIX = flags.DEFINE_string(
38
- 'output_name_prefix',
39
- 'openelm',
40
- 'The prefix of the output tflite model name.',
41
- )
42
- _PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
43
- 'prefill_seq_lens',
44
- (8, 64, 128, 256, 512, 1024),
45
- 'List of the maximum sizes of prefill input tensors.',
46
- )
47
- _KV_CACHE_MAX_LEN = flags.DEFINE_integer(
48
- 'kv_cache_max_len',
49
- 1280,
50
- 'The maximum size of KV cache buffer, including both prefill and decode.',
51
- )
52
- _QUANTIZE = flags.DEFINE_bool(
53
- 'quantize',
54
- True,
55
- 'Whether the model should be quantized.',
56
- )
57
- _LORA_RANKS = flags.DEFINE_multi_integer(
58
- 'lora_ranks',
59
- None,
60
- 'If set, the model will be converted with the provided list of LoRA ranks.',
61
- )
62
-
25
+ flags = converter.define_conversion_flags("openelm")
63
26
 
64
27
  def main(_):
65
28
  pytorch_model = openelm.build_model(
66
- _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
29
+ flags.FLAGS.checkpoint_path, kv_cache_max_len=flags.FLAGS.kv_cache_max_len
67
30
  )
68
31
  converter.convert_to_tflite(
69
32
  pytorch_model,
70
- output_path=_OUTPUT_PATH.value,
71
- output_name_prefix=_OUTPUT_NAME_PREFIX.value,
72
- prefill_seq_len=_PREFILL_SEQ_LENS.value,
73
- quantize=_QUANTIZE.value,
74
- lora_ranks=_LORA_RANKS.value,
33
+ output_path=flags.FLAGS.output_path,
34
+ output_name_prefix=flags.FLAGS.output_name_prefix,
35
+ prefill_seq_len=flags.FLAGS.prefill_seq_lens,
36
+ quantize=flags.FLAGS.quantize,
37
+ lora_ranks=flags.FLAGS.lora_ranks,
75
38
  export_config=ExportConfig(),
76
39
  )
77
40
 
@@ -16,8 +16,6 @@
16
16
  """Example of converting a PaliGemma model to multi-signature tflite model."""
17
17
 
18
18
  import os
19
- import pathlib
20
-
21
19
  from absl import app
22
20
  from absl import flags
23
21
  from ai_edge_torch.generative.examples.paligemma import paligemma
@@ -25,61 +23,32 @@ from ai_edge_torch.generative.utilities import converter
25
23
  from ai_edge_torch.generative.utilities.model_builder import ExportConfig
26
24
  import torch
27
25
 
26
+ flags = converter.define_conversion_flags('paligemma2-3b-224')
27
+
28
28
  _VERSION = flags.DEFINE_enum(
29
29
  'version',
30
30
  '2',
31
31
  ['1', '2'],
32
32
  'The version of PaliGemma model to verify.',
33
33
  )
34
- _CHECKPOINT_PATH = flags.DEFINE_string(
35
- 'checkpoint_path',
36
- os.path.join(pathlib.Path.home(), 'Downloads/llm_data/paligemma2-3b-224'),
37
- 'The path to the model checkpoint, or directory holding the checkpoint.',
38
- )
39
- _OUTPUT_PATH = flags.DEFINE_string(
40
- 'output_path',
41
- '/tmp/',
42
- 'The path to export the tflite model.',
43
- )
44
- _OUTPUT_NAME_PREFIX = flags.DEFINE_string(
45
- 'output_name_prefix',
46
- 'paligemma',
47
- 'The prefix of the output tflite model name.',
48
- )
49
- _PREFILL_SEQ_LEN = flags.DEFINE_integer(
50
- 'prefill_seq_len',
51
- 1024,
52
- 'The maximum size of prefill input tensor.',
53
- )
54
- _KV_CACHE_MAX_LEN = flags.DEFINE_integer(
55
- 'kv_cache_max_len',
56
- 1280,
57
- 'The maximum size of KV cache buffer, including both prefill and decode.',
58
- )
59
- _QUANTIZE = flags.DEFINE_bool(
60
- 'quantize',
61
- True,
62
- 'Whether the model should be quantized.',
63
- )
64
-
65
34
 
66
35
  def main(_):
67
36
  pytorch_model = paligemma.build_model(
68
- _CHECKPOINT_PATH.value,
37
+ flags.FLAGS.checkpoint_path,
69
38
  version=int(_VERSION.value),
70
- kv_cache_max_len=_KV_CACHE_MAX_LEN.value,
39
+ kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
71
40
  )
72
41
 
73
42
  config = pytorch_model.image_encoder.config.image_embedding
74
43
  converter.convert_to_tflite(
75
44
  pytorch_model,
76
- output_path=_OUTPUT_PATH.value,
77
- output_name_prefix=f'{_OUTPUT_NAME_PREFIX.value}_{_VERSION.value}',
78
- prefill_seq_len=_PREFILL_SEQ_LEN.value,
45
+ output_path=flags.FLAGS.output_path,
46
+ output_name_prefix=f'{flags.FLAGS.output_name_prefix}_{_VERSION.value}',
47
+ prefill_seq_len=flags.FLAGS.prefill_seq_lens,
79
48
  pixel_values_size=torch.Size(
80
49
  [1, config.channels, config.image_size, config.image_size]
81
50
  ),
82
- quantize=_QUANTIZE.value,
51
+ quantize=flags.FLAGS.quantize,
83
52
  config=pytorch_model.config.decoder_config,
84
53
  export_config=ExportConfig(),
85
54
  )
@@ -16,62 +16,25 @@
16
16
  """Example of converting a Phi-3.5 model to multi-signature tflite model."""
17
17
 
18
18
  import os
19
- import pathlib
20
-
21
19
  from absl import app
22
20
  from absl import flags
23
21
  from ai_edge_torch.generative.examples.phi import phi3
24
22
  from ai_edge_torch.generative.utilities import converter
25
23
  from ai_edge_torch.generative.utilities.model_builder import ExportConfig
26
24
 
27
- _CHECKPOINT_PATH = flags.DEFINE_string(
28
- 'checkpoint_path',
29
- os.path.join(pathlib.Path.home(), 'Downloads/llm_data/phi3'),
30
- 'The path to the model checkpoint, or directory holding the checkpoint.',
31
- )
32
- _OUTPUT_PATH = flags.DEFINE_string(
33
- 'output_path',
34
- '/tmp/',
35
- 'The path to export the tflite model.',
36
- )
37
- _OUTPUT_NAME_PREFIX = flags.DEFINE_string(
38
- 'output_name_prefix',
39
- 'phi3',
40
- 'The prefix of the output tflite model name.',
41
- )
42
- _PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
43
- 'prefill_seq_lens',
44
- (8, 64, 128, 256, 512, 1024),
45
- 'List of the maximum sizes of prefill input tensors.',
46
- )
47
- _KV_CACHE_MAX_LEN = flags.DEFINE_integer(
48
- 'kv_cache_max_len',
49
- 1280,
50
- 'The maximum size of KV cache buffer, including both prefill and decode.',
51
- )
52
- _QUANTIZE = flags.DEFINE_bool(
53
- 'quantize',
54
- True,
55
- 'Whether the model should be quantized.',
56
- )
57
- _LORA_RANKS = flags.DEFINE_multi_integer(
58
- 'lora_ranks',
59
- None,
60
- 'If set, the model will be converted with the provided list of LoRA ranks.',
61
- )
62
-
25
+ flags = converter.define_conversion_flags("phi3")
63
26
 
64
27
  def main(_):
65
28
  pytorch_model = phi3.build_model(
66
- _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
29
+ flags.FLAGS.checkpoint_path, kv_cache_max_len=flags.FLAGS.kv_cache_max_len
67
30
  )
68
31
  converter.convert_to_tflite(
69
32
  pytorch_model,
70
- output_path=_OUTPUT_PATH.value,
71
- output_name_prefix=_OUTPUT_NAME_PREFIX.value,
72
- prefill_seq_len=_PREFILL_SEQ_LENS.value,
73
- quantize=_QUANTIZE.value,
74
- lora_ranks=_LORA_RANKS.value,
33
+ output_path=flags.FLAGS.output_path,
34
+ output_name_prefix=flags.FLAGS.output_name_prefix,
35
+ prefill_seq_len=flags.FLAGS.prefill_seq_lens,
36
+ quantize=flags.FLAGS.quantize,
37
+ lora_ranks=flags.FLAGS.lora_ranks,
75
38
  export_config=ExportConfig(),
76
39
  )
77
40