ai-edge-torch-nightly 0.3.0.dev20240923__py3-none-any.whl → 0.3.0.dev20240925__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/generative/examples/openelm/openelm.py +1 -3
- ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py +68 -0
- ai_edge_torch/generative/examples/phi/phi3.py +286 -0
- ai_edge_torch/generative/examples/phi/verify.py +1 -1
- ai_edge_torch/generative/examples/phi/verify_phi3.py +68 -0
- ai_edge_torch/generative/examples/stable_diffusion/clip.py +52 -1
- ai_edge_torch/generative/examples/stable_diffusion/decoder.py +56 -0
- ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +69 -1
- ai_edge_torch/generative/examples/test_models/convert_toy_model.py +105 -0
- ai_edge_torch/generative/examples/test_models/toy_model.py +2 -31
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +2 -56
- ai_edge_torch/generative/layers/builder.py +25 -24
- ai_edge_torch/generative/layers/model_config.py +3 -3
- ai_edge_torch/generative/layers/normalization.py +14 -3
- ai_edge_torch/generative/layers/unet/blocks_2d.py +2 -2
- ai_edge_torch/generative/test/test_model_conversion_large.py +119 -0
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240923.dist-info → ai_edge_torch_nightly-0.3.0.dev20240925.dist-info}/METADATA +2 -1
- {ai_edge_torch_nightly-0.3.0.dev20240923.dist-info → ai_edge_torch_nightly-0.3.0.dev20240925.dist-info}/RECORD +22 -18
- {ai_edge_torch_nightly-0.3.0.dev20240923.dist-info → ai_edge_torch_nightly-0.3.0.dev20240925.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240923.dist-info → ai_edge_torch_nightly-0.3.0.dev20240925.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240923.dist-info → ai_edge_torch_nightly-0.3.0.dev20240925.dist-info}/top_level.txt +0 -0
@@ -161,9 +161,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
161
161
|
),
|
162
162
|
ff_config=cfg.FeedForwardConfig(
|
163
163
|
type=cfg.FeedForwardType.SEQUENTIAL,
|
164
|
-
activation=cfg.ActivationConfig(
|
165
|
-
cfg.ActivationType.SILU_GLU, gate_is_front=True
|
166
|
-
),
|
164
|
+
activation=cfg.ActivationConfig(cfg.ActivationType.SILU_GLU),
|
167
165
|
intermediate_size=get_intermediate_size(idx),
|
168
166
|
pre_ff_norm_config=norm_config,
|
169
167
|
),
|
@@ -0,0 +1,68 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Example of converting a Phi-3.5 model to multi-signature tflite model."""
|
17
|
+
|
18
|
+
import os
|
19
|
+
import pathlib
|
20
|
+
|
21
|
+
from absl import app
|
22
|
+
from absl import flags
|
23
|
+
from ai_edge_torch.generative.examples.phi import phi3
|
24
|
+
from ai_edge_torch.generative.utilities import converter
|
25
|
+
|
26
|
+
_CHECKPOINT_PATH = flags.DEFINE_string(
|
27
|
+
'checkpoint_path',
|
28
|
+
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/phi3'),
|
29
|
+
'The path to the model checkpoint, or directory holding the checkpoint.',
|
30
|
+
)
|
31
|
+
_TFLITE_PATH = flags.DEFINE_string(
|
32
|
+
'tflite_path',
|
33
|
+
'/tmp/',
|
34
|
+
'The tflite file path to export.',
|
35
|
+
)
|
36
|
+
_PREFILL_SEQ_LEN = flags.DEFINE_integer(
|
37
|
+
'prefill_seq_len',
|
38
|
+
1024,
|
39
|
+
'The maximum size of prefill input tensor.',
|
40
|
+
)
|
41
|
+
_KV_CACHE_MAX_LEN = flags.DEFINE_integer(
|
42
|
+
'kv_cache_max_len',
|
43
|
+
1280,
|
44
|
+
'The maximum size of KV cache buffer, including both prefill and decode.',
|
45
|
+
)
|
46
|
+
_QUANTIZE = flags.DEFINE_bool(
|
47
|
+
'quantize',
|
48
|
+
True,
|
49
|
+
'Whether the model should be quantized.',
|
50
|
+
)
|
51
|
+
|
52
|
+
|
53
|
+
def main(_):
|
54
|
+
pytorch_model = phi3.build_model(
|
55
|
+
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
|
56
|
+
)
|
57
|
+
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
|
58
|
+
output_filename = f'phi3_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
|
59
|
+
converter.convert_to_tflite(
|
60
|
+
pytorch_model,
|
61
|
+
tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
|
62
|
+
prefill_seq_len=_PREFILL_SEQ_LEN.value,
|
63
|
+
quantize=_QUANTIZE.value,
|
64
|
+
)
|
65
|
+
|
66
|
+
|
67
|
+
if __name__ == '__main__':
|
68
|
+
app.run(main)
|
@@ -0,0 +1,286 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Example of building a Phi-3.5 model up to 4K tokens, not to 128K tokens."""
|
17
|
+
|
18
|
+
import math
|
19
|
+
from typing import Tuple
|
20
|
+
|
21
|
+
from ai_edge_torch.generative.layers import attention
|
22
|
+
from ai_edge_torch.generative.layers import builder
|
23
|
+
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
24
|
+
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
25
|
+
import ai_edge_torch.generative.layers.model_config as cfg
|
26
|
+
import ai_edge_torch.generative.utilities.loader as loading_utils
|
27
|
+
import torch
|
28
|
+
from torch import nn
|
29
|
+
|
30
|
+
TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
31
|
+
ff_up_proj="model.layers.{}.mlp.gate_up_proj",
|
32
|
+
ff_down_proj="model.layers.{}.mlp.down_proj",
|
33
|
+
attn_fused_qkv_proj="model.layers.{}.self_attn.qkv_proj",
|
34
|
+
attn_output_proj="model.layers.{}.self_attn.o_proj",
|
35
|
+
pre_attn_norm="model.layers.{}.input_layernorm",
|
36
|
+
post_attn_norm="model.layers.{}.post_attention_layernorm",
|
37
|
+
embedding="model.embed_tokens",
|
38
|
+
final_norm="model.norm",
|
39
|
+
lm_head="lm_head",
|
40
|
+
)
|
41
|
+
|
42
|
+
# max_position_embeddings / original_max_position_embeddings in Phi-3.5 config.
|
43
|
+
ROPE_SCALE_FACTOR = 32
|
44
|
+
|
45
|
+
# ROPE short factor in Phi-3.5 config. According to LOPE paper and its code in
|
46
|
+
# https://github.com/microsoft/LongRoPE, these values had been searched with
|
47
|
+
# min=1.0, step-0.01 to optimize the errors of sample dataset.
|
48
|
+
ROPE_SHORT_FACTOR = [
|
49
|
+
1.0,
|
50
|
+
1.0199999809265137,
|
51
|
+
1.0299999713897705,
|
52
|
+
1.0299999713897705,
|
53
|
+
1.0499999523162842,
|
54
|
+
1.0499999523162842,
|
55
|
+
1.0499999523162842,
|
56
|
+
1.0499999523162842,
|
57
|
+
1.0499999523162842,
|
58
|
+
1.0699999332427979,
|
59
|
+
1.0999999046325684,
|
60
|
+
1.1099998950958252,
|
61
|
+
1.1599998474121094,
|
62
|
+
1.1599998474121094,
|
63
|
+
1.1699998378753662,
|
64
|
+
1.2899998426437378,
|
65
|
+
1.339999794960022,
|
66
|
+
1.679999828338623,
|
67
|
+
1.7899998426437378,
|
68
|
+
1.8199998140335083,
|
69
|
+
1.8499997854232788,
|
70
|
+
1.8799997568130493,
|
71
|
+
1.9099997282028198,
|
72
|
+
1.9399996995925903,
|
73
|
+
1.9899996519088745,
|
74
|
+
2.0199997425079346,
|
75
|
+
2.0199997425079346,
|
76
|
+
2.0199997425079346,
|
77
|
+
2.0199997425079346,
|
78
|
+
2.0199997425079346,
|
79
|
+
2.0199997425079346,
|
80
|
+
2.0299997329711914,
|
81
|
+
2.0299997329711914,
|
82
|
+
2.0299997329711914,
|
83
|
+
2.0299997329711914,
|
84
|
+
2.0299997329711914,
|
85
|
+
2.0299997329711914,
|
86
|
+
2.0299997329711914,
|
87
|
+
2.0299997329711914,
|
88
|
+
2.0299997329711914,
|
89
|
+
2.0799996852874756,
|
90
|
+
2.0899996757507324,
|
91
|
+
2.189999580383301,
|
92
|
+
2.2199995517730713,
|
93
|
+
2.5899994373321533,
|
94
|
+
2.729999542236328,
|
95
|
+
2.749999523162842,
|
96
|
+
2.8399994373321533,
|
97
|
+
]
|
98
|
+
|
99
|
+
|
100
|
+
def build_rope_cache(
|
101
|
+
size: int,
|
102
|
+
dim: int,
|
103
|
+
base: int = 10000,
|
104
|
+
condense_ratio: int = 1,
|
105
|
+
dtype: torch.dtype = torch.float32,
|
106
|
+
device: torch.device = None,
|
107
|
+
theta_factors: torch.Tensor = None,
|
108
|
+
scale: float = 1.0,
|
109
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
110
|
+
"""Precomputes Rotary Positional Embeddings for Phi-3.5 model.
|
111
|
+
|
112
|
+
It's a modified version of attn_utils.build_rope_cache with additional
|
113
|
+
arguments for Phi-3.5 model. It precompute Rotary Positional Embedding Sin and
|
114
|
+
Cos values with scaling factors for quick lookup during the inference.
|
115
|
+
|
116
|
+
Args:
|
117
|
+
size (int): The size of the built cache.
|
118
|
+
dim (int): Each sequence's dimmension.
|
119
|
+
base (int, optional): Rope base value. Defaults to 10000.
|
120
|
+
condense_ratio (int, optional): The ratio by which sequence indicies are
|
121
|
+
condensed. Defaults to 1.
|
122
|
+
dtype (torch.dtype, optional): Output tensor's data type. Defaults to
|
123
|
+
torch.float32.
|
124
|
+
device (torch.device, optional): Output tensor's data type. Defaults to
|
125
|
+
None in which case "cpu" is used.
|
126
|
+
theta_factors (torch.Tensor, optional): A tensor of shape (dim,) used to
|
127
|
+
scale the theta values. Defaults to None.
|
128
|
+
scale (float, optional): A float used to scale the rope values. Defaults
|
129
|
+
to 1.0.
|
130
|
+
|
131
|
+
Returns:
|
132
|
+
Tuple[torch.Tensor, torch.Tensor]: Rope's Cosine and Sine waves.
|
133
|
+
"""
|
134
|
+
if device is None:
|
135
|
+
device = torch.device('cpu')
|
136
|
+
theta = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
|
137
|
+
if theta_factors is not None:
|
138
|
+
theta = theta / theta_factors
|
139
|
+
seq_idx = torch.arange(size) / condense_ratio
|
140
|
+
idx_theta = torch.outer(seq_idx, theta)
|
141
|
+
cos = torch.cos(idx_theta).to(dtype=dtype, device=device) * scale
|
142
|
+
sin = torch.sin(idx_theta).to(dtype=dtype, device=device) * scale
|
143
|
+
return cos, sin
|
144
|
+
|
145
|
+
|
146
|
+
class Phi3_5Mini(nn.Module):
|
147
|
+
"""A Phi-3.5 model built from the Edge Generative API layers."""
|
148
|
+
|
149
|
+
def __init__(self, config: cfg.ModelConfig):
|
150
|
+
super().__init__()
|
151
|
+
|
152
|
+
# Construct model layers.
|
153
|
+
self.lm_head = nn.Linear(
|
154
|
+
config.embedding_dim, config.vocab_size, bias=config.lm_head_use_bias
|
155
|
+
)
|
156
|
+
self.tok_embedding = nn.Embedding(
|
157
|
+
config.vocab_size, config.embedding_dim, padding_idx=0
|
158
|
+
)
|
159
|
+
# Phi-3.5 has only one block config.
|
160
|
+
block_config = config.block_config(0)
|
161
|
+
self.transformer_blocks = nn.ModuleList(
|
162
|
+
attention.TransformerBlock(block_config, config)
|
163
|
+
for _ in range(config.num_layers)
|
164
|
+
)
|
165
|
+
self.final_norm = builder.build_norm(
|
166
|
+
config.embedding_dim,
|
167
|
+
config.final_norm_config,
|
168
|
+
)
|
169
|
+
attn_config = block_config.attn_config
|
170
|
+
self.rope_cache = build_rope_cache(
|
171
|
+
size=config.kv_cache_max,
|
172
|
+
dim=int(attn_config.rotary_percentage * attn_config.head_dim),
|
173
|
+
base=10_000,
|
174
|
+
condense_ratio=1,
|
175
|
+
dtype=torch.float32,
|
176
|
+
device=torch.device("cpu"),
|
177
|
+
theta_factors=torch.tensor(ROPE_SHORT_FACTOR),
|
178
|
+
scale=math.sqrt(
|
179
|
+
1 + math.log(ROPE_SCALE_FACTOR) / math.log(config.max_seq_len)
|
180
|
+
),
|
181
|
+
)
|
182
|
+
self.mask_cache = attn_utils.build_causal_mask_cache(
|
183
|
+
size=config.kv_cache_max,
|
184
|
+
dtype=torch.float32,
|
185
|
+
device=torch.device("cpu"),
|
186
|
+
)
|
187
|
+
self.config = config
|
188
|
+
|
189
|
+
@torch.inference_mode
|
190
|
+
def forward(
|
191
|
+
self,
|
192
|
+
tokens: torch.Tensor,
|
193
|
+
input_pos: torch.Tensor,
|
194
|
+
kv_cache: kv_utils.KVCache,
|
195
|
+
) -> dict[torch.Tensor, kv_utils.KVCache]:
|
196
|
+
_, seq_len = tokens.size()
|
197
|
+
assert self.config.max_seq_len >= seq_len, (
|
198
|
+
f"Cannot forward sequence of length {seq_len}, max seq length is only"
|
199
|
+
f" {self.config.max_seq_len}"
|
200
|
+
)
|
201
|
+
assert len(self.transformer_blocks) == len(kv_cache.caches), (
|
202
|
+
"The number of transformer blocks and the number of KV cache entries"
|
203
|
+
" must be the same."
|
204
|
+
)
|
205
|
+
|
206
|
+
cos, sin = self.rope_cache
|
207
|
+
cos = cos.index_select(0, input_pos)
|
208
|
+
sin = sin.index_select(0, input_pos)
|
209
|
+
mask = self.mask_cache.index_select(2, input_pos)
|
210
|
+
mask = mask[:, :, :, : self.config.kv_cache_max]
|
211
|
+
|
212
|
+
x = self.tok_embedding(tokens)
|
213
|
+
|
214
|
+
updated_kv_entires = []
|
215
|
+
for i, block in enumerate(self.transformer_blocks):
|
216
|
+
kv_entry = kv_cache.caches[i] if kv_cache else None
|
217
|
+
x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
|
218
|
+
if kv_entry:
|
219
|
+
updated_kv_entires.append(kv_entry)
|
220
|
+
updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
|
221
|
+
|
222
|
+
x = self.final_norm(x)
|
223
|
+
logits = self.lm_head(x) # (b, t, vocab_size)
|
224
|
+
return {"logits": logits, "kv_cache": updated_kv_cache}
|
225
|
+
|
226
|
+
|
227
|
+
def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
228
|
+
"""Returns the model config for a Phi-3.5 model.
|
229
|
+
|
230
|
+
Args:
|
231
|
+
kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
|
232
|
+
is 1024.
|
233
|
+
|
234
|
+
Returns:
|
235
|
+
The model config for a Phi-2 model.
|
236
|
+
"""
|
237
|
+
attn_config = cfg.AttentionConfig(
|
238
|
+
num_heads=32,
|
239
|
+
head_dim=96,
|
240
|
+
num_query_groups=32,
|
241
|
+
rotary_percentage=1.0,
|
242
|
+
qkv_transpose_before_split=True,
|
243
|
+
)
|
244
|
+
ff_config = cfg.FeedForwardConfig(
|
245
|
+
type=cfg.FeedForwardType.SEQUENTIAL,
|
246
|
+
activation=cfg.ActivationConfig(cfg.ActivationType.SILU_GLU),
|
247
|
+
intermediate_size=8192,
|
248
|
+
)
|
249
|
+
norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
|
250
|
+
block_config = cfg.TransformerBlockConfig(
|
251
|
+
attn_config=attn_config,
|
252
|
+
ff_config=ff_config,
|
253
|
+
pre_attention_norm_config=norm_config,
|
254
|
+
post_attention_norm_config=norm_config,
|
255
|
+
)
|
256
|
+
config = cfg.ModelConfig(
|
257
|
+
vocab_size=32064,
|
258
|
+
num_layers=32,
|
259
|
+
max_seq_len=4096,
|
260
|
+
kv_cache_max_len=kv_cache_max_len,
|
261
|
+
embedding_dim=3072,
|
262
|
+
block_configs=block_config,
|
263
|
+
final_norm_config=norm_config,
|
264
|
+
enable_hlfb=True,
|
265
|
+
)
|
266
|
+
return config
|
267
|
+
|
268
|
+
|
269
|
+
def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
|
270
|
+
config = get_model_config(kv_cache_max_len)
|
271
|
+
config.vocab_size = 128
|
272
|
+
config.num_layers = 2
|
273
|
+
config.max_seq_len = 2 * kv_cache_max_len
|
274
|
+
# Phi-3.5 has only one block config.
|
275
|
+
config.block_config(0).ff_config.intermediate_size = 128
|
276
|
+
return config
|
277
|
+
|
278
|
+
|
279
|
+
def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
|
280
|
+
"""Instantiates the model instance and load checkpoint if provided."""
|
281
|
+
config = get_model_config(**kwargs)
|
282
|
+
model = Phi3_5Mini(config)
|
283
|
+
loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
|
284
|
+
loader.load(model)
|
285
|
+
model.eval()
|
286
|
+
return model
|
@@ -27,13 +27,13 @@ _PROMPTS = flags.DEFINE_multi_string(
|
|
27
27
|
"Instruct: Write an email about the weather Output:",
|
28
28
|
"The input prompts to generate answers.",
|
29
29
|
)
|
30
|
-
|
31
30
|
_MAX_NEW_TOKENS = flags.DEFINE_integer(
|
32
31
|
"max_new_tokens",
|
33
32
|
30,
|
34
33
|
"The maximum size of the generated tokens.",
|
35
34
|
)
|
36
35
|
|
36
|
+
|
37
37
|
def main(_):
|
38
38
|
checkpoint = kagglehub.model_download("Microsoft/phi/transformers/2")
|
39
39
|
verifier.log_msg("Loading the original model from", checkpoint)
|
@@ -0,0 +1,68 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Verifies the reauthored Phi-3.5 model."""
|
17
|
+
|
18
|
+
import pathlib
|
19
|
+
|
20
|
+
from absl import app
|
21
|
+
from absl import flags
|
22
|
+
from ai_edge_torch.generative.examples.phi import phi3
|
23
|
+
from ai_edge_torch.generative.utilities import verifier
|
24
|
+
import transformers
|
25
|
+
|
26
|
+
_PROMPTS = flags.DEFINE_multi_string(
|
27
|
+
"prompts",
|
28
|
+
"Instruct: Write an email about the weather Output:",
|
29
|
+
"The input prompts to generate answers.",
|
30
|
+
)
|
31
|
+
_MAX_NEW_TOKENS = flags.DEFINE_integer(
|
32
|
+
"max_new_tokens",
|
33
|
+
30,
|
34
|
+
"The maximum size of the generated tokens.",
|
35
|
+
)
|
36
|
+
|
37
|
+
|
38
|
+
def main(_):
|
39
|
+
checkpoint = "microsoft/Phi-3.5-mini-instruct"
|
40
|
+
verifier.log_msg("Loading the original model from", checkpoint)
|
41
|
+
generation_config = transformers.GenerationConfig.from_pretrained(checkpoint)
|
42
|
+
generation_config.max_new_tokens = _MAX_NEW_TOKENS.value
|
43
|
+
wrapper_model = verifier.ModelWrapper(
|
44
|
+
model=transformers.AutoModelForCausalLM.from_pretrained(checkpoint),
|
45
|
+
hf_generation_config=generation_config,
|
46
|
+
)
|
47
|
+
|
48
|
+
# Locate the cached dir.
|
49
|
+
cached_config_file = transformers.utils.cached_file(
|
50
|
+
checkpoint, transformers.utils.CONFIG_NAME
|
51
|
+
)
|
52
|
+
reauthored_checkpoint = pathlib.Path(cached_config_file).parent
|
53
|
+
verifier.log_msg("Building the reauthored model from", reauthored_checkpoint)
|
54
|
+
reauthored_model = phi3.build_model(reauthored_checkpoint)
|
55
|
+
|
56
|
+
verifier.log_msg("Loading the tokenizer from", checkpoint)
|
57
|
+
tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
|
58
|
+
|
59
|
+
verifier.verify_reauthored_model(
|
60
|
+
original_model=wrapper_model,
|
61
|
+
reauthored_model=reauthored_model,
|
62
|
+
tokenizer=tokenizer,
|
63
|
+
generate_prompts=_PROMPTS.value,
|
64
|
+
)
|
65
|
+
|
66
|
+
|
67
|
+
if __name__ == "__main__":
|
68
|
+
app.run(main)
|
@@ -48,7 +48,7 @@ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
|
48
48
|
|
49
49
|
|
50
50
|
class CLIP(nn.Module):
|
51
|
-
"""CLIP text encoder
|
51
|
+
"""CLIP text encoder.
|
52
52
|
|
53
53
|
For details, see https://arxiv.org/abs/2103.00020
|
54
54
|
"""
|
@@ -86,6 +86,7 @@ class CLIP(nn.Module):
|
|
86
86
|
|
87
87
|
|
88
88
|
def get_model_config() -> cfg.ModelConfig:
|
89
|
+
"""Get configs for the CLIP of Stable Diffusion v1.5."""
|
89
90
|
max_seq_len = 77
|
90
91
|
vocab_size = 49408
|
91
92
|
num_layers = 12
|
@@ -132,3 +133,53 @@ def get_model_config() -> cfg.ModelConfig:
|
|
132
133
|
)
|
133
134
|
|
134
135
|
return config
|
136
|
+
|
137
|
+
|
138
|
+
def get_fake_model_config() -> cfg.ModelConfig:
|
139
|
+
"""Get fake configs for the CLIP of Stable Diffusion v1.5 for testing."""
|
140
|
+
max_seq_len = 6
|
141
|
+
vocab_size = 100
|
142
|
+
num_layers = 2
|
143
|
+
num_heads = 12
|
144
|
+
num_query_groups = 12
|
145
|
+
embedding_dim = 24
|
146
|
+
|
147
|
+
attn_config = cfg.AttentionConfig(
|
148
|
+
num_heads=num_heads,
|
149
|
+
head_dim=embedding_dim // num_heads,
|
150
|
+
num_query_groups=num_query_groups,
|
151
|
+
rotary_percentage=0.0,
|
152
|
+
qkv_use_bias=True,
|
153
|
+
qkv_transpose_before_split=True,
|
154
|
+
qkv_fused_interleaved=False,
|
155
|
+
output_proj_use_bias=True,
|
156
|
+
enable_kv_cache=False,
|
157
|
+
)
|
158
|
+
|
159
|
+
ff_config = cfg.FeedForwardConfig(
|
160
|
+
type=cfg.FeedForwardType.SEQUENTIAL,
|
161
|
+
activation=cfg.ActivationConfig(cfg.ActivationType.GELU_QUICK),
|
162
|
+
intermediate_size=embedding_dim * 4,
|
163
|
+
use_bias=True,
|
164
|
+
)
|
165
|
+
|
166
|
+
norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.LAYER_NORM)
|
167
|
+
|
168
|
+
block_config = cfg.TransformerBlockConfig(
|
169
|
+
attn_config=attn_config,
|
170
|
+
ff_config=ff_config,
|
171
|
+
pre_attention_norm_config=norm_config,
|
172
|
+
post_attention_norm_config=norm_config,
|
173
|
+
)
|
174
|
+
|
175
|
+
config = cfg.ModelConfig(
|
176
|
+
vocab_size=vocab_size,
|
177
|
+
num_layers=num_layers,
|
178
|
+
max_seq_len=max_seq_len,
|
179
|
+
embedding_dim=embedding_dim,
|
180
|
+
block_configs=block_config,
|
181
|
+
final_norm_config=norm_config,
|
182
|
+
enable_hlfb=True,
|
183
|
+
)
|
184
|
+
|
185
|
+
return config
|
@@ -324,3 +324,59 @@ def get_model_config() -> unet_cfg.AutoEncoderConfig:
|
|
324
324
|
mid_block_config=mid_block_config,
|
325
325
|
)
|
326
326
|
return config
|
327
|
+
|
328
|
+
|
329
|
+
def get_fake_model_config() -> unet_cfg.AutoEncoderConfig:
|
330
|
+
"""Get fake configs for the Decoder of Stable Diffusion v1.5 for testing."""
|
331
|
+
in_channels = 3
|
332
|
+
latent_channels = 4
|
333
|
+
out_channels = 3
|
334
|
+
block_out_channels = [2, 4]
|
335
|
+
scaling_factor = 0.18215
|
336
|
+
layers_per_block = 2
|
337
|
+
|
338
|
+
norm_config = layers_cfg.NormalizationConfig(
|
339
|
+
layers_cfg.NormalizationType.GROUP_NORM, group_num=2
|
340
|
+
)
|
341
|
+
|
342
|
+
att_config = unet_cfg.AttentionBlock2DConfig(
|
343
|
+
dim=block_out_channels[-1],
|
344
|
+
normalization_config=norm_config,
|
345
|
+
attention_config=layers_cfg.AttentionConfig(
|
346
|
+
num_heads=1,
|
347
|
+
head_dim=block_out_channels[-1],
|
348
|
+
num_query_groups=1,
|
349
|
+
qkv_use_bias=True,
|
350
|
+
output_proj_use_bias=True,
|
351
|
+
enable_kv_cache=False,
|
352
|
+
qkv_transpose_before_split=True,
|
353
|
+
qkv_fused_interleaved=False,
|
354
|
+
rotary_percentage=0.0,
|
355
|
+
),
|
356
|
+
enable_hlfb=False,
|
357
|
+
)
|
358
|
+
|
359
|
+
mid_block_config = unet_cfg.MidBlock2DConfig(
|
360
|
+
in_channels=block_out_channels[-1],
|
361
|
+
normalization_config=norm_config,
|
362
|
+
activation_config=layers_cfg.ActivationConfig(
|
363
|
+
layers_cfg.ActivationType.SILU
|
364
|
+
),
|
365
|
+
num_layers=1,
|
366
|
+
attention_block_config=att_config,
|
367
|
+
)
|
368
|
+
|
369
|
+
config = unet_cfg.AutoEncoderConfig(
|
370
|
+
in_channels=in_channels,
|
371
|
+
latent_channels=latent_channels,
|
372
|
+
out_channels=out_channels,
|
373
|
+
activation_config=layers_cfg.ActivationConfig(
|
374
|
+
layers_cfg.ActivationType.SILU
|
375
|
+
),
|
376
|
+
block_out_channels=block_out_channels,
|
377
|
+
scaling_factor=scaling_factor,
|
378
|
+
layers_per_block=layers_per_block,
|
379
|
+
normalization_config=norm_config,
|
380
|
+
mid_block_config=mid_block_config,
|
381
|
+
)
|
382
|
+
return config
|
@@ -603,7 +603,7 @@ def get_model_config(batch_size: int) -> unet_cfg.DiffusionModelConfig:
|
|
603
603
|
# Transformer configs.
|
604
604
|
transformer_num_attention_heads = 8
|
605
605
|
transformer_batch_size = batch_size
|
606
|
-
transformer_cross_attention_dim = 768 # Embedding
|
606
|
+
transformer_cross_attention_dim = 768 # Embedding from CLIP model
|
607
607
|
transformer_pre_conv_norm_config = layers_cfg.NormalizationConfig(
|
608
608
|
layers_cfg.NormalizationType.GROUP_NORM, epsilon=1e-6, group_num=32
|
609
609
|
)
|
@@ -645,3 +645,71 @@ def get_model_config(batch_size: int) -> unet_cfg.DiffusionModelConfig:
|
|
645
645
|
final_norm_config=final_norm_config,
|
646
646
|
final_activation_type=final_activation_type,
|
647
647
|
)
|
648
|
+
|
649
|
+
|
650
|
+
def get_fake_model_config(batch_size: int) -> unet_cfg.DiffusionModelConfig:
|
651
|
+
"""Get fake configs for the Diffusion model of Stable Diffusion v1.5 for testing.
|
652
|
+
|
653
|
+
Args:
|
654
|
+
batch_size (int): the batch size of input.
|
655
|
+
|
656
|
+
Retruns:
|
657
|
+
The configuration of diffusion model of Stable Diffusion v1.5.
|
658
|
+
"""
|
659
|
+
in_channels = 4
|
660
|
+
out_channels = 4
|
661
|
+
block_out_channels = [2, 4, 8, 8]
|
662
|
+
layers_per_block = 1
|
663
|
+
downsample_padding = 1
|
664
|
+
|
665
|
+
# Residual configs.
|
666
|
+
residual_norm_config = layers_cfg.NormalizationConfig(
|
667
|
+
layers_cfg.NormalizationType.GROUP_NORM, group_num=2
|
668
|
+
)
|
669
|
+
residual_activation_type = layers_cfg.ActivationType.SILU
|
670
|
+
|
671
|
+
# Transformer configs.
|
672
|
+
transformer_num_attention_heads = 1
|
673
|
+
transformer_batch_size = batch_size
|
674
|
+
transformer_cross_attention_dim = 4 # Embedding from CLIP model
|
675
|
+
transformer_pre_conv_norm_config = layers_cfg.NormalizationConfig(
|
676
|
+
layers_cfg.NormalizationType.GROUP_NORM, epsilon=1e-6, group_num=2
|
677
|
+
)
|
678
|
+
transformer_norm_config = layers_cfg.NormalizationConfig(
|
679
|
+
layers_cfg.NormalizationType.LAYER_NORM
|
680
|
+
)
|
681
|
+
transformer_ff_activation_type = layers_cfg.ActivationType.GE_GLU
|
682
|
+
|
683
|
+
# Time embedding configs.
|
684
|
+
time_embedding_dim = 2
|
685
|
+
time_embedding_blocks_dim = 4
|
686
|
+
|
687
|
+
# Mid block configs.
|
688
|
+
mid_block_layers = 1
|
689
|
+
|
690
|
+
# Finaly layer configs.
|
691
|
+
final_norm_config = layers_cfg.NormalizationConfig(
|
692
|
+
layers_cfg.NormalizationType.GROUP_NORM, group_num=2
|
693
|
+
)
|
694
|
+
final_activation_type = layers_cfg.ActivationType.SILU
|
695
|
+
|
696
|
+
return unet_cfg.DiffusionModelConfig(
|
697
|
+
in_channels=in_channels,
|
698
|
+
out_channels=out_channels,
|
699
|
+
block_out_channels=block_out_channels,
|
700
|
+
layers_per_block=layers_per_block,
|
701
|
+
downsample_padding=downsample_padding,
|
702
|
+
residual_norm_config=residual_norm_config,
|
703
|
+
residual_activation_type=residual_activation_type,
|
704
|
+
transformer_batch_size=transformer_batch_size,
|
705
|
+
transformer_num_attention_heads=transformer_num_attention_heads,
|
706
|
+
transformer_cross_attention_dim=transformer_cross_attention_dim,
|
707
|
+
transformer_pre_conv_norm_config=transformer_pre_conv_norm_config,
|
708
|
+
transformer_norm_config=transformer_norm_config,
|
709
|
+
transformer_ff_activation_type=transformer_ff_activation_type,
|
710
|
+
mid_block_layers=mid_block_layers,
|
711
|
+
time_embedding_dim=time_embedding_dim,
|
712
|
+
time_embedding_blocks_dim=time_embedding_blocks_dim,
|
713
|
+
final_norm_config=final_norm_config,
|
714
|
+
final_activation_type=final_activation_type,
|
715
|
+
)
|
@@ -0,0 +1,105 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
# A toy example which has a single-layer transformer block.
|
16
|
+
from absl import app
|
17
|
+
import ai_edge_torch
|
18
|
+
from ai_edge_torch import lowertools
|
19
|
+
from ai_edge_torch.generative.examples.test_models import toy_model
|
20
|
+
from ai_edge_torch.generative.examples.test_models import toy_model_with_kv_cache
|
21
|
+
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
22
|
+
import torch
|
23
|
+
|
24
|
+
KV_CACHE_MAX_LEN = 100
|
25
|
+
|
26
|
+
|
27
|
+
def convert_toy_model(_) -> None:
|
28
|
+
"""Converts a toy model to tflite."""
|
29
|
+
model = toy_model.ToySingleLayerModel(toy_model.get_model_config())
|
30
|
+
idx = torch.unsqueeze(torch.arange(0, KV_CACHE_MAX_LEN), 0)
|
31
|
+
input_pos = torch.arange(0, KV_CACHE_MAX_LEN)
|
32
|
+
print('running an inference')
|
33
|
+
print(
|
34
|
+
model.forward(
|
35
|
+
idx,
|
36
|
+
input_pos,
|
37
|
+
)
|
38
|
+
)
|
39
|
+
|
40
|
+
# Convert model to tflite.
|
41
|
+
print('converting model to tflite')
|
42
|
+
edge_model = ai_edge_torch.convert(
|
43
|
+
model,
|
44
|
+
(
|
45
|
+
idx,
|
46
|
+
input_pos,
|
47
|
+
),
|
48
|
+
)
|
49
|
+
edge_model.export('/tmp/toy_model.tflite')
|
50
|
+
|
51
|
+
|
52
|
+
def _export_stablehlo_mlir(model, args):
|
53
|
+
ep = torch.export.export(model, args)
|
54
|
+
return lowertools.exported_program_to_mlir_text(ep)
|
55
|
+
|
56
|
+
|
57
|
+
def convert_toy_model_with_kv_cache(_) -> None:
|
58
|
+
"""Converts a toy model with kv cache to tflite."""
|
59
|
+
dump_mlir = False
|
60
|
+
|
61
|
+
config = toy_model_with_kv_cache.get_model_config()
|
62
|
+
model = toy_model_with_kv_cache.ToyModelWithKVCache(config)
|
63
|
+
model.eval()
|
64
|
+
print('running an inference')
|
65
|
+
kv = kv_utils.KVCache.from_model_config(config)
|
66
|
+
|
67
|
+
tokens, input_pos = toy_model_with_kv_cache.get_sample_prefill_inputs()
|
68
|
+
decode_token, decode_input_pos = (
|
69
|
+
toy_model_with_kv_cache.get_sample_decode_inputs()
|
70
|
+
)
|
71
|
+
print(model.forward(tokens, input_pos, kv))
|
72
|
+
|
73
|
+
if dump_mlir:
|
74
|
+
mlir_text = _export_stablehlo_mlir(model, (tokens, input_pos, kv))
|
75
|
+
with open('/tmp/toy_model_with_external_kv.stablehlo.mlir', 'w') as f:
|
76
|
+
f.write(mlir_text)
|
77
|
+
|
78
|
+
# Convert model to tflite with 2 signatures (prefill + decode).
|
79
|
+
print('converting toy model to tflite with 2 signatures (prefill + decode)')
|
80
|
+
edge_model = (
|
81
|
+
ai_edge_torch.signature(
|
82
|
+
'prefill',
|
83
|
+
model,
|
84
|
+
sample_kwargs={
|
85
|
+
'tokens': tokens,
|
86
|
+
'input_pos': input_pos,
|
87
|
+
'kv_cache': kv,
|
88
|
+
},
|
89
|
+
)
|
90
|
+
.signature(
|
91
|
+
'decode',
|
92
|
+
model,
|
93
|
+
sample_kwargs={
|
94
|
+
'tokens': decode_token,
|
95
|
+
'input_pos': decode_input_pos,
|
96
|
+
'kv_cache': kv,
|
97
|
+
},
|
98
|
+
)
|
99
|
+
.convert()
|
100
|
+
)
|
101
|
+
edge_model.export('/tmp/toy_external_kv_cache.tflite')
|
102
|
+
|
103
|
+
|
104
|
+
if __name__ == '__main__':
|
105
|
+
app.run(convert_toy_model)
|
@@ -15,13 +15,12 @@
|
|
15
15
|
# A toy example which has a single-layer transformer block.
|
16
16
|
from typing import Tuple
|
17
17
|
|
18
|
-
import
|
18
|
+
from ai_edge_torch.generative.layers import builder
|
19
19
|
from ai_edge_torch.generative.layers.attention import TransformerBlock
|
20
20
|
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
21
|
-
import ai_edge_torch.generative.layers.builder as builder
|
22
21
|
import ai_edge_torch.generative.layers.model_config as cfg
|
23
22
|
import torch
|
24
|
-
|
23
|
+
from torch import nn
|
25
24
|
|
26
25
|
RoPECache = Tuple[torch.Tensor, torch.Tensor]
|
27
26
|
KV_CACHE_MAX_LEN = 100
|
@@ -149,31 +148,3 @@ def get_model_config() -> cfg.ModelConfig:
|
|
149
148
|
final_norm_config=norm_config,
|
150
149
|
)
|
151
150
|
return config
|
152
|
-
|
153
|
-
|
154
|
-
def define_and_run() -> None:
|
155
|
-
model = ToySingleLayerModel(get_model_config())
|
156
|
-
idx = torch.unsqueeze(torch.arange(0, KV_CACHE_MAX_LEN), 0)
|
157
|
-
input_pos = torch.arange(0, KV_CACHE_MAX_LEN)
|
158
|
-
print('running an inference')
|
159
|
-
print(
|
160
|
-
model.forward(
|
161
|
-
idx,
|
162
|
-
input_pos,
|
163
|
-
)
|
164
|
-
)
|
165
|
-
|
166
|
-
# Convert model to tflite.
|
167
|
-
print('converting model to tflite')
|
168
|
-
edge_model = ai_edge_torch.convert(
|
169
|
-
model,
|
170
|
-
(
|
171
|
-
idx,
|
172
|
-
input_pos,
|
173
|
-
),
|
174
|
-
)
|
175
|
-
edge_model.export('/tmp/toy_model.tflite')
|
176
|
-
|
177
|
-
|
178
|
-
if __name__ == '__main__':
|
179
|
-
define_and_run()
|
@@ -17,15 +17,14 @@
|
|
17
17
|
|
18
18
|
from typing import Tuple
|
19
19
|
|
20
|
-
import
|
21
|
-
from ai_edge_torch import lowertools
|
20
|
+
from absl import app
|
22
21
|
from ai_edge_torch.generative.layers import attention
|
23
22
|
from ai_edge_torch.generative.layers import builder
|
24
23
|
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
25
24
|
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
26
25
|
import ai_edge_torch.generative.layers.model_config as cfg
|
27
26
|
import torch
|
28
|
-
|
27
|
+
from torch import nn
|
29
28
|
|
30
29
|
RoPECache = Tuple[torch.Tensor, torch.Tensor]
|
31
30
|
|
@@ -87,11 +86,6 @@ class ToyModelWithKVCache(torch.nn.Module):
|
|
87
86
|
return {'logits': self.lm_head(x), 'kv_cache': updated_kv_cache}
|
88
87
|
|
89
88
|
|
90
|
-
def _export_stablehlo_mlir(model, args):
|
91
|
-
ep = torch.export.export(model, args)
|
92
|
-
return lowertools.exported_program_to_mlir_text(ep)
|
93
|
-
|
94
|
-
|
95
89
|
def get_model_config() -> cfg.ModelConfig:
|
96
90
|
attn_config = cfg.AttentionConfig(
|
97
91
|
num_heads=32,
|
@@ -133,51 +127,3 @@ def get_sample_decode_inputs() -> Tuple[torch.Tensor, torch.Tensor]:
|
|
133
127
|
tokens = torch.tensor([[1]], dtype=torch.int)
|
134
128
|
input_pos = torch.tensor([10])
|
135
129
|
return tokens, input_pos
|
136
|
-
|
137
|
-
|
138
|
-
def define_and_run() -> None:
|
139
|
-
dump_mlir = False
|
140
|
-
|
141
|
-
config = get_model_config()
|
142
|
-
model = ToyModelWithExternalKV(config)
|
143
|
-
model.eval()
|
144
|
-
print('running an inference')
|
145
|
-
kv = kv_utils.KVCache.from_model_config(config)
|
146
|
-
|
147
|
-
tokens, input_pos = get_sample_prefill_inputs()
|
148
|
-
decode_token, decode_input_pos = get_sample_decode_inputs()
|
149
|
-
print(model.forward(tokens, input_pos, kv))
|
150
|
-
|
151
|
-
if dump_mlir:
|
152
|
-
mlir_text = _export_stablehlo_mlir(model, (tokens, input_pos, kv))
|
153
|
-
with open('/tmp/toy_model_with_external_kv.stablehlo.mlir', 'w') as f:
|
154
|
-
f.write(mlir_text)
|
155
|
-
|
156
|
-
# Convert model to tflite with 2 signatures (prefill + decode).
|
157
|
-
print('converting toy model to tflite with 2 signatures (prefill + decode)')
|
158
|
-
edge_model = (
|
159
|
-
ai_edge_torch.signature(
|
160
|
-
'prefill',
|
161
|
-
model,
|
162
|
-
sample_kwargs={
|
163
|
-
'tokens': tokens,
|
164
|
-
'input_pos': input_pos,
|
165
|
-
'kv_cache': kv,
|
166
|
-
},
|
167
|
-
)
|
168
|
-
.signature(
|
169
|
-
'decode',
|
170
|
-
model,
|
171
|
-
sample_kwargs={
|
172
|
-
'tokens': decode_token,
|
173
|
-
'input_pos': decode_input_pos,
|
174
|
-
'kv_cache': kv,
|
175
|
-
},
|
176
|
-
)
|
177
|
-
.convert()
|
178
|
-
)
|
179
|
-
edge_model.export('/tmp/toy_external_kv_cache.tflite')
|
180
|
-
|
181
|
-
|
182
|
-
if __name__ == '__main__':
|
183
|
-
define_and_run()
|
@@ -23,34 +23,35 @@ from torch import nn
|
|
23
23
|
import torch.nn.functional as F
|
24
24
|
|
25
25
|
|
26
|
-
|
27
|
-
|
28
|
-
) -> Callable[[torch.Tensor], torch.Tensor]:
|
29
|
-
"""Builds an activation function with GLU (Gated Linear Unit).
|
26
|
+
class GeGLU(nn.Module):
|
27
|
+
"""GeGLU is an activation function which is a variant of GELU.
|
30
28
|
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
f(x) = x * act(y),
|
35
|
-
where x is the first half of the input and y is the second half of the input.
|
29
|
+
GeGLU(x) = (xW+b) * GELU(xV+c)
|
30
|
+
See: https://arxiv.org/abs/2002.05202v1
|
31
|
+
"""
|
36
32
|
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
gate_is_front: whether the gate is in front half of the input. Other part is
|
41
|
-
the output in GLU.
|
33
|
+
def __init__(self, d_in: int, d_out: int):
|
34
|
+
super().__init__()
|
35
|
+
self.proj = nn.Linear(d_in, d_out * 2)
|
42
36
|
|
43
|
-
|
44
|
-
|
37
|
+
def forward(self, x: torch.Tensor):
|
38
|
+
x, gate = self.proj(x).chunk(2, dim=-1)
|
39
|
+
return x * F.gelu(gate)
|
40
|
+
|
41
|
+
|
42
|
+
class SwiGLU(nn.Module):
|
43
|
+
"""SwiGLU is an activation function which is a variant of GLU.
|
44
|
+
|
45
|
+
SwiGLU is same as SiLU_GLU, because The SiLU function is also known as the
|
46
|
+
swish function.
|
47
|
+
|
48
|
+
SwiGLU(x) = Swish(xW+b) * (xV+c)
|
49
|
+
See: https://paperswithcode.com/method/swiglu
|
45
50
|
"""
|
46
51
|
|
47
|
-
def
|
52
|
+
def forward(self, x: torch.Tensor):
|
48
53
|
x, y = x.chunk(2, dim=-1)
|
49
|
-
|
50
|
-
return act(x) * y
|
51
|
-
return x * act(y)
|
52
|
-
|
53
|
-
return _glu
|
54
|
+
return F.silu(x) * y
|
54
55
|
|
55
56
|
|
56
57
|
def build_norm(dim: int, config: cfg.NormalizationConfig):
|
@@ -151,10 +152,10 @@ def get_activation(config: cfg.ActivationConfig):
|
|
151
152
|
# See: https://github.com/hendrycks/GELUs
|
152
153
|
return lambda x: x * F.sigmoid(1.702 * x)
|
153
154
|
elif config.type == cfg.ActivationType.GE_GLU:
|
154
|
-
return
|
155
|
+
return GeGLU(config.dim_in, config.dim_out)
|
155
156
|
elif config.type == cfg.ActivationType.RELU:
|
156
157
|
return F.relu
|
157
158
|
elif config.type == cfg.ActivationType.SILU_GLU:
|
158
|
-
return
|
159
|
+
return SwiGLU()
|
159
160
|
else:
|
160
161
|
raise ValueError("Unsupported activation type.")
|
@@ -118,9 +118,9 @@ class AttentionConfig:
|
|
118
118
|
@dataclass
|
119
119
|
class ActivationConfig:
|
120
120
|
type: ActivationType = ActivationType.LINEAR
|
121
|
-
#
|
122
|
-
|
123
|
-
|
121
|
+
# Dimension of input and output, used in GeGLU.
|
122
|
+
dim_in: Optional[int] = None
|
123
|
+
dim_out: Optional[int] = None
|
124
124
|
|
125
125
|
|
126
126
|
@dataclass
|
@@ -183,8 +183,16 @@ def group_norm_with_hlfb(
|
|
183
183
|
"""
|
184
184
|
x = torch.permute(x, (0, 2, 3, 1))
|
185
185
|
|
186
|
+
# TODO: b/366544750 - Change "reduction_axes" field as an array, rather than
|
187
|
+
# int32 when the bug is fixed.
|
186
188
|
builder = StableHLOCompositeBuilder(
|
187
|
-
name="odml.group_norm",
|
189
|
+
name="odml.group_norm",
|
190
|
+
attr={
|
191
|
+
"num_groups": num_groups,
|
192
|
+
"epsilon": eps,
|
193
|
+
"reduction_axes": 3,
|
194
|
+
"channel_axis": 3,
|
195
|
+
},
|
188
196
|
)
|
189
197
|
x, w, b = builder.mark_inputs(x, w, b)
|
190
198
|
x = torch.permute(x, (0, 3, 1, 2))
|
@@ -206,7 +214,7 @@ def layer_norm_with_hlfb(
|
|
206
214
|
"""Layer Normalization with high-level function boundary enabled.
|
207
215
|
|
208
216
|
Args:
|
209
|
-
x (torch.Tensor): Input tensor for Layer Normalization.
|
217
|
+
x (torch.Tensor): Input tensor for Layer Normalization, with BCHW shape.
|
210
218
|
w (torch.Tensor): The weight tensor for the normalization.
|
211
219
|
b (torch.Tensor): The bias tensor for the normalization.
|
212
220
|
eps (float): A small float value to ensure numerical stability.
|
@@ -216,7 +224,10 @@ def layer_norm_with_hlfb(
|
|
216
224
|
Returns:
|
217
225
|
The output tensor of Layer Normalization.
|
218
226
|
"""
|
219
|
-
builder = StableHLOCompositeBuilder(
|
227
|
+
builder = StableHLOCompositeBuilder(
|
228
|
+
name="odml.group_norm",
|
229
|
+
attr={"num_groups": 1, "epsilon": eps, "channel_axis": 1},
|
230
|
+
)
|
220
231
|
x, w, b = builder.mark_inputs(x, w, b)
|
221
232
|
if use_input_shape:
|
222
233
|
normalized_shape = x.shape
|
@@ -13,7 +13,7 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
|
-
from typing import List, Optional, Tuple
|
16
|
+
from typing import List, Optional, Tuple, Union
|
17
17
|
|
18
18
|
from ai_edge_torch.generative.layers.attention import CrossAttention
|
19
19
|
from ai_edge_torch.generative.layers.attention import SelfAttention
|
@@ -416,7 +416,7 @@ class DownEncoderBlock2D(nn.Module):
|
|
416
416
|
time_emb: Optional[torch.Tensor] = None,
|
417
417
|
context_tensor: Optional[torch.Tensor] = None,
|
418
418
|
output_hidden_states: bool = False,
|
419
|
-
) -> torch.Tensor
|
419
|
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]:
|
420
420
|
"""Forward function of the DownEncoderBlock2D.
|
421
421
|
|
422
422
|
Args:
|
@@ -21,7 +21,11 @@ from ai_edge_torch.generative.examples.gemma import gemma1
|
|
21
21
|
from ai_edge_torch.generative.examples.gemma import gemma2
|
22
22
|
from ai_edge_torch.generative.examples.openelm import openelm
|
23
23
|
from ai_edge_torch.generative.examples.phi import phi2
|
24
|
+
from ai_edge_torch.generative.examples.phi import phi3
|
24
25
|
from ai_edge_torch.generative.examples.smollm import smollm
|
26
|
+
from ai_edge_torch.generative.examples.stable_diffusion import clip as sd_clip
|
27
|
+
from ai_edge_torch.generative.examples.stable_diffusion import decoder as sd_decoder
|
28
|
+
from ai_edge_torch.generative.examples.stable_diffusion import diffusion as sd_diffusion
|
25
29
|
from ai_edge_torch.generative.layers import kv_cache
|
26
30
|
from ai_edge_torch.generative.test import utils as test_utils
|
27
31
|
import numpy as np
|
@@ -109,6 +113,17 @@ class TestModelConversion(googletest.TestCase):
|
|
109
113
|
config, pytorch_model, "serving_default", atol=1e-3, rtol=1e-3
|
110
114
|
)
|
111
115
|
|
116
|
+
@googletest.skipIf(
|
117
|
+
ai_edge_config.Config.use_torch_xla,
|
118
|
+
reason="tests with custom ops are not supported on oss",
|
119
|
+
)
|
120
|
+
def test_phi3(self):
|
121
|
+
config = phi3.get_fake_model_config()
|
122
|
+
pytorch_model = phi3.Phi3_5Mini(config).eval()
|
123
|
+
self._test_model(
|
124
|
+
config, pytorch_model, "prefill", atol=1e-5, rtol=1e-5
|
125
|
+
)
|
126
|
+
|
112
127
|
@googletest.skipIf(
|
113
128
|
ai_edge_config.Config.use_torch_xla,
|
114
129
|
reason="tests with custom ops are not supported on oss",
|
@@ -127,6 +142,110 @@ class TestModelConversion(googletest.TestCase):
|
|
127
142
|
pytorch_model = openelm.OpenELM(config).eval()
|
128
143
|
self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
|
129
144
|
|
145
|
+
@googletest.skipIf(
|
146
|
+
ai_edge_config.Config.use_torch_xla,
|
147
|
+
reason="tests with custom ops are not supported on oss",
|
148
|
+
)
|
149
|
+
def test_stable_diffusion_clip(self):
|
150
|
+
config = sd_clip.get_fake_model_config()
|
151
|
+
prompt_tokens = torch.from_numpy(
|
152
|
+
np.array([[1, 2, 3, 4, 5, 6]], dtype=np.int32)
|
153
|
+
)
|
154
|
+
|
155
|
+
pytorch_model = sd_clip.CLIP(config).eval()
|
156
|
+
torch_output = pytorch_model(prompt_tokens)
|
157
|
+
|
158
|
+
edge_model = ai_edge_torch.signature(
|
159
|
+
"encode", pytorch_model, (prompt_tokens,)
|
160
|
+
).convert()
|
161
|
+
edge_model.set_interpreter_builder(
|
162
|
+
self._interpreter_builder(edge_model.tflite_model())
|
163
|
+
)
|
164
|
+
edge_output = edge_model(
|
165
|
+
prompt_tokens.numpy(),
|
166
|
+
signature_name="encode",
|
167
|
+
)
|
168
|
+
self.assertTrue(
|
169
|
+
np.allclose(
|
170
|
+
edge_output,
|
171
|
+
torch_output.detach().numpy(),
|
172
|
+
atol=1e-4,
|
173
|
+
rtol=1e-5,
|
174
|
+
)
|
175
|
+
)
|
176
|
+
|
177
|
+
@googletest.skipIf(
|
178
|
+
ai_edge_config.Config.use_torch_xla,
|
179
|
+
reason="tests with custom ops are not supported on oss",
|
180
|
+
)
|
181
|
+
def test_stable_diffusion_diffusion(self):
|
182
|
+
config = sd_diffusion.get_fake_model_config(2)
|
183
|
+
latents = torch.from_numpy(
|
184
|
+
np.random.normal(size=(2, 4, 8, 8)).astype(np.float32)
|
185
|
+
)
|
186
|
+
context = torch.from_numpy(
|
187
|
+
np.random.normal(size=(2, 4, 4)).astype(np.float32)
|
188
|
+
)
|
189
|
+
time_embedding = torch.from_numpy(
|
190
|
+
np.random.normal(size=(2, 2)).astype(np.float32)
|
191
|
+
)
|
192
|
+
|
193
|
+
pytorch_model = sd_diffusion.Diffusion(config).eval()
|
194
|
+
torch_output = pytorch_model(latents, context, time_embedding)
|
195
|
+
|
196
|
+
edge_model = ai_edge_torch.signature(
|
197
|
+
"diffusion", pytorch_model, (latents, context, time_embedding)
|
198
|
+
).convert()
|
199
|
+
edge_model.set_interpreter_builder(
|
200
|
+
self._interpreter_builder(edge_model.tflite_model())
|
201
|
+
)
|
202
|
+
edge_output = edge_model(
|
203
|
+
latents.numpy(),
|
204
|
+
context.numpy(),
|
205
|
+
time_embedding.numpy(),
|
206
|
+
signature_name="diffusion",
|
207
|
+
)
|
208
|
+
self.assertTrue(
|
209
|
+
np.allclose(
|
210
|
+
edge_output,
|
211
|
+
torch_output.detach().numpy(),
|
212
|
+
atol=1e-4,
|
213
|
+
rtol=1e-5,
|
214
|
+
)
|
215
|
+
)
|
216
|
+
|
217
|
+
@googletest.skipIf(
|
218
|
+
ai_edge_config.Config.use_torch_xla,
|
219
|
+
reason="tests with custom ops are not supported on oss",
|
220
|
+
)
|
221
|
+
def test_stable_diffusion_decoder(self):
|
222
|
+
config = sd_decoder.get_fake_model_config()
|
223
|
+
latents = torch.from_numpy(
|
224
|
+
np.random.normal(size=(1, 4, 64, 64)).astype(np.float32)
|
225
|
+
)
|
226
|
+
|
227
|
+
pytorch_model = sd_decoder.Decoder(config).eval()
|
228
|
+
torch_output = pytorch_model(latents)
|
229
|
+
|
230
|
+
edge_model = ai_edge_torch.signature(
|
231
|
+
"decode", pytorch_model, (latents,)
|
232
|
+
).convert()
|
233
|
+
edge_model.set_interpreter_builder(
|
234
|
+
self._interpreter_builder(edge_model.tflite_model())
|
235
|
+
)
|
236
|
+
edge_output = edge_model(
|
237
|
+
latents.numpy(),
|
238
|
+
signature_name="decode",
|
239
|
+
)
|
240
|
+
self.assertTrue(
|
241
|
+
np.allclose(
|
242
|
+
edge_output,
|
243
|
+
torch_output.detach().numpy(),
|
244
|
+
atol=1e-4,
|
245
|
+
rtol=1e-5,
|
246
|
+
)
|
247
|
+
)
|
248
|
+
|
130
249
|
|
131
250
|
if __name__ == "__main__":
|
132
251
|
googletest.main()
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.3.0.
|
3
|
+
Version: 0.3.0.dev20240925
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
@@ -30,6 +30,7 @@ Requires-Dist: tabulate
|
|
30
30
|
Requires-Dist: torch>=2.4.0
|
31
31
|
Requires-Dist: torch-xla>=2.4.0
|
32
32
|
Requires-Dist: tf-nightly>=2.18.0.dev20240722
|
33
|
+
Requires-Dist: ai-edge-litert-nightly
|
33
34
|
Requires-Dist: ai-edge-quantizer-nightly
|
34
35
|
|
35
36
|
Library that supports converting PyTorch models into a .tflite format, which can
|
@@ -3,7 +3,7 @@ ai_edge_torch/config.py,sha256=FMWeCH2b7HYILBvaI1iZNnYCO4WAhDOwBZBmIE-xrF0,909
|
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
4
4
|
ai_edge_torch/fx_pass_base.py,sha256=D86Gw3pIRcpnTebUPKlnPbPGJae1S6Fw4DZZ3ZkD0zw,3730
|
5
5
|
ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
|
6
|
-
ai_edge_torch/version.py,sha256=
|
6
|
+
ai_edge_torch/version.py,sha256=UXj1-90S3RDoHwYSmy9VdMC0Sm3EHt9ESLZbi3hnWus,706
|
7
7
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
8
8
|
ai_edge_torch/_convert/conversion.py,sha256=5uPwHhmc6kwiIz-CqaiHDejf2SOWMHrb-rYEHm69wKc,3801
|
9
9
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
@@ -48,22 +48,25 @@ ai_edge_torch/generative/examples/gemma/verify_gemma2.py,sha256=kSzn1ITJXqrtNQax
|
|
48
48
|
ai_edge_torch/generative/examples/gemma/verify_util.py,sha256=HBK2d8FcWFoxVDF5zk9sLSbKZEtwZQhX-K_zm4AvQtQ,5160
|
49
49
|
ai_edge_torch/generative/examples/openelm/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
50
50
|
ai_edge_torch/generative/examples/openelm/convert_to_tflite.py,sha256=85FVEt6cKFP2UzCLC78tAkbwGlGhAArtG7Wa75NxJik,2185
|
51
|
-
ai_edge_torch/generative/examples/openelm/openelm.py,sha256=
|
51
|
+
ai_edge_torch/generative/examples/openelm/openelm.py,sha256=VcU8A0B9nQR-FTPHXqNHSHZzeIZZ_As4yvKZMnoU2P4,7482
|
52
52
|
ai_edge_torch/generative/examples/openelm/verify.py,sha256=QdFKymQSCYFJcYVvA63u5uIsn1YxJ0JZD5UqN6gxraI,2112
|
53
53
|
ai_edge_torch/generative/examples/phi/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
54
|
+
ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py,sha256=rkbTtMaqSVG48cm-NTxR_LDgZmXAEBqayTm9O49oMXc,2171
|
54
55
|
ai_edge_torch/generative/examples/phi/convert_to_tflite.py,sha256=3go690yX6PFeXMdpY7y4JZorAwxX0HT_b_pKZieauvk,2169
|
55
56
|
ai_edge_torch/generative/examples/phi/phi2.py,sha256=YwAszA53aOjvaMJ5wua2-5rP79N21Un_Y5yBCfFSYNU,6189
|
56
|
-
ai_edge_torch/generative/examples/phi/
|
57
|
+
ai_edge_torch/generative/examples/phi/phi3.py,sha256=DIDzpG8DZkWDcWsAVkcxzxIC3U3352uVI3zMoYZD16U,9554
|
58
|
+
ai_edge_torch/generative/examples/phi/verify.py,sha256=5pQ0Bt8vGl8uTpkgXvOx8G7_rju0Gi8mIEr5NtRSAbs,2145
|
59
|
+
ai_edge_torch/generative/examples/phi/verify_phi3.py,sha256=o1UTqpimkeX3MDjgdG1QTQkoZHvCEnGClA0J0WB3wJ4,2328
|
57
60
|
ai_edge_torch/generative/examples/smollm/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
58
61
|
ai_edge_torch/generative/examples/smollm/convert_to_tflite.py,sha256=zPrDTDeRVWFi9DS32uNi-RLpzOStFOk5MhNla4ixeew,2179
|
59
62
|
ai_edge_torch/generative/examples/smollm/smollm.py,sha256=hyhMk-b5762Q2xmjdD47g85dcbBSNJXNPIsifm1DRto,3239
|
60
63
|
ai_edge_torch/generative/examples/smollm/verify.py,sha256=G2dAcl-VhAbx1E1PEqM6hpzPF24HqFZaz7UBEpJSQ3w,2022
|
61
64
|
ai_edge_torch/generative/examples/stable_diffusion/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
62
65
|
ai_edge_torch/generative/examples/stable_diffusion/attention.py,sha256=kDWG6MlIGa89zC5KSRcJlw2c4ITuw8KcchtfmF55f4g,3545
|
63
|
-
ai_edge_torch/generative/examples/stable_diffusion/clip.py,sha256=
|
66
|
+
ai_edge_torch/generative/examples/stable_diffusion/clip.py,sha256=2RMi5UmfMT4Ep68ZLJsqF-fMvEumNVkIwqtsRli9HhA,6068
|
64
67
|
ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py,sha256=vfMGI03UL_gfB561t2kzIHuScwnsUmqaPWxgvq_1T5A,5043
|
65
|
-
ai_edge_torch/generative/examples/stable_diffusion/decoder.py,sha256=
|
66
|
-
ai_edge_torch/generative/examples/stable_diffusion/diffusion.py,sha256=
|
68
|
+
ai_edge_torch/generative/examples/stable_diffusion/decoder.py,sha256=ZTRD56e8MsdGPJr7vpLa4Ju_BFw_b-FUgXgd-SO5MBw,15665
|
69
|
+
ai_edge_torch/generative/examples/stable_diffusion/diffusion.py,sha256=6FAnevL8ZfCK2YCSPivarUH0Z8wGKSmnPpJNC0OI5A8,33680
|
67
70
|
ai_edge_torch/generative/examples/stable_diffusion/encoder.py,sha256=CAPsW84A8f00nS6fLFeh_XUjCPsDCA5UxHOUsMrLfSU,3450
|
68
71
|
ai_edge_torch/generative/examples/stable_diffusion/pipeline.py,sha256=x9lEEENGNbpx6VTf_LTVudd9d6bs9tLvFUKTl252zEY,8623
|
69
72
|
ai_edge_torch/generative/examples/stable_diffusion/tokenizer.py,sha256=xychak9hdLd6ieXBYEwrK2BkF8NRZWZSSCijIsESpBA,3420
|
@@ -78,8 +81,9 @@ ai_edge_torch/generative/examples/t5/convert_to_tflite.py,sha256=HHtZTtUh3QgE4F7
|
|
78
81
|
ai_edge_torch/generative/examples/t5/t5.py,sha256=OZ67knK-UB1dBjxydG-Jwkp0Z3FzOCqGPTdg5aBFu4w,21328
|
79
82
|
ai_edge_torch/generative/examples/t5/t5_attention.py,sha256=l01oYyJo77INzRwN4xqXquaFQPvCFBFF5zOnmGVb3Hg,8731
|
80
83
|
ai_edge_torch/generative/examples/test_models/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
81
|
-
ai_edge_torch/generative/examples/test_models/
|
82
|
-
ai_edge_torch/generative/examples/test_models/
|
84
|
+
ai_edge_torch/generative/examples/test_models/convert_toy_model.py,sha256=6-WaNHckq_LlXMVTh8x90MGWeWq2bu_T_XQd3w9FnGg,3261
|
85
|
+
ai_edge_torch/generative/examples/test_models/toy_model.py,sha256=LTuzres5DHmrMT6U9rCrGf6vmR9SmopmB8sO6Cd2NxQ,5255
|
86
|
+
ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py,sha256=xDYTh4m3vBEb6r3_ERhmj5qILW7YdVDAnZ-fitgYONg,4450
|
83
87
|
ai_edge_torch/generative/examples/tiny_llama/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
84
88
|
ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py,sha256=ekxd8efjMgEvauUu3PidWOC-DszPHn5sqU753F7sJIM,2201
|
85
89
|
ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=tlWpa7Aun3u3w5b-9EBtW7olhmSf8W-tn5bKUIwC-ys,6044
|
@@ -89,15 +93,15 @@ ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=hhxSQvkD
|
|
89
93
|
ai_edge_torch/generative/layers/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
90
94
|
ai_edge_torch/generative/layers/attention.py,sha256=Z0Y_G8IG0LmvLX2u9D8__Fkr22szB-az6wMNnZpzhkA,13233
|
91
95
|
ai_edge_torch/generative/layers/attention_utils.py,sha256=68GXGR2HSWBFViTxX7cHifzVG-kcLS2IL2tQJPIpupg,7344
|
92
|
-
ai_edge_torch/generative/layers/builder.py,sha256=
|
96
|
+
ai_edge_torch/generative/layers/builder.py,sha256=oE8DdqLA-oWkBC2zySSCh8JNAJg_hk8-W_UoMSrgDVk,5088
|
93
97
|
ai_edge_torch/generative/layers/feed_forward.py,sha256=hdICat-8gW7-vxDAevJQ8NQ-mynllPiqLdXQMF6JMnc,4189
|
94
98
|
ai_edge_torch/generative/layers/kv_cache.py,sha256=2El7kZYnQRCRcVc63xgiAdBh9oVOksDu35p9XggvaGE,6148
|
95
|
-
ai_edge_torch/generative/layers/model_config.py,sha256=
|
96
|
-
ai_edge_torch/generative/layers/normalization.py,sha256=
|
99
|
+
ai_edge_torch/generative/layers/model_config.py,sha256=l5Rb3h3GK2pux-Lg3BONTD6b7klxXqUbDDtYs_bGKLk,6879
|
100
|
+
ai_edge_torch/generative/layers/normalization.py,sha256=cpo88JUXbF9j3sJTU4JuwOap9ryGV05C1QkPij-YQwU,6999
|
97
101
|
ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=CZqOoibLcHvUgrgaIIWAlmk3XgE2inzx340MN-npLoU,1347
|
98
102
|
ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=gXxh3papKy4FBpGEX7VyZ7rZ1Js6aHK70Q6DKrVSckY,4154
|
99
103
|
ai_edge_torch/generative/layers/unet/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
100
|
-
ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=
|
104
|
+
ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=JwndhL3Z31TvkdGlAoTL5PQzmKfHdRWaaE1EbaMI4Gs,27540
|
101
105
|
ai_edge_torch/generative/layers/unet/builder.py,sha256=zAqWXdimmMrQRhmE_t9XkS68mh6PSrzwb-2NZZXrR5I,1901
|
102
106
|
ai_edge_torch/generative/layers/unet/model_config.py,sha256=8ze9kVWMuyZVQcgK7hWYw9TM1W9lXD-2j0iMHlxoGX4,9267
|
103
107
|
ai_edge_torch/generative/quantize/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
@@ -111,7 +115,7 @@ ai_edge_torch/generative/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudj
|
|
111
115
|
ai_edge_torch/generative/test/test_kv_cache.py,sha256=W6Bh0gYDzmwb0j9HdD5_D7Z7FPToP2HSyFrmwIXuFqo,3793
|
112
116
|
ai_edge_torch/generative/test/test_loader.py,sha256=8y74ChO3CZCfEi1eCf3-w47kRgAI4qPYCXpi8rTQXMA,3378
|
113
117
|
ai_edge_torch/generative/test/test_model_conversion.py,sha256=s-EVLOQGjIeVtgNI8Ggs37pkRdErAliT6NhrrFigPOE,5459
|
114
|
-
ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=
|
118
|
+
ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=IzW2HjXS2-zePZM-qEuXL4zclnGvYsNw-6tuDSeNna4,8163
|
115
119
|
ai_edge_torch/generative/test/test_quantize.py,sha256=8geJhKwYBU20m0mdGPD1BUFwQ0lZKNtCB04SOLO18y4,5980
|
116
120
|
ai_edge_torch/generative/test/utils.py,sha256=YvEhO2HIj1LkBs5du1UxY-cGRW9HMyAYsOUhgsTrTpA,1796
|
117
121
|
ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
|
@@ -166,8 +170,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
|
|
166
170
|
ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
167
171
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
168
172
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
169
|
-
ai_edge_torch_nightly-0.3.0.
|
170
|
-
ai_edge_torch_nightly-0.3.0.
|
171
|
-
ai_edge_torch_nightly-0.3.0.
|
172
|
-
ai_edge_torch_nightly-0.3.0.
|
173
|
-
ai_edge_torch_nightly-0.3.0.
|
173
|
+
ai_edge_torch_nightly-0.3.0.dev20240925.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
174
|
+
ai_edge_torch_nightly-0.3.0.dev20240925.dist-info/METADATA,sha256=5KsshdZ4-3X193HkoO2ukceyDEdWGvb8ZEMcw88qt7k,1897
|
175
|
+
ai_edge_torch_nightly-0.3.0.dev20240925.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
|
176
|
+
ai_edge_torch_nightly-0.3.0.dev20240925.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
177
|
+
ai_edge_torch_nightly-0.3.0.dev20240925.dist-info/RECORD,,
|
File without changes
|
File without changes
|