ai-edge-torch-nightly 0.3.0.dev20241206__py3-none-any.whl → 0.3.0.dev20241214__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. ai_edge_torch/__init__.py +1 -1
  2. ai_edge_torch/_config.py +52 -0
  3. ai_edge_torch/_convert/test/test_convert.py +1 -2
  4. ai_edge_torch/debug/test/test_culprit.py +8 -3
  5. ai_edge_torch/generative/examples/amd_llama_135m/amd_llama_135m.py +8 -3
  6. ai_edge_torch/generative/examples/amd_llama_135m/convert_to_tflite.py +2 -0
  7. ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py +2 -0
  8. ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +2 -0
  9. ai_edge_torch/generative/examples/gemma/gemma1.py +8 -3
  10. ai_edge_torch/generative/examples/gemma/gemma2.py +15 -8
  11. ai_edge_torch/generative/examples/llama/convert_to_tflite.py +2 -0
  12. ai_edge_torch/generative/examples/llama/llama.py +11 -17
  13. ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +2 -0
  14. ai_edge_torch/generative/examples/openelm/openelm.py +8 -3
  15. ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py +2 -0
  16. ai_edge_torch/generative/examples/paligemma/decoder.py +10 -9
  17. ai_edge_torch/generative/examples/paligemma/paligemma.py +11 -1
  18. ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py +2 -0
  19. ai_edge_torch/generative/examples/phi/convert_to_tflite.py +2 -0
  20. ai_edge_torch/generative/examples/phi/phi2.py +8 -3
  21. ai_edge_torch/generative/examples/phi/phi3.py +7 -9
  22. ai_edge_torch/generative/examples/qwen/convert_to_tflite.py +2 -0
  23. ai_edge_torch/generative/examples/qwen/qwen.py +12 -9
  24. ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +3 -0
  25. ai_edge_torch/generative/examples/smollm/smollm.py +8 -3
  26. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +12 -2
  27. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +2 -0
  28. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +8 -3
  29. ai_edge_torch/generative/layers/attention.py +2 -6
  30. ai_edge_torch/generative/layers/kv_cache.py +24 -18
  31. ai_edge_torch/generative/layers/normalization.py +1 -3
  32. ai_edge_torch/generative/test/test_kv_cache.py +3 -3
  33. ai_edge_torch/generative/test/test_model_conversion.py +12 -14
  34. ai_edge_torch/generative/test/test_model_conversion_large.py +63 -59
  35. ai_edge_torch/generative/test/utils.py +31 -6
  36. ai_edge_torch/generative/utilities/converter.py +25 -4
  37. ai_edge_torch/generative/utilities/model_builder.py +24 -4
  38. ai_edge_torch/generative/utilities/verifier.py +16 -2
  39. ai_edge_torch/lowertools/_shim.py +4 -2
  40. ai_edge_torch/lowertools/test_utils.py +4 -2
  41. ai_edge_torch/odml_torch/lowerings/__init__.py +1 -1
  42. ai_edge_torch/odml_torch/lowerings/_basic.py +5 -3
  43. ai_edge_torch/odml_torch/lowerings/_convolution.py +3 -1
  44. ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +28 -2
  45. ai_edge_torch/odml_torch/lowerings/_layer_norm.py +11 -2
  46. ai_edge_torch/odml_torch/lowerings/_quantized_decomposed.py +9 -9
  47. ai_edge_torch/odml_torch/lowerings/decomp.py +65 -0
  48. ai_edge_torch/odml_torch/lowerings/registry.py +0 -32
  49. ai_edge_torch/version.py +1 -1
  50. {ai_edge_torch_nightly-0.3.0.dev20241206.dist-info → ai_edge_torch_nightly-0.3.0.dev20241214.dist-info}/METADATA +7 -5
  51. {ai_edge_torch_nightly-0.3.0.dev20241206.dist-info → ai_edge_torch_nightly-0.3.0.dev20241214.dist-info}/RECORD +54 -54
  52. ai_edge_torch/config.py +0 -27
  53. ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py +0 -283
  54. {ai_edge_torch_nightly-0.3.0.dev20241206.dist-info → ai_edge_torch_nightly-0.3.0.dev20241214.dist-info}/LICENSE +0 -0
  55. {ai_edge_torch_nightly-0.3.0.dev20241206.dist-info → ai_edge_torch_nightly-0.3.0.dev20241214.dist-info}/WHEEL +0 -0
  56. {ai_edge_torch_nightly-0.3.0.dev20241206.dist-info → ai_edge_torch_nightly-0.3.0.dev20241214.dist-info}/top_level.txt +0 -0
