ai-edge-torch-nightly 0.3.0.dev20241206__py3-none-any.whl → 0.3.0.dev20241214__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (56) hide show
  1. ai_edge_torch/__init__.py +1 -1
  2. ai_edge_torch/_config.py +52 -0
  3. ai_edge_torch/_convert/test/test_convert.py +1 -2
  4. ai_edge_torch/debug/test/test_culprit.py +8 -3
  5. ai_edge_torch/generative/examples/amd_llama_135m/amd_llama_135m.py +8 -3
  6. ai_edge_torch/generative/examples/amd_llama_135m/convert_to_tflite.py +2 -0
  7. ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py +2 -0
  8. ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +2 -0
  9. ai_edge_torch/generative/examples/gemma/gemma1.py +8 -3
  10. ai_edge_torch/generative/examples/gemma/gemma2.py +15 -8
  11. ai_edge_torch/generative/examples/llama/convert_to_tflite.py +2 -0
  12. ai_edge_torch/generative/examples/llama/llama.py +11 -17
  13. ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +2 -0
  14. ai_edge_torch/generative/examples/openelm/openelm.py +8 -3
  15. ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py +2 -0
  16. ai_edge_torch/generative/examples/paligemma/decoder.py +10 -9
  17. ai_edge_torch/generative/examples/paligemma/paligemma.py +11 -1
  18. ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py +2 -0
  19. ai_edge_torch/generative/examples/phi/convert_to_tflite.py +2 -0
  20. ai_edge_torch/generative/examples/phi/phi2.py +8 -3
  21. ai_edge_torch/generative/examples/phi/phi3.py +7 -9
  22. ai_edge_torch/generative/examples/qwen/convert_to_tflite.py +2 -0
  23. ai_edge_torch/generative/examples/qwen/qwen.py +12 -9
  24. ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +3 -0
  25. ai_edge_torch/generative/examples/smollm/smollm.py +8 -3
  26. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +12 -2
  27. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +2 -0
  28. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +8 -3
  29. ai_edge_torch/generative/layers/attention.py +2 -6
  30. ai_edge_torch/generative/layers/kv_cache.py +24 -18
  31. ai_edge_torch/generative/layers/normalization.py +1 -3
  32. ai_edge_torch/generative/test/test_kv_cache.py +3 -3
  33. ai_edge_torch/generative/test/test_model_conversion.py +12 -14
  34. ai_edge_torch/generative/test/test_model_conversion_large.py +63 -59
  35. ai_edge_torch/generative/test/utils.py +31 -6
  36. ai_edge_torch/generative/utilities/converter.py +25 -4
  37. ai_edge_torch/generative/utilities/model_builder.py +24 -4
  38. ai_edge_torch/generative/utilities/verifier.py +16 -2
  39. ai_edge_torch/lowertools/_shim.py +4 -2
  40. ai_edge_torch/lowertools/test_utils.py +4 -2
  41. ai_edge_torch/odml_torch/lowerings/__init__.py +1 -1
  42. ai_edge_torch/odml_torch/lowerings/_basic.py +5 -3
  43. ai_edge_torch/odml_torch/lowerings/_convolution.py +3 -1
  44. ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +28 -2
  45. ai_edge_torch/odml_torch/lowerings/_layer_norm.py +11 -2
  46. ai_edge_torch/odml_torch/lowerings/_quantized_decomposed.py +9 -9
  47. ai_edge_torch/odml_torch/lowerings/decomp.py +65 -0
  48. ai_edge_torch/odml_torch/lowerings/registry.py +0 -32
  49. ai_edge_torch/version.py +1 -1
  50. {ai_edge_torch_nightly-0.3.0.dev20241206.dist-info → ai_edge_torch_nightly-0.3.0.dev20241214.dist-info}/METADATA +7 -5
  51. {ai_edge_torch_nightly-0.3.0.dev20241206.dist-info → ai_edge_torch_nightly-0.3.0.dev20241214.dist-info}/RECORD +54 -54
  52. ai_edge_torch/config.py +0 -27
  53. ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py +0 -283
  54. {ai_edge_torch_nightly-0.3.0.dev20241206.dist-info → ai_edge_torch_nightly-0.3.0.dev20241214.dist-info}/LICENSE +0 -0
  55. {ai_edge_torch_nightly-0.3.0.dev20241206.dist-info → ai_edge_torch_nightly-0.3.0.dev20241214.dist-info}/WHEEL +0 -0
  56. {ai_edge_torch_nightly-0.3.0.dev20241206.dist-info → ai_edge_torch_nightly-0.3.0.dev20241214.dist-info}/top_level.txt +0 -0
