spacr 0.3.2__py3-none-any.whl → 0.3.22__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
spacr/core.py CHANGED
@@ -844,4 +844,108 @@ def generate_mediar_masks(src, settings, object_type):
844
844
  gc.collect()
845
845
  torch.cuda.empty_cache()
846
846
 
847
- print("Mask generation completed.")
847
+ print("Mask generation completed.")
848
+
849
+ def generate_screen_graphs(settings):
850
+ """
851
+ Generate screen graphs for different measurements in a given source directory.
852
+
853
+ Args:
854
+ src (str or list): Path(s) to the source directory or directories.
855
+ tables (list): List of tables to include in the analysis (default: ['cell', 'nucleus', 'pathogen', 'cytoplasm']).
856
+ graph_type (str): Type of graph to generate (default: 'bar').
857
+ summary_func (str or function): Function to summarize data (default: 'mean').
858
+ y_axis_start (float): Starting value for the y-axis (default: 0).
859
+ error_bar_type (str): Type of error bar to use ('std' or 'sem') (default: 'std').
860
+ theme (str): Theme for the graph (default: 'pastel').
861
+ representation (str): Representation for grouping (default: 'well').
862
+
863
+ Returns:
864
+ figs (list): List of generated figures.
865
+ results (list): List of corresponding result DataFrames.
866
+ """
867
+
868
+ from .plot import spacrGraph
869
+ from .io import _read_and_merge_data
870
+ from.utils import annotate_conditions
871
+
872
+ if isinstance(settings['src'], str):
873
+ srcs = [settings['src']]
874
+ else:
875
+ srcs = settings['src']
876
+
877
+ all_df = pd.DataFrame()
878
+ figs = []
879
+ results = []
880
+
881
+ for src in srcs:
882
+ db_loc = [os.path.join(src, 'measurements', 'measurements.db')]
883
+
884
+ # Read and merge data from the database
885
+ df, _ = _read_and_merge_data(db_loc, settings['tables'], verbose=True, nuclei_limit=settings['nuclei_limit'], pathogen_limit=settings['pathogen_limit'], uninfected=settings['uninfected'])
886
+
887
+ # Annotate the data
888
+ df = annotate_conditions(df, cells=settings['cells'], cell_loc=None, pathogens=settings['controls'], pathogen_loc=settings['controls_loc'], treatments=None, treatment_loc=None)
889
+
890
+ # Calculate recruitment metric
891
+ df['recruitment'] = df['pathogen_channel_1_mean_intensity'] / df['cytoplasm_channel_1_mean_intensity']
892
+
893
+ # Combine with the overall DataFrame
894
+ all_df = pd.concat([all_df, df], ignore_index=True)
895
+
896
+ # Generate individual plot
897
+ plotter = spacrGraph(df,
898
+ grouping_column='pathogen',
899
+ data_column='recruitment',
900
+ graph_type=settings['graph_type'],
901
+ summary_func=settings['summary_func'],
902
+ y_axis_start=settings['y_axis_start'],
903
+ error_bar_type=settings['error_bar_type'],
904
+ theme=settings['theme'],
905
+ representation=settings['representation'])
906
+
907
+ plotter.create_plot()
908
+ fig = plotter.get_figure()
909
+ results_df = plotter.get_results()
910
+
911
+ # Append to the lists
912
+ figs.append(fig)
913
+ results.append(results_df)
914
+
915
+ # Generate plot for the combined data (all_df)
916
+ plotter = spacrGraph(all_df,
917
+ grouping_column='pathogen',
918
+ data_column='recruitment',
919
+ graph_type=settings['graph_type'],
920
+ summary_func=settings['summary_func'],
921
+ y_axis_start=settings['y_axis_start'],
922
+ error_bar_type=settings['error_bar_type'],
923
+ theme=settings['theme'],
924
+ representation=settings['representation'])
925
+
926
+ plotter.create_plot()
927
+ fig = plotter.get_figure()
928
+ results_df = plotter.get_results()
929
+
930
+ figs.append(fig)
931
+ results.append(results_df)
932
+
933
+ # Save figures and results
934
+ for i, fig in enumerate(figs):
935
+ res = results[i]
936
+
937
+ if i < len(srcs):
938
+ source = srcs[i]
939
+ else:
940
+ source = srcs[0]
941
+
942
+ # Ensure the destination folder exists
943
+ dst = os.path.join(source, 'results')
944
+ print(f"Savings results to {dst}")
945
+ os.makedirs(dst, exist_ok=True)
946
+
947
+ # Save the figure and results DataFrame
948
+ fig.savefig(os.path.join(dst, f"figure_controls_{i}_{settings['representation']}_{settings['summary_func']}_{settings['graph_type']}.pdf"), format='pdf')
949
+ res.to_csv(os.path.join(dst, f"results_controls_{i}_{settings['representation']}_{settings['summary_func']}_{settings['graph_type']}.csv"), index=False)
950
+
951
+ return
spacr/deep_spacr.py CHANGED
@@ -1,4 +1,4 @@
1
- import os, torch, time, gc, datetime
1
+ import os, torch, time, gc, datetime, cv2
2
2
  torch.backends.cudnn.benchmark = True
