ai-edge-torch-nightly 0.3.0.dev20240910__py3-none-any.whl → 0.3.0.dev20240914__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/_convert/conversion.py +2 -1
- ai_edge_torch/_convert/fx_passes/__init__.py +5 -41
- ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py +3 -4
- ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py +3 -4
- ai_edge_torch/_convert/fx_passes/inject_mlir_debuginfo_pass.py +3 -4
- ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +4 -5
- ai_edge_torch/config.py +4 -1
- ai_edge_torch/fx_pass_base.py +101 -0
- ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +35 -16
- ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +29 -10
- ai_edge_torch/generative/examples/gemma/gemma.py +52 -32
- ai_edge_torch/generative/examples/gemma/gemma2.py +87 -60
- ai_edge_torch/generative/examples/{experimental/gemma → openelm}/convert_to_tflite.py +16 -18
- ai_edge_torch/generative/examples/openelm/openelm.py +237 -0
- ai_edge_torch/generative/examples/{experimental/phi → phi}/convert_to_tflite.py +15 -16
- ai_edge_torch/generative/examples/{experimental/phi → phi}/phi2.py +48 -45
- ai_edge_torch/generative/examples/{experimental/tiny_llama → smollm}/convert_to_tflite.py +16 -17
- ai_edge_torch/generative/examples/smollm/smollm.py +131 -0
- ai_edge_torch/generative/examples/stable_diffusion/clip.py +12 -6
- ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +1 -1
- ai_edge_torch/generative/examples/t5/convert_to_tflite.py +20 -20
- ai_edge_torch/generative/examples/t5/t5.py +43 -30
- ai_edge_torch/generative/examples/t5/t5_attention.py +18 -13
- ai_edge_torch/generative/examples/test_models/toy_model.py +15 -13
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +75 -34
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +29 -10
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +57 -36
- ai_edge_torch/generative/fx_passes/__init__.py +4 -4
- ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +3 -4
- ai_edge_torch/generative/layers/attention.py +84 -73
- ai_edge_torch/generative/layers/builder.py +38 -14
- ai_edge_torch/generative/layers/feed_forward.py +26 -8
- ai_edge_torch/generative/layers/kv_cache.py +163 -51
- ai_edge_torch/generative/layers/model_config.py +61 -33
- ai_edge_torch/generative/layers/normalization.py +158 -0
- ai_edge_torch/generative/layers/unet/blocks_2d.py +0 -2
- ai_edge_torch/generative/quantize/example.py +2 -2
- ai_edge_torch/generative/test/{test_experimental_ekv.py → test_kv_cache.py} +12 -24
- ai_edge_torch/generative/test/test_loader.py +1 -1
- ai_edge_torch/generative/test/test_model_conversion.py +77 -62
- ai_edge_torch/generative/test/test_model_conversion_large.py +61 -68
- ai_edge_torch/generative/test/test_quantize.py +5 -5
- ai_edge_torch/generative/test/utils.py +54 -0
- ai_edge_torch/generative/utilities/loader.py +28 -15
- ai_edge_torch/generative/utilities/t5_loader.py +21 -20
- ai_edge_torch/odml_torch/export.py +40 -0
- ai_edge_torch/odml_torch/lowerings/__init__.py +1 -0
- ai_edge_torch/odml_torch/lowerings/_basic.py +44 -0
- ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +0 -2
- ai_edge_torch/odml_torch/lowerings/_layer_norm.py +78 -0
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240914.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240914.dist-info}/RECORD +59 -63
- ai_edge_torch/_convert/fx_passes/_pass_base.py +0 -53
- ai_edge_torch/_convert/fx_passes/canonicalize_pass.py +0 -35
- ai_edge_torch/generative/examples/experimental/gemma/gemma.py +0 -219
- ai_edge_torch/generative/examples/experimental/tiny_llama/__init__.py +0 -14
- ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py +0 -205
- ai_edge_torch/generative/examples/phi2/__init__.py +0 -14
- ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +0 -67
- ai_edge_torch/generative/examples/phi2/phi2.py +0 -189
- ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +0 -176
- /ai_edge_torch/generative/examples/{experimental → openelm}/__init__.py +0 -0
- /ai_edge_torch/generative/examples/{experimental/gemma → phi}/__init__.py +0 -0
- /ai_edge_torch/generative/examples/{experimental/phi → smollm}/__init__.py +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240914.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240914.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240914.dist-info}/top_level.txt +0 -0
@@ -12,14 +12,16 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
|
15
|
+
|
16
|
+
"""Example of building a Gemma2 model."""
|
16
17
|
|
17
18
|
import os
|
18
|
-
|
19
|
+
import pathlib
|
19
20
|
from typing import Optional, Tuple
|
20
21
|
|
21
22
|
from ai_edge_torch.generative.layers import attention
|
22
23
|
from ai_edge_torch.generative.layers import builder
|
24
|
+
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
23
25
|
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
24
26
|
import ai_edge_torch.generative.layers.model_config as cfg
|
25
27
|
import ai_edge_torch.generative.utilities.loader as loading_utils
|
@@ -51,7 +53,8 @@ class Gemma2Block(attention.TransformerBlock):
|
|
51
53
|
rope: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
52
54
|
mask: Optional[torch.Tensor] = None,
|
53
55
|
input_pos: Optional[torch.Tensor] = None,
|
54
|
-
|
56
|
+
kv_cache: kv_utils.KVCacheEntry = None,
|
57
|
+
) -> Tuple[torch.Tensor, Optional[kv_utils.KVCacheEntry]]:
|
55
58
|
"""Forward function of the Gemma2Block.
|
56
59
|
|
57
60
|
Exactly the same as TransformerBlock but we call the post-attention norm
|
@@ -62,17 +65,19 @@ class Gemma2Block(attention.TransformerBlock):
|
|
62
65
|
rope (Tuple[torch.Tensor, torch.Tensor]): the input rope tensor.
|
63
66
|
mask (torch.Tensor): the optional mask tensor.
|
64
67
|
input_pos (torch.Tensor): the optional input position tensor.
|
68
|
+
kv_cache (KVCacheEntry): the optional kv cache entry.
|
65
69
|
|
66
70
|
Returns:
|
67
|
-
output activation from this transformer block
|
71
|
+
output activation from this transformer block, and updated kv cache (if
|
72
|
+
passed in).
|
68
73
|
"""
|
69
74
|
|
70
75
|
x_norm = self.pre_atten_norm(x)
|
71
|
-
attn_out = self.atten_func(x_norm, rope, mask, input_pos)
|
76
|
+
attn_out, kv = self.atten_func(x_norm, rope, mask, input_pos, kv_cache)
|
72
77
|
attn_out_norm = self.post_atten_norm(attn_out)
|
73
78
|
x = x + attn_out_norm
|
74
79
|
output = x + self.ff(x)
|
75
|
-
return output
|
80
|
+
return output, kv
|
76
81
|
|
77
82
|
|
78
83
|
class Gemma2(nn.Module):
|
@@ -81,7 +86,6 @@ class Gemma2(nn.Module):
|
|
81
86
|
def __init__(self, config: cfg.ModelConfig):
|
82
87
|
super().__init__()
|
83
88
|
|
84
|
-
self.config = config
|
85
89
|
# Construct model layers.
|
86
90
|
self.tok_embedding = nn.Embedding(
|
87
91
|
config.vocab_size, config.embedding_dim, padding_idx=0
|
@@ -91,20 +95,22 @@ class Gemma2(nn.Module):
|
|
91
95
|
config.vocab_size,
|
92
96
|
bias=config.lm_head_use_bias,
|
93
97
|
)
|
94
|
-
#
|
98
|
+
# Gemma2 re-uses the embedding as the head projection layer.
|
95
99
|
self.lm_head.weight.data = self.tok_embedding.weight.data
|
96
100
|
self.transformer_blocks = nn.ModuleList(
|
97
|
-
Gemma2Block(config)
|
101
|
+
Gemma2Block(config.block_config(idx), config)
|
102
|
+
for idx in range(config.num_layers)
|
98
103
|
)
|
99
104
|
self.final_norm = builder.build_norm(
|
100
105
|
config.embedding_dim,
|
101
106
|
config.final_norm_config,
|
102
107
|
)
|
108
|
+
# Gemma2 has same hyper parameters for each layer except for attention
|
109
|
+
# types. Use the first layer.
|
110
|
+
attn_config = config.block_config(0).attn_config
|
103
111
|
self.rope_cache = attn_utils.build_rope_cache(
|
104
112
|
size=config.kv_cache_max,
|
105
|
-
dim=int(
|
106
|
-
config.attn_config.rotary_percentage * config.attn_config.head_dim
|
107
|
-
),
|
113
|
+
dim=int(attn_config.rotary_percentage * attn_config.head_dim),
|
108
114
|
base=10_000,
|
109
115
|
condense_ratio=1,
|
110
116
|
dtype=torch.float32,
|
@@ -115,47 +121,56 @@ class Gemma2(nn.Module):
|
|
115
121
|
dtype=torch.float32,
|
116
122
|
device=torch.device("cpu"),
|
117
123
|
)
|
118
|
-
|
119
124
|
self.sliding_window_mask_cache = attn_utils.build_sliding_window_mask_cache(
|
120
125
|
size=config.kv_cache_max,
|
121
|
-
window_size=
|
126
|
+
window_size=attn_config.sliding_window_size,
|
122
127
|
dtype=torch.float32,
|
123
128
|
device=torch.device("cpu"),
|
124
129
|
)
|
125
|
-
|
126
130
|
self.config = config
|
127
131
|
|
128
132
|
def get_attention_mask(
|
129
|
-
self,
|
133
|
+
self, attn_type: cfg.AttentionType, input_pos: torch.Tensor
|
130
134
|
) -> torch.Tensor:
|
131
|
-
if
|
132
|
-
|
133
|
-
self.config.attn_config.attn_types[idx]
|
134
|
-
== cfg.AttentionType.LOCAL_SLIDING
|
135
|
-
):
|
136
|
-
return self.sliding_window_mask_cache.index_select(2, input_pos)
|
137
|
-
|
135
|
+
if attn_type == cfg.AttentionType.LOCAL_SLIDING:
|
136
|
+
return self.sliding_window_mask_cache.index_select(2, input_pos)
|
138
137
|
return self.mask_cache.index_select(2, input_pos)
|
139
138
|
|
140
139
|
@torch.inference_mode
|
141
|
-
def forward(
|
142
|
-
|
140
|
+
def forward(
|
141
|
+
self,
|
142
|
+
tokens: torch.Tensor,
|
143
|
+
input_pos: torch.Tensor,
|
144
|
+
kv_cache: kv_utils.KVCache,
|
145
|
+
) -> dict[torch.Tensor, kv_utils.KVCache]:
|
146
|
+
_, seq_len = tokens.size()
|
143
147
|
assert self.config.max_seq_len >= seq_len, (
|
144
148
|
f"Cannot forward sequence of length {seq_len}, max seq length is only"
|
145
149
|
f" {self.config.max_seq_len}"
|
146
150
|
)
|
151
|
+
assert len(self.transformer_blocks) == len(kv_cache.caches), (
|
152
|
+
"The number of transformer blocks and the number of KV cache entries"
|
153
|
+
" must be the same."
|
154
|
+
)
|
147
155
|
|
148
156
|
cos, sin = self.rope_cache
|
149
157
|
cos = cos.index_select(0, input_pos)
|
150
158
|
sin = sin.index_select(0, input_pos)
|
151
159
|
|
152
160
|
# token embeddings of shape (b, t, n_embd)
|
153
|
-
x = self.tok_embedding(
|
161
|
+
x = self.tok_embedding(tokens)
|
154
162
|
x = x * (self.config.embedding_dim**0.5)
|
155
163
|
|
164
|
+
updated_kv_entires = []
|
156
165
|
for i, block in enumerate(self.transformer_blocks):
|
157
|
-
mask = self.get_attention_mask(
|
158
|
-
|
166
|
+
mask = self.get_attention_mask(
|
167
|
+
block.config.attn_config.attn_type, input_pos
|
168
|
+
)
|
169
|
+
kv_entry = kv_cache.caches[i] if kv_cache else None
|
170
|
+
x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
|
171
|
+
if kv_entry:
|
172
|
+
updated_kv_entires.append(kv_entry)
|
173
|
+
updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
|
159
174
|
|
160
175
|
x = self.final_norm(x)
|
161
176
|
res = self.lm_head(x) # (b, t, vocab_size)
|
@@ -163,7 +178,8 @@ class Gemma2(nn.Module):
|
|
163
178
|
res = res / self.config.final_logit_softcap
|
164
179
|
res = torch.tanh(res)
|
165
180
|
res = res * self.config.final_logit_softcap
|
166
|
-
|
181
|
+
|
182
|
+
return {"logits": res, "kv_cache": updated_kv_cache}
|
167
183
|
|
168
184
|
|
169
185
|
def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
@@ -176,18 +192,6 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
176
192
|
Returns:
|
177
193
|
The model config for a Gemma 2B model.
|
178
194
|
"""
|
179
|
-
attn_config = cfg.AttentionConfig(
|
180
|
-
num_heads=8,
|
181
|
-
head_dim=256,
|
182
|
-
num_query_groups=4,
|
183
|
-
rotary_percentage=1.0,
|
184
|
-
qkv_transpose_before_split=True,
|
185
|
-
logit_softcap=50.0,
|
186
|
-
sliding_window_size=4096,
|
187
|
-
attn_types=[cfg.AttentionType.GLOBAL, cfg.AttentionType.LOCAL_SLIDING]
|
188
|
-
* 13,
|
189
|
-
)
|
190
|
-
|
191
195
|
norm_config = cfg.NormalizationConfig(
|
192
196
|
type=cfg.NormalizationType.RMS_NORM,
|
193
197
|
epsilon=1e-6,
|
@@ -200,18 +204,38 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
200
204
|
pre_ff_norm_config=norm_config,
|
201
205
|
post_ff_norm_config=norm_config,
|
202
206
|
)
|
207
|
+
|
208
|
+
def get_block_config(idx: int) -> cfg.TransformerBlockConfig:
|
209
|
+
attn_config = cfg.AttentionConfig(
|
210
|
+
num_heads=8,
|
211
|
+
head_dim=256,
|
212
|
+
num_query_groups=4,
|
213
|
+
rotary_percentage=1.0,
|
214
|
+
qkv_transpose_before_split=True,
|
215
|
+
logit_softcap=50.0,
|
216
|
+
sliding_window_size=4096,
|
217
|
+
attn_type=(
|
218
|
+
cfg.AttentionType.GLOBAL
|
219
|
+
if idx % 2 == 0
|
220
|
+
else cfg.AttentionType.LOCAL_SLIDING
|
221
|
+
),
|
222
|
+
)
|
223
|
+
return cfg.TransformerBlockConfig(
|
224
|
+
attn_config=attn_config,
|
225
|
+
ff_config=ff_config,
|
226
|
+
pre_attention_norm_config=norm_config,
|
227
|
+
post_attention_norm_config=norm_config,
|
228
|
+
)
|
229
|
+
|
230
|
+
num_layers = 26
|
203
231
|
config = cfg.ModelConfig(
|
204
232
|
vocab_size=256000,
|
205
|
-
num_layers=
|
233
|
+
num_layers=num_layers,
|
206
234
|
max_seq_len=8192,
|
207
235
|
embedding_dim=2304,
|
208
236
|
kv_cache_max_len=kv_cache_max_len,
|
209
|
-
|
210
|
-
ff_config=ff_config,
|
211
|
-
pre_attention_norm_config=norm_config,
|
212
|
-
post_attention_norm_config=norm_config,
|
237
|
+
block_configs=[get_block_config(i) for i in range(num_layers)],
|
213
238
|
final_norm_config=norm_config,
|
214
|
-
parallel_residual=False,
|
215
239
|
lm_head_use_bias=False,
|
216
240
|
enable_hlfb=True,
|
217
241
|
final_logit_softcap=30.0,
|
@@ -221,14 +245,16 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
221
245
|
|
222
246
|
def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
|
223
247
|
config = get_model_config_2b(kv_cache_max_len)
|
224
|
-
config.attn_config.num_heads = 4
|
225
|
-
config.attn_config.head_dim = 64
|
226
|
-
config.attn_config.sliding_window_size = 64
|
227
|
-
config.ff_config.intermediate_size = 128
|
228
248
|
config.vocab_size = 128
|
229
249
|
config.num_layers = 2
|
230
250
|
config.max_seq_len = 2 * kv_cache_max_len
|
231
251
|
config.embedding_dim = 128
|
252
|
+
config.block_configs = config.block_configs[: config.num_layers]
|
253
|
+
for block_config in config.block_configs:
|
254
|
+
block_config.attn_config.num_heads = 4
|
255
|
+
block_config.attn_config.head_dim = 64
|
256
|
+
block_config.attn_config.sliding_window_size = 64
|
257
|
+
block_config.ff_config.intermediate_size = 128
|
232
258
|
return config
|
233
259
|
|
234
260
|
|
@@ -236,33 +262,34 @@ def build_2b_model(checkpoint_path: str, **kwargs) -> nn.Module:
|
|
236
262
|
config = get_model_config_2b(**kwargs)
|
237
263
|
model = Gemma2(config)
|
238
264
|
loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
|
239
|
-
#
|
265
|
+
# Since embedding and lm-head use the same weight, we need to set strict
|
240
266
|
# to False.
|
241
267
|
loader.load(model, strict=False)
|
242
268
|
model.eval()
|
243
269
|
return model
|
244
270
|
|
245
271
|
|
246
|
-
def define_and_run_2b() -> None:
|
272
|
+
def define_and_run_2b(checkpoint_path: str) -> None:
|
247
273
|
"""Instantiates and runs a Gemma2 2B model."""
|
248
274
|
|
249
|
-
current_dir = Path(__file__).parent.resolve()
|
275
|
+
current_dir = pathlib.Path(__file__).parent.resolve()
|
250
276
|
gemma2_goldens = torch.load(current_dir / "gemma2it_2b_golden.pt")
|
251
277
|
print("Running GEMMA 2")
|
252
278
|
kv_cache_max_len = 1024
|
253
|
-
checkpoint_path = os.path.join(Path.home(), "Downloads/llm_data/gemma2-2b")
|
254
279
|
model = build_2b_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
|
255
280
|
toks = torch.from_numpy(
|
256
281
|
np.array([2, 651, 9456, 576, 573, 3520, 3858, 603, 235248])
|
257
282
|
)
|
258
|
-
tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.
|
283
|
+
tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.int, device="cpu")
|
259
284
|
tokens[0, :9] = toks
|
260
|
-
input_pos = torch.arange(0, kv_cache_max_len)
|
261
|
-
|
262
|
-
|
285
|
+
input_pos = torch.arange(0, kv_cache_max_len, dtype=torch.int)
|
286
|
+
kv = kv_utils.KVCache.from_model_config(model.config)
|
287
|
+
out = model.forward(tokens, input_pos, kv)
|
288
|
+
out_final = out["logits"][0, 8, :]
|
263
289
|
assert torch.allclose(gemma2_goldens, out_final, atol=1e-04)
|
264
290
|
|
265
291
|
|
266
292
|
if __name__ == "__main__":
|
267
293
|
torch.set_printoptions(sci_mode=True)
|
268
|
-
|
294
|
+
path = os.path.join(pathlib.Path.home(), "Downloads/llm_data/gemma2-2b")
|
295
|
+
define_and_run_2b(path)
|
@@ -12,30 +12,27 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
#
|
16
|
-
# Note: This is an experimental version of Gemma with external KV cache.
|
17
|
-
# Please use with caution.
|
18
15
|
|
16
|
+
"""Example of converting OpenELM model to multi-signature tflite model."""
|
19
17
|
|
20
18
|
import os
|
21
|
-
|
19
|
+
import pathlib
|
22
20
|
|
23
21
|
import ai_edge_torch
|
24
|
-
from ai_edge_torch.generative.examples.
|
25
|
-
from ai_edge_torch.generative.layers
|
22
|
+
from ai_edge_torch.generative.examples.openelm import openelm
|
23
|
+
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
26
24
|
from ai_edge_torch.generative.quantize import quant_recipes
|
27
25
|
import torch
|
28
26
|
|
29
27
|
|
30
|
-
def
|
28
|
+
def convert_openelm_to_tflite(
|
31
29
|
checkpoint_path: str,
|
32
30
|
prefill_seq_len: int = 512,
|
33
31
|
kv_cache_max_len: int = 1024,
|
34
32
|
quantize: bool = True,
|
35
33
|
):
|
36
|
-
"""
|
34
|
+
"""Converts OpenELM model to multi-signature tflite model.
|
37
35
|
|
38
|
-
tflite model.
|
39
36
|
Args:
|
40
37
|
checkpoint_path (str): The filepath to the model checkpoint, or directory
|
41
38
|
holding the checkpoint.
|
@@ -46,15 +43,15 @@ def convert_gemma_to_tflite(
|
|
46
43
|
quantize (bool, optional): Whether the model should be quanized. Defaults
|
47
44
|
to True.
|
48
45
|
"""
|
49
|
-
pytorch_model =
|
46
|
+
pytorch_model = openelm.build_model(
|
50
47
|
checkpoint_path, kv_cache_max_len=kv_cache_max_len
|
51
48
|
)
|
52
49
|
# Tensors used to trace the model graph during conversion.
|
53
|
-
prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.
|
54
|
-
prefill_input_pos = torch.arange(0, prefill_seq_len)
|
55
|
-
decode_token = torch.tensor([[0]], dtype=torch.
|
56
|
-
decode_input_pos = torch.tensor([0], dtype=torch.
|
57
|
-
kv = kv_utils.
|
50
|
+
prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.int)
|
51
|
+
prefill_input_pos = torch.arange(0, prefill_seq_len, dtype=torch.int)
|
52
|
+
decode_token = torch.tensor([[0]], dtype=torch.int)
|
53
|
+
decode_input_pos = torch.tensor([0], dtype=torch.int)
|
54
|
+
kv = kv_utils.KVCache.from_model_config(pytorch_model.config)
|
58
55
|
|
59
56
|
quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
|
60
57
|
edge_model = (
|
@@ -78,11 +75,12 @@ def convert_gemma_to_tflite(
|
|
78
75
|
)
|
79
76
|
.convert(quant_config=quant_config)
|
80
77
|
)
|
78
|
+
quant_suffix = 'q8' if quantize else 'f32'
|
81
79
|
edge_model.export(
|
82
|
-
f'/tmp/
|
80
|
+
f'/tmp/openelm_{quant_suffix}_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
|
83
81
|
)
|
84
82
|
|
85
83
|
|
86
84
|
if __name__ == '__main__':
|
87
|
-
|
88
|
-
|
85
|
+
path = os.path.join(pathlib.Path.home(), 'Downloads/llm_data/openelm')
|
86
|
+
convert_openelm_to_tflite(path)
|
@@ -0,0 +1,237 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Example of building an OpenELM model."""
|
17
|
+
|
18
|
+
import os
|
19
|
+
import pathlib
|
20
|
+
|
21
|
+
from ai_edge_torch.generative.layers import attention
|
22
|
+
from ai_edge_torch.generative.layers import builder
|
23
|
+
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
24
|
+
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
25
|
+
import ai_edge_torch.generative.layers.model_config as cfg
|
26
|
+
import ai_edge_torch.generative.utilities.loader as loading_utils
|
27
|
+
import numpy as np
|
28
|
+
import torch
|
29
|
+
from torch import nn
|
30
|
+
|
31
|
+
TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
32
|
+
ff_up_proj="transformer.layers.{}.ffn.proj_1",
|
33
|
+
ff_down_proj="transformer.layers.{}.ffn.proj_2",
|
34
|
+
attn_fused_qkv_proj="transformer.layers.{}.attn.qkv_proj",
|
35
|
+
attn_query_norm="transformer.layers.{}.attn.q_norm",
|
36
|
+
attn_key_norm="transformer.layers.{}.attn.k_norm",
|
37
|
+
attn_output_proj="transformer.layers.{}.attn.out_proj",
|
38
|
+
pre_attn_norm="transformer.layers.{}.attn_norm",
|
39
|
+
pre_ff_norm="transformer.layers.{}.ffn_norm",
|
40
|
+
embedding="transformer.token_embeddings",
|
41
|
+
final_norm="transformer.norm",
|
42
|
+
lm_head=None,
|
43
|
+
)
|
44
|
+
|
45
|
+
|
46
|
+
class OpenELM(nn.Module):
|
47
|
+
"""An OpenELM model built from the Edge Generative API layers."""
|
48
|
+
|
49
|
+
def __init__(self, config: cfg.ModelConfig):
|
50
|
+
super().__init__()
|
51
|
+
|
52
|
+
# Construct model layers.
|
53
|
+
self.tok_embedding = nn.Embedding(
|
54
|
+
config.vocab_size, config.embedding_dim, padding_idx=0
|
55
|
+
)
|
56
|
+
self.lm_head = nn.Linear(
|
57
|
+
config.embedding_dim, config.vocab_size, bias=config.lm_head_use_bias
|
58
|
+
)
|
59
|
+
# OpenELM re-uses the embedding as the head projection layer.
|
60
|
+
self.lm_head.weight.data = self.tok_embedding.weight.data
|
61
|
+
self.transformer_blocks = nn.ModuleList(
|
62
|
+
attention.TransformerBlock(config.block_config(idx), config)
|
63
|
+
for idx in range(config.num_layers)
|
64
|
+
)
|
65
|
+
self.final_norm = builder.build_norm(
|
66
|
+
config.embedding_dim,
|
67
|
+
config.final_norm_config,
|
68
|
+
)
|
69
|
+
# OpenELM has same hyper parameters for rotary_percentage and head_dim for
|
70
|
+
# each layer block. Use the first block.
|
71
|
+
attn_config = config.block_config(0).attn_config
|
72
|
+
self.rope_cache = attn_utils.build_rope_cache(
|
73
|
+
size=config.kv_cache_max,
|
74
|
+
dim=int(attn_config.rotary_percentage * attn_config.head_dim),
|
75
|
+
base=10_000,
|
76
|
+
condense_ratio=1,
|
77
|
+
dtype=torch.float32,
|
78
|
+
device=torch.device("cpu"),
|
79
|
+
)
|
80
|
+
self.mask_cache = attn_utils.build_causal_mask_cache(
|
81
|
+
size=config.kv_cache_max,
|
82
|
+
dtype=torch.float32,
|
83
|
+
device=torch.device("cpu"),
|
84
|
+
)
|
85
|
+
self.config = config
|
86
|
+
|
87
|
+
@torch.inference_mode
|
88
|
+
def forward(
|
89
|
+
self,
|
90
|
+
tokens: torch.Tensor,
|
91
|
+
input_pos: torch.Tensor,
|
92
|
+
kv_cache: kv_utils.KVCache,
|
93
|
+
) -> dict[torch.Tensor, kv_utils.KVCache]:
|
94
|
+
_, seq_len = tokens.size()
|
95
|
+
assert self.config.max_seq_len >= seq_len, (
|
96
|
+
f"Cannot forward sequence of length {seq_len}, max seq length is only"
|
97
|
+
f" {self.config.max_seq_len}"
|
98
|
+
)
|
99
|
+
assert len(self.transformer_blocks) == len(kv_cache.caches), (
|
100
|
+
"The number of transformer blocks and the number of KV cache entries"
|
101
|
+
" must be the same."
|
102
|
+
)
|
103
|
+
|
104
|
+
cos, sin = self.rope_cache
|
105
|
+
cos = cos.index_select(0, input_pos)
|
106
|
+
sin = sin.index_select(0, input_pos)
|
107
|
+
mask = self.mask_cache.index_select(2, input_pos)
|
108
|
+
mask = mask[:, :, :, : self.config.kv_cache_max]
|
109
|
+
|
110
|
+
# token embeddings of shape (b, t, n_embd)
|
111
|
+
x = self.tok_embedding(tokens)
|
112
|
+
|
113
|
+
updated_kv_entires = []
|
114
|
+
for i, block in enumerate(self.transformer_blocks):
|
115
|
+
kv_entry = kv_cache.caches[i] if kv_cache else None
|
116
|
+
x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
|
117
|
+
if kv_entry:
|
118
|
+
updated_kv_entires.append(kv_entry)
|
119
|
+
updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
|
120
|
+
|
121
|
+
x = self.final_norm(x)
|
122
|
+
logits = self.lm_head(x) # (b, t, vocab_size)
|
123
|
+
return {"logits": logits, "kv_cache": updated_kv_cache}
|
124
|
+
|
125
|
+
|
126
|
+
def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
127
|
+
"""Returns the model config for an OpenELM model.
|
128
|
+
|
129
|
+
Args:
|
130
|
+
kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
|
131
|
+
is 1024.
|
132
|
+
|
133
|
+
Returns:
|
134
|
+
The model config for an OpenELM model.
|
135
|
+
"""
|
136
|
+
norm_config = cfg.NormalizationConfig(
|
137
|
+
type=cfg.NormalizationType.RMS_NORM, epsilon=1e-6
|
138
|
+
)
|
139
|
+
num_heads = [12] * 4 + [16] * 14 + [20] * 12 + [24] * 6
|
140
|
+
num_query_groups = [3] * 4 + [4] * 14 + [5] * 12 + [6] * 6
|
141
|
+
|
142
|
+
def make_divisible(v, d):
|
143
|
+
"""Ensures that all layers have a channel number that is divisible by d."""
|
144
|
+
new_v = int(v + d / 2) // d * d
|
145
|
+
# Make sure that round down does not go down by more than 10%.
|
146
|
+
if new_v < 0.9 * v:
|
147
|
+
new_v += d
|
148
|
+
return new_v
|
149
|
+
|
150
|
+
# The way to get intermediate size is from
|
151
|
+
# https://huggingface.co/apple/OpenELM-3B/blob/main/modeling_openelm.py
|
152
|
+
def get_intermediate_size(idx: int) -> int:
|
153
|
+
return make_divisible((0.5 + 0.1 * idx) * 3072, 256)
|
154
|
+
|
155
|
+
def get_block_config(idx: int) -> cfg.TransformerBlockConfig:
|
156
|
+
return cfg.TransformerBlockConfig(
|
157
|
+
attn_config=cfg.AttentionConfig(
|
158
|
+
num_heads=num_heads[idx],
|
159
|
+
head_dim=128,
|
160
|
+
num_query_groups=num_query_groups[idx],
|
161
|
+
rotary_percentage=1.0,
|
162
|
+
qkv_transpose_before_split=True,
|
163
|
+
query_norm_config=norm_config,
|
164
|
+
key_norm_config=norm_config,
|
165
|
+
),
|
166
|
+
ff_config=cfg.FeedForwardConfig(
|
167
|
+
type=cfg.FeedForwardType.SEQUENTIAL,
|
168
|
+
activation=cfg.ActivationConfig(
|
169
|
+
cfg.ActivationType.SILU_GLU, gate_is_front=True
|
170
|
+
),
|
171
|
+
intermediate_size=get_intermediate_size(idx),
|
172
|
+
pre_ff_norm_config=norm_config,
|
173
|
+
),
|
174
|
+
pre_attention_norm_config=norm_config,
|
175
|
+
)
|
176
|
+
|
177
|
+
num_layers = 36
|
178
|
+
config = cfg.ModelConfig(
|
179
|
+
vocab_size=32000,
|
180
|
+
num_layers=num_layers,
|
181
|
+
max_seq_len=2048,
|
182
|
+
embedding_dim=3072,
|
183
|
+
kv_cache_max_len=kv_cache_max_len,
|
184
|
+
block_configs=[get_block_config(i) for i in range(num_layers)],
|
185
|
+
final_norm_config=norm_config,
|
186
|
+
)
|
187
|
+
return config
|
188
|
+
|
189
|
+
|
190
|
+
def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
|
191
|
+
config = get_model_config(kv_cache_max_len)
|
192
|
+
config.vocab_size = 128
|
193
|
+
config.num_layers = 2
|
194
|
+
config.max_seq_len = 2 * kv_cache_max_len
|
195
|
+
config.embedding_dim = 128
|
196
|
+
config.block_configs = config.block_configs[: config.num_layers]
|
197
|
+
for block_config in config.block_configs:
|
198
|
+
block_config.attn_config.num_heads = 3
|
199
|
+
block_config.attn_config.head_dim = 64
|
200
|
+
block_config.ff_config.intermediate_size = 128
|
201
|
+
return config
|
202
|
+
|
203
|
+
|
204
|
+
def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
|
205
|
+
config = get_model_config(**kwargs)
|
206
|
+
model = OpenELM(config)
|
207
|
+
loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
|
208
|
+
# Since embedding and lm-head use the same weight, we need to set strict
|
209
|
+
# to False.
|
210
|
+
loader.load(model, strict=False)
|
211
|
+
model.eval()
|
212
|
+
return model
|
213
|
+
|
214
|
+
|
215
|
+
def define_and_run(checkpoint_path: str) -> None:
|
216
|
+
"""Instantiates and runs an OpenELM model."""
|
217
|
+
|
218
|
+
current_dir = pathlib.Path(__file__).parent.resolve()
|
219
|
+
openelm_goldens = torch.load(current_dir / "openelm_lm_logits.pt")
|
220
|
+
kv_cache_max_len = 1024
|
221
|
+
model = build_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
|
222
|
+
idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
|
223
|
+
tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.int, device="cpu")
|
224
|
+
tokens[0, :4] = idx
|
225
|
+
input_pos = torch.arange(0, kv_cache_max_len, dtype=torch.int)
|
226
|
+
kv = kv_utils.KVCache.from_model_config(model.config)
|
227
|
+
output = model.forward(tokens, input_pos, kv)
|
228
|
+
assert torch.allclose(
|
229
|
+
openelm_goldens, output["logits"][0, idx.shape[1] - 1, :], atol=1e-05
|
230
|
+
)
|
231
|
+
|
232
|
+
|
233
|
+
if __name__ == "__main__":
|
234
|
+
input_checkpoint_path = os.path.join(
|
235
|
+
pathlib.Path.home(), "Downloads/llm_data/openelm"
|
236
|
+
)
|
237
|
+
define_and_run(input_checkpoint_path)
|
@@ -12,16 +12,15 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
|
16
|
-
|
17
|
-
# Please use with caution.
|
15
|
+
|
16
|
+
"""Example of converting a Phi-2 model to multi-signature tflite model."""
|
18
17
|
|
19
18
|
import os
|
20
|
-
|
19
|
+
import pathlib
|
21
20
|
|
22
21
|
import ai_edge_torch
|
23
|
-
from ai_edge_torch.generative.examples.
|
24
|
-
from ai_edge_torch.generative.layers
|
22
|
+
from ai_edge_torch.generative.examples.phi import phi2
|
23
|
+
from ai_edge_torch.generative.layers import kv_cache
|
25
24
|
from ai_edge_torch.generative.quantize import quant_recipes
|
26
25
|
import torch
|
27
26
|
|
@@ -32,9 +31,8 @@ def convert_phi2_to_tflite(
|
|
32
31
|
kv_cache_max_len: int = 1024,
|
33
32
|
quantize: bool = True,
|
34
33
|
):
|
35
|
-
"""
|
34
|
+
"""Converts a Phi-2 model to multi-signature tflite model.
|
36
35
|
|
37
|
-
tflite model.
|
38
36
|
Args:
|
39
37
|
checkpoint_path (str): The filepath to the model checkpoint, or directory
|
40
38
|
holding the checkpoint.
|
@@ -49,11 +47,11 @@ def convert_phi2_to_tflite(
|
|
49
47
|
checkpoint_path, kv_cache_max_len=kv_cache_max_len
|
50
48
|
)
|
51
49
|
# Tensors used to trace the model graph during conversion.
|
52
|
-
prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.
|
53
|
-
prefill_input_pos = torch.arange(0, prefill_seq_len)
|
54
|
-
decode_token = torch.tensor([[0]], dtype=torch.
|
55
|
-
decode_input_pos = torch.tensor([0], dtype=torch.
|
56
|
-
kv =
|
50
|
+
prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.int)
|
51
|
+
prefill_input_pos = torch.arange(0, prefill_seq_len, dtype=torch.int)
|
52
|
+
decode_token = torch.tensor([[0]], dtype=torch.int)
|
53
|
+
decode_input_pos = torch.tensor([0], dtype=torch.int)
|
54
|
+
kv = kv_cache.KVCache.from_model_config(pytorch_model.config)
|
57
55
|
|
58
56
|
quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
|
59
57
|
edge_model = (
|
@@ -77,11 +75,12 @@ def convert_phi2_to_tflite(
|
|
77
75
|
)
|
78
76
|
.convert(quant_config=quant_config)
|
79
77
|
)
|
78
|
+
quant_suffix = 'q8' if quantize else 'f32'
|
80
79
|
edge_model.export(
|
81
|
-
f'/tmp/
|
80
|
+
f'/tmp/phi2_{quant_suffix}_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
|
82
81
|
)
|
83
82
|
|
84
83
|
|
85
84
|
if __name__ == '__main__':
|
86
|
-
|
87
|
-
convert_phi2_to_tflite(
|
85
|
+
path = os.path.join(pathlib.Path.home(), 'Downloads/llm_data/phi2')
|
86
|
+
convert_phi2_to_tflite(path)
|