spacr 0.0.36__py3-none-any.whl → 0.0.62__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- spacr/__init__.py +11 -4
- spacr/__main__.py +0 -2
- spacr/alpha.py +514 -2
- spacr/annotate_app.py +112 -116
- spacr/core.py +864 -728
- spacr/deep_spacr.py +696 -0
- spacr/foldseek.py +2 -16
- spacr/graph_learning.py +297 -253
- spacr/gui.py +9 -8
- spacr/gui_2.py +90 -0
- spacr/gui_classify_app.py +3 -4
- spacr/gui_mask_app.py +9 -9
- spacr/gui_measure_app.py +3 -5
- spacr/gui_utils.py +132 -33
- spacr/io.py +308 -464
- spacr/mask_app.py +109 -5
- spacr/measure.py +15 -1
- spacr/models/cp/toxo_pv_lumen.CP_model +0 -0
- spacr/old_code.py +69 -1
- spacr/plot.py +23 -6
- spacr/sequencing.py +1130 -0
- spacr/sim.py +0 -42
- spacr/timelapse.py +0 -1
- spacr/train.py +172 -13
- spacr/umap.py +0 -689
- spacr/utils.py +1322 -75
- {spacr-0.0.36.dist-info → spacr-0.0.62.dist-info}/METADATA +14 -29
- spacr-0.0.62.dist-info/RECORD +39 -0
- {spacr-0.0.36.dist-info → spacr-0.0.62.dist-info}/entry_points.txt +1 -0
- spacr-0.0.36.dist-info/RECORD +0 -35
- {spacr-0.0.36.dist-info → spacr-0.0.62.dist-info}/LICENSE +0 -0
- {spacr-0.0.36.dist-info → spacr-0.0.62.dist-info}/WHEEL +0 -0
- {spacr-0.0.36.dist-info → spacr-0.0.62.dist-info}/top_level.txt +0 -0
spacr/core.py
CHANGED
@@ -1,13 +1,10 @@
|
|
1
|
-
import os, sqlite3, gc, torch, time, random, shutil, cv2, tarfile, datetime, shap
|
1
|
+
import os, sqlite3, gc, torch, time, random, shutil, cv2, tarfile, datetime, shap
|
2
2
|
|
3
|
-
# image and array processing
|
4
3
|
import numpy as np
|
5
4
|
import pandas as pd
|
6
5
|
|
7
6
|
from cellpose import train
|
8
|
-
import cellpose
|
9
7
|
from cellpose import models as cp_models
|
10
|
-
from cellpose.models import CellposeModel
|
11
8
|
|
12
9
|
import statsmodels.formula.api as smf
|
13
10
|
import statsmodels.api as sm
|
@@ -16,31 +13,37 @@ from IPython.display import display
|
|
16
13
|
from multiprocessing import Pool, cpu_count, Value, Lock
|
17
14
|
|
18
15
|
import seaborn as sns
|
19
|
-
|
16
|
+
|
20
17
|
from skimage.measure import regionprops, label
|
21
|
-
|
18
|
+
from skimage.morphology import square
|
22
19
|
from skimage.transform import resize as resizescikit
|
23
|
-
from sklearn.model_selection import train_test_split
|
24
20
|
from collections import defaultdict
|
25
|
-
import multiprocessing
|
26
21
|
from torch.utils.data import DataLoader, random_split
|
27
|
-
import
|
28
|
-
|
22
|
+
from sklearn.cluster import KMeans
|
23
|
+
from sklearn.decomposition import PCA
|
29
24
|
|
30
|
-
|
25
|
+
from skimage import measure
|
31
26
|
from sklearn.model_selection import train_test_split
|
32
27
|
from sklearn.ensemble import IsolationForest, RandomForestClassifier, HistGradientBoostingClassifier
|
33
|
-
from .logger import log_function_call
|
34
|
-
|
35
28
|
from sklearn.linear_model import LogisticRegression
|
36
29
|
from sklearn.inspection import permutation_importance
|
37
30
|
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, classification_report
|
38
|
-
from
|
31
|
+
from sklearn.preprocessing import StandardScaler
|
39
32
|
|
33
|
+
from scipy.ndimage import binary_dilation
|
40
34
|
from scipy.spatial.distance import cosine, euclidean, mahalanobis, cityblock, minkowski, chebyshev, hamming, jaccard, braycurtis
|
41
|
-
|
35
|
+
|
36
|
+
import torchvision.transforms as transforms
|
37
|
+
from xgboost import XGBClassifier
|
42
38
|
import shap
|
43
39
|
|
40
|
+
import matplotlib.pyplot as plt
|
41
|
+
import matplotlib
|
42
|
+
matplotlib.use('Agg')
|
43
|
+
#import matplotlib.pyplot as plt
|
44
|
+
|
45
|
+
from .logger import log_function_call
|
46
|
+
|
44
47
|
def analyze_plaques(folder):
|
45
48
|
summary_data = []
|
46
49
|
details_data = []
|
@@ -77,73 +80,46 @@ def analyze_plaques(folder):
|
|
77
80
|
|
78
81
|
print(f"Analysis completed and saved to database '{db_name}'.")
|
79
82
|
|
80
|
-
def generate_cp_masks(settings):
|
81
|
-
|
82
|
-
src = settings['src']
|
83
|
-
model_name = settings['model_name']
|
84
|
-
channels = settings['channels']
|
85
|
-
diameter = settings['diameter']
|
86
|
-
regex = '.tif'
|
87
|
-
#flow_threshold = 30
|
88
|
-
cellprob_threshold = settings['cellprob_threshold']
|
89
|
-
figuresize = 25
|
90
|
-
cmap = 'inferno'
|
91
|
-
verbose = settings['verbose']
|
92
|
-
plot = settings['plot']
|
93
|
-
save = settings['save']
|
94
|
-
custom_model = settings['custom_model']
|
95
|
-
signal_thresholds = 1000
|
96
|
-
normalize = settings['normalize']
|
97
|
-
resize = settings['resize']
|
98
|
-
target_height = settings['width_height'][1]
|
99
|
-
target_width = settings['width_height'][0]
|
100
|
-
rescale = settings['rescale']
|
101
|
-
resample = settings['resample']
|
102
|
-
net_avg = settings['net_avg']
|
103
|
-
invert = settings['invert']
|
104
|
-
circular = settings['circular']
|
105
|
-
percentiles = settings['percentiles']
|
106
|
-
overlay = settings['overlay']
|
107
|
-
grayscale = settings['grayscale']
|
108
|
-
flow_threshold = settings['flow_threshold']
|
109
|
-
batch_size = settings['batch_size']
|
110
|
-
|
111
|
-
dst = os.path.join(src,'masks')
|
112
|
-
os.makedirs(dst, exist_ok=True)
|
113
|
-
|
114
|
-
identify_masks(src, dst, model_name, channels, diameter, batch_size, flow_threshold, cellprob_threshold, figuresize, cmap, verbose, plot, save, custom_model, signal_thresholds, normalize, resize, target_height, target_width, rescale, resample, net_avg, invert, circular, percentiles, overlay, grayscale)
|
115
|
-
|
116
83
|
def train_cellpose(settings):
|
117
84
|
|
118
85
|
from .io import _load_normalized_images_and_labels, _load_images_and_labels
|
119
86
|
from .utils import resize_images_and_labels
|
120
87
|
|
121
88
|
img_src = settings['img_src']
|
122
|
-
mask_src = os.path.join(img_src, '
|
89
|
+
mask_src = os.path.join(img_src, 'masks')
|
123
90
|
|
124
|
-
model_name = settings
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
91
|
+
model_name = settings.setdefault( 'model_name', '')
|
92
|
+
|
93
|
+
model_name = settings.setdefault('model_name', 'model_name')
|
94
|
+
|
95
|
+
model_type = settings.setdefault( 'model_type', 'cyto')
|
96
|
+
learning_rate = settings.setdefault( 'learning_rate', 0.01)
|
97
|
+
weight_decay = settings.setdefault( 'weight_decay', 1e-05)
|
98
|
+
batch_size = settings.setdefault( 'batch_size', 50)
|
99
|
+
n_epochs = settings.setdefault( 'n_epochs', 100)
|
100
|
+
from_scratch = settings.setdefault( 'from_scratch', False)
|
101
|
+
diameter = settings.setdefault( 'diameter', 40)
|
102
|
+
|
103
|
+
remove_background = settings.setdefault( 'remove_background', False)
|
104
|
+
background = settings.setdefault( 'background', 100)
|
105
|
+
Signal_to_noise = settings.setdefault( 'Signal_to_noise', 10)
|
106
|
+
verbose = settings.setdefault( 'verbose', False)
|
107
|
+
|
108
|
+
|
109
|
+
channels = settings.setdefault( 'channels', [0,0])
|
110
|
+
normalize = settings.setdefault( 'normalize', True)
|
111
|
+
percentiles = settings.setdefault( 'percentiles', None)
|
112
|
+
circular = settings.setdefault( 'circular', False)
|
113
|
+
invert = settings.setdefault( 'invert', False)
|
114
|
+
resize = settings.setdefault( 'resize', False)
|
115
|
+
|
116
|
+
if resize:
|
117
|
+
target_height = settings['width_height'][1]
|
118
|
+
target_width = settings['width_height'][0]
|
119
|
+
|
120
|
+
grayscale = settings.setdefault( 'grayscale', True)
|
121
|
+
rescale = settings.setdefault( 'channels', False)
|
122
|
+
test = settings.setdefault( 'test', False)
|
147
123
|
|
148
124
|
if test:
|
149
125
|
test_img_src = os.path.join(os.path.dirname(img_src), 'test')
|
@@ -177,22 +153,21 @@ def train_cellpose(settings):
|
|
177
153
|
|
178
154
|
image_files = [os.path.join(img_src, f) for f in os.listdir(img_src) if f.endswith('.tif')]
|
179
155
|
label_files = [os.path.join(mask_src, f) for f in os.listdir(mask_src) if f.endswith('.tif')]
|
180
|
-
images, masks, image_names, mask_names = _load_normalized_images_and_labels(image_files, label_files,
|
156
|
+
images, masks, image_names, mask_names = _load_normalized_images_and_labels(image_files, label_files, channels, percentiles, circular, invert, verbose, remove_background, background, Signal_to_noise)
|
181
157
|
images = [np.squeeze(img) if img.shape[-1] == 1 else img for img in images]
|
182
158
|
|
183
159
|
if test:
|
184
160
|
test_image_files = [os.path.join(test_img_src, f) for f in os.listdir(test_img_src) if f.endswith('.tif')]
|
185
161
|
test_label_files = [os.path.join(test_mask_src, f) for f in os.listdir(test_mask_src) if f.endswith('.tif')]
|
186
|
-
test_images, test_masks, test_image_names, test_mask_names = _load_normalized_images_and_labels(
|
162
|
+
test_images, test_masks, test_image_names, test_mask_names = _load_normalized_images_and_labels(test_image_files, test_label_files, channels, percentiles, circular, invert, verbose, remove_background, background, Signal_to_noise)
|
187
163
|
test_images = [np.squeeze(img) if img.shape[-1] == 1 else img for img in test_images]
|
188
164
|
|
189
|
-
|
190
165
|
else:
|
191
166
|
images, masks, image_names, mask_names = _load_images_and_labels(img_src, mask_src, circular, invert)
|
192
167
|
images = [np.squeeze(img) if img.shape[-1] == 1 else img for img in images]
|
193
168
|
|
194
169
|
if test:
|
195
|
-
test_images, test_masks, test_image_names, test_mask_names = _load_images_and_labels(img_src=test_img_src, mask_src=test_mask_src, circular=circular, invert=
|
170
|
+
test_images, test_masks, test_image_names, test_mask_names = _load_images_and_labels(img_src=test_img_src, mask_src=test_mask_src, circular=circular, invert=invert)
|
196
171
|
test_images = [np.squeeze(img) if img.shape[-1] == 1 else img for img in test_images]
|
197
172
|
|
198
173
|
if resize:
|
@@ -250,179 +225,6 @@ def train_cellpose(settings):
|
|
250
225
|
|
251
226
|
return print(f"Model saved at: {model_save_path}/{model_name}")
|
252
227
|
|
253
|
-
def train_cellpose_v1(settings):
|
254
|
-
|
255
|
-
from .io import _load_normalized_images_and_labels, _load_images_and_labels
|
256
|
-
from .utils import resize_images_and_labels
|
257
|
-
|
258
|
-
img_src = settings['img_src']
|
259
|
-
|
260
|
-
mask_src = os.path.join(img_src, 'mask')
|
261
|
-
|
262
|
-
model_name = settings['model_name']
|
263
|
-
model_type = settings['model_type']
|
264
|
-
learning_rate = settings['learning_rate']
|
265
|
-
weight_decay = settings['weight_decay']
|
266
|
-
batch_size = settings['batch_size']
|
267
|
-
n_epochs = settings['n_epochs']
|
268
|
-
verbose = settings['verbose']
|
269
|
-
|
270
|
-
signal_thresholds = 100 #settings['signal_thresholds']
|
271
|
-
|
272
|
-
channels = settings['channels']
|
273
|
-
from_scratch = settings['from_scratch']
|
274
|
-
diameter = settings['diameter']
|
275
|
-
resize = settings['resize']
|
276
|
-
rescale = settings['rescale']
|
277
|
-
normalize = settings['normalize']
|
278
|
-
target_height = settings['width_height'][1]
|
279
|
-
target_width = settings['width_height'][0]
|
280
|
-
circular = settings['circular']
|
281
|
-
invert = settings['invert']
|
282
|
-
percentiles = settings['percentiles']
|
283
|
-
grayscale = settings['grayscale']
|
284
|
-
|
285
|
-
if model_type == 'cyto':
|
286
|
-
settings['diameter'] = 30
|
287
|
-
diameter = settings['diameter']
|
288
|
-
print(f'Cyto model must have diamiter 30. Diameter set the 30')
|
289
|
-
|
290
|
-
if model_type == 'nuclei':
|
291
|
-
settings['diameter'] = 17
|
292
|
-
diameter = settings['diameter']
|
293
|
-
print(f'Nuclei model must have diamiter 17. Diameter set the 17')
|
294
|
-
|
295
|
-
print(settings)
|
296
|
-
|
297
|
-
if from_scratch:
|
298
|
-
model_name=f'scratch_{model_name}_{model_type}_e{n_epochs}_X{target_width}_Y{target_height}.CP_model'
|
299
|
-
else:
|
300
|
-
model_name=f'{model_name}_{model_type}_e{n_epochs}_X{target_width}_Y{target_height}.CP_model'
|
301
|
-
|
302
|
-
model_save_path = os.path.join(mask_src, 'models', 'cellpose_model')
|
303
|
-
print(model_save_path)
|
304
|
-
os.makedirs(model_save_path, exist_ok=True)
|
305
|
-
|
306
|
-
settings_df = pd.DataFrame(list(settings.items()), columns=['Key', 'Value'])
|
307
|
-
settings_csv = os.path.join(model_save_path,f'{model_name}_settings.csv')
|
308
|
-
settings_df.to_csv(settings_csv, index=False)
|
309
|
-
|
310
|
-
if not from_scratch:
|
311
|
-
model = cp_models.CellposeModel(gpu=True, model_type=model_type)
|
312
|
-
|
313
|
-
else:
|
314
|
-
model = cp_models.CellposeModel(gpu=True, model_type=model_type, pretrained_model=None)
|
315
|
-
|
316
|
-
if normalize:
|
317
|
-
image_files = [os.path.join(img_src, f) for f in os.listdir(img_src) if f.endswith('.tif')]
|
318
|
-
label_files = [os.path.join(mask_src, f) for f in os.listdir(mask_src) if f.endswith('.tif')]
|
319
|
-
|
320
|
-
images, masks, image_names, mask_names = _load_normalized_images_and_labels(image_files, label_files, signal_thresholds, channels=channels, percentiles=percentiles, circular=circular, invert=invert, visualize=verbose)
|
321
|
-
images = [np.squeeze(img) if img.shape[-1] == 1 else img for img in images]
|
322
|
-
else:
|
323
|
-
images, masks, image_names, mask_names = _load_images_and_labels(img_src, mask_src, circular, invert)
|
324
|
-
images = [np.squeeze(img) if img.shape[-1] == 1 else img for img in images]
|
325
|
-
|
326
|
-
if resize:
|
327
|
-
images, masks = resize_images_and_labels(images, masks, target_height, target_width, show_example=True)
|
328
|
-
|
329
|
-
if model_type == 'cyto':
|
330
|
-
cp_channels = [0,1]
|
331
|
-
if model_type == 'cyto2':
|
332
|
-
cp_channels = [0,2]
|
333
|
-
if model_type == 'nucleus':
|
334
|
-
cp_channels = [0,0]
|
335
|
-
if grayscale:
|
336
|
-
cp_channels = [0,0]
|
337
|
-
images = [np.squeeze(img) if img.ndim == 3 and 1 in img.shape else img for img in images]
|
338
|
-
|
339
|
-
masks = [np.squeeze(mask) if mask.ndim == 3 and 1 in mask.shape else mask for mask in masks]
|
340
|
-
|
341
|
-
print(f'image shape: {images[0].shape}, image type: images[0].shape mask shape: {masks[0].shape}, image type: masks[0].shape')
|
342
|
-
save_every = int(n_epochs/10)
|
343
|
-
if save_every < 10:
|
344
|
-
save_every = n_epochs
|
345
|
-
|
346
|
-
|
347
|
-
#print('cellpose image input dtype', images[0].dtype)
|
348
|
-
#print('cellpose mask input dtype', masks[0].dtype)
|
349
|
-
|
350
|
-
# Train the model
|
351
|
-
#model.train(train_data=images, #(list of arrays (2D or 3D)) – images for training
|
352
|
-
|
353
|
-
#model.train(train_data=images, #(list of arrays (2D or 3D)) – images for training
|
354
|
-
# train_labels=masks, #(list of arrays (2D or 3D)) – labels for train_data, where 0=no masks; 1,2,…=mask labels can include flows as additional images
|
355
|
-
# train_files=image_names, #(list of strings) – file names for images in train_data (to save flows for future runs)
|
356
|
-
# channels=cp_channels, #(list of ints (default, None)) – channels to use for training
|
357
|
-
# normalize=False, #(bool (default, True)) – normalize data so 0.0=1st percentile and 1.0=99th percentile of image intensities in each channel
|
358
|
-
# save_path=model_save_path, #(string (default, None)) – where to save trained model, if None it is not saved
|
359
|
-
# save_every=save_every, #(int (default, 100)) – save network every [save_every] epochs
|
360
|
-
# learning_rate=learning_rate, #(float or list/np.ndarray (default, 0.2)) – learning rate for training, if list, must be same length as n_epochs
|
361
|
-
# n_epochs=n_epochs, #(int (default, 500)) – how many times to go through whole training set during training
|
362
|
-
# weight_decay=weight_decay, #(float (default, 0.00001)) –
|
363
|
-
# SGD=True, #(bool (default, True)) – use SGD as optimization instead of RAdam
|
364
|
-
# batch_size=batch_size, #(int (optional, default 8)) – number of 224x224 patches to run simultaneously on the GPU (can make smaller or bigger depending on GPU memory usage)
|
365
|
-
# nimg_per_epoch=None, #(int (optional, default None)) – minimum number of images to train on per epoch, with a small training set (< 8 images) it may help to set to 8
|
366
|
-
# rescale=rescale, #(bool (default, True)) – whether or not to rescale images to diam_mean during training, if True it assumes you will fit a size model after training or resize your images accordingly, if False it will try to train the model to be scale-invariant (works worse)
|
367
|
-
# min_train_masks=1, #(int (default, 5)) – minimum number of masks an image must have to use in training set
|
368
|
-
# model_name=model_name) #(str (default, None)) – name of network, otherwise saved with name as params + training start time
|
369
|
-
|
370
|
-
|
371
|
-
train.train_seg(model.net,
|
372
|
-
train_data=images,
|
373
|
-
train_labels=masks,
|
374
|
-
train_files=image_names,
|
375
|
-
train_labels_files=None,
|
376
|
-
train_probs=None,
|
377
|
-
test_data=None,
|
378
|
-
test_labels=None,
|
379
|
-
test_files=None,
|
380
|
-
test_labels_files=None,
|
381
|
-
test_probs=None,
|
382
|
-
load_files=True,
|
383
|
-
batch_size=batch_size,
|
384
|
-
learning_rate=learning_rate,
|
385
|
-
n_epochs=n_epochs,
|
386
|
-
weight_decay=weight_decay,
|
387
|
-
momentum=0.9,
|
388
|
-
SGD=False,
|
389
|
-
channels=cp_channels,
|
390
|
-
channel_axis=None,
|
391
|
-
#rgb=False,
|
392
|
-
normalize=False,
|
393
|
-
compute_flows=False,
|
394
|
-
save_path=model_save_path,
|
395
|
-
save_every=save_every,
|
396
|
-
nimg_per_epoch=None,
|
397
|
-
nimg_test_per_epoch=None,
|
398
|
-
rescale=rescale,
|
399
|
-
#scale_range=None,
|
400
|
-
#bsize=224,
|
401
|
-
min_train_masks=1,
|
402
|
-
model_name=model_name)
|
403
|
-
|
404
|
-
#model_save_path = train.train_seg(model.net,
|
405
|
-
# train_data=images,
|
406
|
-
# train_files=image_names,
|
407
|
-
# train_labels=masks,
|
408
|
-
# channels=cp_channels,
|
409
|
-
# normalize=False,
|
410
|
-
# save_every=save_every,
|
411
|
-
# learning_rate=learning_rate,
|
412
|
-
# n_epochs=n_epochs,
|
413
|
-
# #test_data=test_images,
|
414
|
-
# #test_labels=test_labels,
|
415
|
-
# weight_decay=weight_decay,
|
416
|
-
# SGD=True,
|
417
|
-
# batch_size=batch_size,
|
418
|
-
# nimg_per_epoch=None,
|
419
|
-
# rescale=rescale,
|
420
|
-
# min_train_masks=1,
|
421
|
-
# model_name=model_name)
|
422
|
-
|
423
|
-
|
424
|
-
return print(f"Model saved at: {model_save_path}/{model_name}")
|
425
|
-
|
426
228
|
def analyze_data_reg(sequencing_loc, dv_loc, agg_type = 'mean', dv_col='pred', transform=None, min_cell_count=50, min_reads=100, min_wells=2, max_wells=1000, min_frequency=0.0,remove_outlier_genes=False, refine_model=False,by_plate=False, regression_type='mlr', alpha_value=0.01, fishers=False, fisher_threshold=0.9):
|
427
229
|
|
428
230
|
from .plot import _reg_v_plot
|
@@ -984,15 +786,6 @@ def merge_pred_mes(src,
|
|
984
786
|
|
985
787
|
if verbose:
|
986
788
|
_plot_histograms_and_stats(df=joined_df)
|
987
|
-
|
988
|
-
#dv = joined_df.copy()
|
989
|
-
#if 'prc' not in dv.columns:
|
990
|
-
#dv['prc'] = dv['plate'] + '_' + dv['row'] + '_' + dv['col']
|
991
|
-
#dv = dv[['pred']].groupby('prc').mean()
|
992
|
-
#dv.set_index('prc', inplace=True)
|
993
|
-
|
994
|
-
#loc = '/mnt/data/CellVoyager/20x/tsg101/crispr_screen/all/measurements/dv.csv'
|
995
|
-
#dv.to_csv(loc, index=True, header=True, mode='w')
|
996
789
|
|
997
790
|
return joined_df
|
998
791
|
|
@@ -1282,7 +1075,6 @@ def apply_model(src, model_path, image_size=224, batch_size=64, normalize=True,
|
|
1282
1075
|
torch.cuda.memory.empty_cache()
|
1283
1076
|
return df
|
1284
1077
|
|
1285
|
-
|
1286
1078
|
def generate_training_data_file_list(src,
|
1287
1079
|
target='protein of interest',
|
1288
1080
|
cell_dim=4,
|
@@ -1483,158 +1275,27 @@ def generate_training_dataset(src, mode='annotation', annotation_column='test',
|
|
1483
1275
|
if len(custom_measurement) == 1:
|
1484
1276
|
print(f'Classes will be defined by the Q1 and Q3 quantiles of recruitment ({custom_measurement[0]})')
|
1485
1277
|
df['recruitment'] = df[f'{custom_measurement[0]}']
|
1486
|
-
else:
|
1487
|
-
print(f'Classes will be defined by the Q1 and Q3 quantiles of recruitment (pathogen/cytoplasm for channel {channel_of_interest})')
|
1488
|
-
df['recruitment'] = df[f'pathogen_channel_{channel_of_interest}_mean_intensity']/df[f'cytoplasm_channel_{channel_of_interest}_mean_intensity']
|
1489
|
-
|
1490
|
-
q25 = df['recruitment'].quantile(0.25)
|
1491
|
-
q75 = df['recruitment'].quantile(0.75)
|
1492
|
-
df_lower = df[df['recruitment'] <= q25]
|
1493
|
-
df_upper = df[df['recruitment'] >= q75]
|
1494
|
-
|
1495
|
-
class_paths_lower = get_paths_from_db(df=df_lower, png_df=png_list_df, image_type=png_type)
|
1496
|
-
|
1497
|
-
class_paths_lower = random.sample(class_paths_lower['png_path'].tolist(), size)
|
1498
|
-
class_paths_ls.append(class_paths_lower)
|
1499
|
-
|
1500
|
-
class_paths_upper = get_paths_from_db(df=df_upper, png_df=png_list_df, image_type=png_type)
|
1501
|
-
class_paths_upper = random.sample(class_paths_upper['png_path'].tolist(), size)
|
1502
|
-
class_paths_ls.append(class_paths_upper)
|
1503
|
-
|
1504
|
-
generate_dataset_from_lists(dst, class_data=class_paths_ls, classes=classes, test_split=0.1)
|
1505
|
-
|
1506
|
-
return
|
1507
|
-
|
1508
|
-
def generate_loaders_v1(src, train_mode='erm', mode='train', image_size=224, batch_size=32, classes=['nc','pc'], num_workers=None, validation_split=0.0, max_show=2, pin_memory=False, normalize=False, verbose=False):
|
1509
|
-
"""
|
1510
|
-
Generate data loaders for training and validation/test datasets.
|
1511
|
-
|
1512
|
-
Parameters:
|
1513
|
-
- src (str): The source directory containing the data.
|
1514
|
-
- train_mode (str): The training mode. Options are 'erm' (Empirical Risk Minimization) or 'irm' (Invariant Risk Minimization).
|
1515
|
-
- mode (str): The mode of operation. Options are 'train' or 'test'.
|
1516
|
-
- image_size (int): The size of the input images.
|
1517
|
-
- batch_size (int): The batch size for the data loaders.
|
1518
|
-
- classes (list): The list of classes to consider.
|
1519
|
-
- num_workers (int): The number of worker threads for data loading.
|
1520
|
-
- validation_split (float): The fraction of data to use for validation when train_mode is 'erm'.
|
1521
|
-
- max_show (int): The maximum number of images to show when verbose is True.
|
1522
|
-
- pin_memory (bool): Whether to pin memory for faster data transfer.
|
1523
|
-
- normalize (bool): Whether to normalize the input images.
|
1524
|
-
- verbose (bool): Whether to print additional information and show images.
|
1525
|
-
|
1526
|
-
Returns:
|
1527
|
-
- train_loaders (list): List of data loaders for training datasets.
|
1528
|
-
- val_loaders (list): List of data loaders for validation datasets.
|
1529
|
-
- plate_names (list): List of plate names (only applicable when train_mode is 'irm').
|
1530
|
-
"""
|
1531
|
-
|
1532
|
-
from .io import MyDataset
|
1533
|
-
from .plot import _imshow
|
1534
|
-
|
1535
|
-
plate_to_filenames = defaultdict(list)
|
1536
|
-
plate_to_labels = defaultdict(list)
|
1537
|
-
train_loaders = []
|
1538
|
-
val_loaders = []
|
1539
|
-
plate_names = []
|
1540
|
-
|
1541
|
-
if normalize:
|
1542
|
-
transform = transforms.Compose([
|
1543
|
-
transforms.ToTensor(),
|
1544
|
-
transforms.CenterCrop(size=(image_size, image_size)),
|
1545
|
-
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])
|
1546
|
-
else:
|
1547
|
-
transform = transforms.Compose([
|
1548
|
-
transforms.ToTensor(),
|
1549
|
-
transforms.CenterCrop(size=(image_size, image_size))])
|
1550
|
-
|
1551
|
-
if mode == 'train':
|
1552
|
-
data_dir = os.path.join(src, 'train')
|
1553
|
-
shuffle = True
|
1554
|
-
print(f'Generating Train and validation datasets')
|
1555
|
-
|
1556
|
-
elif mode == 'test':
|
1557
|
-
data_dir = os.path.join(src, 'test')
|
1558
|
-
val_loaders = []
|
1559
|
-
validation_split=0.0
|
1560
|
-
shuffle = True
|
1561
|
-
print(f'Generating test dataset')
|
1562
|
-
|
1563
|
-
else:
|
1564
|
-
print(f'mode:{mode} is not valid, use mode = train or test')
|
1565
|
-
return
|
1566
|
-
|
1567
|
-
if train_mode == 'erm':
|
1568
|
-
data = MyDataset(data_dir, classes, transform=transform, shuffle=shuffle, pin_memory=pin_memory)
|
1569
|
-
#train_loaders = DataLoader(data, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers if num_workers is not None else 0, pin_memory=pin_memory)
|
1570
|
-
if validation_split > 0:
|
1571
|
-
train_size = int((1 - validation_split) * len(data))
|
1572
|
-
val_size = len(data) - train_size
|
1573
|
-
|
1574
|
-
print(f'Train data:{train_size}, Validation data:{val_size}')
|
1575
|
-
|
1576
|
-
train_dataset, val_dataset = random_split(data, [train_size, val_size])
|
1577
|
-
|
1578
|
-
train_loaders = DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers if num_workers is not None else 0, pin_memory=pin_memory)
|
1579
|
-
val_loaders = DataLoader(val_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers if num_workers is not None else 0, pin_memory=pin_memory)
|
1580
|
-
else:
|
1581
|
-
train_loaders = DataLoader(data, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers if num_workers is not None else 0, pin_memory=pin_memory)
|
1582
|
-
|
1583
|
-
elif train_mode == 'irm':
|
1584
|
-
data = MyDataset(data_dir, classes, transform=transform, shuffle=shuffle, pin_memory=pin_memory)
|
1585
|
-
|
1586
|
-
for filename, label in zip(data.filenames, data.labels):
|
1587
|
-
plate = data.get_plate(filename)
|
1588
|
-
plate_to_filenames[plate].append(filename)
|
1589
|
-
plate_to_labels[plate].append(label)
|
1590
|
-
|
1591
|
-
for plate, filenames in plate_to_filenames.items():
|
1592
|
-
labels = plate_to_labels[plate]
|
1593
|
-
plate_data = MyDataset(data_dir, classes, specific_files=filenames, specific_labels=labels, transform=transform, shuffle=False, pin_memory=pin_memory)
|
1594
|
-
plate_names.append(plate)
|
1595
|
-
|
1596
|
-
if validation_split > 0:
|
1597
|
-
train_size = int((1 - validation_split) * len(plate_data))
|
1598
|
-
val_size = len(plate_data) - train_size
|
1599
|
-
|
1600
|
-
print(f'Train data:{train_size}, Validation data:{val_size}')
|
1601
|
-
|
1602
|
-
train_dataset, val_dataset = random_split(plate_data, [train_size, val_size])
|
1603
|
-
|
1604
|
-
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers if num_workers is not None else 0, pin_memory=pin_memory)
|
1605
|
-
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers if num_workers is not None else 0, pin_memory=pin_memory)
|
1606
|
-
|
1607
|
-
train_loaders.append(train_loader)
|
1608
|
-
val_loaders.append(val_loader)
|
1609
|
-
else:
|
1610
|
-
train_loader = DataLoader(plate_data, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers if num_workers is not None else 0, pin_memory=pin_memory)
|
1611
|
-
train_loaders.append(train_loader)
|
1612
|
-
val_loaders.append(None)
|
1613
|
-
|
1614
|
-
else:
|
1615
|
-
print(f'train_mode:{train_mode} is not valid, use: train_mode = irm or erm')
|
1616
|
-
return
|
1617
|
-
|
1618
|
-
if verbose:
|
1619
|
-
if train_mode == 'erm':
|
1620
|
-
for idx, (images, labels, filenames) in enumerate(train_loaders):
|
1621
|
-
if idx >= max_show:
|
1622
|
-
break
|
1623
|
-
images = images.cpu()
|
1624
|
-
label_strings = [str(label.item()) for label in labels]
|
1625
|
-
_imshow(images, label_strings, nrow=20, fontsize=12)
|
1626
|
-
|
1627
|
-
elif train_mode == 'irm':
|
1628
|
-
for plate_name, train_loader in zip(plate_names, train_loaders):
|
1629
|
-
print(f'Plate: {plate_name} with {len(train_loader.dataset)} images')
|
1630
|
-
for idx, (images, labels, filenames) in enumerate(train_loader):
|
1631
|
-
if idx >= max_show:
|
1632
|
-
break
|
1633
|
-
images = images.cpu()
|
1634
|
-
label_strings = [str(label.item()) for label in labels]
|
1635
|
-
_imshow(images, label_strings, nrow=20, fontsize=12)
|
1278
|
+
else:
|
1279
|
+
print(f'Classes will be defined by the Q1 and Q3 quantiles of recruitment (pathogen/cytoplasm for channel {channel_of_interest})')
|
1280
|
+
df['recruitment'] = df[f'pathogen_channel_{channel_of_interest}_mean_intensity']/df[f'cytoplasm_channel_{channel_of_interest}_mean_intensity']
|
1281
|
+
|
1282
|
+
q25 = df['recruitment'].quantile(0.25)
|
1283
|
+
q75 = df['recruitment'].quantile(0.75)
|
1284
|
+
df_lower = df[df['recruitment'] <= q25]
|
1285
|
+
df_upper = df[df['recruitment'] >= q75]
|
1286
|
+
|
1287
|
+
class_paths_lower = get_paths_from_db(df=df_lower, png_df=png_list_df, image_type=png_type)
|
1288
|
+
|
1289
|
+
class_paths_lower = random.sample(class_paths_lower['png_path'].tolist(), size)
|
1290
|
+
class_paths_ls.append(class_paths_lower)
|
1291
|
+
|
1292
|
+
class_paths_upper = get_paths_from_db(df=df_upper, png_df=png_list_df, image_type=png_type)
|
1293
|
+
class_paths_upper = random.sample(class_paths_upper['png_path'].tolist(), size)
|
1294
|
+
class_paths_ls.append(class_paths_upper)
|
1636
1295
|
|
1637
|
-
|
1296
|
+
generate_dataset_from_lists(dst, class_data=class_paths_ls, classes=classes, test_split=0.1)
|
1297
|
+
|
1298
|
+
return
|
1638
1299
|
|
1639
1300
|
def generate_loaders(src, train_mode='erm', mode='train', image_size=224, batch_size=32, classes=['nc','pc'], num_workers=None, validation_split=0.0, max_show=2, pin_memory=False, normalize=False, channels=[1, 2, 3], verbose=False):
|
1640
1301
|
|
@@ -1671,6 +1332,7 @@ def generate_loaders(src, train_mode='erm', mode='train', image_size=224, batch_
|
|
1671
1332
|
import random
|
1672
1333
|
from PIL import Image
|
1673
1334
|
from torchvision.transforms import ToTensor
|
1335
|
+
from .utils import SelectChannels
|
1674
1336
|
|
1675
1337
|
chans = []
|
1676
1338
|
|
@@ -1687,20 +1349,6 @@ def generate_loaders(src, train_mode='erm', mode='train', image_size=224, batch_
|
|
1687
1349
|
print(f'Training a network on channels: {channels}')
|
1688
1350
|
print(f'Channel 1: Red, Channel 2: Green, Channel 3: Blue')
|
1689
1351
|
|
1690
|
-
class SelectChannels:
|
1691
|
-
def __init__(self, channels):
|
1692
|
-
self.channels = channels
|
1693
|
-
|
1694
|
-
def __call__(self, img):
|
1695
|
-
img = img.clone()
|
1696
|
-
if 1 not in self.channels:
|
1697
|
-
img[0, :, :] = 0 # Zero out the red channel
|
1698
|
-
if 2 not in self.channels:
|
1699
|
-
img[1, :, :] = 0 # Zero out the green channel
|
1700
|
-
if 3 not in self.channels:
|
1701
|
-
img[2, :, :] = 0 # Zero out the blue channel
|
1702
|
-
return img
|
1703
|
-
|
1704
1352
|
plate_to_filenames = defaultdict(list)
|
1705
1353
|
plate_to_labels = defaultdict(list)
|
1706
1354
|
train_loaders = []
|
@@ -1989,41 +1637,225 @@ def analyze_recruitment(src, metadata_settings, advanced_settings):
|
|
1989
1637
|
cells,wells = _results_to_csv(src, df, df_well)
|
1990
1638
|
return [cells,wells]
|
1991
1639
|
|
1640
|
+
def _merge_cells_based_on_parasite_overlap(parasite_mask, cell_mask, nuclei_mask, overlap_threshold=5, perimeter_threshold=30):
|
1641
|
+
"""
|
1642
|
+
Merge cells in cell_mask if a parasite in parasite_mask overlaps with more than one cell,
|
1643
|
+
and if cells share more than a specified perimeter percentage.
|
1644
|
+
|
1645
|
+
Args:
|
1646
|
+
parasite_mask (ndarray): Mask of parasites.
|
1647
|
+
cell_mask (ndarray): Mask of cells.
|
1648
|
+
nuclei_mask (ndarray): Mask of nuclei.
|
1649
|
+
overlap_threshold (float): The percentage threshold for merging cells based on parasite overlap.
|
1650
|
+
perimeter_threshold (float): The percentage threshold for merging cells based on shared perimeter.
|
1651
|
+
|
1652
|
+
Returns:
|
1653
|
+
ndarray: The modified cell mask (cell_mask) with unique labels.
|
1654
|
+
"""
|
1655
|
+
labeled_cells = label(cell_mask)
|
1656
|
+
labeled_parasites = label(parasite_mask)
|
1657
|
+
labeled_nuclei = label(nuclei_mask)
|
1658
|
+
num_parasites = np.max(labeled_parasites)
|
1659
|
+
num_cells = np.max(labeled_cells)
|
1660
|
+
num_nuclei = np.max(labeled_nuclei)
|
1661
|
+
|
1662
|
+
# Merge cells based on parasite overlap
|
1663
|
+
for parasite_id in range(1, num_parasites + 1):
|
1664
|
+
current_parasite_mask = labeled_parasites == parasite_id
|
1665
|
+
overlapping_cell_labels = np.unique(labeled_cells[current_parasite_mask])
|
1666
|
+
overlapping_cell_labels = overlapping_cell_labels[overlapping_cell_labels != 0]
|
1667
|
+
if len(overlapping_cell_labels) > 1:
|
1668
|
+
# Calculate the overlap percentages
|
1669
|
+
overlap_percentages = [
|
1670
|
+
np.sum(current_parasite_mask & (labeled_cells == cell_label)) / np.sum(current_parasite_mask) * 100
|
1671
|
+
for cell_label in overlapping_cell_labels
|
1672
|
+
]
|
1673
|
+
# Merge cells if overlap percentage is above the threshold
|
1674
|
+
for cell_label, overlap_percentage in zip(overlapping_cell_labels, overlap_percentages):
|
1675
|
+
if overlap_percentage > overlap_threshold:
|
1676
|
+
first_label = overlapping_cell_labels[0]
|
1677
|
+
for other_label in overlapping_cell_labels[1:]:
|
1678
|
+
if other_label != first_label:
|
1679
|
+
cell_mask[cell_mask == other_label] = first_label
|
1680
|
+
|
1681
|
+
# Merge cells based on nucleus overlap
|
1682
|
+
for nucleus_id in range(1, num_nuclei + 1):
|
1683
|
+
current_nucleus_mask = labeled_nuclei == nucleus_id
|
1684
|
+
overlapping_cell_labels = np.unique(labeled_cells[current_nucleus_mask])
|
1685
|
+
overlapping_cell_labels = overlapping_cell_labels[overlapping_cell_labels != 0]
|
1686
|
+
if len(overlapping_cell_labels) > 1:
|
1687
|
+
# Calculate the overlap percentages
|
1688
|
+
overlap_percentages = [
|
1689
|
+
np.sum(current_nucleus_mask & (labeled_cells == cell_label)) / np.sum(current_nucleus_mask) * 100
|
1690
|
+
for cell_label in overlapping_cell_labels
|
1691
|
+
]
|
1692
|
+
# Merge cells if overlap percentage is above the threshold for each cell
|
1693
|
+
if all(overlap_percentage > overlap_threshold for overlap_percentage in overlap_percentages):
|
1694
|
+
first_label = overlapping_cell_labels[0]
|
1695
|
+
for other_label in overlapping_cell_labels[1:]:
|
1696
|
+
if other_label != first_label:
|
1697
|
+
cell_mask[cell_mask == other_label] = first_label
|
1698
|
+
|
1699
|
+
# Check for cells without nuclei and merge based on shared perimeter
|
1700
|
+
labeled_cells = label(cell_mask) # Re-label after merging based on overlap
|
1701
|
+
cell_regions = regionprops(labeled_cells)
|
1702
|
+
for region in cell_regions:
|
1703
|
+
cell_label = region.label
|
1704
|
+
cell_mask_binary = labeled_cells == cell_label
|
1705
|
+
overlapping_nuclei = np.unique(nuclei_mask[cell_mask_binary])
|
1706
|
+
overlapping_nuclei = overlapping_nuclei[overlapping_nuclei != 0]
|
1707
|
+
|
1708
|
+
if len(overlapping_nuclei) == 0:
|
1709
|
+
# Cell does not overlap with any nucleus
|
1710
|
+
perimeter = region.perimeter
|
1711
|
+
# Dilate the cell to find neighbors
|
1712
|
+
dilated_cell = binary_dilation(cell_mask_binary, structure=square(3))
|
1713
|
+
neighbor_cells = np.unique(labeled_cells[dilated_cell])
|
1714
|
+
neighbor_cells = neighbor_cells[(neighbor_cells != 0) & (neighbor_cells != cell_label)]
|
1715
|
+
# Calculate shared border length with neighboring cells
|
1716
|
+
shared_borders = [
|
1717
|
+
np.sum((labeled_cells == neighbor_label) & dilated_cell) for neighbor_label in neighbor_cells
|
1718
|
+
]
|
1719
|
+
shared_border_percentages = [shared_border / perimeter * 100 for shared_border in shared_borders]
|
1720
|
+
# Merge with the neighbor cell with the largest shared border percentage above the threshold
|
1721
|
+
if shared_borders:
|
1722
|
+
max_shared_border_index = np.argmax(shared_border_percentages)
|
1723
|
+
max_shared_border_percentage = shared_border_percentages[max_shared_border_index]
|
1724
|
+
if max_shared_border_percentage > perimeter_threshold:
|
1725
|
+
cell_mask[labeled_cells == cell_label] = neighbor_cells[max_shared_border_index]
|
1726
|
+
|
1727
|
+
# Relabel the merged cell mask
|
1728
|
+
relabeled_cell_mask, _ = label(cell_mask, return_num=True)
|
1729
|
+
return relabeled_cell_mask
|
1730
|
+
|
1731
|
+
def adjust_cell_masks(parasite_folder, cell_folder, nuclei_folder, overlap_threshold=5, perimeter_threshold=30):
|
1732
|
+
"""
|
1733
|
+
Process all npy files in the given folders. Merge and relabel cells in cell masks
|
1734
|
+
based on parasite overlap and cell perimeter sharing conditions.
|
1735
|
+
|
1736
|
+
Args:
|
1737
|
+
parasite_folder (str): Path to the folder containing parasite masks.
|
1738
|
+
cell_folder (str): Path to the folder containing cell masks.
|
1739
|
+
nuclei_folder (str): Path to the folder containing nuclei masks.
|
1740
|
+
overlap_threshold (float): The percentage threshold for merging cells based on parasite overlap.
|
1741
|
+
perimeter_threshold (float): The percentage threshold for merging cells based on shared perimeter.
|
1742
|
+
"""
|
1743
|
+
|
1744
|
+
parasite_files = sorted([f for f in os.listdir(parasite_folder) if f.endswith('.npy')])
|
1745
|
+
cell_files = sorted([f for f in os.listdir(cell_folder) if f.endswith('.npy')])
|
1746
|
+
nuclei_files = sorted([f for f in os.listdir(nuclei_folder) if f.endswith('.npy')])
|
1747
|
+
|
1748
|
+
# Ensure there are matching files in all folders
|
1749
|
+
if not (len(parasite_files) == len(cell_files) == len(nuclei_files)):
|
1750
|
+
raise ValueError("The number of files in the folders do not match.")
|
1751
|
+
|
1752
|
+
# Match files by name
|
1753
|
+
for file_name in parasite_files:
|
1754
|
+
parasite_path = os.path.join(parasite_folder, file_name)
|
1755
|
+
cell_path = os.path.join(cell_folder, file_name)
|
1756
|
+
nuclei_path = os.path.join(nuclei_folder, file_name)
|
1757
|
+
# Check if the corresponding cell and nuclei mask files exist
|
1758
|
+
if not (os.path.exists(cell_path) and os.path.exists(nuclei_path)):
|
1759
|
+
raise ValueError(f"Corresponding cell or nuclei mask file for {file_name} not found.")
|
1760
|
+
# Load the masks
|
1761
|
+
parasite_mask = np.load(parasite_path)
|
1762
|
+
cell_mask = np.load(cell_path)
|
1763
|
+
nuclei_mask = np.load(nuclei_path)
|
1764
|
+
# Merge and relabel cells
|
1765
|
+
merged_cell_mask = _merge_cells_based_on_parasite_overlap(parasite_mask, cell_mask, nuclei_mask, overlap_threshold, perimeter_threshold)
|
1766
|
+
# Overwrite the original cell mask file with the merged result
|
1767
|
+
np.save(cell_path, merged_cell_mask)
|
1768
|
+
|
1769
|
+
def process_masks(mask_folder, image_folder, channel, batch_size=50, n_clusters=2, plot=False):
|
1770
|
+
|
1771
|
+
def read_files_in_batches(folder, batch_size=50):
|
1772
|
+
files = [f for f in os.listdir(folder) if f.endswith('.npy')]
|
1773
|
+
files.sort() # Sort to ensure matching order
|
1774
|
+
for i in range(0, len(files), batch_size):
|
1775
|
+
yield files[i:i + batch_size]
|
1776
|
+
|
1777
|
+
def measure_morphology_and_intensity(mask, image):
|
1778
|
+
properties = measure.regionprops(mask, intensity_image=image)
|
1779
|
+
properties_list = [{'area': p.area, 'mean_intensity': p.mean_intensity, 'perimeter': p.perimeter, 'eccentricity': p.eccentricity} for p in properties]
|
1780
|
+
return properties_list
|
1781
|
+
|
1782
|
+
def cluster_objects(properties, n_clusters=2):
|
1783
|
+
data = np.array([[p['area'], p['mean_intensity'], p['perimeter'], p['eccentricity']] for p in properties])
|
1784
|
+
kmeans = KMeans(n_clusters=n_clusters, random_state=0).fit(data)
|
1785
|
+
return kmeans
|
1786
|
+
|
1787
|
+
def remove_objects_not_in_largest_cluster(mask, labels, largest_cluster_label):
|
1788
|
+
cleaned_mask = np.zeros_like(mask)
|
1789
|
+
for region in measure.regionprops(mask):
|
1790
|
+
if labels[region.label - 1] == largest_cluster_label:
|
1791
|
+
cleaned_mask[mask == region.label] = region.label
|
1792
|
+
return cleaned_mask
|
1793
|
+
|
1794
|
+
def plot_clusters(properties, labels):
|
1795
|
+
data = np.array([[p['area'], p['mean_intensity'], p['perimeter'], p['eccentricity']] for p in properties])
|
1796
|
+
pca = PCA(n_components=2)
|
1797
|
+
data_2d = pca.fit_transform(data)
|
1798
|
+
plt.scatter(data_2d[:, 0], data_2d[:, 1], c=labels, cmap='viridis')
|
1799
|
+
plt.xlabel('PCA Component 1')
|
1800
|
+
plt.ylabel('PCA Component 2')
|
1801
|
+
plt.title('Object Clustering')
|
1802
|
+
plt.show()
|
1803
|
+
|
1804
|
+
all_properties = []
|
1805
|
+
|
1806
|
+
# Step 1: Accumulate properties over all files
|
1807
|
+
for batch in read_files_in_batches(mask_folder, batch_size):
|
1808
|
+
mask_files = [os.path.join(mask_folder, file) for file in batch]
|
1809
|
+
image_files = [os.path.join(image_folder, file) for file in batch]
|
1810
|
+
|
1811
|
+
masks = [np.load(file) for file in mask_files]
|
1812
|
+
images = [np.load(file)[:, :, channel] for file in image_files]
|
1813
|
+
|
1814
|
+
for i, mask in enumerate(masks):
|
1815
|
+
image = images[i]
|
1816
|
+
# Measure morphology and intensity
|
1817
|
+
properties = measure_morphology_and_intensity(mask, image)
|
1818
|
+
all_properties.extend(properties)
|
1819
|
+
|
1820
|
+
# Step 2: Perform clustering on accumulated properties
|
1821
|
+
kmeans = cluster_objects(all_properties, n_clusters)
|
1822
|
+
labels = kmeans.labels_
|
1823
|
+
|
1824
|
+
if plot:
|
1825
|
+
# Step 3: Plot clusters using PCA
|
1826
|
+
plot_clusters(all_properties, labels)
|
1827
|
+
|
1828
|
+
# Step 4: Remove objects not in the largest cluster and overwrite files in batches
|
1829
|
+
label_index = 0
|
1830
|
+
for batch in read_files_in_batches(mask_folder, batch_size):
|
1831
|
+
mask_files = [os.path.join(mask_folder, file) for file in batch]
|
1832
|
+
masks = [np.load(file) for file in mask_files]
|
1833
|
+
|
1834
|
+
for i, mask in enumerate(masks):
|
1835
|
+
batch_properties = measure_morphology_and_intensity(mask, mask)
|
1836
|
+
batch_labels = labels[label_index:label_index + len(batch_properties)]
|
1837
|
+
largest_cluster_label = np.bincount(batch_labels).argmax()
|
1838
|
+
cleaned_mask = remove_objects_not_in_largest_cluster(mask, batch_labels, largest_cluster_label)
|
1839
|
+
np.save(mask_files[i], cleaned_mask)
|
1840
|
+
label_index += len(batch_properties)
|
1841
|
+
|
1992
1842
|
def preprocess_generate_masks(src, settings={}):
|
1993
1843
|
|
1994
1844
|
from .io import preprocess_img_data, _load_and_concatenate_arrays
|
1995
1845
|
from .plot import plot_merged, plot_arrays
|
1996
|
-
from .utils import _pivot_counts_table
|
1997
|
-
|
1998
|
-
settings
|
1999
|
-
settings['fps'] = 2
|
2000
|
-
settings['remove_background'] = True
|
2001
|
-
settings['lower_quantile'] = 0.02
|
2002
|
-
settings['merge'] = False
|
2003
|
-
settings['normalize_plots'] = True
|
2004
|
-
settings['all_to_mip'] = False
|
2005
|
-
settings['pick_slice'] = False
|
2006
|
-
settings['skip_mode'] = src
|
2007
|
-
settings['workers'] = os.cpu_count()-4
|
2008
|
-
settings['verbose'] = True
|
2009
|
-
settings['examples_to_plot'] = 1
|
2010
|
-
settings['src'] = src
|
2011
|
-
settings['upscale'] = False
|
2012
|
-
settings['upscale_factor'] = 2.0
|
2013
|
-
|
2014
|
-
settings['randomize'] = True
|
2015
|
-
settings['timelapse'] = False
|
2016
|
-
settings['timelapse_displacement'] = None
|
2017
|
-
settings['timelapse_memory'] = 3
|
2018
|
-
settings['timelapse_frame_limits'] = None
|
2019
|
-
settings['timelapse_remove_transient'] = False
|
2020
|
-
settings['timelapse_mode'] = 'trackpy'
|
2021
|
-
settings['timelapse_objects'] = ['cells']
|
1846
|
+
from .utils import _pivot_counts_table, set_default_settings_preprocess_generate_masks, set_default_plot_merge_settings, check_mask_folder
|
1847
|
+
|
1848
|
+
settings = set_default_settings_preprocess_generate_masks(src, settings)
|
2022
1849
|
|
2023
1850
|
settings_df = pd.DataFrame(list(settings.items()), columns=['Key', 'Value'])
|
2024
1851
|
settings_csv = os.path.join(src,'settings','preprocess_generate_masks_settings.csv')
|
2025
1852
|
os.makedirs(os.path.join(src,'settings'), exist_ok=True)
|
2026
1853
|
settings_df.to_csv(settings_csv, index=False)
|
1854
|
+
|
1855
|
+
if not settings['pathogen_channel'] is None:
|
1856
|
+
custom_model_ls = ['toxo_pv_lumen','toxo_cyto']
|
1857
|
+
if settings['pathogen_model'] not in custom_model_ls:
|
1858
|
+
ValueError(f'Pathogen model must be {custom_model_ls} or None')
|
2027
1859
|
|
2028
1860
|
if settings['timelapse']:
|
2029
1861
|
settings['randomize'] = False
|
@@ -2032,24 +1864,50 @@ def preprocess_generate_masks(src, settings={}):
|
|
2032
1864
|
if not settings['masks']:
|
2033
1865
|
print(f'WARNING: channels for mask generation are defined when preprocess = True')
|
2034
1866
|
|
2035
|
-
if isinstance(settings['merge'], bool):
|
2036
|
-
settings['merge'] = [settings['merge']]*3
|
2037
1867
|
if isinstance(settings['save'], bool):
|
2038
1868
|
settings['save'] = [settings['save']]*3
|
2039
1869
|
|
1870
|
+
if settings['verbose']:
|
1871
|
+
settings_df = pd.DataFrame(list(settings.items()), columns=['setting_key', 'setting_value'])
|
1872
|
+
settings_df['setting_value'] = settings_df['setting_value'].apply(str)
|
1873
|
+
display(settings_df)
|
1874
|
+
|
1875
|
+
if settings['test_mode']:
|
1876
|
+
print(f'Starting Test mode ...')
|
1877
|
+
|
2040
1878
|
if settings['preprocess']:
|
2041
1879
|
settings, src = preprocess_img_data(settings)
|
2042
1880
|
|
2043
1881
|
if settings['masks']:
|
2044
1882
|
mask_src = os.path.join(src, 'norm_channel_stack')
|
2045
1883
|
if settings['cell_channel'] != None:
|
2046
|
-
|
1884
|
+
if check_mask_folder(src, 'cell_mask_stack'):
|
1885
|
+
generate_cellpose_masks(mask_src, settings, 'cell')
|
2047
1886
|
|
2048
1887
|
if settings['nucleus_channel'] != None:
|
2049
|
-
|
1888
|
+
if check_mask_folder(src, 'nucleus_mask_stack'):
|
1889
|
+
generate_cellpose_masks(mask_src, settings, 'nucleus')
|
2050
1890
|
|
2051
1891
|
if settings['pathogen_channel'] != None:
|
2052
|
-
|
1892
|
+
if check_mask_folder(src, 'pathogen_mask_stack'):
|
1893
|
+
generate_cellpose_masks(mask_src, settings, 'pathogen')
|
1894
|
+
|
1895
|
+
if settings['adjust_cells']:
|
1896
|
+
if settings['pathogen_channel'] != None and settings['cell_channel'] != None and settings['nucleus_channel'] != None:
|
1897
|
+
|
1898
|
+
start = time.time()
|
1899
|
+
cell_folder = os.path.join(mask_src, 'cell_mask_stack')
|
1900
|
+
nuclei_folder = os.path.join(mask_src, 'nucleus_mask_stack')
|
1901
|
+
parasite_folder = os.path.join(mask_src, 'pathogen_mask_stack')
|
1902
|
+
#image_folder = os.path.join(src, 'stack')
|
1903
|
+
|
1904
|
+
#process_masks(cell_folder, image_folder, settings['cell_channel'], settings['batch_size'], n_clusters=2, plot=settings['plot'])
|
1905
|
+
#process_masks(nuclei_folder, image_folder, settings['nucleus_channel'], settings['batch_size'], n_clusters=2, plot=settings['plot'])
|
1906
|
+
#process_masks(parasite_folder, image_folder, settings['pathogen_channel'], settings['batch_size'], n_clusters=2, plot=settings['plot'])
|
1907
|
+
|
1908
|
+
adjust_cell_masks(parasite_folder, cell_folder, nuclei_folder, overlap_threshold=5, perimeter_threshold=30)
|
1909
|
+
stop = time.time()
|
1910
|
+
print(f'Cell mask adjustment: {stop-start} seconds')
|
2053
1911
|
|
2054
1912
|
if os.path.exists(os.path.join(src,'measurements')):
|
2055
1913
|
_pivot_counts_table(db_path=os.path.join(src,'measurements', 'measurements.db'))
|
@@ -2078,28 +1936,14 @@ def preprocess_generate_masks(src, settings={}):
|
|
2078
1936
|
overlay_channels = [settings['nucleus_channel'], settings['pathogen_channel'], settings['cell_channel']]
|
2079
1937
|
overlay_channels = [element for element in overlay_channels if element is not None]
|
2080
1938
|
|
2081
|
-
plot_settings =
|
2082
|
-
|
2083
|
-
|
2084
|
-
|
2085
|
-
|
2086
|
-
|
2087
|
-
|
2088
|
-
|
2089
|
-
'nucleus_mask_dim':nucleus_mask_dim,
|
2090
|
-
'pathogen_mask_dim':pathogen_mask_dim,
|
2091
|
-
'outline_thickness':3,
|
2092
|
-
'outline_color':'gbr',
|
2093
|
-
'overlay_chans':overlay_channels,
|
2094
|
-
'overlay':True,
|
2095
|
-
'normalization_percentiles':[1,99],
|
2096
|
-
'normalize':True,
|
2097
|
-
'print_object_number':True,
|
2098
|
-
'nr':settings['examples_to_plot'],
|
2099
|
-
'figuresize':20,
|
2100
|
-
'cmap':'inferno',
|
2101
|
-
'verbose':False}
|
2102
|
-
|
1939
|
+
plot_settings = set_default_plot_merge_settings()
|
1940
|
+
plot_settings['channel_dims'] = settings['channels']
|
1941
|
+
plot_settings['cell_mask_dim'] = cell_mask_dim
|
1942
|
+
plot_settings['nucleus_mask_dim'] = nucleus_mask_dim
|
1943
|
+
plot_settings['pathogen_mask_dim'] = pathogen_mask_dim
|
1944
|
+
plot_settings['overlay_chans'] = overlay_channels
|
1945
|
+
plot_settings['nr'] = settings['examples_to_plot']
|
1946
|
+
|
2103
1947
|
if settings['test_mode'] == True:
|
2104
1948
|
plot_settings['nr'] = len(os.path.join(src,'merged'))
|
2105
1949
|
|
@@ -2108,7 +1952,7 @@ def preprocess_generate_masks(src, settings={}):
|
|
2108
1952
|
except Exception as e:
|
2109
1953
|
print(f'Failed to plot image mask overly. Error: {e}')
|
2110
1954
|
else:
|
2111
|
-
plot_arrays(src=os.path.join(src,'merged'), figuresize=
|
1955
|
+
plot_arrays(src=os.path.join(src,'merged'), figuresize=settings['figuresize'], cmap=settings['cmap'], nr=settings['examples_to_plot'], normalize=settings['normalize'], q1=1, q2=99)
|
2112
1956
|
|
2113
1957
|
torch.cuda.empty_cache()
|
2114
1958
|
gc.collect()
|
@@ -2121,36 +1965,62 @@ def identify_masks_finetune(settings):
|
|
2121
1965
|
from .utils import get_files_from_dir, resize_images_and_labels
|
2122
1966
|
from .io import _load_normalized_images_and_labels, _load_images_and_labels
|
2123
1967
|
|
1968
|
+
#User defined settings
|
2124
1969
|
src=settings['src']
|
2125
1970
|
dst=settings['dst']
|
1971
|
+
|
1972
|
+
|
1973
|
+
settings.setdefault('model_name', 'cyto')
|
1974
|
+
settings.setdefault('custom_model', None)
|
1975
|
+
settings.setdefault('channels', [0,0])
|
1976
|
+
settings.setdefault('background', 100)
|
1977
|
+
settings.setdefault('remove_background', False)
|
1978
|
+
settings.setdefault('Signal_to_noise', 10)
|
1979
|
+
settings.setdefault('CP_prob', 0)
|
1980
|
+
settings.setdefault('diameter', 30)
|
1981
|
+
settings.setdefault('batch_size', 50)
|
1982
|
+
settings.setdefault('flow_threshold', 0.4)
|
1983
|
+
settings.setdefault('save', False)
|
1984
|
+
settings.setdefault('verbose', False)
|
1985
|
+
settings.setdefault('normalize', True)
|
1986
|
+
settings.setdefault('percentiles', None)
|
1987
|
+
settings.setdefault('circular', False)
|
1988
|
+
settings.setdefault('invert', False)
|
1989
|
+
settings.setdefault('resize', False)
|
1990
|
+
settings.setdefault('target_height', None)
|
1991
|
+
settings.setdefault('target_width', None)
|
1992
|
+
settings.setdefault('rescale', False)
|
1993
|
+
settings.setdefault('resample', False)
|
1994
|
+
settings.setdefault('grayscale', True)
|
1995
|
+
|
1996
|
+
|
2126
1997
|
model_name=settings['model_name']
|
1998
|
+
custom_model=settings['custom_model']
|
1999
|
+
channels = settings['channels']
|
2000
|
+
background = settings['background']
|
2001
|
+
remove_background=settings['remove_background']
|
2002
|
+
Signal_to_noise = settings['Signal_to_noise']
|
2003
|
+
CP_prob = settings['CP_prob']
|
2127
2004
|
diameter=settings['diameter']
|
2128
2005
|
batch_size=settings['batch_size']
|
2129
2006
|
flow_threshold=settings['flow_threshold']
|
2130
|
-
cellprob_threshold=settings['cellprob_threshold']
|
2131
|
-
|
2132
|
-
verbose=settings['verbose']
|
2133
|
-
plot=settings['plot']
|
2134
2007
|
save=settings['save']
|
2135
|
-
|
2136
|
-
overlay=settings['overlay']
|
2008
|
+
verbose=settings['verbose']
|
2137
2009
|
|
2138
|
-
|
2139
|
-
|
2140
|
-
|
2141
|
-
|
2142
|
-
|
2143
|
-
|
2144
|
-
|
2145
|
-
|
2146
|
-
|
2147
|
-
|
2148
|
-
|
2149
|
-
|
2150
|
-
|
2151
|
-
|
2152
|
-
grayscale = True
|
2153
|
-
test = False
|
2010
|
+
# static settings
|
2011
|
+
normalize = settings['normalize']
|
2012
|
+
percentiles = settings['percentiles']
|
2013
|
+
circular = settings['circular']
|
2014
|
+
invert = settings['invert']
|
2015
|
+
resize = settings['resize']
|
2016
|
+
|
2017
|
+
if resize:
|
2018
|
+
target_height = settings['target_height']
|
2019
|
+
target_width = settings['target_width']
|
2020
|
+
|
2021
|
+
rescale = settings['rescale']
|
2022
|
+
resample = settings['resample']
|
2023
|
+
grayscale = settings['grayscale']
|
2154
2024
|
|
2155
2025
|
os.makedirs(dst, exist_ok=True)
|
2156
2026
|
|
@@ -2179,7 +2049,7 @@ def identify_masks_finetune(settings):
|
|
2179
2049
|
print(f'Using channels: {chans} for model of type {model_name}')
|
2180
2050
|
|
2181
2051
|
if verbose == True:
|
2182
|
-
print(f'Cellpose settings: Model: {model_name}, channels: {channels}, cellpose_chans: {chans}, diameter:{diameter}, flow_threshold:{flow_threshold}, cellprob_threshold:{
|
2052
|
+
print(f'Cellpose settings: Model: {model_name}, channels: {channels}, cellpose_chans: {chans}, diameter:{diameter}, flow_threshold:{flow_threshold}, cellprob_threshold:{CP_prob}')
|
2183
2053
|
|
2184
2054
|
all_image_files = [os.path.join(src, f) for f in os.listdir(src) if f.endswith('.tif')]
|
2185
2055
|
|
@@ -2188,9 +2058,9 @@ def identify_masks_finetune(settings):
|
|
2188
2058
|
time_ls = []
|
2189
2059
|
for i in range(0, len(all_image_files), batch_size):
|
2190
2060
|
image_files = all_image_files[i:i+batch_size]
|
2191
|
-
|
2061
|
+
|
2192
2062
|
if normalize:
|
2193
|
-
images, _, image_names, _ = _load_normalized_images_and_labels(image_files=image_files, label_files=None,
|
2063
|
+
images, _, image_names, _ = _load_normalized_images_and_labels(image_files=image_files, label_files=None, channels=channels, percentiles=percentiles, circular=circular, invert=invert, visualize=verbose, remove_background=remove_background, background=background, Signal_to_noise=Signal_to_noise)
|
2194
2064
|
images = [np.squeeze(img) if img.shape[-1] == 1 else img for img in images]
|
2195
2065
|
orig_dims = [(image.shape[0], image.shape[1]) for image in images]
|
2196
2066
|
else:
|
@@ -2208,7 +2078,7 @@ def identify_masks_finetune(settings):
|
|
2208
2078
|
channel_axis=3,
|
2209
2079
|
diameter=diameter,
|
2210
2080
|
flow_threshold=flow_threshold,
|
2211
|
-
cellprob_threshold=
|
2081
|
+
cellprob_threshold=CP_prob,
|
2212
2082
|
rescale=rescale,
|
2213
2083
|
resample=resample,
|
2214
2084
|
progress=True)
|
@@ -2229,11 +2099,12 @@ def identify_masks_finetune(settings):
|
|
2229
2099
|
time_ls.append(duration)
|
2230
2100
|
average_time = np.mean(time_ls) if len(time_ls) > 0 else 0
|
2231
2101
|
print(f'Processing {file_index+1}/{len(images)} images : Time/image {average_time:.3f} sec', end='\r', flush=True)
|
2232
|
-
if
|
2102
|
+
if verbose:
|
2233
2103
|
if resize:
|
2234
2104
|
stack = resizescikit(stack, dims, preserve_range=True, anti_aliasing=False).astype(stack.dtype)
|
2235
|
-
print_mask_and_flows(stack, mask, flows, overlay=
|
2105
|
+
print_mask_and_flows(stack, mask, flows, overlay=True)
|
2236
2106
|
if save:
|
2107
|
+
os.makedirs(dst, exist_ok=True)
|
2237
2108
|
output_filename = os.path.join(dst, image_names[file_index])
|
2238
2109
|
cv2.imwrite(output_filename, mask)
|
2239
2110
|
return
|
@@ -2375,8 +2246,6 @@ def identify_masks(src, object_type, model_name, batch_size, channels, diameter,
|
|
2375
2246
|
stitch_threshold=0.0
|
2376
2247
|
|
2377
2248
|
cellpose_batch_size = _get_cellpose_batch_size()
|
2378
|
-
|
2379
|
-
#model = cellpose.denoise.DenoiseModel(model_type=f"denoise_{model_name}", gpu=True)
|
2380
2249
|
|
2381
2250
|
masks, flows, _, _ = model.eval(x=batch,
|
2382
2251
|
batch_size=cellpose_batch_size,
|
@@ -2450,9 +2319,21 @@ def all_elements_match(list1, list2):
|
|
2450
2319
|
# Check if all elements in list1 are in list2
|
2451
2320
|
return all(element in list2 for element in list1)
|
2452
2321
|
|
2322
|
+
def prepare_batch_for_cellpose(batch):
|
2323
|
+
# Ensure the batch is of dtype float32
|
2324
|
+
if batch.dtype != np.float32:
|
2325
|
+
batch = batch.astype(np.float32)
|
2326
|
+
|
2327
|
+
# Normalize each image in the batch
|
2328
|
+
for i in range(batch.shape[0]):
|
2329
|
+
if batch[i].max() > 1:
|
2330
|
+
batch[i] = batch[i] / batch[i].max()
|
2331
|
+
|
2332
|
+
return batch
|
2333
|
+
|
2453
2334
|
def generate_cellpose_masks(src, settings, object_type):
|
2454
2335
|
|
2455
|
-
from .utils import _masks_to_masks_stack, _filter_cp_masks, _get_cellpose_batch_size, _get_object_settings, _get_cellpose_channels, _choose_model, mask_object_count
|
2336
|
+
from .utils import _masks_to_masks_stack, _filter_cp_masks, _get_cellpose_batch_size, _get_object_settings, _get_cellpose_channels, _choose_model, mask_object_count, set_default_settings_preprocess_generate_masks
|
2456
2337
|
from .io import _create_database, _save_object_counts_to_database, _check_masks, _get_avg_object_size
|
2457
2338
|
from .timelapse import _npz_to_movie, _btrack_track_cells, _trackpy_track_cells
|
2458
2339
|
from .plot import plot_masks
|
@@ -2460,6 +2341,13 @@ def generate_cellpose_masks(src, settings, object_type):
|
|
2460
2341
|
gc.collect()
|
2461
2342
|
if not torch.cuda.is_available():
|
2462
2343
|
print(f'Torch CUDA is not available, using CPU')
|
2344
|
+
|
2345
|
+
settings = set_default_settings_preprocess_generate_masks(src, settings)
|
2346
|
+
|
2347
|
+
if settings['verbose']:
|
2348
|
+
settings_df = pd.DataFrame(list(settings.items()), columns=['setting_key', 'setting_value'])
|
2349
|
+
settings_df['setting_value'] = settings_df['setting_value'].apply(str)
|
2350
|
+
display(settings_df)
|
2463
2351
|
|
2464
2352
|
figuresize=25
|
2465
2353
|
timelapse = settings['timelapse']
|
@@ -2474,8 +2362,9 @@ def generate_cellpose_masks(src, settings, object_type):
|
|
2474
2362
|
|
2475
2363
|
batch_size = settings['batch_size']
|
2476
2364
|
cellprob_threshold = settings[f'{object_type}_CP_prob']
|
2477
|
-
|
2478
|
-
|
2365
|
+
|
2366
|
+
flow_threshold = settings[f'{object_type}_FT']
|
2367
|
+
|
2479
2368
|
object_settings = _get_object_settings(object_type, settings)
|
2480
2369
|
model_name = object_settings['model_name']
|
2481
2370
|
|
@@ -2486,7 +2375,12 @@ def generate_cellpose_masks(src, settings, object_type):
|
|
2486
2375
|
channels = cellpose_channels[object_type]
|
2487
2376
|
cellpose_batch_size = _get_cellpose_batch_size()
|
2488
2377
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
2489
|
-
|
2378
|
+
|
2379
|
+
if object_type == 'pathogen' and not settings['pathogen_model'] is None:
|
2380
|
+
model_name = settings['pathogen_model']
|
2381
|
+
|
2382
|
+
model = _choose_model(model_name, device, object_type=object_type, restore_type=None, object_settings=object_settings)
|
2383
|
+
|
2490
2384
|
chans = [2, 1] if model_name == 'cyto2' else [0,0] if model_name == 'nucleus' else [2,0] if model_name == 'cyto' else [2, 0] if model_name == 'cyto3' else [2, 0]
|
2491
2385
|
paths = [os.path.join(src, file) for file in os.listdir(src) if file.endswith('.npz')]
|
2492
2386
|
|
@@ -2505,6 +2399,14 @@ def generate_cellpose_masks(src, settings, object_type):
|
|
2505
2399
|
with np.load(path) as data:
|
2506
2400
|
stack = data['data']
|
2507
2401
|
filenames = data['filenames']
|
2402
|
+
|
2403
|
+
for i, filename in enumerate(filenames):
|
2404
|
+
output_path = os.path.join(output_folder, filename)
|
2405
|
+
|
2406
|
+
if os.path.exists(output_path):
|
2407
|
+
print(f"File {filename} already exists in the output folder. Skipping...")
|
2408
|
+
continue
|
2409
|
+
|
2508
2410
|
if settings['timelapse']:
|
2509
2411
|
|
2510
2412
|
trackable_objects = ['cell','nucleus','pathogen']
|
@@ -2539,31 +2441,43 @@ def generate_cellpose_masks(src, settings, object_type):
|
|
2539
2441
|
if batch.size == 0:
|
2540
2442
|
print(f'Processing {file_index}/{len(paths)}: Images/npz {batch.shape[0]}')
|
2541
2443
|
continue
|
2542
|
-
|
2543
|
-
|
2444
|
+
|
2445
|
+
batch = prepare_batch_for_cellpose(batch)
|
2544
2446
|
|
2545
2447
|
if timelapse:
|
2546
|
-
stitch_threshold=100.0
|
2547
2448
|
movie_path = os.path.join(os.path.dirname(src), 'movies')
|
2548
2449
|
os.makedirs(movie_path, exist_ok=True)
|
2549
2450
|
save_path = os.path.join(movie_path, f'timelapse_{object_type}_{name}.mp4')
|
2550
2451
|
_npz_to_movie(batch, batch_filenames, save_path, fps=2)
|
2551
|
-
else:
|
2552
|
-
stitch_threshold=0.0
|
2553
|
-
|
2554
|
-
print('batch.shape',batch.shape)
|
2555
|
-
masks, flows, _, _ = model.eval(x=batch,
|
2556
|
-
batch_size=cellpose_batch_size,
|
2557
|
-
normalize=False,
|
2558
|
-
channels=chans,
|
2559
|
-
channel_axis=3,
|
2560
|
-
diameter=object_settings['diameter'],
|
2561
|
-
flow_threshold=flow_threshold,
|
2562
|
-
cellprob_threshold=cellprob_threshold,
|
2563
|
-
rescale=None,
|
2564
|
-
resample=object_settings['resample'],
|
2565
|
-
stitch_threshold=stitch_threshold)
|
2566
2452
|
|
2453
|
+
if settings['verbose']:
|
2454
|
+
print(f'Processing {file_index}/{len(paths)}: Images/npz {batch.shape[0]}')
|
2455
|
+
|
2456
|
+
#cellpose_normalize_dict = {'lowhigh':[0.0,1.0], #pass in normalization values for 0.0 and 1.0 as list [low, high] if None all other keys ignored
|
2457
|
+
# 'sharpen':object_settings['diameter']/4, #recommended to be 1/4-1/8 diameter of cells in pixels
|
2458
|
+
# 'normalize':True, #(if False, all following parameters ignored)
|
2459
|
+
# 'percentile':[2,98], #[perc_low, perc_high]
|
2460
|
+
# 'tile_norm':224, #normalize by tile set to e.g. 100 for normailize window to be 100 px
|
2461
|
+
# 'norm3D':True} #compute normalization across entire z-stack rather than plane-by-plane in stitching mode.
|
2462
|
+
|
2463
|
+
output = model.eval(x=batch,
|
2464
|
+
batch_size=cellpose_batch_size,
|
2465
|
+
normalize=False,
|
2466
|
+
channels=chans,
|
2467
|
+
channel_axis=3,
|
2468
|
+
diameter=object_settings['diameter'],
|
2469
|
+
flow_threshold=flow_threshold,
|
2470
|
+
cellprob_threshold=cellprob_threshold,
|
2471
|
+
rescale=None,
|
2472
|
+
resample=object_settings['resample'])
|
2473
|
+
|
2474
|
+
if len(output) == 4:
|
2475
|
+
masks, flows, _, _ = output
|
2476
|
+
elif len(output) == 3:
|
2477
|
+
masks, flows, _ = output
|
2478
|
+
else:
|
2479
|
+
raise ValueError(f"Unexpected number of return values from model.eval(). Expected 3 or 4, got {len(output)}")
|
2480
|
+
|
2567
2481
|
if timelapse:
|
2568
2482
|
if settings['plot']:
|
2569
2483
|
for idx, (mask, flow, image) in enumerate(zip(masks, flows[0], batch)):
|
@@ -2676,15 +2590,15 @@ def generate_cellpose_masks(src, settings, object_type):
|
|
2676
2590
|
torch.cuda.empty_cache()
|
2677
2591
|
return
|
2678
2592
|
|
2679
|
-
def generate_masks_from_imgs(src, model, model_name, batch_size, diameter, cellprob_threshold, grayscale, save, normalize, channels, percentiles, circular, invert, plot, resize, target_height, target_width, verbose):
|
2593
|
+
def generate_masks_from_imgs(src, model, model_name, batch_size, diameter, cellprob_threshold, flow_threshold, grayscale, save, normalize, channels, percentiles, circular, invert, plot, resize, target_height, target_width, remove_background, background, Signal_to_noise, verbose):
|
2594
|
+
|
2680
2595
|
from .io import _load_images_and_labels, _load_normalized_images_and_labels
|
2681
2596
|
from .utils import resize_images_and_labels, resizescikit
|
2682
2597
|
from .plot import print_mask_and_flows
|
2683
2598
|
|
2684
2599
|
dst = os.path.join(src, model_name)
|
2685
2600
|
os.makedirs(dst, exist_ok=True)
|
2686
|
-
|
2687
|
-
flow_threshold = 30
|
2601
|
+
|
2688
2602
|
chans = [2, 1] if model_name == 'cyto2' else [0,0] if model_name == 'nucleus' else [1,0] if model_name == 'cyto' else [2, 0]
|
2689
2603
|
|
2690
2604
|
if grayscale:
|
@@ -2692,7 +2606,6 @@ def generate_masks_from_imgs(src, model, model_name, batch_size, diameter, cellp
|
|
2692
2606
|
|
2693
2607
|
all_image_files = [os.path.join(src, f) for f in os.listdir(src) if f.endswith('.tif')]
|
2694
2608
|
random.shuffle(all_image_files)
|
2695
|
-
|
2696
2609
|
|
2697
2610
|
if verbose == True:
|
2698
2611
|
print(f'Cellpose settings: Model: {model_name}, channels: {channels}, cellpose_chans: {chans}, diameter:{diameter}, flow_threshold:{flow_threshold}, cellprob_threshold:{cellprob_threshold}')
|
@@ -2702,11 +2615,11 @@ def generate_masks_from_imgs(src, model, model_name, batch_size, diameter, cellp
|
|
2702
2615
|
image_files = all_image_files[i:i+batch_size]
|
2703
2616
|
|
2704
2617
|
if normalize:
|
2705
|
-
images, _, image_names, _ = _load_normalized_images_and_labels(image_files
|
2618
|
+
images, _, image_names, _ = _load_normalized_images_and_labels(image_files, None, channels, percentiles, circular, invert, plot, remove_background, background, Signal_to_noise)
|
2706
2619
|
images = [np.squeeze(img) if img.shape[-1] == 1 else img for img in images]
|
2707
2620
|
orig_dims = [(image.shape[0], image.shape[1]) for image in images]
|
2708
2621
|
else:
|
2709
|
-
images, _, image_names, _ = _load_images_and_labels(image_files
|
2622
|
+
images, _, image_names, _ = _load_images_and_labels(image_files, None, circular, invert)
|
2710
2623
|
images = [np.squeeze(img) if img.shape[-1] == 1 else img for img in images]
|
2711
2624
|
orig_dims = [(image.shape[0], image.shape[1]) for image in images]
|
2712
2625
|
if resize:
|
@@ -2723,7 +2636,7 @@ def generate_masks_from_imgs(src, model, model_name, batch_size, diameter, cellp
|
|
2723
2636
|
cellprob_threshold=cellprob_threshold,
|
2724
2637
|
rescale=False,
|
2725
2638
|
resample=False,
|
2726
|
-
progress=
|
2639
|
+
progress=False)
|
2727
2640
|
|
2728
2641
|
if len(output) == 4:
|
2729
2642
|
mask, flows, _, _ = output
|
@@ -2753,22 +2666,31 @@ def generate_masks_from_imgs(src, model, model_name, batch_size, diameter, cellp
|
|
2753
2666
|
def check_cellpose_models(settings):
|
2754
2667
|
|
2755
2668
|
src = settings['src']
|
2756
|
-
|
2757
|
-
|
2758
|
-
|
2759
|
-
|
2760
|
-
|
2761
|
-
|
2762
|
-
|
2763
|
-
|
2764
|
-
|
2765
|
-
|
2766
|
-
|
2767
|
-
|
2768
|
-
|
2769
|
-
|
2770
|
-
|
2771
|
-
|
2669
|
+
settings.setdefault('batch_size', 10)
|
2670
|
+
settings.setdefault('CP_prob', 0)
|
2671
|
+
settings.setdefault('flow_threshold', 0.4)
|
2672
|
+
settings.setdefault('save', True)
|
2673
|
+
settings.setdefault('normalize', True)
|
2674
|
+
settings.setdefault('channels', [0,0])
|
2675
|
+
settings.setdefault('percentiles', None)
|
2676
|
+
settings.setdefault('circular', False)
|
2677
|
+
settings.setdefault('invert', False)
|
2678
|
+
settings.setdefault('plot', True)
|
2679
|
+
settings.setdefault('diameter', 40)
|
2680
|
+
settings.setdefault('grayscale', True)
|
2681
|
+
settings.setdefault('remove_background', False)
|
2682
|
+
settings.setdefault('background', 100)
|
2683
|
+
settings.setdefault('Signal_to_noise', 5)
|
2684
|
+
settings.setdefault('verbose', False)
|
2685
|
+
settings.setdefault('resize', False)
|
2686
|
+
settings.setdefault('target_height', None)
|
2687
|
+
settings.setdefault('target_width', None)
|
2688
|
+
|
2689
|
+
if settings['verbose']:
|
2690
|
+
settings_df = pd.DataFrame(list(settings.items()), columns=['setting_key', 'setting_value'])
|
2691
|
+
settings_df['setting_value'] = settings_df['setting_value'].apply(str)
|
2692
|
+
display(settings_df)
|
2693
|
+
|
2772
2694
|
cellpose_models = ['cyto', 'nuclei', 'cyto2', 'cyto3']
|
2773
2695
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
2774
2696
|
|
@@ -2776,149 +2698,22 @@ def check_cellpose_models(settings):
|
|
2776
2698
|
|
2777
2699
|
model = cp_models.CellposeModel(gpu=True, model_type=model_name, device=device)
|
2778
2700
|
print(f'Using {model_name}')
|
2779
|
-
generate_masks_from_imgs(src, model, model_name, batch_size, diameter,
|
2780
|
-
|
2781
|
-
return
|
2782
|
-
|
2783
|
-
def compare_masks_v1(dir1, dir2, dir3, verbose=False):
|
2784
|
-
|
2785
|
-
from .io import _read_mask
|
2786
|
-
from .plot import visualize_masks, plot_comparison_results
|
2787
|
-
from .utils import extract_boundaries, boundary_f1_score, compute_segmentation_ap, jaccard_index, dice_coefficient
|
2788
|
-
|
2789
|
-
filenames = os.listdir(dir1)
|
2790
|
-
results = []
|
2791
|
-
cond_1 = os.path.basename(dir1)
|
2792
|
-
cond_2 = os.path.basename(dir2)
|
2793
|
-
cond_3 = os.path.basename(dir3)
|
2794
|
-
|
2795
|
-
for index, filename in enumerate(filenames):
|
2796
|
-
print(f'Processing image:{index+1}', end='\r', flush=True)
|
2797
|
-
path1, path2, path3 = os.path.join(dir1, filename), os.path.join(dir2, filename), os.path.join(dir3, filename)
|
2798
|
-
|
2799
|
-
print(path1)
|
2800
|
-
print(path2)
|
2801
|
-
print(path3)
|
2802
|
-
|
2803
|
-
if os.path.exists(path2) and os.path.exists(path3):
|
2804
|
-
|
2805
|
-
mask1, mask2, mask3 = _read_mask(path1), _read_mask(path2), _read_mask(path3)
|
2806
|
-
boundary_true1, boundary_true2, boundary_true3 = extract_boundaries(mask1), extract_boundaries(mask2), extract_boundaries(mask3)
|
2807
|
-
|
2808
|
-
|
2809
|
-
true_masks, pred_masks = [mask1], [mask2, mask3] # Assuming mask1 is the ground truth for simplicity
|
2810
|
-
true_labels, pred_labels_1, pred_labels_2 = label(mask1), label(mask2), label(mask3)
|
2811
|
-
average_precision_0, average_precision_1 = compute_segmentation_ap(mask1, mask2), compute_segmentation_ap(mask1, mask3)
|
2812
|
-
ap_scores = [average_precision_0, average_precision_1]
|
2813
|
-
|
2814
|
-
if verbose:
|
2815
|
-
#unique_values1, unique_values2, unique_values3 = np.unique(mask1), np.unique(mask2), np.unique(mask3)
|
2816
|
-
#print(f"Unique values in mask 1: {unique_values1}, mask 2: {unique_values2}, mask 3: {unique_values3}")
|
2817
|
-
visualize_masks(boundary_true1, boundary_true2, boundary_true3, title=f"Boundaries - {filename}")
|
2818
|
-
|
2819
|
-
boundary_f1_12, boundary_f1_13, boundary_f1_23 = boundary_f1_score(mask1, mask2), boundary_f1_score(mask1, mask3), boundary_f1_score(mask2, mask3)
|
2820
|
-
|
2821
|
-
if (np.unique(mask1).size == 1 and np.unique(mask1)[0] == 0) and \
|
2822
|
-
(np.unique(mask2).size == 1 and np.unique(mask2)[0] == 0) and \
|
2823
|
-
(np.unique(mask3).size == 1 and np.unique(mask3)[0] == 0):
|
2824
|
-
continue
|
2825
|
-
|
2826
|
-
if verbose:
|
2827
|
-
#unique_values4, unique_values5, unique_values6 = np.unique(boundary_f1_12), np.unique(boundary_f1_13), np.unique(boundary_f1_23)
|
2828
|
-
#print(f"Unique values in boundary mask 1: {unique_values4}, mask 2: {unique_values5}, mask 3: {unique_values6}")
|
2829
|
-
visualize_masks(mask1, mask2, mask3, title=filename)
|
2830
|
-
|
2831
|
-
jaccard12 = jaccard_index(mask1, mask2)
|
2832
|
-
dice12 = dice_coefficient(mask1, mask2)
|
2833
|
-
|
2834
|
-
jaccard13 = jaccard_index(mask1, mask3)
|
2835
|
-
dice13 = dice_coefficient(mask1, mask3)
|
2836
|
-
|
2837
|
-
jaccard23 = jaccard_index(mask2, mask3)
|
2838
|
-
dice23 = dice_coefficient(mask2, mask3)
|
2839
|
-
|
2840
|
-
results.append({
|
2841
|
-
f'filename': filename,
|
2842
|
-
f'jaccard_{cond_1}_{cond_2}': jaccard12,
|
2843
|
-
f'dice_{cond_1}_{cond_2}': dice12,
|
2844
|
-
f'jaccard_{cond_1}_{cond_3}': jaccard13,
|
2845
|
-
f'dice_{cond_1}_{cond_3}': dice13,
|
2846
|
-
f'jaccard_{cond_2}_{cond_3}': jaccard23,
|
2847
|
-
f'dice_{cond_2}_{cond_3}': dice23,
|
2848
|
-
f'boundary_f1_{cond_1}_{cond_2}': boundary_f1_12,
|
2849
|
-
f'boundary_f1_{cond_1}_{cond_3}': boundary_f1_13,
|
2850
|
-
f'boundary_f1_{cond_2}_{cond_3}': boundary_f1_23,
|
2851
|
-
f'average_precision_{cond_1}_{cond_2}': ap_scores[0],
|
2852
|
-
f'average_precision_{cond_1}_{cond_3}': ap_scores[1]
|
2853
|
-
})
|
2854
|
-
else:
|
2855
|
-
print(f'Cannot find {path1} or {path2} or {path3}')
|
2856
|
-
fig = plot_comparison_results(results)
|
2857
|
-
return results, fig
|
2858
|
-
|
2859
|
-
def compare_cellpose_masks_v1(src, verbose=False):
|
2860
|
-
from .io import _read_mask
|
2861
|
-
from .plot import visualize_masks, plot_comparison_results, visualize_cellpose_masks
|
2862
|
-
from .utils import extract_boundaries, boundary_f1_score, compute_segmentation_ap, jaccard_index
|
2863
|
-
|
2864
|
-
import os
|
2865
|
-
import numpy as np
|
2866
|
-
from skimage.measure import label
|
2867
|
-
|
2868
|
-
# Collect all subdirectories in src
|
2869
|
-
dirs = [os.path.join(src, d) for d in os.listdir(src) if os.path.isdir(os.path.join(src, d))]
|
2870
|
-
|
2871
|
-
dirs.sort() # Optional: sort directories if needed
|
2872
|
-
|
2873
|
-
# Get common files in all directories
|
2874
|
-
common_files = set(os.listdir(dirs[0]))
|
2875
|
-
for d in dirs[1:]:
|
2876
|
-
common_files.intersection_update(os.listdir(d))
|
2877
|
-
common_files = list(common_files)
|
2878
|
-
|
2879
|
-
results = []
|
2880
|
-
conditions = [os.path.basename(d) for d in dirs]
|
2881
|
-
|
2882
|
-
for index, filename in enumerate(common_files):
|
2883
|
-
print(f'Processing image {index+1}/{len(common_files)}', end='\r', flush=True)
|
2884
|
-
paths = [os.path.join(d, filename) for d in dirs]
|
2701
|
+
generate_masks_from_imgs(src, model, model_name, settings['batch_size'], settings['diameter'], settings['CP_prob'], settings['flow_threshold'], settings['grayscale'], settings['save'], settings['normalize'], settings['channels'], settings['percentiles'], settings['circular'], settings['invert'], settings['plot'], settings['resize'], settings['target_height'], settings['target_width'], settings['remove_background'], settings['background'], settings['Signal_to_noise'], settings['verbose'])
|
2885
2702
|
|
2886
|
-
|
2887
|
-
if not all(os.path.exists(path) for path in paths):
|
2888
|
-
print(f'Skipping {filename} as it is not present in all directories.')
|
2889
|
-
continue
|
2890
|
-
|
2891
|
-
masks = [_read_mask(path) for path in paths]
|
2892
|
-
boundaries = [extract_boundaries(mask) for mask in masks]
|
2893
|
-
|
2894
|
-
if verbose:
|
2895
|
-
visualize_cellpose_masks(masks, titles=conditions, comparison_title=f"Masks Comparison for {filename}")
|
2896
|
-
|
2897
|
-
# Initialize data structure for results
|
2898
|
-
file_results = {'filename': filename}
|
2899
|
-
|
2900
|
-
# Compare each mask with each other
|
2901
|
-
for i in range(len(masks)):
|
2902
|
-
for j in range(i + 1, len(masks)):
|
2903
|
-
condition_i = conditions[i]
|
2904
|
-
condition_j = conditions[j]
|
2905
|
-
mask_i = masks[i]
|
2906
|
-
mask_j = masks[j]
|
2907
|
-
|
2908
|
-
# Compute metrics
|
2909
|
-
boundary_f1 = boundary_f1_score(mask_i, mask_j)
|
2910
|
-
jaccard = jaccard_index(mask_i, mask_j)
|
2911
|
-
average_precision = compute_segmentation_ap(mask_i, mask_j)
|
2703
|
+
return
|
2912
2704
|
|
2913
|
-
|
2914
|
-
file_results[f'jaccard_{condition_i}_{condition_j}'] = jaccard
|
2915
|
-
file_results[f'boundary_f1_{condition_i}_{condition_j}'] = boundary_f1
|
2916
|
-
file_results[f'average_precision_{condition_i}_{condition_j}'] = average_precision
|
2705
|
+
def save_results_and_figure(src, fig, results):
|
2917
2706
|
|
2918
|
-
|
2707
|
+
if not isinstance(results, pd.DataFrame):
|
2708
|
+
results = pd.DataFrame(results)
|
2919
2709
|
|
2920
|
-
|
2921
|
-
|
2710
|
+
results_dir = os.path.join(src, 'results')
|
2711
|
+
os.makedirs(results_dir, exist_ok=True)
|
2712
|
+
results_path = os.path.join(results_dir,f'results.csv')
|
2713
|
+
fig_path = os.path.join(results_dir, f'model_comparison_plot.pdf')
|
2714
|
+
results.to_csv(results_path, index=False)
|
2715
|
+
fig.savefig(fig_path, format='pdf')
|
2716
|
+
print(f'Saved figure to {fig_path} and results to {results_path}')
|
2922
2717
|
|
2923
2718
|
def compare_mask(args):
|
2924
2719
|
src, filename, dirs, conditions = args
|
@@ -2949,10 +2744,11 @@ def compare_mask(args):
|
|
2949
2744
|
|
2950
2745
|
return file_results
|
2951
2746
|
|
2952
|
-
def compare_cellpose_masks(src, verbose=False, processes=None):
|
2747
|
+
def compare_cellpose_masks(src, verbose=False, processes=None, save=True):
|
2953
2748
|
from .plot import visualize_cellpose_masks, plot_comparison_results
|
2954
2749
|
from .io import _read_mask
|
2955
|
-
|
2750
|
+
|
2751
|
+
dirs = [os.path.join(src, d) for d in os.listdir(src) if os.path.isdir(os.path.join(src, d)) and d != 'results']
|
2956
2752
|
dirs.sort() # Optional: sort directories if needed
|
2957
2753
|
conditions = [os.path.basename(d) for d in dirs]
|
2958
2754
|
|
@@ -2969,16 +2765,16 @@ def compare_cellpose_masks(src, verbose=False, processes=None):
|
|
2969
2765
|
|
2970
2766
|
# Filter out None results (from skipped files)
|
2971
2767
|
results = [res for res in results if res is not None]
|
2972
|
-
|
2768
|
+
#print(results)
|
2973
2769
|
if verbose:
|
2974
2770
|
for result in results:
|
2975
2771
|
filename = result['filename']
|
2976
2772
|
masks = [_read_mask(os.path.join(d, filename)) for d in dirs]
|
2977
|
-
visualize_cellpose_masks(masks, titles=conditions,
|
2773
|
+
visualize_cellpose_masks(masks, titles=conditions, filename=filename, save=save, src=src)
|
2978
2774
|
|
2979
2775
|
fig = plot_comparison_results(results)
|
2980
|
-
|
2981
|
-
|
2776
|
+
save_results_and_figure(src, fig, results)
|
2777
|
+
return
|
2982
2778
|
|
2983
2779
|
def _calculate_similarity(df, features, col_to_compare, val1, val2):
|
2984
2780
|
"""
|
@@ -3060,6 +2856,8 @@ def _permutation_importance(df, feature_string='channel_3', col_to_compare='col'
|
|
3060
2856
|
pandas.DataFrame: DataFrame containing the importances and standard deviations.
|
3061
2857
|
"""
|
3062
2858
|
|
2859
|
+
from .utils import filter_dataframe_features
|
2860
|
+
|
3063
2861
|
if 'cells_per_well' in df.columns:
|
3064
2862
|
df = df.drop(columns=['cells_per_well'])
|
3065
2863
|
|
@@ -3074,33 +2872,12 @@ def _permutation_importance(df, feature_string='channel_3', col_to_compare='col'
|
|
3074
2872
|
# Combine the subsets for analysis
|
3075
2873
|
combined_df = pd.concat([df1, df2])
|
3076
2874
|
|
3077
|
-
|
3078
|
-
|
3079
|
-
|
3080
|
-
|
3081
|
-
if clean:
|
3082
|
-
combined_df = combined_df.loc[:, combined_df.nunique() > 1]
|
3083
|
-
features = [feature for feature in features if feature in combined_df.columns]
|
3084
|
-
|
3085
|
-
if feature_string is not None:
|
3086
|
-
feature_list = ['channel_0', 'channel_1', 'channel_2', 'channel_3']
|
3087
|
-
|
3088
|
-
# Remove feature_string from the list if it exists
|
3089
|
-
if feature_string in feature_list:
|
3090
|
-
feature_list.remove(feature_string)
|
3091
|
-
|
3092
|
-
features = [feature for feature in features if feature_string in feature]
|
2875
|
+
if feature_string in ['channel_0', 'channel_1', 'channel_2', 'channel_3']:
|
2876
|
+
channel_of_interest = int(feature_string.split('_')[-1])
|
2877
|
+
elif not feature_string is 'morphology':
|
2878
|
+
channel_of_interest = 'morphology'
|
3093
2879
|
|
3094
|
-
|
3095
|
-
for feature_ in feature_list:
|
3096
|
-
features = [feature for feature in features if feature_ not in feature]
|
3097
|
-
print(f'After removing {feature_} features: {len(features)}')
|
3098
|
-
|
3099
|
-
if exclude:
|
3100
|
-
if isinstance(exclude, list):
|
3101
|
-
features = [feature for feature in features if feature not in exclude]
|
3102
|
-
else:
|
3103
|
-
features.remove(exclude)
|
2880
|
+
_, features = filter_dataframe_features(combined_df, channel_of_interest, exclude)
|
3104
2881
|
|
3105
2882
|
X = combined_df[features]
|
3106
2883
|
y = combined_df['target']
|
@@ -3333,4 +3110,363 @@ def jitterplot_by_annotation(src, x_column, y_column, plot_title='Jitter Plot',
|
|
3333
3110
|
else:
|
3334
3111
|
plt.show()
|
3335
3112
|
|
3336
|
-
return balanced_df
|
3113
|
+
return balanced_df
|
3114
|
+
|
3115
|
+
def generate_image_umap(settings={}):
|
3116
|
+
"""
|
3117
|
+
Generate UMAP or tSNE embedding and visualize the data with clustering.
|
3118
|
+
|
3119
|
+
Parameters:
|
3120
|
+
settings (dict): Dictionary containing the following keys:
|
3121
|
+
src (str): Source directory containing the data.
|
3122
|
+
row_limit (int): Limit the number of rows to process.
|
3123
|
+
tables (list): List of table names to read from the database.
|
3124
|
+
visualize (str): Visualization type.
|
3125
|
+
image_nr (int): Number of images to display.
|
3126
|
+
dot_size (int): Size of dots in the scatter plot.
|
3127
|
+
n_neighbors (int): Number of neighbors for UMAP.
|
3128
|
+
figuresize (int): Size of the figure.
|
3129
|
+
black_background (bool): Whether to use a black background.
|
3130
|
+
remove_image_canvas (bool): Whether to remove the image canvas.
|
3131
|
+
plot_outlines (bool): Whether to plot outlines.
|
3132
|
+
plot_points (bool): Whether to plot points.
|
3133
|
+
smooth_lines (bool): Whether to smooth lines.
|
3134
|
+
verbose (bool): Whether to print verbose output.
|
3135
|
+
embedding_by_controls (bool): Whether to use embedding from controls.
|
3136
|
+
col_to_compare (str): Column to compare for control-based embedding.
|
3137
|
+
pos (str): Positive control value.
|
3138
|
+
neg (str): Negative control value.
|
3139
|
+
clustering (str): Clustering method ('DBSCAN' or 'KMeans').
|
3140
|
+
exclude (list): List of columns to exclude from the analysis.
|
3141
|
+
plot_images (bool): Whether to plot images.
|
3142
|
+
reduction_method (str): Dimensionality reduction method ('UMAP' or 'tSNE').
|
3143
|
+
save_figure (bool): Whether to save the figure as a PDF.
|
3144
|
+
|
3145
|
+
Returns:
|
3146
|
+
pd.DataFrame: DataFrame with the original data and an additional column 'cluster' containing the cluster identity.
|
3147
|
+
"""
|
3148
|
+
|
3149
|
+
from .io import _read_and_join_tables
|
3150
|
+
from .utils import get_db_paths, preprocess_data, reduction_and_clustering, remove_noise, generate_colors, correct_paths, plot_embedding, plot_clusters_grid, get_umap_image_settings
|
3151
|
+
from .alpha import cluster_feature_analysis, generate_umap_from_images
|
3152
|
+
|
3153
|
+
settings = get_umap_image_settings(settings)
|
3154
|
+
|
3155
|
+
if isinstance(settings['src'], str):
|
3156
|
+
settings['src'] = [settings['src']]
|
3157
|
+
|
3158
|
+
if settings['plot_images'] is False:
|
3159
|
+
settings['black_background'] = False
|
3160
|
+
|
3161
|
+
if settings['color_by']:
|
3162
|
+
settings['remove_cluster_noise'] = False
|
3163
|
+
settings['plot_outlines'] = False
|
3164
|
+
settings['smooth_lines'] = False
|
3165
|
+
|
3166
|
+
settings_df = pd.DataFrame(list(settings.items()), columns=['Key', 'Value'])
|
3167
|
+
settings_dir = os.path.join(settings['src'][0],'settings')
|
3168
|
+
settings_csv = os.path.join(settings_dir,'embedding_settings.csv')
|
3169
|
+
os.makedirs(settings_dir, exist_ok=True)
|
3170
|
+
settings_df.to_csv(settings_csv, index=False)
|
3171
|
+
display(settings_df)
|
3172
|
+
|
3173
|
+
db_paths = get_db_paths(settings['src'])
|
3174
|
+
|
3175
|
+
tables = settings['tables'] + ['png_list']
|
3176
|
+
all_df = pd.DataFrame()
|
3177
|
+
#image_paths = []
|
3178
|
+
|
3179
|
+
for i,db_path in enumerate(db_paths):
|
3180
|
+
df = _read_and_join_tables(db_path, table_names=tables)
|
3181
|
+
df, image_paths_tmp = correct_paths(df, settings['src'][i])
|
3182
|
+
all_df = pd.concat([all_df, df], axis=0)
|
3183
|
+
#image_paths.extend(image_paths_tmp)
|
3184
|
+
|
3185
|
+
all_df['cond'] = all_df['col'].apply(map_condition, neg=settings['neg'], pos=settings['pos'], mix=settings['mix'])
|
3186
|
+
|
3187
|
+
if settings['exclude_conditions']:
|
3188
|
+
if isinstance(settings['exclude_conditions'], str):
|
3189
|
+
settings['exclude_conditions'] = [settings['exclude_conditions']]
|
3190
|
+
row_count_before = len(all_df)
|
3191
|
+
all_df = all_df[~all_df['cond'].isin(settings['exclude_conditions'])]
|
3192
|
+
if settings['verbose']:
|
3193
|
+
print(f'Excluded {row_count_before - len(all_df)} rows after excluding: {settings["exclude_conditions"]}, rows left: {len(all_df)}')
|
3194
|
+
|
3195
|
+
if settings['row_limit'] is not None:
|
3196
|
+
all_df = all_df.sample(n=settings['row_limit'], random_state=42)
|
3197
|
+
|
3198
|
+
image_paths = all_df['png_path'].to_list()
|
3199
|
+
|
3200
|
+
if settings['embedding_by_controls']:
|
3201
|
+
|
3202
|
+
# Extract and reset the index for the column to compare
|
3203
|
+
col_to_compare = all_df[settings['col_to_compare']].reset_index(drop=True)
|
3204
|
+
|
3205
|
+
# Preprocess the data to obtain numeric data
|
3206
|
+
numeric_data = preprocess_data(all_df, settings['filter_by'], settings['remove_highly_correlated'], settings['log_data'], settings['exclude'])
|
3207
|
+
|
3208
|
+
# Convert numeric_data back to a DataFrame to align with col_to_compare
|
3209
|
+
numeric_data_df = pd.DataFrame(numeric_data)
|
3210
|
+
|
3211
|
+
# Ensure numeric_data_df and col_to_compare are properly aligned
|
3212
|
+
numeric_data_df = numeric_data_df.reset_index(drop=True)
|
3213
|
+
|
3214
|
+
# Assign the column back to numeric_data_df
|
3215
|
+
numeric_data_df[settings['col_to_compare']] = col_to_compare
|
3216
|
+
|
3217
|
+
# Subset the dataframe based on specified column values for controls
|
3218
|
+
positive_control_df = numeric_data_df[numeric_data_df[settings['col_to_compare']] == settings['pos']].copy()
|
3219
|
+
negative_control_df = numeric_data_df[numeric_data_df[settings['col_to_compare']] == settings['neg']].copy()
|
3220
|
+
control_numeric_data_df = pd.concat([positive_control_df, negative_control_df])
|
3221
|
+
|
3222
|
+
# Drop the comparison column from numeric_data_df and control_numeric_data_df
|
3223
|
+
numeric_data_df = numeric_data_df.drop(columns=[settings['col_to_compare']])
|
3224
|
+
control_numeric_data_df = control_numeric_data_df.drop(columns=[settings['col_to_compare']])
|
3225
|
+
|
3226
|
+
# Convert numeric_data_df and control_numeric_data_df back to numpy arrays
|
3227
|
+
numeric_data = numeric_data_df.values
|
3228
|
+
control_numeric_data = control_numeric_data_df.values
|
3229
|
+
|
3230
|
+
# Train the reducer on control data
|
3231
|
+
_, _, reducer = reduction_and_clustering(control_numeric_data, settings['n_neighbors'], settings['min_dist'], settings['metric'], settings['eps'], settings['min_samples'], settings['clustering'], settings['reduction_method'], settings['verbose'], n_jobs=settings['n_jobs'], mode='fit', model=False)
|
3232
|
+
|
3233
|
+
# Apply the trained reducer to the entire dataset
|
3234
|
+
numeric_data = preprocess_data(all_df, settings['filter_by'], settings['remove_highly_correlated'], settings['log_data'], settings['exclude'])
|
3235
|
+
embedding, labels, _ = reduction_and_clustering(numeric_data, settings['n_neighbors'], settings['min_dist'], settings['metric'], settings['eps'], settings['min_samples'], settings['clustering'], settings['reduction_method'], settings['verbose'], n_jobs=settings['n_jobs'], mode=None, model=reducer)
|
3236
|
+
|
3237
|
+
else:
|
3238
|
+
if settings['resnet_features']:
|
3239
|
+
numeric_data, embedding, labels = generate_umap_from_images(image_paths, settings['n_neighbors'], settings['min_dist'], settings['metric'], settings['clustering'], settings['eps'], settings['min_samples'], settings['n_jobs'], settings['verbose'])
|
3240
|
+
else:
|
3241
|
+
# Apply the trained reducer to the entire dataset
|
3242
|
+
numeric_data = preprocess_data(all_df, settings['filter_by'], settings['remove_highly_correlated'], settings['log_data'], settings['exclude'])
|
3243
|
+
embedding, labels, _ = reduction_and_clustering(numeric_data, settings['n_neighbors'], settings['min_dist'], settings['metric'], settings['eps'], settings['min_samples'], settings['clustering'], settings['reduction_method'], settings['verbose'], n_jobs=settings['n_jobs'])
|
3244
|
+
|
3245
|
+
if settings['remove_cluster_noise']:
|
3246
|
+
# Remove noise from the clusters (removes -1 labels from DBSCAN)
|
3247
|
+
embedding, labels = remove_noise(embedding, labels)
|
3248
|
+
|
3249
|
+
# Plot the results
|
3250
|
+
if settings['color_by']:
|
3251
|
+
if settings['embedding_by_controls']:
|
3252
|
+
labels = all_df[settings['color_by']]
|
3253
|
+
else:
|
3254
|
+
labels = all_df[settings['color_by']]
|
3255
|
+
|
3256
|
+
# Generate colors for the clusters
|
3257
|
+
colors = generate_colors(len(np.unique(labels)), settings['black_background'])
|
3258
|
+
|
3259
|
+
# Plot the embedding
|
3260
|
+
umap_plt = plot_embedding(embedding, image_paths, labels, settings['image_nr'], settings['img_zoom'], colors, settings['plot_by_cluster'], settings['plot_outlines'], settings['plot_points'], settings['plot_images'], settings['smooth_lines'], settings['black_background'], settings['figuresize'], settings['dot_size'], settings['remove_image_canvas'], settings['verbose'])
|
3261
|
+
if settings['plot_cluster_grids'] and settings['plot_images']:
|
3262
|
+
grid_plt = plot_clusters_grid(embedding, labels, settings['image_nr'], image_paths, colors, settings['figuresize'], settings['black_background'], settings['verbose'])
|
3263
|
+
|
3264
|
+
# Save figure as PDF if required
|
3265
|
+
if settings['save_figure']:
|
3266
|
+
results_dir = os.path.join(settings['src'][0], 'results')
|
3267
|
+
os.makedirs(results_dir, exist_ok=True)
|
3268
|
+
reduction_method = settings['reduction_method'].upper()
|
3269
|
+
embedding_path = os.path.join(results_dir, f'{reduction_method}_embedding.pdf')
|
3270
|
+
umap_plt.savefig(embedding_path, format='pdf')
|
3271
|
+
print(f'Saved {reduction_method} embedding to {embedding_path} and grid to {embedding_path}')
|
3272
|
+
if settings['plot_cluster_grids'] and settings['plot_images']:
|
3273
|
+
grid_path = os.path.join(results_dir, f'{reduction_method}_grid.pdf')
|
3274
|
+
grid_plt.savefig(grid_path, format='pdf')
|
3275
|
+
print(f'Saved {reduction_method} embedding to {embedding_path} and grid to {grid_path}')
|
3276
|
+
|
3277
|
+
# Add cluster labels to the dataframe
|
3278
|
+
all_df['cluster'] = labels
|
3279
|
+
|
3280
|
+
# Save the results to a CSV file
|
3281
|
+
results_dir = os.path.join(settings['src'][0], 'results')
|
3282
|
+
results_csv = os.path.join(results_dir,'embedding_results.csv')
|
3283
|
+
os.makedirs(results_dir, exist_ok=True)
|
3284
|
+
all_df.to_csv(results_csv, index=False)
|
3285
|
+
print(f'Results saved to {results_csv}')
|
3286
|
+
|
3287
|
+
if settings['analyze_clusters']:
|
3288
|
+
combined_results = cluster_feature_analysis(all_df)
|
3289
|
+
results_dir = os.path.join(settings['src'][0], 'results')
|
3290
|
+
cluster_results_csv = os.path.join(results_dir,'cluster_results.csv')
|
3291
|
+
os.makedirs(results_dir, exist_ok=True)
|
3292
|
+
combined_results.to_csv(cluster_results_csv, index=False)
|
3293
|
+
print(f'Cluster results saved to {cluster_results_csv}')
|
3294
|
+
|
3295
|
+
return all_df
|
3296
|
+
|
3297
|
+
# Define the mapping function
|
3298
|
+
def map_condition(col_value, neg='c1', pos='c2', mix='c3'):
|
3299
|
+
if col_value == neg:
|
3300
|
+
return 'neg'
|
3301
|
+
elif col_value == pos:
|
3302
|
+
return 'pos'
|
3303
|
+
elif col_value == mix:
|
3304
|
+
return 'mix'
|
3305
|
+
else:
|
3306
|
+
return 'screen'
|
3307
|
+
|
3308
|
+
def reducer_hyperparameter_search(settings={}, reduction_params=None, dbscan_params=None, kmeans_params=None, save=False):
|
3309
|
+
"""
|
3310
|
+
Perform a hyperparameter search for UMAP or tSNE on the given data.
|
3311
|
+
|
3312
|
+
Parameters:
|
3313
|
+
settings (dict): Dictionary containing the following keys:
|
3314
|
+
src (str): Source directory containing the data.
|
3315
|
+
row_limit (int): Limit the number of rows to process.
|
3316
|
+
tables (list): List of table names to read from the database.
|
3317
|
+
filter_by (str): Column to filter the data.
|
3318
|
+
sample_size (int): Number of samples to use for the hyperparameter search.
|
3319
|
+
remove_highly_correlated (bool): Whether to remove highly correlated columns.
|
3320
|
+
log_data (bool): Whether to log transform the data.
|
3321
|
+
verbose (bool): Whether to print verbose output.
|
3322
|
+
reduction_method (str): Dimensionality reduction method ('UMAP' or 'tSNE').
|
3323
|
+
reduction_params (list): List of dictionaries containing hyperparameters to test for the reduction method.
|
3324
|
+
dbscan_params (list): List of dictionaries containing DBSCAN hyperparameters to test.
|
3325
|
+
kmeans_params (list): List of dictionaries containing KMeans hyperparameters to test.
|
3326
|
+
pointsize (int): Size of the points in the scatter plot.
|
3327
|
+
save (bool): Whether to save the resulting plot as a file.
|
3328
|
+
|
3329
|
+
Returns:
|
3330
|
+
None
|
3331
|
+
"""
|
3332
|
+
|
3333
|
+
from .io import _read_and_join_tables
|
3334
|
+
from .utils import get_db_paths, preprocess_data, search_reduction_and_clustering, generate_colors, get_umap_image_settings
|
3335
|
+
|
3336
|
+
settings = get_umap_image_settings(settings)
|
3337
|
+
pointsize = settings['dot_size']
|
3338
|
+
if isinstance(dbscan_params, dict):
|
3339
|
+
dbscan_params = [dbscan_params]
|
3340
|
+
|
3341
|
+
if isinstance(kmeans_params, dict):
|
3342
|
+
kmeans_params = [kmeans_params]
|
3343
|
+
|
3344
|
+
if isinstance(reduction_params, dict):
|
3345
|
+
reduction_params = [reduction_params]
|
3346
|
+
|
3347
|
+
# Determine reduction method based on the keys in reduction_param
|
3348
|
+
if any('n_neighbors' in param for param in reduction_params):
|
3349
|
+
reduction_method = 'umap'
|
3350
|
+
elif any('perplexity' in param for param in reduction_params):
|
3351
|
+
reduction_method = 'tsne'
|
3352
|
+
elif any('perplexity' in param for param in reduction_params) and any('n_neighbors' in param for param in reduction_params):
|
3353
|
+
raise ValueError("Reduction parameters must include 'n_neighbors' for UMAP or 'perplexity' for tSNE, not both.")
|
3354
|
+
|
3355
|
+
if settings['reduction_method'].lower() != reduction_method:
|
3356
|
+
settings['reduction_method'] = reduction_method
|
3357
|
+
print(f'Changed reduction method to {reduction_method} based on the provided parameters.')
|
3358
|
+
|
3359
|
+
if settings['verbose']:
|
3360
|
+
display(pd.DataFrame(list(settings.items()), columns=['Key', 'Value']))
|
3361
|
+
|
3362
|
+
db_paths = get_db_paths(settings['src'])
|
3363
|
+
|
3364
|
+
tables = settings['tables']
|
3365
|
+
all_df = pd.DataFrame()
|
3366
|
+
for db_path in db_paths:
|
3367
|
+
df = _read_and_join_tables(db_path, table_names=tables)
|
3368
|
+
all_df = pd.concat([all_df, df], axis=0)
|
3369
|
+
|
3370
|
+
all_df['cond'] = all_df['col'].apply(map_condition, neg=settings['neg'], pos=settings['pos'], mix=settings['mix'])
|
3371
|
+
|
3372
|
+
if settings['exclude_conditions']:
|
3373
|
+
if isinstance(settings['exclude_conditions'], str):
|
3374
|
+
settings['exclude_conditions'] = [settings['exclude_conditions']]
|
3375
|
+
row_count_before = len(all_df)
|
3376
|
+
all_df = all_df[~all_df['cond'].isin(settings['exclude_conditions'])]
|
3377
|
+
if settings['verbose']:
|
3378
|
+
print(f'Excluded {row_count_before - len(all_df)} rows after excluding: {settings["exclude_conditions"]}, rows left: {len(all_df)}')
|
3379
|
+
|
3380
|
+
if settings['row_limit'] is not None:
|
3381
|
+
all_df = all_df.sample(n=settings['row_limit'], random_state=42)
|
3382
|
+
|
3383
|
+
numeric_data = preprocess_data(all_df, settings['filter_by'], settings['remove_highly_correlated'], settings['log_data'], settings['exclude'])
|
3384
|
+
|
3385
|
+
# Combine DBSCAN and KMeans parameters
|
3386
|
+
clustering_params = []
|
3387
|
+
if dbscan_params:
|
3388
|
+
for param in dbscan_params:
|
3389
|
+
param['method'] = 'dbscan'
|
3390
|
+
clustering_params.append(param)
|
3391
|
+
if kmeans_params:
|
3392
|
+
for param in kmeans_params:
|
3393
|
+
param['method'] = 'kmeans'
|
3394
|
+
clustering_params.append(param)
|
3395
|
+
|
3396
|
+
print('Testing paramiters:', reduction_params)
|
3397
|
+
print('Testing clustering paramiters:', clustering_params)
|
3398
|
+
|
3399
|
+
# Calculate the grid size
|
3400
|
+
grid_rows = len(reduction_params)
|
3401
|
+
grid_cols = len(clustering_params)
|
3402
|
+
|
3403
|
+
fig_width = grid_cols*10
|
3404
|
+
fig_height = grid_rows*10
|
3405
|
+
|
3406
|
+
fig, axs = plt.subplots(grid_rows, grid_cols, figsize=(fig_width, fig_height))
|
3407
|
+
|
3408
|
+
# Make sure axs is always an array of axes
|
3409
|
+
axs = np.atleast_1d(axs)
|
3410
|
+
|
3411
|
+
# Iterate through the Cartesian product of reduction and clustering hyperparameters
|
3412
|
+
for i, reduction_param in enumerate(reduction_params):
|
3413
|
+
for j, clustering_param in enumerate(clustering_params):
|
3414
|
+
if len(clustering_params) <= 1:
|
3415
|
+
axs[i].axis('off')
|
3416
|
+
ax = axs[i]
|
3417
|
+
elif len(reduction_params) <= 1:
|
3418
|
+
axs[j].axis('off')
|
3419
|
+
ax = axs[j]
|
3420
|
+
else:
|
3421
|
+
ax = axs[i, j]
|
3422
|
+
|
3423
|
+
# Perform dimensionality reduction and clustering
|
3424
|
+
if settings['reduction_method'].lower() == 'umap':
|
3425
|
+
n_neighbors = reduction_param.get('n_neighbors', 15)
|
3426
|
+
|
3427
|
+
if isinstance(n_neighbors, float):
|
3428
|
+
n_neighbors = int(n_neighbors * len(numeric_data))
|
3429
|
+
|
3430
|
+
min_dist = reduction_param.get('min_dist', 0.1)
|
3431
|
+
embedding, labels = search_reduction_and_clustering(numeric_data, n_neighbors, min_dist, settings['metric'],
|
3432
|
+
clustering_param.get('eps', 0.5), clustering_param.get('min_samples', 5),
|
3433
|
+
clustering_param['method'], settings['reduction_method'], settings['verbose'], reduction_param, n_jobs=settings['n_jobs'])
|
3434
|
+
|
3435
|
+
elif settings['reduction_method'].lower() == 'tsne':
|
3436
|
+
perplexity = reduction_param.get('perplexity', 30)
|
3437
|
+
|
3438
|
+
if isinstance(perplexity, float):
|
3439
|
+
perplexity = int(perplexity * len(numeric_data))
|
3440
|
+
|
3441
|
+
embedding, labels = search_reduction_and_clustering(numeric_data, perplexity, 0.1, settings['metric'],
|
3442
|
+
clustering_param.get('eps', 0.5), clustering_param.get('min_samples', 5),
|
3443
|
+
clustering_param['method'], settings['reduction_method'], settings['verbose'], reduction_param, n_jobs=settings['n_jobs'])
|
3444
|
+
|
3445
|
+
else:
|
3446
|
+
raise ValueError(f"Unsupported reduction method: {settings['reduction_method']}. Supported methods are 'UMAP' and 'tSNE'")
|
3447
|
+
|
3448
|
+
# Plot the results
|
3449
|
+
if settings['color_by']:
|
3450
|
+
unique_groups = all_df[settings['color_by']].unique()
|
3451
|
+
colors = generate_colors(len(unique_groups), False)
|
3452
|
+
for group, color in zip(unique_groups, colors):
|
3453
|
+
indices = all_df[settings['color_by']] == group
|
3454
|
+
ax.scatter(embedding[indices, 0], embedding[indices, 1], s=pointsize, label=f"{group}", color=color)
|
3455
|
+
else:
|
3456
|
+
unique_labels = np.unique(labels)
|
3457
|
+
colors = generate_colors(len(unique_labels), False)
|
3458
|
+
for label, color in zip(unique_labels, colors):
|
3459
|
+
ax.scatter(embedding[labels == label, 0], embedding[labels == label, 1], s=pointsize, label=f"Cluster {label}", color=color)
|
3460
|
+
|
3461
|
+
ax.set_title(f"{settings['reduction_method']} {reduction_param}\n{clustering_param['method']} {clustering_param}")
|
3462
|
+
ax.legend()
|
3463
|
+
|
3464
|
+
plt.tight_layout()
|
3465
|
+
if save:
|
3466
|
+
results_dir = os.path.join(settings['src'], 'results')
|
3467
|
+
os.makedirs(results_dir, exist_ok=True)
|
3468
|
+
plt.savefig(os.path.join(results_dir, 'hyperparameter_search.pdf'))
|
3469
|
+
else:
|
3470
|
+
plt.show()
|
3471
|
+
|
3472
|
+
return
|