Myosotis-Researches 0.1.7__py3-none-any.whl → 0.1.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. myosotis_researches/CcGAN/train/__init__.py +4 -0
  2. myosotis_researches/CcGAN/{train_128_output_10 → train}/train_ccgan.py +4 -4
  3. myosotis_researches/CcGAN/{train_128 → train}/train_cgan.py +1 -3
  4. myosotis_researches/CcGAN/{train_128 → train}/train_cgan_concat.py +1 -3
  5. myosotis_researches/CcGAN/utils/__init__.py +2 -1
  6. myosotis_researches/CcGAN/utils/train.py +94 -3
  7. {myosotis_researches-0.1.7.dist-info → myosotis_researches-0.1.9.dist-info}/METADATA +1 -1
  8. myosotis_researches-0.1.9.dist-info/RECORD +24 -0
  9. myosotis_researches/CcGAN/models_128/CcGAN_SAGAN.py +0 -301
  10. myosotis_researches/CcGAN/models_128/ResNet_class_eval.py +0 -141
  11. myosotis_researches/CcGAN/models_128/ResNet_embed.py +0 -188
  12. myosotis_researches/CcGAN/models_128/ResNet_regre_eval.py +0 -175
  13. myosotis_researches/CcGAN/models_128/__init__.py +0 -7
  14. myosotis_researches/CcGAN/models_128/autoencoder.py +0 -119
  15. myosotis_researches/CcGAN/models_128/cGAN_SAGAN.py +0 -276
  16. myosotis_researches/CcGAN/models_128/cGAN_concat_SAGAN.py +0 -245
  17. myosotis_researches/CcGAN/models_256/CcGAN_SAGAN.py +0 -303
  18. myosotis_researches/CcGAN/models_256/ResNet_class_eval.py +0 -142
  19. myosotis_researches/CcGAN/models_256/ResNet_embed.py +0 -188
  20. myosotis_researches/CcGAN/models_256/ResNet_regre_eval.py +0 -178
  21. myosotis_researches/CcGAN/models_256/__init__.py +0 -7
  22. myosotis_researches/CcGAN/models_256/autoencoder.py +0 -133
  23. myosotis_researches/CcGAN/models_256/cGAN_SAGAN.py +0 -280
  24. myosotis_researches/CcGAN/models_256/cGAN_concat_SAGAN.py +0 -249
  25. myosotis_researches/CcGAN/train_128/DiffAugment_pytorch.py +0 -76
  26. myosotis_researches/CcGAN/train_128/__init__.py +0 -0
  27. myosotis_researches/CcGAN/train_128/eval_metrics.py +0 -205
  28. myosotis_researches/CcGAN/train_128/opts.py +0 -87
  29. myosotis_researches/CcGAN/train_128/pretrain_AE.py +0 -268
  30. myosotis_researches/CcGAN/train_128/pretrain_CNN_class.py +0 -251
  31. myosotis_researches/CcGAN/train_128/pretrain_CNN_regre.py +0 -255
  32. myosotis_researches/CcGAN/train_128/train_ccgan.py +0 -303
  33. myosotis_researches/CcGAN/train_128/utils.py +0 -120
  34. myosotis_researches/CcGAN/train_128_output_10/DiffAugment_pytorch.py +0 -76
  35. myosotis_researches/CcGAN/train_128_output_10/__init__.py +0 -0
  36. myosotis_researches/CcGAN/train_128_output_10/eval_metrics.py +0 -205
  37. myosotis_researches/CcGAN/train_128_output_10/opts.py +0 -87
  38. myosotis_researches/CcGAN/train_128_output_10/pretrain_AE.py +0 -268
  39. myosotis_researches/CcGAN/train_128_output_10/pretrain_CNN_class.py +0 -251
  40. myosotis_researches/CcGAN/train_128_output_10/pretrain_CNN_regre.py +0 -255
  41. myosotis_researches/CcGAN/train_128_output_10/train_cgan.py +0 -254
  42. myosotis_researches/CcGAN/train_128_output_10/train_cgan_concat.py +0 -242
  43. myosotis_researches/CcGAN/train_128_output_10/train_net_for_label_embed.py +0 -181
  44. myosotis_researches/CcGAN/train_128_output_10/utils.py +0 -120
  45. myosotis_researches-0.1.7.dist-info/RECORD +0 -59
  46. /myosotis_researches/CcGAN/{train_128 → train}/train_net_for_label_embed.py +0 -0
  47. {myosotis_researches-0.1.7.dist-info → myosotis_researches-0.1.9.dist-info}/WHEEL +0 -0
  48. {myosotis_researches-0.1.7.dist-info → myosotis_researches-0.1.9.dist-info}/licenses/LICENSE +0 -0
  49. {myosotis_researches-0.1.7.dist-info → myosotis_researches-0.1.9.dist-info}/top_level.txt +0 -0
