ai-edge-torch-nightly 0.5.0.dev20250514__py3-none-any.whl → 0.5.0.dev20250516__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. ai_edge_torch/__init__.py +1 -0
  2. ai_edge_torch/_convert/conversion.py +23 -0
  3. ai_edge_torch/_convert/converter.py +57 -3
  4. ai_edge_torch/_convert/test/test_convert.py +25 -0
  5. ai_edge_torch/generative/examples/amd_llama_135m/amd_llama_135m.py +9 -2
  6. ai_edge_torch/generative/examples/amd_llama_135m/convert_to_tflite.py +7 -1
  7. ai_edge_torch/generative/examples/deepseek/convert_to_tflite.py +7 -2
  8. ai_edge_torch/generative/examples/deepseek/deepseek.py +8 -1
  9. ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py +7 -1
  10. ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +7 -1
  11. ai_edge_torch/generative/examples/gemma/gemma1.py +9 -1
  12. ai_edge_torch/generative/examples/gemma/gemma2.py +7 -2
  13. ai_edge_torch/generative/examples/gemma3/convert_gemma3_to_tflite.py +6 -1
  14. ai_edge_torch/generative/examples/hammer/convert_to_tflite.py +7 -1
  15. ai_edge_torch/generative/examples/hammer/hammer.py +14 -2
  16. ai_edge_torch/generative/examples/llama/convert_to_tflite.py +7 -1
  17. ai_edge_torch/generative/examples/llama/llama.py +25 -6
  18. ai_edge_torch/generative/examples/moonshine/convert_moonshine_to_tflite.py +0 -1
  19. ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +7 -1
  20. ai_edge_torch/generative/examples/openelm/openelm.py +8 -1
  21. ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py +6 -1
  22. ai_edge_torch/generative/examples/paligemma/decoder.py +1 -0
  23. ai_edge_torch/generative/examples/paligemma/decoder2.py +1 -0
  24. ai_edge_torch/generative/examples/paligemma/image_encoder.py +2 -1
  25. ai_edge_torch/generative/examples/paligemma/paligemma.py +12 -5
  26. ai_edge_torch/generative/examples/paligemma/verify.py +27 -5
  27. ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py +7 -1
  28. ai_edge_torch/generative/examples/phi/convert_phi4_to_tflite.py +7 -1
  29. ai_edge_torch/generative/examples/phi/convert_to_tflite.py +7 -1
  30. ai_edge_torch/generative/examples/phi/phi2.py +8 -1
  31. ai_edge_torch/generative/examples/phi/phi3.py +7 -2
  32. ai_edge_torch/generative/examples/phi/phi4.py +7 -2
  33. ai_edge_torch/generative/examples/qwen/convert_to_tflite.py +7 -1
  34. ai_edge_torch/generative/examples/qwen/qwen.py +20 -3
  35. ai_edge_torch/generative/examples/qwen_vl/convert_to_tflite.py +6 -1
  36. ai_edge_torch/generative/examples/qwen_vl/decoder.py +1 -2
  37. ai_edge_torch/generative/examples/qwen_vl/image_encoder.py +12 -4
  38. ai_edge_torch/generative/examples/qwen_vl/qwen_vl.py +12 -4
  39. ai_edge_torch/generative/examples/qwen_vl/verify.py +26 -5
  40. ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +7 -2
  41. ai_edge_torch/generative/examples/smollm/convert_v2_to_tflite.py +7 -1
  42. ai_edge_torch/generative/examples/smollm/smollm.py +14 -2
  43. ai_edge_torch/generative/examples/smollm/verify.py +2 -2
  44. ai_edge_torch/generative/examples/stable_diffusion/clip.py +2 -1
  45. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +7 -1
  46. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +8 -1
  47. ai_edge_torch/generative/layers/normalization.py +26 -7
  48. ai_edge_torch/generative/layers/normalization_test.py +73 -0
  49. ai_edge_torch/generative/utilities/converter.py +16 -4
  50. ai_edge_torch/generative/utilities/loader.py +45 -0
  51. ai_edge_torch/version.py +1 -1
  52. {ai_edge_torch_nightly-0.5.0.dev20250514.dist-info → ai_edge_torch_nightly-0.5.0.dev20250516.dist-info}/METADATA +1 -1
  53. {ai_edge_torch_nightly-0.5.0.dev20250514.dist-info → ai_edge_torch_nightly-0.5.0.dev20250516.dist-info}/RECORD +56 -55
  54. {ai_edge_torch_nightly-0.5.0.dev20250514.dist-info → ai_edge_torch_nightly-0.5.0.dev20250516.dist-info}/LICENSE +0 -0
  55. {ai_edge_torch_nightly-0.5.0.dev20250514.dist-info → ai_edge_torch_nightly-0.5.0.dev20250516.dist-info}/WHEEL +0 -0
  56. {ai_edge_torch_nightly-0.5.0.dev20250514.dist-info → ai_edge_torch_nightly-0.5.0.dev20250516.dist-info}/top_level.txt +0 -0