ai_edge_torch/__init__.py CHANGED
@@ -13,13 +13,13 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
+ from ai_edge_torch._config import config
16
17
  from ai_edge_torch._convert.converter import convert
17
18
  from ai_edge_torch._convert.converter import signature
18
19
  from ai_edge_torch._convert.to_channel_last_io import to_channel_last_io
19
20
  from ai_edge_torch.model import Model
20
21
  from ai_edge_torch.version import __version__
21
22
 
22
-
23
23
  def load(path: str) -> Model:
24
24
  """Imports an ai_edge_torch model from disk.
25
25
 
@@ -0,0 +1,52 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Provides a configuration for the ai-edge-torch."""
17
+
18
+ import functools
19
+ import logging
20
+ import os
21
+
22
+ __all__ = ["config"]
23
+
24
+
25
+ class _Config:
26
+ """ai-edge-torch global configs."""
27
+
28
+ @property
29
+ @functools.cache # pylint: disable=method-cache-max-size-none
30
+ def use_torch_xla(self) -> bool:
31
+ """True if using torch_xla to lower torch ops to StableHLO.
32
+
33
+ To use torch_xla as the lowering backend, set environment variable
34
+ `USE_TORCH_XLA` to "true".
35
+ """
36
+ var = os.environ.get("USE_TORCH_XLA", "false")
37
+ var = var.lower().strip()
38
+ if var in ("y", "yes", "t", "true", "on", "1"):
39
+ return True
40
+ elif var in ("n", "no", "f", "false", "off", "0"):
41
+ return False
42
+ else:
43
+ logging.warning("Invalid USE_TORCH_XLA value is ignored: %s.", var)
44
+ return False
45
+
46
+ @property
47
+ def in_oss(self) -> bool:
48
+ """True if the code is not running in google internal environment."""
49
+ return True
50
+
51
+
52
+ config = _Config()
@@ -19,7 +19,6 @@ import os
19
19
  from typing import Tuple
20
20
 
21
21
  import ai_edge_torch
22
- from ai_edge_torch import config
23
22
  from ai_edge_torch._convert import conversion_utils
24
23
  from ai_edge_torch.quantize import pt2e_quantizer
25
24
  from ai_edge_torch.testing import model_coverage
@@ -292,7 +291,7 @@ class TestConvert(googletest.TestCase):
292
291
  self.assertTrue(result)
293
292
 
294
293
  @googletest.skipIf(
295
- not config.Config.use_torch_xla,
294
+ not ai_edge_torch.config.use_torch_xla,
296
295
  reason="Shape polymorphism is not yet support with odml_torch.",
297
296
  )
298
297
  def test_convert_model_with_dynamic_batch(self):
@@ -15,14 +15,14 @@
15
15
 
16
16
 
17
17
  import ast
18
- import io
19
- import sys
20
18
 
21
- from ai_edge_torch.debug import find_culprits
19
+ import ai_edge_torch.debug
22
20
  import torch
23
21
 
24
22
  from absl.testing import absltest as googletest
25
23
 
24
+ find_culprits = ai_edge_torch.debug.find_culprits
25
+
26
26
  _test_culprit_lib = torch.library.Library("test_culprit", "DEF")
27
27
 
28
28
  _test_culprit_lib.define("non_lowerable_op(Tensor x) -> Tensor")
@@ -52,6 +52,11 @@ class BadModel(torch.nn.Module):
52
52
 
