ai-edge-torch-nightly 0.3.0.dev20241110__py3-none-any.whl → 0.3.0.dev20241115__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -51,6 +51,7 @@ def _get_upsample_bilinear2d_pattern():
51
51
  return {
52
52
  "output": (int(output_h), int(output_w)),
53
53
  "align_corners": False,
54
+ "is_nchw_op": True,
54
55
  }
55
56
 
56
57
  return pattern
@@ -74,6 +75,7 @@ def _get_upsample_bilinear2d_align_corners_pattern():
74
75
  return {
75
76
  "output": (int(output_h), int(output_w)),
76
77
  "align_corners": True,
78
+ "is_nchw_op": True,
77
79
  }
78
80
 
79
81
  return pattern
@@ -15,10 +15,8 @@
15
15
 
16
16
  """Verifies the reauthored Gemma2 model."""
17
17
 
18
- import logging
19
18
  from absl import app
20
19
  from absl import flags
21
- from ai_edge_torch.generative.examples.gemma import gemma2
22
20
  from ai_edge_torch.generative.examples.gemma import verify_util
23
21
  import kagglehub
24
22
 
@@ -38,18 +36,7 @@ _MAX_NEW_TOKENS = flags.DEFINE_integer(
38
36
  def main(_):
39
37
  checkpoint = kagglehub.model_download("google/gemma-2/pyTorch/gemma-2-2b-it")
40
38
 
41
- logging.info("Building the reauthored model from: %s", checkpoint)
42
- reauthored_model = gemma2.build_2b_model(checkpoint)
43
-
44
- verify_util.verify_reauthored_gemma_model(
45
- checkpoint=checkpoint,
46
- variant="2b-v2",
47
- reauthored_model=reauthored_model,
48
- generate_prompts=_PROMPTS.value,
49
- forward_input_ids=[[2, 651, 9456, 576, 573, 3520, 3858, 603, 235248]],
50
- max_new_tokens=_MAX_NEW_TOKENS.value,
51
- atol=1e-04,
52
- )
39
+ verify_util.verify_gemma2(checkpoint, _PROMPTS.value, _MAX_NEW_TOKENS.value)
53
40
 
54
41
 
55
42
  if __name__ == "__main__":
@@ -19,6 +19,7 @@ import logging
19
19
  import os
20
20
  from typing import List, Tuple
21
21
 
22
+ from ai_edge_torch.generative.examples.gemma import gemma2
22
23
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
23
24
  from ai_edge_torch.generative.utilities import verifier
24
25
  from gemma import config as gemma_config
@@ -109,8 +110,11 @@ def verify_reauthored_gemma_model(
109
110
  max_new_tokens: int = 20,
110
111
  rtol: float = 1e-05,
111
112
  atol: float = 1e-05,
112
- ):
113
- """Verifies the reauthored Gemma model against the original model."""
113
+ ) -> bool:
114
+ """Verifies the reauthored Gemma model against the original model.
115
+
116
+ Returns True if the verification passes, False otherwise.
117
+ """
114
118
  config = gemma_config.get_model_config(variant)
115
119
  config.tokenizer = os.path.join(checkpoint, tokenizer_filename)
116
120
  # Use float32 to be compatible with the reauthored model.
