ai-edge-torch-nightly 0.4.0.dev20250330__py3-none-any.whl → 0.4.0.dev20250331__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (23) hide show
  1. ai_edge_torch/generative/examples/amd_llama_135m/convert_to_tflite.py +7 -43
  2. ai_edge_torch/generative/examples/deepseek/convert_to_tflite.py +7 -42
  3. ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py +7 -45
  4. ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +7 -44
  5. ai_edge_torch/generative/examples/gemma3/convert_gemma3_to_tflite.py +10 -45
  6. ai_edge_torch/generative/examples/llama/convert_to_tflite.py +9 -43
  7. ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +7 -44
  8. ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py +8 -39
  9. ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py +7 -44
  10. ai_edge_torch/generative/examples/phi/convert_phi4_to_tflite.py +7 -44
  11. ai_edge_torch/generative/examples/phi/convert_to_tflite.py +7 -42
  12. ai_edge_torch/generative/examples/qwen/convert_to_tflite.py +8 -45
  13. ai_edge_torch/generative/examples/qwen_vl/convert_to_tflite.py +8 -39
  14. ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +8 -43
  15. ai_edge_torch/generative/examples/smollm/convert_v2_to_tflite.py +8 -43
  16. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +7 -44
  17. ai_edge_torch/generative/utilities/converter.py +45 -0
  18. ai_edge_torch/version.py +1 -1
  19. {ai_edge_torch_nightly-0.4.0.dev20250330.dist-info → ai_edge_torch_nightly-0.4.0.dev20250331.dist-info}/METADATA +1 -1
  20. {ai_edge_torch_nightly-0.4.0.dev20250330.dist-info → ai_edge_torch_nightly-0.4.0.dev20250331.dist-info}/RECORD +23 -23
  21. {ai_edge_torch_nightly-0.4.0.dev20250330.dist-info → ai_edge_torch_nightly-0.4.0.dev20250331.dist-info}/LICENSE +0 -0
  22. {ai_edge_torch_nightly-0.4.0.dev20250330.dist-info → ai_edge_torch_nightly-0.4.0.dev20250331.dist-info}/WHEEL +0 -0
  23. {ai_edge_torch_nightly-0.4.0.dev20250330.dist-info → ai_edge_torch_nightly-0.4.0.dev20250331.dist-info}/top_level.txt +0 -0
@@ -16,62 +16,25 @@
16
16
  """Example of converting a Phi-4 model to multi-signature tflite model."""
17
17
 
18
18
  import os
19
- import pathlib
20
-
21
19
  from absl import app
22
20
  from absl import flags
23
21
  from ai_edge_torch.generative.examples.phi import phi4
24
22
  from ai_edge_torch.generative.utilities import converter
25
23
  from ai_edge_torch.generative.utilities.model_builder import ExportConfig
26
24
 
27
- _CHECKPOINT_PATH = flags.DEFINE_string(
28
- 'checkpoint_path',
29
- os.path.join(pathlib.Path.home(), 'Downloads/llm_data/phi4'),
30
- 'The path to the model checkpoint, or directory holding the checkpoint.',
31
- )
32
- _OUTPUT_PATH = flags.DEFINE_string(
33
- 'output_path',
34
- '/tmp/',
35
- 'The path to export the tflite model.',
36
- )
37
- _OUTPUT_NAME_PREFIX = flags.DEFINE_string(
38
- 'output_name_prefix',
39
- 'phi4',
40
- 'The prefix of the output tflite model name.',
41
- )
42
- _PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
43
- 'prefill_seq_lens',
44
- (8, 64, 128, 256, 512, 1024),
45
- 'List of the maximum sizes of prefill input tensors.',
46
- )
47
- _KV_CACHE_MAX_LEN = flags.DEFINE_integer(
48
- 'kv_cache_max_len',
49
- 1280,
50
- 'The maximum size of KV cache buffer, including both prefill and decode.',
51
- )
52
- _QUANTIZE = flags.DEFINE_bool(
53
- 'quantize',
54
- True,
55
- 'Whether the model should be quantized.',
56
- )
57
- _LORA_RANKS = flags.DEFINE_multi_integer(
58
- 'lora_ranks',
59
- None,
60
- 'If set, the model will be converted with the provided list of LoRA ranks.',
61
- )
62
-
25
+ flags = converter.define_conversion_flags("phi4")
63
26
 
