ai-edge-torch-nightly 0.3.0.dev20241108__py3-none-any.whl → 0.3.0.dev20241114__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (21) hide show
  1. ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py +2 -0
  2. ai_edge_torch/generative/examples/gemma/verify_gemma2.py +1 -14
  3. ai_edge_torch/generative/examples/gemma/verify_util.py +28 -3
  4. ai_edge_torch/generative/examples/openelm/openelm.py +1 -0
  5. ai_edge_torch/generative/examples/paligemma/__init__.py +14 -0
  6. ai_edge_torch/generative/examples/paligemma/decoder.py +103 -0
  7. ai_edge_torch/generative/examples/paligemma/image_encoder.py +158 -0
  8. ai_edge_torch/generative/examples/paligemma/verify_decoder.py +75 -0
  9. ai_edge_torch/generative/examples/paligemma/verify_image_encoder.py +82 -0
  10. ai_edge_torch/generative/layers/attention.py +8 -6
  11. ai_edge_torch/generative/layers/model_config.py +14 -0
  12. ai_edge_torch/generative/layers/unet/blocks_2d.py +53 -17
  13. ai_edge_torch/generative/layers/unet/model_config.py +8 -0
  14. ai_edge_torch/generative/utilities/loader.py +4 -0
  15. ai_edge_torch/generative/utilities/verifier.py +46 -21
  16. ai_edge_torch/version.py +1 -1
  17. {ai_edge_torch_nightly-0.3.0.dev20241108.dist-info → ai_edge_torch_nightly-0.3.0.dev20241114.dist-info}/METADATA +1 -1
  18. {ai_edge_torch_nightly-0.3.0.dev20241108.dist-info → ai_edge_torch_nightly-0.3.0.dev20241114.dist-info}/RECORD +21 -16
  19. {ai_edge_torch_nightly-0.3.0.dev20241108.dist-info → ai_edge_torch_nightly-0.3.0.dev20241114.dist-info}/WHEEL +1 -1
  20. {ai_edge_torch_nightly-0.3.0.dev20241108.dist-info → ai_edge_torch_nightly-0.3.0.dev20241114.dist-info}/LICENSE +0 -0
  21. {ai_edge_torch_nightly-0.3.0.dev20241108.dist-info → ai_edge_torch_nightly-0.3.0.dev20241114.dist-info}/top_level.txt +0 -0
@@ -51,6 +51,7 @@ def _get_upsample_bilinear2d_pattern():
51
51
  return {
52
52
  "output": (int(output_h), int(output_w)),
53
53
  "align_corners": False,
54
+ "is_nchw_op": True,
54
55
  }
55
56
 
56
57
  return pattern
@@ -74,6 +75,7 @@ def _get_upsample_bilinear2d_align_corners_pattern():
74
75
  return {
75
76
  "output": (int(output_h), int(output_w)),
76
77
  "align_corners": True,
78
+ "is_nchw_op": True,
77
79
  }
78
80
 
79
81
  return pattern
@@ -15,10 +15,8 @@
15
15
 
16
16
  """Verifies the reauthored Gemma2 model."""
17
17
 
18
- import logging
19
18
  from absl import app
20
19
  from absl import flags
21
- from ai_edge_torch.generative.examples.gemma import gemma2
22
20
  from ai_edge_torch.generative.examples.gemma import verify_util
23
21
  import kagglehub
24
22
 
@@ -38,18 +36,7 @@ _MAX_NEW_TOKENS = flags.DEFINE_integer(
38
36
  def main(_):
39
37
  checkpoint = kagglehub.model_download("google/gemma-2/pyTorch/gemma-2-2b-it")
40
38
 
41
- logging.info("Building the reauthored model from: %s", checkpoint)
42
- reauthored_model = gemma2.build_2b_model(checkpoint)
43
-
44
- verify_util.verify_reauthored_gemma_model(
45
- checkpoint=checkpoint,
46
- variant="2b-v2",
47
- reauthored_model=reauthored_model,
48
- generate_prompts=_PROMPTS.value,
49
- forward_input_ids=[[2, 651, 9456, 576, 573, 3520, 3858, 603, 235248]],
50
- max_new_tokens=_MAX_NEW_TOKENS.value,
51
- atol=1e-04,
52
- )
39
+ verify_util.verify_gemma2(checkpoint, _PROMPTS.value, _MAX_NEW_TOKENS.value)
53
40
 
54
41
 
55
42
  if __name__ == "__main__":
@@ -19,6 +19,7 @@ import logging
19
19
  import os
20
20
  from typing import List, Tuple
21
21
 
22
+ from ai_edge_torch.generative.examples.gemma import gemma2
22
23
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
23
24
  from ai_edge_torch.generative.utilities import verifier
24
25
  from gemma import config as gemma_config
@@ -109,8 +110,11 @@ def verify_reauthored_gemma_model(
109
110
  max_new_tokens: int = 20,
110
111
  rtol: float = 1e-05,
111
112
  atol: float = 1e-05,
112
- ):
113
- """Verifies the reauthored Gemma model against the original model."""
113
+ ) -> bool:
114
+ """Verifies the reauthored Gemma model against the original model.
115
+
116
+ Returns True if the verification passes, False otherwise.
117
+ """
114
118
  config = gemma_config.get_model_config(variant)
115
119
  config.tokenizer = os.path.join(checkpoint, tokenizer_filename)
116
120
  # Use float32 to be compatible with the reauthored model.
