ai-edge-torch-nightly 0.3.0.dev20240925__py3-none-any.whl → 0.3.0.dev20240927__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- ai_edge_torch/generative/examples/gemma/gemma2.py +0 -2
- ai_edge_torch/generative/examples/gemma/verify_gemma1.py +3 -2
- ai_edge_torch/generative/examples/gemma/verify_gemma2.py +3 -2
- ai_edge_torch/generative/examples/gemma/verify_util.py +15 -25
- ai_edge_torch/generative/examples/llama/__init__.py +14 -0
- ai_edge_torch/generative/examples/llama/convert_3b_to_tflite.py +68 -0
- ai_edge_torch/generative/examples/llama/convert_to_tflite.py +68 -0
- ai_edge_torch/generative/examples/llama/llama.py +203 -0
- ai_edge_torch/generative/examples/llama/verify.py +73 -0
- ai_edge_torch/generative/examples/llama/verify_3b.py +73 -0
- ai_edge_torch/generative/examples/openelm/verify.py +19 -11
- ai_edge_torch/generative/examples/phi/phi3.py +15 -21
- ai_edge_torch/generative/examples/phi/verify.py +13 -12
- ai_edge_torch/generative/examples/phi/verify_phi3.py +13 -12
- ai_edge_torch/generative/examples/smollm/verify.py +19 -9
- ai_edge_torch/generative/examples/tiny_llama/verify.py +20 -10
- ai_edge_torch/generative/test/test_model_conversion_large.py +10 -0
- ai_edge_torch/generative/utilities/transformers_verifier.py +42 -0
- ai_edge_torch/generative/utilities/verifier.py +130 -114
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240925.dist-info → ai_edge_torch_nightly-0.3.0.dev20240927.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240925.dist-info → ai_edge_torch_nightly-0.3.0.dev20240927.dist-info}/RECORD +25 -18
- {ai_edge_torch_nightly-0.3.0.dev20240925.dist-info → ai_edge_torch_nightly-0.3.0.dev20240927.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240925.dist-info → ai_edge_torch_nightly-0.3.0.dev20240927.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240925.dist-info → ai_edge_torch_nightly-0.3.0.dev20240927.dist-info}/top_level.txt +0 -0
@@ -15,28 +15,33 @@
|
|
15
15
|
|
16
16
|
"""Verifies the reauthored OpenELM-3B model."""
|
17
17
|
|
18
|
+
import logging
|
18
19
|
import pathlib
|
19
|
-
|
20
20
|
from absl import app
|
21
21
|
from absl import flags
|
22
22
|
from ai_edge_torch.generative.examples.openelm import openelm
|
23
|
+
from ai_edge_torch.generative.utilities import transformers_verifier
|
23
24
|
from ai_edge_torch.generative.utilities import verifier
|
24
25
|
import transformers
|
25
26
|
|
27
|
+
|
26
28
|
_PROMPTS = flags.DEFINE_multi_string(
|
27
29
|
"prompts",
|
28
30
|
"What is the meaning of life?",
|
29
31
|
"The input prompts to generate answers.",
|
30
32
|
)
|
33
|
+
_MAX_NEW_TOKENS = flags.DEFINE_integer(
|
34
|
+
"max_new_tokens",
|
35
|
+
30,
|
36
|
+
"The maximum size of the generated tokens.",
|
37
|
+
)
|
31
38
|
|
32
39
|
|
33
40
|
def main(_):
|
34
41
|
checkpoint = "apple/OpenELM-3B"
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
checkpoint, trust_remote_code=True
|
39
|
-
),
|
42
|
+
logging.info("Loading the original model from: %s", checkpoint)
|
43
|
+
original_model = transformers.AutoModelForCausalLM.from_pretrained(
|
44
|
+
checkpoint, trust_remote_code=True
|
40
45
|
)
|
41
46
|
|
42
47
|
# Locate the cached dir.
|
@@ -44,18 +49,21 @@ def main(_):
|
|
44
49
|
checkpoint, transformers.utils.CONFIG_NAME
|
45
50
|
)
|
46
51
|
reauthored_checkpoint = pathlib.Path(cached_config_file).parent
|
47
|
-
|
52
|
+
logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
|
48
53
|
reauthored_model = openelm.build_model(reauthored_checkpoint)
|
49
54
|
|
50
55
|
tokenizer_checkpoint = "meta-llama/Llama-2-7b-hf"
|
51
|
-
|
56
|
+
logging.info("Loading the tokenizer from: %s", tokenizer_checkpoint)
|
52
57
|
tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer_checkpoint)
|
53
58
|
|
54
59
|
verifier.verify_reauthored_model(
|
55
|
-
original_model=
|
56
|
-
|
57
|
-
|
60
|
+
original_model=transformers_verifier.TransformersModelWrapper(
|
61
|
+
original_model
|
62
|
+
),
|
63
|
+
reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
|
64
|
+
tokenizer=verifier.TokenizerWrapper(tokenizer),
|
58
65
|
generate_prompts=_PROMPTS.value,
|
66
|
+
max_new_tokens=_MAX_NEW_TOKENS.value,
|
59
67
|
)
|
60
68
|
|
61
69
|
|
@@ -97,15 +97,15 @@ ROPE_SHORT_FACTOR = [
|
|
97
97
|
]
|
98
98
|
|
99
99
|
|
100
|
-
def
|
100
|
+
def _build_rope_cache(
|
101
101
|
size: int,
|
102
102
|
dim: int,
|
103
|
-
base: int
|
104
|
-
condense_ratio: int
|
105
|
-
dtype: torch.dtype
|
106
|
-
device: torch.device
|
107
|
-
theta_factors: torch.Tensor
|
108
|
-
scale: float
|
103
|
+
base: int,
|
104
|
+
condense_ratio: int,
|
105
|
+
dtype: torch.dtype,
|
106
|
+
device: torch.device,
|
107
|
+
theta_factors: torch.Tensor,
|
108
|
+
scale: float,
|
109
109
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
110
110
|
"""Precomputes Rotary Positional Embeddings for Phi-3.5 model.
|
111
111
|
|
@@ -116,26 +116,20 @@ def build_rope_cache(
|
|
116
116
|
Args:
|
117
117
|
size (int): The size of the built cache.
|
118
118
|
dim (int): Each sequence's dimmension.
|
119
|
-
base (int, optional): Rope base value.
|
119
|
+
base (int, optional): Rope base value.
|
120
120
|
condense_ratio (int, optional): The ratio by which sequence indicies are
|
121
|
-
condensed.
|
122
|
-
dtype (torch.dtype, optional): Output tensor's data type.
|
123
|
-
|
124
|
-
device (torch.device, optional): Output tensor's data type. Defaults to
|
125
|
-
None in which case "cpu" is used.
|
121
|
+
condensed.
|
122
|
+
dtype (torch.dtype, optional): Output tensor's data type.
|
123
|
+
device (torch.device, optional): Output tensor's data type.
|
126
124
|
theta_factors (torch.Tensor, optional): A tensor of shape (dim,) used to
|
127
|
-
scale the theta values.
|
128
|
-
scale (float, optional): A float used to scale the rope values.
|
129
|
-
to 1.0.
|
125
|
+
scale the theta values.
|
126
|
+
scale (float, optional): A float used to scale the rope values.
|
130
127
|
|
131
128
|
Returns:
|
132
129
|
Tuple[torch.Tensor, torch.Tensor]: Rope's Cosine and Sine waves.
|
133
130
|
"""
|
134
|
-
if device is None:
|
135
|
-
device = torch.device('cpu')
|
136
131
|
theta = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
|
137
|
-
|
138
|
-
theta = theta / theta_factors
|
132
|
+
theta = theta / theta_factors
|
139
133
|
seq_idx = torch.arange(size) / condense_ratio
|
140
134
|
idx_theta = torch.outer(seq_idx, theta)
|
141
135
|
cos = torch.cos(idx_theta).to(dtype=dtype, device=device) * scale
|
@@ -167,7 +161,7 @@ class Phi3_5Mini(nn.Module):
|
|
167
161
|
config.final_norm_config,
|
168
162
|
)
|
169
163
|
attn_config = block_config.attn_config
|
170
|
-
self.rope_cache =
|
164
|
+
self.rope_cache = _build_rope_cache(
|
171
165
|
size=config.kv_cache_max,
|
172
166
|
dim=int(attn_config.rotary_percentage * attn_config.head_dim),
|
173
167
|
base=10_000,
|
@@ -14,14 +14,17 @@
|
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
16
|
"""Verifies the reauthored Phi-2 model."""
|
17
|
+
import logging
|
17
18
|
|
18
19
|
from absl import app
|
19
20
|
from absl import flags
|
20
21
|
from ai_edge_torch.generative.examples.phi import phi2
|
22
|
+
from ai_edge_torch.generative.utilities import transformers_verifier
|
21
23
|
from ai_edge_torch.generative.utilities import verifier
|
22
24
|
import kagglehub
|
23
25
|
import transformers
|
24
26
|
|
27
|
+
|
25
28
|
_PROMPTS = flags.DEFINE_multi_string(
|
26
29
|
"prompts",
|
27
30
|
"Instruct: Write an email about the weather Output:",
|
@@ -36,25 +39,23 @@ _MAX_NEW_TOKENS = flags.DEFINE_integer(
|
|
36
39
|
|
37
40
|
def main(_):
|
38
41
|
checkpoint = kagglehub.model_download("Microsoft/phi/transformers/2")
|
39
|
-
|
40
|
-
|
41
|
-
generation_config.max_new_tokens = _MAX_NEW_TOKENS.value
|
42
|
-
wrapper_model = verifier.ModelWrapper(
|
43
|
-
model=transformers.AutoModelForCausalLM.from_pretrained(checkpoint),
|
44
|
-
hf_generation_config=generation_config,
|
45
|
-
)
|
42
|
+
logging.info("Loading the original model from: %s", checkpoint)
|
43
|
+
original_model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint)
|
46
44
|
|
47
|
-
|
45
|
+
logging.info("Building the reauthored model from: %s", checkpoint)
|
48
46
|
reauthored_model = phi2.build_model(checkpoint)
|
49
47
|
|
50
|
-
|
48
|
+
logging.info("Loading the tokenizer from: %s", checkpoint)
|
51
49
|
tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
|
52
50
|
|
53
51
|
verifier.verify_reauthored_model(
|
54
|
-
original_model=
|
55
|
-
|
56
|
-
|
52
|
+
original_model=transformers_verifier.TransformersModelWrapper(
|
53
|
+
original_model
|
54
|
+
),
|
55
|
+
reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
|
56
|
+
tokenizer=verifier.TokenizerWrapper(tokenizer),
|
57
57
|
generate_prompts=_PROMPTS.value,
|
58
|
+
max_new_tokens=_MAX_NEW_TOKENS.value,
|
58
59
|
atol=1e-03,
|
59
60
|
)
|
60
61
|
|
@@ -15,14 +15,17 @@
|
|
15
15
|
|
16
16
|
"""Verifies the reauthored Phi-3.5 model."""
|
17
17
|
|
18
|
+
import logging
|
18
19
|
import pathlib
|
19
20
|
|
20
21
|
from absl import app
|
21
22
|
from absl import flags
|
22
23
|
from ai_edge_torch.generative.examples.phi import phi3
|
24
|
+
from ai_edge_torch.generative.utilities import transformers_verifier
|
23
25
|
from ai_edge_torch.generative.utilities import verifier
|
24
26
|
import transformers
|
25
27
|
|
28
|
+
|
26
29
|
_PROMPTS = flags.DEFINE_multi_string(
|
27
30
|
"prompts",
|
28
31
|
"Instruct: Write an email about the weather Output:",
|
@@ -37,30 +40,28 @@ _MAX_NEW_TOKENS = flags.DEFINE_integer(
|
|
37
40
|
|
38
41
|
def main(_):
|
39
42
|
checkpoint = "microsoft/Phi-3.5-mini-instruct"
|
40
|
-
|
41
|
-
|
42
|
-
generation_config.max_new_tokens = _MAX_NEW_TOKENS.value
|
43
|
-
wrapper_model = verifier.ModelWrapper(
|
44
|
-
model=transformers.AutoModelForCausalLM.from_pretrained(checkpoint),
|
45
|
-
hf_generation_config=generation_config,
|
46
|
-
)
|
43
|
+
logging.info("Loading the original model from: %s", checkpoint)
|
44
|
+
original_model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint)
|
47
45
|
|
48
46
|
# Locate the cached dir.
|
49
47
|
cached_config_file = transformers.utils.cached_file(
|
50
48
|
checkpoint, transformers.utils.CONFIG_NAME
|
51
49
|
)
|
52
50
|
reauthored_checkpoint = pathlib.Path(cached_config_file).parent
|
53
|
-
|
51
|
+
logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
|
54
52
|
reauthored_model = phi3.build_model(reauthored_checkpoint)
|
55
53
|
|
56
|
-
|
54
|
+
logging.info("Loading the tokenizer from: %s", checkpoint)
|
57
55
|
tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
|
58
56
|
|
59
57
|
verifier.verify_reauthored_model(
|
60
|
-
original_model=
|
61
|
-
|
62
|
-
|
58
|
+
original_model=transformers_verifier.TransformersModelWrapper(
|
59
|
+
original_model
|
60
|
+
),
|
61
|
+
reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
|
62
|
+
tokenizer=verifier.TokenizerWrapper(tokenizer),
|
63
63
|
generate_prompts=_PROMPTS.value,
|
64
|
+
max_new_tokens=_MAX_NEW_TOKENS.value,
|
64
65
|
)
|
65
66
|
|
66
67
|
|
@@ -15,43 +15,53 @@
|
|
15
15
|
|
16
16
|
"""Verifies the reauthored SmolLM-135M model."""
|
17
17
|
|
18
|
+
import logging
|
18
19
|
import pathlib
|
19
20
|
|
20
21
|
from absl import app
|
21
22
|
from absl import flags
|
22
23
|
from ai_edge_torch.generative.examples.smollm import smollm
|
24
|
+
from ai_edge_torch.generative.utilities import transformers_verifier
|
23
25
|
from ai_edge_torch.generative.utilities import verifier
|
24
26
|
import transformers
|
25
27
|
|
28
|
+
|
26
29
|
_PROMPTS = flags.DEFINE_multi_string(
|
27
30
|
"prompts",
|
28
31
|
"What is the meaning of life?",
|
29
32
|
"The input prompts to generate answers.",
|
30
33
|
)
|
34
|
+
_MAX_NEW_TOKENS = flags.DEFINE_integer(
|
35
|
+
"max_new_tokens",
|
36
|
+
30,
|
37
|
+
"The maximum size of the generated tokens.",
|
38
|
+
)
|
31
39
|
|
32
40
|
|
33
41
|
def main(_):
|
34
42
|
checkpoint = "HuggingFaceTB/SmolLM-135M"
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
)
|
43
|
+
logging.info("Loading the original model from: %s", checkpoint)
|
44
|
+
original_model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint)
|
45
|
+
|
39
46
|
# Locate the cached dir.
|
40
47
|
cached_config_file = transformers.utils.cached_file(
|
41
48
|
checkpoint, transformers.utils.CONFIG_NAME
|
42
49
|
)
|
43
50
|
reauthored_checkpoint = pathlib.Path(cached_config_file).parent
|
44
|
-
|
51
|
+
logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
|
45
52
|
reauthored_model = smollm.build_model(reauthored_checkpoint)
|
46
53
|
|
47
|
-
|
54
|
+
logging.info("Loading the tokenizer from: %s", checkpoint)
|
48
55
|
tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
|
49
56
|
|
50
57
|
verifier.verify_reauthored_model(
|
51
|
-
original_model=
|
52
|
-
|
53
|
-
|
58
|
+
original_model=transformers_verifier.TransformersModelWrapper(
|
59
|
+
original_model
|
60
|
+
),
|
61
|
+
reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
|
62
|
+
tokenizer=verifier.TokenizerWrapper(tokenizer),
|
54
63
|
generate_prompts=_PROMPTS.value,
|
64
|
+
max_new_tokens=_MAX_NEW_TOKENS.value,
|
55
65
|
atol=1e-04,
|
56
66
|
)
|
57
67
|
|
@@ -15,45 +15,55 @@
|
|
15
15
|
|
16
16
|
"""Verifies the reauthored TinyLlama-1.1B model."""
|
17
17
|
|
18
|
+
import logging
|
18
19
|
import pathlib
|
19
20
|
|
20
21
|
from absl import app
|
21
22
|
from absl import flags
|
22
23
|
from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
|
24
|
+
from ai_edge_torch.generative.utilities import transformers_verifier
|
23
25
|
from ai_edge_torch.generative.utilities import verifier
|
24
26
|
import transformers
|
25
27
|
|
28
|
+
|
26
29
|
_PROMPTS = flags.DEFINE_multi_string(
|
27
30
|
"prompts",
|
28
31
|
"Show me the program to add 2 and 3.",
|
29
32
|
"The input prompts to generate answers.",
|
30
33
|
)
|
34
|
+
_MAX_NEW_TOKENS = flags.DEFINE_integer(
|
35
|
+
"max_new_tokens",
|
36
|
+
30,
|
37
|
+
"The maximum size of the generated tokens.",
|
38
|
+
)
|
31
39
|
|
32
40
|
|
33
41
|
def main(_):
|
34
42
|
checkpoint = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
checkpoint, trust_remote_code=True
|
39
|
-
),
|
43
|
+
logging.info("Loading the original model from: %s", checkpoint)
|
44
|
+
original_model = transformers.AutoModelForCausalLM.from_pretrained(
|
45
|
+
checkpoint, trust_remote_code=True
|
40
46
|
)
|
47
|
+
|
41
48
|
# Locate the cached dir.
|
42
49
|
cached_config_file = transformers.utils.cached_file(
|
43
50
|
checkpoint, transformers.utils.CONFIG_NAME
|
44
51
|
)
|
45
52
|
reauthored_checkpoint = pathlib.Path(cached_config_file).parent
|
46
|
-
|
53
|
+
logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
|
47
54
|
reauthored_model = tiny_llama.build_model(reauthored_checkpoint)
|
48
55
|
|
49
|
-
|
56
|
+
logging.info("Loading the tokenizer from: %s", checkpoint)
|
50
57
|
tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
|
51
58
|
|
52
59
|
verifier.verify_reauthored_model(
|
53
|
-
original_model=
|
54
|
-
|
55
|
-
|
60
|
+
original_model=transformers_verifier.TransformersModelWrapper(
|
61
|
+
original_model
|
62
|
+
),
|
63
|
+
reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
|
64
|
+
tokenizer=verifier.TokenizerWrapper(tokenizer),
|
56
65
|
generate_prompts=_PROMPTS.value,
|
66
|
+
max_new_tokens=_MAX_NEW_TOKENS.value,
|
57
67
|
atol=1e-04,
|
58
68
|
)
|
59
69
|
|
@@ -19,6 +19,7 @@ import ai_edge_torch
|
|
19
19
|
from ai_edge_torch import config as ai_edge_config
|
20
20
|
from ai_edge_torch.generative.examples.gemma import gemma1
|
21
21
|
from ai_edge_torch.generative.examples.gemma import gemma2
|
22
|
+
from ai_edge_torch.generative.examples.llama import llama
|
22
23
|
from ai_edge_torch.generative.examples.openelm import openelm
|
23
24
|
from ai_edge_torch.generative.examples.phi import phi2
|
24
25
|
from ai_edge_torch.generative.examples.phi import phi3
|
@@ -102,6 +103,15 @@ class TestModelConversion(googletest.TestCase):
|
|
102
103
|
pytorch_model = gemma2.Gemma2(config).eval()
|
103
104
|
self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
|
104
105
|
|
106
|
+
@googletest.skipIf(
|
107
|
+
ai_edge_config.Config.use_torch_xla,
|
108
|
+
reason="tests with custom ops are not supported on oss",
|
109
|
+
)
|
110
|
+
def test_llama(self):
|
111
|
+
config = llama.get_fake_model_config()
|
112
|
+
pytorch_model = llama.Llama(config).eval()
|
113
|
+
self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)
|
114
|
+
|
105
115
|
@googletest.skipIf(
|
106
116
|
ai_edge_config.Config.use_torch_xla,
|
107
117
|
reason="tests with custom ops are not supported on oss",
|
@@ -0,0 +1,42 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Utilities for the models predefined in HuggingFace transformers."""
|
17
|
+
|
18
|
+
from typing import cast
|
19
|
+
|
20
|
+
from ai_edge_torch.generative.utilities import verifier
|
21
|
+
import torch
|
22
|
+
import transformers
|
23
|
+
|
24
|
+
|
25
|
+
class TransformersModelWrapper(verifier.ModelWrapper):
|
26
|
+
"""A wrapper for the model predefined in HuggingFace transformers.
|
27
|
+
|
28
|
+
Verifier expects forward() to return logits while Transformers models return
|
29
|
+
an object with `logits` field.
|
30
|
+
|
31
|
+
Transformers models get `max_new_tokens` settings for generate() via
|
32
|
+
GenerationConfig.
|
33
|
+
"""
|
34
|
+
|
35
|
+
def forward(self, tokens: torch.Tensor) -> torch.Tensor:
|
36
|
+
return self.model.forward(tokens).logits
|
37
|
+
|
38
|
+
def generate(
|
39
|
+
self, inputs: torch.Tensor, max_new_tokens: int
|
40
|
+
) -> torch.IntTensor:
|
41
|
+
gen_config = transformers.GenerationConfig(max_new_tokens=max_new_tokens)
|
42
|
+
return self.model.generate(inputs=inputs, generation_config=gen_config)
|