ai-edge-torch-nightly 0.3.0.dev20240924__py3-none-any.whl → 0.3.0.dev20240928__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/generative/examples/gemma/gemma1.py +2 -6
- ai_edge_torch/generative/examples/gemma/gemma2.py +2 -10
- ai_edge_torch/generative/examples/gemma/verify_gemma1.py +3 -2
- ai_edge_torch/generative/examples/gemma/verify_gemma2.py +3 -2
- ai_edge_torch/generative/examples/gemma/verify_util.py +15 -25
- ai_edge_torch/generative/examples/llama/__init__.py +14 -0
- ai_edge_torch/generative/examples/llama/convert_3b_to_tflite.py +68 -0
- ai_edge_torch/generative/examples/llama/convert_to_tflite.py +68 -0
- ai_edge_torch/generative/examples/llama/llama.py +204 -0
- ai_edge_torch/generative/examples/llama/verify.py +73 -0
- ai_edge_torch/generative/examples/llama/verify_3b.py +73 -0
- ai_edge_torch/generative/examples/openelm/openelm.py +2 -6
- ai_edge_torch/generative/examples/openelm/verify.py +19 -11
- ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py +68 -0
- ai_edge_torch/generative/examples/phi/phi2.py +2 -6
- ai_edge_torch/generative/examples/phi/phi3.py +279 -0
- ai_edge_torch/generative/examples/phi/verify.py +13 -13
- ai_edge_torch/generative/examples/phi/verify_phi3.py +69 -0
- ai_edge_torch/generative/examples/smollm/smollm.py +1 -0
- ai_edge_torch/generative/examples/smollm/verify.py +19 -9
- ai_edge_torch/generative/examples/stable_diffusion/clip.py +54 -1
- ai_edge_torch/generative/examples/stable_diffusion/decoder.py +58 -0
- ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +71 -1
- ai_edge_torch/generative/examples/t5/t5.py +0 -2
- ai_edge_torch/generative/examples/test_models/convert_toy_model.py +105 -0
- ai_edge_torch/generative/examples/test_models/toy_model.py +7 -41
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +5 -61
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +2 -6
- ai_edge_torch/generative/examples/tiny_llama/verify.py +20 -10
- ai_edge_torch/generative/layers/model_config.py +2 -0
- ai_edge_torch/generative/layers/normalization.py +2 -2
- ai_edge_torch/generative/layers/unet/blocks_2d.py +2 -2
- ai_edge_torch/generative/test/test_model_conversion_large.py +129 -0
- ai_edge_torch/generative/utilities/transformers_verifier.py +42 -0
- ai_edge_torch/generative/utilities/verifier.py +130 -114
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240924.dist-info → ai_edge_torch_nightly-0.3.0.dev20240928.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240924.dist-info → ai_edge_torch_nightly-0.3.0.dev20240928.dist-info}/RECORD +41 -30
- {ai_edge_torch_nightly-0.3.0.dev20240924.dist-info → ai_edge_torch_nightly-0.3.0.dev20240928.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240924.dist-info → ai_edge_torch_nightly-0.3.0.dev20240928.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240924.dist-info → ai_edge_torch_nightly-0.3.0.dev20240928.dist-info}/top_level.txt +0 -0
@@ -15,43 +15,53 @@
|
|
15
15
|
|
16
16
|
"""Verifies the reauthored SmolLM-135M model."""
|
17
17
|
|
18
|
+
import logging
|
18
19
|
import pathlib
|
19
20
|
|
20
21
|
from absl import app
|
21
22
|
from absl import flags
|
22
23
|
from ai_edge_torch.generative.examples.smollm import smollm
|
24
|
+
from ai_edge_torch.generative.utilities import transformers_verifier
|
23
25
|
from ai_edge_torch.generative.utilities import verifier
|
24
26
|
import transformers
|
25
27
|
|
28
|
+
|
26
29
|
_PROMPTS = flags.DEFINE_multi_string(
|
27
30
|
"prompts",
|
28
31
|
"What is the meaning of life?",
|
29
32
|
"The input prompts to generate answers.",
|
30
33
|
)
|
34
|
+
_MAX_NEW_TOKENS = flags.DEFINE_integer(
|
35
|
+
"max_new_tokens",
|
36
|
+
30,
|
37
|
+
"The maximum size of the generated tokens.",
|
38
|
+
)
|
31
39
|
|
32
40
|
|
33
41
|
def main(_):
|
34
42
|
checkpoint = "HuggingFaceTB/SmolLM-135M"
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
)
|
43
|
+
logging.info("Loading the original model from: %s", checkpoint)
|
44
|
+
original_model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint)
|
45
|
+
|
39
46
|
# Locate the cached dir.
|
40
47
|
cached_config_file = transformers.utils.cached_file(
|
41
48
|
checkpoint, transformers.utils.CONFIG_NAME
|
42
49
|
)
|
43
50
|
reauthored_checkpoint = pathlib.Path(cached_config_file).parent
|
44
|
-
|
51
|
+
logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
|
45
52
|
reauthored_model = smollm.build_model(reauthored_checkpoint)
|
46
53
|
|
47
|
-
|
54
|
+
logging.info("Loading the tokenizer from: %s", checkpoint)
|
48
55
|
tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
|
49
56
|
|
50
57
|
verifier.verify_reauthored_model(
|
51
|
-
original_model=
|
52
|
-
|
53
|
-
|
58
|
+
original_model=transformers_verifier.TransformersModelWrapper(
|
59
|
+
original_model
|
60
|
+
),
|
61
|
+
reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
|
62
|
+
tokenizer=verifier.TokenizerWrapper(tokenizer),
|
54
63
|
generate_prompts=_PROMPTS.value,
|
64
|
+
max_new_tokens=_MAX_NEW_TOKENS.value,
|
55
65
|
atol=1e-04,
|
56
66
|
)
|
57
67
|
|
@@ -48,7 +48,7 @@ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
|
48
48
|
|
49
49
|
|
50
50
|
class CLIP(nn.Module):
|
51
|
-
"""CLIP text encoder
|
51
|
+
"""CLIP text encoder.
|
52
52
|
|
53
53
|
For details, see https://arxiv.org/abs/2103.00020
|
54
54
|
"""
|
@@ -86,6 +86,7 @@ class CLIP(nn.Module):
|
|
86
86
|
|
87
87
|
|
88
88
|
def get_model_config() -> cfg.ModelConfig:
|
89
|
+
"""Get configs for the CLIP of Stable Diffusion v1.5."""
|
89
90
|
max_seq_len = 77
|
90
91
|
vocab_size = 49408
|
91
92
|
num_layers = 12
|
@@ -97,6 +98,58 @@ def get_model_config() -> cfg.ModelConfig:
|
|
97
98
|
num_heads=num_heads,
|
98
99
|
head_dim=embedding_dim // num_heads,
|
99
100
|
num_query_groups=num_query_groups,
|
101
|
+
rotary_base=0,
|
102
|
+
rotary_percentage=0.0,
|
103
|
+
qkv_use_bias=True,
|
104
|
+
qkv_transpose_before_split=True,
|
105
|
+
qkv_fused_interleaved=False,
|
106
|
+
output_proj_use_bias=True,
|
107
|
+
enable_kv_cache=False,
|
108
|
+
)
|
109
|
+
|
110
|
+
ff_config = cfg.FeedForwardConfig(
|
111
|
+
type=cfg.FeedForwardType.SEQUENTIAL,
|
112
|
+
activation=cfg.ActivationConfig(cfg.ActivationType.GELU_QUICK),
|
113
|
+
intermediate_size=embedding_dim * 4,
|
114
|
+
use_bias=True,
|
115
|
+
)
|
116
|
+
|
117
|
+
norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.LAYER_NORM)
|
118
|
+
|
119
|
+
block_config = cfg.TransformerBlockConfig(
|
120
|
+
attn_config=attn_config,
|
121
|
+
ff_config=ff_config,
|
122
|
+
pre_attention_norm_config=norm_config,
|
123
|
+
post_attention_norm_config=norm_config,
|
124
|
+
)
|
125
|
+
|
126
|
+
config = cfg.ModelConfig(
|
127
|
+
vocab_size=vocab_size,
|
128
|
+
num_layers=num_layers,
|
129
|
+
max_seq_len=max_seq_len,
|
130
|
+
embedding_dim=embedding_dim,
|
131
|
+
block_configs=block_config,
|
132
|
+
final_norm_config=norm_config,
|
133
|
+
enable_hlfb=True,
|
134
|
+
)
|
135
|
+
|
136
|
+
return config
|
137
|
+
|
138
|
+
|
139
|
+
def get_fake_model_config() -> cfg.ModelConfig:
|
140
|
+
"""Get fake configs for the CLIP of Stable Diffusion v1.5 for testing."""
|
141
|
+
max_seq_len = 6
|
142
|
+
vocab_size = 100
|
143
|
+
num_layers = 2
|
144
|
+
num_heads = 12
|
145
|
+
num_query_groups = 12
|
146
|
+
embedding_dim = 24
|
147
|
+
|
148
|
+
attn_config = cfg.AttentionConfig(
|
149
|
+
num_heads=num_heads,
|
150
|
+
head_dim=embedding_dim // num_heads,
|
151
|
+
num_query_groups=num_query_groups,
|
152
|
+
rotary_base=0,
|
100
153
|
rotary_percentage=0.0,
|
101
154
|
qkv_use_bias=True,
|
102
155
|
qkv_transpose_before_split=True,
|
@@ -295,6 +295,64 @@ def get_model_config() -> unet_cfg.AutoEncoderConfig:
|
|
295
295
|
enable_kv_cache=False,
|
296
296
|
qkv_transpose_before_split=True,
|
297
297
|
qkv_fused_interleaved=False,
|
298
|
+
rotary_base=0,
|
299
|
+
rotary_percentage=0.0,
|
300
|
+
),
|
301
|
+
enable_hlfb=False,
|
302
|
+
)
|
303
|
+
|
304
|
+
mid_block_config = unet_cfg.MidBlock2DConfig(
|
305
|
+
in_channels=block_out_channels[-1],
|
306
|
+
normalization_config=norm_config,
|
307
|
+
activation_config=layers_cfg.ActivationConfig(
|
308
|
+
layers_cfg.ActivationType.SILU
|
309
|
+
),
|
310
|
+
num_layers=1,
|
311
|
+
attention_block_config=att_config,
|
312
|
+
)
|
313
|
+
|
314
|
+
config = unet_cfg.AutoEncoderConfig(
|
315
|
+
in_channels=in_channels,
|
316
|
+
latent_channels=latent_channels,
|
317
|
+
out_channels=out_channels,
|
318
|
+
activation_config=layers_cfg.ActivationConfig(
|
319
|
+
layers_cfg.ActivationType.SILU
|
320
|
+
),
|
321
|
+
block_out_channels=block_out_channels,
|
322
|
+
scaling_factor=scaling_factor,
|
323
|
+
layers_per_block=layers_per_block,
|
324
|
+
normalization_config=norm_config,
|
325
|
+
mid_block_config=mid_block_config,
|
326
|
+
)
|
327
|
+
return config
|
328
|
+
|
329
|
+
|
330
|
+
def get_fake_model_config() -> unet_cfg.AutoEncoderConfig:
|
331
|
+
"""Get fake configs for the Decoder of Stable Diffusion v1.5 for testing."""
|
332
|
+
in_channels = 3
|
333
|
+
latent_channels = 4
|
334
|
+
out_channels = 3
|
335
|
+
block_out_channels = [2, 4]
|
336
|
+
scaling_factor = 0.18215
|
337
|
+
layers_per_block = 2
|
338
|
+
|
339
|
+
norm_config = layers_cfg.NormalizationConfig(
|
340
|
+
layers_cfg.NormalizationType.GROUP_NORM, group_num=2
|
341
|
+
)
|
342
|
+
|
343
|
+
att_config = unet_cfg.AttentionBlock2DConfig(
|
344
|
+
dim=block_out_channels[-1],
|
345
|
+
normalization_config=norm_config,
|
346
|
+
attention_config=layers_cfg.AttentionConfig(
|
347
|
+
num_heads=1,
|
348
|
+
head_dim=block_out_channels[-1],
|
349
|
+
num_query_groups=1,
|
350
|
+
qkv_use_bias=True,
|
351
|
+
output_proj_use_bias=True,
|
352
|
+
enable_kv_cache=False,
|
353
|
+
qkv_transpose_before_split=True,
|
354
|
+
qkv_fused_interleaved=False,
|
355
|
+
rotary_base=0,
|
298
356
|
rotary_percentage=0.0,
|
299
357
|
),
|
300
358
|
enable_hlfb=False,
|
@@ -199,6 +199,7 @@ def build_attention_config(
|
|
199
199
|
num_heads,
|
200
200
|
dim,
|
201
201
|
num_query_groups,
|
202
|
+
rotary_base=0,
|
202
203
|
rotary_percentage=0.0,
|
203
204
|
qkv_transpose_before_split=True,
|
204
205
|
qkv_use_bias=False,
|
@@ -211,6 +212,7 @@ def build_attention_config(
|
|
211
212
|
num_heads=num_heads,
|
212
213
|
head_dim=dim // num_heads,
|
213
214
|
num_query_groups=num_query_groups,
|
215
|
+
rotary_base=rotary_base,
|
214
216
|
rotary_percentage=rotary_percentage,
|
215
217
|
qkv_transpose_before_split=qkv_transpose_before_split,
|
216
218
|
qkv_use_bias=qkv_use_bias,
|
@@ -603,7 +605,7 @@ def get_model_config(batch_size: int) -> unet_cfg.DiffusionModelConfig:
|
|
603
605
|
# Transformer configs.
|
604
606
|
transformer_num_attention_heads = 8
|
605
607
|
transformer_batch_size = batch_size
|
606
|
-
transformer_cross_attention_dim = 768 # Embedding
|
608
|
+
transformer_cross_attention_dim = 768 # Embedding from CLIP model
|
607
609
|
transformer_pre_conv_norm_config = layers_cfg.NormalizationConfig(
|
608
610
|
layers_cfg.NormalizationType.GROUP_NORM, epsilon=1e-6, group_num=32
|
609
611
|
)
|
@@ -645,3 +647,71 @@ def get_model_config(batch_size: int) -> unet_cfg.DiffusionModelConfig:
|
|
645
647
|
final_norm_config=final_norm_config,
|
646
648
|
final_activation_type=final_activation_type,
|
647
649
|
)
|
650
|
+
|
651
|
+
|
652
|
+
def get_fake_model_config(batch_size: int) -> unet_cfg.DiffusionModelConfig:
|
653
|
+
"""Get fake configs for the Diffusion model of Stable Diffusion v1.5 for testing.
|
654
|
+
|
655
|
+
Args:
|
656
|
+
batch_size (int): the batch size of input.
|
657
|
+
|
658
|
+
Retruns:
|
659
|
+
The configuration of diffusion model of Stable Diffusion v1.5.
|
660
|
+
"""
|
661
|
+
in_channels = 4
|
662
|
+
out_channels = 4
|
663
|
+
block_out_channels = [2, 4, 8, 8]
|
664
|
+
layers_per_block = 1
|
665
|
+
downsample_padding = 1
|
666
|
+
|
667
|
+
# Residual configs.
|
668
|
+
residual_norm_config = layers_cfg.NormalizationConfig(
|
669
|
+
layers_cfg.NormalizationType.GROUP_NORM, group_num=2
|
670
|
+
)
|
671
|
+
residual_activation_type = layers_cfg.ActivationType.SILU
|
672
|
+
|
673
|
+
# Transformer configs.
|
674
|
+
transformer_num_attention_heads = 1
|
675
|
+
transformer_batch_size = batch_size
|
676
|
+
transformer_cross_attention_dim = 4 # Embedding from CLIP model
|
677
|
+
transformer_pre_conv_norm_config = layers_cfg.NormalizationConfig(
|
678
|
+
layers_cfg.NormalizationType.GROUP_NORM, epsilon=1e-6, group_num=2
|
679
|
+
)
|
680
|
+
transformer_norm_config = layers_cfg.NormalizationConfig(
|
681
|
+
layers_cfg.NormalizationType.LAYER_NORM
|
682
|
+
)
|
683
|
+
transformer_ff_activation_type = layers_cfg.ActivationType.GE_GLU
|
684
|
+
|
685
|
+
# Time embedding configs.
|
686
|
+
time_embedding_dim = 2
|
687
|
+
time_embedding_blocks_dim = 4
|
688
|
+
|
689
|
+
# Mid block configs.
|
690
|
+
mid_block_layers = 1
|
691
|
+
|
692
|
+
# Finaly layer configs.
|
693
|
+
final_norm_config = layers_cfg.NormalizationConfig(
|
694
|
+
layers_cfg.NormalizationType.GROUP_NORM, group_num=2
|
695
|
+
)
|
696
|
+
final_activation_type = layers_cfg.ActivationType.SILU
|
697
|
+
|
698
|
+
return unet_cfg.DiffusionModelConfig(
|
699
|
+
in_channels=in_channels,
|
700
|
+
out_channels=out_channels,
|
701
|
+
block_out_channels=block_out_channels,
|
702
|
+
layers_per_block=layers_per_block,
|
703
|
+
downsample_padding=downsample_padding,
|
704
|
+
residual_norm_config=residual_norm_config,
|
705
|
+
residual_activation_type=residual_activation_type,
|
706
|
+
transformer_batch_size=transformer_batch_size,
|
707
|
+
transformer_num_attention_heads=transformer_num_attention_heads,
|
708
|
+
transformer_cross_attention_dim=transformer_cross_attention_dim,
|
709
|
+
transformer_pre_conv_norm_config=transformer_pre_conv_norm_config,
|
710
|
+
transformer_norm_config=transformer_norm_config,
|
711
|
+
transformer_ff_activation_type=transformer_ff_activation_type,
|
712
|
+
mid_block_layers=mid_block_layers,
|
713
|
+
time_embedding_dim=time_embedding_dim,
|
714
|
+
time_embedding_blocks_dim=time_embedding_blocks_dim,
|
715
|
+
final_norm_config=final_norm_config,
|
716
|
+
final_activation_type=final_activation_type,
|
717
|
+
)
|
@@ -0,0 +1,105 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
# A toy example which has a single-layer transformer block.
|
16
|
+
from absl import app
|
17
|
+
import ai_edge_torch
|
18
|
+
from ai_edge_torch import lowertools
|
19
|
+
from ai_edge_torch.generative.examples.test_models import toy_model
|
20
|
+
from ai_edge_torch.generative.examples.test_models import toy_model_with_kv_cache
|
21
|
+
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
22
|
+
import torch
|
23
|
+
|
24
|
+
KV_CACHE_MAX_LEN = 100
|
25
|
+
|
26
|
+
|
27
|
+
def convert_toy_model(_) -> None:
|
28
|
+
"""Converts a toy model to tflite."""
|
29
|
+
model = toy_model.ToySingleLayerModel(toy_model.get_model_config())
|
30
|
+
idx = torch.unsqueeze(torch.arange(0, KV_CACHE_MAX_LEN), 0)
|
31
|
+
input_pos = torch.arange(0, KV_CACHE_MAX_LEN)
|
32
|
+
print('running an inference')
|
33
|
+
print(
|
34
|
+
model.forward(
|
35
|
+
idx,
|
36
|
+
input_pos,
|
37
|
+
)
|
38
|
+
)
|
39
|
+
|
40
|
+
# Convert model to tflite.
|
41
|
+
print('converting model to tflite')
|
42
|
+
edge_model = ai_edge_torch.convert(
|
43
|
+
model,
|
44
|
+
(
|
45
|
+
idx,
|
46
|
+
input_pos,
|
47
|
+
),
|
48
|
+
)
|
49
|
+
edge_model.export('/tmp/toy_model.tflite')
|
50
|
+
|
51
|
+
|
52
|
+
def _export_stablehlo_mlir(model, args):
|
53
|
+
ep = torch.export.export(model, args)
|
54
|
+
return lowertools.exported_program_to_mlir_text(ep)
|
55
|
+
|
56
|
+
|
57
|
+
def convert_toy_model_with_kv_cache(_) -> None:
|
58
|
+
"""Converts a toy model with kv cache to tflite."""
|
59
|
+
dump_mlir = False
|
60
|
+
|
61
|
+
config = toy_model_with_kv_cache.get_model_config()
|
62
|
+
model = toy_model_with_kv_cache.ToyModelWithKVCache(config)
|
63
|
+
model.eval()
|
64
|
+
print('running an inference')
|
65
|
+
kv = kv_utils.KVCache.from_model_config(config)
|
66
|
+
|
67
|
+
tokens, input_pos = toy_model_with_kv_cache.get_sample_prefill_inputs()
|
68
|
+
decode_token, decode_input_pos = (
|
69
|
+
toy_model_with_kv_cache.get_sample_decode_inputs()
|
70
|
+
)
|
71
|
+
print(model.forward(tokens, input_pos, kv))
|
72
|
+
|
73
|
+
if dump_mlir:
|
74
|
+
mlir_text = _export_stablehlo_mlir(model, (tokens, input_pos, kv))
|
75
|
+
with open('/tmp/toy_model_with_external_kv.stablehlo.mlir', 'w') as f:
|
76
|
+
f.write(mlir_text)
|
77
|
+
|
78
|
+
# Convert model to tflite with 2 signatures (prefill + decode).
|
79
|
+
print('converting toy model to tflite with 2 signatures (prefill + decode)')
|
80
|
+
edge_model = (
|
81
|
+
ai_edge_torch.signature(
|
82
|
+
'prefill',
|
83
|
+
model,
|
84
|
+
sample_kwargs={
|
85
|
+
'tokens': tokens,
|
86
|
+
'input_pos': input_pos,
|
87
|
+
'kv_cache': kv,
|
88
|
+
},
|
89
|
+
)
|
90
|
+
.signature(
|
91
|
+
'decode',
|
92
|
+
model,
|
93
|
+
sample_kwargs={
|
94
|
+
'tokens': decode_token,
|
95
|
+
'input_pos': decode_input_pos,
|
96
|
+
'kv_cache': kv,
|
97
|
+
},
|
98
|
+
)
|
99
|
+
.convert()
|
100
|
+
)
|
101
|
+
edge_model.export('/tmp/toy_external_kv_cache.tflite')
|
102
|
+
|
103
|
+
|
104
|
+
if __name__ == '__main__':
|
105
|
+
app.run(convert_toy_model)
|
@@ -15,13 +15,12 @@
|
|
15
15
|
# A toy example which has a single-layer transformer block.
|
16
16
|
from typing import Tuple
|
17
17
|
|
18
|
-
import
|
18
|
+
from ai_edge_torch.generative.layers import builder
|
19
19
|
from ai_edge_torch.generative.layers.attention import TransformerBlock
|
20
20
|
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
21
|
-
import ai_edge_torch.generative.layers.builder as builder
|
22
21
|
import ai_edge_torch.generative.layers.model_config as cfg
|
23
22
|
import torch
|
24
|
-
|
23
|
+
from torch import nn
|
25
24
|
|
26
25
|
RoPECache = Tuple[torch.Tensor, torch.Tensor]
|
27
26
|
KV_CACHE_MAX_LEN = 100
|
@@ -45,13 +44,10 @@ class ToySingleLayerModel(torch.nn.Module):
|
|
45
44
|
self.rope_cache = attn_utils.build_rope_cache(
|
46
45
|
size=config.max_seq_len,
|
47
46
|
dim=int(attn_config.rotary_percentage * attn_config.head_dim),
|
48
|
-
base=
|
49
|
-
condense_ratio=1,
|
50
|
-
dtype=torch.float32,
|
51
|
-
device=torch.device('cpu'),
|
47
|
+
base=attn_config.rotary_base,
|
52
48
|
)
|
53
49
|
self.mask_cache = attn_utils.build_causal_mask_cache(
|
54
|
-
size=config.max_seq_len,
|
50
|
+
size=config.max_seq_len,
|
55
51
|
)
|
56
52
|
self.config = config
|
57
53
|
|
@@ -94,13 +90,10 @@ class ToySingleLayerModelWeightSharing(torch.nn.Module):
|
|
94
90
|
self.rope_cache = attn_utils.build_rope_cache(
|
95
91
|
size=config.max_seq_len,
|
96
92
|
dim=int(attn_config.rotary_percentage * attn_config.head_dim),
|
97
|
-
base=
|
98
|
-
condense_ratio=1,
|
99
|
-
dtype=torch.float32,
|
100
|
-
device=torch.device('cpu'),
|
93
|
+
base=attn_config.rotary_base,
|
101
94
|
)
|
102
95
|
self.mask_cache = attn_utils.build_causal_mask_cache(
|
103
|
-
size=config.max_seq_len,
|
96
|
+
size=config.max_seq_len,
|
104
97
|
)
|
105
98
|
self.config = config
|
106
99
|
|
@@ -125,6 +118,7 @@ def get_model_config() -> cfg.ModelConfig:
|
|
125
118
|
num_heads=32,
|
126
119
|
head_dim=4,
|
127
120
|
num_query_groups=4,
|
121
|
+
rotary_base=10000,
|
128
122
|
rotary_percentage=1.0,
|
129
123
|
enable_kv_cache=False,
|
130
124
|
)
|
@@ -149,31 +143,3 @@ def get_model_config() -> cfg.ModelConfig:
|
|
149
143
|
final_norm_config=norm_config,
|
150
144
|
)
|
151
145
|
return config
|
152
|
-
|
153
|
-
|
154
|
-
def define_and_run() -> None:
|
155
|
-
model = ToySingleLayerModel(get_model_config())
|
156
|
-
idx = torch.unsqueeze(torch.arange(0, KV_CACHE_MAX_LEN), 0)
|
157
|
-
input_pos = torch.arange(0, KV_CACHE_MAX_LEN)
|
158
|
-
print('running an inference')
|
159
|
-
print(
|
160
|
-
model.forward(
|
161
|
-
idx,
|
162
|
-
input_pos,
|
163
|
-
)
|
164
|
-
)
|
165
|
-
|
166
|
-
# Convert model to tflite.
|
167
|
-
print('converting model to tflite')
|
168
|
-
edge_model = ai_edge_torch.convert(
|
169
|
-
model,
|
170
|
-
(
|
171
|
-
idx,
|
172
|
-
input_pos,
|
173
|
-
),
|
174
|
-
)
|
175
|
-
edge_model.export('/tmp/toy_model.tflite')
|
176
|
-
|
177
|
-
|
178
|
-
if __name__ == '__main__':
|
179
|
-
define_and_run()
|
@@ -17,15 +17,14 @@
|
|
17
17
|
|
18
18
|
from typing import Tuple
|
19
19
|
|
20
|
-
import
|
21
|
-
from ai_edge_torch import lowertools
|
20
|
+
from absl import app
|
22
21
|
from ai_edge_torch.generative.layers import attention
|
23
22
|
from ai_edge_torch.generative.layers import builder
|
24
23
|
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
25
24
|
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
26
25
|
import ai_edge_torch.generative.layers.model_config as cfg
|
27
26
|
import torch
|
28
|
-
|
27
|
+
from torch import nn
|
29
28
|
|
30
29
|
RoPECache = Tuple[torch.Tensor, torch.Tensor]
|
31
30
|
|
@@ -52,13 +51,10 @@ class ToyModelWithKVCache(torch.nn.Module):
|
|
52
51
|
self.rope_cache = attn_utils.build_rope_cache(
|
53
52
|
size=config.max_seq_len,
|
54
53
|
dim=int(attn_config.rotary_percentage * attn_config.head_dim),
|
55
|
-
base=
|
56
|
-
condense_ratio=1,
|
57
|
-
dtype=torch.float32,
|
58
|
-
device=torch.device('cpu'),
|
54
|
+
base=attn_config.rotary_base,
|
59
55
|
)
|
60
56
|
self.mask_cache = attn_utils.build_causal_mask_cache(
|
61
|
-
size=config.max_seq_len,
|
57
|
+
size=config.max_seq_len,
|
62
58
|
)
|
63
59
|
self.config = config
|
64
60
|
|
@@ -87,16 +83,12 @@ class ToyModelWithKVCache(torch.nn.Module):
|
|
87
83
|
return {'logits': self.lm_head(x), 'kv_cache': updated_kv_cache}
|
88
84
|
|
89
85
|
|
90
|
-
def _export_stablehlo_mlir(model, args):
|
91
|
-
ep = torch.export.export(model, args)
|
92
|
-
return lowertools.exported_program_to_mlir_text(ep)
|
93
|
-
|
94
|
-
|
95
86
|
def get_model_config() -> cfg.ModelConfig:
|
96
87
|
attn_config = cfg.AttentionConfig(
|
97
88
|
num_heads=32,
|
98
89
|
head_dim=4,
|
99
90
|
num_query_groups=4,
|
91
|
+
rotary_base=10000,
|
100
92
|
rotary_percentage=1.0,
|
101
93
|
)
|
102
94
|
ff_config = cfg.FeedForwardConfig(
|
@@ -133,51 +125,3 @@ def get_sample_decode_inputs() -> Tuple[torch.Tensor, torch.Tensor]:
|
|
133
125
|
tokens = torch.tensor([[1]], dtype=torch.int)
|
134
126
|
input_pos = torch.tensor([10])
|
135
127
|
return tokens, input_pos
|
136
|
-
|
137
|
-
|
138
|
-
def define_and_run() -> None:
|
139
|
-
dump_mlir = False
|
140
|
-
|
141
|
-
config = get_model_config()
|
142
|
-
model = ToyModelWithExternalKV(config)
|
143
|
-
model.eval()
|
144
|
-
print('running an inference')
|
145
|
-
kv = kv_utils.KVCache.from_model_config(config)
|
146
|
-
|
147
|
-
tokens, input_pos = get_sample_prefill_inputs()
|
148
|
-
decode_token, decode_input_pos = get_sample_decode_inputs()
|
149
|
-
print(model.forward(tokens, input_pos, kv))
|
150
|
-
|
151
|
-
if dump_mlir:
|
152
|
-
mlir_text = _export_stablehlo_mlir(model, (tokens, input_pos, kv))
|
153
|
-
with open('/tmp/toy_model_with_external_kv.stablehlo.mlir', 'w') as f:
|
154
|
-
f.write(mlir_text)
|
155
|
-
|
156
|
-
# Convert model to tflite with 2 signatures (prefill + decode).
|
157
|
-
print('converting toy model to tflite with 2 signatures (prefill + decode)')
|
158
|
-
edge_model = (
|
159
|
-
ai_edge_torch.signature(
|
160
|
-
'prefill',
|
161
|
-
model,
|
162
|
-
sample_kwargs={
|
163
|
-
'tokens': tokens,
|
164
|
-
'input_pos': input_pos,
|
165
|
-
'kv_cache': kv,
|
166
|
-
},
|
167
|
-
)
|
168
|
-
.signature(
|
169
|
-
'decode',
|
170
|
-
model,
|
171
|
-
sample_kwargs={
|
172
|
-
'tokens': decode_token,
|
173
|
-
'input_pos': decode_input_pos,
|
174
|
-
'kv_cache': kv,
|
175
|
-
},
|
176
|
-
)
|
177
|
-
.convert()
|
178
|
-
)
|
179
|
-
edge_model.export('/tmp/toy_external_kv_cache.tflite')
|
180
|
-
|
181
|
-
|
182
|
-
if __name__ == '__main__':
|
183
|
-
define_and_run()
|
@@ -67,15 +67,10 @@ class TinyLlama(nn.Module):
|
|
67
67
|
self.rope_cache = attn_utils.build_rope_cache(
|
68
68
|
size=config.kv_cache_max,
|
69
69
|
dim=int(attn_config.rotary_percentage * attn_config.head_dim),
|
70
|
-
base=
|
71
|
-
condense_ratio=1,
|
72
|
-
dtype=torch.float32,
|
73
|
-
device=torch.device("cpu"),
|
70
|
+
base=attn_config.rotary_base,
|
74
71
|
)
|
75
72
|
self.mask_cache = attn_utils.build_causal_mask_cache(
|
76
73
|
size=config.kv_cache_max,
|
77
|
-
dtype=torch.float32,
|
78
|
-
device=torch.device("cpu"),
|
79
74
|
)
|
80
75
|
self.config = config
|
81
76
|
|
@@ -132,6 +127,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
132
127
|
num_heads=32,
|
133
128
|
head_dim=64,
|
134
129
|
num_query_groups=4,
|
130
|
+
rotary_base=10000,
|
135
131
|
rotary_percentage=1.0,
|
136
132
|
)
|
137
133
|
ff_config = cfg.FeedForwardConfig(
|
@@ -15,45 +15,55 @@
|
|
15
15
|
|
16
16
|
"""Verifies the reauthored TinyLlama-1.1B model."""
|
17
17
|
|
18
|
+
import logging
|
18
19
|
import pathlib
|
19
20
|
|
20
21
|
from absl import app
|
21
22
|
from absl import flags
|
22
23
|
from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
|
24
|
+
from ai_edge_torch.generative.utilities import transformers_verifier
|
23
25
|
from ai_edge_torch.generative.utilities import verifier
|
24
26
|
import transformers
|
25
27
|
|
28
|
+
|
26
29
|
_PROMPTS = flags.DEFINE_multi_string(
|
27
30
|
"prompts",
|
28
31
|
"Show me the program to add 2 and 3.",
|
29
32
|
"The input prompts to generate answers.",
|
30
33
|
)
|
34
|
+
_MAX_NEW_TOKENS = flags.DEFINE_integer(
|
35
|
+
"max_new_tokens",
|
36
|
+
30,
|
37
|
+
"The maximum size of the generated tokens.",
|
38
|
+
)
|
31
39
|
|
32
40
|
|
33
41
|
def main(_):
|
34
42
|
checkpoint = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
checkpoint, trust_remote_code=True
|
39
|
-
),
|
43
|
+
logging.info("Loading the original model from: %s", checkpoint)
|
44
|
+
original_model = transformers.AutoModelForCausalLM.from_pretrained(
|
45
|
+
checkpoint, trust_remote_code=True
|
40
46
|
)
|
47
|
+
|
41
48
|
# Locate the cached dir.
|
42
49
|
cached_config_file = transformers.utils.cached_file(
|
43
50
|
checkpoint, transformers.utils.CONFIG_NAME
|
44
51
|
)
|
45
52
|
reauthored_checkpoint = pathlib.Path(cached_config_file).parent
|
46
|
-
|
53
|
+
logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
|
47
54
|
reauthored_model = tiny_llama.build_model(reauthored_checkpoint)
|
48
55
|
|
49
|
-
|
56
|
+
logging.info("Loading the tokenizer from: %s", checkpoint)
|
50
57
|
tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
|
51
58
|
|
52
59
|
verifier.verify_reauthored_model(
|
53
|
-
original_model=
|
54
|
-
|
55
|
-
|
60
|
+
original_model=transformers_verifier.TransformersModelWrapper(
|
61
|
+
original_model
|
62
|
+
),
|
63
|
+
reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
|
64
|
+
tokenizer=verifier.TokenizerWrapper(tokenizer),
|
56
65
|
generate_prompts=_PROMPTS.value,
|
66
|
+
max_new_tokens=_MAX_NEW_TOKENS.value,
|
57
67
|
atol=1e-04,
|
58
68
|
)
|
59
69
|
|