Myosotis-Researches 0.0.14__py3-none-any.whl → 0.0.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,303 @@
1
+ import torch
2
+ import numpy as np
3
+ import os
4
+ import timeit
5
+ from PIL import Image
6
+ from torchvision.utils import save_image
7
+ import torch.cuda as cutorch
8
+
9
+ from utils import SimpleProgressBar, IMGs_dataset
10
+ from opts import parse_opts
11
+ from DiffAugment_pytorch import DiffAugment
12
+
13
+ ''' Settings '''
14
+ args = parse_opts()
15
+
16
+ # some parameters in opts
17
+ gan_arch = args.GAN_arch
18
+ loss_type = args.loss_type_gan
19
+ niters = args.niters_gan
20
+ resume_niters = args.resume_niters_gan
21
+ dim_gan = args.dim_gan
22
+ lr_g = args.lr_g_gan
23
+ lr_d = args.lr_d_gan
24
+ save_niters_freq = args.save_niters_freq
25
+ batch_size_disc = args.batch_size_disc
26
+ batch_size_gene = args.batch_size_gene
27
+ # batch_size_max = max(batch_size_disc, batch_size_gene)
28
+ num_D_steps = args.num_D_steps
29
+
30
+ visualize_freq = args.visualize_freq
31
+
32
+ num_workers = args.num_workers
33
+
34
+ threshold_type = args.threshold_type
35
+ nonzero_soft_weight_threshold = args.nonzero_soft_weight_threshold
36
+
37
+ num_channels = args.num_channels
38
+ img_size = args.img_size
39
+ max_label = args.max_label
40
+
41
+ use_DiffAugment = args.gan_DiffAugment
42
+ policy = args.gan_DiffAugment_policy
43
+
44
+
45
+ ## normalize images
46
+ def normalize_images(batch_images):
47
+ batch_images = batch_images/255.0
48
+ batch_images = (batch_images - 0.5)/0.5
49
+ return batch_images
50
+
51
+
52
+ def train_ccgan(kernel_sigma, kappa, train_images, train_labels, netG, netD, net_y2h, save_images_folder, save_models_folder = None, clip_label=False):
53
+
54
+ '''
55
+ Note that train_images are not normalized to [-1,1]
56
+ '''
57
+
58
+ netG = netG.cuda()
59
+ netD = netD.cuda()
60
+ net_y2h = net_y2h.cuda()
61
+ net_y2h.eval()
62
+
63
+ optimizerG = torch.optim.Adam(netG.parameters(), lr=lr_g, betas=(0.5, 0.999))
64
+ optimizerD = torch.optim.Adam(netD.parameters(), lr=lr_d, betas=(0.5, 0.999))
65
+
66
+ if save_models_folder is not None and resume_niters>0:
67
+ save_file = save_models_folder + "/CcGAN_{}_{}_nDsteps_{}_checkpoint_intrain/CcGAN_checkpoint_niters_{}.pth".format(gan_arch, threshold_type, num_D_steps, resume_niters)
68
+ checkpoint = torch.load(save_file)
69
+ netG.load_state_dict(checkpoint['netG_state_dict'])
70
+ netD.load_state_dict(checkpoint['netD_state_dict'])
71
+ optimizerG.load_state_dict(checkpoint['optimizerG_state_dict'])
72
+ optimizerD.load_state_dict(checkpoint['optimizerD_state_dict'])
73
+ torch.set_rng_state(checkpoint['rng_state'])
74
+ #end if
75
+
76
+ #################
77
+ unique_train_labels = np.sort(np.array(list(set(train_labels))))
78
+
79
+ # printed images with labels between the 5-th quantile and 95-th quantile of training labels
80
+ n_row=10; n_col = n_row
81
+ z_fixed = torch.randn(n_row*n_col, dim_gan, dtype=torch.float).cuda()
82
+ start_label = np.quantile(train_labels, 0.05)
83
+ end_label = np.quantile(train_labels, 0.95)
84
+ selected_labels = np.linspace(start_label, end_label, num=n_row)
85
+ y_fixed = np.zeros(n_row*n_col)
86
+ for i in range(n_row):
87
+ curr_label = selected_labels[i]
88
+ for j in range(n_col):
89
+ y_fixed[i*n_col+j] = curr_label
90
+ print(y_fixed)
91
+ y_fixed = torch.from_numpy(y_fixed).type(torch.float).view(-1,1).cuda()
92
+
93
+
94
+ start_time = timeit.default_timer()
95
+ for niter in range(resume_niters, niters):
96
+
97
+ ''' Train Discriminator '''
98
+ for _ in range(num_D_steps):
99
+
100
+ ## randomly draw batch_size_disc y's from unique_train_labels
101
+ batch_target_labels_in_dataset = np.random.choice(unique_train_labels, size=batch_size_disc, replace=True)
102
+ ## add Gaussian noise; we estimate image distribution conditional on these labels
103
+ batch_epsilons = np.random.normal(0, kernel_sigma, batch_size_disc)
104
+ batch_target_labels = batch_target_labels_in_dataset + batch_epsilons
105
+
106
+ ## find index of real images with labels in the vicinity of batch_target_labels
107
+ ## generate labels for fake image generation; these labels are also in the vicinity of batch_target_labels
108
+ batch_real_indx = np.zeros(batch_size_disc, dtype=int) #index of images in the datata; the labels of these images are in the vicinity
109
+ batch_fake_labels = np.zeros(batch_size_disc)
110
+
111
+ for j in range(batch_size_disc):
112
+ ## index for real images
113
+ if threshold_type == "hard":
114
+ indx_real_in_vicinity = np.where(np.abs(train_labels-batch_target_labels[j])<= kappa)[0]
115
+ else:
116
+ # reverse the weight function for SVDL
117
+ indx_real_in_vicinity = np.where((train_labels-batch_target_labels[j])**2 <= -np.log(nonzero_soft_weight_threshold)/kappa)[0]
118
+
119
+ ## if the max gap between two consecutive ordered unique labels is large, it is possible that len(indx_real_in_vicinity)<1
120
+ while len(indx_real_in_vicinity)<1:
121
+ batch_epsilons_j = np.random.normal(0, kernel_sigma, 1)
122
+ batch_target_labels[j] = batch_target_labels_in_dataset[j] + batch_epsilons_j
123
+ if clip_label:
124
+ batch_target_labels = np.clip(batch_target_labels, 0.0, 1.0)
125
+ ## index for real images
126
+ if threshold_type == "hard":
127
+ indx_real_in_vicinity = np.where(np.abs(train_labels-batch_target_labels[j])<= kappa)[0]
128
+ else:
129
+ # reverse the weight function for SVDL
130
+ indx_real_in_vicinity = np.where((train_labels-batch_target_labels[j])**2 <= -np.log(nonzero_soft_weight_threshold)/kappa)[0]
131
+ #end while len(indx_real_in_vicinity)<1
132
+
133
+ assert len(indx_real_in_vicinity)>=1
134
+
135
+ batch_real_indx[j] = np.random.choice(indx_real_in_vicinity, size=1)[0]
136
+
137
+ ## labels for fake images generation
138
+ if threshold_type == "hard":
139
+ lb = batch_target_labels[j] - kappa
140
+ ub = batch_target_labels[j] + kappa
141
+ else:
142
+ lb = batch_target_labels[j] - np.sqrt(-np.log(nonzero_soft_weight_threshold)/kappa)
143
+ ub = batch_target_labels[j] + np.sqrt(-np.log(nonzero_soft_weight_threshold)/kappa)
144
+ lb = max(0.0, lb); ub = min(ub, 1.0)
145
+ assert lb<=ub
146
+ assert lb>=0 and ub>=0
147
+ assert lb<=1 and ub<=1
148
+ batch_fake_labels[j] = np.random.uniform(lb, ub, size=1)[0]
149
+ #end for j
150
+
151
+ ## draw real image/label batch from the training set
152
+ batch_real_images = torch.from_numpy(normalize_images(train_images[batch_real_indx]))
153
+ batch_real_images = batch_real_images.type(torch.float).cuda()
154
+ batch_real_labels = train_labels[batch_real_indx]
155
+ batch_real_labels = torch.from_numpy(batch_real_labels).type(torch.float).cuda()
156
+
157
+
158
+ ## generate the fake image batch
159
+ batch_fake_labels = torch.from_numpy(batch_fake_labels).type(torch.float).cuda()
160
+ z = torch.randn(batch_size_disc, dim_gan, dtype=torch.float).cuda()
161
+ batch_fake_images = netG(z, net_y2h(batch_fake_labels))
162
+
163
+ ## target labels on gpu
164
+ batch_target_labels = torch.from_numpy(batch_target_labels).type(torch.float).cuda()
165
+
166
+ ## weight vector
167
+ if threshold_type == "soft":
168
+ real_weights = torch.exp(-kappa*(batch_real_labels-batch_target_labels)**2).cuda()
169
+ fake_weights = torch.exp(-kappa*(batch_fake_labels-batch_target_labels)**2).cuda()
170
+ else:
171
+ real_weights = torch.ones(batch_size_disc, dtype=torch.float).cuda()
172
+ fake_weights = torch.ones(batch_size_disc, dtype=torch.float).cuda()
173
+ #end if threshold type
174
+
175
+ # forward pass
176
+ if use_DiffAugment:
177
+ real_dis_out = netD(DiffAugment(batch_real_images, policy=policy), net_y2h(batch_target_labels))
178
+ fake_dis_out = netD(DiffAugment(batch_fake_images.detach(), policy=policy), net_y2h(batch_target_labels))
179
+ else:
180
+ real_dis_out = netD(batch_real_images, net_y2h(batch_target_labels))
181
+ fake_dis_out = netD(batch_fake_images.detach(), net_y2h(batch_target_labels))
182
+
183
+ if loss_type == "vanilla":
184
+ real_dis_out = torch.nn.Sigmoid()(real_dis_out)
185
+ fake_dis_out = torch.nn.Sigmoid()(fake_dis_out)
186
+ d_loss_real = - torch.log(real_dis_out+1e-20)
187
+ d_loss_fake = - torch.log(1-fake_dis_out+1e-20)
188
+ elif loss_type == "hinge":
189
+ d_loss_real = torch.nn.ReLU()(1.0 - real_dis_out)
190
+ d_loss_fake = torch.nn.ReLU()(1.0 + fake_dis_out)
191
+ else:
192
+ raise ValueError('Not supported loss type!!!')
193
+
194
+ d_loss = torch.mean(real_weights.view(-1) * d_loss_real.view(-1)) + torch.mean(fake_weights.view(-1) * d_loss_fake.view(-1))
195
+
196
+ optimizerD.zero_grad()
197
+ d_loss.backward()
198
+ optimizerD.step()
199
+
200
+ #end for step_D_index
201
+
202
+
203
+
204
+ ''' Train Generator '''
205
+ netG.train()
206
+
207
+ # generate fake images
208
+ ## randomly draw batch_size_gene y's from unique_train_labels
209
+ batch_target_labels_in_dataset = np.random.choice(unique_train_labels, size=batch_size_gene, replace=True)
210
+ ## add Gaussian noise; we estimate image distribution conditional on these labels
211
+ batch_epsilons = np.random.normal(0, kernel_sigma, batch_size_gene)
212
+ batch_target_labels = batch_target_labels_in_dataset + batch_epsilons
213
+ batch_target_labels = torch.from_numpy(batch_target_labels).type(torch.float).cuda()
214
+
215
+ z = torch.randn(batch_size_gene, dim_gan, dtype=torch.float).cuda()
216
+ batch_fake_images = netG(z, net_y2h(batch_target_labels))
217
+
218
+ # loss
219
+ if use_DiffAugment:
220
+ dis_out = netD(DiffAugment(batch_fake_images, policy=policy), net_y2h(batch_target_labels))
221
+ else:
222
+ dis_out = netD(batch_fake_images, net_y2h(batch_target_labels))
223
+ if loss_type == "vanilla":
224
+ dis_out = torch.nn.Sigmoid()(dis_out)
225
+ g_loss = - torch.mean(torch.log(dis_out+1e-20))
226
+ elif loss_type == "hinge":
227
+ g_loss = - dis_out.mean()
228
+
229
+ # backward
230
+ optimizerG.zero_grad()
231
+ g_loss.backward()
232
+ optimizerG.step()
233
+
234
+ # print loss
235
+ if (niter+1) % 20 == 0:
236
+ print ("CcGAN,%s: [Iter %d/%d] [D loss: %.4e] [G loss: %.4e] [real prob: %.3f] [fake prob: %.3f] [Time: %.4f]" % (gan_arch, niter+1, niters, d_loss.item(), g_loss.item(), real_dis_out.mean().item(), fake_dis_out.mean().item(), timeit.default_timer()-start_time))
237
+
238
+ if (niter+1) % visualize_freq == 0:
239
+ netG.eval()
240
+ with torch.no_grad():
241
+ gen_imgs = netG(z_fixed, net_y2h(y_fixed))
242
+ gen_imgs = gen_imgs.detach().cpu()
243
+ save_image(gen_imgs.data, save_images_folder + '/{}.png'.format(niter+1), nrow=n_row, normalize=True)
244
+
245
+ if save_models_folder is not None and ((niter+1) % save_niters_freq == 0 or (niter+1) == niters):
246
+ save_file = save_models_folder + "/CcGAN_{}_{}_nDsteps_{}_checkpoint_intrain/CcGAN_checkpoint_niters_{}.pth".format(gan_arch, threshold_type, num_D_steps, niter+1)
247
+ os.makedirs(os.path.dirname(save_file), exist_ok=True)
248
+ torch.save({
249
+ 'netG_state_dict': netG.state_dict(),
250
+ 'netD_state_dict': netD.state_dict(),
251
+ 'optimizerG_state_dict': optimizerG.state_dict(),
252
+ 'optimizerD_state_dict': optimizerD.state_dict(),
253
+ 'rng_state': torch.get_rng_state()
254
+ }, save_file)
255
+ #end for niter
256
+ return netG, netD
257
+
258
+
259
+ def sample_ccgan_given_labels(netG, net_y2h, labels, batch_size = 500, to_numpy=True, denorm=True, verbose=True):
260
+ '''
261
+ netG: pretrained generator network
262
+ labels: float. normalized labels.
263
+ '''
264
+
265
+ nfake = len(labels)
266
+ if batch_size>nfake:
267
+ batch_size=nfake
268
+
269
+ fake_images = []
270
+ fake_labels = np.concatenate((labels, labels[0:batch_size]))
271
+ netG=netG.cuda()
272
+ netG.eval()
273
+ net_y2h = net_y2h.cuda()
274
+ net_y2h.eval()
275
+ with torch.no_grad():
276
+ if verbose:
277
+ pb = SimpleProgressBar()
278
+ n_img_got = 0
279
+ while n_img_got < nfake:
280
+ z = torch.randn(batch_size, dim_gan, dtype=torch.float).cuda()
281
+ y = torch.from_numpy(fake_labels[n_img_got:(n_img_got+batch_size)]).type(torch.float).view(-1,1).cuda()
282
+ batch_fake_images = netG(z, net_y2h(y))
283
+ if denorm: #denorm imgs to save memory
284
+ assert batch_fake_images.max().item()<=1.0 and batch_fake_images.min().item()>=-1.0
285
+ batch_fake_images = batch_fake_images*0.5+0.5
286
+ batch_fake_images = batch_fake_images*255.0
287
+ batch_fake_images = batch_fake_images.type(torch.uint8)
288
+ # assert batch_fake_images.max().item()>1
289
+ fake_images.append(batch_fake_images.cpu())
290
+ n_img_got += batch_size
291
+ if verbose:
292
+ pb.update(min(float(n_img_got)/nfake, 1)*100)
293
+ ##end while
294
+
295
+ fake_images = torch.cat(fake_images, dim=0)
296
+ #remove extra entries
297
+ fake_images = fake_images[0:nfake]
298
+ fake_labels = fake_labels[0:nfake]
299
+
300
+ if to_numpy:
301
+ fake_images = fake_images.numpy()
302
+
303
+ return fake_images, fake_labels
@@ -0,0 +1,254 @@
1
+
2
+ import torch
3
+ import torch.nn as nn
4
+ from torchvision.utils import save_image
5
+ import numpy as np
6
+ import os
7
+ import timeit
8
+
9
+ from utils import IMGs_dataset, SimpleProgressBar
10
+ from opts import parse_opts
11
+ from DiffAugment_pytorch import DiffAugment
12
+
13
+ ''' Settings '''
14
+ args = parse_opts()
15
+
16
+ # some parameters in opts
17
+ loss_type = args.loss_type_gan
18
+ niters = args.niters_gan
19
+ resume_niters = args.resume_niters_gan
20
+ dim_gan = args.dim_gan
21
+ lr_g = args.lr_g_gan
22
+ lr_d = args.lr_d_gan
23
+ save_niters_freq = args.save_niters_freq
24
+ batch_size = min(args.batch_size_disc, args.batch_size_gene)
25
+ num_classes = args.cGAN_num_classes
26
+ gan_arch = args.GAN_arch
27
+ num_D_steps = args.num_D_steps
28
+
29
+ visualize_freq = args.visualize_freq
30
+
31
+ num_workers = args.num_workers
32
+
33
+ NC = args.num_channels
34
+ IMG_SIZE = args.img_size
35
+
36
+ use_DiffAugment = args.gan_DiffAugment
37
+ policy = args.gan_DiffAugment_policy
38
+
39
+
40
+ def train_cgan(images, labels, netG, netD, save_images_folder, save_models_folder = None):
41
+
42
+ netG = netG.cuda()
43
+ netD = netD.cuda()
44
+
45
+ optimizerG = torch.optim.Adam(netG.parameters(), lr=lr_g, betas=(0.5, 0.999))
46
+ optimizerD = torch.optim.Adam(netD.parameters(), lr=lr_d, betas=(0.5, 0.999))
47
+
48
+ trainset = IMGs_dataset(images, labels, normalize=True)
49
+ train_dataloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
50
+ unique_labels = np.sort(np.array(list(set(labels)))).astype(np.int)
51
+
52
+ if save_models_folder is not None and resume_niters>0:
53
+ save_file = save_models_folder + "/cGAN_{}_nDsteps_{}_checkpoint_intrain/cGAN_checkpoint_niters_{}.pth".format(gan_arch, num_D_steps, resume_niters)
54
+ checkpoint = torch.load(save_file)
55
+ netG.load_state_dict(checkpoint['netG_state_dict'])
56
+ netD.load_state_dict(checkpoint['netD_state_dict'])
57
+ optimizerG.load_state_dict(checkpoint['optimizerG_state_dict'])
58
+ optimizerD.load_state_dict(checkpoint['optimizerD_state_dict'])
59
+ torch.set_rng_state(checkpoint['rng_state'])
60
+ #end if
61
+
62
+ n_row=10
63
+ z_fixed = torch.randn(n_row**2, dim_gan, dtype=torch.float).cuda()
64
+ unique_labels = np.sort(unique_labels)
65
+ selected_labels = np.zeros(n_row)
66
+ indx_step_size = len(unique_labels)//n_row
67
+ for i in range(n_row):
68
+ indx = i*indx_step_size
69
+ selected_labels[i] = unique_labels[indx]
70
+ y_fixed = np.zeros(n_row**2)
71
+ for i in range(n_row):
72
+ curr_label = selected_labels[i]
73
+ for j in range(n_row):
74
+ y_fixed[i*n_row+j] = curr_label
75
+ y_fixed = torch.from_numpy(y_fixed).type(torch.long).cuda()
76
+
77
+
78
+ batch_idx = 0
79
+ dataloader_iter = iter(train_dataloader)
80
+
81
+ start_time = timeit.default_timer()
82
+ for niter in range(resume_niters, niters):
83
+
84
+ if batch_idx+1 == len(train_dataloader):
85
+ dataloader_iter = iter(train_dataloader)
86
+ batch_idx = 0
87
+
88
+ '''
89
+
90
+ Train Generator: maximize log(D(G(z)))
91
+
92
+ '''
93
+
94
+ netG.train()
95
+
96
+ # get training images
97
+ _, batch_train_labels = dataloader_iter.next()
98
+ assert batch_size == batch_train_labels.shape[0]
99
+ batch_train_labels = batch_train_labels.type(torch.long).cuda()
100
+ batch_idx+=1
101
+
102
+ # Sample noise and labels as generator input
103
+ z = torch.randn(batch_size, dim_gan, dtype=torch.float).cuda()
104
+
105
+ #generate fake images
106
+ batch_fake_images = netG(z, batch_train_labels)
107
+
108
+ # Loss measures generator's ability to fool the discriminator
109
+ if use_DiffAugment:
110
+ dis_out = netD(DiffAugment(batch_fake_images, policy=policy), batch_train_labels)
111
+ else:
112
+ dis_out = netD(batch_fake_images, batch_train_labels)
113
+
114
+ if loss_type == "vanilla":
115
+ dis_out = torch.nn.Sigmoid()(dis_out)
116
+ g_loss = - torch.mean(torch.log(dis_out+1e-20))
117
+ elif loss_type == "hinge":
118
+ g_loss = - dis_out.mean()
119
+
120
+ optimizerG.zero_grad()
121
+ g_loss.backward()
122
+ optimizerG.step()
123
+
124
+ '''
125
+
126
+ Train Discriminator: maximize log(D(x)) + log(1 - D(G(z)))
127
+
128
+ '''
129
+
130
+ for _ in range(num_D_steps):
131
+
132
+ if batch_idx+1 == len(train_dataloader):
133
+ dataloader_iter = iter(train_dataloader)
134
+ batch_idx = 0
135
+
136
+ # get training images
137
+ batch_train_images, batch_train_labels = dataloader_iter.next()
138
+ assert batch_size == batch_train_images.shape[0]
139
+ batch_train_images = batch_train_images.type(torch.float).cuda()
140
+ batch_train_labels = batch_train_labels.type(torch.long).cuda()
141
+ batch_idx+=1
142
+
143
+ # Measure discriminator's ability to classify real from generated samples
144
+ if use_DiffAugment:
145
+ real_dis_out = netD(DiffAugment(batch_train_images, policy=policy), batch_train_labels)
146
+ fake_dis_out = netD(DiffAugment(batch_fake_images.detach(), policy=policy), batch_train_labels.detach())
147
+ else:
148
+ real_dis_out = netD(batch_train_images, batch_train_labels)
149
+ fake_dis_out = netD(batch_fake_images.detach(), batch_train_labels.detach())
150
+
151
+
152
+ if loss_type == "vanilla":
153
+ real_dis_out = torch.nn.Sigmoid()(real_dis_out)
154
+ fake_dis_out = torch.nn.Sigmoid()(fake_dis_out)
155
+ d_loss_real = - torch.log(real_dis_out+1e-20)
156
+ d_loss_fake = - torch.log(1-fake_dis_out+1e-20)
157
+ elif loss_type == "hinge":
158
+ d_loss_real = torch.nn.ReLU()(1.0 - real_dis_out)
159
+ d_loss_fake = torch.nn.ReLU()(1.0 + fake_dis_out)
160
+ d_loss = (d_loss_real + d_loss_fake).mean()
161
+
162
+ optimizerD.zero_grad()
163
+ d_loss.backward()
164
+ optimizerD.step()
165
+
166
+
167
+
168
+ if (niter+1)%20 == 0:
169
+ print ("cGAN-%s: [Iter %d/%d] [D loss: %.4f] [G loss: %.4f] [D out real:%.4f] [D out fake:%.4f] [Time: %.4f]" % (gan_arch, niter+1, niters, d_loss.item(), g_loss.item(), real_dis_out.mean().item(),fake_dis_out.mean().item(), timeit.default_timer()-start_time))
170
+
171
+
172
+ if (niter+1) % visualize_freq == 0:
173
+ netG.eval()
174
+ with torch.no_grad():
175
+ gen_imgs = netG(z_fixed, y_fixed)
176
+ gen_imgs = gen_imgs.detach()
177
+ save_image(gen_imgs.data, save_images_folder +'/{}.png'.format(niter+1), nrow=n_row, normalize=True)
178
+
179
+ if save_models_folder is not None and ((niter+1) % save_niters_freq == 0 or (niter+1) == niters):
180
+ save_file = save_models_folder + "/cGAN_{}_nDsteps_{}_checkpoint_intrain/cGAN_checkpoint_niters_{}.pth".format(gan_arch, num_D_steps, niter+1)
181
+ os.makedirs(os.path.dirname(save_file), exist_ok=True)
182
+ torch.save({
183
+ 'netG_state_dict': netG.state_dict(),
184
+ 'netD_state_dict': netD.state_dict(),
185
+ 'optimizerG_state_dict': optimizerG.state_dict(),
186
+ 'optimizerD_state_dict': optimizerD.state_dict(),
187
+ 'rng_state': torch.get_rng_state()
188
+ }, save_file)
189
+ #end for niter
190
+
191
+
192
+ return netG, netD
193
+
194
+
195
+ def sample_cgan_given_labels(netG, given_labels, class_cutoff_points, batch_size = 500, denorm=True, to_numpy=True, verbose=True):
196
+ '''
197
+ given_labels: a numpy array; raw label without any normalization; not class label
198
+ class_cutoff_points: the cutoff points to determine the membership of a give label
199
+ '''
200
+
201
+ class_cutoff_points = np.array(class_cutoff_points)
202
+ num_classes = len(class_cutoff_points)-1
203
+
204
+ nfake = len(given_labels)
205
+ given_class_labels = np.zeros(nfake)
206
+ for i in range(nfake):
207
+ curr_given_label = given_labels[i]
208
+ diff_tmp = class_cutoff_points - curr_given_label
209
+ indx_nonneg = np.where(diff_tmp>=0)[0]
210
+ if len(indx_nonneg)==1: #the last element of diff_tmp is non-negative
211
+ curr_given_class_label = num_classes-1
212
+ assert indx_nonneg[0] == num_classes
213
+ elif len(indx_nonneg)>1:
214
+ if diff_tmp[indx_nonneg[0]]>0:
215
+ curr_given_class_label = indx_nonneg[0] - 1
216
+ else:
217
+ curr_given_class_label = indx_nonneg[0]
218
+ given_class_labels[i] = curr_given_class_label
219
+ given_class_labels = np.concatenate((given_class_labels, given_class_labels[0:batch_size]))
220
+
221
+ if batch_size>nfake:
222
+ batch_size = nfake
223
+ fake_images = []
224
+ netG=netG.cuda()
225
+ netG.eval()
226
+ with torch.no_grad():
227
+ if verbose:
228
+ pb = SimpleProgressBar()
229
+ tmp = 0
230
+ while tmp < nfake:
231
+ z = torch.randn(batch_size, dim_gan, dtype=torch.float).cuda()
232
+ labels = torch.from_numpy(given_class_labels[tmp:(tmp+batch_size)]).type(torch.long).cuda()
233
+ if labels.max().item()>num_classes:
234
+ print("Error: max label {}".format(labels.max().item()))
235
+ batch_fake_images = netG(z, labels)
236
+ if denorm: #denorm imgs to save memory
237
+ assert batch_fake_images.max().item()<=1.0 and batch_fake_images.min().item()>=-1.0
238
+ batch_fake_images = batch_fake_images*0.5+0.5
239
+ batch_fake_images = batch_fake_images*255.0
240
+ batch_fake_images = batch_fake_images.type(torch.uint8)
241
+ # assert batch_fake_images.max().item()>1
242
+ fake_images.append(batch_fake_images.detach().cpu())
243
+ tmp += batch_size
244
+ if verbose:
245
+ pb.update(min(float(tmp)/nfake, 1)*100)
246
+
247
+ fake_images = torch.cat(fake_images, dim=0)
248
+ #remove extra entries
249
+ fake_images = fake_images[0:nfake]
250
+
251
+ if to_numpy:
252
+ fake_images = fake_images.numpy()
253
+
254
+ return fake_images, given_labels