53
53
  class TestCulprit(googletest.TestCase):
54
54
 
55
+ def setUp(self):
56
+ super().setUp()
57
+ torch.manual_seed(0)
58
+ torch._dynamo.reset()
59
+
55
60
  def test_find_culprits(self):
56
61
  model = BadModel().eval()
57
62
  args = (torch.rand(10),)
@@ -17,10 +17,16 @@
17
17
 
18
18
  import ai_edge_torch.generative.layers.model_config as cfg
19
19
  from ai_edge_torch.generative.utilities import model_builder
20
+ from torch import nn
20
21
 
21
22
  TENSOR_NAMES = model_builder.TENSOR_NAMES_WITH_SEPARATE_LM_HEAD
22
23
 
23
24
 
25
+ class AmdLlama(model_builder.DecoderOnlyModel):
26
+ """An AMD-Llama model built from the Edge Generative API layers."""
27
+ pass
28
+
29
+
24
30
  def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
25
31
  """Returns the model config for an AMD-Llama-135m model.
26
32
 
@@ -72,11 +78,10 @@ def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
72
78
  return config
73
79
 
74
80
 
75
- def build_model(
76
- checkpoint_path: str, **kwargs
77
- ) -> model_builder.DecoderOnlyModel:
81
+ def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
78
82
  return model_builder.build_decoder_only_model(
79
83
  checkpoint_path=checkpoint_path,
80
84
  config=get_model_config(**kwargs),
81
85
  tensor_names=TENSOR_NAMES,
86
+ model_class=AmdLlama
82
87
  )
@@ -22,6 +22,7 @@ from absl import app
22
22
  from absl import flags
23
23
  from ai_edge_torch.generative.examples.amd_llama_135m import amd_llama_135m
24
24
  from ai_edge_torch.generative.utilities import converter
25
+ from ai_edge_torch.generative.utilities.model_builder import ExportConfig
25
26
 
26
27
  _CHECKPOINT_PATH = flags.DEFINE_string(
27
28
  'checkpoint_path',
@@ -61,6 +62,7 @@ def main(_):
61
62
  tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
62
63
  prefill_seq_len=_PREFILL_SEQ_LEN.value,
63
64
  quantize=_QUANTIZE.value,
65
+ export_config=ExportConfig(),
64
66
  )
65
67
 
66
68
 
@@ -22,6 +22,7 @@ from absl import app
22
22
  from absl import flags
23
23
  from ai_edge_torch.generative.examples.gemma import gemma1
24
24
  from ai_edge_torch.generative.utilities import converter
25
+ from ai_edge_torch.generative.utilities.model_builder import ExportConfig
25
26
 
26
27
  _CHECKPOINT_PATH = flags.DEFINE_string(
27
28
  'checkpoint_path',
@@ -61,6 +62,7 @@ def main(_):
61
62
  tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
62
63
  prefill_seq_len=_PREFILL_SEQ_LENS.value,
63
64
  quantize=_QUANTIZE.value,
65
+ export_config=ExportConfig(),
64
66
  )
65
67
 
66
68
 
@@ -22,6 +22,7 @@ from absl import app
22
22
  from absl import flags
23
23
  from ai_edge_torch.generative.examples.gemma import gemma2
24
24
  from ai_edge_torch.generative.utilities import converter
25
+ from ai_edge_torch.generative.utilities.model_builder import ExportConfig
25
26
 
26
27
  _CHECKPOINT_PATH = flags.DEFINE_string(
27
28
  'checkpoint_path',
@@ -61,6 +62,7 @@ def main(_):
61
62
  tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
62
63
  prefill_seq_len=_PREFILL_SEQ_LENS.value,
63
64
  quantize=_QUANTIZE.value,
65
+ export_config=ExportConfig(),
64
66
  )
65
67
 
66
68
 
@@ -18,6 +18,7 @@
18
18
  import ai_edge_torch.generative.layers.model_config as cfg
19
19
  from ai_edge_torch.generative.utilities import model_builder
20
20
  import ai_edge_torch.generative.utilities.loader as loading_utils
21
+ from torch import nn
21
22
 
22
23
  TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
23
24
  ff_up_proj="model.layers.{}.mlp.up_proj",
@@ -33,6 +34,11 @@ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
33
34
  )
34
35
 
35
36
 
37
+ class Gemma1(model_builder.DecoderOnlyModel):
38
+ """A Gemma1 model built from the Edge Generative API layers."""
39
+ pass
40
+
41
+
36
42
  def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
37
43
  """Returns the model config for a Gemma 2B model.
