ai-edge-torch-nightly 0.3.0.dev20240926__py3-none-any.whl → 0.3.0.dev20240927__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/generative/examples/gemma/verify_gemma2.py +0 -1
- ai_edge_torch/generative/examples/gemma/verify_util.py +13 -24
- ai_edge_torch/generative/examples/llama/__init__.py +14 -0
- ai_edge_torch/generative/examples/llama/convert_3b_to_tflite.py +68 -0
- ai_edge_torch/generative/examples/llama/convert_to_tflite.py +68 -0
- ai_edge_torch/generative/examples/llama/llama.py +203 -0
- ai_edge_torch/generative/examples/llama/verify.py +73 -0
- ai_edge_torch/generative/examples/llama/verify_3b.py +73 -0
- ai_edge_torch/generative/examples/openelm/verify.py +14 -7
- ai_edge_torch/generative/examples/phi/phi3.py +15 -21
- ai_edge_torch/generative/examples/phi/verify.py +8 -9
- ai_edge_torch/generative/examples/phi/verify_phi3.py +8 -9
- ai_edge_torch/generative/examples/smollm/verify.py +14 -6
- ai_edge_torch/generative/examples/tiny_llama/verify.py +15 -7
- ai_edge_torch/generative/test/test_model_conversion_large.py +10 -0
- ai_edge_torch/generative/utilities/transformers_verifier.py +42 -0
- ai_edge_torch/generative/utilities/verifier.py +117 -97
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240926.dist-info → ai_edge_torch_nightly-0.3.0.dev20240927.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240926.dist-info → ai_edge_torch_nightly-0.3.0.dev20240927.dist-info}/RECORD +23 -16
- {ai_edge_torch_nightly-0.3.0.dev20240926.dist-info → ai_edge_torch_nightly-0.3.0.dev20240927.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240926.dist-info → ai_edge_torch_nightly-0.3.0.dev20240927.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240926.dist-info → ai_edge_torch_nightly-0.3.0.dev20240927.dist-info}/top_level.txt +0 -0
@@ -97,15 +97,15 @@ ROPE_SHORT_FACTOR = [
|
|
97
97
|
]
|
98
98
|
|
99
99
|
|
100
|
-
def
|
100
|
+
def _build_rope_cache(
|
101
101
|
size: int,
|
102
102
|
dim: int,
|
103
|
-
base: int
|
104
|
-
condense_ratio: int
|
105
|
-
dtype: torch.dtype
|
106
|
-
device: torch.device
|
107
|
-
theta_factors: torch.Tensor
|
108
|
-
scale: float
|
103
|
+
base: int,
|
104
|
+
condense_ratio: int,
|
105
|
+
dtype: torch.dtype,
|
106
|
+
device: torch.device,
|
107
|
+
theta_factors: torch.Tensor,
|
108
|
+
scale: float,
|
109
109
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
110
110
|
"""Precomputes Rotary Positional Embeddings for Phi-3.5 model.
|
111
111
|
|
@@ -116,26 +116,20 @@ def build_rope_cache(
|
|
116
116
|
Args:
|
117
117
|
size (int): The size of the built cache.
|
118
118
|
dim (int): Each sequence's dimmension.
|
119
|
-
base (int, optional): Rope base value.
|
119
|
+
base (int, optional): Rope base value.
|
120
120
|
condense_ratio (int, optional): The ratio by which sequence indicies are
|
121
|
-
condensed.
|
122
|
-
dtype (torch.dtype, optional): Output tensor's data type.
|
123
|
-
|
124
|
-
device (torch.device, optional): Output tensor's data type. Defaults to
|
125
|
-
None in which case "cpu" is used.
|
121
|
+
condensed.
|
122
|
+
dtype (torch.dtype, optional): Output tensor's data type.
|
123
|
+
device (torch.device, optional): Output tensor's data type.
|
126
124
|
theta_factors (torch.Tensor, optional): A tensor of shape (dim,) used to
|
127
|
-
scale the theta values.
|
128
|
-
scale (float, optional): A float used to scale the rope values.
|
129
|
-
to 1.0.
|
125
|
+
scale the theta values.
|
126
|
+
scale (float, optional): A float used to scale the rope values.
|
130
127
|
|
131
128
|
Returns:
|
132
129
|
Tuple[torch.Tensor, torch.Tensor]: Rope's Cosine and Sine waves.
|
133
130
|
"""
|
134
|
-
if device is None:
|
135
|
-
device = torch.device('cpu')
|
136
131
|
theta = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
|
137
|
-
|
138
|
-
theta = theta / theta_factors
|
132
|
+
theta = theta / theta_factors
|
139
133
|
seq_idx = torch.arange(size) / condense_ratio
|
140
134
|
idx_theta = torch.outer(seq_idx, theta)
|
141
135
|
cos = torch.cos(idx_theta).to(dtype=dtype, device=device) * scale
|
@@ -167,7 +161,7 @@ class Phi3_5Mini(nn.Module):
|
|
167
161
|
config.final_norm_config,
|
168
162
|
)
|
169
163
|
attn_config = block_config.attn_config
|
170
|
-
self.rope_cache =
|
164
|
+
self.rope_cache = _build_rope_cache(
|
171
165
|
size=config.kv_cache_max,
|
172
166
|
dim=int(attn_config.rotary_percentage * attn_config.head_dim),
|
173
167
|
base=10_000,
|
@@ -19,6 +19,7 @@ import logging
|
|
19
19
|
from absl import app
|
20
20
|
from absl import flags
|
21
21
|
from ai_edge_torch.generative.examples.phi import phi2
|
22
|
+
from ai_edge_torch.generative.utilities import transformers_verifier
|
22
23
|
from ai_edge_torch.generative.utilities import verifier
|
23
24
|
import kagglehub
|
24
25
|
import transformers
|
@@ -39,12 +40,7 @@ _MAX_NEW_TOKENS = flags.DEFINE_integer(
|
|
39
40
|
def main(_):
|
40
41
|
checkpoint = kagglehub.model_download("Microsoft/phi/transformers/2")
|
41
42
|
logging.info("Loading the original model from: %s", checkpoint)
|
42
|
-
|
43
|
-
generation_config.max_new_tokens = _MAX_NEW_TOKENS.value
|
44
|
-
wrapper_model = verifier.ModelWrapper(
|
45
|
-
model=transformers.AutoModelForCausalLM.from_pretrained(checkpoint),
|
46
|
-
hf_generation_config=generation_config,
|
47
|
-
)
|
43
|
+
original_model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint)
|
48
44
|
|
49
45
|
logging.info("Building the reauthored model from: %s", checkpoint)
|
50
46
|
reauthored_model = phi2.build_model(checkpoint)
|
@@ -53,10 +49,13 @@ def main(_):
|
|
53
49
|
tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
|
54
50
|
|
55
51
|
verifier.verify_reauthored_model(
|
56
|
-
original_model=
|
57
|
-
|
58
|
-
|
52
|
+
original_model=transformers_verifier.TransformersModelWrapper(
|
53
|
+
original_model
|
54
|
+
),
|
55
|
+
reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
|
56
|
+
tokenizer=verifier.TokenizerWrapper(tokenizer),
|
59
57
|
generate_prompts=_PROMPTS.value,
|
58
|
+
max_new_tokens=_MAX_NEW_TOKENS.value,
|
60
59
|
atol=1e-03,
|
61
60
|
)
|
62
61
|
|
@@ -21,6 +21,7 @@ import pathlib
|
|
21
21
|
from absl import app
|
22
22
|
from absl import flags
|
23
23
|
from ai_edge_torch.generative.examples.phi import phi3
|
24
|
+
from ai_edge_torch.generative.utilities import transformers_verifier
|
24
25
|
from ai_edge_torch.generative.utilities import verifier
|
25
26
|
import transformers
|
26
27
|
|
@@ -40,12 +41,7 @@ _MAX_NEW_TOKENS = flags.DEFINE_integer(
|
|
40
41
|
def main(_):
|
41
42
|
checkpoint = "microsoft/Phi-3.5-mini-instruct"
|
42
43
|
logging.info("Loading the original model from: %s", checkpoint)
|
43
|
-
|
44
|
-
generation_config.max_new_tokens = _MAX_NEW_TOKENS.value
|
45
|
-
wrapper_model = verifier.ModelWrapper(
|
46
|
-
model=transformers.AutoModelForCausalLM.from_pretrained(checkpoint),
|
47
|
-
hf_generation_config=generation_config,
|
48
|
-
)
|
44
|
+
original_model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint)
|
49
45
|
|
50
46
|
# Locate the cached dir.
|
51
47
|
cached_config_file = transformers.utils.cached_file(
|
@@ -59,10 +55,13 @@ def main(_):
|
|
59
55
|
tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
|
60
56
|
|
61
57
|
verifier.verify_reauthored_model(
|
62
|
-
original_model=
|
63
|
-
|
64
|
-
|
58
|
+
original_model=transformers_verifier.TransformersModelWrapper(
|
59
|
+
original_model
|
60
|
+
),
|
61
|
+
reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
|
62
|
+
tokenizer=verifier.TokenizerWrapper(tokenizer),
|
65
63
|
generate_prompts=_PROMPTS.value,
|
64
|
+
max_new_tokens=_MAX_NEW_TOKENS.value,
|
66
65
|
)
|
67
66
|
|
68
67
|
|
@@ -21,6 +21,7 @@ import pathlib
|
|
21
21
|
from absl import app
|
22
22
|
from absl import flags
|
23
23
|
from ai_edge_torch.generative.examples.smollm import smollm
|
24
|
+
from ai_edge_torch.generative.utilities import transformers_verifier
|
24
25
|
from ai_edge_torch.generative.utilities import verifier
|
25
26
|
import transformers
|
26
27
|
|
@@ -30,14 +31,18 @@ _PROMPTS = flags.DEFINE_multi_string(
|
|
30
31
|
"What is the meaning of life?",
|
31
32
|
"The input prompts to generate answers.",
|
32
33
|
)
|
34
|
+
_MAX_NEW_TOKENS = flags.DEFINE_integer(
|
35
|
+
"max_new_tokens",
|
36
|
+
30,
|
37
|
+
"The maximum size of the generated tokens.",
|
38
|
+
)
|
33
39
|
|
34
40
|
|
35
41
|
def main(_):
|
36
42
|
checkpoint = "HuggingFaceTB/SmolLM-135M"
|
37
43
|
logging.info("Loading the original model from: %s", checkpoint)
|
38
|
-
|
39
|
-
|
40
|
-
)
|
44
|
+
original_model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint)
|
45
|
+
|
41
46
|
# Locate the cached dir.
|
42
47
|
cached_config_file = transformers.utils.cached_file(
|
43
48
|
checkpoint, transformers.utils.CONFIG_NAME
|
@@ -50,10 +55,13 @@ def main(_):
|
|
50
55
|
tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
|
51
56
|
|
52
57
|
verifier.verify_reauthored_model(
|
53
|
-
original_model=
|
54
|
-
|
55
|
-
|
58
|
+
original_model=transformers_verifier.TransformersModelWrapper(
|
59
|
+
original_model
|
60
|
+
),
|
61
|
+
reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
|
62
|
+
tokenizer=verifier.TokenizerWrapper(tokenizer),
|
56
63
|
generate_prompts=_PROMPTS.value,
|
64
|
+
max_new_tokens=_MAX_NEW_TOKENS.value,
|
57
65
|
atol=1e-04,
|
58
66
|
)
|
59
67
|
|
@@ -21,6 +21,7 @@ import pathlib
|
|
21
21
|
from absl import app
|
22
22
|
from absl import flags
|
23
23
|
from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
|
24
|
+
from ai_edge_torch.generative.utilities import transformers_verifier
|
24
25
|
from ai_edge_torch.generative.utilities import verifier
|
25
26
|
import transformers
|
26
27
|
|
@@ -30,16 +31,20 @@ _PROMPTS = flags.DEFINE_multi_string(
|
|
30
31
|
"Show me the program to add 2 and 3.",
|
31
32
|
"The input prompts to generate answers.",
|
32
33
|
)
|
34
|
+
_MAX_NEW_TOKENS = flags.DEFINE_integer(
|
35
|
+
"max_new_tokens",
|
36
|
+
30,
|
37
|
+
"The maximum size of the generated tokens.",
|
38
|
+
)
|
33
39
|
|
34
40
|
|
35
41
|
def main(_):
|
36
42
|
checkpoint = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
|
37
43
|
logging.info("Loading the original model from: %s", checkpoint)
|
38
|
-
|
39
|
-
|
40
|
-
checkpoint, trust_remote_code=True
|
41
|
-
),
|
44
|
+
original_model = transformers.AutoModelForCausalLM.from_pretrained(
|
45
|
+
checkpoint, trust_remote_code=True
|
42
46
|
)
|
47
|
+
|
43
48
|
# Locate the cached dir.
|
44
49
|
cached_config_file = transformers.utils.cached_file(
|
45
50
|
checkpoint, transformers.utils.CONFIG_NAME
|
@@ -52,10 +57,13 @@ def main(_):
|
|
52
57
|
tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
|
53
58
|
|
54
59
|
verifier.verify_reauthored_model(
|
55
|
-
original_model=
|
56
|
-
|
57
|
-
|
60
|
+
original_model=transformers_verifier.TransformersModelWrapper(
|
61
|
+
original_model
|
62
|
+
),
|
63
|
+
reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
|
64
|
+
tokenizer=verifier.TokenizerWrapper(tokenizer),
|
58
65
|
generate_prompts=_PROMPTS.value,
|
66
|
+
max_new_tokens=_MAX_NEW_TOKENS.value,
|
59
67
|
atol=1e-04,
|
60
68
|
)
|
61
69
|
|
@@ -19,6 +19,7 @@ import ai_edge_torch
|
|
19
19
|
from ai_edge_torch import config as ai_edge_config
|
20
20
|
from ai_edge_torch.generative.examples.gemma import gemma1
|
21
21
|
from ai_edge_torch.generative.examples.gemma import gemma2
|
22
|
+
from ai_edge_torch.generative.examples.llama import llama
|
22
23
|
from ai_edge_torch.generative.examples.openelm import openelm
|
23
24
|
from ai_edge_torch.generative.examples.phi import phi2
|
24
25
|
from ai_edge_torch.generative.examples.phi import phi3
|
@@ -102,6 +103,15 @@ class TestModelConversion(googletest.TestCase):
|
|
102
103
|
pytorch_model = gemma2.Gemma2(config).eval()
|
103
104
|
self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
|
104
105
|
|
106
|
+
@googletest.skipIf(
|
107
|
+
ai_edge_config.Config.use_torch_xla,
|
108
|
+
reason="tests with custom ops are not supported on oss",
|
109
|
+
)
|
110
|
+
def test_llama(self):
|
111
|
+
config = llama.get_fake_model_config()
|
112
|
+
pytorch_model = llama.Llama(config).eval()
|
113
|
+
self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)
|
114
|
+
|
105
115
|
@googletest.skipIf(
|
106
116
|
ai_edge_config.Config.use_torch_xla,
|
107
117
|
reason="tests with custom ops are not supported on oss",
|
@@ -0,0 +1,42 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Utilities for the models predefined in HuggingFace transformers."""
|
17
|
+
|
18
|
+
from typing import cast
|
19
|
+
|
20
|
+
from ai_edge_torch.generative.utilities import verifier
|
21
|
+
import torch
|
22
|
+
import transformers
|
23
|
+
|
24
|
+
|
25
|
+
class TransformersModelWrapper(verifier.ModelWrapper):
|
26
|
+
"""A wrapper for the model predefined in HuggingFace transformers.
|
27
|
+
|
28
|
+
Verifier expects forward() to return logits while Transformers models return
|
29
|
+
an object with `logits` field.
|
30
|
+
|
31
|
+
Transformers models get `max_new_tokens` settings for generate() via
|
32
|
+
GenerationConfig.
|
33
|
+
"""
|
34
|
+
|
35
|
+
def forward(self, tokens: torch.Tensor) -> torch.Tensor:
|
36
|
+
return self.model.forward(tokens).logits
|
37
|
+
|
38
|
+
def generate(
|
39
|
+
self, inputs: torch.Tensor, max_new_tokens: int
|
40
|
+
) -> torch.IntTensor:
|
41
|
+
gen_config = transformers.GenerationConfig(max_new_tokens=max_new_tokens)
|
42
|
+
return self.model.generate(inputs=inputs, generation_config=gen_config)
|
@@ -16,111 +16,129 @@
|
|
16
16
|
"""Common utility functions to verify the reauthored models."""
|
17
17
|
|
18
18
|
import logging
|
19
|
-
from typing import List
|
19
|
+
from typing import List
|
20
20
|
|
21
21
|
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
22
22
|
import torch
|
23
|
-
import transformers
|
24
23
|
|
25
24
|
|
26
25
|
class ModelWrapper(torch.nn.Module):
|
27
|
-
"""A wrapper for the model to be verified
|
26
|
+
"""A wrapper for the model to be verified.
|
28
27
|
|
29
|
-
|
28
|
+
It unifies the interface of forward() and generate() of models for the
|
29
|
+
verification to call.
|
30
30
|
"""
|
31
31
|
|
32
|
-
def __init__(
|
33
|
-
self,
|
34
|
-
model: torch.nn.Module,
|
35
|
-
model_format: str = "huggingface",
|
36
|
-
hf_generation_config: Optional[transformers.GenerationConfig] = None,
|
37
|
-
):
|
32
|
+
def __init__(self, model: torch.nn.Module):
|
38
33
|
"""Initializes the wrapper.
|
39
34
|
|
40
35
|
Args:
|
41
|
-
model (torch.nn.Module): The
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
hf_generation_config (transformers.GenerationConfig): The HuggingFace
|
46
|
-
generation config. This config will only be used if the underlying model
|
47
|
-
is built from HuggingFace transformers.
|
36
|
+
model (torch.nn.Module): The model which might have different interfaces
|
37
|
+
of forward() and generate(). It could be a model built from HuggingFace
|
38
|
+
transformers, a regular PyTorch model, or a model re-authored with
|
39
|
+
ai_edge_torch Generative API.
|
48
40
|
"""
|
49
41
|
super().__init__()
|
50
42
|
self.model = model
|
51
|
-
|
52
|
-
|
43
|
+
|
44
|
+
def forward(self, tokens: torch.Tensor) -> torch.Tensor:
|
45
|
+
"""Gets output logits by forwarding the input tokens.
|
46
|
+
|
47
|
+
Args:
|
48
|
+
tokens (torch.Tensor): The input tokens to forward. Its dimension is
|
49
|
+
expected to be (batch_size=1, kv_cache_max_len).
|
50
|
+
|
51
|
+
Returns:
|
52
|
+
The output logits.
|
53
|
+
"""
|
54
|
+
raise NotImplementedError("forward() is not implemented.")
|
53
55
|
|
54
56
|
def generate(
|
55
|
-
self,
|
56
|
-
) ->
|
57
|
-
|
58
|
-
return self.model.generate(
|
59
|
-
inputs=inputs, generation_config=self.hf_generation_config
|
60
|
-
)
|
61
|
-
else:
|
62
|
-
raise NotImplementedError(
|
63
|
-
"generate() is not implemented for model format: %s"
|
64
|
-
% self.model_format
|
65
|
-
)
|
57
|
+
self, prompts: torch.Tensor, max_new_tokens: int
|
58
|
+
) -> torch.IntTensor:
|
59
|
+
"""Returns the response token IDs to the given prompts tensor.
|
66
60
|
|
67
|
-
|
68
|
-
self,
|
69
|
-
inputs: torch.Tensor,
|
70
|
-
):
|
71
|
-
return self.model.forward(inputs)
|
61
|
+
The maximum number of tokens to generate might be set by subclasses.
|
72
62
|
|
63
|
+
Args:
|
64
|
+
prompts (torch.Tensor): The input token IDs to generate with. Its shape is
|
65
|
+
expected to be (batch_size=1, input_ids_len).
|
66
|
+
max_new_tokens (int): The maximum number of response token IDs to
|
67
|
+
generate.
|
68
|
+
|
69
|
+
Returns:
|
70
|
+
The tensor of response token IDs with shape of (batch_size=1,
|
71
|
+
response_ids_len).
|
72
|
+
"""
|
73
|
+
raise NotImplementedError("generate() is not implemented.")
|
73
74
|
|
74
|
-
def forward(
|
75
|
-
model: torch.nn.Module,
|
76
|
-
tokens: torch.Tensor,
|
77
|
-
kv_cache: kv_utils.KVCache,
|
78
|
-
) -> tuple[torch.Tensor, kv_utils.KVCache]:
|
79
|
-
"""Forwards the model reauthored with ai_edge_torch Generative API.
|
80
75
|
|
81
|
-
|
82
|
-
|
83
|
-
with ai_edge_torch Generative API.
|
84
|
-
tokens (torch.Tensor): The input tokens to forward.
|
85
|
-
kv_cache (KVCache): The KV cache to forward.
|
76
|
+
class ReauthoredModelWrapper(ModelWrapper):
|
77
|
+
"""A wrapper for the model reauthored with ai_edge_torch Generative API."""
|
86
78
|
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
input_pos = torch.arange(0, tokens.shape[1], dtype=torch.int)
|
91
|
-
output = model.forward(tokens, input_pos, kv_cache)
|
92
|
-
return output["logits"], output["kv_cache"]
|
79
|
+
def _init_kv_cache(self):
|
80
|
+
"""Returns an initialized KV cache."""
|
81
|
+
return kv_utils.KVCache.from_model_config(self.model.config)
|
93
82
|
|
83
|
+
def _forward_with_kv_cache(
|
84
|
+
self,
|
85
|
+
tokens: torch.Tensor,
|
86
|
+
kv_cache: kv_utils.KVCache,
|
87
|
+
) -> tuple[torch.Tensor, kv_utils.KVCache]:
|
88
|
+
"""Forwards the model and updates an external KV cache.
|
94
89
|
|
95
|
-
|
96
|
-
|
97
|
-
)
|
98
|
-
"""Generates the response to the prompts.
|
90
|
+
Args:
|
91
|
+
tokens (torch.Tensor): The input tokens to forward.
|
92
|
+
kv_cache (KVCache): The KV cache to forward.
|
99
93
|
|
100
|
-
|
101
|
-
|
94
|
+
Returns:
|
95
|
+
The output logits and the updated KV cache.
|
96
|
+
"""
|
97
|
+
input_pos = torch.arange(0, tokens.shape[1], dtype=torch.int)
|
98
|
+
output = self.model.forward(tokens, input_pos, kv_cache)
|
99
|
+
return output["logits"], output["kv_cache"]
|
102
100
|
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
prompts (torch.Tensor): The prompts to generate.
|
107
|
-
response_len (int): The number of tokens to generate.
|
101
|
+
def forward(self, tokens: torch.Tensor) -> torch.Tensor:
|
102
|
+
logits, _ = self._forward_with_kv_cache(tokens, self._init_kv_cache())
|
103
|
+
return logits
|
108
104
|
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
105
|
+
def generate(
|
106
|
+
self, prompts: torch.Tensor, max_new_tokens: int
|
107
|
+
) -> torch.IntTensor:
|
108
|
+
input_ids = prompts[0].int().tolist()
|
109
|
+
kv_cache = self._init_kv_cache()
|
110
|
+
for _ in range(max_new_tokens):
|
111
|
+
tokens = torch.tensor([input_ids])
|
112
|
+
logits, kv_cache = self._forward_with_kv_cache(tokens, kv_cache)
|
113
|
+
generated_token = logits[0][-1].argmax().item()
|
114
|
+
input_ids.append(generated_token)
|
115
|
+
return torch.tensor([input_ids])
|
116
|
+
|
117
|
+
|
118
|
+
class TokenizerWrapper(torch.nn.Module):
|
119
|
+
"""A wrapper for the tokenizer used for verification."""
|
120
|
+
|
121
|
+
def __init__(self, tokenizer: torch.nn.Module):
|
122
|
+
"""Initializes the wrapper.
|
123
|
+
|
124
|
+
Args:
|
125
|
+
tokenizer (torch.nn.Module): The tokenizer to wrap.
|
126
|
+
"""
|
127
|
+
super().__init__()
|
128
|
+
self.tokenizer = tokenizer
|
129
|
+
|
130
|
+
def encode(self, prompts: str) -> torch.Tensor:
|
131
|
+
"""Encodes the prompts to token IDs."""
|
132
|
+
return self.tokenizer.encode(prompts, return_tensors="pt")
|
133
|
+
|
134
|
+
def decode(self, token_ids: torch.Tensor) -> str:
|
135
|
+
"""Decodes the token IDs to a string."""
|
136
|
+
return self.tokenizer.decode(token_ids)
|
119
137
|
|
120
138
|
|
121
139
|
def verify_with_input_ids(
|
122
140
|
original_model: ModelWrapper,
|
123
|
-
reauthored_model:
|
141
|
+
reauthored_model: ReauthoredModelWrapper,
|
124
142
|
input_ids: List[int],
|
125
143
|
kv_cache_max_len: int = 1024,
|
126
144
|
rtol: float = 1e-05,
|
@@ -132,8 +150,8 @@ def verify_with_input_ids(
|
|
132
150
|
|
133
151
|
Args:
|
134
152
|
original_model (ModelWrapper): The original model.
|
135
|
-
reauthored_model (
|
136
|
-
Generative API.
|
153
|
+
reauthored_model (ReauthoredModelWrapper): The model reauthored with
|
154
|
+
ai_edge_torch Generative API.
|
137
155
|
input_ids (List[int]): The input token IDs to forward with.
|
138
156
|
kv_cache_max_len (int): The maximum sequence length of the KV cache.
|
139
157
|
rtol (float): The relative tolerance for the comparison.
|
@@ -147,13 +165,12 @@ def verify_with_input_ids(
|
|
147
165
|
|
148
166
|
logging.info("Forwarding the original model...")
|
149
167
|
outputs_original = original_model.forward(tokens)
|
150
|
-
logits_original = outputs_original
|
168
|
+
logits_original = outputs_original[0, len(input_ids) - 1, :]
|
151
169
|
logging.info("logits_original: %s", logits_original)
|
152
170
|
|
153
171
|
logging.info("Forwarding the reauthored model...")
|
154
|
-
|
155
|
-
|
156
|
-
logits_reauthored = outputs_reauthored[0][0, len(input_ids) - 1, :]
|
172
|
+
outputs_reauthored = reauthored_model.forward(tokens)
|
173
|
+
logits_reauthored = outputs_reauthored[0, len(input_ids) - 1, :]
|
157
174
|
logging.info("logits_reauthored: %s", logits_reauthored)
|
158
175
|
|
159
176
|
return torch.allclose(
|
@@ -163,9 +180,10 @@ def verify_with_input_ids(
|
|
163
180
|
|
164
181
|
def verify_model_with_prompts(
|
165
182
|
original_model: ModelWrapper,
|
166
|
-
reauthored_model:
|
167
|
-
tokenizer:
|
183
|
+
reauthored_model: ReauthoredModelWrapper,
|
184
|
+
tokenizer: TokenizerWrapper,
|
168
185
|
prompts: str,
|
186
|
+
max_new_tokens: int,
|
169
187
|
) -> bool:
|
170
188
|
"""Verifies if the model reauthored generates the same answer of the oringal.
|
171
189
|
|
@@ -174,24 +192,24 @@ def verify_model_with_prompts(
|
|
174
192
|
|
175
193
|
Args:
|
176
194
|
original_model (ModelWrapper): The original model.
|
177
|
-
reauthored_model (
|
178
|
-
Generative API.
|
179
|
-
tokenizer (
|
195
|
+
reauthored_model (ReauthoredModelWrapper): The model reauthored with
|
196
|
+
ai_edge_torch Generative API.
|
197
|
+
tokenizer (TokenizerWrapper): The tokenizer.
|
180
198
|
prompts (str): The input prompts to generate answers.
|
199
|
+
max_new_tokens (int): The maximum number of new tokens to generate.
|
181
200
|
|
182
201
|
Returns:
|
183
202
|
True if the model reauthored generates the same answer of the original.
|
184
203
|
"""
|
185
|
-
prompt_tokens = tokenizer.encode(prompts
|
204
|
+
prompt_tokens = tokenizer.encode(prompts)
|
186
205
|
|
187
206
|
logging.info("Generating answer with the original model...")
|
188
|
-
outputs_original = original_model.generate(prompt_tokens)
|
207
|
+
outputs_original = original_model.generate(prompt_tokens, max_new_tokens)
|
189
208
|
response_original = tokenizer.decode(outputs_original[0])
|
190
209
|
logging.info("outputs_from_original_model: [[%s]]", response_original)
|
191
210
|
|
192
211
|
logging.info("Generating answer with the reauthored model...")
|
193
|
-
|
194
|
-
outputs_reauthored = generate(reauthored_model, prompt_tokens, generate_len)
|
212
|
+
outputs_reauthored = reauthored_model.generate(prompt_tokens, max_new_tokens)
|
195
213
|
response_reauthored = tokenizer.decode(outputs_reauthored[0])
|
196
214
|
logging.info("outputs from reauthored model: [[%s]]", response_reauthored)
|
197
215
|
|
@@ -200,9 +218,10 @@ def verify_model_with_prompts(
|
|
200
218
|
|
201
219
|
def verify_reauthored_model(
|
202
220
|
original_model: ModelWrapper,
|
203
|
-
reauthored_model:
|
204
|
-
tokenizer:
|
221
|
+
reauthored_model: ReauthoredModelWrapper,
|
222
|
+
tokenizer: TokenizerWrapper,
|
205
223
|
generate_prompts: List[str],
|
224
|
+
max_new_tokens: int = 30,
|
206
225
|
forward_input_ids: List[List[int]] = [[1, 2, 3, 4]],
|
207
226
|
rtol: float = 1e-05,
|
208
227
|
atol: float = 1e-05,
|
@@ -219,10 +238,11 @@ def verify_reauthored_model(
|
|
219
238
|
|
220
239
|
Args:
|
221
240
|
original_model (ModelWrapper): The original model.
|
222
|
-
reauthored_model (
|
223
|
-
Generative API.
|
224
|
-
tokenizer (
|
241
|
+
reauthored_model (ReauthoredModelWrapper): The model reauthored with
|
242
|
+
ai_edge_torch Generative API.
|
243
|
+
tokenizer (TokenizerWrapper): The tokenizer.
|
225
244
|
generate_prompts (List[str]): List of the input prompts to generate answers.
|
245
|
+
max_new_tokens (int): The maximum number of new tokens to generate.
|
226
246
|
forward_input_ids (List[torch.Tensor]): List if ihe input token IDs to
|
227
247
|
forward with.
|
228
248
|
rtol (float): The relative tolerance for the comparison.
|
@@ -235,13 +255,13 @@ def verify_reauthored_model(
|
|
235
255
|
):
|
236
256
|
logging.info("PASS")
|
237
257
|
else:
|
238
|
-
logging.
|
258
|
+
logging.error("FAILED")
|
239
259
|
|
240
260
|
for prompts in generate_prompts:
|
241
261
|
logging.info("Verifying the reauthored model with prompts:%s", prompts)
|
242
262
|
if verify_model_with_prompts(
|
243
|
-
original_model, reauthored_model, tokenizer, prompts
|
263
|
+
original_model, reauthored_model, tokenizer, prompts, max_new_tokens
|
244
264
|
):
|
245
265
|
logging.info("PASS")
|
246
266
|
else:
|
247
|
-
logging.
|
267
|
+
logging.error("FAILED")
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.3.0.
|
3
|
+
Version: 0.3.0.dev20240927
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|