ai-edge-torch-nightly 0.3.0.dev20240910__py3-none-any.whl → 0.3.0.dev20240914__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/_convert/conversion.py +2 -1
- ai_edge_torch/_convert/fx_passes/__init__.py +5 -41
- ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py +3 -4
- ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py +3 -4
- ai_edge_torch/_convert/fx_passes/inject_mlir_debuginfo_pass.py +3 -4
- ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +4 -5
- ai_edge_torch/config.py +4 -1
- ai_edge_torch/fx_pass_base.py +101 -0
- ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +35 -16
- ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +29 -10
- ai_edge_torch/generative/examples/gemma/gemma.py +52 -32
- ai_edge_torch/generative/examples/gemma/gemma2.py +87 -60
- ai_edge_torch/generative/examples/{experimental/gemma → openelm}/convert_to_tflite.py +16 -18
- ai_edge_torch/generative/examples/openelm/openelm.py +237 -0
- ai_edge_torch/generative/examples/{experimental/phi → phi}/convert_to_tflite.py +15 -16
- ai_edge_torch/generative/examples/{experimental/phi → phi}/phi2.py +48 -45
- ai_edge_torch/generative/examples/{experimental/tiny_llama → smollm}/convert_to_tflite.py +16 -17
- ai_edge_torch/generative/examples/smollm/smollm.py +131 -0
- ai_edge_torch/generative/examples/stable_diffusion/clip.py +12 -6
- ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +1 -1
- ai_edge_torch/generative/examples/t5/convert_to_tflite.py +20 -20
- ai_edge_torch/generative/examples/t5/t5.py +43 -30
- ai_edge_torch/generative/examples/t5/t5_attention.py +18 -13
- ai_edge_torch/generative/examples/test_models/toy_model.py +15 -13
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +75 -34
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +29 -10
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +57 -36
- ai_edge_torch/generative/fx_passes/__init__.py +4 -4
- ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +3 -4
- ai_edge_torch/generative/layers/attention.py +84 -73
- ai_edge_torch/generative/layers/builder.py +38 -14
- ai_edge_torch/generative/layers/feed_forward.py +26 -8
- ai_edge_torch/generative/layers/kv_cache.py +163 -51
- ai_edge_torch/generative/layers/model_config.py +61 -33
- ai_edge_torch/generative/layers/normalization.py +158 -0
- ai_edge_torch/generative/layers/unet/blocks_2d.py +0 -2
- ai_edge_torch/generative/quantize/example.py +2 -2
- ai_edge_torch/generative/test/{test_experimental_ekv.py → test_kv_cache.py} +12 -24
- ai_edge_torch/generative/test/test_loader.py +1 -1
- ai_edge_torch/generative/test/test_model_conversion.py +77 -62
- ai_edge_torch/generative/test/test_model_conversion_large.py +61 -68
- ai_edge_torch/generative/test/test_quantize.py +5 -5
- ai_edge_torch/generative/test/utils.py +54 -0
- ai_edge_torch/generative/utilities/loader.py +28 -15
- ai_edge_torch/generative/utilities/t5_loader.py +21 -20
- ai_edge_torch/odml_torch/export.py +40 -0
- ai_edge_torch/odml_torch/lowerings/__init__.py +1 -0
- ai_edge_torch/odml_torch/lowerings/_basic.py +44 -0
- ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +0 -2
- ai_edge_torch/odml_torch/lowerings/_layer_norm.py +78 -0
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240914.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240914.dist-info}/RECORD +59 -63
- ai_edge_torch/_convert/fx_passes/_pass_base.py +0 -53
- ai_edge_torch/_convert/fx_passes/canonicalize_pass.py +0 -35
- ai_edge_torch/generative/examples/experimental/gemma/gemma.py +0 -219
- ai_edge_torch/generative/examples/experimental/tiny_llama/__init__.py +0 -14
- ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py +0 -205
- ai_edge_torch/generative/examples/phi2/__init__.py +0 -14
- ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +0 -67
- ai_edge_torch/generative/examples/phi2/phi2.py +0 -189
- ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +0 -176
- /ai_edge_torch/generative/examples/{experimental → openelm}/__init__.py +0 -0
- /ai_edge_torch/generative/examples/{experimental/gemma → phi}/__init__.py +0 -0
- /ai_edge_torch/generative/examples/{experimental/phi → smollm}/__init__.py +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240914.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240914.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240914.dist-info}/top_level.txt +0 -0
@@ -1,205 +0,0 @@
|
|
1
|
-
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
# Example of building a TinyLlama model from the Edge Generative API layers.
|
16
|
-
#
|
17
|
-
# Note: This is an experimental version of TinyLlama with external KV cache.
|
18
|
-
# Please use with caution.
|
19
|
-
|
20
|
-
import os
|
21
|
-
from pathlib import Path
|
22
|
-
from typing import Tuple
|
23
|
-
|
24
|
-
from ai_edge_torch.generative.layers import builder
|
25
|
-
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
26
|
-
from ai_edge_torch.generative.layers.experimental import attention
|
27
|
-
from ai_edge_torch.generative.layers.experimental import ekv_cache as kv_utils
|
28
|
-
import ai_edge_torch.generative.layers.model_config as cfg
|
29
|
-
import ai_edge_torch.generative.utilities.loader as loading_utils
|
30
|
-
import numpy as np
|
31
|
-
import torch
|
32
|
-
from torch import nn
|
33
|
-
|
34
|
-
|
35
|
-
TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
36
|
-
ff_up_proj="model.layers.{}.mlp.up_proj",
|
37
|
-
ff_down_proj="model.layers.{}.mlp.down_proj",
|
38
|
-
ff_gate_proj="model.layers.{}.mlp.gate_proj",
|
39
|
-
attn_query_proj="model.layers.{}.self_attn.q_proj",
|
40
|
-
attn_key_proj="model.layers.{}.self_attn.k_proj",
|
41
|
-
attn_value_proj="model.layers.{}.self_attn.v_proj",
|
42
|
-
attn_output_proj="model.layers.{}.self_attn.o_proj",
|
43
|
-
pre_attn_norm="model.layers.{}.input_layernorm",
|
44
|
-
post_attn_norm="model.layers.{}.post_attention_layernorm",
|
45
|
-
embedding="model.embed_tokens",
|
46
|
-
final_norm="model.norm",
|
47
|
-
lm_head="lm_head",
|
48
|
-
)
|
49
|
-
|
50
|
-
|
51
|
-
class TinyLLamma(nn.Module):
|
52
|
-
"""A TinyLlama model built from the Edge Generative API layers."""
|
53
|
-
|
54
|
-
def __init__(self, config: cfg.ModelConfig):
|
55
|
-
super().__init__()
|
56
|
-
|
57
|
-
self.config = config
|
58
|
-
# Construct model layers.
|
59
|
-
self.lm_head = nn.Linear(
|
60
|
-
config.embedding_dim, config.vocab_size, bias=config.lm_head_use_bias
|
61
|
-
)
|
62
|
-
self.tok_embedding = nn.Embedding(
|
63
|
-
config.vocab_size, config.embedding_dim, padding_idx=0
|
64
|
-
)
|
65
|
-
self.transformer_blocks = nn.ModuleList(
|
66
|
-
attention.TransformerBlock(config) for _ in range(config.num_layers)
|
67
|
-
)
|
68
|
-
self.final_norm = builder.build_norm(
|
69
|
-
config.embedding_dim,
|
70
|
-
config.final_norm_config,
|
71
|
-
)
|
72
|
-
self.rope_cache = attn_utils.build_rope_cache(
|
73
|
-
size=config.kv_cache_max,
|
74
|
-
dim=int(
|
75
|
-
config.attn_config.rotary_percentage * config.attn_config.head_dim
|
76
|
-
),
|
77
|
-
base=10_000,
|
78
|
-
condense_ratio=1,
|
79
|
-
dtype=torch.float32,
|
80
|
-
device=torch.device("cpu"),
|
81
|
-
)
|
82
|
-
self.mask_cache = attn_utils.build_causal_mask_cache(
|
83
|
-
size=config.kv_cache_max,
|
84
|
-
dtype=torch.float32,
|
85
|
-
device=torch.device("cpu"),
|
86
|
-
)
|
87
|
-
self.config = config
|
88
|
-
|
89
|
-
@torch.inference_mode
|
90
|
-
def forward(
|
91
|
-
self,
|
92
|
-
tokens: torch.Tensor,
|
93
|
-
input_pos: torch.Tensor,
|
94
|
-
kv_cache: kv_utils.EKVCache,
|
95
|
-
) -> Tuple[torch.Tensor, kv_utils.EKVCache]:
|
96
|
-
_, seq_len = tokens.size()
|
97
|
-
assert self.config.max_seq_len >= seq_len, (
|
98
|
-
f"Cannot forward sequence of length {seq_len}, max seq length is only"
|
99
|
-
f" {self.config.max_seq_len}"
|
100
|
-
)
|
101
|
-
|
102
|
-
cos, sin = self.rope_cache
|
103
|
-
cos = cos.index_select(0, input_pos)
|
104
|
-
sin = sin.index_select(0, input_pos)
|
105
|
-
mask = self.mask_cache.index_select(2, input_pos)
|
106
|
-
mask = mask[:, :, :, : self.config.kv_cache_max]
|
107
|
-
|
108
|
-
# token embeddings of shape (b, t, n_embd)
|
109
|
-
x = self.tok_embedding(tokens)
|
110
|
-
|
111
|
-
updated_kv_entires = []
|
112
|
-
for i, block in enumerate(self.transformer_blocks):
|
113
|
-
kv_entry = kv_cache.caches[i] if kv_cache else None
|
114
|
-
x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
|
115
|
-
if kv_entry:
|
116
|
-
updated_kv_entires.append(kv_entry)
|
117
|
-
updated_kv_cache = kv_utils.EKVCache(tuple(updated_kv_entires))
|
118
|
-
|
119
|
-
x = self.final_norm(x)
|
120
|
-
res = self.lm_head(x) # (b, t, vocab_size)
|
121
|
-
return res, updated_kv_cache
|
122
|
-
|
123
|
-
|
124
|
-
def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
125
|
-
"""Returns the model config for a TinyLlama model.
|
126
|
-
|
127
|
-
Args:
|
128
|
-
kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
|
129
|
-
is 1024.
|
130
|
-
|
131
|
-
Returns:
|
132
|
-
The model config for a TinyLlama model.
|
133
|
-
"""
|
134
|
-
attn_config = cfg.AttentionConfig(
|
135
|
-
num_heads=32,
|
136
|
-
head_dim=64,
|
137
|
-
num_query_groups=4,
|
138
|
-
rotary_percentage=1.0,
|
139
|
-
)
|
140
|
-
ff_config = cfg.FeedForwardConfig(
|
141
|
-
type=cfg.FeedForwardType.GATED,
|
142
|
-
activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
|
143
|
-
intermediate_size=5632,
|
144
|
-
)
|
145
|
-
norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
|
146
|
-
config = cfg.ModelConfig(
|
147
|
-
vocab_size=32000,
|
148
|
-
num_layers=22,
|
149
|
-
max_seq_len=2048,
|
150
|
-
embedding_dim=2048,
|
151
|
-
kv_cache_max_len=kv_cache_max_len,
|
152
|
-
attn_config=attn_config,
|
153
|
-
ff_config=ff_config,
|
154
|
-
pre_attention_norm_config=norm_config,
|
155
|
-
post_attention_norm_config=norm_config,
|
156
|
-
final_norm_config=norm_config,
|
157
|
-
enable_hlfb=True,
|
158
|
-
)
|
159
|
-
return config
|
160
|
-
|
161
|
-
|
162
|
-
def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
|
163
|
-
config = get_model_config(**kwargs)
|
164
|
-
config.vocab_size = 128
|
165
|
-
config.num_layers = 2
|
166
|
-
config.ff_config.intermediate_size = 256
|
167
|
-
return config
|
168
|
-
|
169
|
-
|
170
|
-
def build_model(
|
171
|
-
checkpoint_path: str, test_model: bool = False, **kwargs
|
172
|
-
) -> nn.Module:
|
173
|
-
"""Instantiates the model instance and load checkpoint if provided."""
|
174
|
-
config = (
|
175
|
-
get_fake_model_config(**kwargs)
|
176
|
-
if test_model
|
177
|
-
else get_model_config(**kwargs)
|
178
|
-
)
|
179
|
-
model = TinyLLamma(config)
|
180
|
-
if checkpoint_path is not None:
|
181
|
-
loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
|
182
|
-
loader.load(model)
|
183
|
-
model.eval()
|
184
|
-
return model
|
185
|
-
|
186
|
-
|
187
|
-
def define_and_run(checkpoint_path: str, test_model: bool = False) -> None:
|
188
|
-
"""Instantiates and runs a TinyLlama model."""
|
189
|
-
|
190
|
-
kv_cache_max_len = 1024
|
191
|
-
model = build_model(
|
192
|
-
checkpoint_path, test_model=test_model, kv_cache_max_len=kv_cache_max_len
|
193
|
-
)
|
194
|
-
idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
|
195
|
-
tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.long, device="cpu")
|
196
|
-
tokens[0, :4] = idx
|
197
|
-
input_pos = torch.arange(0, kv_cache_max_len)
|
198
|
-
kv = kv_utils.EKVCache.from_model_config(model.config)
|
199
|
-
print("running an inference")
|
200
|
-
print(model.forward(tokens, input_pos, kv))
|
201
|
-
|
202
|
-
|
203
|
-
if __name__ == "__main__":
|
204
|
-
input_checkpoint_path = os.path.join(Path.home(), "Downloads/tiny_llama")
|
205
|
-
define_and_run(input_checkpoint_path)
|
@@ -1,14 +0,0 @@
|
|
1
|
-
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
@@ -1,67 +0,0 @@
|
|
1
|
-
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
import os
|
17
|
-
from pathlib import Path
|
18
|
-
|
19
|
-
import ai_edge_torch
|
20
|
-
from ai_edge_torch.generative.examples.phi2 import phi2
|
21
|
-
from ai_edge_torch.generative.quantize import quant_recipes
|
22
|
-
import torch
|
23
|
-
|
24
|
-
|
25
|
-
def convert_phi2_to_tflite(
|
26
|
-
checkpoint_path: str,
|
27
|
-
prefill_seq_len: int = 512,
|
28
|
-
kv_cache_max_len: int = 1024,
|
29
|
-
quantize: bool = True,
|
30
|
-
):
|
31
|
-
"""Converts a Phi-2 model to multi-signature tflite model.
|
32
|
-
|
33
|
-
Args:
|
34
|
-
checkpoint_path (str): The filepath to the model checkpoint, or directory
|
35
|
-
holding the checkpoint.
|
36
|
-
prefill_seq_len (int, optional): The maximum size of prefill input tensor.
|
37
|
-
Defaults to 512.
|
38
|
-
kv_cache_max_len (int, optional): The maximum size of KV cache buffer,
|
39
|
-
including both prefill and decode. Defaults to 1024.
|
40
|
-
quantize (bool, optional): Whether the model should be quanized. Defaults
|
41
|
-
to True.
|
42
|
-
"""
|
43
|
-
pytorch_model = phi2.build_model(
|
44
|
-
checkpoint_path, kv_cache_max_len=kv_cache_max_len
|
45
|
-
)
|
46
|
-
# Tensors used to trace the model graph during conversion.
|
47
|
-
prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.long)
|
48
|
-
prefill_input_pos = torch.arange(0, prefill_seq_len)
|
49
|
-
decode_token = torch.tensor([[0]], dtype=torch.long)
|
50
|
-
decode_input_pos = torch.tensor([0], dtype=torch.int64)
|
51
|
-
|
52
|
-
quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
|
53
|
-
edge_model = (
|
54
|
-
ai_edge_torch.signature(
|
55
|
-
'prefill', pytorch_model, (prefill_tokens, prefill_input_pos)
|
56
|
-
)
|
57
|
-
.signature('decode', pytorch_model, (decode_token, decode_input_pos))
|
58
|
-
.convert(quant_config=quant_config)
|
59
|
-
)
|
60
|
-
edge_model.export(
|
61
|
-
f'/tmp/phi2_seq{prefill_seq_len}_kv{kv_cache_max_len}.tflite'
|
62
|
-
)
|
63
|
-
|
64
|
-
|
65
|
-
if __name__ == '__main__':
|
66
|
-
checkpoint_path = os.path.join(Path.home(), 'Downloads/llm_data/phi2')
|
67
|
-
convert_phi2_to_tflite(checkpoint_path)
|
@@ -1,189 +0,0 @@
|
|
1
|
-
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
# Example of building phi-2 model from the Edge Generative API layers.
|
16
|
-
|
17
|
-
|
18
|
-
import os
|
19
|
-
from pathlib import Path
|
20
|
-
|
21
|
-
from ai_edge_torch.generative.layers import attention
|
22
|
-
from ai_edge_torch.generative.layers import builder
|
23
|
-
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
24
|
-
import ai_edge_torch.generative.layers.model_config as cfg
|
25
|
-
import ai_edge_torch.generative.utilities.loader as loading_utils
|
26
|
-
import numpy as np
|
27
|
-
import torch
|
28
|
-
from torch import nn
|
29
|
-
|
30
|
-
TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
31
|
-
ff_up_proj="model.layers.{}.mlp.fc1",
|
32
|
-
ff_down_proj="model.layers.{}.mlp.fc2",
|
33
|
-
attn_query_proj="model.layers.{}.self_attn.q_proj",
|
34
|
-
attn_key_proj="model.layers.{}.self_attn.k_proj",
|
35
|
-
attn_value_proj="model.layers.{}.self_attn.v_proj",
|
36
|
-
attn_output_proj="model.layers.{}.self_attn.dense",
|
37
|
-
pre_attn_norm="model.layers.{}.input_layernorm",
|
38
|
-
embedding="model.embed_tokens",
|
39
|
-
final_norm="model.final_layernorm",
|
40
|
-
lm_head="lm_head",
|
41
|
-
)
|
42
|
-
|
43
|
-
|
44
|
-
class Phi2(nn.Module):
|
45
|
-
"""A Phi-2 model built from the Edge Generative API layers."""
|
46
|
-
|
47
|
-
def __init__(self, config: cfg.ModelConfig):
|
48
|
-
super().__init__()
|
49
|
-
|
50
|
-
self.config = config
|
51
|
-
# Construct model layers.
|
52
|
-
self.lm_head = nn.Linear(
|
53
|
-
config.embedding_dim, config.vocab_size, bias=config.lm_head_use_bias
|
54
|
-
)
|
55
|
-
self.tok_embedding = nn.Embedding(
|
56
|
-
config.vocab_size, config.embedding_dim, padding_idx=0
|
57
|
-
)
|
58
|
-
self.transformer_blocks = nn.ModuleList(
|
59
|
-
attention.TransformerBlock(config) for _ in range(config.num_layers)
|
60
|
-
)
|
61
|
-
self.final_norm = builder.build_norm(
|
62
|
-
config.embedding_dim,
|
63
|
-
config.final_norm_config,
|
64
|
-
)
|
65
|
-
self.rope_cache = attn_utils.build_rope_cache(
|
66
|
-
size=config.kv_cache_max,
|
67
|
-
dim=int(
|
68
|
-
config.attn_config.rotary_percentage * config.attn_config.head_dim
|
69
|
-
),
|
70
|
-
base=10_000,
|
71
|
-
condense_ratio=1,
|
72
|
-
dtype=torch.float32,
|
73
|
-
device=torch.device("cpu"),
|
74
|
-
)
|
75
|
-
self.mask_cache = attn_utils.build_causal_mask_cache(
|
76
|
-
size=config.kv_cache_max,
|
77
|
-
dtype=torch.float32,
|
78
|
-
device=torch.device("cpu"),
|
79
|
-
)
|
80
|
-
self.config = config
|
81
|
-
|
82
|
-
# The model's forward function takes in additional k/v cache tensors
|
83
|
-
# and returns the updated k/v cache tensors to the caller.
|
84
|
-
# This can be eliminated if we handle k/v cache updates inside the model itself.
|
85
|
-
@torch.inference_mode
|
86
|
-
def forward(self, idx: torch.Tensor, input_pos: torch.Tensor) -> torch.Tensor:
|
87
|
-
_, seq_len = idx.size()
|
88
|
-
assert self.config.max_seq_len >= seq_len, (
|
89
|
-
f"Cannot forward sequence of length {seq_len}, max seq length is only"
|
90
|
-
f" {self.config.max_seq_len}"
|
91
|
-
)
|
92
|
-
|
93
|
-
cos, sin = self.rope_cache
|
94
|
-
cos = cos.index_select(0, input_pos)
|
95
|
-
sin = sin.index_select(0, input_pos)
|
96
|
-
mask = self.mask_cache.index_select(2, input_pos)
|
97
|
-
mask = mask[:, :, :, : self.config.kv_cache_max]
|
98
|
-
|
99
|
-
# forward the model itself
|
100
|
-
x = self.tok_embedding(idx) # token embeddings of shape (b, t, n_embd)
|
101
|
-
|
102
|
-
for _, block in enumerate(self.transformer_blocks):
|
103
|
-
x = block(x, (cos, sin), mask, input_pos)
|
104
|
-
|
105
|
-
x = self.final_norm(x)
|
106
|
-
res = self.lm_head(x) # (b, t, vocab_size)
|
107
|
-
return res
|
108
|
-
|
109
|
-
|
110
|
-
def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
111
|
-
"""Returns the model config for a Phi-2 model.
|
112
|
-
|
113
|
-
Args:
|
114
|
-
kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
|
115
|
-
is 1024.
|
116
|
-
|
117
|
-
Returns:
|
118
|
-
The model config for a Phi-2 model.
|
119
|
-
"""
|
120
|
-
attn_config = cfg.AttentionConfig(
|
121
|
-
num_heads=32,
|
122
|
-
head_dim=80,
|
123
|
-
num_query_groups=32,
|
124
|
-
rotary_percentage=0.4,
|
125
|
-
qkv_use_bias=True,
|
126
|
-
output_proj_use_bias=True,
|
127
|
-
)
|
128
|
-
ff_config = cfg.FeedForwardConfig(
|
129
|
-
type=cfg.FeedForwardType.SEQUENTIAL,
|
130
|
-
activation=cfg.ActivationConfig(cfg.ActivationType.GELU_TANH),
|
131
|
-
intermediate_size=10240,
|
132
|
-
use_bias=True,
|
133
|
-
)
|
134
|
-
norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.LAYER_NORM)
|
135
|
-
config = cfg.ModelConfig(
|
136
|
-
vocab_size=51200,
|
137
|
-
num_layers=32,
|
138
|
-
max_seq_len=2048,
|
139
|
-
kv_cache_max_len=kv_cache_max_len,
|
140
|
-
embedding_dim=2560,
|
141
|
-
attn_config=attn_config,
|
142
|
-
ff_config=ff_config,
|
143
|
-
pre_attention_norm_config=norm_config,
|
144
|
-
final_norm_config=norm_config,
|
145
|
-
parallel_residual=True,
|
146
|
-
lm_head_use_bias=True,
|
147
|
-
enable_hlfb=True,
|
148
|
-
)
|
149
|
-
return config
|
150
|
-
|
151
|
-
|
152
|
-
def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
|
153
|
-
config = get_model_config(kv_cache_max_len)
|
154
|
-
config.vocab_size = 128
|
155
|
-
config.num_layers = 2
|
156
|
-
config.max_seq_len = 2 * kv_cache_max_len
|
157
|
-
config.ff_config.intermediate_size = 128
|
158
|
-
return config
|
159
|
-
|
160
|
-
|
161
|
-
def build_model(checkpoint_path, **kwargs) -> nn.Module:
|
162
|
-
config = get_model_config(**kwargs)
|
163
|
-
model = Phi2(config)
|
164
|
-
loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
|
165
|
-
loader.load(model)
|
166
|
-
return model
|
167
|
-
|
168
|
-
|
169
|
-
def define_and_run() -> None:
|
170
|
-
"""Instantiates and runs a Phi-2 model."""
|
171
|
-
|
172
|
-
current_dir = Path(__file__).parent.resolve()
|
173
|
-
phi2_goldens = torch.load(current_dir / "phi2_lm_logits.pt")
|
174
|
-
kv_cache_max_len = 1024
|
175
|
-
checkpoint_path = os.path.join(Path.home(), "Downloads/llm_data/phi2")
|
176
|
-
model = build_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
|
177
|
-
idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
|
178
|
-
tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.long, device="cpu")
|
179
|
-
tokens[0, :4] = idx
|
180
|
-
input_pos = torch.arange(0, kv_cache_max_len)
|
181
|
-
lm_logits = model.forward(tokens, input_pos)
|
182
|
-
print("comparing with goldens..")
|
183
|
-
assert torch.allclose(
|
184
|
-
phi2_goldens, lm_logits[0, idx.shape[1] - 1, :], atol=1e-05
|
185
|
-
)
|
186
|
-
|
187
|
-
|
188
|
-
if __name__ == "__main__":
|
189
|
-
define_and_run()
|
@@ -1,176 +0,0 @@
|
|
1
|
-
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
# A toy example which has basic transformer block (w/ externalized KV-Cache).
|
16
|
-
|
17
|
-
from typing import Tuple
|
18
|
-
|
19
|
-
import ai_edge_torch
|
20
|
-
from ai_edge_torch import lowertools
|
21
|
-
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
22
|
-
import ai_edge_torch.generative.layers.builder as builder
|
23
|
-
from ai_edge_torch.generative.layers.experimental import ekv_cache as kv_utils
|
24
|
-
from ai_edge_torch.generative.layers.experimental.attention import TransformerBlock # NOQA
|
25
|
-
import ai_edge_torch.generative.layers.model_config as cfg
|
26
|
-
import torch
|
27
|
-
import torch.nn as nn
|
28
|
-
|
29
|
-
RoPECache = Tuple[torch.Tensor, torch.Tensor]
|
30
|
-
|
31
|
-
|
32
|
-
class ToyModelWithExternalKV(torch.nn.Module):
|
33
|
-
|
34
|
-
def __init__(self, config: cfg.ModelConfig) -> None:
|
35
|
-
super().__init__()
|
36
|
-
self.lm_head = nn.Linear(
|
37
|
-
config.embedding_dim, config.vocab_size, bias=config.lm_head_use_bias
|
38
|
-
)
|
39
|
-
self.tok_embedding = nn.Embedding(config.vocab_size, config.embedding_dim)
|
40
|
-
self.transformer_blocks = nn.ModuleList(
|
41
|
-
TransformerBlock(config) for _ in range(config.num_layers)
|
42
|
-
)
|
43
|
-
self.final_norm = builder.build_norm(
|
44
|
-
config.embedding_dim,
|
45
|
-
config.final_norm_config,
|
46
|
-
)
|
47
|
-
self.rope_cache = attn_utils.build_rope_cache(
|
48
|
-
size=config.max_seq_len,
|
49
|
-
dim=int(
|
50
|
-
config.attn_config.rotary_percentage * config.attn_config.head_dim
|
51
|
-
),
|
52
|
-
base=10_000,
|
53
|
-
condense_ratio=1,
|
54
|
-
dtype=torch.float32,
|
55
|
-
device=torch.device('cpu'),
|
56
|
-
)
|
57
|
-
self.mask_cache = attn_utils.build_causal_mask_cache(
|
58
|
-
size=config.max_seq_len, dtype=torch.float32, device=torch.device('cpu')
|
59
|
-
)
|
60
|
-
self.config = config
|
61
|
-
|
62
|
-
def forward(
|
63
|
-
self,
|
64
|
-
tokens: torch.Tensor,
|
65
|
-
input_pos: torch.Tensor,
|
66
|
-
kv_cache: kv_utils.EKVCache,
|
67
|
-
) -> Tuple[torch.Tensor, kv_utils.EKVCache]:
|
68
|
-
x = self.tok_embedding(tokens)
|
69
|
-
cos, sin = self.rope_cache
|
70
|
-
cos = cos.index_select(0, input_pos)
|
71
|
-
sin = sin.index_select(0, input_pos)
|
72
|
-
mask = self.mask_cache.index_select(2, input_pos)
|
73
|
-
mask = mask[:, :, :, : self.config.max_seq_len]
|
74
|
-
|
75
|
-
updated_kv_entires = []
|
76
|
-
for i, block in enumerate(self.transformer_blocks):
|
77
|
-
kv_entry = kv_cache.caches[i] if kv_cache else None
|
78
|
-
x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
|
79
|
-
if kv_entry:
|
80
|
-
updated_kv_entires.append(kv_entry)
|
81
|
-
|
82
|
-
x = self.final_norm(x)
|
83
|
-
updated_kv_cache = kv_utils.EKVCache(tuple(updated_kv_entires))
|
84
|
-
return self.lm_head(x), updated_kv_cache
|
85
|
-
|
86
|
-
|
87
|
-
def _export_stablehlo_mlir(model, args):
|
88
|
-
ep = torch.export.export(model, args)
|
89
|
-
return lowertools.exported_program_to_mlir_text(ep)
|
90
|
-
|
91
|
-
|
92
|
-
def get_model_config() -> cfg.ModelConfig:
|
93
|
-
attn_config = cfg.AttentionConfig(
|
94
|
-
num_heads=32, head_dim=4, num_query_groups=4, rotary_percentage=1.0
|
95
|
-
)
|
96
|
-
ff_config = cfg.FeedForwardConfig(
|
97
|
-
type=cfg.FeedForwardType.GATED,
|
98
|
-
activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
|
99
|
-
intermediate_size=256,
|
100
|
-
)
|
101
|
-
norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
|
102
|
-
config = cfg.ModelConfig(
|
103
|
-
vocab_size=150,
|
104
|
-
num_layers=2,
|
105
|
-
max_seq_len=100,
|
106
|
-
embedding_dim=128,
|
107
|
-
attn_config=attn_config,
|
108
|
-
ff_config=ff_config,
|
109
|
-
pre_attention_norm_config=norm_config,
|
110
|
-
post_attention_norm_config=norm_config,
|
111
|
-
final_norm_config=norm_config,
|
112
|
-
enable_hlfb=True,
|
113
|
-
)
|
114
|
-
return config
|
115
|
-
|
116
|
-
|
117
|
-
def get_sample_prefill_inputs() -> Tuple[torch.Tensor, torch.Tensor]:
|
118
|
-
tokens = torch.unsqueeze(torch.arange(0, 100), 0)
|
119
|
-
input_pos = torch.arange(0, 100)
|
120
|
-
return tokens, input_pos
|
121
|
-
|
122
|
-
|
123
|
-
def get_sample_decode_inputs() -> Tuple[torch.Tensor, torch.Tensor]:
|
124
|
-
tokens = torch.tensor([[1]], dtype=torch.long)
|
125
|
-
input_pos = torch.tensor([10])
|
126
|
-
return tokens, input_pos
|
127
|
-
|
128
|
-
|
129
|
-
def define_and_run() -> None:
|
130
|
-
dump_mlir = False
|
131
|
-
|
132
|
-
config = get_model_config()
|
133
|
-
model = ToyModelWithExternalKV(config)
|
134
|
-
model.eval()
|
135
|
-
print('running an inference')
|
136
|
-
kv = kv_utils.EKVCache.from_model_config(config)
|
137
|
-
|
138
|
-
tokens, input_pos = get_sample_prefill_inputs()
|
139
|
-
decode_token, decode_input_pos = get_sample_decode_inputs()
|
140
|
-
print(model.forward(tokens, input_pos, kv))
|
141
|
-
|
142
|
-
if dump_mlir:
|
143
|
-
mlir_text = _export_stablehlo_mlir(model, (tokens, input_pos, kv))
|
144
|
-
with open('/tmp/toy_model_with_external_kv.stablehlo.mlir', 'w') as f:
|
145
|
-
f.write(mlir_text)
|
146
|
-
|
147
|
-
# Convert model to tflite with 2 signatures (prefill + decode).
|
148
|
-
# TODO(b/344014416): currently conversion will fail, because we generate int64 index
|
149
|
-
# in dynamic update slice op.
|
150
|
-
print('converting toy model to tflite with 2 signatures (prefill + decode)')
|
151
|
-
edge_model = (
|
152
|
-
ai_edge_torch.signature(
|
153
|
-
'prefill',
|
154
|
-
model,
|
155
|
-
sample_kwargs={
|
156
|
-
'tokens': tokens,
|
157
|
-
'input_pos': input_pos,
|
158
|
-
'kv_cache': kv,
|
159
|
-
},
|
160
|
-
)
|
161
|
-
.signature(
|
162
|
-
'decode',
|
163
|
-
model,
|
164
|
-
sample_kwargs={
|
165
|
-
'tokens': decode_token,
|
166
|
-
'input_pos': decode_input_pos,
|
167
|
-
'kv_cache': kv,
|
168
|
-
},
|
169
|
-
)
|
170
|
-
.convert()
|
171
|
-
)
|
172
|
-
edge_model.export('/tmp/toy_external_kv_cache.tflite')
|
173
|
-
|
174
|
-
|
175
|
-
if __name__ == '__main__':
|
176
|
-
define_and_run()
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|