38
44
 
@@ -91,11 +97,10 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
91
97
  return config
92
98
 
93
99
 
94
- def build_2b_model(
95
- checkpoint_path: str, **kwargs
96
- ) -> model_builder.DecoderOnlyModel:
100
+ def build_2b_model(checkpoint_path: str, **kwargs) -> nn.Module:
97
101
  return model_builder.build_decoder_only_model(
98
102
  checkpoint_path=checkpoint_path,
99
103
  config=get_model_config_2b(**kwargs),
100
104
  tensor_names=TENSOR_NAMES,
105
+ model_class=Gemma1,
101
106
  )
@@ -22,6 +22,7 @@ from ai_edge_torch.generative.layers import builder
22
22
  from ai_edge_torch.generative.layers import kv_cache as kv_utils
23
23
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
24
24
  import ai_edge_torch.generative.layers.model_config as cfg
25
+ from ai_edge_torch.generative.utilities import model_builder
25
26
  import ai_edge_torch.generative.utilities.loader as loading_utils
26
27
  import torch
27
28
  from torch import nn
@@ -132,6 +133,7 @@ class Gemma2(nn.Module):
132
133
  tokens: torch.Tensor,
133
134
  input_pos: torch.Tensor,
134
135
  kv_cache: kv_utils.KVCache,
136
+ export_config: Optional[model_builder.ExportConfig] = None,
135
137
  ) -> dict[torch.Tensor, kv_utils.KVCache]:
136
138
  _, seq_len = tokens.size()
137
139
  assert self.config.max_seq_len >= seq_len, (
@@ -162,6 +164,13 @@ class Gemma2(nn.Module):
162
164
  updated_kv_entires.append(kv_entry)
163
165
  updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
164
166
 
167
+ if export_config is not None:
168
+ if (
169
+ torch.numel(input_pos) > 1
170
+ and not export_config.output_logits_on_prefill
171
+ ):
172
+ return {"kv_cache": updated_kv_cache}
173
+
165
174
  x = self.final_norm(x)
166
175
  res = self.lm_head(x) # (b, t, vocab_size)
167
176
  if self.config.final_logit_softcap is not None:
@@ -250,11 +259,9 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
250
259
 
251
260
 
252
261
  def build_2b_model(checkpoint_path: str, **kwargs) -> nn.Module:
253
- config = get_model_config_2b(**kwargs)
254
- model = Gemma2(config)
255
- loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
256
- # Since embedding and lm-head use the same weight, we need to set strict
257
- # to False.
258
- loader.load(model, strict=False)
259
- model.eval()
260
- return model
262
+ return model_builder.build_decoder_only_model(
263
+ checkpoint_path=checkpoint_path,
264
+ config=get_model_config_2b(**kwargs),
265
+ tensor_names=TENSOR_NAMES,
266
+ model_class=Gemma2,
267
+ )
@@ -22,6 +22,7 @@ from absl import app
22
22
  from absl import flags
23
23
  from ai_edge_torch.generative.examples.llama import llama
24
24
  from ai_edge_torch.generative.utilities import converter
25
+ from ai_edge_torch.generative.utilities.model_builder import ExportConfig
25
26
 
26
27
  _MODEL_SIZE = flags.DEFINE_enum(
27
28
  'model_size',
@@ -72,6 +73,7 @@ def main(_):
72
73
  tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
73
74
  prefill_seq_len=_PREFILL_SEQ_LENS.value,
74
75
  quantize=_QUANTIZE.value,
76
+ export_config=ExportConfig(),
75
77
  )
76
78
 
77
79
 
@@ -20,7 +20,6 @@ from typing import Tuple
20
20
 
21
21
  import ai_edge_torch.generative.layers.model_config as cfg
22
22
  from ai_edge_torch.generative.utilities import model_builder
23
- import ai_edge_torch.generative.utilities.loader as loading_utils
24
23
  import torch
25
24
 
26
25
  TENSOR_NAMES = model_builder.TENSOR_NAMES
@@ -177,23 +176,18 @@ def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
177
176
 
178
177
  def _build_model(
179
178
  checkpoint_path: str, config: cfg.ModelConfig
180
- ) -> model_builder.DecoderOnlyModel:
181
- model = Llama(config)
182
- loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
183
- # Since embedding and lm-head use the same weight, we need to set strict
184
- # to False.
185
- loader.load(model, strict=False)
186
- model.eval()
187
- return model
188
-
189
-
190
- def build_1b_model(
191
- checkpoint_path: str, **kwargs
192
- ) -> model_builder.DecoderOnlyModel:
179
+ ) -> torch.nn.Module:
180
+ return model_builder.build_decoder_only_model(
181
+ checkpoint_path=checkpoint_path,
182
+ config=config,
183
+ tensor_names=TENSOR_NAMES,
184
+ model_class=Llama,
185
+ )
186
+
187
+
188
+ def build_1b_model(checkpoint_path: str, **kwargs) -> torch.nn.Module:
193
189
  return _build_model(checkpoint_path, get_1b_model_config(**kwargs))
