spacr 0.3.22__py3-none-any.whl → 0.3.30__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
spacr/deep_spacr.py CHANGED
@@ -610,264 +610,168 @@ def train_model(dst, model_type, train_loaders, epochs=100, learning_rate=0.0001
610
610
 
611
611
  return model, model_path
612
612
 
613
- def visualize_saliency_map(settings):
614
- from spacr.utils import SaliencyMapGenerator, print_progress
615
- from spacr.io import TarImageDataset # Assuming you have a dataset class
616
- from torchvision.utils import make_grid
617
-
613
+ def generate_activation_map(settings):
614
+
615
+ from .utils import SaliencyMapGenerator, GradCAMGenerator, SelectChannels, activation_maps_to_database, activation_correlations_to_database
616
+ from .utils import print_progress, save_settings, calculate_activation_correlations
617
+ from .io import TarImageDataset
618
+ from .settings import get_default_generate_activation_map_settings
619
+
620
+ torch.cuda.empty_cache()
621
+ gc.collect()
622
+
623
+ plt.clf()
618
624
  use_cuda = torch.cuda.is_available()
619
625
  device = torch.device("cuda" if use_cuda else "cpu")
620
-
626
+
627
+ source_folder = os.path.dirname(os.path.dirname(settings['dataset']))
628
+ settings['src'] = source_folder
629
+ settings = get_default_generate_activation_map_settings(settings)
630
+ save_settings(settings, name=f"{settings['cam_type']}_settings", show=False)
631
+
632
+ if settings['model_type'] == 'maxvit' and settings['target_layer'] == None:
633
+ settings['target_layer'] = 'base_model.blocks.3.layers.1.layers.MBconv.layers.conv_b'
634
+ if settings['cam_type'] in ['saliency_image', 'saliency_channel']:
635
+ settings['target_layer'] = None
636
+
621
637
  # Set number of jobs for loading
622
- if settings['n_jobs'] is None:
638
+ n_jobs = settings['n_jobs']
639
+ if n_jobs is None:
623
640
  n_jobs = max(1, cpu_count() - 4)
624
- else:
625
- n_jobs = settings['n_jobs']
626
641
 
627
642
  # Set transforms for images
628
- if settings['normalize']:
629
- transform = transforms.Compose([
630
- transforms.ToTensor(),
631
- transforms.CenterCrop(size=(settings['image_size'], settings['image_size'])),
632
- transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])
633
- else:
634
- transform = transforms.Compose([
635
- transforms.ToTensor(),
636
- transforms.CenterCrop(size=(settings['image_size'], settings['image_size']))])
643
+ transform = transforms.Compose([
644
+ transforms.ToTensor(),
645
+ transforms.CenterCrop(size=(settings['image_size'], settings['image_size'])),
646
+ transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) if settings['normalize_input'] else None,
647
+ SelectChannels(settings['channels'])
648
+ ])
637
649
 
638
650
  # Handle dataset path
639
- if os.path.exists(settings['dataset']):
640
- tar_path = settings['dataset']
641
- else:
651
+ if not os.path.exists(settings['dataset']):
642
652
  print(f"Dataset not found at {settings['dataset']}")
643
653
  return
644
-
645
- if settings.get('save', False):
646
- if settings['dtype'] not in ['uint8', 'uint16']:
647
- print("Invalid dtype in settings. Please use 'uint8' or 'uint16'.")
648
- return
649
654
 
650
655
  # Load the model
651
656
  model = torch.load(settings['model_path'])
652
657
  model.to(device)
653
- model.eval() # Ensure the model is in evaluation mode
658
+ model.eval()
654
659
 
655
- # Create directory for saving saliency maps if it does not exist
656
- if settings.get('save', False):
657
- dataset_dir = os.path.dirname(tar_path)
658
- dataset_name = os.path.splitext(os.path.basename(tar_path))[0]
659
- save_dir = os.path.join(dataset_dir, dataset_name, 'saliency_maps')
660
+ # Create directory for saving activation maps if it does not exist
661
+ dataset_dir = os.path.dirname(settings['dataset'])
662
+ dataset_name = os.path.splitext(os.path.basename(settings['dataset']))[0]
663
+ save_dir = os.path.join(dataset_dir, dataset_name, settings['cam_type'])
664
+ batch_grid_fldr = os.path.join(save_dir, 'batch_grids')
665
+
666
+ if settings['save']:
660
667
  os.makedirs(save_dir, exist_ok=True)
661
- print(f"Saliency maps will be saved in: {save_dir}")
662
-
668
+ print(f"Activation maps will be saved in: {save_dir}")
669
+
670
+ if settings['plot']:
671
+ os.makedirs(batch_grid_fldr, exist_ok=True)
672
+ print(f"Batch grid maps will be saved in: {batch_grid_fldr}")
673
+
663
674
  # Load dataset
664
- dataset = TarImageDataset(tar_path, transform=transform)
665
- data_loader = DataLoader(dataset, batch_size=settings['batch_size'], shuffle=True, num_workers=n_jobs, pin_memory=True)
666
-
667
- # Initialize SaliencyMapGenerator
668
- cam_generator = SaliencyMapGenerator(model)
675
+ dataset = TarImageDataset(settings['dataset'], transform=transform)
676
+ data_loader = DataLoader(dataset, batch_size=settings['batch_size'], shuffle=settings['shuffle'], num_workers=n_jobs, pin_memory=True)
677
+
678
+ # Initialize generator based on cam_type
679
+ if settings['cam_type'] in ['gradcam', 'gradcam_pp']:
680
+ cam_generator = GradCAMGenerator(model, target_layer=settings['target_layer'], cam_type=settings['cam_type'])
681
+ elif settings['cam_type'] in ['saliency_image', 'saliency_channel']:
682
+ cam_generator = SaliencyMapGenerator(model)
683
+
669
684
  time_ls = []
670
-
671
685
  for batch_idx, (inputs, filenames) in enumerate(data_loader):
672
686
  start = time.time()
687
+ img_paths = []
673
688
  inputs = inputs.to(device)
674
-
675
- saliency_maps, predicted_classes = cam_generator.compute_saliency_and_predictions(inputs)
676
-
677
- if settings['saliency_mode'] not in ['mean', 'sum']:
678
- print("To generate channel average or sum saliency maps set saliency_mode to 'mean' or 'sum', respectively.")
679
-
680
- if settings['saliency_mode'] == 'mean':
681
- saliency_maps = saliency_maps.mean(dim=1, keepdim=True)
682
-
683
- elif settings['saliency_mode'] == 'sum':
684
- saliency_maps = saliency_maps.sum(dim=1, keepdim=True)
685
-
686
- # Example usage with the class
687
- if settings.get('plot', False):
688
- if settings['plot_mode'] not in ['mean', 'channel', '3-channel']:
689
- print("Invalid plot_mode in settings. Please use 'mean', 'channel', or '3-channel'.")
690
- return
691
- else:
692
- cam_generator.plot_saliency_grid(inputs, saliency_maps, predicted_classes, mode=settings['plot_mode'])
693
-
694
- if settings.get('save', False):
695
- for i in range(inputs.size(0)):
696
- saliency_map = saliency_maps[i].detach().cpu().numpy()
697
-
698
- # Check dtype in settings and normalize accordingly
699
- if settings['dtype'] == 'uint16':
700
- saliency_map = np.clip(saliency_map, 0, 1) * 65535
701
- saliency_map = saliency_map.astype(np.uint16)
702
- mode = 'I;16'
703
- elif settings['dtype'] == 'uint8':
704
- saliency_map = np.clip(saliency_map, 0, 1) * 255
705
- saliency_map = saliency_map.astype(np.uint8)
706
- mode = 'L' # Grayscale mode for uint8
707
-
708
- # Get the class prediction (0 or 1)
709
- class_pred = predicted_classes[i].item()
710
-
711
- save_class_dir = os.path.join(save_dir, f'class_{class_pred}')
712
- os.makedirs(save_class_dir, exist_ok=True)
713
- save_path = os.path.join(save_class_dir, filenames[i])
714
-
715
- # Handle different cases based on saliency_map dimensions
716
- if saliency_map.ndim == 3: # Multi-channel case (C, H, W)
717
- if saliency_map.shape[0] == 3: # RGB-like saliency map
718
- saliency_image = Image.fromarray(np.moveaxis(saliency_map, 0, -1), mode="RGB") # Convert (C, H, W) to (H, W, C)
719
- elif saliency_map.shape[0] == 1: # Single-channel case (1, H, W)
720
- saliency_map = np.squeeze(saliency_map) # Remove the extra channel dimension
721
- saliency_image = Image.fromarray(saliency_map, mode=mode) # Use grayscale mode for single-channel
722
- else:
723
- raise ValueError(f"Unexpected number of channels: {saliency_map.shape[0]}")
724
-
725
- elif saliency_map.ndim == 2: # Single-channel case (H, W)
726
- saliency_image = Image.fromarray(saliency_map, mode=mode) # Keep single channel (H, W)
727
-
728
- else:
729
- raise ValueError(f"Unexpected number of dimensions: {saliency_map.ndim}")
730
-
731
- # Save the image
732
- saliency_image.save(save_path)
733
689
 
