ai-edge-torch-nightly 0.3.0.dev20240910__py3-none-any.whl → 0.3.0.dev20240913__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +31 -12
- ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +25 -6
- ai_edge_torch/generative/examples/gemma/gemma.py +50 -30
- ai_edge_torch/generative/examples/gemma/gemma2.py +85 -58
- ai_edge_torch/generative/examples/{experimental/phi → phi}/convert_to_tflite.py +11 -12
- ai_edge_torch/generative/examples/{experimental/phi → phi}/phi2.py +46 -43
- ai_edge_torch/generative/examples/{experimental/gemma → smallm}/convert_to_tflite.py +12 -14
- ai_edge_torch/generative/examples/smallm/smallm.py +122 -0
- ai_edge_torch/generative/examples/stable_diffusion/clip.py +11 -5
- ai_edge_torch/generative/examples/t5/t5.py +35 -22
- ai_edge_torch/generative/examples/t5/t5_attention.py +18 -13
- ai_edge_torch/generative/examples/test_models/toy_model.py +15 -13
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +74 -33
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +25 -6
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +55 -34
- ai_edge_torch/generative/layers/attention.py +77 -73
- ai_edge_torch/generative/layers/builder.py +5 -3
- ai_edge_torch/generative/layers/kv_cache.py +163 -51
- ai_edge_torch/generative/layers/model_config.py +38 -19
- ai_edge_torch/generative/layers/normalization.py +158 -0
- ai_edge_torch/generative/layers/unet/blocks_2d.py +0 -2
- ai_edge_torch/generative/test/{test_experimental_ekv.py → test_kv_cache.py} +12 -24
- ai_edge_torch/generative/test/test_loader.py +1 -1
- ai_edge_torch/generative/test/test_model_conversion.py +72 -34
- ai_edge_torch/generative/test/test_model_conversion_large.py +51 -23
- ai_edge_torch/generative/test/utils.py +54 -0
- ai_edge_torch/generative/utilities/loader.py +15 -15
- ai_edge_torch/generative/utilities/t5_loader.py +21 -20
- ai_edge_torch/odml_torch/lowerings/__init__.py +1 -0
- ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +0 -1
- ai_edge_torch/odml_torch/lowerings/_layer_norm.py +78 -0
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240913.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240913.dist-info}/RECORD +39 -45
- ai_edge_torch/generative/examples/experimental/gemma/gemma.py +0 -219
- ai_edge_torch/generative/examples/experimental/phi/__init__.py +0 -14
- ai_edge_torch/generative/examples/experimental/tiny_llama/__init__.py +0 -14
- ai_edge_torch/generative/examples/experimental/tiny_llama/convert_to_tflite.py +0 -87
- ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py +0 -205
- ai_edge_torch/generative/examples/phi2/__init__.py +0 -14
- ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +0 -67
- ai_edge_torch/generative/examples/phi2/phi2.py +0 -189
- ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +0 -176
- /ai_edge_torch/generative/examples/{experimental → phi}/__init__.py +0 -0
- /ai_edge_torch/generative/examples/{experimental/gemma → smallm}/__init__.py +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240913.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240913.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240913.dist-info}/top_level.txt +0 -0
@@ -12,26 +12,22 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
|
16
|
-
|
17
|
-
# Note: This is an experimental version of phi2 with external KV cache.
|
18
|
-
# Please use with caution.
|
15
|
+
|
16
|
+
"""Example of building a Phi-2 model."""
|
19
17
|
|
20
18
|
import os
|
21
|
-
|
22
|
-
from typing import Tuple
|
19
|
+
import pathlib
|
23
20
|
|
21
|
+
from ai_edge_torch.generative.layers import attention
|
24
22
|
from ai_edge_torch.generative.layers import builder
|
23
|
+
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
25
24
|
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
26
|
-
from ai_edge_torch.generative.layers.experimental import attention
|
27
|
-
from ai_edge_torch.generative.layers.experimental import ekv_cache as kv_utils
|
28
25
|
import ai_edge_torch.generative.layers.model_config as cfg
|
29
26
|
import ai_edge_torch.generative.utilities.loader as loading_utils
|
30
27
|
import numpy as np
|
31
28
|
import torch
|
32
29
|
from torch import nn
|
33
30
|
|
34
|
-
|
35
31
|
TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
36
32
|
ff_up_proj="model.layers.{}.mlp.fc1",
|
37
33
|
ff_down_proj="model.layers.{}.mlp.fc2",
|
@@ -52,7 +48,6 @@ class Phi2(nn.Module):
|
|
52
48
|
def __init__(self, config: cfg.ModelConfig):
|
53
49
|
super().__init__()
|
54
50
|
|
55
|
-
self.config = config
|
56
51
|
# Construct model layers.
|
57
52
|
self.lm_head = nn.Linear(
|
58
53
|
config.embedding_dim, config.vocab_size, bias=config.lm_head_use_bias
|
@@ -60,18 +55,20 @@ class Phi2(nn.Module):
|
|
60
55
|
self.tok_embedding = nn.Embedding(
|
61
56
|
config.vocab_size, config.embedding_dim, padding_idx=0
|
62
57
|
)
|
58
|
+
# Phi-2 has only one block config.
|
59
|
+
block_config = config.block_config(0)
|
63
60
|
self.transformer_blocks = nn.ModuleList(
|
64
|
-
attention.TransformerBlock(
|
61
|
+
attention.TransformerBlock(block_config, config)
|
62
|
+
for _ in range(config.num_layers)
|
65
63
|
)
|
66
64
|
self.final_norm = builder.build_norm(
|
67
65
|
config.embedding_dim,
|
68
66
|
config.final_norm_config,
|
69
67
|
)
|
68
|
+
attn_config = block_config.attn_config
|
70
69
|
self.rope_cache = attn_utils.build_rope_cache(
|
71
70
|
size=config.kv_cache_max,
|
72
|
-
dim=int(
|
73
|
-
config.attn_config.rotary_percentage * config.attn_config.head_dim
|
74
|
-
),
|
71
|
+
dim=int(attn_config.rotary_percentage * attn_config.head_dim),
|
75
72
|
base=10_000,
|
76
73
|
condense_ratio=1,
|
77
74
|
dtype=torch.float32,
|
@@ -89,13 +86,17 @@ class Phi2(nn.Module):
|
|
89
86
|
self,
|
90
87
|
tokens: torch.Tensor,
|
91
88
|
input_pos: torch.Tensor,
|
92
|
-
kv_cache: kv_utils.
|
93
|
-
) ->
|
89
|
+
kv_cache: kv_utils.KVCache,
|
90
|
+
) -> dict[torch.Tensor, kv_utils.KVCache]:
|
94
91
|
_, seq_len = tokens.size()
|
95
92
|
assert self.config.max_seq_len >= seq_len, (
|
96
93
|
f"Cannot forward sequence of length {seq_len}, max seq length is only"
|
97
94
|
f" {self.config.max_seq_len}"
|
98
95
|
)
|
96
|
+
assert len(self.transformer_blocks) == len(kv_cache.caches), (
|
97
|
+
"The number of transformer blocks and the number of KV cache entries"
|
98
|
+
" must be the same."
|
99
|
+
)
|
99
100
|
|
100
101
|
cos, sin = self.rope_cache
|
101
102
|
cos = cos.index_select(0, input_pos)
|
@@ -111,11 +112,11 @@ class Phi2(nn.Module):
|
|
111
112
|
x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
|
112
113
|
if kv_entry:
|
113
114
|
updated_kv_entires.append(kv_entry)
|
114
|
-
updated_kv_cache = kv_utils.
|
115
|
+
updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
|
115
116
|
|
116
117
|
x = self.final_norm(x)
|
117
|
-
|
118
|
-
return
|
118
|
+
logits = self.lm_head(x) # (b, t, vocab_size)
|
119
|
+
return {"logits": logits, "kv_cache": updated_kv_cache}
|
119
120
|
|
120
121
|
|
121
122
|
def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
@@ -143,17 +144,20 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
143
144
|
use_bias=True,
|
144
145
|
)
|
145
146
|
norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.LAYER_NORM)
|
147
|
+
block_config = cfg.TransformerBlockConfig(
|
148
|
+
attn_config=attn_config,
|
149
|
+
ff_config=ff_config,
|
150
|
+
pre_attention_norm_config=norm_config,
|
151
|
+
parallel_residual=True,
|
152
|
+
)
|
146
153
|
config = cfg.ModelConfig(
|
147
154
|
vocab_size=51200,
|
148
155
|
num_layers=32,
|
149
156
|
max_seq_len=2048,
|
150
157
|
kv_cache_max_len=kv_cache_max_len,
|
151
158
|
embedding_dim=2560,
|
152
|
-
|
153
|
-
ff_config=ff_config,
|
154
|
-
pre_attention_norm_config=norm_config,
|
159
|
+
block_configs=block_config,
|
155
160
|
final_norm_config=norm_config,
|
156
|
-
parallel_residual=True,
|
157
161
|
lm_head_use_bias=True,
|
158
162
|
enable_hlfb=True,
|
159
163
|
)
|
@@ -165,43 +169,42 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
|
|
165
169
|
config.vocab_size = 128
|
166
170
|
config.num_layers = 2
|
167
171
|
config.max_seq_len = 2 * kv_cache_max_len
|
168
|
-
config.
|
172
|
+
# Phi-2 has only one block config.
|
173
|
+
config.block_config(0).ff_config.intermediate_size = 128
|
169
174
|
return config
|
170
175
|
|
171
176
|
|
172
|
-
def build_model(
|
173
|
-
checkpoint_path: str, test_model: bool = False, **kwargs
|
174
|
-
) -> nn.Module:
|
177
|
+
def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
|
175
178
|
"""Instantiates the model instance and load checkpoint if provided."""
|
176
|
-
config = (
|
177
|
-
get_fake_model_config(**kwargs)
|
178
|
-
if test_model
|
179
|
-
else get_model_config(**kwargs)
|
180
|
-
)
|
179
|
+
config = get_model_config(**kwargs)
|
181
180
|
model = Phi2(config)
|
182
|
-
|
183
|
-
|
184
|
-
loader.load(model)
|
181
|
+
loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
|
182
|
+
loader.load(model)
|
185
183
|
model.eval()
|
186
184
|
return model
|
187
185
|
|
188
186
|
|
189
|
-
def define_and_run(checkpoint_path: str
|
187
|
+
def define_and_run(checkpoint_path: str) -> None:
|
190
188
|
"""Instantiates and runs a Phi-2 model."""
|
191
189
|
|
190
|
+
current_dir = pathlib.Path(__file__).parent.resolve()
|
191
|
+
phi2_goldens = torch.load(current_dir / "phi2_lm_logits.pt")
|
192
192
|
kv_cache_max_len = 1024
|
193
|
-
model = build_model(
|
194
|
-
checkpoint_path, test_model=test_model, kv_cache_max_len=kv_cache_max_len
|
195
|
-
)
|
193
|
+
model = build_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
|
196
194
|
idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
|
197
195
|
tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.long, device="cpu")
|
198
196
|
tokens[0, :4] = idx
|
199
197
|
input_pos = torch.arange(0, kv_cache_max_len)
|
200
|
-
kv = kv_utils.
|
201
|
-
|
202
|
-
print(
|
198
|
+
kv = kv_utils.KVCache.from_model_config(model.config)
|
199
|
+
output = model.forward(tokens, input_pos, kv)
|
200
|
+
print("comparing with goldens..")
|
201
|
+
assert torch.allclose(
|
202
|
+
phi2_goldens, output["logits"][0, idx.shape[1] - 1, :], atol=1e-02
|
203
|
+
)
|
203
204
|
|
204
205
|
|
205
206
|
if __name__ == "__main__":
|
206
|
-
input_checkpoint_path = os.path.join(
|
207
|
+
input_checkpoint_path = os.path.join(
|
208
|
+
pathlib.Path.home(), "Downloads/llm_data/phi2"
|
209
|
+
)
|
207
210
|
define_and_run(input_checkpoint_path)
|
@@ -12,30 +12,27 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
#
|
16
|
-
# Note: This is an experimental version of Gemma with external KV cache.
|
17
|
-
# Please use with caution.
|
18
15
|
|
16
|
+
"""Example of converting SmalLM model to multi-signature tflite model."""
|
19
17
|
|
20
18
|
import os
|
21
|
-
|
19
|
+
import pathlib
|
22
20
|
|
23
21
|
import ai_edge_torch
|
24
|
-
from ai_edge_torch.generative.examples.
|
25
|
-
from ai_edge_torch.generative.layers
|
22
|
+
from ai_edge_torch.generative.examples.smallm import smallm
|
23
|
+
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
26
24
|
from ai_edge_torch.generative.quantize import quant_recipes
|
27
25
|
import torch
|
28
26
|
|
29
27
|
|
30
|
-
def
|
28
|
+
def convert_smallm_to_tflite(
|
31
29
|
checkpoint_path: str,
|
32
30
|
prefill_seq_len: int = 512,
|
33
31
|
kv_cache_max_len: int = 1024,
|
34
32
|
quantize: bool = True,
|
35
33
|
):
|
36
|
-
"""
|
34
|
+
"""Converts SmalLM model to multi-signature tflite model.
|
37
35
|
|
38
|
-
tflite model.
|
39
36
|
Args:
|
40
37
|
checkpoint_path (str): The filepath to the model checkpoint, or directory
|
41
38
|
holding the checkpoint.
|
@@ -46,7 +43,7 @@ def convert_gemma_to_tflite(
|
|
46
43
|
quantize (bool, optional): Whether the model should be quanized. Defaults
|
47
44
|
to True.
|
48
45
|
"""
|
49
|
-
pytorch_model =
|
46
|
+
pytorch_model = smallm.build_model(
|
50
47
|
checkpoint_path, kv_cache_max_len=kv_cache_max_len
|
51
48
|
)
|
52
49
|
# Tensors used to trace the model graph during conversion.
|
@@ -54,7 +51,7 @@ def convert_gemma_to_tflite(
|
|
54
51
|
prefill_input_pos = torch.arange(0, prefill_seq_len)
|
55
52
|
decode_token = torch.tensor([[0]], dtype=torch.long)
|
56
53
|
decode_input_pos = torch.tensor([0], dtype=torch.int64)
|
57
|
-
kv = kv_utils.
|
54
|
+
kv = kv_utils.KVCache.from_model_config(pytorch_model.config)
|
58
55
|
|
59
56
|
quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
|
60
57
|
edge_model = (
|
@@ -78,11 +75,12 @@ def convert_gemma_to_tflite(
|
|
78
75
|
)
|
79
76
|
.convert(quant_config=quant_config)
|
80
77
|
)
|
78
|
+
quant_suffix = 'q8' if quantize else 'f32'
|
81
79
|
edge_model.export(
|
82
|
-
f'/tmp/
|
80
|
+
f'/tmp/smallm_{quant_suffix}_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
|
83
81
|
)
|
84
82
|
|
85
83
|
|
86
84
|
if __name__ == '__main__':
|
87
|
-
|
88
|
-
|
85
|
+
path = os.path.join(pathlib.Path.home(), 'Downloads/llm_data/smallm')
|
86
|
+
convert_smallm_to_tflite(path)
|
@@ -0,0 +1,122 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Example of building a SmalLM model."""
|
17
|
+
|
18
|
+
import copy
|
19
|
+
import os
|
20
|
+
import pathlib
|
21
|
+
|
22
|
+
from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
|
23
|
+
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
24
|
+
import ai_edge_torch.generative.layers.model_config as cfg
|
25
|
+
import ai_edge_torch.generative.utilities.loader as loading_utils
|
26
|
+
import numpy as np
|
27
|
+
import torch
|
28
|
+
from torch import nn
|
29
|
+
|
30
|
+
TENSOR_NAMES = copy.copy(tiny_llama.TENSOR_NAMES)
|
31
|
+
# SmalLM re-uses the embedding as the head projection layer.
|
32
|
+
TENSOR_NAMES.lm_head = None
|
33
|
+
|
34
|
+
|
35
|
+
class SmalLM(tiny_llama.TinyLlama):
|
36
|
+
"""A SmalLM model built from the Edge Generative API layers.
|
37
|
+
|
38
|
+
SmalLM shares the same architecture as TinyLlama, but with different model
|
39
|
+
sizes.
|
40
|
+
"""
|
41
|
+
|
42
|
+
def __init__(self, config: cfg.ModelConfig):
|
43
|
+
super().__init__(config)
|
44
|
+
# SmalLM re-uses the embedding as the head projection layer.
|
45
|
+
self.lm_head.weight.data = self.tok_embedding.weight.data
|
46
|
+
|
47
|
+
|
48
|
+
def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
49
|
+
"""Returns the model config for a SmalLM 135M model.
|
50
|
+
|
51
|
+
Args:
|
52
|
+
kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
|
53
|
+
is 1024.
|
54
|
+
|
55
|
+
Returns:
|
56
|
+
The model config for a SmalLM model.
|
57
|
+
"""
|
58
|
+
attn_config = cfg.AttentionConfig(
|
59
|
+
num_heads=9,
|
60
|
+
head_dim=64,
|
61
|
+
num_query_groups=3,
|
62
|
+
rotary_percentage=1.0,
|
63
|
+
)
|
64
|
+
ff_config = cfg.FeedForwardConfig(
|
65
|
+
type=cfg.FeedForwardType.GATED,
|
66
|
+
activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
|
67
|
+
intermediate_size=1536,
|
68
|
+
)
|
69
|
+
norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
|
70
|
+
block_config = cfg.TransformerBlockConfig(
|
71
|
+
attn_config=attn_config,
|
72
|
+
ff_config=ff_config,
|
73
|
+
pre_attention_norm_config=norm_config,
|
74
|
+
post_attention_norm_config=norm_config,
|
75
|
+
)
|
76
|
+
config = cfg.ModelConfig(
|
77
|
+
vocab_size=49152,
|
78
|
+
num_layers=30,
|
79
|
+
max_seq_len=2048,
|
80
|
+
embedding_dim=576,
|
81
|
+
kv_cache_max_len=kv_cache_max_len,
|
82
|
+
block_configs=block_config,
|
83
|
+
final_norm_config=norm_config,
|
84
|
+
enable_hlfb=True,
|
85
|
+
)
|
86
|
+
return config
|
87
|
+
|
88
|
+
|
89
|
+
def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
|
90
|
+
config = get_model_config(**kwargs)
|
91
|
+
model = SmalLM(config)
|
92
|
+
loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
|
93
|
+
# Since embedding and lm-head use the same weight, we need to set strict
|
94
|
+
# to False.
|
95
|
+
loader.load(model, strict=False)
|
96
|
+
model.eval()
|
97
|
+
return model
|
98
|
+
|
99
|
+
|
100
|
+
def define_and_run(checkpoint_path: str) -> None:
|
101
|
+
"""Instantiates and runs a SmalLM model."""
|
102
|
+
|
103
|
+
current_dir = pathlib.Path(__file__).parent.resolve()
|
104
|
+
smallm_goldens = torch.load(current_dir / "smallm_lm_logits.pt")
|
105
|
+
kv_cache_max_len = 1024
|
106
|
+
model = build_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
|
107
|
+
idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
|
108
|
+
tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.long, device="cpu")
|
109
|
+
tokens[0, :4] = idx
|
110
|
+
input_pos = torch.arange(0, kv_cache_max_len)
|
111
|
+
kv = kv_utils.KVCache.from_model_config(model.config)
|
112
|
+
output = model.forward(tokens, input_pos, kv)
|
113
|
+
assert torch.allclose(
|
114
|
+
smallm_goldens, output["logits"][0, idx.shape[1] - 1, :], atol=1e-05
|
115
|
+
)
|
116
|
+
|
117
|
+
|
118
|
+
if __name__ == "__main__":
|
119
|
+
input_checkpoint_path = os.path.join(
|
120
|
+
pathlib.Path.home(), "Downloads/llm_data/smallm"
|
121
|
+
)
|
122
|
+
define_and_run(input_checkpoint_path)
|
@@ -61,8 +61,10 @@ class CLIP(nn.Module):
|
|
61
61
|
)
|
62
62
|
|
63
63
|
self.config = config
|
64
|
+
# CLIP has only one block config.
|
65
|
+
block_config = config.block_config(0)
|
64
66
|
self.transformer_blocks = nn.ModuleList(
|
65
|
-
TransformerBlock(config) for _ in range(config.num_layers)
|
67
|
+
TransformerBlock(block_config, config) for _ in range(config.num_layers)
|
66
68
|
)
|
67
69
|
self.final_norm = builder.build_norm(
|
68
70
|
config.embedding_dim, config.final_norm_config
|
@@ -112,15 +114,19 @@ def get_model_config() -> cfg.ModelConfig:
|
|
112
114
|
|
113
115
|
norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.LAYER_NORM)
|
114
116
|
|
117
|
+
block_config = cfg.TransformerBlockConfig(
|
118
|
+
attn_config=attn_config,
|
119
|
+
ff_config=ff_config,
|
120
|
+
pre_attention_norm_config=norm_config,
|
121
|
+
post_attention_norm_config=norm_config,
|
122
|
+
)
|
123
|
+
|
115
124
|
config = cfg.ModelConfig(
|
116
125
|
vocab_size=vocab_size,
|
117
126
|
num_layers=num_layers,
|
118
127
|
max_seq_len=max_seq_len,
|
119
128
|
embedding_dim=embedding_dim,
|
120
|
-
|
121
|
-
ff_config=ff_config,
|
122
|
-
pre_attention_norm_config=norm_config,
|
123
|
-
post_attention_norm_config=norm_config,
|
129
|
+
block_configs=block_config,
|
124
130
|
final_norm_config=norm_config,
|
125
131
|
enable_hlfb=True,
|
126
132
|
)
|
@@ -52,9 +52,15 @@ class T5Stack(nn.Module):
|
|
52
52
|
self.config = config
|
53
53
|
self.embed_tokens = embed_tokens
|
54
54
|
self.is_decoder = config.is_decoder
|
55
|
+
# T5 has only one block config.
|
56
|
+
block_config = config.block_config(0)
|
55
57
|
self.transformer_blocks = nn.ModuleList([
|
56
|
-
EncoderDecoderBlock(
|
57
|
-
|
58
|
+
EncoderDecoderBlock(
|
59
|
+
block_config,
|
60
|
+
config,
|
61
|
+
has_relative_attention_bias=bool(idx == 0),
|
62
|
+
)
|
63
|
+
for idx in range(config.num_layers)
|
58
64
|
])
|
59
65
|
self.final_norm = builder.build_norm(
|
60
66
|
config.embedding_dim, config.final_norm_config
|
@@ -73,13 +79,11 @@ class T5Stack(nn.Module):
|
|
73
79
|
torch.Tensor
|
74
80
|
] = None, # should be for decoder case
|
75
81
|
):
|
76
|
-
input_shape = input_ids.size()
|
77
82
|
inputs_embeds = self.embed_tokens(input_ids)
|
78
|
-
batch_size, seq_length = input_shape
|
79
83
|
hidden_states = inputs_embeds
|
80
84
|
position_bias = None
|
81
85
|
encoder_decoder_position_bias = None
|
82
|
-
for
|
86
|
+
for _, layer_module in enumerate(self.transformer_blocks):
|
83
87
|
# EncoderDecoderBlock.forward
|
84
88
|
hidden_states, position_bias, encoder_decoder_position_bias = (
|
85
89
|
layer_module(
|
@@ -111,7 +115,8 @@ class T5(nn.Module):
|
|
111
115
|
|
112
116
|
encoder_config = copy.deepcopy(config)
|
113
117
|
encoder_config.is_decoder = False
|
114
|
-
|
118
|
+
# T5 has only one block config.
|
119
|
+
encoder_config.block_config(0).attn_config.enable_kv_cache = False
|
115
120
|
self.encoder = T5Stack(encoder_config, self.tok_embedding)
|
116
121
|
|
117
122
|
decoder_config = copy.deepcopy(config)
|
@@ -137,20 +142,22 @@ class T5(nn.Module):
|
|
137
142
|
device=torch.device("cpu"),
|
138
143
|
)
|
139
144
|
|
145
|
+
# T5 has only one block config.
|
146
|
+
attn_config = config.block_config(0).attn_config
|
140
147
|
self.enc_rel_pos_mask = attn_utils.build_relative_position_buckets(
|
141
148
|
bidirectional=True,
|
142
149
|
query_length=config.kv_cache_max,
|
143
150
|
key_length=config.kv_cache_max,
|
144
|
-
num_buckets=
|
145
|
-
max_distance=
|
151
|
+
num_buckets=attn_config.relative_attention_num_buckets,
|
152
|
+
max_distance=attn_config.relative_attention_max_distance,
|
146
153
|
)
|
147
154
|
|
148
155
|
self.dec_rel_pos_mask = attn_utils.build_relative_position_buckets(
|
149
156
|
bidirectional=False,
|
150
157
|
query_length=config.kv_cache_max,
|
151
158
|
key_length=config.kv_cache_max,
|
152
|
-
num_buckets=
|
153
|
-
max_distance=
|
159
|
+
num_buckets=attn_config.relative_attention_num_buckets,
|
160
|
+
max_distance=attn_config.relative_attention_max_distance,
|
154
161
|
)
|
155
162
|
|
156
163
|
@torch.inference_mode
|
@@ -230,7 +237,8 @@ class T5Encoder(nn.Module):
|
|
230
237
|
|
231
238
|
encoder_config = copy.deepcopy(config)
|
232
239
|
encoder_config.is_decoder = False
|
233
|
-
|
240
|
+
# T5 has only one block config.
|
241
|
+
encoder_config.block_config(0).attn_config.enable_kv_cache = False
|
234
242
|
self.encoder = T5Stack(encoder_config, self.tok_embedding)
|
235
243
|
|
236
244
|
self.enc_attn_mask_cache = (
|
@@ -243,12 +251,14 @@ class T5Encoder(nn.Module):
|
|
243
251
|
.unsqueeze(0)
|
244
252
|
)
|
245
253
|
|
254
|
+
# T5 has only one block config.
|
255
|
+
attn_config = config.block_config(0).attn_config
|
246
256
|
self.enc_rel_pos_mask = attn_utils.build_relative_position_buckets(
|
247
257
|
bidirectional=True,
|
248
258
|
query_length=config.kv_cache_max,
|
249
259
|
key_length=config.kv_cache_max,
|
250
|
-
num_buckets=
|
251
|
-
max_distance=
|
260
|
+
num_buckets=attn_config.relative_attention_num_buckets,
|
261
|
+
max_distance=attn_config.relative_attention_max_distance,
|
252
262
|
)
|
253
263
|
|
254
264
|
@torch.inference_mode
|
@@ -313,12 +323,14 @@ class T5Decoder(nn.Module):
|
|
313
323
|
.unsqueeze(0)
|
314
324
|
)
|
315
325
|
|
326
|
+
# T5 has only one block config.
|
327
|
+
attn_config = config.block_config(0).attn_config
|
316
328
|
self.enc_rel_pos_mask = attn_utils.build_relative_position_buckets(
|
317
329
|
bidirectional=True,
|
318
330
|
query_length=config.kv_cache_max,
|
319
331
|
key_length=config.kv_cache_max,
|
320
|
-
num_buckets=
|
321
|
-
max_distance=
|
332
|
+
num_buckets=attn_config.relative_attention_num_buckets,
|
333
|
+
max_distance=attn_config.relative_attention_max_distance,
|
322
334
|
)
|
323
335
|
|
324
336
|
self.dec_attn_mask_cache = attn_utils.build_causal_mask_cache(
|
@@ -386,19 +398,20 @@ def get_model_config_t5() -> cfg.ModelConfig:
|
|
386
398
|
type=cfg.NormalizationType.RMS_NORM,
|
387
399
|
epsilon=1e-6,
|
388
400
|
)
|
389
|
-
|
390
|
-
config = cfg.ModelConfig(
|
391
|
-
vocab_size=32128,
|
392
|
-
num_layers=12,
|
393
|
-
max_seq_len=512,
|
394
|
-
embedding_dim=768,
|
401
|
+
block_config = cfg.TransformerBlockConfig(
|
395
402
|
attn_config=attn_config,
|
396
403
|
relative_attention=True,
|
397
404
|
ff_config=ff_config,
|
398
405
|
pre_attention_norm_config=norm_config,
|
399
406
|
post_attention_norm_config=norm_config,
|
407
|
+
)
|
408
|
+
config = cfg.ModelConfig(
|
409
|
+
vocab_size=32128,
|
410
|
+
num_layers=12,
|
411
|
+
max_seq_len=512,
|
412
|
+
embedding_dim=768,
|
413
|
+
block_configs=block_config,
|
400
414
|
final_norm_config=norm_config,
|
401
|
-
parallel_residual=False,
|
402
415
|
lm_head_use_bias=False,
|
403
416
|
enable_hlfb=True,
|
404
417
|
)
|
@@ -24,7 +24,6 @@ from ai_edge_torch.generative.layers.scaled_dot_product_attention import scaled_
|
|
24
24
|
from ai_edge_torch.generative.layers.scaled_dot_product_attention import scaled_dot_product_attention_with_hlfb # NOQA
|
25
25
|
import torch
|
26
26
|
from torch import nn
|
27
|
-
import torch.nn.functional as F
|
28
27
|
|
29
28
|
BATCH_SIZE = 1
|
30
29
|
|
@@ -32,13 +31,18 @@ BATCH_SIZE = 1
|
|
32
31
|
class EncoderDecoderBlock(nn.Module):
|
33
32
|
|
34
33
|
def __init__(
|
35
|
-
self,
|
34
|
+
self,
|
35
|
+
config: cfg.TransformerBlockConfig,
|
36
|
+
model_config: cfg.ModelConfig,
|
37
|
+
has_relative_attention_bias: bool = False,
|
36
38
|
) -> None:
|
37
39
|
"""Initialize an instance of the EncoderDecoderBlock.
|
38
40
|
|
39
41
|
Args:
|
40
|
-
config (cfg.
|
41
|
-
block.
|
42
|
+
config (cfg.TransformerBlockConfig): the configuration object for this
|
43
|
+
transformer block.
|
44
|
+
model_config (cfg.ModelConfig): the configuration object for the model
|
45
|
+
this transformer block belongs to.
|
42
46
|
has_relative_attention_bias (bool): whether the self attention block has
|
43
47
|
relative bias.
|
44
48
|
"""
|
@@ -46,22 +50,22 @@ class EncoderDecoderBlock(nn.Module):
|
|
46
50
|
super().__init__()
|
47
51
|
self.atten_func = T5Attention(
|
48
52
|
BATCH_SIZE,
|
49
|
-
|
53
|
+
model_config.embedding_dim,
|
50
54
|
config.attn_config,
|
51
55
|
config.pre_attention_norm_config,
|
52
|
-
|
53
|
-
|
56
|
+
model_config.kv_cache_max,
|
57
|
+
model_config.enable_hlfb,
|
54
58
|
has_relative_attention_bias=has_relative_attention_bias,
|
55
59
|
)
|
56
60
|
# For a decoder, we add a cross attention.
|
57
|
-
if
|
61
|
+
if model_config.is_decoder:
|
58
62
|
self.cross_atten_func = T5Attention(
|
59
63
|
BATCH_SIZE,
|
60
|
-
|
64
|
+
model_config.embedding_dim,
|
61
65
|
config.attn_config,
|
62
66
|
config.pre_attention_norm_config,
|
63
|
-
|
64
|
-
|
67
|
+
model_config.kv_cache_max,
|
68
|
+
model_config.enable_hlfb,
|
65
69
|
# Cross Attention does not have relative attention bias.
|
66
70
|
has_relative_attention_bias=False,
|
67
71
|
)
|
@@ -69,9 +73,10 @@ class EncoderDecoderBlock(nn.Module):
|
|
69
73
|
self.cross_atten_func = None
|
70
74
|
|
71
75
|
self.post_atten_norm = builder.build_norm(
|
72
|
-
|
76
|
+
model_config.embedding_dim,
|
77
|
+
config.post_attention_norm_config,
|
73
78
|
)
|
74
|
-
self.ff = builder.build_ff(
|
79
|
+
self.ff = builder.build_ff(model_config.embedding_dim, config.ff_config)
|
75
80
|
self.config = config
|
76
81
|
|
77
82
|
def forward(
|