64
27
  def main(_):
65
28
  pytorch_model = phi4.build_model(
66
- _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
29
+ flags.FLAGS.checkpoint_path, kv_cache_max_len=flags.FLAGS.kv_cache_max_len
67
30
  )
68
31
  converter.convert_to_tflite(
69
32
  pytorch_model,
70
- output_path=_OUTPUT_PATH.value,
71
- output_name_prefix=_OUTPUT_NAME_PREFIX.value,
72
- prefill_seq_len=_PREFILL_SEQ_LENS.value,
73
- quantize=_QUANTIZE.value,
74
- lora_ranks=_LORA_RANKS.value,
33
+ output_path=flags.FLAGS.output_path,
34
+ output_name_prefix=flags.FLAGS.output_name_prefix,
35
+ prefill_seq_len=flags.FLAGS.prefill_seq_lens,
36
+ quantize=flags.FLAGS.quantize,
37
+ lora_ranks=flags.FLAGS.lora_ranks,
75
38
  export_config=ExportConfig(),
76
39
  )
77
40
 
@@ -24,54 +24,19 @@ from ai_edge_torch.generative.examples.phi import phi2
24
24
  from ai_edge_torch.generative.utilities import converter
25
25
  from ai_edge_torch.generative.utilities.model_builder import ExportConfig
26
26
 
27
- _CHECKPOINT_PATH = flags.DEFINE_string(
28
- 'checkpoint_path',
29
- os.path.join(pathlib.Path.home(), 'Downloads/llm_data/phi2'),
30
- 'The path to the model checkpoint, or directory holding the checkpoint.',
31
- )
32
- _OUTPUT_PATH = flags.DEFINE_string(
33
- 'output_path',
34
- '/tmp/',
35
- 'The path to export the tflite model.',
36
- )
37
- _OUTPUT_NAME_PREFIX = flags.DEFINE_string(
38
- 'output_name_prefix',
39
- 'phi2',
40
- 'The prefix of the output tflite model name.',
41
- )
42
- _PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
43
- 'prefill_seq_lens',
44
- (8, 64, 128, 256, 512, 1024),
45
- 'List of the maximum sizes of prefill input tensors.',
46
- )
47
- _KV_CACHE_MAX_LEN = flags.DEFINE_integer(
48
- 'kv_cache_max_len',
49
- 1280,
50
- 'The maximum size of KV cache buffer, including both prefill and decode.',
51
- )
52
- _QUANTIZE = flags.DEFINE_bool(
53
- 'quantize',
54
- True,
55
- 'Whether the model should be quantized.',
56
- )
57
- _LORA_RANKS = flags.DEFINE_multi_integer(
58
- 'lora_ranks',
59
- None,
60
- 'If set, the model will be converted with the provided list of LoRA ranks.',
61
- )
62
-
27
+ flags = converter.define_conversion_flags("phi2")
63
28
 
64
29
  def main(_):
65
30
  pytorch_model = phi2.build_model(
66
- _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
31
+ flags.FLAGS.checkpoint_path, kv_cache_max_len=flags.FLAGS.kv_cache_max_len
67
32
  )
68
33
  converter.convert_to_tflite(
69
34
  pytorch_model,
70
- output_path=_OUTPUT_PATH.value,
71
- output_name_prefix=_OUTPUT_NAME_PREFIX.value,
72
- prefill_seq_len=_PREFILL_SEQ_LENS.value,
73
- quantize=_QUANTIZE.value,
74
- lora_ranks=_LORA_RANKS.value,
35
+ output_path=flags.FLAGS.output_path,
36
+ output_name_prefix=flags.FLAGS.output_name_prefix,
37
+ prefill_seq_len=flags.FLAGS.prefill_seq_lens,
38
+ quantize=flags.FLAGS.quantize,
39
+ lora_ranks=flags.FLAGS.lora_ranks,
75
40
  export_config=ExportConfig(),
76
41
  )