690
+ # Compute activation maps and predictions
691
+ if settings['cam_type'] in ['gradcam', 'gradcam_pp']:
692
+ activation_maps, predicted_classes = cam_generator.compute_gradcam_and_predictions(inputs)
693
+ elif settings['cam_type'] in ['saliency_image', 'saliency_channel']:
694
+ activation_maps, predicted_classes = cam_generator.compute_saliency_and_predictions(inputs)
695
+
696
+ # Move activation maps to CPU
697
+ activation_maps = activation_maps.cpu()
698
+
699
+ # Sum saliency maps for 'saliency_image' type
700
+ if settings['cam_type'] == 'saliency_image':
701
+ summed_activation_maps = []
702
+ for i in range(activation_maps.size(0)):
703
+ activation_map = activation_maps[i]
704
+ #print(f"1: {activation_map.shape}")
705
+ activation_map_sum = activation_map.sum(dim=0, keepdim=False)
706
+ #print(f"2: {activation_map.shape}")
707
+ activation_map_sum = np.squeeze(activation_map_sum, axis=0)
708
+ #print(f"3: {activation_map_sum.shape}")
709
+ summed_activation_maps.append(activation_map_sum)
710
+ activation_maps = torch.stack(summed_activation_maps)
711
+
712
+ # For plotting
713
+ if settings['plot']:
714
+ fig = cam_generator.plot_activation_grid(inputs, activation_maps, predicted_classes, overlay=settings['overlay'], normalize=settings['normalize'])
715
+ pdf_save_path = os.path.join(batch_grid_fldr,f"batch_{batch_idx}_grid.pdf")
716
+ fig.savefig(pdf_save_path, format='pdf')
717
+ print(f"Saved batch grid to {pdf_save_path}")
718
+ #plt.show()
719
+ display(fig)
720
+
721
+ for i in range(inputs.size(0)):
722
+ activation_map = activation_maps[i].detach().numpy()
723
+
724
+ if settings['cam_type'] in ['saliency_image', 'gradcam', 'gradcam_pp']:
725
+ #activation_map = activation_map.sum(axis=0)
726
+ activation_map = (activation_map - activation_map.min()) / (activation_map.max() - activation_map.min())
727
+ activation_map = (activation_map * 255).astype(np.uint8)
728
+ activation_image = Image.fromarray(activation_map, mode='L')
729
+
730
+ elif settings['cam_type'] == 'saliency_channel':
731
+ # Handle each channel separately and save as RGB
732
+ rgb_activation_map = np.zeros((activation_map.shape[1], activation_map.shape[2], 3), dtype=np.uint8)
733
+ for c in range(min(activation_map.shape[0], 3)): # Limit to 3 channels for RGB
734
+ channel_map = activation_map[c]
735
+ channel_map = (channel_map - channel_map.min()) / (channel_map.max() - channel_map.min())
736
+ rgb_activation_map[:, :, c] = (channel_map * 255).astype(np.uint8)
737
+ activation_image = Image.fromarray(rgb_activation_map, mode='RGB')
738
+
739
+ # Save activation maps
740
+ class_pred = predicted_classes[i].item()
741
+ parts = filenames[i].split('_')
742
+ plate = parts[0]
743
+ well = parts[1]
744
+ save_class_dir = os.path.join(save_dir, f'class_{class_pred}', str(plate), str(well))
745
+ os.makedirs(save_class_dir, exist_ok=True)
746
+ save_path = os.path.join(save_class_dir, f'{filenames[i]}')
747
+ if settings['save']:
748
+ activation_image.save(save_path)
749
+ img_paths.append(save_path)
750
+
751
+ if settings['save']:
752
+ activation_maps_to_database(img_paths, source_folder, settings)
753
+
754
+ if settings['correlation']:
755
+ df = calculate_activation_correlations(inputs, activation_maps, filenames, manders_thresholds=settings['manders_thresholds'])
756
+ if settings['plot']:
757
+ display(df)
758
+ if settings['save']:
759
+ activation_correlations_to_database(df, img_paths, source_folder, settings)
734
760
 
735
761
  stop = time.time()
736
762
  duration = stop - start
737
763
  time_ls.append(duration)
738
764
  files_processed = batch_idx * settings['batch_size']
739
- files_to_process = len(data_loader)
740
- print_progress(files_processed, files_to_process, n_jobs=n_jobs, time_ls=time_ls, batch_size=settings['batch_size'], operation_type="Generating Saliency Maps")
741
-
742
- print("Saliency map generation complete.")
743
-
744
- def visualize_saliency_map_v1(src, model_type='maxvit', model_path='', image_size=224, channels=[1,2,3], normalize=True, class_names=None, save_saliency=False, save_dir='saliency_maps'):
765
+ files_to_process = len(data_loader) * settings['batch_size']
766
+ print_progress(files_processed, files_to_process, n_jobs=n_jobs, time_ls=time_ls, batch_size=settings['batch_size'], operation_type="Generating Activation Maps")
745
767
 
746
- from spacr.utils import SaliencyMapGenerator, preprocess_image
747
-
748
- use_cuda = torch.cuda.is_available()
749
- device = torch.device("cuda" if use_cuda else "cpu")
750
-
751
- # Load the entire model object
752
- model = torch.load(model_path)
753
- model.to(device)
754
-
755
- # Create directory for saving saliency maps if it does not exist
756
- if save_saliency and not os.path.exists(save_dir):
757
- os.makedirs(save_dir)
758
-
759
- # Collect all images and their tensors
760
- images = []
761
- input_tensors = []
762
- filenames = []
763
- for file in os.listdir(src):
764
- if not file.endswith('.png'):
765
- continue
766
- image_path = os.path.join(src, file)
767
- image, input_tensor = preprocess_image(image_path, normalize=normalize, image_size=image_size, channels=channels)
768
- images.append(image)
769
- input_tensors.append(input_tensor)
770
- filenames.append(file)
771
-
772
- input_tensors = torch.cat(input_tensors).to(device)
773
- class_labels = torch.zeros(input_tensors.size(0), dtype=torch.long).to(device) # Replace with actual class labels if available
774
-
775
- # Generate saliency maps
776
- cam_generator = SaliencyMapGenerator(model)
777
- saliency_maps = cam_generator.compute_saliency_maps(input_tensors, class_labels)
778
-
779
- # Convert saliency maps to numpy arrays
780
- saliency_maps = saliency_maps.cpu().numpy()
781
-
782
- N = len(images)
783
-
784
- dst = os.path.join(src, 'saliency_maps')
785
-
786
- for i in range(N):
787
- fig, axes = plt.subplots(1, 3, figsize=(20, 5))
788
-
789
- # Original image
790
- axes[0].imshow(images[i])
791
- axes[0].axis('off')
792
- if class_names:
793
- axes[0].set_title(f"Class: {class_names[class_labels[i].item()]}")
794
-
795
- # Saliency Map
796
- axes[1].imshow(saliency_maps[i, 0], cmap='hot')
797
- axes[1].axis('off')
798
- axes[1].set_title("Saliency Map")
799
-
800
- # Overlay
801
- overlay = np.array(images[i])
802
- overlay = overlay / overlay.max()
803
- saliency_map_rgb = np.stack([saliency_maps[i, 0]] * 3, axis=-1) # Convert saliency map to RGB
804
- overlay = (overlay * 0.5 + saliency_map_rgb * 0.5).clip(0, 1)
805
- axes[2].imshow(overlay)
806
- axes[2].axis('off')
807
- axes[2].set_title("Overlay")
808
-
809
- plt.tight_layout()
810
- plt.show()
811
-
812
- # Save the saliency map if required
813
- if save_saliency:
814
- os.makedirs(dst, exist_ok=True)
815
- saliency_image = Image.fromarray((saliency_maps[i, 0] * 255).astype(np.uint8))
816
- saliency_image.save(os.path.join(dst, f'saliency_{filenames[i]}'))
817
-
818
- def visualize_grad_cam(src, model_path, target_layers=None, image_size=224, channels=[1, 2, 3], normalize=True, class_names=None, save_cam=False, save_dir='grad_cam'):
819
-
820
- from spacr.utils import GradCAM, preprocess_image, show_cam_on_image, recommend_target_layers
821
-
822
- use_cuda = torch.cuda.is_available()
823
- device = torch.device("cuda" if use_cuda else "cpu")
824
-
825
- model = torch.load(model_path)
826
- model.to(device)
827
-
828
- # If no target layers provided, recommend a target layer
829
- if target_layers is None:
830
- target_layers, all_layers = recommend_target_layers(model)
831
- print(f"No target layer provided. Using recommended layer: {target_layers[0]}")
832
- print("All possible target layers:")
833
- for layer in all_layers:
834
- print(layer)
835
-
836
- grad_cam = GradCAM(model=model, target_layers=target_layers, use_cuda=use_cuda)
837
-
838
- if save_cam and not os.path.exists(save_dir):
839
- os.makedirs(save_dir)
840
-
841
- images = []
842
- filenames = []
843
- for file in os.listdir(src):
844
- if not file.endswith('.png'):
845
- continue
846
- image_path = os.path.join(src, file)
847
- image, input_tensor = preprocess_image(image_path, normalize=normalize, image_size=image_size, channels=channels)
848
- images.append(image)
849
- filenames.append(file)
850
-
851
- input_tensor = input_tensor.to(device)
852
- cam = grad_cam(input_tensor)
853
- cam_image = show_cam_on_image(np.array(image) / 255.0, cam)
854
-
855
- fig, ax = plt.subplots(1, 2, figsize=(10, 5))
856
- ax[0].imshow(image)
857
- ax[0].axis('off')
858
- ax[0].set_title("Original Image")
859
- ax[1].imshow(cam_image)
860
- ax[1].axis('off')
861
- ax[1].set_title("Grad-CAM")
862
- plt.show()
863
-
864
- if save_cam:
865
- cam_pil = Image.fromarray(cam_image)
866
- cam_pil.save(os.path.join(save_dir, f'grad_cam_{file}'))
768
+ torch.cuda.empty_cache()
769
+ gc.collect()
770
+ print("Activation map generation complete.")
867
771
 
