ai-edge-torch-nightly 0.3.0.dev20250204__py3-none-any.whl → 0.3.0.dev20250207__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -13,11 +13,7 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- """Example of converting a PaliGemma model to multi-signature tflite model.
17
-
18
- DISCLAIMER: It works only with ODML Torch conversion backend. Refer to
19
- https://github.com/google-ai-edge/ai-edge-torch/blob/main/docs/pytorch_converter/README.md#use-odml-torch-conversion-backend-experimental.
20
- """
16
+ """Example of converting a PaliGemma model to multi-signature tflite model."""
21
17
 
22
18
  import os
23
19
  import pathlib
@@ -55,7 +55,6 @@ class Decoder(model_builder.DecoderOnlyModel):
55
55
  input_embeds: torch.Tensor = None,
56
56
  mask: Optional[torch.Tensor] = None,
57
57
  export_config: Optional[model_builder.ExportConfig] = None,
58
- called_by_generate: bool = True,
59
58
  ) -> dict[torch.Tensor, kv_utils.KVCache]:
60
59
  if input_embeds is None:
61
60
  return super().forward(
@@ -64,11 +63,11 @@ class Decoder(model_builder.DecoderOnlyModel):
64
63
 
65
64
  assert input_embeds is not None
66
65
 
67
- repo_pos = input_pos + 1 # PaliGemma position is 1-based.
66
+ rope_pos = input_pos + 1 # PaliGemma position is 1-based.
68
67
  # ROPE parameters for all attn_configs are the same. Take the first one.
69
68
  attn_config = self.config.block_config(0).attn_config
70
69
  n_elem = int(attn_config.rotary_percentage * attn_config.head_dim)
71
- rope = self.config.build_rope(repo_pos, n_elem, attn_config.rotary_base)
70
+ rope = self.config.build_rope(rope_pos, n_elem, attn_config.rotary_base)
72
71
 
73
72
  # The first part of input_embeds are image embeddings. Diagonal causal mask
74
73
  # doesn't work here.
@@ -58,34 +58,23 @@ class Decoder2(gemma2.Gemma2):
58
58
  input_embeds: torch.Tensor = None,
59
59
  mask: Optional[torch.Tensor] = None,
60
60
  export_config: Optional[model_builder.ExportConfig] = None,
61
- called_by_generate: bool = True,
62
61
  ) -> dict[torch.Tensor, kv_utils.KVCache]:
63
62
  if input_embeds is None:
64
63
  return super().forward(tokens, input_pos, kv_cache, mask, export_config)
65
64
 
66
65
  assert input_embeds is not None
67
66
 
68
- repo_pos = input_pos + 1 # PaliGemma2 position is 1-based.
67
+ rope_pos = input_pos + 1 # PaliGemma2 position is 1-based.
69
68
  # ROPE parameters for all attn_configs are the same. Take the first one.
70
69
  attn_config = self.config.block_config(0).attn_config
71
70
  n_elem = int(attn_config.rotary_percentage * attn_config.head_dim)
72
- rope = self.config.build_rope(repo_pos, n_elem, attn_config.rotary_base)
71
+ rope = self.config.build_rope(rope_pos, n_elem, attn_config.rotary_base)
73
72
 
74
73
  if mask is None:
75
- if called_by_generate:
76
- # PaliGemma2 generate() uses a diagonal causal mask even with image
77
- # embeds.
78
- mask = [
79
- self.get_attention_mask(
80
- self.config.block_config(i).attn_config.attn_type, input_pos
81
- )
82
- for i in range(self.config.num_layers)
83
- ]
84
- else:
85
- # By default, don't mask image embeds with a diagonal causal mask.
86
- embeds_len = input_embeds.shape[1]
87
- mask = torch.zeros(embeds_len, self.config.kv_cache_max)
88
- mask[:, embeds_len:] = float("-inf")
74
+ # By default, don't mask image embeds with a diagonal causal mask.
75
+ embeds_len = input_embeds.shape[1]
76
+ mask = torch.zeros(embeds_len, self.config.kv_cache_max)
77
+ mask[:, embeds_len:] = float("-inf")
89
78
 
90
79
  return self._forward_with_embeds(
91
80
  input_embeds, rope, mask, input_pos, kv_cache, export_config
@@ -15,7 +15,7 @@
15
15
 
16
16
  """Example of building a full-stack of PaliGemma model."""
17
17
 
18
- from dataclasses import dataclass
18
+ import dataclasses
19
19
  from typing import Optional
20
20
 
21
21
  from ai_edge_torch.generative.examples.paligemma import decoder
@@ -31,7 +31,7 @@ from torch import nn
31
31
  PROJECTION_TENSOR_NAME = "multi_modal_projector.linear"
32
32
 
33
33
 
34
- @dataclass
34
+ @dataclasses.dataclass
35
35
  class PaliGemmaConfig:
36
36
  """PaliGemma model configurations."""
37
37
 
@@ -39,7 +39,6 @@ class PaliGemmaConfig:
39
39
  decoder_config: cfg.ModelConfig
40
40
 
41
41
  image_token_id: int
42
- image_projection_scale: float
43
42
  image_projection_use_bias: bool = False
44
43
 
45
44
 
@@ -73,7 +72,6 @@ class PaliGemma(nn.Module):
73
72
  mask: Optional[torch.Tensor] = None,
74
73
  pixel_values: torch.Tensor = None,
75
74
  export_config: Optional[model_builder.ExportConfig] = None,
76
- called_by_generate: bool = True,
77
75
  ) -> dict[torch.Tensor, kv_utils.KVCache]:
