Myosotis-Researches 0.0.18__py3-none-any.whl → 0.0.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- myosotis_researches/CcGAN/train_128_output_10/DiffAugment_pytorch.py +76 -0
- myosotis_researches/CcGAN/train_128_output_10/__init__.py +0 -0
- myosotis_researches/CcGAN/train_128_output_10/eval_metrics.py +205 -0
- myosotis_researches/CcGAN/train_128_output_10/opts.py +87 -0
- myosotis_researches/CcGAN/train_128_output_10/pretrain_AE.py +268 -0
- myosotis_researches/CcGAN/train_128_output_10/pretrain_CNN_class.py +251 -0
- myosotis_researches/CcGAN/train_128_output_10/pretrain_CNN_regre.py +255 -0
- myosotis_researches/CcGAN/train_128_output_10/train_ccgan.py +302 -0
- myosotis_researches/CcGAN/train_128_output_10/train_cgan.py +254 -0
- myosotis_researches/CcGAN/train_128_output_10/train_cgan_concat.py +242 -0
- myosotis_researches/CcGAN/train_128_output_10/train_net_for_label_embed.py +181 -0
- myosotis_researches/CcGAN/train_128_output_10/utils.py +120 -0
- {myosotis_researches-0.0.18.dist-info → myosotis_researches-0.0.19.dist-info}/METADATA +1 -1
- {myosotis_researches-0.0.18.dist-info → myosotis_researches-0.0.19.dist-info}/RECORD +17 -5
- {myosotis_researches-0.0.18.dist-info → myosotis_researches-0.0.19.dist-info}/WHEEL +0 -0
- {myosotis_researches-0.0.18.dist-info → myosotis_researches-0.0.19.dist-info}/licenses/LICENSE +0 -0
- {myosotis_researches-0.0.18.dist-info → myosotis_researches-0.0.19.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,251 @@
|
|
1
|
+
"""
|
2
|
+
|
3
|
+
Pre-train a CNN on the whole dataset for evaluation purpose
|
4
|
+
|
5
|
+
"""
|
6
|
+
import os
|
7
|
+
import argparse
|
8
|
+
import shutil
|
9
|
+
import timeit
|
10
|
+
|
11
|
+
import torch
|
12
|
+
import torchvision
|
13
|
+
import torchvision.transforms as transforms
|
14
|
+
import numpy as np
|
15
|
+
import torch.nn as nn
|
16
|
+
import torch.backends.cudnn as cudnn
|
17
|
+
import random
|
18
|
+
import matplotlib.pyplot as plt
|
19
|
+
import matplotlib as mpl
|
20
|
+
from torch import autograd
|
21
|
+
from torchvision.utils import save_image
|
22
|
+
import csv
|
23
|
+
from tqdm import tqdm
|
24
|
+
import gc
|
25
|
+
import h5py
|
26
|
+
|
27
|
+
from models import *
|
28
|
+
from utils import IMGs_dataset
|
29
|
+
|
30
|
+
|
31
|
+
#############################
|
32
|
+
# Settings
|
33
|
+
#############################
|
34
|
+
|
35
|
+
parser = argparse.ArgumentParser(description='Pre-train CNNs')
|
36
|
+
parser.add_argument('--root_path', type=str, default='')
|
37
|
+
parser.add_argument('--data_path', type=str, default='')
|
38
|
+
parser.add_argument('--num_workers', type=int, default=0)
|
39
|
+
parser.add_argument('--CNN', type=str, default='ResNet34_class',
|
40
|
+
help='CNN for training; ResNetXX')
|
41
|
+
parser.add_argument('--epochs', type=int, default=200, metavar='N',
|
42
|
+
help='number of epochs to train CNNs (default: 200)')
|
43
|
+
parser.add_argument('--batch_size_train', type=int, default=128, metavar='N',
|
44
|
+
help='input batch size for training')
|
45
|
+
parser.add_argument('--batch_size_valid', type=int, default=10, metavar='N',
|
46
|
+
help='input batch size for testing')
|
47
|
+
parser.add_argument('--base_lr', type=float, default=0.01,
|
48
|
+
help='learning rate, default=0.1')
|
49
|
+
parser.add_argument('--weight_dacay', type=float, default=1e-4,
|
50
|
+
help='Weigth decay, default=1e-4')
|
51
|
+
parser.add_argument('--seed', type=int, default=2020, metavar='S',
|
52
|
+
help='random seed (default: 1)')
|
53
|
+
parser.add_argument('--CVMode', action='store_true', default=False,
|
54
|
+
help='CV mode?')
|
55
|
+
parser.add_argument('--valid_proport', type=float, default=0.1,
|
56
|
+
help='Proportion of validation samples')
|
57
|
+
parser.add_argument('--img_size', type=int, default=128, metavar='N')
|
58
|
+
parser.add_argument('--min_label', type=float, default=0.0)
|
59
|
+
parser.add_argument('--max_label', type=float, default=90.0)
|
60
|
+
args = parser.parse_args()
|
61
|
+
|
62
|
+
|
63
|
+
wd = args.root_path
|
64
|
+
os.chdir(wd)
|
65
|
+
from ..models_128 import *
|
66
|
+
from .utils import IMGs_dataset
|
67
|
+
|
68
|
+
# cuda
|
69
|
+
device = torch.device("cuda")
|
70
|
+
ngpu = torch.cuda.device_count() # number of gpus
|
71
|
+
|
72
|
+
# random seed
|
73
|
+
random.seed(args.seed)
|
74
|
+
torch.manual_seed(args.seed)
|
75
|
+
torch.backends.cudnn.deterministic = True
|
76
|
+
np.random.seed(args.seed)
|
77
|
+
|
78
|
+
# directories for checkpoint, images and log files
|
79
|
+
save_models_folder = wd + '/output/eval_models/'
|
80
|
+
os.makedirs(save_models_folder, exist_ok=True)
|
81
|
+
|
82
|
+
|
83
|
+
# data loader
|
84
|
+
data_filename = args.data_path + '/Ra_' + str(args.img_size) + 'x' + str(args.img_size) + '.h5'
|
85
|
+
hf = h5py.File(data_filename, 'r')
|
86
|
+
angles = hf['labels'][:]
|
87
|
+
angles = angles.astype(float)
|
88
|
+
labels = hf['types'][:]
|
89
|
+
images = hf['images'][:]
|
90
|
+
hf.close()
|
91
|
+
num_classes = len(set(labels))
|
92
|
+
assert max(labels)==num_classes-1
|
93
|
+
|
94
|
+
q1 = args.min_label
|
95
|
+
q2 = args.max_label
|
96
|
+
indx = np.where((angles>q1)*(angles<q2)==True)[0]
|
97
|
+
angles = angles[indx]
|
98
|
+
labels = labels[indx]
|
99
|
+
images = images[indx]
|
100
|
+
assert len(labels)==len(images)
|
101
|
+
assert len(angles)==len(images)
|
102
|
+
|
103
|
+
|
104
|
+
# define training (and validaiton) set
|
105
|
+
if args.CVMode:
|
106
|
+
for i in range(num_classes):
|
107
|
+
indx_i = np.where(labels==i)[0] #i-th class
|
108
|
+
np.random.shuffle(indx_i)
|
109
|
+
num_imgs_all_i = len(indx_i)
|
110
|
+
num_imgs_valid_i = int(num_imgs_all_i*args.valid_proport)
|
111
|
+
if i == 0:
|
112
|
+
indx_valid = indx_i[0:num_imgs_valid_i]
|
113
|
+
indx_train = indx_i[num_imgs_valid_i:]
|
114
|
+
else:
|
115
|
+
indx_valid = np.concatenate((indx_valid, indx_i[0:num_imgs_valid_i]))
|
116
|
+
indx_train = np.concatenate((indx_train, indx_i[num_imgs_valid_i:]))
|
117
|
+
#end for i
|
118
|
+
trainset = IMGs_dataset(images[indx_train], labels[indx_train], normalize=True)
|
119
|
+
trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size_train, shuffle=True, num_workers=args.num_workers)
|
120
|
+
validset = IMGs_dataset(images[indx_valid], labels[indx_valid], normalize=True)
|
121
|
+
validloader = torch.utils.data.DataLoader(validset, batch_size=args.batch_size_valid, shuffle=False, num_workers=args.num_workers)
|
122
|
+
else:
|
123
|
+
trainset = IMGs_dataset(images, labels, normalize=True)
|
124
|
+
trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size_train, shuffle=True, num_workers=args.num_workers)
|
125
|
+
|
126
|
+
|
127
|
+
|
128
|
+
###########################################################################################################
|
129
|
+
# Necessary functions
|
130
|
+
###########################################################################################################
|
131
|
+
|
132
|
+
#initialize CNNs
|
133
|
+
def net_initialization(Pretrained_CNN_Name, ngpu = 1, num_classes=num_classes):
|
134
|
+
if Pretrained_CNN_Name == "ResNet18_class":
|
135
|
+
net = ResNet18_class_eval(num_classes=num_classes, ngpu = ngpu)
|
136
|
+
elif Pretrained_CNN_Name == "ResNet34_class":
|
137
|
+
net = ResNet34_class_eval(num_classes=num_classes, ngpu = ngpu)
|
138
|
+
elif Pretrained_CNN_Name == "ResNet50_class":
|
139
|
+
net = ResNet50_class_eval(num_classes=num_classes, ngpu = ngpu)
|
140
|
+
elif Pretrained_CNN_Name == "ResNet101_class":
|
141
|
+
net = ResNet101_class_eval(num_classes=num_classes, ngpu = ngpu)
|
142
|
+
|
143
|
+
net_name = 'PreCNNForEvalGANs_' + Pretrained_CNN_Name #get the net's name
|
144
|
+
net = net.to(device)
|
145
|
+
|
146
|
+
return net, net_name
|
147
|
+
|
148
|
+
#adjust CNN learning rate
|
149
|
+
def adjust_learning_rate(optimizer, epoch, BASE_LR_CNN):
|
150
|
+
lr = BASE_LR_CNN
|
151
|
+
# if epoch >= 35:
|
152
|
+
# lr /= 10
|
153
|
+
# if epoch >= 70:
|
154
|
+
# lr /= 10
|
155
|
+
if epoch >= 50:
|
156
|
+
lr /= 10
|
157
|
+
if epoch >= 120:
|
158
|
+
lr /= 10
|
159
|
+
for param_group in optimizer.param_groups:
|
160
|
+
param_group['lr'] = lr
|
161
|
+
|
162
|
+
|
163
|
+
def train_CNN():
|
164
|
+
|
165
|
+
start_tmp = timeit.default_timer()
|
166
|
+
for epoch in range(args.epochs):
|
167
|
+
net.train()
|
168
|
+
train_loss = 0
|
169
|
+
adjust_learning_rate(optimizer, epoch, args.base_lr)
|
170
|
+
for batch_idx, (batch_train_images, batch_train_labels) in enumerate(trainloader):
|
171
|
+
|
172
|
+
# batch_train_images = nn.functional.interpolate(batch_train_images, size = (299,299), scale_factor=None, mode='bilinear', align_corners=False)
|
173
|
+
|
174
|
+
batch_train_images = batch_train_images.type(torch.float).cuda()
|
175
|
+
batch_train_labels = batch_train_labels.type(torch.long).cuda()
|
176
|
+
|
177
|
+
#Forward pass
|
178
|
+
outputs,_ = net(batch_train_images)
|
179
|
+
loss = criterion(outputs, batch_train_labels)
|
180
|
+
|
181
|
+
#backward pass
|
182
|
+
optimizer.zero_grad()
|
183
|
+
loss.backward()
|
184
|
+
optimizer.step()
|
185
|
+
|
186
|
+
train_loss += loss.cpu().item()
|
187
|
+
#end for batch_idx
|
188
|
+
train_loss = train_loss / len(trainloader)
|
189
|
+
|
190
|
+
if args.CVMode:
|
191
|
+
valid_acc = valid_CNN(verbose=False)
|
192
|
+
print('CNN: [epoch %d/%d] train_loss:%f valid_acc:%f Time: %.4f' % (epoch+1, args.epochs, train_loss, valid_acc, timeit.default_timer()-start_tmp))
|
193
|
+
else:
|
194
|
+
print('CNN: [epoch %d/%d] train_loss:%f Time: %.4f' % (epoch+1, args.epochs, train_loss, timeit.default_timer()-start_tmp))
|
195
|
+
#end for epoch
|
196
|
+
|
197
|
+
return net, optimizer
|
198
|
+
|
199
|
+
|
200
|
+
if args.CVMode:
|
201
|
+
def valid_CNN(verbose=True):
|
202
|
+
net.eval() # eval mode (batchnorm uses moving mean/variance instead of mini-batch mean/variance)
|
203
|
+
with torch.no_grad():
|
204
|
+
correct = 0
|
205
|
+
total = 0
|
206
|
+
for batch_idx, (images, labels) in enumerate(validloader):
|
207
|
+
images = images.type(torch.float).cuda()
|
208
|
+
labels = labels.type(torch.long).cuda()
|
209
|
+
outputs,_ = net(images)
|
210
|
+
_, predicted = torch.max(outputs.data, 1)
|
211
|
+
total += labels.size(0)
|
212
|
+
correct += (predicted == labels).sum().item()
|
213
|
+
if verbose:
|
214
|
+
print('Valid Accuracy of the model on the validation set: {} %'.format(100.0 * correct / total))
|
215
|
+
return 100.0 * correct / total
|
216
|
+
|
217
|
+
|
218
|
+
|
219
|
+
###########################################################################################################
|
220
|
+
# Training and Testing
|
221
|
+
###########################################################################################################
|
222
|
+
# model initialization
|
223
|
+
net, net_name = net_initialization(args.CNN, ngpu = ngpu, num_classes = num_classes)
|
224
|
+
criterion = nn.CrossEntropyLoss()
|
225
|
+
optimizer = torch.optim.SGD(net.parameters(), lr = args.base_lr, momentum= 0.9, weight_decay=args.weight_dacay)
|
226
|
+
|
227
|
+
filename_ckpt = save_models_folder + '/ckpt_{}_epoch_{}_seed_{}_classify_{}_chair_types_CVMode_{}.pth'.format(net_name, args.epochs, args.seed, num_classes, args.CVMode)
|
228
|
+
|
229
|
+
# training
|
230
|
+
if not os.path.isfile(filename_ckpt):
|
231
|
+
# TRAIN CNN
|
232
|
+
print("\n Begin training CNN: ")
|
233
|
+
start = timeit.default_timer()
|
234
|
+
net, optimizer = train_CNN()
|
235
|
+
stop = timeit.default_timer()
|
236
|
+
print("Time elapses: {}s".format(stop - start))
|
237
|
+
# save model
|
238
|
+
torch.save({
|
239
|
+
'net_state_dict': net.state_dict(),
|
240
|
+
}, filename_ckpt)
|
241
|
+
else:
|
242
|
+
print("\n Ckpt already exists")
|
243
|
+
print("\n Loading...")
|
244
|
+
checkpoint = torch.load(filename_ckpt)
|
245
|
+
net.load_state_dict(checkpoint['net_state_dict'])
|
246
|
+
torch.cuda.empty_cache()#release GPU mem which is not references
|
247
|
+
|
248
|
+
if args.CVMode:
|
249
|
+
#validation
|
250
|
+
_ = valid_CNN(True)
|
251
|
+
torch.cuda.empty_cache()
|
@@ -0,0 +1,255 @@
|
|
1
|
+
"""
|
2
|
+
|
3
|
+
Pre-train a CNN on the whole dataset for evaluation purpose
|
4
|
+
|
5
|
+
"""
|
6
|
+
import os
|
7
|
+
import argparse
|
8
|
+
import shutil
|
9
|
+
import timeit
|
10
|
+
import torch
|
11
|
+
import torchvision
|
12
|
+
import torchvision.transforms as transforms
|
13
|
+
import numpy as np
|
14
|
+
import torch.nn as nn
|
15
|
+
import torch.backends.cudnn as cudnn
|
16
|
+
import random
|
17
|
+
import matplotlib.pyplot as plt
|
18
|
+
import matplotlib as mpl
|
19
|
+
from torch import autograd
|
20
|
+
from torchvision.utils import save_image
|
21
|
+
import csv
|
22
|
+
from tqdm import tqdm
|
23
|
+
import gc
|
24
|
+
import h5py
|
25
|
+
|
26
|
+
|
27
|
+
#############################
|
28
|
+
# Settings
|
29
|
+
#############################
|
30
|
+
|
31
|
+
parser = argparse.ArgumentParser(description='Pre-train CNNs')
|
32
|
+
parser.add_argument('--root_path', type=str, default='')
|
33
|
+
parser.add_argument('--data_path', type=str, default='')
|
34
|
+
parser.add_argument('--num_workers', type=int, default=0)
|
35
|
+
parser.add_argument('--CNN', type=str, default='ResNet34_regre',
|
36
|
+
help='CNN for training; ResNetXX')
|
37
|
+
parser.add_argument('--epochs', type=int, default=200, metavar='N',
|
38
|
+
help='number of epochs to train CNNs (default: 200)')
|
39
|
+
parser.add_argument('--batch_size_train', type=int, default=256, metavar='N',
|
40
|
+
help='input batch size for training')
|
41
|
+
parser.add_argument('--batch_size_valid', type=int, default=10, metavar='N',
|
42
|
+
help='input batch size for testing')
|
43
|
+
parser.add_argument('--base_lr', type=float, default=0.01,
|
44
|
+
help='learning rate, default=0.1')
|
45
|
+
parser.add_argument('--weight_dacay', type=float, default=1e-4,
|
46
|
+
help='Weigth decay, default=1e-4')
|
47
|
+
parser.add_argument('--seed', type=int, default=2020, metavar='S',
|
48
|
+
help='random seed (default: 1)')
|
49
|
+
parser.add_argument('--CVMode', action='store_true', default=False,
|
50
|
+
help='CV mode?')
|
51
|
+
parser.add_argument('--img_size', type=int, default=128, metavar='N')
|
52
|
+
parser.add_argument('--min_label', type=float, default=0.0)
|
53
|
+
parser.add_argument('--max_label', type=float, default=90.0)
|
54
|
+
args = parser.parse_args()
|
55
|
+
|
56
|
+
|
57
|
+
wd = args.root_path
|
58
|
+
os.chdir(wd)
|
59
|
+
from ..models_128 import *
|
60
|
+
from .utils import IMGs_dataset
|
61
|
+
|
62
|
+
# cuda
|
63
|
+
device = torch.device("cuda")
|
64
|
+
ngpu = torch.cuda.device_count() # number of gpus
|
65
|
+
|
66
|
+
# random seed
|
67
|
+
random.seed(args.seed)
|
68
|
+
torch.manual_seed(args.seed)
|
69
|
+
torch.backends.cudnn.deterministic = True
|
70
|
+
np.random.seed(args.seed)
|
71
|
+
|
72
|
+
# directories for checkpoint, images and log files
|
73
|
+
save_models_folder = wd + '/output/eval_models/'
|
74
|
+
os.makedirs(save_models_folder, exist_ok=True)
|
75
|
+
|
76
|
+
|
77
|
+
###########################################################################################################
|
78
|
+
# Data
|
79
|
+
###########################################################################################################
|
80
|
+
# data loader
|
81
|
+
data_filename = args.data_path + '/Ra_' + str(args.img_size) + 'x' + str(args.img_size) + '.h5'
|
82
|
+
hf = h5py.File(data_filename, 'r')
|
83
|
+
labels = hf['labels'][:]
|
84
|
+
labels = labels.astype(float)
|
85
|
+
images = hf['images'][:]
|
86
|
+
hf.close()
|
87
|
+
N_all = len(images)
|
88
|
+
assert len(images) == len(labels)
|
89
|
+
|
90
|
+
q1 = args.min_label
|
91
|
+
q2 = args.max_label
|
92
|
+
indx = np.where((labels>q1)*(labels<q2)==True)[0]
|
93
|
+
labels = labels[indx]
|
94
|
+
images = images[indx]
|
95
|
+
assert len(labels)==len(images)
|
96
|
+
|
97
|
+
|
98
|
+
# normalize to [0, 1]
|
99
|
+
#min_label = np.min(labels)
|
100
|
+
#labels += np.abs(min_label)
|
101
|
+
#max_label = np.max(labels)
|
102
|
+
#labels /= max_label
|
103
|
+
labels /= args.max_label
|
104
|
+
|
105
|
+
|
106
|
+
# define training and validation sets
|
107
|
+
if args.CVMode:
|
108
|
+
#80% Training; 20% valdation
|
109
|
+
indx_all = np.arange(len(labels))
|
110
|
+
np.random.shuffle(indx_all)
|
111
|
+
indx_valid = indx_all[0:int(0.2*len(labels))]
|
112
|
+
indx_train = indx_all[int(0.2*len(labels)):]
|
113
|
+
|
114
|
+
trainset = IMGs_dataset(images[indx_train], labels[indx_train], normalize=True)
|
115
|
+
trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size_train, shuffle=True, num_workers=args.num_workers)
|
116
|
+
validset = IMGs_dataset(images[indx_valid], labels[indx_valid], normalize=True)
|
117
|
+
validloader = torch.utils.data.DataLoader(validset, batch_size=args.batch_size_valid, shuffle=False, num_workers=args.num_workers)
|
118
|
+
|
119
|
+
else:
|
120
|
+
trainset = IMGs_dataset(images, labels, normalize=True)
|
121
|
+
trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size_train, shuffle=True, num_workers=args.num_workers)
|
122
|
+
|
123
|
+
|
124
|
+
|
125
|
+
|
126
|
+
###########################################################################################################
|
127
|
+
# Necessary functions
|
128
|
+
###########################################################################################################
|
129
|
+
|
130
|
+
#initialize CNNs
|
131
|
+
def net_initialization(Pretrained_CNN_Name, ngpu = 1):
|
132
|
+
if Pretrained_CNN_Name == "ResNet18_regre":
|
133
|
+
net = ResNet18_regre_eval(ngpu = ngpu)
|
134
|
+
elif Pretrained_CNN_Name == "ResNet34_regre":
|
135
|
+
net = ResNet34_regre_eval(ngpu = ngpu)
|
136
|
+
elif Pretrained_CNN_Name == "ResNet50_regre":
|
137
|
+
net = ResNet50_regre_eval(ngpu = ngpu)
|
138
|
+
elif Pretrained_CNN_Name == "ResNet101_regre":
|
139
|
+
net = ResNet101_regre_eval(ngpu = ngpu)
|
140
|
+
|
141
|
+
net_name = 'PreCNNForEvalGANs_' + Pretrained_CNN_Name #get the net's name
|
142
|
+
net = net.to(device)
|
143
|
+
|
144
|
+
return net, net_name
|
145
|
+
|
146
|
+
#adjust CNN learning rate
|
147
|
+
def adjust_learning_rate(optimizer, epoch, BASE_LR_CNN):
|
148
|
+
lr = BASE_LR_CNN
|
149
|
+
# if epoch >= 35:
|
150
|
+
# lr /= 10
|
151
|
+
# if epoch >= 70:
|
152
|
+
# lr /= 10
|
153
|
+
if epoch >= 50:
|
154
|
+
lr /= 10
|
155
|
+
if epoch >= 120:
|
156
|
+
lr /= 10
|
157
|
+
for param_group in optimizer.param_groups:
|
158
|
+
param_group['lr'] = lr
|
159
|
+
|
160
|
+
|
161
|
+
def train_CNN():
|
162
|
+
|
163
|
+
start_tmp = timeit.default_timer()
|
164
|
+
for epoch in range(args.epochs):
|
165
|
+
net.train()
|
166
|
+
train_loss = 0
|
167
|
+
adjust_learning_rate(optimizer, epoch, args.base_lr)
|
168
|
+
for batch_idx, (batch_train_images, batch_train_labels) in enumerate(trainloader):
|
169
|
+
|
170
|
+
# batch_train_images = nn.functional.interpolate(batch_train_images, size = (299,299), scale_factor=None, mode='bilinear', align_corners=False)
|
171
|
+
|
172
|
+
batch_train_images = batch_train_images.type(torch.float).cuda()
|
173
|
+
batch_train_labels = batch_train_labels.type(torch.float).view(-1,1).cuda()
|
174
|
+
|
175
|
+
#Forward pass
|
176
|
+
outputs,_ = net(batch_train_images)
|
177
|
+
loss = criterion(outputs, batch_train_labels)
|
178
|
+
|
179
|
+
#backward pass
|
180
|
+
optimizer.zero_grad()
|
181
|
+
loss.backward()
|
182
|
+
optimizer.step()
|
183
|
+
|
184
|
+
train_loss += loss.cpu().item()
|
185
|
+
#end for batch_idx
|
186
|
+
train_loss = train_loss / len(trainloader)
|
187
|
+
|
188
|
+
if args.CVMode:
|
189
|
+
valid_loss = valid_CNN(verbose=False)
|
190
|
+
print('CNN: [epoch %d/%d] train_loss:%f valid_loss (avg_abs):%f Time: %.4f' % (epoch+1, args.epochs, train_loss, valid_loss, timeit.default_timer()-start_tmp))
|
191
|
+
else:
|
192
|
+
print('CNN: [epoch %d/%d] train_loss:%f Time:%.4f' % (epoch+1, args.epochs, train_loss, timeit.default_timer()-start_tmp))
|
193
|
+
#end for epoch
|
194
|
+
|
195
|
+
return net, optimizer
|
196
|
+
|
197
|
+
if args.CVMode:
|
198
|
+
def valid_CNN(verbose=True):
|
199
|
+
net.eval() # eval mode (batchnorm uses moving mean/variance instead of mini-batch mean/variance)
|
200
|
+
with torch.no_grad():
|
201
|
+
abs_diff_avg = 0
|
202
|
+
total = 0
|
203
|
+
for batch_idx, (images, labels) in enumerate(validloader):
|
204
|
+
images = images.type(torch.float).cuda()
|
205
|
+
labels = labels.type(torch.float).view(-1).cpu().numpy()
|
206
|
+
outputs,_ = net(images)
|
207
|
+
outputs = outputs.view(-1).cpu().numpy()
|
208
|
+
labels = labels * args.max_label
|
209
|
+
outputs = outputs * args.max_label
|
210
|
+
abs_diff_avg += np.sum(np.abs(labels-outputs))
|
211
|
+
total += len(labels)
|
212
|
+
|
213
|
+
if verbose:
|
214
|
+
print('Validation Average Absolute Difference: {}'.format(abs_diff_avg/total))
|
215
|
+
return abs_diff_avg/total
|
216
|
+
|
217
|
+
|
218
|
+
|
219
|
+
###########################################################################################################
|
220
|
+
# Training and validation
|
221
|
+
###########################################################################################################
|
222
|
+
|
223
|
+
|
224
|
+
# model initialization
|
225
|
+
net, net_name = net_initialization(args.CNN, ngpu = ngpu)
|
226
|
+
criterion = nn.MSELoss()
|
227
|
+
optimizer = torch.optim.SGD(net.parameters(), lr = args.base_lr, momentum= 0.9, weight_decay=args.weight_dacay)
|
228
|
+
|
229
|
+
filename_ckpt = save_models_folder + '/ckpt_{}_epoch_{}_seed_{}_CVMode_{}.pth'.format(net_name, args.epochs, args.seed, args.CVMode)
|
230
|
+
|
231
|
+
|
232
|
+
# training
|
233
|
+
if not os.path.isfile(filename_ckpt):
|
234
|
+
# TRAIN CNN
|
235
|
+
print("\n Begin training CNN: ")
|
236
|
+
start = timeit.default_timer()
|
237
|
+
net, optimizer = train_CNN()
|
238
|
+
stop = timeit.default_timer()
|
239
|
+
print("Time elapses: {}s".format(stop - start))
|
240
|
+
# save model
|
241
|
+
torch.save({
|
242
|
+
'net_state_dict': net.state_dict(),
|
243
|
+
}, filename_ckpt)
|
244
|
+
else:
|
245
|
+
print("\n Ckpt already exists")
|
246
|
+
print("\n Loading...")
|
247
|
+
checkpoint = torch.load(filename_ckpt)
|
248
|
+
net.load_state_dict(checkpoint['net_state_dict'])
|
249
|
+
torch.cuda.empty_cache()#release GPU mem which is not references
|
250
|
+
|
251
|
+
|
252
|
+
if args.CVMode:
|
253
|
+
#validation
|
254
|
+
_ = valid_CNN(True)
|
255
|
+
torch.cuda.empty_cache()
|