194
190
 
195
191
 
196
- def build_3b_model(
197
- checkpoint_path: str, **kwargs
198
- ) -> model_builder.DecoderOnlyModel:
192
+ def build_3b_model(checkpoint_path: str, **kwargs) -> torch.nn.Module:
199
193
  return _build_model(checkpoint_path, get_3b_model_config(**kwargs))
@@ -22,6 +22,7 @@ from absl import app
22
22
  from absl import flags
23
23
  from ai_edge_torch.generative.examples.openelm import openelm
24
24
  from ai_edge_torch.generative.utilities import converter
25
+ from ai_edge_torch.generative.utilities.model_builder import ExportConfig
25
26
 
26
27
  _CHECKPOINT_PATH = flags.DEFINE_string(
27
28
  'checkpoint_path',
@@ -64,6 +65,7 @@ def main(_):
64
65
  tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
65
66
  prefill_seq_len=_PREFILL_SEQ_LENS.value,
66
67
  quantize=_QUANTIZE.value,
68
+ export_config=ExportConfig(),
67
69
  )
68
70
 
69
71
 
@@ -18,6 +18,7 @@
18
18
  import ai_edge_torch.generative.layers.model_config as cfg
19
19
  from ai_edge_torch.generative.utilities import model_builder
20
20
  import ai_edge_torch.generative.utilities.loader as loading_utils
21
+ from torch import nn
21
22
 
22
23
  TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
23
24
  ff_up_proj="transformer.layers.{}.ffn.proj_1",
@@ -34,6 +35,11 @@ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
34
35
  )
35
36
 
36
37
 
38
+ class OpenELM(model_builder.DecoderOnlyModel):
39
+ """An OpenELM model built from the Edge Generative API layers."""
40
+ pass
41
+
42
+
37
43
  def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
