ai-edge-torch-nightly 0.3.0.dev20240918__py3-none-any.whl → 0.3.0.dev20240920__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- ai_edge_torch/_convert/test/test_convert.py +7 -3
- ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +6 -4
- ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +6 -4
- ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +6 -4
- ai_edge_torch/generative/examples/openelm/openelm.py +0 -29
- ai_edge_torch/generative/examples/openelm/verify.py +63 -0
- ai_edge_torch/generative/examples/phi/convert_to_tflite.py +6 -4
- ai_edge_torch/generative/examples/phi/phi2.py +4 -31
- ai_edge_torch/generative/examples/phi/verify.py +63 -0
- ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +6 -4
- ai_edge_torch/generative/examples/smollm/smollm.py +0 -30
- ai_edge_torch/generative/examples/smollm/verify.py +60 -0
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +6 -4
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +0 -29
- ai_edge_torch/generative/examples/tiny_llama/verify.py +62 -0
- ai_edge_torch/generative/layers/builder.py +3 -1
- ai_edge_torch/generative/layers/model_config.py +3 -0
- ai_edge_torch/generative/layers/normalization.py +31 -20
- ai_edge_torch/generative/layers/scaled_dot_product_attention.py +19 -9
- ai_edge_torch/generative/layers/unet/blocks_2d.py +9 -4
- ai_edge_torch/generative/layers/unet/model_config.py +1 -0
- ai_edge_torch/generative/test/test_model_conversion.py +1 -1
- ai_edge_torch/generative/test/test_model_conversion_large.py +2 -2
- ai_edge_torch/generative/utilities/stable_diffusion_loader.py +4 -0
- ai_edge_torch/generative/utilities/verifier.py +249 -0
- ai_edge_torch/model.py +7 -4
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240918.dist-info → ai_edge_torch_nightly-0.3.0.dev20240920.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240918.dist-info → ai_edge_torch_nightly-0.3.0.dev20240920.dist-info}/RECORD +32 -27
- {ai_edge_torch_nightly-0.3.0.dev20240918.dist-info → ai_edge_torch_nightly-0.3.0.dev20240920.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240918.dist-info → ai_edge_torch_nightly-0.3.0.dev20240920.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240918.dist-info → ai_edge_torch_nightly-0.3.0.dev20240920.dist-info}/top_level.txt +0 -0
@@ -23,12 +23,12 @@ from ai_edge_torch import config
|
|
23
23
|
from ai_edge_torch._convert import conversion_utils
|
24
24
|
from ai_edge_torch.testing import model_coverage
|
25
25
|
import numpy as np
|
26
|
-
import tensorflow as tf
|
27
26
|
import torch
|
28
27
|
from torch import nn
|
29
28
|
import torchvision
|
30
29
|
|
31
30
|
from absl.testing import absltest as googletest
|
31
|
+
from ai_edge_litert import interpreter as tfl_interpreter # pylint: disable=g-direct-tensorflow-import
|
32
32
|
|
33
33
|
|
34
34
|
@dataclasses.dataclass
|
@@ -466,7 +466,9 @@ class TestConvert(googletest.TestCase):
|
|
466
466
|
np.testing.assert_almost_equal(edge_output["y_data_2_0"], args[1])
|
467
467
|
np.testing.assert_almost_equal(edge_output["y_data_2_1"], args[2])
|
468
468
|
|
469
|
-
interpreter =
|
469
|
+
interpreter = tfl_interpreter.Interpreter(
|
470
|
+
model_content=edge_model._tflite_model
|
471
|
+
)
|
470
472
|
runner = interpreter.get_signature_runner("serving_default")
|
471
473
|
output_details = runner.get_output_details()
|
472
474
|
self.assertIn("x", output_details.keys())
|
@@ -477,7 +479,9 @@ class TestConvert(googletest.TestCase):
|
|
477
479
|
def _compare_tflite_torch_args_kwargs(self, model, args, kwargs, flat_inputs):
|
478
480
|
model.eval()
|
479
481
|
edge_model = ai_edge_torch.convert(model, args, kwargs)
|
480
|
-
interpreter =
|
482
|
+
interpreter = tfl_interpreter.Interpreter(
|
483
|
+
model_content=edge_model._tflite_model
|
484
|
+
)
|
481
485
|
runner = interpreter.get_signature_runner("serving_default")
|
482
486
|
input_details = runner.get_input_details()
|
483
487
|
self.assertEqual(input_details.keys(), flat_inputs.keys())
|
@@ -30,17 +30,17 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
|
|
30
30
|
)
|
31
31
|
_TFLITE_PATH = flags.DEFINE_string(
|
32
32
|
'tflite_path',
|
33
|
-
'/tmp/
|
33
|
+
'/tmp/',
|
34
34
|
'The tflite file path to export.',
|
35
35
|
)
|
36
36
|
_PREFILL_SEQ_LEN = flags.DEFINE_integer(
|
37
37
|
'prefill_seq_len',
|
38
|
-
|
38
|
+
1024,
|
39
39
|
'The maximum size of prefill input tensor.',
|
40
40
|
)
|
41
41
|
_KV_CACHE_MAX_LEN = flags.DEFINE_integer(
|
42
42
|
'kv_cache_max_len',
|
43
|
-
|
43
|
+
1280,
|
44
44
|
'The maximum size of KV cache buffer, including both prefill and decode.',
|
45
45
|
)
|
46
46
|
_QUANTIZE = flags.DEFINE_bool(
|
@@ -54,9 +54,11 @@ def main(_):
|
|
54
54
|
pytorch_model = gemma2.build_2b_model(
|
55
55
|
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
|
56
56
|
)
|
57
|
+
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
|
58
|
+
output_filename = f'gemma2_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
|
57
59
|
converter.convert_to_tflite(
|
58
60
|
pytorch_model,
|
59
|
-
tflite_path=_TFLITE_PATH.value,
|
61
|
+
tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
|
60
62
|
prefill_seq_len=_PREFILL_SEQ_LEN.value,
|
61
63
|
quantize=_QUANTIZE.value,
|
62
64
|
)
|
@@ -30,17 +30,17 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
|
|
30
30
|
)
|
31
31
|
_TFLITE_PATH = flags.DEFINE_string(
|
32
32
|
'tflite_path',
|
33
|
-
'/tmp/
|
33
|
+
'/tmp/',
|
34
34
|
'The tflite file path to export.',
|
35
35
|
)
|
36
36
|
_PREFILL_SEQ_LEN = flags.DEFINE_integer(
|
37
37
|
'prefill_seq_len',
|
38
|
-
|
38
|
+
1024,
|
39
39
|
'The maximum size of prefill input tensor.',
|
40
40
|
)
|
41
41
|
_KV_CACHE_MAX_LEN = flags.DEFINE_integer(
|
42
42
|
'kv_cache_max_len',
|
43
|
-
|
43
|
+
1280,
|
44
44
|
'The maximum size of KV cache buffer, including both prefill and decode.',
|
45
45
|
)
|
46
46
|
_QUANTIZE = flags.DEFINE_bool(
|
@@ -54,9 +54,11 @@ def main(_):
|
|
54
54
|
pytorch_model = gemma.build_2b_model(
|
55
55
|
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
|
56
56
|
)
|
57
|
+
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
|
58
|
+
output_filename = f'gemma_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
|
57
59
|
converter.convert_to_tflite(
|
58
60
|
pytorch_model,
|
59
|
-
tflite_path=_TFLITE_PATH.value,
|
61
|
+
tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
|
60
62
|
prefill_seq_len=_PREFILL_SEQ_LEN.value,
|
61
63
|
quantize=_QUANTIZE.value,
|
62
64
|
)
|
@@ -30,17 +30,17 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
|
|
30
30
|
)
|
31
31
|
_TFLITE_PATH = flags.DEFINE_string(
|
32
32
|
'tflite_path',
|
33
|
-
'/tmp/
|
33
|
+
'/tmp/',
|
34
34
|
'The tflite file path to export.',
|
35
35
|
)
|
36
36
|
_PREFILL_SEQ_LEN = flags.DEFINE_integer(
|
37
37
|
'prefill_seq_len',
|
38
|
-
|
38
|
+
1024,
|
39
39
|
'The maximum size of prefill input tensor.',
|
40
40
|
)
|
41
41
|
_KV_CACHE_MAX_LEN = flags.DEFINE_integer(
|
42
42
|
'kv_cache_max_len',
|
43
|
-
|
43
|
+
1280,
|
44
44
|
'The maximum size of KV cache buffer, including both prefill and decode.',
|
45
45
|
)
|
46
46
|
_QUANTIZE = flags.DEFINE_bool(
|
@@ -54,9 +54,11 @@ def main(_):
|
|
54
54
|
pytorch_model = openelm.build_model(
|
55
55
|
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
|
56
56
|
)
|
57
|
+
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
|
58
|
+
output_filename = f'openelm_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
|
57
59
|
converter.convert_to_tflite(
|
58
60
|
pytorch_model,
|
59
|
-
tflite_path=_TFLITE_PATH.value,
|
61
|
+
tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
|
60
62
|
prefill_seq_len=_PREFILL_SEQ_LEN.value,
|
61
63
|
quantize=_QUANTIZE.value,
|
62
64
|
)
|
@@ -15,16 +15,12 @@
|
|
15
15
|
|
16
16
|
"""Example of building an OpenELM model."""
|
17
17
|
|
18
|
-
import os
|
19
|
-
import pathlib
|
20
|
-
|
21
18
|
from ai_edge_torch.generative.layers import attention
|
22
19
|
from ai_edge_torch.generative.layers import builder
|
23
20
|
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
24
21
|
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
25
22
|
import ai_edge_torch.generative.layers.model_config as cfg
|
26
23
|
import ai_edge_torch.generative.utilities.loader as loading_utils
|
27
|
-
import numpy as np
|
28
24
|
import torch
|
29
25
|
from torch import nn
|
30
26
|
|
@@ -210,28 +206,3 @@ def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
|
|
210
206
|
loader.load(model, strict=False)
|
211
207
|
model.eval()
|
212
208
|
return model
|
213
|
-
|
214
|
-
|
215
|
-
def define_and_run(checkpoint_path: str) -> None:
|
216
|
-
"""Instantiates and runs an OpenELM model."""
|
217
|
-
|
218
|
-
current_dir = pathlib.Path(__file__).parent.resolve()
|
219
|
-
openelm_goldens = torch.load(current_dir / "openelm_lm_logits.pt")
|
220
|
-
kv_cache_max_len = 1024
|
221
|
-
model = build_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
|
222
|
-
idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
|
223
|
-
tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.int, device="cpu")
|
224
|
-
tokens[0, :4] = idx
|
225
|
-
input_pos = torch.arange(0, kv_cache_max_len, dtype=torch.int)
|
226
|
-
kv = kv_utils.KVCache.from_model_config(model.config)
|
227
|
-
output = model.forward(tokens, input_pos, kv)
|
228
|
-
assert torch.allclose(
|
229
|
-
openelm_goldens, output["logits"][0, idx.shape[1] - 1, :], atol=1e-05
|
230
|
-
)
|
231
|
-
|
232
|
-
|
233
|
-
if __name__ == "__main__":
|
234
|
-
input_checkpoint_path = os.path.join(
|
235
|
-
pathlib.Path.home(), "Downloads/llm_data/openelm"
|
236
|
-
)
|
237
|
-
define_and_run(input_checkpoint_path)
|
@@ -0,0 +1,63 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Verifies the reauthored OpenELM-3B model."""
|
17
|
+
|
18
|
+
import pathlib
|
19
|
+
|
20
|
+
from absl import app
|
21
|
+
from absl import flags
|
22
|
+
from ai_edge_torch.generative.examples.openelm import openelm
|
23
|
+
from ai_edge_torch.generative.utilities import verifier
|
24
|
+
import transformers
|
25
|
+
|
26
|
+
_PROMPTS = flags.DEFINE_multi_string(
|
27
|
+
"prompts",
|
28
|
+
"What is the meaning of life?",
|
29
|
+
"The input prompts to generate answers.",
|
30
|
+
)
|
31
|
+
|
32
|
+
|
33
|
+
def main(_):
|
34
|
+
checkpoint = "apple/OpenELM-3B"
|
35
|
+
verifier.log_msg("Loading the original model from", checkpoint)
|
36
|
+
wrapper_model = verifier.ModelWrapper(
|
37
|
+
model=transformers.AutoModelForCausalLM.from_pretrained(
|
38
|
+
checkpoint, trust_remote_code=True
|
39
|
+
),
|
40
|
+
)
|
41
|
+
|
42
|
+
# Locate the cached dir.
|
43
|
+
cached_config_file = transformers.utils.cached_file(
|
44
|
+
checkpoint, transformers.utils.CONFIG_NAME
|
45
|
+
)
|
46
|
+
reauthored_checkpoint = pathlib.Path(cached_config_file).parent
|
47
|
+
verifier.log_msg("Building the reauthored model from", reauthored_checkpoint)
|
48
|
+
reauthored_model = openelm.build_model(reauthored_checkpoint)
|
49
|
+
|
50
|
+
tokenizer_checkpoint = "meta-llama/Llama-2-7b-hf"
|
51
|
+
verifier.log_msg("Loading the tokenizer from", tokenizer_checkpoint)
|
52
|
+
tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer_checkpoint)
|
53
|
+
|
54
|
+
verifier.verify_reauthored_model(
|
55
|
+
original_model=wrapper_model,
|
56
|
+
reauthored_model=reauthored_model,
|
57
|
+
tokenizer=tokenizer,
|
58
|
+
prompts=_PROMPTS.value,
|
59
|
+
)
|
60
|
+
|
61
|
+
|
62
|
+
if __name__ == "__main__":
|
63
|
+
app.run(main)
|
@@ -30,17 +30,17 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
|
|
30
30
|
)
|
31
31
|
_TFLITE_PATH = flags.DEFINE_string(
|
32
32
|
'tflite_path',
|
33
|
-
'/tmp/
|
33
|
+
'/tmp/',
|
34
34
|
'The tflite file path to export.',
|
35
35
|
)
|
36
36
|
_PREFILL_SEQ_LEN = flags.DEFINE_integer(
|
37
37
|
'prefill_seq_len',
|
38
|
-
|
38
|
+
1024,
|
39
39
|
'The maximum size of prefill input tensor.',
|
40
40
|
)
|
41
41
|
_KV_CACHE_MAX_LEN = flags.DEFINE_integer(
|
42
42
|
'kv_cache_max_len',
|
43
|
-
|
43
|
+
1280,
|
44
44
|
'The maximum size of KV cache buffer, including both prefill and decode.',
|
45
45
|
)
|
46
46
|
_QUANTIZE = flags.DEFINE_bool(
|
@@ -54,9 +54,11 @@ def main(_):
|
|
54
54
|
pytorch_model = phi2.build_model(
|
55
55
|
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
|
56
56
|
)
|
57
|
+
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
|
58
|
+
output_filename = f'phi2_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
|
57
59
|
converter.convert_to_tflite(
|
58
60
|
pytorch_model,
|
59
|
-
tflite_path=_TFLITE_PATH.value,
|
61
|
+
tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
|
60
62
|
prefill_seq_len=_PREFILL_SEQ_LEN.value,
|
61
63
|
quantize=_QUANTIZE.value,
|
62
64
|
)
|
@@ -15,16 +15,12 @@
|
|
15
15
|
|
16
16
|
"""Example of building a Phi-2 model."""
|
17
17
|
|
18
|
-
import os
|
19
|
-
import pathlib
|
20
|
-
|
21
18
|
from ai_edge_torch.generative.layers import attention
|
22
19
|
from ai_edge_torch.generative.layers import builder
|
23
20
|
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
24
21
|
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
25
22
|
import ai_edge_torch.generative.layers.model_config as cfg
|
26
23
|
import ai_edge_torch.generative.utilities.loader as loading_utils
|
27
|
-
import numpy as np
|
28
24
|
import torch
|
29
25
|
from torch import nn
|
30
26
|
|
@@ -143,7 +139,10 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
143
139
|
intermediate_size=10240,
|
144
140
|
use_bias=True,
|
145
141
|
)
|
146
|
-
norm_config = cfg.NormalizationConfig(
|
142
|
+
norm_config = cfg.NormalizationConfig(
|
143
|
+
type=cfg.NormalizationType.LAYER_NORM,
|
144
|
+
use_input_shape=False, # Phi-2 does layer-norm with the weight shape.
|
145
|
+
)
|
147
146
|
block_config = cfg.TransformerBlockConfig(
|
148
147
|
attn_config=attn_config,
|
149
148
|
ff_config=ff_config,
|
@@ -182,29 +181,3 @@ def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
|
|
182
181
|
loader.load(model)
|
183
182
|
model.eval()
|
184
183
|
return model
|
185
|
-
|
186
|
-
|
187
|
-
def define_and_run(checkpoint_path: str) -> None:
|
188
|
-
"""Instantiates and runs a Phi-2 model."""
|
189
|
-
|
190
|
-
current_dir = pathlib.Path(__file__).parent.resolve()
|
191
|
-
phi2_goldens = torch.load(current_dir / "phi2_lm_logits.pt")
|
192
|
-
kv_cache_max_len = 1024
|
193
|
-
model = build_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
|
194
|
-
idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
|
195
|
-
tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.int, device="cpu")
|
196
|
-
tokens[0, :4] = idx
|
197
|
-
input_pos = torch.arange(0, kv_cache_max_len, dtype=torch.int)
|
198
|
-
kv = kv_utils.KVCache.from_model_config(model.config)
|
199
|
-
output = model.forward(tokens, input_pos, kv)
|
200
|
-
print("comparing with goldens..")
|
201
|
-
assert torch.allclose(
|
202
|
-
phi2_goldens, output["logits"][0, idx.shape[1] - 1, :], atol=1e-02
|
203
|
-
)
|
204
|
-
|
205
|
-
|
206
|
-
if __name__ == "__main__":
|
207
|
-
input_checkpoint_path = os.path.join(
|
208
|
-
pathlib.Path.home(), "Downloads/llm_data/phi2"
|
209
|
-
)
|
210
|
-
define_and_run(input_checkpoint_path)
|
@@ -0,0 +1,63 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Verifies the reauthored Phi-2 model."""
|
17
|
+
|
18
|
+
from absl import app
|
19
|
+
from absl import flags
|
20
|
+
from ai_edge_torch.generative.examples.phi import phi2
|
21
|
+
from ai_edge_torch.generative.utilities import verifier
|
22
|
+
import kagglehub
|
23
|
+
import transformers
|
24
|
+
|
25
|
+
_PROMPTS = flags.DEFINE_multi_string(
|
26
|
+
"prompts",
|
27
|
+
"Instruct: Write an email about the weather Output:",
|
28
|
+
"The input prompts to generate answers.",
|
29
|
+
)
|
30
|
+
|
31
|
+
_MAX_NEW_TOKENS = flags.DEFINE_integer(
|
32
|
+
"max_new_tokens",
|
33
|
+
30,
|
34
|
+
"The maximum size of the generated tokens.",
|
35
|
+
)
|
36
|
+
|
37
|
+
def main(_):
|
38
|
+
checkpoint = kagglehub.model_download("Microsoft/phi/transformers/2")
|
39
|
+
verifier.log_msg("Loading the original model from", checkpoint)
|
40
|
+
generation_config = transformers.GenerationConfig.from_pretrained(checkpoint)
|
41
|
+
generation_config.max_new_tokens = _MAX_NEW_TOKENS.value
|
42
|
+
wrapper_model = verifier.ModelWrapper(
|
43
|
+
model=transformers.AutoModelForCausalLM.from_pretrained(checkpoint),
|
44
|
+
hf_generation_config=generation_config,
|
45
|
+
)
|
46
|
+
|
47
|
+
verifier.log_msg("Building the reauthored model from", checkpoint)
|
48
|
+
reauthored_model = phi2.build_model(checkpoint)
|
49
|
+
|
50
|
+
verifier.log_msg("Loading the tokenizer from", checkpoint)
|
51
|
+
tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
|
52
|
+
|
53
|
+
verifier.verify_reauthored_model(
|
54
|
+
original_model=wrapper_model,
|
55
|
+
reauthored_model=reauthored_model,
|
56
|
+
tokenizer=tokenizer,
|
57
|
+
prompts=_PROMPTS.value,
|
58
|
+
atol=1e-03,
|
59
|
+
)
|
60
|
+
|
61
|
+
|
62
|
+
if __name__ == "__main__":
|
63
|
+
app.run(main)
|
@@ -30,17 +30,17 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
|
|
30
30
|
)
|
31
31
|
_TFLITE_PATH = flags.DEFINE_string(
|
32
32
|
'tflite_path',
|
33
|
-
'/tmp/
|
33
|
+
'/tmp/',
|
34
34
|
'The tflite file path to export.',
|
35
35
|
)
|
36
36
|
_PREFILL_SEQ_LEN = flags.DEFINE_integer(
|
37
37
|
'prefill_seq_len',
|
38
|
-
|
38
|
+
1024,
|
39
39
|
'The maximum size of prefill input tensor.',
|
40
40
|
)
|
41
41
|
_KV_CACHE_MAX_LEN = flags.DEFINE_integer(
|
42
42
|
'kv_cache_max_len',
|
43
|
-
|
43
|
+
1280,
|
44
44
|
'The maximum size of KV cache buffer, including both prefill and decode.',
|
45
45
|
)
|
46
46
|
_QUANTIZE = flags.DEFINE_bool(
|
@@ -54,9 +54,11 @@ def main(_):
|
|
54
54
|
pytorch_model = smollm.build_model(
|
55
55
|
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
|
56
56
|
)
|
57
|
+
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
|
58
|
+
output_filename = f'smollm_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
|
57
59
|
converter.convert_to_tflite(
|
58
60
|
pytorch_model,
|
59
|
-
tflite_path=_TFLITE_PATH.value,
|
61
|
+
tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
|
60
62
|
prefill_seq_len=_PREFILL_SEQ_LEN.value,
|
61
63
|
quantize=_QUANTIZE.value,
|
62
64
|
)
|
@@ -16,15 +16,10 @@
|
|
16
16
|
"""Example of building a SmolLM model."""
|
17
17
|
|
18
18
|
import copy
|
19
|
-
import os
|
20
|
-
import pathlib
|
21
19
|
|
22
20
|
from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
|
23
|
-
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
24
21
|
import ai_edge_torch.generative.layers.model_config as cfg
|
25
22
|
import ai_edge_torch.generative.utilities.loader as loading_utils
|
26
|
-
import numpy as np
|
27
|
-
import torch
|
28
23
|
from torch import nn
|
29
24
|
|
30
25
|
TENSOR_NAMES = copy.copy(tiny_llama.TENSOR_NAMES)
|
@@ -104,28 +99,3 @@ def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
|
|
104
99
|
loader.load(model, strict=False)
|
105
100
|
model.eval()
|
106
101
|
return model
|
107
|
-
|
108
|
-
|
109
|
-
def define_and_run(checkpoint_path: str) -> None:
|
110
|
-
"""Instantiates and runs a SmolLM model."""
|
111
|
-
|
112
|
-
current_dir = pathlib.Path(__file__).parent.resolve()
|
113
|
-
smollm_goldens = torch.load(current_dir / "smollm_lm_logits.pt")
|
114
|
-
kv_cache_max_len = 1024
|
115
|
-
model = build_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
|
116
|
-
idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
|
117
|
-
tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.int, device="cpu")
|
118
|
-
tokens[0, :4] = idx
|
119
|
-
input_pos = torch.arange(0, kv_cache_max_len, dtype=torch.int)
|
120
|
-
kv = kv_utils.KVCache.from_model_config(model.config)
|
121
|
-
output = model.forward(tokens, input_pos, kv)
|
122
|
-
assert torch.allclose(
|
123
|
-
smollm_goldens, output["logits"][0, idx.shape[1] - 1, :], atol=1e-05
|
124
|
-
)
|
125
|
-
|
126
|
-
|
127
|
-
if __name__ == "__main__":
|
128
|
-
input_checkpoint_path = os.path.join(
|
129
|
-
pathlib.Path.home(), "Downloads/llm_data/smollm"
|
130
|
-
)
|
131
|
-
define_and_run(input_checkpoint_path)
|
@@ -0,0 +1,60 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Verifies the reauthored SmolLM-135M model."""
|
17
|
+
|
18
|
+
import pathlib
|
19
|
+
|
20
|
+
from absl import app
|
21
|
+
from absl import flags
|
22
|
+
from ai_edge_torch.generative.examples.smollm import smollm
|
23
|
+
from ai_edge_torch.generative.utilities import verifier
|
24
|
+
import transformers
|
25
|
+
|
26
|
+
_PROMPTS = flags.DEFINE_multi_string(
|
27
|
+
"prompts",
|
28
|
+
"What is the meaning of life?",
|
29
|
+
"The input prompts to generate answers.",
|
30
|
+
)
|
31
|
+
|
32
|
+
|
33
|
+
def main(_):
|
34
|
+
checkpoint = "HuggingFaceTB/SmolLM-135M"
|
35
|
+
verifier.log_msg("Loading the original model from", checkpoint)
|
36
|
+
wrapper_model = verifier.ModelWrapper(
|
37
|
+
model=transformers.AutoModelForCausalLM.from_pretrained(checkpoint),
|
38
|
+
)
|
39
|
+
# Locate the cached dir.
|
40
|
+
cached_config_file = transformers.utils.cached_file(
|
41
|
+
checkpoint, transformers.utils.CONFIG_NAME
|
42
|
+
)
|
43
|
+
reauthored_checkpoint = pathlib.Path(cached_config_file).parent
|
44
|
+
verifier.log_msg("Building the reauthored model from", reauthored_checkpoint)
|
45
|
+
reauthored_model = smollm.build_model(reauthored_checkpoint)
|
46
|
+
|
47
|
+
verifier.log_msg("Loading the tokenizer from", checkpoint)
|
48
|
+
tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
|
49
|
+
|
50
|
+
verifier.verify_reauthored_model(
|
51
|
+
original_model=wrapper_model,
|
52
|
+
reauthored_model=reauthored_model,
|
53
|
+
tokenizer=tokenizer,
|
54
|
+
prompts=_PROMPTS.value,
|
55
|
+
atol=1e-04,
|
56
|
+
)
|
57
|
+
|
58
|
+
|
59
|
+
if __name__ == "__main__":
|
60
|
+
app.run(main)
|
@@ -30,17 +30,17 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
|
|
30
30
|
)
|
31
31
|
_TFLITE_PATH = flags.DEFINE_string(
|
32
32
|
'tflite_path',
|
33
|
-
'/tmp/
|
33
|
+
'/tmp/',
|
34
34
|
'The tflite file path to export.',
|
35
35
|
)
|
36
36
|
_PREFILL_SEQ_LEN = flags.DEFINE_integer(
|
37
37
|
'prefill_seq_len',
|
38
|
-
|
38
|
+
1024,
|
39
39
|
'The maximum size of prefill input tensor.',
|
40
40
|
)
|
41
41
|
_KV_CACHE_MAX_LEN = flags.DEFINE_integer(
|
42
42
|
'kv_cache_max_len',
|
43
|
-
|
43
|
+
1280,
|
44
44
|
'The maximum size of KV cache buffer, including both prefill and decode.',
|
45
45
|
)
|
46
46
|
_QUANTIZE = flags.DEFINE_bool(
|
@@ -54,9 +54,11 @@ def main(_):
|
|
54
54
|
pytorch_model = tiny_llama.build_model(
|
55
55
|
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
|
56
56
|
)
|
57
|
+
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
|
58
|
+
output_filename = f'tinyllama_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
|
57
59
|
converter.convert_to_tflite(
|
58
60
|
pytorch_model,
|
59
|
-
tflite_path=_TFLITE_PATH.value,
|
61
|
+
tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
|
60
62
|
prefill_seq_len=_PREFILL_SEQ_LEN.value,
|
61
63
|
quantize=_QUANTIZE.value,
|
62
64
|
)
|
@@ -15,16 +15,12 @@
|
|
15
15
|
|
16
16
|
"""Example of building a TinyLlama model."""
|
17
17
|
|
18
|
-
import os
|
19
|
-
import pathlib
|
20
|
-
|
21
18
|
from ai_edge_torch.generative.layers import attention
|
22
19
|
from ai_edge_torch.generative.layers import builder
|
23
20
|
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
24
21
|
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
25
22
|
import ai_edge_torch.generative.layers.model_config as cfg
|
26
23
|
import ai_edge_torch.generative.utilities.loader as loading_utils
|
27
|
-
import numpy as np
|
28
24
|
import torch
|
29
25
|
from torch import nn
|
30
26
|
|
@@ -179,28 +175,3 @@ def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
|
|
179
175
|
loader.load(model)
|
180
176
|
model.eval()
|
181
177
|
return model
|
182
|
-
|
183
|
-
|
184
|
-
def define_and_run(checkpoint_path: str) -> None:
|
185
|
-
"""Instantiates and runs a TinyLlama model."""
|
186
|
-
|
187
|
-
current_dir = pathlib.Path(__file__).parent.resolve()
|
188
|
-
tiny_llama_goldens = torch.load(current_dir / "tiny_llama_lm_logits.pt")
|
189
|
-
kv_cache_max_len = 1024
|
190
|
-
model = build_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
|
191
|
-
idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
|
192
|
-
tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.int, device="cpu")
|
193
|
-
tokens[0, :4] = idx
|
194
|
-
input_pos = torch.arange(0, kv_cache_max_len, dtype=torch.int)
|
195
|
-
kv = kv_utils.KVCache.from_model_config(model.config)
|
196
|
-
output = model.forward(tokens, input_pos, kv)
|
197
|
-
assert torch.allclose(
|
198
|
-
tiny_llama_goldens, output["logits"][0, idx.shape[1] - 1, :], atol=1e-02
|
199
|
-
)
|
200
|
-
|
201
|
-
|
202
|
-
if __name__ == "__main__":
|
203
|
-
input_checkpoint_path = os.path.join(
|
204
|
-
pathlib.Path.home(), "Downloads/llm_data/tiny_llama"
|
205
|
-
)
|
206
|
-
define_and_run(input_checkpoint_path)
|