ai-edge-torch-nightly 0.3.0.dev20250205__py3-none-any.whl → 0.3.0.dev20250207__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/generative/examples/qwen_vl/decoder.py +49 -4
- ai_edge_torch/generative/examples/qwen_vl/image_encoder.py +8 -5
- ai_edge_torch/generative/examples/qwen_vl/qwen_vl.py +211 -0
- ai_edge_torch/generative/examples/qwen_vl/verify.py +143 -0
- ai_edge_torch/generative/examples/qwen_vl/verify_decoder.py +1 -1
- ai_edge_torch/odml_torch/debuginfo/__init__.py +1 -1
- ai_edge_torch/odml_torch/debuginfo/_build.py +24 -0
- ai_edge_torch/odml_torch/export.py +6 -1
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20250205.dist-info → ai_edge_torch_nightly-0.3.0.dev20250207.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20250205.dist-info → ai_edge_torch_nightly-0.3.0.dev20250207.dist-info}/RECORD +14 -12
- {ai_edge_torch_nightly-0.3.0.dev20250205.dist-info → ai_edge_torch_nightly-0.3.0.dev20250207.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20250205.dist-info → ai_edge_torch_nightly-0.3.0.dev20250207.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20250205.dist-info → ai_edge_torch_nightly-0.3.0.dev20250207.dist-info}/top_level.txt +0 -0
@@ -15,16 +15,61 @@
|
|
15
15
|
|
16
16
|
"""Example of building decoder for Qwen 2.5 VL models."""
|
17
17
|
|
18
|
+
from typing import Optional, Tuple
|
19
|
+
|
20
|
+
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
18
21
|
import ai_edge_torch.generative.layers.model_config as cfg
|
19
22
|
from ai_edge_torch.generative.utilities import model_builder
|
20
|
-
|
23
|
+
import torch
|
21
24
|
|
22
25
|
TENSOR_NAMES = model_builder.TENSOR_NAMES
|
23
26
|
|
24
27
|
|
25
28
|
class Decoder(model_builder.DecoderOnlyModel):
|
26
|
-
"""A decoder for Qwen-VL model built from the Edge Generative API layers.
|
27
|
-
|
29
|
+
"""A decoder for Qwen-VL model built from the Edge Generative API layers.
|
30
|
+
|
31
|
+
Besides a tensor of text token IDs, forward() can also take a tensor of
|
32
|
+
embeddings which may include text or image or both.
|
33
|
+
"""
|
34
|
+
|
35
|
+
@torch.inference_mode
|
36
|
+
def forward(
|
37
|
+
self,
|
38
|
+
tokens: torch.Tensor,
|
39
|
+
input_pos: torch.Tensor,
|
40
|
+
kv_cache: kv_utils.KVCache,
|
41
|
+
input_embeds: torch.Tensor = None,
|
42
|
+
rope: Tuple[torch.Tensor, torch.Tensor] = None,
|
43
|
+
mask: Optional[torch.Tensor] = None,
|
44
|
+
export_config: Optional[model_builder.ExportConfig] = None,
|
45
|
+
) -> dict[torch.Tensor, kv_utils.KVCache]:
|
46
|
+
if input_embeds is None:
|
47
|
+
_, seq_len = tokens.size()
|
48
|
+
assert self.config.max_seq_len >= seq_len, (
|
49
|
+
f"Cannot forward sequence of length {seq_len}, max seq length is only"
|
50
|
+
f" {self.config.max_seq_len}"
|
51
|
+
)
|
52
|
+
# token embeddings of shape (b, t, n_embd)
|
53
|
+
input_embeds = self.tok_embedding(tokens)
|
54
|
+
|
55
|
+
if rope is None:
|
56
|
+
# ROPE parameters for all attn_configs are the same. Take the first one.
|
57
|
+
attn_config = self.config.block_config(0).attn_config
|
58
|
+
n_elem = int(attn_config.rotary_percentage * attn_config.head_dim)
|
59
|
+
rope = self.config.build_rope(input_pos, n_elem, attn_config.rotary_base)
|
60
|
+
|
61
|
+
if mask is None:
|
62
|
+
mask = self.mask_cache.index_select(2, input_pos)
|
63
|
+
mask = mask[:, :, :, : self.config.kv_cache_max]
|
64
|
+
|
65
|
+
return self._forward_with_embeds(
|
66
|
+
input_embeds,
|
67
|
+
rope,
|
68
|
+
mask,
|
69
|
+
input_pos,
|
70
|
+
kv_cache,
|
71
|
+
export_config=export_config,
|
72
|
+
)
|
28
73
|
|
29
74
|
|
30
75
|
def get_decoder_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
@@ -82,7 +127,7 @@ def get_fake_decoder_config(**kwargs) -> cfg.ModelConfig:
|
|
82
127
|
return config
|
83
128
|
|
84
129
|
|
85
|
-
def build_decoder(checkpoint_path: str, **kwargs) -> nn.Module:
|
130
|
+
def build_decoder(checkpoint_path: str, **kwargs) -> torch.nn.Module:
|
86
131
|
return model_builder.build_decoder_only_model(
|
87
132
|
checkpoint_path=checkpoint_path,
|
88
133
|
config=get_decoder_config(**kwargs),
|
@@ -356,6 +356,12 @@ def get_fake_image_encoder_config() -> QwenVLImageConfig:
|
|
356
356
|
def build_image_encoder(checkpoint_path: str) -> QwenVLImageEncoder:
|
357
357
|
config = get_image_encoder_config()
|
358
358
|
encoder = QwenVLImageEncoder(config)
|
359
|
+
load_image_encoder(checkpoint_path, encoder)
|
360
|
+
encoder.eval()
|
361
|
+
return encoder
|
362
|
+
|
363
|
+
|
364
|
+
def load_image_encoder(checkpoint_path: str, encoder: QwenVLImageEncoder):
|
359
365
|
loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
|
360
366
|
# Loose the strictness because only image encoder is being loaded.
|
361
367
|
loader.load(encoder, strict=False)
|
@@ -365,15 +371,12 @@ def build_image_encoder(checkpoint_path: str) -> QwenVLImageEncoder:
|
|
365
371
|
state = merger_loader.get_state()
|
366
372
|
w1_state = dict()
|
367
373
|
w1_state["weight"] = state.pop(f"{MERGER_TENSOR_NAMES.ff_up_proj}.weight")
|
368
|
-
if config.merger_config.use_bias:
|
374
|
+
if encoder.config.merger_config.use_bias:
|
369
375
|
w1_state["bias"] = state.pop(f"{MERGER_TENSOR_NAMES.ff_up_proj}.bias")
|
370
376
|
encoder.merger.w1.load_state_dict(w1_state)
|
371
377
|
|
372
378
|
w2_state = dict()
|
373
379
|
w2_state["weight"] = state.pop(f"{MERGER_TENSOR_NAMES.ff_down_proj}.weight")
|
374
|
-
if config.merger_config.use_bias:
|
380
|
+
if encoder.config.merger_config.use_bias:
|
375
381
|
w2_state["bias"] = state.pop(f"{MERGER_TENSOR_NAMES.ff_down_proj}.bias")
|
376
382
|
encoder.merger.w2.load_state_dict(w2_state)
|
377
|
-
|
378
|
-
encoder.eval()
|
379
|
-
return encoder
|
@@ -0,0 +1,211 @@
|
|
1
|
+
# Copyright 2025 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Example of building a full-stack of Qwen 2.5 VL model."""
|
17
|
+
|
18
|
+
import dataclasses
|
19
|
+
from typing import List, Optional, Tuple
|
20
|
+
|
21
|
+
from ai_edge_torch.generative.examples.qwen_vl import decoder
|
22
|
+
from ai_edge_torch.generative.examples.qwen_vl import image_encoder
|
23
|
+
import ai_edge_torch.generative.layers.kv_cache as kv_utils
|
24
|
+
import ai_edge_torch.generative.layers.model_config as cfg
|
25
|
+
from ai_edge_torch.generative.utilities import model_builder
|
26
|
+
import ai_edge_torch.generative.utilities.loader as loading_utils
|
27
|
+
import torch
|
28
|
+
from torch import nn
|
29
|
+
|
30
|
+
|
31
|
+
@dataclasses.dataclass
|
32
|
+
class QwenVLConfig:
|
33
|
+
"""Qwen VL model configurations."""
|
34
|
+
|
35
|
+
image_encoder_config: image_encoder.QwenVLImageConfig
|
36
|
+
decoder_config: cfg.ModelConfig
|
37
|
+
image_token_id: int
|
38
|
+
mrope_section: List[int]
|
39
|
+
|
40
|
+
|
41
|
+
class QwenVL(nn.Module):
|
42
|
+
"""Qwen VL model from the Edge Generative API."""
|
43
|
+
|
44
|
+
def __init__(self, config: QwenVLConfig):
|
45
|
+
super().__init__()
|
46
|
+
|
47
|
+
self.image_encoder = image_encoder.QwenVLImageEncoder(
|
48
|
+
config.image_encoder_config
|
49
|
+
)
|
50
|
+
self.decoder = decoder.Decoder(config.decoder_config)
|
51
|
+
# The amount of adjustment in input_pos to calculate RoPE properly in
|
52
|
+
# forward() calls after image is handled.
|
53
|
+
self.rope_pos_adjust = 0
|
54
|
+
self.config = config
|
55
|
+
|
56
|
+
@torch.inference_mode
|
57
|
+
def forward(
|
58
|
+
self,
|
59
|
+
tokens: torch.Tensor,
|
60
|
+
input_pos: torch.Tensor,
|
61
|
+
kv_cache: kv_utils.KVCache,
|
62
|
+
mask: Optional[torch.Tensor] = None,
|
63
|
+
pixel_values: torch.Tensor = None,
|
64
|
+
grid_thw: torch.Tensor = None,
|
65
|
+
export_config: Optional[model_builder.ExportConfig] = None,
|
66
|
+
) -> dict[torch.Tensor, kv_utils.KVCache]:
|
67
|
+
if pixel_values is None:
|
68
|
+
return self.decoder(
|
69
|
+
tokens=tokens,
|
70
|
+
input_pos=input_pos,
|
71
|
+
kv_cache=kv_cache,
|
72
|
+
mask=mask,
|
73
|
+
rope=self._build_text_rope(input_pos),
|
74
|
+
input_embeds=None,
|
75
|
+
export_config=export_config,
|
76
|
+
)
|
77
|
+
|
78
|
+
input_embeds = self.decoder.tok_embedding(tokens)
|
79
|
+
image_embeds = self.image_encoder(pixel_values, grid_thw).unsqueeze(0)
|
80
|
+
|
81
|
+
# Merging image_embeds into text_embeds as PaliGemmaForConditionalGeneration
|
82
|
+
# can be done like:
|
83
|
+
#
|
84
|
+
# image_mask = tokens == self.config.image_token_id
|
85
|
+
# image_mask = image_mask.unsqueeze(-1).expand_as(input_embeds)
|
86
|
+
# input_embeds = input_embeds.masked_scatter(image_mask, image_embeds)
|
87
|
+
#
|
88
|
+
# Unfortunately, torch.Tensor.masked_scatter can't be lowered on CPU.
|
89
|
+
# Assume that image is put at the beginning of the input sequence wrapped
|
90
|
+
# with vision_start and vision_end tokens.
|
91
|
+
input_embeds = torch.cat(
|
92
|
+
(
|
93
|
+
input_embeds[:, :1, :],
|
94
|
+
image_embeds,
|
95
|
+
input_embeds[:, image_embeds.shape[1] + 1 :, :],
|
96
|
+
),
|
97
|
+
dim=1,
|
98
|
+
)
|
99
|
+
|
100
|
+
return self.decoder(
|
101
|
+
tokens=None,
|
102
|
+
input_pos=input_pos,
|
103
|
+
kv_cache=kv_cache,
|
104
|
+
mask=mask,
|
105
|
+
input_embeds=input_embeds,
|
106
|
+
rope=self._build_multimodal_rope(input_pos, grid_thw),
|
107
|
+
export_config=export_config,
|
108
|
+
)
|
109
|
+
|
110
|
+
def _build_rope(
|
111
|
+
self, rope_pos: torch.Tensor
|
112
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
113
|
+
# ROPE parameters for all attn_configs are the same. Take the first one.
|
114
|
+
attn_config = self.config.decoder_config.block_config(0).attn_config
|
115
|
+
n_elem = int(attn_config.rotary_percentage * attn_config.head_dim)
|
116
|
+
return self.config.decoder_config.build_rope(
|
117
|
+
rope_pos, n_elem, attn_config.rotary_base
|
118
|
+
)
|
119
|
+
|
120
|
+
def _build_text_rope(
|
121
|
+
self, input_pos: torch.Tensor
|
122
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
123
|
+
# Reset rope_pos_adjust to 0 when input sequence starts from scratch, i.e.
|
124
|
+
# input_pos[0] = 0.
|
125
|
+
if input_pos[0] == 0:
|
126
|
+
self.rope_pos_adjust = 0
|
127
|
+
return self._build_rope(input_pos + self.rope_pos_adjust)
|
128
|
+
|
129
|
+
def _build_multimodal_rope(
|
130
|
+
self, input_pos: torch.Tensor, grid_thw: torch.Tensor
|
131
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
132
|
+
"""Builds RoPE of multimodal input for the Qwen VL model.
|
133
|
+
|
134
|
+
It's copied from Qwen2_5_VLForConditionalGeneration.get_rope_index() and
|
135
|
+
simplified based on the assumption that an image is put at the beginning of
|
136
|
+
the input sequence with vision start and vision end tokens.
|
137
|
+
"""
|
138
|
+
spatial_merge_size = self.config.image_encoder_config.spatial_merge_size
|
139
|
+
height = grid_thw[0][1] // spatial_merge_size
|
140
|
+
width = grid_thw[0][2] // spatial_merge_size
|
141
|
+
image_pos_max = max(height, width)
|
142
|
+
image_pos_count = height * width
|
143
|
+
|
144
|
+
# The position of vision end tokek and text tokens and after the image.
|
145
|
+
text_pos_start = image_pos_max + 1
|
146
|
+
text_pos_count = len(input_pos) - image_pos_count - 1
|
147
|
+
text_pos = torch.arange(text_pos_start, text_pos_start + text_pos_count)
|
148
|
+
# Set input_pos_adjust since text_pos_start has changed.
|
149
|
+
self.rope_pos_adjust = image_pos_max - image_pos_count
|
150
|
+
|
151
|
+
temporal_rope = self._build_image_text_rope(
|
152
|
+
torch.ones(image_pos_count, dtype=torch.int), text_pos
|
153
|
+
)
|
154
|
+
height_rope = self._build_image_text_rope(
|
155
|
+
torch.arange(1, height + 1).view(-1, 1).expand(-1, width).flatten(),
|
156
|
+
text_pos,
|
157
|
+
)
|
158
|
+
width_rope = self._build_image_text_rope(
|
159
|
+
torch.arange(1, width + 1).view(1, -1).expand(height, -1).flatten(),
|
160
|
+
text_pos,
|
161
|
+
)
|
162
|
+
|
163
|
+
return (
|
164
|
+
self._merge_ropes(temporal_rope[0], height_rope[0], width_rope[0]),
|
165
|
+
self._merge_ropes(temporal_rope[1], height_rope[1], width_rope[1]),
|
166
|
+
)
|
167
|
+
|
168
|
+
def _build_image_text_rope(
|
169
|
+
self, image_pos: torch.Tensor, text_pos: torch.Tensor
|
170
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
171
|
+
return self._build_rope(
|
172
|
+
torch.cat((torch.zeros(1, dtype=torch.int), image_pos, text_pos))
|
173
|
+
)
|
174
|
+
|
175
|
+
def _merge_ropes(self, a: torch.Tensor, b: torch.Tensor, c: torch.Tensor):
|
176
|
+
"""Merges RoPE tensors based on apply_multimodal_rotary_pos_emb()."""
|
177
|
+
split = torch.stack([a, b, c]).split(self.config.mrope_section, dim=-1)
|
178
|
+
return torch.cat([m[i % 3] for i, m in enumerate(split)], dim=-1)
|
179
|
+
|
180
|
+
|
181
|
+
def get_model_config(**kwargs) -> QwenVLConfig:
|
182
|
+
"""Returns the model config for a PaliGemma 3B-224 model.
|
183
|
+
|
184
|
+
Returns:
|
185
|
+
The model config for a PaliGemma 3B model.
|
186
|
+
"""
|
187
|
+
return QwenVLConfig(
|
188
|
+
image_encoder_config=image_encoder.get_image_encoder_config(),
|
189
|
+
decoder_config=decoder.get_decoder_config(**kwargs),
|
190
|
+
image_token_id=151655,
|
191
|
+
mrope_section=[16, 24, 24],
|
192
|
+
)
|
193
|
+
|
194
|
+
|
195
|
+
def get_fake_model_config(**kwargs) -> QwenVLConfig:
|
196
|
+
return QwenVLConfig(
|
197
|
+
image_encoder_config=image_encoder.get_fake_image_encoder_config(),
|
198
|
+
decoder_config=decoder.get_fake_decoder_config(**kwargs),
|
199
|
+
image_token_id=127,
|
200
|
+
)
|
201
|
+
|
202
|
+
|
203
|
+
def build_model(checkpoint_path: str, **kwargs) -> QwenVL:
|
204
|
+
config = get_model_config(**kwargs)
|
205
|
+
model = QwenVL(config)
|
206
|
+
image_encoder.load_image_encoder(checkpoint_path, model.image_encoder)
|
207
|
+
# Load the parameters of decoder.
|
208
|
+
loader = loading_utils.ModelLoader(checkpoint_path, decoder.TENSOR_NAMES)
|
209
|
+
loader.load(model.decoder, strict=False)
|
210
|
+
model.eval()
|
211
|
+
return model
|
@@ -0,0 +1,143 @@
|
|
1
|
+
# Copyright 2025 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Verifies the reauthored Qwen 2.5 VL model."""
|
17
|
+
|
18
|
+
import logging
|
19
|
+
import pathlib
|
20
|
+
from absl import app
|
21
|
+
from absl import flags
|
22
|
+
from ai_edge_torch.generative.examples.qwen_vl import qwen_vl
|
23
|
+
from ai_edge_torch.generative.layers import kv_cache
|
24
|
+
from ai_edge_torch.generative.utilities import verifier
|
25
|
+
from PIL import Image
|
26
|
+
import requests
|
27
|
+
import torch
|
28
|
+
import transformers
|
29
|
+
|
30
|
+
_IMAGE_URL = flags.DEFINE_string(
|
31
|
+
"image_url",
|
32
|
+
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg?download=true",
|
33
|
+
"The image URI to encode.",
|
34
|
+
)
|
35
|
+
_PROMPTS = flags.DEFINE_string(
|
36
|
+
"prompts",
|
37
|
+
"<|vision_start|><|image_pad|><|vision_end|>Describe the image.<|im_end|>",
|
38
|
+
"The input prompts to generate answers.",
|
39
|
+
)
|
40
|
+
_MAX_NEW_TOKENS = flags.DEFINE_integer(
|
41
|
+
"max_new_tokens",
|
42
|
+
30,
|
43
|
+
"The maximum size of the generated tokens.",
|
44
|
+
)
|
45
|
+
|
46
|
+
|
47
|
+
class ReauthoredQwenVLWrapper(verifier.ReauthoredModelWrapper):
|
48
|
+
"""Reauthored Qwen VL model wrapper."""
|
49
|
+
|
50
|
+
def __init__(self, model: torch.nn.Module):
|
51
|
+
super().__init__(model)
|
52
|
+
self.grid_thw = None
|
53
|
+
|
54
|
+
def _init_kv_cache(self):
|
55
|
+
return kv_cache.KVCache.from_model_config(self.model.config.decoder_config)
|
56
|
+
|
57
|
+
def _get_extra_args_for_forward(self):
|
58
|
+
return {"grid_thw": self.grid_thw}
|
59
|
+
|
60
|
+
|
61
|
+
def main(_):
|
62
|
+
checkpoint = "Qwen/Qwen2.5-VL-3B-Instruct"
|
63
|
+
logging.info("Loading the original model from: %s", checkpoint)
|
64
|
+
original_model = (
|
65
|
+
transformers.Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
66
|
+
checkpoint
|
67
|
+
)
|
68
|
+
)
|
69
|
+
|
70
|
+
# Locate the cached dir.
|
71
|
+
cached_config_file = transformers.utils.cached_file(
|
72
|
+
checkpoint, transformers.utils.CONFIG_NAME
|
73
|
+
)
|
74
|
+
reauthored_checkpoint = pathlib.Path(cached_config_file).parent
|
75
|
+
logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
|
76
|
+
reauthored_model = qwen_vl.build_model(reauthored_checkpoint)
|
77
|
+
|
78
|
+
logging.info("Loading the processor from: %s", checkpoint)
|
79
|
+
processor = transformers.AutoProcessor.from_pretrained(checkpoint)
|
80
|
+
|
81
|
+
logging.info("Loading the image from: %s", _IMAGE_URL.value)
|
82
|
+
image = Image.open(requests.get(_IMAGE_URL.value, stream=True).raw)
|
83
|
+
inputs = processor(text=_PROMPTS.value, images=image, return_tensors="pt")
|
84
|
+
|
85
|
+
logging.info("Verifying the reauthored model with model.forward()...")
|
86
|
+
logging.info("Forwarding the original model...")
|
87
|
+
outputs_original = original_model.forward(
|
88
|
+
input_ids=inputs["input_ids"],
|
89
|
+
pixel_values=inputs["pixel_values"],
|
90
|
+
image_grid_thw=inputs["image_grid_thw"],
|
91
|
+
)
|
92
|
+
outputs_original = outputs_original.logits
|
93
|
+
logging.info("outputs_original: %s", outputs_original)
|
94
|
+
|
95
|
+
logging.info("Forwarding the reauthored model...")
|
96
|
+
wrapped_reauthored_model = ReauthoredQwenVLWrapper(reauthored_model)
|
97
|
+
wrapped_reauthored_model.grid_thw = inputs["image_grid_thw"]
|
98
|
+
outputs_reauthored = wrapped_reauthored_model.forward(
|
99
|
+
tokens=inputs["input_ids"],
|
100
|
+
pixel_values=inputs["pixel_values"],
|
101
|
+
)
|
102
|
+
logging.info("outputs_reauthored: %s", outputs_reauthored)
|
103
|
+
|
104
|
+
try:
|
105
|
+
assert torch.allclose(outputs_original, outputs_reauthored, atol=1e-01)
|
106
|
+
except AssertionError as e:
|
107
|
+
logging.error("*** FAILED *** verify with forward()")
|
108
|
+
raise e
|
109
|
+
else:
|
110
|
+
logging.info("*** PASSED *** verify with forward()")
|
111
|
+
|
112
|
+
logging.info("Verifying the reauthored model with model.generate()...")
|
113
|
+
logging.info("Generating answer with the original model...")
|
114
|
+
outputs_original = original_model.generate(
|
115
|
+
**inputs, max_new_tokens=_MAX_NEW_TOKENS.value
|
116
|
+
)
|
117
|
+
response_original = processor.decode(
|
118
|
+
outputs_original[0], skip_special_tokens=True
|
119
|
+
)
|
120
|
+
logging.info("outputs_from_original_model: [[%s]]", response_original)
|
121
|
+
|
122
|
+
logging.info("Generating answer with the reauthored model...")
|
123
|
+
outputs_reauthored = wrapped_reauthored_model.generate(
|
124
|
+
prompts=inputs["input_ids"],
|
125
|
+
pixel_values=inputs["pixel_values"],
|
126
|
+
max_new_tokens=_MAX_NEW_TOKENS.value,
|
127
|
+
)
|
128
|
+
response_reauthored = processor.decode(
|
129
|
+
outputs_reauthored[0], skip_special_tokens=True
|
130
|
+
)
|
131
|
+
logging.info("outputs from reauthored model: [[%s]]", response_reauthored)
|
132
|
+
|
133
|
+
try:
|
134
|
+
assert response_original == response_reauthored
|
135
|
+
except AssertionError as e:
|
136
|
+
logging.error("*** FAILED *** verify with generate()")
|
137
|
+
raise e
|
138
|
+
else:
|
139
|
+
logging.info("*** PASSED *** verify with generate()")
|
140
|
+
|
141
|
+
|
142
|
+
if __name__ == "__main__":
|
143
|
+
app.run(main)
|
@@ -12,5 +12,5 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
from ._build import build_mlir_debuginfo
|
15
|
+
from ._build import build_mlir_debuginfo, build_mlir_file_debuginfo
|
16
16
|
from ._op_polyfill import write_mlir_debuginfo_op
|
@@ -13,6 +13,7 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
import torch
|
16
|
+
import re
|
16
17
|
|
17
18
|
|
18
19
|
def _class_fullname(cls):
|
@@ -34,6 +35,29 @@ def _get_hierarchy(node: torch.fx.Node):
|
|
34
35
|
return hierachy_str
|
35
36
|
|
36
37
|
|
38
|
+
def _get_canonical_filename(filename):
|
39
|
+
"""Remove unnecessary path prefix to make the filename more readable.
|
40
|
+
|
41
|
+
This should be factored out so that pattern is a global option that a user
|
42
|
+
can override.
|
43
|
+
"""
|
44
|
+
|
45
|
+
# TODO: We should add a config option to provide a regex to strip from the
|
46
|
+
# debug info. Currently absolute path is used.
|
47
|
+
return filename
|
48
|
+
|
49
|
+
|
50
|
+
def build_mlir_file_debuginfo(node: torch.fx.Node):
|
51
|
+
"""Build the file and line info for the given node's lowerings in MLIR."""
|
52
|
+
|
53
|
+
if not node.stack_trace:
|
54
|
+
return None, None
|
55
|
+
|
56
|
+
# Note: This uses internal APIs and may break in the future.
|
57
|
+
pt_trace = torch.fx.graph._parse_stack_trace(node.stack_trace)
|
58
|
+
return _get_canonical_filename(pt_trace.file), int(pt_trace.lineno)
|
59
|
+
|
60
|
+
|
37
61
|
def build_mlir_debuginfo(node: torch.fx.Node):
|
38
62
|
"""Build the debuginfo string for the given node's lowerings in MLIR."""
|
39
63
|
|
@@ -93,7 +93,12 @@ class LoweringInterpreter(torch.fx.Interpreter):
|
|
93
93
|
if info is None:
|
94
94
|
return ir.Location.unknown()
|
95
95
|
|
96
|
-
|
96
|
+
(file, line) = debuginfo.build_mlir_file_debuginfo(node)
|
97
|
+
fileinfo = None
|
98
|
+
if file is not None:
|
99
|
+
fileinfo = ir.Location.file(filename=file, line=line, col=0)
|
100
|
+
|
101
|
+
return ir.Location.name(name=info, childLoc=fileinfo)
|
97
102
|
|
98
103
|
def run_node(self, node: torch.fx.Node):
|
99
104
|
loc = self._build_loc(node)
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.3.0.
|
3
|
+
Version: 0.3.0.dev20250207
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=8sPR_5uXJA4NEE0nIwNdSl-ADOJEoR8hAgYvBQDY70Y,120
|
|
2
2
|
ai_edge_torch/_config.py,sha256=AiqhbcheF7j_ozIGDLC89k1we95aVgFDa-tR6h7UI0s,2529
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
4
4
|
ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
|
5
|
-
ai_edge_torch/version.py,sha256=
|
5
|
+
ai_edge_torch/version.py,sha256=9V9FbxtqLT70Tzmv_G0qlbqixmVc0pPPJs22C_iBlHE,706
|
6
6
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
7
7
|
ai_edge_torch/_convert/conversion.py,sha256=mckvxznKLXdF2HuJg_IxQaT5Ty-iWl_iXElHEugH3VI,5452
|
8
8
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
@@ -94,9 +94,11 @@ ai_edge_torch/generative/examples/qwen/convert_to_tflite.py,sha256=tqvXVGNdDehda
|
|
94
94
|
ai_edge_torch/generative/examples/qwen/qwen.py,sha256=Zi_qiQ1JPokXZ95jgSEnQp3F-LKzFCvWvFLKhJjnASo,4199
|
95
95
|
ai_edge_torch/generative/examples/qwen/verify.py,sha256=9_AyEJTeUfvhhID64Rto2bflFPyXMFokdQLsseLUMiI,2775
|
96
96
|
ai_edge_torch/generative/examples/qwen_vl/__init__.py,sha256=JaAnrFoXTl3RJX97XspklkTyqOHVyAgRJsZtzNDd10c,671
|
97
|
-
ai_edge_torch/generative/examples/qwen_vl/decoder.py,sha256=
|
98
|
-
ai_edge_torch/generative/examples/qwen_vl/image_encoder.py,sha256=
|
99
|
-
ai_edge_torch/generative/examples/qwen_vl/
|
97
|
+
ai_edge_torch/generative/examples/qwen_vl/decoder.py,sha256=0x4iDg2cBe3PFnjVce3nj7g2rjagGHcKqRCfbASNxA8,4402
|
98
|
+
ai_edge_torch/generative/examples/qwen_vl/image_encoder.py,sha256=OYyF0bLVYJno9azmKDqX3gT8ojYYWEyp_F8nLtltPWs,13544
|
99
|
+
ai_edge_torch/generative/examples/qwen_vl/qwen_vl.py,sha256=Uzl1ZPkdYIaHN9QxezqxNwagZiGOHf1VreWnqgRQwf8,7627
|
100
|
+
ai_edge_torch/generative/examples/qwen_vl/verify.py,sha256=2GPi0Vay4a69EwBSOfPMCMjE9PTwPOQus5j2KN7HE7I,5031
|
101
|
+
ai_edge_torch/generative/examples/qwen_vl/verify_decoder.py,sha256=xPWoOBLh2eK12KEhELLYymfL7xvc0chmYC98c6x37oo,2602
|
100
102
|
ai_edge_torch/generative/examples/qwen_vl/verify_image_encoder.py,sha256=lQR8p6Zp7PxDN_erMf-FKLIn_Rv4BGyQHjDbModFkeY,2946
|
101
103
|
ai_edge_torch/generative/examples/smollm/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
102
104
|
ai_edge_torch/generative/examples/smollm/convert_to_tflite.py,sha256=megskv1oiPhwHSnguoG7zV-esXp1Ns_FPeMLAYKhDb0,2522
|
@@ -195,14 +197,14 @@ ai_edge_torch/lowertools/translate_recipe.py,sha256=ymkBpFqAUiupRWqrPOWiVphKcXR1
|
|
195
197
|
ai_edge_torch/odml_torch/__init__.py,sha256=S8jOzE9nLof-6es3XDiGJRN-9H_XTxsVm9dE7lD3RWo,812
|
196
198
|
ai_edge_torch/odml_torch/_torch_future.py,sha256=jSYHf1CMTJzMizPMbu2b39hAt0ZTR6gQLq67GMe9KTo,2336
|
197
199
|
ai_edge_torch/odml_torch/_torch_library.py,sha256=Lw1gqL2HWNRspdTwNhIkYAHDyafHedHtkXyKKxn-Wss,805
|
198
|
-
ai_edge_torch/odml_torch/export.py,sha256=
|
200
|
+
ai_edge_torch/odml_torch/export.py,sha256=LDyZUehM1lmT3y2bGeA94rMGRUTLxzIUm4DTlCA8tQc,13640
|
199
201
|
ai_edge_torch/odml_torch/export_utils.py,sha256=QeA37Irlty6AiIBuqmHmJgn3lqahBQ5xsh6IKRoKm1g,4774
|
200
202
|
ai_edge_torch/odml_torch/tf_integration.py,sha256=NN29WeXmHZ0S1RPDFHUnBi2DEjMvAtwczStPYIsQ1w8,4849
|
201
203
|
ai_edge_torch/odml_torch/composite/__init__.py,sha256=71GM_gDZxJyo38ZSoYSwhZX3xKA9rknO93JS9kw9w_c,778
|
202
204
|
ai_edge_torch/odml_torch/composite/mark_tensor.py,sha256=U--rwl-XkWKgkdXCXDn6yySug8FR66o1YFUAIoSaWW4,3523
|
203
205
|
ai_edge_torch/odml_torch/composite/stablehlo_composite_builder.py,sha256=2Y52E_gLeoXpMcPpV-svXsgN3JbEIjnPVjm0xkpTUdQ,3319
|
204
|
-
ai_edge_torch/odml_torch/debuginfo/__init__.py,sha256=
|
205
|
-
ai_edge_torch/odml_torch/debuginfo/_build.py,sha256=
|
206
|
+
ai_edge_torch/odml_torch/debuginfo/__init__.py,sha256=DoE3HgAtV_GNKGBDGzH2Lb7JUHvyH7TUqWbDZIObr34,789
|
207
|
+
ai_edge_torch/odml_torch/debuginfo/_build.py,sha256=sjpYeqgdbDmD7lhp80yc8jfWq-HxX3xuQ58ND8ZeU-I,2213
|
206
208
|
ai_edge_torch/odml_torch/debuginfo/_op_polyfill.py,sha256=IvOBQyROI9WHS3umHRxsDW-1YElU9BPWzKtJA2eKWOI,1739
|
207
209
|
ai_edge_torch/odml_torch/jax_bridge/__init__.py,sha256=e9Oa4J3An9FYr3zM0OzjzyNNitEeg-IoBUGNfUxsDSA,798
|
208
210
|
ai_edge_torch/odml_torch/jax_bridge/_wrap.py,sha256=LqwZ1vCJTSOzgzvH8LUAN-sAkF-l_pGj1AMEIzAqHCA,6638
|
@@ -227,8 +229,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
|
|
227
229
|
ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
228
230
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
229
231
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
230
|
-
ai_edge_torch_nightly-0.3.0.
|
231
|
-
ai_edge_torch_nightly-0.3.0.
|
232
|
-
ai_edge_torch_nightly-0.3.0.
|
233
|
-
ai_edge_torch_nightly-0.3.0.
|
234
|
-
ai_edge_torch_nightly-0.3.0.
|
232
|
+
ai_edge_torch_nightly-0.3.0.dev20250207.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
233
|
+
ai_edge_torch_nightly-0.3.0.dev20250207.dist-info/METADATA,sha256=pvcJfgIOezx3rNegfvMIVrkFXmZuqnnE_zMzC9Wt37k,1966
|
234
|
+
ai_edge_torch_nightly-0.3.0.dev20250207.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
|
235
|
+
ai_edge_torch_nightly-0.3.0.dev20250207.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
236
|
+
ai_edge_torch_nightly-0.3.0.dev20250207.dist-info/RECORD,,
|
File without changes
|
File without changes
|