38
44
  """Returns the model config for an OpenELM model.
39
45
 
@@ -112,11 +118,10 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
112
118
  return config
113
119
 
114
120
 
115
- def build_model(
116
- checkpoint_path: str, **kwargs
117
- ) -> model_builder.DecoderOnlyModel:
121
+ def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
118
122
  return model_builder.build_decoder_only_model(
119
123
  checkpoint_path=checkpoint_path,
120
124
  config=get_model_config(**kwargs),
121
125
  tensor_names=TENSOR_NAMES,
126
+ model_class=OpenELM,
122
127
  )
@@ -26,6 +26,7 @@ from absl import app
26
26
  from absl import flags
27
27
  from ai_edge_torch.generative.examples.paligemma import paligemma
28
28
  from ai_edge_torch.generative.utilities import converter
29
+ from ai_edge_torch.generative.utilities.model_builder import ExportConfig
29
30
  import torch
30
31
 
31
32
  _CHECKPOINT_PATH = flags.DEFINE_string(
@@ -73,6 +74,7 @@ def main(_):
73
74
  pixel_values_size=torch.Size(_PIXEL_VALUES_SIZE.value),
74
75
  quantize=_QUANTIZE.value,
75
76
  config=pytorch_model.config.decoder_config,
77
+ export_config=ExportConfig(),
76
78
  )
77
79
 
78
80
 
@@ -15,6 +15,8 @@
15
15
 
16
16
  """Example of building a decoder of PaliGemma 3B model which is Gemma1."""
17
17
 
18
+ from typing import Optional
19
+
18
20
  from ai_edge_torch.generative.layers import kv_cache as kv_utils
19
21
  import ai_edge_torch.generative.layers.model_config as cfg
20
22
  from ai_edge_torch.generative.utilities import model_builder
@@ -51,6 +53,7 @@ class Decoder(model_builder.DecoderOnlyModel):
51
53
  input_pos: torch.Tensor,
52
54
  kv_cache: kv_utils.KVCache,
53
55
  input_embeds: torch.Tensor = None,
56
+ export_config: Optional[model_builder.ExportConfig] = None,
54
57
  ) -> dict[torch.Tensor, kv_utils.KVCache]:
55
58
  if input_embeds is None:
56
59
  return super().forward(tokens, input_pos, kv_cache)
@@ -130,12 +133,10 @@ def get_fake_decoder_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
130
133
  return config
131
134
 
132
135
 
133
- def build_decoder(
134
- checkpoint_path: str, **kwargs
135
- ) -> model_builder.DecoderOnlyModel:
136
- decoder = Decoder(get_decoder_config(**kwargs))
137
- loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
138
- # Loose the strictness because only decoder is being loaded.
139
- loader.load(decoder, strict=False)
140
- decoder.eval()
141
- return decoder
136
+ def build_decoder(checkpoint_path: str, **kwargs) -> torch.nn.Module:
137
+ return model_builder.build_decoder_only_model(
138
+ checkpoint_path=checkpoint_path,
139
+ config=get_decoder_config(**kwargs),
140
+ tensor_names=TENSOR_NAMES,
141
+ model_class=Decoder,
142
+ )
@@ -16,11 +16,13 @@
16
16
  """Example of building a full-stack of PaliGemma model."""
17
17
 
18
18
  from dataclasses import dataclass
19
+ from typing import Optional
19
20
 
20
21
  from ai_edge_torch.generative.examples.paligemma import decoder
21
22
  from ai_edge_torch.generative.examples.paligemma import image_encoder
22
23
  import ai_edge_torch.generative.layers.kv_cache as kv_utils
23
24
  import ai_edge_torch.generative.layers.model_config as cfg
25
+ from ai_edge_torch.generative.utilities import model_builder
24
26
  import ai_edge_torch.generative.utilities.loader as loading_utils
25
27
  import torch
26
28
  from torch import nn
@@ -67,9 +69,16 @@ class PaliGemma(nn.Module):
67
69
  input_pos: torch.Tensor,
68
70
  kv_cache: kv_utils.KVCache,
69
71
  pixel_values: torch.Tensor = None,
72
+ export_config: Optional[model_builder.ExportConfig] = None,
70
73
  ) -> dict[torch.Tensor, kv_utils.KVCache]:
71
74
  if pixel_values is None:
72
- return self.decoder(tokens, input_pos, kv_cache)
75
+ return self.decoder(
76
+ tokens=tokens,
77
+ input_pos=input_pos,
78
+ kv_cache=kv_cache,
79
+ input_embeds=None,
80
+ export_config=export_config
81
+ )
73
82
 
74
83
  input_embeds = self.decoder.tok_embedding(tokens)
75
84
 
@@ -100,6 +109,7 @@ class PaliGemma(nn.Module):
100
109
  input_pos=input_pos,
101
110
  kv_cache=kv_cache,
102
111
  input_embeds=input_embeds,
112
+ export_config=export_config,
103
113
  )
104
114
 
105
115
 
@@ -22,6 +22,7 @@ from absl import app
22
22
  from absl import flags
23
23
  from ai_edge_torch.generative.examples.phi import phi3
24
24
  from ai_edge_torch.generative.utilities import converter
25
+ from ai_edge_torch.generative.utilities.model_builder import ExportConfig
25
26
 
26
27
  _CHECKPOINT_PATH = flags.DEFINE_string(
27
28
  'checkpoint_path',
@@ -61,6 +62,7 @@ def main(_):
61
62
  tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
62
63
  prefill_seq_len=_PREFILL_SEQ_LENS.value,
63
64
  quantize=_QUANTIZE.value,
65
+ export_config=ExportConfig(),
64
66
  )
65
67
 
66
68
 
@@ -22,6 +22,7 @@ from absl import app
22
22
  from absl import flags
23
23
  from ai_edge_torch.generative.examples.phi import phi2
24
24
  from ai_edge_torch.generative.utilities import converter
25
+ from ai_edge_torch.generative.utilities.model_builder import ExportConfig
25
26
 
26
27
  _CHECKPOINT_PATH = flags.DEFINE_string(
27
28
  'checkpoint_path',
@@ -61,6 +62,7 @@ def main(_):
61
62
  tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
62
63
  prefill_seq_len=_PREFILL_SEQ_LENS.value,
63
64
  quantize=_QUANTIZE.value,
65
+ export_config=ExportConfig(),
64
66
  )
65
67
 
66
68
 
@@ -18,6 +18,7 @@
18
18
  import ai_edge_torch.generative.layers.model_config as cfg
19
19
  from ai_edge_torch.generative.utilities import model_builder
20
20
  import ai_edge_torch.generative.utilities.loader as loading_utils
21
+ from torch import nn
21
22
 
22
23
  TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
23
24
  ff_up_proj="model.layers.{}.mlp.fc1",
@@ -33,6 +34,11 @@ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
33
34
  )
34
35
 
35
36
 
37
+ class Phi2(model_builder.DecoderOnlyModel):
38
+ """A Phi-2 model built from the Edge Generative API layers."""
39
+ pass
40
+
41
+
36
42
  def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
37
43
  """Returns the model config for a Phi-2 model.
