spacr 0.0.18__py3-none-any.whl → 0.0.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- spacr/alpha.py +291 -14
- spacr/annotate_app.py +2 -2
- spacr/core.py +1377 -296
- spacr/foldseek.py +793 -0
- spacr/get_alfafold_structures.py +72 -0
- spacr/graph_learning.py +259 -65
- spacr/graph_learning_lap.py +73 -71
- spacr/gui_classify_app.py +5 -21
- spacr/gui_mask_app.py +36 -30
- spacr/gui_measure_app.py +10 -24
- spacr/gui_utils.py +82 -54
- spacr/io.py +505 -205
- spacr/measure.py +160 -80
- spacr/old_code.py +155 -1
- spacr/plot.py +243 -99
- spacr/sim.py +666 -119
- spacr/timelapse.py +343 -52
- spacr/train.py +18 -10
- spacr/utils.py +252 -151
- {spacr-0.0.18.dist-info → spacr-0.0.21.dist-info}/METADATA +32 -27
- spacr-0.0.21.dist-info/RECORD +33 -0
- {spacr-0.0.18.dist-info → spacr-0.0.21.dist-info}/WHEEL +1 -1
- spacr/gui_temp.py +0 -212
- spacr/test_annotate_app.py +0 -58
- spacr/test_plot.py +0 -43
- spacr/test_train.py +0 -39
- spacr/test_utils.py +0 -33
- spacr-0.0.18.dist-info/RECORD +0 -36
- {spacr-0.0.18.dist-info → spacr-0.0.21.dist-info}/LICENSE +0 -0
- {spacr-0.0.18.dist-info → spacr-0.0.21.dist-info}/entry_points.txt +0 -0
- {spacr-0.0.18.dist-info → spacr-0.0.21.dist-info}/top_level.txt +0 -0
spacr/core.py
CHANGED
@@ -1,12 +1,13 @@
|
|
1
|
-
import os, sqlite3, gc, torch, time, random, shutil, cv2, tarfile, datetime
|
1
|
+
import os, sqlite3, gc, torch, time, random, shutil, cv2, tarfile, datetime, shap, string
|
2
2
|
|
3
3
|
# image and array processing
|
4
4
|
import numpy as np
|
5
5
|
import pandas as pd
|
6
6
|
|
7
|
+
from cellpose import train
|
7
8
|
import cellpose
|
8
9
|
from cellpose import models as cp_models
|
9
|
-
from cellpose import
|
10
|
+
from cellpose.models import CellposeModel
|
10
11
|
|
11
12
|
import statsmodels.formula.api as smf
|
12
13
|
import statsmodels.api as sm
|
@@ -28,18 +29,18 @@ matplotlib.use('Agg')
|
|
28
29
|
|
29
30
|
import torchvision.transforms as transforms
|
30
31
|
from sklearn.model_selection import train_test_split
|
31
|
-
from sklearn.ensemble import IsolationForest
|
32
|
-
|
32
|
+
from sklearn.ensemble import IsolationForest, RandomForestClassifier, HistGradientBoostingClassifier
|
33
33
|
from .logger import log_function_call
|
34
34
|
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
35
|
+
from sklearn.linear_model import LogisticRegression
|
36
|
+
from sklearn.inspection import permutation_importance
|
37
|
+
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, classification_report
|
38
|
+
from xgboost import XGBClassifier
|
39
|
+
|
40
|
+
from scipy.spatial.distance import cosine, euclidean, mahalanobis, cityblock, minkowski, chebyshev, hamming, jaccard, braycurtis
|
41
|
+
from sklearn.preprocessing import StandardScaler
|
42
|
+
import shap
|
41
43
|
|
42
|
-
@log_function_call
|
43
44
|
def analyze_plaques(folder):
|
44
45
|
summary_data = []
|
45
46
|
details_data = []
|
@@ -76,75 +77,6 @@ def analyze_plaques(folder):
|
|
76
77
|
|
77
78
|
print(f"Analysis completed and saved to database '{db_name}'.")
|
78
79
|
|
79
|
-
@log_function_call
|
80
|
-
def compare_masks(dir1, dir2, dir3, verbose=False):
|
81
|
-
|
82
|
-
from .io import _read_mask
|
83
|
-
from .plot import visualize_masks, plot_comparison_results
|
84
|
-
from .utils import extract_boundaries, boundary_f1_score, compute_segmentation_ap, jaccard_index, dice_coefficient
|
85
|
-
|
86
|
-
filenames = os.listdir(dir1)
|
87
|
-
results = []
|
88
|
-
cond_1 = os.path.basename(dir1)
|
89
|
-
cond_2 = os.path.basename(dir2)
|
90
|
-
cond_3 = os.path.basename(dir3)
|
91
|
-
for index, filename in enumerate(filenames):
|
92
|
-
print(f'Processing image:{index+1}', end='\r', flush=True)
|
93
|
-
path1, path2, path3 = os.path.join(dir1, filename), os.path.join(dir2, filename), os.path.join(dir3, filename)
|
94
|
-
if os.path.exists(path2) and os.path.exists(path3):
|
95
|
-
|
96
|
-
mask1, mask2, mask3 = _read_mask(path1), _read_mask(path2), _read_mask(path3)
|
97
|
-
boundary_true1, boundary_true2, boundary_true3 = extract_boundaries(mask1), extract_boundaries(mask2), extract_boundaries(mask3)
|
98
|
-
|
99
|
-
|
100
|
-
true_masks, pred_masks = [mask1], [mask2, mask3] # Assuming mask1 is the ground truth for simplicity
|
101
|
-
true_labels, pred_labels_1, pred_labels_2 = label(mask1), label(mask2), label(mask3)
|
102
|
-
average_precision_0, average_precision_1 = compute_segmentation_ap(mask1, mask2), compute_segmentation_ap(mask1, mask3)
|
103
|
-
ap_scores = [average_precision_0, average_precision_1]
|
104
|
-
|
105
|
-
if verbose:
|
106
|
-
unique_values1, unique_values2, unique_values3 = np.unique(mask1), np.unique(mask2), np.unique(mask3)
|
107
|
-
print(f"Unique values in mask 1: {unique_values1}, mask 2: {unique_values2}, mask 3: {unique_values3}")
|
108
|
-
visualize_masks(boundary_true1, boundary_true2, boundary_true3, title=f"Boundaries - {filename}")
|
109
|
-
|
110
|
-
boundary_f1_12, boundary_f1_13, boundary_f1_23 = boundary_f1_score(mask1, mask2), boundary_f1_score(mask1, mask3), boundary_f1_score(mask2, mask3)
|
111
|
-
|
112
|
-
if (np.unique(mask1).size == 1 and np.unique(mask1)[0] == 0) and \
|
113
|
-
(np.unique(mask2).size == 1 and np.unique(mask2)[0] == 0) and \
|
114
|
-
(np.unique(mask3).size == 1 and np.unique(mask3)[0] == 0):
|
115
|
-
continue
|
116
|
-
|
117
|
-
if verbose:
|
118
|
-
unique_values4, unique_values5, unique_values6 = np.unique(boundary_f1_12), np.unique(boundary_f1_13), np.unique(boundary_f1_23)
|
119
|
-
print(f"Unique values in boundary mask 1: {unique_values4}, mask 2: {unique_values5}, mask 3: {unique_values6}")
|
120
|
-
visualize_masks(mask1, mask2, mask3, title=filename)
|
121
|
-
|
122
|
-
jaccard12 = jaccard_index(mask1, mask2)
|
123
|
-
dice12 = dice_coefficient(mask1, mask2)
|
124
|
-
jaccard13 = jaccard_index(mask1, mask3)
|
125
|
-
dice13 = dice_coefficient(mask1, mask3)
|
126
|
-
jaccard23 = jaccard_index(mask2, mask3)
|
127
|
-
dice23 = dice_coefficient(mask2, mask3)
|
128
|
-
|
129
|
-
results.append({
|
130
|
-
f'filename': filename,
|
131
|
-
f'jaccard_{cond_1}_{cond_2}': jaccard12,
|
132
|
-
f'dice_{cond_1}_{cond_2}': dice12,
|
133
|
-
f'jaccard_{cond_1}_{cond_3}': jaccard13,
|
134
|
-
f'dice_{cond_1}_{cond_3}': dice13,
|
135
|
-
f'jaccard_{cond_2}_{cond_3}': jaccard23,
|
136
|
-
f'dice_{cond_2}_{cond_3}': dice23,
|
137
|
-
f'boundary_f1_{cond_1}_{cond_2}': boundary_f1_12,
|
138
|
-
f'boundary_f1_{cond_1}_{cond_3}': boundary_f1_13,
|
139
|
-
f'boundary_f1_{cond_2}_{cond_3}': boundary_f1_23,
|
140
|
-
f'average_precision_{cond_1}_{cond_2}': ap_scores[0],
|
141
|
-
f'average_precision_{cond_1}_{cond_3}': ap_scores[1]
|
142
|
-
})
|
143
|
-
else:
|
144
|
-
print(f'Cannot find {path1} or {path2} or {path3}')
|
145
|
-
fig = plot_comparison_results(results)
|
146
|
-
return results, fig
|
147
|
-
|
148
80
|
def generate_cp_masks(settings):
|
149
81
|
|
150
82
|
src = settings['src']
|
@@ -178,18 +110,155 @@ def generate_cp_masks(settings):
|
|
178
110
|
|
179
111
|
dst = os.path.join(src,'masks')
|
180
112
|
os.makedirs(dst, exist_ok=True)
|
181
|
-
|
113
|
+
|
182
114
|
identify_masks(src, dst, model_name, channels, diameter, batch_size, flow_threshold, cellprob_threshold, figuresize, cmap, verbose, plot, save, custom_model, signal_thresholds, normalize, resize, target_height, target_width, rescale, resample, net_avg, invert, circular, percentiles, overlay, grayscale)
|
183
115
|
|
184
|
-
@log_function_call
|
185
116
|
def train_cellpose(settings):
|
186
117
|
|
187
118
|
from .io import _load_normalized_images_and_labels, _load_images_and_labels
|
188
119
|
from .utils import resize_images_and_labels
|
189
120
|
|
190
121
|
img_src = settings['img_src']
|
191
|
-
mask_src=
|
192
|
-
|
122
|
+
mask_src = os.path.join(img_src, 'mask')
|
123
|
+
|
124
|
+
model_name = settings['model_name']
|
125
|
+
model_type = settings['model_type']
|
126
|
+
learning_rate = settings['learning_rate']
|
127
|
+
weight_decay = settings['weight_decay']
|
128
|
+
batch_size = settings['batch_size']
|
129
|
+
n_epochs = settings['n_epochs']
|
130
|
+
from_scratch = settings['from_scratch']
|
131
|
+
diameter = settings['diameter']
|
132
|
+
verbose = settings['verbose']
|
133
|
+
|
134
|
+
channels = [0,0]
|
135
|
+
signal_thresholds = 1000
|
136
|
+
normalize = True
|
137
|
+
percentiles = [2,98]
|
138
|
+
circular = False
|
139
|
+
invert = False
|
140
|
+
resize = False
|
141
|
+
settings['width_height'] = [1000,1000]
|
142
|
+
target_height = settings['width_height'][1]
|
143
|
+
target_width = settings['width_height'][0]
|
144
|
+
rescale = False
|
145
|
+
grayscale = True
|
146
|
+
test = False
|
147
|
+
|
148
|
+
if test:
|
149
|
+
test_img_src = os.path.join(os.path.dirname(img_src), 'test')
|
150
|
+
test_mask_src = os.path.join(test_img_src, 'mask')
|
151
|
+
|
152
|
+
test_images, test_masks, test_image_names, test_mask_names = None,None,None,None,
|
153
|
+
print(settings)
|
154
|
+
|
155
|
+
if from_scratch:
|
156
|
+
model_name=f'scratch_{model_name}_{model_type}_e{n_epochs}_X{target_width}_Y{target_height}.CP_model'
|
157
|
+
else:
|
158
|
+
if resize:
|
159
|
+
model_name=f'{model_name}_{model_type}_e{n_epochs}_X{target_width}_Y{target_height}.CP_model'
|
160
|
+
else:
|
161
|
+
model_name=f'{model_name}_{model_type}_e{n_epochs}.CP_model'
|
162
|
+
|
163
|
+
model_save_path = os.path.join(mask_src, 'models', 'cellpose_model')
|
164
|
+
print(model_save_path)
|
165
|
+
os.makedirs(model_save_path, exist_ok=True)
|
166
|
+
|
167
|
+
settings_df = pd.DataFrame(list(settings.items()), columns=['Key', 'Value'])
|
168
|
+
settings_csv = os.path.join(model_save_path,f'{model_name}_settings.csv')
|
169
|
+
settings_df.to_csv(settings_csv, index=False)
|
170
|
+
|
171
|
+
if from_scratch:
|
172
|
+
model = cp_models.CellposeModel(gpu=True, model_type=model_type, diam_mean=diameter, pretrained_model=None)
|
173
|
+
else:
|
174
|
+
model = cp_models.CellposeModel(gpu=True, model_type=model_type)
|
175
|
+
|
176
|
+
if normalize:
|
177
|
+
|
178
|
+
image_files = [os.path.join(img_src, f) for f in os.listdir(img_src) if f.endswith('.tif')]
|
179
|
+
label_files = [os.path.join(mask_src, f) for f in os.listdir(mask_src) if f.endswith('.tif')]
|
180
|
+
images, masks, image_names, mask_names = _load_normalized_images_and_labels(image_files, label_files, signal_thresholds, channels=channels, percentiles=percentiles, circular=circular, invert=invert, visualize=verbose)
|
181
|
+
images = [np.squeeze(img) if img.shape[-1] == 1 else img for img in images]
|
182
|
+
|
183
|
+
if test:
|
184
|
+
test_image_files = [os.path.join(test_img_src, f) for f in os.listdir(test_img_src) if f.endswith('.tif')]
|
185
|
+
test_label_files = [os.path.join(test_mask_src, f) for f in os.listdir(test_mask_src) if f.endswith('.tif')]
|
186
|
+
test_images, test_masks, test_image_names, test_mask_names = _load_normalized_images_and_labels(image_files=test_image_files, label_files=test_label_files, signal_thresholds=signal_thresholds, channels=channels, percentiles=percentiles, circular=circular, invert=invert, visualize=verbose)
|
187
|
+
test_images = [np.squeeze(img) if img.shape[-1] == 1 else img for img in test_images]
|
188
|
+
|
189
|
+
|
190
|
+
else:
|
191
|
+
images, masks, image_names, mask_names = _load_images_and_labels(img_src, mask_src, circular, invert)
|
192
|
+
images = [np.squeeze(img) if img.shape[-1] == 1 else img for img in images]
|
193
|
+
|
194
|
+
if test:
|
195
|
+
test_images, test_masks, test_image_names, test_mask_names = _load_images_and_labels(img_src=test_img_src, mask_src=test_mask_src, circular=circular, invert=circular)
|
196
|
+
test_images = [np.squeeze(img) if img.shape[-1] == 1 else img for img in test_images]
|
197
|
+
|
198
|
+
if resize:
|
199
|
+
images, masks = resize_images_and_labels(images, masks, target_height, target_width, show_example=True)
|
200
|
+
|
201
|
+
if model_type == 'cyto':
|
202
|
+
cp_channels = [0,1]
|
203
|
+
if model_type == 'cyto2':
|
204
|
+
cp_channels = [0,2]
|
205
|
+
if model_type == 'nucleus':
|
206
|
+
cp_channels = [0,0]
|
207
|
+
if grayscale:
|
208
|
+
cp_channels = [0,0]
|
209
|
+
images = [np.squeeze(img) if img.ndim == 3 and 1 in img.shape else img for img in images]
|
210
|
+
|
211
|
+
masks = [np.squeeze(mask) if mask.ndim == 3 and 1 in mask.shape else mask for mask in masks]
|
212
|
+
|
213
|
+
print(f'image shape: {images[0].shape}, image type: images[0].shape mask shape: {masks[0].shape}, image type: masks[0].shape')
|
214
|
+
save_every = int(n_epochs/10)
|
215
|
+
if save_every < 10:
|
216
|
+
save_every = n_epochs
|
217
|
+
|
218
|
+
train.train_seg(model.net,
|
219
|
+
train_data=images,
|
220
|
+
train_labels=masks,
|
221
|
+
train_files=image_names,
|
222
|
+
train_labels_files=mask_names,
|
223
|
+
train_probs=None,
|
224
|
+
test_data=test_images,
|
225
|
+
test_labels=test_masks,
|
226
|
+
test_files=test_image_names,
|
227
|
+
test_labels_files=test_mask_names,
|
228
|
+
test_probs=None,
|
229
|
+
load_files=True,
|
230
|
+
batch_size=batch_size,
|
231
|
+
learning_rate=learning_rate,
|
232
|
+
n_epochs=n_epochs,
|
233
|
+
weight_decay=weight_decay,
|
234
|
+
momentum=0.9,
|
235
|
+
SGD=False,
|
236
|
+
channels=cp_channels,
|
237
|
+
channel_axis=None,
|
238
|
+
#rgb=False,
|
239
|
+
normalize=False,
|
240
|
+
compute_flows=False,
|
241
|
+
save_path=model_save_path,
|
242
|
+
save_every=save_every,
|
243
|
+
nimg_per_epoch=None,
|
244
|
+
nimg_test_per_epoch=None,
|
245
|
+
rescale=rescale,
|
246
|
+
#scale_range=None,
|
247
|
+
#bsize=224,
|
248
|
+
min_train_masks=1,
|
249
|
+
model_name=model_name)
|
250
|
+
|
251
|
+
return print(f"Model saved at: {model_save_path}/{model_name}")
|
252
|
+
|
253
|
+
def train_cellpose_v1(settings):
|
254
|
+
|
255
|
+
from .io import _load_normalized_images_and_labels, _load_images_and_labels
|
256
|
+
from .utils import resize_images_and_labels
|
257
|
+
|
258
|
+
img_src = settings['img_src']
|
259
|
+
|
260
|
+
mask_src = os.path.join(img_src, 'mask')
|
261
|
+
|
193
262
|
model_name = settings['model_name']
|
194
263
|
model_type = settings['model_type']
|
195
264
|
learning_rate = settings['learning_rate']
|
@@ -197,7 +266,9 @@ def train_cellpose(settings):
|
|
197
266
|
batch_size = settings['batch_size']
|
198
267
|
n_epochs = settings['n_epochs']
|
199
268
|
verbose = settings['verbose']
|
200
|
-
|
269
|
+
|
270
|
+
signal_thresholds = 100 #settings['signal_thresholds']
|
271
|
+
|
201
272
|
channels = settings['channels']
|
202
273
|
from_scratch = settings['from_scratch']
|
203
274
|
diameter = settings['diameter']
|
@@ -210,7 +281,17 @@ def train_cellpose(settings):
|
|
210
281
|
invert = settings['invert']
|
211
282
|
percentiles = settings['percentiles']
|
212
283
|
grayscale = settings['grayscale']
|
213
|
-
|
284
|
+
|
285
|
+
if model_type == 'cyto':
|
286
|
+
settings['diameter'] = 30
|
287
|
+
diameter = settings['diameter']
|
288
|
+
print(f'Cyto model must have diamiter 30. Diameter set the 30')
|
289
|
+
|
290
|
+
if model_type == 'nuclei':
|
291
|
+
settings['diameter'] = 17
|
292
|
+
diameter = settings['diameter']
|
293
|
+
print(f'Nuclei model must have diamiter 17. Diameter set the 17')
|
294
|
+
|
214
295
|
print(settings)
|
215
296
|
|
216
297
|
if from_scratch:
|
@@ -219,24 +300,24 @@ def train_cellpose(settings):
|
|
219
300
|
model_name=f'{model_name}_{model_type}_e{n_epochs}_X{target_width}_Y{target_height}.CP_model'
|
220
301
|
|
221
302
|
model_save_path = os.path.join(mask_src, 'models', 'cellpose_model')
|
222
|
-
|
303
|
+
print(model_save_path)
|
304
|
+
os.makedirs(model_save_path, exist_ok=True)
|
223
305
|
|
224
306
|
settings_df = pd.DataFrame(list(settings.items()), columns=['Key', 'Value'])
|
225
307
|
settings_csv = os.path.join(model_save_path,f'{model_name}_settings.csv')
|
226
308
|
settings_df.to_csv(settings_csv, index=False)
|
227
309
|
|
228
|
-
if
|
229
|
-
if not from_scratch:
|
230
|
-
model = cp_models.CellposeModel(gpu=True, model_type=model_type)
|
231
|
-
else:
|
232
|
-
model = cp_models.CellposeModel(gpu=True, model_type=model_type, net_avg=False, diam_mean=diameter, pretrained_model=None)
|
233
|
-
if model_type !='cyto':
|
310
|
+
if not from_scratch:
|
234
311
|
model = cp_models.CellposeModel(gpu=True, model_type=model_type)
|
235
|
-
|
236
|
-
|
237
|
-
|
238
|
-
|
239
|
-
|
312
|
+
|
313
|
+
else:
|
314
|
+
model = cp_models.CellposeModel(gpu=True, model_type=model_type, pretrained_model=None)
|
315
|
+
|
316
|
+
if normalize:
|
317
|
+
image_files = [os.path.join(img_src, f) for f in os.listdir(img_src) if f.endswith('.tif')]
|
318
|
+
label_files = [os.path.join(mask_src, f) for f in os.listdir(mask_src) if f.endswith('.tif')]
|
319
|
+
|
320
|
+
images, masks, image_names, mask_names = _load_normalized_images_and_labels(image_files, label_files, signal_thresholds, channels=channels, percentiles=percentiles, circular=circular, invert=invert, visualize=verbose)
|
240
321
|
images = [np.squeeze(img) if img.shape[-1] == 1 else img for img in images]
|
241
322
|
else:
|
242
323
|
images, masks, image_names, mask_names = _load_images_and_labels(img_src, mask_src, circular, invert)
|
@@ -259,29 +340,89 @@ def train_cellpose(settings):
|
|
259
340
|
|
260
341
|
print(f'image shape: {images[0].shape}, image type: images[0].shape mask shape: {masks[0].shape}, image type: masks[0].shape')
|
261
342
|
save_every = int(n_epochs/10)
|
262
|
-
|
263
|
-
|
343
|
+
if save_every < 10:
|
344
|
+
save_every = n_epochs
|
345
|
+
|
346
|
+
|
347
|
+
#print('cellpose image input dtype', images[0].dtype)
|
348
|
+
#print('cellpose mask input dtype', masks[0].dtype)
|
349
|
+
|
264
350
|
# Train the model
|
265
|
-
model.train(train_data=images, #(list of arrays (2D or 3D)) – images for training
|
266
|
-
|
267
|
-
|
268
|
-
|
269
|
-
|
270
|
-
|
271
|
-
|
272
|
-
|
273
|
-
|
274
|
-
|
275
|
-
|
276
|
-
|
277
|
-
|
278
|
-
|
279
|
-
|
280
|
-
|
351
|
+
#model.train(train_data=images, #(list of arrays (2D or 3D)) – images for training
|
352
|
+
|
353
|
+
#model.train(train_data=images, #(list of arrays (2D or 3D)) – images for training
|
354
|
+
# train_labels=masks, #(list of arrays (2D or 3D)) – labels for train_data, where 0=no masks; 1,2,…=mask labels can include flows as additional images
|
355
|
+
# train_files=image_names, #(list of strings) – file names for images in train_data (to save flows for future runs)
|
356
|
+
# channels=cp_channels, #(list of ints (default, None)) – channels to use for training
|
357
|
+
# normalize=False, #(bool (default, True)) – normalize data so 0.0=1st percentile and 1.0=99th percentile of image intensities in each channel
|
358
|
+
# save_path=model_save_path, #(string (default, None)) – where to save trained model, if None it is not saved
|
359
|
+
# save_every=save_every, #(int (default, 100)) – save network every [save_every] epochs
|
360
|
+
# learning_rate=learning_rate, #(float or list/np.ndarray (default, 0.2)) – learning rate for training, if list, must be same length as n_epochs
|
361
|
+
# n_epochs=n_epochs, #(int (default, 500)) – how many times to go through whole training set during training
|
362
|
+
# weight_decay=weight_decay, #(float (default, 0.00001)) –
|
363
|
+
# SGD=True, #(bool (default, True)) – use SGD as optimization instead of RAdam
|
364
|
+
# batch_size=batch_size, #(int (optional, default 8)) – number of 224x224 patches to run simultaneously on the GPU (can make smaller or bigger depending on GPU memory usage)
|
365
|
+
# nimg_per_epoch=None, #(int (optional, default None)) – minimum number of images to train on per epoch, with a small training set (< 8 images) it may help to set to 8
|
366
|
+
# rescale=rescale, #(bool (default, True)) – whether or not to rescale images to diam_mean during training, if True it assumes you will fit a size model after training or resize your images accordingly, if False it will try to train the model to be scale-invariant (works worse)
|
367
|
+
# min_train_masks=1, #(int (default, 5)) – minimum number of masks an image must have to use in training set
|
368
|
+
# model_name=model_name) #(str (default, None)) – name of network, otherwise saved with name as params + training start time
|
369
|
+
|
370
|
+
|
371
|
+
train.train_seg(model.net,
|
372
|
+
train_data=images,
|
373
|
+
train_labels=masks,
|
374
|
+
train_files=image_names,
|
375
|
+
train_labels_files=None,
|
376
|
+
train_probs=None,
|
377
|
+
test_data=None,
|
378
|
+
test_labels=None,
|
379
|
+
test_files=None,
|
380
|
+
test_labels_files=None,
|
381
|
+
test_probs=None,
|
382
|
+
load_files=True,
|
383
|
+
batch_size=batch_size,
|
384
|
+
learning_rate=learning_rate,
|
385
|
+
n_epochs=n_epochs,
|
386
|
+
weight_decay=weight_decay,
|
387
|
+
momentum=0.9,
|
388
|
+
SGD=False,
|
389
|
+
channels=cp_channels,
|
390
|
+
channel_axis=None,
|
391
|
+
#rgb=False,
|
392
|
+
normalize=False,
|
393
|
+
compute_flows=False,
|
394
|
+
save_path=model_save_path,
|
395
|
+
save_every=save_every,
|
396
|
+
nimg_per_epoch=None,
|
397
|
+
nimg_test_per_epoch=None,
|
398
|
+
rescale=rescale,
|
399
|
+
#scale_range=None,
|
400
|
+
#bsize=224,
|
401
|
+
min_train_masks=1,
|
402
|
+
model_name=model_name)
|
403
|
+
|
404
|
+
#model_save_path = train.train_seg(model.net,
|
405
|
+
# train_data=images,
|
406
|
+
# train_files=image_names,
|
407
|
+
# train_labels=masks,
|
408
|
+
# channels=cp_channels,
|
409
|
+
# normalize=False,
|
410
|
+
# save_every=save_every,
|
411
|
+
# learning_rate=learning_rate,
|
412
|
+
# n_epochs=n_epochs,
|
413
|
+
# #test_data=test_images,
|
414
|
+
# #test_labels=test_labels,
|
415
|
+
# weight_decay=weight_decay,
|
416
|
+
# SGD=True,
|
417
|
+
# batch_size=batch_size,
|
418
|
+
# nimg_per_epoch=None,
|
419
|
+
# rescale=rescale,
|
420
|
+
# min_train_masks=1,
|
421
|
+
# model_name=model_name)
|
422
|
+
|
281
423
|
|
282
424
|
return print(f"Model saved at: {model_save_path}/{model_name}")
|
283
425
|
|
284
|
-
@log_function_call
|
285
426
|
def analyze_data_reg(sequencing_loc, dv_loc, agg_type = 'mean', dv_col='pred', transform=None, min_cell_count=50, min_reads=100, min_wells=2, max_wells=1000, min_frequency=0.0,remove_outlier_genes=False, refine_model=False,by_plate=False, regression_type='mlr', alpha_value=0.01, fishers=False, fisher_threshold=0.9):
|
286
427
|
|
287
428
|
from .plot import _reg_v_plot
|
@@ -430,7 +571,6 @@ def analyze_data_reg(sequencing_loc, dv_loc, agg_type = 'mean', dv_col='pred', t
|
|
430
571
|
|
431
572
|
return result
|
432
573
|
|
433
|
-
@log_function_call
|
434
574
|
def analyze_data_reg(sequencing_loc, dv_loc, agg_type = 'mean', min_cell_count=50, min_reads=100, min_wells=2, max_wells=1000, remove_outlier_genes=False, refine_model=False, by_plate=False, threshold=0.5, fishers=False):
|
435
575
|
|
436
576
|
from .plot import _reg_v_plot
|
@@ -609,7 +749,6 @@ def analyze_data_reg(sequencing_loc, dv_loc, agg_type = 'mean', min_cell_count=5
|
|
609
749
|
|
610
750
|
return max_effects, max_effects_pvalues, model, df
|
611
751
|
|
612
|
-
@log_function_call
|
613
752
|
def regression_analasys(dv_df,sequencing_loc, min_reads=75, min_wells=2, max_wells=0, model_type = 'mlr', min_cells=100, transform='logit', min_frequency=0.05, gene_column='gene', effect_size_threshold=0.25, fishers=True, clean_regression=False, VIF_threshold=10):
|
614
753
|
|
615
754
|
from .utils import generate_fraction_map, fishers_odds, model_metrics, check_multicollinearity
|
@@ -777,7 +916,6 @@ def regression_analasys(dv_df,sequencing_loc, min_reads=75, min_wells=2, max_wel
|
|
777
916
|
|
778
917
|
return
|
779
918
|
|
780
|
-
@log_function_call
|
781
919
|
def merge_pred_mes(src,
|
782
920
|
pred_loc,
|
783
921
|
target='protein of interest',
|
@@ -941,30 +1079,38 @@ def annotate_results(pred_loc):
|
|
941
1079
|
display(df)
|
942
1080
|
return df
|
943
1081
|
|
944
|
-
def generate_dataset(src,
|
1082
|
+
def generate_dataset(src, file_metadata=None, experiment='TSG101_screen', sample=None):
|
945
1083
|
|
946
|
-
from .utils import
|
947
|
-
|
948
|
-
db_path = os.path.join(src, 'measurements','measurements.db')
|
1084
|
+
from .utils import initiate_counter, add_images_to_tar
|
1085
|
+
|
1086
|
+
db_path = os.path.join(src, 'measurements', 'measurements.db')
|
949
1087
|
dst = os.path.join(src, 'datasets')
|
950
|
-
|
951
|
-
global total_images
|
952
1088
|
all_paths = []
|
953
|
-
|
1089
|
+
|
954
1090
|
# Connect to the database and retrieve the image paths
|
955
1091
|
print(f'Reading DataBase: {db_path}')
|
956
|
-
|
957
|
-
|
958
|
-
|
959
|
-
|
960
|
-
|
961
|
-
|
962
|
-
|
963
|
-
|
964
|
-
|
965
|
-
|
966
|
-
|
967
|
-
|
1092
|
+
try:
|
1093
|
+
with sqlite3.connect(db_path) as conn:
|
1094
|
+
cursor = conn.cursor()
|
1095
|
+
if file_metadata:
|
1096
|
+
if isinstance(file_metadata, str):
|
1097
|
+
cursor.execute("SELECT png_path FROM png_list WHERE png_path LIKE ?", (f"%{file_metadata}%",))
|
1098
|
+
else:
|
1099
|
+
cursor.execute("SELECT png_path FROM png_list")
|
1100
|
+
|
1101
|
+
while True:
|
1102
|
+
rows = cursor.fetchmany(1000)
|
1103
|
+
if not rows:
|
1104
|
+
break
|
1105
|
+
all_paths.extend([row[0] for row in rows])
|
1106
|
+
|
1107
|
+
except sqlite3.Error as e:
|
1108
|
+
print(f"Database error: {e}")
|
1109
|
+
return
|
1110
|
+
except Exception as e:
|
1111
|
+
print(f"Error: {e}")
|
1112
|
+
return
|
1113
|
+
|
968
1114
|
if isinstance(sample, int):
|
969
1115
|
selected_paths = random.sample(all_paths, sample)
|
970
1116
|
print(f'Random selection of {len(selected_paths)} paths')
|
@@ -972,23 +1118,18 @@ def generate_dataset(src, file_type=None, experiment='TSG101_screen', sample=Non
|
|
972
1118
|
selected_paths = all_paths
|
973
1119
|
random.shuffle(selected_paths)
|
974
1120
|
print(f'All paths: {len(selected_paths)} paths')
|
975
|
-
|
1121
|
+
|
976
1122
|
total_images = len(selected_paths)
|
977
|
-
print(f'
|
978
|
-
|
1123
|
+
print(f'Found {total_images} images')
|
1124
|
+
|
979
1125
|
# Create a temp folder in dst
|
980
1126
|
temp_dir = os.path.join(dst, "temp_tars")
|
981
1127
|
os.makedirs(temp_dir, exist_ok=True)
|
982
1128
|
|
983
1129
|
# Chunking the data
|
984
|
-
|
985
|
-
|
986
|
-
|
987
|
-
remainder = len(selected_paths) % num_procs
|
988
|
-
else:
|
989
|
-
num_procs = 2
|
990
|
-
chunk_size = len(selected_paths) // 2
|
991
|
-
remainder = 0
|
1130
|
+
num_procs = max(2, cpu_count() - 2)
|
1131
|
+
chunk_size = len(selected_paths) // num_procs
|
1132
|
+
remainder = len(selected_paths) % num_procs
|
992
1133
|
|
993
1134
|
paths_chunks = []
|
994
1135
|
start = 0
|
@@ -998,45 +1139,43 @@ def generate_dataset(src, file_type=None, experiment='TSG101_screen', sample=Non
|
|
998
1139
|
start = end
|
999
1140
|
|
1000
1141
|
temp_tar_files = [os.path.join(temp_dir, f'temp_{i}.tar') for i in range(num_procs)]
|
1001
|
-
|
1002
|
-
# Initialize the shared objects
|
1003
|
-
counter_ = Value('i', 0)
|
1004
|
-
lock_ = Lock()
|
1005
1142
|
|
1006
|
-
ctx = multiprocessing.get_context('spawn')
|
1007
|
-
|
1008
1143
|
print(f'Generating temporary tar files in {dst}')
|
1009
|
-
|
1144
|
+
|
1145
|
+
# Initialize shared counter and lock
|
1146
|
+
counter = Value('i', 0)
|
1147
|
+
lock = Lock()
|
1148
|
+
|
1149
|
+
with Pool(processes=num_procs, initializer=initiate_counter, initargs=(counter, lock)) as pool:
|
1150
|
+
pool.starmap(add_images_to_tar, [(paths_chunks[i], temp_tar_files[i], total_images) for i in range(num_procs)])
|
1151
|
+
|
1010
1152
|
# Combine the temporary tar files into a final tar
|
1011
1153
|
date_name = datetime.date.today().strftime('%y%m%d')
|
1012
|
-
|
1154
|
+
if not file_metadata is None:
|
1155
|
+
tar_name = f'{date_name}_{experiment}_{file_metadata}.tar'
|
1156
|
+
else:
|
1157
|
+
tar_name = f'{date_name}_{experiment}.tar'
|
1158
|
+
tar_name = os.path.join(dst, tar_name)
|
1013
1159
|
if os.path.exists(tar_name):
|
1014
1160
|
number = random.randint(1, 100)
|
1015
|
-
tar_name_2 = f'{date_name}_{experiment}_{
|
1016
|
-
print(f'Warning: {os.path.basename(tar_name)} exists saving as {os.path.basename(tar_name_2)} ')
|
1017
|
-
tar_name = tar_name_2
|
1018
|
-
|
1019
|
-
# Add the counter and lock to the arguments for pool.map
|
1161
|
+
tar_name_2 = f'{date_name}_{experiment}_{file_metadata}_{number}.tar'
|
1162
|
+
print(f'Warning: {os.path.basename(tar_name)} exists, saving as {os.path.basename(tar_name_2)} ')
|
1163
|
+
tar_name = os.path.join(dst, tar_name_2)
|
1164
|
+
|
1020
1165
|
print(f'Merging temporary files')
|
1021
|
-
#with Pool(processes=num_procs, initializer=init_globals, initargs=(counter_, lock_)) as pool:
|
1022
|
-
# results = pool.map(add_images_to_tar, zip(paths_chunks, temp_tar_files))
|
1023
1166
|
|
1024
|
-
with
|
1025
|
-
|
1026
|
-
|
1027
|
-
|
1028
|
-
|
1029
|
-
|
1030
|
-
|
1031
|
-
t.extract(member, path=dst)
|
1032
|
-
final_tar.add(os.path.join(dst, member.name), arcname=member.name)
|
1033
|
-
os.remove(os.path.join(dst, member.name))
|
1034
|
-
os.remove(tar_path)
|
1167
|
+
with tarfile.open(tar_name, 'w') as final_tar:
|
1168
|
+
for temp_tar_path in temp_tar_files:
|
1169
|
+
with tarfile.open(temp_tar_path, 'r') as temp_tar:
|
1170
|
+
for member in temp_tar.getmembers():
|
1171
|
+
file_obj = temp_tar.extractfile(member)
|
1172
|
+
final_tar.addfile(member, file_obj)
|
1173
|
+
os.remove(temp_tar_path)
|
1035
1174
|
|
1036
1175
|
# Delete the temp folder
|
1037
1176
|
shutil.rmtree(temp_dir)
|
1038
|
-
print(f"\nSaved {total_images} images to {
|
1039
|
-
|
1177
|
+
print(f"\nSaved {total_images} images to {tar_name}")
|
1178
|
+
|
1040
1179
|
def apply_model_to_tar(tar_path, model_path, file_type='cell_png', image_size=224, batch_size=64, normalize=True, preload='images', num_workers=10, verbose=False):
|
1041
1180
|
|
1042
1181
|
from .io import TarImageDataset, DataLoader
|
@@ -1272,7 +1411,14 @@ def generate_training_dataset(src, mode='annotation', annotation_column='test',
|
|
1272
1411
|
|
1273
1412
|
db_path = os.path.join(src, 'measurements','measurements.db')
|
1274
1413
|
dst = os.path.join(src, 'datasets', 'training')
|
1275
|
-
|
1414
|
+
|
1415
|
+
if os.path.exists(dst):
|
1416
|
+
for i in range(1, 1000):
|
1417
|
+
dst = os.path.join(src, 'datasets', f'training_{i}')
|
1418
|
+
if not os.path.exists(dst):
|
1419
|
+
print(f'Creating new directory for training: {dst}')
|
1420
|
+
break
|
1421
|
+
|
1276
1422
|
if mode == 'annotation':
|
1277
1423
|
class_paths_ls_2 = []
|
1278
1424
|
class_paths_ls = training_dataset_from_annotation(db_path, dst, annotation_column, annotated_classes=annotated_classes)
|
@@ -1283,6 +1429,7 @@ def generate_training_dataset(src, mode='annotation', annotation_column='test',
|
|
1283
1429
|
|
1284
1430
|
elif mode == 'metadata':
|
1285
1431
|
class_paths_ls = []
|
1432
|
+
class_len_ls = []
|
1286
1433
|
[df] = _read_db(db_loc=db_path, tables=['png_list'])
|
1287
1434
|
df['metadata_based_class'] = pd.NA
|
1288
1435
|
for i, class_ in enumerate(classes):
|
@@ -1290,7 +1437,18 @@ def generate_training_dataset(src, mode='annotation', annotation_column='test',
|
|
1290
1437
|
df.loc[df[metadata_type_by].isin(ls), 'metadata_based_class'] = class_
|
1291
1438
|
|
1292
1439
|
for class_ in classes:
|
1440
|
+
if size == None:
|
1441
|
+
c_s = []
|
1442
|
+
for c in classes:
|
1443
|
+
c_s_t_df = df[df['metadata_based_class'] == c]
|
1444
|
+
c_s.append(len(c_s_t_df))
|
1445
|
+
print(f'Found {len(c_s_t_df)} images for class {c}')
|
1446
|
+
size = min(c_s)
|
1447
|
+
print(f'Using the smallest class size: {size}')
|
1448
|
+
|
1293
1449
|
class_temp_df = df[df['metadata_based_class'] == class_]
|
1450
|
+
class_len_ls.append(len(class_temp_df))
|
1451
|
+
print(f'Found {len(class_temp_df)} images for class {class_}')
|
1294
1452
|
class_paths_temp = random.sample(class_temp_df['png_path'].tolist(), size)
|
1295
1453
|
class_paths_ls.append(class_paths_temp)
|
1296
1454
|
|
@@ -1347,7 +1505,7 @@ def generate_training_dataset(src, mode='annotation', annotation_column='test',
|
|
1347
1505
|
|
1348
1506
|
return
|
1349
1507
|
|
1350
|
-
def
|
1508
|
+
def generate_loaders_v1(src, train_mode='erm', mode='train', image_size=224, batch_size=32, classes=['nc','pc'], num_workers=None, validation_split=0.0, max_show=2, pin_memory=False, normalize=False, verbose=False):
|
1351
1509
|
"""
|
1352
1510
|
Generate data loaders for training and validation/test datasets.
|
1353
1511
|
|
@@ -1478,55 +1636,222 @@ def generate_loaders(src, train_mode='erm', mode='train', image_size=224, batch_
|
|
1478
1636
|
|
1479
1637
|
return train_loaders, val_loaders, plate_names
|
1480
1638
|
|
1481
|
-
def
|
1639
|
+
def generate_loaders(src, train_mode='erm', mode='train', image_size=224, batch_size=32, classes=['nc','pc'], num_workers=None, validation_split=0.0, max_show=2, pin_memory=False, normalize=False, channels=[1, 2, 3], verbose=False):
|
1640
|
+
|
1482
1641
|
"""
|
1483
|
-
|
1642
|
+
Generate data loaders for training and validation/test datasets.
|
1484
1643
|
|
1485
1644
|
Parameters:
|
1486
|
-
src (str): The source
|
1487
|
-
|
1488
|
-
|
1645
|
+
- src (str): The source directory containing the data.
|
1646
|
+
- train_mode (str): The training mode. Options are 'erm' (Empirical Risk Minimization) or 'irm' (Invariant Risk Minimization).
|
1647
|
+
- mode (str): The mode of operation. Options are 'train' or 'test'.
|
1648
|
+
- image_size (int): The size of the input images.
|
1649
|
+
- batch_size (int): The batch size for the data loaders.
|
1650
|
+
- classes (list): The list of classes to consider.
|
1651
|
+
- num_workers (int): The number of worker threads for data loading.
|
1652
|
+
- validation_split (float): The fraction of data to use for validation when train_mode is 'erm'.
|
1653
|
+
- max_show (int): The maximum number of images to show when verbose is True.
|
1654
|
+
- pin_memory (bool): Whether to pin memory for faster data transfer.
|
1655
|
+
- normalize (bool): Whether to normalize the input images.
|
1656
|
+
- verbose (bool): Whether to print additional information and show images.
|
1657
|
+
- channels (list): The list of channels to retain. Options are [1, 2, 3] for all channels, [1, 2] for blue and green, etc.
|
1489
1658
|
|
1490
1659
|
Returns:
|
1491
|
-
|
1660
|
+
- train_loaders (list): List of data loaders for training datasets.
|
1661
|
+
- val_loaders (list): List of data loaders for validation datasets.
|
1662
|
+
- plate_names (list): List of plate names (only applicable when train_mode is 'irm').
|
1492
1663
|
"""
|
1493
|
-
|
1494
|
-
from .io import _read_and_merge_data, _results_to_csv
|
1495
|
-
from .plot import plot_merged, _plot_controls, _plot_recruitment
|
1496
|
-
from .utils import _object_filter, annotate_conditions, _calculate_recruitment, _group_by_well
|
1497
|
-
|
1498
|
-
settings_dict = {**metadata_settings, **advanced_settings}
|
1499
|
-
settings_df = pd.DataFrame(list(settings_dict.items()), columns=['Key', 'Value'])
|
1500
|
-
settings_csv = os.path.join(src,'settings','analyze_settings.csv')
|
1501
|
-
os.makedirs(os.path.join(src,'settings'), exist_ok=True)
|
1502
|
-
settings_df.to_csv(settings_csv, index=False)
|
1503
1664
|
|
1504
|
-
|
1505
|
-
|
1506
|
-
|
1507
|
-
|
1508
|
-
|
1509
|
-
|
1510
|
-
|
1511
|
-
|
1512
|
-
|
1513
|
-
|
1514
|
-
|
1515
|
-
|
1516
|
-
|
1517
|
-
|
1518
|
-
|
1519
|
-
|
1520
|
-
|
1521
|
-
|
1522
|
-
|
1523
|
-
|
1524
|
-
|
1525
|
-
|
1526
|
-
|
1527
|
-
|
1528
|
-
|
1529
|
-
|
1665
|
+
from .io import MyDataset
|
1666
|
+
from .plot import _imshow
|
1667
|
+
from torchvision import transforms
|
1668
|
+
from torch.utils.data import DataLoader, random_split
|
1669
|
+
from collections import defaultdict
|
1670
|
+
import os
|
1671
|
+
import random
|
1672
|
+
from PIL import Image
|
1673
|
+
from torchvision.transforms import ToTensor
|
1674
|
+
|
1675
|
+
chans = []
|
1676
|
+
|
1677
|
+
if 'r' in channels:
|
1678
|
+
chans.append(1)
|
1679
|
+
if 'g' in channels:
|
1680
|
+
chans.append(2)
|
1681
|
+
if 'b' in channels:
|
1682
|
+
chans.append(3)
|
1683
|
+
|
1684
|
+
channels = chans
|
1685
|
+
|
1686
|
+
if verbose:
|
1687
|
+
print(f'Training a network on channels: {channels}')
|
1688
|
+
print(f'Channel 1: Red, Channel 2: Green, Channel 3: Blue')
|
1689
|
+
|
1690
|
+
class SelectChannels:
|
1691
|
+
def __init__(self, channels):
|
1692
|
+
self.channels = channels
|
1693
|
+
|
1694
|
+
def __call__(self, img):
|
1695
|
+
img = img.clone()
|
1696
|
+
if 1 not in self.channels:
|
1697
|
+
img[0, :, :] = 0 # Zero out the red channel
|
1698
|
+
if 2 not in self.channels:
|
1699
|
+
img[1, :, :] = 0 # Zero out the green channel
|
1700
|
+
if 3 not in self.channels:
|
1701
|
+
img[2, :, :] = 0 # Zero out the blue channel
|
1702
|
+
return img
|
1703
|
+
|
1704
|
+
plate_to_filenames = defaultdict(list)
|
1705
|
+
plate_to_labels = defaultdict(list)
|
1706
|
+
train_loaders = []
|
1707
|
+
val_loaders = []
|
1708
|
+
plate_names = []
|
1709
|
+
|
1710
|
+
if normalize:
|
1711
|
+
transform = transforms.Compose([
|
1712
|
+
transforms.ToTensor(),
|
1713
|
+
transforms.CenterCrop(size=(image_size, image_size)),
|
1714
|
+
SelectChannels(channels),
|
1715
|
+
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])
|
1716
|
+
else:
|
1717
|
+
transform = transforms.Compose([
|
1718
|
+
transforms.ToTensor(),
|
1719
|
+
transforms.CenterCrop(size=(image_size, image_size)),
|
1720
|
+
SelectChannels(channels)])
|
1721
|
+
|
1722
|
+
if mode == 'train':
|
1723
|
+
data_dir = os.path.join(src, 'train')
|
1724
|
+
shuffle = True
|
1725
|
+
print('Generating Train and validation datasets')
|
1726
|
+
elif mode == 'test':
|
1727
|
+
data_dir = os.path.join(src, 'test')
|
1728
|
+
val_loaders = []
|
1729
|
+
validation_split = 0.0
|
1730
|
+
shuffle = True
|
1731
|
+
print('Generating test dataset')
|
1732
|
+
else:
|
1733
|
+
print(f'mode:{mode} is not valid, use mode = train or test')
|
1734
|
+
return
|
1735
|
+
|
1736
|
+
if train_mode == 'erm':
|
1737
|
+
data = MyDataset(data_dir, classes, transform=transform, shuffle=shuffle, pin_memory=pin_memory)
|
1738
|
+
if validation_split > 0:
|
1739
|
+
train_size = int((1 - validation_split) * len(data))
|
1740
|
+
val_size = len(data) - train_size
|
1741
|
+
|
1742
|
+
print(f'Train data:{train_size}, Validation data:{val_size}')
|
1743
|
+
|
1744
|
+
train_dataset, val_dataset = random_split(data, [train_size, val_size])
|
1745
|
+
|
1746
|
+
train_loaders = DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers if num_workers is not None else 0, pin_memory=pin_memory)
|
1747
|
+
val_loaders = DataLoader(val_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers if num_workers is not None else 0, pin_memory=pin_memory)
|
1748
|
+
else:
|
1749
|
+
train_loaders = DataLoader(data, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers if num_workers is not None else 0, pin_memory=pin_memory)
|
1750
|
+
|
1751
|
+
elif train_mode == 'irm':
|
1752
|
+
data = MyDataset(data_dir, classes, transform=transform, shuffle=shuffle, pin_memory=pin_memory)
|
1753
|
+
|
1754
|
+
for filename, label in zip(data.filenames, data.labels):
|
1755
|
+
plate = data.get_plate(filename)
|
1756
|
+
plate_to_filenames[plate].append(filename)
|
1757
|
+
plate_to_labels[plate].append(label)
|
1758
|
+
|
1759
|
+
for plate, filenames in plate_to_filenames.items():
|
1760
|
+
labels = plate_to_labels[plate]
|
1761
|
+
plate_data = MyDataset(data_dir, classes, specific_files=filenames, specific_labels=labels, transform=transform, shuffle=False, pin_memory=pin_memory)
|
1762
|
+
plate_names.append(plate)
|
1763
|
+
|
1764
|
+
if validation_split > 0:
|
1765
|
+
train_size = int((1 - validation_split) * len(plate_data))
|
1766
|
+
val_size = len(plate_data) - train_size
|
1767
|
+
|
1768
|
+
print(f'Train data:{train_size}, Validation data:{val_size}')
|
1769
|
+
|
1770
|
+
train_dataset, val_dataset = random_split(plate_data, [train_size, val_size])
|
1771
|
+
|
1772
|
+
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers if num_workers is not None else 0, pin_memory=pin_memory)
|
1773
|
+
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers if num_workers is not None else 0, pin_memory=pin_memory)
|
1774
|
+
|
1775
|
+
train_loaders.append(train_loader)
|
1776
|
+
val_loaders.append(val_loader)
|
1777
|
+
else:
|
1778
|
+
train_loader = DataLoader(plate_data, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers if num_workers is not None else 0, pin_memory=pin_memory)
|
1779
|
+
train_loaders.append(train_loader)
|
1780
|
+
val_loaders.append(None)
|
1781
|
+
|
1782
|
+
else:
|
1783
|
+
print(f'train_mode:{train_mode} is not valid, use: train_mode = irm or erm')
|
1784
|
+
return
|
1785
|
+
|
1786
|
+
if verbose:
|
1787
|
+
if train_mode == 'erm':
|
1788
|
+
for idx, (images, labels, filenames) in enumerate(train_loaders):
|
1789
|
+
if idx >= max_show:
|
1790
|
+
break
|
1791
|
+
images = images.cpu()
|
1792
|
+
label_strings = [str(label.item()) for label in labels]
|
1793
|
+
_imshow(images, label_strings, nrow=20, fontsize=12)
|
1794
|
+
elif train_mode == 'irm':
|
1795
|
+
for plate_name, train_loader in zip(plate_names, train_loaders):
|
1796
|
+
print(f'Plate: {plate_name} with {len(train_loader.dataset)} images')
|
1797
|
+
for idx, (images, labels, filenames) in enumerate(train_loader):
|
1798
|
+
if idx >= max_show:
|
1799
|
+
break
|
1800
|
+
images = images.cpu()
|
1801
|
+
label_strings = [str(label.item()) for label in labels]
|
1802
|
+
_imshow(images, label_strings, nrow=20, fontsize=12)
|
1803
|
+
|
1804
|
+
return train_loaders, val_loaders, plate_names
|
1805
|
+
|
1806
|
+
def analyze_recruitment(src, metadata_settings, advanced_settings):
|
1807
|
+
"""
|
1808
|
+
Analyze recruitment data by grouping the DataFrame by well coordinates and plotting controls and recruitment data.
|
1809
|
+
|
1810
|
+
Parameters:
|
1811
|
+
src (str): The source of the recruitment data.
|
1812
|
+
metadata_settings (dict): The settings for metadata.
|
1813
|
+
advanced_settings (dict): The advanced settings for recruitment analysis.
|
1814
|
+
|
1815
|
+
Returns:
|
1816
|
+
None
|
1817
|
+
"""
|
1818
|
+
|
1819
|
+
from .io import _read_and_merge_data, _results_to_csv
|
1820
|
+
from .plot import plot_merged, _plot_controls, _plot_recruitment
|
1821
|
+
from .utils import _object_filter, annotate_conditions, _calculate_recruitment, _group_by_well
|
1822
|
+
|
1823
|
+
settings_dict = {**metadata_settings, **advanced_settings}
|
1824
|
+
settings_df = pd.DataFrame(list(settings_dict.items()), columns=['Key', 'Value'])
|
1825
|
+
settings_csv = os.path.join(src,'settings','analyze_settings.csv')
|
1826
|
+
os.makedirs(os.path.join(src,'settings'), exist_ok=True)
|
1827
|
+
settings_df.to_csv(settings_csv, index=False)
|
1828
|
+
|
1829
|
+
# metadata settings
|
1830
|
+
target = metadata_settings['target']
|
1831
|
+
cell_types = metadata_settings['cell_types']
|
1832
|
+
cell_plate_metadata = metadata_settings['cell_plate_metadata']
|
1833
|
+
pathogen_types = metadata_settings['pathogen_types']
|
1834
|
+
pathogen_plate_metadata = metadata_settings['pathogen_plate_metadata']
|
1835
|
+
treatments = metadata_settings['treatments']
|
1836
|
+
treatment_plate_metadata = metadata_settings['treatment_plate_metadata']
|
1837
|
+
metadata_types = metadata_settings['metadata_types']
|
1838
|
+
channel_dims = metadata_settings['channel_dims']
|
1839
|
+
cell_chann_dim = metadata_settings['cell_chann_dim']
|
1840
|
+
cell_mask_dim = metadata_settings['cell_mask_dim']
|
1841
|
+
nucleus_chann_dim = metadata_settings['nucleus_chann_dim']
|
1842
|
+
nucleus_mask_dim = metadata_settings['nucleus_mask_dim']
|
1843
|
+
pathogen_chann_dim = metadata_settings['pathogen_chann_dim']
|
1844
|
+
pathogen_mask_dim = metadata_settings['pathogen_mask_dim']
|
1845
|
+
channel_of_interest = metadata_settings['channel_of_interest']
|
1846
|
+
|
1847
|
+
# Advanced settings
|
1848
|
+
plot = advanced_settings['plot']
|
1849
|
+
plot_nr = advanced_settings['plot_nr']
|
1850
|
+
plot_control = advanced_settings['plot_control']
|
1851
|
+
figuresize = advanced_settings['figuresize']
|
1852
|
+
remove_background = advanced_settings['remove_background']
|
1853
|
+
backgrounds = advanced_settings['backgrounds']
|
1854
|
+
include_noninfected = advanced_settings['include_noninfected']
|
1530
1855
|
include_multiinfected = advanced_settings['include_multiinfected']
|
1531
1856
|
include_multinucleated = advanced_settings['include_multinucleated']
|
1532
1857
|
cells_per_well = advanced_settings['cells_per_well']
|
@@ -1584,15 +1909,30 @@ def analyze_recruitment(src, metadata_settings, advanced_settings):
|
|
1584
1909
|
df = df.dropna(subset=['condition'])
|
1585
1910
|
print(f'After dropping non-annotated wells: {len(df)} rows')
|
1586
1911
|
files = df['file_name'].tolist()
|
1912
|
+
print(f'found: {len(files)} files')
|
1587
1913
|
files = [item + '.npy' for item in files]
|
1588
1914
|
random.shuffle(files)
|
1589
|
-
|
1915
|
+
|
1916
|
+
_max = 10**100
|
1917
|
+
|
1918
|
+
if cell_size_range is None and nucleus_size_range is None and pathogen_size_range is None:
|
1919
|
+
filter_min_max = None
|
1920
|
+
else:
|
1921
|
+
if cell_size_range is None:
|
1922
|
+
cell_size_range = [0,_max]
|
1923
|
+
if nucleus_size_range is None:
|
1924
|
+
nucleus_size_range = [0,_max]
|
1925
|
+
if pathogen_size_range is None:
|
1926
|
+
pathogen_size_range = [0,_max]
|
1927
|
+
|
1928
|
+
filter_min_max = [[cell_size_range[0],cell_size_range[1]],[nucleus_size_range[0],nucleus_size_range[1]],[pathogen_size_range[0],pathogen_size_range[1]]]
|
1929
|
+
|
1590
1930
|
if plot:
|
1591
1931
|
plot_settings = {'include_noninfected':include_noninfected,
|
1592
1932
|
'include_multiinfected':include_multiinfected,
|
1593
1933
|
'include_multinucleated':include_multinucleated,
|
1594
1934
|
'remove_background':remove_background,
|
1595
|
-
'filter_min_max':
|
1935
|
+
'filter_min_max':filter_min_max,
|
1596
1936
|
'channel_dims':channel_dims,
|
1597
1937
|
'backgrounds':backgrounds,
|
1598
1938
|
'cell_mask_dim':mask_dims[0],
|
@@ -1649,15 +1989,37 @@ def analyze_recruitment(src, metadata_settings, advanced_settings):
|
|
1649
1989
|
cells,wells = _results_to_csv(src, df, df_well)
|
1650
1990
|
return [cells,wells]
|
1651
1991
|
|
1652
|
-
|
1653
|
-
def preprocess_generate_masks(src, settings={},advanced_settings={}):
|
1992
|
+
def preprocess_generate_masks(src, settings={}):
|
1654
1993
|
|
1655
1994
|
from .io import preprocess_img_data, _load_and_concatenate_arrays
|
1656
1995
|
from .plot import plot_merged, plot_arrays
|
1657
1996
|
from .utils import _pivot_counts_table
|
1658
|
-
|
1659
|
-
settings =
|
1997
|
+
|
1998
|
+
settings['plot'] = False
|
1999
|
+
settings['fps'] = 2
|
2000
|
+
settings['remove_background'] = True
|
2001
|
+
settings['lower_quantile'] = 0.02
|
2002
|
+
settings['merge'] = False
|
2003
|
+
settings['normalize_plots'] = True
|
2004
|
+
settings['all_to_mip'] = False
|
2005
|
+
settings['pick_slice'] = False
|
2006
|
+
settings['skip_mode'] = src
|
2007
|
+
settings['workers'] = os.cpu_count()-4
|
2008
|
+
settings['verbose'] = True
|
2009
|
+
settings['examples_to_plot'] = 1
|
1660
2010
|
settings['src'] = src
|
2011
|
+
settings['upscale'] = False
|
2012
|
+
settings['upscale_factor'] = 2.0
|
2013
|
+
|
2014
|
+
settings['randomize'] = True
|
2015
|
+
settings['timelapse'] = False
|
2016
|
+
settings['timelapse_displacement'] = None
|
2017
|
+
settings['timelapse_memory'] = 3
|
2018
|
+
settings['timelapse_frame_limits'] = None
|
2019
|
+
settings['timelapse_remove_transient'] = False
|
2020
|
+
settings['timelapse_mode'] = 'trackpy'
|
2021
|
+
settings['timelapse_objects'] = ['cells']
|
2022
|
+
|
1661
2023
|
settings_df = pd.DataFrame(list(settings.items()), columns=['Key', 'Value'])
|
1662
2024
|
settings_csv = os.path.join(src,'settings','preprocess_generate_masks_settings.csv')
|
1663
2025
|
os.makedirs(os.path.join(src,'settings'), exist_ok=True)
|
@@ -1676,7 +2038,7 @@ def preprocess_generate_masks(src, settings={},advanced_settings={}):
|
|
1676
2038
|
settings['save'] = [settings['save']]*3
|
1677
2039
|
|
1678
2040
|
if settings['preprocess']:
|
1679
|
-
preprocess_img_data(settings)
|
2041
|
+
settings, src = preprocess_img_data(settings)
|
1680
2042
|
|
1681
2043
|
if settings['masks']:
|
1682
2044
|
mask_src = os.path.join(src, 'norm_channel_stack')
|
@@ -1726,7 +2088,6 @@ def preprocess_generate_masks(src, settings={},advanced_settings={}):
|
|
1726
2088
|
'cell_mask_dim':cell_mask_dim,
|
1727
2089
|
'nucleus_mask_dim':nucleus_mask_dim,
|
1728
2090
|
'pathogen_mask_dim':pathogen_mask_dim,
|
1729
|
-
'overlay_chans':[0,2,3],
|
1730
2091
|
'outline_thickness':3,
|
1731
2092
|
'outline_color':'gbr',
|
1732
2093
|
'overlay_chans':overlay_channels,
|
@@ -1738,6 +2099,10 @@ def preprocess_generate_masks(src, settings={},advanced_settings={}):
|
|
1738
2099
|
'figuresize':20,
|
1739
2100
|
'cmap':'inferno',
|
1740
2101
|
'verbose':False}
|
2102
|
+
|
2103
|
+
if settings['test_mode'] == True:
|
2104
|
+
plot_settings['nr'] = len(os.path.join(src,'merged'))
|
2105
|
+
|
1741
2106
|
try:
|
1742
2107
|
fig = plot_merged(src=os.path.join(src,'merged'), settings=plot_settings)
|
1743
2108
|
except Exception as e:
|
@@ -1750,26 +2115,61 @@ def preprocess_generate_masks(src, settings={},advanced_settings={}):
|
|
1750
2115
|
print("Successfully completed run")
|
1751
2116
|
return
|
1752
2117
|
|
1753
|
-
def identify_masks_finetune(
|
2118
|
+
def identify_masks_finetune(settings):
|
1754
2119
|
|
1755
2120
|
from .plot import print_mask_and_flows
|
1756
2121
|
from .utils import get_files_from_dir, resize_images_and_labels
|
1757
2122
|
from .io import _load_normalized_images_and_labels, _load_images_and_labels
|
1758
2123
|
|
2124
|
+
src=settings['src']
|
2125
|
+
dst=settings['dst']
|
2126
|
+
model_name=settings['model_name']
|
2127
|
+
diameter=settings['diameter']
|
2128
|
+
batch_size=settings['batch_size']
|
2129
|
+
flow_threshold=settings['flow_threshold']
|
2130
|
+
cellprob_threshold=settings['cellprob_threshold']
|
2131
|
+
|
2132
|
+
verbose=settings['verbose']
|
2133
|
+
plot=settings['plot']
|
2134
|
+
save=settings['save']
|
2135
|
+
custom_model=settings['custom_model']
|
2136
|
+
overlay=settings['overlay']
|
2137
|
+
|
2138
|
+
figuresize=25
|
2139
|
+
cmap='inferno'
|
2140
|
+
channels = [0,0]
|
2141
|
+
signal_thresholds = 1000
|
2142
|
+
normalize = True
|
2143
|
+
percentiles = [2,98]
|
2144
|
+
circular = False
|
2145
|
+
invert = False
|
2146
|
+
resize = False
|
2147
|
+
settings['width_height'] = [1000,1000]
|
2148
|
+
target_height = settings['width_height'][1]
|
2149
|
+
target_width = settings['width_height'][0]
|
2150
|
+
rescale = False
|
2151
|
+
resample = False
|
2152
|
+
grayscale = True
|
2153
|
+
test = False
|
2154
|
+
|
2155
|
+
os.makedirs(dst, exist_ok=True)
|
2156
|
+
|
2157
|
+
if not custom_model is None:
|
2158
|
+
if not os.path.exists(custom_model):
|
2159
|
+
print(f'Custom model not found: {custom_model}')
|
2160
|
+
return
|
2161
|
+
|
1759
2162
|
if not torch.cuda.is_available():
|
1760
2163
|
print(f'Torch CUDA is not available, using CPU')
|
1761
2164
|
|
1762
2165
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
1763
2166
|
|
1764
2167
|
if custom_model == None:
|
1765
|
-
|
1766
|
-
|
1767
|
-
|
1768
|
-
|
1769
|
-
|
1770
|
-
if custom_model != None:
|
1771
|
-
model = cp_models.CellposeModel(gpu=torch.cuda.is_available(), model_type=None, pretrained_model=custom_model, diam_mean=diameter, device=device, net_avg=False) #Assuming diameter is defined elsewhere
|
1772
|
-
print(f'loaded custom model:{custom_model}')
|
2168
|
+
model = cp_models.CellposeModel(gpu=True, model_type=model_name, device=device)
|
2169
|
+
print(f'Loaded model: {model_name}')
|
2170
|
+
else:
|
2171
|
+
model = cp_models.CellposeModel(gpu=torch.cuda.is_available(), model_type=None, pretrained_model=custom_model, diam_mean=diameter, device=device)
|
2172
|
+
print("Pretrained Model Loaded:", model.pretrained_model)
|
1773
2173
|
|
1774
2174
|
chans = [2, 1] if model_name == 'cyto2' else [0,0] if model_name == 'nucleus' else [1,0] if model_name == 'cyto' else [2, 0]
|
1775
2175
|
|
@@ -1781,14 +2181,16 @@ def identify_masks_finetune(src, dst, model_name, channels, diameter, batch_size
|
|
1781
2181
|
if verbose == True:
|
1782
2182
|
print(f'Cellpose settings: Model: {model_name}, channels: {channels}, cellpose_chans: {chans}, diameter:{diameter}, flow_threshold:{flow_threshold}, cellprob_threshold:{cellprob_threshold}')
|
1783
2183
|
|
1784
|
-
all_image_files =
|
2184
|
+
all_image_files = [os.path.join(src, f) for f in os.listdir(src) if f.endswith('.tif')]
|
2185
|
+
|
1785
2186
|
random.shuffle(all_image_files)
|
1786
2187
|
|
1787
2188
|
time_ls = []
|
1788
2189
|
for i in range(0, len(all_image_files), batch_size):
|
1789
2190
|
image_files = all_image_files[i:i+batch_size]
|
2191
|
+
|
1790
2192
|
if normalize:
|
1791
|
-
images, _, image_names, _ = _load_normalized_images_and_labels(image_files=image_files, label_files=None, signal_thresholds=signal_thresholds, channels=channels, percentiles=percentiles, circular=circular, invert=invert, visualize=
|
2193
|
+
images, _, image_names, _ = _load_normalized_images_and_labels(image_files=image_files, label_files=None, signal_thresholds=signal_thresholds, channels=channels, percentiles=percentiles, circular=circular, invert=invert, visualize=plot)
|
1792
2194
|
images = [np.squeeze(img) if img.shape[-1] == 1 else img for img in images]
|
1793
2195
|
orig_dims = [(image.shape[0], image.shape[1]) for image in images]
|
1794
2196
|
else:
|
@@ -1809,8 +2211,7 @@ def identify_masks_finetune(src, dst, model_name, channels, diameter, batch_size
|
|
1809
2211
|
cellprob_threshold=cellprob_threshold,
|
1810
2212
|
rescale=rescale,
|
1811
2213
|
resample=resample,
|
1812
|
-
|
1813
|
-
progress=False)
|
2214
|
+
progress=True)
|
1814
2215
|
|
1815
2216
|
if len(output) == 4:
|
1816
2217
|
mask, flows, _, _ = output
|
@@ -1837,8 +2238,7 @@ def identify_masks_finetune(src, dst, model_name, channels, diameter, batch_size
|
|
1837
2238
|
cv2.imwrite(output_filename, mask)
|
1838
2239
|
return
|
1839
2240
|
|
1840
|
-
|
1841
|
-
def identify_masks(src, object_type, model_name, batch_size, channels, diameter, minimum_size, maximum_size, flow_threshold=30, cellprob_threshold=1, figuresize=25, cmap='inferno', refine_masks=True, filter_size=True, filter_dimm=True, remove_border_objects=False, verbose=False, plot=False, merge=False, save=True, start_at=0, file_type='.npz', net_avg=True, resample=True, timelapse=False, timelapse_displacement=None, timelapse_frame_limits=None, timelapse_memory=3, timelapse_remove_transient=False, timelapse_mode='btrack', timelapse_objects='cell'):
|
2241
|
+
def identify_masks(src, object_type, model_name, batch_size, channels, diameter, minimum_size, maximum_size, filter_intensity, flow_threshold=30, cellprob_threshold=1, figuresize=25, cmap='inferno', refine_masks=True, filter_size=True, filter_dimm=True, remove_border_objects=False, verbose=False, plot=False, merge=False, save=True, start_at=0, file_type='.npz', net_avg=True, resample=True, timelapse=False, timelapse_displacement=None, timelapse_frame_limits=None, timelapse_memory=3, timelapse_remove_transient=False, timelapse_mode='btrack', timelapse_objects='cell'):
|
1842
2242
|
"""
|
1843
2243
|
Identify masks from the source images.
|
1844
2244
|
|
@@ -1886,13 +2286,13 @@ def identify_masks(src, object_type, model_name, batch_size, channels, diameter,
|
|
1886
2286
|
|
1887
2287
|
#Note add logic that handles batches of size 1 as these will break the code batches must all be > 2 images
|
1888
2288
|
gc.collect()
|
1889
|
-
#print('========== generating masks ==========')
|
1890
2289
|
|
1891
2290
|
if not torch.cuda.is_available():
|
1892
2291
|
print(f'Torch CUDA is not available, using CPU')
|
1893
2292
|
|
1894
2293
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
1895
|
-
model = cp_models.Cellpose(gpu=True, model_type=model_name, device=device)
|
2294
|
+
model = cp_models.Cellpose(gpu=True, model_type=model_name, device=device)
|
2295
|
+
|
1896
2296
|
if file_type == '.npz':
|
1897
2297
|
paths = [os.path.join(src, file) for file in os.listdir(src) if file.endswith('.npz')]
|
1898
2298
|
else:
|
@@ -1919,9 +2319,6 @@ def identify_masks(src, object_type, model_name, batch_size, channels, diameter,
|
|
1919
2319
|
|
1920
2320
|
average_sizes = []
|
1921
2321
|
time_ls = []
|
1922
|
-
moving_avg_q1 = 0
|
1923
|
-
moving_avg_q3 = 0
|
1924
|
-
moving_count = 0
|
1925
2322
|
for file_index, path in enumerate(paths):
|
1926
2323
|
|
1927
2324
|
name = os.path.basename(path)
|
@@ -1979,7 +2376,7 @@ def identify_masks(src, object_type, model_name, batch_size, channels, diameter,
|
|
1979
2376
|
|
1980
2377
|
cellpose_batch_size = _get_cellpose_batch_size()
|
1981
2378
|
|
1982
|
-
model = cellpose.denoise.DenoiseModel(model_type=f"denoise_{model_name}", gpu=True)
|
2379
|
+
#model = cellpose.denoise.DenoiseModel(model_type=f"denoise_{model_name}", gpu=True)
|
1983
2380
|
|
1984
2381
|
masks, flows, _, _ = model.eval(x=batch,
|
1985
2382
|
batch_size=cellpose_batch_size,
|
@@ -1991,9 +2388,9 @@ def identify_masks(src, object_type, model_name, batch_size, channels, diameter,
|
|
1991
2388
|
cellprob_threshold=cellprob_threshold,
|
1992
2389
|
rescale=None,
|
1993
2390
|
resample=resample,
|
1994
|
-
#net_avg=net_avg,
|
1995
2391
|
stitch_threshold=stitch_threshold,
|
1996
2392
|
progress=None)
|
2393
|
+
|
1997
2394
|
print('Masks shape',masks.shape)
|
1998
2395
|
if timelapse:
|
1999
2396
|
_save_object_counts_to_database(masks, object_type, batch_filenames, count_loc, added_string='_timelapse')
|
@@ -2017,7 +2414,7 @@ def identify_masks(src, object_type, model_name, batch_size, channels, diameter,
|
|
2017
2414
|
|
2018
2415
|
else:
|
2019
2416
|
_save_object_counts_to_database(masks, object_type, batch_filenames, count_loc, added_string='_before_filtration')
|
2020
|
-
mask_stack = _filter_cp_masks(masks, flows,
|
2417
|
+
mask_stack = _filter_cp_masks(masks, flows, filter_size, filter_intensity, minimum_size, maximum_size, remove_border_objects, merge, batch, plot, figuresize)
|
2021
2418
|
_save_object_counts_to_database(mask_stack, object_type, batch_filenames, count_loc, added_string='_after_filtration')
|
2022
2419
|
|
2023
2420
|
if not np.any(mask_stack):
|
@@ -2049,10 +2446,13 @@ def identify_masks(src, object_type, model_name, batch_size, channels, diameter,
|
|
2049
2446
|
gc.collect()
|
2050
2447
|
return
|
2051
2448
|
|
2052
|
-
|
2449
|
+
def all_elements_match(list1, list2):
|
2450
|
+
# Check if all elements in list1 are in list2
|
2451
|
+
return all(element in list2 for element in list1)
|
2452
|
+
|
2053
2453
|
def generate_cellpose_masks(src, settings, object_type):
|
2054
2454
|
|
2055
|
-
from .utils import _masks_to_masks_stack, _filter_cp_masks, _get_cellpose_batch_size, _get_object_settings, _get_cellpose_channels
|
2455
|
+
from .utils import _masks_to_masks_stack, _filter_cp_masks, _get_cellpose_batch_size, _get_object_settings, _get_cellpose_channels, _choose_model, mask_object_count
|
2056
2456
|
from .io import _create_database, _save_object_counts_to_database, _check_masks, _get_avg_object_size
|
2057
2457
|
from .timelapse import _npz_to_movie, _btrack_track_cells, _trackpy_track_cells
|
2058
2458
|
from .plot import plot_masks
|
@@ -2079,16 +2479,15 @@ def generate_cellpose_masks(src, settings, object_type):
|
|
2079
2479
|
object_settings = _get_object_settings(object_type, settings)
|
2080
2480
|
model_name = object_settings['model_name']
|
2081
2481
|
|
2082
|
-
cellpose_channels = _get_cellpose_channels(settings['nucleus_channel'], settings['pathogen_channel'], settings['cell_channel'])
|
2482
|
+
cellpose_channels = _get_cellpose_channels(src, settings['nucleus_channel'], settings['pathogen_channel'], settings['cell_channel'])
|
2483
|
+
if settings['verbose']:
|
2484
|
+
print(cellpose_channels)
|
2485
|
+
|
2083
2486
|
channels = cellpose_channels[object_type]
|
2084
2487
|
cellpose_batch_size = _get_cellpose_batch_size()
|
2085
|
-
|
2086
2488
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
2087
|
-
model =
|
2088
|
-
#dn = denoise.CellposeDenoiseModel(model_type=f"denoise_{model_name}", gpu=True, device=device)
|
2089
|
-
|
2489
|
+
model = _choose_model(model_name, device, object_type='cell', restore_type=None)
|
2090
2490
|
chans = [2, 1] if model_name == 'cyto2' else [0,0] if model_name == 'nucleus' else [2,0] if model_name == 'cyto' else [2, 0] if model_name == 'cyto3' else [2, 0]
|
2091
|
-
|
2092
2491
|
paths = [os.path.join(src, file) for file in os.listdir(src) if file.endswith('.npz')]
|
2093
2492
|
|
2094
2493
|
count_loc = os.path.dirname(src)+'/measurements/measurements.db'
|
@@ -2097,10 +2496,6 @@ def generate_cellpose_masks(src, settings, object_type):
|
|
2097
2496
|
|
2098
2497
|
average_sizes = []
|
2099
2498
|
time_ls = []
|
2100
|
-
moving_avg_q1 = 0
|
2101
|
-
moving_avg_q3 = 0
|
2102
|
-
moving_count = 0
|
2103
|
-
|
2104
2499
|
for file_index, path in enumerate(paths):
|
2105
2500
|
name = os.path.basename(path)
|
2106
2501
|
name, ext = os.path.splitext(name)
|
@@ -2111,16 +2506,22 @@ def generate_cellpose_masks(src, settings, object_type):
|
|
2111
2506
|
stack = data['data']
|
2112
2507
|
filenames = data['filenames']
|
2113
2508
|
if settings['timelapse']:
|
2509
|
+
|
2510
|
+
trackable_objects = ['cell','nucleus','pathogen']
|
2511
|
+
if not all_elements_match(settings['timelapse_objects'], trackable_objects):
|
2512
|
+
print(f'timelapse_objects {settings["timelapse_objects"]} must be a subset of {trackable_objects}')
|
2513
|
+
return
|
2514
|
+
|
2114
2515
|
if len(stack) != batch_size:
|
2115
2516
|
print(f'Changed batch_size:{batch_size} to {len(stack)}, data length:{len(stack)}')
|
2116
|
-
settings['
|
2517
|
+
settings['timelapse_batch_size'] = len(stack)
|
2117
2518
|
batch_size = len(stack)
|
2118
2519
|
if isinstance(timelapse_frame_limits, list):
|
2119
2520
|
if len(timelapse_frame_limits) >= 2:
|
2120
2521
|
stack = stack[timelapse_frame_limits[0]: timelapse_frame_limits[1], :, :, :].astype(stack.dtype)
|
2121
2522
|
filenames = filenames[timelapse_frame_limits[0]: timelapse_frame_limits[1]]
|
2122
2523
|
batch_size = len(stack)
|
2123
|
-
print(f'Cut batch
|
2524
|
+
print(f'Cut batch at indecies: {timelapse_frame_limits}, New batch_size: {batch_size} ')
|
2124
2525
|
|
2125
2526
|
for i in range(0, stack.shape[0], batch_size):
|
2126
2527
|
mask_stack = []
|
@@ -2136,8 +2537,7 @@ def generate_cellpose_masks(src, settings, object_type):
|
|
2136
2537
|
if not settings['plot']:
|
2137
2538
|
batch, batch_filenames = _check_masks(batch, batch_filenames, output_folder)
|
2138
2539
|
if batch.size == 0:
|
2139
|
-
print(f'Processing {file_index}/{len(paths)}: Images/
|
2140
|
-
#print(f'Processing {file_index}/{len(paths)}: Images/N100pz {batch.shape[0]}', end='\r', flush=True)
|
2540
|
+
print(f'Processing {file_index}/{len(paths)}: Images/npz {batch.shape[0]}')
|
2141
2541
|
continue
|
2142
2542
|
if batch.max() > 1:
|
2143
2543
|
batch = batch / batch.max()
|
@@ -2150,10 +2550,8 @@ def generate_cellpose_masks(src, settings, object_type):
|
|
2150
2550
|
_npz_to_movie(batch, batch_filenames, save_path, fps=2)
|
2151
2551
|
else:
|
2152
2552
|
stitch_threshold=0.0
|
2153
|
-
|
2154
|
-
|
2155
|
-
#batch = np.stack((batch, batch), axis=-1)
|
2156
|
-
#print(f'object: {object_type} chans : {chans} channels : {channels} model: {model_name}')
|
2553
|
+
|
2554
|
+
print('batch.shape',batch.shape)
|
2157
2555
|
masks, flows, _, _ = model.eval(x=batch,
|
2158
2556
|
batch_size=cellpose_batch_size,
|
2159
2557
|
normalize=False,
|
@@ -2165,9 +2563,15 @@ def generate_cellpose_masks(src, settings, object_type):
|
|
2165
2563
|
rescale=None,
|
2166
2564
|
resample=object_settings['resample'],
|
2167
2565
|
stitch_threshold=stitch_threshold)
|
2168
|
-
|
2169
|
-
|
2566
|
+
|
2170
2567
|
if timelapse:
|
2568
|
+
if settings['plot']:
|
2569
|
+
for idx, (mask, flow, image) in enumerate(zip(masks, flows[0], batch)):
|
2570
|
+
if idx == 0:
|
2571
|
+
num_objects = mask_object_count(mask)
|
2572
|
+
print(f'Number of objects: {num_objects}')
|
2573
|
+
plot_masks(batch=image, masks=mask, flows=flow, cmap='inferno', figuresize=figuresize, nr=1, file_type='.npz', print_object_number=True)
|
2574
|
+
|
2171
2575
|
_save_object_counts_to_database(masks, object_type, batch_filenames, count_loc, added_string='_timelapse')
|
2172
2576
|
if object_type in timelapse_objects:
|
2173
2577
|
if timelapse_mode == 'btrack':
|
@@ -2196,35 +2600,54 @@ def generate_cellpose_masks(src, settings, object_type):
|
|
2196
2600
|
name=name,
|
2197
2601
|
batch_filenames=batch_filenames,
|
2198
2602
|
object_type=object_type,
|
2199
|
-
|
2603
|
+
masks=masks,
|
2200
2604
|
timelapse_displacement=timelapse_displacement,
|
2201
2605
|
timelapse_memory=timelapse_memory,
|
2202
2606
|
timelapse_remove_transient=timelapse_remove_transient,
|
2203
2607
|
plot=settings['plot'],
|
2204
2608
|
save=settings['save'],
|
2205
|
-
|
2609
|
+
mode=timelapse_mode)
|
2206
2610
|
else:
|
2207
2611
|
mask_stack = _masks_to_masks_stack(masks)
|
2208
|
-
|
2209
2612
|
else:
|
2210
2613
|
_save_object_counts_to_database(masks, object_type, batch_filenames, count_loc, added_string='_before_filtration')
|
2211
|
-
|
2212
|
-
|
2213
|
-
|
2214
|
-
|
2215
|
-
|
2216
|
-
|
2217
|
-
|
2218
|
-
|
2219
|
-
|
2220
|
-
|
2221
|
-
|
2222
|
-
|
2223
|
-
|
2224
|
-
|
2225
|
-
|
2226
|
-
|
2614
|
+
if object_settings['merge'] and not settings['filter']:
|
2615
|
+
mask_stack = _filter_cp_masks(masks=masks,
|
2616
|
+
flows=flows,
|
2617
|
+
filter_size=False,
|
2618
|
+
filter_intensity=False,
|
2619
|
+
minimum_size=object_settings['minimum_size'],
|
2620
|
+
maximum_size=object_settings['maximum_size'],
|
2621
|
+
remove_border_objects=False,
|
2622
|
+
merge=object_settings['merge'],
|
2623
|
+
batch=batch,
|
2624
|
+
plot=settings['plot'],
|
2625
|
+
figuresize=figuresize)
|
2626
|
+
|
2627
|
+
if settings['filter']:
|
2628
|
+
mask_stack = _filter_cp_masks(masks=masks,
|
2629
|
+
flows=flows,
|
2630
|
+
filter_size=object_settings['filter_size'],
|
2631
|
+
filter_intensity=object_settings['filter_intensity'],
|
2632
|
+
minimum_size=object_settings['minimum_size'],
|
2633
|
+
maximum_size=object_settings['maximum_size'],
|
2634
|
+
remove_border_objects=object_settings['remove_border_objects'],
|
2635
|
+
merge=object_settings['merge'],
|
2636
|
+
batch=batch,
|
2637
|
+
plot=settings['plot'],
|
2638
|
+
figuresize=figuresize)
|
2639
|
+
|
2640
|
+
_save_object_counts_to_database(mask_stack, object_type, batch_filenames, count_loc, added_string='_after_filtration')
|
2641
|
+
else:
|
2642
|
+
mask_stack = _masks_to_masks_stack(masks)
|
2227
2643
|
|
2644
|
+
if settings['plot']:
|
2645
|
+
for idx, (mask, flow, image) in enumerate(zip(masks, flows[0], batch)):
|
2646
|
+
if idx == 0:
|
2647
|
+
num_objects = mask_object_count(mask)
|
2648
|
+
print(f'Number of objects, : {num_objects}')
|
2649
|
+
plot_masks(batch=image, masks=mask, flows=flow, cmap='inferno', figuresize=figuresize, nr=1, file_type='.npz', print_object_number=True)
|
2650
|
+
|
2228
2651
|
if not np.any(mask_stack):
|
2229
2652
|
average_obj_size = 0
|
2230
2653
|
else:
|
@@ -2240,7 +2663,6 @@ def generate_cellpose_masks(src, settings, object_type):
|
|
2240
2663
|
time_in_min = average_time/60
|
2241
2664
|
time_per_mask = average_time/batch_size
|
2242
2665
|
print(f'Processing {len(paths)} files with {batch_size} imgs: {(file_index+1)*(batch_size+1)}/{(len(paths))*(batch_size+1)}: Time/batch {time_in_min:.3f} min: Time/mask {time_per_mask:.3f}sec: {object_type} size: {overall_average_size:.3f} px2')
|
2243
|
-
#print(f'Processing {len(paths)} files with {batch_size} imgs: {(file_index+1)*(batch_size+1)}/{(len(paths))*(batch_size+1)}: Time/batch {time_in_min:.3f} min: Time/mask {time_per_mask:.3f}sec: {object_type} size: {overall_average_size:.3f} px2', end='\r', flush=True)
|
2244
2666
|
if not timelapse:
|
2245
2667
|
if settings['plot']:
|
2246
2668
|
plot_masks(batch, mask_stack, flows, figuresize=figuresize, cmap='inferno', nr=batch_size)
|
@@ -2252,4 +2674,663 @@ def generate_cellpose_masks(src, settings, object_type):
|
|
2252
2674
|
batch_filenames = []
|
2253
2675
|
gc.collect()
|
2254
2676
|
torch.cuda.empty_cache()
|
2255
|
-
return
|
2677
|
+
return
|
2678
|
+
|
2679
|
+
def generate_masks_from_imgs(src, model, model_name, batch_size, diameter, cellprob_threshold, grayscale, save, normalize, channels, percentiles, circular, invert, plot, resize, target_height, target_width, verbose):
|
2680
|
+
from .io import _load_images_and_labels, _load_normalized_images_and_labels
|
2681
|
+
from .utils import resize_images_and_labels, resizescikit
|
2682
|
+
from .plot import print_mask_and_flows
|
2683
|
+
|
2684
|
+
dst = os.path.join(src, model_name)
|
2685
|
+
os.makedirs(dst, exist_ok=True)
|
2686
|
+
|
2687
|
+
flow_threshold = 30
|
2688
|
+
chans = [2, 1] if model_name == 'cyto2' else [0,0] if model_name == 'nucleus' else [1,0] if model_name == 'cyto' else [2, 0]
|
2689
|
+
|
2690
|
+
if grayscale:
|
2691
|
+
chans=[0, 0]
|
2692
|
+
|
2693
|
+
all_image_files = [os.path.join(src, f) for f in os.listdir(src) if f.endswith('.tif')]
|
2694
|
+
random.shuffle(all_image_files)
|
2695
|
+
|
2696
|
+
|
2697
|
+
if verbose == True:
|
2698
|
+
print(f'Cellpose settings: Model: {model_name}, channels: {channels}, cellpose_chans: {chans}, diameter:{diameter}, flow_threshold:{flow_threshold}, cellprob_threshold:{cellprob_threshold}')
|
2699
|
+
|
2700
|
+
time_ls = []
|
2701
|
+
for i in range(0, len(all_image_files), batch_size):
|
2702
|
+
image_files = all_image_files[i:i+batch_size]
|
2703
|
+
|
2704
|
+
if normalize:
|
2705
|
+
images, _, image_names, _ = _load_normalized_images_and_labels(image_files=image_files, label_files=None, signal_thresholds=100, channels=channels, percentiles=percentiles, circular=circular, invert=invert, visualize=plot)
|
2706
|
+
images = [np.squeeze(img) if img.shape[-1] == 1 else img for img in images]
|
2707
|
+
orig_dims = [(image.shape[0], image.shape[1]) for image in images]
|
2708
|
+
else:
|
2709
|
+
images, _, image_names, _ = _load_images_and_labels(image_files=image_files, label_files=None, circular=circular, invert=invert)
|
2710
|
+
images = [np.squeeze(img) if img.shape[-1] == 1 else img for img in images]
|
2711
|
+
orig_dims = [(image.shape[0], image.shape[1]) for image in images]
|
2712
|
+
if resize:
|
2713
|
+
images, _ = resize_images_and_labels(images, None, target_height, target_width, True)
|
2714
|
+
|
2715
|
+
for file_index, stack in enumerate(images):
|
2716
|
+
start = time.time()
|
2717
|
+
output = model.eval(x=stack,
|
2718
|
+
normalize=False,
|
2719
|
+
channels=chans,
|
2720
|
+
channel_axis=3,
|
2721
|
+
diameter=diameter,
|
2722
|
+
flow_threshold=flow_threshold,
|
2723
|
+
cellprob_threshold=cellprob_threshold,
|
2724
|
+
rescale=False,
|
2725
|
+
resample=False,
|
2726
|
+
progress=True)
|
2727
|
+
|
2728
|
+
if len(output) == 4:
|
2729
|
+
mask, flows, _, _ = output
|
2730
|
+
elif len(output) == 3:
|
2731
|
+
mask, flows, _ = output
|
2732
|
+
else:
|
2733
|
+
raise ValueError("Unexpected number of return values from model.eval()")
|
2734
|
+
|
2735
|
+
if resize:
|
2736
|
+
dims = orig_dims[file_index]
|
2737
|
+
mask = resizescikit(mask, dims, order=0, preserve_range=True, anti_aliasing=False).astype(mask.dtype)
|
2738
|
+
|
2739
|
+
stop = time.time()
|
2740
|
+
duration = (stop - start)
|
2741
|
+
time_ls.append(duration)
|
2742
|
+
average_time = np.mean(time_ls) if len(time_ls) > 0 else 0
|
2743
|
+
print(f'Processing {file_index+1}/{len(images)} images : Time/image {average_time:.3f} sec', end='\r', flush=True)
|
2744
|
+
if plot:
|
2745
|
+
if resize:
|
2746
|
+
stack = resizescikit(stack, dims, preserve_range=True, anti_aliasing=False).astype(stack.dtype)
|
2747
|
+
print_mask_and_flows(stack, mask, flows, overlay=True)
|
2748
|
+
if save:
|
2749
|
+
output_filename = os.path.join(dst, image_names[file_index])
|
2750
|
+
cv2.imwrite(output_filename, mask)
|
2751
|
+
|
2752
|
+
|
2753
|
+
def check_cellpose_models(settings):
|
2754
|
+
|
2755
|
+
src = settings['src']
|
2756
|
+
batch_size = settings['batch_size']
|
2757
|
+
cellprob_threshold = settings['cellprob_threshold']
|
2758
|
+
save = settings['save']
|
2759
|
+
normalize = settings['normalize']
|
2760
|
+
channels = settings['channels']
|
2761
|
+
percentiles = settings['percentiles']
|
2762
|
+
circular = settings['circular']
|
2763
|
+
invert = settings['invert']
|
2764
|
+
plot = settings['plot']
|
2765
|
+
diameter = settings['diameter']
|
2766
|
+
resize = settings['resize']
|
2767
|
+
grayscale = settings['grayscale']
|
2768
|
+
verbose = settings['verbose']
|
2769
|
+
target_height = settings['width_height'][0]
|
2770
|
+
target_width = settings['width_height'][1]
|
2771
|
+
|
2772
|
+
cellpose_models = ['cyto', 'nuclei', 'cyto2', 'cyto3']
|
2773
|
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
2774
|
+
|
2775
|
+
for model_name in cellpose_models:
|
2776
|
+
|
2777
|
+
model = cp_models.CellposeModel(gpu=True, model_type=model_name, device=device)
|
2778
|
+
print(f'Using {model_name}')
|
2779
|
+
generate_masks_from_imgs(src, model, model_name, batch_size, diameter, cellprob_threshold, grayscale, save, normalize, channels, percentiles, circular, invert, plot, resize, target_height, target_width, verbose)
|
2780
|
+
|
2781
|
+
return
|
2782
|
+
|
2783
|
+
def compare_masks_v1(dir1, dir2, dir3, verbose=False):
|
2784
|
+
|
2785
|
+
from .io import _read_mask
|
2786
|
+
from .plot import visualize_masks, plot_comparison_results
|
2787
|
+
from .utils import extract_boundaries, boundary_f1_score, compute_segmentation_ap, jaccard_index, dice_coefficient
|
2788
|
+
|
2789
|
+
filenames = os.listdir(dir1)
|
2790
|
+
results = []
|
2791
|
+
cond_1 = os.path.basename(dir1)
|
2792
|
+
cond_2 = os.path.basename(dir2)
|
2793
|
+
cond_3 = os.path.basename(dir3)
|
2794
|
+
|
2795
|
+
for index, filename in enumerate(filenames):
|
2796
|
+
print(f'Processing image:{index+1}', end='\r', flush=True)
|
2797
|
+
path1, path2, path3 = os.path.join(dir1, filename), os.path.join(dir2, filename), os.path.join(dir3, filename)
|
2798
|
+
|
2799
|
+
print(path1)
|
2800
|
+
print(path2)
|
2801
|
+
print(path3)
|
2802
|
+
|
2803
|
+
if os.path.exists(path2) and os.path.exists(path3):
|
2804
|
+
|
2805
|
+
mask1, mask2, mask3 = _read_mask(path1), _read_mask(path2), _read_mask(path3)
|
2806
|
+
boundary_true1, boundary_true2, boundary_true3 = extract_boundaries(mask1), extract_boundaries(mask2), extract_boundaries(mask3)
|
2807
|
+
|
2808
|
+
|
2809
|
+
true_masks, pred_masks = [mask1], [mask2, mask3] # Assuming mask1 is the ground truth for simplicity
|
2810
|
+
true_labels, pred_labels_1, pred_labels_2 = label(mask1), label(mask2), label(mask3)
|
2811
|
+
average_precision_0, average_precision_1 = compute_segmentation_ap(mask1, mask2), compute_segmentation_ap(mask1, mask3)
|
2812
|
+
ap_scores = [average_precision_0, average_precision_1]
|
2813
|
+
|
2814
|
+
if verbose:
|
2815
|
+
#unique_values1, unique_values2, unique_values3 = np.unique(mask1), np.unique(mask2), np.unique(mask3)
|
2816
|
+
#print(f"Unique values in mask 1: {unique_values1}, mask 2: {unique_values2}, mask 3: {unique_values3}")
|
2817
|
+
visualize_masks(boundary_true1, boundary_true2, boundary_true3, title=f"Boundaries - {filename}")
|
2818
|
+
|
2819
|
+
boundary_f1_12, boundary_f1_13, boundary_f1_23 = boundary_f1_score(mask1, mask2), boundary_f1_score(mask1, mask3), boundary_f1_score(mask2, mask3)
|
2820
|
+
|
2821
|
+
if (np.unique(mask1).size == 1 and np.unique(mask1)[0] == 0) and \
|
2822
|
+
(np.unique(mask2).size == 1 and np.unique(mask2)[0] == 0) and \
|
2823
|
+
(np.unique(mask3).size == 1 and np.unique(mask3)[0] == 0):
|
2824
|
+
continue
|
2825
|
+
|
2826
|
+
if verbose:
|
2827
|
+
#unique_values4, unique_values5, unique_values6 = np.unique(boundary_f1_12), np.unique(boundary_f1_13), np.unique(boundary_f1_23)
|
2828
|
+
#print(f"Unique values in boundary mask 1: {unique_values4}, mask 2: {unique_values5}, mask 3: {unique_values6}")
|
2829
|
+
visualize_masks(mask1, mask2, mask3, title=filename)
|
2830
|
+
|
2831
|
+
jaccard12 = jaccard_index(mask1, mask2)
|
2832
|
+
dice12 = dice_coefficient(mask1, mask2)
|
2833
|
+
|
2834
|
+
jaccard13 = jaccard_index(mask1, mask3)
|
2835
|
+
dice13 = dice_coefficient(mask1, mask3)
|
2836
|
+
|
2837
|
+
jaccard23 = jaccard_index(mask2, mask3)
|
2838
|
+
dice23 = dice_coefficient(mask2, mask3)
|
2839
|
+
|
2840
|
+
results.append({
|
2841
|
+
f'filename': filename,
|
2842
|
+
f'jaccard_{cond_1}_{cond_2}': jaccard12,
|
2843
|
+
f'dice_{cond_1}_{cond_2}': dice12,
|
2844
|
+
f'jaccard_{cond_1}_{cond_3}': jaccard13,
|
2845
|
+
f'dice_{cond_1}_{cond_3}': dice13,
|
2846
|
+
f'jaccard_{cond_2}_{cond_3}': jaccard23,
|
2847
|
+
f'dice_{cond_2}_{cond_3}': dice23,
|
2848
|
+
f'boundary_f1_{cond_1}_{cond_2}': boundary_f1_12,
|
2849
|
+
f'boundary_f1_{cond_1}_{cond_3}': boundary_f1_13,
|
2850
|
+
f'boundary_f1_{cond_2}_{cond_3}': boundary_f1_23,
|
2851
|
+
f'average_precision_{cond_1}_{cond_2}': ap_scores[0],
|
2852
|
+
f'average_precision_{cond_1}_{cond_3}': ap_scores[1]
|
2853
|
+
})
|
2854
|
+
else:
|
2855
|
+
print(f'Cannot find {path1} or {path2} or {path3}')
|
2856
|
+
fig = plot_comparison_results(results)
|
2857
|
+
return results, fig
|
2858
|
+
|
2859
|
+
def compare_cellpose_masks_v1(src, verbose=False):
|
2860
|
+
from .io import _read_mask
|
2861
|
+
from .plot import visualize_masks, plot_comparison_results, visualize_cellpose_masks
|
2862
|
+
from .utils import extract_boundaries, boundary_f1_score, compute_segmentation_ap, jaccard_index
|
2863
|
+
|
2864
|
+
import os
|
2865
|
+
import numpy as np
|
2866
|
+
from skimage.measure import label
|
2867
|
+
|
2868
|
+
# Collect all subdirectories in src
|
2869
|
+
dirs = [os.path.join(src, d) for d in os.listdir(src) if os.path.isdir(os.path.join(src, d))]
|
2870
|
+
|
2871
|
+
dirs.sort() # Optional: sort directories if needed
|
2872
|
+
|
2873
|
+
# Get common files in all directories
|
2874
|
+
common_files = set(os.listdir(dirs[0]))
|
2875
|
+
for d in dirs[1:]:
|
2876
|
+
common_files.intersection_update(os.listdir(d))
|
2877
|
+
common_files = list(common_files)
|
2878
|
+
|
2879
|
+
results = []
|
2880
|
+
conditions = [os.path.basename(d) for d in dirs]
|
2881
|
+
|
2882
|
+
for index, filename in enumerate(common_files):
|
2883
|
+
print(f'Processing image {index+1}/{len(common_files)}', end='\r', flush=True)
|
2884
|
+
paths = [os.path.join(d, filename) for d in dirs]
|
2885
|
+
|
2886
|
+
# Check if file exists in all directories
|
2887
|
+
if not all(os.path.exists(path) for path in paths):
|
2888
|
+
print(f'Skipping {filename} as it is not present in all directories.')
|
2889
|
+
continue
|
2890
|
+
|
2891
|
+
masks = [_read_mask(path) for path in paths]
|
2892
|
+
boundaries = [extract_boundaries(mask) for mask in masks]
|
2893
|
+
|
2894
|
+
if verbose:
|
2895
|
+
visualize_cellpose_masks(masks, titles=conditions, comparison_title=f"Masks Comparison for {filename}")
|
2896
|
+
|
2897
|
+
# Initialize data structure for results
|
2898
|
+
file_results = {'filename': filename}
|
2899
|
+
|
2900
|
+
# Compare each mask with each other
|
2901
|
+
for i in range(len(masks)):
|
2902
|
+
for j in range(i + 1, len(masks)):
|
2903
|
+
condition_i = conditions[i]
|
2904
|
+
condition_j = conditions[j]
|
2905
|
+
mask_i = masks[i]
|
2906
|
+
mask_j = masks[j]
|
2907
|
+
|
2908
|
+
# Compute metrics
|
2909
|
+
boundary_f1 = boundary_f1_score(mask_i, mask_j)
|
2910
|
+
jaccard = jaccard_index(mask_i, mask_j)
|
2911
|
+
average_precision = compute_segmentation_ap(mask_i, mask_j)
|
2912
|
+
|
2913
|
+
# Store results
|
2914
|
+
file_results[f'jaccard_{condition_i}_{condition_j}'] = jaccard
|
2915
|
+
file_results[f'boundary_f1_{condition_i}_{condition_j}'] = boundary_f1
|
2916
|
+
file_results[f'average_precision_{condition_i}_{condition_j}'] = average_precision
|
2917
|
+
|
2918
|
+
results.append(file_results)
|
2919
|
+
|
2920
|
+
fig = plot_comparison_results(results)
|
2921
|
+
return results, fig
|
2922
|
+
|
2923
|
+
def compare_mask(args):
|
2924
|
+
src, filename, dirs, conditions = args
|
2925
|
+
paths = [os.path.join(d, filename) for d in dirs]
|
2926
|
+
|
2927
|
+
if not all(os.path.exists(path) for path in paths):
|
2928
|
+
return None
|
2929
|
+
|
2930
|
+
from .io import _read_mask # Import here to avoid issues in multiprocessing
|
2931
|
+
from .utils import extract_boundaries, boundary_f1_score, compute_segmentation_ap, jaccard_index
|
2932
|
+
from .plot import plot_comparison_results
|
2933
|
+
|
2934
|
+
masks = [_read_mask(path) for path in paths]
|
2935
|
+
file_results = {'filename': filename}
|
2936
|
+
|
2937
|
+
for i in range(len(masks)):
|
2938
|
+
for j in range(i + 1, len(masks)):
|
2939
|
+
mask_i, mask_j = masks[i], masks[j]
|
2940
|
+
f1_score = boundary_f1_score(mask_i, mask_j)
|
2941
|
+
jac_index = jaccard_index(mask_i, mask_j)
|
2942
|
+
ap_score = compute_segmentation_ap(mask_i, mask_j)
|
2943
|
+
|
2944
|
+
file_results.update({
|
2945
|
+
f'jaccard_{conditions[i]}_{conditions[j]}': jac_index,
|
2946
|
+
f'boundary_f1_{conditions[i]}_{conditions[j]}': f1_score,
|
2947
|
+
f'ap_{conditions[i]}_{conditions[j]}': ap_score
|
2948
|
+
})
|
2949
|
+
|
2950
|
+
return file_results
|
2951
|
+
|
2952
|
+
def compare_cellpose_masks(src, verbose=False, processes=None):
|
2953
|
+
from .plot import visualize_cellpose_masks, plot_comparison_results
|
2954
|
+
from .io import _read_mask
|
2955
|
+
dirs = [os.path.join(src, d) for d in os.listdir(src) if os.path.isdir(os.path.join(src, d))]
|
2956
|
+
dirs.sort() # Optional: sort directories if needed
|
2957
|
+
conditions = [os.path.basename(d) for d in dirs]
|
2958
|
+
|
2959
|
+
# Get common files in all directories
|
2960
|
+
common_files = set(os.listdir(dirs[0]))
|
2961
|
+
for d in dirs[1:]:
|
2962
|
+
common_files.intersection_update(os.listdir(d))
|
2963
|
+
common_files = list(common_files)
|
2964
|
+
|
2965
|
+
# Create a pool of workers
|
2966
|
+
with Pool(processes=processes) as pool:
|
2967
|
+
args = [(src, filename, dirs, conditions) for filename in common_files]
|
2968
|
+
results = pool.map(compare_mask, args)
|
2969
|
+
|
2970
|
+
# Filter out None results (from skipped files)
|
2971
|
+
results = [res for res in results if res is not None]
|
2972
|
+
|
2973
|
+
if verbose:
|
2974
|
+
for result in results:
|
2975
|
+
filename = result['filename']
|
2976
|
+
masks = [_read_mask(os.path.join(d, filename)) for d in dirs]
|
2977
|
+
visualize_cellpose_masks(masks, titles=conditions, comparison_title=f"Masks Comparison for {filename}")
|
2978
|
+
|
2979
|
+
fig = plot_comparison_results(results)
|
2980
|
+
return results, fig
|
2981
|
+
|
2982
|
+
|
2983
|
+
def _calculate_similarity(df, features, col_to_compare, val1, val2):
|
2984
|
+
"""
|
2985
|
+
Calculate similarity scores of each well to the positive and negative controls using various metrics.
|
2986
|
+
|
2987
|
+
Args:
|
2988
|
+
df (pandas.DataFrame): DataFrame containing the data.
|
2989
|
+
features (list): List of feature columns to use for similarity calculation.
|
2990
|
+
col_to_compare (str): Column name to use for comparing groups.
|
2991
|
+
val1, val2 (str): Values in col_to_compare to create subsets for comparison.
|
2992
|
+
|
2993
|
+
Returns:
|
2994
|
+
pandas.DataFrame: DataFrame with similarity scores.
|
2995
|
+
"""
|
2996
|
+
# Separate positive and negative control wells
|
2997
|
+
pos_control = df[df[col_to_compare] == val1][features].mean()
|
2998
|
+
neg_control = df[df[col_to_compare] == val2][features].mean()
|
2999
|
+
|
3000
|
+
# Standardize features for Mahalanobis distance
|
3001
|
+
scaler = StandardScaler()
|
3002
|
+
scaled_features = scaler.fit_transform(df[features])
|
3003
|
+
|
3004
|
+
# Regularize the covariance matrix to avoid singularity
|
3005
|
+
cov_matrix = np.cov(scaled_features, rowvar=False)
|
3006
|
+
inv_cov_matrix = None
|
3007
|
+
try:
|
3008
|
+
inv_cov_matrix = np.linalg.inv(cov_matrix)
|
3009
|
+
except np.linalg.LinAlgError:
|
3010
|
+
# Add a small value to the diagonal elements for regularization
|
3011
|
+
epsilon = 1e-5
|
3012
|
+
inv_cov_matrix = np.linalg.inv(cov_matrix + np.eye(cov_matrix.shape[0]) * epsilon)
|
3013
|
+
|
3014
|
+
# Calculate similarity scores
|
3015
|
+
df['similarity_to_pos_euclidean'] = df[features].apply(lambda row: euclidean(row, pos_control), axis=1)
|
3016
|
+
df['similarity_to_neg_euclidean'] = df[features].apply(lambda row: euclidean(row, neg_control), axis=1)
|
3017
|
+
df['similarity_to_pos_cosine'] = df[features].apply(lambda row: cosine(row, pos_control), axis=1)
|
3018
|
+
df['similarity_to_neg_cosine'] = df[features].apply(lambda row: cosine(row, neg_control), axis=1)
|
3019
|
+
df['similarity_to_pos_mahalanobis'] = df[features].apply(lambda row: mahalanobis(row, pos_control, inv_cov_matrix), axis=1)
|
3020
|
+
df['similarity_to_neg_mahalanobis'] = df[features].apply(lambda row: mahalanobis(row, neg_control, inv_cov_matrix), axis=1)
|
3021
|
+
df['similarity_to_pos_manhattan'] = df[features].apply(lambda row: cityblock(row, pos_control), axis=1)
|
3022
|
+
df['similarity_to_neg_manhattan'] = df[features].apply(lambda row: cityblock(row, neg_control), axis=1)
|
3023
|
+
df['similarity_to_pos_minkowski'] = df[features].apply(lambda row: minkowski(row, pos_control, p=3), axis=1)
|
3024
|
+
df['similarity_to_neg_minkowski'] = df[features].apply(lambda row: minkowski(row, neg_control, p=3), axis=1)
|
3025
|
+
df['similarity_to_pos_chebyshev'] = df[features].apply(lambda row: chebyshev(row, pos_control), axis=1)
|
3026
|
+
df['similarity_to_neg_chebyshev'] = df[features].apply(lambda row: chebyshev(row, neg_control), axis=1)
|
3027
|
+
df['similarity_to_pos_hamming'] = df[features].apply(lambda row: hamming(row, pos_control), axis=1)
|
3028
|
+
df['similarity_to_neg_hamming'] = df[features].apply(lambda row: hamming(row, neg_control), axis=1)
|
3029
|
+
df['similarity_to_pos_jaccard'] = df[features].apply(lambda row: jaccard(row, pos_control), axis=1)
|
3030
|
+
df['similarity_to_neg_jaccard'] = df[features].apply(lambda row: jaccard(row, neg_control), axis=1)
|
3031
|
+
df['similarity_to_pos_braycurtis'] = df[features].apply(lambda row: braycurtis(row, pos_control), axis=1)
|
3032
|
+
df['similarity_to_neg_braycurtis'] = df[features].apply(lambda row: braycurtis(row, neg_control), axis=1)
|
3033
|
+
|
3034
|
+
return df
|
3035
|
+
|
3036
|
+
def _permutation_importance(df, feature_string='channel_3', col_to_compare='col', pos='c1', neg='c2', exclude=None, n_repeats=10, clean=True, nr_to_plot=30, n_estimators=100, test_size=0.2, random_state=42, model_type='xgboost', n_jobs=-1):
|
3037
|
+
|
3038
|
+
"""
|
3039
|
+
Calculates permutation importance for numerical features in the dataframe,
|
3040
|
+
comparing groups based on specified column values and uses the model to predict
|
3041
|
+
the class for all other rows in the dataframe.
|
3042
|
+
|
3043
|
+
Args:
|
3044
|
+
df (pandas.DataFrame): The DataFrame containing the data.
|
3045
|
+
feature_string (str): String to filter features that contain this substring.
|
3046
|
+
col_to_compare (str): Column name to use for comparing groups.
|
3047
|
+
pos, neg (str): Values in col_to_compare to create subsets for comparison.
|
3048
|
+
exclude (list or str, optional): Columns to exclude from features.
|
3049
|
+
n_repeats (int): Number of repeats for permutation importance.
|
3050
|
+
clean (bool): Whether to remove columns with a single value.
|
3051
|
+
nr_to_plot (int): Number of top features to plot based on permutation importance.
|
3052
|
+
n_estimators (int): Number of trees in the random forest, gradient boosting, or XGBoost model.
|
3053
|
+
test_size (float): Proportion of the dataset to include in the test split.
|
3054
|
+
random_state (int): Random seed for reproducibility.
|
3055
|
+
model_type (str): Type of model to use ('random_forest', 'logistic_regression', 'gradient_boosting', 'xgboost').
|
3056
|
+
n_jobs (int): Number of jobs to run in parallel for applicable models.
|
3057
|
+
|
3058
|
+
Returns:
|
3059
|
+
pandas.DataFrame: The original dataframe with added prediction and data usage columns.
|
3060
|
+
pandas.DataFrame: DataFrame containing the importances and standard deviations.
|
3061
|
+
"""
|
3062
|
+
|
3063
|
+
if 'cells_per_well' in df.columns:
|
3064
|
+
df = df.drop(columns=['cells_per_well'])
|
3065
|
+
|
3066
|
+
# Subset the dataframe based on specified column values
|
3067
|
+
df1 = df[df[col_to_compare] == pos].copy()
|
3068
|
+
df2 = df[df[col_to_compare] == neg].copy()
|
3069
|
+
|
3070
|
+
# Create target variable
|
3071
|
+
df1['target'] = 0
|
3072
|
+
df2['target'] = 1
|
3073
|
+
|
3074
|
+
# Combine the subsets for analysis
|
3075
|
+
combined_df = pd.concat([df1, df2])
|
3076
|
+
|
3077
|
+
# Automatically select numerical features
|
3078
|
+
features = combined_df.select_dtypes(include=[np.number]).columns.tolist()
|
3079
|
+
features.remove('target')
|
3080
|
+
|
3081
|
+
if clean:
|
3082
|
+
combined_df = combined_df.loc[:, combined_df.nunique() > 1]
|
3083
|
+
features = [feature for feature in features if feature in combined_df.columns]
|
3084
|
+
|
3085
|
+
if feature_string is not None:
|
3086
|
+
feature_list = ['channel_0', 'channel_1', 'channel_2', 'channel_3']
|
3087
|
+
|
3088
|
+
# Remove feature_string from the list if it exists
|
3089
|
+
if feature_string in feature_list:
|
3090
|
+
feature_list.remove(feature_string)
|
3091
|
+
|
3092
|
+
features = [feature for feature in features if feature_string in feature]
|
3093
|
+
|
3094
|
+
# Iterate through the list and remove columns from df
|
3095
|
+
for feature_ in feature_list:
|
3096
|
+
features = [feature for feature in features if feature_ not in feature]
|
3097
|
+
print(f'After removing {feature_} features: {len(features)}')
|
3098
|
+
|
3099
|
+
if exclude:
|
3100
|
+
if isinstance(exclude, list):
|
3101
|
+
features = [feature for feature in features if feature not in exclude]
|
3102
|
+
else:
|
3103
|
+
features.remove(exclude)
|
3104
|
+
|
3105
|
+
X = combined_df[features]
|
3106
|
+
y = combined_df['target']
|
3107
|
+
|
3108
|
+
# Split the data into training and testing sets
|
3109
|
+
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size, random_state=random_state)
|
3110
|
+
|
3111
|
+
# Label the data in the original dataframe
|
3112
|
+
combined_df['data_usage'] = 'train'
|
3113
|
+
combined_df.loc[X_test.index, 'data_usage'] = 'test'
|
3114
|
+
|
3115
|
+
# Initialize the model based on model_type
|
3116
|
+
if model_type == 'random_forest':
|
3117
|
+
model = RandomForestClassifier(n_estimators=n_estimators, random_state=random_state, n_jobs=n_jobs)
|
3118
|
+
elif model_type == 'logistic_regression':
|
3119
|
+
model = LogisticRegression(max_iter=1000, random_state=random_state, n_jobs=n_jobs)
|
3120
|
+
elif model_type == 'gradient_boosting':
|
3121
|
+
model = HistGradientBoostingClassifier(max_iter=n_estimators, random_state=random_state) # Supports n_jobs internally
|
3122
|
+
elif model_type == 'xgboost':
|
3123
|
+
model = XGBClassifier(n_estimators=n_estimators, random_state=random_state, nthread=n_jobs, use_label_encoder=False, eval_metric='logloss')
|
3124
|
+
else:
|
3125
|
+
raise ValueError(f"Unsupported model_type: {model_type}")
|
3126
|
+
|
3127
|
+
model.fit(X_train, y_train)
|
3128
|
+
|
3129
|
+
perm_importance = permutation_importance(model, X_train, y_train, n_repeats=n_repeats, random_state=random_state, n_jobs=n_jobs)
|
3130
|
+
|
3131
|
+
# Create a DataFrame for permutation importances
|
3132
|
+
permutation_df = pd.DataFrame({
|
3133
|
+
'feature': [features[i] for i in perm_importance.importances_mean.argsort()],
|
3134
|
+
'importance_mean': perm_importance.importances_mean[perm_importance.importances_mean.argsort()],
|
3135
|
+
'importance_std': perm_importance.importances_std[perm_importance.importances_mean.argsort()]
|
3136
|
+
}).tail(nr_to_plot)
|
3137
|
+
|
3138
|
+
# Plotting
|
3139
|
+
fig, ax = plt.subplots()
|
3140
|
+
ax.barh(permutation_df['feature'], permutation_df['importance_mean'], xerr=permutation_df['importance_std'], color="teal", align="center", alpha=0.6)
|
3141
|
+
ax.set_xlabel('Permutation Importance')
|
3142
|
+
plt.tight_layout()
|
3143
|
+
plt.show()
|
3144
|
+
|
3145
|
+
# Feature importance for models that support it
|
3146
|
+
if model_type in ['random_forest', 'xgboost', 'gradient_boosting']:
|
3147
|
+
feature_importances = model.feature_importances_
|
3148
|
+
feature_importance_df = pd.DataFrame({
|
3149
|
+
'feature': features,
|
3150
|
+
'importance': feature_importances
|
3151
|
+
}).sort_values(by='importance', ascending=False).head(nr_to_plot)
|
3152
|
+
|
3153
|
+
# Plotting feature importance
|
3154
|
+
fig, ax = plt.subplots()
|
3155
|
+
ax.barh(feature_importance_df['feature'], feature_importance_df['importance'], color="blue", align="center", alpha=0.6)
|
3156
|
+
ax.set_xlabel('Feature Importance')
|
3157
|
+
plt.tight_layout()
|
3158
|
+
plt.show()
|
3159
|
+
else:
|
3160
|
+
feature_importance_df = pd.DataFrame()
|
3161
|
+
|
3162
|
+
# Predicting the target variable for the test set
|
3163
|
+
predictions_test = model.predict(X_test)
|
3164
|
+
combined_df.loc[X_test.index, 'predictions'] = predictions_test
|
3165
|
+
|
3166
|
+
# Predicting the target variable for the training set
|
3167
|
+
predictions_train = model.predict(X_train)
|
3168
|
+
combined_df.loc[X_train.index, 'predictions'] = predictions_train
|
3169
|
+
|
3170
|
+
# Predicting the target variable for all other rows in the dataframe
|
3171
|
+
X_all = df[features]
|
3172
|
+
all_predictions = model.predict(X_all)
|
3173
|
+
df['predictions'] = all_predictions
|
3174
|
+
|
3175
|
+
# Combine data usage labels back to the original dataframe
|
3176
|
+
combined_data_usage = pd.concat([combined_df[['data_usage']], df[['predictions']]], axis=0)
|
3177
|
+
df = df.join(combined_data_usage, how='left', rsuffix='_model')
|
3178
|
+
|
3179
|
+
# Calculating and printing the accuracy metrics
|
3180
|
+
accuracy = accuracy_score(y_test, predictions_test)
|
3181
|
+
precision = precision_score(y_test, predictions_test)
|
3182
|
+
recall = recall_score(y_test, predictions_test)
|
3183
|
+
f1 = f1_score(y_test, predictions_test)
|
3184
|
+
print(f"Accuracy: {accuracy}")
|
3185
|
+
print(f"Precision: {precision}")
|
3186
|
+
print(f"Recall: {recall}")
|
3187
|
+
print(f"F1 Score: {f1}")
|
3188
|
+
|
3189
|
+
# Printing class-specific accuracy metrics
|
3190
|
+
print("\nClassification Report:")
|
3191
|
+
print(classification_report(y_test, predictions_test))
|
3192
|
+
|
3193
|
+
df = _calculate_similarity(df, features, col_to_compare, pos, neg)
|
3194
|
+
|
3195
|
+
return [df, permutation_df, feature_importance_df, model, X_train, X_test, y_train, y_test]
|
3196
|
+
|
3197
|
+
def _shap_analysis(model, X_train, X_test):
|
3198
|
+
|
3199
|
+
"""
|
3200
|
+
Performs SHAP analysis on the given model and data.
|
3201
|
+
|
3202
|
+
Args:
|
3203
|
+
model: The trained model.
|
3204
|
+
X_train (pandas.DataFrame): Training feature set.
|
3205
|
+
X_test (pandas.DataFrame): Testing feature set.
|
3206
|
+
"""
|
3207
|
+
|
3208
|
+
explainer = shap.Explainer(model, X_train)
|
3209
|
+
shap_values = explainer(X_test)
|
3210
|
+
|
3211
|
+
# Summary plot
|
3212
|
+
shap.summary_plot(shap_values, X_test)
|
3213
|
+
|
3214
|
+
def plate_heatmap(src, model_type='xgboost', variable='predictions', grouping='mean', min_max='allq', cmap='viridis', channel_of_interest=3, min_count=25, n_estimators=100, col_to_compare='col', pos='c1', neg='c2', exclude=None, n_repeats=10, clean=True, nr_to_plot=20, verbose=False, n_jobs=-1):
|
3215
|
+
from .io import _read_and_merge_data
|
3216
|
+
from .plot import _plot_plates
|
3217
|
+
|
3218
|
+
db_loc = [src+'/measurements/measurements.db']
|
3219
|
+
tables = ['cell', 'nucleus', 'pathogen','cytoplasm']
|
3220
|
+
include_multinucleated, include_multiinfected, include_noninfected = True, 2.0, True
|
3221
|
+
|
3222
|
+
df, _ = _read_and_merge_data(db_loc,
|
3223
|
+
tables,
|
3224
|
+
verbose=verbose,
|
3225
|
+
include_multinucleated=include_multinucleated,
|
3226
|
+
include_multiinfected=include_multiinfected,
|
3227
|
+
include_noninfected=include_noninfected)
|
3228
|
+
|
3229
|
+
if not channel_of_interest is None:
|
3230
|
+
df['recruitment'] = df[f'pathogen_channel_{channel_of_interest}_mean_intensity']/df[f'cytoplasm_channel_{channel_of_interest}_mean_intensity']
|
3231
|
+
feature_string = f'channel_{channel_of_interest}'
|
3232
|
+
else:
|
3233
|
+
feature_string = None
|
3234
|
+
|
3235
|
+
output = _permutation_importance(df, feature_string, col_to_compare, pos, neg, exclude, n_repeats, clean, nr_to_plot, n_estimators=n_estimators, random_state=42, model_type=model_type, n_jobs=n_jobs)
|
3236
|
+
|
3237
|
+
_shap_analysis(output[3], output[4], output[5])
|
3238
|
+
|
3239
|
+
features = output[0].select_dtypes(include=[np.number]).columns.tolist()
|
3240
|
+
|
3241
|
+
if not variable in features:
|
3242
|
+
raise ValueError(f"Variable {variable} not found in the dataframe. Please choose one of the following: {features}")
|
3243
|
+
|
3244
|
+
plate_heatmap = _plot_plates(output[0], variable, grouping, min_max, cmap, min_count)
|
3245
|
+
return [output, plate_heatmap]
|
3246
|
+
|
3247
|
+
def join_measurments_and_annotation(src, tables = ['cell', 'nucleus', 'pathogen','cytoplasm']):
|
3248
|
+
|
3249
|
+
from .io import _read_and_merge_data, _read_db
|
3250
|
+
|
3251
|
+
db_loc = [src+'/measurements/measurements.db']
|
3252
|
+
loc = src+'/measurements/measurements.db'
|
3253
|
+
df, _ = _read_and_merge_data(db_loc,
|
3254
|
+
tables,
|
3255
|
+
verbose=True,
|
3256
|
+
include_multinucleated=True,
|
3257
|
+
include_multiinfected=True,
|
3258
|
+
include_noninfected=True)
|
3259
|
+
|
3260
|
+
paths_df = _read_db(loc, tables=['png_list'])
|
3261
|
+
|
3262
|
+
merged_df = pd.merge(df, paths_df[0], on='prcfo', how='left')
|
3263
|
+
|
3264
|
+
return merged_df
|
3265
|
+
|
3266
|
+
def jitterplot_by_annotation(src, x_column, y_column, plot_title='Jitter Plot', output_path=None, filter_column=None, filter_values=None):
|
3267
|
+
"""
|
3268
|
+
Reads a CSV file and creates a jitter plot of one column grouped by another column.
|
3269
|
+
|
3270
|
+
Args:
|
3271
|
+
src (str): Path to the source data.
|
3272
|
+
x_column (str): Name of the column to be used for the x-axis.
|
3273
|
+
y_column (str): Name of the column to be used for the y-axis.
|
3274
|
+
plot_title (str): Title of the plot. Default is 'Jitter Plot'.
|
3275
|
+
output_path (str): Path to save the plot image. If None, the plot will be displayed. Default is None.
|
3276
|
+
|
3277
|
+
Returns:
|
3278
|
+
pd.DataFrame: The filtered and balanced DataFrame.
|
3279
|
+
"""
|
3280
|
+
# Read the CSV file into a DataFrame
|
3281
|
+
df = join_measurments_and_annotation(src, tables=['cell', 'nucleus', 'pathogen', 'cytoplasm'])
|
3282
|
+
|
3283
|
+
# Print column names for debugging
|
3284
|
+
print(f"Generated dataframe with: {df.shape[1]} columns and {df.shape[0]} rows")
|
3285
|
+
#print("Columns in DataFrame:", df.columns.tolist())
|
3286
|
+
|
3287
|
+
# Replace NaN values with a specific label in x_column
|
3288
|
+
df[x_column] = df[x_column].fillna('NaN')
|
3289
|
+
|
3290
|
+
# Filter the DataFrame if filter_column and filter_values are provided
|
3291
|
+
if not filter_column is None:
|
3292
|
+
if isinstance(filter_column, str):
|
3293
|
+
df = df[df[filter_column].isin(filter_values)]
|
3294
|
+
if isinstance(filter_column, list):
|
3295
|
+
for i,val in enumerate(filter_column):
|
3296
|
+
print(f'hello {len(df)}')
|
3297
|
+
df = df[df[val].isin(filter_values[i])]
|
3298
|
+
|
3299
|
+
# Use the correct column names based on your DataFrame
|
3300
|
+
required_columns = ['plate_x', 'row_x', 'col_x']
|
3301
|
+
if not all(column in df.columns for column in required_columns):
|
3302
|
+
raise KeyError(f"DataFrame does not contain the necessary columns: {required_columns}")
|
3303
|
+
|
3304
|
+
# Filter to retain rows with non-NaN values in x_column and with matching plate, row, col values
|
3305
|
+
non_nan_df = df[df[x_column] != 'NaN']
|
3306
|
+
retained_rows = df[df[['plate_x', 'row_x', 'col_x']].apply(tuple, axis=1).isin(non_nan_df[['plate_x', 'row_x', 'col_x']].apply(tuple, axis=1))]
|
3307
|
+
|
3308
|
+
# Determine the minimum count of examples across all groups in x_column
|
3309
|
+
min_count = retained_rows[x_column].value_counts().min()
|
3310
|
+
print(f'Found {min_count} annotated images')
|
3311
|
+
|
3312
|
+
# Randomly sample min_count examples from each group in x_column
|
3313
|
+
balanced_df = retained_rows.groupby(x_column).apply(lambda x: x.sample(min_count, random_state=42)).reset_index(drop=True)
|
3314
|
+
|
3315
|
+
# Create the jitter plot
|
3316
|
+
plt.figure(figsize=(10, 6))
|
3317
|
+
jitter_plot = sns.stripplot(data=balanced_df, x=x_column, y=y_column, hue=x_column, jitter=True, palette='viridis', dodge=False)
|
3318
|
+
plt.title(plot_title)
|
3319
|
+
plt.xlabel(x_column)
|
3320
|
+
plt.ylabel(y_column)
|
3321
|
+
|
3322
|
+
# Customize the x-axis labels
|
3323
|
+
plt.xticks(rotation=45, ha='right')
|
3324
|
+
|
3325
|
+
# Adjust the position of the x-axis labels to be centered below the data
|
3326
|
+
ax = plt.gca()
|
3327
|
+
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='center')
|
3328
|
+
|
3329
|
+
# Save the plot to a file or display it
|
3330
|
+
if output_path:
|
3331
|
+
plt.savefig(output_path, bbox_inches='tight')
|
3332
|
+
print(f"Jitter plot saved to {output_path}")
|
3333
|
+
else:
|
3334
|
+
plt.show()
|
3335
|
+
|
3336
|
+
return balanced_df
|