Myosotis-Researches 0.1.8__py3-none-any.whl → 0.1.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- myosotis_researches/CcGAN/train/__init__.py +4 -0
- myosotis_researches/CcGAN/{train_128_output_10 → train}/train_ccgan.py +4 -4
- myosotis_researches/CcGAN/{train_128_output_10 → train}/train_cgan.py +1 -3
- myosotis_researches/CcGAN/{train_128_output_10 → train}/train_cgan_concat.py +1 -3
- {myosotis_researches-0.1.8.dist-info → myosotis_researches-0.1.10.dist-info}/METADATA +1 -1
- myosotis_researches-0.1.10.dist-info/RECORD +40 -0
- myosotis_researches/CcGAN/train_128/DiffAugment_pytorch.py +0 -76
- myosotis_researches/CcGAN/train_128/__init__.py +0 -0
- myosotis_researches/CcGAN/train_128/eval_metrics.py +0 -205
- myosotis_researches/CcGAN/train_128/opts.py +0 -87
- myosotis_researches/CcGAN/train_128/pretrain_AE.py +0 -268
- myosotis_researches/CcGAN/train_128/pretrain_CNN_class.py +0 -251
- myosotis_researches/CcGAN/train_128/pretrain_CNN_regre.py +0 -255
- myosotis_researches/CcGAN/train_128/train_ccgan.py +0 -303
- myosotis_researches/CcGAN/train_128/train_cgan.py +0 -254
- myosotis_researches/CcGAN/train_128/train_cgan_concat.py +0 -242
- myosotis_researches/CcGAN/train_128/utils.py +0 -120
- myosotis_researches/CcGAN/train_128_output_10/DiffAugment_pytorch.py +0 -76
- myosotis_researches/CcGAN/train_128_output_10/__init__.py +0 -0
- myosotis_researches/CcGAN/train_128_output_10/eval_metrics.py +0 -205
- myosotis_researches/CcGAN/train_128_output_10/opts.py +0 -87
- myosotis_researches/CcGAN/train_128_output_10/pretrain_AE.py +0 -268
- myosotis_researches/CcGAN/train_128_output_10/pretrain_CNN_class.py +0 -251
- myosotis_researches/CcGAN/train_128_output_10/pretrain_CNN_regre.py +0 -255
- myosotis_researches/CcGAN/train_128_output_10/train_net_for_label_embed.py +0 -181
- myosotis_researches/CcGAN/train_128_output_10/utils.py +0 -120
- myosotis_researches-0.1.8.dist-info/RECORD +0 -59
- /myosotis_researches/CcGAN/{train_128 → train}/train_net_for_label_embed.py +0 -0
- {myosotis_researches-0.1.8.dist-info → myosotis_researches-0.1.10.dist-info}/WHEEL +0 -0
- {myosotis_researches-0.1.8.dist-info → myosotis_researches-0.1.10.dist-info}/licenses/LICENSE +0 -0
- {myosotis_researches-0.1.8.dist-info → myosotis_researches-0.1.10.dist-info}/top_level.txt +0 -0
@@ -1,254 +0,0 @@
|
|
1
|
-
|
2
|
-
import torch
|
3
|
-
import torch.nn as nn
|
4
|
-
from torchvision.utils import save_image
|
5
|
-
import numpy as np
|
6
|
-
import os
|
7
|
-
import timeit
|
8
|
-
|
9
|
-
from .utils import IMGs_dataset, SimpleProgressBar
|
10
|
-
from .opts import parse_opts
|
11
|
-
from .DiffAugment_pytorch import DiffAugment
|
12
|
-
|
13
|
-
''' Settings '''
|
14
|
-
args = parse_opts()
|
15
|
-
|
16
|
-
# some parameters in opts
|
17
|
-
loss_type = args.loss_type_gan
|
18
|
-
niters = args.niters_gan
|
19
|
-
resume_niters = args.resume_niters_gan
|
20
|
-
dim_gan = args.dim_gan
|
21
|
-
lr_g = args.lr_g_gan
|
22
|
-
lr_d = args.lr_d_gan
|
23
|
-
save_niters_freq = args.save_niters_freq
|
24
|
-
batch_size = min(args.batch_size_disc, args.batch_size_gene)
|
25
|
-
num_classes = args.cGAN_num_classes
|
26
|
-
gan_arch = args.GAN_arch
|
27
|
-
num_D_steps = args.num_D_steps
|
28
|
-
|
29
|
-
visualize_freq = args.visualize_freq
|
30
|
-
|
31
|
-
num_workers = args.num_workers
|
32
|
-
|
33
|
-
NC = args.num_channels
|
34
|
-
IMG_SIZE = args.img_size
|
35
|
-
|
36
|
-
use_DiffAugment = args.gan_DiffAugment
|
37
|
-
policy = args.gan_DiffAugment_policy
|
38
|
-
|
39
|
-
|
40
|
-
def train_cgan(images, labels, netG, netD, save_images_folder, save_models_folder = None):
|
41
|
-
|
42
|
-
netG = netG.cuda()
|
43
|
-
netD = netD.cuda()
|
44
|
-
|
45
|
-
optimizerG = torch.optim.Adam(netG.parameters(), lr=lr_g, betas=(0.5, 0.999))
|
46
|
-
optimizerD = torch.optim.Adam(netD.parameters(), lr=lr_d, betas=(0.5, 0.999))
|
47
|
-
|
48
|
-
trainset = IMGs_dataset(images, labels, normalize=True)
|
49
|
-
train_dataloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
|
50
|
-
unique_labels = np.sort(np.array(list(set(labels)))).astype(np.int)
|
51
|
-
|
52
|
-
if save_models_folder is not None and resume_niters>0:
|
53
|
-
save_file = save_models_folder + "/cGAN_{}_nDsteps_{}_checkpoint_intrain/cGAN_checkpoint_niters_{}.pth".format(gan_arch, num_D_steps, resume_niters)
|
54
|
-
checkpoint = torch.load(save_file)
|
55
|
-
netG.load_state_dict(checkpoint['netG_state_dict'])
|
56
|
-
netD.load_state_dict(checkpoint['netD_state_dict'])
|
57
|
-
optimizerG.load_state_dict(checkpoint['optimizerG_state_dict'])
|
58
|
-
optimizerD.load_state_dict(checkpoint['optimizerD_state_dict'])
|
59
|
-
torch.set_rng_state(checkpoint['rng_state'])
|
60
|
-
#end if
|
61
|
-
|
62
|
-
n_row=10
|
63
|
-
z_fixed = torch.randn(n_row**2, dim_gan, dtype=torch.float).cuda()
|
64
|
-
unique_labels = np.sort(unique_labels)
|
65
|
-
selected_labels = np.zeros(n_row)
|
66
|
-
indx_step_size = len(unique_labels)//n_row
|
67
|
-
for i in range(n_row):
|
68
|
-
indx = i*indx_step_size
|
69
|
-
selected_labels[i] = unique_labels[indx]
|
70
|
-
y_fixed = np.zeros(n_row**2)
|
71
|
-
for i in range(n_row):
|
72
|
-
curr_label = selected_labels[i]
|
73
|
-
for j in range(n_row):
|
74
|
-
y_fixed[i*n_row+j] = curr_label
|
75
|
-
y_fixed = torch.from_numpy(y_fixed).type(torch.long).cuda()
|
76
|
-
|
77
|
-
|
78
|
-
batch_idx = 0
|
79
|
-
dataloader_iter = iter(train_dataloader)
|
80
|
-
|
81
|
-
start_time = timeit.default_timer()
|
82
|
-
for niter in range(resume_niters, niters):
|
83
|
-
|
84
|
-
if batch_idx+1 == len(train_dataloader):
|
85
|
-
dataloader_iter = iter(train_dataloader)
|
86
|
-
batch_idx = 0
|
87
|
-
|
88
|
-
'''
|
89
|
-
|
90
|
-
Train Generator: maximize log(D(G(z)))
|
91
|
-
|
92
|
-
'''
|
93
|
-
|
94
|
-
netG.train()
|
95
|
-
|
96
|
-
# get training images
|
97
|
-
_, batch_train_labels = dataloader_iter.next()
|
98
|
-
assert batch_size == batch_train_labels.shape[0]
|
99
|
-
batch_train_labels = batch_train_labels.type(torch.long).cuda()
|
100
|
-
batch_idx+=1
|
101
|
-
|
102
|
-
# Sample noise and labels as generator input
|
103
|
-
z = torch.randn(batch_size, dim_gan, dtype=torch.float).cuda()
|
104
|
-
|
105
|
-
#generate fake images
|
106
|
-
batch_fake_images = netG(z, batch_train_labels)
|
107
|
-
|
108
|
-
# Loss measures generator's ability to fool the discriminator
|
109
|
-
if use_DiffAugment:
|
110
|
-
dis_out = netD(DiffAugment(batch_fake_images, policy=policy), batch_train_labels)
|
111
|
-
else:
|
112
|
-
dis_out = netD(batch_fake_images, batch_train_labels)
|
113
|
-
|
114
|
-
if loss_type == "vanilla":
|
115
|
-
dis_out = torch.nn.Sigmoid()(dis_out)
|
116
|
-
g_loss = - torch.mean(torch.log(dis_out+1e-20))
|
117
|
-
elif loss_type == "hinge":
|
118
|
-
g_loss = - dis_out.mean()
|
119
|
-
|
120
|
-
optimizerG.zero_grad()
|
121
|
-
g_loss.backward()
|
122
|
-
optimizerG.step()
|
123
|
-
|
124
|
-
'''
|
125
|
-
|
126
|
-
Train Discriminator: maximize log(D(x)) + log(1 - D(G(z)))
|
127
|
-
|
128
|
-
'''
|
129
|
-
|
130
|
-
for _ in range(num_D_steps):
|
131
|
-
|
132
|
-
if batch_idx+1 == len(train_dataloader):
|
133
|
-
dataloader_iter = iter(train_dataloader)
|
134
|
-
batch_idx = 0
|
135
|
-
|
136
|
-
# get training images
|
137
|
-
batch_train_images, batch_train_labels = dataloader_iter.next()
|
138
|
-
assert batch_size == batch_train_images.shape[0]
|
139
|
-
batch_train_images = batch_train_images.type(torch.float).cuda()
|
140
|
-
batch_train_labels = batch_train_labels.type(torch.long).cuda()
|
141
|
-
batch_idx+=1
|
142
|
-
|
143
|
-
# Measure discriminator's ability to classify real from generated samples
|
144
|
-
if use_DiffAugment:
|
145
|
-
real_dis_out = netD(DiffAugment(batch_train_images, policy=policy), batch_train_labels)
|
146
|
-
fake_dis_out = netD(DiffAugment(batch_fake_images.detach(), policy=policy), batch_train_labels.detach())
|
147
|
-
else:
|
148
|
-
real_dis_out = netD(batch_train_images, batch_train_labels)
|
149
|
-
fake_dis_out = netD(batch_fake_images.detach(), batch_train_labels.detach())
|
150
|
-
|
151
|
-
|
152
|
-
if loss_type == "vanilla":
|
153
|
-
real_dis_out = torch.nn.Sigmoid()(real_dis_out)
|
154
|
-
fake_dis_out = torch.nn.Sigmoid()(fake_dis_out)
|
155
|
-
d_loss_real = - torch.log(real_dis_out+1e-20)
|
156
|
-
d_loss_fake = - torch.log(1-fake_dis_out+1e-20)
|
157
|
-
elif loss_type == "hinge":
|
158
|
-
d_loss_real = torch.nn.ReLU()(1.0 - real_dis_out)
|
159
|
-
d_loss_fake = torch.nn.ReLU()(1.0 + fake_dis_out)
|
160
|
-
d_loss = (d_loss_real + d_loss_fake).mean()
|
161
|
-
|
162
|
-
optimizerD.zero_grad()
|
163
|
-
d_loss.backward()
|
164
|
-
optimizerD.step()
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
if (niter+1)%20 == 0:
|
169
|
-
print ("cGAN-%s: [Iter %d/%d] [D loss: %.4f] [G loss: %.4f] [D out real:%.4f] [D out fake:%.4f] [Time: %.4f]" % (gan_arch, niter+1, niters, d_loss.item(), g_loss.item(), real_dis_out.mean().item(),fake_dis_out.mean().item(), timeit.default_timer()-start_time))
|
170
|
-
|
171
|
-
|
172
|
-
if (niter+1) % visualize_freq == 0:
|
173
|
-
netG.eval()
|
174
|
-
with torch.no_grad():
|
175
|
-
gen_imgs = netG(z_fixed, y_fixed)
|
176
|
-
gen_imgs = gen_imgs.detach()
|
177
|
-
save_image(gen_imgs.data, save_images_folder +'/{}.png'.format(niter+1), nrow=n_row, normalize=True)
|
178
|
-
|
179
|
-
if save_models_folder is not None and ((niter+1) % save_niters_freq == 0 or (niter+1) == niters):
|
180
|
-
save_file = save_models_folder + "/cGAN_{}_nDsteps_{}_checkpoint_intrain/cGAN_checkpoint_niters_{}.pth".format(gan_arch, num_D_steps, niter+1)
|
181
|
-
os.makedirs(os.path.dirname(save_file), exist_ok=True)
|
182
|
-
torch.save({
|
183
|
-
'netG_state_dict': netG.state_dict(),
|
184
|
-
'netD_state_dict': netD.state_dict(),
|
185
|
-
'optimizerG_state_dict': optimizerG.state_dict(),
|
186
|
-
'optimizerD_state_dict': optimizerD.state_dict(),
|
187
|
-
'rng_state': torch.get_rng_state()
|
188
|
-
}, save_file)
|
189
|
-
#end for niter
|
190
|
-
|
191
|
-
|
192
|
-
return netG, netD
|
193
|
-
|
194
|
-
|
195
|
-
def sample_cgan_given_labels(netG, given_labels, class_cutoff_points, batch_size = 500, denorm=True, to_numpy=True, verbose=True):
|
196
|
-
'''
|
197
|
-
given_labels: a numpy array; raw label without any normalization; not class label
|
198
|
-
class_cutoff_points: the cutoff points to determine the membership of a give label
|
199
|
-
'''
|
200
|
-
|
201
|
-
class_cutoff_points = np.array(class_cutoff_points)
|
202
|
-
num_classes = len(class_cutoff_points)-1
|
203
|
-
|
204
|
-
nfake = len(given_labels)
|
205
|
-
given_class_labels = np.zeros(nfake)
|
206
|
-
for i in range(nfake):
|
207
|
-
curr_given_label = given_labels[i]
|
208
|
-
diff_tmp = class_cutoff_points - curr_given_label
|
209
|
-
indx_nonneg = np.where(diff_tmp>=0)[0]
|
210
|
-
if len(indx_nonneg)==1: #the last element of diff_tmp is non-negative
|
211
|
-
curr_given_class_label = num_classes-1
|
212
|
-
assert indx_nonneg[0] == num_classes
|
213
|
-
elif len(indx_nonneg)>1:
|
214
|
-
if diff_tmp[indx_nonneg[0]]>0:
|
215
|
-
curr_given_class_label = indx_nonneg[0] - 1
|
216
|
-
else:
|
217
|
-
curr_given_class_label = indx_nonneg[0]
|
218
|
-
given_class_labels[i] = curr_given_class_label
|
219
|
-
given_class_labels = np.concatenate((given_class_labels, given_class_labels[0:batch_size]))
|
220
|
-
|
221
|
-
if batch_size>nfake:
|
222
|
-
batch_size = nfake
|
223
|
-
fake_images = []
|
224
|
-
netG=netG.cuda()
|
225
|
-
netG.eval()
|
226
|
-
with torch.no_grad():
|
227
|
-
if verbose:
|
228
|
-
pb = SimpleProgressBar()
|
229
|
-
tmp = 0
|
230
|
-
while tmp < nfake:
|
231
|
-
z = torch.randn(batch_size, dim_gan, dtype=torch.float).cuda()
|
232
|
-
labels = torch.from_numpy(given_class_labels[tmp:(tmp+batch_size)]).type(torch.long).cuda()
|
233
|
-
if labels.max().item()>num_classes:
|
234
|
-
print("Error: max label {}".format(labels.max().item()))
|
235
|
-
batch_fake_images = netG(z, labels)
|
236
|
-
if denorm: #denorm imgs to save memory
|
237
|
-
assert batch_fake_images.max().item()<=1.0 and batch_fake_images.min().item()>=-1.0
|
238
|
-
batch_fake_images = batch_fake_images*0.5+0.5
|
239
|
-
batch_fake_images = batch_fake_images*255.0
|
240
|
-
batch_fake_images = batch_fake_images.type(torch.uint8)
|
241
|
-
# assert batch_fake_images.max().item()>1
|
242
|
-
fake_images.append(batch_fake_images.detach().cpu())
|
243
|
-
tmp += batch_size
|
244
|
-
if verbose:
|
245
|
-
pb.update(min(float(tmp)/nfake, 1)*100)
|
246
|
-
|
247
|
-
fake_images = torch.cat(fake_images, dim=0)
|
248
|
-
#remove extra entries
|
249
|
-
fake_images = fake_images[0:nfake]
|
250
|
-
|
251
|
-
if to_numpy:
|
252
|
-
fake_images = fake_images.numpy()
|
253
|
-
|
254
|
-
return fake_images, given_labels
|
@@ -1,242 +0,0 @@
|
|
1
|
-
|
2
|
-
import torch
|
3
|
-
import torch.nn as nn
|
4
|
-
from torchvision.utils import save_image
|
5
|
-
import numpy as np
|
6
|
-
import os
|
7
|
-
import timeit
|
8
|
-
|
9
|
-
from .utils import IMGs_dataset, SimpleProgressBar
|
10
|
-
from .opts import parse_opts
|
11
|
-
from .DiffAugment_pytorch import DiffAugment
|
12
|
-
|
13
|
-
''' Settings '''
|
14
|
-
args = parse_opts()
|
15
|
-
|
16
|
-
# some parameters in opts
|
17
|
-
loss_type = args.loss_type_gan
|
18
|
-
niters = args.niters_gan
|
19
|
-
resume_niters = args.resume_niters_gan
|
20
|
-
dim_gan = args.dim_gan
|
21
|
-
lr_g = args.lr_g_gan
|
22
|
-
lr_d = args.lr_d_gan
|
23
|
-
save_niters_freq = args.save_niters_freq
|
24
|
-
batch_size = min(args.batch_size_disc, args.batch_size_gene)
|
25
|
-
num_classes = args.cGAN_num_classes
|
26
|
-
gan_arch = args.GAN_arch
|
27
|
-
num_D_steps = args.num_D_steps
|
28
|
-
|
29
|
-
visualize_freq = args.visualize_freq
|
30
|
-
|
31
|
-
num_workers = args.num_workers
|
32
|
-
|
33
|
-
NC = args.num_channels
|
34
|
-
IMG_SIZE = args.img_size
|
35
|
-
max_label = args.max_label
|
36
|
-
|
37
|
-
use_DiffAugment = args.gan_DiffAugment
|
38
|
-
policy = args.gan_DiffAugment_policy
|
39
|
-
|
40
|
-
|
41
|
-
def train_cgan_concat(images, labels, netG, netD, save_images_folder, save_models_folder = None):
|
42
|
-
|
43
|
-
netG = netG.cuda()
|
44
|
-
netD = netD.cuda()
|
45
|
-
|
46
|
-
optimizerG = torch.optim.Adam(netG.parameters(), lr=lr_g, betas=(0.5, 0.999))
|
47
|
-
optimizerD = torch.optim.Adam(netD.parameters(), lr=lr_d, betas=(0.5, 0.999))
|
48
|
-
|
49
|
-
trainset = IMGs_dataset(images, labels, normalize=True)
|
50
|
-
train_dataloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
|
51
|
-
unique_labels = np.sort(np.array(list(set(labels)))).astype(np.int)
|
52
|
-
|
53
|
-
if save_models_folder is not None and resume_niters>0:
|
54
|
-
save_file = save_models_folder + "/cGAN_{}_nDsteps_{}_checkpoint_intrain/cGAN_checkpoint_niters_{}.pth".format(gan_arch, num_D_steps, resume_niters)
|
55
|
-
checkpoint = torch.load(save_file)
|
56
|
-
netG.load_state_dict(checkpoint['netG_state_dict'])
|
57
|
-
netD.load_state_dict(checkpoint['netD_state_dict'])
|
58
|
-
optimizerG.load_state_dict(checkpoint['optimizerG_state_dict'])
|
59
|
-
optimizerD.load_state_dict(checkpoint['optimizerD_state_dict'])
|
60
|
-
torch.set_rng_state(checkpoint['rng_state'])
|
61
|
-
#end if
|
62
|
-
|
63
|
-
# printed images with labels between the 5-th quantile and 95-th quantile of training labels
|
64
|
-
n_row=10; n_col = n_row
|
65
|
-
z_fixed = torch.randn(n_row*n_col, dim_gan, dtype=torch.float).cuda()
|
66
|
-
start_label = np.quantile(labels, 0.05)
|
67
|
-
end_label = np.quantile(labels, 0.95)
|
68
|
-
selected_labels = np.linspace(start_label, end_label, num=n_row)
|
69
|
-
y_fixed = np.zeros(n_row*n_col)
|
70
|
-
for i in range(n_row):
|
71
|
-
curr_label = selected_labels[i]
|
72
|
-
for j in range(n_col):
|
73
|
-
y_fixed[i*n_col+j] = curr_label
|
74
|
-
print(y_fixed)
|
75
|
-
y_fixed = torch.from_numpy(y_fixed).type(torch.float).view(-1,1).cuda()
|
76
|
-
|
77
|
-
|
78
|
-
batch_idx = 0
|
79
|
-
dataloader_iter = iter(train_dataloader)
|
80
|
-
|
81
|
-
start_time = timeit.default_timer()
|
82
|
-
for niter in range(resume_niters, niters):
|
83
|
-
|
84
|
-
if batch_idx+1 == len(train_dataloader):
|
85
|
-
dataloader_iter = iter(train_dataloader)
|
86
|
-
batch_idx = 0
|
87
|
-
|
88
|
-
'''
|
89
|
-
|
90
|
-
Train Generator: maximize log(D(G(z)))
|
91
|
-
|
92
|
-
'''
|
93
|
-
|
94
|
-
netG.train()
|
95
|
-
|
96
|
-
# get training images
|
97
|
-
_, batch_train_labels = dataloader_iter.next()
|
98
|
-
assert batch_size == batch_train_labels.shape[0]
|
99
|
-
batch_train_labels = batch_train_labels.type(torch.long).cuda()
|
100
|
-
batch_idx+=1
|
101
|
-
|
102
|
-
# Sample noise and labels as generator input
|
103
|
-
z = torch.randn(batch_size, dim_gan, dtype=torch.float).cuda()
|
104
|
-
|
105
|
-
#generate fake images
|
106
|
-
batch_fake_images = netG(z, batch_train_labels)
|
107
|
-
|
108
|
-
# Loss measures generator's ability to fool the discriminator
|
109
|
-
if use_DiffAugment:
|
110
|
-
dis_out = netD(DiffAugment(batch_fake_images, policy=policy), batch_train_labels)
|
111
|
-
else:
|
112
|
-
dis_out = netD(batch_fake_images, batch_train_labels)
|
113
|
-
|
114
|
-
if loss_type == "vanilla":
|
115
|
-
dis_out = torch.nn.Sigmoid()(dis_out)
|
116
|
-
g_loss = - torch.mean(torch.log(dis_out+1e-20))
|
117
|
-
elif loss_type == "hinge":
|
118
|
-
g_loss = - dis_out.mean()
|
119
|
-
|
120
|
-
optimizerG.zero_grad()
|
121
|
-
g_loss.backward()
|
122
|
-
optimizerG.step()
|
123
|
-
|
124
|
-
'''
|
125
|
-
|
126
|
-
Train Discriminator: maximize log(D(x)) + log(1 - D(G(z)))
|
127
|
-
|
128
|
-
'''
|
129
|
-
|
130
|
-
for _ in range(num_D_steps):
|
131
|
-
|
132
|
-
if batch_idx+1 == len(train_dataloader):
|
133
|
-
dataloader_iter = iter(train_dataloader)
|
134
|
-
batch_idx = 0
|
135
|
-
|
136
|
-
# get training images
|
137
|
-
batch_train_images, batch_train_labels = dataloader_iter.next()
|
138
|
-
assert batch_size == batch_train_images.shape[0]
|
139
|
-
batch_train_images = batch_train_images.type(torch.float).cuda()
|
140
|
-
batch_train_labels = batch_train_labels.type(torch.long).cuda()
|
141
|
-
batch_idx+=1
|
142
|
-
|
143
|
-
# Measure discriminator's ability to classify real from generated samples
|
144
|
-
if use_DiffAugment:
|
145
|
-
real_dis_out = netD(DiffAugment(batch_train_images, policy=policy), batch_train_labels)
|
146
|
-
fake_dis_out = netD(DiffAugment(batch_fake_images.detach(), policy=policy), batch_train_labels.detach())
|
147
|
-
else:
|
148
|
-
real_dis_out = netD(batch_train_images, batch_train_labels)
|
149
|
-
fake_dis_out = netD(batch_fake_images.detach(), batch_train_labels.detach())
|
150
|
-
|
151
|
-
|
152
|
-
if loss_type == "vanilla":
|
153
|
-
real_dis_out = torch.nn.Sigmoid()(real_dis_out)
|
154
|
-
fake_dis_out = torch.nn.Sigmoid()(fake_dis_out)
|
155
|
-
d_loss_real = - torch.log(real_dis_out+1e-20)
|
156
|
-
d_loss_fake = - torch.log(1-fake_dis_out+1e-20)
|
157
|
-
elif loss_type == "hinge":
|
158
|
-
d_loss_real = torch.nn.ReLU()(1.0 - real_dis_out)
|
159
|
-
d_loss_fake = torch.nn.ReLU()(1.0 + fake_dis_out)
|
160
|
-
d_loss = (d_loss_real + d_loss_fake).mean()
|
161
|
-
|
162
|
-
optimizerD.zero_grad()
|
163
|
-
d_loss.backward()
|
164
|
-
optimizerD.step()
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
if (niter+1)%20 == 0:
|
169
|
-
print ("cGAN(concat)-%s: [Iter %d/%d] [D loss: %.4f] [G loss: %.4f] [D out real:%.4f] [D out fake:%.4f] [Time: %.4f]" % (gan_arch, niter+1, niters, d_loss.item(), g_loss.item(), real_dis_out.mean().item(),fake_dis_out.mean().item(), timeit.default_timer()-start_time))
|
170
|
-
|
171
|
-
|
172
|
-
if (niter+1) % visualize_freq == 0:
|
173
|
-
netG.eval()
|
174
|
-
with torch.no_grad():
|
175
|
-
gen_imgs = netG(z_fixed, y_fixed)
|
176
|
-
gen_imgs = gen_imgs.detach()
|
177
|
-
save_image(gen_imgs.data, save_images_folder +'/{}.png'.format(niter+1), nrow=n_row, normalize=True)
|
178
|
-
|
179
|
-
if save_models_folder is not None and ((niter+1) % save_niters_freq == 0 or (niter+1) == niters):
|
180
|
-
save_file = save_models_folder + "/cGAN_{}_nDsteps_{}_checkpoint_intrain/cGAN_checkpoint_niters_{}.pth".format(gan_arch, num_D_steps, niter+1)
|
181
|
-
os.makedirs(os.path.dirname(save_file), exist_ok=True)
|
182
|
-
torch.save({
|
183
|
-
'netG_state_dict': netG.state_dict(),
|
184
|
-
'netD_state_dict': netD.state_dict(),
|
185
|
-
'optimizerG_state_dict': optimizerG.state_dict(),
|
186
|
-
'optimizerD_state_dict': optimizerD.state_dict(),
|
187
|
-
'rng_state': torch.get_rng_state()
|
188
|
-
}, save_file)
|
189
|
-
#end for niter
|
190
|
-
|
191
|
-
|
192
|
-
return netG, netD
|
193
|
-
|
194
|
-
|
195
|
-
def sample_cgan_concat_given_labels(netG, given_labels, batch_size = 100, denorm=True, to_numpy=True, verbose=True):
|
196
|
-
'''
|
197
|
-
netG: pretrained generator network
|
198
|
-
given_labels: float. unnormalized labels. we need to convert them to values in [-1,1].
|
199
|
-
'''
|
200
|
-
|
201
|
-
## num of fake images will be generated
|
202
|
-
nfake = len(given_labels)
|
203
|
-
|
204
|
-
## normalize regression
|
205
|
-
labels = given_labels/max_label
|
206
|
-
|
207
|
-
## generate images
|
208
|
-
if batch_size>nfake:
|
209
|
-
batch_size = nfake
|
210
|
-
|
211
|
-
fake_images = []
|
212
|
-
## concat to avoid out of index errors
|
213
|
-
labels = np.concatenate((labels, labels[0:batch_size]), axis=0)
|
214
|
-
|
215
|
-
netG=netG.cuda()
|
216
|
-
netG.eval()
|
217
|
-
with torch.no_grad():
|
218
|
-
if verbose:
|
219
|
-
pb = SimpleProgressBar()
|
220
|
-
tmp = 0
|
221
|
-
while tmp < nfake:
|
222
|
-
z = torch.randn(batch_size, dim_gan, dtype=torch.float).cuda()
|
223
|
-
c = torch.from_numpy(labels[tmp:(tmp+batch_size)]).type(torch.float).cuda()
|
224
|
-
batch_fake_images = netG(z, c)
|
225
|
-
if denorm: #denorm imgs to save memory
|
226
|
-
assert batch_fake_images.max().item()<=1.0 and batch_fake_images.min().item()>=-1.0
|
227
|
-
batch_fake_images = batch_fake_images*0.5+0.5
|
228
|
-
batch_fake_images = batch_fake_images*255.0
|
229
|
-
batch_fake_images = batch_fake_images.type(torch.uint8)
|
230
|
-
fake_images.append(batch_fake_images.detach().cpu())
|
231
|
-
tmp += batch_size
|
232
|
-
if verbose:
|
233
|
-
pb.update(min(float(tmp)/nfake, 1)*100)
|
234
|
-
|
235
|
-
fake_images = torch.cat(fake_images, dim=0)
|
236
|
-
#remove extra entries
|
237
|
-
fake_images = fake_images[0:nfake]
|
238
|
-
|
239
|
-
if to_numpy:
|
240
|
-
fake_images = fake_images.numpy()
|
241
|
-
|
242
|
-
return fake_images, given_labels
|
@@ -1,120 +0,0 @@
|
|
1
|
-
"""
|
2
|
-
Some helpful functions
|
3
|
-
|
4
|
-
"""
|
5
|
-
import numpy as np
|
6
|
-
import torch
|
7
|
-
import torch.nn as nn
|
8
|
-
import torchvision
|
9
|
-
import matplotlib.pyplot as plt
|
10
|
-
import matplotlib as mpl
|
11
|
-
from torch.nn import functional as F
|
12
|
-
import sys
|
13
|
-
import PIL
|
14
|
-
from PIL import Image
|
15
|
-
|
16
|
-
|
17
|
-
|
18
|
-
# ################################################################################
|
19
|
-
# Progress Bar
|
20
|
-
class SimpleProgressBar():
|
21
|
-
def __init__(self, width=50):
|
22
|
-
self.last_x = -1
|
23
|
-
self.width = width
|
24
|
-
|
25
|
-
def update(self, x):
|
26
|
-
assert 0 <= x <= 100 # `x`: progress in percent ( between 0 and 100)
|
27
|
-
if self.last_x == int(x): return
|
28
|
-
self.last_x = int(x)
|
29
|
-
pointer = int(self.width * (x / 100.0))
|
30
|
-
sys.stdout.write( '\r%d%% [%s]' % (int(x), '#' * pointer + '.' * (self.width - pointer)))
|
31
|
-
sys.stdout.flush()
|
32
|
-
if x == 100:
|
33
|
-
print('')
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
################################################################################
|
38
|
-
# torch dataset from numpy array
|
39
|
-
class IMGs_dataset(torch.utils.data.Dataset):
|
40
|
-
def __init__(self, images, labels=None, normalize=False):
|
41
|
-
super(IMGs_dataset, self).__init__()
|
42
|
-
|
43
|
-
self.images = images
|
44
|
-
self.n_images = len(self.images)
|
45
|
-
self.labels = labels
|
46
|
-
if labels is not None:
|
47
|
-
if len(self.images) != len(self.labels):
|
48
|
-
raise Exception('images (' + str(len(self.images)) +') and labels ('+str(len(self.labels))+') do not have the same length!!!')
|
49
|
-
self.normalize = normalize
|
50
|
-
|
51
|
-
|
52
|
-
def __getitem__(self, index):
|
53
|
-
|
54
|
-
image = self.images[index]
|
55
|
-
|
56
|
-
if self.normalize:
|
57
|
-
image = image/255.0
|
58
|
-
image = (image-0.5)/0.5
|
59
|
-
|
60
|
-
if self.labels is not None:
|
61
|
-
label = self.labels[index]
|
62
|
-
return (image, label)
|
63
|
-
else:
|
64
|
-
return image
|
65
|
-
|
66
|
-
def __len__(self):
|
67
|
-
return self.n_images
|
68
|
-
|
69
|
-
|
70
|
-
def PlotLoss(loss, filename):
|
71
|
-
x_axis = np.arange(start = 1, stop = len(loss)+1)
|
72
|
-
plt.switch_backend('agg')
|
73
|
-
mpl.style.use('seaborn')
|
74
|
-
fig = plt.figure()
|
75
|
-
ax = plt.subplot(111)
|
76
|
-
ax.plot(x_axis, np.array(loss))
|
77
|
-
plt.xlabel("epoch")
|
78
|
-
plt.ylabel("training loss")
|
79
|
-
plt.legend()
|
80
|
-
#ax.legend(loc='upper center', bbox_to_anchor=(0.5, 1.15), shadow=True, ncol=3)
|
81
|
-
#plt.title('Training Loss')
|
82
|
-
plt.savefig(filename)
|
83
|
-
|
84
|
-
|
85
|
-
# compute entropy of class labels; labels is a numpy array
|
86
|
-
def compute_entropy(labels, base=None):
|
87
|
-
value,counts = np.unique(labels, return_counts=True)
|
88
|
-
norm_counts = counts / counts.sum()
|
89
|
-
base = np.e if base is None else base
|
90
|
-
return -(norm_counts * np.log(norm_counts)/np.log(base)).sum()
|
91
|
-
|
92
|
-
def predict_class_labels(net, images, batch_size=500, verbose=False, num_workers=0):
|
93
|
-
net = net.cuda()
|
94
|
-
net.eval()
|
95
|
-
|
96
|
-
n = len(images)
|
97
|
-
if batch_size>n:
|
98
|
-
batch_size=n
|
99
|
-
dataset_pred = IMGs_dataset(images, normalize=False)
|
100
|
-
dataloader_pred = torch.utils.data.DataLoader(dataset_pred, batch_size=batch_size, shuffle=False, num_workers=num_workers)
|
101
|
-
|
102
|
-
class_labels_pred = np.zeros(n+batch_size)
|
103
|
-
with torch.no_grad():
|
104
|
-
nimgs_got = 0
|
105
|
-
if verbose:
|
106
|
-
pb = SimpleProgressBar()
|
107
|
-
for batch_idx, batch_images in enumerate(dataloader_pred):
|
108
|
-
batch_images = batch_images.type(torch.float).cuda()
|
109
|
-
batch_size_curr = len(batch_images)
|
110
|
-
|
111
|
-
outputs,_ = net(batch_images)
|
112
|
-
_, batch_class_labels_pred = torch.max(outputs.data, 1)
|
113
|
-
class_labels_pred[nimgs_got:(nimgs_got+batch_size_curr)] = batch_class_labels_pred.detach().cpu().numpy().reshape(-1)
|
114
|
-
|
115
|
-
nimgs_got += batch_size_curr
|
116
|
-
if verbose:
|
117
|
-
pb.update((float(nimgs_got)/n)*100)
|
118
|
-
#end for batch_idx
|
119
|
-
class_labels_pred = class_labels_pred[0:n]
|
120
|
-
return class_labels_pred
|
@@ -1,76 +0,0 @@
|
|
1
|
-
# Differentiable Augmentation for Data-Efficient GAN Training
|
2
|
-
# Shengyu Zhao, Zhijian Liu, Ji Lin, Jun-Yan Zhu, and Song Han
|
3
|
-
# https://arxiv.org/pdf/2006.10738
|
4
|
-
|
5
|
-
import torch
|
6
|
-
import torch.nn.functional as F
|
7
|
-
|
8
|
-
|
9
|
-
def DiffAugment(x, policy='', channels_first=True):
|
10
|
-
if policy:
|
11
|
-
if not channels_first:
|
12
|
-
x = x.permute(0, 3, 1, 2)
|
13
|
-
for p in policy.split(','):
|
14
|
-
for f in AUGMENT_FNS[p]:
|
15
|
-
x = f(x)
|
16
|
-
if not channels_first:
|
17
|
-
x = x.permute(0, 2, 3, 1)
|
18
|
-
x = x.contiguous()
|
19
|
-
return x
|
20
|
-
|
21
|
-
|
22
|
-
def rand_brightness(x):
|
23
|
-
x = x + (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) - 0.5)
|
24
|
-
return x
|
25
|
-
|
26
|
-
|
27
|
-
def rand_saturation(x):
|
28
|
-
x_mean = x.mean(dim=1, keepdim=True)
|
29
|
-
x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) * 2) + x_mean
|
30
|
-
return x
|
31
|
-
|
32
|
-
|
33
|
-
def rand_contrast(x):
|
34
|
-
x_mean = x.mean(dim=[1, 2, 3], keepdim=True)
|
35
|
-
x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) + 0.5) + x_mean
|
36
|
-
return x
|
37
|
-
|
38
|
-
|
39
|
-
def rand_translation(x, ratio=0.125):
|
40
|
-
shift_x, shift_y = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
|
41
|
-
translation_x = torch.randint(-shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device)
|
42
|
-
translation_y = torch.randint(-shift_y, shift_y + 1, size=[x.size(0), 1, 1], device=x.device)
|
43
|
-
grid_batch, grid_x, grid_y = torch.meshgrid(
|
44
|
-
torch.arange(x.size(0), dtype=torch.long, device=x.device),
|
45
|
-
torch.arange(x.size(2), dtype=torch.long, device=x.device),
|
46
|
-
torch.arange(x.size(3), dtype=torch.long, device=x.device),
|
47
|
-
)
|
48
|
-
grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1)
|
49
|
-
grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1)
|
50
|
-
x_pad = F.pad(x, [1, 1, 1, 1, 0, 0, 0, 0])
|
51
|
-
x = x_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_x, grid_y].permute(0, 3, 1, 2)
|
52
|
-
return x
|
53
|
-
|
54
|
-
|
55
|
-
def rand_cutout(x, ratio=0.5):
|
56
|
-
cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
|
57
|
-
offset_x = torch.randint(0, x.size(2) + (1 - cutout_size[0] % 2), size=[x.size(0), 1, 1], device=x.device)
|
58
|
-
offset_y = torch.randint(0, x.size(3) + (1 - cutout_size[1] % 2), size=[x.size(0), 1, 1], device=x.device)
|
59
|
-
grid_batch, grid_x, grid_y = torch.meshgrid(
|
60
|
-
torch.arange(x.size(0), dtype=torch.long, device=x.device),
|
61
|
-
torch.arange(cutout_size[0], dtype=torch.long, device=x.device),
|
62
|
-
torch.arange(cutout_size[1], dtype=torch.long, device=x.device),
|
63
|
-
)
|
64
|
-
grid_x = torch.clamp(grid_x + offset_x - cutout_size[0] // 2, min=0, max=x.size(2) - 1)
|
65
|
-
grid_y = torch.clamp(grid_y + offset_y - cutout_size[1] // 2, min=0, max=x.size(3) - 1)
|
66
|
-
mask = torch.ones(x.size(0), x.size(2), x.size(3), dtype=x.dtype, device=x.device)
|
67
|
-
mask[grid_batch, grid_x, grid_y] = 0
|
68
|
-
x = x * mask.unsqueeze(1)
|
69
|
-
return x
|
70
|
-
|
71
|
-
|
72
|
-
AUGMENT_FNS = {
|
73
|
-
'color': [rand_brightness, rand_saturation, rand_contrast],
|
74
|
-
'translation': [rand_translation],
|
75
|
-
'cutout': [rand_cutout],
|
76
|
-
}
|
File without changes
|