ai-edge-torch-nightly 0.3.0.dev20250203__py3-none-any.whl → 0.3.0.dev20250205__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -13,11 +13,7 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- """Example of converting a PaliGemma model to multi-signature tflite model.
17
-
18
- DISCLAIMER: It works only with ODML Torch conversion backend. Refer to
19
- https://github.com/google-ai-edge/ai-edge-torch/blob/main/docs/pytorch_converter/README.md#use-odml-torch-conversion-backend-experimental.
20
- """
16
+ """Example of converting a PaliGemma model to multi-signature tflite model."""
21
17
 
22
18
  import os
23
19
  import pathlib
@@ -19,7 +19,6 @@ from typing import Optional
19
19
 
20
20
  from ai_edge_torch.generative.layers import kv_cache as kv_utils
21
21
  import ai_edge_torch.generative.layers.model_config as cfg
22
- import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
23
22
  from ai_edge_torch.generative.utilities import model_builder
24
23
  import ai_edge_torch.generative.utilities.loader as loading_utils
25
24
  import torch
@@ -56,28 +55,34 @@ class Decoder(model_builder.DecoderOnlyModel):
56
55
  input_embeds: torch.Tensor = None,
57
56
  mask: Optional[torch.Tensor] = None,
58
57
  export_config: Optional[model_builder.ExportConfig] = None,
59
- called_by_generate: bool = True,
60
58
  ) -> dict[torch.Tensor, kv_utils.KVCache]:
61
59
  if input_embeds is None:
62
- return super().forward(tokens, input_pos, kv_cache)
60
+ return super().forward(
61
+ tokens, input_pos, kv_cache, mask, export_config=export_config
62
+ )
63
63
 
64
64
  assert input_embeds is not None
65
65
 
66
- repo_pos = input_pos + 1 # PaliGemma position is 1-based.
66
+ rope_pos = input_pos + 1 # PaliGemma position is 1-based.
67
67
  # ROPE parameters for all attn_configs are the same. Take the first one.
68
68
  attn_config = self.config.block_config(0).attn_config
69
69
  n_elem = int(attn_config.rotary_percentage * attn_config.head_dim)
70
- rope = rotary_pos_emb.build_rope(repo_pos, n_elem, attn_config.rotary_base)
70
+ rope = self.config.build_rope(rope_pos, n_elem, attn_config.rotary_base)
71
71
 
72
72
  # The first part of input_embeds are image embeddings. Diagonal causal mask
73
73
  # doesn't work here.
74
- embeds_len = input_embeds.shape[1]
75
74
  if mask is None:
75
+ embeds_len = input_embeds.shape[1]
76
76
  mask = torch.zeros(embeds_len, self.config.kv_cache_max)
77
77
  mask[:, embeds_len:] = float("-inf")
78
78
 
79
79
  return self._forward_with_embeds(
80
- input_embeds, rope, mask, input_pos, kv_cache
80
+ input_embeds,
81
+ rope,
82
+ mask,
83
+ input_pos,
84
+ kv_cache,
85
+ export_config=export_config,
81
86
  )
82
87
 
83
88
 
@@ -20,7 +20,6 @@ from typing import Optional
20
20
  from ai_edge_torch.generative.examples.gemma import gemma2
21
21
  from ai_edge_torch.generative.layers import kv_cache as kv_utils
22
22
  import ai_edge_torch.generative.layers.model_config as cfg
23
- import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
24
23
  from ai_edge_torch.generative.utilities import model_builder
25
24
  import ai_edge_torch.generative.utilities.loader as loading_utils
26
25
  import torch
@@ -59,33 +58,23 @@ class Decoder2(gemma2.Gemma2):
59
58
  input_embeds: torch.Tensor = None,
60
59
  mask: Optional[torch.Tensor] = None,
61
60
  export_config: Optional[model_builder.ExportConfig] = None,
62
- called_by_generate: bool = True,
63
61
  ) -> dict[torch.Tensor, kv_utils.KVCache]:
64
62
  if input_embeds is None:
65
- return super().forward(tokens, input_pos, kv_cache)
63
+ return super().forward(tokens, input_pos, kv_cache, mask, export_config)
66
64
 
67
65
  assert input_embeds is not None
68
66
 
69
- repo_pos = input_pos + 1 # PaliGemma2 position is 1-based.
67
+ rope_pos = input_pos + 1 # PaliGemma2 position is 1-based.
70
68
  # ROPE parameters for all attn_configs are the same. Take the first one.
71
69
  attn_config = self.config.block_config(0).attn_config
72
70
  n_elem = int(attn_config.rotary_percentage * attn_config.head_dim)
73
- rope = rotary_pos_emb.build_rope(repo_pos, n_elem, attn_config.rotary_base)
71
+ rope = self.config.build_rope(rope_pos, n_elem, attn_config.rotary_base)
74
72
 
75
73
  if mask is None:
