Myosotis-Researches 0.1.7__py3-none-any.whl → 0.1.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- myosotis_researches/CcGAN/train/__init__.py +4 -0
- myosotis_researches/CcGAN/{train_128_output_10 → train}/train_ccgan.py +4 -4
- myosotis_researches/CcGAN/{train_128 → train}/train_cgan.py +1 -3
- myosotis_researches/CcGAN/{train_128 → train}/train_cgan_concat.py +1 -3
- myosotis_researches/CcGAN/utils/__init__.py +2 -1
- myosotis_researches/CcGAN/utils/train.py +94 -3
- {myosotis_researches-0.1.7.dist-info → myosotis_researches-0.1.9.dist-info}/METADATA +1 -1
- myosotis_researches-0.1.9.dist-info/RECORD +24 -0
- myosotis_researches/CcGAN/models_128/CcGAN_SAGAN.py +0 -301
- myosotis_researches/CcGAN/models_128/ResNet_class_eval.py +0 -141
- myosotis_researches/CcGAN/models_128/ResNet_embed.py +0 -188
- myosotis_researches/CcGAN/models_128/ResNet_regre_eval.py +0 -175
- myosotis_researches/CcGAN/models_128/__init__.py +0 -7
- myosotis_researches/CcGAN/models_128/autoencoder.py +0 -119
- myosotis_researches/CcGAN/models_128/cGAN_SAGAN.py +0 -276
- myosotis_researches/CcGAN/models_128/cGAN_concat_SAGAN.py +0 -245
- myosotis_researches/CcGAN/models_256/CcGAN_SAGAN.py +0 -303
- myosotis_researches/CcGAN/models_256/ResNet_class_eval.py +0 -142
- myosotis_researches/CcGAN/models_256/ResNet_embed.py +0 -188
- myosotis_researches/CcGAN/models_256/ResNet_regre_eval.py +0 -178
- myosotis_researches/CcGAN/models_256/__init__.py +0 -7
- myosotis_researches/CcGAN/models_256/autoencoder.py +0 -133
- myosotis_researches/CcGAN/models_256/cGAN_SAGAN.py +0 -280
- myosotis_researches/CcGAN/models_256/cGAN_concat_SAGAN.py +0 -249
- myosotis_researches/CcGAN/train_128/DiffAugment_pytorch.py +0 -76
- myosotis_researches/CcGAN/train_128/__init__.py +0 -0
- myosotis_researches/CcGAN/train_128/eval_metrics.py +0 -205
- myosotis_researches/CcGAN/train_128/opts.py +0 -87
- myosotis_researches/CcGAN/train_128/pretrain_AE.py +0 -268
- myosotis_researches/CcGAN/train_128/pretrain_CNN_class.py +0 -251
- myosotis_researches/CcGAN/train_128/pretrain_CNN_regre.py +0 -255
- myosotis_researches/CcGAN/train_128/train_ccgan.py +0 -303
- myosotis_researches/CcGAN/train_128/utils.py +0 -120
- myosotis_researches/CcGAN/train_128_output_10/DiffAugment_pytorch.py +0 -76
- myosotis_researches/CcGAN/train_128_output_10/__init__.py +0 -0
- myosotis_researches/CcGAN/train_128_output_10/eval_metrics.py +0 -205
- myosotis_researches/CcGAN/train_128_output_10/opts.py +0 -87
- myosotis_researches/CcGAN/train_128_output_10/pretrain_AE.py +0 -268
- myosotis_researches/CcGAN/train_128_output_10/pretrain_CNN_class.py +0 -251
- myosotis_researches/CcGAN/train_128_output_10/pretrain_CNN_regre.py +0 -255
- myosotis_researches/CcGAN/train_128_output_10/train_cgan.py +0 -254
- myosotis_researches/CcGAN/train_128_output_10/train_cgan_concat.py +0 -242
- myosotis_researches/CcGAN/train_128_output_10/train_net_for_label_embed.py +0 -181
- myosotis_researches/CcGAN/train_128_output_10/utils.py +0 -120
- myosotis_researches-0.1.7.dist-info/RECORD +0 -59
- /myosotis_researches/CcGAN/{train_128 → train}/train_net_for_label_embed.py +0 -0
- {myosotis_researches-0.1.7.dist-info → myosotis_researches-0.1.9.dist-info}/WHEEL +0 -0
- {myosotis_researches-0.1.7.dist-info → myosotis_researches-0.1.9.dist-info}/licenses/LICENSE +0 -0
- {myosotis_researches-0.1.7.dist-info → myosotis_researches-0.1.9.dist-info}/top_level.txt +0 -0
@@ -1,268 +0,0 @@
|
|
1
|
-
|
2
|
-
import os
|
3
|
-
import argparse
|
4
|
-
import shutil
|
5
|
-
import timeit
|
6
|
-
import torch
|
7
|
-
import torchvision
|
8
|
-
import torchvision.transforms as transforms
|
9
|
-
import numpy as np
|
10
|
-
import torch.nn as nn
|
11
|
-
import torch.backends.cudnn as cudnn
|
12
|
-
import random
|
13
|
-
import matplotlib.pyplot as plt
|
14
|
-
import matplotlib as mpl
|
15
|
-
from torch import autograd
|
16
|
-
from torchvision.utils import save_image
|
17
|
-
import csv
|
18
|
-
from tqdm import tqdm
|
19
|
-
import gc
|
20
|
-
import h5py
|
21
|
-
|
22
|
-
|
23
|
-
#############################
|
24
|
-
# Settings
|
25
|
-
#############################
|
26
|
-
|
27
|
-
parser = argparse.ArgumentParser(description='Pre-train AE for computing FID')
|
28
|
-
parser.add_argument('--root_path', type=str, default='')
|
29
|
-
parser.add_argument('--data_path', type=str, default='')
|
30
|
-
parser.add_argument('--num_workers', type=int, default=0)
|
31
|
-
parser.add_argument('--dim_bottleneck', type=int, default=512)
|
32
|
-
parser.add_argument('--epochs', type=int, default=200, metavar='N',
|
33
|
-
help='number of epochs to train CNNs (default: 200)')
|
34
|
-
parser.add_argument('--resume_epoch', type=int, default=0)
|
35
|
-
parser.add_argument('--batch_size_train', type=int, default=128, metavar='N',
|
36
|
-
help='input batch size for training')
|
37
|
-
parser.add_argument('--batch_size_valid', type=int, default=10, metavar='N',
|
38
|
-
help='input batch size for testing')
|
39
|
-
parser.add_argument('--base_lr', type=float, default=1e-3,
|
40
|
-
help='learning rate, default=1e-3')
|
41
|
-
parser.add_argument('--lr_decay_epochs', type=int, default=50) #decay lr rate every dre_lr_decay_epochs epochs
|
42
|
-
parser.add_argument('--lr_decay_factor', type=float, default=0.1)
|
43
|
-
parser.add_argument('--lambda_sparsity', type=float, default=1e-4, help='penalty for sparsity')
|
44
|
-
parser.add_argument('--weight_dacay', type=float, default=1e-4,
|
45
|
-
help='Weigth decay, default=1e-4')
|
46
|
-
parser.add_argument('--seed', type=int, default=2020, metavar='S',
|
47
|
-
help='random seed (default: 1)')
|
48
|
-
parser.add_argument('--CVMode', action='store_true', default=False,
|
49
|
-
help='CV mode?')
|
50
|
-
parser.add_argument('--img_size', type=int, default=128, metavar='N')
|
51
|
-
parser.add_argument('--min_label', type=float, default=0.0)
|
52
|
-
parser.add_argument('--max_label', type=float, default=90.0)
|
53
|
-
args = parser.parse_args()
|
54
|
-
|
55
|
-
wd = args.root_path
|
56
|
-
os.chdir(wd)
|
57
|
-
from ..models_128 import *
|
58
|
-
from .utils import IMGs_dataset, SimpleProgressBar
|
59
|
-
|
60
|
-
# some parameters in the opts
|
61
|
-
dim_bottleneck = args.dim_bottleneck
|
62
|
-
epochs = args.epochs
|
63
|
-
base_lr = args.base_lr
|
64
|
-
lr_decay_epochs = args.lr_decay_epochs
|
65
|
-
lr_decay_factor = args.lr_decay_factor
|
66
|
-
resume_epoch = args.resume_epoch
|
67
|
-
lambda_sparsity = args.lambda_sparsity
|
68
|
-
|
69
|
-
|
70
|
-
# random seed
|
71
|
-
random.seed(args.seed)
|
72
|
-
torch.manual_seed(args.seed)
|
73
|
-
torch.backends.cudnn.deterministic = True
|
74
|
-
cudnn.benchmark = False
|
75
|
-
np.random.seed(args.seed)
|
76
|
-
|
77
|
-
# directories for checkpoint, images and log files
|
78
|
-
save_models_folder = wd + '/output/eval_models'
|
79
|
-
os.makedirs(save_models_folder, exist_ok=True)
|
80
|
-
save_AE_images_in_train_folder = save_models_folder + '/AE_lambda_{}_images_in_train'.format(lambda_sparsity)
|
81
|
-
os.makedirs(save_AE_images_in_train_folder, exist_ok=True)
|
82
|
-
save_AE_images_in_valid_folder = save_models_folder + '/AE_lambda_{}_images_in_valid'.format(lambda_sparsity)
|
83
|
-
os.makedirs(save_AE_images_in_valid_folder, exist_ok=True)
|
84
|
-
|
85
|
-
|
86
|
-
###########################################################################################################
|
87
|
-
# Data
|
88
|
-
###########################################################################################################
|
89
|
-
# data loader
|
90
|
-
data_filename = args.data_path + '/Ra_' + str(args.img_size) + 'x' + str(args.img_size) + '.h5'
|
91
|
-
hf = h5py.File(data_filename, 'r')
|
92
|
-
labels = hf['labels'][:]
|
93
|
-
labels = labels.astype(float)
|
94
|
-
images = hf['images'][:]
|
95
|
-
hf.close()
|
96
|
-
N_all = len(images)
|
97
|
-
assert len(images) == len(labels)
|
98
|
-
|
99
|
-
q1 = args.min_label
|
100
|
-
q2 = args.max_label
|
101
|
-
indx = np.where((labels>q1)*(labels<q2)==True)[0]
|
102
|
-
labels = labels[indx]
|
103
|
-
images = images[indx]
|
104
|
-
assert len(labels)==len(images)
|
105
|
-
|
106
|
-
# define training and validation sets
|
107
|
-
if args.CVMode:
|
108
|
-
#90% Training; 10% valdation
|
109
|
-
valid_prop = 0.1 #proportion of the validation samples
|
110
|
-
indx_all = np.arange(len(images))
|
111
|
-
np.random.shuffle(indx_all)
|
112
|
-
indx_valid = indx_all[0:int(valid_prop*len(images))]
|
113
|
-
indx_train = indx_all[int(valid_prop*len(images)):]
|
114
|
-
|
115
|
-
trainset = IMGs_dataset(images[indx_train], labels=None, normalize=True)
|
116
|
-
trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size_train, shuffle=True, num_workers=args.num_workers)
|
117
|
-
validset = IMGs_dataset(images[indx_valid], labels=None, normalize=True)
|
118
|
-
validloader = torch.utils.data.DataLoader(validset, batch_size=args.batch_size_valid, shuffle=False, num_workers=args.num_workers)
|
119
|
-
|
120
|
-
else:
|
121
|
-
trainset = IMGs_dataset(images, labels=None, normalize=True)
|
122
|
-
trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size_train, shuffle=True, num_workers=args.num_workers)
|
123
|
-
|
124
|
-
|
125
|
-
###########################################################################################################
|
126
|
-
# Necessary functions
|
127
|
-
###########################################################################################################
|
128
|
-
|
129
|
-
def adjust_learning_rate(epoch, epochs, optimizer, base_lr, lr_decay_epochs, lr_decay_factor):
|
130
|
-
lr = base_lr #1e-4
|
131
|
-
|
132
|
-
for i in range(epochs//lr_decay_epochs):
|
133
|
-
if epoch >= (i+1)*lr_decay_epochs:
|
134
|
-
lr *= lr_decay_factor
|
135
|
-
|
136
|
-
for param_group in optimizer.param_groups:
|
137
|
-
param_group['lr'] = lr
|
138
|
-
|
139
|
-
def train_AE():
|
140
|
-
|
141
|
-
# define optimizer
|
142
|
-
params = list(net_encoder.parameters()) + list(net_decoder.parameters())
|
143
|
-
optimizer = torch.optim.Adam(params, lr = base_lr, betas=(0.5, 0.999), weight_decay=1e-4)
|
144
|
-
|
145
|
-
# criterion
|
146
|
-
criterion = nn.MSELoss()
|
147
|
-
|
148
|
-
if resume_epoch>0:
|
149
|
-
print("Loading ckpt to resume training AE >>>")
|
150
|
-
ckpt_fullpath = save_models_folder + "/AE_checkpoint_intrain/AE_checkpoint_epoch_{}_lambda_{}.pth".format(resume_epoch, lambda_sparsity)
|
151
|
-
checkpoint = torch.load(ckpt_fullpath)
|
152
|
-
net_encoder.load_state_dict(checkpoint['net_encoder_state_dict'])
|
153
|
-
net_decoder.load_state_dict(checkpoint['net_decoder_state_dict'])
|
154
|
-
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
155
|
-
torch.set_rng_state(checkpoint['rng_state'])
|
156
|
-
gen_iterations = checkpoint['gen_iterations']
|
157
|
-
else:
|
158
|
-
gen_iterations = 0
|
159
|
-
|
160
|
-
start_time = timeit.default_timer()
|
161
|
-
for epoch in range(resume_epoch, epochs):
|
162
|
-
|
163
|
-
adjust_learning_rate(epoch, epochs, optimizer, base_lr, lr_decay_epochs, lr_decay_factor)
|
164
|
-
|
165
|
-
train_loss = 0
|
166
|
-
|
167
|
-
for batch_idx, batch_real_images in enumerate(trainloader):
|
168
|
-
|
169
|
-
net_encoder.train()
|
170
|
-
net_decoder.train()
|
171
|
-
|
172
|
-
batch_size_curr = batch_real_images.shape[0]
|
173
|
-
|
174
|
-
batch_real_images = batch_real_images.type(torch.float).cuda()
|
175
|
-
|
176
|
-
|
177
|
-
batch_features = net_encoder(batch_real_images)
|
178
|
-
batch_recons_images = net_decoder(batch_features)
|
179
|
-
|
180
|
-
'''
|
181
|
-
based on https://debuggercafe.com/sparse-autoencoders-using-l1-regularization-with-pytorch/
|
182
|
-
'''
|
183
|
-
loss = criterion(batch_recons_images, batch_real_images) + lambda_sparsity * batch_features.mean()
|
184
|
-
|
185
|
-
#backward pass
|
186
|
-
optimizer.zero_grad()
|
187
|
-
loss.backward()
|
188
|
-
optimizer.step()
|
189
|
-
|
190
|
-
train_loss += loss.cpu().item()
|
191
|
-
|
192
|
-
gen_iterations += 1
|
193
|
-
|
194
|
-
if gen_iterations % 100 == 0:
|
195
|
-
n_row=min(10, int(np.sqrt(batch_size_curr)))
|
196
|
-
with torch.no_grad():
|
197
|
-
batch_recons_images = net_decoder(net_encoder(batch_real_images[0:n_row**2]))
|
198
|
-
batch_recons_images = batch_recons_images.detach().cpu()
|
199
|
-
save_image(batch_recons_images.data, save_AE_images_in_train_folder + '/{}.png'.format(gen_iterations), nrow=n_row, normalize=True)
|
200
|
-
|
201
|
-
if gen_iterations % 20 == 0:
|
202
|
-
print("AE+lambda{}: [step {}] [epoch {}/{}] [train loss {}] [Time {}]".format(lambda_sparsity, gen_iterations, epoch+1, epochs, train_loss/(batch_idx+1), timeit.default_timer()-start_time) )
|
203
|
-
# end for batch_idx
|
204
|
-
|
205
|
-
if (epoch+1) % 50 == 0:
|
206
|
-
save_file = save_models_folder + "/AE_checkpoint_intrain/AE_checkpoint_epoch_{}_lambda_{}.pth".format(epoch+1, lambda_sparsity)
|
207
|
-
os.makedirs(os.path.dirname(save_file), exist_ok=True)
|
208
|
-
torch.save({
|
209
|
-
'gen_iterations': gen_iterations,
|
210
|
-
'net_encoder_state_dict': net_encoder.state_dict(),
|
211
|
-
'net_decoder_state_dict': net_decoder.state_dict(),
|
212
|
-
'optimizer_state_dict': optimizer.state_dict(),
|
213
|
-
'rng_state': torch.get_rng_state()
|
214
|
-
}, save_file)
|
215
|
-
#end for epoch
|
216
|
-
|
217
|
-
return net_encoder, net_decoder
|
218
|
-
|
219
|
-
|
220
|
-
if args.CVMode:
|
221
|
-
def valid_AE():
|
222
|
-
net_encoder.eval()
|
223
|
-
net_decoder.eval()
|
224
|
-
with torch.no_grad():
|
225
|
-
for batch_idx, images in enumerate(validloader):
|
226
|
-
images = images.type(torch.float).cuda()
|
227
|
-
features = net_encoder(images)
|
228
|
-
recons_images = net_decoder(features)
|
229
|
-
save_image(recons_images.data, save_AE_images_in_valid_folder + '/{}_recons.png'.format(batch_idx), nrow=10, normalize=True)
|
230
|
-
save_image(images.data, save_AE_images_in_valid_folder + '/{}_real.png'.format(batch_idx), nrow=10, normalize=True)
|
231
|
-
return None
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
###########################################################################################################
|
236
|
-
# Training and validation
|
237
|
-
###########################################################################################################
|
238
|
-
|
239
|
-
# model initialization
|
240
|
-
net_encoder = encoder(dim_bottleneck=args.dim_bottleneck).cuda()
|
241
|
-
net_decoder = decoder(dim_bottleneck=args.dim_bottleneck).cuda()
|
242
|
-
net_encoder = nn.DataParallel(net_encoder)
|
243
|
-
net_decoder = nn.DataParallel(net_decoder)
|
244
|
-
|
245
|
-
filename_ckpt = save_models_folder + '/ckpt_AE_epoch_{}_seed_{}_CVMode_{}.pth'.format(args.epochs, args.seed, args.CVMode)
|
246
|
-
|
247
|
-
# training
|
248
|
-
if not os.path.isfile(filename_ckpt):
|
249
|
-
print("\n Begin training AE: ")
|
250
|
-
start = timeit.default_timer()
|
251
|
-
net_encoder, net_decoder = train_AE()
|
252
|
-
stop = timeit.default_timer()
|
253
|
-
print("Time elapses: {}s".format(stop - start))
|
254
|
-
# save model
|
255
|
-
torch.save({
|
256
|
-
'net_encoder_state_dict': net_encoder.state_dict(),
|
257
|
-
'net_decoder_state_dict': net_decoder.state_dict(),
|
258
|
-
}, filename_ckpt)
|
259
|
-
else:
|
260
|
-
print("\n Ckpt already exists")
|
261
|
-
print("\n Loading...")
|
262
|
-
checkpoint = torch.load(filename_ckpt)
|
263
|
-
net_encoder.load_state_dict(checkpoint['net_encoder_state_dict'])
|
264
|
-
net_decoder.load_state_dict(checkpoint['net_decoder_state_dict'])
|
265
|
-
|
266
|
-
if args.CVMode:
|
267
|
-
#validation
|
268
|
-
_ = valid_AE()
|
@@ -1,251 +0,0 @@
|
|
1
|
-
"""
|
2
|
-
|
3
|
-
Pre-train a CNN on the whole dataset for evaluation purpose
|
4
|
-
|
5
|
-
"""
|
6
|
-
import os
|
7
|
-
import argparse
|
8
|
-
import shutil
|
9
|
-
import timeit
|
10
|
-
|
11
|
-
import torch
|
12
|
-
import torchvision
|
13
|
-
import torchvision.transforms as transforms
|
14
|
-
import numpy as np
|
15
|
-
import torch.nn as nn
|
16
|
-
import torch.backends.cudnn as cudnn
|
17
|
-
import random
|
18
|
-
import matplotlib.pyplot as plt
|
19
|
-
import matplotlib as mpl
|
20
|
-
from torch import autograd
|
21
|
-
from torchvision.utils import save_image
|
22
|
-
import csv
|
23
|
-
from tqdm import tqdm
|
24
|
-
import gc
|
25
|
-
import h5py
|
26
|
-
|
27
|
-
from models import *
|
28
|
-
from utils import IMGs_dataset
|
29
|
-
|
30
|
-
|
31
|
-
#############################
|
32
|
-
# Settings
|
33
|
-
#############################
|
34
|
-
|
35
|
-
parser = argparse.ArgumentParser(description='Pre-train CNNs')
|
36
|
-
parser.add_argument('--root_path', type=str, default='')
|
37
|
-
parser.add_argument('--data_path', type=str, default='')
|
38
|
-
parser.add_argument('--num_workers', type=int, default=0)
|
39
|
-
parser.add_argument('--CNN', type=str, default='ResNet34_class',
|
40
|
-
help='CNN for training; ResNetXX')
|
41
|
-
parser.add_argument('--epochs', type=int, default=200, metavar='N',
|
42
|
-
help='number of epochs to train CNNs (default: 200)')
|
43
|
-
parser.add_argument('--batch_size_train', type=int, default=128, metavar='N',
|
44
|
-
help='input batch size for training')
|
45
|
-
parser.add_argument('--batch_size_valid', type=int, default=10, metavar='N',
|
46
|
-
help='input batch size for testing')
|
47
|
-
parser.add_argument('--base_lr', type=float, default=0.01,
|
48
|
-
help='learning rate, default=0.1')
|
49
|
-
parser.add_argument('--weight_dacay', type=float, default=1e-4,
|
50
|
-
help='Weigth decay, default=1e-4')
|
51
|
-
parser.add_argument('--seed', type=int, default=2020, metavar='S',
|
52
|
-
help='random seed (default: 1)')
|
53
|
-
parser.add_argument('--CVMode', action='store_true', default=False,
|
54
|
-
help='CV mode?')
|
55
|
-
parser.add_argument('--valid_proport', type=float, default=0.1,
|
56
|
-
help='Proportion of validation samples')
|
57
|
-
parser.add_argument('--img_size', type=int, default=128, metavar='N')
|
58
|
-
parser.add_argument('--min_label', type=float, default=0.0)
|
59
|
-
parser.add_argument('--max_label', type=float, default=90.0)
|
60
|
-
args = parser.parse_args()
|
61
|
-
|
62
|
-
|
63
|
-
wd = args.root_path
|
64
|
-
os.chdir(wd)
|
65
|
-
from ..models_128 import *
|
66
|
-
from .utils import IMGs_dataset
|
67
|
-
|
68
|
-
# cuda
|
69
|
-
device = torch.device("cuda")
|
70
|
-
ngpu = torch.cuda.device_count() # number of gpus
|
71
|
-
|
72
|
-
# random seed
|
73
|
-
random.seed(args.seed)
|
74
|
-
torch.manual_seed(args.seed)
|
75
|
-
torch.backends.cudnn.deterministic = True
|
76
|
-
np.random.seed(args.seed)
|
77
|
-
|
78
|
-
# directories for checkpoint, images and log files
|
79
|
-
save_models_folder = wd + '/output/eval_models/'
|
80
|
-
os.makedirs(save_models_folder, exist_ok=True)
|
81
|
-
|
82
|
-
|
83
|
-
# data loader
|
84
|
-
data_filename = args.data_path + '/Ra_' + str(args.img_size) + 'x' + str(args.img_size) + '.h5'
|
85
|
-
hf = h5py.File(data_filename, 'r')
|
86
|
-
angles = hf['labels'][:]
|
87
|
-
angles = angles.astype(float)
|
88
|
-
labels = hf['types'][:]
|
89
|
-
images = hf['images'][:]
|
90
|
-
hf.close()
|
91
|
-
num_classes = len(set(labels))
|
92
|
-
assert max(labels)==num_classes-1
|
93
|
-
|
94
|
-
q1 = args.min_label
|
95
|
-
q2 = args.max_label
|
96
|
-
indx = np.where((angles>q1)*(angles<q2)==True)[0]
|
97
|
-
angles = angles[indx]
|
98
|
-
labels = labels[indx]
|
99
|
-
images = images[indx]
|
100
|
-
assert len(labels)==len(images)
|
101
|
-
assert len(angles)==len(images)
|
102
|
-
|
103
|
-
|
104
|
-
# define training (and validaiton) set
|
105
|
-
if args.CVMode:
|
106
|
-
for i in range(num_classes):
|
107
|
-
indx_i = np.where(labels==i)[0] #i-th class
|
108
|
-
np.random.shuffle(indx_i)
|
109
|
-
num_imgs_all_i = len(indx_i)
|
110
|
-
num_imgs_valid_i = int(num_imgs_all_i*args.valid_proport)
|
111
|
-
if i == 0:
|
112
|
-
indx_valid = indx_i[0:num_imgs_valid_i]
|
113
|
-
indx_train = indx_i[num_imgs_valid_i:]
|
114
|
-
else:
|
115
|
-
indx_valid = np.concatenate((indx_valid, indx_i[0:num_imgs_valid_i]))
|
116
|
-
indx_train = np.concatenate((indx_train, indx_i[num_imgs_valid_i:]))
|
117
|
-
#end for i
|
118
|
-
trainset = IMGs_dataset(images[indx_train], labels[indx_train], normalize=True)
|
119
|
-
trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size_train, shuffle=True, num_workers=args.num_workers)
|
120
|
-
validset = IMGs_dataset(images[indx_valid], labels[indx_valid], normalize=True)
|
121
|
-
validloader = torch.utils.data.DataLoader(validset, batch_size=args.batch_size_valid, shuffle=False, num_workers=args.num_workers)
|
122
|
-
else:
|
123
|
-
trainset = IMGs_dataset(images, labels, normalize=True)
|
124
|
-
trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size_train, shuffle=True, num_workers=args.num_workers)
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
###########################################################################################################
|
129
|
-
# Necessary functions
|
130
|
-
###########################################################################################################
|
131
|
-
|
132
|
-
#initialize CNNs
|
133
|
-
def net_initialization(Pretrained_CNN_Name, ngpu = 1, num_classes=num_classes):
|
134
|
-
if Pretrained_CNN_Name == "ResNet18_class":
|
135
|
-
net = ResNet18_class_eval(num_classes=num_classes, ngpu = ngpu)
|
136
|
-
elif Pretrained_CNN_Name == "ResNet34_class":
|
137
|
-
net = ResNet34_class_eval(num_classes=num_classes, ngpu = ngpu)
|
138
|
-
elif Pretrained_CNN_Name == "ResNet50_class":
|
139
|
-
net = ResNet50_class_eval(num_classes=num_classes, ngpu = ngpu)
|
140
|
-
elif Pretrained_CNN_Name == "ResNet101_class":
|
141
|
-
net = ResNet101_class_eval(num_classes=num_classes, ngpu = ngpu)
|
142
|
-
|
143
|
-
net_name = 'PreCNNForEvalGANs_' + Pretrained_CNN_Name #get the net's name
|
144
|
-
net = net.to(device)
|
145
|
-
|
146
|
-
return net, net_name
|
147
|
-
|
148
|
-
#adjust CNN learning rate
|
149
|
-
def adjust_learning_rate(optimizer, epoch, BASE_LR_CNN):
|
150
|
-
lr = BASE_LR_CNN
|
151
|
-
# if epoch >= 35:
|
152
|
-
# lr /= 10
|
153
|
-
# if epoch >= 70:
|
154
|
-
# lr /= 10
|
155
|
-
if epoch >= 50:
|
156
|
-
lr /= 10
|
157
|
-
if epoch >= 120:
|
158
|
-
lr /= 10
|
159
|
-
for param_group in optimizer.param_groups:
|
160
|
-
param_group['lr'] = lr
|
161
|
-
|
162
|
-
|
163
|
-
def train_CNN():
|
164
|
-
|
165
|
-
start_tmp = timeit.default_timer()
|
166
|
-
for epoch in range(args.epochs):
|
167
|
-
net.train()
|
168
|
-
train_loss = 0
|
169
|
-
adjust_learning_rate(optimizer, epoch, args.base_lr)
|
170
|
-
for batch_idx, (batch_train_images, batch_train_labels) in enumerate(trainloader):
|
171
|
-
|
172
|
-
# batch_train_images = nn.functional.interpolate(batch_train_images, size = (299,299), scale_factor=None, mode='bilinear', align_corners=False)
|
173
|
-
|
174
|
-
batch_train_images = batch_train_images.type(torch.float).cuda()
|
175
|
-
batch_train_labels = batch_train_labels.type(torch.long).cuda()
|
176
|
-
|
177
|
-
#Forward pass
|
178
|
-
outputs,_ = net(batch_train_images)
|
179
|
-
loss = criterion(outputs, batch_train_labels)
|
180
|
-
|
181
|
-
#backward pass
|
182
|
-
optimizer.zero_grad()
|
183
|
-
loss.backward()
|
184
|
-
optimizer.step()
|
185
|
-
|
186
|
-
train_loss += loss.cpu().item()
|
187
|
-
#end for batch_idx
|
188
|
-
train_loss = train_loss / len(trainloader)
|
189
|
-
|
190
|
-
if args.CVMode:
|
191
|
-
valid_acc = valid_CNN(verbose=False)
|
192
|
-
print('CNN: [epoch %d/%d] train_loss:%f valid_acc:%f Time: %.4f' % (epoch+1, args.epochs, train_loss, valid_acc, timeit.default_timer()-start_tmp))
|
193
|
-
else:
|
194
|
-
print('CNN: [epoch %d/%d] train_loss:%f Time: %.4f' % (epoch+1, args.epochs, train_loss, timeit.default_timer()-start_tmp))
|
195
|
-
#end for epoch
|
196
|
-
|
197
|
-
return net, optimizer
|
198
|
-
|
199
|
-
|
200
|
-
if args.CVMode:
|
201
|
-
def valid_CNN(verbose=True):
|
202
|
-
net.eval() # eval mode (batchnorm uses moving mean/variance instead of mini-batch mean/variance)
|
203
|
-
with torch.no_grad():
|
204
|
-
correct = 0
|
205
|
-
total = 0
|
206
|
-
for batch_idx, (images, labels) in enumerate(validloader):
|
207
|
-
images = images.type(torch.float).cuda()
|
208
|
-
labels = labels.type(torch.long).cuda()
|
209
|
-
outputs,_ = net(images)
|
210
|
-
_, predicted = torch.max(outputs.data, 1)
|
211
|
-
total += labels.size(0)
|
212
|
-
correct += (predicted == labels).sum().item()
|
213
|
-
if verbose:
|
214
|
-
print('Valid Accuracy of the model on the validation set: {} %'.format(100.0 * correct / total))
|
215
|
-
return 100.0 * correct / total
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
###########################################################################################################
|
220
|
-
# Training and Testing
|
221
|
-
###########################################################################################################
|
222
|
-
# model initialization
|
223
|
-
net, net_name = net_initialization(args.CNN, ngpu = ngpu, num_classes = num_classes)
|
224
|
-
criterion = nn.CrossEntropyLoss()
|
225
|
-
optimizer = torch.optim.SGD(net.parameters(), lr = args.base_lr, momentum= 0.9, weight_decay=args.weight_dacay)
|
226
|
-
|
227
|
-
filename_ckpt = save_models_folder + '/ckpt_{}_epoch_{}_seed_{}_classify_{}_chair_types_CVMode_{}.pth'.format(net_name, args.epochs, args.seed, num_classes, args.CVMode)
|
228
|
-
|
229
|
-
# training
|
230
|
-
if not os.path.isfile(filename_ckpt):
|
231
|
-
# TRAIN CNN
|
232
|
-
print("\n Begin training CNN: ")
|
233
|
-
start = timeit.default_timer()
|
234
|
-
net, optimizer = train_CNN()
|
235
|
-
stop = timeit.default_timer()
|
236
|
-
print("Time elapses: {}s".format(stop - start))
|
237
|
-
# save model
|
238
|
-
torch.save({
|
239
|
-
'net_state_dict': net.state_dict(),
|
240
|
-
}, filename_ckpt)
|
241
|
-
else:
|
242
|
-
print("\n Ckpt already exists")
|
243
|
-
print("\n Loading...")
|
244
|
-
checkpoint = torch.load(filename_ckpt)
|
245
|
-
net.load_state_dict(checkpoint['net_state_dict'])
|
246
|
-
torch.cuda.empty_cache()#release GPU mem which is not references
|
247
|
-
|
248
|
-
if args.CVMode:
|
249
|
-
#validation
|
250
|
-
_ = valid_CNN(True)
|
251
|
-
torch.cuda.empty_cache()
|