spacr 0.0.2__py3-none-any.whl → 0.0.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- spacr/__init__.py +2 -2
- spacr/__main__.py +0 -2
- spacr/alpha.py +803 -14
- spacr/annotate_app.py +118 -120
- spacr/chris.py +50 -0
- spacr/core.py +1544 -533
- spacr/deep_spacr.py +696 -0
- spacr/foldseek.py +779 -0
- spacr/get_alfafold_structures.py +72 -0
- spacr/graph_learning.py +297 -253
- spacr/gui.py +145 -0
- spacr/gui_2.py +90 -0
- spacr/gui_classify_app.py +70 -80
- spacr/gui_mask_app.py +114 -91
- spacr/gui_measure_app.py +109 -88
- spacr/gui_utils.py +376 -32
- spacr/io.py +441 -438
- spacr/mask_app.py +116 -9
- spacr/measure.py +169 -69
- spacr/models/cp/toxo_pv_lumen.CP_model +0 -0
- spacr/old_code.py +70 -2
- spacr/plot.py +173 -17
- spacr/sequencing.py +1130 -0
- spacr/sim.py +630 -125
- spacr/timelapse.py +139 -10
- spacr/train.py +188 -21
- spacr/umap.py +0 -689
- spacr/utils.py +1360 -119
- {spacr-0.0.2.dist-info → spacr-0.0.6.dist-info}/METADATA +17 -29
- spacr-0.0.6.dist-info/RECORD +39 -0
- {spacr-0.0.2.dist-info → spacr-0.0.6.dist-info}/WHEEL +1 -1
- spacr-0.0.6.dist-info/entry_points.txt +9 -0
- spacr-0.0.2.dist-info/RECORD +0 -31
- spacr-0.0.2.dist-info/entry_points.txt +0 -7
- {spacr-0.0.2.dist-info → spacr-0.0.6.dist-info}/LICENSE +0 -0
- {spacr-0.0.2.dist-info → spacr-0.0.6.dist-info}/top_level.txt +0 -0
spacr/core.py
CHANGED
@@ -1,10 +1,9 @@
|
|
1
|
-
import os, sqlite3, gc, torch, time, random, shutil, cv2, tarfile, datetime
|
1
|
+
import os, sqlite3, gc, torch, time, random, shutil, cv2, tarfile, datetime, shap
|
2
2
|
|
3
|
-
# image and array processing
|
4
3
|
import numpy as np
|
5
4
|
import pandas as pd
|
6
5
|
|
7
|
-
import
|
6
|
+
from cellpose import train
|
8
7
|
from cellpose import models as cp_models
|
9
8
|
|
10
9
|
import statsmodels.formula.api as smf
|
@@ -14,23 +13,37 @@ from IPython.display import display
|
|
14
13
|
from multiprocessing import Pool, cpu_count, Value, Lock
|
15
14
|
|
16
15
|
import seaborn as sns
|
17
|
-
|
16
|
+
|
18
17
|
from skimage.measure import regionprops, label
|
19
|
-
|
18
|
+
from skimage.morphology import square
|
20
19
|
from skimage.transform import resize as resizescikit
|
21
|
-
from sklearn.model_selection import train_test_split
|
22
20
|
from collections import defaultdict
|
23
|
-
import multiprocessing
|
24
21
|
from torch.utils.data import DataLoader, random_split
|
22
|
+
from sklearn.cluster import KMeans
|
23
|
+
from sklearn.decomposition import PCA
|
24
|
+
|
25
|
+
from skimage import measure
|
26
|
+
from sklearn.model_selection import train_test_split
|
27
|
+
from sklearn.ensemble import IsolationForest, RandomForestClassifier, HistGradientBoostingClassifier
|
28
|
+
from sklearn.linear_model import LogisticRegression
|
29
|
+
from sklearn.inspection import permutation_importance
|
30
|
+
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, classification_report
|
31
|
+
from sklearn.preprocessing import StandardScaler
|
32
|
+
|
33
|
+
from scipy.ndimage import binary_dilation
|
34
|
+
from scipy.spatial.distance import cosine, euclidean, mahalanobis, cityblock, minkowski, chebyshev, hamming, jaccard, braycurtis
|
35
|
+
|
36
|
+
import torchvision.transforms as transforms
|
37
|
+
from xgboost import XGBClassifier
|
38
|
+
import shap
|
39
|
+
|
40
|
+
import matplotlib.pyplot as plt
|
25
41
|
import matplotlib
|
26
42
|
matplotlib.use('Agg')
|
43
|
+
#import matplotlib.pyplot as plt
|
27
44
|
|
28
|
-
import torchvision.transforms as transforms
|
29
|
-
from sklearn.model_selection import train_test_split
|
30
|
-
from sklearn.ensemble import IsolationForest
|
31
45
|
from .logger import log_function_call
|
32
46
|
|
33
|
-
|
34
47
|
def analyze_plaques(folder):
|
35
48
|
summary_data = []
|
36
49
|
details_data = []
|
@@ -67,169 +80,95 @@ def analyze_plaques(folder):
|
|
67
80
|
|
68
81
|
print(f"Analysis completed and saved to database '{db_name}'.")
|
69
82
|
|
70
|
-
def compare_masks(dir1, dir2, dir3, verbose=False):
|
71
|
-
|
72
|
-
from .io import _read_mask
|
73
|
-
from .plot import visualize_masks, plot_comparison_results
|
74
|
-
from .utils import extract_boundaries, boundary_f1_score, compute_segmentation_ap, jaccard_index, dice_coefficient
|
75
|
-
|
76
|
-
filenames = os.listdir(dir1)
|
77
|
-
results = []
|
78
|
-
cond_1 = os.path.basename(dir1)
|
79
|
-
cond_2 = os.path.basename(dir2)
|
80
|
-
cond_3 = os.path.basename(dir3)
|
81
|
-
for index, filename in enumerate(filenames):
|
82
|
-
print(f'Processing image:{index+1}', end='\r', flush=True)
|
83
|
-
path1, path2, path3 = os.path.join(dir1, filename), os.path.join(dir2, filename), os.path.join(dir3, filename)
|
84
|
-
if os.path.exists(path2) and os.path.exists(path3):
|
85
|
-
|
86
|
-
mask1, mask2, mask3 = _read_mask(path1), _read_mask(path2), _read_mask(path3)
|
87
|
-
boundary_true1, boundary_true2, boundary_true3 = extract_boundaries(mask1), extract_boundaries(mask2), extract_boundaries(mask3)
|
88
|
-
|
89
|
-
|
90
|
-
true_masks, pred_masks = [mask1], [mask2, mask3] # Assuming mask1 is the ground truth for simplicity
|
91
|
-
true_labels, pred_labels_1, pred_labels_2 = label(mask1), label(mask2), label(mask3)
|
92
|
-
average_precision_0, average_precision_1 = compute_segmentation_ap(mask1, mask2), compute_segmentation_ap(mask1, mask3)
|
93
|
-
ap_scores = [average_precision_0, average_precision_1]
|
94
|
-
|
95
|
-
if verbose:
|
96
|
-
unique_values1, unique_values2, unique_values3 = np.unique(mask1), np.unique(mask2), np.unique(mask3)
|
97
|
-
print(f"Unique values in mask 1: {unique_values1}, mask 2: {unique_values2}, mask 3: {unique_values3}")
|
98
|
-
visualize_masks(boundary_true1, boundary_true2, boundary_true3, title=f"Boundaries - {filename}")
|
99
|
-
|
100
|
-
boundary_f1_12, boundary_f1_13, boundary_f1_23 = boundary_f1_score(mask1, mask2), boundary_f1_score(mask1, mask3), boundary_f1_score(mask2, mask3)
|
101
|
-
|
102
|
-
if (np.unique(mask1).size == 1 and np.unique(mask1)[0] == 0) and \
|
103
|
-
(np.unique(mask2).size == 1 and np.unique(mask2)[0] == 0) and \
|
104
|
-
(np.unique(mask3).size == 1 and np.unique(mask3)[0] == 0):
|
105
|
-
continue
|
106
|
-
|
107
|
-
if verbose:
|
108
|
-
unique_values4, unique_values5, unique_values6 = np.unique(boundary_f1_12), np.unique(boundary_f1_13), np.unique(boundary_f1_23)
|
109
|
-
print(f"Unique values in boundary mask 1: {unique_values4}, mask 2: {unique_values5}, mask 3: {unique_values6}")
|
110
|
-
visualize_masks(mask1, mask2, mask3, title=filename)
|
111
|
-
|
112
|
-
jaccard12 = jaccard_index(mask1, mask2)
|
113
|
-
dice12 = dice_coefficient(mask1, mask2)
|
114
|
-
jaccard13 = jaccard_index(mask1, mask3)
|
115
|
-
dice13 = dice_coefficient(mask1, mask3)
|
116
|
-
jaccard23 = jaccard_index(mask2, mask3)
|
117
|
-
dice23 = dice_coefficient(mask2, mask3)
|
118
|
-
|
119
|
-
results.append({
|
120
|
-
f'filename': filename,
|
121
|
-
f'jaccard_{cond_1}_{cond_2}': jaccard12,
|
122
|
-
f'dice_{cond_1}_{cond_2}': dice12,
|
123
|
-
f'jaccard_{cond_1}_{cond_3}': jaccard13,
|
124
|
-
f'dice_{cond_1}_{cond_3}': dice13,
|
125
|
-
f'jaccard_{cond_2}_{cond_3}': jaccard23,
|
126
|
-
f'dice_{cond_2}_{cond_3}': dice23,
|
127
|
-
f'boundary_f1_{cond_1}_{cond_2}': boundary_f1_12,
|
128
|
-
f'boundary_f1_{cond_1}_{cond_3}': boundary_f1_13,
|
129
|
-
f'boundary_f1_{cond_2}_{cond_3}': boundary_f1_23,
|
130
|
-
f'average_precision_{cond_1}_{cond_2}': ap_scores[0],
|
131
|
-
f'average_precision_{cond_1}_{cond_3}': ap_scores[1]
|
132
|
-
})
|
133
|
-
else:
|
134
|
-
print(f'Cannot find {path1} or {path2} or {path3}')
|
135
|
-
fig = plot_comparison_results(results)
|
136
|
-
return results, fig
|
137
|
-
|
138
|
-
def generate_cp_masks(settings):
|
139
|
-
|
140
|
-
src = settings['src']
|
141
|
-
model_name = settings['model_name']
|
142
|
-
channels = settings['channels']
|
143
|
-
diameter = settings['diameter']
|
144
|
-
regex = '.tif'
|
145
|
-
#flow_threshold = 30
|
146
|
-
cellprob_threshold = settings['cellprob_threshold']
|
147
|
-
figuresize = 25
|
148
|
-
cmap = 'inferno'
|
149
|
-
verbose = settings['verbose']
|
150
|
-
plot = settings['plot']
|
151
|
-
save = settings['save']
|
152
|
-
custom_model = settings['custom_model']
|
153
|
-
signal_thresholds = 1000
|
154
|
-
normalize = settings['normalize']
|
155
|
-
resize = settings['resize']
|
156
|
-
target_height = settings['width_height'][1]
|
157
|
-
target_width = settings['width_height'][0]
|
158
|
-
rescale = settings['rescale']
|
159
|
-
resample = settings['resample']
|
160
|
-
net_avg = settings['net_avg']
|
161
|
-
invert = settings['invert']
|
162
|
-
circular = settings['circular']
|
163
|
-
percentiles = settings['percentiles']
|
164
|
-
overlay = settings['overlay']
|
165
|
-
grayscale = settings['grayscale']
|
166
|
-
flow_threshold = settings['flow_threshold']
|
167
|
-
batch_size = settings['batch_size']
|
168
|
-
|
169
|
-
dst = os.path.join(src,'masks')
|
170
|
-
os.makedirs(dst, exist_ok=True)
|
171
|
-
|
172
|
-
identify_masks(src, dst, model_name, channels, diameter, batch_size, flow_threshold, cellprob_threshold, figuresize, cmap, verbose, plot, save, custom_model, signal_thresholds, normalize, resize, target_height, target_width, rescale, resample, net_avg, invert, circular, percentiles, overlay, grayscale)
|
173
|
-
|
174
83
|
def train_cellpose(settings):
|
175
84
|
|
176
85
|
from .io import _load_normalized_images_and_labels, _load_images_and_labels
|
177
86
|
from .utils import resize_images_and_labels
|
178
87
|
|
179
88
|
img_src = settings['img_src']
|
180
|
-
mask_src=
|
181
|
-
|
182
|
-
model_name = settings
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
from_scratch = settings
|
192
|
-
diameter = settings
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
invert = settings['invert']
|
200
|
-
percentiles = settings['percentiles']
|
201
|
-
grayscale = settings['grayscale']
|
89
|
+
mask_src = os.path.join(img_src, 'masks')
|
90
|
+
|
91
|
+
model_name = settings.setdefault( 'model_name', '')
|
92
|
+
|
93
|
+
model_name = settings.setdefault('model_name', 'model_name')
|
94
|
+
|
95
|
+
model_type = settings.setdefault( 'model_type', 'cyto')
|
96
|
+
learning_rate = settings.setdefault( 'learning_rate', 0.01)
|
97
|
+
weight_decay = settings.setdefault( 'weight_decay', 1e-05)
|
98
|
+
batch_size = settings.setdefault( 'batch_size', 50)
|
99
|
+
n_epochs = settings.setdefault( 'n_epochs', 100)
|
100
|
+
from_scratch = settings.setdefault( 'from_scratch', False)
|
101
|
+
diameter = settings.setdefault( 'diameter', 40)
|
102
|
+
|
103
|
+
remove_background = settings.setdefault( 'remove_background', False)
|
104
|
+
background = settings.setdefault( 'background', 100)
|
105
|
+
Signal_to_noise = settings.setdefault( 'Signal_to_noise', 10)
|
106
|
+
verbose = settings.setdefault( 'verbose', False)
|
107
|
+
|
202
108
|
|
109
|
+
channels = settings.setdefault( 'channels', [0,0])
|
110
|
+
normalize = settings.setdefault( 'normalize', True)
|
111
|
+
percentiles = settings.setdefault( 'percentiles', None)
|
112
|
+
circular = settings.setdefault( 'circular', False)
|
113
|
+
invert = settings.setdefault( 'invert', False)
|
114
|
+
resize = settings.setdefault( 'resize', False)
|
115
|
+
|
116
|
+
if resize:
|
117
|
+
target_height = settings['width_height'][1]
|
118
|
+
target_width = settings['width_height'][0]
|
119
|
+
|
120
|
+
grayscale = settings.setdefault( 'grayscale', True)
|
121
|
+
rescale = settings.setdefault( 'channels', False)
|
122
|
+
test = settings.setdefault( 'test', False)
|
123
|
+
|
124
|
+
if test:
|
125
|
+
test_img_src = os.path.join(os.path.dirname(img_src), 'test')
|
126
|
+
test_mask_src = os.path.join(test_img_src, 'mask')
|
127
|
+
|
128
|
+
test_images, test_masks, test_image_names, test_mask_names = None,None,None,None,
|
203
129
|
print(settings)
|
204
130
|
|
205
131
|
if from_scratch:
|
206
132
|
model_name=f'scratch_{model_name}_{model_type}_e{n_epochs}_X{target_width}_Y{target_height}.CP_model'
|
207
133
|
else:
|
208
|
-
|
134
|
+
if resize:
|
135
|
+
model_name=f'{model_name}_{model_type}_e{n_epochs}_X{target_width}_Y{target_height}.CP_model'
|
136
|
+
else:
|
137
|
+
model_name=f'{model_name}_{model_type}_e{n_epochs}.CP_model'
|
209
138
|
|
210
139
|
model_save_path = os.path.join(mask_src, 'models', 'cellpose_model')
|
211
|
-
|
140
|
+
print(model_save_path)
|
141
|
+
os.makedirs(model_save_path, exist_ok=True)
|
212
142
|
|
213
143
|
settings_df = pd.DataFrame(list(settings.items()), columns=['Key', 'Value'])
|
214
144
|
settings_csv = os.path.join(model_save_path,f'{model_name}_settings.csv')
|
215
145
|
settings_df.to_csv(settings_csv, index=False)
|
216
146
|
|
217
|
-
if
|
218
|
-
|
219
|
-
|
220
|
-
else:
|
221
|
-
model = cp_models.CellposeModel(gpu=True, model_type=model_type, net_avg=False, diam_mean=diameter, pretrained_model=None)
|
222
|
-
if model_type !='cyto':
|
147
|
+
if from_scratch:
|
148
|
+
model = cp_models.CellposeModel(gpu=True, model_type=model_type, diam_mean=diameter, pretrained_model=None)
|
149
|
+
else:
|
223
150
|
model = cp_models.CellposeModel(gpu=True, model_type=model_type)
|
224
151
|
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
152
|
+
if normalize:
|
153
|
+
|
154
|
+
image_files = [os.path.join(img_src, f) for f in os.listdir(img_src) if f.endswith('.tif')]
|
155
|
+
label_files = [os.path.join(mask_src, f) for f in os.listdir(mask_src) if f.endswith('.tif')]
|
156
|
+
images, masks, image_names, mask_names = _load_normalized_images_and_labels(image_files, label_files, channels, percentiles, circular, invert, verbose, remove_background, background, Signal_to_noise)
|
229
157
|
images = [np.squeeze(img) if img.shape[-1] == 1 else img for img in images]
|
158
|
+
|
159
|
+
if test:
|
160
|
+
test_image_files = [os.path.join(test_img_src, f) for f in os.listdir(test_img_src) if f.endswith('.tif')]
|
161
|
+
test_label_files = [os.path.join(test_mask_src, f) for f in os.listdir(test_mask_src) if f.endswith('.tif')]
|
162
|
+
test_images, test_masks, test_image_names, test_mask_names = _load_normalized_images_and_labels(test_image_files, test_label_files, channels, percentiles, circular, invert, verbose, remove_background, background, Signal_to_noise)
|
163
|
+
test_images = [np.squeeze(img) if img.shape[-1] == 1 else img for img in test_images]
|
164
|
+
|
230
165
|
else:
|
231
166
|
images, masks, image_names, mask_names = _load_images_and_labels(img_src, mask_src, circular, invert)
|
232
167
|
images = [np.squeeze(img) if img.shape[-1] == 1 else img for img in images]
|
168
|
+
|
169
|
+
if test:
|
170
|
+
test_images, test_masks, test_image_names, test_mask_names = _load_images_and_labels(img_src=test_img_src, mask_src=test_mask_src, circular=circular, invert=invert)
|
171
|
+
test_images = [np.squeeze(img) if img.shape[-1] == 1 else img for img in test_images]
|
233
172
|
|
234
173
|
if resize:
|
235
174
|
images, masks = resize_images_and_labels(images, masks, target_height, target_width, show_example=True)
|
@@ -248,25 +187,41 @@ def train_cellpose(settings):
|
|
248
187
|
|
249
188
|
print(f'image shape: {images[0].shape}, image type: images[0].shape mask shape: {masks[0].shape}, image type: masks[0].shape')
|
250
189
|
save_every = int(n_epochs/10)
|
251
|
-
|
252
|
-
|
253
|
-
|
254
|
-
model.
|
255
|
-
|
256
|
-
|
257
|
-
|
258
|
-
|
259
|
-
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
|
266
|
-
|
267
|
-
|
268
|
-
|
269
|
-
|
190
|
+
if save_every < 10:
|
191
|
+
save_every = n_epochs
|
192
|
+
|
193
|
+
train.train_seg(model.net,
|
194
|
+
train_data=images,
|
195
|
+
train_labels=masks,
|
196
|
+
train_files=image_names,
|
197
|
+
train_labels_files=mask_names,
|
198
|
+
train_probs=None,
|
199
|
+
test_data=test_images,
|
200
|
+
test_labels=test_masks,
|
201
|
+
test_files=test_image_names,
|
202
|
+
test_labels_files=test_mask_names,
|
203
|
+
test_probs=None,
|
204
|
+
load_files=True,
|
205
|
+
batch_size=batch_size,
|
206
|
+
learning_rate=learning_rate,
|
207
|
+
n_epochs=n_epochs,
|
208
|
+
weight_decay=weight_decay,
|
209
|
+
momentum=0.9,
|
210
|
+
SGD=False,
|
211
|
+
channels=cp_channels,
|
212
|
+
channel_axis=None,
|
213
|
+
#rgb=False,
|
214
|
+
normalize=False,
|
215
|
+
compute_flows=False,
|
216
|
+
save_path=model_save_path,
|
217
|
+
save_every=save_every,
|
218
|
+
nimg_per_epoch=None,
|
219
|
+
nimg_test_per_epoch=None,
|
220
|
+
rescale=rescale,
|
221
|
+
#scale_range=None,
|
222
|
+
#bsize=224,
|
223
|
+
min_train_masks=1,
|
224
|
+
model_name=model_name)
|
270
225
|
|
271
226
|
return print(f"Model saved at: {model_save_path}/{model_name}")
|
272
227
|
|
@@ -831,15 +786,6 @@ def merge_pred_mes(src,
|
|
831
786
|
|
832
787
|
if verbose:
|
833
788
|
_plot_histograms_and_stats(df=joined_df)
|
834
|
-
|
835
|
-
#dv = joined_df.copy()
|
836
|
-
#if 'prc' not in dv.columns:
|
837
|
-
#dv['prc'] = dv['plate'] + '_' + dv['row'] + '_' + dv['col']
|
838
|
-
#dv = dv[['pred']].groupby('prc').mean()
|
839
|
-
#dv.set_index('prc', inplace=True)
|
840
|
-
|
841
|
-
#loc = '/mnt/data/CellVoyager/20x/tsg101/crispr_screen/all/measurements/dv.csv'
|
842
|
-
#dv.to_csv(loc, index=True, header=True, mode='w')
|
843
789
|
|
844
790
|
return joined_df
|
845
791
|
|
@@ -926,30 +872,38 @@ def annotate_results(pred_loc):
|
|
926
872
|
display(df)
|
927
873
|
return df
|
928
874
|
|
929
|
-
def generate_dataset(src,
|
875
|
+
def generate_dataset(src, file_metadata=None, experiment='TSG101_screen', sample=None):
|
930
876
|
|
931
|
-
from .utils import
|
932
|
-
|
933
|
-
db_path = os.path.join(src, 'measurements','measurements.db')
|
877
|
+
from .utils import initiate_counter, add_images_to_tar
|
878
|
+
|
879
|
+
db_path = os.path.join(src, 'measurements', 'measurements.db')
|
934
880
|
dst = os.path.join(src, 'datasets')
|
935
|
-
|
936
|
-
global total_images
|
937
881
|
all_paths = []
|
938
|
-
|
882
|
+
|
939
883
|
# Connect to the database and retrieve the image paths
|
940
884
|
print(f'Reading DataBase: {db_path}')
|
941
|
-
|
942
|
-
|
943
|
-
|
944
|
-
|
945
|
-
|
946
|
-
|
947
|
-
|
948
|
-
|
949
|
-
|
950
|
-
|
951
|
-
|
952
|
-
|
885
|
+
try:
|
886
|
+
with sqlite3.connect(db_path) as conn:
|
887
|
+
cursor = conn.cursor()
|
888
|
+
if file_metadata:
|
889
|
+
if isinstance(file_metadata, str):
|
890
|
+
cursor.execute("SELECT png_path FROM png_list WHERE png_path LIKE ?", (f"%{file_metadata}%",))
|
891
|
+
else:
|
892
|
+
cursor.execute("SELECT png_path FROM png_list")
|
893
|
+
|
894
|
+
while True:
|
895
|
+
rows = cursor.fetchmany(1000)
|
896
|
+
if not rows:
|
897
|
+
break
|
898
|
+
all_paths.extend([row[0] for row in rows])
|
899
|
+
|
900
|
+
except sqlite3.Error as e:
|
901
|
+
print(f"Database error: {e}")
|
902
|
+
return
|
903
|
+
except Exception as e:
|
904
|
+
print(f"Error: {e}")
|
905
|
+
return
|
906
|
+
|
953
907
|
if isinstance(sample, int):
|
954
908
|
selected_paths = random.sample(all_paths, sample)
|
955
909
|
print(f'Random selection of {len(selected_paths)} paths')
|
@@ -957,23 +911,18 @@ def generate_dataset(src, file_type=None, experiment='TSG101_screen', sample=Non
|
|
957
911
|
selected_paths = all_paths
|
958
912
|
random.shuffle(selected_paths)
|
959
913
|
print(f'All paths: {len(selected_paths)} paths')
|
960
|
-
|
914
|
+
|
961
915
|
total_images = len(selected_paths)
|
962
|
-
print(f'
|
963
|
-
|
916
|
+
print(f'Found {total_images} images')
|
917
|
+
|
964
918
|
# Create a temp folder in dst
|
965
919
|
temp_dir = os.path.join(dst, "temp_tars")
|
966
920
|
os.makedirs(temp_dir, exist_ok=True)
|
967
921
|
|
968
922
|
# Chunking the data
|
969
|
-
|
970
|
-
|
971
|
-
|
972
|
-
remainder = len(selected_paths) % num_procs
|
973
|
-
else:
|
974
|
-
num_procs = 2
|
975
|
-
chunk_size = len(selected_paths) // 2
|
976
|
-
remainder = 0
|
923
|
+
num_procs = max(2, cpu_count() - 2)
|
924
|
+
chunk_size = len(selected_paths) // num_procs
|
925
|
+
remainder = len(selected_paths) % num_procs
|
977
926
|
|
978
927
|
paths_chunks = []
|
979
928
|
start = 0
|
@@ -983,45 +932,43 @@ def generate_dataset(src, file_type=None, experiment='TSG101_screen', sample=Non
|
|
983
932
|
start = end
|
984
933
|
|
985
934
|
temp_tar_files = [os.path.join(temp_dir, f'temp_{i}.tar') for i in range(num_procs)]
|
986
|
-
|
987
|
-
# Initialize the shared objects
|
988
|
-
counter_ = Value('i', 0)
|
989
|
-
lock_ = Lock()
|
990
935
|
|
991
|
-
ctx = multiprocessing.get_context('spawn')
|
992
|
-
|
993
936
|
print(f'Generating temporary tar files in {dst}')
|
994
|
-
|
937
|
+
|
938
|
+
# Initialize shared counter and lock
|
939
|
+
counter = Value('i', 0)
|
940
|
+
lock = Lock()
|
941
|
+
|
942
|
+
with Pool(processes=num_procs, initializer=initiate_counter, initargs=(counter, lock)) as pool:
|
943
|
+
pool.starmap(add_images_to_tar, [(paths_chunks[i], temp_tar_files[i], total_images) for i in range(num_procs)])
|
944
|
+
|
995
945
|
# Combine the temporary tar files into a final tar
|
996
946
|
date_name = datetime.date.today().strftime('%y%m%d')
|
997
|
-
|
947
|
+
if not file_metadata is None:
|
948
|
+
tar_name = f'{date_name}_{experiment}_{file_metadata}.tar'
|
949
|
+
else:
|
950
|
+
tar_name = f'{date_name}_{experiment}.tar'
|
951
|
+
tar_name = os.path.join(dst, tar_name)
|
998
952
|
if os.path.exists(tar_name):
|
999
953
|
number = random.randint(1, 100)
|
1000
|
-
tar_name_2 = f'{date_name}_{experiment}_{
|
1001
|
-
print(f'Warning: {os.path.basename(tar_name)} exists saving as {os.path.basename(tar_name_2)} ')
|
1002
|
-
tar_name = tar_name_2
|
1003
|
-
|
1004
|
-
# Add the counter and lock to the arguments for pool.map
|
954
|
+
tar_name_2 = f'{date_name}_{experiment}_{file_metadata}_{number}.tar'
|
955
|
+
print(f'Warning: {os.path.basename(tar_name)} exists, saving as {os.path.basename(tar_name_2)} ')
|
956
|
+
tar_name = os.path.join(dst, tar_name_2)
|
957
|
+
|
1005
958
|
print(f'Merging temporary files')
|
1006
|
-
#with Pool(processes=num_procs, initializer=init_globals, initargs=(counter_, lock_)) as pool:
|
1007
|
-
# results = pool.map(add_images_to_tar, zip(paths_chunks, temp_tar_files))
|
1008
959
|
|
1009
|
-
with
|
1010
|
-
|
1011
|
-
|
1012
|
-
|
1013
|
-
|
1014
|
-
|
1015
|
-
|
1016
|
-
t.extract(member, path=dst)
|
1017
|
-
final_tar.add(os.path.join(dst, member.name), arcname=member.name)
|
1018
|
-
os.remove(os.path.join(dst, member.name))
|
1019
|
-
os.remove(tar_path)
|
960
|
+
with tarfile.open(tar_name, 'w') as final_tar:
|
961
|
+
for temp_tar_path in temp_tar_files:
|
962
|
+
with tarfile.open(temp_tar_path, 'r') as temp_tar:
|
963
|
+
for member in temp_tar.getmembers():
|
964
|
+
file_obj = temp_tar.extractfile(member)
|
965
|
+
final_tar.addfile(member, file_obj)
|
966
|
+
os.remove(temp_tar_path)
|
1020
967
|
|
1021
968
|
# Delete the temp folder
|
1022
969
|
shutil.rmtree(temp_dir)
|
1023
|
-
print(f"\nSaved {total_images} images to {
|
1024
|
-
|
970
|
+
print(f"\nSaved {total_images} images to {tar_name}")
|
971
|
+
|
1025
972
|
def apply_model_to_tar(tar_path, model_path, file_type='cell_png', image_size=224, batch_size=64, normalize=True, preload='images', num_workers=10, verbose=False):
|
1026
973
|
|
1027
974
|
from .io import TarImageDataset, DataLoader
|
@@ -1128,7 +1075,6 @@ def apply_model(src, model_path, image_size=224, batch_size=64, normalize=True,
|
|
1128
1075
|
torch.cuda.memory.empty_cache()
|
1129
1076
|
return df
|
1130
1077
|
|
1131
|
-
|
1132
1078
|
def generate_training_data_file_list(src,
|
1133
1079
|
target='protein of interest',
|
1134
1080
|
cell_dim=4,
|
@@ -1257,7 +1203,14 @@ def generate_training_dataset(src, mode='annotation', annotation_column='test',
|
|
1257
1203
|
|
1258
1204
|
db_path = os.path.join(src, 'measurements','measurements.db')
|
1259
1205
|
dst = os.path.join(src, 'datasets', 'training')
|
1260
|
-
|
1206
|
+
|
1207
|
+
if os.path.exists(dst):
|
1208
|
+
for i in range(1, 1000):
|
1209
|
+
dst = os.path.join(src, 'datasets', f'training_{i}')
|
1210
|
+
if not os.path.exists(dst):
|
1211
|
+
print(f'Creating new directory for training: {dst}')
|
1212
|
+
break
|
1213
|
+
|
1261
1214
|
if mode == 'annotation':
|
1262
1215
|
class_paths_ls_2 = []
|
1263
1216
|
class_paths_ls = training_dataset_from_annotation(db_path, dst, annotation_column, annotated_classes=annotated_classes)
|
@@ -1268,6 +1221,7 @@ def generate_training_dataset(src, mode='annotation', annotation_column='test',
|
|
1268
1221
|
|
1269
1222
|
elif mode == 'metadata':
|
1270
1223
|
class_paths_ls = []
|
1224
|
+
class_len_ls = []
|
1271
1225
|
[df] = _read_db(db_loc=db_path, tables=['png_list'])
|
1272
1226
|
df['metadata_based_class'] = pd.NA
|
1273
1227
|
for i, class_ in enumerate(classes):
|
@@ -1275,7 +1229,18 @@ def generate_training_dataset(src, mode='annotation', annotation_column='test',
|
|
1275
1229
|
df.loc[df[metadata_type_by].isin(ls), 'metadata_based_class'] = class_
|
1276
1230
|
|
1277
1231
|
for class_ in classes:
|
1232
|
+
if size == None:
|
1233
|
+
c_s = []
|
1234
|
+
for c in classes:
|
1235
|
+
c_s_t_df = df[df['metadata_based_class'] == c]
|
1236
|
+
c_s.append(len(c_s_t_df))
|
1237
|
+
print(f'Found {len(c_s_t_df)} images for class {c}')
|
1238
|
+
size = min(c_s)
|
1239
|
+
print(f'Using the smallest class size: {size}')
|
1240
|
+
|
1278
1241
|
class_temp_df = df[df['metadata_based_class'] == class_]
|
1242
|
+
class_len_ls.append(len(class_temp_df))
|
1243
|
+
print(f'Found {len(class_temp_df)} images for class {class_}')
|
1279
1244
|
class_paths_temp = random.sample(class_temp_df['png_path'].tolist(), size)
|
1280
1245
|
class_paths_ls.append(class_paths_temp)
|
1281
1246
|
|
@@ -1332,7 +1297,8 @@ def generate_training_dataset(src, mode='annotation', annotation_column='test',
|
|
1332
1297
|
|
1333
1298
|
return
|
1334
1299
|
|
1335
|
-
def generate_loaders(src, train_mode='erm', mode='train', image_size=224, batch_size=32, classes=['nc','pc'], num_workers=None, validation_split=0.0, max_show=2, pin_memory=False, normalize=False, verbose=False):
|
1300
|
+
def generate_loaders(src, train_mode='erm', mode='train', image_size=224, batch_size=32, classes=['nc','pc'], num_workers=None, validation_split=0.0, max_show=2, pin_memory=False, normalize=False, channels=[1, 2, 3], verbose=False):
|
1301
|
+
|
1336
1302
|
"""
|
1337
1303
|
Generate data loaders for training and validation/test datasets.
|
1338
1304
|
|
@@ -1349,16 +1315,40 @@ def generate_loaders(src, train_mode='erm', mode='train', image_size=224, batch_
|
|
1349
1315
|
- pin_memory (bool): Whether to pin memory for faster data transfer.
|
1350
1316
|
- normalize (bool): Whether to normalize the input images.
|
1351
1317
|
- verbose (bool): Whether to print additional information and show images.
|
1318
|
+
- channels (list): The list of channels to retain. Options are [1, 2, 3] for all channels, [1, 2] for blue and green, etc.
|
1352
1319
|
|
1353
1320
|
Returns:
|
1354
1321
|
- train_loaders (list): List of data loaders for training datasets.
|
1355
1322
|
- val_loaders (list): List of data loaders for validation datasets.
|
1356
1323
|
- plate_names (list): List of plate names (only applicable when train_mode is 'irm').
|
1357
1324
|
"""
|
1358
|
-
|
1325
|
+
|
1359
1326
|
from .io import MyDataset
|
1360
1327
|
from .plot import _imshow
|
1361
|
-
|
1328
|
+
from torchvision import transforms
|
1329
|
+
from torch.utils.data import DataLoader, random_split
|
1330
|
+
from collections import defaultdict
|
1331
|
+
import os
|
1332
|
+
import random
|
1333
|
+
from PIL import Image
|
1334
|
+
from torchvision.transforms import ToTensor
|
1335
|
+
from .utils import SelectChannels
|
1336
|
+
|
1337
|
+
chans = []
|
1338
|
+
|
1339
|
+
if 'r' in channels:
|
1340
|
+
chans.append(1)
|
1341
|
+
if 'g' in channels:
|
1342
|
+
chans.append(2)
|
1343
|
+
if 'b' in channels:
|
1344
|
+
chans.append(3)
|
1345
|
+
|
1346
|
+
channels = chans
|
1347
|
+
|
1348
|
+
if verbose:
|
1349
|
+
print(f'Training a network on channels: {channels}')
|
1350
|
+
print(f'Channel 1: Red, Channel 2: Green, Channel 3: Blue')
|
1351
|
+
|
1362
1352
|
plate_to_filenames = defaultdict(list)
|
1363
1353
|
plate_to_labels = defaultdict(list)
|
1364
1354
|
train_loaders = []
|
@@ -1369,31 +1359,30 @@ def generate_loaders(src, train_mode='erm', mode='train', image_size=224, batch_
|
|
1369
1359
|
transform = transforms.Compose([
|
1370
1360
|
transforms.ToTensor(),
|
1371
1361
|
transforms.CenterCrop(size=(image_size, image_size)),
|
1362
|
+
SelectChannels(channels),
|
1372
1363
|
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])
|
1373
1364
|
else:
|
1374
1365
|
transform = transforms.Compose([
|
1375
1366
|
transforms.ToTensor(),
|
1376
|
-
transforms.CenterCrop(size=(image_size, image_size))
|
1377
|
-
|
1367
|
+
transforms.CenterCrop(size=(image_size, image_size)),
|
1368
|
+
SelectChannels(channels)])
|
1369
|
+
|
1378
1370
|
if mode == 'train':
|
1379
1371
|
data_dir = os.path.join(src, 'train')
|
1380
1372
|
shuffle = True
|
1381
|
-
print(
|
1382
|
-
|
1373
|
+
print('Generating Train and validation datasets')
|
1383
1374
|
elif mode == 'test':
|
1384
1375
|
data_dir = os.path.join(src, 'test')
|
1385
1376
|
val_loaders = []
|
1386
|
-
validation_split=0.0
|
1377
|
+
validation_split = 0.0
|
1387
1378
|
shuffle = True
|
1388
|
-
print(
|
1389
|
-
|
1379
|
+
print('Generating test dataset')
|
1390
1380
|
else:
|
1391
1381
|
print(f'mode:{mode} is not valid, use mode = train or test')
|
1392
1382
|
return
|
1393
|
-
|
1383
|
+
|
1394
1384
|
if train_mode == 'erm':
|
1395
1385
|
data = MyDataset(data_dir, classes, transform=transform, shuffle=shuffle, pin_memory=pin_memory)
|
1396
|
-
#train_loaders = DataLoader(data, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers if num_workers is not None else 0, pin_memory=pin_memory)
|
1397
1386
|
if validation_split > 0:
|
1398
1387
|
train_size = int((1 - validation_split) * len(data))
|
1399
1388
|
val_size = len(data) - train_size
|
@@ -1450,7 +1439,6 @@ def generate_loaders(src, train_mode='erm', mode='train', image_size=224, batch_
|
|
1450
1439
|
images = images.cpu()
|
1451
1440
|
label_strings = [str(label.item()) for label in labels]
|
1452
1441
|
_imshow(images, label_strings, nrow=20, fontsize=12)
|
1453
|
-
|
1454
1442
|
elif train_mode == 'irm':
|
1455
1443
|
for plate_name, train_loader in zip(plate_names, train_loaders):
|
1456
1444
|
print(f'Plate: {plate_name} with {len(train_loader.dataset)} images')
|
@@ -1569,15 +1557,30 @@ def analyze_recruitment(src, metadata_settings, advanced_settings):
|
|
1569
1557
|
df = df.dropna(subset=['condition'])
|
1570
1558
|
print(f'After dropping non-annotated wells: {len(df)} rows')
|
1571
1559
|
files = df['file_name'].tolist()
|
1560
|
+
print(f'found: {len(files)} files')
|
1572
1561
|
files = [item + '.npy' for item in files]
|
1573
1562
|
random.shuffle(files)
|
1574
|
-
|
1563
|
+
|
1564
|
+
_max = 10**100
|
1565
|
+
|
1566
|
+
if cell_size_range is None and nucleus_size_range is None and pathogen_size_range is None:
|
1567
|
+
filter_min_max = None
|
1568
|
+
else:
|
1569
|
+
if cell_size_range is None:
|
1570
|
+
cell_size_range = [0,_max]
|
1571
|
+
if nucleus_size_range is None:
|
1572
|
+
nucleus_size_range = [0,_max]
|
1573
|
+
if pathogen_size_range is None:
|
1574
|
+
pathogen_size_range = [0,_max]
|
1575
|
+
|
1576
|
+
filter_min_max = [[cell_size_range[0],cell_size_range[1]],[nucleus_size_range[0],nucleus_size_range[1]],[pathogen_size_range[0],pathogen_size_range[1]]]
|
1577
|
+
|
1575
1578
|
if plot:
|
1576
1579
|
plot_settings = {'include_noninfected':include_noninfected,
|
1577
1580
|
'include_multiinfected':include_multiinfected,
|
1578
1581
|
'include_multinucleated':include_multinucleated,
|
1579
1582
|
'remove_background':remove_background,
|
1580
|
-
'filter_min_max':
|
1583
|
+
'filter_min_max':filter_min_max,
|
1581
1584
|
'channel_dims':channel_dims,
|
1582
1585
|
'backgrounds':backgrounds,
|
1583
1586
|
'cell_mask_dim':mask_dims[0],
|
@@ -1634,31 +1637,225 @@ def analyze_recruitment(src, metadata_settings, advanced_settings):
|
|
1634
1637
|
cells,wells = _results_to_csv(src, df, df_well)
|
1635
1638
|
return [cells,wells]
|
1636
1639
|
|
1640
|
+
def _merge_cells_based_on_parasite_overlap(parasite_mask, cell_mask, nuclei_mask, overlap_threshold=5, perimeter_threshold=30):
|
1641
|
+
"""
|
1642
|
+
Merge cells in cell_mask if a parasite in parasite_mask overlaps with more than one cell,
|
1643
|
+
and if cells share more than a specified perimeter percentage.
|
1644
|
+
|
1645
|
+
Args:
|
1646
|
+
parasite_mask (ndarray): Mask of parasites.
|
1647
|
+
cell_mask (ndarray): Mask of cells.
|
1648
|
+
nuclei_mask (ndarray): Mask of nuclei.
|
1649
|
+
overlap_threshold (float): The percentage threshold for merging cells based on parasite overlap.
|
1650
|
+
perimeter_threshold (float): The percentage threshold for merging cells based on shared perimeter.
|
1651
|
+
|
1652
|
+
Returns:
|
1653
|
+
ndarray: The modified cell mask (cell_mask) with unique labels.
|
1654
|
+
"""
|
1655
|
+
labeled_cells = label(cell_mask)
|
1656
|
+
labeled_parasites = label(parasite_mask)
|
1657
|
+
labeled_nuclei = label(nuclei_mask)
|
1658
|
+
num_parasites = np.max(labeled_parasites)
|
1659
|
+
num_cells = np.max(labeled_cells)
|
1660
|
+
num_nuclei = np.max(labeled_nuclei)
|
1661
|
+
|
1662
|
+
# Merge cells based on parasite overlap
|
1663
|
+
for parasite_id in range(1, num_parasites + 1):
|
1664
|
+
current_parasite_mask = labeled_parasites == parasite_id
|
1665
|
+
overlapping_cell_labels = np.unique(labeled_cells[current_parasite_mask])
|
1666
|
+
overlapping_cell_labels = overlapping_cell_labels[overlapping_cell_labels != 0]
|
1667
|
+
if len(overlapping_cell_labels) > 1:
|
1668
|
+
# Calculate the overlap percentages
|
1669
|
+
overlap_percentages = [
|
1670
|
+
np.sum(current_parasite_mask & (labeled_cells == cell_label)) / np.sum(current_parasite_mask) * 100
|
1671
|
+
for cell_label in overlapping_cell_labels
|
1672
|
+
]
|
1673
|
+
# Merge cells if overlap percentage is above the threshold
|
1674
|
+
for cell_label, overlap_percentage in zip(overlapping_cell_labels, overlap_percentages):
|
1675
|
+
if overlap_percentage > overlap_threshold:
|
1676
|
+
first_label = overlapping_cell_labels[0]
|
1677
|
+
for other_label in overlapping_cell_labels[1:]:
|
1678
|
+
if other_label != first_label:
|
1679
|
+
cell_mask[cell_mask == other_label] = first_label
|
1680
|
+
|
1681
|
+
# Merge cells based on nucleus overlap
|
1682
|
+
for nucleus_id in range(1, num_nuclei + 1):
|
1683
|
+
current_nucleus_mask = labeled_nuclei == nucleus_id
|
1684
|
+
overlapping_cell_labels = np.unique(labeled_cells[current_nucleus_mask])
|
1685
|
+
overlapping_cell_labels = overlapping_cell_labels[overlapping_cell_labels != 0]
|
1686
|
+
if len(overlapping_cell_labels) > 1:
|
1687
|
+
# Calculate the overlap percentages
|
1688
|
+
overlap_percentages = [
|
1689
|
+
np.sum(current_nucleus_mask & (labeled_cells == cell_label)) / np.sum(current_nucleus_mask) * 100
|
1690
|
+
for cell_label in overlapping_cell_labels
|
1691
|
+
]
|
1692
|
+
# Merge cells if overlap percentage is above the threshold for each cell
|
1693
|
+
if all(overlap_percentage > overlap_threshold for overlap_percentage in overlap_percentages):
|
1694
|
+
first_label = overlapping_cell_labels[0]
|
1695
|
+
for other_label in overlapping_cell_labels[1:]:
|
1696
|
+
if other_label != first_label:
|
1697
|
+
cell_mask[cell_mask == other_label] = first_label
|
1698
|
+
|
1699
|
+
# Check for cells without nuclei and merge based on shared perimeter
|
1700
|
+
labeled_cells = label(cell_mask) # Re-label after merging based on overlap
|
1701
|
+
cell_regions = regionprops(labeled_cells)
|
1702
|
+
for region in cell_regions:
|
1703
|
+
cell_label = region.label
|
1704
|
+
cell_mask_binary = labeled_cells == cell_label
|
1705
|
+
overlapping_nuclei = np.unique(nuclei_mask[cell_mask_binary])
|
1706
|
+
overlapping_nuclei = overlapping_nuclei[overlapping_nuclei != 0]
|
1707
|
+
|
1708
|
+
if len(overlapping_nuclei) == 0:
|
1709
|
+
# Cell does not overlap with any nucleus
|
1710
|
+
perimeter = region.perimeter
|
1711
|
+
# Dilate the cell to find neighbors
|
1712
|
+
dilated_cell = binary_dilation(cell_mask_binary, structure=square(3))
|
1713
|
+
neighbor_cells = np.unique(labeled_cells[dilated_cell])
|
1714
|
+
neighbor_cells = neighbor_cells[(neighbor_cells != 0) & (neighbor_cells != cell_label)]
|
1715
|
+
# Calculate shared border length with neighboring cells
|
1716
|
+
shared_borders = [
|
1717
|
+
np.sum((labeled_cells == neighbor_label) & dilated_cell) for neighbor_label in neighbor_cells
|
1718
|
+
]
|
1719
|
+
shared_border_percentages = [shared_border / perimeter * 100 for shared_border in shared_borders]
|
1720
|
+
# Merge with the neighbor cell with the largest shared border percentage above the threshold
|
1721
|
+
if shared_borders:
|
1722
|
+
max_shared_border_index = np.argmax(shared_border_percentages)
|
1723
|
+
max_shared_border_percentage = shared_border_percentages[max_shared_border_index]
|
1724
|
+
if max_shared_border_percentage > perimeter_threshold:
|
1725
|
+
cell_mask[labeled_cells == cell_label] = neighbor_cells[max_shared_border_index]
|
1726
|
+
|
1727
|
+
# Relabel the merged cell mask
|
1728
|
+
relabeled_cell_mask, _ = label(cell_mask, return_num=True)
|
1729
|
+
return relabeled_cell_mask
|
1730
|
+
|
1731
|
+
def adjust_cell_masks(parasite_folder, cell_folder, nuclei_folder, overlap_threshold=5, perimeter_threshold=30):
|
1732
|
+
"""
|
1733
|
+
Process all npy files in the given folders. Merge and relabel cells in cell masks
|
1734
|
+
based on parasite overlap and cell perimeter sharing conditions.
|
1735
|
+
|
1736
|
+
Args:
|
1737
|
+
parasite_folder (str): Path to the folder containing parasite masks.
|
1738
|
+
cell_folder (str): Path to the folder containing cell masks.
|
1739
|
+
nuclei_folder (str): Path to the folder containing nuclei masks.
|
1740
|
+
overlap_threshold (float): The percentage threshold for merging cells based on parasite overlap.
|
1741
|
+
perimeter_threshold (float): The percentage threshold for merging cells based on shared perimeter.
|
1742
|
+
"""
|
1743
|
+
|
1744
|
+
parasite_files = sorted([f for f in os.listdir(parasite_folder) if f.endswith('.npy')])
|
1745
|
+
cell_files = sorted([f for f in os.listdir(cell_folder) if f.endswith('.npy')])
|
1746
|
+
nuclei_files = sorted([f for f in os.listdir(nuclei_folder) if f.endswith('.npy')])
|
1747
|
+
|
1748
|
+
# Ensure there are matching files in all folders
|
1749
|
+
if not (len(parasite_files) == len(cell_files) == len(nuclei_files)):
|
1750
|
+
raise ValueError("The number of files in the folders do not match.")
|
1751
|
+
|
1752
|
+
# Match files by name
|
1753
|
+
for file_name in parasite_files:
|
1754
|
+
parasite_path = os.path.join(parasite_folder, file_name)
|
1755
|
+
cell_path = os.path.join(cell_folder, file_name)
|
1756
|
+
nuclei_path = os.path.join(nuclei_folder, file_name)
|
1757
|
+
# Check if the corresponding cell and nuclei mask files exist
|
1758
|
+
if not (os.path.exists(cell_path) and os.path.exists(nuclei_path)):
|
1759
|
+
raise ValueError(f"Corresponding cell or nuclei mask file for {file_name} not found.")
|
1760
|
+
# Load the masks
|
1761
|
+
parasite_mask = np.load(parasite_path)
|
1762
|
+
cell_mask = np.load(cell_path)
|
1763
|
+
nuclei_mask = np.load(nuclei_path)
|
1764
|
+
# Merge and relabel cells
|
1765
|
+
merged_cell_mask = _merge_cells_based_on_parasite_overlap(parasite_mask, cell_mask, nuclei_mask, overlap_threshold, perimeter_threshold)
|
1766
|
+
# Overwrite the original cell mask file with the merged result
|
1767
|
+
np.save(cell_path, merged_cell_mask)
|
1768
|
+
|
1769
|
+
def process_masks(mask_folder, image_folder, channel, batch_size=50, n_clusters=2, plot=False):
|
1770
|
+
|
1771
|
+
def read_files_in_batches(folder, batch_size=50):
|
1772
|
+
files = [f for f in os.listdir(folder) if f.endswith('.npy')]
|
1773
|
+
files.sort() # Sort to ensure matching order
|
1774
|
+
for i in range(0, len(files), batch_size):
|
1775
|
+
yield files[i:i + batch_size]
|
1776
|
+
|
1777
|
+
def measure_morphology_and_intensity(mask, image):
|
1778
|
+
properties = measure.regionprops(mask, intensity_image=image)
|
1779
|
+
properties_list = [{'area': p.area, 'mean_intensity': p.mean_intensity, 'perimeter': p.perimeter, 'eccentricity': p.eccentricity} for p in properties]
|
1780
|
+
return properties_list
|
1781
|
+
|
1782
|
+
def cluster_objects(properties, n_clusters=2):
|
1783
|
+
data = np.array([[p['area'], p['mean_intensity'], p['perimeter'], p['eccentricity']] for p in properties])
|
1784
|
+
kmeans = KMeans(n_clusters=n_clusters, random_state=0).fit(data)
|
1785
|
+
return kmeans
|
1786
|
+
|
1787
|
+
def remove_objects_not_in_largest_cluster(mask, labels, largest_cluster_label):
|
1788
|
+
cleaned_mask = np.zeros_like(mask)
|
1789
|
+
for region in measure.regionprops(mask):
|
1790
|
+
if labels[region.label - 1] == largest_cluster_label:
|
1791
|
+
cleaned_mask[mask == region.label] = region.label
|
1792
|
+
return cleaned_mask
|
1793
|
+
|
1794
|
+
def plot_clusters(properties, labels):
|
1795
|
+
data = np.array([[p['area'], p['mean_intensity'], p['perimeter'], p['eccentricity']] for p in properties])
|
1796
|
+
pca = PCA(n_components=2)
|
1797
|
+
data_2d = pca.fit_transform(data)
|
1798
|
+
plt.scatter(data_2d[:, 0], data_2d[:, 1], c=labels, cmap='viridis')
|
1799
|
+
plt.xlabel('PCA Component 1')
|
1800
|
+
plt.ylabel('PCA Component 2')
|
1801
|
+
plt.title('Object Clustering')
|
1802
|
+
plt.show()
|
1803
|
+
|
1804
|
+
all_properties = []
|
1805
|
+
|
1806
|
+
# Step 1: Accumulate properties over all files
|
1807
|
+
for batch in read_files_in_batches(mask_folder, batch_size):
|
1808
|
+
mask_files = [os.path.join(mask_folder, file) for file in batch]
|
1809
|
+
image_files = [os.path.join(image_folder, file) for file in batch]
|
1810
|
+
|
1811
|
+
masks = [np.load(file) for file in mask_files]
|
1812
|
+
images = [np.load(file)[:, :, channel] for file in image_files]
|
1813
|
+
|
1814
|
+
for i, mask in enumerate(masks):
|
1815
|
+
image = images[i]
|
1816
|
+
# Measure morphology and intensity
|
1817
|
+
properties = measure_morphology_and_intensity(mask, image)
|
1818
|
+
all_properties.extend(properties)
|
1819
|
+
|
1820
|
+
# Step 2: Perform clustering on accumulated properties
|
1821
|
+
kmeans = cluster_objects(all_properties, n_clusters)
|
1822
|
+
labels = kmeans.labels_
|
1823
|
+
|
1824
|
+
if plot:
|
1825
|
+
# Step 3: Plot clusters using PCA
|
1826
|
+
plot_clusters(all_properties, labels)
|
1827
|
+
|
1828
|
+
# Step 4: Remove objects not in the largest cluster and overwrite files in batches
|
1829
|
+
label_index = 0
|
1830
|
+
for batch in read_files_in_batches(mask_folder, batch_size):
|
1831
|
+
mask_files = [os.path.join(mask_folder, file) for file in batch]
|
1832
|
+
masks = [np.load(file) for file in mask_files]
|
1833
|
+
|
1834
|
+
for i, mask in enumerate(masks):
|
1835
|
+
batch_properties = measure_morphology_and_intensity(mask, mask)
|
1836
|
+
batch_labels = labels[label_index:label_index + len(batch_properties)]
|
1837
|
+
largest_cluster_label = np.bincount(batch_labels).argmax()
|
1838
|
+
cleaned_mask = remove_objects_not_in_largest_cluster(mask, batch_labels, largest_cluster_label)
|
1839
|
+
np.save(mask_files[i], cleaned_mask)
|
1840
|
+
label_index += len(batch_properties)
|
1841
|
+
|
1637
1842
|
def preprocess_generate_masks(src, settings={}):
|
1638
1843
|
|
1639
1844
|
from .io import preprocess_img_data, _load_and_concatenate_arrays
|
1640
1845
|
from .plot import plot_merged, plot_arrays
|
1641
|
-
from .utils import _pivot_counts_table
|
1642
|
-
|
1643
|
-
settings
|
1644
|
-
settings['remove_background'] = True
|
1645
|
-
settings['lower_quantile'] = 0.02
|
1646
|
-
settings['merge'] = False
|
1647
|
-
settings['normalize_plots'] = True
|
1648
|
-
settings['all_to_mip'] = False
|
1649
|
-
settings['pick_slice'] = False
|
1650
|
-
settings['skip_mode'] = src
|
1651
|
-
settings['workers'] = os.cpu_count()-4
|
1652
|
-
settings['verbose'] = True
|
1653
|
-
settings['examples_to_plot'] = 1
|
1654
|
-
settings['src'] = src
|
1655
|
-
settings['upscale'] = False
|
1656
|
-
settings['upscale_factor'] = 2.0
|
1846
|
+
from .utils import _pivot_counts_table, set_default_settings_preprocess_generate_masks, set_default_plot_merge_settings, check_mask_folder
|
1847
|
+
|
1848
|
+
settings = set_default_settings_preprocess_generate_masks(src, settings)
|
1657
1849
|
|
1658
1850
|
settings_df = pd.DataFrame(list(settings.items()), columns=['Key', 'Value'])
|
1659
1851
|
settings_csv = os.path.join(src,'settings','preprocess_generate_masks_settings.csv')
|
1660
1852
|
os.makedirs(os.path.join(src,'settings'), exist_ok=True)
|
1661
1853
|
settings_df.to_csv(settings_csv, index=False)
|
1854
|
+
|
1855
|
+
if not settings['pathogen_channel'] is None:
|
1856
|
+
custom_model_ls = ['toxo_pv_lumen','toxo_cyto']
|
1857
|
+
if settings['pathogen_model'] not in custom_model_ls:
|
1858
|
+
ValueError(f'Pathogen model must be {custom_model_ls} or None')
|
1662
1859
|
|
1663
1860
|
if settings['timelapse']:
|
1664
1861
|
settings['randomize'] = False
|
@@ -1667,24 +1864,50 @@ def preprocess_generate_masks(src, settings={}):
|
|
1667
1864
|
if not settings['masks']:
|
1668
1865
|
print(f'WARNING: channels for mask generation are defined when preprocess = True')
|
1669
1866
|
|
1670
|
-
if isinstance(settings['merge'], bool):
|
1671
|
-
settings['merge'] = [settings['merge']]*3
|
1672
1867
|
if isinstance(settings['save'], bool):
|
1673
1868
|
settings['save'] = [settings['save']]*3
|
1674
1869
|
|
1870
|
+
if settings['verbose']:
|
1871
|
+
settings_df = pd.DataFrame(list(settings.items()), columns=['setting_key', 'setting_value'])
|
1872
|
+
settings_df['setting_value'] = settings_df['setting_value'].apply(str)
|
1873
|
+
display(settings_df)
|
1874
|
+
|
1875
|
+
if settings['test_mode']:
|
1876
|
+
print(f'Starting Test mode ...')
|
1877
|
+
|
1675
1878
|
if settings['preprocess']:
|
1676
1879
|
settings, src = preprocess_img_data(settings)
|
1677
1880
|
|
1678
1881
|
if settings['masks']:
|
1679
1882
|
mask_src = os.path.join(src, 'norm_channel_stack')
|
1680
1883
|
if settings['cell_channel'] != None:
|
1681
|
-
|
1884
|
+
if check_mask_folder(src, 'cell_mask_stack'):
|
1885
|
+
generate_cellpose_masks(mask_src, settings, 'cell')
|
1682
1886
|
|
1683
1887
|
if settings['nucleus_channel'] != None:
|
1684
|
-
|
1888
|
+
if check_mask_folder(src, 'nucleus_mask_stack'):
|
1889
|
+
generate_cellpose_masks(mask_src, settings, 'nucleus')
|
1685
1890
|
|
1686
1891
|
if settings['pathogen_channel'] != None:
|
1687
|
-
|
1892
|
+
if check_mask_folder(src, 'pathogen_mask_stack'):
|
1893
|
+
generate_cellpose_masks(mask_src, settings, 'pathogen')
|
1894
|
+
|
1895
|
+
if settings['adjust_cells']:
|
1896
|
+
if settings['pathogen_channel'] != None and settings['cell_channel'] != None and settings['nucleus_channel'] != None:
|
1897
|
+
|
1898
|
+
start = time.time()
|
1899
|
+
cell_folder = os.path.join(mask_src, 'cell_mask_stack')
|
1900
|
+
nuclei_folder = os.path.join(mask_src, 'nucleus_mask_stack')
|
1901
|
+
parasite_folder = os.path.join(mask_src, 'pathogen_mask_stack')
|
1902
|
+
#image_folder = os.path.join(src, 'stack')
|
1903
|
+
|
1904
|
+
#process_masks(cell_folder, image_folder, settings['cell_channel'], settings['batch_size'], n_clusters=2, plot=settings['plot'])
|
1905
|
+
#process_masks(nuclei_folder, image_folder, settings['nucleus_channel'], settings['batch_size'], n_clusters=2, plot=settings['plot'])
|
1906
|
+
#process_masks(parasite_folder, image_folder, settings['pathogen_channel'], settings['batch_size'], n_clusters=2, plot=settings['plot'])
|
1907
|
+
|
1908
|
+
adjust_cell_masks(parasite_folder, cell_folder, nuclei_folder, overlap_threshold=5, perimeter_threshold=30)
|
1909
|
+
stop = time.time()
|
1910
|
+
print(f'Cell mask adjustment: {stop-start} seconds')
|
1688
1911
|
|
1689
1912
|
if os.path.exists(os.path.join(src,'measurements')):
|
1690
1913
|
_pivot_counts_table(db_path=os.path.join(src,'measurements', 'measurements.db'))
|
@@ -1713,60 +1936,110 @@ def preprocess_generate_masks(src, settings={}):
|
|
1713
1936
|
overlay_channels = [settings['nucleus_channel'], settings['pathogen_channel'], settings['cell_channel']]
|
1714
1937
|
overlay_channels = [element for element in overlay_channels if element is not None]
|
1715
1938
|
|
1716
|
-
plot_settings =
|
1717
|
-
|
1718
|
-
|
1719
|
-
|
1720
|
-
|
1721
|
-
|
1722
|
-
|
1723
|
-
|
1724
|
-
|
1725
|
-
|
1726
|
-
|
1727
|
-
'outline_thickness':3,
|
1728
|
-
'outline_color':'gbr',
|
1729
|
-
'overlay_chans':overlay_channels,
|
1730
|
-
'overlay':True,
|
1731
|
-
'normalization_percentiles':[1,99],
|
1732
|
-
'normalize':True,
|
1733
|
-
'print_object_number':True,
|
1734
|
-
'nr':settings['examples_to_plot'],
|
1735
|
-
'figuresize':20,
|
1736
|
-
'cmap':'inferno',
|
1737
|
-
'verbose':False}
|
1939
|
+
plot_settings = set_default_plot_merge_settings()
|
1940
|
+
plot_settings['channel_dims'] = settings['channels']
|
1941
|
+
plot_settings['cell_mask_dim'] = cell_mask_dim
|
1942
|
+
plot_settings['nucleus_mask_dim'] = nucleus_mask_dim
|
1943
|
+
plot_settings['pathogen_mask_dim'] = pathogen_mask_dim
|
1944
|
+
plot_settings['overlay_chans'] = overlay_channels
|
1945
|
+
plot_settings['nr'] = settings['examples_to_plot']
|
1946
|
+
|
1947
|
+
if settings['test_mode'] == True:
|
1948
|
+
plot_settings['nr'] = len(os.path.join(src,'merged'))
|
1949
|
+
|
1738
1950
|
try:
|
1739
1951
|
fig = plot_merged(src=os.path.join(src,'merged'), settings=plot_settings)
|
1740
1952
|
except Exception as e:
|
1741
1953
|
print(f'Failed to plot image mask overly. Error: {e}')
|
1742
1954
|
else:
|
1743
|
-
plot_arrays(src=os.path.join(src,'merged'), figuresize=
|
1955
|
+
plot_arrays(src=os.path.join(src,'merged'), figuresize=settings['figuresize'], cmap=settings['cmap'], nr=settings['examples_to_plot'], normalize=settings['normalize'], q1=1, q2=99)
|
1744
1956
|
|
1745
1957
|
torch.cuda.empty_cache()
|
1746
1958
|
gc.collect()
|
1747
1959
|
print("Successfully completed run")
|
1748
1960
|
return
|
1749
1961
|
|
1750
|
-
def identify_masks_finetune(
|
1962
|
+
def identify_masks_finetune(settings):
|
1751
1963
|
|
1752
1964
|
from .plot import print_mask_and_flows
|
1753
1965
|
from .utils import get_files_from_dir, resize_images_and_labels
|
1754
1966
|
from .io import _load_normalized_images_and_labels, _load_images_and_labels
|
1755
1967
|
|
1968
|
+
#User defined settings
|
1969
|
+
src=settings['src']
|
1970
|
+
dst=settings['dst']
|
1971
|
+
|
1972
|
+
|
1973
|
+
settings.setdefault('model_name', 'cyto')
|
1974
|
+
settings.setdefault('custom_model', None)
|
1975
|
+
settings.setdefault('channels', [0,0])
|
1976
|
+
settings.setdefault('background', 100)
|
1977
|
+
settings.setdefault('remove_background', False)
|
1978
|
+
settings.setdefault('Signal_to_noise', 10)
|
1979
|
+
settings.setdefault('CP_prob', 0)
|
1980
|
+
settings.setdefault('diameter', 30)
|
1981
|
+
settings.setdefault('batch_size', 50)
|
1982
|
+
settings.setdefault('flow_threshold', 0.4)
|
1983
|
+
settings.setdefault('save', False)
|
1984
|
+
settings.setdefault('verbose', False)
|
1985
|
+
settings.setdefault('normalize', True)
|
1986
|
+
settings.setdefault('percentiles', None)
|
1987
|
+
settings.setdefault('circular', False)
|
1988
|
+
settings.setdefault('invert', False)
|
1989
|
+
settings.setdefault('resize', False)
|
1990
|
+
settings.setdefault('target_height', None)
|
1991
|
+
settings.setdefault('target_width', None)
|
1992
|
+
settings.setdefault('rescale', False)
|
1993
|
+
settings.setdefault('resample', False)
|
1994
|
+
settings.setdefault('grayscale', True)
|
1995
|
+
|
1996
|
+
|
1997
|
+
model_name=settings['model_name']
|
1998
|
+
custom_model=settings['custom_model']
|
1999
|
+
channels = settings['channels']
|
2000
|
+
background = settings['background']
|
2001
|
+
remove_background=settings['remove_background']
|
2002
|
+
Signal_to_noise = settings['Signal_to_noise']
|
2003
|
+
CP_prob = settings['CP_prob']
|
2004
|
+
diameter=settings['diameter']
|
2005
|
+
batch_size=settings['batch_size']
|
2006
|
+
flow_threshold=settings['flow_threshold']
|
2007
|
+
save=settings['save']
|
2008
|
+
verbose=settings['verbose']
|
2009
|
+
|
2010
|
+
# static settings
|
2011
|
+
normalize = settings['normalize']
|
2012
|
+
percentiles = settings['percentiles']
|
2013
|
+
circular = settings['circular']
|
2014
|
+
invert = settings['invert']
|
2015
|
+
resize = settings['resize']
|
2016
|
+
|
2017
|
+
if resize:
|
2018
|
+
target_height = settings['target_height']
|
2019
|
+
target_width = settings['target_width']
|
2020
|
+
|
2021
|
+
rescale = settings['rescale']
|
2022
|
+
resample = settings['resample']
|
2023
|
+
grayscale = settings['grayscale']
|
2024
|
+
|
2025
|
+
os.makedirs(dst, exist_ok=True)
|
2026
|
+
|
2027
|
+
if not custom_model is None:
|
2028
|
+
if not os.path.exists(custom_model):
|
2029
|
+
print(f'Custom model not found: {custom_model}')
|
2030
|
+
return
|
2031
|
+
|
1756
2032
|
if not torch.cuda.is_available():
|
1757
2033
|
print(f'Torch CUDA is not available, using CPU')
|
1758
2034
|
|
1759
2035
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
1760
2036
|
|
1761
2037
|
if custom_model == None:
|
1762
|
-
|
1763
|
-
|
1764
|
-
|
1765
|
-
|
1766
|
-
|
1767
|
-
if custom_model != None:
|
1768
|
-
model = cp_models.CellposeModel(gpu=torch.cuda.is_available(), model_type=None, pretrained_model=custom_model, diam_mean=diameter, device=device, net_avg=False) #Assuming diameter is defined elsewhere
|
1769
|
-
print(f'loaded custom model:{custom_model}')
|
2038
|
+
model = cp_models.CellposeModel(gpu=True, model_type=model_name, device=device)
|
2039
|
+
print(f'Loaded model: {model_name}')
|
2040
|
+
else:
|
2041
|
+
model = cp_models.CellposeModel(gpu=torch.cuda.is_available(), model_type=None, pretrained_model=custom_model, diam_mean=diameter, device=device)
|
2042
|
+
print("Pretrained Model Loaded:", model.pretrained_model)
|
1770
2043
|
|
1771
2044
|
chans = [2, 1] if model_name == 'cyto2' else [0,0] if model_name == 'nucleus' else [1,0] if model_name == 'cyto' else [2, 0]
|
1772
2045
|
|
@@ -1776,16 +2049,18 @@ def identify_masks_finetune(src, dst, model_name, channels, diameter, batch_size
|
|
1776
2049
|
print(f'Using channels: {chans} for model of type {model_name}')
|
1777
2050
|
|
1778
2051
|
if verbose == True:
|
1779
|
-
print(f'Cellpose settings: Model: {model_name}, channels: {channels}, cellpose_chans: {chans}, diameter:{diameter}, flow_threshold:{flow_threshold}, cellprob_threshold:{
|
2052
|
+
print(f'Cellpose settings: Model: {model_name}, channels: {channels}, cellpose_chans: {chans}, diameter:{diameter}, flow_threshold:{flow_threshold}, cellprob_threshold:{CP_prob}')
|
1780
2053
|
|
1781
|
-
all_image_files =
|
2054
|
+
all_image_files = [os.path.join(src, f) for f in os.listdir(src) if f.endswith('.tif')]
|
2055
|
+
|
1782
2056
|
random.shuffle(all_image_files)
|
1783
2057
|
|
1784
2058
|
time_ls = []
|
1785
2059
|
for i in range(0, len(all_image_files), batch_size):
|
1786
2060
|
image_files = all_image_files[i:i+batch_size]
|
2061
|
+
|
1787
2062
|
if normalize:
|
1788
|
-
images, _, image_names, _ = _load_normalized_images_and_labels(image_files=image_files, label_files=None,
|
2063
|
+
images, _, image_names, _ = _load_normalized_images_and_labels(image_files=image_files, label_files=None, channels=channels, percentiles=percentiles, circular=circular, invert=invert, visualize=verbose, remove_background=remove_background, background=background, Signal_to_noise=Signal_to_noise)
|
1789
2064
|
images = [np.squeeze(img) if img.shape[-1] == 1 else img for img in images]
|
1790
2065
|
orig_dims = [(image.shape[0], image.shape[1]) for image in images]
|
1791
2066
|
else:
|
@@ -1803,11 +2078,10 @@ def identify_masks_finetune(src, dst, model_name, channels, diameter, batch_size
|
|
1803
2078
|
channel_axis=3,
|
1804
2079
|
diameter=diameter,
|
1805
2080
|
flow_threshold=flow_threshold,
|
1806
|
-
cellprob_threshold=
|
2081
|
+
cellprob_threshold=CP_prob,
|
1807
2082
|
rescale=rescale,
|
1808
2083
|
resample=resample,
|
1809
|
-
|
1810
|
-
progress=False)
|
2084
|
+
progress=True)
|
1811
2085
|
|
1812
2086
|
if len(output) == 4:
|
1813
2087
|
mask, flows, _, _ = output
|
@@ -1825,11 +2099,12 @@ def identify_masks_finetune(src, dst, model_name, channels, diameter, batch_size
|
|
1825
2099
|
time_ls.append(duration)
|
1826
2100
|
average_time = np.mean(time_ls) if len(time_ls) > 0 else 0
|
1827
2101
|
print(f'Processing {file_index+1}/{len(images)} images : Time/image {average_time:.3f} sec', end='\r', flush=True)
|
1828
|
-
if
|
2102
|
+
if verbose:
|
1829
2103
|
if resize:
|
1830
2104
|
stack = resizescikit(stack, dims, preserve_range=True, anti_aliasing=False).astype(stack.dtype)
|
1831
|
-
print_mask_and_flows(stack, mask, flows, overlay=
|
2105
|
+
print_mask_and_flows(stack, mask, flows, overlay=True)
|
1832
2106
|
if save:
|
2107
|
+
os.makedirs(dst, exist_ok=True)
|
1833
2108
|
output_filename = os.path.join(dst, image_names[file_index])
|
1834
2109
|
cv2.imwrite(output_filename, mask)
|
1835
2110
|
return
|
@@ -1882,7 +2157,6 @@ def identify_masks(src, object_type, model_name, batch_size, channels, diameter,
|
|
1882
2157
|
|
1883
2158
|
#Note add logic that handles batches of size 1 as these will break the code batches must all be > 2 images
|
1884
2159
|
gc.collect()
|
1885
|
-
#print('========== generating masks ==========')
|
1886
2160
|
|
1887
2161
|
if not torch.cuda.is_available():
|
1888
2162
|
print(f'Torch CUDA is not available, using CPU')
|
@@ -1972,8 +2246,6 @@ def identify_masks(src, object_type, model_name, batch_size, channels, diameter,
|
|
1972
2246
|
stitch_threshold=0.0
|
1973
2247
|
|
1974
2248
|
cellpose_batch_size = _get_cellpose_batch_size()
|
1975
|
-
|
1976
|
-
#model = cellpose.denoise.DenoiseModel(model_type=f"denoise_{model_name}", gpu=True)
|
1977
2249
|
|
1978
2250
|
masks, flows, _, _ = model.eval(x=batch,
|
1979
2251
|
batch_size=cellpose_batch_size,
|
@@ -2047,9 +2319,21 @@ def all_elements_match(list1, list2):
|
|
2047
2319
|
# Check if all elements in list1 are in list2
|
2048
2320
|
return all(element in list2 for element in list1)
|
2049
2321
|
|
2050
|
-
def
|
2322
|
+
def prepare_batch_for_cellpose(batch):
|
2323
|
+
# Ensure the batch is of dtype float32
|
2324
|
+
if batch.dtype != np.float32:
|
2325
|
+
batch = batch.astype(np.float32)
|
2326
|
+
|
2327
|
+
# Normalize each image in the batch
|
2328
|
+
for i in range(batch.shape[0]):
|
2329
|
+
if batch[i].max() > 1:
|
2330
|
+
batch[i] = batch[i] / batch[i].max()
|
2331
|
+
|
2332
|
+
return batch
|
2333
|
+
|
2334
|
+
def generate_cellpose_masks(src, settings, object_type):
|
2051
2335
|
|
2052
|
-
from .utils import _masks_to_masks_stack, _filter_cp_masks, _get_cellpose_batch_size, _get_object_settings, _get_cellpose_channels, mask_object_count
|
2336
|
+
from .utils import _masks_to_masks_stack, _filter_cp_masks, _get_cellpose_batch_size, _get_object_settings, _get_cellpose_channels, _choose_model, mask_object_count, set_default_settings_preprocess_generate_masks
|
2053
2337
|
from .io import _create_database, _save_object_counts_to_database, _check_masks, _get_avg_object_size
|
2054
2338
|
from .timelapse import _npz_to_movie, _btrack_track_cells, _trackpy_track_cells
|
2055
2339
|
from .plot import plot_masks
|
@@ -2057,6 +2341,13 @@ def generate_cellpose_masks_v1(src, settings, object_type):
|
|
2057
2341
|
gc.collect()
|
2058
2342
|
if not torch.cuda.is_available():
|
2059
2343
|
print(f'Torch CUDA is not available, using CPU')
|
2344
|
+
|
2345
|
+
settings = set_default_settings_preprocess_generate_masks(src, settings)
|
2346
|
+
|
2347
|
+
if settings['verbose']:
|
2348
|
+
settings_df = pd.DataFrame(list(settings.items()), columns=['setting_key', 'setting_value'])
|
2349
|
+
settings_df['setting_value'] = settings_df['setting_value'].apply(str)
|
2350
|
+
display(settings_df)
|
2060
2351
|
|
2061
2352
|
figuresize=25
|
2062
2353
|
timelapse = settings['timelapse']
|
@@ -2071,23 +2362,26 @@ def generate_cellpose_masks_v1(src, settings, object_type):
|
|
2071
2362
|
|
2072
2363
|
batch_size = settings['batch_size']
|
2073
2364
|
cellprob_threshold = settings[f'{object_type}_CP_prob']
|
2074
|
-
|
2075
|
-
|
2365
|
+
|
2366
|
+
flow_threshold = settings[f'{object_type}_FT']
|
2367
|
+
|
2076
2368
|
object_settings = _get_object_settings(object_type, settings)
|
2077
2369
|
model_name = object_settings['model_name']
|
2078
2370
|
|
2079
2371
|
cellpose_channels = _get_cellpose_channels(src, settings['nucleus_channel'], settings['pathogen_channel'], settings['cell_channel'])
|
2080
2372
|
if settings['verbose']:
|
2081
2373
|
print(cellpose_channels)
|
2374
|
+
|
2082
2375
|
channels = cellpose_channels[object_type]
|
2083
2376
|
cellpose_batch_size = _get_cellpose_batch_size()
|
2084
|
-
|
2085
2377
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
2086
|
-
model = cp_models.Cellpose(gpu=True, model_type=model_name, device=device) #net_avg=net_avg
|
2087
|
-
#dn = denoise.CellposeDenoiseModel(model_type=f"denoise_{model_name}", gpu=True, device=device)
|
2088
2378
|
|
2379
|
+
if object_type == 'pathogen' and not settings['pathogen_model'] is None:
|
2380
|
+
model_name = settings['pathogen_model']
|
2381
|
+
|
2382
|
+
model = _choose_model(model_name, device, object_type=object_type, restore_type=None, object_settings=object_settings)
|
2383
|
+
|
2089
2384
|
chans = [2, 1] if model_name == 'cyto2' else [0,0] if model_name == 'nucleus' else [2,0] if model_name == 'cyto' else [2, 0] if model_name == 'cyto3' else [2, 0]
|
2090
|
-
|
2091
2385
|
paths = [os.path.join(src, file) for file in os.listdir(src) if file.endswith('.npz')]
|
2092
2386
|
|
2093
2387
|
count_loc = os.path.dirname(src)+'/measurements/measurements.db'
|
@@ -2096,7 +2390,6 @@ def generate_cellpose_masks_v1(src, settings, object_type):
|
|
2096
2390
|
|
2097
2391
|
average_sizes = []
|
2098
2392
|
time_ls = []
|
2099
|
-
|
2100
2393
|
for file_index, path in enumerate(paths):
|
2101
2394
|
name = os.path.basename(path)
|
2102
2395
|
name, ext = os.path.splitext(name)
|
@@ -2106,6 +2399,14 @@ def generate_cellpose_masks_v1(src, settings, object_type):
|
|
2106
2399
|
with np.load(path) as data:
|
2107
2400
|
stack = data['data']
|
2108
2401
|
filenames = data['filenames']
|
2402
|
+
|
2403
|
+
for i, filename in enumerate(filenames):
|
2404
|
+
output_path = os.path.join(output_folder, filename)
|
2405
|
+
|
2406
|
+
if os.path.exists(output_path):
|
2407
|
+
print(f"File {filename} already exists in the output folder. Skipping...")
|
2408
|
+
continue
|
2409
|
+
|
2109
2410
|
if settings['timelapse']:
|
2110
2411
|
|
2111
2412
|
trackable_objects = ['cell','nucleus','pathogen']
|
@@ -2140,31 +2441,43 @@ def generate_cellpose_masks_v1(src, settings, object_type):
|
|
2140
2441
|
if batch.size == 0:
|
2141
2442
|
print(f'Processing {file_index}/{len(paths)}: Images/npz {batch.shape[0]}')
|
2142
2443
|
continue
|
2143
|
-
|
2144
|
-
|
2444
|
+
|
2445
|
+
batch = prepare_batch_for_cellpose(batch)
|
2145
2446
|
|
2146
2447
|
if timelapse:
|
2147
|
-
stitch_threshold=100.0
|
2148
2448
|
movie_path = os.path.join(os.path.dirname(src), 'movies')
|
2149
2449
|
os.makedirs(movie_path, exist_ok=True)
|
2150
2450
|
save_path = os.path.join(movie_path, f'timelapse_{object_type}_{name}.mp4')
|
2151
2451
|
_npz_to_movie(batch, batch_filenames, save_path, fps=2)
|
2152
|
-
else:
|
2153
|
-
stitch_threshold=0.0
|
2154
|
-
|
2155
|
-
print('batch.shape',batch.shape)
|
2156
|
-
masks, flows, _, _ = model.eval(x=batch,
|
2157
|
-
batch_size=cellpose_batch_size,
|
2158
|
-
normalize=False,
|
2159
|
-
channels=chans,
|
2160
|
-
channel_axis=3,
|
2161
|
-
diameter=object_settings['diameter'],
|
2162
|
-
flow_threshold=flow_threshold,
|
2163
|
-
cellprob_threshold=cellprob_threshold,
|
2164
|
-
rescale=None,
|
2165
|
-
resample=object_settings['resample'],
|
2166
|
-
stitch_threshold=stitch_threshold)
|
2167
2452
|
|
2453
|
+
if settings['verbose']:
|
2454
|
+
print(f'Processing {file_index}/{len(paths)}: Images/npz {batch.shape[0]}')
|
2455
|
+
|
2456
|
+
#cellpose_normalize_dict = {'lowhigh':[0.0,1.0], #pass in normalization values for 0.0 and 1.0 as list [low, high] if None all other keys ignored
|
2457
|
+
# 'sharpen':object_settings['diameter']/4, #recommended to be 1/4-1/8 diameter of cells in pixels
|
2458
|
+
# 'normalize':True, #(if False, all following parameters ignored)
|
2459
|
+
# 'percentile':[2,98], #[perc_low, perc_high]
|
2460
|
+
# 'tile_norm':224, #normalize by tile set to e.g. 100 for normailize window to be 100 px
|
2461
|
+
# 'norm3D':True} #compute normalization across entire z-stack rather than plane-by-plane in stitching mode.
|
2462
|
+
|
2463
|
+
output = model.eval(x=batch,
|
2464
|
+
batch_size=cellpose_batch_size,
|
2465
|
+
normalize=False,
|
2466
|
+
channels=chans,
|
2467
|
+
channel_axis=3,
|
2468
|
+
diameter=object_settings['diameter'],
|
2469
|
+
flow_threshold=flow_threshold,
|
2470
|
+
cellprob_threshold=cellprob_threshold,
|
2471
|
+
rescale=None,
|
2472
|
+
resample=object_settings['resample'])
|
2473
|
+
|
2474
|
+
if len(output) == 4:
|
2475
|
+
masks, flows, _, _ = output
|
2476
|
+
elif len(output) == 3:
|
2477
|
+
masks, flows, _ = output
|
2478
|
+
else:
|
2479
|
+
raise ValueError(f"Unexpected number of return values from model.eval(). Expected 3 or 4, got {len(output)}")
|
2480
|
+
|
2168
2481
|
if timelapse:
|
2169
2482
|
if settings['plot']:
|
2170
2483
|
for idx, (mask, flow, image) in enumerate(zip(masks, flows[0], batch)):
|
@@ -2210,23 +2523,45 @@ def generate_cellpose_masks_v1(src, settings, object_type):
|
|
2210
2523
|
mode=timelapse_mode)
|
2211
2524
|
else:
|
2212
2525
|
mask_stack = _masks_to_masks_stack(masks)
|
2213
|
-
|
2214
2526
|
else:
|
2215
2527
|
_save_object_counts_to_database(masks, object_type, batch_filenames, count_loc, added_string='_before_filtration')
|
2216
|
-
|
2217
|
-
|
2218
|
-
|
2219
|
-
|
2220
|
-
|
2221
|
-
|
2222
|
-
|
2223
|
-
|
2224
|
-
|
2225
|
-
|
2226
|
-
|
2227
|
-
|
2228
|
-
|
2528
|
+
if object_settings['merge'] and not settings['filter']:
|
2529
|
+
mask_stack = _filter_cp_masks(masks=masks,
|
2530
|
+
flows=flows,
|
2531
|
+
filter_size=False,
|
2532
|
+
filter_intensity=False,
|
2533
|
+
minimum_size=object_settings['minimum_size'],
|
2534
|
+
maximum_size=object_settings['maximum_size'],
|
2535
|
+
remove_border_objects=False,
|
2536
|
+
merge=object_settings['merge'],
|
2537
|
+
batch=batch,
|
2538
|
+
plot=settings['plot'],
|
2539
|
+
figuresize=figuresize)
|
2540
|
+
|
2541
|
+
if settings['filter']:
|
2542
|
+
mask_stack = _filter_cp_masks(masks=masks,
|
2543
|
+
flows=flows,
|
2544
|
+
filter_size=object_settings['filter_size'],
|
2545
|
+
filter_intensity=object_settings['filter_intensity'],
|
2546
|
+
minimum_size=object_settings['minimum_size'],
|
2547
|
+
maximum_size=object_settings['maximum_size'],
|
2548
|
+
remove_border_objects=object_settings['remove_border_objects'],
|
2549
|
+
merge=object_settings['merge'],
|
2550
|
+
batch=batch,
|
2551
|
+
plot=settings['plot'],
|
2552
|
+
figuresize=figuresize)
|
2553
|
+
|
2554
|
+
_save_object_counts_to_database(mask_stack, object_type, batch_filenames, count_loc, added_string='_after_filtration')
|
2555
|
+
else:
|
2556
|
+
mask_stack = _masks_to_masks_stack(masks)
|
2229
2557
|
|
2558
|
+
if settings['plot']:
|
2559
|
+
for idx, (mask, flow, image) in enumerate(zip(masks, flows[0], batch)):
|
2560
|
+
if idx == 0:
|
2561
|
+
num_objects = mask_object_count(mask)
|
2562
|
+
print(f'Number of objects, : {num_objects}')
|
2563
|
+
plot_masks(batch=image, masks=mask, flows=flow, cmap='inferno', figuresize=figuresize, nr=1, file_type='.npz', print_object_number=True)
|
2564
|
+
|
2230
2565
|
if not np.any(mask_stack):
|
2231
2566
|
average_obj_size = 0
|
2232
2567
|
else:
|
@@ -2255,207 +2590,883 @@ def generate_cellpose_masks_v1(src, settings, object_type):
|
|
2255
2590
|
torch.cuda.empty_cache()
|
2256
2591
|
return
|
2257
2592
|
|
2258
|
-
def
|
2593
|
+
def generate_masks_from_imgs(src, model, model_name, batch_size, diameter, cellprob_threshold, flow_threshold, grayscale, save, normalize, channels, percentiles, circular, invert, plot, resize, target_height, target_width, remove_background, background, Signal_to_noise, verbose):
|
2259
2594
|
|
2260
|
-
from .
|
2261
|
-
from .
|
2262
|
-
from .
|
2263
|
-
|
2595
|
+
from .io import _load_images_and_labels, _load_normalized_images_and_labels
|
2596
|
+
from .utils import resize_images_and_labels, resizescikit
|
2597
|
+
from .plot import print_mask_and_flows
|
2598
|
+
|
2599
|
+
dst = os.path.join(src, model_name)
|
2600
|
+
os.makedirs(dst, exist_ok=True)
|
2601
|
+
|
2602
|
+
chans = [2, 1] if model_name == 'cyto2' else [0,0] if model_name == 'nucleus' else [1,0] if model_name == 'cyto' else [2, 0]
|
2603
|
+
|
2604
|
+
if grayscale:
|
2605
|
+
chans=[0, 0]
|
2264
2606
|
|
2265
|
-
|
2266
|
-
|
2267
|
-
print(f'Torch CUDA is not available, using CPU')
|
2607
|
+
all_image_files = [os.path.join(src, f) for f in os.listdir(src) if f.endswith('.tif')]
|
2608
|
+
random.shuffle(all_image_files)
|
2268
2609
|
|
2269
|
-
|
2270
|
-
|
2271
|
-
|
2272
|
-
if timelapse:
|
2273
|
-
timelapse_displacement = settings['timelapse_displacement']
|
2274
|
-
timelapse_frame_limits = settings['timelapse_frame_limits']
|
2275
|
-
timelapse_memory = settings['timelapse_memory']
|
2276
|
-
timelapse_remove_transient = settings['timelapse_remove_transient']
|
2277
|
-
timelapse_mode = settings['timelapse_mode']
|
2278
|
-
timelapse_objects = settings['timelapse_objects']
|
2279
|
-
|
2280
|
-
batch_size = settings['batch_size']
|
2281
|
-
cellprob_threshold = settings[f'{object_type}_CP_prob']
|
2282
|
-
flow_threshold = 30
|
2610
|
+
if verbose == True:
|
2611
|
+
print(f'Cellpose settings: Model: {model_name}, channels: {channels}, cellpose_chans: {chans}, diameter:{diameter}, flow_threshold:{flow_threshold}, cellprob_threshold:{cellprob_threshold}')
|
2283
2612
|
|
2284
|
-
|
2285
|
-
|
2613
|
+
time_ls = []
|
2614
|
+
for i in range(0, len(all_image_files), batch_size):
|
2615
|
+
image_files = all_image_files[i:i+batch_size]
|
2616
|
+
|
2617
|
+
if normalize:
|
2618
|
+
images, _, image_names, _ = _load_normalized_images_and_labels(image_files, None, channels, percentiles, circular, invert, plot, remove_background, background, Signal_to_noise)
|
2619
|
+
images = [np.squeeze(img) if img.shape[-1] == 1 else img for img in images]
|
2620
|
+
orig_dims = [(image.shape[0], image.shape[1]) for image in images]
|
2621
|
+
else:
|
2622
|
+
images, _, image_names, _ = _load_images_and_labels(image_files, None, circular, invert)
|
2623
|
+
images = [np.squeeze(img) if img.shape[-1] == 1 else img for img in images]
|
2624
|
+
orig_dims = [(image.shape[0], image.shape[1]) for image in images]
|
2625
|
+
if resize:
|
2626
|
+
images, _ = resize_images_and_labels(images, None, target_height, target_width, True)
|
2627
|
+
|
2628
|
+
for file_index, stack in enumerate(images):
|
2629
|
+
start = time.time()
|
2630
|
+
output = model.eval(x=stack,
|
2631
|
+
normalize=False,
|
2632
|
+
channels=chans,
|
2633
|
+
channel_axis=3,
|
2634
|
+
diameter=diameter,
|
2635
|
+
flow_threshold=flow_threshold,
|
2636
|
+
cellprob_threshold=cellprob_threshold,
|
2637
|
+
rescale=False,
|
2638
|
+
resample=False,
|
2639
|
+
progress=False)
|
2640
|
+
|
2641
|
+
if len(output) == 4:
|
2642
|
+
mask, flows, _, _ = output
|
2643
|
+
elif len(output) == 3:
|
2644
|
+
mask, flows, _ = output
|
2645
|
+
else:
|
2646
|
+
raise ValueError("Unexpected number of return values from model.eval()")
|
2647
|
+
|
2648
|
+
if resize:
|
2649
|
+
dims = orig_dims[file_index]
|
2650
|
+
mask = resizescikit(mask, dims, order=0, preserve_range=True, anti_aliasing=False).astype(mask.dtype)
|
2651
|
+
|
2652
|
+
stop = time.time()
|
2653
|
+
duration = (stop - start)
|
2654
|
+
time_ls.append(duration)
|
2655
|
+
average_time = np.mean(time_ls) if len(time_ls) > 0 else 0
|
2656
|
+
print(f'Processing {file_index+1}/{len(images)} images : Time/image {average_time:.3f} sec', end='\r', flush=True)
|
2657
|
+
if plot:
|
2658
|
+
if resize:
|
2659
|
+
stack = resizescikit(stack, dims, preserve_range=True, anti_aliasing=False).astype(stack.dtype)
|
2660
|
+
print_mask_and_flows(stack, mask, flows, overlay=True)
|
2661
|
+
if save:
|
2662
|
+
output_filename = os.path.join(dst, image_names[file_index])
|
2663
|
+
cv2.imwrite(output_filename, mask)
|
2664
|
+
|
2665
|
+
|
2666
|
+
def check_cellpose_models(settings):
|
2286
2667
|
|
2287
|
-
|
2668
|
+
src = settings['src']
|
2669
|
+
settings.setdefault('batch_size', 10)
|
2670
|
+
settings.setdefault('CP_prob', 0)
|
2671
|
+
settings.setdefault('flow_threshold', 0.4)
|
2672
|
+
settings.setdefault('save', True)
|
2673
|
+
settings.setdefault('normalize', True)
|
2674
|
+
settings.setdefault('channels', [0,0])
|
2675
|
+
settings.setdefault('percentiles', None)
|
2676
|
+
settings.setdefault('circular', False)
|
2677
|
+
settings.setdefault('invert', False)
|
2678
|
+
settings.setdefault('plot', True)
|
2679
|
+
settings.setdefault('diameter', 40)
|
2680
|
+
settings.setdefault('grayscale', True)
|
2681
|
+
settings.setdefault('remove_background', False)
|
2682
|
+
settings.setdefault('background', 100)
|
2683
|
+
settings.setdefault('Signal_to_noise', 5)
|
2684
|
+
settings.setdefault('verbose', False)
|
2685
|
+
settings.setdefault('resize', False)
|
2686
|
+
settings.setdefault('target_height', None)
|
2687
|
+
settings.setdefault('target_width', None)
|
2688
|
+
|
2288
2689
|
if settings['verbose']:
|
2289
|
-
|
2690
|
+
settings_df = pd.DataFrame(list(settings.items()), columns=['setting_key', 'setting_value'])
|
2691
|
+
settings_df['setting_value'] = settings_df['setting_value'].apply(str)
|
2692
|
+
display(settings_df)
|
2290
2693
|
|
2291
|
-
|
2292
|
-
cellpose_batch_size = _get_cellpose_batch_size()
|
2694
|
+
cellpose_models = ['cyto', 'nuclei', 'cyto2', 'cyto3']
|
2293
2695
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
2294
|
-
model = _choose_model(model_name, device, object_type='cell', restore_type=None)
|
2295
|
-
chans = [2, 1] if model_name == 'cyto2' else [0,0] if model_name == 'nucleus' else [2,0] if model_name == 'cyto' else [2, 0] if model_name == 'cyto3' else [2, 0]
|
2296
|
-
paths = [os.path.join(src, file) for file in os.listdir(src) if file.endswith('.npz')]
|
2297
2696
|
|
2298
|
-
|
2299
|
-
|
2300
|
-
|
2697
|
+
for model_name in cellpose_models:
|
2698
|
+
|
2699
|
+
model = cp_models.CellposeModel(gpu=True, model_type=model_name, device=device)
|
2700
|
+
print(f'Using {model_name}')
|
2701
|
+
generate_masks_from_imgs(src, model, model_name, settings['batch_size'], settings['diameter'], settings['CP_prob'], settings['flow_threshold'], settings['grayscale'], settings['save'], settings['normalize'], settings['channels'], settings['percentiles'], settings['circular'], settings['invert'], settings['plot'], settings['resize'], settings['target_height'], settings['target_width'], settings['remove_background'], settings['background'], settings['Signal_to_noise'], settings['verbose'])
|
2702
|
+
|
2703
|
+
return
|
2704
|
+
|
2705
|
+
def save_results_and_figure(src, fig, results):
|
2706
|
+
|
2707
|
+
if not isinstance(results, pd.DataFrame):
|
2708
|
+
results = pd.DataFrame(results)
|
2709
|
+
|
2710
|
+
results_dir = os.path.join(src, 'results')
|
2711
|
+
os.makedirs(results_dir, exist_ok=True)
|
2712
|
+
results_path = os.path.join(results_dir,f'results.csv')
|
2713
|
+
fig_path = os.path.join(results_dir, f'model_comparison_plot.pdf')
|
2714
|
+
results.to_csv(results_path, index=False)
|
2715
|
+
fig.savefig(fig_path, format='pdf')
|
2716
|
+
print(f'Saved figure to {fig_path} and results to {results_path}')
|
2717
|
+
|
2718
|
+
def compare_mask(args):
|
2719
|
+
src, filename, dirs, conditions = args
|
2720
|
+
paths = [os.path.join(d, filename) for d in dirs]
|
2721
|
+
|
2722
|
+
if not all(os.path.exists(path) for path in paths):
|
2723
|
+
return None
|
2724
|
+
|
2725
|
+
from .io import _read_mask # Import here to avoid issues in multiprocessing
|
2726
|
+
from .utils import extract_boundaries, boundary_f1_score, compute_segmentation_ap, jaccard_index
|
2727
|
+
from .plot import plot_comparison_results
|
2728
|
+
|
2729
|
+
masks = [_read_mask(path) for path in paths]
|
2730
|
+
file_results = {'filename': filename}
|
2731
|
+
|
2732
|
+
for i in range(len(masks)):
|
2733
|
+
for j in range(i + 1, len(masks)):
|
2734
|
+
mask_i, mask_j = masks[i], masks[j]
|
2735
|
+
f1_score = boundary_f1_score(mask_i, mask_j)
|
2736
|
+
jac_index = jaccard_index(mask_i, mask_j)
|
2737
|
+
ap_score = compute_segmentation_ap(mask_i, mask_j)
|
2738
|
+
|
2739
|
+
file_results.update({
|
2740
|
+
f'jaccard_{conditions[i]}_{conditions[j]}': jac_index,
|
2741
|
+
f'boundary_f1_{conditions[i]}_{conditions[j]}': f1_score,
|
2742
|
+
f'ap_{conditions[i]}_{conditions[j]}': ap_score
|
2743
|
+
})
|
2301
2744
|
|
2302
|
-
|
2303
|
-
time_ls = []
|
2304
|
-
for file_index, path in enumerate(paths):
|
2305
|
-
name = os.path.basename(path)
|
2306
|
-
name, ext = os.path.splitext(name)
|
2307
|
-
output_folder = os.path.join(os.path.dirname(path), object_type+'_mask_stack')
|
2308
|
-
os.makedirs(output_folder, exist_ok=True)
|
2309
|
-
overall_average_size = 0
|
2310
|
-
with np.load(path) as data:
|
2311
|
-
stack = data['data']
|
2312
|
-
filenames = data['filenames']
|
2313
|
-
if settings['timelapse']:
|
2745
|
+
return file_results
|
2314
2746
|
|
2315
|
-
|
2316
|
-
|
2317
|
-
|
2318
|
-
return
|
2747
|
+
def compare_cellpose_masks(src, verbose=False, processes=None, save=True):
|
2748
|
+
from .plot import visualize_cellpose_masks, plot_comparison_results
|
2749
|
+
from .io import _read_mask
|
2319
2750
|
|
2320
|
-
|
2321
|
-
|
2322
|
-
|
2323
|
-
batch_size = len(stack)
|
2324
|
-
if isinstance(timelapse_frame_limits, list):
|
2325
|
-
if len(timelapse_frame_limits) >= 2:
|
2326
|
-
stack = stack[timelapse_frame_limits[0]: timelapse_frame_limits[1], :, :, :].astype(stack.dtype)
|
2327
|
-
filenames = filenames[timelapse_frame_limits[0]: timelapse_frame_limits[1]]
|
2328
|
-
batch_size = len(stack)
|
2329
|
-
print(f'Cut batch at indecies: {timelapse_frame_limits}, New batch_size: {batch_size} ')
|
2751
|
+
dirs = [os.path.join(src, d) for d in os.listdir(src) if os.path.isdir(os.path.join(src, d)) and d != 'results']
|
2752
|
+
dirs.sort() # Optional: sort directories if needed
|
2753
|
+
conditions = [os.path.basename(d) for d in dirs]
|
2330
2754
|
|
2331
|
-
|
2332
|
-
|
2333
|
-
|
2755
|
+
# Get common files in all directories
|
2756
|
+
common_files = set(os.listdir(dirs[0]))
|
2757
|
+
for d in dirs[1:]:
|
2758
|
+
common_files.intersection_update(os.listdir(d))
|
2759
|
+
common_files = list(common_files)
|
2334
2760
|
|
2335
|
-
|
2336
|
-
|
2337
|
-
|
2338
|
-
|
2761
|
+
# Create a pool of workers
|
2762
|
+
with Pool(processes=processes) as pool:
|
2763
|
+
args = [(src, filename, dirs, conditions) for filename in common_files]
|
2764
|
+
results = pool.map(compare_mask, args)
|
2339
2765
|
|
2340
|
-
|
2766
|
+
# Filter out None results (from skipped files)
|
2767
|
+
results = [res for res in results if res is not None]
|
2768
|
+
#print(results)
|
2769
|
+
if verbose:
|
2770
|
+
for result in results:
|
2771
|
+
filename = result['filename']
|
2772
|
+
masks = [_read_mask(os.path.join(d, filename)) for d in dirs]
|
2773
|
+
visualize_cellpose_masks(masks, titles=conditions, filename=filename, save=save, src=src)
|
2341
2774
|
|
2342
|
-
|
2343
|
-
|
2344
|
-
|
2345
|
-
print(f'Processing {file_index}/{len(paths)}: Images/npz {batch.shape[0]}')
|
2346
|
-
continue
|
2347
|
-
if batch.max() > 1:
|
2348
|
-
batch = batch / batch.max()
|
2775
|
+
fig = plot_comparison_results(results)
|
2776
|
+
save_results_and_figure(src, fig, results)
|
2777
|
+
return
|
2349
2778
|
|
2350
|
-
|
2351
|
-
|
2352
|
-
|
2353
|
-
|
2354
|
-
|
2355
|
-
|
2356
|
-
|
2357
|
-
|
2358
|
-
|
2359
|
-
print('batch.shape',batch.shape)
|
2360
|
-
masks, flows, _, _ = model.eval(x=batch,
|
2361
|
-
batch_size=cellpose_batch_size,
|
2362
|
-
normalize=False,
|
2363
|
-
channels=chans,
|
2364
|
-
channel_axis=3,
|
2365
|
-
diameter=object_settings['diameter'],
|
2366
|
-
flow_threshold=flow_threshold,
|
2367
|
-
cellprob_threshold=cellprob_threshold,
|
2368
|
-
rescale=None,
|
2369
|
-
resample=object_settings['resample'],
|
2370
|
-
stitch_threshold=stitch_threshold)
|
2371
|
-
|
2372
|
-
if timelapse:
|
2779
|
+
def _calculate_similarity(df, features, col_to_compare, val1, val2):
|
2780
|
+
"""
|
2781
|
+
Calculate similarity scores of each well to the positive and negative controls using various metrics.
|
2782
|
+
|
2783
|
+
Args:
|
2784
|
+
df (pandas.DataFrame): DataFrame containing the data.
|
2785
|
+
features (list): List of feature columns to use for similarity calculation.
|
2786
|
+
col_to_compare (str): Column name to use for comparing groups.
|
2787
|
+
val1, val2 (str): Values in col_to_compare to create subsets for comparison.
|
2373
2788
|
|
2374
|
-
|
2375
|
-
|
2376
|
-
|
2377
|
-
|
2378
|
-
|
2379
|
-
|
2789
|
+
Returns:
|
2790
|
+
pandas.DataFrame: DataFrame with similarity scores.
|
2791
|
+
"""
|
2792
|
+
# Separate positive and negative control wells
|
2793
|
+
pos_control = df[df[col_to_compare] == val1][features].mean()
|
2794
|
+
neg_control = df[df[col_to_compare] == val2][features].mean()
|
2795
|
+
|
2796
|
+
# Standardize features for Mahalanobis distance
|
2797
|
+
scaler = StandardScaler()
|
2798
|
+
scaled_features = scaler.fit_transform(df[features])
|
2799
|
+
|
2800
|
+
# Regularize the covariance matrix to avoid singularity
|
2801
|
+
cov_matrix = np.cov(scaled_features, rowvar=False)
|
2802
|
+
inv_cov_matrix = None
|
2803
|
+
try:
|
2804
|
+
inv_cov_matrix = np.linalg.inv(cov_matrix)
|
2805
|
+
except np.linalg.LinAlgError:
|
2806
|
+
# Add a small value to the diagonal elements for regularization
|
2807
|
+
epsilon = 1e-5
|
2808
|
+
inv_cov_matrix = np.linalg.inv(cov_matrix + np.eye(cov_matrix.shape[0]) * epsilon)
|
2809
|
+
|
2810
|
+
# Calculate similarity scores
|
2811
|
+
df['similarity_to_pos_euclidean'] = df[features].apply(lambda row: euclidean(row, pos_control), axis=1)
|
2812
|
+
df['similarity_to_neg_euclidean'] = df[features].apply(lambda row: euclidean(row, neg_control), axis=1)
|
2813
|
+
df['similarity_to_pos_cosine'] = df[features].apply(lambda row: cosine(row, pos_control), axis=1)
|
2814
|
+
df['similarity_to_neg_cosine'] = df[features].apply(lambda row: cosine(row, neg_control), axis=1)
|
2815
|
+
df['similarity_to_pos_mahalanobis'] = df[features].apply(lambda row: mahalanobis(row, pos_control, inv_cov_matrix), axis=1)
|
2816
|
+
df['similarity_to_neg_mahalanobis'] = df[features].apply(lambda row: mahalanobis(row, neg_control, inv_cov_matrix), axis=1)
|
2817
|
+
df['similarity_to_pos_manhattan'] = df[features].apply(lambda row: cityblock(row, pos_control), axis=1)
|
2818
|
+
df['similarity_to_neg_manhattan'] = df[features].apply(lambda row: cityblock(row, neg_control), axis=1)
|
2819
|
+
df['similarity_to_pos_minkowski'] = df[features].apply(lambda row: minkowski(row, pos_control, p=3), axis=1)
|
2820
|
+
df['similarity_to_neg_minkowski'] = df[features].apply(lambda row: minkowski(row, neg_control, p=3), axis=1)
|
2821
|
+
df['similarity_to_pos_chebyshev'] = df[features].apply(lambda row: chebyshev(row, pos_control), axis=1)
|
2822
|
+
df['similarity_to_neg_chebyshev'] = df[features].apply(lambda row: chebyshev(row, neg_control), axis=1)
|
2823
|
+
df['similarity_to_pos_hamming'] = df[features].apply(lambda row: hamming(row, pos_control), axis=1)
|
2824
|
+
df['similarity_to_neg_hamming'] = df[features].apply(lambda row: hamming(row, neg_control), axis=1)
|
2825
|
+
df['similarity_to_pos_jaccard'] = df[features].apply(lambda row: jaccard(row, pos_control), axis=1)
|
2826
|
+
df['similarity_to_neg_jaccard'] = df[features].apply(lambda row: jaccard(row, neg_control), axis=1)
|
2827
|
+
df['similarity_to_pos_braycurtis'] = df[features].apply(lambda row: braycurtis(row, pos_control), axis=1)
|
2828
|
+
df['similarity_to_neg_braycurtis'] = df[features].apply(lambda row: braycurtis(row, neg_control), axis=1)
|
2829
|
+
|
2830
|
+
return df
|
2380
2831
|
|
2381
|
-
|
2382
|
-
|
2383
|
-
|
2384
|
-
|
2385
|
-
|
2386
|
-
|
2387
|
-
radius = 100
|
2832
|
+
def _permutation_importance(df, feature_string='channel_3', col_to_compare='col', pos='c1', neg='c2', exclude=None, n_repeats=10, clean=True, nr_to_plot=30, n_estimators=100, test_size=0.2, random_state=42, model_type='xgboost', n_jobs=-1):
|
2833
|
+
|
2834
|
+
"""
|
2835
|
+
Calculates permutation importance for numerical features in the dataframe,
|
2836
|
+
comparing groups based on specified column values and uses the model to predict
|
2837
|
+
the class for all other rows in the dataframe.
|
2388
2838
|
|
2389
|
-
|
2390
|
-
|
2391
|
-
|
2839
|
+
Args:
|
2840
|
+
df (pandas.DataFrame): The DataFrame containing the data.
|
2841
|
+
feature_string (str): String to filter features that contain this substring.
|
2842
|
+
col_to_compare (str): Column name to use for comparing groups.
|
2843
|
+
pos, neg (str): Values in col_to_compare to create subsets for comparison.
|
2844
|
+
exclude (list or str, optional): Columns to exclude from features.
|
2845
|
+
n_repeats (int): Number of repeats for permutation importance.
|
2846
|
+
clean (bool): Whether to remove columns with a single value.
|
2847
|
+
nr_to_plot (int): Number of top features to plot based on permutation importance.
|
2848
|
+
n_estimators (int): Number of trees in the random forest, gradient boosting, or XGBoost model.
|
2849
|
+
test_size (float): Proportion of the dataset to include in the test split.
|
2850
|
+
random_state (int): Random seed for reproducibility.
|
2851
|
+
model_type (str): Type of model to use ('random_forest', 'logistic_regression', 'gradient_boosting', 'xgboost').
|
2852
|
+
n_jobs (int): Number of jobs to run in parallel for applicable models.
|
2392
2853
|
|
2393
|
-
|
2394
|
-
|
2395
|
-
|
2396
|
-
|
2397
|
-
|
2398
|
-
|
2399
|
-
|
2400
|
-
|
2401
|
-
|
2402
|
-
|
2403
|
-
|
2404
|
-
|
2405
|
-
|
2406
|
-
|
2407
|
-
|
2408
|
-
|
2409
|
-
|
2410
|
-
|
2411
|
-
|
2412
|
-
|
2413
|
-
|
2414
|
-
|
2415
|
-
|
2416
|
-
|
2417
|
-
|
2854
|
+
Returns:
|
2855
|
+
pandas.DataFrame: The original dataframe with added prediction and data usage columns.
|
2856
|
+
pandas.DataFrame: DataFrame containing the importances and standard deviations.
|
2857
|
+
"""
|
2858
|
+
|
2859
|
+
from .utils import filter_dataframe_features
|
2860
|
+
|
2861
|
+
if 'cells_per_well' in df.columns:
|
2862
|
+
df = df.drop(columns=['cells_per_well'])
|
2863
|
+
|
2864
|
+
# Subset the dataframe based on specified column values
|
2865
|
+
df1 = df[df[col_to_compare] == pos].copy()
|
2866
|
+
df2 = df[df[col_to_compare] == neg].copy()
|
2867
|
+
|
2868
|
+
# Create target variable
|
2869
|
+
df1['target'] = 0
|
2870
|
+
df2['target'] = 1
|
2871
|
+
|
2872
|
+
# Combine the subsets for analysis
|
2873
|
+
combined_df = pd.concat([df1, df2])
|
2874
|
+
|
2875
|
+
if feature_string in ['channel_0', 'channel_1', 'channel_2', 'channel_3']:
|
2876
|
+
channel_of_interest = int(feature_string.split('_')[-1])
|
2877
|
+
elif not feature_string is 'morphology':
|
2878
|
+
channel_of_interest = 'morphology'
|
2879
|
+
|
2880
|
+
_, features = filter_dataframe_features(combined_df, channel_of_interest, exclude)
|
2881
|
+
|
2882
|
+
X = combined_df[features]
|
2883
|
+
y = combined_df['target']
|
2884
|
+
|
2885
|
+
# Split the data into training and testing sets
|
2886
|
+
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size, random_state=random_state)
|
2887
|
+
|
2888
|
+
# Label the data in the original dataframe
|
2889
|
+
combined_df['data_usage'] = 'train'
|
2890
|
+
combined_df.loc[X_test.index, 'data_usage'] = 'test'
|
2891
|
+
|
2892
|
+
# Initialize the model based on model_type
|
2893
|
+
if model_type == 'random_forest':
|
2894
|
+
model = RandomForestClassifier(n_estimators=n_estimators, random_state=random_state, n_jobs=n_jobs)
|
2895
|
+
elif model_type == 'logistic_regression':
|
2896
|
+
model = LogisticRegression(max_iter=1000, random_state=random_state, n_jobs=n_jobs)
|
2897
|
+
elif model_type == 'gradient_boosting':
|
2898
|
+
model = HistGradientBoostingClassifier(max_iter=n_estimators, random_state=random_state) # Supports n_jobs internally
|
2899
|
+
elif model_type == 'xgboost':
|
2900
|
+
model = XGBClassifier(n_estimators=n_estimators, random_state=random_state, nthread=n_jobs, use_label_encoder=False, eval_metric='logloss')
|
2901
|
+
else:
|
2902
|
+
raise ValueError(f"Unsupported model_type: {model_type}")
|
2903
|
+
|
2904
|
+
model.fit(X_train, y_train)
|
2905
|
+
|
2906
|
+
perm_importance = permutation_importance(model, X_train, y_train, n_repeats=n_repeats, random_state=random_state, n_jobs=n_jobs)
|
2907
|
+
|
2908
|
+
# Create a DataFrame for permutation importances
|
2909
|
+
permutation_df = pd.DataFrame({
|
2910
|
+
'feature': [features[i] for i in perm_importance.importances_mean.argsort()],
|
2911
|
+
'importance_mean': perm_importance.importances_mean[perm_importance.importances_mean.argsort()],
|
2912
|
+
'importance_std': perm_importance.importances_std[perm_importance.importances_mean.argsort()]
|
2913
|
+
}).tail(nr_to_plot)
|
2914
|
+
|
2915
|
+
# Plotting
|
2916
|
+
fig, ax = plt.subplots()
|
2917
|
+
ax.barh(permutation_df['feature'], permutation_df['importance_mean'], xerr=permutation_df['importance_std'], color="teal", align="center", alpha=0.6)
|
2918
|
+
ax.set_xlabel('Permutation Importance')
|
2919
|
+
plt.tight_layout()
|
2920
|
+
plt.show()
|
2921
|
+
|
2922
|
+
# Feature importance for models that support it
|
2923
|
+
if model_type in ['random_forest', 'xgboost', 'gradient_boosting']:
|
2924
|
+
feature_importances = model.feature_importances_
|
2925
|
+
feature_importance_df = pd.DataFrame({
|
2926
|
+
'feature': features,
|
2927
|
+
'importance': feature_importances
|
2928
|
+
}).sort_values(by='importance', ascending=False).head(nr_to_plot)
|
2929
|
+
|
2930
|
+
# Plotting feature importance
|
2931
|
+
fig, ax = plt.subplots()
|
2932
|
+
ax.barh(feature_importance_df['feature'], feature_importance_df['importance'], color="blue", align="center", alpha=0.6)
|
2933
|
+
ax.set_xlabel('Feature Importance')
|
2934
|
+
plt.tight_layout()
|
2935
|
+
plt.show()
|
2936
|
+
else:
|
2937
|
+
feature_importance_df = pd.DataFrame()
|
2938
|
+
|
2939
|
+
# Predicting the target variable for the test set
|
2940
|
+
predictions_test = model.predict(X_test)
|
2941
|
+
combined_df.loc[X_test.index, 'predictions'] = predictions_test
|
2942
|
+
|
2943
|
+
# Predicting the target variable for the training set
|
2944
|
+
predictions_train = model.predict(X_train)
|
2945
|
+
combined_df.loc[X_train.index, 'predictions'] = predictions_train
|
2946
|
+
|
2947
|
+
# Predicting the target variable for all other rows in the dataframe
|
2948
|
+
X_all = df[features]
|
2949
|
+
all_predictions = model.predict(X_all)
|
2950
|
+
df['predictions'] = all_predictions
|
2951
|
+
|
2952
|
+
# Combine data usage labels back to the original dataframe
|
2953
|
+
combined_data_usage = pd.concat([combined_df[['data_usage']], df[['predictions']]], axis=0)
|
2954
|
+
df = df.join(combined_data_usage, how='left', rsuffix='_model')
|
2955
|
+
|
2956
|
+
# Calculating and printing the accuracy metrics
|
2957
|
+
accuracy = accuracy_score(y_test, predictions_test)
|
2958
|
+
precision = precision_score(y_test, predictions_test)
|
2959
|
+
recall = recall_score(y_test, predictions_test)
|
2960
|
+
f1 = f1_score(y_test, predictions_test)
|
2961
|
+
print(f"Accuracy: {accuracy}")
|
2962
|
+
print(f"Precision: {precision}")
|
2963
|
+
print(f"Recall: {recall}")
|
2964
|
+
print(f"F1 Score: {f1}")
|
2965
|
+
|
2966
|
+
# Printing class-specific accuracy metrics
|
2967
|
+
print("\nClassification Report:")
|
2968
|
+
print(classification_report(y_test, predictions_test))
|
2969
|
+
|
2970
|
+
df = _calculate_similarity(df, features, col_to_compare, pos, neg)
|
2971
|
+
|
2972
|
+
return [df, permutation_df, feature_importance_df, model, X_train, X_test, y_train, y_test]
|
2973
|
+
|
2974
|
+
def _shap_analysis(model, X_train, X_test):
|
2975
|
+
|
2976
|
+
"""
|
2977
|
+
Performs SHAP analysis on the given model and data.
|
2978
|
+
|
2979
|
+
Args:
|
2980
|
+
model: The trained model.
|
2981
|
+
X_train (pandas.DataFrame): Training feature set.
|
2982
|
+
X_test (pandas.DataFrame): Testing feature set.
|
2983
|
+
"""
|
2984
|
+
|
2985
|
+
explainer = shap.Explainer(model, X_train)
|
2986
|
+
shap_values = explainer(X_test)
|
2987
|
+
|
2988
|
+
# Summary plot
|
2989
|
+
shap.summary_plot(shap_values, X_test)
|
2990
|
+
|
2991
|
+
def plate_heatmap(src, model_type='xgboost', variable='predictions', grouping='mean', min_max='allq', cmap='viridis', channel_of_interest=3, min_count=25, n_estimators=100, col_to_compare='col', pos='c1', neg='c2', exclude=None, n_repeats=10, clean=True, nr_to_plot=20, verbose=False, n_jobs=-1):
|
2992
|
+
from .io import _read_and_merge_data
|
2993
|
+
from .plot import _plot_plates
|
2994
|
+
|
2995
|
+
db_loc = [src+'/measurements/measurements.db']
|
2996
|
+
tables = ['cell', 'nucleus', 'pathogen','cytoplasm']
|
2997
|
+
include_multinucleated, include_multiinfected, include_noninfected = True, 2.0, True
|
2998
|
+
|
2999
|
+
df, _ = _read_and_merge_data(db_loc,
|
3000
|
+
tables,
|
3001
|
+
verbose=verbose,
|
3002
|
+
include_multinucleated=include_multinucleated,
|
3003
|
+
include_multiinfected=include_multiinfected,
|
3004
|
+
include_noninfected=include_noninfected)
|
3005
|
+
|
3006
|
+
if not channel_of_interest is None:
|
3007
|
+
df['recruitment'] = df[f'pathogen_channel_{channel_of_interest}_mean_intensity']/df[f'cytoplasm_channel_{channel_of_interest}_mean_intensity']
|
3008
|
+
feature_string = f'channel_{channel_of_interest}'
|
3009
|
+
else:
|
3010
|
+
feature_string = None
|
3011
|
+
|
3012
|
+
output = _permutation_importance(df, feature_string, col_to_compare, pos, neg, exclude, n_repeats, clean, nr_to_plot, n_estimators=n_estimators, random_state=42, model_type=model_type, n_jobs=n_jobs)
|
3013
|
+
|
3014
|
+
_shap_analysis(output[3], output[4], output[5])
|
3015
|
+
|
3016
|
+
features = output[0].select_dtypes(include=[np.number]).columns.tolist()
|
3017
|
+
|
3018
|
+
if not variable in features:
|
3019
|
+
raise ValueError(f"Variable {variable} not found in the dataframe. Please choose one of the following: {features}")
|
3020
|
+
|
3021
|
+
plate_heatmap = _plot_plates(output[0], variable, grouping, min_max, cmap, min_count)
|
3022
|
+
return [output, plate_heatmap]
|
3023
|
+
|
3024
|
+
def join_measurments_and_annotation(src, tables = ['cell', 'nucleus', 'pathogen','cytoplasm']):
|
3025
|
+
|
3026
|
+
from .io import _read_and_merge_data, _read_db
|
3027
|
+
|
3028
|
+
db_loc = [src+'/measurements/measurements.db']
|
3029
|
+
loc = src+'/measurements/measurements.db'
|
3030
|
+
df, _ = _read_and_merge_data(db_loc,
|
3031
|
+
tables,
|
3032
|
+
verbose=True,
|
3033
|
+
include_multinucleated=True,
|
3034
|
+
include_multiinfected=True,
|
3035
|
+
include_noninfected=True)
|
3036
|
+
|
3037
|
+
paths_df = _read_db(loc, tables=['png_list'])
|
3038
|
+
|
3039
|
+
merged_df = pd.merge(df, paths_df[0], on='prcfo', how='left')
|
3040
|
+
|
3041
|
+
return merged_df
|
3042
|
+
|
3043
|
+
def jitterplot_by_annotation(src, x_column, y_column, plot_title='Jitter Plot', output_path=None, filter_column=None, filter_values=None):
|
3044
|
+
"""
|
3045
|
+
Reads a CSV file and creates a jitter plot of one column grouped by another column.
|
3046
|
+
|
3047
|
+
Args:
|
3048
|
+
src (str): Path to the source data.
|
3049
|
+
x_column (str): Name of the column to be used for the x-axis.
|
3050
|
+
y_column (str): Name of the column to be used for the y-axis.
|
3051
|
+
plot_title (str): Title of the plot. Default is 'Jitter Plot'.
|
3052
|
+
output_path (str): Path to save the plot image. If None, the plot will be displayed. Default is None.
|
3053
|
+
|
3054
|
+
Returns:
|
3055
|
+
pd.DataFrame: The filtered and balanced DataFrame.
|
3056
|
+
"""
|
3057
|
+
# Read the CSV file into a DataFrame
|
3058
|
+
df = join_measurments_and_annotation(src, tables=['cell', 'nucleus', 'pathogen', 'cytoplasm'])
|
3059
|
+
|
3060
|
+
# Print column names for debugging
|
3061
|
+
print(f"Generated dataframe with: {df.shape[1]} columns and {df.shape[0]} rows")
|
3062
|
+
#print("Columns in DataFrame:", df.columns.tolist())
|
3063
|
+
|
3064
|
+
# Replace NaN values with a specific label in x_column
|
3065
|
+
df[x_column] = df[x_column].fillna('NaN')
|
3066
|
+
|
3067
|
+
# Filter the DataFrame if filter_column and filter_values are provided
|
3068
|
+
if not filter_column is None:
|
3069
|
+
if isinstance(filter_column, str):
|
3070
|
+
df = df[df[filter_column].isin(filter_values)]
|
3071
|
+
if isinstance(filter_column, list):
|
3072
|
+
for i,val in enumerate(filter_column):
|
3073
|
+
print(f'hello {len(df)}')
|
3074
|
+
df = df[df[val].isin(filter_values[i])]
|
3075
|
+
|
3076
|
+
# Use the correct column names based on your DataFrame
|
3077
|
+
required_columns = ['plate_x', 'row_x', 'col_x']
|
3078
|
+
if not all(column in df.columns for column in required_columns):
|
3079
|
+
raise KeyError(f"DataFrame does not contain the necessary columns: {required_columns}")
|
3080
|
+
|
3081
|
+
# Filter to retain rows with non-NaN values in x_column and with matching plate, row, col values
|
3082
|
+
non_nan_df = df[df[x_column] != 'NaN']
|
3083
|
+
retained_rows = df[df[['plate_x', 'row_x', 'col_x']].apply(tuple, axis=1).isin(non_nan_df[['plate_x', 'row_x', 'col_x']].apply(tuple, axis=1))]
|
3084
|
+
|
3085
|
+
# Determine the minimum count of examples across all groups in x_column
|
3086
|
+
min_count = retained_rows[x_column].value_counts().min()
|
3087
|
+
print(f'Found {min_count} annotated images')
|
3088
|
+
|
3089
|
+
# Randomly sample min_count examples from each group in x_column
|
3090
|
+
balanced_df = retained_rows.groupby(x_column).apply(lambda x: x.sample(min_count, random_state=42)).reset_index(drop=True)
|
3091
|
+
|
3092
|
+
# Create the jitter plot
|
3093
|
+
plt.figure(figsize=(10, 6))
|
3094
|
+
jitter_plot = sns.stripplot(data=balanced_df, x=x_column, y=y_column, hue=x_column, jitter=True, palette='viridis', dodge=False)
|
3095
|
+
plt.title(plot_title)
|
3096
|
+
plt.xlabel(x_column)
|
3097
|
+
plt.ylabel(y_column)
|
3098
|
+
|
3099
|
+
# Customize the x-axis labels
|
3100
|
+
plt.xticks(rotation=45, ha='right')
|
3101
|
+
|
3102
|
+
# Adjust the position of the x-axis labels to be centered below the data
|
3103
|
+
ax = plt.gca()
|
3104
|
+
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='center')
|
3105
|
+
|
3106
|
+
# Save the plot to a file or display it
|
3107
|
+
if output_path:
|
3108
|
+
plt.savefig(output_path, bbox_inches='tight')
|
3109
|
+
print(f"Jitter plot saved to {output_path}")
|
3110
|
+
else:
|
3111
|
+
plt.show()
|
2418
3112
|
|
3113
|
+
return balanced_df
|
3114
|
+
|
3115
|
+
def generate_image_umap(settings={}):
|
3116
|
+
"""
|
3117
|
+
Generate UMAP or tSNE embedding and visualize the data with clustering.
|
3118
|
+
|
3119
|
+
Parameters:
|
3120
|
+
settings (dict): Dictionary containing the following keys:
|
3121
|
+
src (str): Source directory containing the data.
|
3122
|
+
row_limit (int): Limit the number of rows to process.
|
3123
|
+
tables (list): List of table names to read from the database.
|
3124
|
+
visualize (str): Visualization type.
|
3125
|
+
image_nr (int): Number of images to display.
|
3126
|
+
dot_size (int): Size of dots in the scatter plot.
|
3127
|
+
n_neighbors (int): Number of neighbors for UMAP.
|
3128
|
+
figuresize (int): Size of the figure.
|
3129
|
+
black_background (bool): Whether to use a black background.
|
3130
|
+
remove_image_canvas (bool): Whether to remove the image canvas.
|
3131
|
+
plot_outlines (bool): Whether to plot outlines.
|
3132
|
+
plot_points (bool): Whether to plot points.
|
3133
|
+
smooth_lines (bool): Whether to smooth lines.
|
3134
|
+
verbose (bool): Whether to print verbose output.
|
3135
|
+
embedding_by_controls (bool): Whether to use embedding from controls.
|
3136
|
+
col_to_compare (str): Column to compare for control-based embedding.
|
3137
|
+
pos (str): Positive control value.
|
3138
|
+
neg (str): Negative control value.
|
3139
|
+
clustering (str): Clustering method ('DBSCAN' or 'KMeans').
|
3140
|
+
exclude (list): List of columns to exclude from the analysis.
|
3141
|
+
plot_images (bool): Whether to plot images.
|
3142
|
+
reduction_method (str): Dimensionality reduction method ('UMAP' or 'tSNE').
|
3143
|
+
save_figure (bool): Whether to save the figure as a PDF.
|
3144
|
+
|
3145
|
+
Returns:
|
3146
|
+
pd.DataFrame: DataFrame with the original data and an additional column 'cluster' containing the cluster identity.
|
3147
|
+
"""
|
3148
|
+
|
3149
|
+
from .io import _read_and_join_tables
|
3150
|
+
from .utils import get_db_paths, preprocess_data, reduction_and_clustering, remove_noise, generate_colors, correct_paths, plot_embedding, plot_clusters_grid, get_umap_image_settings
|
3151
|
+
from .alpha import cluster_feature_analysis, generate_umap_from_images
|
3152
|
+
|
3153
|
+
settings = get_umap_image_settings(settings)
|
3154
|
+
|
3155
|
+
if isinstance(settings['src'], str):
|
3156
|
+
settings['src'] = [settings['src']]
|
3157
|
+
|
3158
|
+
if settings['plot_images'] is False:
|
3159
|
+
settings['black_background'] = False
|
3160
|
+
|
3161
|
+
if settings['color_by']:
|
3162
|
+
settings['remove_cluster_noise'] = False
|
3163
|
+
settings['plot_outlines'] = False
|
3164
|
+
settings['smooth_lines'] = False
|
3165
|
+
|
3166
|
+
settings_df = pd.DataFrame(list(settings.items()), columns=['Key', 'Value'])
|
3167
|
+
settings_dir = os.path.join(settings['src'][0],'settings')
|
3168
|
+
settings_csv = os.path.join(settings_dir,'embedding_settings.csv')
|
3169
|
+
os.makedirs(settings_dir, exist_ok=True)
|
3170
|
+
settings_df.to_csv(settings_csv, index=False)
|
3171
|
+
display(settings_df)
|
3172
|
+
|
3173
|
+
db_paths = get_db_paths(settings['src'])
|
3174
|
+
|
3175
|
+
tables = settings['tables'] + ['png_list']
|
3176
|
+
all_df = pd.DataFrame()
|
3177
|
+
#image_paths = []
|
3178
|
+
|
3179
|
+
for i,db_path in enumerate(db_paths):
|
3180
|
+
df = _read_and_join_tables(db_path, table_names=tables)
|
3181
|
+
df, image_paths_tmp = correct_paths(df, settings['src'][i])
|
3182
|
+
all_df = pd.concat([all_df, df], axis=0)
|
3183
|
+
#image_paths.extend(image_paths_tmp)
|
3184
|
+
|
3185
|
+
all_df['cond'] = all_df['col'].apply(map_condition, neg=settings['neg'], pos=settings['pos'], mix=settings['mix'])
|
3186
|
+
|
3187
|
+
if settings['exclude_conditions']:
|
3188
|
+
if isinstance(settings['exclude_conditions'], str):
|
3189
|
+
settings['exclude_conditions'] = [settings['exclude_conditions']]
|
3190
|
+
row_count_before = len(all_df)
|
3191
|
+
all_df = all_df[~all_df['cond'].isin(settings['exclude_conditions'])]
|
3192
|
+
if settings['verbose']:
|
3193
|
+
print(f'Excluded {row_count_before - len(all_df)} rows after excluding: {settings["exclude_conditions"]}, rows left: {len(all_df)}')
|
3194
|
+
|
3195
|
+
if settings['row_limit'] is not None:
|
3196
|
+
all_df = all_df.sample(n=settings['row_limit'], random_state=42)
|
3197
|
+
|
3198
|
+
image_paths = all_df['png_path'].to_list()
|
3199
|
+
|
3200
|
+
if settings['embedding_by_controls']:
|
3201
|
+
|
3202
|
+
# Extract and reset the index for the column to compare
|
3203
|
+
col_to_compare = all_df[settings['col_to_compare']].reset_index(drop=True)
|
3204
|
+
|
3205
|
+
# Preprocess the data to obtain numeric data
|
3206
|
+
numeric_data = preprocess_data(all_df, settings['filter_by'], settings['remove_highly_correlated'], settings['log_data'], settings['exclude'])
|
3207
|
+
|
3208
|
+
# Convert numeric_data back to a DataFrame to align with col_to_compare
|
3209
|
+
numeric_data_df = pd.DataFrame(numeric_data)
|
3210
|
+
|
3211
|
+
# Ensure numeric_data_df and col_to_compare are properly aligned
|
3212
|
+
numeric_data_df = numeric_data_df.reset_index(drop=True)
|
3213
|
+
|
3214
|
+
# Assign the column back to numeric_data_df
|
3215
|
+
numeric_data_df[settings['col_to_compare']] = col_to_compare
|
3216
|
+
|
3217
|
+
# Subset the dataframe based on specified column values for controls
|
3218
|
+
positive_control_df = numeric_data_df[numeric_data_df[settings['col_to_compare']] == settings['pos']].copy()
|
3219
|
+
negative_control_df = numeric_data_df[numeric_data_df[settings['col_to_compare']] == settings['neg']].copy()
|
3220
|
+
control_numeric_data_df = pd.concat([positive_control_df, negative_control_df])
|
3221
|
+
|
3222
|
+
# Drop the comparison column from numeric_data_df and control_numeric_data_df
|
3223
|
+
numeric_data_df = numeric_data_df.drop(columns=[settings['col_to_compare']])
|
3224
|
+
control_numeric_data_df = control_numeric_data_df.drop(columns=[settings['col_to_compare']])
|
3225
|
+
|
3226
|
+
# Convert numeric_data_df and control_numeric_data_df back to numpy arrays
|
3227
|
+
numeric_data = numeric_data_df.values
|
3228
|
+
control_numeric_data = control_numeric_data_df.values
|
3229
|
+
|
3230
|
+
# Train the reducer on control data
|
3231
|
+
_, _, reducer = reduction_and_clustering(control_numeric_data, settings['n_neighbors'], settings['min_dist'], settings['metric'], settings['eps'], settings['min_samples'], settings['clustering'], settings['reduction_method'], settings['verbose'], n_jobs=settings['n_jobs'], mode='fit', model=False)
|
3232
|
+
|
3233
|
+
# Apply the trained reducer to the entire dataset
|
3234
|
+
numeric_data = preprocess_data(all_df, settings['filter_by'], settings['remove_highly_correlated'], settings['log_data'], settings['exclude'])
|
3235
|
+
embedding, labels, _ = reduction_and_clustering(numeric_data, settings['n_neighbors'], settings['min_dist'], settings['metric'], settings['eps'], settings['min_samples'], settings['clustering'], settings['reduction_method'], settings['verbose'], n_jobs=settings['n_jobs'], mode=None, model=reducer)
|
3236
|
+
|
3237
|
+
else:
|
3238
|
+
if settings['resnet_features']:
|
3239
|
+
numeric_data, embedding, labels = generate_umap_from_images(image_paths, settings['n_neighbors'], settings['min_dist'], settings['metric'], settings['clustering'], settings['eps'], settings['min_samples'], settings['n_jobs'], settings['verbose'])
|
3240
|
+
else:
|
3241
|
+
# Apply the trained reducer to the entire dataset
|
3242
|
+
numeric_data = preprocess_data(all_df, settings['filter_by'], settings['remove_highly_correlated'], settings['log_data'], settings['exclude'])
|
3243
|
+
embedding, labels, _ = reduction_and_clustering(numeric_data, settings['n_neighbors'], settings['min_dist'], settings['metric'], settings['eps'], settings['min_samples'], settings['clustering'], settings['reduction_method'], settings['verbose'], n_jobs=settings['n_jobs'])
|
3244
|
+
|
3245
|
+
if settings['remove_cluster_noise']:
|
3246
|
+
# Remove noise from the clusters (removes -1 labels from DBSCAN)
|
3247
|
+
embedding, labels = remove_noise(embedding, labels)
|
3248
|
+
|
3249
|
+
# Plot the results
|
3250
|
+
if settings['color_by']:
|
3251
|
+
if settings['embedding_by_controls']:
|
3252
|
+
labels = all_df[settings['color_by']]
|
3253
|
+
else:
|
3254
|
+
labels = all_df[settings['color_by']]
|
3255
|
+
|
3256
|
+
# Generate colors for the clusters
|
3257
|
+
colors = generate_colors(len(np.unique(labels)), settings['black_background'])
|
3258
|
+
|
3259
|
+
# Plot the embedding
|
3260
|
+
umap_plt = plot_embedding(embedding, image_paths, labels, settings['image_nr'], settings['img_zoom'], colors, settings['plot_by_cluster'], settings['plot_outlines'], settings['plot_points'], settings['plot_images'], settings['smooth_lines'], settings['black_background'], settings['figuresize'], settings['dot_size'], settings['remove_image_canvas'], settings['verbose'])
|
3261
|
+
if settings['plot_cluster_grids'] and settings['plot_images']:
|
3262
|
+
grid_plt = plot_clusters_grid(embedding, labels, settings['image_nr'], image_paths, colors, settings['figuresize'], settings['black_background'], settings['verbose'])
|
3263
|
+
|
3264
|
+
# Save figure as PDF if required
|
3265
|
+
if settings['save_figure']:
|
3266
|
+
results_dir = os.path.join(settings['src'][0], 'results')
|
3267
|
+
os.makedirs(results_dir, exist_ok=True)
|
3268
|
+
reduction_method = settings['reduction_method'].upper()
|
3269
|
+
embedding_path = os.path.join(results_dir, f'{reduction_method}_embedding.pdf')
|
3270
|
+
umap_plt.savefig(embedding_path, format='pdf')
|
3271
|
+
print(f'Saved {reduction_method} embedding to {embedding_path} and grid to {embedding_path}')
|
3272
|
+
if settings['plot_cluster_grids'] and settings['plot_images']:
|
3273
|
+
grid_path = os.path.join(results_dir, f'{reduction_method}_grid.pdf')
|
3274
|
+
grid_plt.savefig(grid_path, format='pdf')
|
3275
|
+
print(f'Saved {reduction_method} embedding to {embedding_path} and grid to {grid_path}')
|
3276
|
+
|
3277
|
+
# Add cluster labels to the dataframe
|
3278
|
+
all_df['cluster'] = labels
|
3279
|
+
|
3280
|
+
# Save the results to a CSV file
|
3281
|
+
results_dir = os.path.join(settings['src'][0], 'results')
|
3282
|
+
results_csv = os.path.join(results_dir,'embedding_results.csv')
|
3283
|
+
os.makedirs(results_dir, exist_ok=True)
|
3284
|
+
all_df.to_csv(results_csv, index=False)
|
3285
|
+
print(f'Results saved to {results_csv}')
|
3286
|
+
|
3287
|
+
if settings['analyze_clusters']:
|
3288
|
+
combined_results = cluster_feature_analysis(all_df)
|
3289
|
+
results_dir = os.path.join(settings['src'][0], 'results')
|
3290
|
+
cluster_results_csv = os.path.join(results_dir,'cluster_results.csv')
|
3291
|
+
os.makedirs(results_dir, exist_ok=True)
|
3292
|
+
combined_results.to_csv(cluster_results_csv, index=False)
|
3293
|
+
print(f'Cluster results saved to {cluster_results_csv}')
|
3294
|
+
|
3295
|
+
return all_df
|
3296
|
+
|
3297
|
+
# Define the mapping function
|
3298
|
+
def map_condition(col_value, neg='c1', pos='c2', mix='c3'):
|
3299
|
+
if col_value == neg:
|
3300
|
+
return 'neg'
|
3301
|
+
elif col_value == pos:
|
3302
|
+
return 'pos'
|
3303
|
+
elif col_value == mix:
|
3304
|
+
return 'mix'
|
3305
|
+
else:
|
3306
|
+
return 'screen'
|
3307
|
+
|
3308
|
+
def reducer_hyperparameter_search(settings={}, reduction_params=None, dbscan_params=None, kmeans_params=None, save=False):
|
3309
|
+
"""
|
3310
|
+
Perform a hyperparameter search for UMAP or tSNE on the given data.
|
3311
|
+
|
3312
|
+
Parameters:
|
3313
|
+
settings (dict): Dictionary containing the following keys:
|
3314
|
+
src (str): Source directory containing the data.
|
3315
|
+
row_limit (int): Limit the number of rows to process.
|
3316
|
+
tables (list): List of table names to read from the database.
|
3317
|
+
filter_by (str): Column to filter the data.
|
3318
|
+
sample_size (int): Number of samples to use for the hyperparameter search.
|
3319
|
+
remove_highly_correlated (bool): Whether to remove highly correlated columns.
|
3320
|
+
log_data (bool): Whether to log transform the data.
|
3321
|
+
verbose (bool): Whether to print verbose output.
|
3322
|
+
reduction_method (str): Dimensionality reduction method ('UMAP' or 'tSNE').
|
3323
|
+
reduction_params (list): List of dictionaries containing hyperparameters to test for the reduction method.
|
3324
|
+
dbscan_params (list): List of dictionaries containing DBSCAN hyperparameters to test.
|
3325
|
+
kmeans_params (list): List of dictionaries containing KMeans hyperparameters to test.
|
3326
|
+
pointsize (int): Size of the points in the scatter plot.
|
3327
|
+
save (bool): Whether to save the resulting plot as a file.
|
3328
|
+
|
3329
|
+
Returns:
|
3330
|
+
None
|
3331
|
+
"""
|
3332
|
+
|
3333
|
+
from .io import _read_and_join_tables
|
3334
|
+
from .utils import get_db_paths, preprocess_data, search_reduction_and_clustering, generate_colors, get_umap_image_settings
|
3335
|
+
|
3336
|
+
settings = get_umap_image_settings(settings)
|
3337
|
+
pointsize = settings['dot_size']
|
3338
|
+
if isinstance(dbscan_params, dict):
|
3339
|
+
dbscan_params = [dbscan_params]
|
3340
|
+
|
3341
|
+
if isinstance(kmeans_params, dict):
|
3342
|
+
kmeans_params = [kmeans_params]
|
3343
|
+
|
3344
|
+
if isinstance(reduction_params, dict):
|
3345
|
+
reduction_params = [reduction_params]
|
3346
|
+
|
3347
|
+
# Determine reduction method based on the keys in reduction_param
|
3348
|
+
if any('n_neighbors' in param for param in reduction_params):
|
3349
|
+
reduction_method = 'umap'
|
3350
|
+
elif any('perplexity' in param for param in reduction_params):
|
3351
|
+
reduction_method = 'tsne'
|
3352
|
+
elif any('perplexity' in param for param in reduction_params) and any('n_neighbors' in param for param in reduction_params):
|
3353
|
+
raise ValueError("Reduction parameters must include 'n_neighbors' for UMAP or 'perplexity' for tSNE, not both.")
|
3354
|
+
|
3355
|
+
if settings['reduction_method'].lower() != reduction_method:
|
3356
|
+
settings['reduction_method'] = reduction_method
|
3357
|
+
print(f'Changed reduction method to {reduction_method} based on the provided parameters.')
|
3358
|
+
|
3359
|
+
if settings['verbose']:
|
3360
|
+
display(pd.DataFrame(list(settings.items()), columns=['Key', 'Value']))
|
3361
|
+
|
3362
|
+
db_paths = get_db_paths(settings['src'])
|
3363
|
+
|
3364
|
+
tables = settings['tables']
|
3365
|
+
all_df = pd.DataFrame()
|
3366
|
+
for db_path in db_paths:
|
3367
|
+
df = _read_and_join_tables(db_path, table_names=tables)
|
3368
|
+
all_df = pd.concat([all_df, df], axis=0)
|
3369
|
+
|
3370
|
+
all_df['cond'] = all_df['col'].apply(map_condition, neg=settings['neg'], pos=settings['pos'], mix=settings['mix'])
|
3371
|
+
|
3372
|
+
if settings['exclude_conditions']:
|
3373
|
+
if isinstance(settings['exclude_conditions'], str):
|
3374
|
+
settings['exclude_conditions'] = [settings['exclude_conditions']]
|
3375
|
+
row_count_before = len(all_df)
|
3376
|
+
all_df = all_df[~all_df['cond'].isin(settings['exclude_conditions'])]
|
3377
|
+
if settings['verbose']:
|
3378
|
+
print(f'Excluded {row_count_before - len(all_df)} rows after excluding: {settings["exclude_conditions"]}, rows left: {len(all_df)}')
|
3379
|
+
|
3380
|
+
if settings['row_limit'] is not None:
|
3381
|
+
all_df = all_df.sample(n=settings['row_limit'], random_state=42)
|
3382
|
+
|
3383
|
+
numeric_data = preprocess_data(all_df, settings['filter_by'], settings['remove_highly_correlated'], settings['log_data'], settings['exclude'])
|
3384
|
+
|
3385
|
+
# Combine DBSCAN and KMeans parameters
|
3386
|
+
clustering_params = []
|
3387
|
+
if dbscan_params:
|
3388
|
+
for param in dbscan_params:
|
3389
|
+
param['method'] = 'dbscan'
|
3390
|
+
clustering_params.append(param)
|
3391
|
+
if kmeans_params:
|
3392
|
+
for param in kmeans_params:
|
3393
|
+
param['method'] = 'kmeans'
|
3394
|
+
clustering_params.append(param)
|
3395
|
+
|
3396
|
+
print('Testing paramiters:', reduction_params)
|
3397
|
+
print('Testing clustering paramiters:', clustering_params)
|
3398
|
+
|
3399
|
+
# Calculate the grid size
|
3400
|
+
grid_rows = len(reduction_params)
|
3401
|
+
grid_cols = len(clustering_params)
|
3402
|
+
|
3403
|
+
fig_width = grid_cols*10
|
3404
|
+
fig_height = grid_rows*10
|
3405
|
+
|
3406
|
+
fig, axs = plt.subplots(grid_rows, grid_cols, figsize=(fig_width, fig_height))
|
3407
|
+
|
3408
|
+
# Make sure axs is always an array of axes
|
3409
|
+
axs = np.atleast_1d(axs)
|
3410
|
+
|
3411
|
+
# Iterate through the Cartesian product of reduction and clustering hyperparameters
|
3412
|
+
for i, reduction_param in enumerate(reduction_params):
|
3413
|
+
for j, clustering_param in enumerate(clustering_params):
|
3414
|
+
if len(clustering_params) <= 1:
|
3415
|
+
axs[i].axis('off')
|
3416
|
+
ax = axs[i]
|
3417
|
+
elif len(reduction_params) <= 1:
|
3418
|
+
axs[j].axis('off')
|
3419
|
+
ax = axs[j]
|
2419
3420
|
else:
|
2420
|
-
|
2421
|
-
|
2422
|
-
|
2423
|
-
|
2424
|
-
|
2425
|
-
|
2426
|
-
|
2427
|
-
|
2428
|
-
|
2429
|
-
|
2430
|
-
|
2431
|
-
|
3421
|
+
ax = axs[i, j]
|
3422
|
+
|
3423
|
+
# Perform dimensionality reduction and clustering
|
3424
|
+
if settings['reduction_method'].lower() == 'umap':
|
3425
|
+
n_neighbors = reduction_param.get('n_neighbors', 15)
|
3426
|
+
|
3427
|
+
if isinstance(n_neighbors, float):
|
3428
|
+
n_neighbors = int(n_neighbors * len(numeric_data))
|
3429
|
+
|
3430
|
+
min_dist = reduction_param.get('min_dist', 0.1)
|
3431
|
+
embedding, labels = search_reduction_and_clustering(numeric_data, n_neighbors, min_dist, settings['metric'],
|
3432
|
+
clustering_param.get('eps', 0.5), clustering_param.get('min_samples', 5),
|
3433
|
+
clustering_param['method'], settings['reduction_method'], settings['verbose'], reduction_param, n_jobs=settings['n_jobs'])
|
2432
3434
|
|
2433
|
-
|
3435
|
+
elif settings['reduction_method'].lower() == 'tsne':
|
3436
|
+
perplexity = reduction_param.get('perplexity', 30)
|
2434
3437
|
|
2435
|
-
|
2436
|
-
|
2437
|
-
else:
|
2438
|
-
average_obj_size = _get_avg_object_size(mask_stack)
|
3438
|
+
if isinstance(perplexity, float):
|
3439
|
+
perplexity = int(perplexity * len(numeric_data))
|
2439
3440
|
|
2440
|
-
|
2441
|
-
|
3441
|
+
embedding, labels = search_reduction_and_clustering(numeric_data, perplexity, 0.1, settings['metric'],
|
3442
|
+
clustering_param.get('eps', 0.5), clustering_param.get('min_samples', 5),
|
3443
|
+
clustering_param['method'], settings['reduction_method'], settings['verbose'], reduction_param, n_jobs=settings['n_jobs'])
|
3444
|
+
|
3445
|
+
else:
|
3446
|
+
raise ValueError(f"Unsupported reduction method: {settings['reduction_method']}. Supported methods are 'UMAP' and 'tSNE'")
|
3447
|
+
|
3448
|
+
# Plot the results
|
3449
|
+
if settings['color_by']:
|
3450
|
+
unique_groups = all_df[settings['color_by']].unique()
|
3451
|
+
colors = generate_colors(len(unique_groups), False)
|
3452
|
+
for group, color in zip(unique_groups, colors):
|
3453
|
+
indices = all_df[settings['color_by']] == group
|
3454
|
+
ax.scatter(embedding[indices, 0], embedding[indices, 1], s=pointsize, label=f"{group}", color=color)
|
3455
|
+
else:
|
3456
|
+
unique_labels = np.unique(labels)
|
3457
|
+
colors = generate_colors(len(unique_labels), False)
|
3458
|
+
for label, color in zip(unique_labels, colors):
|
3459
|
+
ax.scatter(embedding[labels == label, 0], embedding[labels == label, 1], s=pointsize, label=f"Cluster {label}", color=color)
|
3460
|
+
|
3461
|
+
ax.set_title(f"{settings['reduction_method']} {reduction_param}\n{clustering_param['method']} {clustering_param}")
|
3462
|
+
ax.legend()
|
3463
|
+
|
3464
|
+
plt.tight_layout()
|
3465
|
+
if save:
|
3466
|
+
results_dir = os.path.join(settings['src'], 'results')
|
3467
|
+
os.makedirs(results_dir, exist_ok=True)
|
3468
|
+
plt.savefig(os.path.join(results_dir, 'hyperparameter_search.pdf'))
|
3469
|
+
else:
|
3470
|
+
plt.show()
|
2442
3471
|
|
2443
|
-
stop = time.time()
|
2444
|
-
duration = (stop - start)
|
2445
|
-
time_ls.append(duration)
|
2446
|
-
average_time = np.mean(time_ls) if len(time_ls) > 0 else 0
|
2447
|
-
time_in_min = average_time/60
|
2448
|
-
time_per_mask = average_time/batch_size
|
2449
|
-
print(f'Processing {len(paths)} files with {batch_size} imgs: {(file_index+1)*(batch_size+1)}/{(len(paths))*(batch_size+1)}: Time/batch {time_in_min:.3f} min: Time/mask {time_per_mask:.3f}sec: {object_type} size: {overall_average_size:.3f} px2')
|
2450
|
-
if not timelapse:
|
2451
|
-
if settings['plot']:
|
2452
|
-
plot_masks(batch, mask_stack, flows, figuresize=figuresize, cmap='inferno', nr=batch_size)
|
2453
|
-
if settings['save']:
|
2454
|
-
for mask_index, mask in enumerate(mask_stack):
|
2455
|
-
output_filename = os.path.join(output_folder, batch_filenames[mask_index])
|
2456
|
-
np.save(output_filename, mask)
|
2457
|
-
mask_stack = []
|
2458
|
-
batch_filenames = []
|
2459
|
-
gc.collect()
|
2460
|
-
torch.cuda.empty_cache()
|
2461
3472
|
return
|