ai-edge-torch-nightly 0.1.dev202405131930__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ai-edge-torch-nightly might be problematic. Click here for more details.
- ai_edge_torch/__init__.py +30 -0
- ai_edge_torch/convert/__init__.py +14 -0
- ai_edge_torch/convert/conversion.py +117 -0
- ai_edge_torch/convert/conversion_utils.py +330 -0
- ai_edge_torch/convert/converter.py +171 -0
- ai_edge_torch/convert/fx_passes/__init__.py +59 -0
- ai_edge_torch/convert/fx_passes/_pass_base.py +49 -0
- ai_edge_torch/convert/fx_passes/build_aten_composite_pass.py +192 -0
- ai_edge_torch/convert/fx_passes/build_upsample_bilinear2d_composite_pass.py +84 -0
- ai_edge_torch/convert/fx_passes/canonicalize_pass.py +37 -0
- ai_edge_torch/convert/fx_passes/inject_mlir_debuginfo_pass.py +73 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/__init__.py +16 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_check.py +215 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_mark.py +48 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/__init__.py +17 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py +59 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py +196 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +400 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/op_func_registry.py +30 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +286 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/utils.py +62 -0
- ai_edge_torch/convert/test/__init__.py +14 -0
- ai_edge_torch/convert/test/test_convert.py +273 -0
- ai_edge_torch/convert/test/test_convert_composites.py +171 -0
- ai_edge_torch/convert/test/test_convert_multisig.py +139 -0
- ai_edge_torch/debug/__init__.py +16 -0
- ai_edge_torch/debug/culprit.py +423 -0
- ai_edge_torch/debug/test/__init__.py +14 -0
- ai_edge_torch/debug/test/test_culprit.py +133 -0
- ai_edge_torch/debug/utils.py +48 -0
- ai_edge_torch/experimental/__init__.py +14 -0
- ai_edge_torch/generative/__init__.py +14 -0
- ai_edge_torch/generative/examples/__init__.py +14 -0
- ai_edge_torch/generative/examples/gemma/__init__.py +14 -0
- ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +66 -0
- ai_edge_torch/generative/examples/gemma/gemma.py +174 -0
- ai_edge_torch/generative/examples/phi2/__init__.py +14 -0
- ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +64 -0
- ai_edge_torch/generative/examples/phi2/phi2.py +164 -0
- ai_edge_torch/generative/examples/t5/__init__.py +14 -0
- ai_edge_torch/generative/examples/t5/convert_to_tflite.py +135 -0
- ai_edge_torch/generative/examples/t5/t5.py +608 -0
- ai_edge_torch/generative/examples/t5/t5_attention.py +255 -0
- ai_edge_torch/generative/examples/test_models/__init__.py +14 -0
- ai_edge_torch/generative/examples/test_models/toy_model.py +119 -0
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +143 -0
- ai_edge_torch/generative/examples/tiny_llama/__init__.py +0 -0
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +66 -0
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +164 -0
- ai_edge_torch/generative/layers/__init__.py +14 -0
- ai_edge_torch/generative/layers/attention.py +288 -0
- ai_edge_torch/generative/layers/attention_utils.py +169 -0
- ai_edge_torch/generative/layers/builder.py +103 -0
- ai_edge_torch/generative/layers/feed_forward.py +95 -0
- ai_edge_torch/generative/layers/kv_cache.py +83 -0
- ai_edge_torch/generative/layers/model_config.py +135 -0
- ai_edge_torch/generative/layers/normalization.py +62 -0
- ai_edge_torch/generative/layers/rotary_position_embedding.py +36 -0
- ai_edge_torch/generative/quantize/__init__.py +14 -0
- ai_edge_torch/generative/quantize/example.py +45 -0
- ai_edge_torch/generative/quantize/quant_attrs.py +66 -0
- ai_edge_torch/generative/quantize/quant_recipe.py +106 -0
- ai_edge_torch/generative/quantize/quant_recipe_utils.py +51 -0
- ai_edge_torch/generative/quantize/quant_recipes.py +48 -0
- ai_edge_torch/generative/quantize/supported_schemes.py +31 -0
- ai_edge_torch/generative/test/__init__.py +14 -0
- ai_edge_torch/generative/test/test_model_conversion.py +201 -0
- ai_edge_torch/generative/test/test_quantize.py +109 -0
- ai_edge_torch/generative/utilities/__init__.py +15 -0
- ai_edge_torch/generative/utilities/loader.py +290 -0
- ai_edge_torch/generative/utilities/t5_loader.py +467 -0
- ai_edge_torch/hlfb/__init__.py +16 -0
- ai_edge_torch/hlfb/mark_pattern/__init__.py +139 -0
- ai_edge_torch/hlfb/mark_pattern/passes.py +42 -0
- ai_edge_torch/hlfb/mark_pattern/pattern.py +260 -0
- ai_edge_torch/hlfb/test/__init__.py +14 -0
- ai_edge_torch/hlfb/test/test_mark_pattern.py +133 -0
- ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py +270 -0
- ai_edge_torch/model.py +134 -0
- ai_edge_torch/quantize/__init__.py +16 -0
- ai_edge_torch/quantize/pt2e_quantizer.py +438 -0
- ai_edge_torch/quantize/pt2e_quantizer_utils.py +1041 -0
- ai_edge_torch/quantize/quant_config.py +85 -0
- ai_edge_torch/testing/__init__.py +14 -0
- ai_edge_torch/testing/model_coverage/__init__.py +16 -0
- ai_edge_torch/testing/model_coverage/model_coverage.py +126 -0
- ai_edge_torch_nightly-0.1.dev202405131930.dist-info/LICENSE +202 -0
- ai_edge_torch_nightly-0.1.dev202405131930.dist-info/METADATA +38 -0
- ai_edge_torch_nightly-0.1.dev202405131930.dist-info/RECORD +91 -0
- ai_edge_torch_nightly-0.1.dev202405131930.dist-info/WHEEL +5 -0
- ai_edge_torch_nightly-0.1.dev202405131930.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
import os
|
|
17
|
+
from pathlib import Path
|
|
18
|
+
|
|
19
|
+
import torch
|
|
20
|
+
|
|
21
|
+
import ai_edge_torch
|
|
22
|
+
from ai_edge_torch.generative.examples.gemma import gemma
|
|
23
|
+
from ai_edge_torch.generative.quantize import quant_recipes
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def convert_gemma_to_tflite(
|
|
27
|
+
checkpoint_path: str,
|
|
28
|
+
prefill_seq_len: int = 512,
|
|
29
|
+
kv_cache_max_len: int = 1024,
|
|
30
|
+
quantize: bool = True,
|
|
31
|
+
):
|
|
32
|
+
"""An example method for converting a Gemma 2B model to multi-signature
|
|
33
|
+
tflite model.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
checkpoint_path (str): The filepath to the model checkpoint, or directory holding the checkpoint.
|
|
37
|
+
prefill_seq_len (int, optional): The maximum size of prefill input tensor.
|
|
38
|
+
Defaults to 512.
|
|
39
|
+
kv_cache_max_len (int, optional): The maximum size of KV cache buffer,
|
|
40
|
+
including both prefill and decode. Defaults to 1024.
|
|
41
|
+
quantize (bool, optional): Whether the model should be quanized.
|
|
42
|
+
Defaults to True.
|
|
43
|
+
"""
|
|
44
|
+
pytorch_model = gemma.build_2b_model(
|
|
45
|
+
checkpoint_path, kv_cache_max_len=kv_cache_max_len
|
|
46
|
+
)
|
|
47
|
+
# Tensors used to trace the model graph during conversion.
|
|
48
|
+
prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.long)
|
|
49
|
+
prefill_input_pos = torch.arange(0, prefill_seq_len)
|
|
50
|
+
decode_token = torch.tensor([[0]], dtype=torch.long)
|
|
51
|
+
decode_input_pos = torch.tensor([0], dtype=torch.int64)
|
|
52
|
+
|
|
53
|
+
quant_config = quant_recipes.full_linear_int8_dynamic_recipe() if quantize else None
|
|
54
|
+
edge_model = (
|
|
55
|
+
ai_edge_torch.signature(
|
|
56
|
+
'prefill', pytorch_model, (prefill_tokens, prefill_input_pos)
|
|
57
|
+
)
|
|
58
|
+
.signature('decode', pytorch_model, (decode_token, decode_input_pos))
|
|
59
|
+
.convert(quant_config=quant_config)
|
|
60
|
+
)
|
|
61
|
+
edge_model.export(f'/tmp/gemma_seq{prefill_seq_len}_kv{kv_cache_max_len}.tflite')
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
if __name__ == '__main__':
|
|
65
|
+
checkpoint_path = os.path.join(Path.home(), 'Downloads/llm_data/gemma-2b')
|
|
66
|
+
convert_gemma_to_tflite(checkpoint_path)
|
|
@@ -0,0 +1,174 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
# Example of building a Gemma model.
|
|
16
|
+
|
|
17
|
+
import os
|
|
18
|
+
from pathlib import Path
|
|
19
|
+
|
|
20
|
+
import numpy as np
|
|
21
|
+
import torch
|
|
22
|
+
import torch.nn as nn
|
|
23
|
+
|
|
24
|
+
from ai_edge_torch.generative.layers.attention import TransformerBlock
|
|
25
|
+
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
|
26
|
+
import ai_edge_torch.generative.layers.builder as builder
|
|
27
|
+
import ai_edge_torch.generative.layers.model_config as cfg
|
|
28
|
+
import ai_edge_torch.generative.utilities.loader as loading_utils
|
|
29
|
+
|
|
30
|
+
TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
|
31
|
+
ff_up_proj="model.layers.{}.mlp.up_proj",
|
|
32
|
+
ff_down_proj="model.layers.{}.mlp.down_proj",
|
|
33
|
+
ff_gate_proj="model.layers.{}.mlp.gate_proj",
|
|
34
|
+
attn_query_proj="model.layers.{}.self_attn.q_proj",
|
|
35
|
+
attn_key_proj="model.layers.{}.self_attn.k_proj",
|
|
36
|
+
attn_value_proj="model.layers.{}.self_attn.v_proj",
|
|
37
|
+
attn_output_proj="model.layers.{}.self_attn.o_proj",
|
|
38
|
+
pre_attn_norm="model.layers.{}.input_layernorm",
|
|
39
|
+
pre_ff_norm="model.layers.{}.post_attention_layernorm",
|
|
40
|
+
embedding="model.embed_tokens",
|
|
41
|
+
final_norm="model.norm",
|
|
42
|
+
lm_head=None,
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class Gemma(nn.Module):
|
|
47
|
+
|
|
48
|
+
def __init__(self, config: cfg.ModelConfig):
|
|
49
|
+
super().__init__()
|
|
50
|
+
|
|
51
|
+
self.config = config
|
|
52
|
+
# Construct model layers.
|
|
53
|
+
self.tok_embedding = nn.Embedding(
|
|
54
|
+
config.vocab_size, config.embedding_dim, padding_idx=0
|
|
55
|
+
)
|
|
56
|
+
self.lm_head = nn.Linear(
|
|
57
|
+
config.embedding_dim,
|
|
58
|
+
config.vocab_size,
|
|
59
|
+
bias=config.lm_head_use_bias,
|
|
60
|
+
)
|
|
61
|
+
# Gemma re-uses the embedding as the head projection layer.
|
|
62
|
+
self.lm_head.weight.data = self.tok_embedding.weight.data
|
|
63
|
+
self.transformer_blocks = nn.ModuleList(
|
|
64
|
+
TransformerBlock(config) for _ in range(config.num_layers)
|
|
65
|
+
)
|
|
66
|
+
self.final_norm = builder.build_norm(
|
|
67
|
+
config.embedding_dim,
|
|
68
|
+
config.final_norm_config,
|
|
69
|
+
)
|
|
70
|
+
self.rope_cache = attn_utils.build_rope_cache(
|
|
71
|
+
size=config.kv_cache_max,
|
|
72
|
+
dim=int(config.attn_config.rotary_percentage * config.head_dim),
|
|
73
|
+
base=10_000,
|
|
74
|
+
condense_ratio=1,
|
|
75
|
+
dtype=torch.float32,
|
|
76
|
+
device=torch.device("cpu"),
|
|
77
|
+
)
|
|
78
|
+
self.mask_cache = attn_utils.build_causal_mask_cache(
|
|
79
|
+
size=config.kv_cache_max, dtype=torch.float32, device=torch.device("cpu")
|
|
80
|
+
)
|
|
81
|
+
self.config = config
|
|
82
|
+
|
|
83
|
+
# The model's forward function takes in additional k/v cache tensors
|
|
84
|
+
# and returns the updated k/v cache tensors to the caller.
|
|
85
|
+
# This can be eliminated if we handle k/v cache updates inside the model itself.
|
|
86
|
+
@torch.inference_mode
|
|
87
|
+
def forward(self, idx: torch.Tensor, input_pos: torch.Tensor) -> torch.Tensor:
|
|
88
|
+
B, T = idx.size()
|
|
89
|
+
assert (
|
|
90
|
+
self.config.max_seq_len >= T
|
|
91
|
+
), f"Cannot forward sequence of length {T}, max seq length is only {self.config.max_seq_len}"
|
|
92
|
+
|
|
93
|
+
cos, sin = self.rope_cache
|
|
94
|
+
cos = cos.index_select(0, input_pos)
|
|
95
|
+
sin = sin.index_select(0, input_pos)
|
|
96
|
+
mask = self.mask_cache.index_select(2, input_pos)
|
|
97
|
+
mask = mask[:, :, :, : self.config.kv_cache_max]
|
|
98
|
+
|
|
99
|
+
# token embeddings of shape (b, t, n_embd)
|
|
100
|
+
x = self.tok_embedding(idx)
|
|
101
|
+
x = x * (self.config.embedding_dim**0.5)
|
|
102
|
+
|
|
103
|
+
for i, block in enumerate(self.transformer_blocks):
|
|
104
|
+
x = block(x, (cos, sin), mask, input_pos)
|
|
105
|
+
|
|
106
|
+
x = self.final_norm(x)
|
|
107
|
+
res = self.lm_head(x) # (b, t, vocab_size)
|
|
108
|
+
return res
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
112
|
+
attn_config = cfg.AttentionConfig(
|
|
113
|
+
num_heads=8,
|
|
114
|
+
num_query_groups=1,
|
|
115
|
+
rotary_percentage=1.0,
|
|
116
|
+
)
|
|
117
|
+
ff_config = cfg.FeedForwardConfig(
|
|
118
|
+
type=cfg.FeedForwardType.GATED,
|
|
119
|
+
activation=cfg.ActivationType.GELU_TANH,
|
|
120
|
+
intermediate_size=16384,
|
|
121
|
+
)
|
|
122
|
+
norm_config = cfg.NormalizationConfig(
|
|
123
|
+
type=cfg.NormalizationType.RMS_NORM,
|
|
124
|
+
epsilon=1e-6,
|
|
125
|
+
zero_centered=True,
|
|
126
|
+
)
|
|
127
|
+
config = cfg.ModelConfig(
|
|
128
|
+
vocab_size=256000,
|
|
129
|
+
num_layers=18,
|
|
130
|
+
max_seq_len=8192,
|
|
131
|
+
embedding_dim=2048,
|
|
132
|
+
kv_cache_max_len=kv_cache_max_len,
|
|
133
|
+
attn_config=attn_config,
|
|
134
|
+
ff_config=ff_config,
|
|
135
|
+
pre_attention_norm_config=norm_config,
|
|
136
|
+
pre_ff_norm_config=norm_config,
|
|
137
|
+
final_norm_config=norm_config,
|
|
138
|
+
parallel_residual=False,
|
|
139
|
+
lm_head_use_bias=False,
|
|
140
|
+
enable_hlfb=True,
|
|
141
|
+
)
|
|
142
|
+
return config
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
def get_fake_model_config_2b_for_test() -> cfg.ModelConfig:
|
|
146
|
+
config = get_model_config_2b()
|
|
147
|
+
config.num_layers = 2
|
|
148
|
+
return config
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def build_2b_model(checkpoint_path, **kwargs) -> nn.Module:
|
|
152
|
+
config = get_model_config_2b(**kwargs)
|
|
153
|
+
model = Gemma(config)
|
|
154
|
+
loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
|
|
155
|
+
# since embedding and lm-head use the same weight, we need to set strict
|
|
156
|
+
# to False.
|
|
157
|
+
loader.load(model, strict=False)
|
|
158
|
+
return model
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
def define_and_run_2b() -> None:
|
|
162
|
+
kv_cache_max_len = 1024
|
|
163
|
+
checkpoint_path = os.path.join(Path.home(), "Downloads/llm_data/gemma-2b")
|
|
164
|
+
model = build_2b_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
|
|
165
|
+
idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
|
|
166
|
+
tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.long, device="cpu")
|
|
167
|
+
tokens[0, :4] = idx
|
|
168
|
+
input_pos = torch.arange(0, kv_cache_max_len)
|
|
169
|
+
print("running an inference")
|
|
170
|
+
print(model.forward(tokens, input_pos))
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
if __name__ == "__main__":
|
|
174
|
+
define_and_run_2b()
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
import os
|
|
17
|
+
from pathlib import Path
|
|
18
|
+
|
|
19
|
+
import torch
|
|
20
|
+
|
|
21
|
+
import ai_edge_torch
|
|
22
|
+
from ai_edge_torch.generative.examples.phi2 import phi2
|
|
23
|
+
from ai_edge_torch.generative.quantize import quant_recipes
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def convert_phi2_to_tflite(
|
|
27
|
+
checkpoint_path: str,
|
|
28
|
+
prefill_seq_len: int = 512,
|
|
29
|
+
kv_cache_max_len: int = 1024,
|
|
30
|
+
quantize: bool = True,
|
|
31
|
+
):
|
|
32
|
+
"""An example method for converting a Phi-2 model to multi-signature
|
|
33
|
+
tflite model.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
checkpoint_path (str): The filepath to the model checkpoint, or directory holding the checkpoint.
|
|
37
|
+
prefill_seq_len (int, optional): The maximum size of prefill input tensor.
|
|
38
|
+
Defaults to 512.
|
|
39
|
+
kv_cache_max_len (int, optional): The maximum size of KV cache buffer,
|
|
40
|
+
including both prefill and decode. Defaults to 1024.
|
|
41
|
+
quantize (bool, optional): Whether the model should be quanized.
|
|
42
|
+
Defaults to True.
|
|
43
|
+
"""
|
|
44
|
+
pytorch_model = phi2.build_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
|
|
45
|
+
# Tensors used to trace the model graph during conversion.
|
|
46
|
+
prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.long)
|
|
47
|
+
prefill_input_pos = torch.arange(0, prefill_seq_len)
|
|
48
|
+
decode_token = torch.tensor([[0]], dtype=torch.long)
|
|
49
|
+
decode_input_pos = torch.tensor([0], dtype=torch.int64)
|
|
50
|
+
|
|
51
|
+
quant_config = quant_recipes.full_linear_int8_dynamic_recipe() if quantize else None
|
|
52
|
+
edge_model = (
|
|
53
|
+
ai_edge_torch.signature(
|
|
54
|
+
'prefill', pytorch_model, (prefill_tokens, prefill_input_pos)
|
|
55
|
+
)
|
|
56
|
+
.signature('decode', pytorch_model, (decode_token, decode_input_pos))
|
|
57
|
+
.convert(quant_config=quant_config)
|
|
58
|
+
)
|
|
59
|
+
edge_model.export(f'/tmp/phi2_seq{prefill_seq_len}_kv{kv_cache_max_len}.tflite')
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
if __name__ == '__main__':
|
|
63
|
+
checkpoint_path = os.path.join(Path.home(), 'Downloads/llm_data/phi2')
|
|
64
|
+
convert_phi2_to_tflite(checkpoint_path)
|
|
@@ -0,0 +1,164 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
# Example of building phi-2 model from the Edge Generative API layers.
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
import os
|
|
19
|
+
from pathlib import Path
|
|
20
|
+
|
|
21
|
+
import numpy as np
|
|
22
|
+
import torch
|
|
23
|
+
import torch.nn as nn
|
|
24
|
+
|
|
25
|
+
from ai_edge_torch.generative.layers.attention import TransformerBlock
|
|
26
|
+
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
|
27
|
+
import ai_edge_torch.generative.layers.builder as builder
|
|
28
|
+
import ai_edge_torch.generative.layers.model_config as cfg
|
|
29
|
+
import ai_edge_torch.generative.utilities.loader as loading_utils
|
|
30
|
+
|
|
31
|
+
TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
|
32
|
+
ff_up_proj="model.layers.{}.mlp.fc1",
|
|
33
|
+
ff_down_proj="model.layers.{}.mlp.fc2",
|
|
34
|
+
attn_query_proj="model.layers.{}.self_attn.q_proj",
|
|
35
|
+
attn_key_proj="model.layers.{}.self_attn.k_proj",
|
|
36
|
+
attn_value_proj="model.layers.{}.self_attn.v_proj",
|
|
37
|
+
attn_output_proj="model.layers.{}.self_attn.dense",
|
|
38
|
+
pre_attn_norm="model.layers.{}.input_layernorm",
|
|
39
|
+
embedding="model.embed_tokens",
|
|
40
|
+
final_norm="model.final_layernorm",
|
|
41
|
+
lm_head="lm_head",
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class Phi2(nn.Module):
|
|
46
|
+
|
|
47
|
+
def __init__(self, config: cfg.ModelConfig):
|
|
48
|
+
super().__init__()
|
|
49
|
+
|
|
50
|
+
self.config = config
|
|
51
|
+
# Construct model layers.
|
|
52
|
+
self.lm_head = nn.Linear(
|
|
53
|
+
config.embedding_dim, config.vocab_size, bias=config.lm_head_use_bias
|
|
54
|
+
)
|
|
55
|
+
self.tok_embedding = nn.Embedding(
|
|
56
|
+
config.vocab_size, config.embedding_dim, padding_idx=0
|
|
57
|
+
)
|
|
58
|
+
self.transformer_blocks = nn.ModuleList(
|
|
59
|
+
TransformerBlock(config) for _ in range(config.num_layers)
|
|
60
|
+
)
|
|
61
|
+
self.final_norm = builder.build_norm(
|
|
62
|
+
config.embedding_dim,
|
|
63
|
+
config.final_norm_config,
|
|
64
|
+
)
|
|
65
|
+
self.rope_cache = attn_utils.build_rope_cache(
|
|
66
|
+
size=config.kv_cache_max,
|
|
67
|
+
dim=int(config.attn_config.rotary_percentage * config.head_dim),
|
|
68
|
+
base=10_000,
|
|
69
|
+
condense_ratio=1,
|
|
70
|
+
dtype=torch.float32,
|
|
71
|
+
device=torch.device("cpu"),
|
|
72
|
+
)
|
|
73
|
+
self.mask_cache = attn_utils.build_causal_mask_cache(
|
|
74
|
+
size=config.kv_cache_max, dtype=torch.float32, device=torch.device("cpu")
|
|
75
|
+
)
|
|
76
|
+
self.config = config
|
|
77
|
+
|
|
78
|
+
# The model's forward function takes in additional k/v cache tensors
|
|
79
|
+
# and returns the updated k/v cache tensors to the caller.
|
|
80
|
+
# This can be eliminated if we handle k/v cache updates inside the model itself.
|
|
81
|
+
@torch.inference_mode
|
|
82
|
+
def forward(self, idx: torch.Tensor, input_pos: torch.Tensor) -> torch.Tensor:
|
|
83
|
+
B, T = idx.size()
|
|
84
|
+
assert (
|
|
85
|
+
self.config.max_seq_len >= T
|
|
86
|
+
), f"Cannot forward sequence of length {T}, max seq length is only {self.config.max_seq_len}"
|
|
87
|
+
|
|
88
|
+
cos, sin = self.rope_cache
|
|
89
|
+
cos = cos.index_select(0, input_pos)
|
|
90
|
+
sin = sin.index_select(0, input_pos)
|
|
91
|
+
mask = self.mask_cache.index_select(2, input_pos)
|
|
92
|
+
mask = mask[:, :, :, : self.config.kv_cache_max]
|
|
93
|
+
|
|
94
|
+
# forward the model itself
|
|
95
|
+
x = self.tok_embedding(idx) # token embeddings of shape (b, t, n_embd)
|
|
96
|
+
|
|
97
|
+
for i, block in enumerate(self.transformer_blocks):
|
|
98
|
+
x = block(x, (cos, sin), mask, input_pos)
|
|
99
|
+
|
|
100
|
+
x = self.final_norm(x)
|
|
101
|
+
res = self.lm_head(x) # (b, t, vocab_size)
|
|
102
|
+
return res
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
106
|
+
attn_config = cfg.AttentionConfig(
|
|
107
|
+
num_heads=32,
|
|
108
|
+
num_query_groups=32,
|
|
109
|
+
rotary_percentage=0.4,
|
|
110
|
+
qkv_use_bias=True,
|
|
111
|
+
output_proj_use_bias=True,
|
|
112
|
+
)
|
|
113
|
+
ff_config = cfg.FeedForwardConfig(
|
|
114
|
+
type=cfg.FeedForwardType.SEQUENTIAL,
|
|
115
|
+
activation=cfg.ActivationType.GELU_TANH,
|
|
116
|
+
intermediate_size=10240,
|
|
117
|
+
use_bias=True,
|
|
118
|
+
)
|
|
119
|
+
norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.LAYER_NORM)
|
|
120
|
+
config = cfg.ModelConfig(
|
|
121
|
+
vocab_size=51200,
|
|
122
|
+
num_layers=32,
|
|
123
|
+
max_seq_len=2048,
|
|
124
|
+
kv_cache_max_len=kv_cache_max_len,
|
|
125
|
+
embedding_dim=2560,
|
|
126
|
+
attn_config=attn_config,
|
|
127
|
+
ff_config=ff_config,
|
|
128
|
+
pre_attention_norm_config=norm_config,
|
|
129
|
+
final_norm_config=norm_config,
|
|
130
|
+
parallel_residual=True,
|
|
131
|
+
lm_head_use_bias=True,
|
|
132
|
+
enable_hlfb=True,
|
|
133
|
+
)
|
|
134
|
+
return config
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def get_fake_model_config_for_test() -> cfg.ModelConfig:
|
|
138
|
+
config = get_model_config()
|
|
139
|
+
config.num_layers = 2
|
|
140
|
+
return config
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
def build_model(checkpoint_path, **kwargs) -> nn.Module:
|
|
144
|
+
config = get_model_config(**kwargs)
|
|
145
|
+
model = Phi2(config)
|
|
146
|
+
loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
|
|
147
|
+
loader.load(model)
|
|
148
|
+
return model
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def define_and_run() -> None:
|
|
152
|
+
kv_cache_max_len = 1024
|
|
153
|
+
checkpoint_path = os.path.join(Path.home(), "Downloads/llm_data/phi2")
|
|
154
|
+
model = build_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
|
|
155
|
+
idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
|
|
156
|
+
tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.long, device="cpu")
|
|
157
|
+
tokens[0, :4] = idx
|
|
158
|
+
input_pos = torch.arange(0, kv_cache_max_len)
|
|
159
|
+
print("running an inference")
|
|
160
|
+
print(model.forward(tokens, input_pos))
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
if __name__ == "__main__":
|
|
164
|
+
define_and_run()
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
@@ -0,0 +1,135 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
import os
|
|
17
|
+
from pathlib import Path
|
|
18
|
+
|
|
19
|
+
import numpy as np
|
|
20
|
+
import torch
|
|
21
|
+
|
|
22
|
+
import ai_edge_torch
|
|
23
|
+
from ai_edge_torch.generative.examples.t5 import t5
|
|
24
|
+
from ai_edge_torch.generative.quantize import quant_recipes
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
# TODO(haoliang): clean this up untile 2-sig model is validated e2e.
|
|
28
|
+
def convert_t5_to_tflite_singlesig(checkpoint_path: str):
|
|
29
|
+
pytorch_model = t5.build_t5_model(checkpoint_path)
|
|
30
|
+
|
|
31
|
+
# encoder
|
|
32
|
+
seq_len = 512
|
|
33
|
+
prefill_e_tokens = torch.full((1, seq_len), 0, dtype=torch.long)
|
|
34
|
+
prompt_e_token = [1, 2, 3, 4, 5, 6]
|
|
35
|
+
prefill_e_tokens[0, : len(prompt_e_token)] = torch.tensor(
|
|
36
|
+
prompt_e_token, dtype=torch.long
|
|
37
|
+
)
|
|
38
|
+
prefill_e_input_pos = torch.arange(0, seq_len)
|
|
39
|
+
prefill_d_tokens = torch.full((1, seq_len), 0, dtype=torch.long)
|
|
40
|
+
prompt_d_token = [1, 2, 3, 4, 5, 6]
|
|
41
|
+
prefill_d_tokens[0, : len(prompt_d_token)] = torch.tensor(
|
|
42
|
+
prompt_d_token, dtype=torch.long
|
|
43
|
+
)
|
|
44
|
+
prefill_d_input_pos = torch.arange(0, seq_len)
|
|
45
|
+
|
|
46
|
+
# decoder
|
|
47
|
+
decode_token = torch.tensor([[1]], dtype=torch.long)
|
|
48
|
+
decode_input_pos = torch.tensor([0], dtype=torch.int64)
|
|
49
|
+
decode_d_token = torch.tensor([[1]], dtype=torch.long)
|
|
50
|
+
decode_d_input_pos = torch.tensor([0], dtype=torch.int64)
|
|
51
|
+
|
|
52
|
+
# Pad mask for self attention only on "real" tokens.
|
|
53
|
+
# Pad with `-inf` for any tokens indices that aren't desired.
|
|
54
|
+
pad_mask = torch.zeros([seq_len], dtype=torch.float32)
|
|
55
|
+
|
|
56
|
+
edge_model = ai_edge_torch.signature(
|
|
57
|
+
'decode',
|
|
58
|
+
pytorch_model,
|
|
59
|
+
(
|
|
60
|
+
prefill_e_tokens,
|
|
61
|
+
prefill_e_input_pos,
|
|
62
|
+
decode_d_token,
|
|
63
|
+
decode_d_input_pos,
|
|
64
|
+
pad_mask,
|
|
65
|
+
),
|
|
66
|
+
).convert()
|
|
67
|
+
|
|
68
|
+
edge_model.export('/tmp/t5_encode_decode.tflite')
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def convert_t5_to_tflite_multisig(checkpoint_path: str):
|
|
72
|
+
config = t5.get_model_config_t5()
|
|
73
|
+
embedding_layer = torch.nn.Embedding(
|
|
74
|
+
config.vocab_size, config.embedding_dim, padding_idx=0
|
|
75
|
+
)
|
|
76
|
+
t5_encoder_model = t5.build_t5_encoder_model(config, embedding_layer, checkpoint_path)
|
|
77
|
+
t5_decoder_model = t5.build_t5_decoder_model(config, embedding_layer, checkpoint_path)
|
|
78
|
+
|
|
79
|
+
# encoder
|
|
80
|
+
seq_len = 512
|
|
81
|
+
prefill_e_tokens = torch.full((1, seq_len), 0, dtype=torch.long)
|
|
82
|
+
prompt_e_token = [1, 2, 3, 4, 5, 6]
|
|
83
|
+
prefill_e_tokens[0, : len(prompt_e_token)] = torch.tensor(
|
|
84
|
+
prompt_e_token, dtype=torch.long
|
|
85
|
+
)
|
|
86
|
+
prefill_e_input_pos = torch.arange(0, seq_len)
|
|
87
|
+
prefill_d_tokens = torch.full((1, seq_len), 0, dtype=torch.long)
|
|
88
|
+
prompt_d_token = [1, 2, 3, 4, 5, 6]
|
|
89
|
+
prefill_d_tokens[0, : len(prompt_d_token)] = torch.tensor(
|
|
90
|
+
prompt_d_token, dtype=torch.long
|
|
91
|
+
)
|
|
92
|
+
prefill_d_input_pos = torch.arange(0, seq_len)
|
|
93
|
+
|
|
94
|
+
# decoder
|
|
95
|
+
decode_token = torch.tensor([[1]], dtype=torch.long)
|
|
96
|
+
decode_input_pos = torch.tensor([0], dtype=torch.int64)
|
|
97
|
+
decode_d_token = torch.tensor([[1]], dtype=torch.long)
|
|
98
|
+
decode_d_input_pos = torch.tensor([0], dtype=torch.int64)
|
|
99
|
+
|
|
100
|
+
# Pad mask for self attention only on "real" tokens.
|
|
101
|
+
# Pad with `-inf` for any tokens indices that aren't desired.
|
|
102
|
+
pad_mask = torch.zeros([seq_len], dtype=torch.float32)
|
|
103
|
+
hidden_states = torch.zeros((1, 512, 768), dtype=torch.float32)
|
|
104
|
+
quant_config = quant_recipes.full_linear_int8_dynamic_recipe()
|
|
105
|
+
|
|
106
|
+
edge_model = (
|
|
107
|
+
ai_edge_torch.signature(
|
|
108
|
+
'encode',
|
|
109
|
+
t5_encoder_model,
|
|
110
|
+
(
|
|
111
|
+
prefill_e_tokens,
|
|
112
|
+
prefill_e_input_pos,
|
|
113
|
+
pad_mask,
|
|
114
|
+
),
|
|
115
|
+
)
|
|
116
|
+
.signature(
|
|
117
|
+
'decode',
|
|
118
|
+
t5_decoder_model,
|
|
119
|
+
(
|
|
120
|
+
hidden_states,
|
|
121
|
+
decode_d_token,
|
|
122
|
+
decode_d_input_pos,
|
|
123
|
+
pad_mask,
|
|
124
|
+
),
|
|
125
|
+
)
|
|
126
|
+
.convert(quant_config=quant_config)
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
edge_model.export('/tmp/t5_encode_decode_2_sigs.tflite')
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
if __name__ == '__main__':
|
|
133
|
+
checkpoint_path = os.path.join(Path.home(), 'Downloads/llm_data/t5')
|
|
134
|
+
# convert_t5_to_tflite_singlesig(checkpoint_path)
|
|
135
|
+
convert_t5_to_tflite_multisig(checkpoint_path)
|