Myosotis-Researches 0.1.8__py3-none-any.whl → 0.1.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (31) hide show
  1. myosotis_researches/CcGAN/train/__init__.py +4 -0
  2. myosotis_researches/CcGAN/{train_128_output_10 → train}/train_ccgan.py +4 -4
  3. myosotis_researches/CcGAN/{train_128_output_10 → train}/train_cgan.py +1 -3
  4. myosotis_researches/CcGAN/{train_128_output_10 → train}/train_cgan_concat.py +1 -3
  5. {myosotis_researches-0.1.8.dist-info → myosotis_researches-0.1.10.dist-info}/METADATA +1 -1
  6. myosotis_researches-0.1.10.dist-info/RECORD +40 -0
  7. myosotis_researches/CcGAN/train_128/DiffAugment_pytorch.py +0 -76
  8. myosotis_researches/CcGAN/train_128/__init__.py +0 -0
  9. myosotis_researches/CcGAN/train_128/eval_metrics.py +0 -205
  10. myosotis_researches/CcGAN/train_128/opts.py +0 -87
  11. myosotis_researches/CcGAN/train_128/pretrain_AE.py +0 -268
  12. myosotis_researches/CcGAN/train_128/pretrain_CNN_class.py +0 -251
  13. myosotis_researches/CcGAN/train_128/pretrain_CNN_regre.py +0 -255
  14. myosotis_researches/CcGAN/train_128/train_ccgan.py +0 -303
  15. myosotis_researches/CcGAN/train_128/train_cgan.py +0 -254
  16. myosotis_researches/CcGAN/train_128/train_cgan_concat.py +0 -242
  17. myosotis_researches/CcGAN/train_128/utils.py +0 -120
  18. myosotis_researches/CcGAN/train_128_output_10/DiffAugment_pytorch.py +0 -76
  19. myosotis_researches/CcGAN/train_128_output_10/__init__.py +0 -0
  20. myosotis_researches/CcGAN/train_128_output_10/eval_metrics.py +0 -205
  21. myosotis_researches/CcGAN/train_128_output_10/opts.py +0 -87
  22. myosotis_researches/CcGAN/train_128_output_10/pretrain_AE.py +0 -268
  23. myosotis_researches/CcGAN/train_128_output_10/pretrain_CNN_class.py +0 -251
  24. myosotis_researches/CcGAN/train_128_output_10/pretrain_CNN_regre.py +0 -255
  25. myosotis_researches/CcGAN/train_128_output_10/train_net_for_label_embed.py +0 -181
  26. myosotis_researches/CcGAN/train_128_output_10/utils.py +0 -120
  27. myosotis_researches-0.1.8.dist-info/RECORD +0 -59
  28. /myosotis_researches/CcGAN/{train_128 → train}/train_net_for_label_embed.py +0 -0
  29. {myosotis_researches-0.1.8.dist-info → myosotis_researches-0.1.10.dist-info}/WHEEL +0 -0
  30. {myosotis_researches-0.1.8.dist-info → myosotis_researches-0.1.10.dist-info}/licenses/LICENSE +0 -0
  31. {myosotis_researches-0.1.8.dist-info → myosotis_researches-0.1.10.dist-info}/top_level.txt +0 -0