78
76
  if pixel_values is None:
79
77
  return self.decoder(
@@ -83,14 +81,13 @@ class PaliGemma(nn.Module):
83
81
  mask=mask,
84
82
  input_embeds=None,
85
83
  export_config=export_config,
86
- called_by_generate=called_by_generate,
87
84
  )
88
85
 
89
86
  input_embeds = self.decoder.tok_embedding(tokens)
90
87
 
91
88
  image_encoded = self.image_encoder(pixel_values=pixel_values)
92
89
  image_embeds = self.image_projection(image_encoded)
93
- image_embeds = image_embeds / self.config.image_projection_scale
90
+ image_embeds = image_embeds / self.config.decoder_config.embedding_scale
94
91
 
95
92
  # Merging image_embeds into text_embeds as PaliGemmaForConditionalGeneration
96
93
  # can be done like:
@@ -116,7 +113,6 @@ class PaliGemma(nn.Module):
116
113
  mask=mask,
117
114
  input_embeds=input_embeds,
118
115
  export_config=export_config,
119
- called_by_generate=called_by_generate,
120
116
  )
121
117
 
122
118
 
@@ -130,7 +126,6 @@ def get_model_config(get_decoder_config, **kwargs) -> PaliGemmaConfig:
130
126
  image_encoder_config=image_encoder.get_image_encoder_config(),
131
127
  decoder_config=get_decoder_config(**kwargs),
132
128
  image_token_id=257152,
133
- image_projection_scale=2048**0.5,
134
129
  image_projection_use_bias=True,
135
130
  )
136
131
 
@@ -140,7 +135,6 @@ def get_fake_model_config(get_decoder_config, **kwargs) -> PaliGemmaConfig:
140
135
  image_encoder_config=image_encoder.get_fake_image_encoder_config(),
141
136
  decoder_config=get_decoder_config(**kwargs),
142
137
  image_token_id=127,
143
- image_projection_scale=128**0.5,
144
138
  image_projection_use_bias=True,
145
139
  )
146
140
 
@@ -41,7 +41,7 @@ _IMAGE_URL = flags.DEFINE_string(
41
41
  )
42
42
  _PROMPTS = flags.DEFINE_string(
43
43
  "prompts",
44
- "describe en",
44
+ "<image><bos>describe en",
45
45
  "The input prompts to generate answers.",
46
46
  )