@@ -1,249 +0,0 @@
1
- import numpy as np
2
- import torch
3
- import torch.nn as nn
4
- import torch.nn.functional as F
5
-
6
- from torch.nn.utils import spectral_norm
7
- from torch.nn.init import xavier_uniform_
8
-
9
-
10
- def init_weights(m):
11
- if type(m) == nn.Linear or type(m) == nn.Conv2d:
12
- xavier_uniform_(m.weight)
13
- m.bias.data.fill_(0.)
14
-
15
-
16
- def snconv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
17
- return spectral_norm(nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
18
- stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias))
19
-
20
-
21
- def snlinear(in_features, out_features):
22
- return spectral_norm(nn.Linear(in_features=in_features, out_features=out_features))
23
-
24
-
25
- def sn_embedding(num_embeddings, embedding_dim):
26
- return spectral_norm(nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim))
27
-
28
-
29
- class Self_Attn(nn.Module):
30
- """ Self attention Layer"""
31
-
32
- def __init__(self, in_channels):
33
- super(Self_Attn, self).__init__()
34
- self.in_channels = in_channels
35
- self.snconv1x1_theta = snconv2d(in_channels=in_channels, out_channels=in_channels//8, kernel_size=1, stride=1, padding=0)
36
- self.snconv1x1_phi = snconv2d(in_channels=in_channels, out_channels=in_channels//8, kernel_size=1, stride=1, padding=0)
37
- self.snconv1x1_g = snconv2d(in_channels=in_channels, out_channels=in_channels//2, kernel_size=1, stride=1, padding=0)
38
- self.snconv1x1_attn = snconv2d(in_channels=in_channels//2, out_channels=in_channels, kernel_size=1, stride=1, padding=0)
39
- self.maxpool = nn.MaxPool2d(2, stride=2, padding=0)
40
- self.softmax = nn.Softmax(dim=-1)
41
- self.sigma = nn.Parameter(torch.zeros(1))
42
-
43
- def forward(self, x):
44
- """
45
- inputs :
46
- x : input feature maps(B X C X W X H)
47
- returns :
48
- out : self attention value + input feature
49
- attention: B X N X N (N is Width*Height)
50
- """
51
- _, ch, h, w = x.size()
52
- # Theta path
53
- theta = self.snconv1x1_theta(x)
54
- theta = theta.view(-1, ch//8, h*w)
55
- # Phi path
56
- phi = self.snconv1x1_phi(x)
57
- phi = self.maxpool(phi)
58
- phi = phi.view(-1, ch//8, h*w//4)
59
- # Attn map
60
- attn = torch.bmm(theta.permute(0, 2, 1), phi)
61
- attn = self.softmax(attn)
62
- # g path
63
- g = self.snconv1x1_g(x)
64
- g = self.maxpool(g)
65
- g = g.view(-1, ch//2, h*w//4)
66
- # Attn_g
67
- attn_g = torch.bmm(g, attn.permute(0, 2, 1))
68
- attn_g = attn_g.view(-1, ch//2, h, w)
69
- attn_g = self.snconv1x1_attn(attn_g)
70
- # Out
71
- out = x + self.sigma*attn_g
72
- return out
73
-
74
-
75
- class GenBlock(nn.Module):
76
- def __init__(self, in_channels, out_channels):
77
- super(GenBlock, self).__init__()
78
- self.cond_bn1 = nn.BatchNorm2d(in_channels)
79
- self.relu = nn.ReLU(inplace=True)
80
- self.snconv2d1 = snconv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
81
- self.cond_bn2 = nn.BatchNorm2d(out_channels)
82
- self.snconv2d2 = snconv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
83
- self.snconv2d0 = snconv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0)
84
-
85
- def forward(self, x):
86
- x0 = x
87
-
88
- x = self.cond_bn1(x)
89
- x = self.relu(x)
90
- x = F.interpolate(x, scale_factor=2, mode='nearest') # upsample
91
- x = self.snconv2d1(x)
92
- x = self.cond_bn2(x)
93
- x = self.relu(x)
94
- x = self.snconv2d2(x)
95
-
96
- x0 = F.interpolate(x0, scale_factor=2, mode='nearest') # upsample
97
- x0 = self.snconv2d0(x0)
98
-
99
- out = x + x0
100
- return out
101
-
102
-
103
- class cGAN_concat_SAGAN_Generator(nn.Module):
104
- """Generator."""
105
-
106
- def __init__(self, z_dim, dim_c=1, g_conv_dim=64):
107
- super(cGAN_concat_SAGAN_Generator, self).__init__()
108
-
109
- self.z_dim = z_dim
110
- self.dim_c = dim_c
111
- self.g_conv_dim = g_conv_dim
112
- self.snlinear0 = snlinear(in_features=z_dim+dim_c, out_features=g_conv_dim*16*4*4)
113
- self.block1 = GenBlock(g_conv_dim*16, g_conv_dim*16)
114
- self.block2 = GenBlock(g_conv_dim*16, g_conv_dim*8)
115
- self.block3 = GenBlock(g_conv_dim*8, g_conv_dim*4)
116
- self.block4 = GenBlock(g_conv_dim*4, g_conv_dim*2)
117
- self.self_attn = Self_Attn(g_conv_dim*2)
118
- self.block5 = GenBlock(g_conv_dim*2, g_conv_dim*2)
119
- self.block6 = GenBlock(g_conv_dim*2, g_conv_dim)
120
- self.bn = nn.BatchNorm2d(g_conv_dim, eps=1e-5, momentum=0.0001, affine=True)
121
- self.relu = nn.ReLU()
122
- self.snconv2d1 = snconv2d(in_channels=g_conv_dim, out_channels=3, kernel_size=3, stride=1, padding=1)
123
- self.tanh = nn.Tanh()
124
-
125
- # Weight init
126
- self.apply(init_weights)
127
-
128
- def forward(self, z, labels):
129
- # n x z_dim
130
- act0 = self.snlinear0(torch.cat((z, labels.view(-1,1)),dim=1)) # n x g_conv_dim*16*4*4
131
- act0 = act0.view(-1, self.g_conv_dim*16, 4, 4) # n x g_conv_dim*16 x 4 x 4
132
- act1 = self.block1(act0) # n x g_conv_dim*16 x 8 x 8
133
- act2 = self.block2(act1) # n x g_conv_dim*8 x 16 x 16
134
- act3 = self.block3(act2) # n x g_conv_dim*4 x 32 x 32
135
- act4 = self.block4(act3) # n x g_conv_dim*2 x 64 x 64
136
- act4 = self.self_attn(act4) # n x g_conv_dim*2 x 64 x 64
137
- act5 = self.block5(act4) # n x g_conv_dim x 128 x 128
138
- act6 = self.block6(act5) # n x g_conv_dim x 256 x 256
139
- act6 = self.bn(act6) # n x g_conv_dim x 256 x 256
140
- act6 = self.relu(act6) # n x g_conv_dim x 256 x 256
141
- act7 = self.snconv2d1(act6) # n x 3 x 256 x 256
142
- act7 = self.tanh(act7) # n x 3 x 256 x 256
143
- return act7
144
-
145
-
146
- class DiscOptBlock(nn.Module):
147
- def __init__(self, in_channels, out_channels):
148
- super(DiscOptBlock, self).__init__()
149
- self.snconv2d1 = snconv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
150
- self.relu = nn.ReLU(inplace=True)
151
- self.snconv2d2 = snconv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
152
- self.downsample = nn.AvgPool2d(2)
153
- self.snconv2d0 = snconv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0)
154
-
155
- def forward(self, x):
156
- x0 = x
157
-
158
- x = self.snconv2d1(x)
159
- x = self.relu(x)
160
- x = self.snconv2d2(x)
161
- x = self.downsample(x)
162
-
163
- x0 = self.downsample(x0)
164
- x0 = self.snconv2d0(x0)
165
-
166
- out = x + x0
167
- return out
168
-
169
-
170
- class DiscBlock(nn.Module):
171
- def __init__(self, in_channels, out_channels):
172
- super(DiscBlock, self).__init__()
173
- self.relu = nn.ReLU(inplace=True)
174
- self.snconv2d1 = snconv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
175
- self.snconv2d2 = snconv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
176
- self.downsample = nn.AvgPool2d(2)
177
- self.ch_mismatch = False
178
- if in_channels != out_channels:
179
- self.ch_mismatch = True
180
- self.snconv2d0 = snconv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0)
181
-
182
- def forward(self, x, downsample=True):
183
- x0 = x
184
-
185
- x = self.relu(x)
186
- x = self.snconv2d1(x)
187
- x = self.relu(x)
188
- x = self.snconv2d2(x)
189
- if downsample:
190
- x = self.downsample(x)
191
-
192
- if downsample or self.ch_mismatch:
193
- x0 = self.snconv2d0(x0)
194
- if downsample:
195
- x0 = self.downsample(x0)
196
-
197
- out = x + x0
198
- return out
199
-
200
-
201
- class cGAN_concat_SAGAN_Discriminator(nn.Module):
202
- """Discriminator."""
203
-
204
- def __init__(self, dim_c=1, d_conv_dim=64):
205
- super(cGAN_concat_SAGAN_Discriminator, self).__init__()
206
- self.d_conv_dim = d_conv_dim
207
- self.dim_c = dim_c
208
- self.opt_block1 = DiscOptBlock(3, d_conv_dim)
209
- self.block1 = DiscBlock(d_conv_dim, d_conv_dim*2)
210
- self.self_attn = Self_Attn(d_conv_dim*2)
211
- self.block2 = DiscBlock(d_conv_dim*2, d_conv_dim*4)
212
- self.block3 = DiscBlock(d_conv_dim*4, d_conv_dim*6)
213
- self.block4 = DiscBlock(d_conv_dim*6, d_conv_dim*12)
214
- self.block5 = DiscBlock(d_conv_dim*12, d_conv_dim*12)
215
- self.block6 = DiscBlock(d_conv_dim*12, d_conv_dim*16)
216
- self.relu = nn.ReLU()
217
- self.snlinear1 = snlinear(in_features=d_conv_dim*16*4*4+dim_c, out_features=1)
218
-
219
- def forward(self, x, labels):
220
- labels = labels.view(-1,1)
221
- # n x 3 x 256 x 256
222
- h0 = self.opt_block1(x) # n x d_conv_dim x 128 x 128
223
- h1 = self.block1(h0) # n x d_conv_dim*2 x 64 x 64
224
- h1 = self.self_attn(h1) # n x d_conv_dim*2 x 64 x 64
225
- h2 = self.block2(h1) # n x d_conv_dim*4 x 32 x 32
226
- h3 = self.block3(h2) # n x d_conv_dim*8 x 16 x 16
227
- h4 = self.block4(h3) # n x d_conv_dim*16 x 8 x 8
228
- h5 = self.block5(h4) # n x d_conv_dim*16 x 4 x 4
229
- h6 = self.block6(h5, downsample=False) # n x d_conv_dim*16 x 4 x 4
230
- out = self.relu(h6) # n x d_conv_dim*16 x 4 x 4
231
- out = out.view(-1,self.d_conv_dim*16*4*4)
232
- out = torch.cat((out, labels),dim=1)
233
- out = self.snlinear1(out)
234
- return out
235
-
236
-
237
-
238
- if __name__ == "__main__":
239
-
240
- netG = cGAN_concat_SAGAN_Generator(z_dim=128).cuda()
241
- netD = cGAN_concat_SAGAN_Discriminator().cuda()
242
-
243
- n = 4
244
- y = torch.randn(n, 1).cuda()
245
- z = torch.randn(n, 128).cuda()
246
- x = netG(z,y)
247
- print(x.size())
248
- o = netD(x,y)
249
- print(o.size())
@@ -1,76 +0,0 @@
1
- # Differentiable Augmentation for Data-Efficient GAN Training
2
- # Shengyu Zhao, Zhijian Liu, Ji Lin, Jun-Yan Zhu, and Song Han
3
- # https://arxiv.org/pdf/2006.10738
4
-
5
- import torch
6
- import torch.nn.functional as F
7
-
8
-
9
- def DiffAugment(x, policy='', channels_first=True):
10
- if policy:
11
- if not channels_first:
12
- x = x.permute(0, 3, 1, 2)
13
- for p in policy.split(','):
14
- for f in AUGMENT_FNS[p]:
15
- x = f(x)
16
- if not channels_first:
17
- x = x.permute(0, 2, 3, 1)
18
- x = x.contiguous()
19
- return x
20
-
21
-
22
- def rand_brightness(x):
23
- x = x + (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) - 0.5)
24
- return x
25
-
26
-
27
- def rand_saturation(x):
28
- x_mean = x.mean(dim=1, keepdim=True)
29
- x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) * 2) + x_mean
30
- return x
31
-
32
-
33
- def rand_contrast(x):
34
- x_mean = x.mean(dim=[1, 2, 3], keepdim=True)
35
- x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) + 0.5) + x_mean
36
- return x
37
-
38
-
39
- def rand_translation(x, ratio=0.125):
40
- shift_x, shift_y = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
41
- translation_x = torch.randint(-shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device)
42
- translation_y = torch.randint(-shift_y, shift_y + 1, size=[x.size(0), 1, 1], device=x.device)
43
- grid_batch, grid_x, grid_y = torch.meshgrid(
44
- torch.arange(x.size(0), dtype=torch.long, device=x.device),
45
- torch.arange(x.size(2), dtype=torch.long, device=x.device),
46
- torch.arange(x.size(3), dtype=torch.long, device=x.device),
47
- )
48
- grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1)
49
- grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1)
50
- x_pad = F.pad(x, [1, 1, 1, 1, 0, 0, 0, 0])
51
- x = x_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_x, grid_y].permute(0, 3, 1, 2)
52
- return x
53
-
54
-
55
- def rand_cutout(x, ratio=0.5):
56
- cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
57
- offset_x = torch.randint(0, x.size(2) + (1 - cutout_size[0] % 2), size=[x.size(0), 1, 1], device=x.device)
58
- offset_y = torch.randint(0, x.size(3) + (1 - cutout_size[1] % 2), size=[x.size(0), 1, 1], device=x.device)
59
- grid_batch, grid_x, grid_y = torch.meshgrid(
60
- torch.arange(x.size(0), dtype=torch.long, device=x.device),
61
- torch.arange(cutout_size[0], dtype=torch.long, device=x.device),
62
- torch.arange(cutout_size[1], dtype=torch.long, device=x.device),
63
- )
64
- grid_x = torch.clamp(grid_x + offset_x - cutout_size[0] // 2, min=0, max=x.size(2) - 1)
65
- grid_y = torch.clamp(grid_y + offset_y - cutout_size[1] // 2, min=0, max=x.size(3) - 1)
66
- mask = torch.ones(x.size(0), x.size(2), x.size(3), dtype=x.dtype, device=x.device)
67
- mask[grid_batch, grid_x, grid_y] = 0
68
- x = x * mask.unsqueeze(1)
69
- return x
70
-
71
-
72
- AUGMENT_FNS = {
73
- 'color': [rand_brightness, rand_saturation, rand_contrast],
74
- 'translation': [rand_translation],
75
- 'cutout': [rand_cutout],
76
- }
File without changes
@@ -1,205 +0,0 @@
1
- """
2
- Compute
3
- Inception Score (IS),
4
- Frechet Inception Discrepency (FID), ref "https://github.com/mseitzer/pytorch-fid/blob/master/fid_score.py"
5
- Maximum Mean Discrepancy (MMD)
6
- for a set of fake images
7
-
8
- use numpy array
9
- Xr: high-level features for real images; nr by d array
10
- Yr: labels for real images
11
- Xg: high-level features for fake images; ng by d array
12
- Yg: labels for fake images
13
- IMGSr: real images
14
- IMGSg: fake images
15
-
16
- """
17
-
18
- import os
19
- import gc
20
- import numpy as np
21
- # from numpy import linalg as LA
22
- from scipy import linalg
23
- import torch
24
- import torch.nn as nn
25
- from scipy.stats import entropy
26
- from torch.nn import functional as F
27
- from torchvision.utils import save_image
28
-
29
- from .utils import SimpleProgressBar, IMGs_dataset
30
-
31
-
32
- def normalize_images(batch_images):
33
- batch_images = batch_images/255.0
34
- batch_images = (batch_images - 0.5)/0.5
35
- return batch_images
36
-
37
- ##############################################################################
38
- # FID scores
39
- ##############################################################################
40
- # compute FID based on extracted features
41
- def FID(Xr, Xg, eps=1e-10):
42
- '''
43
- The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
44
- and X_2 ~ N(mu_2, C_2) is
45
- d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
46
- '''
47
- #sample mean
48
- MUr = np.mean(Xr, axis = 0)
49
- MUg = np.mean(Xg, axis = 0)
50
- mean_diff = MUr - MUg
51
- #sample covariance
52
- SIGMAr = np.cov(Xr.transpose())
53
- SIGMAg = np.cov(Xg.transpose())
54
-
55
- # Product might be almost singular
56
- covmean, _ = linalg.sqrtm(SIGMAr.dot(SIGMAg), disp=False)#square root of a matrix
57
- covmean = covmean.real
58
- if not np.isfinite(covmean).all():
59
- msg = ('fid calculation produces singular product; '
60
- 'adding %s to diagonal of cov estimates') % eps
61
- print(msg)
62
- offset = np.eye(SIGMAr.shape[0]) * eps
63
- covmean = linalg.sqrtm((SIGMAr + offset).dot(SIGMAg + offset))
64
-
65
- #fid score
66
- fid_score = mean_diff.dot(mean_diff) + np.trace(SIGMAr + SIGMAg - 2*covmean)
67
-
68
- return fid_score
69
-
70
- ##test
71
- #Xr = np.random.rand(10000,1000)
72
- #Xg = np.random.rand(10000,1000)
73
- #print(FID(Xr, Xg))
74
-
75
- # compute FID from raw images
76
- def cal_FID(PreNetFID, IMGSr, IMGSg, batch_size = 500, resize = None, norm_img = False):
77
- #resize: if None, do not resize; if resize = (H,W), resize images to 3 x H x W
78
-
79
- PreNetFID.eval()
80
-
81
- nr = IMGSr.shape[0]
82
- ng = IMGSg.shape[0]
83
-
84
- nc = IMGSr.shape[1] #IMGSr is nrxNCxIMG_SIExIMG_SIZE
85
- img_size = IMGSr.shape[2]
86
-
87
- if batch_size > min(nr, ng):
88
- batch_size = min(nr, ng)
89
- # print("FID: recude batch size to {}".format(batch_size))
90
-
91
- #compute the length of extracted features
92
- with torch.no_grad():
93
- test_img = torch.from_numpy(IMGSr[0].reshape((1,nc,img_size,img_size))).type(torch.float).cuda()
94
- if resize is not None:
95
- test_img = nn.functional.interpolate(test_img, size = resize, scale_factor=None, mode='bilinear', align_corners=False)
96
- if norm_img:
97
- test_img = normalize_images(test_img)
98
- # _, test_features = PreNetFID(test_img)
99
- test_features = PreNetFID(test_img)
100
- d = test_features.shape[1] #length of extracted features
101
-
102
- Xr = np.zeros((nr, d))
103
- Xg = np.zeros((ng, d))
104
-
105
- #batch_size = 500
106
- with torch.no_grad():
107
- tmp = 0
108
- pb1 = SimpleProgressBar()
109
- for i in range(nr//batch_size):
110
- imgr_tensor = torch.from_numpy(IMGSr[tmp:(tmp+batch_size)]).type(torch.float).cuda()
111
- if resize is not None:
112
- imgr_tensor = nn.functional.interpolate(imgr_tensor, size = resize, scale_factor=None, mode='bilinear', align_corners=False)
113
- if norm_img:
114
- imgr_tensor = normalize_images(imgr_tensor)
115
- # _, Xr_tmp = PreNetFID(imgr_tensor)
116
- Xr_tmp = PreNetFID(imgr_tensor)
117
- Xr[tmp:(tmp+batch_size)] = Xr_tmp.detach().cpu().numpy()
118
- tmp+=batch_size
119
- # pb1.update(min(float(i)*100/(nr//batch_size), 100))
120
- pb1.update(min(max(tmp/nr*100,100), 100))
121
- del Xr_tmp,imgr_tensor; gc.collect()
122
- torch.cuda.empty_cache()
123
-
124
- tmp = 0
125
- pb2 = SimpleProgressBar()
126
- for j in range(ng//batch_size):
127
- imgg_tensor = torch.from_numpy(IMGSg[tmp:(tmp+batch_size)]).type(torch.float).cuda()
128
- if resize is not None:
129
- imgg_tensor = nn.functional.interpolate(imgg_tensor, size = resize, scale_factor=None, mode='bilinear', align_corners=False)
130
- if norm_img:
131
- imgg_tensor = normalize_images(imgg_tensor)
132
- # _, Xg_tmp = PreNetFID(imgg_tensor)
133
- Xg_tmp = PreNetFID(imgg_tensor)
134
- Xg[tmp:(tmp+batch_size)] = Xg_tmp.detach().cpu().numpy()
135
- tmp+=batch_size
136
- # pb2.update(min(float(j)*100/(ng//batch_size), 100))
137
- pb2.update(min(max(tmp/ng*100, 100), 100))
138
- del Xg_tmp,imgg_tensor; gc.collect()
139
- torch.cuda.empty_cache()
140
-
141
-
142
- fid_score = FID(Xr, Xg, eps=1e-6)
143
-
144
- return fid_score
145
-
146
-
147
-
148
-
149
-
150
-
151
- ##############################################################################
152
- # label_score
153
- # difference between assigned label and predicted label
154
- ##############################################################################
155
- def cal_labelscore(PreNet, images, labels_assi, min_label_before_shift, max_label_after_shift, batch_size = 500, resize = None, norm_img = False, num_workers=0):
156
- '''
157
- PreNet: pre-trained CNN
158
- images: fake images
159
- labels_assi: assigned labels
160
- resize: if None, do not resize; if resize = (H,W), resize images to 3 x H x W
161
- '''
162
-
163
- PreNet.eval()
164
-
165
- # assume images are nxncximg_sizeximg_size
166
- n = images.shape[0]
167
- nc = images.shape[1] #number of channels
168
- img_size = images.shape[2]
169
- labels_assi = labels_assi.reshape(-1)
170
-
171
- eval_trainset = IMGs_dataset(images, labels_assi, normalize=False)
172
- eval_dataloader = torch.utils.data.DataLoader(eval_trainset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
173
-
174
- labels_pred = np.zeros(n+batch_size)
175
-
176
- nimgs_got = 0
177
- pb = SimpleProgressBar()
178
- for batch_idx, (batch_images, batch_labels) in enumerate(eval_dataloader):
179
- batch_images = batch_images.type(torch.float).cuda()
180
- batch_labels = batch_labels.type(torch.float).cuda()
181
- batch_size_curr = len(batch_labels)
182
-
183
- if norm_img:
184
- batch_images = normalize_images(batch_images)
185
-
186
- batch_labels_pred, _ = PreNet(batch_images)
187
- labels_pred[nimgs_got:(nimgs_got+batch_size_curr)] = batch_labels_pred.detach().cpu().numpy().reshape(-1)
188
-
189
- nimgs_got += batch_size_curr
190
- pb.update((float(nimgs_got)/n)*100)
191
-
192
- del batch_images; gc.collect()
193
- torch.cuda.empty_cache()
194
- #end for batch_idx
195
-
196
- labels_pred = labels_pred[0:n]
197
-
198
-
199
- labels_pred = (labels_pred*max_label_after_shift)-np.abs(min_label_before_shift)
200
- labels_assi = (labels_assi*max_label_after_shift)-np.abs(min_label_before_shift)
201
-
202
- ls_mean = np.mean(np.abs(labels_pred-labels_assi))
203
- ls_std = np.std(np.abs(labels_pred-labels_assi))
204
-
205
- return ls_mean, ls_std
@@ -1,87 +0,0 @@
1
- import argparse
2
-
3
- def parse_opts():
4
- parser = argparse.ArgumentParser()
5
-
6
- ''' Overall Settings '''
7
- parser.add_argument('--root_path', type=str, default='')
8
- parser.add_argument('--data_path', type=str, default='')
9
- parser.add_argument('--eval_ckpt_path', type=str, default='')
10
- parser.add_argument('--seed', type=int, default=2021, metavar='S', help='random seed (default: 2020)')
11
- parser.add_argument('--num_workers', type=int, default=0)
12
-
13
-
14
- ''' Dataset '''
15
- ## Data split: RC-49 is split into a train set (the last decimal of the degree is odd) and a test set (the last decimal of the degree is even); the unique labels in two sets do not overlap.
16
- parser.add_argument('--data_split', type=str, default='train',
17
- choices=['all', 'train'])
18
- parser.add_argument('--min_label', type=float, default=0.0)
19
- parser.add_argument('--max_label', type=float, default=90.0)
20
- parser.add_argument('--num_channels', type=int, default=3, metavar='N')
21
- parser.add_argument('--img_size', type=int, default=128, metavar='N')
22
- parser.add_argument('--max_num_img_per_label', type=int, default=50, metavar='N')
23
- parser.add_argument('--max_num_img_per_label_after_replica', type=int, default=0, metavar='N')
24
- parser.add_argument('--show_real_imgs', action='store_true', default=False)
25
- parser.add_argument('--visualize_fake_images', action='store_true', default=False)
26
-
27
-
28
- ''' GAN settings '''
29
- parser.add_argument('--GAN', type=str, default='CcGAN', choices=['cGAN', 'cGAN-concat', 'CcGAN'])
30
- parser.add_argument('--GAN_arch', type=str, default='SAGAN', choices=['SAGAN'])
31
-
32
- # label embedding setting
33
- parser.add_argument('--net_embed', type=str, default='ResNet34_embed') #ResNetXX_emebed
34
- parser.add_argument('--epoch_cnn_embed', type=int, default=200) #epoch of cnn training for label embedding
35
- parser.add_argument('--resumeepoch_cnn_embed', type=int, default=0) #epoch of cnn training for label embedding
36
- parser.add_argument('--epoch_net_y2h', type=int, default=500)
37
- parser.add_argument('--dim_embed', type=int, default=128) #dimension of the embedding space
38
- parser.add_argument('--batch_size_embed', type=int, default=256, metavar='N')
39
-
40
- parser.add_argument('--loss_type_gan', type=str, default='hinge')
41
- parser.add_argument('--niters_gan', type=int, default=10000, help='number of iterations')
42
- parser.add_argument('--resume_niters_gan', type=int, default=0)
43
- parser.add_argument('--save_niters_freq', type=int, default=2000, help='frequency of saving checkpoints')
44
- parser.add_argument('--lr_g_gan', type=float, default=1e-4, help='learning rate for generator')
45
- parser.add_argument('--lr_d_gan', type=float, default=1e-4, help='learning rate for discriminator')
46
- parser.add_argument('--dim_gan', type=int, default=128, help='Latent dimension of GAN')
47
- parser.add_argument('--batch_size_disc', type=int, default=64)
48
- parser.add_argument('--batch_size_gene', type=int, default=64)
49
- parser.add_argument('--num_D_steps', type=int, default=4, help='number of Ds updates in one iteration')
50
- parser.add_argument('--cGAN_num_classes', type=int, default=20, metavar='N') #bin label into cGAN_num_classes
51
- parser.add_argument('--visualize_freq', type=int, default=2000, help='frequency of visualization')
52
-
53
- parser.add_argument('--kernel_sigma', type=float, default=-1.0,
54
- help='If kernel_sigma<0, then use rule-of-thumb formula to compute the sigma.')
55
- parser.add_argument('--threshold_type', type=str, default='hard', choices=['soft', 'hard'])
56
- parser.add_argument('--kappa', type=float, default=-1)
57
- parser.add_argument('--nonzero_soft_weight_threshold', type=float, default=1e-3,
58
- help='threshold for determining nonzero weights for SVDL; we neglect images with too small weights')
59
-
60
- # DiffAugment setting
61
- parser.add_argument('--gan_DiffAugment', action='store_true', default=False)
62
- parser.add_argument('--gan_DiffAugment_policy', type=str, default='color,translation,cutout')
63
-
64
-
65
- # evaluation setting
66
- '''
67
- Four evaluation modes:
68
- Mode 1: eval on unique labels used for GAN training;
69
- Mode 2. eval on all unique labels in the dataset and when computing FID use all real images in the dataset;
70
- Mode 3. eval on all unique labels in the dataset and when computing FID only use real images for GAN training in the dataset (to test SFID's effectiveness on unseen labels);
71
- Mode 4. eval on a interval [min_label, max_label] with num_eval_labels labels.
72
- '''
73
- parser.add_argument('--eval_mode', type=int, default=2)
74
- parser.add_argument('--num_eval_labels', type=int, default=-1)
75
- parser.add_argument('--samp_batch_size', type=int, default=200)
76
- parser.add_argument('--nfake_per_label', type=int, default=200)
77
- parser.add_argument('--nreal_per_label', type=int, default=-1)
78
- parser.add_argument('--comp_FID', action='store_true', default=False)
79
- parser.add_argument('--epoch_FID_CNN', type=int, default=200)
80
- parser.add_argument('--FID_radius', type=float, default=0)
81
- parser.add_argument('--FID_num_centers', type=int, default=-1)
82
- parser.add_argument('--dump_fake_for_NIQE', action='store_true', default=False,
83
- help='Dump fake images for computing NIQE')
84
-
85
- args = parser.parse_args()
86
-
87
- return args