ai_edge_torch/__init__.py CHANGED
@@ -13,13 +13,13 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
+ from ai_edge_torch._config import config
16
17
  from ai_edge_torch._convert.converter import convert
17
18
  from ai_edge_torch._convert.converter import signature
18
19
  from ai_edge_torch._convert.to_channel_last_io import to_channel_last_io
19
20
  from ai_edge_torch.model import Model
20
21
  from ai_edge_torch.version import __version__
21
22
 
22
-
23
23
  def load(path: str) -> Model:
24
24
  """Imports an ai_edge_torch model from disk.
25
25
 
@@ -0,0 +1,52 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Provides a configuration for the ai-edge-torch."""
17
+
18
+ import functools
19
+ import logging
20
+ import os
21
+
22
+ __all__ = ["config"]
23
+
24
+
25
+ class _Config:
26
+ """ai-edge-torch global configs."""
27
+
28
+ @property
29
+ @functools.cache # pylint: disable=method-cache-max-size-none
30
+ def use_torch_xla(self) -> bool:
31
+ """True if using torch_xla to lower torch ops to StableHLO.
32
+
33
+ To use torch_xla as the lowering backend, set environment variable
34
+ `USE_TORCH_XLA` to "true".
35
+ """
36
+ var = os.environ.get("USE_TORCH_XLA", "false")
37
+ var = var.lower().strip()
38
+ if var in ("y", "yes", "t", "true", "on", "1"):
39
+ return True
40
+ elif var in ("n", "no", "f", "false", "off", "0"):
41
+ return False
42
+ else:
43
+ logging.warning("Invalid USE_TORCH_XLA value is ignored: %s.", var)
44
+ return False
45
+
46
+ @property
47
+ def in_oss(self) -> bool:
48
+ """True if the code is not running in google internal environment."""
49
+ return True
50
+
51
+
52
+ config = _Config()
@@ -19,7 +19,6 @@ import os
19
19
  from typing import Tuple
20
20
 
21
21
  import ai_edge_torch
22
- from ai_edge_torch import config
23
22
  from ai_edge_torch._convert import conversion_utils
24
23
  from ai_edge_torch.quantize import pt2e_quantizer
25
24
  from ai_edge_torch.testing import model_coverage
@@ -292,7 +291,7 @@ class TestConvert(googletest.TestCase):
292
291
  self.assertTrue(result)
293
292
 
294
293
  @googletest.skipIf(
295
- not config.Config.use_torch_xla,
294
+ not ai_edge_torch.config.use_torch_xla,
296
295
  reason="Shape polymorphism is not yet support with odml_torch.",
297
296
  )
298
297
  def test_convert_model_with_dynamic_batch(self):
@@ -15,14 +15,14 @@
15
15
 
16
16
 
17
17
  import ast
18
- import io
19
- import sys
20
18
 
21
- from ai_edge_torch.debug import find_culprits
19
+ import ai_edge_torch.debug
22
20
  import torch
23
21
 
24
22
  from absl.testing import absltest as googletest
25
23
 
24
+ find_culprits = ai_edge_torch.debug.find_culprits
25
+
26
26
  _test_culprit_lib = torch.library.Library("test_culprit", "DEF")
27
27
 
28
28
  _test_culprit_lib.define("non_lowerable_op(Tensor x) -> Tensor")
@@ -52,6 +52,11 @@ class BadModel(torch.nn.Module):
52
52
 
