spacr 0.0.36__py3-none-any.whl → 0.0.61__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
spacr/deep_spacr.py ADDED
@@ -0,0 +1,696 @@
1
+ import os, torch, time, gc, datetime
2
+ import numpy as np
3
+ import pandas as pd
4
+ from torch.optim import Adagrad, AdamW
5
+ from torch.autograd import grad
6
+ from torch.optim.lr_scheduler import StepLR
7
+ import torch.nn.functional as F
8
+ from IPython.display import display, clear_output
9
+
10
+ import matplotlib.pyplot as plt
11
+ from PIL import Image
12
+
13
+ from .logger import log_function_call
14
+
15
+ def evaluate_model_core(model, loader, loader_name, epoch, loss_type):
16
+ """
17
+ Evaluates the performance of a model on a given data loader.
18
+
19
+ Args:
20
+ model (torch.nn.Module): The model to evaluate.
21
+ loader (torch.utils.data.DataLoader): The data loader to evaluate the model on.
22
+ loader_name (str): The name of the data loader.
23
+ epoch (int): The current epoch number.
24
+ loss_type (str): The type of loss function to use.
25
+
26
+ Returns:
27
+ data_df (pandas.DataFrame): The classification metrics data as a DataFrame.
28
+ prediction_pos_probs (list): The positive class probabilities for each prediction.
29
+ all_labels (list): The true labels for each prediction.
30
+ """
31
+
32
+ from .utils import calculate_loss, classification_metrics
33
+
34
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
35
+ model.eval()
36
+ loss = 0
37
+ correct = 0
38
+ total_samples = 0
39
+ prediction_pos_probs = []
40
+ all_labels = []
41
+ model = model.to(device)
42
+ with torch.no_grad():
43
+ for batch_idx, (data, target, _) in enumerate(loader, start=1):
44
+ start_time = time.time()
45
+ data, target = data.to(device), target.to(device).float()
46
+ #data, target = data.to(torch.float).to(device), target.to(device).float()
47
+ output = model(data)
48
+ loss += F.binary_cross_entropy_with_logits(output, target, reduction='sum').item()
49
+ loss = calculate_loss(output, target, loss_type=loss_type)
50
+ loss += loss.item()
51
+ total_samples += data.size(0)
52
+ pred = torch.where(output >= 0.5,
53
+ torch.Tensor([1.0]).to(device).float(),
54
+ torch.Tensor([0.0]).to(device).float())
55
+ correct += pred.eq(target.view_as(pred)).sum().item()
56
+ batch_prediction_pos_prob = torch.sigmoid(output).cpu().numpy()
57
+ prediction_pos_probs.extend(batch_prediction_pos_prob.tolist())
58
+ all_labels.extend(target.cpu().numpy().tolist())
59
+ mean_loss = loss / total_samples
60
+ acc = correct / total_samples
61
+ end_time = time.time()
62
+ test_time = end_time - start_time
63
+ print(f'\rTest: epoch: {epoch} Accuracy: {acc:.5f} batch: {batch_idx+1}/{len(loader)} loss: {mean_loss:.5f} loss: {mean_loss:.5f} time {test_time:.5f}', end='\r', flush=True)
64
+ loss /= len(loader)
65
+ data_df = classification_metrics(all_labels, prediction_pos_probs, loader_name, loss, epoch)
66
+ return data_df, prediction_pos_probs, all_labels
67
+
68
+ def evaluate_model_performance(loaders, model, loader_name_list, epoch, train_mode, loss_type):
69
+ """
70
+ Evaluate the performance of a model on given data loaders.
71
+
72
+ Args:
73
+ loaders (list): List of data loaders.
74
+ model: The model to evaluate.
75
+ loader_name_list (list): List of names for the data loaders.
76
+ epoch (int): The current epoch.
77
+ train_mode (str): The training mode ('erm' or 'irm').
78
+ loss_type: The type of loss function.
79
+
80
+ Returns:
81
+ tuple: A tuple containing the evaluation result and the time taken for evaluation.
82
+ """
83
+ start_time = time.time()
84
+ df_list = []
85
+ if train_mode == 'erm':
86
+ result, _, _ = evaluate_model_core(model, loaders, loader_name_list, epoch, loss_type)
87
+ if train_mode == 'irm':
88
+ for loader_index in range(0, len(loaders)):
89
+ loader = loaders[loader_index]
90
+ loader_name = loader_name_list[loader_index]
91
+ data_df, _, _ = evaluate_model_core(model, loader, loader_name, epoch, loss_type)
92
+ torch.cuda.empty_cache()
93
+ df_list.append(data_df)
94
+ result = pd.concat(df_list)
95
+ nc_mean = result['neg_accuracy'].mean(skipna=True)
96
+ pc_mean = result['pos_accuracy'].mean(skipna=True)
97
+ tot_mean = result['accuracy'].mean(skipna=True)
98
+ loss_mean = result['loss'].mean(skipna=True)
99
+ prauc_mean = result['prauc'].mean(skipna=True)
100
+ data_mean = {'accuracy': tot_mean, 'neg_accuracy': nc_mean, 'pos_accuracy': pc_mean, 'loss': loss_mean, 'prauc': prauc_mean}
101
+ result = pd.concat([pd.DataFrame(result), pd.DataFrame(data_mean, index=[str(epoch)+'_mean'])])
102
+ end_time = time.time()
103
+ test_time = end_time - start_time
104
+ return result, test_time
105
+
106
+ def test_model_core(model, loader, loader_name, epoch, loss_type):
107
+
108
+ from .utils import calculate_loss, classification_metrics
109
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
110
+ model.eval()
111
+ loss = 0
112
+ correct = 0
113
+ total_samples = 0
114
+ prediction_pos_probs = []
115
+ all_labels = []
116
+ filenames = []
117
+ true_targets = []
118
+ predicted_outputs = []
119
+
120
+ model = model.to(device)
121
+ with torch.no_grad():
122
+ for batch_idx, (data, target, filename) in enumerate(loader, start=1): # Assuming loader provides filenames
123
+ start_time = time.time()
124
+ data, target = data.to(device), target.to(device).float()
125
+ output = model(data)
126
+ loss += F.binary_cross_entropy_with_logits(output, target, reduction='sum').item()
127
+ loss = calculate_loss(output, target, loss_type=loss_type)
128
+ loss += loss.item()
129
+ total_samples += data.size(0)
130
+ pred = torch.where(output >= 0.5,
131
+ torch.Tensor([1.0]).to(device).float(),
132
+ torch.Tensor([0.0]).to(device).float())
133
+ correct += pred.eq(target.view_as(pred)).sum().item()
134
+ batch_prediction_pos_prob = torch.sigmoid(output).cpu().numpy()
135
+ prediction_pos_probs.extend(batch_prediction_pos_prob.tolist())
136
+ all_labels.extend(target.cpu().numpy().tolist())
137
+
138
+ # Storing intermediate results in lists
139
+ true_targets.extend(target.cpu().numpy().tolist())
140
+ predicted_outputs.extend(pred.cpu().numpy().tolist())
141
+ filenames.extend(filename)
142
+
143
+ mean_loss = loss / total_samples
144
+ acc = correct / total_samples
145
+ end_time = time.time()
146
+ test_time = end_time - start_time
147
+ print(f'\rTest: epoch: {epoch} Accuracy: {acc:.5f} batch: {batch_idx}/{len(loader)} loss: {mean_loss:.5f} time {test_time:.5f}', end='\r', flush=True)
148
+
149
+ # Constructing the DataFrame
150
+ results_df = pd.DataFrame({
151
+ 'filename': filenames,
152
+ 'true_label': true_targets,
153
+ 'predicted_label': predicted_outputs,
154
+ 'class_1_probability':prediction_pos_probs})
155
+
156
+ loss /= len(loader)
157
+ data_df = classification_metrics(all_labels, prediction_pos_probs, loader_name, loss, epoch)
158
+ return data_df, prediction_pos_probs, all_labels, results_df
159
+
160
+ def test_model_performance(loaders, model, loader_name_list, epoch, train_mode, loss_type):
161
+ """
162
+ Test the performance of a model on given data loaders.
163
+
164
+ Args:
165
+ loaders (list): List of data loaders.
166
+ model: The model to be tested.
167
+ loader_name_list (list): List of names for the data loaders.
168
+ epoch (int): The current epoch.
169
+ train_mode (str): The training mode ('erm' or 'irm').
170
+ loss_type: The type of loss function.
171
+
172
+ Returns:
173
+ tuple: A tuple containing the test results and the results dataframe.
174
+ """
175
+ start_time = time.time()
176
+ df_list = []
177
+ if train_mode == 'erm':
178
+ result, prediction_pos_probs, all_labels, results_df = test_model_core(model, loaders, loader_name_list, epoch, loss_type)
179
+ if train_mode == 'irm':
180
+ for loader_index in range(0, len(loaders)):
181
+ loader = loaders[loader_index]
182
+ loader_name = loader_name_list[loader_index]
183
+ data_df, prediction_pos_probs, all_labels, results_df = test_model_core(model, loader, loader_name, epoch, loss_type)
184
+ torch.cuda.empty_cache()
185
+ df_list.append(data_df)
186
+ result = pd.concat(df_list)
187
+ nc_mean = result['neg_accuracy'].mean(skipna=True)
188
+ pc_mean = result['pos_accuracy'].mean(skipna=True)
189
+ tot_mean = result['accuracy'].mean(skipna=True)
190
+ loss_mean = result['loss'].mean(skipna=True)
191
+ prauc_mean = result['prauc'].mean(skipna=True)
192
+ data_mean = {'accuracy': tot_mean, 'neg_accuracy': nc_mean, 'pos_accuracy': pc_mean, 'loss': loss_mean, 'prauc': prauc_mean}
193
+ result = pd.concat([pd.DataFrame(result), pd.DataFrame(data_mean, index=[str(epoch)+'_mean'])])
194
+ end_time = time.time()
195
+ test_time = end_time - start_time
196
+ return result, results_df
197
+
198
+ def train_test_model(src, settings, custom_model=False, custom_model_path=None):
199
+
200
+ from .io import _save_settings, _copy_missclassified
201
+ from .utils import pick_best_model
202
+ from .core import generate_loaders
203
+
204
+ settings['src'] = src
205
+ settings_df = pd.DataFrame(list(settings.items()), columns=['Key', 'Value'])
206
+ settings_csv = os.path.join(src,'settings','train_test_model_settings.csv')
207
+ os.makedirs(os.path.join(src,'settings'), exist_ok=True)
208
+ settings_df.to_csv(settings_csv, index=False)
209
+
210
+ if custom_model:
211
+ model = torch.load(custom_model_path)
212
+
213
+ if settings['train']:
214
+ _save_settings(settings, src)
215
+ torch.cuda.empty_cache()
216
+ torch.cuda.memory.empty_cache()
217
+ gc.collect()
218
+ dst = os.path.join(src,'model')
219
+ os.makedirs(dst, exist_ok=True)
220
+ settings['src'] = src
221
+ settings['dst'] = dst
222
+ if settings['train']:
223
+ train, val, plate_names = generate_loaders(src,
224
+ train_mode=settings['train_mode'],
225
+ mode='train',
226
+ image_size=settings['image_size'],
227
+ batch_size=settings['batch_size'],
228
+ classes=settings['classes'],
229
+ num_workers=settings['num_workers'],
230
+ validation_split=settings['val_split'],
231
+ pin_memory=settings['pin_memory'],
232
+ normalize=settings['normalize'],
233
+ channels=settings['channels'],
234
+ verbose=settings['verbose'])
235
+
236
+
237
+ if settings['test']:
238
+ test, _, plate_names_test = generate_loaders(src,
239
+ train_mode=settings['train_mode'],
240
+ mode='test',
241
+ image_size=settings['image_size'],
242
+ batch_size=settings['batch_size'],
243
+ classes=settings['classes'],
244
+ num_workers=settings['num_workers'],
245
+ validation_split=0.0,
246
+ pin_memory=settings['pin_memory'],
247
+ normalize=settings['normalize'],
248
+ channels=settings['channels'],
249
+ verbose=settings['verbose'])
250
+ if model == None:
251
+ model_path = pick_best_model(src+'/model')
252
+ print(f'Best model: {model_path}')
253
+
254
+ model = torch.load(model_path, map_location=lambda storage, loc: storage)
255
+
256
+ model_type = settings['model_type']
257
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
258
+ print(type(model))
259
+ print(model)
260
+
261
+ model_fldr = os.path.join(src,'model')
262
+ time_now = datetime.date.today().strftime('%y%m%d')
263
+ result_loc = f'{model_fldr}/{model_type}_time_{time_now}_result.csv'
264
+ acc_loc = f'{model_fldr}/{model_type}_time_{time_now}_acc.csv'
265
+ print(f'Results wil be saved in: {result_loc}')
266
+
267
+ result, accuracy = test_model_performance(loaders=test,
268
+ model=model,
269
+ loader_name_list='test',
270
+ epoch=1,
271
+ train_mode=settings['train_mode'],
272
+ loss_type=settings['loss_type'])
273
+
274
+ result.to_csv(result_loc, index=True, header=True, mode='w')
275
+ accuracy.to_csv(acc_loc, index=True, header=True, mode='w')
276
+ _copy_missclassified(accuracy)
277
+ else:
278
+ test = None
279
+
280
+ if settings['train']:
281
+ train_model(dst = settings['dst'],
282
+ model_type=settings['model_type'],
283
+ train_loaders = train,
284
+ train_loader_names = plate_names,
285
+ train_mode = settings['train_mode'],
286
+ epochs = settings['epochs'],
287
+ learning_rate = settings['learning_rate'],
288
+ init_weights = settings['init_weights'],
289
+ weight_decay = settings['weight_decay'],
290
+ amsgrad = settings['amsgrad'],
291
+ optimizer_type = settings['optimizer_type'],
292
+ use_checkpoint = settings['use_checkpoint'],
293
+ dropout_rate = settings['dropout_rate'],
294
+ num_workers = settings['num_workers'],
295
+ val_loaders = val,
296
+ test_loaders = test,
297
+ intermedeate_save = settings['intermedeate_save'],
298
+ schedule = settings['schedule'],
299
+ loss_type=settings['loss_type'],
300
+ gradient_accumulation=settings['gradient_accumulation'],
301
+ gradient_accumulation_steps=settings['gradient_accumulation_steps'])
302
+
303
+ torch.cuda.empty_cache()
304
+ torch.cuda.memory.empty_cache()
305
+ gc.collect()
306
+
307
+ def train_model(dst, model_type, train_loaders, train_loader_names, train_mode='erm', epochs=100, learning_rate=0.0001, weight_decay=0.05, amsgrad=False, optimizer_type='adamw', use_checkpoint=False, dropout_rate=0, num_workers=20, val_loaders=None, test_loaders=None, init_weights='imagenet', intermedeate_save=None, chan_dict=None, schedule = None, loss_type='binary_cross_entropy_with_logits', gradient_accumulation=False, gradient_accumulation_steps=4):
308
+ """
309
+ Trains a model using the specified parameters.
310
+
311
+ Args:
312
+ dst (str): The destination path to save the model and results.
313
+ model_type (str): The type of model to train.
314
+ train_loaders (list): A list of training data loaders.
315
+ train_loader_names (list): A list of names for the training data loaders.
316
+ train_mode (str, optional): The training mode. Defaults to 'erm'.
317
+ epochs (int, optional): The number of training epochs. Defaults to 100.
318
+ learning_rate (float, optional): The learning rate for the optimizer. Defaults to 0.0001.
319
+ weight_decay (float, optional): The weight decay for the optimizer. Defaults to 0.05.
320
+ amsgrad (bool, optional): Whether to use AMSGrad for the optimizer. Defaults to False.
321
+ optimizer_type (str, optional): The type of optimizer to use. Defaults to 'adamw'.
322
+ use_checkpoint (bool, optional): Whether to use checkpointing during training. Defaults to False.
323
+ dropout_rate (float, optional): The dropout rate for the model. Defaults to 0.
324
+ num_workers (int, optional): The number of workers for data loading. Defaults to 20.
325
+ val_loaders (list, optional): A list of validation data loaders. Defaults to None.
326
+ test_loaders (list, optional): A list of test data loaders. Defaults to None.
327
+ init_weights (str, optional): The initialization weights for the model. Defaults to 'imagenet'.
328
+ intermedeate_save (list, optional): The intermediate save thresholds. Defaults to None.
329
+ chan_dict (dict, optional): The channel dictionary. Defaults to None.
330
+ schedule (str, optional): The learning rate schedule. Defaults to None.
331
+ loss_type (str, optional): The loss function type. Defaults to 'binary_cross_entropy_with_logits'.
332
+ gradient_accumulation (bool, optional): Whether to use gradient accumulation. Defaults to False.
333
+ gradient_accumulation_steps (int, optional): The number of steps for gradient accumulation. Defaults to 4.
334
+
335
+ Returns:
336
+ None
337
+ """
338
+
339
+ from .io import _save_model, _save_progress
340
+ from .utils import compute_irm_penalty, calculate_loss, choose_model
341
+
342
+ print(f'Train batches:{len(train_loaders)}, Validation batches:{len(val_loaders)}')
343
+
344
+ if test_loaders != None:
345
+ print(f'Test batches:{len(test_loaders)}')
346
+
347
+ use_cuda = torch.cuda.is_available()
348
+ device = torch.device("cuda" if use_cuda else "cpu")
349
+ kwargs = {'num_workers': num_workers, 'pin_memory': True} if use_cuda else {}
350
+
351
+ for idx, (images, labels, filenames) in enumerate(train_loaders):
352
+ batch, channels, height, width = images.shape
353
+ break
354
+
355
+ model = choose_model(model_type, device, init_weights, dropout_rate, use_checkpoint)
356
+
357
+ if model is None:
358
+ print(f'Model {model_type} not found')
359
+ return
360
+
361
+ model.to(device)
362
+
363
+ if optimizer_type == 'adamw':
364
+ optimizer = AdamW(model.parameters(), lr=learning_rate, betas=(0.9, 0.999), weight_decay=weight_decay, amsgrad=amsgrad)
365
+
366
+ if optimizer_type == 'adagrad':
367
+ optimizer = Adagrad(model.parameters(), lr=learning_rate, eps=1e-8, weight_decay=weight_decay)
368
+
369
+ if schedule == 'step_lr':
370
+ StepLR_step_size = int(epochs/5)
371
+ StepLR_gamma = 0.75
372
+ scheduler = StepLR(optimizer, step_size=StepLR_step_size, gamma=StepLR_gamma)
373
+ elif schedule == 'reduce_lr_on_plateau':
374
+ scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10, verbose=True)
375
+ else:
376
+ scheduler = None
377
+
378
+ if train_mode == 'erm':
379
+ for epoch in range(1, epochs+1):
380
+ model.train()
381
+ start_time = time.time()
382
+ running_loss = 0.0
383
+
384
+ # Initialize gradients if using gradient accumulation
385
+ if gradient_accumulation:
386
+ optimizer.zero_grad()
387
+
388
+ for batch_idx, (data, target, filenames) in enumerate(train_loaders, start=1):
389
+ data, target = data.to(device), target.to(device).float()
390
+ output = model(data)
391
+ loss = calculate_loss(output, target, loss_type=loss_type)
392
+ # Normalize loss if using gradient accumulation
393
+ if gradient_accumulation:
394
+ loss /= gradient_accumulation_steps
395
+ running_loss += loss.item() * gradient_accumulation_steps # correct the running_loss
396
+ loss.backward()
397
+
398
+ # Step optimizer if not using gradient accumulation or every gradient_accumulation_steps
399
+ if not gradient_accumulation or (batch_idx % gradient_accumulation_steps == 0):
400
+ optimizer.step()
401
+ optimizer.zero_grad()
402
+
403
+ avg_loss = running_loss / batch_idx
404
+ print(f'\rTrain: epoch: {epoch} batch: {batch_idx}/{len(train_loaders)} avg_loss: {avg_loss:.5f} time: {(time.time()-start_time):.5f}', end='\r', flush=True)
405
+
406
+ end_time = time.time()
407
+ train_time = end_time - start_time
408
+ train_metrics = {'epoch':epoch,'loss':loss.cpu().item(), 'train_time':train_time}
409
+ train_metrics_df = pd.DataFrame(train_metrics, index=[epoch])
410
+ train_names = 'train'
411
+ results_df, train_test_time = evaluate_model_performance(train_loaders, model, train_names, epoch, train_mode='erm', loss_type=loss_type)
412
+ train_metrics_df['train_test_time'] = train_test_time
413
+ if val_loaders != None:
414
+ val_names = 'val'
415
+ result, val_time = evaluate_model_performance(val_loaders, model, val_names, epoch, train_mode='erm', loss_type=loss_type)
416
+
417
+ if schedule == 'reduce_lr_on_plateau':
418
+ val_loss = result['loss']
419
+
420
+ results_df = pd.concat([results_df, result])
421
+ train_metrics_df['val_time'] = val_time
422
+ if test_loaders != None:
423
+ test_names = 'test'
424
+ result, test_test_time = evaluate_model_performance(test_loaders, model, test_names, epoch, train_mode='erm', loss_type=loss_type)
425
+ results_df = pd.concat([results_df, result])
426
+ test_time = (train_test_time+val_time+test_test_time)/3
427
+ train_metrics_df['test_time'] = test_time
428
+
429
+ if scheduler:
430
+ if schedule == 'reduce_lr_on_plateau':
431
+ scheduler.step(val_loss)
432
+ if schedule == 'step_lr':
433
+ scheduler.step()
434
+
435
+ _save_progress(dst, results_df, train_metrics_df)
436
+ clear_output(wait=True)
437
+ display(results_df)
438
+ _save_model(model, model_type, results_df, dst, epoch, epochs, intermedeate_save=[0.99,0.98,0.95,0.94])
439
+
440
+ if train_mode == 'irm':
441
+ dummy_w = torch.nn.Parameter(torch.Tensor([1.0])).to(device)
442
+ phi = torch.nn.Parameter (torch.ones(4,1))
443
+ for epoch in range(1, epochs):
444
+ model.train()
445
+ penalty_factor = epoch * 1e-5
446
+ epoch_names = [str(epoch) + '_' + item for item in train_loader_names]
447
+ loader_erm_loss_list = []
448
+ total_erm_loss_mean = 0
449
+ for loader_index in range(0, len(train_loaders)):
450
+ start_time = time.time()
451
+ loader = train_loaders[loader_index]
452
+ loader_erm_loss_mean = 0
453
+ batch_count = 0
454
+ batch_erm_loss_list = []
455
+ for batch_idx, (data, target, filenames) in enumerate(loader, start=1):
456
+ optimizer.zero_grad()
457
+ data, target = data.to(device), target.to(device).float()
458
+
459
+ output = model(data)
460
+ erm_loss = F.binary_cross_entropy_with_logits(output * dummy_w, target, reduction='none')
461
+
462
+ batch_erm_loss_list.append(erm_loss.mean())
463
+ print(f'\repoch: {epoch} loader: {loader_index} batch: {batch_idx+1}/{len(loader)}', end='\r', flush=True)
464
+ loader_erm_loss_mean = torch.stack(batch_erm_loss_list).mean()
465
+ loader_erm_loss_list.append(loader_erm_loss_mean)
466
+ total_erm_loss_mean = torch.stack(loader_erm_loss_list).mean()
467
+ irm_loss = compute_irm_penalty(loader_erm_loss_list, dummy_w, device)
468
+
469
+ (total_erm_loss_mean + penalty_factor * irm_loss).backward()
470
+ optimizer.step()
471
+
472
+ end_time = time.time()
473
+ train_time = end_time - start_time
474
+
475
+ train_metrics = {'epoch': epoch, 'irm_loss': irm_loss, 'erm_loss': total_erm_loss_mean, 'penalty_factor': penalty_factor, 'train_time': train_time}
476
+ #train_metrics = {'epoch':epoch,'irm_loss':irm_loss.cpu().item(),'erm_loss':total_erm_loss_mean.cpu().item(),'penalty_factor':penalty_factor, 'train_time':train_time}
477
+ train_metrics_df = pd.DataFrame(train_metrics, index=[epoch])
478
+ print(f'\rTrain: epoch: {epoch} loader: {loader_index} batch: {batch_idx+1}/{len(loader)} irm_loss: {irm_loss:.5f} mean_erm_loss: {total_erm_loss_mean:.5f} train time {train_time:.5f}', end='\r', flush=True)
479
+
480
+ train_names = [item + '_train' for item in train_loader_names]
481
+ results_df, train_test_time = evaluate_model_performance(train_loaders, model, train_names, epoch, train_mode='irm', loss_type=loss_type)
482
+ train_metrics_df['train_test_time'] = train_test_time
483
+
484
+ if val_loaders != None:
485
+ val_names = [item + '_val' for item in train_loader_names]
486
+ result, val_time = evaluate_model_performance(val_loaders, model, val_names, epoch, train_mode='irm', loss_type=loss_type)
487
+
488
+ if schedule == 'reduce_lr_on_plateau':
489
+ val_loss = result['loss']
490
+
491
+ results_df = pd.concat([results_df, result])
492
+ train_metrics_df['val_time'] = val_time
493
+
494
+ if test_loaders != None:
495
+ test_names = [item + '_test' for item in train_loader_names] #test_loader_names?
496
+ result, test_test_time = evaluate_model_performance(test_loaders, model, test_names, epoch, train_mode='irm', loss_type=loss_type)
497
+ results_df = pd.concat([results_df, result])
498
+ train_metrics_df['test_test_time'] = test_test_time
499
+
500
+ if scheduler:
501
+ if schedule == 'reduce_lr_on_plateau':
502
+ scheduler.step(val_loss)
503
+ if schedule == 'step_lr':
504
+ scheduler.step()
505
+
506
+ clear_output(wait=True)
507
+ display(results_df)
508
+ _save_progress(dst, results_df, train_metrics_df)
509
+ _save_model(model, model_type, results_df, dst, epoch, epochs, intermedeate_save=[0.99,0.98,0.95,0.94])
510
+ print(f'Saved model: {dst}')
511
+ return
512
+
513
+ def visualize_saliency_map(src, model_type='maxvit', model_path='', image_size=224, channels=[1,2,3], normalize=True, class_names=None, save_saliency=False, save_dir='saliency_maps'):
514
+
515
+ from spacr.utils import SaliencyMapGenerator, preprocess_image
516
+
517
+ use_cuda = torch.cuda.is_available()
518
+ device = torch.device("cuda" if use_cuda else "cpu")
519
+
520
+ # Load the entire model object
521
+ model = torch.load(model_path)
522
+ model.to(device)
523
+
524
+ # Create directory for saving saliency maps if it does not exist
525
+ if save_saliency and not os.path.exists(save_dir):
526
+ os.makedirs(save_dir)
527
+
528
+ # Collect all images and their tensors
529
+ images = []
530
+ input_tensors = []
531
+ filenames = []
532
+ for file in os.listdir(src):
533
+ if not file.endswith('.png'):
534
+ continue
535
+ image_path = os.path.join(src, file)
536
+ image, input_tensor = preprocess_image(image_path, normalize=normalize, image_size=image_size, channels=channels)
537
+ images.append(image)
538
+ input_tensors.append(input_tensor)
539
+ filenames.append(file)
540
+
541
+ input_tensors = torch.cat(input_tensors).to(device)
542
+ class_labels = torch.zeros(input_tensors.size(0), dtype=torch.long).to(device) # Replace with actual class labels if available
543
+
544
+ # Generate saliency maps
545
+ cam_generator = SaliencyMapGenerator(model)
546
+ saliency_maps = cam_generator.compute_saliency_maps(input_tensors, class_labels)
547
+
548
+ # Convert saliency maps to numpy arrays
549
+ saliency_maps = saliency_maps.cpu().numpy()
550
+
551
+ N = len(images)
552
+
553
+ dst = os.path.join(src, 'saliency_maps')
554
+
555
+ for i in range(N):
556
+ fig, axes = plt.subplots(1, 3, figsize=(20, 5))
557
+
558
+ # Original image
559
+ axes[0].imshow(images[i])
560
+ axes[0].axis('off')
561
+ if class_names:
562
+ axes[0].set_title(f"Class: {class_names[class_labels[i].item()]}")
563
+
564
+ # Saliency Map
565
+ axes[1].imshow(saliency_maps[i, 0], cmap='hot')
566
+ axes[1].axis('off')
567
+ axes[1].set_title("Saliency Map")
568
+
569
+ # Overlay
570
+ overlay = np.array(images[i])
571
+ overlay = overlay / overlay.max()
572
+ saliency_map_rgb = np.stack([saliency_maps[i, 0]] * 3, axis=-1) # Convert saliency map to RGB
573
+ overlay = (overlay * 0.5 + saliency_map_rgb * 0.5).clip(0, 1)
574
+ axes[2].imshow(overlay)
575
+ axes[2].axis('off')
576
+ axes[2].set_title("Overlay")
577
+
578
+ plt.tight_layout()
579
+ plt.show()
580
+
581
+ # Save the saliency map if required
582
+ if save_saliency:
583
+ os.makedirs(dst, exist_ok=True)
584
+ saliency_image = Image.fromarray((saliency_maps[i, 0] * 255).astype(np.uint8))
585
+ saliency_image.save(os.path.join(dst, f'saliency_{filenames[i]}'))
586
+
587
+ def visualize_grad_cam(src, model_path, target_layers=None, image_size=224, channels=[1, 2, 3], normalize=True, class_names=None, save_cam=False, save_dir='grad_cam'):
588
+
589
+ from spacr.utils import GradCAM, preprocess_image, show_cam_on_image, recommend_target_layers
590
+
591
+ use_cuda = torch.cuda.is_available()
592
+ device = torch.device("cuda" if use_cuda else "cpu")
593
+
594
+ model = torch.load(model_path)
595
+ model.to(device)
596
+
597
+ # If no target layers provided, recommend a target layer
598
+ if target_layers is None:
599
+ target_layers, all_layers = recommend_target_layers(model)
600
+ print(f"No target layer provided. Using recommended layer: {target_layers[0]}")
601
+ print("All possible target layers:")
602
+ for layer in all_layers:
603
+ print(layer)
604
+
605
+ grad_cam = GradCAM(model=model, target_layers=target_layers, use_cuda=use_cuda)
606
+
607
+ if save_cam and not os.path.exists(save_dir):
608
+ os.makedirs(save_dir)
609
+
610
+ images = []
611
+ filenames = []
612
+ for file in os.listdir(src):
613
+ if not file.endswith('.png'):
614
+ continue
615
+ image_path = os.path.join(src, file)
616
+ image, input_tensor = preprocess_image(image_path, normalize=normalize, image_size=image_size, channels=channels)
617
+ images.append(image)
618
+ filenames.append(file)
619
+
620
+ input_tensor = input_tensor.to(device)
621
+ cam = grad_cam(input_tensor)
622
+ cam_image = show_cam_on_image(np.array(image) / 255.0, cam)
623
+
624
+ fig, ax = plt.subplots(1, 2, figsize=(10, 5))
625
+ ax[0].imshow(image)
626
+ ax[0].axis('off')
627
+ ax[0].set_title("Original Image")
628
+ ax[1].imshow(cam_image)
629
+ ax[1].axis('off')
630
+ ax[1].set_title("Grad-CAM")
631
+ plt.show()
632
+
633
+ if save_cam:
634
+ cam_pil = Image.fromarray(cam_image)
635
+ cam_pil.save(os.path.join(save_dir, f'grad_cam_{file}'))
636
+
637
+ def visualize_classes(model, dtype, class_names, **kwargs):
638
+
639
+ from spacr.utils import class_visualization
640
+
641
+ for target_y in range(2): # Assuming binary classification
642
+ print(f"Visualizing class: {class_names[target_y]}")
643
+ visualization = class_visualization(target_y, model, dtype, **kwargs)
644
+ plt.imshow(visualization)
645
+ plt.title(f"Class {class_names[target_y]} Visualization")
646
+ plt.axis('off')
647
+ plt.show()
648
+
649
+ def visualize_integrated_gradients(src, model_path, target_label_idx=0, image_size=224, channels=[1,2,3], normalize=True, save_integrated_grads=False, save_dir='integrated_grads'):
650
+
651
+ from .utils import IntegratedGradients, preprocess_image
652
+
653
+ use_cuda = torch.cuda.is_available()
654
+ device = torch.device("cuda" if use_cuda else "cpu")
655
+
656
+ model = torch.load(model_path)
657
+ model.to(device)
658
+ integrated_gradients = IntegratedGradients(model)
659
+
660
+ if save_integrated_grads and not os.path.exists(save_dir):
661
+ os.makedirs(save_dir)
662
+
663
+ images = []
664
+ filenames = []
665
+ for file in os.listdir(src):
666
+ if not file.endswith('.png'):
667
+ continue
668
+ image_path = os.path.join(src, file)
669
+ image, input_tensor = preprocess_image(image_path, normalize=normalize, image_size=image_size, channels=channels)
670
+ images.append(image)
671
+ filenames.append(file)
672
+
673
+ input_tensor = input_tensor.to(device)
674
+ integrated_grads = integrated_gradients.generate_integrated_gradients(input_tensor, target_label_idx)
675
+ integrated_grads = np.mean(integrated_grads, axis=1).squeeze()
676
+
677
+ fig, ax = plt.subplots(1, 3, figsize=(20, 5))
678
+ ax[0].imshow(image)
679
+ ax[0].axis('off')
680
+ ax[0].set_title("Original Image")
681
+ ax[1].imshow(integrated_grads, cmap='hot')
682
+ ax[1].axis('off')
683
+ ax[1].set_title("Integrated Gradients")
684
+ overlay = np.array(image)
685
+ overlay = overlay / overlay.max()
686
+ integrated_grads_rgb = np.stack([integrated_grads] * 3, axis=-1) # Convert saliency map to RGB
687
+ overlay = (overlay * 0.5 + integrated_grads_rgb * 0.5).clip(0, 1)
688
+ ax[2].imshow(overlay)
689
+ ax[2].axis('off')
690
+ ax[2].set_title("Overlay")
691
+ plt.show()
692
+
693
+ if save_integrated_grads:
694
+ os.makedirs(save_dir, exist_ok=True)
695
+ integrated_grads_image = Image.fromarray((integrated_grads * 255).astype(np.uint8))
696
+ integrated_grads_image.save(os.path.join(save_dir, f'integrated_grads_{file}'))