ai-edge-torch-nightly 0.3.0.dev20241110__py3-none-any.whl → 0.3.0.dev20241115__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -51,6 +51,7 @@ def _get_upsample_bilinear2d_pattern():
51
51
  return {
52
52
  "output": (int(output_h), int(output_w)),
53
53
  "align_corners": False,
54
+ "is_nchw_op": True,
54
55
  }
55
56
 
56
57
  return pattern
@@ -74,6 +75,7 @@ def _get_upsample_bilinear2d_align_corners_pattern():
74
75
  return {
75
76
  "output": (int(output_h), int(output_w)),
76
77
  "align_corners": True,
78
+ "is_nchw_op": True,
77
79
  }
78
80
 
79
81
  return pattern
@@ -15,10 +15,8 @@
15
15
 
16
16
  """Verifies the reauthored Gemma2 model."""
17
17
 
18
- import logging
19
18
  from absl import app
20
19
  from absl import flags
21
- from ai_edge_torch.generative.examples.gemma import gemma2
22
20
  from ai_edge_torch.generative.examples.gemma import verify_util
23
21
  import kagglehub
24
22
 
@@ -38,18 +36,7 @@ _MAX_NEW_TOKENS = flags.DEFINE_integer(
38
36
  def main(_):
39
37
  checkpoint = kagglehub.model_download("google/gemma-2/pyTorch/gemma-2-2b-it")
40
38
 
41
- logging.info("Building the reauthored model from: %s", checkpoint)
42
- reauthored_model = gemma2.build_2b_model(checkpoint)
43
-
44
- verify_util.verify_reauthored_gemma_model(
45
- checkpoint=checkpoint,
46
- variant="2b-v2",
47
- reauthored_model=reauthored_model,
48
- generate_prompts=_PROMPTS.value,
49
- forward_input_ids=[[2, 651, 9456, 576, 573, 3520, 3858, 603, 235248]],
50
- max_new_tokens=_MAX_NEW_TOKENS.value,
51
- atol=1e-04,
52
- )
39
+ verify_util.verify_gemma2(checkpoint, _PROMPTS.value, _MAX_NEW_TOKENS.value)
53
40
 
54
41
 
55
42
  if __name__ == "__main__":
@@ -19,6 +19,7 @@ import logging
19
19
  import os
20
20
  from typing import List, Tuple
21
21
 
22
+ from ai_edge_torch.generative.examples.gemma import gemma2
22
23
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
23
24
  from ai_edge_torch.generative.utilities import verifier
24
25
  from gemma import config as gemma_config
@@ -109,8 +110,11 @@ def verify_reauthored_gemma_model(
109
110
  max_new_tokens: int = 20,
110
111
  rtol: float = 1e-05,
111
112
  atol: float = 1e-05,
112
- ):
113
- """Verifies the reauthored Gemma model against the original model."""
113
+ ) -> bool:
114
+ """Verifies the reauthored Gemma model against the original model.
115
+
116
+ Returns True if the verification passes, False otherwise.
117
+ """
114
118
  config = gemma_config.get_model_config(variant)
115
119
  config.tokenizer = os.path.join(checkpoint, tokenizer_filename)
116
120
  # Use float32 to be compatible with the reauthored model.