53
53
  class TestCulprit(googletest.TestCase):
54
54
 
55
+ def setUp(self):
56
+ super().setUp()
57
+ torch.manual_seed(0)
58
+ torch._dynamo.reset()
59
+
55
60
  def test_find_culprits(self):
56
61
  model = BadModel().eval()
57
62
  args = (torch.rand(10),)
@@ -17,10 +17,16 @@
17
17
 
18
18
  import ai_edge_torch.generative.layers.model_config as cfg
19
19
  from ai_edge_torch.generative.utilities import model_builder
20
+ from torch import nn
20
21
 
21
22
  TENSOR_NAMES = model_builder.TENSOR_NAMES_WITH_SEPARATE_LM_HEAD
22
23
 
23
24
 
25
+ class AmdLlama(model_builder.DecoderOnlyModel):
26
+ """An AMD-Llama model built from the Edge Generative API layers."""
27
+ pass
28
+
29
+
24
30
  def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
25
31
  """Returns the model config for an AMD-Llama-135m model.
26
32
 
@@ -72,11 +78,10 @@ def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
72
78
  return config
73
79
 
74
80
 
75
- def build_model(
76
- checkpoint_path: str, **kwargs
77
- ) -> model_builder.DecoderOnlyModel:
81
+ def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
78
82
  return model_builder.build_decoder_only_model(
79
83
  checkpoint_path=checkpoint_path,
80
84
  config=get_model_config(**kwargs),
81
85
  tensor_names=TENSOR_NAMES,
86
+ model_class=AmdLlama
82
87
  )
@@ -22,6 +22,7 @@ from absl import app
22
22
  from absl import flags
23
23
  from ai_edge_torch.generative.examples.amd_llama_135m import amd_llama_135m
24
24
  from ai_edge_torch.generative.utilities import converter
25
+ from ai_edge_torch.generative.utilities.model_builder import ExportConfig
25
26
 
26
27
  _CHECKPOINT_PATH = flags.DEFINE_string(
27
28
  'checkpoint_path',
@@ -61,6 +62,7 @@ def main(_):
61
62
  tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
62
63
  prefill_seq_len=_PREFILL_SEQ_LEN.value,
63
64
  quantize=_QUANTIZE.value,
65
+ export_config=ExportConfig(),
64
66
  )
65
67
 
66
68
 
@@ -22,6 +22,7 @@ from absl import app
22
22
  from absl import flags
23
23
  from ai_edge_torch.generative.examples.gemma import gemma1
24
24
  from ai_edge_torch.generative.utilities import converter
25
+ from ai_edge_torch.generative.utilities.model_builder import ExportConfig
25
26
 
26
27
  _CHECKPOINT_PATH = flags.DEFINE_string(
27
28
  'checkpoint_path',
@@ -61,6 +62,7 @@ def main(_):
61
62
  tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
62
63
  prefill_seq_len=_PREFILL_SEQ_LENS.value,
63
64
  quantize=_QUANTIZE.value,
65
+ export_config=ExportConfig(),
64
66
  )
65
67
 
66
68
 
@@ -22,6 +22,7 @@ from absl import app
22
22
  from absl import flags
23
23
  from ai_edge_torch.generative.examples.gemma import gemma2
24
24
  from ai_edge_torch.generative.utilities import converter
25
+ from ai_edge_torch.generative.utilities.model_builder import ExportConfig
25
26
 
26
27
  _CHECKPOINT_PATH = flags.DEFINE_string(
27
28
  'checkpoint_path',
@@ -61,6 +62,7 @@ def main(_):
61
62
  tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
62
63
  prefill_seq_len=_PREFILL_SEQ_LENS.value,
63
64
  quantize=_QUANTIZE.value,
65
+ export_config=ExportConfig(),
64
66
  )
65
67
 
66
68
 
@@ -18,6 +18,7 @@
18
18
  import ai_edge_torch.generative.layers.model_config as cfg
19
19
  from ai_edge_torch.generative.utilities import model_builder
20
20
  import ai_edge_torch.generative.utilities.loader as loading_utils
21
+ from torch import nn
21
22
 
22
23
  TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
23
24
  ff_up_proj="model.layers.{}.mlp.up_proj",
@@ -33,6 +34,11 @@ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
33
34
  )
34
35
 
35
36
 
37
+ class Gemma1(model_builder.DecoderOnlyModel):
38
+ """A Gemma1 model built from the Edge Generative API layers."""
39
+ pass
40
+
41
+
36
42
  def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
37
43
  """Returns the model config for a Gemma 2B model.
