ai-edge-torch-nightly 0.3.0.dev20240924__py3-none-any.whl → 0.3.0.dev20240928__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/generative/examples/gemma/gemma1.py +2 -6
- ai_edge_torch/generative/examples/gemma/gemma2.py +2 -10
- ai_edge_torch/generative/examples/gemma/verify_gemma1.py +3 -2
- ai_edge_torch/generative/examples/gemma/verify_gemma2.py +3 -2
- ai_edge_torch/generative/examples/gemma/verify_util.py +15 -25
- ai_edge_torch/generative/examples/llama/__init__.py +14 -0
- ai_edge_torch/generative/examples/llama/convert_3b_to_tflite.py +68 -0
- ai_edge_torch/generative/examples/llama/convert_to_tflite.py +68 -0
- ai_edge_torch/generative/examples/llama/llama.py +204 -0
- ai_edge_torch/generative/examples/llama/verify.py +73 -0
- ai_edge_torch/generative/examples/llama/verify_3b.py +73 -0
- ai_edge_torch/generative/examples/openelm/openelm.py +2 -6
- ai_edge_torch/generative/examples/openelm/verify.py +19 -11
- ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py +68 -0
- ai_edge_torch/generative/examples/phi/phi2.py +2 -6
- ai_edge_torch/generative/examples/phi/phi3.py +279 -0
- ai_edge_torch/generative/examples/phi/verify.py +13 -13
- ai_edge_torch/generative/examples/phi/verify_phi3.py +69 -0
- ai_edge_torch/generative/examples/smollm/smollm.py +1 -0
- ai_edge_torch/generative/examples/smollm/verify.py +19 -9
- ai_edge_torch/generative/examples/stable_diffusion/clip.py +54 -1
- ai_edge_torch/generative/examples/stable_diffusion/decoder.py +58 -0
- ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +71 -1
- ai_edge_torch/generative/examples/t5/t5.py +0 -2
- ai_edge_torch/generative/examples/test_models/convert_toy_model.py +105 -0
- ai_edge_torch/generative/examples/test_models/toy_model.py +7 -41
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +5 -61
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +2 -6
- ai_edge_torch/generative/examples/tiny_llama/verify.py +20 -10
- ai_edge_torch/generative/layers/model_config.py +2 -0
- ai_edge_torch/generative/layers/normalization.py +2 -2
- ai_edge_torch/generative/layers/unet/blocks_2d.py +2 -2
- ai_edge_torch/generative/test/test_model_conversion_large.py +129 -0
- ai_edge_torch/generative/utilities/transformers_verifier.py +42 -0
- ai_edge_torch/generative/utilities/verifier.py +130 -114
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240924.dist-info → ai_edge_torch_nightly-0.3.0.dev20240928.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240924.dist-info → ai_edge_torch_nightly-0.3.0.dev20240928.dist-info}/RECORD +41 -30
- {ai_edge_torch_nightly-0.3.0.dev20240924.dist-info → ai_edge_torch_nightly-0.3.0.dev20240928.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240924.dist-info → ai_edge_torch_nightly-0.3.0.dev20240928.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240924.dist-info → ai_edge_torch_nightly-0.3.0.dev20240928.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,73 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Verifies the reauthored Llama 3.2-3B model."""
|
17
|
+
|
18
|
+
import logging
|
19
|
+
import pathlib
|
20
|
+
|
21
|
+
from absl import app
|
22
|
+
from absl import flags
|
23
|
+
from ai_edge_torch.generative.examples.llama import llama
|
24
|
+
from ai_edge_torch.generative.utilities import transformers_verifier
|
25
|
+
from ai_edge_torch.generative.utilities import verifier
|
26
|
+
import transformers
|
27
|
+
|
28
|
+
|
29
|
+
_PROMPTS = flags.DEFINE_multi_string(
|
30
|
+
"prompts",
|
31
|
+
"What is the meaning of life?",
|
32
|
+
"The input prompts to generate answers.",
|
33
|
+
)
|
34
|
+
_MAX_NEW_TOKENS = flags.DEFINE_integer(
|
35
|
+
"max_new_tokens",
|
36
|
+
30,
|
37
|
+
"The maximum size of the generated tokens.",
|
38
|
+
)
|
39
|
+
|
40
|
+
|
41
|
+
def main(_):
|
42
|
+
checkpoint = "meta-llama/Llama-3.2-3B-Instruct"
|
43
|
+
logging.info("Loading the original model from: %s", checkpoint)
|
44
|
+
original_model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint)
|
45
|
+
|
46
|
+
# Locate the cached dir.
|
47
|
+
cached_config_file = transformers.utils.cached_file(
|
48
|
+
checkpoint, transformers.utils.CONFIG_NAME
|
49
|
+
)
|
50
|
+
reauthored_checkpoint = pathlib.Path(cached_config_file).parent
|
51
|
+
logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
|
52
|
+
reauthored_model = llama.build_3b_model(reauthored_checkpoint)
|
53
|
+
|
54
|
+
logging.info("Loading the tokenizer from: %s", checkpoint)
|
55
|
+
# Llama tokenizer_config.json sets a fast tokenizer class explicitly,
|
56
|
+
# "PreTrainedTokenizerFast". It works only when the fast tokenizer is
|
57
|
+
# available.
|
58
|
+
tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
|
59
|
+
|
60
|
+
verifier.verify_reauthored_model(
|
61
|
+
original_model=transformers_verifier.TransformersModelWrapper(
|
62
|
+
original_model
|
63
|
+
),
|
64
|
+
reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
|
65
|
+
tokenizer=verifier.TokenizerWrapper(tokenizer),
|
66
|
+
generate_prompts=_PROMPTS.value,
|
67
|
+
max_new_tokens=_MAX_NEW_TOKENS.value,
|
68
|
+
atol=1e-04,
|
69
|
+
)
|
70
|
+
|
71
|
+
|
72
|
+
if __name__ == "__main__":
|
73
|
+
app.run(main)
|
@@ -68,15 +68,10 @@ class OpenELM(nn.Module):
|
|
68
68
|
self.rope_cache = attn_utils.build_rope_cache(
|
69
69
|
size=config.kv_cache_max,
|
70
70
|
dim=int(attn_config.rotary_percentage * attn_config.head_dim),
|
71
|
-
base=
|
72
|
-
condense_ratio=1,
|
73
|
-
dtype=torch.float32,
|
74
|
-
device=torch.device("cpu"),
|
71
|
+
base=attn_config.rotary_base,
|
75
72
|
)
|
76
73
|
self.mask_cache = attn_utils.build_causal_mask_cache(
|
77
74
|
size=config.kv_cache_max,
|
78
|
-
dtype=torch.float32,
|
79
|
-
device=torch.device("cpu"),
|
80
75
|
)
|
81
76
|
self.config = config
|
82
77
|
|
@@ -154,6 +149,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
154
149
|
num_heads=num_heads[idx],
|
155
150
|
head_dim=128,
|
156
151
|
num_query_groups=num_query_groups[idx],
|
152
|
+
rotary_base=10000,
|
157
153
|
rotary_percentage=1.0,
|
158
154
|
qkv_transpose_before_split=True,
|
159
155
|
query_norm_config=norm_config,
|
@@ -15,28 +15,33 @@
|
|
15
15
|
|
16
16
|
"""Verifies the reauthored OpenELM-3B model."""
|
17
17
|
|
18
|
+
import logging
|
18
19
|
import pathlib
|
19
|
-
|
20
20
|
from absl import app
|
21
21
|
from absl import flags
|
22
22
|
from ai_edge_torch.generative.examples.openelm import openelm
|
23
|
+
from ai_edge_torch.generative.utilities import transformers_verifier
|
23
24
|
from ai_edge_torch.generative.utilities import verifier
|
24
25
|
import transformers
|
25
26
|
|
27
|
+
|
26
28
|
_PROMPTS = flags.DEFINE_multi_string(
|
27
29
|
"prompts",
|
28
30
|
"What is the meaning of life?",
|
29
31
|
"The input prompts to generate answers.",
|
30
32
|
)
|
33
|
+
_MAX_NEW_TOKENS = flags.DEFINE_integer(
|
34
|
+
"max_new_tokens",
|
35
|
+
30,
|
36
|
+
"The maximum size of the generated tokens.",
|
37
|
+
)
|
31
38
|
|
32
39
|
|
33
40
|
def main(_):
|
34
41
|
checkpoint = "apple/OpenELM-3B"
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
checkpoint, trust_remote_code=True
|
39
|
-
),
|
42
|
+
logging.info("Loading the original model from: %s", checkpoint)
|
43
|
+
original_model = transformers.AutoModelForCausalLM.from_pretrained(
|
44
|
+
checkpoint, trust_remote_code=True
|
40
45
|
)
|
41
46
|
|
42
47
|
# Locate the cached dir.
|
@@ -44,18 +49,21 @@ def main(_):
|
|
44
49
|
checkpoint, transformers.utils.CONFIG_NAME
|
45
50
|
)
|
46
51
|
reauthored_checkpoint = pathlib.Path(cached_config_file).parent
|
47
|
-
|
52
|
+
logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
|
48
53
|
reauthored_model = openelm.build_model(reauthored_checkpoint)
|
49
54
|
|
50
55
|
tokenizer_checkpoint = "meta-llama/Llama-2-7b-hf"
|
51
|
-
|
56
|
+
logging.info("Loading the tokenizer from: %s", tokenizer_checkpoint)
|
52
57
|
tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer_checkpoint)
|
53
58
|
|
54
59
|
verifier.verify_reauthored_model(
|
55
|
-
original_model=
|
56
|
-
|
57
|
-
|
60
|
+
original_model=transformers_verifier.TransformersModelWrapper(
|
61
|
+
original_model
|
62
|
+
),
|
63
|
+
reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
|
64
|
+
tokenizer=verifier.TokenizerWrapper(tokenizer),
|
58
65
|
generate_prompts=_PROMPTS.value,
|
66
|
+
max_new_tokens=_MAX_NEW_TOKENS.value,
|
59
67
|
)
|
60
68
|
|
61
69
|
|
@@ -0,0 +1,68 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Example of converting a Phi-3.5 model to multi-signature tflite model."""
|
17
|
+
|
18
|
+
import os
|
19
|
+
import pathlib
|
20
|
+
|
21
|
+
from absl import app
|
22
|
+
from absl import flags
|
23
|
+
from ai_edge_torch.generative.examples.phi import phi3
|
24
|
+
from ai_edge_torch.generative.utilities import converter
|
25
|
+
|
26
|
+
_CHECKPOINT_PATH = flags.DEFINE_string(
|
27
|
+
'checkpoint_path',
|
28
|
+
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/phi3'),
|
29
|
+
'The path to the model checkpoint, or directory holding the checkpoint.',
|
30
|
+
)
|
31
|
+
_TFLITE_PATH = flags.DEFINE_string(
|
32
|
+
'tflite_path',
|
33
|
+
'/tmp/',
|
34
|
+
'The tflite file path to export.',
|
35
|
+
)
|
36
|
+
_PREFILL_SEQ_LEN = flags.DEFINE_integer(
|
37
|
+
'prefill_seq_len',
|
38
|
+
1024,
|
39
|
+
'The maximum size of prefill input tensor.',
|
40
|
+
)
|
41
|
+
_KV_CACHE_MAX_LEN = flags.DEFINE_integer(
|
42
|
+
'kv_cache_max_len',
|
43
|
+
1280,
|
44
|
+
'The maximum size of KV cache buffer, including both prefill and decode.',
|
45
|
+
)
|
46
|
+
_QUANTIZE = flags.DEFINE_bool(
|
47
|
+
'quantize',
|
48
|
+
True,
|
49
|
+
'Whether the model should be quantized.',
|
50
|
+
)
|
51
|
+
|
52
|
+
|
53
|
+
def main(_):
|
54
|
+
pytorch_model = phi3.build_model(
|
55
|
+
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
|
56
|
+
)
|
57
|
+
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
|
58
|
+
output_filename = f'phi3_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
|
59
|
+
converter.convert_to_tflite(
|
60
|
+
pytorch_model,
|
61
|
+
tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
|
62
|
+
prefill_seq_len=_PREFILL_SEQ_LEN.value,
|
63
|
+
quantize=_QUANTIZE.value,
|
64
|
+
)
|
65
|
+
|
66
|
+
|
67
|
+
if __name__ == '__main__':
|
68
|
+
app.run(main)
|
@@ -65,15 +65,10 @@ class Phi2(nn.Module):
|
|
65
65
|
self.rope_cache = attn_utils.build_rope_cache(
|
66
66
|
size=config.kv_cache_max,
|
67
67
|
dim=int(attn_config.rotary_percentage * attn_config.head_dim),
|
68
|
-
base=
|
69
|
-
condense_ratio=1,
|
70
|
-
dtype=torch.float32,
|
71
|
-
device=torch.device("cpu"),
|
68
|
+
base=attn_config.rotary_base,
|
72
69
|
)
|
73
70
|
self.mask_cache = attn_utils.build_causal_mask_cache(
|
74
71
|
size=config.kv_cache_max,
|
75
|
-
dtype=torch.float32,
|
76
|
-
device=torch.device("cpu"),
|
77
72
|
)
|
78
73
|
self.config = config
|
79
74
|
|
@@ -129,6 +124,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
129
124
|
num_heads=32,
|
130
125
|
head_dim=80,
|
131
126
|
num_query_groups=32,
|
127
|
+
rotary_base=10000,
|
132
128
|
rotary_percentage=0.4,
|
133
129
|
qkv_use_bias=True,
|
134
130
|
output_proj_use_bias=True,
|
@@ -0,0 +1,279 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Example of building a Phi-3.5 model up to 4K tokens, not to 128K tokens."""
|
17
|
+
|
18
|
+
import math
|
19
|
+
from typing import Tuple
|
20
|
+
|
21
|
+
from ai_edge_torch.generative.layers import attention
|
22
|
+
from ai_edge_torch.generative.layers import builder
|
23
|
+
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
24
|
+
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
25
|
+
import ai_edge_torch.generative.layers.model_config as cfg
|
26
|
+
import ai_edge_torch.generative.utilities.loader as loading_utils
|
27
|
+
import torch
|
28
|
+
from torch import nn
|
29
|
+
|
30
|
+
TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
31
|
+
ff_up_proj="model.layers.{}.mlp.gate_up_proj",
|
32
|
+
ff_down_proj="model.layers.{}.mlp.down_proj",
|
33
|
+
attn_fused_qkv_proj="model.layers.{}.self_attn.qkv_proj",
|
34
|
+
attn_output_proj="model.layers.{}.self_attn.o_proj",
|
35
|
+
pre_attn_norm="model.layers.{}.input_layernorm",
|
36
|
+
post_attn_norm="model.layers.{}.post_attention_layernorm",
|
37
|
+
embedding="model.embed_tokens",
|
38
|
+
final_norm="model.norm",
|
39
|
+
lm_head="lm_head",
|
40
|
+
)
|
41
|
+
|
42
|
+
# max_position_embeddings / original_max_position_embeddings in Phi-3.5 config.
|
43
|
+
ROPE_SCALE_FACTOR = 32
|
44
|
+
|
45
|
+
# ROPE short factor in Phi-3.5 config. According to LOPE paper and its code in
|
46
|
+
# https://github.com/microsoft/LongRoPE, these values had been searched with
|
47
|
+
# min=1.0, step-0.01 to optimize the errors of sample dataset.
|
48
|
+
ROPE_SHORT_FACTOR = [
|
49
|
+
1.0,
|
50
|
+
1.0199999809265137,
|
51
|
+
1.0299999713897705,
|
52
|
+
1.0299999713897705,
|
53
|
+
1.0499999523162842,
|
54
|
+
1.0499999523162842,
|
55
|
+
1.0499999523162842,
|
56
|
+
1.0499999523162842,
|
57
|
+
1.0499999523162842,
|
58
|
+
1.0699999332427979,
|
59
|
+
1.0999999046325684,
|
60
|
+
1.1099998950958252,
|
61
|
+
1.1599998474121094,
|
62
|
+
1.1599998474121094,
|
63
|
+
1.1699998378753662,
|
64
|
+
1.2899998426437378,
|
65
|
+
1.339999794960022,
|
66
|
+
1.679999828338623,
|
67
|
+
1.7899998426437378,
|
68
|
+
1.8199998140335083,
|
69
|
+
1.8499997854232788,
|
70
|
+
1.8799997568130493,
|
71
|
+
1.9099997282028198,
|
72
|
+
1.9399996995925903,
|
73
|
+
1.9899996519088745,
|
74
|
+
2.0199997425079346,
|
75
|
+
2.0199997425079346,
|
76
|
+
2.0199997425079346,
|
77
|
+
2.0199997425079346,
|
78
|
+
2.0199997425079346,
|
79
|
+
2.0199997425079346,
|
80
|
+
2.0299997329711914,
|
81
|
+
2.0299997329711914,
|
82
|
+
2.0299997329711914,
|
83
|
+
2.0299997329711914,
|
84
|
+
2.0299997329711914,
|
85
|
+
2.0299997329711914,
|
86
|
+
2.0299997329711914,
|
87
|
+
2.0299997329711914,
|
88
|
+
2.0299997329711914,
|
89
|
+
2.0799996852874756,
|
90
|
+
2.0899996757507324,
|
91
|
+
2.189999580383301,
|
92
|
+
2.2199995517730713,
|
93
|
+
2.5899994373321533,
|
94
|
+
2.729999542236328,
|
95
|
+
2.749999523162842,
|
96
|
+
2.8399994373321533,
|
97
|
+
]
|
98
|
+
|
99
|
+
|
100
|
+
def _build_rope_cache(
|
101
|
+
size: int,
|
102
|
+
dim: int,
|
103
|
+
base: int,
|
104
|
+
condense_ratio: int,
|
105
|
+
dtype: torch.dtype,
|
106
|
+
device: torch.device,
|
107
|
+
theta_factors: torch.Tensor,
|
108
|
+
scale: float,
|
109
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
110
|
+
"""Precomputes Rotary Positional Embeddings for Phi-3.5 model.
|
111
|
+
|
112
|
+
It's a modified version of attn_utils.build_rope_cache with additional
|
113
|
+
arguments for Phi-3.5 model. It precompute Rotary Positional Embedding Sin and
|
114
|
+
Cos values with scaling factors for quick lookup during the inference.
|
115
|
+
|
116
|
+
Args:
|
117
|
+
size (int): The size of the built cache.
|
118
|
+
dim (int): Each sequence's dimmension.
|
119
|
+
base (int, optional): Rope base value.
|
120
|
+
condense_ratio (int, optional): The ratio by which sequence indicies are
|
121
|
+
condensed.
|
122
|
+
dtype (torch.dtype, optional): Output tensor's data type.
|
123
|
+
device (torch.device, optional): Output tensor's data type.
|
124
|
+
theta_factors (torch.Tensor, optional): A tensor of shape (dim,) used to
|
125
|
+
scale the theta values.
|
126
|
+
scale (float, optional): A float used to scale the rope values.
|
127
|
+
|
128
|
+
Returns:
|
129
|
+
Tuple[torch.Tensor, torch.Tensor]: Rope's Cosine and Sine waves.
|
130
|
+
"""
|
131
|
+
theta = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
|
132
|
+
theta = theta / theta_factors
|
133
|
+
seq_idx = torch.arange(size) / condense_ratio
|
134
|
+
idx_theta = torch.outer(seq_idx, theta)
|
135
|
+
cos = torch.cos(idx_theta).to(dtype=dtype, device=device) * scale
|
136
|
+
sin = torch.sin(idx_theta).to(dtype=dtype, device=device) * scale
|
137
|
+
return cos, sin
|
138
|
+
|
139
|
+
|
140
|
+
class Phi3_5Mini(nn.Module):
|
141
|
+
"""A Phi-3.5 model built from the Edge Generative API layers."""
|
142
|
+
|
143
|
+
def __init__(self, config: cfg.ModelConfig):
|
144
|
+
super().__init__()
|
145
|
+
|
146
|
+
# Construct model layers.
|
147
|
+
self.lm_head = nn.Linear(
|
148
|
+
config.embedding_dim, config.vocab_size, bias=config.lm_head_use_bias
|
149
|
+
)
|
150
|
+
self.tok_embedding = nn.Embedding(
|
151
|
+
config.vocab_size, config.embedding_dim, padding_idx=0
|
152
|
+
)
|
153
|
+
# Phi-3.5 has only one block config.
|
154
|
+
block_config = config.block_config(0)
|
155
|
+
self.transformer_blocks = nn.ModuleList(
|
156
|
+
attention.TransformerBlock(block_config, config)
|
157
|
+
for _ in range(config.num_layers)
|
158
|
+
)
|
159
|
+
self.final_norm = builder.build_norm(
|
160
|
+
config.embedding_dim,
|
161
|
+
config.final_norm_config,
|
162
|
+
)
|
163
|
+
attn_config = block_config.attn_config
|
164
|
+
self.rope_cache = _build_rope_cache(
|
165
|
+
size=config.kv_cache_max,
|
166
|
+
dim=int(attn_config.rotary_percentage * attn_config.head_dim),
|
167
|
+
base=attn_config.rotary_base,
|
168
|
+
condense_ratio=1,
|
169
|
+
dtype=torch.float32,
|
170
|
+
device=torch.device("cpu"),
|
171
|
+
theta_factors=torch.tensor(ROPE_SHORT_FACTOR),
|
172
|
+
scale=math.sqrt(
|
173
|
+
1 + math.log(ROPE_SCALE_FACTOR) / math.log(config.max_seq_len)
|
174
|
+
),
|
175
|
+
)
|
176
|
+
self.mask_cache = attn_utils.build_causal_mask_cache(
|
177
|
+
size=config.kv_cache_max,
|
178
|
+
)
|
179
|
+
self.config = config
|
180
|
+
|
181
|
+
@torch.inference_mode
|
182
|
+
def forward(
|
183
|
+
self,
|
184
|
+
tokens: torch.Tensor,
|
185
|
+
input_pos: torch.Tensor,
|
186
|
+
kv_cache: kv_utils.KVCache,
|
187
|
+
) -> dict[torch.Tensor, kv_utils.KVCache]:
|
188
|
+
_, seq_len = tokens.size()
|
189
|
+
assert self.config.max_seq_len >= seq_len, (
|
190
|
+
f"Cannot forward sequence of length {seq_len}, max seq length is only"
|
191
|
+
f" {self.config.max_seq_len}"
|
192
|
+
)
|
193
|
+
assert len(self.transformer_blocks) == len(kv_cache.caches), (
|
194
|
+
"The number of transformer blocks and the number of KV cache entries"
|
195
|
+
" must be the same."
|
196
|
+
)
|
197
|
+
|
198
|
+
cos, sin = self.rope_cache
|
199
|
+
cos = cos.index_select(0, input_pos)
|
200
|
+
sin = sin.index_select(0, input_pos)
|
201
|
+
mask = self.mask_cache.index_select(2, input_pos)
|
202
|
+
mask = mask[:, :, :, : self.config.kv_cache_max]
|
203
|
+
|
204
|
+
x = self.tok_embedding(tokens)
|
205
|
+
|
206
|
+
updated_kv_entires = []
|
207
|
+
for i, block in enumerate(self.transformer_blocks):
|
208
|
+
kv_entry = kv_cache.caches[i] if kv_cache else None
|
209
|
+
x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
|
210
|
+
if kv_entry:
|
211
|
+
updated_kv_entires.append(kv_entry)
|
212
|
+
updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
|
213
|
+
|
214
|
+
x = self.final_norm(x)
|
215
|
+
logits = self.lm_head(x) # (b, t, vocab_size)
|
216
|
+
return {"logits": logits, "kv_cache": updated_kv_cache}
|
217
|
+
|
218
|
+
|
219
|
+
def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
220
|
+
"""Returns the model config for a Phi-3.5 model.
|
221
|
+
|
222
|
+
Args:
|
223
|
+
kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
|
224
|
+
is 1024.
|
225
|
+
|
226
|
+
Returns:
|
227
|
+
The model config for a Phi-2 model.
|
228
|
+
"""
|
229
|
+
attn_config = cfg.AttentionConfig(
|
230
|
+
num_heads=32,
|
231
|
+
head_dim=96,
|
232
|
+
num_query_groups=32,
|
233
|
+
rotary_base=10000,
|
234
|
+
rotary_percentage=1.0,
|
235
|
+
qkv_transpose_before_split=True,
|
236
|
+
)
|
237
|
+
ff_config = cfg.FeedForwardConfig(
|
238
|
+
type=cfg.FeedForwardType.SEQUENTIAL,
|
239
|
+
activation=cfg.ActivationConfig(cfg.ActivationType.SILU_GLU),
|
240
|
+
intermediate_size=8192,
|
241
|
+
)
|
242
|
+
norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
|
243
|
+
block_config = cfg.TransformerBlockConfig(
|
244
|
+
attn_config=attn_config,
|
245
|
+
ff_config=ff_config,
|
246
|
+
pre_attention_norm_config=norm_config,
|
247
|
+
post_attention_norm_config=norm_config,
|
248
|
+
)
|
249
|
+
config = cfg.ModelConfig(
|
250
|
+
vocab_size=32064,
|
251
|
+
num_layers=32,
|
252
|
+
max_seq_len=4096,
|
253
|
+
kv_cache_max_len=kv_cache_max_len,
|
254
|
+
embedding_dim=3072,
|
255
|
+
block_configs=block_config,
|
256
|
+
final_norm_config=norm_config,
|
257
|
+
enable_hlfb=True,
|
258
|
+
)
|
259
|
+
return config
|
260
|
+
|
261
|
+
|
262
|
+
def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
|
263
|
+
config = get_model_config(kv_cache_max_len)
|
264
|
+
config.vocab_size = 128
|
265
|
+
config.num_layers = 2
|
266
|
+
config.max_seq_len = 2 * kv_cache_max_len
|
267
|
+
# Phi-3.5 has only one block config.
|
268
|
+
config.block_config(0).ff_config.intermediate_size = 128
|
269
|
+
return config
|
270
|
+
|
271
|
+
|
272
|
+
def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
|
273
|
+
"""Instantiates the model instance and load checkpoint if provided."""
|
274
|
+
config = get_model_config(**kwargs)
|
275
|
+
model = Phi3_5Mini(config)
|
276
|
+
loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
|
277
|
+
loader.load(model)
|
278
|
+
model.eval()
|
279
|
+
return model
|
@@ -14,20 +14,22 @@
|
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
16
|
"""Verifies the reauthored Phi-2 model."""
|
17
|
+
import logging
|
17
18
|
|
18
19
|
from absl import app
|
19
20
|
from absl import flags
|
20
21
|
from ai_edge_torch.generative.examples.phi import phi2
|
22
|
+
from ai_edge_torch.generative.utilities import transformers_verifier
|
21
23
|
from ai_edge_torch.generative.utilities import verifier
|
22
24
|
import kagglehub
|
23
25
|
import transformers
|
24
26
|
|
27
|
+
|
25
28
|
_PROMPTS = flags.DEFINE_multi_string(
|
26
29
|
"prompts",
|
27
30
|
"Instruct: Write an email about the weather Output:",
|
28
31
|
"The input prompts to generate answers.",
|
29
32
|
)
|
30
|
-
|
31
33
|
_MAX_NEW_TOKENS = flags.DEFINE_integer(
|
32
34
|
"max_new_tokens",
|
33
35
|
30,
|
@@ -37,25 +39,23 @@ _MAX_NEW_TOKENS = flags.DEFINE_integer(
|
|
37
39
|
|
38
40
|
def main(_):
|
39
41
|
checkpoint = kagglehub.model_download("Microsoft/phi/transformers/2")
|
40
|
-
|
41
|
-
|
42
|
-
generation_config.max_new_tokens = _MAX_NEW_TOKENS.value
|
43
|
-
wrapper_model = verifier.ModelWrapper(
|
44
|
-
model=transformers.AutoModelForCausalLM.from_pretrained(checkpoint),
|
45
|
-
hf_generation_config=generation_config,
|
46
|
-
)
|
42
|
+
logging.info("Loading the original model from: %s", checkpoint)
|
43
|
+
original_model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint)
|
47
44
|
|
48
|
-
|
45
|
+
logging.info("Building the reauthored model from: %s", checkpoint)
|
49
46
|
reauthored_model = phi2.build_model(checkpoint)
|
50
47
|
|
51
|
-
|
48
|
+
logging.info("Loading the tokenizer from: %s", checkpoint)
|
52
49
|
tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
|
53
50
|
|
54
51
|
verifier.verify_reauthored_model(
|
55
|
-
original_model=
|
56
|
-
|
57
|
-
|
52
|
+
original_model=transformers_verifier.TransformersModelWrapper(
|
53
|
+
original_model
|
54
|
+
),
|
55
|
+
reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
|
56
|
+
tokenizer=verifier.TokenizerWrapper(tokenizer),
|
58
57
|
generate_prompts=_PROMPTS.value,
|
58
|
+
max_new_tokens=_MAX_NEW_TOKENS.value,
|
59
59
|
atol=1e-03,
|
60
60
|
)
|
61
61
|
|
@@ -0,0 +1,69 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Verifies the reauthored Phi-3.5 model."""
|
17
|
+
|
18
|
+
import logging
|
19
|
+
import pathlib
|
20
|
+
|
21
|
+
from absl import app
|
22
|
+
from absl import flags
|
23
|
+
from ai_edge_torch.generative.examples.phi import phi3
|
24
|
+
from ai_edge_torch.generative.utilities import transformers_verifier
|
25
|
+
from ai_edge_torch.generative.utilities import verifier
|
26
|
+
import transformers
|
27
|
+
|
28
|
+
|
29
|
+
_PROMPTS = flags.DEFINE_multi_string(
|
30
|
+
"prompts",
|
31
|
+
"Instruct: Write an email about the weather Output:",
|
32
|
+
"The input prompts to generate answers.",
|
33
|
+
)
|
34
|
+
_MAX_NEW_TOKENS = flags.DEFINE_integer(
|
35
|
+
"max_new_tokens",
|
36
|
+
30,
|
37
|
+
"The maximum size of the generated tokens.",
|
38
|
+
)
|
39
|
+
|
40
|
+
|
41
|
+
def main(_):
|
42
|
+
checkpoint = "microsoft/Phi-3.5-mini-instruct"
|
43
|
+
logging.info("Loading the original model from: %s", checkpoint)
|
44
|
+
original_model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint)
|
45
|
+
|
46
|
+
# Locate the cached dir.
|
47
|
+
cached_config_file = transformers.utils.cached_file(
|
48
|
+
checkpoint, transformers.utils.CONFIG_NAME
|
49
|
+
)
|
50
|
+
reauthored_checkpoint = pathlib.Path(cached_config_file).parent
|
51
|
+
logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
|
52
|
+
reauthored_model = phi3.build_model(reauthored_checkpoint)
|
53
|
+
|
54
|
+
logging.info("Loading the tokenizer from: %s", checkpoint)
|
55
|
+
tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
|
56
|
+
|
57
|
+
verifier.verify_reauthored_model(
|
58
|
+
original_model=transformers_verifier.TransformersModelWrapper(
|
59
|
+
original_model
|
60
|
+
),
|
61
|
+
reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
|
62
|
+
tokenizer=verifier.TokenizerWrapper(tokenizer),
|
63
|
+
generate_prompts=_PROMPTS.value,
|
64
|
+
max_new_tokens=_MAX_NEW_TOKENS.value,
|
65
|
+
)
|
66
|
+
|
67
|
+
|
68
|
+
if __name__ == "__main__":
|
69
|
+
app.run(main)
|