Myosotis-Researches 0.0.14__py3-none-any.whl → 0.0.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- myosotis_researches/CcGAN/train_128/DiffAugment_pytorch.py +76 -0
- myosotis_researches/CcGAN/train_128/__init__.py +0 -0
- myosotis_researches/CcGAN/train_128/eval_metrics.py +205 -0
- myosotis_researches/CcGAN/train_128/opts.py +87 -0
- myosotis_researches/CcGAN/train_128/pretrain_AE.py +268 -0
- myosotis_researches/CcGAN/train_128/pretrain_CNN_class.py +251 -0
- myosotis_researches/CcGAN/train_128/pretrain_CNN_regre.py +255 -0
- myosotis_researches/CcGAN/train_128/train_ccgan.py +303 -0
- myosotis_researches/CcGAN/train_128/train_cgan.py +254 -0
- myosotis_researches/CcGAN/train_128/train_cgan_concat.py +242 -0
- myosotis_researches/CcGAN/train_128/train_net_for_label_embed.py +181 -0
- myosotis_researches/CcGAN/train_128/utils.py +120 -0
- {myosotis_researches-0.0.14.dist-info → myosotis_researches-0.0.15.dist-info}/METADATA +1 -1
- {myosotis_researches-0.0.14.dist-info → myosotis_researches-0.0.15.dist-info}/RECORD +17 -5
- {myosotis_researches-0.0.14.dist-info → myosotis_researches-0.0.15.dist-info}/WHEEL +0 -0
- {myosotis_researches-0.0.14.dist-info → myosotis_researches-0.0.15.dist-info}/licenses/LICENSE +0 -0
- {myosotis_researches-0.0.14.dist-info → myosotis_researches-0.0.15.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,303 @@
|
|
1
|
+
import torch
|
2
|
+
import numpy as np
|
3
|
+
import os
|
4
|
+
import timeit
|
5
|
+
from PIL import Image
|
6
|
+
from torchvision.utils import save_image
|
7
|
+
import torch.cuda as cutorch
|
8
|
+
|
9
|
+
from utils import SimpleProgressBar, IMGs_dataset
|
10
|
+
from opts import parse_opts
|
11
|
+
from DiffAugment_pytorch import DiffAugment
|
12
|
+
|
13
|
+
''' Settings '''
|
14
|
+
args = parse_opts()
|
15
|
+
|
16
|
+
# some parameters in opts
|
17
|
+
gan_arch = args.GAN_arch
|
18
|
+
loss_type = args.loss_type_gan
|
19
|
+
niters = args.niters_gan
|
20
|
+
resume_niters = args.resume_niters_gan
|
21
|
+
dim_gan = args.dim_gan
|
22
|
+
lr_g = args.lr_g_gan
|
23
|
+
lr_d = args.lr_d_gan
|
24
|
+
save_niters_freq = args.save_niters_freq
|
25
|
+
batch_size_disc = args.batch_size_disc
|
26
|
+
batch_size_gene = args.batch_size_gene
|
27
|
+
# batch_size_max = max(batch_size_disc, batch_size_gene)
|
28
|
+
num_D_steps = args.num_D_steps
|
29
|
+
|
30
|
+
visualize_freq = args.visualize_freq
|
31
|
+
|
32
|
+
num_workers = args.num_workers
|
33
|
+
|
34
|
+
threshold_type = args.threshold_type
|
35
|
+
nonzero_soft_weight_threshold = args.nonzero_soft_weight_threshold
|
36
|
+
|
37
|
+
num_channels = args.num_channels
|
38
|
+
img_size = args.img_size
|
39
|
+
max_label = args.max_label
|
40
|
+
|
41
|
+
use_DiffAugment = args.gan_DiffAugment
|
42
|
+
policy = args.gan_DiffAugment_policy
|
43
|
+
|
44
|
+
|
45
|
+
## normalize images
|
46
|
+
def normalize_images(batch_images):
|
47
|
+
batch_images = batch_images/255.0
|
48
|
+
batch_images = (batch_images - 0.5)/0.5
|
49
|
+
return batch_images
|
50
|
+
|
51
|
+
|
52
|
+
def train_ccgan(kernel_sigma, kappa, train_images, train_labels, netG, netD, net_y2h, save_images_folder, save_models_folder = None, clip_label=False):
|
53
|
+
|
54
|
+
'''
|
55
|
+
Note that train_images are not normalized to [-1,1]
|
56
|
+
'''
|
57
|
+
|
58
|
+
netG = netG.cuda()
|
59
|
+
netD = netD.cuda()
|
60
|
+
net_y2h = net_y2h.cuda()
|
61
|
+
net_y2h.eval()
|
62
|
+
|
63
|
+
optimizerG = torch.optim.Adam(netG.parameters(), lr=lr_g, betas=(0.5, 0.999))
|
64
|
+
optimizerD = torch.optim.Adam(netD.parameters(), lr=lr_d, betas=(0.5, 0.999))
|
65
|
+
|
66
|
+
if save_models_folder is not None and resume_niters>0:
|
67
|
+
save_file = save_models_folder + "/CcGAN_{}_{}_nDsteps_{}_checkpoint_intrain/CcGAN_checkpoint_niters_{}.pth".format(gan_arch, threshold_type, num_D_steps, resume_niters)
|
68
|
+
checkpoint = torch.load(save_file)
|
69
|
+
netG.load_state_dict(checkpoint['netG_state_dict'])
|
70
|
+
netD.load_state_dict(checkpoint['netD_state_dict'])
|
71
|
+
optimizerG.load_state_dict(checkpoint['optimizerG_state_dict'])
|
72
|
+
optimizerD.load_state_dict(checkpoint['optimizerD_state_dict'])
|
73
|
+
torch.set_rng_state(checkpoint['rng_state'])
|
74
|
+
#end if
|
75
|
+
|
76
|
+
#################
|
77
|
+
unique_train_labels = np.sort(np.array(list(set(train_labels))))
|
78
|
+
|
79
|
+
# printed images with labels between the 5-th quantile and 95-th quantile of training labels
|
80
|
+
n_row=10; n_col = n_row
|
81
|
+
z_fixed = torch.randn(n_row*n_col, dim_gan, dtype=torch.float).cuda()
|
82
|
+
start_label = np.quantile(train_labels, 0.05)
|
83
|
+
end_label = np.quantile(train_labels, 0.95)
|
84
|
+
selected_labels = np.linspace(start_label, end_label, num=n_row)
|
85
|
+
y_fixed = np.zeros(n_row*n_col)
|
86
|
+
for i in range(n_row):
|
87
|
+
curr_label = selected_labels[i]
|
88
|
+
for j in range(n_col):
|
89
|
+
y_fixed[i*n_col+j] = curr_label
|
90
|
+
print(y_fixed)
|
91
|
+
y_fixed = torch.from_numpy(y_fixed).type(torch.float).view(-1,1).cuda()
|
92
|
+
|
93
|
+
|
94
|
+
start_time = timeit.default_timer()
|
95
|
+
for niter in range(resume_niters, niters):
|
96
|
+
|
97
|
+
''' Train Discriminator '''
|
98
|
+
for _ in range(num_D_steps):
|
99
|
+
|
100
|
+
## randomly draw batch_size_disc y's from unique_train_labels
|
101
|
+
batch_target_labels_in_dataset = np.random.choice(unique_train_labels, size=batch_size_disc, replace=True)
|
102
|
+
## add Gaussian noise; we estimate image distribution conditional on these labels
|
103
|
+
batch_epsilons = np.random.normal(0, kernel_sigma, batch_size_disc)
|
104
|
+
batch_target_labels = batch_target_labels_in_dataset + batch_epsilons
|
105
|
+
|
106
|
+
## find index of real images with labels in the vicinity of batch_target_labels
|
107
|
+
## generate labels for fake image generation; these labels are also in the vicinity of batch_target_labels
|
108
|
+
batch_real_indx = np.zeros(batch_size_disc, dtype=int) #index of images in the datata; the labels of these images are in the vicinity
|
109
|
+
batch_fake_labels = np.zeros(batch_size_disc)
|
110
|
+
|
111
|
+
for j in range(batch_size_disc):
|
112
|
+
## index for real images
|
113
|
+
if threshold_type == "hard":
|
114
|
+
indx_real_in_vicinity = np.where(np.abs(train_labels-batch_target_labels[j])<= kappa)[0]
|
115
|
+
else:
|
116
|
+
# reverse the weight function for SVDL
|
117
|
+
indx_real_in_vicinity = np.where((train_labels-batch_target_labels[j])**2 <= -np.log(nonzero_soft_weight_threshold)/kappa)[0]
|
118
|
+
|
119
|
+
## if the max gap between two consecutive ordered unique labels is large, it is possible that len(indx_real_in_vicinity)<1
|
120
|
+
while len(indx_real_in_vicinity)<1:
|
121
|
+
batch_epsilons_j = np.random.normal(0, kernel_sigma, 1)
|
122
|
+
batch_target_labels[j] = batch_target_labels_in_dataset[j] + batch_epsilons_j
|
123
|
+
if clip_label:
|
124
|
+
batch_target_labels = np.clip(batch_target_labels, 0.0, 1.0)
|
125
|
+
## index for real images
|
126
|
+
if threshold_type == "hard":
|
127
|
+
indx_real_in_vicinity = np.where(np.abs(train_labels-batch_target_labels[j])<= kappa)[0]
|
128
|
+
else:
|
129
|
+
# reverse the weight function for SVDL
|
130
|
+
indx_real_in_vicinity = np.where((train_labels-batch_target_labels[j])**2 <= -np.log(nonzero_soft_weight_threshold)/kappa)[0]
|
131
|
+
#end while len(indx_real_in_vicinity)<1
|
132
|
+
|
133
|
+
assert len(indx_real_in_vicinity)>=1
|
134
|
+
|
135
|
+
batch_real_indx[j] = np.random.choice(indx_real_in_vicinity, size=1)[0]
|
136
|
+
|
137
|
+
## labels for fake images generation
|
138
|
+
if threshold_type == "hard":
|
139
|
+
lb = batch_target_labels[j] - kappa
|
140
|
+
ub = batch_target_labels[j] + kappa
|
141
|
+
else:
|
142
|
+
lb = batch_target_labels[j] - np.sqrt(-np.log(nonzero_soft_weight_threshold)/kappa)
|
143
|
+
ub = batch_target_labels[j] + np.sqrt(-np.log(nonzero_soft_weight_threshold)/kappa)
|
144
|
+
lb = max(0.0, lb); ub = min(ub, 1.0)
|
145
|
+
assert lb<=ub
|
146
|
+
assert lb>=0 and ub>=0
|
147
|
+
assert lb<=1 and ub<=1
|
148
|
+
batch_fake_labels[j] = np.random.uniform(lb, ub, size=1)[0]
|
149
|
+
#end for j
|
150
|
+
|
151
|
+
## draw real image/label batch from the training set
|
152
|
+
batch_real_images = torch.from_numpy(normalize_images(train_images[batch_real_indx]))
|
153
|
+
batch_real_images = batch_real_images.type(torch.float).cuda()
|
154
|
+
batch_real_labels = train_labels[batch_real_indx]
|
155
|
+
batch_real_labels = torch.from_numpy(batch_real_labels).type(torch.float).cuda()
|
156
|
+
|
157
|
+
|
158
|
+
## generate the fake image batch
|
159
|
+
batch_fake_labels = torch.from_numpy(batch_fake_labels).type(torch.float).cuda()
|
160
|
+
z = torch.randn(batch_size_disc, dim_gan, dtype=torch.float).cuda()
|
161
|
+
batch_fake_images = netG(z, net_y2h(batch_fake_labels))
|
162
|
+
|
163
|
+
## target labels on gpu
|
164
|
+
batch_target_labels = torch.from_numpy(batch_target_labels).type(torch.float).cuda()
|
165
|
+
|
166
|
+
## weight vector
|
167
|
+
if threshold_type == "soft":
|
168
|
+
real_weights = torch.exp(-kappa*(batch_real_labels-batch_target_labels)**2).cuda()
|
169
|
+
fake_weights = torch.exp(-kappa*(batch_fake_labels-batch_target_labels)**2).cuda()
|
170
|
+
else:
|
171
|
+
real_weights = torch.ones(batch_size_disc, dtype=torch.float).cuda()
|
172
|
+
fake_weights = torch.ones(batch_size_disc, dtype=torch.float).cuda()
|
173
|
+
#end if threshold type
|
174
|
+
|
175
|
+
# forward pass
|
176
|
+
if use_DiffAugment:
|
177
|
+
real_dis_out = netD(DiffAugment(batch_real_images, policy=policy), net_y2h(batch_target_labels))
|
178
|
+
fake_dis_out = netD(DiffAugment(batch_fake_images.detach(), policy=policy), net_y2h(batch_target_labels))
|
179
|
+
else:
|
180
|
+
real_dis_out = netD(batch_real_images, net_y2h(batch_target_labels))
|
181
|
+
fake_dis_out = netD(batch_fake_images.detach(), net_y2h(batch_target_labels))
|
182
|
+
|
183
|
+
if loss_type == "vanilla":
|
184
|
+
real_dis_out = torch.nn.Sigmoid()(real_dis_out)
|
185
|
+
fake_dis_out = torch.nn.Sigmoid()(fake_dis_out)
|
186
|
+
d_loss_real = - torch.log(real_dis_out+1e-20)
|
187
|
+
d_loss_fake = - torch.log(1-fake_dis_out+1e-20)
|
188
|
+
elif loss_type == "hinge":
|
189
|
+
d_loss_real = torch.nn.ReLU()(1.0 - real_dis_out)
|
190
|
+
d_loss_fake = torch.nn.ReLU()(1.0 + fake_dis_out)
|
191
|
+
else:
|
192
|
+
raise ValueError('Not supported loss type!!!')
|
193
|
+
|
194
|
+
d_loss = torch.mean(real_weights.view(-1) * d_loss_real.view(-1)) + torch.mean(fake_weights.view(-1) * d_loss_fake.view(-1))
|
195
|
+
|
196
|
+
optimizerD.zero_grad()
|
197
|
+
d_loss.backward()
|
198
|
+
optimizerD.step()
|
199
|
+
|
200
|
+
#end for step_D_index
|
201
|
+
|
202
|
+
|
203
|
+
|
204
|
+
''' Train Generator '''
|
205
|
+
netG.train()
|
206
|
+
|
207
|
+
# generate fake images
|
208
|
+
## randomly draw batch_size_gene y's from unique_train_labels
|
209
|
+
batch_target_labels_in_dataset = np.random.choice(unique_train_labels, size=batch_size_gene, replace=True)
|
210
|
+
## add Gaussian noise; we estimate image distribution conditional on these labels
|
211
|
+
batch_epsilons = np.random.normal(0, kernel_sigma, batch_size_gene)
|
212
|
+
batch_target_labels = batch_target_labels_in_dataset + batch_epsilons
|
213
|
+
batch_target_labels = torch.from_numpy(batch_target_labels).type(torch.float).cuda()
|
214
|
+
|
215
|
+
z = torch.randn(batch_size_gene, dim_gan, dtype=torch.float).cuda()
|
216
|
+
batch_fake_images = netG(z, net_y2h(batch_target_labels))
|
217
|
+
|
218
|
+
# loss
|
219
|
+
if use_DiffAugment:
|
220
|
+
dis_out = netD(DiffAugment(batch_fake_images, policy=policy), net_y2h(batch_target_labels))
|
221
|
+
else:
|
222
|
+
dis_out = netD(batch_fake_images, net_y2h(batch_target_labels))
|
223
|
+
if loss_type == "vanilla":
|
224
|
+
dis_out = torch.nn.Sigmoid()(dis_out)
|
225
|
+
g_loss = - torch.mean(torch.log(dis_out+1e-20))
|
226
|
+
elif loss_type == "hinge":
|
227
|
+
g_loss = - dis_out.mean()
|
228
|
+
|
229
|
+
# backward
|
230
|
+
optimizerG.zero_grad()
|
231
|
+
g_loss.backward()
|
232
|
+
optimizerG.step()
|
233
|
+
|
234
|
+
# print loss
|
235
|
+
if (niter+1) % 20 == 0:
|
236
|
+
print ("CcGAN,%s: [Iter %d/%d] [D loss: %.4e] [G loss: %.4e] [real prob: %.3f] [fake prob: %.3f] [Time: %.4f]" % (gan_arch, niter+1, niters, d_loss.item(), g_loss.item(), real_dis_out.mean().item(), fake_dis_out.mean().item(), timeit.default_timer()-start_time))
|
237
|
+
|
238
|
+
if (niter+1) % visualize_freq == 0:
|
239
|
+
netG.eval()
|
240
|
+
with torch.no_grad():
|
241
|
+
gen_imgs = netG(z_fixed, net_y2h(y_fixed))
|
242
|
+
gen_imgs = gen_imgs.detach().cpu()
|
243
|
+
save_image(gen_imgs.data, save_images_folder + '/{}.png'.format(niter+1), nrow=n_row, normalize=True)
|
244
|
+
|
245
|
+
if save_models_folder is not None and ((niter+1) % save_niters_freq == 0 or (niter+1) == niters):
|
246
|
+
save_file = save_models_folder + "/CcGAN_{}_{}_nDsteps_{}_checkpoint_intrain/CcGAN_checkpoint_niters_{}.pth".format(gan_arch, threshold_type, num_D_steps, niter+1)
|
247
|
+
os.makedirs(os.path.dirname(save_file), exist_ok=True)
|
248
|
+
torch.save({
|
249
|
+
'netG_state_dict': netG.state_dict(),
|
250
|
+
'netD_state_dict': netD.state_dict(),
|
251
|
+
'optimizerG_state_dict': optimizerG.state_dict(),
|
252
|
+
'optimizerD_state_dict': optimizerD.state_dict(),
|
253
|
+
'rng_state': torch.get_rng_state()
|
254
|
+
}, save_file)
|
255
|
+
#end for niter
|
256
|
+
return netG, netD
|
257
|
+
|
258
|
+
|
259
|
+
def sample_ccgan_given_labels(netG, net_y2h, labels, batch_size = 500, to_numpy=True, denorm=True, verbose=True):
|
260
|
+
'''
|
261
|
+
netG: pretrained generator network
|
262
|
+
labels: float. normalized labels.
|
263
|
+
'''
|
264
|
+
|
265
|
+
nfake = len(labels)
|
266
|
+
if batch_size>nfake:
|
267
|
+
batch_size=nfake
|
268
|
+
|
269
|
+
fake_images = []
|
270
|
+
fake_labels = np.concatenate((labels, labels[0:batch_size]))
|
271
|
+
netG=netG.cuda()
|
272
|
+
netG.eval()
|
273
|
+
net_y2h = net_y2h.cuda()
|
274
|
+
net_y2h.eval()
|
275
|
+
with torch.no_grad():
|
276
|
+
if verbose:
|
277
|
+
pb = SimpleProgressBar()
|
278
|
+
n_img_got = 0
|
279
|
+
while n_img_got < nfake:
|
280
|
+
z = torch.randn(batch_size, dim_gan, dtype=torch.float).cuda()
|
281
|
+
y = torch.from_numpy(fake_labels[n_img_got:(n_img_got+batch_size)]).type(torch.float).view(-1,1).cuda()
|
282
|
+
batch_fake_images = netG(z, net_y2h(y))
|
283
|
+
if denorm: #denorm imgs to save memory
|
284
|
+
assert batch_fake_images.max().item()<=1.0 and batch_fake_images.min().item()>=-1.0
|
285
|
+
batch_fake_images = batch_fake_images*0.5+0.5
|
286
|
+
batch_fake_images = batch_fake_images*255.0
|
287
|
+
batch_fake_images = batch_fake_images.type(torch.uint8)
|
288
|
+
# assert batch_fake_images.max().item()>1
|
289
|
+
fake_images.append(batch_fake_images.cpu())
|
290
|
+
n_img_got += batch_size
|
291
|
+
if verbose:
|
292
|
+
pb.update(min(float(n_img_got)/nfake, 1)*100)
|
293
|
+
##end while
|
294
|
+
|
295
|
+
fake_images = torch.cat(fake_images, dim=0)
|
296
|
+
#remove extra entries
|
297
|
+
fake_images = fake_images[0:nfake]
|
298
|
+
fake_labels = fake_labels[0:nfake]
|
299
|
+
|
300
|
+
if to_numpy:
|
301
|
+
fake_images = fake_images.numpy()
|
302
|
+
|
303
|
+
return fake_images, fake_labels
|
@@ -0,0 +1,254 @@
|
|
1
|
+
|
2
|
+
import torch
|
3
|
+
import torch.nn as nn
|
4
|
+
from torchvision.utils import save_image
|
5
|
+
import numpy as np
|
6
|
+
import os
|
7
|
+
import timeit
|
8
|
+
|
9
|
+
from utils import IMGs_dataset, SimpleProgressBar
|
10
|
+
from opts import parse_opts
|
11
|
+
from DiffAugment_pytorch import DiffAugment
|
12
|
+
|
13
|
+
''' Settings '''
|
14
|
+
args = parse_opts()
|
15
|
+
|
16
|
+
# some parameters in opts
|
17
|
+
loss_type = args.loss_type_gan
|
18
|
+
niters = args.niters_gan
|
19
|
+
resume_niters = args.resume_niters_gan
|
20
|
+
dim_gan = args.dim_gan
|
21
|
+
lr_g = args.lr_g_gan
|
22
|
+
lr_d = args.lr_d_gan
|
23
|
+
save_niters_freq = args.save_niters_freq
|
24
|
+
batch_size = min(args.batch_size_disc, args.batch_size_gene)
|
25
|
+
num_classes = args.cGAN_num_classes
|
26
|
+
gan_arch = args.GAN_arch
|
27
|
+
num_D_steps = args.num_D_steps
|
28
|
+
|
29
|
+
visualize_freq = args.visualize_freq
|
30
|
+
|
31
|
+
num_workers = args.num_workers
|
32
|
+
|
33
|
+
NC = args.num_channels
|
34
|
+
IMG_SIZE = args.img_size
|
35
|
+
|
36
|
+
use_DiffAugment = args.gan_DiffAugment
|
37
|
+
policy = args.gan_DiffAugment_policy
|
38
|
+
|
39
|
+
|
40
|
+
def train_cgan(images, labels, netG, netD, save_images_folder, save_models_folder = None):
|
41
|
+
|
42
|
+
netG = netG.cuda()
|
43
|
+
netD = netD.cuda()
|
44
|
+
|
45
|
+
optimizerG = torch.optim.Adam(netG.parameters(), lr=lr_g, betas=(0.5, 0.999))
|
46
|
+
optimizerD = torch.optim.Adam(netD.parameters(), lr=lr_d, betas=(0.5, 0.999))
|
47
|
+
|
48
|
+
trainset = IMGs_dataset(images, labels, normalize=True)
|
49
|
+
train_dataloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
|
50
|
+
unique_labels = np.sort(np.array(list(set(labels)))).astype(np.int)
|
51
|
+
|
52
|
+
if save_models_folder is not None and resume_niters>0:
|
53
|
+
save_file = save_models_folder + "/cGAN_{}_nDsteps_{}_checkpoint_intrain/cGAN_checkpoint_niters_{}.pth".format(gan_arch, num_D_steps, resume_niters)
|
54
|
+
checkpoint = torch.load(save_file)
|
55
|
+
netG.load_state_dict(checkpoint['netG_state_dict'])
|
56
|
+
netD.load_state_dict(checkpoint['netD_state_dict'])
|
57
|
+
optimizerG.load_state_dict(checkpoint['optimizerG_state_dict'])
|
58
|
+
optimizerD.load_state_dict(checkpoint['optimizerD_state_dict'])
|
59
|
+
torch.set_rng_state(checkpoint['rng_state'])
|
60
|
+
#end if
|
61
|
+
|
62
|
+
n_row=10
|
63
|
+
z_fixed = torch.randn(n_row**2, dim_gan, dtype=torch.float).cuda()
|
64
|
+
unique_labels = np.sort(unique_labels)
|
65
|
+
selected_labels = np.zeros(n_row)
|
66
|
+
indx_step_size = len(unique_labels)//n_row
|
67
|
+
for i in range(n_row):
|
68
|
+
indx = i*indx_step_size
|
69
|
+
selected_labels[i] = unique_labels[indx]
|
70
|
+
y_fixed = np.zeros(n_row**2)
|
71
|
+
for i in range(n_row):
|
72
|
+
curr_label = selected_labels[i]
|
73
|
+
for j in range(n_row):
|
74
|
+
y_fixed[i*n_row+j] = curr_label
|
75
|
+
y_fixed = torch.from_numpy(y_fixed).type(torch.long).cuda()
|
76
|
+
|
77
|
+
|
78
|
+
batch_idx = 0
|
79
|
+
dataloader_iter = iter(train_dataloader)
|
80
|
+
|
81
|
+
start_time = timeit.default_timer()
|
82
|
+
for niter in range(resume_niters, niters):
|
83
|
+
|
84
|
+
if batch_idx+1 == len(train_dataloader):
|
85
|
+
dataloader_iter = iter(train_dataloader)
|
86
|
+
batch_idx = 0
|
87
|
+
|
88
|
+
'''
|
89
|
+
|
90
|
+
Train Generator: maximize log(D(G(z)))
|
91
|
+
|
92
|
+
'''
|
93
|
+
|
94
|
+
netG.train()
|
95
|
+
|
96
|
+
# get training images
|
97
|
+
_, batch_train_labels = dataloader_iter.next()
|
98
|
+
assert batch_size == batch_train_labels.shape[0]
|
99
|
+
batch_train_labels = batch_train_labels.type(torch.long).cuda()
|
100
|
+
batch_idx+=1
|
101
|
+
|
102
|
+
# Sample noise and labels as generator input
|
103
|
+
z = torch.randn(batch_size, dim_gan, dtype=torch.float).cuda()
|
104
|
+
|
105
|
+
#generate fake images
|
106
|
+
batch_fake_images = netG(z, batch_train_labels)
|
107
|
+
|
108
|
+
# Loss measures generator's ability to fool the discriminator
|
109
|
+
if use_DiffAugment:
|
110
|
+
dis_out = netD(DiffAugment(batch_fake_images, policy=policy), batch_train_labels)
|
111
|
+
else:
|
112
|
+
dis_out = netD(batch_fake_images, batch_train_labels)
|
113
|
+
|
114
|
+
if loss_type == "vanilla":
|
115
|
+
dis_out = torch.nn.Sigmoid()(dis_out)
|
116
|
+
g_loss = - torch.mean(torch.log(dis_out+1e-20))
|
117
|
+
elif loss_type == "hinge":
|
118
|
+
g_loss = - dis_out.mean()
|
119
|
+
|
120
|
+
optimizerG.zero_grad()
|
121
|
+
g_loss.backward()
|
122
|
+
optimizerG.step()
|
123
|
+
|
124
|
+
'''
|
125
|
+
|
126
|
+
Train Discriminator: maximize log(D(x)) + log(1 - D(G(z)))
|
127
|
+
|
128
|
+
'''
|
129
|
+
|
130
|
+
for _ in range(num_D_steps):
|
131
|
+
|
132
|
+
if batch_idx+1 == len(train_dataloader):
|
133
|
+
dataloader_iter = iter(train_dataloader)
|
134
|
+
batch_idx = 0
|
135
|
+
|
136
|
+
# get training images
|
137
|
+
batch_train_images, batch_train_labels = dataloader_iter.next()
|
138
|
+
assert batch_size == batch_train_images.shape[0]
|
139
|
+
batch_train_images = batch_train_images.type(torch.float).cuda()
|
140
|
+
batch_train_labels = batch_train_labels.type(torch.long).cuda()
|
141
|
+
batch_idx+=1
|
142
|
+
|
143
|
+
# Measure discriminator's ability to classify real from generated samples
|
144
|
+
if use_DiffAugment:
|
145
|
+
real_dis_out = netD(DiffAugment(batch_train_images, policy=policy), batch_train_labels)
|
146
|
+
fake_dis_out = netD(DiffAugment(batch_fake_images.detach(), policy=policy), batch_train_labels.detach())
|
147
|
+
else:
|
148
|
+
real_dis_out = netD(batch_train_images, batch_train_labels)
|
149
|
+
fake_dis_out = netD(batch_fake_images.detach(), batch_train_labels.detach())
|
150
|
+
|
151
|
+
|
152
|
+
if loss_type == "vanilla":
|
153
|
+
real_dis_out = torch.nn.Sigmoid()(real_dis_out)
|
154
|
+
fake_dis_out = torch.nn.Sigmoid()(fake_dis_out)
|
155
|
+
d_loss_real = - torch.log(real_dis_out+1e-20)
|
156
|
+
d_loss_fake = - torch.log(1-fake_dis_out+1e-20)
|
157
|
+
elif loss_type == "hinge":
|
158
|
+
d_loss_real = torch.nn.ReLU()(1.0 - real_dis_out)
|
159
|
+
d_loss_fake = torch.nn.ReLU()(1.0 + fake_dis_out)
|
160
|
+
d_loss = (d_loss_real + d_loss_fake).mean()
|
161
|
+
|
162
|
+
optimizerD.zero_grad()
|
163
|
+
d_loss.backward()
|
164
|
+
optimizerD.step()
|
165
|
+
|
166
|
+
|
167
|
+
|
168
|
+
if (niter+1)%20 == 0:
|
169
|
+
print ("cGAN-%s: [Iter %d/%d] [D loss: %.4f] [G loss: %.4f] [D out real:%.4f] [D out fake:%.4f] [Time: %.4f]" % (gan_arch, niter+1, niters, d_loss.item(), g_loss.item(), real_dis_out.mean().item(),fake_dis_out.mean().item(), timeit.default_timer()-start_time))
|
170
|
+
|
171
|
+
|
172
|
+
if (niter+1) % visualize_freq == 0:
|
173
|
+
netG.eval()
|
174
|
+
with torch.no_grad():
|
175
|
+
gen_imgs = netG(z_fixed, y_fixed)
|
176
|
+
gen_imgs = gen_imgs.detach()
|
177
|
+
save_image(gen_imgs.data, save_images_folder +'/{}.png'.format(niter+1), nrow=n_row, normalize=True)
|
178
|
+
|
179
|
+
if save_models_folder is not None and ((niter+1) % save_niters_freq == 0 or (niter+1) == niters):
|
180
|
+
save_file = save_models_folder + "/cGAN_{}_nDsteps_{}_checkpoint_intrain/cGAN_checkpoint_niters_{}.pth".format(gan_arch, num_D_steps, niter+1)
|
181
|
+
os.makedirs(os.path.dirname(save_file), exist_ok=True)
|
182
|
+
torch.save({
|
183
|
+
'netG_state_dict': netG.state_dict(),
|
184
|
+
'netD_state_dict': netD.state_dict(),
|
185
|
+
'optimizerG_state_dict': optimizerG.state_dict(),
|
186
|
+
'optimizerD_state_dict': optimizerD.state_dict(),
|
187
|
+
'rng_state': torch.get_rng_state()
|
188
|
+
}, save_file)
|
189
|
+
#end for niter
|
190
|
+
|
191
|
+
|
192
|
+
return netG, netD
|
193
|
+
|
194
|
+
|
195
|
+
def sample_cgan_given_labels(netG, given_labels, class_cutoff_points, batch_size = 500, denorm=True, to_numpy=True, verbose=True):
|
196
|
+
'''
|
197
|
+
given_labels: a numpy array; raw label without any normalization; not class label
|
198
|
+
class_cutoff_points: the cutoff points to determine the membership of a give label
|
199
|
+
'''
|
200
|
+
|
201
|
+
class_cutoff_points = np.array(class_cutoff_points)
|
202
|
+
num_classes = len(class_cutoff_points)-1
|
203
|
+
|
204
|
+
nfake = len(given_labels)
|
205
|
+
given_class_labels = np.zeros(nfake)
|
206
|
+
for i in range(nfake):
|
207
|
+
curr_given_label = given_labels[i]
|
208
|
+
diff_tmp = class_cutoff_points - curr_given_label
|
209
|
+
indx_nonneg = np.where(diff_tmp>=0)[0]
|
210
|
+
if len(indx_nonneg)==1: #the last element of diff_tmp is non-negative
|
211
|
+
curr_given_class_label = num_classes-1
|
212
|
+
assert indx_nonneg[0] == num_classes
|
213
|
+
elif len(indx_nonneg)>1:
|
214
|
+
if diff_tmp[indx_nonneg[0]]>0:
|
215
|
+
curr_given_class_label = indx_nonneg[0] - 1
|
216
|
+
else:
|
217
|
+
curr_given_class_label = indx_nonneg[0]
|
218
|
+
given_class_labels[i] = curr_given_class_label
|
219
|
+
given_class_labels = np.concatenate((given_class_labels, given_class_labels[0:batch_size]))
|
220
|
+
|
221
|
+
if batch_size>nfake:
|
222
|
+
batch_size = nfake
|
223
|
+
fake_images = []
|
224
|
+
netG=netG.cuda()
|
225
|
+
netG.eval()
|
226
|
+
with torch.no_grad():
|
227
|
+
if verbose:
|
228
|
+
pb = SimpleProgressBar()
|
229
|
+
tmp = 0
|
230
|
+
while tmp < nfake:
|
231
|
+
z = torch.randn(batch_size, dim_gan, dtype=torch.float).cuda()
|
232
|
+
labels = torch.from_numpy(given_class_labels[tmp:(tmp+batch_size)]).type(torch.long).cuda()
|
233
|
+
if labels.max().item()>num_classes:
|
234
|
+
print("Error: max label {}".format(labels.max().item()))
|
235
|
+
batch_fake_images = netG(z, labels)
|
236
|
+
if denorm: #denorm imgs to save memory
|
237
|
+
assert batch_fake_images.max().item()<=1.0 and batch_fake_images.min().item()>=-1.0
|
238
|
+
batch_fake_images = batch_fake_images*0.5+0.5
|
239
|
+
batch_fake_images = batch_fake_images*255.0
|
240
|
+
batch_fake_images = batch_fake_images.type(torch.uint8)
|
241
|
+
# assert batch_fake_images.max().item()>1
|
242
|
+
fake_images.append(batch_fake_images.detach().cpu())
|
243
|
+
tmp += batch_size
|
244
|
+
if verbose:
|
245
|
+
pb.update(min(float(tmp)/nfake, 1)*100)
|
246
|
+
|
247
|
+
fake_images = torch.cat(fake_images, dim=0)
|
248
|
+
#remove extra entries
|
249
|
+
fake_images = fake_images[0:nfake]
|
250
|
+
|
251
|
+
if to_numpy:
|
252
|
+
fake_images = fake_images.numpy()
|
253
|
+
|
254
|
+
return fake_images, given_labels
|