ai-edge-torch-nightly 0.3.0.dev20240910__py3-none-any.whl → 0.3.0.dev20240914__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/_convert/conversion.py +2 -1
- ai_edge_torch/_convert/fx_passes/__init__.py +5 -41
- ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py +3 -4
- ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py +3 -4
- ai_edge_torch/_convert/fx_passes/inject_mlir_debuginfo_pass.py +3 -4
- ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +4 -5
- ai_edge_torch/config.py +4 -1
- ai_edge_torch/fx_pass_base.py +101 -0
- ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +35 -16
- ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +29 -10
- ai_edge_torch/generative/examples/gemma/gemma.py +52 -32
- ai_edge_torch/generative/examples/gemma/gemma2.py +87 -60
- ai_edge_torch/generative/examples/{experimental/gemma → openelm}/convert_to_tflite.py +16 -18
- ai_edge_torch/generative/examples/openelm/openelm.py +237 -0
- ai_edge_torch/generative/examples/{experimental/phi → phi}/convert_to_tflite.py +15 -16
- ai_edge_torch/generative/examples/{experimental/phi → phi}/phi2.py +48 -45
- ai_edge_torch/generative/examples/{experimental/tiny_llama → smollm}/convert_to_tflite.py +16 -17
- ai_edge_torch/generative/examples/smollm/smollm.py +131 -0
- ai_edge_torch/generative/examples/stable_diffusion/clip.py +12 -6
- ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +1 -1
- ai_edge_torch/generative/examples/t5/convert_to_tflite.py +20 -20
- ai_edge_torch/generative/examples/t5/t5.py +43 -30
- ai_edge_torch/generative/examples/t5/t5_attention.py +18 -13
- ai_edge_torch/generative/examples/test_models/toy_model.py +15 -13
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +75 -34
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +29 -10
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +57 -36
- ai_edge_torch/generative/fx_passes/__init__.py +4 -4
- ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +3 -4
- ai_edge_torch/generative/layers/attention.py +84 -73
- ai_edge_torch/generative/layers/builder.py +38 -14
- ai_edge_torch/generative/layers/feed_forward.py +26 -8
- ai_edge_torch/generative/layers/kv_cache.py +163 -51
- ai_edge_torch/generative/layers/model_config.py +61 -33
- ai_edge_torch/generative/layers/normalization.py +158 -0
- ai_edge_torch/generative/layers/unet/blocks_2d.py +0 -2
- ai_edge_torch/generative/quantize/example.py +2 -2
- ai_edge_torch/generative/test/{test_experimental_ekv.py → test_kv_cache.py} +12 -24
- ai_edge_torch/generative/test/test_loader.py +1 -1
- ai_edge_torch/generative/test/test_model_conversion.py +77 -62
- ai_edge_torch/generative/test/test_model_conversion_large.py +61 -68
- ai_edge_torch/generative/test/test_quantize.py +5 -5
- ai_edge_torch/generative/test/utils.py +54 -0
- ai_edge_torch/generative/utilities/loader.py +28 -15
- ai_edge_torch/generative/utilities/t5_loader.py +21 -20
- ai_edge_torch/odml_torch/export.py +40 -0
- ai_edge_torch/odml_torch/lowerings/__init__.py +1 -0
- ai_edge_torch/odml_torch/lowerings/_basic.py +44 -0
- ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +0 -2
- ai_edge_torch/odml_torch/lowerings/_layer_norm.py +78 -0
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240914.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240914.dist-info}/RECORD +59 -63
- ai_edge_torch/_convert/fx_passes/_pass_base.py +0 -53
- ai_edge_torch/_convert/fx_passes/canonicalize_pass.py +0 -35
- ai_edge_torch/generative/examples/experimental/gemma/gemma.py +0 -219
- ai_edge_torch/generative/examples/experimental/tiny_llama/__init__.py +0 -14
- ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py +0 -205
- ai_edge_torch/generative/examples/phi2/__init__.py +0 -14
- ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +0 -67
- ai_edge_torch/generative/examples/phi2/phi2.py +0 -189
- ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +0 -176
- /ai_edge_torch/generative/examples/{experimental → openelm}/__init__.py +0 -0
- /ai_edge_torch/generative/examples/{experimental/gemma → phi}/__init__.py +0 -0
- /ai_edge_torch/generative/examples/{experimental/phi → smollm}/__init__.py +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240914.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240914.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240914.dist-info}/top_level.txt +0 -0
@@ -12,26 +12,22 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
|
16
|
-
|
17
|
-
# Note: This is an experimental version of phi2 with external KV cache.
|
18
|
-
# Please use with caution.
|
15
|
+
|
16
|
+
"""Example of building a Phi-2 model."""
|
19
17
|
|
20
18
|
import os
|
21
|
-
|
22
|
-
from typing import Tuple
|
19
|
+
import pathlib
|
23
20
|
|
21
|
+
from ai_edge_torch.generative.layers import attention
|
24
22
|
from ai_edge_torch.generative.layers import builder
|
23
|
+
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
25
24
|
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
26
|
-
from ai_edge_torch.generative.layers.experimental import attention
|
27
|
-
from ai_edge_torch.generative.layers.experimental import ekv_cache as kv_utils
|
28
25
|
import ai_edge_torch.generative.layers.model_config as cfg
|
29
26
|
import ai_edge_torch.generative.utilities.loader as loading_utils
|
30
27
|
import numpy as np
|
31
28
|
import torch
|
32
29
|
from torch import nn
|
33
30
|
|
34
|
-
|
35
31
|
TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
36
32
|
ff_up_proj="model.layers.{}.mlp.fc1",
|
37
33
|
ff_down_proj="model.layers.{}.mlp.fc2",
|
@@ -52,7 +48,6 @@ class Phi2(nn.Module):
|
|
52
48
|
def __init__(self, config: cfg.ModelConfig):
|
53
49
|
super().__init__()
|
54
50
|
|
55
|
-
self.config = config
|
56
51
|
# Construct model layers.
|
57
52
|
self.lm_head = nn.Linear(
|
58
53
|
config.embedding_dim, config.vocab_size, bias=config.lm_head_use_bias
|
@@ -60,18 +55,20 @@ class Phi2(nn.Module):
|
|
60
55
|
self.tok_embedding = nn.Embedding(
|
61
56
|
config.vocab_size, config.embedding_dim, padding_idx=0
|
62
57
|
)
|
58
|
+
# Phi-2 has only one block config.
|
59
|
+
block_config = config.block_config(0)
|
63
60
|
self.transformer_blocks = nn.ModuleList(
|
64
|
-
attention.TransformerBlock(
|
61
|
+
attention.TransformerBlock(block_config, config)
|
62
|
+
for _ in range(config.num_layers)
|
65
63
|
)
|
66
64
|
self.final_norm = builder.build_norm(
|
67
65
|
config.embedding_dim,
|
68
66
|
config.final_norm_config,
|
69
67
|
)
|
68
|
+
attn_config = block_config.attn_config
|
70
69
|
self.rope_cache = attn_utils.build_rope_cache(
|
71
70
|
size=config.kv_cache_max,
|
72
|
-
dim=int(
|
73
|
-
config.attn_config.rotary_percentage * config.attn_config.head_dim
|
74
|
-
),
|
71
|
+
dim=int(attn_config.rotary_percentage * attn_config.head_dim),
|
75
72
|
base=10_000,
|
76
73
|
condense_ratio=1,
|
77
74
|
dtype=torch.float32,
|
@@ -89,13 +86,17 @@ class Phi2(nn.Module):
|
|
89
86
|
self,
|
90
87
|
tokens: torch.Tensor,
|
91
88
|
input_pos: torch.Tensor,
|
92
|
-
kv_cache: kv_utils.
|
93
|
-
) ->
|
89
|
+
kv_cache: kv_utils.KVCache,
|
90
|
+
) -> dict[torch.Tensor, kv_utils.KVCache]:
|
94
91
|
_, seq_len = tokens.size()
|
95
92
|
assert self.config.max_seq_len >= seq_len, (
|
96
93
|
f"Cannot forward sequence of length {seq_len}, max seq length is only"
|
97
94
|
f" {self.config.max_seq_len}"
|
98
95
|
)
|
96
|
+
assert len(self.transformer_blocks) == len(kv_cache.caches), (
|
97
|
+
"The number of transformer blocks and the number of KV cache entries"
|
98
|
+
" must be the same."
|
99
|
+
)
|
99
100
|
|
100
101
|
cos, sin = self.rope_cache
|
101
102
|
cos = cos.index_select(0, input_pos)
|
@@ -111,11 +112,11 @@ class Phi2(nn.Module):
|
|
111
112
|
x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
|
112
113
|
if kv_entry:
|
113
114
|
updated_kv_entires.append(kv_entry)
|
114
|
-
updated_kv_cache = kv_utils.
|
115
|
+
updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
|
115
116
|
|
116
117
|
x = self.final_norm(x)
|
117
|
-
|
118
|
-
return
|
118
|
+
logits = self.lm_head(x) # (b, t, vocab_size)
|
119
|
+
return {"logits": logits, "kv_cache": updated_kv_cache}
|
119
120
|
|
120
121
|
|
121
122
|
def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
@@ -143,17 +144,20 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
143
144
|
use_bias=True,
|
144
145
|
)
|
145
146
|
norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.LAYER_NORM)
|
147
|
+
block_config = cfg.TransformerBlockConfig(
|
148
|
+
attn_config=attn_config,
|
149
|
+
ff_config=ff_config,
|
150
|
+
pre_attention_norm_config=norm_config,
|
151
|
+
parallel_residual=True,
|
152
|
+
)
|
146
153
|
config = cfg.ModelConfig(
|
147
154
|
vocab_size=51200,
|
148
155
|
num_layers=32,
|
149
156
|
max_seq_len=2048,
|
150
157
|
kv_cache_max_len=kv_cache_max_len,
|
151
158
|
embedding_dim=2560,
|
152
|
-
|
153
|
-
ff_config=ff_config,
|
154
|
-
pre_attention_norm_config=norm_config,
|
159
|
+
block_configs=block_config,
|
155
160
|
final_norm_config=norm_config,
|
156
|
-
parallel_residual=True,
|
157
161
|
lm_head_use_bias=True,
|
158
162
|
enable_hlfb=True,
|
159
163
|
)
|
@@ -165,43 +169,42 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
|
|
165
169
|
config.vocab_size = 128
|
166
170
|
config.num_layers = 2
|
167
171
|
config.max_seq_len = 2 * kv_cache_max_len
|
168
|
-
config.
|
172
|
+
# Phi-2 has only one block config.
|
173
|
+
config.block_config(0).ff_config.intermediate_size = 128
|
169
174
|
return config
|
170
175
|
|
171
176
|
|
172
|
-
def build_model(
|
173
|
-
checkpoint_path: str, test_model: bool = False, **kwargs
|
174
|
-
) -> nn.Module:
|
177
|
+
def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
|
175
178
|
"""Instantiates the model instance and load checkpoint if provided."""
|
176
|
-
config = (
|
177
|
-
get_fake_model_config(**kwargs)
|
178
|
-
if test_model
|
179
|
-
else get_model_config(**kwargs)
|
180
|
-
)
|
179
|
+
config = get_model_config(**kwargs)
|
181
180
|
model = Phi2(config)
|
182
|
-
|
183
|
-
|
184
|
-
loader.load(model)
|
181
|
+
loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
|
182
|
+
loader.load(model)
|
185
183
|
model.eval()
|
186
184
|
return model
|
187
185
|
|
188
186
|
|
189
|
-
def define_and_run(checkpoint_path: str
|
187
|
+
def define_and_run(checkpoint_path: str) -> None:
|
190
188
|
"""Instantiates and runs a Phi-2 model."""
|
191
189
|
|
190
|
+
current_dir = pathlib.Path(__file__).parent.resolve()
|
191
|
+
phi2_goldens = torch.load(current_dir / "phi2_lm_logits.pt")
|
192
192
|
kv_cache_max_len = 1024
|
193
|
-
model = build_model(
|
194
|
-
checkpoint_path, test_model=test_model, kv_cache_max_len=kv_cache_max_len
|
195
|
-
)
|
193
|
+
model = build_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
|
196
194
|
idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
|
197
|
-
tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.
|
195
|
+
tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.int, device="cpu")
|
198
196
|
tokens[0, :4] = idx
|
199
|
-
input_pos = torch.arange(0, kv_cache_max_len)
|
200
|
-
kv = kv_utils.
|
201
|
-
|
202
|
-
print(
|
197
|
+
input_pos = torch.arange(0, kv_cache_max_len, dtype=torch.int)
|
198
|
+
kv = kv_utils.KVCache.from_model_config(model.config)
|
199
|
+
output = model.forward(tokens, input_pos, kv)
|
200
|
+
print("comparing with goldens..")
|
201
|
+
assert torch.allclose(
|
202
|
+
phi2_goldens, output["logits"][0, idx.shape[1] - 1, :], atol=1e-02
|
203
|
+
)
|
203
204
|
|
204
205
|
|
205
206
|
if __name__ == "__main__":
|
206
|
-
input_checkpoint_path = os.path.join(
|
207
|
+
input_checkpoint_path = os.path.join(
|
208
|
+
pathlib.Path.home(), "Downloads/llm_data/phi2"
|
209
|
+
)
|
207
210
|
define_and_run(input_checkpoint_path)
|
@@ -12,28 +12,26 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
#
|
16
|
-
# Note: This is an experimental version of TinyLlama with external KV cache.
|
17
|
-
# Please use with caution.
|
18
15
|
|
16
|
+
"""Example of converting SmolLM model to multi-signature tflite model."""
|
19
17
|
|
20
18
|
import os
|
21
|
-
|
19
|
+
import pathlib
|
22
20
|
|
23
21
|
import ai_edge_torch
|
24
|
-
from ai_edge_torch.generative.examples.
|
25
|
-
from ai_edge_torch.generative.layers
|
22
|
+
from ai_edge_torch.generative.examples.smollm import smollm
|
23
|
+
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
26
24
|
from ai_edge_torch.generative.quantize import quant_recipes
|
27
25
|
import torch
|
28
26
|
|
29
27
|
|
30
|
-
def
|
28
|
+
def convert_smollm_to_tflite(
|
31
29
|
checkpoint_path: str,
|
32
30
|
prefill_seq_len: int = 512,
|
33
31
|
kv_cache_max_len: int = 1024,
|
34
32
|
quantize: bool = True,
|
35
33
|
):
|
36
|
-
"""
|
34
|
+
"""Converts SmolLM model to multi-signature tflite model.
|
37
35
|
|
38
36
|
Args:
|
39
37
|
checkpoint_path (str): The filepath to the model checkpoint, or directory
|
@@ -45,15 +43,15 @@ def convert_tiny_llama_to_tflite(
|
|
45
43
|
quantize (bool, optional): Whether the model should be quanized. Defaults
|
46
44
|
to True.
|
47
45
|
"""
|
48
|
-
pytorch_model =
|
46
|
+
pytorch_model = smollm.build_model(
|
49
47
|
checkpoint_path, kv_cache_max_len=kv_cache_max_len
|
50
48
|
)
|
51
49
|
# Tensors used to trace the model graph during conversion.
|
52
|
-
prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.
|
53
|
-
prefill_input_pos = torch.arange(0, prefill_seq_len)
|
54
|
-
decode_token = torch.tensor([[0]], dtype=torch.
|
55
|
-
decode_input_pos = torch.tensor([0], dtype=torch.
|
56
|
-
kv = kv_utils.
|
50
|
+
prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.int)
|
51
|
+
prefill_input_pos = torch.arange(0, prefill_seq_len, dtype=torch.int)
|
52
|
+
decode_token = torch.tensor([[0]], dtype=torch.int)
|
53
|
+
decode_input_pos = torch.tensor([0], dtype=torch.int)
|
54
|
+
kv = kv_utils.KVCache.from_model_config(pytorch_model.config)
|
57
55
|
|
58
56
|
quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
|
59
57
|
edge_model = (
|
@@ -77,11 +75,12 @@ def convert_tiny_llama_to_tflite(
|
|
77
75
|
)
|
78
76
|
.convert(quant_config=quant_config)
|
79
77
|
)
|
78
|
+
quant_suffix = 'q8' if quantize else 'f32'
|
80
79
|
edge_model.export(
|
81
|
-
f'/tmp/
|
80
|
+
f'/tmp/smollm_{quant_suffix}_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
|
82
81
|
)
|
83
82
|
|
84
83
|
|
85
84
|
if __name__ == '__main__':
|
86
|
-
|
87
|
-
|
85
|
+
path = os.path.join(pathlib.Path.home(), 'Downloads/llm_data/smollm')
|
86
|
+
convert_smollm_to_tflite(path)
|
@@ -0,0 +1,131 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Example of building a SmolLM model."""
|
17
|
+
|
18
|
+
import copy
|
19
|
+
import os
|
20
|
+
import pathlib
|
21
|
+
|
22
|
+
from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
|
23
|
+
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
24
|
+
import ai_edge_torch.generative.layers.model_config as cfg
|
25
|
+
import ai_edge_torch.generative.utilities.loader as loading_utils
|
26
|
+
import numpy as np
|
27
|
+
import torch
|
28
|
+
from torch import nn
|
29
|
+
|
30
|
+
TENSOR_NAMES = copy.copy(tiny_llama.TENSOR_NAMES)
|
31
|
+
# SmolLM re-uses the embedding as the head projection layer.
|
32
|
+
TENSOR_NAMES.lm_head = None
|
33
|
+
|
34
|
+
|
35
|
+
class SmolLM(tiny_llama.TinyLlama):
|
36
|
+
"""A SmolLM model built from the Edge Generative API layers.
|
37
|
+
|
38
|
+
SmolLM shares the same architecture as TinyLlama, but with different model
|
39
|
+
sizes.
|
40
|
+
"""
|
41
|
+
|
42
|
+
def __init__(self, config: cfg.ModelConfig):
|
43
|
+
super().__init__(config)
|
44
|
+
# SmolLM re-uses the embedding as the head projection layer.
|
45
|
+
self.lm_head.weight.data = self.tok_embedding.weight.data
|
46
|
+
|
47
|
+
|
48
|
+
def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
49
|
+
"""Returns the model config for a SmolLM 135M model.
|
50
|
+
|
51
|
+
Args:
|
52
|
+
kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
|
53
|
+
is 1024.
|
54
|
+
|
55
|
+
Returns:
|
56
|
+
The model config for a SmolLM model.
|
57
|
+
"""
|
58
|
+
attn_config = cfg.AttentionConfig(
|
59
|
+
num_heads=9,
|
60
|
+
head_dim=64,
|
61
|
+
num_query_groups=3,
|
62
|
+
rotary_percentage=1.0,
|
63
|
+
)
|
64
|
+
ff_config = cfg.FeedForwardConfig(
|
65
|
+
type=cfg.FeedForwardType.GATED,
|
66
|
+
activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
|
67
|
+
intermediate_size=1536,
|
68
|
+
)
|
69
|
+
norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
|
70
|
+
block_config = cfg.TransformerBlockConfig(
|
71
|
+
attn_config=attn_config,
|
72
|
+
ff_config=ff_config,
|
73
|
+
pre_attention_norm_config=norm_config,
|
74
|
+
post_attention_norm_config=norm_config,
|
75
|
+
)
|
76
|
+
config = cfg.ModelConfig(
|
77
|
+
vocab_size=49152,
|
78
|
+
num_layers=30,
|
79
|
+
max_seq_len=2048,
|
80
|
+
embedding_dim=576,
|
81
|
+
kv_cache_max_len=kv_cache_max_len,
|
82
|
+
block_configs=block_config,
|
83
|
+
final_norm_config=norm_config,
|
84
|
+
enable_hlfb=True,
|
85
|
+
)
|
86
|
+
return config
|
87
|
+
|
88
|
+
|
89
|
+
def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
|
90
|
+
config = get_model_config(**kwargs)
|
91
|
+
config.vocab_size = 128
|
92
|
+
config.num_layers = 2
|
93
|
+
# SmolLM has only one block config.
|
94
|
+
config.block_config(0).ff_config.intermediate_size = 64
|
95
|
+
return config
|
96
|
+
|
97
|
+
|
98
|
+
def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
|
99
|
+
config = get_model_config(**kwargs)
|
100
|
+
model = SmolLM(config)
|
101
|
+
loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
|
102
|
+
# Since embedding and lm-head use the same weight, we need to set strict
|
103
|
+
# to False.
|
104
|
+
loader.load(model, strict=False)
|
105
|
+
model.eval()
|
106
|
+
return model
|
107
|
+
|
108
|
+
|
109
|
+
def define_and_run(checkpoint_path: str) -> None:
|
110
|
+
"""Instantiates and runs a SmolLM model."""
|
111
|
+
|
112
|
+
current_dir = pathlib.Path(__file__).parent.resolve()
|
113
|
+
smollm_goldens = torch.load(current_dir / "smollm_lm_logits.pt")
|
114
|
+
kv_cache_max_len = 1024
|
115
|
+
model = build_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
|
116
|
+
idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
|
117
|
+
tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.int, device="cpu")
|
118
|
+
tokens[0, :4] = idx
|
119
|
+
input_pos = torch.arange(0, kv_cache_max_len, dtype=torch.int)
|
120
|
+
kv = kv_utils.KVCache.from_model_config(model.config)
|
121
|
+
output = model.forward(tokens, input_pos, kv)
|
122
|
+
assert torch.allclose(
|
123
|
+
smollm_goldens, output["logits"][0, idx.shape[1] - 1, :], atol=1e-05
|
124
|
+
)
|
125
|
+
|
126
|
+
|
127
|
+
if __name__ == "__main__":
|
128
|
+
input_checkpoint_path = os.path.join(
|
129
|
+
pathlib.Path.home(), "Downloads/llm_data/smollm"
|
130
|
+
)
|
131
|
+
define_and_run(input_checkpoint_path)
|
@@ -61,8 +61,10 @@ class CLIP(nn.Module):
|
|
61
61
|
)
|
62
62
|
|
63
63
|
self.config = config
|
64
|
+
# CLIP has only one block config.
|
65
|
+
block_config = config.block_config(0)
|
64
66
|
self.transformer_blocks = nn.ModuleList(
|
65
|
-
TransformerBlock(config) for _ in range(config.num_layers)
|
67
|
+
TransformerBlock(block_config, config) for _ in range(config.num_layers)
|
66
68
|
)
|
67
69
|
self.final_norm = builder.build_norm(
|
68
70
|
config.embedding_dim, config.final_norm_config
|
@@ -74,7 +76,7 @@ class CLIP(nn.Module):
|
|
74
76
|
|
75
77
|
@torch.inference_mode
|
76
78
|
def forward(self, tokens: torch.LongTensor) -> torch.FloatTensor:
|
77
|
-
tokens = tokens.type(torch.
|
79
|
+
tokens = tokens.type(torch.int)
|
78
80
|
|
79
81
|
state = self.tok_embedding(tokens) + self.tok_embedding_position
|
80
82
|
for layer in self.transformer_blocks:
|
@@ -112,15 +114,19 @@ def get_model_config() -> cfg.ModelConfig:
|
|
112
114
|
|
113
115
|
norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.LAYER_NORM)
|
114
116
|
|
117
|
+
block_config = cfg.TransformerBlockConfig(
|
118
|
+
attn_config=attn_config,
|
119
|
+
ff_config=ff_config,
|
120
|
+
pre_attention_norm_config=norm_config,
|
121
|
+
post_attention_norm_config=norm_config,
|
122
|
+
)
|
123
|
+
|
115
124
|
config = cfg.ModelConfig(
|
116
125
|
vocab_size=vocab_size,
|
117
126
|
num_layers=num_layers,
|
118
127
|
max_seq_len=max_seq_len,
|
119
128
|
embedding_dim=embedding_dim,
|
120
|
-
|
121
|
-
ff_config=ff_config,
|
122
|
-
pre_attention_norm_config=norm_config,
|
123
|
-
post_attention_norm_config=norm_config,
|
129
|
+
block_configs=block_config,
|
124
130
|
final_norm_config=norm_config,
|
125
131
|
enable_hlfb=True,
|
126
132
|
)
|
@@ -94,7 +94,7 @@ def convert_stable_diffusion_to_tflite(
|
|
94
94
|
n_tokens = 77
|
95
95
|
timestamp = 0
|
96
96
|
len_prompt = 1
|
97
|
-
prompt_tokens = torch.full((1, n_tokens), 0, dtype=torch.
|
97
|
+
prompt_tokens = torch.full((1, n_tokens), 0, dtype=torch.int)
|
98
98
|
input_image = torch.full(
|
99
99
|
(1, 3, image_height, image_width), 0, dtype=torch.float32
|
100
100
|
)
|
@@ -29,24 +29,24 @@ def convert_t5_to_tflite_singlesig(checkpoint_path: str):
|
|
29
29
|
|
30
30
|
# encoder
|
31
31
|
seq_len = 512
|
32
|
-
prefill_e_tokens = torch.full((1, seq_len), 0, dtype=torch.
|
32
|
+
prefill_e_tokens = torch.full((1, seq_len), 0, dtype=torch.int)
|
33
33
|
prompt_e_token = [1, 2, 3, 4, 5, 6]
|
34
34
|
prefill_e_tokens[0, : len(prompt_e_token)] = torch.tensor(
|
35
|
-
prompt_e_token, dtype=torch.
|
35
|
+
prompt_e_token, dtype=torch.int
|
36
36
|
)
|
37
|
-
prefill_e_input_pos = torch.arange(0, seq_len)
|
38
|
-
prefill_d_tokens = torch.full((1, seq_len), 0, dtype=torch.
|
37
|
+
prefill_e_input_pos = torch.arange(0, seq_len, dtype=torch.int)
|
38
|
+
prefill_d_tokens = torch.full((1, seq_len), 0, dtype=torch.int)
|
39
39
|
prompt_d_token = [1, 2, 3, 4, 5, 6]
|
40
40
|
prefill_d_tokens[0, : len(prompt_d_token)] = torch.tensor(
|
41
|
-
prompt_d_token, dtype=torch.
|
41
|
+
prompt_d_token, dtype=torch.int
|
42
42
|
)
|
43
|
-
prefill_d_input_pos = torch.arange(0, seq_len)
|
43
|
+
prefill_d_input_pos = torch.arange(0, seq_len, dtype=torch.int)
|
44
44
|
|
45
45
|
# decoder
|
46
|
-
decode_token = torch.tensor([[1]], dtype=torch.
|
47
|
-
decode_input_pos = torch.tensor([0], dtype=torch.
|
48
|
-
decode_d_token = torch.tensor([[1]], dtype=torch.
|
49
|
-
decode_d_input_pos = torch.tensor([0], dtype=torch.
|
46
|
+
decode_token = torch.tensor([[1]], dtype=torch.int)
|
47
|
+
decode_input_pos = torch.tensor([0], dtype=torch.int)
|
48
|
+
decode_d_token = torch.tensor([[1]], dtype=torch.int)
|
49
|
+
decode_d_input_pos = torch.tensor([0], dtype=torch.int)
|
50
50
|
|
51
51
|
# Pad mask for self attention only on "real" tokens.
|
52
52
|
# Pad with `-inf` for any tokens indices that aren't desired.
|
@@ -81,24 +81,24 @@ def convert_t5_to_tflite_multisig(checkpoint_path: str):
|
|
81
81
|
|
82
82
|
# encoder
|
83
83
|
seq_len = 512
|
84
|
-
prefill_e_tokens = torch.full((1, seq_len), 0, dtype=torch.
|
84
|
+
prefill_e_tokens = torch.full((1, seq_len), 0, dtype=torch.int)
|
85
85
|
prompt_e_token = [1, 2, 3, 4, 5, 6]
|
86
86
|
prefill_e_tokens[0, : len(prompt_e_token)] = torch.tensor(
|
87
|
-
prompt_e_token, dtype=torch.
|
87
|
+
prompt_e_token, dtype=torch.int
|
88
88
|
)
|
89
|
-
prefill_e_input_pos = torch.arange(0, seq_len)
|
90
|
-
prefill_d_tokens = torch.full((1, seq_len), 0, dtype=torch.
|
89
|
+
prefill_e_input_pos = torch.arange(0, seq_len, dtype=torch.int)
|
90
|
+
prefill_d_tokens = torch.full((1, seq_len), 0, dtype=torch.int)
|
91
91
|
prompt_d_token = [1, 2, 3, 4, 5, 6]
|
92
92
|
prefill_d_tokens[0, : len(prompt_d_token)] = torch.tensor(
|
93
|
-
prompt_d_token, dtype=torch.
|
93
|
+
prompt_d_token, dtype=torch.int
|
94
94
|
)
|
95
|
-
prefill_d_input_pos = torch.arange(0, seq_len)
|
95
|
+
prefill_d_input_pos = torch.arange(0, seq_len, dtype=torch.int)
|
96
96
|
|
97
97
|
# decoder
|
98
|
-
decode_token = torch.tensor([[1]], dtype=torch.
|
99
|
-
decode_input_pos = torch.tensor([0], dtype=torch.
|
100
|
-
decode_d_token = torch.tensor([[1]], dtype=torch.
|
101
|
-
decode_d_input_pos = torch.tensor([0], dtype=torch.
|
98
|
+
decode_token = torch.tensor([[1]], dtype=torch.int)
|
99
|
+
decode_input_pos = torch.tensor([0], dtype=torch.int)
|
100
|
+
decode_d_token = torch.tensor([[1]], dtype=torch.int)
|
101
|
+
decode_d_input_pos = torch.tensor([0], dtype=torch.int)
|
102
102
|
|
103
103
|
# Pad mask for self attention only on "real" tokens.
|
104
104
|
# Pad with `-inf` for any tokens indices that aren't desired.
|