ai-edge-torch-nightly 0.3.0.dev20241108__py3-none-any.whl → 0.3.0.dev20241114__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py +2 -0
- ai_edge_torch/generative/examples/gemma/verify_gemma2.py +1 -14
- ai_edge_torch/generative/examples/gemma/verify_util.py +28 -3
- ai_edge_torch/generative/examples/openelm/openelm.py +1 -0
- ai_edge_torch/generative/examples/paligemma/__init__.py +14 -0
- ai_edge_torch/generative/examples/paligemma/decoder.py +103 -0
- ai_edge_torch/generative/examples/paligemma/image_encoder.py +158 -0
- ai_edge_torch/generative/examples/paligemma/verify_decoder.py +75 -0
- ai_edge_torch/generative/examples/paligemma/verify_image_encoder.py +82 -0
- ai_edge_torch/generative/layers/attention.py +8 -6
- ai_edge_torch/generative/layers/model_config.py +14 -0
- ai_edge_torch/generative/layers/unet/blocks_2d.py +53 -17
- ai_edge_torch/generative/layers/unet/model_config.py +8 -0
- ai_edge_torch/generative/utilities/loader.py +4 -0
- ai_edge_torch/generative/utilities/verifier.py +46 -21
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20241108.dist-info → ai_edge_torch_nightly-0.3.0.dev20241114.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20241108.dist-info → ai_edge_torch_nightly-0.3.0.dev20241114.dist-info}/RECORD +21 -16
- {ai_edge_torch_nightly-0.3.0.dev20241108.dist-info → ai_edge_torch_nightly-0.3.0.dev20241114.dist-info}/WHEEL +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20241108.dist-info → ai_edge_torch_nightly-0.3.0.dev20241114.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20241108.dist-info → ai_edge_torch_nightly-0.3.0.dev20241114.dist-info}/top_level.txt +0 -0
@@ -51,6 +51,7 @@ def _get_upsample_bilinear2d_pattern():
|
|
51
51
|
return {
|
52
52
|
"output": (int(output_h), int(output_w)),
|
53
53
|
"align_corners": False,
|
54
|
+
"is_nchw_op": True,
|
54
55
|
}
|
55
56
|
|
56
57
|
return pattern
|
@@ -74,6 +75,7 @@ def _get_upsample_bilinear2d_align_corners_pattern():
|
|
74
75
|
return {
|
75
76
|
"output": (int(output_h), int(output_w)),
|
76
77
|
"align_corners": True,
|
78
|
+
"is_nchw_op": True,
|
77
79
|
}
|
78
80
|
|
79
81
|
return pattern
|
@@ -15,10 +15,8 @@
|
|
15
15
|
|
16
16
|
"""Verifies the reauthored Gemma2 model."""
|
17
17
|
|
18
|
-
import logging
|
19
18
|
from absl import app
|
20
19
|
from absl import flags
|
21
|
-
from ai_edge_torch.generative.examples.gemma import gemma2
|
22
20
|
from ai_edge_torch.generative.examples.gemma import verify_util
|
23
21
|
import kagglehub
|
24
22
|
|
@@ -38,18 +36,7 @@ _MAX_NEW_TOKENS = flags.DEFINE_integer(
|
|
38
36
|
def main(_):
|
39
37
|
checkpoint = kagglehub.model_download("google/gemma-2/pyTorch/gemma-2-2b-it")
|
40
38
|
|
41
|
-
|
42
|
-
reauthored_model = gemma2.build_2b_model(checkpoint)
|
43
|
-
|
44
|
-
verify_util.verify_reauthored_gemma_model(
|
45
|
-
checkpoint=checkpoint,
|
46
|
-
variant="2b-v2",
|
47
|
-
reauthored_model=reauthored_model,
|
48
|
-
generate_prompts=_PROMPTS.value,
|
49
|
-
forward_input_ids=[[2, 651, 9456, 576, 573, 3520, 3858, 603, 235248]],
|
50
|
-
max_new_tokens=_MAX_NEW_TOKENS.value,
|
51
|
-
atol=1e-04,
|
52
|
-
)
|
39
|
+
verify_util.verify_gemma2(checkpoint, _PROMPTS.value, _MAX_NEW_TOKENS.value)
|
53
40
|
|
54
41
|
|
55
42
|
if __name__ == "__main__":
|
@@ -19,6 +19,7 @@ import logging
|
|
19
19
|
import os
|
20
20
|
from typing import List, Tuple
|
21
21
|
|
22
|
+
from ai_edge_torch.generative.examples.gemma import gemma2
|
22
23
|
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
23
24
|
from ai_edge_torch.generative.utilities import verifier
|
24
25
|
from gemma import config as gemma_config
|
@@ -109,8 +110,11 @@ def verify_reauthored_gemma_model(
|
|
109
110
|
max_new_tokens: int = 20,
|
110
111
|
rtol: float = 1e-05,
|
111
112
|
atol: float = 1e-05,
|
112
|
-
):
|
113
|
-
"""Verifies the reauthored Gemma model against the original model.
|
113
|
+
) -> bool:
|
114
|
+
"""Verifies the reauthored Gemma model against the original model.
|
115
|
+
|
116
|
+
Returns True if the verification passes, False otherwise.
|
117
|
+
"""
|
114
118
|
config = gemma_config.get_model_config(variant)
|
115
119
|
config.tokenizer = os.path.join(checkpoint, tokenizer_filename)
|
116
120
|
# Use float32 to be compatible with the reauthored model.
|
@@ -120,7 +124,7 @@ def verify_reauthored_gemma_model(
|
|
120
124
|
original_model = gemma_model.GemmaForCausalLM(config).eval()
|
121
125
|
original_model.load_weights(os.path.join(checkpoint, weight_filename))
|
122
126
|
|
123
|
-
verifier.verify_reauthored_model(
|
127
|
+
return verifier.verify_reauthored_model(
|
124
128
|
original_model=GemmaWrapper(original_model),
|
125
129
|
reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
|
126
130
|
tokenizer=GemmaTokenizerWrapper(original_model.tokenizer),
|
@@ -130,3 +134,24 @@ def verify_reauthored_gemma_model(
|
|
130
134
|
rtol=rtol,
|
131
135
|
atol=atol,
|
132
136
|
)
|
137
|
+
|
138
|
+
|
139
|
+
def verify_gemma2(
|
140
|
+
gemma2_model_path: str, prompts: List[str], max_new_tokens: int
|
141
|
+
) -> bool:
|
142
|
+
"""Verifies the reauthored Gemma2 model.
|
143
|
+
|
144
|
+
Return True if the verification passes, False otherwise.
|
145
|
+
"""
|
146
|
+
logging.info("Building the reauthored model from: %s", gemma2_model_path)
|
147
|
+
reauthored_model = gemma2.build_2b_model(gemma2_model_path)
|
148
|
+
|
149
|
+
return verify_reauthored_gemma_model(
|
150
|
+
checkpoint=gemma2_model_path,
|
151
|
+
variant="2b-v2",
|
152
|
+
reauthored_model=reauthored_model,
|
153
|
+
generate_prompts=prompts,
|
154
|
+
forward_input_ids=[[2, 651, 9456, 576, 573, 3520, 3858, 603, 235248]],
|
155
|
+
max_new_tokens=max_new_tokens,
|
156
|
+
atol=1e-04,
|
157
|
+
)
|
@@ -0,0 +1,14 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
@@ -0,0 +1,103 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Example of building a decoder of PaliGemma 3B model which is Gemma1."""
|
17
|
+
|
18
|
+
import ai_edge_torch.generative.layers.model_config as cfg
|
19
|
+
from ai_edge_torch.generative.utilities import model_builder
|
20
|
+
import ai_edge_torch.generative.utilities.loader as loading_utils
|
21
|
+
|
22
|
+
TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
23
|
+
ff_up_proj="language_model.model.layers.{}.mlp.up_proj",
|
24
|
+
ff_down_proj="language_model.model.layers.{}.mlp.down_proj",
|
25
|
+
ff_gate_proj="language_model.model.layers.{}.mlp.gate_proj",
|
26
|
+
attn_query_proj="language_model.model.layers.{}.self_attn.q_proj",
|
27
|
+
attn_key_proj="language_model.model.layers.{}.self_attn.k_proj",
|
28
|
+
attn_value_proj="language_model.model.layers.{}.self_attn.v_proj",
|
29
|
+
attn_output_proj="language_model.model.layers.{}.self_attn.o_proj",
|
30
|
+
pre_attn_norm="language_model.model.layers.{}.input_layernorm",
|
31
|
+
post_attn_norm="language_model.model.layers.{}.post_attention_layernorm",
|
32
|
+
embedding="language_model.model.embed_tokens",
|
33
|
+
final_norm="language_model.model.norm",
|
34
|
+
lm_head=None,
|
35
|
+
)
|
36
|
+
|
37
|
+
|
38
|
+
def get_decoder_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
39
|
+
"""Returns the model config for the decoder of a PaliGemma 3B model.
|
40
|
+
|
41
|
+
Args:
|
42
|
+
kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
|
43
|
+
is 1024.
|
44
|
+
|
45
|
+
Returns:
|
46
|
+
The model config for the decoder of a PaliGemma 3B model.
|
47
|
+
"""
|
48
|
+
attn_config = cfg.AttentionConfig(
|
49
|
+
num_heads=8,
|
50
|
+
head_dim=256,
|
51
|
+
num_query_groups=1,
|
52
|
+
rotary_base=10000,
|
53
|
+
rotary_percentage=1.0,
|
54
|
+
)
|
55
|
+
ff_config = cfg.FeedForwardConfig(
|
56
|
+
type=cfg.FeedForwardType.GATED,
|
57
|
+
activation=cfg.ActivationConfig(cfg.ActivationType.GELU_TANH),
|
58
|
+
intermediate_size=16384,
|
59
|
+
)
|
60
|
+
norm_config = cfg.NormalizationConfig(
|
61
|
+
type=cfg.NormalizationType.RMS_NORM,
|
62
|
+
epsilon=1e-6,
|
63
|
+
zero_centered=True,
|
64
|
+
)
|
65
|
+
block_config = cfg.TransformerBlockConfig(
|
66
|
+
attn_config=attn_config,
|
67
|
+
ff_config=ff_config,
|
68
|
+
pre_attention_norm_config=norm_config,
|
69
|
+
post_attention_norm_config=norm_config,
|
70
|
+
)
|
71
|
+
config = cfg.ModelConfig(
|
72
|
+
vocab_size=257216,
|
73
|
+
num_layers=18,
|
74
|
+
max_seq_len=8192,
|
75
|
+
embedding_dim=2048,
|
76
|
+
embedding_scale=2048**0.5,
|
77
|
+
kv_cache_max_len=kv_cache_max_len,
|
78
|
+
block_configs=block_config,
|
79
|
+
final_norm_config=norm_config,
|
80
|
+
lm_head_use_bias=False,
|
81
|
+
enable_hlfb=True,
|
82
|
+
)
|
83
|
+
return config
|
84
|
+
|
85
|
+
|
86
|
+
def get_fake_decoder_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
|
87
|
+
config = get_decoder_config(kv_cache_max_len)
|
88
|
+
# PaliGemma decoder has only one block config.
|
89
|
+
config.block_config(0).ff_config.intermediate_size = 128
|
90
|
+
config.vocab_size = 128
|
91
|
+
config.num_layers = 2
|
92
|
+
config.max_seq_len = 2 * kv_cache_max_len
|
93
|
+
return config
|
94
|
+
|
95
|
+
|
96
|
+
def build_decoder(
|
97
|
+
checkpoint_path: str, **kwargs
|
98
|
+
) -> model_builder.DecoderOnlyModel:
|
99
|
+
return model_builder.build_decoder_only_model(
|
100
|
+
checkpoint_path=checkpoint_path,
|
101
|
+
config=get_decoder_config(**kwargs),
|
102
|
+
tensor_names=TENSOR_NAMES,
|
103
|
+
)
|
@@ -0,0 +1,158 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Example of building an image encoder of PaliGemma model which is Siglip."""
|
17
|
+
|
18
|
+
from ai_edge_torch.generative.layers import attention
|
19
|
+
from ai_edge_torch.generative.layers import builder
|
20
|
+
import ai_edge_torch.generative.layers.model_config as cfg
|
21
|
+
import ai_edge_torch.generative.utilities.loader as loading_utils
|
22
|
+
import torch
|
23
|
+
from torch import nn
|
24
|
+
|
25
|
+
TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
26
|
+
ff_up_proj="vision_tower.vision_model.encoder.layers.{}.mlp.fc1",
|
27
|
+
ff_down_proj="vision_tower.vision_model.encoder.layers.{}.mlp.fc2",
|
28
|
+
attn_query_proj=(
|
29
|
+
"vision_tower.vision_model.encoder.layers.{}.self_attn.q_proj"
|
30
|
+
),
|
31
|
+
attn_key_proj=(
|
32
|
+
"vision_tower.vision_model.encoder.layers.{}.self_attn.k_proj"
|
33
|
+
),
|
34
|
+
attn_value_proj=(
|
35
|
+
"vision_tower.vision_model.encoder.layers.{}.self_attn.v_proj"
|
36
|
+
),
|
37
|
+
attn_output_proj=(
|
38
|
+
"vision_tower.vision_model.encoder.layers.{}.self_attn.out_proj"
|
39
|
+
),
|
40
|
+
pre_attn_norm="vision_tower.vision_model.encoder.layers.{}.layer_norm1",
|
41
|
+
post_attn_norm="vision_tower.vision_model.encoder.layers.{}.layer_norm2",
|
42
|
+
embedding="vision_tower.vision_model.embeddings.patch_embedding",
|
43
|
+
embedding_position=(
|
44
|
+
"vision_tower.vision_model.embeddings.position_embedding.weight"
|
45
|
+
),
|
46
|
+
final_norm="vision_tower.vision_model.post_layernorm",
|
47
|
+
)
|
48
|
+
|
49
|
+
|
50
|
+
class SiglipVisionEncoder(nn.Module):
|
51
|
+
"""Signlip vision encoder from the Edge Generative API."""
|
52
|
+
|
53
|
+
def __init__(self, config: cfg.ModelConfig):
|
54
|
+
super().__init__()
|
55
|
+
|
56
|
+
# Construct model layers.
|
57
|
+
self.tok_embedding = nn.Conv2d(
|
58
|
+
in_channels=config.image_embedding.channels,
|
59
|
+
out_channels=config.embedding_dim,
|
60
|
+
kernel_size=config.image_embedding.patch_size,
|
61
|
+
stride=config.image_embedding.patch_size,
|
62
|
+
padding="valid",
|
63
|
+
)
|
64
|
+
num_patches = (
|
65
|
+
config.image_embedding.image_size // config.image_embedding.patch_size
|
66
|
+
) ** 2
|
67
|
+
self.tok_embedding_position = nn.Parameter(
|
68
|
+
torch.zeros((num_patches, config.embedding_dim))
|
69
|
+
)
|
70
|
+
|
71
|
+
self.transformer_blocks = nn.ModuleList(
|
72
|
+
attention.TransformerBlock(config.block_config(idx), config)
|
73
|
+
for idx in range(config.num_layers)
|
74
|
+
)
|
75
|
+
self.final_norm = builder.build_norm(
|
76
|
+
config.embedding_dim,
|
77
|
+
config.final_norm_config,
|
78
|
+
)
|
79
|
+
self.config = config
|
80
|
+
|
81
|
+
@torch.inference_mode
|
82
|
+
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
83
|
+
# Embed the image according to SiplipVisionEmbeddings.
|
84
|
+
x = self.tok_embedding(pixel_values)
|
85
|
+
x = x.flatten(2).transpose(1, 2) + self.tok_embedding_position
|
86
|
+
|
87
|
+
# Pass a dummy mask because SDPA attention impl expects non-None mask.
|
88
|
+
mask = torch.zeros(x.shape[:2])
|
89
|
+
for _, block in enumerate(self.transformer_blocks):
|
90
|
+
x = block(x, mask=mask)
|
91
|
+
return self.final_norm(x)
|
92
|
+
|
93
|
+
|
94
|
+
def get_image_encoder_config() -> cfg.ModelConfig:
|
95
|
+
"""Returns the model config for the image encoder of a PaliGemma 3B-224 model.
|
96
|
+
|
97
|
+
Returns:
|
98
|
+
The model config for the image encoder of a PaliGemma 3B model.
|
99
|
+
"""
|
100
|
+
image_embedding_config = cfg.ImageEmbeddingConfig(
|
101
|
+
channels=3,
|
102
|
+
image_size=224,
|
103
|
+
patch_size=14,
|
104
|
+
)
|
105
|
+
attn_config = cfg.AttentionConfig(
|
106
|
+
num_heads=16,
|
107
|
+
head_dim=72,
|
108
|
+
num_query_groups=16,
|
109
|
+
qkv_use_bias=True,
|
110
|
+
output_proj_use_bias=True,
|
111
|
+
)
|
112
|
+
ff_config = cfg.FeedForwardConfig(
|
113
|
+
type=cfg.FeedForwardType.SEQUENTIAL,
|
114
|
+
activation=cfg.ActivationConfig(cfg.ActivationType.GELU_TANH),
|
115
|
+
intermediate_size=4304,
|
116
|
+
use_bias=True,
|
117
|
+
)
|
118
|
+
norm_config = cfg.NormalizationConfig(
|
119
|
+
type=cfg.NormalizationType.LAYER_NORM,
|
120
|
+
epsilon=1e-6,
|
121
|
+
enable_hlfb=True,
|
122
|
+
)
|
123
|
+
block_config = cfg.TransformerBlockConfig(
|
124
|
+
attn_config=attn_config,
|
125
|
+
ff_config=ff_config,
|
126
|
+
pre_attention_norm_config=norm_config,
|
127
|
+
post_attention_norm_config=norm_config,
|
128
|
+
)
|
129
|
+
config = cfg.ModelConfig(
|
130
|
+
vocab_size=0, # Not used in image encoder.
|
131
|
+
num_layers=27,
|
132
|
+
max_seq_len=0, # Not used in image encoder.
|
133
|
+
embedding_dim=1152,
|
134
|
+
embedding_use_bias=True,
|
135
|
+
image_embedding=image_embedding_config,
|
136
|
+
block_configs=block_config,
|
137
|
+
final_norm_config=norm_config,
|
138
|
+
enable_hlfb=True,
|
139
|
+
)
|
140
|
+
return config
|
141
|
+
|
142
|
+
|
143
|
+
def get_fake_image_encoder_config() -> cfg.ModelConfig:
|
144
|
+
config = get_image_encoder_config()
|
145
|
+
# PaliGemma image encoder has only one block config.
|
146
|
+
config.block_config(0).ff_config.intermediate_size = 128
|
147
|
+
config.num_layers = 2
|
148
|
+
return config
|
149
|
+
|
150
|
+
|
151
|
+
def build_image_encoder(checkpoint_path: str) -> SiglipVisionEncoder:
|
152
|
+
config = get_image_encoder_config()
|
153
|
+
encoder = SiglipVisionEncoder(config)
|
154
|
+
loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
|
155
|
+
# Loose the strictness because only image encoder is being loaded.
|
156
|
+
loader.load(encoder, strict=False)
|
157
|
+
encoder.eval()
|
158
|
+
return encoder
|
@@ -0,0 +1,75 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Verifies the reauthored decoder of PaliGemma 3B model."""
|
17
|
+
|
18
|
+
import logging
|
19
|
+
import pathlib
|
20
|
+
|
21
|
+
from absl import app
|
22
|
+
from absl import flags
|
23
|
+
from ai_edge_torch.generative.examples.paligemma import decoder
|
24
|
+
from ai_edge_torch.generative.utilities import transformers_verifier
|
25
|
+
from ai_edge_torch.generative.utilities import verifier
|
26
|
+
import transformers
|
27
|
+
|
28
|
+
_PROMPTS = flags.DEFINE_multi_string(
|
29
|
+
"prompts",
|
30
|
+
"What is the meaning of life?",
|
31
|
+
"The input prompts to generate answers.",
|
32
|
+
)
|
33
|
+
_MAX_NEW_TOKENS = flags.DEFINE_integer(
|
34
|
+
"max_new_tokens",
|
35
|
+
30,
|
36
|
+
"The maximum size of the generated tokens.",
|
37
|
+
)
|
38
|
+
|
39
|
+
|
40
|
+
def main(_):
|
41
|
+
checkpoint = "google/paligemma-3b-mix-224"
|
42
|
+
logging.info("Loading the original model from: %s", checkpoint)
|
43
|
+
original_full_model = (
|
44
|
+
transformers.PaliGemmaForConditionalGeneration.from_pretrained(checkpoint)
|
45
|
+
)
|
46
|
+
original_language_model = original_full_model.eval().language_model
|
47
|
+
|
48
|
+
# Locate the cached dir.
|
49
|
+
cached_config_file = transformers.utils.cached_file(
|
50
|
+
checkpoint, transformers.utils.CONFIG_NAME
|
51
|
+
)
|
52
|
+
reauthored_checkpoint = pathlib.Path(cached_config_file).parent
|
53
|
+
logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
|
54
|
+
reauthored_model = decoder.build_decoder(reauthored_checkpoint)
|
55
|
+
|
56
|
+
logging.info("Loading the tokenizer from: %s", checkpoint)
|
57
|
+
# It works only when GemmaTokenizerFast is available. In some environments,
|
58
|
+
# use_fast=False doeesn't work either if the tokenizer cannot load the
|
59
|
+
# sentencepiece model file properly.
|
60
|
+
processor = transformers.AutoProcessor.from_pretrained(checkpoint)
|
61
|
+
|
62
|
+
verifier.verify_reauthored_model(
|
63
|
+
original_model=transformers_verifier.TransformersModelWrapper(
|
64
|
+
original_language_model
|
65
|
+
),
|
66
|
+
reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
|
67
|
+
tokenizer=verifier.TokenizerWrapper(processor.tokenizer),
|
68
|
+
generate_prompts=_PROMPTS.value,
|
69
|
+
max_new_tokens=_MAX_NEW_TOKENS.value,
|
70
|
+
atol=1e-04,
|
71
|
+
)
|
72
|
+
|
73
|
+
|
74
|
+
if __name__ == "__main__":
|
75
|
+
app.run(main)
|
@@ -0,0 +1,82 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Verifies the reauthored image encoder of PaliGemma 3B model."""
|
17
|
+
|
18
|
+
import logging
|
19
|
+
import pathlib
|
20
|
+
from absl import app
|
21
|
+
from absl import flags
|
22
|
+
from ai_edge_torch.generative.examples.paligemma import image_encoder
|
23
|
+
from PIL import Image
|
24
|
+
import requests
|
25
|
+
import torch
|
26
|
+
import transformers
|
27
|
+
|
28
|
+
_IMAGE_URL = flags.DEFINE_string(
|
29
|
+
"image_url",
|
30
|
+
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg?download=true",
|
31
|
+
"The image URI to encode.",
|
32
|
+
)
|
33
|
+
|
34
|
+
|
35
|
+
def main(_):
|
36
|
+
checkpoint = "google/paligemma-3b-mix-224"
|
37
|
+
logging.info("Loading the original model from: %s", checkpoint)
|
38
|
+
original_full_model = (
|
39
|
+
transformers.PaliGemmaForConditionalGeneration.from_pretrained(checkpoint)
|
40
|
+
)
|
41
|
+
original_vision_model = original_full_model.eval().vision_tower
|
42
|
+
|
43
|
+
# Locate the cached dir.
|
44
|
+
cached_config_file = transformers.utils.cached_file(
|
45
|
+
checkpoint, transformers.utils.CONFIG_NAME
|
46
|
+
)
|
47
|
+
reauthored_checkpoint = pathlib.Path(cached_config_file).parent
|
48
|
+
logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
|
49
|
+
reauthored_model = image_encoder.build_image_encoder(reauthored_checkpoint)
|
50
|
+
|
51
|
+
logging.info("Loading the processor from: %s", checkpoint)
|
52
|
+
# It works only when GemmaTokenizerFast is available. In some environments,
|
53
|
+
# use_fast=False doeesn't work either if the tokenizer cannot load the
|
54
|
+
# sentencepiece model file properly.
|
55
|
+
processor = transformers.AutoProcessor.from_pretrained(checkpoint)
|
56
|
+
|
57
|
+
logging.info("Loading the image from: %s", _IMAGE_URL.value)
|
58
|
+
image = Image.open(requests.get(_IMAGE_URL.value, stream=True).raw)
|
59
|
+
pixel_values = processor(images=image, return_tensors="pt")["pixel_values"]
|
60
|
+
|
61
|
+
logging.info("Forwarding the original model...")
|
62
|
+
outputs_original = original_vision_model.forward(pixel_values=pixel_values)
|
63
|
+
outputs_original = outputs_original.last_hidden_state
|
64
|
+
logging.info("outputs_original: %s", outputs_original)
|
65
|
+
|
66
|
+
logging.info("Forwarding the reauthored model...")
|
67
|
+
outputs_reauthored = reauthored_model.forward(pixel_values=pixel_values)
|
68
|
+
logging.info("outputs_reauthored: %s", outputs_reauthored)
|
69
|
+
|
70
|
+
try:
|
71
|
+
assert torch.allclose(
|
72
|
+
outputs_original, outputs_reauthored, atol=1e-04, rtol=1e-04
|
73
|
+
)
|
74
|
+
except AssertionError as e:
|
75
|
+
logging.error("*** FAILED *** verify with an image")
|
76
|
+
raise e
|
77
|
+
else:
|
78
|
+
logging.info("*** PASSED *** verify with an image")
|
79
|
+
|
80
|
+
|
81
|
+
if __name__ == "__main__":
|
82
|
+
app.run(main)
|
@@ -235,9 +235,10 @@ class CausalSelfAttention(nn.Module):
|
|
235
235
|
k = k.reshape(B, T, -1, self.config.head_dim)
|
236
236
|
v = v.reshape(B, T, -1, self.config.head_dim)
|
237
237
|
|
238
|
-
|
239
|
-
|
240
|
-
|
238
|
+
if rope is not None:
|
239
|
+
# Compute rotary positional embedding for query and key.
|
240
|
+
n_elem = int(self.config.rotary_percentage * self.config.head_dim)
|
241
|
+
q, k = _embed_rope(q, k, n_elem, rope)
|
241
242
|
|
242
243
|
if kv_cache is not None:
|
243
244
|
kv_cache = kv_utils.update(
|
@@ -372,9 +373,10 @@ class CrossAttention(nn.Module):
|
|
372
373
|
k = k.view(interim_shape)
|
373
374
|
v = v.view(interim_shape)
|
374
375
|
|
375
|
-
|
376
|
-
|
377
|
-
|
376
|
+
if rope is not None:
|
377
|
+
# Compute rotary positional embedding for query and key.
|
378
|
+
n_elem = int(self.config.rotary_percentage * self.config.head_dim)
|
379
|
+
q, k = _embed_rope(q, k, n_elem, rope)
|
378
380
|
|
379
381
|
if kv_cache is not None:
|
380
382
|
kv_cache = kv_utils.update(
|
@@ -163,6 +163,16 @@ class TransformerBlockConfig:
|
|
163
163
|
relative_attention: bool = False
|
164
164
|
|
165
165
|
|
166
|
+
@dataclass
|
167
|
+
class ImageEmbeddingConfig:
|
168
|
+
"""Image embedding parameters."""
|
169
|
+
|
170
|
+
channels: int
|
171
|
+
# All images should be normalized to the size of [image_size * image_size].
|
172
|
+
image_size: int
|
173
|
+
patch_size: int
|
174
|
+
|
175
|
+
|
166
176
|
@dataclass
|
167
177
|
class ModelConfig:
|
168
178
|
"""Base configurations for building a transformer architecture."""
|
@@ -183,6 +193,10 @@ class ModelConfig:
|
|
183
193
|
|
184
194
|
# Scale factor of the embedding.
|
185
195
|
embedding_scale: Optional[float] = None
|
196
|
+
# Use bias term within embedding.
|
197
|
+
embedding_use_bias: bool = False
|
198
|
+
# Image embedding parameters.
|
199
|
+
image_embedding: Optional[ImageEmbeddingConfig] = None
|
186
200
|
|
187
201
|
# Use bias term within LLM's HEAD.
|
188
202
|
lm_head_use_bias: bool = False
|
@@ -115,12 +115,15 @@ class AttentionBlock2D(nn.Module):
|
|
115
115
|
"""
|
116
116
|
super().__init__()
|
117
117
|
self.config = config
|
118
|
+
hidden_dim = config.hidden_dim
|
119
|
+
if not hidden_dim:
|
120
|
+
hidden_dim = config.dim
|
118
121
|
self.norm = layers_builder.build_norm(
|
119
|
-
|
122
|
+
hidden_dim, config.normalization_config
|
120
123
|
)
|
121
124
|
self.attention = SelfAttention(
|
122
125
|
config.attention_batch_size,
|
123
|
-
|
126
|
+
hidden_dim,
|
124
127
|
config.attention_config,
|
125
128
|
enable_hlfb=config.enable_hlfb,
|
126
129
|
)
|
@@ -172,7 +175,7 @@ class CrossAttentionBlock2D(nn.Module):
|
|
172
175
|
super().__init__()
|
173
176
|
self.config = config
|
174
177
|
self.norm = layers_builder.build_norm(
|
175
|
-
config.
|
178
|
+
config.output_dim, config.normalization_config
|
176
179
|
)
|
177
180
|
self.attention = CrossAttention(
|
178
181
|
config.attention_batch_size,
|
@@ -294,21 +297,30 @@ class TransformerBlock2D(nn.Module):
|
|
294
297
|
hidden_states
|
295
298
|
"""
|
296
299
|
|
297
|
-
def __init__(
|
300
|
+
def __init__(
|
301
|
+
self, config: unet_cfg.TransformerBlock2DConfig, dim_override=None
|
302
|
+
):
|
298
303
|
"""Initialize an instance of the TransformerBlock2D.
|
299
304
|
|
300
305
|
Args:
|
301
306
|
config (unet_cfg.TransformerBlock2Dconfig): the configuration of this
|
302
307
|
block.
|
308
|
+
dim_override: in case specified, overrides config.attention_block_config.hidden_dim. Set to None by default.
|
303
309
|
"""
|
304
310
|
super().__init__()
|
305
311
|
self.config = config
|
312
|
+
attention_block_config_dim = config.attention_block_config.dim
|
313
|
+
attention_block_config_hidden_dim = config.attention_block_config.hidden_dim
|
314
|
+
if dim_override:
|
315
|
+
attention_block_config_dim = dim_override
|
316
|
+
if not attention_block_config_hidden_dim:
|
317
|
+
attention_block_config_hidden_dim = attention_block_config_dim
|
306
318
|
self.pre_conv_norm = layers_builder.build_norm(
|
307
|
-
|
319
|
+
attention_block_config_dim, config.pre_conv_normalization_config
|
308
320
|
)
|
309
321
|
self.conv_in = nn.Conv2d(
|
310
|
-
|
311
|
-
|
322
|
+
attention_block_config_dim,
|
323
|
+
attention_block_config_hidden_dim,
|
312
324
|
kernel_size=1,
|
313
325
|
padding=0,
|
314
326
|
)
|
@@ -318,8 +330,8 @@ class TransformerBlock2D(nn.Module):
|
|
318
330
|
)
|
319
331
|
self.feed_forward = FeedForwardBlock2D(config.feed_forward_block_config)
|
320
332
|
self.conv_out = nn.Conv2d(
|
321
|
-
|
322
|
-
|
333
|
+
attention_block_config_hidden_dim,
|
334
|
+
attention_block_config_dim,
|
323
335
|
kernel_size=1,
|
324
336
|
padding=0,
|
325
337
|
)
|
@@ -385,14 +397,18 @@ class DownEncoderBlock2D(nn.Module):
|
|
385
397
|
self.config = config
|
386
398
|
resnets = []
|
387
399
|
transformers = []
|
400
|
+
hidden_channels = config.hidden_channels
|
401
|
+
if not hidden_channels:
|
402
|
+
hidden_channels = config.out_channels
|
388
403
|
for i in range(config.num_layers):
|
389
404
|
input_channels = config.in_channels if i == 0 else config.out_channels
|
390
405
|
resnets.append(
|
391
406
|
ResidualBlock2D(
|
392
407
|
unet_cfg.ResidualBlock2DConfig(
|
393
408
|
in_channels=input_channels,
|
394
|
-
hidden_channels=
|
409
|
+
hidden_channels=hidden_channels,
|
395
410
|
out_channels=config.out_channels,
|
411
|
+
residual_out_channels=config.out_channels,
|
396
412
|
time_embedding_channels=config.time_embedding_channels,
|
397
413
|
normalization_config=config.normalization_config,
|
398
414
|
activation_config=config.activation_config,
|
@@ -589,23 +605,37 @@ class SkipUpDecoderBlock2D(nn.Module):
|
|
589
605
|
"""
|
590
606
|
super().__init__()
|
591
607
|
self.config = config
|
608
|
+
hidden_channels = config.hidden_channels
|
609
|
+
if not hidden_channels:
|
610
|
+
hidden_channels = config.out_channels
|
611
|
+
sub_block_channels = config.sub_block_channels
|
612
|
+
if sub_block_channels:
|
613
|
+
assert len(sub_block_channels) == config.num_layers, (
|
614
|
+
"Assertion failed: The length of 'sub_block_channels'"
|
615
|
+
f" ({len(sub_block_channels)}) does not match 'config.num_layers'"
|
616
|
+
f" ({config.num_layers})."
|
617
|
+
)
|
618
|
+
else:
|
619
|
+
sub_block_channels = [config.out_channels] * config.num_layers
|
592
620
|
resnets = []
|
593
621
|
transformers = []
|
594
622
|
for i in range(config.num_layers):
|
623
|
+
resnet_in_channels = (
|
624
|
+
config.prev_out_channels if i == 0 else sub_block_channels[i - 1]
|
625
|
+
)
|
595
626
|
res_skip_channels = (
|
596
627
|
config.in_channels
|
597
628
|
if (i == config.num_layers - 1)
|
598
629
|
else config.out_channels
|
599
630
|
)
|
600
|
-
|
601
|
-
config.prev_out_channels if i == 0 else config.out_channels
|
602
|
-
)
|
631
|
+
residual_out_channel = sub_block_channels[i]
|
603
632
|
resnets.append(
|
604
633
|
ResidualBlock2D(
|
605
634
|
unet_cfg.ResidualBlock2DConfig(
|
606
635
|
in_channels=resnet_in_channels + res_skip_channels,
|
607
|
-
hidden_channels=
|
608
|
-
out_channels=
|
636
|
+
hidden_channels=hidden_channels,
|
637
|
+
out_channels=sub_block_channels[i],
|
638
|
+
residual_out_channels=residual_out_channel,
|
609
639
|
time_embedding_channels=config.time_embedding_channels,
|
610
640
|
normalization_config=config.normalization_config,
|
611
641
|
activation_config=config.activation_config,
|
@@ -613,7 +643,12 @@ class SkipUpDecoderBlock2D(nn.Module):
|
|
613
643
|
)
|
614
644
|
)
|
615
645
|
if config.transformer_block_config:
|
616
|
-
transformers.append(
|
646
|
+
transformers.append(
|
647
|
+
TransformerBlock2D(
|
648
|
+
config.transformer_block_config,
|
649
|
+
dim_override=sub_block_channels[i],
|
650
|
+
)
|
651
|
+
)
|
617
652
|
self.resnets = nn.ModuleList(resnets)
|
618
653
|
self.transformers = (
|
619
654
|
nn.ModuleList(transformers) if len(transformers) > 0 else None
|
@@ -623,7 +658,7 @@ class SkipUpDecoderBlock2D(nn.Module):
|
|
623
658
|
if config.upsample_conv:
|
624
659
|
self.upsample_conv = nn.Conv2d(
|
625
660
|
config.out_channels,
|
626
|
-
|
661
|
+
sub_block_channels[0],
|
627
662
|
kernel_size=3,
|
628
663
|
stride=1,
|
629
664
|
padding=1,
|
@@ -711,6 +746,7 @@ class MidBlock2D(nn.Module):
|
|
711
746
|
in_channels=config.in_channels,
|
712
747
|
hidden_channels=config.in_channels,
|
713
748
|
out_channels=config.in_channels,
|
749
|
+
residual_out_channels=config.in_channels,
|
714
750
|
time_embedding_channels=config.time_embedding_channels,
|
715
751
|
normalization_config=config.normalization_config,
|
716
752
|
activation_config=config.activation_config,
|
@@ -50,10 +50,12 @@ class ResidualBlock2DConfig:
|
|
50
50
|
in_channels: int
|
51
51
|
hidden_channels: int
|
52
52
|
out_channels: int
|
53
|
+
hidden_channels: int
|
53
54
|
normalization_config: layers_cfg.NormalizationConfig
|
54
55
|
activation_config: layers_cfg.ActivationConfig
|
55
56
|
# Optional time embedding channels if the residual block takes a time embedding context as input
|
56
57
|
time_embedding_channels: Optional[int] = None
|
58
|
+
residual_out_channels: Optional[int] = None
|
57
59
|
|
58
60
|
|
59
61
|
@dataclasses.dataclass
|
@@ -63,6 +65,7 @@ class AttentionBlock2DConfig:
|
|
63
65
|
attention_config: layers_cfg.AttentionConfig
|
64
66
|
enable_hlfb: bool = True
|
65
67
|
attention_batch_size: int = 1
|
68
|
+
hidden_dim: Optional[int] = None
|
66
69
|
|
67
70
|
|
68
71
|
@dataclasses.dataclass
|
@@ -101,6 +104,8 @@ class UpDecoderBlock2DConfig:
|
|
101
104
|
normalization_config: layers_cfg.NormalizationConfig
|
102
105
|
activation_config: layers_cfg.ActivationConfig
|
103
106
|
num_layers: int
|
107
|
+
# The dimension of output channels of previous connected block
|
108
|
+
prev_out_channels: Optional[int] = None
|
104
109
|
# Optional time embedding channels if the residual blocks take a time embedding as input
|
105
110
|
time_embedding_channels: Optional[int] = None
|
106
111
|
# Whether to add upsample operation after residual blocks
|
@@ -136,6 +141,8 @@ class SkipUpDecoderBlock2DConfig:
|
|
136
141
|
transformer_block_config: Optional[TransformerBlock2DConfig] = None
|
137
142
|
# Optional dimension of context tensor if context tensor is given as input.
|
138
143
|
context_dim: Optional[int] = None
|
144
|
+
sub_block_channels: Optional[tuple] = None
|
145
|
+
hidden_channels: Optional[int] = None
|
139
146
|
|
140
147
|
|
141
148
|
@dataclasses.dataclass
|
@@ -157,6 +164,7 @@ class DownEncoderBlock2DConfig:
|
|
157
164
|
transformer_block_config: Optional[TransformerBlock2DConfig] = None
|
158
165
|
# Optional dimension of context tensor if context tensor is given as input.
|
159
166
|
context_dim: Optional[int] = None
|
167
|
+
hidden_channels: Optional[int] = None
|
160
168
|
|
161
169
|
|
162
170
|
@dataclasses.dataclass
|
@@ -157,6 +157,10 @@ class ModelLoader:
|
|
157
157
|
converted_state["tok_embedding.weight"] = state.pop(
|
158
158
|
f"{self._names.embedding}.weight"
|
159
159
|
)
|
160
|
+
if model.config.embedding_use_bias:
|
161
|
+
converted_state["tok_embedding.bias"] = state.pop(
|
162
|
+
f"{self._names.embedding}.bias"
|
163
|
+
)
|
160
164
|
if self._names.embedding_position is not None:
|
161
165
|
converted_state["tok_embedding_position"] = state.pop(
|
162
166
|
f"{self._names.embedding_position}"
|
@@ -143,7 +143,7 @@ def verify_with_input_ids(
|
|
143
143
|
kv_cache_max_len: int = 1024,
|
144
144
|
rtol: float = 1e-05,
|
145
145
|
atol: float = 1e-05,
|
146
|
-
)
|
146
|
+
):
|
147
147
|
"""Verifies if the model reauthored generates the same output of the oringal.
|
148
148
|
|
149
149
|
It compares only one outputs from the original and the reauthored model.
|
@@ -157,8 +157,9 @@ def verify_with_input_ids(
|
|
157
157
|
rtol (float): The relative tolerance for the comparison.
|
158
158
|
atol (float): The absolute tolerance for the comparison.
|
159
159
|
|
160
|
-
|
161
|
-
|
160
|
+
Raises:
|
161
|
+
AssertError if the model reauthored fails to generate the same output of the
|
162
|
+
original.
|
162
163
|
"""
|
163
164
|
tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.int, device="cpu")
|
164
165
|
tokens[0, : len(input_ids)] = torch.tensor([input_ids]).int()
|
@@ -173,7 +174,7 @@ def verify_with_input_ids(
|
|
173
174
|
logits_reauthored = outputs_reauthored[0, len(input_ids) - 1, :]
|
174
175
|
logging.info("logits_reauthored: %s", logits_reauthored)
|
175
176
|
|
176
|
-
|
177
|
+
assert torch.allclose(
|
177
178
|
logits_original, logits_reauthored, rtol=rtol, atol=atol
|
178
179
|
)
|
179
180
|
|
@@ -184,7 +185,7 @@ def verify_model_with_prompts(
|
|
184
185
|
tokenizer: TokenizerWrapper,
|
185
186
|
prompts: str,
|
186
187
|
max_new_tokens: int,
|
187
|
-
)
|
188
|
+
):
|
188
189
|
"""Verifies if the model reauthored generates the same answer of the oringal.
|
189
190
|
|
190
191
|
It compares an answer, i.e. multiple continuous outputs generated by the
|
@@ -198,8 +199,9 @@ def verify_model_with_prompts(
|
|
198
199
|
prompts (str): The input prompts to generate answers.
|
199
200
|
max_new_tokens (int): The maximum number of new tokens to generate.
|
200
201
|
|
201
|
-
|
202
|
-
|
202
|
+
Raises:
|
203
|
+
AssertError if the model reauthored fails to generate the same answer of the
|
204
|
+
original.
|
203
205
|
"""
|
204
206
|
prompt_tokens = tokenizer.encode(prompts)
|
205
207
|
|
@@ -213,7 +215,7 @@ def verify_model_with_prompts(
|
|
213
215
|
response_reauthored = tokenizer.decode(outputs_reauthored[0])
|
214
216
|
logging.info("outputs from reauthored model: [[%s]]", response_reauthored)
|
215
217
|
|
216
|
-
|
218
|
+
assert response_original == response_reauthored
|
217
219
|
|
218
220
|
|
219
221
|
def verify_reauthored_model(
|
@@ -225,7 +227,8 @@ def verify_reauthored_model(
|
|
225
227
|
forward_input_ids: List[List[int]] = [[1, 2, 3, 4]],
|
226
228
|
rtol: float = 1e-05,
|
227
229
|
atol: float = 1e-05,
|
228
|
-
|
230
|
+
continue_on_failure: bool = False,
|
231
|
+
) -> bool:
|
229
232
|
"""Verifies the reauthored model against the original model.
|
230
233
|
|
231
234
|
It verifies the reauthored model with two methods:
|
@@ -234,7 +237,8 @@ def verify_reauthored_model(
|
|
234
237
|
2. It compares the answer generated by the original and the reauthored model
|
235
238
|
with a prompt.
|
236
239
|
|
237
|
-
It prints out "PASS" or "FAILED" to the console.
|
240
|
+
It prints out "PASS" or "FAILED" to the console. It returns True if all
|
241
|
+
verification passes, False otherwise.
|
238
242
|
|
239
243
|
Args:
|
240
244
|
original_model (ModelWrapper): The original model.
|
@@ -247,21 +251,42 @@ def verify_reauthored_model(
|
|
247
251
|
forward with.
|
248
252
|
rtol (float): The relative tolerance for the comparison.
|
249
253
|
atol (float): The absolute tolerance for the comparison.
|
254
|
+
continue_on_failure (bool): If True, it continues to verify the next prompt
|
255
|
+
or input IDs even if a previous one fails.
|
250
256
|
"""
|
257
|
+
failure_count = 0
|
258
|
+
|
251
259
|
for input_ids in forward_input_ids:
|
252
260
|
logging.info("Verifying the reauthored model with input IDs: %s", input_ids)
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
|
261
|
+
try:
|
262
|
+
verify_with_input_ids(
|
263
|
+
original_model, reauthored_model, input_ids, rtol=rtol, atol=atol
|
264
|
+
)
|
265
|
+
except AssertionError as e:
|
266
|
+
logging.error("*** FAILED *** verify with input IDs: %s", input_ids)
|
267
|
+
failure_count += 1
|
268
|
+
if not continue_on_failure:
|
269
|
+
return False
|
257
270
|
else:
|
258
|
-
logging.
|
271
|
+
logging.info("*** PASSED *** verify with input IDs: %s", input_ids)
|
259
272
|
|
260
273
|
for prompts in generate_prompts:
|
261
|
-
logging.info("Verifying the reauthored model with prompts
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
|
274
|
+
logging.info("Verifying the reauthored model with prompts: %s", prompts)
|
275
|
+
try:
|
276
|
+
verify_model_with_prompts(
|
277
|
+
original_model, reauthored_model, tokenizer, prompts, max_new_tokens
|
278
|
+
)
|
279
|
+
except AssertionError as e:
|
280
|
+
logging.error("*** FAILED *** verify with prompts: %s", prompts)
|
281
|
+
failure_count += 1
|
282
|
+
if not continue_on_failure:
|
283
|
+
return False
|
266
284
|
else:
|
267
|
-
logging.
|
285
|
+
logging.info("*** PASSED *** verify with prompts: %s", prompts)
|
286
|
+
|
287
|
+
if failure_count == 0:
|
288
|
+
logging.info("*** PASSED *** verify_reauthored_model")
|
289
|
+
return True
|
290
|
+
else:
|
291
|
+
logging.error("*** FAILED *** verify_reauthored_model")
|
292
|
+
return False
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.3.0.
|
3
|
+
Version: 0.3.0.dev20241114
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
@@ -3,7 +3,7 @@ ai_edge_torch/config.py,sha256=FMWeCH2b7HYILBvaI1iZNnYCO4WAhDOwBZBmIE-xrF0,909
|
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
4
4
|
ai_edge_torch/fx_pass_base.py,sha256=518ziQ0TUxqum2qZXqlD8qr65pHPh8ZNLnwFC6zvK3k,4253
|
5
5
|
ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
|
6
|
-
ai_edge_torch/version.py,sha256=
|
6
|
+
ai_edge_torch/version.py,sha256=kPX3PxeMGiRGAaIsITYzfPxu3_FKKLg91pGrqXOP6OY,706
|
7
7
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
8
8
|
ai_edge_torch/_convert/conversion.py,sha256=HwzfRx_DX5TLtPqwEH1_NOm38_INvHzHl4_mX67KOdQ,5448
|
9
9
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
@@ -12,7 +12,7 @@ ai_edge_torch/_convert/signature.py,sha256=rGpBNss3Y9FCCCcdBwDo16KqavJi8N5P0M_6W
|
|
12
12
|
ai_edge_torch/_convert/to_channel_last_io.py,sha256=_31phf7TYgZY2ftpNbrdlB1RhDium1lz_BXEQ6IsMFc,2893
|
13
13
|
ai_edge_torch/_convert/fx_passes/__init__.py,sha256=NVe-eGcm7j8jZpP2pcMhC8j5dVjgR1pPzyXhHdvKH4E,1267
|
14
14
|
ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py,sha256=doaww8KqrgRTD5LotBVAIRFsEqzPn9R5lcGehBJOczA,9098
|
15
|
-
ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py,sha256=
|
15
|
+
ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py,sha256=m_yj66V11LmWCYgA7yLtr__cy14IbC5WEJe0BE0_IPE,4339
|
16
16
|
ai_edge_torch/_convert/fx_passes/inject_mlir_debuginfo_pass.py,sha256=IlZuK42kfVcRqAWZp4j2k_81T2uWo9T2558U_GPJAlU,2327
|
17
17
|
ai_edge_torch/_convert/fx_passes/remove_non_user_outputs_pass.py,sha256=f1IUVWyhioOClsMiZzLyynoW2R17U83vA-7Q-3pGPM4,2126
|
18
18
|
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/__init__.py,sha256=lxnoH-WGLeiQIF8XjMGodjiZEFTxucl7g05N7MR9OPk,796
|
@@ -50,16 +50,21 @@ ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py,sha256=RZDs6
|
|
50
50
|
ai_edge_torch/generative/examples/gemma/gemma1.py,sha256=oSbysiPvwp5efMbNYZop3HrxDMGiD15Tmz-HiQuTr2E,3315
|
51
51
|
ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=RQFQDMEnIVp8PefcCTr7P0CvllKI7FVoIJLXbPLLIsc,9056
|
52
52
|
ai_edge_torch/generative/examples/gemma/verify_gemma1.py,sha256=ip-Gmk4CI5f0GWSdAIdrectxQWJ0t328KCsA4nfHuGg,1736
|
53
|
-
ai_edge_torch/generative/examples/gemma/verify_gemma2.py,sha256=
|
54
|
-
ai_edge_torch/generative/examples/gemma/verify_util.py,sha256=
|
53
|
+
ai_edge_torch/generative/examples/gemma/verify_gemma2.py,sha256=IoBhEMwH07-tFm5-U6F2hpCsI8xynglhq1x9tIOdaPQ,1322
|
54
|
+
ai_edge_torch/generative/examples/gemma/verify_util.py,sha256=tR8RflXocDZqvuStyw9aFlzuiTllEC8rNnjrxms6_Is,5727
|
55
55
|
ai_edge_torch/generative/examples/llama/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
56
56
|
ai_edge_torch/generative/examples/llama/convert_to_tflite.py,sha256=P0-pByTM5tslE23ILgo7nd0nOGE25ciBRG5wKJj0bBk,2411
|
57
57
|
ai_edge_torch/generative/examples/llama/llama.py,sha256=AMcCbuDBxEfbO-l3KiEXbUaXEJ3RLLwkHii7to7UhVo,6854
|
58
58
|
ai_edge_torch/generative/examples/llama/verify.py,sha256=X7oKQi85M789ugBrOlMvzk8eSRR3Kf1Mprfl-U-WIpo,2842
|
59
59
|
ai_edge_torch/generative/examples/openelm/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
60
60
|
ai_edge_torch/generative/examples/openelm/convert_to_tflite.py,sha256=85FVEt6cKFP2UzCLC78tAkbwGlGhAArtG7Wa75NxJik,2185
|
61
|
-
ai_edge_torch/generative/examples/openelm/openelm.py,sha256=
|
61
|
+
ai_edge_torch/generative/examples/openelm/openelm.py,sha256=sFakstoPDcOHSak0IGFEEq_HQMBBSMcx-WVCDZqcVDo,4411
|
62
62
|
ai_edge_torch/generative/examples/openelm/verify.py,sha256=VkigoqhAr8ew95neb3TifYv-SLOSheaWKv2AH0iKDrc,2441
|
63
|
+
ai_edge_torch/generative/examples/paligemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
64
|
+
ai_edge_torch/generative/examples/paligemma/decoder.py,sha256=JSb9h3gcIh5oYrbLU6rI8OU8FzfWeTCFJT5XRWu4btE,3675
|
65
|
+
ai_edge_torch/generative/examples/paligemma/image_encoder.py,sha256=v19_EKALhAP9FjkINKqpv8JsVaQ6iH_7X5FpnhE6abw,5500
|
66
|
+
ai_edge_torch/generative/examples/paligemma/verify_decoder.py,sha256=al5wMPWri4IRVWrLmCplPi6uoCzwh0vBHMGnCt-XUqo,2690
|
67
|
+
ai_edge_torch/generative/examples/paligemma/verify_image_encoder.py,sha256=pSekf1BybhieQz3cQx_llbRQHxczXbTqool8fOyGj_0,3114
|
63
68
|
ai_edge_torch/generative/examples/phi/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
64
69
|
ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py,sha256=rkbTtMaqSVG48cm-NTxR_LDgZmXAEBqayTm9O49oMXc,2171
|
65
70
|
ai_edge_torch/generative/examples/phi/convert_to_tflite.py,sha256=3go690yX6PFeXMdpY7y4JZorAwxX0HT_b_pKZieauvk,2169
|
@@ -105,19 +110,19 @@ ai_edge_torch/generative/examples/tiny_llama/verify.py,sha256=7Bk8z033M-BCXJ299f
|
|
105
110
|
ai_edge_torch/generative/fx_passes/__init__.py,sha256=jrzCB3ZyY_t5jJM1e2Czdt3DjAIL43R0_a-T-I7wOzw,1155
|
106
111
|
ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=hhxSQvkDMv0isZJhmuLiod66ZODaJ8uSPSVTJVHBabQ,1931
|
107
112
|
ai_edge_torch/generative/layers/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
108
|
-
ai_edge_torch/generative/layers/attention.py,sha256=
|
113
|
+
ai_edge_torch/generative/layers/attention.py,sha256=zN3BQjA25Ej_aRU0rFnyx--K74xf5ykc02zGvUpYHeE,13295
|
109
114
|
ai_edge_torch/generative/layers/attention_utils.py,sha256=zBVwlBUTs-nStIKCZG0ks5ra7tsqc9ShfakFJKH5rds,7344
|
110
115
|
ai_edge_torch/generative/layers/builder.py,sha256=Z5LyzCEThgnYZeyViakaE3yJVzTGHtw13acHsAQR15U,5050
|
111
116
|
ai_edge_torch/generative/layers/feed_forward.py,sha256=hdICat-8gW7-vxDAevJQ8NQ-mynllPiqLdXQMF6JMnc,4189
|
112
117
|
ai_edge_torch/generative/layers/kv_cache.py,sha256=lbm-yJ1jGPtcgWS4C3FmSnB1IlxqDE7g0BLRh3PN4N4,6324
|
113
|
-
ai_edge_torch/generative/layers/model_config.py,sha256=
|
118
|
+
ai_edge_torch/generative/layers/model_config.py,sha256=xqa7ZBEjgK4UWJAThRXb_VBFZ5KCGtDu-QaY5GXar9s,7366
|
114
119
|
ai_edge_torch/generative/layers/normalization.py,sha256=eKAGst9rPuyRFExMcQFJO7R3iHdCtlmjeF_lITjLhwE,6498
|
115
120
|
ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=CZqOoibLcHvUgrgaIIWAlmk3XgE2inzx340MN-npLoU,1347
|
116
121
|
ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=gXxh3papKy4FBpGEX7VyZ7rZ1Js6aHK70Q6DKrVSckY,4154
|
117
122
|
ai_edge_torch/generative/layers/unet/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
118
|
-
ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=
|
123
|
+
ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=9jKzOfiBQ66bp1ZnVIAoREIifVNFx4aTlQeYMAx2_pA,29062
|
119
124
|
ai_edge_torch/generative/layers/unet/builder.py,sha256=zAqWXdimmMrQRhmE_t9XkS68mh6PSrzwb-2NZZXrR5I,1901
|
120
|
-
ai_edge_torch/generative/layers/unet/model_config.py,sha256=
|
125
|
+
ai_edge_torch/generative/layers/unet/model_config.py,sha256=pPDwLawc23pfMaPVyMJlYmxVVusjMvx-l8wBwOYOH-c,9692
|
121
126
|
ai_edge_torch/generative/quantize/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
122
127
|
ai_edge_torch/generative/quantize/example.py,sha256=1lfVNUd2cEyRUnoZ7BLbRJ9IN-FTKiWBtZNPFUzAiWE,1747
|
123
128
|
ai_edge_torch/generative/quantize/quant_attrs.py,sha256=n1Fm8BFC8gJa_oiwwAOOghJyHtOXYZ4q-5ZRy4pHrIw,1957
|
@@ -134,12 +139,12 @@ ai_edge_torch/generative/test/test_quantize.py,sha256=8geJhKwYBU20m0mdGPD1BUFwQ0
|
|
134
139
|
ai_edge_torch/generative/test/utils.py,sha256=YvEhO2HIj1LkBs5du1UxY-cGRW9HMyAYsOUhgsTrTpA,1796
|
135
140
|
ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
|
136
141
|
ai_edge_torch/generative/utilities/converter.py,sha256=17O83wVifH1vQJCI4WC3DaNiCIOyK2gys1GzohbLrRs,5554
|
137
|
-
ai_edge_torch/generative/utilities/loader.py,sha256=
|
142
|
+
ai_edge_torch/generative/utilities/loader.py,sha256=k5fjCokNomte4ymy9IJrEWAuCSMhsPCJfmv1y5s0ZEc,13452
|
138
143
|
ai_edge_torch/generative/utilities/model_builder.py,sha256=89jt80UUfDzYBi-x077HBavWeuNJuYPXym9fiKCY1Tk,5278
|
139
144
|
ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=dqPD9qRXEWtU3ombslOC-BE2l_dMwHoCNu7NsIJhsso,36158
|
140
145
|
ai_edge_torch/generative/utilities/t5_loader.py,sha256=tEsfy8-ymzbbjOIc-oesXF3yGyyWtJgFXn2s7VOavt8,16961
|
141
146
|
ai_edge_torch/generative/utilities/transformers_verifier.py,sha256=8sp9m_FMcXn7nqOrochtu2jIANkJKhnhIBUmH0ZTDR4,1549
|
142
|
-
ai_edge_torch/generative/utilities/verifier.py,sha256=
|
147
|
+
ai_edge_torch/generative/utilities/verifier.py,sha256=h5hGyIpYGyPZwvelbzpdkjy99Kpd4JkvhqWtQN9cm-M,10413
|
143
148
|
ai_edge_torch/hlfb/__init__.py,sha256=sH4um75na-O8tzxN6chFyp6Y4xnexsE7kUQpZySv6dE,735
|
144
149
|
ai_edge_torch/hlfb/mark_pattern/__init__.py,sha256=cjTprggj_cuktSCm7-A25e7Shop3k63ylp7sdZmtZ8o,4790
|
145
150
|
ai_edge_torch/hlfb/mark_pattern/passes.py,sha256=pjkKcI1nHECPluAt87cFBrt1DP0f3ge7rHq1NhCkBIE,1936
|
@@ -186,8 +191,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
|
|
186
191
|
ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
187
192
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
188
193
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
189
|
-
ai_edge_torch_nightly-0.3.0.
|
190
|
-
ai_edge_torch_nightly-0.3.0.
|
191
|
-
ai_edge_torch_nightly-0.3.0.
|
192
|
-
ai_edge_torch_nightly-0.3.0.
|
193
|
-
ai_edge_torch_nightly-0.3.0.
|
194
|
+
ai_edge_torch_nightly-0.3.0.dev20241114.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
195
|
+
ai_edge_torch_nightly-0.3.0.dev20241114.dist-info/METADATA,sha256=LApNIpUONX7p4MYLOsx9QbWyD31PtVQDgQShvzqW-ec,1897
|
196
|
+
ai_edge_torch_nightly-0.3.0.dev20241114.dist-info/WHEEL,sha256=bFJAMchF8aTQGUgMZzHJyDDMPTO3ToJ7x23SLJa1SVo,92
|
197
|
+
ai_edge_torch_nightly-0.3.0.dev20241114.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
198
|
+
ai_edge_torch_nightly-0.3.0.dev20241114.dist-info/RECORD,,
|
File without changes
|