ai-edge-torch-nightly 0.3.0.dev20240913__py3-none-any.whl → 0.3.0.dev20240915__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/_convert/conversion.py +2 -1
- ai_edge_torch/_convert/fx_passes/__init__.py +5 -41
- ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py +3 -4
- ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py +3 -4
- ai_edge_torch/_convert/fx_passes/inject_mlir_debuginfo_pass.py +3 -4
- ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +4 -5
- ai_edge_torch/config.py +4 -1
- ai_edge_torch/fx_pass_base.py +101 -0
- ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +4 -4
- ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +4 -4
- ai_edge_torch/generative/examples/gemma/gemma.py +2 -2
- ai_edge_torch/generative/examples/gemma/gemma2.py +2 -2
- ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +86 -0
- ai_edge_torch/generative/examples/openelm/openelm.py +237 -0
- ai_edge_torch/generative/examples/phi/convert_to_tflite.py +4 -4
- ai_edge_torch/generative/examples/phi/phi2.py +2 -2
- ai_edge_torch/generative/examples/smollm/__init__.py +14 -0
- ai_edge_torch/generative/examples/{smallm → smollm}/convert_to_tflite.py +12 -12
- ai_edge_torch/generative/examples/{smallm/smallm.py → smollm/smollm.py} +24 -15
- ai_edge_torch/generative/examples/stable_diffusion/clip.py +1 -1
- ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +1 -1
- ai_edge_torch/generative/examples/t5/convert_to_tflite.py +20 -20
- ai_edge_torch/generative/examples/t5/t5.py +8 -8
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +3 -3
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +4 -4
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +2 -2
- ai_edge_torch/generative/fx_passes/__init__.py +4 -4
- ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +3 -4
- ai_edge_torch/generative/layers/attention.py +7 -0
- ai_edge_torch/generative/layers/builder.py +33 -11
- ai_edge_torch/generative/layers/feed_forward.py +26 -8
- ai_edge_torch/generative/layers/kv_cache.py +4 -4
- ai_edge_torch/generative/layers/model_config.py +24 -15
- ai_edge_torch/generative/quantize/example.py +2 -2
- ai_edge_torch/generative/test/test_model_conversion.py +28 -51
- ai_edge_torch/generative/test/test_model_conversion_large.py +43 -78
- ai_edge_torch/generative/test/test_quantize.py +5 -5
- ai_edge_torch/generative/utilities/loader.py +13 -0
- ai_edge_torch/odml_torch/export.py +40 -0
- ai_edge_torch/odml_torch/lowerings/_basic.py +44 -0
- ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +0 -1
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240913.dist-info → ai_edge_torch_nightly-0.3.0.dev20240915.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240913.dist-info → ai_edge_torch_nightly-0.3.0.dev20240915.dist-info}/RECORD +48 -46
- ai_edge_torch/_convert/fx_passes/_pass_base.py +0 -53
- ai_edge_torch/_convert/fx_passes/canonicalize_pass.py +0 -35
- /ai_edge_torch/generative/examples/{smallm → openelm}/__init__.py +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240913.dist-info → ai_edge_torch_nightly-0.3.0.dev20240915.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240913.dist-info → ai_edge_torch_nightly-0.3.0.dev20240915.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240913.dist-info → ai_edge_torch_nightly-0.3.0.dev20240915.dist-info}/top_level.txt +0 -0
@@ -47,10 +47,10 @@ def convert_phi2_to_tflite(
|
|
47
47
|
checkpoint_path, kv_cache_max_len=kv_cache_max_len
|
48
48
|
)
|
49
49
|
# Tensors used to trace the model graph during conversion.
|
50
|
-
prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.
|
51
|
-
prefill_input_pos = torch.arange(0, prefill_seq_len)
|
52
|
-
decode_token = torch.tensor([[0]], dtype=torch.
|
53
|
-
decode_input_pos = torch.tensor([0], dtype=torch.
|
50
|
+
prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.int)
|
51
|
+
prefill_input_pos = torch.arange(0, prefill_seq_len, dtype=torch.int)
|
52
|
+
decode_token = torch.tensor([[0]], dtype=torch.int)
|
53
|
+
decode_input_pos = torch.tensor([0], dtype=torch.int)
|
54
54
|
kv = kv_cache.KVCache.from_model_config(pytorch_model.config)
|
55
55
|
|
56
56
|
quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
|
@@ -192,9 +192,9 @@ def define_and_run(checkpoint_path: str) -> None:
|
|
192
192
|
kv_cache_max_len = 1024
|
193
193
|
model = build_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
|
194
194
|
idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
|
195
|
-
tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.
|
195
|
+
tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.int, device="cpu")
|
196
196
|
tokens[0, :4] = idx
|
197
|
-
input_pos = torch.arange(0, kv_cache_max_len)
|
197
|
+
input_pos = torch.arange(0, kv_cache_max_len, dtype=torch.int)
|
198
198
|
kv = kv_utils.KVCache.from_model_config(model.config)
|
199
199
|
output = model.forward(tokens, input_pos, kv)
|
200
200
|
print("comparing with goldens..")
|
@@ -0,0 +1,14 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
@@ -13,25 +13,25 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
|
-
"""Example of converting
|
16
|
+
"""Example of converting SmolLM model to multi-signature tflite model."""
|
17
17
|
|
18
18
|
import os
|
19
19
|
import pathlib
|
20
20
|
|
21
21
|
import ai_edge_torch
|
22
|
-
from ai_edge_torch.generative.examples.
|
22
|
+
from ai_edge_torch.generative.examples.smollm import smollm
|
23
23
|
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
24
24
|
from ai_edge_torch.generative.quantize import quant_recipes
|
25
25
|
import torch
|
26
26
|
|
27
27
|
|
28
|
-
def
|
28
|
+
def convert_smollm_to_tflite(
|
29
29
|
checkpoint_path: str,
|
30
30
|
prefill_seq_len: int = 512,
|
31
31
|
kv_cache_max_len: int = 1024,
|
32
32
|
quantize: bool = True,
|
33
33
|
):
|
34
|
-
"""Converts
|
34
|
+
"""Converts SmolLM model to multi-signature tflite model.
|
35
35
|
|
36
36
|
Args:
|
37
37
|
checkpoint_path (str): The filepath to the model checkpoint, or directory
|
@@ -43,14 +43,14 @@ def convert_smallm_to_tflite(
|
|
43
43
|
quantize (bool, optional): Whether the model should be quanized. Defaults
|
44
44
|
to True.
|
45
45
|
"""
|
46
|
-
pytorch_model =
|
46
|
+
pytorch_model = smollm.build_model(
|
47
47
|
checkpoint_path, kv_cache_max_len=kv_cache_max_len
|
48
48
|
)
|
49
49
|
# Tensors used to trace the model graph during conversion.
|
50
|
-
prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.
|
51
|
-
prefill_input_pos = torch.arange(0, prefill_seq_len)
|
52
|
-
decode_token = torch.tensor([[0]], dtype=torch.
|
53
|
-
decode_input_pos = torch.tensor([0], dtype=torch.
|
50
|
+
prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.int)
|
51
|
+
prefill_input_pos = torch.arange(0, prefill_seq_len, dtype=torch.int)
|
52
|
+
decode_token = torch.tensor([[0]], dtype=torch.int)
|
53
|
+
decode_input_pos = torch.tensor([0], dtype=torch.int)
|
54
54
|
kv = kv_utils.KVCache.from_model_config(pytorch_model.config)
|
55
55
|
|
56
56
|
quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
|
@@ -77,10 +77,10 @@ def convert_smallm_to_tflite(
|
|
77
77
|
)
|
78
78
|
quant_suffix = 'q8' if quantize else 'f32'
|
79
79
|
edge_model.export(
|
80
|
-
f'/tmp/
|
80
|
+
f'/tmp/smollm_{quant_suffix}_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
|
81
81
|
)
|
82
82
|
|
83
83
|
|
84
84
|
if __name__ == '__main__':
|
85
|
-
path = os.path.join(pathlib.Path.home(), 'Downloads/llm_data/
|
86
|
-
|
85
|
+
path = os.path.join(pathlib.Path.home(), 'Downloads/llm_data/smollm')
|
86
|
+
convert_smollm_to_tflite(path)
|
@@ -13,7 +13,7 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
|
-
"""Example of building a
|
16
|
+
"""Example of building a SmolLM model."""
|
17
17
|
|
18
18
|
import copy
|
19
19
|
import os
|
@@ -28,32 +28,32 @@ import torch
|
|
28
28
|
from torch import nn
|
29
29
|
|
30
30
|
TENSOR_NAMES = copy.copy(tiny_llama.TENSOR_NAMES)
|
31
|
-
#
|
31
|
+
# SmolLM re-uses the embedding as the head projection layer.
|
32
32
|
TENSOR_NAMES.lm_head = None
|
33
33
|
|
34
34
|
|
35
|
-
class
|
36
|
-
"""A
|
35
|
+
class SmolLM(tiny_llama.TinyLlama):
|
36
|
+
"""A SmolLM model built from the Edge Generative API layers.
|
37
37
|
|
38
|
-
|
38
|
+
SmolLM shares the same architecture as TinyLlama, but with different model
|
39
39
|
sizes.
|
40
40
|
"""
|
41
41
|
|
42
42
|
def __init__(self, config: cfg.ModelConfig):
|
43
43
|
super().__init__(config)
|
44
|
-
#
|
44
|
+
# SmolLM re-uses the embedding as the head projection layer.
|
45
45
|
self.lm_head.weight.data = self.tok_embedding.weight.data
|
46
46
|
|
47
47
|
|
48
48
|
def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
49
|
-
"""Returns the model config for a
|
49
|
+
"""Returns the model config for a SmolLM 135M model.
|
50
50
|
|
51
51
|
Args:
|
52
52
|
kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
|
53
53
|
is 1024.
|
54
54
|
|
55
55
|
Returns:
|
56
|
-
The model config for a
|
56
|
+
The model config for a SmolLM model.
|
57
57
|
"""
|
58
58
|
attn_config = cfg.AttentionConfig(
|
59
59
|
num_heads=9,
|
@@ -86,9 +86,18 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
86
86
|
return config
|
87
87
|
|
88
88
|
|
89
|
+
def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
|
90
|
+
config = get_model_config(**kwargs)
|
91
|
+
config.vocab_size = 128
|
92
|
+
config.num_layers = 2
|
93
|
+
# SmolLM has only one block config.
|
94
|
+
config.block_config(0).ff_config.intermediate_size = 64
|
95
|
+
return config
|
96
|
+
|
97
|
+
|
89
98
|
def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
|
90
99
|
config = get_model_config(**kwargs)
|
91
|
-
model =
|
100
|
+
model = SmolLM(config)
|
92
101
|
loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
|
93
102
|
# Since embedding and lm-head use the same weight, we need to set strict
|
94
103
|
# to False.
|
@@ -98,25 +107,25 @@ def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
|
|
98
107
|
|
99
108
|
|
100
109
|
def define_and_run(checkpoint_path: str) -> None:
|
101
|
-
"""Instantiates and runs a
|
110
|
+
"""Instantiates and runs a SmolLM model."""
|
102
111
|
|
103
112
|
current_dir = pathlib.Path(__file__).parent.resolve()
|
104
|
-
|
113
|
+
smollm_goldens = torch.load(current_dir / "smollm_lm_logits.pt")
|
105
114
|
kv_cache_max_len = 1024
|
106
115
|
model = build_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
|
107
116
|
idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
|
108
|
-
tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.
|
117
|
+
tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.int, device="cpu")
|
109
118
|
tokens[0, :4] = idx
|
110
|
-
input_pos = torch.arange(0, kv_cache_max_len)
|
119
|
+
input_pos = torch.arange(0, kv_cache_max_len, dtype=torch.int)
|
111
120
|
kv = kv_utils.KVCache.from_model_config(model.config)
|
112
121
|
output = model.forward(tokens, input_pos, kv)
|
113
122
|
assert torch.allclose(
|
114
|
-
|
123
|
+
smollm_goldens, output["logits"][0, idx.shape[1] - 1, :], atol=1e-05
|
115
124
|
)
|
116
125
|
|
117
126
|
|
118
127
|
if __name__ == "__main__":
|
119
128
|
input_checkpoint_path = os.path.join(
|
120
|
-
pathlib.Path.home(), "Downloads/llm_data/
|
129
|
+
pathlib.Path.home(), "Downloads/llm_data/smollm"
|
121
130
|
)
|
122
131
|
define_and_run(input_checkpoint_path)
|
@@ -76,7 +76,7 @@ class CLIP(nn.Module):
|
|
76
76
|
|
77
77
|
@torch.inference_mode
|
78
78
|
def forward(self, tokens: torch.LongTensor) -> torch.FloatTensor:
|
79
|
-
tokens = tokens.type(torch.
|
79
|
+
tokens = tokens.type(torch.int)
|
80
80
|
|
81
81
|
state = self.tok_embedding(tokens) + self.tok_embedding_position
|
82
82
|
for layer in self.transformer_blocks:
|
@@ -94,7 +94,7 @@ def convert_stable_diffusion_to_tflite(
|
|
94
94
|
n_tokens = 77
|
95
95
|
timestamp = 0
|
96
96
|
len_prompt = 1
|
97
|
-
prompt_tokens = torch.full((1, n_tokens), 0, dtype=torch.
|
97
|
+
prompt_tokens = torch.full((1, n_tokens), 0, dtype=torch.int)
|
98
98
|
input_image = torch.full(
|
99
99
|
(1, 3, image_height, image_width), 0, dtype=torch.float32
|
100
100
|
)
|
@@ -29,24 +29,24 @@ def convert_t5_to_tflite_singlesig(checkpoint_path: str):
|
|
29
29
|
|
30
30
|
# encoder
|
31
31
|
seq_len = 512
|
32
|
-
prefill_e_tokens = torch.full((1, seq_len), 0, dtype=torch.
|
32
|
+
prefill_e_tokens = torch.full((1, seq_len), 0, dtype=torch.int)
|
33
33
|
prompt_e_token = [1, 2, 3, 4, 5, 6]
|
34
34
|
prefill_e_tokens[0, : len(prompt_e_token)] = torch.tensor(
|
35
|
-
prompt_e_token, dtype=torch.
|
35
|
+
prompt_e_token, dtype=torch.int
|
36
36
|
)
|
37
|
-
prefill_e_input_pos = torch.arange(0, seq_len)
|
38
|
-
prefill_d_tokens = torch.full((1, seq_len), 0, dtype=torch.
|
37
|
+
prefill_e_input_pos = torch.arange(0, seq_len, dtype=torch.int)
|
38
|
+
prefill_d_tokens = torch.full((1, seq_len), 0, dtype=torch.int)
|
39
39
|
prompt_d_token = [1, 2, 3, 4, 5, 6]
|
40
40
|
prefill_d_tokens[0, : len(prompt_d_token)] = torch.tensor(
|
41
|
-
prompt_d_token, dtype=torch.
|
41
|
+
prompt_d_token, dtype=torch.int
|
42
42
|
)
|
43
|
-
prefill_d_input_pos = torch.arange(0, seq_len)
|
43
|
+
prefill_d_input_pos = torch.arange(0, seq_len, dtype=torch.int)
|
44
44
|
|
45
45
|
# decoder
|
46
|
-
decode_token = torch.tensor([[1]], dtype=torch.
|
47
|
-
decode_input_pos = torch.tensor([0], dtype=torch.
|
48
|
-
decode_d_token = torch.tensor([[1]], dtype=torch.
|
49
|
-
decode_d_input_pos = torch.tensor([0], dtype=torch.
|
46
|
+
decode_token = torch.tensor([[1]], dtype=torch.int)
|
47
|
+
decode_input_pos = torch.tensor([0], dtype=torch.int)
|
48
|
+
decode_d_token = torch.tensor([[1]], dtype=torch.int)
|
49
|
+
decode_d_input_pos = torch.tensor([0], dtype=torch.int)
|
50
50
|
|
51
51
|
# Pad mask for self attention only on "real" tokens.
|
52
52
|
# Pad with `-inf` for any tokens indices that aren't desired.
|
@@ -81,24 +81,24 @@ def convert_t5_to_tflite_multisig(checkpoint_path: str):
|
|
81
81
|
|
82
82
|
# encoder
|
83
83
|
seq_len = 512
|
84
|
-
prefill_e_tokens = torch.full((1, seq_len), 0, dtype=torch.
|
84
|
+
prefill_e_tokens = torch.full((1, seq_len), 0, dtype=torch.int)
|
85
85
|
prompt_e_token = [1, 2, 3, 4, 5, 6]
|
86
86
|
prefill_e_tokens[0, : len(prompt_e_token)] = torch.tensor(
|
87
|
-
prompt_e_token, dtype=torch.
|
87
|
+
prompt_e_token, dtype=torch.int
|
88
88
|
)
|
89
|
-
prefill_e_input_pos = torch.arange(0, seq_len)
|
90
|
-
prefill_d_tokens = torch.full((1, seq_len), 0, dtype=torch.
|
89
|
+
prefill_e_input_pos = torch.arange(0, seq_len, dtype=torch.int)
|
90
|
+
prefill_d_tokens = torch.full((1, seq_len), 0, dtype=torch.int)
|
91
91
|
prompt_d_token = [1, 2, 3, 4, 5, 6]
|
92
92
|
prefill_d_tokens[0, : len(prompt_d_token)] = torch.tensor(
|
93
|
-
prompt_d_token, dtype=torch.
|
93
|
+
prompt_d_token, dtype=torch.int
|
94
94
|
)
|
95
|
-
prefill_d_input_pos = torch.arange(0, seq_len)
|
95
|
+
prefill_d_input_pos = torch.arange(0, seq_len, dtype=torch.int)
|
96
96
|
|
97
97
|
# decoder
|
98
|
-
decode_token = torch.tensor([[1]], dtype=torch.
|
99
|
-
decode_input_pos = torch.tensor([0], dtype=torch.
|
100
|
-
decode_d_token = torch.tensor([[1]], dtype=torch.
|
101
|
-
decode_d_input_pos = torch.tensor([0], dtype=torch.
|
98
|
+
decode_token = torch.tensor([[1]], dtype=torch.int)
|
99
|
+
decode_input_pos = torch.tensor([0], dtype=torch.int)
|
100
|
+
decode_d_token = torch.tensor([[1]], dtype=torch.int)
|
101
|
+
decode_d_input_pos = torch.tensor([0], dtype=torch.int)
|
102
102
|
|
103
103
|
# Pad mask for self attention only on "real" tokens.
|
104
104
|
# Pad with `-inf` for any tokens indices that aren't desired.
|
@@ -601,12 +601,12 @@ def define_and_run_t5(checkpoint_path: str) -> None:
|
|
601
601
|
model = build_t5_model(checkpoint_path)
|
602
602
|
|
603
603
|
idx = get_sample_encoder_input_ids()
|
604
|
-
tokens = torch.full((1, 512), 0, dtype=torch.
|
604
|
+
tokens = torch.full((1, 512), 0, dtype=torch.int, device="cpu")
|
605
605
|
tokens[0, :77] = idx
|
606
|
-
input_pos = torch.arange(0, 512)
|
606
|
+
input_pos = torch.arange(0, 512, dtype=torch.int)
|
607
607
|
|
608
|
-
decode_d_token = torch.tensor([[0]], dtype=torch.
|
609
|
-
decode_d_input_pos = torch.tensor([0], dtype=torch.
|
608
|
+
decode_d_token = torch.tensor([[0]], dtype=torch.int)
|
609
|
+
decode_d_input_pos = torch.tensor([0], dtype=torch.int)
|
610
610
|
pad_mask = torch.zeros([model.config.kv_cache_max], dtype=torch.float32)
|
611
611
|
pad_mask[77:] = float("-inf")
|
612
612
|
lm_logits = model.forward(
|
@@ -633,12 +633,12 @@ def define_and_run_t5_split(checkpoint_path: str) -> None:
|
|
633
633
|
)
|
634
634
|
idx = get_sample_encoder_input_ids()
|
635
635
|
|
636
|
-
tokens = torch.full((1, 512), 0, dtype=torch.
|
636
|
+
tokens = torch.full((1, 512), 0, dtype=torch.int, device="cpu")
|
637
637
|
tokens[0, :77] = idx
|
638
|
-
input_pos = torch.arange(0, 512)
|
638
|
+
input_pos = torch.arange(0, 512, dtype=torch.int)
|
639
639
|
|
640
|
-
decode_d_token = torch.tensor([[0]], dtype=torch.
|
641
|
-
decode_d_input_pos = torch.tensor([0], dtype=torch.
|
640
|
+
decode_d_token = torch.tensor([[0]], dtype=torch.int)
|
641
|
+
decode_d_input_pos = torch.tensor([0], dtype=torch.int)
|
642
642
|
pad_mask = torch.zeros(
|
643
643
|
[t5_encoder_model.config.kv_cache_max], dtype=torch.float32
|
644
644
|
)
|
@@ -124,13 +124,13 @@ def get_model_config() -> cfg.ModelConfig:
|
|
124
124
|
|
125
125
|
|
126
126
|
def get_sample_prefill_inputs() -> Tuple[torch.Tensor, torch.Tensor]:
|
127
|
-
tokens = torch.unsqueeze(torch.arange(0, 100), 0)
|
128
|
-
input_pos = torch.arange(0, 100)
|
127
|
+
tokens = torch.unsqueeze(torch.arange(0, 100, dtype=torch.int), 0)
|
128
|
+
input_pos = torch.arange(0, 100, dtype=torch.int)
|
129
129
|
return tokens, input_pos
|
130
130
|
|
131
131
|
|
132
132
|
def get_sample_decode_inputs() -> Tuple[torch.Tensor, torch.Tensor]:
|
133
|
-
tokens = torch.tensor([[1]], dtype=torch.
|
133
|
+
tokens = torch.tensor([[1]], dtype=torch.int)
|
134
134
|
input_pos = torch.tensor([10])
|
135
135
|
return tokens, input_pos
|
136
136
|
|
@@ -47,10 +47,10 @@ def convert_tiny_llama_to_tflite(
|
|
47
47
|
checkpoint_path, kv_cache_max_len=kv_cache_max_len
|
48
48
|
)
|
49
49
|
# Tensors used to trace the model graph during conversion.
|
50
|
-
prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.
|
51
|
-
prefill_input_pos = torch.arange(0, prefill_seq_len)
|
52
|
-
decode_token = torch.tensor([[0]], dtype=torch.
|
53
|
-
decode_input_pos = torch.tensor([0], dtype=torch.
|
50
|
+
prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.int)
|
51
|
+
prefill_input_pos = torch.arange(0, prefill_seq_len, dtype=torch.int)
|
52
|
+
decode_token = torch.tensor([[0]], dtype=torch.int)
|
53
|
+
decode_input_pos = torch.tensor([0], dtype=torch.int)
|
54
54
|
kv = kv_utils.KVCache.from_model_config(pytorch_model.config)
|
55
55
|
|
56
56
|
quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
|
@@ -189,9 +189,9 @@ def define_and_run(checkpoint_path: str) -> None:
|
|
189
189
|
kv_cache_max_len = 1024
|
190
190
|
model = build_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
|
191
191
|
idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
|
192
|
-
tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.
|
192
|
+
tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.int, device="cpu")
|
193
193
|
tokens[0, :4] = idx
|
194
|
-
input_pos = torch.arange(0, kv_cache_max_len)
|
194
|
+
input_pos = torch.arange(0, kv_cache_max_len, dtype=torch.int)
|
195
195
|
kv = kv_utils.KVCache.from_model_config(model.config)
|
196
196
|
output = model.forward(tokens, input_pos, kv)
|
197
197
|
assert torch.allclose(
|
@@ -12,16 +12,16 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
from ai_edge_torch
|
16
|
-
from ai_edge_torch.
|
17
|
-
from ai_edge_torch.generative.fx_passes.remove_sdpa_zero_mask_pass import RemoveSDPACompositeZeroMaskPass
|
15
|
+
from ai_edge_torch import fx_pass_base
|
16
|
+
from ai_edge_torch.fx_pass_base import CanonicalizePass
|
17
|
+
from ai_edge_torch.generative.fx_passes.remove_sdpa_zero_mask_pass import RemoveSDPACompositeZeroMaskPass
|
18
18
|
import torch
|
19
19
|
|
20
20
|
|
21
21
|
def run_generative_passes(
|
22
22
|
exported_program: torch.export.ExportedProgram,
|
23
23
|
) -> torch.export.ExportedProgram:
|
24
|
-
return run_passes(
|
24
|
+
return fx_pass_base.run_passes(
|
25
25
|
exported_program,
|
26
26
|
[
|
27
27
|
RemoveSDPACompositeZeroMaskPass(),
|
@@ -12,13 +12,12 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
+
from ai_edge_torch import fx_pass_base
|
15
16
|
from ai_edge_torch import lowertools
|
16
|
-
from ai_edge_torch._convert.fx_passes._pass_base import ExportedProgramPassBase
|
17
|
-
from ai_edge_torch._convert.fx_passes._pass_base import ExportedProgramPassResult
|
18
17
|
import torch
|
19
18
|
|
20
19
|
|
21
|
-
class RemoveSDPACompositeZeroMaskPass(ExportedProgramPassBase):
|
20
|
+
class RemoveSDPACompositeZeroMaskPass(fx_pass_base.ExportedProgramPassBase):
|
22
21
|
|
23
22
|
def is_zero_tensor_node(self, node: torch.fx.Node):
|
24
23
|
return node.target == torch.ops.aten.zeros.default
|
@@ -48,4 +47,4 @@ class RemoveSDPACompositeZeroMaskPass(ExportedProgramPassBase):
|
|
48
47
|
|
49
48
|
exported_program.graph_module.graph.lint()
|
50
49
|
exported_program.graph_module.recompile()
|
51
|
-
return ExportedProgramPassResult(exported_program, True)
|
50
|
+
return fx_pass_base.ExportedProgramPassResult(exported_program, True)
|
@@ -160,6 +160,10 @@ class CausalSelfAttention(nn.Module):
|
|
160
160
|
self.output_projection = nn.Linear(
|
161
161
|
output_shape, dim, bias=config.output_proj_use_bias
|
162
162
|
)
|
163
|
+
self.query_norm = builder.build_norm(
|
164
|
+
config.head_dim, config.query_norm_config
|
165
|
+
)
|
166
|
+
self.key_norm = builder.build_norm(config.head_dim, config.key_norm_config)
|
163
167
|
self.config = config
|
164
168
|
self.enable_hlfb = enable_hlfb
|
165
169
|
self.sdpa_func = (
|
@@ -224,6 +228,9 @@ class CausalSelfAttention(nn.Module):
|
|
224
228
|
dim=-1,
|
225
229
|
)
|
226
230
|
|
231
|
+
q = self.query_norm(q)
|
232
|
+
k = self.key_norm(k)
|
233
|
+
|
227
234
|
q = q.reshape(B, T, -1, self.config.head_dim)
|
228
235
|
k = k.reshape(B, T, -1, self.config.head_dim)
|
229
236
|
v = v.reshape(B, T, -1, self.config.head_dim)
|
@@ -13,6 +13,8 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
# Builder class for individual components.
|
16
|
+
from typing import Callable
|
17
|
+
|
16
18
|
import ai_edge_torch.generative.layers.feed_forward as feed_forward
|
17
19
|
import ai_edge_torch.generative.layers.model_config as cfg
|
18
20
|
import ai_edge_torch.generative.layers.normalization as normalization
|
@@ -21,20 +23,34 @@ from torch import nn
|
|
21
23
|
import torch.nn.functional as F
|
22
24
|
|
23
25
|
|
24
|
-
|
25
|
-
|
26
|
+
def build_glu(
|
27
|
+
act: Callable[[torch.Tensor], torch.Tensor], gate_is_front: bool = False
|
28
|
+
) -> Callable[[torch.Tensor], torch.Tensor]:
|
29
|
+
"""Builds an activation function with GLU (Gated Linear Unit).
|
30
|
+
|
31
|
+
If gate_is_front is True,
|
32
|
+
f(x) = act(x) * y
|
33
|
+
otherwise,
|
34
|
+
f(x) = x * act(y),
|
35
|
+
where x is the first half of the input and y is the second half of the input.
|
26
36
|
|
27
|
-
|
28
|
-
|
37
|
+
Args:
|
38
|
+
act (Callable[[torch.Tensor], torch.Tensor]): activation function to apply
|
39
|
+
to the gate.
|
40
|
+
gate_is_front: whether the gate is in front half of the input. Other part is
|
41
|
+
the output in GLU.
|
42
|
+
|
43
|
+
Returns:
|
44
|
+
A callable activation function with GLU.
|
29
45
|
"""
|
30
46
|
|
31
|
-
def
|
32
|
-
|
33
|
-
|
47
|
+
def _glu(x):
|
48
|
+
x, y = x.chunk(2, dim=-1)
|
49
|
+
if gate_is_front:
|
50
|
+
return act(x) * y
|
51
|
+
return x * act(y)
|
34
52
|
|
35
|
-
|
36
|
-
x, gate = self.proj(x).chunk(2, dim=-1)
|
37
|
-
return x * F.gelu(gate)
|
53
|
+
return _glu
|
38
54
|
|
39
55
|
|
40
56
|
def build_norm(dim: int, config: cfg.NormalizationConfig):
|
@@ -99,6 +115,10 @@ def build_ff(dim: int, config: cfg.FeedForwardConfig):
|
|
99
115
|
hidden_dim=config.intermediate_size,
|
100
116
|
activation=activation,
|
101
117
|
use_bias=config.use_bias,
|
118
|
+
use_glu=(
|
119
|
+
config.activation.type == cfg.ActivationType.GE_GLU
|
120
|
+
or config.activation.type == cfg.ActivationType.SILU_GLU
|
121
|
+
),
|
102
122
|
pre_ff_norm=pre_ff_norm,
|
103
123
|
post_ff_norm=post_ff_norm,
|
104
124
|
)
|
@@ -129,8 +149,10 @@ def get_activation(config: cfg.ActivationConfig):
|
|
129
149
|
# See: https://github.com/hendrycks/GELUs
|
130
150
|
return lambda x: x * F.sigmoid(1.702 * x)
|
131
151
|
elif config.type == cfg.ActivationType.GE_GLU:
|
132
|
-
return
|
152
|
+
return build_glu(F.gelu, config.gate_is_front)
|
133
153
|
elif config.type == cfg.ActivationType.RELU:
|
134
154
|
return F.relu
|
155
|
+
elif config.type == cfg.ActivationType.SILU_GLU:
|
156
|
+
return build_glu(F.silu, config.gate_is_front)
|
135
157
|
else:
|
136
158
|
raise ValueError("Unsupported activation type.")
|
@@ -30,18 +30,27 @@ class SequentialFeedForward(nn.Module):
|
|
30
30
|
hidden_dim: int,
|
31
31
|
activation: Callable[[torch.Tensor], torch.Tensor],
|
32
32
|
use_bias=False,
|
33
|
+
use_glu=False,
|
33
34
|
pre_ff_norm: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
|
34
35
|
post_ff_norm: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
|
35
36
|
):
|
36
37
|
"""Init function for feedforward layer.
|
37
38
|
|
38
|
-
Args:
|
39
|
-
|
40
|
-
|
39
|
+
Args:
|
40
|
+
dim (int): embedding size.
|
41
|
+
hidden_dim (int): hidden dim size of the feedforward layer.
|
42
|
+
activation (Callable): activation function used in this block.
|
43
|
+
use_bias (Boolean): whether to use bias. Default is false.
|
44
|
+
use_glu (Boolean): whether to use glu in activation. Default is false.
|
45
|
+
pre_ff_norm (Callable): pre feedforward norm. Default is None.
|
46
|
+
post_ff_norm (Callable): post feedforward norm. Default is None.
|
41
47
|
"""
|
42
48
|
super().__init__()
|
43
49
|
self.act = activation
|
44
|
-
|
50
|
+
if use_glu:
|
51
|
+
self.w1 = nn.Linear(dim, hidden_dim * 2, bias=use_bias)
|
52
|
+
else:
|
53
|
+
self.w1 = nn.Linear(dim, hidden_dim, bias=use_bias)
|
45
54
|
self.w2 = nn.Linear(hidden_dim, dim, bias=use_bias)
|
46
55
|
self.pre_ff_norm = pre_ff_norm if pre_ff_norm else lambda x: x
|
47
56
|
self.post_ff_norm = post_ff_norm if post_ff_norm else lambda x: x
|
@@ -72,18 +81,27 @@ class GatedFeedForward(nn.Module):
|
|
72
81
|
hidden_dim: int,
|
73
82
|
activation: Callable[[torch.Tensor], torch.Tensor],
|
74
83
|
use_bias=False,
|
84
|
+
use_glu=False,
|
75
85
|
pre_ff_norm: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
|
76
86
|
post_ff_norm: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
|
77
87
|
):
|
78
88
|
"""Init function for feedforward layer.
|
79
89
|
|
80
|
-
Args:
|
81
|
-
|
82
|
-
|
90
|
+
Args:
|
91
|
+
dim (int): embedding size.
|
92
|
+
hidden_dim (int): hidden dim size of the feedforward layer.
|
93
|
+
activation (Callable): activation function used in this block.
|
94
|
+
use_bias (Boolean): whether to use bias. Default is false.
|
95
|
+
use_glu (Boolean): whether to use glu in activation. Default is false.
|
96
|
+
pre_ff_norm (Callable): pre feedforward norm. Default is None.
|
97
|
+
post_ff_norm (Callable): post feedforward norm. Default is None.
|
83
98
|
"""
|
84
99
|
super().__init__()
|
85
100
|
self.act = activation
|
86
|
-
|
101
|
+
if use_glu:
|
102
|
+
self.w1 = nn.Linear(dim, hidden_dim * 2, bias=use_bias)
|
103
|
+
else:
|
104
|
+
self.w1 = nn.Linear(dim, hidden_dim, bias=use_bias)
|
87
105
|
self.w2 = nn.Linear(hidden_dim, dim, bias=use_bias)
|
88
106
|
self.w3 = nn.Linear(dim, hidden_dim, bias=use_bias)
|
89
107
|
self.pre_ff_norm = pre_ff_norm if pre_ff_norm else lambda x: x
|
@@ -172,8 +172,8 @@ def _update_kv_base_impl(
|
|
172
172
|
v_slice: torch.Tensor,
|
173
173
|
) -> KVCacheEntry:
|
174
174
|
"""Update the cache buffer without High Level Function Boundary annotation."""
|
175
|
-
k = cache.k_cache.index_copy(1, input_pos, k_slice)
|
176
|
-
v = cache.v_cache.index_copy(1, input_pos, v_slice)
|
175
|
+
k = cache.k_cache.index_copy(1, input_pos.to(torch.long), k_slice)
|
176
|
+
v = cache.v_cache.index_copy(1, input_pos.to(torch.long), v_slice)
|
177
177
|
updated_cache = KVCacheEntry(k, v)
|
178
178
|
return updated_cache
|
179
179
|
|
@@ -189,7 +189,7 @@ def _update_kv_hlfb_impl(
|
|
189
189
|
k_cache, v_cache, input_pos, k_slice, v_slice = builder.mark_inputs(
|
190
190
|
cache.k_cache, cache.v_cache, input_pos, k_slice, v_slice
|
191
191
|
)
|
192
|
-
k = k_cache.index_copy(1, input_pos, k_slice)
|
193
|
-
v = v_cache.index_copy(1, input_pos, v_slice)
|
192
|
+
k = k_cache.index_copy(1, input_pos.to(torch.long), k_slice)
|
193
|
+
v = v_cache.index_copy(1, input_pos.to(torch.long), v_slice)
|
194
194
|
k, v = builder.mark_outputs(k, v)
|
195
195
|
return KVCacheEntry(k, v)
|