spacr 0.0.66__py3-none-any.whl → 0.0.71__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
spacr/train.py DELETED
@@ -1,667 +0,0 @@
1
- import os, torch, time, gc, datetime
2
- import pandas as pd
3
- from torch.optim import Adagrad
4
- from torch.optim import AdamW
5
- from torch.autograd import grad
6
- from torch.optim.lr_scheduler import StepLR
7
- import torch.nn.functional as F
8
- from IPython.display import display, clear_output
9
- import difflib
10
-
11
- from .logger import log_function_call
12
-
13
- def evaluate_model_core(model, loader, loader_name, epoch, loss_type):
14
- """
15
- Evaluates the performance of a model on a given data loader.
16
-
17
- Args:
18
- model (torch.nn.Module): The model to evaluate.
19
- loader (torch.utils.data.DataLoader): The data loader to evaluate the model on.
20
- loader_name (str): The name of the data loader.
21
- epoch (int): The current epoch number.
22
- loss_type (str): The type of loss function to use.
23
-
24
- Returns:
25
- data_df (pandas.DataFrame): The classification metrics data as a DataFrame.
26
- prediction_pos_probs (list): The positive class probabilities for each prediction.
27
- all_labels (list): The true labels for each prediction.
28
- """
29
-
30
- from .utils import calculate_loss, classification_metrics
31
-
32
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
33
- model.eval()
34
- loss = 0
35
- correct = 0
36
- total_samples = 0
37
- prediction_pos_probs = []
38
- all_labels = []
39
- model = model.to(device)
40
- with torch.no_grad():
41
- for batch_idx, (data, target, _) in enumerate(loader, start=1):
42
- start_time = time.time()
43
- data, target = data.to(device), target.to(device).float()
44
- #data, target = data.to(torch.float).to(device), target.to(device).float()
45
- output = model(data)
46
- loss += F.binary_cross_entropy_with_logits(output, target, reduction='sum').item()
47
- loss = calculate_loss(output, target, loss_type=loss_type)
48
- loss += loss.item()
49
- total_samples += data.size(0)
50
- pred = torch.where(output >= 0.5,
51
- torch.Tensor([1.0]).to(device).float(),
52
- torch.Tensor([0.0]).to(device).float())
53
- correct += pred.eq(target.view_as(pred)).sum().item()
54
- batch_prediction_pos_prob = torch.sigmoid(output).cpu().numpy()
55
- prediction_pos_probs.extend(batch_prediction_pos_prob.tolist())
56
- all_labels.extend(target.cpu().numpy().tolist())
57
- mean_loss = loss / total_samples
58
- acc = correct / total_samples
59
- end_time = time.time()
60
- test_time = end_time - start_time
61
- print(f'\rTest: epoch: {epoch} Accuracy: {acc:.5f} batch: {batch_idx+1}/{len(loader)} loss: {mean_loss:.5f} loss: {mean_loss:.5f} time {test_time:.5f}', end='\r', flush=True)
62
- loss /= len(loader)
63
- data_df = classification_metrics(all_labels, prediction_pos_probs, loader_name, loss, epoch)
64
- return data_df, prediction_pos_probs, all_labels
65
-
66
- def evaluate_model_performance(loaders, model, loader_name_list, epoch, train_mode, loss_type):
67
- """
68
- Evaluate the performance of a model on given data loaders.
69
-
70
- Args:
71
- loaders (list): List of data loaders.
72
- model: The model to evaluate.
73
- loader_name_list (list): List of names for the data loaders.
74
- epoch (int): The current epoch.
75
- train_mode (str): The training mode ('erm' or 'irm').
76
- loss_type: The type of loss function.
77
-
78
- Returns:
79
- tuple: A tuple containing the evaluation result and the time taken for evaluation.
80
- """
81
- start_time = time.time()
82
- df_list = []
83
- if train_mode == 'erm':
84
- result, _, _ = evaluate_model_core(model, loaders, loader_name_list, epoch, loss_type)
85
- if train_mode == 'irm':
86
- for loader_index in range(0, len(loaders)):
87
- loader = loaders[loader_index]
88
- loader_name = loader_name_list[loader_index]
89
- data_df, _, _ = evaluate_model_core(model, loader, loader_name, epoch, loss_type)
90
- torch.cuda.empty_cache()
91
- df_list.append(data_df)
92
- result = pd.concat(df_list)
93
- nc_mean = result['neg_accuracy'].mean(skipna=True)
94
- pc_mean = result['pos_accuracy'].mean(skipna=True)
95
- tot_mean = result['accuracy'].mean(skipna=True)
96
- loss_mean = result['loss'].mean(skipna=True)
97
- prauc_mean = result['prauc'].mean(skipna=True)
98
- data_mean = {'accuracy': tot_mean, 'neg_accuracy': nc_mean, 'pos_accuracy': pc_mean, 'loss': loss_mean, 'prauc': prauc_mean}
99
- result = pd.concat([pd.DataFrame(result), pd.DataFrame(data_mean, index=[str(epoch)+'_mean'])])
100
- end_time = time.time()
101
- test_time = end_time - start_time
102
- return result, test_time
103
-
104
- def test_model_core(model, loader, loader_name, epoch, loss_type):
105
-
106
- from .utils import calculate_loss, classification_metrics
107
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
108
- model.eval()
109
- loss = 0
110
- correct = 0
111
- total_samples = 0
112
- prediction_pos_probs = []
113
- all_labels = []
114
- filenames = []
115
- true_targets = []
116
- predicted_outputs = []
117
-
118
- model = model.to(device)
119
- with torch.no_grad():
120
- for batch_idx, (data, target, filename) in enumerate(loader, start=1): # Assuming loader provides filenames
121
- start_time = time.time()
122
- data, target = data.to(device), target.to(device).float()
123
- output = model(data)
124
- loss += F.binary_cross_entropy_with_logits(output, target, reduction='sum').item()
125
- loss = calculate_loss(output, target, loss_type=loss_type)
126
- loss += loss.item()
127
- total_samples += data.size(0)
128
- pred = torch.where(output >= 0.5,
129
- torch.Tensor([1.0]).to(device).float(),
130
- torch.Tensor([0.0]).to(device).float())
131
- correct += pred.eq(target.view_as(pred)).sum().item()
132
- batch_prediction_pos_prob = torch.sigmoid(output).cpu().numpy()
133
- prediction_pos_probs.extend(batch_prediction_pos_prob.tolist())
134
- all_labels.extend(target.cpu().numpy().tolist())
135
-
136
- # Storing intermediate results in lists
137
- true_targets.extend(target.cpu().numpy().tolist())
138
- predicted_outputs.extend(pred.cpu().numpy().tolist())
139
- filenames.extend(filename)
140
-
141
- mean_loss = loss / total_samples
142
- acc = correct / total_samples
143
- end_time = time.time()
144
- test_time = end_time - start_time
145
- print(f'\rTest: epoch: {epoch} Accuracy: {acc:.5f} batch: {batch_idx}/{len(loader)} loss: {mean_loss:.5f} time {test_time:.5f}', end='\r', flush=True)
146
-
147
- # Constructing the DataFrame
148
- results_df = pd.DataFrame({
149
- 'filename': filenames,
150
- 'true_label': true_targets,
151
- 'predicted_label': predicted_outputs,
152
- 'class_1_probability':prediction_pos_probs})
153
-
154
- loss /= len(loader)
155
- data_df = classification_metrics(all_labels, prediction_pos_probs, loader_name, loss, epoch)
156
- return data_df, prediction_pos_probs, all_labels, results_df
157
-
158
- def test_model_performance(loaders, model, loader_name_list, epoch, train_mode, loss_type):
159
- """
160
- Test the performance of a model on given data loaders.
161
-
162
- Args:
163
- loaders (list): List of data loaders.
164
- model: The model to be tested.
165
- loader_name_list (list): List of names for the data loaders.
166
- epoch (int): The current epoch.
167
- train_mode (str): The training mode ('erm' or 'irm').
168
- loss_type: The type of loss function.
169
-
170
- Returns:
171
- tuple: A tuple containing the test results and the results dataframe.
172
- """
173
- start_time = time.time()
174
- df_list = []
175
- if train_mode == 'erm':
176
- result, prediction_pos_probs, all_labels, results_df = test_model_core(model, loaders, loader_name_list, epoch, loss_type)
177
- if train_mode == 'irm':
178
- for loader_index in range(0, len(loaders)):
179
- loader = loaders[loader_index]
180
- loader_name = loader_name_list[loader_index]
181
- data_df, prediction_pos_probs, all_labels, results_df = test_model_core(model, loader, loader_name, epoch, loss_type)
182
- torch.cuda.empty_cache()
183
- df_list.append(data_df)
184
- result = pd.concat(df_list)
185
- nc_mean = result['neg_accuracy'].mean(skipna=True)
186
- pc_mean = result['pos_accuracy'].mean(skipna=True)
187
- tot_mean = result['accuracy'].mean(skipna=True)
188
- loss_mean = result['loss'].mean(skipna=True)
189
- prauc_mean = result['prauc'].mean(skipna=True)
190
- data_mean = {'accuracy': tot_mean, 'neg_accuracy': nc_mean, 'pos_accuracy': pc_mean, 'loss': loss_mean, 'prauc': prauc_mean}
191
- result = pd.concat([pd.DataFrame(result), pd.DataFrame(data_mean, index=[str(epoch)+'_mean'])])
192
- end_time = time.time()
193
- test_time = end_time - start_time
194
- return result, results_df
195
-
196
- def train_test_model(src, settings, custom_model=False, custom_model_path=None):
197
-
198
- from .io import _save_settings, _copy_missclassified
199
- from .utils import pick_best_model
200
- from .core import generate_loaders
201
-
202
- settings['src'] = src
203
- settings_df = pd.DataFrame(list(settings.items()), columns=['Key', 'Value'])
204
- settings_csv = os.path.join(src,'settings','train_test_model_settings.csv')
205
- os.makedirs(os.path.join(src,'settings'), exist_ok=True)
206
- settings_df.to_csv(settings_csv, index=False)
207
-
208
- if custom_model:
209
- model = torch.load(custom_model_path)
210
-
211
- if settings['train']:
212
- _save_settings(settings, src)
213
- torch.cuda.empty_cache()
214
- torch.cuda.memory.empty_cache()
215
- gc.collect()
216
- dst = os.path.join(src,'model')
217
- os.makedirs(dst, exist_ok=True)
218
- settings['src'] = src
219
- settings['dst'] = dst
220
- if settings['train']:
221
- train, val, plate_names = generate_loaders(src,
222
- train_mode=settings['train_mode'],
223
- mode='train',
224
- image_size=settings['image_size'],
225
- batch_size=settings['batch_size'],
226
- classes=settings['classes'],
227
- num_workers=settings['num_workers'],
228
- validation_split=settings['val_split'],
229
- pin_memory=settings['pin_memory'],
230
- normalize=settings['normalize'],
231
- channels=settings['channels'],
232
- verbose=settings['verbose'])
233
-
234
-
235
- if settings['test']:
236
- test, _, plate_names_test = generate_loaders(src,
237
- train_mode=settings['train_mode'],
238
- mode='test',
239
- image_size=settings['image_size'],
240
- batch_size=settings['batch_size'],
241
- classes=settings['classes'],
242
- num_workers=settings['num_workers'],
243
- validation_split=0.0,
244
- pin_memory=settings['pin_memory'],
245
- normalize=settings['normalize'],
246
- channels=settings['channels'],
247
- verbose=settings['verbose'])
248
- if model == None:
249
- model_path = pick_best_model(src+'/model')
250
- print(f'Best model: {model_path}')
251
-
252
- model = torch.load(model_path, map_location=lambda storage, loc: storage)
253
-
254
- model_type = settings['model_type']
255
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
256
- print(type(model))
257
- print(model)
258
-
259
- model_fldr = os.path.join(src,'model')
260
- time_now = datetime.date.today().strftime('%y%m%d')
261
- result_loc = f'{model_fldr}/{model_type}_time_{time_now}_result.csv'
262
- acc_loc = f'{model_fldr}/{model_type}_time_{time_now}_acc.csv'
263
- print(f'Results wil be saved in: {result_loc}')
264
-
265
- result, accuracy = test_model_performance(loaders=test,
266
- model=model,
267
- loader_name_list='test',
268
- epoch=1,
269
- train_mode=settings['train_mode'],
270
- loss_type=settings['loss_type'])
271
-
272
- result.to_csv(result_loc, index=True, header=True, mode='w')
273
- accuracy.to_csv(acc_loc, index=True, header=True, mode='w')
274
- _copy_missclassified(accuracy)
275
- else:
276
- test = None
277
-
278
- if settings['train']:
279
- train_model(dst = settings['dst'],
280
- model_type=settings['model_type'],
281
- train_loaders = train,
282
- train_loader_names = plate_names,
283
- train_mode = settings['train_mode'],
284
- epochs = settings['epochs'],
285
- learning_rate = settings['learning_rate'],
286
- init_weights = settings['init_weights'],
287
- weight_decay = settings['weight_decay'],
288
- amsgrad = settings['amsgrad'],
289
- optimizer_type = settings['optimizer_type'],
290
- use_checkpoint = settings['use_checkpoint'],
291
- dropout_rate = settings['dropout_rate'],
292
- num_workers = settings['num_workers'],
293
- val_loaders = val,
294
- test_loaders = test,
295
- intermedeate_save = settings['intermedeate_save'],
296
- schedule = settings['schedule'],
297
- loss_type=settings['loss_type'],
298
- gradient_accumulation=settings['gradient_accumulation'],
299
- gradient_accumulation_steps=settings['gradient_accumulation_steps'])
300
-
301
- torch.cuda.empty_cache()
302
- torch.cuda.memory.empty_cache()
303
- gc.collect()
304
-
305
- def train_model(dst, model_type, train_loaders, train_loader_names, train_mode='erm', epochs=100, learning_rate=0.0001, weight_decay=0.05, amsgrad=False, optimizer_type='adamw', use_checkpoint=False, dropout_rate=0, num_workers=20, val_loaders=None, test_loaders=None, init_weights='imagenet', intermedeate_save=None, chan_dict=None, schedule = None, loss_type='binary_cross_entropy_with_logits', gradient_accumulation=False, gradient_accumulation_steps=4):
306
- """
307
- Trains a model using the specified parameters.
308
-
309
- Args:
310
- dst (str): The destination path to save the model and results.
311
- model_type (str): The type of model to train.
312
- train_loaders (list): A list of training data loaders.
313
- train_loader_names (list): A list of names for the training data loaders.
314
- train_mode (str, optional): The training mode. Defaults to 'erm'.
315
- epochs (int, optional): The number of training epochs. Defaults to 100.
316
- learning_rate (float, optional): The learning rate for the optimizer. Defaults to 0.0001.
317
- weight_decay (float, optional): The weight decay for the optimizer. Defaults to 0.05.
318
- amsgrad (bool, optional): Whether to use AMSGrad for the optimizer. Defaults to False.
319
- optimizer_type (str, optional): The type of optimizer to use. Defaults to 'adamw'.
320
- use_checkpoint (bool, optional): Whether to use checkpointing during training. Defaults to False.
321
- dropout_rate (float, optional): The dropout rate for the model. Defaults to 0.
322
- num_workers (int, optional): The number of workers for data loading. Defaults to 20.
323
- val_loaders (list, optional): A list of validation data loaders. Defaults to None.
324
- test_loaders (list, optional): A list of test data loaders. Defaults to None.
325
- init_weights (str, optional): The initialization weights for the model. Defaults to 'imagenet'.
326
- intermedeate_save (list, optional): The intermediate save thresholds. Defaults to None.
327
- chan_dict (dict, optional): The channel dictionary. Defaults to None.
328
- schedule (str, optional): The learning rate schedule. Defaults to None.
329
- loss_type (str, optional): The loss function type. Defaults to 'binary_cross_entropy_with_logits'.
330
- gradient_accumulation (bool, optional): Whether to use gradient accumulation. Defaults to False.
331
- gradient_accumulation_steps (int, optional): The number of steps for gradient accumulation. Defaults to 4.
332
-
333
- Returns:
334
- None
335
- """
336
-
337
- from .io import _save_model, _save_progress
338
- from .utils import compute_irm_penalty, calculate_loss, choose_model
339
-
340
- print(f'Train batches:{len(train_loaders)}, Validation batches:{len(val_loaders)}')
341
-
342
- if test_loaders != None:
343
- print(f'Test batches:{len(test_loaders)}')
344
-
345
- use_cuda = torch.cuda.is_available()
346
- device = torch.device("cuda" if use_cuda else "cpu")
347
- kwargs = {'num_workers': num_workers, 'pin_memory': True} if use_cuda else {}
348
-
349
- for idx, (images, labels, filenames) in enumerate(train_loaders):
350
- batch, channels, height, width = images.shape
351
- break
352
-
353
- model = choose_model(model_type, device, init_weights, dropout_rate, use_checkpoint)
354
-
355
- if model is None:
356
- print(f'Model {model_type} not found')
357
- return
358
-
359
- model.to(device)
360
-
361
- if optimizer_type == 'adamw':
362
- optimizer = AdamW(model.parameters(), lr=learning_rate, betas=(0.9, 0.999), weight_decay=weight_decay, amsgrad=amsgrad)
363
-
364
- if optimizer_type == 'adagrad':
365
- optimizer = Adagrad(model.parameters(), lr=learning_rate, eps=1e-8, weight_decay=weight_decay)
366
-
367
- if schedule == 'step_lr':
368
- StepLR_step_size = int(epochs/5)
369
- StepLR_gamma = 0.75
370
- scheduler = StepLR(optimizer, step_size=StepLR_step_size, gamma=StepLR_gamma)
371
- elif schedule == 'reduce_lr_on_plateau':
372
- scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10, verbose=True)
373
- else:
374
- scheduler = None
375
-
376
- if train_mode == 'erm':
377
- for epoch in range(1, epochs+1):
378
- model.train()
379
- start_time = time.time()
380
- running_loss = 0.0
381
-
382
- # Initialize gradients if using gradient accumulation
383
- if gradient_accumulation:
384
- optimizer.zero_grad()
385
-
386
- for batch_idx, (data, target, filenames) in enumerate(train_loaders, start=1):
387
- data, target = data.to(device), target.to(device).float()
388
- output = model(data)
389
- loss = calculate_loss(output, target, loss_type=loss_type)
390
- # Normalize loss if using gradient accumulation
391
- if gradient_accumulation:
392
- loss /= gradient_accumulation_steps
393
- running_loss += loss.item() * gradient_accumulation_steps # correct the running_loss
394
- loss.backward()
395
-
396
- # Step optimizer if not using gradient accumulation or every gradient_accumulation_steps
397
- if not gradient_accumulation or (batch_idx % gradient_accumulation_steps == 0):
398
- optimizer.step()
399
- optimizer.zero_grad()
400
-
401
- avg_loss = running_loss / batch_idx
402
- print(f'\rTrain: epoch: {epoch} batch: {batch_idx}/{len(train_loaders)} avg_loss: {avg_loss:.5f} time: {(time.time()-start_time):.5f}', end='\r', flush=True)
403
-
404
- end_time = time.time()
405
- train_time = end_time - start_time
406
- train_metrics = {'epoch':epoch,'loss':loss.cpu().item(), 'train_time':train_time}
407
- train_metrics_df = pd.DataFrame(train_metrics, index=[epoch])
408
- train_names = 'train'
409
- results_df, train_test_time = evaluate_model_performance(train_loaders, model, train_names, epoch, train_mode='erm', loss_type=loss_type)
410
- train_metrics_df['train_test_time'] = train_test_time
411
- if val_loaders != None:
412
- val_names = 'val'
413
- result, val_time = evaluate_model_performance(val_loaders, model, val_names, epoch, train_mode='erm', loss_type=loss_type)
414
-
415
- if schedule == 'reduce_lr_on_plateau':
416
- val_loss = result['loss']
417
-
418
- results_df = pd.concat([results_df, result])
419
- train_metrics_df['val_time'] = val_time
420
- if test_loaders != None:
421
- test_names = 'test'
422
- result, test_test_time = evaluate_model_performance(test_loaders, model, test_names, epoch, train_mode='erm', loss_type=loss_type)
423
- results_df = pd.concat([results_df, result])
424
- test_time = (train_test_time+val_time+test_test_time)/3
425
- train_metrics_df['test_time'] = test_time
426
-
427
- if scheduler:
428
- if schedule == 'reduce_lr_on_plateau':
429
- scheduler.step(val_loss)
430
- if schedule == 'step_lr':
431
- scheduler.step()
432
-
433
- _save_progress(dst, results_df, train_metrics_df)
434
- clear_output(wait=True)
435
- display(results_df)
436
- _save_model(model, model_type, results_df, dst, epoch, epochs, intermedeate_save=[0.99,0.98,0.95,0.94])
437
-
438
- if train_mode == 'irm':
439
- dummy_w = torch.nn.Parameter(torch.Tensor([1.0])).to(device)
440
- phi = torch.nn.Parameter (torch.ones(4,1))
441
- for epoch in range(1, epochs):
442
- model.train()
443
- penalty_factor = epoch * 1e-5
444
- epoch_names = [str(epoch) + '_' + item for item in train_loader_names]
445
- loader_erm_loss_list = []
446
- total_erm_loss_mean = 0
447
- for loader_index in range(0, len(train_loaders)):
448
- start_time = time.time()
449
- loader = train_loaders[loader_index]
450
- loader_erm_loss_mean = 0
451
- batch_count = 0
452
- batch_erm_loss_list = []
453
- for batch_idx, (data, target, filenames) in enumerate(loader, start=1):
454
- optimizer.zero_grad()
455
- data, target = data.to(device), target.to(device).float()
456
-
457
- output = model(data)
458
- erm_loss = F.binary_cross_entropy_with_logits(output * dummy_w, target, reduction='none')
459
-
460
- batch_erm_loss_list.append(erm_loss.mean())
461
- print(f'\repoch: {epoch} loader: {loader_index} batch: {batch_idx+1}/{len(loader)}', end='\r', flush=True)
462
- loader_erm_loss_mean = torch.stack(batch_erm_loss_list).mean()
463
- loader_erm_loss_list.append(loader_erm_loss_mean)
464
- total_erm_loss_mean = torch.stack(loader_erm_loss_list).mean()
465
- irm_loss = compute_irm_penalty(loader_erm_loss_list, dummy_w, device)
466
-
467
- (total_erm_loss_mean + penalty_factor * irm_loss).backward()
468
- optimizer.step()
469
-
470
- end_time = time.time()
471
- train_time = end_time - start_time
472
-
473
- train_metrics = {'epoch': epoch, 'irm_loss': irm_loss, 'erm_loss': total_erm_loss_mean, 'penalty_factor': penalty_factor, 'train_time': train_time}
474
- #train_metrics = {'epoch':epoch,'irm_loss':irm_loss.cpu().item(),'erm_loss':total_erm_loss_mean.cpu().item(),'penalty_factor':penalty_factor, 'train_time':train_time}
475
- train_metrics_df = pd.DataFrame(train_metrics, index=[epoch])
476
- print(f'\rTrain: epoch: {epoch} loader: {loader_index} batch: {batch_idx+1}/{len(loader)} irm_loss: {irm_loss:.5f} mean_erm_loss: {total_erm_loss_mean:.5f} train time {train_time:.5f}', end='\r', flush=True)
477
-
478
- train_names = [item + '_train' for item in train_loader_names]
479
- results_df, train_test_time = evaluate_model_performance(train_loaders, model, train_names, epoch, train_mode='irm', loss_type=loss_type)
480
- train_metrics_df['train_test_time'] = train_test_time
481
-
482
- if val_loaders != None:
483
- val_names = [item + '_val' for item in train_loader_names]
484
- result, val_time = evaluate_model_performance(val_loaders, model, val_names, epoch, train_mode='irm', loss_type=loss_type)
485
-
486
- if schedule == 'reduce_lr_on_plateau':
487
- val_loss = result['loss']
488
-
489
- results_df = pd.concat([results_df, result])
490
- train_metrics_df['val_time'] = val_time
491
-
492
- if test_loaders != None:
493
- test_names = [item + '_test' for item in train_loader_names] #test_loader_names?
494
- result, test_test_time = evaluate_model_performance(test_loaders, model, test_names, epoch, train_mode='irm', loss_type=loss_type)
495
- results_df = pd.concat([results_df, result])
496
- train_metrics_df['test_test_time'] = test_test_time
497
-
498
- if scheduler:
499
- if schedule == 'reduce_lr_on_plateau':
500
- scheduler.step(val_loss)
501
- if schedule == 'step_lr':
502
- scheduler.step()
503
-
504
- clear_output(wait=True)
505
- display(results_df)
506
- _save_progress(dst, results_df, train_metrics_df)
507
- _save_model(model, model_type, results_df, dst, epoch, epochs, intermedeate_save=[0.99,0.98,0.95,0.94])
508
- print(f'Saved model: {dst}')
509
- return
510
-
511
- def get_submodules(model, prefix=''):
512
- submodules = []
513
- for name, module in model.named_children():
514
- full_name = prefix + ('.' if prefix else '') + name
515
- submodules.append(full_name)
516
- submodules.extend(get_submodules(module, full_name))
517
- return submodules
518
-
519
- def visualize_model_attention_v2(src, model_type='maxvit', model_path='', image_size=224, channels=[1,2,3], normalize=True, class_names=None, save_saliency=False, save_dir='saliency_maps'):
520
- import torch
521
- import os
522
- from spacr.utils import SaliencyMapGenerator, preprocess_image
523
- import matplotlib.pyplot as plt
524
- import numpy as np
525
- from PIL import Image
526
-
527
- use_cuda = torch.cuda.is_available()
528
- device = torch.device("cuda" if use_cuda else "cpu")
529
-
530
- # Load the entire model object
531
- model = torch.load(model_path)
532
- model.to(device)
533
-
534
- # Create directory for saving saliency maps if it does not exist
535
- if save_saliency and not os.path.exists(save_dir):
536
- os.makedirs(save_dir)
537
-
538
- # Collect all images and their tensors
539
- images = []
540
- input_tensors = []
541
- filenames = []
542
- for file in os.listdir(src):
543
- image_path = os.path.join(src, file)
544
- image, input_tensor = preprocess_image(image_path, normalize=normalize, image_size=image_size, channels=channels)
545
- images.append(image)
546
- input_tensors.append(input_tensor)
547
- filenames.append(file)
548
-
549
- input_tensors = torch.cat(input_tensors).to(device)
550
- class_labels = torch.zeros(input_tensors.size(0), dtype=torch.long).to(device) # Replace with actual class labels if available
551
-
552
- # Generate saliency maps
553
- cam_generator = SaliencyMapGenerator(model)
554
- saliency_maps = cam_generator.compute_saliency_maps(input_tensors, class_labels)
555
-
556
- # Plot images, saliency maps, and overlays
557
- saliency_maps = saliency_maps.cpu().numpy()
558
- N = len(images)
559
-
560
- dst = os.path.join(src, 'saliency_maps')
561
- os.makedirs(dst, exist_ok=True)
562
-
563
- for i in range(N):
564
- fig, axes = plt.subplots(1, 3, figsize=(15, 5))
565
-
566
- # Original image
567
- axes[0].imshow(images[i])
568
- axes[0].axis('off')
569
- if class_names:
570
- axes[0].set_title(class_names[class_labels[i].item()])
571
-
572
- # Saliency map
573
- axes[1].imshow(saliency_maps[i], cmap=plt.cm.hot)
574
- axes[1].axis('off')
575
-
576
- # Overlay
577
- overlay = np.array(images[i])
578
- axes[2].imshow(overlay)
579
- axes[2].imshow(saliency_maps[i], cmap='jet', alpha=0.5)
580
- axes[2].axis('off')
581
-
582
- plt.tight_layout()
583
- plt.show()
584
-
585
- # Save the saliency map if required
586
- if save_saliency:
587
- saliency_image = Image.fromarray((saliency_maps[i] * 255).astype(np.uint8))
588
- saliency_image.save(os.path.join(dst, f'saliency_{filenames[i]}'))
589
-
590
- def visualize_model_attention(src, model_type='maxvit', model_path='', image_size=224, channels=[1,2,3], normalize=True, class_names=None, save_saliency=False, save_dir='saliency_maps'):
591
- import torch
592
- import os
593
- from spacr.utils import SaliencyMapGenerator, preprocess_image
594
- import matplotlib.pyplot as plt
595
- import numpy as np
596
- from PIL import Image
597
-
598
- use_cuda = torch.cuda.is_available()
599
- device = torch.device("cuda" if use_cuda else "cpu")
600
-
601
- # Load the entire model object
602
- model = torch.load(model_path)
603
- model.to(device)
604
-
605
- # Create directory for saving saliency maps if it does not exist
606
- if save_saliency and not os.path.exists(save_dir):
607
- os.makedirs(save_dir)
608
-
609
- # Collect all images and their tensors
610
- images = []
611
- input_tensors = []
612
- filenames = []
613
- for file in os.listdir(src):
614
- if not file.endswith('.png'):
615
- continue
616
- image_path = os.path.join(src, file)
617
- image, input_tensor = preprocess_image(image_path, normalize=normalize, image_size=image_size, channels=channels)
618
- images.append(image)
619
- input_tensors.append(input_tensor)
620
- filenames.append(file)
621
-
622
- input_tensors = torch.cat(input_tensors).to(device)
623
- class_labels = torch.zeros(input_tensors.size(0), dtype=torch.long).to(device) # Replace with actual class labels if available
624
-
625
- # Generate saliency maps
626
- cam_generator = SaliencyMapGenerator(model)
627
- saliency_maps = cam_generator.compute_saliency_maps(input_tensors, class_labels)
628
-
629
- # Convert saliency maps to numpy arrays
630
- saliency_maps = saliency_maps.cpu().numpy()
631
-
632
- N = len(images)
633
-
634
- dst = os.path.join(src, 'saliency_maps')
635
-
636
- for i in range(N):
637
- fig, axes = plt.subplots(1, 3, figsize=(20, 5))
638
-
639
- # Original image
640
- axes[0].imshow(images[i])
641
- axes[0].axis('off')
642
- if class_names:
643
- axes[0].set_title(f"Class: {class_names[class_labels[i].item()]}")
644
-
645
- # Saliency Map
646
- axes[1].imshow(saliency_maps[i, 0], cmap='hot')
647
- axes[1].axis('off')
648
- axes[1].set_title("Saliency Map")
649
-
650
- # Overlay
651
- overlay = np.array(images[i])
652
- overlay = overlay / overlay.max()
653
- saliency_map_rgb = np.stack([saliency_maps[i, 0]] * 3, axis=-1) # Convert saliency map to RGB
654
- overlay = (overlay * 0.5 + saliency_map_rgb * 0.5).clip(0, 1)
655
- axes[2].imshow(overlay)
656
- axes[2].axis('off')
657
- axes[2].set_title("Overlay")
658
-
659
- plt.tight_layout()
660
- plt.show()
661
-
662
- # Save the saliency map if required
663
- if save_saliency:
664
- os.makedirs(dst, exist_ok=True)
665
- saliency_image = Image.fromarray((saliency_maps[i, 0] * 255).astype(np.uint8))
666
- saliency_image.save(os.path.join(dst, f'saliency_{filenames[i]}'))
667
-
spacr/umap.py DELETED
File without changes
File without changes