Myosotis-Researches 0.1.7__py3-none-any.whl → 0.1.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. myosotis_researches/CcGAN/train/__init__.py +4 -0
  2. myosotis_researches/CcGAN/{train_128_output_10 → train}/train_ccgan.py +4 -4
  3. myosotis_researches/CcGAN/{train_128 → train}/train_cgan.py +1 -3
  4. myosotis_researches/CcGAN/{train_128 → train}/train_cgan_concat.py +1 -3
  5. myosotis_researches/CcGAN/utils/__init__.py +2 -1
  6. myosotis_researches/CcGAN/utils/train.py +94 -3
  7. {myosotis_researches-0.1.7.dist-info → myosotis_researches-0.1.9.dist-info}/METADATA +1 -1
  8. myosotis_researches-0.1.9.dist-info/RECORD +24 -0
  9. myosotis_researches/CcGAN/models_128/CcGAN_SAGAN.py +0 -301
  10. myosotis_researches/CcGAN/models_128/ResNet_class_eval.py +0 -141
  11. myosotis_researches/CcGAN/models_128/ResNet_embed.py +0 -188
  12. myosotis_researches/CcGAN/models_128/ResNet_regre_eval.py +0 -175
  13. myosotis_researches/CcGAN/models_128/__init__.py +0 -7
  14. myosotis_researches/CcGAN/models_128/autoencoder.py +0 -119
  15. myosotis_researches/CcGAN/models_128/cGAN_SAGAN.py +0 -276
  16. myosotis_researches/CcGAN/models_128/cGAN_concat_SAGAN.py +0 -245
  17. myosotis_researches/CcGAN/models_256/CcGAN_SAGAN.py +0 -303
  18. myosotis_researches/CcGAN/models_256/ResNet_class_eval.py +0 -142
  19. myosotis_researches/CcGAN/models_256/ResNet_embed.py +0 -188
  20. myosotis_researches/CcGAN/models_256/ResNet_regre_eval.py +0 -178
  21. myosotis_researches/CcGAN/models_256/__init__.py +0 -7
  22. myosotis_researches/CcGAN/models_256/autoencoder.py +0 -133
  23. myosotis_researches/CcGAN/models_256/cGAN_SAGAN.py +0 -280
  24. myosotis_researches/CcGAN/models_256/cGAN_concat_SAGAN.py +0 -249
  25. myosotis_researches/CcGAN/train_128/DiffAugment_pytorch.py +0 -76
  26. myosotis_researches/CcGAN/train_128/__init__.py +0 -0
  27. myosotis_researches/CcGAN/train_128/eval_metrics.py +0 -205
  28. myosotis_researches/CcGAN/train_128/opts.py +0 -87
  29. myosotis_researches/CcGAN/train_128/pretrain_AE.py +0 -268
  30. myosotis_researches/CcGAN/train_128/pretrain_CNN_class.py +0 -251
  31. myosotis_researches/CcGAN/train_128/pretrain_CNN_regre.py +0 -255
  32. myosotis_researches/CcGAN/train_128/train_ccgan.py +0 -303
  33. myosotis_researches/CcGAN/train_128/utils.py +0 -120
  34. myosotis_researches/CcGAN/train_128_output_10/DiffAugment_pytorch.py +0 -76
  35. myosotis_researches/CcGAN/train_128_output_10/__init__.py +0 -0
  36. myosotis_researches/CcGAN/train_128_output_10/eval_metrics.py +0 -205
  37. myosotis_researches/CcGAN/train_128_output_10/opts.py +0 -87
  38. myosotis_researches/CcGAN/train_128_output_10/pretrain_AE.py +0 -268
  39. myosotis_researches/CcGAN/train_128_output_10/pretrain_CNN_class.py +0 -251
  40. myosotis_researches/CcGAN/train_128_output_10/pretrain_CNN_regre.py +0 -255
  41. myosotis_researches/CcGAN/train_128_output_10/train_cgan.py +0 -254
  42. myosotis_researches/CcGAN/train_128_output_10/train_cgan_concat.py +0 -242
  43. myosotis_researches/CcGAN/train_128_output_10/train_net_for_label_embed.py +0 -181
  44. myosotis_researches/CcGAN/train_128_output_10/utils.py +0 -120
  45. myosotis_researches-0.1.7.dist-info/RECORD +0 -59
  46. /myosotis_researches/CcGAN/{train_128 → train}/train_net_for_label_embed.py +0 -0
  47. {myosotis_researches-0.1.7.dist-info → myosotis_researches-0.1.9.dist-info}/WHEEL +0 -0
  48. {myosotis_researches-0.1.7.dist-info → myosotis_researches-0.1.9.dist-info}/licenses/LICENSE +0 -0
  49. {myosotis_researches-0.1.7.dist-info → myosotis_researches-0.1.9.dist-info}/top_level.txt +0 -0
