ai-edge-torch-nightly 0.3.0.dev20250205__py3-none-any.whl → 0.3.0.dev20250207__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- ai_edge_torch/generative/examples/qwen_vl/decoder.py +49 -4
- ai_edge_torch/generative/examples/qwen_vl/image_encoder.py +8 -5
- ai_edge_torch/generative/examples/qwen_vl/qwen_vl.py +211 -0
- ai_edge_torch/generative/examples/qwen_vl/verify.py +143 -0
- ai_edge_torch/generative/examples/qwen_vl/verify_decoder.py +1 -1
- ai_edge_torch/odml_torch/debuginfo/__init__.py +1 -1
- ai_edge_torch/odml_torch/debuginfo/_build.py +24 -0
- ai_edge_torch/odml_torch/export.py +6 -1
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20250205.dist-info → ai_edge_torch_nightly-0.3.0.dev20250207.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20250205.dist-info → ai_edge_torch_nightly-0.3.0.dev20250207.dist-info}/RECORD +14 -12
- {ai_edge_torch_nightly-0.3.0.dev20250205.dist-info → ai_edge_torch_nightly-0.3.0.dev20250207.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20250205.dist-info → ai_edge_torch_nightly-0.3.0.dev20250207.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20250205.dist-info → ai_edge_torch_nightly-0.3.0.dev20250207.dist-info}/top_level.txt +0 -0
@@ -15,16 +15,61 @@
|
|
15
15
|
|
16
16
|
"""Example of building decoder for Qwen 2.5 VL models."""
|
17
17
|
|
18
|
+
from typing import Optional, Tuple
|
19
|
+
|
20
|
+
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
18
21
|
import ai_edge_torch.generative.layers.model_config as cfg
|
19
22
|
from ai_edge_torch.generative.utilities import model_builder
|
20
|
-
|
23
|
+
import torch
|
21
24
|
|
22
25
|
TENSOR_NAMES = model_builder.TENSOR_NAMES
|
23
26
|
|
24
27
|
|
25
28
|
class Decoder(model_builder.DecoderOnlyModel):
|
26
|
-
"""A decoder for Qwen-VL model built from the Edge Generative API layers.
|
27
|
-
|
29
|
+
"""A decoder for Qwen-VL model built from the Edge Generative API layers.
|
30
|
+
|
31
|
+
Besides a tensor of text token IDs, forward() can also take a tensor of
|
32
|
+
embeddings which may include text or image or both.
|
33
|
+
"""
|
34
|
+
|
35
|
+
@torch.inference_mode
|
36
|
+
def forward(
|
37
|
+
self,
|
38
|
+
tokens: torch.Tensor,
|
39
|
+
input_pos: torch.Tensor,
|
40
|
+
kv_cache: kv_utils.KVCache,
|
41
|
+
input_embeds: torch.Tensor = None,
|
42
|
+
rope: Tuple[torch.Tensor, torch.Tensor] = None,
|
43
|
+
mask: Optional[torch.Tensor] = None,
|
44
|
+
export_config: Optional[model_builder.ExportConfig] = None,
|
45
|
+
) -> dict[torch.Tensor, kv_utils.KVCache]:
|
46
|
+
if input_embeds is None:
|
47
|
+
_, seq_len = tokens.size()
|
48
|
+
assert self.config.max_seq_len >= seq_len, (
|
49
|
+
f"Cannot forward sequence of length {seq_len}, max seq length is only"
|
50
|
+
f" {self.config.max_seq_len}"
|
51
|
+
)
|
52
|
+
# token embeddings of shape (b, t, n_embd)
|
53
|
+
input_embeds = self.tok_embedding(tokens)
|
54
|
+
|
55
|
+
if rope is None:
|
56
|
+
# ROPE parameters for all attn_configs are the same. Take the first one.
|
57
|
+
attn_config = self.config.block_config(0).attn_config
|
58
|
+
n_elem = int(attn_config.rotary_percentage * attn_config.head_dim)
|
59
|
+
rope = self.config.build_rope(input_pos, n_elem, attn_config.rotary_base)
|
60
|
+
|
61
|
+
if mask is None:
|
62
|
+
mask = self.mask_cache.index_select(2, input_pos)
|
63
|
+
mask = mask[:, :, :, : self.config.kv_cache_max]
|
64
|
+
|
65
|
+
return self._forward_with_embeds(
|
66
|
+
input_embeds,
|
67
|
+
rope,
|
68
|
+
mask,
|
69
|
+
input_pos,
|
70
|
+
kv_cache,
|
71
|
+
export_config=export_config,
|
72
|
+
)
|
28
73
|
|
29
74
|
|
30
75
|
def get_decoder_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
@@ -82,7 +127,7 @@ def get_fake_decoder_config(**kwargs) -> cfg.ModelConfig:
|
|
82
127
|
return config
|
83
128
|
|
84
129
|
|
85
|
-
def build_decoder(checkpoint_path: str, **kwargs) -> nn.Module:
|
130
|
+
def build_decoder(checkpoint_path: str, **kwargs) -> torch.nn.Module:
|
86
131
|
return model_builder.build_decoder_only_model(
|
87
132
|
checkpoint_path=checkpoint_path,
|
88
133
|
config=get_decoder_config(**kwargs),
|
@@ -356,6 +356,12 @@ def get_fake_image_encoder_config() -> QwenVLImageConfig:
|
|
356
356
|
def build_image_encoder(checkpoint_path: str) -> QwenVLImageEncoder:
|
357
357
|
config = get_image_encoder_config()
|
358
358
|
encoder = QwenVLImageEncoder(config)
|
359
|
+
load_image_encoder(checkpoint_path, encoder)
|
360
|
+
encoder.eval()
|
361
|
+
return encoder
|
362
|
+
|
363
|
+
|
364
|
+
def load_image_encoder(checkpoint_path: str, encoder: QwenVLImageEncoder):
|
359
365
|
loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
|
360
366
|
# Loose the strictness because only image encoder is being loaded.
|
361
367
|
loader.load(encoder, strict=False)
|
@@ -365,15 +371,12 @@ def build_image_encoder(checkpoint_path: str) -> QwenVLImageEncoder:
|
|
365
371
|
state = merger_loader.get_state()
|
366
372
|
w1_state = dict()
|
367
373
|
w1_state["weight"] = state.pop(f"{MERGER_TENSOR_NAMES.ff_up_proj}.weight")
|
368
|
-
if config.merger_config.use_bias:
|
374
|
+
if encoder.config.merger_config.use_bias:
|
369
375
|
w1_state["bias"] = state.pop(f"{MERGER_TENSOR_NAMES.ff_up_proj}.bias")
|
370
376
|
encoder.merger.w1.load_state_dict(w1_state)
|
371
377
|
|
372
378
|
w2_state = dict()
|
373
379
|
w2_state["weight"] = state.pop(f"{MERGER_TENSOR_NAMES.ff_down_proj}.weight")
|
374
|
-
if config.merger_config.use_bias:
|
380
|
+
if encoder.config.merger_config.use_bias:
|
375
381
|
w2_state["bias"] = state.pop(f"{MERGER_TENSOR_NAMES.ff_down_proj}.bias")
|
376
382
|
encoder.merger.w2.load_state_dict(w2_state)
|
377
|
-
|
378
|
-
encoder.eval()
|
379
|
-
return encoder
|
@@ -0,0 +1,211 @@
|
|
1
|
+
# Copyright 2025 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Example of building a full-stack of Qwen 2.5 VL model."""
|
17
|
+
|
18
|
+
import dataclasses
|
19
|
+
from typing import List, Optional, Tuple
|
20
|
+
|
21
|
+
from ai_edge_torch.generative.examples.qwen_vl import decoder
|
22
|
+
from ai_edge_torch.generative.examples.qwen_vl import image_encoder
|
23
|
+
import ai_edge_torch.generative.layers.kv_cache as kv_utils
|
24
|
+
import ai_edge_torch.generative.layers.model_config as cfg
|
25
|
+
from ai_edge_torch.generative.utilities import model_builder
|
26
|
+
import ai_edge_torch.generative.utilities.loader as loading_utils
|
27
|
+
import torch
|
28
|
+
from torch import nn
|
29
|
+
|
30
|
+
|
31
|
+
@dataclasses.dataclass
|
32
|
+
class QwenVLConfig:
|
33
|
+
"""Qwen VL model configurations."""
|
34
|
+
|
35
|
+
image_encoder_config: image_encoder.QwenVLImageConfig
|
36
|
+
decoder_config: cfg.ModelConfig
|
37
|
+
image_token_id: int
|
38
|
+
mrope_section: List[int]
|
39
|
+
|
40
|
+
|
41
|
+
class QwenVL(nn.Module):
|
42
|
+
"""Qwen VL model from the Edge Generative API."""
|
43
|
+
|
44
|
+
def __init__(self, config: QwenVLConfig):
|
45
|
+
super().__init__()
|
46
|
+
|
47
|
+
self.image_encoder = image_encoder.QwenVLImageEncoder(
|
48
|
+
config.image_encoder_config
|
49
|
+
)
|
50
|
+
self.decoder = decoder.Decoder(config.decoder_config)
|
51
|
+
# The amount of adjustment in input_pos to calculate RoPE properly in
|
52
|
+
# forward() calls after image is handled.
|
53
|
+
self.rope_pos_adjust = 0
|
54
|
+
self.config = config
|
55
|
+
|
56
|
+
@torch.inference_mode
|
57
|
+
def forward(
|
58
|
+
self,
|
59
|
+
tokens: torch.Tensor,
|
60
|
+
input_pos: torch.Tensor,
|
61
|
+
kv_cache: kv_utils.KVCache,
|
62
|
+
mask: Optional[torch.Tensor] = None,
|
63
|
+
pixel_values: torch.Tensor = None,
|
64
|
+
grid_thw: torch.Tensor = None,
|
65
|
+
export_config: Optional[model_builder.ExportConfig] = None,
|
66
|
+
) -> dict[torch.Tensor, kv_utils.KVCache]:
|
67
|
+
if pixel_values is None:
|
68
|
+
return self.decoder(
|
69
|
+
tokens=tokens,
|
70
|
+
input_pos=input_pos,
|
71
|
+
kv_cache=kv_cache,
|
72
|
+
mask=mask,
|
73
|
+
rope=self._build_text_rope(input_pos),
|
74
|
+
input_embeds=None,
|
75
|
+
export_config=export_config,
|
76
|
+
)
|
77
|
+
|
78
|
+
input_embeds = self.decoder.tok_embedding(tokens)
|
79
|
+
image_embeds = self.image_encoder(pixel_values, grid_thw).unsqueeze(0)
|
80
|
+
|
81
|
+
# Merging image_embeds into text_embeds as PaliGemmaForConditionalGeneration
|
82
|
+
# can be done like:
|
83
|
+
#
|
84
|
+
# image_mask = tokens == self.config.image_token_id
|
85
|
+
# image_mask = image_mask.unsqueeze(-1).expand_as(input_embeds)
|
86
|
+
# input_embeds = input_embeds.masked_scatter(image_mask, image_embeds)
|
87
|
+
#
|
88
|
+
# Unfortunately, torch.Tensor.masked_scatter can't be lowered on CPU.
|
89
|
+
# Assume that image is put at the beginning of the input sequence wrapped
|
90
|
+
# with vision_start and vision_end tokens.
|
91
|
+
input_embeds = torch.cat(
|
92
|
+
(
|
93
|
+
input_embeds[:, :1, :],
|
94
|
+
image_embeds,
|
95
|
+
input_embeds[:, image_embeds.shape[1] + 1 :, :],
|
96
|
+
),
|
97
|
+
dim=1,
|
98
|
+
)
|
99
|
+
|
100
|
+
return self.decoder(
|
101
|
+
tokens=None,
|
102
|
+
input_pos=input_pos,
|
103
|
+
kv_cache=kv_cache,
|
104
|
+
mask=mask,
|
105
|
+
input_embeds=input_embeds,
|
106
|
+
rope=self._build_multimodal_rope(input_pos, grid_thw),
|
107
|
+
export_config=export_config,
|
108
|
+
)
|
109
|
+
|
110
|
+
def _build_rope(
|
111
|
+
self, rope_pos: torch.Tensor
|
112
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
113
|
+
# ROPE parameters for all attn_configs are the same. Take the first one.
|
114
|
+
attn_config = self.config.decoder_config.block_config(0).attn_config
|
115
|
+
n_elem = int(attn_config.rotary_percentage * attn_config.head_dim)
|
116
|
+
return self.config.decoder_config.build_rope(
|
117
|
+
rope_pos, n_elem, attn_config.rotary_base
|
118
|
+
)
|
119
|
+
|
120
|
+
def _build_text_rope(
|
121
|
+
self, input_pos: torch.Tensor
|
122
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
123
|
+
# Reset rope_pos_adjust to 0 when input sequence starts from scratch, i.e.
|
124
|
+
# input_pos[0] = 0.
|
125
|
+
if input_pos[0] == 0:
|
126
|
+
self.rope_pos_adjust = 0
|
127
|
+
return self._build_rope(input_pos + self.rope_pos_adjust)
|
128
|
+
|
129
|
+
def _build_multimodal_rope(
|
130
|
+
self, input_pos: torch.Tensor, grid_thw: torch.Tensor
|
131
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
132
|
+
"""Builds RoPE of multimodal input for the Qwen VL model.
|
133
|
+
|
134
|
+
It's copied from Qwen2_5_VLForConditionalGeneration.get_rope_index() and
|
135
|
+
simplified based on the assumption that an image is put at the beginning of
|
136
|
+
the input sequence with vision start and vision end tokens.
|
137
|
+
"""
|
138
|
+
spatial_merge_size = self.config.image_encoder_config.spatial_merge_size
|
139
|
+
height = grid_thw[0][1] // spatial_merge_size
|
140
|
+
width = grid_thw[0][2] // spatial_merge_size
|
141
|
+
image_pos_max = max(height, width)
|
142
|
+
image_pos_count = height * width
|
143
|
+
|
144
|
+
# The position of vision end tokek and text tokens and after the image.
|
145
|
+
text_pos_start = image_pos_max + 1
|
146
|
+
text_pos_count = len(input_pos) - image_pos_count - 1
|
147
|
+
text_pos = torch.arange(text_pos_start, text_pos_start + text_pos_count)
|
148
|
+
# Set input_pos_adjust since text_pos_start has changed.
|
149
|
+
self.rope_pos_adjust = image_pos_max - image_pos_count
|
150
|
+
|
151
|
+
temporal_rope = self._build_image_text_rope(
|
152
|
+
torch.ones(image_pos_count, dtype=torch.int), text_pos
|
153
|
+
)
|
154
|
+
height_rope = self._build_image_text_rope(
|
155
|
+
torch.arange(1, height + 1).view(-1, 1).expand(-1, width).flatten(),
|
156
|
+
text_pos,
|
157
|
+
)
|
158
|
+
width_rope = self._build_image_text_rope(
|
159
|
+
torch.arange(1, width + 1).view(1, -1).expand(height, -1).flatten(),
|
160
|
+
text_pos,
|
161
|
+
)
|
162
|
+
|
163
|
+
return (
|
164
|
+
self._merge_ropes(temporal_rope[0], height_rope[0], width_rope[0]),
|
165
|
+
self._merge_ropes(temporal_rope[1], height_rope[1], width_rope[1]),
|
166
|
+
)
|
167
|
+
|
168
|
+
def _build_image_text_rope(
|
169
|
+
self, image_pos: torch.Tensor, text_pos: torch.Tensor
|
170
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
171
|
+
return self._build_rope(
|
172
|
+
torch.cat((torch.zeros(1, dtype=torch.int), image_pos, text_pos))
|
173
|
+
)
|
174
|
+
|
175
|
+
def _merge_ropes(self, a: torch.Tensor, b: torch.Tensor, c: torch.Tensor):
|
176
|
+
"""Merges RoPE tensors based on apply_multimodal_rotary_pos_emb()."""
|
177
|
+
split = torch.stack([a, b, c]).split(self.config.mrope_section, dim=-1)
|
178
|
+
return torch.cat([m[i % 3] for i, m in enumerate(split)], dim=-1)
|
179
|
+
|
180
|
+
|
181
|
+
def get_model_config(**kwargs) -> QwenVLConfig:
|
182
|
+
"""Returns the model config for a PaliGemma 3B-224 model.
|
183
|
+
|
184
|
+
Returns:
|
185
|
+
The model config for a PaliGemma 3B model.
|
186
|
+
"""
|
187
|
+
return QwenVLConfig(
|
188
|
+
image_encoder_config=image_encoder.get_image_encoder_config(),
|
189
|
+
decoder_config=decoder.get_decoder_config(**kwargs),
|
190
|
+
image_token_id=151655,
|
191
|
+
mrope_section=[16, 24, 24],
|
192
|
+
)
|
193
|
+
|
194
|
+
|
195
|
+
def get_fake_model_config(**kwargs) -> QwenVLConfig:
|
196
|
+
return QwenVLConfig(
|
197
|
+
image_encoder_config=image_encoder.get_fake_image_encoder_config(),
|
198
|
+
decoder_config=decoder.get_fake_decoder_config(**kwargs),
|
199
|
+
image_token_id=127,
|
200
|
+
)
|
201
|
+
|
202
|
+
|
203
|
+
def build_model(checkpoint_path: str, **kwargs) -> QwenVL:
|
204
|
+
config = get_model_config(**kwargs)
|
205
|
+
model = QwenVL(config)
|
206
|
+
image_encoder.load_image_encoder(checkpoint_path, model.image_encoder)
|
207
|
+
# Load the parameters of decoder.
|
208
|
+
loader = loading_utils.ModelLoader(checkpoint_path, decoder.TENSOR_NAMES)
|
209
|
+
loader.load(model.decoder, strict=False)
|
210
|
+
model.eval()
|
211
|
+
return model
|
@@ -0,0 +1,143 @@
|
|
1
|
+
# Copyright 2025 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Verifies the reauthored Qwen 2.5 VL model."""
|
17
|
+
|
18
|
+
import logging
|
19
|
+
import pathlib
|
20
|
+
from absl import app
|
21
|
+
from absl import flags
|
22
|
+
from ai_edge_torch.generative.examples.qwen_vl import qwen_vl
|
23
|
+
from ai_edge_torch.generative.layers import kv_cache
|
24
|
+
from ai_edge_torch.generative.utilities import verifier
|
25
|
+
from PIL import Image
|
26
|
+
import requests
|
27
|
+
import torch
|
28
|
+
import transformers
|
29
|
+
|
30
|
+
_IMAGE_URL = flags.DEFINE_string(
|
31
|
+
"image_url",
|
32
|
+
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg?download=true",
|
33
|
+
"The image URI to encode.",
|
34
|
+
)
|
35
|
+
_PROMPTS = flags.DEFINE_string(
|
36
|
+
"prompts",
|
37
|
+
"<|vision_start|><|image_pad|><|vision_end|>Describe the image.<|im_end|>",
|
38
|
+
"The input prompts to generate answers.",
|
39
|
+
)
|
40
|
+
_MAX_NEW_TOKENS = flags.DEFINE_integer(
|
41
|
+
"max_new_tokens",
|
42
|
+
30,
|
43
|
+
"The maximum size of the generated tokens.",
|
44
|
+
)
|
45
|
+
|
46
|
+
|
47
|
+
class ReauthoredQwenVLWrapper(verifier.ReauthoredModelWrapper):
|
48
|
+
"""Reauthored Qwen VL model wrapper."""
|
49
|
+
|
50
|
+
def __init__(self, model: torch.nn.Module):
|
51
|
+
super().__init__(model)
|
52
|
+
self.grid_thw = None
|
53
|
+
|
54
|
+
def _init_kv_cache(self):
|
55
|
+
return kv_cache.KVCache.from_model_config(self.model.config.decoder_config)
|
56
|
+
|
57
|
+
def _get_extra_args_for_forward(self):
|
58
|
+
return {"grid_thw": self.grid_thw}
|
59
|
+
|
60
|
+
|
61
|
+
def main(_):
|
62
|
+
checkpoint = "Qwen/Qwen2.5-VL-3B-Instruct"
|
63
|
+
logging.info("Loading the original model from: %s", checkpoint)
|
64
|
+
original_model = (
|
65
|
+
transformers.Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
66
|
+
checkpoint
|
67
|
+
)
|
68
|
+
)
|
69
|
+
|
70
|
+
# Locate the cached dir.
|
71
|
+
cached_config_file = transformers.utils.cached_file(
|
72
|
+
checkpoint, transformers.utils.CONFIG_NAME
|
73
|
+
)
|
74
|
+
reauthored_checkpoint = pathlib.Path(cached_config_file).parent
|
75
|
+
logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
|
76
|
+
reauthored_model = qwen_vl.build_model(reauthored_checkpoint)
|
77
|
+
|
78
|
+
logging.info("Loading the processor from: %s", checkpoint)
|
79
|
+
processor = transformers.AutoProcessor.from_pretrained(checkpoint)
|
80
|
+
|
81
|
+
logging.info("Loading the image from: %s", _IMAGE_URL.value)
|
82
|
+
image = Image.open(requests.get(_IMAGE_URL.value, stream=True).raw)
|
83
|
+
inputs = processor(text=_PROMPTS.value, images=image, return_tensors="pt")
|
84
|
+
|
85
|
+
logging.info("Verifying the reauthored model with model.forward()...")
|
86
|
+
logging.info("Forwarding the original model...")
|
87
|
+
outputs_original = original_model.forward(
|
88
|
+
input_ids=inputs["input_ids"],
|
89
|
+
pixel_values=inputs["pixel_values"],
|
90
|
+
image_grid_thw=inputs["image_grid_thw"],
|
91
|
+
)
|
92
|
+
outputs_original = outputs_original.logits
|
93
|
+
logging.info("outputs_original: %s", outputs_original)
|
94
|
+
|
95
|
+
logging.info("Forwarding the reauthored model...")
|
96
|
+
wrapped_reauthored_model = ReauthoredQwenVLWrapper(reauthored_model)
|
97
|
+
wrapped_reauthored_model.grid_thw = inputs["image_grid_thw"]
|
98
|
+
outputs_reauthored = wrapped_reauthored_model.forward(
|
99
|
+
tokens=inputs["input_ids"],
|
100
|
+
pixel_values=inputs["pixel_values"],
|
101
|
+
)
|
102
|
+
logging.info("outputs_reauthored: %s", outputs_reauthored)
|
103
|
+
|
104
|
+
try:
|
105
|
+
assert torch.allclose(outputs_original, outputs_reauthored, atol=1e-01)
|
106
|
+
except AssertionError as e:
|
107
|
+
logging.error("*** FAILED *** verify with forward()")
|
108
|
+
raise e
|
109
|
+
else:
|
110
|
+
logging.info("*** PASSED *** verify with forward()")
|
111
|
+
|
112
|
+
logging.info("Verifying the reauthored model with model.generate()...")
|
113
|
+
logging.info("Generating answer with the original model...")
|
114
|
+
outputs_original = original_model.generate(
|
115
|
+
**inputs, max_new_tokens=_MAX_NEW_TOKENS.value
|
116
|
+
)
|
117
|
+
response_original = processor.decode(
|
118
|
+
outputs_original[0], skip_special_tokens=True
|
119
|
+
)
|
120
|
+
logging.info("outputs_from_original_model: [[%s]]", response_original)
|
121
|
+
|
122
|
+
logging.info("Generating answer with the reauthored model...")
|
123
|
+
outputs_reauthored = wrapped_reauthored_model.generate(
|
124
|
+
prompts=inputs["input_ids"],
|
125
|
+
pixel_values=inputs["pixel_values"],
|
126
|
+
max_new_tokens=_MAX_NEW_TOKENS.value,
|
127
|
+
)
|
128
|
+
response_reauthored = processor.decode(
|
129
|
+
outputs_reauthored[0], skip_special_tokens=True
|
130
|
+
)
|
131
|
+
logging.info("outputs from reauthored model: [[%s]]", response_reauthored)
|
132
|
+
|
133
|
+
try:
|
134
|
+
assert response_original == response_reauthored
|
135
|
+
except AssertionError as e:
|
136
|
+
logging.error("*** FAILED *** verify with generate()")
|
137
|
+
raise e
|
138
|
+
else:
|
139
|
+
logging.info("*** PASSED *** verify with generate()")
|
140
|
+
|
141
|
+
|
142
|
+
if __name__ == "__main__":
|
143
|
+
app.run(main)
|
@@ -12,5 +12,5 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
from ._build import build_mlir_debuginfo
|
15
|
+
from ._build import build_mlir_debuginfo, build_mlir_file_debuginfo
|
16
16
|
from ._op_polyfill import write_mlir_debuginfo_op
|
@@ -13,6 +13,7 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
import torch
|
16
|
+
import re
|
16
17
|
|
17
18
|
|
18
19
|
def _class_fullname(cls):
|
@@ -34,6 +35,29 @@ def _get_hierarchy(node: torch.fx.Node):
|
|
34
35
|
return hierachy_str
|
35
36
|
|
36
37
|
|
38
|
+
def _get_canonical_filename(filename):
|
39
|
+
"""Remove unnecessary path prefix to make the filename more readable.
|
40
|
+
|
41
|
+
This should be factored out so that pattern is a global option that a user
|
42
|
+
can override.
|
43
|
+
"""
|
44
|
+
|
45
|
+
# TODO: We should add a config option to provide a regex to strip from the
|
46
|
+
# debug info. Currently absolute path is used.
|
47
|
+
return filename
|
48
|
+
|
49
|
+
|
50
|
+
def build_mlir_file_debuginfo(node: torch.fx.Node):
|
51
|
+
"""Build the file and line info for the given node's lowerings in MLIR."""
|
52
|
+
|
53
|
+
if not node.stack_trace:
|
54
|
+
return None, None
|
55
|
+
|
56
|
+
# Note: This uses internal APIs and may break in the future.
|
57
|
+
pt_trace = torch.fx.graph._parse_stack_trace(node.stack_trace)
|
58
|
+
return _get_canonical_filename(pt_trace.file), int(pt_trace.lineno)
|
59
|
+
|
60
|
+
|
37
61
|
def build_mlir_debuginfo(node: torch.fx.Node):
|
38
62
|
"""Build the debuginfo string for the given node's lowerings in MLIR."""
|
39
63
|
|
@@ -93,7 +93,12 @@ class LoweringInterpreter(torch.fx.Interpreter):
|
|
93
93
|
if info is None:
|
94
94
|
return ir.Location.unknown()
|
95
95
|
|
96
|
-
|
96
|
+
(file, line) = debuginfo.build_mlir_file_debuginfo(node)
|
97
|
+
fileinfo = None
|
98
|
+
if file is not None:
|
99
|
+
fileinfo = ir.Location.file(filename=file, line=line, col=0)
|
100
|
+
|
101
|
+
return ir.Location.name(name=info, childLoc=fileinfo)
|
97
102
|
|
98
103
|
def run_node(self, node: torch.fx.Node):
|
99
104
|
loc = self._build_loc(node)
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.3.0.
|
3
|
+
Version: 0.3.0.dev20250207
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=8sPR_5uXJA4NEE0nIwNdSl-ADOJEoR8hAgYvBQDY70Y,120
|
|
2
2
|
ai_edge_torch/_config.py,sha256=AiqhbcheF7j_ozIGDLC89k1we95aVgFDa-tR6h7UI0s,2529
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
4
4
|
ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
|
5
|
-
ai_edge_torch/version.py,sha256=
|
5
|
+
ai_edge_torch/version.py,sha256=9V9FbxtqLT70Tzmv_G0qlbqixmVc0pPPJs22C_iBlHE,706
|
6
6
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
7
7
|
ai_edge_torch/_convert/conversion.py,sha256=mckvxznKLXdF2HuJg_IxQaT5Ty-iWl_iXElHEugH3VI,5452
|
8
8
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
@@ -94,9 +94,11 @@ ai_edge_torch/generative/examples/qwen/convert_to_tflite.py,sha256=tqvXVGNdDehda
|
|
94
94
|
ai_edge_torch/generative/examples/qwen/qwen.py,sha256=Zi_qiQ1JPokXZ95jgSEnQp3F-LKzFCvWvFLKhJjnASo,4199
|
95
95
|
ai_edge_torch/generative/examples/qwen/verify.py,sha256=9_AyEJTeUfvhhID64Rto2bflFPyXMFokdQLsseLUMiI,2775
|
96
96
|
ai_edge_torch/generative/examples/qwen_vl/__init__.py,sha256=JaAnrFoXTl3RJX97XspklkTyqOHVyAgRJsZtzNDd10c,671
|
97
|
-
ai_edge_torch/generative/examples/qwen_vl/decoder.py,sha256=
|
98
|
-
ai_edge_torch/generative/examples/qwen_vl/image_encoder.py,sha256=
|
99
|
-
ai_edge_torch/generative/examples/qwen_vl/
|
97
|
+
ai_edge_torch/generative/examples/qwen_vl/decoder.py,sha256=0x4iDg2cBe3PFnjVce3nj7g2rjagGHcKqRCfbASNxA8,4402
|
98
|
+
ai_edge_torch/generative/examples/qwen_vl/image_encoder.py,sha256=OYyF0bLVYJno9azmKDqX3gT8ojYYWEyp_F8nLtltPWs,13544
|
99
|
+
ai_edge_torch/generative/examples/qwen_vl/qwen_vl.py,sha256=Uzl1ZPkdYIaHN9QxezqxNwagZiGOHf1VreWnqgRQwf8,7627
|
100
|
+
ai_edge_torch/generative/examples/qwen_vl/verify.py,sha256=2GPi0Vay4a69EwBSOfPMCMjE9PTwPOQus5j2KN7HE7I,5031
|
101
|
+
ai_edge_torch/generative/examples/qwen_vl/verify_decoder.py,sha256=xPWoOBLh2eK12KEhELLYymfL7xvc0chmYC98c6x37oo,2602
|
100
102
|
ai_edge_torch/generative/examples/qwen_vl/verify_image_encoder.py,sha256=lQR8p6Zp7PxDN_erMf-FKLIn_Rv4BGyQHjDbModFkeY,2946
|
101
103
|
ai_edge_torch/generative/examples/smollm/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
102
104
|
ai_edge_torch/generative/examples/smollm/convert_to_tflite.py,sha256=megskv1oiPhwHSnguoG7zV-esXp1Ns_FPeMLAYKhDb0,2522
|
@@ -195,14 +197,14 @@ ai_edge_torch/lowertools/translate_recipe.py,sha256=ymkBpFqAUiupRWqrPOWiVphKcXR1
|
|
195
197
|
ai_edge_torch/odml_torch/__init__.py,sha256=S8jOzE9nLof-6es3XDiGJRN-9H_XTxsVm9dE7lD3RWo,812
|
196
198
|
ai_edge_torch/odml_torch/_torch_future.py,sha256=jSYHf1CMTJzMizPMbu2b39hAt0ZTR6gQLq67GMe9KTo,2336
|
197
199
|
ai_edge_torch/odml_torch/_torch_library.py,sha256=Lw1gqL2HWNRspdTwNhIkYAHDyafHedHtkXyKKxn-Wss,805
|
198
|
-
ai_edge_torch/odml_torch/export.py,sha256=
|
200
|
+
ai_edge_torch/odml_torch/export.py,sha256=LDyZUehM1lmT3y2bGeA94rMGRUTLxzIUm4DTlCA8tQc,13640
|
199
201
|
ai_edge_torch/odml_torch/export_utils.py,sha256=QeA37Irlty6AiIBuqmHmJgn3lqahBQ5xsh6IKRoKm1g,4774
|
200
202
|
ai_edge_torch/odml_torch/tf_integration.py,sha256=NN29WeXmHZ0S1RPDFHUnBi2DEjMvAtwczStPYIsQ1w8,4849
|
201
203
|
ai_edge_torch/odml_torch/composite/__init__.py,sha256=71GM_gDZxJyo38ZSoYSwhZX3xKA9rknO93JS9kw9w_c,778
|
202
204
|
ai_edge_torch/odml_torch/composite/mark_tensor.py,sha256=U--rwl-XkWKgkdXCXDn6yySug8FR66o1YFUAIoSaWW4,3523
|
203
205
|
ai_edge_torch/odml_torch/composite/stablehlo_composite_builder.py,sha256=2Y52E_gLeoXpMcPpV-svXsgN3JbEIjnPVjm0xkpTUdQ,3319
|
204
|
-
ai_edge_torch/odml_torch/debuginfo/__init__.py,sha256=
|
205
|
-
ai_edge_torch/odml_torch/debuginfo/_build.py,sha256=
|
206
|
+
ai_edge_torch/odml_torch/debuginfo/__init__.py,sha256=DoE3HgAtV_GNKGBDGzH2Lb7JUHvyH7TUqWbDZIObr34,789
|
207
|
+
ai_edge_torch/odml_torch/debuginfo/_build.py,sha256=sjpYeqgdbDmD7lhp80yc8jfWq-HxX3xuQ58ND8ZeU-I,2213
|
206
208
|
ai_edge_torch/odml_torch/debuginfo/_op_polyfill.py,sha256=IvOBQyROI9WHS3umHRxsDW-1YElU9BPWzKtJA2eKWOI,1739
|
207
209
|
ai_edge_torch/odml_torch/jax_bridge/__init__.py,sha256=e9Oa4J3An9FYr3zM0OzjzyNNitEeg-IoBUGNfUxsDSA,798
|
208
210
|
ai_edge_torch/odml_torch/jax_bridge/_wrap.py,sha256=LqwZ1vCJTSOzgzvH8LUAN-sAkF-l_pGj1AMEIzAqHCA,6638
|
@@ -227,8 +229,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
|
|
227
229
|
ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
228
230
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
229
231
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
230
|
-
ai_edge_torch_nightly-0.3.0.
|
231
|
-
ai_edge_torch_nightly-0.3.0.
|
232
|
-
ai_edge_torch_nightly-0.3.0.
|
233
|
-
ai_edge_torch_nightly-0.3.0.
|
234
|
-
ai_edge_torch_nightly-0.3.0.
|
232
|
+
ai_edge_torch_nightly-0.3.0.dev20250207.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
233
|
+
ai_edge_torch_nightly-0.3.0.dev20250207.dist-info/METADATA,sha256=pvcJfgIOezx3rNegfvMIVrkFXmZuqnnE_zMzC9Wt37k,1966
|
234
|
+
ai_edge_torch_nightly-0.3.0.dev20250207.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
|
235
|
+
ai_edge_torch_nightly-0.3.0.dev20250207.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
236
|
+
ai_edge_torch_nightly-0.3.0.dev20250207.dist-info/RECORD,,
|
File without changes
|
File without changes
|