ai-edge-torch-nightly 0.3.0.dev20241002__py3-none-any.whl → 0.3.0.dev20241005__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/generative/examples/gemma/gemma1.py +10 -93
- ai_edge_torch/generative/examples/gemma/gemma2.py +0 -1
- ai_edge_torch/generative/examples/llama/convert_to_tflite.py +13 -2
- ai_edge_torch/generative/examples/llama/llama.py +19 -24
- ai_edge_torch/generative/examples/llama/verify.py +18 -3
- ai_edge_torch/generative/examples/openelm/openelm.py +9 -90
- ai_edge_torch/generative/examples/phi/phi2.py +10 -86
- ai_edge_torch/generative/examples/phi/phi3.py +9 -69
- ai_edge_torch/generative/examples/qwen/qwen.py +26 -36
- ai_edge_torch/generative/examples/smollm/smollm.py +10 -30
- ai_edge_torch/generative/examples/stable_diffusion/clip.py +1 -3
- ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +40 -32
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +11 -101
- ai_edge_torch/generative/layers/model_config.py +6 -0
- ai_edge_torch/generative/test/test_loader.py +2 -1
- ai_edge_torch/generative/test/test_model_conversion.py +39 -17
- ai_edge_torch/generative/test/test_model_conversion_large.py +6 -5
- ai_edge_torch/generative/utilities/model_builder.py +141 -0
- ai_edge_torch/lowertools/translate_recipe.py +2 -2
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20241002.dist-info → ai_edge_torch_nightly-0.3.0.dev20241005.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20241002.dist-info → ai_edge_torch_nightly-0.3.0.dev20241005.dist-info}/RECORD +25 -26
- ai_edge_torch/generative/examples/llama/convert_3b_to_tflite.py +0 -68
- ai_edge_torch/generative/examples/llama/verify_3b.py +0 -73
- {ai_edge_torch_nightly-0.3.0.dev20241002.dist-info → ai_edge_torch_nightly-0.3.0.dev20241005.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20241002.dist-info → ai_edge_torch_nightly-0.3.0.dev20241005.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20241002.dist-info → ai_edge_torch_nightly-0.3.0.dev20241005.dist-info}/top_level.txt +0 -0
@@ -15,28 +15,10 @@
|
|
15
15
|
|
16
16
|
"""Example of building Qwen 2.5 models."""
|
17
17
|
|
18
|
-
import copy
|
19
|
-
|
20
|
-
from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
|
21
18
|
import ai_edge_torch.generative.layers.model_config as cfg
|
22
|
-
|
23
|
-
from torch import nn
|
24
|
-
|
25
|
-
TENSOR_NAMES = copy.copy(tiny_llama.TENSOR_NAMES)
|
26
|
-
# Qwen re-uses the embedding as the head projection layer.
|
27
|
-
TENSOR_NAMES.lm_head = None
|
28
|
-
|
29
|
-
|
30
|
-
class Qwen(tiny_llama.TinyLlama):
|
31
|
-
"""A Qwen model built from the Edge Generative API layers.
|
32
|
-
|
33
|
-
Qwen 2.5 shares the same architecture as TinyLlama.
|
34
|
-
"""
|
19
|
+
from ai_edge_torch.generative.utilities import model_builder
|
35
20
|
|
36
|
-
|
37
|
-
super().__init__(config)
|
38
|
-
# Qwen re-uses the embedding as the head projection layer.
|
39
|
-
self.lm_head.weight.data = self.tok_embedding.weight.data
|
21
|
+
TENSOR_NAMES = model_builder.TENSOR_NAMES
|
40
22
|
|
41
23
|
|
42
24
|
def get_3b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
@@ -119,23 +101,31 @@ def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
|
|
119
101
|
return config
|
120
102
|
|
121
103
|
|
122
|
-
def
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
def build_3b_model(checkpoint_path: str, **kwargs) -> nn.Module:
|
133
|
-
return _build_model(checkpoint_path, get_3b_model_config(**kwargs))
|
104
|
+
def build_3b_model(
|
105
|
+
checkpoint_path: str, **kwargs
|
106
|
+
) -> model_builder.DecoderOnlyModel:
|
107
|
+
return model_builder.build_decoder_only_model(
|
108
|
+
checkpoint_path=checkpoint_path,
|
109
|
+
config=get_3b_model_config(**kwargs),
|
110
|
+
tensor_names=TENSOR_NAMES,
|
111
|
+
)
|
134
112
|
|
135
113
|
|
136
|
-
def build_1_5b_model(
|
137
|
-
|
114
|
+
def build_1_5b_model(
|
115
|
+
checkpoint_path: str, **kwargs
|
116
|
+
) -> model_builder.DecoderOnlyModel:
|
117
|
+
return model_builder.build_decoder_only_model(
|
118
|
+
checkpoint_path=checkpoint_path,
|
119
|
+
config=get_1_5b_model_config(**kwargs),
|
120
|
+
tensor_names=TENSOR_NAMES,
|
121
|
+
)
|
138
122
|
|
139
123
|
|
140
|
-
def build_0_5b_model(
|
141
|
-
|
124
|
+
def build_0_5b_model(
|
125
|
+
checkpoint_path: str, **kwargs
|
126
|
+
) -> model_builder.DecoderOnlyModel:
|
127
|
+
return model_builder.build_decoder_only_model(
|
128
|
+
checkpoint_path=checkpoint_path,
|
129
|
+
config=get_0_5b_model_config(**kwargs),
|
130
|
+
tensor_names=TENSOR_NAMES,
|
131
|
+
)
|
@@ -15,29 +15,10 @@
|
|
15
15
|
|
16
16
|
"""Example of building a SmolLM model."""
|
17
17
|
|
18
|
-
import copy
|
19
|
-
|
20
|
-
from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
|
21
18
|
import ai_edge_torch.generative.layers.model_config as cfg
|
22
|
-
|
23
|
-
from torch import nn
|
24
|
-
|
25
|
-
TENSOR_NAMES = copy.copy(tiny_llama.TENSOR_NAMES)
|
26
|
-
# SmolLM re-uses the embedding as the head projection layer.
|
27
|
-
TENSOR_NAMES.lm_head = None
|
28
|
-
|
29
|
-
|
30
|
-
class SmolLM(tiny_llama.TinyLlama):
|
31
|
-
"""A SmolLM model built from the Edge Generative API layers.
|
19
|
+
from ai_edge_torch.generative.utilities import model_builder
|
32
20
|
|
33
|
-
|
34
|
-
sizes.
|
35
|
-
"""
|
36
|
-
|
37
|
-
def __init__(self, config: cfg.ModelConfig):
|
38
|
-
super().__init__(config)
|
39
|
-
# SmolLM re-uses the embedding as the head projection layer.
|
40
|
-
self.lm_head.weight.data = self.tok_embedding.weight.data
|
21
|
+
TENSOR_NAMES = model_builder.TENSOR_NAMES
|
41
22
|
|
42
23
|
|
43
24
|
def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
@@ -91,12 +72,11 @@ def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
|
|
91
72
|
return config
|
92
73
|
|
93
74
|
|
94
|
-
def build_model(
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
return model
|
75
|
+
def build_model(
|
76
|
+
checkpoint_path: str, **kwargs
|
77
|
+
) -> model_builder.DecoderOnlyModel:
|
78
|
+
return model_builder.build_decoder_only_model(
|
79
|
+
checkpoint_path=checkpoint_path,
|
80
|
+
config=get_model_config(**kwargs),
|
81
|
+
tensor_names=TENSOR_NAMES,
|
82
|
+
)
|
@@ -75,9 +75,7 @@ class CLIP(nn.Module):
|
|
75
75
|
)
|
76
76
|
|
77
77
|
@torch.inference_mode
|
78
|
-
def forward(self, tokens: torch.
|
79
|
-
tokens = tokens.type(torch.int)
|
80
|
-
|
78
|
+
def forward(self, tokens: torch.IntTensor) -> torch.FloatTensor:
|
81
79
|
state = self.tok_embedding(tokens) + self.tok_embedding_position
|
82
80
|
for layer in self.transformer_blocks:
|
83
81
|
state = layer(state, mask=self.mask_cache)
|
@@ -13,47 +13,54 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
|
-
import argparse
|
17
16
|
import os
|
18
|
-
|
19
|
-
from typing import Optional
|
17
|
+
import pathlib
|
20
18
|
|
19
|
+
from absl import app
|
20
|
+
from absl import flags
|
21
21
|
import ai_edge_torch
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
from ai_edge_torch.generative.examples.stable_diffusion
|
26
|
-
import ai_edge_torch.generative.examples.stable_diffusion.util as util
|
22
|
+
from ai_edge_torch.generative.examples.stable_diffusion import clip
|
23
|
+
from ai_edge_torch.generative.examples.stable_diffusion import decoder
|
24
|
+
from ai_edge_torch.generative.examples.stable_diffusion import diffusion
|
25
|
+
from ai_edge_torch.generative.examples.stable_diffusion import util
|
27
26
|
from ai_edge_torch.generative.quantize import quant_recipes
|
28
|
-
|
27
|
+
from ai_edge_torch.generative.utilities import stable_diffusion_loader
|
29
28
|
import torch
|
30
29
|
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
type=str,
|
30
|
+
_CLIP_CKPT = flags.DEFINE_string(
|
31
|
+
'clip_ckpt',
|
32
|
+
None,
|
35
33
|
help='Path to source CLIP model checkpoint',
|
36
34
|
required=True,
|
37
35
|
)
|
38
|
-
|
39
|
-
|
40
|
-
|
36
|
+
|
37
|
+
_DIFFUSION_CKPT = flags.DEFINE_string(
|
38
|
+
'diffusion_ckpt',
|
39
|
+
None,
|
41
40
|
help='Path to source diffusion model checkpoint',
|
42
41
|
required=True,
|
43
42
|
)
|
44
|
-
|
45
|
-
|
46
|
-
|
43
|
+
|
44
|
+
_DECODER_CKPT = flags.DEFINE_string(
|
45
|
+
'decoder_ckpt',
|
46
|
+
None,
|
47
47
|
help='Path to source image decoder model checkpoint',
|
48
48
|
required=True,
|
49
49
|
)
|
50
|
-
|
51
|
-
|
52
|
-
|
50
|
+
|
51
|
+
_OUTPUT_DIR = flags.DEFINE_string(
|
52
|
+
'output_dir',
|
53
|
+
None,
|
53
54
|
help='Path to the converted TF Lite directory.',
|
54
55
|
required=True,
|
55
56
|
)
|
56
57
|
|
58
|
+
_QUANTIZE = flags.DEFINE_bool(
|
59
|
+
'quantize',
|
60
|
+
help='Whether to quantize the model during conversion.',
|
61
|
+
default=True,
|
62
|
+
)
|
63
|
+
|
57
64
|
|
58
65
|
@torch.inference_mode
|
59
66
|
def convert_stable_diffusion_to_tflite(
|
@@ -111,7 +118,7 @@ def convert_stable_diffusion_to_tflite(
|
|
111
118
|
time_embedding = util.get_time_embedding(timestamp)
|
112
119
|
|
113
120
|
if not os.path.exists(output_dir):
|
114
|
-
Path(output_dir).mkdir(parents=True, exist_ok=True)
|
121
|
+
pathlib.Path(output_dir).mkdir(parents=True, exist_ok=True)
|
115
122
|
|
116
123
|
quant_config = (
|
117
124
|
quant_recipes.full_int8_weight_only_recipe() if quantize else None
|
@@ -142,14 +149,15 @@ def convert_stable_diffusion_to_tflite(
|
|
142
149
|
).export(f'{output_dir}/decoder.tflite')
|
143
150
|
|
144
151
|
|
145
|
-
|
146
|
-
args = arg_parser.parse_args()
|
152
|
+
def main(_):
|
147
153
|
convert_stable_diffusion_to_tflite(
|
148
|
-
output_dir=
|
149
|
-
clip_ckpt_path=
|
150
|
-
diffusion_ckpt_path=
|
151
|
-
decoder_ckpt_path=
|
152
|
-
|
153
|
-
image_width=512,
|
154
|
-
quantize=True,
|
154
|
+
output_dir=_OUTPUT_DIR.value,
|
155
|
+
clip_ckpt_path=_CLIP_CKPT.value,
|
156
|
+
diffusion_ckpt_path=_DIFFUSION_CKPT.value,
|
157
|
+
decoder_ckpt_path=_DECODER_CKPT.value,
|
158
|
+
quantize=_QUANTIZE.value,
|
155
159
|
)
|
160
|
+
|
161
|
+
|
162
|
+
if __name__ == '__main__':
|
163
|
+
app.run(main)
|
@@ -15,102 +15,10 @@
|
|
15
15
|
|
16
16
|
"""Example of building a TinyLlama model."""
|
17
17
|
|
18
|
-
from ai_edge_torch.generative.layers import attention
|
19
|
-
from ai_edge_torch.generative.layers import builder
|
20
|
-
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
21
|
-
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
22
18
|
import ai_edge_torch.generative.layers.model_config as cfg
|
23
|
-
|
24
|
-
import torch
|
25
|
-
from torch import nn
|
19
|
+
from ai_edge_torch.generative.utilities import model_builder
|
26
20
|
|
27
|
-
TENSOR_NAMES =
|
28
|
-
ff_up_proj="model.layers.{}.mlp.up_proj",
|
29
|
-
ff_down_proj="model.layers.{}.mlp.down_proj",
|
30
|
-
ff_gate_proj="model.layers.{}.mlp.gate_proj",
|
31
|
-
attn_query_proj="model.layers.{}.self_attn.q_proj",
|
32
|
-
attn_key_proj="model.layers.{}.self_attn.k_proj",
|
33
|
-
attn_value_proj="model.layers.{}.self_attn.v_proj",
|
34
|
-
attn_output_proj="model.layers.{}.self_attn.o_proj",
|
35
|
-
pre_attn_norm="model.layers.{}.input_layernorm",
|
36
|
-
post_attn_norm="model.layers.{}.post_attention_layernorm",
|
37
|
-
embedding="model.embed_tokens",
|
38
|
-
final_norm="model.norm",
|
39
|
-
lm_head="lm_head",
|
40
|
-
)
|
41
|
-
|
42
|
-
|
43
|
-
class TinyLlama(nn.Module):
|
44
|
-
"""A TinyLlama model built from the Edge Generative API layers."""
|
45
|
-
|
46
|
-
def __init__(self, config: cfg.ModelConfig):
|
47
|
-
super().__init__()
|
48
|
-
|
49
|
-
# Construct model layers.
|
50
|
-
self.lm_head = nn.Linear(
|
51
|
-
config.embedding_dim, config.vocab_size, bias=config.lm_head_use_bias
|
52
|
-
)
|
53
|
-
self.tok_embedding = nn.Embedding(
|
54
|
-
config.vocab_size, config.embedding_dim, padding_idx=0
|
55
|
-
)
|
56
|
-
# TinyLlama has only one block config.
|
57
|
-
block_config = config.block_config(0)
|
58
|
-
self.transformer_blocks = nn.ModuleList(
|
59
|
-
attention.TransformerBlock(block_config, config)
|
60
|
-
for _ in range(config.num_layers)
|
61
|
-
)
|
62
|
-
self.final_norm = builder.build_norm(
|
63
|
-
config.embedding_dim,
|
64
|
-
config.final_norm_config,
|
65
|
-
)
|
66
|
-
attn_config = block_config.attn_config
|
67
|
-
self.rope_cache = attn_utils.build_rope_cache(
|
68
|
-
size=config.kv_cache_max,
|
69
|
-
dim=int(attn_config.rotary_percentage * attn_config.head_dim),
|
70
|
-
base=attn_config.rotary_base,
|
71
|
-
)
|
72
|
-
self.mask_cache = attn_utils.build_causal_mask_cache(
|
73
|
-
size=config.kv_cache_max,
|
74
|
-
)
|
75
|
-
self.config = config
|
76
|
-
|
77
|
-
@torch.inference_mode
|
78
|
-
def forward(
|
79
|
-
self,
|
80
|
-
tokens: torch.Tensor,
|
81
|
-
input_pos: torch.Tensor,
|
82
|
-
kv_cache: kv_utils.KVCache,
|
83
|
-
) -> dict[torch.Tensor, kv_utils.KVCache]:
|
84
|
-
_, seq_len = tokens.size()
|
85
|
-
assert self.config.max_seq_len >= seq_len, (
|
86
|
-
f"Cannot forward sequence of length {seq_len}, max seq length is only"
|
87
|
-
f" {self.config.max_seq_len}"
|
88
|
-
)
|
89
|
-
assert len(self.transformer_blocks) == len(kv_cache.caches), (
|
90
|
-
"The number of transformer blocks and the number of KV cache entries"
|
91
|
-
" must be the same."
|
92
|
-
)
|
93
|
-
|
94
|
-
cos, sin = self.rope_cache
|
95
|
-
cos = cos.index_select(0, input_pos)
|
96
|
-
sin = sin.index_select(0, input_pos)
|
97
|
-
mask = self.mask_cache.index_select(2, input_pos)
|
98
|
-
mask = mask[:, :, :, : self.config.kv_cache_max]
|
99
|
-
|
100
|
-
# token embeddings of shape (b, t, n_embd)
|
101
|
-
x = self.tok_embedding(tokens)
|
102
|
-
|
103
|
-
updated_kv_entires = []
|
104
|
-
for i, block in enumerate(self.transformer_blocks):
|
105
|
-
kv_entry = kv_cache.caches[i] if kv_cache else None
|
106
|
-
x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
|
107
|
-
if kv_entry:
|
108
|
-
updated_kv_entires.append(kv_entry)
|
109
|
-
updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
|
110
|
-
|
111
|
-
x = self.final_norm(x)
|
112
|
-
logits = self.lm_head(x) # (b, t, vocab_size)
|
113
|
-
return {"logits": logits, "kv_cache": updated_kv_cache}
|
21
|
+
TENSOR_NAMES = model_builder.TENSOR_NAMES_WITH_SEPARATE_LM_HEAD
|
114
22
|
|
115
23
|
|
116
24
|
def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
@@ -150,6 +58,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
150
58
|
kv_cache_max_len=kv_cache_max_len,
|
151
59
|
block_configs=block_config,
|
152
60
|
final_norm_config=norm_config,
|
61
|
+
lm_head_share_weight_with_embedding=False,
|
153
62
|
enable_hlfb=True,
|
154
63
|
)
|
155
64
|
return config
|
@@ -164,10 +73,11 @@ def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
|
|
164
73
|
return config
|
165
74
|
|
166
75
|
|
167
|
-
def build_model(
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
|
76
|
+
def build_model(
|
77
|
+
checkpoint_path: str, **kwargs
|
78
|
+
) -> model_builder.DecoderOnlyModel:
|
79
|
+
return model_builder.build_decoder_only_model(
|
80
|
+
checkpoint_path=checkpoint_path,
|
81
|
+
config=get_model_config(**kwargs),
|
82
|
+
tensor_names=TENSOR_NAMES,
|
83
|
+
)
|
@@ -184,8 +184,14 @@ class ModelConfig:
|
|
184
184
|
default_factory=NormalizationConfig
|
185
185
|
)
|
186
186
|
|
187
|
+
# Scale factor of the embedding.
|
188
|
+
embedding_scale: Optional[float] = None
|
189
|
+
|
187
190
|
# Use bias term within LLM's HEAD.
|
188
191
|
lm_head_use_bias: bool = False
|
192
|
+
# Whether LLM's HEAD shares the weight of the embedding.
|
193
|
+
lm_head_share_weight_with_embedding: bool = True
|
194
|
+
|
189
195
|
# Whether to turn on high-level function boundary.
|
190
196
|
enable_hlfb: bool = False
|
191
197
|
|
@@ -19,6 +19,7 @@ import tempfile
|
|
19
19
|
|
20
20
|
from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
|
21
21
|
from ai_edge_torch.generative.utilities import loader as loading_utils
|
22
|
+
from ai_edge_torch.generative.utilities import model_builder
|
22
23
|
import safetensors.torch
|
23
24
|
import torch
|
24
25
|
|
@@ -71,7 +72,7 @@ class TestLoader(googletest.TestCase):
|
|
71
72
|
safetensors.torch.save_file(test_weights, file_path)
|
72
73
|
cfg = tiny_llama.get_model_config()
|
73
74
|
cfg.num_layers = 1
|
74
|
-
model =
|
75
|
+
model = model_builder.DecoderOnlyModel(cfg)
|
75
76
|
|
76
77
|
loader = loading_utils.ModelLoader(file_path, tiny_llama.TENSOR_NAMES)
|
77
78
|
# if returns successfully, it means all the tensors were initiallized.
|
@@ -21,6 +21,7 @@ from ai_edge_torch.generative.examples.test_models import toy_model_with_kv_cach
|
|
21
21
|
from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
|
22
22
|
from ai_edge_torch.generative.layers import kv_cache
|
23
23
|
from ai_edge_torch.generative.test import utils as test_utils
|
24
|
+
from ai_edge_torch.generative.utilities import model_builder
|
24
25
|
import numpy as np
|
25
26
|
import torch
|
26
27
|
|
@@ -42,31 +43,40 @@ class TestModelConversion(googletest.TestCase):
|
|
42
43
|
)
|
43
44
|
)
|
44
45
|
|
45
|
-
def
|
46
|
+
def _get_params(self, enable_hlfb: bool):
|
47
|
+
"""Returns a model, edge model and the kwargs to use for testing."""
|
48
|
+
config = toy_model_with_kv_cache.get_model_config()
|
49
|
+
config.enable_hlfb = enable_hlfb
|
50
|
+
pytorch_model = toy_model_with_kv_cache.ToyModelWithKVCache(config).eval()
|
46
51
|
tokens, input_pos = torch.tensor([[1]], dtype=torch.int), torch.tensor(
|
47
52
|
[10], dtype=torch.int
|
48
53
|
)
|
49
54
|
kv = kv_cache.KVCache.from_model_config(config)
|
55
|
+
kwargs = {
|
56
|
+
"tokens": tokens,
|
57
|
+
"input_pos": input_pos,
|
58
|
+
"kv_cache": kv,
|
59
|
+
}
|
50
60
|
|
51
61
|
edge_model = ai_edge_torch.convert(
|
52
62
|
pytorch_model,
|
53
|
-
sample_kwargs=
|
54
|
-
"tokens": tokens,
|
55
|
-
"input_pos": input_pos,
|
56
|
-
"kv_cache": kv,
|
57
|
-
},
|
63
|
+
sample_kwargs=kwargs,
|
58
64
|
)
|
59
65
|
edge_model.set_interpreter_builder(
|
60
66
|
self._interpreter_builder(edge_model.tflite_model())
|
61
67
|
)
|
68
|
+
return pytorch_model, edge_model, kwargs
|
69
|
+
|
70
|
+
def _test_model_with_kv_cache(self, enable_hlfb: bool):
|
71
|
+
pytorch_model, edge_model, kwargs = self._get_params(enable_hlfb)
|
62
72
|
|
63
73
|
self.assertTrue(
|
64
74
|
test_utils.compare_tflite_torch(
|
65
75
|
edge_model,
|
66
76
|
pytorch_model,
|
67
|
-
tokens,
|
68
|
-
input_pos,
|
69
|
-
|
77
|
+
kwargs["tokens"],
|
78
|
+
kwargs["input_pos"],
|
79
|
+
kwargs["kv_cache"],
|
70
80
|
signature_name="serving_default",
|
71
81
|
atol=1e-5,
|
72
82
|
rtol=1e-5,
|
@@ -78,19 +88,31 @@ class TestModelConversion(googletest.TestCase):
|
|
78
88
|
reason="tests with custom ops are not supported on oss",
|
79
89
|
)
|
80
90
|
def test_toy_model_with_kv_cache(self):
|
81
|
-
|
82
|
-
pytorch_model = toy_model_with_kv_cache.ToyModelWithKVCache(config).eval()
|
83
|
-
self._test_model_with_kv_cache(config, pytorch_model)
|
91
|
+
self._test_model_with_kv_cache(enable_hlfb=False)
|
84
92
|
|
85
93
|
@googletest.skipIf(
|
86
94
|
ai_edge_config.Config.use_torch_xla,
|
87
95
|
reason="tests with custom ops are not supported on oss",
|
88
96
|
)
|
89
97
|
def test_toy_model_with_kv_cache_with_hlfb(self):
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
98
|
+
self._test_model_with_kv_cache(enable_hlfb=True)
|
99
|
+
|
100
|
+
@googletest.skipIf(
|
101
|
+
ai_edge_config.Config.use_torch_xla,
|
102
|
+
reason="tests with custom ops are not supported on oss",
|
103
|
+
)
|
104
|
+
def test_toy_model_has_ekv_op(self):
|
105
|
+
"""Tests that the model has the external kv cache op."""
|
106
|
+
_, edge_model, _ = self._get_params(enable_hlfb=True)
|
107
|
+
interpreter_ = interpreter.InterpreterWithCustomOps(
|
108
|
+
custom_op_registerers=["GenAIOpsRegisterer"],
|
109
|
+
model_content=edge_model.tflite_model(),
|
110
|
+
experimental_default_delegate_latest_features=True,
|
111
|
+
)
|
112
|
+
|
113
|
+
# pylint: disable=protected-access
|
114
|
+
op_names = [op["op_name"] for op in interpreter_._get_ops_details()]
|
115
|
+
self.assertIn("odml.update_external_kv_cache", op_names)
|
94
116
|
|
95
117
|
def _test_multisig_model(self, config, pytorch_model, atol, rtol):
|
96
118
|
# prefill
|
@@ -163,7 +185,7 @@ class TestModelConversion(googletest.TestCase):
|
|
163
185
|
)
|
164
186
|
def test_tiny_llama_multisig(self):
|
165
187
|
config = tiny_llama.get_fake_model_config()
|
166
|
-
pytorch_model =
|
188
|
+
pytorch_model = model_builder.DecoderOnlyModel(config).eval()
|
167
189
|
self._test_multisig_model(config, pytorch_model, atol=1e-5, rtol=1e-5)
|
168
190
|
|
169
191
|
|
@@ -29,6 +29,7 @@ from ai_edge_torch.generative.examples.stable_diffusion import clip as sd_clip
|
|
29
29
|
from ai_edge_torch.generative.examples.stable_diffusion import decoder as sd_decoder
|
30
30
|
from ai_edge_torch.generative.examples.stable_diffusion import diffusion as sd_diffusion
|
31
31
|
from ai_edge_torch.generative.layers import kv_cache
|
32
|
+
from ai_edge_torch.generative.utilities import model_builder
|
32
33
|
from ai_edge_torch.generative.test import utils as test_utils
|
33
34
|
import numpy as np
|
34
35
|
import torch
|
@@ -90,7 +91,7 @@ class TestModelConversion(googletest.TestCase):
|
|
90
91
|
)
|
91
92
|
def test_gemma1(self):
|
92
93
|
config = gemma1.get_fake_model_config()
|
93
|
-
pytorch_model =
|
94
|
+
pytorch_model = model_builder.DecoderOnlyModel(config).eval()
|
94
95
|
self._test_model(
|
95
96
|
config, pytorch_model, "serving_default", atol=1e-2, rtol=1e-5
|
96
97
|
)
|
@@ -119,7 +120,7 @@ class TestModelConversion(googletest.TestCase):
|
|
119
120
|
)
|
120
121
|
def test_phi2(self):
|
121
122
|
config = phi2.get_fake_model_config()
|
122
|
-
pytorch_model =
|
123
|
+
pytorch_model = model_builder.DecoderOnlyModel(config).eval()
|
123
124
|
self._test_model(
|
124
125
|
config, pytorch_model, "serving_default", atol=1e-3, rtol=1e-3
|
125
126
|
)
|
@@ -139,7 +140,7 @@ class TestModelConversion(googletest.TestCase):
|
|
139
140
|
)
|
140
141
|
def test_smollm(self):
|
141
142
|
config = smollm.get_fake_model_config()
|
142
|
-
pytorch_model =
|
143
|
+
pytorch_model = model_builder.DecoderOnlyModel(config).eval()
|
143
144
|
self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
|
144
145
|
|
145
146
|
@googletest.skipIf(
|
@@ -148,7 +149,7 @@ class TestModelConversion(googletest.TestCase):
|
|
148
149
|
)
|
149
150
|
def test_openelm(self):
|
150
151
|
config = openelm.get_fake_model_config()
|
151
|
-
pytorch_model =
|
152
|
+
pytorch_model = model_builder.DecoderOnlyModel(config).eval()
|
152
153
|
self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
|
153
154
|
|
154
155
|
@googletest.skipIf(
|
@@ -157,7 +158,7 @@ class TestModelConversion(googletest.TestCase):
|
|
157
158
|
)
|
158
159
|
def test_qwen(self):
|
159
160
|
config = qwen.get_fake_model_config()
|
160
|
-
pytorch_model =
|
161
|
+
pytorch_model = model_builder.DecoderOnlyModel(config).eval()
|
161
162
|
self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)
|
162
163
|
|
163
164
|
@googletest.skipIf(
|