ai-edge-torch-nightly 0.3.0.dev20240924__py3-none-any.whl → 0.3.0.dev20240928__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- ai_edge_torch/generative/examples/gemma/gemma1.py +2 -6
- ai_edge_torch/generative/examples/gemma/gemma2.py +2 -10
- ai_edge_torch/generative/examples/gemma/verify_gemma1.py +3 -2
- ai_edge_torch/generative/examples/gemma/verify_gemma2.py +3 -2
- ai_edge_torch/generative/examples/gemma/verify_util.py +15 -25
- ai_edge_torch/generative/examples/llama/__init__.py +14 -0
- ai_edge_torch/generative/examples/llama/convert_3b_to_tflite.py +68 -0
- ai_edge_torch/generative/examples/llama/convert_to_tflite.py +68 -0
- ai_edge_torch/generative/examples/llama/llama.py +204 -0
- ai_edge_torch/generative/examples/llama/verify.py +73 -0
- ai_edge_torch/generative/examples/llama/verify_3b.py +73 -0
- ai_edge_torch/generative/examples/openelm/openelm.py +2 -6
- ai_edge_torch/generative/examples/openelm/verify.py +19 -11
- ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py +68 -0
- ai_edge_torch/generative/examples/phi/phi2.py +2 -6
- ai_edge_torch/generative/examples/phi/phi3.py +279 -0
- ai_edge_torch/generative/examples/phi/verify.py +13 -13
- ai_edge_torch/generative/examples/phi/verify_phi3.py +69 -0
- ai_edge_torch/generative/examples/smollm/smollm.py +1 -0
- ai_edge_torch/generative/examples/smollm/verify.py +19 -9
- ai_edge_torch/generative/examples/stable_diffusion/clip.py +54 -1
- ai_edge_torch/generative/examples/stable_diffusion/decoder.py +58 -0
- ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +71 -1
- ai_edge_torch/generative/examples/t5/t5.py +0 -2
- ai_edge_torch/generative/examples/test_models/convert_toy_model.py +105 -0
- ai_edge_torch/generative/examples/test_models/toy_model.py +7 -41
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +5 -61
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +2 -6
- ai_edge_torch/generative/examples/tiny_llama/verify.py +20 -10
- ai_edge_torch/generative/layers/model_config.py +2 -0
- ai_edge_torch/generative/layers/normalization.py +2 -2
- ai_edge_torch/generative/layers/unet/blocks_2d.py +2 -2
- ai_edge_torch/generative/test/test_model_conversion_large.py +129 -0
- ai_edge_torch/generative/utilities/transformers_verifier.py +42 -0
- ai_edge_torch/generative/utilities/verifier.py +130 -114
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240924.dist-info → ai_edge_torch_nightly-0.3.0.dev20240928.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240924.dist-info → ai_edge_torch_nightly-0.3.0.dev20240928.dist-info}/RECORD +41 -30
- {ai_edge_torch_nightly-0.3.0.dev20240924.dist-info → ai_edge_torch_nightly-0.3.0.dev20240928.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240924.dist-info → ai_edge_torch_nightly-0.3.0.dev20240928.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240924.dist-info → ai_edge_torch_nightly-0.3.0.dev20240928.dist-info}/top_level.txt +0 -0
@@ -83,6 +83,8 @@ class AttentionConfig:
|
|
83
83
|
# Used to determine number of groups in grouped query attention (GQA)
|
84
84
|
# https://arxiv.org/pdf/2305.13245.pdf
|
85
85
|
num_query_groups: Optional[int]
|
86
|
+
# Base of rotary positional embedding.
|
87
|
+
rotary_base: int = 10_000
|
86
88
|
# Percentage of Rotary Positional Embedding added Q and K projections.
|
87
89
|
rotary_percentage: Optional[float] = None
|
88
90
|
# Whether to transpose the query groups of qkv bundled tensor before
|
@@ -189,7 +189,7 @@ def group_norm_with_hlfb(
|
|
189
189
|
name="odml.group_norm",
|
190
190
|
attr={
|
191
191
|
"num_groups": num_groups,
|
192
|
-
"
|
192
|
+
"epsilon": eps,
|
193
193
|
"reduction_axes": 3,
|
194
194
|
"channel_axis": 3,
|
195
195
|
},
|
@@ -226,7 +226,7 @@ def layer_norm_with_hlfb(
|
|
226
226
|
"""
|
227
227
|
builder = StableHLOCompositeBuilder(
|
228
228
|
name="odml.group_norm",
|
229
|
-
attr={"num_groups": 1, "
|
229
|
+
attr={"num_groups": 1, "epsilon": eps, "channel_axis": 1},
|
230
230
|
)
|
231
231
|
x, w, b = builder.mark_inputs(x, w, b)
|
232
232
|
if use_input_shape:
|
@@ -13,7 +13,7 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
|
-
from typing import List, Optional, Tuple
|
16
|
+
from typing import List, Optional, Tuple, Union
|
17
17
|
|
18
18
|
from ai_edge_torch.generative.layers.attention import CrossAttention
|
19
19
|
from ai_edge_torch.generative.layers.attention import SelfAttention
|
@@ -416,7 +416,7 @@ class DownEncoderBlock2D(nn.Module):
|
|
416
416
|
time_emb: Optional[torch.Tensor] = None,
|
417
417
|
context_tensor: Optional[torch.Tensor] = None,
|
418
418
|
output_hidden_states: bool = False,
|
419
|
-
) -> torch.Tensor
|
419
|
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]:
|
420
420
|
"""Forward function of the DownEncoderBlock2D.
|
421
421
|
|
422
422
|
Args:
|
@@ -19,9 +19,14 @@ import ai_edge_torch
|
|
19
19
|
from ai_edge_torch import config as ai_edge_config
|
20
20
|
from ai_edge_torch.generative.examples.gemma import gemma1
|
21
21
|
from ai_edge_torch.generative.examples.gemma import gemma2
|
22
|
+
from ai_edge_torch.generative.examples.llama import llama
|
22
23
|
from ai_edge_torch.generative.examples.openelm import openelm
|
23
24
|
from ai_edge_torch.generative.examples.phi import phi2
|
25
|
+
from ai_edge_torch.generative.examples.phi import phi3
|
24
26
|
from ai_edge_torch.generative.examples.smollm import smollm
|
27
|
+
from ai_edge_torch.generative.examples.stable_diffusion import clip as sd_clip
|
28
|
+
from ai_edge_torch.generative.examples.stable_diffusion import decoder as sd_decoder
|
29
|
+
from ai_edge_torch.generative.examples.stable_diffusion import diffusion as sd_diffusion
|
25
30
|
from ai_edge_torch.generative.layers import kv_cache
|
26
31
|
from ai_edge_torch.generative.test import utils as test_utils
|
27
32
|
import numpy as np
|
@@ -98,6 +103,15 @@ class TestModelConversion(googletest.TestCase):
|
|
98
103
|
pytorch_model = gemma2.Gemma2(config).eval()
|
99
104
|
self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
|
100
105
|
|
106
|
+
@googletest.skipIf(
|
107
|
+
ai_edge_config.Config.use_torch_xla,
|
108
|
+
reason="tests with custom ops are not supported on oss",
|
109
|
+
)
|
110
|
+
def test_llama(self):
|
111
|
+
config = llama.get_fake_model_config()
|
112
|
+
pytorch_model = llama.Llama(config).eval()
|
113
|
+
self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)
|
114
|
+
|
101
115
|
@googletest.skipIf(
|
102
116
|
ai_edge_config.Config.use_torch_xla,
|
103
117
|
reason="tests with custom ops are not supported on oss",
|
@@ -109,6 +123,17 @@ class TestModelConversion(googletest.TestCase):
|
|
109
123
|
config, pytorch_model, "serving_default", atol=1e-3, rtol=1e-3
|
110
124
|
)
|
111
125
|
|
126
|
+
@googletest.skipIf(
|
127
|
+
ai_edge_config.Config.use_torch_xla,
|
128
|
+
reason="tests with custom ops are not supported on oss",
|
129
|
+
)
|
130
|
+
def test_phi3(self):
|
131
|
+
config = phi3.get_fake_model_config()
|
132
|
+
pytorch_model = phi3.Phi3_5Mini(config).eval()
|
133
|
+
self._test_model(
|
134
|
+
config, pytorch_model, "prefill", atol=1e-5, rtol=1e-5
|
135
|
+
)
|
136
|
+
|
112
137
|
@googletest.skipIf(
|
113
138
|
ai_edge_config.Config.use_torch_xla,
|
114
139
|
reason="tests with custom ops are not supported on oss",
|
@@ -127,6 +152,110 @@ class TestModelConversion(googletest.TestCase):
|
|
127
152
|
pytorch_model = openelm.OpenELM(config).eval()
|
128
153
|
self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
|
129
154
|
|
155
|
+
@googletest.skipIf(
|
156
|
+
ai_edge_config.Config.use_torch_xla,
|
157
|
+
reason="tests with custom ops are not supported on oss",
|
158
|
+
)
|
159
|
+
def test_stable_diffusion_clip(self):
|
160
|
+
config = sd_clip.get_fake_model_config()
|
161
|
+
prompt_tokens = torch.from_numpy(
|
162
|
+
np.array([[1, 2, 3, 4, 5, 6]], dtype=np.int32)
|
163
|
+
)
|
164
|
+
|
165
|
+
pytorch_model = sd_clip.CLIP(config).eval()
|
166
|
+
torch_output = pytorch_model(prompt_tokens)
|
167
|
+
|
168
|
+
edge_model = ai_edge_torch.signature(
|
169
|
+
"encode", pytorch_model, (prompt_tokens,)
|
170
|
+
).convert()
|
171
|
+
edge_model.set_interpreter_builder(
|
172
|
+
self._interpreter_builder(edge_model.tflite_model())
|
173
|
+
)
|
174
|
+
edge_output = edge_model(
|
175
|
+
prompt_tokens.numpy(),
|
176
|
+
signature_name="encode",
|
177
|
+
)
|
178
|
+
self.assertTrue(
|
179
|
+
np.allclose(
|
180
|
+
edge_output,
|
181
|
+
torch_output.detach().numpy(),
|
182
|
+
atol=1e-4,
|
183
|
+
rtol=1e-5,
|
184
|
+
)
|
185
|
+
)
|
186
|
+
|
187
|
+
@googletest.skipIf(
|
188
|
+
ai_edge_config.Config.use_torch_xla,
|
189
|
+
reason="tests with custom ops are not supported on oss",
|
190
|
+
)
|
191
|
+
def test_stable_diffusion_diffusion(self):
|
192
|
+
config = sd_diffusion.get_fake_model_config(2)
|
193
|
+
latents = torch.from_numpy(
|
194
|
+
np.random.normal(size=(2, 4, 8, 8)).astype(np.float32)
|
195
|
+
)
|
196
|
+
context = torch.from_numpy(
|
197
|
+
np.random.normal(size=(2, 4, 4)).astype(np.float32)
|
198
|
+
)
|
199
|
+
time_embedding = torch.from_numpy(
|
200
|
+
np.random.normal(size=(2, 2)).astype(np.float32)
|
201
|
+
)
|
202
|
+
|
203
|
+
pytorch_model = sd_diffusion.Diffusion(config).eval()
|
204
|
+
torch_output = pytorch_model(latents, context, time_embedding)
|
205
|
+
|
206
|
+
edge_model = ai_edge_torch.signature(
|
207
|
+
"diffusion", pytorch_model, (latents, context, time_embedding)
|
208
|
+
).convert()
|
209
|
+
edge_model.set_interpreter_builder(
|
210
|
+
self._interpreter_builder(edge_model.tflite_model())
|
211
|
+
)
|
212
|
+
edge_output = edge_model(
|
213
|
+
latents.numpy(),
|
214
|
+
context.numpy(),
|
215
|
+
time_embedding.numpy(),
|
216
|
+
signature_name="diffusion",
|
217
|
+
)
|
218
|
+
self.assertTrue(
|
219
|
+
np.allclose(
|
220
|
+
edge_output,
|
221
|
+
torch_output.detach().numpy(),
|
222
|
+
atol=1e-4,
|
223
|
+
rtol=1e-5,
|
224
|
+
)
|
225
|
+
)
|
226
|
+
|
227
|
+
@googletest.skipIf(
|
228
|
+
ai_edge_config.Config.use_torch_xla,
|
229
|
+
reason="tests with custom ops are not supported on oss",
|
230
|
+
)
|
231
|
+
def test_stable_diffusion_decoder(self):
|
232
|
+
config = sd_decoder.get_fake_model_config()
|
233
|
+
latents = torch.from_numpy(
|
234
|
+
np.random.normal(size=(1, 4, 64, 64)).astype(np.float32)
|
235
|
+
)
|
236
|
+
|
237
|
+
pytorch_model = sd_decoder.Decoder(config).eval()
|
238
|
+
torch_output = pytorch_model(latents)
|
239
|
+
|
240
|
+
edge_model = ai_edge_torch.signature(
|
241
|
+
"decode", pytorch_model, (latents,)
|
242
|
+
).convert()
|
243
|
+
edge_model.set_interpreter_builder(
|
244
|
+
self._interpreter_builder(edge_model.tflite_model())
|
245
|
+
)
|
246
|
+
edge_output = edge_model(
|
247
|
+
latents.numpy(),
|
248
|
+
signature_name="decode",
|
249
|
+
)
|
250
|
+
self.assertTrue(
|
251
|
+
np.allclose(
|
252
|
+
edge_output,
|
253
|
+
torch_output.detach().numpy(),
|
254
|
+
atol=1e-4,
|
255
|
+
rtol=1e-5,
|
256
|
+
)
|
257
|
+
)
|
258
|
+
|
130
259
|
|
131
260
|
if __name__ == "__main__":
|
132
261
|
googletest.main()
|
@@ -0,0 +1,42 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Utilities for the models predefined in HuggingFace transformers."""
|
17
|
+
|
18
|
+
from typing import cast
|
19
|
+
|
20
|
+
from ai_edge_torch.generative.utilities import verifier
|
21
|
+
import torch
|
22
|
+
import transformers
|
23
|
+
|
24
|
+
|
25
|
+
class TransformersModelWrapper(verifier.ModelWrapper):
|
26
|
+
"""A wrapper for the model predefined in HuggingFace transformers.
|
27
|
+
|
28
|
+
Verifier expects forward() to return logits while Transformers models return
|
29
|
+
an object with `logits` field.
|
30
|
+
|
31
|
+
Transformers models get `max_new_tokens` settings for generate() via
|
32
|
+
GenerationConfig.
|
33
|
+
"""
|
34
|
+
|
35
|
+
def forward(self, tokens: torch.Tensor) -> torch.Tensor:
|
36
|
+
return self.model.forward(tokens).logits
|
37
|
+
|
38
|
+
def generate(
|
39
|
+
self, inputs: torch.Tensor, max_new_tokens: int
|
40
|
+
) -> torch.IntTensor:
|
41
|
+
gen_config = transformers.GenerationConfig(max_new_tokens=max_new_tokens)
|
42
|
+
return self.model.generate(inputs=inputs, generation_config=gen_config)
|
@@ -15,116 +15,130 @@
|
|
15
15
|
|
16
16
|
"""Common utility functions to verify the reauthored models."""
|
17
17
|
|
18
|
-
import
|
19
|
-
from typing import List
|
18
|
+
import logging
|
19
|
+
from typing import List
|
20
20
|
|
21
21
|
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
22
22
|
import torch
|
23
|
-
import transformers
|
24
|
-
|
25
|
-
|
26
|
-
def log_msg(*args):
|
27
|
-
print("[%s]" % datetime.datetime.now(), *args)
|
28
23
|
|
29
24
|
|
30
25
|
class ModelWrapper(torch.nn.Module):
|
31
|
-
"""A wrapper for the model to be verified
|
26
|
+
"""A wrapper for the model to be verified.
|
32
27
|
|
33
|
-
|
28
|
+
It unifies the interface of forward() and generate() of models for the
|
29
|
+
verification to call.
|
34
30
|
"""
|
35
31
|
|
36
|
-
def __init__(
|
37
|
-
self,
|
38
|
-
model: torch.nn.Module,
|
39
|
-
model_format: str = "huggingface",
|
40
|
-
hf_generation_config: Optional[transformers.GenerationConfig] = None,
|
41
|
-
):
|
32
|
+
def __init__(self, model: torch.nn.Module):
|
42
33
|
"""Initializes the wrapper.
|
43
34
|
|
44
35
|
Args:
|
45
|
-
model (torch.nn.Module): The
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
hf_generation_config (transformers.GenerationConfig): The HuggingFace
|
50
|
-
generation config. This config will only be used if the underlying model
|
51
|
-
is built from HuggingFace transformers.
|
36
|
+
model (torch.nn.Module): The model which might have different interfaces
|
37
|
+
of forward() and generate(). It could be a model built from HuggingFace
|
38
|
+
transformers, a regular PyTorch model, or a model re-authored with
|
39
|
+
ai_edge_torch Generative API.
|
52
40
|
"""
|
53
41
|
super().__init__()
|
54
42
|
self.model = model
|
55
|
-
|
56
|
-
|
43
|
+
|
44
|
+
def forward(self, tokens: torch.Tensor) -> torch.Tensor:
|
45
|
+
"""Gets output logits by forwarding the input tokens.
|
46
|
+
|
47
|
+
Args:
|
48
|
+
tokens (torch.Tensor): The input tokens to forward. Its dimension is
|
49
|
+
expected to be (batch_size=1, kv_cache_max_len).
|
50
|
+
|
51
|
+
Returns:
|
52
|
+
The output logits.
|
53
|
+
"""
|
54
|
+
raise NotImplementedError("forward() is not implemented.")
|
57
55
|
|
58
56
|
def generate(
|
59
|
-
self,
|
60
|
-
) ->
|
61
|
-
|
62
|
-
return self.model.generate(
|
63
|
-
inputs=inputs, generation_config=self.hf_generation_config
|
64
|
-
)
|
65
|
-
else:
|
66
|
-
raise NotImplementedError(
|
67
|
-
"generate() is not implemented for model format: %s"
|
68
|
-
% self.model_format
|
69
|
-
)
|
57
|
+
self, prompts: torch.Tensor, max_new_tokens: int
|
58
|
+
) -> torch.IntTensor:
|
59
|
+
"""Returns the response token IDs to the given prompts tensor.
|
70
60
|
|
71
|
-
|
72
|
-
self,
|
73
|
-
inputs: torch.Tensor,
|
74
|
-
):
|
75
|
-
return self.model.forward(inputs)
|
61
|
+
The maximum number of tokens to generate might be set by subclasses.
|
76
62
|
|
63
|
+
Args:
|
64
|
+
prompts (torch.Tensor): The input token IDs to generate with. Its shape is
|
65
|
+
expected to be (batch_size=1, input_ids_len).
|
66
|
+
max_new_tokens (int): The maximum number of response token IDs to
|
67
|
+
generate.
|
68
|
+
|
69
|
+
Returns:
|
70
|
+
The tensor of response token IDs with shape of (batch_size=1,
|
71
|
+
response_ids_len).
|
72
|
+
"""
|
73
|
+
raise NotImplementedError("generate() is not implemented.")
|
77
74
|
|
78
|
-
def forward(
|
79
|
-
model: torch.nn.Module,
|
80
|
-
tokens: torch.Tensor,
|
81
|
-
kv_cache: kv_utils.KVCache,
|
82
|
-
) -> tuple[torch.Tensor, kv_utils.KVCache]:
|
83
|
-
"""Forwards the model reauthored with ai_edge_torch Generative API.
|
84
75
|
|
85
|
-
|
86
|
-
|
87
|
-
with ai_edge_torch Generative API.
|
88
|
-
tokens (torch.Tensor): The input tokens to forward.
|
89
|
-
kv_cache (KVCache): The KV cache to forward.
|
76
|
+
class ReauthoredModelWrapper(ModelWrapper):
|
77
|
+
"""A wrapper for the model reauthored with ai_edge_torch Generative API."""
|
90
78
|
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
input_pos = torch.arange(0, tokens.shape[1], dtype=torch.int)
|
95
|
-
output = model.forward(tokens, input_pos, kv_cache)
|
96
|
-
return output["logits"], output["kv_cache"]
|
79
|
+
def _init_kv_cache(self):
|
80
|
+
"""Returns an initialized KV cache."""
|
81
|
+
return kv_utils.KVCache.from_model_config(self.model.config)
|
97
82
|
|
83
|
+
def _forward_with_kv_cache(
|
84
|
+
self,
|
85
|
+
tokens: torch.Tensor,
|
86
|
+
kv_cache: kv_utils.KVCache,
|
87
|
+
) -> tuple[torch.Tensor, kv_utils.KVCache]:
|
88
|
+
"""Forwards the model and updates an external KV cache.
|
98
89
|
|
99
|
-
|
100
|
-
|
101
|
-
)
|
102
|
-
"""Generates the response to the prompts.
|
90
|
+
Args:
|
91
|
+
tokens (torch.Tensor): The input tokens to forward.
|
92
|
+
kv_cache (KVCache): The KV cache to forward.
|
103
93
|
|
104
|
-
|
105
|
-
|
94
|
+
Returns:
|
95
|
+
The output logits and the updated KV cache.
|
96
|
+
"""
|
97
|
+
input_pos = torch.arange(0, tokens.shape[1], dtype=torch.int)
|
98
|
+
output = self.model.forward(tokens, input_pos, kv_cache)
|
99
|
+
return output["logits"], output["kv_cache"]
|
106
100
|
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
prompts (torch.Tensor): The prompts to generate.
|
111
|
-
response_len (int): The number of tokens to generate.
|
101
|
+
def forward(self, tokens: torch.Tensor) -> torch.Tensor:
|
102
|
+
logits, _ = self._forward_with_kv_cache(tokens, self._init_kv_cache())
|
103
|
+
return logits
|
112
104
|
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
105
|
+
def generate(
|
106
|
+
self, prompts: torch.Tensor, max_new_tokens: int
|
107
|
+
) -> torch.IntTensor:
|
108
|
+
input_ids = prompts[0].int().tolist()
|
109
|
+
kv_cache = self._init_kv_cache()
|
110
|
+
for _ in range(max_new_tokens):
|
111
|
+
tokens = torch.tensor([input_ids])
|
112
|
+
logits, kv_cache = self._forward_with_kv_cache(tokens, kv_cache)
|
113
|
+
generated_token = logits[0][-1].argmax().item()
|
114
|
+
input_ids.append(generated_token)
|
115
|
+
return torch.tensor([input_ids])
|
116
|
+
|
117
|
+
|
118
|
+
class TokenizerWrapper(torch.nn.Module):
|
119
|
+
"""A wrapper for the tokenizer used for verification."""
|
120
|
+
|
121
|
+
def __init__(self, tokenizer: torch.nn.Module):
|
122
|
+
"""Initializes the wrapper.
|
123
|
+
|
124
|
+
Args:
|
125
|
+
tokenizer (torch.nn.Module): The tokenizer to wrap.
|
126
|
+
"""
|
127
|
+
super().__init__()
|
128
|
+
self.tokenizer = tokenizer
|
129
|
+
|
130
|
+
def encode(self, prompts: str) -> torch.Tensor:
|
131
|
+
"""Encodes the prompts to token IDs."""
|
132
|
+
return self.tokenizer.encode(prompts, return_tensors="pt")
|
133
|
+
|
134
|
+
def decode(self, token_ids: torch.Tensor) -> str:
|
135
|
+
"""Decodes the token IDs to a string."""
|
136
|
+
return self.tokenizer.decode(token_ids)
|
123
137
|
|
124
138
|
|
125
139
|
def verify_with_input_ids(
|
126
140
|
original_model: ModelWrapper,
|
127
|
-
reauthored_model:
|
141
|
+
reauthored_model: ReauthoredModelWrapper,
|
128
142
|
input_ids: List[int],
|
129
143
|
kv_cache_max_len: int = 1024,
|
130
144
|
rtol: float = 1e-05,
|
@@ -136,8 +150,8 @@ def verify_with_input_ids(
|
|
136
150
|
|
137
151
|
Args:
|
138
152
|
original_model (ModelWrapper): The original model.
|
139
|
-
reauthored_model (
|
140
|
-
Generative API.
|
153
|
+
reauthored_model (ReauthoredModelWrapper): The model reauthored with
|
154
|
+
ai_edge_torch Generative API.
|
141
155
|
input_ids (List[int]): The input token IDs to forward with.
|
142
156
|
kv_cache_max_len (int): The maximum sequence length of the KV cache.
|
143
157
|
rtol (float): The relative tolerance for the comparison.
|
@@ -149,16 +163,15 @@ def verify_with_input_ids(
|
|
149
163
|
tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.int, device="cpu")
|
150
164
|
tokens[0, : len(input_ids)] = torch.tensor([input_ids]).int()
|
151
165
|
|
152
|
-
|
166
|
+
logging.info("Forwarding the original model...")
|
153
167
|
outputs_original = original_model.forward(tokens)
|
154
|
-
logits_original = outputs_original
|
155
|
-
|
168
|
+
logits_original = outputs_original[0, len(input_ids) - 1, :]
|
169
|
+
logging.info("logits_original: %s", logits_original)
|
156
170
|
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
logits_reauthored
|
161
|
-
log_msg("logits_reauthored:", logits_reauthored)
|
171
|
+
logging.info("Forwarding the reauthored model...")
|
172
|
+
outputs_reauthored = reauthored_model.forward(tokens)
|
173
|
+
logits_reauthored = outputs_reauthored[0, len(input_ids) - 1, :]
|
174
|
+
logging.info("logits_reauthored: %s", logits_reauthored)
|
162
175
|
|
163
176
|
return torch.allclose(
|
164
177
|
logits_original, logits_reauthored, rtol=rtol, atol=atol
|
@@ -167,9 +180,10 @@ def verify_with_input_ids(
|
|
167
180
|
|
168
181
|
def verify_model_with_prompts(
|
169
182
|
original_model: ModelWrapper,
|
170
|
-
reauthored_model:
|
171
|
-
tokenizer:
|
183
|
+
reauthored_model: ReauthoredModelWrapper,
|
184
|
+
tokenizer: TokenizerWrapper,
|
172
185
|
prompts: str,
|
186
|
+
max_new_tokens: int,
|
173
187
|
) -> bool:
|
174
188
|
"""Verifies if the model reauthored generates the same answer of the oringal.
|
175
189
|
|
@@ -178,35 +192,36 @@ def verify_model_with_prompts(
|
|
178
192
|
|
179
193
|
Args:
|
180
194
|
original_model (ModelWrapper): The original model.
|
181
|
-
reauthored_model (
|
182
|
-
Generative API.
|
183
|
-
tokenizer (
|
195
|
+
reauthored_model (ReauthoredModelWrapper): The model reauthored with
|
196
|
+
ai_edge_torch Generative API.
|
197
|
+
tokenizer (TokenizerWrapper): The tokenizer.
|
184
198
|
prompts (str): The input prompts to generate answers.
|
199
|
+
max_new_tokens (int): The maximum number of new tokens to generate.
|
185
200
|
|
186
201
|
Returns:
|
187
202
|
True if the model reauthored generates the same answer of the original.
|
188
203
|
"""
|
189
|
-
prompt_tokens = tokenizer.encode(prompts
|
204
|
+
prompt_tokens = tokenizer.encode(prompts)
|
190
205
|
|
191
|
-
|
192
|
-
outputs_original = original_model.generate(prompt_tokens)
|
206
|
+
logging.info("Generating answer with the original model...")
|
207
|
+
outputs_original = original_model.generate(prompt_tokens, max_new_tokens)
|
193
208
|
response_original = tokenizer.decode(outputs_original[0])
|
194
|
-
|
209
|
+
logging.info("outputs_from_original_model: [[%s]]", response_original)
|
195
210
|
|
196
|
-
|
197
|
-
|
198
|
-
outputs_reauthored = generate(reauthored_model, prompt_tokens, generate_len)
|
211
|
+
logging.info("Generating answer with the reauthored model...")
|
212
|
+
outputs_reauthored = reauthored_model.generate(prompt_tokens, max_new_tokens)
|
199
213
|
response_reauthored = tokenizer.decode(outputs_reauthored[0])
|
200
|
-
|
214
|
+
logging.info("outputs from reauthored model: [[%s]]", response_reauthored)
|
201
215
|
|
202
216
|
return response_original == response_reauthored
|
203
217
|
|
204
218
|
|
205
219
|
def verify_reauthored_model(
|
206
220
|
original_model: ModelWrapper,
|
207
|
-
reauthored_model:
|
208
|
-
tokenizer:
|
221
|
+
reauthored_model: ReauthoredModelWrapper,
|
222
|
+
tokenizer: TokenizerWrapper,
|
209
223
|
generate_prompts: List[str],
|
224
|
+
max_new_tokens: int = 30,
|
210
225
|
forward_input_ids: List[List[int]] = [[1, 2, 3, 4]],
|
211
226
|
rtol: float = 1e-05,
|
212
227
|
atol: float = 1e-05,
|
@@ -223,29 +238,30 @@ def verify_reauthored_model(
|
|
223
238
|
|
224
239
|
Args:
|
225
240
|
original_model (ModelWrapper): The original model.
|
226
|
-
reauthored_model (
|
227
|
-
Generative API.
|
228
|
-
tokenizer (
|
241
|
+
reauthored_model (ReauthoredModelWrapper): The model reauthored with
|
242
|
+
ai_edge_torch Generative API.
|
243
|
+
tokenizer (TokenizerWrapper): The tokenizer.
|
229
244
|
generate_prompts (List[str]): List of the input prompts to generate answers.
|
245
|
+
max_new_tokens (int): The maximum number of new tokens to generate.
|
230
246
|
forward_input_ids (List[torch.Tensor]): List if ihe input token IDs to
|
231
247
|
forward with.
|
232
248
|
rtol (float): The relative tolerance for the comparison.
|
233
249
|
atol (float): The absolute tolerance for the comparison.
|
234
250
|
"""
|
235
251
|
for input_ids in forward_input_ids:
|
236
|
-
|
252
|
+
logging.info("Verifying the reauthored model with input IDs: %s", input_ids)
|
237
253
|
if verify_with_input_ids(
|
238
254
|
original_model, reauthored_model, input_ids, rtol=rtol, atol=atol
|
239
255
|
):
|
240
|
-
|
256
|
+
logging.info("PASS")
|
241
257
|
else:
|
242
|
-
|
258
|
+
logging.error("FAILED")
|
243
259
|
|
244
260
|
for prompts in generate_prompts:
|
245
|
-
|
261
|
+
logging.info("Verifying the reauthored model with prompts:%s", prompts)
|
246
262
|
if verify_model_with_prompts(
|
247
|
-
original_model, reauthored_model, tokenizer, prompts
|
263
|
+
original_model, reauthored_model, tokenizer, prompts, max_new_tokens
|
248
264
|
):
|
249
|
-
|
265
|
+
logging.info("PASS")
|
250
266
|
else:
|
251
|
-
|
267
|
+
logging.error("FAILED")
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.3.0.
|
3
|
+
Version: 0.3.0.dev20240928
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|