868
772
  def visualize_classes(model, dtype, class_names, **kwargs):
869
773
 
870
- from spacr.utils import class_visualization
774
+ from .utils import class_visualization
871
775
 
872
776
  for target_y in range(2): # Assuming binary classification
873
777
  print(f"Visualizing class: {class_names[target_y]}")
spacr/gui.py CHANGED
@@ -57,6 +57,7 @@ class MainApp(tk.Tk):
57
57
  "Map Barcodes": (lambda frame: initiate_root(self, 'map_barcodes'), "Map barcodes to data."),
58
58
  "Regression": (lambda frame: initiate_root(self, 'regression'), "Perform regression analysis."),
59
59
  "Recruitment": (lambda frame: initiate_root(self, 'recruitment'), "Analyze recruitment data."),
60
+ "Activation": (lambda frame: initiate_root(self, 'activation'), "Generate activation maps of computer vision models and measure channel-activation correlation."),
60
61
  "Plaque": (lambda frame: initiate_root(self, 'analyze_plaques'), "Analyze plaque data.")
61
62
  }
62
63
 
spacr/gui_core.py CHANGED
@@ -379,10 +379,13 @@ def set_globals(thread_control_var, q_var, console_output_var, parent_frame_var,
379
379
  index_control = index_control_var
380
380
 
381
381
  def import_settings(settings_type='mask'):
382
- from .gui_utils import convert_settings_dict_for_gui, hide_all_settings
383
382
  global vars_dict, scrollable_frame, button_scrollable_frame
384
- from .settings import generate_fields, set_default_settings_preprocess_generate_masks, get_measure_crop_settings, set_default_train_test_model, set_default_generate_barecode_mapping, set_default_umap_image_settings, get_analyze_recruitment_default_settings
385
383
 
384
+ from .gui_utils import convert_settings_dict_for_gui, hide_all_settings
385
+ from .settings import generate_fields, set_default_settings_preprocess_generate_masks, get_measure_crop_settings, set_default_train_test_model
386
+ from .settings import set_default_generate_barecode_mapping, set_default_umap_image_settings, get_analyze_recruitment_default_settings
387
+ from .settings import get_default_generate_activation_map_settings
388
+ #activation
386
389
  def read_settings_from_csv(csv_file_path):
387
390
  settings = {}
388
391
  with open(csv_file_path, newline='') as csvfile:
@@ -422,6 +425,8 @@ def import_settings(settings_type='mask'):
422
425
  settings = set_default_umap_image_settings(settings={})
423
426
  elif settings_type == 'recruitment':
424
427
  settings = get_analyze_recruitment_default_settings(settings={})
428
+ elif settings_type == 'activation':
429
+ settings = get_default_generate_activation_map_settings(settings={})
425
430
  elif settings_type == 'analyze_plaques':
426
431
  settings = {}
427
432
  elif settings_type == 'convert':
@@ -436,8 +441,10 @@ def import_settings(settings_type='mask'):
436
441
 
437
442
  def setup_settings_panel(vertical_container, settings_type='mask'):
438
443
  global vars_dict, scrollable_frame
439
- from .settings import get_identify_masks_finetune_default_settings, set_default_analyze_screen, set_default_settings_preprocess_generate_masks, get_measure_crop_settings, deep_spacr_defaults, set_default_generate_barecode_mapping, set_default_umap_image_settings
440
- from .settings import get_map_barcodes_default_settings, get_analyze_recruitment_default_settings, get_check_cellpose_models_default_settings, generate_fields, get_perform_regression_default_settings, get_train_cellpose_default_settings
444
+ from .settings import get_identify_masks_finetune_default_settings, set_default_analyze_screen, set_default_settings_preprocess_generate_masks
445
+ from .settings import get_measure_crop_settings, deep_spacr_defaults, set_default_generate_barecode_mapping, set_default_umap_image_settings
446
+ from .settings import get_map_barcodes_default_settings, get_analyze_recruitment_default_settings, get_check_cellpose_models_default_settings
447
+ from .settings import generate_fields, get_perform_regression_default_settings, get_train_cellpose_default_settings, get_default_generate_activation_map_settings
441
448
  from .gui_utils import convert_settings_dict_for_gui
442
449
  from .gui_elements import set_element_size
443
450
 
@@ -480,6 +487,8 @@ def setup_settings_panel(vertical_container, settings_type='mask'):
480
487
  settings = get_perform_regression_default_settings(settings={})
481
488
  elif settings_type == 'recruitment':
482
489
  settings = get_analyze_recruitment_default_settings(settings={})
490
+ elif settings_type == 'activation':
491
+ settings = get_default_generate_activation_map_settings(settings={})
483
492
  elif settings_type == 'analyze_plaques':
484
493
  settings = {'src':'path to images'}
485
494
  elif settings_type == 'convert':
spacr/gui_utils.py CHANGED
@@ -77,7 +77,7 @@ def load_app(root, app_name, app_func):
77
77
  else:
78
78
  proceed_with_app(root, app_name, app_func)
79
79
 
80
- def parse_list(value):
80
+ def parse_list_v1(value):
81
81
  """
82
82
  Parses a string representation of a list and returns the parsed list.
83
83
 
@@ -98,6 +98,34 @@ def parse_list(value):
98
98
  return parsed_value
99
99
  elif all(isinstance(item, str) for item in parsed_value):
100
100
  return parsed_value
101
+ elif all(isinstance(item, float) for item in parsed_value):
102
+ return parsed_value
103
+ else:
104
+ raise ValueError("List contains mixed types or unsupported types")
105
+ else:
106
+ raise ValueError(f"Expected a list but got {type(parsed_value).__name__}")
107
+ except (ValueError, SyntaxError) as e:
108
+ raise ValueError(f"Invalid format for list: {value}. Error: {e}")
109
+
110
+ def parse_list(value):
111
+ """
112
+ Parses a string representation of a list and returns the parsed list.
113
+
114
+ Args:
115
+ value (str): The string representation of the list.
116
+
117
+ Returns:
118
+ list: The parsed list, which can contain integers, floats, or strings.
119
+
120
+ Raises:
121
+ ValueError: If the input value is not a valid list format or contains mixed types or unsupported types.
122
+ """
123
+ try:
124
+ parsed_value = ast.literal_eval(value)
125
+ if isinstance(parsed_value, list):
126
+ # Check if all elements are homogeneous (either all int, float, or str)
127
+ if all(isinstance(item, (int, float, str)) for item in parsed_value):
128
+ return parsed_value
101
129
  else:
102
130
  raise ValueError("List contains mixed types or unsupported types")
103
131
  else:
spacr/io.py CHANGED
@@ -2861,10 +2861,10 @@ def generate_dataset(settings={}):
2861
2861
  date_name = datetime.date.today().strftime('%y%m%d')
2862
2862
  if len(settings['src']) > 1:
2863
2863
  date_name = f"{date_name}_combined"
2864
- if not settings['file_metadata'] is None:
2865
- tar_name = f"{date_name}_{settings['experiment']}_{settings['file_metadata']}.tar"
2866
- else:
2867
- tar_name = f"{date_name}_{settings['experiment']}.tar"
2864
+ #if not settings['file_metadata'] is None:
2865
+ # tar_name = f"{date_name}_{settings['experiment']}_{settings['file_metadata']}.tar"
2866
+ #else:
2867
+ tar_name = f"{date_name}_{settings['experiment']}.tar"
2868
2868
  tar_name = os.path.join(dst, tar_name)
2869
2869
  if os.path.exists(tar_name):
2870
2870
  number = random.randint(1, 100)
spacr/measure.py CHANGED
@@ -652,43 +652,6 @@ def img_list_to_grid(grid, titles=None):
652
652
  plt.tight_layout(pad=0.1)
653
653
  return fig
654
654
 
655
- def filepaths_to_database(img_paths, settings, source_folder, crop_mode):
656
- from. utils import _map_wells_png
657
- png_df = pd.DataFrame(img_paths, columns=['png_path'])
658
-
659
- png_df['file_name'] = png_df['png_path'].apply(lambda x: os.path.basename(x))
660
-
661
- parts = png_df['file_name'].apply(lambda x: pd.Series(_map_wells_png(x, timelapse=settings['timelapse'])))
662
-
663
- columns = ['plate', 'row', 'col', 'field']
664
-
665
- if settings['timelapse']:
666
- columns = columns + ['time_id']
667
-
668
- columns = columns + ['prcfo']
669
-
670
- if crop_mode == 'cell':
671
- columns = columns + ['cell_id']
672
-
673
- if crop_mode == 'nucleus':
674
- columns = columns + ['nucleus_id']
675
-
676
- if crop_mode == 'pathogen':
677
- columns = columns + ['pathogen_id']
678
-
679
- if crop_mode == 'cytoplasm':
680
- columns = columns + ['cytoplasm_id']
681
-
682
- png_df[columns] = parts
683
-
684
- try:
685
- conn = sqlite3.connect(f'{source_folder}/measurements/measurements.db', timeout=5)
686
- png_df.to_sql('png_list', conn, if_exists='append', index=False)
687
- conn.commit()
688
- except sqlite3.OperationalError as e:
689
- print(f"SQLite error: {e}", flush=True)
690
- traceback.print_exc()
691
-
692
655
  #@log_function_call
693
656
  def _measure_crop_core(index, time_ls, file, settings):
694
657
 
@@ -711,7 +674,7 @@ def _measure_crop_core(index, time_ls, file, settings):
711
674
  """
712
675
 
713
676
  from .plot import _plot_cropped_arrays
714
- from .utils import _merge_overlapping_objects, _filter_object, _relabel_parent_with_child_labels, _exclude_objects, normalize_to_dtype
677
+ from .utils import _merge_overlapping_objects, _filter_object, _relabel_parent_with_child_labels, _exclude_objects, normalize_to_dtype, filepaths_to_database
715
678
  from .utils import _merge_and_save_to_database, _crop_center, _find_bounding_box, _generate_names, _get_percentiles
716
679
 
717
680
  figs = {}
spacr/settings.py CHANGED
@@ -246,7 +246,7 @@ def get_measure_crop_settings(settings={}):
246
246
  settings.setdefault('normalize_by','png')
247
247
  settings.setdefault('crop_mode',['cell'])
248
248
  settings.setdefault('dialate_pngs', False)
249
- settings.setdefault('dialate_png_ratios', [0.2, 0,2])
249
+ settings.setdefault('dialate_png_ratios', [0.2,0.2])
250
250
 
251
251
  # Timelapsed settings
252
252
  settings.setdefault('timelapse', False)
@@ -859,7 +859,7 @@ expected_types = {
859
859
  'dataset':str,
860
860
  'score_threshold':float,
861
861
  'sample':None,
862
- 'file_metadata':None,
862
+ 'file_metadata':(str, type(None), list),
863
863
  'apply_model_to_dataset':False,
864
864
  "train":bool,
865
865
  "test":bool,
@@ -880,6 +880,11 @@ expected_types = {
880
880
  "generate_training_dataset":bool,
881
881
  "segmentation_mode":str,
882
882
  "train_DL_model":bool,
883
+ "normalize":bool,
884
+ "overlay":bool,
885
+ "correlate":bool,
886
+ "target_layer":str,
887
+ "normalize_input":bool,
883
888
  }
884
889
 
885
890
  categories = {"Paths":[ "src", "grna", "barcodes", "custom_model_path", "dataset","model_path","grna_csv","row_csv","column_csv"],
@@ -889,18 +894,19 @@ categories = {"Paths":[ "src", "grna", "barcodes", "custom_model_path", "dataset
889
894
  "Nucleus": ["nucleus_intensity_range", "nucleus_size_range", "nucleus_chann_dim", "nucleus_channel", "nucleus_background", "nucleus_Signal_to_noise", "nucleus_CP_prob", "nucleus_FT", "remove_background_nucleus", "nucleus_min_size", "nucleus_mask_dim", "nucleus_loc"],
890
895
  "Pathogen": ["pathogen_intensity_range", "pathogen_size_range", "pathogen_chann_dim", "pathogen_channel", "pathogen_background", "pathogen_Signal_to_noise", "pathogen_CP_prob", "pathogen_FT", "pathogen_model", "remove_background_pathogen", "pathogen_min_size", "pathogen_mask_dim", "pathogens", "pathogen_loc", "pathogen_types", "pathogen_plate_metadata", ],
891
896
  "Measurements": ["remove_image_canvas", "remove_highly_correlated", "homogeneity", "homogeneity_distances", "radial_dist", "calculate_correlation", "manders_thresholds", "save_measurements", "tables", "image_nr", "dot_size", "filter_by", "remove_highly_correlated_features", "remove_low_variance_features", "channel_of_interest"],
892
- "Object Image": ["save_png", "dialate_pngs", "dialate_png_ratios", "png_size", "png_dims", "save_arrays", "normalize_by", "dialate_png_ratios", "crop_mode", "dialate_pngs", "normalize", "use_bounding_box"],
897
+ "Object Image": ["save_png", "dialate_pngs", "dialate_png_ratios", "png_size", "png_dims", "save_arrays", "normalize_by", "crop_mode", "dialate_pngs", "normalize", "use_bounding_box"],
893
898
  "Sequencing": ["signal_direction","mode","comp_level","comp_type","save_h5","expected_end","offset","target_sequence","regex", "highlight"],
894
899
  "Generate Dataset":["file_metadata","class_metadata", "annotation_column","annotated_classes", "dataset_mode", "metadata_type_by","custom_measurement", "sample", "size"],
895
900
  "Hyperparamiters (Training)": ["png_type", "score_threshold","file_type", "train_channels", "epochs", "loss_type", "optimizer_type","image_size","val_split","learning_rate","weight_decay","dropout_rate", "init_weights", "train", "classes", "augment", "amsgrad","use_checkpoint","gradient_accumulation","gradient_accumulation_steps","intermedeate_save","pin_memory"],
896
901
  "Hyperparamiters (Embedding)": ["visualize","n_neighbors","min_dist","metric","resnet_features","reduction_method","embedding_by_controls","col_to_compare","log_data"],
897
902
  "Hyperparamiters (Clustering)": ["eps","min_samples","analyze_clusters","clustering","remove_cluster_noise"],
898
903
  "Hyperparamiters (Regression)":["cov_type", "class_1_threshold", "plate", "other", "fraction_threshold", "alpha", "random_row_column_effects", "regression_type", "min_cell_count", "agg_type", "transform", "dependent_variable"],
904
+ "Hyperparamiters (Activation)":["cam_type", "normalize", "overlay", "correlation", "target_layer", "normalize_input"],
899
905
  "Annotation": ["nc_loc", "pc_loc", "nc", "pc", "cell_plate_metadata","treatment_plate_metadata", "metadata_types", "cell_types", "target","positive_control","negative_control", "location_column", "treatment_loc", "channel_of_interest", "measurement", "treatments", "um_per_pixel", "nr_imgs", "exclude", "exclude_conditions", "mix", "pos", "neg"],
900
906
  "Plot": ["plot", "plot_control", "plot_nr", "examples_to_plot", "normalize_plots", "cmap", "figuresize", "plot_cluster_grids", "img_zoom", "row_limit", "color_by", "plot_images", "smooth_lines", "plot_points", "plot_outlines", "black_background", "plot_by_cluster", "heatmap_feature","grouping","min_max","cmap","save_figure"],
901
907
  "Test": ["test_mode", "test_images", "random_test", "test_nr", "test", "test_split"],
902
908
  "Timelapse": ["timelapse", "fps", "timelapse_displacement", "timelapse_memory", "timelapse_frame_limits", "timelapse_remove_transient", "timelapse_mode", "timelapse_objects", "compartments"],
903
- "Advanced": ["target_intensity_min", "cells_per_well", "nuclei_limit", "pathogen_limit", "uninfected", "backgrounds", "schedule", "test_size","exclude","n_repeats","top_features", "model_type_ml", "model_type","minimum_cell_count","n_estimators","preprocess", "remove_background", "normalize", "lower_percentile", "merge_pathogens", "batch_size", "filter", "save", "masks", "verbose", "randomize", "n_jobs"],
909
+ "Advanced": ["shuffle", "target_intensity_min", "cells_per_well", "nuclei_limit", "pathogen_limit", "uninfected", "backgrounds", "schedule", "test_size","exclude","n_repeats","top_features", "model_type_ml", "model_type","minimum_cell_count","n_estimators","preprocess", "remove_background", "normalize", "lower_percentile", "merge_pathogens", "batch_size", "filter", "save", "masks", "verbose", "randomize", "n_jobs"],
904
910
  "Miscellaneous": ["all_to_mip", "pick_slice", "skip_mode", "upscale", "upscale_factor"]
905
911
  }
906
912
 
@@ -949,6 +955,14 @@ def check_settings(vars_dict, expected_types, q=None):
949
955
  settings[key] = float(value) if '.' in value else int(value)
950
956
  elif expected_type == (str, type(None)):
951
957
  settings[key] = str(value) if value else None
958
+ elif expected_type == (str, type(None), list):
959
+ if isinstance(value, list):
960
+ settings[key] = parse_list(value) if value else None
961
+ elif isinstance(value, str):
962
+ settings[key] = str(value)
963
+ else:
964
+ settings[key] = None
965
+
952
966
  elif expected_type == dict:
953
967
  try:
954
968
  # Ensure that the value is a string that can be converted to a dictionary
@@ -1206,7 +1220,7 @@ def generate_fields(variables, scrollable_frame):
1206
1220
  "dataset": "str - file name of the tar file with image dataset",
1207
1221
  "score_threshold": "float - threshold for classification",
1208
1222
  "sample": "str - number of images to sample for tar dataset (including both classes). Default: None",
1209
- "file_metadata": "str - string that must be present in image path to be included in the dataset",
1223
+ "file_metadata": "str or list of strings - string(s) that must be present in image path to be included in the dataset",
1210
1224
  "apply_model_to_dataset": "bool - whether to apply model to the dataset",
1211
1225
  "train_channels": "list - channels to use for training",
1212
1226
  "dataset_mode": "str - How to generate train/test dataset.",
@@ -1247,6 +1261,13 @@ def generate_fields(variables, scrollable_frame):
1247
1261
  "mode": "(str) - Mode to use for sequence analysis (either single for R1 or R2 fastq files or paired for the combination of R1 and R2).",
1248
1262
  "signal_direction": "(str) - Direction of fastq file (R1 or R2). only relevent when mode is single.",
1249
1263
  "custom_model_path": "(str) - Path to the custom model to finetune.",
1264
+ "cam_type": "(str) - Choose between: gradcam, gradcam_pp, saliency_image, saliency_channel to generate activateion maps of DL models",
1265
+ "target_layer": "(str) - Only used for gradcam and gradcam_pp. The layer to use for the activation map.",
1266
+ "normalize": "(bool) - Normalize images before overlayng the activation maps.",
1267
+ "overlay": "(bool) - Overlay activation maps on the images.",
1268
+ "shuffle": "(bool) - Shuffle the dataset bufore generating the activation maps",
1269
+ "correlation": "(bool) - Calculate correlation between image channels and activation maps. Data is saved to .db.",
1270
+ "normalize_input": "(bool) - Normalize the input images before passing them to the model.",
1250
1271
  }
1251
1272
 
1252
1273
  for key, (var_type, options, default_value) in variables.items():
@@ -1282,6 +1303,8 @@ descriptions = {
1282
1303
 
1283
1304
  'regression': "Perform regression analysis on your data. Function: regression_tools from spacr.analysis.\n\nKey Features:\n- Statistical Analysis: Conduct various types of regression analysis to identify relationships within your data.\n- Flexible Options: Supports multiple regression models and configurations.\n- Data Insight: Gain deeper insights into your dataset through advanced regression techniques.",
1284
1305
 
1306
+ 'activation': "",
1307
+
1285
1308
  'recruitment': "Analyze recruitment data to understand sample recruitment dynamics. Function: recruitment_analysis_tools from spacr.analysis.\n\nKey Features:\n- Recruitment Analysis: Investigate and analyze the recruitment of samples over time or conditions.\n- Visualization: Generate visualizations to represent recruitment trends and patterns.\n- Integration: Utilize data from various sources for a comprehensive recruitment analysis."
1286
1309
  }
1287
1310
 
@@ -1314,4 +1337,25 @@ def set_default_generate_barecode_mapping(settings={}):
1314
1337
  settings.setdefault('mode', 'paired')
1315
1338
  settings.setdefault('single_direction', 'R1')
1316
1339
  settings.setdefault('test', False)
1340
+ return settings
1341
+
1342
+ def get_default_generate_activation_map_settings(settings):
1343
+ settings.setdefault('dataset', 'path')
1344
+ settings.setdefault('model_type', 'maxvit')
1345
+ settings.setdefault('model_path', 'path')
1346
+ settings.setdefault('image_size', 224)
1347
+ settings.setdefault('batch_size', 64)
1348
+ settings.setdefault('normalize', True)
1349
+ settings.setdefault('cam_type', 'gradcam')
1350
+ settings.setdefault('target_layer', None)
1351
+ settings.setdefault('plot', False)
1352
+ settings.setdefault('save', True)
1353
+ settings.setdefault('normalize_input', True)
1354
+ settings.setdefault('channels', [1,2,3])
1355
+ settings.setdefault('overlay', True)
1356
+ settings.setdefault('shuffle', True)
1357
+ settings.setdefault('correlation', True)
1358
+ settings.setdefault('manders_thresholds', [15,50, 75])
1359
+ settings.setdefault('n_jobs', None)
1360
+
1317
1361
  return settings
spacr/utils.py CHANGED
@@ -1,4 +1,4 @@
1
- import os, re, sqlite3, torch, torchvision, random, string, shutil, cv2, tarfile, glob, psutil, platform, gzip, subprocess, time, requests, ast
1
+ import os, re, sqlite3, torch, torchvision, random, string, shutil, cv2, tarfile, glob, psutil, platform, gzip, subprocess, time, requests, ast, traceback
2
2
 
3
3
  import numpy as np
4
4
  import pandas as pd
@@ -12,6 +12,7 @@ from skimage.transform import resize as resizescikit
12
12
  from skimage.morphology import dilation, square
13
13
  from skimage.measure import find_contours
14
14
  from skimage.segmentation import clear_border
15
+ from scipy.stats import pearsonr
15
16
 
16
17
  from collections import defaultdict, OrderedDict
17
18
  from PIL import Image
@@ -67,6 +68,192 @@ from huggingface_hub import list_repo_files
67
68
  import umap.umap_ as umap
68
69
  #import umap
69
70
 
71
+ def filepaths_to_database(img_paths, settings, source_folder, crop_mode):
72
+
73
+ png_df = pd.DataFrame(img_paths, columns=['png_path'])
74
+
75
+ png_df['file_name'] = png_df['png_path'].apply(lambda x: os.path.basename(x))
76
+
77
+ parts = png_df['file_name'].apply(lambda x: pd.Series(_map_wells_png(x, timelapse=settings['timelapse'])))
78
+
79
+ columns = ['plate', 'row', 'col', 'field']
80
+
81
+ if settings['timelapse']:
82
+ columns = columns + ['time_id']
83
+
84
+ columns = columns + ['prcfo']
85
+
86
+ if crop_mode == 'cell':
87
+ columns = columns + ['cell_id']
88
+
89
+ if crop_mode == 'nucleus':
90
+ columns = columns + ['nucleus_id']
91
+
92
+ if crop_mode == 'pathogen':
93
+ columns = columns + ['pathogen_id']
94
+
95
+ if crop_mode == 'cytoplasm':
96
+ columns = columns + ['cytoplasm_id']
97
+
98
+ png_df[columns] = parts
99
+
100
+ try:
101
+ conn = sqlite3.connect(f'{source_folder}/measurements/measurements.db', timeout=5)
102
+ png_df.to_sql('png_list', conn, if_exists='append', index=False)
103
+ conn.commit()
104
+ except sqlite3.OperationalError as e:
105
+ print(f"SQLite error: {e}", flush=True)
106
+ traceback.print_exc()
107
+
108
+ def activation_maps_to_database(img_paths, source_folder, settings):
109
+ from .io import _create_database
110
+
111
+ png_df = pd.DataFrame(img_paths, columns=['png_path'])
112
+ png_df['file_name'] = png_df['png_path'].apply(lambda x: os.path.basename(x))
113
+ parts = png_df['file_name'].apply(lambda x: pd.Series(_map_wells_png(x, timelapse=False)))
114
+ columns = ['plate', 'row', 'col', 'field', 'prcfo', 'object']
115
+ png_df[columns] = parts
116
+
117
+ dataset_name = os.path.splitext(os.path.basename(settings['dataset']))[0]
118
+ database_name = f"{source_folder}/measurements/{dataset_name}.db"
119
+
120
+ if not os.path.exists(database_name):
121
+ _create_database(database_name)
122
+
123
+ try:
124
+ conn = sqlite3.connect(database_name, timeout=5)
125
+ png_df.to_sql(f"{settings['cam_type']}_list", conn, if_exists='append', index=False)
126
+ conn.commit()
127
+ except sqlite3.OperationalError as e:
128
+ print(f"SQLite error: {e}", flush=True)
129
+ traceback.print_exc()
130
+
131
+ def activation_correlations_to_database(df, img_paths, source_folder, settings):
132
+ from .io import _create_database
133
+
134
+ png_df = pd.DataFrame(img_paths, columns=['png_path'])
135
+ png_df['file_name'] = png_df['png_path'].apply(lambda x: os.path.basename(x))
136
+ parts = png_df['file_name'].apply(lambda x: pd.Series(_map_wells_png(x, timelapse=False)))
137
+ columns = ['plate', 'row', 'col', 'field', 'prcfo', 'object']
138
+ png_df[columns] = parts
139
+
140
+ # Align both DataFrames by file_name
141
+ png_df.set_index('file_name', inplace=True)
142
+ df.set_index('file_name', inplace=True)
143
+
144
+ merged_df = pd.concat([png_df, df], axis=1)
145
+ merged_df.reset_index(inplace=True)
146
+
147
+ dataset_name = os.path.splitext(os.path.basename(settings['dataset']))[0]
148
+ database_name = f"{source_folder}/measurements/{dataset_name}.db"
149
+
150
+ if not os.path.exists(database_name):
151
+ _create_database(database_name)
152
+
153
+ try:
154
+ conn = sqlite3.connect(database_name, timeout=5)
155
+ merged_df.to_sql(f"{settings['cam_type']}_correlations", conn, if_exists='append', index=False)
156
+ conn.commit()
157
+ except sqlite3.OperationalError as e:
158
+ print(f"SQLite error: {e}", flush=True)
159
+ traceback.print_exc()
160
+
161
+ def calculate_activation_correlations(inputs, activation_maps, file_names, manders_thresholds=[15, 50, 75]):
162
+ """
163
+ Calculates Pearson and Manders correlations between input image channels and activation map channels.
164
+
165
+ Args:
166
+ inputs: A batch of input images, Tensor of shape (batch_size, channels, height, width)
167
+ activation_maps: A batch of activation maps, Tensor of shape (batch_size, channels, height, width)
168
+ file_names: List of file names corresponding to each image in the batch.
169
+ manders_thresholds: List of intensity percentiles to calculate Manders correlation.
170
+
171
+ Returns:
172
+ df_correlations: A DataFrame with columns for pairwise correlations (Pearson and Manders)
173
+ between input channels and activation map channels.
174
+ """
175
+
176
+ # Ensure tensors are detached and moved to CPU before converting to numpy
177
+ inputs = inputs.detach().cpu()
178
+ activation_maps = activation_maps.detach().cpu()
179
+
180
+ batch_size, in_channels, height, width = inputs.shape
181
+
182
+ if activation_maps.dim() == 3:
183
+ # If activation maps have no channels, add a dummy channel dimension
184
+ activation_maps = activation_maps.unsqueeze(1) # Now shape is (batch_size, 1, height, width)
185
+
186
+ _, act_channels, act_height, act_width = activation_maps.shape
187
+
188
+ # Ensure that the inputs and activation maps are the same size
189
+ if (height != act_height) or (width != act_width):
190
+ activation_maps = torch.nn.functional.interpolate(activation_maps, size=(height, width), mode='bilinear')
191
+
192
+ # Dictionary to collect correlation results
193
+ correlations_dict = {'file_name': []}
194
+
195
+ # Initialize correlation columns based on input channels and activation map channels
196
+ for in_c in range(in_channels):
197
+ for act_c in range(act_channels):
198
+ correlations_dict[f'channel_{in_c}_activation_{act_c}_pearsons'] = []
199
+ for threshold in manders_thresholds:
200
+ correlations_dict[f'channel_{in_c}_activation_{act_c}_{threshold}_M1'] = []
201
+ correlations_dict[f'channel_{in_c}_activation_{act_c}_{threshold}_M2'] = []
202
+
203
+ # Loop over the batch
204
+ for b in range(batch_size):
205
+ input_img = inputs[b] # Input image channels (C, H, W)
206
+ activation_map = activation_maps[b] # Activation map channels (C, H, W)
207
+
208
+ # Add the file name to the current row
209
+ correlations_dict['file_name'].append(file_names[b])
210
+
211
+ # Calculate correlations for each channel pair
212
+ for in_c in range(in_channels):
213
+ input_channel = input_img[in_c].flatten().numpy() # Flatten the input image channel
214
+ input_channel = input_channel[np.isfinite(input_channel)] # Remove NaN or inf values
215
+
216
+ for act_c in range(act_channels):
217
+ activation_channel = activation_map[act_c].flatten().numpy() # Flatten the activation map channel
218
+ activation_channel = activation_channel[np.isfinite(activation_channel)] # Remove NaN or inf values
219
+
220
+ # Check if there are valid (non-empty) arrays left to calculate the Pearson correlation
221
+ if input_channel.size > 0 and activation_channel.size > 0:
222
+ pearson_corr, _ = pearsonr(input_channel, activation_channel)
223
+ else:
224
+ pearson_corr = np.nan # Assign NaN if there are no valid data points
225
+ correlations_dict[f'channel_{in_c}_activation_{act_c}_pearsons'].append(pearson_corr)
226
+
227
+ # Compute Manders correlations for each threshold
228
+ for threshold in manders_thresholds:
229
+ # Get the top percentile pixels based on intensity in both channels
230
+ if input_channel.size > 0 and activation_channel.size > 0:
231
+ input_threshold = np.percentile(input_channel, threshold)
232
+ activation_threshold = np.percentile(activation_channel, threshold)
233
+
234
+ # Mask the pixels above the threshold
235
+ mask = (input_channel >= input_threshold) & (activation_channel >= activation_threshold)
236
+
237
+ # If we have enough pixels, calculate Manders correlation
238
+ if np.sum(mask) > 0:
239
+ manders_corr_M1 = np.sum(input_channel[mask] * activation_channel[mask]) / np.sum(input_channel[mask] ** 2)
240
+ manders_corr_M2 = np.sum(activation_channel[mask] * input_channel[mask]) / np.sum(activation_channel[mask] ** 2)
241
+ else:
242
+ manders_corr_M1 = np.nan
243
+ manders_corr_M2 = np.nan
244
+ else:
245
+ manders_corr_M1 = np.nan
246
+ manders_corr_M2 = np.nan
247
+
248
+ # Store the Manders correlation for this threshold
249
+ correlations_dict[f'channel_{in_c}_activation_{act_c}_{threshold}_M1'].append(manders_corr_M1)
250
+ correlations_dict[f'channel_{in_c}_activation_{act_c}_{threshold}_M2'].append(manders_corr_M2)
251
+
252
+ # Convert the dictionary to a DataFrame
253
+ df_correlations = pd.DataFrame(correlations_dict)
254
+
255
+ return df_correlations
256
+
70
257
  def load_settings(csv_file_path, show=False, setting_key='setting_key', setting_value='setting_value'):
71
258
  """
72
259
  Convert a CSV file with 'settings_key' and 'settings_value' columns into a dictionary.
@@ -892,7 +1079,7 @@ def _map_wells_png(file_name, timelapse=False):
892
1079
  print(f"Error: {e}")
893
1080
  plate, row, column, field, object_id, prcfo = 'error', 'error', 'error', 'error', 'error', 'error'
894
1081
  if timelapse:
895
- return plate, row, column, field, timeid, prcfo, object_id,
1082
+ return plate, row, column, field, timeid, prcfo, object_id
896
1083
  else:
897
1084
  return plate, row, column, field, prcfo, object_id
898
1085
 
@@ -3097,46 +3284,176 @@ class SaliencyMapGenerator:
3097
3284
  saliency = X.grad.abs()
3098
3285
 
3099
3286
  return saliency, predictions
3100
-
3101
- def plot_saliency_grid(self, X, saliency, predictions, mode='mean'):
3287
+
3288
+ def plot_activation_grid(self, X, saliency, predictions, overlay=True, normalize=False):
3102
3289
  N = X.shape[0]
3103
- rows = (N + 7) // 8 # Ensure we can handle batches of different sizes
3290
+ rows = (N + 7) // 8
3104
3291
  fig, axs = plt.subplots(rows, 8, figsize=(16, rows * 2))
3105
3292
 
3106
3293
  for i in range(N):
3107
3294
  ax = axs[i // 8, i % 8]
3295
+ saliency_map = saliency[i].cpu().numpy() # Move to CPU and convert to numpy
3108
3296
 
3109
- if mode == 'mean':
3110
- saliency_map = saliency[i].mean(dim=0).cpu().numpy() # Mean saliency over channels
3111
- ax.imshow(X[i].permute(1, 2, 0).detach().cpu().numpy()) # Added .detach() here
3297
+ if saliency_map.shape[0] == 3: # Channels first, reshape to (H, W, 3)
3298
+ saliency_map = np.transpose(saliency_map, (1, 2, 0))
3299
+
3300
+ # Normalize image channels to 2nd and 98th percentiles
3301
+ if overlay:
3302
+ img_np = X[i].permute(1, 2, 0).detach().cpu().numpy()
3303
+ if normalize:
3304
+ img_np = self.percentile_normalize(img_np)
3305
+ ax.imshow(img_np)
3112
3306
  ax.imshow(saliency_map, cmap='jet', alpha=0.5)
3113
3307
 
3114
- elif mode == 'channel':
3115
- # Plot individual channels in a loop if the image has multiple channels
3116
- for j in range(X.shape[1]):
3117
- saliency_map = saliency[i, j].cpu().numpy()
3118
- ax.imshow(saliency_map, cmap='jet')
3119
- ax.axis('off')
3308
+ # Add class label in the top-left corner
3309
+ ax.text(5, 25, str(predictions[i].item()), fontsize=12, color='white', weight='bold',
3310
+ bbox=dict(facecolor='black', alpha=0.7, boxstyle='round,pad=0.2'))
3311
+ ax.axis('off')
3312
+
3313
+ plt.tight_layout(pad=0)
3314
+ return fig
3315
+
3316
+ def percentile_normalize(self, img, lower_percentile=2, upper_percentile=98):
3317
+ """
3318
+ Normalize each channel of the image to the given percentiles.
3319
+ Args:
3320
+ img: Input image as numpy array with shape (H, W, C)
3321
+ lower_percentile: Lower percentile for normalization (default 2)
3322
+ upper_percentile: Upper percentile for normalization (default 98)
3323
+ Returns:
3324
+ img: Normalized image
3325
+ """
3326
+ img_normalized = np.zeros_like(img)
3120
3327
 
3121
- elif mode == '3-channel' and X.shape[1] == 3:
3122
- saliency_map = saliency[i].cpu().numpy().transpose(1, 2, 0)
3123
- ax.imshow(saliency_map)
3124
-
3125
- elif mode == '2-channel' and X.shape[1] == 2:
3126
- saliency_map = saliency[i].cpu().numpy().transpose(1, 2, 0)
3127
- ax.imshow(saliency_map)
3328
+ for c in range(img.shape[2]): # Iterate over each channel
3329
+ low = np.percentile(img[:, :, c], lower_percentile)
3330
+ high = np.percentile(img[:, :, c], upper_percentile)
3331
+ img_normalized[:, :, c] = np.clip((img[:, :, c] - low) / (high - low), 0, 1)
3332
+
3333
+ return img_normalized
3334
+
3335
+
3336
+ class GradCAMGenerator:
3337
+ def __init__(self, model, target_layer, cam_type='gradcam'):
3338
+ self.model = model
3339
+ self.model.eval()
3340
+ self.target_layer = target_layer
3341
+ self.cam_type = cam_type
3342
+ self.gradients = None
3343
+ self.activations = None
3344
+
3345
+ # Hook the target layer
3346
+ self.target_layer_module = self.get_layer(self.model, self.target_layer)
3347
+ self.hook_layers()
3348
+
3349
+ def hook_layers(self):
3350
+ # Forward hook to get activations
3351
+ def forward_hook(module, input, output):
3352
+ self.activations = output
3353
+
3354
+ # Backward hook to get gradients
3355
+ def backward_hook(module, grad_input, grad_output):
3356
+ self.gradients = grad_output[0]
3357
+
3358
+ self.target_layer_module.register_forward_hook(forward_hook)
3359
+ self.target_layer_module.register_backward_hook(backward_hook)
3360
+
3361
+ def get_layer(self, model, target_layer):
3362
+ # Recursively find the layer specified in target_layer
3363
+ modules = target_layer.split('.')
3364
+ layer = model
3365
+ for module in modules:
3366
+ layer = getattr(layer, module)
3367
+ return layer
3368
+
3369
+ def compute_gradcam_maps(self, X, y):
3370
+ X.requires_grad_()
3128
3371
 
3129
- # Add class label in top-left corner
3372
+ # Forward pass
3373
+ scores = self.model(X).squeeze()
3374
+
3375
+ # Perform backward pass
3376
+ target_scores = scores * (2 * y - 1)
3377
+ self.model.zero_grad()
3378
+ target_scores.backward(torch.ones_like(target_scores))
3379
+
3380
+ # Compute GradCAM
3381
+ pooled_gradients = torch.mean(self.gradients, dim=[0, 2, 3])
3382
+ for i in range(self.activations.size(1)):
3383
+ self.activations[:, i, :, :] *= pooled_gradients[i]
3384
+
3385
+ gradcam = torch.mean(self.activations, dim=1).squeeze()
3386
+ gradcam = F.relu(gradcam)
3387
+ gradcam = F.interpolate(gradcam.unsqueeze(0).unsqueeze(0), size=X.shape[2:], mode='bilinear')
3388
+ gradcam = gradcam.squeeze().cpu().detach().numpy()
3389
+ gradcam = (gradcam - gradcam.min()) / (gradcam.max() - gradcam.min())
3390
+
3391
+ return gradcam
3392
+
3393
+ def compute_gradcam_and_predictions(self, X):
3394
+ self.model.eval()
3395
+ X.requires_grad_()
3396
+
3397
+ # Forward pass to get predictions (logits)
3398
+ scores = self.model(X).squeeze()
3399
+
3400
+ # Get predicted class (0 or 1 for binary classification)
3401
+ predictions = (scores > 0).long()
3402
+
3403
+ # Compute gradcam maps
3404
+ gradcam_maps = []
3405
+ for i in range(X.size(0)):
3406
+ gradcam_map = self.compute_gradcam_maps(X[i].unsqueeze(0), predictions[i])
3407
+ gradcam_maps.append(gradcam_map)
3408
+
3409
+ return torch.tensor(gradcam_maps), predictions
3410
+
3411
+ def plot_activation_grid(self, X, gradcam, predictions, overlay=True, normalize=False):
3412
+ N = X.shape[0]
3413
+ rows = (N + 7) // 8
3414
+ fig, axs = plt.subplots(rows, 8, figsize=(16, rows * 2))
3415
+
3416
+ for i in range(N):
3417
+ ax = axs[i // 8, i % 8]
3418
+ gradcam_map = gradcam[i].cpu().numpy()
3419
+
3420
+ # Normalize image channels to 2nd and 98th percentiles
3421
+ if overlay:
3422
+ img_np = X[i].permute(1, 2, 0).detach().cpu().numpy()
3423
+ if normalize:
3424
+ img_np = self.percentile_normalize(img_np)
3425
+ ax.imshow(img_np)
3426
+ ax.imshow(gradcam_map, cmap='jet', alpha=0.5)
3427
+
3428
+ #ax.imshow(X[i].permute(1, 2, 0).detach().cpu().numpy()) # Original image
3429
+ #ax.imshow(gradcam_map, cmap='jet', alpha=0.5) # Overlay the gradcam map
3430
+
3431
+ # Add class label in the top-left corner
3130
3432
  ax.text(5, 25, str(predictions[i].item()), fontsize=12, color='white', weight='bold',
3131
3433
  bbox=dict(facecolor='black', alpha=0.7, boxstyle='round,pad=0.2'))
3132
3434
  ax.axis('off')
3133
3435
 
3134
- # Turn off unused axes
3135
- for j in range(N, rows * 8):
3136
- fig.delaxes(axs[j // 8, j % 8])
3137
-
3138
3436
  plt.tight_layout(pad=0)
3139
- plt.show()
3437
+ return fig
3438
+
3439
+ def percentile_normalize(self, img, lower_percentile=2, upper_percentile=98):
3440
+ """
3441
+ Normalize each channel of the image to the given percentiles.
3442
+ Args:
3443
+ img: Input image as numpy array with shape (H, W, C)
3444
+ lower_percentile: Lower percentile for normalization (default 2)
3445
+ upper_percentile: Upper percentile for normalization (default 98)
3446
+ Returns:
3447
+ img: Normalized image
3448
+ """
3449
+ img_normalized = np.zeros_like(img)
3450
+
3451
+ for c in range(img.shape[2]): # Iterate over each channel
3452
+ low = np.percentile(img[:, :, c], lower_percentile)
3453
+ high = np.percentile(img[:, :, c], upper_percentile)
3454
+ img_normalized[:, :, c] = np.clip((img[:, :, c] - low) / (high - low), 0, 1)
3455
+
3456
+ return img_normalized
3140
3457
 
3141
3458
  def preprocess_image(image_path, normalize=True, image_size=224, channels=[1,2,3]):
3142
3459
  preprocess = transforms.Compose([
@@ -3677,8 +3994,37 @@ def plot_grid(cluster_images, colors, figuresize, black_background, verbose):
3677
3994
  plt.show()
3678
3995
  return grid_fig
3679
3996
 
3680
- def generate_path_list_from_db(db_path, file_metadata):
3997
+ def generate_path_list_from_db_v1(db_path, file_metadata):
3998
+
3999
+ all_paths = []
4000
+
4001
+ # Connect to the database and retrieve the image paths
4002
+ print(f"Reading DataBase: {db_path}")
4003
+ try:
4004
+ with sqlite3.connect(db_path) as conn:
4005
+ cursor = conn.cursor()
4006
+ if file_metadata:
4007
+ if isinstance(file_metadata, str):
4008
+ cursor.execute("SELECT png_path FROM png_list WHERE png_path LIKE ?", (f"%{file_metadata}%",))
4009
+ else:
4010
+ cursor.execute("SELECT png_path FROM png_list")
3681
4011
 
4012
+ while True:
4013
+ rows = cursor.fetchmany(1000)
4014
+ if not rows:
4015
+ break
4016
+ all_paths.extend([row[0] for row in rows])
4017
+
4018
+ except sqlite3.Error as e:
4019
+ print(f"Database error: {e}")
4020
+ return
4021
+ except Exception as e:
4022
+ print(f"Error: {e}")
4023
+ return
4024
+
4025
+ return all_paths
4026
+
4027
+ def generate_path_list_from_db(db_path, file_metadata):
3682
4028
  all_paths = []
3683
4029
 
3684
4030
  # Connect to the database and retrieve the image paths
@@ -3686,10 +4032,19 @@ def generate_path_list_from_db(db_path, file_metadata):
3686
4032
  try:
3687
4033
  with sqlite3.connect(db_path) as conn:
3688
4034
  cursor = conn.cursor()
4035
+
3689
4036
  if file_metadata:
3690
4037
  if isinstance(file_metadata, str):
4038
+ # If file_metadata is a single string
3691
4039
  cursor.execute("SELECT png_path FROM png_list WHERE png_path LIKE ?", (f"%{file_metadata}%",))
4040
+ elif isinstance(file_metadata, list):
4041
+ # If file_metadata is a list of strings
4042
+ query = "SELECT png_path FROM png_list WHERE " + " OR ".join(
4043
+ ["png_path LIKE ?" for _ in file_metadata])
4044
+ params = [f"%{meta}%" for meta in file_metadata]
4045
+ cursor.execute(query, params)
3692
4046
  else:
4047
+ # If file_metadata is None or empty
3693
4048
  cursor.execute("SELECT png_path FROM png_list")
3694
4049
 
3695
4050
  while True:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: spacr
3
- Version: 0.3.22
3
+ Version: 0.3.30
4
4
  Summary: Spatial phenotype analysis of crisp screens (SpaCr)
5
5
  Home-page: https://github.com/EinarOlafsson/spacr
6
6
  Author: Einar Birnir Olafsson
@@ -9,25 +9,25 @@ spacr/app_sequencing.py,sha256=DjG26jy4cpddnV8WOOAIiExtOe9MleVMY4MFa5uTo5w,157
9
9
  spacr/app_umap.py,sha256=ZWAmf_OsIKbYvolYuWPMYhdlVe-n2CADoJulAizMiEo,153
10
10
  spacr/cellpose.py,sha256=zv4BzhaP2O-mtQ-pUfYvpOyxgn1ke_bDWgdHD5UWm9I,13942
11
11
  spacr/core.py,sha256=G_x-w7FRIHNfSOoPaIZPSf_A7mVj7PA7o9HQZ4nIu5o,48231
12
- spacr/deep_spacr.py,sha256=iPJhwhNQKF0_PmQ3RXi_gK7BKvIb5m54DeYZRlXBqMU,46081
13
- spacr/gui.py,sha256=ndmWP4F0QO5j6DM6MNzoGtzv_7Yj4LTW2SLi9URBZIQ,8055
14
- spacr/gui_core.py,sha256=OJQxzpehIyDzjSjIsvxSHat4NIjkqjX0VZAUQTnzEzg,40921
12
+ spacr/deep_spacr.py,sha256=HdOcNU8cHcE_19nP7_5uTz-ih3E169ffr2Hm--NvMvA,43255
13
+ spacr/gui.py,sha256=ARyn9Q_g8HoP-cXh1nzMLVFCKqthY4v2u9yORyaQqQE,8230
14
+ spacr/gui_core.py,sha256=LV_HX5zreu3Bye6sQFDbOuk8Dfj4StMoohy6hsrDEXA,41363
15
15
  spacr/gui_elements.py,sha256=3ru8FPZtXCZSj7167GJj18-Zo6TVebhAzkit-mmqmTI,135342
16
- spacr/gui_utils.py,sha256=76utRICvY0k_6X8CA1P_TmYBJARp4b87OkI9t39tldA,45822
17
- spacr/io.py,sha256=Xy1Drm5NPhxvwE1nyJVd2SQu3yTynlnvSrjlRuFXwBw,143371
16
+ spacr/gui_utils.py,sha256=hY7JC8HMlyKa9d7tDjkgXgRILgBYTw85jAkRsexO0P0,46960
17
+ spacr/io.py,sha256=AARmqn1fMmTgVDwWy8bEYK6SjH-6DZIulgCSPdBTyf0,143370
18
18
  spacr/logger.py,sha256=lJhTqt-_wfAunCPl93xE65Wr9Y1oIHJWaZMjunHUeIw,1538
19
- spacr/measure.py,sha256=8MRjQdB-2n8JVLjEpF3cxvfT-Udug27uJ2ErJJ5t1ic,56000
19
+ spacr/measure.py,sha256=BThn_sALgKrwGKnLOGpT4FyoJeRVoTZoP9SXbCtCMRw,54857
20
20
  spacr/mediar.py,sha256=FwLvbLQW5LQzPgvJZG8Lw7GniA2vbZx6Jv6vIKu7I5c,14743
21
21
  spacr/ml.py,sha256=3XiQUfhhseCz9cZXhaVkCCv_qfqoZCdXGnO_p3ulwo4,47131
22
22
  spacr/openai.py,sha256=5vBZ3Jl2llYcW3oaTEXgdyCB2aJujMUIO5K038z7w_A,1246
23
23
  spacr/plot.py,sha256=eZcs-CQrDTENXVeMY8y8N8VZnmPePO-kAWdoaweFmW8,105540
24
24
  spacr/sequencing.py,sha256=t18mgpK6rhWuB1LtFOsPxqgpFXxuUmrD06ecsaVQ0Gw,19655
25
- spacr/settings.py,sha256=YExChD7DWY_cJyaPGKDTpFajsXXi5ZQ8P0XR9ZQf8CE,73560
25
+ spacr/settings.py,sha256=BUQv8mSQLaw3yT08cKB0x5Y5gl0-S7AxmV6TABoPQlk,75773
26
26
  spacr/sim.py,sha256=1xKhXimNU3ukzIw-3l9cF3Znc_brW8h20yv8fSTzvss,71173
27
27
  spacr/submodules.py,sha256=AB7s6-cULsaqz-haAaCtXfGEIi8uPZGT4xoCslUJC3Y,18391
28
28
  spacr/timelapse.py,sha256=FSYpUtAVy6xc3lwprRYgyDTT9ysUhfRQ4zrP9_h2mvg,39465
29
29
  spacr/toxo.py,sha256=us3pQyULtMTyfTq0MWPn4QJTTmQ6BwAJKChNf75jo3I,10082
30
- spacr/utils.py,sha256=brlNXsDcsKyHjJ2IodB0KyMQkEpQfMBp5QZCCb0vdz8,198459
30
+ spacr/utils.py,sha256=w4Cht32Mhep7jfXKm5CSpyFLB3lOxiBCQI6PnaYcI3Q,213360
31
31
  spacr/version.py,sha256=axH5tnGwtgSnJHb5IDhiu4Zjk5GhLyAEDRe-rnaoFOA,409
32
32
  spacr/resources/MEDIAR/.gitignore,sha256=Ff1q9Nme14JUd-4Q3jZ65aeQ5X4uttptssVDgBVHYo8,152
33
33
  spacr/resources/MEDIAR/LICENSE,sha256=yEj_TRDLUfDpHDNM0StALXIt6mLqSgaV2hcCwa6_TcY,1065
@@ -150,9 +150,9 @@ spacr/resources/icons/umap.png,sha256=dOLF3DeLYy9k0nkUybiZMe1wzHQwLJFRmgccppw-8b
150
150
  spacr/resources/images/plate1_E01_T0001F001L01A01Z01C02.tif,sha256=Tl0ZUfZ_AYAbu0up_nO0tPRtF1BxXhWQ3T3pURBCCRo,7958528
151
151
  spacr/resources/images/plate1_E01_T0001F001L01A02Z01C01.tif,sha256=m8N-V71rA1TT4dFlENNg8s0Q0YEXXs8slIn7yObmZJQ,7958528
152
152
  spacr/resources/images/plate1_E01_T0001F001L01A03Z01C03.tif,sha256=Pbhk7xn-KUP6RSIhJsxQcrHFImBm3GEpLkzx7WOc-5M,7958528
153
- spacr-0.3.22.dist-info/LICENSE,sha256=SR-2MeGc6SCM1UORJYyarSWY_A-JaOMFDj7ReSs9tRM,1083
154
- spacr-0.3.22.dist-info/METADATA,sha256=wHP5zD5dSSsxLHNjlr3-OALEmhzL7gexG_uqX6M0OWc,5949
155
- spacr-0.3.22.dist-info/WHEEL,sha256=HiCZjzuy6Dw0hdX5R3LCFPDmFS4BWl8H-8W39XfmgX4,91
156
- spacr-0.3.22.dist-info/entry_points.txt,sha256=BMC0ql9aNNpv8lUZ8sgDLQMsqaVnX5L535gEhKUP5ho,296
157
- spacr-0.3.22.dist-info/top_level.txt,sha256=GJPU8FgwRXGzKeut6JopsSRY2R8T3i9lDgya42tLInY,6
158
- spacr-0.3.22.dist-info/RECORD,,
153
+ spacr-0.3.30.dist-info/LICENSE,sha256=SR-2MeGc6SCM1UORJYyarSWY_A-JaOMFDj7ReSs9tRM,1083
154
+ spacr-0.3.30.dist-info/METADATA,sha256=pi7DGlwhEsgk3YMKEDCiATIXZC1ku_PuEo_CKGjDUE4,5949
155
+ spacr-0.3.30.dist-info/WHEEL,sha256=HiCZjzuy6Dw0hdX5R3LCFPDmFS4BWl8H-8W39XfmgX4,91
156
+ spacr-0.3.30.dist-info/entry_points.txt,sha256=BMC0ql9aNNpv8lUZ8sgDLQMsqaVnX5L535gEhKUP5ho,296
157
+ spacr-0.3.30.dist-info/top_level.txt,sha256=GJPU8FgwRXGzKeut6JopsSRY2R8T3i9lDgya42tLInY,6
158
+ spacr-0.3.30.dist-info/RECORD,,
File without changes