Myosotis-Researches 0.0.13__py3-none-any.whl → 0.0.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. myosotis_researches/CcGAN/models_128/CcGAN_SAGAN.py +301 -0
  2. myosotis_researches/CcGAN/models_128/ResNet_class_eval.py +141 -0
  3. myosotis_researches/CcGAN/models_128/ResNet_embed.py +188 -0
  4. myosotis_researches/CcGAN/models_128/ResNet_regre_eval.py +175 -0
  5. myosotis_researches/CcGAN/models_128/__init__.py +8 -0
  6. myosotis_researches/CcGAN/models_128/autoencoder.py +119 -0
  7. myosotis_researches/CcGAN/models_128/cGAN_SAGAN.py +276 -0
  8. myosotis_researches/CcGAN/models_128/cGAN_concat_SAGAN.py +245 -0
  9. myosotis_researches/CcGAN/models_256/CcGAN_SAGAN.py +303 -0
  10. myosotis_researches/CcGAN/models_256/ResNet_class_eval.py +142 -0
  11. myosotis_researches/CcGAN/models_256/ResNet_embed.py +188 -0
  12. myosotis_researches/CcGAN/models_256/ResNet_regre_eval.py +178 -0
  13. myosotis_researches/CcGAN/models_256/__init__.py +8 -0
  14. myosotis_researches/CcGAN/models_256/autoencoder.py +133 -0
  15. myosotis_researches/CcGAN/models_256/cGAN_SAGAN.py +280 -0
  16. myosotis_researches/CcGAN/models_256/cGAN_concat_SAGAN.py +249 -0
  17. myosotis_researches/CcGAN/train_128/DiffAugment_pytorch.py +76 -0
  18. myosotis_researches/CcGAN/train_128/__init__.py +0 -0
  19. myosotis_researches/CcGAN/train_128/eval_metrics.py +205 -0
  20. myosotis_researches/CcGAN/train_128/opts.py +87 -0
  21. myosotis_researches/CcGAN/train_128/pretrain_AE.py +268 -0
  22. myosotis_researches/CcGAN/train_128/pretrain_CNN_class.py +251 -0
  23. myosotis_researches/CcGAN/train_128/pretrain_CNN_regre.py +255 -0
  24. myosotis_researches/CcGAN/train_128/train_ccgan.py +303 -0
  25. myosotis_researches/CcGAN/train_128/train_cgan.py +254 -0
  26. myosotis_researches/CcGAN/train_128/train_cgan_concat.py +242 -0
  27. myosotis_researches/CcGAN/train_128/train_net_for_label_embed.py +181 -0
  28. myosotis_researches/CcGAN/train_128/utils.py +120 -0
  29. {myosotis_researches-0.0.13.dist-info → myosotis_researches-0.0.15.dist-info}/METADATA +1 -1
  30. myosotis_researches-0.0.15.dist-info/RECORD +40 -0
  31. myosotis_researches-0.0.13.dist-info/RECORD +0 -12
  32. {myosotis_researches-0.0.13.dist-info → myosotis_researches-0.0.15.dist-info}/WHEEL +0 -0
  33. {myosotis_researches-0.0.13.dist-info → myosotis_researches-0.0.15.dist-info}/licenses/LICENSE +0 -0
  34. {myosotis_researches-0.0.13.dist-info → myosotis_researches-0.0.15.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,251 @@
1
+ """
2
+
3
+ Pre-train a CNN on the whole dataset for evaluation purpose
4
+
5
+ """
6
+ import os
7
+ import argparse
8
+ import shutil
9
+ import timeit
10
+
11
+ import torch
12
+ import torchvision
13
+ import torchvision.transforms as transforms
14
+ import numpy as np
15
+ import torch.nn as nn
16
+ import torch.backends.cudnn as cudnn
17
+ import random
18
+ import matplotlib.pyplot as plt
19
+ import matplotlib as mpl
20
+ from torch import autograd
21
+ from torchvision.utils import save_image
22
+ import csv
23
+ from tqdm import tqdm
24
+ import gc
25
+ import h5py
26
+
27
+ from models import *
28
+ from utils import IMGs_dataset
29
+
30
+
31
+ #############################
32
+ # Settings
33
+ #############################
34
+
35
+ parser = argparse.ArgumentParser(description='Pre-train CNNs')
36
+ parser.add_argument('--root_path', type=str, default='')
37
+ parser.add_argument('--data_path', type=str, default='')
38
+ parser.add_argument('--num_workers', type=int, default=0)
39
+ parser.add_argument('--CNN', type=str, default='ResNet34_class',
40
+ help='CNN for training; ResNetXX')
41
+ parser.add_argument('--epochs', type=int, default=200, metavar='N',
42
+ help='number of epochs to train CNNs (default: 200)')
43
+ parser.add_argument('--batch_size_train', type=int, default=128, metavar='N',
44
+ help='input batch size for training')
45
+ parser.add_argument('--batch_size_valid', type=int, default=10, metavar='N',
46
+ help='input batch size for testing')
47
+ parser.add_argument('--base_lr', type=float, default=0.01,
48
+ help='learning rate, default=0.1')
49
+ parser.add_argument('--weight_dacay', type=float, default=1e-4,
50
+ help='Weigth decay, default=1e-4')
51
+ parser.add_argument('--seed', type=int, default=2020, metavar='S',
52
+ help='random seed (default: 1)')
53
+ parser.add_argument('--CVMode', action='store_true', default=False,
54
+ help='CV mode?')
55
+ parser.add_argument('--valid_proport', type=float, default=0.1,
56
+ help='Proportion of validation samples')
57
+ parser.add_argument('--img_size', type=int, default=128, metavar='N')
58
+ parser.add_argument('--min_label', type=float, default=0.0)
59
+ parser.add_argument('--max_label', type=float, default=90.0)
60
+ args = parser.parse_args()
61
+
62
+
63
+ wd = args.root_path
64
+ os.chdir(wd)
65
+ from ..models_128 import *
66
+ from utils import IMGs_dataset
67
+
68
+ # cuda
69
+ device = torch.device("cuda")
70
+ ngpu = torch.cuda.device_count() # number of gpus
71
+
72
+ # random seed
73
+ random.seed(args.seed)
74
+ torch.manual_seed(args.seed)
75
+ torch.backends.cudnn.deterministic = True
76
+ np.random.seed(args.seed)
77
+
78
+ # directories for checkpoint, images and log files
79
+ save_models_folder = wd + '/output/eval_models/'
80
+ os.makedirs(save_models_folder, exist_ok=True)
81
+
82
+
83
+ # data loader
84
+ data_filename = args.data_path + '/RC-49_' + str(args.img_size) + 'x' + str(args.img_size) + '.h5'
85
+ hf = h5py.File(data_filename, 'r')
86
+ angles = hf['labels'][:]
87
+ angles = angles.astype(float)
88
+ labels = hf['types'][:]
89
+ images = hf['images'][:]
90
+ hf.close()
91
+ num_classes = len(set(labels))
92
+ assert max(labels)==num_classes-1
93
+
94
+ q1 = args.min_label
95
+ q2 = args.max_label
96
+ indx = np.where((angles>q1)*(angles<q2)==True)[0]
97
+ angles = angles[indx]
98
+ labels = labels[indx]
99
+ images = images[indx]
100
+ assert len(labels)==len(images)
101
+ assert len(angles)==len(images)
102
+
103
+
104
+ # define training (and validaiton) set
105
+ if args.CVMode:
106
+ for i in range(num_classes):
107
+ indx_i = np.where(labels==i)[0] #i-th class
108
+ np.random.shuffle(indx_i)
109
+ num_imgs_all_i = len(indx_i)
110
+ num_imgs_valid_i = int(num_imgs_all_i*args.valid_proport)
111
+ if i == 0:
112
+ indx_valid = indx_i[0:num_imgs_valid_i]
113
+ indx_train = indx_i[num_imgs_valid_i:]
114
+ else:
115
+ indx_valid = np.concatenate((indx_valid, indx_i[0:num_imgs_valid_i]))
116
+ indx_train = np.concatenate((indx_train, indx_i[num_imgs_valid_i:]))
117
+ #end for i
118
+ trainset = IMGs_dataset(images[indx_train], labels[indx_train], normalize=True)
119
+ trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size_train, shuffle=True, num_workers=args.num_workers)
120
+ validset = IMGs_dataset(images[indx_valid], labels[indx_valid], normalize=True)
121
+ validloader = torch.utils.data.DataLoader(validset, batch_size=args.batch_size_valid, shuffle=False, num_workers=args.num_workers)
122
+ else:
123
+ trainset = IMGs_dataset(images, labels, normalize=True)
124
+ trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size_train, shuffle=True, num_workers=args.num_workers)
125
+
126
+
127
+
128
+ ###########################################################################################################
129
+ # Necessary functions
130
+ ###########################################################################################################
131
+
132
+ #initialize CNNs
133
+ def net_initialization(Pretrained_CNN_Name, ngpu = 1, num_classes=num_classes):
134
+ if Pretrained_CNN_Name == "ResNet18_class":
135
+ net = ResNet18_class_eval(num_classes=num_classes, ngpu = ngpu)
136
+ elif Pretrained_CNN_Name == "ResNet34_class":
137
+ net = ResNet34_class_eval(num_classes=num_classes, ngpu = ngpu)
138
+ elif Pretrained_CNN_Name == "ResNet50_class":
139
+ net = ResNet50_class_eval(num_classes=num_classes, ngpu = ngpu)
140
+ elif Pretrained_CNN_Name == "ResNet101_class":
141
+ net = ResNet101_class_eval(num_classes=num_classes, ngpu = ngpu)
142
+
143
+ net_name = 'PreCNNForEvalGANs_' + Pretrained_CNN_Name #get the net's name
144
+ net = net.to(device)
145
+
146
+ return net, net_name
147
+
148
+ #adjust CNN learning rate
149
+ def adjust_learning_rate(optimizer, epoch, BASE_LR_CNN):
150
+ lr = BASE_LR_CNN
151
+ # if epoch >= 35:
152
+ # lr /= 10
153
+ # if epoch >= 70:
154
+ # lr /= 10
155
+ if epoch >= 50:
156
+ lr /= 10
157
+ if epoch >= 120:
158
+ lr /= 10
159
+ for param_group in optimizer.param_groups:
160
+ param_group['lr'] = lr
161
+
162
+
163
+ def train_CNN():
164
+
165
+ start_tmp = timeit.default_timer()
166
+ for epoch in range(args.epochs):
167
+ net.train()
168
+ train_loss = 0
169
+ adjust_learning_rate(optimizer, epoch, args.base_lr)
170
+ for batch_idx, (batch_train_images, batch_train_labels) in enumerate(trainloader):
171
+
172
+ # batch_train_images = nn.functional.interpolate(batch_train_images, size = (299,299), scale_factor=None, mode='bilinear', align_corners=False)
173
+
174
+ batch_train_images = batch_train_images.type(torch.float).cuda()
175
+ batch_train_labels = batch_train_labels.type(torch.long).cuda()
176
+
177
+ #Forward pass
178
+ outputs,_ = net(batch_train_images)
179
+ loss = criterion(outputs, batch_train_labels)
180
+
181
+ #backward pass
182
+ optimizer.zero_grad()
183
+ loss.backward()
184
+ optimizer.step()
185
+
186
+ train_loss += loss.cpu().item()
187
+ #end for batch_idx
188
+ train_loss = train_loss / len(trainloader)
189
+
190
+ if args.CVMode:
191
+ valid_acc = valid_CNN(verbose=False)
192
+ print('CNN: [epoch %d/%d] train_loss:%f valid_acc:%f Time: %.4f' % (epoch+1, args.epochs, train_loss, valid_acc, timeit.default_timer()-start_tmp))
193
+ else:
194
+ print('CNN: [epoch %d/%d] train_loss:%f Time: %.4f' % (epoch+1, args.epochs, train_loss, timeit.default_timer()-start_tmp))
195
+ #end for epoch
196
+
197
+ return net, optimizer
198
+
199
+
200
+ if args.CVMode:
201
+ def valid_CNN(verbose=True):
202
+ net.eval() # eval mode (batchnorm uses moving mean/variance instead of mini-batch mean/variance)
203
+ with torch.no_grad():
204
+ correct = 0
205
+ total = 0
206
+ for batch_idx, (images, labels) in enumerate(validloader):
207
+ images = images.type(torch.float).cuda()
208
+ labels = labels.type(torch.long).cuda()
209
+ outputs,_ = net(images)
210
+ _, predicted = torch.max(outputs.data, 1)
211
+ total += labels.size(0)
212
+ correct += (predicted == labels).sum().item()
213
+ if verbose:
214
+ print('Valid Accuracy of the model on the validation set: {} %'.format(100.0 * correct / total))
215
+ return 100.0 * correct / total
216
+
217
+
218
+
219
+ ###########################################################################################################
220
+ # Training and Testing
221
+ ###########################################################################################################
222
+ # model initialization
223
+ net, net_name = net_initialization(args.CNN, ngpu = ngpu, num_classes = num_classes)
224
+ criterion = nn.CrossEntropyLoss()
225
+ optimizer = torch.optim.SGD(net.parameters(), lr = args.base_lr, momentum= 0.9, weight_decay=args.weight_dacay)
226
+
227
+ filename_ckpt = save_models_folder + '/ckpt_{}_epoch_{}_seed_{}_classify_{}_chair_types_CVMode_{}.pth'.format(net_name, args.epochs, args.seed, num_classes, args.CVMode)
228
+
229
+ # training
230
+ if not os.path.isfile(filename_ckpt):
231
+ # TRAIN CNN
232
+ print("\n Begin training CNN: ")
233
+ start = timeit.default_timer()
234
+ net, optimizer = train_CNN()
235
+ stop = timeit.default_timer()
236
+ print("Time elapses: {}s".format(stop - start))
237
+ # save model
238
+ torch.save({
239
+ 'net_state_dict': net.state_dict(),
240
+ }, filename_ckpt)
241
+ else:
242
+ print("\n Ckpt already exists")
243
+ print("\n Loading...")
244
+ checkpoint = torch.load(filename_ckpt)
245
+ net.load_state_dict(checkpoint['net_state_dict'])
246
+ torch.cuda.empty_cache()#release GPU mem which is not references
247
+
248
+ if args.CVMode:
249
+ #validation
250
+ _ = valid_CNN(True)
251
+ torch.cuda.empty_cache()
@@ -0,0 +1,255 @@
1
+ """
2
+
3
+ Pre-train a CNN on the whole dataset for evaluation purpose
4
+
5
+ """
6
+ import os
7
+ import argparse
8
+ import shutil
9
+ import timeit
10
+ import torch
11
+ import torchvision
12
+ import torchvision.transforms as transforms
13
+ import numpy as np
14
+ import torch.nn as nn
15
+ import torch.backends.cudnn as cudnn
16
+ import random
17
+ import matplotlib.pyplot as plt
18
+ import matplotlib as mpl
19
+ from torch import autograd
20
+ from torchvision.utils import save_image
21
+ import csv
22
+ from tqdm import tqdm
23
+ import gc
24
+ import h5py
25
+
26
+
27
+ #############################
28
+ # Settings
29
+ #############################
30
+
31
+ parser = argparse.ArgumentParser(description='Pre-train CNNs')
32
+ parser.add_argument('--root_path', type=str, default='')
33
+ parser.add_argument('--data_path', type=str, default='')
34
+ parser.add_argument('--num_workers', type=int, default=0)
35
+ parser.add_argument('--CNN', type=str, default='ResNet34_regre',
36
+ help='CNN for training; ResNetXX')
37
+ parser.add_argument('--epochs', type=int, default=200, metavar='N',
38
+ help='number of epochs to train CNNs (default: 200)')
39
+ parser.add_argument('--batch_size_train', type=int, default=256, metavar='N',
40
+ help='input batch size for training')
41
+ parser.add_argument('--batch_size_valid', type=int, default=10, metavar='N',
42
+ help='input batch size for testing')
43
+ parser.add_argument('--base_lr', type=float, default=0.01,
44
+ help='learning rate, default=0.1')
45
+ parser.add_argument('--weight_dacay', type=float, default=1e-4,
46
+ help='Weigth decay, default=1e-4')
47
+ parser.add_argument('--seed', type=int, default=2020, metavar='S',
48
+ help='random seed (default: 1)')
49
+ parser.add_argument('--CVMode', action='store_true', default=False,
50
+ help='CV mode?')
51
+ parser.add_argument('--img_size', type=int, default=128, metavar='N')
52
+ parser.add_argument('--min_label', type=float, default=0.0)
53
+ parser.add_argument('--max_label', type=float, default=90.0)
54
+ args = parser.parse_args()
55
+
56
+
57
+ wd = args.root_path
58
+ os.chdir(wd)
59
+ from models import *
60
+ from utils import IMGs_dataset
61
+
62
+ # cuda
63
+ device = torch.device("cuda")
64
+ ngpu = torch.cuda.device_count() # number of gpus
65
+
66
+ # random seed
67
+ random.seed(args.seed)
68
+ torch.manual_seed(args.seed)
69
+ torch.backends.cudnn.deterministic = True
70
+ np.random.seed(args.seed)
71
+
72
+ # directories for checkpoint, images and log files
73
+ save_models_folder = wd + '/output/eval_models/'
74
+ os.makedirs(save_models_folder, exist_ok=True)
75
+
76
+
77
+ ###########################################################################################################
78
+ # Data
79
+ ###########################################################################################################
80
+ # data loader
81
+ data_filename = args.data_path + '/RC-49_' + str(args.img_size) + 'x' + str(args.img_size) + '.h5'
82
+ hf = h5py.File(data_filename, 'r')
83
+ labels = hf['labels'][:]
84
+ labels = labels.astype(float)
85
+ images = hf['images'][:]
86
+ hf.close()
87
+ N_all = len(images)
88
+ assert len(images) == len(labels)
89
+
90
+ q1 = args.min_label
91
+ q2 = args.max_label
92
+ indx = np.where((labels>q1)*(labels<q2)==True)[0]
93
+ labels = labels[indx]
94
+ images = images[indx]
95
+ assert len(labels)==len(images)
96
+
97
+
98
+ # normalize to [0, 1]
99
+ #min_label = np.min(labels)
100
+ #labels += np.abs(min_label)
101
+ #max_label = np.max(labels)
102
+ #labels /= max_label
103
+ labels /= args.max_label
104
+
105
+
106
+ # define training and validation sets
107
+ if args.CVMode:
108
+ #80% Training; 20% valdation
109
+ indx_all = np.arange(len(labels))
110
+ np.random.shuffle(indx_all)
111
+ indx_valid = indx_all[0:int(0.2*len(labels))]
112
+ indx_train = indx_all[int(0.2*len(labels)):]
113
+
114
+ trainset = IMGs_dataset(images[indx_train], labels[indx_train], normalize=True)
115
+ trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size_train, shuffle=True, num_workers=args.num_workers)
116
+ validset = IMGs_dataset(images[indx_valid], labels[indx_valid], normalize=True)
117
+ validloader = torch.utils.data.DataLoader(validset, batch_size=args.batch_size_valid, shuffle=False, num_workers=args.num_workers)
118
+
119
+ else:
120
+ trainset = IMGs_dataset(images, labels, normalize=True)
121
+ trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size_train, shuffle=True, num_workers=args.num_workers)
122
+
123
+
124
+
125
+
126
+ ###########################################################################################################
127
+ # Necessary functions
128
+ ###########################################################################################################
129
+
130
+ #initialize CNNs
131
+ def net_initialization(Pretrained_CNN_Name, ngpu = 1):
132
+ if Pretrained_CNN_Name == "ResNet18_regre":
133
+ net = ResNet18_regre_eval(ngpu = ngpu)
134
+ elif Pretrained_CNN_Name == "ResNet34_regre":
135
+ net = ResNet34_regre_eval(ngpu = ngpu)
136
+ elif Pretrained_CNN_Name == "ResNet50_regre":
137
+ net = ResNet50_regre_eval(ngpu = ngpu)
138
+ elif Pretrained_CNN_Name == "ResNet101_regre":
139
+ net = ResNet101_regre_eval(ngpu = ngpu)
140
+
141
+ net_name = 'PreCNNForEvalGANs_' + Pretrained_CNN_Name #get the net's name
142
+ net = net.to(device)
143
+
144
+ return net, net_name
145
+
146
+ #adjust CNN learning rate
147
+ def adjust_learning_rate(optimizer, epoch, BASE_LR_CNN):
148
+ lr = BASE_LR_CNN
149
+ # if epoch >= 35:
150
+ # lr /= 10
151
+ # if epoch >= 70:
152
+ # lr /= 10
153
+ if epoch >= 50:
154
+ lr /= 10
155
+ if epoch >= 120:
156
+ lr /= 10
157
+ for param_group in optimizer.param_groups:
158
+ param_group['lr'] = lr
159
+
160
+
161
+ def train_CNN():
162
+
163
+ start_tmp = timeit.default_timer()
164
+ for epoch in range(args.epochs):
165
+ net.train()
166
+ train_loss = 0
167
+ adjust_learning_rate(optimizer, epoch, args.base_lr)
168
+ for batch_idx, (batch_train_images, batch_train_labels) in enumerate(trainloader):
169
+
170
+ # batch_train_images = nn.functional.interpolate(batch_train_images, size = (299,299), scale_factor=None, mode='bilinear', align_corners=False)
171
+
172
+ batch_train_images = batch_train_images.type(torch.float).cuda()
173
+ batch_train_labels = batch_train_labels.type(torch.float).view(-1,1).cuda()
174
+
175
+ #Forward pass
176
+ outputs,_ = net(batch_train_images)
177
+ loss = criterion(outputs, batch_train_labels)
178
+
179
+ #backward pass
180
+ optimizer.zero_grad()
181
+ loss.backward()
182
+ optimizer.step()
183
+
184
+ train_loss += loss.cpu().item()
185
+ #end for batch_idx
186
+ train_loss = train_loss / len(trainloader)
187
+
188
+ if args.CVMode:
189
+ valid_loss = valid_CNN(verbose=False)
190
+ print('CNN: [epoch %d/%d] train_loss:%f valid_loss (avg_abs):%f Time: %.4f' % (epoch+1, args.epochs, train_loss, valid_loss, timeit.default_timer()-start_tmp))
191
+ else:
192
+ print('CNN: [epoch %d/%d] train_loss:%f Time:%.4f' % (epoch+1, args.epochs, train_loss, timeit.default_timer()-start_tmp))
193
+ #end for epoch
194
+
195
+ return net, optimizer
196
+
197
+ if args.CVMode:
198
+ def valid_CNN(verbose=True):
199
+ net.eval() # eval mode (batchnorm uses moving mean/variance instead of mini-batch mean/variance)
200
+ with torch.no_grad():
201
+ abs_diff_avg = 0
202
+ total = 0
203
+ for batch_idx, (images, labels) in enumerate(validloader):
204
+ images = images.type(torch.float).cuda()
205
+ labels = labels.type(torch.float).view(-1).cpu().numpy()
206
+ outputs,_ = net(images)
207
+ outputs = outputs.view(-1).cpu().numpy()
208
+ labels = labels * args.max_label
209
+ outputs = outputs * args.max_label
210
+ abs_diff_avg += np.sum(np.abs(labels-outputs))
211
+ total += len(labels)
212
+
213
+ if verbose:
214
+ print('Validation Average Absolute Difference: {}'.format(abs_diff_avg/total))
215
+ return abs_diff_avg/total
216
+
217
+
218
+
219
+ ###########################################################################################################
220
+ # Training and validation
221
+ ###########################################################################################################
222
+
223
+
224
+ # model initialization
225
+ net, net_name = net_initialization(args.CNN, ngpu = ngpu)
226
+ criterion = nn.MSELoss()
227
+ optimizer = torch.optim.SGD(net.parameters(), lr = args.base_lr, momentum= 0.9, weight_decay=args.weight_dacay)
228
+
229
+ filename_ckpt = save_models_folder + '/ckpt_{}_epoch_{}_seed_{}_CVMode_{}.pth'.format(net_name, args.epochs, args.seed, args.CVMode)
230
+
231
+
232
+ # training
233
+ if not os.path.isfile(filename_ckpt):
234
+ # TRAIN CNN
235
+ print("\n Begin training CNN: ")
236
+ start = timeit.default_timer()
237
+ net, optimizer = train_CNN()
238
+ stop = timeit.default_timer()
239
+ print("Time elapses: {}s".format(stop - start))
240
+ # save model
241
+ torch.save({
242
+ 'net_state_dict': net.state_dict(),
243
+ }, filename_ckpt)
244
+ else:
245
+ print("\n Ckpt already exists")
246
+ print("\n Loading...")
247
+ checkpoint = torch.load(filename_ckpt)
248
+ net.load_state_dict(checkpoint['net_state_dict'])
249
+ torch.cuda.empty_cache()#release GPU mem which is not references
250
+
251
+
252
+ if args.CVMode:
253
+ #validation
254
+ _ = valid_CNN(True)
255
+ torch.cuda.empty_cache()