38
44
 
@@ -91,11 +97,10 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
91
97
  return config
92
98
 
93
99
 
94
- def build_2b_model(
95
- checkpoint_path: str, **kwargs
96
- ) -> model_builder.DecoderOnlyModel:
100
+ def build_2b_model(checkpoint_path: str, **kwargs) -> nn.Module:
97
101
  return model_builder.build_decoder_only_model(
98
102
  checkpoint_path=checkpoint_path,
99
103
  config=get_model_config_2b(**kwargs),
100
104
  tensor_names=TENSOR_NAMES,
105
+ model_class=Gemma1,
101
106
  )
@@ -22,6 +22,7 @@ from ai_edge_torch.generative.layers import builder
22
22
  from ai_edge_torch.generative.layers import kv_cache as kv_utils
23
23
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
24
24
  import ai_edge_torch.generative.layers.model_config as cfg
25
+ from ai_edge_torch.generative.utilities import model_builder
25
26
  import ai_edge_torch.generative.utilities.loader as loading_utils
26
27
  import torch
27
28
  from torch import nn
@@ -132,6 +133,7 @@ class Gemma2(nn.Module):
132
133
  tokens: torch.Tensor,
133
134
  input_pos: torch.Tensor,
134
135
  kv_cache: kv_utils.KVCache,
136
+ export_config: Optional[model_builder.ExportConfig] = None,
135
137
  ) -> dict[torch.Tensor, kv_utils.KVCache]:
136
138
  _, seq_len = tokens.size()
137
139
  assert self.config.max_seq_len >= seq_len, (
@@ -162,6 +164,13 @@ class Gemma2(nn.Module):
162
164
  updated_kv_entires.append(kv_entry)
163
165
  updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
164
166
 
167
+ if export_config is not None:
168
+ if (
169
+ torch.numel(input_pos) > 1
170
+ and not export_config.output_logits_on_prefill
171
+ ):
172
+ return {"kv_cache": updated_kv_cache}
173
+
165
174
  x = self.final_norm(x)
166
175
  res = self.lm_head(x) # (b, t, vocab_size)
167
176
  if self.config.final_logit_softcap is not None:
@@ -250,11 +259,9 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
250
259
 
251
260
 
252
261
  def build_2b_model(checkpoint_path: str, **kwargs) -> nn.Module:
253
- config = get_model_config_2b(**kwargs)
254
- model = Gemma2(config)
255
- loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
256
- # Since embedding and lm-head use the same weight, we need to set strict
257
- # to False.
258
- loader.load(model, strict=False)
259
- model.eval()
260
- return model
262
+ return model_builder.build_decoder_only_model(
263
+ checkpoint_path=checkpoint_path,
264
+ config=get_model_config_2b(**kwargs),
265
+ tensor_names=TENSOR_NAMES,
266
+ model_class=Gemma2,
267
+ )
@@ -22,6 +22,7 @@ from absl import app
22
22
  from absl import flags
23
23
  from ai_edge_torch.generative.examples.llama import llama
24
24
  from ai_edge_torch.generative.utilities import converter
25
+ from ai_edge_torch.generative.utilities.model_builder import ExportConfig
25
26
 
26
27
  _MODEL_SIZE = flags.DEFINE_enum(
27
28
  'model_size',
@@ -72,6 +73,7 @@ def main(_):
72
73
  tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
73
74
  prefill_seq_len=_PREFILL_SEQ_LENS.value,
74
75
  quantize=_QUANTIZE.value,
76
+ export_config=ExportConfig(),
75
77
  )
76
78
 
77
79
 
@@ -20,7 +20,6 @@ from typing import Tuple
20
20
 
21
21
  import ai_edge_torch.generative.layers.model_config as cfg
22
22
  from ai_edge_torch.generative.utilities import model_builder
23
- import ai_edge_torch.generative.utilities.loader as loading_utils
24
23
  import torch
25
24
 
26
25
  TENSOR_NAMES = model_builder.TENSOR_NAMES
@@ -177,23 +176,18 @@ def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
177
176
 
178
177
  def _build_model(
179
178
  checkpoint_path: str, config: cfg.ModelConfig
180
- ) -> model_builder.DecoderOnlyModel:
181
- model = Llama(config)
182
- loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
183
- # Since embedding and lm-head use the same weight, we need to set strict
184
- # to False.
185
- loader.load(model, strict=False)
186
- model.eval()
187
- return model
188
-
189
-
190
- def build_1b_model(
191
- checkpoint_path: str, **kwargs
192
- ) -> model_builder.DecoderOnlyModel:
179
+ ) -> torch.nn.Module:
180
+ return model_builder.build_decoder_only_model(
181
+ checkpoint_path=checkpoint_path,
182
+ config=config,
183
+ tensor_names=TENSOR_NAMES,
184
+ model_class=Llama,
185
+ )
186
+
187
+
188
+ def build_1b_model(checkpoint_path: str, **kwargs) -> torch.nn.Module:
193
189
  return _build_model(checkpoint_path, get_1b_model_config(**kwargs))
