Myosotis-Researches 0.1.8__py3-none-any.whl → 0.1.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- myosotis_researches/CcGAN/train/__init__.py +4 -0
- myosotis_researches/CcGAN/{train_128_output_10 → train}/train_ccgan.py +4 -4
- myosotis_researches/CcGAN/{train_128_output_10 → train}/train_cgan.py +1 -3
- myosotis_researches/CcGAN/{train_128_output_10 → train}/train_cgan_concat.py +1 -3
- {myosotis_researches-0.1.8.dist-info → myosotis_researches-0.1.10.dist-info}/METADATA +1 -1
- myosotis_researches-0.1.10.dist-info/RECORD +40 -0
- myosotis_researches/CcGAN/train_128/DiffAugment_pytorch.py +0 -76
- myosotis_researches/CcGAN/train_128/__init__.py +0 -0
- myosotis_researches/CcGAN/train_128/eval_metrics.py +0 -205
- myosotis_researches/CcGAN/train_128/opts.py +0 -87
- myosotis_researches/CcGAN/train_128/pretrain_AE.py +0 -268
- myosotis_researches/CcGAN/train_128/pretrain_CNN_class.py +0 -251
- myosotis_researches/CcGAN/train_128/pretrain_CNN_regre.py +0 -255
- myosotis_researches/CcGAN/train_128/train_ccgan.py +0 -303
- myosotis_researches/CcGAN/train_128/train_cgan.py +0 -254
- myosotis_researches/CcGAN/train_128/train_cgan_concat.py +0 -242
- myosotis_researches/CcGAN/train_128/utils.py +0 -120
- myosotis_researches/CcGAN/train_128_output_10/DiffAugment_pytorch.py +0 -76
- myosotis_researches/CcGAN/train_128_output_10/__init__.py +0 -0
- myosotis_researches/CcGAN/train_128_output_10/eval_metrics.py +0 -205
- myosotis_researches/CcGAN/train_128_output_10/opts.py +0 -87
- myosotis_researches/CcGAN/train_128_output_10/pretrain_AE.py +0 -268
- myosotis_researches/CcGAN/train_128_output_10/pretrain_CNN_class.py +0 -251
- myosotis_researches/CcGAN/train_128_output_10/pretrain_CNN_regre.py +0 -255
- myosotis_researches/CcGAN/train_128_output_10/train_net_for_label_embed.py +0 -181
- myosotis_researches/CcGAN/train_128_output_10/utils.py +0 -120
- myosotis_researches-0.1.8.dist-info/RECORD +0 -59
- /myosotis_researches/CcGAN/{train_128 → train}/train_net_for_label_embed.py +0 -0
- {myosotis_researches-0.1.8.dist-info → myosotis_researches-0.1.10.dist-info}/WHEEL +0 -0
- {myosotis_researches-0.1.8.dist-info → myosotis_researches-0.1.10.dist-info}/licenses/LICENSE +0 -0
- {myosotis_researches-0.1.8.dist-info → myosotis_researches-0.1.10.dist-info}/top_level.txt +0 -0
@@ -1,251 +0,0 @@
|
|
1
|
-
"""
|
2
|
-
|
3
|
-
Pre-train a CNN on the whole dataset for evaluation purpose
|
4
|
-
|
5
|
-
"""
|
6
|
-
import os
|
7
|
-
import argparse
|
8
|
-
import shutil
|
9
|
-
import timeit
|
10
|
-
|
11
|
-
import torch
|
12
|
-
import torchvision
|
13
|
-
import torchvision.transforms as transforms
|
14
|
-
import numpy as np
|
15
|
-
import torch.nn as nn
|
16
|
-
import torch.backends.cudnn as cudnn
|
17
|
-
import random
|
18
|
-
import matplotlib.pyplot as plt
|
19
|
-
import matplotlib as mpl
|
20
|
-
from torch import autograd
|
21
|
-
from torchvision.utils import save_image
|
22
|
-
import csv
|
23
|
-
from tqdm import tqdm
|
24
|
-
import gc
|
25
|
-
import h5py
|
26
|
-
|
27
|
-
from models import *
|
28
|
-
from utils import IMGs_dataset
|
29
|
-
|
30
|
-
|
31
|
-
#############################
|
32
|
-
# Settings
|
33
|
-
#############################
|
34
|
-
|
35
|
-
parser = argparse.ArgumentParser(description='Pre-train CNNs')
|
36
|
-
parser.add_argument('--root_path', type=str, default='')
|
37
|
-
parser.add_argument('--data_path', type=str, default='')
|
38
|
-
parser.add_argument('--num_workers', type=int, default=0)
|
39
|
-
parser.add_argument('--CNN', type=str, default='ResNet34_class',
|
40
|
-
help='CNN for training; ResNetXX')
|
41
|
-
parser.add_argument('--epochs', type=int, default=200, metavar='N',
|
42
|
-
help='number of epochs to train CNNs (default: 200)')
|
43
|
-
parser.add_argument('--batch_size_train', type=int, default=128, metavar='N',
|
44
|
-
help='input batch size for training')
|
45
|
-
parser.add_argument('--batch_size_valid', type=int, default=10, metavar='N',
|
46
|
-
help='input batch size for testing')
|
47
|
-
parser.add_argument('--base_lr', type=float, default=0.01,
|
48
|
-
help='learning rate, default=0.1')
|
49
|
-
parser.add_argument('--weight_dacay', type=float, default=1e-4,
|
50
|
-
help='Weigth decay, default=1e-4')
|
51
|
-
parser.add_argument('--seed', type=int, default=2020, metavar='S',
|
52
|
-
help='random seed (default: 1)')
|
53
|
-
parser.add_argument('--CVMode', action='store_true', default=False,
|
54
|
-
help='CV mode?')
|
55
|
-
parser.add_argument('--valid_proport', type=float, default=0.1,
|
56
|
-
help='Proportion of validation samples')
|
57
|
-
parser.add_argument('--img_size', type=int, default=128, metavar='N')
|
58
|
-
parser.add_argument('--min_label', type=float, default=0.0)
|
59
|
-
parser.add_argument('--max_label', type=float, default=90.0)
|
60
|
-
args = parser.parse_args()
|
61
|
-
|
62
|
-
|
63
|
-
wd = args.root_path
|
64
|
-
os.chdir(wd)
|
65
|
-
from ..models_128 import *
|
66
|
-
from .utils import IMGs_dataset
|
67
|
-
|
68
|
-
# cuda
|
69
|
-
device = torch.device("cuda")
|
70
|
-
ngpu = torch.cuda.device_count() # number of gpus
|
71
|
-
|
72
|
-
# random seed
|
73
|
-
random.seed(args.seed)
|
74
|
-
torch.manual_seed(args.seed)
|
75
|
-
torch.backends.cudnn.deterministic = True
|
76
|
-
np.random.seed(args.seed)
|
77
|
-
|
78
|
-
# directories for checkpoint, images and log files
|
79
|
-
save_models_folder = wd + '/output/eval_models/'
|
80
|
-
os.makedirs(save_models_folder, exist_ok=True)
|
81
|
-
|
82
|
-
|
83
|
-
# data loader
|
84
|
-
data_filename = args.data_path + '/Ra_' + str(args.img_size) + 'x' + str(args.img_size) + '.h5'
|
85
|
-
hf = h5py.File(data_filename, 'r')
|
86
|
-
angles = hf['labels'][:]
|
87
|
-
angles = angles.astype(float)
|
88
|
-
labels = hf['types'][:]
|
89
|
-
images = hf['images'][:]
|
90
|
-
hf.close()
|
91
|
-
num_classes = len(set(labels))
|
92
|
-
assert max(labels)==num_classes-1
|
93
|
-
|
94
|
-
q1 = args.min_label
|
95
|
-
q2 = args.max_label
|
96
|
-
indx = np.where((angles>q1)*(angles<q2)==True)[0]
|
97
|
-
angles = angles[indx]
|
98
|
-
labels = labels[indx]
|
99
|
-
images = images[indx]
|
100
|
-
assert len(labels)==len(images)
|
101
|
-
assert len(angles)==len(images)
|
102
|
-
|
103
|
-
|
104
|
-
# define training (and validaiton) set
|
105
|
-
if args.CVMode:
|
106
|
-
for i in range(num_classes):
|
107
|
-
indx_i = np.where(labels==i)[0] #i-th class
|
108
|
-
np.random.shuffle(indx_i)
|
109
|
-
num_imgs_all_i = len(indx_i)
|
110
|
-
num_imgs_valid_i = int(num_imgs_all_i*args.valid_proport)
|
111
|
-
if i == 0:
|
112
|
-
indx_valid = indx_i[0:num_imgs_valid_i]
|
113
|
-
indx_train = indx_i[num_imgs_valid_i:]
|
114
|
-
else:
|
115
|
-
indx_valid = np.concatenate((indx_valid, indx_i[0:num_imgs_valid_i]))
|
116
|
-
indx_train = np.concatenate((indx_train, indx_i[num_imgs_valid_i:]))
|
117
|
-
#end for i
|
118
|
-
trainset = IMGs_dataset(images[indx_train], labels[indx_train], normalize=True)
|
119
|
-
trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size_train, shuffle=True, num_workers=args.num_workers)
|
120
|
-
validset = IMGs_dataset(images[indx_valid], labels[indx_valid], normalize=True)
|
121
|
-
validloader = torch.utils.data.DataLoader(validset, batch_size=args.batch_size_valid, shuffle=False, num_workers=args.num_workers)
|
122
|
-
else:
|
123
|
-
trainset = IMGs_dataset(images, labels, normalize=True)
|
124
|
-
trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size_train, shuffle=True, num_workers=args.num_workers)
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
###########################################################################################################
|
129
|
-
# Necessary functions
|
130
|
-
###########################################################################################################
|
131
|
-
|
132
|
-
#initialize CNNs
|
133
|
-
def net_initialization(Pretrained_CNN_Name, ngpu = 1, num_classes=num_classes):
|
134
|
-
if Pretrained_CNN_Name == "ResNet18_class":
|
135
|
-
net = ResNet18_class_eval(num_classes=num_classes, ngpu = ngpu)
|
136
|
-
elif Pretrained_CNN_Name == "ResNet34_class":
|
137
|
-
net = ResNet34_class_eval(num_classes=num_classes, ngpu = ngpu)
|
138
|
-
elif Pretrained_CNN_Name == "ResNet50_class":
|
139
|
-
net = ResNet50_class_eval(num_classes=num_classes, ngpu = ngpu)
|
140
|
-
elif Pretrained_CNN_Name == "ResNet101_class":
|
141
|
-
net = ResNet101_class_eval(num_classes=num_classes, ngpu = ngpu)
|
142
|
-
|
143
|
-
net_name = 'PreCNNForEvalGANs_' + Pretrained_CNN_Name #get the net's name
|
144
|
-
net = net.to(device)
|
145
|
-
|
146
|
-
return net, net_name
|
147
|
-
|
148
|
-
#adjust CNN learning rate
|
149
|
-
def adjust_learning_rate(optimizer, epoch, BASE_LR_CNN):
|
150
|
-
lr = BASE_LR_CNN
|
151
|
-
# if epoch >= 35:
|
152
|
-
# lr /= 10
|
153
|
-
# if epoch >= 70:
|
154
|
-
# lr /= 10
|
155
|
-
if epoch >= 50:
|
156
|
-
lr /= 10
|
157
|
-
if epoch >= 120:
|
158
|
-
lr /= 10
|
159
|
-
for param_group in optimizer.param_groups:
|
160
|
-
param_group['lr'] = lr
|
161
|
-
|
162
|
-
|
163
|
-
def train_CNN():
|
164
|
-
|
165
|
-
start_tmp = timeit.default_timer()
|
166
|
-
for epoch in range(args.epochs):
|
167
|
-
net.train()
|
168
|
-
train_loss = 0
|
169
|
-
adjust_learning_rate(optimizer, epoch, args.base_lr)
|
170
|
-
for batch_idx, (batch_train_images, batch_train_labels) in enumerate(trainloader):
|
171
|
-
|
172
|
-
# batch_train_images = nn.functional.interpolate(batch_train_images, size = (299,299), scale_factor=None, mode='bilinear', align_corners=False)
|
173
|
-
|
174
|
-
batch_train_images = batch_train_images.type(torch.float).cuda()
|
175
|
-
batch_train_labels = batch_train_labels.type(torch.long).cuda()
|
176
|
-
|
177
|
-
#Forward pass
|
178
|
-
outputs,_ = net(batch_train_images)
|
179
|
-
loss = criterion(outputs, batch_train_labels)
|
180
|
-
|
181
|
-
#backward pass
|
182
|
-
optimizer.zero_grad()
|
183
|
-
loss.backward()
|
184
|
-
optimizer.step()
|
185
|
-
|
186
|
-
train_loss += loss.cpu().item()
|
187
|
-
#end for batch_idx
|
188
|
-
train_loss = train_loss / len(trainloader)
|
189
|
-
|
190
|
-
if args.CVMode:
|
191
|
-
valid_acc = valid_CNN(verbose=False)
|
192
|
-
print('CNN: [epoch %d/%d] train_loss:%f valid_acc:%f Time: %.4f' % (epoch+1, args.epochs, train_loss, valid_acc, timeit.default_timer()-start_tmp))
|
193
|
-
else:
|
194
|
-
print('CNN: [epoch %d/%d] train_loss:%f Time: %.4f' % (epoch+1, args.epochs, train_loss, timeit.default_timer()-start_tmp))
|
195
|
-
#end for epoch
|
196
|
-
|
197
|
-
return net, optimizer
|
198
|
-
|
199
|
-
|
200
|
-
if args.CVMode:
|
201
|
-
def valid_CNN(verbose=True):
|
202
|
-
net.eval() # eval mode (batchnorm uses moving mean/variance instead of mini-batch mean/variance)
|
203
|
-
with torch.no_grad():
|
204
|
-
correct = 0
|
205
|
-
total = 0
|
206
|
-
for batch_idx, (images, labels) in enumerate(validloader):
|
207
|
-
images = images.type(torch.float).cuda()
|
208
|
-
labels = labels.type(torch.long).cuda()
|
209
|
-
outputs,_ = net(images)
|
210
|
-
_, predicted = torch.max(outputs.data, 1)
|
211
|
-
total += labels.size(0)
|
212
|
-
correct += (predicted == labels).sum().item()
|
213
|
-
if verbose:
|
214
|
-
print('Valid Accuracy of the model on the validation set: {} %'.format(100.0 * correct / total))
|
215
|
-
return 100.0 * correct / total
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
###########################################################################################################
|
220
|
-
# Training and Testing
|
221
|
-
###########################################################################################################
|
222
|
-
# model initialization
|
223
|
-
net, net_name = net_initialization(args.CNN, ngpu = ngpu, num_classes = num_classes)
|
224
|
-
criterion = nn.CrossEntropyLoss()
|
225
|
-
optimizer = torch.optim.SGD(net.parameters(), lr = args.base_lr, momentum= 0.9, weight_decay=args.weight_dacay)
|
226
|
-
|
227
|
-
filename_ckpt = save_models_folder + '/ckpt_{}_epoch_{}_seed_{}_classify_{}_chair_types_CVMode_{}.pth'.format(net_name, args.epochs, args.seed, num_classes, args.CVMode)
|
228
|
-
|
229
|
-
# training
|
230
|
-
if not os.path.isfile(filename_ckpt):
|
231
|
-
# TRAIN CNN
|
232
|
-
print("\n Begin training CNN: ")
|
233
|
-
start = timeit.default_timer()
|
234
|
-
net, optimizer = train_CNN()
|
235
|
-
stop = timeit.default_timer()
|
236
|
-
print("Time elapses: {}s".format(stop - start))
|
237
|
-
# save model
|
238
|
-
torch.save({
|
239
|
-
'net_state_dict': net.state_dict(),
|
240
|
-
}, filename_ckpt)
|
241
|
-
else:
|
242
|
-
print("\n Ckpt already exists")
|
243
|
-
print("\n Loading...")
|
244
|
-
checkpoint = torch.load(filename_ckpt)
|
245
|
-
net.load_state_dict(checkpoint['net_state_dict'])
|
246
|
-
torch.cuda.empty_cache()#release GPU mem which is not references
|
247
|
-
|
248
|
-
if args.CVMode:
|
249
|
-
#validation
|
250
|
-
_ = valid_CNN(True)
|
251
|
-
torch.cuda.empty_cache()
|
@@ -1,255 +0,0 @@
|
|
1
|
-
"""
|
2
|
-
|
3
|
-
Pre-train a CNN on the whole dataset for evaluation purpose
|
4
|
-
|
5
|
-
"""
|
6
|
-
import os
|
7
|
-
import argparse
|
8
|
-
import shutil
|
9
|
-
import timeit
|
10
|
-
import torch
|
11
|
-
import torchvision
|
12
|
-
import torchvision.transforms as transforms
|
13
|
-
import numpy as np
|
14
|
-
import torch.nn as nn
|
15
|
-
import torch.backends.cudnn as cudnn
|
16
|
-
import random
|
17
|
-
import matplotlib.pyplot as plt
|
18
|
-
import matplotlib as mpl
|
19
|
-
from torch import autograd
|
20
|
-
from torchvision.utils import save_image
|
21
|
-
import csv
|
22
|
-
from tqdm import tqdm
|
23
|
-
import gc
|
24
|
-
import h5py
|
25
|
-
|
26
|
-
|
27
|
-
#############################
|
28
|
-
# Settings
|
29
|
-
#############################
|
30
|
-
|
31
|
-
parser = argparse.ArgumentParser(description='Pre-train CNNs')
|
32
|
-
parser.add_argument('--root_path', type=str, default='')
|
33
|
-
parser.add_argument('--data_path', type=str, default='')
|
34
|
-
parser.add_argument('--num_workers', type=int, default=0)
|
35
|
-
parser.add_argument('--CNN', type=str, default='ResNet34_regre',
|
36
|
-
help='CNN for training; ResNetXX')
|
37
|
-
parser.add_argument('--epochs', type=int, default=200, metavar='N',
|
38
|
-
help='number of epochs to train CNNs (default: 200)')
|
39
|
-
parser.add_argument('--batch_size_train', type=int, default=256, metavar='N',
|
40
|
-
help='input batch size for training')
|
41
|
-
parser.add_argument('--batch_size_valid', type=int, default=10, metavar='N',
|
42
|
-
help='input batch size for testing')
|
43
|
-
parser.add_argument('--base_lr', type=float, default=0.01,
|
44
|
-
help='learning rate, default=0.1')
|
45
|
-
parser.add_argument('--weight_dacay', type=float, default=1e-4,
|
46
|
-
help='Weigth decay, default=1e-4')
|
47
|
-
parser.add_argument('--seed', type=int, default=2020, metavar='S',
|
48
|
-
help='random seed (default: 1)')
|
49
|
-
parser.add_argument('--CVMode', action='store_true', default=False,
|
50
|
-
help='CV mode?')
|
51
|
-
parser.add_argument('--img_size', type=int, default=128, metavar='N')
|
52
|
-
parser.add_argument('--min_label', type=float, default=0.0)
|
53
|
-
parser.add_argument('--max_label', type=float, default=90.0)
|
54
|
-
args = parser.parse_args()
|
55
|
-
|
56
|
-
|
57
|
-
wd = args.root_path
|
58
|
-
os.chdir(wd)
|
59
|
-
from ..models_128 import *
|
60
|
-
from .utils import IMGs_dataset
|
61
|
-
|
62
|
-
# cuda
|
63
|
-
device = torch.device("cuda")
|
64
|
-
ngpu = torch.cuda.device_count() # number of gpus
|
65
|
-
|
66
|
-
# random seed
|
67
|
-
random.seed(args.seed)
|
68
|
-
torch.manual_seed(args.seed)
|
69
|
-
torch.backends.cudnn.deterministic = True
|
70
|
-
np.random.seed(args.seed)
|
71
|
-
|
72
|
-
# directories for checkpoint, images and log files
|
73
|
-
save_models_folder = wd + '/output/eval_models/'
|
74
|
-
os.makedirs(save_models_folder, exist_ok=True)
|
75
|
-
|
76
|
-
|
77
|
-
###########################################################################################################
|
78
|
-
# Data
|
79
|
-
###########################################################################################################
|
80
|
-
# data loader
|
81
|
-
data_filename = args.data_path + '/Ra_' + str(args.img_size) + 'x' + str(args.img_size) + '.h5'
|
82
|
-
hf = h5py.File(data_filename, 'r')
|
83
|
-
labels = hf['labels'][:]
|
84
|
-
labels = labels.astype(float)
|
85
|
-
images = hf['images'][:]
|
86
|
-
hf.close()
|
87
|
-
N_all = len(images)
|
88
|
-
assert len(images) == len(labels)
|
89
|
-
|
90
|
-
q1 = args.min_label
|
91
|
-
q2 = args.max_label
|
92
|
-
indx = np.where((labels>q1)*(labels<q2)==True)[0]
|
93
|
-
labels = labels[indx]
|
94
|
-
images = images[indx]
|
95
|
-
assert len(labels)==len(images)
|
96
|
-
|
97
|
-
|
98
|
-
# normalize to [0, 1]
|
99
|
-
#min_label = np.min(labels)
|
100
|
-
#labels += np.abs(min_label)
|
101
|
-
#max_label = np.max(labels)
|
102
|
-
#labels /= max_label
|
103
|
-
labels /= args.max_label
|
104
|
-
|
105
|
-
|
106
|
-
# define training and validation sets
|
107
|
-
if args.CVMode:
|
108
|
-
#80% Training; 20% valdation
|
109
|
-
indx_all = np.arange(len(labels))
|
110
|
-
np.random.shuffle(indx_all)
|
111
|
-
indx_valid = indx_all[0:int(0.2*len(labels))]
|
112
|
-
indx_train = indx_all[int(0.2*len(labels)):]
|
113
|
-
|
114
|
-
trainset = IMGs_dataset(images[indx_train], labels[indx_train], normalize=True)
|
115
|
-
trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size_train, shuffle=True, num_workers=args.num_workers)
|
116
|
-
validset = IMGs_dataset(images[indx_valid], labels[indx_valid], normalize=True)
|
117
|
-
validloader = torch.utils.data.DataLoader(validset, batch_size=args.batch_size_valid, shuffle=False, num_workers=args.num_workers)
|
118
|
-
|
119
|
-
else:
|
120
|
-
trainset = IMGs_dataset(images, labels, normalize=True)
|
121
|
-
trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size_train, shuffle=True, num_workers=args.num_workers)
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
###########################################################################################################
|
127
|
-
# Necessary functions
|
128
|
-
###########################################################################################################
|
129
|
-
|
130
|
-
#initialize CNNs
|
131
|
-
def net_initialization(Pretrained_CNN_Name, ngpu = 1):
|
132
|
-
if Pretrained_CNN_Name == "ResNet18_regre":
|
133
|
-
net = ResNet18_regre_eval(ngpu = ngpu)
|
134
|
-
elif Pretrained_CNN_Name == "ResNet34_regre":
|
135
|
-
net = ResNet34_regre_eval(ngpu = ngpu)
|
136
|
-
elif Pretrained_CNN_Name == "ResNet50_regre":
|
137
|
-
net = ResNet50_regre_eval(ngpu = ngpu)
|
138
|
-
elif Pretrained_CNN_Name == "ResNet101_regre":
|
139
|
-
net = ResNet101_regre_eval(ngpu = ngpu)
|
140
|
-
|
141
|
-
net_name = 'PreCNNForEvalGANs_' + Pretrained_CNN_Name #get the net's name
|
142
|
-
net = net.to(device)
|
143
|
-
|
144
|
-
return net, net_name
|
145
|
-
|
146
|
-
#adjust CNN learning rate
|
147
|
-
def adjust_learning_rate(optimizer, epoch, BASE_LR_CNN):
|
148
|
-
lr = BASE_LR_CNN
|
149
|
-
# if epoch >= 35:
|
150
|
-
# lr /= 10
|
151
|
-
# if epoch >= 70:
|
152
|
-
# lr /= 10
|
153
|
-
if epoch >= 50:
|
154
|
-
lr /= 10
|
155
|
-
if epoch >= 120:
|
156
|
-
lr /= 10
|
157
|
-
for param_group in optimizer.param_groups:
|
158
|
-
param_group['lr'] = lr
|
159
|
-
|
160
|
-
|
161
|
-
def train_CNN():
|
162
|
-
|
163
|
-
start_tmp = timeit.default_timer()
|
164
|
-
for epoch in range(args.epochs):
|
165
|
-
net.train()
|
166
|
-
train_loss = 0
|
167
|
-
adjust_learning_rate(optimizer, epoch, args.base_lr)
|
168
|
-
for batch_idx, (batch_train_images, batch_train_labels) in enumerate(trainloader):
|
169
|
-
|
170
|
-
# batch_train_images = nn.functional.interpolate(batch_train_images, size = (299,299), scale_factor=None, mode='bilinear', align_corners=False)
|
171
|
-
|
172
|
-
batch_train_images = batch_train_images.type(torch.float).cuda()
|
173
|
-
batch_train_labels = batch_train_labels.type(torch.float).view(-1,1).cuda()
|
174
|
-
|
175
|
-
#Forward pass
|
176
|
-
outputs,_ = net(batch_train_images)
|
177
|
-
loss = criterion(outputs, batch_train_labels)
|
178
|
-
|
179
|
-
#backward pass
|
180
|
-
optimizer.zero_grad()
|
181
|
-
loss.backward()
|
182
|
-
optimizer.step()
|
183
|
-
|
184
|
-
train_loss += loss.cpu().item()
|
185
|
-
#end for batch_idx
|
186
|
-
train_loss = train_loss / len(trainloader)
|
187
|
-
|
188
|
-
if args.CVMode:
|
189
|
-
valid_loss = valid_CNN(verbose=False)
|
190
|
-
print('CNN: [epoch %d/%d] train_loss:%f valid_loss (avg_abs):%f Time: %.4f' % (epoch+1, args.epochs, train_loss, valid_loss, timeit.default_timer()-start_tmp))
|
191
|
-
else:
|
192
|
-
print('CNN: [epoch %d/%d] train_loss:%f Time:%.4f' % (epoch+1, args.epochs, train_loss, timeit.default_timer()-start_tmp))
|
193
|
-
#end for epoch
|
194
|
-
|
195
|
-
return net, optimizer
|
196
|
-
|
197
|
-
if args.CVMode:
|
198
|
-
def valid_CNN(verbose=True):
|
199
|
-
net.eval() # eval mode (batchnorm uses moving mean/variance instead of mini-batch mean/variance)
|
200
|
-
with torch.no_grad():
|
201
|
-
abs_diff_avg = 0
|
202
|
-
total = 0
|
203
|
-
for batch_idx, (images, labels) in enumerate(validloader):
|
204
|
-
images = images.type(torch.float).cuda()
|
205
|
-
labels = labels.type(torch.float).view(-1).cpu().numpy()
|
206
|
-
outputs,_ = net(images)
|
207
|
-
outputs = outputs.view(-1).cpu().numpy()
|
208
|
-
labels = labels * args.max_label
|
209
|
-
outputs = outputs * args.max_label
|
210
|
-
abs_diff_avg += np.sum(np.abs(labels-outputs))
|
211
|
-
total += len(labels)
|
212
|
-
|
213
|
-
if verbose:
|
214
|
-
print('Validation Average Absolute Difference: {}'.format(abs_diff_avg/total))
|
215
|
-
return abs_diff_avg/total
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
###########################################################################################################
|
220
|
-
# Training and validation
|
221
|
-
###########################################################################################################
|
222
|
-
|
223
|
-
|
224
|
-
# model initialization
|
225
|
-
net, net_name = net_initialization(args.CNN, ngpu = ngpu)
|
226
|
-
criterion = nn.MSELoss()
|
227
|
-
optimizer = torch.optim.SGD(net.parameters(), lr = args.base_lr, momentum= 0.9, weight_decay=args.weight_dacay)
|
228
|
-
|
229
|
-
filename_ckpt = save_models_folder + '/ckpt_{}_epoch_{}_seed_{}_CVMode_{}.pth'.format(net_name, args.epochs, args.seed, args.CVMode)
|
230
|
-
|
231
|
-
|
232
|
-
# training
|
233
|
-
if not os.path.isfile(filename_ckpt):
|
234
|
-
# TRAIN CNN
|
235
|
-
print("\n Begin training CNN: ")
|
236
|
-
start = timeit.default_timer()
|
237
|
-
net, optimizer = train_CNN()
|
238
|
-
stop = timeit.default_timer()
|
239
|
-
print("Time elapses: {}s".format(stop - start))
|
240
|
-
# save model
|
241
|
-
torch.save({
|
242
|
-
'net_state_dict': net.state_dict(),
|
243
|
-
}, filename_ckpt)
|
244
|
-
else:
|
245
|
-
print("\n Ckpt already exists")
|
246
|
-
print("\n Loading...")
|
247
|
-
checkpoint = torch.load(filename_ckpt)
|
248
|
-
net.load_state_dict(checkpoint['net_state_dict'])
|
249
|
-
torch.cuda.empty_cache()#release GPU mem which is not references
|
250
|
-
|
251
|
-
|
252
|
-
if args.CVMode:
|
253
|
-
#validation
|
254
|
-
_ = valid_CNN(True)
|
255
|
-
torch.cuda.empty_cache()
|
@@ -1,181 +0,0 @@
|
|
1
|
-
|
2
|
-
import torch
|
3
|
-
import torch.nn as nn
|
4
|
-
from torchvision.utils import save_image
|
5
|
-
import numpy as np
|
6
|
-
import os
|
7
|
-
import timeit
|
8
|
-
from PIL import Image
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
#-------------------------------------------------------------
|
14
|
-
def train_net_embed(net, net_name, trainloader, testloader, epochs=200, resume_epoch = 0, lr_base=0.01, lr_decay_factor=0.1, lr_decay_epochs=[80, 140], weight_decay=1e-4, path_to_ckpt = None):
|
15
|
-
|
16
|
-
''' learning rate decay '''
|
17
|
-
def adjust_learning_rate_1(optimizer, epoch):
|
18
|
-
"""decrease the learning rate """
|
19
|
-
lr = lr_base
|
20
|
-
|
21
|
-
num_decays = len(lr_decay_epochs)
|
22
|
-
for decay_i in range(num_decays):
|
23
|
-
if epoch >= lr_decay_epochs[decay_i]:
|
24
|
-
lr = lr * lr_decay_factor
|
25
|
-
#end if epoch
|
26
|
-
#end for decay_i
|
27
|
-
for param_group in optimizer.param_groups:
|
28
|
-
param_group['lr'] = lr
|
29
|
-
|
30
|
-
net = net.cuda()
|
31
|
-
criterion = nn.MSELoss()
|
32
|
-
optimizer = torch.optim.SGD(net.parameters(), lr = lr_base, momentum= 0.9, weight_decay=weight_decay)
|
33
|
-
|
34
|
-
# resume training; load checkpoint
|
35
|
-
if path_to_ckpt is not None and resume_epoch>0:
|
36
|
-
save_file = path_to_ckpt + "/embed_x2y_ckpt_in_train/embed_x2y_checkpoint_epoch_{}.pth".format(resume_epoch)
|
37
|
-
checkpoint = torch.load(save_file)
|
38
|
-
net.load_state_dict(checkpoint['net_state_dict'])
|
39
|
-
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
40
|
-
torch.set_rng_state(checkpoint['rng_state'])
|
41
|
-
#end if
|
42
|
-
|
43
|
-
start_tmp = timeit.default_timer()
|
44
|
-
for epoch in range(resume_epoch, epochs):
|
45
|
-
net.train()
|
46
|
-
train_loss = 0
|
47
|
-
adjust_learning_rate_1(optimizer, epoch)
|
48
|
-
for _, (batch_train_images, batch_train_labels) in enumerate(trainloader):
|
49
|
-
|
50
|
-
# batch_train_images = nn.functional.interpolate(batch_train_images, size = (299,299), scale_factor=None, mode='bilinear', align_corners=False)
|
51
|
-
|
52
|
-
batch_train_images = batch_train_images.type(torch.float).cuda()
|
53
|
-
batch_train_labels = batch_train_labels.type(torch.float).view(-1,1).cuda()
|
54
|
-
|
55
|
-
#Forward pass
|
56
|
-
outputs, _ = net(batch_train_images)
|
57
|
-
loss = criterion(outputs, batch_train_labels)
|
58
|
-
|
59
|
-
#backward pass
|
60
|
-
optimizer.zero_grad()
|
61
|
-
loss.backward()
|
62
|
-
optimizer.step()
|
63
|
-
|
64
|
-
train_loss += loss.cpu().item()
|
65
|
-
#end for batch_idx
|
66
|
-
train_loss = train_loss / len(trainloader)
|
67
|
-
|
68
|
-
if testloader is None:
|
69
|
-
print('Train net_x2y for embedding: [epoch %d/%d] train_loss:%f Time:%.4f' % (epoch+1, epochs, train_loss, timeit.default_timer()-start_tmp))
|
70
|
-
else:
|
71
|
-
net.eval() # eval mode (batchnorm uses moving mean/variance instead of mini-batch mean/variance)
|
72
|
-
with torch.no_grad():
|
73
|
-
test_loss = 0
|
74
|
-
for batch_test_images, batch_test_labels in testloader:
|
75
|
-
batch_test_images = batch_test_images.type(torch.float).cuda()
|
76
|
-
batch_test_labels = batch_test_labels.type(torch.float).view(-1,1).cuda()
|
77
|
-
outputs,_ = net(batch_test_images)
|
78
|
-
loss = criterion(outputs, batch_test_labels)
|
79
|
-
test_loss += loss.cpu().item()
|
80
|
-
test_loss = test_loss/len(testloader)
|
81
|
-
|
82
|
-
print('Train net_x2y for label embedding: [epoch %d/%d] train_loss:%f test_loss:%f Time:%.4f' % (epoch+1, epochs, train_loss, test_loss, timeit.default_timer()-start_tmp))
|
83
|
-
|
84
|
-
#save checkpoint
|
85
|
-
if path_to_ckpt is not None and (((epoch+1) % 50 == 0) or (epoch+1==epochs)):
|
86
|
-
save_file = path_to_ckpt + "/embed_x2y_ckpt_in_train/embed_x2y_checkpoint_epoch_{}.pth".format(epoch+1)
|
87
|
-
os.makedirs(os.path.dirname(save_file), exist_ok=True)
|
88
|
-
torch.save({
|
89
|
-
'epoch': epoch,
|
90
|
-
'net_state_dict': net.state_dict(),
|
91
|
-
'optimizer_state_dict': optimizer.state_dict(),
|
92
|
-
'rng_state': torch.get_rng_state()
|
93
|
-
}, save_file)
|
94
|
-
#end for epoch
|
95
|
-
|
96
|
-
return net
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
###################################################################################
|
102
|
-
class label_dataset(torch.utils.data.Dataset):
|
103
|
-
def __init__(self, labels):
|
104
|
-
super(label_dataset, self).__init__()
|
105
|
-
|
106
|
-
self.labels = labels
|
107
|
-
self.n_samples = len(self.labels)
|
108
|
-
|
109
|
-
def __getitem__(self, index):
|
110
|
-
|
111
|
-
y = self.labels[index]
|
112
|
-
return y
|
113
|
-
|
114
|
-
def __len__(self):
|
115
|
-
return self.n_samples
|
116
|
-
|
117
|
-
|
118
|
-
def train_net_y2h(unique_labels_norm, net_y2h, net_embed, epochs=500, lr_base=0.01, lr_decay_factor=0.1, lr_decay_epochs=[150, 250, 350], weight_decay=1e-4, batch_size=128):
|
119
|
-
'''
|
120
|
-
unique_labels_norm: an array of normalized unique labels
|
121
|
-
'''
|
122
|
-
|
123
|
-
''' learning rate decay '''
|
124
|
-
def adjust_learning_rate_2(optimizer, epoch):
|
125
|
-
"""decrease the learning rate """
|
126
|
-
lr = lr_base
|
127
|
-
|
128
|
-
num_decays = len(lr_decay_epochs)
|
129
|
-
for decay_i in range(num_decays):
|
130
|
-
if epoch >= lr_decay_epochs[decay_i]:
|
131
|
-
lr = lr * lr_decay_factor
|
132
|
-
#end if epoch
|
133
|
-
#end for decay_i
|
134
|
-
for param_group in optimizer.param_groups:
|
135
|
-
param_group['lr'] = lr
|
136
|
-
|
137
|
-
|
138
|
-
assert np.max(unique_labels_norm)<=1 and np.min(unique_labels_norm)>=0
|
139
|
-
trainset = label_dataset(unique_labels_norm)
|
140
|
-
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)
|
141
|
-
|
142
|
-
net_embed.eval()
|
143
|
-
net_h2y=net_embed.module.h2y #convert embedding labels to original labels
|
144
|
-
optimizer_y2h = torch.optim.SGD(net_y2h.parameters(), lr = lr_base, momentum= 0.9, weight_decay=weight_decay)
|
145
|
-
|
146
|
-
start_tmp = timeit.default_timer()
|
147
|
-
for epoch in range(epochs):
|
148
|
-
net_y2h.train()
|
149
|
-
train_loss = 0
|
150
|
-
adjust_learning_rate_2(optimizer_y2h, epoch)
|
151
|
-
for _, batch_labels in enumerate(trainloader):
|
152
|
-
|
153
|
-
batch_labels = batch_labels.type(torch.float).view(-1,1).cuda()
|
154
|
-
|
155
|
-
# generate noises which will be added to labels
|
156
|
-
batch_size_curr = len(batch_labels)
|
157
|
-
batch_gamma = np.random.normal(0, 0.2, batch_size_curr)
|
158
|
-
batch_gamma = torch.from_numpy(batch_gamma).view(-1,1).type(torch.float).cuda()
|
159
|
-
|
160
|
-
# add noise to labels
|
161
|
-
batch_labels_noise = torch.clamp(batch_labels+batch_gamma, 0.0, 1.0)
|
162
|
-
|
163
|
-
#Forward pass
|
164
|
-
batch_hiddens_noise = net_y2h(batch_labels_noise)
|
165
|
-
batch_rec_labels_noise = net_h2y(batch_hiddens_noise)
|
166
|
-
|
167
|
-
loss = nn.MSELoss()(batch_rec_labels_noise, batch_labels_noise)
|
168
|
-
|
169
|
-
#backward pass
|
170
|
-
optimizer_y2h.zero_grad()
|
171
|
-
loss.backward()
|
172
|
-
optimizer_y2h.step()
|
173
|
-
|
174
|
-
train_loss += loss.cpu().item()
|
175
|
-
#end for batch_idx
|
176
|
-
train_loss = train_loss / len(trainloader)
|
177
|
-
|
178
|
-
print('\n Train net_y2h: [epoch %d/%d] train_loss:%f Time:%.4f' % (epoch+1, epochs, train_loss, timeit.default_timer()-start_tmp))
|
179
|
-
#end for epoch
|
180
|
-
|
181
|
-
return net_y2h
|