ai-edge-torch-nightly 0.1.dev202405131930__py3-none-any.whl → 0.2.0.dev20240527__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ai-edge-torch-nightly might be problematic. Click here for more details.
- ai_edge_torch/generative/examples/stable_diffusion/__init__.py +14 -0
- ai_edge_torch/generative/examples/stable_diffusion/attention.py +106 -0
- ai_edge_torch/generative/examples/stable_diffusion/clip.py +78 -0
- ai_edge_torch/generative/examples/stable_diffusion/decoder.py +111 -0
- ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +498 -0
- ai_edge_torch/generative/examples/stable_diffusion/encoder.py +65 -0
- ai_edge_torch/generative/examples/stable_diffusion/util.py +62 -0
- ai_edge_torch/generative/test/loader_test.py +80 -0
- ai_edge_torch/generative/utilities/loader.py +8 -4
- {ai_edge_torch_nightly-0.1.dev202405131930.dist-info → ai_edge_torch_nightly-0.2.0.dev20240527.dist-info}/METADATA +2 -2
- {ai_edge_torch_nightly-0.1.dev202405131930.dist-info → ai_edge_torch_nightly-0.2.0.dev20240527.dist-info}/RECORD +14 -6
- {ai_edge_torch_nightly-0.1.dev202405131930.dist-info → ai_edge_torch_nightly-0.2.0.dev20240527.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.1.dev202405131930.dist-info → ai_edge_torch_nightly-0.2.0.dev20240527.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.1.dev202405131930.dist-info → ai_edge_torch_nightly-0.2.0.dev20240527.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
@@ -0,0 +1,106 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
import math
|
|
17
|
+
|
|
18
|
+
import torch
|
|
19
|
+
from torch import _decomp
|
|
20
|
+
from torch import nn
|
|
21
|
+
from torch._prims_common import mask_tensor
|
|
22
|
+
from torch._prims_common.wrappers import out_wrapper
|
|
23
|
+
from torch.nn import functional as F
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def triu(a):
|
|
27
|
+
h, w = a.shape[-2:]
|
|
28
|
+
mask = (
|
|
29
|
+
torch.arange(w, device=a.device).unsqueeze(-2)
|
|
30
|
+
- torch.arange(h, device=a.device).unsqueeze(-1)
|
|
31
|
+
) >= 1
|
|
32
|
+
mask = torch.broadcast_to(mask, a.shape)
|
|
33
|
+
return torch.ops.aten.logical_and(a, mask).contiguous()
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
# _decomp.decomposition_table[torch.ops.aten.triu.default] = triu
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class SelfAttention(nn.Module):
|
|
40
|
+
|
|
41
|
+
def __init__(self, n_heads, d_embed, in_proj_bias=True, out_proj_bias=True):
|
|
42
|
+
super().__init__()
|
|
43
|
+
self.in_proj = nn.Linear(d_embed, 3 * d_embed, bias=in_proj_bias)
|
|
44
|
+
self.out_proj = nn.Linear(d_embed, d_embed, bias=out_proj_bias)
|
|
45
|
+
self.n_heads = n_heads
|
|
46
|
+
self.d_head = d_embed // n_heads
|
|
47
|
+
|
|
48
|
+
def forward(self, x, causal_mask=False):
|
|
49
|
+
input_shape = x.shape
|
|
50
|
+
batch_size, sequence_length, d_embed = input_shape
|
|
51
|
+
interim_shape = (batch_size, sequence_length, self.n_heads, self.d_head)
|
|
52
|
+
|
|
53
|
+
q, k, v = self.in_proj(x).chunk(3, dim=-1)
|
|
54
|
+
|
|
55
|
+
q = q.view(interim_shape).transpose(1, 2)
|
|
56
|
+
k = k.view(interim_shape).transpose(1, 2)
|
|
57
|
+
v = v.view(interim_shape).transpose(1, 2)
|
|
58
|
+
|
|
59
|
+
weight = q @ k.transpose(-1, -2)
|
|
60
|
+
if causal_mask:
|
|
61
|
+
# mask = torch.ones_like(weight, dtype=torch.bool).triu(1)
|
|
62
|
+
mask = triu(torch.ones_like(weight, dtype=torch.bool))
|
|
63
|
+
weight.masked_fill_(mask, -torch.inf)
|
|
64
|
+
weight /= math.sqrt(self.d_head)
|
|
65
|
+
weight = F.softmax(weight, dim=-1)
|
|
66
|
+
|
|
67
|
+
output = weight @ v
|
|
68
|
+
output = output.transpose(1, 2)
|
|
69
|
+
output = output.reshape(input_shape)
|
|
70
|
+
output = self.out_proj(output)
|
|
71
|
+
return output
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
class CrossAttention(nn.Module):
|
|
75
|
+
|
|
76
|
+
def __init__(self, n_heads, d_embed, d_cross, in_proj_bias=True, out_proj_bias=True):
|
|
77
|
+
super().__init__()
|
|
78
|
+
self.q_proj = nn.Linear(d_embed, d_embed, bias=in_proj_bias)
|
|
79
|
+
self.k_proj = nn.Linear(d_cross, d_embed, bias=in_proj_bias)
|
|
80
|
+
self.v_proj = nn.Linear(d_cross, d_embed, bias=in_proj_bias)
|
|
81
|
+
self.out_proj = nn.Linear(d_embed, d_embed, bias=out_proj_bias)
|
|
82
|
+
self.n_heads = n_heads
|
|
83
|
+
self.d_head = d_embed // n_heads
|
|
84
|
+
|
|
85
|
+
def forward(self, x, y):
|
|
86
|
+
input_shape = x.shape
|
|
87
|
+
batch_size, sequence_length, d_embed = input_shape
|
|
88
|
+
interim_shape = (batch_size, -1, self.n_heads, self.d_head)
|
|
89
|
+
|
|
90
|
+
q = self.q_proj(x)
|
|
91
|
+
k = self.k_proj(y)
|
|
92
|
+
v = self.v_proj(y)
|
|
93
|
+
|
|
94
|
+
q = q.view(interim_shape).transpose(1, 2)
|
|
95
|
+
k = k.view(interim_shape).transpose(1, 2)
|
|
96
|
+
v = v.view(interim_shape).transpose(1, 2)
|
|
97
|
+
|
|
98
|
+
weight = q @ k.transpose(-1, -2)
|
|
99
|
+
weight /= math.sqrt(self.d_head)
|
|
100
|
+
weight = F.softmax(weight, dim=-1)
|
|
101
|
+
|
|
102
|
+
output = weight @ v
|
|
103
|
+
output = output.transpose(1, 2).contiguous()
|
|
104
|
+
output = output.view(input_shape)
|
|
105
|
+
output = self.out_proj(output)
|
|
106
|
+
return output
|
|
@@ -0,0 +1,78 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
import torch
|
|
17
|
+
from torch import nn
|
|
18
|
+
from torch._prims_common import mask_tensor
|
|
19
|
+
from torch._prims_common.wrappers import out_wrapper
|
|
20
|
+
|
|
21
|
+
from ai_edge_torch.generative.examples.stable_diffusion.attention import SelfAttention # NOQA
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class CLIPEmbedding(nn.Module):
|
|
25
|
+
|
|
26
|
+
def __init__(self, n_vocab: int, n_embd: int, n_token: int):
|
|
27
|
+
super().__init__()
|
|
28
|
+
self.token_embedding = nn.Embedding(n_vocab, n_embd)
|
|
29
|
+
self.position_value = nn.Parameter(torch.zeros((n_token, n_embd)))
|
|
30
|
+
|
|
31
|
+
def forward(self, tokens):
|
|
32
|
+
x = self.token_embedding(tokens)
|
|
33
|
+
x += self.position_value
|
|
34
|
+
return x
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class CLIPLayer(nn.Module):
|
|
38
|
+
|
|
39
|
+
def __init__(self, n_head: int, n_embd: int):
|
|
40
|
+
super().__init__()
|
|
41
|
+
self.layernorm_1 = nn.LayerNorm(n_embd)
|
|
42
|
+
self.attention = SelfAttention(n_head, n_embd)
|
|
43
|
+
self.layernorm_2 = nn.LayerNorm(n_embd)
|
|
44
|
+
self.linear_1 = nn.Linear(n_embd, 4 * n_embd)
|
|
45
|
+
self.linear_2 = nn.Linear(4 * n_embd, n_embd)
|
|
46
|
+
|
|
47
|
+
def forward(self, x):
|
|
48
|
+
residue = x
|
|
49
|
+
x = self.layernorm_1(x)
|
|
50
|
+
x = self.attention(x, causal_mask=True)
|
|
51
|
+
x += residue
|
|
52
|
+
|
|
53
|
+
residue = x
|
|
54
|
+
x = self.layernorm_2(x)
|
|
55
|
+
x = self.linear_1(x)
|
|
56
|
+
x = x * torch.sigmoid(1.702 * x) # QuickGELU activation function
|
|
57
|
+
x = self.linear_2(x)
|
|
58
|
+
x += residue
|
|
59
|
+
|
|
60
|
+
return x
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
class CLIP(nn.Module):
|
|
64
|
+
|
|
65
|
+
def __init__(self):
|
|
66
|
+
super().__init__()
|
|
67
|
+
self.embedding = CLIPEmbedding(49408, 768, 77)
|
|
68
|
+
self.layers = nn.ModuleList([CLIPLayer(12, 768) for i in range(12)])
|
|
69
|
+
self.layernorm = nn.LayerNorm(768)
|
|
70
|
+
|
|
71
|
+
def forward(self, tokens: torch.LongTensor) -> torch.FloatTensor:
|
|
72
|
+
tokens = tokens.type(torch.long)
|
|
73
|
+
|
|
74
|
+
state = self.embedding(tokens)
|
|
75
|
+
for layer in self.layers:
|
|
76
|
+
state = layer(state)
|
|
77
|
+
output = self.layernorm(state)
|
|
78
|
+
return output
|
|
@@ -0,0 +1,111 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
from torch import nn
|
|
17
|
+
from torch.nn import functional as F
|
|
18
|
+
|
|
19
|
+
from ai_edge_torch.generative.examples.stable_diffusion.attention import SelfAttention # NOQA
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class AttentionBlock(nn.Module):
|
|
23
|
+
|
|
24
|
+
def __init__(self, channels):
|
|
25
|
+
super().__init__()
|
|
26
|
+
self.groupnorm = nn.GroupNorm(32, channels)
|
|
27
|
+
self.attention = SelfAttention(1, channels)
|
|
28
|
+
|
|
29
|
+
def forward(self, x):
|
|
30
|
+
residue = x
|
|
31
|
+
x = self.groupnorm(x)
|
|
32
|
+
|
|
33
|
+
n, c, h, w = x.shape
|
|
34
|
+
x = x.view((n, c, h * w))
|
|
35
|
+
x = x.transpose(-1, -2)
|
|
36
|
+
x = self.attention(x)
|
|
37
|
+
x = x.transpose(-1, -2)
|
|
38
|
+
x = x.view((n, c, h, w))
|
|
39
|
+
|
|
40
|
+
x += residue
|
|
41
|
+
return x
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class ResidualBlock(nn.Module):
|
|
45
|
+
|
|
46
|
+
def __init__(self, in_channels, out_channels):
|
|
47
|
+
super().__init__()
|
|
48
|
+
self.groupnorm_1 = nn.GroupNorm(32, in_channels)
|
|
49
|
+
self.conv_1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
|
|
50
|
+
|
|
51
|
+
self.groupnorm_2 = nn.GroupNorm(32, out_channels)
|
|
52
|
+
self.conv_2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
|
|
53
|
+
|
|
54
|
+
if in_channels == out_channels:
|
|
55
|
+
self.residual_layer = nn.Identity()
|
|
56
|
+
else:
|
|
57
|
+
self.residual_layer = nn.Conv2d(
|
|
58
|
+
in_channels, out_channels, kernel_size=1, padding=0
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
def forward(self, x):
|
|
62
|
+
residue = x
|
|
63
|
+
|
|
64
|
+
x = self.groupnorm_1(x)
|
|
65
|
+
x = F.silu(x)
|
|
66
|
+
x = self.conv_1(x)
|
|
67
|
+
|
|
68
|
+
x = self.groupnorm_2(x)
|
|
69
|
+
x = F.silu(x)
|
|
70
|
+
x = self.conv_2(x)
|
|
71
|
+
|
|
72
|
+
return x + self.residual_layer(residue)
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
class Decoder(nn.Sequential):
|
|
76
|
+
|
|
77
|
+
def __init__(self):
|
|
78
|
+
super().__init__(
|
|
79
|
+
nn.Conv2d(4, 4, kernel_size=1, padding=0),
|
|
80
|
+
nn.Conv2d(4, 512, kernel_size=3, padding=1),
|
|
81
|
+
ResidualBlock(512, 512),
|
|
82
|
+
AttentionBlock(512),
|
|
83
|
+
ResidualBlock(512, 512),
|
|
84
|
+
ResidualBlock(512, 512),
|
|
85
|
+
ResidualBlock(512, 512),
|
|
86
|
+
ResidualBlock(512, 512),
|
|
87
|
+
nn.Upsample(scale_factor=2),
|
|
88
|
+
nn.Conv2d(512, 512, kernel_size=3, padding=1),
|
|
89
|
+
ResidualBlock(512, 512),
|
|
90
|
+
ResidualBlock(512, 512),
|
|
91
|
+
ResidualBlock(512, 512),
|
|
92
|
+
nn.Upsample(scale_factor=2),
|
|
93
|
+
nn.Conv2d(512, 512, kernel_size=3, padding=1),
|
|
94
|
+
ResidualBlock(512, 256),
|
|
95
|
+
ResidualBlock(256, 256),
|
|
96
|
+
ResidualBlock(256, 256),
|
|
97
|
+
nn.Upsample(scale_factor=2),
|
|
98
|
+
nn.Conv2d(256, 256, kernel_size=3, padding=1),
|
|
99
|
+
ResidualBlock(256, 128),
|
|
100
|
+
ResidualBlock(128, 128),
|
|
101
|
+
ResidualBlock(128, 128),
|
|
102
|
+
nn.GroupNorm(32, 128),
|
|
103
|
+
nn.SiLU(),
|
|
104
|
+
nn.Conv2d(128, 3, kernel_size=3, padding=1),
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
def forward(self, x):
|
|
108
|
+
x = x / 0.18215
|
|
109
|
+
for module in self:
|
|
110
|
+
x = module(x)
|
|
111
|
+
return x
|
|
@@ -0,0 +1,498 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
import torch
|
|
17
|
+
from torch import nn
|
|
18
|
+
from torch.nn import functional as F
|
|
19
|
+
|
|
20
|
+
from ai_edge_torch.generative.examples.stable_diffusion.attention import CrossAttention # NOQA
|
|
21
|
+
from ai_edge_torch.generative.examples.stable_diffusion.attention import SelfAttention # NOQA
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class TimeEmbedding(nn.Module):
|
|
25
|
+
|
|
26
|
+
def __init__(self, n_embd):
|
|
27
|
+
super().__init__()
|
|
28
|
+
self.linear_1 = nn.Linear(n_embd, 4 * n_embd)
|
|
29
|
+
self.linear_2 = nn.Linear(4 * n_embd, 4 * n_embd)
|
|
30
|
+
|
|
31
|
+
def forward(self, x):
|
|
32
|
+
x = self.linear_1(x)
|
|
33
|
+
x = F.silu(x)
|
|
34
|
+
x = self.linear_2(x)
|
|
35
|
+
return x
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class ResidualBlock(nn.Module):
|
|
39
|
+
|
|
40
|
+
def __init__(self, in_channels, out_channels, n_time=1280):
|
|
41
|
+
super().__init__()
|
|
42
|
+
self.groupnorm_feature = nn.GroupNorm(32, in_channels)
|
|
43
|
+
self.conv_feature = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
|
|
44
|
+
self.linear_time = nn.Linear(n_time, out_channels)
|
|
45
|
+
|
|
46
|
+
self.groupnorm_merged = nn.GroupNorm(32, out_channels)
|
|
47
|
+
self.conv_merged = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
|
|
48
|
+
|
|
49
|
+
if in_channels == out_channels:
|
|
50
|
+
self.residual_layer = nn.Identity()
|
|
51
|
+
else:
|
|
52
|
+
self.residual_layer = nn.Conv2d(
|
|
53
|
+
in_channels, out_channels, kernel_size=1, padding=0
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
def forward(self, feature, time):
|
|
57
|
+
residue = feature
|
|
58
|
+
|
|
59
|
+
feature = self.groupnorm_feature(feature)
|
|
60
|
+
feature = F.silu(feature)
|
|
61
|
+
feature = self.conv_feature(feature)
|
|
62
|
+
|
|
63
|
+
time = F.silu(time)
|
|
64
|
+
time = self.linear_time(time)
|
|
65
|
+
|
|
66
|
+
merged = feature + time.unsqueeze(-1).unsqueeze(-1)
|
|
67
|
+
merged = self.groupnorm_merged(merged)
|
|
68
|
+
merged = F.silu(merged)
|
|
69
|
+
merged = self.conv_merged(merged)
|
|
70
|
+
|
|
71
|
+
return merged + self.residual_layer(residue)
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
class AttentionBlock(nn.Module):
|
|
75
|
+
|
|
76
|
+
def __init__(self, n_head: int, n_embd: int, d_context=768):
|
|
77
|
+
super().__init__()
|
|
78
|
+
channels = n_head * n_embd
|
|
79
|
+
|
|
80
|
+
self.groupnorm = nn.GroupNorm(32, channels, eps=1e-6)
|
|
81
|
+
self.conv_input = nn.Conv2d(channels, channels, kernel_size=1, padding=0)
|
|
82
|
+
|
|
83
|
+
self.layernorm_1 = nn.LayerNorm(channels)
|
|
84
|
+
self.attention_1 = SelfAttention(n_head, channels, in_proj_bias=False)
|
|
85
|
+
self.layernorm_2 = nn.LayerNorm(channels)
|
|
86
|
+
self.attention_2 = CrossAttention(n_head, channels, d_context, in_proj_bias=False)
|
|
87
|
+
self.layernorm_3 = nn.LayerNorm(channels)
|
|
88
|
+
self.linear_geglu_1 = nn.Linear(channels, 4 * channels * 2)
|
|
89
|
+
self.linear_geglu_2 = nn.Linear(4 * channels, channels)
|
|
90
|
+
|
|
91
|
+
self.conv_output = nn.Conv2d(channels, channels, kernel_size=1, padding=0)
|
|
92
|
+
|
|
93
|
+
def forward(self, x, context):
|
|
94
|
+
residue_long = x
|
|
95
|
+
|
|
96
|
+
x = self.groupnorm(x)
|
|
97
|
+
x = self.conv_input(x)
|
|
98
|
+
|
|
99
|
+
n, c, h, w = x.shape
|
|
100
|
+
x = x.view((n, c, h * w)) # (n, c, hw)
|
|
101
|
+
x = x.transpose(-1, -2) # (n, hw, c)
|
|
102
|
+
|
|
103
|
+
residue_short = x
|
|
104
|
+
x = self.layernorm_1(x)
|
|
105
|
+
x = self.attention_1(x)
|
|
106
|
+
x += residue_short
|
|
107
|
+
|
|
108
|
+
residue_short = x
|
|
109
|
+
x = self.layernorm_2(x)
|
|
110
|
+
x = self.attention_2(x, context)
|
|
111
|
+
x += residue_short
|
|
112
|
+
|
|
113
|
+
residue_short = x
|
|
114
|
+
x = self.layernorm_3(x)
|
|
115
|
+
x, gate = self.linear_geglu_1(x).chunk(2, dim=-1)
|
|
116
|
+
x = x * F.gelu(gate)
|
|
117
|
+
x = self.linear_geglu_2(x)
|
|
118
|
+
x += residue_short
|
|
119
|
+
|
|
120
|
+
x = x.transpose(-1, -2) # (n, c, hw)
|
|
121
|
+
x = x.view((n, c, h, w)) # (n, c, h, w)
|
|
122
|
+
|
|
123
|
+
return self.conv_output(x) + residue_long
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
class Upsample(nn.Module):
|
|
127
|
+
|
|
128
|
+
def __init__(self, channels):
|
|
129
|
+
super().__init__()
|
|
130
|
+
self.conv = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
|
|
131
|
+
|
|
132
|
+
def forward(self, x):
|
|
133
|
+
x = F.interpolate(x, scale_factor=2, mode='nearest')
|
|
134
|
+
return self.conv(x)
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
class SwitchSequential(nn.Sequential):
|
|
138
|
+
|
|
139
|
+
def forward(self, x, context, time):
|
|
140
|
+
for layer in self:
|
|
141
|
+
if isinstance(layer, AttentionBlock):
|
|
142
|
+
x = layer(x, context)
|
|
143
|
+
elif isinstance(layer, ResidualBlock):
|
|
144
|
+
x = layer(x, time)
|
|
145
|
+
else:
|
|
146
|
+
x = layer(x)
|
|
147
|
+
return x
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
class UNet(nn.Module):
|
|
151
|
+
|
|
152
|
+
def __init__(self):
|
|
153
|
+
super().__init__()
|
|
154
|
+
self.encoders = nn.ModuleList(
|
|
155
|
+
[
|
|
156
|
+
SwitchSequential(nn.Conv2d(4, 320, kernel_size=3, padding=1)),
|
|
157
|
+
SwitchSequential(ResidualBlock(320, 320), AttentionBlock(8, 40)),
|
|
158
|
+
SwitchSequential(ResidualBlock(320, 320), AttentionBlock(8, 40)),
|
|
159
|
+
SwitchSequential(nn.Conv2d(320, 320, kernel_size=3, stride=2, padding=1)),
|
|
160
|
+
SwitchSequential(ResidualBlock(320, 640), AttentionBlock(8, 80)),
|
|
161
|
+
SwitchSequential(ResidualBlock(640, 640), AttentionBlock(8, 80)),
|
|
162
|
+
SwitchSequential(nn.Conv2d(640, 640, kernel_size=3, stride=2, padding=1)),
|
|
163
|
+
SwitchSequential(ResidualBlock(640, 1280), AttentionBlock(8, 160)),
|
|
164
|
+
SwitchSequential(ResidualBlock(1280, 1280), AttentionBlock(8, 160)),
|
|
165
|
+
SwitchSequential(nn.Conv2d(1280, 1280, kernel_size=3, stride=2, padding=1)),
|
|
166
|
+
SwitchSequential(ResidualBlock(1280, 1280)),
|
|
167
|
+
SwitchSequential(ResidualBlock(1280, 1280)),
|
|
168
|
+
]
|
|
169
|
+
)
|
|
170
|
+
self.bottleneck = SwitchSequential(
|
|
171
|
+
ResidualBlock(1280, 1280),
|
|
172
|
+
AttentionBlock(8, 160),
|
|
173
|
+
ResidualBlock(1280, 1280),
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
self.decoders = nn.ModuleList(
|
|
177
|
+
[
|
|
178
|
+
SwitchSequential(ResidualBlock(2560, 1280)),
|
|
179
|
+
SwitchSequential(ResidualBlock(2560, 1280)),
|
|
180
|
+
SwitchSequential(ResidualBlock(2560, 1280), Upsample(1280)),
|
|
181
|
+
SwitchSequential(ResidualBlock(2560, 1280), AttentionBlock(8, 160)),
|
|
182
|
+
SwitchSequential(ResidualBlock(2560, 1280), AttentionBlock(8, 160)),
|
|
183
|
+
SwitchSequential(
|
|
184
|
+
ResidualBlock(1920, 1280), AttentionBlock(8, 160), Upsample(1280)
|
|
185
|
+
),
|
|
186
|
+
SwitchSequential(ResidualBlock(1920, 640), AttentionBlock(8, 80)),
|
|
187
|
+
SwitchSequential(ResidualBlock(1280, 640), AttentionBlock(8, 80)),
|
|
188
|
+
SwitchSequential(
|
|
189
|
+
ResidualBlock(960, 640), AttentionBlock(8, 80), Upsample(640)
|
|
190
|
+
),
|
|
191
|
+
SwitchSequential(ResidualBlock(960, 320), AttentionBlock(8, 40)),
|
|
192
|
+
SwitchSequential(ResidualBlock(640, 320), AttentionBlock(8, 40)),
|
|
193
|
+
SwitchSequential(ResidualBlock(640, 320), AttentionBlock(8, 40)),
|
|
194
|
+
]
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
def forward(self, x, context, time):
|
|
198
|
+
skip_connections = []
|
|
199
|
+
for layers in self.encoders:
|
|
200
|
+
x = layers(x, context, time)
|
|
201
|
+
skip_connections.append(x)
|
|
202
|
+
|
|
203
|
+
x = self.bottleneck(x, context, time)
|
|
204
|
+
|
|
205
|
+
# print('x shape:')
|
|
206
|
+
# print(list(x.shape))
|
|
207
|
+
# print('time shape:')
|
|
208
|
+
# print(list(time.shape))
|
|
209
|
+
|
|
210
|
+
for layers in self.decoders:
|
|
211
|
+
x = torch.cat((x, skip_connections.pop()), dim=1)
|
|
212
|
+
x = layers(x, context, time)
|
|
213
|
+
|
|
214
|
+
return x
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
# The encoder component.
|
|
218
|
+
class UNetEncoder(nn.Module):
|
|
219
|
+
|
|
220
|
+
def __init__(self):
|
|
221
|
+
super().__init__()
|
|
222
|
+
self.time_embedding = TimeEmbedding(320)
|
|
223
|
+
self.encoders = nn.ModuleList(
|
|
224
|
+
[
|
|
225
|
+
SwitchSequential(nn.Conv2d(4, 320, kernel_size=3, padding=1)),
|
|
226
|
+
SwitchSequential(ResidualBlock(320, 320), AttentionBlock(8, 40)),
|
|
227
|
+
SwitchSequential(ResidualBlock(320, 320), AttentionBlock(8, 40)),
|
|
228
|
+
SwitchSequential(nn.Conv2d(320, 320, kernel_size=3, stride=2, padding=1)),
|
|
229
|
+
SwitchSequential(ResidualBlock(320, 640), AttentionBlock(8, 80)),
|
|
230
|
+
SwitchSequential(ResidualBlock(640, 640), AttentionBlock(8, 80)),
|
|
231
|
+
SwitchSequential(nn.Conv2d(640, 640, kernel_size=3, stride=2, padding=1)),
|
|
232
|
+
SwitchSequential(ResidualBlock(640, 1280), AttentionBlock(8, 160)),
|
|
233
|
+
SwitchSequential(ResidualBlock(1280, 1280), AttentionBlock(8, 160)),
|
|
234
|
+
SwitchSequential(nn.Conv2d(1280, 1280, kernel_size=3, stride=2, padding=1)),
|
|
235
|
+
SwitchSequential(ResidualBlock(1280, 1280)),
|
|
236
|
+
SwitchSequential(ResidualBlock(1280, 1280)),
|
|
237
|
+
]
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
def forward(self, x, context, time):
|
|
241
|
+
time_embedding = self.time_embedding(time)
|
|
242
|
+
skip_connections = []
|
|
243
|
+
for layers in self.encoders:
|
|
244
|
+
x = layers(x, context, time_embedding)
|
|
245
|
+
skip_connections.append(x)
|
|
246
|
+
|
|
247
|
+
return x, skip_connections, time_embedding
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
class UNetBottleNeck(nn.Module):
|
|
251
|
+
|
|
252
|
+
def __init__(self):
|
|
253
|
+
super().__init__()
|
|
254
|
+
self.bottleneck = SwitchSequential(
|
|
255
|
+
ResidualBlock(1280, 1280),
|
|
256
|
+
AttentionBlock(8, 160),
|
|
257
|
+
ResidualBlock(1280, 1280),
|
|
258
|
+
)
|
|
259
|
+
|
|
260
|
+
def forward(self, x, context, time):
|
|
261
|
+
x = self.bottleneck(x, context, time)
|
|
262
|
+
# print('shape')
|
|
263
|
+
# print(list(x.shape))
|
|
264
|
+
return x
|
|
265
|
+
|
|
266
|
+
|
|
267
|
+
# Unet decoder.
|
|
268
|
+
class UNetDecoder1(nn.Module):
|
|
269
|
+
|
|
270
|
+
def __init__(self):
|
|
271
|
+
super().__init__()
|
|
272
|
+
self.decoders = nn.ModuleList(
|
|
273
|
+
[
|
|
274
|
+
SwitchSequential(ResidualBlock(2560, 1280)),
|
|
275
|
+
SwitchSequential(ResidualBlock(2560, 1280)),
|
|
276
|
+
SwitchSequential(ResidualBlock(2560, 1280), Upsample(1280)),
|
|
277
|
+
SwitchSequential(ResidualBlock(2560, 1280), AttentionBlock(8, 160)),
|
|
278
|
+
]
|
|
279
|
+
)
|
|
280
|
+
|
|
281
|
+
def forward(self, x, context, time, s9, s10, s11, s12):
|
|
282
|
+
x = torch.cat((x, s12), dim=1)
|
|
283
|
+
x = self.decoders[0](x, context, time)
|
|
284
|
+
x = torch.cat((x, s11), dim=1)
|
|
285
|
+
x = self.decoders[1](x, context, time)
|
|
286
|
+
x = torch.cat((x, s10), dim=1)
|
|
287
|
+
x = self.decoders[2](x, context, time)
|
|
288
|
+
x = torch.cat((x, s9), dim=1)
|
|
289
|
+
x = self.decoders[3](x, context, time)
|
|
290
|
+
|
|
291
|
+
return x
|
|
292
|
+
|
|
293
|
+
|
|
294
|
+
class UNetDecoder2(nn.Module):
|
|
295
|
+
|
|
296
|
+
def __init__(self):
|
|
297
|
+
super().__init__()
|
|
298
|
+
self.decoders = nn.ModuleList(
|
|
299
|
+
[
|
|
300
|
+
SwitchSequential(ResidualBlock(2560, 1280), AttentionBlock(8, 160)),
|
|
301
|
+
SwitchSequential(
|
|
302
|
+
ResidualBlock(1920, 1280), AttentionBlock(8, 160), Upsample(1280)
|
|
303
|
+
),
|
|
304
|
+
SwitchSequential(ResidualBlock(1920, 640), AttentionBlock(8, 80)),
|
|
305
|
+
SwitchSequential(ResidualBlock(1280, 640), AttentionBlock(8, 80)),
|
|
306
|
+
]
|
|
307
|
+
)
|
|
308
|
+
|
|
309
|
+
def forward(self, x, context, time, s5, s6, s7, s8):
|
|
310
|
+
x = torch.cat((x, s8), dim=1)
|
|
311
|
+
x = self.decoders[0](x, context, time)
|
|
312
|
+
x = torch.cat((x, s7), dim=1)
|
|
313
|
+
x = self.decoders[1](x, context, time)
|
|
314
|
+
x = torch.cat((x, s6), dim=1)
|
|
315
|
+
x = self.decoders[2](x, context, time)
|
|
316
|
+
x = torch.cat((x, s5), dim=1)
|
|
317
|
+
x = self.decoders[3](x, context, time)
|
|
318
|
+
return x
|
|
319
|
+
|
|
320
|
+
|
|
321
|
+
class UNetDecoder3(nn.Module):
|
|
322
|
+
|
|
323
|
+
def __init__(self):
|
|
324
|
+
super().__init__()
|
|
325
|
+
self.decoders = nn.ModuleList(
|
|
326
|
+
[
|
|
327
|
+
SwitchSequential(
|
|
328
|
+
ResidualBlock(960, 640), AttentionBlock(8, 80), Upsample(640)
|
|
329
|
+
),
|
|
330
|
+
SwitchSequential(ResidualBlock(960, 320), AttentionBlock(8, 40)),
|
|
331
|
+
SwitchSequential(ResidualBlock(640, 320), AttentionBlock(8, 40)),
|
|
332
|
+
SwitchSequential(ResidualBlock(640, 320), AttentionBlock(8, 40)),
|
|
333
|
+
]
|
|
334
|
+
)
|
|
335
|
+
self.final = FinalLayer(320, 4)
|
|
336
|
+
|
|
337
|
+
def forward(self, x, context, time, s1, s2, s3, s4):
|
|
338
|
+
x = torch.cat((x, s4), dim=1)
|
|
339
|
+
x = self.decoders[0](x, context, time)
|
|
340
|
+
x = torch.cat((x, s3), dim=1)
|
|
341
|
+
x = self.decoders[1](x, context, time)
|
|
342
|
+
x = torch.cat((x, s2), dim=1)
|
|
343
|
+
x = self.decoders[2](x, context, time)
|
|
344
|
+
x = torch.cat((x, s1), dim=1)
|
|
345
|
+
x = self.decoders[3](x, context, time)
|
|
346
|
+
|
|
347
|
+
x = self.final(x)
|
|
348
|
+
return x
|
|
349
|
+
|
|
350
|
+
|
|
351
|
+
class UNetDecoder(nn.Module):
|
|
352
|
+
|
|
353
|
+
def __init__(self):
|
|
354
|
+
super().__init__()
|
|
355
|
+
self.decoders = nn.ModuleList(
|
|
356
|
+
[
|
|
357
|
+
SwitchSequential(ResidualBlock(2560, 1280)),
|
|
358
|
+
SwitchSequential(ResidualBlock(2560, 1280)),
|
|
359
|
+
SwitchSequential(ResidualBlock(2560, 1280), Upsample(1280)),
|
|
360
|
+
SwitchSequential(ResidualBlock(2560, 1280), AttentionBlock(8, 160)),
|
|
361
|
+
SwitchSequential(ResidualBlock(2560, 1280), AttentionBlock(8, 160)),
|
|
362
|
+
SwitchSequential(
|
|
363
|
+
ResidualBlock(1920, 1280), AttentionBlock(8, 160), Upsample(1280)
|
|
364
|
+
),
|
|
365
|
+
SwitchSequential(ResidualBlock(1920, 640), AttentionBlock(8, 80)),
|
|
366
|
+
SwitchSequential(ResidualBlock(1280, 640), AttentionBlock(8, 80)),
|
|
367
|
+
SwitchSequential(
|
|
368
|
+
ResidualBlock(960, 640), AttentionBlock(8, 80), Upsample(640)
|
|
369
|
+
),
|
|
370
|
+
SwitchSequential(ResidualBlock(960, 320), AttentionBlock(8, 40)),
|
|
371
|
+
SwitchSequential(ResidualBlock(640, 320), AttentionBlock(8, 40)),
|
|
372
|
+
SwitchSequential(ResidualBlock(640, 320), AttentionBlock(8, 40)),
|
|
373
|
+
]
|
|
374
|
+
)
|
|
375
|
+
self.final = FinalLayer(320, 4)
|
|
376
|
+
|
|
377
|
+
def forward(
|
|
378
|
+
self, x, context, time, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12
|
|
379
|
+
):
|
|
380
|
+
x = torch.cat((x, s12), dim=1)
|
|
381
|
+
x = self.decoders[0](x, context, time)
|
|
382
|
+
x = torch.cat((x, s11), dim=1)
|
|
383
|
+
x = self.decoders[1](x, context, time)
|
|
384
|
+
x = torch.cat((x, s10), dim=1)
|
|
385
|
+
x = self.decoders[2](x, context, time)
|
|
386
|
+
x = torch.cat((x, s9), dim=1)
|
|
387
|
+
x = self.decoders[3](x, context, time)
|
|
388
|
+
x = torch.cat((x, s8), dim=1)
|
|
389
|
+
x = self.decoders[4](x, context, time)
|
|
390
|
+
x = torch.cat((x, s7), dim=1)
|
|
391
|
+
x = self.decoders[5](x, context, time)
|
|
392
|
+
x = torch.cat((x, s6), dim=1)
|
|
393
|
+
x = self.decoders[6](x, context, time)
|
|
394
|
+
x = torch.cat((x, s5), dim=1)
|
|
395
|
+
x = self.decoders[7](x, context, time)
|
|
396
|
+
x = torch.cat((x, s4), dim=1)
|
|
397
|
+
x = self.decoders[0](x, context, time)
|
|
398
|
+
x = torch.cat((x, s3), dim=1)
|
|
399
|
+
x = self.decoders[1](x, context, time)
|
|
400
|
+
x = torch.cat((x, s2), dim=1)
|
|
401
|
+
x = self.decoders[2](x, context, time)
|
|
402
|
+
x = torch.cat((x, s1), dim=1)
|
|
403
|
+
x = self.decoders[3](x, context, time)
|
|
404
|
+
|
|
405
|
+
x = self.final(x)
|
|
406
|
+
|
|
407
|
+
return x
|
|
408
|
+
|
|
409
|
+
|
|
410
|
+
class FinalLayer(nn.Module):
|
|
411
|
+
|
|
412
|
+
def __init__(self, in_channels, out_channels):
|
|
413
|
+
super().__init__()
|
|
414
|
+
self.groupnorm = nn.GroupNorm(32, in_channels)
|
|
415
|
+
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
|
|
416
|
+
|
|
417
|
+
def forward(self, x):
|
|
418
|
+
x = self.groupnorm(x)
|
|
419
|
+
x = F.silu(x)
|
|
420
|
+
x = self.conv(x)
|
|
421
|
+
return x
|
|
422
|
+
|
|
423
|
+
|
|
424
|
+
class Diffusion(nn.Module):
|
|
425
|
+
|
|
426
|
+
def __init__(self):
|
|
427
|
+
super().__init__()
|
|
428
|
+
self.time_embedding = TimeEmbedding(320)
|
|
429
|
+
self.unet = UNet()
|
|
430
|
+
self.final = FinalLayer(320, 4)
|
|
431
|
+
|
|
432
|
+
def forward(self, latent, context, time):
|
|
433
|
+
time = self.time_embedding(time)
|
|
434
|
+
# print('time:')
|
|
435
|
+
# print(list(time.shape))
|
|
436
|
+
output = self.unet(latent, context, time)
|
|
437
|
+
output = self.final(output)
|
|
438
|
+
return output
|
|
439
|
+
|
|
440
|
+
|
|
441
|
+
# Calling code as if Diffusion is splitted into two parts.
|
|
442
|
+
class DiffusionSplitted(nn.Module):
|
|
443
|
+
|
|
444
|
+
def __init__(self):
|
|
445
|
+
super().__init__()
|
|
446
|
+
self.unet_encoder = UNetEncoder()
|
|
447
|
+
self.bottleneck = UNetBottleNeck()
|
|
448
|
+
self.unet_decoder1 = UNetDecoder1()
|
|
449
|
+
self.unet_decoder2 = UNetDecoder2()
|
|
450
|
+
self.unet_decoder3 = UNetDecoder3()
|
|
451
|
+
|
|
452
|
+
def get_skip_connections(self, latent, context, time):
|
|
453
|
+
_, skip_connections, _ = self.unet_encoder(latent, context, time)
|
|
454
|
+
return skip_connections
|
|
455
|
+
|
|
456
|
+
def forward(self, latent, context, time):
|
|
457
|
+
output, skip_connections, time = self.unet_encoder(latent, context, time)
|
|
458
|
+
# print("output shape of unet encoder...")
|
|
459
|
+
# print(list(output.shape))
|
|
460
|
+
# print("output shape of time...")
|
|
461
|
+
# print(list(time.shape))
|
|
462
|
+
output = self.bottleneck(output, context, time)
|
|
463
|
+
# print("output shape of bn")
|
|
464
|
+
# print(list(output.shape))
|
|
465
|
+
output = self.unet_decoder1(
|
|
466
|
+
output,
|
|
467
|
+
context,
|
|
468
|
+
time,
|
|
469
|
+
skip_connections[8],
|
|
470
|
+
skip_connections[9],
|
|
471
|
+
skip_connections[10],
|
|
472
|
+
skip_connections[11],
|
|
473
|
+
)
|
|
474
|
+
# print("output shape of d1:")
|
|
475
|
+
# print(list(output.shape))
|
|
476
|
+
|
|
477
|
+
output = self.unet_decoder2(
|
|
478
|
+
output,
|
|
479
|
+
context,
|
|
480
|
+
time,
|
|
481
|
+
skip_connections[4],
|
|
482
|
+
skip_connections[5],
|
|
483
|
+
skip_connections[6],
|
|
484
|
+
skip_connections[7],
|
|
485
|
+
)
|
|
486
|
+
|
|
487
|
+
# print("output shape of d2:")
|
|
488
|
+
# print(list(output.shape))
|
|
489
|
+
output = self.unet_decoder3(
|
|
490
|
+
output,
|
|
491
|
+
context,
|
|
492
|
+
time,
|
|
493
|
+
skip_connections[0],
|
|
494
|
+
skip_connections[1],
|
|
495
|
+
skip_connections[2],
|
|
496
|
+
skip_connections[3],
|
|
497
|
+
)
|
|
498
|
+
return output
|
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
import torch
|
|
17
|
+
from torch import nn
|
|
18
|
+
from torch.nn import functional as F
|
|
19
|
+
|
|
20
|
+
from ai_edge_torch.generative.examples.stable_diffusion.decoder import AttentionBlock # NOQA
|
|
21
|
+
from ai_edge_torch.generative.examples.stable_diffusion.decoder import ResidualBlock # NOQA
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class Encoder(nn.Sequential):
|
|
25
|
+
|
|
26
|
+
def __init__(self):
|
|
27
|
+
super().__init__(
|
|
28
|
+
nn.Conv2d(3, 128, kernel_size=3, padding=1),
|
|
29
|
+
ResidualBlock(128, 128),
|
|
30
|
+
ResidualBlock(128, 128),
|
|
31
|
+
nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=0),
|
|
32
|
+
ResidualBlock(128, 256),
|
|
33
|
+
ResidualBlock(256, 256),
|
|
34
|
+
nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=0),
|
|
35
|
+
ResidualBlock(256, 512),
|
|
36
|
+
ResidualBlock(512, 512),
|
|
37
|
+
nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=0),
|
|
38
|
+
ResidualBlock(512, 512),
|
|
39
|
+
ResidualBlock(512, 512),
|
|
40
|
+
ResidualBlock(512, 512),
|
|
41
|
+
AttentionBlock(512),
|
|
42
|
+
ResidualBlock(512, 512),
|
|
43
|
+
nn.GroupNorm(32, 512),
|
|
44
|
+
nn.SiLU(),
|
|
45
|
+
nn.Conv2d(512, 8, kernel_size=3, padding=1),
|
|
46
|
+
nn.Conv2d(8, 8, kernel_size=1, padding=0),
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
def forward(self, x, noise):
|
|
50
|
+
for module in self:
|
|
51
|
+
if getattr(module, 'stride', None) == (
|
|
52
|
+
2,
|
|
53
|
+
2,
|
|
54
|
+
): # Padding at downsampling should be asymmetric (see #8)
|
|
55
|
+
x = F.pad(x, (0, 1, 0, 1))
|
|
56
|
+
x = module(x)
|
|
57
|
+
|
|
58
|
+
mean, log_variance = torch.chunk(x, 2, dim=1)
|
|
59
|
+
log_variance = torch.clamp(log_variance, -30, 20)
|
|
60
|
+
variance = log_variance.exp()
|
|
61
|
+
stdev = variance.sqrt()
|
|
62
|
+
x = mean + stdev * noise
|
|
63
|
+
|
|
64
|
+
x *= 0.18215
|
|
65
|
+
return x
|
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
import os
|
|
17
|
+
|
|
18
|
+
import numpy as np
|
|
19
|
+
import torch
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def get_time_embedding(timestep):
|
|
23
|
+
freqs = torch.pow(10000, -torch.arange(start=0, end=160, dtype=torch.float32) / 160)
|
|
24
|
+
x = torch.tensor([timestep], dtype=torch.float32)[:, None] * freqs[None]
|
|
25
|
+
return torch.cat([torch.cos(x), torch.sin(x)], dim=-1)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def get_alphas_cumprod(beta_start=0.00085, beta_end=0.0120, n_training_steps=1000):
|
|
29
|
+
betas = (
|
|
30
|
+
np.linspace(beta_start**0.5, beta_end**0.5, n_training_steps, dtype=np.float32)
|
|
31
|
+
** 2
|
|
32
|
+
)
|
|
33
|
+
alphas = 1.0 - betas
|
|
34
|
+
alphas_cumprod = np.cumprod(alphas, axis=0)
|
|
35
|
+
return alphas_cumprod
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def get_file_path(filename, url=None):
|
|
39
|
+
module_location = os.path.dirname(os.path.abspath(__file__))
|
|
40
|
+
parent_location = os.path.dirname(module_location)
|
|
41
|
+
file_location = os.path.join(parent_location, "data", filename)
|
|
42
|
+
return file_location
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def move_channel(image, to):
|
|
46
|
+
if to == "first":
|
|
47
|
+
return image.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
|
|
48
|
+
elif to == "last":
|
|
49
|
+
return image.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
|
|
50
|
+
else:
|
|
51
|
+
raise ValueError("to must be one of the following: first, last")
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def rescale(x, old_range, new_range, clamp=False):
|
|
55
|
+
old_min, old_max = old_range
|
|
56
|
+
new_min, new_max = new_range
|
|
57
|
+
x -= old_min
|
|
58
|
+
x *= (new_max - new_min) / (old_max - old_min)
|
|
59
|
+
x += new_min
|
|
60
|
+
if clamp:
|
|
61
|
+
x = x.clamp(new_min, new_max)
|
|
62
|
+
return x
|
|
@@ -0,0 +1,80 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
# Testing weight loader utilities.
|
|
16
|
+
|
|
17
|
+
import os
|
|
18
|
+
import tempfile
|
|
19
|
+
import unittest
|
|
20
|
+
|
|
21
|
+
import safetensors.torch
|
|
22
|
+
import torch
|
|
23
|
+
|
|
24
|
+
from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
|
|
25
|
+
from ai_edge_torch.generative.utilities import loader as loading_utils
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class TestLoader(unittest.TestCase):
|
|
29
|
+
"""Unit tests that check weight loader."""
|
|
30
|
+
|
|
31
|
+
def test_load_safetensors(self):
|
|
32
|
+
with tempfile.TemporaryDirectory() as temp_dir:
|
|
33
|
+
file_path = os.path.join(temp_dir, "test.safetensors")
|
|
34
|
+
test_data = {"weight": torch.randn(20, 10), "bias": torch.randn(20)}
|
|
35
|
+
safetensors.torch.save_file(test_data, file_path)
|
|
36
|
+
|
|
37
|
+
loaded_tensors = loading_utils.load_safetensors(file_path)
|
|
38
|
+
self.assertIn("weight", loaded_tensors)
|
|
39
|
+
self.assertIn("bias", loaded_tensors)
|
|
40
|
+
|
|
41
|
+
def test_load_statedict(self):
|
|
42
|
+
with tempfile.TemporaryDirectory() as temp_dir:
|
|
43
|
+
file_path = os.path.join(temp_dir, "test.pt")
|
|
44
|
+
model = torch.nn.Linear(10, 5)
|
|
45
|
+
state_dict = model.state_dict()
|
|
46
|
+
torch.save(state_dict, file_path)
|
|
47
|
+
|
|
48
|
+
loaded_tensors = loading_utils.load_pytorch_statedict(file_path)
|
|
49
|
+
self.assertIn("weight", loaded_tensors)
|
|
50
|
+
self.assertIn("bias", loaded_tensors)
|
|
51
|
+
|
|
52
|
+
def test_model_loader(self):
|
|
53
|
+
with tempfile.TemporaryDirectory() as temp_dir:
|
|
54
|
+
file_path = os.path.join(temp_dir, "test.safetensors")
|
|
55
|
+
test_weights = {
|
|
56
|
+
"lm_head.weight": torch.randn((32000, 2048)),
|
|
57
|
+
"model.embed_tokens.weight": torch.randn((32000, 2048)),
|
|
58
|
+
"model.layers.0.input_layernorm.weight": torch.randn((2048,)),
|
|
59
|
+
"model.layers.0.mlp.down_proj.weight": torch.randn((2048, 5632)),
|
|
60
|
+
"model.layers.0.mlp.gate_proj.weight": torch.randn((5632, 2048)),
|
|
61
|
+
"model.layers.0.mlp.up_proj.weight": torch.randn((5632, 2048)),
|
|
62
|
+
"model.layers.0.post_attention_layernorm.weight": torch.randn((2048,)),
|
|
63
|
+
"model.layers.0.self_attn.k_proj.weight": torch.randn((256, 2048)),
|
|
64
|
+
"model.layers.0.self_attn.o_proj.weight": torch.randn((2048, 2048)),
|
|
65
|
+
"model.layers.0.self_attn.q_proj.weight": torch.randn((2048, 2048)),
|
|
66
|
+
"model.layers.0.self_attn.v_proj.weight": torch.randn((256, 2048)),
|
|
67
|
+
"model.norm.weight": torch.randn((2048,)),
|
|
68
|
+
}
|
|
69
|
+
safetensors.torch.save_file(test_weights, file_path)
|
|
70
|
+
cfg = tiny_llama.get_model_config()
|
|
71
|
+
cfg.num_layers = 1
|
|
72
|
+
model = tiny_llama.TinyLLamma(cfg)
|
|
73
|
+
|
|
74
|
+
loader = loading_utils.ModelLoader(file_path, tiny_llama.TENSOR_NAMES)
|
|
75
|
+
# if returns successfully, it means all the tensors were initiallized.
|
|
76
|
+
loader.load(model, strict=True)
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
if __name__ == "__main__":
|
|
80
|
+
unittest.main()
|
|
@@ -228,14 +228,14 @@ class ModelLoader:
|
|
|
228
228
|
q_name = self._names.attn_query_proj.format(idx)
|
|
229
229
|
k_name = self._names.attn_key_proj.format(idx)
|
|
230
230
|
v_name = self._names.attn_value_proj.format(idx)
|
|
231
|
-
converted_state[f"{prefix}.atten_func.
|
|
231
|
+
converted_state[f"{prefix}.atten_func.qkv_projection.weight"] = self._fuse_qkv(
|
|
232
232
|
config,
|
|
233
233
|
state.pop(f"{q_name}.weight"),
|
|
234
234
|
state.pop(f"{k_name}.weight"),
|
|
235
235
|
state.pop(f"{v_name}.weight"),
|
|
236
236
|
)
|
|
237
237
|
if config.attn_config.qkv_use_bias:
|
|
238
|
-
converted_state[f"{prefix}.atten_func.
|
|
238
|
+
converted_state[f"{prefix}.atten_func.qkv_projection.bias"] = self._fuse_qkv(
|
|
239
239
|
config,
|
|
240
240
|
state.pop(f"{q_name}.bias"),
|
|
241
241
|
state.pop(f"{k_name}.bias"),
|
|
@@ -243,9 +243,13 @@ class ModelLoader:
|
|
|
243
243
|
)
|
|
244
244
|
|
|
245
245
|
o_name = self._names.attn_output_proj.format(idx)
|
|
246
|
-
converted_state[f"{prefix}.atten_func.
|
|
246
|
+
converted_state[f"{prefix}.atten_func.output_projection.weight"] = state.pop(
|
|
247
|
+
f"{o_name}.weight"
|
|
248
|
+
)
|
|
247
249
|
if config.attn_config.output_proj_use_bias:
|
|
248
|
-
converted_state[f"{prefix}.atten_func.
|
|
250
|
+
converted_state[f"{prefix}.atten_func.output_projection.bias"] = state.pop(
|
|
251
|
+
f"{o_name}.bias"
|
|
252
|
+
)
|
|
249
253
|
|
|
250
254
|
def _map_norm(
|
|
251
255
|
self,
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: ai-edge-torch-nightly
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.2.0.dev20240527
|
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
|
@@ -24,8 +24,8 @@ Requires-Python: >=3.9, <3.12
|
|
|
24
24
|
Description-Content-Type: text/markdown
|
|
25
25
|
License-File: LICENSE
|
|
26
26
|
Requires-Dist: numpy
|
|
27
|
-
Requires-Dist: safetensors
|
|
28
27
|
Requires-Dist: scipy
|
|
28
|
+
Requires-Dist: safetensors
|
|
29
29
|
Requires-Dist: tabulate
|
|
30
30
|
Requires-Dist: torch ==2.4.*
|
|
31
31
|
|
|
@@ -38,6 +38,13 @@ ai_edge_torch/generative/examples/gemma/gemma.py,sha256=YF4Ua-1lnL3qhQnh1sY5-HlY
|
|
|
38
38
|
ai_edge_torch/generative/examples/phi2/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
39
39
|
ai_edge_torch/generative/examples/phi2/convert_to_tflite.py,sha256=6nOuwx9q3AUlYcjXRRXSr_3M2JKqdJ-vUf-uE3VFYHE,2512
|
|
40
40
|
ai_edge_torch/generative/examples/phi2/phi2.py,sha256=VvigzPQ_LJHeADTsMliwFwPe2BcnOhFgKDqr_WZ2JQ8,5540
|
|
41
|
+
ai_edge_torch/generative/examples/stable_diffusion/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
42
|
+
ai_edge_torch/generative/examples/stable_diffusion/attention.py,sha256=Lo4Dq7a3Kg-lyH56iqGtqCo5UaClQHRCTDdNagXGTo8,3535
|
|
43
|
+
ai_edge_torch/generative/examples/stable_diffusion/clip.py,sha256=8W4X9PKdnMsWGxXbBfm5OX6mX4XhvaMZ2gZw8yCTScY,2410
|
|
44
|
+
ai_edge_torch/generative/examples/stable_diffusion/decoder.py,sha256=beLCtogA32oYT2nlATpyT-1xzkyPF8zi4v3kfHpw6Mc,3239
|
|
45
|
+
ai_edge_torch/generative/examples/stable_diffusion/diffusion.py,sha256=nnsfgjSeL16U3TVdjTkRycaoWA2ChFeitx2RjGLpwyA,16200
|
|
46
|
+
ai_edge_torch/generative/examples/stable_diffusion/encoder.py,sha256=X6ekByU19KNHNh5OaztZEROv-QwcCwVm1xiJjm2SCoo,2251
|
|
47
|
+
ai_edge_torch/generative/examples/stable_diffusion/util.py,sha256=pG_dsV4xIaB7B8MgoRgSXBvLCVqDlF6bNunPN3GIm-s,2046
|
|
41
48
|
ai_edge_torch/generative/examples/t5/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
42
49
|
ai_edge_torch/generative/examples/t5/convert_to_tflite.py,sha256=bWtwtUacvJOEDUpuYvLTgkP7oTkXKJA-Tf4FPxlD1Cw,4536
|
|
43
50
|
ai_edge_torch/generative/examples/t5/t5.py,sha256=q2gG5RRo7RgNzvHXYC0Juh6Tgt5d_RTMSWFaYvOKiZU,21065
|
|
@@ -65,10 +72,11 @@ ai_edge_torch/generative/quantize/quant_recipe_utils.py,sha256=9ktL7fT8C5j1dnY_7
|
|
|
65
72
|
ai_edge_torch/generative/quantize/quant_recipes.py,sha256=CRA2ENevS-3usHqidWDe2wrf_epILE_7Hx-XfZQ9buk,1798
|
|
66
73
|
ai_edge_torch/generative/quantize/supported_schemes.py,sha256=OQ4ghQXknA1PPjuY-xBgAmOpaIBgYFM8F2YAIot06hE,1345
|
|
67
74
|
ai_edge_torch/generative/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
75
|
+
ai_edge_torch/generative/test/loader_test.py,sha256=N88CbrLW7Q2x1EyurwdXQ6YjsA-ySQcPxpZH3QOGp-M,3317
|
|
68
76
|
ai_edge_torch/generative/test/test_model_conversion.py,sha256=1NfZxKo9Gx6CmVfd86K1FkmsNQnjzIV1ojBS85UGvT0,6500
|
|
69
77
|
ai_edge_torch/generative/test/test_quantize.py,sha256=f70sH1ZFzdCwYj0MG-eg54WOC4LasR0D8CTUYpjxZYM,3728
|
|
70
78
|
ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
|
|
71
|
-
ai_edge_torch/generative/utilities/loader.py,sha256=
|
|
79
|
+
ai_edge_torch/generative/utilities/loader.py,sha256=c-ZOIDBVnat_5l2W5sWU7HQm7CL-wducS8poSu5PlUg,10107
|
|
72
80
|
ai_edge_torch/generative/utilities/t5_loader.py,sha256=guDTv-12UUvJGl4eDvvZX3t4rRKewfXO8SpcYXM6gbc,16156
|
|
73
81
|
ai_edge_torch/hlfb/__init__.py,sha256=rrje8a2iuKboBoV96bVq7nlS9HsnuEMbHE5JiWmCxFA,752
|
|
74
82
|
ai_edge_torch/hlfb/mark_pattern/__init__.py,sha256=2VXnHcGf23VOuP-1GriGIpuL98leBB8twp_qaScMnmc,4799
|
|
@@ -84,8 +92,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=ExThdTXqnWmGC3-F6sdXbXr8nYzkEe_qCz
|
|
|
84
92
|
ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
85
93
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
|
86
94
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=EIyKz-HY70DguWuSrJal8LpYXQ5ZSEUf3ZrVl7jikFM,4286
|
|
87
|
-
ai_edge_torch_nightly-0.
|
|
88
|
-
ai_edge_torch_nightly-0.
|
|
89
|
-
ai_edge_torch_nightly-0.
|
|
90
|
-
ai_edge_torch_nightly-0.
|
|
91
|
-
ai_edge_torch_nightly-0.
|
|
95
|
+
ai_edge_torch_nightly-0.2.0.dev20240527.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
|
96
|
+
ai_edge_torch_nightly-0.2.0.dev20240527.dist-info/METADATA,sha256=nbZoIm0s6CWdrMkaffTrpz-XooKzTR1q0SQ17rs-AKU,1748
|
|
97
|
+
ai_edge_torch_nightly-0.2.0.dev20240527.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
|
98
|
+
ai_edge_torch_nightly-0.2.0.dev20240527.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
|
99
|
+
ai_edge_torch_nightly-0.2.0.dev20240527.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|