spacr 0.0.1__py3-none-any.whl → 0.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
spacr/train.py CHANGED
@@ -6,6 +6,7 @@ from torch.autograd import grad
6
6
  from torch.optim.lr_scheduler import StepLR
7
7
  import torch.nn.functional as F
8
8
  from IPython.display import display, clear_output
9
+ import difflib
9
10
 
10
11
  from .logger import log_function_call
11
12
 
@@ -194,15 +195,21 @@ def test_model_performance(loaders, model, loader_name_list, epoch, train_mode,
194
195
 
195
196
  def train_test_model(src, settings, custom_model=False, custom_model_path=None):
196
197
 
197
- from .io import save_settings, _copy_missclassified
198
- from .utils import pick_best_model, test_model_performance
198
+ from .io import _save_settings, _copy_missclassified
199
+ from .utils import pick_best_model
199
200
  from .core import generate_loaders
200
201
 
202
+ settings['src'] = src
203
+ settings_df = pd.DataFrame(list(settings.items()), columns=['Key', 'Value'])
204
+ settings_csv = os.path.join(src,'settings','train_test_model_settings.csv')
205
+ os.makedirs(os.path.join(src,'settings'), exist_ok=True)
206
+ settings_df.to_csv(settings_csv, index=False)
207
+
201
208
  if custom_model:
202
- model = torch.load(custom_model_path) #if using a custom trained model
209
+ model = torch.load(custom_model_path)
203
210
 
204
211
  if settings['train']:
205
- save_settings(settings, src)
212
+ _save_settings(settings, src)
206
213
  torch.cuda.empty_cache()
207
214
  torch.cuda.memory.empty_cache()
208
215
  gc.collect()
@@ -221,20 +228,23 @@ def train_test_model(src, settings, custom_model=False, custom_model_path=None):
221
228
  validation_split=settings['val_split'],
222
229
  pin_memory=settings['pin_memory'],
223
230
  normalize=settings['normalize'],
224
- verbose=settings['verbose'])
231
+ channels=settings['channels'],
232
+ verbose=settings['verbose'])
233
+
225
234
 
226
235
  if settings['test']:
227
236
  test, _, plate_names_test = generate_loaders(src,
228
- train_mode=settings['train_mode'],
229
- mode='test',
230
- image_size=settings['image_size'],
231
- batch_size=settings['batch_size'],
232
- classes=settings['classes'],
233
- num_workers=settings['num_workers'],
234
- validation_split=0.0,
235
- pin_memory=settings['pin_memory'],
236
- normalize=settings['normalize'],
237
- verbose=settings['verbose'])
237
+ train_mode=settings['train_mode'],
238
+ mode='test',
239
+ image_size=settings['image_size'],
240
+ batch_size=settings['batch_size'],
241
+ classes=settings['classes'],
242
+ num_workers=settings['num_workers'],
243
+ validation_split=0.0,
244
+ pin_memory=settings['pin_memory'],
245
+ normalize=settings['normalize'],
246
+ channels=settings['channels'],
247
+ verbose=settings['verbose'])
238
248
  if model == None:
239
249
  model_path = pick_best_model(src+'/model')
240
250
  print(f'Best model: {model_path}')
