Myosotis-Researches 0.0.18__py3-none-any.whl → 0.0.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- myosotis_researches/CcGAN/train_128_output_10/DiffAugment_pytorch.py +76 -0
- myosotis_researches/CcGAN/train_128_output_10/__init__.py +0 -0
- myosotis_researches/CcGAN/train_128_output_10/eval_metrics.py +205 -0
- myosotis_researches/CcGAN/train_128_output_10/opts.py +87 -0
- myosotis_researches/CcGAN/train_128_output_10/pretrain_AE.py +268 -0
- myosotis_researches/CcGAN/train_128_output_10/pretrain_CNN_class.py +251 -0
- myosotis_researches/CcGAN/train_128_output_10/pretrain_CNN_regre.py +255 -0
- myosotis_researches/CcGAN/train_128_output_10/train_ccgan.py +302 -0
- myosotis_researches/CcGAN/train_128_output_10/train_cgan.py +254 -0
- myosotis_researches/CcGAN/train_128_output_10/train_cgan_concat.py +242 -0
- myosotis_researches/CcGAN/train_128_output_10/train_net_for_label_embed.py +181 -0
- myosotis_researches/CcGAN/train_128_output_10/utils.py +120 -0
- {myosotis_researches-0.0.18.dist-info → myosotis_researches-0.0.19.dist-info}/METADATA +1 -1
- {myosotis_researches-0.0.18.dist-info → myosotis_researches-0.0.19.dist-info}/RECORD +17 -5
- {myosotis_researches-0.0.18.dist-info → myosotis_researches-0.0.19.dist-info}/WHEEL +0 -0
- {myosotis_researches-0.0.18.dist-info → myosotis_researches-0.0.19.dist-info}/licenses/LICENSE +0 -0
- {myosotis_researches-0.0.18.dist-info → myosotis_researches-0.0.19.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,242 @@
|
|
1
|
+
|
2
|
+
import torch
|
3
|
+
import torch.nn as nn
|
4
|
+
from torchvision.utils import save_image
|
5
|
+
import numpy as np
|
6
|
+
import os
|
7
|
+
import timeit
|
8
|
+
|
9
|
+
from .utils import IMGs_dataset, SimpleProgressBar
|
10
|
+
from .opts import parse_opts
|
11
|
+
from .DiffAugment_pytorch import DiffAugment
|
12
|
+
|
13
|
+
''' Settings '''
|
14
|
+
args = parse_opts()
|
15
|
+
|
16
|
+
# some parameters in opts
|
17
|
+
loss_type = args.loss_type_gan
|
18
|
+
niters = args.niters_gan
|
19
|
+
resume_niters = args.resume_niters_gan
|
20
|
+
dim_gan = args.dim_gan
|
21
|
+
lr_g = args.lr_g_gan
|
22
|
+
lr_d = args.lr_d_gan
|
23
|
+
save_niters_freq = args.save_niters_freq
|
24
|
+
batch_size = min(args.batch_size_disc, args.batch_size_gene)
|
25
|
+
num_classes = args.cGAN_num_classes
|
26
|
+
gan_arch = args.GAN_arch
|
27
|
+
num_D_steps = args.num_D_steps
|
28
|
+
|
29
|
+
visualize_freq = args.visualize_freq
|
30
|
+
|
31
|
+
num_workers = args.num_workers
|
32
|
+
|
33
|
+
NC = args.num_channels
|
34
|
+
IMG_SIZE = args.img_size
|
35
|
+
max_label = args.max_label
|
36
|
+
|
37
|
+
use_DiffAugment = args.gan_DiffAugment
|
38
|
+
policy = args.gan_DiffAugment_policy
|
39
|
+
|
40
|
+
|
41
|
+
def train_cgan_concat(images, labels, netG, netD, save_images_folder, save_models_folder = None):
|
42
|
+
|
43
|
+
netG = netG.cuda()
|
44
|
+
netD = netD.cuda()
|
45
|
+
|
46
|
+
optimizerG = torch.optim.Adam(netG.parameters(), lr=lr_g, betas=(0.5, 0.999))
|
47
|
+
optimizerD = torch.optim.Adam(netD.parameters(), lr=lr_d, betas=(0.5, 0.999))
|
48
|
+
|
49
|
+
trainset = IMGs_dataset(images, labels, normalize=True)
|
50
|
+
train_dataloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
|
51
|
+
unique_labels = np.sort(np.array(list(set(labels)))).astype(np.int)
|
52
|
+
|
53
|
+
if save_models_folder is not None and resume_niters>0:
|
54
|
+
save_file = save_models_folder + "/cGAN_{}_nDsteps_{}_checkpoint_intrain/cGAN_checkpoint_niters_{}.pth".format(gan_arch, num_D_steps, resume_niters)
|
55
|
+
checkpoint = torch.load(save_file)
|
56
|
+
netG.load_state_dict(checkpoint['netG_state_dict'])
|
57
|
+
netD.load_state_dict(checkpoint['netD_state_dict'])
|
58
|
+
optimizerG.load_state_dict(checkpoint['optimizerG_state_dict'])
|
59
|
+
optimizerD.load_state_dict(checkpoint['optimizerD_state_dict'])
|
60
|
+
torch.set_rng_state(checkpoint['rng_state'])
|
61
|
+
#end if
|
62
|
+
|
63
|
+
# printed images with labels between the 5-th quantile and 95-th quantile of training labels
|
64
|
+
n_row=10; n_col = n_row
|
65
|
+
z_fixed = torch.randn(n_row*n_col, dim_gan, dtype=torch.float).cuda()
|
66
|
+
start_label = np.quantile(labels, 0.05)
|
67
|
+
end_label = np.quantile(labels, 0.95)
|
68
|
+
selected_labels = np.linspace(start_label, end_label, num=n_row)
|
69
|
+
y_fixed = np.zeros(n_row*n_col)
|
70
|
+
for i in range(n_row):
|
71
|
+
curr_label = selected_labels[i]
|
72
|
+
for j in range(n_col):
|
73
|
+
y_fixed[i*n_col+j] = curr_label
|
74
|
+
print(y_fixed)
|
75
|
+
y_fixed = torch.from_numpy(y_fixed).type(torch.float).view(-1,1).cuda()
|
76
|
+
|
77
|
+
|
78
|
+
batch_idx = 0
|
79
|
+
dataloader_iter = iter(train_dataloader)
|
80
|
+
|
81
|
+
start_time = timeit.default_timer()
|
82
|
+
for niter in range(resume_niters, niters):
|
83
|
+
|
84
|
+
if batch_idx+1 == len(train_dataloader):
|
85
|
+
dataloader_iter = iter(train_dataloader)
|
86
|
+
batch_idx = 0
|
87
|
+
|
88
|
+
'''
|
89
|
+
|
90
|
+
Train Generator: maximize log(D(G(z)))
|
91
|
+
|
92
|
+
'''
|
93
|
+
|
94
|
+
netG.train()
|
95
|
+
|
96
|
+
# get training images
|
97
|
+
_, batch_train_labels = dataloader_iter.next()
|
98
|
+
assert batch_size == batch_train_labels.shape[0]
|
99
|
+
batch_train_labels = batch_train_labels.type(torch.long).cuda()
|
100
|
+
batch_idx+=1
|
101
|
+
|
102
|
+
# Sample noise and labels as generator input
|
103
|
+
z = torch.randn(batch_size, dim_gan, dtype=torch.float).cuda()
|
104
|
+
|
105
|
+
#generate fake images
|
106
|
+
batch_fake_images = netG(z, batch_train_labels)
|
107
|
+
|
108
|
+
# Loss measures generator's ability to fool the discriminator
|
109
|
+
if use_DiffAugment:
|
110
|
+
dis_out = netD(DiffAugment(batch_fake_images, policy=policy), batch_train_labels)
|
111
|
+
else:
|
112
|
+
dis_out = netD(batch_fake_images, batch_train_labels)
|
113
|
+
|
114
|
+
if loss_type == "vanilla":
|
115
|
+
dis_out = torch.nn.Sigmoid()(dis_out)
|
116
|
+
g_loss = - torch.mean(torch.log(dis_out+1e-20))
|
117
|
+
elif loss_type == "hinge":
|
118
|
+
g_loss = - dis_out.mean()
|
119
|
+
|
120
|
+
optimizerG.zero_grad()
|
121
|
+
g_loss.backward()
|
122
|
+
optimizerG.step()
|
123
|
+
|
124
|
+
'''
|
125
|
+
|
126
|
+
Train Discriminator: maximize log(D(x)) + log(1 - D(G(z)))
|
127
|
+
|
128
|
+
'''
|
129
|
+
|
130
|
+
for _ in range(num_D_steps):
|
131
|
+
|
132
|
+
if batch_idx+1 == len(train_dataloader):
|
133
|
+
dataloader_iter = iter(train_dataloader)
|
134
|
+
batch_idx = 0
|
135
|
+
|
136
|
+
# get training images
|
137
|
+
batch_train_images, batch_train_labels = dataloader_iter.next()
|
138
|
+
assert batch_size == batch_train_images.shape[0]
|
139
|
+
batch_train_images = batch_train_images.type(torch.float).cuda()
|
140
|
+
batch_train_labels = batch_train_labels.type(torch.long).cuda()
|
141
|
+
batch_idx+=1
|
142
|
+
|
143
|
+
# Measure discriminator's ability to classify real from generated samples
|
144
|
+
if use_DiffAugment:
|
145
|
+
real_dis_out = netD(DiffAugment(batch_train_images, policy=policy), batch_train_labels)
|
146
|
+
fake_dis_out = netD(DiffAugment(batch_fake_images.detach(), policy=policy), batch_train_labels.detach())
|
147
|
+
else:
|
148
|
+
real_dis_out = netD(batch_train_images, batch_train_labels)
|
149
|
+
fake_dis_out = netD(batch_fake_images.detach(), batch_train_labels.detach())
|
150
|
+
|
151
|
+
|
152
|
+
if loss_type == "vanilla":
|
153
|
+
real_dis_out = torch.nn.Sigmoid()(real_dis_out)
|
154
|
+
fake_dis_out = torch.nn.Sigmoid()(fake_dis_out)
|
155
|
+
d_loss_real = - torch.log(real_dis_out+1e-20)
|
156
|
+
d_loss_fake = - torch.log(1-fake_dis_out+1e-20)
|
157
|
+
elif loss_type == "hinge":
|
158
|
+
d_loss_real = torch.nn.ReLU()(1.0 - real_dis_out)
|
159
|
+
d_loss_fake = torch.nn.ReLU()(1.0 + fake_dis_out)
|
160
|
+
d_loss = (d_loss_real + d_loss_fake).mean()
|
161
|
+
|
162
|
+
optimizerD.zero_grad()
|
163
|
+
d_loss.backward()
|
164
|
+
optimizerD.step()
|
165
|
+
|
166
|
+
|
167
|
+
|
168
|
+
if (niter+1)%20 == 0:
|
169
|
+
print ("cGAN(concat)-%s: [Iter %d/%d] [D loss: %.4f] [G loss: %.4f] [D out real:%.4f] [D out fake:%.4f] [Time: %.4f]" % (gan_arch, niter+1, niters, d_loss.item(), g_loss.item(), real_dis_out.mean().item(),fake_dis_out.mean().item(), timeit.default_timer()-start_time))
|
170
|
+
|
171
|
+
|
172
|
+
if (niter+1) % visualize_freq == 0:
|
173
|
+
netG.eval()
|
174
|
+
with torch.no_grad():
|
175
|
+
gen_imgs = netG(z_fixed, y_fixed)
|
176
|
+
gen_imgs = gen_imgs.detach()
|
177
|
+
save_image(gen_imgs.data, save_images_folder +'/{}.png'.format(niter+1), nrow=n_row, normalize=True)
|
178
|
+
|
179
|
+
if save_models_folder is not None and ((niter+1) % save_niters_freq == 0 or (niter+1) == niters):
|
180
|
+
save_file = save_models_folder + "/cGAN_{}_nDsteps_{}_checkpoint_intrain/cGAN_checkpoint_niters_{}.pth".format(gan_arch, num_D_steps, niter+1)
|
181
|
+
os.makedirs(os.path.dirname(save_file), exist_ok=True)
|
182
|
+
torch.save({
|
183
|
+
'netG_state_dict': netG.state_dict(),
|
184
|
+
'netD_state_dict': netD.state_dict(),
|
185
|
+
'optimizerG_state_dict': optimizerG.state_dict(),
|
186
|
+
'optimizerD_state_dict': optimizerD.state_dict(),
|
187
|
+
'rng_state': torch.get_rng_state()
|
188
|
+
}, save_file)
|
189
|
+
#end for niter
|
190
|
+
|
191
|
+
|
192
|
+
return netG, netD
|
193
|
+
|
194
|
+
|
195
|
+
def sample_cgan_concat_given_labels(netG, given_labels, batch_size = 100, denorm=True, to_numpy=True, verbose=True):
|
196
|
+
'''
|
197
|
+
netG: pretrained generator network
|
198
|
+
given_labels: float. unnormalized labels. we need to convert them to values in [-1,1].
|
199
|
+
'''
|
200
|
+
|
201
|
+
## num of fake images will be generated
|
202
|
+
nfake = len(given_labels)
|
203
|
+
|
204
|
+
## normalize regression
|
205
|
+
labels = given_labels/max_label
|
206
|
+
|
207
|
+
## generate images
|
208
|
+
if batch_size>nfake:
|
209
|
+
batch_size = nfake
|
210
|
+
|
211
|
+
fake_images = []
|
212
|
+
## concat to avoid out of index errors
|
213
|
+
labels = np.concatenate((labels, labels[0:batch_size]), axis=0)
|
214
|
+
|
215
|
+
netG=netG.cuda()
|
216
|
+
netG.eval()
|
217
|
+
with torch.no_grad():
|
218
|
+
if verbose:
|
219
|
+
pb = SimpleProgressBar()
|
220
|
+
tmp = 0
|
221
|
+
while tmp < nfake:
|
222
|
+
z = torch.randn(batch_size, dim_gan, dtype=torch.float).cuda()
|
223
|
+
c = torch.from_numpy(labels[tmp:(tmp+batch_size)]).type(torch.float).cuda()
|
224
|
+
batch_fake_images = netG(z, c)
|
225
|
+
if denorm: #denorm imgs to save memory
|
226
|
+
assert batch_fake_images.max().item()<=1.0 and batch_fake_images.min().item()>=-1.0
|
227
|
+
batch_fake_images = batch_fake_images*0.5+0.5
|
228
|
+
batch_fake_images = batch_fake_images*255.0
|
229
|
+
batch_fake_images = batch_fake_images.type(torch.uint8)
|
230
|
+
fake_images.append(batch_fake_images.detach().cpu())
|
231
|
+
tmp += batch_size
|
232
|
+
if verbose:
|
233
|
+
pb.update(min(float(tmp)/nfake, 1)*100)
|
234
|
+
|
235
|
+
fake_images = torch.cat(fake_images, dim=0)
|
236
|
+
#remove extra entries
|
237
|
+
fake_images = fake_images[0:nfake]
|
238
|
+
|
239
|
+
if to_numpy:
|
240
|
+
fake_images = fake_images.numpy()
|
241
|
+
|
242
|
+
return fake_images, given_labels
|
@@ -0,0 +1,181 @@
|
|
1
|
+
|
2
|
+
import torch
|
3
|
+
import torch.nn as nn
|
4
|
+
from torchvision.utils import save_image
|
5
|
+
import numpy as np
|
6
|
+
import os
|
7
|
+
import timeit
|
8
|
+
from PIL import Image
|
9
|
+
|
10
|
+
|
11
|
+
|
12
|
+
|
13
|
+
#-------------------------------------------------------------
|
14
|
+
def train_net_embed(net, net_name, trainloader, testloader, epochs=200, resume_epoch = 0, lr_base=0.01, lr_decay_factor=0.1, lr_decay_epochs=[80, 140], weight_decay=1e-4, path_to_ckpt = None):
|
15
|
+
|
16
|
+
''' learning rate decay '''
|
17
|
+
def adjust_learning_rate_1(optimizer, epoch):
|
18
|
+
"""decrease the learning rate """
|
19
|
+
lr = lr_base
|
20
|
+
|
21
|
+
num_decays = len(lr_decay_epochs)
|
22
|
+
for decay_i in range(num_decays):
|
23
|
+
if epoch >= lr_decay_epochs[decay_i]:
|
24
|
+
lr = lr * lr_decay_factor
|
25
|
+
#end if epoch
|
26
|
+
#end for decay_i
|
27
|
+
for param_group in optimizer.param_groups:
|
28
|
+
param_group['lr'] = lr
|
29
|
+
|
30
|
+
net = net.cuda()
|
31
|
+
criterion = nn.MSELoss()
|
32
|
+
optimizer = torch.optim.SGD(net.parameters(), lr = lr_base, momentum= 0.9, weight_decay=weight_decay)
|
33
|
+
|
34
|
+
# resume training; load checkpoint
|
35
|
+
if path_to_ckpt is not None and resume_epoch>0:
|
36
|
+
save_file = path_to_ckpt + "/embed_x2y_ckpt_in_train/embed_x2y_checkpoint_epoch_{}.pth".format(resume_epoch)
|
37
|
+
checkpoint = torch.load(save_file)
|
38
|
+
net.load_state_dict(checkpoint['net_state_dict'])
|
39
|
+
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
40
|
+
torch.set_rng_state(checkpoint['rng_state'])
|
41
|
+
#end if
|
42
|
+
|
43
|
+
start_tmp = timeit.default_timer()
|
44
|
+
for epoch in range(resume_epoch, epochs):
|
45
|
+
net.train()
|
46
|
+
train_loss = 0
|
47
|
+
adjust_learning_rate_1(optimizer, epoch)
|
48
|
+
for _, (batch_train_images, batch_train_labels) in enumerate(trainloader):
|
49
|
+
|
50
|
+
# batch_train_images = nn.functional.interpolate(batch_train_images, size = (299,299), scale_factor=None, mode='bilinear', align_corners=False)
|
51
|
+
|
52
|
+
batch_train_images = batch_train_images.type(torch.float).cuda()
|
53
|
+
batch_train_labels = batch_train_labels.type(torch.float).view(-1,1).cuda()
|
54
|
+
|
55
|
+
#Forward pass
|
56
|
+
outputs, _ = net(batch_train_images)
|
57
|
+
loss = criterion(outputs, batch_train_labels)
|
58
|
+
|
59
|
+
#backward pass
|
60
|
+
optimizer.zero_grad()
|
61
|
+
loss.backward()
|
62
|
+
optimizer.step()
|
63
|
+
|
64
|
+
train_loss += loss.cpu().item()
|
65
|
+
#end for batch_idx
|
66
|
+
train_loss = train_loss / len(trainloader)
|
67
|
+
|
68
|
+
if testloader is None:
|
69
|
+
print('Train net_x2y for embedding: [epoch %d/%d] train_loss:%f Time:%.4f' % (epoch+1, epochs, train_loss, timeit.default_timer()-start_tmp))
|
70
|
+
else:
|
71
|
+
net.eval() # eval mode (batchnorm uses moving mean/variance instead of mini-batch mean/variance)
|
72
|
+
with torch.no_grad():
|
73
|
+
test_loss = 0
|
74
|
+
for batch_test_images, batch_test_labels in testloader:
|
75
|
+
batch_test_images = batch_test_images.type(torch.float).cuda()
|
76
|
+
batch_test_labels = batch_test_labels.type(torch.float).view(-1,1).cuda()
|
77
|
+
outputs,_ = net(batch_test_images)
|
78
|
+
loss = criterion(outputs, batch_test_labels)
|
79
|
+
test_loss += loss.cpu().item()
|
80
|
+
test_loss = test_loss/len(testloader)
|
81
|
+
|
82
|
+
print('Train net_x2y for label embedding: [epoch %d/%d] train_loss:%f test_loss:%f Time:%.4f' % (epoch+1, epochs, train_loss, test_loss, timeit.default_timer()-start_tmp))
|
83
|
+
|
84
|
+
#save checkpoint
|
85
|
+
if path_to_ckpt is not None and (((epoch+1) % 50 == 0) or (epoch+1==epochs)):
|
86
|
+
save_file = path_to_ckpt + "/embed_x2y_ckpt_in_train/embed_x2y_checkpoint_epoch_{}.pth".format(epoch+1)
|
87
|
+
os.makedirs(os.path.dirname(save_file), exist_ok=True)
|
88
|
+
torch.save({
|
89
|
+
'epoch': epoch,
|
90
|
+
'net_state_dict': net.state_dict(),
|
91
|
+
'optimizer_state_dict': optimizer.state_dict(),
|
92
|
+
'rng_state': torch.get_rng_state()
|
93
|
+
}, save_file)
|
94
|
+
#end for epoch
|
95
|
+
|
96
|
+
return net
|
97
|
+
|
98
|
+
|
99
|
+
|
100
|
+
|
101
|
+
###################################################################################
|
102
|
+
class label_dataset(torch.utils.data.Dataset):
|
103
|
+
def __init__(self, labels):
|
104
|
+
super(label_dataset, self).__init__()
|
105
|
+
|
106
|
+
self.labels = labels
|
107
|
+
self.n_samples = len(self.labels)
|
108
|
+
|
109
|
+
def __getitem__(self, index):
|
110
|
+
|
111
|
+
y = self.labels[index]
|
112
|
+
return y
|
113
|
+
|
114
|
+
def __len__(self):
|
115
|
+
return self.n_samples
|
116
|
+
|
117
|
+
|
118
|
+
def train_net_y2h(unique_labels_norm, net_y2h, net_embed, epochs=500, lr_base=0.01, lr_decay_factor=0.1, lr_decay_epochs=[150, 250, 350], weight_decay=1e-4, batch_size=128):
|
119
|
+
'''
|
120
|
+
unique_labels_norm: an array of normalized unique labels
|
121
|
+
'''
|
122
|
+
|
123
|
+
''' learning rate decay '''
|
124
|
+
def adjust_learning_rate_2(optimizer, epoch):
|
125
|
+
"""decrease the learning rate """
|
126
|
+
lr = lr_base
|
127
|
+
|
128
|
+
num_decays = len(lr_decay_epochs)
|
129
|
+
for decay_i in range(num_decays):
|
130
|
+
if epoch >= lr_decay_epochs[decay_i]:
|
131
|
+
lr = lr * lr_decay_factor
|
132
|
+
#end if epoch
|
133
|
+
#end for decay_i
|
134
|
+
for param_group in optimizer.param_groups:
|
135
|
+
param_group['lr'] = lr
|
136
|
+
|
137
|
+
|
138
|
+
assert np.max(unique_labels_norm)<=1 and np.min(unique_labels_norm)>=0
|
139
|
+
trainset = label_dataset(unique_labels_norm)
|
140
|
+
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)
|
141
|
+
|
142
|
+
net_embed.eval()
|
143
|
+
net_h2y=net_embed.module.h2y #convert embedding labels to original labels
|
144
|
+
optimizer_y2h = torch.optim.SGD(net_y2h.parameters(), lr = lr_base, momentum= 0.9, weight_decay=weight_decay)
|
145
|
+
|
146
|
+
start_tmp = timeit.default_timer()
|
147
|
+
for epoch in range(epochs):
|
148
|
+
net_y2h.train()
|
149
|
+
train_loss = 0
|
150
|
+
adjust_learning_rate_2(optimizer_y2h, epoch)
|
151
|
+
for _, batch_labels in enumerate(trainloader):
|
152
|
+
|
153
|
+
batch_labels = batch_labels.type(torch.float).view(-1,1).cuda()
|
154
|
+
|
155
|
+
# generate noises which will be added to labels
|
156
|
+
batch_size_curr = len(batch_labels)
|
157
|
+
batch_gamma = np.random.normal(0, 0.2, batch_size_curr)
|
158
|
+
batch_gamma = torch.from_numpy(batch_gamma).view(-1,1).type(torch.float).cuda()
|
159
|
+
|
160
|
+
# add noise to labels
|
161
|
+
batch_labels_noise = torch.clamp(batch_labels+batch_gamma, 0.0, 1.0)
|
162
|
+
|
163
|
+
#Forward pass
|
164
|
+
batch_hiddens_noise = net_y2h(batch_labels_noise)
|
165
|
+
batch_rec_labels_noise = net_h2y(batch_hiddens_noise)
|
166
|
+
|
167
|
+
loss = nn.MSELoss()(batch_rec_labels_noise, batch_labels_noise)
|
168
|
+
|
169
|
+
#backward pass
|
170
|
+
optimizer_y2h.zero_grad()
|
171
|
+
loss.backward()
|
172
|
+
optimizer_y2h.step()
|
173
|
+
|
174
|
+
train_loss += loss.cpu().item()
|
175
|
+
#end for batch_idx
|
176
|
+
train_loss = train_loss / len(trainloader)
|
177
|
+
|
178
|
+
print('\n Train net_y2h: [epoch %d/%d] train_loss:%f Time:%.4f' % (epoch+1, epochs, train_loss, timeit.default_timer()-start_tmp))
|
179
|
+
#end for epoch
|
180
|
+
|
181
|
+
return net_y2h
|
@@ -0,0 +1,120 @@
|
|
1
|
+
"""
|
2
|
+
Some helpful functions
|
3
|
+
|
4
|
+
"""
|
5
|
+
import numpy as np
|
6
|
+
import torch
|
7
|
+
import torch.nn as nn
|
8
|
+
import torchvision
|
9
|
+
import matplotlib.pyplot as plt
|
10
|
+
import matplotlib as mpl
|
11
|
+
from torch.nn import functional as F
|
12
|
+
import sys
|
13
|
+
import PIL
|
14
|
+
from PIL import Image
|
15
|
+
|
16
|
+
|
17
|
+
|
18
|
+
# ################################################################################
|
19
|
+
# Progress Bar
|
20
|
+
class SimpleProgressBar():
|
21
|
+
def __init__(self, width=50):
|
22
|
+
self.last_x = -1
|
23
|
+
self.width = width
|
24
|
+
|
25
|
+
def update(self, x):
|
26
|
+
assert 0 <= x <= 100 # `x`: progress in percent ( between 0 and 100)
|
27
|
+
if self.last_x == int(x): return
|
28
|
+
self.last_x = int(x)
|
29
|
+
pointer = int(self.width * (x / 100.0))
|
30
|
+
sys.stdout.write( '\r%d%% [%s]' % (int(x), '#' * pointer + '.' * (self.width - pointer)))
|
31
|
+
sys.stdout.flush()
|
32
|
+
if x == 100:
|
33
|
+
print('')
|
34
|
+
|
35
|
+
|
36
|
+
|
37
|
+
################################################################################
|
38
|
+
# torch dataset from numpy array
|
39
|
+
class IMGs_dataset(torch.utils.data.Dataset):
|
40
|
+
def __init__(self, images, labels=None, normalize=False):
|
41
|
+
super(IMGs_dataset, self).__init__()
|
42
|
+
|
43
|
+
self.images = images
|
44
|
+
self.n_images = len(self.images)
|
45
|
+
self.labels = labels
|
46
|
+
if labels is not None:
|
47
|
+
if len(self.images) != len(self.labels):
|
48
|
+
raise Exception('images (' + str(len(self.images)) +') and labels ('+str(len(self.labels))+') do not have the same length!!!')
|
49
|
+
self.normalize = normalize
|
50
|
+
|
51
|
+
|
52
|
+
def __getitem__(self, index):
|
53
|
+
|
54
|
+
image = self.images[index]
|
55
|
+
|
56
|
+
if self.normalize:
|
57
|
+
image = image/255.0
|
58
|
+
image = (image-0.5)/0.5
|
59
|
+
|
60
|
+
if self.labels is not None:
|
61
|
+
label = self.labels[index]
|
62
|
+
return (image, label)
|
63
|
+
else:
|
64
|
+
return image
|
65
|
+
|
66
|
+
def __len__(self):
|
67
|
+
return self.n_images
|
68
|
+
|
69
|
+
|
70
|
+
def PlotLoss(loss, filename):
|
71
|
+
x_axis = np.arange(start = 1, stop = len(loss)+1)
|
72
|
+
plt.switch_backend('agg')
|
73
|
+
mpl.style.use('seaborn')
|
74
|
+
fig = plt.figure()
|
75
|
+
ax = plt.subplot(111)
|
76
|
+
ax.plot(x_axis, np.array(loss))
|
77
|
+
plt.xlabel("epoch")
|
78
|
+
plt.ylabel("training loss")
|
79
|
+
plt.legend()
|
80
|
+
#ax.legend(loc='upper center', bbox_to_anchor=(0.5, 1.15), shadow=True, ncol=3)
|
81
|
+
#plt.title('Training Loss')
|
82
|
+
plt.savefig(filename)
|
83
|
+
|
84
|
+
|
85
|
+
# compute entropy of class labels; labels is a numpy array
|
86
|
+
def compute_entropy(labels, base=None):
|
87
|
+
value,counts = np.unique(labels, return_counts=True)
|
88
|
+
norm_counts = counts / counts.sum()
|
89
|
+
base = np.e if base is None else base
|
90
|
+
return -(norm_counts * np.log(norm_counts)/np.log(base)).sum()
|
91
|
+
|
92
|
+
def predict_class_labels(net, images, batch_size=500, verbose=False, num_workers=0):
|
93
|
+
net = net.cuda()
|
94
|
+
net.eval()
|
95
|
+
|
96
|
+
n = len(images)
|
97
|
+
if batch_size>n:
|
98
|
+
batch_size=n
|
99
|
+
dataset_pred = IMGs_dataset(images, normalize=False)
|
100
|
+
dataloader_pred = torch.utils.data.DataLoader(dataset_pred, batch_size=batch_size, shuffle=False, num_workers=num_workers)
|
101
|
+
|
102
|
+
class_labels_pred = np.zeros(n+batch_size)
|
103
|
+
with torch.no_grad():
|
104
|
+
nimgs_got = 0
|
105
|
+
if verbose:
|
106
|
+
pb = SimpleProgressBar()
|
107
|
+
for batch_idx, batch_images in enumerate(dataloader_pred):
|
108
|
+
batch_images = batch_images.type(torch.float).cuda()
|
109
|
+
batch_size_curr = len(batch_images)
|
110
|
+
|
111
|
+
outputs,_ = net(batch_images)
|
112
|
+
_, batch_class_labels_pred = torch.max(outputs.data, 1)
|
113
|
+
class_labels_pred[nimgs_got:(nimgs_got+batch_size_curr)] = batch_class_labels_pred.detach().cpu().numpy().reshape(-1)
|
114
|
+
|
115
|
+
nimgs_got += batch_size_curr
|
116
|
+
if verbose:
|
117
|
+
pb.update((float(nimgs_got)/n)*100)
|
118
|
+
#end for batch_idx
|
119
|
+
class_labels_pred = class_labels_pred[0:n]
|
120
|
+
return class_labels_pred
|
@@ -28,13 +28,25 @@ myosotis_researches/CcGAN/train_128/train_cgan.py,sha256=bYJbBskTpESfCG2uj52RW9z
|
|
28
28
|
myosotis_researches/CcGAN/train_128/train_cgan_concat.py,sha256=PYctY3IZiHGh4TshXx3mUZBf9su_8NuV_D8InkxKQZ4,8940
|
29
29
|
myosotis_researches/CcGAN/train_128/train_net_for_label_embed.py,sha256=4j6r4_o4rXgAN4MdUQL-TXqZJpbhH7d9gWQR8YzBlXw,6976
|
30
30
|
myosotis_researches/CcGAN/train_128/utils.py,sha256=B-V6ct4WDisVVCOLO0W7VIBL8StPVNJJTZZ2b2NkMFU,3766
|
31
|
+
myosotis_researches/CcGAN/train_128_output_10/DiffAugment_pytorch.py,sha256=HxMZdMpE4KvwY3AsNgci8VNEFV3cNALg3obTyELlCaY,3025
|
32
|
+
myosotis_researches/CcGAN/train_128_output_10/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
33
|
+
myosotis_researches/CcGAN/train_128_output_10/eval_metrics.py,sha256=nqDh0xhumSmpMSk2HElCR6LiUtydaFLRy6rGdt39sSg,7169
|
34
|
+
myosotis_researches/CcGAN/train_128_output_10/opts.py,sha256=oIScD7A6GdcWI_ptB-k3Df5WWoWglf8bp32v3pNlerY,5374
|
35
|
+
myosotis_researches/CcGAN/train_128_output_10/pretrain_AE.py,sha256=VAbe5kSfvTl2k0aV6eV3XnMMV28KrIzB2EglahXEXiU,10746
|
36
|
+
myosotis_researches/CcGAN/train_128_output_10/pretrain_CNN_class.py,sha256=tnGH5rKeJyWY29esBXlFJx9Qr30uB6W5cMw1Wge-Leg,9247
|
37
|
+
myosotis_researches/CcGAN/train_128_output_10/pretrain_CNN_regre.py,sha256=RSmHQ5Z3Jbq2TKU2D2p-HJriXxKDw4YG-B8gstwefcI,8953
|
38
|
+
myosotis_researches/CcGAN/train_128_output_10/train_ccgan.py,sha256=ZykhfbmRFCUpP9JAFjIaO4B3nSLl6-KeQTqCa1WJltY,13335
|
39
|
+
myosotis_researches/CcGAN/train_128_output_10/train_cgan.py,sha256=bYJbBskTpESfCG2uj52RW9zLh3Zod4e8Uop7rim3dmE,9698
|
40
|
+
myosotis_researches/CcGAN/train_128_output_10/train_cgan_concat.py,sha256=PYctY3IZiHGh4TshXx3mUZBf9su_8NuV_D8InkxKQZ4,8940
|
41
|
+
myosotis_researches/CcGAN/train_128_output_10/train_net_for_label_embed.py,sha256=4j6r4_o4rXgAN4MdUQL-TXqZJpbhH7d9gWQR8YzBlXw,6976
|
42
|
+
myosotis_researches/CcGAN/train_128_output_10/utils.py,sha256=B-V6ct4WDisVVCOLO0W7VIBL8StPVNJJTZZ2b2NkMFU,3766
|
31
43
|
myosotis_researches/CcGAN/utils/__init__.py,sha256=Pu9COV4zcXHGXuczhObersyeshVChmlEtwqp8VLUDxw,300
|
32
44
|
myosotis_researches/CcGAN/utils/concat_image_horizontal.py,sha256=e6WsfO9IiSoP8zkZNz7IGimPUASr9VvyJUJdF-d40iw,954
|
33
45
|
myosotis_researches/CcGAN/utils/concat_image_vertical.py,sha256=97-SuE8ZWpaeBm_ed6MAEaUOvtpzlYq_X3yWt4OEUTY,951
|
34
46
|
myosotis_researches/CcGAN/utils/make_h5.py,sha256=Q5OW1JA35ormmsrlAJp6XdC6x0uJBRNjsE31wM3zBiI,1422
|
35
47
|
myosotis_researches/CcGAN/utils/print_hdf5_structure.py,sha256=leaR8H3GhlX6EuIXDMh36xG2zBdV-XlJkaXBuoorl6I,320
|
36
|
-
myosotis_researches-0.0.
|
37
|
-
myosotis_researches-0.0.
|
38
|
-
myosotis_researches-0.0.
|
39
|
-
myosotis_researches-0.0.
|
40
|
-
myosotis_researches-0.0.
|
48
|
+
myosotis_researches-0.0.19.dist-info/licenses/LICENSE,sha256=OXLcl0T2SZ8Pmy2_dmlvKuetivmyPd5m1q-Gyd-zaYY,35149
|
49
|
+
myosotis_researches-0.0.19.dist-info/METADATA,sha256=o-O5s4Ir9KGNpduL-hYtSFywxpc_j9RfB5i6MLVe2uw,765
|
50
|
+
myosotis_researches-0.0.19.dist-info/WHEEL,sha256=pxyMxgL8-pra_rKaQ4drOZAegBVuX-G_4nRHjjgWbmo,91
|
51
|
+
myosotis_researches-0.0.19.dist-info/top_level.txt,sha256=zxAiMn5eyZNJM28MewTAkgi_RZJMbfWbzVR-KF0LdZE,20
|
52
|
+
myosotis_researches-0.0.19.dist-info/RECORD,,
|
File without changes
|
{myosotis_researches-0.0.18.dist-info → myosotis_researches-0.0.19.dist-info}/licenses/LICENSE
RENAMED
File without changes
|
File without changes
|