Myosotis-Researches 0.1.8__py3-none-any.whl → 0.1.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (31) hide show
  1. myosotis_researches/CcGAN/train/__init__.py +4 -0
  2. myosotis_researches/CcGAN/{train_128_output_10 → train}/train_ccgan.py +4 -4
  3. myosotis_researches/CcGAN/{train_128_output_10 → train}/train_cgan.py +1 -3
  4. myosotis_researches/CcGAN/{train_128_output_10 → train}/train_cgan_concat.py +1 -3
  5. {myosotis_researches-0.1.8.dist-info → myosotis_researches-0.1.10.dist-info}/METADATA +1 -1
  6. myosotis_researches-0.1.10.dist-info/RECORD +40 -0
  7. myosotis_researches/CcGAN/train_128/DiffAugment_pytorch.py +0 -76
  8. myosotis_researches/CcGAN/train_128/__init__.py +0 -0
  9. myosotis_researches/CcGAN/train_128/eval_metrics.py +0 -205
  10. myosotis_researches/CcGAN/train_128/opts.py +0 -87
  11. myosotis_researches/CcGAN/train_128/pretrain_AE.py +0 -268
  12. myosotis_researches/CcGAN/train_128/pretrain_CNN_class.py +0 -251
  13. myosotis_researches/CcGAN/train_128/pretrain_CNN_regre.py +0 -255
  14. myosotis_researches/CcGAN/train_128/train_ccgan.py +0 -303
  15. myosotis_researches/CcGAN/train_128/train_cgan.py +0 -254
  16. myosotis_researches/CcGAN/train_128/train_cgan_concat.py +0 -242
  17. myosotis_researches/CcGAN/train_128/utils.py +0 -120
  18. myosotis_researches/CcGAN/train_128_output_10/DiffAugment_pytorch.py +0 -76
  19. myosotis_researches/CcGAN/train_128_output_10/__init__.py +0 -0
  20. myosotis_researches/CcGAN/train_128_output_10/eval_metrics.py +0 -205
  21. myosotis_researches/CcGAN/train_128_output_10/opts.py +0 -87
  22. myosotis_researches/CcGAN/train_128_output_10/pretrain_AE.py +0 -268
  23. myosotis_researches/CcGAN/train_128_output_10/pretrain_CNN_class.py +0 -251
  24. myosotis_researches/CcGAN/train_128_output_10/pretrain_CNN_regre.py +0 -255
  25. myosotis_researches/CcGAN/train_128_output_10/train_net_for_label_embed.py +0 -181
  26. myosotis_researches/CcGAN/train_128_output_10/utils.py +0 -120
  27. myosotis_researches-0.1.8.dist-info/RECORD +0 -59
  28. /myosotis_researches/CcGAN/{train_128 → train}/train_net_for_label_embed.py +0 -0
  29. {myosotis_researches-0.1.8.dist-info → myosotis_researches-0.1.10.dist-info}/WHEEL +0 -0
  30. {myosotis_researches-0.1.8.dist-info → myosotis_researches-0.1.10.dist-info}/licenses/LICENSE +0 -0
  31. {myosotis_researches-0.1.8.dist-info → myosotis_researches-0.1.10.dist-info}/top_level.txt +0 -0