3
3
 
4
4
  import numpy as np
@@ -10,6 +10,8 @@ import torch.nn.functional as F
10
10
  import matplotlib.pyplot as plt
11
11
  from PIL import Image
12
12
  from sklearn.metrics import auc, precision_recall_curve
13
+ from IPython.display import display
14
+ from multiprocessing import cpu_count
13
15
 
14
16
  from torchvision import transforms
15
17
  from torch.utils.data import DataLoader
@@ -73,6 +75,12 @@ def apply_model_to_tar(settings={}):
73
75
 
74
76
  from .io import TarImageDataset
75
77
  from .utils import process_vision_results, print_progress
78
+
79
+ if os.path.exists(settings['dataset']):
80
+ tar_path = settings['dataset']
81
+ else:
82
+ tar_path = os.path.join(settings['src'], 'datasets', settings['dataset'])
83
+ model_path = settings['model_path']
76
84
 
77
85
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
78
86
  if settings['normalize']:
@@ -86,18 +94,18 @@ def apply_model_to_tar(settings={}):
86
94
  transforms.CenterCrop(size=(settings['image_size'], settings['image_size']))])
87
95
 
88
96
  if settings['verbose']:
89
- print(f"Loading model from {settings['model_path']}")
90
- print(f"Loading dataset from {settings['tar_path']}")
97
+ print(f"Loading model from {model_path}")
98
+ print(f"Loading dataset from {tar_path}")
91
99
 
92
100
  model = torch.load(settings['model_path'])
93
101
 
94
- dataset = TarImageDataset(settings['tar_path'], transform=transform)
102
+ dataset = TarImageDataset(tar_path, transform=transform)
95
103
  data_loader = DataLoader(dataset, batch_size=settings['batch_size'], shuffle=True, num_workers=settings['n_jobs'], pin_memory=True)
96
104
 
97
- model_name = os.path.splitext(os.path.basename(settings['model_path']))[0]
98
- dataset_name = os.path.splitext(os.path.basename(settings['tar_path']))[0]
105
+ model_name = os.path.splitext(os.path.basename(model_path))[0]
106
+ dataset_name = os.path.splitext(os.path.basename(settings['dataset']))[0]
99
107
  date_name = datetime.date.today().strftime('%y%m%d')
100
- dst = os.path.dirname(settings['tar_path'])
108
+ dst = os.path.dirname(tar_path)
101
109
  result_loc = f'{dst}/{date_name}_{dataset_name}_{model_name}_result.csv'
102
110
 
103
111
  model.eval()
@@ -240,7 +248,7 @@ def evaluate_model_performance(model, loader, epoch, loss_type):
240
248
 
241
249
  loss /= len(loader)
242
250
  data_dict = classification_metrics(all_labels, prediction_pos_probs)
243
- data_dict['loss'] = loss
251
+ data_dict['loss'] = loss.item()
244
252
  data_dict['epoch'] = epoch
245
253
  data_dict['Accuracy'] = acc
246
254
 
@@ -323,8 +331,8 @@ def test_model_performance(loaders, model, loader_name_list, epoch, loss_type):
323
331
 
324
332
  def train_test_model(settings):
325
333
 
326
- from .io import _save_settings, _copy_missclassified
327
- from .utils import pick_best_model
334
+ from .io import _copy_missclassified
335
+ from .utils import pick_best_model, save_settings
328
336
  from .io import generate_loaders
329
337
  from .settings import get_train_test_model_settings
330
338
 
@@ -346,7 +354,12 @@ def train_test_model(settings):
346
354
  model = torch.load(settings['custom_model_path'])
347
355
 
348
356
  if settings['train']:
349
- _save_settings(settings, src)
357
+ if settings['train'] and settings['test']:
358
+ save_settings(settings, name=f"train_test_{settings['model_type']}_{settings['epochs']}", show=True)
359
+ elif settings['train'] is True:
360
+ save_settings(settings, name=f"train_{settings['model_type']}_{settings['epochs']}", show=True)
361
+ elif settings['test'] is True:
362
+ save_settings(settings, name=f"test_{settings['model_type']}_{settings['epochs']}", show=True)
350
363
 