194
190
 
195
191
 
196
- def build_3b_model(
197
- checkpoint_path: str, **kwargs
198
- ) -> model_builder.DecoderOnlyModel:
192
+ def build_3b_model(checkpoint_path: str, **kwargs) -> torch.nn.Module:
199
193
  return _build_model(checkpoint_path, get_3b_model_config(**kwargs))
@@ -22,6 +22,7 @@ from absl import app
22
22
  from absl import flags
23
23
  from ai_edge_torch.generative.examples.openelm import openelm
24
24
  from ai_edge_torch.generative.utilities import converter
25
+ from ai_edge_torch.generative.utilities.model_builder import ExportConfig
25
26
 
26
27
  _CHECKPOINT_PATH = flags.DEFINE_string(
27
28
  'checkpoint_path',
@@ -64,6 +65,7 @@ def main(_):
64
65
  tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
65
66
  prefill_seq_len=_PREFILL_SEQ_LENS.value,
66
67
  quantize=_QUANTIZE.value,
68
+ export_config=ExportConfig(),
67
69
  )
68
70
 
69
71
 
@@ -18,6 +18,7 @@
18
18
  import ai_edge_torch.generative.layers.model_config as cfg
19
19
  from ai_edge_torch.generative.utilities import model_builder
20
20
  import ai_edge_torch.generative.utilities.loader as loading_utils
21
+ from torch import nn
21
22
 
22
23
  TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
23
24
  ff_up_proj="transformer.layers.{}.ffn.proj_1",
@@ -34,6 +35,11 @@ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
34
35
  )
35
36
 
36
37
 
38
+ class OpenELM(model_builder.DecoderOnlyModel):
39
+ """An OpenELM model built from the Edge Generative API layers."""
40
+ pass
41
+
42
+
37
43
  def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
