ai-edge-torch-nightly 0.2.0.dev20240604__py3-none-any.whl → 0.2.0.dev20240605__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ai-edge-torch-nightly might be problematic. Click here for more details.
- ai_edge_torch/convert/fx_passes/build_aten_composite_pass.py +30 -0
- ai_edge_torch/convert/test/test_convert_composites.py +18 -0
- ai_edge_torch/generative/examples/stable_diffusion/clip.py +83 -49
- ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +7 -5
- ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +0 -260
- ai_edge_torch/generative/examples/t5/t5_attention.py +2 -2
- ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +161 -0
- ai_edge_torch/generative/layers/attention.py +27 -114
- ai_edge_torch/generative/layers/builder.py +4 -0
- ai_edge_torch/generative/layers/model_config.py +5 -1
- ai_edge_torch/generative/layers/scaled_dot_product_attention.py +117 -0
- ai_edge_torch/generative/utilities/loader.py +56 -27
- {ai_edge_torch_nightly-0.2.0.dev20240604.dist-info → ai_edge_torch_nightly-0.2.0.dev20240605.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.2.0.dev20240604.dist-info → ai_edge_torch_nightly-0.2.0.dev20240605.dist-info}/RECORD +17 -15
- {ai_edge_torch_nightly-0.2.0.dev20240604.dist-info → ai_edge_torch_nightly-0.2.0.dev20240605.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.2.0.dev20240604.dist-info → ai_edge_torch_nightly-0.2.0.dev20240605.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.2.0.dev20240604.dist-info → ai_edge_torch_nightly-0.2.0.dev20240605.dist-info}/top_level.txt +0 -0
|
@@ -99,6 +99,36 @@ def _aten_hardswish(gm: GraphModule, node: Node):
|
|
|
99
99
|
node.target = hardswish
|
|
100
100
|
|
|
101
101
|
|
|
102
|
+
@_register_composite_builder(torch.ops.aten.gelu.default)
|
|
103
|
+
def _aten_gelu(gm: GraphModule, node: Node):
|
|
104
|
+
op = node.target
|
|
105
|
+
args_mapper = TorchOpArgumentsMapper(op)
|
|
106
|
+
|
|
107
|
+
def gelu(*args, **kwargs):
|
|
108
|
+
nonlocal op, args_mapper
|
|
109
|
+
|
|
110
|
+
full_kwargs = args_mapper.get_full_kwargs(args, kwargs)
|
|
111
|
+
|
|
112
|
+
# TFLite supports exact and tanh approximate.
|
|
113
|
+
if full_kwargs["approximate"] != "none" and full_kwargs["approximate"] != "tanh":
|
|
114
|
+
return op(*args, **kwargs)
|
|
115
|
+
|
|
116
|
+
builder = StableHLOCompositeBuilder(
|
|
117
|
+
"aten.gelu.default",
|
|
118
|
+
attr=_tree_map_to_composite_attr_values(
|
|
119
|
+
{
|
|
120
|
+
"approximate": full_kwargs["approximate"],
|
|
121
|
+
}
|
|
122
|
+
),
|
|
123
|
+
)
|
|
124
|
+
full_kwargs["self"] = builder.mark_inputs(full_kwargs["self"])
|
|
125
|
+
output = op(full_kwargs["self"])
|
|
126
|
+
output = builder.mark_outputs(output)
|
|
127
|
+
return output
|
|
128
|
+
|
|
129
|
+
node.target = gelu
|
|
130
|
+
|
|
131
|
+
|
|
102
132
|
@_register_composite_builder(torch.ops.aten.avg_pool2d.default)
|
|
103
133
|
def _aten_avg_pool2d(gm: GraphModule, node: Node):
|
|
104
134
|
op = node.target
|
|
@@ -169,6 +169,24 @@ class TestConvertComposites(unittest.TestCase):
|
|
|
169
169
|
model_coverage.compare_tflite_torch(edge_model, torch_module, tracing_args)
|
|
170
170
|
)
|
|
171
171
|
|
|
172
|
+
def test_convert_gelu(self):
|
|
173
|
+
"""Tests conversion of a GELU module."""
|
|
174
|
+
|
|
175
|
+
args = (torch.randn((5, 10)),)
|
|
176
|
+
torch_module = torch.nn.GELU().eval()
|
|
177
|
+
edge_model = ai_edge_torch.convert(torch_module, args)
|
|
178
|
+
|
|
179
|
+
self.assertTrue(model_coverage.compare_tflite_torch(edge_model, torch_module, args))
|
|
180
|
+
|
|
181
|
+
def test_convert_gelu_approximate(self):
|
|
182
|
+
"""Tests conversion of an Approximate GELU module."""
|
|
183
|
+
|
|
184
|
+
args = (torch.randn((5, 10)),)
|
|
185
|
+
torch_module = torch.nn.GELU('tanh').eval()
|
|
186
|
+
edge_model = ai_edge_torch.convert(torch_module, args)
|
|
187
|
+
|
|
188
|
+
self.assertTrue(model_coverage.compare_tflite_torch(edge_model, torch_module, args))
|
|
189
|
+
|
|
172
190
|
|
|
173
191
|
if __name__ == '__main__':
|
|
174
192
|
unittest.main()
|
|
@@ -15,65 +15,99 @@
|
|
|
15
15
|
|
|
16
16
|
import torch
|
|
17
17
|
from torch import nn
|
|
18
|
-
from torch._prims_common import mask_tensor
|
|
19
|
-
from torch._prims_common.wrappers import out_wrapper
|
|
20
18
|
|
|
21
|
-
from ai_edge_torch.generative.
|
|
19
|
+
from ai_edge_torch.generative.layers.attention import TransformerBlock
|
|
20
|
+
import ai_edge_torch.generative.layers.attention_utils as attention_utils
|
|
21
|
+
import ai_edge_torch.generative.layers.builder as builder
|
|
22
|
+
import ai_edge_torch.generative.layers.model_config as cfg
|
|
23
|
+
import ai_edge_torch.generative.utilities.loader as loading_utils
|
|
24
|
+
|
|
25
|
+
TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
|
26
|
+
ff_up_proj="layers.{}.linear_1",
|
|
27
|
+
ff_down_proj="layers.{}.linear_2",
|
|
28
|
+
ff_gate_proj="layers.{}.linear_1",
|
|
29
|
+
attn_fused_qkv_proj="layers.{}.attention.in_proj",
|
|
30
|
+
attn_output_proj="layers.{}.attention.out_proj",
|
|
31
|
+
pre_attn_norm="layers.{}.layernorm_1",
|
|
32
|
+
pre_ff_norm="layers.{}.layernorm_2",
|
|
33
|
+
embedding="embedding.token_embedding",
|
|
34
|
+
embedding_position="embedding.position_value",
|
|
35
|
+
final_norm="layernorm",
|
|
36
|
+
lm_head=None,
|
|
37
|
+
)
|
|
22
38
|
|
|
23
39
|
|
|
24
|
-
class
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
self.token_embedding = nn.Embedding(n_vocab, n_embd)
|
|
29
|
-
self.position_value = nn.Parameter(torch.zeros((n_token, n_embd)))
|
|
30
|
-
|
|
31
|
-
def forward(self, tokens):
|
|
32
|
-
x = self.token_embedding(tokens)
|
|
33
|
-
x += self.position_value
|
|
34
|
-
return x
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
class CLIPLayer(nn.Module):
|
|
40
|
+
class CLIP(nn.Module):
|
|
41
|
+
"""CLIP text encoder
|
|
42
|
+
For details, see https://arxiv.org/abs/2103.00020
|
|
43
|
+
"""
|
|
38
44
|
|
|
39
|
-
def __init__(self,
|
|
45
|
+
def __init__(self, config: cfg.ModelConfig):
|
|
40
46
|
super().__init__()
|
|
41
|
-
self.
|
|
42
|
-
self.
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
self.linear_2 = nn.Linear(4 * n_embd, n_embd)
|
|
46
|
-
|
|
47
|
-
def forward(self, x):
|
|
48
|
-
residue = x
|
|
49
|
-
x = self.layernorm_1(x)
|
|
50
|
-
x = self.attention(x, causal_mask=True)
|
|
51
|
-
x += residue
|
|
47
|
+
self.tok_embedding = nn.Embedding(config.vocab_size, config.embedding_dim)
|
|
48
|
+
self.tok_embedding_position = nn.Parameter(
|
|
49
|
+
torch.zeros((config.max_seq_len, config.embedding_dim))
|
|
50
|
+
)
|
|
52
51
|
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
x += residue
|
|
52
|
+
self.config = config
|
|
53
|
+
self.transformer_blocks = nn.ModuleList(
|
|
54
|
+
TransformerBlock(config) for _ in range(config.num_layers)
|
|
55
|
+
)
|
|
56
|
+
self.final_norm = builder.build_norm(config.embedding_dim, config.final_norm_config)
|
|
59
57
|
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
class CLIP(nn.Module):
|
|
64
|
-
|
|
65
|
-
def __init__(self):
|
|
66
|
-
super().__init__()
|
|
67
|
-
self.embedding = CLIPEmbedding(49408, 768, 77)
|
|
68
|
-
self.layers = nn.ModuleList([CLIPLayer(12, 768) for i in range(12)])
|
|
69
|
-
self.layernorm = nn.LayerNorm(768)
|
|
58
|
+
self.mask_cache = attention_utils.build_causal_mask_cache(
|
|
59
|
+
size=config.max_seq_len, dtype=torch.float32
|
|
60
|
+
)
|
|
70
61
|
|
|
71
62
|
@torch.inference_mode
|
|
72
63
|
def forward(self, tokens: torch.LongTensor) -> torch.FloatTensor:
|
|
73
64
|
tokens = tokens.type(torch.long)
|
|
74
65
|
|
|
75
|
-
state = self.
|
|
76
|
-
for layer in self.
|
|
77
|
-
state = layer(state)
|
|
78
|
-
output = self.
|
|
66
|
+
state = self.tok_embedding(tokens) + self.tok_embedding_position
|
|
67
|
+
for layer in self.transformer_blocks:
|
|
68
|
+
state = layer(state, mask=self.mask_cache)
|
|
69
|
+
output = self.final_norm(state)
|
|
79
70
|
return output
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def get_model_config() -> cfg.ModelConfig:
|
|
74
|
+
max_seq_len = 77
|
|
75
|
+
vocab_size = 49408
|
|
76
|
+
num_layers = 12
|
|
77
|
+
num_heads = 12
|
|
78
|
+
num_query_groups = 12
|
|
79
|
+
embedding_dim = 768
|
|
80
|
+
|
|
81
|
+
attn_config = cfg.AttentionConfig(
|
|
82
|
+
num_heads=num_heads,
|
|
83
|
+
num_query_groups=num_query_groups,
|
|
84
|
+
rotary_percentage=0.0,
|
|
85
|
+
qkv_use_bias=True,
|
|
86
|
+
qkv_transpose_before_split=True,
|
|
87
|
+
output_proj_use_bias=True,
|
|
88
|
+
enable_kv_cache=False,
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
ff_config = cfg.FeedForwardConfig(
|
|
92
|
+
type=cfg.FeedForwardType.SEQUENTIAL,
|
|
93
|
+
activation=cfg.ActivationType.GELU_QUICK,
|
|
94
|
+
intermediate_size=embedding_dim * 4,
|
|
95
|
+
use_bias=True,
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.LAYER_NORM)
|
|
99
|
+
|
|
100
|
+
config = cfg.ModelConfig(
|
|
101
|
+
vocab_size=vocab_size,
|
|
102
|
+
num_layers=num_layers,
|
|
103
|
+
max_seq_len=max_seq_len,
|
|
104
|
+
embedding_dim=embedding_dim,
|
|
105
|
+
attn_config=attn_config,
|
|
106
|
+
ff_config=ff_config,
|
|
107
|
+
pre_attention_norm_config=norm_config,
|
|
108
|
+
pre_ff_norm_config=norm_config,
|
|
109
|
+
final_norm_config=norm_config,
|
|
110
|
+
enable_hlfb=True,
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
return config
|
|
@@ -19,11 +19,12 @@ from pathlib import Path
|
|
|
19
19
|
import torch
|
|
20
20
|
|
|
21
21
|
import ai_edge_torch
|
|
22
|
-
|
|
22
|
+
import ai_edge_torch.generative.examples.stable_diffusion.clip as clip
|
|
23
23
|
from ai_edge_torch.generative.examples.stable_diffusion.decoder import Decoder
|
|
24
24
|
from ai_edge_torch.generative.examples.stable_diffusion.diffusion import Diffusion # NOQA
|
|
25
25
|
from ai_edge_torch.generative.examples.stable_diffusion.encoder import Encoder
|
|
26
26
|
import ai_edge_torch.generative.examples.stable_diffusion.util as util
|
|
27
|
+
import ai_edge_torch.generative.utilities.loader as loading_utils
|
|
27
28
|
|
|
28
29
|
|
|
29
30
|
@torch.inference_mode
|
|
@@ -36,8 +37,9 @@ def convert_stable_diffusion_to_tflite(
|
|
|
36
37
|
image_width: int = 512,
|
|
37
38
|
):
|
|
38
39
|
|
|
39
|
-
|
|
40
|
-
|
|
40
|
+
clip_model = clip.CLIP(clip.get_model_config())
|
|
41
|
+
loader = loading_utils.ModelLoader(clip_ckpt_path, clip.TENSOR_NAMES)
|
|
42
|
+
loader.load(clip_model, strict=False)
|
|
41
43
|
|
|
42
44
|
encoder = Encoder()
|
|
43
45
|
encoder.load_state_dict(torch.load(encoder_ckpt_path))
|
|
@@ -59,13 +61,13 @@ def convert_stable_diffusion_to_tflite(
|
|
|
59
61
|
)
|
|
60
62
|
|
|
61
63
|
input_latents = encoder(input_image, noise)
|
|
62
|
-
context_cond =
|
|
64
|
+
context_cond = clip_model(prompt_tokens)
|
|
63
65
|
context_uncond = torch.zeros_like(context_cond)
|
|
64
66
|
context = torch.cat([context_cond, context_uncond], axis=0)
|
|
65
67
|
time_embedding = util.get_time_embedding(timestamp)
|
|
66
68
|
|
|
67
69
|
# CLIP text encoder
|
|
68
|
-
ai_edge_torch.signature('encode',
|
|
70
|
+
ai_edge_torch.signature('encode', clip_model, (prompt_tokens,)).convert().export(
|
|
69
71
|
'/tmp/stable_diffusion/clip.tflite'
|
|
70
72
|
)
|
|
71
73
|
|
|
@@ -202,11 +202,6 @@ class UNet(nn.Module):
|
|
|
202
202
|
|
|
203
203
|
x = self.bottleneck(x, context, time)
|
|
204
204
|
|
|
205
|
-
# print('x shape:')
|
|
206
|
-
# print(list(x.shape))
|
|
207
|
-
# print('time shape:')
|
|
208
|
-
# print(list(time.shape))
|
|
209
|
-
|
|
210
205
|
for layers in self.decoders:
|
|
211
206
|
x = torch.cat((x, skip_connections.pop()), dim=1)
|
|
212
207
|
x = layers(x, context, time)
|
|
@@ -214,199 +209,6 @@ class UNet(nn.Module):
|
|
|
214
209
|
return x
|
|
215
210
|
|
|
216
211
|
|
|
217
|
-
# The encoder component.
|
|
218
|
-
class UNetEncoder(nn.Module):
|
|
219
|
-
|
|
220
|
-
def __init__(self):
|
|
221
|
-
super().__init__()
|
|
222
|
-
self.time_embedding = TimeEmbedding(320)
|
|
223
|
-
self.encoders = nn.ModuleList(
|
|
224
|
-
[
|
|
225
|
-
SwitchSequential(nn.Conv2d(4, 320, kernel_size=3, padding=1)),
|
|
226
|
-
SwitchSequential(ResidualBlock(320, 320), AttentionBlock(8, 40)),
|
|
227
|
-
SwitchSequential(ResidualBlock(320, 320), AttentionBlock(8, 40)),
|
|
228
|
-
SwitchSequential(nn.Conv2d(320, 320, kernel_size=3, stride=2, padding=1)),
|
|
229
|
-
SwitchSequential(ResidualBlock(320, 640), AttentionBlock(8, 80)),
|
|
230
|
-
SwitchSequential(ResidualBlock(640, 640), AttentionBlock(8, 80)),
|
|
231
|
-
SwitchSequential(nn.Conv2d(640, 640, kernel_size=3, stride=2, padding=1)),
|
|
232
|
-
SwitchSequential(ResidualBlock(640, 1280), AttentionBlock(8, 160)),
|
|
233
|
-
SwitchSequential(ResidualBlock(1280, 1280), AttentionBlock(8, 160)),
|
|
234
|
-
SwitchSequential(nn.Conv2d(1280, 1280, kernel_size=3, stride=2, padding=1)),
|
|
235
|
-
SwitchSequential(ResidualBlock(1280, 1280)),
|
|
236
|
-
SwitchSequential(ResidualBlock(1280, 1280)),
|
|
237
|
-
]
|
|
238
|
-
)
|
|
239
|
-
|
|
240
|
-
def forward(self, x, context, time):
|
|
241
|
-
time_embedding = self.time_embedding(time)
|
|
242
|
-
skip_connections = []
|
|
243
|
-
for layers in self.encoders:
|
|
244
|
-
x = layers(x, context, time_embedding)
|
|
245
|
-
skip_connections.append(x)
|
|
246
|
-
|
|
247
|
-
return x, skip_connections, time_embedding
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
class UNetBottleNeck(nn.Module):
|
|
251
|
-
|
|
252
|
-
def __init__(self):
|
|
253
|
-
super().__init__()
|
|
254
|
-
self.bottleneck = SwitchSequential(
|
|
255
|
-
ResidualBlock(1280, 1280),
|
|
256
|
-
AttentionBlock(8, 160),
|
|
257
|
-
ResidualBlock(1280, 1280),
|
|
258
|
-
)
|
|
259
|
-
|
|
260
|
-
def forward(self, x, context, time):
|
|
261
|
-
x = self.bottleneck(x, context, time)
|
|
262
|
-
# print('shape')
|
|
263
|
-
# print(list(x.shape))
|
|
264
|
-
return x
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
# Unet decoder.
|
|
268
|
-
class UNetDecoder1(nn.Module):
|
|
269
|
-
|
|
270
|
-
def __init__(self):
|
|
271
|
-
super().__init__()
|
|
272
|
-
self.decoders = nn.ModuleList(
|
|
273
|
-
[
|
|
274
|
-
SwitchSequential(ResidualBlock(2560, 1280)),
|
|
275
|
-
SwitchSequential(ResidualBlock(2560, 1280)),
|
|
276
|
-
SwitchSequential(ResidualBlock(2560, 1280), Upsample(1280)),
|
|
277
|
-
SwitchSequential(ResidualBlock(2560, 1280), AttentionBlock(8, 160)),
|
|
278
|
-
]
|
|
279
|
-
)
|
|
280
|
-
|
|
281
|
-
def forward(self, x, context, time, s9, s10, s11, s12):
|
|
282
|
-
x = torch.cat((x, s12), dim=1)
|
|
283
|
-
x = self.decoders[0](x, context, time)
|
|
284
|
-
x = torch.cat((x, s11), dim=1)
|
|
285
|
-
x = self.decoders[1](x, context, time)
|
|
286
|
-
x = torch.cat((x, s10), dim=1)
|
|
287
|
-
x = self.decoders[2](x, context, time)
|
|
288
|
-
x = torch.cat((x, s9), dim=1)
|
|
289
|
-
x = self.decoders[3](x, context, time)
|
|
290
|
-
|
|
291
|
-
return x
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
class UNetDecoder2(nn.Module):
|
|
295
|
-
|
|
296
|
-
def __init__(self):
|
|
297
|
-
super().__init__()
|
|
298
|
-
self.decoders = nn.ModuleList(
|
|
299
|
-
[
|
|
300
|
-
SwitchSequential(ResidualBlock(2560, 1280), AttentionBlock(8, 160)),
|
|
301
|
-
SwitchSequential(
|
|
302
|
-
ResidualBlock(1920, 1280), AttentionBlock(8, 160), Upsample(1280)
|
|
303
|
-
),
|
|
304
|
-
SwitchSequential(ResidualBlock(1920, 640), AttentionBlock(8, 80)),
|
|
305
|
-
SwitchSequential(ResidualBlock(1280, 640), AttentionBlock(8, 80)),
|
|
306
|
-
]
|
|
307
|
-
)
|
|
308
|
-
|
|
309
|
-
def forward(self, x, context, time, s5, s6, s7, s8):
|
|
310
|
-
x = torch.cat((x, s8), dim=1)
|
|
311
|
-
x = self.decoders[0](x, context, time)
|
|
312
|
-
x = torch.cat((x, s7), dim=1)
|
|
313
|
-
x = self.decoders[1](x, context, time)
|
|
314
|
-
x = torch.cat((x, s6), dim=1)
|
|
315
|
-
x = self.decoders[2](x, context, time)
|
|
316
|
-
x = torch.cat((x, s5), dim=1)
|
|
317
|
-
x = self.decoders[3](x, context, time)
|
|
318
|
-
return x
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
class UNetDecoder3(nn.Module):
|
|
322
|
-
|
|
323
|
-
def __init__(self):
|
|
324
|
-
super().__init__()
|
|
325
|
-
self.decoders = nn.ModuleList(
|
|
326
|
-
[
|
|
327
|
-
SwitchSequential(
|
|
328
|
-
ResidualBlock(960, 640), AttentionBlock(8, 80), Upsample(640)
|
|
329
|
-
),
|
|
330
|
-
SwitchSequential(ResidualBlock(960, 320), AttentionBlock(8, 40)),
|
|
331
|
-
SwitchSequential(ResidualBlock(640, 320), AttentionBlock(8, 40)),
|
|
332
|
-
SwitchSequential(ResidualBlock(640, 320), AttentionBlock(8, 40)),
|
|
333
|
-
]
|
|
334
|
-
)
|
|
335
|
-
self.final = FinalLayer(320, 4)
|
|
336
|
-
|
|
337
|
-
def forward(self, x, context, time, s1, s2, s3, s4):
|
|
338
|
-
x = torch.cat((x, s4), dim=1)
|
|
339
|
-
x = self.decoders[0](x, context, time)
|
|
340
|
-
x = torch.cat((x, s3), dim=1)
|
|
341
|
-
x = self.decoders[1](x, context, time)
|
|
342
|
-
x = torch.cat((x, s2), dim=1)
|
|
343
|
-
x = self.decoders[2](x, context, time)
|
|
344
|
-
x = torch.cat((x, s1), dim=1)
|
|
345
|
-
x = self.decoders[3](x, context, time)
|
|
346
|
-
|
|
347
|
-
x = self.final(x)
|
|
348
|
-
return x
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
class UNetDecoder(nn.Module):
|
|
352
|
-
|
|
353
|
-
def __init__(self):
|
|
354
|
-
super().__init__()
|
|
355
|
-
self.decoders = nn.ModuleList(
|
|
356
|
-
[
|
|
357
|
-
SwitchSequential(ResidualBlock(2560, 1280)),
|
|
358
|
-
SwitchSequential(ResidualBlock(2560, 1280)),
|
|
359
|
-
SwitchSequential(ResidualBlock(2560, 1280), Upsample(1280)),
|
|
360
|
-
SwitchSequential(ResidualBlock(2560, 1280), AttentionBlock(8, 160)),
|
|
361
|
-
SwitchSequential(ResidualBlock(2560, 1280), AttentionBlock(8, 160)),
|
|
362
|
-
SwitchSequential(
|
|
363
|
-
ResidualBlock(1920, 1280), AttentionBlock(8, 160), Upsample(1280)
|
|
364
|
-
),
|
|
365
|
-
SwitchSequential(ResidualBlock(1920, 640), AttentionBlock(8, 80)),
|
|
366
|
-
SwitchSequential(ResidualBlock(1280, 640), AttentionBlock(8, 80)),
|
|
367
|
-
SwitchSequential(
|
|
368
|
-
ResidualBlock(960, 640), AttentionBlock(8, 80), Upsample(640)
|
|
369
|
-
),
|
|
370
|
-
SwitchSequential(ResidualBlock(960, 320), AttentionBlock(8, 40)),
|
|
371
|
-
SwitchSequential(ResidualBlock(640, 320), AttentionBlock(8, 40)),
|
|
372
|
-
SwitchSequential(ResidualBlock(640, 320), AttentionBlock(8, 40)),
|
|
373
|
-
]
|
|
374
|
-
)
|
|
375
|
-
self.final = FinalLayer(320, 4)
|
|
376
|
-
|
|
377
|
-
def forward(
|
|
378
|
-
self, x, context, time, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12
|
|
379
|
-
):
|
|
380
|
-
x = torch.cat((x, s12), dim=1)
|
|
381
|
-
x = self.decoders[0](x, context, time)
|
|
382
|
-
x = torch.cat((x, s11), dim=1)
|
|
383
|
-
x = self.decoders[1](x, context, time)
|
|
384
|
-
x = torch.cat((x, s10), dim=1)
|
|
385
|
-
x = self.decoders[2](x, context, time)
|
|
386
|
-
x = torch.cat((x, s9), dim=1)
|
|
387
|
-
x = self.decoders[3](x, context, time)
|
|
388
|
-
x = torch.cat((x, s8), dim=1)
|
|
389
|
-
x = self.decoders[4](x, context, time)
|
|
390
|
-
x = torch.cat((x, s7), dim=1)
|
|
391
|
-
x = self.decoders[5](x, context, time)
|
|
392
|
-
x = torch.cat((x, s6), dim=1)
|
|
393
|
-
x = self.decoders[6](x, context, time)
|
|
394
|
-
x = torch.cat((x, s5), dim=1)
|
|
395
|
-
x = self.decoders[7](x, context, time)
|
|
396
|
-
x = torch.cat((x, s4), dim=1)
|
|
397
|
-
x = self.decoders[0](x, context, time)
|
|
398
|
-
x = torch.cat((x, s3), dim=1)
|
|
399
|
-
x = self.decoders[1](x, context, time)
|
|
400
|
-
x = torch.cat((x, s2), dim=1)
|
|
401
|
-
x = self.decoders[2](x, context, time)
|
|
402
|
-
x = torch.cat((x, s1), dim=1)
|
|
403
|
-
x = self.decoders[3](x, context, time)
|
|
404
|
-
|
|
405
|
-
x = self.final(x)
|
|
406
|
-
|
|
407
|
-
return x
|
|
408
|
-
|
|
409
|
-
|
|
410
212
|
class FinalLayer(nn.Module):
|
|
411
213
|
|
|
412
214
|
def __init__(self, in_channels, out_channels):
|
|
@@ -432,68 +234,6 @@ class Diffusion(nn.Module):
|
|
|
432
234
|
@torch.inference_mode
|
|
433
235
|
def forward(self, latent, context, time):
|
|
434
236
|
time = self.time_embedding(time)
|
|
435
|
-
# print('time:')
|
|
436
|
-
# print(list(time.shape))
|
|
437
237
|
output = self.unet(latent, context, time)
|
|
438
238
|
output = self.final(output)
|
|
439
239
|
return output
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
# Calling code as if Diffusion is splitted into two parts.
|
|
443
|
-
class DiffusionSplitted(nn.Module):
|
|
444
|
-
|
|
445
|
-
def __init__(self):
|
|
446
|
-
super().__init__()
|
|
447
|
-
self.unet_encoder = UNetEncoder()
|
|
448
|
-
self.bottleneck = UNetBottleNeck()
|
|
449
|
-
self.unet_decoder1 = UNetDecoder1()
|
|
450
|
-
self.unet_decoder2 = UNetDecoder2()
|
|
451
|
-
self.unet_decoder3 = UNetDecoder3()
|
|
452
|
-
|
|
453
|
-
def get_skip_connections(self, latent, context, time):
|
|
454
|
-
_, skip_connections, _ = self.unet_encoder(latent, context, time)
|
|
455
|
-
return skip_connections
|
|
456
|
-
|
|
457
|
-
def forward(self, latent, context, time):
|
|
458
|
-
output, skip_connections, time = self.unet_encoder(latent, context, time)
|
|
459
|
-
# print("output shape of unet encoder...")
|
|
460
|
-
# print(list(output.shape))
|
|
461
|
-
# print("output shape of time...")
|
|
462
|
-
# print(list(time.shape))
|
|
463
|
-
output = self.bottleneck(output, context, time)
|
|
464
|
-
# print("output shape of bn")
|
|
465
|
-
# print(list(output.shape))
|
|
466
|
-
output = self.unet_decoder1(
|
|
467
|
-
output,
|
|
468
|
-
context,
|
|
469
|
-
time,
|
|
470
|
-
skip_connections[8],
|
|
471
|
-
skip_connections[9],
|
|
472
|
-
skip_connections[10],
|
|
473
|
-
skip_connections[11],
|
|
474
|
-
)
|
|
475
|
-
# print("output shape of d1:")
|
|
476
|
-
# print(list(output.shape))
|
|
477
|
-
|
|
478
|
-
output = self.unet_decoder2(
|
|
479
|
-
output,
|
|
480
|
-
context,
|
|
481
|
-
time,
|
|
482
|
-
skip_connections[4],
|
|
483
|
-
skip_connections[5],
|
|
484
|
-
skip_connections[6],
|
|
485
|
-
skip_connections[7],
|
|
486
|
-
)
|
|
487
|
-
|
|
488
|
-
# print("output shape of d2:")
|
|
489
|
-
# print(list(output.shape))
|
|
490
|
-
output = self.unet_decoder3(
|
|
491
|
-
output,
|
|
492
|
-
context,
|
|
493
|
-
time,
|
|
494
|
-
skip_connections[0],
|
|
495
|
-
skip_connections[1],
|
|
496
|
-
skip_connections[2],
|
|
497
|
-
skip_connections[3],
|
|
498
|
-
)
|
|
499
|
-
return output
|
|
@@ -20,11 +20,11 @@ import torch
|
|
|
20
20
|
from torch import nn
|
|
21
21
|
import torch.nn.functional as F
|
|
22
22
|
|
|
23
|
-
from ai_edge_torch.generative.layers.attention import scaled_dot_product_attention # NOQA
|
|
24
|
-
from ai_edge_torch.generative.layers.attention import scaled_dot_product_attention_with_hlfb # NOQA
|
|
25
23
|
import ai_edge_torch.generative.layers.builder as builder
|
|
26
24
|
from ai_edge_torch.generative.layers.kv_cache import KVCache
|
|
27
25
|
import ai_edge_torch.generative.layers.model_config as cfg
|
|
26
|
+
from ai_edge_torch.generative.layers.scaled_dot_product_attention import scaled_dot_product_attention # NOQA
|
|
27
|
+
from ai_edge_torch.generative.layers.scaled_dot_product_attention import scaled_dot_product_attention_with_hlfb # NOQA
|
|
28
28
|
|
|
29
29
|
|
|
30
30
|
class EncoderDecoderBlock(nn.Module):
|
|
@@ -0,0 +1,161 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
# A toy example which has basic transformer block (w/ externalized KV-Cache).
|
|
16
|
+
|
|
17
|
+
from typing import List, Tuple
|
|
18
|
+
|
|
19
|
+
import numpy as np
|
|
20
|
+
import torch
|
|
21
|
+
import torch.nn as nn
|
|
22
|
+
import torch_xla
|
|
23
|
+
|
|
24
|
+
import ai_edge_torch
|
|
25
|
+
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
|
26
|
+
import ai_edge_torch.generative.layers.builder as builder
|
|
27
|
+
from ai_edge_torch.generative.layers.experimental.attention import TransformerBlock # NOQA
|
|
28
|
+
import ai_edge_torch.generative.layers.model_config as cfg
|
|
29
|
+
|
|
30
|
+
RoPECache = Tuple[torch.Tensor, torch.Tensor]
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class ToyModelWithExternalKV(torch.nn.Module):
|
|
34
|
+
|
|
35
|
+
def __init__(self, config: cfg.ModelConfig) -> None:
|
|
36
|
+
super().__init__()
|
|
37
|
+
self.lm_head = nn.Linear(
|
|
38
|
+
config.embedding_dim, config.vocab_size, bias=config.lm_head_use_bias
|
|
39
|
+
)
|
|
40
|
+
self.tok_embedding = nn.Embedding(config.vocab_size, config.embedding_dim)
|
|
41
|
+
self.transformer_blocks = nn.ModuleList(
|
|
42
|
+
TransformerBlock(config) for _ in range(config.num_layers)
|
|
43
|
+
)
|
|
44
|
+
self.final_norm = builder.build_norm(
|
|
45
|
+
config.embedding_dim,
|
|
46
|
+
config.final_norm_config,
|
|
47
|
+
)
|
|
48
|
+
self.rope_cache = attn_utils.build_rope_cache(
|
|
49
|
+
size=config.max_seq_len,
|
|
50
|
+
dim=int(config.attn_config.rotary_percentage * config.head_dim),
|
|
51
|
+
base=10_000,
|
|
52
|
+
condense_ratio=1,
|
|
53
|
+
dtype=torch.float32,
|
|
54
|
+
device=torch.device('cpu'),
|
|
55
|
+
)
|
|
56
|
+
self.mask_cache = attn_utils.build_causal_mask_cache(
|
|
57
|
+
size=config.max_seq_len, dtype=torch.float32, device=torch.device('cpu')
|
|
58
|
+
)
|
|
59
|
+
self.config = config
|
|
60
|
+
|
|
61
|
+
def forward(
|
|
62
|
+
self,
|
|
63
|
+
idx: torch.Tensor,
|
|
64
|
+
input_pos: torch.Tensor,
|
|
65
|
+
k_caches: torch.Tensor,
|
|
66
|
+
v_caches: torch.Tensor,
|
|
67
|
+
) -> (torch.Tensor, torch.Tensor, torch.Tensor):
|
|
68
|
+
x = self.tok_embedding(idx)
|
|
69
|
+
cos, sin = self.rope_cache
|
|
70
|
+
cos = cos.index_select(0, input_pos)
|
|
71
|
+
sin = sin.index_select(0, input_pos)
|
|
72
|
+
mask = self.mask_cache.index_select(2, input_pos)
|
|
73
|
+
mask = mask[:, :, :, : self.config.max_seq_len]
|
|
74
|
+
|
|
75
|
+
for i, block in enumerate(self.transformer_blocks):
|
|
76
|
+
input_k, input_v = k_caches[i], v_caches[i]
|
|
77
|
+
x, (updated_k, updated_v) = block(
|
|
78
|
+
x, (cos, sin), mask, input_pos, (input_k, input_v)
|
|
79
|
+
)
|
|
80
|
+
k_caches[i], v_caches[i] = updated_k, updated_v
|
|
81
|
+
|
|
82
|
+
x = self.final_norm(x)
|
|
83
|
+
return self.lm_head(x), k_caches, v_caches
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def _export_stablehlo_mlir(model, args):
|
|
87
|
+
ep = torch.export.export(model, args)
|
|
88
|
+
stablehlo_gm = torch_xla.stablehlo.exported_program_to_stablehlo(ep)
|
|
89
|
+
return stablehlo_gm.get_stablehlo_text()
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def get_model_config() -> cfg.ModelConfig:
|
|
93
|
+
attn_config = cfg.AttentionConfig(
|
|
94
|
+
num_heads=32, num_query_groups=4, rotary_percentage=1.0
|
|
95
|
+
)
|
|
96
|
+
ff_config = cfg.FeedForwardConfig(
|
|
97
|
+
type=cfg.FeedForwardType.GATED,
|
|
98
|
+
activation=cfg.ActivationType.SILU,
|
|
99
|
+
intermediate_size=256,
|
|
100
|
+
)
|
|
101
|
+
norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
|
|
102
|
+
config = cfg.ModelConfig(
|
|
103
|
+
vocab_size=150,
|
|
104
|
+
num_layers=2,
|
|
105
|
+
max_seq_len=100,
|
|
106
|
+
embedding_dim=128,
|
|
107
|
+
attn_config=attn_config,
|
|
108
|
+
ff_config=ff_config,
|
|
109
|
+
pre_attention_norm_config=norm_config,
|
|
110
|
+
pre_ff_norm_config=norm_config,
|
|
111
|
+
final_norm_config=norm_config,
|
|
112
|
+
enable_hlfb=True,
|
|
113
|
+
)
|
|
114
|
+
return config
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def get_sample_prefill_inputs() -> Tuple[torch.Tensor, torch.Tensor]:
|
|
118
|
+
idx = torch.unsqueeze(torch.arange(0, 100), 0)
|
|
119
|
+
input_pos = torch.arange(0, 100)
|
|
120
|
+
return idx, input_pos
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def get_sample_decode_inputs() -> Tuple[torch.Tensor, torch.Tensor]:
|
|
124
|
+
idx = torch.tensor([[1]], dtype=torch.long)
|
|
125
|
+
input_pos = torch.tensor([10])
|
|
126
|
+
return idx, input_pos
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def define_and_run() -> None:
|
|
130
|
+
dump_mlir = False
|
|
131
|
+
|
|
132
|
+
config = get_model_config()
|
|
133
|
+
model = ToyModelWithExternalKV(config)
|
|
134
|
+
print('running an inference')
|
|
135
|
+
k_caches = torch.zeros((2, 1, 100, 4, 4), dtype=torch.float32)
|
|
136
|
+
v_caches = torch.zeros((2, 1, 100, 4, 4), dtype=torch.float32)
|
|
137
|
+
|
|
138
|
+
idx, input_pos = get_sample_prefill_inputs()
|
|
139
|
+
decode_idx, decode_input_pos = get_sample_decode_inputs()
|
|
140
|
+
print(model.forward(idx, input_pos, k_caches, v_caches))
|
|
141
|
+
|
|
142
|
+
if dump_mlir:
|
|
143
|
+
mlir_text = _export_stablehlo_mlir(model, (idx, input_pos, k_caches, v_caches))
|
|
144
|
+
with open('/tmp/toy_model_with_external_kv.stablehlo.mlir', 'w') as f:
|
|
145
|
+
f.write(mlir_text)
|
|
146
|
+
|
|
147
|
+
# Convert model to tflite with 2 signatures (prefill + decode).
|
|
148
|
+
# TODO(b/344014416): currently conversion will fail, because we generate int64 index
|
|
149
|
+
# in dynamic update slice op.
|
|
150
|
+
print('converting toy model to tflite with 2 signatures (prefill + decode)')
|
|
151
|
+
edge_model = (
|
|
152
|
+
ai_edge_torch.signature('prefill', model, (idx, input_pos, k_caches, v_caches))
|
|
153
|
+
.signature('decode', model, (decode_idx, decode_input_pos, k_caches, v_caches))
|
|
154
|
+
.convert()
|
|
155
|
+
)
|
|
156
|
+
edge_model.export('/tmp/toy_external_kv_cache.tflite')
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
if __name__ == '__main__':
|
|
160
|
+
with torch.inference_mode():
|
|
161
|
+
define_and_run()
|
|
@@ -14,7 +14,6 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
# Common building blocks for Attention layer.
|
|
16
16
|
|
|
17
|
-
import math
|
|
18
17
|
from typing import Optional, Tuple
|
|
19
18
|
|
|
20
19
|
import torch
|
|
@@ -25,101 +24,8 @@ import ai_edge_torch.generative.layers.builder as builder
|
|
|
25
24
|
from ai_edge_torch.generative.layers.kv_cache import KVCache
|
|
26
25
|
import ai_edge_torch.generative.layers.model_config as cfg
|
|
27
26
|
import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
|
|
28
|
-
from ai_edge_torch.
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
def scaled_dot_product_attention(
|
|
32
|
-
q: torch.Tensor,
|
|
33
|
-
k: torch.Tensor,
|
|
34
|
-
v: torch.Tensor,
|
|
35
|
-
head_size: int,
|
|
36
|
-
mask: Optional[torch.Tensor] = None,
|
|
37
|
-
scale: Optional[float] = None,
|
|
38
|
-
):
|
|
39
|
-
"""Scaled dot product attention.
|
|
40
|
-
|
|
41
|
-
Args:
|
|
42
|
-
q (torch.Tensor): Query tensor, with shape [B, T, N, H].
|
|
43
|
-
k (torch.Tensor): Key tensor, with shape [B, T, KV_LEN, H].
|
|
44
|
-
v (torch.Tensor): Value tensor, with shape [B, T, KV_LEN, H].
|
|
45
|
-
head_size (int): head dimension.
|
|
46
|
-
mask (torch.Tensor): the optional mask tensor.
|
|
47
|
-
|
|
48
|
-
Returns:
|
|
49
|
-
The output tensor of scaled_dot_product_attention.
|
|
50
|
-
"""
|
|
51
|
-
|
|
52
|
-
if scale is None:
|
|
53
|
-
scale = 1.0 / math.sqrt(head_size)
|
|
54
|
-
|
|
55
|
-
q = q.transpose(1, 2)
|
|
56
|
-
k = k.transpose(1, 2)
|
|
57
|
-
v = v.transpose(1, 2)
|
|
58
|
-
if q.size() != k.size():
|
|
59
|
-
# Handle the GQA case, where q.shape[1] % k.shape[1] == 0.
|
|
60
|
-
k = k.repeat_interleave(q.shape[1] // k.shape[1], dim=1)
|
|
61
|
-
v = v.repeat_interleave(q.shape[1] // v.shape[1], dim=1)
|
|
62
|
-
y = F.scaled_dot_product_attention(
|
|
63
|
-
q,
|
|
64
|
-
k,
|
|
65
|
-
v,
|
|
66
|
-
attn_mask=mask,
|
|
67
|
-
dropout_p=0.0,
|
|
68
|
-
is_causal=mask is None,
|
|
69
|
-
scale=scale,
|
|
70
|
-
)
|
|
71
|
-
return y.transpose(1, 2)
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
def scaled_dot_product_attention_with_hlfb(
|
|
75
|
-
q: torch.Tensor,
|
|
76
|
-
k: torch.Tensor,
|
|
77
|
-
v: torch.Tensor,
|
|
78
|
-
head_size: int,
|
|
79
|
-
mask: Optional[torch.Tensor] = None,
|
|
80
|
-
scale: Optional[float] = None,
|
|
81
|
-
):
|
|
82
|
-
"""Scaled dot product attention with high-level function boundary enabled.
|
|
83
|
-
|
|
84
|
-
Args:
|
|
85
|
-
q (torch.Tensor): Query tensor, with shape [B, T, N, H].
|
|
86
|
-
k (torch.Tensor): Key tensor, with shape [B, T, KV_LEN, H].
|
|
87
|
-
v (torch.Tensor): Value tensor, with shape [B, T, KV_LEN, H].
|
|
88
|
-
head_size (int): head dimension.
|
|
89
|
-
mask (torch.Tensor): the optional mask tensor.
|
|
90
|
-
|
|
91
|
-
Returns:
|
|
92
|
-
The output tensor of scaled_dot_product_attention.
|
|
93
|
-
"""
|
|
94
|
-
|
|
95
|
-
if scale is None:
|
|
96
|
-
scale = 1.0 / math.sqrt(head_size)
|
|
97
|
-
|
|
98
|
-
builder = StableHLOCompositeBuilder(
|
|
99
|
-
name="odml.scaled_dot_product_attention", attr={"scale": scale}
|
|
100
|
-
)
|
|
101
|
-
q, k, v, mask = builder.mark_inputs(q, k, v, mask)
|
|
102
|
-
|
|
103
|
-
q = q.transpose(1, 2)
|
|
104
|
-
k = k.transpose(1, 2)
|
|
105
|
-
v = v.transpose(1, 2)
|
|
106
|
-
if q.size() != k.size():
|
|
107
|
-
# Handle the GQA case, where q.shape[1] % k.shape[1] == 0.
|
|
108
|
-
k = k.repeat_interleave(q.shape[1] // k.shape[1], dim=1)
|
|
109
|
-
v = v.repeat_interleave(q.shape[1] // v.shape[1], dim=1)
|
|
110
|
-
y = F.scaled_dot_product_attention(
|
|
111
|
-
q,
|
|
112
|
-
k,
|
|
113
|
-
v,
|
|
114
|
-
attn_mask=mask,
|
|
115
|
-
dropout_p=0.0,
|
|
116
|
-
is_causal=mask is None,
|
|
117
|
-
scale=scale,
|
|
118
|
-
)
|
|
119
|
-
|
|
120
|
-
result = y.transpose(1, 2)
|
|
121
|
-
result = builder.mark_outputs(result)
|
|
122
|
-
return result
|
|
27
|
+
from ai_edge_torch.generative.layers.scaled_dot_product_attention import scaled_dot_product_attention # NOQA
|
|
28
|
+
from ai_edge_torch.generative.layers.scaled_dot_product_attention import scaled_dot_product_attention_with_hlfb # NOQA
|
|
123
29
|
|
|
124
30
|
|
|
125
31
|
class TransformerBlock(nn.Module):
|
|
@@ -151,7 +57,7 @@ class TransformerBlock(nn.Module):
|
|
|
151
57
|
def forward(
|
|
152
58
|
self,
|
|
153
59
|
x: torch.Tensor,
|
|
154
|
-
rope: Tuple[torch.Tensor, torch.Tensor],
|
|
60
|
+
rope: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
|
155
61
|
mask: Optional[torch.Tensor] = None,
|
|
156
62
|
input_pos: Optional[torch.Tensor] = None,
|
|
157
63
|
) -> torch.Tensor:
|
|
@@ -182,7 +88,6 @@ class TransformerBlock(nn.Module):
|
|
|
182
88
|
return output
|
|
183
89
|
|
|
184
90
|
|
|
185
|
-
# CausalSelfAttention which can support MHQ, MQA or GQA.
|
|
186
91
|
class CausalSelfAttention(nn.Module):
|
|
187
92
|
|
|
188
93
|
def __init__(
|
|
@@ -229,11 +134,12 @@ class CausalSelfAttention(nn.Module):
|
|
|
229
134
|
def forward(
|
|
230
135
|
self,
|
|
231
136
|
x: torch.Tensor,
|
|
232
|
-
rope: Tuple[torch.Tensor, torch.Tensor],
|
|
137
|
+
rope: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
|
233
138
|
mask: Optional[torch.Tensor] = None,
|
|
234
139
|
input_pos: Optional[torch.Tensor] = None,
|
|
235
140
|
) -> torch.Tensor:
|
|
236
|
-
"""Forward function of the CausalSelfAttention layer
|
|
141
|
+
"""Forward function of the CausalSelfAttention layer, which can support
|
|
142
|
+
MQA, GQA and MHA.
|
|
237
143
|
|
|
238
144
|
Args:
|
|
239
145
|
x (torch.Tensor): the input tensor.
|
|
@@ -253,28 +159,35 @@ class CausalSelfAttention(nn.Module):
|
|
|
253
159
|
# Assemble into a number of query groups to support MHA, MQA and GQA.
|
|
254
160
|
q_per_kv = self.config.num_heads // self.config.num_query_groups
|
|
255
161
|
total_qkv = q_per_kv + 2 # Each group has >=1 queries, 1 key, and 1 value.
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
162
|
+
if self.config.qkv_transpose_before_split:
|
|
163
|
+
qkv = qkv.view(
|
|
164
|
+
B, T, total_qkv, self.config.num_query_groups, self.head_dim
|
|
165
|
+
) # (B, T, total_qkv, num_query_groups, head_dim)
|
|
166
|
+
qkv_axis = -3
|
|
167
|
+
else:
|
|
168
|
+
qkv = qkv.view(
|
|
169
|
+
B, T, self.config.num_query_groups, total_qkv, self.head_dim
|
|
170
|
+
) # (B, T, num_query_groups, total_qkv, head_dim)
|
|
171
|
+
qkv_axis = -2
|
|
259
172
|
|
|
260
173
|
# Split batched computation into three.
|
|
261
|
-
q, k, v = qkv.split((q_per_kv, 1, 1), dim
|
|
262
|
-
|
|
174
|
+
q, k, v = qkv.split((q_per_kv, 1, 1), dim=qkv_axis)
|
|
263
175
|
q = q.reshape(B, T, -1, self.head_dim)
|
|
264
176
|
k = k.reshape(B, T, -1, self.head_dim)
|
|
265
177
|
v = v.reshape(B, T, -1, self.head_dim)
|
|
266
178
|
|
|
267
179
|
# Compute rotary positional embedding for query and key.
|
|
268
180
|
n_elem = int(self.config.rotary_percentage * self.head_dim)
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
181
|
+
if n_elem > 0:
|
|
182
|
+
cos, sin = rope
|
|
183
|
+
q_roped = rotary_pos_emb.apply_rope(
|
|
184
|
+
q[..., :n_elem], cos.repeat(1, 2), sin.repeat(1, 2)
|
|
185
|
+
)
|
|
186
|
+
k_roped = rotary_pos_emb.apply_rope(
|
|
187
|
+
k[..., :n_elem], cos.repeat(1, 2), sin.repeat(1, 2)
|
|
188
|
+
)
|
|
189
|
+
q = torch.cat((q_roped, q[..., n_elem:]), dim=-1)
|
|
190
|
+
k = torch.cat((k_roped, k[..., n_elem:]), dim=-1)
|
|
278
191
|
|
|
279
192
|
if self.kv_cache is not None:
|
|
280
193
|
# TODO(haoliang): Handle when execeeding max sequence length.
|
|
@@ -97,6 +97,10 @@ def _get_activation(type_: cfg.ActivationType):
|
|
|
97
97
|
return F.gelu
|
|
98
98
|
elif type_ == cfg.ActivationType.GELU_TANH:
|
|
99
99
|
return lambda x: F.gelu(x, approximate="tanh")
|
|
100
|
+
elif type_ == cfg.ActivationType.GELU_QUICK:
|
|
101
|
+
# GELU approximation that is fast but somewhat inaccurate.
|
|
102
|
+
# See: https://github.com/hendrycks/GELUs
|
|
103
|
+
return lambda x: x * F.sigmoid(1.702 * x)
|
|
100
104
|
elif type_ == cfg.ActivationType.RELU:
|
|
101
105
|
return F.relu
|
|
102
106
|
else:
|
|
@@ -27,6 +27,7 @@ class ActivationType(enum.Enum):
|
|
|
27
27
|
SILU = enum.auto()
|
|
28
28
|
GELU = enum.auto()
|
|
29
29
|
GELU_TANH = enum.auto()
|
|
30
|
+
GELU_QUICK = enum.auto()
|
|
30
31
|
RELU = enum.auto()
|
|
31
32
|
|
|
32
33
|
|
|
@@ -46,7 +47,7 @@ class FeedForwardType(enum.Enum):
|
|
|
46
47
|
|
|
47
48
|
# `output = linear(act(linear(x)))`.
|
|
48
49
|
SEQUENTIAL = enum.auto()
|
|
49
|
-
# `output =
|
|
50
|
+
# `output = linear_2(act(linear_1(x)) * lienar_3(x))`.
|
|
50
51
|
GATED = enum.auto()
|
|
51
52
|
|
|
52
53
|
|
|
@@ -60,6 +61,9 @@ class AttentionConfig:
|
|
|
60
61
|
num_query_groups: Optional[int]
|
|
61
62
|
# Percentage of Rotary Positional Embedding added Q and K projections.
|
|
62
63
|
rotary_percentage: Optional[float] = None
|
|
64
|
+
# Whether to transpose the query groups of qkv bundled tensor before
|
|
65
|
+
# splitting into separated tensors.
|
|
66
|
+
qkv_transpose_before_split: bool = False
|
|
63
67
|
# Whether to use bias with Query, Key, and Value projection.
|
|
64
68
|
qkv_use_bias: bool = False
|
|
65
69
|
# Whether to use bias with attention output projection.
|
|
@@ -0,0 +1,117 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
# Implements scaled dot product attention.
|
|
16
|
+
|
|
17
|
+
import math
|
|
18
|
+
from typing import Optional
|
|
19
|
+
|
|
20
|
+
import torch
|
|
21
|
+
import torch.nn.functional as F
|
|
22
|
+
|
|
23
|
+
from ai_edge_torch.hlfb import StableHLOCompositeBuilder
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def scaled_dot_product_attention(
|
|
27
|
+
q: torch.Tensor,
|
|
28
|
+
k: torch.Tensor,
|
|
29
|
+
v: torch.Tensor,
|
|
30
|
+
head_size: int,
|
|
31
|
+
mask: Optional[torch.Tensor] = None,
|
|
32
|
+
scale: Optional[float] = None,
|
|
33
|
+
):
|
|
34
|
+
"""Scaled dot product attention.
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
q (torch.Tensor): Query tensor, with shape [B, T, N, H].
|
|
38
|
+
k (torch.Tensor): Key tensor, with shape [B, T, KV_LEN, H].
|
|
39
|
+
v (torch.Tensor): Value tensor, with shape [B, T, KV_LEN, H].
|
|
40
|
+
head_size (int): head dimension.
|
|
41
|
+
mask (torch.Tensor): the optional mask tensor.
|
|
42
|
+
|
|
43
|
+
Returns:
|
|
44
|
+
The output tensor of scaled_dot_product_attention.
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
if scale is None:
|
|
48
|
+
scale = 1.0 / math.sqrt(head_size)
|
|
49
|
+
|
|
50
|
+
q = q.transpose(1, 2)
|
|
51
|
+
k = k.transpose(1, 2)
|
|
52
|
+
v = v.transpose(1, 2)
|
|
53
|
+
if q.size() != k.size():
|
|
54
|
+
# Handle the GQA case, where q.shape[1] % k.shape[1] == 0.
|
|
55
|
+
k = k.repeat_interleave(q.shape[1] // k.shape[1], dim=1)
|
|
56
|
+
v = v.repeat_interleave(q.shape[1] // v.shape[1], dim=1)
|
|
57
|
+
y = F.scaled_dot_product_attention(
|
|
58
|
+
q,
|
|
59
|
+
k,
|
|
60
|
+
v,
|
|
61
|
+
attn_mask=mask,
|
|
62
|
+
dropout_p=0.0,
|
|
63
|
+
is_causal=mask is None,
|
|
64
|
+
scale=scale,
|
|
65
|
+
)
|
|
66
|
+
return y.transpose(1, 2)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def scaled_dot_product_attention_with_hlfb(
|
|
70
|
+
q: torch.Tensor,
|
|
71
|
+
k: torch.Tensor,
|
|
72
|
+
v: torch.Tensor,
|
|
73
|
+
head_size: int,
|
|
74
|
+
mask: Optional[torch.Tensor] = None,
|
|
75
|
+
scale: Optional[float] = None,
|
|
76
|
+
):
|
|
77
|
+
"""Scaled dot product attention with high-level function boundary enabled.
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
q (torch.Tensor): Query tensor, with shape [B, T, N, H].
|
|
81
|
+
k (torch.Tensor): Key tensor, with shape [B, T, KV_LEN, H].
|
|
82
|
+
v (torch.Tensor): Value tensor, with shape [B, T, KV_LEN, H].
|
|
83
|
+
head_size (int): head dimension.
|
|
84
|
+
mask (torch.Tensor): the optional mask tensor.
|
|
85
|
+
|
|
86
|
+
Returns:
|
|
87
|
+
The output tensor of scaled_dot_product_attention.
|
|
88
|
+
"""
|
|
89
|
+
|
|
90
|
+
if scale is None:
|
|
91
|
+
scale = 1.0 / math.sqrt(head_size)
|
|
92
|
+
|
|
93
|
+
builder = StableHLOCompositeBuilder(
|
|
94
|
+
name="odml.scaled_dot_product_attention", attr={"scale": scale}
|
|
95
|
+
)
|
|
96
|
+
q, k, v, mask = builder.mark_inputs(q, k, v, mask)
|
|
97
|
+
|
|
98
|
+
q = q.transpose(1, 2)
|
|
99
|
+
k = k.transpose(1, 2)
|
|
100
|
+
v = v.transpose(1, 2)
|
|
101
|
+
if q.size() != k.size():
|
|
102
|
+
# Handle the GQA case, where q.shape[1] % k.shape[1] == 0.
|
|
103
|
+
k = k.repeat_interleave(q.shape[1] // k.shape[1], dim=1)
|
|
104
|
+
v = v.repeat_interleave(q.shape[1] // v.shape[1], dim=1)
|
|
105
|
+
y = F.scaled_dot_product_attention(
|
|
106
|
+
q,
|
|
107
|
+
k,
|
|
108
|
+
v,
|
|
109
|
+
attn_mask=mask,
|
|
110
|
+
dropout_p=0.0,
|
|
111
|
+
is_causal=mask is None,
|
|
112
|
+
scale=scale,
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
result = y.transpose(1, 2)
|
|
116
|
+
result = builder.mark_outputs(result)
|
|
117
|
+
return result
|
|
@@ -69,10 +69,16 @@ def load_pytorch_statedict(full_path: str):
|
|
|
69
69
|
Raises:
|
|
70
70
|
ValueError: If no tensors are loaded from the provided directory or file.
|
|
71
71
|
"""
|
|
72
|
-
pattern = os.path.join(full_path, "*.bin") if os.path.isdir(full_path) else full_path
|
|
73
72
|
files = []
|
|
74
|
-
|
|
75
|
-
|
|
73
|
+
patterns = []
|
|
74
|
+
if os.path.isdir(full_path):
|
|
75
|
+
patterns.append(os.path.join(full_path, "*.bin"))
|
|
76
|
+
patterns.append(os.path.join(full_path, "*.pt"))
|
|
77
|
+
else:
|
|
78
|
+
patterns.append(full_path)
|
|
79
|
+
for pattern in patterns:
|
|
80
|
+
for file in glob.glob(pattern):
|
|
81
|
+
files.append(file)
|
|
76
82
|
|
|
77
83
|
tensors = {}
|
|
78
84
|
for file in files:
|
|
@@ -93,18 +99,20 @@ class ModelLoader:
|
|
|
93
99
|
|
|
94
100
|
@dataclass
|
|
95
101
|
class TensorNames:
|
|
96
|
-
attn_query_proj: str
|
|
97
|
-
attn_key_proj: str
|
|
98
|
-
attn_value_proj: str
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
102
|
+
attn_query_proj: str = None
|
|
103
|
+
attn_key_proj: str = None
|
|
104
|
+
attn_value_proj: str = None
|
|
105
|
+
attn_fused_qkv_proj: str = None
|
|
106
|
+
attn_output_proj: str = None
|
|
107
|
+
|
|
108
|
+
ff_up_proj: str = None
|
|
109
|
+
ff_down_proj: str = None
|
|
103
110
|
ff_gate_proj: str = None
|
|
104
111
|
|
|
105
112
|
pre_attn_norm: str = None
|
|
106
113
|
pre_ff_norm: str = None
|
|
107
114
|
embedding: str = None
|
|
115
|
+
embedding_position: str = None
|
|
108
116
|
final_norm: str = None
|
|
109
117
|
lm_head: str = None
|
|
110
118
|
|
|
@@ -129,6 +137,10 @@ class ModelLoader:
|
|
|
129
137
|
strict (bool, optional): Whether the converted keys are strictly
|
|
130
138
|
matched. Defaults to True.
|
|
131
139
|
|
|
140
|
+
Returns:
|
|
141
|
+
missing_keys (List[str]): a list of str containing the missing keys
|
|
142
|
+
unexpected_keys (List[str]): a list of str containing the unexpected keys
|
|
143
|
+
|
|
132
144
|
Raises:
|
|
133
145
|
ValueError: If conversion results in unmapped tensors and strict mode is
|
|
134
146
|
enabled.
|
|
@@ -139,6 +151,10 @@ class ModelLoader:
|
|
|
139
151
|
converted_state["tok_embedding.weight"] = state.pop(
|
|
140
152
|
f"{self._names.embedding}.weight"
|
|
141
153
|
)
|
|
154
|
+
if self._names.embedding_position is not None:
|
|
155
|
+
converted_state["tok_embedding_position"] = state.pop(
|
|
156
|
+
f"{self._names.embedding_position}"
|
|
157
|
+
)
|
|
142
158
|
if self._names.lm_head is not None:
|
|
143
159
|
converted_state["lm_head.weight"] = state.pop(f"{self._names.lm_head}.weight")
|
|
144
160
|
if model.config.lm_head_use_bias:
|
|
@@ -158,7 +174,7 @@ class ModelLoader:
|
|
|
158
174
|
raise ValueError(
|
|
159
175
|
f"Failed to map all tensor. Remaing tensor are: {list(state.keys())}"
|
|
160
176
|
)
|
|
161
|
-
model.load_state_dict(converted_state, strict=strict)
|
|
177
|
+
return model.load_state_dict(converted_state, strict=strict)
|
|
162
178
|
|
|
163
179
|
def _get_loader(self) -> Callable[[str], Dict[str, torch.Tensor]]:
|
|
164
180
|
"""A best effort method for finding appropriate state loader.
|
|
@@ -172,13 +188,15 @@ class ModelLoader:
|
|
|
172
188
|
if os.path.isdir(self._file_name):
|
|
173
189
|
if glob.glob(os.path.join(self._file_name, "*.safetensors")):
|
|
174
190
|
return load_safetensors
|
|
175
|
-
if glob.glob(os.path.join(self._file_name, "*.bin"))
|
|
191
|
+
if glob.glob(os.path.join(self._file_name, "*.bin")) or glob.glob(
|
|
192
|
+
os.path.join(self._file_name, "*.pt")
|
|
193
|
+
):
|
|
176
194
|
return load_pytorch_statedict
|
|
177
195
|
|
|
178
196
|
if self._file_name.endswith(".safetensors"):
|
|
179
197
|
return load_safetensors
|
|
180
198
|
|
|
181
|
-
if self._file_name.endswith(".bin"):
|
|
199
|
+
if self._file_name.endswith(".bin") or self._file_name.endswith(".pt"):
|
|
182
200
|
return load_pytorch_statedict
|
|
183
201
|
|
|
184
202
|
raise ValueError(f"File format not supported.")
|
|
@@ -225,22 +243,33 @@ class ModelLoader:
|
|
|
225
243
|
converted_state: Dict[str, torch.Tensor],
|
|
226
244
|
):
|
|
227
245
|
prefix = f"transformer_blocks.{idx}"
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
converted_state[f"{prefix}.atten_func.qkv_projection.bias"] = self._fuse_qkv(
|
|
246
|
+
if self._names.attn_fused_qkv_proj:
|
|
247
|
+
fused_qkv_name = self._names.attn_fused_qkv_proj.format(idx)
|
|
248
|
+
converted_state[f"{prefix}.atten_func.qkv_projection.weight"] = state.pop(
|
|
249
|
+
f"{fused_qkv_name}.weight"
|
|
250
|
+
)
|
|
251
|
+
else:
|
|
252
|
+
q_name = self._names.attn_query_proj.format(idx)
|
|
253
|
+
k_name = self._names.attn_key_proj.format(idx)
|
|
254
|
+
v_name = self._names.attn_value_proj.format(idx)
|
|
255
|
+
converted_state[f"{prefix}.atten_func.qkv_projection.weight"] = self._fuse_qkv(
|
|
239
256
|
config,
|
|
240
|
-
state.pop(f"{q_name}.
|
|
241
|
-
state.pop(f"{k_name}.
|
|
242
|
-
state.pop(f"{v_name}.
|
|
257
|
+
state.pop(f"{q_name}.weight"),
|
|
258
|
+
state.pop(f"{k_name}.weight"),
|
|
259
|
+
state.pop(f"{v_name}.weight"),
|
|
243
260
|
)
|
|
261
|
+
if config.attn_config.qkv_use_bias:
|
|
262
|
+
if self._names.attn_fused_qkv_proj:
|
|
263
|
+
converted_state[f"{prefix}.atten_func.qkv_projection.bias"] = state.pop(
|
|
264
|
+
f"{fused_qkv_name}.bias"
|
|
265
|
+
)
|
|
266
|
+
else:
|
|
267
|
+
converted_state[f"{prefix}.atten_func.qkv_projection.bias"] = self._fuse_qkv(
|
|
268
|
+
config,
|
|
269
|
+
state.pop(f"{q_name}.bias"),
|
|
270
|
+
state.pop(f"{k_name}.bias"),
|
|
271
|
+
state.pop(f"{v_name}.bias"),
|
|
272
|
+
)
|
|
244
273
|
|
|
245
274
|
o_name = self._names.attn_output_proj.format(idx)
|
|
246
275
|
converted_state[f"{prefix}.atten_func.output_projection.weight"] = state.pop(
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: ai-edge-torch-nightly
|
|
3
|
-
Version: 0.2.0.
|
|
3
|
+
Version: 0.2.0.dev20240605
|
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
|
@@ -6,7 +6,7 @@ ai_edge_torch/convert/conversion_utils.py,sha256=NpVm3Ms81_cIW5IYgGsr0BVganJJgBK
|
|
|
6
6
|
ai_edge_torch/convert/converter.py,sha256=bjj5TV5_g4sGyuSh8ThEDydlNMqhkGSY4SzXK6vwhqI,6927
|
|
7
7
|
ai_edge_torch/convert/fx_passes/__init__.py,sha256=Ll2nNwufjcV5nSruQPXiloq7F1E7pWJ2T5clXmy1lk8,2825
|
|
8
8
|
ai_edge_torch/convert/fx_passes/_pass_base.py,sha256=ijVyDclPnd6a0DWWUJkwR4igj6f82S-cE1-83QGPvgw,1652
|
|
9
|
-
ai_edge_torch/convert/fx_passes/build_aten_composite_pass.py,sha256=
|
|
9
|
+
ai_edge_torch/convert/fx_passes/build_aten_composite_pass.py,sha256=2yqUwJJ2R233_X9FNMOP9oYRTTzH34TR_BIUj-wfnKw,7080
|
|
10
10
|
ai_edge_torch/convert/fx_passes/build_upsample_bilinear2d_composite_pass.py,sha256=76XYoIlFDgrzp5QemoaEalPFcEbfszkEH_PLvO1ASCk,2607
|
|
11
11
|
ai_edge_torch/convert/fx_passes/canonicalize_pass.py,sha256=UX6dJsxCqSkftXXvNBV-i7Bjk6H7qTyqzUnE640Itfg,1673
|
|
12
12
|
ai_edge_torch/convert/fx_passes/inject_mlir_debuginfo_pass.py,sha256=aRT8hTS3n9ie28lgu6mygtFO6Ypwu0qjNb0c81v9HLs,2448
|
|
@@ -22,7 +22,7 @@ ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partition
|
|
|
22
22
|
ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py,sha256=FlNKt2EhIKnlVEeUWTiv5sz446YKU6Yy1H0Gd6VRgkU,6432
|
|
23
23
|
ai_edge_torch/convert/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
24
24
|
ai_edge_torch/convert/test/test_convert.py,sha256=2qPmmGqnfV_o1gfsSdjGq3-JR1b323ligiy5MdAv9NA,8021
|
|
25
|
-
ai_edge_torch/convert/test/test_convert_composites.py,sha256=
|
|
25
|
+
ai_edge_torch/convert/test/test_convert_composites.py,sha256=_Ojc-H6GOS5s8ek3_8eRBL_AiCs-k3srziPJ2R4Ulrg,7255
|
|
26
26
|
ai_edge_torch/convert/test/test_convert_multisig.py,sha256=kMaGnHe9ylfyU68qCifYcaGwJqyejKz--QQt9jS2oUA,4537
|
|
27
27
|
ai_edge_torch/debug/__init__.py,sha256=TKvmnjVk3asvYcVh6C-LPr6srgAF_nppSAupWEXqwPY,707
|
|
28
28
|
ai_edge_torch/debug/culprit.py,sha256=vklaxBUfINdo44OsH7csILK70N41gEThCGchGEfbTZw,12789
|
|
@@ -40,10 +40,10 @@ ai_edge_torch/generative/examples/phi2/convert_to_tflite.py,sha256=6nOuwx9q3AUlY
|
|
|
40
40
|
ai_edge_torch/generative/examples/phi2/phi2.py,sha256=VvigzPQ_LJHeADTsMliwFwPe2BcnOhFgKDqr_WZ2JQ8,5540
|
|
41
41
|
ai_edge_torch/generative/examples/stable_diffusion/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
42
42
|
ai_edge_torch/generative/examples/stable_diffusion/attention.py,sha256=Lo4Dq7a3Kg-lyH56iqGtqCo5UaClQHRCTDdNagXGTo8,3535
|
|
43
|
-
ai_edge_torch/generative/examples/stable_diffusion/clip.py,sha256=
|
|
44
|
-
ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py,sha256=
|
|
43
|
+
ai_edge_torch/generative/examples/stable_diffusion/clip.py,sha256=yUCJemEh4n8ez-yLgVU0HZAki-PZ9nY04DFjgpx9PUc,3698
|
|
44
|
+
ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py,sha256=MI73RjOeD4Kh7AL0j5_QXiZq-rl_qCdibSE6eCQCyeY,3804
|
|
45
45
|
ai_edge_torch/generative/examples/stable_diffusion/decoder.py,sha256=AgVAdUbSkHXONVUjAyBQEXhIUUlinf9kNljcBpWnj3A,3276
|
|
46
|
-
ai_edge_torch/generative/examples/stable_diffusion/diffusion.py,sha256=
|
|
46
|
+
ai_edge_torch/generative/examples/stable_diffusion/diffusion.py,sha256=TfbfsmuKoGsBENF9fYIAN_SMEQNhj-kjNdqQXFJGxpg,7784
|
|
47
47
|
ai_edge_torch/generative/examples/stable_diffusion/encoder.py,sha256=L6hLaMQGb8-_BwSvTLIuDnZwfTqn0K4swBUjfPnYWZo,2341
|
|
48
48
|
ai_edge_torch/generative/examples/stable_diffusion/pipeline.py,sha256=FCbnwlkpYYb-tF7KscbSYjNEdg7XnuLju1cDuIRoQv8,8277
|
|
49
49
|
ai_edge_torch/generative/examples/stable_diffusion/tokenizer.py,sha256=r9RqbyNvuvXOGu3ojtl7ZmbC7o4Pt8aUKAhN1yCdtEc,3397
|
|
@@ -56,22 +56,24 @@ ai_edge_torch/generative/examples/stable_diffusion/samplers/sampler.py,sha256=5i
|
|
|
56
56
|
ai_edge_torch/generative/examples/t5/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
57
57
|
ai_edge_torch/generative/examples/t5/convert_to_tflite.py,sha256=bWtwtUacvJOEDUpuYvLTgkP7oTkXKJA-Tf4FPxlD1Cw,4536
|
|
58
58
|
ai_edge_torch/generative/examples/t5/t5.py,sha256=q2gG5RRo7RgNzvHXYC0Juh6Tgt5d_RTMSWFaYvOKiZU,21065
|
|
59
|
-
ai_edge_torch/generative/examples/t5/t5_attention.py,sha256=
|
|
59
|
+
ai_edge_torch/generative/examples/t5/t5_attention.py,sha256=rRgwCEdVtzcJEaGbbBjw8HxCxrCX3pXA5nelawdYiME,9036
|
|
60
60
|
ai_edge_torch/generative/examples/test_models/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
61
61
|
ai_edge_torch/generative/examples/test_models/toy_model.py,sha256=EV07_MEG3fv9g0ZGu9gbBd5BjjrGkxCT1pv7dvhz4TI,3791
|
|
62
|
+
ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py,sha256=rzL5h7Z5DIEgfpc1pWgYHdKt2aR8ha_CUqTKQBSPBaU,5521
|
|
62
63
|
ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py,sha256=MUr6fSj2hBuYSlNbZtrBBpzqB_0WY-l_xYcd_TFFUjY,4831
|
|
63
64
|
ai_edge_torch/generative/examples/tiny_llama/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
64
65
|
ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py,sha256=E4I5OlC4zyl5cxiiu7uTED-zcwYRu210lP1zuT3xLBE,2566
|
|
65
66
|
ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=hVGpuI8gpj4Rn9k4otsRE22MSLFHBDlUOgioY6Ru6VI,5629
|
|
66
67
|
ai_edge_torch/generative/layers/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
67
|
-
ai_edge_torch/generative/layers/attention.py,sha256=
|
|
68
|
+
ai_edge_torch/generative/layers/attention.py,sha256=zNIBXxCOA5Mz_F_dfBbKpIovhtcB6q5a-i8oAxls1d0,7071
|
|
68
69
|
ai_edge_torch/generative/layers/attention_utils.py,sha256=hXhuyKblPPxKIRzlAf1YNlwHgpbj-6nReRLhRHELx5k,6350
|
|
69
|
-
ai_edge_torch/generative/layers/builder.py,sha256=
|
|
70
|
+
ai_edge_torch/generative/layers/builder.py,sha256=WLTeDId9t3Xwt0h1zxzqoYyFvfrNzPKLskcl39q8Aqw,3403
|
|
70
71
|
ai_edge_torch/generative/layers/feed_forward.py,sha256=4j2QaSCw59Jkk_ixKDpKEj7FLRauzuExTiSNRzAjAhE,2820
|
|
71
72
|
ai_edge_torch/generative/layers/kv_cache.py,sha256=4uiZLO3om5G3--kT04Jt0esEYznbkJ7QLzSHfb8mjc4,3090
|
|
72
|
-
ai_edge_torch/generative/layers/model_config.py,sha256=
|
|
73
|
+
ai_edge_torch/generative/layers/model_config.py,sha256=2zT9nyoyuuyk5ziiww0VSJ6_JO7pDf7uOYbO9O3OQc4,4249
|
|
73
74
|
ai_edge_torch/generative/layers/normalization.py,sha256=M27eW3TcNK20oaXClXtfnu0lLWrAGrSKSsbegRWnj3c,1867
|
|
74
75
|
ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=12SsCuoRuLNCwnFGe_pHDOZEBwBcqXs87Aj0PaWWw4E,1383
|
|
76
|
+
ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=dYafGC205QE5CLIbBTCI-7eVvEGZEHzs1toPEhemeDs,3391
|
|
75
77
|
ai_edge_torch/generative/quantize/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
76
78
|
ai_edge_torch/generative/quantize/example.py,sha256=t-YwyKSPAG-OZC1DfH-0vfie2RHHpTSQjxUY-tmhu5g,1543
|
|
77
79
|
ai_edge_torch/generative/quantize/quant_attrs.py,sha256=ffBALrrbrfiG_mrOr-f3B1Gc6PlAma9gtvVnfP7SDzI,1862
|
|
@@ -84,7 +86,7 @@ ai_edge_torch/generative/test/loader_test.py,sha256=N88CbrLW7Q2x1EyurwdXQ6YjsA-y
|
|
|
84
86
|
ai_edge_torch/generative/test/test_model_conversion.py,sha256=1NfZxKo9Gx6CmVfd86K1FkmsNQnjzIV1ojBS85UGvT0,6500
|
|
85
87
|
ai_edge_torch/generative/test/test_quantize.py,sha256=f70sH1ZFzdCwYj0MG-eg54WOC4LasR0D8CTUYpjxZYM,3728
|
|
86
88
|
ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
|
|
87
|
-
ai_edge_torch/generative/utilities/loader.py,sha256=
|
|
89
|
+
ai_edge_torch/generative/utilities/loader.py,sha256=r-_hSanSjLZ_YXFpZUb0Up94u5F8JHp70Vf2nlONPSg,11269
|
|
88
90
|
ai_edge_torch/generative/utilities/t5_loader.py,sha256=guDTv-12UUvJGl4eDvvZX3t4rRKewfXO8SpcYXM6gbc,16156
|
|
89
91
|
ai_edge_torch/hlfb/__init__.py,sha256=rrje8a2iuKboBoV96bVq7nlS9HsnuEMbHE5JiWmCxFA,752
|
|
90
92
|
ai_edge_torch/hlfb/mark_pattern/__init__.py,sha256=2VXnHcGf23VOuP-1GriGIpuL98leBB8twp_qaScMnmc,4799
|
|
@@ -100,8 +102,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=ExThdTXqnWmGC3-F6sdXbXr8nYzkEe_qCz
|
|
|
100
102
|
ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
101
103
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
|
102
104
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=EIyKz-HY70DguWuSrJal8LpYXQ5ZSEUf3ZrVl7jikFM,4286
|
|
103
|
-
ai_edge_torch_nightly-0.2.0.
|
|
104
|
-
ai_edge_torch_nightly-0.2.0.
|
|
105
|
-
ai_edge_torch_nightly-0.2.0.
|
|
106
|
-
ai_edge_torch_nightly-0.2.0.
|
|
107
|
-
ai_edge_torch_nightly-0.2.0.
|
|
105
|
+
ai_edge_torch_nightly-0.2.0.dev20240605.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
|
106
|
+
ai_edge_torch_nightly-0.2.0.dev20240605.dist-info/METADATA,sha256=GJzwmKkM4T0H-vTvMyoxiD80WfppEpE_sd2Ip4aSbgM,1748
|
|
107
|
+
ai_edge_torch_nightly-0.2.0.dev20240605.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
|
108
|
+
ai_edge_torch_nightly-0.2.0.dev20240605.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
|
109
|
+
ai_edge_torch_nightly-0.2.0.dev20240605.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|