Myosotis-Researches 0.1.7__py3-none-any.whl → 0.1.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. myosotis_researches/CcGAN/train/__init__.py +4 -0
  2. myosotis_researches/CcGAN/{train_128_output_10 → train}/train_ccgan.py +4 -4
  3. myosotis_researches/CcGAN/{train_128 → train}/train_cgan.py +1 -3
  4. myosotis_researches/CcGAN/{train_128 → train}/train_cgan_concat.py +1 -3
  5. myosotis_researches/CcGAN/utils/__init__.py +2 -1
  6. myosotis_researches/CcGAN/utils/train.py +94 -3
  7. {myosotis_researches-0.1.7.dist-info → myosotis_researches-0.1.9.dist-info}/METADATA +1 -1
  8. myosotis_researches-0.1.9.dist-info/RECORD +24 -0
  9. myosotis_researches/CcGAN/models_128/CcGAN_SAGAN.py +0 -301
  10. myosotis_researches/CcGAN/models_128/ResNet_class_eval.py +0 -141
  11. myosotis_researches/CcGAN/models_128/ResNet_embed.py +0 -188
  12. myosotis_researches/CcGAN/models_128/ResNet_regre_eval.py +0 -175
  13. myosotis_researches/CcGAN/models_128/__init__.py +0 -7
  14. myosotis_researches/CcGAN/models_128/autoencoder.py +0 -119
  15. myosotis_researches/CcGAN/models_128/cGAN_SAGAN.py +0 -276
  16. myosotis_researches/CcGAN/models_128/cGAN_concat_SAGAN.py +0 -245
  17. myosotis_researches/CcGAN/models_256/CcGAN_SAGAN.py +0 -303
  18. myosotis_researches/CcGAN/models_256/ResNet_class_eval.py +0 -142
  19. myosotis_researches/CcGAN/models_256/ResNet_embed.py +0 -188
  20. myosotis_researches/CcGAN/models_256/ResNet_regre_eval.py +0 -178
  21. myosotis_researches/CcGAN/models_256/__init__.py +0 -7
  22. myosotis_researches/CcGAN/models_256/autoencoder.py +0 -133
  23. myosotis_researches/CcGAN/models_256/cGAN_SAGAN.py +0 -280
  24. myosotis_researches/CcGAN/models_256/cGAN_concat_SAGAN.py +0 -249
  25. myosotis_researches/CcGAN/train_128/DiffAugment_pytorch.py +0 -76
  26. myosotis_researches/CcGAN/train_128/__init__.py +0 -0
  27. myosotis_researches/CcGAN/train_128/eval_metrics.py +0 -205
  28. myosotis_researches/CcGAN/train_128/opts.py +0 -87
  29. myosotis_researches/CcGAN/train_128/pretrain_AE.py +0 -268
  30. myosotis_researches/CcGAN/train_128/pretrain_CNN_class.py +0 -251
  31. myosotis_researches/CcGAN/train_128/pretrain_CNN_regre.py +0 -255
  32. myosotis_researches/CcGAN/train_128/train_ccgan.py +0 -303
  33. myosotis_researches/CcGAN/train_128/utils.py +0 -120
  34. myosotis_researches/CcGAN/train_128_output_10/DiffAugment_pytorch.py +0 -76
  35. myosotis_researches/CcGAN/train_128_output_10/__init__.py +0 -0
  36. myosotis_researches/CcGAN/train_128_output_10/eval_metrics.py +0 -205
  37. myosotis_researches/CcGAN/train_128_output_10/opts.py +0 -87
  38. myosotis_researches/CcGAN/train_128_output_10/pretrain_AE.py +0 -268
  39. myosotis_researches/CcGAN/train_128_output_10/pretrain_CNN_class.py +0 -251
  40. myosotis_researches/CcGAN/train_128_output_10/pretrain_CNN_regre.py +0 -255
  41. myosotis_researches/CcGAN/train_128_output_10/train_cgan.py +0 -254
  42. myosotis_researches/CcGAN/train_128_output_10/train_cgan_concat.py +0 -242
  43. myosotis_researches/CcGAN/train_128_output_10/train_net_for_label_embed.py +0 -181
  44. myosotis_researches/CcGAN/train_128_output_10/utils.py +0 -120
  45. myosotis_researches-0.1.7.dist-info/RECORD +0 -59
  46. /myosotis_researches/CcGAN/{train_128 → train}/train_net_for_label_embed.py +0 -0
  47. {myosotis_researches-0.1.7.dist-info → myosotis_researches-0.1.9.dist-info}/WHEEL +0 -0
  48. {myosotis_researches-0.1.7.dist-info → myosotis_researches-0.1.9.dist-info}/licenses/LICENSE +0 -0
  49. {myosotis_researches-0.1.7.dist-info → myosotis_researches-0.1.9.dist-info}/top_level.txt +0 -0
