spacr 0.3.2__py3-none-any.whl → 0.3.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- spacr/core.py +105 -1
- spacr/deep_spacr.py +191 -141
- spacr/gui.py +1 -0
- spacr/gui_core.py +13 -4
- spacr/gui_utils.py +29 -1
- spacr/io.py +84 -125
- spacr/measure.py +1 -38
- spacr/ml.py +153 -66
- spacr/plot.py +429 -7
- spacr/settings.py +55 -10
- spacr/submodules.py +7 -6
- spacr/toxo.py +9 -4
- spacr/utils.py +510 -16
- {spacr-0.3.2.dist-info → spacr-0.3.3.dist-info}/METADATA +28 -25
- {spacr-0.3.2.dist-info → spacr-0.3.3.dist-info}/RECORD +19 -19
- {spacr-0.3.2.dist-info → spacr-0.3.3.dist-info}/LICENSE +0 -0
- {spacr-0.3.2.dist-info → spacr-0.3.3.dist-info}/WHEEL +0 -0
- {spacr-0.3.2.dist-info → spacr-0.3.3.dist-info}/entry_points.txt +0 -0
- {spacr-0.3.2.dist-info → spacr-0.3.3.dist-info}/top_level.txt +0 -0
spacr/core.py
CHANGED
@@ -844,4 +844,108 @@ def generate_mediar_masks(src, settings, object_type):
|
|
844
844
|
gc.collect()
|
845
845
|
torch.cuda.empty_cache()
|
846
846
|
|
847
|
-
print("Mask generation completed.")
|
847
|
+
print("Mask generation completed.")
|
848
|
+
|
849
|
+
def generate_screen_graphs(settings):
|
850
|
+
"""
|
851
|
+
Generate screen graphs for different measurements in a given source directory.
|
852
|
+
|
853
|
+
Args:
|
854
|
+
src (str or list): Path(s) to the source directory or directories.
|
855
|
+
tables (list): List of tables to include in the analysis (default: ['cell', 'nucleus', 'pathogen', 'cytoplasm']).
|
856
|
+
graph_type (str): Type of graph to generate (default: 'bar').
|
857
|
+
summary_func (str or function): Function to summarize data (default: 'mean').
|
858
|
+
y_axis_start (float): Starting value for the y-axis (default: 0).
|
859
|
+
error_bar_type (str): Type of error bar to use ('std' or 'sem') (default: 'std').
|
860
|
+
theme (str): Theme for the graph (default: 'pastel').
|
861
|
+
representation (str): Representation for grouping (default: 'well').
|
862
|
+
|
863
|
+
Returns:
|
864
|
+
figs (list): List of generated figures.
|
865
|
+
results (list): List of corresponding result DataFrames.
|
866
|
+
"""
|
867
|
+
|
868
|
+
from .plot import spacrGraph
|
869
|
+
from .io import _read_and_merge_data
|
870
|
+
from.utils import annotate_conditions
|
871
|
+
|
872
|
+
if isinstance(settings['src'], str):
|
873
|
+
srcs = [settings['src']]
|
874
|
+
else:
|
875
|
+
srcs = settings['src']
|
876
|
+
|
877
|
+
all_df = pd.DataFrame()
|
878
|
+
figs = []
|
879
|
+
results = []
|
880
|
+
|
881
|
+
for src in srcs:
|
882
|
+
db_loc = [os.path.join(src, 'measurements', 'measurements.db')]
|
883
|
+
|
884
|
+
# Read and merge data from the database
|
885
|
+
df, _ = _read_and_merge_data(db_loc, settings['tables'], verbose=True, nuclei_limit=settings['nuclei_limit'], pathogen_limit=settings['pathogen_limit'], uninfected=settings['uninfected'])
|
886
|
+
|
887
|
+
# Annotate the data
|
888
|
+
df = annotate_conditions(df, cells=settings['cells'], cell_loc=None, pathogens=settings['controls'], pathogen_loc=settings['controls_loc'], treatments=None, treatment_loc=None)
|
889
|
+
|
890
|
+
# Calculate recruitment metric
|
891
|
+
df['recruitment'] = df['pathogen_channel_1_mean_intensity'] / df['cytoplasm_channel_1_mean_intensity']
|
892
|
+
|
893
|
+
# Combine with the overall DataFrame
|
894
|
+
all_df = pd.concat([all_df, df], ignore_index=True)
|
895
|
+
|
896
|
+
# Generate individual plot
|
897
|
+
plotter = spacrGraph(df,
|
898
|
+
grouping_column='pathogen',
|
899
|
+
data_column='recruitment',
|
900
|
+
graph_type=settings['graph_type'],
|
901
|
+
summary_func=settings['summary_func'],
|
902
|
+
y_axis_start=settings['y_axis_start'],
|
903
|
+
error_bar_type=settings['error_bar_type'],
|
904
|
+
theme=settings['theme'],
|
905
|
+
representation=settings['representation'])
|
906
|
+
|
907
|
+
plotter.create_plot()
|
908
|
+
fig = plotter.get_figure()
|
909
|
+
results_df = plotter.get_results()
|
910
|
+
|
911
|
+
# Append to the lists
|
912
|
+
figs.append(fig)
|
913
|
+
results.append(results_df)
|
914
|
+
|
915
|
+
# Generate plot for the combined data (all_df)
|
916
|
+
plotter = spacrGraph(all_df,
|
917
|
+
grouping_column='pathogen',
|
918
|
+
data_column='recruitment',
|
919
|
+
graph_type=settings['graph_type'],
|
920
|
+
summary_func=settings['summary_func'],
|
921
|
+
y_axis_start=settings['y_axis_start'],
|
922
|
+
error_bar_type=settings['error_bar_type'],
|
923
|
+
theme=settings['theme'],
|
924
|
+
representation=settings['representation'])
|
925
|
+
|
926
|
+
plotter.create_plot()
|
927
|
+
fig = plotter.get_figure()
|
928
|
+
results_df = plotter.get_results()
|
929
|
+
|
930
|
+
figs.append(fig)
|
931
|
+
results.append(results_df)
|
932
|
+
|
933
|
+
# Save figures and results
|
934
|
+
for i, fig in enumerate(figs):
|
935
|
+
res = results[i]
|
936
|
+
|
937
|
+
if i < len(srcs):
|
938
|
+
source = srcs[i]
|
939
|
+
else:
|
940
|
+
source = srcs[0]
|
941
|
+
|
942
|
+
# Ensure the destination folder exists
|
943
|
+
dst = os.path.join(source, 'results')
|
944
|
+
print(f"Savings results to {dst}")
|
945
|
+
os.makedirs(dst, exist_ok=True)
|
946
|
+
|
947
|
+
# Save the figure and results DataFrame
|
948
|
+
fig.savefig(os.path.join(dst, f"figure_controls_{i}_{settings['representation']}_{settings['summary_func']}_{settings['graph_type']}.pdf"), format='pdf')
|
949
|
+
res.to_csv(os.path.join(dst, f"results_controls_{i}_{settings['representation']}_{settings['summary_func']}_{settings['graph_type']}.csv"), index=False)
|
950
|
+
|
951
|
+
return
|
spacr/deep_spacr.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
import os, torch, time, gc, datetime
|
1
|
+
import os, torch, time, gc, datetime, cv2
|
2
2
|
torch.backends.cudnn.benchmark = True
|
3
3
|
|
4
4
|
import numpy as np
|
@@ -10,6 +10,8 @@ import torch.nn.functional as F
|
|
10
10
|
import matplotlib.pyplot as plt
|
11
11
|
from PIL import Image
|
12
12
|
from sklearn.metrics import auc, precision_recall_curve
|
13
|
+
from IPython.display import display
|
14
|
+
from multiprocessing import cpu_count
|
13
15
|
|
14
16
|
from torchvision import transforms
|
15
17
|
from torch.utils.data import DataLoader
|
@@ -73,6 +75,12 @@ def apply_model_to_tar(settings={}):
|
|
73
75
|
|
74
76
|
from .io import TarImageDataset
|
75
77
|
from .utils import process_vision_results, print_progress
|
78
|
+
|
79
|
+
if os.path.exists(settings['dataset']):
|
80
|
+
tar_path = settings['dataset']
|
81
|
+
else:
|
82
|
+
tar_path = os.path.join(settings['src'], 'datasets', settings['dataset'])
|
83
|
+
model_path = settings['model_path']
|
76
84
|
|
77
85
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
78
86
|
if settings['normalize']:
|
@@ -86,18 +94,18 @@ def apply_model_to_tar(settings={}):
|
|
86
94
|
transforms.CenterCrop(size=(settings['image_size'], settings['image_size']))])
|
87
95
|
|
88
96
|
if settings['verbose']:
|
89
|
-
print(f"Loading model from {
|
90
|
-
print(f"Loading dataset from {
|
97
|
+
print(f"Loading model from {model_path}")
|
98
|
+
print(f"Loading dataset from {tar_path}")
|
91
99
|
|
92
100
|
model = torch.load(settings['model_path'])
|
93
101
|
|
94
|
-
dataset = TarImageDataset(
|
102
|
+
dataset = TarImageDataset(tar_path, transform=transform)
|
95
103
|
data_loader = DataLoader(dataset, batch_size=settings['batch_size'], shuffle=True, num_workers=settings['n_jobs'], pin_memory=True)
|
96
104
|
|
97
|
-
model_name = os.path.splitext(os.path.basename(
|
98
|
-
dataset_name = os.path.splitext(os.path.basename(settings['
|
105
|
+
model_name = os.path.splitext(os.path.basename(model_path))[0]
|
106
|
+
dataset_name = os.path.splitext(os.path.basename(settings['dataset']))[0]
|
99
107
|
date_name = datetime.date.today().strftime('%y%m%d')
|
100
|
-
dst = os.path.dirname(
|
108
|
+
dst = os.path.dirname(tar_path)
|
101
109
|
result_loc = f'{dst}/{date_name}_{dataset_name}_{model_name}_result.csv'
|
102
110
|
|
103
111
|
model.eval()
|
@@ -240,7 +248,7 @@ def evaluate_model_performance(model, loader, epoch, loss_type):
|
|
240
248
|
|
241
249
|
loss /= len(loader)
|
242
250
|
data_dict = classification_metrics(all_labels, prediction_pos_probs)
|
243
|
-
data_dict['loss'] = loss
|
251
|
+
data_dict['loss'] = loss.item()
|
244
252
|
data_dict['epoch'] = epoch
|
245
253
|
data_dict['Accuracy'] = acc
|
246
254
|
|
@@ -323,8 +331,8 @@ def test_model_performance(loaders, model, loader_name_list, epoch, loss_type):
|
|
323
331
|
|
324
332
|
def train_test_model(settings):
|
325
333
|
|
326
|
-
from .io import
|
327
|
-
from .utils import pick_best_model
|
334
|
+
from .io import _copy_missclassified
|
335
|
+
from .utils import pick_best_model, save_settings
|
328
336
|
from .io import generate_loaders
|
329
337
|
from .settings import get_train_test_model_settings
|
330
338
|
|
@@ -346,7 +354,12 @@ def train_test_model(settings):
|
|
346
354
|
model = torch.load(settings['custom_model_path'])
|
347
355
|
|
348
356
|
if settings['train']:
|
349
|
-
|
357
|
+
if settings['train'] and settings['test']:
|
358
|
+
save_settings(settings, name=f"train_test_{settings['model_type']}_{settings['epochs']}", show=True)
|
359
|
+
elif settings['train'] is True:
|
360
|
+
save_settings(settings, name=f"train_{settings['model_type']}_{settings['epochs']}", show=True)
|
361
|
+
elif settings['test'] is True:
|
362
|
+
save_settings(settings, name=f"test_{settings['model_type']}_{settings['epochs']}", show=True)
|
350
363
|
|
351
364
|
if settings['train']:
|
352
365
|
train, val, train_fig = generate_loaders(src,
|
@@ -574,19 +587,21 @@ def train_model(dst, model_type, train_loaders, epochs=100, learning_rate=0.0001
|
|
574
587
|
if schedule == 'step_lr':
|
575
588
|
scheduler.step()
|
576
589
|
|
577
|
-
if
|
578
|
-
|
579
|
-
|
580
|
-
|
581
|
-
|
582
|
-
|
583
|
-
|
584
|
-
|
585
|
-
|
586
|
-
|
587
|
-
|
588
|
-
|
589
|
-
|
590
|
+
if accumulated_train_dicts and accumulated_val_dicts:
|
591
|
+
train_df = pd.DataFrame(accumulated_train_dicts)
|
592
|
+
validation_df = pd.DataFrame(accumulated_val_dicts)
|
593
|
+
_save_progress(dst, train_df, validation_df)
|
594
|
+
accumulated_train_dicts, accumulated_val_dicts = [], []
|
595
|
+
|
596
|
+
elif accumulated_train_dicts:
|
597
|
+
train_df = pd.DataFrame(accumulated_train_dicts)
|
598
|
+
_save_progress(dst, train_df, None)
|
599
|
+
accumulated_train_dicts = []
|
600
|
+
elif accumulated_test_dicts:
|
601
|
+
test_df = pd.DataFrame(accumulated_test_dicts)
|
602
|
+
_save_progress(dst, test_df, None)
|
603
|
+
accumulated_test_dicts = []
|
604
|
+
|
590
605
|
batch_size = len(train_loaders)
|
591
606
|
duration = time.time() - start_time
|
592
607
|
time_ls.append(duration)
|
@@ -595,133 +610,168 @@ def train_model(dst, model_type, train_loaders, epochs=100, learning_rate=0.0001
|
|
595
610
|
|
596
611
|
return model, model_path
|
597
612
|
|
598
|
-
def
|
599
|
-
|
600
|
-
from
|
601
|
-
|
613
|
+
def generate_activation_map(settings):
|
614
|
+
|
615
|
+
from .utils import SaliencyMapGenerator, GradCAMGenerator, SelectChannels, activation_maps_to_database, activation_correlations_to_database
|
616
|
+
from .utils import print_progress, save_settings, calculate_activation_correlations
|
617
|
+
from .io import TarImageDataset
|
618
|
+
from .settings import get_default_generate_activation_map_settings
|
619
|
+
|
620
|
+
torch.cuda.empty_cache()
|
621
|
+
gc.collect()
|
622
|
+
|
623
|
+
plt.clf()
|
602
624
|
use_cuda = torch.cuda.is_available()
|
603
625
|
device = torch.device("cuda" if use_cuda else "cpu")
|
626
|
+
|
627
|
+
source_folder = os.path.dirname(os.path.dirname(settings['dataset']))
|
628
|
+
settings['src'] = source_folder
|
629
|
+
settings = get_default_generate_activation_map_settings(settings)
|
630
|
+
save_settings(settings, name=f"{settings['cam_type']}_settings", show=False)
|
631
|
+
|
632
|
+
if settings['model_type'] == 'maxvit' and settings['target_layer'] == None:
|
633
|
+
settings['target_layer'] = 'base_model.blocks.3.layers.1.layers.MBconv.layers.conv_b'
|
634
|
+
if settings['cam_type'] in ['saliency_image', 'saliency_channel']:
|
635
|
+
settings['target_layer'] = None
|
636
|
+
|
637
|
+
# Set number of jobs for loading
|
638
|
+
n_jobs = settings['n_jobs']
|
639
|
+
if n_jobs is None:
|
640
|
+
n_jobs = max(1, cpu_count() - 4)
|
641
|
+
|
642
|
+
# Set transforms for images
|
643
|
+
transform = transforms.Compose([
|
644
|
+
transforms.ToTensor(),
|
645
|
+
transforms.CenterCrop(size=(settings['image_size'], settings['image_size'])),
|
646
|
+
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) if settings['normalize_input'] else None,
|
647
|
+
SelectChannels(settings['channels'])
|
648
|
+
])
|
649
|
+
|
650
|
+
# Handle dataset path
|
651
|
+
if not os.path.exists(settings['dataset']):
|
652
|
+
print(f"Dataset not found at {settings['dataset']}")
|
653
|
+
return
|
604
654
|
|
605
|
-
# Load the
|
606
|
-
model = torch.load(model_path)
|
655
|
+
# Load the model
|
656
|
+
model = torch.load(settings['model_path'])
|
607
657
|
model.to(device)
|
658
|
+
model.eval()
|
608
659
|
|
609
|
-
# Create directory for saving
|
610
|
-
|
611
|
-
|
612
|
-
|
613
|
-
|
614
|
-
images = []
|
615
|
-
input_tensors = []
|
616
|
-
filenames = []
|
617
|
-
for file in os.listdir(src):
|
618
|
-
if not file.endswith('.png'):
|
619
|
-
continue
|
620
|
-
image_path = os.path.join(src, file)
|
621
|
-
image, input_tensor = preprocess_image(image_path, normalize=normalize, image_size=image_size, channels=channels)
|
622
|
-
images.append(image)
|
623
|
-
input_tensors.append(input_tensor)
|
624
|
-
filenames.append(file)
|
625
|
-
|
626
|
-
input_tensors = torch.cat(input_tensors).to(device)
|
627
|
-
class_labels = torch.zeros(input_tensors.size(0), dtype=torch.long).to(device) # Replace with actual class labels if available
|
628
|
-
|
629
|
-
# Generate saliency maps
|
630
|
-
cam_generator = SaliencyMapGenerator(model)
|
631
|
-
saliency_maps = cam_generator.compute_saliency_maps(input_tensors, class_labels)
|
632
|
-
|
633
|
-
# Convert saliency maps to numpy arrays
|
634
|
-
saliency_maps = saliency_maps.cpu().numpy()
|
635
|
-
|
636
|
-
N = len(images)
|
637
|
-
|
638
|
-
dst = os.path.join(src, 'saliency_maps')
|
639
|
-
|
640
|
-
for i in range(N):
|
641
|
-
fig, axes = plt.subplots(1, 3, figsize=(20, 5))
|
642
|
-
|
643
|
-
# Original image
|
644
|
-
axes[0].imshow(images[i])
|
645
|
-
axes[0].axis('off')
|
646
|
-
if class_names:
|
647
|
-
axes[0].set_title(f"Class: {class_names[class_labels[i].item()]}")
|
648
|
-
|
649
|
-
# Saliency Map
|
650
|
-
axes[1].imshow(saliency_maps[i, 0], cmap='hot')
|
651
|
-
axes[1].axis('off')
|
652
|
-
axes[1].set_title("Saliency Map")
|
653
|
-
|
654
|
-
# Overlay
|
655
|
-
overlay = np.array(images[i])
|
656
|
-
overlay = overlay / overlay.max()
|
657
|
-
saliency_map_rgb = np.stack([saliency_maps[i, 0]] * 3, axis=-1) # Convert saliency map to RGB
|
658
|
-
overlay = (overlay * 0.5 + saliency_map_rgb * 0.5).clip(0, 1)
|
659
|
-
axes[2].imshow(overlay)
|
660
|
-
axes[2].axis('off')
|
661
|
-
axes[2].set_title("Overlay")
|
662
|
-
|
663
|
-
plt.tight_layout()
|
664
|
-
plt.show()
|
665
|
-
|
666
|
-
# Save the saliency map if required
|
667
|
-
if save_saliency:
|
668
|
-
os.makedirs(dst, exist_ok=True)
|
669
|
-
saliency_image = Image.fromarray((saliency_maps[i, 0] * 255).astype(np.uint8))
|
670
|
-
saliency_image.save(os.path.join(dst, f'saliency_{filenames[i]}'))
|
671
|
-
|
672
|
-
def visualize_grad_cam(src, model_path, target_layers=None, image_size=224, channels=[1, 2, 3], normalize=True, class_names=None, save_cam=False, save_dir='grad_cam'):
|
673
|
-
|
674
|
-
from spacr.utils import GradCAM, preprocess_image, show_cam_on_image, recommend_target_layers
|
675
|
-
|
676
|
-
use_cuda = torch.cuda.is_available()
|
677
|
-
device = torch.device("cuda" if use_cuda else "cpu")
|
678
|
-
|
679
|
-
model = torch.load(model_path)
|
680
|
-
model.to(device)
|
681
|
-
|
682
|
-
# If no target layers provided, recommend a target layer
|
683
|
-
if target_layers is None:
|
684
|
-
target_layers, all_layers = recommend_target_layers(model)
|
685
|
-
print(f"No target layer provided. Using recommended layer: {target_layers[0]}")
|
686
|
-
print("All possible target layers:")
|
687
|
-
for layer in all_layers:
|
688
|
-
print(layer)
|
660
|
+
# Create directory for saving activation maps if it does not exist
|
661
|
+
dataset_dir = os.path.dirname(settings['dataset'])
|
662
|
+
dataset_name = os.path.splitext(os.path.basename(settings['dataset']))[0]
|
663
|
+
save_dir = os.path.join(dataset_dir, dataset_name, settings['cam_type'])
|
664
|
+
batch_grid_fldr = os.path.join(save_dir, 'batch_grids')
|
689
665
|
|
690
|
-
|
691
|
-
|
692
|
-
|
693
|
-
|
694
|
-
|
695
|
-
|
696
|
-
|
697
|
-
|
698
|
-
|
699
|
-
|
700
|
-
|
701
|
-
|
702
|
-
|
703
|
-
|
704
|
-
|
705
|
-
|
706
|
-
|
707
|
-
|
708
|
-
|
709
|
-
|
710
|
-
|
711
|
-
|
712
|
-
|
713
|
-
|
714
|
-
|
715
|
-
|
716
|
-
|
666
|
+
if settings['save']:
|
667
|
+
os.makedirs(save_dir, exist_ok=True)
|
668
|
+
print(f"Activation maps will be saved in: {save_dir}")
|
669
|
+
|
670
|
+
if settings['plot']:
|
671
|
+
os.makedirs(batch_grid_fldr, exist_ok=True)
|
672
|
+
print(f"Batch grid maps will be saved in: {batch_grid_fldr}")
|
673
|
+
|
674
|
+
# Load dataset
|
675
|
+
dataset = TarImageDataset(settings['dataset'], transform=transform)
|
676
|
+
data_loader = DataLoader(dataset, batch_size=settings['batch_size'], shuffle=settings['shuffle'], num_workers=n_jobs, pin_memory=True)
|
677
|
+
|
678
|
+
# Initialize generator based on cam_type
|
679
|
+
if settings['cam_type'] in ['gradcam', 'gradcam_pp']:
|
680
|
+
cam_generator = GradCAMGenerator(model, target_layer=settings['target_layer'], cam_type=settings['cam_type'])
|
681
|
+
elif settings['cam_type'] in ['saliency_image', 'saliency_channel']:
|
682
|
+
cam_generator = SaliencyMapGenerator(model)
|
683
|
+
|
684
|
+
time_ls = []
|
685
|
+
for batch_idx, (inputs, filenames) in enumerate(data_loader):
|
686
|
+
start = time.time()
|
687
|
+
img_paths = []
|
688
|
+
inputs = inputs.to(device)
|
689
|
+
|
690
|
+
# Compute activation maps and predictions
|
691
|
+
if settings['cam_type'] in ['gradcam', 'gradcam_pp']:
|
692
|
+
activation_maps, predicted_classes = cam_generator.compute_gradcam_and_predictions(inputs)
|
693
|
+
elif settings['cam_type'] in ['saliency_image', 'saliency_channel']:
|
694
|
+
activation_maps, predicted_classes = cam_generator.compute_saliency_and_predictions(inputs)
|
695
|
+
|
696
|
+
# Move activation maps to CPU
|
697
|
+
activation_maps = activation_maps.cpu()
|
698
|
+
|
699
|
+
# Sum saliency maps for 'saliency_image' type
|
700
|
+
if settings['cam_type'] == 'saliency_image':
|
701
|
+
summed_activation_maps = []
|
702
|
+
for i in range(activation_maps.size(0)):
|
703
|
+
activation_map = activation_maps[i]
|
704
|
+
#print(f"1: {activation_map.shape}")
|
705
|
+
activation_map_sum = activation_map.sum(dim=0, keepdim=False)
|
706
|
+
#print(f"2: {activation_map.shape}")
|
707
|
+
activation_map_sum = np.squeeze(activation_map_sum, axis=0)
|
708
|
+
#print(f"3: {activation_map_sum.shape}")
|
709
|
+
summed_activation_maps.append(activation_map_sum)
|
710
|
+
activation_maps = torch.stack(summed_activation_maps)
|
711
|
+
|
712
|
+
# For plotting
|
713
|
+
if settings['plot']:
|
714
|
+
fig = cam_generator.plot_activation_grid(inputs, activation_maps, predicted_classes, overlay=settings['overlay'], normalize=settings['normalize'])
|
715
|
+
pdf_save_path = os.path.join(batch_grid_fldr,f"batch_{batch_idx}_grid.pdf")
|
716
|
+
fig.savefig(pdf_save_path, format='pdf')
|
717
|
+
print(f"Saved batch grid to {pdf_save_path}")
|
718
|
+
#plt.show()
|
719
|
+
display(fig)
|
720
|
+
|
721
|
+
for i in range(inputs.size(0)):
|
722
|
+
activation_map = activation_maps[i].detach().numpy()
|
723
|
+
|
724
|
+
if settings['cam_type'] in ['saliency_image', 'gradcam', 'gradcam_pp']:
|
725
|
+
#activation_map = activation_map.sum(axis=0)
|
726
|
+
activation_map = (activation_map - activation_map.min()) / (activation_map.max() - activation_map.min())
|
727
|
+
activation_map = (activation_map * 255).astype(np.uint8)
|
728
|
+
activation_image = Image.fromarray(activation_map, mode='L')
|
729
|
+
|
730
|
+
elif settings['cam_type'] == 'saliency_channel':
|
731
|
+
# Handle each channel separately and save as RGB
|
732
|
+
rgb_activation_map = np.zeros((activation_map.shape[1], activation_map.shape[2], 3), dtype=np.uint8)
|
733
|
+
for c in range(min(activation_map.shape[0], 3)): # Limit to 3 channels for RGB
|
734
|
+
channel_map = activation_map[c]
|
735
|
+
channel_map = (channel_map - channel_map.min()) / (channel_map.max() - channel_map.min())
|
736
|
+
rgb_activation_map[:, :, c] = (channel_map * 255).astype(np.uint8)
|
737
|
+
activation_image = Image.fromarray(rgb_activation_map, mode='RGB')
|
738
|
+
|
739
|
+
# Save activation maps
|
740
|
+
class_pred = predicted_classes[i].item()
|
741
|
+
parts = filenames[i].split('_')
|
742
|
+
plate = parts[0]
|
743
|
+
well = parts[1]
|
744
|
+
save_class_dir = os.path.join(save_dir, f'class_{class_pred}', str(plate), str(well))
|
745
|
+
os.makedirs(save_class_dir, exist_ok=True)
|
746
|
+
save_path = os.path.join(save_class_dir, f'{filenames[i]}')
|
747
|
+
if settings['save']:
|
748
|
+
activation_image.save(save_path)
|
749
|
+
img_paths.append(save_path)
|
750
|
+
|
751
|
+
if settings['save']:
|
752
|
+
activation_maps_to_database(img_paths, source_folder, settings)
|
753
|
+
|
754
|
+
if settings['correlation']:
|
755
|
+
df = calculate_activation_correlations(inputs, activation_maps, filenames, manders_thresholds=settings['manders_thresholds'])
|
756
|
+
if settings['plot']:
|
757
|
+
display(df)
|
758
|
+
if settings['save']:
|
759
|
+
activation_correlations_to_database(df, img_paths, source_folder, settings)
|
760
|
+
|
761
|
+
stop = time.time()
|
762
|
+
duration = stop - start
|
763
|
+
time_ls.append(duration)
|
764
|
+
files_processed = batch_idx * settings['batch_size']
|
765
|
+
files_to_process = len(data_loader) * settings['batch_size']
|
766
|
+
print_progress(files_processed, files_to_process, n_jobs=n_jobs, time_ls=time_ls, batch_size=settings['batch_size'], operation_type="Generating Activation Maps")
|
717
767
|
|
718
|
-
|
719
|
-
|
720
|
-
|
768
|
+
torch.cuda.empty_cache()
|
769
|
+
gc.collect()
|
770
|
+
print("Activation map generation complete.")
|
721
771
|
|
722
772
|
def visualize_classes(model, dtype, class_names, **kwargs):
|
723
773
|
|
724
|
-
from
|
774
|
+
from .utils import class_visualization
|
725
775
|
|
726
776
|
for target_y in range(2): # Assuming binary classification
|
727
777
|
print(f"Visualizing class: {class_names[target_y]}")
|
spacr/gui.py
CHANGED
@@ -57,6 +57,7 @@ class MainApp(tk.Tk):
|
|
57
57
|
"Map Barcodes": (lambda frame: initiate_root(self, 'map_barcodes'), "Map barcodes to data."),
|
58
58
|
"Regression": (lambda frame: initiate_root(self, 'regression'), "Perform regression analysis."),
|
59
59
|
"Recruitment": (lambda frame: initiate_root(self, 'recruitment'), "Analyze recruitment data."),
|
60
|
+
"Activation": (lambda frame: initiate_root(self, 'activation'), "Generate activation maps of computer vision models and measure channel-activation correlation."),
|
60
61
|
"Plaque": (lambda frame: initiate_root(self, 'analyze_plaques'), "Analyze plaque data.")
|
61
62
|
}
|
62
63
|
|
spacr/gui_core.py
CHANGED
@@ -379,10 +379,13 @@ def set_globals(thread_control_var, q_var, console_output_var, parent_frame_var,
|
|
379
379
|
index_control = index_control_var
|
380
380
|
|
381
381
|
def import_settings(settings_type='mask'):
|
382
|
-
from .gui_utils import convert_settings_dict_for_gui, hide_all_settings
|
383
382
|
global vars_dict, scrollable_frame, button_scrollable_frame
|
384
|
-
from .settings import generate_fields, set_default_settings_preprocess_generate_masks, get_measure_crop_settings, set_default_train_test_model, set_default_generate_barecode_mapping, set_default_umap_image_settings, get_analyze_recruitment_default_settings
|
385
383
|
|
384
|
+
from .gui_utils import convert_settings_dict_for_gui, hide_all_settings
|
385
|
+
from .settings import generate_fields, set_default_settings_preprocess_generate_masks, get_measure_crop_settings, set_default_train_test_model
|
386
|
+
from .settings import set_default_generate_barecode_mapping, set_default_umap_image_settings, get_analyze_recruitment_default_settings
|
387
|
+
from .settings import get_default_generate_activation_map_settings
|
388
|
+
#activation
|
386
389
|
def read_settings_from_csv(csv_file_path):
|
387
390
|
settings = {}
|
388
391
|
with open(csv_file_path, newline='') as csvfile:
|
@@ -422,6 +425,8 @@ def import_settings(settings_type='mask'):
|
|
422
425
|
settings = set_default_umap_image_settings(settings={})
|
423
426
|
elif settings_type == 'recruitment':
|
424
427
|
settings = get_analyze_recruitment_default_settings(settings={})
|
428
|
+
elif settings_type == 'activation':
|
429
|
+
settings = get_default_generate_activation_map_settings(settings={})
|
425
430
|
elif settings_type == 'analyze_plaques':
|
426
431
|
settings = {}
|
427
432
|
elif settings_type == 'convert':
|
@@ -436,8 +441,10 @@ def import_settings(settings_type='mask'):
|
|
436
441
|
|
437
442
|
def setup_settings_panel(vertical_container, settings_type='mask'):
|
438
443
|
global vars_dict, scrollable_frame
|
439
|
-
from .settings import get_identify_masks_finetune_default_settings, set_default_analyze_screen, set_default_settings_preprocess_generate_masks
|
440
|
-
from .settings import
|
444
|
+
from .settings import get_identify_masks_finetune_default_settings, set_default_analyze_screen, set_default_settings_preprocess_generate_masks
|
445
|
+
from .settings import get_measure_crop_settings, deep_spacr_defaults, set_default_generate_barecode_mapping, set_default_umap_image_settings
|
446
|
+
from .settings import get_map_barcodes_default_settings, get_analyze_recruitment_default_settings, get_check_cellpose_models_default_settings
|
447
|
+
from .settings import generate_fields, get_perform_regression_default_settings, get_train_cellpose_default_settings, get_default_generate_activation_map_settings
|
441
448
|
from .gui_utils import convert_settings_dict_for_gui
|
442
449
|
from .gui_elements import set_element_size
|
443
450
|
|
@@ -480,6 +487,8 @@ def setup_settings_panel(vertical_container, settings_type='mask'):
|
|
480
487
|
settings = get_perform_regression_default_settings(settings={})
|
481
488
|
elif settings_type == 'recruitment':
|
482
489
|
settings = get_analyze_recruitment_default_settings(settings={})
|
490
|
+
elif settings_type == 'activation':
|
491
|
+
settings = get_default_generate_activation_map_settings(settings={})
|
483
492
|
elif settings_type == 'analyze_plaques':
|
484
493
|
settings = {'src':'path to images'}
|
485
494
|
elif settings_type == 'convert':
|
spacr/gui_utils.py
CHANGED
@@ -77,7 +77,7 @@ def load_app(root, app_name, app_func):
|
|
77
77
|
else:
|
78
78
|
proceed_with_app(root, app_name, app_func)
|
79
79
|
|
80
|
-
def
|
80
|
+
def parse_list_v1(value):
|
81
81
|
"""
|
82
82
|
Parses a string representation of a list and returns the parsed list.
|
83
83
|
|
@@ -98,6 +98,34 @@ def parse_list(value):
|
|
98
98
|
return parsed_value
|
99
99
|
elif all(isinstance(item, str) for item in parsed_value):
|
100
100
|
return parsed_value
|
101
|
+
elif all(isinstance(item, float) for item in parsed_value):
|
102
|
+
return parsed_value
|
103
|
+
else:
|
104
|
+
raise ValueError("List contains mixed types or unsupported types")
|
105
|
+
else:
|
106
|
+
raise ValueError(f"Expected a list but got {type(parsed_value).__name__}")
|
107
|
+
except (ValueError, SyntaxError) as e:
|
108
|
+
raise ValueError(f"Invalid format for list: {value}. Error: {e}")
|
109
|
+
|
110
|
+
def parse_list(value):
|
111
|
+
"""
|
112
|
+
Parses a string representation of a list and returns the parsed list.
|
113
|
+
|
114
|
+
Args:
|
115
|
+
value (str): The string representation of the list.
|
116
|
+
|
117
|
+
Returns:
|
118
|
+
list: The parsed list, which can contain integers, floats, or strings.
|
119
|
+
|
120
|
+
Raises:
|
121
|
+
ValueError: If the input value is not a valid list format or contains mixed types or unsupported types.
|
122
|
+
"""
|
123
|
+
try:
|
124
|
+
parsed_value = ast.literal_eval(value)
|
125
|
+
if isinstance(parsed_value, list):
|
126
|
+
# Check if all elements are homogeneous (either all int, float, or str)
|
127
|
+
if all(isinstance(item, (int, float, str)) for item in parsed_value):
|
128
|
+
return parsed_value
|
101
129
|
else:
|
102
130
|
raise ValueError("List contains mixed types or unsupported types")
|
103
131
|
else:
|