38
44
 
@@ -92,11 +98,10 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
92
98
  return config
93
99
 
94
100
 
95
- def build_model(
96
- checkpoint_path: str, **kwargs
97
- ) -> model_builder.DecoderOnlyModel:
101
+ def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
98
102
  return model_builder.build_decoder_only_model(
99
103
  checkpoint_path=checkpoint_path,
100
104
  config=get_model_config(**kwargs),
101
105
  tensor_names=TENSOR_NAMES,
106
+ model_class=Phi2,
102
107
  )
@@ -207,13 +207,11 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
207
207
  return config
208
208
 
209
209
 
210
- def build_model(
211
- checkpoint_path: str, **kwargs
212
- ) -> model_builder.DecoderOnlyModel:
210
+ def build_model(checkpoint_path: str, **kwargs) -> torch.nn.Module:
213
211
  """Instantiates the model instance and load checkpoint if provided."""
214
- config = get_model_config(**kwargs)
215
- model = Phi3_5Mini(config)
216
- loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
217
- loader.load(model)
218
- model.eval()
219
- return model
212
+ return model_builder.build_decoder_only_model(
213
+ checkpoint_path=checkpoint_path,
214
+ config=get_model_config(**kwargs),
215
+ tensor_names=TENSOR_NAMES,
216
+ model_class=Phi3_5Mini,
217
+ )
@@ -22,6 +22,7 @@ from absl import app
22
22
  from absl import flags