38
44
  """Returns the model config for an OpenELM model.
39
45
 
@@ -112,11 +118,10 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
112
118
  return config
113
119
 
114
120
 
115
- def build_model(
116
- checkpoint_path: str, **kwargs
117
- ) -> model_builder.DecoderOnlyModel:
121
+ def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
118
122
  return model_builder.build_decoder_only_model(
119
123
  checkpoint_path=checkpoint_path,
120
124
  config=get_model_config(**kwargs),
121
125
  tensor_names=TENSOR_NAMES,
126
+ model_class=OpenELM,
122
127
  )
@@ -26,6 +26,7 @@ from absl import app
26
26
  from absl import flags
27
27
  from ai_edge_torch.generative.examples.paligemma import paligemma
28
28
  from ai_edge_torch.generative.utilities import converter
29
+ from ai_edge_torch.generative.utilities.model_builder import ExportConfig
29
30
  import torch
30
31
 
31
32
  _CHECKPOINT_PATH = flags.DEFINE_string(
@@ -73,6 +74,7 @@ def main(_):
73
74
  pixel_values_size=torch.Size(_PIXEL_VALUES_SIZE.value),
74
75
  quantize=_QUANTIZE.value,
75
76
  config=pytorch_model.config.decoder_config,
77
+ export_config=ExportConfig(),
76
78
  )
77
79
 
78
80
 
@@ -15,6 +15,8 @@
15
15
 
16
16
  """Example of building a decoder of PaliGemma 3B model which is Gemma1."""
17
17
 
18
+ from typing import Optional
19
+
18
20
  from ai_edge_torch.generative.layers import kv_cache as kv_utils
19
21
  import ai_edge_torch.generative.layers.model_config as cfg
20
22
  from ai_edge_torch.generative.utilities import model_builder
@@ -51,6 +53,7 @@ class Decoder(model_builder.DecoderOnlyModel):
51
53
  input_pos: torch.Tensor,
52
54
  kv_cache: kv_utils.KVCache,
53
55
  input_embeds: torch.Tensor = None,
56
+ export_config: Optional[model_builder.ExportConfig] = None,
54
57
  ) -> dict[torch.Tensor, kv_utils.KVCache]:
55
58
  if input_embeds is None:
56
59
  return super().forward(tokens, input_pos, kv_cache)
@@ -130,12 +133,10 @@ def get_fake_decoder_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
130
133
  return config
131
134
 
132
135
 
133
- def build_decoder(
134
- checkpoint_path: str, **kwargs
135
- ) -> model_builder.DecoderOnlyModel:
136
- decoder = Decoder(get_decoder_config(**kwargs))
137
- loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
138
- # Loose the strictness because only decoder is being loaded.
139
- loader.load(decoder, strict=False)
140
- decoder.eval()
141
- return decoder
136
+ def build_decoder(checkpoint_path: str, **kwargs) -> torch.nn.Module:
137
+ return model_builder.build_decoder_only_model(
138
+ checkpoint_path=checkpoint_path,
139
+ config=get_decoder_config(**kwargs),
140
+ tensor_names=TENSOR_NAMES,
141
+ model_class=Decoder,
142
+ )
@@ -16,11 +16,13 @@
16
16
  """Example of building a full-stack of PaliGemma model."""
17
17
 
18
18
  from dataclasses import dataclass
19
+ from typing import Optional
19
20
 
20
21
  from ai_edge_torch.generative.examples.paligemma import decoder
21
22
  from ai_edge_torch.generative.examples.paligemma import image_encoder
22
23
  import ai_edge_torch.generative.layers.kv_cache as kv_utils
23
24
  import ai_edge_torch.generative.layers.model_config as cfg
25
+ from ai_edge_torch.generative.utilities import model_builder
24
26
  import ai_edge_torch.generative.utilities.loader as loading_utils
25
27
  import torch
26
28
  from torch import nn
@@ -67,9 +69,16 @@ class PaliGemma(nn.Module):
67
69
  input_pos: torch.Tensor,
68
70
  kv_cache: kv_utils.KVCache,
69
71
  pixel_values: torch.Tensor = None,
72
+ export_config: Optional[model_builder.ExportConfig] = None,
70
73
  ) -> dict[torch.Tensor, kv_utils.KVCache]:
71
74
  if pixel_values is None:
72
- return self.decoder(tokens, input_pos, kv_cache)
75
+ return self.decoder(
76
+ tokens=tokens,
77
+ input_pos=input_pos,
78
+ kv_cache=kv_cache,
79
+ input_embeds=None,
80
+ export_config=export_config
81
+ )
73
82
 
74
83
  input_embeds = self.decoder.tok_embedding(tokens)
75
84
 
@@ -100,6 +109,7 @@ class PaliGemma(nn.Module):
100
109
  input_pos=input_pos,
101
110
  kv_cache=kv_cache,
102
111
  input_embeds=input_embeds,
112
+ export_config=export_config,
103
113
  )
104
114
 
105
115
 
@@ -22,6 +22,7 @@ from absl import app
22
22
  from absl import flags
23
23
  from ai_edge_torch.generative.examples.phi import phi3
24
24
  from ai_edge_torch.generative.utilities import converter
25
+ from ai_edge_torch.generative.utilities.model_builder import ExportConfig
25
26
 
26
27
  _CHECKPOINT_PATH = flags.DEFINE_string(
27
28
  'checkpoint_path',
@@ -61,6 +62,7 @@ def main(_):
61
62
  tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
62
63
  prefill_seq_len=_PREFILL_SEQ_LENS.value,
63
64
  quantize=_QUANTIZE.value,
65
+ export_config=ExportConfig(),
64
66
  )
65
67
 
66
68
 
@@ -22,6 +22,7 @@ from absl import app
22
22
  from absl import flags
23
23
  from ai_edge_torch.generative.examples.phi import phi2
24
24
  from ai_edge_torch.generative.utilities import converter
25
+ from ai_edge_torch.generative.utilities.model_builder import ExportConfig
25
26
 
26
27
  _CHECKPOINT_PATH = flags.DEFINE_string(
27
28
  'checkpoint_path',
@@ -61,6 +62,7 @@ def main(_):
61
62
  tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
62
63
  prefill_seq_len=_PREFILL_SEQ_LENS.value,
63
64
  quantize=_QUANTIZE.value,
65
+ export_config=ExportConfig(),
64
66
  )
65
67
 
66
68
 
@@ -18,6 +18,7 @@
18
18
  import ai_edge_torch.generative.layers.model_config as cfg
19
19
  from ai_edge_torch.generative.utilities import model_builder
20
20
  import ai_edge_torch.generative.utilities.loader as loading_utils
21
+ from torch import nn
21
22
 
22
23
  TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
23
24
  ff_up_proj="model.layers.{}.mlp.fc1",
@@ -33,6 +34,11 @@ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
33
34
  )
34
35
 
35
36
 
37
+ class Phi2(model_builder.DecoderOnlyModel):
38
+ """A Phi-2 model built from the Edge Generative API layers."""
39
+ pass
40
+
41
+
36
42
  def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
37
43
  """Returns the model config for a Phi-2 model.
