spacr 0.3.2__py3-none-any.whl → 0.3.22__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- spacr/core.py +105 -1
- spacr/deep_spacr.py +171 -25
- spacr/io.py +80 -121
- spacr/ml.py +153 -66
- spacr/plot.py +429 -7
- spacr/settings.py +6 -5
- spacr/submodules.py +7 -6
- spacr/toxo.py +9 -4
- spacr/utils.py +152 -13
- {spacr-0.3.2.dist-info → spacr-0.3.22.dist-info}/METADATA +28 -25
- {spacr-0.3.2.dist-info → spacr-0.3.22.dist-info}/RECORD +15 -15
- {spacr-0.3.2.dist-info → spacr-0.3.22.dist-info}/LICENSE +0 -0
- {spacr-0.3.2.dist-info → spacr-0.3.22.dist-info}/WHEEL +0 -0
- {spacr-0.3.2.dist-info → spacr-0.3.22.dist-info}/entry_points.txt +0 -0
- {spacr-0.3.2.dist-info → spacr-0.3.22.dist-info}/top_level.txt +0 -0
spacr/core.py
CHANGED
@@ -844,4 +844,108 @@ def generate_mediar_masks(src, settings, object_type):
|
|
844
844
|
gc.collect()
|
845
845
|
torch.cuda.empty_cache()
|
846
846
|
|
847
|
-
print("Mask generation completed.")
|
847
|
+
print("Mask generation completed.")
|
848
|
+
|
849
|
+
def generate_screen_graphs(settings):
|
850
|
+
"""
|
851
|
+
Generate screen graphs for different measurements in a given source directory.
|
852
|
+
|
853
|
+
Args:
|
854
|
+
src (str or list): Path(s) to the source directory or directories.
|
855
|
+
tables (list): List of tables to include in the analysis (default: ['cell', 'nucleus', 'pathogen', 'cytoplasm']).
|
856
|
+
graph_type (str): Type of graph to generate (default: 'bar').
|
857
|
+
summary_func (str or function): Function to summarize data (default: 'mean').
|
858
|
+
y_axis_start (float): Starting value for the y-axis (default: 0).
|
859
|
+
error_bar_type (str): Type of error bar to use ('std' or 'sem') (default: 'std').
|
860
|
+
theme (str): Theme for the graph (default: 'pastel').
|
861
|
+
representation (str): Representation for grouping (default: 'well').
|
862
|
+
|
863
|
+
Returns:
|
864
|
+
figs (list): List of generated figures.
|
865
|
+
results (list): List of corresponding result DataFrames.
|
866
|
+
"""
|
867
|
+
|
868
|
+
from .plot import spacrGraph
|
869
|
+
from .io import _read_and_merge_data
|
870
|
+
from.utils import annotate_conditions
|
871
|
+
|
872
|
+
if isinstance(settings['src'], str):
|
873
|
+
srcs = [settings['src']]
|
874
|
+
else:
|
875
|
+
srcs = settings['src']
|
876
|
+
|
877
|
+
all_df = pd.DataFrame()
|
878
|
+
figs = []
|
879
|
+
results = []
|
880
|
+
|
881
|
+
for src in srcs:
|
882
|
+
db_loc = [os.path.join(src, 'measurements', 'measurements.db')]
|
883
|
+
|
884
|
+
# Read and merge data from the database
|
885
|
+
df, _ = _read_and_merge_data(db_loc, settings['tables'], verbose=True, nuclei_limit=settings['nuclei_limit'], pathogen_limit=settings['pathogen_limit'], uninfected=settings['uninfected'])
|
886
|
+
|
887
|
+
# Annotate the data
|
888
|
+
df = annotate_conditions(df, cells=settings['cells'], cell_loc=None, pathogens=settings['controls'], pathogen_loc=settings['controls_loc'], treatments=None, treatment_loc=None)
|
889
|
+
|
890
|
+
# Calculate recruitment metric
|
891
|
+
df['recruitment'] = df['pathogen_channel_1_mean_intensity'] / df['cytoplasm_channel_1_mean_intensity']
|
892
|
+
|
893
|
+
# Combine with the overall DataFrame
|
894
|
+
all_df = pd.concat([all_df, df], ignore_index=True)
|
895
|
+
|
896
|
+
# Generate individual plot
|
897
|
+
plotter = spacrGraph(df,
|
898
|
+
grouping_column='pathogen',
|
899
|
+
data_column='recruitment',
|
900
|
+
graph_type=settings['graph_type'],
|
901
|
+
summary_func=settings['summary_func'],
|
902
|
+
y_axis_start=settings['y_axis_start'],
|
903
|
+
error_bar_type=settings['error_bar_type'],
|
904
|
+
theme=settings['theme'],
|
905
|
+
representation=settings['representation'])
|
906
|
+
|
907
|
+
plotter.create_plot()
|
908
|
+
fig = plotter.get_figure()
|
909
|
+
results_df = plotter.get_results()
|
910
|
+
|
911
|
+
# Append to the lists
|
912
|
+
figs.append(fig)
|
913
|
+
results.append(results_df)
|
914
|
+
|
915
|
+
# Generate plot for the combined data (all_df)
|
916
|
+
plotter = spacrGraph(all_df,
|
917
|
+
grouping_column='pathogen',
|
918
|
+
data_column='recruitment',
|
919
|
+
graph_type=settings['graph_type'],
|
920
|
+
summary_func=settings['summary_func'],
|
921
|
+
y_axis_start=settings['y_axis_start'],
|
922
|
+
error_bar_type=settings['error_bar_type'],
|
923
|
+
theme=settings['theme'],
|
924
|
+
representation=settings['representation'])
|
925
|
+
|
926
|
+
plotter.create_plot()
|
927
|
+
fig = plotter.get_figure()
|
928
|
+
results_df = plotter.get_results()
|
929
|
+
|
930
|
+
figs.append(fig)
|
931
|
+
results.append(results_df)
|
932
|
+
|
933
|
+
# Save figures and results
|
934
|
+
for i, fig in enumerate(figs):
|
935
|
+
res = results[i]
|
936
|
+
|
937
|
+
if i < len(srcs):
|
938
|
+
source = srcs[i]
|
939
|
+
else:
|
940
|
+
source = srcs[0]
|
941
|
+
|
942
|
+
# Ensure the destination folder exists
|
943
|
+
dst = os.path.join(source, 'results')
|
944
|
+
print(f"Savings results to {dst}")
|
945
|
+
os.makedirs(dst, exist_ok=True)
|
946
|
+
|
947
|
+
# Save the figure and results DataFrame
|
948
|
+
fig.savefig(os.path.join(dst, f"figure_controls_{i}_{settings['representation']}_{settings['summary_func']}_{settings['graph_type']}.pdf"), format='pdf')
|
949
|
+
res.to_csv(os.path.join(dst, f"results_controls_{i}_{settings['representation']}_{settings['summary_func']}_{settings['graph_type']}.csv"), index=False)
|
950
|
+
|
951
|
+
return
|
spacr/deep_spacr.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
import os, torch, time, gc, datetime
|
1
|
+
import os, torch, time, gc, datetime, cv2
|
2
2
|
torch.backends.cudnn.benchmark = True
|
3
3
|
|
4
4
|
import numpy as np
|
@@ -10,6 +10,8 @@ import torch.nn.functional as F
|
|
10
10
|
import matplotlib.pyplot as plt
|
11
11
|
from PIL import Image
|
12
12
|
from sklearn.metrics import auc, precision_recall_curve
|
13
|
+
from IPython.display import display
|
14
|
+
from multiprocessing import cpu_count
|
13
15
|
|
14
16
|
from torchvision import transforms
|
15
17
|
from torch.utils.data import DataLoader
|
@@ -73,6 +75,12 @@ def apply_model_to_tar(settings={}):
|
|
73
75
|
|
74
76
|
from .io import TarImageDataset
|
75
77
|
from .utils import process_vision_results, print_progress
|
78
|
+
|
79
|
+
if os.path.exists(settings['dataset']):
|
80
|
+
tar_path = settings['dataset']
|
81
|
+
else:
|
82
|
+
tar_path = os.path.join(settings['src'], 'datasets', settings['dataset'])
|
83
|
+
model_path = settings['model_path']
|
76
84
|
|
77
85
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
78
86
|
if settings['normalize']:
|
@@ -86,18 +94,18 @@ def apply_model_to_tar(settings={}):
|
|
86
94
|
transforms.CenterCrop(size=(settings['image_size'], settings['image_size']))])
|
87
95
|
|
88
96
|
if settings['verbose']:
|
89
|
-
print(f"Loading model from {
|
90
|
-
print(f"Loading dataset from {
|
97
|
+
print(f"Loading model from {model_path}")
|
98
|
+
print(f"Loading dataset from {tar_path}")
|
91
99
|
|
92
100
|
model = torch.load(settings['model_path'])
|
93
101
|
|
94
|
-
dataset = TarImageDataset(
|
102
|
+
dataset = TarImageDataset(tar_path, transform=transform)
|
95
103
|
data_loader = DataLoader(dataset, batch_size=settings['batch_size'], shuffle=True, num_workers=settings['n_jobs'], pin_memory=True)
|
96
104
|
|
97
|
-
model_name = os.path.splitext(os.path.basename(
|
98
|
-
dataset_name = os.path.splitext(os.path.basename(settings['
|
105
|
+
model_name = os.path.splitext(os.path.basename(model_path))[0]
|
106
|
+
dataset_name = os.path.splitext(os.path.basename(settings['dataset']))[0]
|
99
107
|
date_name = datetime.date.today().strftime('%y%m%d')
|
100
|
-
dst = os.path.dirname(
|
108
|
+
dst = os.path.dirname(tar_path)
|
101
109
|
result_loc = f'{dst}/{date_name}_{dataset_name}_{model_name}_result.csv'
|
102
110
|
|
103
111
|
model.eval()
|
@@ -240,7 +248,7 @@ def evaluate_model_performance(model, loader, epoch, loss_type):
|
|
240
248
|
|
241
249
|
loss /= len(loader)
|
242
250
|
data_dict = classification_metrics(all_labels, prediction_pos_probs)
|
243
|
-
data_dict['loss'] = loss
|
251
|
+
data_dict['loss'] = loss.item()
|
244
252
|
data_dict['epoch'] = epoch
|
245
253
|
data_dict['Accuracy'] = acc
|
246
254
|
|
@@ -323,8 +331,8 @@ def test_model_performance(loaders, model, loader_name_list, epoch, loss_type):
|
|
323
331
|
|
324
332
|
def train_test_model(settings):
|
325
333
|
|
326
|
-
from .io import
|
327
|
-
from .utils import pick_best_model
|
334
|
+
from .io import _copy_missclassified
|
335
|
+
from .utils import pick_best_model, save_settings
|
328
336
|
from .io import generate_loaders
|
329
337
|
from .settings import get_train_test_model_settings
|
330
338
|
|
@@ -346,7 +354,12 @@ def train_test_model(settings):
|
|
346
354
|
model = torch.load(settings['custom_model_path'])
|
347
355
|
|
348
356
|
if settings['train']:
|
349
|
-
|
357
|
+
if settings['train'] and settings['test']:
|
358
|
+
save_settings(settings, name=f"train_test_{settings['model_type']}_{settings['epochs']}", show=True)
|
359
|
+
elif settings['train'] is True:
|
360
|
+
save_settings(settings, name=f"train_{settings['model_type']}_{settings['epochs']}", show=True)
|
361
|
+
elif settings['test'] is True:
|
362
|
+
save_settings(settings, name=f"test_{settings['model_type']}_{settings['epochs']}", show=True)
|
350
363
|
|
351
364
|
if settings['train']:
|
352
365
|
train, val, train_fig = generate_loaders(src,
|
@@ -574,19 +587,21 @@ def train_model(dst, model_type, train_loaders, epochs=100, learning_rate=0.0001
|
|
574
587
|
if schedule == 'step_lr':
|
575
588
|
scheduler.step()
|
576
589
|
|
577
|
-
if
|
578
|
-
|
579
|
-
|
580
|
-
|
581
|
-
|
582
|
-
|
583
|
-
|
584
|
-
|
585
|
-
|
586
|
-
|
587
|
-
|
588
|
-
|
589
|
-
|
590
|
+
if accumulated_train_dicts and accumulated_val_dicts:
|
591
|
+
train_df = pd.DataFrame(accumulated_train_dicts)
|
592
|
+
validation_df = pd.DataFrame(accumulated_val_dicts)
|
593
|
+
_save_progress(dst, train_df, validation_df)
|
594
|
+
accumulated_train_dicts, accumulated_val_dicts = [], []
|
595
|
+
|
596
|
+
elif accumulated_train_dicts:
|
597
|
+
train_df = pd.DataFrame(accumulated_train_dicts)
|
598
|
+
_save_progress(dst, train_df, None)
|
599
|
+
accumulated_train_dicts = []
|
600
|
+
elif accumulated_test_dicts:
|
601
|
+
test_df = pd.DataFrame(accumulated_test_dicts)
|
602
|
+
_save_progress(dst, test_df, None)
|
603
|
+
accumulated_test_dicts = []
|
604
|
+
|
590
605
|
batch_size = len(train_loaders)
|
591
606
|
duration = time.time() - start_time
|
592
607
|
time_ls.append(duration)
|
@@ -595,7 +610,138 @@ def train_model(dst, model_type, train_loaders, epochs=100, learning_rate=0.0001
|
|
595
610
|
|
596
611
|
return model, model_path
|
597
612
|
|
598
|
-
def visualize_saliency_map(
|
613
|
+
def visualize_saliency_map(settings):
|
614
|
+
from spacr.utils import SaliencyMapGenerator, print_progress
|
615
|
+
from spacr.io import TarImageDataset # Assuming you have a dataset class
|
616
|
+
from torchvision.utils import make_grid
|
617
|
+
|
618
|
+
use_cuda = torch.cuda.is_available()
|
619
|
+
device = torch.device("cuda" if use_cuda else "cpu")
|
620
|
+
|
621
|
+
# Set number of jobs for loading
|
622
|
+
if settings['n_jobs'] is None:
|
623
|
+
n_jobs = max(1, cpu_count() - 4)
|
624
|
+
else:
|
625
|
+
n_jobs = settings['n_jobs']
|
626
|
+
|
627
|
+
# Set transforms for images
|
628
|
+
if settings['normalize']:
|
629
|
+
transform = transforms.Compose([
|
630
|
+
transforms.ToTensor(),
|
631
|
+
transforms.CenterCrop(size=(settings['image_size'], settings['image_size'])),
|
632
|
+
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])
|
633
|
+
else:
|
634
|
+
transform = transforms.Compose([
|
635
|
+
transforms.ToTensor(),
|
636
|
+
transforms.CenterCrop(size=(settings['image_size'], settings['image_size']))])
|
637
|
+
|
638
|
+
# Handle dataset path
|
639
|
+
if os.path.exists(settings['dataset']):
|
640
|
+
tar_path = settings['dataset']
|
641
|
+
else:
|
642
|
+
print(f"Dataset not found at {settings['dataset']}")
|
643
|
+
return
|
644
|
+
|
645
|
+
if settings.get('save', False):
|
646
|
+
if settings['dtype'] not in ['uint8', 'uint16']:
|
647
|
+
print("Invalid dtype in settings. Please use 'uint8' or 'uint16'.")
|
648
|
+
return
|
649
|
+
|
650
|
+
# Load the model
|
651
|
+
model = torch.load(settings['model_path'])
|
652
|
+
model.to(device)
|
653
|
+
model.eval() # Ensure the model is in evaluation mode
|
654
|
+
|
655
|
+
# Create directory for saving saliency maps if it does not exist
|
656
|
+
if settings.get('save', False):
|
657
|
+
dataset_dir = os.path.dirname(tar_path)
|
658
|
+
dataset_name = os.path.splitext(os.path.basename(tar_path))[0]
|
659
|
+
save_dir = os.path.join(dataset_dir, dataset_name, 'saliency_maps')
|
660
|
+
os.makedirs(save_dir, exist_ok=True)
|
661
|
+
print(f"Saliency maps will be saved in: {save_dir}")
|
662
|
+
|
663
|
+
# Load dataset
|
664
|
+
dataset = TarImageDataset(tar_path, transform=transform)
|
665
|
+
data_loader = DataLoader(dataset, batch_size=settings['batch_size'], shuffle=True, num_workers=n_jobs, pin_memory=True)
|
666
|
+
|
667
|
+
# Initialize SaliencyMapGenerator
|
668
|
+
cam_generator = SaliencyMapGenerator(model)
|
669
|
+
time_ls = []
|
670
|
+
|
671
|
+
for batch_idx, (inputs, filenames) in enumerate(data_loader):
|
672
|
+
start = time.time()
|
673
|
+
inputs = inputs.to(device)
|
674
|
+
|
675
|
+
saliency_maps, predicted_classes = cam_generator.compute_saliency_and_predictions(inputs)
|
676
|
+
|
677
|
+
if settings['saliency_mode'] not in ['mean', 'sum']:
|
678
|
+
print("To generate channel average or sum saliency maps set saliency_mode to 'mean' or 'sum', respectively.")
|
679
|
+
|
680
|
+
if settings['saliency_mode'] == 'mean':
|
681
|
+
saliency_maps = saliency_maps.mean(dim=1, keepdim=True)
|
682
|
+
|
683
|
+
elif settings['saliency_mode'] == 'sum':
|
684
|
+
saliency_maps = saliency_maps.sum(dim=1, keepdim=True)
|
685
|
+
|
686
|
+
# Example usage with the class
|
687
|
+
if settings.get('plot', False):
|
688
|
+
if settings['plot_mode'] not in ['mean', 'channel', '3-channel']:
|
689
|
+
print("Invalid plot_mode in settings. Please use 'mean', 'channel', or '3-channel'.")
|
690
|
+
return
|
691
|
+
else:
|
692
|
+
cam_generator.plot_saliency_grid(inputs, saliency_maps, predicted_classes, mode=settings['plot_mode'])
|
693
|
+
|
694
|
+
if settings.get('save', False):
|
695
|
+
for i in range(inputs.size(0)):
|
696
|
+
saliency_map = saliency_maps[i].detach().cpu().numpy()
|
697
|
+
|
698
|
+
# Check dtype in settings and normalize accordingly
|
699
|
+
if settings['dtype'] == 'uint16':
|
700
|
+
saliency_map = np.clip(saliency_map, 0, 1) * 65535
|
701
|
+
saliency_map = saliency_map.astype(np.uint16)
|
702
|
+
mode = 'I;16'
|
703
|
+
elif settings['dtype'] == 'uint8':
|
704
|
+
saliency_map = np.clip(saliency_map, 0, 1) * 255
|
705
|
+
saliency_map = saliency_map.astype(np.uint8)
|
706
|
+
mode = 'L' # Grayscale mode for uint8
|
707
|
+
|
708
|
+
# Get the class prediction (0 or 1)
|
709
|
+
class_pred = predicted_classes[i].item()
|
710
|
+
|
711
|
+
save_class_dir = os.path.join(save_dir, f'class_{class_pred}')
|
712
|
+
os.makedirs(save_class_dir, exist_ok=True)
|
713
|
+
save_path = os.path.join(save_class_dir, filenames[i])
|
714
|
+
|
715
|
+
# Handle different cases based on saliency_map dimensions
|
716
|
+
if saliency_map.ndim == 3: # Multi-channel case (C, H, W)
|
717
|
+
if saliency_map.shape[0] == 3: # RGB-like saliency map
|
718
|
+
saliency_image = Image.fromarray(np.moveaxis(saliency_map, 0, -1), mode="RGB") # Convert (C, H, W) to (H, W, C)
|
719
|
+
elif saliency_map.shape[0] == 1: # Single-channel case (1, H, W)
|
720
|
+
saliency_map = np.squeeze(saliency_map) # Remove the extra channel dimension
|
721
|
+
saliency_image = Image.fromarray(saliency_map, mode=mode) # Use grayscale mode for single-channel
|
722
|
+
else:
|
723
|
+
raise ValueError(f"Unexpected number of channels: {saliency_map.shape[0]}")
|
724
|
+
|
725
|
+
elif saliency_map.ndim == 2: # Single-channel case (H, W)
|
726
|
+
saliency_image = Image.fromarray(saliency_map, mode=mode) # Keep single channel (H, W)
|
727
|
+
|
728
|
+
else:
|
729
|
+
raise ValueError(f"Unexpected number of dimensions: {saliency_map.ndim}")
|
730
|
+
|
731
|
+
# Save the image
|
732
|
+
saliency_image.save(save_path)
|
733
|
+
|
734
|
+
|
735
|
+
stop = time.time()
|
736
|
+
duration = stop - start
|
737
|
+
time_ls.append(duration)
|
738
|
+
files_processed = batch_idx * settings['batch_size']
|
739
|
+
files_to_process = len(data_loader)
|
740
|
+
print_progress(files_processed, files_to_process, n_jobs=n_jobs, time_ls=time_ls, batch_size=settings['batch_size'], operation_type="Generating Saliency Maps")
|
741
|
+
|
742
|
+
print("Saliency map generation complete.")
|
743
|
+
|
744
|
+
def visualize_saliency_map_v1(src, model_type='maxvit', model_path='', image_size=224, channels=[1,2,3], normalize=True, class_names=None, save_saliency=False, save_dir='saliency_maps'):
|
599
745
|
|
600
746
|
from spacr.utils import SaliencyMapGenerator, preprocess_image
|
601
747
|
|