spacr 0.0.20__py3-none-any.whl → 0.0.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- spacr/alpha.py +291 -14
- spacr/annotate_app.py +2 -2
- spacr/core.py +1301 -426
- spacr/foldseek.py +793 -0
- spacr/get_alfafold_structures.py +72 -0
- spacr/gui_mask_app.py +30 -10
- spacr/gui_utils.py +17 -2
- spacr/io.py +260 -102
- spacr/measure.py +150 -64
- spacr/plot.py +151 -12
- spacr/sim.py +666 -119
- spacr/timelapse.py +139 -9
- spacr/train.py +18 -10
- spacr/utils.py +43 -43
- {spacr-0.0.20.dist-info → spacr-0.0.21.dist-info}/METADATA +5 -2
- spacr-0.0.21.dist-info/RECORD +33 -0
- spacr-0.0.20.dist-info/RECORD +0 -31
- {spacr-0.0.20.dist-info → spacr-0.0.21.dist-info}/LICENSE +0 -0
- {spacr-0.0.20.dist-info → spacr-0.0.21.dist-info}/WHEEL +0 -0
- {spacr-0.0.20.dist-info → spacr-0.0.21.dist-info}/entry_points.txt +0 -0
- {spacr-0.0.20.dist-info → spacr-0.0.21.dist-info}/top_level.txt +0 -0
spacr/timelapse.py
CHANGED
@@ -11,6 +11,7 @@ from btrack import datasets as btrack_datasets
|
|
11
11
|
from skimage.measure import regionprops
|
12
12
|
from scipy.signal import find_peaks
|
13
13
|
from scipy.optimize import curve_fit
|
14
|
+
from scipy.integrate import trapz
|
14
15
|
import matplotlib.pyplot as plt
|
15
16
|
|
16
17
|
from .logger import log_function_call
|
@@ -590,6 +591,80 @@ def infected_vs_noninfected(result_df, measurement):
|
|
590
591
|
plt.tight_layout()
|
591
592
|
plt.show()
|
592
593
|
|
594
|
+
def save_figure(fig, src, figure_number):
|
595
|
+
source = os.path.dirname(src)
|
596
|
+
results_fldr = os.path.join(source,'results')
|
597
|
+
os.makedirs(results_fldr, exist_ok=True)
|
598
|
+
fig_loc = os.path.join(results_fldr, f'figure_{figure_number}.pdf')
|
599
|
+
fig.savefig(fig_loc)
|
600
|
+
print(f'Saved figure:{fig_loc}')
|
601
|
+
|
602
|
+
def save_results_dataframe(df, src, results_name):
|
603
|
+
source = os.path.dirname(src)
|
604
|
+
results_fldr = os.path.join(source,'results')
|
605
|
+
os.makedirs(results_fldr, exist_ok=True)
|
606
|
+
csv_loc = os.path.join(results_fldr, f'{results_name}.csv')
|
607
|
+
df.to_csv(csv_loc, index=True)
|
608
|
+
print(f'Saved results:{csv_loc}')
|
609
|
+
|
610
|
+
def summarize_per_well(peak_details_df):
|
611
|
+
# Step 1: Split the 'ID' column
|
612
|
+
split_columns = peak_details_df['ID'].str.split('_', expand=True)
|
613
|
+
peak_details_df[['plate', 'row', 'column', 'field', 'object_number']] = split_columns
|
614
|
+
|
615
|
+
# Step 2: Create 'well_ID' by combining 'row' and 'column'
|
616
|
+
peak_details_df['well_ID'] = peak_details_df['row'] + '_' + peak_details_df['column']
|
617
|
+
|
618
|
+
# Filter entries where 'amplitude' is not null
|
619
|
+
filtered_df = peak_details_df[peak_details_df['amplitude'].notna()]
|
620
|
+
|
621
|
+
# Preparation for Step 3: Identify numeric columns for averaging from the filtered dataframe
|
622
|
+
numeric_cols = filtered_df.select_dtypes(include=['number']).columns
|
623
|
+
|
624
|
+
# Step 3: Calculate summary statistics
|
625
|
+
summary_df = filtered_df.groupby('well_ID').agg(
|
626
|
+
peaks_per_well=('ID', 'size'),
|
627
|
+
unique_IDs_with_amplitude=('ID', 'nunique'), # Count unique IDs per well with non-null amplitude
|
628
|
+
**{col: (col, 'mean') for col in numeric_cols} # exclude 'amplitude' from averaging if it's numeric
|
629
|
+
).reset_index()
|
630
|
+
|
631
|
+
# Step 3: Calculate summary statistics
|
632
|
+
summary_df_2 = peak_details_df.groupby('well_ID').agg(
|
633
|
+
cells_per_well=('object_number', 'nunique'),
|
634
|
+
).reset_index()
|
635
|
+
|
636
|
+
summary_df['cells_per_well'] = summary_df_2['cells_per_well']
|
637
|
+
summary_df['peaks_per_cell'] = summary_df['peaks_per_well'] / summary_df['cells_per_well']
|
638
|
+
|
639
|
+
return summary_df
|
640
|
+
|
641
|
+
def summarize_per_well_inf_non_inf(peak_details_df):
|
642
|
+
# Step 1: Split the 'ID' column
|
643
|
+
split_columns = peak_details_df['ID'].str.split('_', expand=True)
|
644
|
+
peak_details_df[['plate', 'row', 'column', 'field', 'object_number']] = split_columns
|
645
|
+
|
646
|
+
# Step 2: Create 'well_ID' by combining 'row' and 'column'
|
647
|
+
peak_details_df['well_ID'] = peak_details_df['row'] + '_' + peak_details_df['column']
|
648
|
+
|
649
|
+
# Assume 'pathogen_count' indicates infection if > 0
|
650
|
+
# Add an 'infected_status' column to classify cells
|
651
|
+
peak_details_df['infected_status'] = peak_details_df['infected'].apply(lambda x: 'infected' if x > 0 else 'non_infected')
|
652
|
+
|
653
|
+
# Preparation for Step 3: Identify numeric columns for averaging
|
654
|
+
numeric_cols = peak_details_df.select_dtypes(include=['number']).columns
|
655
|
+
|
656
|
+
# Step 3: Calculate summary statistics
|
657
|
+
summary_df = peak_details_df.groupby(['well_ID', 'infected_status']).agg(
|
658
|
+
cells_per_well=('object_number', 'nunique'),
|
659
|
+
peaks_per_well=('ID', 'size'),
|
660
|
+
**{col: (col, 'mean') for col in numeric_cols}
|
661
|
+
).reset_index()
|
662
|
+
|
663
|
+
# Calculate peaks per cell
|
664
|
+
summary_df['peaks_per_cell'] = summary_df['peaks_per_well'] / summary_df['cells_per_well']
|
665
|
+
|
666
|
+
return summary_df
|
667
|
+
|
593
668
|
def analyze_calcium_oscillations(db_loc, measurement='cell_channel_1_mean_intensity', size_filter='cell_area', fluctuation_threshold=0.25, num_lines=None, peak_height=0.01, pathogen=None, cytoplasm=None, remove_transient=True, verbose=False, transience_threshold=0.9):
|
594
669
|
# Load data
|
595
670
|
conn = sqlite3.connect(db_loc)
|
@@ -626,7 +701,7 @@ def analyze_calcium_oscillations(db_loc, measurement='cell_channel_1_mean_intens
|
|
626
701
|
cell_df['plate_row_column_field_object'] = cell_df['plate'].astype(str) + '_' + cell_df['row'].astype(str) + '_' + cell_df['column'].astype(str) + '_' + cell_df['field'].astype(str) + '_' + cell_df['object_label'].astype(str)
|
627
702
|
|
628
703
|
df = cell_df.copy()
|
629
|
-
|
704
|
+
|
630
705
|
# Fit exponential decay model to all scaled fluorescence data
|
631
706
|
try:
|
632
707
|
params, _ = curve_fit(exponential_decay, df['time'], df[measurement], p0=[max(df[measurement]), 0.01, min(df[measurement])], maxfev=10000)
|
@@ -653,11 +728,14 @@ def analyze_calcium_oscillations(db_loc, measurement='cell_channel_1_mean_intens
|
|
653
728
|
if verbose:
|
654
729
|
print(f'Group length: {len(group)} Timelapse length: {total_timepoints}, threshold:{threshold}')
|
655
730
|
|
656
|
-
if
|
731
|
+
if len(group) <= threshold:
|
657
732
|
transience_removed += 1
|
733
|
+
if verbose:
|
734
|
+
print(f'removed group {unique_id} due to transience')
|
658
735
|
continue
|
659
736
|
|
660
737
|
size_diff = group[size_filter].std() / group[size_filter].mean()
|
738
|
+
|
661
739
|
if size_diff <= fluctuation_threshold:
|
662
740
|
group['delta_' + measurement] = group['corrected_' + measurement].diff().fillna(0)
|
663
741
|
corrected_dfs.append(group)
|
@@ -665,12 +743,50 @@ def analyze_calcium_oscillations(db_loc, measurement='cell_channel_1_mean_intens
|
|
665
743
|
# Detect peaks
|
666
744
|
peaks, properties = find_peaks(group['delta_' + measurement], height=peak_height)
|
667
745
|
|
746
|
+
# Set values < 0 to 0
|
747
|
+
group_filtered = group.copy()
|
748
|
+
group_filtered['delta_' + measurement] = group['delta_' + measurement].clip(lower=0)
|
749
|
+
above_zero_auc = trapz(y=group_filtered['delta_' + measurement], x=group_filtered['time'])
|
750
|
+
auc = trapz(y=group['delta_' + measurement], x=group_filtered['time'])
|
751
|
+
is_infected = (group['parasite_count'] > 0).any()
|
752
|
+
|
753
|
+
if is_infected:
|
754
|
+
is_infected = 1
|
755
|
+
else:
|
756
|
+
is_infected = 0
|
757
|
+
|
758
|
+
if len(peaks) == 0:
|
759
|
+
peak_details_list.append({
|
760
|
+
'ID': unique_id,
|
761
|
+
'plate': group['plate'].iloc[0],
|
762
|
+
'row': group['row'].iloc[0],
|
763
|
+
'column': group['column'].iloc[0],
|
764
|
+
'field': group['field'].iloc[0],
|
765
|
+
'object_number': group['object_number'].iloc[0],
|
766
|
+
'time': np.nan, # The time of the peak
|
767
|
+
'amplitude': np.nan,
|
768
|
+
'delta': np.nan,
|
769
|
+
'AUC': auc,
|
770
|
+
'AUC_positive': above_zero_auc,
|
771
|
+
'AUC_peak': np.nan,
|
772
|
+
'infected': is_infected
|
773
|
+
})
|
774
|
+
|
668
775
|
# Inside the for loop where peaks are detected
|
669
776
|
for i, peak in enumerate(peaks):
|
670
|
-
|
671
|
-
|
672
|
-
|
777
|
+
|
778
|
+
amplitude = properties['peak_heights'][i]
|
779
|
+
peak_time = group['time'].iloc[peak]
|
673
780
|
pathogen_count_at_peak = group['parasite_count'].iloc[peak]
|
781
|
+
|
782
|
+
start_idx = max(peak - 1, 0)
|
783
|
+
end_idx = min(peak + 1, len(group) - 1)
|
784
|
+
|
785
|
+
# Using indices to slice for AUC calculation
|
786
|
+
peak_segment_y = group['delta_' + measurement].iloc[start_idx:end_idx + 1]
|
787
|
+
peak_segment_x = group['time'].iloc[start_idx:end_idx + 1]
|
788
|
+
peak_auc = trapz(y=peak_segment_y, x=peak_segment_x)
|
789
|
+
|
674
790
|
peak_details_list.append({
|
675
791
|
'ID': unique_id,
|
676
792
|
'plate': group['plate'].iloc[0],
|
@@ -681,6 +797,9 @@ def analyze_calcium_oscillations(db_loc, measurement='cell_channel_1_mean_intens
|
|
681
797
|
'time': peak_time, # The time of the peak
|
682
798
|
'amplitude': amplitude,
|
683
799
|
'delta': group['delta_' + measurement].iloc[peak],
|
800
|
+
'AUC': auc,
|
801
|
+
'AUC_positive': above_zero_auc,
|
802
|
+
'AUC_peak': peak_auc,
|
684
803
|
'infected': pathogen_count_at_peak
|
685
804
|
})
|
686
805
|
else:
|
@@ -697,7 +816,14 @@ def analyze_calcium_oscillations(db_loc, measurement='cell_channel_1_mean_intens
|
|
697
816
|
return
|
698
817
|
|
699
818
|
peak_details_df = pd.DataFrame(peak_details_list)
|
700
|
-
|
819
|
+
summary_df = summarize_per_well(peak_details_df)
|
820
|
+
summary_df_inf_non_inf = summarize_per_well_inf_non_inf(peak_details_df)
|
821
|
+
|
822
|
+
save_results_dataframe(df=peak_details_df, src=db_loc, results_name='peak_details')
|
823
|
+
save_results_dataframe(df=result_df, src=db_loc, results_name='results')
|
824
|
+
save_results_dataframe(df=summary_df, src=db_loc, results_name='well_results')
|
825
|
+
save_results_dataframe(df=summary_df_inf_non_inf, src=db_loc, results_name='well_results_inf_non_inf')
|
826
|
+
|
701
827
|
# Plotting
|
702
828
|
fig, ax = plt.subplots(figsize=(10, 8))
|
703
829
|
sampled_groups = result_df['plate_row_column_field_object'].unique()
|
@@ -714,12 +840,16 @@ def analyze_calcium_oscillations(db_loc, measurement='cell_channel_1_mean_intens
|
|
714
840
|
ax.set_xlabel('Time')
|
715
841
|
ax.set_ylabel('Normalized Delta ' + measurement)
|
716
842
|
plt.tight_layout()
|
843
|
+
|
717
844
|
plt.show()
|
845
|
+
|
846
|
+
save_figure(fig, src=db_loc, figure_number=1)
|
718
847
|
|
719
848
|
if pathogen:
|
720
849
|
infected_vs_noninfected(result_df, measurement)
|
850
|
+
save_figure(fig, src=db_loc, figure_number=2)
|
721
851
|
|
722
|
-
#
|
852
|
+
# Identify cells with and without pathogens
|
723
853
|
infected_cells = result_df[result_df.groupby('plate_row_column_field_object')['parasite_count'].transform('max') > 0]['plate_row_column_field_object'].unique()
|
724
854
|
noninfected_cells = result_df[result_df.groupby('plate_row_column_field_object')['parasite_count'].transform('max') == 0]['plate_row_column_field_object'].unique()
|
725
855
|
|
@@ -733,5 +863,5 @@ def analyze_calcium_oscillations(db_loc, measurement='cell_channel_1_mean_intens
|
|
733
863
|
|
734
864
|
print(f'Average number of peaks per infected cell: {avg_inf_peaks_per_cell:.2f}')
|
735
865
|
print(f'Average number of peaks per non-infected cell: {avg_non_inf_peaks_per_cell:.2f}')
|
736
|
-
|
737
|
-
return result_df, peak_details_df
|
866
|
+
print(f'done')
|
867
|
+
return result_df, peak_details_df, fig
|
spacr/train.py
CHANGED
@@ -194,8 +194,8 @@ def test_model_performance(loaders, model, loader_name_list, epoch, train_mode,
|
|
194
194
|
|
195
195
|
def train_test_model(src, settings, custom_model=False, custom_model_path=None):
|
196
196
|
|
197
|
-
from .io import
|
198
|
-
from .utils import pick_best_model
|
197
|
+
from .io import _save_settings, _copy_missclassified
|
198
|
+
from .utils import pick_best_model
|
199
199
|
from .core import generate_loaders
|
200
200
|
|
201
201
|
settings['src'] = src
|
@@ -208,7 +208,7 @@ def train_test_model(src, settings, custom_model=False, custom_model_path=None):
|
|
208
208
|
model = torch.load(custom_model_path)
|
209
209
|
|
210
210
|
if settings['train']:
|
211
|
-
|
211
|
+
_save_settings(settings, src)
|
212
212
|
torch.cuda.empty_cache()
|
213
213
|
torch.cuda.memory.empty_cache()
|
214
214
|
gc.collect()
|
@@ -227,7 +227,9 @@ def train_test_model(src, settings, custom_model=False, custom_model_path=None):
|
|
227
227
|
validation_split=settings['val_split'],
|
228
228
|
pin_memory=settings['pin_memory'],
|
229
229
|
normalize=settings['normalize'],
|
230
|
-
|
230
|
+
channels=settings['channels'],
|
231
|
+
verbose=settings['verbose'])
|
232
|
+
|
231
233
|
|
232
234
|
if settings['test']:
|
233
235
|
test, _, plate_names_test = generate_loaders(src,
|
@@ -240,6 +242,7 @@ def train_test_model(src, settings, custom_model=False, custom_model_path=None):
|
|
240
242
|
validation_split=0.0,
|
241
243
|
pin_memory=settings['pin_memory'],
|
242
244
|
normalize=settings['normalize'],
|
245
|
+
channels=settings['channels'],
|
243
246
|
verbose=settings['verbose'])
|
244
247
|
if model == None:
|
245
248
|
model_path = pick_best_model(src+'/model')
|
@@ -330,8 +333,8 @@ def train_model(dst, model_type, train_loaders, train_loader_names, train_mode='
|
|
330
333
|
None
|
331
334
|
"""
|
332
335
|
|
333
|
-
from .io import
|
334
|
-
from .utils import
|
336
|
+
from .io import _save_model, _save_progress
|
337
|
+
from .utils import compute_irm_penalty, calculate_loss, choose_model #evaluate_model_performance,
|
335
338
|
|
336
339
|
print(f'Train batches:{len(train_loaders)}, Validation batches:{len(val_loaders)}')
|
337
340
|
|
@@ -347,6 +350,11 @@ def train_model(dst, model_type, train_loaders, train_loader_names, train_mode='
|
|
347
350
|
break
|
348
351
|
|
349
352
|
model = choose_model(model_type, device, init_weights, dropout_rate, use_checkpoint)
|
353
|
+
|
354
|
+
if model is None:
|
355
|
+
print(f'Model {model_type} not found')
|
356
|
+
return
|
357
|
+
|
350
358
|
model.to(device)
|
351
359
|
|
352
360
|
if optimizer_type == 'adamw':
|
@@ -421,10 +429,10 @@ def train_model(dst, model_type, train_loaders, train_loader_names, train_mode='
|
|
421
429
|
if schedule == 'step_lr':
|
422
430
|
scheduler.step()
|
423
431
|
|
424
|
-
|
432
|
+
_save_progress(dst, results_df, train_metrics_df)
|
425
433
|
clear_output(wait=True)
|
426
434
|
display(results_df)
|
427
|
-
|
435
|
+
_save_model(model, model_type, results_df, dst, epoch, epochs, intermedeate_save=[0.99,0.98,0.95,0.94])
|
428
436
|
|
429
437
|
if train_mode == 'irm':
|
430
438
|
dummy_w = torch.nn.Parameter(torch.Tensor([1.0])).to(device)
|
@@ -494,7 +502,7 @@ def train_model(dst, model_type, train_loaders, train_loader_names, train_mode='
|
|
494
502
|
|
495
503
|
clear_output(wait=True)
|
496
504
|
display(results_df)
|
497
|
-
|
498
|
-
|
505
|
+
_save_progress(dst, results_df, train_metrics_df)
|
506
|
+
_save_model(model, model_type, results_df, dst, epoch, epochs, intermedeate_save=[0.99,0.98,0.95,0.94])
|
499
507
|
print(f'Saved model: {dst}')
|
500
508
|
return
|
spacr/utils.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
import os, re, sqlite3, gc, torch, torchvision, time, random, string, shutil, cv2, tarfile, glob
|
1
|
+
import sys, os, re, sqlite3, gc, torch, torchvision, time, random, string, shutil, cv2, tarfile, glob
|
2
2
|
|
3
3
|
import numpy as np
|
4
4
|
from cellpose import models as cp_models
|
@@ -71,7 +71,8 @@ def _outline_and_overlay(image, rgb_image, mask_dims, outline_colors, outline_th
|
|
71
71
|
outline = np.zeros_like(mask, dtype=np.uint8) # Use uint8 for contour detection efficiency
|
72
72
|
|
73
73
|
# Find and draw contours
|
74
|
-
for j in np.unique(mask)
|
74
|
+
for j in np.unique(mask):
|
75
|
+
#for j in np.unique(mask)[1:]:
|
75
76
|
contours = find_contours(mask == j, 0.5)
|
76
77
|
# Convert contours for OpenCV format and draw directly to optimize
|
77
78
|
cv_contours = [np.flip(contour.astype(int), axis=1) for contour in contours]
|
@@ -781,8 +782,8 @@ def _get_object_settings(object_type, settings):
|
|
781
782
|
object_settings = {}
|
782
783
|
|
783
784
|
object_settings['diameter'] = _get_diam(settings['magnification'], obj=object_type)
|
784
|
-
object_settings['minimum_size'] = (object_settings['diameter']**2)/
|
785
|
-
object_settings['maximum_size'] = (object_settings['diameter']**2)*
|
785
|
+
object_settings['minimum_size'] = (object_settings['diameter']**2)/4
|
786
|
+
object_settings['maximum_size'] = (object_settings['diameter']**2)*10
|
786
787
|
object_settings['merge'] = False
|
787
788
|
object_settings['resample'] = True
|
788
789
|
object_settings['remove_border_objects'] = False
|
@@ -793,21 +794,22 @@ def _get_object_settings(object_type, settings):
|
|
793
794
|
object_settings['model_name'] = 'cyto'
|
794
795
|
else:
|
795
796
|
object_settings['model_name'] = 'cyto2'
|
796
|
-
object_settings['filter_size'] =
|
797
|
-
object_settings['filter_intensity'] =
|
797
|
+
object_settings['filter_size'] = False
|
798
|
+
object_settings['filter_intensity'] = False
|
798
799
|
object_settings['restore_type'] = settings.get('cell_restore_type', None)
|
799
800
|
|
800
801
|
elif object_type == 'nucleus':
|
801
802
|
object_settings['model_name'] = 'nuclei'
|
802
|
-
object_settings['filter_size'] =
|
803
|
-
object_settings['filter_intensity'] =
|
803
|
+
object_settings['filter_size'] = False
|
804
|
+
object_settings['filter_intensity'] = False
|
804
805
|
object_settings['restore_type'] = settings.get('nucleus_restore_type', None)
|
805
806
|
|
806
807
|
elif object_type == 'pathogen':
|
807
808
|
object_settings['model_name'] = 'cyto'
|
808
809
|
object_settings['filter_size'] = True
|
809
|
-
object_settings['filter_intensity'] =
|
810
|
+
object_settings['filter_intensity'] = False
|
810
811
|
object_settings['restore_type'] = settings.get('pathogen_restore_type', None)
|
812
|
+
object_settings['merge'] = settings['merge_pathogens']
|
811
813
|
|
812
814
|
else:
|
813
815
|
print(f'Object type: {object_type} not supported. Supported object types are : cell, nucleus and pathogen')
|
@@ -884,17 +886,15 @@ def _get_cellpose_channels(src, nucleus_channel, pathogen_channel, cell_channel)
|
|
884
886
|
|
885
887
|
if not pathogen_channel is None:
|
886
888
|
if not nucleus_channel is None:
|
887
|
-
|
889
|
+
if not pathogen_channel is None:
|
890
|
+
cellpose_channels['pathogen'] = [0,2]
|
891
|
+
else:
|
892
|
+
cellpose_channels['pathogen'] = [0,1]
|
888
893
|
else:
|
889
894
|
cellpose_channels['pathogen'] = [0,0]
|
890
895
|
|
891
896
|
if not cell_channel is None:
|
892
897
|
if not nucleus_channel is None:
|
893
|
-
if not pathogen_channel is None:
|
894
|
-
cellpose_channels['cell'] = [0,2]
|
895
|
-
else:
|
896
|
-
cellpose_channels['cell'] = [0,1]
|
897
|
-
elif not pathogen_channel is None:
|
898
898
|
cellpose_channels['cell'] = [0,1]
|
899
899
|
else:
|
900
900
|
cellpose_channels['cell'] = [0,0]
|
@@ -1069,7 +1069,7 @@ class Cache:
|
|
1069
1069
|
cache (OrderedDict): The cache data structure.
|
1070
1070
|
"""
|
1071
1071
|
|
1072
|
-
def
|
1072
|
+
def __init__(self, max_size):
|
1073
1073
|
self.cache = OrderedDict()
|
1074
1074
|
self.max_size = max_size
|
1075
1075
|
|
@@ -1100,7 +1100,7 @@ class ScaledDotProductAttention(nn.Module):
|
|
1100
1100
|
|
1101
1101
|
"""
|
1102
1102
|
|
1103
|
-
def
|
1103
|
+
def __init__(self, d_k):
|
1104
1104
|
super(ScaledDotProductAttention, self).__init__()
|
1105
1105
|
self.d_k = d_k
|
1106
1106
|
|
@@ -1131,7 +1131,7 @@ class SelfAttention(nn.Module):
|
|
1131
1131
|
d_k (int): Dimensionality of the key and query vectors.
|
1132
1132
|
"""
|
1133
1133
|
|
1134
|
-
def
|
1134
|
+
def __init__(self, in_channels, d_k):
|
1135
1135
|
super(SelfAttention, self).__init__()
|
1136
1136
|
self.W_q = nn.Linear(in_channels, d_k)
|
1137
1137
|
self.W_k = nn.Linear(in_channels, d_k)
|
@@ -1155,7 +1155,7 @@ class SelfAttention(nn.Module):
|
|
1155
1155
|
return output
|
1156
1156
|
|
1157
1157
|
class ScaledDotProductAttention(nn.Module):
|
1158
|
-
def
|
1158
|
+
def __init__(self, d_k):
|
1159
1159
|
"""
|
1160
1160
|
Initializes the ScaledDotProductAttention module.
|
1161
1161
|
|
@@ -1192,7 +1192,7 @@ class SelfAttention(nn.Module):
|
|
1192
1192
|
in_channels (int): Number of input channels.
|
1193
1193
|
d_k (int): Dimensionality of the key and query vectors.
|
1194
1194
|
"""
|
1195
|
-
def
|
1195
|
+
def __init__(self, in_channels, d_k):
|
1196
1196
|
super(SelfAttention, self).__init__()
|
1197
1197
|
self.W_q = nn.Linear(in_channels, d_k)
|
1198
1198
|
self.W_k = nn.Linear(in_channels, d_k)
|
@@ -1223,7 +1223,7 @@ class EarlyFusion(nn.Module):
|
|
1223
1223
|
Args:
|
1224
1224
|
in_channels (int): Number of input channels.
|
1225
1225
|
"""
|
1226
|
-
def
|
1226
|
+
def __init__(self, in_channels):
|
1227
1227
|
super(EarlyFusion, self).__init__()
|
1228
1228
|
self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=1, stride=1)
|
1229
1229
|
|
@@ -1242,7 +1242,7 @@ class EarlyFusion(nn.Module):
|
|
1242
1242
|
|
1243
1243
|
# Spatial Attention Mechanism
|
1244
1244
|
class SpatialAttention(nn.Module):
|
1245
|
-
def
|
1245
|
+
def __init__(self, kernel_size=7):
|
1246
1246
|
"""
|
1247
1247
|
Initializes the SpatialAttention module.
|
1248
1248
|
|
@@ -1287,7 +1287,7 @@ class MultiScaleBlockWithAttention(nn.Module):
|
|
1287
1287
|
forward: Forward method for the module.
|
1288
1288
|
"""
|
1289
1289
|
|
1290
|
-
def
|
1290
|
+
def __init__(self, in_channels, out_channels):
|
1291
1291
|
super(MultiScaleBlockWithAttention, self).__init__()
|
1292
1292
|
self.dilated_conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, dilation=1, padding=1)
|
1293
1293
|
self.spatial_attention = nn.Conv2d(out_channels, out_channels, kernel_size=1)
|
@@ -1320,7 +1320,7 @@ class MultiScaleBlockWithAttention(nn.Module):
|
|
1320
1320
|
|
1321
1321
|
# Final Classifier
|
1322
1322
|
class CustomCellClassifier(nn.Module):
|
1323
|
-
def
|
1323
|
+
def __init__(self, num_classes, pathogen_channel, use_attention, use_checkpoint, dropout_rate):
|
1324
1324
|
super(CustomCellClassifier, self).__init__()
|
1325
1325
|
self.early_fusion = EarlyFusion(in_channels=3)
|
1326
1326
|
|
@@ -1349,7 +1349,7 @@ class CustomCellClassifier(nn.Module):
|
|
1349
1349
|
|
1350
1350
|
#CNN and Transformer class, pick any Torch model.
|
1351
1351
|
class TorchModel(nn.Module):
|
1352
|
-
def
|
1352
|
+
def __init__(self, model_name='resnet50', pretrained=True, dropout_rate=None, use_checkpoint=False):
|
1353
1353
|
super(TorchModel, self).__init__()
|
1354
1354
|
self.model_name = model_name
|
1355
1355
|
self.use_checkpoint = use_checkpoint
|
@@ -1423,7 +1423,7 @@ class TorchModel(nn.Module):
|
|
1423
1423
|
return logits
|
1424
1424
|
|
1425
1425
|
class FocalLossWithLogits(nn.Module):
|
1426
|
-
def
|
1426
|
+
def __init__(self, alpha=1, gamma=2):
|
1427
1427
|
super(FocalLossWithLogits, self).__init__()
|
1428
1428
|
self.alpha = alpha
|
1429
1429
|
self.gamma = gamma
|
@@ -1435,7 +1435,7 @@ class FocalLossWithLogits(nn.Module):
|
|
1435
1435
|
return focal_loss.mean()
|
1436
1436
|
|
1437
1437
|
class ResNet(nn.Module):
|
1438
|
-
def
|
1438
|
+
def __init__(self, resnet_type='resnet50', dropout_rate=None, use_checkpoint=False, init_weights='imagenet'):
|
1439
1439
|
super(ResNet, self).__init__()
|
1440
1440
|
|
1441
1441
|
resnet_map = {
|
@@ -1788,25 +1788,24 @@ def annotate_predictions(csv_loc):
|
|
1788
1788
|
df['cond'] = df.apply(assign_condition, axis=1)
|
1789
1789
|
return df
|
1790
1790
|
|
1791
|
-
def
|
1791
|
+
def initiate_counter(counter_, lock_):
|
1792
1792
|
global counter, lock
|
1793
1793
|
counter = counter_
|
1794
1794
|
lock = lock_
|
1795
1795
|
|
1796
|
-
def add_images_to_tar(
|
1797
|
-
global counter, lock, total_images
|
1798
|
-
paths_chunk, tar_path = args
|
1796
|
+
def add_images_to_tar(paths_chunk, tar_path, total_images):
|
1799
1797
|
with tarfile.open(tar_path, 'w') as tar:
|
1800
|
-
for img_path in paths_chunk:
|
1798
|
+
for i, img_path in enumerate(paths_chunk):
|
1801
1799
|
arcname = os.path.basename(img_path)
|
1802
1800
|
try:
|
1803
1801
|
tar.add(img_path, arcname=arcname)
|
1804
1802
|
with lock:
|
1805
1803
|
counter.value += 1
|
1806
|
-
|
1804
|
+
if counter.value % 100 == 0: # Print every 100 updates
|
1805
|
+
progress = (counter.value / total_images) * 100
|
1806
|
+
print(f"Progress: {counter.value}/{total_images} ({progress:.2f}%)", end='\r', file=sys.stdout, flush=True)
|
1807
1807
|
except FileNotFoundError:
|
1808
1808
|
print(f"File not found: {img_path}")
|
1809
|
-
return tar_path
|
1810
1809
|
|
1811
1810
|
def generate_fraction_map(df, gene_column, min_frequency=0.0):
|
1812
1811
|
df['fraction'] = df['count']/df['well_read_sum']
|
@@ -2255,8 +2254,8 @@ def dice_coefficient(mask1, mask2):
|
|
2255
2254
|
def extract_boundaries(mask, dilation_radius=1):
|
2256
2255
|
binary_mask = (mask > 0).astype(np.uint8)
|
2257
2256
|
struct_elem = np.ones((dilation_radius*2+1, dilation_radius*2+1))
|
2258
|
-
dilated = binary_dilation(binary_mask, footprint=struct_elem)
|
2259
|
-
eroded = binary_erosion(binary_mask, footprint=struct_elem)
|
2257
|
+
dilated = morphology.binary_dilation(binary_mask, footprint=struct_elem)
|
2258
|
+
eroded = morphology.binary_erosion(binary_mask, footprint=struct_elem)
|
2260
2259
|
boundary = dilated ^ eroded
|
2261
2260
|
return boundary
|
2262
2261
|
|
@@ -2669,6 +2668,13 @@ def _filter_cp_masks(masks, flows, filter_size, filter_intensity, minimum_size,
|
|
2669
2668
|
print(f'Number of objects before filtration: {num_objects}')
|
2670
2669
|
plot_masks(batch=image, masks=mask, flows=flow, cmap='inferno', figuresize=figuresize, nr=1, file_type='.npz', print_object_number=True)
|
2671
2670
|
|
2671
|
+
if merge:
|
2672
|
+
mask = merge_touching_objects(mask, threshold=0.66)
|
2673
|
+
if plot and idx == 0:
|
2674
|
+
num_objects = mask_object_count(mask)
|
2675
|
+
print(f'Number of objects after merging adjacent objects, : {num_objects}')
|
2676
|
+
plot_masks(batch=image, masks=mask, flows=flow, cmap='inferno', figuresize=figuresize, nr=1, file_type='.npz', print_object_number=True)
|
2677
|
+
|
2672
2678
|
if filter_size:
|
2673
2679
|
props = measure.regionprops_table(mask, properties=['label', 'area'])
|
2674
2680
|
valid_labels = props['label'][np.logical_and(props['area'] > minimum_size, props['area'] < maximum_size)]
|
@@ -2714,13 +2720,6 @@ def _filter_cp_masks(masks, flows, filter_size, filter_intensity, minimum_size,
|
|
2714
2720
|
print(f'Number of objects after removing border objects, : {num_objects}')
|
2715
2721
|
plot_masks(batch=image, masks=mask, flows=flow, cmap='inferno', figuresize=figuresize, nr=1, file_type='.npz', print_object_number=True)
|
2716
2722
|
|
2717
|
-
if merge:
|
2718
|
-
mask = merge_touching_objects(mask, threshold=0.25)
|
2719
|
-
if plot and idx == 0:
|
2720
|
-
num_objects = mask_object_count(mask)
|
2721
|
-
print(f'Number of objects after merging adjacent objects, : {num_objects}')
|
2722
|
-
plot_masks(batch=image, masks=mask, flows=flow, cmap='inferno', figuresize=figuresize, nr=1, file_type='.npz', print_object_number=True)
|
2723
|
-
|
2724
2723
|
mask_stack.append(mask)
|
2725
2724
|
|
2726
2725
|
return mask_stack
|
@@ -2789,6 +2788,7 @@ def _run_test_mode(src, regex, timelapse=False):
|
|
2789
2788
|
|
2790
2789
|
# Prepare for random selection
|
2791
2790
|
set_identifiers = list(images_by_set.keys())
|
2791
|
+
random.seed(42)
|
2792
2792
|
random.shuffle(set_identifiers) # Randomize the order
|
2793
2793
|
|
2794
2794
|
# Select a subset based on the test_images count
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: spacr
|
3
|
-
Version: 0.0.
|
3
|
+
Version: 0.0.21
|
4
4
|
Summary: Spatial phenotype analysis of crisp screens (SpaCr)
|
5
5
|
Home-page: https://github.com/EinarOlafsson/spacr
|
6
6
|
Author: Einar Birnir Olafsson
|
@@ -19,6 +19,7 @@ Requires-Dist: scikit-image >=0.22.0
|
|
19
19
|
Requires-Dist: scikit-learn >=1.4.1
|
20
20
|
Requires-Dist: seaborn >=0.13.2
|
21
21
|
Requires-Dist: matplotlib >=3.8.3
|
22
|
+
Requires-Dist: shap >=0.45.0
|
22
23
|
Requires-Dist: pillow >=10.2.0
|
23
24
|
Requires-Dist: imageio >=2.34.0
|
24
25
|
Requires-Dist: scipy >=1.12.0
|
@@ -31,6 +32,8 @@ Requires-Dist: IPython >=8.18.1
|
|
31
32
|
Requires-Dist: opencv-python-headless >=4.9.0.80
|
32
33
|
Requires-Dist: umap >=0.1.1
|
33
34
|
Requires-Dist: ttkthemes >=3.2.2
|
35
|
+
Requires-Dist: xgboost >=2.0.3
|
36
|
+
Requires-Dist: PyWavelets >=1.6.0
|
34
37
|
Requires-Dist: lxml >=5.1.0
|
35
38
|
Provides-Extra: dev
|
36
39
|
Requires-Dist: pytest >=3.9 ; extra == 'dev'
|
@@ -49,7 +52,7 @@ Requires-Dist: opencv-python-headless ; extra == 'headless'
|
|
49
52
|
<tr>
|
50
53
|
<td>
|
51
54
|
|
52
|
-
Spatial phenotype analysis of
|
55
|
+
Spatial phenotype analysis of CRISPR-Cas9 screens (SpaCr). The spatial organization of organelles and proteins within cells constitutes a key level of functional regulation. In the context of infectious disease, the spatial relationships between host cell structures and intracellular pathogens are critical to understand host clearance mechanisms and how pathogens evade them. Spacr is a Python-based software package for generating single cell image data for deep-learning sub-cellular/cellular phenotypic classification from pooled genetic CRISPR-Cas9 screens. Spacr provides a flexible toolset to extract single cell images and measurements from high content cell painting experiments, train deep-learning models to classify cellular/ subcellular phenotypes, simulate and analyze pooled CRISPR-Cas9 imaging screens.
|
53
56
|
|
54
57
|
</td>
|
55
58
|
<td>
|
@@ -0,0 +1,33 @@
|
|
1
|
+
spacr/__init__.py,sha256=mDi-Qu5r1vZnqIbUBV1JAoSq-mxmMEOmni1JSG2e4Wo,879
|
2
|
+
spacr/__main__.py,sha256=_qRkhbFrH_cXr7AZs6KHL8Hh4VApqNdpNCtiKn2ePTo,285
|
3
|
+
spacr/alpha.py,sha256=1LUtTaeVHqcTMxoCMA7mlsNYyaR0KEaXglwXsUQLwKo,13776
|
4
|
+
spacr/annotate_app.py,sha256=jtm1HgWngv9BTsyVzdmBqg1flbmbXs9JkP6CB4wmOH8,19951
|
5
|
+
spacr/cli.py,sha256=507jfOOEV8BoL4eeUcblvH-iiDHdBrEVJLu1ghAAPSc,1800
|
6
|
+
spacr/core.py,sha256=_R8gXNnjf680yrnbCi2piWQUz7PDbqWYn7SL5MACLfo,156457
|
7
|
+
spacr/foldseek.py,sha256=cWtLzvFF2O_mq5I71UMiuU9DTvDCp7wl6aaWAZRrBZc,33970
|
8
|
+
spacr/get_alfafold_structures.py,sha256=n0g8gne-oyAV3Uo6qxZoJq5X1cUUyD8u0pOC_W2PX40,3541
|
9
|
+
spacr/graph_learning.py,sha256=sD4eOC7Q16rr7WO20mCi_E16_LqioGUUgPamAHIIeNI,12568
|
10
|
+
spacr/graph_learning_lap.py,sha256=MyNRLb63gsjBlui-ByZ0anHugYulL6M-OsGm8rnGBmE,3385
|
11
|
+
spacr/gui_classify_app.py,sha256=-I06tVoA3U0jaAoTs32H1Y5ACMz6QBaEM1NEfg5w-9c,7965
|
12
|
+
spacr/gui_mask_app.py,sha256=wb1w_-U1RTJFfRCGfdrkku4FVULSA_2gFZyVM9oQlB0,8752
|
13
|
+
spacr/gui_measure_app.py,sha256=9mAw3Tiuq61uKTzMVslr0MgD8m1Lv5PNI0K4-gQiuXE,8061
|
14
|
+
spacr/gui_sim_app.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
15
|
+
spacr/gui_utils.py,sha256=NlgGFt62yqYnusv4yT1yjPcjd-v29N4IYukkJkEdL28,30493
|
16
|
+
spacr/io.py,sha256=gcQJ-WOzwDzRZsIzzIoMKIbhFN5AP9lYy2HCv9fG8CU,108683
|
17
|
+
spacr/logger.py,sha256=7Zqr3TuuOQLWT32gYr2q1qvv7x0a2JhLANmZcnBXAW8,670
|
18
|
+
spacr/mask_app.py,sha256=B6-zYXVFg-cc58gLcz-Ry6LClO2jxLitL6B2ACb0HTw,39278
|
19
|
+
spacr/measure.py,sha256=_f6UDugw75rILEg0uo2-QsUGUvc4AQdrdHl-BPZk74I,54686
|
20
|
+
spacr/old_code.py,sha256=KxljHpKNsV5EfX9ifN2xJTnUeqAhyabZyfDWd5THOOc,11226
|
21
|
+
spacr/plot.py,sha256=VtDKTJ_zo8CAVC3ILuIN_wUP6197vq089wNZuom7T8g,61655
|
22
|
+
spacr/sim.py,sha256=2NR5hm--HVcYQnj1SCHoUCVbh_b2XUjjjfoAUIXFwnQ,72997
|
23
|
+
spacr/timelapse.py,sha256=plPjR8nZ7_Q50VAvMvHK2TUE4F-vh7R23JnI6tSW02g,39661
|
24
|
+
spacr/train.py,sha256=lp66dWYkiMMlgdYlMjAsJnkIZFWLizKB-xwyVnKgFBs,25904
|
25
|
+
spacr/umap.py,sha256=4QSrQ16Og-Ijq-SwguMQT2f20UWz1LE5HQeSLmzSl8c,29370
|
26
|
+
spacr/utils.py,sha256=r0BjycEXqr48m0xEW9o14bLPj62JNpmEUBaBwR817Bw,121205
|
27
|
+
spacr/version.py,sha256=axH5tnGwtgSnJHb5IDhiu4Zjk5GhLyAEDRe-rnaoFOA,409
|
28
|
+
spacr-0.0.21.dist-info/LICENSE,sha256=SR-2MeGc6SCM1UORJYyarSWY_A-JaOMFDj7ReSs9tRM,1083
|
29
|
+
spacr-0.0.21.dist-info/METADATA,sha256=u93xu8dbff0h_8p25GP5yi49sUOh7GS-axPRbik4qRk,4973
|
30
|
+
spacr-0.0.21.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
31
|
+
spacr-0.0.21.dist-info/entry_points.txt,sha256=5uyJaAxWCbjWYwP15InAKU1yFxTwyuvCGtIGceso1es,290
|
32
|
+
spacr-0.0.21.dist-info/top_level.txt,sha256=GJPU8FgwRXGzKeut6JopsSRY2R8T3i9lDgya42tLInY,6
|
33
|
+
spacr-0.0.21.dist-info/RECORD,,
|
spacr-0.0.20.dist-info/RECORD
DELETED
@@ -1,31 +0,0 @@
|
|
1
|
-
spacr/__init__.py,sha256=mDi-Qu5r1vZnqIbUBV1JAoSq-mxmMEOmni1JSG2e4Wo,879
|
2
|
-
spacr/__main__.py,sha256=_qRkhbFrH_cXr7AZs6KHL8Hh4VApqNdpNCtiKn2ePTo,285
|
3
|
-
spacr/alpha.py,sha256=Q1vnqO0hvU1G7QP26amFwJY2RjZ68zIc3jYoqQSBMrw,462
|
4
|
-
spacr/annotate_app.py,sha256=IPgZfS4TrSqbJr81P1FWUNOgCPPcS6EdQjUsXRwY-4E,19932
|
5
|
-
spacr/cli.py,sha256=507jfOOEV8BoL4eeUcblvH-iiDHdBrEVJLu1ghAAPSc,1800
|
6
|
-
spacr/core.py,sha256=VTk81S80PKq5pGeUEdduubjbsAyPWKDyRu11EWDO2ms,120307
|
7
|
-
spacr/graph_learning.py,sha256=sD4eOC7Q16rr7WO20mCi_E16_LqioGUUgPamAHIIeNI,12568
|
8
|
-
spacr/graph_learning_lap.py,sha256=MyNRLb63gsjBlui-ByZ0anHugYulL6M-OsGm8rnGBmE,3385
|
9
|
-
spacr/gui_classify_app.py,sha256=-I06tVoA3U0jaAoTs32H1Y5ACMz6QBaEM1NEfg5w-9c,7965
|
10
|
-
spacr/gui_mask_app.py,sha256=xpQ_kh-8lTb9xnyKNrZGtf9JnKqutBFu2-LdYzQShh0,8079
|
11
|
-
spacr/gui_measure_app.py,sha256=9mAw3Tiuq61uKTzMVslr0MgD8m1Lv5PNI0K4-gQiuXE,8061
|
12
|
-
spacr/gui_sim_app.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
13
|
-
spacr/gui_utils.py,sha256=dpbrsDVebjm8ZmkYYOmIZCbkQYyc6JyMIPA7C0r4Xxw,29631
|
14
|
-
spacr/io.py,sha256=bCe40kli7jx7hKQyDNuZYCpZ44-Brj4TZVl9r02cFHo,102042
|
15
|
-
spacr/logger.py,sha256=7Zqr3TuuOQLWT32gYr2q1qvv7x0a2JhLANmZcnBXAW8,670
|
16
|
-
spacr/mask_app.py,sha256=B6-zYXVFg-cc58gLcz-Ry6LClO2jxLitL6B2ACb0HTw,39278
|
17
|
-
spacr/measure.py,sha256=LvF6D-TydhXPwKkeAKuIvzHn14qVlqsn6h7ENnoKn5s,50795
|
18
|
-
spacr/old_code.py,sha256=KxljHpKNsV5EfX9ifN2xJTnUeqAhyabZyfDWd5THOOc,11226
|
19
|
-
spacr/plot.py,sha256=qSM0NzVQYqMYioRc_BPCkUipLRZH7_GaHKGpLXGL8oI,55040
|
20
|
-
spacr/sim.py,sha256=tl40lgTMeeJSyBq_c-Rn54C9Ri0FJ2zLkLLLPLSjz3o,51534
|
21
|
-
spacr/timelapse.py,sha256=wAEMv7oPyusLph3RPmF4F6UGmLfMZmrupYfsuaeJ9vI,34003
|
22
|
-
spacr/train.py,sha256=r77zLvLFMzx6MJxXG3JjynD8qTWYM9pNgrChEXYQhtY,25631
|
23
|
-
spacr/umap.py,sha256=4QSrQ16Og-Ijq-SwguMQT2f20UWz1LE5HQeSLmzSl8c,29370
|
24
|
-
spacr/utils.py,sha256=kEQxucklUdogxjOSQxKdA1R_NU5qYj6dJPiQKIB8un4,120992
|
25
|
-
spacr/version.py,sha256=axH5tnGwtgSnJHb5IDhiu4Zjk5GhLyAEDRe-rnaoFOA,409
|
26
|
-
spacr-0.0.20.dist-info/LICENSE,sha256=SR-2MeGc6SCM1UORJYyarSWY_A-JaOMFDj7ReSs9tRM,1083
|
27
|
-
spacr-0.0.20.dist-info/METADATA,sha256=S8bOdo2ZqIj-ScjQRBxQnBQNXXFV-HKfidMs7MRIsIc,4381
|
28
|
-
spacr-0.0.20.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
29
|
-
spacr-0.0.20.dist-info/entry_points.txt,sha256=5uyJaAxWCbjWYwP15InAKU1yFxTwyuvCGtIGceso1es,290
|
30
|
-
spacr-0.0.20.dist-info/top_level.txt,sha256=GJPU8FgwRXGzKeut6JopsSRY2R8T3i9lDgya42tLInY,6
|
31
|
-
spacr-0.0.20.dist-info/RECORD,,
|
File without changes
|