ai_edge_torch/__init__.py CHANGED
@@ -14,6 +14,7 @@
14
14
  # ==============================================================================
15
15
  from ai_edge_torch._config import config
16
16
  from ai_edge_torch._convert.converter import convert
17
+ from ai_edge_torch._convert.converter import experimental_add_compilation_backend
17
18
  from ai_edge_torch._convert.converter import signature
18
19
  from ai_edge_torch._convert.to_channel_last_io import to_channel_last_io
19
20
  from ai_edge_torch.model import Model
@@ -26,6 +26,9 @@ from ai_edge_torch.generative import fx_passes as generative_fx_passes
26
26
  from ai_edge_torch.quantize import quant_config as qcfg
27
27
  import torch
28
28
 
29
+ from ai_edge_litert.aot import aot_compile as aot_compile_lib
30
+ from ai_edge_litert.aot.core import types as litert_types
31
+
29
32
 
30
33
  def _run_convert_passes(
31
34
  exported_program: torch.export.ExportedProgram,
@@ -153,3 +156,23 @@ def convert_signatures(
153
156
  )
154
157
 
155
158
  return model.TfLiteModel(tflite_model)
159
+
160
+
161
+ def aot_compile(
162
+ compilation_configs: list[litert_types.CompilationConfig],
163
+ cpu_model: model.TfLiteModel,
164
+ ) -> litert_types.CompilationResult:
165
+ """Compiles the given CPU model.
166
+
167
+ Args:
168
+ compilation_configs: The list of compilation configs to use.
169
+ cpu_model: The CPU model to compile.
170
+
171
+ Returns:
172
+ The compilation result.
173
+ """
174
+ litert_model = litert_types.Model.create_from_bytes(cpu_model.tflite_model())
175
+ return aot_compile_lib.aot_compile(
176
+ litert_model,
177
+ config=compilation_configs,
178
+ )
@@ -23,6 +23,9 @@ from ai_edge_torch._convert import signature as signature_module
23
23
  from ai_edge_torch.quantize import quant_config as qcfg
24
24
  import torch
25
25
 
26
+ from ai_edge_litert.aot.core import types as litert_types
27
+ from ai_edge_litert.aot.vendors import import_vendor as vendor_lib
28
+
26
29
 
27
30
  class Converter:
28
31
  """A converter for converting PyTorch models to edge models.
@@ -32,6 +35,7 @@ class Converter:
32
35
 
33
36
  def __init__(self):
34
37
  self._signatures: list[signature_module.Signature] = []
38
+ self._compilation_configs: list[litert_types.CompilationConfig] = []
35
39
 
36
40
  def signature(
37
41
  self,
@@ -96,6 +100,31 @@ class Converter:
96
100
  )
97
101
  return self
98
102
 
103
+ def experimental_add_compilation_backend(
104
+ self,
105
+ target: litert_types.Target | None = None,
106
+ **kwargs: litert_types.Config,
107
+ ) -> Converter:
108
+ """Adds an AOT compilation target to the converter.
109
+
110
+ NOTE: This API is experimental and subject to change.
111
+
112
+ Args:
113
+ target: The target to compile for. If not specified, will compile to all
114
+ registered AOT targets in ai_edge_litert. See ai_edge_litert.aot.vendors
115
+ for more details. Adding a same target multiple times will be a no-op.
116
+ **kwargs: Additional arguments to pass to the backend compiler.
117
+
118
+ Returns:
119
+ The converter object itself.
120
+ """
121
+ if target is None:
122
+ target = vendor_lib.AllRegisteredTarget()
123
+ if isinstance(target, litert_types.Target):
124
+ target = litert_types.CompilationConfig(target=target, **kwargs)
125
+ self._compilation_configs.append(target)
126
+ return self
127
+
99
128
  def convert(
100
129
  self,
101
130
  module: torch.nn.Module = None,
@@ -107,7 +136,7 @@ class Converter:
107
136
  dynamic_shapes: Optional[Union[dict[str, Any], Tuple[Any, ...]]] = None,
108
137
  _ai_edge_converter_flags: Optional[dict[str, Any]] = None,
109
138
  _saved_model_dir: Optional[str] = None,
110
- ) -> model.TfLiteModel:
139
+ ) -> model.TfLiteModel | litert_types.CompilationResult:
111
140
  """Finalizes the conversion and produces an edge model.
112
141
 
113
142
  This could be called with no arguments as follows:
@@ -144,7 +173,9 @@ class Converter:
144
173
  specified, a random temporary directory would be used.
145
174
 
146
175
  Returns:
147
- The converted edge model.
176
+ The converted edge model. If compilation configs are provided, returns the
177
+ compilation result that contains the compiled edge models for different
178
+ targets.
148
179
 
149
180
  Raises:
150
181
  ValueError: If the arguments are not provided as expected. See the example
@@ -169,13 +200,16 @@ class Converter:
169
200
  "sample_args or sample_kwargs must be provided if a module is"
170
201
  " specified."
171
202
  )
