ai-edge-torch-nightly 0.3.0.dev20240913__py3-none-any.whl → 0.3.0.dev20240915__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- ai_edge_torch/_convert/conversion.py +2 -1
- ai_edge_torch/_convert/fx_passes/__init__.py +5 -41
- ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py +3 -4
- ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py +3 -4
- ai_edge_torch/_convert/fx_passes/inject_mlir_debuginfo_pass.py +3 -4
- ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +4 -5
- ai_edge_torch/config.py +4 -1
- ai_edge_torch/fx_pass_base.py +101 -0
- ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +4 -4
- ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +4 -4
- ai_edge_torch/generative/examples/gemma/gemma.py +2 -2
- ai_edge_torch/generative/examples/gemma/gemma2.py +2 -2
- ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +86 -0
- ai_edge_torch/generative/examples/openelm/openelm.py +237 -0
- ai_edge_torch/generative/examples/phi/convert_to_tflite.py +4 -4
- ai_edge_torch/generative/examples/phi/phi2.py +2 -2
- ai_edge_torch/generative/examples/smollm/__init__.py +14 -0
- ai_edge_torch/generative/examples/{smallm → smollm}/convert_to_tflite.py +12 -12
- ai_edge_torch/generative/examples/{smallm/smallm.py → smollm/smollm.py} +24 -15
- ai_edge_torch/generative/examples/stable_diffusion/clip.py +1 -1
- ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +1 -1
- ai_edge_torch/generative/examples/t5/convert_to_tflite.py +20 -20
- ai_edge_torch/generative/examples/t5/t5.py +8 -8
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +3 -3
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +4 -4
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +2 -2
- ai_edge_torch/generative/fx_passes/__init__.py +4 -4
- ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +3 -4
- ai_edge_torch/generative/layers/attention.py +7 -0
- ai_edge_torch/generative/layers/builder.py +33 -11
- ai_edge_torch/generative/layers/feed_forward.py +26 -8
- ai_edge_torch/generative/layers/kv_cache.py +4 -4
- ai_edge_torch/generative/layers/model_config.py +24 -15
- ai_edge_torch/generative/quantize/example.py +2 -2
- ai_edge_torch/generative/test/test_model_conversion.py +28 -51
- ai_edge_torch/generative/test/test_model_conversion_large.py +43 -78
- ai_edge_torch/generative/test/test_quantize.py +5 -5
- ai_edge_torch/generative/utilities/loader.py +13 -0
- ai_edge_torch/odml_torch/export.py +40 -0
- ai_edge_torch/odml_torch/lowerings/_basic.py +44 -0
- ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +0 -1
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240913.dist-info → ai_edge_torch_nightly-0.3.0.dev20240915.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240913.dist-info → ai_edge_torch_nightly-0.3.0.dev20240915.dist-info}/RECORD +48 -46
- ai_edge_torch/_convert/fx_passes/_pass_base.py +0 -53
- ai_edge_torch/_convert/fx_passes/canonicalize_pass.py +0 -35
- /ai_edge_torch/generative/examples/{smallm → openelm}/__init__.py +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240913.dist-info → ai_edge_torch_nightly-0.3.0.dev20240915.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240913.dist-info → ai_edge_torch_nightly-0.3.0.dev20240915.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240913.dist-info → ai_edge_torch_nightly-0.3.0.dev20240915.dist-info}/top_level.txt +0 -0
@@ -47,10 +47,10 @@ def convert_phi2_to_tflite(
|
|
47
47
|
checkpoint_path, kv_cache_max_len=kv_cache_max_len
|
48
48
|
)
|
49
49
|
# Tensors used to trace the model graph during conversion.
|
50
|
-
prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.
|
51
|
-
prefill_input_pos = torch.arange(0, prefill_seq_len)
|
52
|
-
decode_token = torch.tensor([[0]], dtype=torch.
|
53
|
-
decode_input_pos = torch.tensor([0], dtype=torch.
|
50
|
+
prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.int)
|
51
|
+
prefill_input_pos = torch.arange(0, prefill_seq_len, dtype=torch.int)
|
52
|
+
decode_token = torch.tensor([[0]], dtype=torch.int)
|
53
|
+
decode_input_pos = torch.tensor([0], dtype=torch.int)
|
54
54
|
kv = kv_cache.KVCache.from_model_config(pytorch_model.config)
|
55
55
|
|
56
56
|
quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
|
@@ -192,9 +192,9 @@ def define_and_run(checkpoint_path: str) -> None:
|
|
192
192
|
kv_cache_max_len = 1024
|
193
193
|
model = build_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
|
194
194
|
idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
|
195
|
-
tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.
|
195
|
+
tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.int, device="cpu")
|
196
196
|
tokens[0, :4] = idx
|
197
|
-
input_pos = torch.arange(0, kv_cache_max_len)
|
197
|
+
input_pos = torch.arange(0, kv_cache_max_len, dtype=torch.int)
|
198
198
|
kv = kv_utils.KVCache.from_model_config(model.config)
|
199
199
|
output = model.forward(tokens, input_pos, kv)
|
200
200
|
print("comparing with goldens..")
|
@@ -0,0 +1,14 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
@@ -13,25 +13,25 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
|
-
"""Example of converting
|
16
|
+
"""Example of converting SmolLM model to multi-signature tflite model."""
|
17
17
|
|
18
18
|
import os
|
19
19
|
import pathlib
|
20
20
|
|
21
21
|
import ai_edge_torch
|
22
|
-
from ai_edge_torch.generative.examples.
|
22
|
+
from ai_edge_torch.generative.examples.smollm import smollm
|
23
23
|
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
24
24
|
from ai_edge_torch.generative.quantize import quant_recipes
|
25
25
|
import torch
|
26
26
|
|
27
27
|
|
28
|
-
def
|
28
|
+
def convert_smollm_to_tflite(
|
29
29
|
checkpoint_path: str,
|
30
30
|
prefill_seq_len: int = 512,
|
31
31
|
kv_cache_max_len: int = 1024,
|
32
32
|
quantize: bool = True,
|
33
33
|
):
|
34
|
-
"""Converts
|
34
|
+
"""Converts SmolLM model to multi-signature tflite model.
|
35
35
|
|
36
36
|
Args:
|
37
37
|
checkpoint_path (str): The filepath to the model checkpoint, or directory
|
@@ -43,14 +43,14 @@ def convert_smallm_to_tflite(
|
|
43
43
|
quantize (bool, optional): Whether the model should be quanized. Defaults
|
44
44
|
to True.
|
45
45
|
"""
|
46
|
-
pytorch_model =
|
46
|
+
pytorch_model = smollm.build_model(
|
47
47
|
checkpoint_path, kv_cache_max_len=kv_cache_max_len
|
48
48
|
)
|
49
49
|
# Tensors used to trace the model graph during conversion.
|
50
|
-
prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.
|
51
|
-
prefill_input_pos = torch.arange(0, prefill_seq_len)
|
52
|
-
decode_token = torch.tensor([[0]], dtype=torch.
|
53
|
-
decode_input_pos = torch.tensor([0], dtype=torch.
|
50
|
+
prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.int)
|
51
|
+
prefill_input_pos = torch.arange(0, prefill_seq_len, dtype=torch.int)
|
52
|
+
decode_token = torch.tensor([[0]], dtype=torch.int)
|
53
|
+
decode_input_pos = torch.tensor([0], dtype=torch.int)
|
54
54
|
kv = kv_utils.KVCache.from_model_config(pytorch_model.config)
|
55
55
|
|
56
56
|
quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
|
@@ -77,10 +77,10 @@ def convert_smallm_to_tflite(
|
|
77
77
|
)
|
78
78
|
quant_suffix = 'q8' if quantize else 'f32'
|
79
79
|
edge_model.export(
|
80
|
-
f'/tmp/
|
80
|
+
f'/tmp/smollm_{quant_suffix}_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
|
81
81
|
)
|
82
82
|
|
83
83
|
|
84
84
|
if __name__ == '__main__':
|
85
|
-
path = os.path.join(pathlib.Path.home(), 'Downloads/llm_data/
|
86
|
-
|
85
|
+
path = os.path.join(pathlib.Path.home(), 'Downloads/llm_data/smollm')
|
86
|
+
convert_smollm_to_tflite(path)
|
@@ -13,7 +13,7 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
|
-
"""Example of building a
|
16
|
+
"""Example of building a SmolLM model."""
|
17
17
|
|
18
18
|
import copy
|
19
19
|
import os
|
@@ -28,32 +28,32 @@ import torch
|
|
28
28
|
from torch import nn
|
29
29
|
|
30
30
|
TENSOR_NAMES = copy.copy(tiny_llama.TENSOR_NAMES)
|
31
|
-
#
|
31
|
+
# SmolLM re-uses the embedding as the head projection layer.
|
32
32
|
TENSOR_NAMES.lm_head = None
|
33
33
|
|
34
34
|
|
35
|
-
class
|
36
|
-
"""A
|
35
|
+
class SmolLM(tiny_llama.TinyLlama):
|
36
|
+
"""A SmolLM model built from the Edge Generative API layers.
|
37
37
|
|
38
|
-
|
38
|
+
SmolLM shares the same architecture as TinyLlama, but with different model
|
39
39
|
sizes.
|
40
40
|
"""
|
41
41
|
|
42
42
|
def __init__(self, config: cfg.ModelConfig):
|
43
43
|
super().__init__(config)
|
44
|
-
#
|
44
|
+
# SmolLM re-uses the embedding as the head projection layer.
|
45
45
|
self.lm_head.weight.data = self.tok_embedding.weight.data
|
46
46
|
|
47
47
|
|
48
48
|
def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
49
|
-
"""Returns the model config for a
|
49
|
+
"""Returns the model config for a SmolLM 135M model.
|
50
50
|
|
51
51
|
Args:
|
52
52
|
kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
|
53
53
|
is 1024.
|
54
54
|
|
55
55
|
Returns:
|
56
|
-
The model config for a
|
56
|
+
The model config for a SmolLM model.
|
57
57
|
"""
|
58
58
|
attn_config = cfg.AttentionConfig(
|
59
59
|
num_heads=9,
|
@@ -86,9 +86,18 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
86
86
|
return config
|
87
87
|
|
88
88
|
|
89
|
+
def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
|
90
|
+
config = get_model_config(**kwargs)
|
91
|
+
config.vocab_size = 128
|
92
|
+
config.num_layers = 2
|
93
|
+
# SmolLM has only one block config.
|
94
|
+
config.block_config(0).ff_config.intermediate_size = 64
|
95
|
+
return config
|
96
|
+
|
97
|
+
|
89
98
|
def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
|
90
99
|
config = get_model_config(**kwargs)
|
91
|
-
model =
|
100
|
+
model = SmolLM(config)
|
92
101
|
loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
|
93
102
|
# Since embedding and lm-head use the same weight, we need to set strict
|
94
103
|
# to False.
|
@@ -98,25 +107,25 @@ def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
|
|
98
107
|
|
99
108
|
|
100
109
|
def define_and_run(checkpoint_path: str) -> None:
|
101
|
-
"""Instantiates and runs a
|
110
|
+
"""Instantiates and runs a SmolLM model."""
|
102
111
|
|
103
112
|
current_dir = pathlib.Path(__file__).parent.resolve()
|
104
|
-
|
113
|
+
smollm_goldens = torch.load(current_dir / "smollm_lm_logits.pt")
|
105
114
|
kv_cache_max_len = 1024
|
106
115
|
model = build_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
|
107
116
|
idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
|
108
|
-
tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.
|
117
|
+
tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.int, device="cpu")
|
109
118
|
tokens[0, :4] = idx
|
110
|
-
input_pos = torch.arange(0, kv_cache_max_len)
|
119
|
+
input_pos = torch.arange(0, kv_cache_max_len, dtype=torch.int)
|
111
120
|
kv = kv_utils.KVCache.from_model_config(model.config)
|
112
121
|
output = model.forward(tokens, input_pos, kv)
|
113
122
|
assert torch.allclose(
|
114
|
-
|
123
|
+
smollm_goldens, output["logits"][0, idx.shape[1] - 1, :], atol=1e-05
|
115
124
|
)
|
116
125
|
|
117
126
|
|
118
127
|
if __name__ == "__main__":
|
119
128
|
input_checkpoint_path = os.path.join(
|
120
|
-
pathlib.Path.home(), "Downloads/llm_data/
|
129
|
+
pathlib.Path.home(), "Downloads/llm_data/smollm"
|
121
130
|
)
|
122
131
|
define_and_run(input_checkpoint_path)
|
@@ -76,7 +76,7 @@ class CLIP(nn.Module):
|
|
76
76
|
|
77
77
|
@torch.inference_mode
|
78
78
|
def forward(self, tokens: torch.LongTensor) -> torch.FloatTensor:
|
79
|
-
tokens = tokens.type(torch.
|
79
|
+
tokens = tokens.type(torch.int)
|
80
80
|
|
81
81
|
state = self.tok_embedding(tokens) + self.tok_embedding_position
|
82
82
|
for layer in self.transformer_blocks:
|
@@ -94,7 +94,7 @@ def convert_stable_diffusion_to_tflite(
|
|
94
94
|
n_tokens = 77
|
95
95
|
timestamp = 0
|
96
96
|
len_prompt = 1
|
97
|
-
prompt_tokens = torch.full((1, n_tokens), 0, dtype=torch.
|
97
|
+
prompt_tokens = torch.full((1, n_tokens), 0, dtype=torch.int)
|
98
98
|
input_image = torch.full(
|
99
99
|
(1, 3, image_height, image_width), 0, dtype=torch.float32
|
100
100
|
)
|
@@ -29,24 +29,24 @@ def convert_t5_to_tflite_singlesig(checkpoint_path: str):
|
|
29
29
|
|
30
30
|
# encoder
|
31
31
|
seq_len = 512
|
32
|
-
prefill_e_tokens = torch.full((1, seq_len), 0, dtype=torch.
|
32
|
+
prefill_e_tokens = torch.full((1, seq_len), 0, dtype=torch.int)
|
33
33
|
prompt_e_token = [1, 2, 3, 4, 5, 6]
|
34
34
|
prefill_e_tokens[0, : len(prompt_e_token)] = torch.tensor(
|
35
|
-
prompt_e_token, dtype=torch.
|
35
|
+
prompt_e_token, dtype=torch.int
|
36
36
|
)
|
37
|
-
prefill_e_input_pos = torch.arange(0, seq_len)
|
38
|
-
prefill_d_tokens = torch.full((1, seq_len), 0, dtype=torch.
|
37
|
+
prefill_e_input_pos = torch.arange(0, seq_len, dtype=torch.int)
|
38
|
+
prefill_d_tokens = torch.full((1, seq_len), 0, dtype=torch.int)
|
39
39
|
prompt_d_token = [1, 2, 3, 4, 5, 6]
|
40
40
|
prefill_d_tokens[0, : len(prompt_d_token)] = torch.tensor(
|
41
|
-
prompt_d_token, dtype=torch.
|
41
|
+
prompt_d_token, dtype=torch.int
|
42
42
|
)
|
43
|
-
prefill_d_input_pos = torch.arange(0, seq_len)
|
43
|
+
prefill_d_input_pos = torch.arange(0, seq_len, dtype=torch.int)
|
44
44
|
|
45
45
|
# decoder
|
46
|
-
decode_token = torch.tensor([[1]], dtype=torch.
|
47
|
-
decode_input_pos = torch.tensor([0], dtype=torch.
|
48
|
-
decode_d_token = torch.tensor([[1]], dtype=torch.
|
49
|
-
decode_d_input_pos = torch.tensor([0], dtype=torch.
|
46
|
+
decode_token = torch.tensor([[1]], dtype=torch.int)
|
47
|
+
decode_input_pos = torch.tensor([0], dtype=torch.int)
|
48
|
+
decode_d_token = torch.tensor([[1]], dtype=torch.int)
|
49
|
+
decode_d_input_pos = torch.tensor([0], dtype=torch.int)
|
50
50
|
|
51
51
|
# Pad mask for self attention only on "real" tokens.
|
52
52
|
# Pad with `-inf` for any tokens indices that aren't desired.
|
@@ -81,24 +81,24 @@ def convert_t5_to_tflite_multisig(checkpoint_path: str):
|
|
81
81
|
|
82
82
|
# encoder
|
83
83
|
seq_len = 512
|
84
|
-
prefill_e_tokens = torch.full((1, seq_len), 0, dtype=torch.
|
84
|
+
prefill_e_tokens = torch.full((1, seq_len), 0, dtype=torch.int)
|
85
85
|
prompt_e_token = [1, 2, 3, 4, 5, 6]
|
86
86
|
prefill_e_tokens[0, : len(prompt_e_token)] = torch.tensor(
|
87
|
-
prompt_e_token, dtype=torch.
|
87
|
+
prompt_e_token, dtype=torch.int
|
88
88
|
)
|
89
|
-
prefill_e_input_pos = torch.arange(0, seq_len)
|
90
|
-
prefill_d_tokens = torch.full((1, seq_len), 0, dtype=torch.
|
89
|
+
prefill_e_input_pos = torch.arange(0, seq_len, dtype=torch.int)
|
90
|
+
prefill_d_tokens = torch.full((1, seq_len), 0, dtype=torch.int)
|
91
91
|
prompt_d_token = [1, 2, 3, 4, 5, 6]
|
92
92
|
prefill_d_tokens[0, : len(prompt_d_token)] = torch.tensor(
|
93
|
-
prompt_d_token, dtype=torch.
|
93
|
+
prompt_d_token, dtype=torch.int
|
94
94
|
)
|
95
|
-
prefill_d_input_pos = torch.arange(0, seq_len)
|
95
|
+
prefill_d_input_pos = torch.arange(0, seq_len, dtype=torch.int)
|
96
96
|
|
97
97
|
# decoder
|
98
|
-
decode_token = torch.tensor([[1]], dtype=torch.
|
99
|
-
decode_input_pos = torch.tensor([0], dtype=torch.
|
100
|
-
decode_d_token = torch.tensor([[1]], dtype=torch.
|
101
|
-
decode_d_input_pos = torch.tensor([0], dtype=torch.
|
98
|
+
decode_token = torch.tensor([[1]], dtype=torch.int)
|
99
|
+
decode_input_pos = torch.tensor([0], dtype=torch.int)
|
100
|
+
decode_d_token = torch.tensor([[1]], dtype=torch.int)
|
101
|
+
decode_d_input_pos = torch.tensor([0], dtype=torch.int)
|
102
102
|
|
103
103
|
# Pad mask for self attention only on "real" tokens.
|
104
104
|
# Pad with `-inf` for any tokens indices that aren't desired.
|
@@ -601,12 +601,12 @@ def define_and_run_t5(checkpoint_path: str) -> None:
|
|
601
601
|
model = build_t5_model(checkpoint_path)
|
602
602
|
|
603
603
|
idx = get_sample_encoder_input_ids()
|
604
|
-
tokens = torch.full((1, 512), 0, dtype=torch.
|
604
|
+
tokens = torch.full((1, 512), 0, dtype=torch.int, device="cpu")
|
605
605
|
tokens[0, :77] = idx
|
606
|
-
input_pos = torch.arange(0, 512)
|
606
|
+
input_pos = torch.arange(0, 512, dtype=torch.int)
|
607
607
|
|
608
|
-
decode_d_token = torch.tensor([[0]], dtype=torch.
|
609
|
-
decode_d_input_pos = torch.tensor([0], dtype=torch.
|
608
|
+
decode_d_token = torch.tensor([[0]], dtype=torch.int)
|
609
|
+
decode_d_input_pos = torch.tensor([0], dtype=torch.int)
|
610
610
|
pad_mask = torch.zeros([model.config.kv_cache_max], dtype=torch.float32)
|
611
611
|
pad_mask[77:] = float("-inf")
|
612
612
|
lm_logits = model.forward(
|
@@ -633,12 +633,12 @@ def define_and_run_t5_split(checkpoint_path: str) -> None:
|
|
633
633
|
)
|
634
634
|
idx = get_sample_encoder_input_ids()
|
635
635
|
|
636
|
-
tokens = torch.full((1, 512), 0, dtype=torch.
|
636
|
+
tokens = torch.full((1, 512), 0, dtype=torch.int, device="cpu")
|
637
637
|
tokens[0, :77] = idx
|
638
|
-
input_pos = torch.arange(0, 512)
|
638
|
+
input_pos = torch.arange(0, 512, dtype=torch.int)
|
639
639
|
|
640
|
-
decode_d_token = torch.tensor([[0]], dtype=torch.
|
641
|
-
decode_d_input_pos = torch.tensor([0], dtype=torch.
|
640
|
+
decode_d_token = torch.tensor([[0]], dtype=torch.int)
|
641
|
+
decode_d_input_pos = torch.tensor([0], dtype=torch.int)
|
642
642
|
pad_mask = torch.zeros(
|
643
643
|
[t5_encoder_model.config.kv_cache_max], dtype=torch.float32
|
644
644
|
)
|
@@ -124,13 +124,13 @@ def get_model_config() -> cfg.ModelConfig:
|
|
124
124
|
|
125
125
|
|
126
126
|
def get_sample_prefill_inputs() -> Tuple[torch.Tensor, torch.Tensor]:
|
127
|
-
tokens = torch.unsqueeze(torch.arange(0, 100), 0)
|
128
|
-
input_pos = torch.arange(0, 100)
|
127
|
+
tokens = torch.unsqueeze(torch.arange(0, 100, dtype=torch.int), 0)
|
128
|
+
input_pos = torch.arange(0, 100, dtype=torch.int)
|
129
129
|
return tokens, input_pos
|
130
130
|
|
131
131
|
|
132
132
|
def get_sample_decode_inputs() -> Tuple[torch.Tensor, torch.Tensor]:
|
133
|
-
tokens = torch.tensor([[1]], dtype=torch.
|
133
|
+
tokens = torch.tensor([[1]], dtype=torch.int)
|
134
134
|
input_pos = torch.tensor([10])
|
135
135
|
return tokens, input_pos
|
136
136
|
|
@@ -47,10 +47,10 @@ def convert_tiny_llama_to_tflite(
|
|
47
47
|
checkpoint_path, kv_cache_max_len=kv_cache_max_len
|
48
48
|
)
|
49
49
|
# Tensors used to trace the model graph during conversion.
|
50
|
-
prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.
|
51
|
-
prefill_input_pos = torch.arange(0, prefill_seq_len)
|
52
|
-
decode_token = torch.tensor([[0]], dtype=torch.
|
53
|
-
decode_input_pos = torch.tensor([0], dtype=torch.
|
50
|
+
prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.int)
|
51
|
+
prefill_input_pos = torch.arange(0, prefill_seq_len, dtype=torch.int)
|
52
|
+
decode_token = torch.tensor([[0]], dtype=torch.int)
|
53
|
+
decode_input_pos = torch.tensor([0], dtype=torch.int)
|
54
54
|
kv = kv_utils.KVCache.from_model_config(pytorch_model.config)
|
55
55
|
|
56
56
|
quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
|
@@ -189,9 +189,9 @@ def define_and_run(checkpoint_path: str) -> None:
|
|
189
189
|
kv_cache_max_len = 1024
|
190
190
|
model = build_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
|
191
191
|
idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
|
192
|
-
tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.
|
192
|
+
tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.int, device="cpu")
|
193
193
|
tokens[0, :4] = idx
|
194
|
-
input_pos = torch.arange(0, kv_cache_max_len)
|
194
|
+
input_pos = torch.arange(0, kv_cache_max_len, dtype=torch.int)
|
195
195
|
kv = kv_utils.KVCache.from_model_config(model.config)
|
196
196
|
output = model.forward(tokens, input_pos, kv)
|
197
197
|
assert torch.allclose(
|
@@ -12,16 +12,16 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
from ai_edge_torch
|
16
|
-
from ai_edge_torch.
|
17
|
-
from ai_edge_torch.generative.fx_passes.remove_sdpa_zero_mask_pass import RemoveSDPACompositeZeroMaskPass
|
15
|
+
from ai_edge_torch import fx_pass_base
|
16
|
+
from ai_edge_torch.fx_pass_base import CanonicalizePass
|
17
|
+
from ai_edge_torch.generative.fx_passes.remove_sdpa_zero_mask_pass import RemoveSDPACompositeZeroMaskPass
|
18
18
|
import torch
|
19
19
|
|
20
20
|
|
21
21
|
def run_generative_passes(
|
22
22
|
exported_program: torch.export.ExportedProgram,
|
23
23
|
) -> torch.export.ExportedProgram:
|
24
|
-
return run_passes(
|
24
|
+
return fx_pass_base.run_passes(
|
25
25
|
exported_program,
|
26
26
|
[
|
27
27
|
RemoveSDPACompositeZeroMaskPass(),
|
@@ -12,13 +12,12 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
+
from ai_edge_torch import fx_pass_base
|
15
16
|
from ai_edge_torch import lowertools
|
16
|
-
from ai_edge_torch._convert.fx_passes._pass_base import ExportedProgramPassBase
|
17
|
-
from ai_edge_torch._convert.fx_passes._pass_base import ExportedProgramPassResult
|
18
17
|
import torch
|
19
18
|
|
20
19
|
|
21
|
-
class RemoveSDPACompositeZeroMaskPass(ExportedProgramPassBase):
|
20
|
+
class RemoveSDPACompositeZeroMaskPass(fx_pass_base.ExportedProgramPassBase):
|
22
21
|
|
23
22
|
def is_zero_tensor_node(self, node: torch.fx.Node):
|
24
23
|
return node.target == torch.ops.aten.zeros.default
|
@@ -48,4 +47,4 @@ class RemoveSDPACompositeZeroMaskPass(ExportedProgramPassBase):
|
|
48
47
|
|
49
48
|
exported_program.graph_module.graph.lint()
|
50
49
|
exported_program.graph_module.recompile()
|
51
|
-
return ExportedProgramPassResult(exported_program, True)
|
50
|
+
return fx_pass_base.ExportedProgramPassResult(exported_program, True)
|
@@ -160,6 +160,10 @@ class CausalSelfAttention(nn.Module):
|
|
160
160
|
self.output_projection = nn.Linear(
|
161
161
|
output_shape, dim, bias=config.output_proj_use_bias
|
162
162
|
)
|
163
|
+
self.query_norm = builder.build_norm(
|
164
|
+
config.head_dim, config.query_norm_config
|
165
|
+
)
|
166
|
+
self.key_norm = builder.build_norm(config.head_dim, config.key_norm_config)
|
163
167
|
self.config = config
|
164
168
|
self.enable_hlfb = enable_hlfb
|
165
169
|
self.sdpa_func = (
|
@@ -224,6 +228,9 @@ class CausalSelfAttention(nn.Module):
|
|
224
228
|
dim=-1,
|
225
229
|
)
|
226
230
|
|
231
|
+
q = self.query_norm(q)
|
232
|
+
k = self.key_norm(k)
|
233
|
+
|
227
234
|
q = q.reshape(B, T, -1, self.config.head_dim)
|
228
235
|
k = k.reshape(B, T, -1, self.config.head_dim)
|
229
236
|
v = v.reshape(B, T, -1, self.config.head_dim)
|
@@ -13,6 +13,8 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
# Builder class for individual components.
|
16
|
+
from typing import Callable
|
17
|
+
|
16
18
|
import ai_edge_torch.generative.layers.feed_forward as feed_forward
|
17
19
|
import ai_edge_torch.generative.layers.model_config as cfg
|
18
20
|
import ai_edge_torch.generative.layers.normalization as normalization
|
@@ -21,20 +23,34 @@ from torch import nn
|
|
21
23
|
import torch.nn.functional as F
|
22
24
|
|
23
25
|
|
24
|
-
|
25
|
-
|
26
|
+
def build_glu(
|
27
|
+
act: Callable[[torch.Tensor], torch.Tensor], gate_is_front: bool = False
|
28
|
+
) -> Callable[[torch.Tensor], torch.Tensor]:
|
29
|
+
"""Builds an activation function with GLU (Gated Linear Unit).
|
30
|
+
|
31
|
+
If gate_is_front is True,
|
32
|
+
f(x) = act(x) * y
|
33
|
+
otherwise,
|
34
|
+
f(x) = x * act(y),
|
35
|
+
where x is the first half of the input and y is the second half of the input.
|
26
36
|
|
27
|
-
|
28
|
-
|
37
|
+
Args:
|
38
|
+
act (Callable[[torch.Tensor], torch.Tensor]): activation function to apply
|
39
|
+
to the gate.
|
40
|
+
gate_is_front: whether the gate is in front half of the input. Other part is
|
41
|
+
the output in GLU.
|
42
|
+
|
43
|
+
Returns:
|
44
|
+
A callable activation function with GLU.
|
29
45
|
"""
|
30
46
|
|
31
|
-
def
|
32
|
-
|
33
|
-
|
47
|
+
def _glu(x):
|
48
|
+
x, y = x.chunk(2, dim=-1)
|
49
|
+
if gate_is_front:
|
50
|
+
return act(x) * y
|
51
|
+
return x * act(y)
|
34
52
|
|
35
|
-
|
36
|
-
x, gate = self.proj(x).chunk(2, dim=-1)
|
37
|
-
return x * F.gelu(gate)
|
53
|
+
return _glu
|
38
54
|
|
39
55
|
|
40
56
|
def build_norm(dim: int, config: cfg.NormalizationConfig):
|
@@ -99,6 +115,10 @@ def build_ff(dim: int, config: cfg.FeedForwardConfig):
|
|
99
115
|
hidden_dim=config.intermediate_size,
|
100
116
|
activation=activation,
|
101
117
|
use_bias=config.use_bias,
|
118
|
+
use_glu=(
|
119
|
+
config.activation.type == cfg.ActivationType.GE_GLU
|
120
|
+
or config.activation.type == cfg.ActivationType.SILU_GLU
|
121
|
+
),
|
102
122
|
pre_ff_norm=pre_ff_norm,
|
103
123
|
post_ff_norm=post_ff_norm,
|
104
124
|
)
|
@@ -129,8 +149,10 @@ def get_activation(config: cfg.ActivationConfig):
|
|
129
149
|
# See: https://github.com/hendrycks/GELUs
|
130
150
|
return lambda x: x * F.sigmoid(1.702 * x)
|
131
151
|
elif config.type == cfg.ActivationType.GE_GLU:
|
132
|
-
return
|
152
|
+
return build_glu(F.gelu, config.gate_is_front)
|
133
153
|
elif config.type == cfg.ActivationType.RELU:
|
134
154
|
return F.relu
|
155
|
+
elif config.type == cfg.ActivationType.SILU_GLU:
|
156
|
+
return build_glu(F.silu, config.gate_is_front)
|
135
157
|
else:
|
136
158
|
raise ValueError("Unsupported activation type.")
|
@@ -30,18 +30,27 @@ class SequentialFeedForward(nn.Module):
|
|
30
30
|
hidden_dim: int,
|
31
31
|
activation: Callable[[torch.Tensor], torch.Tensor],
|
32
32
|
use_bias=False,
|
33
|
+
use_glu=False,
|
33
34
|
pre_ff_norm: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
|
34
35
|
post_ff_norm: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
|
35
36
|
):
|
36
37
|
"""Init function for feedforward layer.
|
37
38
|
|
38
|
-
Args:
|
39
|
-
|
40
|
-
|
39
|
+
Args:
|
40
|
+
dim (int): embedding size.
|
41
|
+
hidden_dim (int): hidden dim size of the feedforward layer.
|
42
|
+
activation (Callable): activation function used in this block.
|
43
|
+
use_bias (Boolean): whether to use bias. Default is false.
|
44
|
+
use_glu (Boolean): whether to use glu in activation. Default is false.
|
45
|
+
pre_ff_norm (Callable): pre feedforward norm. Default is None.
|
46
|
+
post_ff_norm (Callable): post feedforward norm. Default is None.
|
41
47
|
"""
|
42
48
|
super().__init__()
|
43
49
|
self.act = activation
|
44
|
-
|
50
|
+
if use_glu:
|
51
|
+
self.w1 = nn.Linear(dim, hidden_dim * 2, bias=use_bias)
|
52
|
+
else:
|
53
|
+
self.w1 = nn.Linear(dim, hidden_dim, bias=use_bias)
|
45
54
|
self.w2 = nn.Linear(hidden_dim, dim, bias=use_bias)
|
46
55
|
self.pre_ff_norm = pre_ff_norm if pre_ff_norm else lambda x: x
|
47
56
|
self.post_ff_norm = post_ff_norm if post_ff_norm else lambda x: x
|
@@ -72,18 +81,27 @@ class GatedFeedForward(nn.Module):
|
|
72
81
|
hidden_dim: int,
|
73
82
|
activation: Callable[[torch.Tensor], torch.Tensor],
|
74
83
|
use_bias=False,
|
84
|
+
use_glu=False,
|
75
85
|
pre_ff_norm: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
|
76
86
|
post_ff_norm: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
|
77
87
|
):
|
78
88
|
"""Init function for feedforward layer.
|
79
89
|
|
80
|
-
Args:
|
81
|
-
|
82
|
-
|
90
|
+
Args:
|
91
|
+
dim (int): embedding size.
|
92
|
+
hidden_dim (int): hidden dim size of the feedforward layer.
|
93
|
+
activation (Callable): activation function used in this block.
|
94
|
+
use_bias (Boolean): whether to use bias. Default is false.
|
95
|
+
use_glu (Boolean): whether to use glu in activation. Default is false.
|
96
|
+
pre_ff_norm (Callable): pre feedforward norm. Default is None.
|
97
|
+
post_ff_norm (Callable): post feedforward norm. Default is None.
|
83
98
|
"""
|
84
99
|
super().__init__()
|
85
100
|
self.act = activation
|
86
|
-
|
101
|
+
if use_glu:
|
102
|
+
self.w1 = nn.Linear(dim, hidden_dim * 2, bias=use_bias)
|
103
|
+
else:
|
104
|
+
self.w1 = nn.Linear(dim, hidden_dim, bias=use_bias)
|
87
105
|
self.w2 = nn.Linear(hidden_dim, dim, bias=use_bias)
|
88
106
|
self.w3 = nn.Linear(dim, hidden_dim, bias=use_bias)
|
89
107
|
self.pre_ff_norm = pre_ff_norm if pre_ff_norm else lambda x: x
|
@@ -172,8 +172,8 @@ def _update_kv_base_impl(
|
|
172
172
|
v_slice: torch.Tensor,
|
173
173
|
) -> KVCacheEntry:
|
174
174
|
"""Update the cache buffer without High Level Function Boundary annotation."""
|
175
|
-
k = cache.k_cache.index_copy(1, input_pos, k_slice)
|
176
|
-
v = cache.v_cache.index_copy(1, input_pos, v_slice)
|
175
|
+
k = cache.k_cache.index_copy(1, input_pos.to(torch.long), k_slice)
|
176
|
+
v = cache.v_cache.index_copy(1, input_pos.to(torch.long), v_slice)
|
177
177
|
updated_cache = KVCacheEntry(k, v)
|
178
178
|
return updated_cache
|
179
179
|
|
@@ -189,7 +189,7 @@ def _update_kv_hlfb_impl(
|
|
189
189
|
k_cache, v_cache, input_pos, k_slice, v_slice = builder.mark_inputs(
|
190
190
|
cache.k_cache, cache.v_cache, input_pos, k_slice, v_slice
|
191
191
|
)
|
192
|
-
k = k_cache.index_copy(1, input_pos, k_slice)
|
193
|
-
v = v_cache.index_copy(1, input_pos, v_slice)
|
192
|
+
k = k_cache.index_copy(1, input_pos.to(torch.long), k_slice)
|
193
|
+
v = v_cache.index_copy(1, input_pos.to(torch.long), v_slice)
|
194
194
|
k, v = builder.mark_outputs(k, v)
|
195
195
|
return KVCacheEntry(k, v)
|