spacr 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- spacr/__init__.py +37 -0
- spacr/__main__.py +15 -0
- spacr/annotate_app.py +495 -0
- spacr/cli.py +203 -0
- spacr/core.py +2250 -0
- spacr/gui_mask_app.py +247 -0
- spacr/gui_measure_app.py +214 -0
- spacr/gui_utils.py +488 -0
- spacr/io.py +2271 -0
- spacr/logger.py +20 -0
- spacr/mask_app.py +818 -0
- spacr/measure.py +1014 -0
- spacr/old_code.py +104 -0
- spacr/plot.py +1273 -0
- spacr/sim.py +1187 -0
- spacr/timelapse.py +576 -0
- spacr/train.py +494 -0
- spacr/umap.py +689 -0
- spacr/utils.py +2726 -0
- spacr/version.py +19 -0
- spacr-0.0.1.dist-info/LICENSE +21 -0
- spacr-0.0.1.dist-info/METADATA +64 -0
- spacr-0.0.1.dist-info/RECORD +26 -0
- spacr-0.0.1.dist-info/WHEEL +5 -0
- spacr-0.0.1.dist-info/entry_points.txt +5 -0
- spacr-0.0.1.dist-info/top_level.txt +1 -0
spacr/train.py
ADDED
@@ -0,0 +1,494 @@
|
|
1
|
+
import os, torch, time, gc, datetime
|
2
|
+
import pandas as pd
|
3
|
+
from torch.optim import Adagrad
|
4
|
+
from torch.optim import AdamW
|
5
|
+
from torch.autograd import grad
|
6
|
+
from torch.optim.lr_scheduler import StepLR
|
7
|
+
import torch.nn.functional as F
|
8
|
+
from IPython.display import display, clear_output
|
9
|
+
|
10
|
+
from .logger import log_function_call
|
11
|
+
|
12
|
+
def evaluate_model_core(model, loader, loader_name, epoch, loss_type):
|
13
|
+
"""
|
14
|
+
Evaluates the performance of a model on a given data loader.
|
15
|
+
|
16
|
+
Args:
|
17
|
+
model (torch.nn.Module): The model to evaluate.
|
18
|
+
loader (torch.utils.data.DataLoader): The data loader to evaluate the model on.
|
19
|
+
loader_name (str): The name of the data loader.
|
20
|
+
epoch (int): The current epoch number.
|
21
|
+
loss_type (str): The type of loss function to use.
|
22
|
+
|
23
|
+
Returns:
|
24
|
+
data_df (pandas.DataFrame): The classification metrics data as a DataFrame.
|
25
|
+
prediction_pos_probs (list): The positive class probabilities for each prediction.
|
26
|
+
all_labels (list): The true labels for each prediction.
|
27
|
+
"""
|
28
|
+
|
29
|
+
from .utils import calculate_loss, classification_metrics
|
30
|
+
|
31
|
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
32
|
+
model.eval()
|
33
|
+
loss = 0
|
34
|
+
correct = 0
|
35
|
+
total_samples = 0
|
36
|
+
prediction_pos_probs = []
|
37
|
+
all_labels = []
|
38
|
+
model = model.to(device)
|
39
|
+
with torch.no_grad():
|
40
|
+
for batch_idx, (data, target, _) in enumerate(loader, start=1):
|
41
|
+
start_time = time.time()
|
42
|
+
data, target = data.to(device), target.to(device).float()
|
43
|
+
#data, target = data.to(torch.float).to(device), target.to(device).float()
|
44
|
+
output = model(data)
|
45
|
+
loss += F.binary_cross_entropy_with_logits(output, target, reduction='sum').item()
|
46
|
+
loss = calculate_loss(output, target, loss_type=loss_type)
|
47
|
+
loss += loss.item()
|
48
|
+
total_samples += data.size(0)
|
49
|
+
pred = torch.where(output >= 0.5,
|
50
|
+
torch.Tensor([1.0]).to(device).float(),
|
51
|
+
torch.Tensor([0.0]).to(device).float())
|
52
|
+
correct += pred.eq(target.view_as(pred)).sum().item()
|
53
|
+
batch_prediction_pos_prob = torch.sigmoid(output).cpu().numpy()
|
54
|
+
prediction_pos_probs.extend(batch_prediction_pos_prob.tolist())
|
55
|
+
all_labels.extend(target.cpu().numpy().tolist())
|
56
|
+
mean_loss = loss / total_samples
|
57
|
+
acc = correct / total_samples
|
58
|
+
end_time = time.time()
|
59
|
+
test_time = end_time - start_time
|
60
|
+
print(f'\rTest: epoch: {epoch} Accuracy: {acc:.5f} batch: {batch_idx+1}/{len(loader)} loss: {mean_loss:.5f} loss: {mean_loss:.5f} time {test_time:.5f}', end='\r', flush=True)
|
61
|
+
loss /= len(loader)
|
62
|
+
data_df = classification_metrics(all_labels, prediction_pos_probs, loader_name, loss, epoch)
|
63
|
+
return data_df, prediction_pos_probs, all_labels
|
64
|
+
|
65
|
+
def evaluate_model_performance(loaders, model, loader_name_list, epoch, train_mode, loss_type):
|
66
|
+
"""
|
67
|
+
Evaluate the performance of a model on given data loaders.
|
68
|
+
|
69
|
+
Args:
|
70
|
+
loaders (list): List of data loaders.
|
71
|
+
model: The model to evaluate.
|
72
|
+
loader_name_list (list): List of names for the data loaders.
|
73
|
+
epoch (int): The current epoch.
|
74
|
+
train_mode (str): The training mode ('erm' or 'irm').
|
75
|
+
loss_type: The type of loss function.
|
76
|
+
|
77
|
+
Returns:
|
78
|
+
tuple: A tuple containing the evaluation result and the time taken for evaluation.
|
79
|
+
"""
|
80
|
+
start_time = time.time()
|
81
|
+
df_list = []
|
82
|
+
if train_mode == 'erm':
|
83
|
+
result, _, _ = evaluate_model_core(model, loaders, loader_name_list, epoch, loss_type)
|
84
|
+
if train_mode == 'irm':
|
85
|
+
for loader_index in range(0, len(loaders)):
|
86
|
+
loader = loaders[loader_index]
|
87
|
+
loader_name = loader_name_list[loader_index]
|
88
|
+
data_df, _, _ = evaluate_model_core(model, loader, loader_name, epoch, loss_type)
|
89
|
+
torch.cuda.empty_cache()
|
90
|
+
df_list.append(data_df)
|
91
|
+
result = pd.concat(df_list)
|
92
|
+
nc_mean = result['neg_accuracy'].mean(skipna=True)
|
93
|
+
pc_mean = result['pos_accuracy'].mean(skipna=True)
|
94
|
+
tot_mean = result['accuracy'].mean(skipna=True)
|
95
|
+
loss_mean = result['loss'].mean(skipna=True)
|
96
|
+
prauc_mean = result['prauc'].mean(skipna=True)
|
97
|
+
data_mean = {'accuracy': tot_mean, 'neg_accuracy': nc_mean, 'pos_accuracy': pc_mean, 'loss': loss_mean, 'prauc': prauc_mean}
|
98
|
+
result = pd.concat([pd.DataFrame(result), pd.DataFrame(data_mean, index=[str(epoch)+'_mean'])])
|
99
|
+
end_time = time.time()
|
100
|
+
test_time = end_time - start_time
|
101
|
+
return result, test_time
|
102
|
+
|
103
|
+
def test_model_core(model, loader, loader_name, epoch, loss_type):
|
104
|
+
|
105
|
+
from .utils import calculate_loss, classification_metrics
|
106
|
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
107
|
+
model.eval()
|
108
|
+
loss = 0
|
109
|
+
correct = 0
|
110
|
+
total_samples = 0
|
111
|
+
prediction_pos_probs = []
|
112
|
+
all_labels = []
|
113
|
+
filenames = []
|
114
|
+
true_targets = []
|
115
|
+
predicted_outputs = []
|
116
|
+
|
117
|
+
model = model.to(device)
|
118
|
+
with torch.no_grad():
|
119
|
+
for batch_idx, (data, target, filename) in enumerate(loader, start=1): # Assuming loader provides filenames
|
120
|
+
start_time = time.time()
|
121
|
+
data, target = data.to(device), target.to(device).float()
|
122
|
+
output = model(data)
|
123
|
+
loss += F.binary_cross_entropy_with_logits(output, target, reduction='sum').item()
|
124
|
+
loss = calculate_loss(output, target, loss_type=loss_type)
|
125
|
+
loss += loss.item()
|
126
|
+
total_samples += data.size(0)
|
127
|
+
pred = torch.where(output >= 0.5,
|
128
|
+
torch.Tensor([1.0]).to(device).float(),
|
129
|
+
torch.Tensor([0.0]).to(device).float())
|
130
|
+
correct += pred.eq(target.view_as(pred)).sum().item()
|
131
|
+
batch_prediction_pos_prob = torch.sigmoid(output).cpu().numpy()
|
132
|
+
prediction_pos_probs.extend(batch_prediction_pos_prob.tolist())
|
133
|
+
all_labels.extend(target.cpu().numpy().tolist())
|
134
|
+
|
135
|
+
# Storing intermediate results in lists
|
136
|
+
true_targets.extend(target.cpu().numpy().tolist())
|
137
|
+
predicted_outputs.extend(pred.cpu().numpy().tolist())
|
138
|
+
filenames.extend(filename)
|
139
|
+
|
140
|
+
mean_loss = loss / total_samples
|
141
|
+
acc = correct / total_samples
|
142
|
+
end_time = time.time()
|
143
|
+
test_time = end_time - start_time
|
144
|
+
print(f'\rTest: epoch: {epoch} Accuracy: {acc:.5f} batch: {batch_idx}/{len(loader)} loss: {mean_loss:.5f} time {test_time:.5f}', end='\r', flush=True)
|
145
|
+
|
146
|
+
# Constructing the DataFrame
|
147
|
+
results_df = pd.DataFrame({
|
148
|
+
'filename': filenames,
|
149
|
+
'true_label': true_targets,
|
150
|
+
'predicted_label': predicted_outputs,
|
151
|
+
'class_1_probability':prediction_pos_probs})
|
152
|
+
|
153
|
+
loss /= len(loader)
|
154
|
+
data_df = classification_metrics(all_labels, prediction_pos_probs, loader_name, loss, epoch)
|
155
|
+
return data_df, prediction_pos_probs, all_labels, results_df
|
156
|
+
|
157
|
+
def test_model_performance(loaders, model, loader_name_list, epoch, train_mode, loss_type):
|
158
|
+
"""
|
159
|
+
Test the performance of a model on given data loaders.
|
160
|
+
|
161
|
+
Args:
|
162
|
+
loaders (list): List of data loaders.
|
163
|
+
model: The model to be tested.
|
164
|
+
loader_name_list (list): List of names for the data loaders.
|
165
|
+
epoch (int): The current epoch.
|
166
|
+
train_mode (str): The training mode ('erm' or 'irm').
|
167
|
+
loss_type: The type of loss function.
|
168
|
+
|
169
|
+
Returns:
|
170
|
+
tuple: A tuple containing the test results and the results dataframe.
|
171
|
+
"""
|
172
|
+
start_time = time.time()
|
173
|
+
df_list = []
|
174
|
+
if train_mode == 'erm':
|
175
|
+
result, prediction_pos_probs, all_labels, results_df = test_model_core(model, loaders, loader_name_list, epoch, loss_type)
|
176
|
+
if train_mode == 'irm':
|
177
|
+
for loader_index in range(0, len(loaders)):
|
178
|
+
loader = loaders[loader_index]
|
179
|
+
loader_name = loader_name_list[loader_index]
|
180
|
+
data_df, prediction_pos_probs, all_labels, results_df = test_model_core(model, loader, loader_name, epoch, loss_type)
|
181
|
+
torch.cuda.empty_cache()
|
182
|
+
df_list.append(data_df)
|
183
|
+
result = pd.concat(df_list)
|
184
|
+
nc_mean = result['neg_accuracy'].mean(skipna=True)
|
185
|
+
pc_mean = result['pos_accuracy'].mean(skipna=True)
|
186
|
+
tot_mean = result['accuracy'].mean(skipna=True)
|
187
|
+
loss_mean = result['loss'].mean(skipna=True)
|
188
|
+
prauc_mean = result['prauc'].mean(skipna=True)
|
189
|
+
data_mean = {'accuracy': tot_mean, 'neg_accuracy': nc_mean, 'pos_accuracy': pc_mean, 'loss': loss_mean, 'prauc': prauc_mean}
|
190
|
+
result = pd.concat([pd.DataFrame(result), pd.DataFrame(data_mean, index=[str(epoch)+'_mean'])])
|
191
|
+
end_time = time.time()
|
192
|
+
test_time = end_time - start_time
|
193
|
+
return result, results_df
|
194
|
+
|
195
|
+
def train_test_model(src, settings, custom_model=False, custom_model_path=None):
|
196
|
+
|
197
|
+
from .io import save_settings, _copy_missclassified
|
198
|
+
from .utils import pick_best_model, test_model_performance
|
199
|
+
from .core import generate_loaders
|
200
|
+
|
201
|
+
if custom_model:
|
202
|
+
model = torch.load(custom_model_path) #if using a custom trained model
|
203
|
+
|
204
|
+
if settings['train']:
|
205
|
+
save_settings(settings, src)
|
206
|
+
torch.cuda.empty_cache()
|
207
|
+
torch.cuda.memory.empty_cache()
|
208
|
+
gc.collect()
|
209
|
+
dst = os.path.join(src,'model')
|
210
|
+
os.makedirs(dst, exist_ok=True)
|
211
|
+
settings['src'] = src
|
212
|
+
settings['dst'] = dst
|
213
|
+
if settings['train']:
|
214
|
+
train, val, plate_names = generate_loaders(src,
|
215
|
+
train_mode=settings['train_mode'],
|
216
|
+
mode='train',
|
217
|
+
image_size=settings['image_size'],
|
218
|
+
batch_size=settings['batch_size'],
|
219
|
+
classes=settings['classes'],
|
220
|
+
num_workers=settings['num_workers'],
|
221
|
+
validation_split=settings['val_split'],
|
222
|
+
pin_memory=settings['pin_memory'],
|
223
|
+
normalize=settings['normalize'],
|
224
|
+
verbose=settings['verbose'])
|
225
|
+
|
226
|
+
if settings['test']:
|
227
|
+
test, _, plate_names_test = generate_loaders(src,
|
228
|
+
train_mode=settings['train_mode'],
|
229
|
+
mode='test',
|
230
|
+
image_size=settings['image_size'],
|
231
|
+
batch_size=settings['batch_size'],
|
232
|
+
classes=settings['classes'],
|
233
|
+
num_workers=settings['num_workers'],
|
234
|
+
validation_split=0.0,
|
235
|
+
pin_memory=settings['pin_memory'],
|
236
|
+
normalize=settings['normalize'],
|
237
|
+
verbose=settings['verbose'])
|
238
|
+
if model == None:
|
239
|
+
model_path = pick_best_model(src+'/model')
|
240
|
+
print(f'Best model: {model_path}')
|
241
|
+
|
242
|
+
model = torch.load(model_path, map_location=lambda storage, loc: storage)
|
243
|
+
|
244
|
+
model_type = settings['model_type']
|
245
|
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
246
|
+
print(type(model))
|
247
|
+
print(model)
|
248
|
+
|
249
|
+
model_fldr = os.path.join(src,'model')
|
250
|
+
time_now = datetime.date.today().strftime('%y%m%d')
|
251
|
+
result_loc = f'{model_fldr}/{model_type}_time_{time_now}_result.csv'
|
252
|
+
acc_loc = f'{model_fldr}/{model_type}_time_{time_now}_acc.csv'
|
253
|
+
print(f'Results wil be saved in: {result_loc}')
|
254
|
+
|
255
|
+
result, accuracy = test_model_performance(loaders=test,
|
256
|
+
model=model,
|
257
|
+
loader_name_list='test',
|
258
|
+
epoch=1,
|
259
|
+
train_mode=settings['train_mode'],
|
260
|
+
loss_type=settings['loss_type'])
|
261
|
+
|
262
|
+
result.to_csv(result_loc, index=True, header=True, mode='w')
|
263
|
+
accuracy.to_csv(acc_loc, index=True, header=True, mode='w')
|
264
|
+
_copy_missclassified(accuracy)
|
265
|
+
else:
|
266
|
+
test = None
|
267
|
+
|
268
|
+
if settings['train']:
|
269
|
+
train_model(dst = settings['dst'],
|
270
|
+
model_type=settings['model_type'],
|
271
|
+
train_loaders = train,
|
272
|
+
train_loader_names = plate_names,
|
273
|
+
train_mode = settings['train_mode'],
|
274
|
+
epochs = settings['epochs'],
|
275
|
+
learning_rate = settings['learning_rate'],
|
276
|
+
init_weights = settings['init_weights'],
|
277
|
+
weight_decay = settings['weight_decay'],
|
278
|
+
amsgrad = settings['amsgrad'],
|
279
|
+
optimizer_type = settings['optimizer_type'],
|
280
|
+
use_checkpoint = settings['use_checkpoint'],
|
281
|
+
dropout_rate = settings['dropout_rate'],
|
282
|
+
num_workers = settings['num_workers'],
|
283
|
+
val_loaders = val,
|
284
|
+
test_loaders = test,
|
285
|
+
intermedeate_save = settings['intermedeate_save'],
|
286
|
+
schedule = settings['schedule'],
|
287
|
+
loss_type=settings['loss_type'],
|
288
|
+
gradient_accumulation=settings['gradient_accumulation'],
|
289
|
+
gradient_accumulation_steps=settings['gradient_accumulation_steps'])
|
290
|
+
|
291
|
+
torch.cuda.empty_cache()
|
292
|
+
torch.cuda.memory.empty_cache()
|
293
|
+
gc.collect()
|
294
|
+
|
295
|
+
def train_model(dst, model_type, train_loaders, train_loader_names, train_mode='erm', epochs=100, learning_rate=0.0001, weight_decay=0.05, amsgrad=False, optimizer_type='adamw', use_checkpoint=False, dropout_rate=0, num_workers=20, val_loaders=None, test_loaders=None, init_weights='imagenet', intermedeate_save=None, chan_dict=None, schedule = None, loss_type='binary_cross_entropy_with_logits', gradient_accumulation=False, gradient_accumulation_steps=4):
|
296
|
+
"""
|
297
|
+
Trains a model using the specified parameters.
|
298
|
+
|
299
|
+
Args:
|
300
|
+
dst (str): The destination path to save the model and results.
|
301
|
+
model_type (str): The type of model to train.
|
302
|
+
train_loaders (list): A list of training data loaders.
|
303
|
+
train_loader_names (list): A list of names for the training data loaders.
|
304
|
+
train_mode (str, optional): The training mode. Defaults to 'erm'.
|
305
|
+
epochs (int, optional): The number of training epochs. Defaults to 100.
|
306
|
+
learning_rate (float, optional): The learning rate for the optimizer. Defaults to 0.0001.
|
307
|
+
weight_decay (float, optional): The weight decay for the optimizer. Defaults to 0.05.
|
308
|
+
amsgrad (bool, optional): Whether to use AMSGrad for the optimizer. Defaults to False.
|
309
|
+
optimizer_type (str, optional): The type of optimizer to use. Defaults to 'adamw'.
|
310
|
+
use_checkpoint (bool, optional): Whether to use checkpointing during training. Defaults to False.
|
311
|
+
dropout_rate (float, optional): The dropout rate for the model. Defaults to 0.
|
312
|
+
num_workers (int, optional): The number of workers for data loading. Defaults to 20.
|
313
|
+
val_loaders (list, optional): A list of validation data loaders. Defaults to None.
|
314
|
+
test_loaders (list, optional): A list of test data loaders. Defaults to None.
|
315
|
+
init_weights (str, optional): The initialization weights for the model. Defaults to 'imagenet'.
|
316
|
+
intermedeate_save (list, optional): The intermediate save thresholds. Defaults to None.
|
317
|
+
chan_dict (dict, optional): The channel dictionary. Defaults to None.
|
318
|
+
schedule (str, optional): The learning rate schedule. Defaults to None.
|
319
|
+
loss_type (str, optional): The loss function type. Defaults to 'binary_cross_entropy_with_logits'.
|
320
|
+
gradient_accumulation (bool, optional): Whether to use gradient accumulation. Defaults to False.
|
321
|
+
gradient_accumulation_steps (int, optional): The number of steps for gradient accumulation. Defaults to 4.
|
322
|
+
|
323
|
+
Returns:
|
324
|
+
None
|
325
|
+
"""
|
326
|
+
|
327
|
+
from .io import save_model, save_progress
|
328
|
+
from .utils import evaluate_model_performance, compute_irm_penalty, calculate_loss, choose_model
|
329
|
+
|
330
|
+
print(f'Train batches:{len(train_loaders)}, Validation batches:{len(val_loaders)}')
|
331
|
+
|
332
|
+
if test_loaders != None:
|
333
|
+
print(f'Test batches:{len(test_loaders)}')
|
334
|
+
|
335
|
+
use_cuda = torch.cuda.is_available()
|
336
|
+
device = torch.device("cuda" if use_cuda else "cpu")
|
337
|
+
kwargs = {'num_workers': num_workers, 'pin_memory': True} if use_cuda else {}
|
338
|
+
|
339
|
+
for idx, (images, labels, filenames) in enumerate(train_loaders):
|
340
|
+
batch, channels, height, width = images.shape
|
341
|
+
break
|
342
|
+
|
343
|
+
model = choose_model(model_type, device, init_weights, dropout_rate, use_checkpoint)
|
344
|
+
model.to(device)
|
345
|
+
|
346
|
+
if optimizer_type == 'adamw':
|
347
|
+
optimizer = AdamW(model.parameters(), lr=learning_rate, betas=(0.9, 0.999), weight_decay=weight_decay, amsgrad=amsgrad)
|
348
|
+
|
349
|
+
if optimizer_type == 'adagrad':
|
350
|
+
optimizer = Adagrad(model.parameters(), lr=learning_rate, eps=1e-8, weight_decay=weight_decay)
|
351
|
+
|
352
|
+
if schedule == 'step_lr':
|
353
|
+
StepLR_step_size = int(epochs/5)
|
354
|
+
StepLR_gamma = 0.75
|
355
|
+
scheduler = StepLR(optimizer, step_size=StepLR_step_size, gamma=StepLR_gamma)
|
356
|
+
elif schedule == 'reduce_lr_on_plateau':
|
357
|
+
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10, verbose=True)
|
358
|
+
else:
|
359
|
+
scheduler = None
|
360
|
+
|
361
|
+
if train_mode == 'erm':
|
362
|
+
for epoch in range(1, epochs+1):
|
363
|
+
model.train()
|
364
|
+
start_time = time.time()
|
365
|
+
running_loss = 0.0
|
366
|
+
|
367
|
+
# Initialize gradients if using gradient accumulation
|
368
|
+
if gradient_accumulation:
|
369
|
+
optimizer.zero_grad()
|
370
|
+
|
371
|
+
for batch_idx, (data, target, filenames) in enumerate(train_loaders, start=1):
|
372
|
+
data, target = data.to(device), target.to(device).float()
|
373
|
+
output = model(data)
|
374
|
+
loss = calculate_loss(output, target, loss_type=loss_type)
|
375
|
+
# Normalize loss if using gradient accumulation
|
376
|
+
if gradient_accumulation:
|
377
|
+
loss /= gradient_accumulation_steps
|
378
|
+
running_loss += loss.item() * gradient_accumulation_steps # correct the running_loss
|
379
|
+
loss.backward()
|
380
|
+
|
381
|
+
# Step optimizer if not using gradient accumulation or every gradient_accumulation_steps
|
382
|
+
if not gradient_accumulation or (batch_idx % gradient_accumulation_steps == 0):
|
383
|
+
optimizer.step()
|
384
|
+
optimizer.zero_grad()
|
385
|
+
|
386
|
+
avg_loss = running_loss / batch_idx
|
387
|
+
print(f'\rTrain: epoch: {epoch} batch: {batch_idx}/{len(train_loaders)} avg_loss: {avg_loss:.5f} time: {(time.time()-start_time):.5f}', end='\r', flush=True)
|
388
|
+
|
389
|
+
end_time = time.time()
|
390
|
+
train_time = end_time - start_time
|
391
|
+
train_metrics = {'epoch':epoch,'loss':loss.cpu().item(), 'train_time':train_time}
|
392
|
+
train_metrics_df = pd.DataFrame(train_metrics, index=[epoch])
|
393
|
+
train_names = 'train'
|
394
|
+
results_df, train_test_time = evaluate_model_performance(train_loaders, model, train_names, epoch, train_mode='erm', loss_type=loss_type)
|
395
|
+
train_metrics_df['train_test_time'] = train_test_time
|
396
|
+
if val_loaders != None:
|
397
|
+
val_names = 'val'
|
398
|
+
result, val_time = evaluate_model_performance(val_loaders, model, val_names, epoch, train_mode='erm', loss_type=loss_type)
|
399
|
+
|
400
|
+
if schedule == 'reduce_lr_on_plateau':
|
401
|
+
val_loss = result['loss']
|
402
|
+
|
403
|
+
results_df = pd.concat([results_df, result])
|
404
|
+
train_metrics_df['val_time'] = val_time
|
405
|
+
if test_loaders != None:
|
406
|
+
test_names = 'test'
|
407
|
+
result, test_test_time = evaluate_model_performance(test_loaders, model, test_names, epoch, train_mode='erm', loss_type=loss_type)
|
408
|
+
results_df = pd.concat([results_df, result])
|
409
|
+
test_time = (train_test_time+val_time+test_test_time)/3
|
410
|
+
train_metrics_df['test_time'] = test_time
|
411
|
+
|
412
|
+
if scheduler:
|
413
|
+
if schedule == 'reduce_lr_on_plateau':
|
414
|
+
scheduler.step(val_loss)
|
415
|
+
if schedule == 'step_lr':
|
416
|
+
scheduler.step()
|
417
|
+
|
418
|
+
save_progress(dst, results_df, train_metrics_df)
|
419
|
+
clear_output(wait=True)
|
420
|
+
display(results_df)
|
421
|
+
save_model(model, model_type, results_df, dst, epoch, epochs, intermedeate_save=[0.99,0.98,0.95,0.94])
|
422
|
+
|
423
|
+
if train_mode == 'irm':
|
424
|
+
dummy_w = torch.nn.Parameter(torch.Tensor([1.0])).to(device)
|
425
|
+
phi = torch.nn.Parameter (torch.ones(4,1))
|
426
|
+
for epoch in range(1, epochs):
|
427
|
+
model.train()
|
428
|
+
penalty_factor = epoch * 1e-5
|
429
|
+
epoch_names = [str(epoch) + '_' + item for item in train_loader_names]
|
430
|
+
loader_erm_loss_list = []
|
431
|
+
total_erm_loss_mean = 0
|
432
|
+
for loader_index in range(0, len(train_loaders)):
|
433
|
+
start_time = time.time()
|
434
|
+
loader = train_loaders[loader_index]
|
435
|
+
loader_erm_loss_mean = 0
|
436
|
+
batch_count = 0
|
437
|
+
batch_erm_loss_list = []
|
438
|
+
for batch_idx, (data, target, filenames) in enumerate(loader, start=1):
|
439
|
+
optimizer.zero_grad()
|
440
|
+
data, target = data.to(device), target.to(device).float()
|
441
|
+
|
442
|
+
output = model(data)
|
443
|
+
erm_loss = F.binary_cross_entropy_with_logits(output * dummy_w, target, reduction='none')
|
444
|
+
|
445
|
+
batch_erm_loss_list.append(erm_loss.mean())
|
446
|
+
print(f'\repoch: {epoch} loader: {loader_index} batch: {batch_idx+1}/{len(loader)}', end='\r', flush=True)
|
447
|
+
loader_erm_loss_mean = torch.stack(batch_erm_loss_list).mean()
|
448
|
+
loader_erm_loss_list.append(loader_erm_loss_mean)
|
449
|
+
total_erm_loss_mean = torch.stack(loader_erm_loss_list).mean()
|
450
|
+
irm_loss = compute_irm_penalty(loader_erm_loss_list, dummy_w, device)
|
451
|
+
|
452
|
+
(total_erm_loss_mean + penalty_factor * irm_loss).backward()
|
453
|
+
optimizer.step()
|
454
|
+
|
455
|
+
end_time = time.time()
|
456
|
+
train_time = end_time - start_time
|
457
|
+
|
458
|
+
train_metrics = {'epoch': epoch, 'irm_loss': irm_loss, 'erm_loss': total_erm_loss_mean, 'penalty_factor': penalty_factor, 'train_time': train_time}
|
459
|
+
#train_metrics = {'epoch':epoch,'irm_loss':irm_loss.cpu().item(),'erm_loss':total_erm_loss_mean.cpu().item(),'penalty_factor':penalty_factor, 'train_time':train_time}
|
460
|
+
train_metrics_df = pd.DataFrame(train_metrics, index=[epoch])
|
461
|
+
print(f'\rTrain: epoch: {epoch} loader: {loader_index} batch: {batch_idx+1}/{len(loader)} irm_loss: {irm_loss:.5f} mean_erm_loss: {total_erm_loss_mean:.5f} train time {train_time:.5f}', end='\r', flush=True)
|
462
|
+
|
463
|
+
train_names = [item + '_train' for item in train_loader_names]
|
464
|
+
results_df, train_test_time = evaluate_model_performance(train_loaders, model, train_names, epoch, train_mode='irm', loss_type=loss_type)
|
465
|
+
train_metrics_df['train_test_time'] = train_test_time
|
466
|
+
|
467
|
+
if val_loaders != None:
|
468
|
+
val_names = [item + '_val' for item in train_loader_names]
|
469
|
+
result, val_time = evaluate_model_performance(val_loaders, model, val_names, epoch, train_mode='irm', loss_type=loss_type)
|
470
|
+
|
471
|
+
if schedule == 'reduce_lr_on_plateau':
|
472
|
+
val_loss = result['loss']
|
473
|
+
|
474
|
+
results_df = pd.concat([results_df, result])
|
475
|
+
train_metrics_df['val_time'] = val_time
|
476
|
+
|
477
|
+
if test_loaders != None:
|
478
|
+
test_names = [item + '_test' for item in train_loader_names] #test_loader_names?
|
479
|
+
result, test_test_time = evaluate_model_performance(test_loaders, model, test_names, epoch, train_mode='irm', loss_type=loss_type)
|
480
|
+
results_df = pd.concat([results_df, result])
|
481
|
+
train_metrics_df['test_test_time'] = test_test_time
|
482
|
+
|
483
|
+
if scheduler:
|
484
|
+
if schedule == 'reduce_lr_on_plateau':
|
485
|
+
scheduler.step(val_loss)
|
486
|
+
if schedule == 'step_lr':
|
487
|
+
scheduler.step()
|
488
|
+
|
489
|
+
clear_output(wait=True)
|
490
|
+
display(results_df)
|
491
|
+
save_progress(dst, results_df, train_metrics_df)
|
492
|
+
save_model(model, model_type, results_df, dst, epoch, epochs, intermedeate_save=[0.99,0.98,0.95,0.94])
|
493
|
+
print(f'Saved model: {dst}')
|
494
|
+
return
|