spacr 0.0.20__py3-none-any.whl → 0.0.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- spacr/alpha.py +291 -14
- spacr/annotate_app.py +2 -2
- spacr/core.py +1301 -426
- spacr/foldseek.py +793 -0
- spacr/get_alfafold_structures.py +72 -0
- spacr/gui_mask_app.py +30 -10
- spacr/gui_utils.py +17 -2
- spacr/io.py +260 -102
- spacr/measure.py +150 -64
- spacr/plot.py +151 -12
- spacr/sim.py +666 -119
- spacr/timelapse.py +139 -9
- spacr/train.py +18 -10
- spacr/utils.py +43 -43
- {spacr-0.0.20.dist-info → spacr-0.0.21.dist-info}/METADATA +5 -2
- spacr-0.0.21.dist-info/RECORD +33 -0
- spacr-0.0.20.dist-info/RECORD +0 -31
- {spacr-0.0.20.dist-info → spacr-0.0.21.dist-info}/LICENSE +0 -0
- {spacr-0.0.20.dist-info → spacr-0.0.21.dist-info}/WHEEL +0 -0
- {spacr-0.0.20.dist-info → spacr-0.0.21.dist-info}/entry_points.txt +0 -0
- {spacr-0.0.20.dist-info → spacr-0.0.21.dist-info}/top_level.txt +0 -0
spacr/core.py
CHANGED
@@ -1,11 +1,13 @@
|
|
1
|
-
import os, sqlite3, gc, torch, time, random, shutil, cv2, tarfile, datetime
|
1
|
+
import os, sqlite3, gc, torch, time, random, shutil, cv2, tarfile, datetime, shap, string
|
2
2
|
|
3
3
|
# image and array processing
|
4
4
|
import numpy as np
|
5
5
|
import pandas as pd
|
6
6
|
|
7
|
+
from cellpose import train
|
7
8
|
import cellpose
|
8
9
|
from cellpose import models as cp_models
|
10
|
+
from cellpose.models import CellposeModel
|
9
11
|
|
10
12
|
import statsmodels.formula.api as smf
|
11
13
|
import statsmodels.api as sm
|
@@ -27,9 +29,17 @@ matplotlib.use('Agg')
|
|
27
29
|
|
28
30
|
import torchvision.transforms as transforms
|
29
31
|
from sklearn.model_selection import train_test_split
|
30
|
-
from sklearn.ensemble import IsolationForest
|
32
|
+
from sklearn.ensemble import IsolationForest, RandomForestClassifier, HistGradientBoostingClassifier
|
31
33
|
from .logger import log_function_call
|
32
34
|
|
35
|
+
from sklearn.linear_model import LogisticRegression
|
36
|
+
from sklearn.inspection import permutation_importance
|
37
|
+
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, classification_report
|
38
|
+
from xgboost import XGBClassifier
|
39
|
+
|
40
|
+
from scipy.spatial.distance import cosine, euclidean, mahalanobis, cityblock, minkowski, chebyshev, hamming, jaccard, braycurtis
|
41
|
+
from sklearn.preprocessing import StandardScaler
|
42
|
+
import shap
|
33
43
|
|
34
44
|
def analyze_plaques(folder):
|
35
45
|
summary_data = []
|
@@ -67,74 +77,6 @@ def analyze_plaques(folder):
|
|
67
77
|
|
68
78
|
print(f"Analysis completed and saved to database '{db_name}'.")
|
69
79
|
|
70
|
-
def compare_masks(dir1, dir2, dir3, verbose=False):
|
71
|
-
|
72
|
-
from .io import _read_mask
|
73
|
-
from .plot import visualize_masks, plot_comparison_results
|
74
|
-
from .utils import extract_boundaries, boundary_f1_score, compute_segmentation_ap, jaccard_index, dice_coefficient
|
75
|
-
|
76
|
-
filenames = os.listdir(dir1)
|
77
|
-
results = []
|
78
|
-
cond_1 = os.path.basename(dir1)
|
79
|
-
cond_2 = os.path.basename(dir2)
|
80
|
-
cond_3 = os.path.basename(dir3)
|
81
|
-
for index, filename in enumerate(filenames):
|
82
|
-
print(f'Processing image:{index+1}', end='\r', flush=True)
|
83
|
-
path1, path2, path3 = os.path.join(dir1, filename), os.path.join(dir2, filename), os.path.join(dir3, filename)
|
84
|
-
if os.path.exists(path2) and os.path.exists(path3):
|
85
|
-
|
86
|
-
mask1, mask2, mask3 = _read_mask(path1), _read_mask(path2), _read_mask(path3)
|
87
|
-
boundary_true1, boundary_true2, boundary_true3 = extract_boundaries(mask1), extract_boundaries(mask2), extract_boundaries(mask3)
|
88
|
-
|
89
|
-
|
90
|
-
true_masks, pred_masks = [mask1], [mask2, mask3] # Assuming mask1 is the ground truth for simplicity
|
91
|
-
true_labels, pred_labels_1, pred_labels_2 = label(mask1), label(mask2), label(mask3)
|
92
|
-
average_precision_0, average_precision_1 = compute_segmentation_ap(mask1, mask2), compute_segmentation_ap(mask1, mask3)
|
93
|
-
ap_scores = [average_precision_0, average_precision_1]
|
94
|
-
|
95
|
-
if verbose:
|
96
|
-
unique_values1, unique_values2, unique_values3 = np.unique(mask1), np.unique(mask2), np.unique(mask3)
|
97
|
-
print(f"Unique values in mask 1: {unique_values1}, mask 2: {unique_values2}, mask 3: {unique_values3}")
|
98
|
-
visualize_masks(boundary_true1, boundary_true2, boundary_true3, title=f"Boundaries - {filename}")
|
99
|
-
|
100
|
-
boundary_f1_12, boundary_f1_13, boundary_f1_23 = boundary_f1_score(mask1, mask2), boundary_f1_score(mask1, mask3), boundary_f1_score(mask2, mask3)
|
101
|
-
|
102
|
-
if (np.unique(mask1).size == 1 and np.unique(mask1)[0] == 0) and \
|
103
|
-
(np.unique(mask2).size == 1 and np.unique(mask2)[0] == 0) and \
|
104
|
-
(np.unique(mask3).size == 1 and np.unique(mask3)[0] == 0):
|
105
|
-
continue
|
106
|
-
|
107
|
-
if verbose:
|
108
|
-
unique_values4, unique_values5, unique_values6 = np.unique(boundary_f1_12), np.unique(boundary_f1_13), np.unique(boundary_f1_23)
|
109
|
-
print(f"Unique values in boundary mask 1: {unique_values4}, mask 2: {unique_values5}, mask 3: {unique_values6}")
|
110
|
-
visualize_masks(mask1, mask2, mask3, title=filename)
|
111
|
-
|
112
|
-
jaccard12 = jaccard_index(mask1, mask2)
|
113
|
-
dice12 = dice_coefficient(mask1, mask2)
|
114
|
-
jaccard13 = jaccard_index(mask1, mask3)
|
115
|
-
dice13 = dice_coefficient(mask1, mask3)
|
116
|
-
jaccard23 = jaccard_index(mask2, mask3)
|
117
|
-
dice23 = dice_coefficient(mask2, mask3)
|
118
|
-
|
119
|
-
results.append({
|
120
|
-
f'filename': filename,
|
121
|
-
f'jaccard_{cond_1}_{cond_2}': jaccard12,
|
122
|
-
f'dice_{cond_1}_{cond_2}': dice12,
|
123
|
-
f'jaccard_{cond_1}_{cond_3}': jaccard13,
|
124
|
-
f'dice_{cond_1}_{cond_3}': dice13,
|
125
|
-
f'jaccard_{cond_2}_{cond_3}': jaccard23,
|
126
|
-
f'dice_{cond_2}_{cond_3}': dice23,
|
127
|
-
f'boundary_f1_{cond_1}_{cond_2}': boundary_f1_12,
|
128
|
-
f'boundary_f1_{cond_1}_{cond_3}': boundary_f1_13,
|
129
|
-
f'boundary_f1_{cond_2}_{cond_3}': boundary_f1_23,
|
130
|
-
f'average_precision_{cond_1}_{cond_2}': ap_scores[0],
|
131
|
-
f'average_precision_{cond_1}_{cond_3}': ap_scores[1]
|
132
|
-
})
|
133
|
-
else:
|
134
|
-
print(f'Cannot find {path1} or {path2} or {path3}')
|
135
|
-
fig = plot_comparison_results(results)
|
136
|
-
return results, fig
|
137
|
-
|
138
80
|
def generate_cp_masks(settings):
|
139
81
|
|
140
82
|
src = settings['src']
|
@@ -177,8 +119,146 @@ def train_cellpose(settings):
|
|
177
119
|
from .utils import resize_images_and_labels
|
178
120
|
|
179
121
|
img_src = settings['img_src']
|
180
|
-
mask_src=
|
181
|
-
|
122
|
+
mask_src = os.path.join(img_src, 'mask')
|
123
|
+
|
124
|
+
model_name = settings['model_name']
|
125
|
+
model_type = settings['model_type']
|
126
|
+
learning_rate = settings['learning_rate']
|
127
|
+
weight_decay = settings['weight_decay']
|
128
|
+
batch_size = settings['batch_size']
|
129
|
+
n_epochs = settings['n_epochs']
|
130
|
+
from_scratch = settings['from_scratch']
|
131
|
+
diameter = settings['diameter']
|
132
|
+
verbose = settings['verbose']
|
133
|
+
|
134
|
+
channels = [0,0]
|
135
|
+
signal_thresholds = 1000
|
136
|
+
normalize = True
|
137
|
+
percentiles = [2,98]
|
138
|
+
circular = False
|
139
|
+
invert = False
|
140
|
+
resize = False
|
141
|
+
settings['width_height'] = [1000,1000]
|
142
|
+
target_height = settings['width_height'][1]
|
143
|
+
target_width = settings['width_height'][0]
|
144
|
+
rescale = False
|
145
|
+
grayscale = True
|
146
|
+
test = False
|
147
|
+
|
148
|
+
if test:
|
149
|
+
test_img_src = os.path.join(os.path.dirname(img_src), 'test')
|
150
|
+
test_mask_src = os.path.join(test_img_src, 'mask')
|
151
|
+
|
152
|
+
test_images, test_masks, test_image_names, test_mask_names = None,None,None,None,
|
153
|
+
print(settings)
|
154
|
+
|
155
|
+
if from_scratch:
|
156
|
+
model_name=f'scratch_{model_name}_{model_type}_e{n_epochs}_X{target_width}_Y{target_height}.CP_model'
|
157
|
+
else:
|
158
|
+
if resize:
|
159
|
+
model_name=f'{model_name}_{model_type}_e{n_epochs}_X{target_width}_Y{target_height}.CP_model'
|
160
|
+
else:
|
161
|
+
model_name=f'{model_name}_{model_type}_e{n_epochs}.CP_model'
|
162
|
+
|
163
|
+
model_save_path = os.path.join(mask_src, 'models', 'cellpose_model')
|
164
|
+
print(model_save_path)
|
165
|
+
os.makedirs(model_save_path, exist_ok=True)
|
166
|
+
|
167
|
+
settings_df = pd.DataFrame(list(settings.items()), columns=['Key', 'Value'])
|
168
|
+
settings_csv = os.path.join(model_save_path,f'{model_name}_settings.csv')
|
169
|
+
settings_df.to_csv(settings_csv, index=False)
|
170
|
+
|
171
|
+
if from_scratch:
|
172
|
+
model = cp_models.CellposeModel(gpu=True, model_type=model_type, diam_mean=diameter, pretrained_model=None)
|
173
|
+
else:
|
174
|
+
model = cp_models.CellposeModel(gpu=True, model_type=model_type)
|
175
|
+
|
176
|
+
if normalize:
|
177
|
+
|
178
|
+
image_files = [os.path.join(img_src, f) for f in os.listdir(img_src) if f.endswith('.tif')]
|
179
|
+
label_files = [os.path.join(mask_src, f) for f in os.listdir(mask_src) if f.endswith('.tif')]
|
180
|
+
images, masks, image_names, mask_names = _load_normalized_images_and_labels(image_files, label_files, signal_thresholds, channels=channels, percentiles=percentiles, circular=circular, invert=invert, visualize=verbose)
|
181
|
+
images = [np.squeeze(img) if img.shape[-1] == 1 else img for img in images]
|
182
|
+
|
183
|
+
if test:
|
184
|
+
test_image_files = [os.path.join(test_img_src, f) for f in os.listdir(test_img_src) if f.endswith('.tif')]
|
185
|
+
test_label_files = [os.path.join(test_mask_src, f) for f in os.listdir(test_mask_src) if f.endswith('.tif')]
|
186
|
+
test_images, test_masks, test_image_names, test_mask_names = _load_normalized_images_and_labels(image_files=test_image_files, label_files=test_label_files, signal_thresholds=signal_thresholds, channels=channels, percentiles=percentiles, circular=circular, invert=invert, visualize=verbose)
|
187
|
+
test_images = [np.squeeze(img) if img.shape[-1] == 1 else img for img in test_images]
|
188
|
+
|
189
|
+
|
190
|
+
else:
|
191
|
+
images, masks, image_names, mask_names = _load_images_and_labels(img_src, mask_src, circular, invert)
|
192
|
+
images = [np.squeeze(img) if img.shape[-1] == 1 else img for img in images]
|
193
|
+
|
194
|
+
if test:
|
195
|
+
test_images, test_masks, test_image_names, test_mask_names = _load_images_and_labels(img_src=test_img_src, mask_src=test_mask_src, circular=circular, invert=circular)
|
196
|
+
test_images = [np.squeeze(img) if img.shape[-1] == 1 else img for img in test_images]
|
197
|
+
|
198
|
+
if resize:
|
199
|
+
images, masks = resize_images_and_labels(images, masks, target_height, target_width, show_example=True)
|
200
|
+
|
201
|
+
if model_type == 'cyto':
|
202
|
+
cp_channels = [0,1]
|
203
|
+
if model_type == 'cyto2':
|
204
|
+
cp_channels = [0,2]
|
205
|
+
if model_type == 'nucleus':
|
206
|
+
cp_channels = [0,0]
|
207
|
+
if grayscale:
|
208
|
+
cp_channels = [0,0]
|
209
|
+
images = [np.squeeze(img) if img.ndim == 3 and 1 in img.shape else img for img in images]
|
210
|
+
|
211
|
+
masks = [np.squeeze(mask) if mask.ndim == 3 and 1 in mask.shape else mask for mask in masks]
|
212
|
+
|
213
|
+
print(f'image shape: {images[0].shape}, image type: images[0].shape mask shape: {masks[0].shape}, image type: masks[0].shape')
|
214
|
+
save_every = int(n_epochs/10)
|
215
|
+
if save_every < 10:
|
216
|
+
save_every = n_epochs
|
217
|
+
|
218
|
+
train.train_seg(model.net,
|
219
|
+
train_data=images,
|
220
|
+
train_labels=masks,
|
221
|
+
train_files=image_names,
|
222
|
+
train_labels_files=mask_names,
|
223
|
+
train_probs=None,
|
224
|
+
test_data=test_images,
|
225
|
+
test_labels=test_masks,
|
226
|
+
test_files=test_image_names,
|
227
|
+
test_labels_files=test_mask_names,
|
228
|
+
test_probs=None,
|
229
|
+
load_files=True,
|
230
|
+
batch_size=batch_size,
|
231
|
+
learning_rate=learning_rate,
|
232
|
+
n_epochs=n_epochs,
|
233
|
+
weight_decay=weight_decay,
|
234
|
+
momentum=0.9,
|
235
|
+
SGD=False,
|
236
|
+
channels=cp_channels,
|
237
|
+
channel_axis=None,
|
238
|
+
#rgb=False,
|
239
|
+
normalize=False,
|
240
|
+
compute_flows=False,
|
241
|
+
save_path=model_save_path,
|
242
|
+
save_every=save_every,
|
243
|
+
nimg_per_epoch=None,
|
244
|
+
nimg_test_per_epoch=None,
|
245
|
+
rescale=rescale,
|
246
|
+
#scale_range=None,
|
247
|
+
#bsize=224,
|
248
|
+
min_train_masks=1,
|
249
|
+
model_name=model_name)
|
250
|
+
|
251
|
+
return print(f"Model saved at: {model_save_path}/{model_name}")
|
252
|
+
|
253
|
+
def train_cellpose_v1(settings):
|
254
|
+
|
255
|
+
from .io import _load_normalized_images_and_labels, _load_images_and_labels
|
256
|
+
from .utils import resize_images_and_labels
|
257
|
+
|
258
|
+
img_src = settings['img_src']
|
259
|
+
|
260
|
+
mask_src = os.path.join(img_src, 'mask')
|
261
|
+
|
182
262
|
model_name = settings['model_name']
|
183
263
|
model_type = settings['model_type']
|
184
264
|
learning_rate = settings['learning_rate']
|
@@ -186,7 +266,9 @@ def train_cellpose(settings):
|
|
186
266
|
batch_size = settings['batch_size']
|
187
267
|
n_epochs = settings['n_epochs']
|
188
268
|
verbose = settings['verbose']
|
189
|
-
|
269
|
+
|
270
|
+
signal_thresholds = 100 #settings['signal_thresholds']
|
271
|
+
|
190
272
|
channels = settings['channels']
|
191
273
|
from_scratch = settings['from_scratch']
|
192
274
|
diameter = settings['diameter']
|
@@ -199,7 +281,17 @@ def train_cellpose(settings):
|
|
199
281
|
invert = settings['invert']
|
200
282
|
percentiles = settings['percentiles']
|
201
283
|
grayscale = settings['grayscale']
|
202
|
-
|
284
|
+
|
285
|
+
if model_type == 'cyto':
|
286
|
+
settings['diameter'] = 30
|
287
|
+
diameter = settings['diameter']
|
288
|
+
print(f'Cyto model must have diamiter 30. Diameter set the 30')
|
289
|
+
|
290
|
+
if model_type == 'nuclei':
|
291
|
+
settings['diameter'] = 17
|
292
|
+
diameter = settings['diameter']
|
293
|
+
print(f'Nuclei model must have diamiter 17. Diameter set the 17')
|
294
|
+
|
203
295
|
print(settings)
|
204
296
|
|
205
297
|
if from_scratch:
|
@@ -208,24 +300,24 @@ def train_cellpose(settings):
|
|
208
300
|
model_name=f'{model_name}_{model_type}_e{n_epochs}_X{target_width}_Y{target_height}.CP_model'
|
209
301
|
|
210
302
|
model_save_path = os.path.join(mask_src, 'models', 'cellpose_model')
|
211
|
-
|
303
|
+
print(model_save_path)
|
304
|
+
os.makedirs(model_save_path, exist_ok=True)
|
212
305
|
|
213
306
|
settings_df = pd.DataFrame(list(settings.items()), columns=['Key', 'Value'])
|
214
307
|
settings_csv = os.path.join(model_save_path,f'{model_name}_settings.csv')
|
215
308
|
settings_df.to_csv(settings_csv, index=False)
|
216
309
|
|
217
|
-
if
|
218
|
-
if not from_scratch:
|
219
|
-
model = cp_models.CellposeModel(gpu=True, model_type=model_type)
|
220
|
-
else:
|
221
|
-
model = cp_models.CellposeModel(gpu=True, model_type=model_type, net_avg=False, diam_mean=diameter, pretrained_model=None)
|
222
|
-
if model_type !='cyto':
|
310
|
+
if not from_scratch:
|
223
311
|
model = cp_models.CellposeModel(gpu=True, model_type=model_type)
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
312
|
+
|
313
|
+
else:
|
314
|
+
model = cp_models.CellposeModel(gpu=True, model_type=model_type, pretrained_model=None)
|
315
|
+
|
316
|
+
if normalize:
|
317
|
+
image_files = [os.path.join(img_src, f) for f in os.listdir(img_src) if f.endswith('.tif')]
|
318
|
+
label_files = [os.path.join(mask_src, f) for f in os.listdir(mask_src) if f.endswith('.tif')]
|
319
|
+
|
320
|
+
images, masks, image_names, mask_names = _load_normalized_images_and_labels(image_files, label_files, signal_thresholds, channels=channels, percentiles=percentiles, circular=circular, invert=invert, visualize=verbose)
|
229
321
|
images = [np.squeeze(img) if img.shape[-1] == 1 else img for img in images]
|
230
322
|
else:
|
231
323
|
images, masks, image_names, mask_names = _load_images_and_labels(img_src, mask_src, circular, invert)
|
@@ -248,25 +340,86 @@ def train_cellpose(settings):
|
|
248
340
|
|
249
341
|
print(f'image shape: {images[0].shape}, image type: images[0].shape mask shape: {masks[0].shape}, image type: masks[0].shape')
|
250
342
|
save_every = int(n_epochs/10)
|
251
|
-
|
252
|
-
|
343
|
+
if save_every < 10:
|
344
|
+
save_every = n_epochs
|
345
|
+
|
346
|
+
|
347
|
+
#print('cellpose image input dtype', images[0].dtype)
|
348
|
+
#print('cellpose mask input dtype', masks[0].dtype)
|
349
|
+
|
253
350
|
# Train the model
|
254
|
-
model.train(train_data=images, #(list of arrays (2D or 3D)) – images for training
|
255
|
-
|
256
|
-
|
257
|
-
|
258
|
-
|
259
|
-
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
|
266
|
-
|
267
|
-
|
268
|
-
|
269
|
-
|
351
|
+
#model.train(train_data=images, #(list of arrays (2D or 3D)) – images for training
|
352
|
+
|
353
|
+
#model.train(train_data=images, #(list of arrays (2D or 3D)) – images for training
|
354
|
+
# train_labels=masks, #(list of arrays (2D or 3D)) – labels for train_data, where 0=no masks; 1,2,…=mask labels can include flows as additional images
|
355
|
+
# train_files=image_names, #(list of strings) – file names for images in train_data (to save flows for future runs)
|
356
|
+
# channels=cp_channels, #(list of ints (default, None)) – channels to use for training
|
357
|
+
# normalize=False, #(bool (default, True)) – normalize data so 0.0=1st percentile and 1.0=99th percentile of image intensities in each channel
|
358
|
+
# save_path=model_save_path, #(string (default, None)) – where to save trained model, if None it is not saved
|
359
|
+
# save_every=save_every, #(int (default, 100)) – save network every [save_every] epochs
|
360
|
+
# learning_rate=learning_rate, #(float or list/np.ndarray (default, 0.2)) – learning rate for training, if list, must be same length as n_epochs
|
361
|
+
# n_epochs=n_epochs, #(int (default, 500)) – how many times to go through whole training set during training
|
362
|
+
# weight_decay=weight_decay, #(float (default, 0.00001)) –
|
363
|
+
# SGD=True, #(bool (default, True)) – use SGD as optimization instead of RAdam
|
364
|
+
# batch_size=batch_size, #(int (optional, default 8)) – number of 224x224 patches to run simultaneously on the GPU (can make smaller or bigger depending on GPU memory usage)
|
365
|
+
# nimg_per_epoch=None, #(int (optional, default None)) – minimum number of images to train on per epoch, with a small training set (< 8 images) it may help to set to 8
|
366
|
+
# rescale=rescale, #(bool (default, True)) – whether or not to rescale images to diam_mean during training, if True it assumes you will fit a size model after training or resize your images accordingly, if False it will try to train the model to be scale-invariant (works worse)
|
367
|
+
# min_train_masks=1, #(int (default, 5)) – minimum number of masks an image must have to use in training set
|
368
|
+
# model_name=model_name) #(str (default, None)) – name of network, otherwise saved with name as params + training start time
|
369
|
+
|
370
|
+
|
371
|
+
train.train_seg(model.net,
|
372
|
+
train_data=images,
|
373
|
+
train_labels=masks,
|
374
|
+
train_files=image_names,
|
375
|
+
train_labels_files=None,
|
376
|
+
train_probs=None,
|
377
|
+
test_data=None,
|
378
|
+
test_labels=None,
|
379
|
+
test_files=None,
|
380
|
+
test_labels_files=None,
|
381
|
+
test_probs=None,
|
382
|
+
load_files=True,
|
383
|
+
batch_size=batch_size,
|
384
|
+
learning_rate=learning_rate,
|
385
|
+
n_epochs=n_epochs,
|
386
|
+
weight_decay=weight_decay,
|
387
|
+
momentum=0.9,
|
388
|
+
SGD=False,
|
389
|
+
channels=cp_channels,
|
390
|
+
channel_axis=None,
|
391
|
+
#rgb=False,
|
392
|
+
normalize=False,
|
393
|
+
compute_flows=False,
|
394
|
+
save_path=model_save_path,
|
395
|
+
save_every=save_every,
|
396
|
+
nimg_per_epoch=None,
|
397
|
+
nimg_test_per_epoch=None,
|
398
|
+
rescale=rescale,
|
399
|
+
#scale_range=None,
|
400
|
+
#bsize=224,
|
401
|
+
min_train_masks=1,
|
402
|
+
model_name=model_name)
|
403
|
+
|
404
|
+
#model_save_path = train.train_seg(model.net,
|
405
|
+
# train_data=images,
|
406
|
+
# train_files=image_names,
|
407
|
+
# train_labels=masks,
|
408
|
+
# channels=cp_channels,
|
409
|
+
# normalize=False,
|
410
|
+
# save_every=save_every,
|
411
|
+
# learning_rate=learning_rate,
|
412
|
+
# n_epochs=n_epochs,
|
413
|
+
# #test_data=test_images,
|
414
|
+
# #test_labels=test_labels,
|
415
|
+
# weight_decay=weight_decay,
|
416
|
+
# SGD=True,
|
417
|
+
# batch_size=batch_size,
|
418
|
+
# nimg_per_epoch=None,
|
419
|
+
# rescale=rescale,
|
420
|
+
# min_train_masks=1,
|
421
|
+
# model_name=model_name)
|
422
|
+
|
270
423
|
|
271
424
|
return print(f"Model saved at: {model_save_path}/{model_name}")
|
272
425
|
|
@@ -926,30 +1079,38 @@ def annotate_results(pred_loc):
|
|
926
1079
|
display(df)
|
927
1080
|
return df
|
928
1081
|
|
929
|
-
def generate_dataset(src,
|
1082
|
+
def generate_dataset(src, file_metadata=None, experiment='TSG101_screen', sample=None):
|
930
1083
|
|
931
|
-
from .utils import
|
932
|
-
|
933
|
-
db_path = os.path.join(src, 'measurements','measurements.db')
|
1084
|
+
from .utils import initiate_counter, add_images_to_tar
|
1085
|
+
|
1086
|
+
db_path = os.path.join(src, 'measurements', 'measurements.db')
|
934
1087
|
dst = os.path.join(src, 'datasets')
|
935
|
-
|
936
|
-
global total_images
|
937
1088
|
all_paths = []
|
938
|
-
|
1089
|
+
|
939
1090
|
# Connect to the database and retrieve the image paths
|
940
1091
|
print(f'Reading DataBase: {db_path}')
|
941
|
-
|
942
|
-
|
943
|
-
|
944
|
-
|
945
|
-
|
946
|
-
|
947
|
-
|
948
|
-
|
949
|
-
|
950
|
-
|
951
|
-
|
952
|
-
|
1092
|
+
try:
|
1093
|
+
with sqlite3.connect(db_path) as conn:
|
1094
|
+
cursor = conn.cursor()
|
1095
|
+
if file_metadata:
|
1096
|
+
if isinstance(file_metadata, str):
|
1097
|
+
cursor.execute("SELECT png_path FROM png_list WHERE png_path LIKE ?", (f"%{file_metadata}%",))
|
1098
|
+
else:
|
1099
|
+
cursor.execute("SELECT png_path FROM png_list")
|
1100
|
+
|
1101
|
+
while True:
|
1102
|
+
rows = cursor.fetchmany(1000)
|
1103
|
+
if not rows:
|
1104
|
+
break
|
1105
|
+
all_paths.extend([row[0] for row in rows])
|
1106
|
+
|
1107
|
+
except sqlite3.Error as e:
|
1108
|
+
print(f"Database error: {e}")
|
1109
|
+
return
|
1110
|
+
except Exception as e:
|
1111
|
+
print(f"Error: {e}")
|
1112
|
+
return
|
1113
|
+
|
953
1114
|
if isinstance(sample, int):
|
954
1115
|
selected_paths = random.sample(all_paths, sample)
|
955
1116
|
print(f'Random selection of {len(selected_paths)} paths')
|
@@ -957,23 +1118,18 @@ def generate_dataset(src, file_type=None, experiment='TSG101_screen', sample=Non
|
|
957
1118
|
selected_paths = all_paths
|
958
1119
|
random.shuffle(selected_paths)
|
959
1120
|
print(f'All paths: {len(selected_paths)} paths')
|
960
|
-
|
1121
|
+
|
961
1122
|
total_images = len(selected_paths)
|
962
|
-
print(f'
|
963
|
-
|
1123
|
+
print(f'Found {total_images} images')
|
1124
|
+
|
964
1125
|
# Create a temp folder in dst
|
965
1126
|
temp_dir = os.path.join(dst, "temp_tars")
|
966
1127
|
os.makedirs(temp_dir, exist_ok=True)
|
967
1128
|
|
968
1129
|
# Chunking the data
|
969
|
-
|
970
|
-
|
971
|
-
|
972
|
-
remainder = len(selected_paths) % num_procs
|
973
|
-
else:
|
974
|
-
num_procs = 2
|
975
|
-
chunk_size = len(selected_paths) // 2
|
976
|
-
remainder = 0
|
1130
|
+
num_procs = max(2, cpu_count() - 2)
|
1131
|
+
chunk_size = len(selected_paths) // num_procs
|
1132
|
+
remainder = len(selected_paths) % num_procs
|
977
1133
|
|
978
1134
|
paths_chunks = []
|
979
1135
|
start = 0
|
@@ -983,45 +1139,43 @@ def generate_dataset(src, file_type=None, experiment='TSG101_screen', sample=Non
|
|
983
1139
|
start = end
|
984
1140
|
|
985
1141
|
temp_tar_files = [os.path.join(temp_dir, f'temp_{i}.tar') for i in range(num_procs)]
|
986
|
-
|
987
|
-
# Initialize the shared objects
|
988
|
-
counter_ = Value('i', 0)
|
989
|
-
lock_ = Lock()
|
990
1142
|
|
991
|
-
ctx = multiprocessing.get_context('spawn')
|
992
|
-
|
993
1143
|
print(f'Generating temporary tar files in {dst}')
|
994
|
-
|
1144
|
+
|
1145
|
+
# Initialize shared counter and lock
|
1146
|
+
counter = Value('i', 0)
|
1147
|
+
lock = Lock()
|
1148
|
+
|
1149
|
+
with Pool(processes=num_procs, initializer=initiate_counter, initargs=(counter, lock)) as pool:
|
1150
|
+
pool.starmap(add_images_to_tar, [(paths_chunks[i], temp_tar_files[i], total_images) for i in range(num_procs)])
|
1151
|
+
|
995
1152
|
# Combine the temporary tar files into a final tar
|
996
1153
|
date_name = datetime.date.today().strftime('%y%m%d')
|
997
|
-
|
1154
|
+
if not file_metadata is None:
|
1155
|
+
tar_name = f'{date_name}_{experiment}_{file_metadata}.tar'
|
1156
|
+
else:
|
1157
|
+
tar_name = f'{date_name}_{experiment}.tar'
|
1158
|
+
tar_name = os.path.join(dst, tar_name)
|
998
1159
|
if os.path.exists(tar_name):
|
999
1160
|
number = random.randint(1, 100)
|
1000
|
-
tar_name_2 = f'{date_name}_{experiment}_{
|
1001
|
-
print(f'Warning: {os.path.basename(tar_name)} exists saving as {os.path.basename(tar_name_2)} ')
|
1002
|
-
tar_name = tar_name_2
|
1003
|
-
|
1004
|
-
# Add the counter and lock to the arguments for pool.map
|
1161
|
+
tar_name_2 = f'{date_name}_{experiment}_{file_metadata}_{number}.tar'
|
1162
|
+
print(f'Warning: {os.path.basename(tar_name)} exists, saving as {os.path.basename(tar_name_2)} ')
|
1163
|
+
tar_name = os.path.join(dst, tar_name_2)
|
1164
|
+
|
1005
1165
|
print(f'Merging temporary files')
|
1006
|
-
#with Pool(processes=num_procs, initializer=init_globals, initargs=(counter_, lock_)) as pool:
|
1007
|
-
# results = pool.map(add_images_to_tar, zip(paths_chunks, temp_tar_files))
|
1008
1166
|
|
1009
|
-
with
|
1010
|
-
|
1011
|
-
|
1012
|
-
|
1013
|
-
|
1014
|
-
|
1015
|
-
|
1016
|
-
t.extract(member, path=dst)
|
1017
|
-
final_tar.add(os.path.join(dst, member.name), arcname=member.name)
|
1018
|
-
os.remove(os.path.join(dst, member.name))
|
1019
|
-
os.remove(tar_path)
|
1167
|
+
with tarfile.open(tar_name, 'w') as final_tar:
|
1168
|
+
for temp_tar_path in temp_tar_files:
|
1169
|
+
with tarfile.open(temp_tar_path, 'r') as temp_tar:
|
1170
|
+
for member in temp_tar.getmembers():
|
1171
|
+
file_obj = temp_tar.extractfile(member)
|
1172
|
+
final_tar.addfile(member, file_obj)
|
1173
|
+
os.remove(temp_tar_path)
|
1020
1174
|
|
1021
1175
|
# Delete the temp folder
|
1022
1176
|
shutil.rmtree(temp_dir)
|
1023
|
-
print(f"\nSaved {total_images} images to {
|
1024
|
-
|
1177
|
+
print(f"\nSaved {total_images} images to {tar_name}")
|
1178
|
+
|
1025
1179
|
def apply_model_to_tar(tar_path, model_path, file_type='cell_png', image_size=224, batch_size=64, normalize=True, preload='images', num_workers=10, verbose=False):
|
1026
1180
|
|
1027
1181
|
from .io import TarImageDataset, DataLoader
|
@@ -1257,7 +1411,14 @@ def generate_training_dataset(src, mode='annotation', annotation_column='test',
|
|
1257
1411
|
|
1258
1412
|
db_path = os.path.join(src, 'measurements','measurements.db')
|
1259
1413
|
dst = os.path.join(src, 'datasets', 'training')
|
1260
|
-
|
1414
|
+
|
1415
|
+
if os.path.exists(dst):
|
1416
|
+
for i in range(1, 1000):
|
1417
|
+
dst = os.path.join(src, 'datasets', f'training_{i}')
|
1418
|
+
if not os.path.exists(dst):
|
1419
|
+
print(f'Creating new directory for training: {dst}')
|
1420
|
+
break
|
1421
|
+
|
1261
1422
|
if mode == 'annotation':
|
1262
1423
|
class_paths_ls_2 = []
|
1263
1424
|
class_paths_ls = training_dataset_from_annotation(db_path, dst, annotation_column, annotated_classes=annotated_classes)
|
@@ -1268,6 +1429,7 @@ def generate_training_dataset(src, mode='annotation', annotation_column='test',
|
|
1268
1429
|
|
1269
1430
|
elif mode == 'metadata':
|
1270
1431
|
class_paths_ls = []
|
1432
|
+
class_len_ls = []
|
1271
1433
|
[df] = _read_db(db_loc=db_path, tables=['png_list'])
|
1272
1434
|
df['metadata_based_class'] = pd.NA
|
1273
1435
|
for i, class_ in enumerate(classes):
|
@@ -1275,7 +1437,18 @@ def generate_training_dataset(src, mode='annotation', annotation_column='test',
|
|
1275
1437
|
df.loc[df[metadata_type_by].isin(ls), 'metadata_based_class'] = class_
|
1276
1438
|
|
1277
1439
|
for class_ in classes:
|
1440
|
+
if size == None:
|
1441
|
+
c_s = []
|
1442
|
+
for c in classes:
|
1443
|
+
c_s_t_df = df[df['metadata_based_class'] == c]
|
1444
|
+
c_s.append(len(c_s_t_df))
|
1445
|
+
print(f'Found {len(c_s_t_df)} images for class {c}')
|
1446
|
+
size = min(c_s)
|
1447
|
+
print(f'Using the smallest class size: {size}')
|
1448
|
+
|
1278
1449
|
class_temp_df = df[df['metadata_based_class'] == class_]
|
1450
|
+
class_len_ls.append(len(class_temp_df))
|
1451
|
+
print(f'Found {len(class_temp_df)} images for class {class_}')
|
1279
1452
|
class_paths_temp = random.sample(class_temp_df['png_path'].tolist(), size)
|
1280
1453
|
class_paths_ls.append(class_paths_temp)
|
1281
1454
|
|
@@ -1332,7 +1505,7 @@ def generate_training_dataset(src, mode='annotation', annotation_column='test',
|
|
1332
1505
|
|
1333
1506
|
return
|
1334
1507
|
|
1335
|
-
def
|
1508
|
+
def generate_loaders_v1(src, train_mode='erm', mode='train', image_size=224, batch_size=32, classes=['nc','pc'], num_workers=None, validation_split=0.0, max_show=2, pin_memory=False, normalize=False, verbose=False):
|
1336
1509
|
"""
|
1337
1510
|
Generate data loaders for training and validation/test datasets.
|
1338
1511
|
|
@@ -1463,56 +1636,223 @@ def generate_loaders(src, train_mode='erm', mode='train', image_size=224, batch_
|
|
1463
1636
|
|
1464
1637
|
return train_loaders, val_loaders, plate_names
|
1465
1638
|
|
1466
|
-
def
|
1639
|
+
def generate_loaders(src, train_mode='erm', mode='train', image_size=224, batch_size=32, classes=['nc','pc'], num_workers=None, validation_split=0.0, max_show=2, pin_memory=False, normalize=False, channels=[1, 2, 3], verbose=False):
|
1640
|
+
|
1467
1641
|
"""
|
1468
|
-
|
1642
|
+
Generate data loaders for training and validation/test datasets.
|
1469
1643
|
|
1470
1644
|
Parameters:
|
1471
|
-
src (str): The source
|
1472
|
-
|
1473
|
-
|
1645
|
+
- src (str): The source directory containing the data.
|
1646
|
+
- train_mode (str): The training mode. Options are 'erm' (Empirical Risk Minimization) or 'irm' (Invariant Risk Minimization).
|
1647
|
+
- mode (str): The mode of operation. Options are 'train' or 'test'.
|
1648
|
+
- image_size (int): The size of the input images.
|
1649
|
+
- batch_size (int): The batch size for the data loaders.
|
1650
|
+
- classes (list): The list of classes to consider.
|
1651
|
+
- num_workers (int): The number of worker threads for data loading.
|
1652
|
+
- validation_split (float): The fraction of data to use for validation when train_mode is 'erm'.
|
1653
|
+
- max_show (int): The maximum number of images to show when verbose is True.
|
1654
|
+
- pin_memory (bool): Whether to pin memory for faster data transfer.
|
1655
|
+
- normalize (bool): Whether to normalize the input images.
|
1656
|
+
- verbose (bool): Whether to print additional information and show images.
|
1657
|
+
- channels (list): The list of channels to retain. Options are [1, 2, 3] for all channels, [1, 2] for blue and green, etc.
|
1474
1658
|
|
1475
1659
|
Returns:
|
1476
|
-
|
1660
|
+
- train_loaders (list): List of data loaders for training datasets.
|
1661
|
+
- val_loaders (list): List of data loaders for validation datasets.
|
1662
|
+
- plate_names (list): List of plate names (only applicable when train_mode is 'irm').
|
1477
1663
|
"""
|
1478
|
-
|
1479
|
-
from .io import _read_and_merge_data, _results_to_csv
|
1480
|
-
from .plot import plot_merged, _plot_controls, _plot_recruitment
|
1481
|
-
from .utils import _object_filter, annotate_conditions, _calculate_recruitment, _group_by_well
|
1482
|
-
|
1483
|
-
settings_dict = {**metadata_settings, **advanced_settings}
|
1484
|
-
settings_df = pd.DataFrame(list(settings_dict.items()), columns=['Key', 'Value'])
|
1485
|
-
settings_csv = os.path.join(src,'settings','analyze_settings.csv')
|
1486
|
-
os.makedirs(os.path.join(src,'settings'), exist_ok=True)
|
1487
|
-
settings_df.to_csv(settings_csv, index=False)
|
1488
1664
|
|
1489
|
-
|
1490
|
-
|
1491
|
-
|
1492
|
-
|
1493
|
-
|
1494
|
-
|
1495
|
-
|
1496
|
-
|
1497
|
-
|
1498
|
-
|
1499
|
-
|
1500
|
-
|
1501
|
-
|
1502
|
-
|
1503
|
-
|
1504
|
-
|
1505
|
-
|
1506
|
-
|
1507
|
-
|
1508
|
-
|
1509
|
-
|
1510
|
-
|
1511
|
-
|
1512
|
-
|
1513
|
-
|
1514
|
-
|
1515
|
-
|
1665
|
+
from .io import MyDataset
|
1666
|
+
from .plot import _imshow
|
1667
|
+
from torchvision import transforms
|
1668
|
+
from torch.utils.data import DataLoader, random_split
|
1669
|
+
from collections import defaultdict
|
1670
|
+
import os
|
1671
|
+
import random
|
1672
|
+
from PIL import Image
|
1673
|
+
from torchvision.transforms import ToTensor
|
1674
|
+
|
1675
|
+
chans = []
|
1676
|
+
|
1677
|
+
if 'r' in channels:
|
1678
|
+
chans.append(1)
|
1679
|
+
if 'g' in channels:
|
1680
|
+
chans.append(2)
|
1681
|
+
if 'b' in channels:
|
1682
|
+
chans.append(3)
|
1683
|
+
|
1684
|
+
channels = chans
|
1685
|
+
|
1686
|
+
if verbose:
|
1687
|
+
print(f'Training a network on channels: {channels}')
|
1688
|
+
print(f'Channel 1: Red, Channel 2: Green, Channel 3: Blue')
|
1689
|
+
|
1690
|
+
class SelectChannels:
|
1691
|
+
def __init__(self, channels):
|
1692
|
+
self.channels = channels
|
1693
|
+
|
1694
|
+
def __call__(self, img):
|
1695
|
+
img = img.clone()
|
1696
|
+
if 1 not in self.channels:
|
1697
|
+
img[0, :, :] = 0 # Zero out the red channel
|
1698
|
+
if 2 not in self.channels:
|
1699
|
+
img[1, :, :] = 0 # Zero out the green channel
|
1700
|
+
if 3 not in self.channels:
|
1701
|
+
img[2, :, :] = 0 # Zero out the blue channel
|
1702
|
+
return img
|
1703
|
+
|
1704
|
+
plate_to_filenames = defaultdict(list)
|
1705
|
+
plate_to_labels = defaultdict(list)
|
1706
|
+
train_loaders = []
|
1707
|
+
val_loaders = []
|
1708
|
+
plate_names = []
|
1709
|
+
|
1710
|
+
if normalize:
|
1711
|
+
transform = transforms.Compose([
|
1712
|
+
transforms.ToTensor(),
|
1713
|
+
transforms.CenterCrop(size=(image_size, image_size)),
|
1714
|
+
SelectChannels(channels),
|
1715
|
+
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])
|
1716
|
+
else:
|
1717
|
+
transform = transforms.Compose([
|
1718
|
+
transforms.ToTensor(),
|
1719
|
+
transforms.CenterCrop(size=(image_size, image_size)),
|
1720
|
+
SelectChannels(channels)])
|
1721
|
+
|
1722
|
+
if mode == 'train':
|
1723
|
+
data_dir = os.path.join(src, 'train')
|
1724
|
+
shuffle = True
|
1725
|
+
print('Generating Train and validation datasets')
|
1726
|
+
elif mode == 'test':
|
1727
|
+
data_dir = os.path.join(src, 'test')
|
1728
|
+
val_loaders = []
|
1729
|
+
validation_split = 0.0
|
1730
|
+
shuffle = True
|
1731
|
+
print('Generating test dataset')
|
1732
|
+
else:
|
1733
|
+
print(f'mode:{mode} is not valid, use mode = train or test')
|
1734
|
+
return
|
1735
|
+
|
1736
|
+
if train_mode == 'erm':
|
1737
|
+
data = MyDataset(data_dir, classes, transform=transform, shuffle=shuffle, pin_memory=pin_memory)
|
1738
|
+
if validation_split > 0:
|
1739
|
+
train_size = int((1 - validation_split) * len(data))
|
1740
|
+
val_size = len(data) - train_size
|
1741
|
+
|
1742
|
+
print(f'Train data:{train_size}, Validation data:{val_size}')
|
1743
|
+
|
1744
|
+
train_dataset, val_dataset = random_split(data, [train_size, val_size])
|
1745
|
+
|
1746
|
+
train_loaders = DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers if num_workers is not None else 0, pin_memory=pin_memory)
|
1747
|
+
val_loaders = DataLoader(val_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers if num_workers is not None else 0, pin_memory=pin_memory)
|
1748
|
+
else:
|
1749
|
+
train_loaders = DataLoader(data, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers if num_workers is not None else 0, pin_memory=pin_memory)
|
1750
|
+
|
1751
|
+
elif train_mode == 'irm':
|
1752
|
+
data = MyDataset(data_dir, classes, transform=transform, shuffle=shuffle, pin_memory=pin_memory)
|
1753
|
+
|
1754
|
+
for filename, label in zip(data.filenames, data.labels):
|
1755
|
+
plate = data.get_plate(filename)
|
1756
|
+
plate_to_filenames[plate].append(filename)
|
1757
|
+
plate_to_labels[plate].append(label)
|
1758
|
+
|
1759
|
+
for plate, filenames in plate_to_filenames.items():
|
1760
|
+
labels = plate_to_labels[plate]
|
1761
|
+
plate_data = MyDataset(data_dir, classes, specific_files=filenames, specific_labels=labels, transform=transform, shuffle=False, pin_memory=pin_memory)
|
1762
|
+
plate_names.append(plate)
|
1763
|
+
|
1764
|
+
if validation_split > 0:
|
1765
|
+
train_size = int((1 - validation_split) * len(plate_data))
|
1766
|
+
val_size = len(plate_data) - train_size
|
1767
|
+
|
1768
|
+
print(f'Train data:{train_size}, Validation data:{val_size}')
|
1769
|
+
|
1770
|
+
train_dataset, val_dataset = random_split(plate_data, [train_size, val_size])
|
1771
|
+
|
1772
|
+
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers if num_workers is not None else 0, pin_memory=pin_memory)
|
1773
|
+
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers if num_workers is not None else 0, pin_memory=pin_memory)
|
1774
|
+
|
1775
|
+
train_loaders.append(train_loader)
|
1776
|
+
val_loaders.append(val_loader)
|
1777
|
+
else:
|
1778
|
+
train_loader = DataLoader(plate_data, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers if num_workers is not None else 0, pin_memory=pin_memory)
|
1779
|
+
train_loaders.append(train_loader)
|
1780
|
+
val_loaders.append(None)
|
1781
|
+
|
1782
|
+
else:
|
1783
|
+
print(f'train_mode:{train_mode} is not valid, use: train_mode = irm or erm')
|
1784
|
+
return
|
1785
|
+
|
1786
|
+
if verbose:
|
1787
|
+
if train_mode == 'erm':
|
1788
|
+
for idx, (images, labels, filenames) in enumerate(train_loaders):
|
1789
|
+
if idx >= max_show:
|
1790
|
+
break
|
1791
|
+
images = images.cpu()
|
1792
|
+
label_strings = [str(label.item()) for label in labels]
|
1793
|
+
_imshow(images, label_strings, nrow=20, fontsize=12)
|
1794
|
+
elif train_mode == 'irm':
|
1795
|
+
for plate_name, train_loader in zip(plate_names, train_loaders):
|
1796
|
+
print(f'Plate: {plate_name} with {len(train_loader.dataset)} images')
|
1797
|
+
for idx, (images, labels, filenames) in enumerate(train_loader):
|
1798
|
+
if idx >= max_show:
|
1799
|
+
break
|
1800
|
+
images = images.cpu()
|
1801
|
+
label_strings = [str(label.item()) for label in labels]
|
1802
|
+
_imshow(images, label_strings, nrow=20, fontsize=12)
|
1803
|
+
|
1804
|
+
return train_loaders, val_loaders, plate_names
|
1805
|
+
|
1806
|
+
def analyze_recruitment(src, metadata_settings, advanced_settings):
|
1807
|
+
"""
|
1808
|
+
Analyze recruitment data by grouping the DataFrame by well coordinates and plotting controls and recruitment data.
|
1809
|
+
|
1810
|
+
Parameters:
|
1811
|
+
src (str): The source of the recruitment data.
|
1812
|
+
metadata_settings (dict): The settings for metadata.
|
1813
|
+
advanced_settings (dict): The advanced settings for recruitment analysis.
|
1814
|
+
|
1815
|
+
Returns:
|
1816
|
+
None
|
1817
|
+
"""
|
1818
|
+
|
1819
|
+
from .io import _read_and_merge_data, _results_to_csv
|
1820
|
+
from .plot import plot_merged, _plot_controls, _plot_recruitment
|
1821
|
+
from .utils import _object_filter, annotate_conditions, _calculate_recruitment, _group_by_well
|
1822
|
+
|
1823
|
+
settings_dict = {**metadata_settings, **advanced_settings}
|
1824
|
+
settings_df = pd.DataFrame(list(settings_dict.items()), columns=['Key', 'Value'])
|
1825
|
+
settings_csv = os.path.join(src,'settings','analyze_settings.csv')
|
1826
|
+
os.makedirs(os.path.join(src,'settings'), exist_ok=True)
|
1827
|
+
settings_df.to_csv(settings_csv, index=False)
|
1828
|
+
|
1829
|
+
# metadata settings
|
1830
|
+
target = metadata_settings['target']
|
1831
|
+
cell_types = metadata_settings['cell_types']
|
1832
|
+
cell_plate_metadata = metadata_settings['cell_plate_metadata']
|
1833
|
+
pathogen_types = metadata_settings['pathogen_types']
|
1834
|
+
pathogen_plate_metadata = metadata_settings['pathogen_plate_metadata']
|
1835
|
+
treatments = metadata_settings['treatments']
|
1836
|
+
treatment_plate_metadata = metadata_settings['treatment_plate_metadata']
|
1837
|
+
metadata_types = metadata_settings['metadata_types']
|
1838
|
+
channel_dims = metadata_settings['channel_dims']
|
1839
|
+
cell_chann_dim = metadata_settings['cell_chann_dim']
|
1840
|
+
cell_mask_dim = metadata_settings['cell_mask_dim']
|
1841
|
+
nucleus_chann_dim = metadata_settings['nucleus_chann_dim']
|
1842
|
+
nucleus_mask_dim = metadata_settings['nucleus_mask_dim']
|
1843
|
+
pathogen_chann_dim = metadata_settings['pathogen_chann_dim']
|
1844
|
+
pathogen_mask_dim = metadata_settings['pathogen_mask_dim']
|
1845
|
+
channel_of_interest = metadata_settings['channel_of_interest']
|
1846
|
+
|
1847
|
+
# Advanced settings
|
1848
|
+
plot = advanced_settings['plot']
|
1849
|
+
plot_nr = advanced_settings['plot_nr']
|
1850
|
+
plot_control = advanced_settings['plot_control']
|
1851
|
+
figuresize = advanced_settings['figuresize']
|
1852
|
+
remove_background = advanced_settings['remove_background']
|
1853
|
+
backgrounds = advanced_settings['backgrounds']
|
1854
|
+
include_noninfected = advanced_settings['include_noninfected']
|
1855
|
+
include_multiinfected = advanced_settings['include_multiinfected']
|
1516
1856
|
include_multinucleated = advanced_settings['include_multinucleated']
|
1517
1857
|
cells_per_well = advanced_settings['cells_per_well']
|
1518
1858
|
pathogen_size_range = advanced_settings['pathogen_size_range']
|
@@ -1569,15 +1909,30 @@ def analyze_recruitment(src, metadata_settings, advanced_settings):
|
|
1569
1909
|
df = df.dropna(subset=['condition'])
|
1570
1910
|
print(f'After dropping non-annotated wells: {len(df)} rows')
|
1571
1911
|
files = df['file_name'].tolist()
|
1912
|
+
print(f'found: {len(files)} files')
|
1572
1913
|
files = [item + '.npy' for item in files]
|
1573
1914
|
random.shuffle(files)
|
1574
|
-
|
1915
|
+
|
1916
|
+
_max = 10**100
|
1917
|
+
|
1918
|
+
if cell_size_range is None and nucleus_size_range is None and pathogen_size_range is None:
|
1919
|
+
filter_min_max = None
|
1920
|
+
else:
|
1921
|
+
if cell_size_range is None:
|
1922
|
+
cell_size_range = [0,_max]
|
1923
|
+
if nucleus_size_range is None:
|
1924
|
+
nucleus_size_range = [0,_max]
|
1925
|
+
if pathogen_size_range is None:
|
1926
|
+
pathogen_size_range = [0,_max]
|
1927
|
+
|
1928
|
+
filter_min_max = [[cell_size_range[0],cell_size_range[1]],[nucleus_size_range[0],nucleus_size_range[1]],[pathogen_size_range[0],pathogen_size_range[1]]]
|
1929
|
+
|
1575
1930
|
if plot:
|
1576
1931
|
plot_settings = {'include_noninfected':include_noninfected,
|
1577
1932
|
'include_multiinfected':include_multiinfected,
|
1578
1933
|
'include_multinucleated':include_multinucleated,
|
1579
1934
|
'remove_background':remove_background,
|
1580
|
-
'filter_min_max':
|
1935
|
+
'filter_min_max':filter_min_max,
|
1581
1936
|
'channel_dims':channel_dims,
|
1582
1937
|
'backgrounds':backgrounds,
|
1583
1938
|
'cell_mask_dim':mask_dims[0],
|
@@ -1640,6 +1995,7 @@ def preprocess_generate_masks(src, settings={}):
|
|
1640
1995
|
from .plot import plot_merged, plot_arrays
|
1641
1996
|
from .utils import _pivot_counts_table
|
1642
1997
|
|
1998
|
+
settings['plot'] = False
|
1643
1999
|
settings['fps'] = 2
|
1644
2000
|
settings['remove_background'] = True
|
1645
2001
|
settings['lower_quantile'] = 0.02
|
@@ -1655,6 +2011,15 @@ def preprocess_generate_masks(src, settings={}):
|
|
1655
2011
|
settings['upscale'] = False
|
1656
2012
|
settings['upscale_factor'] = 2.0
|
1657
2013
|
|
2014
|
+
settings['randomize'] = True
|
2015
|
+
settings['timelapse'] = False
|
2016
|
+
settings['timelapse_displacement'] = None
|
2017
|
+
settings['timelapse_memory'] = 3
|
2018
|
+
settings['timelapse_frame_limits'] = None
|
2019
|
+
settings['timelapse_remove_transient'] = False
|
2020
|
+
settings['timelapse_mode'] = 'trackpy'
|
2021
|
+
settings['timelapse_objects'] = ['cells']
|
2022
|
+
|
1658
2023
|
settings_df = pd.DataFrame(list(settings.items()), columns=['Key', 'Value'])
|
1659
2024
|
settings_csv = os.path.join(src,'settings','preprocess_generate_masks_settings.csv')
|
1660
2025
|
os.makedirs(os.path.join(src,'settings'), exist_ok=True)
|
@@ -1723,7 +2088,6 @@ def preprocess_generate_masks(src, settings={}):
|
|
1723
2088
|
'cell_mask_dim':cell_mask_dim,
|
1724
2089
|
'nucleus_mask_dim':nucleus_mask_dim,
|
1725
2090
|
'pathogen_mask_dim':pathogen_mask_dim,
|
1726
|
-
'overlay_chans':[0,2,3],
|
1727
2091
|
'outline_thickness':3,
|
1728
2092
|
'outline_color':'gbr',
|
1729
2093
|
'overlay_chans':overlay_channels,
|
@@ -1735,6 +2099,10 @@ def preprocess_generate_masks(src, settings={}):
|
|
1735
2099
|
'figuresize':20,
|
1736
2100
|
'cmap':'inferno',
|
1737
2101
|
'verbose':False}
|
2102
|
+
|
2103
|
+
if settings['test_mode'] == True:
|
2104
|
+
plot_settings['nr'] = len(os.path.join(src,'merged'))
|
2105
|
+
|
1738
2106
|
try:
|
1739
2107
|
fig = plot_merged(src=os.path.join(src,'merged'), settings=plot_settings)
|
1740
2108
|
except Exception as e:
|
@@ -1747,26 +2115,61 @@ def preprocess_generate_masks(src, settings={}):
|
|
1747
2115
|
print("Successfully completed run")
|
1748
2116
|
return
|
1749
2117
|
|
1750
|
-
def identify_masks_finetune(
|
2118
|
+
def identify_masks_finetune(settings):
|
1751
2119
|
|
1752
2120
|
from .plot import print_mask_and_flows
|
1753
2121
|
from .utils import get_files_from_dir, resize_images_and_labels
|
1754
2122
|
from .io import _load_normalized_images_and_labels, _load_images_and_labels
|
1755
2123
|
|
2124
|
+
src=settings['src']
|
2125
|
+
dst=settings['dst']
|
2126
|
+
model_name=settings['model_name']
|
2127
|
+
diameter=settings['diameter']
|
2128
|
+
batch_size=settings['batch_size']
|
2129
|
+
flow_threshold=settings['flow_threshold']
|
2130
|
+
cellprob_threshold=settings['cellprob_threshold']
|
2131
|
+
|
2132
|
+
verbose=settings['verbose']
|
2133
|
+
plot=settings['plot']
|
2134
|
+
save=settings['save']
|
2135
|
+
custom_model=settings['custom_model']
|
2136
|
+
overlay=settings['overlay']
|
2137
|
+
|
2138
|
+
figuresize=25
|
2139
|
+
cmap='inferno'
|
2140
|
+
channels = [0,0]
|
2141
|
+
signal_thresholds = 1000
|
2142
|
+
normalize = True
|
2143
|
+
percentiles = [2,98]
|
2144
|
+
circular = False
|
2145
|
+
invert = False
|
2146
|
+
resize = False
|
2147
|
+
settings['width_height'] = [1000,1000]
|
2148
|
+
target_height = settings['width_height'][1]
|
2149
|
+
target_width = settings['width_height'][0]
|
2150
|
+
rescale = False
|
2151
|
+
resample = False
|
2152
|
+
grayscale = True
|
2153
|
+
test = False
|
2154
|
+
|
2155
|
+
os.makedirs(dst, exist_ok=True)
|
2156
|
+
|
2157
|
+
if not custom_model is None:
|
2158
|
+
if not os.path.exists(custom_model):
|
2159
|
+
print(f'Custom model not found: {custom_model}')
|
2160
|
+
return
|
2161
|
+
|
1756
2162
|
if not torch.cuda.is_available():
|
1757
2163
|
print(f'Torch CUDA is not available, using CPU')
|
1758
2164
|
|
1759
2165
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
1760
2166
|
|
1761
2167
|
if custom_model == None:
|
1762
|
-
|
1763
|
-
|
1764
|
-
|
1765
|
-
|
1766
|
-
|
1767
|
-
if custom_model != None:
|
1768
|
-
model = cp_models.CellposeModel(gpu=torch.cuda.is_available(), model_type=None, pretrained_model=custom_model, diam_mean=diameter, device=device, net_avg=False) #Assuming diameter is defined elsewhere
|
1769
|
-
print(f'loaded custom model:{custom_model}')
|
2168
|
+
model = cp_models.CellposeModel(gpu=True, model_type=model_name, device=device)
|
2169
|
+
print(f'Loaded model: {model_name}')
|
2170
|
+
else:
|
2171
|
+
model = cp_models.CellposeModel(gpu=torch.cuda.is_available(), model_type=None, pretrained_model=custom_model, diam_mean=diameter, device=device)
|
2172
|
+
print("Pretrained Model Loaded:", model.pretrained_model)
|
1770
2173
|
|
1771
2174
|
chans = [2, 1] if model_name == 'cyto2' else [0,0] if model_name == 'nucleus' else [1,0] if model_name == 'cyto' else [2, 0]
|
1772
2175
|
|
@@ -1778,14 +2181,16 @@ def identify_masks_finetune(src, dst, model_name, channels, diameter, batch_size
|
|
1778
2181
|
if verbose == True:
|
1779
2182
|
print(f'Cellpose settings: Model: {model_name}, channels: {channels}, cellpose_chans: {chans}, diameter:{diameter}, flow_threshold:{flow_threshold}, cellprob_threshold:{cellprob_threshold}')
|
1780
2183
|
|
1781
|
-
all_image_files =
|
2184
|
+
all_image_files = [os.path.join(src, f) for f in os.listdir(src) if f.endswith('.tif')]
|
2185
|
+
|
1782
2186
|
random.shuffle(all_image_files)
|
1783
2187
|
|
1784
2188
|
time_ls = []
|
1785
2189
|
for i in range(0, len(all_image_files), batch_size):
|
1786
2190
|
image_files = all_image_files[i:i+batch_size]
|
2191
|
+
|
1787
2192
|
if normalize:
|
1788
|
-
images, _, image_names, _ = _load_normalized_images_and_labels(image_files=image_files, label_files=None, signal_thresholds=signal_thresholds, channels=channels, percentiles=percentiles, circular=circular, invert=invert, visualize=
|
2193
|
+
images, _, image_names, _ = _load_normalized_images_and_labels(image_files=image_files, label_files=None, signal_thresholds=signal_thresholds, channels=channels, percentiles=percentiles, circular=circular, invert=invert, visualize=plot)
|
1789
2194
|
images = [np.squeeze(img) if img.shape[-1] == 1 else img for img in images]
|
1790
2195
|
orig_dims = [(image.shape[0], image.shape[1]) for image in images]
|
1791
2196
|
else:
|
@@ -1806,8 +2211,7 @@ def identify_masks_finetune(src, dst, model_name, channels, diameter, batch_size
|
|
1806
2211
|
cellprob_threshold=cellprob_threshold,
|
1807
2212
|
rescale=rescale,
|
1808
2213
|
resample=resample,
|
1809
|
-
|
1810
|
-
progress=False)
|
2214
|
+
progress=True)
|
1811
2215
|
|
1812
2216
|
if len(output) == 4:
|
1813
2217
|
mask, flows, _, _ = output
|
@@ -1882,7 +2286,6 @@ def identify_masks(src, object_type, model_name, batch_size, channels, diameter,
|
|
1882
2286
|
|
1883
2287
|
#Note add logic that handles batches of size 1 as these will break the code batches must all be > 2 images
|
1884
2288
|
gc.collect()
|
1885
|
-
#print('========== generating masks ==========')
|
1886
2289
|
|
1887
2290
|
if not torch.cuda.is_available():
|
1888
2291
|
print(f'Torch CUDA is not available, using CPU')
|
@@ -2047,9 +2450,9 @@ def all_elements_match(list1, list2):
|
|
2047
2450
|
# Check if all elements in list1 are in list2
|
2048
2451
|
return all(element in list2 for element in list1)
|
2049
2452
|
|
2050
|
-
def
|
2453
|
+
def generate_cellpose_masks(src, settings, object_type):
|
2051
2454
|
|
2052
|
-
from .utils import _masks_to_masks_stack, _filter_cp_masks, _get_cellpose_batch_size, _get_object_settings, _get_cellpose_channels, mask_object_count
|
2455
|
+
from .utils import _masks_to_masks_stack, _filter_cp_masks, _get_cellpose_batch_size, _get_object_settings, _get_cellpose_channels, _choose_model, mask_object_count
|
2053
2456
|
from .io import _create_database, _save_object_counts_to_database, _check_masks, _get_avg_object_size
|
2054
2457
|
from .timelapse import _npz_to_movie, _btrack_track_cells, _trackpy_track_cells
|
2055
2458
|
from .plot import plot_masks
|
@@ -2079,15 +2482,12 @@ def generate_cellpose_masks_v1(src, settings, object_type):
|
|
2079
2482
|
cellpose_channels = _get_cellpose_channels(src, settings['nucleus_channel'], settings['pathogen_channel'], settings['cell_channel'])
|
2080
2483
|
if settings['verbose']:
|
2081
2484
|
print(cellpose_channels)
|
2485
|
+
|
2082
2486
|
channels = cellpose_channels[object_type]
|
2083
2487
|
cellpose_batch_size = _get_cellpose_batch_size()
|
2084
|
-
|
2085
2488
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
2086
|
-
model =
|
2087
|
-
#dn = denoise.CellposeDenoiseModel(model_type=f"denoise_{model_name}", gpu=True, device=device)
|
2088
|
-
|
2489
|
+
model = _choose_model(model_name, device, object_type='cell', restore_type=None)
|
2089
2490
|
chans = [2, 1] if model_name == 'cyto2' else [0,0] if model_name == 'nucleus' else [2,0] if model_name == 'cyto' else [2, 0] if model_name == 'cyto3' else [2, 0]
|
2090
|
-
|
2091
2491
|
paths = [os.path.join(src, file) for file in os.listdir(src) if file.endswith('.npz')]
|
2092
2492
|
|
2093
2493
|
count_loc = os.path.dirname(src)+'/measurements/measurements.db'
|
@@ -2096,7 +2496,6 @@ def generate_cellpose_masks_v1(src, settings, object_type):
|
|
2096
2496
|
|
2097
2497
|
average_sizes = []
|
2098
2498
|
time_ls = []
|
2099
|
-
|
2100
2499
|
for file_index, path in enumerate(paths):
|
2101
2500
|
name = os.path.basename(path)
|
2102
2501
|
name, ext = os.path.splitext(name)
|
@@ -2210,23 +2609,45 @@ def generate_cellpose_masks_v1(src, settings, object_type):
|
|
2210
2609
|
mode=timelapse_mode)
|
2211
2610
|
else:
|
2212
2611
|
mask_stack = _masks_to_masks_stack(masks)
|
2213
|
-
|
2214
2612
|
else:
|
2215
2613
|
_save_object_counts_to_database(masks, object_type, batch_filenames, count_loc, added_string='_before_filtration')
|
2216
|
-
|
2217
|
-
|
2218
|
-
|
2219
|
-
|
2220
|
-
|
2221
|
-
|
2222
|
-
|
2223
|
-
|
2224
|
-
|
2225
|
-
|
2226
|
-
|
2227
|
-
|
2228
|
-
|
2614
|
+
if object_settings['merge'] and not settings['filter']:
|
2615
|
+
mask_stack = _filter_cp_masks(masks=masks,
|
2616
|
+
flows=flows,
|
2617
|
+
filter_size=False,
|
2618
|
+
filter_intensity=False,
|
2619
|
+
minimum_size=object_settings['minimum_size'],
|
2620
|
+
maximum_size=object_settings['maximum_size'],
|
2621
|
+
remove_border_objects=False,
|
2622
|
+
merge=object_settings['merge'],
|
2623
|
+
batch=batch,
|
2624
|
+
plot=settings['plot'],
|
2625
|
+
figuresize=figuresize)
|
2626
|
+
|
2627
|
+
if settings['filter']:
|
2628
|
+
mask_stack = _filter_cp_masks(masks=masks,
|
2629
|
+
flows=flows,
|
2630
|
+
filter_size=object_settings['filter_size'],
|
2631
|
+
filter_intensity=object_settings['filter_intensity'],
|
2632
|
+
minimum_size=object_settings['minimum_size'],
|
2633
|
+
maximum_size=object_settings['maximum_size'],
|
2634
|
+
remove_border_objects=object_settings['remove_border_objects'],
|
2635
|
+
merge=object_settings['merge'],
|
2636
|
+
batch=batch,
|
2637
|
+
plot=settings['plot'],
|
2638
|
+
figuresize=figuresize)
|
2639
|
+
|
2640
|
+
_save_object_counts_to_database(mask_stack, object_type, batch_filenames, count_loc, added_string='_after_filtration')
|
2641
|
+
else:
|
2642
|
+
mask_stack = _masks_to_masks_stack(masks)
|
2229
2643
|
|
2644
|
+
if settings['plot']:
|
2645
|
+
for idx, (mask, flow, image) in enumerate(zip(masks, flows[0], batch)):
|
2646
|
+
if idx == 0:
|
2647
|
+
num_objects = mask_object_count(mask)
|
2648
|
+
print(f'Number of objects, : {num_objects}')
|
2649
|
+
plot_masks(batch=image, masks=mask, flows=flow, cmap='inferno', figuresize=figuresize, nr=1, file_type='.npz', print_object_number=True)
|
2650
|
+
|
2230
2651
|
if not np.any(mask_stack):
|
2231
2652
|
average_obj_size = 0
|
2232
2653
|
else:
|
@@ -2255,207 +2676,661 @@ def generate_cellpose_masks_v1(src, settings, object_type):
|
|
2255
2676
|
torch.cuda.empty_cache()
|
2256
2677
|
return
|
2257
2678
|
|
2258
|
-
def
|
2259
|
-
|
2260
|
-
from .utils import
|
2261
|
-
from .
|
2262
|
-
|
2263
|
-
|
2264
|
-
|
2265
|
-
gc.collect()
|
2266
|
-
if not torch.cuda.is_available():
|
2267
|
-
print(f'Torch CUDA is not available, using CPU')
|
2268
|
-
|
2269
|
-
figuresize=25
|
2270
|
-
timelapse = settings['timelapse']
|
2271
|
-
|
2272
|
-
if timelapse:
|
2273
|
-
timelapse_displacement = settings['timelapse_displacement']
|
2274
|
-
timelapse_frame_limits = settings['timelapse_frame_limits']
|
2275
|
-
timelapse_memory = settings['timelapse_memory']
|
2276
|
-
timelapse_remove_transient = settings['timelapse_remove_transient']
|
2277
|
-
timelapse_mode = settings['timelapse_mode']
|
2278
|
-
timelapse_objects = settings['timelapse_objects']
|
2679
|
+
def generate_masks_from_imgs(src, model, model_name, batch_size, diameter, cellprob_threshold, grayscale, save, normalize, channels, percentiles, circular, invert, plot, resize, target_height, target_width, verbose):
|
2680
|
+
from .io import _load_images_and_labels, _load_normalized_images_and_labels
|
2681
|
+
from .utils import resize_images_and_labels, resizescikit
|
2682
|
+
from .plot import print_mask_and_flows
|
2683
|
+
|
2684
|
+
dst = os.path.join(src, model_name)
|
2685
|
+
os.makedirs(dst, exist_ok=True)
|
2279
2686
|
|
2280
|
-
batch_size = settings['batch_size']
|
2281
|
-
cellprob_threshold = settings[f'{object_type}_CP_prob']
|
2282
2687
|
flow_threshold = 30
|
2283
|
-
|
2284
|
-
object_settings = _get_object_settings(object_type, settings)
|
2285
|
-
model_name = object_settings['model_name']
|
2286
|
-
|
2287
|
-
cellpose_channels = _get_cellpose_channels(src, settings['nucleus_channel'], settings['pathogen_channel'], settings['cell_channel'])
|
2288
|
-
if settings['verbose']:
|
2289
|
-
print(cellpose_channels)
|
2688
|
+
chans = [2, 1] if model_name == 'cyto2' else [0,0] if model_name == 'nucleus' else [1,0] if model_name == 'cyto' else [2, 0]
|
2290
2689
|
|
2291
|
-
|
2292
|
-
|
2293
|
-
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
2294
|
-
model = _choose_model(model_name, device, object_type='cell', restore_type=None)
|
2295
|
-
chans = [2, 1] if model_name == 'cyto2' else [0,0] if model_name == 'nucleus' else [2,0] if model_name == 'cyto' else [2, 0] if model_name == 'cyto3' else [2, 0]
|
2296
|
-
paths = [os.path.join(src, file) for file in os.listdir(src) if file.endswith('.npz')]
|
2690
|
+
if grayscale:
|
2691
|
+
chans=[0, 0]
|
2297
2692
|
|
2298
|
-
|
2299
|
-
|
2300
|
-
|
2693
|
+
all_image_files = [os.path.join(src, f) for f in os.listdir(src) if f.endswith('.tif')]
|
2694
|
+
random.shuffle(all_image_files)
|
2695
|
+
|
2696
|
+
|
2697
|
+
if verbose == True:
|
2698
|
+
print(f'Cellpose settings: Model: {model_name}, channels: {channels}, cellpose_chans: {chans}, diameter:{diameter}, flow_threshold:{flow_threshold}, cellprob_threshold:{cellprob_threshold}')
|
2301
2699
|
|
2302
|
-
average_sizes = []
|
2303
2700
|
time_ls = []
|
2304
|
-
for
|
2305
|
-
|
2306
|
-
name, ext = os.path.splitext(name)
|
2307
|
-
output_folder = os.path.join(os.path.dirname(path), object_type+'_mask_stack')
|
2308
|
-
os.makedirs(output_folder, exist_ok=True)
|
2309
|
-
overall_average_size = 0
|
2310
|
-
with np.load(path) as data:
|
2311
|
-
stack = data['data']
|
2312
|
-
filenames = data['filenames']
|
2313
|
-
if settings['timelapse']:
|
2314
|
-
|
2315
|
-
trackable_objects = ['cell','nucleus','pathogen']
|
2316
|
-
if not all_elements_match(settings['timelapse_objects'], trackable_objects):
|
2317
|
-
print(f'timelapse_objects {settings["timelapse_objects"]} must be a subset of {trackable_objects}')
|
2318
|
-
return
|
2701
|
+
for i in range(0, len(all_image_files), batch_size):
|
2702
|
+
image_files = all_image_files[i:i+batch_size]
|
2319
2703
|
|
2320
|
-
|
2321
|
-
|
2322
|
-
|
2323
|
-
|
2324
|
-
|
2325
|
-
|
2326
|
-
|
2327
|
-
|
2328
|
-
|
2329
|
-
|
2704
|
+
if normalize:
|
2705
|
+
images, _, image_names, _ = _load_normalized_images_and_labels(image_files=image_files, label_files=None, signal_thresholds=100, channels=channels, percentiles=percentiles, circular=circular, invert=invert, visualize=plot)
|
2706
|
+
images = [np.squeeze(img) if img.shape[-1] == 1 else img for img in images]
|
2707
|
+
orig_dims = [(image.shape[0], image.shape[1]) for image in images]
|
2708
|
+
else:
|
2709
|
+
images, _, image_names, _ = _load_images_and_labels(image_files=image_files, label_files=None, circular=circular, invert=invert)
|
2710
|
+
images = [np.squeeze(img) if img.shape[-1] == 1 else img for img in images]
|
2711
|
+
orig_dims = [(image.shape[0], image.shape[1]) for image in images]
|
2712
|
+
if resize:
|
2713
|
+
images, _ = resize_images_and_labels(images, None, target_height, target_width, True)
|
2330
2714
|
|
2331
|
-
for
|
2332
|
-
mask_stack = []
|
2715
|
+
for file_index, stack in enumerate(images):
|
2333
2716
|
start = time.time()
|
2717
|
+
output = model.eval(x=stack,
|
2718
|
+
normalize=False,
|
2719
|
+
channels=chans,
|
2720
|
+
channel_axis=3,
|
2721
|
+
diameter=diameter,
|
2722
|
+
flow_threshold=flow_threshold,
|
2723
|
+
cellprob_threshold=cellprob_threshold,
|
2724
|
+
rescale=False,
|
2725
|
+
resample=False,
|
2726
|
+
progress=True)
|
2334
2727
|
|
2335
|
-
if
|
2336
|
-
|
2728
|
+
if len(output) == 4:
|
2729
|
+
mask, flows, _, _ = output
|
2730
|
+
elif len(output) == 3:
|
2731
|
+
mask, flows, _ = output
|
2337
2732
|
else:
|
2338
|
-
|
2733
|
+
raise ValueError("Unexpected number of return values from model.eval()")
|
2339
2734
|
|
2340
|
-
|
2735
|
+
if resize:
|
2736
|
+
dims = orig_dims[file_index]
|
2737
|
+
mask = resizescikit(mask, dims, order=0, preserve_range=True, anti_aliasing=False).astype(mask.dtype)
|
2341
2738
|
|
2342
|
-
|
2343
|
-
|
2344
|
-
|
2345
|
-
|
2346
|
-
|
2347
|
-
if
|
2348
|
-
|
2739
|
+
stop = time.time()
|
2740
|
+
duration = (stop - start)
|
2741
|
+
time_ls.append(duration)
|
2742
|
+
average_time = np.mean(time_ls) if len(time_ls) > 0 else 0
|
2743
|
+
print(f'Processing {file_index+1}/{len(images)} images : Time/image {average_time:.3f} sec', end='\r', flush=True)
|
2744
|
+
if plot:
|
2745
|
+
if resize:
|
2746
|
+
stack = resizescikit(stack, dims, preserve_range=True, anti_aliasing=False).astype(stack.dtype)
|
2747
|
+
print_mask_and_flows(stack, mask, flows, overlay=True)
|
2748
|
+
if save:
|
2749
|
+
output_filename = os.path.join(dst, image_names[file_index])
|
2750
|
+
cv2.imwrite(output_filename, mask)
|
2349
2751
|
|
2350
|
-
if timelapse:
|
2351
|
-
stitch_threshold=100.0
|
2352
|
-
movie_path = os.path.join(os.path.dirname(src), 'movies')
|
2353
|
-
os.makedirs(movie_path, exist_ok=True)
|
2354
|
-
save_path = os.path.join(movie_path, f'timelapse_{object_type}_{name}.mp4')
|
2355
|
-
_npz_to_movie(batch, batch_filenames, save_path, fps=2)
|
2356
|
-
else:
|
2357
|
-
stitch_threshold=0.0
|
2358
2752
|
|
2359
|
-
|
2360
|
-
|
2361
|
-
|
2362
|
-
|
2363
|
-
|
2364
|
-
|
2365
|
-
|
2366
|
-
|
2367
|
-
|
2368
|
-
|
2369
|
-
|
2370
|
-
|
2753
|
+
def check_cellpose_models(settings):
|
2754
|
+
|
2755
|
+
src = settings['src']
|
2756
|
+
batch_size = settings['batch_size']
|
2757
|
+
cellprob_threshold = settings['cellprob_threshold']
|
2758
|
+
save = settings['save']
|
2759
|
+
normalize = settings['normalize']
|
2760
|
+
channels = settings['channels']
|
2761
|
+
percentiles = settings['percentiles']
|
2762
|
+
circular = settings['circular']
|
2763
|
+
invert = settings['invert']
|
2764
|
+
plot = settings['plot']
|
2765
|
+
diameter = settings['diameter']
|
2766
|
+
resize = settings['resize']
|
2767
|
+
grayscale = settings['grayscale']
|
2768
|
+
verbose = settings['verbose']
|
2769
|
+
target_height = settings['width_height'][0]
|
2770
|
+
target_width = settings['width_height'][1]
|
2771
|
+
|
2772
|
+
cellpose_models = ['cyto', 'nuclei', 'cyto2', 'cyto3']
|
2773
|
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
2774
|
+
|
2775
|
+
for model_name in cellpose_models:
|
2776
|
+
|
2777
|
+
model = cp_models.CellposeModel(gpu=True, model_type=model_name, device=device)
|
2778
|
+
print(f'Using {model_name}')
|
2779
|
+
generate_masks_from_imgs(src, model, model_name, batch_size, diameter, cellprob_threshold, grayscale, save, normalize, channels, percentiles, circular, invert, plot, resize, target_height, target_width, verbose)
|
2780
|
+
|
2781
|
+
return
|
2782
|
+
|
2783
|
+
def compare_masks_v1(dir1, dir2, dir3, verbose=False):
|
2784
|
+
|
2785
|
+
from .io import _read_mask
|
2786
|
+
from .plot import visualize_masks, plot_comparison_results
|
2787
|
+
from .utils import extract_boundaries, boundary_f1_score, compute_segmentation_ap, jaccard_index, dice_coefficient
|
2788
|
+
|
2789
|
+
filenames = os.listdir(dir1)
|
2790
|
+
results = []
|
2791
|
+
cond_1 = os.path.basename(dir1)
|
2792
|
+
cond_2 = os.path.basename(dir2)
|
2793
|
+
cond_3 = os.path.basename(dir3)
|
2794
|
+
|
2795
|
+
for index, filename in enumerate(filenames):
|
2796
|
+
print(f'Processing image:{index+1}', end='\r', flush=True)
|
2797
|
+
path1, path2, path3 = os.path.join(dir1, filename), os.path.join(dir2, filename), os.path.join(dir3, filename)
|
2798
|
+
|
2799
|
+
print(path1)
|
2800
|
+
print(path2)
|
2801
|
+
print(path3)
|
2802
|
+
|
2803
|
+
if os.path.exists(path2) and os.path.exists(path3):
|
2371
2804
|
|
2372
|
-
|
2805
|
+
mask1, mask2, mask3 = _read_mask(path1), _read_mask(path2), _read_mask(path3)
|
2806
|
+
boundary_true1, boundary_true2, boundary_true3 = extract_boundaries(mask1), extract_boundaries(mask2), extract_boundaries(mask3)
|
2807
|
+
|
2808
|
+
|
2809
|
+
true_masks, pred_masks = [mask1], [mask2, mask3] # Assuming mask1 is the ground truth for simplicity
|
2810
|
+
true_labels, pred_labels_1, pred_labels_2 = label(mask1), label(mask2), label(mask3)
|
2811
|
+
average_precision_0, average_precision_1 = compute_segmentation_ap(mask1, mask2), compute_segmentation_ap(mask1, mask3)
|
2812
|
+
ap_scores = [average_precision_0, average_precision_1]
|
2373
2813
|
|
2374
|
-
|
2375
|
-
|
2376
|
-
|
2377
|
-
|
2378
|
-
|
2379
|
-
|
2814
|
+
if verbose:
|
2815
|
+
#unique_values1, unique_values2, unique_values3 = np.unique(mask1), np.unique(mask2), np.unique(mask3)
|
2816
|
+
#print(f"Unique values in mask 1: {unique_values1}, mask 2: {unique_values2}, mask 3: {unique_values3}")
|
2817
|
+
visualize_masks(boundary_true1, boundary_true2, boundary_true3, title=f"Boundaries - {filename}")
|
2818
|
+
|
2819
|
+
boundary_f1_12, boundary_f1_13, boundary_f1_23 = boundary_f1_score(mask1, mask2), boundary_f1_score(mask1, mask3), boundary_f1_score(mask2, mask3)
|
2380
2820
|
|
2381
|
-
|
2382
|
-
|
2383
|
-
|
2384
|
-
|
2385
|
-
|
2386
|
-
|
2387
|
-
|
2821
|
+
if (np.unique(mask1).size == 1 and np.unique(mask1)[0] == 0) and \
|
2822
|
+
(np.unique(mask2).size == 1 and np.unique(mask2)[0] == 0) and \
|
2823
|
+
(np.unique(mask3).size == 1 and np.unique(mask3)[0] == 0):
|
2824
|
+
continue
|
2825
|
+
|
2826
|
+
if verbose:
|
2827
|
+
#unique_values4, unique_values5, unique_values6 = np.unique(boundary_f1_12), np.unique(boundary_f1_13), np.unique(boundary_f1_23)
|
2828
|
+
#print(f"Unique values in boundary mask 1: {unique_values4}, mask 2: {unique_values5}, mask 3: {unique_values6}")
|
2829
|
+
visualize_masks(mask1, mask2, mask3, title=filename)
|
2830
|
+
|
2831
|
+
jaccard12 = jaccard_index(mask1, mask2)
|
2832
|
+
dice12 = dice_coefficient(mask1, mask2)
|
2833
|
+
|
2834
|
+
jaccard13 = jaccard_index(mask1, mask3)
|
2835
|
+
dice13 = dice_coefficient(mask1, mask3)
|
2836
|
+
|
2837
|
+
jaccard23 = jaccard_index(mask2, mask3)
|
2838
|
+
dice23 = dice_coefficient(mask2, mask3)
|
2388
2839
|
|
2389
|
-
|
2390
|
-
|
2391
|
-
|
2840
|
+
results.append({
|
2841
|
+
f'filename': filename,
|
2842
|
+
f'jaccard_{cond_1}_{cond_2}': jaccard12,
|
2843
|
+
f'dice_{cond_1}_{cond_2}': dice12,
|
2844
|
+
f'jaccard_{cond_1}_{cond_3}': jaccard13,
|
2845
|
+
f'dice_{cond_1}_{cond_3}': dice13,
|
2846
|
+
f'jaccard_{cond_2}_{cond_3}': jaccard23,
|
2847
|
+
f'dice_{cond_2}_{cond_3}': dice23,
|
2848
|
+
f'boundary_f1_{cond_1}_{cond_2}': boundary_f1_12,
|
2849
|
+
f'boundary_f1_{cond_1}_{cond_3}': boundary_f1_13,
|
2850
|
+
f'boundary_f1_{cond_2}_{cond_3}': boundary_f1_23,
|
2851
|
+
f'average_precision_{cond_1}_{cond_2}': ap_scores[0],
|
2852
|
+
f'average_precision_{cond_1}_{cond_3}': ap_scores[1]
|
2853
|
+
})
|
2854
|
+
else:
|
2855
|
+
print(f'Cannot find {path1} or {path2} or {path3}')
|
2856
|
+
fig = plot_comparison_results(results)
|
2857
|
+
return results, fig
|
2392
2858
|
|
2393
|
-
|
2394
|
-
|
2395
|
-
|
2396
|
-
|
2397
|
-
plot=settings['plot'],
|
2398
|
-
save=settings['save'],
|
2399
|
-
masks_3D=masks,
|
2400
|
-
mode=timelapse_mode,
|
2401
|
-
timelapse_remove_transient=timelapse_remove_transient,
|
2402
|
-
radius=radius,
|
2403
|
-
workers=workers)
|
2404
|
-
if timelapse_mode == 'trackpy':
|
2405
|
-
mask_stack = _trackpy_track_cells(src=src,
|
2406
|
-
name=name,
|
2407
|
-
batch_filenames=batch_filenames,
|
2408
|
-
object_type=object_type,
|
2409
|
-
masks=masks,
|
2410
|
-
timelapse_displacement=timelapse_displacement,
|
2411
|
-
timelapse_memory=timelapse_memory,
|
2412
|
-
timelapse_remove_transient=timelapse_remove_transient,
|
2413
|
-
plot=settings['plot'],
|
2414
|
-
save=settings['save'],
|
2415
|
-
mode=timelapse_mode)
|
2416
|
-
else:
|
2417
|
-
mask_stack = _masks_to_masks_stack(masks)
|
2859
|
+
def compare_cellpose_masks_v1(src, verbose=False):
|
2860
|
+
from .io import _read_mask
|
2861
|
+
from .plot import visualize_masks, plot_comparison_results, visualize_cellpose_masks
|
2862
|
+
from .utils import extract_boundaries, boundary_f1_score, compute_segmentation_ap, jaccard_index
|
2418
2863
|
|
2419
|
-
|
2420
|
-
|
2421
|
-
|
2422
|
-
flows=flows,
|
2423
|
-
filter_size=object_settings['filter_size'],
|
2424
|
-
filter_intensity=object_settings['filter_intensity'],
|
2425
|
-
minimum_size=object_settings['minimum_size'],
|
2426
|
-
maximum_size=object_settings['maximum_size'],
|
2427
|
-
remove_border_objects=object_settings['remove_border_objects'],
|
2428
|
-
merge=False,
|
2429
|
-
batch=batch,
|
2430
|
-
plot=settings['plot'],
|
2431
|
-
figuresize=figuresize)
|
2432
|
-
|
2433
|
-
_save_object_counts_to_database(mask_stack, object_type, batch_filenames, count_loc, added_string='_after_filtration')
|
2864
|
+
import os
|
2865
|
+
import numpy as np
|
2866
|
+
from skimage.measure import label
|
2434
2867
|
|
2435
|
-
|
2436
|
-
|
2437
|
-
else:
|
2438
|
-
average_obj_size = _get_avg_object_size(mask_stack)
|
2868
|
+
# Collect all subdirectories in src
|
2869
|
+
dirs = [os.path.join(src, d) for d in os.listdir(src) if os.path.isdir(os.path.join(src, d))]
|
2439
2870
|
|
2440
|
-
|
2441
|
-
overall_average_size = np.mean(average_sizes) if len(average_sizes) > 0 else 0
|
2871
|
+
dirs.sort() # Optional: sort directories if needed
|
2442
2872
|
|
2443
|
-
|
2444
|
-
|
2445
|
-
|
2446
|
-
|
2447
|
-
|
2448
|
-
|
2449
|
-
|
2450
|
-
|
2451
|
-
|
2452
|
-
|
2453
|
-
|
2454
|
-
|
2455
|
-
|
2456
|
-
|
2457
|
-
|
2458
|
-
|
2459
|
-
|
2460
|
-
|
2461
|
-
|
2873
|
+
# Get common files in all directories
|
2874
|
+
common_files = set(os.listdir(dirs[0]))
|
2875
|
+
for d in dirs[1:]:
|
2876
|
+
common_files.intersection_update(os.listdir(d))
|
2877
|
+
common_files = list(common_files)
|
2878
|
+
|
2879
|
+
results = []
|
2880
|
+
conditions = [os.path.basename(d) for d in dirs]
|
2881
|
+
|
2882
|
+
for index, filename in enumerate(common_files):
|
2883
|
+
print(f'Processing image {index+1}/{len(common_files)}', end='\r', flush=True)
|
2884
|
+
paths = [os.path.join(d, filename) for d in dirs]
|
2885
|
+
|
2886
|
+
# Check if file exists in all directories
|
2887
|
+
if not all(os.path.exists(path) for path in paths):
|
2888
|
+
print(f'Skipping {filename} as it is not present in all directories.')
|
2889
|
+
continue
|
2890
|
+
|
2891
|
+
masks = [_read_mask(path) for path in paths]
|
2892
|
+
boundaries = [extract_boundaries(mask) for mask in masks]
|
2893
|
+
|
2894
|
+
if verbose:
|
2895
|
+
visualize_cellpose_masks(masks, titles=conditions, comparison_title=f"Masks Comparison for {filename}")
|
2896
|
+
|
2897
|
+
# Initialize data structure for results
|
2898
|
+
file_results = {'filename': filename}
|
2899
|
+
|
2900
|
+
# Compare each mask with each other
|
2901
|
+
for i in range(len(masks)):
|
2902
|
+
for j in range(i + 1, len(masks)):
|
2903
|
+
condition_i = conditions[i]
|
2904
|
+
condition_j = conditions[j]
|
2905
|
+
mask_i = masks[i]
|
2906
|
+
mask_j = masks[j]
|
2907
|
+
|
2908
|
+
# Compute metrics
|
2909
|
+
boundary_f1 = boundary_f1_score(mask_i, mask_j)
|
2910
|
+
jaccard = jaccard_index(mask_i, mask_j)
|
2911
|
+
average_precision = compute_segmentation_ap(mask_i, mask_j)
|
2912
|
+
|
2913
|
+
# Store results
|
2914
|
+
file_results[f'jaccard_{condition_i}_{condition_j}'] = jaccard
|
2915
|
+
file_results[f'boundary_f1_{condition_i}_{condition_j}'] = boundary_f1
|
2916
|
+
file_results[f'average_precision_{condition_i}_{condition_j}'] = average_precision
|
2917
|
+
|
2918
|
+
results.append(file_results)
|
2919
|
+
|
2920
|
+
fig = plot_comparison_results(results)
|
2921
|
+
return results, fig
|
2922
|
+
|
2923
|
+
def compare_mask(args):
|
2924
|
+
src, filename, dirs, conditions = args
|
2925
|
+
paths = [os.path.join(d, filename) for d in dirs]
|
2926
|
+
|
2927
|
+
if not all(os.path.exists(path) for path in paths):
|
2928
|
+
return None
|
2929
|
+
|
2930
|
+
from .io import _read_mask # Import here to avoid issues in multiprocessing
|
2931
|
+
from .utils import extract_boundaries, boundary_f1_score, compute_segmentation_ap, jaccard_index
|
2932
|
+
from .plot import plot_comparison_results
|
2933
|
+
|
2934
|
+
masks = [_read_mask(path) for path in paths]
|
2935
|
+
file_results = {'filename': filename}
|
2936
|
+
|
2937
|
+
for i in range(len(masks)):
|
2938
|
+
for j in range(i + 1, len(masks)):
|
2939
|
+
mask_i, mask_j = masks[i], masks[j]
|
2940
|
+
f1_score = boundary_f1_score(mask_i, mask_j)
|
2941
|
+
jac_index = jaccard_index(mask_i, mask_j)
|
2942
|
+
ap_score = compute_segmentation_ap(mask_i, mask_j)
|
2943
|
+
|
2944
|
+
file_results.update({
|
2945
|
+
f'jaccard_{conditions[i]}_{conditions[j]}': jac_index,
|
2946
|
+
f'boundary_f1_{conditions[i]}_{conditions[j]}': f1_score,
|
2947
|
+
f'ap_{conditions[i]}_{conditions[j]}': ap_score
|
2948
|
+
})
|
2949
|
+
|
2950
|
+
return file_results
|
2951
|
+
|
2952
|
+
def compare_cellpose_masks(src, verbose=False, processes=None):
|
2953
|
+
from .plot import visualize_cellpose_masks, plot_comparison_results
|
2954
|
+
from .io import _read_mask
|
2955
|
+
dirs = [os.path.join(src, d) for d in os.listdir(src) if os.path.isdir(os.path.join(src, d))]
|
2956
|
+
dirs.sort() # Optional: sort directories if needed
|
2957
|
+
conditions = [os.path.basename(d) for d in dirs]
|
2958
|
+
|
2959
|
+
# Get common files in all directories
|
2960
|
+
common_files = set(os.listdir(dirs[0]))
|
2961
|
+
for d in dirs[1:]:
|
2962
|
+
common_files.intersection_update(os.listdir(d))
|
2963
|
+
common_files = list(common_files)
|
2964
|
+
|
2965
|
+
# Create a pool of workers
|
2966
|
+
with Pool(processes=processes) as pool:
|
2967
|
+
args = [(src, filename, dirs, conditions) for filename in common_files]
|
2968
|
+
results = pool.map(compare_mask, args)
|
2969
|
+
|
2970
|
+
# Filter out None results (from skipped files)
|
2971
|
+
results = [res for res in results if res is not None]
|
2972
|
+
|
2973
|
+
if verbose:
|
2974
|
+
for result in results:
|
2975
|
+
filename = result['filename']
|
2976
|
+
masks = [_read_mask(os.path.join(d, filename)) for d in dirs]
|
2977
|
+
visualize_cellpose_masks(masks, titles=conditions, comparison_title=f"Masks Comparison for {filename}")
|
2978
|
+
|
2979
|
+
fig = plot_comparison_results(results)
|
2980
|
+
return results, fig
|
2981
|
+
|
2982
|
+
|
2983
|
+
def _calculate_similarity(df, features, col_to_compare, val1, val2):
|
2984
|
+
"""
|
2985
|
+
Calculate similarity scores of each well to the positive and negative controls using various metrics.
|
2986
|
+
|
2987
|
+
Args:
|
2988
|
+
df (pandas.DataFrame): DataFrame containing the data.
|
2989
|
+
features (list): List of feature columns to use for similarity calculation.
|
2990
|
+
col_to_compare (str): Column name to use for comparing groups.
|
2991
|
+
val1, val2 (str): Values in col_to_compare to create subsets for comparison.
|
2992
|
+
|
2993
|
+
Returns:
|
2994
|
+
pandas.DataFrame: DataFrame with similarity scores.
|
2995
|
+
"""
|
2996
|
+
# Separate positive and negative control wells
|
2997
|
+
pos_control = df[df[col_to_compare] == val1][features].mean()
|
2998
|
+
neg_control = df[df[col_to_compare] == val2][features].mean()
|
2999
|
+
|
3000
|
+
# Standardize features for Mahalanobis distance
|
3001
|
+
scaler = StandardScaler()
|
3002
|
+
scaled_features = scaler.fit_transform(df[features])
|
3003
|
+
|
3004
|
+
# Regularize the covariance matrix to avoid singularity
|
3005
|
+
cov_matrix = np.cov(scaled_features, rowvar=False)
|
3006
|
+
inv_cov_matrix = None
|
3007
|
+
try:
|
3008
|
+
inv_cov_matrix = np.linalg.inv(cov_matrix)
|
3009
|
+
except np.linalg.LinAlgError:
|
3010
|
+
# Add a small value to the diagonal elements for regularization
|
3011
|
+
epsilon = 1e-5
|
3012
|
+
inv_cov_matrix = np.linalg.inv(cov_matrix + np.eye(cov_matrix.shape[0]) * epsilon)
|
3013
|
+
|
3014
|
+
# Calculate similarity scores
|
3015
|
+
df['similarity_to_pos_euclidean'] = df[features].apply(lambda row: euclidean(row, pos_control), axis=1)
|
3016
|
+
df['similarity_to_neg_euclidean'] = df[features].apply(lambda row: euclidean(row, neg_control), axis=1)
|
3017
|
+
df['similarity_to_pos_cosine'] = df[features].apply(lambda row: cosine(row, pos_control), axis=1)
|
3018
|
+
df['similarity_to_neg_cosine'] = df[features].apply(lambda row: cosine(row, neg_control), axis=1)
|
3019
|
+
df['similarity_to_pos_mahalanobis'] = df[features].apply(lambda row: mahalanobis(row, pos_control, inv_cov_matrix), axis=1)
|
3020
|
+
df['similarity_to_neg_mahalanobis'] = df[features].apply(lambda row: mahalanobis(row, neg_control, inv_cov_matrix), axis=1)
|
3021
|
+
df['similarity_to_pos_manhattan'] = df[features].apply(lambda row: cityblock(row, pos_control), axis=1)
|
3022
|
+
df['similarity_to_neg_manhattan'] = df[features].apply(lambda row: cityblock(row, neg_control), axis=1)
|
3023
|
+
df['similarity_to_pos_minkowski'] = df[features].apply(lambda row: minkowski(row, pos_control, p=3), axis=1)
|
3024
|
+
df['similarity_to_neg_minkowski'] = df[features].apply(lambda row: minkowski(row, neg_control, p=3), axis=1)
|
3025
|
+
df['similarity_to_pos_chebyshev'] = df[features].apply(lambda row: chebyshev(row, pos_control), axis=1)
|
3026
|
+
df['similarity_to_neg_chebyshev'] = df[features].apply(lambda row: chebyshev(row, neg_control), axis=1)
|
3027
|
+
df['similarity_to_pos_hamming'] = df[features].apply(lambda row: hamming(row, pos_control), axis=1)
|
3028
|
+
df['similarity_to_neg_hamming'] = df[features].apply(lambda row: hamming(row, neg_control), axis=1)
|
3029
|
+
df['similarity_to_pos_jaccard'] = df[features].apply(lambda row: jaccard(row, pos_control), axis=1)
|
3030
|
+
df['similarity_to_neg_jaccard'] = df[features].apply(lambda row: jaccard(row, neg_control), axis=1)
|
3031
|
+
df['similarity_to_pos_braycurtis'] = df[features].apply(lambda row: braycurtis(row, pos_control), axis=1)
|
3032
|
+
df['similarity_to_neg_braycurtis'] = df[features].apply(lambda row: braycurtis(row, neg_control), axis=1)
|
3033
|
+
|
3034
|
+
return df
|
3035
|
+
|
3036
|
+
def _permutation_importance(df, feature_string='channel_3', col_to_compare='col', pos='c1', neg='c2', exclude=None, n_repeats=10, clean=True, nr_to_plot=30, n_estimators=100, test_size=0.2, random_state=42, model_type='xgboost', n_jobs=-1):
|
3037
|
+
|
3038
|
+
"""
|
3039
|
+
Calculates permutation importance for numerical features in the dataframe,
|
3040
|
+
comparing groups based on specified column values and uses the model to predict
|
3041
|
+
the class for all other rows in the dataframe.
|
3042
|
+
|
3043
|
+
Args:
|
3044
|
+
df (pandas.DataFrame): The DataFrame containing the data.
|
3045
|
+
feature_string (str): String to filter features that contain this substring.
|
3046
|
+
col_to_compare (str): Column name to use for comparing groups.
|
3047
|
+
pos, neg (str): Values in col_to_compare to create subsets for comparison.
|
3048
|
+
exclude (list or str, optional): Columns to exclude from features.
|
3049
|
+
n_repeats (int): Number of repeats for permutation importance.
|
3050
|
+
clean (bool): Whether to remove columns with a single value.
|
3051
|
+
nr_to_plot (int): Number of top features to plot based on permutation importance.
|
3052
|
+
n_estimators (int): Number of trees in the random forest, gradient boosting, or XGBoost model.
|
3053
|
+
test_size (float): Proportion of the dataset to include in the test split.
|
3054
|
+
random_state (int): Random seed for reproducibility.
|
3055
|
+
model_type (str): Type of model to use ('random_forest', 'logistic_regression', 'gradient_boosting', 'xgboost').
|
3056
|
+
n_jobs (int): Number of jobs to run in parallel for applicable models.
|
3057
|
+
|
3058
|
+
Returns:
|
3059
|
+
pandas.DataFrame: The original dataframe with added prediction and data usage columns.
|
3060
|
+
pandas.DataFrame: DataFrame containing the importances and standard deviations.
|
3061
|
+
"""
|
3062
|
+
|
3063
|
+
if 'cells_per_well' in df.columns:
|
3064
|
+
df = df.drop(columns=['cells_per_well'])
|
3065
|
+
|
3066
|
+
# Subset the dataframe based on specified column values
|
3067
|
+
df1 = df[df[col_to_compare] == pos].copy()
|
3068
|
+
df2 = df[df[col_to_compare] == neg].copy()
|
3069
|
+
|
3070
|
+
# Create target variable
|
3071
|
+
df1['target'] = 0
|
3072
|
+
df2['target'] = 1
|
3073
|
+
|
3074
|
+
# Combine the subsets for analysis
|
3075
|
+
combined_df = pd.concat([df1, df2])
|
3076
|
+
|
3077
|
+
# Automatically select numerical features
|
3078
|
+
features = combined_df.select_dtypes(include=[np.number]).columns.tolist()
|
3079
|
+
features.remove('target')
|
3080
|
+
|
3081
|
+
if clean:
|
3082
|
+
combined_df = combined_df.loc[:, combined_df.nunique() > 1]
|
3083
|
+
features = [feature for feature in features if feature in combined_df.columns]
|
3084
|
+
|
3085
|
+
if feature_string is not None:
|
3086
|
+
feature_list = ['channel_0', 'channel_1', 'channel_2', 'channel_3']
|
3087
|
+
|
3088
|
+
# Remove feature_string from the list if it exists
|
3089
|
+
if feature_string in feature_list:
|
3090
|
+
feature_list.remove(feature_string)
|
3091
|
+
|
3092
|
+
features = [feature for feature in features if feature_string in feature]
|
3093
|
+
|
3094
|
+
# Iterate through the list and remove columns from df
|
3095
|
+
for feature_ in feature_list:
|
3096
|
+
features = [feature for feature in features if feature_ not in feature]
|
3097
|
+
print(f'After removing {feature_} features: {len(features)}')
|
3098
|
+
|
3099
|
+
if exclude:
|
3100
|
+
if isinstance(exclude, list):
|
3101
|
+
features = [feature for feature in features if feature not in exclude]
|
3102
|
+
else:
|
3103
|
+
features.remove(exclude)
|
3104
|
+
|
3105
|
+
X = combined_df[features]
|
3106
|
+
y = combined_df['target']
|
3107
|
+
|
3108
|
+
# Split the data into training and testing sets
|
3109
|
+
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size, random_state=random_state)
|
3110
|
+
|
3111
|
+
# Label the data in the original dataframe
|
3112
|
+
combined_df['data_usage'] = 'train'
|
3113
|
+
combined_df.loc[X_test.index, 'data_usage'] = 'test'
|
3114
|
+
|
3115
|
+
# Initialize the model based on model_type
|
3116
|
+
if model_type == 'random_forest':
|
3117
|
+
model = RandomForestClassifier(n_estimators=n_estimators, random_state=random_state, n_jobs=n_jobs)
|
3118
|
+
elif model_type == 'logistic_regression':
|
3119
|
+
model = LogisticRegression(max_iter=1000, random_state=random_state, n_jobs=n_jobs)
|
3120
|
+
elif model_type == 'gradient_boosting':
|
3121
|
+
model = HistGradientBoostingClassifier(max_iter=n_estimators, random_state=random_state) # Supports n_jobs internally
|
3122
|
+
elif model_type == 'xgboost':
|
3123
|
+
model = XGBClassifier(n_estimators=n_estimators, random_state=random_state, nthread=n_jobs, use_label_encoder=False, eval_metric='logloss')
|
3124
|
+
else:
|
3125
|
+
raise ValueError(f"Unsupported model_type: {model_type}")
|
3126
|
+
|
3127
|
+
model.fit(X_train, y_train)
|
3128
|
+
|
3129
|
+
perm_importance = permutation_importance(model, X_train, y_train, n_repeats=n_repeats, random_state=random_state, n_jobs=n_jobs)
|
3130
|
+
|
3131
|
+
# Create a DataFrame for permutation importances
|
3132
|
+
permutation_df = pd.DataFrame({
|
3133
|
+
'feature': [features[i] for i in perm_importance.importances_mean.argsort()],
|
3134
|
+
'importance_mean': perm_importance.importances_mean[perm_importance.importances_mean.argsort()],
|
3135
|
+
'importance_std': perm_importance.importances_std[perm_importance.importances_mean.argsort()]
|
3136
|
+
}).tail(nr_to_plot)
|
3137
|
+
|
3138
|
+
# Plotting
|
3139
|
+
fig, ax = plt.subplots()
|
3140
|
+
ax.barh(permutation_df['feature'], permutation_df['importance_mean'], xerr=permutation_df['importance_std'], color="teal", align="center", alpha=0.6)
|
3141
|
+
ax.set_xlabel('Permutation Importance')
|
3142
|
+
plt.tight_layout()
|
3143
|
+
plt.show()
|
3144
|
+
|
3145
|
+
# Feature importance for models that support it
|
3146
|
+
if model_type in ['random_forest', 'xgboost', 'gradient_boosting']:
|
3147
|
+
feature_importances = model.feature_importances_
|
3148
|
+
feature_importance_df = pd.DataFrame({
|
3149
|
+
'feature': features,
|
3150
|
+
'importance': feature_importances
|
3151
|
+
}).sort_values(by='importance', ascending=False).head(nr_to_plot)
|
3152
|
+
|
3153
|
+
# Plotting feature importance
|
3154
|
+
fig, ax = plt.subplots()
|
3155
|
+
ax.barh(feature_importance_df['feature'], feature_importance_df['importance'], color="blue", align="center", alpha=0.6)
|
3156
|
+
ax.set_xlabel('Feature Importance')
|
3157
|
+
plt.tight_layout()
|
3158
|
+
plt.show()
|
3159
|
+
else:
|
3160
|
+
feature_importance_df = pd.DataFrame()
|
3161
|
+
|
3162
|
+
# Predicting the target variable for the test set
|
3163
|
+
predictions_test = model.predict(X_test)
|
3164
|
+
combined_df.loc[X_test.index, 'predictions'] = predictions_test
|
3165
|
+
|
3166
|
+
# Predicting the target variable for the training set
|
3167
|
+
predictions_train = model.predict(X_train)
|
3168
|
+
combined_df.loc[X_train.index, 'predictions'] = predictions_train
|
3169
|
+
|
3170
|
+
# Predicting the target variable for all other rows in the dataframe
|
3171
|
+
X_all = df[features]
|
3172
|
+
all_predictions = model.predict(X_all)
|
3173
|
+
df['predictions'] = all_predictions
|
3174
|
+
|
3175
|
+
# Combine data usage labels back to the original dataframe
|
3176
|
+
combined_data_usage = pd.concat([combined_df[['data_usage']], df[['predictions']]], axis=0)
|
3177
|
+
df = df.join(combined_data_usage, how='left', rsuffix='_model')
|
3178
|
+
|
3179
|
+
# Calculating and printing the accuracy metrics
|
3180
|
+
accuracy = accuracy_score(y_test, predictions_test)
|
3181
|
+
precision = precision_score(y_test, predictions_test)
|
3182
|
+
recall = recall_score(y_test, predictions_test)
|
3183
|
+
f1 = f1_score(y_test, predictions_test)
|
3184
|
+
print(f"Accuracy: {accuracy}")
|
3185
|
+
print(f"Precision: {precision}")
|
3186
|
+
print(f"Recall: {recall}")
|
3187
|
+
print(f"F1 Score: {f1}")
|
3188
|
+
|
3189
|
+
# Printing class-specific accuracy metrics
|
3190
|
+
print("\nClassification Report:")
|
3191
|
+
print(classification_report(y_test, predictions_test))
|
3192
|
+
|
3193
|
+
df = _calculate_similarity(df, features, col_to_compare, pos, neg)
|
3194
|
+
|
3195
|
+
return [df, permutation_df, feature_importance_df, model, X_train, X_test, y_train, y_test]
|
3196
|
+
|
3197
|
+
def _shap_analysis(model, X_train, X_test):
|
3198
|
+
|
3199
|
+
"""
|
3200
|
+
Performs SHAP analysis on the given model and data.
|
3201
|
+
|
3202
|
+
Args:
|
3203
|
+
model: The trained model.
|
3204
|
+
X_train (pandas.DataFrame): Training feature set.
|
3205
|
+
X_test (pandas.DataFrame): Testing feature set.
|
3206
|
+
"""
|
3207
|
+
|
3208
|
+
explainer = shap.Explainer(model, X_train)
|
3209
|
+
shap_values = explainer(X_test)
|
3210
|
+
|
3211
|
+
# Summary plot
|
3212
|
+
shap.summary_plot(shap_values, X_test)
|
3213
|
+
|
3214
|
+
def plate_heatmap(src, model_type='xgboost', variable='predictions', grouping='mean', min_max='allq', cmap='viridis', channel_of_interest=3, min_count=25, n_estimators=100, col_to_compare='col', pos='c1', neg='c2', exclude=None, n_repeats=10, clean=True, nr_to_plot=20, verbose=False, n_jobs=-1):
|
3215
|
+
from .io import _read_and_merge_data
|
3216
|
+
from .plot import _plot_plates
|
3217
|
+
|
3218
|
+
db_loc = [src+'/measurements/measurements.db']
|
3219
|
+
tables = ['cell', 'nucleus', 'pathogen','cytoplasm']
|
3220
|
+
include_multinucleated, include_multiinfected, include_noninfected = True, 2.0, True
|
3221
|
+
|
3222
|
+
df, _ = _read_and_merge_data(db_loc,
|
3223
|
+
tables,
|
3224
|
+
verbose=verbose,
|
3225
|
+
include_multinucleated=include_multinucleated,
|
3226
|
+
include_multiinfected=include_multiinfected,
|
3227
|
+
include_noninfected=include_noninfected)
|
3228
|
+
|
3229
|
+
if not channel_of_interest is None:
|
3230
|
+
df['recruitment'] = df[f'pathogen_channel_{channel_of_interest}_mean_intensity']/df[f'cytoplasm_channel_{channel_of_interest}_mean_intensity']
|
3231
|
+
feature_string = f'channel_{channel_of_interest}'
|
3232
|
+
else:
|
3233
|
+
feature_string = None
|
3234
|
+
|
3235
|
+
output = _permutation_importance(df, feature_string, col_to_compare, pos, neg, exclude, n_repeats, clean, nr_to_plot, n_estimators=n_estimators, random_state=42, model_type=model_type, n_jobs=n_jobs)
|
3236
|
+
|
3237
|
+
_shap_analysis(output[3], output[4], output[5])
|
3238
|
+
|
3239
|
+
features = output[0].select_dtypes(include=[np.number]).columns.tolist()
|
3240
|
+
|
3241
|
+
if not variable in features:
|
3242
|
+
raise ValueError(f"Variable {variable} not found in the dataframe. Please choose one of the following: {features}")
|
3243
|
+
|
3244
|
+
plate_heatmap = _plot_plates(output[0], variable, grouping, min_max, cmap, min_count)
|
3245
|
+
return [output, plate_heatmap]
|
3246
|
+
|
3247
|
+
def join_measurments_and_annotation(src, tables = ['cell', 'nucleus', 'pathogen','cytoplasm']):
|
3248
|
+
|
3249
|
+
from .io import _read_and_merge_data, _read_db
|
3250
|
+
|
3251
|
+
db_loc = [src+'/measurements/measurements.db']
|
3252
|
+
loc = src+'/measurements/measurements.db'
|
3253
|
+
df, _ = _read_and_merge_data(db_loc,
|
3254
|
+
tables,
|
3255
|
+
verbose=True,
|
3256
|
+
include_multinucleated=True,
|
3257
|
+
include_multiinfected=True,
|
3258
|
+
include_noninfected=True)
|
3259
|
+
|
3260
|
+
paths_df = _read_db(loc, tables=['png_list'])
|
3261
|
+
|
3262
|
+
merged_df = pd.merge(df, paths_df[0], on='prcfo', how='left')
|
3263
|
+
|
3264
|
+
return merged_df
|
3265
|
+
|
3266
|
+
def jitterplot_by_annotation(src, x_column, y_column, plot_title='Jitter Plot', output_path=None, filter_column=None, filter_values=None):
|
3267
|
+
"""
|
3268
|
+
Reads a CSV file and creates a jitter plot of one column grouped by another column.
|
3269
|
+
|
3270
|
+
Args:
|
3271
|
+
src (str): Path to the source data.
|
3272
|
+
x_column (str): Name of the column to be used for the x-axis.
|
3273
|
+
y_column (str): Name of the column to be used for the y-axis.
|
3274
|
+
plot_title (str): Title of the plot. Default is 'Jitter Plot'.
|
3275
|
+
output_path (str): Path to save the plot image. If None, the plot will be displayed. Default is None.
|
3276
|
+
|
3277
|
+
Returns:
|
3278
|
+
pd.DataFrame: The filtered and balanced DataFrame.
|
3279
|
+
"""
|
3280
|
+
# Read the CSV file into a DataFrame
|
3281
|
+
df = join_measurments_and_annotation(src, tables=['cell', 'nucleus', 'pathogen', 'cytoplasm'])
|
3282
|
+
|
3283
|
+
# Print column names for debugging
|
3284
|
+
print(f"Generated dataframe with: {df.shape[1]} columns and {df.shape[0]} rows")
|
3285
|
+
#print("Columns in DataFrame:", df.columns.tolist())
|
3286
|
+
|
3287
|
+
# Replace NaN values with a specific label in x_column
|
3288
|
+
df[x_column] = df[x_column].fillna('NaN')
|
3289
|
+
|
3290
|
+
# Filter the DataFrame if filter_column and filter_values are provided
|
3291
|
+
if not filter_column is None:
|
3292
|
+
if isinstance(filter_column, str):
|
3293
|
+
df = df[df[filter_column].isin(filter_values)]
|
3294
|
+
if isinstance(filter_column, list):
|
3295
|
+
for i,val in enumerate(filter_column):
|
3296
|
+
print(f'hello {len(df)}')
|
3297
|
+
df = df[df[val].isin(filter_values[i])]
|
3298
|
+
|
3299
|
+
# Use the correct column names based on your DataFrame
|
3300
|
+
required_columns = ['plate_x', 'row_x', 'col_x']
|
3301
|
+
if not all(column in df.columns for column in required_columns):
|
3302
|
+
raise KeyError(f"DataFrame does not contain the necessary columns: {required_columns}")
|
3303
|
+
|
3304
|
+
# Filter to retain rows with non-NaN values in x_column and with matching plate, row, col values
|
3305
|
+
non_nan_df = df[df[x_column] != 'NaN']
|
3306
|
+
retained_rows = df[df[['plate_x', 'row_x', 'col_x']].apply(tuple, axis=1).isin(non_nan_df[['plate_x', 'row_x', 'col_x']].apply(tuple, axis=1))]
|
3307
|
+
|
3308
|
+
# Determine the minimum count of examples across all groups in x_column
|
3309
|
+
min_count = retained_rows[x_column].value_counts().min()
|
3310
|
+
print(f'Found {min_count} annotated images')
|
3311
|
+
|
3312
|
+
# Randomly sample min_count examples from each group in x_column
|
3313
|
+
balanced_df = retained_rows.groupby(x_column).apply(lambda x: x.sample(min_count, random_state=42)).reset_index(drop=True)
|
3314
|
+
|
3315
|
+
# Create the jitter plot
|
3316
|
+
plt.figure(figsize=(10, 6))
|
3317
|
+
jitter_plot = sns.stripplot(data=balanced_df, x=x_column, y=y_column, hue=x_column, jitter=True, palette='viridis', dodge=False)
|
3318
|
+
plt.title(plot_title)
|
3319
|
+
plt.xlabel(x_column)
|
3320
|
+
plt.ylabel(y_column)
|
3321
|
+
|
3322
|
+
# Customize the x-axis labels
|
3323
|
+
plt.xticks(rotation=45, ha='right')
|
3324
|
+
|
3325
|
+
# Adjust the position of the x-axis labels to be centered below the data
|
3326
|
+
ax = plt.gca()
|
3327
|
+
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='center')
|
3328
|
+
|
3329
|
+
# Save the plot to a file or display it
|
3330
|
+
if output_path:
|
3331
|
+
plt.savefig(output_path, bbox_inches='tight')
|
3332
|
+
print(f"Jitter plot saved to {output_path}")
|
3333
|
+
else:
|
3334
|
+
plt.show()
|
3335
|
+
|
3336
|
+
return balanced_df
|