@@ -120,7 +124,7 @@ def verify_reauthored_gemma_model(
120
124
  original_model = gemma_model.GemmaForCausalLM(config).eval()
121
125
  original_model.load_weights(os.path.join(checkpoint, weight_filename))
122
126
 
123
- verifier.verify_reauthored_model(
127
+ return verifier.verify_reauthored_model(
124
128
  original_model=GemmaWrapper(original_model),
125
129
  reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
126
130
  tokenizer=GemmaTokenizerWrapper(original_model.tokenizer),
@@ -130,3 +134,24 @@ def verify_reauthored_gemma_model(
130
134
  rtol=rtol,
131
135
  atol=atol,
132
136
  )
137
+
138
+
139
+ def verify_gemma2(
140
+ gemma2_model_path: str, prompts: List[str], max_new_tokens: int
141
+ ) -> bool:
142
+ """Verifies the reauthored Gemma2 model.
143
+
144
+ Return True if the verification passes, False otherwise.
145
+ """
146
+ logging.info("Building the reauthored model from: %s", gemma2_model_path)
147
+ reauthored_model = gemma2.build_2b_model(gemma2_model_path)
148
+
149
+ return verify_reauthored_gemma_model(
150
+ checkpoint=gemma2_model_path,
151
+ variant="2b-v2",
152
+ reauthored_model=reauthored_model,
153
+ generate_prompts=prompts,
154
+ forward_input_ids=[[2, 651, 9456, 576, 573, 3520, 3858, 603, 235248]],
155
+ max_new_tokens=max_new_tokens,
156
+ atol=1e-04,
157
+ )
@@ -93,6 +93,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
93
93
  kv_cache_max_len=kv_cache_max_len,
94
94
  block_configs=[get_block_config(i) for i in range(num_layers)],
95
95
  final_norm_config=norm_config,
96
+ enable_hlfb=True,
96
97
  )
97
98
  return config
98
99
 
@@ -0,0 +1,14 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
@@ -0,0 +1,103 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Example of building a decoder of PaliGemma 3B model which is Gemma1."""
17
+
18
+ import ai_edge_torch.generative.layers.model_config as cfg
19
+ from ai_edge_torch.generative.utilities import model_builder
20
+ import ai_edge_torch.generative.utilities.loader as loading_utils
21
+
22
+ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
23
+ ff_up_proj="language_model.model.layers.{}.mlp.up_proj",
24
+ ff_down_proj="language_model.model.layers.{}.mlp.down_proj",
25
+ ff_gate_proj="language_model.model.layers.{}.mlp.gate_proj",
26
+ attn_query_proj="language_model.model.layers.{}.self_attn.q_proj",
27
+ attn_key_proj="language_model.model.layers.{}.self_attn.k_proj",
28
+ attn_value_proj="language_model.model.layers.{}.self_attn.v_proj",
29
+ attn_output_proj="language_model.model.layers.{}.self_attn.o_proj",
30
+ pre_attn_norm="language_model.model.layers.{}.input_layernorm",
31
+ post_attn_norm="language_model.model.layers.{}.post_attention_layernorm",
32
+ embedding="language_model.model.embed_tokens",
33
+ final_norm="language_model.model.norm",
34
+ lm_head=None,
35
+ )
36
+
37
+
38
+ def get_decoder_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
39
+ """Returns the model config for the decoder of a PaliGemma 3B model.
40
+
41
+ Args:
42
+ kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
43
+ is 1024.
44
+
45
+ Returns:
46
+ The model config for the decoder of a PaliGemma 3B model.
47
+ """
48
+ attn_config = cfg.AttentionConfig(
49
+ num_heads=8,
50
+ head_dim=256,
51
+ num_query_groups=1,
52
+ rotary_base=10000,
53
+ rotary_percentage=1.0,
54
+ )
55
+ ff_config = cfg.FeedForwardConfig(
56
+ type=cfg.FeedForwardType.GATED,
57
+ activation=cfg.ActivationConfig(cfg.ActivationType.GELU_TANH),
58
+ intermediate_size=16384,
59
+ )
60
+ norm_config = cfg.NormalizationConfig(
61
+ type=cfg.NormalizationType.RMS_NORM,
62
+ epsilon=1e-6,
63
+ zero_centered=True,
64
+ )
65
+ block_config = cfg.TransformerBlockConfig(
66
+ attn_config=attn_config,
67
+ ff_config=ff_config,
68
+ pre_attention_norm_config=norm_config,
69
+ post_attention_norm_config=norm_config,
70
+ )
71
+ config = cfg.ModelConfig(
72
+ vocab_size=257216,
73
+ num_layers=18,
74
+ max_seq_len=8192,
75
+ embedding_dim=2048,
76
+ embedding_scale=2048**0.5,
77
+ kv_cache_max_len=kv_cache_max_len,
78
+ block_configs=block_config,
79
+ final_norm_config=norm_config,
80
+ lm_head_use_bias=False,
81
+ enable_hlfb=True,
82
+ )
83
+ return config
84
+
85
+
86
+ def get_fake_decoder_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
87
+ config = get_decoder_config(kv_cache_max_len)
88
+ # PaliGemma decoder has only one block config.
89
+ config.block_config(0).ff_config.intermediate_size = 128
90
+ config.vocab_size = 128
91
+ config.num_layers = 2
92
+ config.max_seq_len = 2 * kv_cache_max_len
93
+ return config
94
+
95
+
96
+ def build_decoder(
97
+ checkpoint_path: str, **kwargs
98
+ ) -> model_builder.DecoderOnlyModel:
99
+ return model_builder.build_decoder_only_model(
100
+ checkpoint_path=checkpoint_path,
101
+ config=get_decoder_config(**kwargs),
102
+ tensor_names=TENSOR_NAMES,
103
+ )
@@ -0,0 +1,158 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Example of building an image encoder of PaliGemma model which is Siglip."""
17
+
18
+ from ai_edge_torch.generative.layers import attention
19
+ from ai_edge_torch.generative.layers import builder
20
+ import ai_edge_torch.generative.layers.model_config as cfg
21
+ import ai_edge_torch.generative.utilities.loader as loading_utils
22
+ import torch
23
+ from torch import nn
24
+
25
+ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
26
+ ff_up_proj="vision_tower.vision_model.encoder.layers.{}.mlp.fc1",
27
+ ff_down_proj="vision_tower.vision_model.encoder.layers.{}.mlp.fc2",
28
+ attn_query_proj=(
29
+ "vision_tower.vision_model.encoder.layers.{}.self_attn.q_proj"
30
+ ),
31
+ attn_key_proj=(
32
+ "vision_tower.vision_model.encoder.layers.{}.self_attn.k_proj"
33
+ ),
34
+ attn_value_proj=(
35
+ "vision_tower.vision_model.encoder.layers.{}.self_attn.v_proj"
36
+ ),
37
+ attn_output_proj=(
38
+ "vision_tower.vision_model.encoder.layers.{}.self_attn.out_proj"
39
+ ),
40
+ pre_attn_norm="vision_tower.vision_model.encoder.layers.{}.layer_norm1",
41
+ post_attn_norm="vision_tower.vision_model.encoder.layers.{}.layer_norm2",
42
+ embedding="vision_tower.vision_model.embeddings.patch_embedding",
43
+ embedding_position=(
44
+ "vision_tower.vision_model.embeddings.position_embedding.weight"
45
+ ),
46
+ final_norm="vision_tower.vision_model.post_layernorm",
47
+ )
48
+
49
+
50
+ class SiglipVisionEncoder(nn.Module):
51
+ """Signlip vision encoder from the Edge Generative API."""
52
+
53
+ def __init__(self, config: cfg.ModelConfig):
54
+ super().__init__()
55
+
56
+ # Construct model layers.
57
+ self.tok_embedding = nn.Conv2d(
58
+ in_channels=config.image_embedding.channels,
59
+ out_channels=config.embedding_dim,
60
+ kernel_size=config.image_embedding.patch_size,
61
+ stride=config.image_embedding.patch_size,
62
+ padding="valid",
63
+ )
64
+ num_patches = (
65
+ config.image_embedding.image_size // config.image_embedding.patch_size
66
+ ) ** 2
67
+ self.tok_embedding_position = nn.Parameter(
68
+ torch.zeros((num_patches, config.embedding_dim))
69
+ )
70
+
71
+ self.transformer_blocks = nn.ModuleList(
72
+ attention.TransformerBlock(config.block_config(idx), config)
73
+ for idx in range(config.num_layers)
74
+ )
75
+ self.final_norm = builder.build_norm(
76
+ config.embedding_dim,
77
+ config.final_norm_config,
78
+ )
79
+ self.config = config
80
+
81
+ @torch.inference_mode
82
+ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
83
+ # Embed the image according to SiplipVisionEmbeddings.
84
+ x = self.tok_embedding(pixel_values)
85
+ x = x.flatten(2).transpose(1, 2) + self.tok_embedding_position
86
+
87
+ # Pass a dummy mask because SDPA attention impl expects non-None mask.
88
+ mask = torch.zeros(x.shape[:2])
89
+ for _, block in enumerate(self.transformer_blocks):
90
+ x = block(x, mask=mask)
91
+ return self.final_norm(x)
92
+
93
+
94
+ def get_image_encoder_config() -> cfg.ModelConfig:
95
+ """Returns the model config for the image encoder of a PaliGemma 3B-224 model.
96
+
97
+ Returns:
98
+ The model config for the image encoder of a PaliGemma 3B model.
99
+ """
100
+ image_embedding_config = cfg.ImageEmbeddingConfig(
101
+ channels=3,
102
+ image_size=224,
103
+ patch_size=14,
104
+ )
105
+ attn_config = cfg.AttentionConfig(
106
+ num_heads=16,
107
+ head_dim=72,
108
+ num_query_groups=16,
109
+ qkv_use_bias=True,
110
+ output_proj_use_bias=True,
111
+ )
112
+ ff_config = cfg.FeedForwardConfig(
113
+ type=cfg.FeedForwardType.SEQUENTIAL,
114
+ activation=cfg.ActivationConfig(cfg.ActivationType.GELU_TANH),
115
+ intermediate_size=4304,
116
+ use_bias=True,
117
+ )
118
+ norm_config = cfg.NormalizationConfig(
119
+ type=cfg.NormalizationType.LAYER_NORM,
120
+ epsilon=1e-6,
121
+ enable_hlfb=True,
122
+ )
123
+ block_config = cfg.TransformerBlockConfig(
124
+ attn_config=attn_config,
125
+ ff_config=ff_config,
126
+ pre_attention_norm_config=norm_config,
127
+ post_attention_norm_config=norm_config,
128
+ )
129
+ config = cfg.ModelConfig(
130
+ vocab_size=0, # Not used in image encoder.
131
+ num_layers=27,
132
+ max_seq_len=0, # Not used in image encoder.
133
+ embedding_dim=1152,
134
+ embedding_use_bias=True,
135
+ image_embedding=image_embedding_config,
136
+ block_configs=block_config,
137
+ final_norm_config=norm_config,
138
+ enable_hlfb=True,
139
+ )
140
+ return config
141
+
142
+
143
+ def get_fake_image_encoder_config() -> cfg.ModelConfig:
144
+ config = get_image_encoder_config()
145
+ # PaliGemma image encoder has only one block config.
146
+ config.block_config(0).ff_config.intermediate_size = 128
147
+ config.num_layers = 2
148
+ return config
149
+
150
+
151
+ def build_image_encoder(checkpoint_path: str) -> SiglipVisionEncoder:
152
+ config = get_image_encoder_config()
153
+ encoder = SiglipVisionEncoder(config)
154
+ loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
155
+ # Loose the strictness because only image encoder is being loaded.
156
+ loader.load(encoder, strict=False)
157
+ encoder.eval()
158
+ return encoder
@@ -0,0 +1,75 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Verifies the reauthored decoder of PaliGemma 3B model."""
17
+
18
+ import logging
19
+ import pathlib
20
+
21
+ from absl import app
22
+ from absl import flags
23
+ from ai_edge_torch.generative.examples.paligemma import decoder
24
+ from ai_edge_torch.generative.utilities import transformers_verifier
25
+ from ai_edge_torch.generative.utilities import verifier
26
+ import transformers
27
+
28
+ _PROMPTS = flags.DEFINE_multi_string(
29
+ "prompts",
30
+ "What is the meaning of life?",
31
+ "The input prompts to generate answers.",
32
+ )
33
+ _MAX_NEW_TOKENS = flags.DEFINE_integer(
34
+ "max_new_tokens",
35
+ 30,
36
+ "The maximum size of the generated tokens.",
37
+ )
38
+
39
+
40
+ def main(_):
41
+ checkpoint = "google/paligemma-3b-mix-224"
42
+ logging.info("Loading the original model from: %s", checkpoint)
43
+ original_full_model = (
44
+ transformers.PaliGemmaForConditionalGeneration.from_pretrained(checkpoint)
45
+ )
46
+ original_language_model = original_full_model.eval().language_model
47
+
48
+ # Locate the cached dir.
49
+ cached_config_file = transformers.utils.cached_file(
50
+ checkpoint, transformers.utils.CONFIG_NAME
51
+ )
52
+ reauthored_checkpoint = pathlib.Path(cached_config_file).parent
53
+ logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
54
+ reauthored_model = decoder.build_decoder(reauthored_checkpoint)
55
+
56
+ logging.info("Loading the tokenizer from: %s", checkpoint)
57
+ # It works only when GemmaTokenizerFast is available. In some environments,
58
+ # use_fast=False doeesn't work either if the tokenizer cannot load the
59
+ # sentencepiece model file properly.
60
+ processor = transformers.AutoProcessor.from_pretrained(checkpoint)
61
+
62
+ verifier.verify_reauthored_model(
63
+ original_model=transformers_verifier.TransformersModelWrapper(
64
+ original_language_model
65
+ ),
66
+ reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
67
+ tokenizer=verifier.TokenizerWrapper(processor.tokenizer),
68
+ generate_prompts=_PROMPTS.value,
69
+ max_new_tokens=_MAX_NEW_TOKENS.value,
70
+ atol=1e-04,
71
+ )
72
+
73
+
74
+ if __name__ == "__main__":
75
+ app.run(main)
@@ -0,0 +1,82 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Verifies the reauthored image encoder of PaliGemma 3B model."""
17
+
18
+ import logging
19
+ import pathlib
20
+ from absl import app
21
+ from absl import flags
22
+ from ai_edge_torch.generative.examples.paligemma import image_encoder
23
+ from PIL import Image
24
+ import requests
25
+ import torch
26
+ import transformers
27
+
28
+ _IMAGE_URL = flags.DEFINE_string(
29
+ "image_url",
30
+ "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg?download=true",
31
+ "The image URI to encode.",
32
+ )
33
+
34
+
35
+ def main(_):
36
+ checkpoint = "google/paligemma-3b-mix-224"
37
+ logging.info("Loading the original model from: %s", checkpoint)
38
+ original_full_model = (
39
+ transformers.PaliGemmaForConditionalGeneration.from_pretrained(checkpoint)
40
+ )
41
+ original_vision_model = original_full_model.eval().vision_tower
42
+
43
+ # Locate the cached dir.
44
+ cached_config_file = transformers.utils.cached_file(
45
+ checkpoint, transformers.utils.CONFIG_NAME
46
+ )
47
+ reauthored_checkpoint = pathlib.Path(cached_config_file).parent
48
+ logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
49
+ reauthored_model = image_encoder.build_image_encoder(reauthored_checkpoint)
50
+
51
+ logging.info("Loading the processor from: %s", checkpoint)
52
+ # It works only when GemmaTokenizerFast is available. In some environments,
53
+ # use_fast=False doeesn't work either if the tokenizer cannot load the
54
+ # sentencepiece model file properly.
55
+ processor = transformers.AutoProcessor.from_pretrained(checkpoint)
56
+
57
+ logging.info("Loading the image from: %s", _IMAGE_URL.value)
58
+ image = Image.open(requests.get(_IMAGE_URL.value, stream=True).raw)
59
+ pixel_values = processor(images=image, return_tensors="pt")["pixel_values"]
60
+
61
+ logging.info("Forwarding the original model...")
62
+ outputs_original = original_vision_model.forward(pixel_values=pixel_values)
63
+ outputs_original = outputs_original.last_hidden_state
64
+ logging.info("outputs_original: %s", outputs_original)
65
+
66
+ logging.info("Forwarding the reauthored model...")
67
+ outputs_reauthored = reauthored_model.forward(pixel_values=pixel_values)
68
+ logging.info("outputs_reauthored: %s", outputs_reauthored)
69
+
70
+ try:
71
+ assert torch.allclose(
72
+ outputs_original, outputs_reauthored, atol=1e-04, rtol=1e-04
73
+ )
74
+ except AssertionError as e:
75
+ logging.error("*** FAILED *** verify with an image")
76
+ raise e
77
+ else:
78
+ logging.info("*** PASSED *** verify with an image")
79
+
80
+
81
+ if __name__ == "__main__":
82
+ app.run(main)
@@ -235,9 +235,10 @@ class CausalSelfAttention(nn.Module):
235
235
  k = k.reshape(B, T, -1, self.config.head_dim)
236
236
  v = v.reshape(B, T, -1, self.config.head_dim)
237
237
 
238
- # Compute rotary positional embedding for query and key.
239
- n_elem = int(self.config.rotary_percentage * self.config.head_dim)
240
- q, k = _embed_rope(q, k, n_elem, rope)
238
+ if rope is not None:
239
+ # Compute rotary positional embedding for query and key.
240
+ n_elem = int(self.config.rotary_percentage * self.config.head_dim)
241
+ q, k = _embed_rope(q, k, n_elem, rope)
241
242
 
242
243
  if kv_cache is not None:
243
244
  kv_cache = kv_utils.update(
@@ -372,9 +373,10 @@ class CrossAttention(nn.Module):
372
373
  k = k.view(interim_shape)
373
374
  v = v.view(interim_shape)
374
375
 
375
- # Compute rotary positional embedding for query and key.
376
- n_elem = int(self.config.rotary_percentage * self.config.head_dim)
377
- q, k = _embed_rope(q, k, n_elem, rope)
376
+ if rope is not None:
377
+ # Compute rotary positional embedding for query and key.
378
+ n_elem = int(self.config.rotary_percentage * self.config.head_dim)
379
+ q, k = _embed_rope(q, k, n_elem, rope)
378
380
 
379
381
  if kv_cache is not None:
380
382
  kv_cache = kv_utils.update(
@@ -163,6 +163,16 @@ class TransformerBlockConfig:
163
163
  relative_attention: bool = False
164
164
 
165
165
 
166
+ @dataclass
167
+ class ImageEmbeddingConfig:
168
+ """Image embedding parameters."""
169
+
170
+ channels: int
171
+ # All images should be normalized to the size of [image_size * image_size].
172
+ image_size: int
173
+ patch_size: int
174
+
175
+
166
176
  @dataclass
167
177
  class ModelConfig:
168
178
  """Base configurations for building a transformer architecture."""
@@ -183,6 +193,10 @@ class ModelConfig:
183
193
 
184
194
  # Scale factor of the embedding.
185
195
  embedding_scale: Optional[float] = None
196
+ # Use bias term within embedding.
197
+ embedding_use_bias: bool = False
198
+ # Image embedding parameters.
199
+ image_embedding: Optional[ImageEmbeddingConfig] = None
186
200
 
187
201
  # Use bias term within LLM's HEAD.
188
202
  lm_head_use_bias: bool = False
@@ -115,12 +115,15 @@ class AttentionBlock2D(nn.Module):
115
115
  """
116
116
  super().__init__()
117
117
  self.config = config
118
+ hidden_dim = config.hidden_dim
119
+ if not hidden_dim:
120
+ hidden_dim = config.dim
118
121
  self.norm = layers_builder.build_norm(
119
- config.dim, config.normalization_config
122
+ hidden_dim, config.normalization_config
120
123
  )
121
124
  self.attention = SelfAttention(
122
125
  config.attention_batch_size,
123
- config.dim,
126
+ hidden_dim,
124
127
  config.attention_config,
125
128
  enable_hlfb=config.enable_hlfb,
126
129
  )
@@ -172,7 +175,7 @@ class CrossAttentionBlock2D(nn.Module):
172
175
  super().__init__()
173
176
  self.config = config
174
177
  self.norm = layers_builder.build_norm(
175
- config.query_dim, config.normalization_config
178
+ config.output_dim, config.normalization_config
176
179
  )
177
180
  self.attention = CrossAttention(
178
181
  config.attention_batch_size,
@@ -294,21 +297,30 @@ class TransformerBlock2D(nn.Module):
294
297
  hidden_states
295
298
  """
296
299
 
297
- def __init__(self, config: unet_cfg.TransformerBlock2DConfig):
300
+ def __init__(
301
+ self, config: unet_cfg.TransformerBlock2DConfig, dim_override=None
302
+ ):
298
303
  """Initialize an instance of the TransformerBlock2D.
299
304
 
300
305
  Args:
301
306
  config (unet_cfg.TransformerBlock2Dconfig): the configuration of this
302
307
  block.
308
+ dim_override: in case specified, overrides config.attention_block_config.hidden_dim. Set to None by default.
303
309
  """
304
310
  super().__init__()
305
311
  self.config = config
312
+ attention_block_config_dim = config.attention_block_config.dim
313
+ attention_block_config_hidden_dim = config.attention_block_config.hidden_dim
314
+ if dim_override:
315
+ attention_block_config_dim = dim_override
316
+ if not attention_block_config_hidden_dim:
317
+ attention_block_config_hidden_dim = attention_block_config_dim
306
318
  self.pre_conv_norm = layers_builder.build_norm(
307
- config.attention_block_config.dim, config.pre_conv_normalization_config
319
+ attention_block_config_dim, config.pre_conv_normalization_config
308
320
  )
309
321
  self.conv_in = nn.Conv2d(
310
- config.attention_block_config.dim,
311
- config.attention_block_config.dim,
322
+ attention_block_config_dim,
323
+ attention_block_config_hidden_dim,
312
324
  kernel_size=1,
313
325
  padding=0,
314
326
  )
@@ -318,8 +330,8 @@ class TransformerBlock2D(nn.Module):
318
330
  )
319
331
  self.feed_forward = FeedForwardBlock2D(config.feed_forward_block_config)
320
332
  self.conv_out = nn.Conv2d(
321
- config.attention_block_config.dim,
322
- config.attention_block_config.dim,
333
+ attention_block_config_hidden_dim,
334
+ attention_block_config_dim,
323
335
  kernel_size=1,
324
336
  padding=0,
325
337
  )
@@ -385,14 +397,18 @@ class DownEncoderBlock2D(nn.Module):
385
397
  self.config = config
386
398
  resnets = []
387
399
  transformers = []
400
+ hidden_channels = config.hidden_channels
401
+ if not hidden_channels:
402
+ hidden_channels = config.out_channels
388
403
  for i in range(config.num_layers):
389
404
  input_channels = config.in_channels if i == 0 else config.out_channels
390
405
  resnets.append(
391
406
  ResidualBlock2D(
392
407
  unet_cfg.ResidualBlock2DConfig(
393
408
  in_channels=input_channels,
394
- hidden_channels=config.out_channels,
409
+ hidden_channels=hidden_channels,
395
410
  out_channels=config.out_channels,
411
+ residual_out_channels=config.out_channels,
396
412
  time_embedding_channels=config.time_embedding_channels,
397
413
  normalization_config=config.normalization_config,
398
414
  activation_config=config.activation_config,
@@ -589,23 +605,37 @@ class SkipUpDecoderBlock2D(nn.Module):
589
605
  """
590
606
  super().__init__()
591
607
  self.config = config
608
+ hidden_channels = config.hidden_channels
609
+ if not hidden_channels:
610
+ hidden_channels = config.out_channels
611
+ sub_block_channels = config.sub_block_channels
612
+ if sub_block_channels:
613
+ assert len(sub_block_channels) == config.num_layers, (
614
+ "Assertion failed: The length of 'sub_block_channels'"
615
+ f" ({len(sub_block_channels)}) does not match 'config.num_layers'"
616
+ f" ({config.num_layers})."
617
+ )
618
+ else:
619
+ sub_block_channels = [config.out_channels] * config.num_layers
592
620
  resnets = []
593
621
  transformers = []
594
622
  for i in range(config.num_layers):
623
+ resnet_in_channels = (
624
+ config.prev_out_channels if i == 0 else sub_block_channels[i - 1]
625
+ )
595
626
  res_skip_channels = (
596
627
  config.in_channels
597
628
  if (i == config.num_layers - 1)
598
629
  else config.out_channels
599
630
  )
600
- resnet_in_channels = (
601
- config.prev_out_channels if i == 0 else config.out_channels
602
- )
631
+ residual_out_channel = sub_block_channels[i]
603
632
  resnets.append(
604
633
  ResidualBlock2D(
605
634
  unet_cfg.ResidualBlock2DConfig(
606
635
  in_channels=resnet_in_channels + res_skip_channels,
607
- hidden_channels=config.out_channels,
608
- out_channels=config.out_channels,
636
+ hidden_channels=hidden_channels,
637
+ out_channels=sub_block_channels[i],
638
+ residual_out_channels=residual_out_channel,
609
639
  time_embedding_channels=config.time_embedding_channels,
610
640
  normalization_config=config.normalization_config,
611
641
  activation_config=config.activation_config,
@@ -613,7 +643,12 @@ class SkipUpDecoderBlock2D(nn.Module):
613
643
  )
614
644
  )
615
645
  if config.transformer_block_config:
616
- transformers.append(TransformerBlock2D(config.transformer_block_config))
646
+ transformers.append(
647
+ TransformerBlock2D(
648
+ config.transformer_block_config,
649
+ dim_override=sub_block_channels[i],
650
+ )
651
+ )
617
652
  self.resnets = nn.ModuleList(resnets)
618
653
  self.transformers = (
619
654
  nn.ModuleList(transformers) if len(transformers) > 0 else None
@@ -623,7 +658,7 @@ class SkipUpDecoderBlock2D(nn.Module):
623
658
  if config.upsample_conv:
624
659
  self.upsample_conv = nn.Conv2d(
625
660
  config.out_channels,
626
- config.out_channels,
661
+ sub_block_channels[0],
627
662
  kernel_size=3,
628
663
  stride=1,
629
664
  padding=1,
@@ -711,6 +746,7 @@ class MidBlock2D(nn.Module):
711
746
  in_channels=config.in_channels,
712
747
  hidden_channels=config.in_channels,
713
748
  out_channels=config.in_channels,
749
+ residual_out_channels=config.in_channels,
714
750
  time_embedding_channels=config.time_embedding_channels,
715
751
  normalization_config=config.normalization_config,
716
752
  activation_config=config.activation_config,
@@ -50,10 +50,12 @@ class ResidualBlock2DConfig:
50
50
  in_channels: int
51
51
  hidden_channels: int
52
52
  out_channels: int
53
+ hidden_channels: int
53
54
  normalization_config: layers_cfg.NormalizationConfig
54
55
  activation_config: layers_cfg.ActivationConfig
55
56
  # Optional time embedding channels if the residual block takes a time embedding context as input
56
57
  time_embedding_channels: Optional[int] = None
58
+ residual_out_channels: Optional[int] = None
57
59
 
58
60
 
59
61
  @dataclasses.dataclass
@@ -63,6 +65,7 @@ class AttentionBlock2DConfig:
63
65
  attention_config: layers_cfg.AttentionConfig
64
66
  enable_hlfb: bool = True
65
67
  attention_batch_size: int = 1
68
+ hidden_dim: Optional[int] = None
66
69
 
67
70
 
68
71
  @dataclasses.dataclass
@@ -101,6 +104,8 @@ class UpDecoderBlock2DConfig:
101
104
  normalization_config: layers_cfg.NormalizationConfig
102
105
  activation_config: layers_cfg.ActivationConfig
103
106
  num_layers: int
107
+ # The dimension of output channels of previous connected block
108
+ prev_out_channels: Optional[int] = None
104
109
  # Optional time embedding channels if the residual blocks take a time embedding as input
105
110
  time_embedding_channels: Optional[int] = None
106
111
  # Whether to add upsample operation after residual blocks
@@ -136,6 +141,8 @@ class SkipUpDecoderBlock2DConfig:
136
141
  transformer_block_config: Optional[TransformerBlock2DConfig] = None
137
142
  # Optional dimension of context tensor if context tensor is given as input.
138
143
  context_dim: Optional[int] = None
144
+ sub_block_channels: Optional[tuple] = None
145
+ hidden_channels: Optional[int] = None
139
146
 
140
147
 
141
148
  @dataclasses.dataclass
@@ -157,6 +164,7 @@ class DownEncoderBlock2DConfig:
157
164
  transformer_block_config: Optional[TransformerBlock2DConfig] = None
158
165
  # Optional dimension of context tensor if context tensor is given as input.
159
166
  context_dim: Optional[int] = None
167
+ hidden_channels: Optional[int] = None
160
168
 
161
169
 
162
170
  @dataclasses.dataclass
@@ -157,6 +157,10 @@ class ModelLoader:
157
157
  converted_state["tok_embedding.weight"] = state.pop(
158
158
  f"{self._names.embedding}.weight"
159
159
  )
160
+ if model.config.embedding_use_bias:
161
+ converted_state["tok_embedding.bias"] = state.pop(
162
+ f"{self._names.embedding}.bias"
163
+ )
160
164
  if self._names.embedding_position is not None:
161
165
  converted_state["tok_embedding_position"] = state.pop(
162
166
  f"{self._names.embedding_position}"
@@ -143,7 +143,7 @@ def verify_with_input_ids(
143
143
  kv_cache_max_len: int = 1024,
144
144
  rtol: float = 1e-05,
145
145
  atol: float = 1e-05,
146
- ) -> bool:
146
+ ):
147
147
  """Verifies if the model reauthored generates the same output of the oringal.
148
148
 
149
149
  It compares only one outputs from the original and the reauthored model.
@@ -157,8 +157,9 @@ def verify_with_input_ids(
157
157
  rtol (float): The relative tolerance for the comparison.
158
158
  atol (float): The absolute tolerance for the comparison.
159
159
 
160
- Returns:
161
- True if the model reauthored generates the same output of the original.
160
+ Raises:
161
+ AssertError if the model reauthored fails to generate the same output of the
162
+ original.
162
163
  """
163
164
  tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.int, device="cpu")
164
165
  tokens[0, : len(input_ids)] = torch.tensor([input_ids]).int()
@@ -173,7 +174,7 @@ def verify_with_input_ids(
173
174
  logits_reauthored = outputs_reauthored[0, len(input_ids) - 1, :]
174
175
  logging.info("logits_reauthored: %s", logits_reauthored)
175
176
 
176
- return torch.allclose(
177
+ assert torch.allclose(
177
178
  logits_original, logits_reauthored, rtol=rtol, atol=atol
178
179
  )
179
180
 
@@ -184,7 +185,7 @@ def verify_model_with_prompts(
184
185
  tokenizer: TokenizerWrapper,
185
186
  prompts: str,
186
187
  max_new_tokens: int,
187
- ) -> bool:
188
+ ):
188
189
  """Verifies if the model reauthored generates the same answer of the oringal.
189
190
 
190
191
  It compares an answer, i.e. multiple continuous outputs generated by the
@@ -198,8 +199,9 @@ def verify_model_with_prompts(
198
199
  prompts (str): The input prompts to generate answers.
199
200
  max_new_tokens (int): The maximum number of new tokens to generate.
200
201
 
201
- Returns:
202
- True if the model reauthored generates the same answer of the original.
202
+ Raises:
203
+ AssertError if the model reauthored fails to generate the same answer of the
204
+ original.
203
205
  """
204
206
  prompt_tokens = tokenizer.encode(prompts)
205
207
 
@@ -213,7 +215,7 @@ def verify_model_with_prompts(
213
215
  response_reauthored = tokenizer.decode(outputs_reauthored[0])
214
216
  logging.info("outputs from reauthored model: [[%s]]", response_reauthored)
215
217
 
216
- return response_original == response_reauthored
218
+ assert response_original == response_reauthored
217
219
 
218
220
 
219
221
  def verify_reauthored_model(
@@ -225,7 +227,8 @@ def verify_reauthored_model(
225
227
  forward_input_ids: List[List[int]] = [[1, 2, 3, 4]],
226
228
  rtol: float = 1e-05,
227
229
  atol: float = 1e-05,
228
- ):
230
+ continue_on_failure: bool = False,
231
+ ) -> bool:
229
232
  """Verifies the reauthored model against the original model.
230
233
 
231
234
  It verifies the reauthored model with two methods:
@@ -234,7 +237,8 @@ def verify_reauthored_model(
234
237
  2. It compares the answer generated by the original and the reauthored model
235
238
  with a prompt.
236
239
 
237
- It prints out "PASS" or "FAILED" to the console.
240
+ It prints out "PASS" or "FAILED" to the console. It returns True if all
241
+ verification passes, False otherwise.
238
242
 
239
243
  Args:
240
244
  original_model (ModelWrapper): The original model.
@@ -247,21 +251,42 @@ def verify_reauthored_model(
247
251
  forward with.
248
252
  rtol (float): The relative tolerance for the comparison.
249
253
  atol (float): The absolute tolerance for the comparison.
254
+ continue_on_failure (bool): If True, it continues to verify the next prompt
255
+ or input IDs even if a previous one fails.
250
256
  """
257
+ failure_count = 0
258
+
251
259
  for input_ids in forward_input_ids:
252
260
  logging.info("Verifying the reauthored model with input IDs: %s", input_ids)
253
- if verify_with_input_ids(
254
- original_model, reauthored_model, input_ids, rtol=rtol, atol=atol
255
- ):
256
- logging.info("PASS")
261
+ try:
262
+ verify_with_input_ids(
263
+ original_model, reauthored_model, input_ids, rtol=rtol, atol=atol
264
+ )
265
+ except AssertionError as e:
266
+ logging.error("*** FAILED *** verify with input IDs: %s", input_ids)
267
+ failure_count += 1
268
+ if not continue_on_failure:
269
+ return False
257
270
  else:
258
- logging.error("FAILED")
271
+ logging.info("*** PASSED *** verify with input IDs: %s", input_ids)
259
272
 
260
273
  for prompts in generate_prompts:
261
- logging.info("Verifying the reauthored model with prompts:%s", prompts)
262
- if verify_model_with_prompts(
263
- original_model, reauthored_model, tokenizer, prompts, max_new_tokens
264
- ):
265
- logging.info("PASS")
274
+ logging.info("Verifying the reauthored model with prompts: %s", prompts)
275
+ try:
276
+ verify_model_with_prompts(
277
+ original_model, reauthored_model, tokenizer, prompts, max_new_tokens
278
+ )
279
+ except AssertionError as e:
280
+ logging.error("*** FAILED *** verify with prompts: %s", prompts)
281
+ failure_count += 1
282
+ if not continue_on_failure:
283
+ return False
266
284
  else:
267
- logging.error("FAILED")
285
+ logging.info("*** PASSED *** verify with prompts: %s", prompts)
286
+
287
+ if failure_count == 0:
288
+ logging.info("*** PASSED *** verify_reauthored_model")
289
+ return True
290
+ else:
291
+ logging.error("*** FAILED *** verify_reauthored_model")
292
+ return False
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.3.0.dev20241108"
16
+ __version__ = "0.3.0.dev20241114"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.3.0.dev20241108
3
+ Version: 0.3.0.dev20241114
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -3,7 +3,7 @@ ai_edge_torch/config.py,sha256=FMWeCH2b7HYILBvaI1iZNnYCO4WAhDOwBZBmIE-xrF0,909
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/fx_pass_base.py,sha256=518ziQ0TUxqum2qZXqlD8qr65pHPh8ZNLnwFC6zvK3k,4253
5
5
  ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
6
- ai_edge_torch/version.py,sha256=sBOl0mgVPJtokiP8qTbTtY0R_qIaF0KNiALh7P3AJEk,706
6
+ ai_edge_torch/version.py,sha256=kPX3PxeMGiRGAaIsITYzfPxu3_FKKLg91pGrqXOP6OY,706
7
7
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
8
8
  ai_edge_torch/_convert/conversion.py,sha256=HwzfRx_DX5TLtPqwEH1_NOm38_INvHzHl4_mX67KOdQ,5448
9
9
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
@@ -12,7 +12,7 @@ ai_edge_torch/_convert/signature.py,sha256=rGpBNss3Y9FCCCcdBwDo16KqavJi8N5P0M_6W
12
12
  ai_edge_torch/_convert/to_channel_last_io.py,sha256=_31phf7TYgZY2ftpNbrdlB1RhDium1lz_BXEQ6IsMFc,2893
13
13
  ai_edge_torch/_convert/fx_passes/__init__.py,sha256=NVe-eGcm7j8jZpP2pcMhC8j5dVjgR1pPzyXhHdvKH4E,1267
14
14
  ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py,sha256=doaww8KqrgRTD5LotBVAIRFsEqzPn9R5lcGehBJOczA,9098
15
- ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py,sha256=XCVqWg_ask0Kb64PED0ZGAODsUuIgfyO2ZJM6aK-TXI,4283
15
+ ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py,sha256=m_yj66V11LmWCYgA7yLtr__cy14IbC5WEJe0BE0_IPE,4339
16
16
  ai_edge_torch/_convert/fx_passes/inject_mlir_debuginfo_pass.py,sha256=IlZuK42kfVcRqAWZp4j2k_81T2uWo9T2558U_GPJAlU,2327
17
17
  ai_edge_torch/_convert/fx_passes/remove_non_user_outputs_pass.py,sha256=f1IUVWyhioOClsMiZzLyynoW2R17U83vA-7Q-3pGPM4,2126
18
18
  ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/__init__.py,sha256=lxnoH-WGLeiQIF8XjMGodjiZEFTxucl7g05N7MR9OPk,796
@@ -50,16 +50,21 @@ ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py,sha256=RZDs6
50
50
  ai_edge_torch/generative/examples/gemma/gemma1.py,sha256=oSbysiPvwp5efMbNYZop3HrxDMGiD15Tmz-HiQuTr2E,3315
51
51
  ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=RQFQDMEnIVp8PefcCTr7P0CvllKI7FVoIJLXbPLLIsc,9056
52
52
  ai_edge_torch/generative/examples/gemma/verify_gemma1.py,sha256=ip-Gmk4CI5f0GWSdAIdrectxQWJ0t328KCsA4nfHuGg,1736
53
- ai_edge_torch/generative/examples/gemma/verify_gemma2.py,sha256=K77k-JpdhIwm3tbBnzpw8HQsFRwAVyszxRo82fR6-q4,1762
54
- ai_edge_torch/generative/examples/gemma/verify_util.py,sha256=sqltZbnyKemNvKqqi9d09i74gP-PPQFodRYfDfnhycQ,4933
53
+ ai_edge_torch/generative/examples/gemma/verify_gemma2.py,sha256=IoBhEMwH07-tFm5-U6F2hpCsI8xynglhq1x9tIOdaPQ,1322
54
+ ai_edge_torch/generative/examples/gemma/verify_util.py,sha256=tR8RflXocDZqvuStyw9aFlzuiTllEC8rNnjrxms6_Is,5727
55
55
  ai_edge_torch/generative/examples/llama/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
56
56
  ai_edge_torch/generative/examples/llama/convert_to_tflite.py,sha256=P0-pByTM5tslE23ILgo7nd0nOGE25ciBRG5wKJj0bBk,2411
57
57
  ai_edge_torch/generative/examples/llama/llama.py,sha256=AMcCbuDBxEfbO-l3KiEXbUaXEJ3RLLwkHii7to7UhVo,6854
58
58
  ai_edge_torch/generative/examples/llama/verify.py,sha256=X7oKQi85M789ugBrOlMvzk8eSRR3Kf1Mprfl-U-WIpo,2842
59
59
  ai_edge_torch/generative/examples/openelm/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
60
60
  ai_edge_torch/generative/examples/openelm/convert_to_tflite.py,sha256=85FVEt6cKFP2UzCLC78tAkbwGlGhAArtG7Wa75NxJik,2185
61
- ai_edge_torch/generative/examples/openelm/openelm.py,sha256=JsrtuUY4q1Rovxsht2cGCuANUj1sUKnah6bAoSe8AoU,4387
61
+ ai_edge_torch/generative/examples/openelm/openelm.py,sha256=sFakstoPDcOHSak0IGFEEq_HQMBBSMcx-WVCDZqcVDo,4411
62
62
  ai_edge_torch/generative/examples/openelm/verify.py,sha256=VkigoqhAr8ew95neb3TifYv-SLOSheaWKv2AH0iKDrc,2441
63
+ ai_edge_torch/generative/examples/paligemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
64
+ ai_edge_torch/generative/examples/paligemma/decoder.py,sha256=JSb9h3gcIh5oYrbLU6rI8OU8FzfWeTCFJT5XRWu4btE,3675
65
+ ai_edge_torch/generative/examples/paligemma/image_encoder.py,sha256=v19_EKALhAP9FjkINKqpv8JsVaQ6iH_7X5FpnhE6abw,5500
66
+ ai_edge_torch/generative/examples/paligemma/verify_decoder.py,sha256=al5wMPWri4IRVWrLmCplPi6uoCzwh0vBHMGnCt-XUqo,2690
67
+ ai_edge_torch/generative/examples/paligemma/verify_image_encoder.py,sha256=pSekf1BybhieQz3cQx_llbRQHxczXbTqool8fOyGj_0,3114
63
68
  ai_edge_torch/generative/examples/phi/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
64
69
  ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py,sha256=rkbTtMaqSVG48cm-NTxR_LDgZmXAEBqayTm9O49oMXc,2171
65
70
  ai_edge_torch/generative/examples/phi/convert_to_tflite.py,sha256=3go690yX6PFeXMdpY7y4JZorAwxX0HT_b_pKZieauvk,2169
@@ -105,19 +110,19 @@ ai_edge_torch/generative/examples/tiny_llama/verify.py,sha256=7Bk8z033M-BCXJ299f
105
110
  ai_edge_torch/generative/fx_passes/__init__.py,sha256=jrzCB3ZyY_t5jJM1e2Czdt3DjAIL43R0_a-T-I7wOzw,1155
106
111
  ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=hhxSQvkDMv0isZJhmuLiod66ZODaJ8uSPSVTJVHBabQ,1931
107
112
  ai_edge_torch/generative/layers/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
108
- ai_edge_torch/generative/layers/attention.py,sha256=Z0Y_G8IG0LmvLX2u9D8__Fkr22szB-az6wMNnZpzhkA,13233
113
+ ai_edge_torch/generative/layers/attention.py,sha256=zN3BQjA25Ej_aRU0rFnyx--K74xf5ykc02zGvUpYHeE,13295
109
114
  ai_edge_torch/generative/layers/attention_utils.py,sha256=zBVwlBUTs-nStIKCZG0ks5ra7tsqc9ShfakFJKH5rds,7344
110
115
  ai_edge_torch/generative/layers/builder.py,sha256=Z5LyzCEThgnYZeyViakaE3yJVzTGHtw13acHsAQR15U,5050
111
116
  ai_edge_torch/generative/layers/feed_forward.py,sha256=hdICat-8gW7-vxDAevJQ8NQ-mynllPiqLdXQMF6JMnc,4189
112
117
  ai_edge_torch/generative/layers/kv_cache.py,sha256=lbm-yJ1jGPtcgWS4C3FmSnB1IlxqDE7g0BLRh3PN4N4,6324
113
- ai_edge_torch/generative/layers/model_config.py,sha256=DdsdhTP5tZAtyWim-qW2m8HDBsYbs7boqSDb83vwgmE,6998
118
+ ai_edge_torch/generative/layers/model_config.py,sha256=xqa7ZBEjgK4UWJAThRXb_VBFZ5KCGtDu-QaY5GXar9s,7366
114
119
  ai_edge_torch/generative/layers/normalization.py,sha256=eKAGst9rPuyRFExMcQFJO7R3iHdCtlmjeF_lITjLhwE,6498
115
120
  ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=CZqOoibLcHvUgrgaIIWAlmk3XgE2inzx340MN-npLoU,1347
116
121
  ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=gXxh3papKy4FBpGEX7VyZ7rZ1Js6aHK70Q6DKrVSckY,4154
117
122
  ai_edge_torch/generative/layers/unet/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
118
- ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=JwndhL3Z31TvkdGlAoTL5PQzmKfHdRWaaE1EbaMI4Gs,27540
123
+ ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=9jKzOfiBQ66bp1ZnVIAoREIifVNFx4aTlQeYMAx2_pA,29062
119
124
  ai_edge_torch/generative/layers/unet/builder.py,sha256=zAqWXdimmMrQRhmE_t9XkS68mh6PSrzwb-2NZZXrR5I,1901
120
- ai_edge_torch/generative/layers/unet/model_config.py,sha256=raYm8Ol-EFi0zs5vNqmj2ZJCFsnQW2TfwhgDcClfwFA,9356
125
+ ai_edge_torch/generative/layers/unet/model_config.py,sha256=pPDwLawc23pfMaPVyMJlYmxVVusjMvx-l8wBwOYOH-c,9692
121
126
  ai_edge_torch/generative/quantize/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
122
127
  ai_edge_torch/generative/quantize/example.py,sha256=1lfVNUd2cEyRUnoZ7BLbRJ9IN-FTKiWBtZNPFUzAiWE,1747
123
128
  ai_edge_torch/generative/quantize/quant_attrs.py,sha256=n1Fm8BFC8gJa_oiwwAOOghJyHtOXYZ4q-5ZRy4pHrIw,1957
@@ -134,12 +139,12 @@ ai_edge_torch/generative/test/test_quantize.py,sha256=8geJhKwYBU20m0mdGPD1BUFwQ0
134
139
  ai_edge_torch/generative/test/utils.py,sha256=YvEhO2HIj1LkBs5du1UxY-cGRW9HMyAYsOUhgsTrTpA,1796
135
140
  ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
136
141
  ai_edge_torch/generative/utilities/converter.py,sha256=17O83wVifH1vQJCI4WC3DaNiCIOyK2gys1GzohbLrRs,5554
137
- ai_edge_torch/generative/utilities/loader.py,sha256=b9iotIhVDX-Zc9XjIDUaLxnV395AyBnkQe3dV5YA7Co,13297
142
+ ai_edge_torch/generative/utilities/loader.py,sha256=k5fjCokNomte4ymy9IJrEWAuCSMhsPCJfmv1y5s0ZEc,13452
138
143
  ai_edge_torch/generative/utilities/model_builder.py,sha256=89jt80UUfDzYBi-x077HBavWeuNJuYPXym9fiKCY1Tk,5278
139
144
  ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=dqPD9qRXEWtU3ombslOC-BE2l_dMwHoCNu7NsIJhsso,36158
140
145
  ai_edge_torch/generative/utilities/t5_loader.py,sha256=tEsfy8-ymzbbjOIc-oesXF3yGyyWtJgFXn2s7VOavt8,16961
141
146
  ai_edge_torch/generative/utilities/transformers_verifier.py,sha256=8sp9m_FMcXn7nqOrochtu2jIANkJKhnhIBUmH0ZTDR4,1549
142
- ai_edge_torch/generative/utilities/verifier.py,sha256=wQ4EtIED_a6FRsaOXeoQVZiHNx07esOYCQYbDVLgZ2o,9520
147
+ ai_edge_torch/generative/utilities/verifier.py,sha256=h5hGyIpYGyPZwvelbzpdkjy99Kpd4JkvhqWtQN9cm-M,10413
143
148
  ai_edge_torch/hlfb/__init__.py,sha256=sH4um75na-O8tzxN6chFyp6Y4xnexsE7kUQpZySv6dE,735
144
149
  ai_edge_torch/hlfb/mark_pattern/__init__.py,sha256=cjTprggj_cuktSCm7-A25e7Shop3k63ylp7sdZmtZ8o,4790
145
150
  ai_edge_torch/hlfb/mark_pattern/passes.py,sha256=pjkKcI1nHECPluAt87cFBrt1DP0f3ge7rHq1NhCkBIE,1936
@@ -186,8 +191,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
186
191
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
187
192
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
188
193
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
189
- ai_edge_torch_nightly-0.3.0.dev20241108.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
190
- ai_edge_torch_nightly-0.3.0.dev20241108.dist-info/METADATA,sha256=gp2VN_X4YPdK8axZYIhqafgiJhCwfiN_tOWT-yL3lW0,1897
191
- ai_edge_torch_nightly-0.3.0.dev20241108.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
192
- ai_edge_torch_nightly-0.3.0.dev20241108.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
193
- ai_edge_torch_nightly-0.3.0.dev20241108.dist-info/RECORD,,
194
+ ai_edge_torch_nightly-0.3.0.dev20241114.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
195
+ ai_edge_torch_nightly-0.3.0.dev20241114.dist-info/METADATA,sha256=LApNIpUONX7p4MYLOsx9QbWyD31PtVQDgQShvzqW-ec,1897
196
+ ai_edge_torch_nightly-0.3.0.dev20241114.dist-info/WHEEL,sha256=bFJAMchF8aTQGUgMZzHJyDDMPTO3ToJ7x23SLJa1SVo,92
197
+ ai_edge_torch_nightly-0.3.0.dev20241114.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
198
+ ai_edge_torch_nightly-0.3.0.dev20241114.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: bdist_wheel (0.44.0)
2
+ Generator: bdist_wheel (0.45.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5