spacr 0.0.20__py3-none-any.whl → 0.0.35__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- spacr/alpha.py +291 -14
- spacr/annotate_app.py +7 -5
- spacr/chris.py +50 -0
- spacr/core.py +1301 -426
- spacr/foldseek.py +793 -0
- spacr/get_alfafold_structures.py +72 -0
- spacr/gui.py +144 -0
- spacr/gui_classify_app.py +65 -74
- spacr/gui_mask_app.py +110 -87
- spacr/gui_measure_app.py +104 -81
- spacr/gui_utils.py +276 -31
- spacr/io.py +261 -102
- spacr/mask_app.py +6 -3
- spacr/measure.py +150 -64
- spacr/plot.py +151 -12
- spacr/sim.py +666 -119
- spacr/timelapse.py +139 -9
- spacr/train.py +18 -10
- spacr/utils.py +43 -49
- {spacr-0.0.20.dist-info → spacr-0.0.35.dist-info}/METADATA +5 -2
- spacr-0.0.35.dist-info/RECORD +35 -0
- spacr-0.0.35.dist-info/entry_points.txt +8 -0
- spacr-0.0.20.dist-info/RECORD +0 -31
- spacr-0.0.20.dist-info/entry_points.txt +0 -7
- {spacr-0.0.20.dist-info → spacr-0.0.35.dist-info}/LICENSE +0 -0
- {spacr-0.0.20.dist-info → spacr-0.0.35.dist-info}/WHEEL +0 -0
- {spacr-0.0.20.dist-info → spacr-0.0.35.dist-info}/top_level.txt +0 -0
spacr/timelapse.py
CHANGED
@@ -11,6 +11,7 @@ from btrack import datasets as btrack_datasets
|
|
11
11
|
from skimage.measure import regionprops
|
12
12
|
from scipy.signal import find_peaks
|
13
13
|
from scipy.optimize import curve_fit
|
14
|
+
from scipy.integrate import trapz
|
14
15
|
import matplotlib.pyplot as plt
|
15
16
|
|
16
17
|
from .logger import log_function_call
|
@@ -590,6 +591,80 @@ def infected_vs_noninfected(result_df, measurement):
|
|
590
591
|
plt.tight_layout()
|
591
592
|
plt.show()
|
592
593
|
|
594
|
+
def save_figure(fig, src, figure_number):
|
595
|
+
source = os.path.dirname(src)
|
596
|
+
results_fldr = os.path.join(source,'results')
|
597
|
+
os.makedirs(results_fldr, exist_ok=True)
|
598
|
+
fig_loc = os.path.join(results_fldr, f'figure_{figure_number}.pdf')
|
599
|
+
fig.savefig(fig_loc)
|
600
|
+
print(f'Saved figure:{fig_loc}')
|
601
|
+
|
602
|
+
def save_results_dataframe(df, src, results_name):
|
603
|
+
source = os.path.dirname(src)
|
604
|
+
results_fldr = os.path.join(source,'results')
|
605
|
+
os.makedirs(results_fldr, exist_ok=True)
|
606
|
+
csv_loc = os.path.join(results_fldr, f'{results_name}.csv')
|
607
|
+
df.to_csv(csv_loc, index=True)
|
608
|
+
print(f'Saved results:{csv_loc}')
|
609
|
+
|
610
|
+
def summarize_per_well(peak_details_df):
|
611
|
+
# Step 1: Split the 'ID' column
|
612
|
+
split_columns = peak_details_df['ID'].str.split('_', expand=True)
|
613
|
+
peak_details_df[['plate', 'row', 'column', 'field', 'object_number']] = split_columns
|
614
|
+
|
615
|
+
# Step 2: Create 'well_ID' by combining 'row' and 'column'
|
616
|
+
peak_details_df['well_ID'] = peak_details_df['row'] + '_' + peak_details_df['column']
|
617
|
+
|
618
|
+
# Filter entries where 'amplitude' is not null
|
619
|
+
filtered_df = peak_details_df[peak_details_df['amplitude'].notna()]
|
620
|
+
|
621
|
+
# Preparation for Step 3: Identify numeric columns for averaging from the filtered dataframe
|
622
|
+
numeric_cols = filtered_df.select_dtypes(include=['number']).columns
|
623
|
+
|
624
|
+
# Step 3: Calculate summary statistics
|
625
|
+
summary_df = filtered_df.groupby('well_ID').agg(
|
626
|
+
peaks_per_well=('ID', 'size'),
|
627
|
+
unique_IDs_with_amplitude=('ID', 'nunique'), # Count unique IDs per well with non-null amplitude
|
628
|
+
**{col: (col, 'mean') for col in numeric_cols} # exclude 'amplitude' from averaging if it's numeric
|
629
|
+
).reset_index()
|
630
|
+
|
631
|
+
# Step 3: Calculate summary statistics
|
632
|
+
summary_df_2 = peak_details_df.groupby('well_ID').agg(
|
633
|
+
cells_per_well=('object_number', 'nunique'),
|
634
|
+
).reset_index()
|
635
|
+
|
636
|
+
summary_df['cells_per_well'] = summary_df_2['cells_per_well']
|
637
|
+
summary_df['peaks_per_cell'] = summary_df['peaks_per_well'] / summary_df['cells_per_well']
|
638
|
+
|
639
|
+
return summary_df
|
640
|
+
|
641
|
+
def summarize_per_well_inf_non_inf(peak_details_df):
|
642
|
+
# Step 1: Split the 'ID' column
|
643
|
+
split_columns = peak_details_df['ID'].str.split('_', expand=True)
|
644
|
+
peak_details_df[['plate', 'row', 'column', 'field', 'object_number']] = split_columns
|
645
|
+
|
646
|
+
# Step 2: Create 'well_ID' by combining 'row' and 'column'
|
647
|
+
peak_details_df['well_ID'] = peak_details_df['row'] + '_' + peak_details_df['column']
|
648
|
+
|
649
|
+
# Assume 'pathogen_count' indicates infection if > 0
|
650
|
+
# Add an 'infected_status' column to classify cells
|
651
|
+
peak_details_df['infected_status'] = peak_details_df['infected'].apply(lambda x: 'infected' if x > 0 else 'non_infected')
|
652
|
+
|
653
|
+
# Preparation for Step 3: Identify numeric columns for averaging
|
654
|
+
numeric_cols = peak_details_df.select_dtypes(include=['number']).columns
|
655
|
+
|
656
|
+
# Step 3: Calculate summary statistics
|
657
|
+
summary_df = peak_details_df.groupby(['well_ID', 'infected_status']).agg(
|
658
|
+
cells_per_well=('object_number', 'nunique'),
|
659
|
+
peaks_per_well=('ID', 'size'),
|
660
|
+
**{col: (col, 'mean') for col in numeric_cols}
|
661
|
+
).reset_index()
|
662
|
+
|
663
|
+
# Calculate peaks per cell
|
664
|
+
summary_df['peaks_per_cell'] = summary_df['peaks_per_well'] / summary_df['cells_per_well']
|
665
|
+
|
666
|
+
return summary_df
|
667
|
+
|
593
668
|
def analyze_calcium_oscillations(db_loc, measurement='cell_channel_1_mean_intensity', size_filter='cell_area', fluctuation_threshold=0.25, num_lines=None, peak_height=0.01, pathogen=None, cytoplasm=None, remove_transient=True, verbose=False, transience_threshold=0.9):
|
594
669
|
# Load data
|
595
670
|
conn = sqlite3.connect(db_loc)
|
@@ -626,7 +701,7 @@ def analyze_calcium_oscillations(db_loc, measurement='cell_channel_1_mean_intens
|
|
626
701
|
cell_df['plate_row_column_field_object'] = cell_df['plate'].astype(str) + '_' + cell_df['row'].astype(str) + '_' + cell_df['column'].astype(str) + '_' + cell_df['field'].astype(str) + '_' + cell_df['object_label'].astype(str)
|
627
702
|
|
628
703
|
df = cell_df.copy()
|
629
|
-
|
704
|
+
|
630
705
|
# Fit exponential decay model to all scaled fluorescence data
|
631
706
|
try:
|
632
707
|
params, _ = curve_fit(exponential_decay, df['time'], df[measurement], p0=[max(df[measurement]), 0.01, min(df[measurement])], maxfev=10000)
|
@@ -653,11 +728,14 @@ def analyze_calcium_oscillations(db_loc, measurement='cell_channel_1_mean_intens
|
|
653
728
|
if verbose:
|
654
729
|
print(f'Group length: {len(group)} Timelapse length: {total_timepoints}, threshold:{threshold}')
|
655
730
|
|
656
|
-
if
|
731
|
+
if len(group) <= threshold:
|
657
732
|
transience_removed += 1
|
733
|
+
if verbose:
|
734
|
+
print(f'removed group {unique_id} due to transience')
|
658
735
|
continue
|
659
736
|
|
660
737
|
size_diff = group[size_filter].std() / group[size_filter].mean()
|
738
|
+
|
661
739
|
if size_diff <= fluctuation_threshold:
|
662
740
|
group['delta_' + measurement] = group['corrected_' + measurement].diff().fillna(0)
|
663
741
|
corrected_dfs.append(group)
|
@@ -665,12 +743,50 @@ def analyze_calcium_oscillations(db_loc, measurement='cell_channel_1_mean_intens
|
|
665
743
|
# Detect peaks
|
666
744
|
peaks, properties = find_peaks(group['delta_' + measurement], height=peak_height)
|
667
745
|
|
746
|
+
# Set values < 0 to 0
|
747
|
+
group_filtered = group.copy()
|
748
|
+
group_filtered['delta_' + measurement] = group['delta_' + measurement].clip(lower=0)
|
749
|
+
above_zero_auc = trapz(y=group_filtered['delta_' + measurement], x=group_filtered['time'])
|
750
|
+
auc = trapz(y=group['delta_' + measurement], x=group_filtered['time'])
|
751
|
+
is_infected = (group['parasite_count'] > 0).any()
|
752
|
+
|
753
|
+
if is_infected:
|
754
|
+
is_infected = 1
|
755
|
+
else:
|
756
|
+
is_infected = 0
|
757
|
+
|
758
|
+
if len(peaks) == 0:
|
759
|
+
peak_details_list.append({
|
760
|
+
'ID': unique_id,
|
761
|
+
'plate': group['plate'].iloc[0],
|
762
|
+
'row': group['row'].iloc[0],
|
763
|
+
'column': group['column'].iloc[0],
|
764
|
+
'field': group['field'].iloc[0],
|
765
|
+
'object_number': group['object_number'].iloc[0],
|
766
|
+
'time': np.nan, # The time of the peak
|
767
|
+
'amplitude': np.nan,
|
768
|
+
'delta': np.nan,
|
769
|
+
'AUC': auc,
|
770
|
+
'AUC_positive': above_zero_auc,
|
771
|
+
'AUC_peak': np.nan,
|
772
|
+
'infected': is_infected
|
773
|
+
})
|
774
|
+
|
668
775
|
# Inside the for loop where peaks are detected
|
669
776
|
for i, peak in enumerate(peaks):
|
670
|
-
|
671
|
-
|
672
|
-
|
777
|
+
|
778
|
+
amplitude = properties['peak_heights'][i]
|
779
|
+
peak_time = group['time'].iloc[peak]
|
673
780
|
pathogen_count_at_peak = group['parasite_count'].iloc[peak]
|
781
|
+
|
782
|
+
start_idx = max(peak - 1, 0)
|
783
|
+
end_idx = min(peak + 1, len(group) - 1)
|
784
|
+
|
785
|
+
# Using indices to slice for AUC calculation
|
786
|
+
peak_segment_y = group['delta_' + measurement].iloc[start_idx:end_idx + 1]
|
787
|
+
peak_segment_x = group['time'].iloc[start_idx:end_idx + 1]
|
788
|
+
peak_auc = trapz(y=peak_segment_y, x=peak_segment_x)
|
789
|
+
|
674
790
|
peak_details_list.append({
|
675
791
|
'ID': unique_id,
|
676
792
|
'plate': group['plate'].iloc[0],
|
@@ -681,6 +797,9 @@ def analyze_calcium_oscillations(db_loc, measurement='cell_channel_1_mean_intens
|
|
681
797
|
'time': peak_time, # The time of the peak
|
682
798
|
'amplitude': amplitude,
|
683
799
|
'delta': group['delta_' + measurement].iloc[peak],
|
800
|
+
'AUC': auc,
|
801
|
+
'AUC_positive': above_zero_auc,
|
802
|
+
'AUC_peak': peak_auc,
|
684
803
|
'infected': pathogen_count_at_peak
|
685
804
|
})
|
686
805
|
else:
|
@@ -697,7 +816,14 @@ def analyze_calcium_oscillations(db_loc, measurement='cell_channel_1_mean_intens
|
|
697
816
|
return
|
698
817
|
|
699
818
|
peak_details_df = pd.DataFrame(peak_details_list)
|
700
|
-
|
819
|
+
summary_df = summarize_per_well(peak_details_df)
|
820
|
+
summary_df_inf_non_inf = summarize_per_well_inf_non_inf(peak_details_df)
|
821
|
+
|
822
|
+
save_results_dataframe(df=peak_details_df, src=db_loc, results_name='peak_details')
|
823
|
+
save_results_dataframe(df=result_df, src=db_loc, results_name='results')
|
824
|
+
save_results_dataframe(df=summary_df, src=db_loc, results_name='well_results')
|
825
|
+
save_results_dataframe(df=summary_df_inf_non_inf, src=db_loc, results_name='well_results_inf_non_inf')
|
826
|
+
|
701
827
|
# Plotting
|
702
828
|
fig, ax = plt.subplots(figsize=(10, 8))
|
703
829
|
sampled_groups = result_df['plate_row_column_field_object'].unique()
|
@@ -714,12 +840,16 @@ def analyze_calcium_oscillations(db_loc, measurement='cell_channel_1_mean_intens
|
|
714
840
|
ax.set_xlabel('Time')
|
715
841
|
ax.set_ylabel('Normalized Delta ' + measurement)
|
716
842
|
plt.tight_layout()
|
843
|
+
|
717
844
|
plt.show()
|
845
|
+
|
846
|
+
save_figure(fig, src=db_loc, figure_number=1)
|
718
847
|
|
719
848
|
if pathogen:
|
720
849
|
infected_vs_noninfected(result_df, measurement)
|
850
|
+
save_figure(fig, src=db_loc, figure_number=2)
|
721
851
|
|
722
|
-
#
|
852
|
+
# Identify cells with and without pathogens
|
723
853
|
infected_cells = result_df[result_df.groupby('plate_row_column_field_object')['parasite_count'].transform('max') > 0]['plate_row_column_field_object'].unique()
|
724
854
|
noninfected_cells = result_df[result_df.groupby('plate_row_column_field_object')['parasite_count'].transform('max') == 0]['plate_row_column_field_object'].unique()
|
725
855
|
|
@@ -733,5 +863,5 @@ def analyze_calcium_oscillations(db_loc, measurement='cell_channel_1_mean_intens
|
|
733
863
|
|
734
864
|
print(f'Average number of peaks per infected cell: {avg_inf_peaks_per_cell:.2f}')
|
735
865
|
print(f'Average number of peaks per non-infected cell: {avg_non_inf_peaks_per_cell:.2f}')
|
736
|
-
|
737
|
-
return result_df, peak_details_df
|
866
|
+
print(f'done')
|
867
|
+
return result_df, peak_details_df, fig
|
spacr/train.py
CHANGED
@@ -194,8 +194,8 @@ def test_model_performance(loaders, model, loader_name_list, epoch, train_mode,
|
|
194
194
|
|
195
195
|
def train_test_model(src, settings, custom_model=False, custom_model_path=None):
|
196
196
|
|
197
|
-
from .io import
|
198
|
-
from .utils import pick_best_model
|
197
|
+
from .io import _save_settings, _copy_missclassified
|
198
|
+
from .utils import pick_best_model
|
199
199
|
from .core import generate_loaders
|
200
200
|
|
201
201
|
settings['src'] = src
|
@@ -208,7 +208,7 @@ def train_test_model(src, settings, custom_model=False, custom_model_path=None):
|
|
208
208
|
model = torch.load(custom_model_path)
|
209
209
|
|
210
210
|
if settings['train']:
|
211
|
-
|
211
|
+
_save_settings(settings, src)
|
212
212
|
torch.cuda.empty_cache()
|
213
213
|
torch.cuda.memory.empty_cache()
|
214
214
|
gc.collect()
|
@@ -227,7 +227,9 @@ def train_test_model(src, settings, custom_model=False, custom_model_path=None):
|
|
227
227
|
validation_split=settings['val_split'],
|
228
228
|
pin_memory=settings['pin_memory'],
|
229
229
|
normalize=settings['normalize'],
|
230
|
-
|
230
|
+
channels=settings['channels'],
|
231
|
+
verbose=settings['verbose'])
|
232
|
+
|
231
233
|
|
232
234
|
if settings['test']:
|
233
235
|
test, _, plate_names_test = generate_loaders(src,
|
@@ -240,6 +242,7 @@ def train_test_model(src, settings, custom_model=False, custom_model_path=None):
|
|
240
242
|
validation_split=0.0,
|
241
243
|
pin_memory=settings['pin_memory'],
|
242
244
|
normalize=settings['normalize'],
|
245
|
+
channels=settings['channels'],
|
243
246
|
verbose=settings['verbose'])
|
244
247
|
if model == None:
|
245
248
|
model_path = pick_best_model(src+'/model')
|
@@ -330,8 +333,8 @@ def train_model(dst, model_type, train_loaders, train_loader_names, train_mode='
|
|
330
333
|
None
|
331
334
|
"""
|
332
335
|
|
333
|
-
from .io import
|
334
|
-
from .utils import
|
336
|
+
from .io import _save_model, _save_progress
|
337
|
+
from .utils import compute_irm_penalty, calculate_loss, choose_model #evaluate_model_performance,
|
335
338
|
|
336
339
|
print(f'Train batches:{len(train_loaders)}, Validation batches:{len(val_loaders)}')
|
337
340
|
|
@@ -347,6 +350,11 @@ def train_model(dst, model_type, train_loaders, train_loader_names, train_mode='
|
|
347
350
|
break
|
348
351
|
|
349
352
|
model = choose_model(model_type, device, init_weights, dropout_rate, use_checkpoint)
|
353
|
+
|
354
|
+
if model is None:
|
355
|
+
print(f'Model {model_type} not found')
|
356
|
+
return
|
357
|
+
|
350
358
|
model.to(device)
|
351
359
|
|
352
360
|
if optimizer_type == 'adamw':
|
@@ -421,10 +429,10 @@ def train_model(dst, model_type, train_loaders, train_loader_names, train_mode='
|
|
421
429
|
if schedule == 'step_lr':
|
422
430
|
scheduler.step()
|
423
431
|
|
424
|
-
|
432
|
+
_save_progress(dst, results_df, train_metrics_df)
|
425
433
|
clear_output(wait=True)
|
426
434
|
display(results_df)
|
427
|
-
|
435
|
+
_save_model(model, model_type, results_df, dst, epoch, epochs, intermedeate_save=[0.99,0.98,0.95,0.94])
|
428
436
|
|
429
437
|
if train_mode == 'irm':
|
430
438
|
dummy_w = torch.nn.Parameter(torch.Tensor([1.0])).to(device)
|
@@ -494,7 +502,7 @@ def train_model(dst, model_type, train_loaders, train_loader_names, train_mode='
|
|
494
502
|
|
495
503
|
clear_output(wait=True)
|
496
504
|
display(results_df)
|
497
|
-
|
498
|
-
|
505
|
+
_save_progress(dst, results_df, train_metrics_df)
|
506
|
+
_save_model(model, model_type, results_df, dst, epoch, epochs, intermedeate_save=[0.99,0.98,0.95,0.94])
|
499
507
|
print(f'Saved model: {dst}')
|
500
508
|
return
|
spacr/utils.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
import os, re, sqlite3, gc, torch, torchvision, time, random, string, shutil, cv2, tarfile, glob
|
1
|
+
import sys, os, re, sqlite3, gc, torch, torchvision, time, random, string, shutil, cv2, tarfile, glob
|
2
2
|
|
3
3
|
import numpy as np
|
4
4
|
from cellpose import models as cp_models
|
@@ -46,12 +46,6 @@ from torchvision.models.resnet import ResNet18_Weights, ResNet34_Weights, ResNet
|
|
46
46
|
|
47
47
|
from .logger import log_function_call
|
48
48
|
|
49
|
-
#from .io import _read_and_join_tables, _save_figure
|
50
|
-
#from .timelapse import _btrack_track_cells, _trackpy_track_cells
|
51
|
-
#from .plot import _plot_images_on_grid, plot_masks, _plot_histograms_and_stats, plot_resize, _plot_plates, _reg_v_plot, plot_masks
|
52
|
-
#from .core import identify_masks
|
53
|
-
|
54
|
-
|
55
49
|
def _gen_rgb_image(image, cahnnels):
|
56
50
|
rgb_image = np.take(image, cahnnels, axis=-1)
|
57
51
|
rgb_image = rgb_image.astype(float)
|
@@ -71,7 +65,8 @@ def _outline_and_overlay(image, rgb_image, mask_dims, outline_colors, outline_th
|
|
71
65
|
outline = np.zeros_like(mask, dtype=np.uint8) # Use uint8 for contour detection efficiency
|
72
66
|
|
73
67
|
# Find and draw contours
|
74
|
-
for j in np.unique(mask)
|
68
|
+
for j in np.unique(mask):
|
69
|
+
#for j in np.unique(mask)[1:]:
|
75
70
|
contours = find_contours(mask == j, 0.5)
|
76
71
|
# Convert contours for OpenCV format and draw directly to optimize
|
77
72
|
cv_contours = [np.flip(contour.astype(int), axis=1) for contour in contours]
|
@@ -781,8 +776,8 @@ def _get_object_settings(object_type, settings):
|
|
781
776
|
object_settings = {}
|
782
777
|
|
783
778
|
object_settings['diameter'] = _get_diam(settings['magnification'], obj=object_type)
|
784
|
-
object_settings['minimum_size'] = (object_settings['diameter']**2)/
|
785
|
-
object_settings['maximum_size'] = (object_settings['diameter']**2)*
|
779
|
+
object_settings['minimum_size'] = (object_settings['diameter']**2)/4
|
780
|
+
object_settings['maximum_size'] = (object_settings['diameter']**2)*10
|
786
781
|
object_settings['merge'] = False
|
787
782
|
object_settings['resample'] = True
|
788
783
|
object_settings['remove_border_objects'] = False
|
@@ -793,21 +788,22 @@ def _get_object_settings(object_type, settings):
|
|
793
788
|
object_settings['model_name'] = 'cyto'
|
794
789
|
else:
|
795
790
|
object_settings['model_name'] = 'cyto2'
|
796
|
-
object_settings['filter_size'] =
|
797
|
-
object_settings['filter_intensity'] =
|
791
|
+
object_settings['filter_size'] = False
|
792
|
+
object_settings['filter_intensity'] = False
|
798
793
|
object_settings['restore_type'] = settings.get('cell_restore_type', None)
|
799
794
|
|
800
795
|
elif object_type == 'nucleus':
|
801
796
|
object_settings['model_name'] = 'nuclei'
|
802
|
-
object_settings['filter_size'] =
|
803
|
-
object_settings['filter_intensity'] =
|
797
|
+
object_settings['filter_size'] = False
|
798
|
+
object_settings['filter_intensity'] = False
|
804
799
|
object_settings['restore_type'] = settings.get('nucleus_restore_type', None)
|
805
800
|
|
806
801
|
elif object_type == 'pathogen':
|
807
802
|
object_settings['model_name'] = 'cyto'
|
808
803
|
object_settings['filter_size'] = True
|
809
|
-
object_settings['filter_intensity'] =
|
804
|
+
object_settings['filter_intensity'] = False
|
810
805
|
object_settings['restore_type'] = settings.get('pathogen_restore_type', None)
|
806
|
+
object_settings['merge'] = settings['merge_pathogens']
|
811
807
|
|
812
808
|
else:
|
813
809
|
print(f'Object type: {object_type} not supported. Supported object types are : cell, nucleus and pathogen')
|
@@ -884,17 +880,15 @@ def _get_cellpose_channels(src, nucleus_channel, pathogen_channel, cell_channel)
|
|
884
880
|
|
885
881
|
if not pathogen_channel is None:
|
886
882
|
if not nucleus_channel is None:
|
887
|
-
|
883
|
+
if not pathogen_channel is None:
|
884
|
+
cellpose_channels['pathogen'] = [0,2]
|
885
|
+
else:
|
886
|
+
cellpose_channels['pathogen'] = [0,1]
|
888
887
|
else:
|
889
888
|
cellpose_channels['pathogen'] = [0,0]
|
890
889
|
|
891
890
|
if not cell_channel is None:
|
892
891
|
if not nucleus_channel is None:
|
893
|
-
if not pathogen_channel is None:
|
894
|
-
cellpose_channels['cell'] = [0,2]
|
895
|
-
else:
|
896
|
-
cellpose_channels['cell'] = [0,1]
|
897
|
-
elif not pathogen_channel is None:
|
898
892
|
cellpose_channels['cell'] = [0,1]
|
899
893
|
else:
|
900
894
|
cellpose_channels['cell'] = [0,0]
|
@@ -1069,7 +1063,7 @@ class Cache:
|
|
1069
1063
|
cache (OrderedDict): The cache data structure.
|
1070
1064
|
"""
|
1071
1065
|
|
1072
|
-
def
|
1066
|
+
def __init__(self, max_size):
|
1073
1067
|
self.cache = OrderedDict()
|
1074
1068
|
self.max_size = max_size
|
1075
1069
|
|
@@ -1100,7 +1094,7 @@ class ScaledDotProductAttention(nn.Module):
|
|
1100
1094
|
|
1101
1095
|
"""
|
1102
1096
|
|
1103
|
-
def
|
1097
|
+
def __init__(self, d_k):
|
1104
1098
|
super(ScaledDotProductAttention, self).__init__()
|
1105
1099
|
self.d_k = d_k
|
1106
1100
|
|
@@ -1131,7 +1125,7 @@ class SelfAttention(nn.Module):
|
|
1131
1125
|
d_k (int): Dimensionality of the key and query vectors.
|
1132
1126
|
"""
|
1133
1127
|
|
1134
|
-
def
|
1128
|
+
def __init__(self, in_channels, d_k):
|
1135
1129
|
super(SelfAttention, self).__init__()
|
1136
1130
|
self.W_q = nn.Linear(in_channels, d_k)
|
1137
1131
|
self.W_k = nn.Linear(in_channels, d_k)
|
@@ -1155,7 +1149,7 @@ class SelfAttention(nn.Module):
|
|
1155
1149
|
return output
|
1156
1150
|
|
1157
1151
|
class ScaledDotProductAttention(nn.Module):
|
1158
|
-
def
|
1152
|
+
def __init__(self, d_k):
|
1159
1153
|
"""
|
1160
1154
|
Initializes the ScaledDotProductAttention module.
|
1161
1155
|
|
@@ -1192,7 +1186,7 @@ class SelfAttention(nn.Module):
|
|
1192
1186
|
in_channels (int): Number of input channels.
|
1193
1187
|
d_k (int): Dimensionality of the key and query vectors.
|
1194
1188
|
"""
|
1195
|
-
def
|
1189
|
+
def __init__(self, in_channels, d_k):
|
1196
1190
|
super(SelfAttention, self).__init__()
|
1197
1191
|
self.W_q = nn.Linear(in_channels, d_k)
|
1198
1192
|
self.W_k = nn.Linear(in_channels, d_k)
|
@@ -1223,7 +1217,7 @@ class EarlyFusion(nn.Module):
|
|
1223
1217
|
Args:
|
1224
1218
|
in_channels (int): Number of input channels.
|
1225
1219
|
"""
|
1226
|
-
def
|
1220
|
+
def __init__(self, in_channels):
|
1227
1221
|
super(EarlyFusion, self).__init__()
|
1228
1222
|
self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=1, stride=1)
|
1229
1223
|
|
@@ -1242,7 +1236,7 @@ class EarlyFusion(nn.Module):
|
|
1242
1236
|
|
1243
1237
|
# Spatial Attention Mechanism
|
1244
1238
|
class SpatialAttention(nn.Module):
|
1245
|
-
def
|
1239
|
+
def __init__(self, kernel_size=7):
|
1246
1240
|
"""
|
1247
1241
|
Initializes the SpatialAttention module.
|
1248
1242
|
|
@@ -1287,7 +1281,7 @@ class MultiScaleBlockWithAttention(nn.Module):
|
|
1287
1281
|
forward: Forward method for the module.
|
1288
1282
|
"""
|
1289
1283
|
|
1290
|
-
def
|
1284
|
+
def __init__(self, in_channels, out_channels):
|
1291
1285
|
super(MultiScaleBlockWithAttention, self).__init__()
|
1292
1286
|
self.dilated_conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, dilation=1, padding=1)
|
1293
1287
|
self.spatial_attention = nn.Conv2d(out_channels, out_channels, kernel_size=1)
|
@@ -1320,7 +1314,7 @@ class MultiScaleBlockWithAttention(nn.Module):
|
|
1320
1314
|
|
1321
1315
|
# Final Classifier
|
1322
1316
|
class CustomCellClassifier(nn.Module):
|
1323
|
-
def
|
1317
|
+
def __init__(self, num_classes, pathogen_channel, use_attention, use_checkpoint, dropout_rate):
|
1324
1318
|
super(CustomCellClassifier, self).__init__()
|
1325
1319
|
self.early_fusion = EarlyFusion(in_channels=3)
|
1326
1320
|
|
@@ -1349,7 +1343,7 @@ class CustomCellClassifier(nn.Module):
|
|
1349
1343
|
|
1350
1344
|
#CNN and Transformer class, pick any Torch model.
|
1351
1345
|
class TorchModel(nn.Module):
|
1352
|
-
def
|
1346
|
+
def __init__(self, model_name='resnet50', pretrained=True, dropout_rate=None, use_checkpoint=False):
|
1353
1347
|
super(TorchModel, self).__init__()
|
1354
1348
|
self.model_name = model_name
|
1355
1349
|
self.use_checkpoint = use_checkpoint
|
@@ -1423,7 +1417,7 @@ class TorchModel(nn.Module):
|
|
1423
1417
|
return logits
|
1424
1418
|
|
1425
1419
|
class FocalLossWithLogits(nn.Module):
|
1426
|
-
def
|
1420
|
+
def __init__(self, alpha=1, gamma=2):
|
1427
1421
|
super(FocalLossWithLogits, self).__init__()
|
1428
1422
|
self.alpha = alpha
|
1429
1423
|
self.gamma = gamma
|
@@ -1435,7 +1429,7 @@ class FocalLossWithLogits(nn.Module):
|
|
1435
1429
|
return focal_loss.mean()
|
1436
1430
|
|
1437
1431
|
class ResNet(nn.Module):
|
1438
|
-
def
|
1432
|
+
def __init__(self, resnet_type='resnet50', dropout_rate=None, use_checkpoint=False, init_weights='imagenet'):
|
1439
1433
|
super(ResNet, self).__init__()
|
1440
1434
|
|
1441
1435
|
resnet_map = {
|
@@ -1788,25 +1782,24 @@ def annotate_predictions(csv_loc):
|
|
1788
1782
|
df['cond'] = df.apply(assign_condition, axis=1)
|
1789
1783
|
return df
|
1790
1784
|
|
1791
|
-
def
|
1785
|
+
def initiate_counter(counter_, lock_):
|
1792
1786
|
global counter, lock
|
1793
1787
|
counter = counter_
|
1794
1788
|
lock = lock_
|
1795
1789
|
|
1796
|
-
def add_images_to_tar(
|
1797
|
-
global counter, lock, total_images
|
1798
|
-
paths_chunk, tar_path = args
|
1790
|
+
def add_images_to_tar(paths_chunk, tar_path, total_images):
|
1799
1791
|
with tarfile.open(tar_path, 'w') as tar:
|
1800
|
-
for img_path in paths_chunk:
|
1792
|
+
for i, img_path in enumerate(paths_chunk):
|
1801
1793
|
arcname = os.path.basename(img_path)
|
1802
1794
|
try:
|
1803
1795
|
tar.add(img_path, arcname=arcname)
|
1804
1796
|
with lock:
|
1805
1797
|
counter.value += 1
|
1806
|
-
|
1798
|
+
if counter.value % 100 == 0: # Print every 100 updates
|
1799
|
+
progress = (counter.value / total_images) * 100
|
1800
|
+
print(f"Progress: {counter.value}/{total_images} ({progress:.2f}%)", end='\r', file=sys.stdout, flush=True)
|
1807
1801
|
except FileNotFoundError:
|
1808
1802
|
print(f"File not found: {img_path}")
|
1809
|
-
return tar_path
|
1810
1803
|
|
1811
1804
|
def generate_fraction_map(df, gene_column, min_frequency=0.0):
|
1812
1805
|
df['fraction'] = df['count']/df['well_read_sum']
|
@@ -2255,8 +2248,8 @@ def dice_coefficient(mask1, mask2):
|
|
2255
2248
|
def extract_boundaries(mask, dilation_radius=1):
|
2256
2249
|
binary_mask = (mask > 0).astype(np.uint8)
|
2257
2250
|
struct_elem = np.ones((dilation_radius*2+1, dilation_radius*2+1))
|
2258
|
-
dilated = binary_dilation(binary_mask, footprint=struct_elem)
|
2259
|
-
eroded = binary_erosion(binary_mask, footprint=struct_elem)
|
2251
|
+
dilated = morphology.binary_dilation(binary_mask, footprint=struct_elem)
|
2252
|
+
eroded = morphology.binary_erosion(binary_mask, footprint=struct_elem)
|
2260
2253
|
boundary = dilated ^ eroded
|
2261
2254
|
return boundary
|
2262
2255
|
|
@@ -2669,6 +2662,13 @@ def _filter_cp_masks(masks, flows, filter_size, filter_intensity, minimum_size,
|
|
2669
2662
|
print(f'Number of objects before filtration: {num_objects}')
|
2670
2663
|
plot_masks(batch=image, masks=mask, flows=flow, cmap='inferno', figuresize=figuresize, nr=1, file_type='.npz', print_object_number=True)
|
2671
2664
|
|
2665
|
+
if merge:
|
2666
|
+
mask = merge_touching_objects(mask, threshold=0.66)
|
2667
|
+
if plot and idx == 0:
|
2668
|
+
num_objects = mask_object_count(mask)
|
2669
|
+
print(f'Number of objects after merging adjacent objects, : {num_objects}')
|
2670
|
+
plot_masks(batch=image, masks=mask, flows=flow, cmap='inferno', figuresize=figuresize, nr=1, file_type='.npz', print_object_number=True)
|
2671
|
+
|
2672
2672
|
if filter_size:
|
2673
2673
|
props = measure.regionprops_table(mask, properties=['label', 'area'])
|
2674
2674
|
valid_labels = props['label'][np.logical_and(props['area'] > minimum_size, props['area'] < maximum_size)]
|
@@ -2714,13 +2714,6 @@ def _filter_cp_masks(masks, flows, filter_size, filter_intensity, minimum_size,
|
|
2714
2714
|
print(f'Number of objects after removing border objects, : {num_objects}')
|
2715
2715
|
plot_masks(batch=image, masks=mask, flows=flow, cmap='inferno', figuresize=figuresize, nr=1, file_type='.npz', print_object_number=True)
|
2716
2716
|
|
2717
|
-
if merge:
|
2718
|
-
mask = merge_touching_objects(mask, threshold=0.25)
|
2719
|
-
if plot and idx == 0:
|
2720
|
-
num_objects = mask_object_count(mask)
|
2721
|
-
print(f'Number of objects after merging adjacent objects, : {num_objects}')
|
2722
|
-
plot_masks(batch=image, masks=mask, flows=flow, cmap='inferno', figuresize=figuresize, nr=1, file_type='.npz', print_object_number=True)
|
2723
|
-
|
2724
2717
|
mask_stack.append(mask)
|
2725
2718
|
|
2726
2719
|
return mask_stack
|
@@ -2789,6 +2782,7 @@ def _run_test_mode(src, regex, timelapse=False):
|
|
2789
2782
|
|
2790
2783
|
# Prepare for random selection
|
2791
2784
|
set_identifiers = list(images_by_set.keys())
|
2785
|
+
random.seed(42)
|
2792
2786
|
random.shuffle(set_identifiers) # Randomize the order
|
2793
2787
|
|
2794
2788
|
# Select a subset based on the test_images count
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: spacr
|
3
|
-
Version: 0.0.
|
3
|
+
Version: 0.0.35
|
4
4
|
Summary: Spatial phenotype analysis of crisp screens (SpaCr)
|
5
5
|
Home-page: https://github.com/EinarOlafsson/spacr
|
6
6
|
Author: Einar Birnir Olafsson
|
@@ -19,6 +19,7 @@ Requires-Dist: scikit-image >=0.22.0
|
|
19
19
|
Requires-Dist: scikit-learn >=1.4.1
|
20
20
|
Requires-Dist: seaborn >=0.13.2
|
21
21
|
Requires-Dist: matplotlib >=3.8.3
|
22
|
+
Requires-Dist: shap >=0.45.0
|
22
23
|
Requires-Dist: pillow >=10.2.0
|
23
24
|
Requires-Dist: imageio >=2.34.0
|
24
25
|
Requires-Dist: scipy >=1.12.0
|
@@ -31,6 +32,8 @@ Requires-Dist: IPython >=8.18.1
|
|
31
32
|
Requires-Dist: opencv-python-headless >=4.9.0.80
|
32
33
|
Requires-Dist: umap >=0.1.1
|
33
34
|
Requires-Dist: ttkthemes >=3.2.2
|
35
|
+
Requires-Dist: xgboost >=2.0.3
|
36
|
+
Requires-Dist: PyWavelets >=1.6.0
|
34
37
|
Requires-Dist: lxml >=5.1.0
|
35
38
|
Provides-Extra: dev
|
36
39
|
Requires-Dist: pytest >=3.9 ; extra == 'dev'
|
@@ -49,7 +52,7 @@ Requires-Dist: opencv-python-headless ; extra == 'headless'
|
|
49
52
|
<tr>
|
50
53
|
<td>
|
51
54
|
|
52
|
-
Spatial phenotype analysis of
|
55
|
+
Spatial phenotype analysis of CRISPR-Cas9 screens (SpaCr). The spatial organization of organelles and proteins within cells constitutes a key level of functional regulation. In the context of infectious disease, the spatial relationships between host cell structures and intracellular pathogens are critical to understand host clearance mechanisms and how pathogens evade them. Spacr is a Python-based software package for generating single cell image data for deep-learning sub-cellular/cellular phenotypic classification from pooled genetic CRISPR-Cas9 screens. Spacr provides a flexible toolset to extract single cell images and measurements from high content cell painting experiments, train deep-learning models to classify cellular/ subcellular phenotypes, simulate and analyze pooled CRISPR-Cas9 imaging screens.
|
53
56
|
|
54
57
|
</td>
|
55
58
|
<td>
|
@@ -0,0 +1,35 @@
|
|
1
|
+
spacr/__init__.py,sha256=mDi-Qu5r1vZnqIbUBV1JAoSq-mxmMEOmni1JSG2e4Wo,879
|
2
|
+
spacr/__main__.py,sha256=_qRkhbFrH_cXr7AZs6KHL8Hh4VApqNdpNCtiKn2ePTo,285
|
3
|
+
spacr/alpha.py,sha256=1LUtTaeVHqcTMxoCMA7mlsNYyaR0KEaXglwXsUQLwKo,13776
|
4
|
+
spacr/annotate_app.py,sha256=8ziG6HZ6Kvany2yYDR15jtW84OnPFe8SZXykIrrNfX0,20031
|
5
|
+
spacr/chris.py,sha256=YlBjSgeZaY8HPy6jkrT_ISAnCMAKVfvCxF0I9eAZLFM,2418
|
6
|
+
spacr/cli.py,sha256=507jfOOEV8BoL4eeUcblvH-iiDHdBrEVJLu1ghAAPSc,1800
|
7
|
+
spacr/core.py,sha256=_R8gXNnjf680yrnbCi2piWQUz7PDbqWYn7SL5MACLfo,156457
|
8
|
+
spacr/foldseek.py,sha256=cWtLzvFF2O_mq5I71UMiuU9DTvDCp7wl6aaWAZRrBZc,33970
|
9
|
+
spacr/get_alfafold_structures.py,sha256=n0g8gne-oyAV3Uo6qxZoJq5X1cUUyD8u0pOC_W2PX40,3541
|
10
|
+
spacr/graph_learning.py,sha256=sD4eOC7Q16rr7WO20mCi_E16_LqioGUUgPamAHIIeNI,12568
|
11
|
+
spacr/graph_learning_lap.py,sha256=MyNRLb63gsjBlui-ByZ0anHugYulL6M-OsGm8rnGBmE,3385
|
12
|
+
spacr/gui.py,sha256=2d2JHYVWhEFUkB3u_2OarCV_V07eLtJKUMKVKuJ7nAo,6430
|
13
|
+
spacr/gui_classify_app.py,sha256=RqVC5Ac0GiFhKg1qUyU-xfbSVMwSTjySkReWCvmsZ1U,7917
|
14
|
+
spacr/gui_mask_app.py,sha256=c0NwgzzMSYx7xE60sp6zLE1h7ct3pnT7n7gz1SqmWbs,9750
|
15
|
+
spacr/gui_measure_app.py,sha256=I6OXPJZZ0sG3OqRO4l7Fr9Sb_5jr6X1Y2LXLgPW_Q_4,9675
|
16
|
+
spacr/gui_sim_app.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
17
|
+
spacr/gui_utils.py,sha256=URyEUfDn7RSE8iWBDwmC1jXWNA2Vl1oA2GsSjbSbuFU,40715
|
18
|
+
spacr/io.py,sha256=-Ho1Rw46s5DfutcTAtp0RhEs3I-GPgXIphOpjJHutPU,108688
|
19
|
+
spacr/logger.py,sha256=7Zqr3TuuOQLWT32gYr2q1qvv7x0a2JhLANmZcnBXAW8,670
|
20
|
+
spacr/mask_app.py,sha256=p9oA0JH0Rcly2Fbsrg-Vye_iThRCZZF9axU6hkE3SAI,39376
|
21
|
+
spacr/measure.py,sha256=_f6UDugw75rILEg0uo2-QsUGUvc4AQdrdHl-BPZk74I,54686
|
22
|
+
spacr/old_code.py,sha256=KxljHpKNsV5EfX9ifN2xJTnUeqAhyabZyfDWd5THOOc,11226
|
23
|
+
spacr/plot.py,sha256=VtDKTJ_zo8CAVC3ILuIN_wUP6197vq089wNZuom7T8g,61655
|
24
|
+
spacr/sim.py,sha256=2NR5hm--HVcYQnj1SCHoUCVbh_b2XUjjjfoAUIXFwnQ,72997
|
25
|
+
spacr/timelapse.py,sha256=plPjR8nZ7_Q50VAvMvHK2TUE4F-vh7R23JnI6tSW02g,39661
|
26
|
+
spacr/train.py,sha256=lp66dWYkiMMlgdYlMjAsJnkIZFWLizKB-xwyVnKgFBs,25904
|
27
|
+
spacr/umap.py,sha256=4QSrQ16Og-Ijq-SwguMQT2f20UWz1LE5HQeSLmzSl8c,29370
|
28
|
+
spacr/utils.py,sha256=D3WRf_0w0T6dZHh3BfwScGBQjorljgWW6CQUdM0ToN8,120918
|
29
|
+
spacr/version.py,sha256=axH5tnGwtgSnJHb5IDhiu4Zjk5GhLyAEDRe-rnaoFOA,409
|
30
|
+
spacr-0.0.35.dist-info/LICENSE,sha256=SR-2MeGc6SCM1UORJYyarSWY_A-JaOMFDj7ReSs9tRM,1083
|
31
|
+
spacr-0.0.35.dist-info/METADATA,sha256=0iEqhFIza7SaHVeYbl0Rc8WocPgGEVtiwMoWsIBBZzQ,4973
|
32
|
+
spacr-0.0.35.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
33
|
+
spacr-0.0.35.dist-info/entry_points.txt,sha256=_khj_UcegrI5N29QcgxECsxVsfpVQzG3U5OUoKdtTcw,288
|
34
|
+
spacr-0.0.35.dist-info/top_level.txt,sha256=GJPU8FgwRXGzKeut6JopsSRY2R8T3i9lDgya42tLInY,6
|
35
|
+
spacr-0.0.35.dist-info/RECORD,,
|
@@ -0,0 +1,8 @@
|
|
1
|
+
[console_scripts]
|
2
|
+
annotate = spacr.annotate_app:gui_annotation
|
3
|
+
classify = spacr.gui_classify_app:gui_classify
|
4
|
+
gui = spacr.gui:gui_app
|
5
|
+
make_masks = spacr.mask_app:gui_make_masks
|
6
|
+
mask = spacr.gui_mask_app:gui_mask
|
7
|
+
measure = spacr.gui_measure_app:gui_measure
|
8
|
+
sim = spacr.gui_sim_app:gui_sim
|