172
- return conversion.convert_signatures(
203
+ converted_model = conversion.convert_signatures(
173
204
  self._signatures,
174
205
  strict_export=strict_export,
175
206
  quant_config=quant_config,
176
207
  _tfl_converter_flags=_ai_edge_converter_flags,
177
208
  _saved_model_dir=_saved_model_dir,
178
209
  )
210
+ if self._compilation_configs:
211
+ return conversion.aot_compile(self._compilation_configs, converted_model)
212
+ return converted_model
179
213
 
180
214
 
181
215
  def signature(
@@ -211,6 +245,26 @@ def signature(
211
245
  )
212
246
 
213
247
 
248
+ def experimental_add_compilation_backend(
249
+ target: litert_types.Target | None = None,
250
+ **kwargs: litert_types.Config,
251
+ ) -> Converter:
252
+ """Adds an AOT compilation target to the converter.
253
+
254
+ NOTE: This API is experimental and subject to change.
255
+
256
+ Args:
257
+ target: The target to compile for. If not specified, will compile to all
258
+ registered AOT targets in ai_edge_litert. See ai_edge_litert.aot.vendors
259
+ for more details. Adding a same target multiple times will be a no-op.
260
+ **kwargs: Additional arguments to pass to the backend compiler.
261
+
262
+ Returns:
263
+ The converter object itself.
264
+ """
265
+ return Converter().experimental_add_compilation_backend(target, **kwargs)
266
+
267
+
214
268
  def convert(
215
269
  module: torch.nn.Module = None,
216
270
  sample_args=None,
@@ -29,6 +29,8 @@ from torch.ao.quantization import quantize_pt2e
29
29
  import torchvision
30
30
 
31
31
  from absl.testing import absltest as googletest
32
+ from ai_edge_litert.aot.core import types as litert_types
33
+ from ai_edge_litert.aot.vendors import fallback_backend
32
34
  from ai_edge_litert import interpreter as tfl_interpreter # pylint: disable=g-direct-tensorflow-import
33
35
 
34
36
 
@@ -574,6 +576,29 @@ class TestConvert(googletest.TestCase):
574
576
  self.fail(f"Conversion failed with bloat16 inputs: {err}")
575
577
  # pylint: enable=broad-except
576
578
 
579
+ def test_compile_model(self):
580
+ """Tests AOT compilation of a simple Add module."""
581
+
582
+ class Add(nn.Module):
583
+
584
+ def forward(self, a, b):
585
+ return a + b
586
+
587
+ args = (
588
+ torch.randn((5, 10)),
589
+ torch.randn((5, 10)),
590
+ )
591
+ torch_module = Add().eval()
592
+ compilation_result = ai_edge_torch.experimental_add_compilation_backend(
593
+ fallback_backend.FallbackTarget()
594
+ ).convert(torch_module, args)
595
+ assert isinstance(compilation_result, litert_types.CompilationResult)
596
+ self.assertLen(compilation_result.models_with_backend, 1)
597
+ self.assertEqual(
598
+ compilation_result.models_with_backend[0][0].id(),
599
+ fallback_backend.FallbackBackend.id(),
600
+ )
601
+
577
602
 
578
603
  if __name__ == "__main__":
579
604
  googletest.main()
@@ -15,8 +15,10 @@
15
15
 
16
16
  """Example of building AMD-Llama-135m."""
17
17
 
18
+ from typing import Callable, Dict
18
19
  import ai_edge_torch.generative.layers.model_config as cfg
19
20
  from ai_edge_torch.generative.utilities import model_builder
21
+ import torch
20
22
  from torch import nn
21
23
 
22
24
  TENSOR_NAMES = model_builder.TENSOR_NAMES_WITH_SEPARATE_LM_HEAD
@@ -80,10 +82,15 @@ def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
80
82
  return config
81
83
 
82
84
 
83
- def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
85
+ def build_model(
86
+ checkpoint_path: str,
87
+ custom_loader: Callable[[str], Dict[str, torch.Tensor]] | None = None,
88
+ **kwargs
89
+ ) -> nn.Module:
84
90
  return model_builder.build_decoder_only_model(
85
91
  checkpoint_path=checkpoint_path,
86
92
  config=get_model_config(**kwargs),
87
93
  tensor_names=TENSOR_NAMES,
88
- model_class=AmdLlama
94
+ model_class=AmdLlama,
95
+ custom_loader=custom_loader,
89
96
  )
@@ -19,13 +19,19 @@ from absl import app
19
19
  from ai_edge_torch.generative.examples.amd_llama_135m import amd_llama_135m
20
20
  from ai_edge_torch.generative.utilities import converter
21
21
  from ai_edge_torch.generative.utilities import export_config
22
+ from ai_edge_torch.generative.utilities import loader
22
23
 
23
24
  flags = converter.define_conversion_flags("amd-llama-135m")
24
25
 
25
26
 
26
27
  def main(_):
28
+ checkpoint_path = flags.FLAGS.checkpoint_path
27
29
  pytorch_model = amd_llama_135m.build_model(
28
- flags.FLAGS.checkpoint_path, kv_cache_max_len=flags.FLAGS.kv_cache_max_len
30
+ checkpoint_path,
31
+ custom_loader=loader.maybe_get_custom_loader(
32
+ checkpoint_path, flags.FLAGS.custom_checkpoint_loader
33
+ ),
34
+ kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
29
35
  )
30
36
  converter.convert_to_tflite(
31
37
  pytorch_model,
@@ -17,15 +17,20 @@
17
17
 
18
18
  from absl import app
19
19
  from ai_edge_torch.generative.examples.deepseek import deepseek
20
- from ai_edge_torch.generative.layers import kv_cache
21
20
  from ai_edge_torch.generative.utilities import converter
22
21
  from ai_edge_torch.generative.utilities import export_config
22
+ from ai_edge_torch.generative.utilities import loader
23
23
 
24
24
  flags = converter.define_conversion_flags('deepseek')
25
25
 
26
26
  def main(_):
27
+ checkpoint_path = flags.FLAGS.checkpoint_path
27
28
  pytorch_model = deepseek.build_model(
28
- flags.FLAGS.checkpoint_path, kv_cache_max_len=flags.FLAGS.kv_cache_max_len
29
+ checkpoint_path,
30
+ custom_loader=loader.maybe_get_custom_loader(
31
+ checkpoint_path, flags.FLAGS.custom_checkpoint_loader
32
+ ),
33
+ kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
29
34
  )
30
35
  converter.convert_to_tflite(
31
36
  pytorch_model,
@@ -15,8 +15,10 @@
15
15
 
16
16
  """Example of building DeepSeek R1 distilled models."""
17
17
 
18
+ from typing import Callable, Dict
18
19
  import ai_edge_torch.generative.layers.model_config as cfg
19
20
  from ai_edge_torch.generative.utilities import model_builder
21
+ import torch
20
22
  from torch import nn
21
23
 
22
24
  TENSOR_NAMES = model_builder.TENSOR_NAMES_WITH_SEPARATE_LM_HEAD
@@ -84,10 +86,15 @@ def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
84
86
  return config
85
87
 
86
88
 
87
- def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
89
+ def build_model(
90
+ checkpoint_path: str,
91
+ custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
92
+ **kwargs
93
+ ) -> nn.Module:
88
94
  return model_builder.build_decoder_only_model(
89
95
  checkpoint_path=checkpoint_path,
90
96
  config=get_model_config(**kwargs),
91
97
  tensor_names=TENSOR_NAMES,
92
98
  model_class=DeepSeekDistillQwen,
99
+ custom_loader=custom_loader,
93
100
  )
@@ -19,13 +19,19 @@ from absl import app
19
19
  from ai_edge_torch.generative.examples.gemma import gemma1
20
20
  from ai_edge_torch.generative.utilities import converter
21
21
  from ai_edge_torch.generative.utilities import export_config
22
+ from ai_edge_torch.generative.utilities import loader
22
23
 
23
24
  flags = converter.define_conversion_flags("gemma-2b")
24
25
 
25
26
 
26
27
  def main(_):
28
+ checkpoint_path = flags.FLAGS.checkpoint_path
27
29
  pytorch_model = gemma1.build_2b_model(
28
- flags.FLAGS.checkpoint_path, kv_cache_max_len=flags.FLAGS.kv_cache_max_len
30
+ checkpoint_path,
31
+ custom_loader=loader.maybe_get_custom_loader(
32
+ checkpoint_path, flags.FLAGS.custom_checkpoint_loader
33
+ ),
34
+ kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
29
35
  )
30
36
  converter.convert_to_tflite(
31
37
  pytorch_model,
@@ -19,6 +19,7 @@ from absl import app
19
19
  from ai_edge_torch.generative.examples.gemma import gemma2
20
20
  from ai_edge_torch.generative.utilities import converter
21
21
  from ai_edge_torch.generative.utilities import export_config
22
+ from ai_edge_torch.generative.utilities import loader
22
23
 
23
24
  flags = converter.define_conversion_flags(
24
25
  "gemma2-2b", default_mask_as_input=True, default_transpose_kv_cache=True
@@ -26,8 +27,13 @@ flags = converter.define_conversion_flags(
26
27
 
27
28
 
28
29
  def main(_):
30
+ checkpoint_path = flags.FLAGS.checkpoint_path
29
31
  pytorch_model = gemma2.build_2b_model(
30
- flags.FLAGS.checkpoint_path, kv_cache_max_len=flags.FLAGS.kv_cache_max_len
32
+ checkpoint_path,
33
+ custom_loader=loader.maybe_get_custom_loader(
34
+ checkpoint_path, flags.FLAGS.custom_checkpoint_loader
35
+ ),
36
+ kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
31
37
  )
32
38
  converter.convert_to_tflite(
33
39
  pytorch_model,
@@ -15,9 +15,12 @@
15
15
 
16
16
  """Example of building a Gemma1 model."""
17
17
 
18
+ from typing import Callable, Dict
19
+
18
20
  import ai_edge_torch.generative.layers.model_config as cfg
19
21
  from ai_edge_torch.generative.utilities import model_builder
20
22
  import ai_edge_torch.generative.utilities.loader as loading_utils
23
+ import torch
21
24
  from torch import nn
22
25
 
23
26
  TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
@@ -99,10 +102,15 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
99
102
  return config
100
103
 
101
104
 
102
- def build_2b_model(checkpoint_path: str, **kwargs) -> nn.Module:
105
+ def build_2b_model(
106
+ checkpoint_path: str,
107
+ custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
108
+ **kwargs
109
+ ) -> nn.Module:
103
110
  return model_builder.build_decoder_only_model(
104
111
  checkpoint_path=checkpoint_path,
105
112
  config=get_model_config_2b(**kwargs),
106
113
  tensor_names=TENSOR_NAMES,
107
114
  model_class=Gemma1,
115
+ custom_loader=custom_loader,
108
116
  )
@@ -15,7 +15,7 @@
15
15
 
16
16
  """Example of building a Gemma2 model."""
17
17
 
18
- from typing import List, Optional, Tuple
18
+ from typing import Callable, Dict, List, Optional, Tuple
19
19
 
20
20
  from ai_edge_torch.generative.layers import attention
21
21
  from ai_edge_torch.generative.layers import builder
@@ -306,7 +306,11 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
306
306
  return config
307
307
 
308
308
 
309
- def build_2b_model(checkpoint_path: str, **kwargs) -> nn.Module:
309
+ def build_2b_model(
310
+ checkpoint_path: str,
311
+ custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
312
+ **kwargs,
313
+ ) -> nn.Module:
310
314
  for tensor_names in TENSOR_NAMES_DICT.values():
311
315
  try:
312
316
  return model_builder.build_decoder_only_model(
@@ -314,6 +318,7 @@ def build_2b_model(checkpoint_path: str, **kwargs) -> nn.Module:
314
318
  config=get_model_config_2b(**kwargs),
315
319
  tensor_names=tensor_names,
316
320
  model_class=Gemma2,
321
+ custom_loader=custom_loader,
317
322
  )
318
323
  except KeyError as _:
319
324
  continue
@@ -19,6 +19,7 @@ from absl import app
19
19
  from ai_edge_torch.generative.examples.gemma3 import gemma3
20
20
  from ai_edge_torch.generative.utilities import converter
21
21
  from ai_edge_torch.generative.utilities import export_config
22
+ from ai_edge_torch.generative.utilities import loader
22
23
 
23
24
  flags = converter.define_conversion_flags(
24
25
  'gemma3-1b', default_mask_as_input=True, default_transpose_kv_cache=True
@@ -32,9 +33,13 @@ _MODEL_SIZE = flags.DEFINE_string(
32
33
 
33
34
 
34
35
  def main(_):
36
+ checkpoint_path = flags.FLAGS.checkpoint_path
35
37
  if _MODEL_SIZE.value == '1b':
36
38
  pytorch_model = gemma3.build_model_1b(
37
- flags.FLAGS.checkpoint_path,
39
+ checkpoint_path,
40
+ custom_loader=loader.maybe_get_custom_loader(
41
+ checkpoint_path, flags.FLAGS.custom_checkpoint_loader
42
+ ),
38
43
  kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
39
44
  )
40
45
  else:
@@ -19,6 +19,7 @@ from absl import app
19
19
  from ai_edge_torch.generative.examples.hammer import hammer
20
20
  from ai_edge_torch.generative.utilities import converter
21
21
  from ai_edge_torch.generative.utilities import export_config
22
+ from ai_edge_torch.generative.utilities import loader
22
23
 
23
24
  flags = converter.define_conversion_flags('hammer')
24
25
 
@@ -36,8 +37,13 @@ _BUILDER = {
36
37
 
37
38
 
38
39
  def main(_):
40
+ checkpoint_path = flags.FLAGS.checkpoint_path
39
41
  pytorch_model = _BUILDER[_MODEL_SIZE.value](
40
- flags.FLAGS.checkpoint_path, kv_cache_max_len=flags.FLAGS.kv_cache_max_len
42
+ checkpoint_path,
43
+ custom_loader=loader.maybe_get_custom_loader(
44
+ checkpoint_path, flags.FLAGS.custom_checkpoint_loader
45
+ ),
46
+ kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
41
47
  )
42
48
  converter.convert_to_tflite(
43
49
  pytorch_model,
@@ -15,8 +15,10 @@
15
15
 
16
16
  """Example of building Hammer 2.1 models."""
17
17
 
18
+ from typing import Callable, Dict
18
19
  import ai_edge_torch.generative.layers.model_config as cfg
19
20
  from ai_edge_torch.generative.utilities import model_builder
21
+ import torch
20
22
  from torch import nn
21
23
 
22
24
  TENSOR_NAMES = model_builder.TENSOR_NAMES
@@ -89,19 +91,29 @@ def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
89
91
  return config
90
92
 
91
93
 
92
- def build_1_5b_model(checkpoint_path: str, **kwargs) -> nn.Module:
94
+ def build_1_5b_model(
95
+ checkpoint_path: str,
96
+ custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
97
+ **kwargs
98
+ ) -> nn.Module:
93
99
  return model_builder.build_decoder_only_model(
94
100
  checkpoint_path=checkpoint_path,
95
101
  config=get_1_5b_model_config(**kwargs),
96
102
  tensor_names=TENSOR_NAMES,
97
103
  model_class=Hammer,
104
+ custom_loader=custom_loader,
98
105
  )
99
106
 
100
107
 
101
- def build_0_5b_model(checkpoint_path: str, **kwargs) -> nn.Module:
108
+ def build_0_5b_model(
109
+ checkpoint_path: str,
110
+ custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
111
+ **kwargs
112
+ ) -> nn.Module:
102
113
  return model_builder.build_decoder_only_model(
103
114
  checkpoint_path=checkpoint_path,
104
115
  config=get_0_5b_model_config(**kwargs),
105
116
  tensor_names=TENSOR_NAMES,
106
117
  model_class=Hammer,
118
+ custom_loader=custom_loader,
107
119
  )
@@ -19,6 +19,7 @@ from absl import app
19
19
  from ai_edge_torch.generative.examples.llama import llama
20
20
  from ai_edge_torch.generative.utilities import converter
21
21
  from ai_edge_torch.generative.utilities import export_config
22
+ from ai_edge_torch.generative.utilities import loader
22
23
 
23
24
 
24
25
  flags = converter.define_conversion_flags('llama')
@@ -37,8 +38,13 @@ _BUILDER = {
37
38
 
38
39
 
39
40
  def main(_):
41
+ checkpoint_path = flags.FLAGS.checkpoint_path
40
42
  pytorch_model = _BUILDER[_MODEL_SIZE.value](
41
- flags.FLAGS.checkpoint_path, kv_cache_max_len=flags.FLAGS.kv_cache_max_len
43
+ checkpoint_path,
44
+ custom_loader=loader.maybe_get_custom_loader(
45
+ checkpoint_path, flags.FLAGS.custom_checkpoint_loader
46
+ ),
47
+ kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
42
48
  )
43
49
  converter.convert_to_tflite(
44
50
  pytorch_model,
@@ -17,7 +17,7 @@
17
17
 
18
18
  from functools import partial
19
19
  import math
20
- from typing import Tuple
20
+ from typing import Callable, Dict, Tuple
21
21
 
22
22
  import ai_edge_torch.generative.layers.model_config as cfg
23
23
  from ai_edge_torch.generative.utilities import model_builder
@@ -180,19 +180,38 @@ def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
180
180
 
181
181
 
182
182
  def _build_model(
183
- checkpoint_path: str, config: cfg.ModelConfig
183
+ checkpoint_path: str,
184
+ config: cfg.ModelConfig,
185
+ custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
184
186
  ) -> torch.nn.Module:
185
187
  return model_builder.build_decoder_only_model(
186
188
  checkpoint_path=checkpoint_path,
187
189
  config=config,
188
190
  tensor_names=TENSOR_NAMES,
189
191
  model_class=Llama,
192
+ custom_loader=custom_loader,
190
193
  )
191
194
 
192
195
 
193
- def build_1b_model(checkpoint_path: str, **kwargs) -> torch.nn.Module:
194
- return _build_model(checkpoint_path, get_1b_model_config(**kwargs))
196
+ def build_1b_model(
197
+ checkpoint_path: str,
198
+ custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
199
+ **kwargs
200
+ ) -> torch.nn.Module:
201
+ return _build_model(
202
+ checkpoint_path,
203
+ get_1b_model_config(**kwargs),
204
+ custom_loader=custom_loader,
205
+ )
195
206
 
196
207
 
197
- def build_3b_model(checkpoint_path: str, **kwargs) -> torch.nn.Module:
198
- return _build_model(checkpoint_path, get_3b_model_config(**kwargs))
208
+ def build_3b_model(
209
+ checkpoint_path: str,
210
+ custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
211
+ **kwargs
212
+ ) -> torch.nn.Module:
213
+ return _build_model(
214
+ checkpoint_path,
215
+ get_3b_model_config(**kwargs),
216
+ custom_loader=custom_loader,
217
+ )
@@ -22,7 +22,6 @@ from absl import app
22
22
  from absl import flags
23
23
  import ai_edge_torch
24
24
  from ai_edge_torch.generative.examples.moonshine import moonshine
25
- from ai_edge_torch.generative.utilities import converter
26
25
  import torch
27
26
 
28
27
  _CHECKPOINT_PATH = flags.DEFINE_string(
@@ -19,13 +19,19 @@ from absl import app
19
19
  from ai_edge_torch.generative.examples.openelm import openelm
20
20
  from ai_edge_torch.generative.utilities import converter
21
21
  from ai_edge_torch.generative.utilities import export_config
22
+ from ai_edge_torch.generative.utilities import loader
22
23
 
23
24
  flags = converter.define_conversion_flags("openelm")
24
25
 
25
26
 
26
27
  def main(_):
28
+ checkpoint_path = flags.FLAGS.checkpoint_path
27
29
  pytorch_model = openelm.build_model(
28
- flags.FLAGS.checkpoint_path, kv_cache_max_len=flags.FLAGS.kv_cache_max_len
30
+ checkpoint_path,
31
+ custom_loader=loader.maybe_get_custom_loader(
32
+ checkpoint_path, flags.FLAGS.custom_checkpoint_loader
33
+ ),
34
+ kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
29
35
  )
30
36
  converter.convert_to_tflite(
31
37
  pytorch_model,
@@ -15,9 +15,11 @@
15
15
 
16
16
  """Example of building an OpenELM model."""
17
17
 
18
+ from typing import Callable, Dict
18
19
  import ai_edge_torch.generative.layers.model_config as cfg
19
20
  from ai_edge_torch.generative.utilities import model_builder
20
21
  import ai_edge_torch.generative.utilities.loader as loading_utils
22
+ import torch
21
23
  from torch import nn
22
24
 
23
25
  TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
@@ -118,10 +120,15 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
118
120
  return config
119
121
 
120
122
 
121
- def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
123
+ def build_model(
124
+ checkpoint_path: str,
125
+ custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
126
+ **kwargs
127
+ ) -> nn.Module:
122
128
  return model_builder.build_decoder_only_model(
123
129
  checkpoint_path=checkpoint_path,
124
130
  config=get_model_config(**kwargs),
125
131
  tensor_names=TENSOR_NAMES,
126
132
  model_class=OpenELM,
133
+ custom_loader=custom_loader,
127
134
  )
@@ -19,6 +19,7 @@ from absl import app
19
19
  from ai_edge_torch.generative.examples.paligemma import paligemma
20
20
  from ai_edge_torch.generative.utilities import converter
21
21
  from ai_edge_torch.generative.utilities import export_config
22
+ from ai_edge_torch.generative.utilities import loader
22
23
  import torch
23
24
 
24
25
  flags = converter.define_conversion_flags('paligemma2-3b-224')
@@ -32,9 +33,13 @@ _VERSION = flags.DEFINE_enum(
32
33
 
33
34
 
34
35
  def main(_):
36
+ checkpoint_path = flags.FLAGS.checkpoint_path
35
37
  pytorch_model = paligemma.build_model(
36
- flags.FLAGS.checkpoint_path,
38
+ checkpoint_path,
37
39
  version=int(_VERSION.value),
40
+ custom_loader=loader.maybe_get_custom_loader(
41
+ checkpoint_path, flags.FLAGS.custom_checkpoint_loader
42
+ ),
38
43
  kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
39
44
  )
40
45
 
@@ -113,6 +113,7 @@ def get_decoder_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
113
113
  type=cfg.NormalizationType.RMS_NORM,
114
114
  epsilon=1e-6,
115
115
  zero_centered=True,
116
+ enable_hlfb=True,
116
117
  )
117
118
  block_config = cfg.TransformerBlockConfig(
118
119
  attn_config=attn_config,
@@ -96,6 +96,7 @@ def get_decoder2_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
96
96
  type=cfg.NormalizationType.RMS_NORM,
97
97
  epsilon=1e-6,
98
98
  zero_centered=True,
99
+ enable_hlfb=True,
99
100
  )
100
101
  ff_config = cfg.FeedForwardConfig(
101
102
  type=cfg.FeedForwardType.GATED,