77
42
 
@@ -16,56 +16,20 @@
16
16
  """Example of converting Qwen 2.5 models to multi-signature tflite model."""
17
17
 
18
18
  import os
19
- import pathlib
20
-
21
19
  from absl import app
22
20
  from absl import flags
23
21
  from ai_edge_torch.generative.examples.qwen import qwen
24
22
  from ai_edge_torch.generative.utilities import converter
25
23
  from ai_edge_torch.generative.utilities.model_builder import ExportConfig
26
24
 
25
+ flags = converter.define_conversion_flags('qwen')
26
+
27
27
  _MODEL_SIZE = flags.DEFINE_enum(
28
28
  'model_size',
29
29
  '3b',
30
30
  ['0.5b', '1.5b', '3b'],
31
31
  'The size of the model to convert.',
32
32
  )
33
- _CHECKPOINT_PATH = flags.DEFINE_string(
34
- 'checkpoint_path',
35
- os.path.join(pathlib.Path.home(), 'Downloads/llm_data/qwen'),
36
- 'The path to the model checkpoint, or directory holding the checkpoint.',
37
- )
38
- _OUTPUT_PATH = flags.DEFINE_string(
39
- 'output_path',
40
- '/tmp/',
41
- 'The path to export the tflite model.',
42
- )
43
- _OUTPUT_NAME_PREFIX = flags.DEFINE_string(
44
- 'output_name_prefix',
45
- 'qwen',
46
- 'The prefix of the output tflite model name.',
47
- )
48
- _PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
49
- 'prefill_seq_lens',
50
- (8, 64, 128, 256, 512, 1024),
51
- 'List of the maximum sizes of prefill input tensors.',
52
- )
53
- _KV_CACHE_MAX_LEN = flags.DEFINE_integer(
54
- 'kv_cache_max_len',
55
- 1280,
56
- 'The maximum size of KV cache buffer, including both prefill and decode.',
57
- )
58
- _QUANTIZE = flags.DEFINE_bool(
59
- 'quantize',
60
- True,
61
- 'Whether the model should be quantized.',
62
- )
63
- _LORA_RANKS = flags.DEFINE_multi_integer(
64
- 'lora_ranks',
65
- None,
66
- 'If set, the model will be converted with the provided list of LoRA ranks.',
67
- )
68
-
69
33
 
