Myosotis-Researches 0.1.7__py3-none-any.whl → 0.1.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. myosotis_researches/CcGAN/train/__init__.py +4 -0
  2. myosotis_researches/CcGAN/{train_128_output_10 → train}/train_ccgan.py +4 -4
  3. myosotis_researches/CcGAN/{train_128 → train}/train_cgan.py +1 -3
  4. myosotis_researches/CcGAN/{train_128 → train}/train_cgan_concat.py +1 -3
  5. myosotis_researches/CcGAN/utils/__init__.py +2 -1
  6. myosotis_researches/CcGAN/utils/train.py +94 -3
  7. {myosotis_researches-0.1.7.dist-info → myosotis_researches-0.1.9.dist-info}/METADATA +1 -1
  8. myosotis_researches-0.1.9.dist-info/RECORD +24 -0
  9. myosotis_researches/CcGAN/models_128/CcGAN_SAGAN.py +0 -301
  10. myosotis_researches/CcGAN/models_128/ResNet_class_eval.py +0 -141
  11. myosotis_researches/CcGAN/models_128/ResNet_embed.py +0 -188
  12. myosotis_researches/CcGAN/models_128/ResNet_regre_eval.py +0 -175
  13. myosotis_researches/CcGAN/models_128/__init__.py +0 -7
  14. myosotis_researches/CcGAN/models_128/autoencoder.py +0 -119
  15. myosotis_researches/CcGAN/models_128/cGAN_SAGAN.py +0 -276
  16. myosotis_researches/CcGAN/models_128/cGAN_concat_SAGAN.py +0 -245
  17. myosotis_researches/CcGAN/models_256/CcGAN_SAGAN.py +0 -303
  18. myosotis_researches/CcGAN/models_256/ResNet_class_eval.py +0 -142
  19. myosotis_researches/CcGAN/models_256/ResNet_embed.py +0 -188
  20. myosotis_researches/CcGAN/models_256/ResNet_regre_eval.py +0 -178
  21. myosotis_researches/CcGAN/models_256/__init__.py +0 -7
  22. myosotis_researches/CcGAN/models_256/autoencoder.py +0 -133
  23. myosotis_researches/CcGAN/models_256/cGAN_SAGAN.py +0 -280
  24. myosotis_researches/CcGAN/models_256/cGAN_concat_SAGAN.py +0 -249
  25. myosotis_researches/CcGAN/train_128/DiffAugment_pytorch.py +0 -76
  26. myosotis_researches/CcGAN/train_128/__init__.py +0 -0
  27. myosotis_researches/CcGAN/train_128/eval_metrics.py +0 -205
  28. myosotis_researches/CcGAN/train_128/opts.py +0 -87
  29. myosotis_researches/CcGAN/train_128/pretrain_AE.py +0 -268
  30. myosotis_researches/CcGAN/train_128/pretrain_CNN_class.py +0 -251
  31. myosotis_researches/CcGAN/train_128/pretrain_CNN_regre.py +0 -255
  32. myosotis_researches/CcGAN/train_128/train_ccgan.py +0 -303
  33. myosotis_researches/CcGAN/train_128/utils.py +0 -120
  34. myosotis_researches/CcGAN/train_128_output_10/DiffAugment_pytorch.py +0 -76
  35. myosotis_researches/CcGAN/train_128_output_10/__init__.py +0 -0
  36. myosotis_researches/CcGAN/train_128_output_10/eval_metrics.py +0 -205
  37. myosotis_researches/CcGAN/train_128_output_10/opts.py +0 -87
  38. myosotis_researches/CcGAN/train_128_output_10/pretrain_AE.py +0 -268
  39. myosotis_researches/CcGAN/train_128_output_10/pretrain_CNN_class.py +0 -251
  40. myosotis_researches/CcGAN/train_128_output_10/pretrain_CNN_regre.py +0 -255
  41. myosotis_researches/CcGAN/train_128_output_10/train_cgan.py +0 -254
  42. myosotis_researches/CcGAN/train_128_output_10/train_cgan_concat.py +0 -242
  43. myosotis_researches/CcGAN/train_128_output_10/train_net_for_label_embed.py +0 -181
  44. myosotis_researches/CcGAN/train_128_output_10/utils.py +0 -120
  45. myosotis_researches-0.1.7.dist-info/RECORD +0 -59
  46. /myosotis_researches/CcGAN/{train_128 → train}/train_net_for_label_embed.py +0 -0
  47. {myosotis_researches-0.1.7.dist-info → myosotis_researches-0.1.9.dist-info}/WHEEL +0 -0
  48. {myosotis_researches-0.1.7.dist-info → myosotis_researches-0.1.9.dist-info}/licenses/LICENSE +0 -0
  49. {myosotis_researches-0.1.7.dist-info → myosotis_researches-0.1.9.dist-info}/top_level.txt +0 -0
