ai-edge-torch-nightly 0.7.0.dev20250929__py3-none-any.whl → 0.8.0.dev20251206__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ai-edge-torch-nightly might be problematic. Click here for more details.

Files changed (57) hide show
  1. ai_edge_torch/_convert/conversion.py +2 -1
  2. ai_edge_torch/fx_infra/_safe_run_decompositions.py +36 -1
  3. ai_edge_torch/generative/examples/amd_llama_135m/convert_to_tflite.py +1 -20
  4. ai_edge_torch/generative/examples/deepseek/convert_to_tflite.py +1 -20
  5. ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py +1 -20
  6. ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +1 -20
  7. ai_edge_torch/generative/examples/gemma3/convert_gemma3_to_tflite.py +3 -27
  8. ai_edge_torch/generative/examples/hammer/convert_to_tflite.py +1 -20
  9. ai_edge_torch/generative/examples/llama/convert_to_tflite.py +1 -20
  10. ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +1 -20
  11. ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py +1 -20
  12. ai_edge_torch/generative/examples/phi/convert_phi4_to_tflite.py +1 -20
  13. ai_edge_torch/generative/examples/phi/convert_to_tflite.py +1 -20
  14. ai_edge_torch/generative/examples/qwen/convert_to_tflite.py +1 -20
  15. ai_edge_torch/generative/examples/qwen/convert_v3_to_tflite.py +1 -20
  16. ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +1 -30
  17. ai_edge_torch/generative/examples/smollm/convert_v2_to_tflite.py +1 -30
  18. ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +1 -3
  19. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_lms.py +1 -1
  20. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +1 -20
  21. ai_edge_torch/generative/layers/attention.py +25 -2
  22. ai_edge_torch/generative/layers/attention_test.py +13 -1
  23. ai_edge_torch/generative/layers/attention_utils.py +62 -1
  24. ai_edge_torch/generative/layers/attention_utils_test.py +20 -0
  25. ai_edge_torch/generative/layers/builder.py +4 -2
  26. ai_edge_torch/generative/layers/model_config.py +5 -0
  27. ai_edge_torch/generative/layers/normalization.py +8 -2
  28. ai_edge_torch/generative/layers/scaled_dot_product_attention.py +35 -5
  29. ai_edge_torch/generative/layers/sdpa_with_kv_update.py +8 -3
  30. ai_edge_torch/generative/quantize/example.py +1 -1
  31. ai_edge_torch/generative/quantize/quant_attrs.py +8 -1
  32. ai_edge_torch/generative/quantize/quant_recipe.py +0 -13
  33. ai_edge_torch/generative/quantize/quant_recipe_utils.py +12 -19
  34. ai_edge_torch/generative/quantize/quant_recipes.py +16 -21
  35. ai_edge_torch/generative/quantize/supported_schemes.py +4 -1
  36. ai_edge_torch/generative/test/test_kv_cache.py +18 -6
  37. ai_edge_torch/generative/test/test_quantize.py +17 -26
  38. ai_edge_torch/generative/utilities/converter.py +183 -28
  39. ai_edge_torch/generative/utilities/export_config.py +2 -0
  40. ai_edge_torch/generative/utilities/litertlm_builder.py +61 -8
  41. ai_edge_torch/generative/utilities/loader.py +2 -1
  42. ai_edge_torch/lowertools/translate_recipe.py +8 -3
  43. ai_edge_torch/odml_torch/experimental/__init__.py +14 -0
  44. ai_edge_torch/odml_torch/experimental/torch_tfl/__init__.py +20 -0
  45. ai_edge_torch/odml_torch/experimental/torch_tfl/_decomps.py +438 -0
  46. ai_edge_torch/odml_torch/experimental/torch_tfl/_lowerings.py +728 -0
  47. ai_edge_torch/odml_torch/experimental/torch_tfl/_ops.py +371 -0
  48. ai_edge_torch/odml_torch/experimental/torch_tfl/torch_library_utils.py +37 -0
  49. ai_edge_torch/odml_torch/export.py +24 -7
  50. ai_edge_torch/odml_torch/lowerings/_basic.py +155 -0
  51. ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +255 -5
  52. ai_edge_torch/version.py +1 -1
  53. {ai_edge_torch_nightly-0.7.0.dev20250929.dist-info → ai_edge_torch_nightly-0.8.0.dev20251206.dist-info}/METADATA +15 -3
  54. {ai_edge_torch_nightly-0.7.0.dev20250929.dist-info → ai_edge_torch_nightly-0.8.0.dev20251206.dist-info}/RECORD +57 -51
  55. {ai_edge_torch_nightly-0.7.0.dev20250929.dist-info → ai_edge_torch_nightly-0.8.0.dev20251206.dist-info}/WHEEL +1 -1
  56. {ai_edge_torch_nightly-0.7.0.dev20250929.dist-info → ai_edge_torch_nightly-0.8.0.dev20251206.dist-info/licenses}/LICENSE +0 -0
  57. {ai_edge_torch_nightly-0.7.0.dev20250929.dist-info → ai_edge_torch_nightly-0.8.0.dev20251206.dist-info}/top_level.txt +0 -0
