ai-edge-torch-nightly 0.3.0.dev20241218__py3-none-any.whl → 0.3.0.dev20241224__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/generative/examples/gemma/gemma1.py +3 -2
- ai_edge_torch/generative/examples/gemma/gemma2.py +43 -25
- ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py +11 -3
- ai_edge_torch/generative/examples/paligemma/decoder.py +14 -5
- ai_edge_torch/generative/examples/paligemma/decoder2.py +174 -0
- ai_edge_torch/generative/examples/paligemma/paligemma.py +30 -15
- ai_edge_torch/generative/examples/paligemma/verify.py +36 -9
- ai_edge_torch/generative/examples/paligemma/verify_decoder2.py +72 -0
- ai_edge_torch/generative/examples/paligemma/verify_image_encoder.py +24 -7
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +3 -3
- ai_edge_torch/generative/layers/attention.py +4 -29
- ai_edge_torch/generative/layers/rotary_position_embedding.py +34 -27
- ai_edge_torch/generative/test/test_model_conversion_large.py +28 -9
- ai_edge_torch/generative/utilities/model_builder.py +14 -14
- ai_edge_torch/generative/utilities/verifier.py +22 -22
- ai_edge_torch/odml_torch/export.py +6 -1
- ai_edge_torch/odml_torch/jax_bridge/__init__.py +4 -1
- ai_edge_torch/odml_torch/lowerings/__init__.py +1 -0
- ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +1 -2
- ai_edge_torch/odml_torch/lowerings/_rand.py +142 -0
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20241218.dist-info → ai_edge_torch_nightly-0.3.0.dev20241224.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20241218.dist-info → ai_edge_torch_nightly-0.3.0.dev20241224.dist-info}/RECORD +26 -23
- {ai_edge_torch_nightly-0.3.0.dev20241218.dist-info → ai_edge_torch_nightly-0.3.0.dev20241224.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20241218.dist-info → ai_edge_torch_nightly-0.3.0.dev20241224.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20241218.dist-info → ai_edge_torch_nightly-0.3.0.dev20241224.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,72 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Verifies the reauthored decoder of PaliGemma2 3B model."""
|
17
|
+
|
18
|
+
import logging
|
19
|
+
|
20
|
+
from absl import app
|
21
|
+
from absl import flags
|
22
|
+
from ai_edge_torch.generative.examples.paligemma import decoder2
|
23
|
+
from ai_edge_torch.generative.utilities import transformers_verifier
|
24
|
+
from ai_edge_torch.generative.utilities import verifier
|
25
|
+
import kagglehub
|
26
|
+
import transformers
|
27
|
+
|
28
|
+
_PROMPTS = flags.DEFINE_multi_string(
|
29
|
+
"prompts",
|
30
|
+
"What is the meaning of life?",
|
31
|
+
"The input prompts to generate answers.",
|
32
|
+
)
|
33
|
+
_MAX_NEW_TOKENS = flags.DEFINE_integer(
|
34
|
+
"max_new_tokens",
|
35
|
+
30,
|
36
|
+
"The maximum size of the generated tokens.",
|
37
|
+
)
|
38
|
+
|
39
|
+
|
40
|
+
def main(_):
|
41
|
+
checkpoint = kagglehub.model_download(
|
42
|
+
"google/paligemma-2/transformers/paligemma2-3b-pt-224"
|
43
|
+
)
|
44
|
+
logging.info("Loading the original model from: %s", checkpoint)
|
45
|
+
original_full_model = (
|
46
|
+
transformers.PaliGemmaForConditionalGeneration.from_pretrained(checkpoint)
|
47
|
+
)
|
48
|
+
original_language_model = original_full_model.eval().language_model
|
49
|
+
|
50
|
+
logging.info("Building the reauthored model from: %s", checkpoint)
|
51
|
+
reauthored_model = decoder2.build_decoder2(checkpoint)
|
52
|
+
|
53
|
+
logging.info("Loading the tokenizer from: %s", checkpoint)
|
54
|
+
# It works only when GemmaTokenizerFast is available. In some environments,
|
55
|
+
# use_fast=False doeesn't work either if the tokenizer cannot load the
|
56
|
+
# sentencepiece model file properly.
|
57
|
+
processor = transformers.AutoProcessor.from_pretrained(checkpoint)
|
58
|
+
|
59
|
+
verifier.verify_reauthored_model(
|
60
|
+
original_model=transformers_verifier.TransformersModelWrapper(
|
61
|
+
original_language_model
|
62
|
+
),
|
63
|
+
reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
|
64
|
+
tokenizer=verifier.TokenizerWrapper(processor.tokenizer),
|
65
|
+
generate_prompts=_PROMPTS.value,
|
66
|
+
max_new_tokens=_MAX_NEW_TOKENS.value,
|
67
|
+
atol=1e-04,
|
68
|
+
)
|
69
|
+
|
70
|
+
|
71
|
+
if __name__ == "__main__":
|
72
|
+
app.run(main)
|
@@ -20,31 +20,48 @@ import pathlib
|
|
20
20
|
from absl import app
|
21
21
|
from absl import flags
|
22
22
|
from ai_edge_torch.generative.examples.paligemma import image_encoder
|
23
|
+
import kagglehub
|
23
24
|
from PIL import Image
|
24
25
|
import requests
|
25
26
|
import torch
|
26
27
|
import transformers
|
27
28
|
|
29
|
+
_VERSION = flags.DEFINE_enum(
|
30
|
+
"version",
|
31
|
+
"2",
|
32
|
+
["1", "2"],
|
33
|
+
"The version of PaliGemma vision model to verify.",
|
34
|
+
)
|
28
35
|
_IMAGE_URL = flags.DEFINE_string(
|
29
36
|
"image_url",
|
30
37
|
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg?download=true",
|
31
38
|
"The image URI to encode.",
|
32
39
|
)
|
33
40
|
|
41
|
+
_CHECKPOINT = {
|
42
|
+
"1": "google/paligemma-3b-mix-224",
|
43
|
+
"2": "google/paligemma-2/transformers/paligemma2-3b-pt-224",
|
44
|
+
}
|
45
|
+
|
34
46
|
|
35
47
|
def main(_):
|
36
|
-
|
48
|
+
if _VERSION.value == "1":
|
49
|
+
checkpoint = _CHECKPOINT[_VERSION.value]
|
50
|
+
# Locate the cached dir.
|
51
|
+
cached_config_file = transformers.utils.cached_file(
|
52
|
+
checkpoint, transformers.utils.CONFIG_NAME
|
53
|
+
)
|
54
|
+
reauthored_checkpoint = pathlib.Path(cached_config_file).parent
|
55
|
+
else:
|
56
|
+
checkpoint = kagglehub.model_download(_CHECKPOINT[_VERSION.value])
|
57
|
+
reauthored_checkpoint = checkpoint
|
58
|
+
|
37
59
|
logging.info("Loading the original model from: %s", checkpoint)
|
38
60
|
original_full_model = (
|
39
61
|
transformers.PaliGemmaForConditionalGeneration.from_pretrained(checkpoint)
|
40
62
|
)
|
41
63
|
original_vision_model = original_full_model.eval().vision_tower
|
42
64
|
|
43
|
-
# Locate the cached dir.
|
44
|
-
cached_config_file = transformers.utils.cached_file(
|
45
|
-
checkpoint, transformers.utils.CONFIG_NAME
|
46
|
-
)
|
47
|
-
reauthored_checkpoint = pathlib.Path(cached_config_file).parent
|
48
65
|
logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
|
49
66
|
reauthored_model = image_encoder.build_image_encoder(reauthored_checkpoint)
|
50
67
|
|
@@ -69,7 +86,7 @@ def main(_):
|
|
69
86
|
|
70
87
|
try:
|
71
88
|
assert torch.allclose(
|
72
|
-
outputs_original, outputs_reauthored, atol=1e-
|
89
|
+
outputs_original, outputs_reauthored, atol=1e-03, rtol=1e-04
|
73
90
|
)
|
74
91
|
except AssertionError as e:
|
75
92
|
logging.error("*** FAILED *** verify with an image")
|
@@ -72,14 +72,14 @@ class ToyModelWithKVCache(torch.nn.Module):
|
|
72
72
|
mask = self.mask_cache.index_select(2, input_pos)
|
73
73
|
mask = mask[:, :, :, : self.config.max_seq_len]
|
74
74
|
|
75
|
-
|
75
|
+
updated_kv_entries = []
|
76
76
|
for i, block in enumerate(self.transformer_blocks):
|
77
77
|
kv_entry = kv_cache.caches[i] if kv_cache else None
|
78
78
|
x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
|
79
79
|
if kv_entry:
|
80
|
-
|
80
|
+
updated_kv_entries.append(kv_entry)
|
81
81
|
|
82
|
-
updated_kv_cache = kv_utils.KVCache(tuple(
|
82
|
+
updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entries))
|
83
83
|
|
84
84
|
if export_config is not None:
|
85
85
|
if (
|
@@ -26,33 +26,6 @@ import torch
|
|
26
26
|
from torch import nn
|
27
27
|
|
28
28
|
|
29
|
-
def _embed_rope(
|
30
|
-
q: torch.Tensor,
|
31
|
-
k: torch.Tensor,
|
32
|
-
n_elem: int,
|
33
|
-
rope: Tuple[torch.Tensor, torch.Tensor],
|
34
|
-
) -> Tuple[torch.Tensor, torch.Tensor]:
|
35
|
-
"""Embed rotary positional embedding for query and key.
|
36
|
-
|
37
|
-
Args:
|
38
|
-
q (torch.Tensor): query tensor.
|
39
|
-
k (torch.Tensor): key tensor.
|
40
|
-
n_elem (int): number of elements to embed rotarty positional embedding.
|
41
|
-
rope (Tuple[torch.Tensor, torch.Tensor]): the input rope tensor.
|
42
|
-
"""
|
43
|
-
if n_elem > 0:
|
44
|
-
cos, sin = rope
|
45
|
-
q_roped = rotary_pos_emb.apply_rope(
|
46
|
-
q[..., :n_elem], cos.repeat(1, 2), sin.repeat(1, 2)
|
47
|
-
)
|
48
|
-
k_roped = rotary_pos_emb.apply_rope(
|
49
|
-
k[..., :n_elem], cos.repeat(1, 2), sin.repeat(1, 2)
|
50
|
-
)
|
51
|
-
q = torch.cat((q_roped, q[..., n_elem:]), dim=-1)
|
52
|
-
k = torch.cat((k_roped, k[..., n_elem:]), dim=-1)
|
53
|
-
return q, k
|
54
|
-
|
55
|
-
|
56
29
|
class TransformerBlock(nn.Module):
|
57
30
|
|
58
31
|
def __init__(
|
@@ -238,7 +211,8 @@ class CausalSelfAttention(nn.Module):
|
|
238
211
|
if rope is not None:
|
239
212
|
# Compute rotary positional embedding for query and key.
|
240
213
|
n_elem = int(self.config.rotary_percentage * self.config.head_dim)
|
241
|
-
|
214
|
+
cos, sin = rope
|
215
|
+
q, k = rotary_pos_emb.apply_rope_inline(q, k, cos, sin)
|
242
216
|
|
243
217
|
if kv_cache is not None:
|
244
218
|
kv_cache = kv_utils.update(kv_cache, input_pos, k, v)
|
@@ -374,7 +348,8 @@ class CrossAttention(nn.Module):
|
|
374
348
|
if rope is not None:
|
375
349
|
# Compute rotary positional embedding for query and key.
|
376
350
|
n_elem = int(self.config.rotary_percentage * self.config.head_dim)
|
377
|
-
|
351
|
+
cos, sin = rope
|
352
|
+
q, k = rotary_pos_emb.apply_rope_inline(q, k, cos, sin)
|
378
353
|
|
379
354
|
if kv_cache is not None:
|
380
355
|
kv_cache = kv_utils.update(kv_cache, input_pos, k, v)
|
@@ -32,57 +32,64 @@ def apply_rope(
|
|
32
32
|
"""
|
33
33
|
x = x.transpose(1, 2)
|
34
34
|
head_size = x.size(-1)
|
35
|
-
x1 = x
|
36
|
-
|
37
|
-
|
38
|
-
roped = (
|
35
|
+
x1, x2 = torch.split(x, head_size // 2, dim=-1)
|
36
|
+
left = x1 * cos - x2 * sin
|
37
|
+
right = x2 * cos + x1 * sin
|
38
|
+
roped = torch.cat([left, right], dim=-1)
|
39
39
|
return roped.transpose(1, 2).type_as(x)
|
40
40
|
|
41
41
|
|
42
|
-
def
|
43
|
-
q: torch.Tensor,
|
44
|
-
k: torch.Tensor,
|
42
|
+
def build_rope(
|
45
43
|
input_pos: torch.Tensor,
|
46
44
|
n_elem: int,
|
45
|
+
head_dim: int,
|
47
46
|
base: int = 10_000,
|
48
47
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
49
|
-
"""Computes rotary positional embedding
|
48
|
+
"""Computes rotary positional embedding cosine and sine tensors.
|
50
49
|
|
51
50
|
Args:
|
52
|
-
q: the query tensor.
|
53
|
-
k: the key tensor.
|
54
51
|
input_pos: the sequence indices for the query and key
|
55
52
|
n_elem: number of elements of the head dimension for RoPE computation
|
53
|
+
base: the base of the exponentiated value for RoPE.
|
56
54
|
|
57
55
|
Returns:
|
58
|
-
|
56
|
+
cos, sin tensors
|
59
57
|
"""
|
60
58
|
|
61
59
|
if n_elem <= 0:
|
62
|
-
return
|
60
|
+
return None, None
|
63
61
|
|
64
62
|
theta = 1.0 / (base ** (torch.arange(0, n_elem, 2).float() / n_elem))
|
65
63
|
freq_exponents = (2.0 / n_elem) * torch.arange(
|
66
|
-
|
64
|
+
head_dim // 2, dtype=torch.float32
|
67
65
|
)
|
68
66
|
timescale = float(base) ** freq_exponents
|
69
67
|
radians = input_pos.clone().unsqueeze(0).unsqueeze(-1) / timescale.unsqueeze(
|
70
68
|
0
|
71
69
|
).unsqueeze(0)
|
72
|
-
cos = torch.cos(radians)
|
73
|
-
sin = torch.sin(radians)
|
70
|
+
cos = torch.cos(radians)
|
71
|
+
sin = torch.sin(radians)
|
72
|
+
return cos, sin
|
73
|
+
|
74
74
|
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
75
|
+
def apply_rope_inline(
|
76
|
+
q: torch.Tensor,
|
77
|
+
k: torch.Tensor,
|
78
|
+
cos: torch.Tensor,
|
79
|
+
sin: torch.Tensor,
|
80
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
81
|
+
"""Computes rotary positional embedding inline for a query and key.
|
82
|
+
|
83
|
+
Args:
|
84
|
+
q: the query tensor.
|
85
|
+
k: the key tensor.
|
86
|
+
cos: the cosine tensor.
|
87
|
+
sin: the sine tensor.
|
88
|
+
|
89
|
+
Returns:
|
90
|
+
output the RoPE'd query and key.
|
91
|
+
"""
|
85
92
|
|
86
|
-
q_roped =
|
87
|
-
k_roped =
|
93
|
+
q_roped = apply_rope(q, cos, sin)
|
94
|
+
k_roped = apply_rope(k, cos, sin)
|
88
95
|
return q_roped, k_roped
|
@@ -21,6 +21,8 @@ from ai_edge_torch.generative.examples.gemma import gemma1
|
|
21
21
|
from ai_edge_torch.generative.examples.gemma import gemma2
|
22
22
|
from ai_edge_torch.generative.examples.llama import llama
|
23
23
|
from ai_edge_torch.generative.examples.openelm import openelm
|
24
|
+
from ai_edge_torch.generative.examples.paligemma import decoder
|
25
|
+
from ai_edge_torch.generative.examples.paligemma import decoder2
|
24
26
|
from ai_edge_torch.generative.examples.paligemma import paligemma
|
25
27
|
from ai_edge_torch.generative.examples.phi import phi2
|
26
28
|
from ai_edge_torch.generative.examples.phi import phi3
|
@@ -171,13 +173,9 @@ class TestModelConversion(googletest.TestCase):
|
|
171
173
|
pytorch_model = amd_llama_135m.AmdLlama(config).eval()
|
172
174
|
self._test_model(config, pytorch_model, "prefill", atol=1e-5, rtol=1e-5)
|
173
175
|
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
)
|
178
|
-
def disabled_test_paligemma(self):
|
179
|
-
config = paligemma.get_fake_model_config()
|
180
|
-
pytorch_model = paligemma.PaliGemma(config).eval()
|
176
|
+
def _test_paligemma_model(self, decoder_class, decoder_config, atol, rtol):
|
177
|
+
config = paligemma.get_fake_model_config(decoder_config)
|
178
|
+
pytorch_model = paligemma.PaliGemma(config, decoder_class).eval()
|
181
179
|
|
182
180
|
image_embedding_config = config.image_encoder_config.image_embedding
|
183
181
|
num_patches = (
|
@@ -215,11 +213,32 @@ class TestModelConversion(googletest.TestCase):
|
|
215
213
|
kv,
|
216
214
|
pixel_values=pixel_values,
|
217
215
|
signature_name="prefill_pixel",
|
218
|
-
atol=
|
219
|
-
rtol=
|
216
|
+
atol=atol,
|
217
|
+
rtol=rtol,
|
220
218
|
)
|
221
219
|
)
|
222
220
|
|
221
|
+
@googletest.skipIf(
|
222
|
+
ai_edge_torch.config.in_oss,
|
223
|
+
reason="tests with custom ops are not supported in oss",
|
224
|
+
)
|
225
|
+
def disabled_test_paligemma1(self):
|
226
|
+
self._test_paligemma_model(
|
227
|
+
decoder.Decoder, decoder.get_fake_decoder_config, atol=1e-3, rtol=1e-5
|
228
|
+
)
|
229
|
+
|
230
|
+
@googletest.skipIf(
|
231
|
+
ai_edge_torch.config.in_oss,
|
232
|
+
reason="tests with custom ops are not supported in oss",
|
233
|
+
)
|
234
|
+
def disabled_test_paligemma2(self):
|
235
|
+
self._test_paligemma_model(
|
236
|
+
decoder2.Decoder2,
|
237
|
+
decoder2.get_fake_decoder2_config,
|
238
|
+
atol=1e-3,
|
239
|
+
rtol=1e-5,
|
240
|
+
)
|
241
|
+
|
223
242
|
@googletest.skipIf(
|
224
243
|
ai_edge_torch.config.in_oss,
|
225
244
|
reason="tests with custom ops are not supported in oss",
|
@@ -24,6 +24,7 @@ from ai_edge_torch.generative.layers import builder
|
|
24
24
|
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
25
25
|
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
26
26
|
import ai_edge_torch.generative.layers.model_config as cfg
|
27
|
+
import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
|
27
28
|
import ai_edge_torch.generative.utilities.loader as loading_utils
|
28
29
|
import torch
|
29
30
|
from torch import nn
|
@@ -85,13 +86,6 @@ class DecoderOnlyModel(nn.Module):
|
|
85
86
|
config.embedding_dim,
|
86
87
|
config.final_norm_config,
|
87
88
|
)
|
88
|
-
# ROPE parameters for all attn_configs are the same. Take the first one.
|
89
|
-
attn_config = config.block_config(0).attn_config
|
90
|
-
self.rope_cache = attn_utils.build_rope_cache(
|
91
|
-
size=config.kv_cache_max,
|
92
|
-
dim=int(attn_config.rotary_percentage * attn_config.head_dim),
|
93
|
-
base=attn_config.rotary_base,
|
94
|
-
)
|
95
89
|
self.mask_cache = attn_utils.build_causal_mask_cache(
|
96
90
|
size=config.kv_cache_max,
|
97
91
|
)
|
@@ -113,16 +107,22 @@ class DecoderOnlyModel(nn.Module):
|
|
113
107
|
|
114
108
|
# token embeddings of shape (b, t, n_embd)
|
115
109
|
input_embeds = self.tok_embedding(tokens)
|
116
|
-
|
117
|
-
|
110
|
+
|
111
|
+
# ROPE parameters for all attn_configs are the same. Take the first one.
|
112
|
+
attn_config = self.config.block_config(0).attn_config
|
113
|
+
n_elem = int(attn_config.rotary_percentage * attn_config.head_dim)
|
114
|
+
rope = rotary_pos_emb.build_rope(
|
115
|
+
input_pos, n_elem, attn_config.head_dim, attn_config.rotary_base
|
116
|
+
)
|
117
|
+
|
118
118
|
mask = self.mask_cache.index_select(2, input_pos)
|
119
119
|
mask = mask[:, :, :, : self.config.kv_cache_max]
|
120
120
|
|
121
|
-
return self.
|
121
|
+
return self._forward_with_embeds(
|
122
122
|
input_embeds, rope, mask, input_pos, kv_cache, export_config
|
123
123
|
)
|
124
124
|
|
125
|
-
def
|
125
|
+
def _forward_with_embeds(
|
126
126
|
self,
|
127
127
|
input_embeds: torch.Tensor,
|
128
128
|
rope: Tuple[torch.Tensor, torch.Tensor],
|
@@ -141,13 +141,13 @@ class DecoderOnlyModel(nn.Module):
|
|
141
141
|
if self.config.embedding_scale is not None:
|
142
142
|
x = x * self.config.embedding_scale
|
143
143
|
|
144
|
-
|
144
|
+
updated_kv_entries = []
|
145
145
|
for i, block in enumerate(self.transformer_blocks):
|
146
146
|
kv_entry = kv_cache.caches[i] if kv_cache else None
|
147
147
|
x, kv_entry = block(x, rope, mask, input_pos, kv_entry)
|
148
148
|
if kv_entry:
|
149
|
-
|
150
|
-
updated_kv_cache = kv_utils.KVCache(tuple(
|
149
|
+
updated_kv_entries.append(kv_entry)
|
150
|
+
updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entries))
|
151
151
|
|
152
152
|
if export_config is not None:
|
153
153
|
if (
|
@@ -16,7 +16,7 @@
|
|
16
16
|
"""Common utility functions to verify the reauthored models."""
|
17
17
|
|
18
18
|
import logging
|
19
|
-
from typing import List
|
19
|
+
from typing import Any,List
|
20
20
|
|
21
21
|
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
22
22
|
from ai_edge_torch.generative.utilities.model_builder import ExportConfig
|
@@ -87,6 +87,10 @@ class ReauthoredModelWrapper(ModelWrapper):
|
|
87
87
|
"""Returns an initialized KV cache."""
|
88
88
|
return kv_utils.KVCache.from_model_config(self.model.config)
|
89
89
|
|
90
|
+
def _get_extra_args_for_forward(self) -> dict[str, Any]:
|
91
|
+
"""Returns extra arguments for the forward() method."""
|
92
|
+
return {}
|
93
|
+
|
90
94
|
def _forward_with_kv_cache(
|
91
95
|
self,
|
92
96
|
tokens: torch.Tensor,
|
@@ -105,26 +109,15 @@ class ReauthoredModelWrapper(ModelWrapper):
|
|
105
109
|
Returns:
|
106
110
|
The output logits and the updated KV cache.
|
107
111
|
"""
|
108
|
-
|
109
|
-
if
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
output = self.model.forward(
|
118
|
-
tokens, input_pos, kv_cache, export_config=self.export_config
|
119
|
-
)
|
120
|
-
else:
|
121
|
-
output = self.model.forward(
|
122
|
-
tokens,
|
123
|
-
input_pos,
|
124
|
-
kv_cache,
|
125
|
-
pixel_values=pixel_values,
|
126
|
-
export_config=self.export_config,
|
127
|
-
)
|
112
|
+
extra_args = self._get_extra_args_for_forward()
|
113
|
+
if self.export_config is not None:
|
114
|
+
# Verification requires logit outputs on prefill for comparison.
|
115
|
+
if not self.export_config.output_logits_on_prefill:
|
116
|
+
raise ValueError("Verifier requires logit output on prefill.")
|
117
|
+
extra_args["export_config"] = self.export_config
|
118
|
+
if pixel_values is not None:
|
119
|
+
extra_args["pixel_values"] = pixel_values
|
120
|
+
output = self.model.forward(tokens, input_pos, kv_cache, **extra_args)
|
128
121
|
return output["logits"], output["kv_cache"]
|
129
122
|
|
130
123
|
def forward(
|
@@ -141,6 +134,7 @@ class ReauthoredModelWrapper(ModelWrapper):
|
|
141
134
|
prompts: torch.Tensor,
|
142
135
|
max_new_tokens: int,
|
143
136
|
pixel_values: torch.Tensor = None,
|
137
|
+
eos_token_id: int = 1,
|
144
138
|
) -> torch.IntTensor:
|
145
139
|
input_ids = prompts[0].int().tolist()
|
146
140
|
tokens = torch.tensor([input_ids])
|
@@ -152,6 +146,8 @@ class ReauthoredModelWrapper(ModelWrapper):
|
|
152
146
|
)
|
153
147
|
generated_token = logits[0][-1].argmax().item()
|
154
148
|
input_ids.append(generated_token)
|
149
|
+
if generated_token == eos_token_id:
|
150
|
+
break
|
155
151
|
tokens = torch.tensor([[generated_token]])
|
156
152
|
input_pos = torch.tensor([len(input_ids) - 1])
|
157
153
|
pixel_values = None # Pass only for the first time.
|
@@ -254,7 +250,11 @@ def verify_model_with_prompts(
|
|
254
250
|
logging.info("outputs_from_original_model: [[%s]]", response_original)
|
255
251
|
|
256
252
|
logging.info("Generating answer with the reauthored model...")
|
257
|
-
outputs_reauthored = reauthored_model.generate(
|
253
|
+
outputs_reauthored = reauthored_model.generate(
|
254
|
+
prompt_tokens,
|
255
|
+
max_new_tokens,
|
256
|
+
eos_token_id=tokenizer.tokenizer.eos_token_id,
|
257
|
+
)
|
258
258
|
response_reauthored = tokenizer.decode(outputs_reauthored[0])
|
259
259
|
logging.info("outputs from reauthored model: [[%s]]", response_reauthored)
|
260
260
|
|
@@ -198,7 +198,12 @@ class MlirLowered:
|
|
198
198
|
# build, which may not have the same StableHLO version as what used in
|
199
199
|
# TFLite converter. Therefore we always serialize MLIR module in VHLO.
|
200
200
|
# TODO(b/362798610) Build MLIR pybinding in ai-edge-torch release.
|
201
|
-
|
201
|
+
if stablehlo.get_api_version() < 9:
|
202
|
+
target_version = stablehlo.get_minimum_version()
|
203
|
+
else:
|
204
|
+
target_version = stablehlo.get_version_from_compatibility_requirement(
|
205
|
+
stablehlo.StablehloCompatibilityRequirement.WEEK_4
|
206
|
+
)
|
202
207
|
module_bytecode = xla_extension.mlir.serialize_portable_artifact(
|
203
208
|
self.module_bytecode, target_version
|
204
209
|
)
|
@@ -12,4 +12,7 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
from ai_edge_torch.odml_torch.jax_bridge
|
15
|
+
from ai_edge_torch.odml_torch.jax_bridge import _wrap
|
16
|
+
from ai_edge_torch.odml_torch.jax_bridge import utils
|
17
|
+
|
18
|
+
wrap = _wrap.wrap
|
@@ -26,6 +26,7 @@ import torch_xla2.ops.ops_registry # Import to load torch_xla2 ops
|
|
26
26
|
|
27
27
|
LoweringContext = context.LoweringContext
|
28
28
|
|
29
|
+
|
29
30
|
@functools.cache
|
30
31
|
def _log_usage(op):
|
31
32
|
logging.warning("Use jax lowering: %s", str(op))
|
@@ -184,8 +185,6 @@ lower_by_torch_xla2(torch.ops.aten.permute_copy)
|
|
184
185
|
lower_by_torch_xla2(torch.ops.aten.pixel_shuffle)
|
185
186
|
lower_by_torch_xla2(torch.ops.aten.pow)
|
186
187
|
lower_by_torch_xla2(torch.ops.aten.prod)
|
187
|
-
lower_by_torch_xla2(torch.ops.aten.rand)
|
188
|
-
lower_by_torch_xla2(torch.ops.aten.randn)
|
189
188
|
lower_by_torch_xla2(torch.ops.aten.reciprocal)
|
190
189
|
lower_by_torch_xla2(torch.ops.aten.reflection_pad1d)
|
191
190
|
lower_by_torch_xla2(torch.ops.aten.relu)
|