Myosotis-Researches 0.1.7__py3-none-any.whl → 0.1.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- myosotis_researches/CcGAN/train/__init__.py +4 -0
- myosotis_researches/CcGAN/{train_128_output_10 → train}/train_ccgan.py +4 -4
- myosotis_researches/CcGAN/{train_128 → train}/train_cgan.py +1 -3
- myosotis_researches/CcGAN/{train_128 → train}/train_cgan_concat.py +1 -3
- myosotis_researches/CcGAN/utils/__init__.py +2 -1
- myosotis_researches/CcGAN/utils/train.py +94 -3
- {myosotis_researches-0.1.7.dist-info → myosotis_researches-0.1.9.dist-info}/METADATA +1 -1
- myosotis_researches-0.1.9.dist-info/RECORD +24 -0
- myosotis_researches/CcGAN/models_128/CcGAN_SAGAN.py +0 -301
- myosotis_researches/CcGAN/models_128/ResNet_class_eval.py +0 -141
- myosotis_researches/CcGAN/models_128/ResNet_embed.py +0 -188
- myosotis_researches/CcGAN/models_128/ResNet_regre_eval.py +0 -175
- myosotis_researches/CcGAN/models_128/__init__.py +0 -7
- myosotis_researches/CcGAN/models_128/autoencoder.py +0 -119
- myosotis_researches/CcGAN/models_128/cGAN_SAGAN.py +0 -276
- myosotis_researches/CcGAN/models_128/cGAN_concat_SAGAN.py +0 -245
- myosotis_researches/CcGAN/models_256/CcGAN_SAGAN.py +0 -303
- myosotis_researches/CcGAN/models_256/ResNet_class_eval.py +0 -142
- myosotis_researches/CcGAN/models_256/ResNet_embed.py +0 -188
- myosotis_researches/CcGAN/models_256/ResNet_regre_eval.py +0 -178
- myosotis_researches/CcGAN/models_256/__init__.py +0 -7
- myosotis_researches/CcGAN/models_256/autoencoder.py +0 -133
- myosotis_researches/CcGAN/models_256/cGAN_SAGAN.py +0 -280
- myosotis_researches/CcGAN/models_256/cGAN_concat_SAGAN.py +0 -249
- myosotis_researches/CcGAN/train_128/DiffAugment_pytorch.py +0 -76
- myosotis_researches/CcGAN/train_128/__init__.py +0 -0
- myosotis_researches/CcGAN/train_128/eval_metrics.py +0 -205
- myosotis_researches/CcGAN/train_128/opts.py +0 -87
- myosotis_researches/CcGAN/train_128/pretrain_AE.py +0 -268
- myosotis_researches/CcGAN/train_128/pretrain_CNN_class.py +0 -251
- myosotis_researches/CcGAN/train_128/pretrain_CNN_regre.py +0 -255
- myosotis_researches/CcGAN/train_128/train_ccgan.py +0 -303
- myosotis_researches/CcGAN/train_128/utils.py +0 -120
- myosotis_researches/CcGAN/train_128_output_10/DiffAugment_pytorch.py +0 -76
- myosotis_researches/CcGAN/train_128_output_10/__init__.py +0 -0
- myosotis_researches/CcGAN/train_128_output_10/eval_metrics.py +0 -205
- myosotis_researches/CcGAN/train_128_output_10/opts.py +0 -87
- myosotis_researches/CcGAN/train_128_output_10/pretrain_AE.py +0 -268
- myosotis_researches/CcGAN/train_128_output_10/pretrain_CNN_class.py +0 -251
- myosotis_researches/CcGAN/train_128_output_10/pretrain_CNN_regre.py +0 -255
- myosotis_researches/CcGAN/train_128_output_10/train_cgan.py +0 -254
- myosotis_researches/CcGAN/train_128_output_10/train_cgan_concat.py +0 -242
- myosotis_researches/CcGAN/train_128_output_10/train_net_for_label_embed.py +0 -181
- myosotis_researches/CcGAN/train_128_output_10/utils.py +0 -120
- myosotis_researches-0.1.7.dist-info/RECORD +0 -59
- /myosotis_researches/CcGAN/{train_128 → train}/train_net_for_label_embed.py +0 -0
- {myosotis_researches-0.1.7.dist-info → myosotis_researches-0.1.9.dist-info}/WHEEL +0 -0
- {myosotis_researches-0.1.7.dist-info → myosotis_researches-0.1.9.dist-info}/licenses/LICENSE +0 -0
- {myosotis_researches-0.1.7.dist-info → myosotis_researches-0.1.9.dist-info}/top_level.txt +0 -0
@@ -1,255 +0,0 @@
|
|
1
|
-
"""
|
2
|
-
|
3
|
-
Pre-train a CNN on the whole dataset for evaluation purpose
|
4
|
-
|
5
|
-
"""
|
6
|
-
import os
|
7
|
-
import argparse
|
8
|
-
import shutil
|
9
|
-
import timeit
|
10
|
-
import torch
|
11
|
-
import torchvision
|
12
|
-
import torchvision.transforms as transforms
|
13
|
-
import numpy as np
|
14
|
-
import torch.nn as nn
|
15
|
-
import torch.backends.cudnn as cudnn
|
16
|
-
import random
|
17
|
-
import matplotlib.pyplot as plt
|
18
|
-
import matplotlib as mpl
|
19
|
-
from torch import autograd
|
20
|
-
from torchvision.utils import save_image
|
21
|
-
import csv
|
22
|
-
from tqdm import tqdm
|
23
|
-
import gc
|
24
|
-
import h5py
|
25
|
-
|
26
|
-
|
27
|
-
#############################
|
28
|
-
# Settings
|
29
|
-
#############################
|
30
|
-
|
31
|
-
parser = argparse.ArgumentParser(description='Pre-train CNNs')
|
32
|
-
parser.add_argument('--root_path', type=str, default='')
|
33
|
-
parser.add_argument('--data_path', type=str, default='')
|
34
|
-
parser.add_argument('--num_workers', type=int, default=0)
|
35
|
-
parser.add_argument('--CNN', type=str, default='ResNet34_regre',
|
36
|
-
help='CNN for training; ResNetXX')
|
37
|
-
parser.add_argument('--epochs', type=int, default=200, metavar='N',
|
38
|
-
help='number of epochs to train CNNs (default: 200)')
|
39
|
-
parser.add_argument('--batch_size_train', type=int, default=256, metavar='N',
|
40
|
-
help='input batch size for training')
|
41
|
-
parser.add_argument('--batch_size_valid', type=int, default=10, metavar='N',
|
42
|
-
help='input batch size for testing')
|
43
|
-
parser.add_argument('--base_lr', type=float, default=0.01,
|
44
|
-
help='learning rate, default=0.1')
|
45
|
-
parser.add_argument('--weight_dacay', type=float, default=1e-4,
|
46
|
-
help='Weigth decay, default=1e-4')
|
47
|
-
parser.add_argument('--seed', type=int, default=2020, metavar='S',
|
48
|
-
help='random seed (default: 1)')
|
49
|
-
parser.add_argument('--CVMode', action='store_true', default=False,
|
50
|
-
help='CV mode?')
|
51
|
-
parser.add_argument('--img_size', type=int, default=128, metavar='N')
|
52
|
-
parser.add_argument('--min_label', type=float, default=0.0)
|
53
|
-
parser.add_argument('--max_label', type=float, default=90.0)
|
54
|
-
args = parser.parse_args()
|
55
|
-
|
56
|
-
|
57
|
-
wd = args.root_path
|
58
|
-
os.chdir(wd)
|
59
|
-
from ..models_128 import *
|
60
|
-
from .utils import IMGs_dataset
|
61
|
-
|
62
|
-
# cuda
|
63
|
-
device = torch.device("cuda")
|
64
|
-
ngpu = torch.cuda.device_count() # number of gpus
|
65
|
-
|
66
|
-
# random seed
|
67
|
-
random.seed(args.seed)
|
68
|
-
torch.manual_seed(args.seed)
|
69
|
-
torch.backends.cudnn.deterministic = True
|
70
|
-
np.random.seed(args.seed)
|
71
|
-
|
72
|
-
# directories for checkpoint, images and log files
|
73
|
-
save_models_folder = wd + '/output/eval_models/'
|
74
|
-
os.makedirs(save_models_folder, exist_ok=True)
|
75
|
-
|
76
|
-
|
77
|
-
###########################################################################################################
|
78
|
-
# Data
|
79
|
-
###########################################################################################################
|
80
|
-
# data loader
|
81
|
-
data_filename = args.data_path + '/Ra_' + str(args.img_size) + 'x' + str(args.img_size) + '.h5'
|
82
|
-
hf = h5py.File(data_filename, 'r')
|
83
|
-
labels = hf['labels'][:]
|
84
|
-
labels = labels.astype(float)
|
85
|
-
images = hf['images'][:]
|
86
|
-
hf.close()
|
87
|
-
N_all = len(images)
|
88
|
-
assert len(images) == len(labels)
|
89
|
-
|
90
|
-
q1 = args.min_label
|
91
|
-
q2 = args.max_label
|
92
|
-
indx = np.where((labels>q1)*(labels<q2)==True)[0]
|
93
|
-
labels = labels[indx]
|
94
|
-
images = images[indx]
|
95
|
-
assert len(labels)==len(images)
|
96
|
-
|
97
|
-
|
98
|
-
# normalize to [0, 1]
|
99
|
-
#min_label = np.min(labels)
|
100
|
-
#labels += np.abs(min_label)
|
101
|
-
#max_label = np.max(labels)
|
102
|
-
#labels /= max_label
|
103
|
-
labels /= args.max_label
|
104
|
-
|
105
|
-
|
106
|
-
# define training and validation sets
|
107
|
-
if args.CVMode:
|
108
|
-
#80% Training; 20% valdation
|
109
|
-
indx_all = np.arange(len(labels))
|
110
|
-
np.random.shuffle(indx_all)
|
111
|
-
indx_valid = indx_all[0:int(0.2*len(labels))]
|
112
|
-
indx_train = indx_all[int(0.2*len(labels)):]
|
113
|
-
|
114
|
-
trainset = IMGs_dataset(images[indx_train], labels[indx_train], normalize=True)
|
115
|
-
trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size_train, shuffle=True, num_workers=args.num_workers)
|
116
|
-
validset = IMGs_dataset(images[indx_valid], labels[indx_valid], normalize=True)
|
117
|
-
validloader = torch.utils.data.DataLoader(validset, batch_size=args.batch_size_valid, shuffle=False, num_workers=args.num_workers)
|
118
|
-
|
119
|
-
else:
|
120
|
-
trainset = IMGs_dataset(images, labels, normalize=True)
|
121
|
-
trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size_train, shuffle=True, num_workers=args.num_workers)
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
###########################################################################################################
|
127
|
-
# Necessary functions
|
128
|
-
###########################################################################################################
|
129
|
-
|
130
|
-
#initialize CNNs
|
131
|
-
def net_initialization(Pretrained_CNN_Name, ngpu = 1):
|
132
|
-
if Pretrained_CNN_Name == "ResNet18_regre":
|
133
|
-
net = ResNet18_regre_eval(ngpu = ngpu)
|
134
|
-
elif Pretrained_CNN_Name == "ResNet34_regre":
|
135
|
-
net = ResNet34_regre_eval(ngpu = ngpu)
|
136
|
-
elif Pretrained_CNN_Name == "ResNet50_regre":
|
137
|
-
net = ResNet50_regre_eval(ngpu = ngpu)
|
138
|
-
elif Pretrained_CNN_Name == "ResNet101_regre":
|
139
|
-
net = ResNet101_regre_eval(ngpu = ngpu)
|
140
|
-
|
141
|
-
net_name = 'PreCNNForEvalGANs_' + Pretrained_CNN_Name #get the net's name
|
142
|
-
net = net.to(device)
|
143
|
-
|
144
|
-
return net, net_name
|
145
|
-
|
146
|
-
#adjust CNN learning rate
|
147
|
-
def adjust_learning_rate(optimizer, epoch, BASE_LR_CNN):
|
148
|
-
lr = BASE_LR_CNN
|
149
|
-
# if epoch >= 35:
|
150
|
-
# lr /= 10
|
151
|
-
# if epoch >= 70:
|
152
|
-
# lr /= 10
|
153
|
-
if epoch >= 50:
|
154
|
-
lr /= 10
|
155
|
-
if epoch >= 120:
|
156
|
-
lr /= 10
|
157
|
-
for param_group in optimizer.param_groups:
|
158
|
-
param_group['lr'] = lr
|
159
|
-
|
160
|
-
|
161
|
-
def train_CNN():
|
162
|
-
|
163
|
-
start_tmp = timeit.default_timer()
|
164
|
-
for epoch in range(args.epochs):
|
165
|
-
net.train()
|
166
|
-
train_loss = 0
|
167
|
-
adjust_learning_rate(optimizer, epoch, args.base_lr)
|
168
|
-
for batch_idx, (batch_train_images, batch_train_labels) in enumerate(trainloader):
|
169
|
-
|
170
|
-
# batch_train_images = nn.functional.interpolate(batch_train_images, size = (299,299), scale_factor=None, mode='bilinear', align_corners=False)
|
171
|
-
|
172
|
-
batch_train_images = batch_train_images.type(torch.float).cuda()
|
173
|
-
batch_train_labels = batch_train_labels.type(torch.float).view(-1,1).cuda()
|
174
|
-
|
175
|
-
#Forward pass
|
176
|
-
outputs,_ = net(batch_train_images)
|
177
|
-
loss = criterion(outputs, batch_train_labels)
|
178
|
-
|
179
|
-
#backward pass
|
180
|
-
optimizer.zero_grad()
|
181
|
-
loss.backward()
|
182
|
-
optimizer.step()
|
183
|
-
|
184
|
-
train_loss += loss.cpu().item()
|
185
|
-
#end for batch_idx
|
186
|
-
train_loss = train_loss / len(trainloader)
|
187
|
-
|
188
|
-
if args.CVMode:
|
189
|
-
valid_loss = valid_CNN(verbose=False)
|
190
|
-
print('CNN: [epoch %d/%d] train_loss:%f valid_loss (avg_abs):%f Time: %.4f' % (epoch+1, args.epochs, train_loss, valid_loss, timeit.default_timer()-start_tmp))
|
191
|
-
else:
|
192
|
-
print('CNN: [epoch %d/%d] train_loss:%f Time:%.4f' % (epoch+1, args.epochs, train_loss, timeit.default_timer()-start_tmp))
|
193
|
-
#end for epoch
|
194
|
-
|
195
|
-
return net, optimizer
|
196
|
-
|
197
|
-
if args.CVMode:
|
198
|
-
def valid_CNN(verbose=True):
|
199
|
-
net.eval() # eval mode (batchnorm uses moving mean/variance instead of mini-batch mean/variance)
|
200
|
-
with torch.no_grad():
|
201
|
-
abs_diff_avg = 0
|
202
|
-
total = 0
|
203
|
-
for batch_idx, (images, labels) in enumerate(validloader):
|
204
|
-
images = images.type(torch.float).cuda()
|
205
|
-
labels = labels.type(torch.float).view(-1).cpu().numpy()
|
206
|
-
outputs,_ = net(images)
|
207
|
-
outputs = outputs.view(-1).cpu().numpy()
|
208
|
-
labels = labels * args.max_label
|
209
|
-
outputs = outputs * args.max_label
|
210
|
-
abs_diff_avg += np.sum(np.abs(labels-outputs))
|
211
|
-
total += len(labels)
|
212
|
-
|
213
|
-
if verbose:
|
214
|
-
print('Validation Average Absolute Difference: {}'.format(abs_diff_avg/total))
|
215
|
-
return abs_diff_avg/total
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
###########################################################################################################
|
220
|
-
# Training and validation
|
221
|
-
###########################################################################################################
|
222
|
-
|
223
|
-
|
224
|
-
# model initialization
|
225
|
-
net, net_name = net_initialization(args.CNN, ngpu = ngpu)
|
226
|
-
criterion = nn.MSELoss()
|
227
|
-
optimizer = torch.optim.SGD(net.parameters(), lr = args.base_lr, momentum= 0.9, weight_decay=args.weight_dacay)
|
228
|
-
|
229
|
-
filename_ckpt = save_models_folder + '/ckpt_{}_epoch_{}_seed_{}_CVMode_{}.pth'.format(net_name, args.epochs, args.seed, args.CVMode)
|
230
|
-
|
231
|
-
|
232
|
-
# training
|
233
|
-
if not os.path.isfile(filename_ckpt):
|
234
|
-
# TRAIN CNN
|
235
|
-
print("\n Begin training CNN: ")
|
236
|
-
start = timeit.default_timer()
|
237
|
-
net, optimizer = train_CNN()
|
238
|
-
stop = timeit.default_timer()
|
239
|
-
print("Time elapses: {}s".format(stop - start))
|
240
|
-
# save model
|
241
|
-
torch.save({
|
242
|
-
'net_state_dict': net.state_dict(),
|
243
|
-
}, filename_ckpt)
|
244
|
-
else:
|
245
|
-
print("\n Ckpt already exists")
|
246
|
-
print("\n Loading...")
|
247
|
-
checkpoint = torch.load(filename_ckpt)
|
248
|
-
net.load_state_dict(checkpoint['net_state_dict'])
|
249
|
-
torch.cuda.empty_cache()#release GPU mem which is not references
|
250
|
-
|
251
|
-
|
252
|
-
if args.CVMode:
|
253
|
-
#validation
|
254
|
-
_ = valid_CNN(True)
|
255
|
-
torch.cuda.empty_cache()
|
@@ -1,303 +0,0 @@
|
|
1
|
-
import torch
|
2
|
-
import numpy as np
|
3
|
-
import os
|
4
|
-
import timeit
|
5
|
-
from PIL import Image
|
6
|
-
from torchvision.utils import save_image
|
7
|
-
import torch.cuda as cutorch
|
8
|
-
|
9
|
-
from .utils import SimpleProgressBar, IMGs_dataset
|
10
|
-
from .opts import parse_opts
|
11
|
-
from .DiffAugment_pytorch import DiffAugment
|
12
|
-
|
13
|
-
''' Settings '''
|
14
|
-
args = parse_opts()
|
15
|
-
|
16
|
-
# some parameters in opts
|
17
|
-
gan_arch = args.GAN_arch
|
18
|
-
loss_type = args.loss_type_gan
|
19
|
-
niters = args.niters_gan
|
20
|
-
resume_niters = args.resume_niters_gan
|
21
|
-
dim_gan = args.dim_gan
|
22
|
-
lr_g = args.lr_g_gan
|
23
|
-
lr_d = args.lr_d_gan
|
24
|
-
save_niters_freq = args.save_niters_freq
|
25
|
-
batch_size_disc = args.batch_size_disc
|
26
|
-
batch_size_gene = args.batch_size_gene
|
27
|
-
# batch_size_max = max(batch_size_disc, batch_size_gene)
|
28
|
-
num_D_steps = args.num_D_steps
|
29
|
-
|
30
|
-
visualize_freq = args.visualize_freq
|
31
|
-
|
32
|
-
num_workers = args.num_workers
|
33
|
-
|
34
|
-
threshold_type = args.threshold_type
|
35
|
-
nonzero_soft_weight_threshold = args.nonzero_soft_weight_threshold
|
36
|
-
|
37
|
-
num_channels = args.num_channels
|
38
|
-
img_size = args.img_size
|
39
|
-
max_label = args.max_label
|
40
|
-
|
41
|
-
use_DiffAugment = args.gan_DiffAugment
|
42
|
-
policy = args.gan_DiffAugment_policy
|
43
|
-
|
44
|
-
|
45
|
-
## normalize images
|
46
|
-
def normalize_images(batch_images):
|
47
|
-
batch_images = batch_images/255.0
|
48
|
-
batch_images = (batch_images - 0.5)/0.5
|
49
|
-
return batch_images
|
50
|
-
|
51
|
-
|
52
|
-
def train_ccgan(kernel_sigma, kappa, train_images, train_labels, netG, netD, net_y2h, save_images_folder, save_models_folder = None, clip_label=False):
|
53
|
-
|
54
|
-
'''
|
55
|
-
Note that train_images are not normalized to [-1,1]
|
56
|
-
'''
|
57
|
-
|
58
|
-
netG = netG.cuda()
|
59
|
-
netD = netD.cuda()
|
60
|
-
net_y2h = net_y2h.cuda()
|
61
|
-
net_y2h.eval()
|
62
|
-
|
63
|
-
optimizerG = torch.optim.Adam(netG.parameters(), lr=lr_g, betas=(0.5, 0.999))
|
64
|
-
optimizerD = torch.optim.Adam(netD.parameters(), lr=lr_d, betas=(0.5, 0.999))
|
65
|
-
|
66
|
-
if save_models_folder is not None and resume_niters>0:
|
67
|
-
save_file = save_models_folder + "/CcGAN_{}_{}_nDsteps_{}_checkpoint_intrain/CcGAN_checkpoint_niters_{}.pth".format(gan_arch, threshold_type, num_D_steps, resume_niters)
|
68
|
-
checkpoint = torch.load(save_file)
|
69
|
-
netG.load_state_dict(checkpoint['netG_state_dict'])
|
70
|
-
netD.load_state_dict(checkpoint['netD_state_dict'])
|
71
|
-
optimizerG.load_state_dict(checkpoint['optimizerG_state_dict'])
|
72
|
-
optimizerD.load_state_dict(checkpoint['optimizerD_state_dict'])
|
73
|
-
torch.set_rng_state(checkpoint['rng_state'])
|
74
|
-
#end if
|
75
|
-
|
76
|
-
#################
|
77
|
-
unique_train_labels = np.sort(np.array(list(set(train_labels))))
|
78
|
-
|
79
|
-
# printed images with labels between the 5-th quantile and 95-th quantile of training labels
|
80
|
-
n_row=10; n_col = n_row
|
81
|
-
z_fixed = torch.randn(n_row*n_col, dim_gan, dtype=torch.float).cuda()
|
82
|
-
start_label = np.quantile(train_labels, 0.05)
|
83
|
-
end_label = np.quantile(train_labels, 0.95)
|
84
|
-
selected_labels = np.linspace(start_label, end_label, num=n_row)
|
85
|
-
y_fixed = np.zeros(n_row*n_col)
|
86
|
-
for i in range(n_row):
|
87
|
-
curr_label = selected_labels[i]
|
88
|
-
for j in range(n_col):
|
89
|
-
y_fixed[i*n_col+j] = curr_label
|
90
|
-
print(y_fixed)
|
91
|
-
y_fixed = torch.from_numpy(y_fixed).type(torch.float).view(-1,1).cuda()
|
92
|
-
|
93
|
-
|
94
|
-
start_time = timeit.default_timer()
|
95
|
-
for niter in range(resume_niters, niters):
|
96
|
-
|
97
|
-
''' Train Discriminator '''
|
98
|
-
for _ in range(num_D_steps):
|
99
|
-
|
100
|
-
## randomly draw batch_size_disc y's from unique_train_labels
|
101
|
-
batch_target_labels_in_dataset = np.random.choice(unique_train_labels, size=batch_size_disc, replace=True)
|
102
|
-
## add Gaussian noise; we estimate image distribution conditional on these labels
|
103
|
-
batch_epsilons = np.random.normal(0, kernel_sigma, batch_size_disc)
|
104
|
-
batch_target_labels = batch_target_labels_in_dataset + batch_epsilons
|
105
|
-
|
106
|
-
## find index of real images with labels in the vicinity of batch_target_labels
|
107
|
-
## generate labels for fake image generation; these labels are also in the vicinity of batch_target_labels
|
108
|
-
batch_real_indx = np.zeros(batch_size_disc, dtype=int) #index of images in the datata; the labels of these images are in the vicinity
|
109
|
-
batch_fake_labels = np.zeros(batch_size_disc)
|
110
|
-
|
111
|
-
for j in range(batch_size_disc):
|
112
|
-
## index for real images
|
113
|
-
if threshold_type == "hard":
|
114
|
-
indx_real_in_vicinity = np.where(np.abs(train_labels-batch_target_labels[j])<= kappa)[0]
|
115
|
-
else:
|
116
|
-
# reverse the weight function for SVDL
|
117
|
-
indx_real_in_vicinity = np.where((train_labels-batch_target_labels[j])**2 <= -np.log(nonzero_soft_weight_threshold)/kappa)[0]
|
118
|
-
|
119
|
-
## if the max gap between two consecutive ordered unique labels is large, it is possible that len(indx_real_in_vicinity)<1
|
120
|
-
while len(indx_real_in_vicinity)<1:
|
121
|
-
batch_epsilons_j = np.random.normal(0, kernel_sigma, 1)
|
122
|
-
batch_target_labels[j] = batch_target_labels_in_dataset[j] + batch_epsilons_j
|
123
|
-
if clip_label:
|
124
|
-
batch_target_labels = np.clip(batch_target_labels, 0.0, 1.0)
|
125
|
-
## index for real images
|
126
|
-
if threshold_type == "hard":
|
127
|
-
indx_real_in_vicinity = np.where(np.abs(train_labels-batch_target_labels[j])<= kappa)[0]
|
128
|
-
else:
|
129
|
-
# reverse the weight function for SVDL
|
130
|
-
indx_real_in_vicinity = np.where((train_labels-batch_target_labels[j])**2 <= -np.log(nonzero_soft_weight_threshold)/kappa)[0]
|
131
|
-
#end while len(indx_real_in_vicinity)<1
|
132
|
-
|
133
|
-
assert len(indx_real_in_vicinity)>=1
|
134
|
-
|
135
|
-
batch_real_indx[j] = np.random.choice(indx_real_in_vicinity, size=1)[0]
|
136
|
-
|
137
|
-
## labels for fake images generation
|
138
|
-
if threshold_type == "hard":
|
139
|
-
lb = batch_target_labels[j] - kappa
|
140
|
-
ub = batch_target_labels[j] + kappa
|
141
|
-
else:
|
142
|
-
lb = batch_target_labels[j] - np.sqrt(-np.log(nonzero_soft_weight_threshold)/kappa)
|
143
|
-
ub = batch_target_labels[j] + np.sqrt(-np.log(nonzero_soft_weight_threshold)/kappa)
|
144
|
-
lb = max(0.0, lb); ub = min(ub, 1.0)
|
145
|
-
assert lb<=ub
|
146
|
-
assert lb>=0 and ub>=0
|
147
|
-
assert lb<=1 and ub<=1
|
148
|
-
batch_fake_labels[j] = np.random.uniform(lb, ub, size=1)[0]
|
149
|
-
#end for j
|
150
|
-
|
151
|
-
## draw real image/label batch from the training set
|
152
|
-
batch_real_images = torch.from_numpy(normalize_images(train_images[batch_real_indx]))
|
153
|
-
batch_real_images = batch_real_images.type(torch.float).cuda()
|
154
|
-
batch_real_labels = train_labels[batch_real_indx]
|
155
|
-
batch_real_labels = torch.from_numpy(batch_real_labels).type(torch.float).cuda()
|
156
|
-
|
157
|
-
|
158
|
-
## generate the fake image batch
|
159
|
-
batch_fake_labels = torch.from_numpy(batch_fake_labels).type(torch.float).cuda()
|
160
|
-
z = torch.randn(batch_size_disc, dim_gan, dtype=torch.float).cuda()
|
161
|
-
batch_fake_images = netG(z, net_y2h(batch_fake_labels))
|
162
|
-
|
163
|
-
## target labels on gpu
|
164
|
-
batch_target_labels = torch.from_numpy(batch_target_labels).type(torch.float).cuda()
|
165
|
-
|
166
|
-
## weight vector
|
167
|
-
if threshold_type == "soft":
|
168
|
-
real_weights = torch.exp(-kappa*(batch_real_labels-batch_target_labels)**2).cuda()
|
169
|
-
fake_weights = torch.exp(-kappa*(batch_fake_labels-batch_target_labels)**2).cuda()
|
170
|
-
else:
|
171
|
-
real_weights = torch.ones(batch_size_disc, dtype=torch.float).cuda()
|
172
|
-
fake_weights = torch.ones(batch_size_disc, dtype=torch.float).cuda()
|
173
|
-
#end if threshold type
|
174
|
-
|
175
|
-
# forward pass
|
176
|
-
if use_DiffAugment:
|
177
|
-
real_dis_out = netD(DiffAugment(batch_real_images, policy=policy), net_y2h(batch_target_labels))
|
178
|
-
fake_dis_out = netD(DiffAugment(batch_fake_images.detach(), policy=policy), net_y2h(batch_target_labels))
|
179
|
-
else:
|
180
|
-
real_dis_out = netD(batch_real_images, net_y2h(batch_target_labels))
|
181
|
-
fake_dis_out = netD(batch_fake_images.detach(), net_y2h(batch_target_labels))
|
182
|
-
|
183
|
-
if loss_type == "vanilla":
|
184
|
-
real_dis_out = torch.nn.Sigmoid()(real_dis_out)
|
185
|
-
fake_dis_out = torch.nn.Sigmoid()(fake_dis_out)
|
186
|
-
d_loss_real = - torch.log(real_dis_out+1e-20)
|
187
|
-
d_loss_fake = - torch.log(1-fake_dis_out+1e-20)
|
188
|
-
elif loss_type == "hinge":
|
189
|
-
d_loss_real = torch.nn.ReLU()(1.0 - real_dis_out)
|
190
|
-
d_loss_fake = torch.nn.ReLU()(1.0 + fake_dis_out)
|
191
|
-
else:
|
192
|
-
raise ValueError('Not supported loss type!!!')
|
193
|
-
|
194
|
-
d_loss = torch.mean(real_weights.view(-1) * d_loss_real.view(-1)) + torch.mean(fake_weights.view(-1) * d_loss_fake.view(-1))
|
195
|
-
|
196
|
-
optimizerD.zero_grad()
|
197
|
-
d_loss.backward()
|
198
|
-
optimizerD.step()
|
199
|
-
|
200
|
-
#end for step_D_index
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
''' Train Generator '''
|
205
|
-
netG.train()
|
206
|
-
|
207
|
-
# generate fake images
|
208
|
-
## randomly draw batch_size_gene y's from unique_train_labels
|
209
|
-
batch_target_labels_in_dataset = np.random.choice(unique_train_labels, size=batch_size_gene, replace=True)
|
210
|
-
## add Gaussian noise; we estimate image distribution conditional on these labels
|
211
|
-
batch_epsilons = np.random.normal(0, kernel_sigma, batch_size_gene)
|
212
|
-
batch_target_labels = batch_target_labels_in_dataset + batch_epsilons
|
213
|
-
batch_target_labels = torch.from_numpy(batch_target_labels).type(torch.float).cuda()
|
214
|
-
|
215
|
-
z = torch.randn(batch_size_gene, dim_gan, dtype=torch.float).cuda()
|
216
|
-
batch_fake_images = netG(z, net_y2h(batch_target_labels))
|
217
|
-
|
218
|
-
# loss
|
219
|
-
if use_DiffAugment:
|
220
|
-
dis_out = netD(DiffAugment(batch_fake_images, policy=policy), net_y2h(batch_target_labels))
|
221
|
-
else:
|
222
|
-
dis_out = netD(batch_fake_images, net_y2h(batch_target_labels))
|
223
|
-
if loss_type == "vanilla":
|
224
|
-
dis_out = torch.nn.Sigmoid()(dis_out)
|
225
|
-
g_loss = - torch.mean(torch.log(dis_out+1e-20))
|
226
|
-
elif loss_type == "hinge":
|
227
|
-
g_loss = - dis_out.mean()
|
228
|
-
|
229
|
-
# backward
|
230
|
-
optimizerG.zero_grad()
|
231
|
-
g_loss.backward()
|
232
|
-
optimizerG.step()
|
233
|
-
|
234
|
-
# print loss
|
235
|
-
if (niter+1) % 20 == 0:
|
236
|
-
print ("CcGAN,%s: [Iter %d/%d] [D loss: %.4e] [G loss: %.4e] [real prob: %.3f] [fake prob: %.3f] [Time: %.4f]" % (gan_arch, niter+1, niters, d_loss.item(), g_loss.item(), real_dis_out.mean().item(), fake_dis_out.mean().item(), timeit.default_timer()-start_time))
|
237
|
-
|
238
|
-
if (niter+1) % visualize_freq == 0:
|
239
|
-
netG.eval()
|
240
|
-
with torch.no_grad():
|
241
|
-
gen_imgs = netG(z_fixed, net_y2h(y_fixed))
|
242
|
-
gen_imgs = gen_imgs.detach().cpu()
|
243
|
-
save_image(gen_imgs.data, save_images_folder + '/{}.png'.format(niter+1), nrow=n_row, normalize=True)
|
244
|
-
|
245
|
-
if save_models_folder is not None and ((niter+1) % save_niters_freq == 0 or (niter+1) == niters):
|
246
|
-
save_file = save_models_folder + "/CcGAN_{}_{}_nDsteps_{}_checkpoint_intrain/CcGAN_checkpoint_niters_{}.pth".format(gan_arch, threshold_type, num_D_steps, niter+1)
|
247
|
-
os.makedirs(os.path.dirname(save_file), exist_ok=True)
|
248
|
-
torch.save({
|
249
|
-
'netG_state_dict': netG.state_dict(),
|
250
|
-
'netD_state_dict': netD.state_dict(),
|
251
|
-
'optimizerG_state_dict': optimizerG.state_dict(),
|
252
|
-
'optimizerD_state_dict': optimizerD.state_dict(),
|
253
|
-
'rng_state': torch.get_rng_state()
|
254
|
-
}, save_file)
|
255
|
-
#end for niter
|
256
|
-
return netG, netD
|
257
|
-
|
258
|
-
|
259
|
-
def sample_ccgan_given_labels(netG, net_y2h, labels, batch_size = 500, to_numpy=True, denorm=True, verbose=True):
|
260
|
-
'''
|
261
|
-
netG: pretrained generator network
|
262
|
-
labels: float. normalized labels.
|
263
|
-
'''
|
264
|
-
|
265
|
-
nfake = len(labels)
|
266
|
-
if batch_size>nfake:
|
267
|
-
batch_size=nfake
|
268
|
-
|
269
|
-
fake_images = []
|
270
|
-
fake_labels = np.concatenate((labels, labels[0:batch_size]))
|
271
|
-
netG=netG.cuda()
|
272
|
-
netG.eval()
|
273
|
-
net_y2h = net_y2h.cuda()
|
274
|
-
net_y2h.eval()
|
275
|
-
with torch.no_grad():
|
276
|
-
if verbose:
|
277
|
-
pb = SimpleProgressBar()
|
278
|
-
n_img_got = 0
|
279
|
-
while n_img_got < nfake:
|
280
|
-
z = torch.randn(batch_size, dim_gan, dtype=torch.float).cuda()
|
281
|
-
y = torch.from_numpy(fake_labels[n_img_got:(n_img_got+batch_size)]).type(torch.float).view(-1,1).cuda()
|
282
|
-
batch_fake_images = netG(z, net_y2h(y))
|
283
|
-
if denorm: #denorm imgs to save memory
|
284
|
-
assert batch_fake_images.max().item()<=1.0 and batch_fake_images.min().item()>=-1.0
|
285
|
-
batch_fake_images = batch_fake_images*0.5+0.5
|
286
|
-
batch_fake_images = batch_fake_images*255.0
|
287
|
-
batch_fake_images = batch_fake_images.type(torch.uint8)
|
288
|
-
# assert batch_fake_images.max().item()>1
|
289
|
-
fake_images.append(batch_fake_images.cpu())
|
290
|
-
n_img_got += batch_size
|
291
|
-
if verbose:
|
292
|
-
pb.update(min(float(n_img_got)/nfake, 1)*100)
|
293
|
-
##end while
|
294
|
-
|
295
|
-
fake_images = torch.cat(fake_images, dim=0)
|
296
|
-
#remove extra entries
|
297
|
-
fake_images = fake_images[0:nfake]
|
298
|
-
fake_labels = fake_labels[0:nfake]
|
299
|
-
|
300
|
-
if to_numpy:
|
301
|
-
fake_images = fake_images.numpy()
|
302
|
-
|
303
|
-
return fake_images, fake_labels
|
@@ -1,120 +0,0 @@
|
|
1
|
-
"""
|
2
|
-
Some helpful functions
|
3
|
-
|
4
|
-
"""
|
5
|
-
import numpy as np
|
6
|
-
import torch
|
7
|
-
import torch.nn as nn
|
8
|
-
import torchvision
|
9
|
-
import matplotlib.pyplot as plt
|
10
|
-
import matplotlib as mpl
|
11
|
-
from torch.nn import functional as F
|
12
|
-
import sys
|
13
|
-
import PIL
|
14
|
-
from PIL import Image
|
15
|
-
|
16
|
-
|
17
|
-
|
18
|
-
# ################################################################################
|
19
|
-
# Progress Bar
|
20
|
-
class SimpleProgressBar():
|
21
|
-
def __init__(self, width=50):
|
22
|
-
self.last_x = -1
|
23
|
-
self.width = width
|
24
|
-
|
25
|
-
def update(self, x):
|
26
|
-
assert 0 <= x <= 100 # `x`: progress in percent ( between 0 and 100)
|
27
|
-
if self.last_x == int(x): return
|
28
|
-
self.last_x = int(x)
|
29
|
-
pointer = int(self.width * (x / 100.0))
|
30
|
-
sys.stdout.write( '\r%d%% [%s]' % (int(x), '#' * pointer + '.' * (self.width - pointer)))
|
31
|
-
sys.stdout.flush()
|
32
|
-
if x == 100:
|
33
|
-
print('')
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
################################################################################
|
38
|
-
# torch dataset from numpy array
|
39
|
-
class IMGs_dataset(torch.utils.data.Dataset):
|
40
|
-
def __init__(self, images, labels=None, normalize=False):
|
41
|
-
super(IMGs_dataset, self).__init__()
|
42
|
-
|
43
|
-
self.images = images
|
44
|
-
self.n_images = len(self.images)
|
45
|
-
self.labels = labels
|
46
|
-
if labels is not None:
|
47
|
-
if len(self.images) != len(self.labels):
|
48
|
-
raise Exception('images (' + str(len(self.images)) +') and labels ('+str(len(self.labels))+') do not have the same length!!!')
|
49
|
-
self.normalize = normalize
|
50
|
-
|
51
|
-
|
52
|
-
def __getitem__(self, index):
|
53
|
-
|
54
|
-
image = self.images[index]
|
55
|
-
|
56
|
-
if self.normalize:
|
57
|
-
image = image/255.0
|
58
|
-
image = (image-0.5)/0.5
|
59
|
-
|
60
|
-
if self.labels is not None:
|
61
|
-
label = self.labels[index]
|
62
|
-
return (image, label)
|
63
|
-
else:
|
64
|
-
return image
|
65
|
-
|
66
|
-
def __len__(self):
|
67
|
-
return self.n_images
|
68
|
-
|
69
|
-
|
70
|
-
def PlotLoss(loss, filename):
|
71
|
-
x_axis = np.arange(start = 1, stop = len(loss)+1)
|
72
|
-
plt.switch_backend('agg')
|
73
|
-
mpl.style.use('seaborn')
|
74
|
-
fig = plt.figure()
|
75
|
-
ax = plt.subplot(111)
|
76
|
-
ax.plot(x_axis, np.array(loss))
|
77
|
-
plt.xlabel("epoch")
|
78
|
-
plt.ylabel("training loss")
|
79
|
-
plt.legend()
|
80
|
-
#ax.legend(loc='upper center', bbox_to_anchor=(0.5, 1.15), shadow=True, ncol=3)
|
81
|
-
#plt.title('Training Loss')
|
82
|
-
plt.savefig(filename)
|
83
|
-
|
84
|
-
|
85
|
-
# compute entropy of class labels; labels is a numpy array
|
86
|
-
def compute_entropy(labels, base=None):
|
87
|
-
value,counts = np.unique(labels, return_counts=True)
|
88
|
-
norm_counts = counts / counts.sum()
|
89
|
-
base = np.e if base is None else base
|
90
|
-
return -(norm_counts * np.log(norm_counts)/np.log(base)).sum()
|
91
|
-
|
92
|
-
def predict_class_labels(net, images, batch_size=500, verbose=False, num_workers=0):
|
93
|
-
net = net.cuda()
|
94
|
-
net.eval()
|
95
|
-
|
96
|
-
n = len(images)
|
97
|
-
if batch_size>n:
|
98
|
-
batch_size=n
|
99
|
-
dataset_pred = IMGs_dataset(images, normalize=False)
|
100
|
-
dataloader_pred = torch.utils.data.DataLoader(dataset_pred, batch_size=batch_size, shuffle=False, num_workers=num_workers)
|
101
|
-
|
102
|
-
class_labels_pred = np.zeros(n+batch_size)
|
103
|
-
with torch.no_grad():
|
104
|
-
nimgs_got = 0
|
105
|
-
if verbose:
|
106
|
-
pb = SimpleProgressBar()
|
107
|
-
for batch_idx, batch_images in enumerate(dataloader_pred):
|
108
|
-
batch_images = batch_images.type(torch.float).cuda()
|
109
|
-
batch_size_curr = len(batch_images)
|
110
|
-
|
111
|
-
outputs,_ = net(batch_images)
|
112
|
-
_, batch_class_labels_pred = torch.max(outputs.data, 1)
|
113
|
-
class_labels_pred[nimgs_got:(nimgs_got+batch_size_curr)] = batch_class_labels_pred.detach().cpu().numpy().reshape(-1)
|
114
|
-
|
115
|
-
nimgs_got += batch_size_curr
|
116
|
-
if verbose:
|
117
|
-
pb.update((float(nimgs_got)/n)*100)
|
118
|
-
#end for batch_idx
|
119
|
-
class_labels_pred = class_labels_pred[0:n]
|
120
|
-
return class_labels_pred
|