@@ -1,268 +0,0 @@
1
-
2
- import os
3
- import argparse
4
- import shutil
5
- import timeit
6
- import torch
7
- import torchvision
8
- import torchvision.transforms as transforms
9
- import numpy as np
10
- import torch.nn as nn
11
- import torch.backends.cudnn as cudnn
12
- import random
13
- import matplotlib.pyplot as plt
14
- import matplotlib as mpl
15
- from torch import autograd
16
- from torchvision.utils import save_image
17
- import csv
18
- from tqdm import tqdm
19
- import gc
20
- import h5py
21
-
22
-
23
- #############################
24
- # Settings
25
- #############################
26
-
27
- parser = argparse.ArgumentParser(description='Pre-train AE for computing FID')
28
- parser.add_argument('--root_path', type=str, default='')
29
- parser.add_argument('--data_path', type=str, default='')
30
- parser.add_argument('--num_workers', type=int, default=0)
31
- parser.add_argument('--dim_bottleneck', type=int, default=512)
32
- parser.add_argument('--epochs', type=int, default=200, metavar='N',
33
- help='number of epochs to train CNNs (default: 200)')
34
- parser.add_argument('--resume_epoch', type=int, default=0)
35
- parser.add_argument('--batch_size_train', type=int, default=128, metavar='N',
36
- help='input batch size for training')
37
- parser.add_argument('--batch_size_valid', type=int, default=10, metavar='N',
38
- help='input batch size for testing')
39
- parser.add_argument('--base_lr', type=float, default=1e-3,
40
- help='learning rate, default=1e-3')
41
- parser.add_argument('--lr_decay_epochs', type=int, default=50) #decay lr rate every dre_lr_decay_epochs epochs
42
- parser.add_argument('--lr_decay_factor', type=float, default=0.1)
43
- parser.add_argument('--lambda_sparsity', type=float, default=1e-4, help='penalty for sparsity')
44
- parser.add_argument('--weight_dacay', type=float, default=1e-4,
45
- help='Weigth decay, default=1e-4')
46
- parser.add_argument('--seed', type=int, default=2020, metavar='S',
47
- help='random seed (default: 1)')
48
- parser.add_argument('--CVMode', action='store_true', default=False,
49
- help='CV mode?')
50
- parser.add_argument('--img_size', type=int, default=128, metavar='N')
51
- parser.add_argument('--min_label', type=float, default=0.0)
52
- parser.add_argument('--max_label', type=float, default=90.0)
53
- args = parser.parse_args()
54
-
55
- wd = args.root_path
56
- os.chdir(wd)
57
- from ..models_128 import *
58
- from .utils import IMGs_dataset, SimpleProgressBar
59
-
60
- # some parameters in the opts
61
- dim_bottleneck = args.dim_bottleneck
62
- epochs = args.epochs
63
- base_lr = args.base_lr
64
- lr_decay_epochs = args.lr_decay_epochs
65
- lr_decay_factor = args.lr_decay_factor
66
- resume_epoch = args.resume_epoch
67
- lambda_sparsity = args.lambda_sparsity
68
-
69
-
70
- # random seed
71
- random.seed(args.seed)
72
- torch.manual_seed(args.seed)
73
- torch.backends.cudnn.deterministic = True
74
- cudnn.benchmark = False
75
- np.random.seed(args.seed)
76
-
77
- # directories for checkpoint, images and log files
78
- save_models_folder = wd + '/output/eval_models'
79
- os.makedirs(save_models_folder, exist_ok=True)
80
- save_AE_images_in_train_folder = save_models_folder + '/AE_lambda_{}_images_in_train'.format(lambda_sparsity)
81
- os.makedirs(save_AE_images_in_train_folder, exist_ok=True)
82
- save_AE_images_in_valid_folder = save_models_folder + '/AE_lambda_{}_images_in_valid'.format(lambda_sparsity)
83
- os.makedirs(save_AE_images_in_valid_folder, exist_ok=True)
84
-
85
-
86
- ###########################################################################################################
87
- # Data
88
- ###########################################################################################################
89
- # data loader
90
- data_filename = args.data_path + '/Ra_' + str(args.img_size) + 'x' + str(args.img_size) + '.h5'
91
- hf = h5py.File(data_filename, 'r')
92
- labels = hf['labels'][:]
93
- labels = labels.astype(float)
94
- images = hf['images'][:]
95
- hf.close()
96
- N_all = len(images)
97
- assert len(images) == len(labels)
98
-
99
- q1 = args.min_label
100
- q2 = args.max_label
101
- indx = np.where((labels>q1)*(labels<q2)==True)[0]
102
- labels = labels[indx]
103
- images = images[indx]
104
- assert len(labels)==len(images)
105
-
106
- # define training and validation sets
107
- if args.CVMode:
108
- #90% Training; 10% valdation
109
- valid_prop = 0.1 #proportion of the validation samples
110
- indx_all = np.arange(len(images))
111
- np.random.shuffle(indx_all)
112
- indx_valid = indx_all[0:int(valid_prop*len(images))]
113
- indx_train = indx_all[int(valid_prop*len(images)):]
114
-
115
- trainset = IMGs_dataset(images[indx_train], labels=None, normalize=True)
116
- trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size_train, shuffle=True, num_workers=args.num_workers)
117
- validset = IMGs_dataset(images[indx_valid], labels=None, normalize=True)
118
- validloader = torch.utils.data.DataLoader(validset, batch_size=args.batch_size_valid, shuffle=False, num_workers=args.num_workers)
119
-
120
- else:
121
- trainset = IMGs_dataset(images, labels=None, normalize=True)
122
- trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size_train, shuffle=True, num_workers=args.num_workers)
123
-
124
-
125
- ###########################################################################################################
126
- # Necessary functions
127
- ###########################################################################################################
128
-
129
- def adjust_learning_rate(epoch, epochs, optimizer, base_lr, lr_decay_epochs, lr_decay_factor):
130
- lr = base_lr #1e-4
131
-
132
- for i in range(epochs//lr_decay_epochs):
133
- if epoch >= (i+1)*lr_decay_epochs:
134
- lr *= lr_decay_factor
135
-
136
- for param_group in optimizer.param_groups:
137
- param_group['lr'] = lr
138
-
139
- def train_AE():
140
-
141
- # define optimizer
142
- params = list(net_encoder.parameters()) + list(net_decoder.parameters())
143
- optimizer = torch.optim.Adam(params, lr = base_lr, betas=(0.5, 0.999), weight_decay=1e-4)
144
-
145
- # criterion
146
- criterion = nn.MSELoss()
147
-
148
- if resume_epoch>0:
149
- print("Loading ckpt to resume training AE >>>")
150
- ckpt_fullpath = save_models_folder + "/AE_checkpoint_intrain/AE_checkpoint_epoch_{}_lambda_{}.pth".format(resume_epoch, lambda_sparsity)
151
- checkpoint = torch.load(ckpt_fullpath)
152
- net_encoder.load_state_dict(checkpoint['net_encoder_state_dict'])
153
- net_decoder.load_state_dict(checkpoint['net_decoder_state_dict'])
154
- optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
155
- torch.set_rng_state(checkpoint['rng_state'])
156
- gen_iterations = checkpoint['gen_iterations']
157
- else:
158
- gen_iterations = 0
159
-
160
- start_time = timeit.default_timer()
161
- for epoch in range(resume_epoch, epochs):
162
-
163
- adjust_learning_rate(epoch, epochs, optimizer, base_lr, lr_decay_epochs, lr_decay_factor)
164
-
165
- train_loss = 0
166
-
167
- for batch_idx, batch_real_images in enumerate(trainloader):
168
-
169
- net_encoder.train()
170
- net_decoder.train()
171
-
172
- batch_size_curr = batch_real_images.shape[0]
173
-
174
- batch_real_images = batch_real_images.type(torch.float).cuda()
175
-
176
-
177
- batch_features = net_encoder(batch_real_images)
178
- batch_recons_images = net_decoder(batch_features)
179
-
180
- '''
181
- based on https://debuggercafe.com/sparse-autoencoders-using-l1-regularization-with-pytorch/
182
- '''
183
- loss = criterion(batch_recons_images, batch_real_images) + lambda_sparsity * batch_features.mean()
184
-
185
- #backward pass
186
- optimizer.zero_grad()
187
- loss.backward()
188
- optimizer.step()
189
-
190
- train_loss += loss.cpu().item()
191
-
192
- gen_iterations += 1
193
-
194
- if gen_iterations % 100 == 0:
195
- n_row=min(10, int(np.sqrt(batch_size_curr)))
196
- with torch.no_grad():
197
- batch_recons_images = net_decoder(net_encoder(batch_real_images[0:n_row**2]))
198
- batch_recons_images = batch_recons_images.detach().cpu()
199
- save_image(batch_recons_images.data, save_AE_images_in_train_folder + '/{}.png'.format(gen_iterations), nrow=n_row, normalize=True)
200
-
201
- if gen_iterations % 20 == 0:
202
- print("AE+lambda{}: [step {}] [epoch {}/{}] [train loss {}] [Time {}]".format(lambda_sparsity, gen_iterations, epoch+1, epochs, train_loss/(batch_idx+1), timeit.default_timer()-start_time) )
203
- # end for batch_idx
204
-
205
- if (epoch+1) % 50 == 0:
206
- save_file = save_models_folder + "/AE_checkpoint_intrain/AE_checkpoint_epoch_{}_lambda_{}.pth".format(epoch+1, lambda_sparsity)
207
- os.makedirs(os.path.dirname(save_file), exist_ok=True)
208
- torch.save({
209
- 'gen_iterations': gen_iterations,
210
- 'net_encoder_state_dict': net_encoder.state_dict(),
211
- 'net_decoder_state_dict': net_decoder.state_dict(),
212
- 'optimizer_state_dict': optimizer.state_dict(),
213
- 'rng_state': torch.get_rng_state()
214
- }, save_file)
215
- #end for epoch
216
-
217
- return net_encoder, net_decoder
218
-
219
-
220
- if args.CVMode:
221
- def valid_AE():
222
- net_encoder.eval()
223
- net_decoder.eval()
224
- with torch.no_grad():
225
- for batch_idx, images in enumerate(validloader):
226
- images = images.type(torch.float).cuda()
227
- features = net_encoder(images)
228
- recons_images = net_decoder(features)
229
- save_image(recons_images.data, save_AE_images_in_valid_folder + '/{}_recons.png'.format(batch_idx), nrow=10, normalize=True)
230
- save_image(images.data, save_AE_images_in_valid_folder + '/{}_real.png'.format(batch_idx), nrow=10, normalize=True)
231
- return None
232
-
233
-
234
-
235
- ###########################################################################################################
236
- # Training and validation
237
- ###########################################################################################################
238
-
239
- # model initialization
240
- net_encoder = encoder(dim_bottleneck=args.dim_bottleneck).cuda()
241
- net_decoder = decoder(dim_bottleneck=args.dim_bottleneck).cuda()
242
- net_encoder = nn.DataParallel(net_encoder)
243
- net_decoder = nn.DataParallel(net_decoder)
244
-
245
- filename_ckpt = save_models_folder + '/ckpt_AE_epoch_{}_seed_{}_CVMode_{}.pth'.format(args.epochs, args.seed, args.CVMode)
246
-
247
- # training
248
- if not os.path.isfile(filename_ckpt):
249
- print("\n Begin training AE: ")
250
- start = timeit.default_timer()
251
- net_encoder, net_decoder = train_AE()
252
- stop = timeit.default_timer()
253
- print("Time elapses: {}s".format(stop - start))
254
- # save model
255
- torch.save({
256
- 'net_encoder_state_dict': net_encoder.state_dict(),
257
- 'net_decoder_state_dict': net_decoder.state_dict(),
258
- }, filename_ckpt)
259
- else:
260
- print("\n Ckpt already exists")
261
- print("\n Loading...")
262
- checkpoint = torch.load(filename_ckpt)
263
- net_encoder.load_state_dict(checkpoint['net_encoder_state_dict'])
264
- net_decoder.load_state_dict(checkpoint['net_decoder_state_dict'])
265
-
266
- if args.CVMode:
267
- #validation
268
- _ = valid_AE()
@@ -1,251 +0,0 @@
1
- """
2
-
3
- Pre-train a CNN on the whole dataset for evaluation purpose
4
-
5
- """
6
- import os
7
- import argparse
8
- import shutil
9
- import timeit
10
-
11
- import torch
12
- import torchvision
13
- import torchvision.transforms as transforms
14
- import numpy as np
15
- import torch.nn as nn
16
- import torch.backends.cudnn as cudnn
17
- import random
18
- import matplotlib.pyplot as plt
19
- import matplotlib as mpl
20
- from torch import autograd
21
- from torchvision.utils import save_image
22
- import csv
23
- from tqdm import tqdm
24
- import gc
25
- import h5py
26
-
27
- from models import *
28
- from utils import IMGs_dataset
29
-
30
-
31
- #############################
32
- # Settings
33
- #############################
34
-
35
- parser = argparse.ArgumentParser(description='Pre-train CNNs')
36
- parser.add_argument('--root_path', type=str, default='')
37
- parser.add_argument('--data_path', type=str, default='')
38
- parser.add_argument('--num_workers', type=int, default=0)
39
- parser.add_argument('--CNN', type=str, default='ResNet34_class',
40
- help='CNN for training; ResNetXX')
41
- parser.add_argument('--epochs', type=int, default=200, metavar='N',
42
- help='number of epochs to train CNNs (default: 200)')
43
- parser.add_argument('--batch_size_train', type=int, default=128, metavar='N',
44
- help='input batch size for training')
45
- parser.add_argument('--batch_size_valid', type=int, default=10, metavar='N',
46
- help='input batch size for testing')
47
- parser.add_argument('--base_lr', type=float, default=0.01,
48
- help='learning rate, default=0.1')
49
- parser.add_argument('--weight_dacay', type=float, default=1e-4,
50
- help='Weigth decay, default=1e-4')
51
- parser.add_argument('--seed', type=int, default=2020, metavar='S',
52
- help='random seed (default: 1)')
53
- parser.add_argument('--CVMode', action='store_true', default=False,
54
- help='CV mode?')
55
- parser.add_argument('--valid_proport', type=float, default=0.1,
56
- help='Proportion of validation samples')
57
- parser.add_argument('--img_size', type=int, default=128, metavar='N')
58
- parser.add_argument('--min_label', type=float, default=0.0)
59
- parser.add_argument('--max_label', type=float, default=90.0)
60
- args = parser.parse_args()
61
-
62
-
63
- wd = args.root_path
64
- os.chdir(wd)
65
- from ..models_128 import *
66
- from .utils import IMGs_dataset
67
-
68
- # cuda
69
- device = torch.device("cuda")
70
- ngpu = torch.cuda.device_count() # number of gpus
71
-
72
- # random seed
73
- random.seed(args.seed)
74
- torch.manual_seed(args.seed)
75
- torch.backends.cudnn.deterministic = True
76
- np.random.seed(args.seed)
77
-
78
- # directories for checkpoint, images and log files
79
- save_models_folder = wd + '/output/eval_models/'
80
- os.makedirs(save_models_folder, exist_ok=True)
81
-
82
-
83
- # data loader
84
- data_filename = args.data_path + '/Ra_' + str(args.img_size) + 'x' + str(args.img_size) + '.h5'
85
- hf = h5py.File(data_filename, 'r')
86
- angles = hf['labels'][:]
87
- angles = angles.astype(float)
88
- labels = hf['types'][:]
89
- images = hf['images'][:]
90
- hf.close()
91
- num_classes = len(set(labels))
92
- assert max(labels)==num_classes-1
93
-
94
- q1 = args.min_label
95
- q2 = args.max_label
96
- indx = np.where((angles>q1)*(angles<q2)==True)[0]
97
- angles = angles[indx]
98
- labels = labels[indx]
99
- images = images[indx]
100
- assert len(labels)==len(images)
101
- assert len(angles)==len(images)
102
-
103
-
104
- # define training (and validaiton) set
105
- if args.CVMode:
106
- for i in range(num_classes):
107
- indx_i = np.where(labels==i)[0] #i-th class
108
- np.random.shuffle(indx_i)
109
- num_imgs_all_i = len(indx_i)
110
- num_imgs_valid_i = int(num_imgs_all_i*args.valid_proport)
111
- if i == 0:
112
- indx_valid = indx_i[0:num_imgs_valid_i]
113
- indx_train = indx_i[num_imgs_valid_i:]
114
- else:
115
- indx_valid = np.concatenate((indx_valid, indx_i[0:num_imgs_valid_i]))
116
- indx_train = np.concatenate((indx_train, indx_i[num_imgs_valid_i:]))
117
- #end for i
118
- trainset = IMGs_dataset(images[indx_train], labels[indx_train], normalize=True)
119
- trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size_train, shuffle=True, num_workers=args.num_workers)
120
- validset = IMGs_dataset(images[indx_valid], labels[indx_valid], normalize=True)
121
- validloader = torch.utils.data.DataLoader(validset, batch_size=args.batch_size_valid, shuffle=False, num_workers=args.num_workers)
122
- else:
123
- trainset = IMGs_dataset(images, labels, normalize=True)
124
- trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size_train, shuffle=True, num_workers=args.num_workers)
125
-
126
-
127
-
128
- ###########################################################################################################
129
- # Necessary functions
130
- ###########################################################################################################
131
-
132
- #initialize CNNs
133
- def net_initialization(Pretrained_CNN_Name, ngpu = 1, num_classes=num_classes):
134
- if Pretrained_CNN_Name == "ResNet18_class":
135
- net = ResNet18_class_eval(num_classes=num_classes, ngpu = ngpu)
136
- elif Pretrained_CNN_Name == "ResNet34_class":
137
- net = ResNet34_class_eval(num_classes=num_classes, ngpu = ngpu)
138
- elif Pretrained_CNN_Name == "ResNet50_class":
139
- net = ResNet50_class_eval(num_classes=num_classes, ngpu = ngpu)
140
- elif Pretrained_CNN_Name == "ResNet101_class":
141
- net = ResNet101_class_eval(num_classes=num_classes, ngpu = ngpu)
142
-
143
- net_name = 'PreCNNForEvalGANs_' + Pretrained_CNN_Name #get the net's name
144
- net = net.to(device)
145
-
146
- return net, net_name
147
-
148
- #adjust CNN learning rate
149
- def adjust_learning_rate(optimizer, epoch, BASE_LR_CNN):
150
- lr = BASE_LR_CNN
151
- # if epoch >= 35:
152
- # lr /= 10
153
- # if epoch >= 70:
154
- # lr /= 10
155
- if epoch >= 50:
156
- lr /= 10
157
- if epoch >= 120:
158
- lr /= 10
159
- for param_group in optimizer.param_groups:
160
- param_group['lr'] = lr
161
-
162
-
163
- def train_CNN():
164
-
165
- start_tmp = timeit.default_timer()
166
- for epoch in range(args.epochs):
167
- net.train()
168
- train_loss = 0
169
- adjust_learning_rate(optimizer, epoch, args.base_lr)
170
- for batch_idx, (batch_train_images, batch_train_labels) in enumerate(trainloader):
171
-
172
- # batch_train_images = nn.functional.interpolate(batch_train_images, size = (299,299), scale_factor=None, mode='bilinear', align_corners=False)
173
-
174
- batch_train_images = batch_train_images.type(torch.float).cuda()
175
- batch_train_labels = batch_train_labels.type(torch.long).cuda()
176
-
177
- #Forward pass
178
- outputs,_ = net(batch_train_images)
179
- loss = criterion(outputs, batch_train_labels)
180
-
181
- #backward pass
182
- optimizer.zero_grad()
183
- loss.backward()
184
- optimizer.step()
185
-
186
- train_loss += loss.cpu().item()
187
- #end for batch_idx
188
- train_loss = train_loss / len(trainloader)
189
-
190
- if args.CVMode:
191
- valid_acc = valid_CNN(verbose=False)
192
- print('CNN: [epoch %d/%d] train_loss:%f valid_acc:%f Time: %.4f' % (epoch+1, args.epochs, train_loss, valid_acc, timeit.default_timer()-start_tmp))
193
- else:
194
- print('CNN: [epoch %d/%d] train_loss:%f Time: %.4f' % (epoch+1, args.epochs, train_loss, timeit.default_timer()-start_tmp))
195
- #end for epoch
196
-
197
- return net, optimizer
198
-
199
-
200
- if args.CVMode:
201
- def valid_CNN(verbose=True):
202
- net.eval() # eval mode (batchnorm uses moving mean/variance instead of mini-batch mean/variance)
203
- with torch.no_grad():
204
- correct = 0
205
- total = 0
206
- for batch_idx, (images, labels) in enumerate(validloader):
207
- images = images.type(torch.float).cuda()
208
- labels = labels.type(torch.long).cuda()
209
- outputs,_ = net(images)
210
- _, predicted = torch.max(outputs.data, 1)
211
- total += labels.size(0)
212
- correct += (predicted == labels).sum().item()
213
- if verbose:
214
- print('Valid Accuracy of the model on the validation set: {} %'.format(100.0 * correct / total))
215
- return 100.0 * correct / total
216
-
217
-
218
-
219
- ###########################################################################################################
220
- # Training and Testing
221
- ###########################################################################################################
222
- # model initialization
223
- net, net_name = net_initialization(args.CNN, ngpu = ngpu, num_classes = num_classes)
224
- criterion = nn.CrossEntropyLoss()
225
- optimizer = torch.optim.SGD(net.parameters(), lr = args.base_lr, momentum= 0.9, weight_decay=args.weight_dacay)
226
-
227
- filename_ckpt = save_models_folder + '/ckpt_{}_epoch_{}_seed_{}_classify_{}_chair_types_CVMode_{}.pth'.format(net_name, args.epochs, args.seed, num_classes, args.CVMode)
228
-
229
- # training
230
- if not os.path.isfile(filename_ckpt):
231
- # TRAIN CNN
232
- print("\n Begin training CNN: ")
233
- start = timeit.default_timer()
234
- net, optimizer = train_CNN()
235
- stop = timeit.default_timer()
236
- print("Time elapses: {}s".format(stop - start))
237
- # save model
238
- torch.save({
239
- 'net_state_dict': net.state_dict(),
240
- }, filename_ckpt)
241
- else:
242
- print("\n Ckpt already exists")
243
- print("\n Loading...")
244
- checkpoint = torch.load(filename_ckpt)
245
- net.load_state_dict(checkpoint['net_state_dict'])
246
- torch.cuda.empty_cache()#release GPU mem which is not references
247
-
248
- if args.CVMode:
249
- #validation
250
- _ = valid_CNN(True)
251
- torch.cuda.empty_cache()