@@ -324,8 +334,8 @@ def train_model(dst, model_type, train_loaders, train_loader_names, train_mode='
324
334
  None
325
335
  """
326
336
 
327
- from .io import save_model, save_progress
328
- from .utils import evaluate_model_performance, compute_irm_penalty, calculate_loss, choose_model
337
+ from .io import _save_model, _save_progress
338
+ from .utils import compute_irm_penalty, calculate_loss, choose_model
329
339
 
330
340
  print(f'Train batches:{len(train_loaders)}, Validation batches:{len(val_loaders)}')
331
341
 
@@ -341,6 +351,11 @@ def train_model(dst, model_type, train_loaders, train_loader_names, train_mode='
341
351
  break
342
352
 
343
353
  model = choose_model(model_type, device, init_weights, dropout_rate, use_checkpoint)
354
+
355
+ if model is None:
356
+ print(f'Model {model_type} not found')
357
+ return
358
+
344
359
  model.to(device)
345
360
 
346
361
  if optimizer_type == 'adamw':
@@ -415,10 +430,10 @@ def train_model(dst, model_type, train_loaders, train_loader_names, train_mode='
415
430
  if schedule == 'step_lr':
416
431
  scheduler.step()
417
432
 
418
- save_progress(dst, results_df, train_metrics_df)
433
+ _save_progress(dst, results_df, train_metrics_df)
419
434
  clear_output(wait=True)
420
435
  display(results_df)
421
- save_model(model, model_type, results_df, dst, epoch, epochs, intermedeate_save=[0.99,0.98,0.95,0.94])
436
+ _save_model(model, model_type, results_df, dst, epoch, epochs, intermedeate_save=[0.99,0.98,0.95,0.94])
422
437
 
423
438
  if train_mode == 'irm':
424
439
  dummy_w = torch.nn.Parameter(torch.Tensor([1.0])).to(device)
@@ -488,7 +503,165 @@ def train_model(dst, model_type, train_loaders, train_loader_names, train_mode='
488
503
 
489
504
  clear_output(wait=True)
490
505
  display(results_df)
491
- save_progress(dst, results_df, train_metrics_df)
492
- save_model(model, model_type, results_df, dst, epoch, epochs, intermedeate_save=[0.99,0.98,0.95,0.94])
506
+ _save_progress(dst, results_df, train_metrics_df)
507
+ _save_model(model, model_type, results_df, dst, epoch, epochs, intermedeate_save=[0.99,0.98,0.95,0.94])
493
508
  print(f'Saved model: {dst}')
494
- return
509
+ return
510
+
511
+ def get_submodules(model, prefix=''):
512
+ submodules = []
513
+ for name, module in model.named_children():
514
+ full_name = prefix + ('.' if prefix else '') + name
515
+ submodules.append(full_name)
516
+ submodules.extend(get_submodules(module, full_name))
517
+ return submodules
518
+
519
+ def visualize_model_attention_v2(src, model_type='maxvit', model_path='', image_size=224, channels=[1,2,3], normalize=True, class_names=None, save_saliency=False, save_dir='saliency_maps'):
520
+ import torch
521
+ import os
522
+ from spacr.utils import SaliencyMapGenerator, preprocess_image
523
+ import matplotlib.pyplot as plt
524
+ import numpy as np
525
+ from PIL import Image
526
+
527
+ use_cuda = torch.cuda.is_available()
528
+ device = torch.device("cuda" if use_cuda else "cpu")
529
+
530
+ # Load the entire model object
531
+ model = torch.load(model_path)
532
+ model.to(device)
533
+
534
+ # Create directory for saving saliency maps if it does not exist
535
+ if save_saliency and not os.path.exists(save_dir):
536
+ os.makedirs(save_dir)
537
+
538
+ # Collect all images and their tensors
539
+ images = []
540
+ input_tensors = []
541
+ filenames = []
542
+ for file in os.listdir(src):
543
+ image_path = os.path.join(src, file)
544
+ image, input_tensor = preprocess_image(image_path, normalize=normalize, image_size=image_size, channels=channels)
545
+ images.append(image)
546
+ input_tensors.append(input_tensor)
547
+ filenames.append(file)
548
+
549
+ input_tensors = torch.cat(input_tensors).to(device)
550
+ class_labels = torch.zeros(input_tensors.size(0), dtype=torch.long).to(device) # Replace with actual class labels if available
551
+
552
+ # Generate saliency maps
553
+ cam_generator = SaliencyMapGenerator(model)
554
+ saliency_maps = cam_generator.compute_saliency_maps(input_tensors, class_labels)
555
+
556
+ # Plot images, saliency maps, and overlays
557
+ saliency_maps = saliency_maps.cpu().numpy()
558
+ N = len(images)
559
+
560
+ dst = os.path.join(src, 'saliency_maps')
561
+ os.makedirs(dst, exist_ok=True)
562
+
563
+ for i in range(N):
564
+ fig, axes = plt.subplots(1, 3, figsize=(15, 5))
565
+
566
+ # Original image
567
+ axes[0].imshow(images[i])
568
+ axes[0].axis('off')
569
+ if class_names:
570
+ axes[0].set_title(class_names[class_labels[i].item()])
571
+
572
+ # Saliency map
573
+ axes[1].imshow(saliency_maps[i], cmap=plt.cm.hot)
574
+ axes[1].axis('off')
575
+
576
+ # Overlay
577
+ overlay = np.array(images[i])
578
+ axes[2].imshow(overlay)
579
+ axes[2].imshow(saliency_maps[i], cmap='jet', alpha=0.5)
580
+ axes[2].axis('off')
581
+
582
+ plt.tight_layout()
583
+ plt.show()
584
+
585
+ # Save the saliency map if required
586
+ if save_saliency:
587
+ saliency_image = Image.fromarray((saliency_maps[i] * 255).astype(np.uint8))
588
+ saliency_image.save(os.path.join(dst, f'saliency_{filenames[i]}'))
589
+
590
+ def visualize_model_attention(src, model_type='maxvit', model_path='', image_size=224, channels=[1,2,3], normalize=True, class_names=None, save_saliency=False, save_dir='saliency_maps'):
591
+ import torch
592
+ import os
593
+ from spacr.utils import SaliencyMapGenerator, preprocess_image
594
+ import matplotlib.pyplot as plt
595
+ import numpy as np
596
+ from PIL import Image
597
+
598
+ use_cuda = torch.cuda.is_available()
599
+ device = torch.device("cuda" if use_cuda else "cpu")
600
+
601
+ # Load the entire model object
602
+ model = torch.load(model_path)
603
+ model.to(device)
604
+
605
+ # Create directory for saving saliency maps if it does not exist
606
+ if save_saliency and not os.path.exists(save_dir):
607
+ os.makedirs(save_dir)
608
+
609
+ # Collect all images and their tensors
610
+ images = []
611
+ input_tensors = []
612
+ filenames = []
613
+ for file in os.listdir(src):
614
+ if not file.endswith('.png'):
615
+ continue
616
+ image_path = os.path.join(src, file)
617
+ image, input_tensor = preprocess_image(image_path, normalize=normalize, image_size=image_size, channels=channels)
618
+ images.append(image)
619
+ input_tensors.append(input_tensor)
620
+ filenames.append(file)
621
+
622
+ input_tensors = torch.cat(input_tensors).to(device)
623
+ class_labels = torch.zeros(input_tensors.size(0), dtype=torch.long).to(device) # Replace with actual class labels if available
624
+
625
+ # Generate saliency maps
626
+ cam_generator = SaliencyMapGenerator(model)
627
+ saliency_maps = cam_generator.compute_saliency_maps(input_tensors, class_labels)
628
+
629
+ # Convert saliency maps to numpy arrays
630
+ saliency_maps = saliency_maps.cpu().numpy()
631
+
632
+ N = len(images)
633
+
634
+ dst = os.path.join(src, 'saliency_maps')
635
+
636
+ for i in range(N):
637
+ fig, axes = plt.subplots(1, 3, figsize=(20, 5))
638
+
639
+ # Original image
640
+ axes[0].imshow(images[i])
641
+ axes[0].axis('off')
642
+ if class_names:
643
+ axes[0].set_title(f"Class: {class_names[class_labels[i].item()]}")
644
+
645
+ # Saliency Map
646
+ axes[1].imshow(saliency_maps[i, 0], cmap='hot')
647
+ axes[1].axis('off')
648
+ axes[1].set_title("Saliency Map")
649
+
650
+ # Overlay
651
+ overlay = np.array(images[i])
652
+ overlay = overlay / overlay.max()
653
+ saliency_map_rgb = np.stack([saliency_maps[i, 0]] * 3, axis=-1) # Convert saliency map to RGB
654
+ overlay = (overlay * 0.5 + saliency_map_rgb * 0.5).clip(0, 1)
655
+ axes[2].imshow(overlay)
656
+ axes[2].axis('off')
657
+ axes[2].set_title("Overlay")
658
+
659
+ plt.tight_layout()
660
+ plt.show()
661
+
662
+ # Save the saliency map if required
663
+ if save_saliency:
664
+ os.makedirs(dst, exist_ok=True)
665
+ saliency_image = Image.fromarray((saliency_maps[i, 0] * 255).astype(np.uint8))
666
+ saliency_image.save(os.path.join(dst, f'saliency_{filenames[i]}'))
667
+