spacr 0.0.2__py3-none-any.whl → 0.0.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- spacr/__init__.py +2 -2
- spacr/__main__.py +0 -2
- spacr/alpha.py +803 -14
- spacr/annotate_app.py +118 -120
- spacr/chris.py +50 -0
- spacr/core.py +1544 -533
- spacr/deep_spacr.py +696 -0
- spacr/foldseek.py +779 -0
- spacr/get_alfafold_structures.py +72 -0
- spacr/graph_learning.py +297 -253
- spacr/gui.py +145 -0
- spacr/gui_2.py +90 -0
- spacr/gui_classify_app.py +70 -80
- spacr/gui_mask_app.py +114 -91
- spacr/gui_measure_app.py +109 -88
- spacr/gui_utils.py +376 -32
- spacr/io.py +441 -438
- spacr/mask_app.py +116 -9
- spacr/measure.py +169 -69
- spacr/models/cp/toxo_pv_lumen.CP_model +0 -0
- spacr/old_code.py +70 -2
- spacr/plot.py +173 -17
- spacr/sequencing.py +1130 -0
- spacr/sim.py +630 -125
- spacr/timelapse.py +139 -10
- spacr/train.py +188 -21
- spacr/umap.py +0 -689
- spacr/utils.py +1360 -119
- {spacr-0.0.2.dist-info → spacr-0.0.6.dist-info}/METADATA +17 -29
- spacr-0.0.6.dist-info/RECORD +39 -0
- {spacr-0.0.2.dist-info → spacr-0.0.6.dist-info}/WHEEL +1 -1
- spacr-0.0.6.dist-info/entry_points.txt +9 -0
- spacr-0.0.2.dist-info/RECORD +0 -31
- spacr-0.0.2.dist-info/entry_points.txt +0 -7
- {spacr-0.0.2.dist-info → spacr-0.0.6.dist-info}/LICENSE +0 -0
- {spacr-0.0.2.dist-info → spacr-0.0.6.dist-info}/top_level.txt +0 -0
spacr/timelapse.py
CHANGED
@@ -3,7 +3,6 @@ import numpy as np
|
|
3
3
|
import pandas as pd
|
4
4
|
from collections import defaultdict
|
5
5
|
import matplotlib.pyplot as plt
|
6
|
-
from matplotlib.animation import FuncAnimation
|
7
6
|
from IPython.display import display
|
8
7
|
from IPython.display import Image as ipyimage
|
9
8
|
import trackpy as tp
|
@@ -11,6 +10,7 @@ from btrack import datasets as btrack_datasets
|
|
11
10
|
from skimage.measure import regionprops
|
12
11
|
from scipy.signal import find_peaks
|
13
12
|
from scipy.optimize import curve_fit
|
13
|
+
from scipy.integrate import trapz
|
14
14
|
import matplotlib.pyplot as plt
|
15
15
|
|
16
16
|
from .logger import log_function_call
|
@@ -590,6 +590,80 @@ def infected_vs_noninfected(result_df, measurement):
|
|
590
590
|
plt.tight_layout()
|
591
591
|
plt.show()
|
592
592
|
|
593
|
+
def save_figure(fig, src, figure_number):
|
594
|
+
source = os.path.dirname(src)
|
595
|
+
results_fldr = os.path.join(source,'results')
|
596
|
+
os.makedirs(results_fldr, exist_ok=True)
|
597
|
+
fig_loc = os.path.join(results_fldr, f'figure_{figure_number}.pdf')
|
598
|
+
fig.savefig(fig_loc)
|
599
|
+
print(f'Saved figure:{fig_loc}')
|
600
|
+
|
601
|
+
def save_results_dataframe(df, src, results_name):
|
602
|
+
source = os.path.dirname(src)
|
603
|
+
results_fldr = os.path.join(source,'results')
|
604
|
+
os.makedirs(results_fldr, exist_ok=True)
|
605
|
+
csv_loc = os.path.join(results_fldr, f'{results_name}.csv')
|
606
|
+
df.to_csv(csv_loc, index=True)
|
607
|
+
print(f'Saved results:{csv_loc}')
|
608
|
+
|
609
|
+
def summarize_per_well(peak_details_df):
|
610
|
+
# Step 1: Split the 'ID' column
|
611
|
+
split_columns = peak_details_df['ID'].str.split('_', expand=True)
|
612
|
+
peak_details_df[['plate', 'row', 'column', 'field', 'object_number']] = split_columns
|
613
|
+
|
614
|
+
# Step 2: Create 'well_ID' by combining 'row' and 'column'
|
615
|
+
peak_details_df['well_ID'] = peak_details_df['row'] + '_' + peak_details_df['column']
|
616
|
+
|
617
|
+
# Filter entries where 'amplitude' is not null
|
618
|
+
filtered_df = peak_details_df[peak_details_df['amplitude'].notna()]
|
619
|
+
|
620
|
+
# Preparation for Step 3: Identify numeric columns for averaging from the filtered dataframe
|
621
|
+
numeric_cols = filtered_df.select_dtypes(include=['number']).columns
|
622
|
+
|
623
|
+
# Step 3: Calculate summary statistics
|
624
|
+
summary_df = filtered_df.groupby('well_ID').agg(
|
625
|
+
peaks_per_well=('ID', 'size'),
|
626
|
+
unique_IDs_with_amplitude=('ID', 'nunique'), # Count unique IDs per well with non-null amplitude
|
627
|
+
**{col: (col, 'mean') for col in numeric_cols} # exclude 'amplitude' from averaging if it's numeric
|
628
|
+
).reset_index()
|
629
|
+
|
630
|
+
# Step 3: Calculate summary statistics
|
631
|
+
summary_df_2 = peak_details_df.groupby('well_ID').agg(
|
632
|
+
cells_per_well=('object_number', 'nunique'),
|
633
|
+
).reset_index()
|
634
|
+
|
635
|
+
summary_df['cells_per_well'] = summary_df_2['cells_per_well']
|
636
|
+
summary_df['peaks_per_cell'] = summary_df['peaks_per_well'] / summary_df['cells_per_well']
|
637
|
+
|
638
|
+
return summary_df
|
639
|
+
|
640
|
+
def summarize_per_well_inf_non_inf(peak_details_df):
|
641
|
+
# Step 1: Split the 'ID' column
|
642
|
+
split_columns = peak_details_df['ID'].str.split('_', expand=True)
|
643
|
+
peak_details_df[['plate', 'row', 'column', 'field', 'object_number']] = split_columns
|
644
|
+
|
645
|
+
# Step 2: Create 'well_ID' by combining 'row' and 'column'
|
646
|
+
peak_details_df['well_ID'] = peak_details_df['row'] + '_' + peak_details_df['column']
|
647
|
+
|
648
|
+
# Assume 'pathogen_count' indicates infection if > 0
|
649
|
+
# Add an 'infected_status' column to classify cells
|
650
|
+
peak_details_df['infected_status'] = peak_details_df['infected'].apply(lambda x: 'infected' if x > 0 else 'non_infected')
|
651
|
+
|
652
|
+
# Preparation for Step 3: Identify numeric columns for averaging
|
653
|
+
numeric_cols = peak_details_df.select_dtypes(include=['number']).columns
|
654
|
+
|
655
|
+
# Step 3: Calculate summary statistics
|
656
|
+
summary_df = peak_details_df.groupby(['well_ID', 'infected_status']).agg(
|
657
|
+
cells_per_well=('object_number', 'nunique'),
|
658
|
+
peaks_per_well=('ID', 'size'),
|
659
|
+
**{col: (col, 'mean') for col in numeric_cols}
|
660
|
+
).reset_index()
|
661
|
+
|
662
|
+
# Calculate peaks per cell
|
663
|
+
summary_df['peaks_per_cell'] = summary_df['peaks_per_well'] / summary_df['cells_per_well']
|
664
|
+
|
665
|
+
return summary_df
|
666
|
+
|
593
667
|
def analyze_calcium_oscillations(db_loc, measurement='cell_channel_1_mean_intensity', size_filter='cell_area', fluctuation_threshold=0.25, num_lines=None, peak_height=0.01, pathogen=None, cytoplasm=None, remove_transient=True, verbose=False, transience_threshold=0.9):
|
594
668
|
# Load data
|
595
669
|
conn = sqlite3.connect(db_loc)
|
@@ -626,7 +700,7 @@ def analyze_calcium_oscillations(db_loc, measurement='cell_channel_1_mean_intens
|
|
626
700
|
cell_df['plate_row_column_field_object'] = cell_df['plate'].astype(str) + '_' + cell_df['row'].astype(str) + '_' + cell_df['column'].astype(str) + '_' + cell_df['field'].astype(str) + '_' + cell_df['object_label'].astype(str)
|
627
701
|
|
628
702
|
df = cell_df.copy()
|
629
|
-
|
703
|
+
|
630
704
|
# Fit exponential decay model to all scaled fluorescence data
|
631
705
|
try:
|
632
706
|
params, _ = curve_fit(exponential_decay, df['time'], df[measurement], p0=[max(df[measurement]), 0.01, min(df[measurement])], maxfev=10000)
|
@@ -653,11 +727,14 @@ def analyze_calcium_oscillations(db_loc, measurement='cell_channel_1_mean_intens
|
|
653
727
|
if verbose:
|
654
728
|
print(f'Group length: {len(group)} Timelapse length: {total_timepoints}, threshold:{threshold}')
|
655
729
|
|
656
|
-
if
|
730
|
+
if len(group) <= threshold:
|
657
731
|
transience_removed += 1
|
732
|
+
if verbose:
|
733
|
+
print(f'removed group {unique_id} due to transience')
|
658
734
|
continue
|
659
735
|
|
660
736
|
size_diff = group[size_filter].std() / group[size_filter].mean()
|
737
|
+
|
661
738
|
if size_diff <= fluctuation_threshold:
|
662
739
|
group['delta_' + measurement] = group['corrected_' + measurement].diff().fillna(0)
|
663
740
|
corrected_dfs.append(group)
|
@@ -665,12 +742,50 @@ def analyze_calcium_oscillations(db_loc, measurement='cell_channel_1_mean_intens
|
|
665
742
|
# Detect peaks
|
666
743
|
peaks, properties = find_peaks(group['delta_' + measurement], height=peak_height)
|
667
744
|
|
745
|
+
# Set values < 0 to 0
|
746
|
+
group_filtered = group.copy()
|
747
|
+
group_filtered['delta_' + measurement] = group['delta_' + measurement].clip(lower=0)
|
748
|
+
above_zero_auc = trapz(y=group_filtered['delta_' + measurement], x=group_filtered['time'])
|
749
|
+
auc = trapz(y=group['delta_' + measurement], x=group_filtered['time'])
|
750
|
+
is_infected = (group['parasite_count'] > 0).any()
|
751
|
+
|
752
|
+
if is_infected:
|
753
|
+
is_infected = 1
|
754
|
+
else:
|
755
|
+
is_infected = 0
|
756
|
+
|
757
|
+
if len(peaks) == 0:
|
758
|
+
peak_details_list.append({
|
759
|
+
'ID': unique_id,
|
760
|
+
'plate': group['plate'].iloc[0],
|
761
|
+
'row': group['row'].iloc[0],
|
762
|
+
'column': group['column'].iloc[0],
|
763
|
+
'field': group['field'].iloc[0],
|
764
|
+
'object_number': group['object_number'].iloc[0],
|
765
|
+
'time': np.nan, # The time of the peak
|
766
|
+
'amplitude': np.nan,
|
767
|
+
'delta': np.nan,
|
768
|
+
'AUC': auc,
|
769
|
+
'AUC_positive': above_zero_auc,
|
770
|
+
'AUC_peak': np.nan,
|
771
|
+
'infected': is_infected
|
772
|
+
})
|
773
|
+
|
668
774
|
# Inside the for loop where peaks are detected
|
669
775
|
for i, peak in enumerate(peaks):
|
670
|
-
|
671
|
-
|
672
|
-
|
776
|
+
|
777
|
+
amplitude = properties['peak_heights'][i]
|
778
|
+
peak_time = group['time'].iloc[peak]
|
673
779
|
pathogen_count_at_peak = group['parasite_count'].iloc[peak]
|
780
|
+
|
781
|
+
start_idx = max(peak - 1, 0)
|
782
|
+
end_idx = min(peak + 1, len(group) - 1)
|
783
|
+
|
784
|
+
# Using indices to slice for AUC calculation
|
785
|
+
peak_segment_y = group['delta_' + measurement].iloc[start_idx:end_idx + 1]
|
786
|
+
peak_segment_x = group['time'].iloc[start_idx:end_idx + 1]
|
787
|
+
peak_auc = trapz(y=peak_segment_y, x=peak_segment_x)
|
788
|
+
|
674
789
|
peak_details_list.append({
|
675
790
|
'ID': unique_id,
|
676
791
|
'plate': group['plate'].iloc[0],
|
@@ -681,6 +796,9 @@ def analyze_calcium_oscillations(db_loc, measurement='cell_channel_1_mean_intens
|
|
681
796
|
'time': peak_time, # The time of the peak
|
682
797
|
'amplitude': amplitude,
|
683
798
|
'delta': group['delta_' + measurement].iloc[peak],
|
799
|
+
'AUC': auc,
|
800
|
+
'AUC_positive': above_zero_auc,
|
801
|
+
'AUC_peak': peak_auc,
|
684
802
|
'infected': pathogen_count_at_peak
|
685
803
|
})
|
686
804
|
else:
|
@@ -697,7 +815,14 @@ def analyze_calcium_oscillations(db_loc, measurement='cell_channel_1_mean_intens
|
|
697
815
|
return
|
698
816
|
|
699
817
|
peak_details_df = pd.DataFrame(peak_details_list)
|
700
|
-
|
818
|
+
summary_df = summarize_per_well(peak_details_df)
|
819
|
+
summary_df_inf_non_inf = summarize_per_well_inf_non_inf(peak_details_df)
|
820
|
+
|
821
|
+
save_results_dataframe(df=peak_details_df, src=db_loc, results_name='peak_details')
|
822
|
+
save_results_dataframe(df=result_df, src=db_loc, results_name='results')
|
823
|
+
save_results_dataframe(df=summary_df, src=db_loc, results_name='well_results')
|
824
|
+
save_results_dataframe(df=summary_df_inf_non_inf, src=db_loc, results_name='well_results_inf_non_inf')
|
825
|
+
|
701
826
|
# Plotting
|
702
827
|
fig, ax = plt.subplots(figsize=(10, 8))
|
703
828
|
sampled_groups = result_df['plate_row_column_field_object'].unique()
|
@@ -714,12 +839,16 @@ def analyze_calcium_oscillations(db_loc, measurement='cell_channel_1_mean_intens
|
|
714
839
|
ax.set_xlabel('Time')
|
715
840
|
ax.set_ylabel('Normalized Delta ' + measurement)
|
716
841
|
plt.tight_layout()
|
842
|
+
|
717
843
|
plt.show()
|
844
|
+
|
845
|
+
save_figure(fig, src=db_loc, figure_number=1)
|
718
846
|
|
719
847
|
if pathogen:
|
720
848
|
infected_vs_noninfected(result_df, measurement)
|
849
|
+
save_figure(fig, src=db_loc, figure_number=2)
|
721
850
|
|
722
|
-
#
|
851
|
+
# Identify cells with and without pathogens
|
723
852
|
infected_cells = result_df[result_df.groupby('plate_row_column_field_object')['parasite_count'].transform('max') > 0]['plate_row_column_field_object'].unique()
|
724
853
|
noninfected_cells = result_df[result_df.groupby('plate_row_column_field_object')['parasite_count'].transform('max') == 0]['plate_row_column_field_object'].unique()
|
725
854
|
|
@@ -733,5 +862,5 @@ def analyze_calcium_oscillations(db_loc, measurement='cell_channel_1_mean_intens
|
|
733
862
|
|
734
863
|
print(f'Average number of peaks per infected cell: {avg_inf_peaks_per_cell:.2f}')
|
735
864
|
print(f'Average number of peaks per non-infected cell: {avg_non_inf_peaks_per_cell:.2f}')
|
736
|
-
|
737
|
-
return result_df, peak_details_df
|
865
|
+
print(f'done')
|
866
|
+
return result_df, peak_details_df, fig
|
spacr/train.py
CHANGED
@@ -6,6 +6,7 @@ from torch.autograd import grad
|
|
6
6
|
from torch.optim.lr_scheduler import StepLR
|
7
7
|
import torch.nn.functional as F
|
8
8
|
from IPython.display import display, clear_output
|
9
|
+
import difflib
|
9
10
|
|
10
11
|
from .logger import log_function_call
|
11
12
|
|
@@ -194,8 +195,8 @@ def test_model_performance(loaders, model, loader_name_list, epoch, train_mode,
|
|
194
195
|
|
195
196
|
def train_test_model(src, settings, custom_model=False, custom_model_path=None):
|
196
197
|
|
197
|
-
from .io import
|
198
|
-
from .utils import pick_best_model
|
198
|
+
from .io import _save_settings, _copy_missclassified
|
199
|
+
from .utils import pick_best_model
|
199
200
|
from .core import generate_loaders
|
200
201
|
|
201
202
|
settings['src'] = src
|
@@ -208,7 +209,7 @@ def train_test_model(src, settings, custom_model=False, custom_model_path=None):
|
|
208
209
|
model = torch.load(custom_model_path)
|
209
210
|
|
210
211
|
if settings['train']:
|
211
|
-
|
212
|
+
_save_settings(settings, src)
|
212
213
|
torch.cuda.empty_cache()
|
213
214
|
torch.cuda.memory.empty_cache()
|
214
215
|
gc.collect()
|
@@ -227,20 +228,23 @@ def train_test_model(src, settings, custom_model=False, custom_model_path=None):
|
|
227
228
|
validation_split=settings['val_split'],
|
228
229
|
pin_memory=settings['pin_memory'],
|
229
230
|
normalize=settings['normalize'],
|
230
|
-
|
231
|
+
channels=settings['channels'],
|
232
|
+
verbose=settings['verbose'])
|
233
|
+
|
231
234
|
|
232
235
|
if settings['test']:
|
233
236
|
test, _, plate_names_test = generate_loaders(src,
|
234
|
-
|
235
|
-
|
236
|
-
|
237
|
-
|
238
|
-
|
239
|
-
|
240
|
-
|
241
|
-
|
242
|
-
|
243
|
-
|
237
|
+
train_mode=settings['train_mode'],
|
238
|
+
mode='test',
|
239
|
+
image_size=settings['image_size'],
|
240
|
+
batch_size=settings['batch_size'],
|
241
|
+
classes=settings['classes'],
|
242
|
+
num_workers=settings['num_workers'],
|
243
|
+
validation_split=0.0,
|
244
|
+
pin_memory=settings['pin_memory'],
|
245
|
+
normalize=settings['normalize'],
|
246
|
+
channels=settings['channels'],
|
247
|
+
verbose=settings['verbose'])
|
244
248
|
if model == None:
|
245
249
|
model_path = pick_best_model(src+'/model')
|
246
250
|
print(f'Best model: {model_path}')
|
@@ -330,8 +334,8 @@ def train_model(dst, model_type, train_loaders, train_loader_names, train_mode='
|
|
330
334
|
None
|
331
335
|
"""
|
332
336
|
|
333
|
-
from .io import
|
334
|
-
from .utils import
|
337
|
+
from .io import _save_model, _save_progress
|
338
|
+
from .utils import compute_irm_penalty, calculate_loss, choose_model
|
335
339
|
|
336
340
|
print(f'Train batches:{len(train_loaders)}, Validation batches:{len(val_loaders)}')
|
337
341
|
|
@@ -347,6 +351,11 @@ def train_model(dst, model_type, train_loaders, train_loader_names, train_mode='
|
|
347
351
|
break
|
348
352
|
|
349
353
|
model = choose_model(model_type, device, init_weights, dropout_rate, use_checkpoint)
|
354
|
+
|
355
|
+
if model is None:
|
356
|
+
print(f'Model {model_type} not found')
|
357
|
+
return
|
358
|
+
|
350
359
|
model.to(device)
|
351
360
|
|
352
361
|
if optimizer_type == 'adamw':
|
@@ -421,10 +430,10 @@ def train_model(dst, model_type, train_loaders, train_loader_names, train_mode='
|
|
421
430
|
if schedule == 'step_lr':
|
422
431
|
scheduler.step()
|
423
432
|
|
424
|
-
|
433
|
+
_save_progress(dst, results_df, train_metrics_df)
|
425
434
|
clear_output(wait=True)
|
426
435
|
display(results_df)
|
427
|
-
|
436
|
+
_save_model(model, model_type, results_df, dst, epoch, epochs, intermedeate_save=[0.99,0.98,0.95,0.94])
|
428
437
|
|
429
438
|
if train_mode == 'irm':
|
430
439
|
dummy_w = torch.nn.Parameter(torch.Tensor([1.0])).to(device)
|
@@ -494,7 +503,165 @@ def train_model(dst, model_type, train_loaders, train_loader_names, train_mode='
|
|
494
503
|
|
495
504
|
clear_output(wait=True)
|
496
505
|
display(results_df)
|
497
|
-
|
498
|
-
|
506
|
+
_save_progress(dst, results_df, train_metrics_df)
|
507
|
+
_save_model(model, model_type, results_df, dst, epoch, epochs, intermedeate_save=[0.99,0.98,0.95,0.94])
|
499
508
|
print(f'Saved model: {dst}')
|
500
|
-
return
|
509
|
+
return
|
510
|
+
|
511
|
+
def get_submodules(model, prefix=''):
|
512
|
+
submodules = []
|
513
|
+
for name, module in model.named_children():
|
514
|
+
full_name = prefix + ('.' if prefix else '') + name
|
515
|
+
submodules.append(full_name)
|
516
|
+
submodules.extend(get_submodules(module, full_name))
|
517
|
+
return submodules
|
518
|
+
|
519
|
+
def visualize_model_attention_v2(src, model_type='maxvit', model_path='', image_size=224, channels=[1,2,3], normalize=True, class_names=None, save_saliency=False, save_dir='saliency_maps'):
|
520
|
+
import torch
|
521
|
+
import os
|
522
|
+
from spacr.utils import SaliencyMapGenerator, preprocess_image
|
523
|
+
import matplotlib.pyplot as plt
|
524
|
+
import numpy as np
|
525
|
+
from PIL import Image
|
526
|
+
|
527
|
+
use_cuda = torch.cuda.is_available()
|
528
|
+
device = torch.device("cuda" if use_cuda else "cpu")
|
529
|
+
|
530
|
+
# Load the entire model object
|
531
|
+
model = torch.load(model_path)
|
532
|
+
model.to(device)
|
533
|
+
|
534
|
+
# Create directory for saving saliency maps if it does not exist
|
535
|
+
if save_saliency and not os.path.exists(save_dir):
|
536
|
+
os.makedirs(save_dir)
|
537
|
+
|
538
|
+
# Collect all images and their tensors
|
539
|
+
images = []
|
540
|
+
input_tensors = []
|
541
|
+
filenames = []
|
542
|
+
for file in os.listdir(src):
|
543
|
+
image_path = os.path.join(src, file)
|
544
|
+
image, input_tensor = preprocess_image(image_path, normalize=normalize, image_size=image_size, channels=channels)
|
545
|
+
images.append(image)
|
546
|
+
input_tensors.append(input_tensor)
|
547
|
+
filenames.append(file)
|
548
|
+
|
549
|
+
input_tensors = torch.cat(input_tensors).to(device)
|
550
|
+
class_labels = torch.zeros(input_tensors.size(0), dtype=torch.long).to(device) # Replace with actual class labels if available
|
551
|
+
|
552
|
+
# Generate saliency maps
|
553
|
+
cam_generator = SaliencyMapGenerator(model)
|
554
|
+
saliency_maps = cam_generator.compute_saliency_maps(input_tensors, class_labels)
|
555
|
+
|
556
|
+
# Plot images, saliency maps, and overlays
|
557
|
+
saliency_maps = saliency_maps.cpu().numpy()
|
558
|
+
N = len(images)
|
559
|
+
|
560
|
+
dst = os.path.join(src, 'saliency_maps')
|
561
|
+
os.makedirs(dst, exist_ok=True)
|
562
|
+
|
563
|
+
for i in range(N):
|
564
|
+
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
|
565
|
+
|
566
|
+
# Original image
|
567
|
+
axes[0].imshow(images[i])
|
568
|
+
axes[0].axis('off')
|
569
|
+
if class_names:
|
570
|
+
axes[0].set_title(class_names[class_labels[i].item()])
|
571
|
+
|
572
|
+
# Saliency map
|
573
|
+
axes[1].imshow(saliency_maps[i], cmap=plt.cm.hot)
|
574
|
+
axes[1].axis('off')
|
575
|
+
|
576
|
+
# Overlay
|
577
|
+
overlay = np.array(images[i])
|
578
|
+
axes[2].imshow(overlay)
|
579
|
+
axes[2].imshow(saliency_maps[i], cmap='jet', alpha=0.5)
|
580
|
+
axes[2].axis('off')
|
581
|
+
|
582
|
+
plt.tight_layout()
|
583
|
+
plt.show()
|
584
|
+
|
585
|
+
# Save the saliency map if required
|
586
|
+
if save_saliency:
|
587
|
+
saliency_image = Image.fromarray((saliency_maps[i] * 255).astype(np.uint8))
|
588
|
+
saliency_image.save(os.path.join(dst, f'saliency_{filenames[i]}'))
|
589
|
+
|
590
|
+
def visualize_model_attention(src, model_type='maxvit', model_path='', image_size=224, channels=[1,2,3], normalize=True, class_names=None, save_saliency=False, save_dir='saliency_maps'):
|
591
|
+
import torch
|
592
|
+
import os
|
593
|
+
from spacr.utils import SaliencyMapGenerator, preprocess_image
|
594
|
+
import matplotlib.pyplot as plt
|
595
|
+
import numpy as np
|
596
|
+
from PIL import Image
|
597
|
+
|
598
|
+
use_cuda = torch.cuda.is_available()
|
599
|
+
device = torch.device("cuda" if use_cuda else "cpu")
|
600
|
+
|
601
|
+
# Load the entire model object
|
602
|
+
model = torch.load(model_path)
|
603
|
+
model.to(device)
|
604
|
+
|
605
|
+
# Create directory for saving saliency maps if it does not exist
|
606
|
+
if save_saliency and not os.path.exists(save_dir):
|
607
|
+
os.makedirs(save_dir)
|
608
|
+
|
609
|
+
# Collect all images and their tensors
|
610
|
+
images = []
|
611
|
+
input_tensors = []
|
612
|
+
filenames = []
|
613
|
+
for file in os.listdir(src):
|
614
|
+
if not file.endswith('.png'):
|
615
|
+
continue
|
616
|
+
image_path = os.path.join(src, file)
|
617
|
+
image, input_tensor = preprocess_image(image_path, normalize=normalize, image_size=image_size, channels=channels)
|
618
|
+
images.append(image)
|
619
|
+
input_tensors.append(input_tensor)
|
620
|
+
filenames.append(file)
|
621
|
+
|
622
|
+
input_tensors = torch.cat(input_tensors).to(device)
|
623
|
+
class_labels = torch.zeros(input_tensors.size(0), dtype=torch.long).to(device) # Replace with actual class labels if available
|
624
|
+
|
625
|
+
# Generate saliency maps
|
626
|
+
cam_generator = SaliencyMapGenerator(model)
|
627
|
+
saliency_maps = cam_generator.compute_saliency_maps(input_tensors, class_labels)
|
628
|
+
|
629
|
+
# Convert saliency maps to numpy arrays
|
630
|
+
saliency_maps = saliency_maps.cpu().numpy()
|
631
|
+
|
632
|
+
N = len(images)
|
633
|
+
|
634
|
+
dst = os.path.join(src, 'saliency_maps')
|
635
|
+
|
636
|
+
for i in range(N):
|
637
|
+
fig, axes = plt.subplots(1, 3, figsize=(20, 5))
|
638
|
+
|
639
|
+
# Original image
|
640
|
+
axes[0].imshow(images[i])
|
641
|
+
axes[0].axis('off')
|
642
|
+
if class_names:
|
643
|
+
axes[0].set_title(f"Class: {class_names[class_labels[i].item()]}")
|
644
|
+
|
645
|
+
# Saliency Map
|
646
|
+
axes[1].imshow(saliency_maps[i, 0], cmap='hot')
|
647
|
+
axes[1].axis('off')
|
648
|
+
axes[1].set_title("Saliency Map")
|
649
|
+
|
650
|
+
# Overlay
|
651
|
+
overlay = np.array(images[i])
|
652
|
+
overlay = overlay / overlay.max()
|
653
|
+
saliency_map_rgb = np.stack([saliency_maps[i, 0]] * 3, axis=-1) # Convert saliency map to RGB
|
654
|
+
overlay = (overlay * 0.5 + saliency_map_rgb * 0.5).clip(0, 1)
|
655
|
+
axes[2].imshow(overlay)
|
656
|
+
axes[2].axis('off')
|
657
|
+
axes[2].set_title("Overlay")
|
658
|
+
|
659
|
+
plt.tight_layout()
|
660
|
+
plt.show()
|
661
|
+
|
662
|
+
# Save the saliency map if required
|
663
|
+
if save_saliency:
|
664
|
+
os.makedirs(dst, exist_ok=True)
|
665
|
+
saliency_image = Image.fromarray((saliency_maps[i, 0] * 255).astype(np.uint8))
|
666
|
+
saliency_image.save(os.path.join(dst, f'saliency_{filenames[i]}'))
|
667
|
+
|