76
- if called_by_generate:
77
- # PaliGemma2 generate() use a diagonal causal mask even with image embeds.
78
- mask = [
79
- self.get_attention_mask(
80
- self.config.block_config(i).attn_config.attn_type, input_pos
81
- )
82
- for i in range(self.config.num_layers)
83
- ]
84
- else:
85
- # By default, don't mask image embeds with a diagonal causal mask.
86
- embeds_len = input_embeds.shape[1]
87
- mask = torch.zeros(embeds_len, self.config.kv_cache_max)
88
- mask[:, embeds_len:] = float("-inf")
74
+ # By default, don't mask image embeds with a diagonal causal mask.
75
+ embeds_len = input_embeds.shape[1]
76
+ mask = torch.zeros(embeds_len, self.config.kv_cache_max)
77
+ mask[:, embeds_len:] = float("-inf")
89
78
 
90
79
  return self._forward_with_embeds(
91
80
  input_embeds, rope, mask, input_pos, kv_cache, export_config
@@ -60,6 +60,7 @@ class SiglipVisionEncoder(nn.Module):
60
60
  kernel_size=config.image_embedding.patch_size,
61
61
  stride=config.image_embedding.patch_size,
62
62
  padding=0,
63
+ bias=config.embedding_use_bias,
63
64
  )
64
65
  num_patches = (
65
66
  config.image_embedding.image_size // config.image_embedding.patch_size
@@ -15,7 +15,7 @@
15
15
 
16
16
  """Example of building a full-stack of PaliGemma model."""
17
17
 
18
- from dataclasses import dataclass
18
+ import dataclasses
19
19
  from typing import Optional
20
20
 
21
21
  from ai_edge_torch.generative.examples.paligemma import decoder
@@ -31,7 +31,7 @@ from torch import nn
31
31
  PROJECTION_TENSOR_NAME = "multi_modal_projector.linear"
32
32
 
33
33
 
34
- @dataclass
34
+ @dataclasses.dataclass
35
35
  class PaliGemmaConfig:
36
36
  """PaliGemma model configurations."""
37
37
 
@@ -39,7 +39,6 @@ class PaliGemmaConfig:
39
39
  decoder_config: cfg.ModelConfig
40
40
 
41
41
  image_token_id: int
42
- image_projection_scale: float
43
42
  image_projection_use_bias: bool = False
44
43
 
45
44
 
@@ -73,7 +72,6 @@ class PaliGemma(nn.Module):
73
72
  mask: Optional[torch.Tensor] = None,
74
73
  pixel_values: torch.Tensor = None,
75
74
  export_config: Optional[model_builder.ExportConfig] = None,
76
- called_by_generate: bool = True,
77
75
  ) -> dict[torch.Tensor, kv_utils.KVCache]:
78
76
  if pixel_values is None:
79
77
  return self.decoder(
@@ -83,14 +81,13 @@ class PaliGemma(nn.Module):
83
81
  mask=mask,
84
82
  input_embeds=None,
85
83
  export_config=export_config,
86
- called_by_generate=called_by_generate,
87
84
  )
88
85
 
89
86
  input_embeds = self.decoder.tok_embedding(tokens)
90
87
 
91
88
  image_encoded = self.image_encoder(pixel_values=pixel_values)
92
89
  image_embeds = self.image_projection(image_encoded)
93
- image_embeds = image_embeds / self.config.image_projection_scale
90
+ image_embeds = image_embeds / self.config.decoder_config.embedding_scale
94
91
 
95
92
  # Merging image_embeds into text_embeds as PaliGemmaForConditionalGeneration
96
93
  # can be done like:
@@ -116,7 +113,6 @@ class PaliGemma(nn.Module):
116
113
  mask=mask,
117
114
  input_embeds=input_embeds,
118
115
  export_config=export_config,
119
- called_by_generate=called_by_generate,
120
116
  )
121
117
 
122
118
 
@@ -130,7 +126,6 @@ def get_model_config(get_decoder_config, **kwargs) -> PaliGemmaConfig:
130
126
  image_encoder_config=image_encoder.get_image_encoder_config(),
131
127
  decoder_config=get_decoder_config(**kwargs),
132
128
  image_token_id=257152,
133
- image_projection_scale=2048**0.5,
134
129
  image_projection_use_bias=True,
135
130
  )
136
131
 
@@ -140,7 +135,6 @@ def get_fake_model_config(get_decoder_config, **kwargs) -> PaliGemmaConfig:
140
135
  image_encoder_config=image_encoder.get_fake_image_encoder_config(),
141
136
  decoder_config=get_decoder_config(**kwargs),
142
137
  image_token_id=127,
143
- image_projection_scale=128**0.5,
144
138
  image_projection_use_bias=True,
145
139
  )
146
140
 
@@ -41,7 +41,7 @@ _IMAGE_URL = flags.DEFINE_string(
41
41
  )
42
42
  _PROMPTS = flags.DEFINE_string(
43
43
  "prompts",
44
- "describe en",
44
+ "<image><bos>describe en",
45
45
  "The input prompts to generate answers.",
46
46
  )