@@ -1,254 +0,0 @@
1
-
2
- import torch
3
- import torch.nn as nn
4
- from torchvision.utils import save_image
5
- import numpy as np
6
- import os
7
- import timeit
8
-
9
- from .utils import IMGs_dataset, SimpleProgressBar
10
- from .opts import parse_opts
11
- from .DiffAugment_pytorch import DiffAugment
12
-
13
- ''' Settings '''
14
- args = parse_opts()
15
-
16
- # some parameters in opts
17
- loss_type = args.loss_type_gan
18
- niters = args.niters_gan
19
- resume_niters = args.resume_niters_gan
20
- dim_gan = args.dim_gan
21
- lr_g = args.lr_g_gan
22
- lr_d = args.lr_d_gan
23
- save_niters_freq = args.save_niters_freq
24
- batch_size = min(args.batch_size_disc, args.batch_size_gene)
25
- num_classes = args.cGAN_num_classes
26
- gan_arch = args.GAN_arch
27
- num_D_steps = args.num_D_steps
28
-
29
- visualize_freq = args.visualize_freq
30
-
31
- num_workers = args.num_workers
32
-
33
- NC = args.num_channels
34
- IMG_SIZE = args.img_size
35
-
36
- use_DiffAugment = args.gan_DiffAugment
37
- policy = args.gan_DiffAugment_policy
38
-
39
-
40
- def train_cgan(images, labels, netG, netD, save_images_folder, save_models_folder = None):
41
-
42
- netG = netG.cuda()
43
- netD = netD.cuda()
44
-
45
- optimizerG = torch.optim.Adam(netG.parameters(), lr=lr_g, betas=(0.5, 0.999))
46
- optimizerD = torch.optim.Adam(netD.parameters(), lr=lr_d, betas=(0.5, 0.999))
47
-
48
- trainset = IMGs_dataset(images, labels, normalize=True)
49
- train_dataloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
50
- unique_labels = np.sort(np.array(list(set(labels)))).astype(np.int)
51
-
52
- if save_models_folder is not None and resume_niters>0:
53
- save_file = save_models_folder + "/cGAN_{}_nDsteps_{}_checkpoint_intrain/cGAN_checkpoint_niters_{}.pth".format(gan_arch, num_D_steps, resume_niters)
54
- checkpoint = torch.load(save_file)
55
- netG.load_state_dict(checkpoint['netG_state_dict'])
56
- netD.load_state_dict(checkpoint['netD_state_dict'])
57
- optimizerG.load_state_dict(checkpoint['optimizerG_state_dict'])
58
- optimizerD.load_state_dict(checkpoint['optimizerD_state_dict'])
59
- torch.set_rng_state(checkpoint['rng_state'])
60
- #end if
61
-
62
- n_row=10
63
- z_fixed = torch.randn(n_row**2, dim_gan, dtype=torch.float).cuda()
64
- unique_labels = np.sort(unique_labels)
65
- selected_labels = np.zeros(n_row)
66
- indx_step_size = len(unique_labels)//n_row
67
- for i in range(n_row):
68
- indx = i*indx_step_size
69
- selected_labels[i] = unique_labels[indx]
70
- y_fixed = np.zeros(n_row**2)
71
- for i in range(n_row):
72
- curr_label = selected_labels[i]
73
- for j in range(n_row):
74
- y_fixed[i*n_row+j] = curr_label
75
- y_fixed = torch.from_numpy(y_fixed).type(torch.long).cuda()
76
-
77
-
78
- batch_idx = 0
79
- dataloader_iter = iter(train_dataloader)
80
-
81
- start_time = timeit.default_timer()
82
- for niter in range(resume_niters, niters):
83
-
84
- if batch_idx+1 == len(train_dataloader):
85
- dataloader_iter = iter(train_dataloader)
86
- batch_idx = 0
87
-
88
- '''
89
-
90
- Train Generator: maximize log(D(G(z)))
91
-
92
- '''
93
-
94
- netG.train()
95
-
96
- # get training images
97
- _, batch_train_labels = dataloader_iter.next()
98
- assert batch_size == batch_train_labels.shape[0]
99
- batch_train_labels = batch_train_labels.type(torch.long).cuda()
100
- batch_idx+=1
101
-
102
- # Sample noise and labels as generator input
103
- z = torch.randn(batch_size, dim_gan, dtype=torch.float).cuda()
104
-
105
- #generate fake images
106
- batch_fake_images = netG(z, batch_train_labels)
107
-
108
- # Loss measures generator's ability to fool the discriminator
109
- if use_DiffAugment:
110
- dis_out = netD(DiffAugment(batch_fake_images, policy=policy), batch_train_labels)
111
- else:
112
- dis_out = netD(batch_fake_images, batch_train_labels)
113
-
114
- if loss_type == "vanilla":
115
- dis_out = torch.nn.Sigmoid()(dis_out)
116
- g_loss = - torch.mean(torch.log(dis_out+1e-20))
117
- elif loss_type == "hinge":
118
- g_loss = - dis_out.mean()
119
-
120
- optimizerG.zero_grad()
121
- g_loss.backward()
122
- optimizerG.step()
123
-
124
- '''
125
-
126
- Train Discriminator: maximize log(D(x)) + log(1 - D(G(z)))
127
-
128
- '''
129
-
130
- for _ in range(num_D_steps):
131
-
132
- if batch_idx+1 == len(train_dataloader):
133
- dataloader_iter = iter(train_dataloader)
134
- batch_idx = 0
135
-
136
- # get training images
137
- batch_train_images, batch_train_labels = dataloader_iter.next()
138
- assert batch_size == batch_train_images.shape[0]
139
- batch_train_images = batch_train_images.type(torch.float).cuda()
140
- batch_train_labels = batch_train_labels.type(torch.long).cuda()
141
- batch_idx+=1
142
-
143
- # Measure discriminator's ability to classify real from generated samples
144
- if use_DiffAugment:
145
- real_dis_out = netD(DiffAugment(batch_train_images, policy=policy), batch_train_labels)
146
- fake_dis_out = netD(DiffAugment(batch_fake_images.detach(), policy=policy), batch_train_labels.detach())
147
- else:
148
- real_dis_out = netD(batch_train_images, batch_train_labels)
149
- fake_dis_out = netD(batch_fake_images.detach(), batch_train_labels.detach())
150
-
151
-
152
- if loss_type == "vanilla":
153
- real_dis_out = torch.nn.Sigmoid()(real_dis_out)
154
- fake_dis_out = torch.nn.Sigmoid()(fake_dis_out)
155
- d_loss_real = - torch.log(real_dis_out+1e-20)
156
- d_loss_fake = - torch.log(1-fake_dis_out+1e-20)
157
- elif loss_type == "hinge":
158
- d_loss_real = torch.nn.ReLU()(1.0 - real_dis_out)
159
- d_loss_fake = torch.nn.ReLU()(1.0 + fake_dis_out)
160
- d_loss = (d_loss_real + d_loss_fake).mean()
161
-
162
- optimizerD.zero_grad()
163
- d_loss.backward()
164
- optimizerD.step()
165
-
166
-
167
-
168
- if (niter+1)%20 == 0:
169
- print ("cGAN-%s: [Iter %d/%d] [D loss: %.4f] [G loss: %.4f] [D out real:%.4f] [D out fake:%.4f] [Time: %.4f]" % (gan_arch, niter+1, niters, d_loss.item(), g_loss.item(), real_dis_out.mean().item(),fake_dis_out.mean().item(), timeit.default_timer()-start_time))
170
-
171
-
172
- if (niter+1) % visualize_freq == 0:
173
- netG.eval()
174
- with torch.no_grad():
175
- gen_imgs = netG(z_fixed, y_fixed)
176
- gen_imgs = gen_imgs.detach()
177
- save_image(gen_imgs.data, save_images_folder +'/{}.png'.format(niter+1), nrow=n_row, normalize=True)
178
-
179
- if save_models_folder is not None and ((niter+1) % save_niters_freq == 0 or (niter+1) == niters):
180
- save_file = save_models_folder + "/cGAN_{}_nDsteps_{}_checkpoint_intrain/cGAN_checkpoint_niters_{}.pth".format(gan_arch, num_D_steps, niter+1)
181
- os.makedirs(os.path.dirname(save_file), exist_ok=True)
182
- torch.save({
183
- 'netG_state_dict': netG.state_dict(),
184
- 'netD_state_dict': netD.state_dict(),
185
- 'optimizerG_state_dict': optimizerG.state_dict(),
186
- 'optimizerD_state_dict': optimizerD.state_dict(),
187
- 'rng_state': torch.get_rng_state()
188
- }, save_file)
189
- #end for niter
190
-
191
-
192
- return netG, netD
193
-
194
-
195
- def sample_cgan_given_labels(netG, given_labels, class_cutoff_points, batch_size = 500, denorm=True, to_numpy=True, verbose=True):
196
- '''
197
- given_labels: a numpy array; raw label without any normalization; not class label
198
- class_cutoff_points: the cutoff points to determine the membership of a give label
199
- '''
200
-
201
- class_cutoff_points = np.array(class_cutoff_points)
202
- num_classes = len(class_cutoff_points)-1
203
-
204
- nfake = len(given_labels)
205
- given_class_labels = np.zeros(nfake)
206
- for i in range(nfake):
207
- curr_given_label = given_labels[i]
208
- diff_tmp = class_cutoff_points - curr_given_label
209
- indx_nonneg = np.where(diff_tmp>=0)[0]
210
- if len(indx_nonneg)==1: #the last element of diff_tmp is non-negative
211
- curr_given_class_label = num_classes-1
212
- assert indx_nonneg[0] == num_classes
213
- elif len(indx_nonneg)>1:
214
- if diff_tmp[indx_nonneg[0]]>0:
215
- curr_given_class_label = indx_nonneg[0] - 1
216
- else:
217
- curr_given_class_label = indx_nonneg[0]
218
- given_class_labels[i] = curr_given_class_label
219
- given_class_labels = np.concatenate((given_class_labels, given_class_labels[0:batch_size]))
220
-
221
- if batch_size>nfake:
222
- batch_size = nfake
223
- fake_images = []
224
- netG=netG.cuda()
225
- netG.eval()
226
- with torch.no_grad():
227
- if verbose:
228
- pb = SimpleProgressBar()
229
- tmp = 0
230
- while tmp < nfake:
231
- z = torch.randn(batch_size, dim_gan, dtype=torch.float).cuda()
232
- labels = torch.from_numpy(given_class_labels[tmp:(tmp+batch_size)]).type(torch.long).cuda()
233
- if labels.max().item()>num_classes:
234
- print("Error: max label {}".format(labels.max().item()))
235
- batch_fake_images = netG(z, labels)
236
- if denorm: #denorm imgs to save memory
237
- assert batch_fake_images.max().item()<=1.0 and batch_fake_images.min().item()>=-1.0
238
- batch_fake_images = batch_fake_images*0.5+0.5
239
- batch_fake_images = batch_fake_images*255.0
240
- batch_fake_images = batch_fake_images.type(torch.uint8)
241
- # assert batch_fake_images.max().item()>1
242
- fake_images.append(batch_fake_images.detach().cpu())
243
- tmp += batch_size
244
- if verbose:
245
- pb.update(min(float(tmp)/nfake, 1)*100)
246
-
247
- fake_images = torch.cat(fake_images, dim=0)
248
- #remove extra entries
249
- fake_images = fake_images[0:nfake]
250
-
251
- if to_numpy:
252
- fake_images = fake_images.numpy()
253
-
254
- return fake_images, given_labels
@@ -1,242 +0,0 @@
1
-
2
- import torch
3
- import torch.nn as nn
4
- from torchvision.utils import save_image
5
- import numpy as np
6
- import os
7
- import timeit
8
-
9
- from .utils import IMGs_dataset, SimpleProgressBar
10
- from .opts import parse_opts
11
- from .DiffAugment_pytorch import DiffAugment
12
-
13
- ''' Settings '''
14
- args = parse_opts()
15
-
16
- # some parameters in opts
17
- loss_type = args.loss_type_gan
18
- niters = args.niters_gan
19
- resume_niters = args.resume_niters_gan
20
- dim_gan = args.dim_gan
21
- lr_g = args.lr_g_gan
22
- lr_d = args.lr_d_gan
23
- save_niters_freq = args.save_niters_freq
24
- batch_size = min(args.batch_size_disc, args.batch_size_gene)
25
- num_classes = args.cGAN_num_classes
26
- gan_arch = args.GAN_arch
27
- num_D_steps = args.num_D_steps
28
-
29
- visualize_freq = args.visualize_freq
30
-
31
- num_workers = args.num_workers
32
-
33
- NC = args.num_channels
34
- IMG_SIZE = args.img_size
35
- max_label = args.max_label
36
-
37
- use_DiffAugment = args.gan_DiffAugment
38
- policy = args.gan_DiffAugment_policy
39
-
40
-
41
- def train_cgan_concat(images, labels, netG, netD, save_images_folder, save_models_folder = None):
42
-
43
- netG = netG.cuda()
44
- netD = netD.cuda()
45
-
46
- optimizerG = torch.optim.Adam(netG.parameters(), lr=lr_g, betas=(0.5, 0.999))
47
- optimizerD = torch.optim.Adam(netD.parameters(), lr=lr_d, betas=(0.5, 0.999))
48
-
49
- trainset = IMGs_dataset(images, labels, normalize=True)
50
- train_dataloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
51
- unique_labels = np.sort(np.array(list(set(labels)))).astype(np.int)
52
-
53
- if save_models_folder is not None and resume_niters>0:
54
- save_file = save_models_folder + "/cGAN_{}_nDsteps_{}_checkpoint_intrain/cGAN_checkpoint_niters_{}.pth".format(gan_arch, num_D_steps, resume_niters)
55
- checkpoint = torch.load(save_file)
56
- netG.load_state_dict(checkpoint['netG_state_dict'])
57
- netD.load_state_dict(checkpoint['netD_state_dict'])
58
- optimizerG.load_state_dict(checkpoint['optimizerG_state_dict'])
59
- optimizerD.load_state_dict(checkpoint['optimizerD_state_dict'])
60
- torch.set_rng_state(checkpoint['rng_state'])
61
- #end if
62
-
63
- # printed images with labels between the 5-th quantile and 95-th quantile of training labels
64
- n_row=10; n_col = n_row
65
- z_fixed = torch.randn(n_row*n_col, dim_gan, dtype=torch.float).cuda()
66
- start_label = np.quantile(labels, 0.05)
67
- end_label = np.quantile(labels, 0.95)
68
- selected_labels = np.linspace(start_label, end_label, num=n_row)
69
- y_fixed = np.zeros(n_row*n_col)
70
- for i in range(n_row):
71
- curr_label = selected_labels[i]
72
- for j in range(n_col):
73
- y_fixed[i*n_col+j] = curr_label
74
- print(y_fixed)
75
- y_fixed = torch.from_numpy(y_fixed).type(torch.float).view(-1,1).cuda()
76
-
77
-
78
- batch_idx = 0
79
- dataloader_iter = iter(train_dataloader)
80
-
81
- start_time = timeit.default_timer()
82
- for niter in range(resume_niters, niters):
83
-
84
- if batch_idx+1 == len(train_dataloader):
85
- dataloader_iter = iter(train_dataloader)
86
- batch_idx = 0
87
-
88
- '''
89
-
90
- Train Generator: maximize log(D(G(z)))
91
-
92
- '''
93
-
94
- netG.train()
95
-
96
- # get training images
97
- _, batch_train_labels = dataloader_iter.next()
98
- assert batch_size == batch_train_labels.shape[0]
99
- batch_train_labels = batch_train_labels.type(torch.long).cuda()
100
- batch_idx+=1
101
-
102
- # Sample noise and labels as generator input
103
- z = torch.randn(batch_size, dim_gan, dtype=torch.float).cuda()
104
-
105
- #generate fake images
106
- batch_fake_images = netG(z, batch_train_labels)
107
-
108
- # Loss measures generator's ability to fool the discriminator
109
- if use_DiffAugment:
110
- dis_out = netD(DiffAugment(batch_fake_images, policy=policy), batch_train_labels)
111
- else:
112
- dis_out = netD(batch_fake_images, batch_train_labels)
113
-
114
- if loss_type == "vanilla":
115
- dis_out = torch.nn.Sigmoid()(dis_out)
116
- g_loss = - torch.mean(torch.log(dis_out+1e-20))
117
- elif loss_type == "hinge":
118
- g_loss = - dis_out.mean()
119
-
120
- optimizerG.zero_grad()
121
- g_loss.backward()
122
- optimizerG.step()
123
-
124
- '''
125
-
126
- Train Discriminator: maximize log(D(x)) + log(1 - D(G(z)))
127
-
128
- '''
129
-
130
- for _ in range(num_D_steps):
131
-
132
- if batch_idx+1 == len(train_dataloader):
133
- dataloader_iter = iter(train_dataloader)
134
- batch_idx = 0
135
-
136
- # get training images
137
- batch_train_images, batch_train_labels = dataloader_iter.next()
138
- assert batch_size == batch_train_images.shape[0]
139
- batch_train_images = batch_train_images.type(torch.float).cuda()
140
- batch_train_labels = batch_train_labels.type(torch.long).cuda()
141
- batch_idx+=1
142
-
143
- # Measure discriminator's ability to classify real from generated samples
144
- if use_DiffAugment:
145
- real_dis_out = netD(DiffAugment(batch_train_images, policy=policy), batch_train_labels)
146
- fake_dis_out = netD(DiffAugment(batch_fake_images.detach(), policy=policy), batch_train_labels.detach())
147
- else:
148
- real_dis_out = netD(batch_train_images, batch_train_labels)
149
- fake_dis_out = netD(batch_fake_images.detach(), batch_train_labels.detach())
150
-
151
-
152
- if loss_type == "vanilla":
153
- real_dis_out = torch.nn.Sigmoid()(real_dis_out)
154
- fake_dis_out = torch.nn.Sigmoid()(fake_dis_out)
155
- d_loss_real = - torch.log(real_dis_out+1e-20)
156
- d_loss_fake = - torch.log(1-fake_dis_out+1e-20)
157
- elif loss_type == "hinge":
158
- d_loss_real = torch.nn.ReLU()(1.0 - real_dis_out)
159
- d_loss_fake = torch.nn.ReLU()(1.0 + fake_dis_out)
160
- d_loss = (d_loss_real + d_loss_fake).mean()
161
-
162
- optimizerD.zero_grad()
163
- d_loss.backward()
164
- optimizerD.step()
165
-
166
-
167
-
168
- if (niter+1)%20 == 0:
169
- print ("cGAN(concat)-%s: [Iter %d/%d] [D loss: %.4f] [G loss: %.4f] [D out real:%.4f] [D out fake:%.4f] [Time: %.4f]" % (gan_arch, niter+1, niters, d_loss.item(), g_loss.item(), real_dis_out.mean().item(),fake_dis_out.mean().item(), timeit.default_timer()-start_time))
170
-
171
-
172
- if (niter+1) % visualize_freq == 0:
173
- netG.eval()
174
- with torch.no_grad():
175
- gen_imgs = netG(z_fixed, y_fixed)
176
- gen_imgs = gen_imgs.detach()
177
- save_image(gen_imgs.data, save_images_folder +'/{}.png'.format(niter+1), nrow=n_row, normalize=True)
178
-
179
- if save_models_folder is not None and ((niter+1) % save_niters_freq == 0 or (niter+1) == niters):
180
- save_file = save_models_folder + "/cGAN_{}_nDsteps_{}_checkpoint_intrain/cGAN_checkpoint_niters_{}.pth".format(gan_arch, num_D_steps, niter+1)
181
- os.makedirs(os.path.dirname(save_file), exist_ok=True)
182
- torch.save({
183
- 'netG_state_dict': netG.state_dict(),
184
- 'netD_state_dict': netD.state_dict(),
185
- 'optimizerG_state_dict': optimizerG.state_dict(),
186
- 'optimizerD_state_dict': optimizerD.state_dict(),
187
- 'rng_state': torch.get_rng_state()
188
- }, save_file)
189
- #end for niter
190
-
191
-
192
- return netG, netD
193
-
194
-
195
- def sample_cgan_concat_given_labels(netG, given_labels, batch_size = 100, denorm=True, to_numpy=True, verbose=True):
196
- '''
197
- netG: pretrained generator network
198
- given_labels: float. unnormalized labels. we need to convert them to values in [-1,1].
199
- '''
200
-
201
- ## num of fake images will be generated
202
- nfake = len(given_labels)
203
-
204
- ## normalize regression
205
- labels = given_labels/max_label
206
-
207
- ## generate images
208
- if batch_size>nfake:
209
- batch_size = nfake
210
-
211
- fake_images = []
212
- ## concat to avoid out of index errors
213
- labels = np.concatenate((labels, labels[0:batch_size]), axis=0)
214
-
215
- netG=netG.cuda()
216
- netG.eval()
217
- with torch.no_grad():
218
- if verbose:
219
- pb = SimpleProgressBar()
220
- tmp = 0
221
- while tmp < nfake:
222
- z = torch.randn(batch_size, dim_gan, dtype=torch.float).cuda()
223
- c = torch.from_numpy(labels[tmp:(tmp+batch_size)]).type(torch.float).cuda()
224
- batch_fake_images = netG(z, c)
225
- if denorm: #denorm imgs to save memory
226
- assert batch_fake_images.max().item()<=1.0 and batch_fake_images.min().item()>=-1.0
227
- batch_fake_images = batch_fake_images*0.5+0.5
228
- batch_fake_images = batch_fake_images*255.0
229
- batch_fake_images = batch_fake_images.type(torch.uint8)
230
- fake_images.append(batch_fake_images.detach().cpu())
231
- tmp += batch_size
232
- if verbose:
233
- pb.update(min(float(tmp)/nfake, 1)*100)
234
-
235
- fake_images = torch.cat(fake_images, dim=0)
236
- #remove extra entries
237
- fake_images = fake_images[0:nfake]
238
-
239
- if to_numpy:
240
- fake_images = fake_images.numpy()
241
-
242
- return fake_images, given_labels
@@ -1,181 +0,0 @@
1
-
2
- import torch
3
- import torch.nn as nn
4
- from torchvision.utils import save_image
5
- import numpy as np
6
- import os
7
- import timeit
8
- from PIL import Image
9
-
10
-
11
-
12
-
13
- #-------------------------------------------------------------
14
- def train_net_embed(net, net_name, trainloader, testloader, epochs=200, resume_epoch = 0, lr_base=0.01, lr_decay_factor=0.1, lr_decay_epochs=[80, 140], weight_decay=1e-4, path_to_ckpt = None):
15
-
16
- ''' learning rate decay '''
17
- def adjust_learning_rate_1(optimizer, epoch):
18
- """decrease the learning rate """
19
- lr = lr_base
20
-
21
- num_decays = len(lr_decay_epochs)
22
- for decay_i in range(num_decays):
23
- if epoch >= lr_decay_epochs[decay_i]:
24
- lr = lr * lr_decay_factor
25
- #end if epoch
26
- #end for decay_i
27
- for param_group in optimizer.param_groups:
28
- param_group['lr'] = lr
29
-
30
- net = net.cuda()
31
- criterion = nn.MSELoss()
32
- optimizer = torch.optim.SGD(net.parameters(), lr = lr_base, momentum= 0.9, weight_decay=weight_decay)
33
-
34
- # resume training; load checkpoint
35
- if path_to_ckpt is not None and resume_epoch>0:
36
- save_file = path_to_ckpt + "/embed_x2y_ckpt_in_train/embed_x2y_checkpoint_epoch_{}.pth".format(resume_epoch)
37
- checkpoint = torch.load(save_file)
38
- net.load_state_dict(checkpoint['net_state_dict'])
39
- optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
40
- torch.set_rng_state(checkpoint['rng_state'])
41
- #end if
42
-
43
- start_tmp = timeit.default_timer()
44
- for epoch in range(resume_epoch, epochs):
45
- net.train()
46
- train_loss = 0
47
- adjust_learning_rate_1(optimizer, epoch)
48
- for _, (batch_train_images, batch_train_labels) in enumerate(trainloader):
49
-
50
- # batch_train_images = nn.functional.interpolate(batch_train_images, size = (299,299), scale_factor=None, mode='bilinear', align_corners=False)
51
-
52
- batch_train_images = batch_train_images.type(torch.float).cuda()
53
- batch_train_labels = batch_train_labels.type(torch.float).view(-1,1).cuda()
54
-
55
- #Forward pass
56
- outputs, _ = net(batch_train_images)
57
- loss = criterion(outputs, batch_train_labels)
58
-
59
- #backward pass
60
- optimizer.zero_grad()
61
- loss.backward()
62
- optimizer.step()
63
-
64
- train_loss += loss.cpu().item()
65
- #end for batch_idx
66
- train_loss = train_loss / len(trainloader)
67
-
68
- if testloader is None:
69
- print('Train net_x2y for embedding: [epoch %d/%d] train_loss:%f Time:%.4f' % (epoch+1, epochs, train_loss, timeit.default_timer()-start_tmp))
70
- else:
71
- net.eval() # eval mode (batchnorm uses moving mean/variance instead of mini-batch mean/variance)
72
- with torch.no_grad():
73
- test_loss = 0
74
- for batch_test_images, batch_test_labels in testloader:
75
- batch_test_images = batch_test_images.type(torch.float).cuda()
76
- batch_test_labels = batch_test_labels.type(torch.float).view(-1,1).cuda()
77
- outputs,_ = net(batch_test_images)
78
- loss = criterion(outputs, batch_test_labels)
79
- test_loss += loss.cpu().item()
80
- test_loss = test_loss/len(testloader)
81
-
82
- print('Train net_x2y for label embedding: [epoch %d/%d] train_loss:%f test_loss:%f Time:%.4f' % (epoch+1, epochs, train_loss, test_loss, timeit.default_timer()-start_tmp))
83
-
84
- #save checkpoint
85
- if path_to_ckpt is not None and (((epoch+1) % 50 == 0) or (epoch+1==epochs)):
86
- save_file = path_to_ckpt + "/embed_x2y_ckpt_in_train/embed_x2y_checkpoint_epoch_{}.pth".format(epoch+1)
87
- os.makedirs(os.path.dirname(save_file), exist_ok=True)
88
- torch.save({
89
- 'epoch': epoch,
90
- 'net_state_dict': net.state_dict(),
91
- 'optimizer_state_dict': optimizer.state_dict(),
92
- 'rng_state': torch.get_rng_state()
93
- }, save_file)
94
- #end for epoch
95
-
96
- return net
97
-
98
-
99
-
100
-
101
- ###################################################################################
102
- class label_dataset(torch.utils.data.Dataset):
103
- def __init__(self, labels):
104
- super(label_dataset, self).__init__()
105
-
106
- self.labels = labels
107
- self.n_samples = len(self.labels)
108
-
109
- def __getitem__(self, index):
110
-
111
- y = self.labels[index]
112
- return y
113
-
114
- def __len__(self):
115
- return self.n_samples
116
-
117
-
118
- def train_net_y2h(unique_labels_norm, net_y2h, net_embed, epochs=500, lr_base=0.01, lr_decay_factor=0.1, lr_decay_epochs=[150, 250, 350], weight_decay=1e-4, batch_size=128):
119
- '''
120
- unique_labels_norm: an array of normalized unique labels
121
- '''
122
-
123
- ''' learning rate decay '''
124
- def adjust_learning_rate_2(optimizer, epoch):
125
- """decrease the learning rate """
126
- lr = lr_base
127
-
128
- num_decays = len(lr_decay_epochs)
129
- for decay_i in range(num_decays):
130
- if epoch >= lr_decay_epochs[decay_i]:
131
- lr = lr * lr_decay_factor
132
- #end if epoch
133
- #end for decay_i
134
- for param_group in optimizer.param_groups:
135
- param_group['lr'] = lr
136
-
137
-
138
- assert np.max(unique_labels_norm)<=1 and np.min(unique_labels_norm)>=0
139
- trainset = label_dataset(unique_labels_norm)
140
- trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)
141
-
142
- net_embed.eval()
143
- net_h2y=net_embed.module.h2y #convert embedding labels to original labels
144
- optimizer_y2h = torch.optim.SGD(net_y2h.parameters(), lr = lr_base, momentum= 0.9, weight_decay=weight_decay)
145
-
146
- start_tmp = timeit.default_timer()
147
- for epoch in range(epochs):
148
- net_y2h.train()
149
- train_loss = 0
150
- adjust_learning_rate_2(optimizer_y2h, epoch)
151
- for _, batch_labels in enumerate(trainloader):
152
-
153
- batch_labels = batch_labels.type(torch.float).view(-1,1).cuda()
154
-
155
- # generate noises which will be added to labels
156
- batch_size_curr = len(batch_labels)
157
- batch_gamma = np.random.normal(0, 0.2, batch_size_curr)
158
- batch_gamma = torch.from_numpy(batch_gamma).view(-1,1).type(torch.float).cuda()
159
-
160
- # add noise to labels
161
- batch_labels_noise = torch.clamp(batch_labels+batch_gamma, 0.0, 1.0)
162
-
163
- #Forward pass
164
- batch_hiddens_noise = net_y2h(batch_labels_noise)
165
- batch_rec_labels_noise = net_h2y(batch_hiddens_noise)
166
-
167
- loss = nn.MSELoss()(batch_rec_labels_noise, batch_labels_noise)
168
-
169
- #backward pass
170
- optimizer_y2h.zero_grad()
171
- loss.backward()
172
- optimizer_y2h.step()
173
-
174
- train_loss += loss.cpu().item()
175
- #end for batch_idx
176
- train_loss = train_loss / len(trainloader)
177
-
178
- print('\n Train net_y2h: [epoch %d/%d] train_loss:%f Time:%.4f' % (epoch+1, epochs, train_loss, timeit.default_timer()-start_tmp))
179
- #end for epoch
180
-
181
- return net_y2h