ai-edge-torch-nightly 0.3.0.dev20240918__py3-none-any.whl → 0.3.0.dev20240919__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/generative/examples/openelm/openelm.py +0 -29
- ai_edge_torch/generative/examples/openelm/verify.py +61 -0
- ai_edge_torch/generative/examples/phi/phi2.py +4 -31
- ai_edge_torch/generative/examples/phi/verify.py +53 -0
- ai_edge_torch/generative/examples/smollm/smollm.py +0 -30
- ai_edge_torch/generative/examples/smollm/verify.py +59 -0
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +0 -29
- ai_edge_torch/generative/examples/tiny_llama/verify.py +61 -0
- ai_edge_torch/generative/layers/builder.py +3 -1
- ai_edge_torch/generative/layers/model_config.py +3 -0
- ai_edge_torch/generative/layers/normalization.py +31 -20
- ai_edge_torch/generative/layers/scaled_dot_product_attention.py +19 -9
- ai_edge_torch/generative/layers/unet/blocks_2d.py +9 -4
- ai_edge_torch/generative/layers/unet/model_config.py +1 -0
- ai_edge_torch/generative/test/test_model_conversion_large.py +1 -1
- ai_edge_torch/generative/utilities/stable_diffusion_loader.py +4 -0
- ai_edge_torch/generative/utilities/verifier.py +200 -0
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240918.dist-info → ai_edge_torch_nightly-0.3.0.dev20240919.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240918.dist-info → ai_edge_torch_nightly-0.3.0.dev20240919.dist-info}/RECORD +23 -18
- {ai_edge_torch_nightly-0.3.0.dev20240918.dist-info → ai_edge_torch_nightly-0.3.0.dev20240919.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240918.dist-info → ai_edge_torch_nightly-0.3.0.dev20240919.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240918.dist-info → ai_edge_torch_nightly-0.3.0.dev20240919.dist-info}/top_level.txt +0 -0
@@ -15,16 +15,12 @@
|
|
15
15
|
|
16
16
|
"""Example of building an OpenELM model."""
|
17
17
|
|
18
|
-
import os
|
19
|
-
import pathlib
|
20
|
-
|
21
18
|
from ai_edge_torch.generative.layers import attention
|
22
19
|
from ai_edge_torch.generative.layers import builder
|
23
20
|
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
24
21
|
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
25
22
|
import ai_edge_torch.generative.layers.model_config as cfg
|
26
23
|
import ai_edge_torch.generative.utilities.loader as loading_utils
|
27
|
-
import numpy as np
|
28
24
|
import torch
|
29
25
|
from torch import nn
|
30
26
|
|
@@ -210,28 +206,3 @@ def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
|
|
210
206
|
loader.load(model, strict=False)
|
211
207
|
model.eval()
|
212
208
|
return model
|
213
|
-
|
214
|
-
|
215
|
-
def define_and_run(checkpoint_path: str) -> None:
|
216
|
-
"""Instantiates and runs an OpenELM model."""
|
217
|
-
|
218
|
-
current_dir = pathlib.Path(__file__).parent.resolve()
|
219
|
-
openelm_goldens = torch.load(current_dir / "openelm_lm_logits.pt")
|
220
|
-
kv_cache_max_len = 1024
|
221
|
-
model = build_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
|
222
|
-
idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
|
223
|
-
tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.int, device="cpu")
|
224
|
-
tokens[0, :4] = idx
|
225
|
-
input_pos = torch.arange(0, kv_cache_max_len, dtype=torch.int)
|
226
|
-
kv = kv_utils.KVCache.from_model_config(model.config)
|
227
|
-
output = model.forward(tokens, input_pos, kv)
|
228
|
-
assert torch.allclose(
|
229
|
-
openelm_goldens, output["logits"][0, idx.shape[1] - 1, :], atol=1e-05
|
230
|
-
)
|
231
|
-
|
232
|
-
|
233
|
-
if __name__ == "__main__":
|
234
|
-
input_checkpoint_path = os.path.join(
|
235
|
-
pathlib.Path.home(), "Downloads/llm_data/openelm"
|
236
|
-
)
|
237
|
-
define_and_run(input_checkpoint_path)
|
@@ -0,0 +1,61 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Verifies the reauthored OpenELM-3B model."""
|
17
|
+
|
18
|
+
import pathlib
|
19
|
+
|
20
|
+
from absl import app
|
21
|
+
from absl import flags
|
22
|
+
from ai_edge_torch.generative.examples.openelm import openelm
|
23
|
+
from ai_edge_torch.generative.utilities import verifier
|
24
|
+
import transformers
|
25
|
+
|
26
|
+
_PROMPTS = flags.DEFINE_multi_string(
|
27
|
+
"prompts",
|
28
|
+
"What is the meaning of life?",
|
29
|
+
"The input prompts to generate answers.",
|
30
|
+
)
|
31
|
+
|
32
|
+
|
33
|
+
def main(_):
|
34
|
+
checkpoint = "apple/OpenELM-3B"
|
35
|
+
verifier.log_msg("Loading the original model from", checkpoint)
|
36
|
+
original_model = transformers.AutoModelForCausalLM.from_pretrained(
|
37
|
+
checkpoint, trust_remote_code=True
|
38
|
+
)
|
39
|
+
|
40
|
+
# Locate the cached dir.
|
41
|
+
cached_config_file = transformers.utils.cached_file(
|
42
|
+
checkpoint, transformers.utils.CONFIG_NAME
|
43
|
+
)
|
44
|
+
reauthored_checkpoint = pathlib.Path(cached_config_file).parent
|
45
|
+
verifier.log_msg("Building the reauthored model from", reauthored_checkpoint)
|
46
|
+
reauthored_model = openelm.build_model(reauthored_checkpoint)
|
47
|
+
|
48
|
+
tokenizer_checkpoint = "meta-llama/Llama-2-7b-hf"
|
49
|
+
verifier.log_msg("Loading the tokenizer from", tokenizer_checkpoint)
|
50
|
+
tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer_checkpoint)
|
51
|
+
|
52
|
+
verifier.verify_reauthored_model(
|
53
|
+
original_model=original_model,
|
54
|
+
reauthored_model=reauthored_model,
|
55
|
+
tokenizer=tokenizer,
|
56
|
+
prompts=_PROMPTS.value,
|
57
|
+
)
|
58
|
+
|
59
|
+
|
60
|
+
if __name__ == "__main__":
|
61
|
+
app.run(main)
|
@@ -15,16 +15,12 @@
|
|
15
15
|
|
16
16
|
"""Example of building a Phi-2 model."""
|
17
17
|
|
18
|
-
import os
|
19
|
-
import pathlib
|
20
|
-
|
21
18
|
from ai_edge_torch.generative.layers import attention
|
22
19
|
from ai_edge_torch.generative.layers import builder
|
23
20
|
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
24
21
|
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
25
22
|
import ai_edge_torch.generative.layers.model_config as cfg
|
26
23
|
import ai_edge_torch.generative.utilities.loader as loading_utils
|
27
|
-
import numpy as np
|
28
24
|
import torch
|
29
25
|
from torch import nn
|
30
26
|
|
@@ -143,7 +139,10 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
143
139
|
intermediate_size=10240,
|
144
140
|
use_bias=True,
|
145
141
|
)
|
146
|
-
norm_config = cfg.NormalizationConfig(
|
142
|
+
norm_config = cfg.NormalizationConfig(
|
143
|
+
type=cfg.NormalizationType.LAYER_NORM,
|
144
|
+
use_input_shape=False, # Phi-2 does layer-norm with the weight shape.
|
145
|
+
)
|
147
146
|
block_config = cfg.TransformerBlockConfig(
|
148
147
|
attn_config=attn_config,
|
149
148
|
ff_config=ff_config,
|
@@ -182,29 +181,3 @@ def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
|
|
182
181
|
loader.load(model)
|
183
182
|
model.eval()
|
184
183
|
return model
|
185
|
-
|
186
|
-
|
187
|
-
def define_and_run(checkpoint_path: str) -> None:
|
188
|
-
"""Instantiates and runs a Phi-2 model."""
|
189
|
-
|
190
|
-
current_dir = pathlib.Path(__file__).parent.resolve()
|
191
|
-
phi2_goldens = torch.load(current_dir / "phi2_lm_logits.pt")
|
192
|
-
kv_cache_max_len = 1024
|
193
|
-
model = build_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
|
194
|
-
idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
|
195
|
-
tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.int, device="cpu")
|
196
|
-
tokens[0, :4] = idx
|
197
|
-
input_pos = torch.arange(0, kv_cache_max_len, dtype=torch.int)
|
198
|
-
kv = kv_utils.KVCache.from_model_config(model.config)
|
199
|
-
output = model.forward(tokens, input_pos, kv)
|
200
|
-
print("comparing with goldens..")
|
201
|
-
assert torch.allclose(
|
202
|
-
phi2_goldens, output["logits"][0, idx.shape[1] - 1, :], atol=1e-02
|
203
|
-
)
|
204
|
-
|
205
|
-
|
206
|
-
if __name__ == "__main__":
|
207
|
-
input_checkpoint_path = os.path.join(
|
208
|
-
pathlib.Path.home(), "Downloads/llm_data/phi2"
|
209
|
-
)
|
210
|
-
define_and_run(input_checkpoint_path)
|
@@ -0,0 +1,53 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Verifies the reauthored Phi-2 model."""
|
17
|
+
|
18
|
+
from absl import app
|
19
|
+
from absl import flags
|
20
|
+
from ai_edge_torch.generative.examples.phi import phi2
|
21
|
+
from ai_edge_torch.generative.utilities import verifier
|
22
|
+
import kagglehub
|
23
|
+
import transformers
|
24
|
+
|
25
|
+
_PROMPTS = flags.DEFINE_multi_string(
|
26
|
+
"prompts",
|
27
|
+
"What is the meaning of life?",
|
28
|
+
"The input prompts to generate answers.",
|
29
|
+
)
|
30
|
+
|
31
|
+
|
32
|
+
def main(_):
|
33
|
+
checkpoint = kagglehub.model_download("Microsoft/phi/transformers/2")
|
34
|
+
verifier.log_msg("Loading the original model from", checkpoint)
|
35
|
+
original_model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint)
|
36
|
+
|
37
|
+
verifier.log_msg("Building the reauthored model from", checkpoint)
|
38
|
+
reauthored_model = phi2.build_model(checkpoint)
|
39
|
+
|
40
|
+
verifier.log_msg("Loading the tokenizer from", checkpoint)
|
41
|
+
tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
|
42
|
+
|
43
|
+
verifier.verify_reauthored_model(
|
44
|
+
original_model=original_model,
|
45
|
+
reauthored_model=reauthored_model,
|
46
|
+
tokenizer=tokenizer,
|
47
|
+
prompts=_PROMPTS.value,
|
48
|
+
atol=1e-03,
|
49
|
+
)
|
50
|
+
|
51
|
+
|
52
|
+
if __name__ == "__main__":
|
53
|
+
app.run(main)
|
@@ -16,15 +16,10 @@
|
|
16
16
|
"""Example of building a SmolLM model."""
|
17
17
|
|
18
18
|
import copy
|
19
|
-
import os
|
20
|
-
import pathlib
|
21
19
|
|
22
20
|
from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
|
23
|
-
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
24
21
|
import ai_edge_torch.generative.layers.model_config as cfg
|
25
22
|
import ai_edge_torch.generative.utilities.loader as loading_utils
|
26
|
-
import numpy as np
|
27
|
-
import torch
|
28
23
|
from torch import nn
|
29
24
|
|
30
25
|
TENSOR_NAMES = copy.copy(tiny_llama.TENSOR_NAMES)
|
@@ -104,28 +99,3 @@ def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
|
|
104
99
|
loader.load(model, strict=False)
|
105
100
|
model.eval()
|
106
101
|
return model
|
107
|
-
|
108
|
-
|
109
|
-
def define_and_run(checkpoint_path: str) -> None:
|
110
|
-
"""Instantiates and runs a SmolLM model."""
|
111
|
-
|
112
|
-
current_dir = pathlib.Path(__file__).parent.resolve()
|
113
|
-
smollm_goldens = torch.load(current_dir / "smollm_lm_logits.pt")
|
114
|
-
kv_cache_max_len = 1024
|
115
|
-
model = build_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
|
116
|
-
idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
|
117
|
-
tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.int, device="cpu")
|
118
|
-
tokens[0, :4] = idx
|
119
|
-
input_pos = torch.arange(0, kv_cache_max_len, dtype=torch.int)
|
120
|
-
kv = kv_utils.KVCache.from_model_config(model.config)
|
121
|
-
output = model.forward(tokens, input_pos, kv)
|
122
|
-
assert torch.allclose(
|
123
|
-
smollm_goldens, output["logits"][0, idx.shape[1] - 1, :], atol=1e-05
|
124
|
-
)
|
125
|
-
|
126
|
-
|
127
|
-
if __name__ == "__main__":
|
128
|
-
input_checkpoint_path = os.path.join(
|
129
|
-
pathlib.Path.home(), "Downloads/llm_data/smollm"
|
130
|
-
)
|
131
|
-
define_and_run(input_checkpoint_path)
|
@@ -0,0 +1,59 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Verifies the reauthored SmolLM-135M model."""
|
17
|
+
|
18
|
+
import pathlib
|
19
|
+
|
20
|
+
from absl import app
|
21
|
+
from absl import flags
|
22
|
+
from ai_edge_torch.generative.examples.smollm import smollm
|
23
|
+
from ai_edge_torch.generative.utilities import verifier
|
24
|
+
import transformers
|
25
|
+
|
26
|
+
_PROMPTS = flags.DEFINE_multi_string(
|
27
|
+
"prompts",
|
28
|
+
"What is the meaning of life?",
|
29
|
+
"The input prompts to generate answers.",
|
30
|
+
)
|
31
|
+
|
32
|
+
|
33
|
+
def main(_):
|
34
|
+
checkpoint = "HuggingFaceTB/SmolLM-135M"
|
35
|
+
verifier.log_msg("Loading the original model from", checkpoint)
|
36
|
+
original_model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint)
|
37
|
+
|
38
|
+
# Locate the cached dir.
|
39
|
+
cached_config_file = transformers.utils.cached_file(
|
40
|
+
checkpoint, transformers.utils.CONFIG_NAME
|
41
|
+
)
|
42
|
+
reauthored_checkpoint = pathlib.Path(cached_config_file).parent
|
43
|
+
verifier.log_msg("Building the reauthored model from", reauthored_checkpoint)
|
44
|
+
reauthored_model = smollm.build_model(reauthored_checkpoint)
|
45
|
+
|
46
|
+
verifier.log_msg("Loading the tokenizer from", checkpoint)
|
47
|
+
tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
|
48
|
+
|
49
|
+
verifier.verify_reauthored_model(
|
50
|
+
original_model=original_model,
|
51
|
+
reauthored_model=reauthored_model,
|
52
|
+
tokenizer=tokenizer,
|
53
|
+
prompts=_PROMPTS.value,
|
54
|
+
atol=1e-04,
|
55
|
+
)
|
56
|
+
|
57
|
+
|
58
|
+
if __name__ == "__main__":
|
59
|
+
app.run(main)
|
@@ -15,16 +15,12 @@
|
|
15
15
|
|
16
16
|
"""Example of building a TinyLlama model."""
|
17
17
|
|
18
|
-
import os
|
19
|
-
import pathlib
|
20
|
-
|
21
18
|
from ai_edge_torch.generative.layers import attention
|
22
19
|
from ai_edge_torch.generative.layers import builder
|
23
20
|
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
24
21
|
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
25
22
|
import ai_edge_torch.generative.layers.model_config as cfg
|
26
23
|
import ai_edge_torch.generative.utilities.loader as loading_utils
|
27
|
-
import numpy as np
|
28
24
|
import torch
|
29
25
|
from torch import nn
|
30
26
|
|
@@ -179,28 +175,3 @@ def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
|
|
179
175
|
loader.load(model)
|
180
176
|
model.eval()
|
181
177
|
return model
|
182
|
-
|
183
|
-
|
184
|
-
def define_and_run(checkpoint_path: str) -> None:
|
185
|
-
"""Instantiates and runs a TinyLlama model."""
|
186
|
-
|
187
|
-
current_dir = pathlib.Path(__file__).parent.resolve()
|
188
|
-
tiny_llama_goldens = torch.load(current_dir / "tiny_llama_lm_logits.pt")
|
189
|
-
kv_cache_max_len = 1024
|
190
|
-
model = build_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
|
191
|
-
idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
|
192
|
-
tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.int, device="cpu")
|
193
|
-
tokens[0, :4] = idx
|
194
|
-
input_pos = torch.arange(0, kv_cache_max_len, dtype=torch.int)
|
195
|
-
kv = kv_utils.KVCache.from_model_config(model.config)
|
196
|
-
output = model.forward(tokens, input_pos, kv)
|
197
|
-
assert torch.allclose(
|
198
|
-
tiny_llama_goldens, output["logits"][0, idx.shape[1] - 1, :], atol=1e-02
|
199
|
-
)
|
200
|
-
|
201
|
-
|
202
|
-
if __name__ == "__main__":
|
203
|
-
input_checkpoint_path = os.path.join(
|
204
|
-
pathlib.Path.home(), "Downloads/llm_data/tiny_llama"
|
205
|
-
)
|
206
|
-
define_and_run(input_checkpoint_path)
|
@@ -0,0 +1,61 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Verifies the reauthored TinyLlama-1.1B model."""
|
17
|
+
|
18
|
+
import pathlib
|
19
|
+
|
20
|
+
from absl import app
|
21
|
+
from absl import flags
|
22
|
+
from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
|
23
|
+
from ai_edge_torch.generative.utilities import verifier
|
24
|
+
import transformers
|
25
|
+
|
26
|
+
_PROMPTS = flags.DEFINE_multi_string(
|
27
|
+
"prompts",
|
28
|
+
"Show me the program to add 2 and 3.",
|
29
|
+
"The input prompts to generate answers.",
|
30
|
+
)
|
31
|
+
|
32
|
+
|
33
|
+
def main(_):
|
34
|
+
checkpoint = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
|
35
|
+
verifier.log_msg("Loading the original model from", checkpoint)
|
36
|
+
original_model = transformers.AutoModelForCausalLM.from_pretrained(
|
37
|
+
checkpoint, trust_remote_code=True
|
38
|
+
)
|
39
|
+
|
40
|
+
# Locate the cached dir.
|
41
|
+
cached_config_file = transformers.utils.cached_file(
|
42
|
+
checkpoint, transformers.utils.CONFIG_NAME
|
43
|
+
)
|
44
|
+
reauthored_checkpoint = pathlib.Path(cached_config_file).parent
|
45
|
+
verifier.log_msg("Building the reauthored model from", reauthored_checkpoint)
|
46
|
+
reauthored_model = tiny_llama.build_model(reauthored_checkpoint)
|
47
|
+
|
48
|
+
verifier.log_msg("Loading the tokenizer from", checkpoint)
|
49
|
+
tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
|
50
|
+
|
51
|
+
verifier.verify_reauthored_model(
|
52
|
+
original_model=original_model,
|
53
|
+
reauthored_model=reauthored_model,
|
54
|
+
tokenizer=tokenizer,
|
55
|
+
prompts=_PROMPTS.value,
|
56
|
+
atol=1e-04,
|
57
|
+
)
|
58
|
+
|
59
|
+
|
60
|
+
if __name__ == "__main__":
|
61
|
+
app.run(main)
|
@@ -75,7 +75,9 @@ def build_norm(dim: int, config: cfg.NormalizationConfig):
|
|
75
75
|
zero_centered_gamma=config.zero_centered,
|
76
76
|
)
|
77
77
|
elif config.type == cfg.NormalizationType.LAYER_NORM:
|
78
|
-
return normalization.LayerNorm(
|
78
|
+
return normalization.LayerNorm(
|
79
|
+
dim, config.epsilon, config.enable_hlfb, config.use_input_shape
|
80
|
+
)
|
79
81
|
elif config.type == cfg.NormalizationType.GROUP_NORM:
|
80
82
|
return normalization.GroupNorm(
|
81
83
|
config.group_num, dim, config.epsilon, config.enable_hlfb
|
@@ -69,6 +69,9 @@ class NormalizationConfig:
|
|
69
69
|
zero_centered: bool = False
|
70
70
|
# Number of groups used in group normalization.
|
71
71
|
group_num: Optional[float] = None
|
72
|
+
# Whether to use the input shape to determine the dimension of normalization
|
73
|
+
# when type is LAYER_NORM.
|
74
|
+
use_input_shape: bool = True
|
72
75
|
|
73
76
|
|
74
77
|
@dataclass
|
@@ -78,7 +78,7 @@ class GroupNorm(torch.nn.Module):
|
|
78
78
|
group_num (int): Number of groups to separate the channels into.
|
79
79
|
dim (int): Dimension of the input tensor.
|
80
80
|
eps (float): A small float value to ensure numerical stability (default:
|
81
|
-
1e-
|
81
|
+
1e-5).
|
82
82
|
enable_hlfb (bool): Whether to convert this normalization into a single
|
83
83
|
op.
|
84
84
|
"""
|
@@ -112,7 +112,13 @@ class GroupNorm(torch.nn.Module):
|
|
112
112
|
|
113
113
|
class LayerNorm(torch.nn.Module):
|
114
114
|
|
115
|
-
def __init__(
|
115
|
+
def __init__(
|
116
|
+
self,
|
117
|
+
dim: int,
|
118
|
+
eps: float = 1e-5,
|
119
|
+
enable_hlfb: bool = False,
|
120
|
+
use_input_shape: bool = True,
|
121
|
+
):
|
116
122
|
"""Initialize the LayerNorm layer.
|
117
123
|
|
118
124
|
Args:
|
@@ -121,9 +127,12 @@ class LayerNorm(torch.nn.Module):
|
|
121
127
|
1e-6).
|
122
128
|
enable_hlfb (bool): Whether to convert this normalization into a single
|
123
129
|
op.
|
130
|
+
use_input_shape (bool): Whether to use the input shape to determine the
|
131
|
+
dimension of normalization (default: True).
|
124
132
|
"""
|
125
133
|
super().__init__()
|
126
134
|
self.enable_hlfb = enable_hlfb
|
135
|
+
self.use_input_shape = use_input_shape
|
127
136
|
self.eps = eps
|
128
137
|
self.weight = torch.nn.Parameter(torch.ones(dim))
|
129
138
|
self.bias = torch.nn.Parameter(torch.ones(dim))
|
@@ -139,19 +148,18 @@ class LayerNorm(torch.nn.Module):
|
|
139
148
|
"""
|
140
149
|
if self.enable_hlfb:
|
141
150
|
return layer_norm_with_hlfb(
|
142
|
-
x,
|
143
|
-
self.weight,
|
144
|
-
self.bias,
|
145
|
-
self.eps,
|
151
|
+
x, self.weight, self.bias, self.eps, self.use_input_shape
|
146
152
|
)
|
153
|
+
|
154
|
+
if self.use_input_shape:
|
155
|
+
normalized_shape = x.shape
|
156
|
+
weight = self.weight.broadcast_to(x.shape)
|
157
|
+
bias = self.bias.broadcast_to(x.shape)
|
147
158
|
else:
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
self.bias.broadcast_to(x.shape),
|
153
|
-
self.eps,
|
154
|
-
)
|
159
|
+
normalized_shape = self.weight.shape
|
160
|
+
weight = self.weight
|
161
|
+
bias = self.bias
|
162
|
+
return F.layer_norm(x, normalized_shape, weight, bias, self.eps)
|
155
163
|
|
156
164
|
|
157
165
|
def group_norm_with_hlfb(
|
@@ -193,6 +201,7 @@ def layer_norm_with_hlfb(
|
|
193
201
|
w: torch.Tensor,
|
194
202
|
b: torch.Tensor,
|
195
203
|
eps: float,
|
204
|
+
use_input_shape: bool,
|
196
205
|
):
|
197
206
|
"""Layer Normalization with high-level function boundary enabled.
|
198
207
|
|
@@ -201,18 +210,20 @@ def layer_norm_with_hlfb(
|
|
201
210
|
w (torch.Tensor): The weight tensor for the normalization.
|
202
211
|
b (torch.Tensor): The bias tensor for the normalization.
|
203
212
|
eps (float): A small float value to ensure numerical stability.
|
213
|
+
use_input_shape (bool): Whether to use the input shape to determine the
|
214
|
+
dimension of normalization.
|
204
215
|
|
205
216
|
Returns:
|
206
217
|
The output tensor of Layer Normalization.
|
207
218
|
"""
|
208
219
|
builder = StableHLOCompositeBuilder(name="odml.layer_norm", attr={"eps": eps})
|
209
220
|
x, w, b = builder.mark_inputs(x, w, b)
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
)
|
221
|
+
if use_input_shape:
|
222
|
+
normalized_shape = x.shape
|
223
|
+
w = w.broadcast_to(x.shape)
|
224
|
+
b = b.broadcast_to(x.shape)
|
225
|
+
else:
|
226
|
+
normalized_shape = w.shape
|
227
|
+
y = F.layer_norm(x, normalized_shape, w, b, eps=eps)
|
217
228
|
y = builder.mark_outputs(y)
|
218
229
|
return y
|
@@ -119,15 +119,25 @@ def scaled_dot_product_attention_with_hlfb(
|
|
119
119
|
# Handle the GQA case, where q.shape[1] % k.shape[1] == 0.
|
120
120
|
k = k.repeat_interleave(q.shape[1] // k.shape[1], dim=1)
|
121
121
|
v = v.repeat_interleave(q.shape[1] // v.shape[1], dim=1)
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
122
|
+
if softcap is None:
|
123
|
+
y = F.scaled_dot_product_attention(
|
124
|
+
q,
|
125
|
+
k,
|
126
|
+
v,
|
127
|
+
attn_mask=mask,
|
128
|
+
dropout_p=0.0,
|
129
|
+
is_causal=mask is None,
|
130
|
+
scale=scale,
|
131
|
+
)
|
132
|
+
else:
|
133
|
+
q.mul_(scale)
|
134
|
+
scores = q @ k.transpose(-1, -2)
|
135
|
+
scores = scores / softcap
|
136
|
+
scores = torch.tanh(scores)
|
137
|
+
scores = scores * softcap
|
138
|
+
scores = scores + mask
|
139
|
+
out = F.softmax(scores.float(), dim=-1).type_as(q)
|
140
|
+
y = torch.matmul(out, v)
|
131
141
|
|
132
142
|
result = y.transpose(1, 2)
|
133
143
|
result = builder.mark_outputs(result)
|
@@ -41,22 +41,22 @@ class ResidualBlock2D(nn.Module):
|
|
41
41
|
)
|
42
42
|
self.conv_1 = nn.Conv2d(
|
43
43
|
config.in_channels,
|
44
|
-
config.
|
44
|
+
config.hidden_channels,
|
45
45
|
kernel_size=3,
|
46
46
|
stride=1,
|
47
47
|
padding=1,
|
48
48
|
)
|
49
49
|
if config.time_embedding_channels is not None:
|
50
50
|
self.time_emb_proj = nn.Linear(
|
51
|
-
config.time_embedding_channels, config.
|
51
|
+
config.time_embedding_channels, config.hidden_channels
|
52
52
|
)
|
53
53
|
else:
|
54
54
|
self.time_emb_proj = None
|
55
55
|
self.norm_2 = layers_builder.build_norm(
|
56
|
-
config.
|
56
|
+
config.hidden_channels, config.normalization_config
|
57
57
|
)
|
58
58
|
self.conv_2 = nn.Conv2d(
|
59
|
-
config.
|
59
|
+
config.hidden_channels,
|
60
60
|
config.out_channels,
|
61
61
|
kernel_size=3,
|
62
62
|
stride=1,
|
@@ -391,6 +391,7 @@ class DownEncoderBlock2D(nn.Module):
|
|
391
391
|
ResidualBlock2D(
|
392
392
|
unet_cfg.ResidualBlock2DConfig(
|
393
393
|
in_channels=input_channels,
|
394
|
+
hidden_channels=config.out_channels,
|
394
395
|
out_channels=config.out_channels,
|
395
396
|
time_embedding_channels=config.time_embedding_channels,
|
396
397
|
normalization_config=config.normalization_config,
|
@@ -492,6 +493,7 @@ class UpDecoderBlock2D(nn.Module):
|
|
492
493
|
ResidualBlock2D(
|
493
494
|
unet_cfg.ResidualBlock2DConfig(
|
494
495
|
in_channels=input_channels,
|
496
|
+
hidden_channels=config.out_channels,
|
495
497
|
out_channels=config.out_channels,
|
496
498
|
time_embedding_channels=config.time_embedding_channels,
|
497
499
|
normalization_config=config.normalization_config,
|
@@ -602,6 +604,7 @@ class SkipUpDecoderBlock2D(nn.Module):
|
|
602
604
|
ResidualBlock2D(
|
603
605
|
unet_cfg.ResidualBlock2DConfig(
|
604
606
|
in_channels=resnet_in_channels + res_skip_channels,
|
607
|
+
hidden_channels=config.out_channels,
|
605
608
|
out_channels=config.out_channels,
|
606
609
|
time_embedding_channels=config.time_embedding_channels,
|
607
610
|
normalization_config=config.normalization_config,
|
@@ -706,6 +709,7 @@ class MidBlock2D(nn.Module):
|
|
706
709
|
ResidualBlock2D(
|
707
710
|
unet_cfg.ResidualBlock2DConfig(
|
708
711
|
in_channels=config.in_channels,
|
712
|
+
hidden_channels=config.in_channels,
|
709
713
|
out_channels=config.in_channels,
|
710
714
|
time_embedding_channels=config.time_embedding_channels,
|
711
715
|
normalization_config=config.normalization_config,
|
@@ -724,6 +728,7 @@ class MidBlock2D(nn.Module):
|
|
724
728
|
ResidualBlock2D(
|
725
729
|
unet_cfg.ResidualBlock2DConfig(
|
726
730
|
in_channels=config.in_channels,
|
731
|
+
hidden_channels=config.in_channels,
|
727
732
|
out_channels=config.in_channels,
|
728
733
|
time_embedding_channels=config.time_embedding_channels,
|
729
734
|
normalization_config=config.normalization_config,
|
@@ -96,7 +96,7 @@ class TestModelConversion(googletest.TestCase):
|
|
96
96
|
def test_gemma2(self):
|
97
97
|
config = gemma2.get_fake_model_config()
|
98
98
|
pytorch_model = gemma2.Gemma2(config).eval()
|
99
|
-
self._test_model(config, pytorch_model, "prefill", atol=1e-
|
99
|
+
self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
|
100
100
|
|
101
101
|
@googletest.skipIf(
|
102
102
|
ai_edge_config.Config.use_torch_xla,
|
@@ -412,6 +412,7 @@ class BaseLoader(loader.ModelLoader):
|
|
412
412
|
):
|
413
413
|
residual_block_config = unet_config.ResidualBlock2DConfig(
|
414
414
|
in_channels=config.in_channels,
|
415
|
+
hidden_channels=config.in_channels,
|
415
416
|
out_channels=config.in_channels,
|
416
417
|
time_embedding_channels=config.time_embedding_channels,
|
417
418
|
normalization_config=config.normalization_config,
|
@@ -466,6 +467,7 @@ class BaseLoader(loader.ModelLoader):
|
|
466
467
|
f"{converted_state_param_prefix}.resnets.{i}",
|
467
468
|
unet_config.ResidualBlock2DConfig(
|
468
469
|
in_channels=input_channels,
|
470
|
+
hidden_channels=config.out_channels,
|
469
471
|
out_channels=config.out_channels,
|
470
472
|
time_embedding_channels=config.time_embedding_channels,
|
471
473
|
normalization_config=config.normalization_config,
|
@@ -508,6 +510,7 @@ class BaseLoader(loader.ModelLoader):
|
|
508
510
|
f"{converted_state_param_prefix}.resnets.{i}",
|
509
511
|
unet_config.ResidualBlock2DConfig(
|
510
512
|
in_channels=input_channels,
|
513
|
+
hidden_channels=config.out_channels,
|
511
514
|
out_channels=config.out_channels,
|
512
515
|
time_embedding_channels=config.time_embedding_channels,
|
513
516
|
normalization_config=config.normalization_config,
|
@@ -554,6 +557,7 @@ class BaseLoader(loader.ModelLoader):
|
|
554
557
|
f"{converted_state_param_prefix}.resnets.{i}",
|
555
558
|
unet_config.ResidualBlock2DConfig(
|
556
559
|
in_channels=resnet_in_channels + res_skip_channels,
|
560
|
+
hidden_channels=config.out_channels,
|
557
561
|
out_channels=config.out_channels,
|
558
562
|
time_embedding_channels=config.time_embedding_channels,
|
559
563
|
normalization_config=config.normalization_config,
|
@@ -0,0 +1,200 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Common utility functions to verify the reauthored models."""
|
17
|
+
|
18
|
+
import datetime
|
19
|
+
from typing import List
|
20
|
+
|
21
|
+
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
22
|
+
import numpy as np
|
23
|
+
import torch
|
24
|
+
|
25
|
+
|
26
|
+
def log_msg(*args):
|
27
|
+
print("[%s]" % datetime.datetime.now(), *args)
|
28
|
+
|
29
|
+
|
30
|
+
def forward(
|
31
|
+
model: torch.nn.Module,
|
32
|
+
tokens: torch.Tensor,
|
33
|
+
kv_cache: kv_utils.KVCache,
|
34
|
+
) -> tuple[torch.Tensor, kv_utils.KVCache]:
|
35
|
+
"""Forwards the model reauthored with ai_edge_torch Generative API.
|
36
|
+
|
37
|
+
Args:
|
38
|
+
model (torch.nn.Module): The model to forward. It should be a model built
|
39
|
+
with ai_edge_torch Generative API.
|
40
|
+
tokens (torch.Tensor): The input tokens to forward.
|
41
|
+
kv_cache (KVCache): The KV cache to forward.
|
42
|
+
|
43
|
+
Returns:
|
44
|
+
The output logits and the updated KV cache.
|
45
|
+
"""
|
46
|
+
input_pos = torch.arange(0, tokens.shape[1], dtype=torch.int)
|
47
|
+
output = model.forward(tokens, input_pos, kv_cache)
|
48
|
+
return output["logits"], output["kv_cache"]
|
49
|
+
|
50
|
+
|
51
|
+
def generate(
|
52
|
+
model: torch.nn.Module, prompts: torch.Tensor, response_len: int
|
53
|
+
) -> torch.Tensor:
|
54
|
+
"""Generates the response to the prompts.
|
55
|
+
|
56
|
+
It appends tokens output by the model to the prompts and feeds them back to
|
57
|
+
the model up to decode_len.
|
58
|
+
|
59
|
+
Args:
|
60
|
+
model (torch.nn.Module): The model to generate. It should be a model built
|
61
|
+
with ai_edge_torch Generative API.
|
62
|
+
prompts (torch.Tensor): The prompts to generate.
|
63
|
+
response_len (int): The number of tokens to generate.
|
64
|
+
|
65
|
+
Returns:
|
66
|
+
The generated tokens.
|
67
|
+
"""
|
68
|
+
input_ids = prompts[0].int().tolist()
|
69
|
+
kv_cache = kv_utils.KVCache.from_model_config(model.config)
|
70
|
+
for _ in range(response_len - len(input_ids)):
|
71
|
+
logits, kv_cache = forward(model, torch.tensor([input_ids]), kv_cache)
|
72
|
+
generated_token = logits[0][-1].argmax().item()
|
73
|
+
input_ids.append(generated_token)
|
74
|
+
return torch.tensor([input_ids])
|
75
|
+
|
76
|
+
|
77
|
+
def verify_with_input_ids(
|
78
|
+
original_model: torch.nn.Module,
|
79
|
+
reauthored_model: torch.nn.Module,
|
80
|
+
input_ids: torch.Tensor = torch.from_numpy(np.array([[1, 2, 3, 4]])).int(),
|
81
|
+
kv_cache_max_len: int = 1024,
|
82
|
+
rtol: float = 1e-05,
|
83
|
+
atol: float = 1e-05,
|
84
|
+
) -> bool:
|
85
|
+
"""Verifies if the model reauthored generates the same output of the oringal.
|
86
|
+
|
87
|
+
It compares only one outputs from the original and the reauthored model.
|
88
|
+
|
89
|
+
Args:
|
90
|
+
original_model (torch.nn.Module): The original model.
|
91
|
+
reauthored_model (torch.nn.Module): The model reauthored with ai_edge_torch
|
92
|
+
Generative API.
|
93
|
+
input_ids (torch.Tensor): The input token IDs to forward.
|
94
|
+
kv_cache_max_len (int): The maximum sequence length of the KV cache.
|
95
|
+
rtol (float): The relative tolerance for the comparison.
|
96
|
+
atol (float): The absolute tolerance for the comparison.
|
97
|
+
|
98
|
+
Returns:
|
99
|
+
True if the model reauthored generates the same output of the original.
|
100
|
+
"""
|
101
|
+
tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.int, device="cpu")
|
102
|
+
input_ids_len = input_ids.shape[1]
|
103
|
+
tokens[0, :input_ids_len] = input_ids
|
104
|
+
|
105
|
+
log_msg("Forwarding the original model...")
|
106
|
+
outputs_original = original_model.forward(tokens)
|
107
|
+
logits_original = outputs_original.logits[0, input_ids_len - 1, :]
|
108
|
+
log_msg("logits_original: ", logits_original)
|
109
|
+
|
110
|
+
log_msg("Forwarding the reauthored model...")
|
111
|
+
kv_cache = kv_utils.KVCache.from_model_config(reauthored_model.config)
|
112
|
+
outputs_reauthored = forward(reauthored_model, tokens, kv_cache)
|
113
|
+
logits_reauthored = outputs_reauthored[0][0, input_ids_len - 1, :]
|
114
|
+
log_msg("logits_reauthored:", logits_reauthored)
|
115
|
+
|
116
|
+
return torch.allclose(
|
117
|
+
logits_original, logits_reauthored, rtol=rtol, atol=atol
|
118
|
+
)
|
119
|
+
|
120
|
+
|
121
|
+
def verify_model_with_prompts(
|
122
|
+
original_model: torch.nn.Module,
|
123
|
+
reauthored_model: torch.nn.Module,
|
124
|
+
tokenizer: torch.nn.Module,
|
125
|
+
prompts: str,
|
126
|
+
) -> bool:
|
127
|
+
"""Verifies if the model reauthored generates the same answer of the oringal.
|
128
|
+
|
129
|
+
It compares an answer, i.e. multiple continuous outputs generated by the
|
130
|
+
original and the reauthored model.
|
131
|
+
|
132
|
+
Args:
|
133
|
+
original_model (torch.nn.Module): The original model.
|
134
|
+
reauthored_model (torch.nn.Module): The model reauthored with ai_edge_torch
|
135
|
+
Generative API.
|
136
|
+
tokenizer (torch.nn.Module): The tokenizer.
|
137
|
+
prompts (str): The input prompts to generate answers.
|
138
|
+
|
139
|
+
Returns:
|
140
|
+
True if the model reauthored generates the same answer of the original.
|
141
|
+
"""
|
142
|
+
prompt_tokens = tokenizer.encode(prompts, return_tensors="pt")
|
143
|
+
|
144
|
+
log_msg("Generating answer with the original model...")
|
145
|
+
outputs_original = original_model.generate(prompt_tokens)
|
146
|
+
response_original = tokenizer.decode(outputs_original[0])
|
147
|
+
log_msg("outputs_from_original_model: [[", response_original, "]]")
|
148
|
+
|
149
|
+
log_msg("Generating answer with the reauthored model...")
|
150
|
+
generate_len = len(outputs_original[0])
|
151
|
+
outputs_reauthored = generate(reauthored_model, prompt_tokens, generate_len)
|
152
|
+
response_reauthored = tokenizer.decode(outputs_reauthored[0])
|
153
|
+
log_msg("outputs from reauthored model: [[", response_reauthored, "]]")
|
154
|
+
|
155
|
+
return response_original == response_reauthored
|
156
|
+
|
157
|
+
|
158
|
+
def verify_reauthored_model(
|
159
|
+
original_model: torch.nn.Module,
|
160
|
+
reauthored_model: torch.nn.Module,
|
161
|
+
tokenizer: torch.nn.Module,
|
162
|
+
prompts: List[str],
|
163
|
+
rtol: float = 1e-05,
|
164
|
+
atol: float = 1e-05,
|
165
|
+
):
|
166
|
+
"""Verifies the reauthored model against the original model.
|
167
|
+
|
168
|
+
It verifies the reauthored model with two methods:
|
169
|
+
1. It compares the output of the original and the reauthored model with an
|
170
|
+
arbitrary input.
|
171
|
+
2. It compares the answer generated by the original and the reauthored model
|
172
|
+
with a prompt.
|
173
|
+
|
174
|
+
It prints out "PASS" or "FAILED" to the console.
|
175
|
+
|
176
|
+
Args:
|
177
|
+
original_model (torch.nn.Module): The original model.
|
178
|
+
reauthored_model (torch.nn.Module): The model reauthored with ai_edge_torch
|
179
|
+
Generative API.
|
180
|
+
tokenizer (torch.nn.Module): The tokenizer.
|
181
|
+
prompts (List[str]): List of the input prompts to generate answers.
|
182
|
+
rtol (float): The relative tolerance for the comparison.
|
183
|
+
atol (float): The absolute tolerance for the comparison.
|
184
|
+
"""
|
185
|
+
log_msg("Verifying the reauthored model with an arbitrary input...")
|
186
|
+
if verify_with_input_ids(
|
187
|
+
original_model, reauthored_model, rtol=rtol, atol=atol
|
188
|
+
):
|
189
|
+
log_msg("PASS")
|
190
|
+
else:
|
191
|
+
log_msg("FAILED")
|
192
|
+
|
193
|
+
for p in prompts:
|
194
|
+
log_msg("Verifying the reauthored model with prompts:", p)
|
195
|
+
if verify_model_with_prompts(
|
196
|
+
original_model, reauthored_model, tokenizer, p
|
197
|
+
):
|
198
|
+
log_msg("PASS")
|
199
|
+
else:
|
200
|
+
log_msg("FAILED")
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.3.0.
|
3
|
+
Version: 0.3.0.dev20240919
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
@@ -3,7 +3,7 @@ ai_edge_torch/config.py,sha256=FMWeCH2b7HYILBvaI1iZNnYCO4WAhDOwBZBmIE-xrF0,909
|
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
4
4
|
ai_edge_torch/fx_pass_base.py,sha256=D86Gw3pIRcpnTebUPKlnPbPGJae1S6Fw4DZZ3ZkD0zw,3730
|
5
5
|
ai_edge_torch/model.py,sha256=NYV6Mkaje_ditIEI_s_7nLP_-8i4kbGM8nRzieVkbUI,5397
|
6
|
-
ai_edge_torch/version.py,sha256=
|
6
|
+
ai_edge_torch/version.py,sha256=N5hYc9s2RU44J1_oe0UfJhTFo0d4JvMlKvxNlYtK0GI,706
|
7
7
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
8
8
|
ai_edge_torch/_convert/conversion.py,sha256=5uPwHhmc6kwiIz-CqaiHDejf2SOWMHrb-rYEHm69wKc,3801
|
9
9
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
@@ -45,13 +45,16 @@ ai_edge_torch/generative/examples/gemma/gemma.py,sha256=hjpSPzEjPHuxwRJ-vHHtCCf2
|
|
45
45
|
ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=gCLOti-4xHunjphNBbx9St6faRteSakm8Oex6R1Xek0,10272
|
46
46
|
ai_edge_torch/generative/examples/openelm/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
47
47
|
ai_edge_torch/generative/examples/openelm/convert_to_tflite.py,sha256=HnqP3te1Qvy4SKaaqPrsG05eojiKDJShp4H3jPC9tYg,2023
|
48
|
-
ai_edge_torch/generative/examples/openelm/openelm.py,sha256=
|
48
|
+
ai_edge_torch/generative/examples/openelm/openelm.py,sha256=gGkHELNrt4xqnu11fCh3sJbZ7OsPyvoiF1J1aKCs5r8,7532
|
49
|
+
ai_edge_torch/generative/examples/openelm/verify.py,sha256=2qFdyLfcefdA3s1KQ-ZGWo4XReMXkEQAvpUEyJE5iqM,2057
|
49
50
|
ai_edge_torch/generative/examples/phi/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
50
51
|
ai_edge_torch/generative/examples/phi/convert_to_tflite.py,sha256=viIkbAgknE3zxavTZtib87cMIG2_-jJXtxJPcmB2pGQ,2007
|
51
|
-
ai_edge_torch/generative/examples/phi/phi2.py,sha256=
|
52
|
+
ai_edge_torch/generative/examples/phi/phi2.py,sha256=YwAszA53aOjvaMJ5wua2-5rP79N21Un_Y5yBCfFSYNU,6189
|
53
|
+
ai_edge_torch/generative/examples/phi/verify.py,sha256=R9BjOArnn-3svoIApmP1NwO47n8KIFikOF0_MEgTOa4,1770
|
52
54
|
ai_edge_torch/generative/examples/smollm/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
53
55
|
ai_edge_torch/generative/examples/smollm/convert_to_tflite.py,sha256=86hvBleyFXWmwy3Ke5J7x7WcCtG20D2kiBNrodE0R4w,2017
|
54
|
-
ai_edge_torch/generative/examples/smollm/smollm.py,sha256=
|
56
|
+
ai_edge_torch/generative/examples/smollm/smollm.py,sha256=hyhMk-b5762Q2xmjdD47g85dcbBSNJXNPIsifm1DRto,3239
|
57
|
+
ai_edge_torch/generative/examples/smollm/verify.py,sha256=JzidfVMMFDXzDdwn7ToDPuMo6eaoENNZGpEzX3f61Jk,1976
|
55
58
|
ai_edge_torch/generative/examples/stable_diffusion/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
56
59
|
ai_edge_torch/generative/examples/stable_diffusion/attention.py,sha256=kDWG6MlIGa89zC5KSRcJlw2c4ITuw8KcchtfmF55f4g,3545
|
57
60
|
ai_edge_torch/generative/examples/stable_diffusion/clip.py,sha256=tL6w2dr6VP66IXjSKo9StDNP-wl0RO3fh6dIliiYlFA,4656
|
@@ -76,23 +79,24 @@ ai_edge_torch/generative/examples/test_models/toy_model.py,sha256=QyLeCqDnk71Wvv
|
|
76
79
|
ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py,sha256=e_Kqm5dStSrNE9_aIYC-vYJRsqLn-hJVkmR4QjYqZI0,5913
|
77
80
|
ai_edge_torch/generative/examples/tiny_llama/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
78
81
|
ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py,sha256=Yg5G1LePoryeTib35lqICqaDW6foLUzSRgwJ2FlklIw,2040
|
79
|
-
ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=
|
82
|
+
ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=tlWpa7Aun3u3w5b-9EBtW7olhmSf8W-tn5bKUIwC-ys,6044
|
83
|
+
ai_edge_torch/generative/examples/tiny_llama/verify.py,sha256=jld5PlGOQXMIWc1WoDYL_1nnsoVzRfrg-WgnsxRgaEU,2041
|
80
84
|
ai_edge_torch/generative/fx_passes/__init__.py,sha256=jrzCB3ZyY_t5jJM1e2Czdt3DjAIL43R0_a-T-I7wOzw,1155
|
81
85
|
ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=hhxSQvkDMv0isZJhmuLiod66ZODaJ8uSPSVTJVHBabQ,1931
|
82
86
|
ai_edge_torch/generative/layers/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
83
87
|
ai_edge_torch/generative/layers/attention.py,sha256=Z0Y_G8IG0LmvLX2u9D8__Fkr22szB-az6wMNnZpzhkA,13233
|
84
88
|
ai_edge_torch/generative/layers/attention_utils.py,sha256=68GXGR2HSWBFViTxX7cHifzVG-kcLS2IL2tQJPIpupg,7344
|
85
|
-
ai_edge_torch/generative/layers/builder.py,sha256=
|
89
|
+
ai_edge_torch/generative/layers/builder.py,sha256=toT9Tl1x9o5KbG-eGOEViUr4fd_4f-XLZdMQT0Ae5_8,5130
|
86
90
|
ai_edge_torch/generative/layers/feed_forward.py,sha256=dfS1psdmomgs4EbwzkYyV_xx1xl3P1lU-3GoS8m0Avw,4221
|
87
91
|
ai_edge_torch/generative/layers/kv_cache.py,sha256=2El7kZYnQRCRcVc63xgiAdBh9oVOksDu35p9XggvaGE,6148
|
88
|
-
ai_edge_torch/generative/layers/model_config.py,sha256=
|
89
|
-
ai_edge_torch/generative/layers/normalization.py,sha256=
|
92
|
+
ai_edge_torch/generative/layers/model_config.py,sha256=d0Y-EFb4Rr7iLZ4Bsdf1i92KuhY1BXRqyeUN2kuu510,6923
|
93
|
+
ai_edge_torch/generative/layers/normalization.py,sha256=l_36uFdruJwqqyubnBTM0M-iGiJfeFafyXKPPK8KHVo,6713
|
90
94
|
ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=CZqOoibLcHvUgrgaIIWAlmk3XgE2inzx340MN-npLoU,1347
|
91
|
-
ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=
|
95
|
+
ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=gXxh3papKy4FBpGEX7VyZ7rZ1Js6aHK70Q6DKrVSckY,4154
|
92
96
|
ai_edge_torch/generative/layers/unet/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
93
|
-
ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=
|
97
|
+
ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=c8rtlfDaeKmUfiiTKPmQhNW-U5vW9jFB2pPPcvT6qsc,27527
|
94
98
|
ai_edge_torch/generative/layers/unet/builder.py,sha256=zAqWXdimmMrQRhmE_t9XkS68mh6PSrzwb-2NZZXrR5I,1901
|
95
|
-
ai_edge_torch/generative/layers/unet/model_config.py,sha256=
|
99
|
+
ai_edge_torch/generative/layers/unet/model_config.py,sha256=8ze9kVWMuyZVQcgK7hWYw9TM1W9lXD-2j0iMHlxoGX4,9267
|
96
100
|
ai_edge_torch/generative/quantize/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
97
101
|
ai_edge_torch/generative/quantize/example.py,sha256=n_YFFP3dpKjeNKYZicDGL5LqtjqwhYEIaDrC6-Ci2vE,1539
|
98
102
|
ai_edge_torch/generative/quantize/quant_attrs.py,sha256=n1Fm8BFC8gJa_oiwwAOOghJyHtOXYZ4q-5ZRy4pHrIw,1957
|
@@ -104,14 +108,15 @@ ai_edge_torch/generative/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudj
|
|
104
108
|
ai_edge_torch/generative/test/test_kv_cache.py,sha256=W6Bh0gYDzmwb0j9HdD5_D7Z7FPToP2HSyFrmwIXuFqo,3793
|
105
109
|
ai_edge_torch/generative/test/test_loader.py,sha256=8y74ChO3CZCfEi1eCf3-w47kRgAI4qPYCXpi8rTQXMA,3378
|
106
110
|
ai_edge_torch/generative/test/test_model_conversion.py,sha256=DBlqxW2IT-dZYzEfOMAp86Wtqiu6kgSWZ9BKZR1Clrw,5467
|
107
|
-
ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=
|
111
|
+
ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=dUYFarOldejqbMpa0j0vIDvXlWPAancuI8di3XkGxm8,4498
|
108
112
|
ai_edge_torch/generative/test/test_quantize.py,sha256=8geJhKwYBU20m0mdGPD1BUFwQ0lZKNtCB04SOLO18y4,5980
|
109
113
|
ai_edge_torch/generative/test/utils.py,sha256=YvEhO2HIj1LkBs5du1UxY-cGRW9HMyAYsOUhgsTrTpA,1796
|
110
114
|
ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
|
111
115
|
ai_edge_torch/generative/utilities/converter.py,sha256=MQUg2ZLmfk_2csWmQWKD_II0bXq4X3McI5i-qWraieE,2987
|
112
116
|
ai_edge_torch/generative/utilities/loader.py,sha256=b9iotIhVDX-Zc9XjIDUaLxnV395AyBnkQe3dV5YA7Co,13297
|
113
|
-
ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=
|
117
|
+
ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=dqPD9qRXEWtU3ombslOC-BE2l_dMwHoCNu7NsIJhsso,36158
|
114
118
|
ai_edge_torch/generative/utilities/t5_loader.py,sha256=tEsfy8-ymzbbjOIc-oesXF3yGyyWtJgFXn2s7VOavt8,16961
|
119
|
+
ai_edge_torch/generative/utilities/verifier.py,sha256=QAv1uJdI5o1yfphr_DpzxhZswKa4VG3JZUpqbCCWKMk,7114
|
115
120
|
ai_edge_torch/hlfb/__init__.py,sha256=sH4um75na-O8tzxN6chFyp6Y4xnexsE7kUQpZySv6dE,735
|
116
121
|
ai_edge_torch/hlfb/mark_pattern/__init__.py,sha256=cjTprggj_cuktSCm7-A25e7Shop3k63ylp7sdZmtZ8o,4790
|
117
122
|
ai_edge_torch/hlfb/mark_pattern/passes.py,sha256=pjkKcI1nHECPluAt87cFBrt1DP0f3ge7rHq1NhCkBIE,1936
|
@@ -158,8 +163,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
|
|
158
163
|
ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
159
164
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
160
165
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
161
|
-
ai_edge_torch_nightly-0.3.0.
|
162
|
-
ai_edge_torch_nightly-0.3.0.
|
163
|
-
ai_edge_torch_nightly-0.3.0.
|
164
|
-
ai_edge_torch_nightly-0.3.0.
|
165
|
-
ai_edge_torch_nightly-0.3.0.
|
166
|
+
ai_edge_torch_nightly-0.3.0.dev20240919.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
167
|
+
ai_edge_torch_nightly-0.3.0.dev20240919.dist-info/METADATA,sha256=NkHYIOMz-5DNKJuSQ8wE-3Nz1R6a9YZ59M-Nq8sAnJg,1859
|
168
|
+
ai_edge_torch_nightly-0.3.0.dev20240919.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
|
169
|
+
ai_edge_torch_nightly-0.3.0.dev20240919.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
170
|
+
ai_edge_torch_nightly-0.3.0.dev20240919.dist-info/RECORD,,
|
File without changes
|
File without changes
|