@@ -1,255 +0,0 @@
1
- """
2
-
3
- Pre-train a CNN on the whole dataset for evaluation purpose
4
-
5
- """
6
- import os
7
- import argparse
8
- import shutil
9
- import timeit
10
- import torch
11
- import torchvision
12
- import torchvision.transforms as transforms
13
- import numpy as np
14
- import torch.nn as nn
15
- import torch.backends.cudnn as cudnn
16
- import random
17
- import matplotlib.pyplot as plt
18
- import matplotlib as mpl
19
- from torch import autograd
20
- from torchvision.utils import save_image
21
- import csv
22
- from tqdm import tqdm
23
- import gc
24
- import h5py
25
-
26
-
27
- #############################
28
- # Settings
29
- #############################
30
-
31
- parser = argparse.ArgumentParser(description='Pre-train CNNs')
32
- parser.add_argument('--root_path', type=str, default='')
33
- parser.add_argument('--data_path', type=str, default='')
34
- parser.add_argument('--num_workers', type=int, default=0)
35
- parser.add_argument('--CNN', type=str, default='ResNet34_regre',
36
- help='CNN for training; ResNetXX')
37
- parser.add_argument('--epochs', type=int, default=200, metavar='N',
38
- help='number of epochs to train CNNs (default: 200)')
39
- parser.add_argument('--batch_size_train', type=int, default=256, metavar='N',
40
- help='input batch size for training')
41
- parser.add_argument('--batch_size_valid', type=int, default=10, metavar='N',
42
- help='input batch size for testing')
43
- parser.add_argument('--base_lr', type=float, default=0.01,
44
- help='learning rate, default=0.1')
45
- parser.add_argument('--weight_dacay', type=float, default=1e-4,
46
- help='Weigth decay, default=1e-4')
47
- parser.add_argument('--seed', type=int, default=2020, metavar='S',
48
- help='random seed (default: 1)')
49
- parser.add_argument('--CVMode', action='store_true', default=False,
50
- help='CV mode?')
51
- parser.add_argument('--img_size', type=int, default=128, metavar='N')
52
- parser.add_argument('--min_label', type=float, default=0.0)
53
- parser.add_argument('--max_label', type=float, default=90.0)
54
- args = parser.parse_args()
55
-
56
-
57
- wd = args.root_path
58
- os.chdir(wd)
59
- from ..models_128 import *
60
- from .utils import IMGs_dataset
61
-
62
- # cuda
63
- device = torch.device("cuda")
64
- ngpu = torch.cuda.device_count() # number of gpus
65
-
66
- # random seed
67
- random.seed(args.seed)
68
- torch.manual_seed(args.seed)
69
- torch.backends.cudnn.deterministic = True
70
- np.random.seed(args.seed)
71
-
72
- # directories for checkpoint, images and log files
73
- save_models_folder = wd + '/output/eval_models/'
74
- os.makedirs(save_models_folder, exist_ok=True)
75
-
76
-
77
- ###########################################################################################################
78
- # Data
79
- ###########################################################################################################
80
- # data loader
81
- data_filename = args.data_path + '/Ra_' + str(args.img_size) + 'x' + str(args.img_size) + '.h5'
82
- hf = h5py.File(data_filename, 'r')
83
- labels = hf['labels'][:]
84
- labels = labels.astype(float)
85
- images = hf['images'][:]
86
- hf.close()
87
- N_all = len(images)
88
- assert len(images) == len(labels)
89
-
90
- q1 = args.min_label
91
- q2 = args.max_label
92
- indx = np.where((labels>q1)*(labels<q2)==True)[0]
93
- labels = labels[indx]
94
- images = images[indx]
95
- assert len(labels)==len(images)
96
-
97
-
98
- # normalize to [0, 1]
99
- #min_label = np.min(labels)
100
- #labels += np.abs(min_label)
101
- #max_label = np.max(labels)
102
- #labels /= max_label
103
- labels /= args.max_label
104
-
105
-
106
- # define training and validation sets
107
- if args.CVMode:
108
- #80% Training; 20% valdation
109
- indx_all = np.arange(len(labels))
110
- np.random.shuffle(indx_all)
111
- indx_valid = indx_all[0:int(0.2*len(labels))]
112
- indx_train = indx_all[int(0.2*len(labels)):]
113
-
114
- trainset = IMGs_dataset(images[indx_train], labels[indx_train], normalize=True)
115
- trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size_train, shuffle=True, num_workers=args.num_workers)
116
- validset = IMGs_dataset(images[indx_valid], labels[indx_valid], normalize=True)
117
- validloader = torch.utils.data.DataLoader(validset, batch_size=args.batch_size_valid, shuffle=False, num_workers=args.num_workers)
118
-
119
- else:
120
- trainset = IMGs_dataset(images, labels, normalize=True)
121
- trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size_train, shuffle=True, num_workers=args.num_workers)
122
-
123
-
124
-
125
-
126
- ###########################################################################################################
127
- # Necessary functions
128
- ###########################################################################################################
129
-
130
- #initialize CNNs
131
- def net_initialization(Pretrained_CNN_Name, ngpu = 1):
132
- if Pretrained_CNN_Name == "ResNet18_regre":
133
- net = ResNet18_regre_eval(ngpu = ngpu)
134
- elif Pretrained_CNN_Name == "ResNet34_regre":
135
- net = ResNet34_regre_eval(ngpu = ngpu)
136
- elif Pretrained_CNN_Name == "ResNet50_regre":
137
- net = ResNet50_regre_eval(ngpu = ngpu)
138
- elif Pretrained_CNN_Name == "ResNet101_regre":
139
- net = ResNet101_regre_eval(ngpu = ngpu)
140
-
141
- net_name = 'PreCNNForEvalGANs_' + Pretrained_CNN_Name #get the net's name
142
- net = net.to(device)
143
-
144
- return net, net_name
145
-
146
- #adjust CNN learning rate
147
- def adjust_learning_rate(optimizer, epoch, BASE_LR_CNN):
148
- lr = BASE_LR_CNN
149
- # if epoch >= 35:
150
- # lr /= 10
151
- # if epoch >= 70:
152
- # lr /= 10
153
- if epoch >= 50:
154
- lr /= 10
155
- if epoch >= 120:
156
- lr /= 10
157
- for param_group in optimizer.param_groups:
158
- param_group['lr'] = lr
159
-
160
-
161
- def train_CNN():
162
-
163
- start_tmp = timeit.default_timer()
164
- for epoch in range(args.epochs):
165
- net.train()
166
- train_loss = 0
167
- adjust_learning_rate(optimizer, epoch, args.base_lr)
168
- for batch_idx, (batch_train_images, batch_train_labels) in enumerate(trainloader):
169
-
170
- # batch_train_images = nn.functional.interpolate(batch_train_images, size = (299,299), scale_factor=None, mode='bilinear', align_corners=False)
171
-
172
- batch_train_images = batch_train_images.type(torch.float).cuda()
173
- batch_train_labels = batch_train_labels.type(torch.float).view(-1,1).cuda()
174
-
175
- #Forward pass
176
- outputs,_ = net(batch_train_images)
177
- loss = criterion(outputs, batch_train_labels)
178
-
179
- #backward pass
180
- optimizer.zero_grad()
181
- loss.backward()
182
- optimizer.step()
183
-
184
- train_loss += loss.cpu().item()
185
- #end for batch_idx
186
- train_loss = train_loss / len(trainloader)
187
-
188
- if args.CVMode:
189
- valid_loss = valid_CNN(verbose=False)
190
- print('CNN: [epoch %d/%d] train_loss:%f valid_loss (avg_abs):%f Time: %.4f' % (epoch+1, args.epochs, train_loss, valid_loss, timeit.default_timer()-start_tmp))
191
- else:
192
- print('CNN: [epoch %d/%d] train_loss:%f Time:%.4f' % (epoch+1, args.epochs, train_loss, timeit.default_timer()-start_tmp))
193
- #end for epoch
194
-
195
- return net, optimizer
196
-
197
- if args.CVMode:
198
- def valid_CNN(verbose=True):
199
- net.eval() # eval mode (batchnorm uses moving mean/variance instead of mini-batch mean/variance)
200
- with torch.no_grad():
201
- abs_diff_avg = 0
202
- total = 0
203
- for batch_idx, (images, labels) in enumerate(validloader):
204
- images = images.type(torch.float).cuda()
205
- labels = labels.type(torch.float).view(-1).cpu().numpy()
206
- outputs,_ = net(images)
207
- outputs = outputs.view(-1).cpu().numpy()
208
- labels = labels * args.max_label
209
- outputs = outputs * args.max_label
210
- abs_diff_avg += np.sum(np.abs(labels-outputs))
211
- total += len(labels)
212
-
213
- if verbose:
214
- print('Validation Average Absolute Difference: {}'.format(abs_diff_avg/total))
215
- return abs_diff_avg/total
216
-
217
-
218
-
219
- ###########################################################################################################
220
- # Training and validation
221
- ###########################################################################################################
222
-
223
-
224
- # model initialization
225
- net, net_name = net_initialization(args.CNN, ngpu = ngpu)
226
- criterion = nn.MSELoss()
227
- optimizer = torch.optim.SGD(net.parameters(), lr = args.base_lr, momentum= 0.9, weight_decay=args.weight_dacay)
228
-
229
- filename_ckpt = save_models_folder + '/ckpt_{}_epoch_{}_seed_{}_CVMode_{}.pth'.format(net_name, args.epochs, args.seed, args.CVMode)
230
-
231
-
232
- # training
233
- if not os.path.isfile(filename_ckpt):
234
- # TRAIN CNN
235
- print("\n Begin training CNN: ")
236
- start = timeit.default_timer()
237
- net, optimizer = train_CNN()
238
- stop = timeit.default_timer()
239
- print("Time elapses: {}s".format(stop - start))
240
- # save model
241
- torch.save({
242
- 'net_state_dict': net.state_dict(),
243
- }, filename_ckpt)
244
- else:
245
- print("\n Ckpt already exists")
246
- print("\n Loading...")
247
- checkpoint = torch.load(filename_ckpt)
248
- net.load_state_dict(checkpoint['net_state_dict'])
249
- torch.cuda.empty_cache()#release GPU mem which is not references
250
-
251
-
252
- if args.CVMode:
253
- #validation
254
- _ = valid_CNN(True)
255
- torch.cuda.empty_cache()
@@ -1,303 +0,0 @@
1
- import torch
2
- import numpy as np
3
- import os
4
- import timeit
5
- from PIL import Image
6
- from torchvision.utils import save_image
7
- import torch.cuda as cutorch
8
-
9
- from .utils import SimpleProgressBar, IMGs_dataset
10
- from .opts import parse_opts
11
- from .DiffAugment_pytorch import DiffAugment
12
-
13
- ''' Settings '''
14
- args = parse_opts()
15
-
16
- # some parameters in opts
17
- gan_arch = args.GAN_arch
18
- loss_type = args.loss_type_gan
19
- niters = args.niters_gan
20
- resume_niters = args.resume_niters_gan
21
- dim_gan = args.dim_gan
22
- lr_g = args.lr_g_gan
23
- lr_d = args.lr_d_gan
24
- save_niters_freq = args.save_niters_freq
25
- batch_size_disc = args.batch_size_disc
26
- batch_size_gene = args.batch_size_gene
27
- # batch_size_max = max(batch_size_disc, batch_size_gene)
28
- num_D_steps = args.num_D_steps
29
-
30
- visualize_freq = args.visualize_freq
31
-
32
- num_workers = args.num_workers
33
-
34
- threshold_type = args.threshold_type
35
- nonzero_soft_weight_threshold = args.nonzero_soft_weight_threshold
36
-
37
- num_channels = args.num_channels
38
- img_size = args.img_size
39
- max_label = args.max_label
40
-
41
- use_DiffAugment = args.gan_DiffAugment
42
- policy = args.gan_DiffAugment_policy
43
-
44
-
45
- ## normalize images
46
- def normalize_images(batch_images):
47
- batch_images = batch_images/255.0
48
- batch_images = (batch_images - 0.5)/0.5
49
- return batch_images
50
-
51
-
52
- def train_ccgan(kernel_sigma, kappa, train_images, train_labels, netG, netD, net_y2h, save_images_folder, save_models_folder = None, clip_label=False):
53
-
54
- '''
55
- Note that train_images are not normalized to [-1,1]
56
- '''
57
-
58
- netG = netG.cuda()
59
- netD = netD.cuda()
60
- net_y2h = net_y2h.cuda()
61
- net_y2h.eval()
62
-
63
- optimizerG = torch.optim.Adam(netG.parameters(), lr=lr_g, betas=(0.5, 0.999))
64
- optimizerD = torch.optim.Adam(netD.parameters(), lr=lr_d, betas=(0.5, 0.999))
65
-
66
- if save_models_folder is not None and resume_niters>0:
67
- save_file = save_models_folder + "/CcGAN_{}_{}_nDsteps_{}_checkpoint_intrain/CcGAN_checkpoint_niters_{}.pth".format(gan_arch, threshold_type, num_D_steps, resume_niters)
68
- checkpoint = torch.load(save_file)
69
- netG.load_state_dict(checkpoint['netG_state_dict'])
70
- netD.load_state_dict(checkpoint['netD_state_dict'])
71
- optimizerG.load_state_dict(checkpoint['optimizerG_state_dict'])
72
- optimizerD.load_state_dict(checkpoint['optimizerD_state_dict'])
73
- torch.set_rng_state(checkpoint['rng_state'])
74
- #end if
75
-
76
- #################
77
- unique_train_labels = np.sort(np.array(list(set(train_labels))))
78
-
79
- # printed images with labels between the 5-th quantile and 95-th quantile of training labels
80
- n_row=10; n_col = n_row
81
- z_fixed = torch.randn(n_row*n_col, dim_gan, dtype=torch.float).cuda()
82
- start_label = np.quantile(train_labels, 0.05)
83
- end_label = np.quantile(train_labels, 0.95)
84
- selected_labels = np.linspace(start_label, end_label, num=n_row)
85
- y_fixed = np.zeros(n_row*n_col)
86
- for i in range(n_row):
87
- curr_label = selected_labels[i]
88
- for j in range(n_col):
89
- y_fixed[i*n_col+j] = curr_label
90
- print(y_fixed)
91
- y_fixed = torch.from_numpy(y_fixed).type(torch.float).view(-1,1).cuda()
92
-
93
-
94
- start_time = timeit.default_timer()
95
- for niter in range(resume_niters, niters):
96
-
97
- ''' Train Discriminator '''
98
- for _ in range(num_D_steps):
99
-
100
- ## randomly draw batch_size_disc y's from unique_train_labels
101
- batch_target_labels_in_dataset = np.random.choice(unique_train_labels, size=batch_size_disc, replace=True)
102
- ## add Gaussian noise; we estimate image distribution conditional on these labels
103
- batch_epsilons = np.random.normal(0, kernel_sigma, batch_size_disc)
104
- batch_target_labels = batch_target_labels_in_dataset + batch_epsilons
105
-
106
- ## find index of real images with labels in the vicinity of batch_target_labels
107
- ## generate labels for fake image generation; these labels are also in the vicinity of batch_target_labels
108
- batch_real_indx = np.zeros(batch_size_disc, dtype=int) #index of images in the datata; the labels of these images are in the vicinity
109
- batch_fake_labels = np.zeros(batch_size_disc)
110
-
111
- for j in range(batch_size_disc):
112
- ## index for real images
113
- if threshold_type == "hard":
114
- indx_real_in_vicinity = np.where(np.abs(train_labels-batch_target_labels[j])<= kappa)[0]
115
- else:
116
- # reverse the weight function for SVDL
117
- indx_real_in_vicinity = np.where((train_labels-batch_target_labels[j])**2 <= -np.log(nonzero_soft_weight_threshold)/kappa)[0]
118
-
119
- ## if the max gap between two consecutive ordered unique labels is large, it is possible that len(indx_real_in_vicinity)<1
120
- while len(indx_real_in_vicinity)<1:
121
- batch_epsilons_j = np.random.normal(0, kernel_sigma, 1)
122
- batch_target_labels[j] = batch_target_labels_in_dataset[j] + batch_epsilons_j
123
- if clip_label:
124
- batch_target_labels = np.clip(batch_target_labels, 0.0, 1.0)
125
- ## index for real images
126
- if threshold_type == "hard":
127
- indx_real_in_vicinity = np.where(np.abs(train_labels-batch_target_labels[j])<= kappa)[0]
128
- else:
129
- # reverse the weight function for SVDL
130
- indx_real_in_vicinity = np.where((train_labels-batch_target_labels[j])**2 <= -np.log(nonzero_soft_weight_threshold)/kappa)[0]
131
- #end while len(indx_real_in_vicinity)<1
132
-
133
- assert len(indx_real_in_vicinity)>=1
134
-
135
- batch_real_indx[j] = np.random.choice(indx_real_in_vicinity, size=1)[0]
136
-
137
- ## labels for fake images generation
138
- if threshold_type == "hard":
139
- lb = batch_target_labels[j] - kappa
140
- ub = batch_target_labels[j] + kappa
141
- else:
142
- lb = batch_target_labels[j] - np.sqrt(-np.log(nonzero_soft_weight_threshold)/kappa)
143
- ub = batch_target_labels[j] + np.sqrt(-np.log(nonzero_soft_weight_threshold)/kappa)
144
- lb = max(0.0, lb); ub = min(ub, 1.0)
145
- assert lb<=ub
146
- assert lb>=0 and ub>=0
147
- assert lb<=1 and ub<=1
148
- batch_fake_labels[j] = np.random.uniform(lb, ub, size=1)[0]
149
- #end for j
150
-
151
- ## draw real image/label batch from the training set
152
- batch_real_images = torch.from_numpy(normalize_images(train_images[batch_real_indx]))
153
- batch_real_images = batch_real_images.type(torch.float).cuda()
154
- batch_real_labels = train_labels[batch_real_indx]
155
- batch_real_labels = torch.from_numpy(batch_real_labels).type(torch.float).cuda()
156
-
157
-
158
- ## generate the fake image batch
159
- batch_fake_labels = torch.from_numpy(batch_fake_labels).type(torch.float).cuda()
160
- z = torch.randn(batch_size_disc, dim_gan, dtype=torch.float).cuda()
161
- batch_fake_images = netG(z, net_y2h(batch_fake_labels))
162
-
163
- ## target labels on gpu
164
- batch_target_labels = torch.from_numpy(batch_target_labels).type(torch.float).cuda()
165
-
166
- ## weight vector
167
- if threshold_type == "soft":
168
- real_weights = torch.exp(-kappa*(batch_real_labels-batch_target_labels)**2).cuda()
169
- fake_weights = torch.exp(-kappa*(batch_fake_labels-batch_target_labels)**2).cuda()
170
- else:
171
- real_weights = torch.ones(batch_size_disc, dtype=torch.float).cuda()
172
- fake_weights = torch.ones(batch_size_disc, dtype=torch.float).cuda()
173
- #end if threshold type
174
-
175
- # forward pass
176
- if use_DiffAugment:
177
- real_dis_out = netD(DiffAugment(batch_real_images, policy=policy), net_y2h(batch_target_labels))
178
- fake_dis_out = netD(DiffAugment(batch_fake_images.detach(), policy=policy), net_y2h(batch_target_labels))
179
- else:
180
- real_dis_out = netD(batch_real_images, net_y2h(batch_target_labels))
181
- fake_dis_out = netD(batch_fake_images.detach(), net_y2h(batch_target_labels))
182
-
183
- if loss_type == "vanilla":
184
- real_dis_out = torch.nn.Sigmoid()(real_dis_out)
185
- fake_dis_out = torch.nn.Sigmoid()(fake_dis_out)
186
- d_loss_real = - torch.log(real_dis_out+1e-20)
187
- d_loss_fake = - torch.log(1-fake_dis_out+1e-20)
188
- elif loss_type == "hinge":
189
- d_loss_real = torch.nn.ReLU()(1.0 - real_dis_out)
190
- d_loss_fake = torch.nn.ReLU()(1.0 + fake_dis_out)
191
- else:
192
- raise ValueError('Not supported loss type!!!')
193
-
194
- d_loss = torch.mean(real_weights.view(-1) * d_loss_real.view(-1)) + torch.mean(fake_weights.view(-1) * d_loss_fake.view(-1))
195
-
196
- optimizerD.zero_grad()
197
- d_loss.backward()
198
- optimizerD.step()
199
-
200
- #end for step_D_index
201
-
202
-
203
-
204
- ''' Train Generator '''
205
- netG.train()
206
-
207
- # generate fake images
208
- ## randomly draw batch_size_gene y's from unique_train_labels
209
- batch_target_labels_in_dataset = np.random.choice(unique_train_labels, size=batch_size_gene, replace=True)
210
- ## add Gaussian noise; we estimate image distribution conditional on these labels
211
- batch_epsilons = np.random.normal(0, kernel_sigma, batch_size_gene)
212
- batch_target_labels = batch_target_labels_in_dataset + batch_epsilons
213
- batch_target_labels = torch.from_numpy(batch_target_labels).type(torch.float).cuda()
214
-
215
- z = torch.randn(batch_size_gene, dim_gan, dtype=torch.float).cuda()
216
- batch_fake_images = netG(z, net_y2h(batch_target_labels))
217
-
218
- # loss
219
- if use_DiffAugment:
220
- dis_out = netD(DiffAugment(batch_fake_images, policy=policy), net_y2h(batch_target_labels))
221
- else:
222
- dis_out = netD(batch_fake_images, net_y2h(batch_target_labels))
223
- if loss_type == "vanilla":
224
- dis_out = torch.nn.Sigmoid()(dis_out)
225
- g_loss = - torch.mean(torch.log(dis_out+1e-20))
226
- elif loss_type == "hinge":
227
- g_loss = - dis_out.mean()
228
-
229
- # backward
230
- optimizerG.zero_grad()
231
- g_loss.backward()
232
- optimizerG.step()
233
-
234
- # print loss
235
- if (niter+1) % 20 == 0:
236
- print ("CcGAN,%s: [Iter %d/%d] [D loss: %.4e] [G loss: %.4e] [real prob: %.3f] [fake prob: %.3f] [Time: %.4f]" % (gan_arch, niter+1, niters, d_loss.item(), g_loss.item(), real_dis_out.mean().item(), fake_dis_out.mean().item(), timeit.default_timer()-start_time))
237
-
238
- if (niter+1) % visualize_freq == 0:
239
- netG.eval()
240
- with torch.no_grad():
241
- gen_imgs = netG(z_fixed, net_y2h(y_fixed))
242
- gen_imgs = gen_imgs.detach().cpu()
243
- save_image(gen_imgs.data, save_images_folder + '/{}.png'.format(niter+1), nrow=n_row, normalize=True)
244
-
245
- if save_models_folder is not None and ((niter+1) % save_niters_freq == 0 or (niter+1) == niters):
246
- save_file = save_models_folder + "/CcGAN_{}_{}_nDsteps_{}_checkpoint_intrain/CcGAN_checkpoint_niters_{}.pth".format(gan_arch, threshold_type, num_D_steps, niter+1)
247
- os.makedirs(os.path.dirname(save_file), exist_ok=True)
248
- torch.save({
249
- 'netG_state_dict': netG.state_dict(),
250
- 'netD_state_dict': netD.state_dict(),
251
- 'optimizerG_state_dict': optimizerG.state_dict(),
252
- 'optimizerD_state_dict': optimizerD.state_dict(),
253
- 'rng_state': torch.get_rng_state()
254
- }, save_file)
255
- #end for niter
256
- return netG, netD
257
-
258
-
259
- def sample_ccgan_given_labels(netG, net_y2h, labels, batch_size = 500, to_numpy=True, denorm=True, verbose=True):
260
- '''
261
- netG: pretrained generator network
262
- labels: float. normalized labels.
263
- '''
264
-
265
- nfake = len(labels)
266
- if batch_size>nfake:
267
- batch_size=nfake
268
-
269
- fake_images = []
270
- fake_labels = np.concatenate((labels, labels[0:batch_size]))
271
- netG=netG.cuda()
272
- netG.eval()
273
- net_y2h = net_y2h.cuda()
274
- net_y2h.eval()
275
- with torch.no_grad():
276
- if verbose:
277
- pb = SimpleProgressBar()
278
- n_img_got = 0
279
- while n_img_got < nfake:
280
- z = torch.randn(batch_size, dim_gan, dtype=torch.float).cuda()
281
- y = torch.from_numpy(fake_labels[n_img_got:(n_img_got+batch_size)]).type(torch.float).view(-1,1).cuda()
282
- batch_fake_images = netG(z, net_y2h(y))
283
- if denorm: #denorm imgs to save memory
284
- assert batch_fake_images.max().item()<=1.0 and batch_fake_images.min().item()>=-1.0
285
- batch_fake_images = batch_fake_images*0.5+0.5
286
- batch_fake_images = batch_fake_images*255.0
287
- batch_fake_images = batch_fake_images.type(torch.uint8)
288
- # assert batch_fake_images.max().item()>1
289
- fake_images.append(batch_fake_images.cpu())
290
- n_img_got += batch_size
291
- if verbose:
292
- pb.update(min(float(n_img_got)/nfake, 1)*100)
293
- ##end while
294
-
295
- fake_images = torch.cat(fake_images, dim=0)
296
- #remove extra entries
297
- fake_images = fake_images[0:nfake]
298
- fake_labels = fake_labels[0:nfake]
299
-
300
- if to_numpy:
301
- fake_images = fake_images.numpy()
302
-
303
- return fake_images, fake_labels
@@ -1,120 +0,0 @@
1
- """
2
- Some helpful functions
3
-
4
- """
5
- import numpy as np
6
- import torch
7
- import torch.nn as nn
8
- import torchvision
9
- import matplotlib.pyplot as plt
10
- import matplotlib as mpl
11
- from torch.nn import functional as F
12
- import sys
13
- import PIL
14
- from PIL import Image
15
-
16
-
17
-
18
- # ################################################################################
19
- # Progress Bar
20
- class SimpleProgressBar():
21
- def __init__(self, width=50):
22
- self.last_x = -1
23
- self.width = width
24
-
25
- def update(self, x):
26
- assert 0 <= x <= 100 # `x`: progress in percent ( between 0 and 100)
27
- if self.last_x == int(x): return
28
- self.last_x = int(x)
29
- pointer = int(self.width * (x / 100.0))
30
- sys.stdout.write( '\r%d%% [%s]' % (int(x), '#' * pointer + '.' * (self.width - pointer)))
31
- sys.stdout.flush()
32
- if x == 100:
33
- print('')
34
-
35
-
36
-
37
- ################################################################################
38
- # torch dataset from numpy array
39
- class IMGs_dataset(torch.utils.data.Dataset):
40
- def __init__(self, images, labels=None, normalize=False):
41
- super(IMGs_dataset, self).__init__()
42
-
43
- self.images = images
44
- self.n_images = len(self.images)
45
- self.labels = labels
46
- if labels is not None:
47
- if len(self.images) != len(self.labels):
48
- raise Exception('images (' + str(len(self.images)) +') and labels ('+str(len(self.labels))+') do not have the same length!!!')
49
- self.normalize = normalize
50
-
51
-
52
- def __getitem__(self, index):
53
-
54
- image = self.images[index]
55
-
56
- if self.normalize:
57
- image = image/255.0
58
- image = (image-0.5)/0.5
59
-
60
- if self.labels is not None:
61
- label = self.labels[index]
62
- return (image, label)
63
- else:
64
- return image
65
-
66
- def __len__(self):
67
- return self.n_images
68
-
69
-
70
- def PlotLoss(loss, filename):
71
- x_axis = np.arange(start = 1, stop = len(loss)+1)
72
- plt.switch_backend('agg')
73
- mpl.style.use('seaborn')
74
- fig = plt.figure()
75
- ax = plt.subplot(111)
76
- ax.plot(x_axis, np.array(loss))
77
- plt.xlabel("epoch")
78
- plt.ylabel("training loss")
79
- plt.legend()
80
- #ax.legend(loc='upper center', bbox_to_anchor=(0.5, 1.15), shadow=True, ncol=3)
81
- #plt.title('Training Loss')
82
- plt.savefig(filename)
83
-
84
-
85
- # compute entropy of class labels; labels is a numpy array
86
- def compute_entropy(labels, base=None):
87
- value,counts = np.unique(labels, return_counts=True)
88
- norm_counts = counts / counts.sum()
89
- base = np.e if base is None else base
90
- return -(norm_counts * np.log(norm_counts)/np.log(base)).sum()
91
-
92
- def predict_class_labels(net, images, batch_size=500, verbose=False, num_workers=0):
93
- net = net.cuda()
94
- net.eval()
95
-
96
- n = len(images)
97
- if batch_size>n:
98
- batch_size=n
99
- dataset_pred = IMGs_dataset(images, normalize=False)
100
- dataloader_pred = torch.utils.data.DataLoader(dataset_pred, batch_size=batch_size, shuffle=False, num_workers=num_workers)
101
-
102
- class_labels_pred = np.zeros(n+batch_size)
103
- with torch.no_grad():
104
- nimgs_got = 0
105
- if verbose:
106
- pb = SimpleProgressBar()
107
- for batch_idx, batch_images in enumerate(dataloader_pred):
108
- batch_images = batch_images.type(torch.float).cuda()
109
- batch_size_curr = len(batch_images)
110
-
111
- outputs,_ = net(batch_images)
112
- _, batch_class_labels_pred = torch.max(outputs.data, 1)
113
- class_labels_pred[nimgs_got:(nimgs_got+batch_size_curr)] = batch_class_labels_pred.detach().cpu().numpy().reshape(-1)
114
-
115
- nimgs_got += batch_size_curr
116
- if verbose:
117
- pb.update((float(nimgs_got)/n)*100)
118
- #end for batch_idx
119
- class_labels_pred = class_labels_pred[0:n]
120
- return class_labels_pred