ai-edge-torch-nightly 0.3.0.dev20240910__py3-none-any.whl → 0.3.0.dev20240913__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +31 -12
- ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +25 -6
- ai_edge_torch/generative/examples/gemma/gemma.py +50 -30
- ai_edge_torch/generative/examples/gemma/gemma2.py +85 -58
- ai_edge_torch/generative/examples/{experimental/phi → phi}/convert_to_tflite.py +11 -12
- ai_edge_torch/generative/examples/{experimental/phi → phi}/phi2.py +46 -43
- ai_edge_torch/generative/examples/{experimental/gemma → smallm}/convert_to_tflite.py +12 -14
- ai_edge_torch/generative/examples/smallm/smallm.py +122 -0
- ai_edge_torch/generative/examples/stable_diffusion/clip.py +11 -5
- ai_edge_torch/generative/examples/t5/t5.py +35 -22
- ai_edge_torch/generative/examples/t5/t5_attention.py +18 -13
- ai_edge_torch/generative/examples/test_models/toy_model.py +15 -13
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +74 -33
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +25 -6
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +55 -34
- ai_edge_torch/generative/layers/attention.py +77 -73
- ai_edge_torch/generative/layers/builder.py +5 -3
- ai_edge_torch/generative/layers/kv_cache.py +163 -51
- ai_edge_torch/generative/layers/model_config.py +38 -19
- ai_edge_torch/generative/layers/normalization.py +158 -0
- ai_edge_torch/generative/layers/unet/blocks_2d.py +0 -2
- ai_edge_torch/generative/test/{test_experimental_ekv.py → test_kv_cache.py} +12 -24
- ai_edge_torch/generative/test/test_loader.py +1 -1
- ai_edge_torch/generative/test/test_model_conversion.py +72 -34
- ai_edge_torch/generative/test/test_model_conversion_large.py +51 -23
- ai_edge_torch/generative/test/utils.py +54 -0
- ai_edge_torch/generative/utilities/loader.py +15 -15
- ai_edge_torch/generative/utilities/t5_loader.py +21 -20
- ai_edge_torch/odml_torch/lowerings/__init__.py +1 -0
- ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +0 -1
- ai_edge_torch/odml_torch/lowerings/_layer_norm.py +78 -0
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240913.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240913.dist-info}/RECORD +39 -45
- ai_edge_torch/generative/examples/experimental/gemma/gemma.py +0 -219
- ai_edge_torch/generative/examples/experimental/phi/__init__.py +0 -14
- ai_edge_torch/generative/examples/experimental/tiny_llama/__init__.py +0 -14
- ai_edge_torch/generative/examples/experimental/tiny_llama/convert_to_tflite.py +0 -87
- ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py +0 -205
- ai_edge_torch/generative/examples/phi2/__init__.py +0 -14
- ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +0 -67
- ai_edge_torch/generative/examples/phi2/phi2.py +0 -189
- ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +0 -176
- /ai_edge_torch/generative/examples/{experimental → phi}/__init__.py +0 -0
- /ai_edge_torch/generative/examples/{experimental/gemma → smallm}/__init__.py +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240913.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240913.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240913.dist-info}/top_level.txt +0 -0
@@ -13,32 +13,35 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
|
+
"""Example of converting a Gemma2 model to multi-signature tflite model."""
|
17
|
+
|
16
18
|
import os
|
17
|
-
|
19
|
+
import pathlib
|
18
20
|
|
19
21
|
import ai_edge_torch
|
20
22
|
from ai_edge_torch.generative.examples.gemma import gemma2
|
23
|
+
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
21
24
|
from ai_edge_torch.generative.quantize import quant_recipes
|
22
25
|
import torch
|
23
26
|
|
24
27
|
|
25
|
-
def
|
28
|
+
def convert_gemma2_to_tflite(
|
26
29
|
checkpoint_path: str,
|
27
30
|
prefill_seq_len: int = 512,
|
28
31
|
kv_cache_max_len: int = 1024,
|
29
32
|
quantize: bool = True,
|
30
33
|
):
|
31
|
-
"""
|
32
|
-
tflite model.
|
34
|
+
"""Converts a Gemma2 2B model to multi-signature tflite model.
|
33
35
|
|
34
36
|
Args:
|
35
|
-
checkpoint_path (str): The filepath to the model checkpoint, or directory
|
37
|
+
checkpoint_path (str): The filepath to the model checkpoint, or directory
|
38
|
+
holding the checkpoint.
|
36
39
|
prefill_seq_len (int, optional): The maximum size of prefill input tensor.
|
37
40
|
Defaults to 512.
|
38
41
|
kv_cache_max_len (int, optional): The maximum size of KV cache buffer,
|
39
42
|
including both prefill and decode. Defaults to 1024.
|
40
|
-
quantize (bool, optional): Whether the model should be quanized.
|
41
|
-
|
43
|
+
quantize (bool, optional): Whether the model should be quanized. Defaults
|
44
|
+
to True.
|
42
45
|
"""
|
43
46
|
pytorch_model = gemma2.build_2b_model(
|
44
47
|
checkpoint_path, kv_cache_max_len=kv_cache_max_len
|
@@ -48,20 +51,36 @@ def convert_gemma_to_tflite(
|
|
48
51
|
prefill_input_pos = torch.arange(0, prefill_seq_len)
|
49
52
|
decode_token = torch.tensor([[0]], dtype=torch.long)
|
50
53
|
decode_input_pos = torch.tensor([0], dtype=torch.int64)
|
54
|
+
kv = kv_utils.KVCache.from_model_config(pytorch_model.config)
|
51
55
|
|
52
56
|
quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
|
53
57
|
edge_model = (
|
54
58
|
ai_edge_torch.signature(
|
55
|
-
'prefill',
|
59
|
+
'prefill',
|
60
|
+
pytorch_model,
|
61
|
+
sample_kwargs={
|
62
|
+
'tokens': prefill_tokens,
|
63
|
+
'input_pos': prefill_input_pos,
|
64
|
+
'kv_cache': kv,
|
65
|
+
},
|
66
|
+
)
|
67
|
+
.signature(
|
68
|
+
'decode',
|
69
|
+
pytorch_model,
|
70
|
+
sample_kwargs={
|
71
|
+
'tokens': decode_token,
|
72
|
+
'input_pos': decode_input_pos,
|
73
|
+
'kv_cache': kv,
|
74
|
+
},
|
56
75
|
)
|
57
|
-
.signature('decode', pytorch_model, (decode_token, decode_input_pos))
|
58
76
|
.convert(quant_config=quant_config)
|
59
77
|
)
|
78
|
+
quant_suffix = 'q8' if quantize else 'f32'
|
60
79
|
edge_model.export(
|
61
|
-
f'/tmp/
|
80
|
+
f'/tmp/gemma2_{quant_suffix}_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
|
62
81
|
)
|
63
82
|
|
64
83
|
|
65
84
|
if __name__ == '__main__':
|
66
|
-
|
67
|
-
|
85
|
+
path = os.path.join(pathlib.Path.home(), 'Downloads/llm_data/gemma2-2b')
|
86
|
+
convert_gemma2_to_tflite(path)
|
@@ -13,11 +13,14 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
|
+
"""Example of converting a Gemma model to multi-signature tflite model."""
|
17
|
+
|
16
18
|
import os
|
17
|
-
|
19
|
+
import pathlib
|
18
20
|
|
19
21
|
import ai_edge_torch
|
20
22
|
from ai_edge_torch.generative.examples.gemma import gemma
|
23
|
+
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
21
24
|
from ai_edge_torch.generative.quantize import quant_recipes
|
22
25
|
import torch
|
23
26
|
|
@@ -48,20 +51,36 @@ def convert_gemma_to_tflite(
|
|
48
51
|
prefill_input_pos = torch.arange(0, prefill_seq_len)
|
49
52
|
decode_token = torch.tensor([[0]], dtype=torch.long)
|
50
53
|
decode_input_pos = torch.tensor([0], dtype=torch.int64)
|
54
|
+
kv = kv_utils.KVCache.from_model_config(pytorch_model.config)
|
51
55
|
|
52
56
|
quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
|
53
57
|
edge_model = (
|
54
58
|
ai_edge_torch.signature(
|
55
|
-
'prefill',
|
59
|
+
'prefill',
|
60
|
+
pytorch_model,
|
61
|
+
sample_kwargs={
|
62
|
+
'tokens': prefill_tokens,
|
63
|
+
'input_pos': prefill_input_pos,
|
64
|
+
'kv_cache': kv,
|
65
|
+
},
|
66
|
+
)
|
67
|
+
.signature(
|
68
|
+
'decode',
|
69
|
+
pytorch_model,
|
70
|
+
sample_kwargs={
|
71
|
+
'tokens': decode_token,
|
72
|
+
'input_pos': decode_input_pos,
|
73
|
+
'kv_cache': kv,
|
74
|
+
},
|
56
75
|
)
|
57
|
-
.signature('decode', pytorch_model, (decode_token, decode_input_pos))
|
58
76
|
.convert(quant_config=quant_config)
|
59
77
|
)
|
78
|
+
quant_suffix = 'q8' if quantize else 'f32'
|
60
79
|
edge_model.export(
|
61
|
-
f'/tmp/
|
80
|
+
f'/tmp/gemma_{quant_suffix}_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
|
62
81
|
)
|
63
82
|
|
64
83
|
|
65
84
|
if __name__ == '__main__':
|
66
|
-
|
67
|
-
convert_gemma_to_tflite(
|
85
|
+
path = os.path.join(pathlib.Path.home(), 'Downloads/llm_data/gemma-2b')
|
86
|
+
convert_gemma_to_tflite(path)
|
@@ -12,13 +12,15 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
|
15
|
+
|
16
|
+
"""Example of building a Gemma model."""
|
16
17
|
|
17
18
|
import os
|
18
|
-
|
19
|
+
import pathlib
|
19
20
|
|
20
21
|
from ai_edge_torch.generative.layers import attention
|
21
22
|
from ai_edge_torch.generative.layers import builder
|
23
|
+
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
22
24
|
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
23
25
|
import ai_edge_torch.generative.layers.model_config as cfg
|
24
26
|
import ai_edge_torch.generative.utilities.loader as loading_utils
|
@@ -48,7 +50,6 @@ class Gemma(nn.Module):
|
|
48
50
|
def __init__(self, config: cfg.ModelConfig):
|
49
51
|
super().__init__()
|
50
52
|
|
51
|
-
self.config = config
|
52
53
|
# Construct model layers.
|
53
54
|
self.tok_embedding = nn.Embedding(
|
54
55
|
config.vocab_size, config.embedding_dim, padding_idx=0
|
@@ -60,18 +61,20 @@ class Gemma(nn.Module):
|
|
60
61
|
)
|
61
62
|
# Gemma re-uses the embedding as the head projection layer.
|
62
63
|
self.lm_head.weight.data = self.tok_embedding.weight.data
|
64
|
+
# Gemma has only one block config.
|
65
|
+
block_config = config.block_config(0)
|
63
66
|
self.transformer_blocks = nn.ModuleList(
|
64
|
-
attention.TransformerBlock(
|
67
|
+
attention.TransformerBlock(block_config, config)
|
68
|
+
for _ in range(config.num_layers)
|
65
69
|
)
|
66
70
|
self.final_norm = builder.build_norm(
|
67
71
|
config.embedding_dim,
|
68
72
|
config.final_norm_config,
|
69
73
|
)
|
74
|
+
attn_config = block_config.attn_config
|
70
75
|
self.rope_cache = attn_utils.build_rope_cache(
|
71
76
|
size=config.kv_cache_max,
|
72
|
-
dim=int(
|
73
|
-
config.attn_config.rotary_percentage * config.attn_config.head_dim
|
74
|
-
),
|
77
|
+
dim=int(attn_config.rotary_percentage * attn_config.head_dim),
|
75
78
|
base=10_000,
|
76
79
|
condense_ratio=1,
|
77
80
|
dtype=torch.float32,
|
@@ -84,16 +87,22 @@ class Gemma(nn.Module):
|
|
84
87
|
)
|
85
88
|
self.config = config
|
86
89
|
|
87
|
-
# The model's forward function takes in additional k/v cache tensors
|
88
|
-
# and returns the updated k/v cache tensors to the caller.
|
89
|
-
# This can be eliminated if we handle k/v cache updates inside the model itself.
|
90
90
|
@torch.inference_mode
|
91
|
-
def forward(
|
92
|
-
|
91
|
+
def forward(
|
92
|
+
self,
|
93
|
+
tokens: torch.Tensor,
|
94
|
+
input_pos: torch.Tensor,
|
95
|
+
kv_cache: kv_utils.KVCache,
|
96
|
+
) -> dict[torch.Tensor, kv_utils.KVCache]:
|
97
|
+
_, seq_len = tokens.size()
|
93
98
|
assert self.config.max_seq_len >= seq_len, (
|
94
99
|
f"Cannot forward sequence of length {seq_len}, max seq length is only"
|
95
100
|
f" {self.config.max_seq_len}"
|
96
101
|
)
|
102
|
+
assert len(self.transformer_blocks) == len(kv_cache.caches), (
|
103
|
+
"The number of transformer blocks and the number of KV cache entries"
|
104
|
+
" must be the same."
|
105
|
+
)
|
97
106
|
|
98
107
|
cos, sin = self.rope_cache
|
99
108
|
cos = cos.index_select(0, input_pos)
|
@@ -102,15 +111,20 @@ class Gemma(nn.Module):
|
|
102
111
|
mask = mask[:, :, :, : self.config.kv_cache_max]
|
103
112
|
|
104
113
|
# token embeddings of shape (b, t, n_embd)
|
105
|
-
x = self.tok_embedding(
|
114
|
+
x = self.tok_embedding(tokens)
|
106
115
|
x = x * (self.config.embedding_dim**0.5)
|
107
116
|
|
108
|
-
|
109
|
-
|
117
|
+
updated_kv_entires = []
|
118
|
+
for i, block in enumerate(self.transformer_blocks):
|
119
|
+
kv_entry = kv_cache.caches[i] if kv_cache else None
|
120
|
+
x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
|
121
|
+
if kv_entry:
|
122
|
+
updated_kv_entires.append(kv_entry)
|
123
|
+
updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
|
110
124
|
|
111
125
|
x = self.final_norm(x)
|
112
|
-
|
113
|
-
return
|
126
|
+
logits = self.lm_head(x) # (b, t, vocab_size)
|
127
|
+
return {"logits": logits, "kv_cache": updated_kv_cache}
|
114
128
|
|
115
129
|
|
116
130
|
def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
@@ -139,18 +153,20 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
139
153
|
epsilon=1e-6,
|
140
154
|
zero_centered=True,
|
141
155
|
)
|
156
|
+
block_config = cfg.TransformerBlockConfig(
|
157
|
+
attn_config=attn_config,
|
158
|
+
ff_config=ff_config,
|
159
|
+
pre_attention_norm_config=norm_config,
|
160
|
+
post_attention_norm_config=norm_config,
|
161
|
+
)
|
142
162
|
config = cfg.ModelConfig(
|
143
163
|
vocab_size=256000,
|
144
164
|
num_layers=18,
|
145
165
|
max_seq_len=8192,
|
146
166
|
embedding_dim=2048,
|
147
167
|
kv_cache_max_len=kv_cache_max_len,
|
148
|
-
|
149
|
-
ff_config=ff_config,
|
150
|
-
pre_attention_norm_config=norm_config,
|
151
|
-
post_attention_norm_config=norm_config,
|
168
|
+
block_configs=block_config,
|
152
169
|
final_norm_config=norm_config,
|
153
|
-
parallel_residual=False,
|
154
170
|
lm_head_use_bias=False,
|
155
171
|
enable_hlfb=True,
|
156
172
|
)
|
@@ -159,7 +175,8 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
159
175
|
|
160
176
|
def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
|
161
177
|
config = get_model_config_2b(kv_cache_max_len)
|
162
|
-
config.
|
178
|
+
# Gemma has only one block config.
|
179
|
+
config.block_config(0).ff_config.intermediate_size = 128
|
163
180
|
config.vocab_size = 128
|
164
181
|
config.num_layers = 2
|
165
182
|
config.max_seq_len = 2 * kv_cache_max_len
|
@@ -170,32 +187,35 @@ def build_2b_model(checkpoint_path: str, **kwargs) -> nn.Module:
|
|
170
187
|
config = get_model_config_2b(**kwargs)
|
171
188
|
model = Gemma(config)
|
172
189
|
loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
|
173
|
-
#
|
190
|
+
# Since embedding and lm-head use the same weight, we need to set strict
|
174
191
|
# to False.
|
175
192
|
loader.load(model, strict=False)
|
176
193
|
model.eval()
|
177
194
|
return model
|
178
195
|
|
179
196
|
|
180
|
-
def define_and_run_2b() -> None:
|
197
|
+
def define_and_run_2b(checkpoint_path: str) -> None:
|
181
198
|
"""Instantiates and runs a Gemma 2B model."""
|
182
199
|
|
183
|
-
current_dir = Path(__file__).parent.resolve()
|
200
|
+
current_dir = pathlib.Path(__file__).parent.resolve()
|
184
201
|
gemma_goldens = torch.load(current_dir / "gemma_lm_logits.pt")
|
185
202
|
|
186
203
|
kv_cache_max_len = 1024
|
187
|
-
checkpoint_path = os.path.join(Path.home(), "Downloads/llm_data/gemma-2b")
|
188
204
|
model = build_2b_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
|
189
205
|
idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
|
190
206
|
tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.long, device="cpu")
|
191
207
|
tokens[0, :4] = idx
|
192
208
|
input_pos = torch.arange(0, kv_cache_max_len)
|
193
|
-
|
209
|
+
kv = kv_utils.KVCache.from_model_config(model.config)
|
210
|
+
output = model.forward(tokens, input_pos, kv)
|
194
211
|
print("comparing with goldens..")
|
195
212
|
assert torch.allclose(
|
196
|
-
gemma_goldens,
|
213
|
+
gemma_goldens, output["logits"][0, idx.shape[1] - 1, :], atol=1e-02
|
197
214
|
)
|
198
215
|
|
199
216
|
|
200
217
|
if __name__ == "__main__":
|
201
|
-
|
218
|
+
input_checkpoint_path = os.path.join(
|
219
|
+
pathlib.Path.home(), "Downloads/llm_data/gemma-2b"
|
220
|
+
)
|
221
|
+
define_and_run_2b(input_checkpoint_path)
|
@@ -12,14 +12,16 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
|
15
|
+
|
16
|
+
"""Example of building a Gemma2 model."""
|
16
17
|
|
17
18
|
import os
|
18
|
-
|
19
|
+
import pathlib
|
19
20
|
from typing import Optional, Tuple
|
20
21
|
|
21
22
|
from ai_edge_torch.generative.layers import attention
|
22
23
|
from ai_edge_torch.generative.layers import builder
|
24
|
+
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
23
25
|
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
24
26
|
import ai_edge_torch.generative.layers.model_config as cfg
|
25
27
|
import ai_edge_torch.generative.utilities.loader as loading_utils
|
@@ -51,7 +53,8 @@ class Gemma2Block(attention.TransformerBlock):
|
|
51
53
|
rope: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
52
54
|
mask: Optional[torch.Tensor] = None,
|
53
55
|
input_pos: Optional[torch.Tensor] = None,
|
54
|
-
|
56
|
+
kv_cache: kv_utils.KVCacheEntry = None,
|
57
|
+
) -> Tuple[torch.Tensor, Optional[kv_utils.KVCacheEntry]]:
|
55
58
|
"""Forward function of the Gemma2Block.
|
56
59
|
|
57
60
|
Exactly the same as TransformerBlock but we call the post-attention norm
|
@@ -62,17 +65,19 @@ class Gemma2Block(attention.TransformerBlock):
|
|
62
65
|
rope (Tuple[torch.Tensor, torch.Tensor]): the input rope tensor.
|
63
66
|
mask (torch.Tensor): the optional mask tensor.
|
64
67
|
input_pos (torch.Tensor): the optional input position tensor.
|
68
|
+
kv_cache (KVCacheEntry): the optional kv cache entry.
|
65
69
|
|
66
70
|
Returns:
|
67
|
-
output activation from this transformer block
|
71
|
+
output activation from this transformer block, and updated kv cache (if
|
72
|
+
passed in).
|
68
73
|
"""
|
69
74
|
|
70
75
|
x_norm = self.pre_atten_norm(x)
|
71
|
-
attn_out = self.atten_func(x_norm, rope, mask, input_pos)
|
76
|
+
attn_out, kv = self.atten_func(x_norm, rope, mask, input_pos, kv_cache)
|
72
77
|
attn_out_norm = self.post_atten_norm(attn_out)
|
73
78
|
x = x + attn_out_norm
|
74
79
|
output = x + self.ff(x)
|
75
|
-
return output
|
80
|
+
return output, kv
|
76
81
|
|
77
82
|
|
78
83
|
class Gemma2(nn.Module):
|
@@ -81,7 +86,6 @@ class Gemma2(nn.Module):
|
|
81
86
|
def __init__(self, config: cfg.ModelConfig):
|
82
87
|
super().__init__()
|
83
88
|
|
84
|
-
self.config = config
|
85
89
|
# Construct model layers.
|
86
90
|
self.tok_embedding = nn.Embedding(
|
87
91
|
config.vocab_size, config.embedding_dim, padding_idx=0
|
@@ -91,20 +95,22 @@ class Gemma2(nn.Module):
|
|
91
95
|
config.vocab_size,
|
92
96
|
bias=config.lm_head_use_bias,
|
93
97
|
)
|
94
|
-
#
|
98
|
+
# Gemma2 re-uses the embedding as the head projection layer.
|
95
99
|
self.lm_head.weight.data = self.tok_embedding.weight.data
|
96
100
|
self.transformer_blocks = nn.ModuleList(
|
97
|
-
Gemma2Block(config)
|
101
|
+
Gemma2Block(config.block_config(idx), config)
|
102
|
+
for idx in range(config.num_layers)
|
98
103
|
)
|
99
104
|
self.final_norm = builder.build_norm(
|
100
105
|
config.embedding_dim,
|
101
106
|
config.final_norm_config,
|
102
107
|
)
|
108
|
+
# Gemma2 has same hyper parameters for each layer except for attention
|
109
|
+
# types. Use the first layer.
|
110
|
+
attn_config = config.block_config(0).attn_config
|
103
111
|
self.rope_cache = attn_utils.build_rope_cache(
|
104
112
|
size=config.kv_cache_max,
|
105
|
-
dim=int(
|
106
|
-
config.attn_config.rotary_percentage * config.attn_config.head_dim
|
107
|
-
),
|
113
|
+
dim=int(attn_config.rotary_percentage * attn_config.head_dim),
|
108
114
|
base=10_000,
|
109
115
|
condense_ratio=1,
|
110
116
|
dtype=torch.float32,
|
@@ -115,47 +121,56 @@ class Gemma2(nn.Module):
|
|
115
121
|
dtype=torch.float32,
|
116
122
|
device=torch.device("cpu"),
|
117
123
|
)
|
118
|
-
|
119
124
|
self.sliding_window_mask_cache = attn_utils.build_sliding_window_mask_cache(
|
120
125
|
size=config.kv_cache_max,
|
121
|
-
window_size=
|
126
|
+
window_size=attn_config.sliding_window_size,
|
122
127
|
dtype=torch.float32,
|
123
128
|
device=torch.device("cpu"),
|
124
129
|
)
|
125
|
-
|
126
130
|
self.config = config
|
127
131
|
|
128
132
|
def get_attention_mask(
|
129
|
-
self,
|
133
|
+
self, attn_type: cfg.AttentionType, input_pos: torch.Tensor
|
130
134
|
) -> torch.Tensor:
|
131
|
-
if
|
132
|
-
|
133
|
-
self.config.attn_config.attn_types[idx]
|
134
|
-
== cfg.AttentionType.LOCAL_SLIDING
|
135
|
-
):
|
136
|
-
return self.sliding_window_mask_cache.index_select(2, input_pos)
|
137
|
-
|
135
|
+
if attn_type == cfg.AttentionType.LOCAL_SLIDING:
|
136
|
+
return self.sliding_window_mask_cache.index_select(2, input_pos)
|
138
137
|
return self.mask_cache.index_select(2, input_pos)
|
139
138
|
|
140
139
|
@torch.inference_mode
|
141
|
-
def forward(
|
142
|
-
|
140
|
+
def forward(
|
141
|
+
self,
|
142
|
+
tokens: torch.Tensor,
|
143
|
+
input_pos: torch.Tensor,
|
144
|
+
kv_cache: kv_utils.KVCache,
|
145
|
+
) -> dict[torch.Tensor, kv_utils.KVCache]:
|
146
|
+
_, seq_len = tokens.size()
|
143
147
|
assert self.config.max_seq_len >= seq_len, (
|
144
148
|
f"Cannot forward sequence of length {seq_len}, max seq length is only"
|
145
149
|
f" {self.config.max_seq_len}"
|
146
150
|
)
|
151
|
+
assert len(self.transformer_blocks) == len(kv_cache.caches), (
|
152
|
+
"The number of transformer blocks and the number of KV cache entries"
|
153
|
+
" must be the same."
|
154
|
+
)
|
147
155
|
|
148
156
|
cos, sin = self.rope_cache
|
149
157
|
cos = cos.index_select(0, input_pos)
|
150
158
|
sin = sin.index_select(0, input_pos)
|
151
159
|
|
152
160
|
# token embeddings of shape (b, t, n_embd)
|
153
|
-
x = self.tok_embedding(
|
161
|
+
x = self.tok_embedding(tokens)
|
154
162
|
x = x * (self.config.embedding_dim**0.5)
|
155
163
|
|
164
|
+
updated_kv_entires = []
|
156
165
|
for i, block in enumerate(self.transformer_blocks):
|
157
|
-
mask = self.get_attention_mask(
|
158
|
-
|
166
|
+
mask = self.get_attention_mask(
|
167
|
+
block.config.attn_config.attn_type, input_pos
|
168
|
+
)
|
169
|
+
kv_entry = kv_cache.caches[i] if kv_cache else None
|
170
|
+
x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
|
171
|
+
if kv_entry:
|
172
|
+
updated_kv_entires.append(kv_entry)
|
173
|
+
updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
|
159
174
|
|
160
175
|
x = self.final_norm(x)
|
161
176
|
res = self.lm_head(x) # (b, t, vocab_size)
|
@@ -163,7 +178,8 @@ class Gemma2(nn.Module):
|
|
163
178
|
res = res / self.config.final_logit_softcap
|
164
179
|
res = torch.tanh(res)
|
165
180
|
res = res * self.config.final_logit_softcap
|
166
|
-
|
181
|
+
|
182
|
+
return {"logits": res, "kv_cache": updated_kv_cache}
|
167
183
|
|
168
184
|
|
169
185
|
def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
@@ -176,18 +192,6 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
176
192
|
Returns:
|
177
193
|
The model config for a Gemma 2B model.
|
178
194
|
"""
|
179
|
-
attn_config = cfg.AttentionConfig(
|
180
|
-
num_heads=8,
|
181
|
-
head_dim=256,
|
182
|
-
num_query_groups=4,
|
183
|
-
rotary_percentage=1.0,
|
184
|
-
qkv_transpose_before_split=True,
|
185
|
-
logit_softcap=50.0,
|
186
|
-
sliding_window_size=4096,
|
187
|
-
attn_types=[cfg.AttentionType.GLOBAL, cfg.AttentionType.LOCAL_SLIDING]
|
188
|
-
* 13,
|
189
|
-
)
|
190
|
-
|
191
195
|
norm_config = cfg.NormalizationConfig(
|
192
196
|
type=cfg.NormalizationType.RMS_NORM,
|
193
197
|
epsilon=1e-6,
|
@@ -200,18 +204,38 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
200
204
|
pre_ff_norm_config=norm_config,
|
201
205
|
post_ff_norm_config=norm_config,
|
202
206
|
)
|
207
|
+
|
208
|
+
def get_block_config(idx: int) -> cfg.TransformerBlockConfig:
|
209
|
+
attn_config = cfg.AttentionConfig(
|
210
|
+
num_heads=8,
|
211
|
+
head_dim=256,
|
212
|
+
num_query_groups=4,
|
213
|
+
rotary_percentage=1.0,
|
214
|
+
qkv_transpose_before_split=True,
|
215
|
+
logit_softcap=50.0,
|
216
|
+
sliding_window_size=4096,
|
217
|
+
attn_type=(
|
218
|
+
cfg.AttentionType.GLOBAL
|
219
|
+
if idx % 2 == 0
|
220
|
+
else cfg.AttentionType.LOCAL_SLIDING
|
221
|
+
),
|
222
|
+
)
|
223
|
+
return cfg.TransformerBlockConfig(
|
224
|
+
attn_config=attn_config,
|
225
|
+
ff_config=ff_config,
|
226
|
+
pre_attention_norm_config=norm_config,
|
227
|
+
post_attention_norm_config=norm_config,
|
228
|
+
)
|
229
|
+
|
230
|
+
num_layers = 26
|
203
231
|
config = cfg.ModelConfig(
|
204
232
|
vocab_size=256000,
|
205
|
-
num_layers=
|
233
|
+
num_layers=num_layers,
|
206
234
|
max_seq_len=8192,
|
207
235
|
embedding_dim=2304,
|
208
236
|
kv_cache_max_len=kv_cache_max_len,
|
209
|
-
|
210
|
-
ff_config=ff_config,
|
211
|
-
pre_attention_norm_config=norm_config,
|
212
|
-
post_attention_norm_config=norm_config,
|
237
|
+
block_configs=[get_block_config(i) for i in range(num_layers)],
|
213
238
|
final_norm_config=norm_config,
|
214
|
-
parallel_residual=False,
|
215
239
|
lm_head_use_bias=False,
|
216
240
|
enable_hlfb=True,
|
217
241
|
final_logit_softcap=30.0,
|
@@ -221,14 +245,16 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
221
245
|
|
222
246
|
def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
|
223
247
|
config = get_model_config_2b(kv_cache_max_len)
|
224
|
-
config.attn_config.num_heads = 4
|
225
|
-
config.attn_config.head_dim = 64
|
226
|
-
config.attn_config.sliding_window_size = 64
|
227
|
-
config.ff_config.intermediate_size = 128
|
228
248
|
config.vocab_size = 128
|
229
249
|
config.num_layers = 2
|
230
250
|
config.max_seq_len = 2 * kv_cache_max_len
|
231
251
|
config.embedding_dim = 128
|
252
|
+
config.block_configs = config.block_configs[: config.num_layers]
|
253
|
+
for block_config in config.block_configs:
|
254
|
+
block_config.attn_config.num_heads = 4
|
255
|
+
block_config.attn_config.head_dim = 64
|
256
|
+
block_config.attn_config.sliding_window_size = 64
|
257
|
+
block_config.ff_config.intermediate_size = 128
|
232
258
|
return config
|
233
259
|
|
234
260
|
|
@@ -236,21 +262,20 @@ def build_2b_model(checkpoint_path: str, **kwargs) -> nn.Module:
|
|
236
262
|
config = get_model_config_2b(**kwargs)
|
237
263
|
model = Gemma2(config)
|
238
264
|
loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
|
239
|
-
#
|
265
|
+
# Since embedding and lm-head use the same weight, we need to set strict
|
240
266
|
# to False.
|
241
267
|
loader.load(model, strict=False)
|
242
268
|
model.eval()
|
243
269
|
return model
|
244
270
|
|
245
271
|
|
246
|
-
def define_and_run_2b() -> None:
|
272
|
+
def define_and_run_2b(checkpoint_path: str) -> None:
|
247
273
|
"""Instantiates and runs a Gemma2 2B model."""
|
248
274
|
|
249
|
-
current_dir = Path(__file__).parent.resolve()
|
275
|
+
current_dir = pathlib.Path(__file__).parent.resolve()
|
250
276
|
gemma2_goldens = torch.load(current_dir / "gemma2it_2b_golden.pt")
|
251
277
|
print("Running GEMMA 2")
|
252
278
|
kv_cache_max_len = 1024
|
253
|
-
checkpoint_path = os.path.join(Path.home(), "Downloads/llm_data/gemma2-2b")
|
254
279
|
model = build_2b_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
|
255
280
|
toks = torch.from_numpy(
|
256
281
|
np.array([2, 651, 9456, 576, 573, 3520, 3858, 603, 235248])
|
@@ -258,11 +283,13 @@ def define_and_run_2b() -> None:
|
|
258
283
|
tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.long, device="cpu")
|
259
284
|
tokens[0, :9] = toks
|
260
285
|
input_pos = torch.arange(0, kv_cache_max_len)
|
261
|
-
|
262
|
-
|
286
|
+
kv = kv_utils.KVCache.from_model_config(model.config)
|
287
|
+
out = model.forward(tokens, input_pos, kv)
|
288
|
+
out_final = out["logits"][0, 8, :]
|
263
289
|
assert torch.allclose(gemma2_goldens, out_final, atol=1e-04)
|
264
290
|
|
265
291
|
|
266
292
|
if __name__ == "__main__":
|
267
293
|
torch.set_printoptions(sci_mode=True)
|
268
|
-
|
294
|
+
path = os.path.join(pathlib.Path.home(), "Downloads/llm_data/gemma2-2b")
|
295
|
+
define_and_run_2b(path)
|
@@ -12,16 +12,15 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
|
16
|
-
|
17
|
-
# Please use with caution.
|
15
|
+
|
16
|
+
"""Example of converting a Phi-2 model to multi-signature tflite model."""
|
18
17
|
|
19
18
|
import os
|
20
|
-
|
19
|
+
import pathlib
|
21
20
|
|
22
21
|
import ai_edge_torch
|
23
|
-
from ai_edge_torch.generative.examples.
|
24
|
-
from ai_edge_torch.generative.layers
|
22
|
+
from ai_edge_torch.generative.examples.phi import phi2
|
23
|
+
from ai_edge_torch.generative.layers import kv_cache
|
25
24
|
from ai_edge_torch.generative.quantize import quant_recipes
|
26
25
|
import torch
|
27
26
|
|
@@ -32,9 +31,8 @@ def convert_phi2_to_tflite(
|
|
32
31
|
kv_cache_max_len: int = 1024,
|
33
32
|
quantize: bool = True,
|
34
33
|
):
|
35
|
-
"""
|
34
|
+
"""Converts a Phi-2 model to multi-signature tflite model.
|
36
35
|
|
37
|
-
tflite model.
|
38
36
|
Args:
|
39
37
|
checkpoint_path (str): The filepath to the model checkpoint, or directory
|
40
38
|
holding the checkpoint.
|
@@ -53,7 +51,7 @@ def convert_phi2_to_tflite(
|
|
53
51
|
prefill_input_pos = torch.arange(0, prefill_seq_len)
|
54
52
|
decode_token = torch.tensor([[0]], dtype=torch.long)
|
55
53
|
decode_input_pos = torch.tensor([0], dtype=torch.int64)
|
56
|
-
kv =
|
54
|
+
kv = kv_cache.KVCache.from_model_config(pytorch_model.config)
|
57
55
|
|
58
56
|
quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
|
59
57
|
edge_model = (
|
@@ -77,11 +75,12 @@ def convert_phi2_to_tflite(
|
|
77
75
|
)
|
78
76
|
.convert(quant_config=quant_config)
|
79
77
|
)
|
78
|
+
quant_suffix = 'q8' if quantize else 'f32'
|
80
79
|
edge_model.export(
|
81
|
-
f'/tmp/
|
80
|
+
f'/tmp/phi2_{quant_suffix}_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
|
82
81
|
)
|
83
82
|
|
84
83
|
|
85
84
|
if __name__ == '__main__':
|
86
|
-
|
87
|
-
convert_phi2_to_tflite(
|
85
|
+
path = os.path.join(pathlib.Path.home(), 'Downloads/llm_data/phi2')
|
86
|
+
convert_phi2_to_tflite(path)
|