351
364
  if settings['train']:
352
365
  train, val, train_fig = generate_loaders(src,
@@ -574,19 +587,21 @@ def train_model(dst, model_type, train_loaders, epochs=100, learning_rate=0.0001
574
587
  if schedule == 'step_lr':
575
588
  scheduler.step()
576
589
 
577
- if epoch % 10 == 0 or epoch == epochs:
578
- if accumulated_train_dicts:
579
- train_df = pd.DataFrame(accumulated_train_dicts)
580
- _save_progress(dst, train_df, result_type='train')
581
-
582
- if accumulated_val_dicts:
583
- val_df = pd.DataFrame(accumulated_val_dicts)
584
- _save_progress(dst, val_df,result_type='validation')
585
-
586
- if accumulated_test_dicts:
587
- val_df = pd.DataFrame(accumulated_test_dicts)
588
- _save_progress(dst, val_df, result_type='test')
589
-
590
+ if accumulated_train_dicts and accumulated_val_dicts:
591
+ train_df = pd.DataFrame(accumulated_train_dicts)
592
+ validation_df = pd.DataFrame(accumulated_val_dicts)
593
+ _save_progress(dst, train_df, validation_df)
594
+ accumulated_train_dicts, accumulated_val_dicts = [], []
595
+
596
+ elif accumulated_train_dicts:
597
+ train_df = pd.DataFrame(accumulated_train_dicts)
598
+ _save_progress(dst, train_df, None)
599
+ accumulated_train_dicts = []
600
+ elif accumulated_test_dicts:
601
+ test_df = pd.DataFrame(accumulated_test_dicts)
602
+ _save_progress(dst, test_df, None)
603
+ accumulated_test_dicts = []
604
+
590
605
  batch_size = len(train_loaders)
591
606
  duration = time.time() - start_time
592
607
  time_ls.append(duration)
@@ -595,7 +610,138 @@ def train_model(dst, model_type, train_loaders, epochs=100, learning_rate=0.0001
595
610
 
596
611
  return model, model_path
597
612
 
598
- def visualize_saliency_map(src, model_type='maxvit', model_path='', image_size=224, channels=[1,2,3], normalize=True, class_names=None, save_saliency=False, save_dir='saliency_maps'):
613
+ def visualize_saliency_map(settings):
614
+ from spacr.utils import SaliencyMapGenerator, print_progress
615
+ from spacr.io import TarImageDataset # Assuming you have a dataset class
616
+ from torchvision.utils import make_grid
617
+
618
+ use_cuda = torch.cuda.is_available()
619
+ device = torch.device("cuda" if use_cuda else "cpu")
620
+
621
+ # Set number of jobs for loading
622
+ if settings['n_jobs'] is None:
623
+ n_jobs = max(1, cpu_count() - 4)
624
+ else:
625
+ n_jobs = settings['n_jobs']
626
+
627
+ # Set transforms for images
628
+ if settings['normalize']:
629
+ transform = transforms.Compose([
630
+ transforms.ToTensor(),
631
+ transforms.CenterCrop(size=(settings['image_size'], settings['image_size'])),
632
+ transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])
633
+ else:
634
+ transform = transforms.Compose([
635
+ transforms.ToTensor(),
636
+ transforms.CenterCrop(size=(settings['image_size'], settings['image_size']))])
637
+
638
+ # Handle dataset path
639
+ if os.path.exists(settings['dataset']):
640
+ tar_path = settings['dataset']
641
+ else:
642
+ print(f"Dataset not found at {settings['dataset']}")
643
+ return
644
+
645
+ if settings.get('save', False):
646
+ if settings['dtype'] not in ['uint8', 'uint16']:
647
+ print("Invalid dtype in settings. Please use 'uint8' or 'uint16'.")
648
+ return
649
+
650
+ # Load the model
651
+ model = torch.load(settings['model_path'])
652
+ model.to(device)
653
+ model.eval() # Ensure the model is in evaluation mode
654
+
655
+ # Create directory for saving saliency maps if it does not exist
656
+ if settings.get('save', False):
657
+ dataset_dir = os.path.dirname(tar_path)
658
+ dataset_name = os.path.splitext(os.path.basename(tar_path))[0]
659
+ save_dir = os.path.join(dataset_dir, dataset_name, 'saliency_maps')
660
+ os.makedirs(save_dir, exist_ok=True)
661
+ print(f"Saliency maps will be saved in: {save_dir}")
662
+
663
+ # Load dataset
664
+ dataset = TarImageDataset(tar_path, transform=transform)
665
+ data_loader = DataLoader(dataset, batch_size=settings['batch_size'], shuffle=True, num_workers=n_jobs, pin_memory=True)
666
+
667
+ # Initialize SaliencyMapGenerator
668
+ cam_generator = SaliencyMapGenerator(model)
669
+ time_ls = []
670
+
671
+ for batch_idx, (inputs, filenames) in enumerate(data_loader):
672
+ start = time.time()
673
+ inputs = inputs.to(device)
674
+
675
+ saliency_maps, predicted_classes = cam_generator.compute_saliency_and_predictions(inputs)
676
+
677
+ if settings['saliency_mode'] not in ['mean', 'sum']:
678
+ print("To generate channel average or sum saliency maps set saliency_mode to 'mean' or 'sum', respectively.")
679
+
680
+ if settings['saliency_mode'] == 'mean':
681
+ saliency_maps = saliency_maps.mean(dim=1, keepdim=True)
682
+
683
+ elif settings['saliency_mode'] == 'sum':
684
+ saliency_maps = saliency_maps.sum(dim=1, keepdim=True)
685
+
686
+ # Example usage with the class
687
+ if settings.get('plot', False):
688
+ if settings['plot_mode'] not in ['mean', 'channel', '3-channel']:
689
+ print("Invalid plot_mode in settings. Please use 'mean', 'channel', or '3-channel'.")
690
+ return
691
+ else:
692
+ cam_generator.plot_saliency_grid(inputs, saliency_maps, predicted_classes, mode=settings['plot_mode'])
693
+
694
+ if settings.get('save', False):
695
+ for i in range(inputs.size(0)):
696
+ saliency_map = saliency_maps[i].detach().cpu().numpy()
697
+
698
+ # Check dtype in settings and normalize accordingly
699
+ if settings['dtype'] == 'uint16':
700
+ saliency_map = np.clip(saliency_map, 0, 1) * 65535
701
+ saliency_map = saliency_map.astype(np.uint16)
702
+ mode = 'I;16'
703
+ elif settings['dtype'] == 'uint8':
704
+ saliency_map = np.clip(saliency_map, 0, 1) * 255
705
+ saliency_map = saliency_map.astype(np.uint8)
706
+ mode = 'L' # Grayscale mode for uint8
707
+
708
+ # Get the class prediction (0 or 1)
709
+ class_pred = predicted_classes[i].item()
710
+
711
+ save_class_dir = os.path.join(save_dir, f'class_{class_pred}')
712
+ os.makedirs(save_class_dir, exist_ok=True)
713
+ save_path = os.path.join(save_class_dir, filenames[i])
714
+
715
+ # Handle different cases based on saliency_map dimensions
716
+ if saliency_map.ndim == 3: # Multi-channel case (C, H, W)
717
+ if saliency_map.shape[0] == 3: # RGB-like saliency map
718
+ saliency_image = Image.fromarray(np.moveaxis(saliency_map, 0, -1), mode="RGB") # Convert (C, H, W) to (H, W, C)
719
+ elif saliency_map.shape[0] == 1: # Single-channel case (1, H, W)
720
+ saliency_map = np.squeeze(saliency_map) # Remove the extra channel dimension
721
+ saliency_image = Image.fromarray(saliency_map, mode=mode) # Use grayscale mode for single-channel
722
+ else:
723
+ raise ValueError(f"Unexpected number of channels: {saliency_map.shape[0]}")
724
+
725
+ elif saliency_map.ndim == 2: # Single-channel case (H, W)
726
+ saliency_image = Image.fromarray(saliency_map, mode=mode) # Keep single channel (H, W)
727
+
728
+ else:
729
+ raise ValueError(f"Unexpected number of dimensions: {saliency_map.ndim}")
730
+
731
+ # Save the image
732
+ saliency_image.save(save_path)
733
+
734
+
735
+ stop = time.time()
736
+ duration = stop - start
737
+ time_ls.append(duration)
738
+ files_processed = batch_idx * settings['batch_size']
739
+ files_to_process = len(data_loader)
740
+ print_progress(files_processed, files_to_process, n_jobs=n_jobs, time_ls=time_ls, batch_size=settings['batch_size'], operation_type="Generating Saliency Maps")
741
+
742
+ print("Saliency map generation complete.")
743
+
744
+ def visualize_saliency_map_v1(src, model_type='maxvit', model_path='', image_size=224, channels=[1,2,3], normalize=True, class_names=None, save_saliency=False, save_dir='saliency_maps'):
599
745
 
600
746
  from spacr.utils import SaliencyMapGenerator, preprocess_image
601
747