ai-edge-torch-nightly 0.2.0.dev20240604__py3-none-any.whl → 0.2.0.dev20240606__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ai-edge-torch-nightly might be problematic. Click here for more details.
- ai_edge_torch/convert/fx_passes/build_aten_composite_pass.py +30 -0
- ai_edge_torch/convert/test/test_convert_composites.py +18 -0
- ai_edge_torch/generative/examples/stable_diffusion/clip.py +83 -49
- ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +7 -5
- ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +0 -260
- ai_edge_torch/generative/examples/t5/t5_attention.py +2 -2
- ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +161 -0
- ai_edge_torch/generative/layers/attention.py +27 -114
- ai_edge_torch/generative/layers/builder.py +4 -0
- ai_edge_torch/generative/layers/model_config.py +5 -1
- ai_edge_torch/generative/layers/scaled_dot_product_attention.py +117 -0
- ai_edge_torch/generative/test/test_model_conversion.py +90 -80
- ai_edge_torch/generative/utilities/loader.py +56 -27
- {ai_edge_torch_nightly-0.2.0.dev20240604.dist-info → ai_edge_torch_nightly-0.2.0.dev20240606.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.2.0.dev20240604.dist-info → ai_edge_torch_nightly-0.2.0.dev20240606.dist-info}/RECORD +18 -16
- {ai_edge_torch_nightly-0.2.0.dev20240604.dist-info → ai_edge_torch_nightly-0.2.0.dev20240606.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.2.0.dev20240604.dist-info → ai_edge_torch_nightly-0.2.0.dev20240606.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.2.0.dev20240604.dist-info → ai_edge_torch_nightly-0.2.0.dev20240606.dist-info}/top_level.txt +0 -0
|
@@ -99,6 +99,36 @@ def _aten_hardswish(gm: GraphModule, node: Node):
|
|
|
99
99
|
node.target = hardswish
|
|
100
100
|
|
|
101
101
|
|
|
102
|
+
@_register_composite_builder(torch.ops.aten.gelu.default)
|
|
103
|
+
def _aten_gelu(gm: GraphModule, node: Node):
|
|
104
|
+
op = node.target
|
|
105
|
+
args_mapper = TorchOpArgumentsMapper(op)
|
|
106
|
+
|
|
107
|
+
def gelu(*args, **kwargs):
|
|
108
|
+
nonlocal op, args_mapper
|
|
109
|
+
|
|
110
|
+
full_kwargs = args_mapper.get_full_kwargs(args, kwargs)
|
|
111
|
+
|
|
112
|
+
# TFLite supports exact and tanh approximate.
|
|
113
|
+
if full_kwargs["approximate"] != "none" and full_kwargs["approximate"] != "tanh":
|
|
114
|
+
return op(*args, **kwargs)
|
|
115
|
+
|
|
116
|
+
builder = StableHLOCompositeBuilder(
|
|
117
|
+
"aten.gelu.default",
|
|
118
|
+
attr=_tree_map_to_composite_attr_values(
|
|
119
|
+
{
|
|
120
|
+
"approximate": full_kwargs["approximate"],
|
|
121
|
+
}
|
|
122
|
+
),
|
|
123
|
+
)
|
|
124
|
+
full_kwargs["self"] = builder.mark_inputs(full_kwargs["self"])
|
|
125
|
+
output = op(full_kwargs["self"])
|
|
126
|
+
output = builder.mark_outputs(output)
|
|
127
|
+
return output
|
|
128
|
+
|
|
129
|
+
node.target = gelu
|
|
130
|
+
|
|
131
|
+
|
|
102
132
|
@_register_composite_builder(torch.ops.aten.avg_pool2d.default)
|
|
103
133
|
def _aten_avg_pool2d(gm: GraphModule, node: Node):
|
|
104
134
|
op = node.target
|
|
@@ -169,6 +169,24 @@ class TestConvertComposites(unittest.TestCase):
|
|
|
169
169
|
model_coverage.compare_tflite_torch(edge_model, torch_module, tracing_args)
|
|
170
170
|
)
|
|
171
171
|
|
|
172
|
+
def test_convert_gelu(self):
|
|
173
|
+
"""Tests conversion of a GELU module."""
|
|
174
|
+
|
|
175
|
+
args = (torch.randn((5, 10)),)
|
|
176
|
+
torch_module = torch.nn.GELU().eval()
|
|
177
|
+
edge_model = ai_edge_torch.convert(torch_module, args)
|
|
178
|
+
|
|
179
|
+
self.assertTrue(model_coverage.compare_tflite_torch(edge_model, torch_module, args))
|
|
180
|
+
|
|
181
|
+
def test_convert_gelu_approximate(self):
|
|
182
|
+
"""Tests conversion of an Approximate GELU module."""
|
|
183
|
+
|
|
184
|
+
args = (torch.randn((5, 10)),)
|
|
185
|
+
torch_module = torch.nn.GELU('tanh').eval()
|
|
186
|
+
edge_model = ai_edge_torch.convert(torch_module, args)
|
|
187
|
+
|
|
188
|
+
self.assertTrue(model_coverage.compare_tflite_torch(edge_model, torch_module, args))
|
|
189
|
+
|
|
172
190
|
|
|
173
191
|
if __name__ == '__main__':
|
|
174
192
|
unittest.main()
|
|
@@ -15,65 +15,99 @@
|
|
|
15
15
|
|
|
16
16
|
import torch
|
|
17
17
|
from torch import nn
|
|
18
|
-
from torch._prims_common import mask_tensor
|
|
19
|
-
from torch._prims_common.wrappers import out_wrapper
|
|
20
18
|
|
|
21
|
-
from ai_edge_torch.generative.
|
|
19
|
+
from ai_edge_torch.generative.layers.attention import TransformerBlock
|
|
20
|
+
import ai_edge_torch.generative.layers.attention_utils as attention_utils
|
|
21
|
+
import ai_edge_torch.generative.layers.builder as builder
|
|
22
|
+
import ai_edge_torch.generative.layers.model_config as cfg
|
|
23
|
+
import ai_edge_torch.generative.utilities.loader as loading_utils
|
|
24
|
+
|
|
25
|
+
TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
|
26
|
+
ff_up_proj="layers.{}.linear_1",
|
|
27
|
+
ff_down_proj="layers.{}.linear_2",
|
|
28
|
+
ff_gate_proj="layers.{}.linear_1",
|
|
29
|
+
attn_fused_qkv_proj="layers.{}.attention.in_proj",
|
|
30
|
+
attn_output_proj="layers.{}.attention.out_proj",
|
|
31
|
+
pre_attn_norm="layers.{}.layernorm_1",
|
|
32
|
+
pre_ff_norm="layers.{}.layernorm_2",
|
|
33
|
+
embedding="embedding.token_embedding",
|
|
34
|
+
embedding_position="embedding.position_value",
|
|
35
|
+
final_norm="layernorm",
|
|
36
|
+
lm_head=None,
|
|
37
|
+
)
|
|
22
38
|
|
|
23
39
|
|
|
24
|
-
class
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
self.token_embedding = nn.Embedding(n_vocab, n_embd)
|
|
29
|
-
self.position_value = nn.Parameter(torch.zeros((n_token, n_embd)))
|
|
30
|
-
|
|
31
|
-
def forward(self, tokens):
|
|
32
|
-
x = self.token_embedding(tokens)
|
|
33
|
-
x += self.position_value
|
|
34
|
-
return x
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
class CLIPLayer(nn.Module):
|
|
40
|
+
class CLIP(nn.Module):
|
|
41
|
+
"""CLIP text encoder
|
|
42
|
+
For details, see https://arxiv.org/abs/2103.00020
|
|
43
|
+
"""
|
|
38
44
|
|
|
39
|
-
def __init__(self,
|
|
45
|
+
def __init__(self, config: cfg.ModelConfig):
|
|
40
46
|
super().__init__()
|
|
41
|
-
self.
|
|
42
|
-
self.
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
self.linear_2 = nn.Linear(4 * n_embd, n_embd)
|
|
46
|
-
|
|
47
|
-
def forward(self, x):
|
|
48
|
-
residue = x
|
|
49
|
-
x = self.layernorm_1(x)
|
|
50
|
-
x = self.attention(x, causal_mask=True)
|
|
51
|
-
x += residue
|
|
47
|
+
self.tok_embedding = nn.Embedding(config.vocab_size, config.embedding_dim)
|
|
48
|
+
self.tok_embedding_position = nn.Parameter(
|
|
49
|
+
torch.zeros((config.max_seq_len, config.embedding_dim))
|
|
50
|
+
)
|
|
52
51
|
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
x += residue
|
|
52
|
+
self.config = config
|
|
53
|
+
self.transformer_blocks = nn.ModuleList(
|
|
54
|
+
TransformerBlock(config) for _ in range(config.num_layers)
|
|
55
|
+
)
|
|
56
|
+
self.final_norm = builder.build_norm(config.embedding_dim, config.final_norm_config)
|
|
59
57
|
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
class CLIP(nn.Module):
|
|
64
|
-
|
|
65
|
-
def __init__(self):
|
|
66
|
-
super().__init__()
|
|
67
|
-
self.embedding = CLIPEmbedding(49408, 768, 77)
|
|
68
|
-
self.layers = nn.ModuleList([CLIPLayer(12, 768) for i in range(12)])
|
|
69
|
-
self.layernorm = nn.LayerNorm(768)
|
|
58
|
+
self.mask_cache = attention_utils.build_causal_mask_cache(
|
|
59
|
+
size=config.max_seq_len, dtype=torch.float32
|
|
60
|
+
)
|
|
70
61
|
|
|
71
62
|
@torch.inference_mode
|
|
72
63
|
def forward(self, tokens: torch.LongTensor) -> torch.FloatTensor:
|
|
73
64
|
tokens = tokens.type(torch.long)
|
|
74
65
|
|
|
75
|
-
state = self.
|
|
76
|
-
for layer in self.
|
|
77
|
-
state = layer(state)
|
|
78
|
-
output = self.
|
|
66
|
+
state = self.tok_embedding(tokens) + self.tok_embedding_position
|
|
67
|
+
for layer in self.transformer_blocks:
|
|
68
|
+
state = layer(state, mask=self.mask_cache)
|
|
69
|
+
output = self.final_norm(state)
|
|
79
70
|
return output
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def get_model_config() -> cfg.ModelConfig:
|
|
74
|
+
max_seq_len = 77
|
|
75
|
+
vocab_size = 49408
|
|
76
|
+
num_layers = 12
|
|
77
|
+
num_heads = 12
|
|
78
|
+
num_query_groups = 12
|
|
79
|
+
embedding_dim = 768
|
|
80
|
+
|
|
81
|
+
attn_config = cfg.AttentionConfig(
|
|
82
|
+
num_heads=num_heads,
|
|
83
|
+
num_query_groups=num_query_groups,
|
|
84
|
+
rotary_percentage=0.0,
|
|
85
|
+
qkv_use_bias=True,
|
|
86
|
+
qkv_transpose_before_split=True,
|
|
87
|
+
output_proj_use_bias=True,
|
|
88
|
+
enable_kv_cache=False,
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
ff_config = cfg.FeedForwardConfig(
|
|
92
|
+
type=cfg.FeedForwardType.SEQUENTIAL,
|
|
93
|
+
activation=cfg.ActivationType.GELU_QUICK,
|
|
94
|
+
intermediate_size=embedding_dim * 4,
|
|
95
|
+
use_bias=True,
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.LAYER_NORM)
|
|
99
|
+
|
|
100
|
+
config = cfg.ModelConfig(
|
|
101
|
+
vocab_size=vocab_size,
|
|
102
|
+
num_layers=num_layers,
|
|
103
|
+
max_seq_len=max_seq_len,
|
|
104
|
+
embedding_dim=embedding_dim,
|
|
105
|
+
attn_config=attn_config,
|
|
106
|
+
ff_config=ff_config,
|
|
107
|
+
pre_attention_norm_config=norm_config,
|
|
108
|
+
pre_ff_norm_config=norm_config,
|
|
109
|
+
final_norm_config=norm_config,
|
|
110
|
+
enable_hlfb=True,
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
return config
|
|
@@ -19,11 +19,12 @@ from pathlib import Path
|
|
|
19
19
|
import torch
|
|
20
20
|
|
|
21
21
|
import ai_edge_torch
|
|
22
|
-
|
|
22
|
+
import ai_edge_torch.generative.examples.stable_diffusion.clip as clip
|
|
23
23
|
from ai_edge_torch.generative.examples.stable_diffusion.decoder import Decoder
|
|
24
24
|
from ai_edge_torch.generative.examples.stable_diffusion.diffusion import Diffusion # NOQA
|
|
25
25
|
from ai_edge_torch.generative.examples.stable_diffusion.encoder import Encoder
|
|
26
26
|
import ai_edge_torch.generative.examples.stable_diffusion.util as util
|
|
27
|
+
import ai_edge_torch.generative.utilities.loader as loading_utils
|
|
27
28
|
|
|
28
29
|
|
|
29
30
|
@torch.inference_mode
|
|
@@ -36,8 +37,9 @@ def convert_stable_diffusion_to_tflite(
|
|
|
36
37
|
image_width: int = 512,
|
|
37
38
|
):
|
|
38
39
|
|
|
39
|
-
|
|
40
|
-
|
|
40
|
+
clip_model = clip.CLIP(clip.get_model_config())
|
|
41
|
+
loader = loading_utils.ModelLoader(clip_ckpt_path, clip.TENSOR_NAMES)
|
|
42
|
+
loader.load(clip_model, strict=False)
|
|
41
43
|
|
|
42
44
|
encoder = Encoder()
|
|
43
45
|
encoder.load_state_dict(torch.load(encoder_ckpt_path))
|
|
@@ -59,13 +61,13 @@ def convert_stable_diffusion_to_tflite(
|
|
|
59
61
|
)
|
|
60
62
|
|
|
61
63
|
input_latents = encoder(input_image, noise)
|
|
62
|
-
context_cond =
|
|
64
|
+
context_cond = clip_model(prompt_tokens)
|
|
63
65
|
context_uncond = torch.zeros_like(context_cond)
|
|
64
66
|
context = torch.cat([context_cond, context_uncond], axis=0)
|
|
65
67
|
time_embedding = util.get_time_embedding(timestamp)
|
|
66
68
|
|
|
67
69
|
# CLIP text encoder
|
|
68
|
-
ai_edge_torch.signature('encode',
|
|
70
|
+
ai_edge_torch.signature('encode', clip_model, (prompt_tokens,)).convert().export(
|
|
69
71
|
'/tmp/stable_diffusion/clip.tflite'
|
|
70
72
|
)
|
|
71
73
|
|
|
@@ -202,11 +202,6 @@ class UNet(nn.Module):
|
|
|
202
202
|
|
|
203
203
|
x = self.bottleneck(x, context, time)
|
|
204
204
|
|
|
205
|
-
# print('x shape:')
|
|
206
|
-
# print(list(x.shape))
|
|
207
|
-
# print('time shape:')
|
|
208
|
-
# print(list(time.shape))
|
|
209
|
-
|
|
210
205
|
for layers in self.decoders:
|
|
211
206
|
x = torch.cat((x, skip_connections.pop()), dim=1)
|
|
212
207
|
x = layers(x, context, time)
|
|
@@ -214,199 +209,6 @@ class UNet(nn.Module):
|
|
|
214
209
|
return x
|
|
215
210
|
|
|
216
211
|
|
|
217
|
-
# The encoder component.
|
|
218
|
-
class UNetEncoder(nn.Module):
|
|
219
|
-
|
|
220
|
-
def __init__(self):
|
|
221
|
-
super().__init__()
|
|
222
|
-
self.time_embedding = TimeEmbedding(320)
|
|
223
|
-
self.encoders = nn.ModuleList(
|
|
224
|
-
[
|
|
225
|
-
SwitchSequential(nn.Conv2d(4, 320, kernel_size=3, padding=1)),
|
|
226
|
-
SwitchSequential(ResidualBlock(320, 320), AttentionBlock(8, 40)),
|
|
227
|
-
SwitchSequential(ResidualBlock(320, 320), AttentionBlock(8, 40)),
|
|
228
|
-
SwitchSequential(nn.Conv2d(320, 320, kernel_size=3, stride=2, padding=1)),
|
|
229
|
-
SwitchSequential(ResidualBlock(320, 640), AttentionBlock(8, 80)),
|
|
230
|
-
SwitchSequential(ResidualBlock(640, 640), AttentionBlock(8, 80)),
|
|
231
|
-
SwitchSequential(nn.Conv2d(640, 640, kernel_size=3, stride=2, padding=1)),
|
|
232
|
-
SwitchSequential(ResidualBlock(640, 1280), AttentionBlock(8, 160)),
|
|
233
|
-
SwitchSequential(ResidualBlock(1280, 1280), AttentionBlock(8, 160)),
|
|
234
|
-
SwitchSequential(nn.Conv2d(1280, 1280, kernel_size=3, stride=2, padding=1)),
|
|
235
|
-
SwitchSequential(ResidualBlock(1280, 1280)),
|
|
236
|
-
SwitchSequential(ResidualBlock(1280, 1280)),
|
|
237
|
-
]
|
|
238
|
-
)
|
|
239
|
-
|
|
240
|
-
def forward(self, x, context, time):
|
|
241
|
-
time_embedding = self.time_embedding(time)
|
|
242
|
-
skip_connections = []
|
|
243
|
-
for layers in self.encoders:
|
|
244
|
-
x = layers(x, context, time_embedding)
|
|
245
|
-
skip_connections.append(x)
|
|
246
|
-
|
|
247
|
-
return x, skip_connections, time_embedding
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
class UNetBottleNeck(nn.Module):
|
|
251
|
-
|
|
252
|
-
def __init__(self):
|
|
253
|
-
super().__init__()
|
|
254
|
-
self.bottleneck = SwitchSequential(
|
|
255
|
-
ResidualBlock(1280, 1280),
|
|
256
|
-
AttentionBlock(8, 160),
|
|
257
|
-
ResidualBlock(1280, 1280),
|
|
258
|
-
)
|
|
259
|
-
|
|
260
|
-
def forward(self, x, context, time):
|
|
261
|
-
x = self.bottleneck(x, context, time)
|
|
262
|
-
# print('shape')
|
|
263
|
-
# print(list(x.shape))
|
|
264
|
-
return x
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
# Unet decoder.
|
|
268
|
-
class UNetDecoder1(nn.Module):
|
|
269
|
-
|
|
270
|
-
def __init__(self):
|
|
271
|
-
super().__init__()
|
|
272
|
-
self.decoders = nn.ModuleList(
|
|
273
|
-
[
|
|
274
|
-
SwitchSequential(ResidualBlock(2560, 1280)),
|
|
275
|
-
SwitchSequential(ResidualBlock(2560, 1280)),
|
|
276
|
-
SwitchSequential(ResidualBlock(2560, 1280), Upsample(1280)),
|
|
277
|
-
SwitchSequential(ResidualBlock(2560, 1280), AttentionBlock(8, 160)),
|
|
278
|
-
]
|
|
279
|
-
)
|
|
280
|
-
|
|
281
|
-
def forward(self, x, context, time, s9, s10, s11, s12):
|
|
282
|
-
x = torch.cat((x, s12), dim=1)
|
|
283
|
-
x = self.decoders[0](x, context, time)
|
|
284
|
-
x = torch.cat((x, s11), dim=1)
|
|
285
|
-
x = self.decoders[1](x, context, time)
|
|
286
|
-
x = torch.cat((x, s10), dim=1)
|
|
287
|
-
x = self.decoders[2](x, context, time)
|
|
288
|
-
x = torch.cat((x, s9), dim=1)
|
|
289
|
-
x = self.decoders[3](x, context, time)
|
|
290
|
-
|
|
291
|
-
return x
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
class UNetDecoder2(nn.Module):
|
|
295
|
-
|
|
296
|
-
def __init__(self):
|
|
297
|
-
super().__init__()
|
|
298
|
-
self.decoders = nn.ModuleList(
|
|
299
|
-
[
|
|
300
|
-
SwitchSequential(ResidualBlock(2560, 1280), AttentionBlock(8, 160)),
|
|
301
|
-
SwitchSequential(
|
|
302
|
-
ResidualBlock(1920, 1280), AttentionBlock(8, 160), Upsample(1280)
|
|
303
|
-
),
|
|
304
|
-
SwitchSequential(ResidualBlock(1920, 640), AttentionBlock(8, 80)),
|
|
305
|
-
SwitchSequential(ResidualBlock(1280, 640), AttentionBlock(8, 80)),
|
|
306
|
-
]
|
|
307
|
-
)
|
|
308
|
-
|
|
309
|
-
def forward(self, x, context, time, s5, s6, s7, s8):
|
|
310
|
-
x = torch.cat((x, s8), dim=1)
|
|
311
|
-
x = self.decoders[0](x, context, time)
|
|
312
|
-
x = torch.cat((x, s7), dim=1)
|
|
313
|
-
x = self.decoders[1](x, context, time)
|
|
314
|
-
x = torch.cat((x, s6), dim=1)
|
|
315
|
-
x = self.decoders[2](x, context, time)
|
|
316
|
-
x = torch.cat((x, s5), dim=1)
|
|
317
|
-
x = self.decoders[3](x, context, time)
|
|
318
|
-
return x
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
class UNetDecoder3(nn.Module):
|
|
322
|
-
|
|
323
|
-
def __init__(self):
|
|
324
|
-
super().__init__()
|
|
325
|
-
self.decoders = nn.ModuleList(
|
|
326
|
-
[
|
|
327
|
-
SwitchSequential(
|
|
328
|
-
ResidualBlock(960, 640), AttentionBlock(8, 80), Upsample(640)
|
|
329
|
-
),
|
|
330
|
-
SwitchSequential(ResidualBlock(960, 320), AttentionBlock(8, 40)),
|
|
331
|
-
SwitchSequential(ResidualBlock(640, 320), AttentionBlock(8, 40)),
|
|
332
|
-
SwitchSequential(ResidualBlock(640, 320), AttentionBlock(8, 40)),
|
|
333
|
-
]
|
|
334
|
-
)
|
|
335
|
-
self.final = FinalLayer(320, 4)
|
|
336
|
-
|
|
337
|
-
def forward(self, x, context, time, s1, s2, s3, s4):
|
|
338
|
-
x = torch.cat((x, s4), dim=1)
|
|
339
|
-
x = self.decoders[0](x, context, time)
|
|
340
|
-
x = torch.cat((x, s3), dim=1)
|
|
341
|
-
x = self.decoders[1](x, context, time)
|
|
342
|
-
x = torch.cat((x, s2), dim=1)
|
|
343
|
-
x = self.decoders[2](x, context, time)
|
|
344
|
-
x = torch.cat((x, s1), dim=1)
|
|
345
|
-
x = self.decoders[3](x, context, time)
|
|
346
|
-
|
|
347
|
-
x = self.final(x)
|
|
348
|
-
return x
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
class UNetDecoder(nn.Module):
|
|
352
|
-
|
|
353
|
-
def __init__(self):
|
|
354
|
-
super().__init__()
|
|
355
|
-
self.decoders = nn.ModuleList(
|
|
356
|
-
[
|
|
357
|
-
SwitchSequential(ResidualBlock(2560, 1280)),
|
|
358
|
-
SwitchSequential(ResidualBlock(2560, 1280)),
|
|
359
|
-
SwitchSequential(ResidualBlock(2560, 1280), Upsample(1280)),
|
|
360
|
-
SwitchSequential(ResidualBlock(2560, 1280), AttentionBlock(8, 160)),
|
|
361
|
-
SwitchSequential(ResidualBlock(2560, 1280), AttentionBlock(8, 160)),
|
|
362
|
-
SwitchSequential(
|
|
363
|
-
ResidualBlock(1920, 1280), AttentionBlock(8, 160), Upsample(1280)
|
|
364
|
-
),
|
|
365
|
-
SwitchSequential(ResidualBlock(1920, 640), AttentionBlock(8, 80)),
|
|
366
|
-
SwitchSequential(ResidualBlock(1280, 640), AttentionBlock(8, 80)),
|
|
367
|
-
SwitchSequential(
|
|
368
|
-
ResidualBlock(960, 640), AttentionBlock(8, 80), Upsample(640)
|
|
369
|
-
),
|
|
370
|
-
SwitchSequential(ResidualBlock(960, 320), AttentionBlock(8, 40)),
|
|
371
|
-
SwitchSequential(ResidualBlock(640, 320), AttentionBlock(8, 40)),
|
|
372
|
-
SwitchSequential(ResidualBlock(640, 320), AttentionBlock(8, 40)),
|
|
373
|
-
]
|
|
374
|
-
)
|
|
375
|
-
self.final = FinalLayer(320, 4)
|
|
376
|
-
|
|
377
|
-
def forward(
|
|
378
|
-
self, x, context, time, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12
|
|
379
|
-
):
|
|
380
|
-
x = torch.cat((x, s12), dim=1)
|
|
381
|
-
x = self.decoders[0](x, context, time)
|
|
382
|
-
x = torch.cat((x, s11), dim=1)
|
|
383
|
-
x = self.decoders[1](x, context, time)
|
|
384
|
-
x = torch.cat((x, s10), dim=1)
|
|
385
|
-
x = self.decoders[2](x, context, time)
|
|
386
|
-
x = torch.cat((x, s9), dim=1)
|
|
387
|
-
x = self.decoders[3](x, context, time)
|
|
388
|
-
x = torch.cat((x, s8), dim=1)
|
|
389
|
-
x = self.decoders[4](x, context, time)
|
|
390
|
-
x = torch.cat((x, s7), dim=1)
|
|
391
|
-
x = self.decoders[5](x, context, time)
|
|
392
|
-
x = torch.cat((x, s6), dim=1)
|
|
393
|
-
x = self.decoders[6](x, context, time)
|
|
394
|
-
x = torch.cat((x, s5), dim=1)
|
|
395
|
-
x = self.decoders[7](x, context, time)
|
|
396
|
-
x = torch.cat((x, s4), dim=1)
|
|
397
|
-
x = self.decoders[0](x, context, time)
|
|
398
|
-
x = torch.cat((x, s3), dim=1)
|
|
399
|
-
x = self.decoders[1](x, context, time)
|
|
400
|
-
x = torch.cat((x, s2), dim=1)
|
|
401
|
-
x = self.decoders[2](x, context, time)
|
|
402
|
-
x = torch.cat((x, s1), dim=1)
|
|
403
|
-
x = self.decoders[3](x, context, time)
|
|
404
|
-
|
|
405
|
-
x = self.final(x)
|
|
406
|
-
|
|
407
|
-
return x
|
|
408
|
-
|
|
409
|
-
|
|
410
212
|
class FinalLayer(nn.Module):
|
|
411
213
|
|
|
412
214
|
def __init__(self, in_channels, out_channels):
|
|
@@ -432,68 +234,6 @@ class Diffusion(nn.Module):
|
|
|
432
234
|
@torch.inference_mode
|
|
433
235
|
def forward(self, latent, context, time):
|
|
434
236
|
time = self.time_embedding(time)
|
|
435
|
-
# print('time:')
|
|
436
|
-
# print(list(time.shape))
|
|
437
237
|
output = self.unet(latent, context, time)
|
|
438
238
|
output = self.final(output)
|
|
439
239
|
return output
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
# Calling code as if Diffusion is splitted into two parts.
|
|
443
|
-
class DiffusionSplitted(nn.Module):
|
|
444
|
-
|
|
445
|
-
def __init__(self):
|
|
446
|
-
super().__init__()
|
|
447
|
-
self.unet_encoder = UNetEncoder()
|
|
448
|
-
self.bottleneck = UNetBottleNeck()
|
|
449
|
-
self.unet_decoder1 = UNetDecoder1()
|
|
450
|
-
self.unet_decoder2 = UNetDecoder2()
|
|
451
|
-
self.unet_decoder3 = UNetDecoder3()
|
|
452
|
-
|
|
453
|
-
def get_skip_connections(self, latent, context, time):
|
|
454
|
-
_, skip_connections, _ = self.unet_encoder(latent, context, time)
|
|
455
|
-
return skip_connections
|
|
456
|
-
|
|
457
|
-
def forward(self, latent, context, time):
|
|
458
|
-
output, skip_connections, time = self.unet_encoder(latent, context, time)
|
|
459
|
-
# print("output shape of unet encoder...")
|
|
460
|
-
# print(list(output.shape))
|
|
461
|
-
# print("output shape of time...")
|
|
462
|
-
# print(list(time.shape))
|
|
463
|
-
output = self.bottleneck(output, context, time)
|
|
464
|
-
# print("output shape of bn")
|
|
465
|
-
# print(list(output.shape))
|
|
466
|
-
output = self.unet_decoder1(
|
|
467
|
-
output,
|
|
468
|
-
context,
|
|
469
|
-
time,
|
|
470
|
-
skip_connections[8],
|
|
471
|
-
skip_connections[9],
|
|
472
|
-
skip_connections[10],
|
|
473
|
-
skip_connections[11],
|
|
474
|
-
)
|
|
475
|
-
# print("output shape of d1:")
|
|
476
|
-
# print(list(output.shape))
|
|
477
|
-
|
|
478
|
-
output = self.unet_decoder2(
|
|
479
|
-
output,
|
|
480
|
-
context,
|
|
481
|
-
time,
|
|
482
|
-
skip_connections[4],
|
|
483
|
-
skip_connections[5],
|
|
484
|
-
skip_connections[6],
|
|
485
|
-
skip_connections[7],
|
|
486
|
-
)
|
|
487
|
-
|
|
488
|
-
# print("output shape of d2:")
|
|
489
|
-
# print(list(output.shape))
|
|
490
|
-
output = self.unet_decoder3(
|
|
491
|
-
output,
|
|
492
|
-
context,
|
|
493
|
-
time,
|
|
494
|
-
skip_connections[0],
|
|
495
|
-
skip_connections[1],
|
|
496
|
-
skip_connections[2],
|
|
497
|
-
skip_connections[3],
|
|
498
|
-
)
|
|
499
|
-
return output
|
|
@@ -20,11 +20,11 @@ import torch
|
|
|
20
20
|
from torch import nn
|
|
21
21
|
import torch.nn.functional as F
|
|
22
22
|
|
|
23
|
-
from ai_edge_torch.generative.layers.attention import scaled_dot_product_attention # NOQA
|
|
24
|
-
from ai_edge_torch.generative.layers.attention import scaled_dot_product_attention_with_hlfb # NOQA
|
|
25
23
|
import ai_edge_torch.generative.layers.builder as builder
|
|
26
24
|
from ai_edge_torch.generative.layers.kv_cache import KVCache
|
|
27
25
|
import ai_edge_torch.generative.layers.model_config as cfg
|
|
26
|
+
from ai_edge_torch.generative.layers.scaled_dot_product_attention import scaled_dot_product_attention # NOQA
|
|
27
|
+
from ai_edge_torch.generative.layers.scaled_dot_product_attention import scaled_dot_product_attention_with_hlfb # NOQA
|
|
28
28
|
|
|
29
29
|
|
|
30
30
|
class EncoderDecoderBlock(nn.Module):
|