ai-edge-torch-nightly 0.3.0.dev20240924__py3-none-any.whl → 0.3.0.dev20240928__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- ai_edge_torch/generative/examples/gemma/gemma1.py +2 -6
- ai_edge_torch/generative/examples/gemma/gemma2.py +2 -10
- ai_edge_torch/generative/examples/gemma/verify_gemma1.py +3 -2
- ai_edge_torch/generative/examples/gemma/verify_gemma2.py +3 -2
- ai_edge_torch/generative/examples/gemma/verify_util.py +15 -25
- ai_edge_torch/generative/examples/llama/__init__.py +14 -0
- ai_edge_torch/generative/examples/llama/convert_3b_to_tflite.py +68 -0
- ai_edge_torch/generative/examples/llama/convert_to_tflite.py +68 -0
- ai_edge_torch/generative/examples/llama/llama.py +204 -0
- ai_edge_torch/generative/examples/llama/verify.py +73 -0
- ai_edge_torch/generative/examples/llama/verify_3b.py +73 -0
- ai_edge_torch/generative/examples/openelm/openelm.py +2 -6
- ai_edge_torch/generative/examples/openelm/verify.py +19 -11
- ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py +68 -0
- ai_edge_torch/generative/examples/phi/phi2.py +2 -6
- ai_edge_torch/generative/examples/phi/phi3.py +279 -0
- ai_edge_torch/generative/examples/phi/verify.py +13 -13
- ai_edge_torch/generative/examples/phi/verify_phi3.py +69 -0
- ai_edge_torch/generative/examples/smollm/smollm.py +1 -0
- ai_edge_torch/generative/examples/smollm/verify.py +19 -9
- ai_edge_torch/generative/examples/stable_diffusion/clip.py +54 -1
- ai_edge_torch/generative/examples/stable_diffusion/decoder.py +58 -0
- ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +71 -1
- ai_edge_torch/generative/examples/t5/t5.py +0 -2
- ai_edge_torch/generative/examples/test_models/convert_toy_model.py +105 -0
- ai_edge_torch/generative/examples/test_models/toy_model.py +7 -41
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +5 -61
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +2 -6
- ai_edge_torch/generative/examples/tiny_llama/verify.py +20 -10
- ai_edge_torch/generative/layers/model_config.py +2 -0
- ai_edge_torch/generative/layers/normalization.py +2 -2
- ai_edge_torch/generative/layers/unet/blocks_2d.py +2 -2
- ai_edge_torch/generative/test/test_model_conversion_large.py +129 -0
- ai_edge_torch/generative/utilities/transformers_verifier.py +42 -0
- ai_edge_torch/generative/utilities/verifier.py +130 -114
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240924.dist-info → ai_edge_torch_nightly-0.3.0.dev20240928.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240924.dist-info → ai_edge_torch_nightly-0.3.0.dev20240928.dist-info}/RECORD +41 -30
- {ai_edge_torch_nightly-0.3.0.dev20240924.dist-info → ai_edge_torch_nightly-0.3.0.dev20240928.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240924.dist-info → ai_edge_torch_nightly-0.3.0.dev20240928.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240924.dist-info → ai_edge_torch_nightly-0.3.0.dev20240928.dist-info}/top_level.txt +0 -0
@@ -15,43 +15,53 @@
|
|
15
15
|
|
16
16
|
"""Verifies the reauthored SmolLM-135M model."""
|
17
17
|
|
18
|
+
import logging
|
18
19
|
import pathlib
|
19
20
|
|
20
21
|
from absl import app
|
21
22
|
from absl import flags
|
22
23
|
from ai_edge_torch.generative.examples.smollm import smollm
|
24
|
+
from ai_edge_torch.generative.utilities import transformers_verifier
|
23
25
|
from ai_edge_torch.generative.utilities import verifier
|
24
26
|
import transformers
|
25
27
|
|
28
|
+
|
26
29
|
_PROMPTS = flags.DEFINE_multi_string(
|
27
30
|
"prompts",
|
28
31
|
"What is the meaning of life?",
|
29
32
|
"The input prompts to generate answers.",
|
30
33
|
)
|
34
|
+
_MAX_NEW_TOKENS = flags.DEFINE_integer(
|
35
|
+
"max_new_tokens",
|
36
|
+
30,
|
37
|
+
"The maximum size of the generated tokens.",
|
38
|
+
)
|
31
39
|
|
32
40
|
|
33
41
|
def main(_):
|
34
42
|
checkpoint = "HuggingFaceTB/SmolLM-135M"
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
)
|
43
|
+
logging.info("Loading the original model from: %s", checkpoint)
|
44
|
+
original_model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint)
|
45
|
+
|
39
46
|
# Locate the cached dir.
|
40
47
|
cached_config_file = transformers.utils.cached_file(
|
41
48
|
checkpoint, transformers.utils.CONFIG_NAME
|
42
49
|
)
|
43
50
|
reauthored_checkpoint = pathlib.Path(cached_config_file).parent
|
44
|
-
|
51
|
+
logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
|
45
52
|
reauthored_model = smollm.build_model(reauthored_checkpoint)
|
46
53
|
|
47
|
-
|
54
|
+
logging.info("Loading the tokenizer from: %s", checkpoint)
|
48
55
|
tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
|
49
56
|
|
50
57
|
verifier.verify_reauthored_model(
|
51
|
-
original_model=
|
52
|
-
|
53
|
-
|
58
|
+
original_model=transformers_verifier.TransformersModelWrapper(
|
59
|
+
original_model
|
60
|
+
),
|
61
|
+
reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
|
62
|
+
tokenizer=verifier.TokenizerWrapper(tokenizer),
|
54
63
|
generate_prompts=_PROMPTS.value,
|
64
|
+
max_new_tokens=_MAX_NEW_TOKENS.value,
|
55
65
|
atol=1e-04,
|
56
66
|
)
|
57
67
|
|
@@ -48,7 +48,7 @@ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
|
48
48
|
|
49
49
|
|
50
50
|
class CLIP(nn.Module):
|
51
|
-
"""CLIP text encoder
|
51
|
+
"""CLIP text encoder.
|
52
52
|
|
53
53
|
For details, see https://arxiv.org/abs/2103.00020
|
54
54
|
"""
|
@@ -86,6 +86,7 @@ class CLIP(nn.Module):
|
|
86
86
|
|
87
87
|
|
88
88
|
def get_model_config() -> cfg.ModelConfig:
|
89
|
+
"""Get configs for the CLIP of Stable Diffusion v1.5."""
|
89
90
|
max_seq_len = 77
|
90
91
|
vocab_size = 49408
|
91
92
|
num_layers = 12
|
@@ -97,6 +98,58 @@ def get_model_config() -> cfg.ModelConfig:
|
|
97
98
|
num_heads=num_heads,
|
98
99
|
head_dim=embedding_dim // num_heads,
|
99
100
|
num_query_groups=num_query_groups,
|
101
|
+
rotary_base=0,
|
102
|
+
rotary_percentage=0.0,
|
103
|
+
qkv_use_bias=True,
|
104
|
+
qkv_transpose_before_split=True,
|
105
|
+
qkv_fused_interleaved=False,
|
106
|
+
output_proj_use_bias=True,
|
107
|
+
enable_kv_cache=False,
|
108
|
+
)
|
109
|
+
|
110
|
+
ff_config = cfg.FeedForwardConfig(
|
111
|
+
type=cfg.FeedForwardType.SEQUENTIAL,
|
112
|
+
activation=cfg.ActivationConfig(cfg.ActivationType.GELU_QUICK),
|
113
|
+
intermediate_size=embedding_dim * 4,
|
114
|
+
use_bias=True,
|
115
|
+
)
|
116
|
+
|
117
|
+
norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.LAYER_NORM)
|
118
|
+
|
119
|
+
block_config = cfg.TransformerBlockConfig(
|
120
|
+
attn_config=attn_config,
|
121
|
+
ff_config=ff_config,
|
122
|
+
pre_attention_norm_config=norm_config,
|
123
|
+
post_attention_norm_config=norm_config,
|
124
|
+
)
|
125
|
+
|
126
|
+
config = cfg.ModelConfig(
|
127
|
+
vocab_size=vocab_size,
|
128
|
+
num_layers=num_layers,
|
129
|
+
max_seq_len=max_seq_len,
|
130
|
+
embedding_dim=embedding_dim,
|
131
|
+
block_configs=block_config,
|
132
|
+
final_norm_config=norm_config,
|
133
|
+
enable_hlfb=True,
|
134
|
+
)
|
135
|
+
|
136
|
+
return config
|
137
|
+
|
138
|
+
|
139
|
+
def get_fake_model_config() -> cfg.ModelConfig:
|
140
|
+
"""Get fake configs for the CLIP of Stable Diffusion v1.5 for testing."""
|
141
|
+
max_seq_len = 6
|
142
|
+
vocab_size = 100
|
143
|
+
num_layers = 2
|
144
|
+
num_heads = 12
|
145
|
+
num_query_groups = 12
|
146
|
+
embedding_dim = 24
|
147
|
+
|
148
|
+
attn_config = cfg.AttentionConfig(
|
149
|
+
num_heads=num_heads,
|
150
|
+
head_dim=embedding_dim // num_heads,
|
151
|
+
num_query_groups=num_query_groups,
|
152
|
+
rotary_base=0,
|
100
153
|
rotary_percentage=0.0,
|
101
154
|
qkv_use_bias=True,
|
102
155
|
qkv_transpose_before_split=True,
|
@@ -295,6 +295,64 @@ def get_model_config() -> unet_cfg.AutoEncoderConfig:
|
|
295
295
|
enable_kv_cache=False,
|
296
296
|
qkv_transpose_before_split=True,
|
297
297
|
qkv_fused_interleaved=False,
|
298
|
+
rotary_base=0,
|
299
|
+
rotary_percentage=0.0,
|
300
|
+
),
|
301
|
+
enable_hlfb=False,
|
302
|
+
)
|
303
|
+
|
304
|
+
mid_block_config = unet_cfg.MidBlock2DConfig(
|
305
|
+
in_channels=block_out_channels[-1],
|
306
|
+
normalization_config=norm_config,
|
307
|
+
activation_config=layers_cfg.ActivationConfig(
|
308
|
+
layers_cfg.ActivationType.SILU
|
309
|
+
),
|
310
|
+
num_layers=1,
|
311
|
+
attention_block_config=att_config,
|
312
|
+
)
|
313
|
+
|
314
|
+
config = unet_cfg.AutoEncoderConfig(
|
315
|
+
in_channels=in_channels,
|
316
|
+
latent_channels=latent_channels,
|
317
|
+
out_channels=out_channels,
|
318
|
+
activation_config=layers_cfg.ActivationConfig(
|
319
|
+
layers_cfg.ActivationType.SILU
|
320
|
+
),
|
321
|
+
block_out_channels=block_out_channels,
|
322
|
+
scaling_factor=scaling_factor,
|
323
|
+
layers_per_block=layers_per_block,
|
324
|
+
normalization_config=norm_config,
|
325
|
+
mid_block_config=mid_block_config,
|
326
|
+
)
|
327
|
+
return config
|
328
|
+
|
329
|
+
|
330
|
+
def get_fake_model_config() -> unet_cfg.AutoEncoderConfig:
|
331
|
+
"""Get fake configs for the Decoder of Stable Diffusion v1.5 for testing."""
|
332
|
+
in_channels = 3
|
333
|
+
latent_channels = 4
|
334
|
+
out_channels = 3
|
335
|
+
block_out_channels = [2, 4]
|
336
|
+
scaling_factor = 0.18215
|
337
|
+
layers_per_block = 2
|
338
|
+
|
339
|
+
norm_config = layers_cfg.NormalizationConfig(
|
340
|
+
layers_cfg.NormalizationType.GROUP_NORM, group_num=2
|
341
|
+
)
|
342
|
+
|
343
|
+
att_config = unet_cfg.AttentionBlock2DConfig(
|
344
|
+
dim=block_out_channels[-1],
|
345
|
+
normalization_config=norm_config,
|
346
|
+
attention_config=layers_cfg.AttentionConfig(
|
347
|
+
num_heads=1,
|
348
|
+
head_dim=block_out_channels[-1],
|
349
|
+
num_query_groups=1,
|
350
|
+
qkv_use_bias=True,
|
351
|
+
output_proj_use_bias=True,
|
352
|
+
enable_kv_cache=False,
|
353
|
+
qkv_transpose_before_split=True,
|
354
|
+
qkv_fused_interleaved=False,
|
355
|
+
rotary_base=0,
|
298
356
|
rotary_percentage=0.0,
|
299
357
|
),
|
300
358
|
enable_hlfb=False,
|
@@ -199,6 +199,7 @@ def build_attention_config(
|
|
199
199
|
num_heads,
|
200
200
|
dim,
|
201
201
|
num_query_groups,
|
202
|
+
rotary_base=0,
|
202
203
|
rotary_percentage=0.0,
|
203
204
|
qkv_transpose_before_split=True,
|
204
205
|
qkv_use_bias=False,
|
@@ -211,6 +212,7 @@ def build_attention_config(
|
|
211
212
|
num_heads=num_heads,
|
212
213
|
head_dim=dim // num_heads,
|
213
214
|
num_query_groups=num_query_groups,
|
215
|
+
rotary_base=rotary_base,
|
214
216
|
rotary_percentage=rotary_percentage,
|
215
217
|
qkv_transpose_before_split=qkv_transpose_before_split,
|
216
218
|
qkv_use_bias=qkv_use_bias,
|
@@ -603,7 +605,7 @@ def get_model_config(batch_size: int) -> unet_cfg.DiffusionModelConfig:
|
|
603
605
|
# Transformer configs.
|
604
606
|
transformer_num_attention_heads = 8
|
605
607
|
transformer_batch_size = batch_size
|
606
|
-
transformer_cross_attention_dim = 768 # Embedding
|
608
|
+
transformer_cross_attention_dim = 768 # Embedding from CLIP model
|
607
609
|
transformer_pre_conv_norm_config = layers_cfg.NormalizationConfig(
|
608
610
|
layers_cfg.NormalizationType.GROUP_NORM, epsilon=1e-6, group_num=32
|
609
611
|
)
|
@@ -645,3 +647,71 @@ def get_model_config(batch_size: int) -> unet_cfg.DiffusionModelConfig:
|
|
645
647
|
final_norm_config=final_norm_config,
|
646
648
|
final_activation_type=final_activation_type,
|
647
649
|
)
|
650
|
+
|
651
|
+
|
652
|
+
def get_fake_model_config(batch_size: int) -> unet_cfg.DiffusionModelConfig:
|
653
|
+
"""Get fake configs for the Diffusion model of Stable Diffusion v1.5 for testing.
|
654
|
+
|
655
|
+
Args:
|
656
|
+
batch_size (int): the batch size of input.
|
657
|
+
|
658
|
+
Retruns:
|
659
|
+
The configuration of diffusion model of Stable Diffusion v1.5.
|
660
|
+
"""
|
661
|
+
in_channels = 4
|
662
|
+
out_channels = 4
|
663
|
+
block_out_channels = [2, 4, 8, 8]
|
664
|
+
layers_per_block = 1
|
665
|
+
downsample_padding = 1
|
666
|
+
|
667
|
+
# Residual configs.
|
668
|
+
residual_norm_config = layers_cfg.NormalizationConfig(
|
669
|
+
layers_cfg.NormalizationType.GROUP_NORM, group_num=2
|
670
|
+
)
|
671
|
+
residual_activation_type = layers_cfg.ActivationType.SILU
|
672
|
+
|
673
|
+
# Transformer configs.
|
674
|
+
transformer_num_attention_heads = 1
|
675
|
+
transformer_batch_size = batch_size
|
676
|
+
transformer_cross_attention_dim = 4 # Embedding from CLIP model
|
677
|
+
transformer_pre_conv_norm_config = layers_cfg.NormalizationConfig(
|
678
|
+
layers_cfg.NormalizationType.GROUP_NORM, epsilon=1e-6, group_num=2
|
679
|
+
)
|
680
|
+
transformer_norm_config = layers_cfg.NormalizationConfig(
|
681
|
+
layers_cfg.NormalizationType.LAYER_NORM
|
682
|
+
)
|
683
|
+
transformer_ff_activation_type = layers_cfg.ActivationType.GE_GLU
|
684
|
+
|
685
|
+
# Time embedding configs.
|
686
|
+
time_embedding_dim = 2
|
687
|
+
time_embedding_blocks_dim = 4
|
688
|
+
|
689
|
+
# Mid block configs.
|
690
|
+
mid_block_layers = 1
|
691
|
+
|
692
|
+
# Finaly layer configs.
|
693
|
+
final_norm_config = layers_cfg.NormalizationConfig(
|
694
|
+
layers_cfg.NormalizationType.GROUP_NORM, group_num=2
|
695
|
+
)
|
696
|
+
final_activation_type = layers_cfg.ActivationType.SILU
|
697
|
+
|
698
|
+
return unet_cfg.DiffusionModelConfig(
|
699
|
+
in_channels=in_channels,
|
700
|
+
out_channels=out_channels,
|
701
|
+
block_out_channels=block_out_channels,
|
702
|
+
layers_per_block=layers_per_block,
|
703
|
+
downsample_padding=downsample_padding,
|
704
|
+
residual_norm_config=residual_norm_config,
|
705
|
+
residual_activation_type=residual_activation_type,
|
706
|
+
transformer_batch_size=transformer_batch_size,
|
707
|
+
transformer_num_attention_heads=transformer_num_attention_heads,
|
708
|
+
transformer_cross_attention_dim=transformer_cross_attention_dim,
|
709
|
+
transformer_pre_conv_norm_config=transformer_pre_conv_norm_config,
|
710
|
+
transformer_norm_config=transformer_norm_config,
|
711
|
+
transformer_ff_activation_type=transformer_ff_activation_type,
|
712
|
+
mid_block_layers=mid_block_layers,
|
713
|
+
time_embedding_dim=time_embedding_dim,
|
714
|
+
time_embedding_blocks_dim=time_embedding_blocks_dim,
|
715
|
+
final_norm_config=final_norm_config,
|
716
|
+
final_activation_type=final_activation_type,
|
717
|
+
)
|
@@ -0,0 +1,105 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
# A toy example which has a single-layer transformer block.
|
16
|
+
from absl import app
|
17
|
+
import ai_edge_torch
|
18
|
+
from ai_edge_torch import lowertools
|
19
|
+
from ai_edge_torch.generative.examples.test_models import toy_model
|
20
|
+
from ai_edge_torch.generative.examples.test_models import toy_model_with_kv_cache
|
21
|
+
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
22
|
+
import torch
|
23
|
+
|
24
|
+
KV_CACHE_MAX_LEN = 100
|
25
|
+
|
26
|
+
|
27
|
+
def convert_toy_model(_) -> None:
|
28
|
+
"""Converts a toy model to tflite."""
|
29
|
+
model = toy_model.ToySingleLayerModel(toy_model.get_model_config())
|
30
|
+
idx = torch.unsqueeze(torch.arange(0, KV_CACHE_MAX_LEN), 0)
|
31
|
+
input_pos = torch.arange(0, KV_CACHE_MAX_LEN)
|
32
|
+
print('running an inference')
|
33
|
+
print(
|
34
|
+
model.forward(
|
35
|
+
idx,
|
36
|
+
input_pos,
|
37
|
+
)
|
38
|
+
)
|
39
|
+
|
40
|
+
# Convert model to tflite.
|
41
|
+
print('converting model to tflite')
|
42
|
+
edge_model = ai_edge_torch.convert(
|
43
|
+
model,
|
44
|
+
(
|
45
|
+
idx,
|
46
|
+
input_pos,
|
47
|
+
),
|
48
|
+
)
|
49
|
+
edge_model.export('/tmp/toy_model.tflite')
|
50
|
+
|
51
|
+
|
52
|
+
def _export_stablehlo_mlir(model, args):
|
53
|
+
ep = torch.export.export(model, args)
|
54
|
+
return lowertools.exported_program_to_mlir_text(ep)
|
55
|
+
|
56
|
+
|
57
|
+
def convert_toy_model_with_kv_cache(_) -> None:
|
58
|
+
"""Converts a toy model with kv cache to tflite."""
|
59
|
+
dump_mlir = False
|
60
|
+
|
61
|
+
config = toy_model_with_kv_cache.get_model_config()
|
62
|
+
model = toy_model_with_kv_cache.ToyModelWithKVCache(config)
|
63
|
+
model.eval()
|
64
|
+
print('running an inference')
|
65
|
+
kv = kv_utils.KVCache.from_model_config(config)
|
66
|
+
|
67
|
+
tokens, input_pos = toy_model_with_kv_cache.get_sample_prefill_inputs()
|
68
|
+
decode_token, decode_input_pos = (
|
69
|
+
toy_model_with_kv_cache.get_sample_decode_inputs()
|
70
|
+
)
|
71
|
+
print(model.forward(tokens, input_pos, kv))
|
72
|
+
|
73
|
+
if dump_mlir:
|
74
|
+
mlir_text = _export_stablehlo_mlir(model, (tokens, input_pos, kv))
|
75
|
+
with open('/tmp/toy_model_with_external_kv.stablehlo.mlir', 'w') as f:
|
76
|
+
f.write(mlir_text)
|
77
|
+
|
78
|
+
# Convert model to tflite with 2 signatures (prefill + decode).
|
79
|
+
print('converting toy model to tflite with 2 signatures (prefill + decode)')
|
80
|
+
edge_model = (
|
81
|
+
ai_edge_torch.signature(
|
82
|
+
'prefill',
|
83
|
+
model,
|
84
|
+
sample_kwargs={
|
85
|
+
'tokens': tokens,
|
86
|
+
'input_pos': input_pos,
|
87
|
+
'kv_cache': kv,
|
88
|
+
},
|
89
|
+
)
|
90
|
+
.signature(
|
91
|
+
'decode',
|
92
|
+
model,
|
93
|
+
sample_kwargs={
|
94
|
+
'tokens': decode_token,
|
95
|
+
'input_pos': decode_input_pos,
|
96
|
+
'kv_cache': kv,
|
97
|
+
},
|
98
|
+
)
|
99
|
+
.convert()
|
100
|
+
)
|
101
|
+
edge_model.export('/tmp/toy_external_kv_cache.tflite')
|
102
|
+
|
103
|
+
|
104
|
+
if __name__ == '__main__':
|
105
|
+
app.run(convert_toy_model)
|
@@ -15,13 +15,12 @@
|
|
15
15
|
# A toy example which has a single-layer transformer block.
|
16
16
|
from typing import Tuple
|
17
17
|
|
18
|
-
import
|
18
|
+
from ai_edge_torch.generative.layers import builder
|
19
19
|
from ai_edge_torch.generative.layers.attention import TransformerBlock
|
20
20
|
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
21
|
-
import ai_edge_torch.generative.layers.builder as builder
|
22
21
|
import ai_edge_torch.generative.layers.model_config as cfg
|
23
22
|
import torch
|
24
|
-
|
23
|
+
from torch import nn
|
25
24
|
|
26
25
|
RoPECache = Tuple[torch.Tensor, torch.Tensor]
|
27
26
|
KV_CACHE_MAX_LEN = 100
|
@@ -45,13 +44,10 @@ class ToySingleLayerModel(torch.nn.Module):
|
|
45
44
|
self.rope_cache = attn_utils.build_rope_cache(
|
46
45
|
size=config.max_seq_len,
|
47
46
|
dim=int(attn_config.rotary_percentage * attn_config.head_dim),
|
48
|
-
base=
|
49
|
-
condense_ratio=1,
|
50
|
-
dtype=torch.float32,
|
51
|
-
device=torch.device('cpu'),
|
47
|
+
base=attn_config.rotary_base,
|
52
48
|
)
|
53
49
|
self.mask_cache = attn_utils.build_causal_mask_cache(
|
54
|
-
size=config.max_seq_len,
|
50
|
+
size=config.max_seq_len,
|
55
51
|
)
|
56
52
|
self.config = config
|
57
53
|
|
@@ -94,13 +90,10 @@ class ToySingleLayerModelWeightSharing(torch.nn.Module):
|
|
94
90
|
self.rope_cache = attn_utils.build_rope_cache(
|
95
91
|
size=config.max_seq_len,
|
96
92
|
dim=int(attn_config.rotary_percentage * attn_config.head_dim),
|
97
|
-
base=
|
98
|
-
condense_ratio=1,
|
99
|
-
dtype=torch.float32,
|
100
|
-
device=torch.device('cpu'),
|
93
|
+
base=attn_config.rotary_base,
|
101
94
|
)
|
102
95
|
self.mask_cache = attn_utils.build_causal_mask_cache(
|
103
|
-
size=config.max_seq_len,
|
96
|
+
size=config.max_seq_len,
|
104
97
|
)
|
105
98
|
self.config = config
|
106
99
|
|
@@ -125,6 +118,7 @@ def get_model_config() -> cfg.ModelConfig:
|
|
125
118
|
num_heads=32,
|
126
119
|
head_dim=4,
|
127
120
|
num_query_groups=4,
|
121
|
+
rotary_base=10000,
|
128
122
|
rotary_percentage=1.0,
|
129
123
|
enable_kv_cache=False,
|
130
124
|
)
|
@@ -149,31 +143,3 @@ def get_model_config() -> cfg.ModelConfig:
|
|
149
143
|
final_norm_config=norm_config,
|
150
144
|
)
|
151
145
|
return config
|
152
|
-
|
153
|
-
|
154
|
-
def define_and_run() -> None:
|
155
|
-
model = ToySingleLayerModel(get_model_config())
|
156
|
-
idx = torch.unsqueeze(torch.arange(0, KV_CACHE_MAX_LEN), 0)
|
157
|
-
input_pos = torch.arange(0, KV_CACHE_MAX_LEN)
|
158
|
-
print('running an inference')
|
159
|
-
print(
|
160
|
-
model.forward(
|
161
|
-
idx,
|
162
|
-
input_pos,
|
163
|
-
)
|
164
|
-
)
|
165
|
-
|
166
|
-
# Convert model to tflite.
|
167
|
-
print('converting model to tflite')
|
168
|
-
edge_model = ai_edge_torch.convert(
|
169
|
-
model,
|
170
|
-
(
|
171
|
-
idx,
|
172
|
-
input_pos,
|
173
|
-
),
|
174
|
-
)
|
175
|
-
edge_model.export('/tmp/toy_model.tflite')
|
176
|
-
|
177
|
-
|
178
|
-
if __name__ == '__main__':
|
179
|
-
define_and_run()
|
@@ -17,15 +17,14 @@
|
|
17
17
|
|
18
18
|
from typing import Tuple
|
19
19
|
|
20
|
-
import
|
21
|
-
from ai_edge_torch import lowertools
|
20
|
+
from absl import app
|
22
21
|
from ai_edge_torch.generative.layers import attention
|
23
22
|
from ai_edge_torch.generative.layers import builder
|
24
23
|
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
25
24
|
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
26
25
|
import ai_edge_torch.generative.layers.model_config as cfg
|
27
26
|
import torch
|
28
|
-
|
27
|
+
from torch import nn
|
29
28
|
|
30
29
|
RoPECache = Tuple[torch.Tensor, torch.Tensor]
|
31
30
|
|
@@ -52,13 +51,10 @@ class ToyModelWithKVCache(torch.nn.Module):
|
|
52
51
|
self.rope_cache = attn_utils.build_rope_cache(
|
53
52
|
size=config.max_seq_len,
|
54
53
|
dim=int(attn_config.rotary_percentage * attn_config.head_dim),
|
55
|
-
base=
|
56
|
-
condense_ratio=1,
|
57
|
-
dtype=torch.float32,
|
58
|
-
device=torch.device('cpu'),
|
54
|
+
base=attn_config.rotary_base,
|
59
55
|
)
|
60
56
|
self.mask_cache = attn_utils.build_causal_mask_cache(
|
61
|
-
size=config.max_seq_len,
|
57
|
+
size=config.max_seq_len,
|
62
58
|
)
|
63
59
|
self.config = config
|
64
60
|
|
@@ -87,16 +83,12 @@ class ToyModelWithKVCache(torch.nn.Module):
|
|
87
83
|
return {'logits': self.lm_head(x), 'kv_cache': updated_kv_cache}
|
88
84
|
|
89
85
|
|
90
|
-
def _export_stablehlo_mlir(model, args):
|
91
|
-
ep = torch.export.export(model, args)
|
92
|
-
return lowertools.exported_program_to_mlir_text(ep)
|
93
|
-
|
94
|
-
|
95
86
|
def get_model_config() -> cfg.ModelConfig:
|
96
87
|
attn_config = cfg.AttentionConfig(
|
97
88
|
num_heads=32,
|
98
89
|
head_dim=4,
|
99
90
|
num_query_groups=4,
|
91
|
+
rotary_base=10000,
|
100
92
|
rotary_percentage=1.0,
|
101
93
|
)
|
102
94
|
ff_config = cfg.FeedForwardConfig(
|
@@ -133,51 +125,3 @@ def get_sample_decode_inputs() -> Tuple[torch.Tensor, torch.Tensor]:
|
|
133
125
|
tokens = torch.tensor([[1]], dtype=torch.int)
|
134
126
|
input_pos = torch.tensor([10])
|
135
127
|
return tokens, input_pos
|
136
|
-
|
137
|
-
|
138
|
-
def define_and_run() -> None:
|
139
|
-
dump_mlir = False
|
140
|
-
|
141
|
-
config = get_model_config()
|
142
|
-
model = ToyModelWithExternalKV(config)
|
143
|
-
model.eval()
|
144
|
-
print('running an inference')
|
145
|
-
kv = kv_utils.KVCache.from_model_config(config)
|
146
|
-
|
147
|
-
tokens, input_pos = get_sample_prefill_inputs()
|
148
|
-
decode_token, decode_input_pos = get_sample_decode_inputs()
|
149
|
-
print(model.forward(tokens, input_pos, kv))
|
150
|
-
|
151
|
-
if dump_mlir:
|
152
|
-
mlir_text = _export_stablehlo_mlir(model, (tokens, input_pos, kv))
|
153
|
-
with open('/tmp/toy_model_with_external_kv.stablehlo.mlir', 'w') as f:
|
154
|
-
f.write(mlir_text)
|
155
|
-
|
156
|
-
# Convert model to tflite with 2 signatures (prefill + decode).
|
157
|
-
print('converting toy model to tflite with 2 signatures (prefill + decode)')
|
158
|
-
edge_model = (
|
159
|
-
ai_edge_torch.signature(
|
160
|
-
'prefill',
|
161
|
-
model,
|
162
|
-
sample_kwargs={
|
163
|
-
'tokens': tokens,
|
164
|
-
'input_pos': input_pos,
|
165
|
-
'kv_cache': kv,
|
166
|
-
},
|
167
|
-
)
|
168
|
-
.signature(
|
169
|
-
'decode',
|
170
|
-
model,
|
171
|
-
sample_kwargs={
|
172
|
-
'tokens': decode_token,
|
173
|
-
'input_pos': decode_input_pos,
|
174
|
-
'kv_cache': kv,
|
175
|
-
},
|
176
|
-
)
|
177
|
-
.convert()
|
178
|
-
)
|
179
|
-
edge_model.export('/tmp/toy_external_kv_cache.tflite')
|
180
|
-
|
181
|
-
|
182
|
-
if __name__ == '__main__':
|
183
|
-
define_and_run()
|
@@ -67,15 +67,10 @@ class TinyLlama(nn.Module):
|
|
67
67
|
self.rope_cache = attn_utils.build_rope_cache(
|
68
68
|
size=config.kv_cache_max,
|
69
69
|
dim=int(attn_config.rotary_percentage * attn_config.head_dim),
|
70
|
-
base=
|
71
|
-
condense_ratio=1,
|
72
|
-
dtype=torch.float32,
|
73
|
-
device=torch.device("cpu"),
|
70
|
+
base=attn_config.rotary_base,
|
74
71
|
)
|
75
72
|
self.mask_cache = attn_utils.build_causal_mask_cache(
|
76
73
|
size=config.kv_cache_max,
|
77
|
-
dtype=torch.float32,
|
78
|
-
device=torch.device("cpu"),
|
79
74
|
)
|
80
75
|
self.config = config
|
81
76
|
|
@@ -132,6 +127,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
132
127
|
num_heads=32,
|
133
128
|
head_dim=64,
|
134
129
|
num_query_groups=4,
|
130
|
+
rotary_base=10000,
|
135
131
|
rotary_percentage=1.0,
|
136
132
|
)
|
137
133
|
ff_config = cfg.FeedForwardConfig(
|
@@ -15,45 +15,55 @@
|
|
15
15
|
|
16
16
|
"""Verifies the reauthored TinyLlama-1.1B model."""
|
17
17
|
|
18
|
+
import logging
|
18
19
|
import pathlib
|
19
20
|
|
20
21
|
from absl import app
|
21
22
|
from absl import flags
|
22
23
|
from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
|
24
|
+
from ai_edge_torch.generative.utilities import transformers_verifier
|
23
25
|
from ai_edge_torch.generative.utilities import verifier
|
24
26
|
import transformers
|
25
27
|
|
28
|
+
|
26
29
|
_PROMPTS = flags.DEFINE_multi_string(
|
27
30
|
"prompts",
|
28
31
|
"Show me the program to add 2 and 3.",
|
29
32
|
"The input prompts to generate answers.",
|
30
33
|
)
|
34
|
+
_MAX_NEW_TOKENS = flags.DEFINE_integer(
|
35
|
+
"max_new_tokens",
|
36
|
+
30,
|
37
|
+
"The maximum size of the generated tokens.",
|
38
|
+
)
|
31
39
|
|
32
40
|
|
33
41
|
def main(_):
|
34
42
|
checkpoint = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
checkpoint, trust_remote_code=True
|
39
|
-
),
|
43
|
+
logging.info("Loading the original model from: %s", checkpoint)
|
44
|
+
original_model = transformers.AutoModelForCausalLM.from_pretrained(
|
45
|
+
checkpoint, trust_remote_code=True
|
40
46
|
)
|
47
|
+
|
41
48
|
# Locate the cached dir.
|
42
49
|
cached_config_file = transformers.utils.cached_file(
|
43
50
|
checkpoint, transformers.utils.CONFIG_NAME
|
44
51
|
)
|
45
52
|
reauthored_checkpoint = pathlib.Path(cached_config_file).parent
|
46
|
-
|
53
|
+
logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
|
47
54
|
reauthored_model = tiny_llama.build_model(reauthored_checkpoint)
|
48
55
|
|
49
|
-
|
56
|
+
logging.info("Loading the tokenizer from: %s", checkpoint)
|
50
57
|
tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
|
51
58
|
|
52
59
|
verifier.verify_reauthored_model(
|
53
|
-
original_model=
|
54
|
-
|
55
|
-
|
60
|
+
original_model=transformers_verifier.TransformersModelWrapper(
|
61
|
+
original_model
|
62
|
+
),
|
63
|
+
reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
|
64
|
+
tokenizer=verifier.TokenizerWrapper(tokenizer),
|
56
65
|
generate_prompts=_PROMPTS.value,
|
66
|
+
max_new_tokens=_MAX_NEW_TOKENS.value,
|
57
67
|
atol=1e-04,
|
58
68
|
)
|
59
69
|
|