ai-edge-torch-nightly 0.3.0.dev20240916__py3-none-any.whl → 0.3.0.dev20240919__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +36 -56
- ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +36 -56
- ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +36 -56
- ai_edge_torch/generative/examples/openelm/openelm.py +0 -29
- ai_edge_torch/generative/examples/openelm/verify.py +61 -0
- ai_edge_torch/generative/examples/phi/convert_to_tflite.py +36 -56
- ai_edge_torch/generative/examples/phi/phi2.py +4 -31
- ai_edge_torch/generative/examples/phi/verify.py +53 -0
- ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +36 -56
- ai_edge_torch/generative/examples/smollm/smollm.py +0 -30
- ai_edge_torch/generative/examples/smollm/verify.py +59 -0
- ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +6 -0
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +36 -56
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +0 -29
- ai_edge_torch/generative/examples/tiny_llama/verify.py +61 -0
- ai_edge_torch/generative/layers/attention.py +8 -4
- ai_edge_torch/generative/layers/builder.py +3 -1
- ai_edge_torch/generative/layers/model_config.py +3 -0
- ai_edge_torch/generative/layers/normalization.py +31 -20
- ai_edge_torch/generative/layers/scaled_dot_product_attention.py +19 -9
- ai_edge_torch/generative/layers/unet/blocks_2d.py +11 -4
- ai_edge_torch/generative/layers/unet/model_config.py +3 -0
- ai_edge_torch/generative/test/test_model_conversion_large.py +1 -1
- ai_edge_torch/generative/utilities/converter.py +82 -0
- ai_edge_torch/generative/utilities/stable_diffusion_loader.py +10 -0
- ai_edge_torch/generative/utilities/verifier.py +200 -0
- ai_edge_torch/odml_torch/lowerings/_basic.py +14 -4
- ai_edge_torch/odml_torch/lowerings/registry.py +1 -0
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240916.dist-info → ai_edge_torch_nightly-0.3.0.dev20240919.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240916.dist-info → ai_edge_torch_nightly-0.3.0.dev20240919.dist-info}/RECORD +34 -28
- {ai_edge_torch_nightly-0.3.0.dev20240916.dist-info → ai_edge_torch_nightly-0.3.0.dev20240919.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240916.dist-info → ai_edge_torch_nightly-0.3.0.dev20240919.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240916.dist-info → ai_edge_torch_nightly-0.3.0.dev20240919.dist-info}/top_level.txt +0 -0
@@ -18,69 +18,49 @@
|
|
18
18
|
import os
|
19
19
|
import pathlib
|
20
20
|
|
21
|
-
import
|
21
|
+
from absl import app
|
22
|
+
from absl import flags
|
22
23
|
from ai_edge_torch.generative.examples.smollm import smollm
|
23
|
-
from ai_edge_torch.generative.
|
24
|
-
from ai_edge_torch.generative.quantize import quant_recipes
|
25
|
-
import torch
|
24
|
+
from ai_edge_torch.generative.utilities import converter
|
26
25
|
|
26
|
+
_CHECKPOINT_PATH = flags.DEFINE_string(
|
27
|
+
'checkpoint_path',
|
28
|
+
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/smollm'),
|
29
|
+
'The path to the model checkpoint, or directory holding the checkpoint.',
|
30
|
+
)
|
31
|
+
_TFLITE_PATH = flags.DEFINE_string(
|
32
|
+
'tflite_path',
|
33
|
+
'/tmp/smollm_q8_seq512_ekv1024.tflite',
|
34
|
+
'The tflite file path to export.',
|
35
|
+
)
|
36
|
+
_PREFILL_SEQ_LEN = flags.DEFINE_integer(
|
37
|
+
'prefill_seq_len',
|
38
|
+
512,
|
39
|
+
'The maximum size of prefill input tensor.',
|
40
|
+
)
|
41
|
+
_KV_CACHE_MAX_LEN = flags.DEFINE_integer(
|
42
|
+
'kv_cache_max_len',
|
43
|
+
1024,
|
44
|
+
'The maximum size of KV cache buffer, including both prefill and decode.',
|
45
|
+
)
|
46
|
+
_QUANTIZE = flags.DEFINE_bool(
|
47
|
+
'quantize',
|
48
|
+
True,
|
49
|
+
'Whether the model should be quantized.',
|
50
|
+
)
|
27
51
|
|
28
|
-
def convert_smollm_to_tflite(
|
29
|
-
checkpoint_path: str,
|
30
|
-
prefill_seq_len: int = 512,
|
31
|
-
kv_cache_max_len: int = 1024,
|
32
|
-
quantize: bool = True,
|
33
|
-
):
|
34
|
-
"""Converts SmolLM model to multi-signature tflite model.
|
35
52
|
|
36
|
-
|
37
|
-
checkpoint_path (str): The filepath to the model checkpoint, or directory
|
38
|
-
holding the checkpoint.
|
39
|
-
prefill_seq_len (int, optional): The maximum size of prefill input tensor.
|
40
|
-
Defaults to 512.
|
41
|
-
kv_cache_max_len (int, optional): The maximum size of KV cache buffer,
|
42
|
-
including both prefill and decode. Defaults to 1024.
|
43
|
-
quantize (bool, optional): Whether the model should be quanized. Defaults
|
44
|
-
to True.
|
45
|
-
"""
|
53
|
+
def main(_):
|
46
54
|
pytorch_model = smollm.build_model(
|
47
|
-
|
55
|
+
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
|
48
56
|
)
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
kv = kv_utils.KVCache.from_model_config(pytorch_model.config)
|
55
|
-
|
56
|
-
quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
|
57
|
-
edge_model = (
|
58
|
-
ai_edge_torch.signature(
|
59
|
-
'prefill',
|
60
|
-
pytorch_model,
|
61
|
-
sample_kwargs={
|
62
|
-
'tokens': prefill_tokens,
|
63
|
-
'input_pos': prefill_input_pos,
|
64
|
-
'kv_cache': kv,
|
65
|
-
},
|
66
|
-
)
|
67
|
-
.signature(
|
68
|
-
'decode',
|
69
|
-
pytorch_model,
|
70
|
-
sample_kwargs={
|
71
|
-
'tokens': decode_token,
|
72
|
-
'input_pos': decode_input_pos,
|
73
|
-
'kv_cache': kv,
|
74
|
-
},
|
75
|
-
)
|
76
|
-
.convert(quant_config=quant_config)
|
77
|
-
)
|
78
|
-
quant_suffix = 'q8' if quantize else 'f32'
|
79
|
-
edge_model.export(
|
80
|
-
f'/tmp/smollm_{quant_suffix}_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
|
57
|
+
converter.convert_to_tflite(
|
58
|
+
pytorch_model,
|
59
|
+
tflite_path=_TFLITE_PATH.value,
|
60
|
+
prefill_seq_len=_PREFILL_SEQ_LEN.value,
|
61
|
+
quantize=_QUANTIZE.value,
|
81
62
|
)
|
82
63
|
|
83
64
|
|
84
65
|
if __name__ == '__main__':
|
85
|
-
|
86
|
-
convert_smollm_to_tflite(path)
|
66
|
+
app.run(main)
|
@@ -16,15 +16,10 @@
|
|
16
16
|
"""Example of building a SmolLM model."""
|
17
17
|
|
18
18
|
import copy
|
19
|
-
import os
|
20
|
-
import pathlib
|
21
19
|
|
22
20
|
from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
|
23
|
-
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
24
21
|
import ai_edge_torch.generative.layers.model_config as cfg
|
25
22
|
import ai_edge_torch.generative.utilities.loader as loading_utils
|
26
|
-
import numpy as np
|
27
|
-
import torch
|
28
23
|
from torch import nn
|
29
24
|
|
30
25
|
TENSOR_NAMES = copy.copy(tiny_llama.TENSOR_NAMES)
|
@@ -104,28 +99,3 @@ def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
|
|
104
99
|
loader.load(model, strict=False)
|
105
100
|
model.eval()
|
106
101
|
return model
|
107
|
-
|
108
|
-
|
109
|
-
def define_and_run(checkpoint_path: str) -> None:
|
110
|
-
"""Instantiates and runs a SmolLM model."""
|
111
|
-
|
112
|
-
current_dir = pathlib.Path(__file__).parent.resolve()
|
113
|
-
smollm_goldens = torch.load(current_dir / "smollm_lm_logits.pt")
|
114
|
-
kv_cache_max_len = 1024
|
115
|
-
model = build_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
|
116
|
-
idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
|
117
|
-
tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.int, device="cpu")
|
118
|
-
tokens[0, :4] = idx
|
119
|
-
input_pos = torch.arange(0, kv_cache_max_len, dtype=torch.int)
|
120
|
-
kv = kv_utils.KVCache.from_model_config(model.config)
|
121
|
-
output = model.forward(tokens, input_pos, kv)
|
122
|
-
assert torch.allclose(
|
123
|
-
smollm_goldens, output["logits"][0, idx.shape[1] - 1, :], atol=1e-05
|
124
|
-
)
|
125
|
-
|
126
|
-
|
127
|
-
if __name__ == "__main__":
|
128
|
-
input_checkpoint_path = os.path.join(
|
129
|
-
pathlib.Path.home(), "Downloads/llm_data/smollm"
|
130
|
-
)
|
131
|
-
define_and_run(input_checkpoint_path)
|
@@ -0,0 +1,59 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Verifies the reauthored SmolLM-135M model."""
|
17
|
+
|
18
|
+
import pathlib
|
19
|
+
|
20
|
+
from absl import app
|
21
|
+
from absl import flags
|
22
|
+
from ai_edge_torch.generative.examples.smollm import smollm
|
23
|
+
from ai_edge_torch.generative.utilities import verifier
|
24
|
+
import transformers
|
25
|
+
|
26
|
+
_PROMPTS = flags.DEFINE_multi_string(
|
27
|
+
"prompts",
|
28
|
+
"What is the meaning of life?",
|
29
|
+
"The input prompts to generate answers.",
|
30
|
+
)
|
31
|
+
|
32
|
+
|
33
|
+
def main(_):
|
34
|
+
checkpoint = "HuggingFaceTB/SmolLM-135M"
|
35
|
+
verifier.log_msg("Loading the original model from", checkpoint)
|
36
|
+
original_model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint)
|
37
|
+
|
38
|
+
# Locate the cached dir.
|
39
|
+
cached_config_file = transformers.utils.cached_file(
|
40
|
+
checkpoint, transformers.utils.CONFIG_NAME
|
41
|
+
)
|
42
|
+
reauthored_checkpoint = pathlib.Path(cached_config_file).parent
|
43
|
+
verifier.log_msg("Building the reauthored model from", reauthored_checkpoint)
|
44
|
+
reauthored_model = smollm.build_model(reauthored_checkpoint)
|
45
|
+
|
46
|
+
verifier.log_msg("Loading the tokenizer from", checkpoint)
|
47
|
+
tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
|
48
|
+
|
49
|
+
verifier.verify_reauthored_model(
|
50
|
+
original_model=original_model,
|
51
|
+
reauthored_model=reauthored_model,
|
52
|
+
tokenizer=tokenizer,
|
53
|
+
prompts=_PROMPTS.value,
|
54
|
+
atol=1e-04,
|
55
|
+
)
|
56
|
+
|
57
|
+
|
58
|
+
if __name__ == "__main__":
|
59
|
+
app.run(main)
|
@@ -336,6 +336,8 @@ class Diffusion(nn.Module):
|
|
336
336
|
cross_attention_block_config=unet_cfg.CrossAttentionBlock2DConfig(
|
337
337
|
query_dim=output_channel,
|
338
338
|
cross_dim=config.transformer_cross_attention_dim,
|
339
|
+
hidden_dim=output_channel,
|
340
|
+
output_dim=output_channel,
|
339
341
|
attention_batch_size=config.transformer_batch_size,
|
340
342
|
normalization_config=config.transformer_norm_config,
|
341
343
|
attention_config=build_attention_config(
|
@@ -406,6 +408,8 @@ class Diffusion(nn.Module):
|
|
406
408
|
cross_attention_block_config=unet_cfg.CrossAttentionBlock2DConfig(
|
407
409
|
query_dim=mid_block_channels,
|
408
410
|
cross_dim=config.transformer_cross_attention_dim,
|
411
|
+
hidden_dim=mid_block_channels,
|
412
|
+
output_dim=mid_block_channels,
|
409
413
|
attention_batch_size=config.transformer_batch_size,
|
410
414
|
normalization_config=config.transformer_norm_config,
|
411
415
|
attention_config=build_attention_config(
|
@@ -477,6 +481,8 @@ class Diffusion(nn.Module):
|
|
477
481
|
cross_attention_block_config=unet_cfg.CrossAttentionBlock2DConfig(
|
478
482
|
query_dim=output_channel,
|
479
483
|
cross_dim=config.transformer_cross_attention_dim,
|
484
|
+
hidden_dim=output_channel,
|
485
|
+
output_dim=output_channel,
|
480
486
|
attention_batch_size=config.transformer_batch_size,
|
481
487
|
normalization_config=config.transformer_norm_config,
|
482
488
|
attention_config=build_attention_config(
|
@@ -18,69 +18,49 @@
|
|
18
18
|
import os
|
19
19
|
import pathlib
|
20
20
|
|
21
|
-
import
|
21
|
+
from absl import app
|
22
|
+
from absl import flags
|
22
23
|
from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
|
23
|
-
from ai_edge_torch.generative.
|
24
|
-
from ai_edge_torch.generative.quantize import quant_recipes
|
25
|
-
import torch
|
24
|
+
from ai_edge_torch.generative.utilities import converter
|
26
25
|
|
26
|
+
_CHECKPOINT_PATH = flags.DEFINE_string(
|
27
|
+
'checkpoint_path',
|
28
|
+
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/tiny_llama'),
|
29
|
+
'The path to the model checkpoint, or directory holding the checkpoint.',
|
30
|
+
)
|
31
|
+
_TFLITE_PATH = flags.DEFINE_string(
|
32
|
+
'tflite_path',
|
33
|
+
'/tmp/tiny_llama_q8_seq512_ekv1024.tflite',
|
34
|
+
'The tflite file path to export.',
|
35
|
+
)
|
36
|
+
_PREFILL_SEQ_LEN = flags.DEFINE_integer(
|
37
|
+
'prefill_seq_len',
|
38
|
+
512,
|
39
|
+
'The maximum size of prefill input tensor.',
|
40
|
+
)
|
41
|
+
_KV_CACHE_MAX_LEN = flags.DEFINE_integer(
|
42
|
+
'kv_cache_max_len',
|
43
|
+
1024,
|
44
|
+
'The maximum size of KV cache buffer, including both prefill and decode.',
|
45
|
+
)
|
46
|
+
_QUANTIZE = flags.DEFINE_bool(
|
47
|
+
'quantize',
|
48
|
+
True,
|
49
|
+
'Whether the model should be quantized.',
|
50
|
+
)
|
27
51
|
|
28
|
-
def convert_tiny_llama_to_tflite(
|
29
|
-
checkpoint_path: str,
|
30
|
-
prefill_seq_len: int = 512,
|
31
|
-
kv_cache_max_len: int = 1024,
|
32
|
-
quantize: bool = True,
|
33
|
-
):
|
34
|
-
"""Converts TinyLlama model to multi-signature tflite model.
|
35
52
|
|
36
|
-
|
37
|
-
checkpoint_path (str): The filepath to the model checkpoint, or directory
|
38
|
-
holding the checkpoint.
|
39
|
-
prefill_seq_len (int, optional): The maximum size of prefill input tensor.
|
40
|
-
Defaults to 512.
|
41
|
-
kv_cache_max_len (int, optional): The maximum size of KV cache buffer,
|
42
|
-
including both prefill and decode. Defaults to 1024.
|
43
|
-
quantize (bool, optional): Whether the model should be quanized. Defaults
|
44
|
-
to True.
|
45
|
-
"""
|
53
|
+
def main(_):
|
46
54
|
pytorch_model = tiny_llama.build_model(
|
47
|
-
|
55
|
+
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
|
48
56
|
)
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
kv = kv_utils.KVCache.from_model_config(pytorch_model.config)
|
55
|
-
|
56
|
-
quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
|
57
|
-
edge_model = (
|
58
|
-
ai_edge_torch.signature(
|
59
|
-
'prefill',
|
60
|
-
pytorch_model,
|
61
|
-
sample_kwargs={
|
62
|
-
'tokens': prefill_tokens,
|
63
|
-
'input_pos': prefill_input_pos,
|
64
|
-
'kv_cache': kv,
|
65
|
-
},
|
66
|
-
)
|
67
|
-
.signature(
|
68
|
-
'decode',
|
69
|
-
pytorch_model,
|
70
|
-
sample_kwargs={
|
71
|
-
'tokens': decode_token,
|
72
|
-
'input_pos': decode_input_pos,
|
73
|
-
'kv_cache': kv,
|
74
|
-
},
|
75
|
-
)
|
76
|
-
.convert(quant_config=quant_config)
|
77
|
-
)
|
78
|
-
quant_suffix = 'q8' if quantize else 'f32'
|
79
|
-
edge_model.export(
|
80
|
-
f'/tmp/tiny_llama_{quant_suffix}_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
|
57
|
+
converter.convert_to_tflite(
|
58
|
+
pytorch_model,
|
59
|
+
tflite_path=_TFLITE_PATH.value,
|
60
|
+
prefill_seq_len=_PREFILL_SEQ_LEN.value,
|
61
|
+
quantize=_QUANTIZE.value,
|
81
62
|
)
|
82
63
|
|
83
64
|
|
84
65
|
if __name__ == '__main__':
|
85
|
-
|
86
|
-
convert_tiny_llama_to_tflite(path)
|
66
|
+
app.run(main)
|
@@ -15,16 +15,12 @@
|
|
15
15
|
|
16
16
|
"""Example of building a TinyLlama model."""
|
17
17
|
|
18
|
-
import os
|
19
|
-
import pathlib
|
20
|
-
|
21
18
|
from ai_edge_torch.generative.layers import attention
|
22
19
|
from ai_edge_torch.generative.layers import builder
|
23
20
|
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
24
21
|
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
25
22
|
import ai_edge_torch.generative.layers.model_config as cfg
|
26
23
|
import ai_edge_torch.generative.utilities.loader as loading_utils
|
27
|
-
import numpy as np
|
28
24
|
import torch
|
29
25
|
from torch import nn
|
30
26
|
|
@@ -179,28 +175,3 @@ def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
|
|
179
175
|
loader.load(model)
|
180
176
|
model.eval()
|
181
177
|
return model
|
182
|
-
|
183
|
-
|
184
|
-
def define_and_run(checkpoint_path: str) -> None:
|
185
|
-
"""Instantiates and runs a TinyLlama model."""
|
186
|
-
|
187
|
-
current_dir = pathlib.Path(__file__).parent.resolve()
|
188
|
-
tiny_llama_goldens = torch.load(current_dir / "tiny_llama_lm_logits.pt")
|
189
|
-
kv_cache_max_len = 1024
|
190
|
-
model = build_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
|
191
|
-
idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
|
192
|
-
tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.int, device="cpu")
|
193
|
-
tokens[0, :4] = idx
|
194
|
-
input_pos = torch.arange(0, kv_cache_max_len, dtype=torch.int)
|
195
|
-
kv = kv_utils.KVCache.from_model_config(model.config)
|
196
|
-
output = model.forward(tokens, input_pos, kv)
|
197
|
-
assert torch.allclose(
|
198
|
-
tiny_llama_goldens, output["logits"][0, idx.shape[1] - 1, :], atol=1e-02
|
199
|
-
)
|
200
|
-
|
201
|
-
|
202
|
-
if __name__ == "__main__":
|
203
|
-
input_checkpoint_path = os.path.join(
|
204
|
-
pathlib.Path.home(), "Downloads/llm_data/tiny_llama"
|
205
|
-
)
|
206
|
-
define_and_run(input_checkpoint_path)
|
@@ -0,0 +1,61 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Verifies the reauthored TinyLlama-1.1B model."""
|
17
|
+
|
18
|
+
import pathlib
|
19
|
+
|
20
|
+
from absl import app
|
21
|
+
from absl import flags
|
22
|
+
from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
|
23
|
+
from ai_edge_torch.generative.utilities import verifier
|
24
|
+
import transformers
|
25
|
+
|
26
|
+
_PROMPTS = flags.DEFINE_multi_string(
|
27
|
+
"prompts",
|
28
|
+
"Show me the program to add 2 and 3.",
|
29
|
+
"The input prompts to generate answers.",
|
30
|
+
)
|
31
|
+
|
32
|
+
|
33
|
+
def main(_):
|
34
|
+
checkpoint = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
|
35
|
+
verifier.log_msg("Loading the original model from", checkpoint)
|
36
|
+
original_model = transformers.AutoModelForCausalLM.from_pretrained(
|
37
|
+
checkpoint, trust_remote_code=True
|
38
|
+
)
|
39
|
+
|
40
|
+
# Locate the cached dir.
|
41
|
+
cached_config_file = transformers.utils.cached_file(
|
42
|
+
checkpoint, transformers.utils.CONFIG_NAME
|
43
|
+
)
|
44
|
+
reauthored_checkpoint = pathlib.Path(cached_config_file).parent
|
45
|
+
verifier.log_msg("Building the reauthored model from", reauthored_checkpoint)
|
46
|
+
reauthored_model = tiny_llama.build_model(reauthored_checkpoint)
|
47
|
+
|
48
|
+
verifier.log_msg("Loading the tokenizer from", checkpoint)
|
49
|
+
tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
|
50
|
+
|
51
|
+
verifier.verify_reauthored_model(
|
52
|
+
original_model=original_model,
|
53
|
+
reauthored_model=reauthored_model,
|
54
|
+
tokenizer=tokenizer,
|
55
|
+
prompts=_PROMPTS.value,
|
56
|
+
atol=1e-04,
|
57
|
+
)
|
58
|
+
|
59
|
+
|
60
|
+
if __name__ == "__main__":
|
61
|
+
app.run(main)
|
@@ -298,6 +298,8 @@ class CrossAttention(nn.Module):
|
|
298
298
|
batch_size: int,
|
299
299
|
query_dim: int,
|
300
300
|
cross_dim: int,
|
301
|
+
hidden_dim: int,
|
302
|
+
output_dim: int,
|
301
303
|
config: cfg.AttentionConfig,
|
302
304
|
enable_hlfb: bool,
|
303
305
|
):
|
@@ -307,6 +309,8 @@ class CrossAttention(nn.Module):
|
|
307
309
|
batch_size (int): batch size of the input tensor.
|
308
310
|
query_dim (int): query tensor's dimension.
|
309
311
|
cross_dim (int): cross attention's dimensions, for key and value tensors.
|
312
|
+
hidden_dim (int): hidden dimension that q, k, v tensors project to.
|
313
|
+
output_dim (int): output tensor's dimension.
|
310
314
|
config (cfg.AttentionConfig): attention specific configurations.
|
311
315
|
enable_hlfb (bool): whether hlfb is enabled or not.
|
312
316
|
"""
|
@@ -314,16 +318,16 @@ class CrossAttention(nn.Module):
|
|
314
318
|
self.config = config
|
315
319
|
self.n_heads = config.num_heads
|
316
320
|
self.q_projection = nn.Linear(
|
317
|
-
query_dim,
|
321
|
+
query_dim, hidden_dim, bias=config.qkv_use_bias
|
318
322
|
)
|
319
323
|
self.k_projection = nn.Linear(
|
320
|
-
cross_dim,
|
324
|
+
cross_dim, hidden_dim, bias=config.qkv_use_bias
|
321
325
|
)
|
322
326
|
self.v_projection = nn.Linear(
|
323
|
-
cross_dim,
|
327
|
+
cross_dim, hidden_dim, bias=config.qkv_use_bias
|
324
328
|
)
|
325
329
|
self.output_projection = nn.Linear(
|
326
|
-
|
330
|
+
hidden_dim, output_dim, bias=config.output_proj_use_bias
|
327
331
|
)
|
328
332
|
|
329
333
|
self.sdpa_func = (
|
@@ -75,7 +75,9 @@ def build_norm(dim: int, config: cfg.NormalizationConfig):
|
|
75
75
|
zero_centered_gamma=config.zero_centered,
|
76
76
|
)
|
77
77
|
elif config.type == cfg.NormalizationType.LAYER_NORM:
|
78
|
-
return normalization.LayerNorm(
|
78
|
+
return normalization.LayerNorm(
|
79
|
+
dim, config.epsilon, config.enable_hlfb, config.use_input_shape
|
80
|
+
)
|
79
81
|
elif config.type == cfg.NormalizationType.GROUP_NORM:
|
80
82
|
return normalization.GroupNorm(
|
81
83
|
config.group_num, dim, config.epsilon, config.enable_hlfb
|
@@ -69,6 +69,9 @@ class NormalizationConfig:
|
|
69
69
|
zero_centered: bool = False
|
70
70
|
# Number of groups used in group normalization.
|
71
71
|
group_num: Optional[float] = None
|
72
|
+
# Whether to use the input shape to determine the dimension of normalization
|
73
|
+
# when type is LAYER_NORM.
|
74
|
+
use_input_shape: bool = True
|
72
75
|
|
73
76
|
|
74
77
|
@dataclass
|
@@ -78,7 +78,7 @@ class GroupNorm(torch.nn.Module):
|
|
78
78
|
group_num (int): Number of groups to separate the channels into.
|
79
79
|
dim (int): Dimension of the input tensor.
|
80
80
|
eps (float): A small float value to ensure numerical stability (default:
|
81
|
-
1e-
|
81
|
+
1e-5).
|
82
82
|
enable_hlfb (bool): Whether to convert this normalization into a single
|
83
83
|
op.
|
84
84
|
"""
|
@@ -112,7 +112,13 @@ class GroupNorm(torch.nn.Module):
|
|
112
112
|
|
113
113
|
class LayerNorm(torch.nn.Module):
|
114
114
|
|
115
|
-
def __init__(
|
115
|
+
def __init__(
|
116
|
+
self,
|
117
|
+
dim: int,
|
118
|
+
eps: float = 1e-5,
|
119
|
+
enable_hlfb: bool = False,
|
120
|
+
use_input_shape: bool = True,
|
121
|
+
):
|
116
122
|
"""Initialize the LayerNorm layer.
|
117
123
|
|
118
124
|
Args:
|
@@ -121,9 +127,12 @@ class LayerNorm(torch.nn.Module):
|
|
121
127
|
1e-6).
|
122
128
|
enable_hlfb (bool): Whether to convert this normalization into a single
|
123
129
|
op.
|
130
|
+
use_input_shape (bool): Whether to use the input shape to determine the
|
131
|
+
dimension of normalization (default: True).
|
124
132
|
"""
|
125
133
|
super().__init__()
|
126
134
|
self.enable_hlfb = enable_hlfb
|
135
|
+
self.use_input_shape = use_input_shape
|
127
136
|
self.eps = eps
|
128
137
|
self.weight = torch.nn.Parameter(torch.ones(dim))
|
129
138
|
self.bias = torch.nn.Parameter(torch.ones(dim))
|
@@ -139,19 +148,18 @@ class LayerNorm(torch.nn.Module):
|
|
139
148
|
"""
|
140
149
|
if self.enable_hlfb:
|
141
150
|
return layer_norm_with_hlfb(
|
142
|
-
x,
|
143
|
-
self.weight,
|
144
|
-
self.bias,
|
145
|
-
self.eps,
|
151
|
+
x, self.weight, self.bias, self.eps, self.use_input_shape
|
146
152
|
)
|
153
|
+
|
154
|
+
if self.use_input_shape:
|
155
|
+
normalized_shape = x.shape
|
156
|
+
weight = self.weight.broadcast_to(x.shape)
|
157
|
+
bias = self.bias.broadcast_to(x.shape)
|
147
158
|
else:
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
self.bias.broadcast_to(x.shape),
|
153
|
-
self.eps,
|
154
|
-
)
|
159
|
+
normalized_shape = self.weight.shape
|
160
|
+
weight = self.weight
|
161
|
+
bias = self.bias
|
162
|
+
return F.layer_norm(x, normalized_shape, weight, bias, self.eps)
|
155
163
|
|
156
164
|
|
157
165
|
def group_norm_with_hlfb(
|
@@ -193,6 +201,7 @@ def layer_norm_with_hlfb(
|
|
193
201
|
w: torch.Tensor,
|
194
202
|
b: torch.Tensor,
|
195
203
|
eps: float,
|
204
|
+
use_input_shape: bool,
|
196
205
|
):
|
197
206
|
"""Layer Normalization with high-level function boundary enabled.
|
198
207
|
|
@@ -201,18 +210,20 @@ def layer_norm_with_hlfb(
|
|
201
210
|
w (torch.Tensor): The weight tensor for the normalization.
|
202
211
|
b (torch.Tensor): The bias tensor for the normalization.
|
203
212
|
eps (float): A small float value to ensure numerical stability.
|
213
|
+
use_input_shape (bool): Whether to use the input shape to determine the
|
214
|
+
dimension of normalization.
|
204
215
|
|
205
216
|
Returns:
|
206
217
|
The output tensor of Layer Normalization.
|
207
218
|
"""
|
208
219
|
builder = StableHLOCompositeBuilder(name="odml.layer_norm", attr={"eps": eps})
|
209
220
|
x, w, b = builder.mark_inputs(x, w, b)
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
)
|
221
|
+
if use_input_shape:
|
222
|
+
normalized_shape = x.shape
|
223
|
+
w = w.broadcast_to(x.shape)
|
224
|
+
b = b.broadcast_to(x.shape)
|
225
|
+
else:
|
226
|
+
normalized_shape = w.shape
|
227
|
+
y = F.layer_norm(x, normalized_shape, w, b, eps=eps)
|
217
228
|
y = builder.mark_outputs(y)
|
218
229
|
return y
|
@@ -119,15 +119,25 @@ def scaled_dot_product_attention_with_hlfb(
|
|
119
119
|
# Handle the GQA case, where q.shape[1] % k.shape[1] == 0.
|
120
120
|
k = k.repeat_interleave(q.shape[1] // k.shape[1], dim=1)
|
121
121
|
v = v.repeat_interleave(q.shape[1] // v.shape[1], dim=1)
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
122
|
+
if softcap is None:
|
123
|
+
y = F.scaled_dot_product_attention(
|
124
|
+
q,
|
125
|
+
k,
|
126
|
+
v,
|
127
|
+
attn_mask=mask,
|
128
|
+
dropout_p=0.0,
|
129
|
+
is_causal=mask is None,
|
130
|
+
scale=scale,
|
131
|
+
)
|
132
|
+
else:
|
133
|
+
q.mul_(scale)
|
134
|
+
scores = q @ k.transpose(-1, -2)
|
135
|
+
scores = scores / softcap
|
136
|
+
scores = torch.tanh(scores)
|
137
|
+
scores = scores * softcap
|
138
|
+
scores = scores + mask
|
139
|
+
out = F.softmax(scores.float(), dim=-1).type_as(q)
|
140
|
+
y = torch.matmul(out, v)
|
131
141
|
|
132
142
|
result = y.transpose(1, 2)
|
133
143
|
result = builder.mark_outputs(result)
|