23
23
  from ai_edge_torch.generative.examples.qwen import qwen
24
24
  from ai_edge_torch.generative.utilities import converter
25
+ from ai_edge_torch.generative.utilities.model_builder import ExportConfig
25
26
 
26
27
  _MODEL_SIZE = flags.DEFINE_enum(
27
28
  'model_size',
@@ -76,6 +77,7 @@ def main(_):
76
77
  tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
77
78
  prefill_seq_len=_PREFILL_SEQ_LENS.value,
78
79
  quantize=_QUANTIZE.value,
80
+ export_config=ExportConfig(),
79
81
  )
80
82
 
81
83
 
@@ -17,10 +17,16 @@
17
17
 
18
18
  import ai_edge_torch.generative.layers.model_config as cfg
19
19
  from ai_edge_torch.generative.utilities import model_builder
20
+ from torch import nn
20
21
 
21
22
  TENSOR_NAMES = model_builder.TENSOR_NAMES
22
23
 
23
24
 
25
+ class Qwen(model_builder.DecoderOnlyModel):
26
+ """A Qwen model built from the Edge Generative API layers."""
27
+ pass
28
+
29
+
24
30
  def get_3b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
25
31
  """Returns the model config for a Qwen 2.5 3B model.
26
32
 
@@ -101,31 +107,28 @@ def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
101
107
  return config
102
108
 
103
109
 
104
- def build_3b_model(
105
- checkpoint_path: str, **kwargs
106
- ) -> model_builder.DecoderOnlyModel:
110
+ def build_3b_model(checkpoint_path: str, **kwargs) -> nn.Module:
107
111
  return model_builder.build_decoder_only_model(
108
112
  checkpoint_path=checkpoint_path,
109
113
  config=get_3b_model_config(**kwargs),
110
114
  tensor_names=TENSOR_NAMES,
115
+ model_class=Qwen,
111
116
  )
112
117
 
113
118
 
114
- def build_1_5b_model(
115
- checkpoint_path: str, **kwargs
116
- ) -> model_builder.DecoderOnlyModel:
119
+ def build_1_5b_model(checkpoint_path: str, **kwargs) -> nn.Module:
117
120
  return model_builder.build_decoder_only_model(
118
121
  checkpoint_path=checkpoint_path,
119
122
  config=get_1_5b_model_config(**kwargs),
120
123
  tensor_names=TENSOR_NAMES,
124
+ model_class=Qwen,
121
125
  )
122
126
 
123
127
 
124
- def build_0_5b_model(
125
- checkpoint_path: str, **kwargs
126
- ) -> model_builder.DecoderOnlyModel:
128
+ def build_0_5b_model(checkpoint_path: str, **kwargs) -> nn.Module:
127
129
  return model_builder.build_decoder_only_model(
128
130
  checkpoint_path=checkpoint_path,
129
131
  config=get_0_5b_model_config(**kwargs),
130
132
  tensor_names=TENSOR_NAMES,
133
+ model_class=Qwen,
131
134
  )
@@ -22,6 +22,7 @@ from absl import app
22
22
  from absl import flags
23
23
  from ai_edge_torch.generative.examples.smollm import smollm
24
24
  from ai_edge_torch.generative.utilities import converter
25
+ from ai_edge_torch.generative.utilities.model_builder import ExportConfig
25
26
 
26
27
  _CHECKPOINT_PATH = flags.DEFINE_string(
27
28
  'checkpoint_path',
@@ -54,6 +55,7 @@ def main(_):
54
55
  pytorch_model = smollm.build_model(
55
56
  _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
56
57
  )
58
+
57
59
  quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
58
60
  output_filename = f'smollm_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
59
61
  converter.convert_to_tflite(
@@ -61,6 +63,7 @@ def main(_):
61
63
  tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
62
64
  prefill_seq_len=_PREFILL_SEQ_LENS.value,
63
65
  quantize=_QUANTIZE.value,
66
+ export_config=ExportConfig(),
64
67
  )
65
68
 
66
69