70
34
  _BUILDER = {
71
35
  '0.5b': qwen.build_0_5b_model,
@@ -73,18 +37,17 @@ _BUILDER = {
73
37
  '3b': qwen.build_3b_model,
74
38
  }
75
39
 
76
-
77
40
  def main(_):
78
41
  pytorch_model = _BUILDER[_MODEL_SIZE.value](
79
- _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
42
+ flags.FLAGS.checkpoint_path, kv_cache_max_len=flags.FLAGS.kv_cache_max_len
80
43
  )
81
44
  converter.convert_to_tflite(
82
45
  pytorch_model,
83
- output_path=_OUTPUT_PATH.value,
84
- output_name_prefix=_OUTPUT_NAME_PREFIX.value,
85
- prefill_seq_len=_PREFILL_SEQ_LENS.value,
86
- quantize=_QUANTIZE.value,
87
- lora_ranks=_LORA_RANKS.value,
46
+ output_path=flags.FLAGS.output_path,
47
+ output_name_prefix=flags.FLAGS.output_name_prefix,
48
+ prefill_seq_len=flags.FLAGS.prefill_seq_lens,
49
+ quantize=flags.FLAGS.quantize,
50
+ lora_ranks=flags.FLAGS.lora_ranks,
88
51
  export_config=ExportConfig(),
89
52
  )
90
53
 
@@ -16,39 +16,14 @@
16
16
  """Example of converting a Qwen 2.5 VL model to multi-signature tflite model."""
17
17
 
18
18
  import os
19
- import pathlib
20
-
21
19
  from absl import app
22
20
  from absl import flags
23
21
  from ai_edge_torch.generative.examples.qwen_vl import qwen_vl
24
22
  from ai_edge_torch.generative.utilities import converter
25
23
  from ai_edge_torch.generative.utilities.model_builder import ExportConfig
26
24
 
27
- _CHECKPOINT_PATH = flags.DEFINE_string(
28
- 'checkpoint_path',
29
- os.path.join(pathlib.Path.home(), 'Downloads/llm_data/qwen-vl'),
30
- 'The path to the model checkpoint, or directory holding the checkpoint.',
31
- )
32
- _OUTPUT_PATH = flags.DEFINE_string(
33
- 'output_path',
34
- '/tmp/',
35
- 'The path to export the tflite model.',
36
- )
37
- _OUTPUT_NAME_PREFIX = flags.DEFINE_string(
38
- 'output_name_prefix',
39
- 'qwen_vl',
40
- 'The prefix of the output tflite model name.',
41
- )
42
- _PREFILL_SEQ_LEN = flags.DEFINE_integer(
43
- 'prefill_seq_len',
44
- 1024,
45
- 'The maximum size of prefill input tensor.',
46
- )
47
- _KV_CACHE_MAX_LEN = flags.DEFINE_integer(
48
- 'kv_cache_max_len',
49
- 1280,
50
- 'The maximum size of KV cache buffer, including both prefill and decode.',
51
- )
25
+ flags = converter.define_conversion_flags('qwen_vl')
26
+
52
27
  _IMAGE_HEIGHT = flags.DEFINE_integer(
53
28
  'image_height',
54
29
  34 * 14,
@@ -59,30 +34,24 @@ _IMAGE_WIDTH = flags.DEFINE_integer(
59
34
  46 * 14,
60
35
  'The width of image.',
61
36
  )
62
- _QUANTIZE = flags.DEFINE_bool(
63
- 'quantize',
64
- True,
65
- 'Whether the model should be quantized.',
66
- )
67
-
68
37
 
69
38
  def main(_):
70
39
  pytorch_model = qwen_vl.build_model(
71
- _CHECKPOINT_PATH.value,
72
- kv_cache_max_len=_KV_CACHE_MAX_LEN.value,
40
+ flags.FLAGS.checkpoint_path,
41
+ kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
73
42
  image_size=(_IMAGE_HEIGHT.value, _IMAGE_WIDTH.value),
74
43
  )
75
44
 
76
45
  grid_thw = pytorch_model.image_encoder.get_grid_thw()
77
46
  converter.convert_to_tflite(
78
47
  pytorch_model,
79
- output_path=_OUTPUT_PATH.value,
80
- output_name_prefix=_OUTPUT_NAME_PREFIX.value,
81
- prefill_seq_len=_PREFILL_SEQ_LEN.value,
48
+ output_path=flags.FLAGS.output_path,
49
+ output_name_prefix=flags.FLAGS.output_name_prefix,
50
+ prefill_seq_len=flags.FLAGS.prefill_seq_lens,
82
51
  pixel_values_size=(
83
52
  pytorch_model.image_encoder.get_pixel_values_size(grid_thw)
84
53
  ),
85
- quantize=_QUANTIZE.value,
54
+ quantize=flags.FLAGS.quantize,
86
55
  config=pytorch_model.config.decoder_config,
87
56
  export_config=ExportConfig(),
88
57
  )
@@ -16,49 +16,14 @@
16
16
  """Example of converting SmolLM model to multi-signature tflite model."""
17
17
 
18
18
  import os
19
- import pathlib
20
-
21
19
  from absl import app
22
20
  from absl import flags
23
21
  from ai_edge_torch.generative.examples.smollm import smollm
24
22
  from ai_edge_torch.generative.utilities import converter
25
23
  from ai_edge_torch.generative.utilities import model_builder
26
24
 
27
- _CHECKPOINT_PATH = flags.DEFINE_string(
28
- 'checkpoint_path',
29
- os.path.join(pathlib.Path.home(), 'Downloads/llm_data/smollm'),
30
- 'The path to the model checkpoint, or directory holding the checkpoint.',
31
- )
32
- _OUTPUT_PATH = flags.DEFINE_string(
33
- 'output_path',
34
- '/tmp/',
35
- 'The path to export the tflite model.',
36
- )
37
- _OUTPUT_NAME_PREFIX = flags.DEFINE_string(
38
- 'output_name_prefix',
39
- 'smollm',
40
- 'The prefix of the output tflite model name.',
41
- )
42
- _PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
43
- 'prefill_seq_lens',
44
- (8, 64, 128, 256, 512, 1024),
45
- 'List of the maximum sizes of prefill input tensors.',
46
- )
47
- _KV_CACHE_MAX_LEN = flags.DEFINE_integer(
48
- 'kv_cache_max_len',
49
- 1280,
50
- 'The maximum size of KV cache buffer, including both prefill and decode.',
51
- )
52
- _QUANTIZE = flags.DEFINE_bool(
53
- 'quantize',
54
- True,
55
- 'Whether the model should be quantized.',
56
- )
57
- _LORA_RANKS = flags.DEFINE_multi_integer(
58
- 'lora_ranks',
59
- None,
60
- 'If set, the model will be converted with the provided list of LoRA ranks.',
61
- )
25
+ flags = converter.define_conversion_flags('smollm')
26
+
62
27
  _DECODE_BATCH_SIZE = flags.DEFINE_integer(
63
28
  'decode_batch_size',
64
29
  1,
@@ -68,15 +33,15 @@ _DECODE_BATCH_SIZE = flags.DEFINE_integer(
68
33
 
69
34
  def main(_):
70
35
  pytorch_model = smollm.build_model(
71
- _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
36
+ flags.FLAGS.checkpoint_path, kv_cache_max_len=flags.FLAGS.kv_cache_max_len
72
37
  )
73
38
  converter.convert_to_tflite(
74
39
  pytorch_model,
75
- output_path=_OUTPUT_PATH.value,
76
- output_name_prefix=_OUTPUT_NAME_PREFIX.value,
77
- prefill_seq_len=_PREFILL_SEQ_LENS.value,
78
- quantize=_QUANTIZE.value,
79
- lora_ranks=_LORA_RANKS.value,
40
+ output_path=flags.FLAGS.output_path,
41
+ output_name_prefix=flags.FLAGS.output_name_prefix,
42
+ prefill_seq_len=flags.FLAGS.prefill_seq_lens,
43
+ quantize=flags.FLAGS.quantize,
44
+ lora_ranks=flags.FLAGS.lora_ranks,
80
45
  export_config=model_builder.ExportConfig(
81
46
  decode_batch_size=_DECODE_BATCH_SIZE.value
82
47
  ),
@@ -16,49 +16,14 @@
16
16
  """Example of converting SmolLM2 model to multi-signature tflite model."""
17
17
 
18
18
  import os
19
- import pathlib
20
-
21
19
  from absl import app
22
20
  from absl import flags
23
21
  from ai_edge_torch.generative.examples.smollm import smollm
24
22
  from ai_edge_torch.generative.utilities import converter
25
23
  from ai_edge_torch.generative.utilities import model_builder
26
24
 
27
- _CHECKPOINT_PATH = flags.DEFINE_string(
28
- 'checkpoint_path',
29
- os.path.join(pathlib.Path.home(), 'Downloads/llm_data/smollm2'),
30
- 'The path to the model checkpoint, or directory holding the checkpoint.',
31
- )
32
- _OUTPUT_PATH = flags.DEFINE_string(
33
- 'output_path',
34
- '/tmp/',
35
- 'The path to export the tflite model.',
36
- )
37
- _OUTPUT_NAME_PREFIX = flags.DEFINE_string(
38
- 'output_name_prefix',
39
- 'smollm2',
40
- 'The prefix of the output tflite model name.',
41
- )
42
- _PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
43
- 'prefill_seq_lens',
44
- (8, 64, 128, 256, 512, 1024),
45
- 'List of the maximum sizes of prefill input tensors.',
46
- )
47
- _KV_CACHE_MAX_LEN = flags.DEFINE_integer(
48
- 'kv_cache_max_len',
49
- 1280,
50
- 'The maximum size of KV cache buffer, including both prefill and decode.',
51
- )
52
- _QUANTIZE = flags.DEFINE_bool(
53
- 'quantize',
54
- True,
55
- 'Whether the model should be quantized.',
56
- )
57
- _LORA_RANKS = flags.DEFINE_multi_integer(
58
- 'lora_ranks',
59
- None,
60
- 'If set, the model will be converted with the provided list of LoRA ranks.',
61
- )
25
+ flags = converter.define_conversion_flags('smollm2')
26
+
62
27
  _DECODE_BATCH_SIZE = flags.DEFINE_integer(
63
28
  'decode_batch_size',
64
29
  1,
@@ -68,16 +33,16 @@ _DECODE_BATCH_SIZE = flags.DEFINE_integer(
68
33
 
69
34
  def main(_):
70
35
  pytorch_model = smollm.build_model_v2(
71
- _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
36
+ flags.FLAGS.checkpoint_path, kv_cache_max_len=flags.FLAGS.kv_cache_max_len
72
37
  )
73
38
 
74
39
  converter.convert_to_tflite(
75
40
  pytorch_model,
76
- output_path=_OUTPUT_PATH.value,
77
- output_name_prefix=_OUTPUT_NAME_PREFIX.value,
78
- prefill_seq_len=_PREFILL_SEQ_LENS.value,
79
- quantize=_QUANTIZE.value,
80
- lora_ranks=_LORA_RANKS.value,
41
+ output_path=flags.FLAGS.output_path,
42
+ output_name_prefix=flags.FLAGS.output_name_prefix,
43
+ prefill_seq_len=flags.FLAGS.prefill_seq_lens,
44
+ quantize=flags.FLAGS.quantize,
45
+ lora_ranks=flags.FLAGS.lora_ranks,
81
46
  export_config=model_builder.ExportConfig(
82
47
  decode_batch_size=_DECODE_BATCH_SIZE.value
83
48
  ),
@@ -16,62 +16,25 @@
16
16
  """Example of converting TinyLlama model to multi-signature tflite model."""
17
17
 
18
18
  import os
19
- import pathlib
20
-
21
19
  from absl import app
22
20
  from absl import flags
23
21
  from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
24
22
  from ai_edge_torch.generative.utilities import converter
25
23
  from ai_edge_torch.generative.utilities.model_builder import ExportConfig
26
24
 
27
- _CHECKPOINT_PATH = flags.DEFINE_string(
28
- 'checkpoint_path',
29
- os.path.join(pathlib.Path.home(), 'Downloads/llm_data/tiny_llama'),
30
- 'The path to the model checkpoint, or directory holding the checkpoint.',
31
- )
32
- _OUTPUT_PATH = flags.DEFINE_string(
33
- 'output_path',
34
- '/tmp/',
35
- 'The path to export the tflite model.',
36
- )
37
- _OUTPUT_NAME_PREFIX = flags.DEFINE_string(
38
- 'output_name_prefix',
39
- 'tinyllama',
40
- 'The prefix of the output tflite model name.',
41
- )
42
- _PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
43
- 'prefill_seq_lens',
44
- (8, 64, 128, 256, 512, 1024),
45
- 'List of the maximum sizes of prefill input tensors.',
46
- )
47
- _KV_CACHE_MAX_LEN = flags.DEFINE_integer(
48
- 'kv_cache_max_len',
49
- 1280,
50
- 'The maximum size of KV cache buffer, including both prefill and decode.',
51
- )
52
- _QUANTIZE = flags.DEFINE_bool(
53
- 'quantize',
54
- True,
55
- 'Whether the model should be quantized.',
56
- )
57
- _LORA_RANKS = flags.DEFINE_multi_integer(
58
- 'lora_ranks',
59
- None,
60
- 'If set, the model will be converted with the provided list of LoRA ranks.',
61
- )
62
-
25
+ flags = converter.define_conversion_flags("tiny_llama")
63
26
 
64
27
  def main(_):
65
28
  pytorch_model = tiny_llama.build_model(
66
- _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
29
+ flags.FLAGS.checkpoint_path, kv_cache_max_len=flags.FLAGS.kv_cache_max_len
67
30
  )
68
31
  converter.convert_to_tflite(
69
32
  pytorch_model,
70
- output_path=_OUTPUT_PATH.value,
71
- output_name_prefix=_OUTPUT_NAME_PREFIX.value,
72
- prefill_seq_len=_PREFILL_SEQ_LENS.value,
73
- quantize=_QUANTIZE.value,
74
- lora_ranks=_LORA_RANKS.value,
33
+ output_path=flags.FLAGS.output_path,
34
+ output_name_prefix=flags.FLAGS.output_name_prefix,
35
+ prefill_seq_len=flags.FLAGS.prefill_seq_lens,
36
+ quantize=flags.FLAGS.quantize,
37
+ lora_ranks=flags.FLAGS.lora_ranks,
75
38
  export_config=ExportConfig(),
76
39
  )
77
40
 
@@ -16,7 +16,9 @@
16
16
  """Common utility functions for model conversion."""
17
17
 
18
18
  import os
19
+ import pathlib
19
20
  from typing import Optional, Union
21
+ from absl import flags
20
22
  from ai_edge_torch._convert import converter as converter_utils
21
23
  from ai_edge_torch.generative.layers import lora as lora_utils
22
24
  import ai_edge_torch.generative.layers.model_config as cfg
@@ -37,6 +39,49 @@ class ExportableModule(torch.nn.Module):
37
39
  return self.module(*export_args, **full_kwargs)
38
40
 
39
41
 
42
+ def define_conversion_flags(model_name: str):
43
+ """Defines common flags used for model conversion."""
44
+
45
+ flags.DEFINE_string(
46
+ 'checkpoint_path',
47
+ os.path.join(pathlib.Path.home(), f'Downloads/llm_data/{model_name}'),
48
+ 'The path to the model checkpoint, or directory holding the checkpoint.',
49
+ )
50
+ flags.DEFINE_string(
51
+ 'output_path',
52
+ '/tmp/',
53
+ 'The path to export the tflite model.',
54
+ )
55
+ flags.DEFINE_string(
56
+ 'output_name_prefix',
57
+ 'qwen',
58
+ 'The prefix of the output tflite model name.',
59
+ )
60
+ flags.DEFINE_multi_integer(
61
+ 'prefill_seq_lens',
62
+ (8, 64, 128, 256, 512, 1024),
63
+ 'List of the maximum sizes of prefill input tensors.',
64
+ )
65
+ flags.DEFINE_integer(
66
+ 'kv_cache_max_len',
67
+ 1280,
68
+ 'The maximum size of KV cache buffer, including both prefill and decode.',
69
+ )
70
+ flags.DEFINE_bool(
71
+ 'quantize',
72
+ True,
73
+ 'Whether the model should be quantized.',
74
+ )
75
+ flags.DEFINE_multi_integer(
76
+ 'lora_ranks',
77
+ None,
78
+ 'If set, the model will be converted with the provided list of LoRA'
79
+ ' ranks.',
80
+ )
81
+
82
+ return flags
83
+
84
+
40
85
  def convert_to_tflite(
41
86
  pytorch_model: torch.nn.Module,
42
87
  output_path: str,
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.4.0.dev20250330"
16
+ __version__ = "0.4.0.dev20250331"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.4.0.dev20250330
3
+ Version: 0.4.0.dev20250331
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI