ai-edge-torch-nightly 0.3.0.dev20240909__py3-none-any.whl → 0.3.0.dev20240913__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/_convert/test/test_convert.py +35 -13
- ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +31 -12
- ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +25 -6
- ai_edge_torch/generative/examples/gemma/gemma.py +50 -30
- ai_edge_torch/generative/examples/gemma/gemma2.py +85 -58
- ai_edge_torch/generative/examples/{experimental/phi → phi}/convert_to_tflite.py +11 -12
- ai_edge_torch/generative/examples/{experimental/phi → phi}/phi2.py +46 -43
- ai_edge_torch/generative/examples/{experimental/gemma → smallm}/convert_to_tflite.py +12 -14
- ai_edge_torch/generative/examples/smallm/smallm.py +122 -0
- ai_edge_torch/generative/examples/stable_diffusion/clip.py +11 -5
- ai_edge_torch/generative/examples/t5/t5.py +35 -22
- ai_edge_torch/generative/examples/t5/t5_attention.py +18 -13
- ai_edge_torch/generative/examples/test_models/toy_model.py +15 -13
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +74 -33
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +25 -6
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +55 -34
- ai_edge_torch/generative/layers/attention.py +77 -73
- ai_edge_torch/generative/layers/builder.py +5 -3
- ai_edge_torch/generative/layers/kv_cache.py +163 -51
- ai_edge_torch/generative/layers/model_config.py +38 -19
- ai_edge_torch/generative/layers/normalization.py +158 -0
- ai_edge_torch/generative/layers/unet/blocks_2d.py +0 -2
- ai_edge_torch/generative/test/{test_experimental_ekv.py → test_kv_cache.py} +12 -24
- ai_edge_torch/generative/test/test_loader.py +1 -1
- ai_edge_torch/generative/test/test_model_conversion.py +72 -34
- ai_edge_torch/generative/test/test_model_conversion_large.py +51 -23
- ai_edge_torch/generative/test/utils.py +54 -0
- ai_edge_torch/generative/utilities/loader.py +15 -15
- ai_edge_torch/generative/utilities/t5_loader.py +21 -20
- ai_edge_torch/odml_torch/lowerings/__init__.py +1 -0
- ai_edge_torch/odml_torch/lowerings/_convolution.py +196 -74
- ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +0 -2
- ai_edge_torch/odml_torch/lowerings/_layer_norm.py +78 -0
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240909.dist-info → ai_edge_torch_nightly-0.3.0.dev20240913.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240909.dist-info → ai_edge_torch_nightly-0.3.0.dev20240913.dist-info}/RECORD +41 -47
- ai_edge_torch/generative/examples/experimental/gemma/gemma.py +0 -219
- ai_edge_torch/generative/examples/experimental/phi/__init__.py +0 -14
- ai_edge_torch/generative/examples/experimental/tiny_llama/__init__.py +0 -14
- ai_edge_torch/generative/examples/experimental/tiny_llama/convert_to_tflite.py +0 -87
- ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py +0 -205
- ai_edge_torch/generative/examples/phi2/__init__.py +0 -14
- ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +0 -67
- ai_edge_torch/generative/examples/phi2/phi2.py +0 -189
- ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +0 -176
- /ai_edge_torch/generative/examples/{experimental → phi}/__init__.py +0 -0
- /ai_edge_torch/generative/examples/{experimental/gemma → smallm}/__init__.py +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240909.dist-info → ai_edge_torch_nightly-0.3.0.dev20240913.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240909.dist-info → ai_edge_torch_nightly-0.3.0.dev20240913.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240909.dist-info → ai_edge_torch_nightly-0.3.0.dev20240913.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,122 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Example of building a SmalLM model."""
|
17
|
+
|
18
|
+
import copy
|
19
|
+
import os
|
20
|
+
import pathlib
|
21
|
+
|
22
|
+
from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
|
23
|
+
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
24
|
+
import ai_edge_torch.generative.layers.model_config as cfg
|
25
|
+
import ai_edge_torch.generative.utilities.loader as loading_utils
|
26
|
+
import numpy as np
|
27
|
+
import torch
|
28
|
+
from torch import nn
|
29
|
+
|
30
|
+
TENSOR_NAMES = copy.copy(tiny_llama.TENSOR_NAMES)
|
31
|
+
# SmalLM re-uses the embedding as the head projection layer.
|
32
|
+
TENSOR_NAMES.lm_head = None
|
33
|
+
|
34
|
+
|
35
|
+
class SmalLM(tiny_llama.TinyLlama):
|
36
|
+
"""A SmalLM model built from the Edge Generative API layers.
|
37
|
+
|
38
|
+
SmalLM shares the same architecture as TinyLlama, but with different model
|
39
|
+
sizes.
|
40
|
+
"""
|
41
|
+
|
42
|
+
def __init__(self, config: cfg.ModelConfig):
|
43
|
+
super().__init__(config)
|
44
|
+
# SmalLM re-uses the embedding as the head projection layer.
|
45
|
+
self.lm_head.weight.data = self.tok_embedding.weight.data
|
46
|
+
|
47
|
+
|
48
|
+
def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
49
|
+
"""Returns the model config for a SmalLM 135M model.
|
50
|
+
|
51
|
+
Args:
|
52
|
+
kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
|
53
|
+
is 1024.
|
54
|
+
|
55
|
+
Returns:
|
56
|
+
The model config for a SmalLM model.
|
57
|
+
"""
|
58
|
+
attn_config = cfg.AttentionConfig(
|
59
|
+
num_heads=9,
|
60
|
+
head_dim=64,
|
61
|
+
num_query_groups=3,
|
62
|
+
rotary_percentage=1.0,
|
63
|
+
)
|
64
|
+
ff_config = cfg.FeedForwardConfig(
|
65
|
+
type=cfg.FeedForwardType.GATED,
|
66
|
+
activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
|
67
|
+
intermediate_size=1536,
|
68
|
+
)
|
69
|
+
norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
|
70
|
+
block_config = cfg.TransformerBlockConfig(
|
71
|
+
attn_config=attn_config,
|
72
|
+
ff_config=ff_config,
|
73
|
+
pre_attention_norm_config=norm_config,
|
74
|
+
post_attention_norm_config=norm_config,
|
75
|
+
)
|
76
|
+
config = cfg.ModelConfig(
|
77
|
+
vocab_size=49152,
|
78
|
+
num_layers=30,
|
79
|
+
max_seq_len=2048,
|
80
|
+
embedding_dim=576,
|
81
|
+
kv_cache_max_len=kv_cache_max_len,
|
82
|
+
block_configs=block_config,
|
83
|
+
final_norm_config=norm_config,
|
84
|
+
enable_hlfb=True,
|
85
|
+
)
|
86
|
+
return config
|
87
|
+
|
88
|
+
|
89
|
+
def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
|
90
|
+
config = get_model_config(**kwargs)
|
91
|
+
model = SmalLM(config)
|
92
|
+
loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
|
93
|
+
# Since embedding and lm-head use the same weight, we need to set strict
|
94
|
+
# to False.
|
95
|
+
loader.load(model, strict=False)
|
96
|
+
model.eval()
|
97
|
+
return model
|
98
|
+
|
99
|
+
|
100
|
+
def define_and_run(checkpoint_path: str) -> None:
|
101
|
+
"""Instantiates and runs a SmalLM model."""
|
102
|
+
|
103
|
+
current_dir = pathlib.Path(__file__).parent.resolve()
|
104
|
+
smallm_goldens = torch.load(current_dir / "smallm_lm_logits.pt")
|
105
|
+
kv_cache_max_len = 1024
|
106
|
+
model = build_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
|
107
|
+
idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
|
108
|
+
tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.long, device="cpu")
|
109
|
+
tokens[0, :4] = idx
|
110
|
+
input_pos = torch.arange(0, kv_cache_max_len)
|
111
|
+
kv = kv_utils.KVCache.from_model_config(model.config)
|
112
|
+
output = model.forward(tokens, input_pos, kv)
|
113
|
+
assert torch.allclose(
|
114
|
+
smallm_goldens, output["logits"][0, idx.shape[1] - 1, :], atol=1e-05
|
115
|
+
)
|
116
|
+
|
117
|
+
|
118
|
+
if __name__ == "__main__":
|
119
|
+
input_checkpoint_path = os.path.join(
|
120
|
+
pathlib.Path.home(), "Downloads/llm_data/smallm"
|
121
|
+
)
|
122
|
+
define_and_run(input_checkpoint_path)
|
@@ -61,8 +61,10 @@ class CLIP(nn.Module):
|
|
61
61
|
)
|
62
62
|
|
63
63
|
self.config = config
|
64
|
+
# CLIP has only one block config.
|
65
|
+
block_config = config.block_config(0)
|
64
66
|
self.transformer_blocks = nn.ModuleList(
|
65
|
-
TransformerBlock(config) for _ in range(config.num_layers)
|
67
|
+
TransformerBlock(block_config, config) for _ in range(config.num_layers)
|
66
68
|
)
|
67
69
|
self.final_norm = builder.build_norm(
|
68
70
|
config.embedding_dim, config.final_norm_config
|
@@ -112,15 +114,19 @@ def get_model_config() -> cfg.ModelConfig:
|
|
112
114
|
|
113
115
|
norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.LAYER_NORM)
|
114
116
|
|
117
|
+
block_config = cfg.TransformerBlockConfig(
|
118
|
+
attn_config=attn_config,
|
119
|
+
ff_config=ff_config,
|
120
|
+
pre_attention_norm_config=norm_config,
|
121
|
+
post_attention_norm_config=norm_config,
|
122
|
+
)
|
123
|
+
|
115
124
|
config = cfg.ModelConfig(
|
116
125
|
vocab_size=vocab_size,
|
117
126
|
num_layers=num_layers,
|
118
127
|
max_seq_len=max_seq_len,
|
119
128
|
embedding_dim=embedding_dim,
|
120
|
-
|
121
|
-
ff_config=ff_config,
|
122
|
-
pre_attention_norm_config=norm_config,
|
123
|
-
post_attention_norm_config=norm_config,
|
129
|
+
block_configs=block_config,
|
124
130
|
final_norm_config=norm_config,
|
125
131
|
enable_hlfb=True,
|
126
132
|
)
|
@@ -52,9 +52,15 @@ class T5Stack(nn.Module):
|
|
52
52
|
self.config = config
|
53
53
|
self.embed_tokens = embed_tokens
|
54
54
|
self.is_decoder = config.is_decoder
|
55
|
+
# T5 has only one block config.
|
56
|
+
block_config = config.block_config(0)
|
55
57
|
self.transformer_blocks = nn.ModuleList([
|
56
|
-
EncoderDecoderBlock(
|
57
|
-
|
58
|
+
EncoderDecoderBlock(
|
59
|
+
block_config,
|
60
|
+
config,
|
61
|
+
has_relative_attention_bias=bool(idx == 0),
|
62
|
+
)
|
63
|
+
for idx in range(config.num_layers)
|
58
64
|
])
|
59
65
|
self.final_norm = builder.build_norm(
|
60
66
|
config.embedding_dim, config.final_norm_config
|
@@ -73,13 +79,11 @@ class T5Stack(nn.Module):
|
|
73
79
|
torch.Tensor
|
74
80
|
] = None, # should be for decoder case
|
75
81
|
):
|
76
|
-
input_shape = input_ids.size()
|
77
82
|
inputs_embeds = self.embed_tokens(input_ids)
|
78
|
-
batch_size, seq_length = input_shape
|
79
83
|
hidden_states = inputs_embeds
|
80
84
|
position_bias = None
|
81
85
|
encoder_decoder_position_bias = None
|
82
|
-
for
|
86
|
+
for _, layer_module in enumerate(self.transformer_blocks):
|
83
87
|
# EncoderDecoderBlock.forward
|
84
88
|
hidden_states, position_bias, encoder_decoder_position_bias = (
|
85
89
|
layer_module(
|
@@ -111,7 +115,8 @@ class T5(nn.Module):
|
|
111
115
|
|
112
116
|
encoder_config = copy.deepcopy(config)
|
113
117
|
encoder_config.is_decoder = False
|
114
|
-
|
118
|
+
# T5 has only one block config.
|
119
|
+
encoder_config.block_config(0).attn_config.enable_kv_cache = False
|
115
120
|
self.encoder = T5Stack(encoder_config, self.tok_embedding)
|
116
121
|
|
117
122
|
decoder_config = copy.deepcopy(config)
|
@@ -137,20 +142,22 @@ class T5(nn.Module):
|
|
137
142
|
device=torch.device("cpu"),
|
138
143
|
)
|
139
144
|
|
145
|
+
# T5 has only one block config.
|
146
|
+
attn_config = config.block_config(0).attn_config
|
140
147
|
self.enc_rel_pos_mask = attn_utils.build_relative_position_buckets(
|
141
148
|
bidirectional=True,
|
142
149
|
query_length=config.kv_cache_max,
|
143
150
|
key_length=config.kv_cache_max,
|
144
|
-
num_buckets=
|
145
|
-
max_distance=
|
151
|
+
num_buckets=attn_config.relative_attention_num_buckets,
|
152
|
+
max_distance=attn_config.relative_attention_max_distance,
|
146
153
|
)
|
147
154
|
|
148
155
|
self.dec_rel_pos_mask = attn_utils.build_relative_position_buckets(
|
149
156
|
bidirectional=False,
|
150
157
|
query_length=config.kv_cache_max,
|
151
158
|
key_length=config.kv_cache_max,
|
152
|
-
num_buckets=
|
153
|
-
max_distance=
|
159
|
+
num_buckets=attn_config.relative_attention_num_buckets,
|
160
|
+
max_distance=attn_config.relative_attention_max_distance,
|
154
161
|
)
|
155
162
|
|
156
163
|
@torch.inference_mode
|
@@ -230,7 +237,8 @@ class T5Encoder(nn.Module):
|
|
230
237
|
|
231
238
|
encoder_config = copy.deepcopy(config)
|
232
239
|
encoder_config.is_decoder = False
|
233
|
-
|
240
|
+
# T5 has only one block config.
|
241
|
+
encoder_config.block_config(0).attn_config.enable_kv_cache = False
|
234
242
|
self.encoder = T5Stack(encoder_config, self.tok_embedding)
|
235
243
|
|
236
244
|
self.enc_attn_mask_cache = (
|
@@ -243,12 +251,14 @@ class T5Encoder(nn.Module):
|
|
243
251
|
.unsqueeze(0)
|
244
252
|
)
|
245
253
|
|
254
|
+
# T5 has only one block config.
|
255
|
+
attn_config = config.block_config(0).attn_config
|
246
256
|
self.enc_rel_pos_mask = attn_utils.build_relative_position_buckets(
|
247
257
|
bidirectional=True,
|
248
258
|
query_length=config.kv_cache_max,
|
249
259
|
key_length=config.kv_cache_max,
|
250
|
-
num_buckets=
|
251
|
-
max_distance=
|
260
|
+
num_buckets=attn_config.relative_attention_num_buckets,
|
261
|
+
max_distance=attn_config.relative_attention_max_distance,
|
252
262
|
)
|
253
263
|
|
254
264
|
@torch.inference_mode
|
@@ -313,12 +323,14 @@ class T5Decoder(nn.Module):
|
|
313
323
|
.unsqueeze(0)
|
314
324
|
)
|
315
325
|
|
326
|
+
# T5 has only one block config.
|
327
|
+
attn_config = config.block_config(0).attn_config
|
316
328
|
self.enc_rel_pos_mask = attn_utils.build_relative_position_buckets(
|
317
329
|
bidirectional=True,
|
318
330
|
query_length=config.kv_cache_max,
|
319
331
|
key_length=config.kv_cache_max,
|
320
|
-
num_buckets=
|
321
|
-
max_distance=
|
332
|
+
num_buckets=attn_config.relative_attention_num_buckets,
|
333
|
+
max_distance=attn_config.relative_attention_max_distance,
|
322
334
|
)
|
323
335
|
|
324
336
|
self.dec_attn_mask_cache = attn_utils.build_causal_mask_cache(
|
@@ -386,19 +398,20 @@ def get_model_config_t5() -> cfg.ModelConfig:
|
|
386
398
|
type=cfg.NormalizationType.RMS_NORM,
|
387
399
|
epsilon=1e-6,
|
388
400
|
)
|
389
|
-
|
390
|
-
config = cfg.ModelConfig(
|
391
|
-
vocab_size=32128,
|
392
|
-
num_layers=12,
|
393
|
-
max_seq_len=512,
|
394
|
-
embedding_dim=768,
|
401
|
+
block_config = cfg.TransformerBlockConfig(
|
395
402
|
attn_config=attn_config,
|
396
403
|
relative_attention=True,
|
397
404
|
ff_config=ff_config,
|
398
405
|
pre_attention_norm_config=norm_config,
|
399
406
|
post_attention_norm_config=norm_config,
|
407
|
+
)
|
408
|
+
config = cfg.ModelConfig(
|
409
|
+
vocab_size=32128,
|
410
|
+
num_layers=12,
|
411
|
+
max_seq_len=512,
|
412
|
+
embedding_dim=768,
|
413
|
+
block_configs=block_config,
|
400
414
|
final_norm_config=norm_config,
|
401
|
-
parallel_residual=False,
|
402
415
|
lm_head_use_bias=False,
|
403
416
|
enable_hlfb=True,
|
404
417
|
)
|
@@ -24,7 +24,6 @@ from ai_edge_torch.generative.layers.scaled_dot_product_attention import scaled_
|
|
24
24
|
from ai_edge_torch.generative.layers.scaled_dot_product_attention import scaled_dot_product_attention_with_hlfb # NOQA
|
25
25
|
import torch
|
26
26
|
from torch import nn
|
27
|
-
import torch.nn.functional as F
|
28
27
|
|
29
28
|
BATCH_SIZE = 1
|
30
29
|
|
@@ -32,13 +31,18 @@ BATCH_SIZE = 1
|
|
32
31
|
class EncoderDecoderBlock(nn.Module):
|
33
32
|
|
34
33
|
def __init__(
|
35
|
-
self,
|
34
|
+
self,
|
35
|
+
config: cfg.TransformerBlockConfig,
|
36
|
+
model_config: cfg.ModelConfig,
|
37
|
+
has_relative_attention_bias: bool = False,
|
36
38
|
) -> None:
|
37
39
|
"""Initialize an instance of the EncoderDecoderBlock.
|
38
40
|
|
39
41
|
Args:
|
40
|
-
config (cfg.
|
41
|
-
block.
|
42
|
+
config (cfg.TransformerBlockConfig): the configuration object for this
|
43
|
+
transformer block.
|
44
|
+
model_config (cfg.ModelConfig): the configuration object for the model
|
45
|
+
this transformer block belongs to.
|
42
46
|
has_relative_attention_bias (bool): whether the self attention block has
|
43
47
|
relative bias.
|
44
48
|
"""
|
@@ -46,22 +50,22 @@ class EncoderDecoderBlock(nn.Module):
|
|
46
50
|
super().__init__()
|
47
51
|
self.atten_func = T5Attention(
|
48
52
|
BATCH_SIZE,
|
49
|
-
|
53
|
+
model_config.embedding_dim,
|
50
54
|
config.attn_config,
|
51
55
|
config.pre_attention_norm_config,
|
52
|
-
|
53
|
-
|
56
|
+
model_config.kv_cache_max,
|
57
|
+
model_config.enable_hlfb,
|
54
58
|
has_relative_attention_bias=has_relative_attention_bias,
|
55
59
|
)
|
56
60
|
# For a decoder, we add a cross attention.
|
57
|
-
if
|
61
|
+
if model_config.is_decoder:
|
58
62
|
self.cross_atten_func = T5Attention(
|
59
63
|
BATCH_SIZE,
|
60
|
-
|
64
|
+
model_config.embedding_dim,
|
61
65
|
config.attn_config,
|
62
66
|
config.pre_attention_norm_config,
|
63
|
-
|
64
|
-
|
67
|
+
model_config.kv_cache_max,
|
68
|
+
model_config.enable_hlfb,
|
65
69
|
# Cross Attention does not have relative attention bias.
|
66
70
|
has_relative_attention_bias=False,
|
67
71
|
)
|
@@ -69,9 +73,10 @@ class EncoderDecoderBlock(nn.Module):
|
|
69
73
|
self.cross_atten_func = None
|
70
74
|
|
71
75
|
self.post_atten_norm = builder.build_norm(
|
72
|
-
|
76
|
+
model_config.embedding_dim,
|
77
|
+
config.post_attention_norm_config,
|
73
78
|
)
|
74
|
-
self.ff = builder.build_ff(
|
79
|
+
self.ff = builder.build_ff(model_config.embedding_dim, config.ff_config)
|
75
80
|
self.config = config
|
76
81
|
|
77
82
|
def forward(
|
@@ -20,7 +20,6 @@ from ai_edge_torch.generative.layers.attention import TransformerBlock
|
|
20
20
|
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
21
21
|
import ai_edge_torch.generative.layers.builder as builder
|
22
22
|
import ai_edge_torch.generative.layers.model_config as cfg
|
23
|
-
import numpy as np
|
24
23
|
import torch
|
25
24
|
import torch.nn as nn
|
26
25
|
|
@@ -36,16 +35,16 @@ class ToySingleLayerModel(torch.nn.Module):
|
|
36
35
|
config.embedding_dim, config.vocab_size, bias=config.lm_head_use_bias
|
37
36
|
)
|
38
37
|
self.tok_embedding = nn.Embedding(config.vocab_size, config.embedding_dim)
|
39
|
-
self.transformer_block = TransformerBlock(config)
|
38
|
+
self.transformer_block = TransformerBlock(config.block_config(0), config)
|
40
39
|
self.final_norm = builder.build_norm(
|
41
40
|
config.embedding_dim,
|
42
41
|
config.final_norm_config,
|
43
42
|
)
|
43
|
+
# Toy model has only one block config.
|
44
|
+
attn_config = config.block_config(0).attn_config
|
44
45
|
self.rope_cache = attn_utils.build_rope_cache(
|
45
46
|
size=config.max_seq_len,
|
46
|
-
dim=int(
|
47
|
-
config.attn_config.rotary_percentage * config.attn_config.head_dim
|
48
|
-
),
|
47
|
+
dim=int(attn_config.rotary_percentage * attn_config.head_dim),
|
49
48
|
base=10_000,
|
50
49
|
condense_ratio=1,
|
51
50
|
dtype=torch.float32,
|
@@ -85,16 +84,16 @@ class ToySingleLayerModelWeightSharing(torch.nn.Module):
|
|
85
84
|
bias=config.lm_head_use_bias,
|
86
85
|
)
|
87
86
|
self.lm_head.weight.data = self.tok_embedding.weight.data
|
88
|
-
self.transformer_block = TransformerBlock(config)
|
87
|
+
self.transformer_block = TransformerBlock(config.block_config(0), config)
|
89
88
|
self.final_norm = builder.build_norm(
|
90
89
|
config.embedding_dim,
|
91
90
|
config.final_norm_config,
|
92
91
|
)
|
92
|
+
# Toy model has only one block config.
|
93
|
+
attn_config = config.block_config(0).attn_config
|
93
94
|
self.rope_cache = attn_utils.build_rope_cache(
|
94
95
|
size=config.max_seq_len,
|
95
|
-
dim=int(
|
96
|
-
config.attn_config.rotary_percentage * config.attn_config.head_dim
|
97
|
-
),
|
96
|
+
dim=int(attn_config.rotary_percentage * attn_config.head_dim),
|
98
97
|
base=10_000,
|
99
98
|
condense_ratio=1,
|
100
99
|
dtype=torch.float32,
|
@@ -135,15 +134,18 @@ def get_model_config() -> cfg.ModelConfig:
|
|
135
134
|
intermediate_size=256,
|
136
135
|
)
|
137
136
|
norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
|
137
|
+
block_config = cfg.TransformerBlockConfig(
|
138
|
+
attn_config=attn_config,
|
139
|
+
ff_config=ff_config,
|
140
|
+
pre_attention_norm_config=norm_config,
|
141
|
+
post_attention_norm_config=norm_config,
|
142
|
+
)
|
138
143
|
config = cfg.ModelConfig(
|
139
144
|
vocab_size=400,
|
140
145
|
num_layers=1,
|
141
146
|
max_seq_len=KV_CACHE_MAX_LEN,
|
142
147
|
embedding_dim=128,
|
143
|
-
|
144
|
-
ff_config=ff_config,
|
145
|
-
pre_attention_norm_config=norm_config,
|
146
|
-
post_attention_norm_config=norm_config,
|
148
|
+
block_configs=block_config,
|
147
149
|
final_norm_config=norm_config,
|
148
150
|
)
|
149
151
|
return config
|
@@ -12,14 +12,17 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
|
15
|
+
|
16
|
+
"""A toy example which has basic transformer block (w/ externalized KV-Cache)."""
|
17
|
+
|
16
18
|
from typing import Tuple
|
17
19
|
|
18
20
|
import ai_edge_torch
|
19
21
|
from ai_edge_torch import lowertools
|
20
|
-
from ai_edge_torch.generative.layers
|
22
|
+
from ai_edge_torch.generative.layers import attention
|
23
|
+
from ai_edge_torch.generative.layers import builder
|
24
|
+
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
21
25
|
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
22
|
-
import ai_edge_torch.generative.layers.builder as builder
|
23
26
|
import ai_edge_torch.generative.layers.model_config as cfg
|
24
27
|
import torch
|
25
28
|
import torch.nn as nn
|
@@ -27,7 +30,7 @@ import torch.nn as nn
|
|
27
30
|
RoPECache = Tuple[torch.Tensor, torch.Tensor]
|
28
31
|
|
29
32
|
|
30
|
-
class
|
33
|
+
class ToyModelWithKVCache(torch.nn.Module):
|
31
34
|
|
32
35
|
def __init__(self, config: cfg.ModelConfig) -> None:
|
33
36
|
super().__init__()
|
@@ -35,18 +38,20 @@ class ToyModelWithKV(torch.nn.Module):
|
|
35
38
|
config.embedding_dim, config.vocab_size, bias=config.lm_head_use_bias
|
36
39
|
)
|
37
40
|
self.tok_embedding = nn.Embedding(config.vocab_size, config.embedding_dim)
|
41
|
+
# Toy model has only one block config.
|
42
|
+
block_config = config.block_config(0)
|
38
43
|
self.transformer_blocks = nn.ModuleList(
|
39
|
-
TransformerBlock(
|
44
|
+
attention.TransformerBlock(block_config, config)
|
45
|
+
for _ in range(config.num_layers)
|
40
46
|
)
|
41
47
|
self.final_norm = builder.build_norm(
|
42
48
|
config.embedding_dim,
|
43
49
|
config.final_norm_config,
|
44
50
|
)
|
51
|
+
attn_config = block_config.attn_config
|
45
52
|
self.rope_cache = attn_utils.build_rope_cache(
|
46
53
|
size=config.max_seq_len,
|
47
|
-
dim=int(
|
48
|
-
config.attn_config.rotary_percentage * config.attn_config.head_dim
|
49
|
-
),
|
54
|
+
dim=int(attn_config.rotary_percentage * attn_config.head_dim),
|
50
55
|
base=10_000,
|
51
56
|
condense_ratio=1,
|
52
57
|
dtype=torch.float32,
|
@@ -57,18 +62,29 @@ class ToyModelWithKV(torch.nn.Module):
|
|
57
62
|
)
|
58
63
|
self.config = config
|
59
64
|
|
60
|
-
|
61
|
-
|
62
|
-
|
65
|
+
def forward(
|
66
|
+
self,
|
67
|
+
tokens: torch.Tensor,
|
68
|
+
input_pos: torch.Tensor,
|
69
|
+
kv_cache: kv_utils.KVCache,
|
70
|
+
) -> Tuple[torch.Tensor, kv_utils.KVCache]:
|
71
|
+
x = self.tok_embedding(tokens)
|
63
72
|
cos, sin = self.rope_cache
|
64
73
|
cos = cos.index_select(0, input_pos)
|
65
74
|
sin = sin.index_select(0, input_pos)
|
66
75
|
mask = self.mask_cache.index_select(2, input_pos)
|
67
76
|
mask = mask[:, :, :, : self.config.max_seq_len]
|
77
|
+
|
78
|
+
updated_kv_entires = []
|
68
79
|
for i, block in enumerate(self.transformer_blocks):
|
69
|
-
|
80
|
+
kv_entry = kv_cache.caches[i] if kv_cache else None
|
81
|
+
x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
|
82
|
+
if kv_entry:
|
83
|
+
updated_kv_entires.append(kv_entry)
|
84
|
+
|
70
85
|
x = self.final_norm(x)
|
71
|
-
|
86
|
+
updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
|
87
|
+
return {'logits': self.lm_head(x), 'kv_cache': updated_kv_cache}
|
72
88
|
|
73
89
|
|
74
90
|
def _export_stablehlo_mlir(model, args):
|
@@ -78,7 +94,10 @@ def _export_stablehlo_mlir(model, args):
|
|
78
94
|
|
79
95
|
def get_model_config() -> cfg.ModelConfig:
|
80
96
|
attn_config = cfg.AttentionConfig(
|
81
|
-
num_heads=32,
|
97
|
+
num_heads=32,
|
98
|
+
head_dim=4,
|
99
|
+
num_query_groups=4,
|
100
|
+
rotary_percentage=1.0,
|
82
101
|
)
|
83
102
|
ff_config = cfg.FeedForwardConfig(
|
84
103
|
type=cfg.FeedForwardType.GATED,
|
@@ -86,15 +105,18 @@ def get_model_config() -> cfg.ModelConfig:
|
|
86
105
|
intermediate_size=256,
|
87
106
|
)
|
88
107
|
norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
|
89
|
-
|
90
|
-
vocab_size=150,
|
91
|
-
num_layers=2,
|
92
|
-
max_seq_len=500,
|
93
|
-
embedding_dim=128,
|
108
|
+
block_config = cfg.TransformerBlockConfig(
|
94
109
|
attn_config=attn_config,
|
95
110
|
ff_config=ff_config,
|
96
111
|
pre_attention_norm_config=norm_config,
|
97
112
|
post_attention_norm_config=norm_config,
|
113
|
+
)
|
114
|
+
config = cfg.ModelConfig(
|
115
|
+
vocab_size=150,
|
116
|
+
num_layers=2,
|
117
|
+
max_seq_len=100,
|
118
|
+
embedding_dim=128,
|
119
|
+
block_configs=block_config,
|
98
120
|
final_norm_config=norm_config,
|
99
121
|
enable_hlfb=True,
|
100
122
|
)
|
@@ -102,40 +124,59 @@ def get_model_config() -> cfg.ModelConfig:
|
|
102
124
|
|
103
125
|
|
104
126
|
def get_sample_prefill_inputs() -> Tuple[torch.Tensor, torch.Tensor]:
|
105
|
-
|
127
|
+
tokens = torch.unsqueeze(torch.arange(0, 100), 0)
|
106
128
|
input_pos = torch.arange(0, 100)
|
107
|
-
return
|
129
|
+
return tokens, input_pos
|
108
130
|
|
109
131
|
|
110
132
|
def get_sample_decode_inputs() -> Tuple[torch.Tensor, torch.Tensor]:
|
111
|
-
|
112
|
-
input_pos = torch.tensor([10]
|
113
|
-
return
|
133
|
+
tokens = torch.tensor([[1]], dtype=torch.long)
|
134
|
+
input_pos = torch.tensor([10])
|
135
|
+
return tokens, input_pos
|
114
136
|
|
115
137
|
|
116
138
|
def define_and_run() -> None:
|
117
139
|
dump_mlir = False
|
118
140
|
|
119
141
|
config = get_model_config()
|
120
|
-
model =
|
142
|
+
model = ToyModelWithExternalKV(config)
|
143
|
+
model.eval()
|
121
144
|
print('running an inference')
|
122
|
-
|
123
|
-
|
124
|
-
|
145
|
+
kv = kv_utils.KVCache.from_model_config(config)
|
146
|
+
|
147
|
+
tokens, input_pos = get_sample_prefill_inputs()
|
148
|
+
decode_token, decode_input_pos = get_sample_decode_inputs()
|
149
|
+
print(model.forward(tokens, input_pos, kv))
|
125
150
|
|
126
151
|
if dump_mlir:
|
127
|
-
mlir_text = _export_stablehlo_mlir(model, (
|
128
|
-
with open('/tmp/
|
152
|
+
mlir_text = _export_stablehlo_mlir(model, (tokens, input_pos, kv))
|
153
|
+
with open('/tmp/toy_model_with_external_kv.stablehlo.mlir', 'w') as f:
|
129
154
|
f.write(mlir_text)
|
130
155
|
|
131
156
|
# Convert model to tflite with 2 signatures (prefill + decode).
|
132
157
|
print('converting toy model to tflite with 2 signatures (prefill + decode)')
|
133
158
|
edge_model = (
|
134
|
-
ai_edge_torch.signature(
|
135
|
-
|
159
|
+
ai_edge_torch.signature(
|
160
|
+
'prefill',
|
161
|
+
model,
|
162
|
+
sample_kwargs={
|
163
|
+
'tokens': tokens,
|
164
|
+
'input_pos': input_pos,
|
165
|
+
'kv_cache': kv,
|
166
|
+
},
|
167
|
+
)
|
168
|
+
.signature(
|
169
|
+
'decode',
|
170
|
+
model,
|
171
|
+
sample_kwargs={
|
172
|
+
'tokens': decode_token,
|
173
|
+
'input_pos': decode_input_pos,
|
174
|
+
'kv_cache': kv,
|
175
|
+
},
|
176
|
+
)
|
136
177
|
.convert()
|
137
178
|
)
|
138
|
-
edge_model.export('/tmp/
|
179
|
+
edge_model.export('/tmp/toy_external_kv_cache.tflite')
|
139
180
|
|
140
181
|
|
141
182
|
if __name__ == '__main__':
|