ai-edge-torch-nightly 0.3.0.dev20240916__py3-none-any.whl → 0.3.0.dev20240919__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +36 -56
- ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +36 -56
- ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +36 -56
- ai_edge_torch/generative/examples/openelm/openelm.py +0 -29
- ai_edge_torch/generative/examples/openelm/verify.py +61 -0
- ai_edge_torch/generative/examples/phi/convert_to_tflite.py +36 -56
- ai_edge_torch/generative/examples/phi/phi2.py +4 -31
- ai_edge_torch/generative/examples/phi/verify.py +53 -0
- ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +36 -56
- ai_edge_torch/generative/examples/smollm/smollm.py +0 -30
- ai_edge_torch/generative/examples/smollm/verify.py +59 -0
- ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +6 -0
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +36 -56
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +0 -29
- ai_edge_torch/generative/examples/tiny_llama/verify.py +61 -0
- ai_edge_torch/generative/layers/attention.py +8 -4
- ai_edge_torch/generative/layers/builder.py +3 -1
- ai_edge_torch/generative/layers/model_config.py +3 -0
- ai_edge_torch/generative/layers/normalization.py +31 -20
- ai_edge_torch/generative/layers/scaled_dot_product_attention.py +19 -9
- ai_edge_torch/generative/layers/unet/blocks_2d.py +11 -4
- ai_edge_torch/generative/layers/unet/model_config.py +3 -0
- ai_edge_torch/generative/test/test_model_conversion_large.py +1 -1
- ai_edge_torch/generative/utilities/converter.py +82 -0
- ai_edge_torch/generative/utilities/stable_diffusion_loader.py +10 -0
- ai_edge_torch/generative/utilities/verifier.py +200 -0
- ai_edge_torch/odml_torch/lowerings/_basic.py +14 -4
- ai_edge_torch/odml_torch/lowerings/registry.py +1 -0
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240916.dist-info → ai_edge_torch_nightly-0.3.0.dev20240919.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240916.dist-info → ai_edge_torch_nightly-0.3.0.dev20240919.dist-info}/RECORD +34 -28
- {ai_edge_torch_nightly-0.3.0.dev20240916.dist-info → ai_edge_torch_nightly-0.3.0.dev20240919.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240916.dist-info → ai_edge_torch_nightly-0.3.0.dev20240919.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240916.dist-info → ai_edge_torch_nightly-0.3.0.dev20240919.dist-info}/top_level.txt +0 -0
@@ -18,69 +18,49 @@
|
|
18
18
|
import os
|
19
19
|
import pathlib
|
20
20
|
|
21
|
-
import
|
21
|
+
from absl import app
|
22
|
+
from absl import flags
|
22
23
|
from ai_edge_torch.generative.examples.gemma import gemma2
|
23
|
-
from ai_edge_torch.generative.
|
24
|
-
from ai_edge_torch.generative.quantize import quant_recipes
|
25
|
-
import torch
|
24
|
+
from ai_edge_torch.generative.utilities import converter
|
26
25
|
|
26
|
+
_CHECKPOINT_PATH = flags.DEFINE_string(
|
27
|
+
'checkpoint_path',
|
28
|
+
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/gemma2-2b'),
|
29
|
+
'The path to the model checkpoint, or directory holding the checkpoint.',
|
30
|
+
)
|
31
|
+
_TFLITE_PATH = flags.DEFINE_string(
|
32
|
+
'tflite_path',
|
33
|
+
'/tmp/gemma2_q8_seq512_ekv1024.tflite',
|
34
|
+
'The tflite file path to export.',
|
35
|
+
)
|
36
|
+
_PREFILL_SEQ_LEN = flags.DEFINE_integer(
|
37
|
+
'prefill_seq_len',
|
38
|
+
512,
|
39
|
+
'The maximum size of prefill input tensor.',
|
40
|
+
)
|
41
|
+
_KV_CACHE_MAX_LEN = flags.DEFINE_integer(
|
42
|
+
'kv_cache_max_len',
|
43
|
+
1024,
|
44
|
+
'The maximum size of KV cache buffer, including both prefill and decode.',
|
45
|
+
)
|
46
|
+
_QUANTIZE = flags.DEFINE_bool(
|
47
|
+
'quantize',
|
48
|
+
True,
|
49
|
+
'Whether the model should be quantized.',
|
50
|
+
)
|
27
51
|
|
28
|
-
def convert_gemma2_to_tflite(
|
29
|
-
checkpoint_path: str,
|
30
|
-
prefill_seq_len: int = 512,
|
31
|
-
kv_cache_max_len: int = 1024,
|
32
|
-
quantize: bool = True,
|
33
|
-
):
|
34
|
-
"""Converts a Gemma2 2B model to multi-signature tflite model.
|
35
52
|
|
36
|
-
|
37
|
-
checkpoint_path (str): The filepath to the model checkpoint, or directory
|
38
|
-
holding the checkpoint.
|
39
|
-
prefill_seq_len (int, optional): The maximum size of prefill input tensor.
|
40
|
-
Defaults to 512.
|
41
|
-
kv_cache_max_len (int, optional): The maximum size of KV cache buffer,
|
42
|
-
including both prefill and decode. Defaults to 1024.
|
43
|
-
quantize (bool, optional): Whether the model should be quanized. Defaults
|
44
|
-
to True.
|
45
|
-
"""
|
53
|
+
def main(_):
|
46
54
|
pytorch_model = gemma2.build_2b_model(
|
47
|
-
|
55
|
+
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
|
48
56
|
)
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
kv = kv_utils.KVCache.from_model_config(pytorch_model.config)
|
55
|
-
|
56
|
-
quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
|
57
|
-
edge_model = (
|
58
|
-
ai_edge_torch.signature(
|
59
|
-
'prefill',
|
60
|
-
pytorch_model,
|
61
|
-
sample_kwargs={
|
62
|
-
'tokens': prefill_tokens,
|
63
|
-
'input_pos': prefill_input_pos,
|
64
|
-
'kv_cache': kv,
|
65
|
-
},
|
66
|
-
)
|
67
|
-
.signature(
|
68
|
-
'decode',
|
69
|
-
pytorch_model,
|
70
|
-
sample_kwargs={
|
71
|
-
'tokens': decode_token,
|
72
|
-
'input_pos': decode_input_pos,
|
73
|
-
'kv_cache': kv,
|
74
|
-
},
|
75
|
-
)
|
76
|
-
.convert(quant_config=quant_config)
|
77
|
-
)
|
78
|
-
quant_suffix = 'q8' if quantize else 'f32'
|
79
|
-
edge_model.export(
|
80
|
-
f'/tmp/gemma2_{quant_suffix}_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
|
57
|
+
converter.convert_to_tflite(
|
58
|
+
pytorch_model,
|
59
|
+
tflite_path=_TFLITE_PATH.value,
|
60
|
+
prefill_seq_len=_PREFILL_SEQ_LEN.value,
|
61
|
+
quantize=_QUANTIZE.value,
|
81
62
|
)
|
82
63
|
|
83
64
|
|
84
65
|
if __name__ == '__main__':
|
85
|
-
|
86
|
-
convert_gemma2_to_tflite(path)
|
66
|
+
app.run(main)
|
@@ -18,69 +18,49 @@
|
|
18
18
|
import os
|
19
19
|
import pathlib
|
20
20
|
|
21
|
-
import
|
21
|
+
from absl import app
|
22
|
+
from absl import flags
|
22
23
|
from ai_edge_torch.generative.examples.gemma import gemma
|
23
|
-
from ai_edge_torch.generative.
|
24
|
-
from ai_edge_torch.generative.quantize import quant_recipes
|
25
|
-
import torch
|
24
|
+
from ai_edge_torch.generative.utilities import converter
|
26
25
|
|
26
|
+
_CHECKPOINT_PATH = flags.DEFINE_string(
|
27
|
+
'checkpoint_path',
|
28
|
+
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/gemma-2b'),
|
29
|
+
'The path to the model checkpoint, or directory holding the checkpoint.',
|
30
|
+
)
|
31
|
+
_TFLITE_PATH = flags.DEFINE_string(
|
32
|
+
'tflite_path',
|
33
|
+
'/tmp/gemma_q8_seq512_ekv1024.tflite',
|
34
|
+
'The tflite file path to export.',
|
35
|
+
)
|
36
|
+
_PREFILL_SEQ_LEN = flags.DEFINE_integer(
|
37
|
+
'prefill_seq_len',
|
38
|
+
512,
|
39
|
+
'The maximum size of prefill input tensor.',
|
40
|
+
)
|
41
|
+
_KV_CACHE_MAX_LEN = flags.DEFINE_integer(
|
42
|
+
'kv_cache_max_len',
|
43
|
+
1024,
|
44
|
+
'The maximum size of KV cache buffer, including both prefill and decode.',
|
45
|
+
)
|
46
|
+
_QUANTIZE = flags.DEFINE_bool(
|
47
|
+
'quantize',
|
48
|
+
True,
|
49
|
+
'Whether the model should be quantized.',
|
50
|
+
)
|
27
51
|
|
28
|
-
def convert_gemma_to_tflite(
|
29
|
-
checkpoint_path: str,
|
30
|
-
prefill_seq_len: int = 512,
|
31
|
-
kv_cache_max_len: int = 1024,
|
32
|
-
quantize: bool = True,
|
33
|
-
):
|
34
|
-
"""Converts a Gemma 2B model to multi-signature tflite model.
|
35
52
|
|
36
|
-
|
37
|
-
checkpoint_path (str): The filepath to the model checkpoint, or directory
|
38
|
-
holding the checkpoint.
|
39
|
-
prefill_seq_len (int, optional): The maximum size of prefill input tensor.
|
40
|
-
Defaults to 512.
|
41
|
-
kv_cache_max_len (int, optional): The maximum size of KV cache buffer,
|
42
|
-
including both prefill and decode. Defaults to 1024.
|
43
|
-
quantize (bool, optional): Whether the model should be quanized. Defaults
|
44
|
-
to True.
|
45
|
-
"""
|
53
|
+
def main(_):
|
46
54
|
pytorch_model = gemma.build_2b_model(
|
47
|
-
|
55
|
+
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
|
48
56
|
)
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
kv = kv_utils.KVCache.from_model_config(pytorch_model.config)
|
55
|
-
|
56
|
-
quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
|
57
|
-
edge_model = (
|
58
|
-
ai_edge_torch.signature(
|
59
|
-
'prefill',
|
60
|
-
pytorch_model,
|
61
|
-
sample_kwargs={
|
62
|
-
'tokens': prefill_tokens,
|
63
|
-
'input_pos': prefill_input_pos,
|
64
|
-
'kv_cache': kv,
|
65
|
-
},
|
66
|
-
)
|
67
|
-
.signature(
|
68
|
-
'decode',
|
69
|
-
pytorch_model,
|
70
|
-
sample_kwargs={
|
71
|
-
'tokens': decode_token,
|
72
|
-
'input_pos': decode_input_pos,
|
73
|
-
'kv_cache': kv,
|
74
|
-
},
|
75
|
-
)
|
76
|
-
.convert(quant_config=quant_config)
|
77
|
-
)
|
78
|
-
quant_suffix = 'q8' if quantize else 'f32'
|
79
|
-
edge_model.export(
|
80
|
-
f'/tmp/gemma_{quant_suffix}_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
|
57
|
+
converter.convert_to_tflite(
|
58
|
+
pytorch_model,
|
59
|
+
tflite_path=_TFLITE_PATH.value,
|
60
|
+
prefill_seq_len=_PREFILL_SEQ_LEN.value,
|
61
|
+
quantize=_QUANTIZE.value,
|
81
62
|
)
|
82
63
|
|
83
64
|
|
84
65
|
if __name__ == '__main__':
|
85
|
-
|
86
|
-
convert_gemma_to_tflite(path)
|
66
|
+
app.run(main)
|
@@ -18,69 +18,49 @@
|
|
18
18
|
import os
|
19
19
|
import pathlib
|
20
20
|
|
21
|
-
import
|
21
|
+
from absl import app
|
22
|
+
from absl import flags
|
22
23
|
from ai_edge_torch.generative.examples.openelm import openelm
|
23
|
-
from ai_edge_torch.generative.
|
24
|
-
from ai_edge_torch.generative.quantize import quant_recipes
|
25
|
-
import torch
|
24
|
+
from ai_edge_torch.generative.utilities import converter
|
26
25
|
|
26
|
+
_CHECKPOINT_PATH = flags.DEFINE_string(
|
27
|
+
'checkpoint_path',
|
28
|
+
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/openelm'),
|
29
|
+
'The path to the model checkpoint, or directory holding the checkpoint.',
|
30
|
+
)
|
31
|
+
_TFLITE_PATH = flags.DEFINE_string(
|
32
|
+
'tflite_path',
|
33
|
+
'/tmp/openelm_q8_seq512_ekv1024.tflite',
|
34
|
+
'The tflite file path to export.',
|
35
|
+
)
|
36
|
+
_PREFILL_SEQ_LEN = flags.DEFINE_integer(
|
37
|
+
'prefill_seq_len',
|
38
|
+
512,
|
39
|
+
'The maximum size of prefill input tensor.',
|
40
|
+
)
|
41
|
+
_KV_CACHE_MAX_LEN = flags.DEFINE_integer(
|
42
|
+
'kv_cache_max_len',
|
43
|
+
1024,
|
44
|
+
'The maximum size of KV cache buffer, including both prefill and decode.',
|
45
|
+
)
|
46
|
+
_QUANTIZE = flags.DEFINE_bool(
|
47
|
+
'quantize',
|
48
|
+
True,
|
49
|
+
'Whether the model should be quantized.',
|
50
|
+
)
|
27
51
|
|
28
|
-
def convert_openelm_to_tflite(
|
29
|
-
checkpoint_path: str,
|
30
|
-
prefill_seq_len: int = 512,
|
31
|
-
kv_cache_max_len: int = 1024,
|
32
|
-
quantize: bool = True,
|
33
|
-
):
|
34
|
-
"""Converts OpenELM model to multi-signature tflite model.
|
35
52
|
|
36
|
-
|
37
|
-
checkpoint_path (str): The filepath to the model checkpoint, or directory
|
38
|
-
holding the checkpoint.
|
39
|
-
prefill_seq_len (int, optional): The maximum size of prefill input tensor.
|
40
|
-
Defaults to 512.
|
41
|
-
kv_cache_max_len (int, optional): The maximum size of KV cache buffer,
|
42
|
-
including both prefill and decode. Defaults to 1024.
|
43
|
-
quantize (bool, optional): Whether the model should be quanized. Defaults
|
44
|
-
to True.
|
45
|
-
"""
|
53
|
+
def main(_):
|
46
54
|
pytorch_model = openelm.build_model(
|
47
|
-
|
55
|
+
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
|
48
56
|
)
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
kv = kv_utils.KVCache.from_model_config(pytorch_model.config)
|
55
|
-
|
56
|
-
quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
|
57
|
-
edge_model = (
|
58
|
-
ai_edge_torch.signature(
|
59
|
-
'prefill',
|
60
|
-
pytorch_model,
|
61
|
-
sample_kwargs={
|
62
|
-
'tokens': prefill_tokens,
|
63
|
-
'input_pos': prefill_input_pos,
|
64
|
-
'kv_cache': kv,
|
65
|
-
},
|
66
|
-
)
|
67
|
-
.signature(
|
68
|
-
'decode',
|
69
|
-
pytorch_model,
|
70
|
-
sample_kwargs={
|
71
|
-
'tokens': decode_token,
|
72
|
-
'input_pos': decode_input_pos,
|
73
|
-
'kv_cache': kv,
|
74
|
-
},
|
75
|
-
)
|
76
|
-
.convert(quant_config=quant_config)
|
77
|
-
)
|
78
|
-
quant_suffix = 'q8' if quantize else 'f32'
|
79
|
-
edge_model.export(
|
80
|
-
f'/tmp/openelm_{quant_suffix}_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
|
57
|
+
converter.convert_to_tflite(
|
58
|
+
pytorch_model,
|
59
|
+
tflite_path=_TFLITE_PATH.value,
|
60
|
+
prefill_seq_len=_PREFILL_SEQ_LEN.value,
|
61
|
+
quantize=_QUANTIZE.value,
|
81
62
|
)
|
82
63
|
|
83
64
|
|
84
65
|
if __name__ == '__main__':
|
85
|
-
|
86
|
-
convert_openelm_to_tflite(path)
|
66
|
+
app.run(main)
|
@@ -15,16 +15,12 @@
|
|
15
15
|
|
16
16
|
"""Example of building an OpenELM model."""
|
17
17
|
|
18
|
-
import os
|
19
|
-
import pathlib
|
20
|
-
|
21
18
|
from ai_edge_torch.generative.layers import attention
|
22
19
|
from ai_edge_torch.generative.layers import builder
|
23
20
|
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
24
21
|
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
25
22
|
import ai_edge_torch.generative.layers.model_config as cfg
|
26
23
|
import ai_edge_torch.generative.utilities.loader as loading_utils
|
27
|
-
import numpy as np
|
28
24
|
import torch
|
29
25
|
from torch import nn
|
30
26
|
|
@@ -210,28 +206,3 @@ def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
|
|
210
206
|
loader.load(model, strict=False)
|
211
207
|
model.eval()
|
212
208
|
return model
|
213
|
-
|
214
|
-
|
215
|
-
def define_and_run(checkpoint_path: str) -> None:
|
216
|
-
"""Instantiates and runs an OpenELM model."""
|
217
|
-
|
218
|
-
current_dir = pathlib.Path(__file__).parent.resolve()
|
219
|
-
openelm_goldens = torch.load(current_dir / "openelm_lm_logits.pt")
|
220
|
-
kv_cache_max_len = 1024
|
221
|
-
model = build_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
|
222
|
-
idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
|
223
|
-
tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.int, device="cpu")
|
224
|
-
tokens[0, :4] = idx
|
225
|
-
input_pos = torch.arange(0, kv_cache_max_len, dtype=torch.int)
|
226
|
-
kv = kv_utils.KVCache.from_model_config(model.config)
|
227
|
-
output = model.forward(tokens, input_pos, kv)
|
228
|
-
assert torch.allclose(
|
229
|
-
openelm_goldens, output["logits"][0, idx.shape[1] - 1, :], atol=1e-05
|
230
|
-
)
|
231
|
-
|
232
|
-
|
233
|
-
if __name__ == "__main__":
|
234
|
-
input_checkpoint_path = os.path.join(
|
235
|
-
pathlib.Path.home(), "Downloads/llm_data/openelm"
|
236
|
-
)
|
237
|
-
define_and_run(input_checkpoint_path)
|
@@ -0,0 +1,61 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Verifies the reauthored OpenELM-3B model."""
|
17
|
+
|
18
|
+
import pathlib
|
19
|
+
|
20
|
+
from absl import app
|
21
|
+
from absl import flags
|
22
|
+
from ai_edge_torch.generative.examples.openelm import openelm
|
23
|
+
from ai_edge_torch.generative.utilities import verifier
|
24
|
+
import transformers
|
25
|
+
|
26
|
+
_PROMPTS = flags.DEFINE_multi_string(
|
27
|
+
"prompts",
|
28
|
+
"What is the meaning of life?",
|
29
|
+
"The input prompts to generate answers.",
|
30
|
+
)
|
31
|
+
|
32
|
+
|
33
|
+
def main(_):
|
34
|
+
checkpoint = "apple/OpenELM-3B"
|
35
|
+
verifier.log_msg("Loading the original model from", checkpoint)
|
36
|
+
original_model = transformers.AutoModelForCausalLM.from_pretrained(
|
37
|
+
checkpoint, trust_remote_code=True
|
38
|
+
)
|
39
|
+
|
40
|
+
# Locate the cached dir.
|
41
|
+
cached_config_file = transformers.utils.cached_file(
|
42
|
+
checkpoint, transformers.utils.CONFIG_NAME
|
43
|
+
)
|
44
|
+
reauthored_checkpoint = pathlib.Path(cached_config_file).parent
|
45
|
+
verifier.log_msg("Building the reauthored model from", reauthored_checkpoint)
|
46
|
+
reauthored_model = openelm.build_model(reauthored_checkpoint)
|
47
|
+
|
48
|
+
tokenizer_checkpoint = "meta-llama/Llama-2-7b-hf"
|
49
|
+
verifier.log_msg("Loading the tokenizer from", tokenizer_checkpoint)
|
50
|
+
tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer_checkpoint)
|
51
|
+
|
52
|
+
verifier.verify_reauthored_model(
|
53
|
+
original_model=original_model,
|
54
|
+
reauthored_model=reauthored_model,
|
55
|
+
tokenizer=tokenizer,
|
56
|
+
prompts=_PROMPTS.value,
|
57
|
+
)
|
58
|
+
|
59
|
+
|
60
|
+
if __name__ == "__main__":
|
61
|
+
app.run(main)
|
@@ -18,69 +18,49 @@
|
|
18
18
|
import os
|
19
19
|
import pathlib
|
20
20
|
|
21
|
-
import
|
21
|
+
from absl import app
|
22
|
+
from absl import flags
|
22
23
|
from ai_edge_torch.generative.examples.phi import phi2
|
23
|
-
from ai_edge_torch.generative.
|
24
|
-
from ai_edge_torch.generative.quantize import quant_recipes
|
25
|
-
import torch
|
24
|
+
from ai_edge_torch.generative.utilities import converter
|
26
25
|
|
26
|
+
_CHECKPOINT_PATH = flags.DEFINE_string(
|
27
|
+
'checkpoint_path',
|
28
|
+
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/phi2'),
|
29
|
+
'The path to the model checkpoint, or directory holding the checkpoint.',
|
30
|
+
)
|
31
|
+
_TFLITE_PATH = flags.DEFINE_string(
|
32
|
+
'tflite_path',
|
33
|
+
'/tmp/phi2_q8_seq512_ekv1024.tflite',
|
34
|
+
'The tflite file path to export.',
|
35
|
+
)
|
36
|
+
_PREFILL_SEQ_LEN = flags.DEFINE_integer(
|
37
|
+
'prefill_seq_len',
|
38
|
+
512,
|
39
|
+
'The maximum size of prefill input tensor.',
|
40
|
+
)
|
41
|
+
_KV_CACHE_MAX_LEN = flags.DEFINE_integer(
|
42
|
+
'kv_cache_max_len',
|
43
|
+
1024,
|
44
|
+
'The maximum size of KV cache buffer, including both prefill and decode.',
|
45
|
+
)
|
46
|
+
_QUANTIZE = flags.DEFINE_bool(
|
47
|
+
'quantize',
|
48
|
+
True,
|
49
|
+
'Whether the model should be quantized.',
|
50
|
+
)
|
27
51
|
|
28
|
-
def convert_phi2_to_tflite(
|
29
|
-
checkpoint_path: str,
|
30
|
-
prefill_seq_len: int = 512,
|
31
|
-
kv_cache_max_len: int = 1024,
|
32
|
-
quantize: bool = True,
|
33
|
-
):
|
34
|
-
"""Converts a Phi-2 model to multi-signature tflite model.
|
35
52
|
|
36
|
-
|
37
|
-
checkpoint_path (str): The filepath to the model checkpoint, or directory
|
38
|
-
holding the checkpoint.
|
39
|
-
prefill_seq_len (int, optional): The maximum size of prefill input tensor.
|
40
|
-
Defaults to 512.
|
41
|
-
kv_cache_max_len (int, optional): The maximum size of KV cache buffer,
|
42
|
-
including both prefill and decode. Defaults to 1024.
|
43
|
-
quantize (bool, optional): Whether the model should be quanized. Defaults
|
44
|
-
to True.
|
45
|
-
"""
|
53
|
+
def main(_):
|
46
54
|
pytorch_model = phi2.build_model(
|
47
|
-
|
55
|
+
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
|
48
56
|
)
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
kv = kv_cache.KVCache.from_model_config(pytorch_model.config)
|
55
|
-
|
56
|
-
quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
|
57
|
-
edge_model = (
|
58
|
-
ai_edge_torch.signature(
|
59
|
-
'prefill',
|
60
|
-
pytorch_model,
|
61
|
-
sample_kwargs={
|
62
|
-
'tokens': prefill_tokens,
|
63
|
-
'input_pos': prefill_input_pos,
|
64
|
-
'kv_cache': kv,
|
65
|
-
},
|
66
|
-
)
|
67
|
-
.signature(
|
68
|
-
'decode',
|
69
|
-
pytorch_model,
|
70
|
-
sample_kwargs={
|
71
|
-
'tokens': decode_token,
|
72
|
-
'input_pos': decode_input_pos,
|
73
|
-
'kv_cache': kv,
|
74
|
-
},
|
75
|
-
)
|
76
|
-
.convert(quant_config=quant_config)
|
77
|
-
)
|
78
|
-
quant_suffix = 'q8' if quantize else 'f32'
|
79
|
-
edge_model.export(
|
80
|
-
f'/tmp/phi2_{quant_suffix}_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
|
57
|
+
converter.convert_to_tflite(
|
58
|
+
pytorch_model,
|
59
|
+
tflite_path=_TFLITE_PATH.value,
|
60
|
+
prefill_seq_len=_PREFILL_SEQ_LEN.value,
|
61
|
+
quantize=_QUANTIZE.value,
|
81
62
|
)
|
82
63
|
|
83
64
|
|
84
65
|
if __name__ == '__main__':
|
85
|
-
|
86
|
-
convert_phi2_to_tflite(path)
|
66
|
+
app.run(main)
|
@@ -15,16 +15,12 @@
|
|
15
15
|
|
16
16
|
"""Example of building a Phi-2 model."""
|
17
17
|
|
18
|
-
import os
|
19
|
-
import pathlib
|
20
|
-
|
21
18
|
from ai_edge_torch.generative.layers import attention
|
22
19
|
from ai_edge_torch.generative.layers import builder
|
23
20
|
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
24
21
|
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
25
22
|
import ai_edge_torch.generative.layers.model_config as cfg
|
26
23
|
import ai_edge_torch.generative.utilities.loader as loading_utils
|
27
|
-
import numpy as np
|
28
24
|
import torch
|
29
25
|
from torch import nn
|
30
26
|
|
@@ -143,7 +139,10 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
143
139
|
intermediate_size=10240,
|
144
140
|
use_bias=True,
|
145
141
|
)
|
146
|
-
norm_config = cfg.NormalizationConfig(
|
142
|
+
norm_config = cfg.NormalizationConfig(
|
143
|
+
type=cfg.NormalizationType.LAYER_NORM,
|
144
|
+
use_input_shape=False, # Phi-2 does layer-norm with the weight shape.
|
145
|
+
)
|
147
146
|
block_config = cfg.TransformerBlockConfig(
|
148
147
|
attn_config=attn_config,
|
149
148
|
ff_config=ff_config,
|
@@ -182,29 +181,3 @@ def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
|
|
182
181
|
loader.load(model)
|
183
182
|
model.eval()
|
184
183
|
return model
|
185
|
-
|
186
|
-
|
187
|
-
def define_and_run(checkpoint_path: str) -> None:
|
188
|
-
"""Instantiates and runs a Phi-2 model."""
|
189
|
-
|
190
|
-
current_dir = pathlib.Path(__file__).parent.resolve()
|
191
|
-
phi2_goldens = torch.load(current_dir / "phi2_lm_logits.pt")
|
192
|
-
kv_cache_max_len = 1024
|
193
|
-
model = build_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
|
194
|
-
idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
|
195
|
-
tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.int, device="cpu")
|
196
|
-
tokens[0, :4] = idx
|
197
|
-
input_pos = torch.arange(0, kv_cache_max_len, dtype=torch.int)
|
198
|
-
kv = kv_utils.KVCache.from_model_config(model.config)
|
199
|
-
output = model.forward(tokens, input_pos, kv)
|
200
|
-
print("comparing with goldens..")
|
201
|
-
assert torch.allclose(
|
202
|
-
phi2_goldens, output["logits"][0, idx.shape[1] - 1, :], atol=1e-02
|
203
|
-
)
|
204
|
-
|
205
|
-
|
206
|
-
if __name__ == "__main__":
|
207
|
-
input_checkpoint_path = os.path.join(
|
208
|
-
pathlib.Path.home(), "Downloads/llm_data/phi2"
|
209
|
-
)
|
210
|
-
define_and_run(input_checkpoint_path)
|
@@ -0,0 +1,53 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Verifies the reauthored Phi-2 model."""
|
17
|
+
|
18
|
+
from absl import app
|
19
|
+
from absl import flags
|
20
|
+
from ai_edge_torch.generative.examples.phi import phi2
|
21
|
+
from ai_edge_torch.generative.utilities import verifier
|
22
|
+
import kagglehub
|
23
|
+
import transformers
|
24
|
+
|
25
|
+
_PROMPTS = flags.DEFINE_multi_string(
|
26
|
+
"prompts",
|
27
|
+
"What is the meaning of life?",
|
28
|
+
"The input prompts to generate answers.",
|
29
|
+
)
|
30
|
+
|
31
|
+
|
32
|
+
def main(_):
|
33
|
+
checkpoint = kagglehub.model_download("Microsoft/phi/transformers/2")
|
34
|
+
verifier.log_msg("Loading the original model from", checkpoint)
|
35
|
+
original_model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint)
|
36
|
+
|
37
|
+
verifier.log_msg("Building the reauthored model from", checkpoint)
|
38
|
+
reauthored_model = phi2.build_model(checkpoint)
|
39
|
+
|
40
|
+
verifier.log_msg("Loading the tokenizer from", checkpoint)
|
41
|
+
tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
|
42
|
+
|
43
|
+
verifier.verify_reauthored_model(
|
44
|
+
original_model=original_model,
|
45
|
+
reauthored_model=reauthored_model,
|
46
|
+
tokenizer=tokenizer,
|
47
|
+
prompts=_PROMPTS.value,
|
48
|
+
atol=1e-03,
|
49
|
+
)
|
50
|
+
|
51
|
+
|
52
|
+
if __name__ == "__main__":
|
53
|
+
app.run(main)
|