ai-edge-torch-nightly 0.2.0.dev20240719__py3-none-any.whl → 0.2.0.dev20240721__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ai-edge-torch-nightly might be problematic. Click here for more details.
- ai_edge_torch/generative/examples/experimental/__init__.py +14 -0
- ai_edge_torch/generative/examples/experimental/gemma/__init__.py +14 -0
- ai_edge_torch/generative/examples/experimental/gemma/convert_to_tflite.py +87 -0
- ai_edge_torch/generative/examples/experimental/gemma/gemma.py +195 -0
- ai_edge_torch/generative/examples/experimental/phi/__init__.py +14 -0
- ai_edge_torch/generative/examples/experimental/phi/convert_to_tflite.py +84 -0
- ai_edge_torch/generative/examples/experimental/phi/phi2.py +184 -0
- ai_edge_torch/generative/examples/experimental/tiny_llama/__init__.py +14 -0
- ai_edge_torch/generative/examples/experimental/tiny_llama/convert_to_tflite.py +89 -0
- ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py +185 -0
- ai_edge_torch/generative/examples/gemma/gemma.py +6 -2
- ai_edge_torch/generative/examples/phi2/phi2.py +5 -2
- ai_edge_torch/generative/examples/t5/t5.py +5 -2
- ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +42 -27
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +6 -2
- ai_edge_torch/generative/test/test_experimental_ekv.py +122 -0
- {ai_edge_torch_nightly-0.2.0.dev20240719.dist-info → ai_edge_torch_nightly-0.2.0.dev20240721.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.2.0.dev20240719.dist-info → ai_edge_torch_nightly-0.2.0.dev20240721.dist-info}/RECORD +21 -10
- {ai_edge_torch_nightly-0.2.0.dev20240719.dist-info → ai_edge_torch_nightly-0.2.0.dev20240721.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.2.0.dev20240719.dist-info → ai_edge_torch_nightly-0.2.0.dev20240721.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.2.0.dev20240719.dist-info → ai_edge_torch_nightly-0.2.0.dev20240721.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
@@ -0,0 +1,87 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
#
|
|
16
|
+
# Note: This is an experimental version of Gemma with external KV cache.
|
|
17
|
+
# Please use with caution.
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
import os
|
|
21
|
+
from pathlib import Path
|
|
22
|
+
|
|
23
|
+
import torch
|
|
24
|
+
|
|
25
|
+
import ai_edge_torch
|
|
26
|
+
from ai_edge_torch.generative.examples.experimental.gemma import gemma
|
|
27
|
+
from ai_edge_torch.generative.layers.experimental import ekv_cache as kv_utils
|
|
28
|
+
from ai_edge_torch.generative.quantize import quant_recipes
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def convert_gemma_to_tflite(
|
|
32
|
+
checkpoint_path: str,
|
|
33
|
+
prefill_seq_len: int = 512,
|
|
34
|
+
kv_cache_max_len: int = 1024,
|
|
35
|
+
quantize: bool = True,
|
|
36
|
+
):
|
|
37
|
+
"""An example method for converting a Gemma 2B model to multi-signature
|
|
38
|
+
tflite model.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
checkpoint_path (str): The filepath to the model checkpoint, or directory
|
|
42
|
+
holding the checkpoint.
|
|
43
|
+
prefill_seq_len (int, optional): The maximum size of prefill input tensor.
|
|
44
|
+
Defaults to 512.
|
|
45
|
+
kv_cache_max_len (int, optional): The maximum size of KV cache buffer,
|
|
46
|
+
including both prefill and decode. Defaults to 1024.
|
|
47
|
+
quantize (bool, optional): Whether the model should be quanized.
|
|
48
|
+
Defaults to True.
|
|
49
|
+
"""
|
|
50
|
+
pytorch_model = gemma.build_2b_model(
|
|
51
|
+
checkpoint_path, kv_cache_max_len=kv_cache_max_len
|
|
52
|
+
)
|
|
53
|
+
# Tensors used to trace the model graph during conversion.
|
|
54
|
+
prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.long)
|
|
55
|
+
prefill_input_pos = torch.arange(0, prefill_seq_len)
|
|
56
|
+
decode_token = torch.tensor([[0]], dtype=torch.long)
|
|
57
|
+
decode_input_pos = torch.tensor([0], dtype=torch.int64)
|
|
58
|
+
kv = kv_utils.EKVCache.from_model_config(pytorch_model.config)
|
|
59
|
+
|
|
60
|
+
quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
|
|
61
|
+
edge_model = (
|
|
62
|
+
ai_edge_torch.signature(
|
|
63
|
+
'prefill',
|
|
64
|
+
pytorch_model,
|
|
65
|
+
sample_kwargs={
|
|
66
|
+
'tokens': prefill_tokens,
|
|
67
|
+
'input_pos': prefill_input_pos,
|
|
68
|
+
'kv_cache': kv,
|
|
69
|
+
},
|
|
70
|
+
)
|
|
71
|
+
.signature(
|
|
72
|
+
'decode',
|
|
73
|
+
pytorch_model,
|
|
74
|
+
sample_kwargs={
|
|
75
|
+
'tokens': decode_token,
|
|
76
|
+
'input_pos': decode_input_pos,
|
|
77
|
+
'kv_cache': kv,
|
|
78
|
+
},
|
|
79
|
+
)
|
|
80
|
+
.convert(quant_config=quant_config)
|
|
81
|
+
)
|
|
82
|
+
edge_model.export(f'/tmp/gemma_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite')
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
if __name__ == '__main__':
|
|
86
|
+
checkpoint_path = os.path.join(Path.home(), 'Downloads/llm_data/gemma-2b')
|
|
87
|
+
convert_gemma_to_tflite(checkpoint_path)
|
|
@@ -0,0 +1,195 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
# Example of building a Gemma model.
|
|
16
|
+
#
|
|
17
|
+
# Note: This is an experimental version of Gemma with external KV cache.
|
|
18
|
+
# Please use with caution.
|
|
19
|
+
|
|
20
|
+
import os
|
|
21
|
+
from pathlib import Path
|
|
22
|
+
from typing import Tuple
|
|
23
|
+
|
|
24
|
+
import numpy as np
|
|
25
|
+
import torch
|
|
26
|
+
import torch.nn as nn
|
|
27
|
+
|
|
28
|
+
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
|
29
|
+
import ai_edge_torch.generative.layers.builder as builder
|
|
30
|
+
from ai_edge_torch.generative.layers.experimental import ekv_cache as kv_utils
|
|
31
|
+
from ai_edge_torch.generative.layers.experimental.attention import TransformerBlock # NOQA
|
|
32
|
+
import ai_edge_torch.generative.layers.model_config as cfg
|
|
33
|
+
import ai_edge_torch.generative.utilities.loader as loading_utils
|
|
34
|
+
|
|
35
|
+
TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
|
36
|
+
ff_up_proj="model.layers.{}.mlp.up_proj",
|
|
37
|
+
ff_down_proj="model.layers.{}.mlp.down_proj",
|
|
38
|
+
ff_gate_proj="model.layers.{}.mlp.gate_proj",
|
|
39
|
+
attn_query_proj="model.layers.{}.self_attn.q_proj",
|
|
40
|
+
attn_key_proj="model.layers.{}.self_attn.k_proj",
|
|
41
|
+
attn_value_proj="model.layers.{}.self_attn.v_proj",
|
|
42
|
+
attn_output_proj="model.layers.{}.self_attn.o_proj",
|
|
43
|
+
pre_attn_norm="model.layers.{}.input_layernorm",
|
|
44
|
+
pre_ff_norm="model.layers.{}.post_attention_layernorm",
|
|
45
|
+
embedding="model.embed_tokens",
|
|
46
|
+
final_norm="model.norm",
|
|
47
|
+
lm_head=None,
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class Gemma(nn.Module):
|
|
52
|
+
|
|
53
|
+
def __init__(self, config: cfg.ModelConfig):
|
|
54
|
+
super().__init__()
|
|
55
|
+
|
|
56
|
+
self.config = config
|
|
57
|
+
# Construct model layers.
|
|
58
|
+
self.tok_embedding = nn.Embedding(
|
|
59
|
+
config.vocab_size, config.embedding_dim, padding_idx=0
|
|
60
|
+
)
|
|
61
|
+
self.lm_head = nn.Linear(
|
|
62
|
+
config.embedding_dim,
|
|
63
|
+
config.vocab_size,
|
|
64
|
+
bias=config.lm_head_use_bias,
|
|
65
|
+
)
|
|
66
|
+
# Gemma re-uses the embedding as the head projection layer.
|
|
67
|
+
self.lm_head.weight.data = self.tok_embedding.weight.data
|
|
68
|
+
self.transformer_blocks = nn.ModuleList(
|
|
69
|
+
TransformerBlock(config) for _ in range(config.num_layers)
|
|
70
|
+
)
|
|
71
|
+
self.final_norm = builder.build_norm(
|
|
72
|
+
config.embedding_dim,
|
|
73
|
+
config.final_norm_config,
|
|
74
|
+
)
|
|
75
|
+
self.rope_cache = attn_utils.build_rope_cache(
|
|
76
|
+
size=config.kv_cache_max,
|
|
77
|
+
dim=int(config.attn_config.rotary_percentage * config.head_dim),
|
|
78
|
+
base=10_000,
|
|
79
|
+
condense_ratio=1,
|
|
80
|
+
dtype=torch.float32,
|
|
81
|
+
device=torch.device("cpu"),
|
|
82
|
+
)
|
|
83
|
+
self.mask_cache = attn_utils.build_causal_mask_cache(
|
|
84
|
+
size=config.kv_cache_max, dtype=torch.float32, device=torch.device("cpu")
|
|
85
|
+
)
|
|
86
|
+
self.config = config
|
|
87
|
+
|
|
88
|
+
@torch.inference_mode
|
|
89
|
+
def forward(
|
|
90
|
+
self,
|
|
91
|
+
tokens: torch.Tensor,
|
|
92
|
+
input_pos: torch.Tensor,
|
|
93
|
+
kv_cache: kv_utils.EKVCache,
|
|
94
|
+
) -> Tuple[torch.Tensor, kv_utils.EKVCache]:
|
|
95
|
+
B, T = tokens.size()
|
|
96
|
+
assert (
|
|
97
|
+
self.config.max_seq_len >= T
|
|
98
|
+
), f"Cannot forward sequence of length {T}, max seq length is only {self.config.max_seq_len}"
|
|
99
|
+
|
|
100
|
+
cos, sin = self.rope_cache
|
|
101
|
+
cos = cos.index_select(0, input_pos)
|
|
102
|
+
sin = sin.index_select(0, input_pos)
|
|
103
|
+
mask = self.mask_cache.index_select(2, input_pos)
|
|
104
|
+
mask = mask[:, :, :, : self.config.kv_cache_max]
|
|
105
|
+
|
|
106
|
+
# token embeddings of shape (b, t, n_embd)
|
|
107
|
+
x = self.tok_embedding(tokens)
|
|
108
|
+
x = x * (self.config.embedding_dim**0.5)
|
|
109
|
+
|
|
110
|
+
updated_kv_entires = []
|
|
111
|
+
for i, block in enumerate(self.transformer_blocks):
|
|
112
|
+
kv_entry = kv_cache.caches[i] if kv_cache else None
|
|
113
|
+
x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
|
|
114
|
+
if kv_entry:
|
|
115
|
+
updated_kv_entires.append(kv_entry)
|
|
116
|
+
updated_kv_cache = kv_utils.EKVCache(tuple(updated_kv_entires))
|
|
117
|
+
|
|
118
|
+
x = self.final_norm(x)
|
|
119
|
+
res = self.lm_head(x) # (b, t, vocab_size)
|
|
120
|
+
return res, updated_kv_cache
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
124
|
+
attn_config = cfg.AttentionConfig(
|
|
125
|
+
num_heads=8,
|
|
126
|
+
num_query_groups=1,
|
|
127
|
+
rotary_percentage=1.0,
|
|
128
|
+
)
|
|
129
|
+
ff_config = cfg.FeedForwardConfig(
|
|
130
|
+
type=cfg.FeedForwardType.GATED,
|
|
131
|
+
activation=cfg.ActivationConfig(cfg.ActivationType.GELU_TANH),
|
|
132
|
+
intermediate_size=16384,
|
|
133
|
+
)
|
|
134
|
+
norm_config = cfg.NormalizationConfig(
|
|
135
|
+
type=cfg.NormalizationType.RMS_NORM,
|
|
136
|
+
epsilon=1e-6,
|
|
137
|
+
zero_centered=True,
|
|
138
|
+
)
|
|
139
|
+
config = cfg.ModelConfig(
|
|
140
|
+
vocab_size=256000,
|
|
141
|
+
num_layers=18,
|
|
142
|
+
max_seq_len=8192,
|
|
143
|
+
embedding_dim=2048,
|
|
144
|
+
kv_cache_max_len=kv_cache_max_len,
|
|
145
|
+
attn_config=attn_config,
|
|
146
|
+
ff_config=ff_config,
|
|
147
|
+
pre_attention_norm_config=norm_config,
|
|
148
|
+
pre_ff_norm_config=norm_config,
|
|
149
|
+
final_norm_config=norm_config,
|
|
150
|
+
parallel_residual=False,
|
|
151
|
+
lm_head_use_bias=False,
|
|
152
|
+
enable_hlfb=True,
|
|
153
|
+
)
|
|
154
|
+
return config
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def get_fake_model_config_2b_for_test(**kwargs) -> cfg.ModelConfig:
|
|
158
|
+
config = get_model_config_2b(**kwargs)
|
|
159
|
+
config.num_layers = 2
|
|
160
|
+
return config
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
def build_2b_model(checkpoint_path, test_model=False, **kwargs) -> nn.Module:
|
|
164
|
+
config = (
|
|
165
|
+
get_fake_model_config_2b_for_test(**kwargs)
|
|
166
|
+
if test_model
|
|
167
|
+
else get_model_config_2b(**kwargs)
|
|
168
|
+
)
|
|
169
|
+
model = Gemma(config)
|
|
170
|
+
if checkpoint_path is not None:
|
|
171
|
+
loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
|
|
172
|
+
# since embedding and lm-head use the same weight, we need to set strict
|
|
173
|
+
# to False.
|
|
174
|
+
loader.load(model, strict=False)
|
|
175
|
+
model.eval()
|
|
176
|
+
return model
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
def define_and_run_2b(checkpoint_path, test_model=False) -> None:
|
|
180
|
+
kv_cache_max_len = 1024
|
|
181
|
+
model = build_2b_model(
|
|
182
|
+
checkpoint_path, test_model=test_model, kv_cache_max_len=kv_cache_max_len
|
|
183
|
+
)
|
|
184
|
+
idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
|
|
185
|
+
tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.long, device="cpu")
|
|
186
|
+
tokens[0, :4] = idx
|
|
187
|
+
input_pos = torch.arange(0, kv_cache_max_len)
|
|
188
|
+
kv = kv_utils.EKVCache.from_model_config(model.config)
|
|
189
|
+
print("running an inference")
|
|
190
|
+
print(model.forward(tokens, input_pos, kv))
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
if __name__ == "__main__":
|
|
194
|
+
checkpoint_path = os.path.join(Path.home(), "Downloads/gemma-2b")
|
|
195
|
+
define_and_run_2b(checkpoint_path)
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
@@ -0,0 +1,84 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
#
|
|
16
|
+
# Note: This is an experimental version of phi2 with external KV cache.
|
|
17
|
+
# Please use with caution.
|
|
18
|
+
|
|
19
|
+
import os
|
|
20
|
+
from pathlib import Path
|
|
21
|
+
|
|
22
|
+
import torch
|
|
23
|
+
|
|
24
|
+
import ai_edge_torch
|
|
25
|
+
from ai_edge_torch.generative.examples.experimental.phi import phi2
|
|
26
|
+
from ai_edge_torch.generative.layers.experimental import ekv_cache
|
|
27
|
+
from ai_edge_torch.generative.quantize import quant_recipes
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def convert_phi2_to_tflite(
|
|
31
|
+
checkpoint_path: str,
|
|
32
|
+
prefill_seq_len: int = 512,
|
|
33
|
+
kv_cache_max_len: int = 1024,
|
|
34
|
+
quantize: bool = True,
|
|
35
|
+
):
|
|
36
|
+
"""An example method for converting a Phi-2 model to multi-signature
|
|
37
|
+
tflite model.
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
checkpoint_path (str): The filepath to the model checkpoint, or
|
|
41
|
+
directory holding the checkpoint.
|
|
42
|
+
prefill_seq_len (int, optional): The maximum size of prefill input tensor.
|
|
43
|
+
Defaults to 512.
|
|
44
|
+
kv_cache_max_len (int, optional): The maximum size of KV cache buffer,
|
|
45
|
+
including both prefill and decode. Defaults to 1024.
|
|
46
|
+
quantize (bool, optional): Whether the model should be quanized.
|
|
47
|
+
Defaults to True.
|
|
48
|
+
"""
|
|
49
|
+
pytorch_model = phi2.build_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
|
|
50
|
+
# Tensors used to trace the model graph during conversion.
|
|
51
|
+
prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.long)
|
|
52
|
+
prefill_input_pos = torch.arange(0, prefill_seq_len)
|
|
53
|
+
decode_token = torch.tensor([[0]], dtype=torch.long)
|
|
54
|
+
decode_input_pos = torch.tensor([0], dtype=torch.int64)
|
|
55
|
+
kv = ekv_cache.EKVCache.from_model_config(pytorch_model.config)
|
|
56
|
+
|
|
57
|
+
quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
|
|
58
|
+
edge_model = (
|
|
59
|
+
ai_edge_torch.signature(
|
|
60
|
+
'prefill',
|
|
61
|
+
pytorch_model,
|
|
62
|
+
sample_kwargs={
|
|
63
|
+
'tokens': prefill_tokens,
|
|
64
|
+
'input_pos': prefill_input_pos,
|
|
65
|
+
'kv_cache': kv,
|
|
66
|
+
},
|
|
67
|
+
)
|
|
68
|
+
.signature(
|
|
69
|
+
'decode',
|
|
70
|
+
pytorch_model,
|
|
71
|
+
sample_kwargs={
|
|
72
|
+
'tokens': decode_token,
|
|
73
|
+
'input_pos': decode_input_pos,
|
|
74
|
+
'kv_cache': kv,
|
|
75
|
+
},
|
|
76
|
+
)
|
|
77
|
+
.convert(quant_config=quant_config)
|
|
78
|
+
)
|
|
79
|
+
edge_model.export(f'/tmp/phi2_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite')
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
if __name__ == '__main__':
|
|
83
|
+
checkpoint_path = os.path.join(Path.home(), 'Downloads/llm_data/phi2')
|
|
84
|
+
convert_phi2_to_tflite(checkpoint_path)
|
|
@@ -0,0 +1,184 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
# Example of building phi-2 model from the Edge Generative API layers.
|
|
16
|
+
#
|
|
17
|
+
# Note: This is an experimental version of phi2 with external KV cache.
|
|
18
|
+
# Please use with caution.
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
import os
|
|
22
|
+
from pathlib import Path
|
|
23
|
+
from typing import Tuple
|
|
24
|
+
|
|
25
|
+
import numpy as np
|
|
26
|
+
import torch
|
|
27
|
+
import torch.nn as nn
|
|
28
|
+
|
|
29
|
+
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
|
30
|
+
import ai_edge_torch.generative.layers.builder as builder
|
|
31
|
+
from ai_edge_torch.generative.layers.experimental import ekv_cache as kv_utils
|
|
32
|
+
from ai_edge_torch.generative.layers.experimental.attention import TransformerBlock # NOQA
|
|
33
|
+
import ai_edge_torch.generative.layers.model_config as cfg
|
|
34
|
+
import ai_edge_torch.generative.utilities.loader as loading_utils
|
|
35
|
+
|
|
36
|
+
TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
|
37
|
+
ff_up_proj="model.layers.{}.mlp.fc1",
|
|
38
|
+
ff_down_proj="model.layers.{}.mlp.fc2",
|
|
39
|
+
attn_query_proj="model.layers.{}.self_attn.q_proj",
|
|
40
|
+
attn_key_proj="model.layers.{}.self_attn.k_proj",
|
|
41
|
+
attn_value_proj="model.layers.{}.self_attn.v_proj",
|
|
42
|
+
attn_output_proj="model.layers.{}.self_attn.dense",
|
|
43
|
+
pre_attn_norm="model.layers.{}.input_layernorm",
|
|
44
|
+
embedding="model.embed_tokens",
|
|
45
|
+
final_norm="model.final_layernorm",
|
|
46
|
+
lm_head="lm_head",
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class Phi2(nn.Module):
|
|
51
|
+
|
|
52
|
+
def __init__(self, config: cfg.ModelConfig):
|
|
53
|
+
super().__init__()
|
|
54
|
+
|
|
55
|
+
self.config = config
|
|
56
|
+
# Construct model layers.
|
|
57
|
+
self.lm_head = nn.Linear(
|
|
58
|
+
config.embedding_dim, config.vocab_size, bias=config.lm_head_use_bias
|
|
59
|
+
)
|
|
60
|
+
self.tok_embedding = nn.Embedding(
|
|
61
|
+
config.vocab_size, config.embedding_dim, padding_idx=0
|
|
62
|
+
)
|
|
63
|
+
self.transformer_blocks = nn.ModuleList(
|
|
64
|
+
TransformerBlock(config) for _ in range(config.num_layers)
|
|
65
|
+
)
|
|
66
|
+
self.final_norm = builder.build_norm(
|
|
67
|
+
config.embedding_dim,
|
|
68
|
+
config.final_norm_config,
|
|
69
|
+
)
|
|
70
|
+
self.rope_cache = attn_utils.build_rope_cache(
|
|
71
|
+
size=config.kv_cache_max,
|
|
72
|
+
dim=int(config.attn_config.rotary_percentage * config.head_dim),
|
|
73
|
+
base=10_000,
|
|
74
|
+
condense_ratio=1,
|
|
75
|
+
dtype=torch.float32,
|
|
76
|
+
device=torch.device("cpu"),
|
|
77
|
+
)
|
|
78
|
+
self.mask_cache = attn_utils.build_causal_mask_cache(
|
|
79
|
+
size=config.kv_cache_max, dtype=torch.float32, device=torch.device("cpu")
|
|
80
|
+
)
|
|
81
|
+
self.config = config
|
|
82
|
+
|
|
83
|
+
@torch.inference_mode
|
|
84
|
+
def forward(
|
|
85
|
+
self,
|
|
86
|
+
tokens: torch.Tensor,
|
|
87
|
+
input_pos: torch.Tensor,
|
|
88
|
+
kv_cache: kv_utils.EKVCache,
|
|
89
|
+
) -> Tuple[torch.Tensor, kv_utils.EKVCache]:
|
|
90
|
+
B, T = tokens.size()
|
|
91
|
+
assert (
|
|
92
|
+
self.config.max_seq_len >= T
|
|
93
|
+
), f"Cannot forward sequence of length {T}, max seq length is only {self.config.max_seq_len}"
|
|
94
|
+
|
|
95
|
+
cos, sin = self.rope_cache
|
|
96
|
+
cos = cos.index_select(0, input_pos)
|
|
97
|
+
sin = sin.index_select(0, input_pos)
|
|
98
|
+
mask = self.mask_cache.index_select(2, input_pos)
|
|
99
|
+
mask = mask[:, :, :, : self.config.kv_cache_max]
|
|
100
|
+
|
|
101
|
+
x = self.tok_embedding(tokens)
|
|
102
|
+
|
|
103
|
+
updated_kv_entires = []
|
|
104
|
+
for i, block in enumerate(self.transformer_blocks):
|
|
105
|
+
kv_entry = kv_cache.caches[i] if kv_cache else None
|
|
106
|
+
x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
|
|
107
|
+
if kv_entry:
|
|
108
|
+
updated_kv_entires.append(kv_entry)
|
|
109
|
+
updated_kv_cache = kv_utils.EKVCache(tuple(updated_kv_entires))
|
|
110
|
+
|
|
111
|
+
x = self.final_norm(x)
|
|
112
|
+
res = self.lm_head(x) # (b, t, vocab_size)
|
|
113
|
+
return res, updated_kv_cache
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
117
|
+
attn_config = cfg.AttentionConfig(
|
|
118
|
+
num_heads=32,
|
|
119
|
+
num_query_groups=32,
|
|
120
|
+
rotary_percentage=0.4,
|
|
121
|
+
qkv_use_bias=True,
|
|
122
|
+
output_proj_use_bias=True,
|
|
123
|
+
)
|
|
124
|
+
ff_config = cfg.FeedForwardConfig(
|
|
125
|
+
type=cfg.FeedForwardType.SEQUENTIAL,
|
|
126
|
+
activation=cfg.ActivationConfig(cfg.ActivationType.GELU_TANH),
|
|
127
|
+
intermediate_size=10240,
|
|
128
|
+
use_bias=True,
|
|
129
|
+
)
|
|
130
|
+
norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.LAYER_NORM)
|
|
131
|
+
config = cfg.ModelConfig(
|
|
132
|
+
vocab_size=51200,
|
|
133
|
+
num_layers=32,
|
|
134
|
+
max_seq_len=2048,
|
|
135
|
+
kv_cache_max_len=kv_cache_max_len,
|
|
136
|
+
embedding_dim=2560,
|
|
137
|
+
attn_config=attn_config,
|
|
138
|
+
ff_config=ff_config,
|
|
139
|
+
pre_attention_norm_config=norm_config,
|
|
140
|
+
final_norm_config=norm_config,
|
|
141
|
+
parallel_residual=True,
|
|
142
|
+
lm_head_use_bias=True,
|
|
143
|
+
enable_hlfb=True,
|
|
144
|
+
)
|
|
145
|
+
return config
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
def get_fake_model_config_for_test(**kwargs) -> cfg.ModelConfig:
|
|
149
|
+
config = get_model_config(**kwargs)
|
|
150
|
+
config.num_layers = 2
|
|
151
|
+
return config
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
def build_model(checkpoint_path, test_model=False, **kwargs) -> nn.Module:
|
|
155
|
+
config = (
|
|
156
|
+
get_fake_model_config_for_test(**kwargs)
|
|
157
|
+
if test_model
|
|
158
|
+
else get_model_config(**kwargs)
|
|
159
|
+
)
|
|
160
|
+
model = Phi2(config)
|
|
161
|
+
if checkpoint_path is not None:
|
|
162
|
+
loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
|
|
163
|
+
loader.load(model)
|
|
164
|
+
model.eval()
|
|
165
|
+
return model
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
def define_and_run(checkpoint_path, test_model=False) -> None:
|
|
169
|
+
kv_cache_max_len = 1024
|
|
170
|
+
model = build_model(
|
|
171
|
+
checkpoint_path, test_model=test_model, kv_cache_max_len=kv_cache_max_len
|
|
172
|
+
)
|
|
173
|
+
idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
|
|
174
|
+
tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.long, device="cpu")
|
|
175
|
+
tokens[0, :4] = idx
|
|
176
|
+
input_pos = torch.arange(0, kv_cache_max_len)
|
|
177
|
+
kv = kv_utils.EKVCache.from_model_config(model.config)
|
|
178
|
+
print("running an inference")
|
|
179
|
+
print(model.forward(tokens, input_pos, kv))
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
if __name__ == "__main__":
|
|
183
|
+
checkpoint_path = os.path.join(Path.home(), "Downloads/phi2")
|
|
184
|
+
define_and_run(checkpoint_path)
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
@@ -0,0 +1,89 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
#
|
|
16
|
+
# Note: This is an experimental version of TinyLlama with external KV cache.
|
|
17
|
+
# Please use with caution.
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
import os
|
|
21
|
+
from pathlib import Path
|
|
22
|
+
|
|
23
|
+
import torch
|
|
24
|
+
|
|
25
|
+
import ai_edge_torch
|
|
26
|
+
from ai_edge_torch.generative.examples.experimental.tiny_llama import tiny_llama # NOQA
|
|
27
|
+
from ai_edge_torch.generative.layers.experimental import ekv_cache as kv_utils
|
|
28
|
+
from ai_edge_torch.generative.quantize import quant_recipes
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def convert_tiny_llama_to_tflite(
|
|
32
|
+
checkpoint_path: str,
|
|
33
|
+
prefill_seq_len: int = 512,
|
|
34
|
+
kv_cache_max_len: int = 1024,
|
|
35
|
+
quantize: bool = True,
|
|
36
|
+
):
|
|
37
|
+
"""An example method for converting TinyLlama model to multi-signature
|
|
38
|
+
tflite model.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
checkpoint_path (str): The filepath to the model checkpoint, or directory
|
|
42
|
+
holding the checkpoint.
|
|
43
|
+
prefill_seq_len (int, optional): The maximum size of prefill input tensor.
|
|
44
|
+
Defaults to 512.
|
|
45
|
+
kv_cache_max_len (int, optional): The maximum size of KV cache buffer,
|
|
46
|
+
including both prefill and decode. Defaults to 1024.
|
|
47
|
+
quantize (bool, optional): Whether the model should be quanized.
|
|
48
|
+
Defaults to True.
|
|
49
|
+
"""
|
|
50
|
+
pytorch_model = tiny_llama.build_model(
|
|
51
|
+
checkpoint_path, kv_cache_max_len=kv_cache_max_len
|
|
52
|
+
)
|
|
53
|
+
# Tensors used to trace the model graph during conversion.
|
|
54
|
+
prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.long)
|
|
55
|
+
prefill_input_pos = torch.arange(0, prefill_seq_len)
|
|
56
|
+
decode_token = torch.tensor([[0]], dtype=torch.long)
|
|
57
|
+
decode_input_pos = torch.tensor([0], dtype=torch.int64)
|
|
58
|
+
kv = kv_utils.EKVCache.from_model_config(pytorch_model.config)
|
|
59
|
+
|
|
60
|
+
quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
|
|
61
|
+
edge_model = (
|
|
62
|
+
ai_edge_torch.signature(
|
|
63
|
+
'prefill',
|
|
64
|
+
pytorch_model,
|
|
65
|
+
sample_kwargs={
|
|
66
|
+
'tokens': prefill_tokens,
|
|
67
|
+
'input_pos': prefill_input_pos,
|
|
68
|
+
'kv_cache': kv,
|
|
69
|
+
},
|
|
70
|
+
)
|
|
71
|
+
.signature(
|
|
72
|
+
'decode',
|
|
73
|
+
pytorch_model,
|
|
74
|
+
sample_kwargs={
|
|
75
|
+
'tokens': decode_token,
|
|
76
|
+
'input_pos': decode_input_pos,
|
|
77
|
+
'kv_cache': kv,
|
|
78
|
+
},
|
|
79
|
+
)
|
|
80
|
+
.convert(quant_config=quant_config)
|
|
81
|
+
)
|
|
82
|
+
edge_model.export(
|
|
83
|
+
f'/tmp/tiny_llama_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
if __name__ == '__main__':
|
|
88
|
+
checkpoint_path = os.path.join(Path.home(), 'Downloads/llm_data/tiny_llama')
|
|
89
|
+
convert_tiny_llama_to_tflite(checkpoint_path)
|
|
@@ -0,0 +1,185 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
# Example of building a TinyLlama model from the Edge Generative API layers.
|
|
16
|
+
#
|
|
17
|
+
# Note: This is an experimental version of TinyLlama with external KV cache.
|
|
18
|
+
# Please use with caution.
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
import os
|
|
22
|
+
from pathlib import Path
|
|
23
|
+
from typing import Tuple
|
|
24
|
+
|
|
25
|
+
import numpy as np
|
|
26
|
+
import torch
|
|
27
|
+
import torch.nn as nn
|
|
28
|
+
|
|
29
|
+
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
|
30
|
+
import ai_edge_torch.generative.layers.builder as builder
|
|
31
|
+
from ai_edge_torch.generative.layers.experimental import ekv_cache as kv_utils
|
|
32
|
+
from ai_edge_torch.generative.layers.experimental.attention import TransformerBlock # NOQA
|
|
33
|
+
import ai_edge_torch.generative.layers.model_config as cfg
|
|
34
|
+
import ai_edge_torch.generative.utilities.loader as loading_utils
|
|
35
|
+
|
|
36
|
+
TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
|
37
|
+
ff_up_proj="model.layers.{}.mlp.up_proj",
|
|
38
|
+
ff_down_proj="model.layers.{}.mlp.down_proj",
|
|
39
|
+
ff_gate_proj="model.layers.{}.mlp.gate_proj",
|
|
40
|
+
attn_query_proj="model.layers.{}.self_attn.q_proj",
|
|
41
|
+
attn_key_proj="model.layers.{}.self_attn.k_proj",
|
|
42
|
+
attn_value_proj="model.layers.{}.self_attn.v_proj",
|
|
43
|
+
attn_output_proj="model.layers.{}.self_attn.o_proj",
|
|
44
|
+
pre_attn_norm="model.layers.{}.input_layernorm",
|
|
45
|
+
pre_ff_norm="model.layers.{}.post_attention_layernorm",
|
|
46
|
+
embedding="model.embed_tokens",
|
|
47
|
+
final_norm="model.norm",
|
|
48
|
+
lm_head="lm_head",
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class TinyLLamma(nn.Module):
|
|
53
|
+
|
|
54
|
+
def __init__(self, config: cfg.ModelConfig):
|
|
55
|
+
super().__init__()
|
|
56
|
+
|
|
57
|
+
self.config = config
|
|
58
|
+
# Construct model layers.
|
|
59
|
+
self.lm_head = nn.Linear(
|
|
60
|
+
config.embedding_dim, config.vocab_size, bias=config.lm_head_use_bias
|
|
61
|
+
)
|
|
62
|
+
self.tok_embedding = nn.Embedding(
|
|
63
|
+
config.vocab_size, config.embedding_dim, padding_idx=0
|
|
64
|
+
)
|
|
65
|
+
self.transformer_blocks = nn.ModuleList(
|
|
66
|
+
TransformerBlock(config) for _ in range(config.num_layers)
|
|
67
|
+
)
|
|
68
|
+
self.final_norm = builder.build_norm(
|
|
69
|
+
config.embedding_dim,
|
|
70
|
+
config.final_norm_config,
|
|
71
|
+
)
|
|
72
|
+
self.rope_cache = attn_utils.build_rope_cache(
|
|
73
|
+
size=config.kv_cache_max,
|
|
74
|
+
dim=int(config.attn_config.rotary_percentage * config.head_dim),
|
|
75
|
+
base=10_000,
|
|
76
|
+
condense_ratio=1,
|
|
77
|
+
dtype=torch.float32,
|
|
78
|
+
device=torch.device("cpu"),
|
|
79
|
+
)
|
|
80
|
+
self.mask_cache = attn_utils.build_causal_mask_cache(
|
|
81
|
+
size=config.kv_cache_max, dtype=torch.float32, device=torch.device("cpu")
|
|
82
|
+
)
|
|
83
|
+
self.config = config
|
|
84
|
+
|
|
85
|
+
@torch.inference_mode
|
|
86
|
+
def forward(
|
|
87
|
+
self,
|
|
88
|
+
tokens: torch.Tensor,
|
|
89
|
+
input_pos: torch.Tensor,
|
|
90
|
+
kv_cache: kv_utils.EKVCache,
|
|
91
|
+
) -> Tuple[torch.Tensor, kv_utils.EKVCache]:
|
|
92
|
+
B, T = tokens.size()
|
|
93
|
+
assert (
|
|
94
|
+
self.config.max_seq_len >= T
|
|
95
|
+
), f"Cannot forward sequence of length {T}, max seq length is only {self.config.max_seq_len}"
|
|
96
|
+
|
|
97
|
+
cos, sin = self.rope_cache
|
|
98
|
+
cos = cos.index_select(0, input_pos)
|
|
99
|
+
sin = sin.index_select(0, input_pos)
|
|
100
|
+
mask = self.mask_cache.index_select(2, input_pos)
|
|
101
|
+
mask = mask[:, :, :, : self.config.kv_cache_max]
|
|
102
|
+
|
|
103
|
+
# token embeddings of shape (b, t, n_embd)
|
|
104
|
+
x = self.tok_embedding(tokens)
|
|
105
|
+
|
|
106
|
+
updated_kv_entires = []
|
|
107
|
+
for i, block in enumerate(self.transformer_blocks):
|
|
108
|
+
kv_entry = kv_cache.caches[i] if kv_cache else None
|
|
109
|
+
x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
|
|
110
|
+
if kv_entry:
|
|
111
|
+
updated_kv_entires.append(kv_entry)
|
|
112
|
+
updated_kv_cache = kv_utils.EKVCache(tuple(updated_kv_entires))
|
|
113
|
+
|
|
114
|
+
x = self.final_norm(x)
|
|
115
|
+
res = self.lm_head(x) # (b, t, vocab_size)
|
|
116
|
+
return res, updated_kv_cache
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
120
|
+
attn_config = cfg.AttentionConfig(
|
|
121
|
+
num_heads=32,
|
|
122
|
+
num_query_groups=4,
|
|
123
|
+
rotary_percentage=1.0,
|
|
124
|
+
)
|
|
125
|
+
ff_config = cfg.FeedForwardConfig(
|
|
126
|
+
type=cfg.FeedForwardType.GATED,
|
|
127
|
+
activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
|
|
128
|
+
intermediate_size=5632,
|
|
129
|
+
)
|
|
130
|
+
norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
|
|
131
|
+
config = cfg.ModelConfig(
|
|
132
|
+
vocab_size=32000,
|
|
133
|
+
num_layers=22,
|
|
134
|
+
max_seq_len=2048,
|
|
135
|
+
embedding_dim=2048,
|
|
136
|
+
kv_cache_max_len=kv_cache_max_len,
|
|
137
|
+
attn_config=attn_config,
|
|
138
|
+
ff_config=ff_config,
|
|
139
|
+
pre_attention_norm_config=norm_config,
|
|
140
|
+
pre_ff_norm_config=norm_config,
|
|
141
|
+
final_norm_config=norm_config,
|
|
142
|
+
enable_hlfb=True,
|
|
143
|
+
)
|
|
144
|
+
return config
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
def get_fake_model_config_for_test(**kwargs) -> cfg.ModelConfig:
|
|
148
|
+
config = get_model_config(**kwargs)
|
|
149
|
+
config.vocab_size = 128
|
|
150
|
+
config.num_layers = 2
|
|
151
|
+
config.ff_config.intermediate_size = 256
|
|
152
|
+
return config
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
def build_model(checkpoint_path, test_model=False, **kwargs) -> nn.Module:
|
|
156
|
+
config = (
|
|
157
|
+
get_fake_model_config_for_test(**kwargs)
|
|
158
|
+
if test_model
|
|
159
|
+
else get_model_config(**kwargs)
|
|
160
|
+
)
|
|
161
|
+
model = TinyLLamma(config)
|
|
162
|
+
if checkpoint_path is not None:
|
|
163
|
+
loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
|
|
164
|
+
loader.load(model)
|
|
165
|
+
model.eval()
|
|
166
|
+
return model
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
def define_and_run(checkpoint_path, test_model=False) -> None:
|
|
170
|
+
kv_cache_max_len = 1024
|
|
171
|
+
model = build_model(
|
|
172
|
+
checkpoint_path, test_model=test_model, kv_cache_max_len=kv_cache_max_len
|
|
173
|
+
)
|
|
174
|
+
idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
|
|
175
|
+
tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.long, device="cpu")
|
|
176
|
+
tokens[0, :4] = idx
|
|
177
|
+
input_pos = torch.arange(0, kv_cache_max_len)
|
|
178
|
+
kv = kv_utils.EKVCache.from_model_config(model.config)
|
|
179
|
+
print("running an inference")
|
|
180
|
+
print(model.forward(tokens, input_pos, kv))
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
if __name__ == "__main__":
|
|
184
|
+
checkpoint_path = os.path.join(Path.home(), "Downloads/tiny_llama")
|
|
185
|
+
define_and_run(checkpoint_path)
|
|
@@ -159,6 +159,9 @@ def build_2b_model(checkpoint_path, **kwargs) -> nn.Module:
|
|
|
159
159
|
|
|
160
160
|
|
|
161
161
|
def define_and_run_2b() -> None:
|
|
162
|
+
current_dir = Path(__file__).parent.resolve()
|
|
163
|
+
gemma_goldens = torch.load(current_dir / "gemma_lm_logits.pt")
|
|
164
|
+
|
|
162
165
|
kv_cache_max_len = 1024
|
|
163
166
|
checkpoint_path = os.path.join(Path.home(), "Downloads/llm_data/gemma-2b")
|
|
164
167
|
model = build_2b_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
|
|
@@ -166,8 +169,9 @@ def define_and_run_2b() -> None:
|
|
|
166
169
|
tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.long, device="cpu")
|
|
167
170
|
tokens[0, :4] = idx
|
|
168
171
|
input_pos = torch.arange(0, kv_cache_max_len)
|
|
169
|
-
|
|
170
|
-
print(
|
|
172
|
+
lm_logits = model.forward(tokens, input_pos)
|
|
173
|
+
print("comparing with goldens..")
|
|
174
|
+
assert torch.allclose(gemma_goldens, lm_logits[0, idx.shape[1] - 1, :], atol=1e-05)
|
|
171
175
|
|
|
172
176
|
|
|
173
177
|
if __name__ == "__main__":
|
|
@@ -149,6 +149,8 @@ def build_model(checkpoint_path, **kwargs) -> nn.Module:
|
|
|
149
149
|
|
|
150
150
|
|
|
151
151
|
def define_and_run() -> None:
|
|
152
|
+
current_dir = Path(__file__).parent.resolve()
|
|
153
|
+
phi2_goldens = torch.load(current_dir / "phi2_lm_logits.pt")
|
|
152
154
|
kv_cache_max_len = 1024
|
|
153
155
|
checkpoint_path = os.path.join(Path.home(), "Downloads/llm_data/phi2")
|
|
154
156
|
model = build_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
|
|
@@ -156,8 +158,9 @@ def define_and_run() -> None:
|
|
|
156
158
|
tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.long, device="cpu")
|
|
157
159
|
tokens[0, :4] = idx
|
|
158
160
|
input_pos = torch.arange(0, kv_cache_max_len)
|
|
159
|
-
|
|
160
|
-
print(
|
|
161
|
+
lm_logits = model.forward(tokens, input_pos)
|
|
162
|
+
print("comparing with goldens..")
|
|
163
|
+
assert torch.allclose(phi2_goldens, lm_logits[0, idx.shape[1] - 1, :], atol=1e-05)
|
|
161
164
|
|
|
162
165
|
|
|
163
166
|
if __name__ == "__main__":
|
|
@@ -557,7 +557,8 @@ def get_sample_encoder_input_ids() -> torch.Tensor:
|
|
|
557
557
|
|
|
558
558
|
|
|
559
559
|
def define_and_run_t5(checkpoint_path: str) -> None:
|
|
560
|
-
|
|
560
|
+
current_dir = Path(__file__).parent.resolve()
|
|
561
|
+
t5_goldens = torch.load(current_dir / "t5_lm_logits.pt")
|
|
561
562
|
|
|
562
563
|
model = build_t5_model(checkpoint_path)
|
|
563
564
|
|
|
@@ -579,7 +580,9 @@ def define_and_run_t5(checkpoint_path: str) -> None:
|
|
|
579
580
|
|
|
580
581
|
# TODO(haoliang): Move those tests.
|
|
581
582
|
def define_and_run_t5_split(checkpoint_path: str) -> None:
|
|
582
|
-
|
|
583
|
+
current_dir = Path(__file__).parent.resolve()
|
|
584
|
+
t5_goldens = torch.load(current_dir / "t5_lm_logits.pt")
|
|
585
|
+
|
|
583
586
|
config = get_model_config_t5()
|
|
584
587
|
embedding_layer = nn.Embedding(config.vocab_size, config.embedding_dim, padding_idx=0)
|
|
585
588
|
t5_encoder_model = build_t5_encoder_model(config, embedding_layer, checkpoint_path)
|
|
@@ -14,9 +14,8 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
# A toy example which has basic transformer block (w/ externalized KV-Cache).
|
|
16
16
|
|
|
17
|
-
from typing import
|
|
17
|
+
from typing import Tuple
|
|
18
18
|
|
|
19
|
-
import numpy as np
|
|
20
19
|
import torch
|
|
21
20
|
import torch.nn as nn
|
|
22
21
|
import torch_xla
|
|
@@ -24,6 +23,7 @@ import torch_xla
|
|
|
24
23
|
import ai_edge_torch
|
|
25
24
|
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
|
26
25
|
import ai_edge_torch.generative.layers.builder as builder
|
|
26
|
+
from ai_edge_torch.generative.layers.experimental import ekv_cache as kv_utils
|
|
27
27
|
from ai_edge_torch.generative.layers.experimental.attention import TransformerBlock # NOQA
|
|
28
28
|
import ai_edge_torch.generative.layers.model_config as cfg
|
|
29
29
|
|
|
@@ -60,27 +60,27 @@ class ToyModelWithExternalKV(torch.nn.Module):
|
|
|
60
60
|
|
|
61
61
|
def forward(
|
|
62
62
|
self,
|
|
63
|
-
|
|
63
|
+
tokens: torch.Tensor,
|
|
64
64
|
input_pos: torch.Tensor,
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
x = self.tok_embedding(idx)
|
|
65
|
+
kv_cache: kv_utils.EKVCache,
|
|
66
|
+
) -> Tuple[torch.Tensor, kv_utils.EKVCache]:
|
|
67
|
+
x = self.tok_embedding(tokens)
|
|
69
68
|
cos, sin = self.rope_cache
|
|
70
69
|
cos = cos.index_select(0, input_pos)
|
|
71
70
|
sin = sin.index_select(0, input_pos)
|
|
72
71
|
mask = self.mask_cache.index_select(2, input_pos)
|
|
73
72
|
mask = mask[:, :, :, : self.config.max_seq_len]
|
|
74
73
|
|
|
74
|
+
updated_kv_entires = []
|
|
75
75
|
for i, block in enumerate(self.transformer_blocks):
|
|
76
|
-
|
|
77
|
-
x, (
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
k_caches[i], v_caches[i] = updated_k, updated_v
|
|
76
|
+
kv_entry = kv_cache.caches[i] if kv_cache else None
|
|
77
|
+
x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
|
|
78
|
+
if kv_entry:
|
|
79
|
+
updated_kv_entires.append(kv_entry)
|
|
81
80
|
|
|
82
81
|
x = self.final_norm(x)
|
|
83
|
-
|
|
82
|
+
updated_kv_cache = kv_utils.EKVCache(tuple(updated_kv_entires))
|
|
83
|
+
return self.lm_head(x), updated_kv_cache
|
|
84
84
|
|
|
85
85
|
|
|
86
86
|
def _export_stablehlo_mlir(model, args):
|
|
@@ -115,15 +115,15 @@ def get_model_config() -> cfg.ModelConfig:
|
|
|
115
115
|
|
|
116
116
|
|
|
117
117
|
def get_sample_prefill_inputs() -> Tuple[torch.Tensor, torch.Tensor]:
|
|
118
|
-
|
|
118
|
+
tokens = torch.unsqueeze(torch.arange(0, 100), 0)
|
|
119
119
|
input_pos = torch.arange(0, 100)
|
|
120
|
-
return
|
|
120
|
+
return tokens, input_pos
|
|
121
121
|
|
|
122
122
|
|
|
123
123
|
def get_sample_decode_inputs() -> Tuple[torch.Tensor, torch.Tensor]:
|
|
124
|
-
|
|
124
|
+
tokens = torch.tensor([[1]], dtype=torch.long)
|
|
125
125
|
input_pos = torch.tensor([10])
|
|
126
|
-
return
|
|
126
|
+
return tokens, input_pos
|
|
127
127
|
|
|
128
128
|
|
|
129
129
|
def define_and_run() -> None:
|
|
@@ -131,16 +131,16 @@ def define_and_run() -> None:
|
|
|
131
131
|
|
|
132
132
|
config = get_model_config()
|
|
133
133
|
model = ToyModelWithExternalKV(config)
|
|
134
|
+
model.eval()
|
|
134
135
|
print('running an inference')
|
|
135
|
-
|
|
136
|
-
v_caches = torch.zeros((2, 1, 100, 4, 4), dtype=torch.float32)
|
|
136
|
+
kv = kv_utils.EKVCache.from_model_config(config)
|
|
137
137
|
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
print(model.forward(
|
|
138
|
+
tokens, input_pos = get_sample_prefill_inputs()
|
|
139
|
+
decode_token, decode_input_pos = get_sample_decode_inputs()
|
|
140
|
+
print(model.forward(tokens, input_pos, kv))
|
|
141
141
|
|
|
142
142
|
if dump_mlir:
|
|
143
|
-
mlir_text = _export_stablehlo_mlir(model, (
|
|
143
|
+
mlir_text = _export_stablehlo_mlir(model, (tokens, input_pos, kv))
|
|
144
144
|
with open('/tmp/toy_model_with_external_kv.stablehlo.mlir', 'w') as f:
|
|
145
145
|
f.write(mlir_text)
|
|
146
146
|
|
|
@@ -149,13 +149,28 @@ def define_and_run() -> None:
|
|
|
149
149
|
# in dynamic update slice op.
|
|
150
150
|
print('converting toy model to tflite with 2 signatures (prefill + decode)')
|
|
151
151
|
edge_model = (
|
|
152
|
-
ai_edge_torch.signature(
|
|
153
|
-
|
|
152
|
+
ai_edge_torch.signature(
|
|
153
|
+
'prefill',
|
|
154
|
+
model,
|
|
155
|
+
sample_kwargs={
|
|
156
|
+
'tokens': tokens,
|
|
157
|
+
'input_pos': input_pos,
|
|
158
|
+
'kv_cache': kv,
|
|
159
|
+
},
|
|
160
|
+
)
|
|
161
|
+
.signature(
|
|
162
|
+
'decode',
|
|
163
|
+
model,
|
|
164
|
+
sample_kwargs={
|
|
165
|
+
'tokens': decode_token,
|
|
166
|
+
'input_pos': decode_input_pos,
|
|
167
|
+
'kv_cache': kv,
|
|
168
|
+
},
|
|
169
|
+
)
|
|
154
170
|
.convert()
|
|
155
171
|
)
|
|
156
172
|
edge_model.export('/tmp/toy_external_kv_cache.tflite')
|
|
157
173
|
|
|
158
174
|
|
|
159
175
|
if __name__ == '__main__':
|
|
160
|
-
|
|
161
|
-
define_and_run()
|
|
176
|
+
define_and_run()
|
|
@@ -149,6 +149,8 @@ def build_model(checkpoint_path, **kwargs) -> nn.Module:
|
|
|
149
149
|
|
|
150
150
|
|
|
151
151
|
def define_and_run() -> None:
|
|
152
|
+
current_dir = Path(__file__).parent.resolve()
|
|
153
|
+
tiny_llama_goldens = torch.load(current_dir / "tiny_llama_lm_logits.pt")
|
|
152
154
|
kv_cache_max_len = 1024
|
|
153
155
|
checkpoint_path = os.path.join(Path.home(), "Downloads/llm_data/tiny_llama")
|
|
154
156
|
model = build_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
|
|
@@ -156,8 +158,10 @@ def define_and_run() -> None:
|
|
|
156
158
|
tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.long, device="cpu")
|
|
157
159
|
tokens[0, :4] = idx
|
|
158
160
|
input_pos = torch.arange(0, kv_cache_max_len)
|
|
159
|
-
|
|
160
|
-
|
|
161
|
+
lm_logits = model.forward(tokens, input_pos)
|
|
162
|
+
assert torch.allclose(
|
|
163
|
+
tiny_llama_goldens, lm_logits[0, idx.shape[1] - 1, :], atol=1e-05
|
|
164
|
+
)
|
|
161
165
|
|
|
162
166
|
|
|
163
167
|
if __name__ == "__main__":
|
|
@@ -0,0 +1,122 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
# A suite of tests to validate experimental external KV Cache layers and models.
|
|
16
|
+
|
|
17
|
+
import unittest
|
|
18
|
+
|
|
19
|
+
import numpy as np
|
|
20
|
+
import torch
|
|
21
|
+
|
|
22
|
+
from ai_edge_torch.generative.examples.experimental.gemma import gemma
|
|
23
|
+
from ai_edge_torch.generative.examples.experimental.phi import phi2
|
|
24
|
+
from ai_edge_torch.generative.examples.experimental.tiny_llama import tiny_llama # NOQA
|
|
25
|
+
from ai_edge_torch.generative.layers.experimental import ekv_cache as kv_utils
|
|
26
|
+
import ai_edge_torch.generative.layers.model_config as cfg
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class TestExternalKVLayers(unittest.TestCase):
|
|
30
|
+
|
|
31
|
+
def _get_test_config(self, num_layers, head_dim, num_query_groups, kv_cache_max_len):
|
|
32
|
+
attn_config = cfg.AttentionConfig(num_heads=1, num_query_groups=num_query_groups)
|
|
33
|
+
config = cfg.ModelConfig(
|
|
34
|
+
kv_cache_max_len=kv_cache_max_len,
|
|
35
|
+
embedding_dim=head_dim,
|
|
36
|
+
attn_config=attn_config,
|
|
37
|
+
num_layers=num_layers,
|
|
38
|
+
max_seq_len=None,
|
|
39
|
+
vocab_size=None,
|
|
40
|
+
ff_config=None,
|
|
41
|
+
)
|
|
42
|
+
return config
|
|
43
|
+
|
|
44
|
+
def test_cache_udpate(self):
|
|
45
|
+
N = 1
|
|
46
|
+
HEAD_DIM = 2
|
|
47
|
+
NUM_QG = 1
|
|
48
|
+
KV_LEN = 4
|
|
49
|
+
config = self._get_test_config(
|
|
50
|
+
num_layers=N,
|
|
51
|
+
head_dim=HEAD_DIM,
|
|
52
|
+
num_query_groups=NUM_QG,
|
|
53
|
+
kv_cache_max_len=KV_LEN,
|
|
54
|
+
)
|
|
55
|
+
kv = kv_utils.EKVCache.from_model_config(config)
|
|
56
|
+
entry = kv.caches[0]
|
|
57
|
+
# single-slice update
|
|
58
|
+
input_pos = torch.tensor([1])
|
|
59
|
+
k_slice = v_slice = torch.full((1, 1, NUM_QG, HEAD_DIM), 5, dtype=torch.float)
|
|
60
|
+
updated_entry = kv_utils.update(entry, input_pos, k_slice, v_slice)
|
|
61
|
+
self.assertEqual(
|
|
62
|
+
updated_entry.k_cache.numpy().flatten().tolist(), [0, 0, 5, 5, 0, 0, 0, 0]
|
|
63
|
+
)
|
|
64
|
+
self.assertEqual(
|
|
65
|
+
updated_entry.v_cache.numpy().flatten().tolist(), [0, 0, 5, 5, 0, 0, 0, 0]
|
|
66
|
+
)
|
|
67
|
+
# multi-slice update
|
|
68
|
+
input_pos = torch.tensor([0, 3])
|
|
69
|
+
k_slice = v_slice = torch.full((1, 2, NUM_QG, HEAD_DIM), 7, dtype=torch.float)
|
|
70
|
+
updated_entry = kv_utils.update(entry, input_pos, k_slice, v_slice)
|
|
71
|
+
self.assertEqual(
|
|
72
|
+
updated_entry.k_cache.numpy().flatten().tolist(), [7, 7, 0, 0, 0, 0, 7, 7]
|
|
73
|
+
)
|
|
74
|
+
self.assertEqual(
|
|
75
|
+
updated_entry.v_cache.numpy().flatten().tolist(), [7, 7, 0, 0, 0, 0, 7, 7]
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
def test_serialization(self):
|
|
79
|
+
class TestModel(torch.nn.Module):
|
|
80
|
+
|
|
81
|
+
def forward(self, kv: kv_utils.EKVCache) -> kv_utils.EKVCache:
|
|
82
|
+
updated_kv_entries = [
|
|
83
|
+
kv_utils.KVCacheEntry(
|
|
84
|
+
torch.zeros_like(entry.k_cache), torch.zeros_like(entry.v_cache)
|
|
85
|
+
)
|
|
86
|
+
for entry in kv.caches
|
|
87
|
+
]
|
|
88
|
+
return kv_utils.EKVCache(updated_kv_entries)
|
|
89
|
+
|
|
90
|
+
N = 1
|
|
91
|
+
HEAD_DIM = 2
|
|
92
|
+
NUM_QG = 1
|
|
93
|
+
KV_LEN = 4
|
|
94
|
+
config = self._get_test_config(
|
|
95
|
+
num_layers=N,
|
|
96
|
+
head_dim=HEAD_DIM,
|
|
97
|
+
num_query_groups=NUM_QG,
|
|
98
|
+
kv_cache_max_len=KV_LEN,
|
|
99
|
+
)
|
|
100
|
+
kv = kv_utils.EKVCache.from_model_config(config)
|
|
101
|
+
model = TestModel()
|
|
102
|
+
exported_program = torch.export.export(model, (kv,))
|
|
103
|
+
input_specs = exported_program.graph_signature.input_specs
|
|
104
|
+
self.assertEqual(len(input_specs), 2)
|
|
105
|
+
self.assertEqual(input_specs[0].arg.name, "kv_k_0")
|
|
106
|
+
self.assertEqual(input_specs[1].arg.name, "kv_v_0")
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
class TestExternalKVModels(unittest.TestCase):
|
|
110
|
+
|
|
111
|
+
def test_can_build_gemma(self):
|
|
112
|
+
gemma.define_and_run_2b(checkpoint_path=None, test_model=True)
|
|
113
|
+
|
|
114
|
+
def test_can_build_phi2(self):
|
|
115
|
+
phi2.define_and_run(checkpoint_path=None, test_model=True)
|
|
116
|
+
|
|
117
|
+
def test_can_build_tinyllama(self):
|
|
118
|
+
tiny_llama.define_and_run(checkpoint_path=None, test_model=True)
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
if __name__ == "__main__":
|
|
122
|
+
unittest.main()
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: ai-edge-torch-nightly
|
|
3
|
-
Version: 0.2.0.
|
|
3
|
+
Version: 0.2.0.dev20240721
|
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
|
@@ -35,12 +35,22 @@ ai_edge_torch/debug/test/test_search_model.py,sha256=0guAEon5cvwBpPXk6J0wVOKj7TX
|
|
|
35
35
|
ai_edge_torch/experimental/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
36
36
|
ai_edge_torch/generative/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
37
37
|
ai_edge_torch/generative/examples/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
38
|
+
ai_edge_torch/generative/examples/experimental/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
39
|
+
ai_edge_torch/generative/examples/experimental/gemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
40
|
+
ai_edge_torch/generative/examples/experimental/gemma/convert_to_tflite.py,sha256=bW_KOj_3fcZAggfST3zHWcMcNJs70b0pld-vvauAOgo,3076
|
|
41
|
+
ai_edge_torch/generative/examples/experimental/gemma/gemma.py,sha256=u4DNsZRnN7whDoK8yQet9Yahb01ToVqTuFQmWV1__1g,6606
|
|
42
|
+
ai_edge_torch/generative/examples/experimental/phi/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
43
|
+
ai_edge_torch/generative/examples/experimental/phi/convert_to_tflite.py,sha256=sLU35tpQ-PEbhZbLfC1vSqM-HamKREVBpIoywWh9O3M,3036
|
|
44
|
+
ai_edge_torch/generative/examples/experimental/phi/phi2.py,sha256=zgxB2JSFAevyS28C6-wIBaQeeKTUejUJY4dnR4BqRBI,6150
|
|
45
|
+
ai_edge_torch/generative/examples/experimental/tiny_llama/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
46
|
+
ai_edge_torch/generative/examples/experimental/tiny_llama/convert_to_tflite.py,sha256=PChEhBotZ8k6GZiq9e_AYnn3RyhNIVm_U96QhVjx3jY,3126
|
|
47
|
+
ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py,sha256=1vL0u6Pkd8SV8uei9BGzSAIokclT_RaE3K0IczoPfeI,6291
|
|
38
48
|
ai_edge_torch/generative/examples/gemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
39
49
|
ai_edge_torch/generative/examples/gemma/convert_to_tflite.py,sha256=UMEZGDGhFvAX4eT5KHAE1Xbxw-qtQWEMxgvB8cSH6wY,2531
|
|
40
|
-
ai_edge_torch/generative/examples/gemma/gemma.py,sha256=
|
|
50
|
+
ai_edge_torch/generative/examples/gemma/gemma.py,sha256=YyGGsgEByIg_tIysMBqaBztf_csthZIjah8mmH5o7UA,6144
|
|
41
51
|
ai_edge_torch/generative/examples/phi2/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
42
52
|
ai_edge_torch/generative/examples/phi2/convert_to_tflite.py,sha256=uF1A2EX8xYie30-T2Z7s1WZCtFhp5CEwRV8SCd7Umrc,2505
|
|
43
|
-
ai_edge_torch/generative/examples/phi2/phi2.py,sha256=
|
|
53
|
+
ai_edge_torch/generative/examples/phi2/phi2.py,sha256=KjfTrD2OBzOfq83-XvJ6ZhmXLuP_VqugSOwyj-M5YY4,5767
|
|
44
54
|
ai_edge_torch/generative/examples/stable_diffusion/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
45
55
|
ai_edge_torch/generative/examples/stable_diffusion/attention.py,sha256=Lo4Dq7a3Kg-lyH56iqGtqCo5UaClQHRCTDdNagXGTo8,3535
|
|
46
56
|
ai_edge_torch/generative/examples/stable_diffusion/clip.py,sha256=P-cUUQaQKGKV2p-7hvLJ--RpCIA7gk8WCDRgg0pNtd0,4331
|
|
@@ -58,15 +68,15 @@ ai_edge_torch/generative/examples/stable_diffusion/samplers/k_lms.py,sha256=iPYX
|
|
|
58
68
|
ai_edge_torch/generative/examples/stable_diffusion/samplers/sampler.py,sha256=5iRfU5MO6GR6K3WrdddIU_9U7ZZGEEb7zGKVY1WFl-8,1340
|
|
59
69
|
ai_edge_torch/generative/examples/t5/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
60
70
|
ai_edge_torch/generative/examples/t5/convert_to_tflite.py,sha256=7RwaZQaKhFt3zKAUbFjq95CSYhL1nd9BVSbSRNJp4-4,4529
|
|
61
|
-
ai_edge_torch/generative/examples/t5/t5.py,sha256=
|
|
71
|
+
ai_edge_torch/generative/examples/t5/t5.py,sha256=fVtJ0S8v2bMtvEuDqD6Orw7CTyXqnRIqZfKcz7DBeJc,21212
|
|
62
72
|
ai_edge_torch/generative/examples/t5/t5_attention.py,sha256=KaGzCAViNOpJIQbRF-ItouuVPqI9nroWRRGN-KFYKZs,8357
|
|
63
73
|
ai_edge_torch/generative/examples/test_models/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
64
74
|
ai_edge_torch/generative/examples/test_models/toy_model.py,sha256=Sf3ZMYv-iuMRKAKLow47qth8vTF1zl6i8TxJ9uT_StU,3885
|
|
65
|
-
ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py,sha256=
|
|
75
|
+
ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py,sha256=jmucKpWY_nHEOAh7G62IxpReNmrKWo4PxfELul_h9xQ,5796
|
|
66
76
|
ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py,sha256=lfYUiem_Pbn3vGgPx84BeI8n7rN3-1fImwCLm8Eo2U8,4853
|
|
67
77
|
ai_edge_torch/generative/examples/tiny_llama/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
68
78
|
ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py,sha256=nT7Fh-f5ZdwaK3dPoCvZflpJ4fRHjLdFMjk1_uw3-b8,2559
|
|
69
|
-
ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=
|
|
79
|
+
ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=to9IlF-X_uIJvO-roZOW1ZMUhmkYbvFjc-tUVaQr6TE,5848
|
|
70
80
|
ai_edge_torch/generative/fx_passes/__init__.py,sha256=aXvYiaHDvETIrh0Q9DDZA_ZBiazGk80DT6nt7lLtC1o,1172
|
|
71
81
|
ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=BCAcc_OcEjvbaXQSbc8vlKeMad7E3gCA4BNsUdWRwBI,1966
|
|
72
82
|
ai_edge_torch/generative/layers/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
@@ -94,6 +104,7 @@ ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/__init__.py,sha256=47DE
|
|
|
94
104
|
ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/translate_recipe.py,sha256=iTNPrlubmq9ia7C3zHl50J2YEMsc4o33GwL5tr5VkkE,5229
|
|
95
105
|
ai_edge_torch/generative/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
96
106
|
ai_edge_torch/generative/test/loader_test.py,sha256=N88CbrLW7Q2x1EyurwdXQ6YjsA-ySQcPxpZH3QOGp-M,3317
|
|
107
|
+
ai_edge_torch/generative/test/test_experimental_ekv.py,sha256=qMR0r7Pr_t2bn-cyeA7Qw_Rl94H1NmFcqM2ua8gpDDw,4230
|
|
97
108
|
ai_edge_torch/generative/test/test_model_conversion.py,sha256=LsPTrLC1I4JW2GowTS3V9Eu257vLHr2Yj5f_qaFUX84,7589
|
|
98
109
|
ai_edge_torch/generative/test/test_quantize.py,sha256=QbF7LC9olJFGXqlAVGciac7xXc4rDtCSr71tTIYuqPk,5230
|
|
99
110
|
ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
|
|
@@ -114,8 +125,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=eO9Ra160ITjQSyRBEGy6nNIVH3gYacSWDd
|
|
|
114
125
|
ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
115
126
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
|
116
127
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=kzIulTldq8R9E-lAZsvfSTvLu3FYEX7b9DyYM3qisXM,4485
|
|
117
|
-
ai_edge_torch_nightly-0.2.0.
|
|
118
|
-
ai_edge_torch_nightly-0.2.0.
|
|
119
|
-
ai_edge_torch_nightly-0.2.0.
|
|
120
|
-
ai_edge_torch_nightly-0.2.0.
|
|
121
|
-
ai_edge_torch_nightly-0.2.0.
|
|
128
|
+
ai_edge_torch_nightly-0.2.0.dev20240721.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
|
129
|
+
ai_edge_torch_nightly-0.2.0.dev20240721.dist-info/METADATA,sha256=TJYFNAxXQkRwt9I_0OqpUOS3opWBU5i-ioMwsicD7cY,1745
|
|
130
|
+
ai_edge_torch_nightly-0.2.0.dev20240721.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
|
131
|
+
ai_edge_torch_nightly-0.2.0.dev20240721.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
|
132
|
+
ai_edge_torch_nightly-0.2.0.dev20240721.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|