spacr 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
spacr/train.py ADDED
@@ -0,0 +1,494 @@
1
+ import os, torch, time, gc, datetime
2
+ import pandas as pd
3
+ from torch.optim import Adagrad
4
+ from torch.optim import AdamW
5
+ from torch.autograd import grad
6
+ from torch.optim.lr_scheduler import StepLR
7
+ import torch.nn.functional as F
8
+ from IPython.display import display, clear_output
9
+
10
+ from .logger import log_function_call
11
+
12
+ def evaluate_model_core(model, loader, loader_name, epoch, loss_type):
13
+ """
14
+ Evaluates the performance of a model on a given data loader.
15
+
16
+ Args:
17
+ model (torch.nn.Module): The model to evaluate.
18
+ loader (torch.utils.data.DataLoader): The data loader to evaluate the model on.
19
+ loader_name (str): The name of the data loader.
20
+ epoch (int): The current epoch number.
21
+ loss_type (str): The type of loss function to use.
22
+
23
+ Returns:
24
+ data_df (pandas.DataFrame): The classification metrics data as a DataFrame.
25
+ prediction_pos_probs (list): The positive class probabilities for each prediction.
26
+ all_labels (list): The true labels for each prediction.
27
+ """
28
+
29
+ from .utils import calculate_loss, classification_metrics
30
+
31
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
32
+ model.eval()
33
+ loss = 0
34
+ correct = 0
35
+ total_samples = 0
36
+ prediction_pos_probs = []
37
+ all_labels = []
38
+ model = model.to(device)
39
+ with torch.no_grad():
40
+ for batch_idx, (data, target, _) in enumerate(loader, start=1):
41
+ start_time = time.time()
42
+ data, target = data.to(device), target.to(device).float()
43
+ #data, target = data.to(torch.float).to(device), target.to(device).float()
44
+ output = model(data)
45
+ loss += F.binary_cross_entropy_with_logits(output, target, reduction='sum').item()
46
+ loss = calculate_loss(output, target, loss_type=loss_type)
47
+ loss += loss.item()
48
+ total_samples += data.size(0)
49
+ pred = torch.where(output >= 0.5,
50
+ torch.Tensor([1.0]).to(device).float(),
51
+ torch.Tensor([0.0]).to(device).float())
52
+ correct += pred.eq(target.view_as(pred)).sum().item()
53
+ batch_prediction_pos_prob = torch.sigmoid(output).cpu().numpy()
54
+ prediction_pos_probs.extend(batch_prediction_pos_prob.tolist())
55
+ all_labels.extend(target.cpu().numpy().tolist())
56
+ mean_loss = loss / total_samples
57
+ acc = correct / total_samples
58
+ end_time = time.time()
59
+ test_time = end_time - start_time
60
+ print(f'\rTest: epoch: {epoch} Accuracy: {acc:.5f} batch: {batch_idx+1}/{len(loader)} loss: {mean_loss:.5f} loss: {mean_loss:.5f} time {test_time:.5f}', end='\r', flush=True)
61
+ loss /= len(loader)
62
+ data_df = classification_metrics(all_labels, prediction_pos_probs, loader_name, loss, epoch)
63
+ return data_df, prediction_pos_probs, all_labels
64
+
65
+ def evaluate_model_performance(loaders, model, loader_name_list, epoch, train_mode, loss_type):
66
+ """
67
+ Evaluate the performance of a model on given data loaders.
68
+
69
+ Args:
70
+ loaders (list): List of data loaders.
71
+ model: The model to evaluate.
72
+ loader_name_list (list): List of names for the data loaders.
73
+ epoch (int): The current epoch.
74
+ train_mode (str): The training mode ('erm' or 'irm').
75
+ loss_type: The type of loss function.
76
+
77
+ Returns:
78
+ tuple: A tuple containing the evaluation result and the time taken for evaluation.
79
+ """
80
+ start_time = time.time()
81
+ df_list = []
82
+ if train_mode == 'erm':
83
+ result, _, _ = evaluate_model_core(model, loaders, loader_name_list, epoch, loss_type)
84
+ if train_mode == 'irm':
85
+ for loader_index in range(0, len(loaders)):
86
+ loader = loaders[loader_index]
87
+ loader_name = loader_name_list[loader_index]
88
+ data_df, _, _ = evaluate_model_core(model, loader, loader_name, epoch, loss_type)
89
+ torch.cuda.empty_cache()
90
+ df_list.append(data_df)
91
+ result = pd.concat(df_list)
92
+ nc_mean = result['neg_accuracy'].mean(skipna=True)
93
+ pc_mean = result['pos_accuracy'].mean(skipna=True)
94
+ tot_mean = result['accuracy'].mean(skipna=True)
95
+ loss_mean = result['loss'].mean(skipna=True)
96
+ prauc_mean = result['prauc'].mean(skipna=True)
97
+ data_mean = {'accuracy': tot_mean, 'neg_accuracy': nc_mean, 'pos_accuracy': pc_mean, 'loss': loss_mean, 'prauc': prauc_mean}
98
+ result = pd.concat([pd.DataFrame(result), pd.DataFrame(data_mean, index=[str(epoch)+'_mean'])])
99
+ end_time = time.time()
100
+ test_time = end_time - start_time
101
+ return result, test_time
102
+
103
+ def test_model_core(model, loader, loader_name, epoch, loss_type):
104
+
105
+ from .utils import calculate_loss, classification_metrics
106
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
107
+ model.eval()
108
+ loss = 0
109
+ correct = 0
110
+ total_samples = 0
111
+ prediction_pos_probs = []
112
+ all_labels = []
113
+ filenames = []
114
+ true_targets = []
115
+ predicted_outputs = []
116
+
117
+ model = model.to(device)
118
+ with torch.no_grad():
119
+ for batch_idx, (data, target, filename) in enumerate(loader, start=1): # Assuming loader provides filenames
120
+ start_time = time.time()
121
+ data, target = data.to(device), target.to(device).float()
122
+ output = model(data)
123
+ loss += F.binary_cross_entropy_with_logits(output, target, reduction='sum').item()
124
+ loss = calculate_loss(output, target, loss_type=loss_type)
125
+ loss += loss.item()
126
+ total_samples += data.size(0)
127
+ pred = torch.where(output >= 0.5,
128
+ torch.Tensor([1.0]).to(device).float(),
129
+ torch.Tensor([0.0]).to(device).float())
130
+ correct += pred.eq(target.view_as(pred)).sum().item()
131
+ batch_prediction_pos_prob = torch.sigmoid(output).cpu().numpy()
132
+ prediction_pos_probs.extend(batch_prediction_pos_prob.tolist())
133
+ all_labels.extend(target.cpu().numpy().tolist())
134
+
135
+ # Storing intermediate results in lists
136
+ true_targets.extend(target.cpu().numpy().tolist())
137
+ predicted_outputs.extend(pred.cpu().numpy().tolist())
138
+ filenames.extend(filename)
139
+
140
+ mean_loss = loss / total_samples
141
+ acc = correct / total_samples
142
+ end_time = time.time()
143
+ test_time = end_time - start_time
144
+ print(f'\rTest: epoch: {epoch} Accuracy: {acc:.5f} batch: {batch_idx}/{len(loader)} loss: {mean_loss:.5f} time {test_time:.5f}', end='\r', flush=True)
145
+
146
+ # Constructing the DataFrame
147
+ results_df = pd.DataFrame({
148
+ 'filename': filenames,
149
+ 'true_label': true_targets,
150
+ 'predicted_label': predicted_outputs,
151
+ 'class_1_probability':prediction_pos_probs})
152
+
153
+ loss /= len(loader)
154
+ data_df = classification_metrics(all_labels, prediction_pos_probs, loader_name, loss, epoch)
155
+ return data_df, prediction_pos_probs, all_labels, results_df
156
+
157
+ def test_model_performance(loaders, model, loader_name_list, epoch, train_mode, loss_type):
158
+ """
159
+ Test the performance of a model on given data loaders.
160
+
161
+ Args:
162
+ loaders (list): List of data loaders.
163
+ model: The model to be tested.
164
+ loader_name_list (list): List of names for the data loaders.
165
+ epoch (int): The current epoch.
166
+ train_mode (str): The training mode ('erm' or 'irm').
167
+ loss_type: The type of loss function.
168
+
169
+ Returns:
170
+ tuple: A tuple containing the test results and the results dataframe.
171
+ """
172
+ start_time = time.time()
173
+ df_list = []
174
+ if train_mode == 'erm':
175
+ result, prediction_pos_probs, all_labels, results_df = test_model_core(model, loaders, loader_name_list, epoch, loss_type)
176
+ if train_mode == 'irm':
177
+ for loader_index in range(0, len(loaders)):
178
+ loader = loaders[loader_index]
179
+ loader_name = loader_name_list[loader_index]
180
+ data_df, prediction_pos_probs, all_labels, results_df = test_model_core(model, loader, loader_name, epoch, loss_type)
181
+ torch.cuda.empty_cache()
182
+ df_list.append(data_df)
183
+ result = pd.concat(df_list)
184
+ nc_mean = result['neg_accuracy'].mean(skipna=True)
185
+ pc_mean = result['pos_accuracy'].mean(skipna=True)
186
+ tot_mean = result['accuracy'].mean(skipna=True)
187
+ loss_mean = result['loss'].mean(skipna=True)
188
+ prauc_mean = result['prauc'].mean(skipna=True)
189
+ data_mean = {'accuracy': tot_mean, 'neg_accuracy': nc_mean, 'pos_accuracy': pc_mean, 'loss': loss_mean, 'prauc': prauc_mean}
190
+ result = pd.concat([pd.DataFrame(result), pd.DataFrame(data_mean, index=[str(epoch)+'_mean'])])
191
+ end_time = time.time()
192
+ test_time = end_time - start_time
193
+ return result, results_df
194
+
195
+ def train_test_model(src, settings, custom_model=False, custom_model_path=None):
196
+
197
+ from .io import save_settings, _copy_missclassified
198
+ from .utils import pick_best_model, test_model_performance
199
+ from .core import generate_loaders
200
+
201
+ if custom_model:
202
+ model = torch.load(custom_model_path) #if using a custom trained model
203
+
204
+ if settings['train']:
205
+ save_settings(settings, src)
206
+ torch.cuda.empty_cache()
207
+ torch.cuda.memory.empty_cache()
208
+ gc.collect()
209
+ dst = os.path.join(src,'model')
210
+ os.makedirs(dst, exist_ok=True)
211
+ settings['src'] = src
212
+ settings['dst'] = dst
213
+ if settings['train']:
214
+ train, val, plate_names = generate_loaders(src,
215
+ train_mode=settings['train_mode'],
216
+ mode='train',
217
+ image_size=settings['image_size'],
218
+ batch_size=settings['batch_size'],
219
+ classes=settings['classes'],
220
+ num_workers=settings['num_workers'],
221
+ validation_split=settings['val_split'],
222
+ pin_memory=settings['pin_memory'],
223
+ normalize=settings['normalize'],
224
+ verbose=settings['verbose'])
225
+
226
+ if settings['test']:
227
+ test, _, plate_names_test = generate_loaders(src,
228
+ train_mode=settings['train_mode'],
229
+ mode='test',
230
+ image_size=settings['image_size'],
231
+ batch_size=settings['batch_size'],
232
+ classes=settings['classes'],
233
+ num_workers=settings['num_workers'],
234
+ validation_split=0.0,
235
+ pin_memory=settings['pin_memory'],
236
+ normalize=settings['normalize'],
237
+ verbose=settings['verbose'])
238
+ if model == None:
239
+ model_path = pick_best_model(src+'/model')
240
+ print(f'Best model: {model_path}')
241
+
242
+ model = torch.load(model_path, map_location=lambda storage, loc: storage)
243
+
244
+ model_type = settings['model_type']
245
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
246
+ print(type(model))
247
+ print(model)
248
+
249
+ model_fldr = os.path.join(src,'model')
250
+ time_now = datetime.date.today().strftime('%y%m%d')
251
+ result_loc = f'{model_fldr}/{model_type}_time_{time_now}_result.csv'
252
+ acc_loc = f'{model_fldr}/{model_type}_time_{time_now}_acc.csv'
253
+ print(f'Results wil be saved in: {result_loc}')
254
+
255
+ result, accuracy = test_model_performance(loaders=test,
256
+ model=model,
257
+ loader_name_list='test',
258
+ epoch=1,
259
+ train_mode=settings['train_mode'],
260
+ loss_type=settings['loss_type'])
261
+
262
+ result.to_csv(result_loc, index=True, header=True, mode='w')
263
+ accuracy.to_csv(acc_loc, index=True, header=True, mode='w')
264
+ _copy_missclassified(accuracy)
265
+ else:
266
+ test = None
267
+
268
+ if settings['train']:
269
+ train_model(dst = settings['dst'],
270
+ model_type=settings['model_type'],
271
+ train_loaders = train,
272
+ train_loader_names = plate_names,
273
+ train_mode = settings['train_mode'],
274
+ epochs = settings['epochs'],
275
+ learning_rate = settings['learning_rate'],
276
+ init_weights = settings['init_weights'],
277
+ weight_decay = settings['weight_decay'],
278
+ amsgrad = settings['amsgrad'],
279
+ optimizer_type = settings['optimizer_type'],
280
+ use_checkpoint = settings['use_checkpoint'],
281
+ dropout_rate = settings['dropout_rate'],
282
+ num_workers = settings['num_workers'],
283
+ val_loaders = val,
284
+ test_loaders = test,
285
+ intermedeate_save = settings['intermedeate_save'],
286
+ schedule = settings['schedule'],
287
+ loss_type=settings['loss_type'],
288
+ gradient_accumulation=settings['gradient_accumulation'],
289
+ gradient_accumulation_steps=settings['gradient_accumulation_steps'])
290
+
291
+ torch.cuda.empty_cache()
292
+ torch.cuda.memory.empty_cache()
293
+ gc.collect()
294
+
295
+ def train_model(dst, model_type, train_loaders, train_loader_names, train_mode='erm', epochs=100, learning_rate=0.0001, weight_decay=0.05, amsgrad=False, optimizer_type='adamw', use_checkpoint=False, dropout_rate=0, num_workers=20, val_loaders=None, test_loaders=None, init_weights='imagenet', intermedeate_save=None, chan_dict=None, schedule = None, loss_type='binary_cross_entropy_with_logits', gradient_accumulation=False, gradient_accumulation_steps=4):
296
+ """
297
+ Trains a model using the specified parameters.
298
+
299
+ Args:
300
+ dst (str): The destination path to save the model and results.
301
+ model_type (str): The type of model to train.
302
+ train_loaders (list): A list of training data loaders.
303
+ train_loader_names (list): A list of names for the training data loaders.
304
+ train_mode (str, optional): The training mode. Defaults to 'erm'.
305
+ epochs (int, optional): The number of training epochs. Defaults to 100.
306
+ learning_rate (float, optional): The learning rate for the optimizer. Defaults to 0.0001.
307
+ weight_decay (float, optional): The weight decay for the optimizer. Defaults to 0.05.
308
+ amsgrad (bool, optional): Whether to use AMSGrad for the optimizer. Defaults to False.
309
+ optimizer_type (str, optional): The type of optimizer to use. Defaults to 'adamw'.
310
+ use_checkpoint (bool, optional): Whether to use checkpointing during training. Defaults to False.
311
+ dropout_rate (float, optional): The dropout rate for the model. Defaults to 0.
312
+ num_workers (int, optional): The number of workers for data loading. Defaults to 20.
313
+ val_loaders (list, optional): A list of validation data loaders. Defaults to None.
314
+ test_loaders (list, optional): A list of test data loaders. Defaults to None.
315
+ init_weights (str, optional): The initialization weights for the model. Defaults to 'imagenet'.
316
+ intermedeate_save (list, optional): The intermediate save thresholds. Defaults to None.
317
+ chan_dict (dict, optional): The channel dictionary. Defaults to None.
318
+ schedule (str, optional): The learning rate schedule. Defaults to None.
319
+ loss_type (str, optional): The loss function type. Defaults to 'binary_cross_entropy_with_logits'.
320
+ gradient_accumulation (bool, optional): Whether to use gradient accumulation. Defaults to False.
321
+ gradient_accumulation_steps (int, optional): The number of steps for gradient accumulation. Defaults to 4.
322
+
323
+ Returns:
324
+ None
325
+ """
326
+
327
+ from .io import save_model, save_progress
328
+ from .utils import evaluate_model_performance, compute_irm_penalty, calculate_loss, choose_model
329
+
330
+ print(f'Train batches:{len(train_loaders)}, Validation batches:{len(val_loaders)}')
331
+
332
+ if test_loaders != None:
333
+ print(f'Test batches:{len(test_loaders)}')
334
+
335
+ use_cuda = torch.cuda.is_available()
336
+ device = torch.device("cuda" if use_cuda else "cpu")
337
+ kwargs = {'num_workers': num_workers, 'pin_memory': True} if use_cuda else {}
338
+
339
+ for idx, (images, labels, filenames) in enumerate(train_loaders):
340
+ batch, channels, height, width = images.shape
341
+ break
342
+
343
+ model = choose_model(model_type, device, init_weights, dropout_rate, use_checkpoint)
344
+ model.to(device)
345
+
346
+ if optimizer_type == 'adamw':
347
+ optimizer = AdamW(model.parameters(), lr=learning_rate, betas=(0.9, 0.999), weight_decay=weight_decay, amsgrad=amsgrad)
348
+
349
+ if optimizer_type == 'adagrad':
350
+ optimizer = Adagrad(model.parameters(), lr=learning_rate, eps=1e-8, weight_decay=weight_decay)
351
+
352
+ if schedule == 'step_lr':
353
+ StepLR_step_size = int(epochs/5)
354
+ StepLR_gamma = 0.75
355
+ scheduler = StepLR(optimizer, step_size=StepLR_step_size, gamma=StepLR_gamma)
356
+ elif schedule == 'reduce_lr_on_plateau':
357
+ scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10, verbose=True)
358
+ else:
359
+ scheduler = None
360
+
361
+ if train_mode == 'erm':
362
+ for epoch in range(1, epochs+1):
363
+ model.train()
364
+ start_time = time.time()
365
+ running_loss = 0.0
366
+
367
+ # Initialize gradients if using gradient accumulation
368
+ if gradient_accumulation:
369
+ optimizer.zero_grad()
370
+
371
+ for batch_idx, (data, target, filenames) in enumerate(train_loaders, start=1):
372
+ data, target = data.to(device), target.to(device).float()
373
+ output = model(data)
374
+ loss = calculate_loss(output, target, loss_type=loss_type)
375
+ # Normalize loss if using gradient accumulation
376
+ if gradient_accumulation:
377
+ loss /= gradient_accumulation_steps
378
+ running_loss += loss.item() * gradient_accumulation_steps # correct the running_loss
379
+ loss.backward()
380
+
381
+ # Step optimizer if not using gradient accumulation or every gradient_accumulation_steps
382
+ if not gradient_accumulation or (batch_idx % gradient_accumulation_steps == 0):
383
+ optimizer.step()
384
+ optimizer.zero_grad()
385
+
386
+ avg_loss = running_loss / batch_idx
387
+ print(f'\rTrain: epoch: {epoch} batch: {batch_idx}/{len(train_loaders)} avg_loss: {avg_loss:.5f} time: {(time.time()-start_time):.5f}', end='\r', flush=True)
388
+
389
+ end_time = time.time()
390
+ train_time = end_time - start_time
391
+ train_metrics = {'epoch':epoch,'loss':loss.cpu().item(), 'train_time':train_time}
392
+ train_metrics_df = pd.DataFrame(train_metrics, index=[epoch])
393
+ train_names = 'train'
394
+ results_df, train_test_time = evaluate_model_performance(train_loaders, model, train_names, epoch, train_mode='erm', loss_type=loss_type)
395
+ train_metrics_df['train_test_time'] = train_test_time
396
+ if val_loaders != None:
397
+ val_names = 'val'
398
+ result, val_time = evaluate_model_performance(val_loaders, model, val_names, epoch, train_mode='erm', loss_type=loss_type)
399
+
400
+ if schedule == 'reduce_lr_on_plateau':
401
+ val_loss = result['loss']
402
+
403
+ results_df = pd.concat([results_df, result])
404
+ train_metrics_df['val_time'] = val_time
405
+ if test_loaders != None:
406
+ test_names = 'test'
407
+ result, test_test_time = evaluate_model_performance(test_loaders, model, test_names, epoch, train_mode='erm', loss_type=loss_type)
408
+ results_df = pd.concat([results_df, result])
409
+ test_time = (train_test_time+val_time+test_test_time)/3
410
+ train_metrics_df['test_time'] = test_time
411
+
412
+ if scheduler:
413
+ if schedule == 'reduce_lr_on_plateau':
414
+ scheduler.step(val_loss)
415
+ if schedule == 'step_lr':
416
+ scheduler.step()
417
+
418
+ save_progress(dst, results_df, train_metrics_df)
419
+ clear_output(wait=True)
420
+ display(results_df)
421
+ save_model(model, model_type, results_df, dst, epoch, epochs, intermedeate_save=[0.99,0.98,0.95,0.94])
422
+
423
+ if train_mode == 'irm':
424
+ dummy_w = torch.nn.Parameter(torch.Tensor([1.0])).to(device)
425
+ phi = torch.nn.Parameter (torch.ones(4,1))
426
+ for epoch in range(1, epochs):
427
+ model.train()
428
+ penalty_factor = epoch * 1e-5
429
+ epoch_names = [str(epoch) + '_' + item for item in train_loader_names]
430
+ loader_erm_loss_list = []
431
+ total_erm_loss_mean = 0
432
+ for loader_index in range(0, len(train_loaders)):
433
+ start_time = time.time()
434
+ loader = train_loaders[loader_index]
435
+ loader_erm_loss_mean = 0
436
+ batch_count = 0
437
+ batch_erm_loss_list = []
438
+ for batch_idx, (data, target, filenames) in enumerate(loader, start=1):
439
+ optimizer.zero_grad()
440
+ data, target = data.to(device), target.to(device).float()
441
+
442
+ output = model(data)
443
+ erm_loss = F.binary_cross_entropy_with_logits(output * dummy_w, target, reduction='none')
444
+
445
+ batch_erm_loss_list.append(erm_loss.mean())
446
+ print(f'\repoch: {epoch} loader: {loader_index} batch: {batch_idx+1}/{len(loader)}', end='\r', flush=True)
447
+ loader_erm_loss_mean = torch.stack(batch_erm_loss_list).mean()
448
+ loader_erm_loss_list.append(loader_erm_loss_mean)
449
+ total_erm_loss_mean = torch.stack(loader_erm_loss_list).mean()
450
+ irm_loss = compute_irm_penalty(loader_erm_loss_list, dummy_w, device)
451
+
452
+ (total_erm_loss_mean + penalty_factor * irm_loss).backward()
453
+ optimizer.step()
454
+
455
+ end_time = time.time()
456
+ train_time = end_time - start_time
457
+
458
+ train_metrics = {'epoch': epoch, 'irm_loss': irm_loss, 'erm_loss': total_erm_loss_mean, 'penalty_factor': penalty_factor, 'train_time': train_time}
459
+ #train_metrics = {'epoch':epoch,'irm_loss':irm_loss.cpu().item(),'erm_loss':total_erm_loss_mean.cpu().item(),'penalty_factor':penalty_factor, 'train_time':train_time}
460
+ train_metrics_df = pd.DataFrame(train_metrics, index=[epoch])
461
+ print(f'\rTrain: epoch: {epoch} loader: {loader_index} batch: {batch_idx+1}/{len(loader)} irm_loss: {irm_loss:.5f} mean_erm_loss: {total_erm_loss_mean:.5f} train time {train_time:.5f}', end='\r', flush=True)
462
+
463
+ train_names = [item + '_train' for item in train_loader_names]
464
+ results_df, train_test_time = evaluate_model_performance(train_loaders, model, train_names, epoch, train_mode='irm', loss_type=loss_type)
465
+ train_metrics_df['train_test_time'] = train_test_time
466
+
467
+ if val_loaders != None:
468
+ val_names = [item + '_val' for item in train_loader_names]
469
+ result, val_time = evaluate_model_performance(val_loaders, model, val_names, epoch, train_mode='irm', loss_type=loss_type)
470
+
471
+ if schedule == 'reduce_lr_on_plateau':
472
+ val_loss = result['loss']
473
+
474
+ results_df = pd.concat([results_df, result])
475
+ train_metrics_df['val_time'] = val_time
476
+
477
+ if test_loaders != None:
478
+ test_names = [item + '_test' for item in train_loader_names] #test_loader_names?
479
+ result, test_test_time = evaluate_model_performance(test_loaders, model, test_names, epoch, train_mode='irm', loss_type=loss_type)
480
+ results_df = pd.concat([results_df, result])
481
+ train_metrics_df['test_test_time'] = test_test_time
482
+
483
+ if scheduler:
484
+ if schedule == 'reduce_lr_on_plateau':
485
+ scheduler.step(val_loss)
486
+ if schedule == 'step_lr':
487
+ scheduler.step()
488
+
489
+ clear_output(wait=True)
490
+ display(results_df)
491
+ save_progress(dst, results_df, train_metrics_df)
492
+ save_model(model, model_type, results_df, dst, epoch, epochs, intermedeate_save=[0.99,0.98,0.95,0.94])
493
+ print(f'Saved model: {dst}')
494
+ return