ai-edge-torch-nightly 0.3.0.dev20241206__py3-none-any.whl → 0.3.0.dev20241214__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/__init__.py +1 -1
- ai_edge_torch/_config.py +52 -0
- ai_edge_torch/_convert/test/test_convert.py +1 -2
- ai_edge_torch/debug/test/test_culprit.py +8 -3
- ai_edge_torch/generative/examples/amd_llama_135m/amd_llama_135m.py +8 -3
- ai_edge_torch/generative/examples/amd_llama_135m/convert_to_tflite.py +2 -0
- ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py +2 -0
- ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +2 -0
- ai_edge_torch/generative/examples/gemma/gemma1.py +8 -3
- ai_edge_torch/generative/examples/gemma/gemma2.py +15 -8
- ai_edge_torch/generative/examples/llama/convert_to_tflite.py +2 -0
- ai_edge_torch/generative/examples/llama/llama.py +11 -17
- ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +2 -0
- ai_edge_torch/generative/examples/openelm/openelm.py +8 -3
- ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py +2 -0
- ai_edge_torch/generative/examples/paligemma/decoder.py +10 -9
- ai_edge_torch/generative/examples/paligemma/paligemma.py +11 -1
- ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py +2 -0
- ai_edge_torch/generative/examples/phi/convert_to_tflite.py +2 -0
- ai_edge_torch/generative/examples/phi/phi2.py +8 -3
- ai_edge_torch/generative/examples/phi/phi3.py +7 -9
- ai_edge_torch/generative/examples/qwen/convert_to_tflite.py +2 -0
- ai_edge_torch/generative/examples/qwen/qwen.py +12 -9
- ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +3 -0
- ai_edge_torch/generative/examples/smollm/smollm.py +8 -3
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +12 -2
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +2 -0
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +8 -3
- ai_edge_torch/generative/layers/attention.py +2 -6
- ai_edge_torch/generative/layers/kv_cache.py +24 -18
- ai_edge_torch/generative/layers/normalization.py +1 -3
- ai_edge_torch/generative/test/test_kv_cache.py +3 -3
- ai_edge_torch/generative/test/test_model_conversion.py +12 -14
- ai_edge_torch/generative/test/test_model_conversion_large.py +63 -59
- ai_edge_torch/generative/test/utils.py +31 -6
- ai_edge_torch/generative/utilities/converter.py +25 -4
- ai_edge_torch/generative/utilities/model_builder.py +24 -4
- ai_edge_torch/generative/utilities/verifier.py +16 -2
- ai_edge_torch/lowertools/_shim.py +4 -2
- ai_edge_torch/lowertools/test_utils.py +4 -2
- ai_edge_torch/odml_torch/lowerings/__init__.py +1 -1
- ai_edge_torch/odml_torch/lowerings/_basic.py +5 -3
- ai_edge_torch/odml_torch/lowerings/_convolution.py +3 -1
- ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +28 -2
- ai_edge_torch/odml_torch/lowerings/_layer_norm.py +11 -2
- ai_edge_torch/odml_torch/lowerings/_quantized_decomposed.py +9 -9
- ai_edge_torch/odml_torch/lowerings/decomp.py +65 -0
- ai_edge_torch/odml_torch/lowerings/registry.py +0 -32
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20241206.dist-info → ai_edge_torch_nightly-0.3.0.dev20241214.dist-info}/METADATA +7 -5
- {ai_edge_torch_nightly-0.3.0.dev20241206.dist-info → ai_edge_torch_nightly-0.3.0.dev20241214.dist-info}/RECORD +54 -54
- ai_edge_torch/config.py +0 -27
- ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py +0 -283
- {ai_edge_torch_nightly-0.3.0.dev20241206.dist-info → ai_edge_torch_nightly-0.3.0.dev20241214.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20241206.dist-info → ai_edge_torch_nightly-0.3.0.dev20241214.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20241206.dist-info → ai_edge_torch_nightly-0.3.0.dev20241214.dist-info}/top_level.txt +0 -0
ai_edge_torch/__init__.py
CHANGED
@@ -13,13 +13,13 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
|
+
from ai_edge_torch._config import config
|
16
17
|
from ai_edge_torch._convert.converter import convert
|
17
18
|
from ai_edge_torch._convert.converter import signature
|
18
19
|
from ai_edge_torch._convert.to_channel_last_io import to_channel_last_io
|
19
20
|
from ai_edge_torch.model import Model
|
20
21
|
from ai_edge_torch.version import __version__
|
21
22
|
|
22
|
-
|
23
23
|
def load(path: str) -> Model:
|
24
24
|
"""Imports an ai_edge_torch model from disk.
|
25
25
|
|
ai_edge_torch/_config.py
ADDED
@@ -0,0 +1,52 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Provides a configuration for the ai-edge-torch."""
|
17
|
+
|
18
|
+
import functools
|
19
|
+
import logging
|
20
|
+
import os
|
21
|
+
|
22
|
+
__all__ = ["config"]
|
23
|
+
|
24
|
+
|
25
|
+
class _Config:
|
26
|
+
"""ai-edge-torch global configs."""
|
27
|
+
|
28
|
+
@property
|
29
|
+
@functools.cache # pylint: disable=method-cache-max-size-none
|
30
|
+
def use_torch_xla(self) -> bool:
|
31
|
+
"""True if using torch_xla to lower torch ops to StableHLO.
|
32
|
+
|
33
|
+
To use torch_xla as the lowering backend, set environment variable
|
34
|
+
`USE_TORCH_XLA` to "true".
|
35
|
+
"""
|
36
|
+
var = os.environ.get("USE_TORCH_XLA", "false")
|
37
|
+
var = var.lower().strip()
|
38
|
+
if var in ("y", "yes", "t", "true", "on", "1"):
|
39
|
+
return True
|
40
|
+
elif var in ("n", "no", "f", "false", "off", "0"):
|
41
|
+
return False
|
42
|
+
else:
|
43
|
+
logging.warning("Invalid USE_TORCH_XLA value is ignored: %s.", var)
|
44
|
+
return False
|
45
|
+
|
46
|
+
@property
|
47
|
+
def in_oss(self) -> bool:
|
48
|
+
"""True if the code is not running in google internal environment."""
|
49
|
+
return True
|
50
|
+
|
51
|
+
|
52
|
+
config = _Config()
|
@@ -19,7 +19,6 @@ import os
|
|
19
19
|
from typing import Tuple
|
20
20
|
|
21
21
|
import ai_edge_torch
|
22
|
-
from ai_edge_torch import config
|
23
22
|
from ai_edge_torch._convert import conversion_utils
|
24
23
|
from ai_edge_torch.quantize import pt2e_quantizer
|
25
24
|
from ai_edge_torch.testing import model_coverage
|
@@ -292,7 +291,7 @@ class TestConvert(googletest.TestCase):
|
|
292
291
|
self.assertTrue(result)
|
293
292
|
|
294
293
|
@googletest.skipIf(
|
295
|
-
not config.
|
294
|
+
not ai_edge_torch.config.use_torch_xla,
|
296
295
|
reason="Shape polymorphism is not yet support with odml_torch.",
|
297
296
|
)
|
298
297
|
def test_convert_model_with_dynamic_batch(self):
|
@@ -15,14 +15,14 @@
|
|
15
15
|
|
16
16
|
|
17
17
|
import ast
|
18
|
-
import io
|
19
|
-
import sys
|
20
18
|
|
21
|
-
|
19
|
+
import ai_edge_torch.debug
|
22
20
|
import torch
|
23
21
|
|
24
22
|
from absl.testing import absltest as googletest
|
25
23
|
|
24
|
+
find_culprits = ai_edge_torch.debug.find_culprits
|
25
|
+
|
26
26
|
_test_culprit_lib = torch.library.Library("test_culprit", "DEF")
|
27
27
|
|
28
28
|
_test_culprit_lib.define("non_lowerable_op(Tensor x) -> Tensor")
|
@@ -52,6 +52,11 @@ class BadModel(torch.nn.Module):
|
|
52
52
|
|
53
53
|
class TestCulprit(googletest.TestCase):
|
54
54
|
|
55
|
+
def setUp(self):
|
56
|
+
super().setUp()
|
57
|
+
torch.manual_seed(0)
|
58
|
+
torch._dynamo.reset()
|
59
|
+
|
55
60
|
def test_find_culprits(self):
|
56
61
|
model = BadModel().eval()
|
57
62
|
args = (torch.rand(10),)
|
@@ -17,10 +17,16 @@
|
|
17
17
|
|
18
18
|
import ai_edge_torch.generative.layers.model_config as cfg
|
19
19
|
from ai_edge_torch.generative.utilities import model_builder
|
20
|
+
from torch import nn
|
20
21
|
|
21
22
|
TENSOR_NAMES = model_builder.TENSOR_NAMES_WITH_SEPARATE_LM_HEAD
|
22
23
|
|
23
24
|
|
25
|
+
class AmdLlama(model_builder.DecoderOnlyModel):
|
26
|
+
"""An AMD-Llama model built from the Edge Generative API layers."""
|
27
|
+
pass
|
28
|
+
|
29
|
+
|
24
30
|
def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
25
31
|
"""Returns the model config for an AMD-Llama-135m model.
|
26
32
|
|
@@ -72,11 +78,10 @@ def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
|
|
72
78
|
return config
|
73
79
|
|
74
80
|
|
75
|
-
def build_model(
|
76
|
-
checkpoint_path: str, **kwargs
|
77
|
-
) -> model_builder.DecoderOnlyModel:
|
81
|
+
def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
|
78
82
|
return model_builder.build_decoder_only_model(
|
79
83
|
checkpoint_path=checkpoint_path,
|
80
84
|
config=get_model_config(**kwargs),
|
81
85
|
tensor_names=TENSOR_NAMES,
|
86
|
+
model_class=AmdLlama
|
82
87
|
)
|
@@ -22,6 +22,7 @@ from absl import app
|
|
22
22
|
from absl import flags
|
23
23
|
from ai_edge_torch.generative.examples.amd_llama_135m import amd_llama_135m
|
24
24
|
from ai_edge_torch.generative.utilities import converter
|
25
|
+
from ai_edge_torch.generative.utilities.model_builder import ExportConfig
|
25
26
|
|
26
27
|
_CHECKPOINT_PATH = flags.DEFINE_string(
|
27
28
|
'checkpoint_path',
|
@@ -61,6 +62,7 @@ def main(_):
|
|
61
62
|
tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
|
62
63
|
prefill_seq_len=_PREFILL_SEQ_LEN.value,
|
63
64
|
quantize=_QUANTIZE.value,
|
65
|
+
export_config=ExportConfig(),
|
64
66
|
)
|
65
67
|
|
66
68
|
|
@@ -22,6 +22,7 @@ from absl import app
|
|
22
22
|
from absl import flags
|
23
23
|
from ai_edge_torch.generative.examples.gemma import gemma1
|
24
24
|
from ai_edge_torch.generative.utilities import converter
|
25
|
+
from ai_edge_torch.generative.utilities.model_builder import ExportConfig
|
25
26
|
|
26
27
|
_CHECKPOINT_PATH = flags.DEFINE_string(
|
27
28
|
'checkpoint_path',
|
@@ -61,6 +62,7 @@ def main(_):
|
|
61
62
|
tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
|
62
63
|
prefill_seq_len=_PREFILL_SEQ_LENS.value,
|
63
64
|
quantize=_QUANTIZE.value,
|
65
|
+
export_config=ExportConfig(),
|
64
66
|
)
|
65
67
|
|
66
68
|
|
@@ -22,6 +22,7 @@ from absl import app
|
|
22
22
|
from absl import flags
|
23
23
|
from ai_edge_torch.generative.examples.gemma import gemma2
|
24
24
|
from ai_edge_torch.generative.utilities import converter
|
25
|
+
from ai_edge_torch.generative.utilities.model_builder import ExportConfig
|
25
26
|
|
26
27
|
_CHECKPOINT_PATH = flags.DEFINE_string(
|
27
28
|
'checkpoint_path',
|
@@ -61,6 +62,7 @@ def main(_):
|
|
61
62
|
tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
|
62
63
|
prefill_seq_len=_PREFILL_SEQ_LENS.value,
|
63
64
|
quantize=_QUANTIZE.value,
|
65
|
+
export_config=ExportConfig(),
|
64
66
|
)
|
65
67
|
|
66
68
|
|
@@ -18,6 +18,7 @@
|
|
18
18
|
import ai_edge_torch.generative.layers.model_config as cfg
|
19
19
|
from ai_edge_torch.generative.utilities import model_builder
|
20
20
|
import ai_edge_torch.generative.utilities.loader as loading_utils
|
21
|
+
from torch import nn
|
21
22
|
|
22
23
|
TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
23
24
|
ff_up_proj="model.layers.{}.mlp.up_proj",
|
@@ -33,6 +34,11 @@ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
|
33
34
|
)
|
34
35
|
|
35
36
|
|
37
|
+
class Gemma1(model_builder.DecoderOnlyModel):
|
38
|
+
"""A Gemma1 model built from the Edge Generative API layers."""
|
39
|
+
pass
|
40
|
+
|
41
|
+
|
36
42
|
def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
37
43
|
"""Returns the model config for a Gemma 2B model.
|
38
44
|
|
@@ -91,11 +97,10 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
|
|
91
97
|
return config
|
92
98
|
|
93
99
|
|
94
|
-
def build_2b_model(
|
95
|
-
checkpoint_path: str, **kwargs
|
96
|
-
) -> model_builder.DecoderOnlyModel:
|
100
|
+
def build_2b_model(checkpoint_path: str, **kwargs) -> nn.Module:
|
97
101
|
return model_builder.build_decoder_only_model(
|
98
102
|
checkpoint_path=checkpoint_path,
|
99
103
|
config=get_model_config_2b(**kwargs),
|
100
104
|
tensor_names=TENSOR_NAMES,
|
105
|
+
model_class=Gemma1,
|
101
106
|
)
|
@@ -22,6 +22,7 @@ from ai_edge_torch.generative.layers import builder
|
|
22
22
|
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
23
23
|
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
24
24
|
import ai_edge_torch.generative.layers.model_config as cfg
|
25
|
+
from ai_edge_torch.generative.utilities import model_builder
|
25
26
|
import ai_edge_torch.generative.utilities.loader as loading_utils
|
26
27
|
import torch
|
27
28
|
from torch import nn
|
@@ -132,6 +133,7 @@ class Gemma2(nn.Module):
|
|
132
133
|
tokens: torch.Tensor,
|
133
134
|
input_pos: torch.Tensor,
|
134
135
|
kv_cache: kv_utils.KVCache,
|
136
|
+
export_config: Optional[model_builder.ExportConfig] = None,
|
135
137
|
) -> dict[torch.Tensor, kv_utils.KVCache]:
|
136
138
|
_, seq_len = tokens.size()
|
137
139
|
assert self.config.max_seq_len >= seq_len, (
|
@@ -162,6 +164,13 @@ class Gemma2(nn.Module):
|
|
162
164
|
updated_kv_entires.append(kv_entry)
|
163
165
|
updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
|
164
166
|
|
167
|
+
if export_config is not None:
|
168
|
+
if (
|
169
|
+
torch.numel(input_pos) > 1
|
170
|
+
and not export_config.output_logits_on_prefill
|
171
|
+
):
|
172
|
+
return {"kv_cache": updated_kv_cache}
|
173
|
+
|
165
174
|
x = self.final_norm(x)
|
166
175
|
res = self.lm_head(x) # (b, t, vocab_size)
|
167
176
|
if self.config.final_logit_softcap is not None:
|
@@ -250,11 +259,9 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
|
|
250
259
|
|
251
260
|
|
252
261
|
def build_2b_model(checkpoint_path: str, **kwargs) -> nn.Module:
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
|
257
|
-
|
258
|
-
|
259
|
-
model.eval()
|
260
|
-
return model
|
262
|
+
return model_builder.build_decoder_only_model(
|
263
|
+
checkpoint_path=checkpoint_path,
|
264
|
+
config=get_model_config_2b(**kwargs),
|
265
|
+
tensor_names=TENSOR_NAMES,
|
266
|
+
model_class=Gemma2,
|
267
|
+
)
|
@@ -22,6 +22,7 @@ from absl import app
|
|
22
22
|
from absl import flags
|
23
23
|
from ai_edge_torch.generative.examples.llama import llama
|
24
24
|
from ai_edge_torch.generative.utilities import converter
|
25
|
+
from ai_edge_torch.generative.utilities.model_builder import ExportConfig
|
25
26
|
|
26
27
|
_MODEL_SIZE = flags.DEFINE_enum(
|
27
28
|
'model_size',
|
@@ -72,6 +73,7 @@ def main(_):
|
|
72
73
|
tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
|
73
74
|
prefill_seq_len=_PREFILL_SEQ_LENS.value,
|
74
75
|
quantize=_QUANTIZE.value,
|
76
|
+
export_config=ExportConfig(),
|
75
77
|
)
|
76
78
|
|
77
79
|
|
@@ -20,7 +20,6 @@ from typing import Tuple
|
|
20
20
|
|
21
21
|
import ai_edge_torch.generative.layers.model_config as cfg
|
22
22
|
from ai_edge_torch.generative.utilities import model_builder
|
23
|
-
import ai_edge_torch.generative.utilities.loader as loading_utils
|
24
23
|
import torch
|
25
24
|
|
26
25
|
TENSOR_NAMES = model_builder.TENSOR_NAMES
|
@@ -177,23 +176,18 @@ def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
|
|
177
176
|
|
178
177
|
def _build_model(
|
179
178
|
checkpoint_path: str, config: cfg.ModelConfig
|
180
|
-
) ->
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
def build_1b_model(
|
191
|
-
checkpoint_path: str, **kwargs
|
192
|
-
) -> model_builder.DecoderOnlyModel:
|
179
|
+
) -> torch.nn.Module:
|
180
|
+
return model_builder.build_decoder_only_model(
|
181
|
+
checkpoint_path=checkpoint_path,
|
182
|
+
config=config,
|
183
|
+
tensor_names=TENSOR_NAMES,
|
184
|
+
model_class=Llama,
|
185
|
+
)
|
186
|
+
|
187
|
+
|
188
|
+
def build_1b_model(checkpoint_path: str, **kwargs) -> torch.nn.Module:
|
193
189
|
return _build_model(checkpoint_path, get_1b_model_config(**kwargs))
|
194
190
|
|
195
191
|
|
196
|
-
def build_3b_model(
|
197
|
-
checkpoint_path: str, **kwargs
|
198
|
-
) -> model_builder.DecoderOnlyModel:
|
192
|
+
def build_3b_model(checkpoint_path: str, **kwargs) -> torch.nn.Module:
|
199
193
|
return _build_model(checkpoint_path, get_3b_model_config(**kwargs))
|
@@ -22,6 +22,7 @@ from absl import app
|
|
22
22
|
from absl import flags
|
23
23
|
from ai_edge_torch.generative.examples.openelm import openelm
|
24
24
|
from ai_edge_torch.generative.utilities import converter
|
25
|
+
from ai_edge_torch.generative.utilities.model_builder import ExportConfig
|
25
26
|
|
26
27
|
_CHECKPOINT_PATH = flags.DEFINE_string(
|
27
28
|
'checkpoint_path',
|
@@ -64,6 +65,7 @@ def main(_):
|
|
64
65
|
tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
|
65
66
|
prefill_seq_len=_PREFILL_SEQ_LENS.value,
|
66
67
|
quantize=_QUANTIZE.value,
|
68
|
+
export_config=ExportConfig(),
|
67
69
|
)
|
68
70
|
|
69
71
|
|
@@ -18,6 +18,7 @@
|
|
18
18
|
import ai_edge_torch.generative.layers.model_config as cfg
|
19
19
|
from ai_edge_torch.generative.utilities import model_builder
|
20
20
|
import ai_edge_torch.generative.utilities.loader as loading_utils
|
21
|
+
from torch import nn
|
21
22
|
|
22
23
|
TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
23
24
|
ff_up_proj="transformer.layers.{}.ffn.proj_1",
|
@@ -34,6 +35,11 @@ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
|
34
35
|
)
|
35
36
|
|
36
37
|
|
38
|
+
class OpenELM(model_builder.DecoderOnlyModel):
|
39
|
+
"""An OpenELM model built from the Edge Generative API layers."""
|
40
|
+
pass
|
41
|
+
|
42
|
+
|
37
43
|
def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
38
44
|
"""Returns the model config for an OpenELM model.
|
39
45
|
|
@@ -112,11 +118,10 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
|
|
112
118
|
return config
|
113
119
|
|
114
120
|
|
115
|
-
def build_model(
|
116
|
-
checkpoint_path: str, **kwargs
|
117
|
-
) -> model_builder.DecoderOnlyModel:
|
121
|
+
def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
|
118
122
|
return model_builder.build_decoder_only_model(
|
119
123
|
checkpoint_path=checkpoint_path,
|
120
124
|
config=get_model_config(**kwargs),
|
121
125
|
tensor_names=TENSOR_NAMES,
|
126
|
+
model_class=OpenELM,
|
122
127
|
)
|
@@ -26,6 +26,7 @@ from absl import app
|
|
26
26
|
from absl import flags
|
27
27
|
from ai_edge_torch.generative.examples.paligemma import paligemma
|
28
28
|
from ai_edge_torch.generative.utilities import converter
|
29
|
+
from ai_edge_torch.generative.utilities.model_builder import ExportConfig
|
29
30
|
import torch
|
30
31
|
|
31
32
|
_CHECKPOINT_PATH = flags.DEFINE_string(
|
@@ -73,6 +74,7 @@ def main(_):
|
|
73
74
|
pixel_values_size=torch.Size(_PIXEL_VALUES_SIZE.value),
|
74
75
|
quantize=_QUANTIZE.value,
|
75
76
|
config=pytorch_model.config.decoder_config,
|
77
|
+
export_config=ExportConfig(),
|
76
78
|
)
|
77
79
|
|
78
80
|
|
@@ -15,6 +15,8 @@
|
|
15
15
|
|
16
16
|
"""Example of building a decoder of PaliGemma 3B model which is Gemma1."""
|
17
17
|
|
18
|
+
from typing import Optional
|
19
|
+
|
18
20
|
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
19
21
|
import ai_edge_torch.generative.layers.model_config as cfg
|
20
22
|
from ai_edge_torch.generative.utilities import model_builder
|
@@ -51,6 +53,7 @@ class Decoder(model_builder.DecoderOnlyModel):
|
|
51
53
|
input_pos: torch.Tensor,
|
52
54
|
kv_cache: kv_utils.KVCache,
|
53
55
|
input_embeds: torch.Tensor = None,
|
56
|
+
export_config: Optional[model_builder.ExportConfig] = None,
|
54
57
|
) -> dict[torch.Tensor, kv_utils.KVCache]:
|
55
58
|
if input_embeds is None:
|
56
59
|
return super().forward(tokens, input_pos, kv_cache)
|
@@ -130,12 +133,10 @@ def get_fake_decoder_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
|
|
130
133
|
return config
|
131
134
|
|
132
135
|
|
133
|
-
def build_decoder(
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
decoder.eval()
|
141
|
-
return decoder
|
136
|
+
def build_decoder(checkpoint_path: str, **kwargs) -> torch.nn.Module:
|
137
|
+
return model_builder.build_decoder_only_model(
|
138
|
+
checkpoint_path=checkpoint_path,
|
139
|
+
config=get_decoder_config(**kwargs),
|
140
|
+
tensor_names=TENSOR_NAMES,
|
141
|
+
model_class=Decoder,
|
142
|
+
)
|
@@ -16,11 +16,13 @@
|
|
16
16
|
"""Example of building a full-stack of PaliGemma model."""
|
17
17
|
|
18
18
|
from dataclasses import dataclass
|
19
|
+
from typing import Optional
|
19
20
|
|
20
21
|
from ai_edge_torch.generative.examples.paligemma import decoder
|
21
22
|
from ai_edge_torch.generative.examples.paligemma import image_encoder
|
22
23
|
import ai_edge_torch.generative.layers.kv_cache as kv_utils
|
23
24
|
import ai_edge_torch.generative.layers.model_config as cfg
|
25
|
+
from ai_edge_torch.generative.utilities import model_builder
|
24
26
|
import ai_edge_torch.generative.utilities.loader as loading_utils
|
25
27
|
import torch
|
26
28
|
from torch import nn
|
@@ -67,9 +69,16 @@ class PaliGemma(nn.Module):
|
|
67
69
|
input_pos: torch.Tensor,
|
68
70
|
kv_cache: kv_utils.KVCache,
|
69
71
|
pixel_values: torch.Tensor = None,
|
72
|
+
export_config: Optional[model_builder.ExportConfig] = None,
|
70
73
|
) -> dict[torch.Tensor, kv_utils.KVCache]:
|
71
74
|
if pixel_values is None:
|
72
|
-
return self.decoder(
|
75
|
+
return self.decoder(
|
76
|
+
tokens=tokens,
|
77
|
+
input_pos=input_pos,
|
78
|
+
kv_cache=kv_cache,
|
79
|
+
input_embeds=None,
|
80
|
+
export_config=export_config
|
81
|
+
)
|
73
82
|
|
74
83
|
input_embeds = self.decoder.tok_embedding(tokens)
|
75
84
|
|
@@ -100,6 +109,7 @@ class PaliGemma(nn.Module):
|
|
100
109
|
input_pos=input_pos,
|
101
110
|
kv_cache=kv_cache,
|
102
111
|
input_embeds=input_embeds,
|
112
|
+
export_config=export_config,
|
103
113
|
)
|
104
114
|
|
105
115
|
|
@@ -22,6 +22,7 @@ from absl import app
|
|
22
22
|
from absl import flags
|
23
23
|
from ai_edge_torch.generative.examples.phi import phi3
|
24
24
|
from ai_edge_torch.generative.utilities import converter
|
25
|
+
from ai_edge_torch.generative.utilities.model_builder import ExportConfig
|
25
26
|
|
26
27
|
_CHECKPOINT_PATH = flags.DEFINE_string(
|
27
28
|
'checkpoint_path',
|
@@ -61,6 +62,7 @@ def main(_):
|
|
61
62
|
tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
|
62
63
|
prefill_seq_len=_PREFILL_SEQ_LENS.value,
|
63
64
|
quantize=_QUANTIZE.value,
|
65
|
+
export_config=ExportConfig(),
|
64
66
|
)
|
65
67
|
|
66
68
|
|
@@ -22,6 +22,7 @@ from absl import app
|
|
22
22
|
from absl import flags
|
23
23
|
from ai_edge_torch.generative.examples.phi import phi2
|
24
24
|
from ai_edge_torch.generative.utilities import converter
|
25
|
+
from ai_edge_torch.generative.utilities.model_builder import ExportConfig
|
25
26
|
|
26
27
|
_CHECKPOINT_PATH = flags.DEFINE_string(
|
27
28
|
'checkpoint_path',
|
@@ -61,6 +62,7 @@ def main(_):
|
|
61
62
|
tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
|
62
63
|
prefill_seq_len=_PREFILL_SEQ_LENS.value,
|
63
64
|
quantize=_QUANTIZE.value,
|
65
|
+
export_config=ExportConfig(),
|
64
66
|
)
|
65
67
|
|
66
68
|
|
@@ -18,6 +18,7 @@
|
|
18
18
|
import ai_edge_torch.generative.layers.model_config as cfg
|
19
19
|
from ai_edge_torch.generative.utilities import model_builder
|
20
20
|
import ai_edge_torch.generative.utilities.loader as loading_utils
|
21
|
+
from torch import nn
|
21
22
|
|
22
23
|
TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
23
24
|
ff_up_proj="model.layers.{}.mlp.fc1",
|
@@ -33,6 +34,11 @@ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
|
33
34
|
)
|
34
35
|
|
35
36
|
|
37
|
+
class Phi2(model_builder.DecoderOnlyModel):
|
38
|
+
"""A Phi-2 model built from the Edge Generative API layers."""
|
39
|
+
pass
|
40
|
+
|
41
|
+
|
36
42
|
def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
37
43
|
"""Returns the model config for a Phi-2 model.
|
38
44
|
|
@@ -92,11 +98,10 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
|
|
92
98
|
return config
|
93
99
|
|
94
100
|
|
95
|
-
def build_model(
|
96
|
-
checkpoint_path: str, **kwargs
|
97
|
-
) -> model_builder.DecoderOnlyModel:
|
101
|
+
def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
|
98
102
|
return model_builder.build_decoder_only_model(
|
99
103
|
checkpoint_path=checkpoint_path,
|
100
104
|
config=get_model_config(**kwargs),
|
101
105
|
tensor_names=TENSOR_NAMES,
|
106
|
+
model_class=Phi2,
|
102
107
|
)
|
@@ -207,13 +207,11 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
|
|
207
207
|
return config
|
208
208
|
|
209
209
|
|
210
|
-
def build_model(
|
211
|
-
checkpoint_path: str, **kwargs
|
212
|
-
) -> model_builder.DecoderOnlyModel:
|
210
|
+
def build_model(checkpoint_path: str, **kwargs) -> torch.nn.Module:
|
213
211
|
"""Instantiates the model instance and load checkpoint if provided."""
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
|
212
|
+
return model_builder.build_decoder_only_model(
|
213
|
+
checkpoint_path=checkpoint_path,
|
214
|
+
config=get_model_config(**kwargs),
|
215
|
+
tensor_names=TENSOR_NAMES,
|
216
|
+
model_class=Phi3_5Mini,
|
217
|
+
)
|
@@ -22,6 +22,7 @@ from absl import app
|
|
22
22
|
from absl import flags
|
23
23
|
from ai_edge_torch.generative.examples.qwen import qwen
|
24
24
|
from ai_edge_torch.generative.utilities import converter
|
25
|
+
from ai_edge_torch.generative.utilities.model_builder import ExportConfig
|
25
26
|
|
26
27
|
_MODEL_SIZE = flags.DEFINE_enum(
|
27
28
|
'model_size',
|
@@ -76,6 +77,7 @@ def main(_):
|
|
76
77
|
tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
|
77
78
|
prefill_seq_len=_PREFILL_SEQ_LENS.value,
|
78
79
|
quantize=_QUANTIZE.value,
|
80
|
+
export_config=ExportConfig(),
|
79
81
|
)
|
80
82
|
|
81
83
|
|
@@ -17,10 +17,16 @@
|
|
17
17
|
|
18
18
|
import ai_edge_torch.generative.layers.model_config as cfg
|
19
19
|
from ai_edge_torch.generative.utilities import model_builder
|
20
|
+
from torch import nn
|
20
21
|
|
21
22
|
TENSOR_NAMES = model_builder.TENSOR_NAMES
|
22
23
|
|
23
24
|
|
25
|
+
class Qwen(model_builder.DecoderOnlyModel):
|
26
|
+
"""A Qwen model built from the Edge Generative API layers."""
|
27
|
+
pass
|
28
|
+
|
29
|
+
|
24
30
|
def get_3b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
25
31
|
"""Returns the model config for a Qwen 2.5 3B model.
|
26
32
|
|
@@ -101,31 +107,28 @@ def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
|
|
101
107
|
return config
|
102
108
|
|
103
109
|
|
104
|
-
def build_3b_model(
|
105
|
-
checkpoint_path: str, **kwargs
|
106
|
-
) -> model_builder.DecoderOnlyModel:
|
110
|
+
def build_3b_model(checkpoint_path: str, **kwargs) -> nn.Module:
|
107
111
|
return model_builder.build_decoder_only_model(
|
108
112
|
checkpoint_path=checkpoint_path,
|
109
113
|
config=get_3b_model_config(**kwargs),
|
110
114
|
tensor_names=TENSOR_NAMES,
|
115
|
+
model_class=Qwen,
|
111
116
|
)
|
112
117
|
|
113
118
|
|
114
|
-
def build_1_5b_model(
|
115
|
-
checkpoint_path: str, **kwargs
|
116
|
-
) -> model_builder.DecoderOnlyModel:
|
119
|
+
def build_1_5b_model(checkpoint_path: str, **kwargs) -> nn.Module:
|
117
120
|
return model_builder.build_decoder_only_model(
|
118
121
|
checkpoint_path=checkpoint_path,
|
119
122
|
config=get_1_5b_model_config(**kwargs),
|
120
123
|
tensor_names=TENSOR_NAMES,
|
124
|
+
model_class=Qwen,
|
121
125
|
)
|
122
126
|
|
123
127
|
|
124
|
-
def build_0_5b_model(
|
125
|
-
checkpoint_path: str, **kwargs
|
126
|
-
) -> model_builder.DecoderOnlyModel:
|
128
|
+
def build_0_5b_model(checkpoint_path: str, **kwargs) -> nn.Module:
|
127
129
|
return model_builder.build_decoder_only_model(
|
128
130
|
checkpoint_path=checkpoint_path,
|
129
131
|
config=get_0_5b_model_config(**kwargs),
|
130
132
|
tensor_names=TENSOR_NAMES,
|
133
|
+
model_class=Qwen,
|
131
134
|
)
|
@@ -22,6 +22,7 @@ from absl import app
|
|
22
22
|
from absl import flags
|
23
23
|
from ai_edge_torch.generative.examples.smollm import smollm
|
24
24
|
from ai_edge_torch.generative.utilities import converter
|
25
|
+
from ai_edge_torch.generative.utilities.model_builder import ExportConfig
|
25
26
|
|
26
27
|
_CHECKPOINT_PATH = flags.DEFINE_string(
|
27
28
|
'checkpoint_path',
|
@@ -54,6 +55,7 @@ def main(_):
|
|
54
55
|
pytorch_model = smollm.build_model(
|
55
56
|
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
|
56
57
|
)
|
58
|
+
|
57
59
|
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
|
58
60
|
output_filename = f'smollm_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
|
59
61
|
converter.convert_to_tflite(
|
@@ -61,6 +63,7 @@ def main(_):
|
|
61
63
|
tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
|
62
64
|
prefill_seq_len=_PREFILL_SEQ_LENS.value,
|
63
65
|
quantize=_QUANTIZE.value,
|
66
|
+
export_config=ExportConfig(),
|
64
67
|
)
|
65
68
|
|
66
69
|
|