@@ -120,7 +124,7 @@ def verify_reauthored_gemma_model(
120
124
  original_model = gemma_model.GemmaForCausalLM(config).eval()
121
125
  original_model.load_weights(os.path.join(checkpoint, weight_filename))
122
126
 
123
- verifier.verify_reauthored_model(
127
+ return verifier.verify_reauthored_model(
124
128
  original_model=GemmaWrapper(original_model),
125
129
  reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
126
130
  tokenizer=GemmaTokenizerWrapper(original_model.tokenizer),
@@ -130,3 +134,24 @@ def verify_reauthored_gemma_model(
130
134
  rtol=rtol,
131
135
  atol=atol,
132
136
  )
137
+
138
+
139
+ def verify_gemma2(
140
+ gemma2_model_path: str, prompts: List[str], max_new_tokens: int
141
+ ) -> bool:
142
+ """Verifies the reauthored Gemma2 model.
143
+
144
+ Return True if the verification passes, False otherwise.
145
+ """
146
+ logging.info("Building the reauthored model from: %s", gemma2_model_path)
147
+ reauthored_model = gemma2.build_2b_model(gemma2_model_path)
148
+
149
+ return verify_reauthored_gemma_model(
150
+ checkpoint=gemma2_model_path,
151
+ variant="2b-v2",
152
+ reauthored_model=reauthored_model,
153
+ generate_prompts=prompts,
154
+ forward_input_ids=[[2, 651, 9456, 576, 573, 3520, 3858, 603, 235248]],
155
+ max_new_tokens=max_new_tokens,
156
+ atol=1e-04,
157
+ )
@@ -0,0 +1,158 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Example of building an image encoder of PaliGemma model which is Siglip."""
17
+
18
+ from ai_edge_torch.generative.layers import attention
19
+ from ai_edge_torch.generative.layers import builder
20
+ import ai_edge_torch.generative.layers.model_config as cfg
21
+ import ai_edge_torch.generative.utilities.loader as loading_utils
22
+ import torch
23
+ from torch import nn
24
+
25
+ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
26
+ ff_up_proj="vision_tower.vision_model.encoder.layers.{}.mlp.fc1",
27
+ ff_down_proj="vision_tower.vision_model.encoder.layers.{}.mlp.fc2",
28
+ attn_query_proj=(
29
+ "vision_tower.vision_model.encoder.layers.{}.self_attn.q_proj"
30
+ ),
31
+ attn_key_proj=(
32
+ "vision_tower.vision_model.encoder.layers.{}.self_attn.k_proj"
33
+ ),
34
+ attn_value_proj=(
35
+ "vision_tower.vision_model.encoder.layers.{}.self_attn.v_proj"
36
+ ),
37
+ attn_output_proj=(
38
+ "vision_tower.vision_model.encoder.layers.{}.self_attn.out_proj"
39
+ ),
40
+ pre_attn_norm="vision_tower.vision_model.encoder.layers.{}.layer_norm1",
41
+ post_attn_norm="vision_tower.vision_model.encoder.layers.{}.layer_norm2",
42
+ embedding="vision_tower.vision_model.embeddings.patch_embedding",
43
+ embedding_position=(
44
+ "vision_tower.vision_model.embeddings.position_embedding.weight"
45
+ ),
46
+ final_norm="vision_tower.vision_model.post_layernorm",
47
+ )
48
+
49
+
50
+ class SiglipVisionEncoder(nn.Module):
51
+ """Signlip vision encoder from the Edge Generative API."""
52
+
53
+ def __init__(self, config: cfg.ModelConfig):
54
+ super().__init__()
55
+
56
+ # Construct model layers.
57
+ self.tok_embedding = nn.Conv2d(
58
+ in_channels=config.image_embedding.channels,
59
+ out_channels=config.embedding_dim,
60
+ kernel_size=config.image_embedding.patch_size,
61
+ stride=config.image_embedding.patch_size,
62
+ padding="valid",
63
+ )
64
+ num_patches = (
65
+ config.image_embedding.image_size // config.image_embedding.patch_size
66
+ ) ** 2
67
+ self.tok_embedding_position = nn.Parameter(
68
+ torch.zeros((num_patches, config.embedding_dim))
69
+ )
70
+
71
+ self.transformer_blocks = nn.ModuleList(
72
+ attention.TransformerBlock(config.block_config(idx), config)
73
+ for idx in range(config.num_layers)
74
+ )
75
+ self.final_norm = builder.build_norm(
76
+ config.embedding_dim,
77
+ config.final_norm_config,
78
+ )
79
+ self.config = config
80
+
81
+ @torch.inference_mode
82
+ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
83
+ # Embed the image according to SiplipVisionEmbeddings.
84
+ x = self.tok_embedding(pixel_values)
85
+ x = x.flatten(2).transpose(1, 2) + self.tok_embedding_position
86
+
87
+ # Pass a dummy mask because SDPA attention impl expects non-None mask.
88
+ mask = torch.zeros(x.shape[:2])
89
+ for _, block in enumerate(self.transformer_blocks):
90
+ x = block(x, mask=mask)
91
+ return self.final_norm(x)
92
+
93
+
94
+ def get_image_encoder_config() -> cfg.ModelConfig:
95
+ """Returns the model config for the image encoder of a PaliGemma 3B-224 model.
96
+
97
+ Returns:
98
+ The model config for the image encoder of a PaliGemma 3B model.
99
+ """
100
+ image_embedding_config = cfg.ImageEmbeddingConfig(
101
+ channels=3,
102
+ image_size=224,
103
+ patch_size=14,
104
+ )
105
+ attn_config = cfg.AttentionConfig(
106
+ num_heads=16,
107
+ head_dim=72,
108
+ num_query_groups=16,
109
+ qkv_use_bias=True,
110
+ output_proj_use_bias=True,
111
+ )
112
+ ff_config = cfg.FeedForwardConfig(
113
+ type=cfg.FeedForwardType.SEQUENTIAL,
114
+ activation=cfg.ActivationConfig(cfg.ActivationType.GELU_TANH),
115
+ intermediate_size=4304,
116
+ use_bias=True,
117
+ )
118
+ norm_config = cfg.NormalizationConfig(
119
+ type=cfg.NormalizationType.LAYER_NORM,
120
+ epsilon=1e-6,
121
+ enable_hlfb=True,
122
+ )
123
+ block_config = cfg.TransformerBlockConfig(
124
+ attn_config=attn_config,
125
+ ff_config=ff_config,
126
+ pre_attention_norm_config=norm_config,
127
+ post_attention_norm_config=norm_config,
128
+ )
129
+ config = cfg.ModelConfig(
130
+ vocab_size=0, # Not used in image encoder.
131
+ num_layers=27,
132
+ max_seq_len=0, # Not used in image encoder.
133
+ embedding_dim=1152,
134
+ embedding_use_bias=True,
135
+ image_embedding=image_embedding_config,
136
+ block_configs=block_config,
137
+ final_norm_config=norm_config,
138
+ enable_hlfb=True,
139
+ )
140
+ return config
141
+
142
+
143
+ def get_fake_image_encoder_config() -> cfg.ModelConfig:
144
+ config = get_image_encoder_config()
145
+ # PaliGemma image encoder has only one block config.
146
+ config.block_config(0).ff_config.intermediate_size = 128
147
+ config.num_layers = 2
148
+ return config
149
+
150
+
151
+ def build_image_encoder(checkpoint_path: str) -> SiglipVisionEncoder:
152
+ config = get_image_encoder_config()
153
+ encoder = SiglipVisionEncoder(config)
154
+ loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
155
+ # Loose the strictness because only image encoder is being loaded.
156
+ loader.load(encoder, strict=False)
157
+ encoder.eval()
158
+ return encoder
@@ -0,0 +1,82 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Verifies the reauthored image encoder of PaliGemma 3B model."""
17
+
18
+ import logging
19
+ import pathlib
20
+ from absl import app
21
+ from absl import flags
22
+ from ai_edge_torch.generative.examples.paligemma import image_encoder
23
+ from PIL import Image
24
+ import requests
25
+ import torch
26
+ import transformers
27
+
28
+ _IMAGE_URL = flags.DEFINE_string(
29
+ "image_url",
30
+ "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg?download=true",
31
+ "The image URI to encode.",
32
+ )
33
+
34
+
35
+ def main(_):
36
+ checkpoint = "google/paligemma-3b-mix-224"
37
+ logging.info("Loading the original model from: %s", checkpoint)
38
+ original_full_model = (
39
+ transformers.PaliGemmaForConditionalGeneration.from_pretrained(checkpoint)
40
+ )
41
+ original_vision_model = original_full_model.eval().vision_tower
42
+
43
+ # Locate the cached dir.
44
+ cached_config_file = transformers.utils.cached_file(
45
+ checkpoint, transformers.utils.CONFIG_NAME
46
+ )
47
+ reauthored_checkpoint = pathlib.Path(cached_config_file).parent
48
+ logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
49
+ reauthored_model = image_encoder.build_image_encoder(reauthored_checkpoint)
50
+
51
+ logging.info("Loading the processor from: %s", checkpoint)
52
+ # It works only when GemmaTokenizerFast is available. In some environments,
53
+ # use_fast=False doeesn't work either if the tokenizer cannot load the
54
+ # sentencepiece model file properly.
55
+ processor = transformers.AutoProcessor.from_pretrained(checkpoint)
56
+
57
+ logging.info("Loading the image from: %s", _IMAGE_URL.value)
58
+ image = Image.open(requests.get(_IMAGE_URL.value, stream=True).raw)
59
+ pixel_values = processor(images=image, return_tensors="pt")["pixel_values"]
60
+
61
+ logging.info("Forwarding the original model...")
62
+ outputs_original = original_vision_model.forward(pixel_values=pixel_values)
63
+ outputs_original = outputs_original.last_hidden_state
64
+ logging.info("outputs_original: %s", outputs_original)
65
+
66
+ logging.info("Forwarding the reauthored model...")
67
+ outputs_reauthored = reauthored_model.forward(pixel_values=pixel_values)
68
+ logging.info("outputs_reauthored: %s", outputs_reauthored)
69
+
70
+ try:
71
+ assert torch.allclose(
72
+ outputs_original, outputs_reauthored, atol=1e-04, rtol=1e-04
73
+ )
74
+ except AssertionError as e:
75
+ logging.error("*** FAILED *** verify with an image")
76
+ raise e
77
+ else:
78
+ logging.info("*** PASSED *** verify with an image")
79
+
80
+
81
+ if __name__ == "__main__":
82
+ app.run(main)
@@ -235,9 +235,10 @@ class CausalSelfAttention(nn.Module):
235
235
  k = k.reshape(B, T, -1, self.config.head_dim)
236
236
  v = v.reshape(B, T, -1, self.config.head_dim)
237
237
 
238
- # Compute rotary positional embedding for query and key.
239
- n_elem = int(self.config.rotary_percentage * self.config.head_dim)
240
- q, k = _embed_rope(q, k, n_elem, rope)
238
+ if rope is not None:
239
+ # Compute rotary positional embedding for query and key.
240
+ n_elem = int(self.config.rotary_percentage * self.config.head_dim)
241
+ q, k = _embed_rope(q, k, n_elem, rope)
241
242
 
242
243
  if kv_cache is not None:
243
244
  kv_cache = kv_utils.update(
@@ -372,9 +373,10 @@ class CrossAttention(nn.Module):
372
373
  k = k.view(interim_shape)
373
374
  v = v.view(interim_shape)
374
375
 
375
- # Compute rotary positional embedding for query and key.
376
- n_elem = int(self.config.rotary_percentage * self.config.head_dim)
377
- q, k = _embed_rope(q, k, n_elem, rope)
376
+ if rope is not None:
377
+ # Compute rotary positional embedding for query and key.
378
+ n_elem = int(self.config.rotary_percentage * self.config.head_dim)
379
+ q, k = _embed_rope(q, k, n_elem, rope)
378
380
 
379
381
  if kv_cache is not None:
380
382
  kv_cache = kv_utils.update(
@@ -163,6 +163,16 @@ class TransformerBlockConfig:
163
163
  relative_attention: bool = False
164
164
 
165
165
 
166
+ @dataclass
167
+ class ImageEmbeddingConfig:
168
+ """Image embedding parameters."""
169
+
170
+ channels: int
171
+ # All images should be normalized to the size of [image_size * image_size].
172
+ image_size: int
173
+ patch_size: int
174
+
175
+
166
176
  @dataclass
167
177
  class ModelConfig:
168
178
  """Base configurations for building a transformer architecture."""
@@ -183,6 +193,10 @@ class ModelConfig:
183
193
 
184
194
  # Scale factor of the embedding.
185
195
  embedding_scale: Optional[float] = None
196
+ # Use bias term within embedding.
197
+ embedding_use_bias: bool = False
198
+ # Image embedding parameters.
199
+ image_embedding: Optional[ImageEmbeddingConfig] = None
186
200
 
187
201
  # Use bias term within LLM's HEAD.
188
202
  lm_head_use_bias: bool = False
@@ -115,12 +115,15 @@ class AttentionBlock2D(nn.Module):
115
115
  """
116
116
  super().__init__()
117
117
  self.config = config
118
+ hidden_dim = config.hidden_dim
119
+ if not hidden_dim:
120
+ hidden_dim = config.dim
118
121
  self.norm = layers_builder.build_norm(
119
- config.dim, config.normalization_config
122
+ hidden_dim, config.normalization_config
120
123
  )
121
124
  self.attention = SelfAttention(
122
125
  config.attention_batch_size,
123
- config.dim,
126
+ hidden_dim,
124
127
  config.attention_config,
125
128
  enable_hlfb=config.enable_hlfb,
126
129
  )
@@ -172,7 +175,7 @@ class CrossAttentionBlock2D(nn.Module):
172
175
  super().__init__()
173
176
  self.config = config
174
177
  self.norm = layers_builder.build_norm(
175
- config.query_dim, config.normalization_config
178
+ config.output_dim, config.normalization_config
176
179
  )
177
180
  self.attention = CrossAttention(
178
181
  config.attention_batch_size,
@@ -294,21 +297,30 @@ class TransformerBlock2D(nn.Module):
294
297
  hidden_states
295
298
  """
296
299
 
297
- def __init__(self, config: unet_cfg.TransformerBlock2DConfig):
300
+ def __init__(
301
+ self, config: unet_cfg.TransformerBlock2DConfig, dim_override=None
302
+ ):
298
303
  """Initialize an instance of the TransformerBlock2D.
299
304
 
300
305
  Args:
301
306
  config (unet_cfg.TransformerBlock2Dconfig): the configuration of this
302
307
  block.
308
+ dim_override: in case specified, overrides config.attention_block_config.hidden_dim. Set to None by default.
303
309
  """
304
310
  super().__init__()
305
311
  self.config = config
312
+ attention_block_config_dim = config.attention_block_config.dim
313
+ attention_block_config_hidden_dim = config.attention_block_config.hidden_dim
314
+ if dim_override:
315
+ attention_block_config_dim = dim_override
316
+ if not attention_block_config_hidden_dim:
317
+ attention_block_config_hidden_dim = attention_block_config_dim
306
318
  self.pre_conv_norm = layers_builder.build_norm(
307
- config.attention_block_config.dim, config.pre_conv_normalization_config
319
+ attention_block_config_dim, config.pre_conv_normalization_config
308
320
  )
309
321
  self.conv_in = nn.Conv2d(
310
- config.attention_block_config.dim,
311
- config.attention_block_config.dim,
322
+ attention_block_config_dim,
323
+ attention_block_config_hidden_dim,
312
324
  kernel_size=1,
313
325
  padding=0,
314
326
  )
@@ -318,8 +330,8 @@ class TransformerBlock2D(nn.Module):
318
330
  )
319
331
  self.feed_forward = FeedForwardBlock2D(config.feed_forward_block_config)
320
332
  self.conv_out = nn.Conv2d(
321
- config.attention_block_config.dim,
322
- config.attention_block_config.dim,
333
+ attention_block_config_hidden_dim,
334
+ attention_block_config_dim,
323
335
  kernel_size=1,
324
336
  padding=0,
325
337
  )
@@ -385,14 +397,18 @@ class DownEncoderBlock2D(nn.Module):
385
397
  self.config = config
386
398
  resnets = []
387
399
  transformers = []
400
+ hidden_channels = config.hidden_channels
401
+ if not hidden_channels:
402
+ hidden_channels = config.out_channels
388
403
  for i in range(config.num_layers):
389
404
  input_channels = config.in_channels if i == 0 else config.out_channels
390
405
  resnets.append(
391
406
  ResidualBlock2D(
392
407
  unet_cfg.ResidualBlock2DConfig(
393
408
  in_channels=input_channels,
394
- hidden_channels=config.out_channels,
409
+ hidden_channels=hidden_channels,
395
410
  out_channels=config.out_channels,
411
+ residual_out_channels=config.out_channels,
396
412
  time_embedding_channels=config.time_embedding_channels,
397
413
  normalization_config=config.normalization_config,
398
414
  activation_config=config.activation_config,
@@ -589,23 +605,37 @@ class SkipUpDecoderBlock2D(nn.Module):
589
605
  """
590
606
  super().__init__()
591
607
  self.config = config
608
+ hidden_channels = config.hidden_channels
609
+ if not hidden_channels:
610
+ hidden_channels = config.out_channels
611
+ sub_block_channels = config.sub_block_channels
612
+ if sub_block_channels:
613
+ assert len(sub_block_channels) == config.num_layers, (
614
+ "Assertion failed: The length of 'sub_block_channels'"
615
+ f" ({len(sub_block_channels)}) does not match 'config.num_layers'"
616
+ f" ({config.num_layers})."
617
+ )
618
+ else:
619
+ sub_block_channels = [config.out_channels] * config.num_layers
592
620
  resnets = []
593
621
  transformers = []
594
622
  for i in range(config.num_layers):
623
+ resnet_in_channels = (
624
+ config.prev_out_channels if i == 0 else sub_block_channels[i - 1]
625
+ )
595
626
  res_skip_channels = (
596
627
  config.in_channels
597
628
  if (i == config.num_layers - 1)
598
629
  else config.out_channels
599
630
  )
600
- resnet_in_channels = (
601
- config.prev_out_channels if i == 0 else config.out_channels
602
- )
631
+ residual_out_channel = sub_block_channels[i]
603
632
  resnets.append(
604
633
  ResidualBlock2D(
605
634
  unet_cfg.ResidualBlock2DConfig(
606
635
  in_channels=resnet_in_channels + res_skip_channels,
607
- hidden_channels=config.out_channels,
608
- out_channels=config.out_channels,
636
+ hidden_channels=hidden_channels,
637
+ out_channels=sub_block_channels[i],
638
+ residual_out_channels=residual_out_channel,
609
639
  time_embedding_channels=config.time_embedding_channels,
610
640
  normalization_config=config.normalization_config,
611
641
  activation_config=config.activation_config,
@@ -613,7 +643,12 @@ class SkipUpDecoderBlock2D(nn.Module):
613
643
  )
614
644
  )
615
645
  if config.transformer_block_config:
616
- transformers.append(TransformerBlock2D(config.transformer_block_config))
646
+ transformers.append(
647
+ TransformerBlock2D(
648
+ config.transformer_block_config,
649
+ dim_override=sub_block_channels[i],
650
+ )
651
+ )
617
652
  self.resnets = nn.ModuleList(resnets)
618
653
  self.transformers = (
619
654
  nn.ModuleList(transformers) if len(transformers) > 0 else None
@@ -623,7 +658,7 @@ class SkipUpDecoderBlock2D(nn.Module):
623
658
  if config.upsample_conv:
624
659
  self.upsample_conv = nn.Conv2d(
625
660
  config.out_channels,
626
- config.out_channels,
661
+ sub_block_channels[0],
627
662
  kernel_size=3,
628
663
  stride=1,
629
664
  padding=1,
@@ -711,6 +746,7 @@ class MidBlock2D(nn.Module):
711
746
  in_channels=config.in_channels,
712
747
  hidden_channels=config.in_channels,
713
748
  out_channels=config.in_channels,
749
+ residual_out_channels=config.in_channels,
714
750
  time_embedding_channels=config.time_embedding_channels,
715
751
  normalization_config=config.normalization_config,
716
752
  activation_config=config.activation_config,
@@ -50,10 +50,12 @@ class ResidualBlock2DConfig:
50
50
  in_channels: int
51
51
  hidden_channels: int
52
52
  out_channels: int
53
+ hidden_channels: int
53
54
  normalization_config: layers_cfg.NormalizationConfig
54
55
  activation_config: layers_cfg.ActivationConfig
55
56
  # Optional time embedding channels if the residual block takes a time embedding context as input
56
57
  time_embedding_channels: Optional[int] = None
58
+ residual_out_channels: Optional[int] = None
57
59
 
58
60
 
59
61
  @dataclasses.dataclass
@@ -63,6 +65,7 @@ class AttentionBlock2DConfig:
63
65
  attention_config: layers_cfg.AttentionConfig
64
66
  enable_hlfb: bool = True
65
67
  attention_batch_size: int = 1
68
+ hidden_dim: Optional[int] = None
66
69
 
67
70
 
68
71
  @dataclasses.dataclass
@@ -101,6 +104,8 @@ class UpDecoderBlock2DConfig:
101
104
  normalization_config: layers_cfg.NormalizationConfig
102
105
  activation_config: layers_cfg.ActivationConfig
103
106
  num_layers: int
107
+ # The dimension of output channels of previous connected block
108
+ prev_out_channels: Optional[int] = None
104
109
  # Optional time embedding channels if the residual blocks take a time embedding as input
105
110
  time_embedding_channels: Optional[int] = None
106
111
  # Whether to add upsample operation after residual blocks
@@ -136,6 +141,8 @@ class SkipUpDecoderBlock2DConfig:
136
141
  transformer_block_config: Optional[TransformerBlock2DConfig] = None
137
142
  # Optional dimension of context tensor if context tensor is given as input.
138
143
  context_dim: Optional[int] = None
144
+ sub_block_channels: Optional[tuple] = None
145
+ hidden_channels: Optional[int] = None
139
146
 
140
147
 
141
148
  @dataclasses.dataclass
@@ -157,6 +164,7 @@ class DownEncoderBlock2DConfig:
157
164
  transformer_block_config: Optional[TransformerBlock2DConfig] = None
158
165
  # Optional dimension of context tensor if context tensor is given as input.
159
166
  context_dim: Optional[int] = None
167
+ hidden_channels: Optional[int] = None
160
168
 
161
169
 
162
170
  @dataclasses.dataclass
@@ -157,6 +157,10 @@ class ModelLoader:
157
157
  converted_state["tok_embedding.weight"] = state.pop(
158
158
  f"{self._names.embedding}.weight"
159
159
  )
160
+ if model.config.embedding_use_bias:
161
+ converted_state["tok_embedding.bias"] = state.pop(
162
+ f"{self._names.embedding}.bias"
163
+ )
160
164
  if self._names.embedding_position is not None:
161
165
  converted_state["tok_embedding_position"] = state.pop(
162
166
  f"{self._names.embedding_position}"
@@ -228,7 +228,7 @@ def verify_reauthored_model(
228
228
  rtol: float = 1e-05,
229
229
  atol: float = 1e-05,
230
230
  continue_on_failure: bool = False,
231
- ):
231
+ ) -> bool:
232
232
  """Verifies the reauthored model against the original model.
233
233
 
234
234
  It verifies the reauthored model with two methods:
@@ -237,7 +237,8 @@ def verify_reauthored_model(
237
237
  2. It compares the answer generated by the original and the reauthored model
238
238
  with a prompt.
239
239
 
240
- It prints out "PASS" or "FAILED" to the console.
240
+ It prints out "PASS" or "FAILED" to the console. It returns True if all
241
+ verification passes, False otherwise.
241
242
 
242
243
  Args:
243
244
  original_model (ModelWrapper): The original model.
@@ -253,6 +254,8 @@ def verify_reauthored_model(
253
254
  continue_on_failure (bool): If True, it continues to verify the next prompt
254
255
  or input IDs even if a previous one fails.
255
256
  """
257
+ failure_count = 0
258
+
256
259
  for input_ids in forward_input_ids:
257
260
  logging.info("Verifying the reauthored model with input IDs: %s", input_ids)
258
261
  try:
@@ -261,8 +264,9 @@ def verify_reauthored_model(
261
264
  )
262
265
  except AssertionError as e:
263
266
  logging.error("*** FAILED *** verify with input IDs: %s", input_ids)
267
+ failure_count += 1
264
268
  if not continue_on_failure:
265
- raise e
269
+ return False
266
270
  else:
267
271
  logging.info("*** PASSED *** verify with input IDs: %s", input_ids)
268
272
 
@@ -274,7 +278,15 @@ def verify_reauthored_model(
274
278
  )
275
279
  except AssertionError as e:
276
280
  logging.error("*** FAILED *** verify with prompts: %s", prompts)
281
+ failure_count += 1
277
282
  if not continue_on_failure:
278
- raise e
283
+ return False
279
284
  else:
280
285
  logging.info("*** PASSED *** verify with prompts: %s", prompts)
286
+
287
+ if failure_count == 0:
288
+ logging.info("*** PASSED *** verify_reauthored_model")
289
+ return True
290
+ else:
291
+ logging.error("*** FAILED *** verify_reauthored_model")
292
+ return False
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.3.0.dev20241110"
16
+ __version__ = "0.3.0.dev20241115"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.3.0.dev20241110
3
+ Version: 0.3.0.dev20241115
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -3,7 +3,7 @@ ai_edge_torch/config.py,sha256=FMWeCH2b7HYILBvaI1iZNnYCO4WAhDOwBZBmIE-xrF0,909
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/fx_pass_base.py,sha256=518ziQ0TUxqum2qZXqlD8qr65pHPh8ZNLnwFC6zvK3k,4253
5
5
  ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
6
- ai_edge_torch/version.py,sha256=0kbL8PwrdMx4mw42_rj8uAYUeehe8jsFhw_tENefuGM,706
6
+ ai_edge_torch/version.py,sha256=pp4KVtq0a8ju4UB5nOeiv7QDkmgpHmz5XUokSR86qfI,706
7
7
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
8
8
  ai_edge_torch/_convert/conversion.py,sha256=HwzfRx_DX5TLtPqwEH1_NOm38_INvHzHl4_mX67KOdQ,5448
9
9
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
@@ -12,7 +12,7 @@ ai_edge_torch/_convert/signature.py,sha256=rGpBNss3Y9FCCCcdBwDo16KqavJi8N5P0M_6W
12
12
  ai_edge_torch/_convert/to_channel_last_io.py,sha256=_31phf7TYgZY2ftpNbrdlB1RhDium1lz_BXEQ6IsMFc,2893
13
13
  ai_edge_torch/_convert/fx_passes/__init__.py,sha256=NVe-eGcm7j8jZpP2pcMhC8j5dVjgR1pPzyXhHdvKH4E,1267
14
14
  ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py,sha256=doaww8KqrgRTD5LotBVAIRFsEqzPn9R5lcGehBJOczA,9098
15
- ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py,sha256=XCVqWg_ask0Kb64PED0ZGAODsUuIgfyO2ZJM6aK-TXI,4283
15
+ ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py,sha256=m_yj66V11LmWCYgA7yLtr__cy14IbC5WEJe0BE0_IPE,4339
16
16
  ai_edge_torch/_convert/fx_passes/inject_mlir_debuginfo_pass.py,sha256=IlZuK42kfVcRqAWZp4j2k_81T2uWo9T2558U_GPJAlU,2327
17
17
  ai_edge_torch/_convert/fx_passes/remove_non_user_outputs_pass.py,sha256=f1IUVWyhioOClsMiZzLyynoW2R17U83vA-7Q-3pGPM4,2126
18
18
  ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/__init__.py,sha256=lxnoH-WGLeiQIF8XjMGodjiZEFTxucl7g05N7MR9OPk,796
@@ -50,8 +50,8 @@ ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py,sha256=RZDs6
50
50
  ai_edge_torch/generative/examples/gemma/gemma1.py,sha256=oSbysiPvwp5efMbNYZop3HrxDMGiD15Tmz-HiQuTr2E,3315
51
51
  ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=RQFQDMEnIVp8PefcCTr7P0CvllKI7FVoIJLXbPLLIsc,9056
52
52
  ai_edge_torch/generative/examples/gemma/verify_gemma1.py,sha256=ip-Gmk4CI5f0GWSdAIdrectxQWJ0t328KCsA4nfHuGg,1736
53
- ai_edge_torch/generative/examples/gemma/verify_gemma2.py,sha256=K77k-JpdhIwm3tbBnzpw8HQsFRwAVyszxRo82fR6-q4,1762
54
- ai_edge_torch/generative/examples/gemma/verify_util.py,sha256=sqltZbnyKemNvKqqi9d09i74gP-PPQFodRYfDfnhycQ,4933
53
+ ai_edge_torch/generative/examples/gemma/verify_gemma2.py,sha256=IoBhEMwH07-tFm5-U6F2hpCsI8xynglhq1x9tIOdaPQ,1322
54
+ ai_edge_torch/generative/examples/gemma/verify_util.py,sha256=tR8RflXocDZqvuStyw9aFlzuiTllEC8rNnjrxms6_Is,5727
55
55
  ai_edge_torch/generative/examples/llama/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
56
56
  ai_edge_torch/generative/examples/llama/convert_to_tflite.py,sha256=P0-pByTM5tslE23ILgo7nd0nOGE25ciBRG5wKJj0bBk,2411
57
57
  ai_edge_torch/generative/examples/llama/llama.py,sha256=AMcCbuDBxEfbO-l3KiEXbUaXEJ3RLLwkHii7to7UhVo,6854
@@ -62,7 +62,9 @@ ai_edge_torch/generative/examples/openelm/openelm.py,sha256=sFakstoPDcOHSak0IGFE
62
62
  ai_edge_torch/generative/examples/openelm/verify.py,sha256=VkigoqhAr8ew95neb3TifYv-SLOSheaWKv2AH0iKDrc,2441
63
63
  ai_edge_torch/generative/examples/paligemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
64
64
  ai_edge_torch/generative/examples/paligemma/decoder.py,sha256=JSb9h3gcIh5oYrbLU6rI8OU8FzfWeTCFJT5XRWu4btE,3675
65
+ ai_edge_torch/generative/examples/paligemma/image_encoder.py,sha256=v19_EKALhAP9FjkINKqpv8JsVaQ6iH_7X5FpnhE6abw,5500
65
66
  ai_edge_torch/generative/examples/paligemma/verify_decoder.py,sha256=al5wMPWri4IRVWrLmCplPi6uoCzwh0vBHMGnCt-XUqo,2690
67
+ ai_edge_torch/generative/examples/paligemma/verify_image_encoder.py,sha256=pSekf1BybhieQz3cQx_llbRQHxczXbTqool8fOyGj_0,3114
66
68
  ai_edge_torch/generative/examples/phi/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
67
69
  ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py,sha256=rkbTtMaqSVG48cm-NTxR_LDgZmXAEBqayTm9O49oMXc,2171
68
70
  ai_edge_torch/generative/examples/phi/convert_to_tflite.py,sha256=3go690yX6PFeXMdpY7y4JZorAwxX0HT_b_pKZieauvk,2169
@@ -108,19 +110,19 @@ ai_edge_torch/generative/examples/tiny_llama/verify.py,sha256=7Bk8z033M-BCXJ299f
108
110
  ai_edge_torch/generative/fx_passes/__init__.py,sha256=jrzCB3ZyY_t5jJM1e2Czdt3DjAIL43R0_a-T-I7wOzw,1155
109
111
  ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=hhxSQvkDMv0isZJhmuLiod66ZODaJ8uSPSVTJVHBabQ,1931
110
112
  ai_edge_torch/generative/layers/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
111
- ai_edge_torch/generative/layers/attention.py,sha256=Z0Y_G8IG0LmvLX2u9D8__Fkr22szB-az6wMNnZpzhkA,13233
113
+ ai_edge_torch/generative/layers/attention.py,sha256=zN3BQjA25Ej_aRU0rFnyx--K74xf5ykc02zGvUpYHeE,13295
112
114
  ai_edge_torch/generative/layers/attention_utils.py,sha256=zBVwlBUTs-nStIKCZG0ks5ra7tsqc9ShfakFJKH5rds,7344
113
115
  ai_edge_torch/generative/layers/builder.py,sha256=Z5LyzCEThgnYZeyViakaE3yJVzTGHtw13acHsAQR15U,5050
114
116
  ai_edge_torch/generative/layers/feed_forward.py,sha256=hdICat-8gW7-vxDAevJQ8NQ-mynllPiqLdXQMF6JMnc,4189
115
117
  ai_edge_torch/generative/layers/kv_cache.py,sha256=lbm-yJ1jGPtcgWS4C3FmSnB1IlxqDE7g0BLRh3PN4N4,6324
116
- ai_edge_torch/generative/layers/model_config.py,sha256=DdsdhTP5tZAtyWim-qW2m8HDBsYbs7boqSDb83vwgmE,6998
118
+ ai_edge_torch/generative/layers/model_config.py,sha256=xqa7ZBEjgK4UWJAThRXb_VBFZ5KCGtDu-QaY5GXar9s,7366
117
119
  ai_edge_torch/generative/layers/normalization.py,sha256=eKAGst9rPuyRFExMcQFJO7R3iHdCtlmjeF_lITjLhwE,6498
118
120
  ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=CZqOoibLcHvUgrgaIIWAlmk3XgE2inzx340MN-npLoU,1347
119
121
  ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=gXxh3papKy4FBpGEX7VyZ7rZ1Js6aHK70Q6DKrVSckY,4154
120
122
  ai_edge_torch/generative/layers/unet/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
121
- ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=JwndhL3Z31TvkdGlAoTL5PQzmKfHdRWaaE1EbaMI4Gs,27540
123
+ ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=9jKzOfiBQ66bp1ZnVIAoREIifVNFx4aTlQeYMAx2_pA,29062
122
124
  ai_edge_torch/generative/layers/unet/builder.py,sha256=zAqWXdimmMrQRhmE_t9XkS68mh6PSrzwb-2NZZXrR5I,1901
123
- ai_edge_torch/generative/layers/unet/model_config.py,sha256=raYm8Ol-EFi0zs5vNqmj2ZJCFsnQW2TfwhgDcClfwFA,9356
125
+ ai_edge_torch/generative/layers/unet/model_config.py,sha256=pPDwLawc23pfMaPVyMJlYmxVVusjMvx-l8wBwOYOH-c,9692
124
126
  ai_edge_torch/generative/quantize/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
125
127
  ai_edge_torch/generative/quantize/example.py,sha256=1lfVNUd2cEyRUnoZ7BLbRJ9IN-FTKiWBtZNPFUzAiWE,1747
126
128
  ai_edge_torch/generative/quantize/quant_attrs.py,sha256=n1Fm8BFC8gJa_oiwwAOOghJyHtOXYZ4q-5ZRy4pHrIw,1957
@@ -137,12 +139,12 @@ ai_edge_torch/generative/test/test_quantize.py,sha256=8geJhKwYBU20m0mdGPD1BUFwQ0
137
139
  ai_edge_torch/generative/test/utils.py,sha256=YvEhO2HIj1LkBs5du1UxY-cGRW9HMyAYsOUhgsTrTpA,1796
138
140
  ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
139
141
  ai_edge_torch/generative/utilities/converter.py,sha256=17O83wVifH1vQJCI4WC3DaNiCIOyK2gys1GzohbLrRs,5554
140
- ai_edge_torch/generative/utilities/loader.py,sha256=b9iotIhVDX-Zc9XjIDUaLxnV395AyBnkQe3dV5YA7Co,13297
142
+ ai_edge_torch/generative/utilities/loader.py,sha256=k5fjCokNomte4ymy9IJrEWAuCSMhsPCJfmv1y5s0ZEc,13452
141
143
  ai_edge_torch/generative/utilities/model_builder.py,sha256=89jt80UUfDzYBi-x077HBavWeuNJuYPXym9fiKCY1Tk,5278
142
144
  ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=dqPD9qRXEWtU3ombslOC-BE2l_dMwHoCNu7NsIJhsso,36158
143
145
  ai_edge_torch/generative/utilities/t5_loader.py,sha256=tEsfy8-ymzbbjOIc-oesXF3yGyyWtJgFXn2s7VOavt8,16961
144
146
  ai_edge_torch/generative/utilities/transformers_verifier.py,sha256=8sp9m_FMcXn7nqOrochtu2jIANkJKhnhIBUmH0ZTDR4,1549
145
- ai_edge_torch/generative/utilities/verifier.py,sha256=5C2cm54d9kwL7nGRX-YfnBIJny1ICNhiU-LB3IqJq2E,10075
147
+ ai_edge_torch/generative/utilities/verifier.py,sha256=h5hGyIpYGyPZwvelbzpdkjy99Kpd4JkvhqWtQN9cm-M,10413
146
148
  ai_edge_torch/hlfb/__init__.py,sha256=sH4um75na-O8tzxN6chFyp6Y4xnexsE7kUQpZySv6dE,735
147
149
  ai_edge_torch/hlfb/mark_pattern/__init__.py,sha256=cjTprggj_cuktSCm7-A25e7Shop3k63ylp7sdZmtZ8o,4790
148
150
  ai_edge_torch/hlfb/mark_pattern/passes.py,sha256=pjkKcI1nHECPluAt87cFBrt1DP0f3ge7rHq1NhCkBIE,1936
@@ -189,8 +191,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
189
191
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
190
192
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
191
193
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
192
- ai_edge_torch_nightly-0.3.0.dev20241110.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
193
- ai_edge_torch_nightly-0.3.0.dev20241110.dist-info/METADATA,sha256=ECohBv1Uc5BzRcnT3r3yM8_sElMqIMmpYcnRP_nOp84,1897
194
- ai_edge_torch_nightly-0.3.0.dev20241110.dist-info/WHEEL,sha256=bFJAMchF8aTQGUgMZzHJyDDMPTO3ToJ7x23SLJa1SVo,92
195
- ai_edge_torch_nightly-0.3.0.dev20241110.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
196
- ai_edge_torch_nightly-0.3.0.dev20241110.dist-info/RECORD,,
194
+ ai_edge_torch_nightly-0.3.0.dev20241115.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
195
+ ai_edge_torch_nightly-0.3.0.dev20241115.dist-info/METADATA,sha256=epuuYZFnqVvLzIS0X27XMCFQpnc-dO8JJQ8DXVNv5IE,1897
196
+ ai_edge_torch_nightly-0.3.0.dev20241115.dist-info/WHEEL,sha256=bFJAMchF8aTQGUgMZzHJyDDMPTO3ToJ7x23SLJa1SVo,92
197
+ ai_edge_torch_nightly-0.3.0.dev20241115.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
198
+ ai_edge_torch_nightly-0.3.0.dev20241115.dist-info/RECORD,,