Myosotis-Researches 0.0.12__py3-none-any.whl → 0.0.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- myosotis_researches/CcGAN/models_128/CcGAN_SAGAN.py +301 -0
- myosotis_researches/CcGAN/models_128/ResNet_class_eval.py +141 -0
- myosotis_researches/CcGAN/models_128/ResNet_embed.py +188 -0
- myosotis_researches/CcGAN/models_128/ResNet_regre_eval.py +175 -0
- myosotis_researches/CcGAN/models_128/__init__.py +8 -0
- myosotis_researches/CcGAN/models_128/autoencoder.py +119 -0
- myosotis_researches/CcGAN/models_128/cGAN_SAGAN.py +276 -0
- myosotis_researches/CcGAN/models_128/cGAN_concat_SAGAN.py +245 -0
- myosotis_researches/CcGAN/models_256/CcGAN_SAGAN.py +303 -0
- myosotis_researches/CcGAN/models_256/ResNet_class_eval.py +142 -0
- myosotis_researches/CcGAN/models_256/ResNet_embed.py +188 -0
- myosotis_researches/CcGAN/models_256/ResNet_regre_eval.py +178 -0
- myosotis_researches/CcGAN/models_256/__init__.py +8 -0
- myosotis_researches/CcGAN/models_256/autoencoder.py +133 -0
- myosotis_researches/CcGAN/models_256/cGAN_SAGAN.py +280 -0
- myosotis_researches/CcGAN/models_256/cGAN_concat_SAGAN.py +249 -0
- myosotis_researches/CcGAN/utils/make_h5.py +13 -9
- {myosotis_researches-0.0.12.dist-info → myosotis_researches-0.0.14.dist-info}/METADATA +1 -1
- myosotis_researches-0.0.14.dist-info/RECORD +28 -0
- myosotis_researches-0.0.12.dist-info/RECORD +0 -12
- {myosotis_researches-0.0.12.dist-info → myosotis_researches-0.0.14.dist-info}/WHEEL +0 -0
- {myosotis_researches-0.0.12.dist-info → myosotis_researches-0.0.14.dist-info}/licenses/LICENSE +0 -0
- {myosotis_researches-0.0.12.dist-info → myosotis_researches-0.0.14.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,280 @@
|
|
1
|
+
import numpy as np
|
2
|
+
import torch
|
3
|
+
import torch.nn as nn
|
4
|
+
import torch.nn.functional as F
|
5
|
+
|
6
|
+
from torch.nn.utils import spectral_norm
|
7
|
+
from torch.nn.init import xavier_uniform_
|
8
|
+
|
9
|
+
|
10
|
+
def init_weights(m):
|
11
|
+
if type(m) == nn.Linear or type(m) == nn.Conv2d:
|
12
|
+
xavier_uniform_(m.weight)
|
13
|
+
m.bias.data.fill_(0.)
|
14
|
+
|
15
|
+
|
16
|
+
def snconv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
|
17
|
+
return spectral_norm(nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
|
18
|
+
stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias))
|
19
|
+
|
20
|
+
|
21
|
+
def snlinear(in_features, out_features):
|
22
|
+
return spectral_norm(nn.Linear(in_features=in_features, out_features=out_features))
|
23
|
+
|
24
|
+
|
25
|
+
def sn_embedding(num_embeddings, embedding_dim):
|
26
|
+
return spectral_norm(nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim))
|
27
|
+
|
28
|
+
|
29
|
+
class Self_Attn(nn.Module):
|
30
|
+
""" Self attention Layer"""
|
31
|
+
|
32
|
+
def __init__(self, in_channels):
|
33
|
+
super(Self_Attn, self).__init__()
|
34
|
+
self.in_channels = in_channels
|
35
|
+
self.snconv1x1_theta = snconv2d(in_channels=in_channels, out_channels=in_channels//8, kernel_size=1, stride=1, padding=0)
|
36
|
+
self.snconv1x1_phi = snconv2d(in_channels=in_channels, out_channels=in_channels//8, kernel_size=1, stride=1, padding=0)
|
37
|
+
self.snconv1x1_g = snconv2d(in_channels=in_channels, out_channels=in_channels//2, kernel_size=1, stride=1, padding=0)
|
38
|
+
self.snconv1x1_attn = snconv2d(in_channels=in_channels//2, out_channels=in_channels, kernel_size=1, stride=1, padding=0)
|
39
|
+
self.maxpool = nn.MaxPool2d(2, stride=2, padding=0)
|
40
|
+
self.softmax = nn.Softmax(dim=-1)
|
41
|
+
self.sigma = nn.Parameter(torch.zeros(1))
|
42
|
+
|
43
|
+
def forward(self, x):
|
44
|
+
"""
|
45
|
+
inputs :
|
46
|
+
x : input feature maps(B X C X W X H)
|
47
|
+
returns :
|
48
|
+
out : self attention value + input feature
|
49
|
+
attention: B X N X N (N is Width*Height)
|
50
|
+
"""
|
51
|
+
_, ch, h, w = x.size()
|
52
|
+
# Theta path
|
53
|
+
theta = self.snconv1x1_theta(x)
|
54
|
+
theta = theta.view(-1, ch//8, h*w)
|
55
|
+
# Phi path
|
56
|
+
phi = self.snconv1x1_phi(x)
|
57
|
+
phi = self.maxpool(phi)
|
58
|
+
phi = phi.view(-1, ch//8, h*w//4)
|
59
|
+
# Attn map
|
60
|
+
attn = torch.bmm(theta.permute(0, 2, 1), phi)
|
61
|
+
attn = self.softmax(attn)
|
62
|
+
# g path
|
63
|
+
g = self.snconv1x1_g(x)
|
64
|
+
g = self.maxpool(g)
|
65
|
+
g = g.view(-1, ch//2, h*w//4)
|
66
|
+
# Attn_g
|
67
|
+
attn_g = torch.bmm(g, attn.permute(0, 2, 1))
|
68
|
+
attn_g = attn_g.view(-1, ch//2, h, w)
|
69
|
+
attn_g = self.snconv1x1_attn(attn_g)
|
70
|
+
# Out
|
71
|
+
out = x + self.sigma*attn_g
|
72
|
+
return out
|
73
|
+
|
74
|
+
|
75
|
+
class ConditionalBatchNorm2d(nn.Module):
|
76
|
+
# https://github.com/pytorch/pytorch/issues/8985#issuecomment-405080775
|
77
|
+
def __init__(self, num_features, num_classes):
|
78
|
+
super().__init__()
|
79
|
+
self.num_features = num_features
|
80
|
+
self.bn = nn.BatchNorm2d(num_features, momentum=0.001, affine=False)
|
81
|
+
self.embed = nn.Embedding(num_classes, num_features * 2)
|
82
|
+
# self.embed.weight.data[:, :num_features].normal_(1, 0.02) # Initialise scale at N(1, 0.02)
|
83
|
+
self.embed.weight.data[:, :num_features].fill_(1.) # Initialize scale to 1
|
84
|
+
self.embed.weight.data[:, num_features:].zero_() # Initialize bias at 0
|
85
|
+
|
86
|
+
def forward(self, x, y):
|
87
|
+
out = self.bn(x)
|
88
|
+
gamma, beta = self.embed(y).chunk(2, 1)
|
89
|
+
out = gamma.view(-1, self.num_features, 1, 1) * out + beta.view(-1, self.num_features, 1, 1)
|
90
|
+
return out
|
91
|
+
|
92
|
+
|
93
|
+
class GenBlock(nn.Module):
|
94
|
+
def __init__(self, in_channels, out_channels, num_classes):
|
95
|
+
super(GenBlock, self).__init__()
|
96
|
+
self.cond_bn1 = ConditionalBatchNorm2d(in_channels, num_classes)
|
97
|
+
self.relu = nn.ReLU(inplace=True)
|
98
|
+
self.snconv2d1 = snconv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
|
99
|
+
self.cond_bn2 = ConditionalBatchNorm2d(out_channels, num_classes)
|
100
|
+
self.snconv2d2 = snconv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
|
101
|
+
self.snconv2d0 = snconv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0)
|
102
|
+
|
103
|
+
def forward(self, x, labels):
|
104
|
+
x0 = x
|
105
|
+
|
106
|
+
x = self.cond_bn1(x, labels)
|
107
|
+
x = self.relu(x)
|
108
|
+
x = F.interpolate(x, scale_factor=2, mode='nearest') # upsample
|
109
|
+
x = self.snconv2d1(x)
|
110
|
+
x = self.cond_bn2(x, labels)
|
111
|
+
x = self.relu(x)
|
112
|
+
x = self.snconv2d2(x)
|
113
|
+
|
114
|
+
x0 = F.interpolate(x0, scale_factor=2, mode='nearest') # upsample
|
115
|
+
x0 = self.snconv2d0(x0)
|
116
|
+
|
117
|
+
out = x + x0
|
118
|
+
return out
|
119
|
+
|
120
|
+
|
121
|
+
class cGAN_SAGAN_Generator(nn.Module):
|
122
|
+
"""Generator."""
|
123
|
+
|
124
|
+
def __init__(self, z_dim, num_classes, g_conv_dim=64):
|
125
|
+
super(cGAN_SAGAN_Generator, self).__init__()
|
126
|
+
|
127
|
+
self.z_dim = z_dim
|
128
|
+
self.g_conv_dim = g_conv_dim
|
129
|
+
self.snlinear0 = snlinear(in_features=z_dim, out_features=g_conv_dim*16*4*4)
|
130
|
+
self.block1 = GenBlock(g_conv_dim*16, g_conv_dim*16, num_classes)
|
131
|
+
self.block2 = GenBlock(g_conv_dim*16, g_conv_dim*8, num_classes)
|
132
|
+
self.block3 = GenBlock(g_conv_dim*8, g_conv_dim*4, num_classes)
|
133
|
+
self.block4 = GenBlock(g_conv_dim*4, g_conv_dim*2, num_classes)
|
134
|
+
self.self_attn = Self_Attn(g_conv_dim*2)
|
135
|
+
self.block5 = GenBlock(g_conv_dim*2, g_conv_dim*2, num_classes)
|
136
|
+
self.block6 = GenBlock(g_conv_dim*2, g_conv_dim, num_classes)
|
137
|
+
self.bn = nn.BatchNorm2d(g_conv_dim, eps=1e-5, momentum=0.0001, affine=True)
|
138
|
+
self.relu = nn.ReLU(inplace=True)
|
139
|
+
self.snconv2d1 = snconv2d(in_channels=g_conv_dim, out_channels=3, kernel_size=3, stride=1, padding=1)
|
140
|
+
self.tanh = nn.Tanh()
|
141
|
+
|
142
|
+
# Weight init
|
143
|
+
self.apply(init_weights)
|
144
|
+
|
145
|
+
def forward(self, z, labels):
|
146
|
+
# n x z_dim
|
147
|
+
act0 = self.snlinear0(z) # n x g_conv_dim*16*4*4
|
148
|
+
act0 = act0.view(-1, self.g_conv_dim*16, 4, 4) # n x g_conv_dim*16 x 4 x 4
|
149
|
+
act1 = self.block1(act0, labels) # n x g_conv_dim*16 x 8 x 8
|
150
|
+
act2 = self.block2(act1, labels) # n x g_conv_dim*8 x 16 x 16
|
151
|
+
act3 = self.block3(act2, labels) # n x g_conv_dim*4 x 32 x 32
|
152
|
+
act4 = self.block4(act3, labels) # n x g_conv_dim*2 x 64 x 64
|
153
|
+
act4 = self.self_attn(act4) # n x g_conv_dim*2 x 64 x 64
|
154
|
+
act5 = self.block5(act4, labels) # n x g_conv_dim x 128 x 128
|
155
|
+
act6 = self.block6(act5, labels) # n x g_conv_dim x 256 x 256
|
156
|
+
act6 = self.bn(act6) # n x g_conv_dim x 256 x 256
|
157
|
+
act6 = self.relu(act6) # n x g_conv_dim x 256 x 256
|
158
|
+
act7 = self.snconv2d1(act6) # n x 3 x 256 x 256
|
159
|
+
act7 = self.tanh(act7) # n x 3 x 256 x 256
|
160
|
+
return act7
|
161
|
+
|
162
|
+
|
163
|
+
class DiscOptBlock(nn.Module):
|
164
|
+
def __init__(self, in_channels, out_channels):
|
165
|
+
super(DiscOptBlock, self).__init__()
|
166
|
+
self.snconv2d1 = snconv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
|
167
|
+
self.relu = nn.ReLU(inplace=True)
|
168
|
+
self.snconv2d2 = snconv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
|
169
|
+
self.downsample = nn.AvgPool2d(2)
|
170
|
+
self.snconv2d0 = snconv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0)
|
171
|
+
|
172
|
+
def forward(self, x):
|
173
|
+
x0 = x
|
174
|
+
|
175
|
+
x = self.snconv2d1(x)
|
176
|
+
x = self.relu(x)
|
177
|
+
x = self.snconv2d2(x)
|
178
|
+
x = self.downsample(x)
|
179
|
+
|
180
|
+
x0 = self.downsample(x0)
|
181
|
+
x0 = self.snconv2d0(x0)
|
182
|
+
|
183
|
+
out = x + x0
|
184
|
+
return out
|
185
|
+
|
186
|
+
|
187
|
+
class DiscBlock(nn.Module):
|
188
|
+
def __init__(self, in_channels, out_channels):
|
189
|
+
super(DiscBlock, self).__init__()
|
190
|
+
self.relu = nn.ReLU(inplace=True)
|
191
|
+
self.snconv2d1 = snconv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
|
192
|
+
self.snconv2d2 = snconv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
|
193
|
+
self.downsample = nn.AvgPool2d(2)
|
194
|
+
self.ch_mismatch = False
|
195
|
+
if in_channels != out_channels:
|
196
|
+
self.ch_mismatch = True
|
197
|
+
self.snconv2d0 = snconv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0)
|
198
|
+
|
199
|
+
def forward(self, x, downsample=True):
|
200
|
+
x0 = x
|
201
|
+
|
202
|
+
x = self.relu(x)
|
203
|
+
x = self.snconv2d1(x)
|
204
|
+
x = self.relu(x)
|
205
|
+
x = self.snconv2d2(x)
|
206
|
+
if downsample:
|
207
|
+
x = self.downsample(x)
|
208
|
+
|
209
|
+
if downsample or self.ch_mismatch:
|
210
|
+
x0 = self.snconv2d0(x0)
|
211
|
+
if downsample:
|
212
|
+
x0 = self.downsample(x0)
|
213
|
+
|
214
|
+
out = x + x0
|
215
|
+
return out
|
216
|
+
|
217
|
+
|
218
|
+
class cGAN_SAGAN_Discriminator(nn.Module):
|
219
|
+
"""Discriminator."""
|
220
|
+
|
221
|
+
def __init__(self, num_classes, d_conv_dim=64):
|
222
|
+
super(cGAN_SAGAN_Discriminator, self).__init__()
|
223
|
+
self.d_conv_dim = d_conv_dim
|
224
|
+
self.opt_block1 = DiscOptBlock(3, d_conv_dim)
|
225
|
+
self.block1 = DiscBlock(d_conv_dim, d_conv_dim*2)
|
226
|
+
self.self_attn = Self_Attn(d_conv_dim*2)
|
227
|
+
self.block2 = DiscBlock(d_conv_dim*2, d_conv_dim*4)
|
228
|
+
self.block3 = DiscBlock(d_conv_dim*4, d_conv_dim*6)
|
229
|
+
self.block4 = DiscBlock(d_conv_dim*6, d_conv_dim*12)
|
230
|
+
self.block5 = DiscBlock(d_conv_dim*12, d_conv_dim*12)
|
231
|
+
self.block6 = DiscBlock(d_conv_dim*12, d_conv_dim*16)
|
232
|
+
self.relu = nn.ReLU(inplace=True)
|
233
|
+
self.snlinear1 = snlinear(in_features=d_conv_dim*16, out_features=1)
|
234
|
+
self.sn_embedding1 = sn_embedding(num_classes, d_conv_dim*16)
|
235
|
+
|
236
|
+
# Weight init
|
237
|
+
self.apply(init_weights)
|
238
|
+
xavier_uniform_(self.sn_embedding1.weight)
|
239
|
+
|
240
|
+
def forward(self, x, labels):
|
241
|
+
# n x 3 x 256 x 256
|
242
|
+
h0 = self.opt_block1(x) # n x d_conv_dim x 128 x 128
|
243
|
+
h1 = self.block1(h0) # n x d_conv_dim*2 x 64 x 64
|
244
|
+
h1 = self.self_attn(h1) # n x d_conv_dim*2 x 64 x 64
|
245
|
+
h2 = self.block2(h1) # n x d_conv_dim*4 x 32 x 32
|
246
|
+
h3 = self.block3(h2) # n x d_conv_dim*8 x 16 x 16
|
247
|
+
h4 = self.block4(h3) # n x d_conv_dim*16 x 8 x 8
|
248
|
+
h5 = self.block5(h4) # n x d_conv_dim*16 x 4 x 4
|
249
|
+
h6 = self.block6(h5, downsample=False) # n x d_conv_dim*16 x 4 x 4
|
250
|
+
h6 = self.relu(h6) # n x d_conv_dim*16 x 4 x 4
|
251
|
+
h6 = torch.sum(h6, dim=[2,3]) # n x d_conv_dim*16
|
252
|
+
output1 = torch.squeeze(self.snlinear1(h6)) # n
|
253
|
+
# Projection
|
254
|
+
h_labels = self.sn_embedding1(labels) # n x d_conv_dim*16
|
255
|
+
proj = torch.mul(h6, h_labels) # n x d_conv_dim*16
|
256
|
+
output2 = torch.sum(proj, dim=[1]) # n
|
257
|
+
# Out
|
258
|
+
output = output1 + output2 # n
|
259
|
+
return output
|
260
|
+
|
261
|
+
|
262
|
+
|
263
|
+
if __name__ == "__main__":
|
264
|
+
|
265
|
+
num_classes = 10
|
266
|
+
|
267
|
+
netG = cGAN_SAGAN_Generator(z_dim=128, num_classes=num_classes, g_conv_dim=128).cuda()
|
268
|
+
netD = cGAN_SAGAN_Discriminator(num_classes=num_classes, d_conv_dim=128).cuda()
|
269
|
+
|
270
|
+
n = 4
|
271
|
+
# target = torch.randint(high=num_classes, size=(1,n)) # set size (2,10) for MHE
|
272
|
+
# y = torch.zeros(n, num_classes)
|
273
|
+
# y[range(y.shape[0]), target]=1
|
274
|
+
# y = y.type(torch.long).cuda()
|
275
|
+
y = torch.randint(high=num_classes, size=(n,)).type(torch.long).cuda()
|
276
|
+
z = torch.randn(n, 128).cuda()
|
277
|
+
x = netG(z,y)
|
278
|
+
o = netD(x,y)
|
279
|
+
print(x.size())
|
280
|
+
print(o.size())
|
@@ -0,0 +1,249 @@
|
|
1
|
+
import numpy as np
|
2
|
+
import torch
|
3
|
+
import torch.nn as nn
|
4
|
+
import torch.nn.functional as F
|
5
|
+
|
6
|
+
from torch.nn.utils import spectral_norm
|
7
|
+
from torch.nn.init import xavier_uniform_
|
8
|
+
|
9
|
+
|
10
|
+
def init_weights(m):
|
11
|
+
if type(m) == nn.Linear or type(m) == nn.Conv2d:
|
12
|
+
xavier_uniform_(m.weight)
|
13
|
+
m.bias.data.fill_(0.)
|
14
|
+
|
15
|
+
|
16
|
+
def snconv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
|
17
|
+
return spectral_norm(nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
|
18
|
+
stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias))
|
19
|
+
|
20
|
+
|
21
|
+
def snlinear(in_features, out_features):
|
22
|
+
return spectral_norm(nn.Linear(in_features=in_features, out_features=out_features))
|
23
|
+
|
24
|
+
|
25
|
+
def sn_embedding(num_embeddings, embedding_dim):
|
26
|
+
return spectral_norm(nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim))
|
27
|
+
|
28
|
+
|
29
|
+
class Self_Attn(nn.Module):
|
30
|
+
""" Self attention Layer"""
|
31
|
+
|
32
|
+
def __init__(self, in_channels):
|
33
|
+
super(Self_Attn, self).__init__()
|
34
|
+
self.in_channels = in_channels
|
35
|
+
self.snconv1x1_theta = snconv2d(in_channels=in_channels, out_channels=in_channels//8, kernel_size=1, stride=1, padding=0)
|
36
|
+
self.snconv1x1_phi = snconv2d(in_channels=in_channels, out_channels=in_channels//8, kernel_size=1, stride=1, padding=0)
|
37
|
+
self.snconv1x1_g = snconv2d(in_channels=in_channels, out_channels=in_channels//2, kernel_size=1, stride=1, padding=0)
|
38
|
+
self.snconv1x1_attn = snconv2d(in_channels=in_channels//2, out_channels=in_channels, kernel_size=1, stride=1, padding=0)
|
39
|
+
self.maxpool = nn.MaxPool2d(2, stride=2, padding=0)
|
40
|
+
self.softmax = nn.Softmax(dim=-1)
|
41
|
+
self.sigma = nn.Parameter(torch.zeros(1))
|
42
|
+
|
43
|
+
def forward(self, x):
|
44
|
+
"""
|
45
|
+
inputs :
|
46
|
+
x : input feature maps(B X C X W X H)
|
47
|
+
returns :
|
48
|
+
out : self attention value + input feature
|
49
|
+
attention: B X N X N (N is Width*Height)
|
50
|
+
"""
|
51
|
+
_, ch, h, w = x.size()
|
52
|
+
# Theta path
|
53
|
+
theta = self.snconv1x1_theta(x)
|
54
|
+
theta = theta.view(-1, ch//8, h*w)
|
55
|
+
# Phi path
|
56
|
+
phi = self.snconv1x1_phi(x)
|
57
|
+
phi = self.maxpool(phi)
|
58
|
+
phi = phi.view(-1, ch//8, h*w//4)
|
59
|
+
# Attn map
|
60
|
+
attn = torch.bmm(theta.permute(0, 2, 1), phi)
|
61
|
+
attn = self.softmax(attn)
|
62
|
+
# g path
|
63
|
+
g = self.snconv1x1_g(x)
|
64
|
+
g = self.maxpool(g)
|
65
|
+
g = g.view(-1, ch//2, h*w//4)
|
66
|
+
# Attn_g
|
67
|
+
attn_g = torch.bmm(g, attn.permute(0, 2, 1))
|
68
|
+
attn_g = attn_g.view(-1, ch//2, h, w)
|
69
|
+
attn_g = self.snconv1x1_attn(attn_g)
|
70
|
+
# Out
|
71
|
+
out = x + self.sigma*attn_g
|
72
|
+
return out
|
73
|
+
|
74
|
+
|
75
|
+
class GenBlock(nn.Module):
|
76
|
+
def __init__(self, in_channels, out_channels):
|
77
|
+
super(GenBlock, self).__init__()
|
78
|
+
self.cond_bn1 = nn.BatchNorm2d(in_channels)
|
79
|
+
self.relu = nn.ReLU(inplace=True)
|
80
|
+
self.snconv2d1 = snconv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
|
81
|
+
self.cond_bn2 = nn.BatchNorm2d(out_channels)
|
82
|
+
self.snconv2d2 = snconv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
|
83
|
+
self.snconv2d0 = snconv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0)
|
84
|
+
|
85
|
+
def forward(self, x):
|
86
|
+
x0 = x
|
87
|
+
|
88
|
+
x = self.cond_bn1(x)
|
89
|
+
x = self.relu(x)
|
90
|
+
x = F.interpolate(x, scale_factor=2, mode='nearest') # upsample
|
91
|
+
x = self.snconv2d1(x)
|
92
|
+
x = self.cond_bn2(x)
|
93
|
+
x = self.relu(x)
|
94
|
+
x = self.snconv2d2(x)
|
95
|
+
|
96
|
+
x0 = F.interpolate(x0, scale_factor=2, mode='nearest') # upsample
|
97
|
+
x0 = self.snconv2d0(x0)
|
98
|
+
|
99
|
+
out = x + x0
|
100
|
+
return out
|
101
|
+
|
102
|
+
|
103
|
+
class cGAN_concat_SAGAN_Generator(nn.Module):
|
104
|
+
"""Generator."""
|
105
|
+
|
106
|
+
def __init__(self, z_dim, dim_c=1, g_conv_dim=64):
|
107
|
+
super(cGAN_concat_SAGAN_Generator, self).__init__()
|
108
|
+
|
109
|
+
self.z_dim = z_dim
|
110
|
+
self.dim_c = dim_c
|
111
|
+
self.g_conv_dim = g_conv_dim
|
112
|
+
self.snlinear0 = snlinear(in_features=z_dim+dim_c, out_features=g_conv_dim*16*4*4)
|
113
|
+
self.block1 = GenBlock(g_conv_dim*16, g_conv_dim*16)
|
114
|
+
self.block2 = GenBlock(g_conv_dim*16, g_conv_dim*8)
|
115
|
+
self.block3 = GenBlock(g_conv_dim*8, g_conv_dim*4)
|
116
|
+
self.block4 = GenBlock(g_conv_dim*4, g_conv_dim*2)
|
117
|
+
self.self_attn = Self_Attn(g_conv_dim*2)
|
118
|
+
self.block5 = GenBlock(g_conv_dim*2, g_conv_dim*2)
|
119
|
+
self.block6 = GenBlock(g_conv_dim*2, g_conv_dim)
|
120
|
+
self.bn = nn.BatchNorm2d(g_conv_dim, eps=1e-5, momentum=0.0001, affine=True)
|
121
|
+
self.relu = nn.ReLU()
|
122
|
+
self.snconv2d1 = snconv2d(in_channels=g_conv_dim, out_channels=3, kernel_size=3, stride=1, padding=1)
|
123
|
+
self.tanh = nn.Tanh()
|
124
|
+
|
125
|
+
# Weight init
|
126
|
+
self.apply(init_weights)
|
127
|
+
|
128
|
+
def forward(self, z, labels):
|
129
|
+
# n x z_dim
|
130
|
+
act0 = self.snlinear0(torch.cat((z, labels.view(-1,1)),dim=1)) # n x g_conv_dim*16*4*4
|
131
|
+
act0 = act0.view(-1, self.g_conv_dim*16, 4, 4) # n x g_conv_dim*16 x 4 x 4
|
132
|
+
act1 = self.block1(act0) # n x g_conv_dim*16 x 8 x 8
|
133
|
+
act2 = self.block2(act1) # n x g_conv_dim*8 x 16 x 16
|
134
|
+
act3 = self.block3(act2) # n x g_conv_dim*4 x 32 x 32
|
135
|
+
act4 = self.block4(act3) # n x g_conv_dim*2 x 64 x 64
|
136
|
+
act4 = self.self_attn(act4) # n x g_conv_dim*2 x 64 x 64
|
137
|
+
act5 = self.block5(act4) # n x g_conv_dim x 128 x 128
|
138
|
+
act6 = self.block6(act5) # n x g_conv_dim x 256 x 256
|
139
|
+
act6 = self.bn(act6) # n x g_conv_dim x 256 x 256
|
140
|
+
act6 = self.relu(act6) # n x g_conv_dim x 256 x 256
|
141
|
+
act7 = self.snconv2d1(act6) # n x 3 x 256 x 256
|
142
|
+
act7 = self.tanh(act7) # n x 3 x 256 x 256
|
143
|
+
return act7
|
144
|
+
|
145
|
+
|
146
|
+
class DiscOptBlock(nn.Module):
|
147
|
+
def __init__(self, in_channels, out_channels):
|
148
|
+
super(DiscOptBlock, self).__init__()
|
149
|
+
self.snconv2d1 = snconv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
|
150
|
+
self.relu = nn.ReLU(inplace=True)
|
151
|
+
self.snconv2d2 = snconv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
|
152
|
+
self.downsample = nn.AvgPool2d(2)
|
153
|
+
self.snconv2d0 = snconv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0)
|
154
|
+
|
155
|
+
def forward(self, x):
|
156
|
+
x0 = x
|
157
|
+
|
158
|
+
x = self.snconv2d1(x)
|
159
|
+
x = self.relu(x)
|
160
|
+
x = self.snconv2d2(x)
|
161
|
+
x = self.downsample(x)
|
162
|
+
|
163
|
+
x0 = self.downsample(x0)
|
164
|
+
x0 = self.snconv2d0(x0)
|
165
|
+
|
166
|
+
out = x + x0
|
167
|
+
return out
|
168
|
+
|
169
|
+
|
170
|
+
class DiscBlock(nn.Module):
|
171
|
+
def __init__(self, in_channels, out_channels):
|
172
|
+
super(DiscBlock, self).__init__()
|
173
|
+
self.relu = nn.ReLU(inplace=True)
|
174
|
+
self.snconv2d1 = snconv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
|
175
|
+
self.snconv2d2 = snconv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
|
176
|
+
self.downsample = nn.AvgPool2d(2)
|
177
|
+
self.ch_mismatch = False
|
178
|
+
if in_channels != out_channels:
|
179
|
+
self.ch_mismatch = True
|
180
|
+
self.snconv2d0 = snconv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0)
|
181
|
+
|
182
|
+
def forward(self, x, downsample=True):
|
183
|
+
x0 = x
|
184
|
+
|
185
|
+
x = self.relu(x)
|
186
|
+
x = self.snconv2d1(x)
|
187
|
+
x = self.relu(x)
|
188
|
+
x = self.snconv2d2(x)
|
189
|
+
if downsample:
|
190
|
+
x = self.downsample(x)
|
191
|
+
|
192
|
+
if downsample or self.ch_mismatch:
|
193
|
+
x0 = self.snconv2d0(x0)
|
194
|
+
if downsample:
|
195
|
+
x0 = self.downsample(x0)
|
196
|
+
|
197
|
+
out = x + x0
|
198
|
+
return out
|
199
|
+
|
200
|
+
|
201
|
+
class cGAN_concat_SAGAN_Discriminator(nn.Module):
|
202
|
+
"""Discriminator."""
|
203
|
+
|
204
|
+
def __init__(self, dim_c=1, d_conv_dim=64):
|
205
|
+
super(cGAN_concat_SAGAN_Discriminator, self).__init__()
|
206
|
+
self.d_conv_dim = d_conv_dim
|
207
|
+
self.dim_c = dim_c
|
208
|
+
self.opt_block1 = DiscOptBlock(3, d_conv_dim)
|
209
|
+
self.block1 = DiscBlock(d_conv_dim, d_conv_dim*2)
|
210
|
+
self.self_attn = Self_Attn(d_conv_dim*2)
|
211
|
+
self.block2 = DiscBlock(d_conv_dim*2, d_conv_dim*4)
|
212
|
+
self.block3 = DiscBlock(d_conv_dim*4, d_conv_dim*6)
|
213
|
+
self.block4 = DiscBlock(d_conv_dim*6, d_conv_dim*12)
|
214
|
+
self.block5 = DiscBlock(d_conv_dim*12, d_conv_dim*12)
|
215
|
+
self.block6 = DiscBlock(d_conv_dim*12, d_conv_dim*16)
|
216
|
+
self.relu = nn.ReLU()
|
217
|
+
self.snlinear1 = snlinear(in_features=d_conv_dim*16*4*4+dim_c, out_features=1)
|
218
|
+
|
219
|
+
def forward(self, x, labels):
|
220
|
+
labels = labels.view(-1,1)
|
221
|
+
# n x 3 x 256 x 256
|
222
|
+
h0 = self.opt_block1(x) # n x d_conv_dim x 128 x 128
|
223
|
+
h1 = self.block1(h0) # n x d_conv_dim*2 x 64 x 64
|
224
|
+
h1 = self.self_attn(h1) # n x d_conv_dim*2 x 64 x 64
|
225
|
+
h2 = self.block2(h1) # n x d_conv_dim*4 x 32 x 32
|
226
|
+
h3 = self.block3(h2) # n x d_conv_dim*8 x 16 x 16
|
227
|
+
h4 = self.block4(h3) # n x d_conv_dim*16 x 8 x 8
|
228
|
+
h5 = self.block5(h4) # n x d_conv_dim*16 x 4 x 4
|
229
|
+
h6 = self.block6(h5, downsample=False) # n x d_conv_dim*16 x 4 x 4
|
230
|
+
out = self.relu(h6) # n x d_conv_dim*16 x 4 x 4
|
231
|
+
out = out.view(-1,self.d_conv_dim*16*4*4)
|
232
|
+
out = torch.cat((out, labels),dim=1)
|
233
|
+
out = self.snlinear1(out)
|
234
|
+
return out
|
235
|
+
|
236
|
+
|
237
|
+
|
238
|
+
if __name__ == "__main__":
|
239
|
+
|
240
|
+
netG = cGAN_concat_SAGAN_Generator(z_dim=128).cuda()
|
241
|
+
netD = cGAN_concat_SAGAN_Discriminator().cuda()
|
242
|
+
|
243
|
+
n = 4
|
244
|
+
y = torch.randn(n, 1).cuda()
|
245
|
+
z = torch.randn(n, 128).cuda()
|
246
|
+
x = netG(z,y)
|
247
|
+
print(x.size())
|
248
|
+
o = netD(x,y)
|
249
|
+
print(o.size())
|
@@ -5,10 +5,20 @@ from PIL import Image
|
|
5
5
|
from print_hdf5_structure import print_hdf5_structure
|
6
6
|
|
7
7
|
# Make all images to a HDF5 file
|
8
|
-
def make_h5(image_dir
|
8
|
+
def make_h5(image_dir, h5_path, image_names = [], indx_train = None, indx_valid = None, image_labels = None, image_types = None):
|
9
9
|
|
10
10
|
N = len(image_names)
|
11
11
|
|
12
|
+
# Process none
|
13
|
+
if indx_train == None:
|
14
|
+
indx_train = np.array(range(1, N, 2), dtype=np.int32)
|
15
|
+
if indx_valid == None:
|
16
|
+
indx_valid = np.array(range(0, N, 2), dtype=np.int32)
|
17
|
+
if image_labels == None:
|
18
|
+
image_labels = np.zeros(N)
|
19
|
+
if image_types == None:
|
20
|
+
image_types = np.zeros(N)
|
21
|
+
|
12
22
|
# Get image data
|
13
23
|
image_datas = []
|
14
24
|
for i in range(N):
|
@@ -19,17 +29,11 @@ def make_h5(image_dir: str, h5_path: str, image_names: list[str], image_labels,
|
|
19
29
|
image_datas.append(rgb_array)
|
20
30
|
image_datas = np.array(image_datas, dtype=np.uint8)
|
21
31
|
|
22
|
-
# Set train_idx = 1, 3, 5, ...
|
23
|
-
train_idx = np.array(range(1, N, 2), dtype=np.int32)
|
24
|
-
|
25
|
-
# Set val_idx = 0, 2, 4, ...
|
26
|
-
val_idx = np.array(range(0, N, 2), dtype=np.int32)
|
27
|
-
|
28
32
|
# Create a new HDF5 file
|
29
33
|
with h5py.File(h5_path, "w") as f:
|
30
34
|
f.create_dataset("images", data=image_datas)
|
31
|
-
f.create_dataset("indx_train", data=
|
32
|
-
f.create_dataset("indx_valid", data=
|
35
|
+
f.create_dataset("indx_train", data=indx_train)
|
36
|
+
f.create_dataset("indx_valid", data=indx_valid)
|
33
37
|
f.create_dataset("labels", data=image_labels)
|
34
38
|
f.create_dataset("types", data=image_types)
|
35
39
|
|
@@ -0,0 +1,28 @@
|
|
1
|
+
myosotis_researches/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
|
+
myosotis_researches/CcGAN/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
|
+
myosotis_researches/CcGAN/models_128/CcGAN_SAGAN.py,sha256=uYDngtHoB7frPg2Vs7YCFXeUh7Y7MjaAXbRWHXO_xvw,10629
|
4
|
+
myosotis_researches/CcGAN/models_128/ResNet_class_eval.py,sha256=wa5CPkYzrS0X6kZ6pGHM-GxcGNkSpBdTTqgy5dKVKkU,5131
|
5
|
+
myosotis_researches/CcGAN/models_128/ResNet_embed.py,sha256=HKSY-5WWa9jGniOgRoR1WOTfWhR1Dcj6cq2sgznZEbE,6344
|
6
|
+
myosotis_researches/CcGAN/models_128/ResNet_regre_eval.py,sha256=VJYJiiwrjf9DvfZrlwOMJJAPu3PlwgFgIddDaRlGsac,6190
|
7
|
+
myosotis_researches/CcGAN/models_128/__init__.py,sha256=X1eTESRCLfGO74qlLWV4hF3_wcmPhTpjhrVArYZ3_rU,451
|
8
|
+
myosotis_researches/CcGAN/models_128/autoencoder.py,sha256=ugOwBNoSNP4-WiATVkhC4-igRjj6yEY91qU0egpX744,3827
|
9
|
+
myosotis_researches/CcGAN/models_128/cGAN_SAGAN.py,sha256=JDr0Ss5osf9m-u34bVN_PvMsvMXkmi2jwPOAnls6EOA,11240
|
10
|
+
myosotis_researches/CcGAN/models_128/cGAN_concat_SAGAN.py,sha256=GHAmrNjORXKu-8UqAdP-A5WG-_3BdQUmWsrWD1NX5-w,9634
|
11
|
+
myosotis_researches/CcGAN/models_256/CcGAN_SAGAN.py,sha256=ju1dBYhqxl722_eeUGc2mKwf1AV_qsv1PlBL3tyOu48,10861
|
12
|
+
myosotis_researches/CcGAN/models_256/ResNet_class_eval.py,sha256=tS5YxIpiFS9tDCNe2IDv1hTZNn40_JBD_nn97MfQJNI,5178
|
13
|
+
myosotis_researches/CcGAN/models_256/ResNet_embed.py,sha256=9OcMQ-8nuWEbEbWc9tGaWQtfV1hdnkl0PrTphoGX77c,6295
|
14
|
+
myosotis_researches/CcGAN/models_256/ResNet_regre_eval.py,sha256=tHAbRNM9XodyfPsu00ac5KMjcgRH8qdx8AtCN9QGXKc,6269
|
15
|
+
myosotis_researches/CcGAN/models_256/__init__.py,sha256=X1eTESRCLfGO74qlLWV4hF3_wcmPhTpjhrVArYZ3_rU,451
|
16
|
+
myosotis_researches/CcGAN/models_256/autoencoder.py,sha256=Nv3eSWJVrWaOufoVGe04sZ_KiXFLtu3Y0asZcAdyyj0,4382
|
17
|
+
myosotis_researches/CcGAN/models_256/cGAN_SAGAN.py,sha256=wTHVkUcAp07n3lgweKFo6cqd91E_rEqgJrBDbBe6qrg,11510
|
18
|
+
myosotis_researches/CcGAN/models_256/cGAN_concat_SAGAN.py,sha256=ZmGEpprDDlFR3dG32LT3NH5yiA1WR8Hg26rcbz42aCQ,9807
|
19
|
+
myosotis_researches/CcGAN/utils/__init__.py,sha256=Pu9COV4zcXHGXuczhObersyeshVChmlEtwqp8VLUDxw,300
|
20
|
+
myosotis_researches/CcGAN/utils/concat_image_horizontal.py,sha256=e6WsfO9IiSoP8zkZNz7IGimPUASr9VvyJUJdF-d40iw,954
|
21
|
+
myosotis_researches/CcGAN/utils/concat_image_vertical.py,sha256=97-SuE8ZWpaeBm_ed6MAEaUOvtpzlYq_X3yWt4OEUTY,951
|
22
|
+
myosotis_researches/CcGAN/utils/make_h5.py,sha256=bZGNx_SgWOaG8h3FpRk4WffC70hF6f9iIwj8z6dCKHI,1421
|
23
|
+
myosotis_researches/CcGAN/utils/print_hdf5_structure.py,sha256=leaR8H3GhlX6EuIXDMh36xG2zBdV-XlJkaXBuoorl6I,320
|
24
|
+
myosotis_researches-0.0.14.dist-info/licenses/LICENSE,sha256=OXLcl0T2SZ8Pmy2_dmlvKuetivmyPd5m1q-Gyd-zaYY,35149
|
25
|
+
myosotis_researches-0.0.14.dist-info/METADATA,sha256=pyKRlBWw4SJu10CHAu-Gm9hnqU714aFE4DtcHMO32sE,765
|
26
|
+
myosotis_researches-0.0.14.dist-info/WHEEL,sha256=pxyMxgL8-pra_rKaQ4drOZAegBVuX-G_4nRHjjgWbmo,91
|
27
|
+
myosotis_researches-0.0.14.dist-info/top_level.txt,sha256=zxAiMn5eyZNJM28MewTAkgi_RZJMbfWbzVR-KF0LdZE,20
|
28
|
+
myosotis_researches-0.0.14.dist-info/RECORD,,
|
@@ -1,12 +0,0 @@
|
|
1
|
-
myosotis_researches/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
|
-
myosotis_researches/CcGAN/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
|
-
myosotis_researches/CcGAN/utils/__init__.py,sha256=Pu9COV4zcXHGXuczhObersyeshVChmlEtwqp8VLUDxw,300
|
4
|
-
myosotis_researches/CcGAN/utils/concat_image_horizontal.py,sha256=e6WsfO9IiSoP8zkZNz7IGimPUASr9VvyJUJdF-d40iw,954
|
5
|
-
myosotis_researches/CcGAN/utils/concat_image_vertical.py,sha256=97-SuE8ZWpaeBm_ed6MAEaUOvtpzlYq_X3yWt4OEUTY,951
|
6
|
-
myosotis_researches/CcGAN/utils/make_h5.py,sha256=baY4lElNUzoCkEcoYrE1bulEYFHr1l6vKAof2WtbMQI,1239
|
7
|
-
myosotis_researches/CcGAN/utils/print_hdf5_structure.py,sha256=leaR8H3GhlX6EuIXDMh36xG2zBdV-XlJkaXBuoorl6I,320
|
8
|
-
myosotis_researches-0.0.12.dist-info/licenses/LICENSE,sha256=OXLcl0T2SZ8Pmy2_dmlvKuetivmyPd5m1q-Gyd-zaYY,35149
|
9
|
-
myosotis_researches-0.0.12.dist-info/METADATA,sha256=_awkE9siKuYkbBqvK9sc9p69z7KN6NJh8Pxf3Qqp994,765
|
10
|
-
myosotis_researches-0.0.12.dist-info/WHEEL,sha256=pxyMxgL8-pra_rKaQ4drOZAegBVuX-G_4nRHjjgWbmo,91
|
11
|
-
myosotis_researches-0.0.12.dist-info/top_level.txt,sha256=zxAiMn5eyZNJM28MewTAkgi_RZJMbfWbzVR-KF0LdZE,20
|
12
|
-
myosotis_researches-0.0.12.dist-info/RECORD,,
|
File without changes
|
{myosotis_researches-0.0.12.dist-info → myosotis_researches-0.0.14.dist-info}/licenses/LICENSE
RENAMED
File without changes
|
File without changes
|