@@ -133,10 +133,11 @@ def convert_signatures(
133
133
  exported_program = fx_infra.safe_run_decompositions(
134
134
  exported_program,
135
135
  fx_infra.decomp.pre_convert_decomp(),
136
+ can_skip=False,
136
137
  )
137
138
  return exported_program
138
139
 
139
- exported_programs: torch.export.ExportedProgram = [
140
+ exported_programs = [
140
141
  export(
141
142
  mod=sig.module,
142
143
  args=sig.args,
@@ -13,6 +13,7 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
  """ExportedProgram.run_decompositions wrapper to handle unexpected export behavior."""
16
+ import operator
16
17
  import torch
17
18
 
18
19
 
@@ -26,8 +27,39 @@ _DUMMY_DECOMP_TABLE = {
26
27
  torch._ops.OperatorBase(): lambda: None,
27
28
  }
28
29
 
30
+ _BUILTIN_OPERATORS = {
31
+ getattr(operator, name)
32
+ for name in dir(operator)
33
+ if not name.startswith("_")
34
+ }
35
+
36
+
37
+ def _require_decomp(
38
+ exported_program: torch.export.ExportedProgram, decomp_table
39
+ ):
40
+ """Checks if the exported program requires decompositions."""
41
+ for node in exported_program.graph.nodes:
42
+ if "call_" not in str(node.op):
43
+ continue
44
+
45
+ op = node.target
46
+ if isinstance(op, torch._ops.OpOverloadPacket):
47
+ op = op.default
29
48
 
30
- def safe_run_decompositions(exported_program, decomp_table=None):
49
+ if op in decomp_table:
50
+ return True
51
+
52
+ if (
53
+ not isinstance(op, (torch._ops.OpOverload, torch._ops.OperatorBase))
54
+ and op not in _BUILTIN_OPERATORS
55
+ ):
56
+ # Python function that requires to be retraced via run_decompositions.
57
+ return True
58
+
59
+ return False
60
+
61
+
62
+ def safe_run_decompositions(exported_program, decomp_table=None, can_skip=True):
31
63
  """Wrapper for ExportedProgram.run_decompositions to handle unexpected export behavior."""
32
64
 
33
65
  if decomp_table is not None and not decomp_table:
@@ -35,6 +67,9 @@ def safe_run_decompositions(exported_program, decomp_table=None):
35
67
  # instead for backward compatibility.
36
68
  decomp_table = _DUMMY_DECOMP_TABLE
37
69
 
70
+ if can_skip and not _require_decomp(exported_program, decomp_table):
71
+ return exported_program
72
+
38
73
  for node in exported_program.graph.nodes:
39
74
  if node.target == torch.ops.aten.view.default:
40
75
  # Passes or torch.export may generate aten.view nodes not respecting the
@@ -18,31 +18,12 @@
18
18
  from absl import app
19
19
  from ai_edge_torch.generative.examples.amd_llama_135m import amd_llama_135m
20
20
  from ai_edge_torch.generative.utilities import converter
21
- from ai_edge_torch.generative.utilities import export_config
22
- from ai_edge_torch.generative.utilities import loader
23
21
 
24
22
  flags = converter.define_conversion_flags("amd-llama-135m")
25
23
 
26
24
 
27
25
  def main(_):
28
- checkpoint_path = flags.FLAGS.checkpoint_path
29
- pytorch_model = amd_llama_135m.build_model(
30
- checkpoint_path,
31
- custom_loader=loader.maybe_get_custom_loader(
32
- checkpoint_path, flags.FLAGS.custom_checkpoint_loader
33
- ),
34
- mask_cache_size=converter.get_mask_cache_size_from_flags(),
35
- )
36
- converter.convert_to_tflite(
37
- pytorch_model,
38
- output_path=flags.FLAGS.output_path,
39
- output_name_prefix=flags.FLAGS.output_name_prefix,
40
- prefill_seq_len=flags.FLAGS.prefill_seq_lens,
41
- kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
42
- quantize=flags.FLAGS.quantize,
43
- lora_ranks=flags.FLAGS.lora_ranks,
44
- export_config=export_config.get_from_flags(),
45
- )
26
+ converter.build_and_convert_to_tflite_from_flags(amd_llama_135m.build_model)
46
27
 
47
28
 
48
29
  if __name__ == '__main__':
@@ -18,8 +18,6 @@
18
18
  from absl import app
19
19
  from ai_edge_torch.generative.examples.deepseek import deepseek
20
20
  from ai_edge_torch.generative.utilities import converter
21
- from ai_edge_torch.generative.utilities import export_config
22
- from ai_edge_torch.generative.utilities import loader
23
21
 
24
22
  flags = converter.define_conversion_flags(
25
23
  'deepseek', default_mask_as_input=True, default_transpose_kv_cache=True
@@ -27,24 +25,7 @@ flags = converter.define_conversion_flags(
27
25
 
28
26
 
29
27
  def main(_):
30
- checkpoint_path = flags.FLAGS.checkpoint_path
31
- pytorch_model = deepseek.build_model(
32
- checkpoint_path,
33
- custom_loader=loader.maybe_get_custom_loader(
34
- checkpoint_path, flags.FLAGS.custom_checkpoint_loader
35
- ),
36
- mask_cache_size=converter.get_mask_cache_size_from_flags(),
37
- )
38
- converter.convert_to_tflite(
39
- pytorch_model,
40
- output_path=flags.FLAGS.output_path,
41
- output_name_prefix=flags.FLAGS.output_name_prefix,
42
- prefill_seq_len=flags.FLAGS.prefill_seq_lens,
43
- kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
44
- quantize=flags.FLAGS.quantize,
45
- lora_ranks=flags.FLAGS.lora_ranks,
46
- export_config=export_config.get_from_flags(),
47
- )
28
+ converter.build_and_convert_to_tflite_from_flags(deepseek.build_model)
48
29
 
49
30
 
50
31
  if __name__ == '__main__':
@@ -18,31 +18,12 @@
18
18
  from absl import app
19
19
  from ai_edge_torch.generative.examples.gemma import gemma1
20
20
  from ai_edge_torch.generative.utilities import converter
21
- from ai_edge_torch.generative.utilities import export_config
22
- from ai_edge_torch.generative.utilities import loader
23
21
 
24
22
  flags = converter.define_conversion_flags("gemma-2b")
25
23
 
26
24
 
27
25
  def main(_):
28
- checkpoint_path = flags.FLAGS.checkpoint_path
29
- pytorch_model = gemma1.build_2b_model(
30
- checkpoint_path,
31
- custom_loader=loader.maybe_get_custom_loader(
32
- checkpoint_path, flags.FLAGS.custom_checkpoint_loader
33
- ),
34
- mask_cache_size=converter.get_mask_cache_size_from_flags(),
35
- )
36
- converter.convert_to_tflite(
37
- pytorch_model,
38
- output_path=flags.FLAGS.output_path,
39
- output_name_prefix=flags.FLAGS.output_name_prefix,
40
- prefill_seq_len=flags.FLAGS.prefill_seq_lens,
41
- kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
42
- quantize=flags.FLAGS.quantize,
43
- lora_ranks=flags.FLAGS.lora_ranks,
44
- export_config=export_config.get_from_flags(),
45
- )
26
+ converter.build_and_convert_to_tflite_from_flags(gemma1.build_2b_model)
46
27
 
47
28
 
48
29
  if __name__ == '__main__':
@@ -18,8 +18,6 @@
18
18
  from absl import app
19
19
  from ai_edge_torch.generative.examples.gemma import gemma2
20
20
  from ai_edge_torch.generative.utilities import converter
21
- from ai_edge_torch.generative.utilities import export_config
22
- from ai_edge_torch.generative.utilities import loader
23
21
 
24
22
  flags = converter.define_conversion_flags(
25
23
  "gemma2-2b", default_mask_as_input=True, default_transpose_kv_cache=True
@@ -27,24 +25,7 @@ flags = converter.define_conversion_flags(
27
25
 
28
26
 
29
27
  def main(_):
30
- checkpoint_path = flags.FLAGS.checkpoint_path
31
- pytorch_model = gemma2.build_2b_model(
32
- checkpoint_path,
33
- custom_loader=loader.maybe_get_custom_loader(
34
- checkpoint_path, flags.FLAGS.custom_checkpoint_loader
35
- ),
36
- mask_cache_size=converter.get_mask_cache_size_from_flags(),
37
- )
38
- converter.convert_to_tflite(
39
- pytorch_model,
40
- output_path=flags.FLAGS.output_path,
41
- output_name_prefix=flags.FLAGS.output_name_prefix,
42
- prefill_seq_len=flags.FLAGS.prefill_seq_lens,
43
- kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
44
- quantize=flags.FLAGS.quantize,
45
- lora_ranks=flags.FLAGS.lora_ranks,
46
- export_config=export_config.get_from_flags(),
47
- )
28
+ converter.build_and_convert_to_tflite_from_flags(gemma2.build_2b_model)
48
29
 
49
30
 
50
31
  if __name__ == '__main__':
@@ -18,8 +18,6 @@
18
18
  from absl import app
19
19
  from ai_edge_torch.generative.examples.gemma3 import gemma3
20
20
  from ai_edge_torch.generative.utilities import converter
21
- from ai_edge_torch.generative.utilities import export_config
22
- from ai_edge_torch.generative.utilities import loader
23
21
 
24
22
  flags = converter.define_conversion_flags(
25
23
  'gemma3-1b', default_mask_as_input=True, default_transpose_kv_cache=True
@@ -33,36 +31,14 @@ _MODEL_SIZE = flags.DEFINE_string(
33
31
 
34
32
 
35
33
  def main(_):
36
- checkpoint_path = flags.FLAGS.checkpoint_path
37
34
  if _MODEL_SIZE.value == '1b':
38
- pytorch_model = gemma3.build_model_1b(
39
- checkpoint_path,
40
- custom_loader=loader.maybe_get_custom_loader(
41
- checkpoint_path, flags.FLAGS.custom_checkpoint_loader
42
- ),
43
- mask_cache_size=converter.get_mask_cache_size_from_flags(),
44
- )
35
+ model_builder = gemma3.build_model_1b
45
36
  elif _MODEL_SIZE.value == '270m':
46
- pytorch_model = gemma3.build_model_270m(
47
- checkpoint_path,
48
- custom_loader=loader.maybe_get_custom_loader(
49
- checkpoint_path, flags.FLAGS.custom_checkpoint_loader
50
- ),
51
- mask_cache_size=converter.get_mask_cache_size_from_flags(),
52
- )
37
+ model_builder = gemma3.build_model_270m
53
38
  else:
54
39
  raise ValueError(f'Unsupported model size: {_MODEL_SIZE.value}')
55
40
 
56
- converter.convert_to_tflite(
57
- pytorch_model,
58
- output_path=flags.FLAGS.output_path,
59
- output_name_prefix=flags.FLAGS.output_name_prefix,
60
- prefill_seq_len=flags.FLAGS.prefill_seq_lens,
61
- kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
62
- quantize=flags.FLAGS.quantize,
63
- lora_ranks=flags.FLAGS.lora_ranks,
64
- export_config=export_config.get_from_flags(),
65
- )
41
+ converter.build_and_convert_to_tflite_from_flags(model_builder)
66
42
 
67
43
 
68
44
  if __name__ == '__main__':
@@ -18,8 +18,6 @@
18
18
  from absl import app
19
19
  from ai_edge_torch.generative.examples.hammer import hammer
20
20
  from ai_edge_torch.generative.utilities import converter
21
- from ai_edge_torch.generative.utilities import export_config
22
- from ai_edge_torch.generative.utilities import loader
23
21
 
24
22
  flags = converter.define_conversion_flags('hammer')
25
23
 
@@ -37,24 +35,7 @@ _BUILDER = {
37
35
 
38
36
 
39
37
  def main(_):
40
- checkpoint_path = flags.FLAGS.checkpoint_path
41
- pytorch_model = _BUILDER[_MODEL_SIZE.value](
42
- checkpoint_path,
43
- custom_loader=loader.maybe_get_custom_loader(
44
- checkpoint_path, flags.FLAGS.custom_checkpoint_loader
45
- ),
46
- mask_cache_size=converter.get_mask_cache_size_from_flags(),
47
- )
48
- converter.convert_to_tflite(
49
- pytorch_model,
50
- output_path=flags.FLAGS.output_path,
51
- output_name_prefix=flags.FLAGS.output_name_prefix,
52
- prefill_seq_len=flags.FLAGS.prefill_seq_lens,
53
- kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
54
- quantize=flags.FLAGS.quantize,
55
- lora_ranks=flags.FLAGS.lora_ranks,
56
- export_config=export_config.get_from_flags(),
57
- )
38
+ converter.build_and_convert_to_tflite_from_flags(_BUILDER[_MODEL_SIZE.value])
58
39
 
59
40
 
60
41
  if __name__ == '__main__':
@@ -18,8 +18,6 @@
18
18
  from absl import app
19
19
  from ai_edge_torch.generative.examples.llama import llama
20
20
  from ai_edge_torch.generative.utilities import converter
21
- from ai_edge_torch.generative.utilities import export_config
22
- from ai_edge_torch.generative.utilities import loader
23
21
 
24
22
  flags = converter.define_conversion_flags('llama')
25
23
 
@@ -37,24 +35,7 @@ _BUILDER = {
37
35
 
38
36
 
39
37
  def main(_):
40
- checkpoint_path = flags.FLAGS.checkpoint_path
41
- pytorch_model = _BUILDER[_MODEL_SIZE.value](
42
- checkpoint_path,
43
- custom_loader=loader.maybe_get_custom_loader(
44
- checkpoint_path, flags.FLAGS.custom_checkpoint_loader
45
- ),
46
- mask_cache_size=converter.get_mask_cache_size_from_flags(),
47
- )
48
- converter.convert_to_tflite(
49
- pytorch_model,
50
- output_path=flags.FLAGS.output_path,
51
- output_name_prefix=flags.FLAGS.output_name_prefix,
52
- prefill_seq_len=flags.FLAGS.prefill_seq_lens,
53
- kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
54
- quantize=flags.FLAGS.quantize,
55
- lora_ranks=flags.FLAGS.lora_ranks,
56
- export_config=export_config.get_from_flags(),
57
- )
38
+ converter.build_and_convert_to_tflite_from_flags(_BUILDER[_MODEL_SIZE.value])
58
39
 
59
40
 
60
41
  if __name__ == '__main__':
@@ -18,31 +18,12 @@
18
18
  from absl import app
19
19
  from ai_edge_torch.generative.examples.openelm import openelm
20
20
  from ai_edge_torch.generative.utilities import converter
21
- from ai_edge_torch.generative.utilities import export_config
22
- from ai_edge_torch.generative.utilities import loader
23
21
 
24
22
  flags = converter.define_conversion_flags("openelm")
25
23
 
26
24
 
27
25
  def main(_):
28
- checkpoint_path = flags.FLAGS.checkpoint_path
29
- pytorch_model = openelm.build_model(
30
- checkpoint_path,
31
- custom_loader=loader.maybe_get_custom_loader(
32
- checkpoint_path, flags.FLAGS.custom_checkpoint_loader
33
- ),
34
- mask_cache_size=converter.get_mask_cache_size_from_flags(),
35
- )
36
- converter.convert_to_tflite(
37
- pytorch_model,
38
- output_path=flags.FLAGS.output_path,
39
- output_name_prefix=flags.FLAGS.output_name_prefix,
40
- prefill_seq_len=flags.FLAGS.prefill_seq_lens,
41
- kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
42
- quantize=flags.FLAGS.quantize,
43
- lora_ranks=flags.FLAGS.lora_ranks,
44
- export_config=export_config.get_from_flags(),
45
- )
26
+ converter.build_and_convert_to_tflite_from_flags(openelm.build_model)
46
27
 
47
28
 
48
29
  if __name__ == '__main__':
@@ -18,31 +18,12 @@
18
18
  from absl import app
19
19
  from ai_edge_torch.generative.examples.phi import phi3
20
20
  from ai_edge_torch.generative.utilities import converter
21
- from ai_edge_torch.generative.utilities import export_config
22
- from ai_edge_torch.generative.utilities import loader
23
21
 
24
22
  flags = converter.define_conversion_flags("phi3")
25
23
 
26
24
 
27
25
  def main(_):
28
- checkpoint_path = flags.FLAGS.checkpoint_path
29
- pytorch_model = phi3.build_model(
30
- checkpoint_path,
31
- custom_loader=loader.maybe_get_custom_loader(
32
- checkpoint_path, flags.FLAGS.custom_checkpoint_loader
33
- ),
34
- mask_cache_size=converter.get_mask_cache_size_from_flags(),
35
- )
36
- converter.convert_to_tflite(
37
- pytorch_model,
38
- output_path=flags.FLAGS.output_path,
39
- output_name_prefix=flags.FLAGS.output_name_prefix,
40
- prefill_seq_len=flags.FLAGS.prefill_seq_lens,
41
- kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
42
- quantize=flags.FLAGS.quantize,
43
- lora_ranks=flags.FLAGS.lora_ranks,
44
- export_config=export_config.get_from_flags(),
45
- )
26
+ converter.build_and_convert_to_tflite_from_flags(phi3.build_model)
46
27
 
47
28
 
48
29
  if __name__ == '__main__':
@@ -18,31 +18,12 @@
18
18
  from absl import app
19
19
  from ai_edge_torch.generative.examples.phi import phi4
20
20
  from ai_edge_torch.generative.utilities import converter
21
- from ai_edge_torch.generative.utilities import export_config
22
- from ai_edge_torch.generative.utilities import loader
23
21
 
24
22
  flags = converter.define_conversion_flags("phi4")
25
23
 
26
24
 
27
25
  def main(_):
28
- checkpoint_path = flags.FLAGS.checkpoint_path
29
- pytorch_model = phi4.build_model(
30
- checkpoint_path,
31
- custom_loader=loader.maybe_get_custom_loader(
32
- checkpoint_path, flags.FLAGS.custom_checkpoint_loader
33
- ),
34
- mask_cache_size=converter.get_mask_cache_size_from_flags(),
35
- )
36
- converter.convert_to_tflite(
37
- pytorch_model,
38
- output_path=flags.FLAGS.output_path,
39
- output_name_prefix=flags.FLAGS.output_name_prefix,
40
- prefill_seq_len=flags.FLAGS.prefill_seq_lens,
41
- kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
42
- quantize=flags.FLAGS.quantize,
43
- lora_ranks=flags.FLAGS.lora_ranks,
44
- export_config=export_config.get_from_flags(),
45
- )
26
+ converter.build_and_convert_to_tflite_from_flags(phi4.build_model)
46
27
 
47
28
 
48
29
  if __name__ == '__main__':
@@ -19,31 +19,12 @@
19
19
  from absl import app
20
20
  from ai_edge_torch.generative.examples.phi import phi2
21
21
  from ai_edge_torch.generative.utilities import converter
22
- from ai_edge_torch.generative.utilities import export_config
23
- from ai_edge_torch.generative.utilities import loader
24
22
 
25
23
  flags = converter.define_conversion_flags("phi2")
26
24
 
27
25
 
28
26
  def main(_):
29
- checkpoint_path = flags.FLAGS.checkpoint_path
30
- pytorch_model = phi2.build_model(
31
- checkpoint_path,
32
- custom_loader=loader.maybe_get_custom_loader(
33
- checkpoint_path, flags.FLAGS.custom_checkpoint_loader
34
- ),
35
- mask_cache_size=converter.get_mask_cache_size_from_flags(),
36
- )
37
- converter.convert_to_tflite(
38
- pytorch_model,
39
- output_path=flags.FLAGS.output_path,
40
- output_name_prefix=flags.FLAGS.output_name_prefix,
41
- prefill_seq_len=flags.FLAGS.prefill_seq_lens,
42
- kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
43
- quantize=flags.FLAGS.quantize,
44
- lora_ranks=flags.FLAGS.lora_ranks,
45
- export_config=export_config.get_from_flags(),
46
- )
27
+ converter.build_and_convert_to_tflite_from_flags(phi2.build_model)
47
28
 
48
29
 
49
30
  if __name__ == '__main__':
@@ -18,8 +18,6 @@
18
18
  from absl import app
19
19
  from ai_edge_torch.generative.examples.qwen import qwen
20
20
  from ai_edge_torch.generative.utilities import converter
21
- from ai_edge_torch.generative.utilities import export_config
22
- from ai_edge_torch.generative.utilities import loader
23
21
 
24
22
  flags = converter.define_conversion_flags('qwen')
25
23
 
@@ -38,24 +36,7 @@ _BUILDER = {
38
36
 
39
37
 
40
38
  def main(_):
41
- checkpoint_path = flags.FLAGS.checkpoint_path
42
- pytorch_model = _BUILDER[_MODEL_SIZE.value](
43
- checkpoint_path,
44
- custom_loader=loader.maybe_get_custom_loader(
45
- checkpoint_path, flags.FLAGS.custom_checkpoint_loader
46
- ),
47
- mask_cache_size=converter.get_mask_cache_size_from_flags(),
48
- )
49
- converter.convert_to_tflite(
50
- pytorch_model,
51
- output_path=flags.FLAGS.output_path,
52
- output_name_prefix=flags.FLAGS.output_name_prefix,
53
- prefill_seq_len=flags.FLAGS.prefill_seq_lens,
54
- kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
55
- quantize=flags.FLAGS.quantize,
56
- lora_ranks=flags.FLAGS.lora_ranks,
57
- export_config=export_config.get_from_flags(),
58
- )
39
+ converter.build_and_convert_to_tflite_from_flags(_BUILDER[_MODEL_SIZE.value])
59
40
 
60
41
 
61
42
  if __name__ == '__main__':
@@ -18,8 +18,6 @@
18
18
  from absl import app
19
19
  from ai_edge_torch.generative.examples.qwen import qwen3
20
20
  from ai_edge_torch.generative.utilities import converter
21
- from ai_edge_torch.generative.utilities import export_config
22
- from ai_edge_torch.generative.utilities import loader
23
21
 
24
22
  flags = converter.define_conversion_flags('qwen')
25
23
 
@@ -38,24 +36,7 @@ _BUILDER = {
38
36
 
39
37
 
40
38
  def main(_):
41
- checkpoint_path = flags.FLAGS.checkpoint_path
42
- pytorch_model = _BUILDER[_MODEL_SIZE.value](
43
- checkpoint_path,
44
- custom_loader=loader.maybe_get_custom_loader(
45
- checkpoint_path, flags.FLAGS.custom_checkpoint_loader
46
- ),
47
- mask_cache_size=converter.get_mask_cache_size_from_flags(),
48
- )
49
- converter.convert_to_tflite(
50
- pytorch_model,
51
- output_path=flags.FLAGS.output_path,
52
- output_name_prefix=flags.FLAGS.output_name_prefix,
53
- prefill_seq_len=flags.FLAGS.prefill_seq_lens,
54
- kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
55
- quantize=flags.FLAGS.quantize,
56
- lora_ranks=flags.FLAGS.lora_ranks,
57
- export_config=export_config.get_from_flags(),
58
- )
39
+ converter.build_and_convert_to_tflite_from_flags(_BUILDER[_MODEL_SIZE.value])
59
40
 
60
41
 
61
42
  if __name__ == '__main__':
@@ -18,41 +18,12 @@
18
18
  from absl import app
19
19
  from ai_edge_torch.generative.examples.smollm import smollm
20
20
  from ai_edge_torch.generative.utilities import converter
21
- from ai_edge_torch.generative.utilities import export_config as export_cfg
22
- from ai_edge_torch.generative.utilities import loader
23
21
 
24
22
  flags = converter.define_conversion_flags('smollm')
25
23
 
26
- _DECODE_BATCH_SIZE = flags.DEFINE_integer(
27
- 'decode_batch_size',
28
- 1,
29
- 'The batch size for the decode signature.',
30
- )
31
-
32
24
 
33
25
  def main(_):
34
- checkpoint_path = flags.FLAGS.checkpoint_path
35
- pytorch_model = smollm.build_model(
36
- checkpoint_path,
37
- custom_loader=loader.maybe_get_custom_loader(
38
- checkpoint_path, flags.FLAGS.custom_checkpoint_loader
39
- ),
40
- mask_cache_size=converter.get_mask_cache_size_from_flags(),
41
- )
42
-
43
- export_config = export_cfg.get_from_flags()
44
- export_config.decode_batch_size = _DECODE_BATCH_SIZE.value
45
-
46
- converter.convert_to_tflite(
47
- pytorch_model,
48
- output_path=flags.FLAGS.output_path,
49
- output_name_prefix=flags.FLAGS.output_name_prefix,
50
- prefill_seq_len=flags.FLAGS.prefill_seq_lens,
51
- kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
52
- quantize=flags.FLAGS.quantize,
53
- lora_ranks=flags.FLAGS.lora_ranks,
54
- export_config=export_config,
55
- )
26
+ converter.build_and_convert_to_tflite_from_flags(smollm.build_model)
56
27
 
57
28
 
58
29
  if __name__ == '__main__':
@@ -18,41 +18,12 @@
18
18
  from absl import app
19
19
  from ai_edge_torch.generative.examples.smollm import smollm
20
20
  from ai_edge_torch.generative.utilities import converter
21
- from ai_edge_torch.generative.utilities import export_config as export_cfg
22
- from ai_edge_torch.generative.utilities import loader
23
21
 
24
22
  flags = converter.define_conversion_flags('smollm2')
25
23
 
26
- _DECODE_BATCH_SIZE = flags.DEFINE_integer(
27
- 'decode_batch_size',
28
- 1,
29
- 'The batch size for the decode signature.',
30
- )
31
-
32
24
 
33
25
  def main(_):
34
- checkpoint_path = flags.FLAGS.checkpoint_path
35
- pytorch_model = smollm.build_model_v2(
36
- checkpoint_path,
37
- custom_loader=loader.maybe_get_custom_loader(
38
- checkpoint_path, flags.FLAGS.custom_checkpoint_loader
39
- ),
40
- mask_cache_size=converter.get_mask_cache_size_from_flags(),
41
- )
42
-
43
- export_config = export_cfg.get_from_flags()
44
- export_config.decode_batch_size = _DECODE_BATCH_SIZE.value
45
-
46
- converter.convert_to_tflite(
47
- pytorch_model,
48
- output_path=flags.FLAGS.output_path,
49
- output_name_prefix=flags.FLAGS.output_name_prefix,
50
- prefill_seq_len=flags.FLAGS.prefill_seq_lens,
51
- kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
52
- quantize=flags.FLAGS.quantize,
53
- lora_ranks=flags.FLAGS.lora_ranks,
54
- export_config=export_config,
55
- )
26
+ converter.build_and_convert_to_tflite_from_flags(smollm.build_model_v2)
56
27
 
57
28
 
58
29
  if __name__ == '__main__':
@@ -138,9 +138,7 @@ def convert_stable_diffusion_to_tflite(
138
138
  if not os.path.exists(output_dir):
139
139
  pathlib.Path(output_dir).mkdir(parents=True, exist_ok=True)
140
140
 
141
- quant_config = (
142
- quant_recipes.full_int8_weight_only_recipe() if quantize else None
143
- )
141
+ quant_config = quant_recipes.full_weight_only_recipe() if quantize else None
144
142
 
145
143
  # TODO(yichunk): convert to multi signature tflite model.
146
144
  # CLIP text encoder
@@ -69,6 +69,6 @@ class KLMSSampler(SamplerInterface):
69
69
  continue
70
70
  y *= x - self.sigmas[t - j]
71
71
  y /= self.sigmas[t - i] - self.sigmas[t - j]
72
- lms_coeff = np.trapz(y=y, x=x)
72
+ lms_coeff = np.trapezoid(y=y, x=x)
73
73
  latents += lms_coeff * output
74
74
  return latents
@@ -18,31 +18,12 @@
18
18
  from absl import app
19
19
  from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
20
20
  from ai_edge_torch.generative.utilities import converter
21
- from ai_edge_torch.generative.utilities import export_config
22
- from ai_edge_torch.generative.utilities import loader
23
21
 
24
22
  flags = converter.define_conversion_flags("tiny_llama")
25
23
 
26
24
 
27
25
  def main(_):
28
- checkpoint_path = flags.FLAGS.checkpoint_path
29
- pytorch_model = tiny_llama.build_model(
30
- checkpoint_path,
31
- custom_loader=loader.maybe_get_custom_loader(
32
- checkpoint_path, flags.FLAGS.custom_checkpoint_loader
33
- ),
34
- mask_cache_size=converter.get_mask_cache_size_from_flags(),
35
- )
36
- converter.convert_to_tflite(
37
- pytorch_model,
38
- output_path=flags.FLAGS.output_path,
39
- output_name_prefix=flags.FLAGS.output_name_prefix,
40
- prefill_seq_len=flags.FLAGS.prefill_seq_lens,
41
- kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
42
- quantize=flags.FLAGS.quantize,
43
- lora_ranks=flags.FLAGS.lora_ranks,
44
- export_config=export_config.get_from_flags(),
45
- )
26
+ converter.build_and_convert_to_tflite_from_flags(tiny_llama.build_model)
46
27
 
47
28
 
48
29
  if __name__ == '__main__':