38
44
 
@@ -92,11 +98,10 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
92
98
  return config
93
99
 
94
100
 
95
- def build_model(
96
- checkpoint_path: str, **kwargs
97
- ) -> model_builder.DecoderOnlyModel:
101
+ def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
98
102
  return model_builder.build_decoder_only_model(
99
103
  checkpoint_path=checkpoint_path,
100
104
  config=get_model_config(**kwargs),
101
105
  tensor_names=TENSOR_NAMES,
106
+ model_class=Phi2,
102
107
  )
@@ -207,13 +207,11 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
207
207
  return config
208
208
 
209
209
 
210
- def build_model(
211
- checkpoint_path: str, **kwargs
212
- ) -> model_builder.DecoderOnlyModel:
210
+ def build_model(checkpoint_path: str, **kwargs) -> torch.nn.Module:
213
211
  """Instantiates the model instance and load checkpoint if provided."""
214
- config = get_model_config(**kwargs)
215
- model = Phi3_5Mini(config)
216
- loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
217
- loader.load(model)
218
- model.eval()
219
- return model
212
+ return model_builder.build_decoder_only_model(
213
+ checkpoint_path=checkpoint_path,
214
+ config=get_model_config(**kwargs),
215
+ tensor_names=TENSOR_NAMES,
216
+ model_class=Phi3_5Mini,
217
+ )
@@ -22,6 +22,7 @@ from absl import app
22
22
  from absl import flags
