Myosotis-Researches 0.0.14__tar.gz → 0.0.16__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {myosotis_researches-0.0.14 → myosotis_researches-0.0.16/Myosotis_Researches.egg-info}/PKG-INFO +1 -1
- {myosotis_researches-0.0.14 → myosotis_researches-0.0.16}/Myosotis_Researches.egg-info/SOURCES.txt +12 -0
- {myosotis_researches-0.0.14/Myosotis_Researches.egg-info → myosotis_researches-0.0.16}/PKG-INFO +1 -1
- myosotis_researches-0.0.16/myosotis_researches/CcGAN/train_128/DiffAugment_pytorch.py +76 -0
- myosotis_researches-0.0.16/myosotis_researches/CcGAN/train_128/eval_metrics.py +205 -0
- myosotis_researches-0.0.16/myosotis_researches/CcGAN/train_128/opts.py +87 -0
- myosotis_researches-0.0.16/myosotis_researches/CcGAN/train_128/pretrain_AE.py +268 -0
- myosotis_researches-0.0.16/myosotis_researches/CcGAN/train_128/pretrain_CNN_class.py +251 -0
- myosotis_researches-0.0.16/myosotis_researches/CcGAN/train_128/pretrain_CNN_regre.py +255 -0
- myosotis_researches-0.0.16/myosotis_researches/CcGAN/train_128/train_ccgan.py +303 -0
- myosotis_researches-0.0.16/myosotis_researches/CcGAN/train_128/train_cgan.py +254 -0
- myosotis_researches-0.0.16/myosotis_researches/CcGAN/train_128/train_cgan_concat.py +242 -0
- myosotis_researches-0.0.16/myosotis_researches/CcGAN/train_128/train_net_for_label_embed.py +181 -0
- myosotis_researches-0.0.16/myosotis_researches/CcGAN/train_128/utils.py +120 -0
- myosotis_researches-0.0.16/myosotis_researches/__init__.py +0 -0
- {myosotis_researches-0.0.14 → myosotis_researches-0.0.16}/setup.py +1 -1
- {myosotis_researches-0.0.14 → myosotis_researches-0.0.16}/LICENSE +0 -0
- {myosotis_researches-0.0.14 → myosotis_researches-0.0.16}/Myosotis_Researches.egg-info/dependency_links.txt +0 -0
- {myosotis_researches-0.0.14 → myosotis_researches-0.0.16}/Myosotis_Researches.egg-info/top_level.txt +0 -0
- {myosotis_researches-0.0.14 → myosotis_researches-0.0.16}/README.md +0 -0
- {myosotis_researches-0.0.14 → myosotis_researches-0.0.16}/myosotis_researches/CcGAN/__init__.py +0 -0
- {myosotis_researches-0.0.14 → myosotis_researches-0.0.16}/myosotis_researches/CcGAN/models_128/CcGAN_SAGAN.py +0 -0
- {myosotis_researches-0.0.14 → myosotis_researches-0.0.16}/myosotis_researches/CcGAN/models_128/ResNet_class_eval.py +0 -0
- {myosotis_researches-0.0.14 → myosotis_researches-0.0.16}/myosotis_researches/CcGAN/models_128/ResNet_embed.py +0 -0
- {myosotis_researches-0.0.14 → myosotis_researches-0.0.16}/myosotis_researches/CcGAN/models_128/ResNet_regre_eval.py +0 -0
- {myosotis_researches-0.0.14 → myosotis_researches-0.0.16}/myosotis_researches/CcGAN/models_128/__init__.py +0 -0
- {myosotis_researches-0.0.14 → myosotis_researches-0.0.16}/myosotis_researches/CcGAN/models_128/autoencoder.py +0 -0
- {myosotis_researches-0.0.14 → myosotis_researches-0.0.16}/myosotis_researches/CcGAN/models_128/cGAN_SAGAN.py +0 -0
- {myosotis_researches-0.0.14 → myosotis_researches-0.0.16}/myosotis_researches/CcGAN/models_128/cGAN_concat_SAGAN.py +0 -0
- {myosotis_researches-0.0.14 → myosotis_researches-0.0.16}/myosotis_researches/CcGAN/models_256/CcGAN_SAGAN.py +0 -0
- {myosotis_researches-0.0.14 → myosotis_researches-0.0.16}/myosotis_researches/CcGAN/models_256/ResNet_class_eval.py +0 -0
- {myosotis_researches-0.0.14 → myosotis_researches-0.0.16}/myosotis_researches/CcGAN/models_256/ResNet_embed.py +0 -0
- {myosotis_researches-0.0.14 → myosotis_researches-0.0.16}/myosotis_researches/CcGAN/models_256/ResNet_regre_eval.py +0 -0
- {myosotis_researches-0.0.14 → myosotis_researches-0.0.16}/myosotis_researches/CcGAN/models_256/__init__.py +0 -0
- {myosotis_researches-0.0.14 → myosotis_researches-0.0.16}/myosotis_researches/CcGAN/models_256/autoencoder.py +0 -0
- {myosotis_researches-0.0.14 → myosotis_researches-0.0.16}/myosotis_researches/CcGAN/models_256/cGAN_SAGAN.py +0 -0
- {myosotis_researches-0.0.14 → myosotis_researches-0.0.16}/myosotis_researches/CcGAN/models_256/cGAN_concat_SAGAN.py +0 -0
- {myosotis_researches-0.0.14/myosotis_researches → myosotis_researches-0.0.16/myosotis_researches/CcGAN/train_128}/__init__.py +0 -0
- {myosotis_researches-0.0.14 → myosotis_researches-0.0.16}/myosotis_researches/CcGAN/utils/__init__.py +0 -0
- {myosotis_researches-0.0.14 → myosotis_researches-0.0.16}/myosotis_researches/CcGAN/utils/concat_image_horizontal.py +0 -0
- {myosotis_researches-0.0.14 → myosotis_researches-0.0.16}/myosotis_researches/CcGAN/utils/concat_image_vertical.py +0 -0
- {myosotis_researches-0.0.14 → myosotis_researches-0.0.16}/myosotis_researches/CcGAN/utils/make_h5.py +0 -0
- {myosotis_researches-0.0.14 → myosotis_researches-0.0.16}/myosotis_researches/CcGAN/utils/print_hdf5_structure.py +0 -0
- {myosotis_researches-0.0.14 → myosotis_researches-0.0.16}/pyproject.toml +0 -0
- {myosotis_researches-0.0.14 → myosotis_researches-0.0.16}/setup.cfg +0 -0
{myosotis_researches-0.0.14 → myosotis_researches-0.0.16}/Myosotis_Researches.egg-info/SOURCES.txt
RENAMED
@@ -24,6 +24,18 @@ myosotis_researches/CcGAN/models_256/__init__.py
|
|
24
24
|
myosotis_researches/CcGAN/models_256/autoencoder.py
|
25
25
|
myosotis_researches/CcGAN/models_256/cGAN_SAGAN.py
|
26
26
|
myosotis_researches/CcGAN/models_256/cGAN_concat_SAGAN.py
|
27
|
+
myosotis_researches/CcGAN/train_128/DiffAugment_pytorch.py
|
28
|
+
myosotis_researches/CcGAN/train_128/__init__.py
|
29
|
+
myosotis_researches/CcGAN/train_128/eval_metrics.py
|
30
|
+
myosotis_researches/CcGAN/train_128/opts.py
|
31
|
+
myosotis_researches/CcGAN/train_128/pretrain_AE.py
|
32
|
+
myosotis_researches/CcGAN/train_128/pretrain_CNN_class.py
|
33
|
+
myosotis_researches/CcGAN/train_128/pretrain_CNN_regre.py
|
34
|
+
myosotis_researches/CcGAN/train_128/train_ccgan.py
|
35
|
+
myosotis_researches/CcGAN/train_128/train_cgan.py
|
36
|
+
myosotis_researches/CcGAN/train_128/train_cgan_concat.py
|
37
|
+
myosotis_researches/CcGAN/train_128/train_net_for_label_embed.py
|
38
|
+
myosotis_researches/CcGAN/train_128/utils.py
|
27
39
|
myosotis_researches/CcGAN/utils/__init__.py
|
28
40
|
myosotis_researches/CcGAN/utils/concat_image_horizontal.py
|
29
41
|
myosotis_researches/CcGAN/utils/concat_image_vertical.py
|
@@ -0,0 +1,76 @@
|
|
1
|
+
# Differentiable Augmentation for Data-Efficient GAN Training
|
2
|
+
# Shengyu Zhao, Zhijian Liu, Ji Lin, Jun-Yan Zhu, and Song Han
|
3
|
+
# https://arxiv.org/pdf/2006.10738
|
4
|
+
|
5
|
+
import torch
|
6
|
+
import torch.nn.functional as F
|
7
|
+
|
8
|
+
|
9
|
+
def DiffAugment(x, policy='', channels_first=True):
|
10
|
+
if policy:
|
11
|
+
if not channels_first:
|
12
|
+
x = x.permute(0, 3, 1, 2)
|
13
|
+
for p in policy.split(','):
|
14
|
+
for f in AUGMENT_FNS[p]:
|
15
|
+
x = f(x)
|
16
|
+
if not channels_first:
|
17
|
+
x = x.permute(0, 2, 3, 1)
|
18
|
+
x = x.contiguous()
|
19
|
+
return x
|
20
|
+
|
21
|
+
|
22
|
+
def rand_brightness(x):
|
23
|
+
x = x + (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) - 0.5)
|
24
|
+
return x
|
25
|
+
|
26
|
+
|
27
|
+
def rand_saturation(x):
|
28
|
+
x_mean = x.mean(dim=1, keepdim=True)
|
29
|
+
x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) * 2) + x_mean
|
30
|
+
return x
|
31
|
+
|
32
|
+
|
33
|
+
def rand_contrast(x):
|
34
|
+
x_mean = x.mean(dim=[1, 2, 3], keepdim=True)
|
35
|
+
x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) + 0.5) + x_mean
|
36
|
+
return x
|
37
|
+
|
38
|
+
|
39
|
+
def rand_translation(x, ratio=0.125):
|
40
|
+
shift_x, shift_y = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
|
41
|
+
translation_x = torch.randint(-shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device)
|
42
|
+
translation_y = torch.randint(-shift_y, shift_y + 1, size=[x.size(0), 1, 1], device=x.device)
|
43
|
+
grid_batch, grid_x, grid_y = torch.meshgrid(
|
44
|
+
torch.arange(x.size(0), dtype=torch.long, device=x.device),
|
45
|
+
torch.arange(x.size(2), dtype=torch.long, device=x.device),
|
46
|
+
torch.arange(x.size(3), dtype=torch.long, device=x.device),
|
47
|
+
)
|
48
|
+
grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1)
|
49
|
+
grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1)
|
50
|
+
x_pad = F.pad(x, [1, 1, 1, 1, 0, 0, 0, 0])
|
51
|
+
x = x_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_x, grid_y].permute(0, 3, 1, 2)
|
52
|
+
return x
|
53
|
+
|
54
|
+
|
55
|
+
def rand_cutout(x, ratio=0.5):
|
56
|
+
cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
|
57
|
+
offset_x = torch.randint(0, x.size(2) + (1 - cutout_size[0] % 2), size=[x.size(0), 1, 1], device=x.device)
|
58
|
+
offset_y = torch.randint(0, x.size(3) + (1 - cutout_size[1] % 2), size=[x.size(0), 1, 1], device=x.device)
|
59
|
+
grid_batch, grid_x, grid_y = torch.meshgrid(
|
60
|
+
torch.arange(x.size(0), dtype=torch.long, device=x.device),
|
61
|
+
torch.arange(cutout_size[0], dtype=torch.long, device=x.device),
|
62
|
+
torch.arange(cutout_size[1], dtype=torch.long, device=x.device),
|
63
|
+
)
|
64
|
+
grid_x = torch.clamp(grid_x + offset_x - cutout_size[0] // 2, min=0, max=x.size(2) - 1)
|
65
|
+
grid_y = torch.clamp(grid_y + offset_y - cutout_size[1] // 2, min=0, max=x.size(3) - 1)
|
66
|
+
mask = torch.ones(x.size(0), x.size(2), x.size(3), dtype=x.dtype, device=x.device)
|
67
|
+
mask[grid_batch, grid_x, grid_y] = 0
|
68
|
+
x = x * mask.unsqueeze(1)
|
69
|
+
return x
|
70
|
+
|
71
|
+
|
72
|
+
AUGMENT_FNS = {
|
73
|
+
'color': [rand_brightness, rand_saturation, rand_contrast],
|
74
|
+
'translation': [rand_translation],
|
75
|
+
'cutout': [rand_cutout],
|
76
|
+
}
|
@@ -0,0 +1,205 @@
|
|
1
|
+
"""
|
2
|
+
Compute
|
3
|
+
Inception Score (IS),
|
4
|
+
Frechet Inception Discrepency (FID), ref "https://github.com/mseitzer/pytorch-fid/blob/master/fid_score.py"
|
5
|
+
Maximum Mean Discrepancy (MMD)
|
6
|
+
for a set of fake images
|
7
|
+
|
8
|
+
use numpy array
|
9
|
+
Xr: high-level features for real images; nr by d array
|
10
|
+
Yr: labels for real images
|
11
|
+
Xg: high-level features for fake images; ng by d array
|
12
|
+
Yg: labels for fake images
|
13
|
+
IMGSr: real images
|
14
|
+
IMGSg: fake images
|
15
|
+
|
16
|
+
"""
|
17
|
+
|
18
|
+
import os
|
19
|
+
import gc
|
20
|
+
import numpy as np
|
21
|
+
# from numpy import linalg as LA
|
22
|
+
from scipy import linalg
|
23
|
+
import torch
|
24
|
+
import torch.nn as nn
|
25
|
+
from scipy.stats import entropy
|
26
|
+
from torch.nn import functional as F
|
27
|
+
from torchvision.utils import save_image
|
28
|
+
|
29
|
+
from .utils import SimpleProgressBar, IMGs_dataset
|
30
|
+
|
31
|
+
|
32
|
+
def normalize_images(batch_images):
|
33
|
+
batch_images = batch_images/255.0
|
34
|
+
batch_images = (batch_images - 0.5)/0.5
|
35
|
+
return batch_images
|
36
|
+
|
37
|
+
##############################################################################
|
38
|
+
# FID scores
|
39
|
+
##############################################################################
|
40
|
+
# compute FID based on extracted features
|
41
|
+
def FID(Xr, Xg, eps=1e-10):
|
42
|
+
'''
|
43
|
+
The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
|
44
|
+
and X_2 ~ N(mu_2, C_2) is
|
45
|
+
d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
|
46
|
+
'''
|
47
|
+
#sample mean
|
48
|
+
MUr = np.mean(Xr, axis = 0)
|
49
|
+
MUg = np.mean(Xg, axis = 0)
|
50
|
+
mean_diff = MUr - MUg
|
51
|
+
#sample covariance
|
52
|
+
SIGMAr = np.cov(Xr.transpose())
|
53
|
+
SIGMAg = np.cov(Xg.transpose())
|
54
|
+
|
55
|
+
# Product might be almost singular
|
56
|
+
covmean, _ = linalg.sqrtm(SIGMAr.dot(SIGMAg), disp=False)#square root of a matrix
|
57
|
+
covmean = covmean.real
|
58
|
+
if not np.isfinite(covmean).all():
|
59
|
+
msg = ('fid calculation produces singular product; '
|
60
|
+
'adding %s to diagonal of cov estimates') % eps
|
61
|
+
print(msg)
|
62
|
+
offset = np.eye(SIGMAr.shape[0]) * eps
|
63
|
+
covmean = linalg.sqrtm((SIGMAr + offset).dot(SIGMAg + offset))
|
64
|
+
|
65
|
+
#fid score
|
66
|
+
fid_score = mean_diff.dot(mean_diff) + np.trace(SIGMAr + SIGMAg - 2*covmean)
|
67
|
+
|
68
|
+
return fid_score
|
69
|
+
|
70
|
+
##test
|
71
|
+
#Xr = np.random.rand(10000,1000)
|
72
|
+
#Xg = np.random.rand(10000,1000)
|
73
|
+
#print(FID(Xr, Xg))
|
74
|
+
|
75
|
+
# compute FID from raw images
|
76
|
+
def cal_FID(PreNetFID, IMGSr, IMGSg, batch_size = 500, resize = None, norm_img = False):
|
77
|
+
#resize: if None, do not resize; if resize = (H,W), resize images to 3 x H x W
|
78
|
+
|
79
|
+
PreNetFID.eval()
|
80
|
+
|
81
|
+
nr = IMGSr.shape[0]
|
82
|
+
ng = IMGSg.shape[0]
|
83
|
+
|
84
|
+
nc = IMGSr.shape[1] #IMGSr is nrxNCxIMG_SIExIMG_SIZE
|
85
|
+
img_size = IMGSr.shape[2]
|
86
|
+
|
87
|
+
if batch_size > min(nr, ng):
|
88
|
+
batch_size = min(nr, ng)
|
89
|
+
# print("FID: recude batch size to {}".format(batch_size))
|
90
|
+
|
91
|
+
#compute the length of extracted features
|
92
|
+
with torch.no_grad():
|
93
|
+
test_img = torch.from_numpy(IMGSr[0].reshape((1,nc,img_size,img_size))).type(torch.float).cuda()
|
94
|
+
if resize is not None:
|
95
|
+
test_img = nn.functional.interpolate(test_img, size = resize, scale_factor=None, mode='bilinear', align_corners=False)
|
96
|
+
if norm_img:
|
97
|
+
test_img = normalize_images(test_img)
|
98
|
+
# _, test_features = PreNetFID(test_img)
|
99
|
+
test_features = PreNetFID(test_img)
|
100
|
+
d = test_features.shape[1] #length of extracted features
|
101
|
+
|
102
|
+
Xr = np.zeros((nr, d))
|
103
|
+
Xg = np.zeros((ng, d))
|
104
|
+
|
105
|
+
#batch_size = 500
|
106
|
+
with torch.no_grad():
|
107
|
+
tmp = 0
|
108
|
+
pb1 = SimpleProgressBar()
|
109
|
+
for i in range(nr//batch_size):
|
110
|
+
imgr_tensor = torch.from_numpy(IMGSr[tmp:(tmp+batch_size)]).type(torch.float).cuda()
|
111
|
+
if resize is not None:
|
112
|
+
imgr_tensor = nn.functional.interpolate(imgr_tensor, size = resize, scale_factor=None, mode='bilinear', align_corners=False)
|
113
|
+
if norm_img:
|
114
|
+
imgr_tensor = normalize_images(imgr_tensor)
|
115
|
+
# _, Xr_tmp = PreNetFID(imgr_tensor)
|
116
|
+
Xr_tmp = PreNetFID(imgr_tensor)
|
117
|
+
Xr[tmp:(tmp+batch_size)] = Xr_tmp.detach().cpu().numpy()
|
118
|
+
tmp+=batch_size
|
119
|
+
# pb1.update(min(float(i)*100/(nr//batch_size), 100))
|
120
|
+
pb1.update(min(max(tmp/nr*100,100), 100))
|
121
|
+
del Xr_tmp,imgr_tensor; gc.collect()
|
122
|
+
torch.cuda.empty_cache()
|
123
|
+
|
124
|
+
tmp = 0
|
125
|
+
pb2 = SimpleProgressBar()
|
126
|
+
for j in range(ng//batch_size):
|
127
|
+
imgg_tensor = torch.from_numpy(IMGSg[tmp:(tmp+batch_size)]).type(torch.float).cuda()
|
128
|
+
if resize is not None:
|
129
|
+
imgg_tensor = nn.functional.interpolate(imgg_tensor, size = resize, scale_factor=None, mode='bilinear', align_corners=False)
|
130
|
+
if norm_img:
|
131
|
+
imgg_tensor = normalize_images(imgg_tensor)
|
132
|
+
# _, Xg_tmp = PreNetFID(imgg_tensor)
|
133
|
+
Xg_tmp = PreNetFID(imgg_tensor)
|
134
|
+
Xg[tmp:(tmp+batch_size)] = Xg_tmp.detach().cpu().numpy()
|
135
|
+
tmp+=batch_size
|
136
|
+
# pb2.update(min(float(j)*100/(ng//batch_size), 100))
|
137
|
+
pb2.update(min(max(tmp/ng*100, 100), 100))
|
138
|
+
del Xg_tmp,imgg_tensor; gc.collect()
|
139
|
+
torch.cuda.empty_cache()
|
140
|
+
|
141
|
+
|
142
|
+
fid_score = FID(Xr, Xg, eps=1e-6)
|
143
|
+
|
144
|
+
return fid_score
|
145
|
+
|
146
|
+
|
147
|
+
|
148
|
+
|
149
|
+
|
150
|
+
|
151
|
+
##############################################################################
|
152
|
+
# label_score
|
153
|
+
# difference between assigned label and predicted label
|
154
|
+
##############################################################################
|
155
|
+
def cal_labelscore(PreNet, images, labels_assi, min_label_before_shift, max_label_after_shift, batch_size = 500, resize = None, norm_img = False, num_workers=0):
|
156
|
+
'''
|
157
|
+
PreNet: pre-trained CNN
|
158
|
+
images: fake images
|
159
|
+
labels_assi: assigned labels
|
160
|
+
resize: if None, do not resize; if resize = (H,W), resize images to 3 x H x W
|
161
|
+
'''
|
162
|
+
|
163
|
+
PreNet.eval()
|
164
|
+
|
165
|
+
# assume images are nxncximg_sizeximg_size
|
166
|
+
n = images.shape[0]
|
167
|
+
nc = images.shape[1] #number of channels
|
168
|
+
img_size = images.shape[2]
|
169
|
+
labels_assi = labels_assi.reshape(-1)
|
170
|
+
|
171
|
+
eval_trainset = IMGs_dataset(images, labels_assi, normalize=False)
|
172
|
+
eval_dataloader = torch.utils.data.DataLoader(eval_trainset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
|
173
|
+
|
174
|
+
labels_pred = np.zeros(n+batch_size)
|
175
|
+
|
176
|
+
nimgs_got = 0
|
177
|
+
pb = SimpleProgressBar()
|
178
|
+
for batch_idx, (batch_images, batch_labels) in enumerate(eval_dataloader):
|
179
|
+
batch_images = batch_images.type(torch.float).cuda()
|
180
|
+
batch_labels = batch_labels.type(torch.float).cuda()
|
181
|
+
batch_size_curr = len(batch_labels)
|
182
|
+
|
183
|
+
if norm_img:
|
184
|
+
batch_images = normalize_images(batch_images)
|
185
|
+
|
186
|
+
batch_labels_pred, _ = PreNet(batch_images)
|
187
|
+
labels_pred[nimgs_got:(nimgs_got+batch_size_curr)] = batch_labels_pred.detach().cpu().numpy().reshape(-1)
|
188
|
+
|
189
|
+
nimgs_got += batch_size_curr
|
190
|
+
pb.update((float(nimgs_got)/n)*100)
|
191
|
+
|
192
|
+
del batch_images; gc.collect()
|
193
|
+
torch.cuda.empty_cache()
|
194
|
+
#end for batch_idx
|
195
|
+
|
196
|
+
labels_pred = labels_pred[0:n]
|
197
|
+
|
198
|
+
|
199
|
+
labels_pred = (labels_pred*max_label_after_shift)-np.abs(min_label_before_shift)
|
200
|
+
labels_assi = (labels_assi*max_label_after_shift)-np.abs(min_label_before_shift)
|
201
|
+
|
202
|
+
ls_mean = np.mean(np.abs(labels_pred-labels_assi))
|
203
|
+
ls_std = np.std(np.abs(labels_pred-labels_assi))
|
204
|
+
|
205
|
+
return ls_mean, ls_std
|
@@ -0,0 +1,87 @@
|
|
1
|
+
import argparse
|
2
|
+
|
3
|
+
def parse_opts():
|
4
|
+
parser = argparse.ArgumentParser()
|
5
|
+
|
6
|
+
''' Overall Settings '''
|
7
|
+
parser.add_argument('--root_path', type=str, default='')
|
8
|
+
parser.add_argument('--data_path', type=str, default='')
|
9
|
+
parser.add_argument('--eval_ckpt_path', type=str, default='')
|
10
|
+
parser.add_argument('--seed', type=int, default=2021, metavar='S', help='random seed (default: 2020)')
|
11
|
+
parser.add_argument('--num_workers', type=int, default=0)
|
12
|
+
|
13
|
+
|
14
|
+
''' Dataset '''
|
15
|
+
## Data split: RC-49 is split into a train set (the last decimal of the degree is odd) and a test set (the last decimal of the degree is even); the unique labels in two sets do not overlap.
|
16
|
+
parser.add_argument('--data_split', type=str, default='train',
|
17
|
+
choices=['all', 'train'])
|
18
|
+
parser.add_argument('--min_label', type=float, default=0.0)
|
19
|
+
parser.add_argument('--max_label', type=float, default=90.0)
|
20
|
+
parser.add_argument('--num_channels', type=int, default=3, metavar='N')
|
21
|
+
parser.add_argument('--img_size', type=int, default=128, metavar='N')
|
22
|
+
parser.add_argument('--max_num_img_per_label', type=int, default=50, metavar='N')
|
23
|
+
parser.add_argument('--max_num_img_per_label_after_replica', type=int, default=0, metavar='N')
|
24
|
+
parser.add_argument('--show_real_imgs', action='store_true', default=False)
|
25
|
+
parser.add_argument('--visualize_fake_images', action='store_true', default=False)
|
26
|
+
|
27
|
+
|
28
|
+
''' GAN settings '''
|
29
|
+
parser.add_argument('--GAN', type=str, default='CcGAN', choices=['cGAN', 'cGAN-concat', 'CcGAN'])
|
30
|
+
parser.add_argument('--GAN_arch', type=str, default='SAGAN', choices=['SAGAN'])
|
31
|
+
|
32
|
+
# label embedding setting
|
33
|
+
parser.add_argument('--net_embed', type=str, default='ResNet34_embed') #ResNetXX_emebed
|
34
|
+
parser.add_argument('--epoch_cnn_embed', type=int, default=200) #epoch of cnn training for label embedding
|
35
|
+
parser.add_argument('--resumeepoch_cnn_embed', type=int, default=0) #epoch of cnn training for label embedding
|
36
|
+
parser.add_argument('--epoch_net_y2h', type=int, default=500)
|
37
|
+
parser.add_argument('--dim_embed', type=int, default=128) #dimension of the embedding space
|
38
|
+
parser.add_argument('--batch_size_embed', type=int, default=256, metavar='N')
|
39
|
+
|
40
|
+
parser.add_argument('--loss_type_gan', type=str, default='hinge')
|
41
|
+
parser.add_argument('--niters_gan', type=int, default=10000, help='number of iterations')
|
42
|
+
parser.add_argument('--resume_niters_gan', type=int, default=0)
|
43
|
+
parser.add_argument('--save_niters_freq', type=int, default=2000, help='frequency of saving checkpoints')
|
44
|
+
parser.add_argument('--lr_g_gan', type=float, default=1e-4, help='learning rate for generator')
|
45
|
+
parser.add_argument('--lr_d_gan', type=float, default=1e-4, help='learning rate for discriminator')
|
46
|
+
parser.add_argument('--dim_gan', type=int, default=128, help='Latent dimension of GAN')
|
47
|
+
parser.add_argument('--batch_size_disc', type=int, default=64)
|
48
|
+
parser.add_argument('--batch_size_gene', type=int, default=64)
|
49
|
+
parser.add_argument('--num_D_steps', type=int, default=4, help='number of Ds updates in one iteration')
|
50
|
+
parser.add_argument('--cGAN_num_classes', type=int, default=20, metavar='N') #bin label into cGAN_num_classes
|
51
|
+
parser.add_argument('--visualize_freq', type=int, default=2000, help='frequency of visualization')
|
52
|
+
|
53
|
+
parser.add_argument('--kernel_sigma', type=float, default=-1.0,
|
54
|
+
help='If kernel_sigma<0, then use rule-of-thumb formula to compute the sigma.')
|
55
|
+
parser.add_argument('--threshold_type', type=str, default='hard', choices=['soft', 'hard'])
|
56
|
+
parser.add_argument('--kappa', type=float, default=-1)
|
57
|
+
parser.add_argument('--nonzero_soft_weight_threshold', type=float, default=1e-3,
|
58
|
+
help='threshold for determining nonzero weights for SVDL; we neglect images with too small weights')
|
59
|
+
|
60
|
+
# DiffAugment setting
|
61
|
+
parser.add_argument('--gan_DiffAugment', action='store_true', default=False)
|
62
|
+
parser.add_argument('--gan_DiffAugment_policy', type=str, default='color,translation,cutout')
|
63
|
+
|
64
|
+
|
65
|
+
# evaluation setting
|
66
|
+
'''
|
67
|
+
Four evaluation modes:
|
68
|
+
Mode 1: eval on unique labels used for GAN training;
|
69
|
+
Mode 2. eval on all unique labels in the dataset and when computing FID use all real images in the dataset;
|
70
|
+
Mode 3. eval on all unique labels in the dataset and when computing FID only use real images for GAN training in the dataset (to test SFID's effectiveness on unseen labels);
|
71
|
+
Mode 4. eval on a interval [min_label, max_label] with num_eval_labels labels.
|
72
|
+
'''
|
73
|
+
parser.add_argument('--eval_mode', type=int, default=2)
|
74
|
+
parser.add_argument('--num_eval_labels', type=int, default=-1)
|
75
|
+
parser.add_argument('--samp_batch_size', type=int, default=200)
|
76
|
+
parser.add_argument('--nfake_per_label', type=int, default=200)
|
77
|
+
parser.add_argument('--nreal_per_label', type=int, default=-1)
|
78
|
+
parser.add_argument('--comp_FID', action='store_true', default=False)
|
79
|
+
parser.add_argument('--epoch_FID_CNN', type=int, default=200)
|
80
|
+
parser.add_argument('--FID_radius', type=float, default=0)
|
81
|
+
parser.add_argument('--FID_num_centers', type=int, default=-1)
|
82
|
+
parser.add_argument('--dump_fake_for_NIQE', action='store_true', default=False,
|
83
|
+
help='Dump fake images for computing NIQE')
|
84
|
+
|
85
|
+
args = parser.parse_args()
|
86
|
+
|
87
|
+
return args
|
@@ -0,0 +1,268 @@
|
|
1
|
+
|
2
|
+
import os
|
3
|
+
import argparse
|
4
|
+
import shutil
|
5
|
+
import timeit
|
6
|
+
import torch
|
7
|
+
import torchvision
|
8
|
+
import torchvision.transforms as transforms
|
9
|
+
import numpy as np
|
10
|
+
import torch.nn as nn
|
11
|
+
import torch.backends.cudnn as cudnn
|
12
|
+
import random
|
13
|
+
import matplotlib.pyplot as plt
|
14
|
+
import matplotlib as mpl
|
15
|
+
from torch import autograd
|
16
|
+
from torchvision.utils import save_image
|
17
|
+
import csv
|
18
|
+
from tqdm import tqdm
|
19
|
+
import gc
|
20
|
+
import h5py
|
21
|
+
|
22
|
+
|
23
|
+
#############################
|
24
|
+
# Settings
|
25
|
+
#############################
|
26
|
+
|
27
|
+
parser = argparse.ArgumentParser(description='Pre-train AE for computing FID')
|
28
|
+
parser.add_argument('--root_path', type=str, default='')
|
29
|
+
parser.add_argument('--data_path', type=str, default='')
|
30
|
+
parser.add_argument('--num_workers', type=int, default=0)
|
31
|
+
parser.add_argument('--dim_bottleneck', type=int, default=512)
|
32
|
+
parser.add_argument('--epochs', type=int, default=200, metavar='N',
|
33
|
+
help='number of epochs to train CNNs (default: 200)')
|
34
|
+
parser.add_argument('--resume_epoch', type=int, default=0)
|
35
|
+
parser.add_argument('--batch_size_train', type=int, default=128, metavar='N',
|
36
|
+
help='input batch size for training')
|
37
|
+
parser.add_argument('--batch_size_valid', type=int, default=10, metavar='N',
|
38
|
+
help='input batch size for testing')
|
39
|
+
parser.add_argument('--base_lr', type=float, default=1e-3,
|
40
|
+
help='learning rate, default=1e-3')
|
41
|
+
parser.add_argument('--lr_decay_epochs', type=int, default=50) #decay lr rate every dre_lr_decay_epochs epochs
|
42
|
+
parser.add_argument('--lr_decay_factor', type=float, default=0.1)
|
43
|
+
parser.add_argument('--lambda_sparsity', type=float, default=1e-4, help='penalty for sparsity')
|
44
|
+
parser.add_argument('--weight_dacay', type=float, default=1e-4,
|
45
|
+
help='Weigth decay, default=1e-4')
|
46
|
+
parser.add_argument('--seed', type=int, default=2020, metavar='S',
|
47
|
+
help='random seed (default: 1)')
|
48
|
+
parser.add_argument('--CVMode', action='store_true', default=False,
|
49
|
+
help='CV mode?')
|
50
|
+
parser.add_argument('--img_size', type=int, default=128, metavar='N')
|
51
|
+
parser.add_argument('--min_label', type=float, default=0.0)
|
52
|
+
parser.add_argument('--max_label', type=float, default=90.0)
|
53
|
+
args = parser.parse_args()
|
54
|
+
|
55
|
+
wd = args.root_path
|
56
|
+
os.chdir(wd)
|
57
|
+
from ..models_128 import *
|
58
|
+
from .utils import IMGs_dataset, SimpleProgressBar
|
59
|
+
|
60
|
+
# some parameters in the opts
|
61
|
+
dim_bottleneck = args.dim_bottleneck
|
62
|
+
epochs = args.epochs
|
63
|
+
base_lr = args.base_lr
|
64
|
+
lr_decay_epochs = args.lr_decay_epochs
|
65
|
+
lr_decay_factor = args.lr_decay_factor
|
66
|
+
resume_epoch = args.resume_epoch
|
67
|
+
lambda_sparsity = args.lambda_sparsity
|
68
|
+
|
69
|
+
|
70
|
+
# random seed
|
71
|
+
random.seed(args.seed)
|
72
|
+
torch.manual_seed(args.seed)
|
73
|
+
torch.backends.cudnn.deterministic = True
|
74
|
+
cudnn.benchmark = False
|
75
|
+
np.random.seed(args.seed)
|
76
|
+
|
77
|
+
# directories for checkpoint, images and log files
|
78
|
+
save_models_folder = wd + '/output/eval_models'
|
79
|
+
os.makedirs(save_models_folder, exist_ok=True)
|
80
|
+
save_AE_images_in_train_folder = save_models_folder + '/AE_lambda_{}_images_in_train'.format(lambda_sparsity)
|
81
|
+
os.makedirs(save_AE_images_in_train_folder, exist_ok=True)
|
82
|
+
save_AE_images_in_valid_folder = save_models_folder + '/AE_lambda_{}_images_in_valid'.format(lambda_sparsity)
|
83
|
+
os.makedirs(save_AE_images_in_valid_folder, exist_ok=True)
|
84
|
+
|
85
|
+
|
86
|
+
###########################################################################################################
|
87
|
+
# Data
|
88
|
+
###########################################################################################################
|
89
|
+
# data loader
|
90
|
+
data_filename = args.data_path + '/Ra_' + str(args.img_size) + 'x' + str(args.img_size) + '.h5'
|
91
|
+
hf = h5py.File(data_filename, 'r')
|
92
|
+
labels = hf['labels'][:]
|
93
|
+
labels = labels.astype(float)
|
94
|
+
images = hf['images'][:]
|
95
|
+
hf.close()
|
96
|
+
N_all = len(images)
|
97
|
+
assert len(images) == len(labels)
|
98
|
+
|
99
|
+
q1 = args.min_label
|
100
|
+
q2 = args.max_label
|
101
|
+
indx = np.where((labels>q1)*(labels<q2)==True)[0]
|
102
|
+
labels = labels[indx]
|
103
|
+
images = images[indx]
|
104
|
+
assert len(labels)==len(images)
|
105
|
+
|
106
|
+
# define training and validation sets
|
107
|
+
if args.CVMode:
|
108
|
+
#90% Training; 10% valdation
|
109
|
+
valid_prop = 0.1 #proportion of the validation samples
|
110
|
+
indx_all = np.arange(len(images))
|
111
|
+
np.random.shuffle(indx_all)
|
112
|
+
indx_valid = indx_all[0:int(valid_prop*len(images))]
|
113
|
+
indx_train = indx_all[int(valid_prop*len(images)):]
|
114
|
+
|
115
|
+
trainset = IMGs_dataset(images[indx_train], labels=None, normalize=True)
|
116
|
+
trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size_train, shuffle=True, num_workers=args.num_workers)
|
117
|
+
validset = IMGs_dataset(images[indx_valid], labels=None, normalize=True)
|
118
|
+
validloader = torch.utils.data.DataLoader(validset, batch_size=args.batch_size_valid, shuffle=False, num_workers=args.num_workers)
|
119
|
+
|
120
|
+
else:
|
121
|
+
trainset = IMGs_dataset(images, labels=None, normalize=True)
|
122
|
+
trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size_train, shuffle=True, num_workers=args.num_workers)
|
123
|
+
|
124
|
+
|
125
|
+
###########################################################################################################
|
126
|
+
# Necessary functions
|
127
|
+
###########################################################################################################
|
128
|
+
|
129
|
+
def adjust_learning_rate(epoch, epochs, optimizer, base_lr, lr_decay_epochs, lr_decay_factor):
|
130
|
+
lr = base_lr #1e-4
|
131
|
+
|
132
|
+
for i in range(epochs//lr_decay_epochs):
|
133
|
+
if epoch >= (i+1)*lr_decay_epochs:
|
134
|
+
lr *= lr_decay_factor
|
135
|
+
|
136
|
+
for param_group in optimizer.param_groups:
|
137
|
+
param_group['lr'] = lr
|
138
|
+
|
139
|
+
def train_AE():
|
140
|
+
|
141
|
+
# define optimizer
|
142
|
+
params = list(net_encoder.parameters()) + list(net_decoder.parameters())
|
143
|
+
optimizer = torch.optim.Adam(params, lr = base_lr, betas=(0.5, 0.999), weight_decay=1e-4)
|
144
|
+
|
145
|
+
# criterion
|
146
|
+
criterion = nn.MSELoss()
|
147
|
+
|
148
|
+
if resume_epoch>0:
|
149
|
+
print("Loading ckpt to resume training AE >>>")
|
150
|
+
ckpt_fullpath = save_models_folder + "/AE_checkpoint_intrain/AE_checkpoint_epoch_{}_lambda_{}.pth".format(resume_epoch, lambda_sparsity)
|
151
|
+
checkpoint = torch.load(ckpt_fullpath)
|
152
|
+
net_encoder.load_state_dict(checkpoint['net_encoder_state_dict'])
|
153
|
+
net_decoder.load_state_dict(checkpoint['net_decoder_state_dict'])
|
154
|
+
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
155
|
+
torch.set_rng_state(checkpoint['rng_state'])
|
156
|
+
gen_iterations = checkpoint['gen_iterations']
|
157
|
+
else:
|
158
|
+
gen_iterations = 0
|
159
|
+
|
160
|
+
start_time = timeit.default_timer()
|
161
|
+
for epoch in range(resume_epoch, epochs):
|
162
|
+
|
163
|
+
adjust_learning_rate(epoch, epochs, optimizer, base_lr, lr_decay_epochs, lr_decay_factor)
|
164
|
+
|
165
|
+
train_loss = 0
|
166
|
+
|
167
|
+
for batch_idx, batch_real_images in enumerate(trainloader):
|
168
|
+
|
169
|
+
net_encoder.train()
|
170
|
+
net_decoder.train()
|
171
|
+
|
172
|
+
batch_size_curr = batch_real_images.shape[0]
|
173
|
+
|
174
|
+
batch_real_images = batch_real_images.type(torch.float).cuda()
|
175
|
+
|
176
|
+
|
177
|
+
batch_features = net_encoder(batch_real_images)
|
178
|
+
batch_recons_images = net_decoder(batch_features)
|
179
|
+
|
180
|
+
'''
|
181
|
+
based on https://debuggercafe.com/sparse-autoencoders-using-l1-regularization-with-pytorch/
|
182
|
+
'''
|
183
|
+
loss = criterion(batch_recons_images, batch_real_images) + lambda_sparsity * batch_features.mean()
|
184
|
+
|
185
|
+
#backward pass
|
186
|
+
optimizer.zero_grad()
|
187
|
+
loss.backward()
|
188
|
+
optimizer.step()
|
189
|
+
|
190
|
+
train_loss += loss.cpu().item()
|
191
|
+
|
192
|
+
gen_iterations += 1
|
193
|
+
|
194
|
+
if gen_iterations % 100 == 0:
|
195
|
+
n_row=min(10, int(np.sqrt(batch_size_curr)))
|
196
|
+
with torch.no_grad():
|
197
|
+
batch_recons_images = net_decoder(net_encoder(batch_real_images[0:n_row**2]))
|
198
|
+
batch_recons_images = batch_recons_images.detach().cpu()
|
199
|
+
save_image(batch_recons_images.data, save_AE_images_in_train_folder + '/{}.png'.format(gen_iterations), nrow=n_row, normalize=True)
|
200
|
+
|
201
|
+
if gen_iterations % 20 == 0:
|
202
|
+
print("AE+lambda{}: [step {}] [epoch {}/{}] [train loss {}] [Time {}]".format(lambda_sparsity, gen_iterations, epoch+1, epochs, train_loss/(batch_idx+1), timeit.default_timer()-start_time) )
|
203
|
+
# end for batch_idx
|
204
|
+
|
205
|
+
if (epoch+1) % 50 == 0:
|
206
|
+
save_file = save_models_folder + "/AE_checkpoint_intrain/AE_checkpoint_epoch_{}_lambda_{}.pth".format(epoch+1, lambda_sparsity)
|
207
|
+
os.makedirs(os.path.dirname(save_file), exist_ok=True)
|
208
|
+
torch.save({
|
209
|
+
'gen_iterations': gen_iterations,
|
210
|
+
'net_encoder_state_dict': net_encoder.state_dict(),
|
211
|
+
'net_decoder_state_dict': net_decoder.state_dict(),
|
212
|
+
'optimizer_state_dict': optimizer.state_dict(),
|
213
|
+
'rng_state': torch.get_rng_state()
|
214
|
+
}, save_file)
|
215
|
+
#end for epoch
|
216
|
+
|
217
|
+
return net_encoder, net_decoder
|
218
|
+
|
219
|
+
|
220
|
+
if args.CVMode:
|
221
|
+
def valid_AE():
|
222
|
+
net_encoder.eval()
|
223
|
+
net_decoder.eval()
|
224
|
+
with torch.no_grad():
|
225
|
+
for batch_idx, images in enumerate(validloader):
|
226
|
+
images = images.type(torch.float).cuda()
|
227
|
+
features = net_encoder(images)
|
228
|
+
recons_images = net_decoder(features)
|
229
|
+
save_image(recons_images.data, save_AE_images_in_valid_folder + '/{}_recons.png'.format(batch_idx), nrow=10, normalize=True)
|
230
|
+
save_image(images.data, save_AE_images_in_valid_folder + '/{}_real.png'.format(batch_idx), nrow=10, normalize=True)
|
231
|
+
return None
|
232
|
+
|
233
|
+
|
234
|
+
|
235
|
+
###########################################################################################################
|
236
|
+
# Training and validation
|
237
|
+
###########################################################################################################
|
238
|
+
|
239
|
+
# model initialization
|
240
|
+
net_encoder = encoder(dim_bottleneck=args.dim_bottleneck).cuda()
|
241
|
+
net_decoder = decoder(dim_bottleneck=args.dim_bottleneck).cuda()
|
242
|
+
net_encoder = nn.DataParallel(net_encoder)
|
243
|
+
net_decoder = nn.DataParallel(net_decoder)
|
244
|
+
|
245
|
+
filename_ckpt = save_models_folder + '/ckpt_AE_epoch_{}_seed_{}_CVMode_{}.pth'.format(args.epochs, args.seed, args.CVMode)
|
246
|
+
|
247
|
+
# training
|
248
|
+
if not os.path.isfile(filename_ckpt):
|
249
|
+
print("\n Begin training AE: ")
|
250
|
+
start = timeit.default_timer()
|
251
|
+
net_encoder, net_decoder = train_AE()
|
252
|
+
stop = timeit.default_timer()
|
253
|
+
print("Time elapses: {}s".format(stop - start))
|
254
|
+
# save model
|
255
|
+
torch.save({
|
256
|
+
'net_encoder_state_dict': net_encoder.state_dict(),
|
257
|
+
'net_decoder_state_dict': net_decoder.state_dict(),
|
258
|
+
}, filename_ckpt)
|
259
|
+
else:
|
260
|
+
print("\n Ckpt already exists")
|
261
|
+
print("\n Loading...")
|
262
|
+
checkpoint = torch.load(filename_ckpt)
|
263
|
+
net_encoder.load_state_dict(checkpoint['net_encoder_state_dict'])
|
264
|
+
net_decoder.load_state_dict(checkpoint['net_decoder_state_dict'])
|
265
|
+
|
266
|
+
if args.CVMode:
|
267
|
+
#validation
|
268
|
+
_ = valid_AE()
|