ai-edge-torch-nightly 0.3.0.dev20240924__py3-none-any.whl → 0.3.0.dev20240928__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/generative/examples/gemma/gemma1.py +2 -6
- ai_edge_torch/generative/examples/gemma/gemma2.py +2 -10
- ai_edge_torch/generative/examples/gemma/verify_gemma1.py +3 -2
- ai_edge_torch/generative/examples/gemma/verify_gemma2.py +3 -2
- ai_edge_torch/generative/examples/gemma/verify_util.py +15 -25
- ai_edge_torch/generative/examples/llama/__init__.py +14 -0
- ai_edge_torch/generative/examples/llama/convert_3b_to_tflite.py +68 -0
- ai_edge_torch/generative/examples/llama/convert_to_tflite.py +68 -0
- ai_edge_torch/generative/examples/llama/llama.py +204 -0
- ai_edge_torch/generative/examples/llama/verify.py +73 -0
- ai_edge_torch/generative/examples/llama/verify_3b.py +73 -0
- ai_edge_torch/generative/examples/openelm/openelm.py +2 -6
- ai_edge_torch/generative/examples/openelm/verify.py +19 -11
- ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py +68 -0
- ai_edge_torch/generative/examples/phi/phi2.py +2 -6
- ai_edge_torch/generative/examples/phi/phi3.py +279 -0
- ai_edge_torch/generative/examples/phi/verify.py +13 -13
- ai_edge_torch/generative/examples/phi/verify_phi3.py +69 -0
- ai_edge_torch/generative/examples/smollm/smollm.py +1 -0
- ai_edge_torch/generative/examples/smollm/verify.py +19 -9
- ai_edge_torch/generative/examples/stable_diffusion/clip.py +54 -1
- ai_edge_torch/generative/examples/stable_diffusion/decoder.py +58 -0
- ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +71 -1
- ai_edge_torch/generative/examples/t5/t5.py +0 -2
- ai_edge_torch/generative/examples/test_models/convert_toy_model.py +105 -0
- ai_edge_torch/generative/examples/test_models/toy_model.py +7 -41
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +5 -61
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +2 -6
- ai_edge_torch/generative/examples/tiny_llama/verify.py +20 -10
- ai_edge_torch/generative/layers/model_config.py +2 -0
- ai_edge_torch/generative/layers/normalization.py +2 -2
- ai_edge_torch/generative/layers/unet/blocks_2d.py +2 -2
- ai_edge_torch/generative/test/test_model_conversion_large.py +129 -0
- ai_edge_torch/generative/utilities/transformers_verifier.py +42 -0
- ai_edge_torch/generative/utilities/verifier.py +130 -114
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240924.dist-info → ai_edge_torch_nightly-0.3.0.dev20240928.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240924.dist-info → ai_edge_torch_nightly-0.3.0.dev20240928.dist-info}/RECORD +41 -30
- {ai_edge_torch_nightly-0.3.0.dev20240924.dist-info → ai_edge_torch_nightly-0.3.0.dev20240928.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240924.dist-info → ai_edge_torch_nightly-0.3.0.dev20240928.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240924.dist-info → ai_edge_torch_nightly-0.3.0.dev20240928.dist-info}/top_level.txt +0 -0
@@ -83,6 +83,8 @@ class AttentionConfig:
|
|
83
83
|
# Used to determine number of groups in grouped query attention (GQA)
|
84
84
|
# https://arxiv.org/pdf/2305.13245.pdf
|
85
85
|
num_query_groups: Optional[int]
|
86
|
+
# Base of rotary positional embedding.
|
87
|
+
rotary_base: int = 10_000
|
86
88
|
# Percentage of Rotary Positional Embedding added Q and K projections.
|
87
89
|
rotary_percentage: Optional[float] = None
|
88
90
|
# Whether to transpose the query groups of qkv bundled tensor before
|
@@ -189,7 +189,7 @@ def group_norm_with_hlfb(
|
|
189
189
|
name="odml.group_norm",
|
190
190
|
attr={
|
191
191
|
"num_groups": num_groups,
|
192
|
-
"
|
192
|
+
"epsilon": eps,
|
193
193
|
"reduction_axes": 3,
|
194
194
|
"channel_axis": 3,
|
195
195
|
},
|
@@ -226,7 +226,7 @@ def layer_norm_with_hlfb(
|
|
226
226
|
"""
|
227
227
|
builder = StableHLOCompositeBuilder(
|
228
228
|
name="odml.group_norm",
|
229
|
-
attr={"num_groups": 1, "
|
229
|
+
attr={"num_groups": 1, "epsilon": eps, "channel_axis": 1},
|
230
230
|
)
|
231
231
|
x, w, b = builder.mark_inputs(x, w, b)
|
232
232
|
if use_input_shape:
|
@@ -13,7 +13,7 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
|
-
from typing import List, Optional, Tuple
|
16
|
+
from typing import List, Optional, Tuple, Union
|
17
17
|
|
18
18
|
from ai_edge_torch.generative.layers.attention import CrossAttention
|
19
19
|
from ai_edge_torch.generative.layers.attention import SelfAttention
|
@@ -416,7 +416,7 @@ class DownEncoderBlock2D(nn.Module):
|
|
416
416
|
time_emb: Optional[torch.Tensor] = None,
|
417
417
|
context_tensor: Optional[torch.Tensor] = None,
|
418
418
|
output_hidden_states: bool = False,
|
419
|
-
) -> torch.Tensor
|
419
|
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]:
|
420
420
|
"""Forward function of the DownEncoderBlock2D.
|
421
421
|
|
422
422
|
Args:
|
@@ -19,9 +19,14 @@ import ai_edge_torch
|
|
19
19
|
from ai_edge_torch import config as ai_edge_config
|
20
20
|
from ai_edge_torch.generative.examples.gemma import gemma1
|
21
21
|
from ai_edge_torch.generative.examples.gemma import gemma2
|
22
|
+
from ai_edge_torch.generative.examples.llama import llama
|
22
23
|
from ai_edge_torch.generative.examples.openelm import openelm
|
23
24
|
from ai_edge_torch.generative.examples.phi import phi2
|
25
|
+
from ai_edge_torch.generative.examples.phi import phi3
|
24
26
|
from ai_edge_torch.generative.examples.smollm import smollm
|
27
|
+
from ai_edge_torch.generative.examples.stable_diffusion import clip as sd_clip
|
28
|
+
from ai_edge_torch.generative.examples.stable_diffusion import decoder as sd_decoder
|
29
|
+
from ai_edge_torch.generative.examples.stable_diffusion import diffusion as sd_diffusion
|
25
30
|
from ai_edge_torch.generative.layers import kv_cache
|
26
31
|
from ai_edge_torch.generative.test import utils as test_utils
|
27
32
|
import numpy as np
|
@@ -98,6 +103,15 @@ class TestModelConversion(googletest.TestCase):
|
|
98
103
|
pytorch_model = gemma2.Gemma2(config).eval()
|
99
104
|
self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
|
100
105
|
|
106
|
+
@googletest.skipIf(
|
107
|
+
ai_edge_config.Config.use_torch_xla,
|
108
|
+
reason="tests with custom ops are not supported on oss",
|
109
|
+
)
|
110
|
+
def test_llama(self):
|
111
|
+
config = llama.get_fake_model_config()
|
112
|
+
pytorch_model = llama.Llama(config).eval()
|
113
|
+
self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)
|
114
|
+
|
101
115
|
@googletest.skipIf(
|
102
116
|
ai_edge_config.Config.use_torch_xla,
|
103
117
|
reason="tests with custom ops are not supported on oss",
|
@@ -109,6 +123,17 @@ class TestModelConversion(googletest.TestCase):
|
|
109
123
|
config, pytorch_model, "serving_default", atol=1e-3, rtol=1e-3
|
110
124
|
)
|
111
125
|
|
126
|
+
@googletest.skipIf(
|
127
|
+
ai_edge_config.Config.use_torch_xla,
|
128
|
+
reason="tests with custom ops are not supported on oss",
|
129
|
+
)
|
130
|
+
def test_phi3(self):
|
131
|
+
config = phi3.get_fake_model_config()
|
132
|
+
pytorch_model = phi3.Phi3_5Mini(config).eval()
|
133
|
+
self._test_model(
|
134
|
+
config, pytorch_model, "prefill", atol=1e-5, rtol=1e-5
|
135
|
+
)
|
136
|
+
|
112
137
|
@googletest.skipIf(
|
113
138
|
ai_edge_config.Config.use_torch_xla,
|
114
139
|
reason="tests with custom ops are not supported on oss",
|
@@ -127,6 +152,110 @@ class TestModelConversion(googletest.TestCase):
|
|
127
152
|
pytorch_model = openelm.OpenELM(config).eval()
|
128
153
|
self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
|
129
154
|
|
155
|
+
@googletest.skipIf(
|
156
|
+
ai_edge_config.Config.use_torch_xla,
|
157
|
+
reason="tests with custom ops are not supported on oss",
|
158
|
+
)
|
159
|
+
def test_stable_diffusion_clip(self):
|
160
|
+
config = sd_clip.get_fake_model_config()
|
161
|
+
prompt_tokens = torch.from_numpy(
|
162
|
+
np.array([[1, 2, 3, 4, 5, 6]], dtype=np.int32)
|
163
|
+
)
|
164
|
+
|
165
|
+
pytorch_model = sd_clip.CLIP(config).eval()
|
166
|
+
torch_output = pytorch_model(prompt_tokens)
|
167
|
+
|
168
|
+
edge_model = ai_edge_torch.signature(
|
169
|
+
"encode", pytorch_model, (prompt_tokens,)
|
170
|
+
).convert()
|
171
|
+
edge_model.set_interpreter_builder(
|
172
|
+
self._interpreter_builder(edge_model.tflite_model())
|
173
|
+
)
|
174
|
+
edge_output = edge_model(
|
175
|
+
prompt_tokens.numpy(),
|
176
|
+
signature_name="encode",
|
177
|
+
)
|
178
|
+
self.assertTrue(
|
179
|
+
np.allclose(
|
180
|
+
edge_output,
|
181
|
+
torch_output.detach().numpy(),
|
182
|
+
atol=1e-4,
|
183
|
+
rtol=1e-5,
|
184
|
+
)
|
185
|
+
)
|
186
|
+
|
187
|
+
@googletest.skipIf(
|
188
|
+
ai_edge_config.Config.use_torch_xla,
|
189
|
+
reason="tests with custom ops are not supported on oss",
|
190
|
+
)
|
191
|
+
def test_stable_diffusion_diffusion(self):
|
192
|
+
config = sd_diffusion.get_fake_model_config(2)
|
193
|
+
latents = torch.from_numpy(
|
194
|
+
np.random.normal(size=(2, 4, 8, 8)).astype(np.float32)
|
195
|
+
)
|
196
|
+
context = torch.from_numpy(
|
197
|
+
np.random.normal(size=(2, 4, 4)).astype(np.float32)
|
198
|
+
)
|
199
|
+
time_embedding = torch.from_numpy(
|
200
|
+
np.random.normal(size=(2, 2)).astype(np.float32)
|
201
|
+
)
|
202
|
+
|
203
|
+
pytorch_model = sd_diffusion.Diffusion(config).eval()
|
204
|
+
torch_output = pytorch_model(latents, context, time_embedding)
|
205
|
+
|
206
|
+
edge_model = ai_edge_torch.signature(
|
207
|
+
"diffusion", pytorch_model, (latents, context, time_embedding)
|
208
|
+
).convert()
|
209
|
+
edge_model.set_interpreter_builder(
|
210
|
+
self._interpreter_builder(edge_model.tflite_model())
|
211
|
+
)
|
212
|
+
edge_output = edge_model(
|
213
|
+
latents.numpy(),
|
214
|
+
context.numpy(),
|
215
|
+
time_embedding.numpy(),
|
216
|
+
signature_name="diffusion",
|
217
|
+
)
|
218
|
+
self.assertTrue(
|
219
|
+
np.allclose(
|
220
|
+
edge_output,
|
221
|
+
torch_output.detach().numpy(),
|
222
|
+
atol=1e-4,
|
223
|
+
rtol=1e-5,
|
224
|
+
)
|
225
|
+
)
|
226
|
+
|
227
|
+
@googletest.skipIf(
|
228
|
+
ai_edge_config.Config.use_torch_xla,
|
229
|
+
reason="tests with custom ops are not supported on oss",
|
230
|
+
)
|
231
|
+
def test_stable_diffusion_decoder(self):
|
232
|
+
config = sd_decoder.get_fake_model_config()
|
233
|
+
latents = torch.from_numpy(
|
234
|
+
np.random.normal(size=(1, 4, 64, 64)).astype(np.float32)
|
235
|
+
)
|
236
|
+
|
237
|
+
pytorch_model = sd_decoder.Decoder(config).eval()
|
238
|
+
torch_output = pytorch_model(latents)
|
239
|
+
|
240
|
+
edge_model = ai_edge_torch.signature(
|
241
|
+
"decode", pytorch_model, (latents,)
|
242
|
+
).convert()
|
243
|
+
edge_model.set_interpreter_builder(
|
244
|
+
self._interpreter_builder(edge_model.tflite_model())
|
245
|
+
)
|
246
|
+
edge_output = edge_model(
|
247
|
+
latents.numpy(),
|
248
|
+
signature_name="decode",
|
249
|
+
)
|
250
|
+
self.assertTrue(
|
251
|
+
np.allclose(
|
252
|
+
edge_output,
|
253
|
+
torch_output.detach().numpy(),
|
254
|
+
atol=1e-4,
|
255
|
+
rtol=1e-5,
|
256
|
+
)
|
257
|
+
)
|
258
|
+
|
130
259
|
|
131
260
|
if __name__ == "__main__":
|
132
261
|
googletest.main()
|
@@ -0,0 +1,42 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Utilities for the models predefined in HuggingFace transformers."""
|
17
|
+
|
18
|
+
from typing import cast
|
19
|
+
|
20
|
+
from ai_edge_torch.generative.utilities import verifier
|
21
|
+
import torch
|
22
|
+
import transformers
|
23
|
+
|
24
|
+
|
25
|
+
class TransformersModelWrapper(verifier.ModelWrapper):
|
26
|
+
"""A wrapper for the model predefined in HuggingFace transformers.
|
27
|
+
|
28
|
+
Verifier expects forward() to return logits while Transformers models return
|
29
|
+
an object with `logits` field.
|
30
|
+
|
31
|
+
Transformers models get `max_new_tokens` settings for generate() via
|
32
|
+
GenerationConfig.
|
33
|
+
"""
|
34
|
+
|
35
|
+
def forward(self, tokens: torch.Tensor) -> torch.Tensor:
|
36
|
+
return self.model.forward(tokens).logits
|
37
|
+
|
38
|
+
def generate(
|
39
|
+
self, inputs: torch.Tensor, max_new_tokens: int
|
40
|
+
) -> torch.IntTensor:
|
41
|
+
gen_config = transformers.GenerationConfig(max_new_tokens=max_new_tokens)
|
42
|
+
return self.model.generate(inputs=inputs, generation_config=gen_config)
|
@@ -15,116 +15,130 @@
|
|
15
15
|
|
16
16
|
"""Common utility functions to verify the reauthored models."""
|
17
17
|
|
18
|
-
import
|
19
|
-
from typing import List
|
18
|
+
import logging
|
19
|
+
from typing import List
|
20
20
|
|
21
21
|
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
22
22
|
import torch
|
23
|
-
import transformers
|
24
|
-
|
25
|
-
|
26
|
-
def log_msg(*args):
|
27
|
-
print("[%s]" % datetime.datetime.now(), *args)
|
28
23
|
|
29
24
|
|
30
25
|
class ModelWrapper(torch.nn.Module):
|
31
|
-
"""A wrapper for the model to be verified
|
26
|
+
"""A wrapper for the model to be verified.
|
32
27
|
|
33
|
-
|
28
|
+
It unifies the interface of forward() and generate() of models for the
|
29
|
+
verification to call.
|
34
30
|
"""
|
35
31
|
|
36
|
-
def __init__(
|
37
|
-
self,
|
38
|
-
model: torch.nn.Module,
|
39
|
-
model_format: str = "huggingface",
|
40
|
-
hf_generation_config: Optional[transformers.GenerationConfig] = None,
|
41
|
-
):
|
32
|
+
def __init__(self, model: torch.nn.Module):
|
42
33
|
"""Initializes the wrapper.
|
43
34
|
|
44
35
|
Args:
|
45
|
-
model (torch.nn.Module): The
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
hf_generation_config (transformers.GenerationConfig): The HuggingFace
|
50
|
-
generation config. This config will only be used if the underlying model
|
51
|
-
is built from HuggingFace transformers.
|
36
|
+
model (torch.nn.Module): The model which might have different interfaces
|
37
|
+
of forward() and generate(). It could be a model built from HuggingFace
|
38
|
+
transformers, a regular PyTorch model, or a model re-authored with
|
39
|
+
ai_edge_torch Generative API.
|
52
40
|
"""
|
53
41
|
super().__init__()
|
54
42
|
self.model = model
|
55
|
-
|
56
|
-
|
43
|
+
|
44
|
+
def forward(self, tokens: torch.Tensor) -> torch.Tensor:
|
45
|
+
"""Gets output logits by forwarding the input tokens.
|
46
|
+
|
47
|
+
Args:
|
48
|
+
tokens (torch.Tensor): The input tokens to forward. Its dimension is
|
49
|
+
expected to be (batch_size=1, kv_cache_max_len).
|
50
|
+
|
51
|
+
Returns:
|
52
|
+
The output logits.
|
53
|
+
"""
|
54
|
+
raise NotImplementedError("forward() is not implemented.")
|
57
55
|
|
58
56
|
def generate(
|
59
|
-
self,
|
60
|
-
) ->
|
61
|
-
|
62
|
-
return self.model.generate(
|
63
|
-
inputs=inputs, generation_config=self.hf_generation_config
|
64
|
-
)
|
65
|
-
else:
|
66
|
-
raise NotImplementedError(
|
67
|
-
"generate() is not implemented for model format: %s"
|
68
|
-
% self.model_format
|
69
|
-
)
|
57
|
+
self, prompts: torch.Tensor, max_new_tokens: int
|
58
|
+
) -> torch.IntTensor:
|
59
|
+
"""Returns the response token IDs to the given prompts tensor.
|
70
60
|
|
71
|
-
|
72
|
-
self,
|
73
|
-
inputs: torch.Tensor,
|
74
|
-
):
|
75
|
-
return self.model.forward(inputs)
|
61
|
+
The maximum number of tokens to generate might be set by subclasses.
|
76
62
|
|
63
|
+
Args:
|
64
|
+
prompts (torch.Tensor): The input token IDs to generate with. Its shape is
|
65
|
+
expected to be (batch_size=1, input_ids_len).
|
66
|
+
max_new_tokens (int): The maximum number of response token IDs to
|
67
|
+
generate.
|
68
|
+
|
69
|
+
Returns:
|
70
|
+
The tensor of response token IDs with shape of (batch_size=1,
|
71
|
+
response_ids_len).
|
72
|
+
"""
|
73
|
+
raise NotImplementedError("generate() is not implemented.")
|
77
74
|
|
78
|
-
def forward(
|
79
|
-
model: torch.nn.Module,
|
80
|
-
tokens: torch.Tensor,
|
81
|
-
kv_cache: kv_utils.KVCache,
|
82
|
-
) -> tuple[torch.Tensor, kv_utils.KVCache]:
|
83
|
-
"""Forwards the model reauthored with ai_edge_torch Generative API.
|
84
75
|
|
85
|
-
|
86
|
-
|
87
|
-
with ai_edge_torch Generative API.
|
88
|
-
tokens (torch.Tensor): The input tokens to forward.
|
89
|
-
kv_cache (KVCache): The KV cache to forward.
|
76
|
+
class ReauthoredModelWrapper(ModelWrapper):
|
77
|
+
"""A wrapper for the model reauthored with ai_edge_torch Generative API."""
|
90
78
|
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
input_pos = torch.arange(0, tokens.shape[1], dtype=torch.int)
|
95
|
-
output = model.forward(tokens, input_pos, kv_cache)
|
96
|
-
return output["logits"], output["kv_cache"]
|
79
|
+
def _init_kv_cache(self):
|
80
|
+
"""Returns an initialized KV cache."""
|
81
|
+
return kv_utils.KVCache.from_model_config(self.model.config)
|
97
82
|
|
83
|
+
def _forward_with_kv_cache(
|
84
|
+
self,
|
85
|
+
tokens: torch.Tensor,
|
86
|
+
kv_cache: kv_utils.KVCache,
|
87
|
+
) -> tuple[torch.Tensor, kv_utils.KVCache]:
|
88
|
+
"""Forwards the model and updates an external KV cache.
|
98
89
|
|
99
|
-
|
100
|
-
|
101
|
-
)
|
102
|
-
"""Generates the response to the prompts.
|
90
|
+
Args:
|
91
|
+
tokens (torch.Tensor): The input tokens to forward.
|
92
|
+
kv_cache (KVCache): The KV cache to forward.
|
103
93
|
|
104
|
-
|
105
|
-
|
94
|
+
Returns:
|
95
|
+
The output logits and the updated KV cache.
|
96
|
+
"""
|
97
|
+
input_pos = torch.arange(0, tokens.shape[1], dtype=torch.int)
|
98
|
+
output = self.model.forward(tokens, input_pos, kv_cache)
|
99
|
+
return output["logits"], output["kv_cache"]
|
106
100
|
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
prompts (torch.Tensor): The prompts to generate.
|
111
|
-
response_len (int): The number of tokens to generate.
|
101
|
+
def forward(self, tokens: torch.Tensor) -> torch.Tensor:
|
102
|
+
logits, _ = self._forward_with_kv_cache(tokens, self._init_kv_cache())
|
103
|
+
return logits
|
112
104
|
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
105
|
+
def generate(
|
106
|
+
self, prompts: torch.Tensor, max_new_tokens: int
|
107
|
+
) -> torch.IntTensor:
|
108
|
+
input_ids = prompts[0].int().tolist()
|
109
|
+
kv_cache = self._init_kv_cache()
|
110
|
+
for _ in range(max_new_tokens):
|
111
|
+
tokens = torch.tensor([input_ids])
|
112
|
+
logits, kv_cache = self._forward_with_kv_cache(tokens, kv_cache)
|
113
|
+
generated_token = logits[0][-1].argmax().item()
|
114
|
+
input_ids.append(generated_token)
|
115
|
+
return torch.tensor([input_ids])
|
116
|
+
|
117
|
+
|
118
|
+
class TokenizerWrapper(torch.nn.Module):
|
119
|
+
"""A wrapper for the tokenizer used for verification."""
|
120
|
+
|
121
|
+
def __init__(self, tokenizer: torch.nn.Module):
|
122
|
+
"""Initializes the wrapper.
|
123
|
+
|
124
|
+
Args:
|
125
|
+
tokenizer (torch.nn.Module): The tokenizer to wrap.
|
126
|
+
"""
|
127
|
+
super().__init__()
|
128
|
+
self.tokenizer = tokenizer
|
129
|
+
|
130
|
+
def encode(self, prompts: str) -> torch.Tensor:
|
131
|
+
"""Encodes the prompts to token IDs."""
|
132
|
+
return self.tokenizer.encode(prompts, return_tensors="pt")
|
133
|
+
|
134
|
+
def decode(self, token_ids: torch.Tensor) -> str:
|
135
|
+
"""Decodes the token IDs to a string."""
|
136
|
+
return self.tokenizer.decode(token_ids)
|
123
137
|
|
124
138
|
|
125
139
|
def verify_with_input_ids(
|
126
140
|
original_model: ModelWrapper,
|
127
|
-
reauthored_model:
|
141
|
+
reauthored_model: ReauthoredModelWrapper,
|
128
142
|
input_ids: List[int],
|
129
143
|
kv_cache_max_len: int = 1024,
|
130
144
|
rtol: float = 1e-05,
|
@@ -136,8 +150,8 @@ def verify_with_input_ids(
|
|
136
150
|
|
137
151
|
Args:
|
138
152
|
original_model (ModelWrapper): The original model.
|
139
|
-
reauthored_model (
|
140
|
-
Generative API.
|
153
|
+
reauthored_model (ReauthoredModelWrapper): The model reauthored with
|
154
|
+
ai_edge_torch Generative API.
|
141
155
|
input_ids (List[int]): The input token IDs to forward with.
|
142
156
|
kv_cache_max_len (int): The maximum sequence length of the KV cache.
|
143
157
|
rtol (float): The relative tolerance for the comparison.
|
@@ -149,16 +163,15 @@ def verify_with_input_ids(
|
|
149
163
|
tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.int, device="cpu")
|
150
164
|
tokens[0, : len(input_ids)] = torch.tensor([input_ids]).int()
|
151
165
|
|
152
|
-
|
166
|
+
logging.info("Forwarding the original model...")
|
153
167
|
outputs_original = original_model.forward(tokens)
|
154
|
-
logits_original = outputs_original
|
155
|
-
|
168
|
+
logits_original = outputs_original[0, len(input_ids) - 1, :]
|
169
|
+
logging.info("logits_original: %s", logits_original)
|
156
170
|
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
logits_reauthored
|
161
|
-
log_msg("logits_reauthored:", logits_reauthored)
|
171
|
+
logging.info("Forwarding the reauthored model...")
|
172
|
+
outputs_reauthored = reauthored_model.forward(tokens)
|
173
|
+
logits_reauthored = outputs_reauthored[0, len(input_ids) - 1, :]
|
174
|
+
logging.info("logits_reauthored: %s", logits_reauthored)
|
162
175
|
|
163
176
|
return torch.allclose(
|
164
177
|
logits_original, logits_reauthored, rtol=rtol, atol=atol
|
@@ -167,9 +180,10 @@ def verify_with_input_ids(
|
|
167
180
|
|
168
181
|
def verify_model_with_prompts(
|
169
182
|
original_model: ModelWrapper,
|
170
|
-
reauthored_model:
|
171
|
-
tokenizer:
|
183
|
+
reauthored_model: ReauthoredModelWrapper,
|
184
|
+
tokenizer: TokenizerWrapper,
|
172
185
|
prompts: str,
|
186
|
+
max_new_tokens: int,
|
173
187
|
) -> bool:
|
174
188
|
"""Verifies if the model reauthored generates the same answer of the oringal.
|
175
189
|
|
@@ -178,35 +192,36 @@ def verify_model_with_prompts(
|
|
178
192
|
|
179
193
|
Args:
|
180
194
|
original_model (ModelWrapper): The original model.
|
181
|
-
reauthored_model (
|
182
|
-
Generative API.
|
183
|
-
tokenizer (
|
195
|
+
reauthored_model (ReauthoredModelWrapper): The model reauthored with
|
196
|
+
ai_edge_torch Generative API.
|
197
|
+
tokenizer (TokenizerWrapper): The tokenizer.
|
184
198
|
prompts (str): The input prompts to generate answers.
|
199
|
+
max_new_tokens (int): The maximum number of new tokens to generate.
|
185
200
|
|
186
201
|
Returns:
|
187
202
|
True if the model reauthored generates the same answer of the original.
|
188
203
|
"""
|
189
|
-
prompt_tokens = tokenizer.encode(prompts
|
204
|
+
prompt_tokens = tokenizer.encode(prompts)
|
190
205
|
|
191
|
-
|
192
|
-
outputs_original = original_model.generate(prompt_tokens)
|
206
|
+
logging.info("Generating answer with the original model...")
|
207
|
+
outputs_original = original_model.generate(prompt_tokens, max_new_tokens)
|
193
208
|
response_original = tokenizer.decode(outputs_original[0])
|
194
|
-
|
209
|
+
logging.info("outputs_from_original_model: [[%s]]", response_original)
|
195
210
|
|
196
|
-
|
197
|
-
|
198
|
-
outputs_reauthored = generate(reauthored_model, prompt_tokens, generate_len)
|
211
|
+
logging.info("Generating answer with the reauthored model...")
|
212
|
+
outputs_reauthored = reauthored_model.generate(prompt_tokens, max_new_tokens)
|
199
213
|
response_reauthored = tokenizer.decode(outputs_reauthored[0])
|
200
|
-
|
214
|
+
logging.info("outputs from reauthored model: [[%s]]", response_reauthored)
|
201
215
|
|
202
216
|
return response_original == response_reauthored
|
203
217
|
|
204
218
|
|
205
219
|
def verify_reauthored_model(
|
206
220
|
original_model: ModelWrapper,
|
207
|
-
reauthored_model:
|
208
|
-
tokenizer:
|
221
|
+
reauthored_model: ReauthoredModelWrapper,
|
222
|
+
tokenizer: TokenizerWrapper,
|
209
223
|
generate_prompts: List[str],
|
224
|
+
max_new_tokens: int = 30,
|
210
225
|
forward_input_ids: List[List[int]] = [[1, 2, 3, 4]],
|
211
226
|
rtol: float = 1e-05,
|
212
227
|
atol: float = 1e-05,
|
@@ -223,29 +238,30 @@ def verify_reauthored_model(
|
|
223
238
|
|
224
239
|
Args:
|
225
240
|
original_model (ModelWrapper): The original model.
|
226
|
-
reauthored_model (
|
227
|
-
Generative API.
|
228
|
-
tokenizer (
|
241
|
+
reauthored_model (ReauthoredModelWrapper): The model reauthored with
|
242
|
+
ai_edge_torch Generative API.
|
243
|
+
tokenizer (TokenizerWrapper): The tokenizer.
|
229
244
|
generate_prompts (List[str]): List of the input prompts to generate answers.
|
245
|
+
max_new_tokens (int): The maximum number of new tokens to generate.
|
230
246
|
forward_input_ids (List[torch.Tensor]): List if ihe input token IDs to
|
231
247
|
forward with.
|
232
248
|
rtol (float): The relative tolerance for the comparison.
|
233
249
|
atol (float): The absolute tolerance for the comparison.
|
234
250
|
"""
|
235
251
|
for input_ids in forward_input_ids:
|
236
|
-
|
252
|
+
logging.info("Verifying the reauthored model with input IDs: %s", input_ids)
|
237
253
|
if verify_with_input_ids(
|
238
254
|
original_model, reauthored_model, input_ids, rtol=rtol, atol=atol
|
239
255
|
):
|
240
|
-
|
256
|
+
logging.info("PASS")
|
241
257
|
else:
|
242
|
-
|
258
|
+
logging.error("FAILED")
|
243
259
|
|
244
260
|
for prompts in generate_prompts:
|
245
|
-
|
261
|
+
logging.info("Verifying the reauthored model with prompts:%s", prompts)
|
246
262
|
if verify_model_with_prompts(
|
247
|
-
original_model, reauthored_model, tokenizer, prompts
|
263
|
+
original_model, reauthored_model, tokenizer, prompts, max_new_tokens
|
248
264
|
):
|
249
|
-
|
265
|
+
logging.info("PASS")
|
250
266
|
else:
|
251
|
-
|
267
|
+
logging.error("FAILED")
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.3.0.
|
3
|
+
Version: 0.3.0.dev20240928
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|