@@ -1,251 +0,0 @@
1
- """
2
-
3
- Pre-train a CNN on the whole dataset for evaluation purpose
4
-
5
- """
6
- import os
7
- import argparse
8
- import shutil
9
- import timeit
10
-
11
- import torch
12
- import torchvision
13
- import torchvision.transforms as transforms
14
- import numpy as np
15
- import torch.nn as nn
16
- import torch.backends.cudnn as cudnn
17
- import random
18
- import matplotlib.pyplot as plt
19
- import matplotlib as mpl
20
- from torch import autograd
21
- from torchvision.utils import save_image
22
- import csv
23
- from tqdm import tqdm
24
- import gc
25
- import h5py
26
-
27
- from models import *
28
- from utils import IMGs_dataset
29
-
30
-
31
- #############################
32
- # Settings
33
- #############################
34
-
35
- parser = argparse.ArgumentParser(description='Pre-train CNNs')
36
- parser.add_argument('--root_path', type=str, default='')
37
- parser.add_argument('--data_path', type=str, default='')
38
- parser.add_argument('--num_workers', type=int, default=0)
39
- parser.add_argument('--CNN', type=str, default='ResNet34_class',
40
- help='CNN for training; ResNetXX')
41
- parser.add_argument('--epochs', type=int, default=200, metavar='N',
42
- help='number of epochs to train CNNs (default: 200)')
43
- parser.add_argument('--batch_size_train', type=int, default=128, metavar='N',
44
- help='input batch size for training')
45
- parser.add_argument('--batch_size_valid', type=int, default=10, metavar='N',
46
- help='input batch size for testing')
47
- parser.add_argument('--base_lr', type=float, default=0.01,
48
- help='learning rate, default=0.1')
49
- parser.add_argument('--weight_dacay', type=float, default=1e-4,
50
- help='Weigth decay, default=1e-4')
51
- parser.add_argument('--seed', type=int, default=2020, metavar='S',
52
- help='random seed (default: 1)')
53
- parser.add_argument('--CVMode', action='store_true', default=False,
54
- help='CV mode?')
55
- parser.add_argument('--valid_proport', type=float, default=0.1,
56
- help='Proportion of validation samples')
57
- parser.add_argument('--img_size', type=int, default=128, metavar='N')
58
- parser.add_argument('--min_label', type=float, default=0.0)
59
- parser.add_argument('--max_label', type=float, default=90.0)
60
- args = parser.parse_args()
61
-
62
-
63
- wd = args.root_path
64
- os.chdir(wd)
65
- from ..models_128 import *
66
- from .utils import IMGs_dataset
67
-
68
- # cuda
69
- device = torch.device("cuda")
70
- ngpu = torch.cuda.device_count() # number of gpus
71
-
72
- # random seed
73
- random.seed(args.seed)
74
- torch.manual_seed(args.seed)
75
- torch.backends.cudnn.deterministic = True
76
- np.random.seed(args.seed)
77
-
78
- # directories for checkpoint, images and log files
79
- save_models_folder = wd + '/output/eval_models/'
80
- os.makedirs(save_models_folder, exist_ok=True)
81
-
82
-
83
- # data loader
84
- data_filename = args.data_path + '/Ra_' + str(args.img_size) + 'x' + str(args.img_size) + '.h5'
85
- hf = h5py.File(data_filename, 'r')
86
- angles = hf['labels'][:]
87
- angles = angles.astype(float)
88
- labels = hf['types'][:]
89
- images = hf['images'][:]
90
- hf.close()
91
- num_classes = len(set(labels))
92
- assert max(labels)==num_classes-1
93
-
94
- q1 = args.min_label
95
- q2 = args.max_label
96
- indx = np.where((angles>q1)*(angles<q2)==True)[0]
97
- angles = angles[indx]
98
- labels = labels[indx]
99
- images = images[indx]
100
- assert len(labels)==len(images)
101
- assert len(angles)==len(images)
102
-
103
-
104
- # define training (and validaiton) set
105
- if args.CVMode:
106
- for i in range(num_classes):
107
- indx_i = np.where(labels==i)[0] #i-th class
108
- np.random.shuffle(indx_i)
109
- num_imgs_all_i = len(indx_i)
110
- num_imgs_valid_i = int(num_imgs_all_i*args.valid_proport)
111
- if i == 0:
112
- indx_valid = indx_i[0:num_imgs_valid_i]
113
- indx_train = indx_i[num_imgs_valid_i:]
114
- else:
115
- indx_valid = np.concatenate((indx_valid, indx_i[0:num_imgs_valid_i]))
116
- indx_train = np.concatenate((indx_train, indx_i[num_imgs_valid_i:]))
117
- #end for i
118
- trainset = IMGs_dataset(images[indx_train], labels[indx_train], normalize=True)
119
- trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size_train, shuffle=True, num_workers=args.num_workers)
120
- validset = IMGs_dataset(images[indx_valid], labels[indx_valid], normalize=True)
121
- validloader = torch.utils.data.DataLoader(validset, batch_size=args.batch_size_valid, shuffle=False, num_workers=args.num_workers)
122
- else:
123
- trainset = IMGs_dataset(images, labels, normalize=True)
124
- trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size_train, shuffle=True, num_workers=args.num_workers)
125
-
126
-
127
-
128
- ###########################################################################################################
129
- # Necessary functions
130
- ###########################################################################################################
131
-
132
- #initialize CNNs
133
- def net_initialization(Pretrained_CNN_Name, ngpu = 1, num_classes=num_classes):
134
- if Pretrained_CNN_Name == "ResNet18_class":
135
- net = ResNet18_class_eval(num_classes=num_classes, ngpu = ngpu)
136
- elif Pretrained_CNN_Name == "ResNet34_class":
137
- net = ResNet34_class_eval(num_classes=num_classes, ngpu = ngpu)
138
- elif Pretrained_CNN_Name == "ResNet50_class":
139
- net = ResNet50_class_eval(num_classes=num_classes, ngpu = ngpu)
140
- elif Pretrained_CNN_Name == "ResNet101_class":
141
- net = ResNet101_class_eval(num_classes=num_classes, ngpu = ngpu)
142
-
143
- net_name = 'PreCNNForEvalGANs_' + Pretrained_CNN_Name #get the net's name
144
- net = net.to(device)
145
-
146
- return net, net_name
147
-
148
- #adjust CNN learning rate
149
- def adjust_learning_rate(optimizer, epoch, BASE_LR_CNN):
150
- lr = BASE_LR_CNN
151
- # if epoch >= 35:
152
- # lr /= 10
153
- # if epoch >= 70:
154
- # lr /= 10
155
- if epoch >= 50:
156
- lr /= 10
157
- if epoch >= 120:
158
- lr /= 10
159
- for param_group in optimizer.param_groups:
160
- param_group['lr'] = lr
161
-
162
-
163
- def train_CNN():
164
-
165
- start_tmp = timeit.default_timer()
166
- for epoch in range(args.epochs):
167
- net.train()
168
- train_loss = 0
169
- adjust_learning_rate(optimizer, epoch, args.base_lr)
170
- for batch_idx, (batch_train_images, batch_train_labels) in enumerate(trainloader):
171
-
172
- # batch_train_images = nn.functional.interpolate(batch_train_images, size = (299,299), scale_factor=None, mode='bilinear', align_corners=False)
173
-
174
- batch_train_images = batch_train_images.type(torch.float).cuda()
175
- batch_train_labels = batch_train_labels.type(torch.long).cuda()
176
-
177
- #Forward pass
178
- outputs,_ = net(batch_train_images)
179
- loss = criterion(outputs, batch_train_labels)
180
-
181
- #backward pass
182
- optimizer.zero_grad()
183
- loss.backward()
184
- optimizer.step()
185
-
186
- train_loss += loss.cpu().item()
187
- #end for batch_idx
188
- train_loss = train_loss / len(trainloader)
189
-
190
- if args.CVMode:
191
- valid_acc = valid_CNN(verbose=False)
192
- print('CNN: [epoch %d/%d] train_loss:%f valid_acc:%f Time: %.4f' % (epoch+1, args.epochs, train_loss, valid_acc, timeit.default_timer()-start_tmp))
193
- else:
194
- print('CNN: [epoch %d/%d] train_loss:%f Time: %.4f' % (epoch+1, args.epochs, train_loss, timeit.default_timer()-start_tmp))
195
- #end for epoch
196
-
197
- return net, optimizer
198
-
199
-
200
- if args.CVMode:
201
- def valid_CNN(verbose=True):
202
- net.eval() # eval mode (batchnorm uses moving mean/variance instead of mini-batch mean/variance)
203
- with torch.no_grad():
204
- correct = 0
205
- total = 0
206
- for batch_idx, (images, labels) in enumerate(validloader):
207
- images = images.type(torch.float).cuda()
208
- labels = labels.type(torch.long).cuda()
209
- outputs,_ = net(images)
210
- _, predicted = torch.max(outputs.data, 1)
211
- total += labels.size(0)
212
- correct += (predicted == labels).sum().item()
213
- if verbose:
214
- print('Valid Accuracy of the model on the validation set: {} %'.format(100.0 * correct / total))
215
- return 100.0 * correct / total
216
-
217
-
218
-
219
- ###########################################################################################################
220
- # Training and Testing
221
- ###########################################################################################################
222
- # model initialization
223
- net, net_name = net_initialization(args.CNN, ngpu = ngpu, num_classes = num_classes)
224
- criterion = nn.CrossEntropyLoss()
225
- optimizer = torch.optim.SGD(net.parameters(), lr = args.base_lr, momentum= 0.9, weight_decay=args.weight_dacay)
226
-
227
- filename_ckpt = save_models_folder + '/ckpt_{}_epoch_{}_seed_{}_classify_{}_chair_types_CVMode_{}.pth'.format(net_name, args.epochs, args.seed, num_classes, args.CVMode)
228
-
229
- # training
230
- if not os.path.isfile(filename_ckpt):
231
- # TRAIN CNN
232
- print("\n Begin training CNN: ")
233
- start = timeit.default_timer()
234
- net, optimizer = train_CNN()
235
- stop = timeit.default_timer()
236
- print("Time elapses: {}s".format(stop - start))
237
- # save model
238
- torch.save({
239
- 'net_state_dict': net.state_dict(),
240
- }, filename_ckpt)
241
- else:
242
- print("\n Ckpt already exists")
243
- print("\n Loading...")
244
- checkpoint = torch.load(filename_ckpt)
245
- net.load_state_dict(checkpoint['net_state_dict'])
246
- torch.cuda.empty_cache()#release GPU mem which is not references
247
-
248
- if args.CVMode:
249
- #validation
250
- _ = valid_CNN(True)
251
- torch.cuda.empty_cache()
@@ -1,255 +0,0 @@
1
- """
2
-
3
- Pre-train a CNN on the whole dataset for evaluation purpose
4
-
5
- """
6
- import os
7
- import argparse
8
- import shutil
9
- import timeit
10
- import torch
11
- import torchvision
12
- import torchvision.transforms as transforms
13
- import numpy as np
14
- import torch.nn as nn
15
- import torch.backends.cudnn as cudnn
16
- import random
17
- import matplotlib.pyplot as plt
18
- import matplotlib as mpl
19
- from torch import autograd
20
- from torchvision.utils import save_image
21
- import csv
22
- from tqdm import tqdm
23
- import gc
24
- import h5py
25
-
26
-
27
- #############################
28
- # Settings
29
- #############################
30
-
31
- parser = argparse.ArgumentParser(description='Pre-train CNNs')
32
- parser.add_argument('--root_path', type=str, default='')
33
- parser.add_argument('--data_path', type=str, default='')
34
- parser.add_argument('--num_workers', type=int, default=0)
35
- parser.add_argument('--CNN', type=str, default='ResNet34_regre',
36
- help='CNN for training; ResNetXX')
37
- parser.add_argument('--epochs', type=int, default=200, metavar='N',
38
- help='number of epochs to train CNNs (default: 200)')
39
- parser.add_argument('--batch_size_train', type=int, default=256, metavar='N',
40
- help='input batch size for training')
41
- parser.add_argument('--batch_size_valid', type=int, default=10, metavar='N',
42
- help='input batch size for testing')
43
- parser.add_argument('--base_lr', type=float, default=0.01,
44
- help='learning rate, default=0.1')
45
- parser.add_argument('--weight_dacay', type=float, default=1e-4,
46
- help='Weigth decay, default=1e-4')
47
- parser.add_argument('--seed', type=int, default=2020, metavar='S',
48
- help='random seed (default: 1)')
49
- parser.add_argument('--CVMode', action='store_true', default=False,
50
- help='CV mode?')
51
- parser.add_argument('--img_size', type=int, default=128, metavar='N')
52
- parser.add_argument('--min_label', type=float, default=0.0)
53
- parser.add_argument('--max_label', type=float, default=90.0)
54
- args = parser.parse_args()
55
-
56
-
57
- wd = args.root_path
58
- os.chdir(wd)
59
- from ..models_128 import *
60
- from .utils import IMGs_dataset
61
-
62
- # cuda
63
- device = torch.device("cuda")
64
- ngpu = torch.cuda.device_count() # number of gpus
65
-
66
- # random seed
67
- random.seed(args.seed)
68
- torch.manual_seed(args.seed)
69
- torch.backends.cudnn.deterministic = True
70
- np.random.seed(args.seed)
71
-
72
- # directories for checkpoint, images and log files
73
- save_models_folder = wd + '/output/eval_models/'
74
- os.makedirs(save_models_folder, exist_ok=True)
75
-
76
-
77
- ###########################################################################################################
78
- # Data
79
- ###########################################################################################################
80
- # data loader
81
- data_filename = args.data_path + '/Ra_' + str(args.img_size) + 'x' + str(args.img_size) + '.h5'
82
- hf = h5py.File(data_filename, 'r')
83
- labels = hf['labels'][:]
84
- labels = labels.astype(float)
85
- images = hf['images'][:]
86
- hf.close()
87
- N_all = len(images)
88
- assert len(images) == len(labels)
89
-
90
- q1 = args.min_label
91
- q2 = args.max_label
92
- indx = np.where((labels>q1)*(labels<q2)==True)[0]
93
- labels = labels[indx]
94
- images = images[indx]
95
- assert len(labels)==len(images)
96
-
97
-
98
- # normalize to [0, 1]
99
- #min_label = np.min(labels)
100
- #labels += np.abs(min_label)
101
- #max_label = np.max(labels)
102
- #labels /= max_label
103
- labels /= args.max_label
104
-
105
-
106
- # define training and validation sets
107
- if args.CVMode:
108
- #80% Training; 20% valdation
109
- indx_all = np.arange(len(labels))
110
- np.random.shuffle(indx_all)
111
- indx_valid = indx_all[0:int(0.2*len(labels))]
112
- indx_train = indx_all[int(0.2*len(labels)):]
113
-
114
- trainset = IMGs_dataset(images[indx_train], labels[indx_train], normalize=True)
115
- trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size_train, shuffle=True, num_workers=args.num_workers)
116
- validset = IMGs_dataset(images[indx_valid], labels[indx_valid], normalize=True)
117
- validloader = torch.utils.data.DataLoader(validset, batch_size=args.batch_size_valid, shuffle=False, num_workers=args.num_workers)
118
-
119
- else:
120
- trainset = IMGs_dataset(images, labels, normalize=True)
121
- trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size_train, shuffle=True, num_workers=args.num_workers)
122
-
123
-
124
-
125
-
126
- ###########################################################################################################
127
- # Necessary functions
128
- ###########################################################################################################
129
-
130
- #initialize CNNs
131
- def net_initialization(Pretrained_CNN_Name, ngpu = 1):
132
- if Pretrained_CNN_Name == "ResNet18_regre":
133
- net = ResNet18_regre_eval(ngpu = ngpu)
134
- elif Pretrained_CNN_Name == "ResNet34_regre":
135
- net = ResNet34_regre_eval(ngpu = ngpu)
136
- elif Pretrained_CNN_Name == "ResNet50_regre":
137
- net = ResNet50_regre_eval(ngpu = ngpu)
138
- elif Pretrained_CNN_Name == "ResNet101_regre":
139
- net = ResNet101_regre_eval(ngpu = ngpu)
140
-
141
- net_name = 'PreCNNForEvalGANs_' + Pretrained_CNN_Name #get the net's name
142
- net = net.to(device)
143
-
144
- return net, net_name
145
-
146
- #adjust CNN learning rate
147
- def adjust_learning_rate(optimizer, epoch, BASE_LR_CNN):
148
- lr = BASE_LR_CNN
149
- # if epoch >= 35:
150
- # lr /= 10
151
- # if epoch >= 70:
152
- # lr /= 10
153
- if epoch >= 50:
154
- lr /= 10
155
- if epoch >= 120:
156
- lr /= 10
157
- for param_group in optimizer.param_groups:
158
- param_group['lr'] = lr
159
-
160
-
161
- def train_CNN():
162
-
163
- start_tmp = timeit.default_timer()
164
- for epoch in range(args.epochs):
165
- net.train()
166
- train_loss = 0
167
- adjust_learning_rate(optimizer, epoch, args.base_lr)
168
- for batch_idx, (batch_train_images, batch_train_labels) in enumerate(trainloader):
169
-
170
- # batch_train_images = nn.functional.interpolate(batch_train_images, size = (299,299), scale_factor=None, mode='bilinear', align_corners=False)
171
-
172
- batch_train_images = batch_train_images.type(torch.float).cuda()
173
- batch_train_labels = batch_train_labels.type(torch.float).view(-1,1).cuda()
174
-
175
- #Forward pass
176
- outputs,_ = net(batch_train_images)
177
- loss = criterion(outputs, batch_train_labels)
178
-
179
- #backward pass
180
- optimizer.zero_grad()
181
- loss.backward()
182
- optimizer.step()
183
-
184
- train_loss += loss.cpu().item()
185
- #end for batch_idx
186
- train_loss = train_loss / len(trainloader)
187
-
188
- if args.CVMode:
189
- valid_loss = valid_CNN(verbose=False)
190
- print('CNN: [epoch %d/%d] train_loss:%f valid_loss (avg_abs):%f Time: %.4f' % (epoch+1, args.epochs, train_loss, valid_loss, timeit.default_timer()-start_tmp))
191
- else:
192
- print('CNN: [epoch %d/%d] train_loss:%f Time:%.4f' % (epoch+1, args.epochs, train_loss, timeit.default_timer()-start_tmp))
193
- #end for epoch
194
-
195
- return net, optimizer
196
-
197
- if args.CVMode:
198
- def valid_CNN(verbose=True):
199
- net.eval() # eval mode (batchnorm uses moving mean/variance instead of mini-batch mean/variance)
200
- with torch.no_grad():
201
- abs_diff_avg = 0
202
- total = 0
203
- for batch_idx, (images, labels) in enumerate(validloader):
204
- images = images.type(torch.float).cuda()
205
- labels = labels.type(torch.float).view(-1).cpu().numpy()
206
- outputs,_ = net(images)
207
- outputs = outputs.view(-1).cpu().numpy()
208
- labels = labels * args.max_label
209
- outputs = outputs * args.max_label
210
- abs_diff_avg += np.sum(np.abs(labels-outputs))
211
- total += len(labels)
212
-
213
- if verbose:
214
- print('Validation Average Absolute Difference: {}'.format(abs_diff_avg/total))
215
- return abs_diff_avg/total
216
-
217
-
218
-
219
- ###########################################################################################################
220
- # Training and validation
221
- ###########################################################################################################
222
-
223
-
224
- # model initialization
225
- net, net_name = net_initialization(args.CNN, ngpu = ngpu)
226
- criterion = nn.MSELoss()
227
- optimizer = torch.optim.SGD(net.parameters(), lr = args.base_lr, momentum= 0.9, weight_decay=args.weight_dacay)
228
-
229
- filename_ckpt = save_models_folder + '/ckpt_{}_epoch_{}_seed_{}_CVMode_{}.pth'.format(net_name, args.epochs, args.seed, args.CVMode)
230
-
231
-
232
- # training
233
- if not os.path.isfile(filename_ckpt):
234
- # TRAIN CNN
235
- print("\n Begin training CNN: ")
236
- start = timeit.default_timer()
237
- net, optimizer = train_CNN()
238
- stop = timeit.default_timer()
239
- print("Time elapses: {}s".format(stop - start))
240
- # save model
241
- torch.save({
242
- 'net_state_dict': net.state_dict(),
243
- }, filename_ckpt)
244
- else:
245
- print("\n Ckpt already exists")
246
- print("\n Loading...")
247
- checkpoint = torch.load(filename_ckpt)
248
- net.load_state_dict(checkpoint['net_state_dict'])
249
- torch.cuda.empty_cache()#release GPU mem which is not references
250
-
251
-
252
- if args.CVMode:
253
- #validation
254
- _ = valid_CNN(True)
255
- torch.cuda.empty_cache()
@@ -1,181 +0,0 @@
1
-
2
- import torch
3
- import torch.nn as nn
4
- from torchvision.utils import save_image
5
- import numpy as np
6
- import os
7
- import timeit
8
- from PIL import Image
9
-
10
-
11
-
12
-
13
- #-------------------------------------------------------------
14
- def train_net_embed(net, net_name, trainloader, testloader, epochs=200, resume_epoch = 0, lr_base=0.01, lr_decay_factor=0.1, lr_decay_epochs=[80, 140], weight_decay=1e-4, path_to_ckpt = None):
15
-
16
- ''' learning rate decay '''
17
- def adjust_learning_rate_1(optimizer, epoch):
18
- """decrease the learning rate """
19
- lr = lr_base
20
-
21
- num_decays = len(lr_decay_epochs)
22
- for decay_i in range(num_decays):
23
- if epoch >= lr_decay_epochs[decay_i]:
24
- lr = lr * lr_decay_factor
25
- #end if epoch
26
- #end for decay_i
27
- for param_group in optimizer.param_groups:
28
- param_group['lr'] = lr
29
-
30
- net = net.cuda()
31
- criterion = nn.MSELoss()
32
- optimizer = torch.optim.SGD(net.parameters(), lr = lr_base, momentum= 0.9, weight_decay=weight_decay)
33
-
34
- # resume training; load checkpoint
35
- if path_to_ckpt is not None and resume_epoch>0:
36
- save_file = path_to_ckpt + "/embed_x2y_ckpt_in_train/embed_x2y_checkpoint_epoch_{}.pth".format(resume_epoch)
37
- checkpoint = torch.load(save_file)
38
- net.load_state_dict(checkpoint['net_state_dict'])
39
- optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
40
- torch.set_rng_state(checkpoint['rng_state'])
41
- #end if
42
-
43
- start_tmp = timeit.default_timer()
44
- for epoch in range(resume_epoch, epochs):
45
- net.train()
46
- train_loss = 0
47
- adjust_learning_rate_1(optimizer, epoch)
48
- for _, (batch_train_images, batch_train_labels) in enumerate(trainloader):
49
-
50
- # batch_train_images = nn.functional.interpolate(batch_train_images, size = (299,299), scale_factor=None, mode='bilinear', align_corners=False)
51
-
52
- batch_train_images = batch_train_images.type(torch.float).cuda()
53
- batch_train_labels = batch_train_labels.type(torch.float).view(-1,1).cuda()
54
-
55
- #Forward pass
56
- outputs, _ = net(batch_train_images)
57
- loss = criterion(outputs, batch_train_labels)
58
-
59
- #backward pass
60
- optimizer.zero_grad()
61
- loss.backward()
62
- optimizer.step()
63
-
64
- train_loss += loss.cpu().item()
65
- #end for batch_idx
66
- train_loss = train_loss / len(trainloader)
67
-
68
- if testloader is None:
69
- print('Train net_x2y for embedding: [epoch %d/%d] train_loss:%f Time:%.4f' % (epoch+1, epochs, train_loss, timeit.default_timer()-start_tmp))
70
- else:
71
- net.eval() # eval mode (batchnorm uses moving mean/variance instead of mini-batch mean/variance)
72
- with torch.no_grad():
73
- test_loss = 0
74
- for batch_test_images, batch_test_labels in testloader:
75
- batch_test_images = batch_test_images.type(torch.float).cuda()
76
- batch_test_labels = batch_test_labels.type(torch.float).view(-1,1).cuda()
77
- outputs,_ = net(batch_test_images)
78
- loss = criterion(outputs, batch_test_labels)
79
- test_loss += loss.cpu().item()
80
- test_loss = test_loss/len(testloader)
81
-
82
- print('Train net_x2y for label embedding: [epoch %d/%d] train_loss:%f test_loss:%f Time:%.4f' % (epoch+1, epochs, train_loss, test_loss, timeit.default_timer()-start_tmp))
83
-
84
- #save checkpoint
85
- if path_to_ckpt is not None and (((epoch+1) % 50 == 0) or (epoch+1==epochs)):
86
- save_file = path_to_ckpt + "/embed_x2y_ckpt_in_train/embed_x2y_checkpoint_epoch_{}.pth".format(epoch+1)
87
- os.makedirs(os.path.dirname(save_file), exist_ok=True)
88
- torch.save({
89
- 'epoch': epoch,
90
- 'net_state_dict': net.state_dict(),
91
- 'optimizer_state_dict': optimizer.state_dict(),
92
- 'rng_state': torch.get_rng_state()
93
- }, save_file)
94
- #end for epoch
95
-
96
- return net
97
-
98
-
99
-
100
-
101
- ###################################################################################
102
- class label_dataset(torch.utils.data.Dataset):
103
- def __init__(self, labels):
104
- super(label_dataset, self).__init__()
105
-
106
- self.labels = labels
107
- self.n_samples = len(self.labels)
108
-
109
- def __getitem__(self, index):
110
-
111
- y = self.labels[index]
112
- return y
113
-
114
- def __len__(self):
115
- return self.n_samples
116
-
117
-
118
- def train_net_y2h(unique_labels_norm, net_y2h, net_embed, epochs=500, lr_base=0.01, lr_decay_factor=0.1, lr_decay_epochs=[150, 250, 350], weight_decay=1e-4, batch_size=128):
119
- '''
120
- unique_labels_norm: an array of normalized unique labels
121
- '''
122
-
123
- ''' learning rate decay '''
124
- def adjust_learning_rate_2(optimizer, epoch):
125
- """decrease the learning rate """
126
- lr = lr_base
127
-
128
- num_decays = len(lr_decay_epochs)
129
- for decay_i in range(num_decays):
130
- if epoch >= lr_decay_epochs[decay_i]:
131
- lr = lr * lr_decay_factor
132
- #end if epoch
133
- #end for decay_i
134
- for param_group in optimizer.param_groups:
135
- param_group['lr'] = lr
136
-
137
-
138
- assert np.max(unique_labels_norm)<=1 and np.min(unique_labels_norm)>=0
139
- trainset = label_dataset(unique_labels_norm)
140
- trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)
141
-
142
- net_embed.eval()
143
- net_h2y=net_embed.module.h2y #convert embedding labels to original labels
144
- optimizer_y2h = torch.optim.SGD(net_y2h.parameters(), lr = lr_base, momentum= 0.9, weight_decay=weight_decay)
145
-
146
- start_tmp = timeit.default_timer()
147
- for epoch in range(epochs):
148
- net_y2h.train()
149
- train_loss = 0
150
- adjust_learning_rate_2(optimizer_y2h, epoch)
151
- for _, batch_labels in enumerate(trainloader):
152
-
153
- batch_labels = batch_labels.type(torch.float).view(-1,1).cuda()
154
-
155
- # generate noises which will be added to labels
156
- batch_size_curr = len(batch_labels)
157
- batch_gamma = np.random.normal(0, 0.2, batch_size_curr)
158
- batch_gamma = torch.from_numpy(batch_gamma).view(-1,1).type(torch.float).cuda()
159
-
160
- # add noise to labels
161
- batch_labels_noise = torch.clamp(batch_labels+batch_gamma, 0.0, 1.0)
162
-
163
- #Forward pass
164
- batch_hiddens_noise = net_y2h(batch_labels_noise)
165
- batch_rec_labels_noise = net_h2y(batch_hiddens_noise)
166
-
167
- loss = nn.MSELoss()(batch_rec_labels_noise, batch_labels_noise)
168
-
169
- #backward pass
170
- optimizer_y2h.zero_grad()
171
- loss.backward()
172
- optimizer_y2h.step()
173
-
174
- train_loss += loss.cpu().item()
175
- #end for batch_idx
176
- train_loss = train_loss / len(trainloader)
177
-
178
- print('\n Train net_y2h: [epoch %d/%d] train_loss:%f Time:%.4f' % (epoch+1, epochs, train_loss, timeit.default_timer()-start_tmp))
179
- #end for epoch
180
-
181
- return net_y2h