@@ -1,254 +0,0 @@
1
-
2
- import torch
3
- import torch.nn as nn
4
- from torchvision.utils import save_image
5
- import numpy as np
6
- import os
7
- import timeit
8
-
9
- from .utils import IMGs_dataset, SimpleProgressBar
10
- from .opts import parse_opts
11
- from .DiffAugment_pytorch import DiffAugment
12
-
13
- ''' Settings '''
14
- args = parse_opts()
15
-
16
- # some parameters in opts
17
- loss_type = args.loss_type_gan
18
- niters = args.niters_gan
19
- resume_niters = args.resume_niters_gan
20
- dim_gan = args.dim_gan
21
- lr_g = args.lr_g_gan
22
- lr_d = args.lr_d_gan
23
- save_niters_freq = args.save_niters_freq
24
- batch_size = min(args.batch_size_disc, args.batch_size_gene)
25
- num_classes = args.cGAN_num_classes
26
- gan_arch = args.GAN_arch
27
- num_D_steps = args.num_D_steps
28
-
29
- visualize_freq = args.visualize_freq
30
-
31
- num_workers = args.num_workers
32
-
33
- NC = args.num_channels
34
- IMG_SIZE = args.img_size
35
-
36
- use_DiffAugment = args.gan_DiffAugment
37
- policy = args.gan_DiffAugment_policy
38
-
39
-
40
- def train_cgan(images, labels, netG, netD, save_images_folder, save_models_folder = None):
41
-
42
- netG = netG.cuda()
43
- netD = netD.cuda()
44
-
45
- optimizerG = torch.optim.Adam(netG.parameters(), lr=lr_g, betas=(0.5, 0.999))
46
- optimizerD = torch.optim.Adam(netD.parameters(), lr=lr_d, betas=(0.5, 0.999))
47
-
48
- trainset = IMGs_dataset(images, labels, normalize=True)
49
- train_dataloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
50
- unique_labels = np.sort(np.array(list(set(labels)))).astype(np.int)
51
-
52
- if save_models_folder is not None and resume_niters>0:
53
- save_file = save_models_folder + "/cGAN_{}_nDsteps_{}_checkpoint_intrain/cGAN_checkpoint_niters_{}.pth".format(gan_arch, num_D_steps, resume_niters)
54
- checkpoint = torch.load(save_file)
55
- netG.load_state_dict(checkpoint['netG_state_dict'])
56
- netD.load_state_dict(checkpoint['netD_state_dict'])
57
- optimizerG.load_state_dict(checkpoint['optimizerG_state_dict'])
58
- optimizerD.load_state_dict(checkpoint['optimizerD_state_dict'])
59
- torch.set_rng_state(checkpoint['rng_state'])
60
- #end if
61
-
62
- n_row=10
63
- z_fixed = torch.randn(n_row**2, dim_gan, dtype=torch.float).cuda()
64
- unique_labels = np.sort(unique_labels)
65
- selected_labels = np.zeros(n_row)
66
- indx_step_size = len(unique_labels)//n_row
67
- for i in range(n_row):
68
- indx = i*indx_step_size
69
- selected_labels[i] = unique_labels[indx]
70
- y_fixed = np.zeros(n_row**2)
71
- for i in range(n_row):
72
- curr_label = selected_labels[i]
73
- for j in range(n_row):
74
- y_fixed[i*n_row+j] = curr_label
75
- y_fixed = torch.from_numpy(y_fixed).type(torch.long).cuda()
76
-
77
-
78
- batch_idx = 0
79
- dataloader_iter = iter(train_dataloader)
80
-
81
- start_time = timeit.default_timer()
82
- for niter in range(resume_niters, niters):
83
-
84
- if batch_idx+1 == len(train_dataloader):
85
- dataloader_iter = iter(train_dataloader)
86
- batch_idx = 0
87
-
88
- '''
89
-
90
- Train Generator: maximize log(D(G(z)))
91
-
92
- '''
93
-
94
- netG.train()
95
-
96
- # get training images
97
- _, batch_train_labels = dataloader_iter.next()
98
- assert batch_size == batch_train_labels.shape[0]
99
- batch_train_labels = batch_train_labels.type(torch.long).cuda()
100
- batch_idx+=1
101
-
102
- # Sample noise and labels as generator input
103
- z = torch.randn(batch_size, dim_gan, dtype=torch.float).cuda()
104
-
105
- #generate fake images
106
- batch_fake_images = netG(z, batch_train_labels)
107
-
108
- # Loss measures generator's ability to fool the discriminator
109
- if use_DiffAugment:
110
- dis_out = netD(DiffAugment(batch_fake_images, policy=policy), batch_train_labels)
111
- else:
112
- dis_out = netD(batch_fake_images, batch_train_labels)
113
-
114
- if loss_type == "vanilla":
115
- dis_out = torch.nn.Sigmoid()(dis_out)
116
- g_loss = - torch.mean(torch.log(dis_out+1e-20))
117
- elif loss_type == "hinge":
118
- g_loss = - dis_out.mean()
119
-
120
- optimizerG.zero_grad()
121
- g_loss.backward()
122
- optimizerG.step()
123
-
124
- '''
125
-
126
- Train Discriminator: maximize log(D(x)) + log(1 - D(G(z)))
127
-
128
- '''
129
-
130
- for _ in range(num_D_steps):
131
-
132
- if batch_idx+1 == len(train_dataloader):
133
- dataloader_iter = iter(train_dataloader)
134
- batch_idx = 0
135
-
136
- # get training images
137
- batch_train_images, batch_train_labels = dataloader_iter.next()
138
- assert batch_size == batch_train_images.shape[0]
139
- batch_train_images = batch_train_images.type(torch.float).cuda()
140
- batch_train_labels = batch_train_labels.type(torch.long).cuda()
141
- batch_idx+=1
142
-
143
- # Measure discriminator's ability to classify real from generated samples
144
- if use_DiffAugment:
145
- real_dis_out = netD(DiffAugment(batch_train_images, policy=policy), batch_train_labels)
146
- fake_dis_out = netD(DiffAugment(batch_fake_images.detach(), policy=policy), batch_train_labels.detach())
147
- else:
148
- real_dis_out = netD(batch_train_images, batch_train_labels)
149
- fake_dis_out = netD(batch_fake_images.detach(), batch_train_labels.detach())
150
-
151
-
152
- if loss_type == "vanilla":
153
- real_dis_out = torch.nn.Sigmoid()(real_dis_out)
154
- fake_dis_out = torch.nn.Sigmoid()(fake_dis_out)
155
- d_loss_real = - torch.log(real_dis_out+1e-20)
156
- d_loss_fake = - torch.log(1-fake_dis_out+1e-20)
157
- elif loss_type == "hinge":
158
- d_loss_real = torch.nn.ReLU()(1.0 - real_dis_out)
159
- d_loss_fake = torch.nn.ReLU()(1.0 + fake_dis_out)
160
- d_loss = (d_loss_real + d_loss_fake).mean()
161
-
162
- optimizerD.zero_grad()
163
- d_loss.backward()
164
- optimizerD.step()
165
-
166
-
167
-
168
- if (niter+1)%20 == 0:
169
- print ("cGAN-%s: [Iter %d/%d] [D loss: %.4f] [G loss: %.4f] [D out real:%.4f] [D out fake:%.4f] [Time: %.4f]" % (gan_arch, niter+1, niters, d_loss.item(), g_loss.item(), real_dis_out.mean().item(),fake_dis_out.mean().item(), timeit.default_timer()-start_time))
170
-
171
-
172
- if (niter+1) % visualize_freq == 0:
173
- netG.eval()
174
- with torch.no_grad():
175
- gen_imgs = netG(z_fixed, y_fixed)
176
- gen_imgs = gen_imgs.detach()
177
- save_image(gen_imgs.data, save_images_folder +'/{}.png'.format(niter+1), nrow=n_row, normalize=True)
178
-
179
- if save_models_folder is not None and ((niter+1) % save_niters_freq == 0 or (niter+1) == niters):
180
- save_file = save_models_folder + "/cGAN_{}_nDsteps_{}_checkpoint_intrain/cGAN_checkpoint_niters_{}.pth".format(gan_arch, num_D_steps, niter+1)
181
- os.makedirs(os.path.dirname(save_file), exist_ok=True)
182
- torch.save({
183
- 'netG_state_dict': netG.state_dict(),
184
- 'netD_state_dict': netD.state_dict(),
185
- 'optimizerG_state_dict': optimizerG.state_dict(),
186
- 'optimizerD_state_dict': optimizerD.state_dict(),
187
- 'rng_state': torch.get_rng_state()
188
- }, save_file)
189
- #end for niter
190
-
191
-
192
- return netG, netD
193
-
194
-
195
- def sample_cgan_given_labels(netG, given_labels, class_cutoff_points, batch_size = 500, denorm=True, to_numpy=True, verbose=True):
196
- '''
197
- given_labels: a numpy array; raw label without any normalization; not class label
198
- class_cutoff_points: the cutoff points to determine the membership of a give label
199
- '''
200
-
201
- class_cutoff_points = np.array(class_cutoff_points)
202
- num_classes = len(class_cutoff_points)-1
203
-
204
- nfake = len(given_labels)
205
- given_class_labels = np.zeros(nfake)
206
- for i in range(nfake):
207
- curr_given_label = given_labels[i]
208
- diff_tmp = class_cutoff_points - curr_given_label
209
- indx_nonneg = np.where(diff_tmp>=0)[0]
210
- if len(indx_nonneg)==1: #the last element of diff_tmp is non-negative
211
- curr_given_class_label = num_classes-1
212
- assert indx_nonneg[0] == num_classes
213
- elif len(indx_nonneg)>1:
214
- if diff_tmp[indx_nonneg[0]]>0:
215
- curr_given_class_label = indx_nonneg[0] - 1
216
- else:
217
- curr_given_class_label = indx_nonneg[0]
218
- given_class_labels[i] = curr_given_class_label
219
- given_class_labels = np.concatenate((given_class_labels, given_class_labels[0:batch_size]))
220
-
221
- if batch_size>nfake:
222
- batch_size = nfake
223
- fake_images = []
224
- netG=netG.cuda()
225
- netG.eval()
226
- with torch.no_grad():
227
- if verbose:
228
- pb = SimpleProgressBar()
229
- tmp = 0
230
- while tmp < nfake:
231
- z = torch.randn(batch_size, dim_gan, dtype=torch.float).cuda()
232
- labels = torch.from_numpy(given_class_labels[tmp:(tmp+batch_size)]).type(torch.long).cuda()
233
- if labels.max().item()>num_classes:
234
- print("Error: max label {}".format(labels.max().item()))
235
- batch_fake_images = netG(z, labels)
236
- if denorm: #denorm imgs to save memory
237
- assert batch_fake_images.max().item()<=1.0 and batch_fake_images.min().item()>=-1.0
238
- batch_fake_images = batch_fake_images*0.5+0.5
239
- batch_fake_images = batch_fake_images*255.0
240
- batch_fake_images = batch_fake_images.type(torch.uint8)
241
- # assert batch_fake_images.max().item()>1
242
- fake_images.append(batch_fake_images.detach().cpu())
243
- tmp += batch_size
244
- if verbose:
245
- pb.update(min(float(tmp)/nfake, 1)*100)
246
-
247
- fake_images = torch.cat(fake_images, dim=0)
248
- #remove extra entries
249
- fake_images = fake_images[0:nfake]
250
-
251
- if to_numpy:
252
- fake_images = fake_images.numpy()
253
-
254
- return fake_images, given_labels
@@ -1,242 +0,0 @@
1
-
2
- import torch
3
- import torch.nn as nn
4
- from torchvision.utils import save_image
5
- import numpy as np
6
- import os
7
- import timeit
8
-
9
- from .utils import IMGs_dataset, SimpleProgressBar
10
- from .opts import parse_opts
11
- from .DiffAugment_pytorch import DiffAugment
12
-
13
- ''' Settings '''
14
- args = parse_opts()
15
-
16
- # some parameters in opts
17
- loss_type = args.loss_type_gan
18
- niters = args.niters_gan
19
- resume_niters = args.resume_niters_gan
20
- dim_gan = args.dim_gan
21
- lr_g = args.lr_g_gan
22
- lr_d = args.lr_d_gan
23
- save_niters_freq = args.save_niters_freq
24
- batch_size = min(args.batch_size_disc, args.batch_size_gene)
25
- num_classes = args.cGAN_num_classes
26
- gan_arch = args.GAN_arch
27
- num_D_steps = args.num_D_steps
28
-
29
- visualize_freq = args.visualize_freq
30
-
31
- num_workers = args.num_workers
32
-
33
- NC = args.num_channels
34
- IMG_SIZE = args.img_size
35
- max_label = args.max_label
36
-
37
- use_DiffAugment = args.gan_DiffAugment
38
- policy = args.gan_DiffAugment_policy
39
-
40
-
41
- def train_cgan_concat(images, labels, netG, netD, save_images_folder, save_models_folder = None):
42
-
43
- netG = netG.cuda()
44
- netD = netD.cuda()
45
-
46
- optimizerG = torch.optim.Adam(netG.parameters(), lr=lr_g, betas=(0.5, 0.999))
47
- optimizerD = torch.optim.Adam(netD.parameters(), lr=lr_d, betas=(0.5, 0.999))
48
-
49
- trainset = IMGs_dataset(images, labels, normalize=True)
50
- train_dataloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
51
- unique_labels = np.sort(np.array(list(set(labels)))).astype(np.int)
52
-
53
- if save_models_folder is not None and resume_niters>0:
54
- save_file = save_models_folder + "/cGAN_{}_nDsteps_{}_checkpoint_intrain/cGAN_checkpoint_niters_{}.pth".format(gan_arch, num_D_steps, resume_niters)
55
- checkpoint = torch.load(save_file)
56
- netG.load_state_dict(checkpoint['netG_state_dict'])
57
- netD.load_state_dict(checkpoint['netD_state_dict'])
58
- optimizerG.load_state_dict(checkpoint['optimizerG_state_dict'])
59
- optimizerD.load_state_dict(checkpoint['optimizerD_state_dict'])
60
- torch.set_rng_state(checkpoint['rng_state'])
61
- #end if
62
-
63
- # printed images with labels between the 5-th quantile and 95-th quantile of training labels
64
- n_row=10; n_col = n_row
65
- z_fixed = torch.randn(n_row*n_col, dim_gan, dtype=torch.float).cuda()
66
- start_label = np.quantile(labels, 0.05)
67
- end_label = np.quantile(labels, 0.95)
68
- selected_labels = np.linspace(start_label, end_label, num=n_row)
69
- y_fixed = np.zeros(n_row*n_col)
70
- for i in range(n_row):
71
- curr_label = selected_labels[i]
72
- for j in range(n_col):
73
- y_fixed[i*n_col+j] = curr_label
74
- print(y_fixed)
75
- y_fixed = torch.from_numpy(y_fixed).type(torch.float).view(-1,1).cuda()
76
-
77
-
78
- batch_idx = 0
79
- dataloader_iter = iter(train_dataloader)
80
-
81
- start_time = timeit.default_timer()
82
- for niter in range(resume_niters, niters):
83
-
84
- if batch_idx+1 == len(train_dataloader):
85
- dataloader_iter = iter(train_dataloader)
86
- batch_idx = 0
87
-
88
- '''
89
-
90
- Train Generator: maximize log(D(G(z)))
91
-
92
- '''
93
-
94
- netG.train()
95
-
96
- # get training images
97
- _, batch_train_labels = dataloader_iter.next()
98
- assert batch_size == batch_train_labels.shape[0]
99
- batch_train_labels = batch_train_labels.type(torch.long).cuda()
100
- batch_idx+=1
101
-
102
- # Sample noise and labels as generator input
103
- z = torch.randn(batch_size, dim_gan, dtype=torch.float).cuda()
104
-
105
- #generate fake images
106
- batch_fake_images = netG(z, batch_train_labels)
107
-
108
- # Loss measures generator's ability to fool the discriminator
109
- if use_DiffAugment:
110
- dis_out = netD(DiffAugment(batch_fake_images, policy=policy), batch_train_labels)
111
- else:
112
- dis_out = netD(batch_fake_images, batch_train_labels)
113
-
114
- if loss_type == "vanilla":
115
- dis_out = torch.nn.Sigmoid()(dis_out)
116
- g_loss = - torch.mean(torch.log(dis_out+1e-20))
117
- elif loss_type == "hinge":
118
- g_loss = - dis_out.mean()
119
-
120
- optimizerG.zero_grad()
121
- g_loss.backward()
122
- optimizerG.step()
123
-
124
- '''
125
-
126
- Train Discriminator: maximize log(D(x)) + log(1 - D(G(z)))
127
-
128
- '''
129
-
130
- for _ in range(num_D_steps):
131
-
132
- if batch_idx+1 == len(train_dataloader):
133
- dataloader_iter = iter(train_dataloader)
134
- batch_idx = 0
135
-
136
- # get training images
137
- batch_train_images, batch_train_labels = dataloader_iter.next()
138
- assert batch_size == batch_train_images.shape[0]
139
- batch_train_images = batch_train_images.type(torch.float).cuda()
140
- batch_train_labels = batch_train_labels.type(torch.long).cuda()
141
- batch_idx+=1
142
-
143
- # Measure discriminator's ability to classify real from generated samples
144
- if use_DiffAugment:
145
- real_dis_out = netD(DiffAugment(batch_train_images, policy=policy), batch_train_labels)
146
- fake_dis_out = netD(DiffAugment(batch_fake_images.detach(), policy=policy), batch_train_labels.detach())
147
- else:
148
- real_dis_out = netD(batch_train_images, batch_train_labels)
149
- fake_dis_out = netD(batch_fake_images.detach(), batch_train_labels.detach())
150
-
151
-
152
- if loss_type == "vanilla":
153
- real_dis_out = torch.nn.Sigmoid()(real_dis_out)
154
- fake_dis_out = torch.nn.Sigmoid()(fake_dis_out)
155
- d_loss_real = - torch.log(real_dis_out+1e-20)
156
- d_loss_fake = - torch.log(1-fake_dis_out+1e-20)
157
- elif loss_type == "hinge":
158
- d_loss_real = torch.nn.ReLU()(1.0 - real_dis_out)
159
- d_loss_fake = torch.nn.ReLU()(1.0 + fake_dis_out)
160
- d_loss = (d_loss_real + d_loss_fake).mean()
161
-
162
- optimizerD.zero_grad()
163
- d_loss.backward()
164
- optimizerD.step()
165
-
166
-
167
-
168
- if (niter+1)%20 == 0:
169
- print ("cGAN(concat)-%s: [Iter %d/%d] [D loss: %.4f] [G loss: %.4f] [D out real:%.4f] [D out fake:%.4f] [Time: %.4f]" % (gan_arch, niter+1, niters, d_loss.item(), g_loss.item(), real_dis_out.mean().item(),fake_dis_out.mean().item(), timeit.default_timer()-start_time))
170
-
171
-
172
- if (niter+1) % visualize_freq == 0:
173
- netG.eval()
174
- with torch.no_grad():
175
- gen_imgs = netG(z_fixed, y_fixed)
176
- gen_imgs = gen_imgs.detach()
177
- save_image(gen_imgs.data, save_images_folder +'/{}.png'.format(niter+1), nrow=n_row, normalize=True)
178
-
179
- if save_models_folder is not None and ((niter+1) % save_niters_freq == 0 or (niter+1) == niters):
180
- save_file = save_models_folder + "/cGAN_{}_nDsteps_{}_checkpoint_intrain/cGAN_checkpoint_niters_{}.pth".format(gan_arch, num_D_steps, niter+1)
181
- os.makedirs(os.path.dirname(save_file), exist_ok=True)
182
- torch.save({
183
- 'netG_state_dict': netG.state_dict(),
184
- 'netD_state_dict': netD.state_dict(),
185
- 'optimizerG_state_dict': optimizerG.state_dict(),
186
- 'optimizerD_state_dict': optimizerD.state_dict(),
187
- 'rng_state': torch.get_rng_state()
188
- }, save_file)
189
- #end for niter
190
-
191
-
192
- return netG, netD
193
-
194
-
195
- def sample_cgan_concat_given_labels(netG, given_labels, batch_size = 100, denorm=True, to_numpy=True, verbose=True):
196
- '''
197
- netG: pretrained generator network
198
- given_labels: float. unnormalized labels. we need to convert them to values in [-1,1].
199
- '''
200
-
201
- ## num of fake images will be generated
202
- nfake = len(given_labels)
203
-
204
- ## normalize regression
205
- labels = given_labels/max_label
206
-
207
- ## generate images
208
- if batch_size>nfake:
209
- batch_size = nfake
210
-
211
- fake_images = []
212
- ## concat to avoid out of index errors
213
- labels = np.concatenate((labels, labels[0:batch_size]), axis=0)
214
-
215
- netG=netG.cuda()
216
- netG.eval()
217
- with torch.no_grad():
218
- if verbose:
219
- pb = SimpleProgressBar()
220
- tmp = 0
221
- while tmp < nfake:
222
- z = torch.randn(batch_size, dim_gan, dtype=torch.float).cuda()
223
- c = torch.from_numpy(labels[tmp:(tmp+batch_size)]).type(torch.float).cuda()
224
- batch_fake_images = netG(z, c)
225
- if denorm: #denorm imgs to save memory
226
- assert batch_fake_images.max().item()<=1.0 and batch_fake_images.min().item()>=-1.0
227
- batch_fake_images = batch_fake_images*0.5+0.5
228
- batch_fake_images = batch_fake_images*255.0
229
- batch_fake_images = batch_fake_images.type(torch.uint8)
230
- fake_images.append(batch_fake_images.detach().cpu())
231
- tmp += batch_size
232
- if verbose:
233
- pb.update(min(float(tmp)/nfake, 1)*100)
234
-
235
- fake_images = torch.cat(fake_images, dim=0)
236
- #remove extra entries
237
- fake_images = fake_images[0:nfake]
238
-
239
- if to_numpy:
240
- fake_images = fake_images.numpy()
241
-
242
- return fake_images, given_labels
@@ -1,120 +0,0 @@
1
- """
2
- Some helpful functions
3
-
4
- """
5
- import numpy as np
6
- import torch
7
- import torch.nn as nn
8
- import torchvision
9
- import matplotlib.pyplot as plt
10
- import matplotlib as mpl
11
- from torch.nn import functional as F
12
- import sys
13
- import PIL
14
- from PIL import Image
15
-
16
-
17
-
18
- # ################################################################################
19
- # Progress Bar
20
- class SimpleProgressBar():
21
- def __init__(self, width=50):
22
- self.last_x = -1
23
- self.width = width
24
-
25
- def update(self, x):
26
- assert 0 <= x <= 100 # `x`: progress in percent ( between 0 and 100)
27
- if self.last_x == int(x): return
28
- self.last_x = int(x)
29
- pointer = int(self.width * (x / 100.0))
30
- sys.stdout.write( '\r%d%% [%s]' % (int(x), '#' * pointer + '.' * (self.width - pointer)))
31
- sys.stdout.flush()
32
- if x == 100:
33
- print('')
34
-
35
-
36
-
37
- ################################################################################
38
- # torch dataset from numpy array
39
- class IMGs_dataset(torch.utils.data.Dataset):
40
- def __init__(self, images, labels=None, normalize=False):
41
- super(IMGs_dataset, self).__init__()
42
-
43
- self.images = images
44
- self.n_images = len(self.images)
45
- self.labels = labels
46
- if labels is not None:
47
- if len(self.images) != len(self.labels):
48
- raise Exception('images (' + str(len(self.images)) +') and labels ('+str(len(self.labels))+') do not have the same length!!!')
49
- self.normalize = normalize
50
-
51
-
52
- def __getitem__(self, index):
53
-
54
- image = self.images[index]
55
-
56
- if self.normalize:
57
- image = image/255.0
58
- image = (image-0.5)/0.5
59
-
60
- if self.labels is not None:
61
- label = self.labels[index]
62
- return (image, label)
63
- else:
64
- return image
65
-
66
- def __len__(self):
67
- return self.n_images
68
-
69
-
70
- def PlotLoss(loss, filename):
71
- x_axis = np.arange(start = 1, stop = len(loss)+1)
72
- plt.switch_backend('agg')
73
- mpl.style.use('seaborn')
74
- fig = plt.figure()
75
- ax = plt.subplot(111)
76
- ax.plot(x_axis, np.array(loss))
77
- plt.xlabel("epoch")
78
- plt.ylabel("training loss")
79
- plt.legend()
80
- #ax.legend(loc='upper center', bbox_to_anchor=(0.5, 1.15), shadow=True, ncol=3)
81
- #plt.title('Training Loss')
82
- plt.savefig(filename)
83
-
84
-
85
- # compute entropy of class labels; labels is a numpy array
86
- def compute_entropy(labels, base=None):
87
- value,counts = np.unique(labels, return_counts=True)
88
- norm_counts = counts / counts.sum()
89
- base = np.e if base is None else base
90
- return -(norm_counts * np.log(norm_counts)/np.log(base)).sum()
91
-
92
- def predict_class_labels(net, images, batch_size=500, verbose=False, num_workers=0):
93
- net = net.cuda()
94
- net.eval()
95
-
96
- n = len(images)
97
- if batch_size>n:
98
- batch_size=n
99
- dataset_pred = IMGs_dataset(images, normalize=False)
100
- dataloader_pred = torch.utils.data.DataLoader(dataset_pred, batch_size=batch_size, shuffle=False, num_workers=num_workers)
101
-
102
- class_labels_pred = np.zeros(n+batch_size)
103
- with torch.no_grad():
104
- nimgs_got = 0
105
- if verbose:
106
- pb = SimpleProgressBar()
107
- for batch_idx, batch_images in enumerate(dataloader_pred):
108
- batch_images = batch_images.type(torch.float).cuda()
109
- batch_size_curr = len(batch_images)
110
-
111
- outputs,_ = net(batch_images)
112
- _, batch_class_labels_pred = torch.max(outputs.data, 1)
113
- class_labels_pred[nimgs_got:(nimgs_got+batch_size_curr)] = batch_class_labels_pred.detach().cpu().numpy().reshape(-1)
114
-
115
- nimgs_got += batch_size_curr
116
- if verbose:
117
- pb.update((float(nimgs_got)/n)*100)
118
- #end for batch_idx
119
- class_labels_pred = class_labels_pred[0:n]
120
- return class_labels_pred
@@ -1,76 +0,0 @@
1
- # Differentiable Augmentation for Data-Efficient GAN Training
2
- # Shengyu Zhao, Zhijian Liu, Ji Lin, Jun-Yan Zhu, and Song Han
3
- # https://arxiv.org/pdf/2006.10738
4
-
5
- import torch
6
- import torch.nn.functional as F
7
-
8
-
9
- def DiffAugment(x, policy='', channels_first=True):
10
- if policy:
11
- if not channels_first:
12
- x = x.permute(0, 3, 1, 2)
13
- for p in policy.split(','):
14
- for f in AUGMENT_FNS[p]:
15
- x = f(x)
16
- if not channels_first:
17
- x = x.permute(0, 2, 3, 1)
18
- x = x.contiguous()
19
- return x
20
-
21
-
22
- def rand_brightness(x):
23
- x = x + (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) - 0.5)
24
- return x
25
-
26
-
27
- def rand_saturation(x):
28
- x_mean = x.mean(dim=1, keepdim=True)
29
- x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) * 2) + x_mean
30
- return x
31
-
32
-
33
- def rand_contrast(x):
34
- x_mean = x.mean(dim=[1, 2, 3], keepdim=True)
35
- x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) + 0.5) + x_mean
36
- return x
37
-
38
-
39
- def rand_translation(x, ratio=0.125):
40
- shift_x, shift_y = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
41
- translation_x = torch.randint(-shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device)
42
- translation_y = torch.randint(-shift_y, shift_y + 1, size=[x.size(0), 1, 1], device=x.device)
43
- grid_batch, grid_x, grid_y = torch.meshgrid(
44
- torch.arange(x.size(0), dtype=torch.long, device=x.device),
45
- torch.arange(x.size(2), dtype=torch.long, device=x.device),
46
- torch.arange(x.size(3), dtype=torch.long, device=x.device),
47
- )
48
- grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1)
49
- grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1)
50
- x_pad = F.pad(x, [1, 1, 1, 1, 0, 0, 0, 0])
51
- x = x_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_x, grid_y].permute(0, 3, 1, 2)
52
- return x
53
-
54
-
55
- def rand_cutout(x, ratio=0.5):
56
- cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
57
- offset_x = torch.randint(0, x.size(2) + (1 - cutout_size[0] % 2), size=[x.size(0), 1, 1], device=x.device)
58
- offset_y = torch.randint(0, x.size(3) + (1 - cutout_size[1] % 2), size=[x.size(0), 1, 1], device=x.device)
59
- grid_batch, grid_x, grid_y = torch.meshgrid(
60
- torch.arange(x.size(0), dtype=torch.long, device=x.device),
61
- torch.arange(cutout_size[0], dtype=torch.long, device=x.device),
62
- torch.arange(cutout_size[1], dtype=torch.long, device=x.device),
63
- )
64
- grid_x = torch.clamp(grid_x + offset_x - cutout_size[0] // 2, min=0, max=x.size(2) - 1)
65
- grid_y = torch.clamp(grid_y + offset_y - cutout_size[1] // 2, min=0, max=x.size(3) - 1)
66
- mask = torch.ones(x.size(0), x.size(2), x.size(3), dtype=x.dtype, device=x.device)
67
- mask[grid_batch, grid_x, grid_y] = 0
68
- x = x * mask.unsqueeze(1)
69
- return x
70
-
71
-
72
- AUGMENT_FNS = {
73
- 'color': [rand_brightness, rand_saturation, rand_contrast],
74
- 'translation': [rand_translation],
75
- 'cutout': [rand_cutout],
76
- }
File without changes