47
47
  _MAX_NEW_TOKENS = flags.DEFINE_integer(
@@ -59,16 +59,9 @@ _CHECKPOINT = {
59
59
  class ReauthoredPaliGemmaWrapper(verifier.ReauthoredModelWrapper):
60
60
  """Reauthored PaliGemma model wrapper."""
61
61
 
62
- def __init__(self, model: torch.nn.Module):
63
- super().__init__(model)
64
- self.forward_called_by_generate = False
65
-
66
62
  def _init_kv_cache(self):
67
63
  return kv_cache.KVCache.from_model_config(self.model.config.decoder_config)
68
64
 
69
- def _get_extra_args_for_forward(self):
70
- return {"called_by_generate": self.forward_called_by_generate}
71
-
72
65
 
73
66
  def main(_):
74
67
  if _VERSION.value == "1":
@@ -137,7 +130,6 @@ def main(_):
137
130
  logging.info("outputs_from_original_model: [[%s]]", response_original)
138
131
 
139
132
  logging.info("Generating answer with the reauthored model...")
140
- wrapped_reauthored_model.forward_called_by_generate = True
141
133
  outputs_reauthored = wrapped_reauthored_model.generate(
142
134
  prompts=inputs["input_ids"],
143
135
  pixel_values=inputs["pixel_values"],
@@ -0,0 +1,379 @@
1
+ # Copyright 2025 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Example of building an image encoder of Qwen 2.5 VL model."""
17
+
18
+ import dataclasses
19
+ from typing import Optional
20
+
21
+ from ai_edge_torch.generative.layers import attention
22
+ from ai_edge_torch.generative.layers import attention_utils
23
+ from ai_edge_torch.generative.layers import builder
24
+ import ai_edge_torch.generative.layers.model_config as cfg
25
+ import ai_edge_torch.generative.utilities.loader as loading_utils
26
+ import torch
27
+ from torch import nn
28
+ import torch.nn.functional as F
29
+
30
+ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
31
+ ff_up_proj="visual.blocks.{}.mlp.up_proj",
32
+ ff_down_proj="visual.blocks.{}.mlp.down_proj",
33
+ ff_gate_proj="visual.blocks.{}.mlp.gate_proj",
34
+ attn_fused_qkv_proj="visual.blocks.{}.attn.qkv",
35
+ attn_output_proj="visual.blocks.{}.attn.proj",
36
+ pre_attn_norm="visual.blocks.{}.norm1",
37
+ post_attn_norm="visual.blocks.{}.norm2",
38
+ embedding="visual.patch_embed.proj",
39
+ final_norm="visual.merger.ln_q",
40
+ )
41
+
42
+ MERGER_TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
43
+ ff_up_proj="visual.merger.mlp.0",
44
+ ff_down_proj="visual.merger.mlp.2",
45
+ )
46
+
47
+
48
+ @dataclasses.dataclass
49
+ class QwenVLMergerConfig:
50
+ """Merger parameters."""
51
+
52
+ activation: cfg.ActivationConfig
53
+ intermediate_size: int
54
+ out_embedding_dim: int
55
+ use_bias: bool = False
56
+
57
+
58
+ @dataclasses.dataclass
59
+ class QwenVLImageConfig(cfg.ModelConfig):
60
+ """model config for Qwen 2.5 VL model."""
61
+
62
+ merger_config: Optional[QwenVLMergerConfig] = None
63
+ window_size: Optional[int] = None
64
+ spatial_merge_size: Optional[int] = None
65
+ full_atten_block_indexes: Optional[list[int]] = None
66
+
67
+
68
+ class QwenVLMerger(nn.Module):
69
+ """Merger of Qwen 2.5 VL models from the Edge Generative API.
70
+
71
+ It's based on Qwen2_5_VLPatchMerger.
72
+ """
73
+
74
+ def __init__(self, config: QwenVLImageConfig):
75
+ super().__init__()
76
+ self.intermediate_size = config.merger_config.intermediate_size
77
+ self.w1 = nn.Linear(self.intermediate_size, self.intermediate_size)
78
+ self.act = builder.get_activation(config.merger_config.activation)
79
+ self.w2 = nn.Linear(
80
+ self.intermediate_size, config.merger_config.out_embedding_dim
81
+ )
82
+
83
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
84
+ x_reshaped = x.view(-1, self.intermediate_size)
85
+ return self.w2(self.act(self.w1(x_reshaped)))
86
+
87
+
88
+ class QwenVLImageEncoder(nn.Module):
89
+ """Image encoder of Qwen 2.5 VL models from the Edge Generative API."""
90
+
91
+ def __init__(self, config: QwenVLImageConfig):
92
+ super().__init__()
93
+
94
+ # Tensor shape used to reshape pixel_values in forward() and various places.
95
+ self.kernel_size = (
96
+ -1, # batch size
97
+ config.image_embedding.channels,
98
+ config.image_embedding.temporal_patch_size,
99
+ config.image_embedding.patch_size,
100
+ config.image_embedding.patch_size,
101
+ )
102
+ self.tok_embedding = nn.Conv3d(
103
+ in_channels=self.kernel_size[1],
104
+ out_channels=config.embedding_dim,
105
+ kernel_size=self.kernel_size[2:],
106
+ stride=self.kernel_size[2:],
107
+ padding=0,
108
+ bias=config.embedding_use_bias,
109
+ )
110
+
111
+ self.transformer_blocks = nn.ModuleList(
112
+ attention.TransformerBlock(config.block_config(idx), config)
113
+ for idx in range(config.num_layers)
114
+ )
115
+ self.final_norm = builder.build_norm(
116
+ config.embedding_dim,
117
+ config.final_norm_config,
118
+ )
119
+ self.merger = QwenVLMerger(config)
120
+ self.config = config
121
+
122
+ @torch.inference_mode
123
+ def forward(
124
+ self, pixel_values: torch.Tensor, grid_thw: torch.Tensor
125
+ ) -> torch.Tensor:
126
+ # Get window index and sequence lengths to rearrange the input tensor.
127
+ window_index, cu_seqlens = self._get_window_index(grid_thw)
128
+
129
+ # Embed the image and rearrange the embedding tensor.
130
+ pixel_reshaped = pixel_values.view(self.kernel_size)
131
+ x = self.tok_embedding(pixel_reshaped)
132
+ x = x.view(-1, self.config.embedding_dim)
133
+ x = self._rearrange(x, window_index).unsqueeze(0)
134
+
135
+ # Get RoPE and attention mask arranged according to the window index.
136
+ cos, sin = self._get_rope(grid_thw)
137
+ rope = (
138
+ self._rearrange(cos, window_index),
139
+ self._rearrange(sin, window_index),
140
+ )
141
+
142
+ mask = self._get_mask(x.shape[1], cu_seqlens)
143
+ full_mask = torch.zeros(x.shape[:2])
144
+ for i, block in enumerate(self.transformer_blocks):
145
+ x = block(
146
+ x,
147
+ rope=rope,
148
+ mask=full_mask if i in self.config.full_atten_block_indexes else mask,
149
+ )
150
+
151
+ y = self.merger.forward(self.final_norm(x))
152
+ # Arrange the output back to the original order.
153
+ reverse_index = torch.argsort(window_index)
154
+ return y[reverse_index, ...]
155
+
156
+ def _get_rope(self, grid_thw: torch.Tensor) -> torch.Tensor:
157
+ """Get RoPE for Qwen VL model based on image grid information.
158
+
159
+ It's copied from Qwen2_5_VisionTransformerPretrainedModel.rot_pos_emb() and
160
+ modified accordingly.
161
+ """
162
+ pos_ids = []
163
+ for t, h, w in grid_thw:
164
+ hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
165
+ hpos_ids = hpos_ids.reshape(
166
+ h // self.config.spatial_merge_size,
167
+ self.config.spatial_merge_size,
168
+ w // self.config.spatial_merge_size,
169
+ self.config.spatial_merge_size,
170
+ )
171
+ hpos_ids = hpos_ids.permute(0, 2, 1, 3)
172
+ hpos_ids = hpos_ids.flatten()
173
+
174
+ wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
175
+ wpos_ids = wpos_ids.reshape(
176
+ h // self.config.spatial_merge_size,
177
+ self.config.spatial_merge_size,
178
+ w // self.config.spatial_merge_size,
179
+ self.config.spatial_merge_size,
180
+ )
181
+ wpos_ids = wpos_ids.permute(0, 2, 1, 3)
182
+ wpos_ids = wpos_ids.flatten()
183
+ pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
184
+ pos_ids = torch.cat(pos_ids, dim=0)
185
+ max_grid_size = grid_thw[:, 1:].max()
186
+
187
+ cos, sin = attention_utils.build_rope_cache(
188
+ max_grid_size,
189
+ # ROPE parameters for all attn_configs are the same. Take the first one.
190
+ self.config.block_config(0).attn_config.head_dim // 2,
191
+ )
192
+ return cos[pos_ids].flatten(1), sin[pos_ids].flatten(1)
193
+
194
+ def _get_window_index(self, grid_thw: torch.Tensor):
195
+ """Get window index for Qwen VL model to rearrange the input tensor.
196
+
197
+ It's copied from Qwen2_5_VisionTransformerPretrainedModel.get_window_index()
198
+ and modified accordingly.
199
+ """
200
+ window_index: list = []
201
+ cu_window_seqlens: list = [0]
202
+ window_index_id = 0
203
+ vit_merger_window_size = (
204
+ self.config.window_size
205
+ // self.config.spatial_merge_size
206
+ // self.config.image_embedding.patch_size
207
+ )
208
+
209
+ for grid_t, grid_h, grid_w in grid_thw:
210
+ llm_grid_h, llm_grid_w = (
211
+ grid_h // self.config.spatial_merge_size,
212
+ grid_w // self.config.spatial_merge_size,
213
+ )
214
+ index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(
215
+ grid_t, llm_grid_h, llm_grid_w
216
+ )
217
+ pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size
218
+ pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size
219
+ num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size
220
+ num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size
221
+ index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100)
222
+ index_padded = index_padded.reshape(
223
+ grid_t,
224
+ num_windows_h,
225
+ vit_merger_window_size,
226
+ num_windows_w,
227
+ vit_merger_window_size,
228
+ )
229
+ index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape(
230
+ grid_t,
231
+ num_windows_h * num_windows_w,
232
+ vit_merger_window_size,
233
+ vit_merger_window_size,
234
+ )
235
+ seqlens = (index_padded != -100).sum([2, 3]).reshape(-1)
236
+ index_padded = index_padded.reshape(-1)
237
+ index_new = index_padded[index_padded != -100]
238
+ window_index.append(index_new + window_index_id)
239
+ spatial_merge_unit = (
240
+ self.config.spatial_merge_size * self.config.spatial_merge_size
241
+ )
242
+ cu_seqlens_tmp = (
243
+ seqlens.cumsum(0) * spatial_merge_unit + cu_window_seqlens[-1]
244
+ )
245
+ cu_window_seqlens.extend(cu_seqlens_tmp.tolist())
246
+ window_index_id += (grid_t * llm_grid_h * llm_grid_w).item()
247
+
248
+ window_index = torch.cat(window_index, dim=0)
249
+ cu_window_seqlens = torch.tensor(cu_window_seqlens)
250
+ cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens)
251
+ return window_index, cu_window_seqlens
252
+
253
+ def _rearrange(
254
+ self, x: torch.Tensor, window_index: torch.Tensor
255
+ ) -> torch.Tensor:
256
+ """Rearrange the tensor according to window_index.
257
+
258
+ It's copied from Qwen2_5_VisionTransformerPretrainedModel.forward() and
259
+ modified accordingly.
260
+ """
261
+ size = x.shape[0]
262
+ spatial_merge_unit = (
263
+ self.config.spatial_merge_size * self.config.spatial_merge_size
264
+ )
265
+ x_reshaped = x.view(size // spatial_merge_unit, spatial_merge_unit, -1)
266
+ x_rearranged = x_reshaped[window_index, ...]
267
+ return x_rearranged.view(size, -1)
268
+
269
+ def _get_mask(self, seqlen: int, cu_seqlens: torch.Tensor) -> torch.Tensor:
270
+ """Get attention mask for Qwen VL model.
271
+
272
+ It's copied from Qwen2_5_VLVisionAttention.forward() and modified
273
+ accordingly.
274
+ """
275
+ mask = torch.full([1, 1, seqlen, seqlen], float("-inf"))
276
+ for i in range(1, len(cu_seqlens)):
277
+ mask[
278
+ ...,
279
+ cu_seqlens[i - 1] : cu_seqlens[i],
280
+ cu_seqlens[i - 1] : cu_seqlens[i],
281
+ ] = 0
282
+ return mask
283
+
284
+
285
+ def get_image_encoder_config() -> QwenVLImageConfig:
286
+ """Returns the model config for the image encoder of a Qwen 2.5 VL model.
287
+
288
+ Returns:
289
+ The model config for the image encoder of a Qwen 2.5 VL model.
290
+ """
291
+ image_embedding_config = cfg.ImageEmbeddingConfig(
292
+ channels=3,
293
+ image_size=0, # Not used in image encoder.
294
+ patch_size=14,
295
+ temporal_patch_size=2,
296
+ )
297
+ attn_config = cfg.AttentionConfig(
298
+ num_heads=16,
299
+ head_dim=80,
300
+ num_query_groups=16,
301
+ qkv_transpose_before_split=True,
302
+ qkv_use_bias=True,
303
+ output_proj_use_bias=True,
304
+ )
305
+ ff_config = cfg.FeedForwardConfig(
306
+ type=cfg.FeedForwardType.GATED,
307
+ activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
308
+ intermediate_size=3420,
309
+ use_bias=True,
310
+ )
311
+ norm_config = cfg.NormalizationConfig(
312
+ type=cfg.NormalizationType.RMS_NORM,
313
+ epsilon=1e-6,
314
+ )
315
+ block_config = cfg.TransformerBlockConfig(
316
+ attn_config=attn_config,
317
+ ff_config=ff_config,
318
+ pre_attention_norm_config=norm_config,
319
+ post_attention_norm_config=norm_config,
320
+ )
321
+ merger_config = QwenVLMergerConfig(
322
+ activation=cfg.ActivationConfig(cfg.ActivationType.GELU),
323
+ intermediate_size=5120, # embedding_dim(1280) * spatial_merge_size(2)^2
324
+ out_embedding_dim=2048, # embedding_dim of decoder config.
325
+ use_bias=True,
326
+ )
327
+ config = QwenVLImageConfig(
328
+ vocab_size=0, # Not used in image encoder.
329
+ num_layers=32,
330
+ max_seq_len=0, # Not used in image encoder.
331
+ embedding_dim=1280,
332
+ image_embedding=image_embedding_config,
333
+ block_configs=block_config,
334
+ final_norm_config=norm_config,
335
+ merger_config=merger_config,
336
+ window_size=112,
337
+ spatial_merge_size=2,
338
+ full_atten_block_indexes=[7, 15, 23, 31],
339
+ # TODO: b/377051577 - Once RemoveSDPACompositeZeroMaskPass is removed,
340
+ # enable_hlfb can be set to True. See b/383865404#comment3 for details.
341
+ # enable_hlfb=True,
342
+ )
343
+ return config
344
+
345
+
346
+ def get_fake_image_encoder_config() -> QwenVLImageConfig:
347
+ config = get_image_encoder_config()
348
+ # PaliGemma image encoder has only one block config.
349
+ config.block_config(0).ff_config.intermediate_size = 128
350
+ config.image_embedding.patch_size = 2
351
+ config.num_layers = 2
352
+ config.merger_config.intermediate_size = 128
353
+ return config
354
+
355
+
356
+ def build_image_encoder(checkpoint_path: str) -> QwenVLImageEncoder:
357
+ config = get_image_encoder_config()
358
+ encoder = QwenVLImageEncoder(config)
359
+ loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
360
+ # Loose the strictness because only image encoder is being loaded.
361
+ loader.load(encoder, strict=False)
362
+
363
+ # Load merger weights.
364
+ merger_loader = loading_utils.ModelLoader(checkpoint_path, None)
365
+ state = merger_loader.get_state()
366
+ w1_state = dict()
367
+ w1_state["weight"] = state.pop(f"{MERGER_TENSOR_NAMES.ff_up_proj}.weight")
368
+ if config.merger_config.use_bias:
369
+ w1_state["bias"] = state.pop(f"{MERGER_TENSOR_NAMES.ff_up_proj}.bias")
370
+ encoder.merger.w1.load_state_dict(w1_state)
371
+
372
+ w2_state = dict()
373
+ w2_state["weight"] = state.pop(f"{MERGER_TENSOR_NAMES.ff_down_proj}.weight")
374
+ if config.merger_config.use_bias:
375
+ w2_state["bias"] = state.pop(f"{MERGER_TENSOR_NAMES.ff_down_proj}.bias")
376
+ encoder.merger.w2.load_state_dict(w2_state)
377
+
378
+ encoder.eval()
379
+ return encoder
@@ -0,0 +1,84 @@
1
+ # Copyright 2025 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Verifies the reauthored image encoder of Qwen 2.5 VL model."""
17
+
18
+ import logging
19
+ import pathlib
20
+ from absl import app
21
+ from absl import flags
22
+ from ai_edge_torch.generative.examples.qwen_vl import image_encoder
23
+ from PIL import Image
24
+ import requests
25
+ import torch
26
+ import transformers
27
+
28
+ _IMAGE_URL = flags.DEFINE_string(
29
+ "image_url",
30
+ "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg?download=true",
31
+ "The image URI to encode.",
32
+ )
33
+
34
+
35
+ def main(_):
36
+ checkpoint = "Qwen/Qwen2.5-VL-3B-Instruct"
37
+ logging.info("Loading the original model from: %s", checkpoint)
38
+ original_model = (
39
+ transformers.Qwen2_5_VLForConditionalGeneration.from_pretrained(
40
+ checkpoint
41
+ )
42
+ )
43
+ original_vision_model = original_model.eval().visual
44
+
45
+ # Locate the cached dir.
46
+ cached_config_file = transformers.utils.cached_file(
47
+ checkpoint, transformers.utils.CONFIG_NAME
48
+ )
49
+ reauthored_checkpoint = pathlib.Path(cached_config_file).parent
50
+ logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
51
+ reauthored_model = image_encoder.build_image_encoder(reauthored_checkpoint)
52
+
53
+ logging.info("Loading the processor from: %s", checkpoint)
54
+ processor = transformers.AutoProcessor.from_pretrained(checkpoint)
55
+
56
+ logging.info("Loading the image from: %s", _IMAGE_URL.value)
57
+ image = Image.open(requests.get(_IMAGE_URL.value, stream=True).raw)
58
+ image_input = processor(images=image, text="", return_tensors="pt")
59
+
60
+ logging.info("Forwarding the original model...")
61
+ outputs_original = original_vision_model.forward(
62
+ image_input["pixel_values"], image_input["image_grid_thw"]
63
+ )
64
+ logging.info("outputs_original: %s", outputs_original)
65
+
66
+ logging.info("Forwarding the reauthored model...")
67
+ outputs_reauthored = reauthored_model.forward(
68
+ image_input["pixel_values"], image_input["image_grid_thw"]
69
+ )
70
+ logging.info("outputs_reauthored: %s", outputs_reauthored)
71
+
72
+ try:
73
+ assert torch.allclose(
74
+ outputs_original, outputs_reauthored, atol=1e-03, rtol=1e-05
75
+ )
76
+ except AssertionError as e:
77
+ logging.error("*** FAILED *** verify with an image")
78
+ raise e
79
+ else:
80
+ logging.info("*** PASSED *** verify with an image")
81
+
82
+
83
+ if __name__ == "__main__":
84
+ app.run(main)
@@ -224,7 +224,6 @@ class CausalSelfAttention(nn.Module):
224
224
 
225
225
  if rope is not None:
226
226
  # Compute rotary positional embedding for query and key.
227
- n_elem = int(self.config.rotary_percentage * self.config.head_dim)
228
227
  cos, sin = rope
229
228
  q, k = rotary_pos_emb.apply_rope_inline(q, k, cos, sin)
230
229
 
@@ -177,6 +177,8 @@ class ImageEmbeddingConfig:
177
177
  # All images should be normalized to the size of [image_size * image_size].
178
178
  image_size: int
179
179
  patch_size: int
180
+ # Meaningful only when image embedding is Conv3d.
181
+ temporal_patch_size: Optional[int] = None
180
182
 
181
183
 
182
184
  @dataclasses.dataclass
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.3.0.dev20250203"
16
+ __version__ = "0.3.0.dev20250205"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.3.0.dev20250203
3
+ Version: 0.3.0.dev20250205
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=8sPR_5uXJA4NEE0nIwNdSl-ADOJEoR8hAgYvBQDY70Y,120
2
2
  ai_edge_torch/_config.py,sha256=AiqhbcheF7j_ozIGDLC89k1we95aVgFDa-tR6h7UI0s,2529
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
5
- ai_edge_torch/version.py,sha256=L1vAuA2I33V8wn9Ar4IzXsPvJ7RWvbgm1MNqp1h-H0E,706
5
+ ai_edge_torch/version.py,sha256=3qCqU6b85lrBJn0A7eFSW9dGx1TkEsCXhffIwwFwUv4,706
6
6
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
7
7
  ai_edge_torch/_convert/conversion.py,sha256=mckvxznKLXdF2HuJg_IxQaT5Ty-iWl_iXElHEugH3VI,5452
8
8
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
@@ -73,12 +73,12 @@ ai_edge_torch/generative/examples/openelm/convert_to_tflite.py,sha256=pyxRGgMxrn
73
73
  ai_edge_torch/generative/examples/openelm/openelm.py,sha256=sIJ8Ie1oxFrJM-1jvv2ukiJbQOTIUGuMEZvmwZbt3n0,4556
74
74
  ai_edge_torch/generative/examples/openelm/verify.py,sha256=VkigoqhAr8ew95neb3TifYv-SLOSheaWKv2AH0iKDrc,2441
75
75
  ai_edge_torch/generative/examples/paligemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
76
- ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py,sha256=scLsguzzuHfKYDWUd2uZkKYVRzdAbQHLd-kPam8QwvM,3004
77
- ai_edge_torch/generative/examples/paligemma/decoder.py,sha256=g0Fbtf9WigOzQij7W1ksUca4eZTwVdCO2RcuFO2GD3M,5439
78
- ai_edge_torch/generative/examples/paligemma/decoder2.py,sha256=craPUFxlBniBz9a0Jc7VjK01jROMg5a47xJiEA1brnw,6430
79
- ai_edge_torch/generative/examples/paligemma/image_encoder.py,sha256=u4hEHjvLaMu-UnRrISOFXKMEJIMSLa9CfpjjmSIrlSY,5731
80
- ai_edge_torch/generative/examples/paligemma/paligemma.py,sha256=CEMG9gh51ev1KXPew927a6nfampiXX9bL6m-25tNYN8,6340
81
- ai_edge_torch/generative/examples/paligemma/verify.py,sha256=KT3Ruy40tSESxQuy-Sw01NAI3zId1BZr6Bp7FZj1wZk,5622
76
+ ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py,sha256=a6ISb96xhEJc1TtaFGCUiA4msKedPTAeMvkWrfIklx4,2792
77
+ ai_edge_torch/generative/examples/paligemma/decoder.py,sha256=z658dW_D0Iqvo6xnh4vG7_o17-Fufndyis8Rq5yafJY,5439
78
+ ai_edge_torch/generative/examples/paligemma/decoder2.py,sha256=GZa0Ou_DvOijB2nTL_jRvGbn0_dvJPosQAPf47yqicw,5988
79
+ ai_edge_torch/generative/examples/paligemma/image_encoder.py,sha256=7K1xl64UvoHaYmqWjIbahwXHfppwTQ8sN7JrpGKX1XQ,5771
80
+ ai_edge_torch/generative/examples/paligemma/paligemma.py,sha256=x1mgRtVLxkCTvlkPow3y7ADoGTjUh5uc5pF46mxatLw,6099
81
+ ai_edge_torch/generative/examples/paligemma/verify.py,sha256=HLcu1fWMtFFFONAqVW94rOBqq4XvFHtatX3JFGOsfZw,5345
82
82
  ai_edge_torch/generative/examples/paligemma/verify_decoder.py,sha256=al5wMPWri4IRVWrLmCplPi6uoCzwh0vBHMGnCt-XUqo,2690
83
83
  ai_edge_torch/generative/examples/paligemma/verify_decoder2.py,sha256=tm-UfLr0YeBRVcQsWLBOMWI9JUzHmtPEbYK2vpITpqY,2534
84
84
  ai_edge_torch/generative/examples/paligemma/verify_image_encoder.py,sha256=vNm-wTT8BD6zbX6GocfP1QrVoHl0zSvuVxoXN36eeiU,3540
@@ -95,7 +95,9 @@ ai_edge_torch/generative/examples/qwen/qwen.py,sha256=Zi_qiQ1JPokXZ95jgSEnQp3F-L
95
95
  ai_edge_torch/generative/examples/qwen/verify.py,sha256=9_AyEJTeUfvhhID64Rto2bflFPyXMFokdQLsseLUMiI,2775
96
96
  ai_edge_torch/generative/examples/qwen_vl/__init__.py,sha256=JaAnrFoXTl3RJX97XspklkTyqOHVyAgRJsZtzNDd10c,671
97
97
  ai_edge_torch/generative/examples/qwen_vl/decoder.py,sha256=rD_Ch5CzuXeatqv0C3z8vU-zou1z9QDUhoB6V4YTPIg,2829
98
+ ai_edge_torch/generative/examples/qwen_vl/image_encoder.py,sha256=QIPbcturxn5OaVsF5zkRRsyAvCM2Bojyz9XBekHOaro,13405
98
99
  ai_edge_torch/generative/examples/qwen_vl/verify_decoder.py,sha256=FEY_PifD9fQGnERzSOljFLraRIbUVF3XTnCv95A30Cs,2602
100
+ ai_edge_torch/generative/examples/qwen_vl/verify_image_encoder.py,sha256=lQR8p6Zp7PxDN_erMf-FKLIn_Rv4BGyQHjDbModFkeY,2946
99
101
  ai_edge_torch/generative/examples/smollm/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
100
102
  ai_edge_torch/generative/examples/smollm/convert_to_tflite.py,sha256=megskv1oiPhwHSnguoG7zV-esXp1Ns_FPeMLAYKhDb0,2522
101
103
  ai_edge_torch/generative/examples/smollm/convert_v2_to_tflite.py,sha256=CjY1i0iCYxFSjhCpQZwxkmVxILgeo0zu1m0oBrHqyDU,2311
@@ -131,13 +133,13 @@ ai_edge_torch/generative/examples/tiny_llama/verify.py,sha256=7Bk8z033M-BCXJ299f
131
133
  ai_edge_torch/generative/fx_passes/__init__.py,sha256=4rFrppMRKlTwwZeX1ON_cdp4yUqoTOES161IZQkJF6c,1143
132
134
  ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=myGjal5A8yIBoqgArd2k40rZmCgD1Ya369KR7182bhI,2129
133
135
  ai_edge_torch/generative/layers/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
134
- ai_edge_torch/generative/layers/attention.py,sha256=GrAy8CT1pEsgRoB8JQP6PlnNYk8kQ4U3YANfSiTJKn8,13776
136
+ ai_edge_torch/generative/layers/attention.py,sha256=Pm8FLKh-NnOvUjqQC9oX5oghPbdivZvlPVkgOVTShoU,13703
135
137
  ai_edge_torch/generative/layers/attention_utils.py,sha256=zBVwlBUTs-nStIKCZG0ks5ra7tsqc9ShfakFJKH5rds,7344
136
138
  ai_edge_torch/generative/layers/builder.py,sha256=LXGuSHIx6QZAzLFm7aJvlzoMPgQwbXLFchGEKYwOOUA,5090
137
139
  ai_edge_torch/generative/layers/feed_forward.py,sha256=hdICat-8gW7-vxDAevJQ8NQ-mynllPiqLdXQMF6JMnc,4189
138
140
  ai_edge_torch/generative/layers/kv_cache.py,sha256=sGGAZD0mWYuO4FukZfDbHXoxpBOBE9lTYICvZzDj5F8,6400
139
141
  ai_edge_torch/generative/layers/lora.py,sha256=hsvWLLOnW7HQ0AysOZu30x_cetMquDd1tjfyLz8HCSU,17892
140
- ai_edge_torch/generative/layers/model_config.py,sha256=ZVRWEGw1BnLbLCuoR71kWGqQteKp-UM1YvMbbWYlkNw,7999
142
+ ai_edge_torch/generative/layers/model_config.py,sha256=Yqa3wqZLBe0Lj4PPTIaVFaZ--sV6NJ6k8KPjRguDvCc,8095
141
143
  ai_edge_torch/generative/layers/normalization.py,sha256=MbwH-n80Fob5YvjBzdqDjBizMHLzSJGYRDdbD-rL5C0,6174
142
144
  ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=975zR202MdIrILJ7blceAcxrNqX1ZCN0ECKG1gz-bV8,2655
143
145
  ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=vp8dVx6tOe99neJhpbrtIt5fvN5NFw19JVH1v0yi5Mg,4154
@@ -225,8 +227,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
225
227
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
226
228
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
227
229
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
228
- ai_edge_torch_nightly-0.3.0.dev20250203.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
229
- ai_edge_torch_nightly-0.3.0.dev20250203.dist-info/METADATA,sha256=Jybn0dpOrId6u1ZcmYrWnjHnjLE3tk7Opt4XZ2nvGYg,1966
230
- ai_edge_torch_nightly-0.3.0.dev20250203.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
231
- ai_edge_torch_nightly-0.3.0.dev20250203.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
232
- ai_edge_torch_nightly-0.3.0.dev20250203.dist-info/RECORD,,
230
+ ai_edge_torch_nightly-0.3.0.dev20250205.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
231
+ ai_edge_torch_nightly-0.3.0.dev20250205.dist-info/METADATA,sha256=F9YG6dtQw7Vh9T4m0C2z4JAiddvpobcdY-Rxjmh4WX4,1966
232
+ ai_edge_torch_nightly-0.3.0.dev20250205.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
233
+ ai_edge_torch_nightly-0.3.0.dev20250205.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
234
+ ai_edge_torch_nightly-0.3.0.dev20250205.dist-info/RECORD,,