@@ -120,7 +124,7 @@ def verify_reauthored_gemma_model(
120
124
  original_model = gemma_model.GemmaForCausalLM(config).eval()
121
125
  original_model.load_weights(os.path.join(checkpoint, weight_filename))
122
126
 
123
- verifier.verify_reauthored_model(
127
+ return verifier.verify_reauthored_model(
124
128
  original_model=GemmaWrapper(original_model),
125
129
  reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
126
130
  tokenizer=GemmaTokenizerWrapper(original_model.tokenizer),
@@ -130,3 +134,24 @@ def verify_reauthored_gemma_model(
130
134
  rtol=rtol,
131
135
  atol=atol,
132
136
  )
137
+
138
+
139
+ def verify_gemma2(
140
+ gemma2_model_path: str, prompts: List[str], max_new_tokens: int
141
+ ) -> bool:
142
+ """Verifies the reauthored Gemma2 model.
143
+
144
+ Return True if the verification passes, False otherwise.
145
+ """
146
+ logging.info("Building the reauthored model from: %s", gemma2_model_path)
147
+ reauthored_model = gemma2.build_2b_model(gemma2_model_path)
148
+
149
+ return verify_reauthored_gemma_model(
150
+ checkpoint=gemma2_model_path,
151
+ variant="2b-v2",
152
+ reauthored_model=reauthored_model,
153
+ generate_prompts=prompts,
154
+ forward_input_ids=[[2, 651, 9456, 576, 573, 3520, 3858, 603, 235248]],
155
+ max_new_tokens=max_new_tokens,
156
+ atol=1e-04,
157
+ )
@@ -0,0 +1,158 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Example of building an image encoder of PaliGemma model which is Siglip."""
17
+
18
+ from ai_edge_torch.generative.layers import attention
19
+ from ai_edge_torch.generative.layers import builder
20
+ import ai_edge_torch.generative.layers.model_config as cfg
21
+ import ai_edge_torch.generative.utilities.loader as loading_utils
22
+ import torch
23
+ from torch import nn
24
+
25
+ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
26
+ ff_up_proj="vision_tower.vision_model.encoder.layers.{}.mlp.fc1",
27
+ ff_down_proj="vision_tower.vision_model.encoder.layers.{}.mlp.fc2",
28
+ attn_query_proj=(
29
+ "vision_tower.vision_model.encoder.layers.{}.self_attn.q_proj"
30
+ ),
31
+ attn_key_proj=(
32
+ "vision_tower.vision_model.encoder.layers.{}.self_attn.k_proj"
33
+ ),
34
+ attn_value_proj=(
35
+ "vision_tower.vision_model.encoder.layers.{}.self_attn.v_proj"
36
+ ),
37
+ attn_output_proj=(
38
+ "vision_tower.vision_model.encoder.layers.{}.self_attn.out_proj"
39
+ ),
40
+ pre_attn_norm="vision_tower.vision_model.encoder.layers.{}.layer_norm1",
41
+ post_attn_norm="vision_tower.vision_model.encoder.layers.{}.layer_norm2",
42
+ embedding="vision_tower.vision_model.embeddings.patch_embedding",
43
+ embedding_position=(
44
+ "vision_tower.vision_model.embeddings.position_embedding.weight"
45
+ ),
46
+ final_norm="vision_tower.vision_model.post_layernorm",
47
+ )
48
+
49
+
50
+ class SiglipVisionEncoder(nn.Module):
51
+ """Signlip vision encoder from the Edge Generative API."""
52
+
53
+ def __init__(self, config: cfg.ModelConfig):
54
+ super().__init__()
55
+
56
+ # Construct model layers.
57
+ self.tok_embedding = nn.Conv2d(
58
+ in_channels=config.image_embedding.channels,
59
+ out_channels=config.embedding_dim,
60
+ kernel_size=config.image_embedding.patch_size,
61
+ stride=config.image_embedding.patch_size,
62
+ padding="valid",
63
+ )
64
+ num_patches = (
65
+ config.image_embedding.image_size // config.image_embedding.patch_size
66
+ ) ** 2
67
+ self.tok_embedding_position = nn.Parameter(
68
+ torch.zeros((num_patches, config.embedding_dim))
69
+ )
70
+
71
+ self.transformer_blocks = nn.ModuleList(
72
+ attention.TransformerBlock(config.block_config(idx), config)
73
+ for idx in range(config.num_layers)
74
+ )
75
+ self.final_norm = builder.build_norm(
76
+ config.embedding_dim,
77
+ config.final_norm_config,
78
+ )
79
+ self.config = config
80
+
81
+ @torch.inference_mode
82
+ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
83
+ # Embed the image according to SiplipVisionEmbeddings.
84
+ x = self.tok_embedding(pixel_values)
85
+ x = x.flatten(2).transpose(1, 2) + self.tok_embedding_position
86
+
87
+ # Pass a dummy mask because SDPA attention impl expects non-None mask.
88
+ mask = torch.zeros(x.shape[:2])
89
+ for _, block in enumerate(self.transformer_blocks):
90
+ x = block(x, mask=mask)
91
+ return self.final_norm(x)
92
+
93
+
94
+ def get_image_encoder_config() -> cfg.ModelConfig:
95
+ """Returns the model config for the image encoder of a PaliGemma 3B-224 model.
96
+
97
+ Returns:
98
+ The model config for the image encoder of a PaliGemma 3B model.
99
+ """
100
+ image_embedding_config = cfg.ImageEmbeddingConfig(
101
+ channels=3,
102
+ image_size=224,
103
+ patch_size=14,
104
+ )
105
+ attn_config = cfg.AttentionConfig(
106
+ num_heads=16,
107
+ head_dim=72,
108
+ num_query_groups=16,
109
+ qkv_use_bias=True,
110
+ output_proj_use_bias=True,
111
+ )
112
+ ff_config = cfg.FeedForwardConfig(
113
+ type=cfg.FeedForwardType.SEQUENTIAL,
114
+ activation=cfg.ActivationConfig(cfg.ActivationType.GELU_TANH),
115
+ intermediate_size=4304,
116
+ use_bias=True,
117
+ )
118
+ norm_config = cfg.NormalizationConfig(
119
+ type=cfg.NormalizationType.LAYER_NORM,
120
+ epsilon=1e-6,
121
+ enable_hlfb=True,
122
+ )
123
+ block_config = cfg.TransformerBlockConfig(
124
+ attn_config=attn_config,
125
+ ff_config=ff_config,
126
+ pre_attention_norm_config=norm_config,
127
+ post_attention_norm_config=norm_config,
128
+ )
129
+ config = cfg.ModelConfig(
130
+ vocab_size=0, # Not used in image encoder.
131
+ num_layers=27,
132
+ max_seq_len=0, # Not used in image encoder.
133
+ embedding_dim=1152,
134
+ embedding_use_bias=True,
135
+ image_embedding=image_embedding_config,
136
+ block_configs=block_config,
137
+ final_norm_config=norm_config,
138
+ enable_hlfb=True,
139
+ )
140
+ return config
141
+
142
+
143
+ def get_fake_image_encoder_config() -> cfg.ModelConfig:
144
+ config = get_image_encoder_config()
145
+ # PaliGemma image encoder has only one block config.
146
+ config.block_config(0).ff_config.intermediate_size = 128
147
+ config.num_layers = 2
148
+ return config
149
+
150
+
151
+ def build_image_encoder(checkpoint_path: str) -> SiglipVisionEncoder:
152
+ config = get_image_encoder_config()
153
+ encoder = SiglipVisionEncoder(config)
154
+ loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
155
+ # Loose the strictness because only image encoder is being loaded.
156
+ loader.load(encoder, strict=False)
157
+ encoder.eval()
158
+ return encoder
@@ -0,0 +1,82 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Verifies the reauthored image encoder of PaliGemma 3B model."""
17
+
18
+ import logging
19
+ import pathlib
20
+ from absl import app
21
+ from absl import flags
22
+ from ai_edge_torch.generative.examples.paligemma import image_encoder
23
+ from PIL import Image
24
+ import requests
25
+ import torch
26
+ import transformers
27
+
28
+ _IMAGE_URL = flags.DEFINE_string(
29
+ "image_url",
30
+ "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg?download=true",
31
+ "The image URI to encode.",
32
+ )
33
+
34
+
35
+ def main(_):
36
+ checkpoint = "google/paligemma-3b-mix-224"
37
+ logging.info("Loading the original model from: %s", checkpoint)
38
+ original_full_model = (
39
+ transformers.PaliGemmaForConditionalGeneration.from_pretrained(checkpoint)
40
+ )
41
+ original_vision_model = original_full_model.eval().vision_tower
42
+
43
+ # Locate the cached dir.
44
+ cached_config_file = transformers.utils.cached_file(
45
+ checkpoint, transformers.utils.CONFIG_NAME
46
+ )
47
+ reauthored_checkpoint = pathlib.Path(cached_config_file).parent
48
+ logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
49
+ reauthored_model = image_encoder.build_image_encoder(reauthored_checkpoint)
50
+
51
+ logging.info("Loading the processor from: %s", checkpoint)
52
+ # It works only when GemmaTokenizerFast is available. In some environments,
53
+ # use_fast=False doeesn't work either if the tokenizer cannot load the
54
+ # sentencepiece model file properly.
55
+ processor = transformers.AutoProcessor.from_pretrained(checkpoint)
56
+
57
+ logging.info("Loading the image from: %s", _IMAGE_URL.value)
58
+ image = Image.open(requests.get(_IMAGE_URL.value, stream=True).raw)
59
+ pixel_values = processor(images=image, return_tensors="pt")["pixel_values"]
60
+
61
+ logging.info("Forwarding the original model...")
62
+ outputs_original = original_vision_model.forward(pixel_values=pixel_values)
63
+ outputs_original = outputs_original.last_hidden_state
64
+ logging.info("outputs_original: %s", outputs_original)
65
+
66
+ logging.info("Forwarding the reauthored model...")
67
+ outputs_reauthored = reauthored_model.forward(pixel_values=pixel_values)
68
+ logging.info("outputs_reauthored: %s", outputs_reauthored)
69
+
70
+ try:
71
+ assert torch.allclose(
72
+ outputs_original, outputs_reauthored, atol=1e-04, rtol=1e-04
73
+ )
74
+ except AssertionError as e:
75
+ logging.error("*** FAILED *** verify with an image")
76
+ raise e
77
+ else:
78
+ logging.info("*** PASSED *** verify with an image")
79
+
80
+
81
+ if __name__ == "__main__":
82
+ app.run(main)
@@ -235,9 +235,10 @@ class CausalSelfAttention(nn.Module):
235
235
  k = k.reshape(B, T, -1, self.config.head_dim)
236
236
  v = v.reshape(B, T, -1, self.config.head_dim)
237
237
 
238
- # Compute rotary positional embedding for query and key.
239
- n_elem = int(self.config.rotary_percentage * self.config.head_dim)
240
- q, k = _embed_rope(q, k, n_elem, rope)
238
+ if rope is not None:
239
+ # Compute rotary positional embedding for query and key.
240
+ n_elem = int(self.config.rotary_percentage * self.config.head_dim)
241
+ q, k = _embed_rope(q, k, n_elem, rope)
241
242
 
242
243
  if kv_cache is not None:
243
244
  kv_cache = kv_utils.update(
@@ -372,9 +373,10 @@ class CrossAttention(nn.Module):
372
373
  k = k.view(interim_shape)
373
374
  v = v.view(interim_shape)
374
375
 
375
- # Compute rotary positional embedding for query and key.
376
- n_elem = int(self.config.rotary_percentage * self.config.head_dim)
377
- q, k = _embed_rope(q, k, n_elem, rope)
376
+ if rope is not None:
377
+ # Compute rotary positional embedding for query and key.
378
+ n_elem = int(self.config.rotary_percentage * self.config.head_dim)
379
+ q, k = _embed_rope(q, k, n_elem, rope)
378
380
 
379
381
  if kv_cache is not None:
380
382
  kv_cache = kv_utils.update(
@@ -163,6 +163,16 @@ class TransformerBlockConfig:
163
163
  relative_attention: bool = False
164
164
 
165
165
 
166
+ @dataclass
167
+ class ImageEmbeddingConfig:
168
+ """Image embedding parameters."""
169
+
170
+ channels: int
171
+ # All images should be normalized to the size of [image_size * image_size].
172
+ image_size: int
173
+ patch_size: int
174
+
175
+
166
176
  @dataclass
167
177
  class ModelConfig:
168
178
  """Base configurations for building a transformer architecture."""
@@ -183,6 +193,10 @@ class ModelConfig:
183
193
 
184
194
  # Scale factor of the embedding.
185
195
  embedding_scale: Optional[float] = None
196
+ # Use bias term within embedding.
197
+ embedding_use_bias: bool = False
198
+ # Image embedding parameters.
199
+ image_embedding: Optional[ImageEmbeddingConfig] = None
186
200
 
187
201
  # Use bias term within LLM's HEAD.
188
202
  lm_head_use_bias: bool = False
@@ -115,12 +115,15 @@ class AttentionBlock2D(nn.Module):
115
115
  """
116
116
  super().__init__()
117
117
  self.config = config
118
+ hidden_dim = config.hidden_dim
119
+ if not hidden_dim:
120
+ hidden_dim = config.dim
118
121
  self.norm = layers_builder.build_norm(
119
- config.dim, config.normalization_config
122
+ hidden_dim, config.normalization_config
120
123
  )
121
124
  self.attention = SelfAttention(
122
125
  config.attention_batch_size,
123
- config.dim,
126
+ hidden_dim,
124
127
  config.attention_config,
125
128
  enable_hlfb=config.enable_hlfb,
126
129
  )
@@ -172,7 +175,7 @@ class CrossAttentionBlock2D(nn.Module):
172
175
  super().__init__()
173
176
  self.config = config
174
177
  self.norm = layers_builder.build_norm(
175
- config.query_dim, config.normalization_config
178
+ config.output_dim, config.normalization_config
176
179
  )
177
180
  self.attention = CrossAttention(
178
181
  config.attention_batch_size,
@@ -294,21 +297,30 @@ class TransformerBlock2D(nn.Module):
294
297
  hidden_states
295
298
  """
296
299
 
297
- def __init__(self, config: unet_cfg.TransformerBlock2DConfig):
300
+ def __init__(
301
+ self, config: unet_cfg.TransformerBlock2DConfig, dim_override=None
302
+ ):
298
303
  """Initialize an instance of the TransformerBlock2D.
299
304
 
300
305
  Args:
301
306
  config (unet_cfg.TransformerBlock2Dconfig): the configuration of this
302
307
  block.
308
+ dim_override: in case specified, overrides config.attention_block_config.hidden_dim. Set to None by default.
303
309
  """
304
310
  super().__init__()
305
311
  self.config = config
312
+ attention_block_config_dim = config.attention_block_config.dim
313
+ attention_block_config_hidden_dim = config.attention_block_config.hidden_dim
314
+ if dim_override:
315
+ attention_block_config_dim = dim_override
316
+ if not attention_block_config_hidden_dim:
317
+ attention_block_config_hidden_dim = attention_block_config_dim
306
318
  self.pre_conv_norm = layers_builder.build_norm(
307
- config.attention_block_config.dim, config.pre_conv_normalization_config
319
+ attention_block_config_dim, config.pre_conv_normalization_config
308
320
  )
309
321
  self.conv_in = nn.Conv2d(
310
- config.attention_block_config.dim,
311
- config.attention_block_config.dim,
322
+ attention_block_config_dim,
323
+ attention_block_config_hidden_dim,
312
324
  kernel_size=1,
313
325
  padding=0,
314
326
  )
@@ -318,8 +330,8 @@ class TransformerBlock2D(nn.Module):
318
330
  )
319
331
  self.feed_forward = FeedForwardBlock2D(config.feed_forward_block_config)
320
332
  self.conv_out = nn.Conv2d(
321
- config.attention_block_config.dim,
322
- config.attention_block_config.dim,
333
+ attention_block_config_hidden_dim,
334
+ attention_block_config_dim,
323
335
  kernel_size=1,
324
336
  padding=0,
325
337
  )
@@ -385,14 +397,18 @@ class DownEncoderBlock2D(nn.Module):
385
397
  self.config = config
386
398
  resnets = []
387
399
  transformers = []
400
+ hidden_channels = config.hidden_channels
401
+ if not hidden_channels:
402
+ hidden_channels = config.out_channels
388
403
  for i in range(config.num_layers):
389
404
  input_channels = config.in_channels if i == 0 else config.out_channels
390
405
  resnets.append(
391
406
  ResidualBlock2D(
392
407
  unet_cfg.ResidualBlock2DConfig(
393
408
  in_channels=input_channels,
394
- hidden_channels=config.out_channels,
409
+ hidden_channels=hidden_channels,
395
410
  out_channels=config.out_channels,
411
+ residual_out_channels=config.out_channels,
396
412
  time_embedding_channels=config.time_embedding_channels,
397
413
  normalization_config=config.normalization_config,
398
414
  activation_config=config.activation_config,
@@ -589,23 +605,37 @@ class SkipUpDecoderBlock2D(nn.Module):
589
605
  """
590
606
  super().__init__()
591
607
  self.config = config
608
+ hidden_channels = config.hidden_channels
609
+ if not hidden_channels:
610
+ hidden_channels = config.out_channels
611
+ sub_block_channels = config.sub_block_channels
612
+ if sub_block_channels:
613
+ assert len(sub_block_channels) == config.num_layers, (
614
+ "Assertion failed: The length of 'sub_block_channels'"
615
+ f" ({len(sub_block_channels)}) does not match 'config.num_layers'"
616
+ f" ({config.num_layers})."
617
+ )
618
+ else:
619
+ sub_block_channels = [config.out_channels] * config.num_layers
592
620
  resnets = []
593
621
  transformers = []
594
622
  for i in range(config.num_layers):
623
+ resnet_in_channels = (
624
+ config.prev_out_channels if i == 0 else sub_block_channels[i - 1]
625
+ )
595
626
  res_skip_channels = (
596
627
  config.in_channels
597
628
  if (i == config.num_layers - 1)
598
629
  else config.out_channels
599
630
  )
600
- resnet_in_channels = (
601
- config.prev_out_channels if i == 0 else config.out_channels
602
- )
631
+ residual_out_channel = sub_block_channels[i]
603
632
  resnets.append(
604
633
  ResidualBlock2D(
605
634
  unet_cfg.ResidualBlock2DConfig(
606
635
  in_channels=resnet_in_channels + res_skip_channels,
607
- hidden_channels=config.out_channels,
608
- out_channels=config.out_channels,
636
+ hidden_channels=hidden_channels,
637
+ out_channels=sub_block_channels[i],
638
+ residual_out_channels=residual_out_channel,
609
639
  time_embedding_channels=config.time_embedding_channels,
610
640
  normalization_config=config.normalization_config,
611
641
  activation_config=config.activation_config,
@@ -613,7 +643,12 @@ class SkipUpDecoderBlock2D(nn.Module):
613
643
  )
614
644
  )
615
645
  if config.transformer_block_config:
616
- transformers.append(TransformerBlock2D(config.transformer_block_config))
646
+ transformers.append(
647
+ TransformerBlock2D(
648
+ config.transformer_block_config,
649
+ dim_override=sub_block_channels[i],
650
+ )
651
+ )
617
652
  self.resnets = nn.ModuleList(resnets)
618
653
  self.transformers = (
619
654
  nn.ModuleList(transformers) if len(transformers) > 0 else None
@@ -623,7 +658,7 @@ class SkipUpDecoderBlock2D(nn.Module):
623
658
  if config.upsample_conv:
624
659
  self.upsample_conv = nn.Conv2d(
625
660
  config.out_channels,
626
- config.out_channels,
661
+ sub_block_channels[0],
627
662
  kernel_size=3,
628
663
  stride=1,
629
664
  padding=1,
@@ -711,6 +746,7 @@ class MidBlock2D(nn.Module):
711
746
  in_channels=config.in_channels,
712
747
  hidden_channels=config.in_channels,
713
748
  out_channels=config.in_channels,
749
+ residual_out_channels=config.in_channels,
714
750
  time_embedding_channels=config.time_embedding_channels,
715
751
  normalization_config=config.normalization_config,
716
752
  activation_config=config.activation_config,
@@ -50,10 +50,12 @@ class ResidualBlock2DConfig:
50
50
  in_channels: int
51
51
  hidden_channels: int
52
52
  out_channels: int
53
+ hidden_channels: int
53
54
  normalization_config: layers_cfg.NormalizationConfig
54
55
  activation_config: layers_cfg.ActivationConfig
55
56
  # Optional time embedding channels if the residual block takes a time embedding context as input
56
57
  time_embedding_channels: Optional[int] = None
58
+ residual_out_channels: Optional[int] = None
57
59
 
58
60
 
59
61
  @dataclasses.dataclass
@@ -63,6 +65,7 @@ class AttentionBlock2DConfig:
63
65
  attention_config: layers_cfg.AttentionConfig
64
66
  enable_hlfb: bool = True
65
67
  attention_batch_size: int = 1
68
+ hidden_dim: Optional[int] = None
66
69
 
67
70
 
68
71
  @dataclasses.dataclass
@@ -101,6 +104,8 @@ class UpDecoderBlock2DConfig:
101
104
  normalization_config: layers_cfg.NormalizationConfig
102
105
  activation_config: layers_cfg.ActivationConfig
103
106
  num_layers: int
107
+ # The dimension of output channels of previous connected block
108
+ prev_out_channels: Optional[int] = None
104
109
  # Optional time embedding channels if the residual blocks take a time embedding as input
105
110
  time_embedding_channels: Optional[int] = None
106
111
  # Whether to add upsample operation after residual blocks
@@ -136,6 +141,8 @@ class SkipUpDecoderBlock2DConfig:
136
141
  transformer_block_config: Optional[TransformerBlock2DConfig] = None
137
142
  # Optional dimension of context tensor if context tensor is given as input.
138
143
  context_dim: Optional[int] = None
144
+ sub_block_channels: Optional[tuple] = None
145
+ hidden_channels: Optional[int] = None
139
146
 
140
147
 
141
148
  @dataclasses.dataclass
@@ -157,6 +164,7 @@ class DownEncoderBlock2DConfig:
157
164
  transformer_block_config: Optional[TransformerBlock2DConfig] = None
158
165
  # Optional dimension of context tensor if context tensor is given as input.
159
166
  context_dim: Optional[int] = None
167
+ hidden_channels: Optional[int] = None
160
168
 
161
169
 
162
170
  @dataclasses.dataclass
@@ -157,6 +157,10 @@ class ModelLoader:
157
157
  converted_state["tok_embedding.weight"] = state.pop(
158
158
  f"{self._names.embedding}.weight"
159
159
  )
160
+ if model.config.embedding_use_bias:
161
+ converted_state["tok_embedding.bias"] = state.pop(
162
+ f"{self._names.embedding}.bias"
163
+ )
160
164
  if self._names.embedding_position is not None:
161
165
  converted_state["tok_embedding_position"] = state.pop(
162
166
  f"{self._names.embedding_position}"
@@ -228,7 +228,7 @@ def verify_reauthored_model(
228
228
  rtol: float = 1e-05,
229
229
  atol: float = 1e-05,
230
230
  continue_on_failure: bool = False,
231
- ):
231
+ ) -> bool:
232
232
  """Verifies the reauthored model against the original model.
233
233
 
234
234
  It verifies the reauthored model with two methods:
@@ -237,7 +237,8 @@ def verify_reauthored_model(
237
237
  2. It compares the answer generated by the original and the reauthored model
238
238
  with a prompt.
239
239
 
240
- It prints out "PASS" or "FAILED" to the console.
240
+ It prints out "PASS" or "FAILED" to the console. It returns True if all
241
+ verification passes, False otherwise.
241
242
 
242
243
  Args:
243
244
  original_model (ModelWrapper): The original model.
@@ -253,6 +254,8 @@ def verify_reauthored_model(
253
254
  continue_on_failure (bool): If True, it continues to verify the next prompt
254
255
  or input IDs even if a previous one fails.
255
256
  """
257
+ failure_count = 0
258
+
256
259
  for input_ids in forward_input_ids:
257
260
  logging.info("Verifying the reauthored model with input IDs: %s", input_ids)
258
261
  try:
@@ -261,8 +264,9 @@ def verify_reauthored_model(
261
264
  )
262
265
  except AssertionError as e:
263
266
  logging.error("*** FAILED *** verify with input IDs: %s", input_ids)
267
+ failure_count += 1
264
268
  if not continue_on_failure:
265
- raise e
269
+ return False
266
270
  else:
267
271
  logging.info("*** PASSED *** verify with input IDs: %s", input_ids)
268
272
 
@@ -274,7 +278,15 @@ def verify_reauthored_model(
274
278
  )
275
279
  except AssertionError as e:
276
280
  logging.error("*** FAILED *** verify with prompts: %s", prompts)
281
+ failure_count += 1
277
282
  if not continue_on_failure:
278
- raise e
283
+ return False
279
284
  else:
280
285
  logging.info("*** PASSED *** verify with prompts: %s", prompts)
286
+
287
+ if failure_count == 0:
288
+ logging.info("*** PASSED *** verify_reauthored_model")
289
+ return True
290
+ else:
291
+ logging.error("*** FAILED *** verify_reauthored_model")
292
+ return False
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.3.0.dev20241110"
16
+ __version__ = "0.3.0.dev20241115"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.3.0.dev20241110
3
+ Version: 0.3.0.dev20241115
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -3,7 +3,7 @@ ai_edge_torch/config.py,sha256=FMWeCH2b7HYILBvaI1iZNnYCO4WAhDOwBZBmIE-xrF0,909
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/fx_pass_base.py,sha256=518ziQ0TUxqum2qZXqlD8qr65pHPh8ZNLnwFC6zvK3k,4253
5
5
  ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
6
- ai_edge_torch/version.py,sha256=0kbL8PwrdMx4mw42_rj8uAYUeehe8jsFhw_tENefuGM,706
6
+ ai_edge_torch/version.py,sha256=pp4KVtq0a8ju4UB5nOeiv7QDkmgpHmz5XUokSR86qfI,706
7
7
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
8
8
  ai_edge_torch/_convert/conversion.py,sha256=HwzfRx_DX5TLtPqwEH1_NOm38_INvHzHl4_mX67KOdQ,5448
9
9
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
@@ -12,7 +12,7 @@ ai_edge_torch/_convert/signature.py,sha256=rGpBNss3Y9FCCCcdBwDo16KqavJi8N5P0M_6W
12
12
  ai_edge_torch/_convert/to_channel_last_io.py,sha256=_31phf7TYgZY2ftpNbrdlB1RhDium1lz_BXEQ6IsMFc,2893
13
13
  ai_edge_torch/_convert/fx_passes/__init__.py,sha256=NVe-eGcm7j8jZpP2pcMhC8j5dVjgR1pPzyXhHdvKH4E,1267
14
14
  ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py,sha256=doaww8KqrgRTD5LotBVAIRFsEqzPn9R5lcGehBJOczA,9098
15
- ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py,sha256=XCVqWg_ask0Kb64PED0ZGAODsUuIgfyO2ZJM6aK-TXI,4283
15
+ ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py,sha256=m_yj66V11LmWCYgA7yLtr__cy14IbC5WEJe0BE0_IPE,4339
16
16
  ai_edge_torch/_convert/fx_passes/inject_mlir_debuginfo_pass.py,sha256=IlZuK42kfVcRqAWZp4j2k_81T2uWo9T2558U_GPJAlU,2327
17
17
  ai_edge_torch/_convert/fx_passes/remove_non_user_outputs_pass.py,sha256=f1IUVWyhioOClsMiZzLyynoW2R17U83vA-7Q-3pGPM4,2126
18
18
  ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/__init__.py,sha256=lxnoH-WGLeiQIF8XjMGodjiZEFTxucl7g05N7MR9OPk,796
@@ -50,8 +50,8 @@ ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py,sha256=RZDs6
50
50
  ai_edge_torch/generative/examples/gemma/gemma1.py,sha256=oSbysiPvwp5efMbNYZop3HrxDMGiD15Tmz-HiQuTr2E,3315
51
51
  ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=RQFQDMEnIVp8PefcCTr7P0CvllKI7FVoIJLXbPLLIsc,9056
52
52
  ai_edge_torch/generative/examples/gemma/verify_gemma1.py,sha256=ip-Gmk4CI5f0GWSdAIdrectxQWJ0t328KCsA4nfHuGg,1736
53
- ai_edge_torch/generative/examples/gemma/verify_gemma2.py,sha256=K77k-JpdhIwm3tbBnzpw8HQsFRwAVyszxRo82fR6-q4,1762
54
- ai_edge_torch/generative/examples/gemma/verify_util.py,sha256=sqltZbnyKemNvKqqi9d09i74gP-PPQFodRYfDfnhycQ,4933
53
+ ai_edge_torch/generative/examples/gemma/verify_gemma2.py,sha256=IoBhEMwH07-tFm5-U6F2hpCsI8xynglhq1x9tIOdaPQ,1322
54
+ ai_edge_torch/generative/examples/gemma/verify_util.py,sha256=tR8RflXocDZqvuStyw9aFlzuiTllEC8rNnjrxms6_Is,5727
55
55
  ai_edge_torch/generative/examples/llama/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
56
56
  ai_edge_torch/generative/examples/llama/convert_to_tflite.py,sha256=P0-pByTM5tslE23ILgo7nd0nOGE25ciBRG5wKJj0bBk,2411
57
57
  ai_edge_torch/generative/examples/llama/llama.py,sha256=AMcCbuDBxEfbO-l3KiEXbUaXEJ3RLLwkHii7to7UhVo,6854
@@ -62,7 +62,9 @@ ai_edge_torch/generative/examples/openelm/openelm.py,sha256=sFakstoPDcOHSak0IGFE
62
62
  ai_edge_torch/generative/examples/openelm/verify.py,sha256=VkigoqhAr8ew95neb3TifYv-SLOSheaWKv2AH0iKDrc,2441
63
63
  ai_edge_torch/generative/examples/paligemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
64
64
  ai_edge_torch/generative/examples/paligemma/decoder.py,sha256=JSb9h3gcIh5oYrbLU6rI8OU8FzfWeTCFJT5XRWu4btE,3675
65
+ ai_edge_torch/generative/examples/paligemma/image_encoder.py,sha256=v19_EKALhAP9FjkINKqpv8JsVaQ6iH_7X5FpnhE6abw,5500
65
66
  ai_edge_torch/generative/examples/paligemma/verify_decoder.py,sha256=al5wMPWri4IRVWrLmCplPi6uoCzwh0vBHMGnCt-XUqo,2690
67
+ ai_edge_torch/generative/examples/paligemma/verify_image_encoder.py,sha256=pSekf1BybhieQz3cQx_llbRQHxczXbTqool8fOyGj_0,3114
66
68
  ai_edge_torch/generative/examples/phi/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
67
69
  ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py,sha256=rkbTtMaqSVG48cm-NTxR_LDgZmXAEBqayTm9O49oMXc,2171
68
70
  ai_edge_torch/generative/examples/phi/convert_to_tflite.py,sha256=3go690yX6PFeXMdpY7y4JZorAwxX0HT_b_pKZieauvk,2169
@@ -108,19 +110,19 @@ ai_edge_torch/generative/examples/tiny_llama/verify.py,sha256=7Bk8z033M-BCXJ299f
108
110
  ai_edge_torch/generative/fx_passes/__init__.py,sha256=jrzCB3ZyY_t5jJM1e2Czdt3DjAIL43R0_a-T-I7wOzw,1155
109
111
  ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=hhxSQvkDMv0isZJhmuLiod66ZODaJ8uSPSVTJVHBabQ,1931
110
112
  ai_edge_torch/generative/layers/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
111
- ai_edge_torch/generative/layers/attention.py,sha256=Z0Y_G8IG0LmvLX2u9D8__Fkr22szB-az6wMNnZpzhkA,13233
113
+ ai_edge_torch/generative/layers/attention.py,sha256=zN3BQjA25Ej_aRU0rFnyx--K74xf5ykc02zGvUpYHeE,13295
112
114
  ai_edge_torch/generative/layers/attention_utils.py,sha256=zBVwlBUTs-nStIKCZG0ks5ra7tsqc9ShfakFJKH5rds,7344
113
115
  ai_edge_torch/generative/layers/builder.py,sha256=Z5LyzCEThgnYZeyViakaE3yJVzTGHtw13acHsAQR15U,5050
114
116
  ai_edge_torch/generative/layers/feed_forward.py,sha256=hdICat-8gW7-vxDAevJQ8NQ-mynllPiqLdXQMF6JMnc,4189
115
117
  ai_edge_torch/generative/layers/kv_cache.py,sha256=lbm-yJ1jGPtcgWS4C3FmSnB1IlxqDE7g0BLRh3PN4N4,6324
116
- ai_edge_torch/generative/layers/model_config.py,sha256=DdsdhTP5tZAtyWim-qW2m8HDBsYbs7boqSDb83vwgmE,6998
118
+ ai_edge_torch/generative/layers/model_config.py,sha256=xqa7ZBEjgK4UWJAThRXb_VBFZ5KCGtDu-QaY5GXar9s,7366
117
119
  ai_edge_torch/generative/layers/normalization.py,sha256=eKAGst9rPuyRFExMcQFJO7R3iHdCtlmjeF_lITjLhwE,6498
118
120
  ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=CZqOoibLcHvUgrgaIIWAlmk3XgE2inzx340MN-npLoU,1347
119
121
  ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=gXxh3papKy4FBpGEX7VyZ7rZ1Js6aHK70Q6DKrVSckY,4154
120
122
  ai_edge_torch/generative/layers/unet/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
121
- ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=JwndhL3Z31TvkdGlAoTL5PQzmKfHdRWaaE1EbaMI4Gs,27540
123
+ ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=9jKzOfiBQ66bp1ZnVIAoREIifVNFx4aTlQeYMAx2_pA,29062
122
124
  ai_edge_torch/generative/layers/unet/builder.py,sha256=zAqWXdimmMrQRhmE_t9XkS68mh6PSrzwb-2NZZXrR5I,1901
123
- ai_edge_torch/generative/layers/unet/model_config.py,sha256=raYm8Ol-EFi0zs5vNqmj2ZJCFsnQW2TfwhgDcClfwFA,9356
125
+ ai_edge_torch/generative/layers/unet/model_config.py,sha256=pPDwLawc23pfMaPVyMJlYmxVVusjMvx-l8wBwOYOH-c,9692
124
126
  ai_edge_torch/generative/quantize/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
125
127
  ai_edge_torch/generative/quantize/example.py,sha256=1lfVNUd2cEyRUnoZ7BLbRJ9IN-FTKiWBtZNPFUzAiWE,1747
126
128
  ai_edge_torch/generative/quantize/quant_attrs.py,sha256=n1Fm8BFC8gJa_oiwwAOOghJyHtOXYZ4q-5ZRy4pHrIw,1957
@@ -137,12 +139,12 @@ ai_edge_torch/generative/test/test_quantize.py,sha256=8geJhKwYBU20m0mdGPD1BUFwQ0
137
139
  ai_edge_torch/generative/test/utils.py,sha256=YvEhO2HIj1LkBs5du1UxY-cGRW9HMyAYsOUhgsTrTpA,1796
138
140
  ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
139
141
  ai_edge_torch/generative/utilities/converter.py,sha256=17O83wVifH1vQJCI4WC3DaNiCIOyK2gys1GzohbLrRs,5554
140
- ai_edge_torch/generative/utilities/loader.py,sha256=b9iotIhVDX-Zc9XjIDUaLxnV395AyBnkQe3dV5YA7Co,13297
142
+ ai_edge_torch/generative/utilities/loader.py,sha256=k5fjCokNomte4ymy9IJrEWAuCSMhsPCJfmv1y5s0ZEc,13452
141
143
  ai_edge_torch/generative/utilities/model_builder.py,sha256=89jt80UUfDzYBi-x077HBavWeuNJuYPXym9fiKCY1Tk,5278
142
144
  ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=dqPD9qRXEWtU3ombslOC-BE2l_dMwHoCNu7NsIJhsso,36158
143
145
  ai_edge_torch/generative/utilities/t5_loader.py,sha256=tEsfy8-ymzbbjOIc-oesXF3yGyyWtJgFXn2s7VOavt8,16961
144
146
  ai_edge_torch/generative/utilities/transformers_verifier.py,sha256=8sp9m_FMcXn7nqOrochtu2jIANkJKhnhIBUmH0ZTDR4,1549
145
- ai_edge_torch/generative/utilities/verifier.py,sha256=5C2cm54d9kwL7nGRX-YfnBIJny1ICNhiU-LB3IqJq2E,10075
147
+ ai_edge_torch/generative/utilities/verifier.py,sha256=h5hGyIpYGyPZwvelbzpdkjy99Kpd4JkvhqWtQN9cm-M,10413
146
148
  ai_edge_torch/hlfb/__init__.py,sha256=sH4um75na-O8tzxN6chFyp6Y4xnexsE7kUQpZySv6dE,735
147
149
  ai_edge_torch/hlfb/mark_pattern/__init__.py,sha256=cjTprggj_cuktSCm7-A25e7Shop3k63ylp7sdZmtZ8o,4790
148
150
  ai_edge_torch/hlfb/mark_pattern/passes.py,sha256=pjkKcI1nHECPluAt87cFBrt1DP0f3ge7rHq1NhCkBIE,1936
@@ -189,8 +191,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
189
191
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
190
192
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
191
193
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
192
- ai_edge_torch_nightly-0.3.0.dev20241110.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
193
- ai_edge_torch_nightly-0.3.0.dev20241110.dist-info/METADATA,sha256=ECohBv1Uc5BzRcnT3r3yM8_sElMqIMmpYcnRP_nOp84,1897
194
- ai_edge_torch_nightly-0.3.0.dev20241110.dist-info/WHEEL,sha256=bFJAMchF8aTQGUgMZzHJyDDMPTO3ToJ7x23SLJa1SVo,92
195
- ai_edge_torch_nightly-0.3.0.dev20241110.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
196
- ai_edge_torch_nightly-0.3.0.dev20241110.dist-info/RECORD,,
194
+ ai_edge_torch_nightly-0.3.0.dev20241115.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
195
+ ai_edge_torch_nightly-0.3.0.dev20241115.dist-info/METADATA,sha256=epuuYZFnqVvLzIS0X27XMCFQpnc-dO8JJQ8DXVNv5IE,1897
196
+ ai_edge_torch_nightly-0.3.0.dev20241115.dist-info/WHEEL,sha256=bFJAMchF8aTQGUgMZzHJyDDMPTO3ToJ7x23SLJa1SVo,92
197
+ ai_edge_torch_nightly-0.3.0.dev20241115.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
198
+ ai_edge_torch_nightly-0.3.0.dev20241115.dist-info/RECORD,,