ai-edge-torch-nightly 0.3.0.dev20241108__py3-none-any.whl → 0.3.0.dev20241114__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py +2 -0
- ai_edge_torch/generative/examples/gemma/verify_gemma2.py +1 -14
- ai_edge_torch/generative/examples/gemma/verify_util.py +28 -3
- ai_edge_torch/generative/examples/openelm/openelm.py +1 -0
- ai_edge_torch/generative/examples/paligemma/__init__.py +14 -0
- ai_edge_torch/generative/examples/paligemma/decoder.py +103 -0
- ai_edge_torch/generative/examples/paligemma/image_encoder.py +158 -0
- ai_edge_torch/generative/examples/paligemma/verify_decoder.py +75 -0
- ai_edge_torch/generative/examples/paligemma/verify_image_encoder.py +82 -0
- ai_edge_torch/generative/layers/attention.py +8 -6
- ai_edge_torch/generative/layers/model_config.py +14 -0
- ai_edge_torch/generative/layers/unet/blocks_2d.py +53 -17
- ai_edge_torch/generative/layers/unet/model_config.py +8 -0
- ai_edge_torch/generative/utilities/loader.py +4 -0
- ai_edge_torch/generative/utilities/verifier.py +46 -21
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20241108.dist-info → ai_edge_torch_nightly-0.3.0.dev20241114.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20241108.dist-info → ai_edge_torch_nightly-0.3.0.dev20241114.dist-info}/RECORD +21 -16
- {ai_edge_torch_nightly-0.3.0.dev20241108.dist-info → ai_edge_torch_nightly-0.3.0.dev20241114.dist-info}/WHEEL +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20241108.dist-info → ai_edge_torch_nightly-0.3.0.dev20241114.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20241108.dist-info → ai_edge_torch_nightly-0.3.0.dev20241114.dist-info}/top_level.txt +0 -0
@@ -51,6 +51,7 @@ def _get_upsample_bilinear2d_pattern():
|
|
51
51
|
return {
|
52
52
|
"output": (int(output_h), int(output_w)),
|
53
53
|
"align_corners": False,
|
54
|
+
"is_nchw_op": True,
|
54
55
|
}
|
55
56
|
|
56
57
|
return pattern
|
@@ -74,6 +75,7 @@ def _get_upsample_bilinear2d_align_corners_pattern():
|
|
74
75
|
return {
|
75
76
|
"output": (int(output_h), int(output_w)),
|
76
77
|
"align_corners": True,
|
78
|
+
"is_nchw_op": True,
|
77
79
|
}
|
78
80
|
|
79
81
|
return pattern
|
@@ -15,10 +15,8 @@
|
|
15
15
|
|
16
16
|
"""Verifies the reauthored Gemma2 model."""
|
17
17
|
|
18
|
-
import logging
|
19
18
|
from absl import app
|
20
19
|
from absl import flags
|
21
|
-
from ai_edge_torch.generative.examples.gemma import gemma2
|
22
20
|
from ai_edge_torch.generative.examples.gemma import verify_util
|
23
21
|
import kagglehub
|
24
22
|
|
@@ -38,18 +36,7 @@ _MAX_NEW_TOKENS = flags.DEFINE_integer(
|
|
38
36
|
def main(_):
|
39
37
|
checkpoint = kagglehub.model_download("google/gemma-2/pyTorch/gemma-2-2b-it")
|
40
38
|
|
41
|
-
|
42
|
-
reauthored_model = gemma2.build_2b_model(checkpoint)
|
43
|
-
|
44
|
-
verify_util.verify_reauthored_gemma_model(
|
45
|
-
checkpoint=checkpoint,
|
46
|
-
variant="2b-v2",
|
47
|
-
reauthored_model=reauthored_model,
|
48
|
-
generate_prompts=_PROMPTS.value,
|
49
|
-
forward_input_ids=[[2, 651, 9456, 576, 573, 3520, 3858, 603, 235248]],
|
50
|
-
max_new_tokens=_MAX_NEW_TOKENS.value,
|
51
|
-
atol=1e-04,
|
52
|
-
)
|
39
|
+
verify_util.verify_gemma2(checkpoint, _PROMPTS.value, _MAX_NEW_TOKENS.value)
|
53
40
|
|
54
41
|
|
55
42
|
if __name__ == "__main__":
|
@@ -19,6 +19,7 @@ import logging
|
|
19
19
|
import os
|
20
20
|
from typing import List, Tuple
|
21
21
|
|
22
|
+
from ai_edge_torch.generative.examples.gemma import gemma2
|
22
23
|
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
23
24
|
from ai_edge_torch.generative.utilities import verifier
|
24
25
|
from gemma import config as gemma_config
|
@@ -109,8 +110,11 @@ def verify_reauthored_gemma_model(
|
|
109
110
|
max_new_tokens: int = 20,
|
110
111
|
rtol: float = 1e-05,
|
111
112
|
atol: float = 1e-05,
|
112
|
-
):
|
113
|
-
"""Verifies the reauthored Gemma model against the original model.
|
113
|
+
) -> bool:
|
114
|
+
"""Verifies the reauthored Gemma model against the original model.
|
115
|
+
|
116
|
+
Returns True if the verification passes, False otherwise.
|
117
|
+
"""
|
114
118
|
config = gemma_config.get_model_config(variant)
|
115
119
|
config.tokenizer = os.path.join(checkpoint, tokenizer_filename)
|
116
120
|
# Use float32 to be compatible with the reauthored model.
|
@@ -120,7 +124,7 @@ def verify_reauthored_gemma_model(
|
|
120
124
|
original_model = gemma_model.GemmaForCausalLM(config).eval()
|
121
125
|
original_model.load_weights(os.path.join(checkpoint, weight_filename))
|
122
126
|
|
123
|
-
verifier.verify_reauthored_model(
|
127
|
+
return verifier.verify_reauthored_model(
|
124
128
|
original_model=GemmaWrapper(original_model),
|
125
129
|
reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
|
126
130
|
tokenizer=GemmaTokenizerWrapper(original_model.tokenizer),
|
@@ -130,3 +134,24 @@ def verify_reauthored_gemma_model(
|
|
130
134
|
rtol=rtol,
|
131
135
|
atol=atol,
|
132
136
|
)
|
137
|
+
|
138
|
+
|
139
|
+
def verify_gemma2(
|
140
|
+
gemma2_model_path: str, prompts: List[str], max_new_tokens: int
|
141
|
+
) -> bool:
|
142
|
+
"""Verifies the reauthored Gemma2 model.
|
143
|
+
|
144
|
+
Return True if the verification passes, False otherwise.
|
145
|
+
"""
|
146
|
+
logging.info("Building the reauthored model from: %s", gemma2_model_path)
|
147
|
+
reauthored_model = gemma2.build_2b_model(gemma2_model_path)
|
148
|
+
|
149
|
+
return verify_reauthored_gemma_model(
|
150
|
+
checkpoint=gemma2_model_path,
|
151
|
+
variant="2b-v2",
|
152
|
+
reauthored_model=reauthored_model,
|
153
|
+
generate_prompts=prompts,
|
154
|
+
forward_input_ids=[[2, 651, 9456, 576, 573, 3520, 3858, 603, 235248]],
|
155
|
+
max_new_tokens=max_new_tokens,
|
156
|
+
atol=1e-04,
|
157
|
+
)
|
@@ -0,0 +1,14 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
@@ -0,0 +1,103 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Example of building a decoder of PaliGemma 3B model which is Gemma1."""
|
17
|
+
|
18
|
+
import ai_edge_torch.generative.layers.model_config as cfg
|
19
|
+
from ai_edge_torch.generative.utilities import model_builder
|
20
|
+
import ai_edge_torch.generative.utilities.loader as loading_utils
|
21
|
+
|
22
|
+
TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
23
|
+
ff_up_proj="language_model.model.layers.{}.mlp.up_proj",
|
24
|
+
ff_down_proj="language_model.model.layers.{}.mlp.down_proj",
|
25
|
+
ff_gate_proj="language_model.model.layers.{}.mlp.gate_proj",
|
26
|
+
attn_query_proj="language_model.model.layers.{}.self_attn.q_proj",
|
27
|
+
attn_key_proj="language_model.model.layers.{}.self_attn.k_proj",
|
28
|
+
attn_value_proj="language_model.model.layers.{}.self_attn.v_proj",
|
29
|
+
attn_output_proj="language_model.model.layers.{}.self_attn.o_proj",
|
30
|
+
pre_attn_norm="language_model.model.layers.{}.input_layernorm",
|
31
|
+
post_attn_norm="language_model.model.layers.{}.post_attention_layernorm",
|
32
|
+
embedding="language_model.model.embed_tokens",
|
33
|
+
final_norm="language_model.model.norm",
|
34
|
+
lm_head=None,
|
35
|
+
)
|
36
|
+
|
37
|
+
|
38
|
+
def get_decoder_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
39
|
+
"""Returns the model config for the decoder of a PaliGemma 3B model.
|
40
|
+
|
41
|
+
Args:
|
42
|
+
kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
|
43
|
+
is 1024.
|
44
|
+
|
45
|
+
Returns:
|
46
|
+
The model config for the decoder of a PaliGemma 3B model.
|
47
|
+
"""
|
48
|
+
attn_config = cfg.AttentionConfig(
|
49
|
+
num_heads=8,
|
50
|
+
head_dim=256,
|
51
|
+
num_query_groups=1,
|
52
|
+
rotary_base=10000,
|
53
|
+
rotary_percentage=1.0,
|
54
|
+
)
|
55
|
+
ff_config = cfg.FeedForwardConfig(
|
56
|
+
type=cfg.FeedForwardType.GATED,
|
57
|
+
activation=cfg.ActivationConfig(cfg.ActivationType.GELU_TANH),
|
58
|
+
intermediate_size=16384,
|
59
|
+
)
|
60
|
+
norm_config = cfg.NormalizationConfig(
|
61
|
+
type=cfg.NormalizationType.RMS_NORM,
|
62
|
+
epsilon=1e-6,
|
63
|
+
zero_centered=True,
|
64
|
+
)
|
65
|
+
block_config = cfg.TransformerBlockConfig(
|
66
|
+
attn_config=attn_config,
|
67
|
+
ff_config=ff_config,
|
68
|
+
pre_attention_norm_config=norm_config,
|
69
|
+
post_attention_norm_config=norm_config,
|
70
|
+
)
|
71
|
+
config = cfg.ModelConfig(
|
72
|
+
vocab_size=257216,
|
73
|
+
num_layers=18,
|
74
|
+
max_seq_len=8192,
|
75
|
+
embedding_dim=2048,
|
76
|
+
embedding_scale=2048**0.5,
|
77
|
+
kv_cache_max_len=kv_cache_max_len,
|
78
|
+
block_configs=block_config,
|
79
|
+
final_norm_config=norm_config,
|
80
|
+
lm_head_use_bias=False,
|
81
|
+
enable_hlfb=True,
|
82
|
+
)
|
83
|
+
return config
|
84
|
+
|
85
|
+
|
86
|
+
def get_fake_decoder_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
|
87
|
+
config = get_decoder_config(kv_cache_max_len)
|
88
|
+
# PaliGemma decoder has only one block config.
|
89
|
+
config.block_config(0).ff_config.intermediate_size = 128
|
90
|
+
config.vocab_size = 128
|
91
|
+
config.num_layers = 2
|
92
|
+
config.max_seq_len = 2 * kv_cache_max_len
|
93
|
+
return config
|
94
|
+
|
95
|
+
|
96
|
+
def build_decoder(
|
97
|
+
checkpoint_path: str, **kwargs
|
98
|
+
) -> model_builder.DecoderOnlyModel:
|
99
|
+
return model_builder.build_decoder_only_model(
|
100
|
+
checkpoint_path=checkpoint_path,
|
101
|
+
config=get_decoder_config(**kwargs),
|
102
|
+
tensor_names=TENSOR_NAMES,
|
103
|
+
)
|
@@ -0,0 +1,158 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Example of building an image encoder of PaliGemma model which is Siglip."""
|
17
|
+
|
18
|
+
from ai_edge_torch.generative.layers import attention
|
19
|
+
from ai_edge_torch.generative.layers import builder
|
20
|
+
import ai_edge_torch.generative.layers.model_config as cfg
|
21
|
+
import ai_edge_torch.generative.utilities.loader as loading_utils
|
22
|
+
import torch
|
23
|
+
from torch import nn
|
24
|
+
|
25
|
+
TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
26
|
+
ff_up_proj="vision_tower.vision_model.encoder.layers.{}.mlp.fc1",
|
27
|
+
ff_down_proj="vision_tower.vision_model.encoder.layers.{}.mlp.fc2",
|
28
|
+
attn_query_proj=(
|
29
|
+
"vision_tower.vision_model.encoder.layers.{}.self_attn.q_proj"
|
30
|
+
),
|
31
|
+
attn_key_proj=(
|
32
|
+
"vision_tower.vision_model.encoder.layers.{}.self_attn.k_proj"
|
33
|
+
),
|
34
|
+
attn_value_proj=(
|
35
|
+
"vision_tower.vision_model.encoder.layers.{}.self_attn.v_proj"
|
36
|
+
),
|
37
|
+
attn_output_proj=(
|
38
|
+
"vision_tower.vision_model.encoder.layers.{}.self_attn.out_proj"
|
39
|
+
),
|
40
|
+
pre_attn_norm="vision_tower.vision_model.encoder.layers.{}.layer_norm1",
|
41
|
+
post_attn_norm="vision_tower.vision_model.encoder.layers.{}.layer_norm2",
|
42
|
+
embedding="vision_tower.vision_model.embeddings.patch_embedding",
|
43
|
+
embedding_position=(
|
44
|
+
"vision_tower.vision_model.embeddings.position_embedding.weight"
|
45
|
+
),
|
46
|
+
final_norm="vision_tower.vision_model.post_layernorm",
|
47
|
+
)
|
48
|
+
|
49
|
+
|
50
|
+
class SiglipVisionEncoder(nn.Module):
|
51
|
+
"""Signlip vision encoder from the Edge Generative API."""
|
52
|
+
|
53
|
+
def __init__(self, config: cfg.ModelConfig):
|
54
|
+
super().__init__()
|
55
|
+
|
56
|
+
# Construct model layers.
|
57
|
+
self.tok_embedding = nn.Conv2d(
|
58
|
+
in_channels=config.image_embedding.channels,
|
59
|
+
out_channels=config.embedding_dim,
|
60
|
+
kernel_size=config.image_embedding.patch_size,
|
61
|
+
stride=config.image_embedding.patch_size,
|
62
|
+
padding="valid",
|
63
|
+
)
|
64
|
+
num_patches = (
|
65
|
+
config.image_embedding.image_size // config.image_embedding.patch_size
|
66
|
+
) ** 2
|
67
|
+
self.tok_embedding_position = nn.Parameter(
|
68
|
+
torch.zeros((num_patches, config.embedding_dim))
|
69
|
+
)
|
70
|
+
|
71
|
+
self.transformer_blocks = nn.ModuleList(
|
72
|
+
attention.TransformerBlock(config.block_config(idx), config)
|
73
|
+
for idx in range(config.num_layers)
|
74
|
+
)
|
75
|
+
self.final_norm = builder.build_norm(
|
76
|
+
config.embedding_dim,
|
77
|
+
config.final_norm_config,
|
78
|
+
)
|
79
|
+
self.config = config
|
80
|
+
|
81
|
+
@torch.inference_mode
|
82
|
+
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
83
|
+
# Embed the image according to SiplipVisionEmbeddings.
|
84
|
+
x = self.tok_embedding(pixel_values)
|
85
|
+
x = x.flatten(2).transpose(1, 2) + self.tok_embedding_position
|
86
|
+
|
87
|
+
# Pass a dummy mask because SDPA attention impl expects non-None mask.
|
88
|
+
mask = torch.zeros(x.shape[:2])
|
89
|
+
for _, block in enumerate(self.transformer_blocks):
|
90
|
+
x = block(x, mask=mask)
|
91
|
+
return self.final_norm(x)
|
92
|
+
|
93
|
+
|
94
|
+
def get_image_encoder_config() -> cfg.ModelConfig:
|
95
|
+
"""Returns the model config for the image encoder of a PaliGemma 3B-224 model.
|
96
|
+
|
97
|
+
Returns:
|
98
|
+
The model config for the image encoder of a PaliGemma 3B model.
|
99
|
+
"""
|
100
|
+
image_embedding_config = cfg.ImageEmbeddingConfig(
|
101
|
+
channels=3,
|
102
|
+
image_size=224,
|
103
|
+
patch_size=14,
|
104
|
+
)
|
105
|
+
attn_config = cfg.AttentionConfig(
|
106
|
+
num_heads=16,
|
107
|
+
head_dim=72,
|
108
|
+
num_query_groups=16,
|
109
|
+
qkv_use_bias=True,
|
110
|
+
output_proj_use_bias=True,
|
111
|
+
)
|
112
|
+
ff_config = cfg.FeedForwardConfig(
|
113
|
+
type=cfg.FeedForwardType.SEQUENTIAL,
|
114
|
+
activation=cfg.ActivationConfig(cfg.ActivationType.GELU_TANH),
|
115
|
+
intermediate_size=4304,
|
116
|
+
use_bias=True,
|
117
|
+
)
|
118
|
+
norm_config = cfg.NormalizationConfig(
|
119
|
+
type=cfg.NormalizationType.LAYER_NORM,
|
120
|
+
epsilon=1e-6,
|
121
|
+
enable_hlfb=True,
|
122
|
+
)
|
123
|
+
block_config = cfg.TransformerBlockConfig(
|
124
|
+
attn_config=attn_config,
|
125
|
+
ff_config=ff_config,
|
126
|
+
pre_attention_norm_config=norm_config,
|
127
|
+
post_attention_norm_config=norm_config,
|
128
|
+
)
|
129
|
+
config = cfg.ModelConfig(
|
130
|
+
vocab_size=0, # Not used in image encoder.
|
131
|
+
num_layers=27,
|
132
|
+
max_seq_len=0, # Not used in image encoder.
|
133
|
+
embedding_dim=1152,
|
134
|
+
embedding_use_bias=True,
|
135
|
+
image_embedding=image_embedding_config,
|
136
|
+
block_configs=block_config,
|
137
|
+
final_norm_config=norm_config,
|
138
|
+
enable_hlfb=True,
|
139
|
+
)
|
140
|
+
return config
|
141
|
+
|
142
|
+
|
143
|
+
def get_fake_image_encoder_config() -> cfg.ModelConfig:
|
144
|
+
config = get_image_encoder_config()
|
145
|
+
# PaliGemma image encoder has only one block config.
|
146
|
+
config.block_config(0).ff_config.intermediate_size = 128
|
147
|
+
config.num_layers = 2
|
148
|
+
return config
|
149
|
+
|
150
|
+
|
151
|
+
def build_image_encoder(checkpoint_path: str) -> SiglipVisionEncoder:
|
152
|
+
config = get_image_encoder_config()
|
153
|
+
encoder = SiglipVisionEncoder(config)
|
154
|
+
loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
|
155
|
+
# Loose the strictness because only image encoder is being loaded.
|
156
|
+
loader.load(encoder, strict=False)
|
157
|
+
encoder.eval()
|
158
|
+
return encoder
|
@@ -0,0 +1,75 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Verifies the reauthored decoder of PaliGemma 3B model."""
|
17
|
+
|
18
|
+
import logging
|
19
|
+
import pathlib
|
20
|
+
|
21
|
+
from absl import app
|
22
|
+
from absl import flags
|
23
|
+
from ai_edge_torch.generative.examples.paligemma import decoder
|
24
|
+
from ai_edge_torch.generative.utilities import transformers_verifier
|
25
|
+
from ai_edge_torch.generative.utilities import verifier
|
26
|
+
import transformers
|
27
|
+
|
28
|
+
_PROMPTS = flags.DEFINE_multi_string(
|
29
|
+
"prompts",
|
30
|
+
"What is the meaning of life?",
|
31
|
+
"The input prompts to generate answers.",
|
32
|
+
)
|
33
|
+
_MAX_NEW_TOKENS = flags.DEFINE_integer(
|
34
|
+
"max_new_tokens",
|
35
|
+
30,
|
36
|
+
"The maximum size of the generated tokens.",
|
37
|
+
)
|
38
|
+
|
39
|
+
|
40
|
+
def main(_):
|
41
|
+
checkpoint = "google/paligemma-3b-mix-224"
|
42
|
+
logging.info("Loading the original model from: %s", checkpoint)
|
43
|
+
original_full_model = (
|
44
|
+
transformers.PaliGemmaForConditionalGeneration.from_pretrained(checkpoint)
|
45
|
+
)
|
46
|
+
original_language_model = original_full_model.eval().language_model
|
47
|
+
|
48
|
+
# Locate the cached dir.
|
49
|
+
cached_config_file = transformers.utils.cached_file(
|
50
|
+
checkpoint, transformers.utils.CONFIG_NAME
|
51
|
+
)
|
52
|
+
reauthored_checkpoint = pathlib.Path(cached_config_file).parent
|
53
|
+
logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
|
54
|
+
reauthored_model = decoder.build_decoder(reauthored_checkpoint)
|
55
|
+
|
56
|
+
logging.info("Loading the tokenizer from: %s", checkpoint)
|
57
|
+
# It works only when GemmaTokenizerFast is available. In some environments,
|
58
|
+
# use_fast=False doeesn't work either if the tokenizer cannot load the
|
59
|
+
# sentencepiece model file properly.
|
60
|
+
processor = transformers.AutoProcessor.from_pretrained(checkpoint)
|
61
|
+
|
62
|
+
verifier.verify_reauthored_model(
|
63
|
+
original_model=transformers_verifier.TransformersModelWrapper(
|
64
|
+
original_language_model
|
65
|
+
),
|
66
|
+
reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
|
67
|
+
tokenizer=verifier.TokenizerWrapper(processor.tokenizer),
|
68
|
+
generate_prompts=_PROMPTS.value,
|
69
|
+
max_new_tokens=_MAX_NEW_TOKENS.value,
|
70
|
+
atol=1e-04,
|
71
|
+
)
|
72
|
+
|
73
|
+
|
74
|
+
if __name__ == "__main__":
|
75
|
+
app.run(main)
|
@@ -0,0 +1,82 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Verifies the reauthored image encoder of PaliGemma 3B model."""
|
17
|
+
|
18
|
+
import logging
|
19
|
+
import pathlib
|
20
|
+
from absl import app
|
21
|
+
from absl import flags
|
22
|
+
from ai_edge_torch.generative.examples.paligemma import image_encoder
|
23
|
+
from PIL import Image
|
24
|
+
import requests
|
25
|
+
import torch
|
26
|
+
import transformers
|
27
|
+
|
28
|
+
_IMAGE_URL = flags.DEFINE_string(
|
29
|
+
"image_url",
|
30
|
+
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg?download=true",
|
31
|
+
"The image URI to encode.",
|
32
|
+
)
|
33
|
+
|
34
|
+
|
35
|
+
def main(_):
|
36
|
+
checkpoint = "google/paligemma-3b-mix-224"
|
37
|
+
logging.info("Loading the original model from: %s", checkpoint)
|
38
|
+
original_full_model = (
|
39
|
+
transformers.PaliGemmaForConditionalGeneration.from_pretrained(checkpoint)
|
40
|
+
)
|
41
|
+
original_vision_model = original_full_model.eval().vision_tower
|
42
|
+
|
43
|
+
# Locate the cached dir.
|
44
|
+
cached_config_file = transformers.utils.cached_file(
|
45
|
+
checkpoint, transformers.utils.CONFIG_NAME
|
46
|
+
)
|
47
|
+
reauthored_checkpoint = pathlib.Path(cached_config_file).parent
|
48
|
+
logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
|
49
|
+
reauthored_model = image_encoder.build_image_encoder(reauthored_checkpoint)
|
50
|
+
|
51
|
+
logging.info("Loading the processor from: %s", checkpoint)
|
52
|
+
# It works only when GemmaTokenizerFast is available. In some environments,
|
53
|
+
# use_fast=False doeesn't work either if the tokenizer cannot load the
|
54
|
+
# sentencepiece model file properly.
|
55
|
+
processor = transformers.AutoProcessor.from_pretrained(checkpoint)
|
56
|
+
|
57
|
+
logging.info("Loading the image from: %s", _IMAGE_URL.value)
|
58
|
+
image = Image.open(requests.get(_IMAGE_URL.value, stream=True).raw)
|
59
|
+
pixel_values = processor(images=image, return_tensors="pt")["pixel_values"]
|
60
|
+
|
61
|
+
logging.info("Forwarding the original model...")
|
62
|
+
outputs_original = original_vision_model.forward(pixel_values=pixel_values)
|
63
|
+
outputs_original = outputs_original.last_hidden_state
|
64
|
+
logging.info("outputs_original: %s", outputs_original)
|
65
|
+
|
66
|
+
logging.info("Forwarding the reauthored model...")
|
67
|
+
outputs_reauthored = reauthored_model.forward(pixel_values=pixel_values)
|
68
|
+
logging.info("outputs_reauthored: %s", outputs_reauthored)
|
69
|
+
|
70
|
+
try:
|
71
|
+
assert torch.allclose(
|
72
|
+
outputs_original, outputs_reauthored, atol=1e-04, rtol=1e-04
|
73
|
+
)
|
74
|
+
except AssertionError as e:
|
75
|
+
logging.error("*** FAILED *** verify with an image")
|
76
|
+
raise e
|
77
|
+
else:
|
78
|
+
logging.info("*** PASSED *** verify with an image")
|
79
|
+
|
80
|
+
|
81
|
+
if __name__ == "__main__":
|
82
|
+
app.run(main)
|
@@ -235,9 +235,10 @@ class CausalSelfAttention(nn.Module):
|
|
235
235
|
k = k.reshape(B, T, -1, self.config.head_dim)
|
236
236
|
v = v.reshape(B, T, -1, self.config.head_dim)
|
237
237
|
|
238
|
-
|
239
|
-
|
240
|
-
|
238
|
+
if rope is not None:
|
239
|
+
# Compute rotary positional embedding for query and key.
|
240
|
+
n_elem = int(self.config.rotary_percentage * self.config.head_dim)
|
241
|
+
q, k = _embed_rope(q, k, n_elem, rope)
|
241
242
|
|
242
243
|
if kv_cache is not None:
|
243
244
|
kv_cache = kv_utils.update(
|
@@ -372,9 +373,10 @@ class CrossAttention(nn.Module):
|
|
372
373
|
k = k.view(interim_shape)
|
373
374
|
v = v.view(interim_shape)
|
374
375
|
|
375
|
-
|
376
|
-
|
377
|
-
|
376
|
+
if rope is not None:
|
377
|
+
# Compute rotary positional embedding for query and key.
|
378
|
+
n_elem = int(self.config.rotary_percentage * self.config.head_dim)
|
379
|
+
q, k = _embed_rope(q, k, n_elem, rope)
|
378
380
|
|
379
381
|
if kv_cache is not None:
|
380
382
|
kv_cache = kv_utils.update(
|
@@ -163,6 +163,16 @@ class TransformerBlockConfig:
|
|
163
163
|
relative_attention: bool = False
|
164
164
|
|
165
165
|
|
166
|
+
@dataclass
|
167
|
+
class ImageEmbeddingConfig:
|
168
|
+
"""Image embedding parameters."""
|
169
|
+
|
170
|
+
channels: int
|
171
|
+
# All images should be normalized to the size of [image_size * image_size].
|
172
|
+
image_size: int
|
173
|
+
patch_size: int
|
174
|
+
|
175
|
+
|
166
176
|
@dataclass
|
167
177
|
class ModelConfig:
|
168
178
|
"""Base configurations for building a transformer architecture."""
|
@@ -183,6 +193,10 @@ class ModelConfig:
|
|
183
193
|
|
184
194
|
# Scale factor of the embedding.
|
185
195
|
embedding_scale: Optional[float] = None
|
196
|
+
# Use bias term within embedding.
|
197
|
+
embedding_use_bias: bool = False
|
198
|
+
# Image embedding parameters.
|
199
|
+
image_embedding: Optional[ImageEmbeddingConfig] = None
|
186
200
|
|
187
201
|
# Use bias term within LLM's HEAD.
|
188
202
|
lm_head_use_bias: bool = False
|
@@ -115,12 +115,15 @@ class AttentionBlock2D(nn.Module):
|
|
115
115
|
"""
|
116
116
|
super().__init__()
|
117
117
|
self.config = config
|
118
|
+
hidden_dim = config.hidden_dim
|
119
|
+
if not hidden_dim:
|
120
|
+
hidden_dim = config.dim
|
118
121
|
self.norm = layers_builder.build_norm(
|
119
|
-
|
122
|
+
hidden_dim, config.normalization_config
|
120
123
|
)
|
121
124
|
self.attention = SelfAttention(
|
122
125
|
config.attention_batch_size,
|
123
|
-
|
126
|
+
hidden_dim,
|
124
127
|
config.attention_config,
|
125
128
|
enable_hlfb=config.enable_hlfb,
|
126
129
|
)
|
@@ -172,7 +175,7 @@ class CrossAttentionBlock2D(nn.Module):
|
|
172
175
|
super().__init__()
|
173
176
|
self.config = config
|
174
177
|
self.norm = layers_builder.build_norm(
|
175
|
-
config.
|
178
|
+
config.output_dim, config.normalization_config
|
176
179
|
)
|
177
180
|
self.attention = CrossAttention(
|
178
181
|
config.attention_batch_size,
|
@@ -294,21 +297,30 @@ class TransformerBlock2D(nn.Module):
|
|
294
297
|
hidden_states
|
295
298
|
"""
|
296
299
|
|
297
|
-
def __init__(
|
300
|
+
def __init__(
|
301
|
+
self, config: unet_cfg.TransformerBlock2DConfig, dim_override=None
|
302
|
+
):
|
298
303
|
"""Initialize an instance of the TransformerBlock2D.
|
299
304
|
|
300
305
|
Args:
|
301
306
|
config (unet_cfg.TransformerBlock2Dconfig): the configuration of this
|
302
307
|
block.
|
308
|
+
dim_override: in case specified, overrides config.attention_block_config.hidden_dim. Set to None by default.
|
303
309
|
"""
|
304
310
|
super().__init__()
|
305
311
|
self.config = config
|
312
|
+
attention_block_config_dim = config.attention_block_config.dim
|
313
|
+
attention_block_config_hidden_dim = config.attention_block_config.hidden_dim
|
314
|
+
if dim_override:
|
315
|
+
attention_block_config_dim = dim_override
|
316
|
+
if not attention_block_config_hidden_dim:
|
317
|
+
attention_block_config_hidden_dim = attention_block_config_dim
|
306
318
|
self.pre_conv_norm = layers_builder.build_norm(
|
307
|
-
|
319
|
+
attention_block_config_dim, config.pre_conv_normalization_config
|
308
320
|
)
|
309
321
|
self.conv_in = nn.Conv2d(
|
310
|
-
|
311
|
-
|
322
|
+
attention_block_config_dim,
|
323
|
+
attention_block_config_hidden_dim,
|
312
324
|
kernel_size=1,
|
313
325
|
padding=0,
|
314
326
|
)
|
@@ -318,8 +330,8 @@ class TransformerBlock2D(nn.Module):
|
|
318
330
|
)
|
319
331
|
self.feed_forward = FeedForwardBlock2D(config.feed_forward_block_config)
|
320
332
|
self.conv_out = nn.Conv2d(
|
321
|
-
|
322
|
-
|
333
|
+
attention_block_config_hidden_dim,
|
334
|
+
attention_block_config_dim,
|
323
335
|
kernel_size=1,
|
324
336
|
padding=0,
|
325
337
|
)
|
@@ -385,14 +397,18 @@ class DownEncoderBlock2D(nn.Module):
|
|
385
397
|
self.config = config
|
386
398
|
resnets = []
|
387
399
|
transformers = []
|
400
|
+
hidden_channels = config.hidden_channels
|
401
|
+
if not hidden_channels:
|
402
|
+
hidden_channels = config.out_channels
|
388
403
|
for i in range(config.num_layers):
|
389
404
|
input_channels = config.in_channels if i == 0 else config.out_channels
|
390
405
|
resnets.append(
|
391
406
|
ResidualBlock2D(
|
392
407
|
unet_cfg.ResidualBlock2DConfig(
|
393
408
|
in_channels=input_channels,
|
394
|
-
hidden_channels=
|
409
|
+
hidden_channels=hidden_channels,
|
395
410
|
out_channels=config.out_channels,
|
411
|
+
residual_out_channels=config.out_channels,
|
396
412
|
time_embedding_channels=config.time_embedding_channels,
|
397
413
|
normalization_config=config.normalization_config,
|
398
414
|
activation_config=config.activation_config,
|
@@ -589,23 +605,37 @@ class SkipUpDecoderBlock2D(nn.Module):
|
|
589
605
|
"""
|
590
606
|
super().__init__()
|
591
607
|
self.config = config
|
608
|
+
hidden_channels = config.hidden_channels
|
609
|
+
if not hidden_channels:
|
610
|
+
hidden_channels = config.out_channels
|
611
|
+
sub_block_channels = config.sub_block_channels
|
612
|
+
if sub_block_channels:
|
613
|
+
assert len(sub_block_channels) == config.num_layers, (
|
614
|
+
"Assertion failed: The length of 'sub_block_channels'"
|
615
|
+
f" ({len(sub_block_channels)}) does not match 'config.num_layers'"
|
616
|
+
f" ({config.num_layers})."
|
617
|
+
)
|
618
|
+
else:
|
619
|
+
sub_block_channels = [config.out_channels] * config.num_layers
|
592
620
|
resnets = []
|
593
621
|
transformers = []
|
594
622
|
for i in range(config.num_layers):
|
623
|
+
resnet_in_channels = (
|
624
|
+
config.prev_out_channels if i == 0 else sub_block_channels[i - 1]
|
625
|
+
)
|
595
626
|
res_skip_channels = (
|
596
627
|
config.in_channels
|
597
628
|
if (i == config.num_layers - 1)
|
598
629
|
else config.out_channels
|
599
630
|
)
|
600
|
-
|
601
|
-
config.prev_out_channels if i == 0 else config.out_channels
|
602
|
-
)
|
631
|
+
residual_out_channel = sub_block_channels[i]
|
603
632
|
resnets.append(
|
604
633
|
ResidualBlock2D(
|
605
634
|
unet_cfg.ResidualBlock2DConfig(
|
606
635
|
in_channels=resnet_in_channels + res_skip_channels,
|
607
|
-
hidden_channels=
|
608
|
-
out_channels=
|
636
|
+
hidden_channels=hidden_channels,
|
637
|
+
out_channels=sub_block_channels[i],
|
638
|
+
residual_out_channels=residual_out_channel,
|
609
639
|
time_embedding_channels=config.time_embedding_channels,
|
610
640
|
normalization_config=config.normalization_config,
|
611
641
|
activation_config=config.activation_config,
|
@@ -613,7 +643,12 @@ class SkipUpDecoderBlock2D(nn.Module):
|
|
613
643
|
)
|
614
644
|
)
|
615
645
|
if config.transformer_block_config:
|
616
|
-
transformers.append(
|
646
|
+
transformers.append(
|
647
|
+
TransformerBlock2D(
|
648
|
+
config.transformer_block_config,
|
649
|
+
dim_override=sub_block_channels[i],
|
650
|
+
)
|
651
|
+
)
|
617
652
|
self.resnets = nn.ModuleList(resnets)
|
618
653
|
self.transformers = (
|
619
654
|
nn.ModuleList(transformers) if len(transformers) > 0 else None
|
@@ -623,7 +658,7 @@ class SkipUpDecoderBlock2D(nn.Module):
|
|
623
658
|
if config.upsample_conv:
|
624
659
|
self.upsample_conv = nn.Conv2d(
|
625
660
|
config.out_channels,
|
626
|
-
|
661
|
+
sub_block_channels[0],
|
627
662
|
kernel_size=3,
|
628
663
|
stride=1,
|
629
664
|
padding=1,
|
@@ -711,6 +746,7 @@ class MidBlock2D(nn.Module):
|
|
711
746
|
in_channels=config.in_channels,
|
712
747
|
hidden_channels=config.in_channels,
|
713
748
|
out_channels=config.in_channels,
|
749
|
+
residual_out_channels=config.in_channels,
|
714
750
|
time_embedding_channels=config.time_embedding_channels,
|
715
751
|
normalization_config=config.normalization_config,
|
716
752
|
activation_config=config.activation_config,
|
@@ -50,10 +50,12 @@ class ResidualBlock2DConfig:
|
|
50
50
|
in_channels: int
|
51
51
|
hidden_channels: int
|
52
52
|
out_channels: int
|
53
|
+
hidden_channels: int
|
53
54
|
normalization_config: layers_cfg.NormalizationConfig
|
54
55
|
activation_config: layers_cfg.ActivationConfig
|
55
56
|
# Optional time embedding channels if the residual block takes a time embedding context as input
|
56
57
|
time_embedding_channels: Optional[int] = None
|
58
|
+
residual_out_channels: Optional[int] = None
|
57
59
|
|
58
60
|
|
59
61
|
@dataclasses.dataclass
|
@@ -63,6 +65,7 @@ class AttentionBlock2DConfig:
|
|
63
65
|
attention_config: layers_cfg.AttentionConfig
|
64
66
|
enable_hlfb: bool = True
|
65
67
|
attention_batch_size: int = 1
|
68
|
+
hidden_dim: Optional[int] = None
|
66
69
|
|
67
70
|
|
68
71
|
@dataclasses.dataclass
|
@@ -101,6 +104,8 @@ class UpDecoderBlock2DConfig:
|
|
101
104
|
normalization_config: layers_cfg.NormalizationConfig
|
102
105
|
activation_config: layers_cfg.ActivationConfig
|
103
106
|
num_layers: int
|
107
|
+
# The dimension of output channels of previous connected block
|
108
|
+
prev_out_channels: Optional[int] = None
|
104
109
|
# Optional time embedding channels if the residual blocks take a time embedding as input
|
105
110
|
time_embedding_channels: Optional[int] = None
|
106
111
|
# Whether to add upsample operation after residual blocks
|
@@ -136,6 +141,8 @@ class SkipUpDecoderBlock2DConfig:
|
|
136
141
|
transformer_block_config: Optional[TransformerBlock2DConfig] = None
|
137
142
|
# Optional dimension of context tensor if context tensor is given as input.
|
138
143
|
context_dim: Optional[int] = None
|
144
|
+
sub_block_channels: Optional[tuple] = None
|
145
|
+
hidden_channels: Optional[int] = None
|
139
146
|
|
140
147
|
|
141
148
|
@dataclasses.dataclass
|
@@ -157,6 +164,7 @@ class DownEncoderBlock2DConfig:
|
|
157
164
|
transformer_block_config: Optional[TransformerBlock2DConfig] = None
|
158
165
|
# Optional dimension of context tensor if context tensor is given as input.
|
159
166
|
context_dim: Optional[int] = None
|
167
|
+
hidden_channels: Optional[int] = None
|
160
168
|
|
161
169
|
|
162
170
|
@dataclasses.dataclass
|
@@ -157,6 +157,10 @@ class ModelLoader:
|
|
157
157
|
converted_state["tok_embedding.weight"] = state.pop(
|
158
158
|
f"{self._names.embedding}.weight"
|
159
159
|
)
|
160
|
+
if model.config.embedding_use_bias:
|
161
|
+
converted_state["tok_embedding.bias"] = state.pop(
|
162
|
+
f"{self._names.embedding}.bias"
|
163
|
+
)
|
160
164
|
if self._names.embedding_position is not None:
|
161
165
|
converted_state["tok_embedding_position"] = state.pop(
|
162
166
|
f"{self._names.embedding_position}"
|
@@ -143,7 +143,7 @@ def verify_with_input_ids(
|
|
143
143
|
kv_cache_max_len: int = 1024,
|
144
144
|
rtol: float = 1e-05,
|
145
145
|
atol: float = 1e-05,
|
146
|
-
)
|
146
|
+
):
|
147
147
|
"""Verifies if the model reauthored generates the same output of the oringal.
|
148
148
|
|
149
149
|
It compares only one outputs from the original and the reauthored model.
|
@@ -157,8 +157,9 @@ def verify_with_input_ids(
|
|
157
157
|
rtol (float): The relative tolerance for the comparison.
|
158
158
|
atol (float): The absolute tolerance for the comparison.
|
159
159
|
|
160
|
-
|
161
|
-
|
160
|
+
Raises:
|
161
|
+
AssertError if the model reauthored fails to generate the same output of the
|
162
|
+
original.
|
162
163
|
"""
|
163
164
|
tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.int, device="cpu")
|
164
165
|
tokens[0, : len(input_ids)] = torch.tensor([input_ids]).int()
|
@@ -173,7 +174,7 @@ def verify_with_input_ids(
|
|
173
174
|
logits_reauthored = outputs_reauthored[0, len(input_ids) - 1, :]
|
174
175
|
logging.info("logits_reauthored: %s", logits_reauthored)
|
175
176
|
|
176
|
-
|
177
|
+
assert torch.allclose(
|
177
178
|
logits_original, logits_reauthored, rtol=rtol, atol=atol
|
178
179
|
)
|
179
180
|
|
@@ -184,7 +185,7 @@ def verify_model_with_prompts(
|
|
184
185
|
tokenizer: TokenizerWrapper,
|
185
186
|
prompts: str,
|
186
187
|
max_new_tokens: int,
|
187
|
-
)
|
188
|
+
):
|
188
189
|
"""Verifies if the model reauthored generates the same answer of the oringal.
|
189
190
|
|
190
191
|
It compares an answer, i.e. multiple continuous outputs generated by the
|
@@ -198,8 +199,9 @@ def verify_model_with_prompts(
|
|
198
199
|
prompts (str): The input prompts to generate answers.
|
199
200
|
max_new_tokens (int): The maximum number of new tokens to generate.
|
200
201
|
|
201
|
-
|
202
|
-
|
202
|
+
Raises:
|
203
|
+
AssertError if the model reauthored fails to generate the same answer of the
|
204
|
+
original.
|
203
205
|
"""
|
204
206
|
prompt_tokens = tokenizer.encode(prompts)
|
205
207
|
|
@@ -213,7 +215,7 @@ def verify_model_with_prompts(
|
|
213
215
|
response_reauthored = tokenizer.decode(outputs_reauthored[0])
|
214
216
|
logging.info("outputs from reauthored model: [[%s]]", response_reauthored)
|
215
217
|
|
216
|
-
|
218
|
+
assert response_original == response_reauthored
|
217
219
|
|
218
220
|
|
219
221
|
def verify_reauthored_model(
|
@@ -225,7 +227,8 @@ def verify_reauthored_model(
|
|
225
227
|
forward_input_ids: List[List[int]] = [[1, 2, 3, 4]],
|
226
228
|
rtol: float = 1e-05,
|
227
229
|
atol: float = 1e-05,
|
228
|
-
|
230
|
+
continue_on_failure: bool = False,
|
231
|
+
) -> bool:
|
229
232
|
"""Verifies the reauthored model against the original model.
|
230
233
|
|
231
234
|
It verifies the reauthored model with two methods:
|
@@ -234,7 +237,8 @@ def verify_reauthored_model(
|
|
234
237
|
2. It compares the answer generated by the original and the reauthored model
|
235
238
|
with a prompt.
|
236
239
|
|
237
|
-
It prints out "PASS" or "FAILED" to the console.
|
240
|
+
It prints out "PASS" or "FAILED" to the console. It returns True if all
|
241
|
+
verification passes, False otherwise.
|
238
242
|
|
239
243
|
Args:
|
240
244
|
original_model (ModelWrapper): The original model.
|
@@ -247,21 +251,42 @@ def verify_reauthored_model(
|
|
247
251
|
forward with.
|
248
252
|
rtol (float): The relative tolerance for the comparison.
|
249
253
|
atol (float): The absolute tolerance for the comparison.
|
254
|
+
continue_on_failure (bool): If True, it continues to verify the next prompt
|
255
|
+
or input IDs even if a previous one fails.
|
250
256
|
"""
|
257
|
+
failure_count = 0
|
258
|
+
|
251
259
|
for input_ids in forward_input_ids:
|
252
260
|
logging.info("Verifying the reauthored model with input IDs: %s", input_ids)
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
|
261
|
+
try:
|
262
|
+
verify_with_input_ids(
|
263
|
+
original_model, reauthored_model, input_ids, rtol=rtol, atol=atol
|
264
|
+
)
|
265
|
+
except AssertionError as e:
|
266
|
+
logging.error("*** FAILED *** verify with input IDs: %s", input_ids)
|
267
|
+
failure_count += 1
|
268
|
+
if not continue_on_failure:
|
269
|
+
return False
|
257
270
|
else:
|
258
|
-
logging.
|
271
|
+
logging.info("*** PASSED *** verify with input IDs: %s", input_ids)
|
259
272
|
|
260
273
|
for prompts in generate_prompts:
|
261
|
-
logging.info("Verifying the reauthored model with prompts
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
|
274
|
+
logging.info("Verifying the reauthored model with prompts: %s", prompts)
|
275
|
+
try:
|
276
|
+
verify_model_with_prompts(
|
277
|
+
original_model, reauthored_model, tokenizer, prompts, max_new_tokens
|
278
|
+
)
|
279
|
+
except AssertionError as e:
|
280
|
+
logging.error("*** FAILED *** verify with prompts: %s", prompts)
|
281
|
+
failure_count += 1
|
282
|
+
if not continue_on_failure:
|
283
|
+
return False
|
266
284
|
else:
|
267
|
-
logging.
|
285
|
+
logging.info("*** PASSED *** verify with prompts: %s", prompts)
|
286
|
+
|
287
|
+
if failure_count == 0:
|
288
|
+
logging.info("*** PASSED *** verify_reauthored_model")
|
289
|
+
return True
|
290
|
+
else:
|
291
|
+
logging.error("*** FAILED *** verify_reauthored_model")
|
292
|
+
return False
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.3.0.
|
3
|
+
Version: 0.3.0.dev20241114
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
@@ -3,7 +3,7 @@ ai_edge_torch/config.py,sha256=FMWeCH2b7HYILBvaI1iZNnYCO4WAhDOwBZBmIE-xrF0,909
|
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
4
4
|
ai_edge_torch/fx_pass_base.py,sha256=518ziQ0TUxqum2qZXqlD8qr65pHPh8ZNLnwFC6zvK3k,4253
|
5
5
|
ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
|
6
|
-
ai_edge_torch/version.py,sha256=
|
6
|
+
ai_edge_torch/version.py,sha256=kPX3PxeMGiRGAaIsITYzfPxu3_FKKLg91pGrqXOP6OY,706
|
7
7
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
8
8
|
ai_edge_torch/_convert/conversion.py,sha256=HwzfRx_DX5TLtPqwEH1_NOm38_INvHzHl4_mX67KOdQ,5448
|
9
9
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
@@ -12,7 +12,7 @@ ai_edge_torch/_convert/signature.py,sha256=rGpBNss3Y9FCCCcdBwDo16KqavJi8N5P0M_6W
|
|
12
12
|
ai_edge_torch/_convert/to_channel_last_io.py,sha256=_31phf7TYgZY2ftpNbrdlB1RhDium1lz_BXEQ6IsMFc,2893
|
13
13
|
ai_edge_torch/_convert/fx_passes/__init__.py,sha256=NVe-eGcm7j8jZpP2pcMhC8j5dVjgR1pPzyXhHdvKH4E,1267
|
14
14
|
ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py,sha256=doaww8KqrgRTD5LotBVAIRFsEqzPn9R5lcGehBJOczA,9098
|
15
|
-
ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py,sha256=
|
15
|
+
ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py,sha256=m_yj66V11LmWCYgA7yLtr__cy14IbC5WEJe0BE0_IPE,4339
|
16
16
|
ai_edge_torch/_convert/fx_passes/inject_mlir_debuginfo_pass.py,sha256=IlZuK42kfVcRqAWZp4j2k_81T2uWo9T2558U_GPJAlU,2327
|
17
17
|
ai_edge_torch/_convert/fx_passes/remove_non_user_outputs_pass.py,sha256=f1IUVWyhioOClsMiZzLyynoW2R17U83vA-7Q-3pGPM4,2126
|
18
18
|
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/__init__.py,sha256=lxnoH-WGLeiQIF8XjMGodjiZEFTxucl7g05N7MR9OPk,796
|
@@ -50,16 +50,21 @@ ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py,sha256=RZDs6
|
|
50
50
|
ai_edge_torch/generative/examples/gemma/gemma1.py,sha256=oSbysiPvwp5efMbNYZop3HrxDMGiD15Tmz-HiQuTr2E,3315
|
51
51
|
ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=RQFQDMEnIVp8PefcCTr7P0CvllKI7FVoIJLXbPLLIsc,9056
|
52
52
|
ai_edge_torch/generative/examples/gemma/verify_gemma1.py,sha256=ip-Gmk4CI5f0GWSdAIdrectxQWJ0t328KCsA4nfHuGg,1736
|
53
|
-
ai_edge_torch/generative/examples/gemma/verify_gemma2.py,sha256=
|
54
|
-
ai_edge_torch/generative/examples/gemma/verify_util.py,sha256=
|
53
|
+
ai_edge_torch/generative/examples/gemma/verify_gemma2.py,sha256=IoBhEMwH07-tFm5-U6F2hpCsI8xynglhq1x9tIOdaPQ,1322
|
54
|
+
ai_edge_torch/generative/examples/gemma/verify_util.py,sha256=tR8RflXocDZqvuStyw9aFlzuiTllEC8rNnjrxms6_Is,5727
|
55
55
|
ai_edge_torch/generative/examples/llama/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
56
56
|
ai_edge_torch/generative/examples/llama/convert_to_tflite.py,sha256=P0-pByTM5tslE23ILgo7nd0nOGE25ciBRG5wKJj0bBk,2411
|
57
57
|
ai_edge_torch/generative/examples/llama/llama.py,sha256=AMcCbuDBxEfbO-l3KiEXbUaXEJ3RLLwkHii7to7UhVo,6854
|
58
58
|
ai_edge_torch/generative/examples/llama/verify.py,sha256=X7oKQi85M789ugBrOlMvzk8eSRR3Kf1Mprfl-U-WIpo,2842
|
59
59
|
ai_edge_torch/generative/examples/openelm/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
60
60
|
ai_edge_torch/generative/examples/openelm/convert_to_tflite.py,sha256=85FVEt6cKFP2UzCLC78tAkbwGlGhAArtG7Wa75NxJik,2185
|
61
|
-
ai_edge_torch/generative/examples/openelm/openelm.py,sha256=
|
61
|
+
ai_edge_torch/generative/examples/openelm/openelm.py,sha256=sFakstoPDcOHSak0IGFEEq_HQMBBSMcx-WVCDZqcVDo,4411
|
62
62
|
ai_edge_torch/generative/examples/openelm/verify.py,sha256=VkigoqhAr8ew95neb3TifYv-SLOSheaWKv2AH0iKDrc,2441
|
63
|
+
ai_edge_torch/generative/examples/paligemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
64
|
+
ai_edge_torch/generative/examples/paligemma/decoder.py,sha256=JSb9h3gcIh5oYrbLU6rI8OU8FzfWeTCFJT5XRWu4btE,3675
|
65
|
+
ai_edge_torch/generative/examples/paligemma/image_encoder.py,sha256=v19_EKALhAP9FjkINKqpv8JsVaQ6iH_7X5FpnhE6abw,5500
|
66
|
+
ai_edge_torch/generative/examples/paligemma/verify_decoder.py,sha256=al5wMPWri4IRVWrLmCplPi6uoCzwh0vBHMGnCt-XUqo,2690
|
67
|
+
ai_edge_torch/generative/examples/paligemma/verify_image_encoder.py,sha256=pSekf1BybhieQz3cQx_llbRQHxczXbTqool8fOyGj_0,3114
|
63
68
|
ai_edge_torch/generative/examples/phi/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
64
69
|
ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py,sha256=rkbTtMaqSVG48cm-NTxR_LDgZmXAEBqayTm9O49oMXc,2171
|
65
70
|
ai_edge_torch/generative/examples/phi/convert_to_tflite.py,sha256=3go690yX6PFeXMdpY7y4JZorAwxX0HT_b_pKZieauvk,2169
|
@@ -105,19 +110,19 @@ ai_edge_torch/generative/examples/tiny_llama/verify.py,sha256=7Bk8z033M-BCXJ299f
|
|
105
110
|
ai_edge_torch/generative/fx_passes/__init__.py,sha256=jrzCB3ZyY_t5jJM1e2Czdt3DjAIL43R0_a-T-I7wOzw,1155
|
106
111
|
ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=hhxSQvkDMv0isZJhmuLiod66ZODaJ8uSPSVTJVHBabQ,1931
|
107
112
|
ai_edge_torch/generative/layers/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
108
|
-
ai_edge_torch/generative/layers/attention.py,sha256=
|
113
|
+
ai_edge_torch/generative/layers/attention.py,sha256=zN3BQjA25Ej_aRU0rFnyx--K74xf5ykc02zGvUpYHeE,13295
|
109
114
|
ai_edge_torch/generative/layers/attention_utils.py,sha256=zBVwlBUTs-nStIKCZG0ks5ra7tsqc9ShfakFJKH5rds,7344
|
110
115
|
ai_edge_torch/generative/layers/builder.py,sha256=Z5LyzCEThgnYZeyViakaE3yJVzTGHtw13acHsAQR15U,5050
|
111
116
|
ai_edge_torch/generative/layers/feed_forward.py,sha256=hdICat-8gW7-vxDAevJQ8NQ-mynllPiqLdXQMF6JMnc,4189
|
112
117
|
ai_edge_torch/generative/layers/kv_cache.py,sha256=lbm-yJ1jGPtcgWS4C3FmSnB1IlxqDE7g0BLRh3PN4N4,6324
|
113
|
-
ai_edge_torch/generative/layers/model_config.py,sha256=
|
118
|
+
ai_edge_torch/generative/layers/model_config.py,sha256=xqa7ZBEjgK4UWJAThRXb_VBFZ5KCGtDu-QaY5GXar9s,7366
|
114
119
|
ai_edge_torch/generative/layers/normalization.py,sha256=eKAGst9rPuyRFExMcQFJO7R3iHdCtlmjeF_lITjLhwE,6498
|
115
120
|
ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=CZqOoibLcHvUgrgaIIWAlmk3XgE2inzx340MN-npLoU,1347
|
116
121
|
ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=gXxh3papKy4FBpGEX7VyZ7rZ1Js6aHK70Q6DKrVSckY,4154
|
117
122
|
ai_edge_torch/generative/layers/unet/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
118
|
-
ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=
|
123
|
+
ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=9jKzOfiBQ66bp1ZnVIAoREIifVNFx4aTlQeYMAx2_pA,29062
|
119
124
|
ai_edge_torch/generative/layers/unet/builder.py,sha256=zAqWXdimmMrQRhmE_t9XkS68mh6PSrzwb-2NZZXrR5I,1901
|
120
|
-
ai_edge_torch/generative/layers/unet/model_config.py,sha256=
|
125
|
+
ai_edge_torch/generative/layers/unet/model_config.py,sha256=pPDwLawc23pfMaPVyMJlYmxVVusjMvx-l8wBwOYOH-c,9692
|
121
126
|
ai_edge_torch/generative/quantize/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
122
127
|
ai_edge_torch/generative/quantize/example.py,sha256=1lfVNUd2cEyRUnoZ7BLbRJ9IN-FTKiWBtZNPFUzAiWE,1747
|
123
128
|
ai_edge_torch/generative/quantize/quant_attrs.py,sha256=n1Fm8BFC8gJa_oiwwAOOghJyHtOXYZ4q-5ZRy4pHrIw,1957
|
@@ -134,12 +139,12 @@ ai_edge_torch/generative/test/test_quantize.py,sha256=8geJhKwYBU20m0mdGPD1BUFwQ0
|
|
134
139
|
ai_edge_torch/generative/test/utils.py,sha256=YvEhO2HIj1LkBs5du1UxY-cGRW9HMyAYsOUhgsTrTpA,1796
|
135
140
|
ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
|
136
141
|
ai_edge_torch/generative/utilities/converter.py,sha256=17O83wVifH1vQJCI4WC3DaNiCIOyK2gys1GzohbLrRs,5554
|
137
|
-
ai_edge_torch/generative/utilities/loader.py,sha256=
|
142
|
+
ai_edge_torch/generative/utilities/loader.py,sha256=k5fjCokNomte4ymy9IJrEWAuCSMhsPCJfmv1y5s0ZEc,13452
|
138
143
|
ai_edge_torch/generative/utilities/model_builder.py,sha256=89jt80UUfDzYBi-x077HBavWeuNJuYPXym9fiKCY1Tk,5278
|
139
144
|
ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=dqPD9qRXEWtU3ombslOC-BE2l_dMwHoCNu7NsIJhsso,36158
|
140
145
|
ai_edge_torch/generative/utilities/t5_loader.py,sha256=tEsfy8-ymzbbjOIc-oesXF3yGyyWtJgFXn2s7VOavt8,16961
|
141
146
|
ai_edge_torch/generative/utilities/transformers_verifier.py,sha256=8sp9m_FMcXn7nqOrochtu2jIANkJKhnhIBUmH0ZTDR4,1549
|
142
|
-
ai_edge_torch/generative/utilities/verifier.py,sha256=
|
147
|
+
ai_edge_torch/generative/utilities/verifier.py,sha256=h5hGyIpYGyPZwvelbzpdkjy99Kpd4JkvhqWtQN9cm-M,10413
|
143
148
|
ai_edge_torch/hlfb/__init__.py,sha256=sH4um75na-O8tzxN6chFyp6Y4xnexsE7kUQpZySv6dE,735
|
144
149
|
ai_edge_torch/hlfb/mark_pattern/__init__.py,sha256=cjTprggj_cuktSCm7-A25e7Shop3k63ylp7sdZmtZ8o,4790
|
145
150
|
ai_edge_torch/hlfb/mark_pattern/passes.py,sha256=pjkKcI1nHECPluAt87cFBrt1DP0f3ge7rHq1NhCkBIE,1936
|
@@ -186,8 +191,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
|
|
186
191
|
ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
187
192
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
188
193
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
189
|
-
ai_edge_torch_nightly-0.3.0.
|
190
|
-
ai_edge_torch_nightly-0.3.0.
|
191
|
-
ai_edge_torch_nightly-0.3.0.
|
192
|
-
ai_edge_torch_nightly-0.3.0.
|
193
|
-
ai_edge_torch_nightly-0.3.0.
|
194
|
+
ai_edge_torch_nightly-0.3.0.dev20241114.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
195
|
+
ai_edge_torch_nightly-0.3.0.dev20241114.dist-info/METADATA,sha256=LApNIpUONX7p4MYLOsx9QbWyD31PtVQDgQShvzqW-ec,1897
|
196
|
+
ai_edge_torch_nightly-0.3.0.dev20241114.dist-info/WHEEL,sha256=bFJAMchF8aTQGUgMZzHJyDDMPTO3ToJ7x23SLJa1SVo,92
|
197
|
+
ai_edge_torch_nightly-0.3.0.dev20241114.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
198
|
+
ai_edge_torch_nightly-0.3.0.dev20241114.dist-info/RECORD,,
|
File without changes
|