23
23
  from ai_edge_torch.generative.examples.qwen import qwen
24
24
  from ai_edge_torch.generative.utilities import converter
25
+ from ai_edge_torch.generative.utilities.model_builder import ExportConfig
25
26
 
26
27
  _MODEL_SIZE = flags.DEFINE_enum(
27
28
  'model_size',
@@ -76,6 +77,7 @@ def main(_):
76
77
  tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
77
78
  prefill_seq_len=_PREFILL_SEQ_LENS.value,
78
79
  quantize=_QUANTIZE.value,
80
+ export_config=ExportConfig(),
79
81
  )
80
82
 
81
83
 
@@ -17,10 +17,16 @@
17
17
 
18
18
  import ai_edge_torch.generative.layers.model_config as cfg
19
19
  from ai_edge_torch.generative.utilities import model_builder
20
+ from torch import nn
20
21
 
21
22
  TENSOR_NAMES = model_builder.TENSOR_NAMES
22
23
 
23
24
 
25
+ class Qwen(model_builder.DecoderOnlyModel):
26
+ """A Qwen model built from the Edge Generative API layers."""
27
+ pass
28
+
29
+
24
30
  def get_3b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
25
31
  """Returns the model config for a Qwen 2.5 3B model.
26
32
 
@@ -101,31 +107,28 @@ def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
101
107
  return config
102
108
 
103
109
 
104
- def build_3b_model(
105
- checkpoint_path: str, **kwargs
106
- ) -> model_builder.DecoderOnlyModel:
110
+ def build_3b_model(checkpoint_path: str, **kwargs) -> nn.Module:
107
111
  return model_builder.build_decoder_only_model(
108
112
  checkpoint_path=checkpoint_path,
109
113
  config=get_3b_model_config(**kwargs),
110
114
  tensor_names=TENSOR_NAMES,
115
+ model_class=Qwen,
111
116
  )
112
117
 
113
118
 
114
- def build_1_5b_model(
115
- checkpoint_path: str, **kwargs
116
- ) -> model_builder.DecoderOnlyModel:
119
+ def build_1_5b_model(checkpoint_path: str, **kwargs) -> nn.Module:
117
120
  return model_builder.build_decoder_only_model(
118
121
  checkpoint_path=checkpoint_path,
119
122
  config=get_1_5b_model_config(**kwargs),
120
123
  tensor_names=TENSOR_NAMES,
124
+ model_class=Qwen,
121
125
  )
122
126
 
123
127
 
124
- def build_0_5b_model(
125
- checkpoint_path: str, **kwargs
126
- ) -> model_builder.DecoderOnlyModel:
128
+ def build_0_5b_model(checkpoint_path: str, **kwargs) -> nn.Module:
127
129
  return model_builder.build_decoder_only_model(
128
130
  checkpoint_path=checkpoint_path,
129
131
  config=get_0_5b_model_config(**kwargs),
130
132
  tensor_names=TENSOR_NAMES,
133
+ model_class=Qwen,
131
134
  )
@@ -22,6 +22,7 @@ from absl import app
22
22
  from absl import flags
23
23
  from ai_edge_torch.generative.examples.smollm import smollm
24
24
  from ai_edge_torch.generative.utilities import converter
25
+ from ai_edge_torch.generative.utilities.model_builder import ExportConfig
25
26
 
26
27
  _CHECKPOINT_PATH = flags.DEFINE_string(
27
28
  'checkpoint_path',
@@ -54,6 +55,7 @@ def main(_):
54
55
  pytorch_model = smollm.build_model(
55
56
  _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
56
57
  )
58
+
57
59
  quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
58
60
  output_filename = f'smollm_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
59
61
  converter.convert_to_tflite(
@@ -61,6 +63,7 @@ def main(_):
61
63
  tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
62
64
  prefill_seq_len=_PREFILL_SEQ_LENS.value,
63
65
  quantize=_QUANTIZE.value,
66
+ export_config=ExportConfig(),
64
67
  )
65
68
 
66
69