Myosotis-Researches 0.1.7__py3-none-any.whl → 0.1.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- myosotis_researches/CcGAN/train/__init__.py +4 -0
- myosotis_researches/CcGAN/{train_128_output_10 → train}/train_ccgan.py +4 -4
- myosotis_researches/CcGAN/{train_128 → train}/train_cgan.py +1 -3
- myosotis_researches/CcGAN/{train_128 → train}/train_cgan_concat.py +1 -3
- myosotis_researches/CcGAN/utils/__init__.py +2 -1
- myosotis_researches/CcGAN/utils/train.py +94 -3
- {myosotis_researches-0.1.7.dist-info → myosotis_researches-0.1.9.dist-info}/METADATA +1 -1
- myosotis_researches-0.1.9.dist-info/RECORD +24 -0
- myosotis_researches/CcGAN/models_128/CcGAN_SAGAN.py +0 -301
- myosotis_researches/CcGAN/models_128/ResNet_class_eval.py +0 -141
- myosotis_researches/CcGAN/models_128/ResNet_embed.py +0 -188
- myosotis_researches/CcGAN/models_128/ResNet_regre_eval.py +0 -175
- myosotis_researches/CcGAN/models_128/__init__.py +0 -7
- myosotis_researches/CcGAN/models_128/autoencoder.py +0 -119
- myosotis_researches/CcGAN/models_128/cGAN_SAGAN.py +0 -276
- myosotis_researches/CcGAN/models_128/cGAN_concat_SAGAN.py +0 -245
- myosotis_researches/CcGAN/models_256/CcGAN_SAGAN.py +0 -303
- myosotis_researches/CcGAN/models_256/ResNet_class_eval.py +0 -142
- myosotis_researches/CcGAN/models_256/ResNet_embed.py +0 -188
- myosotis_researches/CcGAN/models_256/ResNet_regre_eval.py +0 -178
- myosotis_researches/CcGAN/models_256/__init__.py +0 -7
- myosotis_researches/CcGAN/models_256/autoencoder.py +0 -133
- myosotis_researches/CcGAN/models_256/cGAN_SAGAN.py +0 -280
- myosotis_researches/CcGAN/models_256/cGAN_concat_SAGAN.py +0 -249
- myosotis_researches/CcGAN/train_128/DiffAugment_pytorch.py +0 -76
- myosotis_researches/CcGAN/train_128/__init__.py +0 -0
- myosotis_researches/CcGAN/train_128/eval_metrics.py +0 -205
- myosotis_researches/CcGAN/train_128/opts.py +0 -87
- myosotis_researches/CcGAN/train_128/pretrain_AE.py +0 -268
- myosotis_researches/CcGAN/train_128/pretrain_CNN_class.py +0 -251
- myosotis_researches/CcGAN/train_128/pretrain_CNN_regre.py +0 -255
- myosotis_researches/CcGAN/train_128/train_ccgan.py +0 -303
- myosotis_researches/CcGAN/train_128/utils.py +0 -120
- myosotis_researches/CcGAN/train_128_output_10/DiffAugment_pytorch.py +0 -76
- myosotis_researches/CcGAN/train_128_output_10/__init__.py +0 -0
- myosotis_researches/CcGAN/train_128_output_10/eval_metrics.py +0 -205
- myosotis_researches/CcGAN/train_128_output_10/opts.py +0 -87
- myosotis_researches/CcGAN/train_128_output_10/pretrain_AE.py +0 -268
- myosotis_researches/CcGAN/train_128_output_10/pretrain_CNN_class.py +0 -251
- myosotis_researches/CcGAN/train_128_output_10/pretrain_CNN_regre.py +0 -255
- myosotis_researches/CcGAN/train_128_output_10/train_cgan.py +0 -254
- myosotis_researches/CcGAN/train_128_output_10/train_cgan_concat.py +0 -242
- myosotis_researches/CcGAN/train_128_output_10/train_net_for_label_embed.py +0 -181
- myosotis_researches/CcGAN/train_128_output_10/utils.py +0 -120
- myosotis_researches-0.1.7.dist-info/RECORD +0 -59
- /myosotis_researches/CcGAN/{train_128 → train}/train_net_for_label_embed.py +0 -0
- {myosotis_researches-0.1.7.dist-info → myosotis_researches-0.1.9.dist-info}/WHEEL +0 -0
- {myosotis_researches-0.1.7.dist-info → myosotis_researches-0.1.9.dist-info}/licenses/LICENSE +0 -0
- {myosotis_researches-0.1.7.dist-info → myosotis_researches-0.1.9.dist-info}/top_level.txt +0 -0
@@ -1,249 +0,0 @@
|
|
1
|
-
import numpy as np
|
2
|
-
import torch
|
3
|
-
import torch.nn as nn
|
4
|
-
import torch.nn.functional as F
|
5
|
-
|
6
|
-
from torch.nn.utils import spectral_norm
|
7
|
-
from torch.nn.init import xavier_uniform_
|
8
|
-
|
9
|
-
|
10
|
-
def init_weights(m):
|
11
|
-
if type(m) == nn.Linear or type(m) == nn.Conv2d:
|
12
|
-
xavier_uniform_(m.weight)
|
13
|
-
m.bias.data.fill_(0.)
|
14
|
-
|
15
|
-
|
16
|
-
def snconv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
|
17
|
-
return spectral_norm(nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
|
18
|
-
stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias))
|
19
|
-
|
20
|
-
|
21
|
-
def snlinear(in_features, out_features):
|
22
|
-
return spectral_norm(nn.Linear(in_features=in_features, out_features=out_features))
|
23
|
-
|
24
|
-
|
25
|
-
def sn_embedding(num_embeddings, embedding_dim):
|
26
|
-
return spectral_norm(nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim))
|
27
|
-
|
28
|
-
|
29
|
-
class Self_Attn(nn.Module):
|
30
|
-
""" Self attention Layer"""
|
31
|
-
|
32
|
-
def __init__(self, in_channels):
|
33
|
-
super(Self_Attn, self).__init__()
|
34
|
-
self.in_channels = in_channels
|
35
|
-
self.snconv1x1_theta = snconv2d(in_channels=in_channels, out_channels=in_channels//8, kernel_size=1, stride=1, padding=0)
|
36
|
-
self.snconv1x1_phi = snconv2d(in_channels=in_channels, out_channels=in_channels//8, kernel_size=1, stride=1, padding=0)
|
37
|
-
self.snconv1x1_g = snconv2d(in_channels=in_channels, out_channels=in_channels//2, kernel_size=1, stride=1, padding=0)
|
38
|
-
self.snconv1x1_attn = snconv2d(in_channels=in_channels//2, out_channels=in_channels, kernel_size=1, stride=1, padding=0)
|
39
|
-
self.maxpool = nn.MaxPool2d(2, stride=2, padding=0)
|
40
|
-
self.softmax = nn.Softmax(dim=-1)
|
41
|
-
self.sigma = nn.Parameter(torch.zeros(1))
|
42
|
-
|
43
|
-
def forward(self, x):
|
44
|
-
"""
|
45
|
-
inputs :
|
46
|
-
x : input feature maps(B X C X W X H)
|
47
|
-
returns :
|
48
|
-
out : self attention value + input feature
|
49
|
-
attention: B X N X N (N is Width*Height)
|
50
|
-
"""
|
51
|
-
_, ch, h, w = x.size()
|
52
|
-
# Theta path
|
53
|
-
theta = self.snconv1x1_theta(x)
|
54
|
-
theta = theta.view(-1, ch//8, h*w)
|
55
|
-
# Phi path
|
56
|
-
phi = self.snconv1x1_phi(x)
|
57
|
-
phi = self.maxpool(phi)
|
58
|
-
phi = phi.view(-1, ch//8, h*w//4)
|
59
|
-
# Attn map
|
60
|
-
attn = torch.bmm(theta.permute(0, 2, 1), phi)
|
61
|
-
attn = self.softmax(attn)
|
62
|
-
# g path
|
63
|
-
g = self.snconv1x1_g(x)
|
64
|
-
g = self.maxpool(g)
|
65
|
-
g = g.view(-1, ch//2, h*w//4)
|
66
|
-
# Attn_g
|
67
|
-
attn_g = torch.bmm(g, attn.permute(0, 2, 1))
|
68
|
-
attn_g = attn_g.view(-1, ch//2, h, w)
|
69
|
-
attn_g = self.snconv1x1_attn(attn_g)
|
70
|
-
# Out
|
71
|
-
out = x + self.sigma*attn_g
|
72
|
-
return out
|
73
|
-
|
74
|
-
|
75
|
-
class GenBlock(nn.Module):
|
76
|
-
def __init__(self, in_channels, out_channels):
|
77
|
-
super(GenBlock, self).__init__()
|
78
|
-
self.cond_bn1 = nn.BatchNorm2d(in_channels)
|
79
|
-
self.relu = nn.ReLU(inplace=True)
|
80
|
-
self.snconv2d1 = snconv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
|
81
|
-
self.cond_bn2 = nn.BatchNorm2d(out_channels)
|
82
|
-
self.snconv2d2 = snconv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
|
83
|
-
self.snconv2d0 = snconv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0)
|
84
|
-
|
85
|
-
def forward(self, x):
|
86
|
-
x0 = x
|
87
|
-
|
88
|
-
x = self.cond_bn1(x)
|
89
|
-
x = self.relu(x)
|
90
|
-
x = F.interpolate(x, scale_factor=2, mode='nearest') # upsample
|
91
|
-
x = self.snconv2d1(x)
|
92
|
-
x = self.cond_bn2(x)
|
93
|
-
x = self.relu(x)
|
94
|
-
x = self.snconv2d2(x)
|
95
|
-
|
96
|
-
x0 = F.interpolate(x0, scale_factor=2, mode='nearest') # upsample
|
97
|
-
x0 = self.snconv2d0(x0)
|
98
|
-
|
99
|
-
out = x + x0
|
100
|
-
return out
|
101
|
-
|
102
|
-
|
103
|
-
class cGAN_concat_SAGAN_Generator(nn.Module):
|
104
|
-
"""Generator."""
|
105
|
-
|
106
|
-
def __init__(self, z_dim, dim_c=1, g_conv_dim=64):
|
107
|
-
super(cGAN_concat_SAGAN_Generator, self).__init__()
|
108
|
-
|
109
|
-
self.z_dim = z_dim
|
110
|
-
self.dim_c = dim_c
|
111
|
-
self.g_conv_dim = g_conv_dim
|
112
|
-
self.snlinear0 = snlinear(in_features=z_dim+dim_c, out_features=g_conv_dim*16*4*4)
|
113
|
-
self.block1 = GenBlock(g_conv_dim*16, g_conv_dim*16)
|
114
|
-
self.block2 = GenBlock(g_conv_dim*16, g_conv_dim*8)
|
115
|
-
self.block3 = GenBlock(g_conv_dim*8, g_conv_dim*4)
|
116
|
-
self.block4 = GenBlock(g_conv_dim*4, g_conv_dim*2)
|
117
|
-
self.self_attn = Self_Attn(g_conv_dim*2)
|
118
|
-
self.block5 = GenBlock(g_conv_dim*2, g_conv_dim*2)
|
119
|
-
self.block6 = GenBlock(g_conv_dim*2, g_conv_dim)
|
120
|
-
self.bn = nn.BatchNorm2d(g_conv_dim, eps=1e-5, momentum=0.0001, affine=True)
|
121
|
-
self.relu = nn.ReLU()
|
122
|
-
self.snconv2d1 = snconv2d(in_channels=g_conv_dim, out_channels=3, kernel_size=3, stride=1, padding=1)
|
123
|
-
self.tanh = nn.Tanh()
|
124
|
-
|
125
|
-
# Weight init
|
126
|
-
self.apply(init_weights)
|
127
|
-
|
128
|
-
def forward(self, z, labels):
|
129
|
-
# n x z_dim
|
130
|
-
act0 = self.snlinear0(torch.cat((z, labels.view(-1,1)),dim=1)) # n x g_conv_dim*16*4*4
|
131
|
-
act0 = act0.view(-1, self.g_conv_dim*16, 4, 4) # n x g_conv_dim*16 x 4 x 4
|
132
|
-
act1 = self.block1(act0) # n x g_conv_dim*16 x 8 x 8
|
133
|
-
act2 = self.block2(act1) # n x g_conv_dim*8 x 16 x 16
|
134
|
-
act3 = self.block3(act2) # n x g_conv_dim*4 x 32 x 32
|
135
|
-
act4 = self.block4(act3) # n x g_conv_dim*2 x 64 x 64
|
136
|
-
act4 = self.self_attn(act4) # n x g_conv_dim*2 x 64 x 64
|
137
|
-
act5 = self.block5(act4) # n x g_conv_dim x 128 x 128
|
138
|
-
act6 = self.block6(act5) # n x g_conv_dim x 256 x 256
|
139
|
-
act6 = self.bn(act6) # n x g_conv_dim x 256 x 256
|
140
|
-
act6 = self.relu(act6) # n x g_conv_dim x 256 x 256
|
141
|
-
act7 = self.snconv2d1(act6) # n x 3 x 256 x 256
|
142
|
-
act7 = self.tanh(act7) # n x 3 x 256 x 256
|
143
|
-
return act7
|
144
|
-
|
145
|
-
|
146
|
-
class DiscOptBlock(nn.Module):
|
147
|
-
def __init__(self, in_channels, out_channels):
|
148
|
-
super(DiscOptBlock, self).__init__()
|
149
|
-
self.snconv2d1 = snconv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
|
150
|
-
self.relu = nn.ReLU(inplace=True)
|
151
|
-
self.snconv2d2 = snconv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
|
152
|
-
self.downsample = nn.AvgPool2d(2)
|
153
|
-
self.snconv2d0 = snconv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0)
|
154
|
-
|
155
|
-
def forward(self, x):
|
156
|
-
x0 = x
|
157
|
-
|
158
|
-
x = self.snconv2d1(x)
|
159
|
-
x = self.relu(x)
|
160
|
-
x = self.snconv2d2(x)
|
161
|
-
x = self.downsample(x)
|
162
|
-
|
163
|
-
x0 = self.downsample(x0)
|
164
|
-
x0 = self.snconv2d0(x0)
|
165
|
-
|
166
|
-
out = x + x0
|
167
|
-
return out
|
168
|
-
|
169
|
-
|
170
|
-
class DiscBlock(nn.Module):
|
171
|
-
def __init__(self, in_channels, out_channels):
|
172
|
-
super(DiscBlock, self).__init__()
|
173
|
-
self.relu = nn.ReLU(inplace=True)
|
174
|
-
self.snconv2d1 = snconv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
|
175
|
-
self.snconv2d2 = snconv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
|
176
|
-
self.downsample = nn.AvgPool2d(2)
|
177
|
-
self.ch_mismatch = False
|
178
|
-
if in_channels != out_channels:
|
179
|
-
self.ch_mismatch = True
|
180
|
-
self.snconv2d0 = snconv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0)
|
181
|
-
|
182
|
-
def forward(self, x, downsample=True):
|
183
|
-
x0 = x
|
184
|
-
|
185
|
-
x = self.relu(x)
|
186
|
-
x = self.snconv2d1(x)
|
187
|
-
x = self.relu(x)
|
188
|
-
x = self.snconv2d2(x)
|
189
|
-
if downsample:
|
190
|
-
x = self.downsample(x)
|
191
|
-
|
192
|
-
if downsample or self.ch_mismatch:
|
193
|
-
x0 = self.snconv2d0(x0)
|
194
|
-
if downsample:
|
195
|
-
x0 = self.downsample(x0)
|
196
|
-
|
197
|
-
out = x + x0
|
198
|
-
return out
|
199
|
-
|
200
|
-
|
201
|
-
class cGAN_concat_SAGAN_Discriminator(nn.Module):
|
202
|
-
"""Discriminator."""
|
203
|
-
|
204
|
-
def __init__(self, dim_c=1, d_conv_dim=64):
|
205
|
-
super(cGAN_concat_SAGAN_Discriminator, self).__init__()
|
206
|
-
self.d_conv_dim = d_conv_dim
|
207
|
-
self.dim_c = dim_c
|
208
|
-
self.opt_block1 = DiscOptBlock(3, d_conv_dim)
|
209
|
-
self.block1 = DiscBlock(d_conv_dim, d_conv_dim*2)
|
210
|
-
self.self_attn = Self_Attn(d_conv_dim*2)
|
211
|
-
self.block2 = DiscBlock(d_conv_dim*2, d_conv_dim*4)
|
212
|
-
self.block3 = DiscBlock(d_conv_dim*4, d_conv_dim*6)
|
213
|
-
self.block4 = DiscBlock(d_conv_dim*6, d_conv_dim*12)
|
214
|
-
self.block5 = DiscBlock(d_conv_dim*12, d_conv_dim*12)
|
215
|
-
self.block6 = DiscBlock(d_conv_dim*12, d_conv_dim*16)
|
216
|
-
self.relu = nn.ReLU()
|
217
|
-
self.snlinear1 = snlinear(in_features=d_conv_dim*16*4*4+dim_c, out_features=1)
|
218
|
-
|
219
|
-
def forward(self, x, labels):
|
220
|
-
labels = labels.view(-1,1)
|
221
|
-
# n x 3 x 256 x 256
|
222
|
-
h0 = self.opt_block1(x) # n x d_conv_dim x 128 x 128
|
223
|
-
h1 = self.block1(h0) # n x d_conv_dim*2 x 64 x 64
|
224
|
-
h1 = self.self_attn(h1) # n x d_conv_dim*2 x 64 x 64
|
225
|
-
h2 = self.block2(h1) # n x d_conv_dim*4 x 32 x 32
|
226
|
-
h3 = self.block3(h2) # n x d_conv_dim*8 x 16 x 16
|
227
|
-
h4 = self.block4(h3) # n x d_conv_dim*16 x 8 x 8
|
228
|
-
h5 = self.block5(h4) # n x d_conv_dim*16 x 4 x 4
|
229
|
-
h6 = self.block6(h5, downsample=False) # n x d_conv_dim*16 x 4 x 4
|
230
|
-
out = self.relu(h6) # n x d_conv_dim*16 x 4 x 4
|
231
|
-
out = out.view(-1,self.d_conv_dim*16*4*4)
|
232
|
-
out = torch.cat((out, labels),dim=1)
|
233
|
-
out = self.snlinear1(out)
|
234
|
-
return out
|
235
|
-
|
236
|
-
|
237
|
-
|
238
|
-
if __name__ == "__main__":
|
239
|
-
|
240
|
-
netG = cGAN_concat_SAGAN_Generator(z_dim=128).cuda()
|
241
|
-
netD = cGAN_concat_SAGAN_Discriminator().cuda()
|
242
|
-
|
243
|
-
n = 4
|
244
|
-
y = torch.randn(n, 1).cuda()
|
245
|
-
z = torch.randn(n, 128).cuda()
|
246
|
-
x = netG(z,y)
|
247
|
-
print(x.size())
|
248
|
-
o = netD(x,y)
|
249
|
-
print(o.size())
|
@@ -1,76 +0,0 @@
|
|
1
|
-
# Differentiable Augmentation for Data-Efficient GAN Training
|
2
|
-
# Shengyu Zhao, Zhijian Liu, Ji Lin, Jun-Yan Zhu, and Song Han
|
3
|
-
# https://arxiv.org/pdf/2006.10738
|
4
|
-
|
5
|
-
import torch
|
6
|
-
import torch.nn.functional as F
|
7
|
-
|
8
|
-
|
9
|
-
def DiffAugment(x, policy='', channels_first=True):
|
10
|
-
if policy:
|
11
|
-
if not channels_first:
|
12
|
-
x = x.permute(0, 3, 1, 2)
|
13
|
-
for p in policy.split(','):
|
14
|
-
for f in AUGMENT_FNS[p]:
|
15
|
-
x = f(x)
|
16
|
-
if not channels_first:
|
17
|
-
x = x.permute(0, 2, 3, 1)
|
18
|
-
x = x.contiguous()
|
19
|
-
return x
|
20
|
-
|
21
|
-
|
22
|
-
def rand_brightness(x):
|
23
|
-
x = x + (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) - 0.5)
|
24
|
-
return x
|
25
|
-
|
26
|
-
|
27
|
-
def rand_saturation(x):
|
28
|
-
x_mean = x.mean(dim=1, keepdim=True)
|
29
|
-
x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) * 2) + x_mean
|
30
|
-
return x
|
31
|
-
|
32
|
-
|
33
|
-
def rand_contrast(x):
|
34
|
-
x_mean = x.mean(dim=[1, 2, 3], keepdim=True)
|
35
|
-
x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) + 0.5) + x_mean
|
36
|
-
return x
|
37
|
-
|
38
|
-
|
39
|
-
def rand_translation(x, ratio=0.125):
|
40
|
-
shift_x, shift_y = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
|
41
|
-
translation_x = torch.randint(-shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device)
|
42
|
-
translation_y = torch.randint(-shift_y, shift_y + 1, size=[x.size(0), 1, 1], device=x.device)
|
43
|
-
grid_batch, grid_x, grid_y = torch.meshgrid(
|
44
|
-
torch.arange(x.size(0), dtype=torch.long, device=x.device),
|
45
|
-
torch.arange(x.size(2), dtype=torch.long, device=x.device),
|
46
|
-
torch.arange(x.size(3), dtype=torch.long, device=x.device),
|
47
|
-
)
|
48
|
-
grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1)
|
49
|
-
grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1)
|
50
|
-
x_pad = F.pad(x, [1, 1, 1, 1, 0, 0, 0, 0])
|
51
|
-
x = x_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_x, grid_y].permute(0, 3, 1, 2)
|
52
|
-
return x
|
53
|
-
|
54
|
-
|
55
|
-
def rand_cutout(x, ratio=0.5):
|
56
|
-
cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
|
57
|
-
offset_x = torch.randint(0, x.size(2) + (1 - cutout_size[0] % 2), size=[x.size(0), 1, 1], device=x.device)
|
58
|
-
offset_y = torch.randint(0, x.size(3) + (1 - cutout_size[1] % 2), size=[x.size(0), 1, 1], device=x.device)
|
59
|
-
grid_batch, grid_x, grid_y = torch.meshgrid(
|
60
|
-
torch.arange(x.size(0), dtype=torch.long, device=x.device),
|
61
|
-
torch.arange(cutout_size[0], dtype=torch.long, device=x.device),
|
62
|
-
torch.arange(cutout_size[1], dtype=torch.long, device=x.device),
|
63
|
-
)
|
64
|
-
grid_x = torch.clamp(grid_x + offset_x - cutout_size[0] // 2, min=0, max=x.size(2) - 1)
|
65
|
-
grid_y = torch.clamp(grid_y + offset_y - cutout_size[1] // 2, min=0, max=x.size(3) - 1)
|
66
|
-
mask = torch.ones(x.size(0), x.size(2), x.size(3), dtype=x.dtype, device=x.device)
|
67
|
-
mask[grid_batch, grid_x, grid_y] = 0
|
68
|
-
x = x * mask.unsqueeze(1)
|
69
|
-
return x
|
70
|
-
|
71
|
-
|
72
|
-
AUGMENT_FNS = {
|
73
|
-
'color': [rand_brightness, rand_saturation, rand_contrast],
|
74
|
-
'translation': [rand_translation],
|
75
|
-
'cutout': [rand_cutout],
|
76
|
-
}
|
File without changes
|
@@ -1,205 +0,0 @@
|
|
1
|
-
"""
|
2
|
-
Compute
|
3
|
-
Inception Score (IS),
|
4
|
-
Frechet Inception Discrepency (FID), ref "https://github.com/mseitzer/pytorch-fid/blob/master/fid_score.py"
|
5
|
-
Maximum Mean Discrepancy (MMD)
|
6
|
-
for a set of fake images
|
7
|
-
|
8
|
-
use numpy array
|
9
|
-
Xr: high-level features for real images; nr by d array
|
10
|
-
Yr: labels for real images
|
11
|
-
Xg: high-level features for fake images; ng by d array
|
12
|
-
Yg: labels for fake images
|
13
|
-
IMGSr: real images
|
14
|
-
IMGSg: fake images
|
15
|
-
|
16
|
-
"""
|
17
|
-
|
18
|
-
import os
|
19
|
-
import gc
|
20
|
-
import numpy as np
|
21
|
-
# from numpy import linalg as LA
|
22
|
-
from scipy import linalg
|
23
|
-
import torch
|
24
|
-
import torch.nn as nn
|
25
|
-
from scipy.stats import entropy
|
26
|
-
from torch.nn import functional as F
|
27
|
-
from torchvision.utils import save_image
|
28
|
-
|
29
|
-
from .utils import SimpleProgressBar, IMGs_dataset
|
30
|
-
|
31
|
-
|
32
|
-
def normalize_images(batch_images):
|
33
|
-
batch_images = batch_images/255.0
|
34
|
-
batch_images = (batch_images - 0.5)/0.5
|
35
|
-
return batch_images
|
36
|
-
|
37
|
-
##############################################################################
|
38
|
-
# FID scores
|
39
|
-
##############################################################################
|
40
|
-
# compute FID based on extracted features
|
41
|
-
def FID(Xr, Xg, eps=1e-10):
|
42
|
-
'''
|
43
|
-
The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
|
44
|
-
and X_2 ~ N(mu_2, C_2) is
|
45
|
-
d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
|
46
|
-
'''
|
47
|
-
#sample mean
|
48
|
-
MUr = np.mean(Xr, axis = 0)
|
49
|
-
MUg = np.mean(Xg, axis = 0)
|
50
|
-
mean_diff = MUr - MUg
|
51
|
-
#sample covariance
|
52
|
-
SIGMAr = np.cov(Xr.transpose())
|
53
|
-
SIGMAg = np.cov(Xg.transpose())
|
54
|
-
|
55
|
-
# Product might be almost singular
|
56
|
-
covmean, _ = linalg.sqrtm(SIGMAr.dot(SIGMAg), disp=False)#square root of a matrix
|
57
|
-
covmean = covmean.real
|
58
|
-
if not np.isfinite(covmean).all():
|
59
|
-
msg = ('fid calculation produces singular product; '
|
60
|
-
'adding %s to diagonal of cov estimates') % eps
|
61
|
-
print(msg)
|
62
|
-
offset = np.eye(SIGMAr.shape[0]) * eps
|
63
|
-
covmean = linalg.sqrtm((SIGMAr + offset).dot(SIGMAg + offset))
|
64
|
-
|
65
|
-
#fid score
|
66
|
-
fid_score = mean_diff.dot(mean_diff) + np.trace(SIGMAr + SIGMAg - 2*covmean)
|
67
|
-
|
68
|
-
return fid_score
|
69
|
-
|
70
|
-
##test
|
71
|
-
#Xr = np.random.rand(10000,1000)
|
72
|
-
#Xg = np.random.rand(10000,1000)
|
73
|
-
#print(FID(Xr, Xg))
|
74
|
-
|
75
|
-
# compute FID from raw images
|
76
|
-
def cal_FID(PreNetFID, IMGSr, IMGSg, batch_size = 500, resize = None, norm_img = False):
|
77
|
-
#resize: if None, do not resize; if resize = (H,W), resize images to 3 x H x W
|
78
|
-
|
79
|
-
PreNetFID.eval()
|
80
|
-
|
81
|
-
nr = IMGSr.shape[0]
|
82
|
-
ng = IMGSg.shape[0]
|
83
|
-
|
84
|
-
nc = IMGSr.shape[1] #IMGSr is nrxNCxIMG_SIExIMG_SIZE
|
85
|
-
img_size = IMGSr.shape[2]
|
86
|
-
|
87
|
-
if batch_size > min(nr, ng):
|
88
|
-
batch_size = min(nr, ng)
|
89
|
-
# print("FID: recude batch size to {}".format(batch_size))
|
90
|
-
|
91
|
-
#compute the length of extracted features
|
92
|
-
with torch.no_grad():
|
93
|
-
test_img = torch.from_numpy(IMGSr[0].reshape((1,nc,img_size,img_size))).type(torch.float).cuda()
|
94
|
-
if resize is not None:
|
95
|
-
test_img = nn.functional.interpolate(test_img, size = resize, scale_factor=None, mode='bilinear', align_corners=False)
|
96
|
-
if norm_img:
|
97
|
-
test_img = normalize_images(test_img)
|
98
|
-
# _, test_features = PreNetFID(test_img)
|
99
|
-
test_features = PreNetFID(test_img)
|
100
|
-
d = test_features.shape[1] #length of extracted features
|
101
|
-
|
102
|
-
Xr = np.zeros((nr, d))
|
103
|
-
Xg = np.zeros((ng, d))
|
104
|
-
|
105
|
-
#batch_size = 500
|
106
|
-
with torch.no_grad():
|
107
|
-
tmp = 0
|
108
|
-
pb1 = SimpleProgressBar()
|
109
|
-
for i in range(nr//batch_size):
|
110
|
-
imgr_tensor = torch.from_numpy(IMGSr[tmp:(tmp+batch_size)]).type(torch.float).cuda()
|
111
|
-
if resize is not None:
|
112
|
-
imgr_tensor = nn.functional.interpolate(imgr_tensor, size = resize, scale_factor=None, mode='bilinear', align_corners=False)
|
113
|
-
if norm_img:
|
114
|
-
imgr_tensor = normalize_images(imgr_tensor)
|
115
|
-
# _, Xr_tmp = PreNetFID(imgr_tensor)
|
116
|
-
Xr_tmp = PreNetFID(imgr_tensor)
|
117
|
-
Xr[tmp:(tmp+batch_size)] = Xr_tmp.detach().cpu().numpy()
|
118
|
-
tmp+=batch_size
|
119
|
-
# pb1.update(min(float(i)*100/(nr//batch_size), 100))
|
120
|
-
pb1.update(min(max(tmp/nr*100,100), 100))
|
121
|
-
del Xr_tmp,imgr_tensor; gc.collect()
|
122
|
-
torch.cuda.empty_cache()
|
123
|
-
|
124
|
-
tmp = 0
|
125
|
-
pb2 = SimpleProgressBar()
|
126
|
-
for j in range(ng//batch_size):
|
127
|
-
imgg_tensor = torch.from_numpy(IMGSg[tmp:(tmp+batch_size)]).type(torch.float).cuda()
|
128
|
-
if resize is not None:
|
129
|
-
imgg_tensor = nn.functional.interpolate(imgg_tensor, size = resize, scale_factor=None, mode='bilinear', align_corners=False)
|
130
|
-
if norm_img:
|
131
|
-
imgg_tensor = normalize_images(imgg_tensor)
|
132
|
-
# _, Xg_tmp = PreNetFID(imgg_tensor)
|
133
|
-
Xg_tmp = PreNetFID(imgg_tensor)
|
134
|
-
Xg[tmp:(tmp+batch_size)] = Xg_tmp.detach().cpu().numpy()
|
135
|
-
tmp+=batch_size
|
136
|
-
# pb2.update(min(float(j)*100/(ng//batch_size), 100))
|
137
|
-
pb2.update(min(max(tmp/ng*100, 100), 100))
|
138
|
-
del Xg_tmp,imgg_tensor; gc.collect()
|
139
|
-
torch.cuda.empty_cache()
|
140
|
-
|
141
|
-
|
142
|
-
fid_score = FID(Xr, Xg, eps=1e-6)
|
143
|
-
|
144
|
-
return fid_score
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
##############################################################################
|
152
|
-
# label_score
|
153
|
-
# difference between assigned label and predicted label
|
154
|
-
##############################################################################
|
155
|
-
def cal_labelscore(PreNet, images, labels_assi, min_label_before_shift, max_label_after_shift, batch_size = 500, resize = None, norm_img = False, num_workers=0):
|
156
|
-
'''
|
157
|
-
PreNet: pre-trained CNN
|
158
|
-
images: fake images
|
159
|
-
labels_assi: assigned labels
|
160
|
-
resize: if None, do not resize; if resize = (H,W), resize images to 3 x H x W
|
161
|
-
'''
|
162
|
-
|
163
|
-
PreNet.eval()
|
164
|
-
|
165
|
-
# assume images are nxncximg_sizeximg_size
|
166
|
-
n = images.shape[0]
|
167
|
-
nc = images.shape[1] #number of channels
|
168
|
-
img_size = images.shape[2]
|
169
|
-
labels_assi = labels_assi.reshape(-1)
|
170
|
-
|
171
|
-
eval_trainset = IMGs_dataset(images, labels_assi, normalize=False)
|
172
|
-
eval_dataloader = torch.utils.data.DataLoader(eval_trainset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
|
173
|
-
|
174
|
-
labels_pred = np.zeros(n+batch_size)
|
175
|
-
|
176
|
-
nimgs_got = 0
|
177
|
-
pb = SimpleProgressBar()
|
178
|
-
for batch_idx, (batch_images, batch_labels) in enumerate(eval_dataloader):
|
179
|
-
batch_images = batch_images.type(torch.float).cuda()
|
180
|
-
batch_labels = batch_labels.type(torch.float).cuda()
|
181
|
-
batch_size_curr = len(batch_labels)
|
182
|
-
|
183
|
-
if norm_img:
|
184
|
-
batch_images = normalize_images(batch_images)
|
185
|
-
|
186
|
-
batch_labels_pred, _ = PreNet(batch_images)
|
187
|
-
labels_pred[nimgs_got:(nimgs_got+batch_size_curr)] = batch_labels_pred.detach().cpu().numpy().reshape(-1)
|
188
|
-
|
189
|
-
nimgs_got += batch_size_curr
|
190
|
-
pb.update((float(nimgs_got)/n)*100)
|
191
|
-
|
192
|
-
del batch_images; gc.collect()
|
193
|
-
torch.cuda.empty_cache()
|
194
|
-
#end for batch_idx
|
195
|
-
|
196
|
-
labels_pred = labels_pred[0:n]
|
197
|
-
|
198
|
-
|
199
|
-
labels_pred = (labels_pred*max_label_after_shift)-np.abs(min_label_before_shift)
|
200
|
-
labels_assi = (labels_assi*max_label_after_shift)-np.abs(min_label_before_shift)
|
201
|
-
|
202
|
-
ls_mean = np.mean(np.abs(labels_pred-labels_assi))
|
203
|
-
ls_std = np.std(np.abs(labels_pred-labels_assi))
|
204
|
-
|
205
|
-
return ls_mean, ls_std
|
@@ -1,87 +0,0 @@
|
|
1
|
-
import argparse
|
2
|
-
|
3
|
-
def parse_opts():
|
4
|
-
parser = argparse.ArgumentParser()
|
5
|
-
|
6
|
-
''' Overall Settings '''
|
7
|
-
parser.add_argument('--root_path', type=str, default='')
|
8
|
-
parser.add_argument('--data_path', type=str, default='')
|
9
|
-
parser.add_argument('--eval_ckpt_path', type=str, default='')
|
10
|
-
parser.add_argument('--seed', type=int, default=2021, metavar='S', help='random seed (default: 2020)')
|
11
|
-
parser.add_argument('--num_workers', type=int, default=0)
|
12
|
-
|
13
|
-
|
14
|
-
''' Dataset '''
|
15
|
-
## Data split: RC-49 is split into a train set (the last decimal of the degree is odd) and a test set (the last decimal of the degree is even); the unique labels in two sets do not overlap.
|
16
|
-
parser.add_argument('--data_split', type=str, default='train',
|
17
|
-
choices=['all', 'train'])
|
18
|
-
parser.add_argument('--min_label', type=float, default=0.0)
|
19
|
-
parser.add_argument('--max_label', type=float, default=90.0)
|
20
|
-
parser.add_argument('--num_channels', type=int, default=3, metavar='N')
|
21
|
-
parser.add_argument('--img_size', type=int, default=128, metavar='N')
|
22
|
-
parser.add_argument('--max_num_img_per_label', type=int, default=50, metavar='N')
|
23
|
-
parser.add_argument('--max_num_img_per_label_after_replica', type=int, default=0, metavar='N')
|
24
|
-
parser.add_argument('--show_real_imgs', action='store_true', default=False)
|
25
|
-
parser.add_argument('--visualize_fake_images', action='store_true', default=False)
|
26
|
-
|
27
|
-
|
28
|
-
''' GAN settings '''
|
29
|
-
parser.add_argument('--GAN', type=str, default='CcGAN', choices=['cGAN', 'cGAN-concat', 'CcGAN'])
|
30
|
-
parser.add_argument('--GAN_arch', type=str, default='SAGAN', choices=['SAGAN'])
|
31
|
-
|
32
|
-
# label embedding setting
|
33
|
-
parser.add_argument('--net_embed', type=str, default='ResNet34_embed') #ResNetXX_emebed
|
34
|
-
parser.add_argument('--epoch_cnn_embed', type=int, default=200) #epoch of cnn training for label embedding
|
35
|
-
parser.add_argument('--resumeepoch_cnn_embed', type=int, default=0) #epoch of cnn training for label embedding
|
36
|
-
parser.add_argument('--epoch_net_y2h', type=int, default=500)
|
37
|
-
parser.add_argument('--dim_embed', type=int, default=128) #dimension of the embedding space
|
38
|
-
parser.add_argument('--batch_size_embed', type=int, default=256, metavar='N')
|
39
|
-
|
40
|
-
parser.add_argument('--loss_type_gan', type=str, default='hinge')
|
41
|
-
parser.add_argument('--niters_gan', type=int, default=10000, help='number of iterations')
|
42
|
-
parser.add_argument('--resume_niters_gan', type=int, default=0)
|
43
|
-
parser.add_argument('--save_niters_freq', type=int, default=2000, help='frequency of saving checkpoints')
|
44
|
-
parser.add_argument('--lr_g_gan', type=float, default=1e-4, help='learning rate for generator')
|
45
|
-
parser.add_argument('--lr_d_gan', type=float, default=1e-4, help='learning rate for discriminator')
|
46
|
-
parser.add_argument('--dim_gan', type=int, default=128, help='Latent dimension of GAN')
|
47
|
-
parser.add_argument('--batch_size_disc', type=int, default=64)
|
48
|
-
parser.add_argument('--batch_size_gene', type=int, default=64)
|
49
|
-
parser.add_argument('--num_D_steps', type=int, default=4, help='number of Ds updates in one iteration')
|
50
|
-
parser.add_argument('--cGAN_num_classes', type=int, default=20, metavar='N') #bin label into cGAN_num_classes
|
51
|
-
parser.add_argument('--visualize_freq', type=int, default=2000, help='frequency of visualization')
|
52
|
-
|
53
|
-
parser.add_argument('--kernel_sigma', type=float, default=-1.0,
|
54
|
-
help='If kernel_sigma<0, then use rule-of-thumb formula to compute the sigma.')
|
55
|
-
parser.add_argument('--threshold_type', type=str, default='hard', choices=['soft', 'hard'])
|
56
|
-
parser.add_argument('--kappa', type=float, default=-1)
|
57
|
-
parser.add_argument('--nonzero_soft_weight_threshold', type=float, default=1e-3,
|
58
|
-
help='threshold for determining nonzero weights for SVDL; we neglect images with too small weights')
|
59
|
-
|
60
|
-
# DiffAugment setting
|
61
|
-
parser.add_argument('--gan_DiffAugment', action='store_true', default=False)
|
62
|
-
parser.add_argument('--gan_DiffAugment_policy', type=str, default='color,translation,cutout')
|
63
|
-
|
64
|
-
|
65
|
-
# evaluation setting
|
66
|
-
'''
|
67
|
-
Four evaluation modes:
|
68
|
-
Mode 1: eval on unique labels used for GAN training;
|
69
|
-
Mode 2. eval on all unique labels in the dataset and when computing FID use all real images in the dataset;
|
70
|
-
Mode 3. eval on all unique labels in the dataset and when computing FID only use real images for GAN training in the dataset (to test SFID's effectiveness on unseen labels);
|
71
|
-
Mode 4. eval on a interval [min_label, max_label] with num_eval_labels labels.
|
72
|
-
'''
|
73
|
-
parser.add_argument('--eval_mode', type=int, default=2)
|
74
|
-
parser.add_argument('--num_eval_labels', type=int, default=-1)
|
75
|
-
parser.add_argument('--samp_batch_size', type=int, default=200)
|
76
|
-
parser.add_argument('--nfake_per_label', type=int, default=200)
|
77
|
-
parser.add_argument('--nreal_per_label', type=int, default=-1)
|
78
|
-
parser.add_argument('--comp_FID', action='store_true', default=False)
|
79
|
-
parser.add_argument('--epoch_FID_CNN', type=int, default=200)
|
80
|
-
parser.add_argument('--FID_radius', type=float, default=0)
|
81
|
-
parser.add_argument('--FID_num_centers', type=int, default=-1)
|
82
|
-
parser.add_argument('--dump_fake_for_NIQE', action='store_true', default=False,
|
83
|
-
help='Dump fake images for computing NIQE')
|
84
|
-
|
85
|
-
args = parser.parse_args()
|
86
|
-
|
87
|
-
return args
|