spacr 0.0.18__py3-none-any.whl → 0.0.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- spacr/alpha.py +291 -14
- spacr/annotate_app.py +2 -2
- spacr/core.py +1377 -296
- spacr/foldseek.py +793 -0
- spacr/get_alfafold_structures.py +72 -0
- spacr/graph_learning.py +259 -65
- spacr/graph_learning_lap.py +73 -71
- spacr/gui_classify_app.py +5 -21
- spacr/gui_mask_app.py +36 -30
- spacr/gui_measure_app.py +10 -24
- spacr/gui_utils.py +82 -54
- spacr/io.py +505 -205
- spacr/measure.py +160 -80
- spacr/old_code.py +155 -1
- spacr/plot.py +243 -99
- spacr/sim.py +666 -119
- spacr/timelapse.py +343 -52
- spacr/train.py +18 -10
- spacr/utils.py +252 -151
- {spacr-0.0.18.dist-info → spacr-0.0.21.dist-info}/METADATA +32 -27
- spacr-0.0.21.dist-info/RECORD +33 -0
- {spacr-0.0.18.dist-info → spacr-0.0.21.dist-info}/WHEEL +1 -1
- spacr/gui_temp.py +0 -212
- spacr/test_annotate_app.py +0 -58
- spacr/test_plot.py +0 -43
- spacr/test_train.py +0 -39
- spacr/test_utils.py +0 -33
- spacr-0.0.18.dist-info/RECORD +0 -36
- {spacr-0.0.18.dist-info → spacr-0.0.21.dist-info}/LICENSE +0 -0
- {spacr-0.0.18.dist-info → spacr-0.0.21.dist-info}/entry_points.txt +0 -0
- {spacr-0.0.18.dist-info → spacr-0.0.21.dist-info}/top_level.txt +0 -0
spacr/utils.py
CHANGED
@@ -1,6 +1,8 @@
|
|
1
|
-
import os, re, sqlite3, gc, torch, torchvision, time, random, string, shutil, cv2, tarfile, glob
|
1
|
+
import sys, os, re, sqlite3, gc, torch, torchvision, time, random, string, shutil, cv2, tarfile, glob
|
2
2
|
|
3
3
|
import numpy as np
|
4
|
+
from cellpose import models as cp_models
|
5
|
+
from cellpose import denoise
|
4
6
|
from skimage import morphology
|
5
7
|
from skimage.measure import label, regionprops_table, regionprops
|
6
8
|
import skimage.measure as measure
|
@@ -18,6 +20,8 @@ from functools import reduce
|
|
18
20
|
from IPython.display import display, clear_output
|
19
21
|
from multiprocessing import Pool, cpu_count
|
20
22
|
from skimage.transform import resize as resizescikit
|
23
|
+
from skimage.morphology import dilation, square
|
24
|
+
from skimage.measure import find_contours
|
21
25
|
import torch.nn as nn
|
22
26
|
import torch.nn.functional as F
|
23
27
|
#from torchsummary import summary
|
@@ -29,6 +33,7 @@ from skimage.segmentation import clear_border
|
|
29
33
|
import seaborn as sns
|
30
34
|
import matplotlib.pyplot as plt
|
31
35
|
import scipy.ndimage as ndi
|
36
|
+
from scipy.spatial import distance
|
32
37
|
from scipy.stats import fisher_exact
|
33
38
|
from scipy.ndimage import binary_erosion, binary_dilation
|
34
39
|
from skimage.exposure import rescale_intensity
|
@@ -36,6 +41,7 @@ from sklearn.metrics import auc, precision_recall_curve
|
|
36
41
|
from sklearn.model_selection import train_test_split
|
37
42
|
from sklearn.linear_model import Lasso, Ridge
|
38
43
|
from sklearn.preprocessing import OneHotEncoder
|
44
|
+
from sklearn.cluster import KMeans
|
39
45
|
from torchvision.models.resnet import ResNet18_Weights, ResNet34_Weights, ResNet50_Weights, ResNet101_Weights, ResNet152_Weights
|
40
46
|
|
41
47
|
from .logger import log_function_call
|
@@ -45,6 +51,54 @@ from .logger import log_function_call
|
|
45
51
|
#from .plot import _plot_images_on_grid, plot_masks, _plot_histograms_and_stats, plot_resize, _plot_plates, _reg_v_plot, plot_masks
|
46
52
|
#from .core import identify_masks
|
47
53
|
|
54
|
+
|
55
|
+
def _gen_rgb_image(image, cahnnels):
|
56
|
+
rgb_image = np.take(image, cahnnels, axis=-1)
|
57
|
+
rgb_image = rgb_image.astype(float)
|
58
|
+
rgb_image -= rgb_image.min()
|
59
|
+
rgb_image /= rgb_image.max()
|
60
|
+
return rgb_image
|
61
|
+
|
62
|
+
def _outline_and_overlay(image, rgb_image, mask_dims, outline_colors, outline_thickness):
|
63
|
+
from concurrent.futures import ThreadPoolExecutor
|
64
|
+
import cv2
|
65
|
+
|
66
|
+
outlines = []
|
67
|
+
overlayed_image = rgb_image.copy()
|
68
|
+
|
69
|
+
def process_dim(mask_dim):
|
70
|
+
mask = np.take(image, mask_dim, axis=-1)
|
71
|
+
outline = np.zeros_like(mask, dtype=np.uint8) # Use uint8 for contour detection efficiency
|
72
|
+
|
73
|
+
# Find and draw contours
|
74
|
+
for j in np.unique(mask):
|
75
|
+
#for j in np.unique(mask)[1:]:
|
76
|
+
contours = find_contours(mask == j, 0.5)
|
77
|
+
# Convert contours for OpenCV format and draw directly to optimize
|
78
|
+
cv_contours = [np.flip(contour.astype(int), axis=1) for contour in contours]
|
79
|
+
cv2.drawContours(outline, cv_contours, -1, color=int(j), thickness=outline_thickness)
|
80
|
+
|
81
|
+
return dilation(outline, square(outline_thickness))
|
82
|
+
|
83
|
+
# Parallel processing
|
84
|
+
with ThreadPoolExecutor() as executor:
|
85
|
+
outlines = list(executor.map(process_dim, mask_dims))
|
86
|
+
|
87
|
+
# Overlay outlines onto the RGB image in a batch/vectorized manner if possible
|
88
|
+
for i, outline in enumerate(outlines):
|
89
|
+
# This part may need to be adapted to your specific use case and available functions
|
90
|
+
# The goal is to overlay each outline with its respective color more efficiently
|
91
|
+
color = outline_colors[i % len(outline_colors)]
|
92
|
+
for j in np.unique(outline)[1:]:
|
93
|
+
mask = outline == j
|
94
|
+
overlayed_image[mask] = color # Direct assignment with broadcasting
|
95
|
+
|
96
|
+
# Remove mask_dims from image
|
97
|
+
channels_to_keep = [i for i in range(image.shape[-1]) if i not in mask_dims]
|
98
|
+
image = np.take(image, channels_to_keep, axis=-1)
|
99
|
+
|
100
|
+
return overlayed_image, outlines, image
|
101
|
+
|
48
102
|
def _convert_cq1_well_id(well_id):
|
49
103
|
"""
|
50
104
|
Converts a well ID to the CQ1 well format.
|
@@ -114,7 +168,7 @@ def _extract_filename_metadata(filenames, src, images_by_key, regular_expression
|
|
114
168
|
if metadata_type =='cq1':
|
115
169
|
orig_wellID = wellID
|
116
170
|
wellID = _convert_cq1_well_id(wellID)
|
117
|
-
clear_output(wait=True)
|
171
|
+
#clear_output(wait=True)
|
118
172
|
print(f'Converted Well ID: {orig_wellID} to {wellID}', end='\r', flush=True)
|
119
173
|
|
120
174
|
if pick_slice:
|
@@ -673,9 +727,6 @@ def _crop_center(img, cell_mask, new_width, new_height, normalize=(2,98)):
|
|
673
727
|
img = img[start_y:end_y, start_x:end_x, :]
|
674
728
|
return img
|
675
729
|
|
676
|
-
|
677
|
-
|
678
|
-
|
679
730
|
def _masks_to_masks_stack(masks):
|
680
731
|
"""
|
681
732
|
Convert a list of masks into a stack of masks.
|
@@ -692,53 +743,50 @@ def _masks_to_masks_stack(masks):
|
|
692
743
|
return mask_stack
|
693
744
|
|
694
745
|
def _get_diam(mag, obj):
|
695
|
-
|
696
|
-
|
697
|
-
|
698
|
-
|
699
|
-
|
700
|
-
|
701
|
-
|
702
|
-
|
703
|
-
|
704
|
-
|
705
|
-
|
706
|
-
|
707
|
-
if
|
708
|
-
|
709
|
-
|
710
|
-
|
711
|
-
|
712
|
-
|
713
|
-
|
714
|
-
|
715
|
-
|
716
|
-
elif
|
717
|
-
if
|
718
|
-
|
719
|
-
if
|
720
|
-
|
721
|
-
if
|
722
|
-
|
746
|
+
|
747
|
+
if mag == 20:
|
748
|
+
if obj == 'cell':
|
749
|
+
diamiter = 120
|
750
|
+
elif obj == 'nucleus':
|
751
|
+
diamiter = 60
|
752
|
+
elif obj == 'pathogen':
|
753
|
+
diamiter = 30
|
754
|
+
else:
|
755
|
+
raise ValueError("Invalid magnification: Use 20, 40 or 60")
|
756
|
+
|
757
|
+
elif mag == 40:
|
758
|
+
if obj == 'cell':
|
759
|
+
diamiter = 160
|
760
|
+
elif obj == 'nucleus':
|
761
|
+
diamiter = 80
|
762
|
+
elif obj == 'pathogen':
|
763
|
+
diamiter = 40
|
764
|
+
else:
|
765
|
+
raise ValueError("Invalid magnification: Use 20, 40 or 60")
|
766
|
+
|
767
|
+
elif mag == 60:
|
768
|
+
if obj == 'cell':
|
769
|
+
diamiter = 200
|
770
|
+
if obj == 'nucleus':
|
771
|
+
diamiter = 90
|
772
|
+
if obj == 'pathogen':
|
773
|
+
diamiter = 75
|
774
|
+
else:
|
775
|
+
raise ValueError("Invalid magnification: Use 20, 40 or 60")
|
723
776
|
else:
|
724
|
-
raise ValueError("Invalid
|
725
|
-
|
777
|
+
raise ValueError("Invalid magnification: Use 20, 40 or 60")
|
778
|
+
|
726
779
|
return diamiter
|
727
780
|
|
728
781
|
def _get_object_settings(object_type, settings):
|
729
|
-
|
730
782
|
object_settings = {}
|
731
|
-
|
732
|
-
object_settings['filter_size'] = False
|
733
|
-
object_settings['filter_dimm'] = False
|
734
|
-
print(object_type)
|
783
|
+
|
735
784
|
object_settings['diameter'] = _get_diam(settings['magnification'], obj=object_type)
|
736
|
-
object_settings['
|
737
|
-
object_settings['
|
738
|
-
object_settings['maximum_size'] = object_settings['minimum_size']*50
|
785
|
+
object_settings['minimum_size'] = (object_settings['diameter']**2)/4
|
786
|
+
object_settings['maximum_size'] = (object_settings['diameter']**2)*10
|
739
787
|
object_settings['merge'] = False
|
740
|
-
object_settings['net_avg'] = True
|
741
788
|
object_settings['resample'] = True
|
789
|
+
object_settings['remove_border_objects'] = False
|
742
790
|
object_settings['model_name'] = 'cyto'
|
743
791
|
|
744
792
|
if object_type == 'cell':
|
@@ -746,20 +794,28 @@ def _get_object_settings(object_type, settings):
|
|
746
794
|
object_settings['model_name'] = 'cyto'
|
747
795
|
else:
|
748
796
|
object_settings['model_name'] = 'cyto2'
|
749
|
-
|
797
|
+
object_settings['filter_size'] = False
|
798
|
+
object_settings['filter_intensity'] = False
|
799
|
+
object_settings['restore_type'] = settings.get('cell_restore_type', None)
|
800
|
+
|
750
801
|
elif object_type == 'nucleus':
|
751
802
|
object_settings['model_name'] = 'nuclei'
|
803
|
+
object_settings['filter_size'] = False
|
804
|
+
object_settings['filter_intensity'] = False
|
805
|
+
object_settings['restore_type'] = settings.get('nucleus_restore_type', None)
|
752
806
|
|
753
807
|
elif object_type == 'pathogen':
|
754
|
-
object_settings['model_name'] = 'cyto3'
|
755
|
-
|
756
|
-
elif object_type == 'pathogen_nucleus':
|
757
|
-
object_settings['filter_size'] = True
|
758
808
|
object_settings['model_name'] = 'cyto'
|
809
|
+
object_settings['filter_size'] = True
|
810
|
+
object_settings['filter_intensity'] = False
|
811
|
+
object_settings['restore_type'] = settings.get('pathogen_restore_type', None)
|
812
|
+
object_settings['merge'] = settings['merge_pathogens']
|
759
813
|
|
760
814
|
else:
|
761
815
|
print(f'Object type: {object_type} not supported. Supported object types are : cell, nucleus and pathogen')
|
762
|
-
|
816
|
+
|
817
|
+
if settings['verbose']:
|
818
|
+
print(object_settings)
|
763
819
|
|
764
820
|
return object_settings
|
765
821
|
|
@@ -786,6 +842,7 @@ def _pivot_counts_table(db_path):
|
|
786
842
|
return df
|
787
843
|
|
788
844
|
def _pivot_dataframe(df):
|
845
|
+
|
789
846
|
"""
|
790
847
|
Pivot the DataFrame.
|
791
848
|
|
@@ -812,61 +869,32 @@ def _pivot_counts_table(db_path):
|
|
812
869
|
pivoted_df.to_sql('pivoted_counts', conn, if_exists='replace', index=False)
|
813
870
|
conn.close()
|
814
871
|
|
815
|
-
def
|
816
|
-
cellpose_channels = {}
|
817
|
-
if nucleus_chann_dim in mask_channels:
|
818
|
-
cellpose_channels['nucleus'] = [0, mask_channels.index(nucleus_chann_dim)]
|
819
|
-
if pathogen_chann_dim in mask_channels:
|
820
|
-
cellpose_channels['pathogen'] = [0, mask_channels.index(pathogen_chann_dim)]
|
821
|
-
if cell_chann_dim in mask_channels:
|
822
|
-
cellpose_channels['cell'] = [0, mask_channels.index(cell_chann_dim)]
|
823
|
-
return cellpose_channels
|
872
|
+
def _get_cellpose_channels(src, nucleus_channel, pathogen_channel, cell_channel):
|
824
873
|
|
825
|
-
|
826
|
-
|
827
|
-
|
874
|
+
cell_mask_path = os.path.join(src, 'norm_channel_stack', 'cell_mask_stack')
|
875
|
+
nucleus_mask_path = os.path.join(src, 'norm_channel_stack', 'nucleus_mask_stack')
|
876
|
+
pathogen_mask_path = os.path.join(src, 'norm_channel_stack', 'pathogen_mask_stack')
|
828
877
|
|
829
|
-
# Initialize a list to keep track of the channels in their new order
|
830
|
-
new_channel_order = []
|
831
|
-
|
832
|
-
# Add each channel to the new order list if it is not None
|
833
|
-
if cell_channel is not None:
|
834
|
-
new_channel_order.append(('cell', cell_channel))
|
835
|
-
if nucleus_channel is not None:
|
836
|
-
new_channel_order.append(('nucleus', nucleus_channel))
|
837
|
-
if pathogen_channel is not None:
|
838
|
-
new_channel_order.append(('pathogen', pathogen_channel))
|
839
|
-
|
840
|
-
# Sort the list based on the original channel indices to maintain the original order
|
841
|
-
new_channel_order.sort(key=lambda x: x[1])
|
842
|
-
print(new_channel_order)
|
843
|
-
# Assign new indices based on the sorted order
|
844
|
-
for new_index, (channel_name, _) in enumerate(new_channel_order):
|
845
|
-
cellpose_channels[channel_name] = [new_index, 0]
|
846
|
-
|
847
|
-
if cell_channel is not None and nucleus_channel is not None:
|
848
|
-
cellpose_channels['cell'][1] = cellpose_channels['nucleus'][0]
|
849
|
-
|
850
|
-
return cellpose_channels
|
851
878
|
|
852
|
-
|
879
|
+
if os.path.exists(cell_mask_path) or os.path.exists(nucleus_mask_path) or os.path.exists(pathogen_mask_path):
|
880
|
+
if nucleus_channel is None or nucleus_channel is None or nucleus_channel is None:
|
881
|
+
print('Warning: Cellpose masks already exist. Unexpected behaviour when setting any object dimention to None when the object masks have been created.')
|
882
|
+
|
853
883
|
cellpose_channels = {}
|
854
884
|
if not nucleus_channel is None:
|
855
885
|
cellpose_channels['nucleus'] = [0,0]
|
856
886
|
|
857
887
|
if not pathogen_channel is None:
|
858
888
|
if not nucleus_channel is None:
|
859
|
-
|
889
|
+
if not pathogen_channel is None:
|
890
|
+
cellpose_channels['pathogen'] = [0,2]
|
891
|
+
else:
|
892
|
+
cellpose_channels['pathogen'] = [0,1]
|
860
893
|
else:
|
861
894
|
cellpose_channels['pathogen'] = [0,0]
|
862
895
|
|
863
896
|
if not cell_channel is None:
|
864
897
|
if not nucleus_channel is None:
|
865
|
-
if not pathogen_channel is None:
|
866
|
-
cellpose_channels['cell'] = [0,2]
|
867
|
-
else:
|
868
|
-
cellpose_channels['cell'] = [0,1]
|
869
|
-
elif not pathogen_channel is None:
|
870
898
|
cellpose_channels['cell'] = [0,1]
|
871
899
|
else:
|
872
900
|
cellpose_channels['cell'] = [0,0]
|
@@ -1027,9 +1055,6 @@ def _group_by_well(df):
|
|
1027
1055
|
# Apply mean function to numeric columns and first to non-numeric
|
1028
1056
|
df_grouped = df.groupby(['plate', 'row', 'col']).agg({**{col: np.mean for col in numeric_cols}, **{col: 'first' for col in non_numeric_cols}})
|
1029
1057
|
return df_grouped
|
1030
|
-
|
1031
|
-
|
1032
|
-
|
1033
1058
|
|
1034
1059
|
###################################################
|
1035
1060
|
# Classify
|
@@ -1044,7 +1069,7 @@ class Cache:
|
|
1044
1069
|
cache (OrderedDict): The cache data structure.
|
1045
1070
|
"""
|
1046
1071
|
|
1047
|
-
def
|
1072
|
+
def __init__(self, max_size):
|
1048
1073
|
self.cache = OrderedDict()
|
1049
1074
|
self.max_size = max_size
|
1050
1075
|
|
@@ -1075,7 +1100,7 @@ class ScaledDotProductAttention(nn.Module):
|
|
1075
1100
|
|
1076
1101
|
"""
|
1077
1102
|
|
1078
|
-
def
|
1103
|
+
def __init__(self, d_k):
|
1079
1104
|
super(ScaledDotProductAttention, self).__init__()
|
1080
1105
|
self.d_k = d_k
|
1081
1106
|
|
@@ -1106,7 +1131,7 @@ class SelfAttention(nn.Module):
|
|
1106
1131
|
d_k (int): Dimensionality of the key and query vectors.
|
1107
1132
|
"""
|
1108
1133
|
|
1109
|
-
def
|
1134
|
+
def __init__(self, in_channels, d_k):
|
1110
1135
|
super(SelfAttention, self).__init__()
|
1111
1136
|
self.W_q = nn.Linear(in_channels, d_k)
|
1112
1137
|
self.W_k = nn.Linear(in_channels, d_k)
|
@@ -1130,7 +1155,7 @@ class SelfAttention(nn.Module):
|
|
1130
1155
|
return output
|
1131
1156
|
|
1132
1157
|
class ScaledDotProductAttention(nn.Module):
|
1133
|
-
def
|
1158
|
+
def __init__(self, d_k):
|
1134
1159
|
"""
|
1135
1160
|
Initializes the ScaledDotProductAttention module.
|
1136
1161
|
|
@@ -1167,7 +1192,7 @@ class SelfAttention(nn.Module):
|
|
1167
1192
|
in_channels (int): Number of input channels.
|
1168
1193
|
d_k (int): Dimensionality of the key and query vectors.
|
1169
1194
|
"""
|
1170
|
-
def
|
1195
|
+
def __init__(self, in_channels, d_k):
|
1171
1196
|
super(SelfAttention, self).__init__()
|
1172
1197
|
self.W_q = nn.Linear(in_channels, d_k)
|
1173
1198
|
self.W_k = nn.Linear(in_channels, d_k)
|
@@ -1198,7 +1223,7 @@ class EarlyFusion(nn.Module):
|
|
1198
1223
|
Args:
|
1199
1224
|
in_channels (int): Number of input channels.
|
1200
1225
|
"""
|
1201
|
-
def
|
1226
|
+
def __init__(self, in_channels):
|
1202
1227
|
super(EarlyFusion, self).__init__()
|
1203
1228
|
self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=1, stride=1)
|
1204
1229
|
|
@@ -1217,7 +1242,7 @@ class EarlyFusion(nn.Module):
|
|
1217
1242
|
|
1218
1243
|
# Spatial Attention Mechanism
|
1219
1244
|
class SpatialAttention(nn.Module):
|
1220
|
-
def
|
1245
|
+
def __init__(self, kernel_size=7):
|
1221
1246
|
"""
|
1222
1247
|
Initializes the SpatialAttention module.
|
1223
1248
|
|
@@ -1262,7 +1287,7 @@ class MultiScaleBlockWithAttention(nn.Module):
|
|
1262
1287
|
forward: Forward method for the module.
|
1263
1288
|
"""
|
1264
1289
|
|
1265
|
-
def
|
1290
|
+
def __init__(self, in_channels, out_channels):
|
1266
1291
|
super(MultiScaleBlockWithAttention, self).__init__()
|
1267
1292
|
self.dilated_conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, dilation=1, padding=1)
|
1268
1293
|
self.spatial_attention = nn.Conv2d(out_channels, out_channels, kernel_size=1)
|
@@ -1295,7 +1320,7 @@ class MultiScaleBlockWithAttention(nn.Module):
|
|
1295
1320
|
|
1296
1321
|
# Final Classifier
|
1297
1322
|
class CustomCellClassifier(nn.Module):
|
1298
|
-
def
|
1323
|
+
def __init__(self, num_classes, pathogen_channel, use_attention, use_checkpoint, dropout_rate):
|
1299
1324
|
super(CustomCellClassifier, self).__init__()
|
1300
1325
|
self.early_fusion = EarlyFusion(in_channels=3)
|
1301
1326
|
|
@@ -1324,7 +1349,7 @@ class CustomCellClassifier(nn.Module):
|
|
1324
1349
|
|
1325
1350
|
#CNN and Transformer class, pick any Torch model.
|
1326
1351
|
class TorchModel(nn.Module):
|
1327
|
-
def
|
1352
|
+
def __init__(self, model_name='resnet50', pretrained=True, dropout_rate=None, use_checkpoint=False):
|
1328
1353
|
super(TorchModel, self).__init__()
|
1329
1354
|
self.model_name = model_name
|
1330
1355
|
self.use_checkpoint = use_checkpoint
|
@@ -1398,7 +1423,7 @@ class TorchModel(nn.Module):
|
|
1398
1423
|
return logits
|
1399
1424
|
|
1400
1425
|
class FocalLossWithLogits(nn.Module):
|
1401
|
-
def
|
1426
|
+
def __init__(self, alpha=1, gamma=2):
|
1402
1427
|
super(FocalLossWithLogits, self).__init__()
|
1403
1428
|
self.alpha = alpha
|
1404
1429
|
self.gamma = gamma
|
@@ -1410,7 +1435,7 @@ class FocalLossWithLogits(nn.Module):
|
|
1410
1435
|
return focal_loss.mean()
|
1411
1436
|
|
1412
1437
|
class ResNet(nn.Module):
|
1413
|
-
def
|
1438
|
+
def __init__(self, resnet_type='resnet50', dropout_rate=None, use_checkpoint=False, init_weights='imagenet'):
|
1414
1439
|
super(ResNet, self).__init__()
|
1415
1440
|
|
1416
1441
|
resnet_map = {
|
@@ -1763,25 +1788,24 @@ def annotate_predictions(csv_loc):
|
|
1763
1788
|
df['cond'] = df.apply(assign_condition, axis=1)
|
1764
1789
|
return df
|
1765
1790
|
|
1766
|
-
def
|
1791
|
+
def initiate_counter(counter_, lock_):
|
1767
1792
|
global counter, lock
|
1768
1793
|
counter = counter_
|
1769
1794
|
lock = lock_
|
1770
1795
|
|
1771
|
-
def add_images_to_tar(
|
1772
|
-
global counter, lock, total_images
|
1773
|
-
paths_chunk, tar_path = args
|
1796
|
+
def add_images_to_tar(paths_chunk, tar_path, total_images):
|
1774
1797
|
with tarfile.open(tar_path, 'w') as tar:
|
1775
|
-
for img_path in paths_chunk:
|
1798
|
+
for i, img_path in enumerate(paths_chunk):
|
1776
1799
|
arcname = os.path.basename(img_path)
|
1777
1800
|
try:
|
1778
1801
|
tar.add(img_path, arcname=arcname)
|
1779
1802
|
with lock:
|
1780
1803
|
counter.value += 1
|
1781
|
-
|
1804
|
+
if counter.value % 100 == 0: # Print every 100 updates
|
1805
|
+
progress = (counter.value / total_images) * 100
|
1806
|
+
print(f"Progress: {counter.value}/{total_images} ({progress:.2f}%)", end='\r', file=sys.stdout, flush=True)
|
1782
1807
|
except FileNotFoundError:
|
1783
1808
|
print(f"File not found: {img_path}")
|
1784
|
-
return tar_path
|
1785
1809
|
|
1786
1810
|
def generate_fraction_map(df, gene_column, min_frequency=0.0):
|
1787
1811
|
df['fraction'] = df['count']/df['well_read_sum']
|
@@ -2230,8 +2254,8 @@ def dice_coefficient(mask1, mask2):
|
|
2230
2254
|
def extract_boundaries(mask, dilation_radius=1):
|
2231
2255
|
binary_mask = (mask > 0).astype(np.uint8)
|
2232
2256
|
struct_elem = np.ones((dilation_radius*2+1, dilation_radius*2+1))
|
2233
|
-
dilated = binary_dilation(binary_mask, footprint=struct_elem)
|
2234
|
-
eroded = binary_erosion(binary_mask, footprint=struct_elem)
|
2257
|
+
dilated = morphology.binary_dilation(binary_mask, footprint=struct_elem)
|
2258
|
+
eroded = morphology.binary_erosion(binary_mask, footprint=struct_elem)
|
2235
2259
|
boundary = dilated ^ eroded
|
2236
2260
|
return boundary
|
2237
2261
|
|
@@ -2612,24 +2636,21 @@ def _filter_object(mask, min_value):
|
|
2612
2636
|
mask[np.isin(mask, to_remove)] = 0
|
2613
2637
|
return mask
|
2614
2638
|
|
2615
|
-
def _filter_cp_masks(masks, flows, filter_size, minimum_size, maximum_size, remove_border_objects, merge,
|
2639
|
+
def _filter_cp_masks(masks, flows, filter_size, filter_intensity, minimum_size, maximum_size, remove_border_objects, merge, batch, plot, figuresize):
|
2640
|
+
|
2616
2641
|
"""
|
2617
2642
|
Filter the masks based on various criteria such as size, border objects, merging, and intensity.
|
2618
2643
|
|
2619
2644
|
Args:
|
2620
2645
|
masks (list): List of masks.
|
2621
2646
|
flows (list): List of flows.
|
2622
|
-
refine_masks (bool): Flag indicating whether to refine masks.
|
2623
2647
|
filter_size (bool): Flag indicating whether to filter based on size.
|
2648
|
+
filter_intensity (bool): Flag indicating whether to filter based on intensity.
|
2624
2649
|
minimum_size (int): Minimum size of objects to keep.
|
2625
2650
|
maximum_size (int): Maximum size of objects to keep.
|
2626
2651
|
remove_border_objects (bool): Flag indicating whether to remove border objects.
|
2627
2652
|
merge (bool): Flag indicating whether to merge adjacent objects.
|
2628
|
-
filter_dimm (bool): Flag indicating whether to filter based on intensity.
|
2629
2653
|
batch (ndarray): Batch of images.
|
2630
|
-
moving_avg_q1 (float): Moving average of the first quartile of object intensities.
|
2631
|
-
moving_avg_q3 (float): Moving average of the third quartile of object intensities.
|
2632
|
-
moving_count (int): Count of moving averages.
|
2633
2654
|
plot (bool): Flag indicating whether to plot the masks.
|
2634
2655
|
figuresize (tuple): Size of the figure.
|
2635
2656
|
|
@@ -2641,51 +2662,66 @@ def _filter_cp_masks(masks, flows, filter_size, minimum_size, maximum_size, remo
|
|
2641
2662
|
|
2642
2663
|
mask_stack = []
|
2643
2664
|
for idx, (mask, flow, image) in enumerate(zip(masks, flows[0], batch)):
|
2665
|
+
|
2644
2666
|
if plot and idx == 0:
|
2645
2667
|
num_objects = mask_object_count(mask)
|
2646
2668
|
print(f'Number of objects before filtration: {num_objects}')
|
2647
2669
|
plot_masks(batch=image, masks=mask, flows=flow, cmap='inferno', figuresize=figuresize, nr=1, file_type='.npz', print_object_number=True)
|
2648
2670
|
|
2649
|
-
if
|
2650
|
-
|
2651
|
-
valid_labels = props['label'][np.logical_and(props['area'] > minimum_size, props['area'] < maximum_size)] # Select labels of valid size.
|
2652
|
-
masks[idx] = np.isin(mask, valid_labels) * mask # Keep only valid objects.
|
2671
|
+
if merge:
|
2672
|
+
mask = merge_touching_objects(mask, threshold=0.66)
|
2653
2673
|
if plot and idx == 0:
|
2654
2674
|
num_objects = mask_object_count(mask)
|
2655
|
-
print(f'Number of objects after
|
2675
|
+
print(f'Number of objects after merging adjacent objects, : {num_objects}')
|
2656
2676
|
plot_masks(batch=image, masks=mask, flows=flow, cmap='inferno', figuresize=figuresize, nr=1, file_type='.npz', print_object_number=True)
|
2657
|
-
|
2658
|
-
|
2677
|
+
|
2678
|
+
if filter_size:
|
2679
|
+
props = measure.regionprops_table(mask, properties=['label', 'area'])
|
2680
|
+
valid_labels = props['label'][np.logical_and(props['area'] > minimum_size, props['area'] < maximum_size)]
|
2681
|
+
mask = np.isin(mask, valid_labels) * mask
|
2659
2682
|
if plot and idx == 0:
|
2660
2683
|
num_objects = mask_object_count(mask)
|
2661
|
-
print(f'Number of objects after
|
2684
|
+
print(f'Number of objects after size filtration >{minimum_size} and <{maximum_size} : {num_objects}')
|
2662
2685
|
plot_masks(batch=image, masks=mask, flows=flow, cmap='inferno', figuresize=figuresize, nr=1, file_type='.npz', print_object_number=True)
|
2663
|
-
|
2664
|
-
|
2686
|
+
|
2687
|
+
if filter_intensity:
|
2688
|
+
intensity_image = image[:, :, 1]
|
2689
|
+
props = measure.regionprops_table(mask, intensity_image=intensity_image, properties=['label', 'mean_intensity'])
|
2690
|
+
mean_intensities = np.array(props['mean_intensity']).reshape(-1, 1)
|
2691
|
+
|
2692
|
+
if mean_intensities.shape[0] >= 2:
|
2693
|
+
kmeans = KMeans(n_clusters=2, random_state=0).fit(mean_intensities)
|
2694
|
+
centroids = kmeans.cluster_centers_
|
2695
|
+
|
2696
|
+
# Calculate the Euclidean distance between the two centroids
|
2697
|
+
dist_between_centroids = distance.euclidean(centroids[0], centroids[1])
|
2698
|
+
|
2699
|
+
# Set a threshold for the minimum distance to consider clusters distinct
|
2700
|
+
distance_threshold = 0.25
|
2701
|
+
|
2702
|
+
if dist_between_centroids > distance_threshold:
|
2703
|
+
high_intensity_cluster = np.argmax(centroids)
|
2704
|
+
valid_labels = np.array(props['label'])[kmeans.labels_ == high_intensity_cluster]
|
2705
|
+
mask = np.isin(mask, valid_labels) * mask
|
2706
|
+
|
2665
2707
|
if plot and idx == 0:
|
2666
2708
|
num_objects = mask_object_count(mask)
|
2667
|
-
|
2709
|
+
props_after = measure.regionprops_table(mask, intensity_image=intensity_image, properties=['label', 'mean_intensity'])
|
2710
|
+
mean_intensities_after = np.mean(np.array(props_after['mean_intensity']))
|
2711
|
+
average_intensity_before = np.mean(mean_intensities)
|
2712
|
+
print(f'Number of objects after potential intensity clustering: {num_objects}. Mean intensity before:{average_intensity_before:.4f}. After:{mean_intensities_after:.4f}.')
|
2668
2713
|
plot_masks(batch=image, masks=mask, flows=flow, cmap='inferno', figuresize=figuresize, nr=1, file_type='.npz', print_object_number=True)
|
2669
|
-
|
2670
|
-
|
2671
|
-
|
2672
|
-
|
2673
|
-
object_intensities = [np.mean(batch[idx, :, :, 1][mask == label]) for label in unique_labels if label != 0]
|
2674
|
-
object_q1s = [np.percentile(intensities, 25) for intensities in object_intensities if intensities.size > 0]
|
2675
|
-
object_q3s = [np.percentile(intensities, 75) for intensities in object_intensities if intensities.size > 0]
|
2676
|
-
if object_q1s:
|
2677
|
-
object_q1_mean = np.mean(object_q1s)
|
2678
|
-
object_q3_mean = np.mean(object_q3s)
|
2679
|
-
moving_avg_q1 = (moving_avg_q1 * moving_count + object_q1_mean) / (moving_count + 1)
|
2680
|
-
moving_avg_q3 = (moving_avg_q3 * moving_count + object_q3_mean) / (moving_count + 1)
|
2681
|
-
moving_count += 1
|
2682
|
-
mask = remove_intensity_objects(batch[idx, :, :, 1], mask, intensity_threshold=moving_avg_q1, mode='low')
|
2683
|
-
mask = remove_intensity_objects(batch[idx, :, :, 1], mask, intensity_threshold=moving_avg_q3, mode='high')
|
2714
|
+
|
2715
|
+
|
2716
|
+
if remove_border_objects:
|
2717
|
+
mask = clear_border(mask)
|
2684
2718
|
if plot and idx == 0:
|
2685
2719
|
num_objects = mask_object_count(mask)
|
2686
|
-
print(f'
|
2720
|
+
print(f'Number of objects after removing border objects, : {num_objects}')
|
2687
2721
|
plot_masks(batch=image, masks=mask, flows=flow, cmap='inferno', figuresize=figuresize, nr=1, file_type='.npz', print_object_number=True)
|
2722
|
+
|
2688
2723
|
mask_stack.append(mask)
|
2724
|
+
|
2689
2725
|
return mask_stack
|
2690
2726
|
|
2691
2727
|
def _object_filter(df, object_type, size_range, intensity_range, mask_chans, mask_chan):
|
@@ -2721,6 +2757,71 @@ def _object_filter(df, object_type, size_range, intensity_range, mask_chans, mas
|
|
2721
2757
|
print(f'After {object_type} maximum mean intensity filter: {len(df)}')
|
2722
2758
|
return df
|
2723
2759
|
|
2724
|
-
|
2725
|
-
|
2726
|
-
|
2760
|
+
def _run_test_mode(src, regex, timelapse=False):
|
2761
|
+
if timelapse:
|
2762
|
+
test_images = 1 # Use only 1 set for timelapse to ensure full sequence inclusion
|
2763
|
+
else:
|
2764
|
+
test_images = 10 # Use 10 sets for non-timelapse scenarios
|
2765
|
+
|
2766
|
+
test_folder_path = os.path.join(src, 'test')
|
2767
|
+
os.makedirs(test_folder_path, exist_ok=True)
|
2768
|
+
regular_expression = re.compile(regex)
|
2769
|
+
|
2770
|
+
all_filenames = [filename for filename in os.listdir(src) if regular_expression.match(filename)]
|
2771
|
+
print(f'Found {len(all_filenames)} files')
|
2772
|
+
images_by_set = defaultdict(list)
|
2773
|
+
|
2774
|
+
for filename in all_filenames:
|
2775
|
+
match = regular_expression.match(filename)
|
2776
|
+
if match:
|
2777
|
+
plate = match.group('plateID') if 'plateID' in match.groupdict() else os.path.basename(src)
|
2778
|
+
well = match.group('wellID')
|
2779
|
+
field = match.group('fieldID')
|
2780
|
+
# For timelapse experiments, group images by plate, well, and field only
|
2781
|
+
if timelapse:
|
2782
|
+
set_identifier = (plate, well, field)
|
2783
|
+
else:
|
2784
|
+
# For non-timelapse, you might want to distinguish sets more granularly
|
2785
|
+
# Here, assuming you're grouping by plate, well, and field for simplicity
|
2786
|
+
set_identifier = (plate, well, field)
|
2787
|
+
images_by_set[set_identifier].append(filename)
|
2788
|
+
|
2789
|
+
# Prepare for random selection
|
2790
|
+
set_identifiers = list(images_by_set.keys())
|
2791
|
+
random.seed(42)
|
2792
|
+
random.shuffle(set_identifiers) # Randomize the order
|
2793
|
+
|
2794
|
+
# Select a subset based on the test_images count
|
2795
|
+
selected_sets = set_identifiers[:test_images]
|
2796
|
+
|
2797
|
+
# Print information about the number of sets used
|
2798
|
+
print(f'Using {test_images} random image set(s) for test model')
|
2799
|
+
|
2800
|
+
# Copy files for selected sets to the test folder
|
2801
|
+
for set_identifier in selected_sets:
|
2802
|
+
for filename in images_by_set[set_identifier]:
|
2803
|
+
shutil.copy(os.path.join(src, filename), test_folder_path)
|
2804
|
+
|
2805
|
+
return test_folder_path
|
2806
|
+
|
2807
|
+
def _choose_model(model_name, device, object_type='cell', restore_type=None):
|
2808
|
+
restore_list = ['denoise', 'deblur', 'upsample', None]
|
2809
|
+
if restore_type not in restore_list:
|
2810
|
+
print(f"Invalid restore type. Choose from {restore_list} defaulting to None")
|
2811
|
+
restore_type = None
|
2812
|
+
|
2813
|
+
if restore_type == None:
|
2814
|
+
model = cp_models.Cellpose(gpu=True, model_type=model_name, device=device)
|
2815
|
+
else:
|
2816
|
+
if object_type == 'nucleus':
|
2817
|
+
restore = f'{type}_nuclei'
|
2818
|
+
model = denoise.CellposeDenoiseModel(gpu=True, model_type="nuclei",restore_type=restore, chan2_restore=False, device=device)
|
2819
|
+
else:
|
2820
|
+
restore = f'{type}_cyto3'
|
2821
|
+
if model_name =='cyto2':
|
2822
|
+
chan2_restore = True
|
2823
|
+
if model_name =='cyto':
|
2824
|
+
chan2_restore = False
|
2825
|
+
model = denoise.CellposeDenoiseModel(gpu=True, model_type="cyto3",restore_type=restore, chan2_restore=chan2_restore, device=device)
|
2826
|
+
|
2827
|
+
return model
|