47
47
  _MAX_NEW_TOKENS = flags.DEFINE_integer(
@@ -59,16 +59,9 @@ _CHECKPOINT = {
59
59
  class ReauthoredPaliGemmaWrapper(verifier.ReauthoredModelWrapper):
60
60
  """Reauthored PaliGemma model wrapper."""
61
61
 
62
- def __init__(self, model: torch.nn.Module):
63
- super().__init__(model)
64
- self.forward_called_by_generate = False
65
-
66
62
  def _init_kv_cache(self):
67
63
  return kv_cache.KVCache.from_model_config(self.model.config.decoder_config)
68
64
 
69
- def _get_extra_args_for_forward(self):
70
- return {"called_by_generate": self.forward_called_by_generate}
71
-
72
65
 
73
66
  def main(_):
74
67
  if _VERSION.value == "1":
@@ -137,7 +130,6 @@ def main(_):
137
130
  logging.info("outputs_from_original_model: [[%s]]", response_original)
138
131
 
139
132
  logging.info("Generating answer with the reauthored model...")
140
- wrapped_reauthored_model.forward_called_by_generate = True
141
133
  outputs_reauthored = wrapped_reauthored_model.generate(
142
134
  prompts=inputs["input_ids"],
143
135
  pixel_values=inputs["pixel_values"],
@@ -15,16 +15,61 @@
15
15
 
16
16
  """Example of building decoder for Qwen 2.5 VL models."""
17
17
 
18
+ from typing import Optional, Tuple
19
+
20
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
18
21
  import ai_edge_torch.generative.layers.model_config as cfg
19
22
  from ai_edge_torch.generative.utilities import model_builder
20
- from torch import nn
23
+ import torch
21
24
 
22
25
  TENSOR_NAMES = model_builder.TENSOR_NAMES
23
26
 
24
27
 
25
28
  class Decoder(model_builder.DecoderOnlyModel):
26
- """A decoder for Qwen-VL model built from the Edge Generative API layers."""
27
- pass
29
+ """A decoder for Qwen-VL model built from the Edge Generative API layers.
30
+
31
+ Besides a tensor of text token IDs, forward() can also take a tensor of
32
+ embeddings which may include text or image or both.
33
+ """
34
+
35
+ @torch.inference_mode
36
+ def forward(
37
+ self,
38
+ tokens: torch.Tensor,
39
+ input_pos: torch.Tensor,
40
+ kv_cache: kv_utils.KVCache,
41
+ input_embeds: torch.Tensor = None,
42
+ rope: Tuple[torch.Tensor, torch.Tensor] = None,
43
+ mask: Optional[torch.Tensor] = None,
44
+ export_config: Optional[model_builder.ExportConfig] = None,
45
+ ) -> dict[torch.Tensor, kv_utils.KVCache]:
46
+ if input_embeds is None:
47
+ _, seq_len = tokens.size()
48
+ assert self.config.max_seq_len >= seq_len, (
49
+ f"Cannot forward sequence of length {seq_len}, max seq length is only"
50
+ f" {self.config.max_seq_len}"
51
+ )
52
+ # token embeddings of shape (b, t, n_embd)
53
+ input_embeds = self.tok_embedding(tokens)
54
+
55
+ if rope is None:
56
+ # ROPE parameters for all attn_configs are the same. Take the first one.
57
+ attn_config = self.config.block_config(0).attn_config
58
+ n_elem = int(attn_config.rotary_percentage * attn_config.head_dim)
59
+ rope = self.config.build_rope(input_pos, n_elem, attn_config.rotary_base)
60
+
61
+ if mask is None:
62
+ mask = self.mask_cache.index_select(2, input_pos)
63
+ mask = mask[:, :, :, : self.config.kv_cache_max]
64
+
65
+ return self._forward_with_embeds(
66
+ input_embeds,
67
+ rope,
68
+ mask,
69
+ input_pos,
70
+ kv_cache,
71
+ export_config=export_config,
72
+ )
28
73
 
29
74
 
30
75
  def get_decoder_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
@@ -82,7 +127,7 @@ def get_fake_decoder_config(**kwargs) -> cfg.ModelConfig:
82
127
  return config
83
128
 
84
129
 
85
- def build_decoder(checkpoint_path: str, **kwargs) -> nn.Module:
130
+ def build_decoder(checkpoint_path: str, **kwargs) -> torch.nn.Module:
86
131
  return model_builder.build_decoder_only_model(
87
132
  checkpoint_path=checkpoint_path,
88
133
  config=get_decoder_config(**kwargs),
@@ -356,6 +356,12 @@ def get_fake_image_encoder_config() -> QwenVLImageConfig:
356
356
  def build_image_encoder(checkpoint_path: str) -> QwenVLImageEncoder:
357
357
  config = get_image_encoder_config()
358
358
  encoder = QwenVLImageEncoder(config)
359
+ load_image_encoder(checkpoint_path, encoder)
360
+ encoder.eval()
361
+ return encoder
362
+
363
+
364
+ def load_image_encoder(checkpoint_path: str, encoder: QwenVLImageEncoder):
359
365
  loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
360
366
  # Loose the strictness because only image encoder is being loaded.
361
367
  loader.load(encoder, strict=False)
@@ -365,15 +371,12 @@ def build_image_encoder(checkpoint_path: str) -> QwenVLImageEncoder:
365
371
  state = merger_loader.get_state()
366
372
  w1_state = dict()
367
373
  w1_state["weight"] = state.pop(f"{MERGER_TENSOR_NAMES.ff_up_proj}.weight")
368
- if config.merger_config.use_bias:
374
+ if encoder.config.merger_config.use_bias:
369
375
  w1_state["bias"] = state.pop(f"{MERGER_TENSOR_NAMES.ff_up_proj}.bias")
370
376
  encoder.merger.w1.load_state_dict(w1_state)
371
377
 
372
378
  w2_state = dict()
373
379
  w2_state["weight"] = state.pop(f"{MERGER_TENSOR_NAMES.ff_down_proj}.weight")
374
- if config.merger_config.use_bias:
380
+ if encoder.config.merger_config.use_bias:
375
381
  w2_state["bias"] = state.pop(f"{MERGER_TENSOR_NAMES.ff_down_proj}.bias")
376
382
  encoder.merger.w2.load_state_dict(w2_state)
377
-
378
- encoder.eval()
379
- return encoder
@@ -0,0 +1,211 @@
1
+ # Copyright 2025 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Example of building a full-stack of Qwen 2.5 VL model."""
17
+
18
+ import dataclasses
19
+ from typing import List, Optional, Tuple
20
+
21
+ from ai_edge_torch.generative.examples.qwen_vl import decoder
22
+ from ai_edge_torch.generative.examples.qwen_vl import image_encoder
23
+ import ai_edge_torch.generative.layers.kv_cache as kv_utils
24
+ import ai_edge_torch.generative.layers.model_config as cfg
25
+ from ai_edge_torch.generative.utilities import model_builder
26
+ import ai_edge_torch.generative.utilities.loader as loading_utils
27
+ import torch
28
+ from torch import nn
29
+
30
+
31
+ @dataclasses.dataclass
32
+ class QwenVLConfig:
33
+ """Qwen VL model configurations."""
34
+
35
+ image_encoder_config: image_encoder.QwenVLImageConfig
36
+ decoder_config: cfg.ModelConfig
37
+ image_token_id: int
38
+ mrope_section: List[int]
39
+
40
+
41
+ class QwenVL(nn.Module):
42
+ """Qwen VL model from the Edge Generative API."""
43
+
44
+ def __init__(self, config: QwenVLConfig):
45
+ super().__init__()
46
+
47
+ self.image_encoder = image_encoder.QwenVLImageEncoder(
48
+ config.image_encoder_config
49
+ )
50
+ self.decoder = decoder.Decoder(config.decoder_config)
51
+ # The amount of adjustment in input_pos to calculate RoPE properly in
52
+ # forward() calls after image is handled.
53
+ self.rope_pos_adjust = 0
54
+ self.config = config
55
+
56
+ @torch.inference_mode
57
+ def forward(
58
+ self,
59
+ tokens: torch.Tensor,
60
+ input_pos: torch.Tensor,
61
+ kv_cache: kv_utils.KVCache,
62
+ mask: Optional[torch.Tensor] = None,
63
+ pixel_values: torch.Tensor = None,
64
+ grid_thw: torch.Tensor = None,
65
+ export_config: Optional[model_builder.ExportConfig] = None,
66
+ ) -> dict[torch.Tensor, kv_utils.KVCache]:
67
+ if pixel_values is None:
68
+ return self.decoder(
69
+ tokens=tokens,
70
+ input_pos=input_pos,
71
+ kv_cache=kv_cache,
72
+ mask=mask,
73
+ rope=self._build_text_rope(input_pos),
74
+ input_embeds=None,
75
+ export_config=export_config,
76
+ )
77
+
78
+ input_embeds = self.decoder.tok_embedding(tokens)
79
+ image_embeds = self.image_encoder(pixel_values, grid_thw).unsqueeze(0)
80
+
81
+ # Merging image_embeds into text_embeds as PaliGemmaForConditionalGeneration
82
+ # can be done like:
83
+ #
84
+ # image_mask = tokens == self.config.image_token_id
85
+ # image_mask = image_mask.unsqueeze(-1).expand_as(input_embeds)
86
+ # input_embeds = input_embeds.masked_scatter(image_mask, image_embeds)
87
+ #
88
+ # Unfortunately, torch.Tensor.masked_scatter can't be lowered on CPU.
89
+ # Assume that image is put at the beginning of the input sequence wrapped
90
+ # with vision_start and vision_end tokens.
91
+ input_embeds = torch.cat(
92
+ (
93
+ input_embeds[:, :1, :],
94
+ image_embeds,
95
+ input_embeds[:, image_embeds.shape[1] + 1 :, :],
96
+ ),
97
+ dim=1,
98
+ )
99
+
100
+ return self.decoder(
101
+ tokens=None,
102
+ input_pos=input_pos,
103
+ kv_cache=kv_cache,
104
+ mask=mask,
105
+ input_embeds=input_embeds,
106
+ rope=self._build_multimodal_rope(input_pos, grid_thw),
107
+ export_config=export_config,
108
+ )
109
+
110
+ def _build_rope(
111
+ self, rope_pos: torch.Tensor
112
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
113
+ # ROPE parameters for all attn_configs are the same. Take the first one.
114
+ attn_config = self.config.decoder_config.block_config(0).attn_config
115
+ n_elem = int(attn_config.rotary_percentage * attn_config.head_dim)
116
+ return self.config.decoder_config.build_rope(
117
+ rope_pos, n_elem, attn_config.rotary_base
118
+ )
119
+
120
+ def _build_text_rope(
121
+ self, input_pos: torch.Tensor
122
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
123
+ # Reset rope_pos_adjust to 0 when input sequence starts from scratch, i.e.
124
+ # input_pos[0] = 0.
125
+ if input_pos[0] == 0:
126
+ self.rope_pos_adjust = 0
127
+ return self._build_rope(input_pos + self.rope_pos_adjust)
128
+
129
+ def _build_multimodal_rope(
130
+ self, input_pos: torch.Tensor, grid_thw: torch.Tensor
131
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
132
+ """Builds RoPE of multimodal input for the Qwen VL model.
133
+
134
+ It's copied from Qwen2_5_VLForConditionalGeneration.get_rope_index() and
135
+ simplified based on the assumption that an image is put at the beginning of
136
+ the input sequence with vision start and vision end tokens.
137
+ """
138
+ spatial_merge_size = self.config.image_encoder_config.spatial_merge_size
139
+ height = grid_thw[0][1] // spatial_merge_size
140
+ width = grid_thw[0][2] // spatial_merge_size
141
+ image_pos_max = max(height, width)
142
+ image_pos_count = height * width
143
+
144
+ # The position of vision end tokek and text tokens and after the image.
145
+ text_pos_start = image_pos_max + 1
146
+ text_pos_count = len(input_pos) - image_pos_count - 1
147
+ text_pos = torch.arange(text_pos_start, text_pos_start + text_pos_count)
148
+ # Set input_pos_adjust since text_pos_start has changed.
149
+ self.rope_pos_adjust = image_pos_max - image_pos_count
150
+
151
+ temporal_rope = self._build_image_text_rope(
152
+ torch.ones(image_pos_count, dtype=torch.int), text_pos
153
+ )
154
+ height_rope = self._build_image_text_rope(
155
+ torch.arange(1, height + 1).view(-1, 1).expand(-1, width).flatten(),
156
+ text_pos,
157
+ )
158
+ width_rope = self._build_image_text_rope(
159
+ torch.arange(1, width + 1).view(1, -1).expand(height, -1).flatten(),
160
+ text_pos,
161
+ )
162
+
163
+ return (
164
+ self._merge_ropes(temporal_rope[0], height_rope[0], width_rope[0]),
165
+ self._merge_ropes(temporal_rope[1], height_rope[1], width_rope[1]),
166
+ )
167
+
168
+ def _build_image_text_rope(
169
+ self, image_pos: torch.Tensor, text_pos: torch.Tensor
170
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
171
+ return self._build_rope(
172
+ torch.cat((torch.zeros(1, dtype=torch.int), image_pos, text_pos))
173
+ )
174
+
175
+ def _merge_ropes(self, a: torch.Tensor, b: torch.Tensor, c: torch.Tensor):
176
+ """Merges RoPE tensors based on apply_multimodal_rotary_pos_emb()."""
177
+ split = torch.stack([a, b, c]).split(self.config.mrope_section, dim=-1)
178
+ return torch.cat([m[i % 3] for i, m in enumerate(split)], dim=-1)
179
+
180
+
181
+ def get_model_config(**kwargs) -> QwenVLConfig:
182
+ """Returns the model config for a PaliGemma 3B-224 model.
183
+
184
+ Returns:
185
+ The model config for a PaliGemma 3B model.
186
+ """
187
+ return QwenVLConfig(
188
+ image_encoder_config=image_encoder.get_image_encoder_config(),
189
+ decoder_config=decoder.get_decoder_config(**kwargs),
190
+ image_token_id=151655,
191
+ mrope_section=[16, 24, 24],
192
+ )
193
+
194
+
195
+ def get_fake_model_config(**kwargs) -> QwenVLConfig:
196
+ return QwenVLConfig(
197
+ image_encoder_config=image_encoder.get_fake_image_encoder_config(),
198
+ decoder_config=decoder.get_fake_decoder_config(**kwargs),
199
+ image_token_id=127,
200
+ )
201
+
202
+
203
+ def build_model(checkpoint_path: str, **kwargs) -> QwenVL:
204
+ config = get_model_config(**kwargs)
205
+ model = QwenVL(config)
206
+ image_encoder.load_image_encoder(checkpoint_path, model.image_encoder)
207
+ # Load the parameters of decoder.
208
+ loader = loading_utils.ModelLoader(checkpoint_path, decoder.TENSOR_NAMES)
209
+ loader.load(model.decoder, strict=False)
210
+ model.eval()
211
+ return model
@@ -0,0 +1,143 @@
1
+ # Copyright 2025 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Verifies the reauthored Qwen 2.5 VL model."""
17
+
18
+ import logging
19
+ import pathlib
20
+ from absl import app
21
+ from absl import flags
22
+ from ai_edge_torch.generative.examples.qwen_vl import qwen_vl
23
+ from ai_edge_torch.generative.layers import kv_cache
24
+ from ai_edge_torch.generative.utilities import verifier
25
+ from PIL import Image
26
+ import requests
27
+ import torch
28
+ import transformers
29
+
30
+ _IMAGE_URL = flags.DEFINE_string(
31
+ "image_url",
32
+ "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg?download=true",
33
+ "The image URI to encode.",
34
+ )
35
+ _PROMPTS = flags.DEFINE_string(
36
+ "prompts",
37
+ "<|vision_start|><|image_pad|><|vision_end|>Describe the image.<|im_end|>",
38
+ "The input prompts to generate answers.",
39
+ )
40
+ _MAX_NEW_TOKENS = flags.DEFINE_integer(
41
+ "max_new_tokens",
42
+ 30,
43
+ "The maximum size of the generated tokens.",
44
+ )
45
+
46
+
47
+ class ReauthoredQwenVLWrapper(verifier.ReauthoredModelWrapper):
48
+ """Reauthored Qwen VL model wrapper."""
49
+
50
+ def __init__(self, model: torch.nn.Module):
51
+ super().__init__(model)
52
+ self.grid_thw = None
53
+
54
+ def _init_kv_cache(self):
55
+ return kv_cache.KVCache.from_model_config(self.model.config.decoder_config)
56
+
57
+ def _get_extra_args_for_forward(self):
58
+ return {"grid_thw": self.grid_thw}
59
+
60
+
61
+ def main(_):
62
+ checkpoint = "Qwen/Qwen2.5-VL-3B-Instruct"
63
+ logging.info("Loading the original model from: %s", checkpoint)
64
+ original_model = (
65
+ transformers.Qwen2_5_VLForConditionalGeneration.from_pretrained(
66
+ checkpoint
67
+ )
68
+ )
69
+
70
+ # Locate the cached dir.
71
+ cached_config_file = transformers.utils.cached_file(
72
+ checkpoint, transformers.utils.CONFIG_NAME
73
+ )
74
+ reauthored_checkpoint = pathlib.Path(cached_config_file).parent
75
+ logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
76
+ reauthored_model = qwen_vl.build_model(reauthored_checkpoint)
77
+
78
+ logging.info("Loading the processor from: %s", checkpoint)
79
+ processor = transformers.AutoProcessor.from_pretrained(checkpoint)
80
+
81
+ logging.info("Loading the image from: %s", _IMAGE_URL.value)
82
+ image = Image.open(requests.get(_IMAGE_URL.value, stream=True).raw)
83
+ inputs = processor(text=_PROMPTS.value, images=image, return_tensors="pt")
84
+
85
+ logging.info("Verifying the reauthored model with model.forward()...")
86
+ logging.info("Forwarding the original model...")
87
+ outputs_original = original_model.forward(
88
+ input_ids=inputs["input_ids"],
89
+ pixel_values=inputs["pixel_values"],
90
+ image_grid_thw=inputs["image_grid_thw"],
91
+ )
92
+ outputs_original = outputs_original.logits
93
+ logging.info("outputs_original: %s", outputs_original)
94
+
95
+ logging.info("Forwarding the reauthored model...")
96
+ wrapped_reauthored_model = ReauthoredQwenVLWrapper(reauthored_model)
97
+ wrapped_reauthored_model.grid_thw = inputs["image_grid_thw"]
98
+ outputs_reauthored = wrapped_reauthored_model.forward(
99
+ tokens=inputs["input_ids"],
100
+ pixel_values=inputs["pixel_values"],
101
+ )
102
+ logging.info("outputs_reauthored: %s", outputs_reauthored)
103
+
104
+ try:
105
+ assert torch.allclose(outputs_original, outputs_reauthored, atol=1e-01)
106
+ except AssertionError as e:
107
+ logging.error("*** FAILED *** verify with forward()")
108
+ raise e
109
+ else:
110
+ logging.info("*** PASSED *** verify with forward()")
111
+
112
+ logging.info("Verifying the reauthored model with model.generate()...")
113
+ logging.info("Generating answer with the original model...")
114
+ outputs_original = original_model.generate(
115
+ **inputs, max_new_tokens=_MAX_NEW_TOKENS.value
116
+ )
117
+ response_original = processor.decode(
118
+ outputs_original[0], skip_special_tokens=True
119
+ )
120
+ logging.info("outputs_from_original_model: [[%s]]", response_original)
121
+
122
+ logging.info("Generating answer with the reauthored model...")
123
+ outputs_reauthored = wrapped_reauthored_model.generate(
124
+ prompts=inputs["input_ids"],
125
+ pixel_values=inputs["pixel_values"],
126
+ max_new_tokens=_MAX_NEW_TOKENS.value,
127
+ )
128
+ response_reauthored = processor.decode(
129
+ outputs_reauthored[0], skip_special_tokens=True
130
+ )
131
+ logging.info("outputs from reauthored model: [[%s]]", response_reauthored)
132
+
133
+ try:
134
+ assert response_original == response_reauthored
135
+ except AssertionError as e:
136
+ logging.error("*** FAILED *** verify with generate()")
137
+ raise e
138
+ else:
139
+ logging.info("*** PASSED *** verify with generate()")
140
+
141
+
142
+ if __name__ == "__main__":
143
+ app.run(main)
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The AI Edge Torch Authors.
1
+ # Copyright 2025 The AI Edge Torch Authors.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -12,5 +12,5 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- from ._build import build_mlir_debuginfo
15
+ from ._build import build_mlir_debuginfo, build_mlir_file_debuginfo
16
16
  from ._op_polyfill import write_mlir_debuginfo_op
@@ -13,6 +13,7 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
  import torch
16
+ import re
16
17
 
17
18
 
18
19
  def _class_fullname(cls):
@@ -34,6 +35,29 @@ def _get_hierarchy(node: torch.fx.Node):
34
35
  return hierachy_str
35
36
 
36
37
 
38
+ def _get_canonical_filename(filename):
39
+ """Remove unnecessary path prefix to make the filename more readable.
40
+
41
+ This should be factored out so that pattern is a global option that a user
42
+ can override.
43
+ """
44
+
45
+ # TODO: We should add a config option to provide a regex to strip from the
46
+ # debug info. Currently absolute path is used.
47
+ return filename
48
+
49
+
50
+ def build_mlir_file_debuginfo(node: torch.fx.Node):
51
+ """Build the file and line info for the given node's lowerings in MLIR."""
52
+
53
+ if not node.stack_trace:
54
+ return None, None
55
+
56
+ # Note: This uses internal APIs and may break in the future.
57
+ pt_trace = torch.fx.graph._parse_stack_trace(node.stack_trace)
58
+ return _get_canonical_filename(pt_trace.file), int(pt_trace.lineno)
59
+
60
+
37
61
  def build_mlir_debuginfo(node: torch.fx.Node):
38
62
  """Build the debuginfo string for the given node's lowerings in MLIR."""
39
63
 
@@ -93,7 +93,12 @@ class LoweringInterpreter(torch.fx.Interpreter):
93
93
  if info is None:
94
94
  return ir.Location.unknown()
95
95
 
96
- return ir.Location.name(name=info)
96
+ (file, line) = debuginfo.build_mlir_file_debuginfo(node)
97
+ fileinfo = None
98
+ if file is not None:
99
+ fileinfo = ir.Location.file(filename=file, line=line, col=0)
100
+
101
+ return ir.Location.name(name=info, childLoc=fileinfo)
97
102
 
98
103
  def run_node(self, node: torch.fx.Node):
99
104
  loc = self._build_loc(node)
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.3.0.dev20250204"
16
+ __version__ = "0.3.0.dev20250207"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.3.0.dev20250204
3
+ Version: 0.3.0.dev20250207
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=8sPR_5uXJA4NEE0nIwNdSl-ADOJEoR8hAgYvBQDY70Y,120
2
2
  ai_edge_torch/_config.py,sha256=AiqhbcheF7j_ozIGDLC89k1we95aVgFDa-tR6h7UI0s,2529
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
5
- ai_edge_torch/version.py,sha256=4XOGz1x6yfOnkOtBndF7qE1L3Ma12ZMJNwQ7wIWkyEs,706
5
+ ai_edge_torch/version.py,sha256=9V9FbxtqLT70Tzmv_G0qlbqixmVc0pPPJs22C_iBlHE,706
6
6
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
7
7
  ai_edge_torch/_convert/conversion.py,sha256=mckvxznKLXdF2HuJg_IxQaT5Ty-iWl_iXElHEugH3VI,5452
8
8
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
@@ -73,12 +73,12 @@ ai_edge_torch/generative/examples/openelm/convert_to_tflite.py,sha256=pyxRGgMxrn
73
73
  ai_edge_torch/generative/examples/openelm/openelm.py,sha256=sIJ8Ie1oxFrJM-1jvv2ukiJbQOTIUGuMEZvmwZbt3n0,4556
74
74
  ai_edge_torch/generative/examples/openelm/verify.py,sha256=VkigoqhAr8ew95neb3TifYv-SLOSheaWKv2AH0iKDrc,2441
75
75
  ai_edge_torch/generative/examples/paligemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
76
- ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py,sha256=scLsguzzuHfKYDWUd2uZkKYVRzdAbQHLd-kPam8QwvM,3004
77
- ai_edge_torch/generative/examples/paligemma/decoder.py,sha256=S_W-0ojRu2Vd5SLNPs1kC-70xHB8AdSWslm-yPxyezk,5478
78
- ai_edge_torch/generative/examples/paligemma/decoder2.py,sha256=W009ky-yobueTzdaybSCqBAvNyArLXW3jDyp5MarzZU,6376
76
+ ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py,sha256=a6ISb96xhEJc1TtaFGCUiA4msKedPTAeMvkWrfIklx4,2792
77
+ ai_edge_torch/generative/examples/paligemma/decoder.py,sha256=z658dW_D0Iqvo6xnh4vG7_o17-Fufndyis8Rq5yafJY,5439
78
+ ai_edge_torch/generative/examples/paligemma/decoder2.py,sha256=GZa0Ou_DvOijB2nTL_jRvGbn0_dvJPosQAPf47yqicw,5988
79
79
  ai_edge_torch/generative/examples/paligemma/image_encoder.py,sha256=7K1xl64UvoHaYmqWjIbahwXHfppwTQ8sN7JrpGKX1XQ,5771
80
- ai_edge_torch/generative/examples/paligemma/paligemma.py,sha256=CEMG9gh51ev1KXPew927a6nfampiXX9bL6m-25tNYN8,6340
81
- ai_edge_torch/generative/examples/paligemma/verify.py,sha256=KT3Ruy40tSESxQuy-Sw01NAI3zId1BZr6Bp7FZj1wZk,5622
80
+ ai_edge_torch/generative/examples/paligemma/paligemma.py,sha256=x1mgRtVLxkCTvlkPow3y7ADoGTjUh5uc5pF46mxatLw,6099
81
+ ai_edge_torch/generative/examples/paligemma/verify.py,sha256=HLcu1fWMtFFFONAqVW94rOBqq4XvFHtatX3JFGOsfZw,5345
82
82
  ai_edge_torch/generative/examples/paligemma/verify_decoder.py,sha256=al5wMPWri4IRVWrLmCplPi6uoCzwh0vBHMGnCt-XUqo,2690
83
83
  ai_edge_torch/generative/examples/paligemma/verify_decoder2.py,sha256=tm-UfLr0YeBRVcQsWLBOMWI9JUzHmtPEbYK2vpITpqY,2534
84
84
  ai_edge_torch/generative/examples/paligemma/verify_image_encoder.py,sha256=vNm-wTT8BD6zbX6GocfP1QrVoHl0zSvuVxoXN36eeiU,3540
@@ -94,9 +94,11 @@ ai_edge_torch/generative/examples/qwen/convert_to_tflite.py,sha256=tqvXVGNdDehda
94
94
  ai_edge_torch/generative/examples/qwen/qwen.py,sha256=Zi_qiQ1JPokXZ95jgSEnQp3F-LKzFCvWvFLKhJjnASo,4199
95
95
  ai_edge_torch/generative/examples/qwen/verify.py,sha256=9_AyEJTeUfvhhID64Rto2bflFPyXMFokdQLsseLUMiI,2775
96
96
  ai_edge_torch/generative/examples/qwen_vl/__init__.py,sha256=JaAnrFoXTl3RJX97XspklkTyqOHVyAgRJsZtzNDd10c,671
97
- ai_edge_torch/generative/examples/qwen_vl/decoder.py,sha256=rD_Ch5CzuXeatqv0C3z8vU-zou1z9QDUhoB6V4YTPIg,2829
98
- ai_edge_torch/generative/examples/qwen_vl/image_encoder.py,sha256=QIPbcturxn5OaVsF5zkRRsyAvCM2Bojyz9XBekHOaro,13405
99
- ai_edge_torch/generative/examples/qwen_vl/verify_decoder.py,sha256=FEY_PifD9fQGnERzSOljFLraRIbUVF3XTnCv95A30Cs,2602
97
+ ai_edge_torch/generative/examples/qwen_vl/decoder.py,sha256=0x4iDg2cBe3PFnjVce3nj7g2rjagGHcKqRCfbASNxA8,4402
98
+ ai_edge_torch/generative/examples/qwen_vl/image_encoder.py,sha256=OYyF0bLVYJno9azmKDqX3gT8ojYYWEyp_F8nLtltPWs,13544
99
+ ai_edge_torch/generative/examples/qwen_vl/qwen_vl.py,sha256=Uzl1ZPkdYIaHN9QxezqxNwagZiGOHf1VreWnqgRQwf8,7627
100
+ ai_edge_torch/generative/examples/qwen_vl/verify.py,sha256=2GPi0Vay4a69EwBSOfPMCMjE9PTwPOQus5j2KN7HE7I,5031
101
+ ai_edge_torch/generative/examples/qwen_vl/verify_decoder.py,sha256=xPWoOBLh2eK12KEhELLYymfL7xvc0chmYC98c6x37oo,2602
100
102
  ai_edge_torch/generative/examples/qwen_vl/verify_image_encoder.py,sha256=lQR8p6Zp7PxDN_erMf-FKLIn_Rv4BGyQHjDbModFkeY,2946
101
103
  ai_edge_torch/generative/examples/smollm/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
102
104
  ai_edge_torch/generative/examples/smollm/convert_to_tflite.py,sha256=megskv1oiPhwHSnguoG7zV-esXp1Ns_FPeMLAYKhDb0,2522
@@ -195,14 +197,14 @@ ai_edge_torch/lowertools/translate_recipe.py,sha256=ymkBpFqAUiupRWqrPOWiVphKcXR1
195
197
  ai_edge_torch/odml_torch/__init__.py,sha256=S8jOzE9nLof-6es3XDiGJRN-9H_XTxsVm9dE7lD3RWo,812
196
198
  ai_edge_torch/odml_torch/_torch_future.py,sha256=jSYHf1CMTJzMizPMbu2b39hAt0ZTR6gQLq67GMe9KTo,2336
197
199
  ai_edge_torch/odml_torch/_torch_library.py,sha256=Lw1gqL2HWNRspdTwNhIkYAHDyafHedHtkXyKKxn-Wss,805
198
- ai_edge_torch/odml_torch/export.py,sha256=YN7QPrQ8W6T3YVOdyIGadfSQuBroMjIqAMB9FeUa7Ho,13447
200
+ ai_edge_torch/odml_torch/export.py,sha256=LDyZUehM1lmT3y2bGeA94rMGRUTLxzIUm4DTlCA8tQc,13640
199
201
  ai_edge_torch/odml_torch/export_utils.py,sha256=QeA37Irlty6AiIBuqmHmJgn3lqahBQ5xsh6IKRoKm1g,4774
200
202
  ai_edge_torch/odml_torch/tf_integration.py,sha256=NN29WeXmHZ0S1RPDFHUnBi2DEjMvAtwczStPYIsQ1w8,4849
201
203
  ai_edge_torch/odml_torch/composite/__init__.py,sha256=71GM_gDZxJyo38ZSoYSwhZX3xKA9rknO93JS9kw9w_c,778
202
204
  ai_edge_torch/odml_torch/composite/mark_tensor.py,sha256=U--rwl-XkWKgkdXCXDn6yySug8FR66o1YFUAIoSaWW4,3523
203
205
  ai_edge_torch/odml_torch/composite/stablehlo_composite_builder.py,sha256=2Y52E_gLeoXpMcPpV-svXsgN3JbEIjnPVjm0xkpTUdQ,3319
204
- ai_edge_torch/odml_torch/debuginfo/__init__.py,sha256=9ag6-WWRG50rPCtIV7OpIokEKu2YRyGlMZZqVPWUH6g,762
205
- ai_edge_torch/odml_torch/debuginfo/_build.py,sha256=1xCXOs3-9UcsOyLFH0uyQwLu7c06iYFTo0NQ7Ckbl2I,1465
206
+ ai_edge_torch/odml_torch/debuginfo/__init__.py,sha256=DoE3HgAtV_GNKGBDGzH2Lb7JUHvyH7TUqWbDZIObr34,789
207
+ ai_edge_torch/odml_torch/debuginfo/_build.py,sha256=sjpYeqgdbDmD7lhp80yc8jfWq-HxX3xuQ58ND8ZeU-I,2213
206
208
  ai_edge_torch/odml_torch/debuginfo/_op_polyfill.py,sha256=IvOBQyROI9WHS3umHRxsDW-1YElU9BPWzKtJA2eKWOI,1739
207
209
  ai_edge_torch/odml_torch/jax_bridge/__init__.py,sha256=e9Oa4J3An9FYr3zM0OzjzyNNitEeg-IoBUGNfUxsDSA,798
208
210
  ai_edge_torch/odml_torch/jax_bridge/_wrap.py,sha256=LqwZ1vCJTSOzgzvH8LUAN-sAkF-l_pGj1AMEIzAqHCA,6638
@@ -227,8 +229,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
227
229
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
228
230
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
229
231
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
230
- ai_edge_torch_nightly-0.3.0.dev20250204.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
231
- ai_edge_torch_nightly-0.3.0.dev20250204.dist-info/METADATA,sha256=Rf4w5EMQlNWOoFIuVlXUZPU9vmXlOJW7oB4yPrtgK0c,1966
232
- ai_edge_torch_nightly-0.3.0.dev20250204.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
233
- ai_edge_torch_nightly-0.3.0.dev20250204.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
234
- ai_edge_torch_nightly-0.3.0.dev20250204.dist-info/RECORD,,
232
+ ai_edge_torch_nightly-0.3.0.dev20250207.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
233
+ ai_edge_torch_nightly-0.3.0.dev20250207.dist-info/METADATA,sha256=pvcJfgIOezx3rNegfvMIVrkFXmZuqnnE_zMzC9Wt37k,1966
234
+ ai_edge_torch_nightly-0.3.0.dev20250207.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
235
+ ai_edge_torch_nightly-0.3.0.dev20250207.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
236
+ ai_edge_torch_nightly-0.3.0.dev20250207.dist-info/RECORD,,